From 74ec8aa83474cadd50b6856af93f51b979a098c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Doma=C5=84ski?= Date: Fri, 3 Apr 2026 16:15:03 +0200 Subject: [PATCH] controllers and commands, create_token works now --- fooder/__main__.py | 3 +- fooder/app.py | 18 ++- fooder/auth.py | 125 ------------------ fooder/command/__init__.py | 0 fooder/context.py | 43 +++++- fooder/controller/__init__.py | 2 + fooder/controller/base.py | 38 ++---- fooder/controller/token.py | 69 +++------- fooder/controller/user.py | 39 +++--- fooder/controller_old/__init__.py | 0 fooder/controller_old/base.py | 33 +++++ .../{controller => controller_old}/diary.py | 0 .../{controller => controller_old}/entry.py | 0 fooder/{controller => controller_old}/meal.py | 0 .../{controller => controller_old}/preset.py | 0 .../{controller => controller_old}/product.py | 0 .../{controller => controller_old}/tasks.py | 0 fooder/controller_old/token.py | 58 ++++++++ fooder/controller_old/user.py | 19 +++ fooder/exc.py | 19 +++ fooder/model/token.py | 8 +- fooder/repository/base.py | 18 +-- fooder/router.py | 12 -- fooder/tasks_app.py | 17 --- fooder/test/controller/__init__.py | 0 fooder/test/controller/test_token.py | 6 + fooder/test/test_exc.py | 11 ++ fooder/test/utils/test_jwt.py | 27 +++- fooder/utils/datetime.py | 5 + fooder/utils/jwt.py | 17 ++- fooder/view/token.py | 43 +++--- 31 files changed, 326 insertions(+), 304 deletions(-) delete mode 100644 fooder/auth.py create mode 100644 fooder/command/__init__.py create mode 100644 fooder/controller_old/__init__.py create mode 100644 fooder/controller_old/base.py rename fooder/{controller => controller_old}/diary.py (100%) rename fooder/{controller => controller_old}/entry.py (100%) rename fooder/{controller => controller_old}/meal.py (100%) rename fooder/{controller => controller_old}/preset.py (100%) rename fooder/{controller => controller_old}/product.py (100%) rename fooder/{controller => controller_old}/tasks.py (100%) create mode 100644 fooder/controller_old/token.py create mode 100644 fooder/controller_old/user.py create mode 100644 fooder/exc.py delete mode 100644 fooder/tasks_app.py create mode 100644 fooder/test/controller/__init__.py create mode 100644 fooder/test/controller/test_token.py create mode 100644 fooder/test/test_exc.py create mode 100644 fooder/utils/datetime.py diff --git a/fooder/__main__.py b/fooder/__main__.py index d2adece..71c8662 100644 --- a/fooder/__main__.py +++ b/fooder/__main__.py @@ -13,9 +13,8 @@ if __name__ == "__main__": import sqlalchemy from sqlalchemy.orm import Session from .domain import Base - from .settings import Settings + from .settings import settings - settings = Settings() engine = sqlalchemy.create_engine( settings.DB_URI.replace("+asyncpg", "").replace("+aiosqlite", "") ) diff --git a/fooder/app.py b/fooder/app.py index 5f414d7..c970475 100644 --- a/fooder/app.py +++ b/fooder/app.py @@ -1,8 +1,11 @@ -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +import logging from .router import router -from .settings import Settings +from .settings import settings +from .exc import ApiException app = FastAPI(title="Fooder") app.include_router(router) @@ -10,8 +13,17 @@ app.include_router(router) app.add_middleware( CORSMiddleware, - allow_origins=Settings().ALLOWED_ORIGINS, + allow_origins=settings.ALLOWED_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) + + +@app.exception_handler(ApiException) +async def exception_handler_ErrorBase(_: Request, exc: ApiException): + headers = {"www-authenticate": "Bearer"} if exc.HTTP_CODE == 401 else {} + logging.exception(exc) + return JSONResponse( + status_code=exc.HTTP_CODE, content={"message": exc.message}, headers=headers + ) diff --git a/fooder/auth.py b/fooder/auth.py deleted file mode 100644 index 5be919d..0000000 --- a/fooder/auth.py +++ /dev/null @@ -1,125 +0,0 @@ -from datetime import datetime, timedelta -from typing import Annotated - -from fastapi import Depends, HTTPException -from fastapi.security import OAuth2PasswordBearer -from fastapi_users.password import PasswordHelper -from jose import JWTError, jwt -from passlib.context import CryptContext -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker - -from .db import get_db_session -from .domain.token import RefreshToken -from .domain.user import User -from .settings import Settings - -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token") - - -async def authenticate_user( - session: AsyncSession, username: str, password: str -) -> User | None: - user = await User.get_by_username(session, username) - - if user is None: - return None - - assert user is not None - - if not verify_password(password, user.hashed_password): - return None - - return user - - -async def verify_refresh_token( - session: AsyncSession, token: str -) -> RefreshToken | None: - try: - payload = jwt.decode( - token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM] - ) - sub = payload.get("sub") - - if sub is None: - return None - - if not isinstance(sub, str): - return None - - username: str = str(sub) - - if username is None: - return None - - except JWTError: - return None - - user = await User.get_by_username(session, username) - - if user is None: - return None - - assert user is not None - - current_token = await RefreshToken.get_token(session, user.id, token) - - if current_token is not None: - return current_token - - return None - - -def create_access_token(user: User) -> str: - expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) - to_encode = { - "sub": user.username, - "exp": expire, - } - encoded_jwt = jwt.encode( - to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM - ) - return encoded_jwt - - -async def create_refresh_token(session: AsyncSession, user: User) -> RefreshToken: - expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) - to_encode = { - "sub": user.username, - "exp": expire, - } - encoded_jwt = jwt.encode( - to_encode, settings.REFRESH_SECRET_KEY, algorithm=settings.ALGORITHM - ) - return await RefreshToken.create(session, token=encoded_jwt, user_id=user.id) - - -async def get_current_user(ssn: AsyncSessionDependency, token: TokenDependency) -> User: - async with ssn() as session: - try: - payload = jwt.decode( - token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] - ) - sub = payload.get("sub") - - if sub is None: - raise HTTPException(status_code=401, detail="Unathorized") - - if not isinstance(sub, str): - raise HTTPException(status_code=401, detail="Unathorized") - - username: str = str(sub) - - if username is None: - raise HTTPException(status_code=401, detail="Unathorized") - - except JWTError: - raise HTTPException(status_code=401, detail="Unathorized") - - user = await User.get_by_username(session, username) - - if user is None: - raise HTTPException(status_code=401, detail="Unathorized") - - assert user is not None - return user diff --git a/fooder/command/__init__.py b/fooder/command/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fooder/context.py b/fooder/context.py index 566453b..febfbf0 100644 --- a/fooder/context.py +++ b/fooder/context.py @@ -1,7 +1,14 @@ from sqlalchemy.ext.asyncio import AsyncSession from fastapi import Depends +from fastapi.security import OAuth2PasswordBearer +from typing import Callable +from datetime import datetime from .db import get_db_session +from .domain import User from .repository import Repository +from .utils.datetime import utc_now +from .utils.jwt import AccessToken +from .exc import Unauthorized class Context: @@ -9,14 +16,26 @@ class Context: Main API context, aggregating dependencies """ - def __init__(self, repo: Repository) -> None: + def __init__( + self, + repo: Repository, + clock: Callable[[], datetime] = utc_now, + _user: User | None = None, + ) -> None: self.repo = repo + self.clock = clock + self._user = _user + + @property + def user(self) -> User: + if self._user is None: + raise Unauthorized() + return self._user class ContextDependency: """ - Configurable context dependecy. Allows for shared interface configuring - method required dependencies + Context dependecy """ def __init__( @@ -29,3 +48,21 @@ class ContextDependency: session: AsyncSession = Depends(get_db_session), ): return Context(repo=Repository(session)) + + +class AuthContextDependency: + """ + Context dependecy for authorized endpionts + """ + + async def __call__( + self, + token: str = Depends(OAuth2PasswordBearer(tokenUrl="/token")), + session: AsyncSession = Depends(get_db_session), + ) -> Context: + access_token = AccessToken.decode(token) + repo = Repository(session) + user = await repo.user.get(User.id == access_token.sub) + if user is None: + raise Unauthorized() + return Context(repo=repo, _user=user) diff --git a/fooder/controller/__init__.py b/fooder/controller/__init__.py index e69de29..ef2a6a2 100644 --- a/fooder/controller/__init__.py +++ b/fooder/controller/__init__.py @@ -0,0 +1,2 @@ +from .user import UserController +from .token import TokenController diff --git a/fooder/controller/base.py b/fooder/controller/base.py index a9948b8..ae326d3 100644 --- a/fooder/controller/base.py +++ b/fooder/controller/base.py @@ -1,33 +1,17 @@ -from typing import Annotated, Any - -from fastapi import Depends -from sqlalchemy.ext.asyncio import async_sessionmaker - -from ..auth import authorize_api_key, get_current_user -from ..db import get_db_session, AsyncSession -from ..domain.user import User - -UserDependency = Annotated[User, Depends(get_current_user)] -ApiKeyDependency = Annotated[None, Depends(authorize_api_key)] +from ..context import Context +from typing import TypeVar, Generic +from sqlalchemy import BinaryExpression -class BaseController: - def __init__(self, session: AsyncSession) -> None: - self.session = session - - async def call(self, *args, **kwargs) -> Any: - raise NotImplementedError - - async def __call__(self, *args, **kwargs) -> Any: - return await self.call(*args, **kwargs) +T = TypeVar("T") -class AuthorizedController(BaseController): - def __init__(self, session: AsyncSession, user: UserDependency) -> None: - super().__init__(session) - self.user = user +class ControllerBase: + def __init__(self, ctx: Context) -> None: + self.ctx = ctx -class TasksSessionController(BaseController): - def __init__(self, session: AsyncSession, api_key: ApiKeyDependency) -> None: - super().__init__(session) +class ModelController(Generic[T], ControllerBase): + def __init__(self, ctx: Context, obj: T): + super().__init__(ctx) + self.obj = obj diff --git a/fooder/controller/token.py b/fooder/controller/token.py index b0b3445..dc437ce 100644 --- a/fooder/controller/token.py +++ b/fooder/controller/token.py @@ -1,58 +1,25 @@ -from fastapi import HTTPException -from fastapi.security import OAuth2PasswordRequestForm +from .base import ControllerBase +from ..context import Context +from ..utils.jwt import Token, AccessToken, RefreshToken +from typing import Type, TypeVar +from datetime import datetime -from ..auth import ( - authenticate_user, - create_access_token, - create_refresh_token, - verify_refresh_token, -) -from ..domain.user import User as DBUser -from ..model.token import RefreshTokenPayload, Token -from .base import BaseController +T = TypeVar("T", bound=Token) -class CreateToken(BaseController): - async def call(self, content: OAuth2PasswordRequestForm) -> Token: - async with self.async_session.begin() as session: - user = await authenticate_user(session, content.username, content.password) +class TokenController(ControllerBase): + def __init__(self, ctx: Context, entity_id: int) -> None: + super().__init__(ctx) + self.entity_id = entity_id - if user is None: - raise HTTPException( - status_code=401, detail="Invalid username or password" - ) + def generate_token(self, token_cls: Type[T], now: datetime) -> T: + return token_cls(exp=token_cls.calculate_exp(now), sub=self.entity_id) - refresh_token = await create_refresh_token(session, user) - access_token = create_access_token(user) + def generate_refresh_token(self, now: datetime) -> RefreshToken: + return self.generate_token(RefreshToken, now) - return Token( - access_token=access_token, - refresh_token=refresh_token.token, - token_type="bearer", - ) + def generate_access_token(self, now: datetime) -> AccessToken: + return self.generate_token(AccessToken, now) - -class RefreshToken(BaseController): - async def call(self, content: RefreshTokenPayload) -> Token: - async with self.async_session.begin() as session: - current_token = await verify_refresh_token(session, content.refresh_token) - - if current_token is None: - raise HTTPException(status_code=401, detail="Invalid token") - - user = await DBUser.get(session, current_token.user_id) - - if user is None: - raise HTTPException(status_code=401, detail="Invalid token") - - assert user is not None - await current_token.delete(session) - - refresh_token = await create_refresh_token(session, user) - access_token = create_access_token(user) - - return Token( - access_token=access_token, - refresh_token=refresh_token.token, - token_type="bearer", - ) + def generate_token_pair(self, now: datetime) -> tuple[AccessToken, RefreshToken]: + return (self.generate_access_token(now), self.generate_refresh_token(now)) diff --git a/fooder/controller/user.py b/fooder/controller/user.py index de48c9d..c3f5f1a 100644 --- a/fooder/controller/user.py +++ b/fooder/controller/user.py @@ -1,19 +1,24 @@ -from fastapi import HTTPException - -from ..domain.user import User as DBUser -from ..model.user import CreateUserPayload, User -from .base import BaseController +from .base import ModelController +from ..domain import User +from ..context import Context +from ..exc import Unauthorized +from .token import TokenController -class CreateUser(BaseController): - async def call(self, content: CreateUserPayload) -> User: - async with self.async_session.begin() as session: - try: - user = await DBUser.create( - session, - content.username, - content.password, - ) - return User.from_orm(user) - except AssertionError as e: - raise HTTPException(status_code=400, detail=e.args[0]) +class UserController(ModelController[User]): + @classmethod + async def session_start( + cls, + ctx: Context, + username: str, + password: str, + ) -> "UserController": + obj = await ctx.repo.user.get(User.username == username) + + if obj is None or not obj.verify_password(password): + raise Unauthorized() + + return cls(ctx, obj) + + def token_ctrl(self) -> TokenController: + return TokenController(ctx=self.ctx, entity_id=self.obj.id) diff --git a/fooder/controller_old/__init__.py b/fooder/controller_old/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fooder/controller_old/base.py b/fooder/controller_old/base.py new file mode 100644 index 0000000..a9948b8 --- /dev/null +++ b/fooder/controller_old/base.py @@ -0,0 +1,33 @@ +from typing import Annotated, Any + +from fastapi import Depends +from sqlalchemy.ext.asyncio import async_sessionmaker + +from ..auth import authorize_api_key, get_current_user +from ..db import get_db_session, AsyncSession +from ..domain.user import User + +UserDependency = Annotated[User, Depends(get_current_user)] +ApiKeyDependency = Annotated[None, Depends(authorize_api_key)] + + +class BaseController: + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def call(self, *args, **kwargs) -> Any: + raise NotImplementedError + + async def __call__(self, *args, **kwargs) -> Any: + return await self.call(*args, **kwargs) + + +class AuthorizedController(BaseController): + def __init__(self, session: AsyncSession, user: UserDependency) -> None: + super().__init__(session) + self.user = user + + +class TasksSessionController(BaseController): + def __init__(self, session: AsyncSession, api_key: ApiKeyDependency) -> None: + super().__init__(session) diff --git a/fooder/controller/diary.py b/fooder/controller_old/diary.py similarity index 100% rename from fooder/controller/diary.py rename to fooder/controller_old/diary.py diff --git a/fooder/controller/entry.py b/fooder/controller_old/entry.py similarity index 100% rename from fooder/controller/entry.py rename to fooder/controller_old/entry.py diff --git a/fooder/controller/meal.py b/fooder/controller_old/meal.py similarity index 100% rename from fooder/controller/meal.py rename to fooder/controller_old/meal.py diff --git a/fooder/controller/preset.py b/fooder/controller_old/preset.py similarity index 100% rename from fooder/controller/preset.py rename to fooder/controller_old/preset.py diff --git a/fooder/controller/product.py b/fooder/controller_old/product.py similarity index 100% rename from fooder/controller/product.py rename to fooder/controller_old/product.py diff --git a/fooder/controller/tasks.py b/fooder/controller_old/tasks.py similarity index 100% rename from fooder/controller/tasks.py rename to fooder/controller_old/tasks.py diff --git a/fooder/controller_old/token.py b/fooder/controller_old/token.py new file mode 100644 index 0000000..b0b3445 --- /dev/null +++ b/fooder/controller_old/token.py @@ -0,0 +1,58 @@ +from fastapi import HTTPException +from fastapi.security import OAuth2PasswordRequestForm + +from ..auth import ( + authenticate_user, + create_access_token, + create_refresh_token, + verify_refresh_token, +) +from ..domain.user import User as DBUser +from ..model.token import RefreshTokenPayload, Token +from .base import BaseController + + +class CreateToken(BaseController): + async def call(self, content: OAuth2PasswordRequestForm) -> Token: + async with self.async_session.begin() as session: + user = await authenticate_user(session, content.username, content.password) + + if user is None: + raise HTTPException( + status_code=401, detail="Invalid username or password" + ) + + refresh_token = await create_refresh_token(session, user) + access_token = create_access_token(user) + + return Token( + access_token=access_token, + refresh_token=refresh_token.token, + token_type="bearer", + ) + + +class RefreshToken(BaseController): + async def call(self, content: RefreshTokenPayload) -> Token: + async with self.async_session.begin() as session: + current_token = await verify_refresh_token(session, content.refresh_token) + + if current_token is None: + raise HTTPException(status_code=401, detail="Invalid token") + + user = await DBUser.get(session, current_token.user_id) + + if user is None: + raise HTTPException(status_code=401, detail="Invalid token") + + assert user is not None + await current_token.delete(session) + + refresh_token = await create_refresh_token(session, user) + access_token = create_access_token(user) + + return Token( + access_token=access_token, + refresh_token=refresh_token.token, + token_type="bearer", + ) diff --git a/fooder/controller_old/user.py b/fooder/controller_old/user.py new file mode 100644 index 0000000..de48c9d --- /dev/null +++ b/fooder/controller_old/user.py @@ -0,0 +1,19 @@ +from fastapi import HTTPException + +from ..domain.user import User as DBUser +from ..model.user import CreateUserPayload, User +from .base import BaseController + + +class CreateUser(BaseController): + async def call(self, content: CreateUserPayload) -> User: + async with self.async_session.begin() as session: + try: + user = await DBUser.create( + session, + content.username, + content.password, + ) + return User.from_orm(user) + except AssertionError as e: + raise HTTPException(status_code=400, detail=e.args[0]) diff --git a/fooder/exc.py b/fooder/exc.py new file mode 100644 index 0000000..fb08191 --- /dev/null +++ b/fooder/exc.py @@ -0,0 +1,19 @@ +from typing import ClassVar + + +class ApiException(Exception): + HTTP_CODE: ClassVar[int] + MESSAGE: ClassVar[str] + + def __init__(self, message: str | None = None) -> None: + self.message = message or self.MESSAGE + + +class NotFound(ApiException): + HTTP_CODE = 404 + MESSAGE = "Not found" + + +class Unauthorized(ApiException): + HTTP_CODE = 401 + MESSAGE = "Unathorized" diff --git a/fooder/model/token.py b/fooder/model/token.py index 3de026e..368ab79 100644 --- a/fooder/model/token.py +++ b/fooder/model/token.py @@ -1,15 +1,11 @@ from pydantic import BaseModel -class Token(BaseModel): +class TokenResponse(BaseModel): access_token: str refresh_token: str token_type: str = "bearer" -class TokenData(BaseModel): - username: str | None = None - - -class RefreshTokenPayload(BaseModel): +class RefreshTokenRequest(BaseModel): refresh_token: str diff --git a/fooder/repository/base.py b/fooder/repository/base.py index 77785aa..11e3da8 100644 --- a/fooder/repository/base.py +++ b/fooder/repository/base.py @@ -1,14 +1,10 @@ from typing import TypeVar, Generic, Type, Sequence from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import ( - select, - update as sa_update, - delete as sa_delete, - BinaryExpression, -) +from sqlalchemy import select, delete as sa_delete, ColumnElement from sqlalchemy.sql import Select +from ..domain import Base -T = TypeVar("T") +T = TypeVar("T", bound=Base) class RepositoryBase(Generic[T]): @@ -16,7 +12,7 @@ class RepositoryBase(Generic[T]): self.model = model self.session = session - def _build_select(self, *expressions: BinaryExpression) -> Select[tuple[T]]: + def _build_select(self, *expressions: ColumnElement) -> Select[tuple[T]]: stmt = select(self.model) if expressions: @@ -24,12 +20,12 @@ class RepositoryBase(Generic[T]): return stmt - async def get(self, *expressions: BinaryExpression) -> T | None: + async def get(self, *expressions: ColumnElement) -> T | None: stmt = self._build_select(*expressions) result = await self.session.execute(stmt) return result.scalar_one_or_none() - async def list(self, *expressions: BinaryExpression) -> Sequence[T]: + async def list(self, *expressions: ColumnElement) -> Sequence[T]: stmt = self._build_select(*expressions) result = await self.session.execute(stmt) return result.scalars().all() @@ -40,7 +36,7 @@ class RepositoryBase(Generic[T]): await self.session.refresh(obj) return obj - async def delete(self, *expressions: BinaryExpression): + async def delete(self, *expressions: ColumnElement): stmt = sa_delete(self.model) if expressions: diff --git a/fooder/router.py b/fooder/router.py index dc01191..560e8f2 100644 --- a/fooder/router.py +++ b/fooder/router.py @@ -1,18 +1,6 @@ from fastapi import APIRouter -# from .view.diary import router as diary_router -# from .view.entry import router as entry_router -# from .view.meal import router as meal_router -# from .view.preset import router as preset_router -# from .view.product import router as product_router from .view.token import router as token_router -# from .view.user import router as user_router router = APIRouter(prefix="/api") -# router.include_router(product_router, prefix="/product", tags=["product"]) -# router.include_router(diary_router, prefix="/diary", tags=["diary"]) -# router.include_router(meal_router, prefix="/meal", tags=["meal"]) -# router.include_router(entry_router, prefix="/entry", tags=["entry"]) router.include_router(token_router, prefix="/token", tags=["token"]) -# router.include_router(user_router, prefix="/user", tags=["user"]) -# router.include_router(preset_router, prefix="/preset", tags=["preset"]) diff --git a/fooder/tasks_app.py b/fooder/tasks_app.py deleted file mode 100644 index 88e639f..0000000 --- a/fooder/tasks_app.py +++ /dev/null @@ -1,17 +0,0 @@ -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware - -from .settings import Settings -from .view.tasks import router - -app = FastAPI(title="Fooder Tasks admininstrative API") -app.include_router(router) - - -app.add_middleware( - CORSMiddleware, - allow_origins=Settings().ALLOWED_ORIGINS, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) diff --git a/fooder/test/controller/__init__.py b/fooder/test/controller/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fooder/test/controller/test_token.py b/fooder/test/controller/test_token.py new file mode 100644 index 0000000..0b2b4aa --- /dev/null +++ b/fooder/test/controller/test_token.py @@ -0,0 +1,6 @@ +from fooder.controller import TokenController + + +def test_token_ctrl_generates_token(ctx): + token_ctrl = TokenController(ctx, 1) + token_ctrl.generate_token_pair(ctx.clock()) diff --git a/fooder/test/test_exc.py b/fooder/test/test_exc.py new file mode 100644 index 0000000..3e2663a --- /dev/null +++ b/fooder/test/test_exc.py @@ -0,0 +1,11 @@ +from fooder.exc import ApiException + + +class TestException(ApiException): + HTTP_CODE = 0 + MESSAGE = "test" + + +def test_exc_message(): + assert TestException().message == TestException.MESSAGE + assert TestException("other message").message == "other message" diff --git a/fooder/test/utils/test_jwt.py b/fooder/test/utils/test_jwt.py index 8202a79..47ebe9a 100644 --- a/fooder/test/utils/test_jwt.py +++ b/fooder/test/utils/test_jwt.py @@ -1,7 +1,9 @@ from datetime import datetime, timedelta, timezone import pytest -from jose import JWTError +from jose import jwt +from typing import Literal +from fooder.exc import Unauthorized from fooder.utils.jwt import AccessToken, RefreshToken, Token @@ -9,6 +11,7 @@ PAST = datetime(2000, 1, 1, tzinfo=timezone.utc) class WrongKeyToken(Token): + token_type: Literal["test-type"] = "test-type" secret_key = "wrong-secret" expire_delta = timedelta(minutes=30) @@ -29,15 +32,22 @@ class TestAccessToken: now = datetime.now(timezone.utc) token = WrongKeyToken(exp=WrongKeyToken.calculate_exp(now), sub=1) - with pytest.raises(JWTError): + with pytest.raises(Unauthorized): AccessToken.decode(token.encode()) def test_decode_expired_raises(self): token = AccessToken(exp=PAST, sub=1) - with pytest.raises(JWTError): + with pytest.raises(Unauthorized): AccessToken.decode(token.encode()) + def test_encoded_fields(self): + now = datetime.now(timezone.utc) + token = AccessToken(exp=AccessToken.calculate_exp(now), sub=42) + payload = jwt.decode(token.encode(), "", options={"verify_signature": False}) + assert "secret_key" not in payload + assert "expire_delta" not in payload + class TestRefreshToken: def test_encode_decode_roundtrip(self): @@ -55,12 +65,19 @@ class TestRefreshToken: now = datetime.now(timezone.utc) token = RefreshToken(exp=RefreshToken.calculate_exp(now), sub=1) - with pytest.raises(JWTError): + with pytest.raises(Unauthorized): AccessToken.decode(token.encode()) def test_access_token_not_decodable_as_refresh_token(self): now = datetime.now(timezone.utc) token = AccessToken(exp=AccessToken.calculate_exp(now), sub=1) - with pytest.raises(JWTError): + with pytest.raises(Unauthorized): RefreshToken.decode(token.encode()) + + def test_encoded_fields(self): + now = datetime.now(timezone.utc) + token = AccessToken(exp=RefreshToken.calculate_exp(now), sub=42) + payload = jwt.decode(token.encode(), "", options={"verify_signature": False}) + assert "secret_key" not in payload + assert "expire_delta" not in payload diff --git a/fooder/utils/datetime.py b/fooder/utils/datetime.py new file mode 100644 index 0000000..fa26346 --- /dev/null +++ b/fooder/utils/datetime.py @@ -0,0 +1,5 @@ +from datetime import datetime, timezone + + +def utc_now() -> datetime: + return datetime.now(timezone.utc) diff --git a/fooder/utils/jwt.py b/fooder/utils/jwt.py index 6bd3f94..fa0e667 100644 --- a/fooder/utils/jwt.py +++ b/fooder/utils/jwt.py @@ -1,11 +1,14 @@ -from jose import jwt +from jose import jwt, JOSEError from pydantic import BaseModel from datetime import timedelta, datetime -from typing import ClassVar +from typing import ClassVar, Literal +import logging from ..settings import settings +from ..exc import Unauthorized class Token(BaseModel): + token_type: str exp: datetime sub: int @@ -18,7 +21,13 @@ class Token(BaseModel): @classmethod def decode(cls, jwt_token: str | bytes) -> "Token": - data = jwt.decode(jwt_token, cls.secret_key, algorithms=[settings.ALGORITHM]) + try: + data = jwt.decode( + jwt_token, cls.secret_key, algorithms=[settings.ALGORITHM] + ) + except JOSEError as e: + logging.error(e) + raise Unauthorized() return cls(**data) def encode(self) -> str: @@ -28,10 +37,12 @@ class Token(BaseModel): class AccessToken(Token): + token_type: Literal["access"] = "access" secret_key = settings.SECRET_KEY expire_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) class RefreshToken(Token): + token_type: Literal["refresh"] = "refresh" secret_key = settings.REFRESH_SECRET_KEY expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) diff --git a/fooder/view/token.py b/fooder/view/token.py index f5f86af..6b44e2a 100644 --- a/fooder/view/token.py +++ b/fooder/view/token.py @@ -1,39 +1,38 @@ from typing import Annotated -from datetime import datetime, timezone -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends from fastapi.security import OAuth2PasswordRequestForm +from datetime import datetime -from ..model.token import RefreshTokenPayload, Token +from ..model.token import TokenResponse, RefreshTokenRequest from ..context import ContextDependency, Context -from ..utils.jwt import AccessToken, RefreshToken -from ..domain import User +from ..controller import UserController router = APIRouter(tags=["token"]) -@router.post("", response_model=Token) -async def create_token( - data: Annotated[OAuth2PasswordRequestForm, Depends()], - ctx: Context = Depends(ContextDependency()), -): - user = await ctx.repo.user.get(User.username == data.username) - - if user is None or not user.verify_password(data.password): - raise HTTPException(status_code=401, detail="Unathorized") - - now = datetime.now(timezone.utc) - access_token = AccessToken(sub=user.id, exp=AccessToken.calculate_exp(now)) - refresh_token = RefreshToken(sub=user.id, exp=RefreshToken.calculate_exp(now)) - return Token( +def gen_token_response(user_ctrl: UserController, now: datetime) -> TokenResponse: + token_ctrl = user_ctrl.token_ctrl() + access_token, refresh_token = token_ctrl.generate_token_pair(now) + return TokenResponse( access_token=access_token.encode(), refresh_token=refresh_token.encode(), ) -@router.post("/refresh", response_model=Token) -async def refresh_token( - data: RefreshTokenPayload, +@router.post("", response_model=TokenResponse) +async def token_create( + data: Annotated[OAuth2PasswordRequestForm, Depends()], + ctx: Context = Depends(ContextDependency()), +) -> TokenResponse: + now = ctx.clock() + user_ctrl = await UserController.session_start(ctx, data.username, data.password) + return gen_token_response(user_ctrl, now) + + +@router.post("/refresh", response_model=TokenResponse) +async def token_refresh( + data: RefreshTokenRequest, ctx: Context = Depends(ContextDependency()), ): pass