from typing import TypeVar, Generic, Type, Sequence from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, delete as sa_delete, update as sa_update, ColumnElement from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.exc import StaleDataError from sqlalchemy.sql import Select from fooder.domain import Base from fooder.exc import Conflict, NotFound from fooder.utils.datetime import utc_now T = TypeVar("T", bound=Base) DEFAULT_LIMIT = 20 class RepositoryBase(Generic[T]): def __init__(self, model: Type[T], session: AsyncSession): self.model = model self.session = session def _build_select(self, *expressions: ColumnElement) -> Select[tuple[T]]: stmt = select(self.model) if hasattr(self.model, "deleted_at"): stmt = stmt.where(self.model.deleted_at.is_(None)) # type: ignore[attr-defined] if expressions: stmt = stmt.where(*expressions) return stmt async def _get(self, *expressions: ColumnElement) -> T: stmt = self._build_select(*expressions) result = await self.session.execute(stmt) obj = result.scalar_one_or_none() if obj is None: raise NotFound() return obj async def _get_for_update(self, *expressions: ColumnElement) -> T: stmt = self._build_select(*expressions).with_for_update() result = await self.session.execute(stmt) obj = result.scalar_one_or_none() if obj is None: raise NotFound() return obj async def _list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]: stmt = self._build_select(*expressions) if offset: stmt = stmt.offset(offset) if limit is not None: stmt = stmt.limit(limit) result = await self.session.execute(stmt) return result.scalars().all() async def create(self, obj: T) -> T: self.session.add(obj) try: await self.session.flush() except IntegrityError: raise Conflict() await self.session.refresh(obj) return obj async def update(self, obj: T) -> T: try: await self.session.flush() except IntegrityError: raise Conflict() except StaleDataError: raise Conflict() return obj async def _delete(self, *expressions: ColumnElement): if hasattr(self.model, "deleted_at"): stmt = sa_update(self.model).values(deleted_at=utc_now()) # type: ignore[attr-defined] else: stmt = sa_delete(self.model) if expressions: stmt = stmt.where(*expressions) await self.session.execute(stmt)