[repo] encapsulate sqlalchemy completely

This commit is contained in:
Piotr Domański 2026-04-07 14:51:39 +02:00
parent 6129712efe
commit 446850ee12
9 changed files with 40 additions and 99 deletions

View file

@ -65,7 +65,7 @@ class AuthContextDependency:
ctx = Context(repo=Repository(session)) ctx = Context(repo=Repository(session))
user_id = AccessToken.decode(token).sub user_id = AccessToken.decode(token).sub
try: try:
user = await ctx.repo.user.get(User.id == user_id) user = await ctx.repo.user.get_by_id(user_id)
except NotFound: except NotFound:
raise Unauthorized() raise Unauthorized()
ctx.set_user(user) ctx.set_user(user)

View file

@ -13,7 +13,7 @@ class UserController(ModelController[User]):
password: str, password: str,
) -> "UserController": ) -> "UserController":
try: try:
obj = await ctx.repo.user.get(User.username == username) obj = await ctx.repo.user.get_by_username(username)
except NotFound: except NotFound:
raise Unauthorized() raise Unauthorized()

View file

@ -4,6 +4,5 @@ from .entry import Entry # noqa
from .meal import Meal # noqa from .meal import Meal # noqa
from .product import Product # noqa from .product import Product # noqa
from .user import User # noqa from .user import User # noqa
from .token import RefreshToken # noqa
from .preset import Preset # noqa from .preset import Preset # noqa
from .preset_entry import PresetEntry # noqa from .preset_entry import PresetEntry # noqa

View file

@ -1,71 +0,0 @@
from typing import Optional
from sqlalchemy import ForeignKey, Integer, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base, CommonMixin
class RefreshToken(Base, CommonMixin):
"""Diary represents user diary for given day"""
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
token: Mapped[str]
@classmethod
async def get_token(
cls,
session: AsyncSession,
user_id: int,
token: str,
) -> "Optional[RefreshToken]":
"""get_token.
:param session:
:type session: AsyncSession
:param user_id:
:type user_id: int
:param token:
:type token: str
:rtype: "Optional[RefreshToken]"
"""
query = select(cls).where(cls.user_id == user_id).where(cls.token == token)
return await session.scalar(query)
@classmethod
async def create(
cls, session: AsyncSession, user_id: int, token: str
) -> "RefreshToken":
"""create.
:param session:
:type session: AsyncSession
:param user_id:
:type user_id: int
:param token:
:type token: str
:rtype: "RefreshToken"
"""
db_token = cls(
user_id=user_id,
token=token,
)
session.add(db_token)
try:
await session.flush()
except Exception:
raise AssertionError("invalid token")
return db_token
async def delete(self, session: AsyncSession) -> None:
"""delete.
:param session:
:type session: AsyncSession
:rtype: None
"""
await session.delete(self)
await session.flush()

View file

@ -25,7 +25,7 @@ class RepositoryBase(Generic[T]):
return stmt return stmt
async def get(self, *expressions: ColumnElement) -> T: async def _get(self, *expressions: ColumnElement) -> T:
stmt = self._build_select(*expressions) stmt = self._build_select(*expressions)
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
obj = result.scalar_one_or_none() obj = result.scalar_one_or_none()
@ -35,14 +35,14 @@ class RepositoryBase(Generic[T]):
return obj return obj
async def list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]: async def _list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]:
stmt = self._build_select(*expressions) stmt = self._build_select(*expressions)
if offset: if offset:
result = stmt.offset(offset) stmt = stmt.offset(offset)
if limit is not None: if limit is not None:
result = stmt.limit(limit) stmt = stmt.limit(limit)
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
return result.scalars().all() return result.scalars().all()

View file

@ -1,6 +1,22 @@
from .base import RepositoryBase from typing import Sequence
from ..domain import Product
from fooder.domain import Product
from fooder.repository.expression import fuzzy_match
from .base import RepositoryBase, DEFAULT_LIMIT
class ProductRepository(RepositoryBase[Product]): class ProductRepository(RepositoryBase[Product]):
pass async def get_by_id(self, product_id: int) -> Product:
return await self._get(Product.id == product_id)
async def get_by_barcode(self, barcode: str) -> Product:
return await self._get(Product.barcode == barcode)
async def list(
self,
q: str | None = None,
offset: int = 0,
limit: int | None = DEFAULT_LIMIT,
) -> Sequence[Product]:
expressions = (fuzzy_match(Product.name, q),) if q else ()
return await self._list(*expressions, offset=offset, limit=limit)

View file

@ -1,6 +1,10 @@
from fooder.domain import User
from .base import RepositoryBase from .base import RepositoryBase
from ..domain import User
class UserRepository(RepositoryBase[User]): class UserRepository(RepositoryBase[User]):
pass async def get_by_id(self, user_id: int) -> User:
return await self._get(User.id == user_id)
async def get_by_username(self, username: str) -> User:
return await self._get(User.username == username)

View file

@ -15,19 +15,19 @@ async def test_create_returns_object_with_id(test_repo, test_model):
async def test_get_returns_existing_record(test_repo, test_model): async def test_get_returns_existing_record(test_repo, test_model):
created = await test_repo.create(test_model) created = await test_repo.create(test_model)
found = await test_repo.get(TestModel.id == created.id) found = await test_repo._get(TestModel.id == created.id)
assert found is not None assert found is not None
assert found.id == created.id assert found.id == created.id
async def test_get_raises_not_found_for_missing_record(test_repo): async def test_get_raises_not_found_for_missing_record(test_repo):
with pytest.raises(NotFound): with pytest.raises(NotFound):
await test_repo.get(TestModel.id == 1) await test_repo._get(TestModel.id == 1)
async def test_get_by_field(test_repo, test_model_factory): async def test_get_by_field(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1, "value")) await test_repo.create(test_model_factory(1, "value"))
found = await test_repo.get(TestModel.property == "value") found = await test_repo._get(TestModel.property == "value")
assert found is not None assert found is not None
assert found.id == 1 assert found.id == 1
@ -38,7 +38,7 @@ async def test_get_by_field(test_repo, test_model_factory):
async def test_list_returns_all_matching(test_repo, test_model_factory): async def test_list_returns_all_matching(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1)) await test_repo.create(test_model_factory(1))
await test_repo.create(test_model_factory(2)) await test_repo.create(test_model_factory(2))
results = await test_repo.list() results = await test_repo._list()
modelnames = {u.id for u in results} modelnames = {u.id for u in results}
assert {1, 2}.issubset(modelnames) assert {1, 2}.issubset(modelnames)
@ -46,14 +46,14 @@ async def test_list_returns_all_matching(test_repo, test_model_factory):
async def test_list_with_filter(test_repo, test_model_factory): async def test_list_with_filter(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1, "value")) await test_repo.create(test_model_factory(1, "value"))
await test_repo.create(test_model_factory(2, "value2")) await test_repo.create(test_model_factory(2, "value2"))
results = await test_repo.list(TestModel.property == "value") results = await test_repo._list(TestModel.property == "value")
assert len(results) == 1 assert len(results) == 1
assert results[0].id == 1 assert results[0].id == 1
async def test_list_returns_empty_when_no_match(test_repo, test_model_factory): async def test_list_returns_empty_when_no_match(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1, "value")) await test_repo.create(test_model_factory(1, "value"))
results = await test_repo.list(TestModel.property == "value2") results = await test_repo._list(TestModel.property == "value2")
assert results == [] assert results == []
@ -64,4 +64,4 @@ async def test_delete_removes_record(test_repo, test_model):
model = await test_repo.create(test_model) model = await test_repo.create(test_model)
await test_repo.delete(TestModel.id == model.id) await test_repo.delete(TestModel.id == model.id)
with pytest.raises(NotFound): with pytest.raises(NotFound):
await test_repo.get(TestModel.id == model.id) await test_repo._get(TestModel.id == model.id)

View file

@ -1,7 +1,5 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from fooder.repository.expression import fuzzy_match
from fooder.domain import Product
from fooder.model.product import ProductModel, ProductCreateModel, ProductUpdateModel from fooder.model.product import ProductModel, ProductCreateModel, ProductUpdateModel
from fooder.controller.product import ProductController from fooder.controller.product import ProductController
from fooder.context import Context, AuthContextDependency from fooder.context import Context, AuthContextDependency
@ -16,12 +14,7 @@ async def list_products(
offset: int = 0, offset: int = 0,
q: str | None = None, q: str | None = None,
): ):
expressions = [] return await ctx.repo.product.list(q=q, limit=limit, offset=offset)
if q:
expressions.append(
fuzzy_match(Product.name, q)
)
return await ctx.repo.product.list(*expressions, limit=limit, offset=offset)
@router.patch("/{product_id}", response_model=ProductModel) @router.patch("/{product_id}", response_model=ProductModel)
@ -30,7 +23,7 @@ async def update_product(
data: ProductUpdateModel, data: ProductUpdateModel,
ctx: Context = Depends(AuthContextDependency()), ctx: Context = Depends(AuthContextDependency()),
): ):
obj = await ctx.repo.product.get(Product.id == product_id) obj = await ctx.repo.product.get_by_id(product_id)
async with ctx.repo.transaction(): async with ctx.repo.transaction():
await ProductController(ctx, obj).update(data) await ProductController(ctx, obj).update(data)
return obj return obj