from typing import TypeVar, Generic, Type, Sequence from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, delete as sa_delete, ColumnElement from sqlalchemy.exc import IntegrityError from sqlalchemy.sql import Select from fooder.domain import Base from fooder.exc import Conflict, NotFound T = TypeVar("T", bound=Base) DEFAULT_LIMIT = 20 class RepositoryBase(Generic[T]): def __init__(self, model: Type[T], session: AsyncSession): self.model = model self.session = session def _build_select(self, *expressions: ColumnElement) -> Select[tuple[T]]: stmt = select(self.model) if expressions: stmt = stmt.where(*expressions) return stmt async def _get(self, *expressions: ColumnElement) -> T: stmt = self._build_select(*expressions) result = await self.session.execute(stmt) obj = result.scalar_one_or_none() if obj is None: raise NotFound() return obj async def _list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]: stmt = self._build_select(*expressions) if offset: stmt = stmt.offset(offset) if limit is not None: stmt = stmt.limit(limit) result = await self.session.execute(stmt) return result.scalars().all() async def create(self, obj: T) -> T: self.session.add(obj) try: await self.session.flush() except IntegrityError: raise Conflict() await self.session.refresh(obj) return obj async def update(self, obj: T) -> T: try: await self.session.flush() except IntegrityError: raise Conflict() return obj async def _delete(self, *expressions: ColumnElement): stmt = sa_delete(self.model) if expressions: stmt = stmt.where(*expressions) await self.session.execute(stmt)