Compare commits

..

10 commits

28 changed files with 289 additions and 66 deletions

View file

@ -1,9 +1,21 @@
name: Python lint and test
on: [push, pull_request]
on:
push:
branches:
- main
- 'releases/**'
paths:
- '**.py'
pull_request:
branches:
- main
- 'releases/**'
paths:
- '**.py'
jobs:
lint:
linttest:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
@ -18,5 +30,7 @@ jobs:
run: black --check fooder
- name: Run flake8
run: flake8 fooder
- name: Run mypy
run: mypy fooder
- name: Run tests
run: ./test.sh

View file

@ -14,15 +14,15 @@ push:
docker push registry.domandoman.xyz/fooder/api
black:
black fooder
python -m black fooder
.PHONY: mypy
mypy:
mypy fooder
python -m mypy fooder
.PHONY: flake
flake:
flake8 fooder
python -m flake8 fooder
.PHONY: lint
lint: black mypy flake

View file

@ -1,9 +1,6 @@
# FOODER
Simple API for food diary application. It uses FastAPI and async postgres for faster operation.
I plan on developing a few clients for the API, for now only one is available:
- [Fooder CLI Client](https://github.com/ickyicky/fooder-cli-client)
Simple API for food diary application. It uses FastAPI and async postgres.
## Usage

View file

@ -11,3 +11,5 @@ REFRESH_SECRET_KEY="${REFRESH_SECRET_KEY}" # generate with $ openssl rand -hex 3
ALGORITHM="HS256"
ACCESS_TOKEN_EXPIRE_MINUTES=30
REFRESH_TOKEN_EXPIRE_DAYS=30
API_KEY="${API_KEY}" # generate with $ openssl rand -hex 32

View file

@ -1,11 +1,10 @@
from passlib.context import CryptContext
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from jose import JWTError, jwt
from fastapi.security import OAuth2PasswordBearer
from fastapi import Depends, HTTPException
from fastapi_users.password import PasswordHelper
from typing import AsyncGenerator, Annotated
from typing import Annotated
from datetime import datetime, timedelta
from .settings import Settings
from .domain.user import User
@ -16,7 +15,7 @@ from .db import get_session
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
settings = Settings()
password_helper = PasswordHelper(pwd_context)
password_helper = PasswordHelper(pwd_context) # type: ignore
AsyncSessionDependency = Annotated[async_sessionmaker, Depends(get_session)]
TokenDependency = Annotated[str, Depends(oauth2_scheme)]
@ -32,35 +31,57 @@ def get_password_hash(password: str) -> str:
async def authenticate_user(
session: AsyncSession, username: str, password: str
) -> AsyncGenerator[User, None]:
) -> User | None:
user = await User.get_by_username(session, username)
if not user:
if user is None:
return None
assert user is not None
if not verify_password(password, user.hashed_password):
return None
return user
async def verify_refresh_token(
session: AsyncSession, token: str
) -> AsyncGenerator[RefreshToken, None]:
) -> RefreshToken | None:
try:
payload = jwt.decode(
token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
)
username: str = payload.get("sub")
sub = payload.get("sub")
if sub is None:
return None
if not isinstance(sub, str):
return None
username: str = str(sub)
if username is None:
return
return None
except JWTError:
return
return None
user = await User.get_by_username(session, username)
if user is None:
return
return None
assert user is not None
current_token = await RefreshToken.get_token(session, user.id, token)
if current_token is not None:
return current_token
return None
def create_access_token(user: User) -> str:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
@ -86,18 +107,38 @@ async def create_refresh_token(session: AsyncSession, user: User) -> RefreshToke
return await RefreshToken.create(session, token=encoded_jwt, user_id=user.id)
async def get_current_user(
session: AsyncSessionDependency, token: TokenDependency
) -> User:
async with session() as session:
async def get_current_user(ssn: AsyncSessionDependency, token: TokenDependency) -> User:
async with ssn() as session:
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
username: str = payload.get("sub")
sub = payload.get("sub")
if sub is None:
raise HTTPException(status_code=401, detail="Unathorized")
if not isinstance(sub, str):
raise HTTPException(status_code=401, detail="Unathorized")
username: str = str(sub)
if username is None:
raise HTTPException(status_code=401, detail="Unathorized")
except JWTError:
raise HTTPException(status_code=401, detail="Unathorized")
return await User.get_by_username(session, username)
user = await User.get_by_username(session, username)
if user is None:
raise HTTPException(status_code=401, detail="Unathorized")
assert user is not None
return user
async def authorize_api_key(token: TokenDependency) -> None:
if token == settings.API_KEY:
return None
raise HTTPException(status_code=401, detail="Unathorized")

View file

@ -2,12 +2,13 @@ from typing import Annotated, Any
from fastapi import Depends
from sqlalchemy.ext.asyncio import async_sessionmaker
from ..db import get_session
from ..auth import get_current_user
from ..auth import get_current_user, authorize_api_key
from ..domain.user import User
AsyncSession = Annotated[async_sessionmaker, Depends(get_session)]
UserDependency = Annotated[User, Depends(get_current_user)]
ApiKeyDependency = Annotated[None, Depends(authorize_api_key)]
class BaseController:
@ -25,3 +26,8 @@ class AuthorizedController(BaseController):
def __init__(self, session: AsyncSession, user: UserDependency) -> None:
super().__init__(session)
self.user = user
class TasksSessionController(BaseController):
def __init__(self, session: AsyncSession, api_key: ApiKeyDependency) -> None:
super().__init__(session)

View file

@ -39,7 +39,7 @@ class UpdateEntry(AuthorizedController):
class DeleteEntry(AuthorizedController):
async def call(self, entry_id: int) -> Entry:
async def call(self, entry_id: int) -> None:
async with self.async_session.begin() as session:
entry = await DBEntry.get_by_id(session, self.user.id, entry_id)
if entry is None:

View file

@ -29,7 +29,7 @@ class CreateMeal(AuthorizedController):
class SaveMeal(AuthorizedController):
async def call(self, meal_id: id, payload: SaveMealPayload) -> Preset:
async def call(self, meal_id: int, payload: SaveMealPayload) -> Preset:
async with self.async_session.begin() as session:
meal = await DBMeal.get_by_id(session, self.user.id, meal_id)
if meal is None:
@ -38,7 +38,10 @@ class SaveMeal(AuthorizedController):
try:
return Preset.from_orm(
await DBPreset.create(
session, user_id=self.user.id, name=payload.name, meal=meal
session,
user_id=self.user.id,
name=payload.name or meal.name,
meal=meal,
)
)
except AssertionError as e:
@ -46,7 +49,7 @@ class SaveMeal(AuthorizedController):
class DeleteMeal(AuthorizedController):
async def call(self, meal_id: id) -> None:
async def call(self, meal_id: int) -> None:
async with self.async_session.begin() as session:
meal = await DBMeal.get_by_id(session, self.user.id, meal_id)
if meal is None:

View file

@ -32,7 +32,7 @@ class DeletePreset(AuthorizedController):
async def call(
self,
id: int,
) -> AsyncIterator[Preset]:
) -> None:
async with self.async_session.begin() as session:
preset = await DBPreset.get(session, self.user.id, id)

View file

@ -0,0 +1,13 @@
from fastapi import HTTPException
from ..domain.product import Product as DBProduct
from .base import TasksSessionController
class CacheProductUsageData(TasksSessionController):
async def call(self) -> None:
async with self.async_session.begin() as session:
try:
await DBProduct.cache_usage_data(session)
await session.commit()
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

View file

@ -41,6 +41,11 @@ class RefreshToken(BaseController):
raise HTTPException(status_code=401, detail="Invalid token")
user = await DBUser.get(session, current_token.user_id)
if user is None:
raise HTTPException(status_code=401, detail="Invalid token")
assert user is not None
await current_token.delete(session)
refresh_token = await create_refresh_token(session, user)

View file

@ -17,6 +17,6 @@ class CommonMixin:
:rtype: str
"""
return cls.__name__.lower()
return cls.__name__.lower() # type: ignore
id: Mapped[int] = mapped_column(primary_key=True)

View file

@ -3,7 +3,7 @@ from sqlalchemy import ForeignKey, Integer, Date
from sqlalchemy import select
from sqlalchemy.sql.selectable import Select
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import date
import datetime
from typing import Optional
from .base import Base, CommonMixin
@ -17,7 +17,7 @@ class Diary(Base, CommonMixin):
meals: Mapped[list[Meal]] = relationship(
lazy="selectin", order_by=Meal.order.desc()
)
date: Mapped[date] = mapped_column(Date)
date: Mapped[datetime.date] = mapped_column(Date)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
@property
@ -74,14 +74,16 @@ class Diary(Base, CommonMixin):
@classmethod
async def get_diary(
cls, session: AsyncSession, user_id: int, date: date
cls, session: AsyncSession, user_id: int, date: datetime.date
) -> "Optional[Diary]":
"""get_diary."""
query = cls.query(user_id).where(cls.date == date)
return await session.scalar(query)
@classmethod
async def create(cls, session: AsyncSession, user_id: int, date: date) -> "Diary":
async def create(
cls, session: AsyncSession, user_id: int, date: datetime.date
) -> "Diary":
diary = Diary(
date=date,
user_id=user_id,
@ -93,12 +95,13 @@ class Diary(Base, CommonMixin):
except Exception:
raise RuntimeError()
diary = await cls.get_by_id(session, user_id, diary.id)
db_diary = await cls.get_by_id(session, user_id, diary.id)
if not diary:
if not db_diary:
raise RuntimeError()
await Meal.create(session, diary.id)
return diary
await Meal.create(session, db_diary.id)
return db_diary
@classmethod
async def get_by_id(

View file

@ -1,8 +1,8 @@
from sqlalchemy.orm import Mapped, mapped_column, relationship, joinedload
from sqlalchemy import ForeignKey, Integer, DateTime
from sqlalchemy import ForeignKey, Integer, DateTime, Boolean
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError
from sqlalchemy import select
from sqlalchemy import select, update
from datetime import datetime
from typing import Optional
@ -20,6 +20,7 @@ class Entry(Base, CommonMixin):
last_changed: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)
processed: Mapped[bool] = mapped_column(Boolean, default=False)
@property
def amount(self) -> float:
@ -87,10 +88,10 @@ class Entry(Base, CommonMixin):
except IntegrityError:
raise AssertionError("meal or product does not exist")
entry = await cls._get_by_id(session, entry.id)
if not entry:
db_entry = await cls._get_by_id(session, entry.id)
if not db_entry:
raise RuntimeError()
return entry
return db_entry
async def update(
self,
@ -152,3 +153,12 @@ class Entry(Base, CommonMixin):
"""delete."""
await session.delete(self)
await session.flush()
@classmethod
async def mark_processed(
cls,
session: AsyncSession,
) -> None:
stmt = update(cls).where(cls.processed == False).values(processed=True)
await session.execute(stmt)

View file

@ -84,10 +84,10 @@ class Meal(Base, CommonMixin):
except IntegrityError:
raise AssertionError("diary does not exist")
meal = await cls._get_by_id(session, meal.id)
if not meal:
db_meal = await cls._get_by_id(session, meal.id)
if not db_meal:
raise RuntimeError()
return meal
return db_meal
@classmethod
async def create_from_preset(
@ -118,10 +118,10 @@ class Meal(Base, CommonMixin):
for entry in preset.entries:
await Entry.create(session, meal.id, entry.product_id, entry.grams)
meal = await cls._get_by_id(session, meal.id)
if not meal:
db_meal = await cls._get_by_id(session, meal.id)
if not db_meal:
raise RuntimeError()
return meal
return db_meal
@classmethod
async def _get_by_id(cls, session: AsyncSession, id: int) -> "Optional[Meal]":

View file

@ -63,7 +63,7 @@ class Preset(Base, CommonMixin):
@classmethod
async def create(
cls, session: AsyncSession, user_id: int, name: str, meal: "Meal"
) -> None:
) -> "Preset":
preset = Preset(user_id=user_id, name=name)
session.add(preset)
@ -76,7 +76,12 @@ class Preset(Base, CommonMixin):
for entry in meal.entries:
await PresetEntry.create(session, preset.id, entry)
return await cls.get(session, user_id, preset.id)
db_preset = await cls.get(session, user_id, preset.id)
if not db_preset:
raise RuntimeError()
return db_preset
@classmethod
async def list_all(

View file

@ -1,5 +1,5 @@
from sqlalchemy.orm import Mapped
from sqlalchemy import select
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import select, BigInteger, func, update
from sqlalchemy.ext.asyncio import AsyncSession
from typing import AsyncIterator, Optional
@ -15,8 +15,14 @@ class Product(Base, CommonMixin):
carb: Mapped[float]
fat: Mapped[float]
fiber: Mapped[float]
hard_coded_calories: Mapped[Optional[float]] = None
barcode: Mapped[Optional[str]] = None
hard_coded_calories: Mapped[Optional[float]]
barcode: Mapped[Optional[str]]
usage_count_cached: Mapped[int] = mapped_column(
BigInteger,
default=0,
nullable=False,
)
@property
def calories(self) -> float:
@ -41,11 +47,13 @@ class Product(Base, CommonMixin):
if q:
q_list = q.split()
for qq in q_list:
query = query.filter(cls.name.ilike(f"%{qq.lower()}%"))
qq = "%" + "%".join(q_list) + "%"
query = query.filter(cls.name.ilike(f"%{qq.lower()}%"))
query = query.offset(offset).limit(limit)
stream = await session.stream_scalars(query.order_by(cls.id))
stream = await session.stream_scalars(
query.order_by(cls.usage_count_cached.desc())
)
async for row in stream:
yield row
@ -104,3 +112,28 @@ class Product(Base, CommonMixin):
session.add(product)
await session.flush()
return product
@classmethod
async def cache_usage_data(
cls,
session: AsyncSession,
) -> None:
from .entry import Entry
stmt = (
update(cls)
.where(
cls.id.in_(
select(Entry.product_id).where(Entry.processed == False).distinct()
)
)
.values(
usage_count_cached=select(func.count(Entry.id)).where(
Entry.product_id == cls.id,
Entry.processed == False,
)
)
)
await session.execute(stmt)
await Entry.mark_processed(session)

View file

@ -47,18 +47,18 @@ class RefreshToken(Base, CommonMixin):
:type token: str
:rtype: "RefreshToken"
"""
token = cls(
db_token = cls(
user_id=user_id,
token=token,
)
session.add(token)
session.add(db_token)
try:
await session.flush()
except Exception:
raise AssertionError("invalid token")
return token
return db_token
async def delete(self, session: AsyncSession) -> None:
"""delete.

View file

@ -15,3 +15,5 @@ class Settings(BaseSettings):
REFRESH_TOKEN_EXPIRE_DAYS: int = 30
ALLOWED_ORIGINS: List[str] = ["*"]
API_KEY: str

17
fooder/tasks_app.py Normal file
View file

@ -0,0 +1,17 @@
from fastapi import FastAPI
from .view.tasks import router
from .settings import Settings
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI(title="Fooder Tasks admininstrative API")
app.include_router(router)
app.add_middleware(
CORSMiddleware,
allow_origins=Settings().ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

View file

@ -1,7 +1,9 @@
from fooder.app import app
from fooder.tasks_app import app as tasks_app
from httpx import AsyncClient
import pytest
import httpx
import os
class Client:
@ -67,11 +69,29 @@ class Client:
return await self.client.patch(path, **kwargs)
class TasksClient(Client):
def __init__(self, authorized: bool = True):
super().__init__()
self.client = AsyncClient(app=tasks_app, base_url="http://testserver/api")
self.client.headers["Accept"] = "application/json"
if authorized:
self.client.headers["Authorization"] = "Bearer " + self.get_token()
def get_token(self) -> str:
return os.getenv("API_KEY")
@pytest.fixture
def unauthorized_client() -> Client:
return Client()
@pytest.fixture
def tasks_client() -> Client:
return TasksClient()
@pytest.fixture
async def client(user_payload) -> Client:
client = Client()

15
fooder/test/test_tasks.py Normal file
View file

@ -0,0 +1,15 @@
import pytest
@pytest.mark.anyio
async def test_cache_product_usage(client, tasks_client):
response = await client.get("product")
assert response.status_code == 200, response.json()
old_data = response.json()
response = await tasks_client.post("/cache_product_usage_data")
assert response.status_code == 200, response.json()
response = await client.get("product")
assert response.status_code == 200, response.json()
assert response.json() != old_data

View file

@ -48,6 +48,5 @@ def find(bar_code: str) -> Product:
fiber=data["product"]["nutriments"].get("fiber_100g", 0.0),
)
except Exception as e:
raise e
logger.error(e)
raise ParseError()

13
fooder/view/tasks.py Normal file
View file

@ -0,0 +1,13 @@
from fastapi import APIRouter, Depends, Request
from ..controller.tasks import CacheProductUsageData
router = APIRouter(prefix="/api", tags=["tasks"])
@router.post("/cache_product_usage_data")
async def create_user(
request: Request,
contoller: CacheProductUsageData = Depends(CacheProductUsageData),
):
return await contoller.call()

15
mypy.ini Normal file
View file

@ -0,0 +1,15 @@
[mypy]
plugins = sqlalchemy.ext.mypy.plugin,pydantic.mypy
exclude = .*/test/.*
pretty = True
platform = linux
warn_unused_configs = True
warn_unused_ignores = True
[mypy-fooder.controller.*]
disable_error_code=override

View file

@ -13,3 +13,7 @@ flake8
flake8-bugbear
httpx
aiosqlite
mypy
types-requests
types-passlib
types-python-jose

View file

@ -1,6 +1,6 @@
[flake8]
max-line-length = 80
extend-select = B950
extend-ignore = E203,E501,E701
extend-ignore = E203,E501,E701,E712
extend-immutable-calls =
Depends

View file

@ -13,6 +13,7 @@ export DB_URI="sqlite+aiosqlite:///test.db"
export ECHO_SQL=0
export SECRET_KEY=$(openssl rand -hex 32)
export REFRESH_SECRET_KEY=$(openssl rand -hex 32)
export API_KEY=$(openssl rand -hex 32)
python -m fooder --create-tables
@ -23,12 +24,17 @@ else
python -m pytest fooder --disable-warnings -sv
fi
status=$?
# unset test env values
unset POSTGRES_USER
unset POSTGRES_DATABASE
unset POSTGRES_PASSWORD
unset SECRET_KEY
unset REFRESH_SECRET
unset API_KEY
# if exists, remove test.db
[ -f test.db ] && rm test.db
exit $status