66 lines
2 KiB
Python
66 lines
2 KiB
Python
from datetime import datetime, timedelta, timezone
|
|
import pytest
|
|
from jose import JWTError
|
|
|
|
from fooder.utils.jwt import AccessToken, RefreshToken, Token
|
|
|
|
|
|
PAST = datetime(2000, 1, 1, tzinfo=timezone.utc)
|
|
|
|
|
|
class WrongKeyToken(Token):
|
|
secret_key = "wrong-secret"
|
|
expire_delta = timedelta(minutes=30)
|
|
|
|
|
|
class TestAccessToken:
|
|
def test_encode_decode_roundtrip(self):
|
|
now = datetime.now(timezone.utc)
|
|
token = AccessToken(exp=AccessToken.calculate_exp(now), sub=42)
|
|
decoded = AccessToken.decode(token.encode())
|
|
|
|
assert decoded.sub == token.sub
|
|
|
|
def test_calculate_exp(self):
|
|
now = datetime.now(timezone.utc)
|
|
assert AccessToken.calculate_exp(now) > now
|
|
|
|
def test_decode_wrong_key_raises(self):
|
|
now = datetime.now(timezone.utc)
|
|
token = WrongKeyToken(exp=WrongKeyToken.calculate_exp(now), sub=1)
|
|
|
|
with pytest.raises(JWTError):
|
|
AccessToken.decode(token.encode())
|
|
|
|
def test_decode_expired_raises(self):
|
|
token = AccessToken(exp=PAST, sub=1)
|
|
|
|
with pytest.raises(JWTError):
|
|
AccessToken.decode(token.encode())
|
|
|
|
|
|
class TestRefreshToken:
|
|
def test_encode_decode_roundtrip(self):
|
|
now = datetime.now(timezone.utc)
|
|
token = RefreshToken(exp=RefreshToken.calculate_exp(now), sub=7)
|
|
decoded = RefreshToken.decode(token.encode())
|
|
|
|
assert decoded.sub == token.sub
|
|
|
|
def test_calculate_exp(self):
|
|
now = datetime.now(timezone.utc)
|
|
assert RefreshToken.calculate_exp(now) > now
|
|
|
|
def test_refresh_token_not_decodable_as_access_token(self):
|
|
now = datetime.now(timezone.utc)
|
|
token = RefreshToken(exp=RefreshToken.calculate_exp(now), sub=1)
|
|
|
|
with pytest.raises(JWTError):
|
|
AccessToken.decode(token.encode())
|
|
|
|
def test_access_token_not_decodable_as_refresh_token(self):
|
|
now = datetime.now(timezone.utc)
|
|
token = AccessToken(exp=AccessToken.calculate_exp(now), sub=1)
|
|
|
|
with pytest.raises(JWTError):
|
|
RefreshToken.decode(token.encode())
|