[repo] encapsulate sqlalchemy completely
This commit is contained in:
parent
6129712efe
commit
446850ee12
9 changed files with 40 additions and 99 deletions
|
|
@ -65,7 +65,7 @@ class AuthContextDependency:
|
|||
ctx = Context(repo=Repository(session))
|
||||
user_id = AccessToken.decode(token).sub
|
||||
try:
|
||||
user = await ctx.repo.user.get(User.id == user_id)
|
||||
user = await ctx.repo.user.get_by_id(user_id)
|
||||
except NotFound:
|
||||
raise Unauthorized()
|
||||
ctx.set_user(user)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ class UserController(ModelController[User]):
|
|||
password: str,
|
||||
) -> "UserController":
|
||||
try:
|
||||
obj = await ctx.repo.user.get(User.username == username)
|
||||
obj = await ctx.repo.user.get_by_username(username)
|
||||
except NotFound:
|
||||
raise Unauthorized()
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,5 @@ from .entry import Entry # noqa
|
|||
from .meal import Meal # noqa
|
||||
from .product import Product # noqa
|
||||
from .user import User # noqa
|
||||
from .token import RefreshToken # noqa
|
||||
from .preset import Preset # noqa
|
||||
from .preset_entry import PresetEntry # noqa
|
||||
|
|
|
|||
|
|
@ -1,71 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from sqlalchemy import ForeignKey, Integer, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import Base, CommonMixin
|
||||
|
||||
|
||||
class RefreshToken(Base, CommonMixin):
|
||||
"""Diary represents user diary for given day"""
|
||||
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
|
||||
token: Mapped[str]
|
||||
|
||||
@classmethod
|
||||
async def get_token(
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
user_id: int,
|
||||
token: str,
|
||||
) -> "Optional[RefreshToken]":
|
||||
"""get_token.
|
||||
|
||||
:param session:
|
||||
:type session: AsyncSession
|
||||
:param user_id:
|
||||
:type user_id: int
|
||||
:param token:
|
||||
:type token: str
|
||||
:rtype: "Optional[RefreshToken]"
|
||||
"""
|
||||
query = select(cls).where(cls.user_id == user_id).where(cls.token == token)
|
||||
return await session.scalar(query)
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls, session: AsyncSession, user_id: int, token: str
|
||||
) -> "RefreshToken":
|
||||
"""create.
|
||||
|
||||
:param session:
|
||||
:type session: AsyncSession
|
||||
:param user_id:
|
||||
:type user_id: int
|
||||
:param token:
|
||||
:type token: str
|
||||
:rtype: "RefreshToken"
|
||||
"""
|
||||
db_token = cls(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
)
|
||||
session.add(db_token)
|
||||
|
||||
try:
|
||||
await session.flush()
|
||||
except Exception:
|
||||
raise AssertionError("invalid token")
|
||||
|
||||
return db_token
|
||||
|
||||
async def delete(self, session: AsyncSession) -> None:
|
||||
"""delete.
|
||||
|
||||
:param session:
|
||||
:type session: AsyncSession
|
||||
:rtype: None
|
||||
"""
|
||||
await session.delete(self)
|
||||
await session.flush()
|
||||
|
|
@ -25,7 +25,7 @@ class RepositoryBase(Generic[T]):
|
|||
|
||||
return stmt
|
||||
|
||||
async def get(self, *expressions: ColumnElement) -> T:
|
||||
async def _get(self, *expressions: ColumnElement) -> T:
|
||||
stmt = self._build_select(*expressions)
|
||||
result = await self.session.execute(stmt)
|
||||
obj = result.scalar_one_or_none()
|
||||
|
|
@ -35,14 +35,14 @@ class RepositoryBase(Generic[T]):
|
|||
|
||||
return obj
|
||||
|
||||
async def list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]:
|
||||
async def _list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]:
|
||||
stmt = self._build_select(*expressions)
|
||||
|
||||
if offset:
|
||||
result = stmt.offset(offset)
|
||||
stmt = stmt.offset(offset)
|
||||
|
||||
if limit is not None:
|
||||
result = stmt.limit(limit)
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,22 @@
|
|||
from .base import RepositoryBase
|
||||
from ..domain import Product
|
||||
from typing import Sequence
|
||||
|
||||
from fooder.domain import Product
|
||||
from fooder.repository.expression import fuzzy_match
|
||||
from .base import RepositoryBase, DEFAULT_LIMIT
|
||||
|
||||
|
||||
class ProductRepository(RepositoryBase[Product]):
|
||||
pass
|
||||
async def get_by_id(self, product_id: int) -> Product:
|
||||
return await self._get(Product.id == product_id)
|
||||
|
||||
async def get_by_barcode(self, barcode: str) -> Product:
|
||||
return await self._get(Product.barcode == barcode)
|
||||
|
||||
async def list(
|
||||
self,
|
||||
q: str | None = None,
|
||||
offset: int = 0,
|
||||
limit: int | None = DEFAULT_LIMIT,
|
||||
) -> Sequence[Product]:
|
||||
expressions = (fuzzy_match(Product.name, q),) if q else ()
|
||||
return await self._list(*expressions, offset=offset, limit=limit)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
from fooder.domain import User
|
||||
from .base import RepositoryBase
|
||||
from ..domain import User
|
||||
|
||||
|
||||
class UserRepository(RepositoryBase[User]):
|
||||
pass
|
||||
async def get_by_id(self, user_id: int) -> User:
|
||||
return await self._get(User.id == user_id)
|
||||
|
||||
async def get_by_username(self, username: str) -> User:
|
||||
return await self._get(User.username == username)
|
||||
|
|
|
|||
|
|
@ -15,19 +15,19 @@ async def test_create_returns_object_with_id(test_repo, test_model):
|
|||
|
||||
async def test_get_returns_existing_record(test_repo, test_model):
|
||||
created = await test_repo.create(test_model)
|
||||
found = await test_repo.get(TestModel.id == created.id)
|
||||
found = await test_repo._get(TestModel.id == created.id)
|
||||
assert found is not None
|
||||
assert found.id == created.id
|
||||
|
||||
|
||||
async def test_get_raises_not_found_for_missing_record(test_repo):
|
||||
with pytest.raises(NotFound):
|
||||
await test_repo.get(TestModel.id == 1)
|
||||
await test_repo._get(TestModel.id == 1)
|
||||
|
||||
|
||||
async def test_get_by_field(test_repo, test_model_factory):
|
||||
await test_repo.create(test_model_factory(1, "value"))
|
||||
found = await test_repo.get(TestModel.property == "value")
|
||||
found = await test_repo._get(TestModel.property == "value")
|
||||
assert found is not None
|
||||
assert found.id == 1
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ async def test_get_by_field(test_repo, test_model_factory):
|
|||
async def test_list_returns_all_matching(test_repo, test_model_factory):
|
||||
await test_repo.create(test_model_factory(1))
|
||||
await test_repo.create(test_model_factory(2))
|
||||
results = await test_repo.list()
|
||||
results = await test_repo._list()
|
||||
modelnames = {u.id for u in results}
|
||||
assert {1, 2}.issubset(modelnames)
|
||||
|
||||
|
|
@ -46,14 +46,14 @@ async def test_list_returns_all_matching(test_repo, test_model_factory):
|
|||
async def test_list_with_filter(test_repo, test_model_factory):
|
||||
await test_repo.create(test_model_factory(1, "value"))
|
||||
await test_repo.create(test_model_factory(2, "value2"))
|
||||
results = await test_repo.list(TestModel.property == "value")
|
||||
results = await test_repo._list(TestModel.property == "value")
|
||||
assert len(results) == 1
|
||||
assert results[0].id == 1
|
||||
|
||||
|
||||
async def test_list_returns_empty_when_no_match(test_repo, test_model_factory):
|
||||
await test_repo.create(test_model_factory(1, "value"))
|
||||
results = await test_repo.list(TestModel.property == "value2")
|
||||
results = await test_repo._list(TestModel.property == "value2")
|
||||
assert results == []
|
||||
|
||||
|
||||
|
|
@ -64,4 +64,4 @@ async def test_delete_removes_record(test_repo, test_model):
|
|||
model = await test_repo.create(test_model)
|
||||
await test_repo.delete(TestModel.id == model.id)
|
||||
with pytest.raises(NotFound):
|
||||
await test_repo.get(TestModel.id == model.id)
|
||||
await test_repo._get(TestModel.id == model.id)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
from fastapi import APIRouter, Depends
|
||||
|
||||
from fooder.repository.expression import fuzzy_match
|
||||
from fooder.domain import Product
|
||||
from fooder.model.product import ProductModel, ProductCreateModel, ProductUpdateModel
|
||||
from fooder.controller.product import ProductController
|
||||
from fooder.context import Context, AuthContextDependency
|
||||
|
|
@ -16,12 +14,7 @@ async def list_products(
|
|||
offset: int = 0,
|
||||
q: str | None = None,
|
||||
):
|
||||
expressions = []
|
||||
if q:
|
||||
expressions.append(
|
||||
fuzzy_match(Product.name, q)
|
||||
)
|
||||
return await ctx.repo.product.list(*expressions, limit=limit, offset=offset)
|
||||
return await ctx.repo.product.list(q=q, limit=limit, offset=offset)
|
||||
|
||||
|
||||
@router.patch("/{product_id}", response_model=ProductModel)
|
||||
|
|
@ -30,7 +23,7 @@ async def update_product(
|
|||
data: ProductUpdateModel,
|
||||
ctx: Context = Depends(AuthContextDependency()),
|
||||
):
|
||||
obj = await ctx.repo.product.get(Product.id == product_id)
|
||||
obj = await ctx.repo.product.get_by_id(product_id)
|
||||
async with ctx.repo.transaction():
|
||||
await ProductController(ctx, obj).update(data)
|
||||
return obj
|
||||
|
|
|
|||
Loading…
Reference in a new issue