[mypy] fixed auth

This commit is contained in:
Piotr Domański 2024-05-21 14:40:31 +02:00
parent 5d9c2e8bd8
commit cc2c381dbf
3 changed files with 51 additions and 12 deletions

View file

@ -5,7 +5,7 @@ from jose import JWTError, jwt
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException
from fastapi_users.password import PasswordHelper from fastapi_users.password import PasswordHelper
from typing import AsyncGenerator, Annotated from typing import Annotated
from datetime import datetime, timedelta from datetime import datetime, timedelta
from .settings import Settings from .settings import Settings
from .domain.user import User from .domain.user import User
@ -16,7 +16,7 @@ from .db import get_session
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
settings = Settings() settings = Settings()
password_helper = PasswordHelper(pwd_context) password_helper = PasswordHelper(pwd_context) # type: ignore
AsyncSessionDependency = Annotated[async_sessionmaker, Depends(get_session)] AsyncSessionDependency = Annotated[async_sessionmaker, Depends(get_session)]
TokenDependency = Annotated[str, Depends(oauth2_scheme)] TokenDependency = Annotated[str, Depends(oauth2_scheme)]
@ -32,35 +32,57 @@ def get_password_hash(password: str) -> str:
async def authenticate_user( async def authenticate_user(
session: AsyncSession, username: str, password: str session: AsyncSession, username: str, password: str
) -> AsyncGenerator[User, None]: ) -> User | None:
user = await User.get_by_username(session, username) user = await User.get_by_username(session, username)
if not user:
if user is None:
return None return None
assert user is not None
if not verify_password(password, user.hashed_password): if not verify_password(password, user.hashed_password):
return None return None
return user return user
async def verify_refresh_token( async def verify_refresh_token(
session: AsyncSession, token: str session: AsyncSession, token: str
) -> AsyncGenerator[RefreshToken, None]: ) -> RefreshToken | None:
try: try:
payload = jwt.decode( payload = jwt.decode(
token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM] token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
) )
username: str = payload.get("sub") sub = payload.get("sub")
if sub is None:
return None
if not isinstance(sub, str):
return None
username: str = str(sub)
if username is None: if username is None:
return return None
except JWTError: except JWTError:
return return None
user = await User.get_by_username(session, username) user = await User.get_by_username(session, username)
if user is None: if user is None:
return return None
assert user is not None
current_token = await RefreshToken.get_token(session, user.id, token) current_token = await RefreshToken.get_token(session, user.id, token)
if current_token is not None: if current_token is not None:
return current_token return current_token
return None
def create_access_token(user: User) -> str: def create_access_token(user: User) -> str:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
@ -94,13 +116,29 @@ async def get_current_user(
payload = jwt.decode( payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
) )
username: str = payload.get("sub") sub = payload.get("sub")
if sub is None:
raise HTTPException(status_code=401, detail="Unathorized")
if not isinstance(sub, str):
raise HTTPException(status_code=401, detail="Unathorized")
username: str = str(sub)
if username is None: if username is None:
raise HTTPException(status_code=401, detail="Unathorized") raise HTTPException(status_code=401, detail="Unathorized")
except JWTError: except JWTError:
raise HTTPException(status_code=401, detail="Unathorized") raise HTTPException(status_code=401, detail="Unathorized")
return await User.get_by_username(session, username) user = await User.get_by_username(session, username)
if user is None:
raise HTTPException(status_code=401, detail="Unathorized")
assert user is not None
return user
async def authorize_api_key( async def authorize_api_key(

View file

@ -159,6 +159,6 @@ class Entry(Base, CommonMixin):
cls, cls,
session: AsyncSession, session: AsyncSession,
) -> None: ) -> None:
stmt = update(cls).where(cls.processed is False).values(processed=True) stmt = update(cls).where(cls.processed == False).values(processed=True)
await session.execute(stmt) await session.execute(stmt)

View file

@ -10,3 +10,4 @@ platform = linux
warn_unused_configs = True warn_unused_configs = True
warn_unused_ignores = True warn_unused_ignores = True
allow_redefinition = True