add refresh tokens

This commit is contained in:
Piotr Domański 2023-04-02 15:20:53 +02:00
parent 3b3e709e56
commit e38949e03a
4 changed files with 79 additions and 26 deletions

View file

@ -44,7 +44,7 @@ async def authenticate_user(
async def verify_refresh_token(
session: AsyncSession, token: str
) -> AsyncGenerator[User, None]:
) -> AsyncGenerator[RefreshToken, None]:
try:
payload = jwt.decode(
token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
@ -56,9 +56,11 @@ async def verify_refresh_token(
return
user = await User.get_by_username(session, username)
current_token = await RefreshToken.get_token(session, user.id)
if current_token is not None and current_token.token == token:
return user
if user is None:
return
current_token = await RefreshToken.get_token(session, user.id, token)
if current_token is not None:
return current_token
def create_access_token(user: User) -> str:

View file

@ -37,11 +37,14 @@ class CreateToken(BaseController):
class RefreshToken(BaseController):
async def call(self, content: RefreshTokenPayload) -> Token:
async with self.async_session.begin() as session:
user = await verify_refresh_token(session, content.refresh_token)
current_token = await verify_refresh_token(session, content.refresh_token)
if user is None:
if current_token is None:
raise HTTPException(status_code=401, detail="Invalid token")
user = await DBUser.get(session, current_token.user_id)
await current_token.delete(session)
refresh_token = await create_refresh_token(session, user)
access_token = create_access_token(user)

View file

@ -1,18 +1,16 @@
from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload
from sqlalchemy import ForeignKey, Integer, Date
from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload, relationship
from sqlalchemy import ForeignKey, Integer
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import date
from typing import Optional
from .base import Base, CommonMixin
from .meal import Meal
class RefreshToken(Base, CommonMixin):
"""Diary represents user diary for given day"""
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"), unique=True)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
token: Mapped[str]
@classmethod
@ -20,9 +18,19 @@ class RefreshToken(Base, CommonMixin):
cls,
session: AsyncSession,
user_id: int,
token: str,
) -> "Optional[RefreshToken]":
"""get_token."""
query = select(cls).where(cls.user_id == user_id)
"""get_token.
:param session:
:type session: AsyncSession
:param user_id:
:type user_id: int
:param token:
:type token: str
:rtype: "Optional[RefreshToken]"
"""
query = select(cls).where(cls.user_id == user_id).where(cls.token == token)
return await session.scalar(query)
@classmethod
@ -39,12 +47,6 @@ class RefreshToken(Base, CommonMixin):
:type token: str
:rtype: "RefreshToken"
"""
existing = await cls.get_token(session, user_id)
if existing:
existing.token = token
return existing
token = cls(
user_id=user_id,
token=token,
@ -57,3 +59,13 @@ class RefreshToken(Base, CommonMixin):
raise AssertionError("invalid token")
return token
async def delete(self, session: AsyncSession) -> None:
"""delete.
:param session:
:type session: AsyncSession
:rtype: None
"""
await session.delete(self)
await session.flush()

View file

@ -12,22 +12,58 @@ class User(Base, CommonMixin):
username: Mapped[str]
hashed_password: Mapped[str]
@classmethod
async def get_by_username(
cls, session: AsyncSession, username: str
) -> Optional["User"]:
query = select(cls).filter(cls.username == username)
return await session.scalar(query.order_by(cls.id))
def set_password(self, password) -> None:
"""set_password.
:param password:
:rtype: None
"""
from ..auth import password_helper
self.hashed_password = password_helper.hash(password)
@classmethod
async def get_by_username(
cls, session: AsyncSession, username: str
) -> Optional["User"]:
"""get_by_username.
:param session:
:type session: AsyncSession
:param username:
:type username: str
:rtype: Optional["User"]
"""
query = select(cls).filter(cls.username == username)
return await session.scalar(query.order_by(cls.id))
@classmethod
async def get(cls, session: AsyncSession, id: int) -> Optional["User"]:
"""get_by_username.
:param session:
:type session: AsyncSession
:param id:
:type id: int
:rtype: Optional["User"]
"""
query = select(cls).filter(cls.id == id)
return await session.scalar(query.order_by(cls.id))
@classmethod
async def create(
cls, session: AsyncSession, username: str, password: str
) -> "User":
"""create.
:param session:
:type session: AsyncSession
:param username:
:type username: str
:param password:
:type password: str
:rtype: "User"
"""
exsisting_user = await User.get_by_username(session, username)
assert exsisting_user is None, "user already exists"
user = cls(username=username)