manage permission

This commit is contained in:
Piotr Domański 2023-04-03 13:46:03 +02:00
parent ed004e99a4
commit 157ee4ef55
6 changed files with 72 additions and 16 deletions

View file

@ -4,6 +4,5 @@ Simple API for food diary application.
## TODO
- [ ] Add access restriction on each endpoint
- [ ] Add tests
- [ ] Add default servings

View file

@ -3,12 +3,17 @@ from fastapi import HTTPException
from ..model.entry import Entry, CreateEntryPayload, UpdateEntryPayload
from ..domain.entry import Entry as DBEntry
from ..domain.meal import Meal as DBMeal
from .base import AuthorizedController
class CreateEntry(AuthorizedController):
async def call(self, content: CreateEntryPayload) -> Entry:
async with self.async_session.begin() as session:
meal = await DBMeal.get_by_id(session, self.user.id, content.meal_id)
if meal is None:
raise HTTPException(status_code=404, detail="meal not found")
try:
entry = await DBEntry.create(
session, content.meal_id, content.product_id, content.grams
@ -21,7 +26,7 @@ class CreateEntry(AuthorizedController):
class UpdateEntry(AuthorizedController):
async def call(self, entry_id: int, content: UpdateEntryPayload) -> Entry:
async with self.async_session.begin() as session:
entry = await DBEntry.get_by_id(session, entry_id)
entry = await DBEntry.get_by_id(session, self.user.id, entry_id)
if entry is None:
raise HTTPException(status_code=404, detail="entry not found")
@ -37,7 +42,7 @@ class UpdateEntry(AuthorizedController):
class DeleteEntry(AuthorizedController):
async def call(self, entry_id: int) -> Entry:
async with self.async_session.begin() as session:
entry = await DBEntry.get_by_id(session, entry_id)
entry = await DBEntry.get_by_id(session, self.user.id, entry_id)
if entry is None:
raise HTTPException(status_code=404, detail="entry not found")

View file

@ -3,12 +3,18 @@ from fastapi import HTTPException
from ..model.meal import Meal, CreateMealPayload
from ..domain.meal import Meal as DBMeal
from ..domain.diary import Diary as DBDiary
from .base import AuthorizedController
class CreateMeal(AuthorizedController):
async def call(self, content: CreateMealPayload) -> Meal:
async with self.async_session.begin() as session:
if not await DBDiary.has_permission(
session, self.user.id, content.diary_id
):
raise HTTPException(status_code=404, detail="not found")
try:
meal = await DBMeal.create(
session, content.diary_id, content.order, content.name

View file

@ -14,7 +14,9 @@ from .entry import Entry
class Diary(Base, CommonMixin):
"""Diary represents user diary for given day"""
meals: Mapped[list[Meal]] = relationship(lazy="selectin", order_by=Meal.order)
meals: Mapped[list[Meal]] = relationship(
lazy="selectin", order_by=Meal.order.desc()
)
date: Mapped[date] = mapped_column(Date)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
@ -67,7 +69,7 @@ class Diary(Base, CommonMixin):
cls, session: AsyncSession, user_id: int, date: date
) -> "Optional[Diary]":
"""get_diary."""
query = select(cls).where(cls.user_id == user_id).where(cls.date == date)
query = cls.query(user_id).where(cls.date == date)
return await session.scalar(query)
@classmethod
@ -95,10 +97,12 @@ class Diary(Base, CommonMixin):
cls, session: AsyncSession, user_id: int, id: int
) -> "Optional[Diary]":
"""get_by_id."""
query = (
select(cls)
.where(cls.user_id == user_id)
.where(cls.id == id)
.options(joinedload(cls.meals))
)
query = cls.query(user_id).where(cls.id == id)
return await session.scalar(query)
@classmethod
async def has_permission(cls, session: AsyncSession, user_id: int, id: int) -> bool:
"""has_permission."""
query = select(cls.id).where(cls.user_id == user_id).where(cls.id == id)
obj = await session.scalar(query)
return obj is not None

View file

@ -79,7 +79,7 @@ class Entry(Base, CommonMixin):
except IntegrityError:
raise AssertionError("meal or product does not exist")
entry = await cls.get_by_id(session, entry.id)
entry = await cls._get_by_id(session, entry.id)
if not entry:
raise RuntimeError()
return entry
@ -111,11 +111,35 @@ class Entry(Base, CommonMixin):
raise AssertionError("product does not exist")
@classmethod
async def get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Entry]":
async def _get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Entry]":
"""get_by_id."""
query = select(cls).where(cls.id == id).options(joinedload(cls.product))
return await session.scalar(query.order_by(cls.id))
@classmethod
async def get_by_id(
cls, session: AsyncSession, user_id: int, id: int
) -> "Optional[Entry]":
"""get_by_id."""
from .diary import Diary
from .meal import Meal
query = (
select(cls)
.where(cls.id == id)
.join(
Meal,
)
.join(
Diary,
)
.where(
Diary.user_id == user_id,
)
.options(joinedload(cls.product))
)
return await session.scalar(query.order_by(cls.id))
async def delete(self, session) -> None:
"""delete."""
await session.delete(self)

View file

@ -15,7 +15,9 @@ class Meal(Base, CommonMixin):
name: Mapped[str]
order: Mapped[int]
diary_id: Mapped[int] = mapped_column(Integer, ForeignKey("diary.id"))
entries: Mapped[list[Entry]] = relationship(lazy="selectin", order_by=Entry.last_changed)
entries: Mapped[list[Entry]] = relationship(
lazy="selectin", order_by=Entry.last_changed
)
@property
def calories(self) -> float:
@ -72,13 +74,29 @@ class Meal(Base, CommonMixin):
except IntegrityError:
raise AssertionError("diary does not exist")
meal = await cls.get_by_id(session, meal.id)
meal = await cls._get_by_id(session, meal.id)
if not meal:
raise RuntimeError()
return meal
@classmethod
async def get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Meal]":
async def _get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Meal]":
"""get_by_id."""
query = select(cls).where(cls.id == id).options(joinedload(cls.entries))
return await session.scalar(query.order_by(cls.id))
@classmethod
async def get_by_id(
cls, session: AsyncSession, user_id: int, id: int
) -> "Optional[Meal]":
"""get_by_id."""
from .diary import Diary
query = (
select(cls)
.where(cls.id == id)
.join(Diary)
.where(Diary.user_id == user_id)
.options(joinedload(cls.entries))
)
return await session.scalar(query.order_by(cls.id))