from typing import TypeVar, Generic, Type, Any, Sequence from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update as sa_update, delete as sa_delete from sqlalchemy.sql import Select T = TypeVar("T") class RepositoryBase(Generic[T]): def __init__(self, model: Type[T], session: AsyncSession): self.model = model self.session = session def _build_select(self, **filters: Any) -> Select[tuple[T]]: stmt = select(self.model) for field, value in filters.items(): column = getattr(self.model, field, None) if column is None: raise ValueError(f"{self.model.__name__} has no attribute '{field}'") stmt = stmt.where(column == value) return stmt async def get(self, **filters: Any) -> T | None: stmt = self._build_select(**filters) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def list(self, **filters: Any) -> Sequence[T]: stmt = self._build_select(**filters) result = await self.session.execute(stmt) return result.scalars().all() async def create(self, obj: T) -> T: self.session.add(obj) await self.session.flush() await self.session.refresh(obj) return obj async def delete(self, **filters: Any) -> int: stmt = sa_delete(self.model) for field, value in filters.items(): column = getattr(self.model, field) stmt = stmt.where(column == value) result = await self.session.execute(stmt) return result.rowcount if result.rowcount != -1 else 0 async def update(self, filters: dict[str, Any], values: dict[str, Any]) -> int: stmt = sa_update(self.model) for field, value in filters.items(): stmt = stmt.where(getattr(self.model, field) == value) stmt = stmt.values(**values) result = await self.session.execute(stmt) return result.rowcount if result.rowcount != -1 else 0