[controller/view] CQS way + remove token controller
This commit is contained in:
parent
bd102360ad
commit
6129712efe
24 changed files with 355 additions and 165 deletions
|
|
@ -7,7 +7,8 @@ from fooder.db import get_db_session
|
|||
from fooder.domain import User
|
||||
from fooder.repository import Repository
|
||||
from fooder.utils.datetime import utc_now
|
||||
from fooder.exc import Unauthorized
|
||||
from fooder.exc import Unauthorized, NotFound
|
||||
from fooder.utils.jwt import AccessToken
|
||||
|
||||
|
||||
class Context:
|
||||
|
|
@ -61,11 +62,11 @@ class AuthContextDependency:
|
|||
token: str = Depends(OAuth2PasswordBearer(tokenUrl="/token")),
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
) -> Context:
|
||||
ctx = Context(repo = Repository(session))
|
||||
from fooder.controller.token import TokenController
|
||||
token_ctrl = TokenController.from_access_token(ctx, token)
|
||||
user = await ctx.repo.user.get(User.id == token_ctrl.entity_id)
|
||||
if user is None:
|
||||
ctx = Context(repo=Repository(session))
|
||||
user_id = AccessToken.decode(token).sub
|
||||
try:
|
||||
user = await ctx.repo.user.get(User.id == user_id)
|
||||
except NotFound:
|
||||
raise Unauthorized()
|
||||
ctx.set_user(user)
|
||||
return ctx
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
from .user import UserController
|
||||
from .token import TokenController
|
||||
from .product import ProductController
|
||||
|
|
|
|||
53
fooder/controller/product.py
Normal file
53
fooder/controller/product.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
from fooder.controller.base import ModelController
|
||||
from fooder.domain import Product
|
||||
from fooder.context import Context
|
||||
from fooder.model.product import ProductCreateModel, ProductUpdateModel
|
||||
from fooder.utils.calories import calculate_calories
|
||||
|
||||
|
||||
class ProductController(ModelController[Product]):
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
ctx: Context,
|
||||
data: ProductCreateModel,
|
||||
) -> "ProductController":
|
||||
obj = Product(
|
||||
name=data.name,
|
||||
protein=data.protein,
|
||||
carb=data.carb,
|
||||
fat=data.fat,
|
||||
fiber=data.fiber,
|
||||
calories=data.resolved_calories,
|
||||
barcode=data.barcode,
|
||||
)
|
||||
await ctx.repo.product.create(obj)
|
||||
return cls(ctx, obj)
|
||||
|
||||
async def update(self, data: ProductUpdateModel) -> None:
|
||||
if data.name is not None:
|
||||
self.obj.name = data.name
|
||||
|
||||
if data.protein is not None:
|
||||
self.obj.protein = data.protein
|
||||
|
||||
if data.carb is not None:
|
||||
self.obj.carb = data.carb
|
||||
|
||||
if data.fat is not None:
|
||||
self.obj.fat = data.fat
|
||||
|
||||
if data.fiber is not None:
|
||||
self.obj.fiber = data.fiber
|
||||
|
||||
if data.barcode is not None:
|
||||
self.obj.barcode = data.barcode
|
||||
|
||||
self.obj.calories = data.calories if data.calories is not None else calculate_calories(
|
||||
protein=self.obj.protein,
|
||||
carb=self.obj.carb,
|
||||
fat=self.obj.fat,
|
||||
fiber=self.obj.fiber,
|
||||
)
|
||||
|
||||
await self.ctx.repo.product.update(self.obj)
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
from fooder.controller.base import ControllerBase
|
||||
from fooder.context import Context
|
||||
from fooder.utils.jwt import Token, AccessToken, RefreshToken
|
||||
from typing import Type, TypeVar
|
||||
from datetime import datetime
|
||||
|
||||
T = TypeVar("T", bound=Token)
|
||||
|
||||
|
||||
class TokenController(ControllerBase):
|
||||
def __init__(self, ctx: Context, entity_id: int) -> None:
|
||||
super().__init__(ctx)
|
||||
self.entity_id = entity_id
|
||||
|
||||
@classmethod
|
||||
def from_token(cls, ctx: Context, token_str: str, token_cls: Type[T]) -> "TokenController":
|
||||
token = token_cls.decode(token_str)
|
||||
return cls(ctx, token.sub)
|
||||
|
||||
@classmethod
|
||||
def from_refresh_token(cls, ctx: Context, token_str: str) -> "TokenController":
|
||||
return cls.from_token(ctx, token_str, RefreshToken)
|
||||
|
||||
@classmethod
|
||||
def from_access_token(cls, ctx: Context, token_str: str) -> "TokenController":
|
||||
return cls.from_token(ctx, token_str, AccessToken)
|
||||
|
||||
def generate_token(self, token_cls: Type[T], now: datetime) -> T:
|
||||
return token_cls(exp=token_cls.calculate_exp(now), sub=self.entity_id)
|
||||
|
||||
def generate_refresh_token(self, now: datetime) -> RefreshToken:
|
||||
return self.generate_token(RefreshToken, now)
|
||||
|
||||
def generate_access_token(self, now: datetime) -> AccessToken:
|
||||
return self.generate_token(AccessToken, now)
|
||||
|
||||
def generate_token_pair(self, now: datetime) -> tuple[AccessToken, RefreshToken]:
|
||||
return (self.generate_access_token(now), self.generate_refresh_token(now))
|
||||
|
|
@ -1,8 +1,7 @@
|
|||
from fooder.controller.base import ModelController
|
||||
from fooder.controller.token import TokenController
|
||||
from fooder.domain import User
|
||||
from fooder.context import Context
|
||||
from fooder.exc import Unauthorized
|
||||
from fooder.exc import Unauthorized, NotFound
|
||||
|
||||
|
||||
class UserController(ModelController[User]):
|
||||
|
|
@ -13,12 +12,12 @@ class UserController(ModelController[User]):
|
|||
username: str,
|
||||
password: str,
|
||||
) -> "UserController":
|
||||
try:
|
||||
obj = await ctx.repo.user.get(User.username == username)
|
||||
except NotFound:
|
||||
raise Unauthorized()
|
||||
|
||||
if obj is None or not obj.verify_password(password):
|
||||
if not obj.verify_password(password):
|
||||
raise Unauthorized()
|
||||
|
||||
return cls(ctx, obj)
|
||||
|
||||
def token_ctrl(self) -> TokenController:
|
||||
return TokenController(ctx=self.ctx, entity_id=self.obj.id)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,8 @@ class DatabaseSessionManager:
|
|||
),
|
||||
)
|
||||
self._sessionmaker = async_sessionmaker(
|
||||
autocommit=False, autoflush=False, future=True, bind=self._engine
|
||||
autocommit=False, autoflush=False, future=True, bind=self._engine,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from fooder.domain.base import Base, CommonMixin
|
||||
|
||||
|
|
@ -12,16 +12,5 @@ class Product(Base, CommonMixin):
|
|||
carb: Mapped[float]
|
||||
fat: Mapped[float]
|
||||
fiber: Mapped[float]
|
||||
hard_coded_calories: Mapped[float | None]
|
||||
barcode: Mapped[str | None]
|
||||
|
||||
@property
|
||||
def calories(self) -> float:
|
||||
"""calories.
|
||||
|
||||
:rtype: float
|
||||
"""
|
||||
if self.hard_coded_calories:
|
||||
return self.hard_coded_calories
|
||||
|
||||
return self.protein * 4 + self.carb * 4 + self.fat * 9 + self.fiber * 2
|
||||
calories: Mapped[float]
|
||||
barcode: Mapped[str | None] = mapped_column(unique=True)
|
||||
|
|
|
|||
|
|
@ -17,3 +17,13 @@ class NotFound(ApiException):
|
|||
class Unauthorized(ApiException):
|
||||
HTTP_CODE = 401
|
||||
MESSAGE = "Unathorized"
|
||||
|
||||
|
||||
class InvalidValue(ApiException):
|
||||
HTTP_CODE = 400
|
||||
MESSAGE = "Invalid value"
|
||||
|
||||
|
||||
class Conflict(ApiException):
|
||||
HTTP_CODE = 409
|
||||
MESSAGE = "Conflict"
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
from .base import ObjModelMixin
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from fooder.utils.calories import calculate_calories
|
||||
|
||||
|
||||
class ProductModelBase(BaseModel):
|
||||
name: str
|
||||
protein: float
|
||||
carb: float
|
||||
fat: float
|
||||
fiber: float
|
||||
calories: float
|
||||
protein: float = Field(ge=0, le=100)
|
||||
carb: float = Field(ge=0, le=100)
|
||||
fat: float = Field(ge=0, le=100)
|
||||
fiber: float = Field(ge=0, le=100)
|
||||
calories: float = Field(ge=0)
|
||||
barcode: str | None
|
||||
|
||||
|
||||
|
|
@ -17,7 +18,17 @@ class ProductModel(ObjModelMixin, ProductModelBase):
|
|||
|
||||
|
||||
class ProductCreateModel(ProductModelBase):
|
||||
pass
|
||||
calories: float | None = None
|
||||
barcode: str | None = None
|
||||
|
||||
@property
|
||||
def resolved_calories(self) -> float:
|
||||
return self.calories or calculate_calories(
|
||||
protein=self.protein,
|
||||
carb=self.carb,
|
||||
fat=self.fat,
|
||||
fiber=self.fiber,
|
||||
)
|
||||
|
||||
|
||||
class ProductUpdateModel(ProductModelBase):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from typing import TypeVar, Generic, Type, Sequence
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, delete as sa_delete, ColumnElement
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.sql import Select
|
||||
from fooder.domain import Base
|
||||
from fooder.exc import Conflict, NotFound
|
||||
|
||||
T = TypeVar("T", bound=Base)
|
||||
|
||||
|
|
@ -23,10 +25,15 @@ class RepositoryBase(Generic[T]):
|
|||
|
||||
return stmt
|
||||
|
||||
async def get(self, *expressions: ColumnElement) -> T | None:
|
||||
async def get(self, *expressions: ColumnElement) -> T:
|
||||
stmt = self._build_select(*expressions)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
raise NotFound()
|
||||
|
||||
return obj
|
||||
|
||||
async def list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]:
|
||||
stmt = self._build_select(*expressions)
|
||||
|
|
@ -42,10 +49,20 @@ class RepositoryBase(Generic[T]):
|
|||
|
||||
async def create(self, obj: T) -> T:
|
||||
self.session.add(obj)
|
||||
try:
|
||||
await self.session.flush()
|
||||
except IntegrityError:
|
||||
raise Conflict()
|
||||
await self.session.refresh(obj)
|
||||
return obj
|
||||
|
||||
async def update(self, obj: T) -> T:
|
||||
try:
|
||||
await self.session.flush()
|
||||
except IntegrityError:
|
||||
raise Conflict()
|
||||
return obj
|
||||
|
||||
async def delete(self, *expressions: ColumnElement):
|
||||
stmt = sa_delete(self.model)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,5 +4,5 @@ from sqlalchemy import ColumnElement
|
|||
|
||||
def fuzzy_match(attr: InstrumentedAttribute[str], q: str) -> ColumnElement:
|
||||
q_list = q.split()
|
||||
qq = "%" + "%".join(q_list) + "%"
|
||||
return attr.ilike(f"%{qq.lower()}%")
|
||||
qq = "%".join(q_list)
|
||||
return attr.ilike(f"%{qq}%")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from .user import UserRepository
|
||||
from .product import ProductRepository
|
||||
from ..domain import User, Product
|
||||
from fooder.repository.user import UserRepository
|
||||
from fooder.repository.product import ProductRepository
|
||||
from fooder.domain import User, Product
|
||||
from fooder.exc import Conflict
|
||||
|
||||
|
||||
class Repository:
|
||||
|
|
@ -12,7 +16,19 @@ class Repository:
|
|||
self.product = ProductRepository(Product, session)
|
||||
|
||||
async def commit(self) -> None:
|
||||
try:
|
||||
await self.session.commit()
|
||||
except IntegrityError:
|
||||
raise Conflict()
|
||||
|
||||
async def rollback(self) -> None:
|
||||
await self.session.rollback()
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self) -> AsyncIterator["Repository"]:
|
||||
try:
|
||||
yield self
|
||||
await self.commit()
|
||||
except Exception:
|
||||
await self.rollback()
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from fooder.view.token import router as token_router
|
||||
from fooder.view.product import router as product_router
|
||||
|
||||
router = APIRouter(prefix="/api")
|
||||
router.include_router(token_router, prefix="/token", tags=["token"])
|
||||
router.include_router(product_router, prefix="/product", tags=["product"])
|
||||
|
|
|
|||
|
|
@ -1,46 +0,0 @@
|
|||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from fooder.controller.token import TokenController
|
||||
from fooder.exc import Unauthorized
|
||||
from fooder.utils.jwt import AccessToken, RefreshToken
|
||||
|
||||
NOW = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def test_token_ctrl_generates_token(ctx):
|
||||
token_ctrl = TokenController(ctx, 1)
|
||||
token_ctrl.generate_token_pair(ctx.clock())
|
||||
|
||||
|
||||
class TestFromRefreshToken:
|
||||
def test_returns_controller_with_correct_entity_id(self, ctx):
|
||||
token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=42)
|
||||
ctrl = TokenController.from_refresh_token(ctx, token.encode())
|
||||
assert ctrl.entity_id == 42
|
||||
|
||||
def test_invalid_string_raises(self, ctx):
|
||||
with pytest.raises(Unauthorized):
|
||||
TokenController.from_refresh_token(ctx, "bad-token")
|
||||
|
||||
def test_access_token_raises(self, ctx):
|
||||
token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=1)
|
||||
with pytest.raises(Unauthorized):
|
||||
TokenController.from_refresh_token(ctx, token.encode())
|
||||
|
||||
|
||||
class TestFromAccessToken:
|
||||
def test_returns_controller_with_correct_entity_id(self, ctx):
|
||||
token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=7)
|
||||
ctrl = TokenController.from_access_token(ctx, token.encode())
|
||||
assert ctrl.entity_id == 7
|
||||
|
||||
def test_invalid_string_raises(self, ctx):
|
||||
with pytest.raises(Unauthorized):
|
||||
TokenController.from_access_token(ctx, "bad-token")
|
||||
|
||||
def test_refresh_token_raises(self, ctx):
|
||||
token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=1)
|
||||
with pytest.raises(Unauthorized):
|
||||
TokenController.from_access_token(ctx, token.encode())
|
||||
1
fooder/test/fixtures/__init__.py
vendored
1
fooder/test/fixtures/__init__.py
vendored
|
|
@ -4,6 +4,7 @@ from .faker import *
|
|||
from .user import *
|
||||
from .client import *
|
||||
from .context import *
|
||||
from .product import *
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
12
fooder/test/fixtures/client.py
vendored
12
fooder/test/fixtures/client.py
vendored
|
|
@ -14,3 +14,15 @@ async def client(db_session):
|
|||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c:
|
||||
yield c
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def auth_client(client, user, user_password):
|
||||
response = await client.post(
|
||||
"/api/token",
|
||||
data={"username": user.username, "password": user_password},
|
||||
)
|
||||
token = response.json()["access_token"]
|
||||
client.headers["Authorization"] = f"Bearer {token}"
|
||||
yield client
|
||||
del client.headers["Authorization"]
|
||||
|
|
|
|||
2
fooder/test/fixtures/db.py
vendored
2
fooder/test/fixtures/db.py
vendored
|
|
@ -47,7 +47,7 @@ def db_manager() -> DatabaseSessionManager:
|
|||
async def db_session(db_manager):
|
||||
async with db_manager._engine.connect() as conn:
|
||||
trans = await conn.begin()
|
||||
session = AsyncSession(bind=conn)
|
||||
session = AsyncSession(bind=conn, expire_on_commit=False)
|
||||
|
||||
nested = await conn.begin_nested()
|
||||
|
||||
|
|
|
|||
26
fooder/test/fixtures/product.py
vendored
Normal file
26
fooder/test/fixtures/product.py
vendored
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from fooder.controller.product import ProductController
|
||||
from fooder.model.product import ProductCreateModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def product_payload():
|
||||
return {
|
||||
"name": "Chicken Breast",
|
||||
"protein": 31.0,
|
||||
"carb": 0.0,
|
||||
"fat": 3.6,
|
||||
"fiber": 0.0,
|
||||
}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def product(ctx):
|
||||
data = ProductCreateModel(name="Chicken Breast", protein=31.0, carb=0.0, fat=3.6, fiber=0.0)
|
||||
async with ctx.repo.transaction():
|
||||
ctrl = await ProductController.create(ctx, data)
|
||||
return ctrl.obj
|
||||
|
||||
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
from ..fixtures.db import TestModel
|
||||
from fooder.exc import NotFound
|
||||
# ------------------------------------------------------------------ create ---
|
||||
|
||||
|
||||
|
|
@ -19,9 +20,9 @@ async def test_get_returns_existing_record(test_repo, test_model):
|
|||
assert found.id == created.id
|
||||
|
||||
|
||||
async def test_get_returns_none_for_missing_record(test_repo):
|
||||
result = await test_repo.get(TestModel.id == 1)
|
||||
assert result is None
|
||||
async def test_get_raises_not_found_for_missing_record(test_repo):
|
||||
with pytest.raises(NotFound):
|
||||
await test_repo.get(TestModel.id == 1)
|
||||
|
||||
|
||||
async def test_get_by_field(test_repo, test_model_factory):
|
||||
|
|
@ -62,4 +63,5 @@ async def test_list_returns_empty_when_no_match(test_repo, test_model_factory):
|
|||
async def test_delete_removes_record(test_repo, test_model):
|
||||
model = await test_repo.create(test_model)
|
||||
await test_repo.delete(TestModel.id == model.id)
|
||||
assert await test_repo.get(TestModel.id == model.id) is None
|
||||
with pytest.raises(NotFound):
|
||||
await test_repo.get(TestModel.id == model.id)
|
||||
|
|
|
|||
106
fooder/test/view/test_product.py
Normal file
106
fooder/test/view/test_product.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
import pytest
|
||||
|
||||
|
||||
async def test_update_product_returns_200(auth_client, product):
|
||||
response = await auth_client.patch(f"/api/product/{product.id}", json={"name": "Updated Name"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Updated Name"
|
||||
|
||||
|
||||
async def test_update_product_recalculates_calories_when_macros_change(auth_client, product):
|
||||
response = await auth_client.patch(f"/api/product/{product.id}", json={"protein": 10.0})
|
||||
# protein=10, carb=0, fat=3.6, fiber=0 → 10*4 + 3.6*9 = 40 + 32.4 = 72.4
|
||||
assert response.json()["calories"] == pytest.approx(72.4)
|
||||
|
||||
|
||||
async def test_update_product_uses_explicit_calories(auth_client, product):
|
||||
response = await auth_client.patch(
|
||||
f"/api/product/{product.id}", json={"protein": 10.0, "calories": 99.0}
|
||||
)
|
||||
assert response.json()["calories"] == 99.0
|
||||
|
||||
|
||||
async def test_update_product_not_found_returns_404(auth_client):
|
||||
response = await auth_client.patch("/api/product/99999", json={"name": "Ghost"})
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
async def test_update_product_duplicate_barcode_returns_409(auth_client, product, product_payload):
|
||||
await auth_client.post("/api/product", json={**product_payload, "name": "Other", "barcode": "AAA"})
|
||||
response = await auth_client.patch(f"/api/product/{product.id}", json={"barcode": "AAA"})
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
async def test_update_product_without_auth_returns_401(client, product):
|
||||
response = await client.patch(f"/api/product/{product.id}", json={"name": "x"})
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
async def test_create_product_returns_201(auth_client, product_payload):
|
||||
response = await auth_client.post("/api/product", json=product_payload)
|
||||
assert response.status_code == 201
|
||||
|
||||
|
||||
async def test_create_product_returns_correct_fields(auth_client, product_payload):
|
||||
response = await auth_client.post("/api/product", json=product_payload)
|
||||
body = response.json()
|
||||
assert body["name"] == product_payload["name"]
|
||||
assert body["protein"] == product_payload["protein"]
|
||||
assert body["carb"] == product_payload["carb"]
|
||||
assert body["fat"] == product_payload["fat"]
|
||||
assert body["fiber"] == product_payload["fiber"]
|
||||
assert "id" in body
|
||||
|
||||
|
||||
async def test_create_product_calculates_calories(auth_client, product_payload):
|
||||
response = await auth_client.post("/api/product", json=product_payload)
|
||||
# 31*4 + 0*4 + 3.6*9 + 0*2 = 124 + 32.4 = 156.4
|
||||
assert response.json()["calories"] == pytest.approx(156.4)
|
||||
|
||||
|
||||
async def test_create_product_uses_explicit_calories(auth_client, product_payload):
|
||||
response = await auth_client.post("/api/product", json={**product_payload, "calories": 50.0})
|
||||
assert response.json()["calories"] == 50.0
|
||||
|
||||
|
||||
async def test_create_product_duplicate_barcode_returns_409(auth_client, product_payload):
|
||||
await auth_client.post("/api/product", json={**product_payload, "barcode": "123456"})
|
||||
response = await auth_client.post("/api/product", json={**product_payload, "barcode": "123456"})
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
async def test_create_product_invalid_protein_returns_422(auth_client, product_payload):
|
||||
response = await auth_client.post("/api/product", json={**product_payload, "protein": -1.0})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
async def test_create_product_protein_over_100_returns_422(auth_client, product_payload):
|
||||
response = await auth_client.post("/api/product", json={**product_payload, "protein": 101.0})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
async def test_create_product_without_auth_returns_401(client, product_payload):
|
||||
response = await client.post("/api/product", json=product_payload)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
async def test_list_products_returns_200(auth_client, product):
|
||||
response = await auth_client.get("/api/product")
|
||||
assert response.status_code == 200
|
||||
assert isinstance(response.json(), list)
|
||||
|
||||
|
||||
async def test_list_products_contains_created(auth_client, product):
|
||||
response = await auth_client.get("/api/product")
|
||||
ids = [p["id"] for p in response.json()]
|
||||
assert product.id in ids
|
||||
|
||||
|
||||
async def test_list_products_filters_by_name(auth_client, product):
|
||||
response = await auth_client.get("/api/product", params={"q": product.name})
|
||||
assert all(product.name.lower() in p["name"].lower() for p in response.json())
|
||||
|
||||
|
||||
async def test_list_products_without_auth_returns_401(client):
|
||||
response = await client.get("/api/product")
|
||||
assert response.status_code == 401
|
||||
20
fooder/utils/calories.py
Normal file
20
fooder/utils/calories.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
PROTEIN_KCAL = 4
|
||||
CARB_KCAL = 4
|
||||
FAT_KCAL = 9
|
||||
FIBER_KCAL = 2
|
||||
|
||||
|
||||
def calculate_calories(
|
||||
protein: float,
|
||||
carb: float,
|
||||
fat: float,
|
||||
fiber: float,
|
||||
) -> float:
|
||||
return sum(
|
||||
(
|
||||
PROTEIN_KCAL * protein,
|
||||
CARB_KCAL * carb,
|
||||
FAT_KCAL * fat,
|
||||
FIBER_KCAL * fiber,
|
||||
)
|
||||
)
|
||||
|
|
@ -43,3 +43,10 @@ class AccessToken(Token):
|
|||
class RefreshToken(Token):
|
||||
secret_key = settings.REFRESH_SECRET_KEY
|
||||
expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
|
||||
def generate_token_pair(entity_id: int, now: datetime) -> tuple[AccessToken, RefreshToken]:
|
||||
return (
|
||||
AccessToken(exp=AccessToken.calculate_exp(now), sub=entity_id),
|
||||
RefreshToken(exp=RefreshToken.calculate_exp(now), sub=entity_id),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,45 +1,46 @@
|
|||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from fooder.repository.expression import fuzzy_match
|
||||
from fooder.domain import Product
|
||||
from fooder.model.product import ProductModel
|
||||
from fooder.model.product import ProductModel, ProductCreateModel, ProductUpdateModel
|
||||
from fooder.controller.product import ProductController
|
||||
from fooder.context import Context, AuthContextDependency
|
||||
|
||||
router = APIRouter(tags=["product"])
|
||||
|
||||
|
||||
@router.get("/", response_model=list[ProductModel])
|
||||
@router.get("", response_model=list[ProductModel])
|
||||
async def list_products(
|
||||
ctx: Context = Depends(AuthContextDependency()),
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
q: Optional[str] = None,
|
||||
q: str | None = None,
|
||||
):
|
||||
expressions = []
|
||||
if q:
|
||||
expressions.append(
|
||||
fuzzy_match(Product.name, q)
|
||||
)
|
||||
|
||||
objs = await ctx.repo.product.list(*expressions, limit=limit, offset=offset)
|
||||
return [ProductModel.model_validate(obj) for obj in objs]
|
||||
return await ctx.repo.product.list(*expressions, limit=limit, offset=offset)
|
||||
|
||||
|
||||
# @router.post("/", response_model=Product)
|
||||
# async def create_product(
|
||||
# request: Request,
|
||||
# data: CreateProductPayload,
|
||||
# contoller: CreateProduct = Depends(CreateProduct),
|
||||
# ):
|
||||
# return await contoller.call(data)
|
||||
#
|
||||
#
|
||||
# @router.get("/by_barcode", response_model=Product)
|
||||
# async def get_by_bar_code(
|
||||
# request: Request,
|
||||
# barcode: str,
|
||||
# contoller: GetProductByBarCode = Depends(GetProductByBarCode),
|
||||
# ):
|
||||
# return await contoller.call(barcode)
|
||||
@router.patch("/{product_id}", response_model=ProductModel)
|
||||
async def update_product(
|
||||
product_id: int,
|
||||
data: ProductUpdateModel,
|
||||
ctx: Context = Depends(AuthContextDependency()),
|
||||
):
|
||||
obj = await ctx.repo.product.get(Product.id == product_id)
|
||||
async with ctx.repo.transaction():
|
||||
await ProductController(ctx, obj).update(data)
|
||||
return obj
|
||||
|
||||
|
||||
@router.post("", response_model=ProductModel, status_code=201)
|
||||
async def create_product(
|
||||
data: ProductCreateModel,
|
||||
ctx: Context = Depends(AuthContextDependency()),
|
||||
):
|
||||
async with ctx.repo.transaction():
|
||||
ctrl = await ProductController.create(ctx, data)
|
||||
return ctrl.obj
|
||||
|
|
|
|||
|
|
@ -6,14 +6,14 @@ from datetime import datetime
|
|||
|
||||
from fooder.model.token import TokenResponse
|
||||
from fooder.context import ContextDependency, Context
|
||||
from fooder.controller import UserController
|
||||
from fooder.controller.token import TokenController
|
||||
from fooder.controller.user import UserController
|
||||
from fooder.utils.jwt import RefreshToken, generate_token_pair
|
||||
|
||||
router = APIRouter(tags=["token"])
|
||||
|
||||
|
||||
def gen_token_response(token_ctrl: TokenController, now: datetime) -> TokenResponse:
|
||||
access_token, refresh_token = token_ctrl.generate_token_pair(now)
|
||||
def gen_token_response(entity_id: int, now: datetime) -> TokenResponse:
|
||||
access_token, refresh_token = generate_token_pair(entity_id, now)
|
||||
return TokenResponse(
|
||||
access_token=access_token.encode(),
|
||||
refresh_token=refresh_token.encode(),
|
||||
|
|
@ -27,7 +27,7 @@ async def token_create(
|
|||
) -> TokenResponse:
|
||||
now = ctx.clock()
|
||||
user_ctrl = await UserController.session_start(ctx, data.username, data.password)
|
||||
return gen_token_response(user_ctrl.token_ctrl(), now)
|
||||
return gen_token_response(user_ctrl.obj.id, now)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
|
|
@ -36,5 +36,5 @@ async def token_refresh(
|
|||
ctx: Context = Depends(ContextDependency()),
|
||||
) -> TokenResponse:
|
||||
now = ctx.clock()
|
||||
token_ctrl = TokenController.from_refresh_token(ctx, refresh_token)
|
||||
return gen_token_response(token_ctrl, now)
|
||||
token = RefreshToken.decode(refresh_token)
|
||||
return gen_token_response(token.sub, now)
|
||||
|
|
|
|||
Loading…
Reference in a new issue