[token] refresh logic + some cleanups and a lot of tests
This commit is contained in:
parent
991560d943
commit
bd102360ad
9 changed files with 160 additions and 23 deletions
|
|
@ -7,7 +7,6 @@ from fooder.db import get_db_session
|
||||||
from fooder.domain import User
|
from fooder.domain import User
|
||||||
from fooder.repository import Repository
|
from fooder.repository import Repository
|
||||||
from fooder.utils.datetime import utc_now
|
from fooder.utils.datetime import utc_now
|
||||||
from fooder.utils.jwt import AccessToken
|
|
||||||
from fooder.exc import Unauthorized
|
from fooder.exc import Unauthorized
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -20,11 +19,13 @@ class Context:
|
||||||
self,
|
self,
|
||||||
repo: Repository,
|
repo: Repository,
|
||||||
clock: Callable[[], datetime] = utc_now,
|
clock: Callable[[], datetime] = utc_now,
|
||||||
_user: User | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.repo = repo
|
self.repo = repo
|
||||||
self.clock = clock
|
self.clock = clock
|
||||||
self._user = _user
|
self._user = None
|
||||||
|
|
||||||
|
def set_user(self, user: User) -> None:
|
||||||
|
self._user = user
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def user(self) -> User:
|
def user(self) -> User:
|
||||||
|
|
@ -60,9 +61,11 @@ class AuthContextDependency:
|
||||||
token: str = Depends(OAuth2PasswordBearer(tokenUrl="/token")),
|
token: str = Depends(OAuth2PasswordBearer(tokenUrl="/token")),
|
||||||
session: AsyncSession = Depends(get_db_session),
|
session: AsyncSession = Depends(get_db_session),
|
||||||
) -> Context:
|
) -> Context:
|
||||||
access_token = AccessToken.decode(token)
|
ctx = Context(repo = Repository(session))
|
||||||
repo = Repository(session)
|
from fooder.controller.token import TokenController
|
||||||
user = await repo.user.get(User.id == access_token.sub)
|
token_ctrl = TokenController.from_access_token(ctx, token)
|
||||||
|
user = await ctx.repo.user.get(User.id == token_ctrl.entity_id)
|
||||||
if user is None:
|
if user is None:
|
||||||
raise Unauthorized()
|
raise Unauthorized()
|
||||||
return Context(repo=repo, _user=user)
|
ctx.set_user(user)
|
||||||
|
return ctx
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,19 @@ class TokenController(ControllerBase):
|
||||||
super().__init__(ctx)
|
super().__init__(ctx)
|
||||||
self.entity_id = entity_id
|
self.entity_id = entity_id
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_token(cls, ctx: Context, token_str: str, token_cls: Type[T]) -> "TokenController":
|
||||||
|
token = token_cls.decode(token_str)
|
||||||
|
return cls(ctx, token.sub)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_refresh_token(cls, ctx: Context, token_str: str) -> "TokenController":
|
||||||
|
return cls.from_token(ctx, token_str, RefreshToken)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_access_token(cls, ctx: Context, token_str: str) -> "TokenController":
|
||||||
|
return cls.from_token(ctx, token_str, AccessToken)
|
||||||
|
|
||||||
def generate_token(self, token_cls: Type[T], now: datetime) -> T:
|
def generate_token(self, token_cls: Type[T], now: datetime) -> T:
|
||||||
return token_cls(exp=token_cls.calculate_exp(now), sub=self.entity_id)
|
return token_cls(exp=token_cls.calculate_exp(now), sub=self.entity_id)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,3 @@ class TokenResponse(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
refresh_token: str
|
refresh_token: str
|
||||||
token_type: str = "bearer"
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
|
||||||
class RefreshTokenRequest(BaseModel):
|
|
||||||
refresh_token: str
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,46 @@
|
||||||
from fooder.controller import TokenController
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fooder.controller.token import TokenController
|
||||||
|
from fooder.exc import Unauthorized
|
||||||
|
from fooder.utils.jwt import AccessToken, RefreshToken
|
||||||
|
|
||||||
|
NOW = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
def test_token_ctrl_generates_token(ctx):
|
def test_token_ctrl_generates_token(ctx):
|
||||||
token_ctrl = TokenController(ctx, 1)
|
token_ctrl = TokenController(ctx, 1)
|
||||||
token_ctrl.generate_token_pair(ctx.clock())
|
token_ctrl.generate_token_pair(ctx.clock())
|
||||||
|
|
||||||
|
|
||||||
|
class TestFromRefreshToken:
|
||||||
|
def test_returns_controller_with_correct_entity_id(self, ctx):
|
||||||
|
token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=42)
|
||||||
|
ctrl = TokenController.from_refresh_token(ctx, token.encode())
|
||||||
|
assert ctrl.entity_id == 42
|
||||||
|
|
||||||
|
def test_invalid_string_raises(self, ctx):
|
||||||
|
with pytest.raises(Unauthorized):
|
||||||
|
TokenController.from_refresh_token(ctx, "bad-token")
|
||||||
|
|
||||||
|
def test_access_token_raises(self, ctx):
|
||||||
|
token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=1)
|
||||||
|
with pytest.raises(Unauthorized):
|
||||||
|
TokenController.from_refresh_token(ctx, token.encode())
|
||||||
|
|
||||||
|
|
||||||
|
class TestFromAccessToken:
|
||||||
|
def test_returns_controller_with_correct_entity_id(self, ctx):
|
||||||
|
token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=7)
|
||||||
|
ctrl = TokenController.from_access_token(ctx, token.encode())
|
||||||
|
assert ctrl.entity_id == 7
|
||||||
|
|
||||||
|
def test_invalid_string_raises(self, ctx):
|
||||||
|
with pytest.raises(Unauthorized):
|
||||||
|
TokenController.from_access_token(ctx, "bad-token")
|
||||||
|
|
||||||
|
def test_refresh_token_raises(self, ctx):
|
||||||
|
token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=1)
|
||||||
|
with pytest.raises(Unauthorized):
|
||||||
|
TokenController.from_access_token(ctx, token.encode())
|
||||||
32
fooder/test/test_context.py
Normal file
32
fooder/test/test_context.py
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fooder.context import AuthContextDependency
|
||||||
|
from fooder.exc import Unauthorized
|
||||||
|
from fooder.utils.jwt import AccessToken, RefreshToken
|
||||||
|
|
||||||
|
NOW = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_auth_context_valid_token_returns_correct_user(db_session, user):
|
||||||
|
token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=user.id)
|
||||||
|
ctx = await AuthContextDependency()(token=token.encode(), session=db_session)
|
||||||
|
assert ctx.user.id == user.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_auth_context_invalid_token_raises(db_session):
|
||||||
|
with pytest.raises(Unauthorized):
|
||||||
|
await AuthContextDependency()(token="bad-token", session=db_session)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_auth_context_refresh_token_raises(db_session, user):
|
||||||
|
token = RefreshToken(exp=RefreshToken.calculate_exp(NOW), sub=user.id)
|
||||||
|
with pytest.raises(Unauthorized):
|
||||||
|
await AuthContextDependency()(token=token.encode(), session=db_session)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_auth_context_unknown_user_raises(db_session):
|
||||||
|
token = AccessToken(exp=AccessToken.calculate_exp(NOW), sub=99999)
|
||||||
|
with pytest.raises(Unauthorized):
|
||||||
|
await AuthContextDependency()(token=token.encode(), session=db_session)
|
||||||
|
|
@ -11,7 +11,6 @@ PAST = datetime(2000, 1, 1, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
class WrongKeyToken(Token):
|
class WrongKeyToken(Token):
|
||||||
token_type: Literal["test-type"] = "test-type"
|
|
||||||
secret_key = "wrong-secret"
|
secret_key = "wrong-secret"
|
||||||
expire_delta = timedelta(minutes=30)
|
expire_delta = timedelta(minutes=30)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,3 +45,58 @@ async def test_create_token_unknown_user(client):
|
||||||
data={"username": "nobody", "password": "x"},
|
data={"username": "nobody", "password": "x"},
|
||||||
)
|
)
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_token_returns_new_tokens(client, user, user_password):
|
||||||
|
response = await client.post(
|
||||||
|
"/api/token",
|
||||||
|
data={"username": user.username, "password": user_password},
|
||||||
|
)
|
||||||
|
refresh_token = response.json()["refresh_token"]
|
||||||
|
|
||||||
|
response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token})
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert "access_token" in body
|
||||||
|
assert "refresh_token" in body
|
||||||
|
assert body["token_type"] == "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_token_access_token_is_valid(client, user, user_password):
|
||||||
|
response = await client.post(
|
||||||
|
"/api/token",
|
||||||
|
data={"username": user.username, "password": user_password},
|
||||||
|
)
|
||||||
|
refresh_token = response.json()["refresh_token"]
|
||||||
|
|
||||||
|
response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token})
|
||||||
|
token = AccessToken.decode(response.json()["access_token"])
|
||||||
|
assert token.sub == user.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_token_refresh_token_is_valid(client, user, user_password):
|
||||||
|
response = await client.post(
|
||||||
|
"/api/token",
|
||||||
|
data={"username": user.username, "password": user_password},
|
||||||
|
)
|
||||||
|
refresh_token = response.json()["refresh_token"]
|
||||||
|
|
||||||
|
response = await client.post("/api/token/refresh", params={"refresh_token": refresh_token})
|
||||||
|
token = RefreshToken.decode(response.json()["refresh_token"])
|
||||||
|
assert token.sub == user.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_token_invalid_returns_401(client):
|
||||||
|
response = await client.post("/api/token/refresh", params={"refresh_token": "bad-token"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_token_access_token_as_refresh_returns_401(client, user, user_password):
|
||||||
|
response = await client.post(
|
||||||
|
"/api/token",
|
||||||
|
data={"username": user.username, "password": user_password},
|
||||||
|
)
|
||||||
|
access_token = response.json()["access_token"]
|
||||||
|
|
||||||
|
response = await client.post("/api/token/refresh", params={"refresh_token": access_token})
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
@ -8,7 +8,6 @@ from fooder.exc import Unauthorized
|
||||||
|
|
||||||
|
|
||||||
class Token(BaseModel):
|
class Token(BaseModel):
|
||||||
token_type: str
|
|
||||||
exp: datetime
|
exp: datetime
|
||||||
sub: int
|
sub: int
|
||||||
|
|
||||||
|
|
@ -37,12 +36,10 @@ class Token(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class AccessToken(Token):
|
class AccessToken(Token):
|
||||||
token_type: Literal["access"] = "access"
|
|
||||||
secret_key = settings.SECRET_KEY
|
secret_key = settings.SECRET_KEY
|
||||||
expire_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
expire_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
|
||||||
|
|
||||||
class RefreshToken(Token):
|
class RefreshToken(Token):
|
||||||
token_type: Literal["refresh"] = "refresh"
|
|
||||||
secret_key = settings.REFRESH_SECRET_KEY
|
secret_key = settings.REFRESH_SECRET_KEY
|
||||||
expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
|
|
||||||
|
|
@ -4,15 +4,15 @@ from fastapi import APIRouter, Depends
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from fooder.model.token import TokenResponse, RefreshTokenRequest
|
from fooder.model.token import TokenResponse
|
||||||
from fooder.context import ContextDependency, Context
|
from fooder.context import ContextDependency, Context
|
||||||
from fooder.controller import UserController
|
from fooder.controller import UserController
|
||||||
|
from fooder.controller.token import TokenController
|
||||||
|
|
||||||
router = APIRouter(tags=["token"])
|
router = APIRouter(tags=["token"])
|
||||||
|
|
||||||
|
|
||||||
def gen_token_response(user_ctrl: UserController, now: datetime) -> TokenResponse:
|
def gen_token_response(token_ctrl: TokenController, now: datetime) -> TokenResponse:
|
||||||
token_ctrl = user_ctrl.token_ctrl()
|
|
||||||
access_token, refresh_token = token_ctrl.generate_token_pair(now)
|
access_token, refresh_token = token_ctrl.generate_token_pair(now)
|
||||||
return TokenResponse(
|
return TokenResponse(
|
||||||
access_token=access_token.encode(),
|
access_token=access_token.encode(),
|
||||||
|
|
@ -27,12 +27,14 @@ async def token_create(
|
||||||
) -> TokenResponse:
|
) -> TokenResponse:
|
||||||
now = ctx.clock()
|
now = ctx.clock()
|
||||||
user_ctrl = await UserController.session_start(ctx, data.username, data.password)
|
user_ctrl = await UserController.session_start(ctx, data.username, data.password)
|
||||||
return gen_token_response(user_ctrl, now)
|
return gen_token_response(user_ctrl.token_ctrl(), now)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=TokenResponse)
|
@router.post("/refresh", response_model=TokenResponse)
|
||||||
async def token_refresh(
|
async def token_refresh(
|
||||||
data: RefreshTokenRequest,
|
refresh_token: str,
|
||||||
ctx: Context = Depends(ContextDependency()),
|
ctx: Context = Depends(ContextDependency()),
|
||||||
):
|
) -> TokenResponse:
|
||||||
pass
|
now = ctx.clock()
|
||||||
|
token_ctrl = TokenController.from_refresh_token(ctx, refresh_token)
|
||||||
|
return gen_token_response(token_ctrl, now)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue