[tasks] basic implementation of the concept

This commit is contained in:
Piotr Domański 2024-05-21 11:11:47 +02:00
parent 969b57e993
commit 5d9c2e8bd8
13 changed files with 150 additions and 9 deletions

View file

@ -11,3 +11,5 @@ REFRESH_SECRET_KEY="${REFRESH_SECRET_KEY}" # generate with $ openssl rand -hex 3
ALGORITHM="HS256" ALGORITHM="HS256"
ACCESS_TOKEN_EXPIRE_MINUTES=30 ACCESS_TOKEN_EXPIRE_MINUTES=30
REFRESH_TOKEN_EXPIRE_DAYS=30 REFRESH_TOKEN_EXPIRE_DAYS=30
API_KEY="${API_KEY}" # generate with $ openssl rand -hex 32

View file

@ -101,3 +101,11 @@ async def get_current_user(
raise HTTPException(status_code=401, detail="Unathorized") raise HTTPException(status_code=401, detail="Unathorized")
return await User.get_by_username(session, username) return await User.get_by_username(session, username)
async def authorize_api_key(
session: AsyncSessionDependency, token: TokenDependency
) -> None:
if token == settings.API_KEY:
return None
raise HTTPException(status_code=401, detail="Unathorized")

View file

@ -2,12 +2,13 @@ from typing import Annotated, Any
from fastapi import Depends from fastapi import Depends
from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import async_sessionmaker
from ..db import get_session from ..db import get_session
from ..auth import get_current_user from ..auth import get_current_user, authorize_api_key
from ..domain.user import User from ..domain.user import User
AsyncSession = Annotated[async_sessionmaker, Depends(get_session)] AsyncSession = Annotated[async_sessionmaker, Depends(get_session)]
UserDependency = Annotated[User, Depends(get_current_user)] UserDependency = Annotated[User, Depends(get_current_user)]
ApiKeyDependency = Annotated[bool, Depends(authorize_api_key)]
class BaseController: class BaseController:
@ -25,3 +26,8 @@ class AuthorizedController(BaseController):
def __init__(self, session: AsyncSession, user: UserDependency) -> None: def __init__(self, session: AsyncSession, user: UserDependency) -> None:
super().__init__(session) super().__init__(session)
self.user = user self.user = user
class TasksSessionController(BaseController):
def __init__(self, session: AsyncSession, api_key: ApiKeyDependency) -> None:
super().__init__(session)

View file

@ -0,0 +1,13 @@
from fastapi import HTTPException
from ..domain.product import Product as DBProduct
from .base import TasksSessionController
class CacheProductUsageData(TasksSessionController):
async def call(self) -> None:
async with self.async_session.begin() as session:
try:
await DBProduct.cache_usage_data(session)
await session.commit()
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

View file

@ -1,8 +1,8 @@
from sqlalchemy.orm import Mapped, mapped_column, relationship, joinedload from sqlalchemy.orm import Mapped, mapped_column, relationship, joinedload
from sqlalchemy import ForeignKey, Integer, DateTime from sqlalchemy import ForeignKey, Integer, DateTime, Boolean
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy import select from sqlalchemy import select, update
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
@ -20,6 +20,7 @@ class Entry(Base, CommonMixin):
last_changed: Mapped[datetime] = mapped_column( last_changed: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
) )
processed: Mapped[bool] = mapped_column(Boolean, default=False)
@property @property
def amount(self) -> float: def amount(self) -> float:
@ -152,3 +153,12 @@ class Entry(Base, CommonMixin):
"""delete.""" """delete."""
await session.delete(self) await session.delete(self)
await session.flush() await session.flush()
@classmethod
async def mark_processed(
cls,
session: AsyncSession,
) -> None:
stmt = update(cls).where(cls.processed is False).values(processed=True)
await session.execute(stmt)

View file

@ -1,5 +1,5 @@
from sqlalchemy.orm import Mapped from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import select from sqlalchemy import select, BigInteger, func, update
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from typing import AsyncIterator, Optional from typing import AsyncIterator, Optional
@ -18,6 +18,12 @@ class Product(Base, CommonMixin):
hard_coded_calories: Mapped[Optional[float]] hard_coded_calories: Mapped[Optional[float]]
barcode: Mapped[Optional[str]] barcode: Mapped[Optional[str]]
usage_count_cached: Mapped[int] = mapped_column(
BigInteger,
default=0,
nullable=False,
)
@property @property
def calories(self) -> float: def calories(self) -> float:
"""calories. """calories.
@ -41,11 +47,13 @@ class Product(Base, CommonMixin):
if q: if q:
q_list = q.split() q_list = q.split()
for qq in q_list: qq = "%" + "%".join(q_list) + "%"
query = query.filter(cls.name.ilike(f"%{qq.lower()}%")) query = query.filter(cls.name.ilike(f"%{qq.lower()}%"))
query = query.offset(offset).limit(limit) query = query.offset(offset).limit(limit)
stream = await session.stream_scalars(query.order_by(cls.id)) stream = await session.stream_scalars(
query.order_by(cls.usage_count_cached.desc())
)
async for row in stream: async for row in stream:
yield row yield row
@ -104,3 +112,28 @@ class Product(Base, CommonMixin):
session.add(product) session.add(product)
await session.flush() await session.flush()
return product return product
@classmethod
async def cache_usage_data(
cls,
session: AsyncSession,
) -> None:
from .entry import Entry
stmt = (
update(cls)
.where(
cls.id.in_(
select(Entry.product_id).where(Entry.processed == False).distinct()
)
)
.values(
usage_count_cached=select(func.count(Entry.id)).where(
Entry.product_id == cls.id,
Entry.processed == False,
)
)
)
await session.execute(stmt)
await Entry.mark_processed(session)

View file

@ -15,3 +15,5 @@ class Settings(BaseSettings):
REFRESH_TOKEN_EXPIRE_DAYS: int = 30 REFRESH_TOKEN_EXPIRE_DAYS: int = 30
ALLOWED_ORIGINS: List[str] = ["*"] ALLOWED_ORIGINS: List[str] = ["*"]
API_KEY: str

17
fooder/tasks_app.py Normal file
View file

@ -0,0 +1,17 @@
from fastapi import FastAPI
from .view.tasks import router
from .settings import Settings
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI(title="Fooder Tasks admininstrative API")
app.include_router(router)
app.add_middleware(
CORSMiddleware,
allow_origins=Settings().ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

View file

@ -1,7 +1,9 @@
from fooder.app import app from fooder.app import app
from fooder.tasks_app import app as tasks_app
from httpx import AsyncClient from httpx import AsyncClient
import pytest import pytest
import httpx import httpx
import os
class Client: class Client:
@ -67,11 +69,29 @@ class Client:
return await self.client.patch(path, **kwargs) return await self.client.patch(path, **kwargs)
class TasksClient(Client):
def __init__(self, authorized: bool = True):
super().__init__()
self.client = AsyncClient(app=tasks_app, base_url="http://testserver/api")
self.client.headers["Accept"] = "application/json"
if authorized:
self.client.headers["Authorization"] = "Bearer " + self.get_token()
def get_token(self) -> str:
return os.getenv("API_KEY")
@pytest.fixture @pytest.fixture
def unauthorized_client() -> Client: def unauthorized_client() -> Client:
return Client() return Client()
@pytest.fixture
def tasks_client() -> Client:
return TasksClient()
@pytest.fixture @pytest.fixture
async def client(user_payload) -> Client: async def client(user_payload) -> Client:
client = Client() client = Client()

15
fooder/test/test_tasks.py Normal file
View file

@ -0,0 +1,15 @@
import pytest
@pytest.mark.anyio
async def test_cache_product_usage(client, tasks_client):
response = await client.get("product")
assert response.status_code == 200, response.json()
old_data = response.json()
response = await tasks_client.post("/cache_product_usage_data")
assert response.status_code == 200, response.json()
response = await client.get("product")
assert response.status_code == 200, response.json()
assert response.json() != old_data

13
fooder/view/tasks.py Normal file
View file

@ -0,0 +1,13 @@
from fastapi import APIRouter, Depends, Request
from ..controller.tasks import CacheProductUsageData
router = APIRouter(prefix="/api", tags=["tasks"])
@router.post("/cache_product_usage_data")
async def create_user(
request: Request,
contoller: CacheProductUsageData = Depends(CacheProductUsageData),
):
return await contoller.call()

View file

@ -1,6 +1,6 @@
[flake8] [flake8]
max-line-length = 80 max-line-length = 80
extend-select = B950 extend-select = B950
extend-ignore = E203,E501,E701 extend-ignore = E203,E501,E701,E712
extend-immutable-calls = extend-immutable-calls =
Depends Depends

View file

@ -13,6 +13,7 @@ export DB_URI="sqlite+aiosqlite:///test.db"
export ECHO_SQL=0 export ECHO_SQL=0
export SECRET_KEY=$(openssl rand -hex 32) export SECRET_KEY=$(openssl rand -hex 32)
export REFRESH_SECRET_KEY=$(openssl rand -hex 32) export REFRESH_SECRET_KEY=$(openssl rand -hex 32)
export API_KEY=$(openssl rand -hex 32)
python -m fooder --create-tables python -m fooder --create-tables
@ -31,6 +32,7 @@ unset POSTGRES_DATABASE
unset POSTGRES_PASSWORD unset POSTGRES_PASSWORD
unset SECRET_KEY unset SECRET_KEY
unset REFRESH_SECRET unset REFRESH_SECRET
unset API_KEY
# if exists, remove test.db # if exists, remove test.db
[ -f test.db ] && rm test.db [ -f test.db ] && rm test.db