controllers and commands, create_token works now

This commit is contained in:
Piotr Domański 2026-04-03 16:15:03 +02:00
parent 4950f0dfa4
commit 74ec8aa834
31 changed files with 326 additions and 304 deletions

View file

@ -13,9 +13,8 @@ if __name__ == "__main__":
import sqlalchemy
from sqlalchemy.orm import Session
from .domain import Base
from .settings import Settings
from .settings import settings
settings = Settings()
engine = sqlalchemy.create_engine(
settings.DB_URI.replace("+asyncpg", "").replace("+aiosqlite", "")
)

View file

@ -1,8 +1,11 @@
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import logging
from .router import router
from .settings import Settings
from .settings import settings
from .exc import ApiException
app = FastAPI(title="Fooder")
app.include_router(router)
@ -10,8 +13,17 @@ app.include_router(router)
app.add_middleware(
CORSMiddleware,
allow_origins=Settings().ALLOWED_ORIGINS,
allow_origins=settings.ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.exception_handler(ApiException)
async def exception_handler_ErrorBase(_: Request, exc: ApiException):
headers = {"www-authenticate": "Bearer"} if exc.HTTP_CODE == 401 else {}
logging.exception(exc)
return JSONResponse(
status_code=exc.HTTP_CODE, content={"message": exc.message}, headers=headers
)

View file

@ -1,125 +0,0 @@
from datetime import datetime, timedelta
from typing import Annotated
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from fastapi_users.password import PasswordHelper
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from .db import get_db_session
from .domain.token import RefreshToken
from .domain.user import User
from .settings import Settings
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
async def authenticate_user(
session: AsyncSession, username: str, password: str
) -> User | None:
user = await User.get_by_username(session, username)
if user is None:
return None
assert user is not None
if not verify_password(password, user.hashed_password):
return None
return user
async def verify_refresh_token(
session: AsyncSession, token: str
) -> RefreshToken | None:
try:
payload = jwt.decode(
token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
)
sub = payload.get("sub")
if sub is None:
return None
if not isinstance(sub, str):
return None
username: str = str(sub)
if username is None:
return None
except JWTError:
return None
user = await User.get_by_username(session, username)
if user is None:
return None
assert user is not None
current_token = await RefreshToken.get_token(session, user.id, token)
if current_token is not None:
return current_token
return None
def create_access_token(user: User) -> str:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode = {
"sub": user.username,
"exp": expire,
}
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
async def create_refresh_token(session: AsyncSession, user: User) -> RefreshToken:
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {
"sub": user.username,
"exp": expire,
}
encoded_jwt = jwt.encode(
to_encode, settings.REFRESH_SECRET_KEY, algorithm=settings.ALGORITHM
)
return await RefreshToken.create(session, token=encoded_jwt, user_id=user.id)
async def get_current_user(ssn: AsyncSessionDependency, token: TokenDependency) -> User:
async with ssn() as session:
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
sub = payload.get("sub")
if sub is None:
raise HTTPException(status_code=401, detail="Unathorized")
if not isinstance(sub, str):
raise HTTPException(status_code=401, detail="Unathorized")
username: str = str(sub)
if username is None:
raise HTTPException(status_code=401, detail="Unathorized")
except JWTError:
raise HTTPException(status_code=401, detail="Unathorized")
user = await User.get_by_username(session, username)
if user is None:
raise HTTPException(status_code=401, detail="Unathorized")
assert user is not None
return user

View file

View file

@ -1,7 +1,14 @@
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
from typing import Callable
from datetime import datetime
from .db import get_db_session
from .domain import User
from .repository import Repository
from .utils.datetime import utc_now
from .utils.jwt import AccessToken
from .exc import Unauthorized
class Context:
@ -9,14 +16,26 @@ class Context:
Main API context, aggregating dependencies
"""
def __init__(self, repo: Repository) -> None:
def __init__(
self,
repo: Repository,
clock: Callable[[], datetime] = utc_now,
_user: User | None = None,
) -> None:
self.repo = repo
self.clock = clock
self._user = _user
@property
def user(self) -> User:
if self._user is None:
raise Unauthorized()
return self._user
class ContextDependency:
"""
Configurable context dependecy. Allows for shared interface configuring
method required dependencies
Context dependecy
"""
def __init__(
@ -29,3 +48,21 @@ class ContextDependency:
session: AsyncSession = Depends(get_db_session),
):
return Context(repo=Repository(session))
class AuthContextDependency:
"""
Context dependecy for authorized endpionts
"""
async def __call__(
self,
token: str = Depends(OAuth2PasswordBearer(tokenUrl="/token")),
session: AsyncSession = Depends(get_db_session),
) -> Context:
access_token = AccessToken.decode(token)
repo = Repository(session)
user = await repo.user.get(User.id == access_token.sub)
if user is None:
raise Unauthorized()
return Context(repo=repo, _user=user)

View file

@ -0,0 +1,2 @@
from .user import UserController
from .token import TokenController

View file

@ -1,33 +1,17 @@
from typing import Annotated, Any
from fastapi import Depends
from sqlalchemy.ext.asyncio import async_sessionmaker
from ..auth import authorize_api_key, get_current_user
from ..db import get_db_session, AsyncSession
from ..domain.user import User
UserDependency = Annotated[User, Depends(get_current_user)]
ApiKeyDependency = Annotated[None, Depends(authorize_api_key)]
from ..context import Context
from typing import TypeVar, Generic
from sqlalchemy import BinaryExpression
class BaseController:
def __init__(self, session: AsyncSession) -> None:
self.session = session
async def call(self, *args, **kwargs) -> Any:
raise NotImplementedError
async def __call__(self, *args, **kwargs) -> Any:
return await self.call(*args, **kwargs)
T = TypeVar("T")
class AuthorizedController(BaseController):
def __init__(self, session: AsyncSession, user: UserDependency) -> None:
super().__init__(session)
self.user = user
class ControllerBase:
def __init__(self, ctx: Context) -> None:
self.ctx = ctx
class TasksSessionController(BaseController):
def __init__(self, session: AsyncSession, api_key: ApiKeyDependency) -> None:
super().__init__(session)
class ModelController(Generic[T], ControllerBase):
def __init__(self, ctx: Context, obj: T):
super().__init__(ctx)
self.obj = obj

View file

@ -1,58 +1,25 @@
from fastapi import HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from .base import ControllerBase
from ..context import Context
from ..utils.jwt import Token, AccessToken, RefreshToken
from typing import Type, TypeVar
from datetime import datetime
from ..auth import (
authenticate_user,
create_access_token,
create_refresh_token,
verify_refresh_token,
)
from ..domain.user import User as DBUser
from ..model.token import RefreshTokenPayload, Token
from .base import BaseController
T = TypeVar("T", bound=Token)
class CreateToken(BaseController):
async def call(self, content: OAuth2PasswordRequestForm) -> Token:
async with self.async_session.begin() as session:
user = await authenticate_user(session, content.username, content.password)
class TokenController(ControllerBase):
def __init__(self, ctx: Context, entity_id: int) -> None:
super().__init__(ctx)
self.entity_id = entity_id
if user is None:
raise HTTPException(
status_code=401, detail="Invalid username or password"
)
def generate_token(self, token_cls: Type[T], now: datetime) -> T:
return token_cls(exp=token_cls.calculate_exp(now), sub=self.entity_id)
refresh_token = await create_refresh_token(session, user)
access_token = create_access_token(user)
def generate_refresh_token(self, now: datetime) -> RefreshToken:
return self.generate_token(RefreshToken, now)
return Token(
access_token=access_token,
refresh_token=refresh_token.token,
token_type="bearer",
)
def generate_access_token(self, now: datetime) -> AccessToken:
return self.generate_token(AccessToken, now)
class RefreshToken(BaseController):
async def call(self, content: RefreshTokenPayload) -> Token:
async with self.async_session.begin() as session:
current_token = await verify_refresh_token(session, content.refresh_token)
if current_token is None:
raise HTTPException(status_code=401, detail="Invalid token")
user = await DBUser.get(session, current_token.user_id)
if user is None:
raise HTTPException(status_code=401, detail="Invalid token")
assert user is not None
await current_token.delete(session)
refresh_token = await create_refresh_token(session, user)
access_token = create_access_token(user)
return Token(
access_token=access_token,
refresh_token=refresh_token.token,
token_type="bearer",
)
def generate_token_pair(self, now: datetime) -> tuple[AccessToken, RefreshToken]:
return (self.generate_access_token(now), self.generate_refresh_token(now))

View file

@ -1,19 +1,24 @@
from fastapi import HTTPException
from ..domain.user import User as DBUser
from ..model.user import CreateUserPayload, User
from .base import BaseController
from .base import ModelController
from ..domain import User
from ..context import Context
from ..exc import Unauthorized
from .token import TokenController
class CreateUser(BaseController):
async def call(self, content: CreateUserPayload) -> User:
async with self.async_session.begin() as session:
try:
user = await DBUser.create(
session,
content.username,
content.password,
)
return User.from_orm(user)
except AssertionError as e:
raise HTTPException(status_code=400, detail=e.args[0])
class UserController(ModelController[User]):
@classmethod
async def session_start(
cls,
ctx: Context,
username: str,
password: str,
) -> "UserController":
obj = await ctx.repo.user.get(User.username == username)
if obj is None or not obj.verify_password(password):
raise Unauthorized()
return cls(ctx, obj)
def token_ctrl(self) -> TokenController:
return TokenController(ctx=self.ctx, entity_id=self.obj.id)

View file

View file

@ -0,0 +1,33 @@
from typing import Annotated, Any
from fastapi import Depends
from sqlalchemy.ext.asyncio import async_sessionmaker
from ..auth import authorize_api_key, get_current_user
from ..db import get_db_session, AsyncSession
from ..domain.user import User
UserDependency = Annotated[User, Depends(get_current_user)]
ApiKeyDependency = Annotated[None, Depends(authorize_api_key)]
class BaseController:
def __init__(self, session: AsyncSession) -> None:
self.session = session
async def call(self, *args, **kwargs) -> Any:
raise NotImplementedError
async def __call__(self, *args, **kwargs) -> Any:
return await self.call(*args, **kwargs)
class AuthorizedController(BaseController):
def __init__(self, session: AsyncSession, user: UserDependency) -> None:
super().__init__(session)
self.user = user
class TasksSessionController(BaseController):
def __init__(self, session: AsyncSession, api_key: ApiKeyDependency) -> None:
super().__init__(session)

View file

@ -0,0 +1,58 @@
from fastapi import HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from ..auth import (
authenticate_user,
create_access_token,
create_refresh_token,
verify_refresh_token,
)
from ..domain.user import User as DBUser
from ..model.token import RefreshTokenPayload, Token
from .base import BaseController
class CreateToken(BaseController):
async def call(self, content: OAuth2PasswordRequestForm) -> Token:
async with self.async_session.begin() as session:
user = await authenticate_user(session, content.username, content.password)
if user is None:
raise HTTPException(
status_code=401, detail="Invalid username or password"
)
refresh_token = await create_refresh_token(session, user)
access_token = create_access_token(user)
return Token(
access_token=access_token,
refresh_token=refresh_token.token,
token_type="bearer",
)
class RefreshToken(BaseController):
async def call(self, content: RefreshTokenPayload) -> Token:
async with self.async_session.begin() as session:
current_token = await verify_refresh_token(session, content.refresh_token)
if current_token is None:
raise HTTPException(status_code=401, detail="Invalid token")
user = await DBUser.get(session, current_token.user_id)
if user is None:
raise HTTPException(status_code=401, detail="Invalid token")
assert user is not None
await current_token.delete(session)
refresh_token = await create_refresh_token(session, user)
access_token = create_access_token(user)
return Token(
access_token=access_token,
refresh_token=refresh_token.token,
token_type="bearer",
)

View file

@ -0,0 +1,19 @@
from fastapi import HTTPException
from ..domain.user import User as DBUser
from ..model.user import CreateUserPayload, User
from .base import BaseController
class CreateUser(BaseController):
async def call(self, content: CreateUserPayload) -> User:
async with self.async_session.begin() as session:
try:
user = await DBUser.create(
session,
content.username,
content.password,
)
return User.from_orm(user)
except AssertionError as e:
raise HTTPException(status_code=400, detail=e.args[0])

19
fooder/exc.py Normal file
View file

@ -0,0 +1,19 @@
from typing import ClassVar
class ApiException(Exception):
HTTP_CODE: ClassVar[int]
MESSAGE: ClassVar[str]
def __init__(self, message: str | None = None) -> None:
self.message = message or self.MESSAGE
class NotFound(ApiException):
HTTP_CODE = 404
MESSAGE = "Not found"
class Unauthorized(ApiException):
HTTP_CODE = 401
MESSAGE = "Unathorized"

View file

@ -1,15 +1,11 @@
from pydantic import BaseModel
class Token(BaseModel):
class TokenResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
class TokenData(BaseModel):
username: str | None = None
class RefreshTokenPayload(BaseModel):
class RefreshTokenRequest(BaseModel):
refresh_token: str

View file

@ -1,14 +1,10 @@
from typing import TypeVar, Generic, Type, Sequence
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import (
select,
update as sa_update,
delete as sa_delete,
BinaryExpression,
)
from sqlalchemy import select, delete as sa_delete, ColumnElement
from sqlalchemy.sql import Select
from ..domain import Base
T = TypeVar("T")
T = TypeVar("T", bound=Base)
class RepositoryBase(Generic[T]):
@ -16,7 +12,7 @@ class RepositoryBase(Generic[T]):
self.model = model
self.session = session
def _build_select(self, *expressions: BinaryExpression) -> Select[tuple[T]]:
def _build_select(self, *expressions: ColumnElement) -> Select[tuple[T]]:
stmt = select(self.model)
if expressions:
@ -24,12 +20,12 @@ class RepositoryBase(Generic[T]):
return stmt
async def get(self, *expressions: BinaryExpression) -> T | None:
async def get(self, *expressions: ColumnElement) -> T | None:
stmt = self._build_select(*expressions)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def list(self, *expressions: BinaryExpression) -> Sequence[T]:
async def list(self, *expressions: ColumnElement) -> Sequence[T]:
stmt = self._build_select(*expressions)
result = await self.session.execute(stmt)
return result.scalars().all()
@ -40,7 +36,7 @@ class RepositoryBase(Generic[T]):
await self.session.refresh(obj)
return obj
async def delete(self, *expressions: BinaryExpression):
async def delete(self, *expressions: ColumnElement):
stmt = sa_delete(self.model)
if expressions:

View file

@ -1,18 +1,6 @@
from fastapi import APIRouter
# from .view.diary import router as diary_router
# from .view.entry import router as entry_router
# from .view.meal import router as meal_router
# from .view.preset import router as preset_router
# from .view.product import router as product_router
from .view.token import router as token_router
# from .view.user import router as user_router
router = APIRouter(prefix="/api")
# router.include_router(product_router, prefix="/product", tags=["product"])
# router.include_router(diary_router, prefix="/diary", tags=["diary"])
# router.include_router(meal_router, prefix="/meal", tags=["meal"])
# router.include_router(entry_router, prefix="/entry", tags=["entry"])
router.include_router(token_router, prefix="/token", tags=["token"])
# router.include_router(user_router, prefix="/user", tags=["user"])
# router.include_router(preset_router, prefix="/preset", tags=["preset"])

View file

@ -1,17 +0,0 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .settings import Settings
from .view.tasks import router
app = FastAPI(title="Fooder Tasks admininstrative API")
app.include_router(router)
app.add_middleware(
CORSMiddleware,
allow_origins=Settings().ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

View file

View file

@ -0,0 +1,6 @@
from fooder.controller import TokenController
def test_token_ctrl_generates_token(ctx):
token_ctrl = TokenController(ctx, 1)
token_ctrl.generate_token_pair(ctx.clock())

11
fooder/test/test_exc.py Normal file
View file

@ -0,0 +1,11 @@
from fooder.exc import ApiException
class TestException(ApiException):
HTTP_CODE = 0
MESSAGE = "test"
def test_exc_message():
assert TestException().message == TestException.MESSAGE
assert TestException("other message").message == "other message"

View file

@ -1,7 +1,9 @@
from datetime import datetime, timedelta, timezone
import pytest
from jose import JWTError
from jose import jwt
from typing import Literal
from fooder.exc import Unauthorized
from fooder.utils.jwt import AccessToken, RefreshToken, Token
@ -9,6 +11,7 @@ PAST = datetime(2000, 1, 1, tzinfo=timezone.utc)
class WrongKeyToken(Token):
token_type: Literal["test-type"] = "test-type"
secret_key = "wrong-secret"
expire_delta = timedelta(minutes=30)
@ -29,15 +32,22 @@ class TestAccessToken:
now = datetime.now(timezone.utc)
token = WrongKeyToken(exp=WrongKeyToken.calculate_exp(now), sub=1)
with pytest.raises(JWTError):
with pytest.raises(Unauthorized):
AccessToken.decode(token.encode())
def test_decode_expired_raises(self):
token = AccessToken(exp=PAST, sub=1)
with pytest.raises(JWTError):
with pytest.raises(Unauthorized):
AccessToken.decode(token.encode())
def test_encoded_fields(self):
now = datetime.now(timezone.utc)
token = AccessToken(exp=AccessToken.calculate_exp(now), sub=42)
payload = jwt.decode(token.encode(), "", options={"verify_signature": False})
assert "secret_key" not in payload
assert "expire_delta" not in payload
class TestRefreshToken:
def test_encode_decode_roundtrip(self):
@ -55,12 +65,19 @@ class TestRefreshToken:
now = datetime.now(timezone.utc)
token = RefreshToken(exp=RefreshToken.calculate_exp(now), sub=1)
with pytest.raises(JWTError):
with pytest.raises(Unauthorized):
AccessToken.decode(token.encode())
def test_access_token_not_decodable_as_refresh_token(self):
now = datetime.now(timezone.utc)
token = AccessToken(exp=AccessToken.calculate_exp(now), sub=1)
with pytest.raises(JWTError):
with pytest.raises(Unauthorized):
RefreshToken.decode(token.encode())
def test_encoded_fields(self):
now = datetime.now(timezone.utc)
token = AccessToken(exp=RefreshToken.calculate_exp(now), sub=42)
payload = jwt.decode(token.encode(), "", options={"verify_signature": False})
assert "secret_key" not in payload
assert "expire_delta" not in payload

5
fooder/utils/datetime.py Normal file
View file

@ -0,0 +1,5 @@
from datetime import datetime, timezone
def utc_now() -> datetime:
return datetime.now(timezone.utc)

View file

@ -1,11 +1,14 @@
from jose import jwt
from jose import jwt, JOSEError
from pydantic import BaseModel
from datetime import timedelta, datetime
from typing import ClassVar
from typing import ClassVar, Literal
import logging
from ..settings import settings
from ..exc import Unauthorized
class Token(BaseModel):
token_type: str
exp: datetime
sub: int
@ -18,7 +21,13 @@ class Token(BaseModel):
@classmethod
def decode(cls, jwt_token: str | bytes) -> "Token":
data = jwt.decode(jwt_token, cls.secret_key, algorithms=[settings.ALGORITHM])
try:
data = jwt.decode(
jwt_token, cls.secret_key, algorithms=[settings.ALGORITHM]
)
except JOSEError as e:
logging.error(e)
raise Unauthorized()
return cls(**data)
def encode(self) -> str:
@ -28,10 +37,12 @@ class Token(BaseModel):
class AccessToken(Token):
token_type: Literal["access"] = "access"
secret_key = settings.SECRET_KEY
expire_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
class RefreshToken(Token):
token_type: Literal["refresh"] = "refresh"
secret_key = settings.REFRESH_SECRET_KEY
expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)

View file

@ -1,39 +1,38 @@
from typing import Annotated
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends
from fastapi.security import OAuth2PasswordRequestForm
from datetime import datetime
from ..model.token import RefreshTokenPayload, Token
from ..model.token import TokenResponse, RefreshTokenRequest
from ..context import ContextDependency, Context
from ..utils.jwt import AccessToken, RefreshToken
from ..domain import User
from ..controller import UserController
router = APIRouter(tags=["token"])
@router.post("", response_model=Token)
async def create_token(
data: Annotated[OAuth2PasswordRequestForm, Depends()],
ctx: Context = Depends(ContextDependency()),
):
user = await ctx.repo.user.get(User.username == data.username)
if user is None or not user.verify_password(data.password):
raise HTTPException(status_code=401, detail="Unathorized")
now = datetime.now(timezone.utc)
access_token = AccessToken(sub=user.id, exp=AccessToken.calculate_exp(now))
refresh_token = RefreshToken(sub=user.id, exp=RefreshToken.calculate_exp(now))
return Token(
def gen_token_response(user_ctrl: UserController, now: datetime) -> TokenResponse:
token_ctrl = user_ctrl.token_ctrl()
access_token, refresh_token = token_ctrl.generate_token_pair(now)
return TokenResponse(
access_token=access_token.encode(),
refresh_token=refresh_token.encode(),
)
@router.post("/refresh", response_model=Token)
async def refresh_token(
data: RefreshTokenPayload,
@router.post("", response_model=TokenResponse)
async def token_create(
data: Annotated[OAuth2PasswordRequestForm, Depends()],
ctx: Context = Depends(ContextDependency()),
) -> TokenResponse:
now = ctx.clock()
user_ctrl = await UserController.session_start(ctx, data.username, data.password)
return gen_token_response(user_ctrl, now)
@router.post("/refresh", response_model=TokenResponse)
async def token_refresh(
data: RefreshTokenRequest,
ctx: Context = Depends(ContextDependency()),
):
pass