[token] implement jwt tokens + first view!
This commit is contained in:
parent
20ffc18044
commit
10ef646d93
12 changed files with 240 additions and 40 deletions
|
|
@ -13,21 +13,7 @@ from .domain.token import RefreshToken
|
|||
from .domain.user import User
|
||||
from .settings import Settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
|
||||
settings = Settings()
|
||||
password_helper = PasswordHelper(pwd_context) # type: ignore
|
||||
|
||||
AsyncSessionDependency = Annotated[async_sessionmaker, Depends(get_db_session)]
|
||||
TokenDependency = Annotated[str, Depends(oauth2_scheme)]
|
||||
|
||||
|
||||
def verify(plain_password: str, hashed_password: str) -> bool:
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def hash(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
async def authenticate_user(
|
||||
|
|
@ -137,9 +123,3 @@ async def get_current_user(ssn: AsyncSessionDependency, token: TokenDependency)
|
|||
|
||||
assert user is not None
|
||||
return user
|
||||
|
||||
|
||||
async def authorize_api_key(token: TokenDependency) -> None:
|
||||
if token == settings.API_KEY:
|
||||
return None
|
||||
raise HTTPException(status_code=401, detail="Unathorized")
|
||||
|
|
|
|||
|
|
@ -5,4 +5,11 @@ from ..domain import User
|
|||
|
||||
class Repository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.user = UserRepository(User, session)
|
||||
|
||||
async def commit(self) -> None:
|
||||
await self.session.commit()
|
||||
|
||||
async def rollback(self) -> None:
|
||||
await self.session.rollback()
|
||||
|
|
|
|||
|
|
@ -1,18 +1,18 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from .view.diary import router as diary_router
|
||||
from .view.entry import router as entry_router
|
||||
from .view.meal import router as meal_router
|
||||
from .view.preset import router as preset_router
|
||||
from .view.product import router as product_router
|
||||
# from .view.diary import router as diary_router
|
||||
# from .view.entry import router as entry_router
|
||||
# from .view.meal import router as meal_router
|
||||
# from .view.preset import router as preset_router
|
||||
# from .view.product import router as product_router
|
||||
from .view.token import router as token_router
|
||||
from .view.user import router as user_router
|
||||
# from .view.user import router as user_router
|
||||
|
||||
router = APIRouter(prefix="/api")
|
||||
router.include_router(product_router, prefix="/product", tags=["product"])
|
||||
router.include_router(diary_router, prefix="/diary", tags=["diary"])
|
||||
router.include_router(meal_router, prefix="/meal", tags=["meal"])
|
||||
router.include_router(entry_router, prefix="/entry", tags=["entry"])
|
||||
# router.include_router(product_router, prefix="/product", tags=["product"])
|
||||
# router.include_router(diary_router, prefix="/diary", tags=["diary"])
|
||||
# router.include_router(meal_router, prefix="/meal", tags=["meal"])
|
||||
# router.include_router(entry_router, prefix="/entry", tags=["entry"])
|
||||
router.include_router(token_router, prefix="/token", tags=["token"])
|
||||
router.include_router(user_router, prefix="/user", tags=["user"])
|
||||
router.include_router(preset_router, prefix="/preset", tags=["preset"])
|
||||
# router.include_router(user_router, prefix="/user", tags=["user"])
|
||||
# router.include_router(preset_router, prefix="/preset", tags=["preset"])
|
||||
|
|
|
|||
3
fooder/test/fixtures/__init__.py
vendored
3
fooder/test/fixtures/__init__.py
vendored
|
|
@ -1,6 +1,9 @@
|
|||
import pytest
|
||||
from .db import *
|
||||
from .faker import *
|
||||
from .user import *
|
||||
from .client import *
|
||||
from .context import *
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
16
fooder/test/fixtures/client.py
vendored
Normal file
16
fooder/test/fixtures/client.py
vendored
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
import pytest_asyncio
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from fooder.app import app
|
||||
from fooder.db import get_db_session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(db_session):
|
||||
async def override_get_db_session():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db_session] = override_get_db_session
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c:
|
||||
yield c
|
||||
app.dependency_overrides.clear()
|
||||
8
fooder/test/fixtures/context.py
vendored
Normal file
8
fooder/test/fixtures/context.py
vendored
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
import pytest
|
||||
from fooder.context import Context
|
||||
from fooder.repository import Repository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ctx(db_session):
|
||||
return Context(repo=Repository(db_session))
|
||||
25
fooder/test/fixtures/user.py
vendored
Normal file
25
fooder/test/fixtures/user.py
vendored
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from fooder.domain.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_password(faker):
|
||||
return faker.password()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def user_factory(ctx):
|
||||
async def factory(username, password):
|
||||
user = User(username=username)
|
||||
user.set_password(password)
|
||||
await ctx.repo.user.create(user)
|
||||
return user
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def user(faker, user_password, user_factory):
|
||||
return await user_factory(faker.name(), user_password)
|
||||
66
fooder/test/utils/test_jwt.py
Normal file
66
fooder/test/utils/test_jwt.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
from datetime import datetime, timedelta, timezone
|
||||
import pytest
|
||||
from jose import JWTError
|
||||
|
||||
from fooder.utils.jwt import AccessToken, RefreshToken, Token
|
||||
|
||||
|
||||
PAST = datetime(2000, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
class WrongKeyToken(Token):
|
||||
secret_key = "wrong-secret"
|
||||
expire_delta = timedelta(minutes=30)
|
||||
|
||||
|
||||
class TestAccessToken:
|
||||
def test_encode_decode_roundtrip(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
token = AccessToken(exp=AccessToken.calculate_exp(now), sub=42)
|
||||
decoded = AccessToken.decode(token.encode())
|
||||
|
||||
assert decoded.sub == token.sub
|
||||
|
||||
def test_calculate_exp(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
assert AccessToken.calculate_exp(now) > now
|
||||
|
||||
def test_decode_wrong_key_raises(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
token = WrongKeyToken(exp=WrongKeyToken.calculate_exp(now), sub=1)
|
||||
|
||||
with pytest.raises(JWTError):
|
||||
AccessToken.decode(token.encode())
|
||||
|
||||
def test_decode_expired_raises(self):
|
||||
token = AccessToken(exp=PAST, sub=1)
|
||||
|
||||
with pytest.raises(JWTError):
|
||||
AccessToken.decode(token.encode())
|
||||
|
||||
|
||||
class TestRefreshToken:
|
||||
def test_encode_decode_roundtrip(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
token = RefreshToken(exp=RefreshToken.calculate_exp(now), sub=7)
|
||||
decoded = RefreshToken.decode(token.encode())
|
||||
|
||||
assert decoded.sub == token.sub
|
||||
|
||||
def test_calculate_exp(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
assert RefreshToken.calculate_exp(now) > now
|
||||
|
||||
def test_refresh_token_not_decodable_as_access_token(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
token = RefreshToken(exp=RefreshToken.calculate_exp(now), sub=1)
|
||||
|
||||
with pytest.raises(JWTError):
|
||||
AccessToken.decode(token.encode())
|
||||
|
||||
def test_access_token_not_decodable_as_refresh_token(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
token = AccessToken(exp=AccessToken.calculate_exp(now), sub=1)
|
||||
|
||||
with pytest.raises(JWTError):
|
||||
RefreshToken.decode(token.encode())
|
||||
0
fooder/test/view/__init__.py
Normal file
0
fooder/test/view/__init__.py
Normal file
47
fooder/test/view/test_token.py
Normal file
47
fooder/test/view/test_token.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
from fooder.utils.jwt import AccessToken, RefreshToken
|
||||
|
||||
|
||||
async def test_create_token_returns_tokens(client, user, user_password):
|
||||
response = await client.post(
|
||||
"/api/token",
|
||||
data={"username": user.username, "password": user_password},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "access_token" in body
|
||||
assert "refresh_token" in body
|
||||
assert body["token_type"] == "bearer"
|
||||
|
||||
|
||||
async def test_create_token_access_token_is_valid(client, user, user_password):
|
||||
response = await client.post(
|
||||
"/api/token",
|
||||
data={"username": user.username, "password": user_password},
|
||||
)
|
||||
token = AccessToken.decode(response.json()["access_token"])
|
||||
assert token.sub == user.id
|
||||
|
||||
|
||||
async def test_create_token_refresh_token_is_valid(client, user, user_password):
|
||||
response = await client.post(
|
||||
"/api/token",
|
||||
data={"username": user.username, "password": user_password},
|
||||
)
|
||||
token = RefreshToken.decode(response.json()["refresh_token"])
|
||||
assert token.sub == user.id
|
||||
|
||||
|
||||
async def test_create_token_wrong_password(client, user):
|
||||
response = await client.post(
|
||||
"/api/token",
|
||||
data={"username": user.username, "password": "wrong"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
async def test_create_token_unknown_user(client):
|
||||
response = await client.post(
|
||||
"/api/token",
|
||||
data={"username": "nobody", "password": "x"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
37
fooder/utils/jwt.py
Normal file
37
fooder/utils/jwt.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from jose import jwt
|
||||
from pydantic import BaseModel
|
||||
from datetime import timedelta, datetime
|
||||
from typing import ClassVar
|
||||
from ..settings import settings
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
exp: datetime
|
||||
sub: int
|
||||
|
||||
secret_key: ClassVar[str]
|
||||
expire_delta: ClassVar[timedelta]
|
||||
|
||||
@classmethod
|
||||
def calculate_exp(cls, now: datetime) -> datetime:
|
||||
return now + cls.expire_delta
|
||||
|
||||
@classmethod
|
||||
def decode(cls, jwt_token: str | bytes) -> "Token":
|
||||
data = jwt.decode(jwt_token, cls.secret_key, algorithms=[settings.ALGORITHM])
|
||||
return cls(**data)
|
||||
|
||||
def encode(self) -> str:
|
||||
data = self.model_dump()
|
||||
data["sub"] = str(data["sub"])
|
||||
return jwt.encode(data, self.secret_key, settings.ALGORITHM)
|
||||
|
||||
|
||||
class AccessToken(Token):
|
||||
secret_key = settings.SECRET_KEY
|
||||
expire_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
|
||||
class RefreshToken(Token):
|
||||
secret_key = settings.REFRESH_SECRET_KEY
|
||||
expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
|
@ -1,27 +1,38 @@
|
|||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from ..controller.token import CreateToken, RefreshToken
|
||||
from ..model.token import RefreshTokenPayload, Token
|
||||
from ..context import ContextDependency, Context
|
||||
from ..utils.jwt import AccessToken, RefreshToken
|
||||
|
||||
router = APIRouter(tags=["token"])
|
||||
|
||||
|
||||
@router.post("", response_model=Token)
|
||||
async def create_token(
|
||||
request: Request,
|
||||
data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
controller: CreateToken = Depends(CreateToken),
|
||||
ctx: Context = Depends(ContextDependency()),
|
||||
):
|
||||
return await controller.call(data)
|
||||
user = await ctx.repo.user.get(username=data.username)
|
||||
|
||||
if user is None or not user.verify_password(data.password):
|
||||
raise HTTPException(status_code=401, detail="Unathorized")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
access_token = AccessToken(sub=user.id, exp=AccessToken.calculate_exp(now))
|
||||
refresh_token = RefreshToken(sub=user.id, exp=RefreshToken.calculate_exp(now))
|
||||
return Token(
|
||||
access_token=access_token.encode(),
|
||||
refresh_token=refresh_token.encode(),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token)
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
data: RefreshTokenPayload,
|
||||
controller: RefreshToken = Depends(RefreshToken),
|
||||
ctx: Context = Depends(ContextDependency()),
|
||||
):
|
||||
return await controller.call(data)
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in a new issue