add refresh tokens
This commit is contained in:
parent
3b3e709e56
commit
e38949e03a
4 changed files with 79 additions and 26 deletions
|
@ -44,7 +44,7 @@ async def authenticate_user(
|
|||
|
||||
async def verify_refresh_token(
|
||||
session: AsyncSession, token: str
|
||||
) -> AsyncGenerator[User, None]:
|
||||
) -> AsyncGenerator[RefreshToken, None]:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
|
@ -56,9 +56,11 @@ async def verify_refresh_token(
|
|||
return
|
||||
|
||||
user = await User.get_by_username(session, username)
|
||||
current_token = await RefreshToken.get_token(session, user.id)
|
||||
if current_token is not None and current_token.token == token:
|
||||
return user
|
||||
if user is None:
|
||||
return
|
||||
current_token = await RefreshToken.get_token(session, user.id, token)
|
||||
if current_token is not None:
|
||||
return current_token
|
||||
|
||||
|
||||
def create_access_token(user: User) -> str:
|
||||
|
|
|
@ -37,11 +37,14 @@ class CreateToken(BaseController):
|
|||
class RefreshToken(BaseController):
|
||||
async def call(self, content: RefreshTokenPayload) -> Token:
|
||||
async with self.async_session.begin() as session:
|
||||
user = await verify_refresh_token(session, content.refresh_token)
|
||||
current_token = await verify_refresh_token(session, content.refresh_token)
|
||||
|
||||
if user is None:
|
||||
if current_token is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
user = await DBUser.get(session, current_token.user_id)
|
||||
await current_token.delete(session)
|
||||
|
||||
refresh_token = await create_refresh_token(session, user)
|
||||
access_token = create_access_token(user)
|
||||
|
||||
|
|
|
@ -1,18 +1,16 @@
|
|||
from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload
|
||||
from sqlalchemy import ForeignKey, Integer, Date
|
||||
from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload, relationship
|
||||
from sqlalchemy import ForeignKey, Integer
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from datetime import date
|
||||
from typing import Optional
|
||||
|
||||
from .base import Base, CommonMixin
|
||||
from .meal import Meal
|
||||
|
||||
|
||||
class RefreshToken(Base, CommonMixin):
|
||||
"""Diary represents user diary for given day"""
|
||||
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"), unique=True)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
|
||||
token: Mapped[str]
|
||||
|
||||
@classmethod
|
||||
|
@ -20,9 +18,19 @@ class RefreshToken(Base, CommonMixin):
|
|||
cls,
|
||||
session: AsyncSession,
|
||||
user_id: int,
|
||||
token: str,
|
||||
) -> "Optional[RefreshToken]":
|
||||
"""get_token."""
|
||||
query = select(cls).where(cls.user_id == user_id)
|
||||
"""get_token.
|
||||
|
||||
:param session:
|
||||
:type session: AsyncSession
|
||||
:param user_id:
|
||||
:type user_id: int
|
||||
:param token:
|
||||
:type token: str
|
||||
:rtype: "Optional[RefreshToken]"
|
||||
"""
|
||||
query = select(cls).where(cls.user_id == user_id).where(cls.token == token)
|
||||
return await session.scalar(query)
|
||||
|
||||
@classmethod
|
||||
|
@ -39,12 +47,6 @@ class RefreshToken(Base, CommonMixin):
|
|||
:type token: str
|
||||
:rtype: "RefreshToken"
|
||||
"""
|
||||
existing = await cls.get_token(session, user_id)
|
||||
|
||||
if existing:
|
||||
existing.token = token
|
||||
return existing
|
||||
|
||||
token = cls(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
|
@ -57,3 +59,13 @@ class RefreshToken(Base, CommonMixin):
|
|||
raise AssertionError("invalid token")
|
||||
|
||||
return token
|
||||
|
||||
async def delete(self, session: AsyncSession) -> None:
|
||||
"""delete.
|
||||
|
||||
:param session:
|
||||
:type session: AsyncSession
|
||||
:rtype: None
|
||||
"""
|
||||
await session.delete(self)
|
||||
await session.flush()
|
||||
|
|
|
@ -12,22 +12,58 @@ class User(Base, CommonMixin):
|
|||
username: Mapped[str]
|
||||
hashed_password: Mapped[str]
|
||||
|
||||
@classmethod
|
||||
async def get_by_username(
|
||||
cls, session: AsyncSession, username: str
|
||||
) -> Optional["User"]:
|
||||
query = select(cls).filter(cls.username == username)
|
||||
return await session.scalar(query.order_by(cls.id))
|
||||
|
||||
def set_password(self, password) -> None:
|
||||
"""set_password.
|
||||
|
||||
:param password:
|
||||
:rtype: None
|
||||
"""
|
||||
from ..auth import password_helper
|
||||
|
||||
self.hashed_password = password_helper.hash(password)
|
||||
|
||||
@classmethod
|
||||
async def get_by_username(
|
||||
cls, session: AsyncSession, username: str
|
||||
) -> Optional["User"]:
|
||||
"""get_by_username.
|
||||
|
||||
:param session:
|
||||
:type session: AsyncSession
|
||||
:param username:
|
||||
:type username: str
|
||||
:rtype: Optional["User"]
|
||||
"""
|
||||
query = select(cls).filter(cls.username == username)
|
||||
return await session.scalar(query.order_by(cls.id))
|
||||
|
||||
@classmethod
|
||||
async def get(cls, session: AsyncSession, id: int) -> Optional["User"]:
|
||||
"""get_by_username.
|
||||
|
||||
:param session:
|
||||
:type session: AsyncSession
|
||||
:param id:
|
||||
:type id: int
|
||||
:rtype: Optional["User"]
|
||||
"""
|
||||
query = select(cls).filter(cls.id == id)
|
||||
return await session.scalar(query.order_by(cls.id))
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls, session: AsyncSession, username: str, password: str
|
||||
) -> "User":
|
||||
"""create.
|
||||
|
||||
:param session:
|
||||
:type session: AsyncSession
|
||||
:param username:
|
||||
:type username: str
|
||||
:param password:
|
||||
:type password: str
|
||||
:rtype: "User"
|
||||
"""
|
||||
exsisting_user = await User.get_by_username(session, username)
|
||||
assert exsisting_user is None, "user already exists"
|
||||
user = cls(username=username)
|
||||
|
|
Loading…
Reference in a new issue