add refresh tokens
This commit is contained in:
parent
3b3e709e56
commit
e38949e03a
4 changed files with 79 additions and 26 deletions
|
@ -44,7 +44,7 @@ async def authenticate_user(
|
||||||
|
|
||||||
async def verify_refresh_token(
|
async def verify_refresh_token(
|
||||||
session: AsyncSession, token: str
|
session: AsyncSession, token: str
|
||||||
) -> AsyncGenerator[User, None]:
|
) -> AsyncGenerator[RefreshToken, None]:
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
|
token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||||
|
@ -56,9 +56,11 @@ async def verify_refresh_token(
|
||||||
return
|
return
|
||||||
|
|
||||||
user = await User.get_by_username(session, username)
|
user = await User.get_by_username(session, username)
|
||||||
current_token = await RefreshToken.get_token(session, user.id)
|
if user is None:
|
||||||
if current_token is not None and current_token.token == token:
|
return
|
||||||
return user
|
current_token = await RefreshToken.get_token(session, user.id, token)
|
||||||
|
if current_token is not None:
|
||||||
|
return current_token
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(user: User) -> str:
|
def create_access_token(user: User) -> str:
|
||||||
|
|
|
@ -37,11 +37,14 @@ class CreateToken(BaseController):
|
||||||
class RefreshToken(BaseController):
|
class RefreshToken(BaseController):
|
||||||
async def call(self, content: RefreshTokenPayload) -> Token:
|
async def call(self, content: RefreshTokenPayload) -> Token:
|
||||||
async with self.async_session.begin() as session:
|
async with self.async_session.begin() as session:
|
||||||
user = await verify_refresh_token(session, content.refresh_token)
|
current_token = await verify_refresh_token(session, content.refresh_token)
|
||||||
|
|
||||||
if user is None:
|
if current_token is None:
|
||||||
raise HTTPException(status_code=401, detail="Invalid token")
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
|
|
||||||
|
user = await DBUser.get(session, current_token.user_id)
|
||||||
|
await current_token.delete(session)
|
||||||
|
|
||||||
refresh_token = await create_refresh_token(session, user)
|
refresh_token = await create_refresh_token(session, user)
|
||||||
access_token = create_access_token(user)
|
access_token = create_access_token(user)
|
||||||
|
|
||||||
|
|
|
@ -1,18 +1,16 @@
|
||||||
from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload
|
from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload, relationship
|
||||||
from sqlalchemy import ForeignKey, Integer, Date
|
from sqlalchemy import ForeignKey, Integer
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from datetime import date
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .base import Base, CommonMixin
|
from .base import Base, CommonMixin
|
||||||
from .meal import Meal
|
|
||||||
|
|
||||||
|
|
||||||
class RefreshToken(Base, CommonMixin):
|
class RefreshToken(Base, CommonMixin):
|
||||||
"""Diary represents user diary for given day"""
|
"""Diary represents user diary for given day"""
|
||||||
|
|
||||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"), unique=True)
|
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
|
||||||
token: Mapped[str]
|
token: Mapped[str]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -20,9 +18,19 @@ class RefreshToken(Base, CommonMixin):
|
||||||
cls,
|
cls,
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
user_id: int,
|
user_id: int,
|
||||||
|
token: str,
|
||||||
) -> "Optional[RefreshToken]":
|
) -> "Optional[RefreshToken]":
|
||||||
"""get_token."""
|
"""get_token.
|
||||||
query = select(cls).where(cls.user_id == user_id)
|
|
||||||
|
:param session:
|
||||||
|
:type session: AsyncSession
|
||||||
|
:param user_id:
|
||||||
|
:type user_id: int
|
||||||
|
:param token:
|
||||||
|
:type token: str
|
||||||
|
:rtype: "Optional[RefreshToken]"
|
||||||
|
"""
|
||||||
|
query = select(cls).where(cls.user_id == user_id).where(cls.token == token)
|
||||||
return await session.scalar(query)
|
return await session.scalar(query)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -39,12 +47,6 @@ class RefreshToken(Base, CommonMixin):
|
||||||
:type token: str
|
:type token: str
|
||||||
:rtype: "RefreshToken"
|
:rtype: "RefreshToken"
|
||||||
"""
|
"""
|
||||||
existing = await cls.get_token(session, user_id)
|
|
||||||
|
|
||||||
if existing:
|
|
||||||
existing.token = token
|
|
||||||
return existing
|
|
||||||
|
|
||||||
token = cls(
|
token = cls(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
|
@ -57,3 +59,13 @@ class RefreshToken(Base, CommonMixin):
|
||||||
raise AssertionError("invalid token")
|
raise AssertionError("invalid token")
|
||||||
|
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
async def delete(self, session: AsyncSession) -> None:
|
||||||
|
"""delete.
|
||||||
|
|
||||||
|
:param session:
|
||||||
|
:type session: AsyncSession
|
||||||
|
:rtype: None
|
||||||
|
"""
|
||||||
|
await session.delete(self)
|
||||||
|
await session.flush()
|
||||||
|
|
|
@ -12,22 +12,58 @@ class User(Base, CommonMixin):
|
||||||
username: Mapped[str]
|
username: Mapped[str]
|
||||||
hashed_password: Mapped[str]
|
hashed_password: Mapped[str]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_by_username(
|
|
||||||
cls, session: AsyncSession, username: str
|
|
||||||
) -> Optional["User"]:
|
|
||||||
query = select(cls).filter(cls.username == username)
|
|
||||||
return await session.scalar(query.order_by(cls.id))
|
|
||||||
|
|
||||||
def set_password(self, password) -> None:
|
def set_password(self, password) -> None:
|
||||||
|
"""set_password.
|
||||||
|
|
||||||
|
:param password:
|
||||||
|
:rtype: None
|
||||||
|
"""
|
||||||
from ..auth import password_helper
|
from ..auth import password_helper
|
||||||
|
|
||||||
self.hashed_password = password_helper.hash(password)
|
self.hashed_password = password_helper.hash(password)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_by_username(
|
||||||
|
cls, session: AsyncSession, username: str
|
||||||
|
) -> Optional["User"]:
|
||||||
|
"""get_by_username.
|
||||||
|
|
||||||
|
:param session:
|
||||||
|
:type session: AsyncSession
|
||||||
|
:param username:
|
||||||
|
:type username: str
|
||||||
|
:rtype: Optional["User"]
|
||||||
|
"""
|
||||||
|
query = select(cls).filter(cls.username == username)
|
||||||
|
return await session.scalar(query.order_by(cls.id))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get(cls, session: AsyncSession, id: int) -> Optional["User"]:
|
||||||
|
"""get_by_username.
|
||||||
|
|
||||||
|
:param session:
|
||||||
|
:type session: AsyncSession
|
||||||
|
:param id:
|
||||||
|
:type id: int
|
||||||
|
:rtype: Optional["User"]
|
||||||
|
"""
|
||||||
|
query = select(cls).filter(cls.id == id)
|
||||||
|
return await session.scalar(query.order_by(cls.id))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(
|
async def create(
|
||||||
cls, session: AsyncSession, username: str, password: str
|
cls, session: AsyncSession, username: str, password: str
|
||||||
) -> "User":
|
) -> "User":
|
||||||
|
"""create.
|
||||||
|
|
||||||
|
:param session:
|
||||||
|
:type session: AsyncSession
|
||||||
|
:param username:
|
||||||
|
:type username: str
|
||||||
|
:param password:
|
||||||
|
:type password: str
|
||||||
|
:rtype: "User"
|
||||||
|
"""
|
||||||
exsisting_user = await User.get_by_username(session, username)
|
exsisting_user = await User.get_by_username(session, username)
|
||||||
assert exsisting_user is None, "user already exists"
|
assert exsisting_user is None, "user already exists"
|
||||||
user = cls(username=username)
|
user = cls(username=username)
|
||||||
|
|
Loading…
Reference in a new issue