[repo] encapsulate sqlalchemy completely

This commit is contained in:
Piotr Domański 2026-04-07 14:51:39 +02:00
parent 6129712efe
commit 446850ee12
9 changed files with 40 additions and 99 deletions

View file

@ -65,7 +65,7 @@ class AuthContextDependency:
ctx = Context(repo=Repository(session))
user_id = AccessToken.decode(token).sub
try:
user = await ctx.repo.user.get(User.id == user_id)
user = await ctx.repo.user.get_by_id(user_id)
except NotFound:
raise Unauthorized()
ctx.set_user(user)

View file

@ -13,7 +13,7 @@ class UserController(ModelController[User]):
password: str,
) -> "UserController":
try:
obj = await ctx.repo.user.get(User.username == username)
obj = await ctx.repo.user.get_by_username(username)
except NotFound:
raise Unauthorized()

View file

@ -4,6 +4,5 @@ from .entry import Entry # noqa
from .meal import Meal # noqa
from .product import Product # noqa
from .user import User # noqa
from .token import RefreshToken # noqa
from .preset import Preset # noqa
from .preset_entry import PresetEntry # noqa

View file

@ -1,71 +0,0 @@
from typing import Optional
from sqlalchemy import ForeignKey, Integer, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base, CommonMixin
class RefreshToken(Base, CommonMixin):
"""Diary represents user diary for given day"""
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
token: Mapped[str]
@classmethod
async def get_token(
cls,
session: AsyncSession,
user_id: int,
token: str,
) -> "Optional[RefreshToken]":
"""get_token.
:param session:
:type session: AsyncSession
:param user_id:
:type user_id: int
:param token:
:type token: str
:rtype: "Optional[RefreshToken]"
"""
query = select(cls).where(cls.user_id == user_id).where(cls.token == token)
return await session.scalar(query)
@classmethod
async def create(
cls, session: AsyncSession, user_id: int, token: str
) -> "RefreshToken":
"""create.
:param session:
:type session: AsyncSession
:param user_id:
:type user_id: int
:param token:
:type token: str
:rtype: "RefreshToken"
"""
db_token = cls(
user_id=user_id,
token=token,
)
session.add(db_token)
try:
await session.flush()
except Exception:
raise AssertionError("invalid token")
return db_token
async def delete(self, session: AsyncSession) -> None:
"""delete.
:param session:
:type session: AsyncSession
:rtype: None
"""
await session.delete(self)
await session.flush()

View file

@ -25,7 +25,7 @@ class RepositoryBase(Generic[T]):
return stmt
async def get(self, *expressions: ColumnElement) -> T:
async def _get(self, *expressions: ColumnElement) -> T:
stmt = self._build_select(*expressions)
result = await self.session.execute(stmt)
obj = result.scalar_one_or_none()
@ -35,14 +35,14 @@ class RepositoryBase(Generic[T]):
return obj
async def list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]:
async def _list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]:
stmt = self._build_select(*expressions)
if offset:
result = stmt.offset(offset)
stmt = stmt.offset(offset)
if limit is not None:
result = stmt.limit(limit)
stmt = stmt.limit(limit)
result = await self.session.execute(stmt)
return result.scalars().all()

View file

@ -1,6 +1,22 @@
from .base import RepositoryBase
from ..domain import Product
from typing import Sequence
from fooder.domain import Product
from fooder.repository.expression import fuzzy_match
from .base import RepositoryBase, DEFAULT_LIMIT
class ProductRepository(RepositoryBase[Product]):
pass
async def get_by_id(self, product_id: int) -> Product:
return await self._get(Product.id == product_id)
async def get_by_barcode(self, barcode: str) -> Product:
return await self._get(Product.barcode == barcode)
async def list(
self,
q: str | None = None,
offset: int = 0,
limit: int | None = DEFAULT_LIMIT,
) -> Sequence[Product]:
expressions = (fuzzy_match(Product.name, q),) if q else ()
return await self._list(*expressions, offset=offset, limit=limit)

View file

@ -1,6 +1,10 @@
from fooder.domain import User
from .base import RepositoryBase
from ..domain import User
class UserRepository(RepositoryBase[User]):
pass
async def get_by_id(self, user_id: int) -> User:
return await self._get(User.id == user_id)
async def get_by_username(self, username: str) -> User:
return await self._get(User.username == username)

View file

@ -15,19 +15,19 @@ async def test_create_returns_object_with_id(test_repo, test_model):
async def test_get_returns_existing_record(test_repo, test_model):
created = await test_repo.create(test_model)
found = await test_repo.get(TestModel.id == created.id)
found = await test_repo._get(TestModel.id == created.id)
assert found is not None
assert found.id == created.id
async def test_get_raises_not_found_for_missing_record(test_repo):
with pytest.raises(NotFound):
await test_repo.get(TestModel.id == 1)
await test_repo._get(TestModel.id == 1)
async def test_get_by_field(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1, "value"))
found = await test_repo.get(TestModel.property == "value")
found = await test_repo._get(TestModel.property == "value")
assert found is not None
assert found.id == 1
@ -38,7 +38,7 @@ async def test_get_by_field(test_repo, test_model_factory):
async def test_list_returns_all_matching(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1))
await test_repo.create(test_model_factory(2))
results = await test_repo.list()
results = await test_repo._list()
modelnames = {u.id for u in results}
assert {1, 2}.issubset(modelnames)
@ -46,14 +46,14 @@ async def test_list_returns_all_matching(test_repo, test_model_factory):
async def test_list_with_filter(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1, "value"))
await test_repo.create(test_model_factory(2, "value2"))
results = await test_repo.list(TestModel.property == "value")
results = await test_repo._list(TestModel.property == "value")
assert len(results) == 1
assert results[0].id == 1
async def test_list_returns_empty_when_no_match(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1, "value"))
results = await test_repo.list(TestModel.property == "value2")
results = await test_repo._list(TestModel.property == "value2")
assert results == []
@ -64,4 +64,4 @@ async def test_delete_removes_record(test_repo, test_model):
model = await test_repo.create(test_model)
await test_repo.delete(TestModel.id == model.id)
with pytest.raises(NotFound):
await test_repo.get(TestModel.id == model.id)
await test_repo._get(TestModel.id == model.id)

View file

@ -1,7 +1,5 @@
from fastapi import APIRouter, Depends
from fooder.repository.expression import fuzzy_match
from fooder.domain import Product
from fooder.model.product import ProductModel, ProductCreateModel, ProductUpdateModel
from fooder.controller.product import ProductController
from fooder.context import Context, AuthContextDependency
@ -16,12 +14,7 @@ async def list_products(
offset: int = 0,
q: str | None = None,
):
expressions = []
if q:
expressions.append(
fuzzy_match(Product.name, q)
)
return await ctx.repo.product.list(*expressions, limit=limit, offset=offset)
return await ctx.repo.product.list(q=q, limit=limit, offset=offset)
@router.patch("/{product_id}", response_model=ProductModel)
@ -30,7 +23,7 @@ async def update_product(
data: ProductUpdateModel,
ctx: Context = Depends(AuthContextDependency()),
):
obj = await ctx.repo.product.get(Product.id == product_id)
obj = await ctx.repo.product.get_by_id(product_id)
async with ctx.repo.transaction():
await ProductController(ctx, obj).update(data)
return obj