[mypy] fixed domain
This commit is contained in:
parent
d9fd48a50e
commit
6ee8c68746
7 changed files with 33 additions and 25 deletions
|
@ -17,6 +17,6 @@ class CommonMixin:
|
|||
|
||||
:rtype: str
|
||||
"""
|
||||
return cls.__name__.lower()
|
||||
return cls.__name__.lower() # type: ignore
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
|
|
@ -3,7 +3,7 @@ from sqlalchemy import ForeignKey, Integer, Date
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.sql.selectable import Select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from datetime import date
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
from .base import Base, CommonMixin
|
||||
|
@ -17,7 +17,7 @@ class Diary(Base, CommonMixin):
|
|||
meals: Mapped[list[Meal]] = relationship(
|
||||
lazy="selectin", order_by=Meal.order.desc()
|
||||
)
|
||||
date: Mapped[date] = mapped_column(Date)
|
||||
date: Mapped[datetime.date] = mapped_column(Date)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
|
||||
|
||||
@property
|
||||
|
@ -74,14 +74,16 @@ class Diary(Base, CommonMixin):
|
|||
|
||||
@classmethod
|
||||
async def get_diary(
|
||||
cls, session: AsyncSession, user_id: int, date: date
|
||||
cls, session: AsyncSession, user_id: int, date: datetime.date
|
||||
) -> "Optional[Diary]":
|
||||
"""get_diary."""
|
||||
query = cls.query(user_id).where(cls.date == date)
|
||||
return await session.scalar(query)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, session: AsyncSession, user_id: int, date: date) -> "Diary":
|
||||
async def create(
|
||||
cls, session: AsyncSession, user_id: int, date: datetime.date
|
||||
) -> "Diary":
|
||||
diary = Diary(
|
||||
date=date,
|
||||
user_id=user_id,
|
||||
|
@ -93,12 +95,13 @@ class Diary(Base, CommonMixin):
|
|||
except Exception:
|
||||
raise RuntimeError()
|
||||
|
||||
diary = await cls.get_by_id(session, user_id, diary.id)
|
||||
db_diary = await cls.get_by_id(session, user_id, diary.id)
|
||||
|
||||
if not diary:
|
||||
if not db_diary:
|
||||
raise RuntimeError()
|
||||
await Meal.create(session, diary.id)
|
||||
return diary
|
||||
|
||||
await Meal.create(session, db_diary.id)
|
||||
return db_diary
|
||||
|
||||
@classmethod
|
||||
async def get_by_id(
|
||||
|
|
|
@ -87,10 +87,10 @@ class Entry(Base, CommonMixin):
|
|||
except IntegrityError:
|
||||
raise AssertionError("meal or product does not exist")
|
||||
|
||||
entry = await cls._get_by_id(session, entry.id)
|
||||
if not entry:
|
||||
db_entry = await cls._get_by_id(session, entry.id)
|
||||
if not db_entry:
|
||||
raise RuntimeError()
|
||||
return entry
|
||||
return db_entry
|
||||
|
||||
async def update(
|
||||
self,
|
||||
|
|
|
@ -84,10 +84,10 @@ class Meal(Base, CommonMixin):
|
|||
except IntegrityError:
|
||||
raise AssertionError("diary does not exist")
|
||||
|
||||
meal = await cls._get_by_id(session, meal.id)
|
||||
if not meal:
|
||||
db_meal = await cls._get_by_id(session, meal.id)
|
||||
if not db_meal:
|
||||
raise RuntimeError()
|
||||
return meal
|
||||
return db_meal
|
||||
|
||||
@classmethod
|
||||
async def create_from_preset(
|
||||
|
@ -118,10 +118,10 @@ class Meal(Base, CommonMixin):
|
|||
for entry in preset.entries:
|
||||
await Entry.create(session, meal.id, entry.product_id, entry.grams)
|
||||
|
||||
meal = await cls._get_by_id(session, meal.id)
|
||||
if not meal:
|
||||
db_meal = await cls._get_by_id(session, meal.id)
|
||||
if not db_meal:
|
||||
raise RuntimeError()
|
||||
return meal
|
||||
return db_meal
|
||||
|
||||
@classmethod
|
||||
async def _get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Meal]":
|
||||
|
|
|
@ -63,7 +63,7 @@ class Preset(Base, CommonMixin):
|
|||
@classmethod
|
||||
async def create(
|
||||
cls, session: AsyncSession, user_id: int, name: str, meal: "Meal"
|
||||
) -> None:
|
||||
) -> "Preset":
|
||||
preset = Preset(user_id=user_id, name=name)
|
||||
|
||||
session.add(preset)
|
||||
|
@ -76,7 +76,12 @@ class Preset(Base, CommonMixin):
|
|||
for entry in meal.entries:
|
||||
await PresetEntry.create(session, preset.id, entry)
|
||||
|
||||
return await cls.get(session, user_id, preset.id)
|
||||
db_preset = await cls.get(session, user_id, preset.id)
|
||||
|
||||
if not db_preset:
|
||||
raise RuntimeError()
|
||||
|
||||
return db_preset
|
||||
|
||||
@classmethod
|
||||
async def list_all(
|
||||
|
|
|
@ -15,8 +15,8 @@ class Product(Base, CommonMixin):
|
|||
carb: Mapped[float]
|
||||
fat: Mapped[float]
|
||||
fiber: Mapped[float]
|
||||
hard_coded_calories: Mapped[Optional[float]] = None
|
||||
barcode: Mapped[Optional[str]] = None
|
||||
hard_coded_calories: Mapped[Optional[float]]
|
||||
barcode: Mapped[Optional[str]]
|
||||
|
||||
@property
|
||||
def calories(self) -> float:
|
||||
|
|
|
@ -47,18 +47,18 @@ class RefreshToken(Base, CommonMixin):
|
|||
:type token: str
|
||||
:rtype: "RefreshToken"
|
||||
"""
|
||||
token = cls(
|
||||
db_token = cls(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
)
|
||||
session.add(token)
|
||||
session.add(db_token)
|
||||
|
||||
try:
|
||||
await session.flush()
|
||||
except Exception:
|
||||
raise AssertionError("invalid token")
|
||||
|
||||
return token
|
||||
return db_token
|
||||
|
||||
async def delete(self, session: AsyncSession) -> None:
|
||||
"""delete.
|
||||
|
|
Loading…
Reference in a new issue