controllers and commands, create_token works now
This commit is contained in:
parent
4950f0dfa4
commit
74ec8aa834
31 changed files with 326 additions and 304 deletions
|
|
@ -13,9 +13,8 @@ if __name__ == "__main__":
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from .domain import Base
|
from .domain import Base
|
||||||
from .settings import Settings
|
from .settings import settings
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
engine = sqlalchemy.create_engine(
|
engine = sqlalchemy.create_engine(
|
||||||
settings.DB_URI.replace("+asyncpg", "").replace("+aiosqlite", "")
|
settings.DB_URI.replace("+asyncpg", "").replace("+aiosqlite", "")
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,11 @@
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
import logging
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
from .settings import Settings
|
from .settings import settings
|
||||||
|
from .exc import ApiException
|
||||||
|
|
||||||
app = FastAPI(title="Fooder")
|
app = FastAPI(title="Fooder")
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|
@ -10,8 +13,17 @@ app.include_router(router)
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=Settings().ALLOWED_ORIGINS,
|
allow_origins=settings.ALLOWED_ORIGINS,
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(ApiException)
|
||||||
|
async def exception_handler_ErrorBase(_: Request, exc: ApiException):
|
||||||
|
headers = {"www-authenticate": "Bearer"} if exc.HTTP_CODE == 401 else {}
|
||||||
|
logging.exception(exc)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.HTTP_CODE, content={"message": exc.message}, headers=headers
|
||||||
|
)
|
||||||
|
|
|
||||||
125
fooder/auth.py
125
fooder/auth.py
|
|
@ -1,125 +0,0 @@
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException
|
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
from fastapi_users.password import PasswordHelper
|
|
||||||
from jose import JWTError, jwt
|
|
||||||
from passlib.context import CryptContext
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
||||||
|
|
||||||
from .db import get_db_session
|
|
||||||
from .domain.token import RefreshToken
|
|
||||||
from .domain.user import User
|
|
||||||
from .settings import Settings
|
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
|
|
||||||
|
|
||||||
|
|
||||||
async def authenticate_user(
|
|
||||||
session: AsyncSession, username: str, password: str
|
|
||||||
) -> User | None:
|
|
||||||
user = await User.get_by_username(session, username)
|
|
||||||
|
|
||||||
if user is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
assert user is not None
|
|
||||||
|
|
||||||
if not verify_password(password, user.hashed_password):
|
|
||||||
return None
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
async def verify_refresh_token(
|
|
||||||
session: AsyncSession, token: str
|
|
||||||
) -> RefreshToken | None:
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(
|
|
||||||
token, settings.REFRESH_SECRET_KEY, algorithms=[settings.ALGORITHM]
|
|
||||||
)
|
|
||||||
sub = payload.get("sub")
|
|
||||||
|
|
||||||
if sub is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not isinstance(sub, str):
|
|
||||||
return None
|
|
||||||
|
|
||||||
username: str = str(sub)
|
|
||||||
|
|
||||||
if username is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
except JWTError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
user = await User.get_by_username(session, username)
|
|
||||||
|
|
||||||
if user is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
assert user is not None
|
|
||||||
|
|
||||||
current_token = await RefreshToken.get_token(session, user.id, token)
|
|
||||||
|
|
||||||
if current_token is not None:
|
|
||||||
return current_token
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(user: User) -> str:
|
|
||||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
||||||
to_encode = {
|
|
||||||
"sub": user.username,
|
|
||||||
"exp": expire,
|
|
||||||
}
|
|
||||||
encoded_jwt = jwt.encode(
|
|
||||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
|
||||||
)
|
|
||||||
return encoded_jwt
|
|
||||||
|
|
||||||
|
|
||||||
async def create_refresh_token(session: AsyncSession, user: User) -> RefreshToken:
|
|
||||||
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
|
||||||
to_encode = {
|
|
||||||
"sub": user.username,
|
|
||||||
"exp": expire,
|
|
||||||
}
|
|
||||||
encoded_jwt = jwt.encode(
|
|
||||||
to_encode, settings.REFRESH_SECRET_KEY, algorithm=settings.ALGORITHM
|
|
||||||
)
|
|
||||||
return await RefreshToken.create(session, token=encoded_jwt, user_id=user.id)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(ssn: AsyncSessionDependency, token: TokenDependency) -> User:
|
|
||||||
async with ssn() as session:
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(
|
|
||||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
|
||||||
)
|
|
||||||
sub = payload.get("sub")
|
|
||||||
|
|
||||||
if sub is None:
|
|
||||||
raise HTTPException(status_code=401, detail="Unathorized")
|
|
||||||
|
|
||||||
if not isinstance(sub, str):
|
|
||||||
raise HTTPException(status_code=401, detail="Unathorized")
|
|
||||||
|
|
||||||
username: str = str(sub)
|
|
||||||
|
|
||||||
if username is None:
|
|
||||||
raise HTTPException(status_code=401, detail="Unathorized")
|
|
||||||
|
|
||||||
except JWTError:
|
|
||||||
raise HTTPException(status_code=401, detail="Unathorized")
|
|
||||||
|
|
||||||
user = await User.get_by_username(session, username)
|
|
||||||
|
|
||||||
if user is None:
|
|
||||||
raise HTTPException(status_code=401, detail="Unathorized")
|
|
||||||
|
|
||||||
assert user is not None
|
|
||||||
return user
|
|
||||||
0
fooder/command/__init__.py
Normal file
0
fooder/command/__init__.py
Normal file
|
|
@ -1,7 +1,14 @@
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from typing import Callable
|
||||||
|
from datetime import datetime
|
||||||
from .db import get_db_session
|
from .db import get_db_session
|
||||||
|
from .domain import User
|
||||||
from .repository import Repository
|
from .repository import Repository
|
||||||
|
from .utils.datetime import utc_now
|
||||||
|
from .utils.jwt import AccessToken
|
||||||
|
from .exc import Unauthorized
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
|
|
@ -9,14 +16,26 @@ class Context:
|
||||||
Main API context, aggregating dependencies
|
Main API context, aggregating dependencies
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, repo: Repository) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
repo: Repository,
|
||||||
|
clock: Callable[[], datetime] = utc_now,
|
||||||
|
_user: User | None = None,
|
||||||
|
) -> None:
|
||||||
self.repo = repo
|
self.repo = repo
|
||||||
|
self.clock = clock
|
||||||
|
self._user = _user
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user(self) -> User:
|
||||||
|
if self._user is None:
|
||||||
|
raise Unauthorized()
|
||||||
|
return self._user
|
||||||
|
|
||||||
|
|
||||||
class ContextDependency:
|
class ContextDependency:
|
||||||
"""
|
"""
|
||||||
Configurable context dependecy. Allows for shared interface configuring
|
Context dependecy
|
||||||
method required dependencies
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -29,3 +48,21 @@ class ContextDependency:
|
||||||
session: AsyncSession = Depends(get_db_session),
|
session: AsyncSession = Depends(get_db_session),
|
||||||
):
|
):
|
||||||
return Context(repo=Repository(session))
|
return Context(repo=Repository(session))
|
||||||
|
|
||||||
|
|
||||||
|
class AuthContextDependency:
|
||||||
|
"""
|
||||||
|
Context dependecy for authorized endpionts
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
token: str = Depends(OAuth2PasswordBearer(tokenUrl="/token")),
|
||||||
|
session: AsyncSession = Depends(get_db_session),
|
||||||
|
) -> Context:
|
||||||
|
access_token = AccessToken.decode(token)
|
||||||
|
repo = Repository(session)
|
||||||
|
user = await repo.user.get(User.id == access_token.sub)
|
||||||
|
if user is None:
|
||||||
|
raise Unauthorized()
|
||||||
|
return Context(repo=repo, _user=user)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .user import UserController
|
||||||
|
from .token import TokenController
|
||||||
|
|
@ -1,33 +1,17 @@
|
||||||
from typing import Annotated, Any
|
from ..context import Context
|
||||||
|
from typing import TypeVar, Generic
|
||||||
from fastapi import Depends
|
from sqlalchemy import BinaryExpression
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|
||||||
|
|
||||||
from ..auth import authorize_api_key, get_current_user
|
|
||||||
from ..db import get_db_session, AsyncSession
|
|
||||||
from ..domain.user import User
|
|
||||||
|
|
||||||
UserDependency = Annotated[User, Depends(get_current_user)]
|
|
||||||
ApiKeyDependency = Annotated[None, Depends(authorize_api_key)]
|
|
||||||
|
|
||||||
|
|
||||||
class BaseController:
|
T = TypeVar("T")
|
||||||
def __init__(self, session: AsyncSession) -> None:
|
|
||||||
self.session = session
|
|
||||||
|
|
||||||
async def call(self, *args, **kwargs) -> Any:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def __call__(self, *args, **kwargs) -> Any:
|
|
||||||
return await self.call(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthorizedController(BaseController):
|
class ControllerBase:
|
||||||
def __init__(self, session: AsyncSession, user: UserDependency) -> None:
|
def __init__(self, ctx: Context) -> None:
|
||||||
super().__init__(session)
|
self.ctx = ctx
|
||||||
self.user = user
|
|
||||||
|
|
||||||
|
|
||||||
class TasksSessionController(BaseController):
|
class ModelController(Generic[T], ControllerBase):
|
||||||
def __init__(self, session: AsyncSession, api_key: ApiKeyDependency) -> None:
|
def __init__(self, ctx: Context, obj: T):
|
||||||
super().__init__(session)
|
super().__init__(ctx)
|
||||||
|
self.obj = obj
|
||||||
|
|
|
||||||
|
|
@ -1,58 +1,25 @@
|
||||||
from fastapi import HTTPException
|
from .base import ControllerBase
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from ..context import Context
|
||||||
|
from ..utils.jwt import Token, AccessToken, RefreshToken
|
||||||
|
from typing import Type, TypeVar
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from ..auth import (
|
T = TypeVar("T", bound=Token)
|
||||||
authenticate_user,
|
|
||||||
create_access_token,
|
|
||||||
create_refresh_token,
|
|
||||||
verify_refresh_token,
|
|
||||||
)
|
|
||||||
from ..domain.user import User as DBUser
|
|
||||||
from ..model.token import RefreshTokenPayload, Token
|
|
||||||
from .base import BaseController
|
|
||||||
|
|
||||||
|
|
||||||
class CreateToken(BaseController):
|
class TokenController(ControllerBase):
|
||||||
async def call(self, content: OAuth2PasswordRequestForm) -> Token:
|
def __init__(self, ctx: Context, entity_id: int) -> None:
|
||||||
async with self.async_session.begin() as session:
|
super().__init__(ctx)
|
||||||
user = await authenticate_user(session, content.username, content.password)
|
self.entity_id = entity_id
|
||||||
|
|
||||||
if user is None:
|
def generate_token(self, token_cls: Type[T], now: datetime) -> T:
|
||||||
raise HTTPException(
|
return token_cls(exp=token_cls.calculate_exp(now), sub=self.entity_id)
|
||||||
status_code=401, detail="Invalid username or password"
|
|
||||||
)
|
|
||||||
|
|
||||||
refresh_token = await create_refresh_token(session, user)
|
def generate_refresh_token(self, now: datetime) -> RefreshToken:
|
||||||
access_token = create_access_token(user)
|
return self.generate_token(RefreshToken, now)
|
||||||
|
|
||||||
return Token(
|
def generate_access_token(self, now: datetime) -> AccessToken:
|
||||||
access_token=access_token,
|
return self.generate_token(AccessToken, now)
|
||||||
refresh_token=refresh_token.token,
|
|
||||||
token_type="bearer",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def generate_token_pair(self, now: datetime) -> tuple[AccessToken, RefreshToken]:
|
||||||
class RefreshToken(BaseController):
|
return (self.generate_access_token(now), self.generate_refresh_token(now))
|
||||||
async def call(self, content: RefreshTokenPayload) -> Token:
|
|
||||||
async with self.async_session.begin() as session:
|
|
||||||
current_token = await verify_refresh_token(session, content.refresh_token)
|
|
||||||
|
|
||||||
if current_token is None:
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid token")
|
|
||||||
|
|
||||||
user = await DBUser.get(session, current_token.user_id)
|
|
||||||
|
|
||||||
if user is None:
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid token")
|
|
||||||
|
|
||||||
assert user is not None
|
|
||||||
await current_token.delete(session)
|
|
||||||
|
|
||||||
refresh_token = await create_refresh_token(session, user)
|
|
||||||
access_token = create_access_token(user)
|
|
||||||
|
|
||||||
return Token(
|
|
||||||
access_token=access_token,
|
|
||||||
refresh_token=refresh_token.token,
|
|
||||||
token_type="bearer",
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,24 @@
|
||||||
from fastapi import HTTPException
|
from .base import ModelController
|
||||||
|
from ..domain import User
|
||||||
from ..domain.user import User as DBUser
|
from ..context import Context
|
||||||
from ..model.user import CreateUserPayload, User
|
from ..exc import Unauthorized
|
||||||
from .base import BaseController
|
from .token import TokenController
|
||||||
|
|
||||||
|
|
||||||
class CreateUser(BaseController):
|
class UserController(ModelController[User]):
|
||||||
async def call(self, content: CreateUserPayload) -> User:
|
@classmethod
|
||||||
async with self.async_session.begin() as session:
|
async def session_start(
|
||||||
try:
|
cls,
|
||||||
user = await DBUser.create(
|
ctx: Context,
|
||||||
session,
|
username: str,
|
||||||
content.username,
|
password: str,
|
||||||
content.password,
|
) -> "UserController":
|
||||||
)
|
obj = await ctx.repo.user.get(User.username == username)
|
||||||
return User.from_orm(user)
|
|
||||||
except AssertionError as e:
|
if obj is None or not obj.verify_password(password):
|
||||||
raise HTTPException(status_code=400, detail=e.args[0])
|
raise Unauthorized()
|
||||||
|
|
||||||
|
return cls(ctx, obj)
|
||||||
|
|
||||||
|
def token_ctrl(self) -> TokenController:
|
||||||
|
return TokenController(ctx=self.ctx, entity_id=self.obj.id)
|
||||||
|
|
|
||||||
0
fooder/controller_old/__init__.py
Normal file
0
fooder/controller_old/__init__.py
Normal file
33
fooder/controller_old/base.py
Normal file
33
fooder/controller_old/base.py
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||||
|
|
||||||
|
from ..auth import authorize_api_key, get_current_user
|
||||||
|
from ..db import get_db_session, AsyncSession
|
||||||
|
from ..domain.user import User
|
||||||
|
|
||||||
|
UserDependency = Annotated[User, Depends(get_current_user)]
|
||||||
|
ApiKeyDependency = Annotated[None, Depends(authorize_api_key)]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseController:
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def call(self, *args, **kwargs) -> Any:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def __call__(self, *args, **kwargs) -> Any:
|
||||||
|
return await self.call(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizedController(BaseController):
|
||||||
|
def __init__(self, session: AsyncSession, user: UserDependency) -> None:
|
||||||
|
super().__init__(session)
|
||||||
|
self.user = user
|
||||||
|
|
||||||
|
|
||||||
|
class TasksSessionController(BaseController):
|
||||||
|
def __init__(self, session: AsyncSession, api_key: ApiKeyDependency) -> None:
|
||||||
|
super().__init__(session)
|
||||||
58
fooder/controller_old/token.py
Normal file
58
fooder/controller_old/token.py
Normal file
|
|
@ -0,0 +1,58 @@
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
|
from ..auth import (
|
||||||
|
authenticate_user,
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
verify_refresh_token,
|
||||||
|
)
|
||||||
|
from ..domain.user import User as DBUser
|
||||||
|
from ..model.token import RefreshTokenPayload, Token
|
||||||
|
from .base import BaseController
|
||||||
|
|
||||||
|
|
||||||
|
class CreateToken(BaseController):
|
||||||
|
async def call(self, content: OAuth2PasswordRequestForm) -> Token:
|
||||||
|
async with self.async_session.begin() as session:
|
||||||
|
user = await authenticate_user(session, content.username, content.password)
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401, detail="Invalid username or password"
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_token = await create_refresh_token(session, user)
|
||||||
|
access_token = create_access_token(user)
|
||||||
|
|
||||||
|
return Token(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token.token,
|
||||||
|
token_type="bearer",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshToken(BaseController):
|
||||||
|
async def call(self, content: RefreshTokenPayload) -> Token:
|
||||||
|
async with self.async_session.begin() as session:
|
||||||
|
current_token = await verify_refresh_token(session, content.refresh_token)
|
||||||
|
|
||||||
|
if current_token is None:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
|
|
||||||
|
user = await DBUser.get(session, current_token.user_id)
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
|
|
||||||
|
assert user is not None
|
||||||
|
await current_token.delete(session)
|
||||||
|
|
||||||
|
refresh_token = await create_refresh_token(session, user)
|
||||||
|
access_token = create_access_token(user)
|
||||||
|
|
||||||
|
return Token(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token.token,
|
||||||
|
token_type="bearer",
|
||||||
|
)
|
||||||
19
fooder/controller_old/user.py
Normal file
19
fooder/controller_old/user.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from ..domain.user import User as DBUser
|
||||||
|
from ..model.user import CreateUserPayload, User
|
||||||
|
from .base import BaseController
|
||||||
|
|
||||||
|
|
||||||
|
class CreateUser(BaseController):
|
||||||
|
async def call(self, content: CreateUserPayload) -> User:
|
||||||
|
async with self.async_session.begin() as session:
|
||||||
|
try:
|
||||||
|
user = await DBUser.create(
|
||||||
|
session,
|
||||||
|
content.username,
|
||||||
|
content.password,
|
||||||
|
)
|
||||||
|
return User.from_orm(user)
|
||||||
|
except AssertionError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=e.args[0])
|
||||||
19
fooder/exc.py
Normal file
19
fooder/exc.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
|
||||||
|
class ApiException(Exception):
|
||||||
|
HTTP_CODE: ClassVar[int]
|
||||||
|
MESSAGE: ClassVar[str]
|
||||||
|
|
||||||
|
def __init__(self, message: str | None = None) -> None:
|
||||||
|
self.message = message or self.MESSAGE
|
||||||
|
|
||||||
|
|
||||||
|
class NotFound(ApiException):
|
||||||
|
HTTP_CODE = 404
|
||||||
|
MESSAGE = "Not found"
|
||||||
|
|
||||||
|
|
||||||
|
class Unauthorized(ApiException):
|
||||||
|
HTTP_CODE = 401
|
||||||
|
MESSAGE = "Unathorized"
|
||||||
|
|
@ -1,15 +1,11 @@
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class Token(BaseModel):
|
class TokenResponse(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
refresh_token: str
|
refresh_token: str
|
||||||
token_type: str = "bearer"
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
|
||||||
class TokenData(BaseModel):
|
class RefreshTokenRequest(BaseModel):
|
||||||
username: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class RefreshTokenPayload(BaseModel):
|
|
||||||
refresh_token: str
|
refresh_token: str
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,10 @@
|
||||||
from typing import TypeVar, Generic, Type, Sequence
|
from typing import TypeVar, Generic, Type, Sequence
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy import (
|
from sqlalchemy import select, delete as sa_delete, ColumnElement
|
||||||
select,
|
|
||||||
update as sa_update,
|
|
||||||
delete as sa_delete,
|
|
||||||
BinaryExpression,
|
|
||||||
)
|
|
||||||
from sqlalchemy.sql import Select
|
from sqlalchemy.sql import Select
|
||||||
|
from ..domain import Base
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T", bound=Base)
|
||||||
|
|
||||||
|
|
||||||
class RepositoryBase(Generic[T]):
|
class RepositoryBase(Generic[T]):
|
||||||
|
|
@ -16,7 +12,7 @@ class RepositoryBase(Generic[T]):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.session = session
|
self.session = session
|
||||||
|
|
||||||
def _build_select(self, *expressions: BinaryExpression) -> Select[tuple[T]]:
|
def _build_select(self, *expressions: ColumnElement) -> Select[tuple[T]]:
|
||||||
stmt = select(self.model)
|
stmt = select(self.model)
|
||||||
|
|
||||||
if expressions:
|
if expressions:
|
||||||
|
|
@ -24,12 +20,12 @@ class RepositoryBase(Generic[T]):
|
||||||
|
|
||||||
return stmt
|
return stmt
|
||||||
|
|
||||||
async def get(self, *expressions: BinaryExpression) -> T | None:
|
async def get(self, *expressions: ColumnElement) -> T | None:
|
||||||
stmt = self._build_select(*expressions)
|
stmt = self._build_select(*expressions)
|
||||||
result = await self.session.execute(stmt)
|
result = await self.session.execute(stmt)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
async def list(self, *expressions: BinaryExpression) -> Sequence[T]:
|
async def list(self, *expressions: ColumnElement) -> Sequence[T]:
|
||||||
stmt = self._build_select(*expressions)
|
stmt = self._build_select(*expressions)
|
||||||
result = await self.session.execute(stmt)
|
result = await self.session.execute(stmt)
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
|
|
@ -40,7 +36,7 @@ class RepositoryBase(Generic[T]):
|
||||||
await self.session.refresh(obj)
|
await self.session.refresh(obj)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
async def delete(self, *expressions: BinaryExpression):
|
async def delete(self, *expressions: ColumnElement):
|
||||||
stmt = sa_delete(self.model)
|
stmt = sa_delete(self.model)
|
||||||
|
|
||||||
if expressions:
|
if expressions:
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,6 @@
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
# from .view.diary import router as diary_router
|
|
||||||
# from .view.entry import router as entry_router
|
|
||||||
# from .view.meal import router as meal_router
|
|
||||||
# from .view.preset import router as preset_router
|
|
||||||
# from .view.product import router as product_router
|
|
||||||
from .view.token import router as token_router
|
from .view.token import router as token_router
|
||||||
# from .view.user import router as user_router
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api")
|
router = APIRouter(prefix="/api")
|
||||||
# router.include_router(product_router, prefix="/product", tags=["product"])
|
|
||||||
# router.include_router(diary_router, prefix="/diary", tags=["diary"])
|
|
||||||
# router.include_router(meal_router, prefix="/meal", tags=["meal"])
|
|
||||||
# router.include_router(entry_router, prefix="/entry", tags=["entry"])
|
|
||||||
router.include_router(token_router, prefix="/token", tags=["token"])
|
router.include_router(token_router, prefix="/token", tags=["token"])
|
||||||
# router.include_router(user_router, prefix="/user", tags=["user"])
|
|
||||||
# router.include_router(preset_router, prefix="/preset", tags=["preset"])
|
|
||||||
|
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
|
|
||||||
from .settings import Settings
|
|
||||||
from .view.tasks import router
|
|
||||||
|
|
||||||
app = FastAPI(title="Fooder Tasks admininstrative API")
|
|
||||||
app.include_router(router)
|
|
||||||
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=Settings().ALLOWED_ORIGINS,
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
0
fooder/test/controller/__init__.py
Normal file
0
fooder/test/controller/__init__.py
Normal file
6
fooder/test/controller/test_token.py
Normal file
6
fooder/test/controller/test_token.py
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
from fooder.controller import TokenController
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_ctrl_generates_token(ctx):
|
||||||
|
token_ctrl = TokenController(ctx, 1)
|
||||||
|
token_ctrl.generate_token_pair(ctx.clock())
|
||||||
11
fooder/test/test_exc.py
Normal file
11
fooder/test/test_exc.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
from fooder.exc import ApiException
|
||||||
|
|
||||||
|
|
||||||
|
class TestException(ApiException):
|
||||||
|
HTTP_CODE = 0
|
||||||
|
MESSAGE = "test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_exc_message():
|
||||||
|
assert TestException().message == TestException.MESSAGE
|
||||||
|
assert TestException("other message").message == "other message"
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
import pytest
|
import pytest
|
||||||
from jose import JWTError
|
from jose import jwt
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from fooder.exc import Unauthorized
|
||||||
from fooder.utils.jwt import AccessToken, RefreshToken, Token
|
from fooder.utils.jwt import AccessToken, RefreshToken, Token
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -9,6 +11,7 @@ PAST = datetime(2000, 1, 1, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
class WrongKeyToken(Token):
|
class WrongKeyToken(Token):
|
||||||
|
token_type: Literal["test-type"] = "test-type"
|
||||||
secret_key = "wrong-secret"
|
secret_key = "wrong-secret"
|
||||||
expire_delta = timedelta(minutes=30)
|
expire_delta = timedelta(minutes=30)
|
||||||
|
|
||||||
|
|
@ -29,15 +32,22 @@ class TestAccessToken:
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
token = WrongKeyToken(exp=WrongKeyToken.calculate_exp(now), sub=1)
|
token = WrongKeyToken(exp=WrongKeyToken.calculate_exp(now), sub=1)
|
||||||
|
|
||||||
with pytest.raises(JWTError):
|
with pytest.raises(Unauthorized):
|
||||||
AccessToken.decode(token.encode())
|
AccessToken.decode(token.encode())
|
||||||
|
|
||||||
def test_decode_expired_raises(self):
|
def test_decode_expired_raises(self):
|
||||||
token = AccessToken(exp=PAST, sub=1)
|
token = AccessToken(exp=PAST, sub=1)
|
||||||
|
|
||||||
with pytest.raises(JWTError):
|
with pytest.raises(Unauthorized):
|
||||||
AccessToken.decode(token.encode())
|
AccessToken.decode(token.encode())
|
||||||
|
|
||||||
|
def test_encoded_fields(self):
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
token = AccessToken(exp=AccessToken.calculate_exp(now), sub=42)
|
||||||
|
payload = jwt.decode(token.encode(), "", options={"verify_signature": False})
|
||||||
|
assert "secret_key" not in payload
|
||||||
|
assert "expire_delta" not in payload
|
||||||
|
|
||||||
|
|
||||||
class TestRefreshToken:
|
class TestRefreshToken:
|
||||||
def test_encode_decode_roundtrip(self):
|
def test_encode_decode_roundtrip(self):
|
||||||
|
|
@ -55,12 +65,19 @@ class TestRefreshToken:
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
token = RefreshToken(exp=RefreshToken.calculate_exp(now), sub=1)
|
token = RefreshToken(exp=RefreshToken.calculate_exp(now), sub=1)
|
||||||
|
|
||||||
with pytest.raises(JWTError):
|
with pytest.raises(Unauthorized):
|
||||||
AccessToken.decode(token.encode())
|
AccessToken.decode(token.encode())
|
||||||
|
|
||||||
def test_access_token_not_decodable_as_refresh_token(self):
|
def test_access_token_not_decodable_as_refresh_token(self):
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
token = AccessToken(exp=AccessToken.calculate_exp(now), sub=1)
|
token = AccessToken(exp=AccessToken.calculate_exp(now), sub=1)
|
||||||
|
|
||||||
with pytest.raises(JWTError):
|
with pytest.raises(Unauthorized):
|
||||||
RefreshToken.decode(token.encode())
|
RefreshToken.decode(token.encode())
|
||||||
|
|
||||||
|
def test_encoded_fields(self):
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
token = AccessToken(exp=RefreshToken.calculate_exp(now), sub=42)
|
||||||
|
payload = jwt.decode(token.encode(), "", options={"verify_signature": False})
|
||||||
|
assert "secret_key" not in payload
|
||||||
|
assert "expire_delta" not in payload
|
||||||
|
|
|
||||||
5
fooder/utils/datetime.py
Normal file
5
fooder/utils/datetime.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
||||||
|
def utc_now() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
@ -1,11 +1,14 @@
|
||||||
from jose import jwt
|
from jose import jwt, JOSEError
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from datetime import timedelta, datetime
|
from datetime import timedelta, datetime
|
||||||
from typing import ClassVar
|
from typing import ClassVar, Literal
|
||||||
|
import logging
|
||||||
from ..settings import settings
|
from ..settings import settings
|
||||||
|
from ..exc import Unauthorized
|
||||||
|
|
||||||
|
|
||||||
class Token(BaseModel):
|
class Token(BaseModel):
|
||||||
|
token_type: str
|
||||||
exp: datetime
|
exp: datetime
|
||||||
sub: int
|
sub: int
|
||||||
|
|
||||||
|
|
@ -18,7 +21,13 @@ class Token(BaseModel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode(cls, jwt_token: str | bytes) -> "Token":
|
def decode(cls, jwt_token: str | bytes) -> "Token":
|
||||||
data = jwt.decode(jwt_token, cls.secret_key, algorithms=[settings.ALGORITHM])
|
try:
|
||||||
|
data = jwt.decode(
|
||||||
|
jwt_token, cls.secret_key, algorithms=[settings.ALGORITHM]
|
||||||
|
)
|
||||||
|
except JOSEError as e:
|
||||||
|
logging.error(e)
|
||||||
|
raise Unauthorized()
|
||||||
return cls(**data)
|
return cls(**data)
|
||||||
|
|
||||||
def encode(self) -> str:
|
def encode(self) -> str:
|
||||||
|
|
@ -28,10 +37,12 @@ class Token(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class AccessToken(Token):
|
class AccessToken(Token):
|
||||||
|
token_type: Literal["access"] = "access"
|
||||||
secret_key = settings.SECRET_KEY
|
secret_key = settings.SECRET_KEY
|
||||||
expire_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
expire_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
|
||||||
|
|
||||||
class RefreshToken(Token):
|
class RefreshToken(Token):
|
||||||
|
token_type: Literal["refresh"] = "refresh"
|
||||||
secret_key = settings.REFRESH_SECRET_KEY
|
secret_key = settings.REFRESH_SECRET_KEY
|
||||||
expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
expire_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
|
|
||||||
|
|
@ -1,39 +1,38 @@
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from fastapi import APIRouter, Depends
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from ..model.token import RefreshTokenPayload, Token
|
from ..model.token import TokenResponse, RefreshTokenRequest
|
||||||
from ..context import ContextDependency, Context
|
from ..context import ContextDependency, Context
|
||||||
from ..utils.jwt import AccessToken, RefreshToken
|
from ..controller import UserController
|
||||||
from ..domain import User
|
|
||||||
|
|
||||||
router = APIRouter(tags=["token"])
|
router = APIRouter(tags=["token"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=Token)
|
def gen_token_response(user_ctrl: UserController, now: datetime) -> TokenResponse:
|
||||||
async def create_token(
|
token_ctrl = user_ctrl.token_ctrl()
|
||||||
data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
access_token, refresh_token = token_ctrl.generate_token_pair(now)
|
||||||
ctx: Context = Depends(ContextDependency()),
|
return TokenResponse(
|
||||||
):
|
|
||||||
user = await ctx.repo.user.get(User.username == data.username)
|
|
||||||
|
|
||||||
if user is None or not user.verify_password(data.password):
|
|
||||||
raise HTTPException(status_code=401, detail="Unathorized")
|
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
access_token = AccessToken(sub=user.id, exp=AccessToken.calculate_exp(now))
|
|
||||||
refresh_token = RefreshToken(sub=user.id, exp=RefreshToken.calculate_exp(now))
|
|
||||||
return Token(
|
|
||||||
access_token=access_token.encode(),
|
access_token=access_token.encode(),
|
||||||
refresh_token=refresh_token.encode(),
|
refresh_token=refresh_token.encode(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=Token)
|
@router.post("", response_model=TokenResponse)
|
||||||
async def refresh_token(
|
async def token_create(
|
||||||
data: RefreshTokenPayload,
|
data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
|
ctx: Context = Depends(ContextDependency()),
|
||||||
|
) -> TokenResponse:
|
||||||
|
now = ctx.clock()
|
||||||
|
user_ctrl = await UserController.session_start(ctx, data.username, data.password)
|
||||||
|
return gen_token_response(user_ctrl, now)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=TokenResponse)
|
||||||
|
async def token_refresh(
|
||||||
|
data: RefreshTokenRequest,
|
||||||
ctx: Context = Depends(ContextDependency()),
|
ctx: Context = Depends(ContextDependency()),
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue