From 446850ee12ee29910c8c1aa338fd2fa144f5a3b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Doma=C5=84ski?= Date: Tue, 7 Apr 2026 14:51:39 +0200 Subject: [PATCH] [repo] encapsulate sqlalchemy completely --- fooder/context.py | 2 +- fooder/controller/user.py | 2 +- fooder/domain/__init__.py | 1 - fooder/domain/token.py | 71 ----------------------------- fooder/repository/base.py | 8 ++-- fooder/repository/product.py | 22 +++++++-- fooder/repository/user.py | 8 +++- fooder/test/repository/test_base.py | 14 +++--- fooder/view/product.py | 11 +---- 9 files changed, 40 insertions(+), 99 deletions(-) delete mode 100644 fooder/domain/token.py diff --git a/fooder/context.py b/fooder/context.py index 83530fd..779bbf1 100644 --- a/fooder/context.py +++ b/fooder/context.py @@ -65,7 +65,7 @@ class AuthContextDependency: ctx = Context(repo=Repository(session)) user_id = AccessToken.decode(token).sub try: - user = await ctx.repo.user.get(User.id == user_id) + user = await ctx.repo.user.get_by_id(user_id) except NotFound: raise Unauthorized() ctx.set_user(user) diff --git a/fooder/controller/user.py b/fooder/controller/user.py index 26764c7..46333dd 100644 --- a/fooder/controller/user.py +++ b/fooder/controller/user.py @@ -13,7 +13,7 @@ class UserController(ModelController[User]): password: str, ) -> "UserController": try: - obj = await ctx.repo.user.get(User.username == username) + obj = await ctx.repo.user.get_by_username(username) except NotFound: raise Unauthorized() diff --git a/fooder/domain/__init__.py b/fooder/domain/__init__.py index d237f93..be490a2 100644 --- a/fooder/domain/__init__.py +++ b/fooder/domain/__init__.py @@ -4,6 +4,5 @@ from .entry import Entry # noqa from .meal import Meal # noqa from .product import Product # noqa from .user import User # noqa -from .token import RefreshToken # noqa from .preset import Preset # noqa from .preset_entry import PresetEntry # noqa diff --git a/fooder/domain/token.py b/fooder/domain/token.py deleted file mode 100644 index cf73ed5..0000000 --- a/fooder/domain/token.py +++ /dev/null @@ -1,71 +0,0 @@ -from typing import Optional - -from sqlalchemy import ForeignKey, Integer, select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Mapped, mapped_column - -from .base import Base, CommonMixin - - -class RefreshToken(Base, CommonMixin): - """Diary represents user diary for given day""" - - user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id")) - token: Mapped[str] - - @classmethod - async def get_token( - cls, - session: AsyncSession, - user_id: int, - token: str, - ) -> "Optional[RefreshToken]": - """get_token. - - :param session: - :type session: AsyncSession - :param user_id: - :type user_id: int - :param token: - :type token: str - :rtype: "Optional[RefreshToken]" - """ - query = select(cls).where(cls.user_id == user_id).where(cls.token == token) - return await session.scalar(query) - - @classmethod - async def create( - cls, session: AsyncSession, user_id: int, token: str - ) -> "RefreshToken": - """create. - - :param session: - :type session: AsyncSession - :param user_id: - :type user_id: int - :param token: - :type token: str - :rtype: "RefreshToken" - """ - db_token = cls( - user_id=user_id, - token=token, - ) - session.add(db_token) - - try: - await session.flush() - except Exception: - raise AssertionError("invalid token") - - return db_token - - async def delete(self, session: AsyncSession) -> None: - """delete. - - :param session: - :type session: AsyncSession - :rtype: None - """ - await session.delete(self) - await session.flush() diff --git a/fooder/repository/base.py b/fooder/repository/base.py index ef6e8a0..11d4733 100644 --- a/fooder/repository/base.py +++ b/fooder/repository/base.py @@ -25,7 +25,7 @@ class RepositoryBase(Generic[T]): return stmt - async def get(self, *expressions: ColumnElement) -> T: + async def _get(self, *expressions: ColumnElement) -> T: stmt = self._build_select(*expressions) result = await self.session.execute(stmt) obj = result.scalar_one_or_none() @@ -35,14 +35,14 @@ class RepositoryBase(Generic[T]): return obj - async def list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]: + async def _list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]: stmt = self._build_select(*expressions) if offset: - result = stmt.offset(offset) + stmt = stmt.offset(offset) if limit is not None: - result = stmt.limit(limit) + stmt = stmt.limit(limit) result = await self.session.execute(stmt) return result.scalars().all() diff --git a/fooder/repository/product.py b/fooder/repository/product.py index e0a1e7b..3a13f6a 100644 --- a/fooder/repository/product.py +++ b/fooder/repository/product.py @@ -1,6 +1,22 @@ -from .base import RepositoryBase -from ..domain import Product +from typing import Sequence + +from fooder.domain import Product +from fooder.repository.expression import fuzzy_match +from .base import RepositoryBase, DEFAULT_LIMIT class ProductRepository(RepositoryBase[Product]): - pass + async def get_by_id(self, product_id: int) -> Product: + return await self._get(Product.id == product_id) + + async def get_by_barcode(self, barcode: str) -> Product: + return await self._get(Product.barcode == barcode) + + async def list( + self, + q: str | None = None, + offset: int = 0, + limit: int | None = DEFAULT_LIMIT, + ) -> Sequence[Product]: + expressions = (fuzzy_match(Product.name, q),) if q else () + return await self._list(*expressions, offset=offset, limit=limit) diff --git a/fooder/repository/user.py b/fooder/repository/user.py index ca35abe..4cba5a0 100644 --- a/fooder/repository/user.py +++ b/fooder/repository/user.py @@ -1,6 +1,10 @@ +from fooder.domain import User from .base import RepositoryBase -from ..domain import User class UserRepository(RepositoryBase[User]): - pass + async def get_by_id(self, user_id: int) -> User: + return await self._get(User.id == user_id) + + async def get_by_username(self, username: str) -> User: + return await self._get(User.username == username) diff --git a/fooder/test/repository/test_base.py b/fooder/test/repository/test_base.py index 6a29b67..e47c243 100644 --- a/fooder/test/repository/test_base.py +++ b/fooder/test/repository/test_base.py @@ -15,19 +15,19 @@ async def test_create_returns_object_with_id(test_repo, test_model): async def test_get_returns_existing_record(test_repo, test_model): created = await test_repo.create(test_model) - found = await test_repo.get(TestModel.id == created.id) + found = await test_repo._get(TestModel.id == created.id) assert found is not None assert found.id == created.id async def test_get_raises_not_found_for_missing_record(test_repo): with pytest.raises(NotFound): - await test_repo.get(TestModel.id == 1) + await test_repo._get(TestModel.id == 1) async def test_get_by_field(test_repo, test_model_factory): await test_repo.create(test_model_factory(1, "value")) - found = await test_repo.get(TestModel.property == "value") + found = await test_repo._get(TestModel.property == "value") assert found is not None assert found.id == 1 @@ -38,7 +38,7 @@ async def test_get_by_field(test_repo, test_model_factory): async def test_list_returns_all_matching(test_repo, test_model_factory): await test_repo.create(test_model_factory(1)) await test_repo.create(test_model_factory(2)) - results = await test_repo.list() + results = await test_repo._list() modelnames = {u.id for u in results} assert {1, 2}.issubset(modelnames) @@ -46,14 +46,14 @@ async def test_list_returns_all_matching(test_repo, test_model_factory): async def test_list_with_filter(test_repo, test_model_factory): await test_repo.create(test_model_factory(1, "value")) await test_repo.create(test_model_factory(2, "value2")) - results = await test_repo.list(TestModel.property == "value") + results = await test_repo._list(TestModel.property == "value") assert len(results) == 1 assert results[0].id == 1 async def test_list_returns_empty_when_no_match(test_repo, test_model_factory): await test_repo.create(test_model_factory(1, "value")) - results = await test_repo.list(TestModel.property == "value2") + results = await test_repo._list(TestModel.property == "value2") assert results == [] @@ -64,4 +64,4 @@ async def test_delete_removes_record(test_repo, test_model): model = await test_repo.create(test_model) await test_repo.delete(TestModel.id == model.id) with pytest.raises(NotFound): - await test_repo.get(TestModel.id == model.id) + await test_repo._get(TestModel.id == model.id) diff --git a/fooder/view/product.py b/fooder/view/product.py index 22c45a4..e02d60e 100644 --- a/fooder/view/product.py +++ b/fooder/view/product.py @@ -1,7 +1,5 @@ from fastapi import APIRouter, Depends -from fooder.repository.expression import fuzzy_match -from fooder.domain import Product from fooder.model.product import ProductModel, ProductCreateModel, ProductUpdateModel from fooder.controller.product import ProductController from fooder.context import Context, AuthContextDependency @@ -16,12 +14,7 @@ async def list_products( offset: int = 0, q: str | None = None, ): - expressions = [] - if q: - expressions.append( - fuzzy_match(Product.name, q) - ) - return await ctx.repo.product.list(*expressions, limit=limit, offset=offset) + return await ctx.repo.product.list(q=q, limit=limit, offset=offset) @router.patch("/{product_id}", response_model=ProductModel) @@ -30,7 +23,7 @@ async def update_product( data: ProductUpdateModel, ctx: Context = Depends(AuthContextDependency()), ): - obj = await ctx.repo.product.get(Product.id == product_id) + obj = await ctx.repo.product.get_by_id(product_id) async with ctx.repo.transaction(): await ProductController(ctx, obj).update(data) return obj