Compare commits
No commits in common. "d66eb8affa93d2c16a12253deac8127aad12cc50" and "d4d3d97204670eac3b96f7143b31574518ece0b9" have entirely different histories.
d66eb8affa
...
d4d3d97204
28 changed files with 66 additions and 289 deletions
18
.github/workflows/python.yml
vendored
18
.github/workflows/python.yml
vendored
|
@ -1,21 +1,9 @@
|
||||||
name: Python lint and test
|
name: Python lint and test
|
||||||
|
|
||||||
on:
|
on: [push, pull_request]
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
- 'releases/**'
|
|
||||||
paths:
|
|
||||||
- '**.py'
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
- 'releases/**'
|
|
||||||
paths:
|
|
||||||
- '**.py'
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
linttest:
|
lint:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
@ -30,7 +18,5 @@ jobs:
|
||||||
run: black --check fooder
|
run: black --check fooder
|
||||||
- name: Run flake8
|
- name: Run flake8
|
||||||
run: flake8 fooder
|
run: flake8 fooder
|
||||||
- name: Run mypy
|
|
||||||
run: mypy fooder
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: ./test.sh
|
run: ./test.sh
|
||||||
|
|
6
Makefile
6
Makefile
|
@ -14,15 +14,15 @@ push:
|
||||||
docker push registry.domandoman.xyz/fooder/api
|
docker push registry.domandoman.xyz/fooder/api
|
||||||
|
|
||||||
black:
|
black:
|
||||||
python -m black fooder
|
black fooder
|
||||||
|
|
||||||
.PHONY: mypy
|
.PHONY: mypy
|
||||||
mypy:
|
mypy:
|
||||||
python -m mypy fooder
|
mypy fooder
|
||||||
|
|
||||||
.PHONY: flake
|
.PHONY: flake
|
||||||
flake:
|
flake:
|
||||||
python -m flake8 fooder
|
flake8 fooder
|
||||||
|
|
||||||
.PHONY: lint
|
.PHONY: lint
|
||||||
lint: black mypy flake
|
lint: black mypy flake
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
# FOODER
|
# FOODER
|
||||||
|
|
||||||
Simple API for food diary application. It uses FastAPI and async postgres.
|
Simple API for food diary application. It uses FastAPI and async postgres for faster operation.
|
||||||
|
|
||||||
|
I plan on developing a few clients for the API, for now only one is available:
|
||||||
|
- [Fooder CLI Client](https://github.com/ickyicky/fooder-cli-client)
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
|
|
|
@ -11,5 +11,3 @@ REFRESH_SECRET_KEY="${REFRESH_SECRET_KEY}" # generate with $ openssl rand -hex 3
|
||||||
ALGORITHM="HS256"
|
ALGORITHM="HS256"
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||||
REFRESH_TOKEN_EXPIRE_DAYS=30
|
REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||||
|
|
||||||
API_KEY="${API_KEY}" # generate with $ openssl rand -hex 32
|
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
from fastapi_users.password import PasswordHelper
|
from fastapi_users.password import PasswordHelper
|
||||||
from typing import Annotated
|
from typing import AsyncGenerator, Annotated
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from .settings import Settings
|
from .settings import Settings
|
||||||
from .domain.user import User
|
from .domain.user import User
|
||||||
|
@ -15,7 +16,7 @@ from .db import get_session
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
password_helper = PasswordHelper(pwd_context) # type: ignore
|
password_helper = PasswordHelper(pwd_context)
|
||||||
|
|
||||||
AsyncSessionDependency = Annotated[async_sessionmaker, Depends(get_session)]
|
AsyncSessionDependency = Annotated[async_sessionmaker, Depends(get_session)]
|
||||||
TokenDependency = Annotated[str, Depends(oauth2_scheme)]
|
TokenDependency = Annotated[str, Depends(oauth2_scheme)]
|
||||||
|
@ -31,57 +32,35 @@ def get_password_hash(password: str) -> str:
|
||||||
|
|
||||||
async def authenticate_user(
|
async def authenticate_user(
|
||||||
session: AsyncSession, username: str, password: str
|
session: AsyncSession, username: str, password: str
|
||||||
) -> User | None:
|
) -> AsyncGenerator[User, None]:
|
||||||
user = await User.get_by_username(session, username)
|
user = await User.get_by_username(session, username)
|
||||||
|
if not user:
|
||||||
if user is None:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
assert user is not None
|
|
||||||
|
|
||||||
if not verify_password(password, user.hashed_password):
|
if not verify_password(password, user.hashed_password):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
async def verify_refresh_token(
|
async def verify_refresh_token(
|
||||||
session: AsyncSession, token: str
|
session: AsyncSession, token: str
|
||||||
) -> RefreshToken | None:
|
) -> AsyncGenerator[RefreshToken, None]:
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
|
token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||||
)
|
)
|
||||||
sub = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
|
|
||||||
if sub is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not isinstance(sub, str):
|
|
||||||
return None
|
|
||||||
|
|
||||||
username: str = str(sub)
|
|
||||||
|
|
||||||
if username is None:
|
if username is None:
|
||||||
return None
|
return
|
||||||
|
|
||||||
except JWTError:
|
except JWTError:
|
||||||
return None
|
return
|
||||||
|
|
||||||
user = await User.get_by_username(session, username)
|
user = await User.get_by_username(session, username)
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
return None
|
return
|
||||||
|
|
||||||
assert user is not None
|
|
||||||
|
|
||||||
current_token = await RefreshToken.get_token(session, user.id, token)
|
current_token = await RefreshToken.get_token(session, user.id, token)
|
||||||
|
|
||||||
if current_token is not None:
|
if current_token is not None:
|
||||||
return current_token
|
return current_token
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(user: User) -> str:
|
def create_access_token(user: User) -> str:
|
||||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
@ -107,38 +86,18 @@ async def create_refresh_token(session: AsyncSession, user: User) -> RefreshToke
|
||||||
return await RefreshToken.create(session, token=encoded_jwt, user_id=user.id)
|
return await RefreshToken.create(session, token=encoded_jwt, user_id=user.id)
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(ssn: AsyncSessionDependency, token: TokenDependency) -> User:
|
async def get_current_user(
|
||||||
async with ssn() as session:
|
session: AsyncSessionDependency, token: TokenDependency
|
||||||
|
) -> User:
|
||||||
|
async with session() as session:
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||||
)
|
)
|
||||||
sub = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
|
|
||||||
if sub is None:
|
|
||||||
raise HTTPException(status_code=401, detail="Unathorized")
|
|
||||||
|
|
||||||
if not isinstance(sub, str):
|
|
||||||
raise HTTPException(status_code=401, detail="Unathorized")
|
|
||||||
|
|
||||||
username: str = str(sub)
|
|
||||||
|
|
||||||
if username is None:
|
if username is None:
|
||||||
raise HTTPException(status_code=401, detail="Unathorized")
|
raise HTTPException(status_code=401, detail="Unathorized")
|
||||||
|
|
||||||
except JWTError:
|
except JWTError:
|
||||||
raise HTTPException(status_code=401, detail="Unathorized")
|
raise HTTPException(status_code=401, detail="Unathorized")
|
||||||
|
|
||||||
user = await User.get_by_username(session, username)
|
return await User.get_by_username(session, username)
|
||||||
|
|
||||||
if user is None:
|
|
||||||
raise HTTPException(status_code=401, detail="Unathorized")
|
|
||||||
|
|
||||||
assert user is not None
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
async def authorize_api_key(token: TokenDependency) -> None:
|
|
||||||
if token == settings.API_KEY:
|
|
||||||
return None
|
|
||||||
raise HTTPException(status_code=401, detail="Unathorized")
|
|
||||||
|
|
|
@ -2,13 +2,12 @@ from typing import Annotated, Any
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||||
from ..db import get_session
|
from ..db import get_session
|
||||||
from ..auth import get_current_user, authorize_api_key
|
from ..auth import get_current_user
|
||||||
from ..domain.user import User
|
from ..domain.user import User
|
||||||
|
|
||||||
|
|
||||||
AsyncSession = Annotated[async_sessionmaker, Depends(get_session)]
|
AsyncSession = Annotated[async_sessionmaker, Depends(get_session)]
|
||||||
UserDependency = Annotated[User, Depends(get_current_user)]
|
UserDependency = Annotated[User, Depends(get_current_user)]
|
||||||
ApiKeyDependency = Annotated[None, Depends(authorize_api_key)]
|
|
||||||
|
|
||||||
|
|
||||||
class BaseController:
|
class BaseController:
|
||||||
|
@ -26,8 +25,3 @@ class AuthorizedController(BaseController):
|
||||||
def __init__(self, session: AsyncSession, user: UserDependency) -> None:
|
def __init__(self, session: AsyncSession, user: UserDependency) -> None:
|
||||||
super().__init__(session)
|
super().__init__(session)
|
||||||
self.user = user
|
self.user = user
|
||||||
|
|
||||||
|
|
||||||
class TasksSessionController(BaseController):
|
|
||||||
def __init__(self, session: AsyncSession, api_key: ApiKeyDependency) -> None:
|
|
||||||
super().__init__(session)
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ class UpdateEntry(AuthorizedController):
|
||||||
|
|
||||||
|
|
||||||
class DeleteEntry(AuthorizedController):
|
class DeleteEntry(AuthorizedController):
|
||||||
async def call(self, entry_id: int) -> None:
|
async def call(self, entry_id: int) -> Entry:
|
||||||
async with self.async_session.begin() as session:
|
async with self.async_session.begin() as session:
|
||||||
entry = await DBEntry.get_by_id(session, self.user.id, entry_id)
|
entry = await DBEntry.get_by_id(session, self.user.id, entry_id)
|
||||||
if entry is None:
|
if entry is None:
|
||||||
|
|
|
@ -29,7 +29,7 @@ class CreateMeal(AuthorizedController):
|
||||||
|
|
||||||
|
|
||||||
class SaveMeal(AuthorizedController):
|
class SaveMeal(AuthorizedController):
|
||||||
async def call(self, meal_id: int, payload: SaveMealPayload) -> Preset:
|
async def call(self, meal_id: id, payload: SaveMealPayload) -> Preset:
|
||||||
async with self.async_session.begin() as session:
|
async with self.async_session.begin() as session:
|
||||||
meal = await DBMeal.get_by_id(session, self.user.id, meal_id)
|
meal = await DBMeal.get_by_id(session, self.user.id, meal_id)
|
||||||
if meal is None:
|
if meal is None:
|
||||||
|
@ -38,10 +38,7 @@ class SaveMeal(AuthorizedController):
|
||||||
try:
|
try:
|
||||||
return Preset.from_orm(
|
return Preset.from_orm(
|
||||||
await DBPreset.create(
|
await DBPreset.create(
|
||||||
session,
|
session, user_id=self.user.id, name=payload.name, meal=meal
|
||||||
user_id=self.user.id,
|
|
||||||
name=payload.name or meal.name,
|
|
||||||
meal=meal,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
|
@ -49,7 +46,7 @@ class SaveMeal(AuthorizedController):
|
||||||
|
|
||||||
|
|
||||||
class DeleteMeal(AuthorizedController):
|
class DeleteMeal(AuthorizedController):
|
||||||
async def call(self, meal_id: int) -> None:
|
async def call(self, meal_id: id) -> None:
|
||||||
async with self.async_session.begin() as session:
|
async with self.async_session.begin() as session:
|
||||||
meal = await DBMeal.get_by_id(session, self.user.id, meal_id)
|
meal = await DBMeal.get_by_id(session, self.user.id, meal_id)
|
||||||
if meal is None:
|
if meal is None:
|
||||||
|
|
|
@ -32,7 +32,7 @@ class DeletePreset(AuthorizedController):
|
||||||
async def call(
|
async def call(
|
||||||
self,
|
self,
|
||||||
id: int,
|
id: int,
|
||||||
) -> None:
|
) -> AsyncIterator[Preset]:
|
||||||
async with self.async_session.begin() as session:
|
async with self.async_session.begin() as session:
|
||||||
preset = await DBPreset.get(session, self.user.id, id)
|
preset = await DBPreset.get(session, self.user.id, id)
|
||||||
|
|
||||||
|
|
|
@ -1,13 +0,0 @@
|
||||||
from fastapi import HTTPException
|
|
||||||
from ..domain.product import Product as DBProduct
|
|
||||||
from .base import TasksSessionController
|
|
||||||
|
|
||||||
|
|
||||||
class CacheProductUsageData(TasksSessionController):
|
|
||||||
async def call(self) -> None:
|
|
||||||
async with self.async_session.begin() as session:
|
|
||||||
try:
|
|
||||||
await DBProduct.cache_usage_data(session)
|
|
||||||
await session.commit()
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
|
@ -41,11 +41,6 @@ class RefreshToken(BaseController):
|
||||||
raise HTTPException(status_code=401, detail="Invalid token")
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
|
|
||||||
user = await DBUser.get(session, current_token.user_id)
|
user = await DBUser.get(session, current_token.user_id)
|
||||||
|
|
||||||
if user is None:
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid token")
|
|
||||||
|
|
||||||
assert user is not None
|
|
||||||
await current_token.delete(session)
|
await current_token.delete(session)
|
||||||
|
|
||||||
refresh_token = await create_refresh_token(session, user)
|
refresh_token = await create_refresh_token(session, user)
|
||||||
|
|
|
@ -17,6 +17,6 @@ class CommonMixin:
|
||||||
|
|
||||||
:rtype: str
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
return cls.__name__.lower() # type: ignore
|
return cls.__name__.lower()
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
|
|
@ -3,7 +3,7 @@ from sqlalchemy import ForeignKey, Integer, Date
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.sql.selectable import Select
|
from sqlalchemy.sql.selectable import Select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
import datetime
|
from datetime import date
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .base import Base, CommonMixin
|
from .base import Base, CommonMixin
|
||||||
|
@ -17,7 +17,7 @@ class Diary(Base, CommonMixin):
|
||||||
meals: Mapped[list[Meal]] = relationship(
|
meals: Mapped[list[Meal]] = relationship(
|
||||||
lazy="selectin", order_by=Meal.order.desc()
|
lazy="selectin", order_by=Meal.order.desc()
|
||||||
)
|
)
|
||||||
date: Mapped[datetime.date] = mapped_column(Date)
|
date: Mapped[date] = mapped_column(Date)
|
||||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
|
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -74,16 +74,14 @@ class Diary(Base, CommonMixin):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_diary(
|
async def get_diary(
|
||||||
cls, session: AsyncSession, user_id: int, date: datetime.date
|
cls, session: AsyncSession, user_id: int, date: date
|
||||||
) -> "Optional[Diary]":
|
) -> "Optional[Diary]":
|
||||||
"""get_diary."""
|
"""get_diary."""
|
||||||
query = cls.query(user_id).where(cls.date == date)
|
query = cls.query(user_id).where(cls.date == date)
|
||||||
return await session.scalar(query)
|
return await session.scalar(query)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(
|
async def create(cls, session: AsyncSession, user_id: int, date: date) -> "Diary":
|
||||||
cls, session: AsyncSession, user_id: int, date: datetime.date
|
|
||||||
) -> "Diary":
|
|
||||||
diary = Diary(
|
diary = Diary(
|
||||||
date=date,
|
date=date,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@ -95,13 +93,12 @@ class Diary(Base, CommonMixin):
|
||||||
except Exception:
|
except Exception:
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
db_diary = await cls.get_by_id(session, user_id, diary.id)
|
diary = await cls.get_by_id(session, user_id, diary.id)
|
||||||
|
|
||||||
if not db_diary:
|
if not diary:
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
await Meal.create(session, diary.id)
|
||||||
await Meal.create(session, db_diary.id)
|
return diary
|
||||||
return db_diary
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_id(
|
async def get_by_id(
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship, joinedload
|
from sqlalchemy.orm import Mapped, mapped_column, relationship, joinedload
|
||||||
from sqlalchemy import ForeignKey, Integer, DateTime, Boolean
|
from sqlalchemy import ForeignKey, Integer, DateTime
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy import select, update
|
from sqlalchemy import select
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ class Entry(Base, CommonMixin):
|
||||||
last_changed: Mapped[datetime] = mapped_column(
|
last_changed: Mapped[datetime] = mapped_column(
|
||||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||||
)
|
)
|
||||||
processed: Mapped[bool] = mapped_column(Boolean, default=False)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def amount(self) -> float:
|
def amount(self) -> float:
|
||||||
|
@ -88,10 +87,10 @@ class Entry(Base, CommonMixin):
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
raise AssertionError("meal or product does not exist")
|
raise AssertionError("meal or product does not exist")
|
||||||
|
|
||||||
db_entry = await cls._get_by_id(session, entry.id)
|
entry = await cls._get_by_id(session, entry.id)
|
||||||
if not db_entry:
|
if not entry:
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
return db_entry
|
return entry
|
||||||
|
|
||||||
async def update(
|
async def update(
|
||||||
self,
|
self,
|
||||||
|
@ -153,12 +152,3 @@ class Entry(Base, CommonMixin):
|
||||||
"""delete."""
|
"""delete."""
|
||||||
await session.delete(self)
|
await session.delete(self)
|
||||||
await session.flush()
|
await session.flush()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def mark_processed(
|
|
||||||
cls,
|
|
||||||
session: AsyncSession,
|
|
||||||
) -> None:
|
|
||||||
stmt = update(cls).where(cls.processed == False).values(processed=True)
|
|
||||||
|
|
||||||
await session.execute(stmt)
|
|
||||||
|
|
|
@ -84,10 +84,10 @@ class Meal(Base, CommonMixin):
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
raise AssertionError("diary does not exist")
|
raise AssertionError("diary does not exist")
|
||||||
|
|
||||||
db_meal = await cls._get_by_id(session, meal.id)
|
meal = await cls._get_by_id(session, meal.id)
|
||||||
if not db_meal:
|
if not meal:
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
return db_meal
|
return meal
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_from_preset(
|
async def create_from_preset(
|
||||||
|
@ -118,10 +118,10 @@ class Meal(Base, CommonMixin):
|
||||||
for entry in preset.entries:
|
for entry in preset.entries:
|
||||||
await Entry.create(session, meal.id, entry.product_id, entry.grams)
|
await Entry.create(session, meal.id, entry.product_id, entry.grams)
|
||||||
|
|
||||||
db_meal = await cls._get_by_id(session, meal.id)
|
meal = await cls._get_by_id(session, meal.id)
|
||||||
if not db_meal:
|
if not meal:
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
return db_meal
|
return meal
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Meal]":
|
async def _get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Meal]":
|
||||||
|
|
|
@ -63,7 +63,7 @@ class Preset(Base, CommonMixin):
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(
|
async def create(
|
||||||
cls, session: AsyncSession, user_id: int, name: str, meal: "Meal"
|
cls, session: AsyncSession, user_id: int, name: str, meal: "Meal"
|
||||||
) -> "Preset":
|
) -> None:
|
||||||
preset = Preset(user_id=user_id, name=name)
|
preset = Preset(user_id=user_id, name=name)
|
||||||
|
|
||||||
session.add(preset)
|
session.add(preset)
|
||||||
|
@ -76,12 +76,7 @@ class Preset(Base, CommonMixin):
|
||||||
for entry in meal.entries:
|
for entry in meal.entries:
|
||||||
await PresetEntry.create(session, preset.id, entry)
|
await PresetEntry.create(session, preset.id, entry)
|
||||||
|
|
||||||
db_preset = await cls.get(session, user_id, preset.id)
|
return await cls.get(session, user_id, preset.id)
|
||||||
|
|
||||||
if not db_preset:
|
|
||||||
raise RuntimeError()
|
|
||||||
|
|
||||||
return db_preset
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def list_all(
|
async def list_all(
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped
|
||||||
from sqlalchemy import select, BigInteger, func, update
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from typing import AsyncIterator, Optional
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
|
@ -15,14 +15,8 @@ class Product(Base, CommonMixin):
|
||||||
carb: Mapped[float]
|
carb: Mapped[float]
|
||||||
fat: Mapped[float]
|
fat: Mapped[float]
|
||||||
fiber: Mapped[float]
|
fiber: Mapped[float]
|
||||||
hard_coded_calories: Mapped[Optional[float]]
|
hard_coded_calories: Mapped[Optional[float]] = None
|
||||||
barcode: Mapped[Optional[str]]
|
barcode: Mapped[Optional[str]] = None
|
||||||
|
|
||||||
usage_count_cached: Mapped[int] = mapped_column(
|
|
||||||
BigInteger,
|
|
||||||
default=0,
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def calories(self) -> float:
|
def calories(self) -> float:
|
||||||
|
@ -47,13 +41,11 @@ class Product(Base, CommonMixin):
|
||||||
|
|
||||||
if q:
|
if q:
|
||||||
q_list = q.split()
|
q_list = q.split()
|
||||||
qq = "%" + "%".join(q_list) + "%"
|
for qq in q_list:
|
||||||
query = query.filter(cls.name.ilike(f"%{qq.lower()}%"))
|
query = query.filter(cls.name.ilike(f"%{qq.lower()}%"))
|
||||||
|
|
||||||
query = query.offset(offset).limit(limit)
|
query = query.offset(offset).limit(limit)
|
||||||
stream = await session.stream_scalars(
|
stream = await session.stream_scalars(query.order_by(cls.id))
|
||||||
query.order_by(cls.usage_count_cached.desc())
|
|
||||||
)
|
|
||||||
async for row in stream:
|
async for row in stream:
|
||||||
yield row
|
yield row
|
||||||
|
|
||||||
|
@ -112,28 +104,3 @@ class Product(Base, CommonMixin):
|
||||||
session.add(product)
|
session.add(product)
|
||||||
await session.flush()
|
await session.flush()
|
||||||
return product
|
return product
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def cache_usage_data(
|
|
||||||
cls,
|
|
||||||
session: AsyncSession,
|
|
||||||
) -> None:
|
|
||||||
from .entry import Entry
|
|
||||||
|
|
||||||
stmt = (
|
|
||||||
update(cls)
|
|
||||||
.where(
|
|
||||||
cls.id.in_(
|
|
||||||
select(Entry.product_id).where(Entry.processed == False).distinct()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.values(
|
|
||||||
usage_count_cached=select(func.count(Entry.id)).where(
|
|
||||||
Entry.product_id == cls.id,
|
|
||||||
Entry.processed == False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
await session.execute(stmt)
|
|
||||||
await Entry.mark_processed(session)
|
|
||||||
|
|
|
@ -47,18 +47,18 @@ class RefreshToken(Base, CommonMixin):
|
||||||
:type token: str
|
:type token: str
|
||||||
:rtype: "RefreshToken"
|
:rtype: "RefreshToken"
|
||||||
"""
|
"""
|
||||||
db_token = cls(
|
token = cls(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
)
|
)
|
||||||
session.add(db_token)
|
session.add(token)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await session.flush()
|
await session.flush()
|
||||||
except Exception:
|
except Exception:
|
||||||
raise AssertionError("invalid token")
|
raise AssertionError("invalid token")
|
||||||
|
|
||||||
return db_token
|
return token
|
||||||
|
|
||||||
async def delete(self, session: AsyncSession) -> None:
|
async def delete(self, session: AsyncSession) -> None:
|
||||||
"""delete.
|
"""delete.
|
||||||
|
|
|
@ -15,5 +15,3 @@ class Settings(BaseSettings):
|
||||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
||||||
|
|
||||||
ALLOWED_ORIGINS: List[str] = ["*"]
|
ALLOWED_ORIGINS: List[str] = ["*"]
|
||||||
|
|
||||||
API_KEY: str
|
|
||||||
|
|
|
@ -1,17 +0,0 @@
|
||||||
from fastapi import FastAPI
|
|
||||||
from .view.tasks import router
|
|
||||||
from .settings import Settings
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Fooder Tasks admininstrative API")
|
|
||||||
app.include_router(router)
|
|
||||||
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=Settings().ALLOWED_ORIGINS,
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
20
fooder/test/fixtures/client.py
vendored
20
fooder/test/fixtures/client.py
vendored
|
@ -1,9 +1,7 @@
|
||||||
from fooder.app import app
|
from fooder.app import app
|
||||||
from fooder.tasks_app import app as tasks_app
|
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
import pytest
|
import pytest
|
||||||
import httpx
|
import httpx
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
|
@ -69,29 +67,11 @@ class Client:
|
||||||
return await self.client.patch(path, **kwargs)
|
return await self.client.patch(path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class TasksClient(Client):
|
|
||||||
def __init__(self, authorized: bool = True):
|
|
||||||
super().__init__()
|
|
||||||
self.client = AsyncClient(app=tasks_app, base_url="http://testserver/api")
|
|
||||||
self.client.headers["Accept"] = "application/json"
|
|
||||||
|
|
||||||
if authorized:
|
|
||||||
self.client.headers["Authorization"] = "Bearer " + self.get_token()
|
|
||||||
|
|
||||||
def get_token(self) -> str:
|
|
||||||
return os.getenv("API_KEY")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def unauthorized_client() -> Client:
|
def unauthorized_client() -> Client:
|
||||||
return Client()
|
return Client()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def tasks_client() -> Client:
|
|
||||||
return TasksClient()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def client(user_payload) -> Client:
|
async def client(user_payload) -> Client:
|
||||||
client = Client()
|
client = Client()
|
||||||
|
|
|
@ -1,15 +0,0 @@
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_cache_product_usage(client, tasks_client):
|
|
||||||
response = await client.get("product")
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
old_data = response.json()
|
|
||||||
|
|
||||||
response = await tasks_client.post("/cache_product_usage_data")
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
response = await client.get("product")
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
assert response.json() != old_data
|
|
|
@ -48,5 +48,6 @@ def find(bar_code: str) -> Product:
|
||||||
fiber=data["product"]["nutriments"].get("fiber_100g", 0.0),
|
fiber=data["product"]["nutriments"].get("fiber_100g", 0.0),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
raise e
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
raise ParseError()
|
raise ParseError()
|
||||||
|
|
|
@ -1,13 +0,0 @@
|
||||||
from fastapi import APIRouter, Depends, Request
|
|
||||||
from ..controller.tasks import CacheProductUsageData
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["tasks"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/cache_product_usage_data")
|
|
||||||
async def create_user(
|
|
||||||
request: Request,
|
|
||||||
contoller: CacheProductUsageData = Depends(CacheProductUsageData),
|
|
||||||
):
|
|
||||||
return await contoller.call()
|
|
15
mypy.ini
15
mypy.ini
|
@ -1,15 +0,0 @@
|
||||||
|
|
||||||
[mypy]
|
|
||||||
plugins = sqlalchemy.ext.mypy.plugin,pydantic.mypy
|
|
||||||
|
|
||||||
exclude = .*/test/.*
|
|
||||||
|
|
||||||
pretty = True
|
|
||||||
|
|
||||||
platform = linux
|
|
||||||
|
|
||||||
warn_unused_configs = True
|
|
||||||
warn_unused_ignores = True
|
|
||||||
|
|
||||||
[mypy-fooder.controller.*]
|
|
||||||
disable_error_code=override
|
|
|
@ -13,7 +13,3 @@ flake8
|
||||||
flake8-bugbear
|
flake8-bugbear
|
||||||
httpx
|
httpx
|
||||||
aiosqlite
|
aiosqlite
|
||||||
mypy
|
|
||||||
types-requests
|
|
||||||
types-passlib
|
|
||||||
types-python-jose
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[flake8]
|
[flake8]
|
||||||
max-line-length = 80
|
max-line-length = 80
|
||||||
extend-select = B950
|
extend-select = B950
|
||||||
extend-ignore = E203,E501,E701,E712
|
extend-ignore = E203,E501,E701
|
||||||
extend-immutable-calls =
|
extend-immutable-calls =
|
||||||
Depends
|
Depends
|
||||||
|
|
6
test.sh
6
test.sh
|
@ -13,7 +13,6 @@ export DB_URI="sqlite+aiosqlite:///test.db"
|
||||||
export ECHO_SQL=0
|
export ECHO_SQL=0
|
||||||
export SECRET_KEY=$(openssl rand -hex 32)
|
export SECRET_KEY=$(openssl rand -hex 32)
|
||||||
export REFRESH_SECRET_KEY=$(openssl rand -hex 32)
|
export REFRESH_SECRET_KEY=$(openssl rand -hex 32)
|
||||||
export API_KEY=$(openssl rand -hex 32)
|
|
||||||
|
|
||||||
python -m fooder --create-tables
|
python -m fooder --create-tables
|
||||||
|
|
||||||
|
@ -24,17 +23,12 @@ else
|
||||||
python -m pytest fooder --disable-warnings -sv
|
python -m pytest fooder --disable-warnings -sv
|
||||||
fi
|
fi
|
||||||
|
|
||||||
status=$?
|
|
||||||
|
|
||||||
# unset test env values
|
# unset test env values
|
||||||
unset POSTGRES_USER
|
unset POSTGRES_USER
|
||||||
unset POSTGRES_DATABASE
|
unset POSTGRES_DATABASE
|
||||||
unset POSTGRES_PASSWORD
|
unset POSTGRES_PASSWORD
|
||||||
unset SECRET_KEY
|
unset SECRET_KEY
|
||||||
unset REFRESH_SECRET
|
unset REFRESH_SECRET
|
||||||
unset API_KEY
|
|
||||||
|
|
||||||
# if exists, remove test.db
|
# if exists, remove test.db
|
||||||
[ -f test.db ] && rm test.db
|
[ -f test.db ] && rm test.db
|
||||||
|
|
||||||
exit $status
|
|
||||||
|
|
Loading…
Reference in a new issue