add refresh tokens
This commit is contained in:
parent
e537bef22e
commit
28d55d3632
8 changed files with 142 additions and 7 deletions
|
@ -7,5 +7,7 @@ DB_URI="postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@database/${PO
|
||||||
ECHO_SQL=0
|
ECHO_SQL=0
|
||||||
|
|
||||||
SECRET_KEY="" # generate with $ openssl rand -hex 32
|
SECRET_KEY="" # generate with $ openssl rand -hex 32
|
||||||
|
REFRESH_SECRET_KEY="" # generate with $ openssl rand -hex 32
|
||||||
ALGORITHM="HS256"
|
ALGORITHM="HS256"
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||||
|
|
|
@ -10,6 +10,7 @@ from typing import AsyncGenerator, Dict, Annotated, Optional
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from .settings import Settings
|
from .settings import Settings
|
||||||
from .domain.user import User
|
from .domain.user import User
|
||||||
|
from .domain.token import RefreshToken
|
||||||
from .db import get_session
|
from .db import get_session
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,7 +42,26 @@ async def authenticate_user(
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(user: User):
|
async def verify_refresh_token(
|
||||||
|
session: AsyncSession, token: str
|
||||||
|
) -> AsyncGenerator[User, None]:
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||||
|
)
|
||||||
|
username: str = payload.get("sub")
|
||||||
|
if username is None:
|
||||||
|
return
|
||||||
|
except JWTError:
|
||||||
|
return
|
||||||
|
|
||||||
|
user = await User.get_by_username(session, username)
|
||||||
|
current_token = await RefreshToken.get_token(session, user.id)
|
||||||
|
if current_token is not None and current_token.token == token:
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(user: User) -> str:
|
||||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
to_encode = {
|
to_encode = {
|
||||||
"sub": user.username,
|
"sub": user.username,
|
||||||
|
@ -53,6 +73,18 @@ def create_access_token(user: User):
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
||||||
|
async def create_refresh_token(session: AsyncSession, user: User) -> RefreshToken:
|
||||||
|
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
to_encode = {
|
||||||
|
"sub": user.username,
|
||||||
|
"exp": expire,
|
||||||
|
}
|
||||||
|
encoded_jwt = jwt.encode(
|
||||||
|
to_encode, settings.REFRESH_SECRET_KEY, algorithm=settings.ALGORITHM
|
||||||
|
)
|
||||||
|
return await RefreshToken.create(session, token=encoded_jwt, user_id=user.id)
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
session: AsyncSessionDependency, token: TokenDependency
|
session: AsyncSessionDependency, token: TokenDependency
|
||||||
) -> User:
|
) -> User:
|
||||||
|
|
|
@ -3,15 +3,20 @@ from typing import AsyncIterator, Annotated
|
||||||
from fastapi import HTTPException, Depends
|
from fastapi import HTTPException, Depends
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
from ..model.token import Token
|
from ..model.token import Token, RefreshTokenPayload
|
||||||
from ..domain.user import User as DBUser
|
from ..domain.user import User as DBUser
|
||||||
from .base import BaseController, AsyncSession
|
from .base import BaseController, AsyncSession
|
||||||
from ..auth import authenticate_user, create_access_token
|
from ..auth import (
|
||||||
|
authenticate_user,
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
verify_refresh_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CreateToken(BaseController):
|
class CreateToken(BaseController):
|
||||||
async def call(self, content: OAuth2PasswordRequestForm) -> Token:
|
async def call(self, content: OAuth2PasswordRequestForm) -> Token:
|
||||||
async with self.async_session() as session:
|
async with self.async_session.begin() as session:
|
||||||
user = await authenticate_user(session, content.username, content.password)
|
user = await authenticate_user(session, content.username, content.password)
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
|
@ -19,9 +24,29 @@ class CreateToken(BaseController):
|
||||||
status_code=401, detail="Invalid username or password"
|
status_code=401, detail="Invalid username or password"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
refresh_token = await create_refresh_token(session, user)
|
||||||
access_token = create_access_token(user)
|
access_token = create_access_token(user)
|
||||||
|
|
||||||
return Token(
|
return Token(
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token.token,
|
||||||
|
token_type="bearer",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshToken(BaseController):
|
||||||
|
async def call(self, content: RefreshTokenPayload) -> Token:
|
||||||
|
async with self.async_session.begin() as session:
|
||||||
|
user = await verify_refresh_token(session, content.refresh_token)
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
|
|
||||||
|
refresh_token = await create_refresh_token(session, user)
|
||||||
|
access_token = create_access_token(user)
|
||||||
|
|
||||||
|
return Token(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token.token,
|
||||||
token_type="bearer",
|
token_type="bearer",
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,3 +4,4 @@ from .entry import Entry
|
||||||
from .meal import Meal
|
from .meal import Meal
|
||||||
from .product import Product
|
from .product import Product
|
||||||
from .user import User
|
from .user import User
|
||||||
|
from .token import RefreshToken
|
||||||
|
|
59
fooder/domain/token.py
Normal file
59
fooder/domain/token.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload
|
||||||
|
from sqlalchemy import ForeignKey, Integer, Date
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from datetime import date
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .base import Base, CommonMixin
|
||||||
|
from .meal import Meal
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshToken(Base, CommonMixin):
|
||||||
|
"""Diary represents user diary for given day"""
|
||||||
|
|
||||||
|
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"), unique=True)
|
||||||
|
token: Mapped[str]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_token(
|
||||||
|
cls,
|
||||||
|
session: AsyncSession,
|
||||||
|
user_id: int,
|
||||||
|
) -> "Optional[RefreshToken]":
|
||||||
|
"""get_token."""
|
||||||
|
query = select(cls).where(cls.user_id == user_id)
|
||||||
|
return await session.scalar(query)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(
|
||||||
|
cls, session: AsyncSession, user_id: int, token: str
|
||||||
|
) -> "RefreshToken":
|
||||||
|
"""create.
|
||||||
|
|
||||||
|
:param session:
|
||||||
|
:type session: AsyncSession
|
||||||
|
:param user_id:
|
||||||
|
:type user_id: int
|
||||||
|
:param token:
|
||||||
|
:type token: str
|
||||||
|
:rtype: "RefreshToken"
|
||||||
|
"""
|
||||||
|
existing = await cls.get_token(session, user_id)
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
existing.token = token
|
||||||
|
return existing
|
||||||
|
|
||||||
|
token = cls(
|
||||||
|
user_id=user_id,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
session.add(token)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await session.flush()
|
||||||
|
except Exception:
|
||||||
|
raise AssertionError("invalid token")
|
||||||
|
|
||||||
|
return token
|
|
@ -3,8 +3,13 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
class Token(BaseModel):
|
class Token(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
token_type: str
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
|
||||||
class TokenData(BaseModel):
|
class TokenData(BaseModel):
|
||||||
username: str | None = None
|
username: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenPayload(BaseModel):
|
||||||
|
refresh_token: str
|
||||||
|
|
|
@ -9,7 +9,9 @@ class Settings(BaseSettings):
|
||||||
ECHO_SQL: bool
|
ECHO_SQL: bool
|
||||||
|
|
||||||
SECRET_KEY: str
|
SECRET_KEY: str
|
||||||
|
REFRESH_SECRET_KEY: str
|
||||||
ALGORITHM: str = "HS256"
|
ALGORITHM: str = "HS256"
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
||||||
|
|
||||||
ALLOWED_ORIGINS: List[str] = ["*"]
|
ALLOWED_ORIGINS: List[str] = ["*"]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from fastapi import APIRouter, Depends, Request
|
from fastapi import APIRouter, Depends, Request
|
||||||
from ..model.token import Token
|
from ..model.token import Token, RefreshTokenPayload
|
||||||
from ..controller.token import CreateToken
|
from ..controller.token import CreateToken, RefreshToken
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
|
@ -15,3 +15,12 @@ async def create_token(
|
||||||
controller: CreateToken = Depends(CreateToken),
|
controller: CreateToken = Depends(CreateToken),
|
||||||
):
|
):
|
||||||
return await controller.call(data)
|
return await controller.call(data)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=Token)
|
||||||
|
async def refresh_token(
|
||||||
|
request: Request,
|
||||||
|
data: RefreshTokenPayload,
|
||||||
|
controller: RefreshToken = Depends(RefreshToken),
|
||||||
|
):
|
||||||
|
return await controller.call(data)
|
||||||
|
|
Loading…
Reference in a new issue