controllers and commands, create_token works now
This commit is contained in:
parent
4950f0dfa4
commit
74ec8aa834
31 changed files with 326 additions and 304 deletions
|
|
@ -13,9 +13,8 @@ if __name__ == "__main__":
|
|||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
from .domain import Base
|
||||
from .settings import Settings
|
||||
from .settings import settings
|
||||
|
||||
settings = Settings()
|
||||
engine = sqlalchemy.create_engine(
|
||||
settings.DB_URI.replace("+asyncpg", "").replace("+aiosqlite", "")
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
import logging
|
||||
|
||||
from .router import router
|
||||
from .settings import Settings
|
||||
from .settings import settings
|
||||
from .exc import ApiException
|
||||
|
||||
app = FastAPI(title="Fooder")
|
||||
app.include_router(router)
|
||||
|
|
@ -10,8 +13,17 @@ app.include_router(router)
|
|||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=Settings().ALLOWED_ORIGINS,
|
||||
allow_origins=settings.ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(ApiException)
|
||||
async def exception_handler_ErrorBase(_: Request, exc: ApiException):
|
||||
headers = {"www-authenticate": "Bearer"} if exc.HTTP_CODE == 401 else {}
|
||||
logging.exception(exc)
|
||||
return JSONResponse(
|
||||
status_code=exc.HTTP_CODE, content={"message": exc.message}, headers=headers
|
||||
)
|
||||
|
|
|
|||
125
fooder/auth.py
125
fooder/auth.py
|
|
@ -1,125 +0,0 @@
|
|||
from datetime import datetime, timedelta
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from .db import get_db_session
|
||||
from .domain.token import RefreshToken
|
||||
from .domain.user import User
|
||||
from .settings import Settings
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
|
||||
|
||||
|
||||
async def authenticate_user(
|
||||
session: AsyncSession, username: str, password: str
|
||||
) -> User | None:
|
||||
user = await User.get_by_username(session, username)
|
||||
|
||||
if user is None:
|
||||
return None
|
||||
|
||||
assert user is not None
|
||||
|
||||
if not verify_password(password, user.hashed_password):
|
||||
return None
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def verify_refresh_token(
|
||||
session: AsyncSession, token: str
|
||||
) -> RefreshToken | None:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
sub = payload.get("sub")
|
||||
|
||||
if sub is None:
|
||||
return None
|
||||
|
||||
if not isinstance(sub, str):
|
||||
return None
|
||||
|
||||
username: str = str(sub)
|
||||
|
||||
if username is None:
|
||||
return None
|
||||
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
user = await User.get_by_username(session, username)
|
||||
|
||||
if user is None:
|
||||
return None
|
||||
|
||||
assert user is not None
|
||||
|
||||
current_token = await RefreshToken.get_token(session, user.id, token)
|
||||
|
||||
if current_token is not None:
|
||||
return current_token
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def create_access_token(user: User) -> str:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode = {
|
||||
"sub": user.username,
|
||||
"exp": expire,
|
||||
}
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
async def create_refresh_token(session: AsyncSession, user: User) -> RefreshToken:
|
||||
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
to_encode = {
|
||||
"sub": user.username,
|
||||
"exp": expire,
|
||||
}
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.REFRESH_SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
return await RefreshToken.create(session, token=encoded_jwt, user_id=user.id)
|
||||
|
||||
|
||||
async def get_current_user(ssn: AsyncSessionDependency, token: TokenDependency) -> User:
|
||||
async with ssn() as session:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
sub = payload.get("sub")
|
||||
|
||||
if sub is None:
|
||||
raise HTTPException(status_code=401, detail="Unathorized")
|
||||
|
||||
if not isinstance(sub, str):
|
||||
raise HTTPException(status_code=401, detail="Unathorized")
|
||||
|
||||
username: str = str(sub)
|
||||
|
||||
if username is None:
|
||||
raise HTTPException(status_code=401, detail="Unathorized")
|
||||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=401, detail="Unathorized")
|
||||
|
||||
user = await User.get_by_username(session, username)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Unathorized")
|
||||
|
||||
assert user is not None
|
||||
return user
|
||||
0
fooder/command/__init__.py
Normal file
0
fooder/command/__init__.py
Normal file
|
|
@ -1,7 +1,14 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from fastapi import Depends
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from typing import Callable
|
||||
from datetime import datetime
|
||||
from .db import get_db_session
|
||||
from .domain import User
|
||||
from .repository import Repository
|
||||
from .utils.datetime import utc_now
|
||||
from .utils.jwt import AccessToken
|
||||
from .exc import Unauthorized
|
||||
|
||||
|
||||
class Context:
|
||||
|
|
@ -9,14 +16,26 @@ class Context:
|
|||
Main API context, aggregating dependencies
|
||||
"""
|
||||
|
||||
def __init__(self, repo: Repository) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
repo: Repository,
|
||||
clock: Callable[[], datetime] = utc_now,
|
||||
_user: User | None = None,
|
||||
) -> None:
|
||||
self.repo = repo
|
||||
self.clock = clock
|
||||
self._user = _user
|
||||
|
||||
@property
|
||||
def user(self) -> User:
|
||||
if self._user is None:
|
||||
raise Unauthorized()
|
||||
return self._user
|
||||
|
||||
|
||||
class ContextDependency:
|
||||
"""
|
||||
Configurable context dependecy. Allows for shared interface configuring
|
||||
method required dependencies
|
||||
Context dependecy
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -29,3 +48,21 @@ class ContextDependency:
|
|||
session: AsyncSession = Depends(get_db_session),
|
||||
):
|
||||
return Context(repo=Repository(session))
|
||||
|
||||
|
||||
class AuthContextDependency:
|
||||
"""
|
||||
Context dependecy for authorized endpionts
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
token: str = Depends(OAuth2PasswordBearer(tokenUrl="/token")),
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
) -> Context:
|
||||
access_token = AccessToken.decode(token)
|
||||
repo = Repository(session)
|
||||
user = await repo.user.get(User.id == access_token.sub)
|
||||
if user is None:
|
||||
raise Unauthorized()
|
||||
return Context(repo=repo, _user=user)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
from .user import UserController
|
||||
from .token import TokenController
|
||||
|
|
@ -1,33 +1,17 @@
|
|||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
from ..auth import authorize_api_key, get_current_user
|
||||
from ..db import get_db_session, AsyncSession
|
||||
from ..domain.user import User
|
||||
|
||||
UserDependency = Annotated[User, Depends(get_current_user)]
|
||||
ApiKeyDependency = Annotated[None, Depends(authorize_api_key)]
|
||||
from ..context import Context
|
||||
from typing import TypeVar, Generic
|
||||
from sqlalchemy import BinaryExpression
|
||||
|
||||
|
||||
class BaseController:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def call(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> Any:
|
||||
return await self.call(*args, **kwargs)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AuthorizedController(BaseController):
|
||||
def __init__(self, session: AsyncSession, user: UserDependency) -> None:
|
||||
super().__init__(session)
|
||||
self.user = user
|
||||
class ControllerBase:
|
||||
def __init__(self, ctx: Context) -> None:
|
||||
self.ctx = ctx
|
||||
|
||||
|
||||
class TasksSessionController(BaseController):
|
||||
def __init__(self, session: AsyncSession, api_key: ApiKeyDependency) -> None:
|
||||
super().__init__(session)
|
||||
class ModelController(Generic[T], ControllerBase):
|
||||
def __init__(self, ctx: Context, obj: T):
|
||||
super().__init__(ctx)
|
||||
self.obj = obj
|
||||
|
|
|
|||
|
|
@ -1,58 +1,25 @@
|
|||
from fastapi import HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from .base import ControllerBase
|
||||
from ..context import Context
|
||||
from ..utils.jwt import Token, AccessToken, RefreshToken
|
||||
from typing import Type, TypeVar
|
||||
from datetime import datetime
|
||||
|
||||
from ..auth import (
|
||||
authenticate_user,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
verify_refresh_token,
|
||||
)
|
||||
from ..domain.user import User as DBUser
|
||||
from ..model.token import RefreshTokenPayload, Token
|
||||
from .base import BaseController
|
||||
T = TypeVar("T", bound=Token)
|
||||
|
||||
|
||||
class CreateToken(BaseController):
|
||||
async def call(self, content: OAuth2PasswordRequestForm) -> Token:
|
||||
async with self.async_session.begin() as session:
|
||||
user = await authenticate_user(session, content.username, content.password)
|
||||
class TokenController(ControllerBase):
|
||||
def __init__(self, ctx: Context, entity_id: int) -> None:
|
||||
super().__init__(ctx)
|
||||
self.entity_id = entity_id
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid username or password"
|
||||
)
|
||||
def generate_token(self, token_cls: Type[T], now: datetime) -> T:
|
||||
return token_cls(exp=token_cls.calculate_exp(now), sub=self.entity_id)
|
||||
|
||||
refresh_token = await create_refresh_token(session, user)
|
||||
access_token = create_access_token(user)
|
||||
def generate_refresh_token(self, now: datetime) -> RefreshToken:
|
||||
return self.generate_token(RefreshToken, now)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token.token,
|
||||
token_type="bearer",
|
||||
)
|
||||
def generate_access_token(self, now: datetime) -> AccessToken:
|
||||
return self.generate_token(AccessToken, now)
|
||||
|
||||
|
||||
class RefreshToken(BaseController):
|
||||
async def call(self, content: RefreshTokenPayload) -> Token:
|
||||
async with self.async_session.begin() as session:
|
||||
current_token = await verify_refresh_token(session, content.refresh_token)
|
||||
|
||||
if current_token is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
user = await DBUser.get(session, current_token.user_id)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
assert user is not None
|
||||
await current_token.delete(session)
|
||||
|
||||
refresh_token = await create_refresh_token(session, user)
|
||||
access_token = create_access_token(user)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token.token,
|
||||
token_type="bearer",
|
||||
)
|
||||
def generate_token_pair(self, now: datetime) -> tuple[AccessToken, RefreshToken]:
|
||||
return (self.generate_access_token(now), self.generate_refresh_token(now))
|
||||
|
|
|
|||
|
|
@ -1,19 +1,24 @@
|
|||
from fastapi import HTTPException
|
||||
|
||||
from ..domain.user import User as DBUser
|
||||
from ..model.user import CreateUserPayload, User
|
||||
from .base import BaseController
|
||||
from .base import ModelController
|
||||
from ..domain import User
|
||||
from ..context import Context
|
||||
from ..exc import Unauthorized
|
||||
from .token import TokenController
|
||||
|
||||
|
||||
class CreateUser(BaseController):
|
||||
async def call(self, content: CreateUserPayload) -> User:
|
||||
async with self.async_session.begin() as session:
|
||||
try:
|
||||
user = await DBUser.create(
|
||||
session,
|
||||
content.username,
|
||||
content.password,
|
||||
)
|
||||
return User.from_orm(user)
|
||||
except AssertionError as e:
|
||||
raise HTTPException(status_code=400, detail=e.args[0])
|
||||
class UserController(ModelController[User]):
|
||||
@classmethod
|
||||
async def session_start(
|
||||
cls,
|
||||
ctx: Context,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> "UserController":
|
||||
obj = await ctx.repo.user.get(User.username == username)
|
||||
|
||||
if obj is None or not obj.verify_password(password):
|
||||
raise Unauthorized()
|
||||
|
||||
return cls(ctx, obj)
|
||||
|
||||
def token_ctrl(self) -> TokenController:
|
||||
return TokenController(ctx=self.ctx, entity_id=self.obj.id)
|
||||
|
|
|
|||
0
fooder/controller_old/__init__.py
Normal file
0
fooder/controller_old/__init__.py
Normal file
33
fooder/controller_old/base.py
Normal file
33
fooder/controller_old/base.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
from ..auth import authorize_api_key, get_current_user
|
||||
from ..db import get_db_session, AsyncSession
|
||||
from ..domain.user import User
|
||||
|
||||
UserDependency = Annotated[User, Depends(get_current_user)]
|
||||
ApiKeyDependency = Annotated[None, Depends(authorize_api_key)]
|
||||
|
||||
|
||||
class BaseController:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def call(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> Any:
|
||||
return await self.call(*args, **kwargs)
|
||||
|
||||
|
||||
class AuthorizedController(BaseController):
|
||||
def __init__(self, session: AsyncSession, user: UserDependency) -> None:
|
||||
super().__init__(session)
|
||||
self.user = user
|
||||
|
||||
|
||||
class TasksSessionController(BaseController):
|
||||
def __init__(self, session: AsyncSession, api_key: ApiKeyDependency) -> None:
|
||||
super().__init__(session)
|
||||
58
fooder/controller_old/token.py
Normal file
58
fooder/controller_old/token.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
from fastapi import HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from ..auth import (
|
||||
authenticate_user,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
verify_refresh_token,
|
||||
)
|
||||
from ..domain.user import User as DBUser
|
||||
from ..model.token import RefreshTokenPayload, Token
|
||||
from .base import BaseController
|
||||
|
||||
|
||||
class CreateToken(BaseController):
|
||||
async def call(self, content: OAuth2PasswordRequestForm) -> Token:
|
||||
async with self.async_session.begin() as session:
|
||||
user = await authenticate_user(session, content.username, content.password)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid username or password"
|
||||
)
|
||||
|
||||
refresh_token = await create_refresh_token(session, user)
|
||||
access_token = create_access_token(user)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token.token,
|
||||
token_type="bearer",
|
||||
)
|
||||
|
||||
|
||||
class RefreshToken(BaseController):
|
||||
async def call(self, content: RefreshTokenPayload) -> Token:
|
||||
async with self.async_session.begin() as session:
|
||||
current_token = await verify_refresh_token(session, content.refresh_token)
|
||||
|
||||
if current_token is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
user = await DBUser.get(session, current_token.user_id)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
assert user is not None
|
||||
await current_token.delete(session)
|
||||
|
||||
refresh_token = await create_refresh_token(session, user)
|
||||
access_token = create_access_token(user)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token.token,
|
||||
token_type="bearer",
|
||||
)
|
||||
19
fooder/controller_old/user.py
Normal file
19
fooder/controller_old/user.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from fastapi import HTTPException
|
||||
|
||||
from ..domain.user import User as DBUser
|
||||
from ..model.user import CreateUserPayload, User
|
||||
from .base import BaseController
|
||||
|
||||
|
||||
class CreateUser(BaseController):
|
||||
async def call(self, content: CreateUserPayload) -> User:
|
||||
async with self.async_session.begin() as session:
|
||||
try:
|
||||
user = await DBUser.create(
|
||||
session,
|
||||
content.username,
|
||||
content.password,
|
||||
)
|
||||
return User.from_orm(user)
|
||||
except AssertionError as e:
|
||||
raise HTTPException(status_code=400, detail=e.args[0])
|
||||
19
fooder/exc.py
Normal file
19
fooder/exc.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from typing import ClassVar
|
||||
|
||||
|
||||
class ApiException(Exception):
|
||||
HTTP_CODE: ClassVar[int]
|
||||
MESSAGE: ClassVar[str]
|
||||
|
||||
def __init__(self, message: str | None = None) -> None:
|
||||
self.message = message or self.MESSAGE
|
||||
|
||||
|
||||
class NotFound(ApiException):
|
||||
HTTP_CODE = 404
|
||||
MESSAGE = "Not found"
|
||||
|
||||
|
||||
class Unauthorized(ApiException):
|
||||
HTTP_CODE = 401
|
||||
MESSAGE = "Unathorized"
|
||||
|
|
@ -1,15 +1,11 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str | None = None
|
||||
|
||||
|
||||
class RefreshTokenPayload(BaseModel):
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
|
|
|||
|
|
@ -1,14 +1,10 @@
|
|||
from typing import TypeVar, Generic, Type, Sequence
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import (
|
||||
select,
|
||||
update as sa_update,
|
||||
delete as sa_delete,
|
||||
BinaryExpression,
|
||||
)
|
||||
from sqlalchemy import select, delete as sa_delete, ColumnElement
|
||||
from sqlalchemy.sql import Select
|
||||
from ..domain import Base
|
||||
|
||||
T = TypeVar("T")
|
||||
T = TypeVar("T", bound=Base)
|
||||
|
||||
|
||||
class RepositoryBase(Generic[T]):
|
||||
|
|
@ -16,7 +12,7 @@ class RepositoryBase(Generic[T]):
|
|||
self.model = model
|
||||
self.session = session
|
||||
|
||||
def _build_select(self, *expressions: BinaryExpression) -> Select[tuple[T]]:
|
||||
def _build_select(self, *expressions: ColumnElement) -> Select[tuple[T]]:
|
||||
stmt = select(self.model)
|
||||
|
||||
if expressions:
|
||||
|
|
@ -24,12 +20,12 @@ class RepositoryBase(Generic[T]):
|
|||
|
||||
return stmt
|
||||
|
||||
async def get(self, *expressions: BinaryExpression) -> T | None:
|
||||
async def get(self, *expressions: ColumnElement) -> T | None:
|
||||
stmt = self._build_select(*expressions)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list(self, *expressions: BinaryExpression) -> Sequence[T]:
|
||||
async def list(self, *expressions: ColumnElement) -> Sequence[T]:
|
||||
stmt = self._build_select(*expressions)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
|
@ -40,7 +36,7 @@ class RepositoryBase(Generic[T]):
|
|||
await self.session.refresh(obj)
|
||||
return obj
|
||||
|
||||
async def delete(self, *expressions: BinaryExpression):
|
||||
async def delete(self, *expressions: ColumnElement):
|
||||
stmt = sa_delete(self.model)
|
||||
|
||||
if expressions:
|
||||
|
|
|
|||
|
|
@ -1,18 +1,6 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
# from .view.diary import router as diary_router
|
||||
# from .view.entry import router as entry_router
|
||||
# from .view.meal import router as meal_router
|
||||
# from .view.preset import router as preset_router
|
||||
# from .view.product import router as product_router
|
||||
from .view.token import router as token_router
|
||||
# from .view.user import router as user_router
|
||||
|
||||
router = APIRouter(prefix="/api")
|
||||
# router.include_router(product_router, prefix="/product", tags=["product"])
|
||||
# router.include_router(diary_router, prefix="/diary", tags=["diary"])
|
||||
# router.include_router(meal_router, prefix="/meal", tags=["meal"])
|
||||
# router.include_router(entry_router, prefix="/entry", tags=["entry"])
|
||||
router.include_router(token_router, prefix="/token", tags=["token"])
|
||||
# router.include_router(user_router, prefix="/user", tags=["user"])
|
||||
# router.include_router(preset_router, prefix="/preset", tags=["preset"])
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from .settings import Settings
|
||||
from .view.tasks import router
|
||||
|
||||
app = FastAPI(title="Fooder Tasks admininstrative API")
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=Settings().ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
0
fooder/test/controller/__init__.py
Normal file
0
fooder/test/controller/__init__.py
Normal file
6
fooder/test/controller/test_token.py
Normal file
6
fooder/test/controller/test_token.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from fooder.controller import TokenController
|
||||
|
||||
|
||||
def test_token_ctrl_generates_token(ctx):
|
||||
token_ctrl = TokenController(ctx, 1)
|
||||
token_ctrl.generate_token_pair(ctx.clock())
|
||||
11
fooder/test/test_exc.py
Normal file
11
fooder/test/test_exc.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from fooder.exc import ApiException
|
||||
|
||||
|
||||
class TestException(ApiException):
|
||||
HTTP_CODE = 0
|
||||
MESSAGE = "test"
|
||||
|
||||
|
||||
def test_exc_message():
|
||||
assert TestException().message == TestException.MESSAGE
|
||||
assert TestException("other message").message == "other message"
|
||||
|
|
@ -1,7 +1,9 @@
|
|||
from datetime import datetime, timedelta, timezone
|
||||
import pytest
|
||||
from jose import JWTError
|
||||
from jose import jwt
|
||||
from typing import Literal
|
||||
|
||||
from fooder.exc import Unauthorized
|
||||
from fooder.utils.jwt import AccessToken, RefreshToken, Token
|
||||
|
||||
|
||||
|
|
@ -9,6 +11,7 @@ PAST = datetime(2000, 1, 1, tzinfo=timezone.utc)
|
|||
|
||||
|
||||
class WrongKeyToken(Token):
|
||||
token_type: Literal["test-type"] = "test-type"
|
||||
secret_key = "wrong-secret"
|
||||
expire_delta = timedelta(minutes=30)
|
||||
|
||||
|
|
@ -29,15 +32,22 @@ class TestAccessToken:
|
|||
now = datetime.now(timezone.utc)
|
||||
token = WrongKeyToken(exp=WrongKeyToken.calculate_exp(now), sub=1)
|
||||
|
||||
with pytest.raises(JWTError):
|
||||
with pytest.raises(Unauthorized):
|
||||
AccessToken.decode(token.encode())
|
||||
|
||||
def test_decode_expired_raises(self):
|
||||
token = AccessToken(exp=PAST, sub=1)
|
||||
|
||||
with pytest.raises(JWTError):
|
||||
with pytest.raises(Unauthorized):
|
||||
AccessToken.decode(token.encode())
|
||||
|
||||
def test_encoded_fields(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
token = AccessToken(exp=AccessToken.calculate_exp(now), sub=42)
|
||||
payload = jwt.decode(token.encode(), "", options={"verify_signature": False})
|
||||
assert "secret_key" not in payload
|
||||
assert "expire_delta" not in payload
|
||||
|
||||
|
||||
class TestRefreshToken:
|
||||
def test_encode_decode_roundtrip(self):
|
||||
|
|
@ -55,12 +65,19 @@ class TestRefreshToken:
|
|||
now = datetime.now(timezone.utc)
|
||||
token = RefreshToken(exp=RefreshToken.calculate_exp(now), sub=1)
|
||||
|
||||
with pytest.raises(JWTError):
|
||||
with pytest.raises(Unauthorized):
|
||||
AccessToken.decode(token.encode())
|
||||
|
||||
def test_access_token_not_decodable_as_refresh_token(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
token = AccessToken(exp=AccessToken.calculate_exp(now), sub=1)
|
||||
|
||||
with pytest.raises(JWTError):
|
||||
with pytest.raises(Unauthorized):
|
||||
RefreshToken.decode(token.encode())
|
||||
|
||||
def test_encoded_fields(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
token = AccessToken(exp=RefreshToken.calculate_exp(now), sub=42)
|
||||
payload = jwt.decode(token.encode(), "", options={"verify_signature": False})
|
||||
assert "secret_key" not in payload
|
||||
assert "expire_delta" not in payload
|
||||
|
|
|
|||
5
fooder/utils/datetime.py
Normal file
5
fooder/utils/datetime.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
|
@ -1,11 +1,14 @@
|
|||
from jose import jwt
|
||||
from jose import jwt, JOSEError
|
||||
from pydantic import BaseModel
|
||||
from datetime import timedelta, datetime
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, Literal
|
||||
import logging
|
||||
from ..settings import settings
|
||||
from ..exc import Unauthorized
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
token_type: str
|
||||
exp: datetime
|
||||
sub: int
|
||||
|
||||
|
|
@ -18,7 +21,13 @@ class Token(BaseModel):
|
|||
|
||||
@classmethod
|
||||
def decode(cls, jwt_token: str | bytes) -> "Token":
|
||||
data = jwt.decode(jwt_token, cls.secret_key, algorithms=[settings.ALGORITHM])
|
||||
try:
|
||||
data = jwt.decode(
|
||||
jwt_token, cls.secret_key, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
except JOSEError as e:
|
||||
logging.error(e)
|
||||
raise Unauthorized()
|
||||
return cls(**data)
|
||||
|
||||
def encode(self) -> str:
|
||||
|
|
@ -28,10 +37,12 @@ class Token(BaseModel):
|
|||
|
||||
|
||||
class AccessToken(Token):
|
||||
token_type: Literal["access"] = "access"
|
||||
secret_key = settings.SECRET_KEY
|
||||
expire_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
|
||||
class RefreshToken(Token):
|
||||
token_type: Literal["refresh"] = "refresh"
|
||||
secret_key = settings.REFRESH_SECRET_KEY
|
||||
expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
|
|
|||
|
|
@ -1,39 +1,38 @@
|
|||
from typing import Annotated
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from datetime import datetime
|
||||
|
||||
from ..model.token import RefreshTokenPayload, Token
|
||||
from ..model.token import TokenResponse, RefreshTokenRequest
|
||||
from ..context import ContextDependency, Context
|
||||
from ..utils.jwt import AccessToken, RefreshToken
|
||||
from ..domain import User
|
||||
from ..controller import UserController
|
||||
|
||||
router = APIRouter(tags=["token"])
|
||||
|
||||
|
||||
@router.post("", response_model=Token)
|
||||
async def create_token(
|
||||
data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
ctx: Context = Depends(ContextDependency()),
|
||||
):
|
||||
user = await ctx.repo.user.get(User.username == data.username)
|
||||
|
||||
if user is None or not user.verify_password(data.password):
|
||||
raise HTTPException(status_code=401, detail="Unathorized")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
access_token = AccessToken(sub=user.id, exp=AccessToken.calculate_exp(now))
|
||||
refresh_token = RefreshToken(sub=user.id, exp=RefreshToken.calculate_exp(now))
|
||||
return Token(
|
||||
def gen_token_response(user_ctrl: UserController, now: datetime) -> TokenResponse:
|
||||
token_ctrl = user_ctrl.token_ctrl()
|
||||
access_token, refresh_token = token_ctrl.generate_token_pair(now)
|
||||
return TokenResponse(
|
||||
access_token=access_token.encode(),
|
||||
refresh_token=refresh_token.encode(),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token)
|
||||
async def refresh_token(
|
||||
data: RefreshTokenPayload,
|
||||
@router.post("", response_model=TokenResponse)
|
||||
async def token_create(
|
||||
data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
ctx: Context = Depends(ContextDependency()),
|
||||
) -> TokenResponse:
|
||||
now = ctx.clock()
|
||||
user_ctrl = await UserController.session_start(ctx, data.username, data.password)
|
||||
return gen_token_response(user_ctrl, now)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def token_refresh(
|
||||
data: RefreshTokenRequest,
|
||||
ctx: Context = Depends(ContextDependency()),
|
||||
):
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in a new issue