diff --git a/fooder/repository/base.py b/fooder/repository/base.py index a8a8111..d56d591 100644 --- a/fooder/repository/base.py +++ b/fooder/repository/base.py @@ -1,6 +1,11 @@ from typing import TypeVar, Generic, Type, Any, Sequence from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select, update as sa_update, delete as sa_delete +from sqlalchemy import ( + select, + update as sa_update, + delete as sa_delete, + BinaryExpression, +) from sqlalchemy.sql import Select T = TypeVar("T") @@ -11,24 +16,21 @@ class RepositoryBase(Generic[T]): self.model = model self.session = session - def _build_select(self, **filters: Any) -> Select[tuple[T]]: + def _build_select(self, *expressions: BinaryExpression) -> Select[tuple[T]]: stmt = select(self.model) - for field, value in filters.items(): - column = getattr(self.model, field, None) - if column is None: - raise ValueError(f"{self.model.__name__} has no attribute '{field}'") - stmt = stmt.where(column == value) + if expressions: + stmt = stmt.where(*expressions) return stmt - async def get(self, **filters: Any) -> T | None: - stmt = self._build_select(**filters) + async def get(self, *expressions: BinaryExpression) -> T | None: + stmt = self._build_select(*expressions) result = await self.session.execute(stmt) return result.scalar_one_or_none() - async def list(self, **filters: Any) -> Sequence[T]: - stmt = self._build_select(**filters) + async def list(self, *expressions: BinaryExpression) -> Sequence[T]: + stmt = self._build_select(*expressions) result = await self.session.execute(stmt) return result.scalars().all() @@ -38,22 +40,10 @@ class RepositoryBase(Generic[T]): await self.session.refresh(obj) return obj - async def delete(self, **filters: Any) -> int: + async def delete(self, *expressions: BinaryExpression): stmt = sa_delete(self.model) - for field, value in filters.items(): - column = getattr(self.model, field) - stmt = stmt.where(column == value) + if expressions: + stmt = stmt.where(*expressions) - result = await self.session.execute(stmt) - return result.rowcount if result.rowcount != -1 else 0 - - async def update(self, filters: dict[str, Any], values: dict[str, Any]) -> int: - stmt = sa_update(self.model) - - for field, value in filters.items(): - stmt = stmt.where(getattr(self.model, field) == value) - - stmt = stmt.values(**values) - result = await self.session.execute(stmt) - return result.rowcount if result.rowcount != -1 else 0 + await self.session.execute(stmt) diff --git a/fooder/test/repository/test_base.py b/fooder/test/repository/test_base.py index f1b8d97..5973c36 100644 --- a/fooder/test/repository/test_base.py +++ b/fooder/test/repository/test_base.py @@ -1,5 +1,5 @@ import pytest -import faker +from ..fixtures.db import TestModel # ------------------------------------------------------------------ create --- @@ -14,19 +14,19 @@ async def test_create_returns_object_with_id(test_repo, test_model): async def test_get_returns_existing_record(test_repo, test_model): created = await test_repo.create(test_model) - found = await test_repo.get(id=created.id) + found = await test_repo.get(TestModel.id == created.id) assert found is not None assert found.id == created.id async def test_get_returns_none_for_missing_record(test_repo): - result = await test_repo.get(id=1) + result = await test_repo.get(TestModel.id == 1) assert result is None async def test_get_by_field(test_repo, test_model_factory): await test_repo.create(test_model_factory(1, "value")) - found = await test_repo.get(property="value") + found = await test_repo.get(TestModel.property == "value") assert found is not None assert found.id == 1 @@ -45,52 +45,21 @@ async def test_list_returns_all_matching(test_repo, test_model_factory): async def test_list_with_filter(test_repo, test_model_factory): await test_repo.create(test_model_factory(1, "value")) await test_repo.create(test_model_factory(2, "value2")) - results = await test_repo.list(property="value") + results = await test_repo.list(TestModel.property == "value") assert len(results) == 1 assert results[0].id == 1 async def test_list_returns_empty_when_no_match(test_repo, test_model_factory): await test_repo.create(test_model_factory(1, "value")) - results = await test_repo.list(property="value2") + results = await test_repo.list(TestModel.property == "value2") assert results == [] # ------------------------------------------------------------------ delete --- -async def test_delete_removes_record_and_returns_count(test_repo, test_model): +async def test_delete_removes_record(test_repo, test_model): model = await test_repo.create(test_model) - count = await test_repo.delete(id=model.id) - assert count == 1 - assert await test_repo.get(id=model.id) is None - - -async def test_delete_returns_zero_when_nothing_matched(test_repo, test_model_factory): - count = await test_repo.delete(id=999999) - assert count == 0 - - -# ------------------------------------------------------------------ update --- - - -async def test_update_modifies_record_and_returns_count(test_repo, test_model): - model = await test_repo.create(test_model) - count = await test_repo.update({"id": model.id}, {"property": "value2"}) - assert count == 1 - refreshed = await test_repo.get(id=model.id) - assert refreshed is not None - assert refreshed.property == "value2" - - -async def test_update_returns_zero_when_nothing_matched(test_repo, test_model_factory): - count = await test_repo.update({"id": 999999}, {"property": "value"}) - assert count == 0 - - -# ------------------------------------------------------- _build_select ------ - - -def test_build_select_raises_for_unknown_field(test_repo): - with pytest.raises(ValueError, match="has no attribute 'nonexistent'"): - test_repo._build_select(nonexistent="value") + await test_repo.delete(TestModel.id == model.id) + assert await test_repo.get(TestModel.id == model.id) is None diff --git a/fooder/view/token.py b/fooder/view/token.py index d6475a0..f5f86af 100644 --- a/fooder/view/token.py +++ b/fooder/view/token.py @@ -7,6 +7,7 @@ from fastapi.security import OAuth2PasswordRequestForm from ..model.token import RefreshTokenPayload, Token from ..context import ContextDependency, Context from ..utils.jwt import AccessToken, RefreshToken +from ..domain import User router = APIRouter(tags=["token"]) @@ -16,7 +17,7 @@ async def create_token( data: Annotated[OAuth2PasswordRequestForm, Depends()], ctx: Context = Depends(ContextDependency()), ): - user = await ctx.repo.user.get(username=data.username) + user = await ctx.repo.user.get(User.username == data.username) if user is None or not user.verify_password(data.password): raise HTTPException(status_code=401, detail="Unathorized")