add refresh tokens

This commit is contained in:
Piotr Domański 2023-04-02 15:20:53 +02:00
parent 3b3e709e56
commit e38949e03a
4 changed files with 79 additions and 26 deletions

View file

@ -44,7 +44,7 @@ async def authenticate_user(
async def verify_refresh_token( async def verify_refresh_token(
session: AsyncSession, token: str session: AsyncSession, token: str
) -> AsyncGenerator[User, None]: ) -> AsyncGenerator[RefreshToken, None]:
try: try:
payload = jwt.decode( payload = jwt.decode(
token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM] token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
@ -56,9 +56,11 @@ async def verify_refresh_token(
return return
user = await User.get_by_username(session, username) user = await User.get_by_username(session, username)
current_token = await RefreshToken.get_token(session, user.id) if user is None:
if current_token is not None and current_token.token == token: return
return user current_token = await RefreshToken.get_token(session, user.id, token)
if current_token is not None:
return current_token
def create_access_token(user: User) -> str: def create_access_token(user: User) -> str:

View file

@ -37,11 +37,14 @@ class CreateToken(BaseController):
class RefreshToken(BaseController): class RefreshToken(BaseController):
async def call(self, content: RefreshTokenPayload) -> Token: async def call(self, content: RefreshTokenPayload) -> Token:
async with self.async_session.begin() as session: async with self.async_session.begin() as session:
user = await verify_refresh_token(session, content.refresh_token) current_token = await verify_refresh_token(session, content.refresh_token)
if user is None: if current_token is None:
raise HTTPException(status_code=401, detail="Invalid token") raise HTTPException(status_code=401, detail="Invalid token")
user = await DBUser.get(session, current_token.user_id)
await current_token.delete(session)
refresh_token = await create_refresh_token(session, user) refresh_token = await create_refresh_token(session, user)
access_token = create_access_token(user) access_token = create_access_token(user)

View file

@ -1,18 +1,16 @@
from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload, relationship
from sqlalchemy import ForeignKey, Integer, Date from sqlalchemy import ForeignKey, Integer
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from datetime import date
from typing import Optional from typing import Optional
from .base import Base, CommonMixin from .base import Base, CommonMixin
from .meal import Meal
class RefreshToken(Base, CommonMixin): class RefreshToken(Base, CommonMixin):
"""Diary represents user diary for given day""" """Diary represents user diary for given day"""
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"), unique=True) user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
token: Mapped[str] token: Mapped[str]
@classmethod @classmethod
@ -20,9 +18,19 @@ class RefreshToken(Base, CommonMixin):
cls, cls,
session: AsyncSession, session: AsyncSession,
user_id: int, user_id: int,
token: str,
) -> "Optional[RefreshToken]": ) -> "Optional[RefreshToken]":
"""get_token.""" """get_token.
query = select(cls).where(cls.user_id == user_id)
:param session:
:type session: AsyncSession
:param user_id:
:type user_id: int
:param token:
:type token: str
:rtype: "Optional[RefreshToken]"
"""
query = select(cls).where(cls.user_id == user_id).where(cls.token == token)
return await session.scalar(query) return await session.scalar(query)
@classmethod @classmethod
@ -39,12 +47,6 @@ class RefreshToken(Base, CommonMixin):
:type token: str :type token: str
:rtype: "RefreshToken" :rtype: "RefreshToken"
""" """
existing = await cls.get_token(session, user_id)
if existing:
existing.token = token
return existing
token = cls( token = cls(
user_id=user_id, user_id=user_id,
token=token, token=token,
@ -57,3 +59,13 @@ class RefreshToken(Base, CommonMixin):
raise AssertionError("invalid token") raise AssertionError("invalid token")
return token return token
async def delete(self, session: AsyncSession) -> None:
"""delete.
:param session:
:type session: AsyncSession
:rtype: None
"""
await session.delete(self)
await session.flush()

View file

@ -12,22 +12,58 @@ class User(Base, CommonMixin):
username: Mapped[str] username: Mapped[str]
hashed_password: Mapped[str] hashed_password: Mapped[str]
@classmethod
async def get_by_username(
cls, session: AsyncSession, username: str
) -> Optional["User"]:
query = select(cls).filter(cls.username == username)
return await session.scalar(query.order_by(cls.id))
def set_password(self, password) -> None: def set_password(self, password) -> None:
"""set_password.
:param password:
:rtype: None
"""
from ..auth import password_helper from ..auth import password_helper
self.hashed_password = password_helper.hash(password) self.hashed_password = password_helper.hash(password)
@classmethod
async def get_by_username(
cls, session: AsyncSession, username: str
) -> Optional["User"]:
"""get_by_username.
:param session:
:type session: AsyncSession
:param username:
:type username: str
:rtype: Optional["User"]
"""
query = select(cls).filter(cls.username == username)
return await session.scalar(query.order_by(cls.id))
@classmethod
async def get(cls, session: AsyncSession, id: int) -> Optional["User"]:
"""get_by_username.
:param session:
:type session: AsyncSession
:param id:
:type id: int
:rtype: Optional["User"]
"""
query = select(cls).filter(cls.id == id)
return await session.scalar(query.order_by(cls.id))
@classmethod @classmethod
async def create( async def create(
cls, session: AsyncSession, username: str, password: str cls, session: AsyncSession, username: str, password: str
) -> "User": ) -> "User":
"""create.
:param session:
:type session: AsyncSession
:param username:
:type username: str
:param password:
:type password: str
:rtype: "User"
"""
exsisting_user = await User.get_by_username(session, username) exsisting_user = await User.get_by_username(session, username)
assert exsisting_user is None, "user already exists" assert exsisting_user is None, "user already exists"
user = cls(username=username) user = cls(username=username)