add refresh tokens

This commit is contained in:
Piotr Domański 2023-04-02 14:38:22 +02:00
parent e537bef22e
commit 28d55d3632
8 changed files with 142 additions and 7 deletions

View file

@ -7,5 +7,7 @@ DB_URI="postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@database/${PO
ECHO_SQL=0
SECRET_KEY="" # generate with $ openssl rand -hex 32
REFRESH_SECRET_KEY="" # generate with $ openssl rand -hex 32
ALGORITHM="HS256"
ACCESS_TOKEN_EXPIRE_MINUTES=30
REFRESH_TOKEN_EXPIRE_DAYS=30

View file

@ -10,6 +10,7 @@ from typing import AsyncGenerator, Dict, Annotated, Optional
from datetime import datetime, timedelta
from .settings import Settings
from .domain.user import User
from .domain.token import RefreshToken
from .db import get_session
@ -41,7 +42,26 @@ async def authenticate_user(
return user
def create_access_token(user: User):
async def verify_refresh_token(
session: AsyncSession, token: str
) -> AsyncGenerator[User, None]:
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
username: str = payload.get("sub")
if username is None:
return
except JWTError:
return
user = await User.get_by_username(session, username)
current_token = await RefreshToken.get_token(session, user.id)
if current_token is not None and current_token.token == token:
return user
def create_access_token(user: User) -> str:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode = {
"sub": user.username,
@ -53,6 +73,18 @@ def create_access_token(user: User):
return encoded_jwt
async def create_refresh_token(session: AsyncSession, user: User) -> RefreshToken:
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {
"sub": user.username,
"exp": expire,
}
encoded_jwt = jwt.encode(
to_encode, settings.REFRESH_SECRET_KEY, algorithm=settings.ALGORITHM
)
return await RefreshToken.create(session, token=encoded_jwt, user_id=user.id)
async def get_current_user(
session: AsyncSessionDependency, token: TokenDependency
) -> User:

View file

@ -3,15 +3,20 @@ from typing import AsyncIterator, Annotated
from fastapi import HTTPException, Depends
from fastapi.security import OAuth2PasswordRequestForm
from ..model.token import Token
from ..model.token import Token, RefreshTokenPayload
from ..domain.user import User as DBUser
from .base import BaseController, AsyncSession
from ..auth import authenticate_user, create_access_token
from ..auth import (
authenticate_user,
create_access_token,
create_refresh_token,
verify_refresh_token,
)
class CreateToken(BaseController):
async def call(self, content: OAuth2PasswordRequestForm) -> Token:
async with self.async_session() as session:
async with self.async_session.begin() as session:
user = await authenticate_user(session, content.username, content.password)
if user is None:
@ -19,9 +24,29 @@ class CreateToken(BaseController):
status_code=401, detail="Invalid username or password"
)
refresh_token = await create_refresh_token(session, user)
access_token = create_access_token(user)
return Token(
access_token=access_token,
refresh_token=refresh_token.token,
token_type="bearer",
)
class RefreshToken(BaseController):
async def call(self, content: RefreshTokenPayload) -> Token:
async with self.async_session.begin() as session:
user = await verify_refresh_token(session, content.refresh_token)
if user is None:
raise HTTPException(status_code=401, detail="Invalid token")
refresh_token = await create_refresh_token(session, user)
access_token = create_access_token(user)
return Token(
access_token=access_token,
refresh_token=refresh_token.token,
token_type="bearer",
)

View file

@ -4,3 +4,4 @@ from .entry import Entry
from .meal import Meal
from .product import Product
from .user import User
from .token import RefreshToken

59
fooder/domain/token.py Normal file
View file

@ -0,0 +1,59 @@
from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload
from sqlalchemy import ForeignKey, Integer, Date
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import date
from typing import Optional
from .base import Base, CommonMixin
from .meal import Meal
class RefreshToken(Base, CommonMixin):
"""Diary represents user diary for given day"""
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"), unique=True)
token: Mapped[str]
@classmethod
async def get_token(
cls,
session: AsyncSession,
user_id: int,
) -> "Optional[RefreshToken]":
"""get_token."""
query = select(cls).where(cls.user_id == user_id)
return await session.scalar(query)
@classmethod
async def create(
cls, session: AsyncSession, user_id: int, token: str
) -> "RefreshToken":
"""create.
:param session:
:type session: AsyncSession
:param user_id:
:type user_id: int
:param token:
:type token: str
:rtype: "RefreshToken"
"""
existing = await cls.get_token(session, user_id)
if existing:
existing.token = token
return existing
token = cls(
user_id=user_id,
token=token,
)
session.add(token)
try:
await session.flush()
except Exception:
raise AssertionError("invalid token")
return token

View file

@ -3,8 +3,13 @@ from pydantic import BaseModel
class Token(BaseModel):
access_token: str
token_type: str
refresh_token: str
token_type: str = "bearer"
class TokenData(BaseModel):
username: str | None = None
class RefreshTokenPayload(BaseModel):
refresh_token: str

View file

@ -9,7 +9,9 @@ class Settings(BaseSettings):
ECHO_SQL: bool
SECRET_KEY: str
REFRESH_SECRET_KEY: str
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
REFRESH_TOKEN_EXPIRE_DAYS: int = 30
ALLOWED_ORIGINS: List[str] = ["*"]

View file

@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends, Request
from ..model.token import Token
from ..controller.token import CreateToken
from ..model.token import Token, RefreshTokenPayload
from ..controller.token import CreateToken, RefreshToken
from fastapi.security import OAuth2PasswordRequestForm
from typing import Annotated
@ -15,3 +15,12 @@ async def create_token(
controller: CreateToken = Depends(CreateToken),
):
return await controller.call(data)
@router.post("/refresh", response_model=Token)
async def refresh_token(
request: Request,
data: RefreshTokenPayload,
controller: RefreshToken = Depends(RefreshToken),
):
return await controller.call(data)