From bd102360adcc12c82bb3e68ac823442bdbbb9e00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Doma=C5=84ski?= Date: Tue, 7 Apr 2026 13:01:10 +0200 Subject: [PATCH] [token] refresh logic + some cleanups and a lot of tests --- fooder/context.py | 17 +++++---- fooder/controller/token.py | 13 +++++++ fooder/model/token.py | 4 -- fooder/test/controller/test_token.py | 42 ++++++++++++++++++++- fooder/test/test_context.py | 32 ++++++++++++++++ fooder/test/utils/test_jwt.py | 1 - fooder/test/view/test_token.py | 55 ++++++++++++++++++++++++++++ fooder/utils/jwt.py | 3 -- fooder/view/token.py | 16 ++++---- 9 files changed, 160 insertions(+), 23 deletions(-) create mode 100644 fooder/test/test_context.py diff --git a/fooder/context.py b/fooder/context.py index 602e596..31f96d8 100644 --- a/fooder/context.py +++ b/fooder/context.py @@ -7,7 +7,6 @@ from fooder.db import get_db_session from fooder.domain import User from fooder.repository import Repository from fooder.utils.datetime import utc_now -from fooder.utils.jwt import AccessToken from fooder.exc import Unauthorized @@ -20,11 +19,13 @@ class Context: self, repo: Repository, clock: Callable[[], datetime] = utc_now, - _user: User | None = None, ) -> None: self.repo = repo self.clock = clock - self._user = _user + self._user = None + + def set_user(self, user: User) -> None: + self._user = user @property def user(self) -> User: @@ -60,9 +61,11 @@ class AuthContextDependency: token: str = Depends(OAuth2PasswordBearer(tokenUrl="/token")), session: AsyncSession = Depends(get_db_session), ) -> Context: - access_token = AccessToken.decode(token) - repo = Repository(session) - user = await repo.user.get(User.id == access_token.sub) + ctx = Context(repo = Repository(session)) + from fooder.controller.token import TokenController + token_ctrl = TokenController.from_access_token(ctx, token) + user = await ctx.repo.user.get(User.id == token_ctrl.entity_id) if user is None: raise Unauthorized() - return Context(repo=repo, _user=user) + ctx.set_user(user) + return ctx diff --git a/fooder/controller/token.py b/fooder/controller/token.py index 58e8279..3e984f6 100644 --- a/fooder/controller/token.py +++ b/fooder/controller/token.py @@ -12,6 +12,19 @@ class TokenController(ControllerBase): super().__init__(ctx) self.entity_id = entity_id + @classmethod + def from_token(cls, ctx: Context, token_str: str, token_cls: Type[T]) -> "TokenController": + token = token_cls.decode(token_str) + return cls(ctx, token.sub) + + @classmethod + def from_refresh_token(cls, ctx: Context, token_str: str) -> "TokenController": + return cls.from_token(ctx, token_str, RefreshToken) + + @classmethod + def from_access_token(cls, ctx: Context, token_str: str) -> "TokenController": + return cls.from_token(ctx, token_str, AccessToken) + def generate_token(self, token_cls: Type[T], now: datetime) -> T: return token_cls(exp=token_cls.calculate_exp(now), sub=self.entity_id) diff --git a/fooder/model/token.py b/fooder/model/token.py index 368ab79..438f788 100644 --- a/fooder/model/token.py +++ b/fooder/model/token.py @@ -5,7 +5,3 @@ class TokenResponse(BaseModel): access_token: str refresh_token: str token_type: str = "bearer" - - -class RefreshTokenRequest(BaseModel): - refresh_token: str diff --git a/fooder/test/controller/test_token.py b/fooder/test/controller/test_token.py index 0b2b4aa..15bcab9 100644 --- a/fooder/test/controller/test_token.py +++ b/fooder/test/controller/test_token.py @@ -1,6 +1,46 @@ -from fooder.controller import TokenController +from datetime import datetime, timezone + +import pytest + +from fooder.controller.token import TokenController +from fooder.exc import Unauthorized +from fooder.utils.jwt import AccessToken, RefreshToken + +NOW = datetime.now(timezone.utc) def test_token_ctrl_generates_token(ctx): token_ctrl = TokenController(ctx, 1) token_ctrl.generate_token_pair(ctx.clock()) + + +class TestFromRefreshToken: + def test_returns_controller_with_correct_entity_id(self, ctx): + token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=42) + ctrl = TokenController.from_refresh_token(ctx, token.encode()) + assert ctrl.entity_id == 42 + + def test_invalid_string_raises(self, ctx): + with pytest.raises(Unauthorized): + TokenController.from_refresh_token(ctx, "bad-token") + + def test_access_token_raises(self, ctx): + token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=1) + with pytest.raises(Unauthorized): + TokenController.from_refresh_token(ctx, token.encode()) + + +class TestFromAccessToken: + def test_returns_controller_with_correct_entity_id(self, ctx): + token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=7) + ctrl = TokenController.from_access_token(ctx, token.encode()) + assert ctrl.entity_id == 7 + + def test_invalid_string_raises(self, ctx): + with pytest.raises(Unauthorized): + TokenController.from_access_token(ctx, "bad-token") + + def test_refresh_token_raises(self, ctx): + token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=1) + with pytest.raises(Unauthorized): + TokenController.from_access_token(ctx, token.encode()) \ No newline at end of file diff --git a/fooder/test/test_context.py b/fooder/test/test_context.py new file mode 100644 index 0000000..e304ac9 --- /dev/null +++ b/fooder/test/test_context.py @@ -0,0 +1,32 @@ +from datetime import datetime, timezone + +import pytest + +from fooder.context import AuthContextDependency +from fooder.exc import Unauthorized +from fooder.utils.jwt import AccessToken, RefreshToken + +NOW = datetime.now(timezone.utc) + + +async def test_auth_context_valid_token_returns_correct_user(db_session, user): + token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=user.id) + ctx = await AuthContextDependency()(token=token.encode(), session=db_session) + assert ctx.user.id == user.id + + +async def test_auth_context_invalid_token_raises(db_session): + with pytest.raises(Unauthorized): + await AuthContextDependency()(token="bad-token", session=db_session) + + +async def test_auth_context_refresh_token_raises(db_session, user): + token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=user.id) + with pytest.raises(Unauthorized): + await AuthContextDependency()(token=token.encode(), session=db_session) + + +async def test_auth_context_unknown_user_raises(db_session): + token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=99999) + with pytest.raises(Unauthorized): + await AuthContextDependency()(token=token.encode(), session=db_session) diff --git a/fooder/test/utils/test_jwt.py b/fooder/test/utils/test_jwt.py index 47ebe9a..df826c8 100644 --- a/fooder/test/utils/test_jwt.py +++ b/fooder/test/utils/test_jwt.py @@ -11,7 +11,6 @@ PAST = datetime(2000, 1, 1, tzinfo=timezone.utc) class WrongKeyToken(Token): - token_type: Literal["test-type"] = "test-type" secret_key = "wrong-secret" expire_delta = timedelta(minutes=30) diff --git a/fooder/test/view/test_token.py b/fooder/test/view/test_token.py index 7c38a95..4cbdb04 100644 --- a/fooder/test/view/test_token.py +++ b/fooder/test/view/test_token.py @@ -45,3 +45,58 @@ async def test_create_token_unknown_user(client): data={"username": "nobody", "password": "x"}, ) assert response.status_code == 401 + + +async def test_refresh_token_returns_new_tokens(client, user, user_password): + response = await client.post( + "/api/token", + data={"username": user.username, "password": user_password}, + ) + refresh_token = response.json()["refresh_token"] + + response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token}) + assert response.status_code == 200 + body = response.json() + assert "access_token" in body + assert "refresh_token" in body + assert body["token_type"] == "bearer" + + +async def test_refresh_token_access_token_is_valid(client, user, user_password): + response = await client.post( + "/api/token", + data={"username": user.username, "password": user_password}, + ) + refresh_token = response.json()["refresh_token"] + + response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token}) + token = AccessToken.decode(response.json()["access_token"]) + assert token.sub == user.id + + +async def test_refresh_token_refresh_token_is_valid(client, user, user_password): + response = await client.post( + "/api/token", + data={"username": user.username, "password": user_password}, + ) + refresh_token = response.json()["refresh_token"] + + response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token}) + token = RefreshToken.decode(response.json()["refresh_token"]) + assert token.sub == user.id + + +async def test_refresh_token_invalid_returns_401(client): + response = await client.post("/api/token/refresh", params={"refresh_token": "bad-token"}) + assert response.status_code == 401 + + +async def test_refresh_token_access_token_as_refresh_returns_401(client, user, user_password): + response = await client.post( + "/api/token", + data={"username": user.username, "password": user_password}, + ) + access_token = response.json()["access_token"] + + response = await client.post("/api/token/refresh", params={"refresh_token": access_token}) + assert response.status_code == 401 \ No newline at end of file diff --git a/fooder/utils/jwt.py b/fooder/utils/jwt.py index ff6d833..0990b63 100644 --- a/fooder/utils/jwt.py +++ b/fooder/utils/jwt.py @@ -8,7 +8,6 @@ from fooder.exc import Unauthorized class Token(BaseModel): - token_type: str exp: datetime sub: int @@ -37,12 +36,10 @@ class Token(BaseModel): class AccessToken(Token): - token_type: Literal["access"] = "access" secret_key = settings.SECRET_KEY expire_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) class RefreshToken(Token): - token_type: Literal["refresh"] = "refresh" secret_key = settings.REFRESH_SECRET_KEY expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) diff --git a/fooder/view/token.py b/fooder/view/token.py index 2b28ed0..7ec5ac8 100644 --- a/fooder/view/token.py +++ b/fooder/view/token.py @@ -4,15 +4,15 @@ from fastapi import APIRouter, Depends from fastapi.security import OAuth2PasswordRequestForm from datetime import datetime -from fooder.model.token import TokenResponse, RefreshTokenRequest +from fooder.model.token import TokenResponse from fooder.context import ContextDependency, Context from fooder.controller import UserController +from fooder.controller.token import TokenController router = APIRouter(tags=["token"]) -def gen_token_response(user_ctrl: UserController, now: datetime) -> TokenResponse: - token_ctrl = user_ctrl.token_ctrl() +def gen_token_response(token_ctrl: TokenController, now: datetime) -> TokenResponse: access_token, refresh_token = token_ctrl.generate_token_pair(now) return TokenResponse( access_token=access_token.encode(), @@ -27,12 +27,14 @@ async def token_create( ) -> TokenResponse: now = ctx.clock() user_ctrl = await UserController.session_start(ctx, data.username, data.password) - return gen_token_response(user_ctrl, now) + return gen_token_response(user_ctrl.token_ctrl(), now) @router.post("/refresh", response_model=TokenResponse) async def token_refresh( - data: RefreshTokenRequest, + refresh_token: str, ctx: Context = Depends(ContextDependency()), -): - pass +) -> TokenResponse: + now = ctx.clock() + token_ctrl = TokenController.from_refresh_token(ctx, refresh_token) + return gen_token_response(token_ctrl, now)