diff --git a/fooder/db.py b/fooder/db.py index 87ce60b..767387f 100644 --- a/fooder/db.py +++ b/fooder/db.py @@ -4,6 +4,7 @@ from typing import AsyncIterator, AsyncGenerator from fooder.settings import Settings, settings from sqlalchemy.ext.asyncio import ( AsyncConnection, + AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine, @@ -12,7 +13,7 @@ from sqlalchemy.ext.asyncio import ( class DatabaseSessionManager: def __init__(self, settings: Settings) -> None: - self._engine = create_async_engine( + self._engine: AsyncEngine | None = create_async_engine( settings.DB_URI, pool_pre_ping=True, echo=settings.ECHO_SQL, @@ -22,7 +23,7 @@ class DatabaseSessionManager: else {} ), ) - self._sessionmaker = async_sessionmaker( + self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker( autocommit=False, autoflush=False, bind=self._engine, diff --git a/fooder/model/base.py b/fooder/model/base.py index a0d3554..4ec948f 100644 --- a/fooder/model/base.py +++ b/fooder/model/base.py @@ -1,4 +1,11 @@ -from pydantic import ConfigDict +from typing import Annotated + +from pydantic import ConfigDict, Field + +Macronutrient = Annotated[float, Field(ge=0, le=100)] +OptionalMacronutrient = Annotated[float | None, Field(default=None, ge=0, le=100)] +Calories = Annotated[float, Field(ge=0)] +OptionalCalories = Annotated[float | None, Field(default=None, ge=0)] class ObjModelMixin: diff --git a/fooder/model/product.py b/fooder/model/product.py index dfc0a29..337b7a0 100644 --- a/fooder/model/product.py +++ b/fooder/model/product.py @@ -1,24 +1,25 @@ -from .base import ObjModelMixin -from pydantic import BaseModel, Field +from .base import ObjModelMixin, Macronutrient, OptionalMacronutrient, Calories, OptionalCalories +from pydantic import BaseModel from fooder.utils.calories import calculate_calories -class ProductModelBase(BaseModel): +class ProductModel(ObjModelMixin, BaseModel): name: str - protein: float = Field(ge=0, le=100) - carb: float = Field(ge=0, le=100) - fat: float = Field(ge=0, le=100) - fiber: float = Field(ge=0, le=100) - calories: float = Field(ge=0) - barcode: str | None + protein: Macronutrient + carb: Macronutrient + fat: Macronutrient + fiber: Macronutrient + calories: Calories + barcode: str | None = None -class ProductModel(ObjModelMixin, ProductModelBase): - pass - - -class ProductCreateModel(ProductModelBase): - calories: float | None = None +class ProductCreateModel(BaseModel): + name: str + protein: Macronutrient + carb: Macronutrient + fat: Macronutrient + fiber: Macronutrient + calories: OptionalCalories = None barcode: str | None = None @property @@ -35,11 +36,11 @@ class ProductCreateModel(ProductModelBase): ) -class ProductUpdateModel(ProductModelBase): +class ProductUpdateModel(BaseModel): name: str | None = None - protein: float | None = None - carb: float | None = None - fat: float | None = None - fiber: float | None = None - calories: float | None = None + protein: OptionalMacronutrient = None + carb: OptionalMacronutrient = None + fat: OptionalMacronutrient = None + fiber: OptionalMacronutrient = None + calories: OptionalCalories = None barcode: str | None = None diff --git a/fooder/repository/base.py b/fooder/repository/base.py index dc40b6b..f5d6b64 100644 --- a/fooder/repository/base.py +++ b/fooder/repository/base.py @@ -1,6 +1,6 @@ from typing import TypeVar, Generic, Type, Sequence from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select, delete as sa_delete, update as sa_update, ColumnElement +from sqlalchemy import Delete, Update, select, delete as sa_delete, update as sa_update, ColumnElement from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.exc import StaleDataError from sqlalchemy.sql import Select @@ -86,12 +86,12 @@ class RepositoryBase(Generic[T]): return obj async def _delete(self, *expressions: ColumnElement): + stmt: Update | Delete if hasattr(self.model, "deleted_at"): - stmt = sa_update(self.model).values(deleted_at=utc_now()) # type: ignore[attr-defined] + stmt = sa_update(self.model).values(deleted_at=utc_now()) else: stmt = sa_delete(self.model) if expressions: stmt = stmt.where(*expressions) - await self.session.execute(stmt) diff --git a/fooder/repository/user_product_usage.py b/fooder/repository/user_product_usage.py index 19bf84c..0d4b284 100644 --- a/fooder/repository/user_product_usage.py +++ b/fooder/repository/user_product_usage.py @@ -1,4 +1,7 @@ +from typing import cast + from sqlalchemy import update as sa_update +from sqlalchemy.engine import CursorResult from fooder.domain.user_product_usage import UserProductUsage from fooder.repository.base import RepositoryBase @@ -14,7 +17,7 @@ class UserProductUsageRepository(RepositoryBase[UserProductUsage]): ) .values(count=UserProductUsage.count + count) ) - result = await self.session.execute(stmt) + result = cast(CursorResult, await self.session.execute(stmt)) if result.rowcount == 0: obj = UserProductUsage(user_id=user_id, product_id=product_id, count=count) diff --git a/mypy.ini b/mypy.ini index 8e19889..5d2114e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,7 +2,7 @@ [mypy] plugins = sqlalchemy.ext.mypy.plugin,pydantic.mypy -exclude = .*/test/.* +exclude = .*/test/.*|fooder/alembic/.* pretty = True