[repository] changed my mind about building statement, dont want Any
This commit is contained in:
parent
10ef646d93
commit
4182072be2
3 changed files with 28 additions and 68 deletions
|
|
@ -1,6 +1,11 @@
|
||||||
from typing import TypeVar, Generic, Type, Any, Sequence
|
from typing import TypeVar, Generic, Type, Any, Sequence
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy import select, update as sa_update, delete as sa_delete
|
from sqlalchemy import (
|
||||||
|
select,
|
||||||
|
update as sa_update,
|
||||||
|
delete as sa_delete,
|
||||||
|
BinaryExpression,
|
||||||
|
)
|
||||||
from sqlalchemy.sql import Select
|
from sqlalchemy.sql import Select
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
@ -11,24 +16,21 @@ class RepositoryBase(Generic[T]):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.session = session
|
self.session = session
|
||||||
|
|
||||||
def _build_select(self, **filters: Any) -> Select[tuple[T]]:
|
def _build_select(self, *expressions: BinaryExpression) -> Select[tuple[T]]:
|
||||||
stmt = select(self.model)
|
stmt = select(self.model)
|
||||||
|
|
||||||
for field, value in filters.items():
|
if expressions:
|
||||||
column = getattr(self.model, field, None)
|
stmt = stmt.where(*expressions)
|
||||||
if column is None:
|
|
||||||
raise ValueError(f"{self.model.__name__} has no attribute '{field}'")
|
|
||||||
stmt = stmt.where(column == value)
|
|
||||||
|
|
||||||
return stmt
|
return stmt
|
||||||
|
|
||||||
async def get(self, **filters: Any) -> T | None:
|
async def get(self, *expressions: BinaryExpression) -> T | None:
|
||||||
stmt = self._build_select(**filters)
|
stmt = self._build_select(*expressions)
|
||||||
result = await self.session.execute(stmt)
|
result = await self.session.execute(stmt)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
async def list(self, **filters: Any) -> Sequence[T]:
|
async def list(self, *expressions: BinaryExpression) -> Sequence[T]:
|
||||||
stmt = self._build_select(**filters)
|
stmt = self._build_select(*expressions)
|
||||||
result = await self.session.execute(stmt)
|
result = await self.session.execute(stmt)
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
@ -38,22 +40,10 @@ class RepositoryBase(Generic[T]):
|
||||||
await self.session.refresh(obj)
|
await self.session.refresh(obj)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
async def delete(self, **filters: Any) -> int:
|
async def delete(self, *expressions: BinaryExpression):
|
||||||
stmt = sa_delete(self.model)
|
stmt = sa_delete(self.model)
|
||||||
|
|
||||||
for field, value in filters.items():
|
if expressions:
|
||||||
column = getattr(self.model, field)
|
stmt = stmt.where(*expressions)
|
||||||
stmt = stmt.where(column == value)
|
|
||||||
|
|
||||||
result = await self.session.execute(stmt)
|
await self.session.execute(stmt)
|
||||||
return result.rowcount if result.rowcount != -1 else 0
|
|
||||||
|
|
||||||
async def update(self, filters: dict[str, Any], values: dict[str, Any]) -> int:
|
|
||||||
stmt = sa_update(self.model)
|
|
||||||
|
|
||||||
for field, value in filters.items():
|
|
||||||
stmt = stmt.where(getattr(self.model, field) == value)
|
|
||||||
|
|
||||||
stmt = stmt.values(**values)
|
|
||||||
result = await self.session.execute(stmt)
|
|
||||||
return result.rowcount if result.rowcount != -1 else 0
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
import faker
|
from ..fixtures.db import TestModel
|
||||||
# ------------------------------------------------------------------ create ---
|
# ------------------------------------------------------------------ create ---
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -14,19 +14,19 @@ async def test_create_returns_object_with_id(test_repo, test_model):
|
||||||
|
|
||||||
async def test_get_returns_existing_record(test_repo, test_model):
|
async def test_get_returns_existing_record(test_repo, test_model):
|
||||||
created = await test_repo.create(test_model)
|
created = await test_repo.create(test_model)
|
||||||
found = await test_repo.get(id=created.id)
|
found = await test_repo.get(TestModel.id == created.id)
|
||||||
assert found is not None
|
assert found is not None
|
||||||
assert found.id == created.id
|
assert found.id == created.id
|
||||||
|
|
||||||
|
|
||||||
async def test_get_returns_none_for_missing_record(test_repo):
|
async def test_get_returns_none_for_missing_record(test_repo):
|
||||||
result = await test_repo.get(id=1)
|
result = await test_repo.get(TestModel.id == 1)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
async def test_get_by_field(test_repo, test_model_factory):
|
async def test_get_by_field(test_repo, test_model_factory):
|
||||||
await test_repo.create(test_model_factory(1, "value"))
|
await test_repo.create(test_model_factory(1, "value"))
|
||||||
found = await test_repo.get(property="value")
|
found = await test_repo.get(TestModel.property == "value")
|
||||||
assert found is not None
|
assert found is not None
|
||||||
assert found.id == 1
|
assert found.id == 1
|
||||||
|
|
||||||
|
|
@ -45,52 +45,21 @@ async def test_list_returns_all_matching(test_repo, test_model_factory):
|
||||||
async def test_list_with_filter(test_repo, test_model_factory):
|
async def test_list_with_filter(test_repo, test_model_factory):
|
||||||
await test_repo.create(test_model_factory(1, "value"))
|
await test_repo.create(test_model_factory(1, "value"))
|
||||||
await test_repo.create(test_model_factory(2, "value2"))
|
await test_repo.create(test_model_factory(2, "value2"))
|
||||||
results = await test_repo.list(property="value")
|
results = await test_repo.list(TestModel.property == "value")
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results[0].id == 1
|
assert results[0].id == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_list_returns_empty_when_no_match(test_repo, test_model_factory):
|
async def test_list_returns_empty_when_no_match(test_repo, test_model_factory):
|
||||||
await test_repo.create(test_model_factory(1, "value"))
|
await test_repo.create(test_model_factory(1, "value"))
|
||||||
results = await test_repo.list(property="value2")
|
results = await test_repo.list(TestModel.property == "value2")
|
||||||
assert results == []
|
assert results == []
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------ delete ---
|
# ------------------------------------------------------------------ delete ---
|
||||||
|
|
||||||
|
|
||||||
async def test_delete_removes_record_and_returns_count(test_repo, test_model):
|
async def test_delete_removes_record(test_repo, test_model):
|
||||||
model = await test_repo.create(test_model)
|
model = await test_repo.create(test_model)
|
||||||
count = await test_repo.delete(id=model.id)
|
await test_repo.delete(TestModel.id == model.id)
|
||||||
assert count == 1
|
assert await test_repo.get(TestModel.id == model.id) is None
|
||||||
assert await test_repo.get(id=model.id) is None
|
|
||||||
|
|
||||||
|
|
||||||
async def test_delete_returns_zero_when_nothing_matched(test_repo, test_model_factory):
|
|
||||||
count = await test_repo.delete(id=999999)
|
|
||||||
assert count == 0
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------ update ---
|
|
||||||
|
|
||||||
|
|
||||||
async def test_update_modifies_record_and_returns_count(test_repo, test_model):
|
|
||||||
model = await test_repo.create(test_model)
|
|
||||||
count = await test_repo.update({"id": model.id}, {"property": "value2"})
|
|
||||||
assert count == 1
|
|
||||||
refreshed = await test_repo.get(id=model.id)
|
|
||||||
assert refreshed is not None
|
|
||||||
assert refreshed.property == "value2"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_update_returns_zero_when_nothing_matched(test_repo, test_model_factory):
|
|
||||||
count = await test_repo.update({"id": 999999}, {"property": "value"})
|
|
||||||
assert count == 0
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------- _build_select ------
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_select_raises_for_unknown_field(test_repo):
|
|
||||||
with pytest.raises(ValueError, match="has no attribute 'nonexistent'"):
|
|
||||||
test_repo._build_select(nonexistent="value")
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from ..model.token import RefreshTokenPayload, Token
|
from ..model.token import RefreshTokenPayload, Token
|
||||||
from ..context import ContextDependency, Context
|
from ..context import ContextDependency, Context
|
||||||
from ..utils.jwt import AccessToken, RefreshToken
|
from ..utils.jwt import AccessToken, RefreshToken
|
||||||
|
from ..domain import User
|
||||||
|
|
||||||
router = APIRouter(tags=["token"])
|
router = APIRouter(tags=["token"])
|
||||||
|
|
||||||
|
|
@ -16,7 +17,7 @@ async def create_token(
|
||||||
data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
ctx: Context = Depends(ContextDependency()),
|
ctx: Context = Depends(ContextDependency()),
|
||||||
):
|
):
|
||||||
user = await ctx.repo.user.get(username=data.username)
|
user = await ctx.repo.user.get(User.username == data.username)
|
||||||
|
|
||||||
if user is None or not user.verify_password(data.password):
|
if user is None or not user.verify_password(data.password):
|
||||||
raise HTTPException(status_code=401, detail="Unathorized")
|
raise HTTPException(status_code=401, detail="Unathorized")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue