[mypy] fixed domain

This commit is contained in:
Piotr Domański 2024-05-20 13:45:21 +02:00
parent d9fd48a50e
commit 6ee8c68746
7 changed files with 33 additions and 25 deletions

View file

@ -17,6 +17,6 @@ class CommonMixin:
:rtype: str :rtype: str
""" """
return cls.__name__.lower() return cls.__name__.lower() # type: ignore
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)

View file

@ -3,7 +3,7 @@ from sqlalchemy import ForeignKey, Integer, Date
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.sql.selectable import Select from sqlalchemy.sql.selectable import Select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from datetime import date import datetime
from typing import Optional from typing import Optional
from .base import Base, CommonMixin from .base import Base, CommonMixin
@ -17,7 +17,7 @@ class Diary(Base, CommonMixin):
meals: Mapped[list[Meal]] = relationship( meals: Mapped[list[Meal]] = relationship(
lazy="selectin", order_by=Meal.order.desc() lazy="selectin", order_by=Meal.order.desc()
) )
date: Mapped[date] = mapped_column(Date) date: Mapped[datetime.date] = mapped_column(Date)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id")) user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
@property @property
@ -74,14 +74,16 @@ class Diary(Base, CommonMixin):
@classmethod @classmethod
async def get_diary( async def get_diary(
cls, session: AsyncSession, user_id: int, date: date cls, session: AsyncSession, user_id: int, date: datetime.date
) -> "Optional[Diary]": ) -> "Optional[Diary]":
"""get_diary.""" """get_diary."""
query = cls.query(user_id).where(cls.date == date) query = cls.query(user_id).where(cls.date == date)
return await session.scalar(query) return await session.scalar(query)
@classmethod @classmethod
async def create(cls, session: AsyncSession, user_id: int, date: date) -> "Diary": async def create(
cls, session: AsyncSession, user_id: int, date: datetime.date
) -> "Diary":
diary = Diary( diary = Diary(
date=date, date=date,
user_id=user_id, user_id=user_id,
@ -93,12 +95,13 @@ class Diary(Base, CommonMixin):
except Exception: except Exception:
raise RuntimeError() raise RuntimeError()
diary = await cls.get_by_id(session, user_id, diary.id) db_diary = await cls.get_by_id(session, user_id, diary.id)
if not diary: if not db_diary:
raise RuntimeError() raise RuntimeError()
await Meal.create(session, diary.id)
return diary await Meal.create(session, db_diary.id)
return db_diary
@classmethod @classmethod
async def get_by_id( async def get_by_id(

View file

@ -87,10 +87,10 @@ class Entry(Base, CommonMixin):
except IntegrityError: except IntegrityError:
raise AssertionError("meal or product does not exist") raise AssertionError("meal or product does not exist")
entry = await cls._get_by_id(session, entry.id) db_entry = await cls._get_by_id(session, entry.id)
if not entry: if not db_entry:
raise RuntimeError() raise RuntimeError()
return entry return db_entry
async def update( async def update(
self, self,

View file

@ -84,10 +84,10 @@ class Meal(Base, CommonMixin):
except IntegrityError: except IntegrityError:
raise AssertionError("diary does not exist") raise AssertionError("diary does not exist")
meal = await cls._get_by_id(session, meal.id) db_meal = await cls._get_by_id(session, meal.id)
if not meal: if not db_meal:
raise RuntimeError() raise RuntimeError()
return meal return db_meal
@classmethod @classmethod
async def create_from_preset( async def create_from_preset(
@ -118,10 +118,10 @@ class Meal(Base, CommonMixin):
for entry in preset.entries: for entry in preset.entries:
await Entry.create(session, meal.id, entry.product_id, entry.grams) await Entry.create(session, meal.id, entry.product_id, entry.grams)
meal = await cls._get_by_id(session, meal.id) db_meal = await cls._get_by_id(session, meal.id)
if not meal: if not db_meal:
raise RuntimeError() raise RuntimeError()
return meal return db_meal
@classmethod @classmethod
async def _get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Meal]": async def _get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Meal]":

View file

@ -63,7 +63,7 @@ class Preset(Base, CommonMixin):
@classmethod @classmethod
async def create( async def create(
cls, session: AsyncSession, user_id: int, name: str, meal: "Meal" cls, session: AsyncSession, user_id: int, name: str, meal: "Meal"
) -> None: ) -> "Preset":
preset = Preset(user_id=user_id, name=name) preset = Preset(user_id=user_id, name=name)
session.add(preset) session.add(preset)
@ -76,7 +76,12 @@ class Preset(Base, CommonMixin):
for entry in meal.entries: for entry in meal.entries:
await PresetEntry.create(session, preset.id, entry) await PresetEntry.create(session, preset.id, entry)
return await cls.get(session, user_id, preset.id) db_preset = await cls.get(session, user_id, preset.id)
if not db_preset:
raise RuntimeError()
return db_preset
@classmethod @classmethod
async def list_all( async def list_all(

View file

@ -15,8 +15,8 @@ class Product(Base, CommonMixin):
carb: Mapped[float] carb: Mapped[float]
fat: Mapped[float] fat: Mapped[float]
fiber: Mapped[float] fiber: Mapped[float]
hard_coded_calories: Mapped[Optional[float]] = None hard_coded_calories: Mapped[Optional[float]]
barcode: Mapped[Optional[str]] = None barcode: Mapped[Optional[str]]
@property @property
def calories(self) -> float: def calories(self) -> float:

View file

@ -47,18 +47,18 @@ class RefreshToken(Base, CommonMixin):
:type token: str :type token: str
:rtype: "RefreshToken" :rtype: "RefreshToken"
""" """
token = cls( db_token = cls(
user_id=user_id, user_id=user_id,
token=token, token=token,
) )
session.add(token) session.add(db_token)
try: try:
await session.flush() await session.flush()
except Exception: except Exception:
raise AssertionError("invalid token") raise AssertionError("invalid token")
return token return db_token
async def delete(self, session: AsyncSession) -> None: async def delete(self, session: AsyncSession) -> None:
"""delete. """delete.