From e38949e03a5d02cc3c49aae87c3fc5e93f0a17cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Doma=C5=84ski?= Date: Sun, 2 Apr 2023 15:20:53 +0200 Subject: [PATCH] add refresh tokens --- fooder/auth.py | 10 +++++--- fooder/controller/token.py | 7 ++++-- fooder/domain/token.py | 38 +++++++++++++++++++---------- fooder/domain/user.py | 50 ++++++++++++++++++++++++++++++++------ 4 files changed, 79 insertions(+), 26 deletions(-) diff --git a/fooder/auth.py b/fooder/auth.py index 65cbd5c..856cbea 100644 --- a/fooder/auth.py +++ b/fooder/auth.py @@ -44,7 +44,7 @@ async def authenticate_user( async def verify_refresh_token( session: AsyncSession, token: str -) -> AsyncGenerator[User, None]: +) -> AsyncGenerator[RefreshToken, None]: try: payload = jwt.decode( token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM] @@ -56,9 +56,11 @@ async def verify_refresh_token( return user = await User.get_by_username(session, username) - current_token = await RefreshToken.get_token(session, user.id) - if current_token is not None and current_token.token == token: - return user + if user is None: + return + current_token = await RefreshToken.get_token(session, user.id, token) + if current_token is not None: + return current_token def create_access_token(user: User) -> str: diff --git a/fooder/controller/token.py b/fooder/controller/token.py index 6f8ce93..5ac77f2 100644 --- a/fooder/controller/token.py +++ b/fooder/controller/token.py @@ -37,11 +37,14 @@ class CreateToken(BaseController): class RefreshToken(BaseController): async def call(self, content: RefreshTokenPayload) -> Token: async with self.async_session.begin() as session: - user = await verify_refresh_token(session, content.refresh_token) + current_token = await verify_refresh_token(session, content.refresh_token) - if user is None: + if current_token is None: raise HTTPException(status_code=401, detail="Invalid token") + user = await DBUser.get(session, current_token.user_id) + await current_token.delete(session) + refresh_token = await create_refresh_token(session, user) access_token = create_access_token(user) diff --git a/fooder/domain/token.py b/fooder/domain/token.py index d93c008..67a2e6a 100644 --- a/fooder/domain/token.py +++ b/fooder/domain/token.py @@ -1,18 +1,16 @@ -from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload -from sqlalchemy import ForeignKey, Integer, Date +from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload, relationship +from sqlalchemy import ForeignKey, Integer from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from datetime import date from typing import Optional from .base import Base, CommonMixin -from .meal import Meal class RefreshToken(Base, CommonMixin): """Diary represents user diary for given day""" - user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"), unique=True) + user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id")) token: Mapped[str] @classmethod @@ -20,9 +18,19 @@ class RefreshToken(Base, CommonMixin): cls, session: AsyncSession, user_id: int, + token: str, ) -> "Optional[RefreshToken]": - """get_token.""" - query = select(cls).where(cls.user_id == user_id) + """get_token. + + :param session: + :type session: AsyncSession + :param user_id: + :type user_id: int + :param token: + :type token: str + :rtype: "Optional[RefreshToken]" + """ + query = select(cls).where(cls.user_id == user_id).where(cls.token == token) return await session.scalar(query) @classmethod @@ -39,12 +47,6 @@ class RefreshToken(Base, CommonMixin): :type token: str :rtype: "RefreshToken" """ - existing = await cls.get_token(session, user_id) - - if existing: - existing.token = token - return existing - token = cls( user_id=user_id, token=token, @@ -57,3 +59,13 @@ class RefreshToken(Base, CommonMixin): raise AssertionError("invalid token") return token + + async def delete(self, session: AsyncSession) -> None: + """delete. + + :param session: + :type session: AsyncSession + :rtype: None + """ + await session.delete(self) + await session.flush() diff --git a/fooder/domain/user.py b/fooder/domain/user.py index abf6ccc..47157ff 100644 --- a/fooder/domain/user.py +++ b/fooder/domain/user.py @@ -12,22 +12,58 @@ class User(Base, CommonMixin): username: Mapped[str] hashed_password: Mapped[str] - @classmethod - async def get_by_username( - cls, session: AsyncSession, username: str - ) -> Optional["User"]: - query = select(cls).filter(cls.username == username) - return await session.scalar(query.order_by(cls.id)) - def set_password(self, password) -> None: + """set_password. + + :param password: + :rtype: None + """ from ..auth import password_helper self.hashed_password = password_helper.hash(password) + @classmethod + async def get_by_username( + cls, session: AsyncSession, username: str + ) -> Optional["User"]: + """get_by_username. + + :param session: + :type session: AsyncSession + :param username: + :type username: str + :rtype: Optional["User"] + """ + query = select(cls).filter(cls.username == username) + return await session.scalar(query.order_by(cls.id)) + + @classmethod + async def get(cls, session: AsyncSession, id: int) -> Optional["User"]: + """get_by_username. + + :param session: + :type session: AsyncSession + :param id: + :type id: int + :rtype: Optional["User"] + """ + query = select(cls).filter(cls.id == id) + return await session.scalar(query.order_by(cls.id)) + @classmethod async def create( cls, session: AsyncSession, username: str, password: str ) -> "User": + """create. + + :param session: + :type session: AsyncSession + :param username: + :type username: str + :param password: + :type password: str + :rtype: "User" + """ exsisting_user = await User.get_by_username(session, username) assert exsisting_user is None, "user already exists" user = cls(username=username)