[mypy] fixed domain

This commit is contained in:
Piotr Domański 2024-05-20 13:45:21 +02:00
parent d9fd48a50e
commit 6ee8c68746
7 changed files with 33 additions and 25 deletions

View file

@ -17,6 +17,6 @@ class CommonMixin:
:rtype: str
"""
return cls.__name__.lower()
return cls.__name__.lower() # type: ignore
id: Mapped[int] = mapped_column(primary_key=True)

View file

@ -3,7 +3,7 @@ from sqlalchemy import ForeignKey, Integer, Date
from sqlalchemy import select
from sqlalchemy.sql.selectable import Select
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import date
import datetime
from typing import Optional
from .base import Base, CommonMixin
@ -17,7 +17,7 @@ class Diary(Base, CommonMixin):
meals: Mapped[list[Meal]] = relationship(
lazy="selectin", order_by=Meal.order.desc()
)
date: Mapped[date] = mapped_column(Date)
date: Mapped[datetime.date] = mapped_column(Date)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
@property
@ -74,14 +74,16 @@ class Diary(Base, CommonMixin):
@classmethod
async def get_diary(
cls, session: AsyncSession, user_id: int, date: date
cls, session: AsyncSession, user_id: int, date: datetime.date
) -> "Optional[Diary]":
"""get_diary."""
query = cls.query(user_id).where(cls.date == date)
return await session.scalar(query)
@classmethod
async def create(cls, session: AsyncSession, user_id: int, date: date) -> "Diary":
async def create(
cls, session: AsyncSession, user_id: int, date: datetime.date
) -> "Diary":
diary = Diary(
date=date,
user_id=user_id,
@ -93,12 +95,13 @@ class Diary(Base, CommonMixin):
except Exception:
raise RuntimeError()
diary = await cls.get_by_id(session, user_id, diary.id)
db_diary = await cls.get_by_id(session, user_id, diary.id)
if not diary:
if not db_diary:
raise RuntimeError()
await Meal.create(session, diary.id)
return diary
await Meal.create(session, db_diary.id)
return db_diary
@classmethod
async def get_by_id(

View file

@ -87,10 +87,10 @@ class Entry(Base, CommonMixin):
except IntegrityError:
raise AssertionError("meal or product does not exist")
entry = await cls._get_by_id(session, entry.id)
if not entry:
db_entry = await cls._get_by_id(session, entry.id)
if not db_entry:
raise RuntimeError()
return entry
return db_entry
async def update(
self,

View file

@ -84,10 +84,10 @@ class Meal(Base, CommonMixin):
except IntegrityError:
raise AssertionError("diary does not exist")
meal = await cls._get_by_id(session, meal.id)
if not meal:
db_meal = await cls._get_by_id(session, meal.id)
if not db_meal:
raise RuntimeError()
return meal
return db_meal
@classmethod
async def create_from_preset(
@ -118,10 +118,10 @@ class Meal(Base, CommonMixin):
for entry in preset.entries:
await Entry.create(session, meal.id, entry.product_id, entry.grams)
meal = await cls._get_by_id(session, meal.id)
if not meal:
db_meal = await cls._get_by_id(session, meal.id)
if not db_meal:
raise RuntimeError()
return meal
return db_meal
@classmethod
async def _get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Meal]":

View file

@ -63,7 +63,7 @@ class Preset(Base, CommonMixin):
@classmethod
async def create(
cls, session: AsyncSession, user_id: int, name: str, meal: "Meal"
) -> None:
) -> "Preset":
preset = Preset(user_id=user_id, name=name)
session.add(preset)
@ -76,7 +76,12 @@ class Preset(Base, CommonMixin):
for entry in meal.entries:
await PresetEntry.create(session, preset.id, entry)
return await cls.get(session, user_id, preset.id)
db_preset = await cls.get(session, user_id, preset.id)
if not db_preset:
raise RuntimeError()
return db_preset
@classmethod
async def list_all(

View file

@ -15,8 +15,8 @@ class Product(Base, CommonMixin):
carb: Mapped[float]
fat: Mapped[float]
fiber: Mapped[float]
hard_coded_calories: Mapped[Optional[float]] = None
barcode: Mapped[Optional[str]] = None
hard_coded_calories: Mapped[Optional[float]]
barcode: Mapped[Optional[str]]
@property
def calories(self) -> float:

View file

@ -47,18 +47,18 @@ class RefreshToken(Base, CommonMixin):
:type token: str
:rtype: "RefreshToken"
"""
token = cls(
db_token = cls(
user_id=user_id,
token=token,
)
session.add(token)
session.add(db_token)
try:
await session.flush()
except Exception:
raise AssertionError("invalid token")
return token
return db_token
async def delete(self, session: AsyncSession) -> None:
"""delete.