[repository] changed my mind about building statement, dont want Any

This commit is contained in:
Piotr Domański 2026-04-02 23:35:20 +02:00
parent 10ef646d93
commit 4182072be2
3 changed files with 28 additions and 68 deletions

View file

@ -1,6 +1,11 @@
from typing import TypeVar, Generic, Type, Any, Sequence from typing import TypeVar, Generic, Type, Any, Sequence
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update as sa_update, delete as sa_delete from sqlalchemy import (
select,
update as sa_update,
delete as sa_delete,
BinaryExpression,
)
from sqlalchemy.sql import Select from sqlalchemy.sql import Select
T = TypeVar("T") T = TypeVar("T")
@ -11,24 +16,21 @@ class RepositoryBase(Generic[T]):
self.model = model self.model = model
self.session = session self.session = session
def _build_select(self, **filters: Any) -> Select[tuple[T]]: def _build_select(self, *expressions: BinaryExpression) -> Select[tuple[T]]:
stmt = select(self.model) stmt = select(self.model)
for field, value in filters.items(): if expressions:
column = getattr(self.model, field, None) stmt = stmt.where(*expressions)
if column is None:
raise ValueError(f"{self.model.__name__} has no attribute '{field}'")
stmt = stmt.where(column == value)
return stmt return stmt
async def get(self, **filters: Any) -> T | None: async def get(self, *expressions: BinaryExpression) -> T | None:
stmt = self._build_select(**filters) stmt = self._build_select(*expressions)
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
return result.scalar_one_or_none() return result.scalar_one_or_none()
async def list(self, **filters: Any) -> Sequence[T]: async def list(self, *expressions: BinaryExpression) -> Sequence[T]:
stmt = self._build_select(**filters) stmt = self._build_select(*expressions)
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
return result.scalars().all() return result.scalars().all()
@ -38,22 +40,10 @@ class RepositoryBase(Generic[T]):
await self.session.refresh(obj) await self.session.refresh(obj)
return obj return obj
async def delete(self, **filters: Any) -> int: async def delete(self, *expressions: BinaryExpression):
stmt = sa_delete(self.model) stmt = sa_delete(self.model)
for field, value in filters.items(): if expressions:
column = getattr(self.model, field) stmt = stmt.where(*expressions)
stmt = stmt.where(column == value)
result = await self.session.execute(stmt) await self.session.execute(stmt)
return result.rowcount if result.rowcount != -1 else 0
async def update(self, filters: dict[str, Any], values: dict[str, Any]) -> int:
stmt = sa_update(self.model)
for field, value in filters.items():
stmt = stmt.where(getattr(self.model, field) == value)
stmt = stmt.values(**values)
result = await self.session.execute(stmt)
return result.rowcount if result.rowcount != -1 else 0

View file

@ -1,5 +1,5 @@
import pytest import pytest
import faker from ..fixtures.db import TestModel
# ------------------------------------------------------------------ create --- # ------------------------------------------------------------------ create ---
@ -14,19 +14,19 @@ async def test_create_returns_object_with_id(test_repo, test_model):
async def test_get_returns_existing_record(test_repo, test_model): async def test_get_returns_existing_record(test_repo, test_model):
created = await test_repo.create(test_model) created = await test_repo.create(test_model)
found = await test_repo.get(id=created.id) found = await test_repo.get(TestModel.id == created.id)
assert found is not None assert found is not None
assert found.id == created.id assert found.id == created.id
async def test_get_returns_none_for_missing_record(test_repo): async def test_get_returns_none_for_missing_record(test_repo):
result = await test_repo.get(id=1) result = await test_repo.get(TestModel.id == 1)
assert result is None assert result is None
async def test_get_by_field(test_repo, test_model_factory): async def test_get_by_field(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1, "value")) await test_repo.create(test_model_factory(1, "value"))
found = await test_repo.get(property="value") found = await test_repo.get(TestModel.property == "value")
assert found is not None assert found is not None
assert found.id == 1 assert found.id == 1
@ -45,52 +45,21 @@ async def test_list_returns_all_matching(test_repo, test_model_factory):
async def test_list_with_filter(test_repo, test_model_factory): async def test_list_with_filter(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1, "value")) await test_repo.create(test_model_factory(1, "value"))
await test_repo.create(test_model_factory(2, "value2")) await test_repo.create(test_model_factory(2, "value2"))
results = await test_repo.list(property="value") results = await test_repo.list(TestModel.property == "value")
assert len(results) == 1 assert len(results) == 1
assert results[0].id == 1 assert results[0].id == 1
async def test_list_returns_empty_when_no_match(test_repo, test_model_factory): async def test_list_returns_empty_when_no_match(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1, "value")) await test_repo.create(test_model_factory(1, "value"))
results = await test_repo.list(property="value2") results = await test_repo.list(TestModel.property == "value2")
assert results == [] assert results == []
# ------------------------------------------------------------------ delete --- # ------------------------------------------------------------------ delete ---
async def test_delete_removes_record_and_returns_count(test_repo, test_model): async def test_delete_removes_record(test_repo, test_model):
model = await test_repo.create(test_model) model = await test_repo.create(test_model)
count = await test_repo.delete(id=model.id) await test_repo.delete(TestModel.id == model.id)
assert count == 1 assert await test_repo.get(TestModel.id == model.id) is None
assert await test_repo.get(id=model.id) is None
async def test_delete_returns_zero_when_nothing_matched(test_repo, test_model_factory):
count = await test_repo.delete(id=999999)
assert count == 0
# ------------------------------------------------------------------ update ---
async def test_update_modifies_record_and_returns_count(test_repo, test_model):
model = await test_repo.create(test_model)
count = await test_repo.update({"id": model.id}, {"property": "value2"})
assert count == 1
refreshed = await test_repo.get(id=model.id)
assert refreshed is not None
assert refreshed.property == "value2"
async def test_update_returns_zero_when_nothing_matched(test_repo, test_model_factory):
count = await test_repo.update({"id": 999999}, {"property": "value"})
assert count == 0
# ------------------------------------------------------- _build_select ------
def test_build_select_raises_for_unknown_field(test_repo):
with pytest.raises(ValueError, match="has no attribute 'nonexistent'"):
test_repo._build_select(nonexistent="value")

View file

@ -7,6 +7,7 @@ from fastapi.security import OAuth2PasswordRequestForm
from ..model.token import RefreshTokenPayload, Token from ..model.token import RefreshTokenPayload, Token
from ..context import ContextDependency, Context from ..context import ContextDependency, Context
from ..utils.jwt import AccessToken, RefreshToken from ..utils.jwt import AccessToken, RefreshToken
from ..domain import User
router = APIRouter(tags=["token"]) router = APIRouter(tags=["token"])
@ -16,7 +17,7 @@ async def create_token(
data: Annotated[OAuth2PasswordRequestForm, Depends()], data: Annotated[OAuth2PasswordRequestForm, Depends()],
ctx: Context = Depends(ContextDependency()), ctx: Context = Depends(ContextDependency()),
): ):
user = await ctx.repo.user.get(username=data.username) user = await ctx.repo.user.get(User.username == data.username)
if user is None or not user.verify_password(data.password): if user is None or not user.verify_password(data.password):
raise HTTPException(status_code=401, detail="Unathorized") raise HTTPException(status_code=401, detail="Unathorized")