[domain] rebuilt

This commit is contained in:
Piotr Domański 2026-04-07 15:59:08 +02:00
parent e5565dbf87
commit 7d2df880c7
13 changed files with 179 additions and 608 deletions

View file

@ -4,5 +4,7 @@ from .entry import Entry # noqa
from .meal import Meal # noqa from .meal import Meal # noqa
from .product import Product # noqa from .product import Product # noqa
from .user import User # noqa from .user import User # noqa
from .user_product_usage import UserProductUsage # noqa
from .user_settings import UserSettings # noqa
from .preset import Preset # noqa from .preset import Preset # noqa
from .preset_entry import PresetEntry # noqa from .preset_entry import PresetEntry # noqa

View file

@ -1,34 +1,31 @@
from datetime import datetime
from sqlalchemy import DateTime
from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, mapped_column from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, mapped_column
from fooder.utils.datetime import utc_now
from fooder.utils.password_helper import password_helper from fooder.utils.password_helper import password_helper
class Base(DeclarativeBase): class Base(DeclarativeBase):
"""Base from DeclarativeBase"""
pass pass
class CommonMixin: class CommonMixin:
"""
CommonMixin for all common fields in projetc
"""
@declared_attr.directive @declared_attr.directive
def __tablename__(cls) -> str: def __tablename__(cls) -> str:
"""__tablename__.
:rtype: str
"""
return cls.__name__.lower() # type: ignore return cls.__name__.lower() # type: ignore
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=utc_now)
last_changed: Mapped[datetime] = mapped_column(DateTime, default=utc_now, onupdate=utc_now)
class SoftDeleteMixin:
deleted_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
class PasswordMixin: class PasswordMixin:
"""
PasswordMixin for entities with password
"""
hashed_password: Mapped[str] hashed_password: Mapped[str]
def set_password(self, password: str) -> None: def set_password(self, password: str) -> None:
@ -36,3 +33,55 @@ class PasswordMixin:
def verify_password(self, password: str) -> bool: def verify_password(self, password: str) -> bool:
return password_helper.verify(password, self.hashed_password) return password_helper.verify(password, self.hashed_password)
class EntryMacrosMixin:
"""Computed macros for entry-like models that scale product macros by grams."""
@property
def amount(self) -> float:
return self.grams / 100 # type: ignore[attr-defined]
@property
def calories(self) -> float:
return self.amount * self.product.calories # type: ignore[attr-defined]
@property
def protein(self) -> float:
return self.amount * self.product.protein # type: ignore[attr-defined]
@property
def carb(self) -> float:
return self.amount * self.product.carb # type: ignore[attr-defined]
@property
def fat(self) -> float:
return self.amount * self.product.fat # type: ignore[attr-defined]
@property
def fiber(self) -> float:
return self.amount * self.product.fiber # type: ignore[attr-defined]
class AggregateMacrosMixin:
"""Computed macros for models that sum macros across child entries."""
@property
def calories(self) -> float:
return sum(e.calories for e in self.entries) # type: ignore[attr-defined]
@property
def protein(self) -> float:
return sum(e.protein for e in self.entries) # type: ignore[attr-defined]
@property
def carb(self) -> float:
return sum(e.carb for e in self.entries) # type: ignore[attr-defined]
@property
def fat(self) -> float:
return sum(e.fat for e in self.entries) # type: ignore[attr-defined]
@property
def fiber(self) -> float:
return sum(e.fiber for e in self.entries) # type: ignore[attr-defined]

View file

@ -1,119 +1,44 @@
import datetime import datetime
from typing import Optional
from sqlalchemy import Date, ForeignKey, Integer, select from sqlalchemy import Date, ForeignKey, Integer, UniqueConstraint
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import Mapped, joinedload, mapped_column, relationship
from sqlalchemy.sql.selectable import Select
from .base import Base, CommonMixin from fooder.domain.base import Base, CommonMixin
from .entry import Entry from fooder.domain.meal import Meal
from .meal import Meal
class Diary(Base, CommonMixin): class Diary(Base, CommonMixin):
"""Diary represents user diary for given day""" """Diary represents user diary for given day."""
meals: Mapped[list[Meal]] = relationship( __table_args__ = (UniqueConstraint("user_id", "date"),)
lazy="selectin", order_by=Meal.order.desc()
) meals: Mapped[list[Meal]] = relationship(lazy="selectin", order_by=Meal.order.desc())
date: Mapped[datetime.date] = mapped_column(Date) date: Mapped[datetime.date] = mapped_column(Date)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id")) user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
# snapshot of user settings at diary creation time — intentionally decoupled
# from UserSettings so historical goals don't change when settings are updated
protein_goal: Mapped[float]
carb_goal: Mapped[float]
fat_goal: Mapped[float]
fiber_goal: Mapped[float]
calories_goal: Mapped[float]
@property @property
def calories(self) -> float: def calories(self) -> float:
"""calories.
:rtype: float
"""
return sum(meal.calories for meal in self.meals) return sum(meal.calories for meal in self.meals)
@property @property
def protein(self) -> float: def protein(self) -> float:
"""protein.
:rtype: float
"""
return sum(meal.protein for meal in self.meals) return sum(meal.protein for meal in self.meals)
@property @property
def carb(self) -> float: def carb(self) -> float:
"""carb.
:rtype: float
"""
return sum(meal.carb for meal in self.meals) return sum(meal.carb for meal in self.meals)
@property @property
def fat(self) -> float: def fat(self) -> float:
"""fat.
:rtype: float
"""
return sum(meal.fat for meal in self.meals) return sum(meal.fat for meal in self.meals)
@property @property
def fiber(self) -> float: def fiber(self) -> float:
"""fiber.
:rtype: float
"""
return sum(meal.fiber for meal in self.meals) return sum(meal.fiber for meal in self.meals)
@classmethod
def query(cls, user_id: int) -> Select:
"""get_all."""
query = (
select(cls)
.where(cls.user_id == user_id)
.options(
joinedload(cls.meals).joinedload(Meal.entries).joinedload(Entry.product)
)
)
return query
@classmethod
async def get_diary(
cls, session: AsyncSession, user_id: int, date: datetime.date
) -> "Optional[Diary]":
"""get_diary."""
query = cls.query(user_id).where(cls.date == date)
return await session.scalar(query)
@classmethod
async def create(
cls, session: AsyncSession, user_id: int, date: datetime.date
) -> "Diary":
diary = Diary(
date=date,
user_id=user_id,
)
session.add(diary)
try:
await session.flush()
except Exception:
raise RuntimeError()
db_diary = await cls.get_by_id(session, user_id, diary.id)
if not db_diary:
raise RuntimeError()
await Meal.create(session, db_diary.id)
return db_diary
@classmethod
async def get_by_id(
cls, session: AsyncSession, user_id: int, id: int
) -> "Optional[Diary]":
"""get_by_id."""
query = cls.query(user_id).where(cls.id == id)
return await session.scalar(query)
@classmethod
async def has_permission(cls, session: AsyncSession, user_id: int, id: int) -> bool:
"""has_permission."""
query = select(cls.id).where(cls.user_id == user_id).where(cls.id == id)
obj = await session.scalar(query)
return obj is not None

View file

@ -1,164 +1,15 @@
from datetime import datetime from sqlalchemy import Boolean, ForeignKey, Integer
from typing import Optional from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, select, update from fooder.domain.base import Base, CommonMixin, EntryMacrosMixin
from sqlalchemy.exc import IntegrityError from fooder.domain.product import Product
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, joinedload, mapped_column, relationship
from .base import Base, CommonMixin
from .product import Product
class Entry(Base, CommonMixin): class Entry(Base, CommonMixin, EntryMacrosMixin):
"""Entry.""" """Entry."""
grams: Mapped[float] grams: Mapped[float]
product_id: Mapped[int] = mapped_column(Integer, ForeignKey("product.id")) product_id: Mapped[int] = mapped_column(Integer, ForeignKey("product.id"))
product: Mapped[Product] = relationship(lazy="selectin") product: Mapped[Product] = relationship(lazy="selectin")
meal_id: Mapped[int] = mapped_column(Integer, ForeignKey("meal.id")) meal_id: Mapped[int] = mapped_column(Integer, ForeignKey("meal.id"))
last_changed: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
processed: Mapped[bool] = mapped_column(Boolean, default=False) processed: Mapped[bool] = mapped_column(Boolean, default=False)
@property
def amount(self) -> float:
"""amount.
:rtype: float
"""
return self.grams / 100
@property
def calories(self) -> float:
"""calories.
:rtype: float
"""
return self.amount * self.product.calories
@property
def protein(self) -> float:
"""protein.
:rtype: float
"""
return self.amount * self.product.protein
@property
def carb(self) -> float:
"""carb.
:rtype: float
"""
return self.amount * self.product.carb
@property
def fat(self) -> float:
"""fat.
:rtype: float
"""
return self.amount * self.product.fat
@property
def fiber(self) -> float:
"""fiber.
:rtype: float
"""
return self.amount * self.product.fiber
@classmethod
async def create(
cls, session: AsyncSession, meal_id: int, product_id: int, grams: float
) -> "Entry":
"""create."""
assert grams > 0, "grams must be greater than 0"
entry = Entry(
meal_id=meal_id,
product_id=product_id,
grams=grams,
)
session.add(entry)
try:
await session.flush()
except IntegrityError:
raise AssertionError("meal or product does not exist")
db_entry = await cls._get_by_id(session, entry.id)
if not db_entry:
raise RuntimeError()
return db_entry
async def update(
self,
session: AsyncSession,
meal_id: Optional[int],
product_id: Optional[int],
grams: Optional[float],
) -> None:
"""update."""
if grams is not None:
assert grams > 0, "grams must be greater than 0"
self.grams = grams
if meal_id is not None:
self.meal_id = meal_id
try:
await session.flush()
except IntegrityError:
raise AssertionError("meal does not exist")
if product_id is not None:
self.product_id = product_id
try:
await session.flush()
except IntegrityError:
raise AssertionError("product does not exist")
@classmethod
async def _get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Entry]":
"""get_by_id."""
query = select(cls).where(cls.id == id).options(joinedload(cls.product))
return await session.scalar(query.order_by(cls.id))
@classmethod
async def get_by_id(
cls, session: AsyncSession, user_id: int, id: int
) -> "Optional[Entry]":
"""get_by_id."""
from .diary import Diary
from .meal import Meal
query = (
select(cls)
.where(cls.id == id)
.join(
Meal,
)
.join(
Diary,
)
.where(
Diary.user_id == user_id,
)
.options(joinedload(cls.product))
)
return await session.scalar(query.order_by(cls.id))
async def delete(self, session) -> None:
"""delete."""
await session.delete(self)
await session.flush()
@classmethod
async def mark_processed(
cls,
session: AsyncSession,
) -> None:
stmt = update(cls).where(cls.processed == False).values(processed=True)
await session.execute(stmt)

View file

@ -1,153 +1,14 @@
from typing import Optional from sqlalchemy import ForeignKey, Integer
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy import ForeignKey, Integer, select from fooder.domain.base import Base, AggregateMacrosMixin, CommonMixin
from sqlalchemy.exc import IntegrityError from fooder.domain.entry import Entry
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, joinedload, mapped_column, relationship
from .base import Base, CommonMixin
from .entry import Entry
from .preset import Preset
class Meal(Base, CommonMixin): class Meal(Base, CommonMixin, AggregateMacrosMixin):
"""Meal.""" """Meal."""
name: Mapped[str] name: Mapped[str]
order: Mapped[int] order: Mapped[int]
diary_id: Mapped[int] = mapped_column(Integer, ForeignKey("diary.id")) diary_id: Mapped[int] = mapped_column(Integer, ForeignKey("diary.id"))
entries: Mapped[list[Entry]] = relationship( entries: Mapped[list[Entry]] = relationship(lazy="selectin", order_by=Entry.last_changed)
lazy="selectin", order_by=Entry.last_changed
)
@property
def calories(self) -> float:
"""calories.
:rtype: float
"""
return sum(entry.calories for entry in self.entries)
@property
def protein(self) -> float:
"""protein.
:rtype: float
"""
return sum(entry.protein for entry in self.entries)
@property
def carb(self) -> float:
"""carb.
:rtype: float
"""
return sum(entry.carb for entry in self.entries)
@property
def fat(self) -> float:
"""fat.
:rtype: float
"""
return sum(entry.fat for entry in self.entries)
@property
def fiber(self) -> float:
"""fiber.
:rtype: float
"""
return sum(entry.fiber for entry in self.entries)
@classmethod
async def create(
cls,
session: AsyncSession,
diary_id: int,
name: Optional[str] = None,
) -> "Meal":
# check if order already exists in diary
query = (
select(cls.order).where(cls.diary_id == diary_id).order_by(cls.order.desc())
)
existing_meal = await session.scalar(query)
order = existing_meal + 1 if existing_meal else 1
if name is None:
name = f"Meal {order}"
meal = Meal(diary_id=diary_id, name=name, order=order)
session.add(meal)
try:
await session.flush()
except IntegrityError:
raise AssertionError("diary does not exist")
db_meal = await cls._get_by_id(session, meal.id)
if not db_meal:
raise RuntimeError()
return db_meal
@classmethod
async def create_from_preset(
cls,
session: AsyncSession,
diary_id: int,
name: Optional[str],
preset: Preset,
) -> "Meal":
# check if order already exists in diary
query = (
select(cls.order).where(cls.diary_id == diary_id).order_by(cls.order.desc())
)
existing_meal = await session.scalar(query)
order = existing_meal + 1 if existing_meal else 1
if name is None:
name = preset.name or f"Meal {order}"
meal = Meal(diary_id=diary_id, name=name, order=order)
session.add(meal)
try:
await session.flush()
except IntegrityError:
raise AssertionError("diary does not exist")
for entry in preset.entries:
await Entry.create(session, meal.id, entry.product_id, entry.grams)
db_meal = await cls._get_by_id(session, meal.id)
if not db_meal:
raise RuntimeError()
return db_meal
@classmethod
async def _get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Meal]":
"""get_by_id."""
query = select(cls).where(cls.id == id).options(joinedload(cls.entries))
return await session.scalar(query.order_by(cls.id))
@classmethod
async def get_by_id(
cls, session: AsyncSession, user_id: int, id: int
) -> "Optional[Meal]":
"""get_by_id."""
from .diary import Diary
query = (
select(cls)
.where(cls.id == id)
.join(Diary)
.where(Diary.user_id == user_id)
.options(joinedload(cls.entries))
)
return await session.scalar(query.order_by(cls.id))
async def delete(self, session: AsyncSession) -> None:
"""delete."""
for entry in self.entries:
await session.delete(entry)
await session.delete(self)
await session.flush()

View file

@ -1,17 +1,11 @@
from typing import TYPE_CHECKING, AsyncIterator, Optional from sqlalchemy import ForeignKey, Integer
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy import ForeignKey, Integer, select from fooder.domain.base import Base, AggregateMacrosMixin, CommonMixin
from sqlalchemy.ext.asyncio import AsyncSession from fooder.domain.preset_entry import PresetEntry
from sqlalchemy.orm import Mapped, joinedload, mapped_column, relationship
from .base import Base, CommonMixin
from .preset_entry import PresetEntry
if TYPE_CHECKING:
from .meal import Meal
class Preset(Base, CommonMixin): class Preset(Base, CommonMixin, AggregateMacrosMixin):
"""Preset.""" """Preset."""
name: Mapped[str] name: Mapped[str]
@ -19,104 +13,3 @@ class Preset(Base, CommonMixin):
entries: Mapped[list[PresetEntry]] = relationship( entries: Mapped[list[PresetEntry]] = relationship(
lazy="selectin", order_by=PresetEntry.last_changed lazy="selectin", order_by=PresetEntry.last_changed
) )
@property
def calories(self) -> float:
"""calories.
:rtype: float
"""
return sum(entry.calories for entry in self.entries)
@property
def protein(self) -> float:
"""protein.
:rtype: float
"""
return sum(entry.protein for entry in self.entries)
@property
def carb(self) -> float:
"""carb.
:rtype: float
"""
return sum(entry.carb for entry in self.entries)
@property
def fat(self) -> float:
"""fat.
:rtype: float
"""
return sum(entry.fat for entry in self.entries)
@property
def fiber(self) -> float:
"""fiber.
:rtype: float
"""
return sum(entry.fiber for entry in self.entries)
@classmethod
async def create(
cls, session: AsyncSession, user_id: int, name: str, meal: "Meal"
) -> "Preset":
preset = Preset(user_id=user_id, name=name)
session.add(preset)
try:
await session.flush()
except Exception:
raise RuntimeError()
for entry in meal.entries:
await PresetEntry.create(session, preset.id, entry)
db_preset = await cls.get(session, user_id, preset.id)
if not db_preset:
raise RuntimeError()
return db_preset
@classmethod
async def list_all(
cls,
session: AsyncSession,
user_id: int,
offset: int,
limit: int,
q: Optional[str] = None,
) -> AsyncIterator["Preset"]:
query = select(cls).filter(cls.user_id == user_id)
if q:
query = query.filter(cls.name.ilike(f"%{q.lower()}%"))
query = query.offset(offset).limit(limit)
stream = await session.stream_scalars(query.order_by(cls.id))
async for row in stream:
yield row
@classmethod
async def get(
cls, session: AsyncSession, user_id: int, preset_id: int
) -> "Optional[Preset]":
"""get."""
query = (
select(cls)
.where(cls.id == preset_id)
.where(cls.user_id == user_id)
.options(joinedload(cls.entries).joinedload(PresetEntry.product))
)
return await session.scalar(query)
async def delete(self, session: AsyncSession) -> None:
for entry in self.entries:
await session.delete(entry)
await session.delete(self)
await session.flush()

View file

@ -1,89 +1,14 @@
from datetime import datetime from sqlalchemy import ForeignKey, Integer
from sqlalchemy import DateTime, ForeignKey, Integer
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from .base import Base, CommonMixin from fooder.domain.base import Base, CommonMixin, EntryMacrosMixin
from .entry import Entry from fooder.domain.product import Product
from .product import Product
class PresetEntry(Base, CommonMixin): class PresetEntry(Base, CommonMixin, EntryMacrosMixin):
"""Entry.""" """PresetEntry."""
grams: Mapped[float] grams: Mapped[float]
product_id: Mapped[int] = mapped_column(Integer, ForeignKey("product.id")) product_id: Mapped[int] = mapped_column(Integer, ForeignKey("product.id"))
product: Mapped[Product] = relationship(lazy="selectin") product: Mapped[Product] = relationship(lazy="selectin")
preset_id: Mapped[int] = mapped_column(Integer, ForeignKey("preset.id")) preset_id: Mapped[int] = mapped_column(Integer, ForeignKey("preset.id"))
last_changed: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
@property
def amount(self) -> float:
"""amount.
:rtype: float
"""
return self.grams / 100
@property
def calories(self) -> float:
"""calories.
:rtype: float
"""
return self.amount * self.product.calories
@property
def protein(self) -> float:
"""protein.
:rtype: float
"""
return self.amount * self.product.protein
@property
def carb(self) -> float:
"""carb.
:rtype: float
"""
return self.amount * self.product.carb
@property
def fat(self) -> float:
"""fat.
:rtype: float
"""
return self.amount * self.product.fat
@property
def fiber(self) -> float:
"""fiber.
:rtype: float
"""
return self.amount * self.product.fiber
@classmethod
async def create(
self,
session: AsyncSession,
preset_id: int,
entry: Entry,
) -> None:
pentry = PresetEntry(
preset_id=preset_id,
product_id=entry.product_id,
grams=entry.grams,
)
session.add(pentry)
try:
await session.flush()
except IntegrityError:
raise AssertionError("preset or product does not exist")

View file

@ -1,16 +1,22 @@
from sqlalchemy import Index, text
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from fooder.domain.base import Base, CommonMixin from fooder.domain.base import Base, CommonMixin, SoftDeleteMixin
class Product(Base, CommonMixin): class Product(Base, CommonMixin, SoftDeleteMixin):
"""Product.""" """Product."""
name: Mapped[str] __table_args__ = (
Index("ix_product_barcode", "barcode", unique=True,
postgresql_where=text("deleted_at IS NULL"),
sqlite_where=text("deleted_at IS NULL")),
)
name: Mapped[str]
protein: Mapped[float] protein: Mapped[float]
carb: Mapped[float] carb: Mapped[float]
fat: Mapped[float] fat: Mapped[float]
fiber: Mapped[float] fiber: Mapped[float]
calories: Mapped[float] calories: Mapped[float]
barcode: Mapped[str | None] = mapped_column(unique=True) barcode: Mapped[str | None] = mapped_column(default=None)

View file

@ -1,9 +1,19 @@
from sqlalchemy.orm import Mapped from __future__ import annotations
from fooder.domain.base import Base, CommonMixin, PasswordMixin from typing import TYPE_CHECKING
from sqlalchemy.orm import Mapped, relationship
from fooder.domain.base import Base, CommonMixin, PasswordMixin, SoftDeleteMixin
if TYPE_CHECKING:
from fooder.domain.user_settings import UserSettings
class User(Base, CommonMixin, PasswordMixin): class User(Base, CommonMixin, PasswordMixin, SoftDeleteMixin):
"""Product.""" """User."""
username: Mapped[str] username: Mapped[str]
settings: Mapped[UserSettings] = relationship(
back_populates="user", lazy="selectin", uselist=False
)

View file

@ -0,0 +1,19 @@
from sqlalchemy import ForeignKey, Integer, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from fooder.domain.base import Base, CommonMixin
from fooder.domain.product import Product
from fooder.domain.user import User
class UserProductUsage(Base, CommonMixin):
"""Counts how many processed entries a user has for a product.
Used to sort products by usage frequency."""
__table_args__ = (UniqueConstraint("user_id", "product_id"),)
product_id: Mapped[int] = mapped_column(Integer, ForeignKey("product.id"))
product: Mapped[Product] = relationship(lazy="selectin")
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
user: Mapped[User] = relationship(lazy="selectin")
count: Mapped[int]

View file

@ -0,0 +1,24 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, Integer
from sqlalchemy.orm import Mapped, mapped_column, relationship
from fooder.domain.base import Base, CommonMixin
if TYPE_CHECKING:
from fooder.domain.user import User
class UserSettings(Base, CommonMixin):
"""UserSettings."""
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"), unique=True)
user: Mapped[User] = relationship(back_populates="settings")
# meals_name_convention: Mapped[list[str]] # json column, format TBD
protein_goal: Mapped[float]
carb_goal: Mapped[float]
fat_goal: Mapped[float]
fiber_goal: Mapped[float]
calories_goal: Mapped[float]

View file

@ -1,10 +1,11 @@
from typing import TypeVar, Generic, Type, Sequence from typing import TypeVar, Generic, Type, Sequence
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete as sa_delete, ColumnElement from sqlalchemy import select, delete as sa_delete, update as sa_update, ColumnElement
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import Select from sqlalchemy.sql import Select
from fooder.domain import Base from fooder.domain import Base
from fooder.exc import Conflict, NotFound from fooder.exc import Conflict, NotFound
from fooder.utils.datetime import utc_now
T = TypeVar("T", bound=Base) T = TypeVar("T", bound=Base)
@ -20,6 +21,9 @@ class RepositoryBase(Generic[T]):
def _build_select(self, *expressions: ColumnElement) -> Select[tuple[T]]: def _build_select(self, *expressions: ColumnElement) -> Select[tuple[T]]:
stmt = select(self.model) stmt = select(self.model)
if hasattr(self.model, "deleted_at"):
stmt = stmt.where(self.model.deleted_at.is_(None)) # type: ignore[attr-defined]
if expressions: if expressions:
stmt = stmt.where(*expressions) stmt = stmt.where(*expressions)
@ -64,7 +68,10 @@ class RepositoryBase(Generic[T]):
return obj return obj
async def _delete(self, *expressions: ColumnElement): async def _delete(self, *expressions: ColumnElement):
stmt = sa_delete(self.model) if hasattr(self.model, "deleted_at"):
stmt = sa_update(self.model).values(deleted_at=utc_now()) # type: ignore[attr-defined]
else:
stmt = sa_delete(self.model)
if expressions: if expressions:
stmt = stmt.where(*expressions) stmt = stmt.where(*expressions)

View file

@ -5,14 +5,13 @@ from sqlalchemy import event
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from fooder.db import DatabaseSessionManager from fooder.db import DatabaseSessionManager
from fooder.domain import Base from fooder.domain import Base
from fooder.domain.base import CommonMixin
from fooder.settings import settings from fooder.settings import settings
from fooder.repository.base import RepositoryBase from fooder.repository.base import RepositoryBase
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
class TestModel(Base): class TestModel(Base, CommonMixin):
__tablename__ = "test"
id: Mapped[int] = mapped_column(primary_key=True)
property: Mapped[str] property: Mapped[str]