[token] implement jwt tokens + first view!
This commit is contained in:
parent
20ffc18044
commit
10ef646d93
12 changed files with 240 additions and 40 deletions
|
|
@ -13,21 +13,7 @@ from .domain.token import RefreshToken
|
||||||
from .domain.user import User
|
from .domain.user import User
|
||||||
from .settings import Settings
|
from .settings import Settings
|
||||||
|
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
|
||||||
settings = Settings()
|
|
||||||
password_helper = PasswordHelper(pwd_context) # type: ignore
|
|
||||||
|
|
||||||
AsyncSessionDependency = Annotated[async_sessionmaker, Depends(get_db_session)]
|
|
||||||
TokenDependency = Annotated[str, Depends(oauth2_scheme)]
|
|
||||||
|
|
||||||
|
|
||||||
def verify(plain_password: str, hashed_password: str) -> bool:
|
|
||||||
return pwd_context.verify(plain_password, hashed_password)
|
|
||||||
|
|
||||||
|
|
||||||
def hash(password: str) -> str:
|
|
||||||
return pwd_context.hash(password)
|
|
||||||
|
|
||||||
|
|
||||||
async def authenticate_user(
|
async def authenticate_user(
|
||||||
|
|
@ -137,9 +123,3 @@ async def get_current_user(ssn: AsyncSessionDependency, token: TokenDependency)
|
||||||
|
|
||||||
assert user is not None
|
assert user is not None
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
async def authorize_api_key(token: TokenDependency) -> None:
|
|
||||||
if token == settings.API_KEY:
|
|
||||||
return None
|
|
||||||
raise HTTPException(status_code=401, detail="Unathorized")
|
|
||||||
|
|
|
||||||
|
|
@ -5,4 +5,11 @@ from ..domain import User
|
||||||
|
|
||||||
class Repository:
|
class Repository:
|
||||||
def __init__(self, session: AsyncSession):
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.session = session
|
||||||
self.user = UserRepository(User, session)
|
self.user = UserRepository(User, session)
|
||||||
|
|
||||||
|
async def commit(self) -> None:
|
||||||
|
await self.session.commit()
|
||||||
|
|
||||||
|
async def rollback(self) -> None:
|
||||||
|
await self.session.rollback()
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,18 @@
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from .view.diary import router as diary_router
|
# from .view.diary import router as diary_router
|
||||||
from .view.entry import router as entry_router
|
# from .view.entry import router as entry_router
|
||||||
from .view.meal import router as meal_router
|
# from .view.meal import router as meal_router
|
||||||
from .view.preset import router as preset_router
|
# from .view.preset import router as preset_router
|
||||||
from .view.product import router as product_router
|
# from .view.product import router as product_router
|
||||||
from .view.token import router as token_router
|
from .view.token import router as token_router
|
||||||
from .view.user import router as user_router
|
# from .view.user import router as user_router
|
||||||
|
|
||||||
router = APIRouter(prefix="/api")
|
router = APIRouter(prefix="/api")
|
||||||
router.include_router(product_router, prefix="/product", tags=["product"])
|
# router.include_router(product_router, prefix="/product", tags=["product"])
|
||||||
router.include_router(diary_router, prefix="/diary", tags=["diary"])
|
# router.include_router(diary_router, prefix="/diary", tags=["diary"])
|
||||||
router.include_router(meal_router, prefix="/meal", tags=["meal"])
|
# router.include_router(meal_router, prefix="/meal", tags=["meal"])
|
||||||
router.include_router(entry_router, prefix="/entry", tags=["entry"])
|
# router.include_router(entry_router, prefix="/entry", tags=["entry"])
|
||||||
router.include_router(token_router, prefix="/token", tags=["token"])
|
router.include_router(token_router, prefix="/token", tags=["token"])
|
||||||
router.include_router(user_router, prefix="/user", tags=["user"])
|
# router.include_router(user_router, prefix="/user", tags=["user"])
|
||||||
router.include_router(preset_router, prefix="/preset", tags=["preset"])
|
# router.include_router(preset_router, prefix="/preset", tags=["preset"])
|
||||||
|
|
|
||||||
3
fooder/test/fixtures/__init__.py
vendored
3
fooder/test/fixtures/__init__.py
vendored
|
|
@ -1,6 +1,9 @@
|
||||||
import pytest
|
import pytest
|
||||||
from .db import *
|
from .db import *
|
||||||
from .faker import *
|
from .faker import *
|
||||||
|
from .user import *
|
||||||
|
from .client import *
|
||||||
|
from .context import *
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
16
fooder/test/fixtures/client.py
vendored
Normal file
16
fooder/test/fixtures/client.py
vendored
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
|
||||||
|
from fooder.app import app
|
||||||
|
from fooder.db import get_db_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client(db_session):
|
||||||
|
async def override_get_db_session():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db_session] = override_get_db_session
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c:
|
||||||
|
yield c
|
||||||
|
app.dependency_overrides.clear()
|
||||||
8
fooder/test/fixtures/context.py
vendored
Normal file
8
fooder/test/fixtures/context.py
vendored
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
import pytest
|
||||||
|
from fooder.context import Context
|
||||||
|
from fooder.repository import Repository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ctx(db_session):
|
||||||
|
return Context(repo=Repository(db_session))
|
||||||
25
fooder/test/fixtures/user.py
vendored
Normal file
25
fooder/test/fixtures/user.py
vendored
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from fooder.domain.user import User
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_password(faker):
|
||||||
|
return faker.password()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def user_factory(ctx):
|
||||||
|
async def factory(username, password):
|
||||||
|
user = User(username=username)
|
||||||
|
user.set_password(password)
|
||||||
|
await ctx.repo.user.create(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def user(faker, user_password, user_factory):
|
||||||
|
return await user_factory(faker.name(), user_password)
|
||||||
66
fooder/test/utils/test_jwt.py
Normal file
66
fooder/test/utils/test_jwt.py
Normal file
|
|
@ -0,0 +1,66 @@
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
import pytest
|
||||||
|
from jose import JWTError
|
||||||
|
|
||||||
|
from fooder.utils.jwt import AccessToken, RefreshToken, Token
|
||||||
|
|
||||||
|
|
||||||
|
PAST = datetime(2000, 1, 1, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
class WrongKeyToken(Token):
|
||||||
|
secret_key = "wrong-secret"
|
||||||
|
expire_delta = timedelta(minutes=30)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAccessToken:
|
||||||
|
def test_encode_decode_roundtrip(self):
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
token = AccessToken(exp=AccessToken.calculate_exp(now), sub=42)
|
||||||
|
decoded = AccessToken.decode(token.encode())
|
||||||
|
|
||||||
|
assert decoded.sub == token.sub
|
||||||
|
|
||||||
|
def test_calculate_exp(self):
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
assert AccessToken.calculate_exp(now) > now
|
||||||
|
|
||||||
|
def test_decode_wrong_key_raises(self):
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
token = WrongKeyToken(exp=WrongKeyToken.calculate_exp(now), sub=1)
|
||||||
|
|
||||||
|
with pytest.raises(JWTError):
|
||||||
|
AccessToken.decode(token.encode())
|
||||||
|
|
||||||
|
def test_decode_expired_raises(self):
|
||||||
|
token = AccessToken(exp=PAST, sub=1)
|
||||||
|
|
||||||
|
with pytest.raises(JWTError):
|
||||||
|
AccessToken.decode(token.encode())
|
||||||
|
|
||||||
|
|
||||||
|
class TestRefreshToken:
|
||||||
|
def test_encode_decode_roundtrip(self):
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
token = RefreshToken(exp=RefreshToken.calculate_exp(now), sub=7)
|
||||||
|
decoded = RefreshToken.decode(token.encode())
|
||||||
|
|
||||||
|
assert decoded.sub == token.sub
|
||||||
|
|
||||||
|
def test_calculate_exp(self):
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
assert RefreshToken.calculate_exp(now) > now
|
||||||
|
|
||||||
|
def test_refresh_token_not_decodable_as_access_token(self):
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
token = RefreshToken(exp=RefreshToken.calculate_exp(now), sub=1)
|
||||||
|
|
||||||
|
with pytest.raises(JWTError):
|
||||||
|
AccessToken.decode(token.encode())
|
||||||
|
|
||||||
|
def test_access_token_not_decodable_as_refresh_token(self):
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
token = AccessToken(exp=AccessToken.calculate_exp(now), sub=1)
|
||||||
|
|
||||||
|
with pytest.raises(JWTError):
|
||||||
|
RefreshToken.decode(token.encode())
|
||||||
0
fooder/test/view/__init__.py
Normal file
0
fooder/test/view/__init__.py
Normal file
47
fooder/test/view/test_token.py
Normal file
47
fooder/test/view/test_token.py
Normal file
|
|
@ -0,0 +1,47 @@
|
||||||
|
from fooder.utils.jwt import AccessToken, RefreshToken
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_token_returns_tokens(client, user, user_password):
|
||||||
|
response = await client.post(
|
||||||
|
"/api/token",
|
||||||
|
data={"username": user.username, "password": user_password},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert "access_token" in body
|
||||||
|
assert "refresh_token" in body
|
||||||
|
assert body["token_type"] == "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_token_access_token_is_valid(client, user, user_password):
|
||||||
|
response = await client.post(
|
||||||
|
"/api/token",
|
||||||
|
data={"username": user.username, "password": user_password},
|
||||||
|
)
|
||||||
|
token = AccessToken.decode(response.json()["access_token"])
|
||||||
|
assert token.sub == user.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_token_refresh_token_is_valid(client, user, user_password):
|
||||||
|
response = await client.post(
|
||||||
|
"/api/token",
|
||||||
|
data={"username": user.username, "password": user_password},
|
||||||
|
)
|
||||||
|
token = RefreshToken.decode(response.json()["refresh_token"])
|
||||||
|
assert token.sub == user.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_token_wrong_password(client, user):
|
||||||
|
response = await client.post(
|
||||||
|
"/api/token",
|
||||||
|
data={"username": user.username, "password": "wrong"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_token_unknown_user(client):
|
||||||
|
response = await client.post(
|
||||||
|
"/api/token",
|
||||||
|
data={"username": "nobody", "password": "x"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
37
fooder/utils/jwt.py
Normal file
37
fooder/utils/jwt.py
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
from jose import jwt
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from datetime import timedelta, datetime
|
||||||
|
from typing import ClassVar
|
||||||
|
from ..settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class Token(BaseModel):
|
||||||
|
exp: datetime
|
||||||
|
sub: int
|
||||||
|
|
||||||
|
secret_key: ClassVar[str]
|
||||||
|
expire_delta: ClassVar[timedelta]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def calculate_exp(cls, now: datetime) -> datetime:
|
||||||
|
return now + cls.expire_delta
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def decode(cls, jwt_token: str | bytes) -> "Token":
|
||||||
|
data = jwt.decode(jwt_token, cls.secret_key, algorithms=[settings.ALGORITHM])
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
def encode(self) -> str:
|
||||||
|
data = self.model_dump()
|
||||||
|
data["sub"] = str(data["sub"])
|
||||||
|
return jwt.encode(data, self.secret_key, settings.ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
|
class AccessToken(Token):
|
||||||
|
secret_key = settings.SECRET_KEY
|
||||||
|
expire_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshToken(Token):
|
||||||
|
secret_key = settings.REFRESH_SECRET_KEY
|
||||||
|
expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
|
@ -1,27 +1,38 @@
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Request
|
from datetime import datetime, timezone
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
from ..controller.token import CreateToken, RefreshToken
|
|
||||||
from ..model.token import RefreshTokenPayload, Token
|
from ..model.token import RefreshTokenPayload, Token
|
||||||
|
from ..context import ContextDependency, Context
|
||||||
|
from ..utils.jwt import AccessToken, RefreshToken
|
||||||
|
|
||||||
router = APIRouter(tags=["token"])
|
router = APIRouter(tags=["token"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=Token)
|
@router.post("", response_model=Token)
|
||||||
async def create_token(
|
async def create_token(
|
||||||
request: Request,
|
|
||||||
data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
controller: CreateToken = Depends(CreateToken),
|
ctx: Context = Depends(ContextDependency()),
|
||||||
):
|
):
|
||||||
return await controller.call(data)
|
user = await ctx.repo.user.get(username=data.username)
|
||||||
|
|
||||||
|
if user is None or not user.verify_password(data.password):
|
||||||
|
raise HTTPException(status_code=401, detail="Unathorized")
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
access_token = AccessToken(sub=user.id, exp=AccessToken.calculate_exp(now))
|
||||||
|
refresh_token = RefreshToken(sub=user.id, exp=RefreshToken.calculate_exp(now))
|
||||||
|
return Token(
|
||||||
|
access_token=access_token.encode(),
|
||||||
|
refresh_token=refresh_token.encode(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=Token)
|
@router.post("/refresh", response_model=Token)
|
||||||
async def refresh_token(
|
async def refresh_token(
|
||||||
request: Request,
|
|
||||||
data: RefreshTokenPayload,
|
data: RefreshTokenPayload,
|
||||||
controller: RefreshToken = Depends(RefreshToken),
|
ctx: Context = Depends(ContextDependency()),
|
||||||
):
|
):
|
||||||
return await controller.call(data)
|
pass
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue