diff --git a/fooder/domain/base.py b/fooder/domain/base.py index 112aadc..7f11661 100644 --- a/fooder/domain/base.py +++ b/fooder/domain/base.py @@ -17,6 +17,6 @@ class CommonMixin: :rtype: str """ - return cls.__name__.lower() + return cls.__name__.lower() # type: ignore id: Mapped[int] = mapped_column(primary_key=True) diff --git a/fooder/domain/diary.py b/fooder/domain/diary.py index 81094ca..ee7592f 100644 --- a/fooder/domain/diary.py +++ b/fooder/domain/diary.py @@ -3,7 +3,7 @@ from sqlalchemy import ForeignKey, Integer, Date from sqlalchemy import select from sqlalchemy.sql.selectable import Select from sqlalchemy.ext.asyncio import AsyncSession -from datetime import date +import datetime from typing import Optional from .base import Base, CommonMixin @@ -17,7 +17,7 @@ class Diary(Base, CommonMixin): meals: Mapped[list[Meal]] = relationship( lazy="selectin", order_by=Meal.order.desc() ) - date: Mapped[date] = mapped_column(Date) + date: Mapped[datetime.date] = mapped_column(Date) user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id")) @property @@ -74,14 +74,16 @@ class Diary(Base, CommonMixin): @classmethod async def get_diary( - cls, session: AsyncSession, user_id: int, date: date + cls, session: AsyncSession, user_id: int, date: datetime.date ) -> "Optional[Diary]": """get_diary.""" query = cls.query(user_id).where(cls.date == date) return await session.scalar(query) @classmethod - async def create(cls, session: AsyncSession, user_id: int, date: date) -> "Diary": + async def create( + cls, session: AsyncSession, user_id: int, date: datetime.date + ) -> "Diary": diary = Diary( date=date, user_id=user_id, @@ -93,12 +95,13 @@ class Diary(Base, CommonMixin): except Exception: raise RuntimeError() - diary = await cls.get_by_id(session, user_id, diary.id) + db_diary = await cls.get_by_id(session, user_id, diary.id) - if not diary: + if not db_diary: raise RuntimeError() - await Meal.create(session, diary.id) - return diary + + await Meal.create(session, db_diary.id) + return db_diary @classmethod async def get_by_id( diff --git a/fooder/domain/entry.py b/fooder/domain/entry.py index ef877c3..afbc3a5 100644 --- a/fooder/domain/entry.py +++ b/fooder/domain/entry.py @@ -87,10 +87,10 @@ class Entry(Base, CommonMixin): except IntegrityError: raise AssertionError("meal or product does not exist") - entry = await cls._get_by_id(session, entry.id) - if not entry: + db_entry = await cls._get_by_id(session, entry.id) + if not db_entry: raise RuntimeError() - return entry + return db_entry async def update( self, diff --git a/fooder/domain/meal.py b/fooder/domain/meal.py index 7db038d..917fff2 100644 --- a/fooder/domain/meal.py +++ b/fooder/domain/meal.py @@ -84,10 +84,10 @@ class Meal(Base, CommonMixin): except IntegrityError: raise AssertionError("diary does not exist") - meal = await cls._get_by_id(session, meal.id) - if not meal: + db_meal = await cls._get_by_id(session, meal.id) + if not db_meal: raise RuntimeError() - return meal + return db_meal @classmethod async def create_from_preset( @@ -118,10 +118,10 @@ class Meal(Base, CommonMixin): for entry in preset.entries: await Entry.create(session, meal.id, entry.product_id, entry.grams) - meal = await cls._get_by_id(session, meal.id) - if not meal: + db_meal = await cls._get_by_id(session, meal.id) + if not db_meal: raise RuntimeError() - return meal + return db_meal @classmethod async def _get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Meal]": diff --git a/fooder/domain/preset.py b/fooder/domain/preset.py index 7da5c2a..6242978 100644 --- a/fooder/domain/preset.py +++ b/fooder/domain/preset.py @@ -63,7 +63,7 @@ class Preset(Base, CommonMixin): @classmethod async def create( cls, session: AsyncSession, user_id: int, name: str, meal: "Meal" - ) -> None: + ) -> "Preset": preset = Preset(user_id=user_id, name=name) session.add(preset) @@ -76,7 +76,12 @@ class Preset(Base, CommonMixin): for entry in meal.entries: await PresetEntry.create(session, preset.id, entry) - return await cls.get(session, user_id, preset.id) + db_preset = await cls.get(session, user_id, preset.id) + + if not db_preset: + raise RuntimeError() + + return db_preset @classmethod async def list_all( diff --git a/fooder/domain/product.py b/fooder/domain/product.py index 8011c8e..b97ff29 100644 --- a/fooder/domain/product.py +++ b/fooder/domain/product.py @@ -15,8 +15,8 @@ class Product(Base, CommonMixin): carb: Mapped[float] fat: Mapped[float] fiber: Mapped[float] - hard_coded_calories: Mapped[Optional[float]] = None - barcode: Mapped[Optional[str]] = None + hard_coded_calories: Mapped[Optional[float]] + barcode: Mapped[Optional[str]] @property def calories(self) -> float: diff --git a/fooder/domain/token.py b/fooder/domain/token.py index 2ca5a25..43b2e9c 100644 --- a/fooder/domain/token.py +++ b/fooder/domain/token.py @@ -47,18 +47,18 @@ class RefreshToken(Base, CommonMixin): :type token: str :rtype: "RefreshToken" """ - token = cls( + db_token = cls( user_id=user_id, token=token, ) - session.add(token) + session.add(db_token) try: await session.flush() except Exception: raise AssertionError("invalid token") - return token + return db_token async def delete(self, session: AsyncSession) -> None: """delete.