From 56270beeaf178b8f344014f5b3ec125e66b337d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Doma=C5=84ski?= Date: Tue, 7 Apr 2026 16:39:14 +0200 Subject: [PATCH] [domain] final touches, adding optimistic and pesimistic locks --- fooder/domain/base.py | 3 +++ fooder/repository/base.py | 13 +++++++++++++ fooder/repository/repository.py | 4 +++- fooder/repository/user_product_usage.py | 21 +++++++++++++++++++++ 4 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 fooder/repository/user_product_usage.py diff --git a/fooder/domain/base.py b/fooder/domain/base.py index b724030..21dd73f 100644 --- a/fooder/domain/base.py +++ b/fooder/domain/base.py @@ -17,9 +17,12 @@ class CommonMixin: return cls.__name__.lower() # type: ignore id: Mapped[int] = mapped_column(primary_key=True) + version: Mapped[int] = mapped_column(default=0) created_at: Mapped[datetime] = mapped_column(DateTime, default=utc_now) last_changed: Mapped[datetime] = mapped_column(DateTime, default=utc_now, onupdate=utc_now) + __mapper_args__ = {"version_id_col": version} + class SoftDeleteMixin: deleted_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) diff --git a/fooder/repository/base.py b/fooder/repository/base.py index 101c676..47a28d4 100644 --- a/fooder/repository/base.py +++ b/fooder/repository/base.py @@ -2,6 +2,7 @@ from typing import TypeVar, Generic, Type, Sequence from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, delete as sa_delete, update as sa_update, ColumnElement from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm.exc import StaleDataError from sqlalchemy.sql import Select from fooder.domain import Base from fooder.exc import Conflict, NotFound @@ -39,6 +40,16 @@ class RepositoryBase(Generic[T]): return obj + async def _get_for_update(self, *expressions: ColumnElement) -> T: + stmt = self._build_select(*expressions).with_for_update() + result = await self.session.execute(stmt) + obj = result.scalar_one_or_none() + + if obj is None: + raise NotFound() + + return obj + async def _list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]: stmt = self._build_select(*expressions) @@ -65,6 +76,8 @@ class RepositoryBase(Generic[T]): await self.session.flush() except IntegrityError: raise Conflict() + except StaleDataError: + raise Conflict() return obj async def _delete(self, *expressions: ColumnElement): diff --git a/fooder/repository/repository.py b/fooder/repository/repository.py index 31df8d7..a3b5b05 100644 --- a/fooder/repository/repository.py +++ b/fooder/repository/repository.py @@ -5,7 +5,8 @@ from sqlalchemy.exc import IntegrityError from fooder.repository.user import UserRepository from fooder.repository.product import ProductRepository -from fooder.domain import User, Product +from fooder.repository.user_product_usage import UserProductUsageRepository +from fooder.domain import User, Product, UserProductUsage from fooder.exc import Conflict @@ -14,6 +15,7 @@ class Repository: self.session = session self.user = UserRepository(User, session) self.product = ProductRepository(Product, session) + self.user_product_usage = UserProductUsageRepository(UserProductUsage, session) async def commit(self) -> None: try: diff --git a/fooder/repository/user_product_usage.py b/fooder/repository/user_product_usage.py new file mode 100644 index 0000000..19bf84c --- /dev/null +++ b/fooder/repository/user_product_usage.py @@ -0,0 +1,21 @@ +from sqlalchemy import update as sa_update + +from fooder.domain.user_product_usage import UserProductUsage +from fooder.repository.base import RepositoryBase + + +class UserProductUsageRepository(RepositoryBase[UserProductUsage]): + async def increment(self, user_id: int, product_id: int, count: int = 1) -> None: + stmt = ( + sa_update(UserProductUsage) + .where( + UserProductUsage.user_id == user_id, + UserProductUsage.product_id == product_id, + ) + .values(count=UserProductUsage.count + count) + ) + result = await self.session.execute(stmt) + + if result.rowcount == 0: + obj = UserProductUsage(user_id=user_id, product_id=product_id, count=count) + await self.create(obj)