fooder-api/fooder/domain/product.py
Piotr Domański e2af49046c
All checks were successful
Python lint and test / linttest (push) Successful in 5m30s
[isort] ran with black profile
2024-08-04 16:17:16 +02:00

140 lines
3.9 KiB
Python

from typing import AsyncIterator, Optional
from sqlalchemy import BigInteger, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base, CommonMixin
class Product(Base, CommonMixin):
"""Product."""
name: Mapped[str]
protein: Mapped[float]
carb: Mapped[float]
fat: Mapped[float]
fiber: Mapped[float]
hard_coded_calories: Mapped[Optional[float]]
barcode: Mapped[Optional[str]]
usage_count_cached: Mapped[int] = mapped_column(
BigInteger,
default=0,
nullable=False,
)
@property
def calories(self) -> float:
"""calories.
:rtype: float
"""
if self.hard_coded_calories:
return self.hard_coded_calories
return self.protein * 4 + self.carb * 4 + self.fat * 9 + self.fiber * 2
@classmethod
async def list_all(
cls,
session: AsyncSession,
offset: int,
limit: int,
q: Optional[str] = None,
) -> AsyncIterator["Product"]:
query = select(cls)
if q:
q_list = q.split()
qq = "%" + "%".join(q_list) + "%"
query = query.filter(cls.name.ilike(f"%{qq.lower()}%"))
query = query.offset(offset).limit(limit)
stream = await session.stream_scalars(
query.order_by(cls.usage_count_cached.desc())
)
async for row in stream:
yield row
@classmethod
async def get_by_barcode(
cls, session: AsyncSession, barcode: str
) -> Optional["Product"]:
query = select(cls).where(cls.barcode == barcode)
return await session.scalar(query)
@classmethod
async def create(
cls,
session: AsyncSession,
name: str,
carb: float,
protein: float,
fat: float,
fiber: float,
hard_coded_calories: Optional[float] = None,
barcode: Optional[str] = None,
) -> "Product":
# validation here
assert carb <= 100, "carb must be less than 100"
assert protein <= 100, "protein must be less than 100"
assert fat <= 100, "fat must be less than 100"
assert fiber <= 100, "fiber must be less than 100"
assert carb >= 0, "carb must be greater than 0"
assert protein >= 0, "protein must be greater than 0"
assert fat >= 0, "fat must be greater than 0"
assert fiber >= 0, "fiber must be greater than 0"
assert carb + protein + fat <= 100, "total must be less than 100"
# to avoid duplicates in the database keep name as lower
name = name.lower()
# check if product already exists
if barcode is not None:
query = select(cls).where((cls.name == name) | (cls.barcode == barcode))
else:
query = select(cls).where(cls.name == name)
existing_product = await session.scalar(query)
assert existing_product is None, "product already exists"
product = Product(
name=name,
protein=protein,
carb=carb,
fat=fat,
fiber=fiber,
hard_coded_calories=hard_coded_calories,
barcode=barcode,
)
session.add(product)
await session.flush()
return product
@classmethod
async def cache_usage_data(
cls,
session: AsyncSession,
) -> None:
from .entry import Entry
stmt = (
update(cls)
.where(
cls.id.in_(
select(Entry.product_id).where(Entry.processed == False).distinct()
)
)
.values(
usage_count_cached=select(func.count(Entry.id)).where(
Entry.product_id == cls.id,
Entry.processed == False,
)
)
)
await session.execute(stmt)
await Entry.mark_processed(session)