[token] refresh logic + some cleanups and a lot of tests

This commit is contained in:
Piotr Domański 2026-04-07 13:01:10 +02:00
parent 991560d943
commit bd102360ad
9 changed files with 160 additions and 23 deletions

View file

@ -7,7 +7,6 @@ from fooder.db import get_db_session
from fooder.domain import User
from fooder.repository import Repository
from fooder.utils.datetime import utc_now
from fooder.utils.jwt import AccessToken
from fooder.exc import Unauthorized
@ -20,11 +19,13 @@ class Context:
self,
repo: Repository,
clock: Callable[[], datetime] = utc_now,
_user: User | None = None,
) -> None:
self.repo = repo
self.clock = clock
self._user = _user
self._user = None
def set_user(self, user: User) -> None:
self._user = user
@property
def user(self) -> User:
@ -60,9 +61,11 @@ class AuthContextDependency:
token: str = Depends(OAuth2PasswordBearer(tokenUrl="/token")),
session: AsyncSession = Depends(get_db_session),
) -> Context:
access_token = AccessToken.decode(token)
repo = Repository(session)
user = await repo.user.get(User.id == access_token.sub)
ctx = Context(repo = Repository(session))
from fooder.controller.token import TokenController
token_ctrl = TokenController.from_access_token(ctx, token)
user = await ctx.repo.user.get(User.id == token_ctrl.entity_id)
if user is None:
raise Unauthorized()
return Context(repo=repo, _user=user)
ctx.set_user(user)
return ctx

View file

@ -12,6 +12,19 @@ class TokenController(ControllerBase):
super().__init__(ctx)
self.entity_id = entity_id
@classmethod
def from_token(cls, ctx: Context, token_str: str, token_cls: Type[T]) -> "TokenController":
token = token_cls.decode(token_str)
return cls(ctx, token.sub)
@classmethod
def from_refresh_token(cls, ctx: Context, token_str: str) -> "TokenController":
return cls.from_token(ctx, token_str, RefreshToken)
@classmethod
def from_access_token(cls, ctx: Context, token_str: str) -> "TokenController":
return cls.from_token(ctx, token_str, AccessToken)
def generate_token(self, token_cls: Type[T], now: datetime) -> T:
return token_cls(exp=token_cls.calculate_exp(now), sub=self.entity_id)

View file

@ -5,7 +5,3 @@ class TokenResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
class RefreshTokenRequest(BaseModel):
refresh_token: str

View file

@ -1,6 +1,46 @@
from fooder.controller import TokenController
from datetime import datetime, timezone
import pytest
from fooder.controller.token import TokenController
from fooder.exc import Unauthorized
from fooder.utils.jwt import AccessToken, RefreshToken
NOW = datetime.now(timezone.utc)
def test_token_ctrl_generates_token(ctx):
token_ctrl = TokenController(ctx, 1)
token_ctrl.generate_token_pair(ctx.clock())
class TestFromRefreshToken:
def test_returns_controller_with_correct_entity_id(self, ctx):
token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=42)
ctrl = TokenController.from_refresh_token(ctx, token.encode())
assert ctrl.entity_id == 42
def test_invalid_string_raises(self, ctx):
with pytest.raises(Unauthorized):
TokenController.from_refresh_token(ctx, "bad-token")
def test_access_token_raises(self, ctx):
token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=1)
with pytest.raises(Unauthorized):
TokenController.from_refresh_token(ctx, token.encode())
class TestFromAccessToken:
def test_returns_controller_with_correct_entity_id(self, ctx):
token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=7)
ctrl = TokenController.from_access_token(ctx, token.encode())
assert ctrl.entity_id == 7
def test_invalid_string_raises(self, ctx):
with pytest.raises(Unauthorized):
TokenController.from_access_token(ctx, "bad-token")
def test_refresh_token_raises(self, ctx):
token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=1)
with pytest.raises(Unauthorized):
TokenController.from_access_token(ctx, token.encode())

View file

@ -0,0 +1,32 @@
from datetime import datetime, timezone
import pytest
from fooder.context import AuthContextDependency
from fooder.exc import Unauthorized
from fooder.utils.jwt import AccessToken, RefreshToken
NOW = datetime.now(timezone.utc)
async def test_auth_context_valid_token_returns_correct_user(db_session, user):
token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=user.id)
ctx = await AuthContextDependency()(token=token.encode(), session=db_session)
assert ctx.user.id == user.id
async def test_auth_context_invalid_token_raises(db_session):
with pytest.raises(Unauthorized):
await AuthContextDependency()(token="bad-token", session=db_session)
async def test_auth_context_refresh_token_raises(db_session, user):
token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=user.id)
with pytest.raises(Unauthorized):
await AuthContextDependency()(token=token.encode(), session=db_session)
async def test_auth_context_unknown_user_raises(db_session):
token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=99999)
with pytest.raises(Unauthorized):
await AuthContextDependency()(token=token.encode(), session=db_session)

View file

@ -11,7 +11,6 @@ PAST = datetime(2000, 1, 1, tzinfo=timezone.utc)
class WrongKeyToken(Token):
token_type: Literal["test-type"] = "test-type"
secret_key = "wrong-secret"
expire_delta = timedelta(minutes=30)

View file

@ -45,3 +45,58 @@ async def test_create_token_unknown_user(client):
data={"username": "nobody", "password": "x"},
)
assert response.status_code == 401
async def test_refresh_token_returns_new_tokens(client, user, user_password):
response = await client.post(
"/api/token",
data={"username": user.username, "password": user_password},
)
refresh_token = response.json()["refresh_token"]
response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token})
assert response.status_code == 200
body = response.json()
assert "access_token" in body
assert "refresh_token" in body
assert body["token_type"] == "bearer"
async def test_refresh_token_access_token_is_valid(client, user, user_password):
response = await client.post(
"/api/token",
data={"username": user.username, "password": user_password},
)
refresh_token = response.json()["refresh_token"]
response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token})
token = AccessToken.decode(response.json()["access_token"])
assert token.sub == user.id
async def test_refresh_token_refresh_token_is_valid(client, user, user_password):
response = await client.post(
"/api/token",
data={"username": user.username, "password": user_password},
)
refresh_token = response.json()["refresh_token"]
response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token})
token = RefreshToken.decode(response.json()["refresh_token"])
assert token.sub == user.id
async def test_refresh_token_invalid_returns_401(client):
response = await client.post("/api/token/refresh", params={"refresh_token": "bad-token"})
assert response.status_code == 401
async def test_refresh_token_access_token_as_refresh_returns_401(client, user, user_password):
response = await client.post(
"/api/token",
data={"username": user.username, "password": user_password},
)
access_token = response.json()["access_token"]
response = await client.post("/api/token/refresh", params={"refresh_token": access_token})
assert response.status_code == 401

View file

@ -8,7 +8,6 @@ from fooder.exc import Unauthorized
class Token(BaseModel):
token_type: str
exp: datetime
sub: int
@ -37,12 +36,10 @@ class Token(BaseModel):
class AccessToken(Token):
token_type: Literal["access"] = "access"
secret_key = settings.SECRET_KEY
expire_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
class RefreshToken(Token):
token_type: Literal["refresh"] = "refresh"
secret_key = settings.REFRESH_SECRET_KEY
expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)

View file

@ -4,15 +4,15 @@ from fastapi import APIRouter, Depends
from fastapi.security import OAuth2PasswordRequestForm
from datetime import datetime
from fooder.model.token import TokenResponse, RefreshTokenRequest
from fooder.model.token import TokenResponse
from fooder.context import ContextDependency, Context
from fooder.controller import UserController
from fooder.controller.token import TokenController
router = APIRouter(tags=["token"])
def gen_token_response(user_ctrl: UserController, now: datetime) -> TokenResponse:
token_ctrl = user_ctrl.token_ctrl()
def gen_token_response(token_ctrl: TokenController, now: datetime) -> TokenResponse:
access_token, refresh_token = token_ctrl.generate_token_pair(now)
return TokenResponse(
access_token=access_token.encode(),
@ -27,12 +27,14 @@ async def token_create(
) -> TokenResponse:
now = ctx.clock()
user_ctrl = await UserController.session_start(ctx, data.username, data.password)
return gen_token_response(user_ctrl, now)
return gen_token_response(user_ctrl.token_ctrl(), now)
@router.post("/refresh", response_model=TokenResponse)
async def token_refresh(
data: RefreshTokenRequest,
refresh_token: str,
ctx: Context = Depends(ContextDependency()),
):
pass
) -> TokenResponse:
now = ctx.clock()
token_ctrl = TokenController.from_refresh_token(ctx, refresh_token)
return gen_token_response(token_ctrl, now)