[black]
This commit is contained in:
parent
20437905ba
commit
724c350e99
20 changed files with 170 additions and 120 deletions
|
|
@ -1,6 +1,5 @@
|
|||
from argparse import ArgumentParser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
|
|
|
|||
|
|
@ -61,60 +61,30 @@ def upgrade() -> None:
|
|||
sa.UniqueConstraint("user_id"),
|
||||
)
|
||||
op.drop_table("refreshtoken")
|
||||
op.add_column(
|
||||
"diary", sa.Column("protein_goal", sa.Float(), nullable=False)
|
||||
)
|
||||
op.add_column("diary", sa.Column("protein_goal", sa.Float(), nullable=False))
|
||||
op.add_column("diary", sa.Column("carb_goal", sa.Float(), nullable=False))
|
||||
op.add_column("diary", sa.Column("fat_goal", sa.Float(), nullable=False))
|
||||
op.add_column("diary", sa.Column("fiber_goal", sa.Float(), nullable=False))
|
||||
op.add_column(
|
||||
"diary", sa.Column("calories_goal", sa.Float(), nullable=False)
|
||||
)
|
||||
op.add_column("diary", sa.Column("calories_goal", sa.Float(), nullable=False))
|
||||
op.add_column("diary", sa.Column("version", sa.Integer(), nullable=False))
|
||||
op.add_column(
|
||||
"diary", sa.Column("created_at", sa.DateTime(), nullable=False)
|
||||
)
|
||||
op.add_column(
|
||||
"diary", sa.Column("last_changed", sa.DateTime(), nullable=False)
|
||||
)
|
||||
op.add_column("diary", sa.Column("created_at", sa.DateTime(), nullable=False))
|
||||
op.add_column("diary", sa.Column("last_changed", sa.DateTime(), nullable=False))
|
||||
op.create_unique_constraint(None, "diary", ["user_id", "date"])
|
||||
op.add_column("entry", sa.Column("version", sa.Integer(), nullable=False))
|
||||
op.add_column(
|
||||
"entry", sa.Column("created_at", sa.DateTime(), nullable=False)
|
||||
)
|
||||
op.add_column("entry", sa.Column("created_at", sa.DateTime(), nullable=False))
|
||||
op.add_column("meal", sa.Column("version", sa.Integer(), nullable=False))
|
||||
op.add_column(
|
||||
"meal", sa.Column("created_at", sa.DateTime(), nullable=False)
|
||||
)
|
||||
op.add_column(
|
||||
"meal", sa.Column("last_changed", sa.DateTime(), nullable=False)
|
||||
)
|
||||
op.add_column("meal", sa.Column("created_at", sa.DateTime(), nullable=False))
|
||||
op.add_column("meal", sa.Column("last_changed", sa.DateTime(), nullable=False))
|
||||
op.add_column("preset", sa.Column("version", sa.Integer(), nullable=False))
|
||||
op.add_column(
|
||||
"preset", sa.Column("created_at", sa.DateTime(), nullable=False)
|
||||
)
|
||||
op.add_column(
|
||||
"preset", sa.Column("last_changed", sa.DateTime(), nullable=False)
|
||||
)
|
||||
op.add_column(
|
||||
"presetentry", sa.Column("version", sa.Integer(), nullable=False)
|
||||
)
|
||||
op.add_column(
|
||||
"presetentry", sa.Column("created_at", sa.DateTime(), nullable=False)
|
||||
)
|
||||
op.add_column("preset", sa.Column("created_at", sa.DateTime(), nullable=False))
|
||||
op.add_column("preset", sa.Column("last_changed", sa.DateTime(), nullable=False))
|
||||
op.add_column("presetentry", sa.Column("version", sa.Integer(), nullable=False))
|
||||
op.add_column("presetentry", sa.Column("created_at", sa.DateTime(), nullable=False))
|
||||
op.add_column("product", sa.Column("calories", sa.Float(), nullable=False))
|
||||
op.add_column(
|
||||
"product", sa.Column("version", sa.Integer(), nullable=False)
|
||||
)
|
||||
op.add_column(
|
||||
"product", sa.Column("created_at", sa.DateTime(), nullable=False)
|
||||
)
|
||||
op.add_column(
|
||||
"product", sa.Column("last_changed", sa.DateTime(), nullable=False)
|
||||
)
|
||||
op.add_column(
|
||||
"product", sa.Column("deleted_at", sa.DateTime(), nullable=True)
|
||||
)
|
||||
op.add_column("product", sa.Column("version", sa.Integer(), nullable=False))
|
||||
op.add_column("product", sa.Column("created_at", sa.DateTime(), nullable=False))
|
||||
op.add_column("product", sa.Column("last_changed", sa.DateTime(), nullable=False))
|
||||
op.add_column("product", sa.Column("deleted_at", sa.DateTime(), nullable=True))
|
||||
op.create_index(
|
||||
"ix_product_barcode",
|
||||
"product",
|
||||
|
|
@ -126,15 +96,9 @@ def upgrade() -> None:
|
|||
op.drop_column("product", "hard_coded_calories")
|
||||
op.drop_column("product", "usage_count_cached")
|
||||
op.add_column("user", sa.Column("version", sa.Integer(), nullable=False))
|
||||
op.add_column(
|
||||
"user", sa.Column("created_at", sa.DateTime(), nullable=False)
|
||||
)
|
||||
op.add_column(
|
||||
"user", sa.Column("last_changed", sa.DateTime(), nullable=False)
|
||||
)
|
||||
op.add_column(
|
||||
"user", sa.Column("deleted_at", sa.DateTime(), nullable=True)
|
||||
)
|
||||
op.add_column("user", sa.Column("created_at", sa.DateTime(), nullable=False))
|
||||
op.add_column("user", sa.Column("last_changed", sa.DateTime(), nullable=False))
|
||||
op.add_column("user", sa.Column("deleted_at", sa.DateTime(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
|
|
@ -195,9 +159,7 @@ def downgrade() -> None:
|
|||
op.drop_column("diary", "protein_goal")
|
||||
op.create_table(
|
||||
"refreshtoken",
|
||||
sa.Column(
|
||||
"user_id", sa.INTEGER(), autoincrement=False, nullable=False
|
||||
),
|
||||
sa.Column("user_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("token", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from fooder.context import Context
|
||||
from typing import TypeVar, Generic
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -45,11 +45,15 @@ class ProductController(ModelController[Product]):
|
|||
if data.barcode is not None:
|
||||
self.obj.barcode = data.barcode
|
||||
|
||||
self.obj.calories = data.calories if data.calories is not None else calculate_calories(
|
||||
protein=self.obj.protein,
|
||||
carb=self.obj.carb,
|
||||
fat=self.obj.fat,
|
||||
fiber=self.obj.fiber,
|
||||
self.obj.calories = (
|
||||
data.calories
|
||||
if data.calories is not None
|
||||
else calculate_calories(
|
||||
protein=self.obj.protein,
|
||||
carb=self.obj.carb,
|
||||
fat=self.obj.fat,
|
||||
fiber=self.obj.fiber,
|
||||
)
|
||||
)
|
||||
|
||||
await self.ctx.repo.product.update(self.obj)
|
||||
|
|
@ -63,12 +67,15 @@ class ProductController(ModelController[Product]):
|
|||
except product_finder.ParseError:
|
||||
raise InvalidValue()
|
||||
|
||||
return await cls.create(ctx, ProductCreateModel(
|
||||
name=found.name,
|
||||
calories=found.kcal,
|
||||
fat=found.fat,
|
||||
protein=found.protein,
|
||||
carb=found.carb,
|
||||
fiber=found.fiber,
|
||||
barcode=barcode,
|
||||
))
|
||||
return await cls.create(
|
||||
ctx,
|
||||
ProductCreateModel(
|
||||
name=found.name,
|
||||
calories=found.kcal,
|
||||
fat=found.fat,
|
||||
protein=found.protein,
|
||||
carb=found.carb,
|
||||
fiber=found.fiber,
|
||||
barcode=barcode,
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ class DatabaseSessionManager:
|
|||
),
|
||||
)
|
||||
self._sessionmaker = async_sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=self._engine,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=self._engine,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,13 +19,17 @@ class CommonMixin:
|
|||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
version: Mapped[int] = mapped_column(default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=utc_now)
|
||||
last_changed: Mapped[datetime] = mapped_column(DateTime, default=utc_now, onupdate=utc_now)
|
||||
last_changed: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=utc_now, onupdate=utc_now
|
||||
)
|
||||
|
||||
__mapper_args__ = {"version_id_col": version}
|
||||
|
||||
|
||||
class SoftDeleteMixin:
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime, nullable=True, default=None
|
||||
)
|
||||
|
||||
|
||||
class PasswordMixin:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,9 @@ class Diary(Base, CommonMixin):
|
|||
|
||||
__table_args__ = (UniqueConstraint("user_id", "date"),)
|
||||
|
||||
meals: Mapped[list[Meal]] = relationship(lazy="selectin", order_by=Meal.order.desc())
|
||||
meals: Mapped[list[Meal]] = relationship(
|
||||
lazy="selectin", order_by=Meal.order.desc()
|
||||
)
|
||||
date: Mapped[datetime.date] = mapped_column(Date)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
|
||||
# snapshot of user settings at diary creation time — intentionally decoupled
|
||||
|
|
|
|||
|
|
@ -11,4 +11,6 @@ class Meal(Base, CommonMixin, AggregateMacrosMixin):
|
|||
name: Mapped[str]
|
||||
order: Mapped[int]
|
||||
diary_id: Mapped[int] = mapped_column(Integer, ForeignKey("diary.id"))
|
||||
entries: Mapped[list[Entry]] = relationship(lazy="selectin", order_by=Entry.last_changed)
|
||||
entries: Mapped[list[Entry]] = relationship(
|
||||
lazy="selectin", order_by=Entry.last_changed
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,9 +8,13 @@ class Product(Base, CommonMixin, SoftDeleteMixin):
|
|||
"""Product."""
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_product_barcode", "barcode", unique=True,
|
||||
postgresql_where=text("deleted_at IS NULL"),
|
||||
sqlite_where=text("deleted_at IS NULL")),
|
||||
Index(
|
||||
"ix_product_barcode",
|
||||
"barcode",
|
||||
unique=True,
|
||||
postgresql_where=text("deleted_at IS NULL"),
|
||||
sqlite_where=text("deleted_at IS NULL"),
|
||||
),
|
||||
)
|
||||
|
||||
name: Mapped[str]
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ class ObjModelMixin:
|
|||
"""
|
||||
Shared code for ObjModel.
|
||||
"""
|
||||
|
||||
id: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -23,11 +23,15 @@ class ProductCreateModel(ProductModelBase):
|
|||
|
||||
@property
|
||||
def resolved_calories(self) -> float:
|
||||
return self.calories if self.calories is not None else calculate_calories(
|
||||
protein=self.protein,
|
||||
carb=self.carb,
|
||||
fat=self.fat,
|
||||
fiber=self.fiber,
|
||||
return (
|
||||
self.calories
|
||||
if self.calories is not None
|
||||
else calculate_calories(
|
||||
protein=self.protein,
|
||||
carb=self.carb,
|
||||
fat=self.fat,
|
||||
fiber=self.fiber,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,12 @@ class RepositoryBase(Generic[T]):
|
|||
|
||||
return obj
|
||||
|
||||
async def _list(self, *expressions: ColumnElement, offset: int = 0, limit: int | None = DEFAULT_LIMIT) -> Sequence[T]:
|
||||
async def _list(
|
||||
self,
|
||||
*expressions: ColumnElement,
|
||||
offset: int = 0,
|
||||
limit: int | None = DEFAULT_LIMIT
|
||||
) -> Sequence[T]:
|
||||
stmt = self._build_select(*expressions)
|
||||
|
||||
if offset:
|
||||
|
|
|
|||
4
fooder/test/fixtures/client.py
vendored
4
fooder/test/fixtures/client.py
vendored
|
|
@ -11,7 +11,9 @@ async def client(db_session):
|
|||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db_session] = override_get_db_session
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as c:
|
||||
yield c
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
|
|
|||
15
fooder/test/fixtures/product.py
vendored
15
fooder/test/fixtures/product.py
vendored
|
|
@ -20,7 +20,9 @@ def product_payload():
|
|||
|
||||
@pytest_asyncio.fixture
|
||||
async def product(ctx):
|
||||
data = ProductCreateModel(name="Chicken Breast", protein=31.0, carb=0.0, fat=3.6, fiber=0.0)
|
||||
data = ProductCreateModel(
|
||||
name="Chicken Breast", protein=31.0, carb=0.0, fat=3.6, fiber=0.0
|
||||
)
|
||||
async with ctx.repo.transaction():
|
||||
ctrl = await ProductController.create(ctx, data)
|
||||
return ctrl.obj
|
||||
|
|
@ -28,7 +30,14 @@ async def product(ctx):
|
|||
|
||||
@pytest_asyncio.fixture
|
||||
async def product_with_barcode(ctx):
|
||||
data = ProductCreateModel(name="Barcoded Product", protein=10.0, carb=5.0, fat=2.0, fiber=1.0, barcode="1234567890")
|
||||
data = ProductCreateModel(
|
||||
name="Barcoded Product",
|
||||
protein=10.0,
|
||||
carb=5.0,
|
||||
fat=2.0,
|
||||
fiber=1.0,
|
||||
barcode="1234567890",
|
||||
)
|
||||
async with ctx.repo.transaction():
|
||||
ctrl = await ProductController.create(ctx, data)
|
||||
return ctrl.obj
|
||||
|
|
@ -60,5 +69,3 @@ def mock_product_finder_not_found(monkeypatch):
|
|||
raise product_finder.NotFound()
|
||||
|
||||
monkeypatch.setattr(product_finder, "find", fake_find)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
from ..fixtures.db import TestModel
|
||||
from fooder.exc import NotFound
|
||||
|
||||
# ------------------------------------------------------------------ create ---
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from typing import Literal
|
|||
from fooder.exc import Unauthorized
|
||||
from fooder.utils.jwt import AccessToken, RefreshToken, Token
|
||||
|
||||
|
||||
PAST = datetime(2000, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,13 +2,19 @@ import pytest
|
|||
|
||||
|
||||
async def test_update_product_returns_200(auth_client, product):
|
||||
response = await auth_client.patch(f"/api/product/{product.id}", json={"name": "Updated Name"})
|
||||
response = await auth_client.patch(
|
||||
f"/api/product/{product.id}", json={"name": "Updated Name"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Updated Name"
|
||||
|
||||
|
||||
async def test_update_product_recalculates_calories_when_macros_change(auth_client, product):
|
||||
response = await auth_client.patch(f"/api/product/{product.id}", json={"protein": 10.0})
|
||||
async def test_update_product_recalculates_calories_when_macros_change(
|
||||
auth_client, product
|
||||
):
|
||||
response = await auth_client.patch(
|
||||
f"/api/product/{product.id}", json={"protein": 10.0}
|
||||
)
|
||||
# protein=10, carb=0, fat=3.6, fiber=0 → 10*4 + 3.6*9 = 40 + 32.4 = 72.4
|
||||
assert response.json()["calories"] == pytest.approx(72.4)
|
||||
|
||||
|
|
@ -25,9 +31,15 @@ async def test_update_product_not_found_returns_404(auth_client):
|
|||
assert response.status_code == 404
|
||||
|
||||
|
||||
async def test_update_product_duplicate_barcode_returns_409(auth_client, product, product_payload):
|
||||
await auth_client.post("/api/product", json={**product_payload, "name": "Other", "barcode": "AAA"})
|
||||
response = await auth_client.patch(f"/api/product/{product.id}", json={"barcode": "AAA"})
|
||||
async def test_update_product_duplicate_barcode_returns_409(
|
||||
auth_client, product, product_payload
|
||||
):
|
||||
await auth_client.post(
|
||||
"/api/product", json={**product_payload, "name": "Other", "barcode": "AAA"}
|
||||
)
|
||||
response = await auth_client.patch(
|
||||
f"/api/product/{product.id}", json={"barcode": "AAA"}
|
||||
)
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
|
|
@ -59,23 +71,37 @@ async def test_create_product_calculates_calories(auth_client, product_payload):
|
|||
|
||||
|
||||
async def test_create_product_uses_explicit_calories(auth_client, product_payload):
|
||||
response = await auth_client.post("/api/product", json={**product_payload, "calories": 50.0})
|
||||
response = await auth_client.post(
|
||||
"/api/product", json={**product_payload, "calories": 50.0}
|
||||
)
|
||||
assert response.json()["calories"] == 50.0
|
||||
|
||||
|
||||
async def test_create_product_duplicate_barcode_returns_409(auth_client, product_payload):
|
||||
await auth_client.post("/api/product", json={**product_payload, "barcode": "123456"})
|
||||
response = await auth_client.post("/api/product", json={**product_payload, "barcode": "123456"})
|
||||
async def test_create_product_duplicate_barcode_returns_409(
|
||||
auth_client, product_payload
|
||||
):
|
||||
await auth_client.post(
|
||||
"/api/product", json={**product_payload, "barcode": "123456"}
|
||||
)
|
||||
response = await auth_client.post(
|
||||
"/api/product", json={**product_payload, "barcode": "123456"}
|
||||
)
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
async def test_create_product_invalid_protein_returns_422(auth_client, product_payload):
|
||||
response = await auth_client.post("/api/product", json={**product_payload, "protein": -1.0})
|
||||
response = await auth_client.post(
|
||||
"/api/product", json={**product_payload, "protein": -1.0}
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
async def test_create_product_protein_over_100_returns_422(auth_client, product_payload):
|
||||
response = await auth_client.post("/api/product", json={**product_payload, "protein": 101.0})
|
||||
async def test_create_product_protein_over_100_returns_422(
|
||||
auth_client, product_payload
|
||||
):
|
||||
response = await auth_client.post(
|
||||
"/api/product", json={**product_payload, "protein": 101.0}
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
|
|
@ -106,13 +132,19 @@ async def test_list_products_without_auth_returns_401(client):
|
|||
assert response.status_code == 401
|
||||
|
||||
|
||||
async def test_get_by_barcode_returns_product_from_db(auth_client, product_with_barcode):
|
||||
response = await auth_client.get(f"/api/product/barcode/{product_with_barcode.barcode}")
|
||||
async def test_get_by_barcode_returns_product_from_db(
|
||||
auth_client, product_with_barcode
|
||||
):
|
||||
response = await auth_client.get(
|
||||
f"/api/product/barcode/{product_with_barcode.barcode}"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] == product_with_barcode.id
|
||||
|
||||
|
||||
async def test_get_by_barcode_imports_when_not_in_db(auth_client, mock_product_finder, external_product):
|
||||
async def test_get_by_barcode_imports_when_not_in_db(
|
||||
auth_client, mock_product_finder, external_product
|
||||
):
|
||||
response = await auth_client.get("/api/product/barcode/9999999999")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
|
|
@ -125,13 +157,17 @@ async def test_get_by_barcode_imports_when_not_in_db(auth_client, mock_product_f
|
|||
assert body["barcode"] == "9999999999"
|
||||
|
||||
|
||||
async def test_get_by_barcode_persists_imported_product(auth_client, mock_product_finder):
|
||||
async def test_get_by_barcode_persists_imported_product(
|
||||
auth_client, mock_product_finder
|
||||
):
|
||||
await auth_client.get("/api/product/barcode/8888888888")
|
||||
response = await auth_client.get("/api/product/barcode/8888888888")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
async def test_get_by_barcode_not_found_returns_404(auth_client, mock_product_finder_not_found):
|
||||
async def test_get_by_barcode_not_found_returns_404(
|
||||
auth_client, mock_product_finder_not_found
|
||||
):
|
||||
response = await auth_client.get("/api/product/barcode/0000000000")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,9 @@ async def test_refresh_token_returns_new_tokens(client, user, user_password):
|
|||
)
|
||||
refresh_token = response.json()["refresh_token"]
|
||||
|
||||
response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token})
|
||||
response = await client.post(
|
||||
"/api/token/refresh", params={"refresh_token": refresh_token}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "access_token" in body
|
||||
|
|
@ -69,7 +71,9 @@ async def test_refresh_token_access_token_is_valid(client, user, user_password):
|
|||
)
|
||||
refresh_token = response.json()["refresh_token"]
|
||||
|
||||
response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token})
|
||||
response = await client.post(
|
||||
"/api/token/refresh", params={"refresh_token": refresh_token}
|
||||
)
|
||||
token = AccessToken.decode(response.json()["access_token"])
|
||||
assert token.sub == user.id
|
||||
|
||||
|
|
@ -81,22 +85,30 @@ async def test_refresh_token_refresh_token_is_valid(client, user, user_password)
|
|||
)
|
||||
refresh_token = response.json()["refresh_token"]
|
||||
|
||||
response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token})
|
||||
response = await client.post(
|
||||
"/api/token/refresh", params={"refresh_token": refresh_token}
|
||||
)
|
||||
token = RefreshToken.decode(response.json()["refresh_token"])
|
||||
assert token.sub == user.id
|
||||
|
||||
|
||||
async def test_refresh_token_invalid_returns_401(client):
|
||||
response = await client.post("/api/token/refresh", params={"refresh_token": "bad-token"})
|
||||
response = await client.post(
|
||||
"/api/token/refresh", params={"refresh_token": "bad-token"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
async def test_refresh_token_access_token_as_refresh_returns_401(client, user, user_password):
|
||||
async def test_refresh_token_access_token_as_refresh_returns_401(
|
||||
client, user, user_password
|
||||
):
|
||||
response = await client.post(
|
||||
"/api/token",
|
||||
data={"username": user.username, "password": user_password},
|
||||
)
|
||||
access_token = response.json()["access_token"]
|
||||
|
||||
response = await client.post("/api/token/refresh", params={"refresh_token": access_token})
|
||||
assert response.status_code == 401
|
||||
response = await client.post(
|
||||
"/api/token/refresh", params={"refresh_token": access_token}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
|
|
|||
|
|
@ -45,7 +45,9 @@ class RefreshToken(Token):
|
|||
expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
|
||||
def generate_token_pair(entity_id: int, now: datetime) -> tuple[AccessToken, RefreshToken]:
|
||||
def generate_token_pair(
|
||||
entity_id: int, now: datetime
|
||||
) -> tuple[AccessToken, RefreshToken]:
|
||||
return (
|
||||
AccessToken(exp=AccessToken.calculate_exp(now), sub=entity_id),
|
||||
RefreshToken(exp=RefreshToken.calculate_exp(now), sub=entity_id),
|
||||
|
|
|
|||
|
|
@ -51,4 +51,4 @@ async def find(barcode: str) -> ExternalProduct:
|
|||
)
|
||||
except (KeyError, TypeError) as e:
|
||||
logger.error("Failed to parse product %s: %s", barcode, e)
|
||||
raise ParseError() from e
|
||||
raise ParseError() from e
|
||||
|
|
|
|||
Loading…
Reference in a new issue