2024-05-21 11:11:47 +02:00
|
|
|
from sqlalchemy.orm import Mapped, mapped_column
|
|
|
|
from sqlalchemy import select, BigInteger, func, update
|
2023-04-01 16:19:12 +02:00
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
2023-04-01 20:13:11 +02:00
|
|
|
from typing import AsyncIterator, Optional
|
2023-04-01 16:19:12 +02:00
|
|
|
|
|
|
|
from .base import Base, CommonMixin
|
|
|
|
|
|
|
|
|
|
|
|
class Product(Base, CommonMixin):
|
|
|
|
"""Product."""
|
|
|
|
|
|
|
|
name: Mapped[str]
|
|
|
|
|
|
|
|
protein: Mapped[float]
|
|
|
|
carb: Mapped[float]
|
|
|
|
fat: Mapped[float]
|
2023-07-30 20:18:42 +02:00
|
|
|
fiber: Mapped[float]
|
2024-05-20 13:45:21 +02:00
|
|
|
hard_coded_calories: Mapped[Optional[float]]
|
|
|
|
barcode: Mapped[Optional[str]]
|
2023-04-01 16:19:12 +02:00
|
|
|
|
2024-05-21 11:11:47 +02:00
|
|
|
usage_count_cached: Mapped[int] = mapped_column(
|
|
|
|
BigInteger,
|
|
|
|
default=0,
|
|
|
|
nullable=False,
|
|
|
|
)
|
|
|
|
|
2023-04-01 16:19:12 +02:00
|
|
|
@property
|
|
|
|
def calories(self) -> float:
|
|
|
|
"""calories.
|
|
|
|
|
|
|
|
:rtype: float
|
|
|
|
"""
|
2024-03-25 18:22:18 +01:00
|
|
|
if self.hard_coded_calories:
|
|
|
|
return self.hard_coded_calories
|
|
|
|
|
2023-07-30 20:18:42 +02:00
|
|
|
return self.protein * 4 + self.carb * 4 + self.fat * 9 + self.fiber * 2
|
2023-04-01 16:19:12 +02:00
|
|
|
|
|
|
|
@classmethod
|
2023-04-01 20:13:11 +02:00
|
|
|
async def list_all(
|
2024-03-25 18:22:18 +01:00
|
|
|
cls,
|
|
|
|
session: AsyncSession,
|
|
|
|
offset: int,
|
|
|
|
limit: int,
|
|
|
|
q: Optional[str] = None,
|
2023-04-01 20:13:11 +02:00
|
|
|
) -> AsyncIterator["Product"]:
|
|
|
|
query = select(cls)
|
|
|
|
|
|
|
|
if q:
|
2023-11-09 00:24:36 +01:00
|
|
|
q_list = q.split()
|
2024-05-21 11:11:47 +02:00
|
|
|
qq = "%" + "%".join(q_list) + "%"
|
|
|
|
query = query.filter(cls.name.ilike(f"%{qq.lower()}%"))
|
2023-04-01 20:13:11 +02:00
|
|
|
|
|
|
|
query = query.offset(offset).limit(limit)
|
2024-05-21 11:11:47 +02:00
|
|
|
stream = await session.stream_scalars(
|
|
|
|
query.order_by(cls.usage_count_cached.desc())
|
|
|
|
)
|
2023-04-01 16:19:12 +02:00
|
|
|
async for row in stream:
|
|
|
|
yield row
|
|
|
|
|
2024-03-25 18:22:18 +01:00
|
|
|
@classmethod
|
|
|
|
async def get_by_barcode(
|
|
|
|
cls, session: AsyncSession, barcode: str
|
|
|
|
) -> Optional["Product"]:
|
|
|
|
query = select(cls).where(cls.barcode == barcode)
|
|
|
|
return await session.scalar(query)
|
|
|
|
|
2023-04-01 16:19:12 +02:00
|
|
|
@classmethod
|
|
|
|
async def create(
|
2023-07-30 20:18:42 +02:00
|
|
|
cls,
|
|
|
|
session: AsyncSession,
|
|
|
|
name: str,
|
|
|
|
carb: float,
|
|
|
|
protein: float,
|
|
|
|
fat: float,
|
|
|
|
fiber: float,
|
2024-03-25 18:22:18 +01:00
|
|
|
hard_coded_calories: Optional[float] = None,
|
|
|
|
barcode: Optional[str] = None,
|
2023-04-01 16:19:12 +02:00
|
|
|
) -> "Product":
|
|
|
|
# validation here
|
|
|
|
assert carb <= 100, "carb must be less than 100"
|
|
|
|
assert protein <= 100, "protein must be less than 100"
|
|
|
|
assert fat <= 100, "fat must be less than 100"
|
2023-07-30 20:18:42 +02:00
|
|
|
assert fiber <= 100, "fiber must be less than 100"
|
2023-04-01 16:19:12 +02:00
|
|
|
assert carb >= 0, "carb must be greater than 0"
|
|
|
|
assert protein >= 0, "protein must be greater than 0"
|
|
|
|
assert fat >= 0, "fat must be greater than 0"
|
2023-07-30 20:18:42 +02:00
|
|
|
assert fiber >= 0, "fiber must be greater than 0"
|
2023-04-01 16:19:12 +02:00
|
|
|
assert carb + protein + fat <= 100, "total must be less than 100"
|
|
|
|
|
|
|
|
# to avoid duplicates in the database keep name as lower
|
|
|
|
name = name.lower()
|
|
|
|
|
|
|
|
# check if product already exists
|
2024-03-25 18:22:18 +01:00
|
|
|
if barcode is not None:
|
|
|
|
query = select(cls).where((cls.name == name) | (cls.barcode == barcode))
|
|
|
|
else:
|
|
|
|
query = select(cls).where(cls.name == name)
|
|
|
|
|
2023-04-01 16:19:12 +02:00
|
|
|
existing_product = await session.scalar(query)
|
|
|
|
assert existing_product is None, "product already exists"
|
|
|
|
|
|
|
|
product = Product(
|
|
|
|
name=name,
|
|
|
|
protein=protein,
|
|
|
|
carb=carb,
|
|
|
|
fat=fat,
|
2023-07-30 20:44:06 +02:00
|
|
|
fiber=fiber,
|
2024-03-25 18:22:18 +01:00
|
|
|
hard_coded_calories=hard_coded_calories,
|
|
|
|
barcode=barcode,
|
2023-04-01 16:19:12 +02:00
|
|
|
)
|
2024-03-25 18:22:18 +01:00
|
|
|
|
2023-04-01 16:19:12 +02:00
|
|
|
session.add(product)
|
|
|
|
await session.flush()
|
|
|
|
return product
|
2024-05-21 11:11:47 +02:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
async def cache_usage_data(
|
|
|
|
cls,
|
|
|
|
session: AsyncSession,
|
|
|
|
) -> None:
|
|
|
|
from .entry import Entry
|
|
|
|
|
|
|
|
stmt = (
|
|
|
|
update(cls)
|
|
|
|
.where(
|
|
|
|
cls.id.in_(
|
|
|
|
select(Entry.product_id).where(Entry.processed == False).distinct()
|
|
|
|
)
|
|
|
|
)
|
|
|
|
.values(
|
|
|
|
usage_count_cached=select(func.count(Entry.id)).where(
|
|
|
|
Entry.product_id == cls.id,
|
|
|
|
Entry.processed == False,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
await session.execute(stmt)
|
|
|
|
await Entry.mark_processed(session)
|