fooder-api/fooder/repository/base.py

59 lines
2 KiB
Python

from typing import TypeVar, Generic, Type, Any, Sequence
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update as sa_update, delete as sa_delete
from sqlalchemy.sql import Select
T = TypeVar("T")
class RepositoryBase(Generic[T]):
def __init__(self, model: Type[T], session: AsyncSession):
self.model = model
self.session = session
def _build_select(self, **filters: Any) -> Select[tuple[T]]:
stmt = select(self.model)
for field, value in filters.items():
column = getattr(self.model, field, None)
if column is None:
raise ValueError(f"{self.model.__name__} has no attribute '{field}'")
stmt = stmt.where(column == value)
return stmt
async def get(self, **filters: Any) -> T | None:
stmt = self._build_select(**filters)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def list(self, **filters: Any) -> Sequence[T]:
stmt = self._build_select(**filters)
result = await self.session.execute(stmt)
return result.scalars().all()
async def create(self, obj: T) -> T:
self.session.add(obj)
await self.session.flush()
await self.session.refresh(obj)
return obj
async def delete(self, **filters: Any) -> int:
stmt = sa_delete(self.model)
for field, value in filters.items():
column = getattr(self.model, field)
stmt = stmt.where(column == value)
result = await self.session.execute(stmt)
return result.rowcount if result.rowcount != -1 else 0
async def update(self, filters: dict[str, Any], values: dict[str, Any]) -> int:
stmt = sa_update(self.model)
for field, value in filters.items():
stmt = stmt.where(getattr(self.model, field) == value)
stmt = stmt.values(**values)
result = await self.session.execute(stmt)
return result.rowcount if result.rowcount != -1 else 0