diff --git a/fooder/context.py b/fooder/context.py index 31f96d8..83530fd 100644 --- a/fooder/context.py +++ b/fooder/context.py @@ -7,7 +7,8 @@ from fooder.db import get_db_session from fooder.domain import User from fooder.repository import Repository from fooder.utils.datetime import utc_now -from fooder.exc import Unauthorized +from fooder.exc import Unauthorized, NotFound +from fooder.utils.jwt import AccessToken class Context: @@ -61,11 +62,11 @@ class AuthContextDependency: token: str = Depends(OAuth2PasswordBearer(tokenUrl="/token")), session: AsyncSession = Depends(get_db_session), ) -> Context: - ctx = Context(repo = Repository(session)) - from fooder.controller.token import TokenController - token_ctrl = TokenController.from_access_token(ctx, token) - user = await ctx.repo.user.get(User.id == token_ctrl.entity_id) - if user is None: + ctx = Context(repo=Repository(session)) + user_id = AccessToken.decode(token).sub + try: + user = await ctx.repo.user.get(User.id == user_id) + except NotFound: raise Unauthorized() ctx.set_user(user) return ctx diff --git a/fooder/controller/__init__.py b/fooder/controller/__init__.py index ef2a6a2..0801b27 100644 --- a/fooder/controller/__init__.py +++ b/fooder/controller/__init__.py @@ -1,2 +1,2 @@ from .user import UserController -from .token import TokenController +from .product import ProductController diff --git a/fooder/controller/product.py b/fooder/controller/product.py new file mode 100644 index 0000000..09bd254 --- /dev/null +++ b/fooder/controller/product.py @@ -0,0 +1,53 @@ +from fooder.controller.base import ModelController +from fooder.domain import Product +from fooder.context import Context +from fooder.model.product import ProductCreateModel, ProductUpdateModel +from fooder.utils.calories import calculate_calories + + +class ProductController(ModelController[Product]): + @classmethod + async def create( + cls, + ctx: Context, + data: ProductCreateModel, + ) -> "ProductController": + obj = Product( + name=data.name, + protein=data.protein, + carb=data.carb, + fat=data.fat, + fiber=data.fiber, + calories=data.resolved_calories, + barcode=data.barcode, + ) + await ctx.repo.product.create(obj) + return cls(ctx, obj) + + async def update(self, data: ProductUpdateModel) -> None: + if data.name is not None: + self.obj.name = data.name + + if data.protein is not None: + self.obj.protein = data.protein + + if data.carb is not None: + self.obj.carb = data.carb + + if data.fat is not None: + self.obj.fat = data.fat + + if data.fiber is not None: + self.obj.fiber = data.fiber + + if data.barcode is not None: + self.obj.barcode = data.barcode + + self.obj.calories = data.calories if data.calories is not None else calculate_calories( + protein=self.obj.protein, + carb=self.obj.carb, + fat=self.obj.fat, + fiber=self.obj.fiber, + ) + + await self.ctx.repo.product.update(self.obj) diff --git a/fooder/controller/token.py b/fooder/controller/token.py deleted file mode 100644 index 3e984f6..0000000 --- a/fooder/controller/token.py +++ /dev/null @@ -1,38 +0,0 @@ -from fooder.controller.base import ControllerBase -from fooder.context import Context -from fooder.utils.jwt import Token, AccessToken, RefreshToken -from typing import Type, TypeVar -from datetime import datetime - -T = TypeVar("T", bound=Token) - - -class TokenController(ControllerBase): - def __init__(self, ctx: Context, entity_id: int) -> None: - super().__init__(ctx) - self.entity_id = entity_id - - @classmethod - def from_token(cls, ctx: Context, token_str: str, token_cls: Type[T]) -> "TokenController": - token = token_cls.decode(token_str) - return cls(ctx, token.sub) - - @classmethod - def from_refresh_token(cls, ctx: Context, token_str: str) -> "TokenController": - return cls.from_token(ctx, token_str, RefreshToken) - - @classmethod - def from_access_token(cls, ctx: Context, token_str: str) -> "TokenController": - return cls.from_token(ctx, token_str, AccessToken) - - def generate_token(self, token_cls: Type[T], now: datetime) -> T: - return token_cls(exp=token_cls.calculate_exp(now), sub=self.entity_id) - - def generate_refresh_token(self, now: datetime) -> RefreshToken: - return self.generate_token(RefreshToken, now) - - def generate_access_token(self, now: datetime) -> AccessToken: - return self.generate_token(AccessToken, now) - - def generate_token_pair(self, now: datetime) -> tuple[AccessToken, RefreshToken]: - return (self.generate_access_token(now), self.generate_refresh_token(now)) diff --git a/fooder/controller/user.py b/fooder/controller/user.py index 8d320fa..26764c7 100644 --- a/fooder/controller/user.py +++ b/fooder/controller/user.py @@ -1,8 +1,7 @@ from fooder.controller.base import ModelController -from fooder.controller.token import TokenController from fooder.domain import User from fooder.context import Context -from fooder.exc import Unauthorized +from fooder.exc import Unauthorized, NotFound class UserController(ModelController[User]): @@ -13,12 +12,12 @@ class UserController(ModelController[User]): username: str, password: str, ) -> "UserController": - obj = await ctx.repo.user.get(User.username == username) + try: + obj = await ctx.repo.user.get(User.username == username) + except NotFound: + raise Unauthorized() - if obj is None or not obj.verify_password(password): + if not obj.verify_password(password): raise Unauthorized() return cls(ctx, obj) - - def token_ctrl(self) -> TokenController: - return TokenController(ctx=self.ctx, entity_id=self.obj.id) diff --git a/fooder/db.py b/fooder/db.py index ab53951..f3c823c 100644 --- a/fooder/db.py +++ b/fooder/db.py @@ -23,7 +23,8 @@ class DatabaseSessionManager: ), ) self._sessionmaker = async_sessionmaker( - autocommit=False, autoflush=False, future=True, bind=self._engine + autocommit=False, autoflush=False, future=True, bind=self._engine, + expire_on_commit=False, ) async def close(self) -> None: diff --git a/fooder/domain/product.py b/fooder/domain/product.py index cbb0c85..9cb72e1 100644 --- a/fooder/domain/product.py +++ b/fooder/domain/product.py @@ -1,4 +1,4 @@ -from sqlalchemy.orm import Mapped +from sqlalchemy.orm import Mapped, mapped_column from fooder.domain.base import Base, CommonMixin @@ -12,16 +12,5 @@ class Product(Base, CommonMixin): carb: Mapped[float] fat: Mapped[float] fiber: Mapped[float] - hard_coded_calories: Mapped[float | None] - barcode: Mapped[str | None] - - @property - def calories(self) -> float: - """calories. - - :rtype: float - """ - if self.hard_coded_calories: - return self.hard_coded_calories - - return self.protein * 4 + self.carb * 4 + self.fat * 9 + self.fiber * 2 + calories: Mapped[float] + barcode: Mapped[str | None] = mapped_column(unique=True) diff --git a/fooder/exc.py b/fooder/exc.py index fb08191..9752a9e 100644 --- a/fooder/exc.py +++ b/fooder/exc.py @@ -17,3 +17,13 @@ class NotFound(ApiException): class Unauthorized(ApiException): HTTP_CODE = 401 MESSAGE = "Unathorized" + + +class InvalidValue(ApiException): + HTTP_CODE = 400 + MESSAGE = "Invalid value" + + +class Conflict(ApiException): + HTTP_CODE = 409 + MESSAGE = "Conflict" diff --git a/fooder/model/product.py b/fooder/model/product.py index 89ac253..754baf8 100644 --- a/fooder/model/product.py +++ b/fooder/model/product.py @@ -1,14 +1,15 @@ from .base import ObjModelMixin -from pydantic import BaseModel +from pydantic import BaseModel, Field +from fooder.utils.calories import calculate_calories class ProductModelBase(BaseModel): name: str - protein: float - carb: float - fat: float - fiber: float - calories: float + protein: float = Field(ge=0, le=100) + carb: float = Field(ge=0, le=100) + fat: float = Field(ge=0, le=100) + fiber: float = Field(ge=0, le=100) + calories: float = Field(ge=0) barcode: str | None @@ -17,7 +18,17 @@ class ProductModel(ObjModelMixin, ProductModelBase): class ProductCreateModel(ProductModelBase): - pass + calories: float | None = None + barcode: str | None = None + + @property + def resolved_calories(self) -> float: + return self.calories or calculate_calories( + protein=self.protein, + carb=self.carb, + fat=self.fat, + fiber=self.fiber, + ) class ProductUpdateModel(ProductModelBase): diff --git a/fooder/repository/base.py b/fooder/repository/base.py index 39b0340..ef6e8a0 100644 --- a/fooder/repository/base.py +++ b/fooder/repository/base.py @@ -1,8 +1,10 @@ from typing import TypeVar, Generic, Type, Sequence from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, delete as sa_delete, ColumnElement +from sqlalchemy.exc import IntegrityError from sqlalchemy.sql import Select from fooder.domain import Base +from fooder.exc import Conflict, NotFound T = TypeVar("T", bound=Base) @@ -23,10 +25,15 @@ class RepositoryBase(Generic[T]): return stmt - async def get(self, *expressions: ColumnElement) -> T | None: + async def get(self, *expressions: ColumnElement) -> T: stmt = self._build_select(*expressions) result = await self.session.execute(stmt) - return result.scalar_one_or_none() + obj = result.scalar_one_or_none() + + if obj is None: + raise NotFound() + + return obj async def list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]: stmt = self._build_select(*expressions) @@ -42,10 +49,20 @@ class RepositoryBase(Generic[T]): async def create(self, obj: T) -> T: self.session.add(obj) - await self.session.flush() + try: + await self.session.flush() + except IntegrityError: + raise Conflict() await self.session.refresh(obj) return obj + async def update(self, obj: T) -> T: + try: + await self.session.flush() + except IntegrityError: + raise Conflict() + return obj + async def delete(self, *expressions: ColumnElement): stmt = sa_delete(self.model) diff --git a/fooder/repository/expression.py b/fooder/repository/expression.py index 953275c..0a861db 100644 --- a/fooder/repository/expression.py +++ b/fooder/repository/expression.py @@ -4,5 +4,5 @@ from sqlalchemy import ColumnElement def fuzzy_match(attr: InstrumentedAttribute[str], q: str) -> ColumnElement: q_list = q.split() - qq = "%" + "%".join(q_list) + "%" - return attr.ilike(f"%{qq.lower()}%") + qq = "%".join(q_list) + return attr.ilike(f"%{qq}%") diff --git a/fooder/repository/repository.py b/fooder/repository/repository.py index 8aa467d..31df8d7 100644 --- a/fooder/repository/repository.py +++ b/fooder/repository/repository.py @@ -1,8 +1,12 @@ +from contextlib import asynccontextmanager +from typing import AsyncIterator from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.exc import IntegrityError -from .user import UserRepository -from .product import ProductRepository -from ..domain import User, Product +from fooder.repository.user import UserRepository +from fooder.repository.product import ProductRepository +from fooder.domain import User, Product +from fooder.exc import Conflict class Repository: @@ -12,7 +16,19 @@ class Repository: self.product = ProductRepository(Product, session) async def commit(self) -> None: - await self.session.commit() + try: + await self.session.commit() + except IntegrityError: + raise Conflict() async def rollback(self) -> None: await self.session.rollback() + + @asynccontextmanager + async def transaction(self) -> AsyncIterator["Repository"]: + try: + yield self + await self.commit() + except Exception: + await self.rollback() + raise diff --git a/fooder/router.py b/fooder/router.py index cd9ae26..058ddb7 100644 --- a/fooder/router.py +++ b/fooder/router.py @@ -1,6 +1,8 @@ from fastapi import APIRouter from fooder.view.token import router as token_router +from fooder.view.product import router as product_router router = APIRouter(prefix="/api") router.include_router(token_router, prefix="/token", tags=["token"]) +router.include_router(product_router, prefix="/product", tags=["product"]) diff --git a/fooder/test/controller/test_token.py b/fooder/test/controller/test_token.py deleted file mode 100644 index 15bcab9..0000000 --- a/fooder/test/controller/test_token.py +++ /dev/null @@ -1,46 +0,0 @@ -from datetime import datetime, timezone - -import pytest - -from fooder.controller.token import TokenController -from fooder.exc import Unauthorized -from fooder.utils.jwt import AccessToken, RefreshToken - -NOW = datetime.now(timezone.utc) - - -def test_token_ctrl_generates_token(ctx): - token_ctrl = TokenController(ctx, 1) - token_ctrl.generate_token_pair(ctx.clock()) - - -class TestFromRefreshToken: - def test_returns_controller_with_correct_entity_id(self, ctx): - token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=42) - ctrl = TokenController.from_refresh_token(ctx, token.encode()) - assert ctrl.entity_id == 42 - - def test_invalid_string_raises(self, ctx): - with pytest.raises(Unauthorized): - TokenController.from_refresh_token(ctx, "bad-token") - - def test_access_token_raises(self, ctx): - token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=1) - with pytest.raises(Unauthorized): - TokenController.from_refresh_token(ctx, token.encode()) - - -class TestFromAccessToken: - def test_returns_controller_with_correct_entity_id(self, ctx): - token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=7) - ctrl = TokenController.from_access_token(ctx, token.encode()) - assert ctrl.entity_id == 7 - - def test_invalid_string_raises(self, ctx): - with pytest.raises(Unauthorized): - TokenController.from_access_token(ctx, "bad-token") - - def test_refresh_token_raises(self, ctx): - token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=1) - with pytest.raises(Unauthorized): - TokenController.from_access_token(ctx, token.encode()) \ No newline at end of file diff --git a/fooder/test/fixtures/__init__.py b/fooder/test/fixtures/__init__.py index a9fae93..09a0005 100644 --- a/fooder/test/fixtures/__init__.py +++ b/fooder/test/fixtures/__init__.py @@ -4,6 +4,7 @@ from .faker import * from .user import * from .client import * from .context import * +from .product import * @pytest.fixture diff --git a/fooder/test/fixtures/client.py b/fooder/test/fixtures/client.py index a371cf6..06ece03 100644 --- a/fooder/test/fixtures/client.py +++ b/fooder/test/fixtures/client.py @@ -14,3 +14,15 @@ async def client(db_session): async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c: yield c app.dependency_overrides.clear() + + +@pytest_asyncio.fixture +async def auth_client(client, user, user_password): + response = await client.post( + "/api/token", + data={"username": user.username, "password": user_password}, + ) + token = response.json()["access_token"] + client.headers["Authorization"] = f"Bearer {token}" + yield client + del client.headers["Authorization"] diff --git a/fooder/test/fixtures/db.py b/fooder/test/fixtures/db.py index 5df90f7..33543b2 100644 --- a/fooder/test/fixtures/db.py +++ b/fooder/test/fixtures/db.py @@ -47,7 +47,7 @@ def db_manager() -> DatabaseSessionManager: async def db_session(db_manager): async with db_manager._engine.connect() as conn: trans = await conn.begin() - session = AsyncSession(bind=conn) + session = AsyncSession(bind=conn, expire_on_commit=False) nested = await conn.begin_nested() diff --git a/fooder/test/fixtures/product.py b/fooder/test/fixtures/product.py new file mode 100644 index 0000000..faaac39 --- /dev/null +++ b/fooder/test/fixtures/product.py @@ -0,0 +1,26 @@ +import pytest +import pytest_asyncio + +from fooder.controller.product import ProductController +from fooder.model.product import ProductCreateModel + + +@pytest.fixture +def product_payload(): + return { + "name": "Chicken Breast", + "protein": 31.0, + "carb": 0.0, + "fat": 3.6, + "fiber": 0.0, + } + + +@pytest_asyncio.fixture +async def product(ctx): + data = ProductCreateModel(name="Chicken Breast", protein=31.0, carb=0.0, fat=3.6, fiber=0.0) + async with ctx.repo.transaction(): + ctrl = await ProductController.create(ctx, data) + return ctrl.obj + + diff --git a/fooder/test/repository/test_base.py b/fooder/test/repository/test_base.py index 5973c36..6a29b67 100644 --- a/fooder/test/repository/test_base.py +++ b/fooder/test/repository/test_base.py @@ -1,5 +1,6 @@ import pytest from ..fixtures.db import TestModel +from fooder.exc import NotFound # ------------------------------------------------------------------ create --- @@ -19,9 +20,9 @@ async def test_get_returns_existing_record(test_repo, test_model): assert found.id == created.id -async def test_get_returns_none_for_missing_record(test_repo): - result = await test_repo.get(TestModel.id == 1) - assert result is None +async def test_get_raises_not_found_for_missing_record(test_repo): + with pytest.raises(NotFound): + await test_repo.get(TestModel.id == 1) async def test_get_by_field(test_repo, test_model_factory): @@ -62,4 +63,5 @@ async def test_list_returns_empty_when_no_match(test_repo, test_model_factory): async def test_delete_removes_record(test_repo, test_model): model = await test_repo.create(test_model) await test_repo.delete(TestModel.id == model.id) - assert await test_repo.get(TestModel.id == model.id) is None + with pytest.raises(NotFound): + await test_repo.get(TestModel.id == model.id) diff --git a/fooder/test/view/test_product.py b/fooder/test/view/test_product.py new file mode 100644 index 0000000..31f354d --- /dev/null +++ b/fooder/test/view/test_product.py @@ -0,0 +1,106 @@ +import pytest + + +async def test_update_product_returns_200(auth_client, product): + response = await auth_client.patch(f"/api/product/{product.id}", json={"name": "Updated Name"}) + assert response.status_code == 200 + assert response.json()["name"] == "Updated Name" + + +async def test_update_product_recalculates_calories_when_macros_change(auth_client, product): + response = await auth_client.patch(f"/api/product/{product.id}", json={"protein": 10.0}) + # protein=10, carb=0, fat=3.6, fiber=0 → 10*4 + 3.6*9 = 40 + 32.4 = 72.4 + assert response.json()["calories"] == pytest.approx(72.4) + + +async def test_update_product_uses_explicit_calories(auth_client, product): + response = await auth_client.patch( + f"/api/product/{product.id}", json={"protein": 10.0, "calories": 99.0} + ) + assert response.json()["calories"] == 99.0 + + +async def test_update_product_not_found_returns_404(auth_client): + response = await auth_client.patch("/api/product/99999", json={"name": "Ghost"}) + assert response.status_code == 404 + + +async def test_update_product_duplicate_barcode_returns_409(auth_client, product, product_payload): + await auth_client.post("/api/product", json={**product_payload, "name": "Other", "barcode": "AAA"}) + response = await auth_client.patch(f"/api/product/{product.id}", json={"barcode": "AAA"}) + assert response.status_code == 409 + + +async def test_update_product_without_auth_returns_401(client, product): + response = await client.patch(f"/api/product/{product.id}", json={"name": "x"}) + assert response.status_code == 401 + + +async def test_create_product_returns_201(auth_client, product_payload): + response = await auth_client.post("/api/product", json=product_payload) + assert response.status_code == 201 + + +async def test_create_product_returns_correct_fields(auth_client, product_payload): + response = await auth_client.post("/api/product", json=product_payload) + body = response.json() + assert body["name"] == product_payload["name"] + assert body["protein"] == product_payload["protein"] + assert body["carb"] == product_payload["carb"] + assert body["fat"] == product_payload["fat"] + assert body["fiber"] == product_payload["fiber"] + assert "id" in body + + +async def test_create_product_calculates_calories(auth_client, product_payload): + response = await auth_client.post("/api/product", json=product_payload) + # 31*4 + 0*4 + 3.6*9 + 0*2 = 124 + 32.4 = 156.4 + assert response.json()["calories"] == pytest.approx(156.4) + + +async def test_create_product_uses_explicit_calories(auth_client, product_payload): + response = await auth_client.post("/api/product", json={**product_payload, "calories": 50.0}) + assert response.json()["calories"] == 50.0 + + +async def test_create_product_duplicate_barcode_returns_409(auth_client, product_payload): + await auth_client.post("/api/product", json={**product_payload, "barcode": "123456"}) + response = await auth_client.post("/api/product", json={**product_payload, "barcode": "123456"}) + assert response.status_code == 409 + + +async def test_create_product_invalid_protein_returns_422(auth_client, product_payload): + response = await auth_client.post("/api/product", json={**product_payload, "protein": -1.0}) + assert response.status_code == 422 + + +async def test_create_product_protein_over_100_returns_422(auth_client, product_payload): + response = await auth_client.post("/api/product", json={**product_payload, "protein": 101.0}) + assert response.status_code == 422 + + +async def test_create_product_without_auth_returns_401(client, product_payload): + response = await client.post("/api/product", json=product_payload) + assert response.status_code == 401 + + +async def test_list_products_returns_200(auth_client, product): + response = await auth_client.get("/api/product") + assert response.status_code == 200 + assert isinstance(response.json(), list) + + +async def test_list_products_contains_created(auth_client, product): + response = await auth_client.get("/api/product") + ids = [p["id"] for p in response.json()] + assert product.id in ids + + +async def test_list_products_filters_by_name(auth_client, product): + response = await auth_client.get("/api/product", params={"q": product.name}) + assert all(product.name.lower() in p["name"].lower() for p in response.json()) + + +async def test_list_products_without_auth_returns_401(client): + response = await client.get("/api/product") + assert response.status_code == 401 diff --git a/fooder/utils/calories.py b/fooder/utils/calories.py new file mode 100644 index 0000000..1008140 --- /dev/null +++ b/fooder/utils/calories.py @@ -0,0 +1,20 @@ +PROTEIN_KCAL = 4 +CARB_KCAL = 4 +FAT_KCAL = 9 +FIBER_KCAL = 2 + + +def calculate_calories( + protein: float, + carb: float, + fat: float, + fiber: float, +) -> float: + return sum( + ( + PROTEIN_KCAL * protein, + CARB_KCAL * carb, + FAT_KCAL * fat, + FIBER_KCAL * fiber, + ) + ) diff --git a/fooder/utils/jwt.py b/fooder/utils/jwt.py index 0990b63..119ef86 100644 --- a/fooder/utils/jwt.py +++ b/fooder/utils/jwt.py @@ -43,3 +43,10 @@ class AccessToken(Token): class RefreshToken(Token): secret_key = settings.REFRESH_SECRET_KEY expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) + + +def generate_token_pair(entity_id: int, now: datetime) -> tuple[AccessToken, RefreshToken]: + return ( + AccessToken(exp=AccessToken.calculate_exp(now), sub=entity_id), + RefreshToken(exp=RefreshToken.calculate_exp(now), sub=entity_id), + ) diff --git a/fooder/view/product.py b/fooder/view/product.py index 35bc364..22c45a4 100644 --- a/fooder/view/product.py +++ b/fooder/view/product.py @@ -1,45 +1,46 @@ -from typing import Optional - from fastapi import APIRouter, Depends from fooder.repository.expression import fuzzy_match from fooder.domain import Product -from fooder.model.product import ProductModel +from fooder.model.product import ProductModel, ProductCreateModel, ProductUpdateModel +from fooder.controller.product import ProductController from fooder.context import Context, AuthContextDependency router = APIRouter(tags=["product"]) -@router.get("/", response_model=list[ProductModel]) +@router.get("", response_model=list[ProductModel]) async def list_products( ctx: Context = Depends(AuthContextDependency()), limit: int = 10, offset: int = 0, - q: Optional[str] = None, + q: str | None = None, ): expressions = [] if q: expressions.append( fuzzy_match(Product.name, q) ) - - objs = await ctx.repo.product.list(*expressions, limit=limit, offset=offset) - return [ProductModel.model_validate(obj) for obj in objs] + return await ctx.repo.product.list(*expressions, limit=limit, offset=offset) -# @router.post("/", response_model=Product) -# async def create_product( -# request: Request, -# data: CreateProductPayload, -# contoller: CreateProduct = Depends(CreateProduct), -# ): -# return await contoller.call(data) -# -# -# @router.get("/by_barcode", response_model=Product) -# async def get_by_bar_code( -# request: Request, -# barcode: str, -# contoller: GetProductByBarCode = Depends(GetProductByBarCode), -# ): -# return await contoller.call(barcode) +@router.patch("/{product_id}", response_model=ProductModel) +async def update_product( + product_id: int, + data: ProductUpdateModel, + ctx: Context = Depends(AuthContextDependency()), +): + obj = await ctx.repo.product.get(Product.id == product_id) + async with ctx.repo.transaction(): + await ProductController(ctx, obj).update(data) + return obj + + +@router.post("", response_model=ProductModel, status_code=201) +async def create_product( + data: ProductCreateModel, + ctx: Context = Depends(AuthContextDependency()), +): + async with ctx.repo.transaction(): + ctrl = await ProductController.create(ctx, data) + return ctrl.obj diff --git a/fooder/view/token.py b/fooder/view/token.py index 7ec5ac8..ac6e8ad 100644 --- a/fooder/view/token.py +++ b/fooder/view/token.py @@ -6,14 +6,14 @@ from datetime import datetime from fooder.model.token import TokenResponse from fooder.context import ContextDependency, Context -from fooder.controller import UserController -from fooder.controller.token import TokenController +from fooder.controller.user import UserController +from fooder.utils.jwt import RefreshToken, generate_token_pair router = APIRouter(tags=["token"]) -def gen_token_response(token_ctrl: TokenController, now: datetime) -> TokenResponse: - access_token, refresh_token = token_ctrl.generate_token_pair(now) +def gen_token_response(entity_id: int, now: datetime) -> TokenResponse: + access_token, refresh_token = generate_token_pair(entity_id, now) return TokenResponse( access_token=access_token.encode(), refresh_token=refresh_token.encode(), @@ -27,7 +27,7 @@ async def token_create( ) -> TokenResponse: now = ctx.clock() user_ctrl = await UserController.session_start(ctx, data.username, data.password) - return gen_token_response(user_ctrl.token_ctrl(), now) + return gen_token_response(user_ctrl.obj.id, now) @router.post("/refresh", response_model=TokenResponse) @@ -36,5 +36,5 @@ async def token_refresh( ctx: Context = Depends(ContextDependency()), ) -> TokenResponse: now = ctx.clock() - token_ctrl = TokenController.from_refresh_token(ctx, refresh_token) - return gen_token_response(token_ctrl, now) + token = RefreshToken.decode(refresh_token) + return gen_token_response(token.sub, now)