This commit is contained in:
Piotr Domański 2026-04-07 16:49:30 +02:00
parent 20437905ba
commit 724c350e99
20 changed files with 170 additions and 120 deletions

View file

@ -1,6 +1,5 @@
from argparse import ArgumentParser from argparse import ArgumentParser
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
group = parser.add_mutually_exclusive_group() group = parser.add_mutually_exclusive_group()

View file

@ -61,60 +61,30 @@ def upgrade() -> None:
sa.UniqueConstraint("user_id"), sa.UniqueConstraint("user_id"),
) )
op.drop_table("refreshtoken") op.drop_table("refreshtoken")
op.add_column( op.add_column("diary", sa.Column("protein_goal", sa.Float(), nullable=False))
"diary", sa.Column("protein_goal", sa.Float(), nullable=False)
)
op.add_column("diary", sa.Column("carb_goal", sa.Float(), nullable=False)) op.add_column("diary", sa.Column("carb_goal", sa.Float(), nullable=False))
op.add_column("diary", sa.Column("fat_goal", sa.Float(), nullable=False)) op.add_column("diary", sa.Column("fat_goal", sa.Float(), nullable=False))
op.add_column("diary", sa.Column("fiber_goal", sa.Float(), nullable=False)) op.add_column("diary", sa.Column("fiber_goal", sa.Float(), nullable=False))
op.add_column( op.add_column("diary", sa.Column("calories_goal", sa.Float(), nullable=False))
"diary", sa.Column("calories_goal", sa.Float(), nullable=False)
)
op.add_column("diary", sa.Column("version", sa.Integer(), nullable=False)) op.add_column("diary", sa.Column("version", sa.Integer(), nullable=False))
op.add_column( op.add_column("diary", sa.Column("created_at", sa.DateTime(), nullable=False))
"diary", sa.Column("created_at", sa.DateTime(), nullable=False) op.add_column("diary", sa.Column("last_changed", sa.DateTime(), nullable=False))
)
op.add_column(
"diary", sa.Column("last_changed", sa.DateTime(), nullable=False)
)
op.create_unique_constraint(None, "diary", ["user_id", "date"]) op.create_unique_constraint(None, "diary", ["user_id", "date"])
op.add_column("entry", sa.Column("version", sa.Integer(), nullable=False)) op.add_column("entry", sa.Column("version", sa.Integer(), nullable=False))
op.add_column( op.add_column("entry", sa.Column("created_at", sa.DateTime(), nullable=False))
"entry", sa.Column("created_at", sa.DateTime(), nullable=False)
)
op.add_column("meal", sa.Column("version", sa.Integer(), nullable=False)) op.add_column("meal", sa.Column("version", sa.Integer(), nullable=False))
op.add_column( op.add_column("meal", sa.Column("created_at", sa.DateTime(), nullable=False))
"meal", sa.Column("created_at", sa.DateTime(), nullable=False) op.add_column("meal", sa.Column("last_changed", sa.DateTime(), nullable=False))
)
op.add_column(
"meal", sa.Column("last_changed", sa.DateTime(), nullable=False)
)
op.add_column("preset", sa.Column("version", sa.Integer(), nullable=False)) op.add_column("preset", sa.Column("version", sa.Integer(), nullable=False))
op.add_column( op.add_column("preset", sa.Column("created_at", sa.DateTime(), nullable=False))
"preset", sa.Column("created_at", sa.DateTime(), nullable=False) op.add_column("preset", sa.Column("last_changed", sa.DateTime(), nullable=False))
) op.add_column("presetentry", sa.Column("version", sa.Integer(), nullable=False))
op.add_column( op.add_column("presetentry", sa.Column("created_at", sa.DateTime(), nullable=False))
"preset", sa.Column("last_changed", sa.DateTime(), nullable=False)
)
op.add_column(
"presetentry", sa.Column("version", sa.Integer(), nullable=False)
)
op.add_column(
"presetentry", sa.Column("created_at", sa.DateTime(), nullable=False)
)
op.add_column("product", sa.Column("calories", sa.Float(), nullable=False)) op.add_column("product", sa.Column("calories", sa.Float(), nullable=False))
op.add_column( op.add_column("product", sa.Column("version", sa.Integer(), nullable=False))
"product", sa.Column("version", sa.Integer(), nullable=False) op.add_column("product", sa.Column("created_at", sa.DateTime(), nullable=False))
) op.add_column("product", sa.Column("last_changed", sa.DateTime(), nullable=False))
op.add_column( op.add_column("product", sa.Column("deleted_at", sa.DateTime(), nullable=True))
"product", sa.Column("created_at", sa.DateTime(), nullable=False)
)
op.add_column(
"product", sa.Column("last_changed", sa.DateTime(), nullable=False)
)
op.add_column(
"product", sa.Column("deleted_at", sa.DateTime(), nullable=True)
)
op.create_index( op.create_index(
"ix_product_barcode", "ix_product_barcode",
"product", "product",
@ -126,15 +96,9 @@ def upgrade() -> None:
op.drop_column("product", "hard_coded_calories") op.drop_column("product", "hard_coded_calories")
op.drop_column("product", "usage_count_cached") op.drop_column("product", "usage_count_cached")
op.add_column("user", sa.Column("version", sa.Integer(), nullable=False)) op.add_column("user", sa.Column("version", sa.Integer(), nullable=False))
op.add_column( op.add_column("user", sa.Column("created_at", sa.DateTime(), nullable=False))
"user", sa.Column("created_at", sa.DateTime(), nullable=False) op.add_column("user", sa.Column("last_changed", sa.DateTime(), nullable=False))
) op.add_column("user", sa.Column("deleted_at", sa.DateTime(), nullable=True))
op.add_column(
"user", sa.Column("last_changed", sa.DateTime(), nullable=False)
)
op.add_column(
"user", sa.Column("deleted_at", sa.DateTime(), nullable=True)
)
# ### end Alembic commands ### # ### end Alembic commands ###
@ -195,9 +159,7 @@ def downgrade() -> None:
op.drop_column("diary", "protein_goal") op.drop_column("diary", "protein_goal")
op.create_table( op.create_table(
"refreshtoken", "refreshtoken",
sa.Column( sa.Column("user_id", sa.INTEGER(), autoincrement=False, nullable=False),
"user_id", sa.INTEGER(), autoincrement=False, nullable=False
),
sa.Column("token", sa.VARCHAR(), autoincrement=False, nullable=False), sa.Column("token", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.ForeignKeyConstraint( sa.ForeignKeyConstraint(

View file

@ -1,7 +1,6 @@
from fooder.context import Context from fooder.context import Context
from typing import TypeVar, Generic from typing import TypeVar, Generic
T = TypeVar("T") T = TypeVar("T")

View file

@ -45,12 +45,16 @@ class ProductController(ModelController[Product]):
if data.barcode is not None: if data.barcode is not None:
self.obj.barcode = data.barcode self.obj.barcode = data.barcode
self.obj.calories = data.calories if data.calories is not None else calculate_calories( self.obj.calories = (
data.calories
if data.calories is not None
else calculate_calories(
protein=self.obj.protein, protein=self.obj.protein,
carb=self.obj.carb, carb=self.obj.carb,
fat=self.obj.fat, fat=self.obj.fat,
fiber=self.obj.fiber, fiber=self.obj.fiber,
) )
)
await self.ctx.repo.product.update(self.obj) await self.ctx.repo.product.update(self.obj)
@ -63,7 +67,9 @@ class ProductController(ModelController[Product]):
except product_finder.ParseError: except product_finder.ParseError:
raise InvalidValue() raise InvalidValue()
return await cls.create(ctx, ProductCreateModel( return await cls.create(
ctx,
ProductCreateModel(
name=found.name, name=found.name,
calories=found.kcal, calories=found.kcal,
fat=found.fat, fat=found.fat,
@ -71,4 +77,5 @@ class ProductController(ModelController[Product]):
carb=found.carb, carb=found.carb,
fiber=found.fiber, fiber=found.fiber,
barcode=barcode, barcode=barcode,
)) ),
)

View file

@ -23,7 +23,9 @@ class DatabaseSessionManager:
), ),
) )
self._sessionmaker = async_sessionmaker( self._sessionmaker = async_sessionmaker(
autocommit=False, autoflush=False, bind=self._engine, autocommit=False,
autoflush=False,
bind=self._engine,
expire_on_commit=False, expire_on_commit=False,
) )

View file

@ -19,13 +19,17 @@ class CommonMixin:
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
version: Mapped[int] = mapped_column(default=0) version: Mapped[int] = mapped_column(default=0)
created_at: Mapped[datetime] = mapped_column(DateTime, default=utc_now) created_at: Mapped[datetime] = mapped_column(DateTime, default=utc_now)
last_changed: Mapped[datetime] = mapped_column(DateTime, default=utc_now, onupdate=utc_now) last_changed: Mapped[datetime] = mapped_column(
DateTime, default=utc_now, onupdate=utc_now
)
__mapper_args__ = {"version_id_col": version} __mapper_args__ = {"version_id_col": version}
class SoftDeleteMixin: class SoftDeleteMixin:
deleted_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) deleted_at: Mapped[datetime | None] = mapped_column(
DateTime, nullable=True, default=None
)
class PasswordMixin: class PasswordMixin:

View file

@ -12,7 +12,9 @@ class Diary(Base, CommonMixin):
__table_args__ = (UniqueConstraint("user_id", "date"),) __table_args__ = (UniqueConstraint("user_id", "date"),)
meals: Mapped[list[Meal]] = relationship(lazy="selectin", order_by=Meal.order.desc()) meals: Mapped[list[Meal]] = relationship(
lazy="selectin", order_by=Meal.order.desc()
)
date: Mapped[datetime.date] = mapped_column(Date) date: Mapped[datetime.date] = mapped_column(Date)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id")) user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
# snapshot of user settings at diary creation time — intentionally decoupled # snapshot of user settings at diary creation time — intentionally decoupled

View file

@ -11,4 +11,6 @@ class Meal(Base, CommonMixin, AggregateMacrosMixin):
name: Mapped[str] name: Mapped[str]
order: Mapped[int] order: Mapped[int]
diary_id: Mapped[int] = mapped_column(Integer, ForeignKey("diary.id")) diary_id: Mapped[int] = mapped_column(Integer, ForeignKey("diary.id"))
entries: Mapped[list[Entry]] = relationship(lazy="selectin", order_by=Entry.last_changed) entries: Mapped[list[Entry]] = relationship(
lazy="selectin", order_by=Entry.last_changed
)

View file

@ -8,9 +8,13 @@ class Product(Base, CommonMixin, SoftDeleteMixin):
"""Product.""" """Product."""
__table_args__ = ( __table_args__ = (
Index("ix_product_barcode", "barcode", unique=True, Index(
"ix_product_barcode",
"barcode",
unique=True,
postgresql_where=text("deleted_at IS NULL"), postgresql_where=text("deleted_at IS NULL"),
sqlite_where=text("deleted_at IS NULL")), sqlite_where=text("deleted_at IS NULL"),
),
) )
name: Mapped[str] name: Mapped[str]

View file

@ -5,6 +5,7 @@ class ObjModelMixin:
""" """
Shared code for ObjModel. Shared code for ObjModel.
""" """
id: int id: int
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View file

@ -23,12 +23,16 @@ class ProductCreateModel(ProductModelBase):
@property @property
def resolved_calories(self) -> float: def resolved_calories(self) -> float:
return self.calories if self.calories is not None else calculate_calories( return (
self.calories
if self.calories is not None
else calculate_calories(
protein=self.protein, protein=self.protein,
carb=self.carb, carb=self.carb,
fat=self.fat, fat=self.fat,
fiber=self.fiber, fiber=self.fiber,
) )
)
class ProductUpdateModel(ProductModelBase): class ProductUpdateModel(ProductModelBase):

View file

@ -50,7 +50,12 @@ class RepositoryBase(Generic[T]):
return obj return obj
async def _list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]: async def _list(
self,
*expressions: ColumnElement,
offset: int = 0,
limit: int | None = DEFAULT_LIMIT
) -> Sequence[T]:
stmt = self._build_select(*expressions) stmt = self._build_select(*expressions)
if offset: if offset:

View file

@ -11,7 +11,9 @@ async def client(db_session):
yield db_session yield db_session
app.dependency_overrides[get_db_session] = override_get_db_session app.dependency_overrides[get_db_session] = override_get_db_session
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c: async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as c:
yield c yield c
app.dependency_overrides.clear() app.dependency_overrides.clear()

View file

@ -20,7 +20,9 @@ def product_payload():
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def product(ctx): async def product(ctx):
data = ProductCreateModel(name="Chicken Breast", protein=31.0, carb=0.0, fat=3.6, fiber=0.0) data = ProductCreateModel(
name="Chicken Breast", protein=31.0, carb=0.0, fat=3.6, fiber=0.0
)
async with ctx.repo.transaction(): async with ctx.repo.transaction():
ctrl = await ProductController.create(ctx, data) ctrl = await ProductController.create(ctx, data)
return ctrl.obj return ctrl.obj
@ -28,7 +30,14 @@ async def product(ctx):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def product_with_barcode(ctx): async def product_with_barcode(ctx):
data = ProductCreateModel(name="Barcoded Product", protein=10.0, carb=5.0, fat=2.0, fiber=1.0, barcode="1234567890") data = ProductCreateModel(
name="Barcoded Product",
protein=10.0,
carb=5.0,
fat=2.0,
fiber=1.0,
barcode="1234567890",
)
async with ctx.repo.transaction(): async with ctx.repo.transaction():
ctrl = await ProductController.create(ctx, data) ctrl = await ProductController.create(ctx, data)
return ctrl.obj return ctrl.obj
@ -60,5 +69,3 @@ def mock_product_finder_not_found(monkeypatch):
raise product_finder.NotFound() raise product_finder.NotFound()
monkeypatch.setattr(product_finder, "find", fake_find) monkeypatch.setattr(product_finder, "find", fake_find)

View file

@ -1,6 +1,7 @@
import pytest import pytest
from ..fixtures.db import TestModel from ..fixtures.db import TestModel
from fooder.exc import NotFound from fooder.exc import NotFound
# ------------------------------------------------------------------ create --- # ------------------------------------------------------------------ create ---

View file

@ -6,7 +6,6 @@ from typing import Literal
from fooder.exc import Unauthorized from fooder.exc import Unauthorized
from fooder.utils.jwt import AccessToken, RefreshToken, Token from fooder.utils.jwt import AccessToken, RefreshToken, Token
PAST = datetime(2000, 1, 1, tzinfo=timezone.utc) PAST = datetime(2000, 1, 1, tzinfo=timezone.utc)

View file

@ -2,13 +2,19 @@ import pytest
async def test_update_product_returns_200(auth_client, product): async def test_update_product_returns_200(auth_client, product):
response = await auth_client.patch(f"/api/product/{product.id}", json={"name": "Updated Name"}) response = await auth_client.patch(
f"/api/product/{product.id}", json={"name": "Updated Name"}
)
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "Updated Name" assert response.json()["name"] == "Updated Name"
async def test_update_product_recalculates_calories_when_macros_change(auth_client, product): async def test_update_product_recalculates_calories_when_macros_change(
response = await auth_client.patch(f"/api/product/{product.id}", json={"protein": 10.0}) auth_client, product
):
response = await auth_client.patch(
f"/api/product/{product.id}", json={"protein": 10.0}
)
# protein=10, carb=0, fat=3.6, fiber=0 → 10*4 + 3.6*9 = 40 + 32.4 = 72.4 # protein=10, carb=0, fat=3.6, fiber=0 → 10*4 + 3.6*9 = 40 + 32.4 = 72.4
assert response.json()["calories"] == pytest.approx(72.4) assert response.json()["calories"] == pytest.approx(72.4)
@ -25,9 +31,15 @@ async def test_update_product_not_found_returns_404(auth_client):
assert response.status_code == 404 assert response.status_code == 404
async def test_update_product_duplicate_barcode_returns_409(auth_client, product, product_payload): async def test_update_product_duplicate_barcode_returns_409(
await auth_client.post("/api/product", json={**product_payload, "name": "Other", "barcode": "AAA"}) auth_client, product, product_payload
response = await auth_client.patch(f"/api/product/{product.id}", json={"barcode": "AAA"}) ):
await auth_client.post(
"/api/product", json={**product_payload, "name": "Other", "barcode": "AAA"}
)
response = await auth_client.patch(
f"/api/product/{product.id}", json={"barcode": "AAA"}
)
assert response.status_code == 409 assert response.status_code == 409
@ -59,23 +71,37 @@ async def test_create_product_calculates_calories(auth_client, product_payload):
async def test_create_product_uses_explicit_calories(auth_client, product_payload): async def test_create_product_uses_explicit_calories(auth_client, product_payload):
response = await auth_client.post("/api/product", json={**product_payload, "calories": 50.0}) response = await auth_client.post(
"/api/product", json={**product_payload, "calories": 50.0}
)
assert response.json()["calories"] == 50.0 assert response.json()["calories"] == 50.0
async def test_create_product_duplicate_barcode_returns_409(auth_client, product_payload): async def test_create_product_duplicate_barcode_returns_409(
await auth_client.post("/api/product", json={**product_payload, "barcode": "123456"}) auth_client, product_payload
response = await auth_client.post("/api/product", json={**product_payload, "barcode": "123456"}) ):
await auth_client.post(
"/api/product", json={**product_payload, "barcode": "123456"}
)
response = await auth_client.post(
"/api/product", json={**product_payload, "barcode": "123456"}
)
assert response.status_code == 409 assert response.status_code == 409
async def test_create_product_invalid_protein_returns_422(auth_client, product_payload): async def test_create_product_invalid_protein_returns_422(auth_client, product_payload):
response = await auth_client.post("/api/product", json={**product_payload, "protein": -1.0}) response = await auth_client.post(
"/api/product", json={**product_payload, "protein": -1.0}
)
assert response.status_code == 422 assert response.status_code == 422
async def test_create_product_protein_over_100_returns_422(auth_client, product_payload): async def test_create_product_protein_over_100_returns_422(
response = await auth_client.post("/api/product", json={**product_payload, "protein": 101.0}) auth_client, product_payload
):
response = await auth_client.post(
"/api/product", json={**product_payload, "protein": 101.0}
)
assert response.status_code == 422 assert response.status_code == 422
@ -106,13 +132,19 @@ async def test_list_products_without_auth_returns_401(client):
assert response.status_code == 401 assert response.status_code == 401
async def test_get_by_barcode_returns_product_from_db(auth_client, product_with_barcode): async def test_get_by_barcode_returns_product_from_db(
response = await auth_client.get(f"/api/product/barcode/{product_with_barcode.barcode}") auth_client, product_with_barcode
):
response = await auth_client.get(
f"/api/product/barcode/{product_with_barcode.barcode}"
)
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["id"] == product_with_barcode.id assert response.json()["id"] == product_with_barcode.id
async def test_get_by_barcode_imports_when_not_in_db(auth_client, mock_product_finder, external_product): async def test_get_by_barcode_imports_when_not_in_db(
auth_client, mock_product_finder, external_product
):
response = await auth_client.get("/api/product/barcode/9999999999") response = await auth_client.get("/api/product/barcode/9999999999")
assert response.status_code == 200 assert response.status_code == 200
body = response.json() body = response.json()
@ -125,13 +157,17 @@ async def test_get_by_barcode_imports_when_not_in_db(auth_client, mock_product_f
assert body["barcode"] == "9999999999" assert body["barcode"] == "9999999999"
async def test_get_by_barcode_persists_imported_product(auth_client, mock_product_finder): async def test_get_by_barcode_persists_imported_product(
auth_client, mock_product_finder
):
await auth_client.get("/api/product/barcode/8888888888") await auth_client.get("/api/product/barcode/8888888888")
response = await auth_client.get("/api/product/barcode/8888888888") response = await auth_client.get("/api/product/barcode/8888888888")
assert response.status_code == 200 assert response.status_code == 200
async def test_get_by_barcode_not_found_returns_404(auth_client, mock_product_finder_not_found): async def test_get_by_barcode_not_found_returns_404(
auth_client, mock_product_finder_not_found
):
response = await auth_client.get("/api/product/barcode/0000000000") response = await auth_client.get("/api/product/barcode/0000000000")
assert response.status_code == 404 assert response.status_code == 404

View file

@ -54,7 +54,9 @@ async def test_refresh_token_returns_new_tokens(client, user, user_password):
) )
refresh_token = response.json()["refresh_token"] refresh_token = response.json()["refresh_token"]
response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token}) response = await client.post(
"/api/token/refresh", params={"refresh_token": refresh_token}
)
assert response.status_code == 200 assert response.status_code == 200
body = response.json() body = response.json()
assert "access_token" in body assert "access_token" in body
@ -69,7 +71,9 @@ async def test_refresh_token_access_token_is_valid(client, user, user_password):
) )
refresh_token = response.json()["refresh_token"] refresh_token = response.json()["refresh_token"]
response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token}) response = await client.post(
"/api/token/refresh", params={"refresh_token": refresh_token}
)
token = AccessToken.decode(response.json()["access_token"]) token = AccessToken.decode(response.json()["access_token"])
assert token.sub == user.id assert token.sub == user.id
@ -81,22 +85,30 @@ async def test_refresh_token_refresh_token_is_valid(client, user, user_password)
) )
refresh_token = response.json()["refresh_token"] refresh_token = response.json()["refresh_token"]
response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token}) response = await client.post(
"/api/token/refresh", params={"refresh_token": refresh_token}
)
token = RefreshToken.decode(response.json()["refresh_token"]) token = RefreshToken.decode(response.json()["refresh_token"])
assert token.sub == user.id assert token.sub == user.id
async def test_refresh_token_invalid_returns_401(client): async def test_refresh_token_invalid_returns_401(client):
response = await client.post("/api/token/refresh", params={"refresh_token": "bad-token"}) response = await client.post(
"/api/token/refresh", params={"refresh_token": "bad-token"}
)
assert response.status_code == 401 assert response.status_code == 401
async def test_refresh_token_access_token_as_refresh_returns_401(client, user, user_password): async def test_refresh_token_access_token_as_refresh_returns_401(
client, user, user_password
):
response = await client.post( response = await client.post(
"/api/token", "/api/token",
data={"username": user.username, "password": user_password}, data={"username": user.username, "password": user_password},
) )
access_token = response.json()["access_token"] access_token = response.json()["access_token"]
response = await client.post("/api/token/refresh", params={"refresh_token": access_token}) response = await client.post(
"/api/token/refresh", params={"refresh_token": access_token}
)
assert response.status_code == 401 assert response.status_code == 401

View file

@ -45,7 +45,9 @@ class RefreshToken(Token):
expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
def generate_token_pair(entity_id: int, now: datetime) -> tuple[AccessToken, RefreshToken]: def generate_token_pair(
entity_id: int, now: datetime
) -> tuple[AccessToken, RefreshToken]:
return ( return (
AccessToken(exp=AccessToken.calculate_exp(now), sub=entity_id), AccessToken(exp=AccessToken.calculate_exp(now), sub=entity_id),
RefreshToken(exp=RefreshToken.calculate_exp(now), sub=entity_id), RefreshToken(exp=RefreshToken.calculate_exp(now), sub=entity_id),