[repo] encapsulate sqlalchemy completely
This commit is contained in:
parent
6129712efe
commit
446850ee12
9 changed files with 40 additions and 99 deletions
|
|
@ -65,7 +65,7 @@ class AuthContextDependency:
|
||||||
ctx = Context(repo=Repository(session))
|
ctx = Context(repo=Repository(session))
|
||||||
user_id = AccessToken.decode(token).sub
|
user_id = AccessToken.decode(token).sub
|
||||||
try:
|
try:
|
||||||
user = await ctx.repo.user.get(User.id == user_id)
|
user = await ctx.repo.user.get_by_id(user_id)
|
||||||
except NotFound:
|
except NotFound:
|
||||||
raise Unauthorized()
|
raise Unauthorized()
|
||||||
ctx.set_user(user)
|
ctx.set_user(user)
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ class UserController(ModelController[User]):
|
||||||
password: str,
|
password: str,
|
||||||
) -> "UserController":
|
) -> "UserController":
|
||||||
try:
|
try:
|
||||||
obj = await ctx.repo.user.get(User.username == username)
|
obj = await ctx.repo.user.get_by_username(username)
|
||||||
except NotFound:
|
except NotFound:
|
||||||
raise Unauthorized()
|
raise Unauthorized()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,5 @@ from .entry import Entry # noqa
|
||||||
from .meal import Meal # noqa
|
from .meal import Meal # noqa
|
||||||
from .product import Product # noqa
|
from .product import Product # noqa
|
||||||
from .user import User # noqa
|
from .user import User # noqa
|
||||||
from .token import RefreshToken # noqa
|
|
||||||
from .preset import Preset # noqa
|
from .preset import Preset # noqa
|
||||||
from .preset_entry import PresetEntry # noqa
|
from .preset_entry import PresetEntry # noqa
|
||||||
|
|
|
||||||
|
|
@ -1,71 +0,0 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from sqlalchemy import ForeignKey, Integer, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
|
||||||
|
|
||||||
from .base import Base, CommonMixin
|
|
||||||
|
|
||||||
|
|
||||||
class RefreshToken(Base, CommonMixin):
|
|
||||||
"""Diary represents user diary for given day"""
|
|
||||||
|
|
||||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
|
|
||||||
token: Mapped[str]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_token(
|
|
||||||
cls,
|
|
||||||
session: AsyncSession,
|
|
||||||
user_id: int,
|
|
||||||
token: str,
|
|
||||||
) -> "Optional[RefreshToken]":
|
|
||||||
"""get_token.
|
|
||||||
|
|
||||||
:param session:
|
|
||||||
:type session: AsyncSession
|
|
||||||
:param user_id:
|
|
||||||
:type user_id: int
|
|
||||||
:param token:
|
|
||||||
:type token: str
|
|
||||||
:rtype: "Optional[RefreshToken]"
|
|
||||||
"""
|
|
||||||
query = select(cls).where(cls.user_id == user_id).where(cls.token == token)
|
|
||||||
return await session.scalar(query)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def create(
|
|
||||||
cls, session: AsyncSession, user_id: int, token: str
|
|
||||||
) -> "RefreshToken":
|
|
||||||
"""create.
|
|
||||||
|
|
||||||
:param session:
|
|
||||||
:type session: AsyncSession
|
|
||||||
:param user_id:
|
|
||||||
:type user_id: int
|
|
||||||
:param token:
|
|
||||||
:type token: str
|
|
||||||
:rtype: "RefreshToken"
|
|
||||||
"""
|
|
||||||
db_token = cls(
|
|
||||||
user_id=user_id,
|
|
||||||
token=token,
|
|
||||||
)
|
|
||||||
session.add(db_token)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await session.flush()
|
|
||||||
except Exception:
|
|
||||||
raise AssertionError("invalid token")
|
|
||||||
|
|
||||||
return db_token
|
|
||||||
|
|
||||||
async def delete(self, session: AsyncSession) -> None:
|
|
||||||
"""delete.
|
|
||||||
|
|
||||||
:param session:
|
|
||||||
:type session: AsyncSession
|
|
||||||
:rtype: None
|
|
||||||
"""
|
|
||||||
await session.delete(self)
|
|
||||||
await session.flush()
|
|
||||||
|
|
@ -25,7 +25,7 @@ class RepositoryBase(Generic[T]):
|
||||||
|
|
||||||
return stmt
|
return stmt
|
||||||
|
|
||||||
async def get(self, *expressions: ColumnElement) -> T:
|
async def _get(self, *expressions: ColumnElement) -> T:
|
||||||
stmt = self._build_select(*expressions)
|
stmt = self._build_select(*expressions)
|
||||||
result = await self.session.execute(stmt)
|
result = await self.session.execute(stmt)
|
||||||
obj = result.scalar_one_or_none()
|
obj = result.scalar_one_or_none()
|
||||||
|
|
@ -35,14 +35,14 @@ class RepositoryBase(Generic[T]):
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
async def list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]:
|
async def _list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]:
|
||||||
stmt = self._build_select(*expressions)
|
stmt = self._build_select(*expressions)
|
||||||
|
|
||||||
if offset:
|
if offset:
|
||||||
result = stmt.offset(offset)
|
stmt = stmt.offset(offset)
|
||||||
|
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
result = stmt.limit(limit)
|
stmt = stmt.limit(limit)
|
||||||
|
|
||||||
result = await self.session.execute(stmt)
|
result = await self.session.execute(stmt)
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,22 @@
|
||||||
from .base import RepositoryBase
|
from typing import Sequence
|
||||||
from ..domain import Product
|
|
||||||
|
from fooder.domain import Product
|
||||||
|
from fooder.repository.expression import fuzzy_match
|
||||||
|
from .base import RepositoryBase, DEFAULT_LIMIT
|
||||||
|
|
||||||
|
|
||||||
class ProductRepository(RepositoryBase[Product]):
|
class ProductRepository(RepositoryBase[Product]):
|
||||||
pass
|
async def get_by_id(self, product_id: int) -> Product:
|
||||||
|
return await self._get(Product.id == product_id)
|
||||||
|
|
||||||
|
async def get_by_barcode(self, barcode: str) -> Product:
|
||||||
|
return await self._get(Product.barcode == barcode)
|
||||||
|
|
||||||
|
async def list(
|
||||||
|
self,
|
||||||
|
q: str | None = None,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int | None = DEFAULT_LIMIT,
|
||||||
|
) -> Sequence[Product]:
|
||||||
|
expressions = (fuzzy_match(Product.name, q),) if q else ()
|
||||||
|
return await self._list(*expressions, offset=offset, limit=limit)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,10 @@
|
||||||
|
from fooder.domain import User
|
||||||
from .base import RepositoryBase
|
from .base import RepositoryBase
|
||||||
from ..domain import User
|
|
||||||
|
|
||||||
|
|
||||||
class UserRepository(RepositoryBase[User]):
|
class UserRepository(RepositoryBase[User]):
|
||||||
pass
|
async def get_by_id(self, user_id: int) -> User:
|
||||||
|
return await self._get(User.id == user_id)
|
||||||
|
|
||||||
|
async def get_by_username(self, username: str) -> User:
|
||||||
|
return await self._get(User.username == username)
|
||||||
|
|
|
||||||
|
|
@ -15,19 +15,19 @@ async def test_create_returns_object_with_id(test_repo, test_model):
|
||||||
|
|
||||||
async def test_get_returns_existing_record(test_repo, test_model):
|
async def test_get_returns_existing_record(test_repo, test_model):
|
||||||
created = await test_repo.create(test_model)
|
created = await test_repo.create(test_model)
|
||||||
found = await test_repo.get(TestModel.id == created.id)
|
found = await test_repo._get(TestModel.id == created.id)
|
||||||
assert found is not None
|
assert found is not None
|
||||||
assert found.id == created.id
|
assert found.id == created.id
|
||||||
|
|
||||||
|
|
||||||
async def test_get_raises_not_found_for_missing_record(test_repo):
|
async def test_get_raises_not_found_for_missing_record(test_repo):
|
||||||
with pytest.raises(NotFound):
|
with pytest.raises(NotFound):
|
||||||
await test_repo.get(TestModel.id == 1)
|
await test_repo._get(TestModel.id == 1)
|
||||||
|
|
||||||
|
|
||||||
async def test_get_by_field(test_repo, test_model_factory):
|
async def test_get_by_field(test_repo, test_model_factory):
|
||||||
await test_repo.create(test_model_factory(1, "value"))
|
await test_repo.create(test_model_factory(1, "value"))
|
||||||
found = await test_repo.get(TestModel.property == "value")
|
found = await test_repo._get(TestModel.property == "value")
|
||||||
assert found is not None
|
assert found is not None
|
||||||
assert found.id == 1
|
assert found.id == 1
|
||||||
|
|
||||||
|
|
@ -38,7 +38,7 @@ async def test_get_by_field(test_repo, test_model_factory):
|
||||||
async def test_list_returns_all_matching(test_repo, test_model_factory):
|
async def test_list_returns_all_matching(test_repo, test_model_factory):
|
||||||
await test_repo.create(test_model_factory(1))
|
await test_repo.create(test_model_factory(1))
|
||||||
await test_repo.create(test_model_factory(2))
|
await test_repo.create(test_model_factory(2))
|
||||||
results = await test_repo.list()
|
results = await test_repo._list()
|
||||||
modelnames = {u.id for u in results}
|
modelnames = {u.id for u in results}
|
||||||
assert {1, 2}.issubset(modelnames)
|
assert {1, 2}.issubset(modelnames)
|
||||||
|
|
||||||
|
|
@ -46,14 +46,14 @@ async def test_list_returns_all_matching(test_repo, test_model_factory):
|
||||||
async def test_list_with_filter(test_repo, test_model_factory):
|
async def test_list_with_filter(test_repo, test_model_factory):
|
||||||
await test_repo.create(test_model_factory(1, "value"))
|
await test_repo.create(test_model_factory(1, "value"))
|
||||||
await test_repo.create(test_model_factory(2, "value2"))
|
await test_repo.create(test_model_factory(2, "value2"))
|
||||||
results = await test_repo.list(TestModel.property == "value")
|
results = await test_repo._list(TestModel.property == "value")
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results[0].id == 1
|
assert results[0].id == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_list_returns_empty_when_no_match(test_repo, test_model_factory):
|
async def test_list_returns_empty_when_no_match(test_repo, test_model_factory):
|
||||||
await test_repo.create(test_model_factory(1, "value"))
|
await test_repo.create(test_model_factory(1, "value"))
|
||||||
results = await test_repo.list(TestModel.property == "value2")
|
results = await test_repo._list(TestModel.property == "value2")
|
||||||
assert results == []
|
assert results == []
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -64,4 +64,4 @@ async def test_delete_removes_record(test_repo, test_model):
|
||||||
model = await test_repo.create(test_model)
|
model = await test_repo.create(test_model)
|
||||||
await test_repo.delete(TestModel.id == model.id)
|
await test_repo.delete(TestModel.id == model.id)
|
||||||
with pytest.raises(NotFound):
|
with pytest.raises(NotFound):
|
||||||
await test_repo.get(TestModel.id == model.id)
|
await test_repo._get(TestModel.id == model.id)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,5 @@
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
from fooder.repository.expression import fuzzy_match
|
|
||||||
from fooder.domain import Product
|
|
||||||
from fooder.model.product import ProductModel, ProductCreateModel, ProductUpdateModel
|
from fooder.model.product import ProductModel, ProductCreateModel, ProductUpdateModel
|
||||||
from fooder.controller.product import ProductController
|
from fooder.controller.product import ProductController
|
||||||
from fooder.context import Context, AuthContextDependency
|
from fooder.context import Context, AuthContextDependency
|
||||||
|
|
@ -16,12 +14,7 @@ async def list_products(
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
q: str | None = None,
|
q: str | None = None,
|
||||||
):
|
):
|
||||||
expressions = []
|
return await ctx.repo.product.list(q=q, limit=limit, offset=offset)
|
||||||
if q:
|
|
||||||
expressions.append(
|
|
||||||
fuzzy_match(Product.name, q)
|
|
||||||
)
|
|
||||||
return await ctx.repo.product.list(*expressions, limit=limit, offset=offset)
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/{product_id}", response_model=ProductModel)
|
@router.patch("/{product_id}", response_model=ProductModel)
|
||||||
|
|
@ -30,7 +23,7 @@ async def update_product(
|
||||||
data: ProductUpdateModel,
|
data: ProductUpdateModel,
|
||||||
ctx: Context = Depends(AuthContextDependency()),
|
ctx: Context = Depends(AuthContextDependency()),
|
||||||
):
|
):
|
||||||
obj = await ctx.repo.product.get(Product.id == product_id)
|
obj = await ctx.repo.product.get_by_id(product_id)
|
||||||
async with ctx.repo.transaction():
|
async with ctx.repo.transaction():
|
||||||
await ProductController(ctx, obj).update(data)
|
await ProductController(ctx, obj).update(data)
|
||||||
return obj
|
return obj
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue