[controller/view] CQS way + remove token controller

This commit is contained in:
Piotr Domański 2026-04-07 14:34:03 +02:00
parent bd102360ad
commit 6129712efe
24 changed files with 355 additions and 165 deletions

View file

@ -7,7 +7,8 @@ from fooder.db import get_db_session
from fooder.domain import User
from fooder.repository import Repository
from fooder.utils.datetime import utc_now
from fooder.exc import Unauthorized
from fooder.exc import Unauthorized, NotFound
from fooder.utils.jwt import AccessToken
class Context:
@ -62,10 +63,10 @@ class AuthContextDependency:
session: AsyncSession = Depends(get_db_session),
) -> Context:
ctx = Context(repo=Repository(session))
from fooder.controller.token import TokenController
token_ctrl = TokenController.from_access_token(ctx, token)
user = await ctx.repo.user.get(User.id == token_ctrl.entity_id)
if user is None:
user_id = AccessToken.decode(token).sub
try:
user = await ctx.repo.user.get(User.id == user_id)
except NotFound:
raise Unauthorized()
ctx.set_user(user)
return ctx

View file

@ -1,2 +1,2 @@
from .user import UserController
from .token import TokenController
from .product import ProductController

View file

@ -0,0 +1,53 @@
from fooder.controller.base import ModelController
from fooder.domain import Product
from fooder.context import Context
from fooder.model.product import ProductCreateModel, ProductUpdateModel
from fooder.utils.calories import calculate_calories
class ProductController(ModelController[Product]):
@classmethod
async def create(
cls,
ctx: Context,
data: ProductCreateModel,
) -> "ProductController":
obj = Product(
name=data.name,
protein=data.protein,
carb=data.carb,
fat=data.fat,
fiber=data.fiber,
calories=data.resolved_calories,
barcode=data.barcode,
)
await ctx.repo.product.create(obj)
return cls(ctx, obj)
async def update(self, data: ProductUpdateModel) -> None:
if data.name is not None:
self.obj.name = data.name
if data.protein is not None:
self.obj.protein = data.protein
if data.carb is not None:
self.obj.carb = data.carb
if data.fat is not None:
self.obj.fat = data.fat
if data.fiber is not None:
self.obj.fiber = data.fiber
if data.barcode is not None:
self.obj.barcode = data.barcode
self.obj.calories = data.calories if data.calories is not None else calculate_calories(
protein=self.obj.protein,
carb=self.obj.carb,
fat=self.obj.fat,
fiber=self.obj.fiber,
)
await self.ctx.repo.product.update(self.obj)

View file

@ -1,38 +0,0 @@
from fooder.controller.base import ControllerBase
from fooder.context import Context
from fooder.utils.jwt import Token, AccessToken, RefreshToken
from typing import Type, TypeVar
from datetime import datetime
T = TypeVar("T", bound=Token)
class TokenController(ControllerBase):
def __init__(self, ctx: Context, entity_id: int) -> None:
super().__init__(ctx)
self.entity_id = entity_id
@classmethod
def from_token(cls, ctx: Context, token_str: str, token_cls: Type[T]) -> "TokenController":
token = token_cls.decode(token_str)
return cls(ctx, token.sub)
@classmethod
def from_refresh_token(cls, ctx: Context, token_str: str) -> "TokenController":
return cls.from_token(ctx, token_str, RefreshToken)
@classmethod
def from_access_token(cls, ctx: Context, token_str: str) -> "TokenController":
return cls.from_token(ctx, token_str, AccessToken)
def generate_token(self, token_cls: Type[T], now: datetime) -> T:
return token_cls(exp=token_cls.calculate_exp(now), sub=self.entity_id)
def generate_refresh_token(self, now: datetime) -> RefreshToken:
return self.generate_token(RefreshToken, now)
def generate_access_token(self, now: datetime) -> AccessToken:
return self.generate_token(AccessToken, now)
def generate_token_pair(self, now: datetime) -> tuple[AccessToken, RefreshToken]:
return (self.generate_access_token(now), self.generate_refresh_token(now))

View file

@ -1,8 +1,7 @@
from fooder.controller.base import ModelController
from fooder.controller.token import TokenController
from fooder.domain import User
from fooder.context import Context
from fooder.exc import Unauthorized
from fooder.exc import Unauthorized, NotFound
class UserController(ModelController[User]):
@ -13,12 +12,12 @@ class UserController(ModelController[User]):
username: str,
password: str,
) -> "UserController":
try:
obj = await ctx.repo.user.get(User.username == username)
except NotFound:
raise Unauthorized()
if obj is None or not obj.verify_password(password):
if not obj.verify_password(password):
raise Unauthorized()
return cls(ctx, obj)
def token_ctrl(self) -> TokenController:
return TokenController(ctx=self.ctx, entity_id=self.obj.id)

View file

@ -23,7 +23,8 @@ class DatabaseSessionManager:
),
)
self._sessionmaker = async_sessionmaker(
autocommit=False, autoflush=False, future=True, bind=self._engine
autocommit=False, autoflush=False, future=True, bind=self._engine,
expire_on_commit=False,
)
async def close(self) -> None:

View file

@ -1,4 +1,4 @@
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import Mapped, mapped_column
from fooder.domain.base import Base, CommonMixin
@ -12,16 +12,5 @@ class Product(Base, CommonMixin):
carb: Mapped[float]
fat: Mapped[float]
fiber: Mapped[float]
hard_coded_calories: Mapped[float | None]
barcode: Mapped[str | None]
@property
def calories(self) -> float:
"""calories.
:rtype: float
"""
if self.hard_coded_calories:
return self.hard_coded_calories
return self.protein * 4 + self.carb * 4 + self.fat * 9 + self.fiber * 2
calories: Mapped[float]
barcode: Mapped[str | None] = mapped_column(unique=True)

View file

@ -17,3 +17,13 @@ class NotFound(ApiException):
class Unauthorized(ApiException):
HTTP_CODE = 401
MESSAGE = "Unathorized"
class InvalidValue(ApiException):
HTTP_CODE = 400
MESSAGE = "Invalid value"
class Conflict(ApiException):
HTTP_CODE = 409
MESSAGE = "Conflict"

View file

@ -1,14 +1,15 @@
from .base import ObjModelMixin
from pydantic import BaseModel
from pydantic import BaseModel, Field
from fooder.utils.calories import calculate_calories
class ProductModelBase(BaseModel):
name: str
protein: float
carb: float
fat: float
fiber: float
calories: float
protein: float = Field(ge=0, le=100)
carb: float = Field(ge=0, le=100)
fat: float = Field(ge=0, le=100)
fiber: float = Field(ge=0, le=100)
calories: float = Field(ge=0)
barcode: str | None
@ -17,7 +18,17 @@ class ProductModel(ObjModelMixin, ProductModelBase):
class ProductCreateModel(ProductModelBase):
pass
calories: float | None = None
barcode: str | None = None
@property
def resolved_calories(self) -> float:
return self.calories or calculate_calories(
protein=self.protein,
carb=self.carb,
fat=self.fat,
fiber=self.fiber,
)
class ProductUpdateModel(ProductModelBase):

View file

@ -1,8 +1,10 @@
from typing import TypeVar, Generic, Type, Sequence
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete as sa_delete, ColumnElement
from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import Select
from fooder.domain import Base
from fooder.exc import Conflict, NotFound
T = TypeVar("T", bound=Base)
@ -23,10 +25,15 @@ class RepositoryBase(Generic[T]):
return stmt
async def get(self, *expressions: ColumnElement) -> T | None:
async def get(self, *expressions: ColumnElement) -> T:
stmt = self._build_select(*expressions)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
obj = result.scalar_one_or_none()
if obj is None:
raise NotFound()
return obj
async def list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]:
stmt = self._build_select(*expressions)
@ -42,10 +49,20 @@ class RepositoryBase(Generic[T]):
async def create(self, obj: T) -> T:
self.session.add(obj)
try:
await self.session.flush()
except IntegrityError:
raise Conflict()
await self.session.refresh(obj)
return obj
async def update(self, obj: T) -> T:
try:
await self.session.flush()
except IntegrityError:
raise Conflict()
return obj
async def delete(self, *expressions: ColumnElement):
stmt = sa_delete(self.model)

View file

@ -4,5 +4,5 @@ from sqlalchemy import ColumnElement
def fuzzy_match(attr: InstrumentedAttribute[str], q: str) -> ColumnElement:
q_list = q.split()
qq = "%" + "%".join(q_list) + "%"
return attr.ilike(f"%{qq.lower()}%")
qq = "%".join(q_list)
return attr.ilike(f"%{qq}%")

View file

@ -1,8 +1,12 @@
from contextlib import asynccontextmanager
from typing import AsyncIterator
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError
from .user import UserRepository
from .product import ProductRepository
from ..domain import User, Product
from fooder.repository.user import UserRepository
from fooder.repository.product import ProductRepository
from fooder.domain import User, Product
from fooder.exc import Conflict
class Repository:
@ -12,7 +16,19 @@ class Repository:
self.product = ProductRepository(Product, session)
async def commit(self) -> None:
try:
await self.session.commit()
except IntegrityError:
raise Conflict()
async def rollback(self) -> None:
await self.session.rollback()
@asynccontextmanager
async def transaction(self) -> AsyncIterator["Repository"]:
try:
yield self
await self.commit()
except Exception:
await self.rollback()
raise

View file

@ -1,6 +1,8 @@
from fastapi import APIRouter
from fooder.view.token import router as token_router
from fooder.view.product import router as product_router
router = APIRouter(prefix="/api")
router.include_router(token_router, prefix="/token", tags=["token"])
router.include_router(product_router, prefix="/product", tags=["product"])

View file

@ -1,46 +0,0 @@
from datetime import datetime, timezone
import pytest
from fooder.controller.token import TokenController
from fooder.exc import Unauthorized
from fooder.utils.jwt import AccessToken, RefreshToken
NOW = datetime.now(timezone.utc)
def test_token_ctrl_generates_token(ctx):
token_ctrl = TokenController(ctx, 1)
token_ctrl.generate_token_pair(ctx.clock())
class TestFromRefreshToken:
def test_returns_controller_with_correct_entity_id(self, ctx):
token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=42)
ctrl = TokenController.from_refresh_token(ctx, token.encode())
assert ctrl.entity_id == 42
def test_invalid_string_raises(self, ctx):
with pytest.raises(Unauthorized):
TokenController.from_refresh_token(ctx, "bad-token")
def test_access_token_raises(self, ctx):
token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=1)
with pytest.raises(Unauthorized):
TokenController.from_refresh_token(ctx, token.encode())
class TestFromAccessToken:
def test_returns_controller_with_correct_entity_id(self, ctx):
token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=7)
ctrl = TokenController.from_access_token(ctx, token.encode())
assert ctrl.entity_id == 7
def test_invalid_string_raises(self, ctx):
with pytest.raises(Unauthorized):
TokenController.from_access_token(ctx, "bad-token")
def test_refresh_token_raises(self, ctx):
token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=1)
with pytest.raises(Unauthorized):
TokenController.from_access_token(ctx, token.encode())

View file

@ -4,6 +4,7 @@ from .faker import *
from .user import *
from .client import *
from .context import *
from .product import *
@pytest.fixture

View file

@ -14,3 +14,15 @@ async def client(db_session):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c:
yield c
app.dependency_overrides.clear()
@pytest_asyncio.fixture
async def auth_client(client, user, user_password):
response = await client.post(
"/api/token",
data={"username": user.username, "password": user_password},
)
token = response.json()["access_token"]
client.headers["Authorization"] = f"Bearer {token}"
yield client
del client.headers["Authorization"]

View file

@ -47,7 +47,7 @@ def db_manager() -> DatabaseSessionManager:
async def db_session(db_manager):
async with db_manager._engine.connect() as conn:
trans = await conn.begin()
session = AsyncSession(bind=conn)
session = AsyncSession(bind=conn, expire_on_commit=False)
nested = await conn.begin_nested()

26
fooder/test/fixtures/product.py vendored Normal file
View file

@ -0,0 +1,26 @@
import pytest
import pytest_asyncio
from fooder.controller.product import ProductController
from fooder.model.product import ProductCreateModel
@pytest.fixture
def product_payload():
return {
"name": "Chicken Breast",
"protein": 31.0,
"carb": 0.0,
"fat": 3.6,
"fiber": 0.0,
}
@pytest_asyncio.fixture
async def product(ctx):
data = ProductCreateModel(name="Chicken Breast", protein=31.0, carb=0.0, fat=3.6, fiber=0.0)
async with ctx.repo.transaction():
ctrl = await ProductController.create(ctx, data)
return ctrl.obj

View file

@ -1,5 +1,6 @@
import pytest
from ..fixtures.db import TestModel
from fooder.exc import NotFound
# ------------------------------------------------------------------ create ---
@ -19,9 +20,9 @@ async def test_get_returns_existing_record(test_repo, test_model):
assert found.id == created.id
async def test_get_returns_none_for_missing_record(test_repo):
result = await test_repo.get(TestModel.id == 1)
assert result is None
async def test_get_raises_not_found_for_missing_record(test_repo):
with pytest.raises(NotFound):
await test_repo.get(TestModel.id == 1)
async def test_get_by_field(test_repo, test_model_factory):
@ -62,4 +63,5 @@ async def test_list_returns_empty_when_no_match(test_repo, test_model_factory):
async def test_delete_removes_record(test_repo, test_model):
model = await test_repo.create(test_model)
await test_repo.delete(TestModel.id == model.id)
assert await test_repo.get(TestModel.id == model.id) is None
with pytest.raises(NotFound):
await test_repo.get(TestModel.id == model.id)

View file

@ -0,0 +1,106 @@
import pytest
async def test_update_product_returns_200(auth_client, product):
response = await auth_client.patch(f"/api/product/{product.id}", json={"name": "Updated Name"})
assert response.status_code == 200
assert response.json()["name"] == "Updated Name"
async def test_update_product_recalculates_calories_when_macros_change(auth_client, product):
response = await auth_client.patch(f"/api/product/{product.id}", json={"protein": 10.0})
# protein=10, carb=0, fat=3.6, fiber=0 → 10*4 + 3.6*9 = 40 + 32.4 = 72.4
assert response.json()["calories"] == pytest.approx(72.4)
async def test_update_product_uses_explicit_calories(auth_client, product):
response = await auth_client.patch(
f"/api/product/{product.id}", json={"protein": 10.0, "calories": 99.0}
)
assert response.json()["calories"] == 99.0
async def test_update_product_not_found_returns_404(auth_client):
response = await auth_client.patch("/api/product/99999", json={"name": "Ghost"})
assert response.status_code == 404
async def test_update_product_duplicate_barcode_returns_409(auth_client, product, product_payload):
await auth_client.post("/api/product", json={**product_payload, "name": "Other", "barcode": "AAA"})
response = await auth_client.patch(f"/api/product/{product.id}", json={"barcode": "AAA"})
assert response.status_code == 409
async def test_update_product_without_auth_returns_401(client, product):
response = await client.patch(f"/api/product/{product.id}", json={"name": "x"})
assert response.status_code == 401
async def test_create_product_returns_201(auth_client, product_payload):
response = await auth_client.post("/api/product", json=product_payload)
assert response.status_code == 201
async def test_create_product_returns_correct_fields(auth_client, product_payload):
response = await auth_client.post("/api/product", json=product_payload)
body = response.json()
assert body["name"] == product_payload["name"]
assert body["protein"] == product_payload["protein"]
assert body["carb"] == product_payload["carb"]
assert body["fat"] == product_payload["fat"]
assert body["fiber"] == product_payload["fiber"]
assert "id" in body
async def test_create_product_calculates_calories(auth_client, product_payload):
response = await auth_client.post("/api/product", json=product_payload)
# 31*4 + 0*4 + 3.6*9 + 0*2 = 124 + 32.4 = 156.4
assert response.json()["calories"] == pytest.approx(156.4)
async def test_create_product_uses_explicit_calories(auth_client, product_payload):
response = await auth_client.post("/api/product", json={**product_payload, "calories": 50.0})
assert response.json()["calories"] == 50.0
async def test_create_product_duplicate_barcode_returns_409(auth_client, product_payload):
await auth_client.post("/api/product", json={**product_payload, "barcode": "123456"})
response = await auth_client.post("/api/product", json={**product_payload, "barcode": "123456"})
assert response.status_code == 409
async def test_create_product_invalid_protein_returns_422(auth_client, product_payload):
response = await auth_client.post("/api/product", json={**product_payload, "protein": -1.0})
assert response.status_code == 422
async def test_create_product_protein_over_100_returns_422(auth_client, product_payload):
response = await auth_client.post("/api/product", json={**product_payload, "protein": 101.0})
assert response.status_code == 422
async def test_create_product_without_auth_returns_401(client, product_payload):
response = await client.post("/api/product", json=product_payload)
assert response.status_code == 401
async def test_list_products_returns_200(auth_client, product):
response = await auth_client.get("/api/product")
assert response.status_code == 200
assert isinstance(response.json(), list)
async def test_list_products_contains_created(auth_client, product):
response = await auth_client.get("/api/product")
ids = [p["id"] for p in response.json()]
assert product.id in ids
async def test_list_products_filters_by_name(auth_client, product):
response = await auth_client.get("/api/product", params={"q": product.name})
assert all(product.name.lower() in p["name"].lower() for p in response.json())
async def test_list_products_without_auth_returns_401(client):
response = await client.get("/api/product")
assert response.status_code == 401

20
fooder/utils/calories.py Normal file
View file

@ -0,0 +1,20 @@
PROTEIN_KCAL = 4
CARB_KCAL = 4
FAT_KCAL = 9
FIBER_KCAL = 2
def calculate_calories(
protein: float,
carb: float,
fat: float,
fiber: float,
) -> float:
return sum(
(
PROTEIN_KCAL * protein,
CARB_KCAL * carb,
FAT_KCAL * fat,
FIBER_KCAL * fiber,
)
)

View file

@ -43,3 +43,10 @@ class AccessToken(Token):
class RefreshToken(Token):
secret_key = settings.REFRESH_SECRET_KEY
expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
def generate_token_pair(entity_id: int, now: datetime) -> tuple[AccessToken, RefreshToken]:
return (
AccessToken(exp=AccessToken.calculate_exp(now), sub=entity_id),
RefreshToken(exp=RefreshToken.calculate_exp(now), sub=entity_id),
)

View file

@ -1,45 +1,46 @@
from typing import Optional
from fastapi import APIRouter, Depends
from fooder.repository.expression import fuzzy_match
from fooder.domain import Product
from fooder.model.product import ProductModel
from fooder.model.product import ProductModel, ProductCreateModel, ProductUpdateModel
from fooder.controller.product import ProductController
from fooder.context import Context, AuthContextDependency
router = APIRouter(tags=["product"])
@router.get("/", response_model=list[ProductModel])
@router.get("", response_model=list[ProductModel])
async def list_products(
ctx: Context = Depends(AuthContextDependency()),
limit: int = 10,
offset: int = 0,
q: Optional[str] = None,
q: str | None = None,
):
expressions = []
if q:
expressions.append(
fuzzy_match(Product.name, q)
)
objs = await ctx.repo.product.list(*expressions, limit=limit, offset=offset)
return [ProductModel.model_validate(obj) for obj in objs]
return await ctx.repo.product.list(*expressions, limit=limit, offset=offset)
# @router.post("/", response_model=Product)
# async def create_product(
# request: Request,
# data: CreateProductPayload,
# contoller: CreateProduct = Depends(CreateProduct),
# ):
# return await contoller.call(data)
#
#
# @router.get("/by_barcode", response_model=Product)
# async def get_by_bar_code(
# request: Request,
# barcode: str,
# contoller: GetProductByBarCode = Depends(GetProductByBarCode),
# ):
# return await contoller.call(barcode)
@router.patch("/{product_id}", response_model=ProductModel)
async def update_product(
product_id: int,
data: ProductUpdateModel,
ctx: Context = Depends(AuthContextDependency()),
):
obj = await ctx.repo.product.get(Product.id == product_id)
async with ctx.repo.transaction():
await ProductController(ctx, obj).update(data)
return obj
@router.post("", response_model=ProductModel, status_code=201)
async def create_product(
data: ProductCreateModel,
ctx: Context = Depends(AuthContextDependency()),
):
async with ctx.repo.transaction():
ctrl = await ProductController.create(ctx, data)
return ctrl.obj

View file

@ -6,14 +6,14 @@ from datetime import datetime
from fooder.model.token import TokenResponse
from fooder.context import ContextDependency, Context
from fooder.controller import UserController
from fooder.controller.token import TokenController
from fooder.controller.user import UserController
from fooder.utils.jwt import RefreshToken, generate_token_pair
router = APIRouter(tags=["token"])
def gen_token_response(token_ctrl: TokenController, now: datetime) -> TokenResponse:
access_token, refresh_token = token_ctrl.generate_token_pair(now)
def gen_token_response(entity_id: int, now: datetime) -> TokenResponse:
access_token, refresh_token = generate_token_pair(entity_id, now)
return TokenResponse(
access_token=access_token.encode(),
refresh_token=refresh_token.encode(),
@ -27,7 +27,7 @@ async def token_create(
) -> TokenResponse:
now = ctx.clock()
user_ctrl = await UserController.session_start(ctx, data.username, data.password)
return gen_token_response(user_ctrl.token_ctrl(), now)
return gen_token_response(user_ctrl.obj.id, now)
@router.post("/refresh", response_model=TokenResponse)
@ -36,5 +36,5 @@ async def token_refresh(
ctx: Context = Depends(ContextDependency()),
) -> TokenResponse:
now = ctx.clock()
token_ctrl = TokenController.from_refresh_token(ctx, refresh_token)
return gen_token_response(token_ctrl, now)
token = RefreshToken.decode(refresh_token)
return gen_token_response(token.sub, now)