from typing import TypeVar, Generic, Type, Sequence, cast from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import ( ColumnElement, select, ) from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.exc import StaleDataError from sqlalchemy.sql import Select from fooder.domain.base import Base, SoftDeletable from fooder.exc import Conflict, NotFound from fooder.utils.datetime import utc_now T = TypeVar("T", bound=Base) DEFAULT_LIMIT = 20 class RepositoryBase(Generic[T]): def __init__(self, model: Type[T], session: AsyncSession): self.model = model self.session = session self._is_soft_delete = hasattr(self.model, "deleted_at") def _build_select( self, *expressions: ColumnElement, stmt: Select | None = None ) -> Select[tuple[T]]: if stmt is None: stmt = select(self.model) if self._is_soft_delete: cls = cast(Type[SoftDeletable], self.model) stmt = stmt.where(cls.deleted_at.is_(None)) if expressions: stmt = stmt.where(*expressions) return stmt async def _get(self, *expressions: ColumnElement, stmt: Select | None = None) -> T: stmt = self._build_select(*expressions, stmt=stmt) obj = await self.session.scalar(stmt) if obj is None: raise NotFound() return obj async def _get_for_update( self, *expressions: ColumnElement, stmt: Select | None = None ) -> T: stmt = self._build_select(*expressions, stmt=stmt).with_for_update() obj = await self.session.scalar(stmt) if obj is None: raise NotFound() return obj async def _list( self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT, stmt: Select | None = None, ) -> Sequence[T]: stmt = self._build_select(*expressions, stmt=stmt) if offset: stmt = stmt.offset(offset) if limit is not None: stmt = stmt.limit(limit) result = await self.session.execute(stmt) return result.scalars().all() async def create(self, obj: T) -> T: self.session.add(obj) try: await self.session.flush() except IntegrityError: raise Conflict() await self.session.refresh(obj) return obj async def update(self, obj: T) -> T: try: await self.session.flush() except IntegrityError: raise Conflict() except StaleDataError: raise Conflict() return obj async def delete(self, obj: T) -> None: if self._is_soft_delete: soft_obj = cast(SoftDeletable, obj) soft_obj.deleted_at = utc_now() else: await self.session.delete(obj) try: await self.session.flush() except IntegrityError: raise Conflict()