[mypy] fixed auth
This commit is contained in:
		
							parent
							
								
									5d9c2e8bd8
								
							
						
					
					
						commit
						cc2c381dbf
					
				
					 3 changed files with 51 additions and 12 deletions
				
			
		| 
						 | 
					@ -5,7 +5,7 @@ from jose import JWTError, jwt
 | 
				
			||||||
from fastapi.security import OAuth2PasswordBearer
 | 
					from fastapi.security import OAuth2PasswordBearer
 | 
				
			||||||
from fastapi import Depends, HTTPException
 | 
					from fastapi import Depends, HTTPException
 | 
				
			||||||
from fastapi_users.password import PasswordHelper
 | 
					from fastapi_users.password import PasswordHelper
 | 
				
			||||||
from typing import AsyncGenerator, Annotated
 | 
					from typing import Annotated
 | 
				
			||||||
from datetime import datetime, timedelta
 | 
					from datetime import datetime, timedelta
 | 
				
			||||||
from .settings import Settings
 | 
					from .settings import Settings
 | 
				
			||||||
from .domain.user import User
 | 
					from .domain.user import User
 | 
				
			||||||
| 
						 | 
					@ -16,7 +16,7 @@ from .db import get_session
 | 
				
			||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
 | 
					pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
 | 
				
			||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
 | 
					oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
 | 
				
			||||||
settings = Settings()
 | 
					settings = Settings()
 | 
				
			||||||
password_helper = PasswordHelper(pwd_context)
 | 
					password_helper = PasswordHelper(pwd_context)  # type: ignore
 | 
				
			||||||
 | 
					
 | 
				
			||||||
AsyncSessionDependency = Annotated[async_sessionmaker, Depends(get_session)]
 | 
					AsyncSessionDependency = Annotated[async_sessionmaker, Depends(get_session)]
 | 
				
			||||||
TokenDependency = Annotated[str, Depends(oauth2_scheme)]
 | 
					TokenDependency = Annotated[str, Depends(oauth2_scheme)]
 | 
				
			||||||
| 
						 | 
					@ -32,35 +32,57 @@ def get_password_hash(password: str) -> str:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def authenticate_user(
 | 
					async def authenticate_user(
 | 
				
			||||||
    session: AsyncSession, username: str, password: str
 | 
					    session: AsyncSession, username: str, password: str
 | 
				
			||||||
) -> AsyncGenerator[User, None]:
 | 
					) -> User | None:
 | 
				
			||||||
    user = await User.get_by_username(session, username)
 | 
					    user = await User.get_by_username(session, username)
 | 
				
			||||||
    if not user:
 | 
					
 | 
				
			||||||
 | 
					    if user is None:
 | 
				
			||||||
        return None
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert user is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if not verify_password(password, user.hashed_password):
 | 
					    if not verify_password(password, user.hashed_password):
 | 
				
			||||||
        return None
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return user
 | 
					    return user
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def verify_refresh_token(
 | 
					async def verify_refresh_token(
 | 
				
			||||||
    session: AsyncSession, token: str
 | 
					    session: AsyncSession, token: str
 | 
				
			||||||
) -> AsyncGenerator[RefreshToken, None]:
 | 
					) -> RefreshToken | None:
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        payload = jwt.decode(
 | 
					        payload = jwt.decode(
 | 
				
			||||||
            token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
 | 
					            token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        username: str = payload.get("sub")
 | 
					        sub = payload.get("sub")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if sub is None:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not isinstance(sub, str):
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        username: str = str(sub)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if username is None:
 | 
					        if username is None:
 | 
				
			||||||
            return
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    except JWTError:
 | 
					    except JWTError:
 | 
				
			||||||
        return
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    user = await User.get_by_username(session, username)
 | 
					    user = await User.get_by_username(session, username)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if user is None:
 | 
					    if user is None:
 | 
				
			||||||
        return
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert user is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    current_token = await RefreshToken.get_token(session, user.id, token)
 | 
					    current_token = await RefreshToken.get_token(session, user.id, token)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if current_token is not None:
 | 
					    if current_token is not None:
 | 
				
			||||||
        return current_token
 | 
					        return current_token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_access_token(user: User) -> str:
 | 
					def create_access_token(user: User) -> str:
 | 
				
			||||||
    expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
 | 
					    expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
 | 
				
			||||||
| 
						 | 
					@ -94,13 +116,29 @@ async def get_current_user(
 | 
				
			||||||
            payload = jwt.decode(
 | 
					            payload = jwt.decode(
 | 
				
			||||||
                token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
 | 
					                token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            username: str = payload.get("sub")
 | 
					            sub = payload.get("sub")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if sub is None:
 | 
				
			||||||
 | 
					                raise HTTPException(status_code=401, detail="Unathorized")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if not isinstance(sub, str):
 | 
				
			||||||
 | 
					                raise HTTPException(status_code=401, detail="Unathorized")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            username: str = str(sub)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if username is None:
 | 
					            if username is None:
 | 
				
			||||||
                raise HTTPException(status_code=401, detail="Unathorized")
 | 
					                raise HTTPException(status_code=401, detail="Unathorized")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except JWTError:
 | 
					        except JWTError:
 | 
				
			||||||
            raise HTTPException(status_code=401, detail="Unathorized")
 | 
					            raise HTTPException(status_code=401, detail="Unathorized")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return await User.get_by_username(session, username)
 | 
					        user = await User.get_by_username(session, username)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if user is None:
 | 
				
			||||||
 | 
					            raise HTTPException(status_code=401, detail="Unathorized")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert user is not None
 | 
				
			||||||
 | 
					        return user
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def authorize_api_key(
 | 
					async def authorize_api_key(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -159,6 +159,6 @@ class Entry(Base, CommonMixin):
 | 
				
			||||||
        cls,
 | 
					        cls,
 | 
				
			||||||
        session: AsyncSession,
 | 
					        session: AsyncSession,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        stmt = update(cls).where(cls.processed is False).values(processed=True)
 | 
					        stmt = update(cls).where(cls.processed == False).values(processed=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        await session.execute(stmt)
 | 
					        await session.execute(stmt)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										1
									
								
								mypy.ini
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								mypy.ini
									
									
									
									
									
								
							| 
						 | 
					@ -10,3 +10,4 @@ platform = linux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
warn_unused_configs = True
 | 
					warn_unused_configs = True
 | 
				
			||||||
warn_unused_ignores = True
 | 
					warn_unused_ignores = True
 | 
				
			||||||
 | 
					allow_redefinition = True
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue