diff --git a/env.template b/env.template index 9dc6843..749f54f 100644 --- a/env.template +++ b/env.template @@ -7,5 +7,7 @@ DB_URI="postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@database/${PO ECHO_SQL=0 SECRET_KEY="" # generate with $ openssl rand -hex 32 +REFRESH_SECRET_KEY="" # generate with $ openssl rand -hex 32 ALGORITHM="HS256" ACCESS_TOKEN_EXPIRE_MINUTES=30 +REFRESH_TOKEN_EXPIRE_DAYS=30 diff --git a/fooder/auth.py b/fooder/auth.py index 1e12b0e..7c8641b 100644 --- a/fooder/auth.py +++ b/fooder/auth.py @@ -10,6 +10,7 @@ from typing import AsyncGenerator, Dict, Annotated, Optional from datetime import datetime, timedelta from .settings import Settings from .domain.user import User +from .domain.token import RefreshToken from .db import get_session @@ -41,7 +42,26 @@ async def authenticate_user( return user -def create_access_token(user: User): +async def verify_refresh_token( + session: AsyncSession, token: str +) -> AsyncGenerator[User, None]: + try: + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] + ) + username: str = payload.get("sub") + if username is None: + return + except JWTError: + return + + user = await User.get_by_username(session, username) + current_token = await RefreshToken.get_token(session, user.id) + if current_token is not None and current_token.token == token: + return user + + +def create_access_token(user: User) -> str: expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) to_encode = { "sub": user.username, @@ -53,6 +73,18 @@ def create_access_token(user: User): return encoded_jwt +async def create_refresh_token(session: AsyncSession, user: User) -> RefreshToken: + expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) + to_encode = { + "sub": user.username, + "exp": expire, + } + encoded_jwt = jwt.encode( + to_encode, settings.REFRESH_SECRET_KEY, algorithm=settings.ALGORITHM + ) + return await RefreshToken.create(session, token=encoded_jwt, user_id=user.id) + + async def get_current_user( session: AsyncSessionDependency, token: TokenDependency ) -> User: diff --git a/fooder/controller/token.py b/fooder/controller/token.py index d08a28f..6f8ce93 100644 --- a/fooder/controller/token.py +++ b/fooder/controller/token.py @@ -3,15 +3,20 @@ from typing import AsyncIterator, Annotated from fastapi import HTTPException, Depends from fastapi.security import OAuth2PasswordRequestForm -from ..model.token import Token +from ..model.token import Token, RefreshTokenPayload from ..domain.user import User as DBUser from .base import BaseController, AsyncSession -from ..auth import authenticate_user, create_access_token +from ..auth import ( + authenticate_user, + create_access_token, + create_refresh_token, + verify_refresh_token, +) class CreateToken(BaseController): async def call(self, content: OAuth2PasswordRequestForm) -> Token: - async with self.async_session() as session: + async with self.async_session.begin() as session: user = await authenticate_user(session, content.username, content.password) if user is None: @@ -19,9 +24,29 @@ class CreateToken(BaseController): status_code=401, detail="Invalid username or password" ) + refresh_token = await create_refresh_token(session, user) access_token = create_access_token(user) return Token( access_token=access_token, + refresh_token=refresh_token.token, + token_type="bearer", + ) + + +class RefreshToken(BaseController): + async def call(self, content: RefreshTokenPayload) -> Token: + async with self.async_session.begin() as session: + user = await verify_refresh_token(session, content.refresh_token) + + if user is None: + raise HTTPException(status_code=401, detail="Invalid token") + + refresh_token = await create_refresh_token(session, user) + access_token = create_access_token(user) + + return Token( + access_token=access_token, + refresh_token=refresh_token.token, token_type="bearer", ) diff --git a/fooder/domain/__init__.py b/fooder/domain/__init__.py index 08b4fe3..c9e10a6 100644 --- a/fooder/domain/__init__.py +++ b/fooder/domain/__init__.py @@ -4,3 +4,4 @@ from .entry import Entry from .meal import Meal from .product import Product from .user import User +from .token import RefreshToken diff --git a/fooder/domain/token.py b/fooder/domain/token.py new file mode 100644 index 0000000..d93c008 --- /dev/null +++ b/fooder/domain/token.py @@ -0,0 +1,59 @@ +from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload +from sqlalchemy import ForeignKey, Integer, Date +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from datetime import date +from typing import Optional + +from .base import Base, CommonMixin +from .meal import Meal + + +class RefreshToken(Base, CommonMixin): + """Diary represents user diary for given day""" + + user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"), unique=True) + token: Mapped[str] + + @classmethod + async def get_token( + cls, + session: AsyncSession, + user_id: int, + ) -> "Optional[RefreshToken]": + """get_token.""" + query = select(cls).where(cls.user_id == user_id) + return await session.scalar(query) + + @classmethod + async def create( + cls, session: AsyncSession, user_id: int, token: str + ) -> "RefreshToken": + """create. + + :param session: + :type session: AsyncSession + :param user_id: + :type user_id: int + :param token: + :type token: str + :rtype: "RefreshToken" + """ + existing = await cls.get_token(session, user_id) + + if existing: + existing.token = token + return existing + + token = cls( + user_id=user_id, + token=token, + ) + session.add(token) + + try: + await session.flush() + except Exception: + raise AssertionError("invalid token") + + return token diff --git a/fooder/model/token.py b/fooder/model/token.py index 0802867..3de026e 100644 --- a/fooder/model/token.py +++ b/fooder/model/token.py @@ -3,8 +3,13 @@ from pydantic import BaseModel class Token(BaseModel): access_token: str - token_type: str + refresh_token: str + token_type: str = "bearer" class TokenData(BaseModel): username: str | None = None + + +class RefreshTokenPayload(BaseModel): + refresh_token: str diff --git a/fooder/settings.py b/fooder/settings.py index 7d004bc..89b5688 100644 --- a/fooder/settings.py +++ b/fooder/settings.py @@ -9,7 +9,9 @@ class Settings(BaseSettings): ECHO_SQL: bool SECRET_KEY: str + REFRESH_SECRET_KEY: str ALGORITHM: str = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 + REFRESH_TOKEN_EXPIRE_DAYS: int = 30 ALLOWED_ORIGINS: List[str] = ["*"] diff --git a/fooder/view/token.py b/fooder/view/token.py index 6f9f45a..5e7ab77 100644 --- a/fooder/view/token.py +++ b/fooder/view/token.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Depends, Request -from ..model.token import Token -from ..controller.token import CreateToken +from ..model.token import Token, RefreshTokenPayload +from ..controller.token import CreateToken, RefreshToken from fastapi.security import OAuth2PasswordRequestForm from typing import Annotated @@ -15,3 +15,12 @@ async def create_token( controller: CreateToken = Depends(CreateToken), ): return await controller.call(data) + + +@router.post("/refresh", response_model=Token) +async def refresh_token( + request: Request, + data: RefreshTokenPayload, + controller: RefreshToken = Depends(RefreshToken), +): + return await controller.call(data)