[mypy]
This commit is contained in:
parent
724c350e99
commit
e34208a91b
6 changed files with 41 additions and 29 deletions
|
|
@ -4,6 +4,7 @@ from typing import AsyncIterator, AsyncGenerator
|
||||||
from fooder.settings import Settings, settings
|
from fooder.settings import Settings, settings
|
||||||
from sqlalchemy.ext.asyncio import (
|
from sqlalchemy.ext.asyncio import (
|
||||||
AsyncConnection,
|
AsyncConnection,
|
||||||
|
AsyncEngine,
|
||||||
AsyncSession,
|
AsyncSession,
|
||||||
async_sessionmaker,
|
async_sessionmaker,
|
||||||
create_async_engine,
|
create_async_engine,
|
||||||
|
|
@ -12,7 +13,7 @@ from sqlalchemy.ext.asyncio import (
|
||||||
|
|
||||||
class DatabaseSessionManager:
|
class DatabaseSessionManager:
|
||||||
def __init__(self, settings: Settings) -> None:
|
def __init__(self, settings: Settings) -> None:
|
||||||
self._engine = create_async_engine(
|
self._engine: AsyncEngine | None = create_async_engine(
|
||||||
settings.DB_URI,
|
settings.DB_URI,
|
||||||
pool_pre_ping=True,
|
pool_pre_ping=True,
|
||||||
echo=settings.ECHO_SQL,
|
echo=settings.ECHO_SQL,
|
||||||
|
|
@ -22,7 +23,7 @@ class DatabaseSessionManager:
|
||||||
else {}
|
else {}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._sessionmaker = async_sessionmaker(
|
self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(
|
||||||
autocommit=False,
|
autocommit=False,
|
||||||
autoflush=False,
|
autoflush=False,
|
||||||
bind=self._engine,
|
bind=self._engine,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,11 @@
|
||||||
from pydantic import ConfigDict
|
from typing import Annotated
|
||||||
|
|
||||||
|
from pydantic import ConfigDict, Field
|
||||||
|
|
||||||
|
Macronutrient = Annotated[float, Field(ge=0, le=100)]
|
||||||
|
OptionalMacronutrient = Annotated[float | None, Field(default=None, ge=0, le=100)]
|
||||||
|
Calories = Annotated[float, Field(ge=0)]
|
||||||
|
OptionalCalories = Annotated[float | None, Field(default=None, ge=0)]
|
||||||
|
|
||||||
|
|
||||||
class ObjModelMixin:
|
class ObjModelMixin:
|
||||||
|
|
|
||||||
|
|
@ -1,24 +1,25 @@
|
||||||
from .base import ObjModelMixin
|
from .base import ObjModelMixin, Macronutrient, OptionalMacronutrient, Calories, OptionalCalories
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel
|
||||||
from fooder.utils.calories import calculate_calories
|
from fooder.utils.calories import calculate_calories
|
||||||
|
|
||||||
|
|
||||||
class ProductModelBase(BaseModel):
|
class ProductModel(ObjModelMixin, BaseModel):
|
||||||
name: str
|
name: str
|
||||||
protein: float = Field(ge=0, le=100)
|
protein: Macronutrient
|
||||||
carb: float = Field(ge=0, le=100)
|
carb: Macronutrient
|
||||||
fat: float = Field(ge=0, le=100)
|
fat: Macronutrient
|
||||||
fiber: float = Field(ge=0, le=100)
|
fiber: Macronutrient
|
||||||
calories: float = Field(ge=0)
|
calories: Calories
|
||||||
barcode: str | None
|
barcode: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ProductModel(ObjModelMixin, ProductModelBase):
|
class ProductCreateModel(BaseModel):
|
||||||
pass
|
name: str
|
||||||
|
protein: Macronutrient
|
||||||
|
carb: Macronutrient
|
||||||
class ProductCreateModel(ProductModelBase):
|
fat: Macronutrient
|
||||||
calories: float | None = None
|
fiber: Macronutrient
|
||||||
|
calories: OptionalCalories = None
|
||||||
barcode: str | None = None
|
barcode: str | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -35,11 +36,11 @@ class ProductCreateModel(ProductModelBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProductUpdateModel(ProductModelBase):
|
class ProductUpdateModel(BaseModel):
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
protein: float | None = None
|
protein: OptionalMacronutrient = None
|
||||||
carb: float | None = None
|
carb: OptionalMacronutrient = None
|
||||||
fat: float | None = None
|
fat: OptionalMacronutrient = None
|
||||||
fiber: float | None = None
|
fiber: OptionalMacronutrient = None
|
||||||
calories: float | None = None
|
calories: OptionalCalories = None
|
||||||
barcode: str | None = None
|
barcode: str | None = None
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import TypeVar, Generic, Type, Sequence
|
from typing import TypeVar, Generic, Type, Sequence
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy import select, delete as sa_delete, update as sa_update, ColumnElement
|
from sqlalchemy import Delete, Update, select, delete as sa_delete, update as sa_update, ColumnElement
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm.exc import StaleDataError
|
from sqlalchemy.orm.exc import StaleDataError
|
||||||
from sqlalchemy.sql import Select
|
from sqlalchemy.sql import Select
|
||||||
|
|
@ -86,12 +86,12 @@ class RepositoryBase(Generic[T]):
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
async def _delete(self, *expressions: ColumnElement):
|
async def _delete(self, *expressions: ColumnElement):
|
||||||
|
stmt: Update | Delete
|
||||||
if hasattr(self.model, "deleted_at"):
|
if hasattr(self.model, "deleted_at"):
|
||||||
stmt = sa_update(self.model).values(deleted_at=utc_now()) # type: ignore[attr-defined]
|
stmt = sa_update(self.model).values(deleted_at=utc_now())
|
||||||
else:
|
else:
|
||||||
stmt = sa_delete(self.model)
|
stmt = sa_delete(self.model)
|
||||||
|
|
||||||
if expressions:
|
if expressions:
|
||||||
stmt = stmt.where(*expressions)
|
stmt = stmt.where(*expressions)
|
||||||
|
|
||||||
await self.session.execute(stmt)
|
await self.session.execute(stmt)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,7 @@
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from sqlalchemy import update as sa_update
|
from sqlalchemy import update as sa_update
|
||||||
|
from sqlalchemy.engine import CursorResult
|
||||||
|
|
||||||
from fooder.domain.user_product_usage import UserProductUsage
|
from fooder.domain.user_product_usage import UserProductUsage
|
||||||
from fooder.repository.base import RepositoryBase
|
from fooder.repository.base import RepositoryBase
|
||||||
|
|
@ -14,7 +17,7 @@ class UserProductUsageRepository(RepositoryBase[UserProductUsage]):
|
||||||
)
|
)
|
||||||
.values(count=UserProductUsage.count + count)
|
.values(count=UserProductUsage.count + count)
|
||||||
)
|
)
|
||||||
result = await self.session.execute(stmt)
|
result = cast(CursorResult, await self.session.execute(stmt))
|
||||||
|
|
||||||
if result.rowcount == 0:
|
if result.rowcount == 0:
|
||||||
obj = UserProductUsage(user_id=user_id, product_id=product_id, count=count)
|
obj = UserProductUsage(user_id=user_id, product_id=product_id, count=count)
|
||||||
|
|
|
||||||
2
mypy.ini
2
mypy.ini
|
|
@ -2,7 +2,7 @@
|
||||||
[mypy]
|
[mypy]
|
||||||
plugins = sqlalchemy.ext.mypy.plugin,pydantic.mypy
|
plugins = sqlalchemy.ext.mypy.plugin,pydantic.mypy
|
||||||
|
|
||||||
exclude = .*/test/.*
|
exclude = .*/test/.*|fooder/alembic/.*
|
||||||
|
|
||||||
pretty = True
|
pretty = True
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue