[token] refresh logic + some cleanups and a lot of tests
This commit is contained in:
parent
991560d943
commit
bd102360ad
9 changed files with 160 additions and 23 deletions
|
|
@ -7,7 +7,6 @@ from fooder.db import get_db_session
|
|||
from fooder.domain import User
|
||||
from fooder.repository import Repository
|
||||
from fooder.utils.datetime import utc_now
|
||||
from fooder.utils.jwt import AccessToken
|
||||
from fooder.exc import Unauthorized
|
||||
|
||||
|
||||
|
|
@ -20,11 +19,13 @@ class Context:
|
|||
self,
|
||||
repo: Repository,
|
||||
clock: Callable[[], datetime] = utc_now,
|
||||
_user: User | None = None,
|
||||
) -> None:
|
||||
self.repo = repo
|
||||
self.clock = clock
|
||||
self._user = _user
|
||||
self._user = None
|
||||
|
||||
def set_user(self, user: User) -> None:
|
||||
self._user = user
|
||||
|
||||
@property
|
||||
def user(self) -> User:
|
||||
|
|
@ -60,9 +61,11 @@ class AuthContextDependency:
|
|||
token: str = Depends(OAuth2PasswordBearer(tokenUrl="/token")),
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
) -> Context:
|
||||
access_token = AccessToken.decode(token)
|
||||
repo = Repository(session)
|
||||
user = await repo.user.get(User.id == access_token.sub)
|
||||
ctx = Context(repo = Repository(session))
|
||||
from fooder.controller.token import TokenController
|
||||
token_ctrl = TokenController.from_access_token(ctx, token)
|
||||
user = await ctx.repo.user.get(User.id == token_ctrl.entity_id)
|
||||
if user is None:
|
||||
raise Unauthorized()
|
||||
return Context(repo=repo, _user=user)
|
||||
ctx.set_user(user)
|
||||
return ctx
|
||||
|
|
|
|||
|
|
@ -12,6 +12,19 @@ class TokenController(ControllerBase):
|
|||
super().__init__(ctx)
|
||||
self.entity_id = entity_id
|
||||
|
||||
@classmethod
|
||||
def from_token(cls, ctx: Context, token_str: str, token_cls: Type[T]) -> "TokenController":
|
||||
token = token_cls.decode(token_str)
|
||||
return cls(ctx, token.sub)
|
||||
|
||||
@classmethod
|
||||
def from_refresh_token(cls, ctx: Context, token_str: str) -> "TokenController":
|
||||
return cls.from_token(ctx, token_str, RefreshToken)
|
||||
|
||||
@classmethod
|
||||
def from_access_token(cls, ctx: Context, token_str: str) -> "TokenController":
|
||||
return cls.from_token(ctx, token_str, AccessToken)
|
||||
|
||||
def generate_token(self, token_cls: Type[T], now: datetime) -> T:
|
||||
return token_cls(exp=token_cls.calculate_exp(now), sub=self.entity_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,3 @@ class TokenResponse(BaseModel):
|
|||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
|
|
|||
|
|
@ -1,6 +1,46 @@
|
|||
from fooder.controller import TokenController
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from fooder.controller.token import TokenController
|
||||
from fooder.exc import Unauthorized
|
||||
from fooder.utils.jwt import AccessToken, RefreshToken
|
||||
|
||||
NOW = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def test_token_ctrl_generates_token(ctx):
|
||||
token_ctrl = TokenController(ctx, 1)
|
||||
token_ctrl.generate_token_pair(ctx.clock())
|
||||
|
||||
|
||||
class TestFromRefreshToken:
|
||||
def test_returns_controller_with_correct_entity_id(self, ctx):
|
||||
token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=42)
|
||||
ctrl = TokenController.from_refresh_token(ctx, token.encode())
|
||||
assert ctrl.entity_id == 42
|
||||
|
||||
def test_invalid_string_raises(self, ctx):
|
||||
with pytest.raises(Unauthorized):
|
||||
TokenController.from_refresh_token(ctx, "bad-token")
|
||||
|
||||
def test_access_token_raises(self, ctx):
|
||||
token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=1)
|
||||
with pytest.raises(Unauthorized):
|
||||
TokenController.from_refresh_token(ctx, token.encode())
|
||||
|
||||
|
||||
class TestFromAccessToken:
|
||||
def test_returns_controller_with_correct_entity_id(self, ctx):
|
||||
token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=7)
|
||||
ctrl = TokenController.from_access_token(ctx, token.encode())
|
||||
assert ctrl.entity_id == 7
|
||||
|
||||
def test_invalid_string_raises(self, ctx):
|
||||
with pytest.raises(Unauthorized):
|
||||
TokenController.from_access_token(ctx, "bad-token")
|
||||
|
||||
def test_refresh_token_raises(self, ctx):
|
||||
token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=1)
|
||||
with pytest.raises(Unauthorized):
|
||||
TokenController.from_access_token(ctx, token.encode())
|
||||
32
fooder/test/test_context.py
Normal file
32
fooder/test/test_context.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from fooder.context import AuthContextDependency
|
||||
from fooder.exc import Unauthorized
|
||||
from fooder.utils.jwt import AccessToken, RefreshToken
|
||||
|
||||
NOW = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
async def test_auth_context_valid_token_returns_correct_user(db_session, user):
|
||||
token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=user.id)
|
||||
ctx = await AuthContextDependency()(token=token.encode(), session=db_session)
|
||||
assert ctx.user.id == user.id
|
||||
|
||||
|
||||
async def test_auth_context_invalid_token_raises(db_session):
|
||||
with pytest.raises(Unauthorized):
|
||||
await AuthContextDependency()(token="bad-token", session=db_session)
|
||||
|
||||
|
||||
async def test_auth_context_refresh_token_raises(db_session, user):
|
||||
token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=user.id)
|
||||
with pytest.raises(Unauthorized):
|
||||
await AuthContextDependency()(token=token.encode(), session=db_session)
|
||||
|
||||
|
||||
async def test_auth_context_unknown_user_raises(db_session):
|
||||
token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=99999)
|
||||
with pytest.raises(Unauthorized):
|
||||
await AuthContextDependency()(token=token.encode(), session=db_session)
|
||||
|
|
@ -11,7 +11,6 @@ PAST = datetime(2000, 1, 1, tzinfo=timezone.utc)
|
|||
|
||||
|
||||
class WrongKeyToken(Token):
|
||||
token_type: Literal["test-type"] = "test-type"
|
||||
secret_key = "wrong-secret"
|
||||
expire_delta = timedelta(minutes=30)
|
||||
|
||||
|
|
|
|||
|
|
@ -45,3 +45,58 @@ async def test_create_token_unknown_user(client):
|
|||
data={"username": "nobody", "password": "x"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
async def test_refresh_token_returns_new_tokens(client, user, user_password):
|
||||
response = await client.post(
|
||||
"/api/token",
|
||||
data={"username": user.username, "password": user_password},
|
||||
)
|
||||
refresh_token = response.json()["refresh_token"]
|
||||
|
||||
response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token})
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "access_token" in body
|
||||
assert "refresh_token" in body
|
||||
assert body["token_type"] == "bearer"
|
||||
|
||||
|
||||
async def test_refresh_token_access_token_is_valid(client, user, user_password):
|
||||
response = await client.post(
|
||||
"/api/token",
|
||||
data={"username": user.username, "password": user_password},
|
||||
)
|
||||
refresh_token = response.json()["refresh_token"]
|
||||
|
||||
response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token})
|
||||
token = AccessToken.decode(response.json()["access_token"])
|
||||
assert token.sub == user.id
|
||||
|
||||
|
||||
async def test_refresh_token_refresh_token_is_valid(client, user, user_password):
|
||||
response = await client.post(
|
||||
"/api/token",
|
||||
data={"username": user.username, "password": user_password},
|
||||
)
|
||||
refresh_token = response.json()["refresh_token"]
|
||||
|
||||
response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token})
|
||||
token = RefreshToken.decode(response.json()["refresh_token"])
|
||||
assert token.sub == user.id
|
||||
|
||||
|
||||
async def test_refresh_token_invalid_returns_401(client):
|
||||
response = await client.post("/api/token/refresh", params={"refresh_token": "bad-token"})
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
async def test_refresh_token_access_token_as_refresh_returns_401(client, user, user_password):
|
||||
response = await client.post(
|
||||
"/api/token",
|
||||
data={"username": user.username, "password": user_password},
|
||||
)
|
||||
access_token = response.json()["access_token"]
|
||||
|
||||
response = await client.post("/api/token/refresh", params={"refresh_token": access_token})
|
||||
assert response.status_code == 401
|
||||
|
|
@ -8,7 +8,6 @@ from fooder.exc import Unauthorized
|
|||
|
||||
|
||||
class Token(BaseModel):
|
||||
token_type: str
|
||||
exp: datetime
|
||||
sub: int
|
||||
|
||||
|
|
@ -37,12 +36,10 @@ class Token(BaseModel):
|
|||
|
||||
|
||||
class AccessToken(Token):
|
||||
token_type: Literal["access"] = "access"
|
||||
secret_key = settings.SECRET_KEY
|
||||
expire_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
|
||||
class RefreshToken(Token):
|
||||
token_type: Literal["refresh"] = "refresh"
|
||||
secret_key = settings.REFRESH_SECRET_KEY
|
||||
expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
|
|
|||
|
|
@ -4,15 +4,15 @@ from fastapi import APIRouter, Depends
|
|||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from datetime import datetime
|
||||
|
||||
from fooder.model.token import TokenResponse, RefreshTokenRequest
|
||||
from fooder.model.token import TokenResponse
|
||||
from fooder.context import ContextDependency, Context
|
||||
from fooder.controller import UserController
|
||||
from fooder.controller.token import TokenController
|
||||
|
||||
router = APIRouter(tags=["token"])
|
||||
|
||||
|
||||
def gen_token_response(user_ctrl: UserController, now: datetime) -> TokenResponse:
|
||||
token_ctrl = user_ctrl.token_ctrl()
|
||||
def gen_token_response(token_ctrl: TokenController, now: datetime) -> TokenResponse:
|
||||
access_token, refresh_token = token_ctrl.generate_token_pair(now)
|
||||
return TokenResponse(
|
||||
access_token=access_token.encode(),
|
||||
|
|
@ -27,12 +27,14 @@ async def token_create(
|
|||
) -> TokenResponse:
|
||||
now = ctx.clock()
|
||||
user_ctrl = await UserController.session_start(ctx, data.username, data.password)
|
||||
return gen_token_response(user_ctrl, now)
|
||||
return gen_token_response(user_ctrl.token_ctrl(), now)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def token_refresh(
|
||||
data: RefreshTokenRequest,
|
||||
refresh_token: str,
|
||||
ctx: Context = Depends(ContextDependency()),
|
||||
):
|
||||
pass
|
||||
) -> TokenResponse:
|
||||
now = ctx.clock()
|
||||
token_ctrl = TokenController.from_refresh_token(ctx, refresh_token)
|
||||
return gen_token_response(token_ctrl, now)
|
||||
|
|
|
|||
Loading…
Reference in a new issue