[product] list for user

This commit is contained in:
Piotr Domański 2026-04-07 18:42:09 +02:00
parent f2dd9bfea4
commit c26247cca6
4 changed files with 52 additions and 19 deletions

View file

@ -26,8 +26,11 @@ class RepositoryBase(Generic[T]):
self.model = model self.model = model
self.session = session self.session = session
def _build_select(self, *expressions: ColumnElement) -> Select[tuple[T]]: def _build_select(
stmt = select(self.model) self, *expressions: ColumnElement, stmt: Select | None = None
) -> Select[tuple[T]]:
if stmt is None:
stmt = select(self.model)
if hasattr(self.model, "deleted_at"): if hasattr(self.model, "deleted_at"):
stmt = stmt.where(self.model.deleted_at.is_(None)) # type: ignore[attr-defined] stmt = stmt.where(self.model.deleted_at.is_(None)) # type: ignore[attr-defined]
@ -37,8 +40,8 @@ class RepositoryBase(Generic[T]):
return stmt return stmt
async def _get(self, *expressions: ColumnElement) -> T: async def _get(self, *expressions: ColumnElement, stmt: Select | None = None) -> T:
stmt = self._build_select(*expressions) stmt = self._build_select(*expressions, stmt=stmt)
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
obj = result.scalar_one_or_none() obj = result.scalar_one_or_none()
@ -47,8 +50,10 @@ class RepositoryBase(Generic[T]):
return obj return obj
async def _get_for_update(self, *expressions: ColumnElement) -> T: async def _get_for_update(
stmt = self._build_select(*expressions).with_for_update() self, *expressions: ColumnElement, stmt: Select | None = None
) -> T:
stmt = self._build_select(*expressions, stmt=stmt).with_for_update()
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
obj = result.scalar_one_or_none() obj = result.scalar_one_or_none()
@ -61,9 +66,10 @@ class RepositoryBase(Generic[T]):
self, self,
*expressions: ColumnElement, *expressions: ColumnElement,
offset: int = 0, offset: int = 0,
limit: int | None = DEFAULT_LIMIT limit: int | None = DEFAULT_LIMIT,
stmt: Select | None = None,
) -> Sequence[T]: ) -> Sequence[T]:
stmt = self._build_select(*expressions) stmt = self._build_select(*expressions, stmt=stmt)
if offset: if offset:
stmt = stmt.offset(offset) stmt = stmt.offset(offset)
@ -92,12 +98,14 @@ class RepositoryBase(Generic[T]):
raise Conflict() raise Conflict()
return obj return obj
async def _delete(self, *expressions: ColumnElement): async def _delete(
stmt: Update | Delete self, *expressions: ColumnElement, stmt: Update | Delete | None = None
if hasattr(self.model, "deleted_at"): ):
stmt = sa_update(self.model).values(deleted_at=utc_now()) if stmt is None:
else: if hasattr(self.model, "deleted_at"):
stmt = sa_delete(self.model) stmt = sa_update(self.model).values(deleted_at=utc_now())
else:
stmt = sa_delete(self.model)
if expressions: if expressions:
stmt = stmt.where(*expressions) stmt = stmt.where(*expressions)

View file

@ -1,6 +1,8 @@
from typing import Sequence from typing import Sequence
from fooder.domain import Product from sqlalchemy import desc, func, select
from fooder.domain import Product, UserProductUsage
from fooder.repository.expression import fuzzy_match from fooder.repository.expression import fuzzy_match
from fooder.repository.base import RepositoryBase, DEFAULT_LIMIT from fooder.repository.base import RepositoryBase, DEFAULT_LIMIT
@ -20,3 +22,25 @@ class ProductRepository(RepositoryBase[Product]):
) -> Sequence[Product]: ) -> Sequence[Product]:
expressions = (fuzzy_match(Product.name, q),) if q else () expressions = (fuzzy_match(Product.name, q),) if q else ()
return await self._list(*expressions, offset=offset, limit=limit) return await self._list(*expressions, offset=offset, limit=limit)
async def list_for_user(
self,
user_id: int,
q: str | None = None,
offset: int = 0,
limit: int | None = DEFAULT_LIMIT,
):
usage_label = "usage_count"
usage_count = func.coalesce(UserProductUsage.count, 0).label(usage_label)
expressions = (fuzzy_match(Product.name, q),) if q else ()
stmt = (
select(Product, usage_count)
.outerjoin(
UserProductUsage,
(Product.id == UserProductUsage.product_id)
& (UserProductUsage.user_id == user_id),
)
.order_by(desc(usage_label), Product.name)
)
return await self._list(*expressions, offset=offset, limit=limit, stmt=stmt)

View file

@ -19,7 +19,6 @@ def test_wrong_password_doesnt_verify(faker):
def test_wrong_hash_breaks(faker): def test_wrong_hash_breaks(faker):
password = faker.password() password = faker.password()
hash = password_helper.hash(password)
with pytest.raises(Exception, match="malformed"): with pytest.raises(Exception, match="hash"):
assert not password_helper.verify(password, hash + password) assert not password_helper.verify(password, "invalid_hash")

View file

@ -17,7 +17,9 @@ async def list_products(
offset: int = 0, offset: int = 0,
q: str | None = None, q: str | None = None,
): ):
return await ctx.repo.product.list(q=q, limit=limit, offset=offset) return await ctx.repo.product.list_for_user(
q=q, limit=limit, offset=offset, user_id=ctx.user.id
)
@router.patch("/{product_id}", response_model=ProductModel) @router.patch("/{product_id}", response_model=ProductModel)