diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 95b7416..ed156c3 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -15,7 +15,7 @@ services: api: restart: unless-stopped - image: api + image: registry.domandoman.xyz/fooder/api build: dockerfile: Dockerfile context: . diff --git a/fooder/controller/product.py b/fooder/controller/product.py index 944df1b..2155ed8 100644 --- a/fooder/controller/product.py +++ b/fooder/controller/product.py @@ -2,6 +2,7 @@ from typing import AsyncIterator, Optional from fastapi import HTTPException +from ..utils import product_finder from ..model.product import Product, CreateProductPayload from ..domain.product import Product as DBProduct from .base import AuthorizedController @@ -33,3 +34,36 @@ class ListProduct(AuthorizedController): session, limit=limit, offset=offset, q=q ): yield Product.from_orm(product) + + +class GetProductByBarCode(AuthorizedController): + async def call(self, barcode: str) -> Product: + async with self.async_session() as session: + product = await DBProduct.get_by_barcode(session, barcode) + + if product: + return Product.from_orm(product) + + try: + product_data = product_finder.find(barcode) + except product_finder.ProductNotFound: + raise HTTPException(status_code=404, detail="Product not found") + except product_finder.ParseError: + raise HTTPException( + status_code=400, detail="Product was found, but unable to import" + ) + + try: + product = await DBProduct.create( + session, + product_data.name, + product_data.carb, + product_data.protein, + product_data.fat, + product_data.fiber, + product_data.kcal, + barcode, + ) + return Product.from_orm(product) + except AssertionError as e: + raise HTTPException(status_code=400, detail=e.args[0]) diff --git a/fooder/domain/product.py b/fooder/domain/product.py index 4b50e8b..8011c8e 100644 --- a/fooder/domain/product.py +++ b/fooder/domain/product.py @@ -15,6 +15,8 @@ class Product(Base, CommonMixin): carb: Mapped[float] fat: Mapped[float] fiber: Mapped[float] + hard_coded_calories: Mapped[Optional[float]] = None + barcode: Mapped[Optional[str]] = None @property def calories(self) -> float: @@ -22,11 +24,18 @@ class Product(Base, CommonMixin): :rtype: float """ + if self.hard_coded_calories: + return self.hard_coded_calories + return self.protein * 4 + self.carb * 4 + self.fat * 9 + self.fiber * 2 @classmethod async def list_all( - cls, session: AsyncSession, offset: int, limit: int, q: Optional[str] = None + cls, + session: AsyncSession, + offset: int, + limit: int, + q: Optional[str] = None, ) -> AsyncIterator["Product"]: query = select(cls) @@ -40,6 +49,13 @@ class Product(Base, CommonMixin): async for row in stream: yield row + @classmethod + async def get_by_barcode( + cls, session: AsyncSession, barcode: str + ) -> Optional["Product"]: + query = select(cls).where(cls.barcode == barcode) + return await session.scalar(query) + @classmethod async def create( cls, @@ -49,6 +65,8 @@ class Product(Base, CommonMixin): protein: float, fat: float, fiber: float, + hard_coded_calories: Optional[float] = None, + barcode: Optional[str] = None, ) -> "Product": # validation here assert carb <= 100, "carb must be less than 100" @@ -65,7 +83,11 @@ class Product(Base, CommonMixin): name = name.lower() # check if product already exists - query = select(cls).where(cls.name == name) + if barcode is not None: + query = select(cls).where((cls.name == name) | (cls.barcode == barcode)) + else: + query = select(cls).where(cls.name == name) + existing_product = await session.scalar(query) assert existing_product is None, "product already exists" @@ -75,7 +97,10 @@ class Product(Base, CommonMixin): carb=carb, fat=fat, fiber=fiber, + hard_coded_calories=hard_coded_calories, + barcode=barcode, ) + session.add(product) await session.flush() return product diff --git a/fooder/test/test_product.py b/fooder/test/test_product.py index c9f4e7d..05fca3d 100644 --- a/fooder/test/test_product.py +++ b/fooder/test/test_product.py @@ -1,5 +1,4 @@ import pytest -import datetime @pytest.mark.dependency() @@ -20,3 +19,9 @@ def test_list_product(client): for product in data: assert product["id"] not in product_ids product_ids.add(product["id"]) + + +@pytest.mark.dependency(depends=["test_create_product"]) +def test_get_product_by_barcode(client): + response = client.get("product/by_barcode", params={"barcode": "4056489666028"}) + assert response.status_code == 200, response.json() diff --git a/fooder/utils/__init__.py b/fooder/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fooder/utils/product_finder.py b/fooder/utils/product_finder.py new file mode 100644 index 0000000..64cffa3 --- /dev/null +++ b/fooder/utils/product_finder.py @@ -0,0 +1,52 @@ +import requests as r +from dataclasses import dataclass +from logging import getLogger + + +logger = getLogger(__name__) + + +class NotFound(Exception): + pass + + +class ParseError(Exception): + pass + + +@dataclass +class Product: + name: str + kcal: float + fat: float + protein: float + carb: float + fiber: float + + +def find(bar_code: str) -> Product: + url = f"https://world.openfoodfacts.org/api/v2/product/{bar_code}.json" + response = r.get(url) + + if response.status_code == 404: + raise NotFound() + + try: + data = response.json() + + name = data["product"]["product_name"] + + if data["product"]["brands"]: + name = data["product"]["brands"] + " " + name + + return Product( + name=name, + kcal=data["product"]["nutriments"]["energy-kcal_100g"], + fat=data["product"]["nutriments"]["fat_100g"], + protein=data["product"]["nutriments"]["proteins_100g"], + carb=data["product"]["nutriments"]["carbohydrates_100g"], + fiber=data["product"]["nutriments"].get("fiber_100g", 0.0), + ) + except Exception as e: + logger.error(e) + raise ParseError() diff --git a/fooder/view/product.py b/fooder/view/product.py index 7f97c64..c60b339 100644 --- a/fooder/view/product.py +++ b/fooder/view/product.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Depends, Request from ..model.product import Product, CreateProductPayload, ListProductPayload -from ..controller.product import ListProduct, CreateProduct +from ..controller.product import ListProduct, CreateProduct, GetProductByBarCode from typing import Optional @@ -27,3 +27,12 @@ async def create_product( contoller: CreateProduct = Depends(CreateProduct), ): return await contoller.call(data) + + +@router.get("/by_barcode", response_model=Product) +async def get_by_bar_code( + request: Request, + barcode: str, + contoller: GetProductByBarCode = Depends(GetProductByBarCode), +): + return await contoller.call(barcode) diff --git a/requirements.txt b/requirements.txt index aae4f8a..7ce385c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ psycopg2-binary==2.9.3 python-jose[cryptography] passlib[bcrypt] fastapi-users +requests diff --git a/requirements_local.txt b/requirements_local.txt index 3a8626f..ff4b6e7 100644 --- a/requirements_local.txt +++ b/requirements_local.txt @@ -7,3 +7,5 @@ python-jose[cryptography] passlib[bcrypt] fastapi-users pytest +requests +black