begin tedious work
This commit is contained in:
parent
9792b0feb3
commit
bbbd124d78
29 changed files with 340 additions and 578 deletions
|
|
@ -17,7 +17,7 @@ FROM python:3.11.5-bullseye
|
||||||
|
|
||||||
RUN apt-get -y install libpq-dev
|
RUN apt-get -y install libpq-dev
|
||||||
|
|
||||||
COPY requirements.txt requirements.txt
|
COPY requirements/docker.txt requirements.txt
|
||||||
RUN pip install -r requirements.txt
|
RUN pip install -r requirements.txt
|
||||||
|
|
||||||
RUN useradd fooder
|
RUN useradd fooder
|
||||||
|
|
|
||||||
2
Makefile
2
Makefile
|
|
@ -40,7 +40,7 @@ version:
|
||||||
.PHONY: create-venv
|
.PHONY: create-venv
|
||||||
create-venv:
|
create-venv:
|
||||||
python3 -m venv .venv --prompt="fooderapi-venv" --system-site-packages
|
python3 -m venv .venv --prompt="fooderapi-venv" --system-site-packages
|
||||||
bash -c "source .venv/bin/activate && pip install -r requirements_local.txt"
|
bash -c "source .venv/bin/activate && pip install -r requirements/local.txt"
|
||||||
|
|
||||||
.PHONY: test
|
.PHONY: test
|
||||||
test:
|
test:
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from jose import JWTError, jwt
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
from .db import get_session
|
from .db import get_db_session
|
||||||
from .domain.token import RefreshToken
|
from .domain.token import RefreshToken
|
||||||
from .domain.user import User
|
from .domain.user import User
|
||||||
from .settings import Settings
|
from .settings import Settings
|
||||||
|
|
@ -18,7 +18,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
password_helper = PasswordHelper(pwd_context) # type: ignore
|
password_helper = PasswordHelper(pwd_context) # type: ignore
|
||||||
|
|
||||||
AsyncSessionDependency = Annotated[async_sessionmaker, Depends(get_session)]
|
AsyncSessionDependency = Annotated[async_sessionmaker, Depends(get_db_session)]
|
||||||
TokenDependency = Annotated[str, Depends(oauth2_scheme)]
|
TokenDependency = Annotated[str, Depends(oauth2_scheme)]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
30
fooder/context.py
Normal file
30
fooder/context.py
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from fastapi import Depends
|
||||||
|
from fooder.db import get_db_session
|
||||||
|
|
||||||
|
|
||||||
|
class Context:
|
||||||
|
"""
|
||||||
|
Main API context, aggregating dependencies
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dbssn: AsyncSession) -> None:
|
||||||
|
self.dbssn = dbssn
|
||||||
|
|
||||||
|
|
||||||
|
class ContextDependency:
|
||||||
|
"""
|
||||||
|
Configurable context dependecy. Allows for shared interface configuring
|
||||||
|
method required dependencies
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
dbssn: AsyncSession = Depends(get_db_session),
|
||||||
|
):
|
||||||
|
return Context(dbssn=dbssn)
|
||||||
|
|
@ -4,17 +4,16 @@ from fastapi import Depends
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||||
|
|
||||||
from ..auth import authorize_api_key, get_current_user
|
from ..auth import authorize_api_key, get_current_user
|
||||||
from ..db import get_session
|
from ..db import get_db_session, AsyncSession
|
||||||
from ..domain.user import User
|
from ..domain.user import User
|
||||||
|
|
||||||
AsyncSession = Annotated[async_sessionmaker, Depends(get_session)]
|
|
||||||
UserDependency = Annotated[User, Depends(get_current_user)]
|
UserDependency = Annotated[User, Depends(get_current_user)]
|
||||||
ApiKeyDependency = Annotated[None, Depends(authorize_api_key)]
|
ApiKeyDependency = Annotated[None, Depends(authorize_api_key)]
|
||||||
|
|
||||||
|
|
||||||
class BaseController:
|
class BaseController:
|
||||||
def __init__(self, session: AsyncSession) -> None:
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
self.async_session = session
|
self.session = session
|
||||||
|
|
||||||
async def call(self, *args, **kwargs) -> Any:
|
async def call(self, *args, **kwargs) -> Any:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,7 @@ from .base import AuthorizedController
|
||||||
|
|
||||||
class GetDiary(AuthorizedController):
|
class GetDiary(AuthorizedController):
|
||||||
async def call(self, date: date) -> Diary:
|
async def call(self, date: date) -> Diary:
|
||||||
async with self.async_session() as session:
|
diary = await DBDiary.get_diary(self.session, self.user.id, date)
|
||||||
diary = await DBDiary.get_diary(session, self.user.id, date)
|
|
||||||
|
|
||||||
if diary is not None:
|
if diary is not None:
|
||||||
return Diary.from_orm(diary)
|
return Diary.from_orm(diary)
|
||||||
|
|
|
||||||
85
fooder/db.py
85
fooder/db.py
|
|
@ -1,38 +1,69 @@
|
||||||
import logging
|
import contextlib
|
||||||
from typing import AsyncIterator
|
from typing import AsyncIterator, AsyncGenerator
|
||||||
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from fooder.settings import Settings, settings
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import (
|
||||||
|
AsyncConnection,
|
||||||
|
AsyncSession,
|
||||||
|
async_sessionmaker,
|
||||||
|
create_async_engine,
|
||||||
|
)
|
||||||
|
|
||||||
from .settings import Settings
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
class DatabaseSessionManager:
|
||||||
settings = Settings.parse_obj({})
|
def __init__(self, settings: Settings) -> None:
|
||||||
|
self._engine = create_async_engine(
|
||||||
if settings.DB_URI.startswith("sqlite"):
|
|
||||||
settings.DB_URI = settings.DB_URI + "?check_same_thread=False"
|
|
||||||
|
|
||||||
"""
|
|
||||||
Asynchronous PostgreSQL database engine.
|
|
||||||
"""
|
|
||||||
async_engine = create_async_engine(
|
|
||||||
settings.DB_URI,
|
settings.DB_URI,
|
||||||
pool_pre_ping=True,
|
pool_pre_ping=True,
|
||||||
echo=settings.ECHO_SQL,
|
echo=settings.ECHO_SQL,
|
||||||
connect_args=(
|
connect_args=(
|
||||||
{"check_same_thread": False} if settings.DB_URI.startswith("sqlite") else {}
|
{"check_same_thread": False}
|
||||||
|
if settings.DB_URI.startswith("sqlite")
|
||||||
|
else {}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
AsyncSessionLocal = async_sessionmaker(
|
self._sessionmaker = async_sessionmaker(
|
||||||
bind=async_engine,
|
autocommit=False, autoflush=False, future=True, bind=self._engine
|
||||||
autocommit=False,
|
)
|
||||||
autoflush=False,
|
|
||||||
future=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._engine is None:
|
||||||
|
raise Exception("DatabaseSessionManager is not initialized")
|
||||||
|
await self._engine.dispose()
|
||||||
|
|
||||||
async def get_session() -> AsyncIterator[async_sessionmaker]:
|
self._engine = None
|
||||||
|
self._sessionmaker = None
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def connect(self) -> AsyncIterator[AsyncConnection]:
|
||||||
|
if self._engine is None:
|
||||||
|
raise Exception("DatabaseSessionManager is not initialized")
|
||||||
|
|
||||||
|
async with self._engine.begin() as connection:
|
||||||
try:
|
try:
|
||||||
yield AsyncSessionLocal
|
yield connection
|
||||||
except SQLAlchemyError as e:
|
except Exception:
|
||||||
log.exception(e)
|
await connection.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def session(self) -> AsyncIterator[AsyncSession]:
|
||||||
|
if self._sessionmaker is None:
|
||||||
|
raise Exception("DatabaseSessionManager is not initialized")
|
||||||
|
|
||||||
|
session = self._sessionmaker()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
session_manager = DatabaseSessionManager(settings)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
async with session_manager.session() as session:
|
||||||
|
yield session
|
||||||
|
|
|
||||||
0
fooder/repository/__init__.py
Normal file
0
fooder/repository/__init__.py
Normal file
6
fooder/repository/base.py
Normal file
6
fooder/repository/base.py
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
|
class RepositoryBase:
|
||||||
|
def __init__(self, dbssn: AsyncSession):
|
||||||
|
self.dbssn = dbssn
|
||||||
2
fooder/repository/user.py
Normal file
2
fooder/repository/user.py
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
class UserRepository:
|
||||||
|
pass
|
||||||
|
|
@ -13,8 +13,11 @@ class Settings(BaseSettings):
|
||||||
REFRESH_SECRET_KEY: str
|
REFRESH_SECRET_KEY: str
|
||||||
ALGORITHM: str = "HS256"
|
ALGORITHM: str = "HS256"
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
REFRESH_TOKEN_EXPIRE_DAYS: int = 120
|
||||||
|
|
||||||
ALLOWED_ORIGINS: List[str] = ["*"]
|
ALLOWED_ORIGINS: List[str] = ["*"]
|
||||||
|
|
||||||
API_KEY: str
|
API_KEY: str
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
|
|
||||||
|
|
@ -1 +1,34 @@
|
||||||
from .fixtures import * # noqa
|
import os
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Supply minimal dummy env-vars *before* any of our modules are imported. #
|
||||||
|
# This lets the global `settings = Settings()` call succeed. #
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
os.environ.update(
|
||||||
|
{
|
||||||
|
"DB_URI": "sqlite+aiosqlite:///:memory:",
|
||||||
|
"ECHO_SQL": "false",
|
||||||
|
"SECRET_KEY": "test-secret",
|
||||||
|
"REFRESH_SECRET_KEY": "test-refresh",
|
||||||
|
"API_KEY": "test-key",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
from fooder.db import DatabaseSessionManager
|
||||||
|
from fooder.domain import Base
|
||||||
|
from fooder.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def db_manager() -> DatabaseSessionManager:
|
||||||
|
return DatabaseSessionManager(settings)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||||
|
async def setup_database(db_manager: DatabaseSessionManager):
|
||||||
|
async with db_manager.connect() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
|
||||||
6
fooder/test/fixtures/__init__.py
vendored
6
fooder/test/fixtures/__init__.py
vendored
|
|
@ -1,9 +1,5 @@
|
||||||
from .client import * # noqa
|
|
||||||
from .user import * # noqa
|
|
||||||
from .product import * # noqa
|
|
||||||
from .meal import * # noqa
|
|
||||||
from .entry import * # noqa
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from .dbssn import *
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
110
fooder/test/fixtures/client.py
vendored
110
fooder/test/fixtures/client.py
vendored
|
|
@ -1,110 +0,0 @@
|
||||||
from fooder.app import app
|
|
||||||
from fooder.tasks_app import app as tasks_app
|
|
||||||
from httpx import ASGITransport, AsyncClient
|
|
||||||
import pytest
|
|
||||||
import httpx
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
username: str | None = None,
|
|
||||||
password: str | None = None,
|
|
||||||
):
|
|
||||||
self.client = lambda: AsyncClient(
|
|
||||||
transport=ASGITransport(app=app),
|
|
||||||
base_url="http://testserver/api",
|
|
||||||
headers=self.headers,
|
|
||||||
)
|
|
||||||
self.headers = {"Accept": "application/json"}
|
|
||||||
|
|
||||||
def set_token(self, token: str) -> None:
|
|
||||||
"""set_token.
|
|
||||||
|
|
||||||
:param token:
|
|
||||||
:type token: str
|
|
||||||
:rtype: None
|
|
||||||
"""
|
|
||||||
self.headers["Authorization"] = "Bearer " + token
|
|
||||||
|
|
||||||
async def create_user(self, username: str, password: str) -> None:
|
|
||||||
data = {"username": username, "password": password}
|
|
||||||
response = await self.post("user", json=data)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
async def login(self, username: str, password: str, force_login: bool) -> None:
|
|
||||||
"""login.
|
|
||||||
|
|
||||||
:param username:
|
|
||||||
:type username: str
|
|
||||||
:param password:
|
|
||||||
:type password: str
|
|
||||||
:param force_login:
|
|
||||||
:type password: bool
|
|
||||||
:rtype: None
|
|
||||||
"""
|
|
||||||
data = {"username": username, "password": password}
|
|
||||||
|
|
||||||
response = await self.post("token", data=data)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
if force_login:
|
|
||||||
await self.create_user(username, password)
|
|
||||||
return await self.login(username, password, False)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"Could not login as {username}! Detail: {response.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
self.set_token(result["access_token"])
|
|
||||||
|
|
||||||
async def get(self, path: str, **kwargs) -> httpx.Response:
|
|
||||||
async with self.client() as client:
|
|
||||||
return await client.get(path, **kwargs)
|
|
||||||
|
|
||||||
async def delete(self, path: str, **kwargs) -> httpx.Response:
|
|
||||||
async with self.client() as client:
|
|
||||||
return await client.delete(path, **kwargs)
|
|
||||||
|
|
||||||
async def post(self, path: str, **kwargs) -> httpx.Response:
|
|
||||||
async with self.client() as client:
|
|
||||||
return await client.post(path, **kwargs)
|
|
||||||
|
|
||||||
async def patch(self, path: str, **kwargs) -> httpx.Response:
|
|
||||||
async with self.client() as client:
|
|
||||||
return await client.patch(path, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class TasksClient(Client):
|
|
||||||
def __init__(self, authorized: bool = True):
|
|
||||||
super().__init__()
|
|
||||||
self.client = lambda: AsyncClient(
|
|
||||||
transport=ASGITransport(app=tasks_app),
|
|
||||||
base_url="http://testserver/api",
|
|
||||||
headers=self.headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
if authorized:
|
|
||||||
self.headers["Authorization"] = "Bearer " + self.get_token()
|
|
||||||
|
|
||||||
def get_token(self) -> str:
|
|
||||||
return os.getenv("API_KEY")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def unauthorized_client() -> Client:
|
|
||||||
return Client()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def tasks_client() -> Client:
|
|
||||||
return TasksClient()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def client(user_payload) -> Client:
|
|
||||||
client = Client()
|
|
||||||
await client.login(user_payload["username"], user_payload["password"], True)
|
|
||||||
return client
|
|
||||||
24
fooder/test/fixtures/dbssn.py
vendored
Normal file
24
fooder/test/fixtures/dbssn.py
vendored
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
import pytest_asyncio
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import event
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db_session(db_manager):
|
||||||
|
async with db_manager._engine.connect() as conn:
|
||||||
|
trans = await conn.begin()
|
||||||
|
session = AsyncSession(bind=conn)
|
||||||
|
|
||||||
|
nested = await conn.begin_nested()
|
||||||
|
|
||||||
|
@event.listens_for(session.sync_session, "after_transaction_end")
|
||||||
|
def restart_savepoint(sess, transaction):
|
||||||
|
nonlocal nested
|
||||||
|
if not nested.is_active:
|
||||||
|
nested = conn.sync_connection.begin_nested()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
await trans.rollback()
|
||||||
14
fooder/test/fixtures/entry.py
vendored
14
fooder/test/fixtures/entry.py
vendored
|
|
@ -1,14 +0,0 @@
|
||||||
import pytest
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def entry_payload_factory() -> Callable[[int, int, float], dict[str, int | float]]:
|
|
||||||
def factory(meal_id: int, product_id: int, grams: float) -> dict[str, int | float]:
|
|
||||||
return {
|
|
||||||
"meal_id": meal_id,
|
|
||||||
"product_id": product_id,
|
|
||||||
"grams": grams,
|
|
||||||
}
|
|
||||||
|
|
||||||
return factory
|
|
||||||
37
fooder/test/fixtures/meal.py
vendored
37
fooder/test/fixtures/meal.py
vendored
|
|
@ -1,37 +0,0 @@
|
||||||
import pytest
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def meal_payload_factory() -> Callable[[int, int], dict[str, int | str]]:
|
|
||||||
def factory(diary_id: int, order: int) -> dict[str, int | str]:
|
|
||||||
return {
|
|
||||||
"order": order,
|
|
||||||
"diary_id": diary_id,
|
|
||||||
"name": f"meal {order}",
|
|
||||||
}
|
|
||||||
|
|
||||||
return factory
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def meal_save_payload() -> Callable[[int], dict[str, str]]:
|
|
||||||
def factory(meal_id: int) -> dict[str, str]:
|
|
||||||
return {
|
|
||||||
"name": "new name",
|
|
||||||
}
|
|
||||||
|
|
||||||
return factory
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def meal_from_preset() -> Callable[[int, int, int], dict[str, str | int]]:
|
|
||||||
def factory(order: int, diary_id: int, preset_id: int) -> dict[str, str | int]:
|
|
||||||
return {
|
|
||||||
"name": "new name",
|
|
||||||
"order": order,
|
|
||||||
"diary_id": diary_id,
|
|
||||||
"preset_id": preset_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
return factory
|
|
||||||
22
fooder/test/fixtures/product.py
vendored
22
fooder/test/fixtures/product.py
vendored
|
|
@ -1,22 +0,0 @@
|
||||||
import pytest
|
|
||||||
import uuid
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def product_payload_factory() -> Callable[[], dict[str, str | float]]:
|
|
||||||
def factory() -> dict[str, str | float]:
|
|
||||||
return {
|
|
||||||
"name": "test" + str(uuid.uuid4().hex),
|
|
||||||
"protein": 1.0,
|
|
||||||
"carb": 1.0,
|
|
||||||
"fat": 1.0,
|
|
||||||
"fiber": 1.0,
|
|
||||||
}
|
|
||||||
|
|
||||||
return factory
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def product_payload(product_payload_factory) -> dict[str, str | float]:
|
|
||||||
return product_payload_factory()
|
|
||||||
22
fooder/test/fixtures/user.py
vendored
22
fooder/test/fixtures/user.py
vendored
|
|
@ -1,22 +0,0 @@
|
||||||
import pytest
|
|
||||||
from typing import Callable
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def user_payload() -> dict[str, str]:
|
|
||||||
return {
|
|
||||||
"username": "test",
|
|
||||||
"password": "test",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def user_payload_factory(user_payload) -> Callable[[], dict[str, str]]:
|
|
||||||
def factory() -> dict[str, str]:
|
|
||||||
return {
|
|
||||||
"username": "test" + str(uuid.uuid4().hex),
|
|
||||||
"password": "test",
|
|
||||||
}
|
|
||||||
|
|
||||||
return factory
|
|
||||||
148
fooder/test/test_db.py
Normal file
148
fooder/test/test_db.py
Normal file
|
|
@ -0,0 +1,148 @@
|
||||||
|
# tests/test_db.py
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
|
||||||
|
|
||||||
|
from fooder.db import DatabaseSessionManager, get_db_session
|
||||||
|
from fooder.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fresh_manager():
|
||||||
|
return DatabaseSessionManager(settings)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_init_creates_engine_and_sessionmaker(db_manager: DatabaseSessionManager):
|
||||||
|
assert db_manager._engine is not None
|
||||||
|
assert db_manager._sessionmaker is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_close_disposes_engine_and_nullifies_attrs(
|
||||||
|
fresh_manager: DatabaseSessionManager,
|
||||||
|
):
|
||||||
|
await fresh_manager.close()
|
||||||
|
assert fresh_manager._engine is None
|
||||||
|
assert fresh_manager._sessionmaker is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_connect_after_close_raises(fresh_manager: DatabaseSessionManager):
|
||||||
|
await fresh_manager.close()
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="not initialized"):
|
||||||
|
async with fresh_manager.connect():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def test_session_after_close_raises(fresh_manager: DatabaseSessionManager):
|
||||||
|
await fresh_manager.close()
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="not initialized"):
|
||||||
|
async with fresh_manager.session():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def test_close_when_already_closed_raises(fresh_manager: DatabaseSessionManager):
|
||||||
|
await fresh_manager.close()
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="not initialized"):
|
||||||
|
await fresh_manager.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_session_commit_persists_data(db_manager: DatabaseSessionManager):
|
||||||
|
async with db_manager.connect() as conn:
|
||||||
|
await conn.execute(text("CREATE TABLE test_commit(x int)"))
|
||||||
|
|
||||||
|
async with db_manager.session() as session:
|
||||||
|
await session.execute(text("INSERT INTO test_commit VALUES (42)"))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async with db_manager.session() as session:
|
||||||
|
res = await session.execute(text("SELECT x FROM test_commit"))
|
||||||
|
assert res.scalar() == 42
|
||||||
|
|
||||||
|
|
||||||
|
async def test_session_does_not_autocommit(db_manager: DatabaseSessionManager):
|
||||||
|
async with db_manager.connect() as conn:
|
||||||
|
await conn.execute(text("CREATE TABLE test_no_commit(x int)"))
|
||||||
|
|
||||||
|
async with db_manager.session() as session:
|
||||||
|
await session.execute(text("INSERT INTO test_no_commit VALUES (1)"))
|
||||||
|
# no commit
|
||||||
|
|
||||||
|
async with db_manager.session() as session:
|
||||||
|
res = await session.execute(text("SELECT * FROM test_no_commit"))
|
||||||
|
assert res.first() is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_connect_context_yields_working_connection(
|
||||||
|
db_manager: DatabaseSessionManager,
|
||||||
|
):
|
||||||
|
async with db_manager.connect() as conn:
|
||||||
|
assert isinstance(conn, AsyncConnection)
|
||||||
|
# prove the connection is real
|
||||||
|
res = await conn.execute(text("SELECT 1"))
|
||||||
|
assert res.scalar() == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_connect_rolls_back_on_exception(db_manager: DatabaseSessionManager):
|
||||||
|
"""Raising inside connect() must roll back the txn."""
|
||||||
|
|
||||||
|
class BoomError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(BoomError):
|
||||||
|
async with db_manager.connect() as conn:
|
||||||
|
await conn.execute(text("CREATE TABLE t(x int)"))
|
||||||
|
await conn.execute(text("INSERT INTO t VALUES (1)"))
|
||||||
|
raise BoomError("deliberate")
|
||||||
|
|
||||||
|
# Use a *fresh* connection so the failed one is really gone
|
||||||
|
async with db_manager.connect() as conn:
|
||||||
|
res = await conn.execute(text("SELECT * FROM t"))
|
||||||
|
assert res.first() is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_session_rolls_back_on_exception(db_manager: DatabaseSessionManager):
|
||||||
|
"""Raising inside session() must roll back the txn."""
|
||||||
|
|
||||||
|
class BoomError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(BoomError):
|
||||||
|
async with db_manager.session() as session:
|
||||||
|
await session.execute(text("CREATE TABLE s(a int)"))
|
||||||
|
await session.execute(text("INSERT INTO s VALUES (1)"))
|
||||||
|
raise BoomError("deliberate")
|
||||||
|
|
||||||
|
# Fresh session / connection
|
||||||
|
async with db_manager.session() as session:
|
||||||
|
res = await session.execute(text("SELECT * FROM s"))
|
||||||
|
assert res.first() is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_db_session_yields_active_session():
|
||||||
|
async for session in get_db_session():
|
||||||
|
assert isinstance(session, AsyncSession)
|
||||||
|
res = await session.execute(text("SELECT 1337"))
|
||||||
|
assert res.scalar() == 1337
|
||||||
|
break # single yield is enough
|
||||||
|
|
||||||
|
|
||||||
|
async def test_concurrent_sessions(db_manager: DatabaseSessionManager):
|
||||||
|
async with db_manager.connect() as conn:
|
||||||
|
await conn.execute(text("CREATE TABLE test_concurrent(x int)"))
|
||||||
|
|
||||||
|
async def worker(val):
|
||||||
|
async with db_manager.session() as session:
|
||||||
|
await session.execute(
|
||||||
|
text("INSERT INTO test_concurrent VALUES (:v)"),
|
||||||
|
{"v": val},
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
await asyncio.gather(*(worker(i) for i in range(5)))
|
||||||
|
|
||||||
|
async with db_manager.session() as session:
|
||||||
|
res = await session.execute(text("SELECT COUNT(*) FROM test_concurrent"))
|
||||||
|
assert res.scalar() == 5
|
||||||
|
|
@ -1,95 +0,0 @@
|
||||||
import datetime
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_get_diary(client):
|
|
||||||
today = datetime.date.today().isoformat()
|
|
||||||
response = await client.get("diary", params={"date": today})
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
assert response.json()["date"] == today
|
|
||||||
# new diary should contain exactly one meal
|
|
||||||
assert len(response.json()["meals"]) == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_diary_add_meal(client, meal_payload_factory):
|
|
||||||
today = datetime.date.today().isoformat()
|
|
||||||
response = await client.get("diary", params={"date": today})
|
|
||||||
|
|
||||||
diary_id = response.json()["id"]
|
|
||||||
meal_order = len(response.json()["meals"]) + 1
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"meal", json=meal_payload_factory(diary_id, meal_order)
|
|
||||||
)
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_diary_delete_meal(client):
|
|
||||||
today = datetime.date.today().isoformat()
|
|
||||||
response = await client.get("diary", params={"date": today})
|
|
||||||
|
|
||||||
meals_amount = len(response.json()["meals"])
|
|
||||||
meal_id = response.json()["meals"][0]["id"]
|
|
||||||
|
|
||||||
response = await client.delete(f"meal/{meal_id}")
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
response = await client.get("diary", params={"date": today})
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
assert len(response.json()["meals"]) == meals_amount - 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_diary_add_entry(client, product_payload_factory, entry_payload_factory):
|
|
||||||
today = datetime.date.today().isoformat()
|
|
||||||
response = await client.get("diary", params={"date": today})
|
|
||||||
|
|
||||||
meal_id = response.json()["meals"][0]["id"]
|
|
||||||
|
|
||||||
product_id = (await client.post("product", json=product_payload_factory())).json()[
|
|
||||||
"id"
|
|
||||||
]
|
|
||||||
|
|
||||||
entry_payload = entry_payload_factory(meal_id, product_id, 100.0)
|
|
||||||
response = await client.post("entry", json=entry_payload)
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_diary_edit_entry(client, entry_payload_factory):
|
|
||||||
today = datetime.date.today().isoformat()
|
|
||||||
response = await client.get("diary", params={"date": today})
|
|
||||||
|
|
||||||
entry = response.json()["meals"][0]["entries"][0]
|
|
||||||
id_ = entry["id"]
|
|
||||||
entry_payload = entry_payload_factory(
|
|
||||||
entry["meal_id"], entry["product"]["id"], entry["grams"] + 100.0
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await client.patch(f"entry/{id_}", json=entry_payload)
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
assert response.json()["grams"] == entry_payload["grams"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_diary_delete_entry(client):
|
|
||||||
today = datetime.date.today().isoformat()
|
|
||||||
response = await client.get("diary", params={"date": today})
|
|
||||||
|
|
||||||
entry_id = response.json()["meals"][0]["entries"][0]["id"]
|
|
||||||
response = await client.delete(f"entry/{entry_id}")
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
response = await client.get("diary", params={"date": today})
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
deleted_entries = [
|
|
||||||
entry
|
|
||||||
for meal in response.json()["meals"]
|
|
||||||
for entry in meal["entries"]
|
|
||||||
if entry["id"] == entry_id
|
|
||||||
]
|
|
||||||
assert len(deleted_entries) == 0
|
|
||||||
|
|
@ -1,107 +0,0 @@
|
||||||
import datetime
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_create_meal(
|
|
||||||
client, meal_payload_factory, product_payload_factory, entry_payload_factory
|
|
||||||
):
|
|
||||||
today = datetime.date.today().isoformat()
|
|
||||||
response = await client.get("diary", params={"date": today})
|
|
||||||
|
|
||||||
diary_id = response.json()["id"]
|
|
||||||
meal_order = len(response.json()["meals"]) + 1
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"meal", json=meal_payload_factory(diary_id, meal_order)
|
|
||||||
)
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
meal_id = response.json()["id"]
|
|
||||||
|
|
||||||
product_id = (await client.post("product", json=product_payload_factory())).json()[
|
|
||||||
"id"
|
|
||||||
]
|
|
||||||
|
|
||||||
entry_payload = entry_payload_factory(meal_id, product_id, 100.0)
|
|
||||||
response = await client.post("entry", json=entry_payload)
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_save_meal(client, meal_save_payload):
|
|
||||||
today = datetime.date.today().isoformat()
|
|
||||||
response = await client.get("diary", params={"date": today})
|
|
||||||
|
|
||||||
meal = response.json()["meals"][0]
|
|
||||||
meal_id = meal["id"]
|
|
||||||
save_payload = meal_save_payload(meal_id)
|
|
||||||
|
|
||||||
response = await client.post(f"meal/{meal_id}/save", json=save_payload)
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
preset = response.json()
|
|
||||||
|
|
||||||
for k, v in preset.items():
|
|
||||||
if k in ("id", "name", "entries"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
assert meal[k] == v, f"{k} != {v}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_list_presets(client, meal_save_payload):
|
|
||||||
response = await client.get("preset")
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
assert len(response.json()["presets"]) > 0, response.json()
|
|
||||||
|
|
||||||
name = meal_save_payload(0)["name"]
|
|
||||||
response = await client.get(f"preset?q={name}")
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
assert len(response.json()["presets"]) > 0, response.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_create_meal_from_preset(client, meal_from_preset):
|
|
||||||
today = datetime.date.today().isoformat()
|
|
||||||
response = await client.get("diary", params={"date": today})
|
|
||||||
|
|
||||||
diary_id = response.json()["id"]
|
|
||||||
meal_order = len(response.json()["meals"]) + 1
|
|
||||||
|
|
||||||
response = await client.get("preset")
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
assert len(response.json()["presets"]) > 0, response.json()
|
|
||||||
|
|
||||||
preset = response.json()["presets"][0]
|
|
||||||
|
|
||||||
payload = meal_from_preset(
|
|
||||||
meal_order,
|
|
||||||
diary_id,
|
|
||||||
preset["id"],
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await client.post("meal/from_preset", json=payload)
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
meal = response.json()
|
|
||||||
|
|
||||||
for k, v in preset.items():
|
|
||||||
if k in ("id", "name", "entries"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
assert meal[k] == v, f"{k} != {v}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_delete_preset(client):
|
|
||||||
presets = (await client.get("preset")).json()["presets"]
|
|
||||||
preset_id = presets[0]["id"]
|
|
||||||
|
|
||||||
response = await client.get(f"preset/{preset_id}")
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
response = await client.delete(f"preset/{preset_id}")
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
response = await client.get(f"preset/{preset_id}")
|
|
||||||
assert response.status_code == 404, response.json()
|
|
||||||
|
|
@ -1,35 +0,0 @@
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_create_product(client, product_payload):
|
|
||||||
response = await client.post("product", json=product_payload)
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_list_product(client):
|
|
||||||
response = await client.get("product")
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
data = response.json()["products"]
|
|
||||||
assert len(data) != 0
|
|
||||||
|
|
||||||
product_ids = set()
|
|
||||||
for product in data:
|
|
||||||
assert product["id"] not in product_ids
|
|
||||||
product_ids.add(product["id"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_get_product_by_barcode(client):
|
|
||||||
response = await client.get(
|
|
||||||
"product/by_barcode", params={"barcode": "4056489666028"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
name = response.json()["name"]
|
|
||||||
|
|
||||||
response = await client.get("product", params={"q": name})
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
assert len(response.json()["products"]) == 1
|
|
||||||
|
|
@ -1,15 +0,0 @@
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_cache_product_usage(client, tasks_client):
|
|
||||||
response = await client.get("product")
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
old_data = response.json()
|
|
||||||
|
|
||||||
response = await tasks_client.post("/cache_product_usage_data")
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
response = await client.get("product")
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
assert response.json() != old_data
|
|
||||||
|
|
@ -1,29 +0,0 @@
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_user_creation(unauthorized_client, user_payload_factory):
|
|
||||||
response = await unauthorized_client.post("user", json=user_payload_factory())
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_user_login(client, user_payload):
|
|
||||||
response = await client.post("token", data=user_payload)
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["access_token"] is not None
|
|
||||||
assert data["refresh_token"] is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_user_refresh_token(client, user_payload):
|
|
||||||
response = await client.post("token", data=user_payload)
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
|
|
||||||
token = response.json()["refresh_token"]
|
|
||||||
payload = {"refresh_token": token}
|
|
||||||
|
|
||||||
response = await client.post("token/refresh", json=payload)
|
|
||||||
assert response.status_code == 200, response.json()
|
|
||||||
2
pytest.ini
Normal file
2
pytest.ini
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
[pytest]
|
||||||
|
asyncio_mode = auto
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
fastapi
|
fastapi
|
||||||
pydantic
|
pydantic
|
||||||
pydantic_settings
|
pydantic-settings
|
||||||
sqlalchemy[postgresql_asyncpg]
|
sqlalchemy[postgresql_asyncpg]
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
asyncpg
|
asyncpg
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
fastapi
|
fastapi
|
||||||
pydantic
|
pydantic
|
||||||
pydantic_settings
|
pydantic-settings
|
||||||
sqlalchemy[postgresql_asyncpg]
|
sqlalchemy[postgresql_asyncpg]
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
python-jose[cryptography]
|
python-jose[cryptography]
|
||||||
|
|
@ -8,6 +8,7 @@ bcrypt<5.0.0
|
||||||
passlib[bcrypt]
|
passlib[bcrypt]
|
||||||
fastapi-users
|
fastapi-users
|
||||||
pytest
|
pytest
|
||||||
|
pytest-asyncio
|
||||||
requests
|
requests
|
||||||
black
|
black
|
||||||
flake8
|
flake8
|
||||||
32
test.sh
32
test.sh
|
|
@ -5,36 +5,10 @@
|
||||||
|
|
||||||
echo "Running fooder api tests"
|
echo "Running fooder api tests"
|
||||||
|
|
||||||
# if exists, remove test.db
|
|
||||||
[ -f test.db ] && rm test.db
|
|
||||||
|
|
||||||
# create test env values
|
|
||||||
export DB_URI="sqlite+aiosqlite:///test.db"
|
|
||||||
export ECHO_SQL=0
|
|
||||||
export SECRET_KEY=$(openssl rand -hex 32)
|
|
||||||
export REFRESH_SECRET_KEY=$(openssl rand -hex 32)
|
|
||||||
export API_KEY=$(openssl rand -hex 32)
|
|
||||||
|
|
||||||
python3 -m fooder --create-tables
|
|
||||||
|
|
||||||
# finally run tests
|
|
||||||
if [[ $# -eq 1 ]]; then
|
if [[ $# -eq 1 ]]; then
|
||||||
python3 -m pytest fooder --disable-warnings -sv -k "${1}"
|
python -m pytest fooder --disable-warnings -sv -k "${1}"
|
||||||
else
|
else
|
||||||
python3 -m pytest fooder --disable-warnings -sv
|
python -m pytest fooder --disable-warnings -sv
|
||||||
fi
|
fi
|
||||||
|
|
||||||
status=$?
|
exit $?
|
||||||
|
|
||||||
# unset test env values
|
|
||||||
unset POSTGRES_USER
|
|
||||||
unset POSTGRES_DATABASE
|
|
||||||
unset POSTGRES_PASSWORD
|
|
||||||
unset SECRET_KEY
|
|
||||||
unset REFRESH_SECRET
|
|
||||||
unset API_KEY
|
|
||||||
|
|
||||||
# if exists, remove test.db
|
|
||||||
[ -f test.db ] && rm test.db
|
|
||||||
|
|
||||||
exit $status
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue