This commit is contained in:
Piotr Domański 2026-04-07 17:07:38 +02:00
parent 724c350e99
commit e34208a91b
6 changed files with 41 additions and 29 deletions

View file

@ -4,6 +4,7 @@ from typing import AsyncIterator, AsyncGenerator
from fooder.settings import Settings, settings
from sqlalchemy.ext.asyncio import (
AsyncConnection,
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
@ -12,7 +13,7 @@ from sqlalchemy.ext.asyncio import (
class DatabaseSessionManager:
def __init__(self, settings: Settings) -> None:
self._engine = create_async_engine(
self._engine: AsyncEngine | None = create_async_engine(
settings.DB_URI,
pool_pre_ping=True,
echo=settings.ECHO_SQL,
@ -22,7 +23,7 @@ class DatabaseSessionManager:
else {}
),
)
self._sessionmaker = async_sessionmaker(
self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(
autocommit=False,
autoflush=False,
bind=self._engine,

View file

@ -1,4 +1,11 @@
from pydantic import ConfigDict
from typing import Annotated
from pydantic import ConfigDict, Field
Macronutrient = Annotated[float, Field(ge=0, le=100)]
OptionalMacronutrient = Annotated[float | None, Field(default=None, ge=0, le=100)]
Calories = Annotated[float, Field(ge=0)]
OptionalCalories = Annotated[float | None, Field(default=None, ge=0)]
class ObjModelMixin:

View file

@ -1,24 +1,25 @@
from .base import ObjModelMixin
from pydantic import BaseModel, Field
from .base import ObjModelMixin, Macronutrient, OptionalMacronutrient, Calories, OptionalCalories
from pydantic import BaseModel
from fooder.utils.calories import calculate_calories
class ProductModelBase(BaseModel):
class ProductModel(ObjModelMixin, BaseModel):
name: str
protein: float = Field(ge=0, le=100)
carb: float = Field(ge=0, le=100)
fat: float = Field(ge=0, le=100)
fiber: float = Field(ge=0, le=100)
calories: float = Field(ge=0)
barcode: str | None
protein: Macronutrient
carb: Macronutrient
fat: Macronutrient
fiber: Macronutrient
calories: Calories
barcode: str | None = None
class ProductModel(ObjModelMixin, ProductModelBase):
pass
class ProductCreateModel(ProductModelBase):
calories: float | None = None
class ProductCreateModel(BaseModel):
name: str
protein: Macronutrient
carb: Macronutrient
fat: Macronutrient
fiber: Macronutrient
calories: OptionalCalories = None
barcode: str | None = None
@property
@ -35,11 +36,11 @@ class ProductCreateModel(ProductModelBase):
)
class ProductUpdateModel(ProductModelBase):
class ProductUpdateModel(BaseModel):
name: str | None = None
protein: float | None = None
carb: float | None = None
fat: float | None = None
fiber: float | None = None
calories: float | None = None
protein: OptionalMacronutrient = None
carb: OptionalMacronutrient = None
fat: OptionalMacronutrient = None
fiber: OptionalMacronutrient = None
calories: OptionalCalories = None
barcode: str | None = None

View file

@ -1,6 +1,6 @@
from typing import TypeVar, Generic, Type, Sequence
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete as sa_delete, update as sa_update, ColumnElement
from sqlalchemy import Delete, Update, select, delete as sa_delete, update as sa_update, ColumnElement
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import StaleDataError
from sqlalchemy.sql import Select
@ -86,12 +86,12 @@ class RepositoryBase(Generic[T]):
return obj
async def _delete(self, *expressions: ColumnElement):
stmt: Update | Delete
if hasattr(self.model, "deleted_at"):
stmt = sa_update(self.model).values(deleted_at=utc_now()) # type: ignore[attr-defined]
stmt = sa_update(self.model).values(deleted_at=utc_now())
else:
stmt = sa_delete(self.model)
if expressions:
stmt = stmt.where(*expressions)
await self.session.execute(stmt)

View file

@ -1,4 +1,7 @@
from typing import cast
from sqlalchemy import update as sa_update
from sqlalchemy.engine import CursorResult
from fooder.domain.user_product_usage import UserProductUsage
from fooder.repository.base import RepositoryBase
@ -14,7 +17,7 @@ class UserProductUsageRepository(RepositoryBase[UserProductUsage]):
)
.values(count=UserProductUsage.count + count)
)
result = await self.session.execute(stmt)
result = cast(CursorResult, await self.session.execute(stmt))
if result.rowcount == 0:
obj = UserProductUsage(user_id=user_id, product_id=product_id, count=count)

View file

@ -2,7 +2,7 @@
[mypy]
plugins = sqlalchemy.ext.mypy.plugin,pydantic.mypy
exclude = .*/test/.*
exclude = .*/test/.*|fooder/alembic/.*
pretty = True