diff --git a/fooder/domain/__init__.py b/fooder/domain/__init__.py index be490a2..39fc3a5 100644 --- a/fooder/domain/__init__.py +++ b/fooder/domain/__init__.py @@ -4,5 +4,7 @@ from .entry import Entry # noqa from .meal import Meal # noqa from .product import Product # noqa from .user import User # noqa +from .user_product_usage import UserProductUsage # noqa +from .user_settings import UserSettings # noqa from .preset import Preset # noqa from .preset_entry import PresetEntry # noqa diff --git a/fooder/domain/base.py b/fooder/domain/base.py index fb4b556..b724030 100644 --- a/fooder/domain/base.py +++ b/fooder/domain/base.py @@ -1,34 +1,31 @@ +from datetime import datetime + +from sqlalchemy import DateTime from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, mapped_column + +from fooder.utils.datetime import utc_now from fooder.utils.password_helper import password_helper class Base(DeclarativeBase): - """Base from DeclarativeBase""" - pass class CommonMixin: - """ - CommonMixin for all common fields in projetc - """ - @declared_attr.directive def __tablename__(cls) -> str: - """__tablename__. - - :rtype: str - """ return cls.__name__.lower() # type: ignore id: Mapped[int] = mapped_column(primary_key=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=utc_now) + last_changed: Mapped[datetime] = mapped_column(DateTime, default=utc_now, onupdate=utc_now) + + +class SoftDeleteMixin: + deleted_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) class PasswordMixin: - """ - PasswordMixin for entities with password - """ - hashed_password: Mapped[str] def set_password(self, password: str) -> None: @@ -36,3 +33,55 @@ class PasswordMixin: def verify_password(self, password: str) -> bool: return password_helper.verify(password, self.hashed_password) + + +class EntryMacrosMixin: + """Computed macros for entry-like models that scale product macros by grams.""" + + @property + def amount(self) -> float: + return self.grams / 100 # type: ignore[attr-defined] + + @property + def calories(self) -> float: + return self.amount * self.product.calories # type: ignore[attr-defined] + + @property + def protein(self) -> float: + return self.amount * self.product.protein # type: ignore[attr-defined] + + @property + def carb(self) -> float: + return self.amount * self.product.carb # type: ignore[attr-defined] + + @property + def fat(self) -> float: + return self.amount * self.product.fat # type: ignore[attr-defined] + + @property + def fiber(self) -> float: + return self.amount * self.product.fiber # type: ignore[attr-defined] + + +class AggregateMacrosMixin: + """Computed macros for models that sum macros across child entries.""" + + @property + def calories(self) -> float: + return sum(e.calories for e in self.entries) # type: ignore[attr-defined] + + @property + def protein(self) -> float: + return sum(e.protein for e in self.entries) # type: ignore[attr-defined] + + @property + def carb(self) -> float: + return sum(e.carb for e in self.entries) # type: ignore[attr-defined] + + @property + def fat(self) -> float: + return sum(e.fat for e in self.entries) # type: ignore[attr-defined] + + @property + def fiber(self) -> float: + return sum(e.fiber for e in self.entries) # type: ignore[attr-defined] diff --git a/fooder/domain/diary.py b/fooder/domain/diary.py index 900cf72..c4db9ee 100644 --- a/fooder/domain/diary.py +++ b/fooder/domain/diary.py @@ -1,119 +1,44 @@ import datetime -from typing import Optional -from sqlalchemy import Date, ForeignKey, Integer, select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Mapped, joinedload, mapped_column, relationship -from sqlalchemy.sql.selectable import Select +from sqlalchemy import Date, ForeignKey, Integer, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship -from .base import Base, CommonMixin -from .entry import Entry -from .meal import Meal +from fooder.domain.base import Base, CommonMixin +from fooder.domain.meal import Meal class Diary(Base, CommonMixin): - """Diary represents user diary for given day""" + """Diary represents user diary for given day.""" - meals: Mapped[list[Meal]] = relationship( - lazy="selectin", order_by=Meal.order.desc() - ) + __table_args__ = (UniqueConstraint("user_id", "date"),) + + meals: Mapped[list[Meal]] = relationship(lazy="selectin", order_by=Meal.order.desc()) date: Mapped[datetime.date] = mapped_column(Date) user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id")) + # snapshot of user settings at diary creation time — intentionally decoupled + # from UserSettings so historical goals don't change when settings are updated + protein_goal: Mapped[float] + carb_goal: Mapped[float] + fat_goal: Mapped[float] + fiber_goal: Mapped[float] + calories_goal: Mapped[float] @property def calories(self) -> float: - """calories. - - :rtype: float - """ return sum(meal.calories for meal in self.meals) @property def protein(self) -> float: - """protein. - - :rtype: float - """ return sum(meal.protein for meal in self.meals) @property def carb(self) -> float: - """carb. - - :rtype: float - """ return sum(meal.carb for meal in self.meals) @property def fat(self) -> float: - """fat. - - :rtype: float - """ return sum(meal.fat for meal in self.meals) @property def fiber(self) -> float: - """fiber. - - :rtype: float - """ return sum(meal.fiber for meal in self.meals) - - @classmethod - def query(cls, user_id: int) -> Select: - """get_all.""" - query = ( - select(cls) - .where(cls.user_id == user_id) - .options( - joinedload(cls.meals).joinedload(Meal.entries).joinedload(Entry.product) - ) - ) - return query - - @classmethod - async def get_diary( - cls, session: AsyncSession, user_id: int, date: datetime.date - ) -> "Optional[Diary]": - """get_diary.""" - query = cls.query(user_id).where(cls.date == date) - return await session.scalar(query) - - @classmethod - async def create( - cls, session: AsyncSession, user_id: int, date: datetime.date - ) -> "Diary": - diary = Diary( - date=date, - user_id=user_id, - ) - session.add(diary) - - try: - await session.flush() - except Exception: - raise RuntimeError() - - db_diary = await cls.get_by_id(session, user_id, diary.id) - - if not db_diary: - raise RuntimeError() - - await Meal.create(session, db_diary.id) - return db_diary - - @classmethod - async def get_by_id( - cls, session: AsyncSession, user_id: int, id: int - ) -> "Optional[Diary]": - """get_by_id.""" - query = cls.query(user_id).where(cls.id == id) - return await session.scalar(query) - - @classmethod - async def has_permission(cls, session: AsyncSession, user_id: int, id: int) -> bool: - """has_permission.""" - query = select(cls.id).where(cls.user_id == user_id).where(cls.id == id) - obj = await session.scalar(query) - return obj is not None diff --git a/fooder/domain/entry.py b/fooder/domain/entry.py index 37e0003..bc3b5c3 100644 --- a/fooder/domain/entry.py +++ b/fooder/domain/entry.py @@ -1,164 +1,15 @@ -from datetime import datetime -from typing import Optional +from sqlalchemy import Boolean, ForeignKey, Integer +from sqlalchemy.orm import Mapped, mapped_column, relationship -from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, select, update -from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Mapped, joinedload, mapped_column, relationship - -from .base import Base, CommonMixin -from .product import Product +from fooder.domain.base import Base, CommonMixin, EntryMacrosMixin +from fooder.domain.product import Product -class Entry(Base, CommonMixin): +class Entry(Base, CommonMixin, EntryMacrosMixin): """Entry.""" grams: Mapped[float] product_id: Mapped[int] = mapped_column(Integer, ForeignKey("product.id")) product: Mapped[Product] = relationship(lazy="selectin") meal_id: Mapped[int] = mapped_column(Integer, ForeignKey("meal.id")) - last_changed: Mapped[datetime] = mapped_column( - DateTime, default=datetime.utcnow, onupdate=datetime.utcnow - ) processed: Mapped[bool] = mapped_column(Boolean, default=False) - - @property - def amount(self) -> float: - """amount. - - :rtype: float - """ - return self.grams / 100 - - @property - def calories(self) -> float: - """calories. - - :rtype: float - """ - return self.amount * self.product.calories - - @property - def protein(self) -> float: - """protein. - - :rtype: float - """ - return self.amount * self.product.protein - - @property - def carb(self) -> float: - """carb. - - :rtype: float - """ - return self.amount * self.product.carb - - @property - def fat(self) -> float: - """fat. - - :rtype: float - """ - return self.amount * self.product.fat - - @property - def fiber(self) -> float: - """fiber. - - :rtype: float - """ - return self.amount * self.product.fiber - - @classmethod - async def create( - cls, session: AsyncSession, meal_id: int, product_id: int, grams: float - ) -> "Entry": - """create.""" - assert grams > 0, "grams must be greater than 0" - entry = Entry( - meal_id=meal_id, - product_id=product_id, - grams=grams, - ) - session.add(entry) - - try: - await session.flush() - except IntegrityError: - raise AssertionError("meal or product does not exist") - - db_entry = await cls._get_by_id(session, entry.id) - if not db_entry: - raise RuntimeError() - return db_entry - - async def update( - self, - session: AsyncSession, - meal_id: Optional[int], - product_id: Optional[int], - grams: Optional[float], - ) -> None: - """update.""" - if grams is not None: - assert grams > 0, "grams must be greater than 0" - self.grams = grams - - if meal_id is not None: - self.meal_id = meal_id - try: - await session.flush() - except IntegrityError: - raise AssertionError("meal does not exist") - - if product_id is not None: - self.product_id = product_id - try: - await session.flush() - except IntegrityError: - raise AssertionError("product does not exist") - - @classmethod - async def _get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Entry]": - """get_by_id.""" - query = select(cls).where(cls.id == id).options(joinedload(cls.product)) - return await session.scalar(query.order_by(cls.id)) - - @classmethod - async def get_by_id( - cls, session: AsyncSession, user_id: int, id: int - ) -> "Optional[Entry]": - """get_by_id.""" - from .diary import Diary - from .meal import Meal - - query = ( - select(cls) - .where(cls.id == id) - .join( - Meal, - ) - .join( - Diary, - ) - .where( - Diary.user_id == user_id, - ) - .options(joinedload(cls.product)) - ) - return await session.scalar(query.order_by(cls.id)) - - async def delete(self, session) -> None: - """delete.""" - await session.delete(self) - await session.flush() - - @classmethod - async def mark_processed( - cls, - session: AsyncSession, - ) -> None: - stmt = update(cls).where(cls.processed == False).values(processed=True) - - await session.execute(stmt) diff --git a/fooder/domain/meal.py b/fooder/domain/meal.py index 4066afd..ca4d5b9 100644 --- a/fooder/domain/meal.py +++ b/fooder/domain/meal.py @@ -1,153 +1,14 @@ -from typing import Optional +from sqlalchemy import ForeignKey, Integer +from sqlalchemy.orm import Mapped, mapped_column, relationship -from sqlalchemy import ForeignKey, Integer, select -from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Mapped, joinedload, mapped_column, relationship - -from .base import Base, CommonMixin -from .entry import Entry -from .preset import Preset +from fooder.domain.base import Base, AggregateMacrosMixin, CommonMixin +from fooder.domain.entry import Entry -class Meal(Base, CommonMixin): +class Meal(Base, CommonMixin, AggregateMacrosMixin): """Meal.""" name: Mapped[str] order: Mapped[int] diary_id: Mapped[int] = mapped_column(Integer, ForeignKey("diary.id")) - entries: Mapped[list[Entry]] = relationship( - lazy="selectin", order_by=Entry.last_changed - ) - - @property - def calories(self) -> float: - """calories. - - :rtype: float - """ - return sum(entry.calories for entry in self.entries) - - @property - def protein(self) -> float: - """protein. - - :rtype: float - """ - return sum(entry.protein for entry in self.entries) - - @property - def carb(self) -> float: - """carb. - - :rtype: float - """ - return sum(entry.carb for entry in self.entries) - - @property - def fat(self) -> float: - """fat. - - :rtype: float - """ - return sum(entry.fat for entry in self.entries) - - @property - def fiber(self) -> float: - """fiber. - - :rtype: float - """ - return sum(entry.fiber for entry in self.entries) - - @classmethod - async def create( - cls, - session: AsyncSession, - diary_id: int, - name: Optional[str] = None, - ) -> "Meal": - # check if order already exists in diary - query = ( - select(cls.order).where(cls.diary_id == diary_id).order_by(cls.order.desc()) - ) - existing_meal = await session.scalar(query) - order = existing_meal + 1 if existing_meal else 1 - - if name is None: - name = f"Meal {order}" - meal = Meal(diary_id=diary_id, name=name, order=order) - session.add(meal) - - try: - await session.flush() - except IntegrityError: - raise AssertionError("diary does not exist") - - db_meal = await cls._get_by_id(session, meal.id) - if not db_meal: - raise RuntimeError() - return db_meal - - @classmethod - async def create_from_preset( - cls, - session: AsyncSession, - diary_id: int, - name: Optional[str], - preset: Preset, - ) -> "Meal": - # check if order already exists in diary - query = ( - select(cls.order).where(cls.diary_id == diary_id).order_by(cls.order.desc()) - ) - existing_meal = await session.scalar(query) - order = existing_meal + 1 if existing_meal else 1 - - if name is None: - name = preset.name or f"Meal {order}" - - meal = Meal(diary_id=diary_id, name=name, order=order) - session.add(meal) - - try: - await session.flush() - except IntegrityError: - raise AssertionError("diary does not exist") - - for entry in preset.entries: - await Entry.create(session, meal.id, entry.product_id, entry.grams) - - db_meal = await cls._get_by_id(session, meal.id) - if not db_meal: - raise RuntimeError() - return db_meal - - @classmethod - async def _get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Meal]": - """get_by_id.""" - query = select(cls).where(cls.id == id).options(joinedload(cls.entries)) - return await session.scalar(query.order_by(cls.id)) - - @classmethod - async def get_by_id( - cls, session: AsyncSession, user_id: int, id: int - ) -> "Optional[Meal]": - """get_by_id.""" - from .diary import Diary - - query = ( - select(cls) - .where(cls.id == id) - .join(Diary) - .where(Diary.user_id == user_id) - .options(joinedload(cls.entries)) - ) - return await session.scalar(query.order_by(cls.id)) - - async def delete(self, session: AsyncSession) -> None: - """delete.""" - for entry in self.entries: - await session.delete(entry) - await session.delete(self) - await session.flush() + entries: Mapped[list[Entry]] = relationship(lazy="selectin", order_by=Entry.last_changed) diff --git a/fooder/domain/preset.py b/fooder/domain/preset.py index 445405f..b4731f7 100644 --- a/fooder/domain/preset.py +++ b/fooder/domain/preset.py @@ -1,17 +1,11 @@ -from typing import TYPE_CHECKING, AsyncIterator, Optional +from sqlalchemy import ForeignKey, Integer +from sqlalchemy.orm import Mapped, mapped_column, relationship -from sqlalchemy import ForeignKey, Integer, select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Mapped, joinedload, mapped_column, relationship - -from .base import Base, CommonMixin -from .preset_entry import PresetEntry - -if TYPE_CHECKING: - from .meal import Meal +from fooder.domain.base import Base, AggregateMacrosMixin, CommonMixin +from fooder.domain.preset_entry import PresetEntry -class Preset(Base, CommonMixin): +class Preset(Base, CommonMixin, AggregateMacrosMixin): """Preset.""" name: Mapped[str] @@ -19,104 +13,3 @@ class Preset(Base, CommonMixin): entries: Mapped[list[PresetEntry]] = relationship( lazy="selectin", order_by=PresetEntry.last_changed ) - - @property - def calories(self) -> float: - """calories. - - :rtype: float - """ - return sum(entry.calories for entry in self.entries) - - @property - def protein(self) -> float: - """protein. - - :rtype: float - """ - return sum(entry.protein for entry in self.entries) - - @property - def carb(self) -> float: - """carb. - - :rtype: float - """ - return sum(entry.carb for entry in self.entries) - - @property - def fat(self) -> float: - """fat. - - :rtype: float - """ - return sum(entry.fat for entry in self.entries) - - @property - def fiber(self) -> float: - """fiber. - - :rtype: float - """ - return sum(entry.fiber for entry in self.entries) - - @classmethod - async def create( - cls, session: AsyncSession, user_id: int, name: str, meal: "Meal" - ) -> "Preset": - preset = Preset(user_id=user_id, name=name) - - session.add(preset) - - try: - await session.flush() - except Exception: - raise RuntimeError() - - for entry in meal.entries: - await PresetEntry.create(session, preset.id, entry) - - db_preset = await cls.get(session, user_id, preset.id) - - if not db_preset: - raise RuntimeError() - - return db_preset - - @classmethod - async def list_all( - cls, - session: AsyncSession, - user_id: int, - offset: int, - limit: int, - q: Optional[str] = None, - ) -> AsyncIterator["Preset"]: - query = select(cls).filter(cls.user_id == user_id) - - if q: - query = query.filter(cls.name.ilike(f"%{q.lower()}%")) - - query = query.offset(offset).limit(limit) - stream = await session.stream_scalars(query.order_by(cls.id)) - async for row in stream: - yield row - - @classmethod - async def get( - cls, session: AsyncSession, user_id: int, preset_id: int - ) -> "Optional[Preset]": - """get.""" - query = ( - select(cls) - .where(cls.id == preset_id) - .where(cls.user_id == user_id) - .options(joinedload(cls.entries).joinedload(PresetEntry.product)) - ) - return await session.scalar(query) - - async def delete(self, session: AsyncSession) -> None: - for entry in self.entries: - await session.delete(entry) - await session.delete(self) - await session.flush() diff --git a/fooder/domain/preset_entry.py b/fooder/domain/preset_entry.py index fdd40c3..6d2b71a 100644 --- a/fooder/domain/preset_entry.py +++ b/fooder/domain/preset_entry.py @@ -1,89 +1,14 @@ -from datetime import datetime - -from sqlalchemy import DateTime, ForeignKey, Integer -from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import ForeignKey, Integer from sqlalchemy.orm import Mapped, mapped_column, relationship -from .base import Base, CommonMixin -from .entry import Entry -from .product import Product +from fooder.domain.base import Base, CommonMixin, EntryMacrosMixin +from fooder.domain.product import Product -class PresetEntry(Base, CommonMixin): - """Entry.""" +class PresetEntry(Base, CommonMixin, EntryMacrosMixin): + """PresetEntry.""" grams: Mapped[float] product_id: Mapped[int] = mapped_column(Integer, ForeignKey("product.id")) product: Mapped[Product] = relationship(lazy="selectin") preset_id: Mapped[int] = mapped_column(Integer, ForeignKey("preset.id")) - last_changed: Mapped[datetime] = mapped_column( - DateTime, default=datetime.utcnow, onupdate=datetime.utcnow - ) - - @property - def amount(self) -> float: - """amount. - - :rtype: float - """ - return self.grams / 100 - - @property - def calories(self) -> float: - """calories. - - :rtype: float - """ - return self.amount * self.product.calories - - @property - def protein(self) -> float: - """protein. - - :rtype: float - """ - return self.amount * self.product.protein - - @property - def carb(self) -> float: - """carb. - - :rtype: float - """ - return self.amount * self.product.carb - - @property - def fat(self) -> float: - """fat. - - :rtype: float - """ - return self.amount * self.product.fat - - @property - def fiber(self) -> float: - """fiber. - - :rtype: float - """ - return self.amount * self.product.fiber - - @classmethod - async def create( - self, - session: AsyncSession, - preset_id: int, - entry: Entry, - ) -> None: - pentry = PresetEntry( - preset_id=preset_id, - product_id=entry.product_id, - grams=entry.grams, - ) - session.add(pentry) - - try: - await session.flush() - except IntegrityError: - raise AssertionError("preset or product does not exist") diff --git a/fooder/domain/product.py b/fooder/domain/product.py index 9cb72e1..47a3385 100644 --- a/fooder/domain/product.py +++ b/fooder/domain/product.py @@ -1,16 +1,22 @@ +from sqlalchemy import Index, text from sqlalchemy.orm import Mapped, mapped_column -from fooder.domain.base import Base, CommonMixin +from fooder.domain.base import Base, CommonMixin, SoftDeleteMixin -class Product(Base, CommonMixin): +class Product(Base, CommonMixin, SoftDeleteMixin): """Product.""" - name: Mapped[str] + __table_args__ = ( + Index("ix_product_barcode", "barcode", unique=True, + postgresql_where=text("deleted_at IS NULL"), + sqlite_where=text("deleted_at IS NULL")), + ) + name: Mapped[str] protein: Mapped[float] carb: Mapped[float] fat: Mapped[float] fiber: Mapped[float] calories: Mapped[float] - barcode: Mapped[str | None] = mapped_column(unique=True) + barcode: Mapped[str | None] = mapped_column(default=None) diff --git a/fooder/domain/user.py b/fooder/domain/user.py index c6c3ca6..4e78c0b 100644 --- a/fooder/domain/user.py +++ b/fooder/domain/user.py @@ -1,9 +1,19 @@ -from sqlalchemy.orm import Mapped +from __future__ import annotations -from fooder.domain.base import Base, CommonMixin, PasswordMixin +from typing import TYPE_CHECKING + +from sqlalchemy.orm import Mapped, relationship + +from fooder.domain.base import Base, CommonMixin, PasswordMixin, SoftDeleteMixin + +if TYPE_CHECKING: + from fooder.domain.user_settings import UserSettings -class User(Base, CommonMixin, PasswordMixin): - """Product.""" +class User(Base, CommonMixin, PasswordMixin, SoftDeleteMixin): + """User.""" username: Mapped[str] + settings: Mapped[UserSettings] = relationship( + back_populates="user", lazy="selectin", uselist=False + ) diff --git a/fooder/domain/user_product_usage.py b/fooder/domain/user_product_usage.py new file mode 100644 index 0000000..07d3a22 --- /dev/null +++ b/fooder/domain/user_product_usage.py @@ -0,0 +1,19 @@ +from sqlalchemy import ForeignKey, Integer, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from fooder.domain.base import Base, CommonMixin +from fooder.domain.product import Product +from fooder.domain.user import User + + +class UserProductUsage(Base, CommonMixin): + """Counts how many processed entries a user has for a product. + Used to sort products by usage frequency.""" + + __table_args__ = (UniqueConstraint("user_id", "product_id"),) + + product_id: Mapped[int] = mapped_column(Integer, ForeignKey("product.id")) + product: Mapped[Product] = relationship(lazy="selectin") + user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id")) + user: Mapped[User] = relationship(lazy="selectin") + count: Mapped[int] diff --git a/fooder/domain/user_settings.py b/fooder/domain/user_settings.py new file mode 100644 index 0000000..543fe9a --- /dev/null +++ b/fooder/domain/user_settings.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from sqlalchemy import ForeignKey, Integer +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from fooder.domain.base import Base, CommonMixin + +if TYPE_CHECKING: + from fooder.domain.user import User + + +class UserSettings(Base, CommonMixin): + """UserSettings.""" + + user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"), unique=True) + user: Mapped[User] = relationship(back_populates="settings") + # meals_name_convention: Mapped[list[str]] # json column, format TBD + protein_goal: Mapped[float] + carb_goal: Mapped[float] + fat_goal: Mapped[float] + fiber_goal: Mapped[float] + calories_goal: Mapped[float] diff --git a/fooder/repository/base.py b/fooder/repository/base.py index f7c15ba..101c676 100644 --- a/fooder/repository/base.py +++ b/fooder/repository/base.py @@ -1,10 +1,11 @@ from typing import TypeVar, Generic, Type, Sequence from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select, delete as sa_delete, ColumnElement +from sqlalchemy import select, delete as sa_delete, update as sa_update, ColumnElement from sqlalchemy.exc import IntegrityError from sqlalchemy.sql import Select from fooder.domain import Base from fooder.exc import Conflict, NotFound +from fooder.utils.datetime import utc_now T = TypeVar("T", bound=Base) @@ -20,6 +21,9 @@ class RepositoryBase(Generic[T]): def _build_select(self, *expressions: ColumnElement) -> Select[tuple[T]]: stmt = select(self.model) + if hasattr(self.model, "deleted_at"): + stmt = stmt.where(self.model.deleted_at.is_(None)) # type: ignore[attr-defined] + if expressions: stmt = stmt.where(*expressions) @@ -64,7 +68,10 @@ class RepositoryBase(Generic[T]): return obj async def _delete(self, *expressions: ColumnElement): - stmt = sa_delete(self.model) + if hasattr(self.model, "deleted_at"): + stmt = sa_update(self.model).values(deleted_at=utc_now()) # type: ignore[attr-defined] + else: + stmt = sa_delete(self.model) if expressions: stmt = stmt.where(*expressions) diff --git a/fooder/test/fixtures/db.py b/fooder/test/fixtures/db.py index 33543b2..b6ebd30 100644 --- a/fooder/test/fixtures/db.py +++ b/fooder/test/fixtures/db.py @@ -5,14 +5,13 @@ from sqlalchemy import event from sqlalchemy.ext.asyncio import AsyncSession from fooder.db import DatabaseSessionManager from fooder.domain import Base +from fooder.domain.base import CommonMixin from fooder.settings import settings from fooder.repository.base import RepositoryBase from sqlalchemy.orm import Mapped, mapped_column -class TestModel(Base): - __tablename__ = "test" - id: Mapped[int] = mapped_column(primary_key=True) +class TestModel(Base, CommonMixin): property: Mapped[str]