[mypy] fixed domain
This commit is contained in:
		
							parent
							
								
									d9fd48a50e
								
							
						
					
					
						commit
						6ee8c68746
					
				
					 7 changed files with 33 additions and 25 deletions
				
			
		| 
						 | 
					@ -17,6 +17,6 @@ class CommonMixin:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        :rtype: str
 | 
					        :rtype: str
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        return cls.__name__.lower()
 | 
					        return cls.__name__.lower()  # type: ignore
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    id: Mapped[int] = mapped_column(primary_key=True)
 | 
					    id: Mapped[int] = mapped_column(primary_key=True)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,7 +3,7 @@ from sqlalchemy import ForeignKey, Integer, Date
 | 
				
			||||||
from sqlalchemy import select
 | 
					from sqlalchemy import select
 | 
				
			||||||
from sqlalchemy.sql.selectable import Select
 | 
					from sqlalchemy.sql.selectable import Select
 | 
				
			||||||
from sqlalchemy.ext.asyncio import AsyncSession
 | 
					from sqlalchemy.ext.asyncio import AsyncSession
 | 
				
			||||||
from datetime import date
 | 
					import datetime
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .base import Base, CommonMixin
 | 
					from .base import Base, CommonMixin
 | 
				
			||||||
| 
						 | 
					@ -17,7 +17,7 @@ class Diary(Base, CommonMixin):
 | 
				
			||||||
    meals: Mapped[list[Meal]] = relationship(
 | 
					    meals: Mapped[list[Meal]] = relationship(
 | 
				
			||||||
        lazy="selectin", order_by=Meal.order.desc()
 | 
					        lazy="selectin", order_by=Meal.order.desc()
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    date: Mapped[date] = mapped_column(Date)
 | 
					    date: Mapped[datetime.date] = mapped_column(Date)
 | 
				
			||||||
    user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
 | 
					    user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
| 
						 | 
					@ -74,14 +74,16 @@ class Diary(Base, CommonMixin):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    async def get_diary(
 | 
					    async def get_diary(
 | 
				
			||||||
        cls, session: AsyncSession, user_id: int, date: date
 | 
					        cls, session: AsyncSession, user_id: int, date: datetime.date
 | 
				
			||||||
    ) -> "Optional[Diary]":
 | 
					    ) -> "Optional[Diary]":
 | 
				
			||||||
        """get_diary."""
 | 
					        """get_diary."""
 | 
				
			||||||
        query = cls.query(user_id).where(cls.date == date)
 | 
					        query = cls.query(user_id).where(cls.date == date)
 | 
				
			||||||
        return await session.scalar(query)
 | 
					        return await session.scalar(query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    async def create(cls, session: AsyncSession, user_id: int, date: date) -> "Diary":
 | 
					    async def create(
 | 
				
			||||||
 | 
					        cls, session: AsyncSession, user_id: int, date: datetime.date
 | 
				
			||||||
 | 
					    ) -> "Diary":
 | 
				
			||||||
        diary = Diary(
 | 
					        diary = Diary(
 | 
				
			||||||
            date=date,
 | 
					            date=date,
 | 
				
			||||||
            user_id=user_id,
 | 
					            user_id=user_id,
 | 
				
			||||||
| 
						 | 
					@ -93,12 +95,13 @@ class Diary(Base, CommonMixin):
 | 
				
			||||||
        except Exception:
 | 
					        except Exception:
 | 
				
			||||||
            raise RuntimeError()
 | 
					            raise RuntimeError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        diary = await cls.get_by_id(session, user_id, diary.id)
 | 
					        db_diary = await cls.get_by_id(session, user_id, diary.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not diary:
 | 
					        if not db_diary:
 | 
				
			||||||
            raise RuntimeError()
 | 
					            raise RuntimeError()
 | 
				
			||||||
        await Meal.create(session, diary.id)
 | 
					
 | 
				
			||||||
        return diary
 | 
					        await Meal.create(session, db_diary.id)
 | 
				
			||||||
 | 
					        return db_diary
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    async def get_by_id(
 | 
					    async def get_by_id(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -87,10 +87,10 @@ class Entry(Base, CommonMixin):
 | 
				
			||||||
        except IntegrityError:
 | 
					        except IntegrityError:
 | 
				
			||||||
            raise AssertionError("meal or product does not exist")
 | 
					            raise AssertionError("meal or product does not exist")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        entry = await cls._get_by_id(session, entry.id)
 | 
					        db_entry = await cls._get_by_id(session, entry.id)
 | 
				
			||||||
        if not entry:
 | 
					        if not db_entry:
 | 
				
			||||||
            raise RuntimeError()
 | 
					            raise RuntimeError()
 | 
				
			||||||
        return entry
 | 
					        return db_entry
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def update(
 | 
					    async def update(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -84,10 +84,10 @@ class Meal(Base, CommonMixin):
 | 
				
			||||||
        except IntegrityError:
 | 
					        except IntegrityError:
 | 
				
			||||||
            raise AssertionError("diary does not exist")
 | 
					            raise AssertionError("diary does not exist")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        meal = await cls._get_by_id(session, meal.id)
 | 
					        db_meal = await cls._get_by_id(session, meal.id)
 | 
				
			||||||
        if not meal:
 | 
					        if not db_meal:
 | 
				
			||||||
            raise RuntimeError()
 | 
					            raise RuntimeError()
 | 
				
			||||||
        return meal
 | 
					        return db_meal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    async def create_from_preset(
 | 
					    async def create_from_preset(
 | 
				
			||||||
| 
						 | 
					@ -118,10 +118,10 @@ class Meal(Base, CommonMixin):
 | 
				
			||||||
        for entry in preset.entries:
 | 
					        for entry in preset.entries:
 | 
				
			||||||
            await Entry.create(session, meal.id, entry.product_id, entry.grams)
 | 
					            await Entry.create(session, meal.id, entry.product_id, entry.grams)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        meal = await cls._get_by_id(session, meal.id)
 | 
					        db_meal = await cls._get_by_id(session, meal.id)
 | 
				
			||||||
        if not meal:
 | 
					        if not db_meal:
 | 
				
			||||||
            raise RuntimeError()
 | 
					            raise RuntimeError()
 | 
				
			||||||
        return meal
 | 
					        return db_meal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    async def _get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Meal]":
 | 
					    async def _get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Meal]":
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -63,7 +63,7 @@ class Preset(Base, CommonMixin):
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    async def create(
 | 
					    async def create(
 | 
				
			||||||
        cls, session: AsyncSession, user_id: int, name: str, meal: "Meal"
 | 
					        cls, session: AsyncSession, user_id: int, name: str, meal: "Meal"
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> "Preset":
 | 
				
			||||||
        preset = Preset(user_id=user_id, name=name)
 | 
					        preset = Preset(user_id=user_id, name=name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        session.add(preset)
 | 
					        session.add(preset)
 | 
				
			||||||
| 
						 | 
					@ -76,7 +76,12 @@ class Preset(Base, CommonMixin):
 | 
				
			||||||
        for entry in meal.entries:
 | 
					        for entry in meal.entries:
 | 
				
			||||||
            await PresetEntry.create(session, preset.id, entry)
 | 
					            await PresetEntry.create(session, preset.id, entry)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return await cls.get(session, user_id, preset.id)
 | 
					        db_preset = await cls.get(session, user_id, preset.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not db_preset:
 | 
				
			||||||
 | 
					            raise RuntimeError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return db_preset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    async def list_all(
 | 
					    async def list_all(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,8 +15,8 @@ class Product(Base, CommonMixin):
 | 
				
			||||||
    carb: Mapped[float]
 | 
					    carb: Mapped[float]
 | 
				
			||||||
    fat: Mapped[float]
 | 
					    fat: Mapped[float]
 | 
				
			||||||
    fiber: Mapped[float]
 | 
					    fiber: Mapped[float]
 | 
				
			||||||
    hard_coded_calories: Mapped[Optional[float]] = None
 | 
					    hard_coded_calories: Mapped[Optional[float]]
 | 
				
			||||||
    barcode: Mapped[Optional[str]] = None
 | 
					    barcode: Mapped[Optional[str]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def calories(self) -> float:
 | 
					    def calories(self) -> float:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -47,18 +47,18 @@ class RefreshToken(Base, CommonMixin):
 | 
				
			||||||
        :type token: str
 | 
					        :type token: str
 | 
				
			||||||
        :rtype: "RefreshToken"
 | 
					        :rtype: "RefreshToken"
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        token = cls(
 | 
					        db_token = cls(
 | 
				
			||||||
            user_id=user_id,
 | 
					            user_id=user_id,
 | 
				
			||||||
            token=token,
 | 
					            token=token,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        session.add(token)
 | 
					        session.add(db_token)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            await session.flush()
 | 
					            await session.flush()
 | 
				
			||||||
        except Exception:
 | 
					        except Exception:
 | 
				
			||||||
            raise AssertionError("invalid token")
 | 
					            raise AssertionError("invalid token")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return token
 | 
					        return db_token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def delete(self, session: AsyncSession) -> None:
 | 
					    async def delete(self, session: AsyncSession) -> None:
 | 
				
			||||||
        """delete.
 | 
					        """delete.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue