add refresh tokens
This commit is contained in:
parent
e537bef22e
commit
28d55d3632
8 changed files with 142 additions and 7 deletions
|
@ -7,5 +7,7 @@ DB_URI="postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@database/${PO
|
|||
ECHO_SQL=0
|
||||
|
||||
SECRET_KEY="" # generate with $ openssl rand -hex 32
|
||||
REFRESH_SECRET_KEY="" # generate with $ openssl rand -hex 32
|
||||
ALGORITHM="HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import AsyncGenerator, Dict, Annotated, Optional
|
|||
from datetime import datetime, timedelta
|
||||
from .settings import Settings
|
||||
from .domain.user import User
|
||||
from .domain.token import RefreshToken
|
||||
from .db import get_session
|
||||
|
||||
|
||||
|
@ -41,7 +42,26 @@ async def authenticate_user(
|
|||
return user
|
||||
|
||||
|
||||
def create_access_token(user: User):
|
||||
async def verify_refresh_token(
|
||||
session: AsyncSession, token: str
|
||||
) -> AsyncGenerator[User, None]:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
return
|
||||
except JWTError:
|
||||
return
|
||||
|
||||
user = await User.get_by_username(session, username)
|
||||
current_token = await RefreshToken.get_token(session, user.id)
|
||||
if current_token is not None and current_token.token == token:
|
||||
return user
|
||||
|
||||
|
||||
def create_access_token(user: User) -> str:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode = {
|
||||
"sub": user.username,
|
||||
|
@ -53,6 +73,18 @@ def create_access_token(user: User):
|
|||
return encoded_jwt
|
||||
|
||||
|
||||
async def create_refresh_token(session: AsyncSession, user: User) -> RefreshToken:
|
||||
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
to_encode = {
|
||||
"sub": user.username,
|
||||
"exp": expire,
|
||||
}
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.REFRESH_SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
return await RefreshToken.create(session, token=encoded_jwt, user_id=user.id)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
session: AsyncSessionDependency, token: TokenDependency
|
||||
) -> User:
|
||||
|
|
|
@ -3,15 +3,20 @@ from typing import AsyncIterator, Annotated
|
|||
from fastapi import HTTPException, Depends
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from ..model.token import Token
|
||||
from ..model.token import Token, RefreshTokenPayload
|
||||
from ..domain.user import User as DBUser
|
||||
from .base import BaseController, AsyncSession
|
||||
from ..auth import authenticate_user, create_access_token
|
||||
from ..auth import (
|
||||
authenticate_user,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
verify_refresh_token,
|
||||
)
|
||||
|
||||
|
||||
class CreateToken(BaseController):
|
||||
async def call(self, content: OAuth2PasswordRequestForm) -> Token:
|
||||
async with self.async_session() as session:
|
||||
async with self.async_session.begin() as session:
|
||||
user = await authenticate_user(session, content.username, content.password)
|
||||
|
||||
if user is None:
|
||||
|
@ -19,9 +24,29 @@ class CreateToken(BaseController):
|
|||
status_code=401, detail="Invalid username or password"
|
||||
)
|
||||
|
||||
refresh_token = await create_refresh_token(session, user)
|
||||
access_token = create_access_token(user)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token.token,
|
||||
token_type="bearer",
|
||||
)
|
||||
|
||||
|
||||
class RefreshToken(BaseController):
|
||||
async def call(self, content: RefreshTokenPayload) -> Token:
|
||||
async with self.async_session.begin() as session:
|
||||
user = await verify_refresh_token(session, content.refresh_token)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
refresh_token = await create_refresh_token(session, user)
|
||||
access_token = create_access_token(user)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token.token,
|
||||
token_type="bearer",
|
||||
)
|
||||
|
|
|
@ -4,3 +4,4 @@ from .entry import Entry
|
|||
from .meal import Meal
|
||||
from .product import Product
|
||||
from .user import User
|
||||
from .token import RefreshToken
|
||||
|
|
59
fooder/domain/token.py
Normal file
59
fooder/domain/token.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload
|
||||
from sqlalchemy import ForeignKey, Integer, Date
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from datetime import date
|
||||
from typing import Optional
|
||||
|
||||
from .base import Base, CommonMixin
|
||||
from .meal import Meal
|
||||
|
||||
|
||||
class RefreshToken(Base, CommonMixin):
|
||||
"""Diary represents user diary for given day"""
|
||||
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"), unique=True)
|
||||
token: Mapped[str]
|
||||
|
||||
@classmethod
|
||||
async def get_token(
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
user_id: int,
|
||||
) -> "Optional[RefreshToken]":
|
||||
"""get_token."""
|
||||
query = select(cls).where(cls.user_id == user_id)
|
||||
return await session.scalar(query)
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls, session: AsyncSession, user_id: int, token: str
|
||||
) -> "RefreshToken":
|
||||
"""create.
|
||||
|
||||
:param session:
|
||||
:type session: AsyncSession
|
||||
:param user_id:
|
||||
:type user_id: int
|
||||
:param token:
|
||||
:type token: str
|
||||
:rtype: "RefreshToken"
|
||||
"""
|
||||
existing = await cls.get_token(session, user_id)
|
||||
|
||||
if existing:
|
||||
existing.token = token
|
||||
return existing
|
||||
|
||||
token = cls(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
)
|
||||
session.add(token)
|
||||
|
||||
try:
|
||||
await session.flush()
|
||||
except Exception:
|
||||
raise AssertionError("invalid token")
|
||||
|
||||
return token
|
|
@ -3,8 +3,13 @@ from pydantic import BaseModel
|
|||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str | None = None
|
||||
|
||||
|
||||
class RefreshTokenPayload(BaseModel):
|
||||
refresh_token: str
|
||||
|
|
|
@ -9,7 +9,9 @@ class Settings(BaseSettings):
|
|||
ECHO_SQL: bool
|
||||
|
||||
SECRET_KEY: str
|
||||
REFRESH_SECRET_KEY: str
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
||||
|
||||
ALLOWED_ORIGINS: List[str] = ["*"]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from fastapi import APIRouter, Depends, Request
|
||||
from ..model.token import Token
|
||||
from ..controller.token import CreateToken
|
||||
from ..model.token import Token, RefreshTokenPayload
|
||||
from ..controller.token import CreateToken, RefreshToken
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from typing import Annotated
|
||||
|
||||
|
@ -15,3 +15,12 @@ async def create_token(
|
|||
controller: CreateToken = Depends(CreateToken),
|
||||
):
|
||||
return await controller.call(data)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token)
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
data: RefreshTokenPayload,
|
||||
controller: RefreshToken = Depends(RefreshToken),
|
||||
):
|
||||
return await controller.call(data)
|
||||
|
|
Loading…
Reference in a new issue