add refresh tokens
This commit is contained in:
		
							parent
							
								
									e537bef22e
								
							
						
					
					
						commit
						28d55d3632
					
				
					 8 changed files with 142 additions and 7 deletions
				
			
		| 
						 | 
					@ -7,5 +7,7 @@ DB_URI="postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@database/${PO
 | 
				
			||||||
ECHO_SQL=0
 | 
					ECHO_SQL=0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
SECRET_KEY="" # generate with $ openssl rand -hex 32
 | 
					SECRET_KEY="" # generate with $ openssl rand -hex 32
 | 
				
			||||||
 | 
					REFRESH_SECRET_KEY="" # generate with $ openssl rand -hex 32
 | 
				
			||||||
ALGORITHM="HS256"
 | 
					ALGORITHM="HS256"
 | 
				
			||||||
ACCESS_TOKEN_EXPIRE_MINUTES=30
 | 
					ACCESS_TOKEN_EXPIRE_MINUTES=30
 | 
				
			||||||
 | 
					REFRESH_TOKEN_EXPIRE_DAYS=30
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -10,6 +10,7 @@ from typing import AsyncGenerator, Dict, Annotated, Optional
 | 
				
			||||||
from datetime import datetime, timedelta
 | 
					from datetime import datetime, timedelta
 | 
				
			||||||
from .settings import Settings
 | 
					from .settings import Settings
 | 
				
			||||||
from .domain.user import User
 | 
					from .domain.user import User
 | 
				
			||||||
 | 
					from .domain.token import RefreshToken
 | 
				
			||||||
from .db import get_session
 | 
					from .db import get_session
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -41,7 +42,26 @@ async def authenticate_user(
 | 
				
			||||||
    return user
 | 
					    return user
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_access_token(user: User):
 | 
					async def verify_refresh_token(
 | 
				
			||||||
 | 
					    session: AsyncSession, token: str
 | 
				
			||||||
 | 
					) -> AsyncGenerator[User, None]:
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        payload = jwt.decode(
 | 
				
			||||||
 | 
					            token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        username: str = payload.get("sub")
 | 
				
			||||||
 | 
					        if username is None:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					    except JWTError:
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    user = await User.get_by_username(session, username)
 | 
				
			||||||
 | 
					    current_token = await RefreshToken.get_token(session, user.id)
 | 
				
			||||||
 | 
					    if current_token is not None and current_token.token == token:
 | 
				
			||||||
 | 
					        return user
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def create_access_token(user: User) -> str:
 | 
				
			||||||
    expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
 | 
					    expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
 | 
				
			||||||
    to_encode = {
 | 
					    to_encode = {
 | 
				
			||||||
        "sub": user.username,
 | 
					        "sub": user.username,
 | 
				
			||||||
| 
						 | 
					@ -53,6 +73,18 @@ def create_access_token(user: User):
 | 
				
			||||||
    return encoded_jwt
 | 
					    return encoded_jwt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def create_refresh_token(session: AsyncSession, user: User) -> RefreshToken:
 | 
				
			||||||
 | 
					    expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
 | 
				
			||||||
 | 
					    to_encode = {
 | 
				
			||||||
 | 
					        "sub": user.username,
 | 
				
			||||||
 | 
					        "exp": expire,
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    encoded_jwt = jwt.encode(
 | 
				
			||||||
 | 
					        to_encode, settings.REFRESH_SECRET_KEY, algorithm=settings.ALGORITHM
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    return await RefreshToken.create(session, token=encoded_jwt, user_id=user.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def get_current_user(
 | 
					async def get_current_user(
 | 
				
			||||||
    session: AsyncSessionDependency, token: TokenDependency
 | 
					    session: AsyncSessionDependency, token: TokenDependency
 | 
				
			||||||
) -> User:
 | 
					) -> User:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,15 +3,20 @@ from typing import AsyncIterator, Annotated
 | 
				
			||||||
from fastapi import HTTPException, Depends
 | 
					from fastapi import HTTPException, Depends
 | 
				
			||||||
from fastapi.security import OAuth2PasswordRequestForm
 | 
					from fastapi.security import OAuth2PasswordRequestForm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..model.token import Token
 | 
					from ..model.token import Token, RefreshTokenPayload
 | 
				
			||||||
from ..domain.user import User as DBUser
 | 
					from ..domain.user import User as DBUser
 | 
				
			||||||
from .base import BaseController, AsyncSession
 | 
					from .base import BaseController, AsyncSession
 | 
				
			||||||
from ..auth import authenticate_user, create_access_token
 | 
					from ..auth import (
 | 
				
			||||||
 | 
					    authenticate_user,
 | 
				
			||||||
 | 
					    create_access_token,
 | 
				
			||||||
 | 
					    create_refresh_token,
 | 
				
			||||||
 | 
					    verify_refresh_token,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class CreateToken(BaseController):
 | 
					class CreateToken(BaseController):
 | 
				
			||||||
    async def call(self, content: OAuth2PasswordRequestForm) -> Token:
 | 
					    async def call(self, content: OAuth2PasswordRequestForm) -> Token:
 | 
				
			||||||
        async with self.async_session() as session:
 | 
					        async with self.async_session.begin() as session:
 | 
				
			||||||
            user = await authenticate_user(session, content.username, content.password)
 | 
					            user = await authenticate_user(session, content.username, content.password)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if user is None:
 | 
					            if user is None:
 | 
				
			||||||
| 
						 | 
					@ -19,9 +24,29 @@ class CreateToken(BaseController):
 | 
				
			||||||
                    status_code=401, detail="Invalid username or password"
 | 
					                    status_code=401, detail="Invalid username or password"
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            refresh_token = await create_refresh_token(session, user)
 | 
				
			||||||
            access_token = create_access_token(user)
 | 
					            access_token = create_access_token(user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            return Token(
 | 
					            return Token(
 | 
				
			||||||
                access_token=access_token,
 | 
					                access_token=access_token,
 | 
				
			||||||
 | 
					                refresh_token=refresh_token.token,
 | 
				
			||||||
 | 
					                token_type="bearer",
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RefreshToken(BaseController):
 | 
				
			||||||
 | 
					    async def call(self, content: RefreshTokenPayload) -> Token:
 | 
				
			||||||
 | 
					        async with self.async_session.begin() as session:
 | 
				
			||||||
 | 
					            user = await verify_refresh_token(session, content.refresh_token)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if user is None:
 | 
				
			||||||
 | 
					                raise HTTPException(status_code=401, detail="Invalid token")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            refresh_token = await create_refresh_token(session, user)
 | 
				
			||||||
 | 
					            access_token = create_access_token(user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return Token(
 | 
				
			||||||
 | 
					                access_token=access_token,
 | 
				
			||||||
 | 
					                refresh_token=refresh_token.token,
 | 
				
			||||||
                token_type="bearer",
 | 
					                token_type="bearer",
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -4,3 +4,4 @@ from .entry import Entry
 | 
				
			||||||
from .meal import Meal
 | 
					from .meal import Meal
 | 
				
			||||||
from .product import Product
 | 
					from .product import Product
 | 
				
			||||||
from .user import User
 | 
					from .user import User
 | 
				
			||||||
 | 
					from .token import RefreshToken
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										59
									
								
								fooder/domain/token.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								fooder/domain/token.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,59 @@
 | 
				
			||||||
 | 
					from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload
 | 
				
			||||||
 | 
					from sqlalchemy import ForeignKey, Integer, Date
 | 
				
			||||||
 | 
					from sqlalchemy import select
 | 
				
			||||||
 | 
					from sqlalchemy.ext.asyncio import AsyncSession
 | 
				
			||||||
 | 
					from datetime import date
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .base import Base, CommonMixin
 | 
				
			||||||
 | 
					from .meal import Meal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RefreshToken(Base, CommonMixin):
 | 
				
			||||||
 | 
					    """Diary represents user diary for given day"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"), unique=True)
 | 
				
			||||||
 | 
					    token: Mapped[str]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    async def get_token(
 | 
				
			||||||
 | 
					        cls,
 | 
				
			||||||
 | 
					        session: AsyncSession,
 | 
				
			||||||
 | 
					        user_id: int,
 | 
				
			||||||
 | 
					    ) -> "Optional[RefreshToken]":
 | 
				
			||||||
 | 
					        """get_token."""
 | 
				
			||||||
 | 
					        query = select(cls).where(cls.user_id == user_id)
 | 
				
			||||||
 | 
					        return await session.scalar(query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    async def create(
 | 
				
			||||||
 | 
					        cls, session: AsyncSession, user_id: int, token: str
 | 
				
			||||||
 | 
					    ) -> "RefreshToken":
 | 
				
			||||||
 | 
					        """create.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param session:
 | 
				
			||||||
 | 
					        :type session: AsyncSession
 | 
				
			||||||
 | 
					        :param user_id:
 | 
				
			||||||
 | 
					        :type user_id: int
 | 
				
			||||||
 | 
					        :param token:
 | 
				
			||||||
 | 
					        :type token: str
 | 
				
			||||||
 | 
					        :rtype: "RefreshToken"
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        existing = await cls.get_token(session, user_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if existing:
 | 
				
			||||||
 | 
					            existing.token = token
 | 
				
			||||||
 | 
					            return existing
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        token = cls(
 | 
				
			||||||
 | 
					            user_id=user_id,
 | 
				
			||||||
 | 
					            token=token,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        session.add(token)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            await session.flush()
 | 
				
			||||||
 | 
					        except Exception:
 | 
				
			||||||
 | 
					            raise AssertionError("invalid token")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return token
 | 
				
			||||||
| 
						 | 
					@ -3,8 +3,13 @@ from pydantic import BaseModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Token(BaseModel):
 | 
					class Token(BaseModel):
 | 
				
			||||||
    access_token: str
 | 
					    access_token: str
 | 
				
			||||||
    token_type: str
 | 
					    refresh_token: str
 | 
				
			||||||
 | 
					    token_type: str = "bearer"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TokenData(BaseModel):
 | 
					class TokenData(BaseModel):
 | 
				
			||||||
    username: str | None = None
 | 
					    username: str | None = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RefreshTokenPayload(BaseModel):
 | 
				
			||||||
 | 
					    refresh_token: str
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -9,7 +9,9 @@ class Settings(BaseSettings):
 | 
				
			||||||
    ECHO_SQL: bool
 | 
					    ECHO_SQL: bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    SECRET_KEY: str
 | 
					    SECRET_KEY: str
 | 
				
			||||||
 | 
					    REFRESH_SECRET_KEY: str
 | 
				
			||||||
    ALGORITHM: str = "HS256"
 | 
					    ALGORITHM: str = "HS256"
 | 
				
			||||||
    ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
 | 
					    ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
 | 
				
			||||||
 | 
					    REFRESH_TOKEN_EXPIRE_DAYS: int = 30
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ALLOWED_ORIGINS: List[str] = ["*"]
 | 
					    ALLOWED_ORIGINS: List[str] = ["*"]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,6 @@
 | 
				
			||||||
from fastapi import APIRouter, Depends, Request
 | 
					from fastapi import APIRouter, Depends, Request
 | 
				
			||||||
from ..model.token import Token
 | 
					from ..model.token import Token, RefreshTokenPayload
 | 
				
			||||||
from ..controller.token import CreateToken
 | 
					from ..controller.token import CreateToken, RefreshToken
 | 
				
			||||||
from fastapi.security import OAuth2PasswordRequestForm
 | 
					from fastapi.security import OAuth2PasswordRequestForm
 | 
				
			||||||
from typing import Annotated
 | 
					from typing import Annotated
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,3 +15,12 @@ async def create_token(
 | 
				
			||||||
    controller: CreateToken = Depends(CreateToken),
 | 
					    controller: CreateToken = Depends(CreateToken),
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    return await controller.call(data)
 | 
					    return await controller.call(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@router.post("/refresh", response_model=Token)
 | 
				
			||||||
 | 
					async def refresh_token(
 | 
				
			||||||
 | 
					    request: Request,
 | 
				
			||||||
 | 
					    data: RefreshTokenPayload,
 | 
				
			||||||
 | 
					    controller: RefreshToken = Depends(RefreshToken),
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    return await controller.call(data)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue