add refresh tokens
This commit is contained in:
		
							parent
							
								
									3b3e709e56
								
							
						
					
					
						commit
						e38949e03a
					
				
					 4 changed files with 79 additions and 26 deletions
				
			
		| 
						 | 
					@ -44,7 +44,7 @@ async def authenticate_user(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def verify_refresh_token(
 | 
					async def verify_refresh_token(
 | 
				
			||||||
    session: AsyncSession, token: str
 | 
					    session: AsyncSession, token: str
 | 
				
			||||||
) -> AsyncGenerator[User, None]:
 | 
					) -> AsyncGenerator[RefreshToken, None]:
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        payload = jwt.decode(
 | 
					        payload = jwt.decode(
 | 
				
			||||||
            token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
 | 
					            token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
 | 
				
			||||||
| 
						 | 
					@ -56,9 +56,11 @@ async def verify_refresh_token(
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    user = await User.get_by_username(session, username)
 | 
					    user = await User.get_by_username(session, username)
 | 
				
			||||||
    current_token = await RefreshToken.get_token(session, user.id)
 | 
					    if user is None:
 | 
				
			||||||
    if current_token is not None and current_token.token == token:
 | 
					        return
 | 
				
			||||||
        return user
 | 
					    current_token = await RefreshToken.get_token(session, user.id, token)
 | 
				
			||||||
 | 
					    if current_token is not None:
 | 
				
			||||||
 | 
					        return current_token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_access_token(user: User) -> str:
 | 
					def create_access_token(user: User) -> str:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,11 +37,14 @@ class CreateToken(BaseController):
 | 
				
			||||||
class RefreshToken(BaseController):
 | 
					class RefreshToken(BaseController):
 | 
				
			||||||
    async def call(self, content: RefreshTokenPayload) -> Token:
 | 
					    async def call(self, content: RefreshTokenPayload) -> Token:
 | 
				
			||||||
        async with self.async_session.begin() as session:
 | 
					        async with self.async_session.begin() as session:
 | 
				
			||||||
            user = await verify_refresh_token(session, content.refresh_token)
 | 
					            current_token = await verify_refresh_token(session, content.refresh_token)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if user is None:
 | 
					            if current_token is None:
 | 
				
			||||||
                raise HTTPException(status_code=401, detail="Invalid token")
 | 
					                raise HTTPException(status_code=401, detail="Invalid token")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            user = await DBUser.get(session, current_token.user_id)
 | 
				
			||||||
 | 
					            await current_token.delete(session)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            refresh_token = await create_refresh_token(session, user)
 | 
					            refresh_token = await create_refresh_token(session, user)
 | 
				
			||||||
            access_token = create_access_token(user)
 | 
					            access_token = create_access_token(user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,18 +1,16 @@
 | 
				
			||||||
from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload
 | 
					from sqlalchemy.orm import relationship, Mapped, mapped_column, joinedload, relationship
 | 
				
			||||||
from sqlalchemy import ForeignKey, Integer, Date
 | 
					from sqlalchemy import ForeignKey, Integer
 | 
				
			||||||
from sqlalchemy import select
 | 
					from sqlalchemy import select
 | 
				
			||||||
from sqlalchemy.ext.asyncio import AsyncSession
 | 
					from sqlalchemy.ext.asyncio import AsyncSession
 | 
				
			||||||
from datetime import date
 | 
					 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .base import Base, CommonMixin
 | 
					from .base import Base, CommonMixin
 | 
				
			||||||
from .meal import Meal
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class RefreshToken(Base, CommonMixin):
 | 
					class RefreshToken(Base, CommonMixin):
 | 
				
			||||||
    """Diary represents user diary for given day"""
 | 
					    """Diary represents user diary for given day"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"), unique=True)
 | 
					    user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
 | 
				
			||||||
    token: Mapped[str]
 | 
					    token: Mapped[str]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
| 
						 | 
					@ -20,9 +18,19 @@ class RefreshToken(Base, CommonMixin):
 | 
				
			||||||
        cls,
 | 
					        cls,
 | 
				
			||||||
        session: AsyncSession,
 | 
					        session: AsyncSession,
 | 
				
			||||||
        user_id: int,
 | 
					        user_id: int,
 | 
				
			||||||
 | 
					        token: str,
 | 
				
			||||||
    ) -> "Optional[RefreshToken]":
 | 
					    ) -> "Optional[RefreshToken]":
 | 
				
			||||||
        """get_token."""
 | 
					        """get_token.
 | 
				
			||||||
        query = select(cls).where(cls.user_id == user_id)
 | 
					
 | 
				
			||||||
 | 
					        :param session:
 | 
				
			||||||
 | 
					        :type session: AsyncSession
 | 
				
			||||||
 | 
					        :param user_id:
 | 
				
			||||||
 | 
					        :type user_id: int
 | 
				
			||||||
 | 
					        :param token:
 | 
				
			||||||
 | 
					        :type token: str
 | 
				
			||||||
 | 
					        :rtype: "Optional[RefreshToken]"
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        query = select(cls).where(cls.user_id == user_id).where(cls.token == token)
 | 
				
			||||||
        return await session.scalar(query)
 | 
					        return await session.scalar(query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
| 
						 | 
					@ -39,12 +47,6 @@ class RefreshToken(Base, CommonMixin):
 | 
				
			||||||
        :type token: str
 | 
					        :type token: str
 | 
				
			||||||
        :rtype: "RefreshToken"
 | 
					        :rtype: "RefreshToken"
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        existing = await cls.get_token(session, user_id)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if existing:
 | 
					 | 
				
			||||||
            existing.token = token
 | 
					 | 
				
			||||||
            return existing
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        token = cls(
 | 
					        token = cls(
 | 
				
			||||||
            user_id=user_id,
 | 
					            user_id=user_id,
 | 
				
			||||||
            token=token,
 | 
					            token=token,
 | 
				
			||||||
| 
						 | 
					@ -57,3 +59,13 @@ class RefreshToken(Base, CommonMixin):
 | 
				
			||||||
            raise AssertionError("invalid token")
 | 
					            raise AssertionError("invalid token")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return token
 | 
					        return token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def delete(self, session: AsyncSession) -> None:
 | 
				
			||||||
 | 
					        """delete.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param session:
 | 
				
			||||||
 | 
					        :type session: AsyncSession
 | 
				
			||||||
 | 
					        :rtype: None
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        await session.delete(self)
 | 
				
			||||||
 | 
					        await session.flush()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -12,22 +12,58 @@ class User(Base, CommonMixin):
 | 
				
			||||||
    username: Mapped[str]
 | 
					    username: Mapped[str]
 | 
				
			||||||
    hashed_password: Mapped[str]
 | 
					    hashed_password: Mapped[str]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					 | 
				
			||||||
    async def get_by_username(
 | 
					 | 
				
			||||||
        cls, session: AsyncSession, username: str
 | 
					 | 
				
			||||||
    ) -> Optional["User"]:
 | 
					 | 
				
			||||||
        query = select(cls).filter(cls.username == username)
 | 
					 | 
				
			||||||
        return await session.scalar(query.order_by(cls.id))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def set_password(self, password) -> None:
 | 
					    def set_password(self, password) -> None:
 | 
				
			||||||
 | 
					        """set_password.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param password:
 | 
				
			||||||
 | 
					        :rtype: None
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        from ..auth import password_helper
 | 
					        from ..auth import password_helper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.hashed_password = password_helper.hash(password)
 | 
					        self.hashed_password = password_helper.hash(password)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    async def get_by_username(
 | 
				
			||||||
 | 
					        cls, session: AsyncSession, username: str
 | 
				
			||||||
 | 
					    ) -> Optional["User"]:
 | 
				
			||||||
 | 
					        """get_by_username.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param session:
 | 
				
			||||||
 | 
					        :type session: AsyncSession
 | 
				
			||||||
 | 
					        :param username:
 | 
				
			||||||
 | 
					        :type username: str
 | 
				
			||||||
 | 
					        :rtype: Optional["User"]
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        query = select(cls).filter(cls.username == username)
 | 
				
			||||||
 | 
					        return await session.scalar(query.order_by(cls.id))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    async def get(cls, session: AsyncSession, id: int) -> Optional["User"]:
 | 
				
			||||||
 | 
					        """get_by_username.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param session:
 | 
				
			||||||
 | 
					        :type session: AsyncSession
 | 
				
			||||||
 | 
					        :param id:
 | 
				
			||||||
 | 
					        :type id: int
 | 
				
			||||||
 | 
					        :rtype: Optional["User"]
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        query = select(cls).filter(cls.id == id)
 | 
				
			||||||
 | 
					        return await session.scalar(query.order_by(cls.id))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    async def create(
 | 
					    async def create(
 | 
				
			||||||
        cls, session: AsyncSession, username: str, password: str
 | 
					        cls, session: AsyncSession, username: str, password: str
 | 
				
			||||||
    ) -> "User":
 | 
					    ) -> "User":
 | 
				
			||||||
 | 
					        """create.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param session:
 | 
				
			||||||
 | 
					        :type session: AsyncSession
 | 
				
			||||||
 | 
					        :param username:
 | 
				
			||||||
 | 
					        :type username: str
 | 
				
			||||||
 | 
					        :param password:
 | 
				
			||||||
 | 
					        :type password: str
 | 
				
			||||||
 | 
					        :rtype: "User"
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        exsisting_user = await User.get_by_username(session, username)
 | 
					        exsisting_user = await User.get_by_username(session, username)
 | 
				
			||||||
        assert exsisting_user is None, "user already exists"
 | 
					        assert exsisting_user is None, "user already exists"
 | 
				
			||||||
        user = cls(username=username)
 | 
					        user = cls(username=username)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue