[mypy] fixed auth

This commit is contained in:
Piotr Domański 2024-05-21 14:40:31 +02:00
parent 5d9c2e8bd8
commit cc2c381dbf
3 changed files with 51 additions and 12 deletions

View file

@ -5,7 +5,7 @@ from jose import JWTError, jwt
from fastapi.security import OAuth2PasswordBearer
from fastapi import Depends, HTTPException
from fastapi_users.password import PasswordHelper
from typing import AsyncGenerator, Annotated
from typing import Annotated
from datetime import datetime, timedelta
from .settings import Settings
from .domain.user import User
@ -16,7 +16,7 @@ from .db import get_session
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
settings = Settings()
password_helper = PasswordHelper(pwd_context)
password_helper = PasswordHelper(pwd_context) # type: ignore
AsyncSessionDependency = Annotated[async_sessionmaker, Depends(get_session)]
TokenDependency = Annotated[str, Depends(oauth2_scheme)]
@ -32,35 +32,57 @@ def get_password_hash(password: str) -> str:
async def authenticate_user(
session: AsyncSession, username: str, password: str
) -> AsyncGenerator[User, None]:
) -> User | None:
user = await User.get_by_username(session, username)
if not user:
if user is None:
return None
assert user is not None
if not verify_password(password, user.hashed_password):
return None
return user
async def verify_refresh_token(
session: AsyncSession, token: str
) -> AsyncGenerator[RefreshToken, None]:
) -> RefreshToken | None:
try:
payload = jwt.decode(
token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
)
username: str = payload.get("sub")
sub = payload.get("sub")
if sub is None:
return None
if not isinstance(sub, str):
return None
username: str = str(sub)
if username is None:
return
return None
except JWTError:
return
return None
user = await User.get_by_username(session, username)
if user is None:
return
return None
assert user is not None
current_token = await RefreshToken.get_token(session, user.id, token)
if current_token is not None:
return current_token
return None
def create_access_token(user: User) -> str:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
@ -94,13 +116,29 @@ async def get_current_user(
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
username: str = payload.get("sub")
sub = payload.get("sub")
if sub is None:
raise HTTPException(status_code=401, detail="Unathorized")
if not isinstance(sub, str):
raise HTTPException(status_code=401, detail="Unathorized")
username: str = str(sub)
if username is None:
raise HTTPException(status_code=401, detail="Unathorized")
except JWTError:
raise HTTPException(status_code=401, detail="Unathorized")
return await User.get_by_username(session, username)
user = await User.get_by_username(session, username)
if user is None:
raise HTTPException(status_code=401, detail="Unathorized")
assert user is not None
return user
async def authorize_api_key(

View file

@ -159,6 +159,6 @@ class Entry(Base, CommonMixin):
cls,
session: AsyncSession,
) -> None:
stmt = update(cls).where(cls.processed is False).values(processed=True)
stmt = update(cls).where(cls.processed == False).values(processed=True)
await session.execute(stmt)

View file

@ -10,3 +10,4 @@ platform = linux
warn_unused_configs = True
warn_unused_ignores = True
allow_redefinition = True