[token] implement jwt tokens + first view!

This commit is contained in:
Piotr Domański 2026-04-02 22:32:42 +02:00
parent 20ffc18044
commit 10ef646d93
12 changed files with 240 additions and 40 deletions

View file

@ -13,21 +13,7 @@ from .domain.token import RefreshToken
from .domain.user import User from .domain.user import User
from .settings import Settings from .settings import Settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
settings = Settings()
password_helper = PasswordHelper(pwd_context) # type: ignore
AsyncSessionDependency = Annotated[async_sessionmaker, Depends(get_db_session)]
TokenDependency = Annotated[str, Depends(oauth2_scheme)]
def verify(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def hash(password: str) -> str:
return pwd_context.hash(password)
async def authenticate_user( async def authenticate_user(
@ -137,9 +123,3 @@ async def get_current_user(ssn: AsyncSessionDependency, token: TokenDependency)
assert user is not None assert user is not None
return user return user
async def authorize_api_key(token: TokenDependency) -> None:
if token == settings.API_KEY:
return None
raise HTTPException(status_code=401, detail="Unathorized")

View file

@ -5,4 +5,11 @@ from ..domain import User
class Repository: class Repository:
def __init__(self, session: AsyncSession): def __init__(self, session: AsyncSession):
self.session = session
self.user = UserRepository(User, session) self.user = UserRepository(User, session)
async def commit(self) -> None:
await self.session.commit()
async def rollback(self) -> None:
await self.session.rollback()

View file

@ -1,18 +1,18 @@
from fastapi import APIRouter from fastapi import APIRouter
from .view.diary import router as diary_router # from .view.diary import router as diary_router
from .view.entry import router as entry_router # from .view.entry import router as entry_router
from .view.meal import router as meal_router # from .view.meal import router as meal_router
from .view.preset import router as preset_router # from .view.preset import router as preset_router
from .view.product import router as product_router # from .view.product import router as product_router
from .view.token import router as token_router from .view.token import router as token_router
from .view.user import router as user_router # from .view.user import router as user_router
router = APIRouter(prefix="/api") router = APIRouter(prefix="/api")
router.include_router(product_router, prefix="/product", tags=["product"]) # router.include_router(product_router, prefix="/product", tags=["product"])
router.include_router(diary_router, prefix="/diary", tags=["diary"]) # router.include_router(diary_router, prefix="/diary", tags=["diary"])
router.include_router(meal_router, prefix="/meal", tags=["meal"]) # router.include_router(meal_router, prefix="/meal", tags=["meal"])
router.include_router(entry_router, prefix="/entry", tags=["entry"]) # router.include_router(entry_router, prefix="/entry", tags=["entry"])
router.include_router(token_router, prefix="/token", tags=["token"]) router.include_router(token_router, prefix="/token", tags=["token"])
router.include_router(user_router, prefix="/user", tags=["user"]) # router.include_router(user_router, prefix="/user", tags=["user"])
router.include_router(preset_router, prefix="/preset", tags=["preset"]) # router.include_router(preset_router, prefix="/preset", tags=["preset"])

View file

@ -1,6 +1,9 @@
import pytest import pytest
from .db import * from .db import *
from .faker import * from .faker import *
from .user import *
from .client import *
from .context import *
@pytest.fixture @pytest.fixture

16
fooder/test/fixtures/client.py vendored Normal file
View file

@ -0,0 +1,16 @@
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from fooder.app import app
from fooder.db import get_db_session
@pytest_asyncio.fixture
async def client(db_session):
async def override_get_db_session():
yield db_session
app.dependency_overrides[get_db_session] = override_get_db_session
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c:
yield c
app.dependency_overrides.clear()

8
fooder/test/fixtures/context.py vendored Normal file
View file

@ -0,0 +1,8 @@
import pytest
from fooder.context import Context
from fooder.repository import Repository
@pytest.fixture
def ctx(db_session):
return Context(repo=Repository(db_session))

25
fooder/test/fixtures/user.py vendored Normal file
View file

@ -0,0 +1,25 @@
import pytest
import pytest_asyncio
from fooder.domain.user import User
@pytest.fixture
def user_password(faker):
return faker.password()
@pytest_asyncio.fixture
async def user_factory(ctx):
async def factory(username, password):
user = User(username=username)
user.set_password(password)
await ctx.repo.user.create(user)
return user
return factory
@pytest_asyncio.fixture
async def user(faker, user_password, user_factory):
return await user_factory(faker.name(), user_password)

View file

@ -0,0 +1,66 @@
from datetime import datetime, timedelta, timezone
import pytest
from jose import JWTError
from fooder.utils.jwt import AccessToken, RefreshToken, Token
PAST = datetime(2000, 1, 1, tzinfo=timezone.utc)
class WrongKeyToken(Token):
secret_key = "wrong-secret"
expire_delta = timedelta(minutes=30)
class TestAccessToken:
def test_encode_decode_roundtrip(self):
now = datetime.now(timezone.utc)
token = AccessToken(exp=AccessToken.calculate_exp(now), sub=42)
decoded = AccessToken.decode(token.encode())
assert decoded.sub == token.sub
def test_calculate_exp(self):
now = datetime.now(timezone.utc)
assert AccessToken.calculate_exp(now) > now
def test_decode_wrong_key_raises(self):
now = datetime.now(timezone.utc)
token = WrongKeyToken(exp=WrongKeyToken.calculate_exp(now), sub=1)
with pytest.raises(JWTError):
AccessToken.decode(token.encode())
def test_decode_expired_raises(self):
token = AccessToken(exp=PAST, sub=1)
with pytest.raises(JWTError):
AccessToken.decode(token.encode())
class TestRefreshToken:
def test_encode_decode_roundtrip(self):
now = datetime.now(timezone.utc)
token = RefreshToken(exp=RefreshToken.calculate_exp(now), sub=7)
decoded = RefreshToken.decode(token.encode())
assert decoded.sub == token.sub
def test_calculate_exp(self):
now = datetime.now(timezone.utc)
assert RefreshToken.calculate_exp(now) > now
def test_refresh_token_not_decodable_as_access_token(self):
now = datetime.now(timezone.utc)
token = RefreshToken(exp=RefreshToken.calculate_exp(now), sub=1)
with pytest.raises(JWTError):
AccessToken.decode(token.encode())
def test_access_token_not_decodable_as_refresh_token(self):
now = datetime.now(timezone.utc)
token = AccessToken(exp=AccessToken.calculate_exp(now), sub=1)
with pytest.raises(JWTError):
RefreshToken.decode(token.encode())

View file

View file

@ -0,0 +1,47 @@
from fooder.utils.jwt import AccessToken, RefreshToken
async def test_create_token_returns_tokens(client, user, user_password):
response = await client.post(
"/api/token",
data={"username": user.username, "password": user_password},
)
assert response.status_code == 200
body = response.json()
assert "access_token" in body
assert "refresh_token" in body
assert body["token_type"] == "bearer"
async def test_create_token_access_token_is_valid(client, user, user_password):
response = await client.post(
"/api/token",
data={"username": user.username, "password": user_password},
)
token = AccessToken.decode(response.json()["access_token"])
assert token.sub == user.id
async def test_create_token_refresh_token_is_valid(client, user, user_password):
response = await client.post(
"/api/token",
data={"username": user.username, "password": user_password},
)
token = RefreshToken.decode(response.json()["refresh_token"])
assert token.sub == user.id
async def test_create_token_wrong_password(client, user):
response = await client.post(
"/api/token",
data={"username": user.username, "password": "wrong"},
)
assert response.status_code == 401
async def test_create_token_unknown_user(client):
response = await client.post(
"/api/token",
data={"username": "nobody", "password": "x"},
)
assert response.status_code == 401

37
fooder/utils/jwt.py Normal file
View file

@ -0,0 +1,37 @@
from jose import jwt
from pydantic import BaseModel
from datetime import timedelta, datetime
from typing import ClassVar
from ..settings import settings
class Token(BaseModel):
exp: datetime
sub: int
secret_key: ClassVar[str]
expire_delta: ClassVar[timedelta]
@classmethod
def calculate_exp(cls, now: datetime) -> datetime:
return now + cls.expire_delta
@classmethod
def decode(cls, jwt_token: str | bytes) -> "Token":
data = jwt.decode(jwt_token, cls.secret_key, algorithms=[settings.ALGORITHM])
return cls(**data)
def encode(self) -> str:
data = self.model_dump()
data["sub"] = str(data["sub"])
return jwt.encode(data, self.secret_key, settings.ALGORITHM)
class AccessToken(Token):
secret_key = settings.SECRET_KEY
expire_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
class RefreshToken(Token):
secret_key = settings.REFRESH_SECRET_KEY
expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)

View file

@ -1,27 +1,38 @@
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, Request from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from ..controller.token import CreateToken, RefreshToken
from ..model.token import RefreshTokenPayload, Token from ..model.token import RefreshTokenPayload, Token
from ..context import ContextDependency, Context
from ..utils.jwt import AccessToken, RefreshToken
router = APIRouter(tags=["token"]) router = APIRouter(tags=["token"])
@router.post("", response_model=Token) @router.post("", response_model=Token)
async def create_token( async def create_token(
request: Request,
data: Annotated[OAuth2PasswordRequestForm, Depends()], data: Annotated[OAuth2PasswordRequestForm, Depends()],
controller: CreateToken = Depends(CreateToken), ctx: Context = Depends(ContextDependency()),
): ):
return await controller.call(data) user = await ctx.repo.user.get(username=data.username)
if user is None or not user.verify_password(data.password):
raise HTTPException(status_code=401, detail="Unathorized")
now = datetime.now(timezone.utc)
access_token = AccessToken(sub=user.id, exp=AccessToken.calculate_exp(now))
refresh_token = RefreshToken(sub=user.id, exp=RefreshToken.calculate_exp(now))
return Token(
access_token=access_token.encode(),
refresh_token=refresh_token.encode(),
)
@router.post("/refresh", response_model=Token) @router.post("/refresh", response_model=Token)
async def refresh_token( async def refresh_token(
request: Request,
data: RefreshTokenPayload, data: RefreshTokenPayload,
controller: RefreshToken = Depends(RefreshToken), ctx: Context = Depends(ContextDependency()),
): ):
return await controller.call(data) pass