[domain] final touches, adding optimistic and pesimistic locks

This commit is contained in:
Piotr Domański 2026-04-07 16:39:14 +02:00
parent 7d2df880c7
commit 56270beeaf
4 changed files with 40 additions and 1 deletions

View file

@ -17,9 +17,12 @@ class CommonMixin:
return cls.__name__.lower() # type: ignore return cls.__name__.lower() # type: ignore
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
version: Mapped[int] = mapped_column(default=0)
created_at: Mapped[datetime] = mapped_column(DateTime, default=utc_now) created_at: Mapped[datetime] = mapped_column(DateTime, default=utc_now)
last_changed: Mapped[datetime] = mapped_column(DateTime, default=utc_now, onupdate=utc_now) last_changed: Mapped[datetime] = mapped_column(DateTime, default=utc_now, onupdate=utc_now)
__mapper_args__ = {"version_id_col": version}
class SoftDeleteMixin: class SoftDeleteMixin:
deleted_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) deleted_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)

View file

@ -2,6 +2,7 @@ from typing import TypeVar, Generic, Type, Sequence
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete as sa_delete, update as sa_update, ColumnElement from sqlalchemy import select, delete as sa_delete, update as sa_update, ColumnElement
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import StaleDataError
from sqlalchemy.sql import Select from sqlalchemy.sql import Select
from fooder.domain import Base from fooder.domain import Base
from fooder.exc import Conflict, NotFound from fooder.exc import Conflict, NotFound
@ -39,6 +40,16 @@ class RepositoryBase(Generic[T]):
return obj return obj
async def _get_for_update(self, *expressions: ColumnElement) -> T:
stmt = self._build_select(*expressions).with_for_update()
result = await self.session.execute(stmt)
obj = result.scalar_one_or_none()
if obj is None:
raise NotFound()
return obj
async def _list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]: async def _list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]:
stmt = self._build_select(*expressions) stmt = self._build_select(*expressions)
@ -65,6 +76,8 @@ class RepositoryBase(Generic[T]):
await self.session.flush() await self.session.flush()
except IntegrityError: except IntegrityError:
raise Conflict() raise Conflict()
except StaleDataError:
raise Conflict()
return obj return obj
async def _delete(self, *expressions: ColumnElement): async def _delete(self, *expressions: ColumnElement):

View file

@ -5,7 +5,8 @@ from sqlalchemy.exc import IntegrityError
from fooder.repository.user import UserRepository from fooder.repository.user import UserRepository
from fooder.repository.product import ProductRepository from fooder.repository.product import ProductRepository
from fooder.domain import User, Product from fooder.repository.user_product_usage import UserProductUsageRepository
from fooder.domain import User, Product, UserProductUsage
from fooder.exc import Conflict from fooder.exc import Conflict
@ -14,6 +15,7 @@ class Repository:
self.session = session self.session = session
self.user = UserRepository(User, session) self.user = UserRepository(User, session)
self.product = ProductRepository(Product, session) self.product = ProductRepository(Product, session)
self.user_product_usage = UserProductUsageRepository(UserProductUsage, session)
async def commit(self) -> None: async def commit(self) -> None:
try: try:

View file

@ -0,0 +1,21 @@
from sqlalchemy import update as sa_update
from fooder.domain.user_product_usage import UserProductUsage
from fooder.repository.base import RepositoryBase
class UserProductUsageRepository(RepositoryBase[UserProductUsage]):
async def increment(self, user_id: int, product_id: int, count: int = 1) -> None:
stmt = (
sa_update(UserProductUsage)
.where(
UserProductUsage.user_id == user_id,
UserProductUsage.product_id == product_id,
)
.values(count=UserProductUsage.count + count)
)
result = await self.session.execute(stmt)
if result.rowcount == 0:
obj = UserProductUsage(user_id=user_id, product_id=product_id, count=count)
await self.create(obj)