diff --git a/fooder/repository/base.py b/fooder/repository/base.py index a68b8e4..17f431d 100644 --- a/fooder/repository/base.py +++ b/fooder/repository/base.py @@ -26,8 +26,11 @@ class RepositoryBase(Generic[T]): self.model = model self.session = session - def _build_select(self, *expressions: ColumnElement) -> Select[tuple[T]]: - stmt = select(self.model) + def _build_select( + self, *expressions: ColumnElement, stmt: Select | None = None + ) -> Select[tuple[T]]: + if stmt is None: + stmt = select(self.model) if hasattr(self.model, "deleted_at"): stmt = stmt.where(self.model.deleted_at.is_(None)) # type: ignore[attr-defined] @@ -37,8 +40,8 @@ class RepositoryBase(Generic[T]): return stmt - async def _get(self, *expressions: ColumnElement) -> T: - stmt = self._build_select(*expressions) + async def _get(self, *expressions: ColumnElement, stmt: Select | None = None) -> T: + stmt = self._build_select(*expressions, stmt=stmt) result = await self.session.execute(stmt) obj = result.scalar_one_or_none() @@ -47,8 +50,10 @@ class RepositoryBase(Generic[T]): return obj - async def _get_for_update(self, *expressions: ColumnElement) -> T: - stmt = self._build_select(*expressions).with_for_update() + async def _get_for_update( + self, *expressions: ColumnElement, stmt: Select | None = None + ) -> T: + stmt = self._build_select(*expressions, stmt=stmt).with_for_update() result = await self.session.execute(stmt) obj = result.scalar_one_or_none() @@ -61,9 +66,10 @@ class RepositoryBase(Generic[T]): self, *expressions: ColumnElement, offset: int = 0, - limit: int | None = DEFAULT_LIMIT + limit: int | None = DEFAULT_LIMIT, + stmt: Select | None = None, ) -> Sequence[T]: - stmt = self._build_select(*expressions) + stmt = self._build_select(*expressions, stmt=stmt) if offset: stmt = stmt.offset(offset) @@ -92,12 +98,14 @@ class RepositoryBase(Generic[T]): raise Conflict() return obj - async def _delete(self, *expressions: ColumnElement): - stmt: Update | Delete - if hasattr(self.model, "deleted_at"): - stmt = sa_update(self.model).values(deleted_at=utc_now()) - else: - stmt = sa_delete(self.model) + async def _delete( + self, *expressions: ColumnElement, stmt: Update | Delete | None = None + ): + if stmt is None: + if hasattr(self.model, "deleted_at"): + stmt = sa_update(self.model).values(deleted_at=utc_now()) + else: + stmt = sa_delete(self.model) if expressions: stmt = stmt.where(*expressions) diff --git a/fooder/repository/product.py b/fooder/repository/product.py index 078a3f4..124a155 100644 --- a/fooder/repository/product.py +++ b/fooder/repository/product.py @@ -1,6 +1,8 @@ from typing import Sequence -from fooder.domain import Product +from sqlalchemy import desc, func, select + +from fooder.domain import Product, UserProductUsage from fooder.repository.expression import fuzzy_match from fooder.repository.base import RepositoryBase, DEFAULT_LIMIT @@ -20,3 +22,25 @@ class ProductRepository(RepositoryBase[Product]): ) -> Sequence[Product]: expressions = (fuzzy_match(Product.name, q),) if q else () return await self._list(*expressions, offset=offset, limit=limit) + + async def list_for_user( + self, + user_id: int, + q: str | None = None, + offset: int = 0, + limit: int | None = DEFAULT_LIMIT, + ): + usage_label = "usage_count" + usage_count = func.coalesce(UserProductUsage.count, 0).label(usage_label) + expressions = (fuzzy_match(Product.name, q),) if q else () + + stmt = ( + select(Product, usage_count) + .outerjoin( + UserProductUsage, + (Product.id == UserProductUsage.product_id) + & (UserProductUsage.user_id == user_id), + ) + .order_by(desc(usage_label), Product.name) + ) + return await self._list(*expressions, offset=offset, limit=limit, stmt=stmt) diff --git a/fooder/test/utils/test_password_helper.py b/fooder/test/utils/test_password_helper.py index 59bac85..0f4d15b 100644 --- a/fooder/test/utils/test_password_helper.py +++ b/fooder/test/utils/test_password_helper.py @@ -19,7 +19,6 @@ def test_wrong_password_doesnt_verify(faker): def test_wrong_hash_breaks(faker): password = faker.password() - hash = password_helper.hash(password) - with pytest.raises(Exception, match="malformed"): - assert not password_helper.verify(password, hash + password) + with pytest.raises(Exception, match="hash"): + assert not password_helper.verify(password, "invalid_hash") diff --git a/fooder/view/product.py b/fooder/view/product.py index 0b4e2cb..b379b78 100644 --- a/fooder/view/product.py +++ b/fooder/view/product.py @@ -17,7 +17,9 @@ async def list_products( offset: int = 0, q: str | None = None, ): - return await ctx.repo.product.list(q=q, limit=limit, offset=offset) + return await ctx.repo.product.list_for_user( + q=q, limit=limit, offset=offset, user_id=ctx.user.id + ) @router.patch("/{product_id}", response_model=ProductModel)