From 65d0a19e41bacf1426ab379cf00ffcddaa6ccf53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Doma=C5=84ski?= Date: Tue, 7 Apr 2026 21:36:42 +0200 Subject: [PATCH] [preset] implement [repo] simplify fetching for views on most repos --- fooder/alembic/versions/4e8d78ff6e9e_.py | 144 ++++++++++-- fooder/alembic/versions/564e5948f3ed_.py | 71 ++++++ fooder/command/create_entry.py | 5 +- fooder/command/load_preset_as_meal.py | 22 ++ fooder/command/save_meal_as_preset.py | 19 ++ fooder/controller/entry.py | 6 +- fooder/controller/preset.py | 17 ++ fooder/controller/preset_entry.py | 21 ++ fooder/domain/entry.py | 14 +- fooder/domain/meal.py | 7 +- fooder/domain/preset.py | 7 +- fooder/domain/preset_entry.py | 14 +- fooder/domain/user_product_usage.py | 10 +- fooder/model/preset.py | 28 +++ fooder/model/preset_entry.py | 17 ++ fooder/model/token.py | 4 + fooder/repository/entry.py | 14 +- fooder/repository/meal.py | 10 +- fooder/repository/preset.py | 17 ++ fooder/repository/preset_entry.py | 15 ++ fooder/repository/repository.py | 6 + fooder/router.py | 2 + fooder/test/fixtures/diary.py | 2 +- fooder/test/view/test_preset.py | 285 +++++++++++++++++++++++ fooder/test/view/test_product.py | 26 +++ fooder/test/view/test_token.py | 10 +- fooder/view/entry.py | 13 +- fooder/view/meal.py | 35 ++- fooder/view/preset.py | 81 +++++++ fooder/view/token.py | 6 +- 30 files changed, 862 insertions(+), 66 deletions(-) create mode 100644 fooder/alembic/versions/564e5948f3ed_.py create mode 100644 fooder/command/load_preset_as_meal.py create mode 100644 fooder/command/save_meal_as_preset.py create mode 100644 fooder/controller/preset.py create mode 100644 fooder/controller/preset_entry.py create mode 100644 fooder/model/preset.py create mode 100644 fooder/model/preset_entry.py create mode 100644 fooder/repository/preset.py create mode 100644 fooder/repository/preset_entry.py create mode 100644 fooder/test/view/test_preset.py create mode 100644 fooder/view/preset.py diff --git a/fooder/alembic/versions/4e8d78ff6e9e_.py b/fooder/alembic/versions/4e8d78ff6e9e_.py index 2fda5d2..a083cff 100644 --- a/fooder/alembic/versions/4e8d78ff6e9e_.py +++ b/fooder/alembic/versions/4e8d78ff6e9e_.py @@ -20,6 +20,10 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Upgrade schema.""" + _now = sa.text("CURRENT_TIMESTAMP") + _zero_int = sa.text("0") + _zero_float = sa.text("0") + # ### commands auto generated by Alembic - please adjust! ### op.create_table( "userproductusage", @@ -60,30 +64,111 @@ def upgrade() -> None: sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("user_id"), ) + # Create default settings for all existing users + op.execute( + sa.text( + "INSERT INTO usersettings" + " (user_id, protein_goal, carb_goal, fat_goal, fiber_goal, calories_goal," + " version, created_at, last_changed)" + " SELECT id, 0, 0, 0, 0, 0, 0, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP" + ' FROM "user"' + ) + ) + op.drop_table("refreshtoken") - op.add_column("diary", sa.Column("protein_goal", sa.Float(), nullable=False)) - op.add_column("diary", sa.Column("carb_goal", sa.Float(), nullable=False)) - op.add_column("diary", sa.Column("fat_goal", sa.Float(), nullable=False)) - op.add_column("diary", sa.Column("fiber_goal", sa.Float(), nullable=False)) - op.add_column("diary", sa.Column("calories_goal", sa.Float(), nullable=False)) - op.add_column("diary", sa.Column("version", sa.Integer(), nullable=False)) - op.add_column("diary", sa.Column("created_at", sa.DateTime(), nullable=False)) - op.add_column("diary", sa.Column("last_changed", sa.DateTime(), nullable=False)) + op.add_column( + "diary", + sa.Column( + "protein_goal", sa.Float(), nullable=False, server_default=_zero_float + ), + ) + op.add_column( + "diary", + sa.Column("carb_goal", sa.Float(), nullable=False, server_default=_zero_float), + ) + op.add_column( + "diary", + sa.Column("fat_goal", sa.Float(), nullable=False, server_default=_zero_float), + ) + op.add_column( + "diary", + sa.Column("fiber_goal", sa.Float(), nullable=False, server_default=_zero_float), + ) + op.add_column( + "diary", + sa.Column( + "calories_goal", sa.Float(), nullable=False, server_default=_zero_float + ), + ) + op.add_column( + "diary", + sa.Column("version", sa.Integer(), nullable=False, server_default=_zero_int), + ) + op.add_column( + "diary", + sa.Column("created_at", sa.DateTime(), nullable=False, server_default=_now), + ) + op.add_column( + "diary", + sa.Column("last_changed", sa.DateTime(), nullable=False, server_default=_now), + ) op.create_unique_constraint(None, "diary", ["user_id", "date"]) - op.add_column("entry", sa.Column("version", sa.Integer(), nullable=False)) - op.add_column("entry", sa.Column("created_at", sa.DateTime(), nullable=False)) - op.add_column("meal", sa.Column("version", sa.Integer(), nullable=False)) - op.add_column("meal", sa.Column("created_at", sa.DateTime(), nullable=False)) - op.add_column("meal", sa.Column("last_changed", sa.DateTime(), nullable=False)) - op.add_column("preset", sa.Column("version", sa.Integer(), nullable=False)) - op.add_column("preset", sa.Column("created_at", sa.DateTime(), nullable=False)) - op.add_column("preset", sa.Column("last_changed", sa.DateTime(), nullable=False)) - op.add_column("presetentry", sa.Column("version", sa.Integer(), nullable=False)) - op.add_column("presetentry", sa.Column("created_at", sa.DateTime(), nullable=False)) - op.add_column("product", sa.Column("calories", sa.Float(), nullable=False)) - op.add_column("product", sa.Column("version", sa.Integer(), nullable=False)) - op.add_column("product", sa.Column("created_at", sa.DateTime(), nullable=False)) - op.add_column("product", sa.Column("last_changed", sa.DateTime(), nullable=False)) + op.add_column( + "entry", + sa.Column("version", sa.Integer(), nullable=False, server_default=_zero_int), + ) + op.add_column( + "entry", + sa.Column("created_at", sa.DateTime(), nullable=False, server_default=_now), + ) + op.add_column( + "meal", + sa.Column("version", sa.Integer(), nullable=False, server_default=_zero_int), + ) + op.add_column( + "meal", + sa.Column("created_at", sa.DateTime(), nullable=False, server_default=_now), + ) + op.add_column( + "meal", + sa.Column("last_changed", sa.DateTime(), nullable=False, server_default=_now), + ) + op.add_column( + "preset", + sa.Column("version", sa.Integer(), nullable=False, server_default=_zero_int), + ) + op.add_column( + "preset", + sa.Column("created_at", sa.DateTime(), nullable=False, server_default=_now), + ) + op.add_column( + "preset", + sa.Column("last_changed", sa.DateTime(), nullable=False, server_default=_now), + ) + op.add_column( + "presetentry", + sa.Column("version", sa.Integer(), nullable=False, server_default=_zero_int), + ) + op.add_column( + "presetentry", + sa.Column("created_at", sa.DateTime(), nullable=False, server_default=_now), + ) + op.add_column( + "product", + sa.Column("calories", sa.Float(), nullable=False, server_default=_zero_float), + ) + op.add_column( + "product", + sa.Column("version", sa.Integer(), nullable=False, server_default=_zero_int), + ) + op.add_column( + "product", + sa.Column("created_at", sa.DateTime(), nullable=False, server_default=_now), + ) + op.add_column( + "product", + sa.Column("last_changed", sa.DateTime(), nullable=False, server_default=_now), + ) op.add_column("product", sa.Column("deleted_at", sa.DateTime(), nullable=True)) op.create_index( "ix_product_barcode", @@ -95,9 +180,18 @@ def upgrade() -> None: ) op.drop_column("product", "hard_coded_calories") op.drop_column("product", "usage_count_cached") - op.add_column("user", sa.Column("version", sa.Integer(), nullable=False)) - op.add_column("user", sa.Column("created_at", sa.DateTime(), nullable=False)) - op.add_column("user", sa.Column("last_changed", sa.DateTime(), nullable=False)) + op.add_column( + "user", + sa.Column("version", sa.Integer(), nullable=False, server_default=_zero_int), + ) + op.add_column( + "user", + sa.Column("created_at", sa.DateTime(), nullable=False, server_default=_now), + ) + op.add_column( + "user", + sa.Column("last_changed", sa.DateTime(), nullable=False, server_default=_now), + ) op.add_column("user", sa.Column("deleted_at", sa.DateTime(), nullable=True)) # ### end Alembic commands ### diff --git a/fooder/alembic/versions/564e5948f3ed_.py b/fooder/alembic/versions/564e5948f3ed_.py new file mode 100644 index 0000000..31bed52 --- /dev/null +++ b/fooder/alembic/versions/564e5948f3ed_.py @@ -0,0 +1,71 @@ +""" + +Revision ID: 564e5948f3ed +Revises: 4e8d78ff6e9e +Create Date: 2026-04-07 19:31:01.616100 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "564e5948f3ed" +down_revision: Union[str, Sequence[str], None] = "4e8d78ff6e9e" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_index(op.f("ix_entry_meal_id"), "entry", ["meal_id"], unique=False) + op.create_index(op.f("ix_entry_product_id"), "entry", ["product_id"], unique=False) + op.drop_column("entry", "processed") + op.create_index(op.f("ix_meal_diary_id"), "meal", ["diary_id"], unique=False) + op.create_index(op.f("ix_preset_user_id"), "preset", ["user_id"], unique=False) + op.create_index( + op.f("ix_presetentry_preset_id"), + "presetentry", + ["preset_id"], + unique=False, + ) + op.create_index( + op.f("ix_presetentry_product_id"), + "presetentry", + ["product_id"], + unique=False, + ) + op.create_index( + "ix_userproductusage_product_user", + "userproductusage", + ["product_id", "user_id"], + unique=False, + ) + op.create_index( + op.f("ix_userproductusage_user_id"), + "userproductusage", + ["user_id"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_userproductusage_user_id"), table_name="userproductusage") + op.drop_index("ix_userproductusage_product_user", table_name="userproductusage") + op.drop_index(op.f("ix_presetentry_product_id"), table_name="presetentry") + op.drop_index(op.f("ix_presetentry_preset_id"), table_name="presetentry") + op.drop_index(op.f("ix_preset_user_id"), table_name="preset") + op.drop_index(op.f("ix_meal_diary_id"), table_name="meal") + op.add_column( + "entry", + sa.Column("processed", sa.BOOLEAN(), autoincrement=False, nullable=False), + ) + op.drop_index(op.f("ix_entry_product_id"), table_name="entry") + op.drop_index(op.f("ix_entry_meal_id"), table_name="entry") + # ### end Alembic commands ### diff --git a/fooder/command/create_entry.py b/fooder/command/create_entry.py index d8f18a6..e3e9307 100644 --- a/fooder/command/create_entry.py +++ b/fooder/command/create_entry.py @@ -1,11 +1,12 @@ from fooder.context import Context from fooder.controller.entry import EntryController from fooder.domain import Entry +from fooder.domain.meal import Meal from fooder.model.entry import EntryCreateModel -async def create_entry(ctx: Context, meal_id: int, data: EntryCreateModel) -> Entry: - ctrl = await EntryController.create(ctx, meal_id=meal_id, data=data) +async def create_entry(ctx: Context, meal: Meal, data: EntryCreateModel) -> Entry: + ctrl = await EntryController.create(ctx, meal=meal, data=data) await ctx.repo.user_product_usage.increment( user_id=ctx.user.id, product_id=data.product_id, diff --git a/fooder/command/load_preset_as_meal.py b/fooder/command/load_preset_as_meal.py new file mode 100644 index 0000000..77329f0 --- /dev/null +++ b/fooder/command/load_preset_as_meal.py @@ -0,0 +1,22 @@ +from fooder.context import Context +from fooder.controller.entry import EntryController +from fooder.controller.meal import MealController +from fooder.domain import Meal, Preset +from fooder.model.entry import EntryCreateModel +from fooder.model.meal import MealCreateModel + + +async def load_preset_as_meal( + ctx: Context, preset: Preset, diary_id: int, name: str | None = None +) -> Meal: + meal_ctrl = await MealController.create( + ctx, diary_id=diary_id, data=MealCreateModel(name=name or preset.name) + ) + for entry in preset.entries: + await EntryController.create( + ctx, + meal=meal_ctrl.obj, + data=EntryCreateModel(grams=entry.grams, product_id=entry.product_id), + ) + await ctx.repo.meal.session.refresh(meal_ctrl.obj) + return meal_ctrl.obj diff --git a/fooder/command/save_meal_as_preset.py b/fooder/command/save_meal_as_preset.py new file mode 100644 index 0000000..c52ccee --- /dev/null +++ b/fooder/command/save_meal_as_preset.py @@ -0,0 +1,19 @@ +from fooder.context import Context +from fooder.controller.preset import PresetController +from fooder.domain import Meal, Preset, PresetEntry + + +async def save_meal_as_preset( + ctx: Context, meal: Meal, name: str | None = None +) -> Preset: + await ctx.repo.meal.session.refresh(meal) + ctrl = await PresetController.create(ctx, name=name or meal.name) + for entry in meal.entries: + preset_entry = PresetEntry( + grams=entry.grams, + product_id=entry.product_id, + preset_id=ctrl.obj.id, + ) + await ctx.repo.preset_entry.create(preset_entry) + await ctx.repo.preset.session.refresh(ctrl.obj) + return ctrl.obj diff --git a/fooder/controller/entry.py b/fooder/controller/entry.py index a01594d..1688b00 100644 --- a/fooder/controller/entry.py +++ b/fooder/controller/entry.py @@ -1,15 +1,17 @@ from fooder.context import Context from fooder.controller.base import ModelController from fooder.domain import Entry +from fooder.domain.meal import Meal from fooder.model.entry import EntryCreateModel, EntryUpdateModel class EntryController(ModelController[Entry]): @classmethod async def create( - cls, ctx: Context, meal_id: int, data: EntryCreateModel + cls, ctx: Context, meal: Meal, data: EntryCreateModel ) -> "EntryController": - obj = Entry(grams=data.grams, product_id=data.product_id, meal_id=meal_id) + obj = Entry(grams=data.grams, product_id=data.product_id) + obj.meal = meal await ctx.repo.entry.create(obj) return cls(ctx, obj) diff --git a/fooder/controller/preset.py b/fooder/controller/preset.py new file mode 100644 index 0000000..e62ad2b --- /dev/null +++ b/fooder/controller/preset.py @@ -0,0 +1,17 @@ +from fooder.context import Context +from fooder.controller.base import ModelController +from fooder.domain import Preset +from fooder.model.preset import PresetUpdateModel + + +class PresetController(ModelController[Preset]): + @classmethod + async def create(cls, ctx: Context, name: str) -> "PresetController": + obj = Preset(name=name, user_id=ctx.user.id) + await ctx.repo.preset.create(obj) + return cls(ctx, obj) + + async def update(self, data: PresetUpdateModel) -> None: + if data.name is not None: + self.obj.name = data.name + await self.ctx.repo.preset.update(self.obj) diff --git a/fooder/controller/preset_entry.py b/fooder/controller/preset_entry.py new file mode 100644 index 0000000..d6a5872 --- /dev/null +++ b/fooder/controller/preset_entry.py @@ -0,0 +1,21 @@ +from fooder.context import Context +from fooder.controller.base import ModelController +from fooder.domain import PresetEntry +from fooder.domain.preset import Preset +from fooder.model.entry import EntryCreateModel, EntryUpdateModel + + +class PresetEntryController(ModelController[PresetEntry]): + @classmethod + async def create( + cls, ctx: Context, preset: Preset, data: EntryCreateModel + ) -> "PresetEntryController": + obj = PresetEntry(grams=data.grams, product_id=data.product_id) + obj.preset = preset + await ctx.repo.preset_entry.create(obj) + return cls(ctx, obj) + + async def update(self, data: EntryUpdateModel) -> None: + if data.grams is not None: + self.obj.grams = data.grams + await self.ctx.repo.preset_entry.update(self.obj) diff --git a/fooder/domain/entry.py b/fooder/domain/entry.py index ae20337..eecc4b8 100644 --- a/fooder/domain/entry.py +++ b/fooder/domain/entry.py @@ -1,14 +1,24 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from sqlalchemy import ForeignKey, Integer from sqlalchemy.orm import Mapped, mapped_column, relationship from fooder.domain.base import Base, CommonMixin, EntryMacrosMixin from fooder.domain.product import Product +if TYPE_CHECKING: + from fooder.domain.meal import Meal + class Entry(Base, CommonMixin, EntryMacrosMixin): """Entry.""" grams: Mapped[float] - product_id: Mapped[int] = mapped_column(Integer, ForeignKey("product.id")) + product_id: Mapped[int] = mapped_column( + Integer, ForeignKey("product.id"), index=True + ) product: Mapped[Product] = relationship(lazy="selectin") - meal_id: Mapped[int] = mapped_column(Integer, ForeignKey("meal.id")) + meal_id: Mapped[int] = mapped_column(Integer, ForeignKey("meal.id"), index=True) + meal: Mapped[Meal] = relationship(back_populates="entries") diff --git a/fooder/domain/meal.py b/fooder/domain/meal.py index 28ecad3..0a00930 100644 --- a/fooder/domain/meal.py +++ b/fooder/domain/meal.py @@ -10,7 +10,10 @@ class Meal(Base, CommonMixin, AggregateMacrosMixin): name: Mapped[str] order: Mapped[int] - diary_id: Mapped[int] = mapped_column(Integer, ForeignKey("diary.id")) + diary_id: Mapped[int] = mapped_column(Integer, ForeignKey("diary.id"), index=True) entries: Mapped[list[Entry]] = relationship( - lazy="selectin", order_by=Entry.last_changed, cascade="all, delete-orphan" + lazy="selectin", + order_by=Entry.last_changed, + cascade="all, delete-orphan", + back_populates="meal", ) diff --git a/fooder/domain/preset.py b/fooder/domain/preset.py index b4731f7..fbeac86 100644 --- a/fooder/domain/preset.py +++ b/fooder/domain/preset.py @@ -9,7 +9,10 @@ class Preset(Base, CommonMixin, AggregateMacrosMixin): """Preset.""" name: Mapped[str] - user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id")) + user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"), index=True) entries: Mapped[list[PresetEntry]] = relationship( - lazy="selectin", order_by=PresetEntry.last_changed + lazy="selectin", + order_by=PresetEntry.last_changed, + cascade="all, delete-orphan", + back_populates="preset", ) diff --git a/fooder/domain/preset_entry.py b/fooder/domain/preset_entry.py index 6d2b71a..827bf10 100644 --- a/fooder/domain/preset_entry.py +++ b/fooder/domain/preset_entry.py @@ -1,14 +1,24 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from sqlalchemy import ForeignKey, Integer from sqlalchemy.orm import Mapped, mapped_column, relationship from fooder.domain.base import Base, CommonMixin, EntryMacrosMixin from fooder.domain.product import Product +if TYPE_CHECKING: + from fooder.domain.preset import Preset + class PresetEntry(Base, CommonMixin, EntryMacrosMixin): """PresetEntry.""" grams: Mapped[float] - product_id: Mapped[int] = mapped_column(Integer, ForeignKey("product.id")) + product_id: Mapped[int] = mapped_column( + Integer, ForeignKey("product.id"), index=True + ) product: Mapped[Product] = relationship(lazy="selectin") - preset_id: Mapped[int] = mapped_column(Integer, ForeignKey("preset.id")) + preset_id: Mapped[int] = mapped_column(Integer, ForeignKey("preset.id"), index=True) + preset: Mapped[Preset] = relationship(back_populates="entries") diff --git a/fooder/domain/user_product_usage.py b/fooder/domain/user_product_usage.py index 07d3a22..2c0c357 100644 --- a/fooder/domain/user_product_usage.py +++ b/fooder/domain/user_product_usage.py @@ -1,4 +1,4 @@ -from sqlalchemy import ForeignKey, Integer, UniqueConstraint +from sqlalchemy import ForeignKey, Index, Integer, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship from fooder.domain.base import Base, CommonMixin @@ -10,10 +10,14 @@ class UserProductUsage(Base, CommonMixin): """Counts how many processed entries a user has for a product. Used to sort products by usage frequency.""" - __table_args__ = (UniqueConstraint("user_id", "product_id"),) + __table_args__ = ( + UniqueConstraint("user_id", "product_id"), + # Covers outerjoin on (product_id, user_id) in list_for_user + Index("ix_userproductusage_product_user", "product_id", "user_id"), + ) product_id: Mapped[int] = mapped_column(Integer, ForeignKey("product.id")) product: Mapped[Product] = relationship(lazy="selectin") - user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id")) + user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"), index=True) user: Mapped[User] = relationship(lazy="selectin") count: Mapped[int] diff --git a/fooder/model/preset.py b/fooder/model/preset.py new file mode 100644 index 0000000..d74d99a --- /dev/null +++ b/fooder/model/preset.py @@ -0,0 +1,28 @@ +from pydantic import BaseModel + +from fooder.model.base import ObjModelMixin +from fooder.model.preset_entry import PresetEntryModel + + +class PresetModel(ObjModelMixin, BaseModel): + name: str + user_id: int + protein: float + carb: float + fat: float + fiber: float + calories: float + entries: list[PresetEntryModel] + + +class SaveAsPresetModel(BaseModel): + name: str | None = None + + +class PresetUpdateModel(BaseModel): + name: str | None = None + + +class LoadPresetAsMealModel(BaseModel): + preset_id: int + name: str | None = None diff --git a/fooder/model/preset_entry.py b/fooder/model/preset_entry.py new file mode 100644 index 0000000..1432e9a --- /dev/null +++ b/fooder/model/preset_entry.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel + +from fooder.model.base import ObjModelMixin +from fooder.model.entry import Grams +from fooder.model.product import ProductModel + + +class PresetEntryModel(ObjModelMixin, BaseModel): + grams: Grams + product_id: int + preset_id: int + product: ProductModel + protein: float + carb: float + fat: float + fiber: float + calories: float diff --git a/fooder/model/token.py b/fooder/model/token.py index 438f788..368ab79 100644 --- a/fooder/model/token.py +++ b/fooder/model/token.py @@ -5,3 +5,7 @@ class TokenResponse(BaseModel): access_token: str refresh_token: str token_type: str = "bearer" + + +class RefreshTokenRequest(BaseModel): + refresh_token: str diff --git a/fooder/repository/entry.py b/fooder/repository/entry.py index d44297d..2815bba 100644 --- a/fooder/repository/entry.py +++ b/fooder/repository/entry.py @@ -1,7 +1,17 @@ +from sqlalchemy import select + from fooder.domain import Entry +from fooder.domain.diary import Diary +from fooder.domain.meal import Meal from fooder.repository.base import RepositoryBase class EntryRepository(RepositoryBase[Entry]): - async def get_by_id_and_meal(self, entry_id: int, meal_id: int) -> Entry: - return await self._get(Entry.id == entry_id, Entry.meal_id == meal_id) + async def get_by_id_and_user(self, entry_id: int, user_id: int) -> Entry: + stmt = ( + select(Entry) + .join(Meal, Entry.meal_id == Meal.id) + .join(Diary, Meal.diary_id == Diary.id) + .where(Entry.id == entry_id, Diary.user_id == user_id) + ) + return await self._get(stmt=stmt) diff --git a/fooder/repository/meal.py b/fooder/repository/meal.py index b9df096..257cb97 100644 --- a/fooder/repository/meal.py +++ b/fooder/repository/meal.py @@ -1,12 +1,18 @@ from sqlalchemy import select, func from fooder.domain import Meal +from fooder.domain.diary import Diary from fooder.repository.base import RepositoryBase class MealRepository(RepositoryBase[Meal]): - async def get_by_id_and_diary(self, meal_id: int, diary_id: int) -> Meal: - return await self._get(Meal.id == meal_id, Meal.diary_id == diary_id) + async def get_by_id_and_user(self, meal_id: int, user_id: int) -> Meal: + stmt = ( + select(Meal) + .join(Diary, Meal.diary_id == Diary.id) + .where(Meal.id == meal_id, Diary.user_id == user_id) + ) + return await self._get(stmt=stmt) async def next_order(self, diary_id: int) -> int: stmt = select(func.max(Meal.order)).where(Meal.diary_id == diary_id) diff --git a/fooder/repository/preset.py b/fooder/repository/preset.py new file mode 100644 index 0000000..d8e9f3d --- /dev/null +++ b/fooder/repository/preset.py @@ -0,0 +1,17 @@ +from typing import Sequence + +from fooder.domain import Preset +from fooder.repository.base import DEFAULT_LIMIT, RepositoryBase + + +class PresetRepository(RepositoryBase[Preset]): + async def get_by_id_and_user(self, preset_id: int, user_id: int) -> Preset: + return await self._get(Preset.id == preset_id, Preset.user_id == user_id) + + async def list_by_user( + self, + user_id: int, + offset: int = 0, + limit: int | None = DEFAULT_LIMIT, + ) -> Sequence[Preset]: + return await self._list(Preset.user_id == user_id, offset=offset, limit=limit) diff --git a/fooder/repository/preset_entry.py b/fooder/repository/preset_entry.py new file mode 100644 index 0000000..9ef49dc --- /dev/null +++ b/fooder/repository/preset_entry.py @@ -0,0 +1,15 @@ +from sqlalchemy import select + +from fooder.domain import PresetEntry +from fooder.domain.preset import Preset +from fooder.repository.base import RepositoryBase + + +class PresetEntryRepository(RepositoryBase[PresetEntry]): + async def get_by_id_and_user(self, entry_id: int, user_id: int) -> PresetEntry: + stmt = ( + select(PresetEntry) + .join(Preset, PresetEntry.preset_id == Preset.id) + .where(PresetEntry.id == entry_id, Preset.user_id == user_id) + ) + return await self._get(stmt=stmt) diff --git a/fooder/repository/repository.py b/fooder/repository/repository.py index cc6443c..bf59a9b 100644 --- a/fooder/repository/repository.py +++ b/fooder/repository/repository.py @@ -10,6 +10,8 @@ from fooder.repository.user_settings import UserSettingsRepository from fooder.repository.diary import DiaryRepository from fooder.repository.meal import MealRepository from fooder.repository.entry import EntryRepository +from fooder.repository.preset import PresetRepository +from fooder.repository.preset_entry import PresetEntryRepository from fooder.domain import ( User, Product, @@ -18,6 +20,8 @@ from fooder.domain import ( Diary, Meal, Entry, + Preset, + PresetEntry, ) from fooder.exc import Conflict @@ -32,6 +36,8 @@ class Repository: self.diary = DiaryRepository(Diary, session) self.meal = MealRepository(Meal, session) self.entry = EntryRepository(Entry, session) + self.preset = PresetRepository(Preset, session) + self.preset_entry = PresetEntryRepository(PresetEntry, session) async def commit(self) -> None: try: diff --git a/fooder/router.py b/fooder/router.py index 1e12fd8..722fb5c 100644 --- a/fooder/router.py +++ b/fooder/router.py @@ -6,6 +6,7 @@ from fooder.view.user_settings import router as user_settings_router from fooder.view.diary import router as diary_router from fooder.view.meal import router as meal_router from fooder.view.entry import router as entry_router +from fooder.view.preset import router as preset_router router = APIRouter(prefix="/api") router.include_router(token_router, prefix="/token", tags=["token"]) @@ -18,3 +19,4 @@ router.include_router(meal_router, prefix="/diary/{date}/meal", tags=["meal"]) router.include_router( entry_router, prefix="/diary/{date}/meal/{meal_id}/entry", tags=["entry"] ) +router.include_router(preset_router, prefix="/preset", tags=["preset"]) diff --git a/fooder/test/fixtures/diary.py b/fooder/test/fixtures/diary.py index d9a9271..08b6089 100644 --- a/fooder/test/fixtures/diary.py +++ b/fooder/test/fixtures/diary.py @@ -42,7 +42,7 @@ async def entry(auth_ctx, meal, product): async with auth_ctx.repo.transaction(): ctrl = await EntryController.create( auth_ctx, - meal_id=meal.id, + meal=meal, data=EntryCreateModel(grams=100.0, product_id=product.id), ) return ctrl.obj diff --git a/fooder/test/view/test_preset.py b/fooder/test/view/test_preset.py new file mode 100644 index 0000000..c4f14be --- /dev/null +++ b/fooder/test/view/test_preset.py @@ -0,0 +1,285 @@ +import datetime + +import pytest_asyncio + +from fooder.command.save_meal_as_preset import save_meal_as_preset + +TODAY = datetime.date.today().isoformat() + + +@pytest_asyncio.fixture +async def preset(auth_ctx, meal, entry): + async with auth_ctx.repo.transaction(): + obj = await save_meal_as_preset(auth_ctx, meal) + await auth_ctx.repo.preset.session.refresh(obj) + return obj + + +@pytest_asyncio.fixture +async def preset_entry(preset): + return preset.entries[0] + + +# --- save meal as preset --- + + +async def test_save_meal_as_preset_returns_201(auth_client, meal, entry): + response = await auth_client.post( + f"/api/diary/{TODAY}/meal/{meal.id}/preset", json={} + ) + assert response.status_code == 201 + + +async def test_save_meal_as_preset_returns_preset_name(auth_client, meal, entry): + response = await auth_client.post( + f"/api/diary/{TODAY}/meal/{meal.id}/preset", json={} + ) + assert response.json()["name"] == meal.name + + +async def test_save_meal_as_preset_overrides_name(auth_client, meal, entry): + response = await auth_client.post( + f"/api/diary/{TODAY}/meal/{meal.id}/preset", json={"name": "My Custom Preset"} + ) + assert response.json()["name"] == "My Custom Preset" + + +async def test_save_meal_as_preset_copies_entries(auth_client, meal, entry): + response = await auth_client.post( + f"/api/diary/{TODAY}/meal/{meal.id}/preset", json={} + ) + body = response.json() + assert len(body["entries"]) == 1 + assert body["entries"][0]["grams"] == entry.grams + assert body["entries"][0]["product_id"] == entry.product_id + + +async def test_save_empty_meal_as_preset_has_no_entries(auth_client, meal): + response = await auth_client.post( + f"/api/diary/{TODAY}/meal/{meal.id}/preset", json={} + ) + assert response.status_code == 201 + assert response.json()["entries"] == [] + + +async def test_save_meal_as_preset_meal_not_found_returns_404(auth_client, diary): + response = await auth_client.post(f"/api/diary/{TODAY}/meal/99999/preset", json={}) + assert response.status_code == 404 + + +async def test_save_meal_as_preset_without_auth_returns_401(client, meal): + response = await client.post(f"/api/diary/{TODAY}/meal/{meal.id}/preset", json={}) + assert response.status_code == 401 + + +# --- list presets --- + + +async def test_list_presets_returns_200(auth_client): + response = await auth_client.get("/api/preset") + assert response.status_code == 200 + assert isinstance(response.json(), list) + + +async def test_list_presets_contains_saved(auth_client, preset): + response = await auth_client.get("/api/preset") + ids = [p["id"] for p in response.json()] + assert preset.id in ids + + +async def test_list_presets_without_auth_returns_401(client): + response = await client.get("/api/preset") + assert response.status_code == 401 + + +# --- update preset --- + + +async def test_update_preset_returns_200(auth_client, preset): + response = await auth_client.patch( + f"/api/preset/{preset.id}", json={"name": "Renamed"} + ) + assert response.status_code == 200 + + +async def test_update_preset_changes_name(auth_client, preset): + response = await auth_client.patch( + f"/api/preset/{preset.id}", json={"name": "Renamed"} + ) + assert response.json()["name"] == "Renamed" + + +async def test_update_preset_not_found_returns_404(auth_client): + response = await auth_client.patch("/api/preset/99999", json={"name": "Ghost"}) + assert response.status_code == 404 + + +async def test_update_preset_without_auth_returns_401(client, preset): + response = await client.patch(f"/api/preset/{preset.id}", json={"name": "x"}) + assert response.status_code == 401 + + +# --- delete preset --- + + +async def test_delete_preset_returns_204(auth_client, preset): + response = await auth_client.delete(f"/api/preset/{preset.id}") + assert response.status_code == 204 + + +async def test_delete_preset_not_found_returns_404(auth_client): + response = await auth_client.delete("/api/preset/99999") + assert response.status_code == 404 + + +async def test_delete_preset_without_auth_returns_401(client, preset): + response = await client.delete(f"/api/preset/{preset.id}") + assert response.status_code == 401 + + +# --- create preset entry --- + + +async def test_create_preset_entry_returns_201(auth_client, preset, product): + response = await auth_client.post( + f"/api/preset/{preset.id}/entry", + json={"grams": 200.0, "product_id": product.id}, + ) + assert response.status_code == 201 + + +async def test_create_preset_entry_returns_correct_grams(auth_client, preset, product): + response = await auth_client.post( + f"/api/preset/{preset.id}/entry", + json={"grams": 200.0, "product_id": product.id}, + ) + assert response.json()["grams"] == 200.0 + + +async def test_create_preset_entry_preset_not_found_returns_404(auth_client, product): + response = await auth_client.post( + "/api/preset/99999/entry", json={"grams": 100.0, "product_id": product.id} + ) + assert response.status_code == 404 + + +async def test_create_preset_entry_without_auth_returns_401(client, preset, product): + response = await client.post( + f"/api/preset/{preset.id}/entry", + json={"grams": 100.0, "product_id": product.id}, + ) + assert response.status_code == 401 + + +# --- update preset entry --- + + +async def test_update_preset_entry_returns_200(auth_client, preset, preset_entry): + response = await auth_client.patch( + f"/api/preset/{preset.id}/entry/{preset_entry.id}", + json={"grams": 250.0}, + ) + assert response.status_code == 200 + + +async def test_update_preset_entry_changes_grams(auth_client, preset, preset_entry): + response = await auth_client.patch( + f"/api/preset/{preset.id}/entry/{preset_entry.id}", + json={"grams": 250.0}, + ) + assert response.json()["grams"] == 250.0 + + +async def test_update_preset_entry_not_found_returns_404(auth_client, preset): + response = await auth_client.patch( + f"/api/preset/{preset.id}/entry/99999", json={"grams": 100.0} + ) + assert response.status_code == 404 + + +async def test_update_preset_entry_without_auth_returns_401( + client, preset, preset_entry +): + response = await client.patch( + f"/api/preset/{preset.id}/entry/{preset_entry.id}", + json={"grams": 100.0}, + ) + assert response.status_code == 401 + + +# --- delete preset entry --- + + +async def test_delete_preset_entry_returns_204(auth_client, preset, preset_entry): + response = await auth_client.delete( + f"/api/preset/{preset.id}/entry/{preset_entry.id}" + ) + assert response.status_code == 204 + + +async def test_delete_preset_entry_not_found_returns_404(auth_client, preset): + response = await auth_client.delete(f"/api/preset/{preset.id}/entry/99999") + assert response.status_code == 404 + + +async def test_delete_preset_entry_without_auth_returns_401( + client, preset, preset_entry +): + response = await client.delete(f"/api/preset/{preset.id}/entry/{preset_entry.id}") + assert response.status_code == 401 + + +# --- load preset as meal --- + + +async def test_load_preset_as_meal_returns_201(auth_client, diary, preset): + response = await auth_client.post( + f"/api/diary/{TODAY}/meal/from_preset", json={"preset_id": preset.id} + ) + assert response.status_code == 201 + + +async def test_load_preset_as_meal_uses_preset_name(auth_client, diary, preset): + response = await auth_client.post( + f"/api/diary/{TODAY}/meal/from_preset", json={"preset_id": preset.id} + ) + assert response.json()["name"] == preset.name + + +async def test_load_preset_as_meal_overrides_name(auth_client, diary, preset): + response = await auth_client.post( + f"/api/diary/{TODAY}/meal/from_preset", + json={"preset_id": preset.id, "name": "Custom Name"}, + ) + assert response.json()["name"] == "Custom Name" + + +async def test_load_preset_as_meal_copies_entries(auth_client, diary, preset): + response = await auth_client.post( + f"/api/diary/{TODAY}/meal/from_preset", json={"preset_id": preset.id} + ) + body = response.json() + assert len(body["entries"]) == len(preset.entries) + assert body["entries"][0]["grams"] == preset.entries[0].grams + assert body["entries"][0]["product_id"] == preset.entries[0].product_id + + +async def test_load_preset_as_meal_diary_not_found_returns_404(auth_client, preset): + response = await auth_client.post( + "/api/diary/2000-01-01/meal/from_preset", json={"preset_id": preset.id} + ) + assert response.status_code == 404 + + +async def test_load_preset_as_meal_preset_not_found_returns_404(auth_client, diary): + response = await auth_client.post( + f"/api/diary/{TODAY}/meal/from_preset", json={"preset_id": 99999} + ) + assert response.status_code == 404 + + +async def test_load_preset_as_meal_without_auth_returns_401(client, preset): + response = await client.post( + f"/api/diary/{TODAY}/meal/from_preset", json={"preset_id": preset.id} + ) + assert response.status_code == 401 diff --git a/fooder/test/view/test_product.py b/fooder/test/view/test_product.py index 6672a65..4b13e76 100644 --- a/fooder/test/view/test_product.py +++ b/fooder/test/view/test_product.py @@ -1,4 +1,30 @@ import pytest +import pytest_asyncio + +from fooder.controller.product import ProductController +from fooder.model.product import ProductCreateModel + + +@pytest_asyncio.fixture +async def second_product(ctx): + data = ProductCreateModel( + name="Broccoli", protein=2.8, carb=7.0, fat=0.4, fiber=2.6 + ) + async with ctx.repo.transaction(): + ctrl = await ProductController.create(ctx, data) + return ctrl.obj + + +async def test_list_products_orders_by_usage( + auth_client, auth_ctx, product, second_product +): + async with auth_ctx.repo.transaction(): + await auth_ctx.repo.user_product_usage.increment( + user_id=auth_ctx.user.id, product_id=second_product.id, count=5 + ) + response = await auth_client.get("/api/product") + ids = [p["id"] for p in response.json()] + assert ids.index(second_product.id) < ids.index(product.id) async def test_update_product_returns_200(auth_client, product): diff --git a/fooder/test/view/test_token.py b/fooder/test/view/test_token.py index a906953..2bb31db 100644 --- a/fooder/test/view/test_token.py +++ b/fooder/test/view/test_token.py @@ -55,7 +55,7 @@ async def test_refresh_token_returns_new_tokens(client, user, user_password): refresh_token = response.json()["refresh_token"] response = await client.post( - "/api/token/refresh", params={"refresh_token": refresh_token} + "/api/token/refresh", json={"refresh_token": refresh_token} ) assert response.status_code == 200 body = response.json() @@ -72,7 +72,7 @@ async def test_refresh_token_access_token_is_valid(client, user, user_password): refresh_token = response.json()["refresh_token"] response = await client.post( - "/api/token/refresh", params={"refresh_token": refresh_token} + "/api/token/refresh", json={"refresh_token": refresh_token} ) token = AccessToken.decode(response.json()["access_token"]) assert token.sub == user.id @@ -86,7 +86,7 @@ async def test_refresh_token_refresh_token_is_valid(client, user, user_password) refresh_token = response.json()["refresh_token"] response = await client.post( - "/api/token/refresh", params={"refresh_token": refresh_token} + "/api/token/refresh", json={"refresh_token": refresh_token} ) token = RefreshToken.decode(response.json()["refresh_token"]) assert token.sub == user.id @@ -94,7 +94,7 @@ async def test_refresh_token_refresh_token_is_valid(client, user, user_password) async def test_refresh_token_invalid_returns_401(client): response = await client.post( - "/api/token/refresh", params={"refresh_token": "bad-token"} + "/api/token/refresh", json={"refresh_token": "bad-token"} ) assert response.status_code == 401 @@ -109,6 +109,6 @@ async def test_refresh_token_access_token_as_refresh_returns_401( access_token = response.json()["access_token"] response = await client.post( - "/api/token/refresh", params={"refresh_token": access_token} + "/api/token/refresh", json={"refresh_token": access_token} ) assert response.status_code == 401 diff --git a/fooder/view/entry.py b/fooder/view/entry.py index 8e0388a..472148d 100644 --- a/fooder/view/entry.py +++ b/fooder/view/entry.py @@ -20,9 +20,8 @@ async def create_entry_route( ctx: Context = Depends(_auth_ctx), ): async with ctx.repo.transaction(): - diary = await ctx.repo.diary.get_by_user_and_date(ctx.user.id, date) - await ctx.repo.meal.get_by_id_and_diary(meal_id, diary.id) - entry = await create_entry(ctx, meal_id=meal_id, data=data) + meal = await ctx.repo.meal.get_by_id_and_user(meal_id, ctx.user.id) + entry = await create_entry(ctx, meal=meal, data=data) return entry @@ -35,9 +34,7 @@ async def update_entry( ctx: Context = Depends(_auth_ctx), ): async with ctx.repo.transaction(): - diary = await ctx.repo.diary.get_by_user_and_date(ctx.user.id, date) - await ctx.repo.meal.get_by_id_and_diary(meal_id, diary.id) - entry = await ctx.repo.entry.get_by_id_and_meal(entry_id, meal_id) + entry = await ctx.repo.entry.get_by_id_and_user(entry_id, ctx.user.id) await EntryController(ctx, entry).update(data) return entry @@ -50,7 +47,5 @@ async def delete_entry( ctx: Context = Depends(_auth_ctx), ): async with ctx.repo.transaction(): - diary = await ctx.repo.diary.get_by_user_and_date(ctx.user.id, date) - await ctx.repo.meal.get_by_id_and_diary(meal_id, diary.id) - entry = await ctx.repo.entry.get_by_id_and_meal(entry_id, meal_id) + entry = await ctx.repo.entry.get_by_id_and_user(entry_id, ctx.user.id) await ctx.repo.entry.delete(entry) diff --git a/fooder/view/meal.py b/fooder/view/meal.py index e768701..c5f5314 100644 --- a/fooder/view/meal.py +++ b/fooder/view/meal.py @@ -2,9 +2,12 @@ import datetime from fastapi import APIRouter, Depends +from fooder.command.load_preset_as_meal import load_preset_as_meal +from fooder.command.save_meal_as_preset import save_meal_as_preset from fooder.context import AuthContextDependency, Context from fooder.controller.meal import MealController from fooder.model.meal import MealCreateModel, MealModel, MealUpdateModel +from fooder.model.preset import LoadPresetAsMealModel, PresetModel, SaveAsPresetModel router = APIRouter(tags=["meal"]) @@ -23,6 +26,19 @@ async def create_meal( return ctrl.obj +@router.post("/from_preset", response_model=MealModel, status_code=201) +async def load_preset_as_meal_view( + date: datetime.date, + data: LoadPresetAsMealModel, + ctx: Context = Depends(_auth_ctx), +): + async with ctx.repo.transaction(): + diary = await ctx.repo.diary.get_by_user_and_date(ctx.user.id, date) + preset = await ctx.repo.preset.get_by_id_and_user(data.preset_id, ctx.user.id) + meal = await load_preset_as_meal(ctx, preset, diary_id=diary.id, name=data.name) + return meal + + @router.patch("/{meal_id}", response_model=MealModel) async def update_meal( date: datetime.date, @@ -31,12 +47,24 @@ async def update_meal( ctx: Context = Depends(_auth_ctx), ): async with ctx.repo.transaction(): - diary = await ctx.repo.diary.get_by_user_and_date(ctx.user.id, date) - meal = await ctx.repo.meal.get_by_id_and_diary(meal_id, diary.id) + meal = await ctx.repo.meal.get_by_id_and_user(meal_id, ctx.user.id) await MealController(ctx, meal).update(data) return meal +@router.post("/{meal_id}/preset", response_model=PresetModel, status_code=201) +async def save_meal_as_preset_view( + date: datetime.date, + meal_id: int, + data: SaveAsPresetModel, + ctx: Context = Depends(_auth_ctx), +): + async with ctx.repo.transaction(): + meal = await ctx.repo.meal.get_by_id_and_user(meal_id, ctx.user.id) + preset = await save_meal_as_preset(ctx, meal, name=data.name) + return preset + + @router.delete("/{meal_id}", status_code=204) async def delete_meal( date: datetime.date, @@ -44,6 +72,5 @@ async def delete_meal( ctx: Context = Depends(_auth_ctx), ): async with ctx.repo.transaction(): - diary = await ctx.repo.diary.get_by_user_and_date(ctx.user.id, date) - meal = await ctx.repo.meal.get_by_id_and_diary(meal_id, diary.id) + meal = await ctx.repo.meal.get_by_id_and_user(meal_id, ctx.user.id) await ctx.repo.meal.delete(meal) diff --git a/fooder/view/preset.py b/fooder/view/preset.py new file mode 100644 index 0000000..fc8367f --- /dev/null +++ b/fooder/view/preset.py @@ -0,0 +1,81 @@ +from fastapi import APIRouter, Depends + +from fooder.context import AuthContextDependency, Context +from fooder.controller.preset import PresetController +from fooder.controller.preset_entry import PresetEntryController +from fooder.model.entry import EntryCreateModel, EntryUpdateModel +from fooder.model.preset import PresetModel, PresetUpdateModel +from fooder.model.preset_entry import PresetEntryModel + +router = APIRouter() + +_auth_ctx = AuthContextDependency() + + +@router.get("", response_model=list[PresetModel]) +async def list_presets( + ctx: Context = Depends(_auth_ctx), + limit: int = 10, + offset: int = 0, +): + return await ctx.repo.preset.list_by_user( + user_id=ctx.user.id, limit=limit, offset=offset + ) + + +@router.patch("/{preset_id}", response_model=PresetModel) +async def update_preset( + preset_id: int, + data: PresetUpdateModel, + ctx: Context = Depends(_auth_ctx), +): + async with ctx.repo.transaction(): + preset = await ctx.repo.preset.get_by_id_and_user(preset_id, ctx.user.id) + await PresetController(ctx, preset).update(data) + return preset + + +@router.delete("/{preset_id}", status_code=204) +async def delete_preset( + preset_id: int, + ctx: Context = Depends(_auth_ctx), +): + async with ctx.repo.transaction(): + preset = await ctx.repo.preset.get_by_id_and_user(preset_id, ctx.user.id) + await ctx.repo.preset.delete(preset) + + +@router.post("/{preset_id}/entry", response_model=PresetEntryModel, status_code=201) +async def create_preset_entry( + preset_id: int, + data: EntryCreateModel, + ctx: Context = Depends(_auth_ctx), +): + async with ctx.repo.transaction(): + preset = await ctx.repo.preset.get_by_id_and_user(preset_id, ctx.user.id) + ctrl = await PresetEntryController.create(ctx, preset=preset, data=data) + return ctrl.obj + + +@router.patch("/{preset_id}/entry/{entry_id}", response_model=PresetEntryModel) +async def update_preset_entry( + preset_id: int, + entry_id: int, + data: EntryUpdateModel, + ctx: Context = Depends(_auth_ctx), +): + async with ctx.repo.transaction(): + entry = await ctx.repo.preset_entry.get_by_id_and_user(entry_id, ctx.user.id) + await PresetEntryController(ctx, entry).update(data) + return entry + + +@router.delete("/{preset_id}/entry/{entry_id}", status_code=204) +async def delete_preset_entry( + preset_id: int, + entry_id: int, + ctx: Context = Depends(_auth_ctx), +): + async with ctx.repo.transaction(): + entry = await ctx.repo.preset_entry.get_by_id_and_user(entry_id, ctx.user.id) + await ctx.repo.preset_entry.delete(entry) diff --git a/fooder/view/token.py b/fooder/view/token.py index 7cf740f..7a92d99 100644 --- a/fooder/view/token.py +++ b/fooder/view/token.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends from fastapi.security import OAuth2PasswordRequestForm from datetime import datetime -from fooder.model.token import TokenResponse +from fooder.model.token import TokenResponse, RefreshTokenRequest from fooder.context import ContextDependency, Context from fooder.controller.user import UserController from fooder.utils.jwt import RefreshToken, generate_token_pair @@ -34,9 +34,9 @@ async def token_create( @router.post("/refresh", response_model=TokenResponse) async def token_refresh( - refresh_token: str, + data: RefreshTokenRequest, ctx: Context = Depends(_ctx), ) -> TokenResponse: now = ctx.clock() - token = RefreshToken.decode(refresh_token) + token = RefreshToken.decode(data.refresh_token) return gen_token_response(token.sub, now)