begin tedious work
This commit is contained in:
parent
9792b0feb3
commit
bbbd124d78
29 changed files with 340 additions and 578 deletions
|
|
@ -17,7 +17,7 @@ FROM python:3.11.5-bullseye
|
|||
|
||||
RUN apt-get -y install libpq-dev
|
||||
|
||||
COPY requirements.txt requirements.txt
|
||||
COPY requirements/docker.txt requirements.txt
|
||||
RUN pip install -r requirements.txt
|
||||
|
||||
RUN useradd fooder
|
||||
|
|
|
|||
2
Makefile
2
Makefile
|
|
@ -40,7 +40,7 @@ version:
|
|||
.PHONY: create-venv
|
||||
create-venv:
|
||||
python3 -m venv .venv --prompt="fooderapi-venv" --system-site-packages
|
||||
bash -c "source .venv/bin/activate && pip install -r requirements_local.txt"
|
||||
bash -c "source .venv/bin/activate && pip install -r requirements/local.txt"
|
||||
|
||||
.PHONY: test
|
||||
test:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from jose import JWTError, jwt
|
|||
from passlib.context import CryptContext
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from .db import get_session
|
||||
from .db import get_db_session
|
||||
from .domain.token import RefreshToken
|
||||
from .domain.user import User
|
||||
from .settings import Settings
|
||||
|
|
@ -18,7 +18,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/token")
|
|||
settings = Settings()
|
||||
password_helper = PasswordHelper(pwd_context) # type: ignore
|
||||
|
||||
AsyncSessionDependency = Annotated[async_sessionmaker, Depends(get_session)]
|
||||
AsyncSessionDependency = Annotated[async_sessionmaker, Depends(get_db_session)]
|
||||
TokenDependency = Annotated[str, Depends(oauth2_scheme)]
|
||||
|
||||
|
||||
|
|
|
|||
30
fooder/context.py
Normal file
30
fooder/context.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from fastapi import Depends
|
||||
from fooder.db import get_db_session
|
||||
|
||||
|
||||
class Context:
|
||||
"""
|
||||
Main API context, aggregating dependencies
|
||||
"""
|
||||
|
||||
def __init__(self, dbssn: AsyncSession) -> None:
|
||||
self.dbssn = dbssn
|
||||
|
||||
|
||||
class ContextDependency:
|
||||
"""
|
||||
Configurable context dependecy. Allows for shared interface configuring
|
||||
method required dependencies
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
dbssn: AsyncSession = Depends(get_db_session),
|
||||
):
|
||||
return Context(dbssn=dbssn)
|
||||
|
|
@ -4,17 +4,16 @@ from fastapi import Depends
|
|||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
from ..auth import authorize_api_key, get_current_user
|
||||
from ..db import get_session
|
||||
from ..db import get_db_session, AsyncSession
|
||||
from ..domain.user import User
|
||||
|
||||
AsyncSession = Annotated[async_sessionmaker, Depends(get_session)]
|
||||
UserDependency = Annotated[User, Depends(get_current_user)]
|
||||
ApiKeyDependency = Annotated[None, Depends(authorize_api_key)]
|
||||
|
||||
|
||||
class BaseController:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.async_session = session
|
||||
self.session = session
|
||||
|
||||
async def call(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@ from .base import AuthorizedController
|
|||
|
||||
class GetDiary(AuthorizedController):
|
||||
async def call(self, date: date) -> Diary:
|
||||
async with self.async_session() as session:
|
||||
diary = await DBDiary.get_diary(session, self.user.id, date)
|
||||
diary = await DBDiary.get_diary(self.session, self.user.id, date)
|
||||
|
||||
if diary is not None:
|
||||
return Diary.from_orm(diary)
|
||||
|
|
|
|||
85
fooder/db.py
85
fooder/db.py
|
|
@ -1,38 +1,69 @@
|
|||
import logging
|
||||
from typing import AsyncIterator
|
||||
import contextlib
|
||||
from typing import AsyncIterator, AsyncGenerator
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from fooder.settings import Settings, settings
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncConnection,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from .settings import Settings
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
settings = Settings.parse_obj({})
|
||||
|
||||
if settings.DB_URI.startswith("sqlite"):
|
||||
settings.DB_URI = settings.DB_URI + "?check_same_thread=False"
|
||||
|
||||
"""
|
||||
Asynchronous PostgreSQL database engine.
|
||||
"""
|
||||
async_engine = create_async_engine(
|
||||
class DatabaseSessionManager:
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self._engine = create_async_engine(
|
||||
settings.DB_URI,
|
||||
pool_pre_ping=True,
|
||||
echo=settings.ECHO_SQL,
|
||||
connect_args=(
|
||||
{"check_same_thread": False} if settings.DB_URI.startswith("sqlite") else {}
|
||||
{"check_same_thread": False}
|
||||
if settings.DB_URI.startswith("sqlite")
|
||||
else {}
|
||||
),
|
||||
)
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
future=True,
|
||||
)
|
||||
)
|
||||
self._sessionmaker = async_sessionmaker(
|
||||
autocommit=False, autoflush=False, future=True, bind=self._engine
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._engine is None:
|
||||
raise Exception("DatabaseSessionManager is not initialized")
|
||||
await self._engine.dispose()
|
||||
|
||||
async def get_session() -> AsyncIterator[async_sessionmaker]:
|
||||
self._engine = None
|
||||
self._sessionmaker = None
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def connect(self) -> AsyncIterator[AsyncConnection]:
|
||||
if self._engine is None:
|
||||
raise Exception("DatabaseSessionManager is not initialized")
|
||||
|
||||
async with self._engine.begin() as connection:
|
||||
try:
|
||||
yield AsyncSessionLocal
|
||||
except SQLAlchemyError as e:
|
||||
log.exception(e)
|
||||
yield connection
|
||||
except Exception:
|
||||
await connection.rollback()
|
||||
raise
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def session(self) -> AsyncIterator[AsyncSession]:
|
||||
if self._sessionmaker is None:
|
||||
raise Exception("DatabaseSessionManager is not initialized")
|
||||
|
||||
session = self._sessionmaker()
|
||||
try:
|
||||
yield session
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
session_manager = DatabaseSessionManager(settings)
|
||||
|
||||
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with session_manager.session() as session:
|
||||
yield session
|
||||
|
|
|
|||
0
fooder/repository/__init__.py
Normal file
0
fooder/repository/__init__.py
Normal file
6
fooder/repository/base.py
Normal file
6
fooder/repository/base.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class RepositoryBase:
|
||||
def __init__(self, dbssn: AsyncSession):
|
||||
self.dbssn = dbssn
|
||||
2
fooder/repository/user.py
Normal file
2
fooder/repository/user.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
class UserRepository:
|
||||
pass
|
||||
|
|
@ -13,8 +13,11 @@ class Settings(BaseSettings):
|
|||
REFRESH_SECRET_KEY: str
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 120
|
||||
|
||||
ALLOWED_ORIGINS: List[str] = ["*"]
|
||||
|
||||
API_KEY: str
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
|
|
|||
|
|
@ -1 +1,34 @@
|
|||
from .fixtures import * # noqa
|
||||
import os
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Supply minimal dummy env-vars *before* any of our modules are imported. #
|
||||
# This lets the global `settings = Settings()` call succeed. #
|
||||
# --------------------------------------------------------------------------- #
|
||||
os.environ.update(
|
||||
{
|
||||
"DB_URI": "sqlite+aiosqlite:///:memory:",
|
||||
"ECHO_SQL": "false",
|
||||
"SECRET_KEY": "test-secret",
|
||||
"REFRESH_SECRET_KEY": "test-refresh",
|
||||
"API_KEY": "test-key",
|
||||
}
|
||||
)
|
||||
|
||||
from fooder.db import DatabaseSessionManager
|
||||
from fooder.domain import Base
|
||||
from fooder.settings import settings
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def db_manager() -> DatabaseSessionManager:
|
||||
return DatabaseSessionManager(settings)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def setup_database(db_manager: DatabaseSessionManager):
|
||||
async with db_manager.connect() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield
|
||||
|
|
|
|||
6
fooder/test/fixtures/__init__.py
vendored
6
fooder/test/fixtures/__init__.py
vendored
|
|
@ -1,9 +1,5 @@
|
|||
from .client import * # noqa
|
||||
from .user import * # noqa
|
||||
from .product import * # noqa
|
||||
from .meal import * # noqa
|
||||
from .entry import * # noqa
|
||||
import pytest
|
||||
from .dbssn import *
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
110
fooder/test/fixtures/client.py
vendored
110
fooder/test/fixtures/client.py
vendored
|
|
@ -1,110 +0,0 @@
|
|||
from fooder.app import app
|
||||
from fooder.tasks_app import app as tasks_app
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
import pytest
|
||||
import httpx
|
||||
import os
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(
|
||||
self,
|
||||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
):
|
||||
self.client = lambda: AsyncClient(
|
||||
transport=ASGITransport(app=app),
|
||||
base_url="http://testserver/api",
|
||||
headers=self.headers,
|
||||
)
|
||||
self.headers = {"Accept": "application/json"}
|
||||
|
||||
def set_token(self, token: str) -> None:
|
||||
"""set_token.
|
||||
|
||||
:param token:
|
||||
:type token: str
|
||||
:rtype: None
|
||||
"""
|
||||
self.headers["Authorization"] = "Bearer " + token
|
||||
|
||||
async def create_user(self, username: str, password: str) -> None:
|
||||
data = {"username": username, "password": password}
|
||||
response = await self.post("user", json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
async def login(self, username: str, password: str, force_login: bool) -> None:
|
||||
"""login.
|
||||
|
||||
:param username:
|
||||
:type username: str
|
||||
:param password:
|
||||
:type password: str
|
||||
:param force_login:
|
||||
:type password: bool
|
||||
:rtype: None
|
||||
"""
|
||||
data = {"username": username, "password": password}
|
||||
|
||||
response = await self.post("token", data=data)
|
||||
|
||||
if response.status_code != 200:
|
||||
if force_login:
|
||||
await self.create_user(username, password)
|
||||
return await self.login(username, password, False)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Could not login as {username}! Detail: {response.text}"
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
self.set_token(result["access_token"])
|
||||
|
||||
async def get(self, path: str, **kwargs) -> httpx.Response:
|
||||
async with self.client() as client:
|
||||
return await client.get(path, **kwargs)
|
||||
|
||||
async def delete(self, path: str, **kwargs) -> httpx.Response:
|
||||
async with self.client() as client:
|
||||
return await client.delete(path, **kwargs)
|
||||
|
||||
async def post(self, path: str, **kwargs) -> httpx.Response:
|
||||
async with self.client() as client:
|
||||
return await client.post(path, **kwargs)
|
||||
|
||||
async def patch(self, path: str, **kwargs) -> httpx.Response:
|
||||
async with self.client() as client:
|
||||
return await client.patch(path, **kwargs)
|
||||
|
||||
|
||||
class TasksClient(Client):
|
||||
def __init__(self, authorized: bool = True):
|
||||
super().__init__()
|
||||
self.client = lambda: AsyncClient(
|
||||
transport=ASGITransport(app=tasks_app),
|
||||
base_url="http://testserver/api",
|
||||
headers=self.headers,
|
||||
)
|
||||
|
||||
if authorized:
|
||||
self.headers["Authorization"] = "Bearer " + self.get_token()
|
||||
|
||||
def get_token(self) -> str:
|
||||
return os.getenv("API_KEY")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unauthorized_client() -> Client:
|
||||
return Client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tasks_client() -> Client:
|
||||
return TasksClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(user_payload) -> Client:
|
||||
client = Client()
|
||||
await client.login(user_payload["username"], user_payload["password"], True)
|
||||
return client
|
||||
24
fooder/test/fixtures/dbssn.py
vendored
Normal file
24
fooder/test/fixtures/dbssn.py
vendored
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import event
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session(db_manager):
|
||||
async with db_manager._engine.connect() as conn:
|
||||
trans = await conn.begin()
|
||||
session = AsyncSession(bind=conn)
|
||||
|
||||
nested = await conn.begin_nested()
|
||||
|
||||
@event.listens_for(session.sync_session, "after_transaction_end")
|
||||
def restart_savepoint(sess, transaction):
|
||||
nonlocal nested
|
||||
if not nested.is_active:
|
||||
nested = conn.sync_connection.begin_nested()
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
await trans.rollback()
|
||||
14
fooder/test/fixtures/entry.py
vendored
14
fooder/test/fixtures/entry.py
vendored
|
|
@ -1,14 +0,0 @@
|
|||
import pytest
|
||||
from typing import Callable
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entry_payload_factory() -> Callable[[int, int, float], dict[str, int | float]]:
|
||||
def factory(meal_id: int, product_id: int, grams: float) -> dict[str, int | float]:
|
||||
return {
|
||||
"meal_id": meal_id,
|
||||
"product_id": product_id,
|
||||
"grams": grams,
|
||||
}
|
||||
|
||||
return factory
|
||||
37
fooder/test/fixtures/meal.py
vendored
37
fooder/test/fixtures/meal.py
vendored
|
|
@ -1,37 +0,0 @@
|
|||
import pytest
|
||||
from typing import Callable
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def meal_payload_factory() -> Callable[[int, int], dict[str, int | str]]:
|
||||
def factory(diary_id: int, order: int) -> dict[str, int | str]:
|
||||
return {
|
||||
"order": order,
|
||||
"diary_id": diary_id,
|
||||
"name": f"meal {order}",
|
||||
}
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def meal_save_payload() -> Callable[[int], dict[str, str]]:
|
||||
def factory(meal_id: int) -> dict[str, str]:
|
||||
return {
|
||||
"name": "new name",
|
||||
}
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def meal_from_preset() -> Callable[[int, int, int], dict[str, str | int]]:
|
||||
def factory(order: int, diary_id: int, preset_id: int) -> dict[str, str | int]:
|
||||
return {
|
||||
"name": "new name",
|
||||
"order": order,
|
||||
"diary_id": diary_id,
|
||||
"preset_id": preset_id,
|
||||
}
|
||||
|
||||
return factory
|
||||
22
fooder/test/fixtures/product.py
vendored
22
fooder/test/fixtures/product.py
vendored
|
|
@ -1,22 +0,0 @@
|
|||
import pytest
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def product_payload_factory() -> Callable[[], dict[str, str | float]]:
|
||||
def factory() -> dict[str, str | float]:
|
||||
return {
|
||||
"name": "test" + str(uuid.uuid4().hex),
|
||||
"protein": 1.0,
|
||||
"carb": 1.0,
|
||||
"fat": 1.0,
|
||||
"fiber": 1.0,
|
||||
}
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def product_payload(product_payload_factory) -> dict[str, str | float]:
|
||||
return product_payload_factory()
|
||||
22
fooder/test/fixtures/user.py
vendored
22
fooder/test/fixtures/user.py
vendored
|
|
@ -1,22 +0,0 @@
|
|||
import pytest
|
||||
from typing import Callable
|
||||
import uuid
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_payload() -> dict[str, str]:
|
||||
return {
|
||||
"username": "test",
|
||||
"password": "test",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_payload_factory(user_payload) -> Callable[[], dict[str, str]]:
|
||||
def factory() -> dict[str, str]:
|
||||
return {
|
||||
"username": "test" + str(uuid.uuid4().hex),
|
||||
"password": "test",
|
||||
}
|
||||
|
||||
return factory
|
||||
148
fooder/test/test_db.py
Normal file
148
fooder/test/test_db.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
# tests/test_db.py
|
||||
import pytest
|
||||
import asyncio
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
|
||||
|
||||
from fooder.db import DatabaseSessionManager, get_db_session
|
||||
from fooder.settings import settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_manager():
|
||||
return DatabaseSessionManager(settings)
|
||||
|
||||
|
||||
async def test_init_creates_engine_and_sessionmaker(db_manager: DatabaseSessionManager):
|
||||
assert db_manager._engine is not None
|
||||
assert db_manager._sessionmaker is not None
|
||||
|
||||
|
||||
async def test_close_disposes_engine_and_nullifies_attrs(
|
||||
fresh_manager: DatabaseSessionManager,
|
||||
):
|
||||
await fresh_manager.close()
|
||||
assert fresh_manager._engine is None
|
||||
assert fresh_manager._sessionmaker is None
|
||||
|
||||
|
||||
async def test_connect_after_close_raises(fresh_manager: DatabaseSessionManager):
|
||||
await fresh_manager.close()
|
||||
|
||||
with pytest.raises(Exception, match="not initialized"):
|
||||
async with fresh_manager.connect():
|
||||
pass
|
||||
|
||||
|
||||
async def test_session_after_close_raises(fresh_manager: DatabaseSessionManager):
|
||||
await fresh_manager.close()
|
||||
|
||||
with pytest.raises(Exception, match="not initialized"):
|
||||
async with fresh_manager.session():
|
||||
pass
|
||||
|
||||
|
||||
async def test_close_when_already_closed_raises(fresh_manager: DatabaseSessionManager):
|
||||
await fresh_manager.close()
|
||||
|
||||
with pytest.raises(Exception, match="not initialized"):
|
||||
await fresh_manager.close()
|
||||
|
||||
|
||||
async def test_session_commit_persists_data(db_manager: DatabaseSessionManager):
|
||||
async with db_manager.connect() as conn:
|
||||
await conn.execute(text("CREATE TABLE test_commit(x int)"))
|
||||
|
||||
async with db_manager.session() as session:
|
||||
await session.execute(text("INSERT INTO test_commit VALUES (42)"))
|
||||
await session.commit()
|
||||
|
||||
async with db_manager.session() as session:
|
||||
res = await session.execute(text("SELECT x FROM test_commit"))
|
||||
assert res.scalar() == 42
|
||||
|
||||
|
||||
async def test_session_does_not_autocommit(db_manager: DatabaseSessionManager):
|
||||
async with db_manager.connect() as conn:
|
||||
await conn.execute(text("CREATE TABLE test_no_commit(x int)"))
|
||||
|
||||
async with db_manager.session() as session:
|
||||
await session.execute(text("INSERT INTO test_no_commit VALUES (1)"))
|
||||
# no commit
|
||||
|
||||
async with db_manager.session() as session:
|
||||
res = await session.execute(text("SELECT * FROM test_no_commit"))
|
||||
assert res.first() is None
|
||||
|
||||
|
||||
async def test_connect_context_yields_working_connection(
|
||||
db_manager: DatabaseSessionManager,
|
||||
):
|
||||
async with db_manager.connect() as conn:
|
||||
assert isinstance(conn, AsyncConnection)
|
||||
# prove the connection is real
|
||||
res = await conn.execute(text("SELECT 1"))
|
||||
assert res.scalar() == 1
|
||||
|
||||
|
||||
async def test_connect_rolls_back_on_exception(db_manager: DatabaseSessionManager):
|
||||
"""Raising inside connect() must roll back the txn."""
|
||||
|
||||
class BoomError(Exception):
|
||||
pass
|
||||
|
||||
with pytest.raises(BoomError):
|
||||
async with db_manager.connect() as conn:
|
||||
await conn.execute(text("CREATE TABLE t(x int)"))
|
||||
await conn.execute(text("INSERT INTO t VALUES (1)"))
|
||||
raise BoomError("deliberate")
|
||||
|
||||
# Use a *fresh* connection so the failed one is really gone
|
||||
async with db_manager.connect() as conn:
|
||||
res = await conn.execute(text("SELECT * FROM t"))
|
||||
assert res.first() is None
|
||||
|
||||
|
||||
async def test_session_rolls_back_on_exception(db_manager: DatabaseSessionManager):
|
||||
"""Raising inside session() must roll back the txn."""
|
||||
|
||||
class BoomError(Exception):
|
||||
pass
|
||||
|
||||
with pytest.raises(BoomError):
|
||||
async with db_manager.session() as session:
|
||||
await session.execute(text("CREATE TABLE s(a int)"))
|
||||
await session.execute(text("INSERT INTO s VALUES (1)"))
|
||||
raise BoomError("deliberate")
|
||||
|
||||
# Fresh session / connection
|
||||
async with db_manager.session() as session:
|
||||
res = await session.execute(text("SELECT * FROM s"))
|
||||
assert res.first() is None
|
||||
|
||||
|
||||
async def test_get_db_session_yields_active_session():
|
||||
async for session in get_db_session():
|
||||
assert isinstance(session, AsyncSession)
|
||||
res = await session.execute(text("SELECT 1337"))
|
||||
assert res.scalar() == 1337
|
||||
break # single yield is enough
|
||||
|
||||
|
||||
async def test_concurrent_sessions(db_manager: DatabaseSessionManager):
|
||||
async with db_manager.connect() as conn:
|
||||
await conn.execute(text("CREATE TABLE test_concurrent(x int)"))
|
||||
|
||||
async def worker(val):
|
||||
async with db_manager.session() as session:
|
||||
await session.execute(
|
||||
text("INSERT INTO test_concurrent VALUES (:v)"),
|
||||
{"v": val},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
await asyncio.gather(*(worker(i) for i in range(5)))
|
||||
|
||||
async with db_manager.session() as session:
|
||||
res = await session.execute(text("SELECT COUNT(*) FROM test_concurrent"))
|
||||
assert res.scalar() == 5
|
||||
|
|
@ -1,95 +0,0 @@
|
|||
import datetime
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_diary(client):
|
||||
today = datetime.date.today().isoformat()
|
||||
response = await client.get("diary", params={"date": today})
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
assert response.json()["date"] == today
|
||||
# new diary should contain exactly one meal
|
||||
assert len(response.json()["meals"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_diary_add_meal(client, meal_payload_factory):
|
||||
today = datetime.date.today().isoformat()
|
||||
response = await client.get("diary", params={"date": today})
|
||||
|
||||
diary_id = response.json()["id"]
|
||||
meal_order = len(response.json()["meals"]) + 1
|
||||
|
||||
response = await client.post(
|
||||
"meal", json=meal_payload_factory(diary_id, meal_order)
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_diary_delete_meal(client):
|
||||
today = datetime.date.today().isoformat()
|
||||
response = await client.get("diary", params={"date": today})
|
||||
|
||||
meals_amount = len(response.json()["meals"])
|
||||
meal_id = response.json()["meals"][0]["id"]
|
||||
|
||||
response = await client.delete(f"meal/{meal_id}")
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
response = await client.get("diary", params={"date": today})
|
||||
assert response.status_code == 200, response.json()
|
||||
assert len(response.json()["meals"]) == meals_amount - 1
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_diary_add_entry(client, product_payload_factory, entry_payload_factory):
|
||||
today = datetime.date.today().isoformat()
|
||||
response = await client.get("diary", params={"date": today})
|
||||
|
||||
meal_id = response.json()["meals"][0]["id"]
|
||||
|
||||
product_id = (await client.post("product", json=product_payload_factory())).json()[
|
||||
"id"
|
||||
]
|
||||
|
||||
entry_payload = entry_payload_factory(meal_id, product_id, 100.0)
|
||||
response = await client.post("entry", json=entry_payload)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_diary_edit_entry(client, entry_payload_factory):
|
||||
today = datetime.date.today().isoformat()
|
||||
response = await client.get("diary", params={"date": today})
|
||||
|
||||
entry = response.json()["meals"][0]["entries"][0]
|
||||
id_ = entry["id"]
|
||||
entry_payload = entry_payload_factory(
|
||||
entry["meal_id"], entry["product"]["id"], entry["grams"] + 100.0
|
||||
)
|
||||
|
||||
response = await client.patch(f"entry/{id_}", json=entry_payload)
|
||||
assert response.status_code == 200, response.json()
|
||||
assert response.json()["grams"] == entry_payload["grams"]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_diary_delete_entry(client):
|
||||
today = datetime.date.today().isoformat()
|
||||
response = await client.get("diary", params={"date": today})
|
||||
|
||||
entry_id = response.json()["meals"][0]["entries"][0]["id"]
|
||||
response = await client.delete(f"entry/{entry_id}")
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
response = await client.get("diary", params={"date": today})
|
||||
assert response.status_code == 200, response.json()
|
||||
deleted_entries = [
|
||||
entry
|
||||
for meal in response.json()["meals"]
|
||||
for entry in meal["entries"]
|
||||
if entry["id"] == entry_id
|
||||
]
|
||||
assert len(deleted_entries) == 0
|
||||
|
|
@ -1,107 +0,0 @@
|
|||
import datetime
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_meal(
|
||||
client, meal_payload_factory, product_payload_factory, entry_payload_factory
|
||||
):
|
||||
today = datetime.date.today().isoformat()
|
||||
response = await client.get("diary", params={"date": today})
|
||||
|
||||
diary_id = response.json()["id"]
|
||||
meal_order = len(response.json()["meals"]) + 1
|
||||
|
||||
response = await client.post(
|
||||
"meal", json=meal_payload_factory(diary_id, meal_order)
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
meal_id = response.json()["id"]
|
||||
|
||||
product_id = (await client.post("product", json=product_payload_factory())).json()[
|
||||
"id"
|
||||
]
|
||||
|
||||
entry_payload = entry_payload_factory(meal_id, product_id, 100.0)
|
||||
response = await client.post("entry", json=entry_payload)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_save_meal(client, meal_save_payload):
|
||||
today = datetime.date.today().isoformat()
|
||||
response = await client.get("diary", params={"date": today})
|
||||
|
||||
meal = response.json()["meals"][0]
|
||||
meal_id = meal["id"]
|
||||
save_payload = meal_save_payload(meal_id)
|
||||
|
||||
response = await client.post(f"meal/{meal_id}/save", json=save_payload)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
preset = response.json()
|
||||
|
||||
for k, v in preset.items():
|
||||
if k in ("id", "name", "entries"):
|
||||
continue
|
||||
|
||||
assert meal[k] == v, f"{k} != {v}"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_presets(client, meal_save_payload):
|
||||
response = await client.get("preset")
|
||||
assert response.status_code == 200, response.json()
|
||||
assert len(response.json()["presets"]) > 0, response.json()
|
||||
|
||||
name = meal_save_payload(0)["name"]
|
||||
response = await client.get(f"preset?q={name}")
|
||||
assert response.status_code == 200, response.json()
|
||||
assert len(response.json()["presets"]) > 0, response.json()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_meal_from_preset(client, meal_from_preset):
|
||||
today = datetime.date.today().isoformat()
|
||||
response = await client.get("diary", params={"date": today})
|
||||
|
||||
diary_id = response.json()["id"]
|
||||
meal_order = len(response.json()["meals"]) + 1
|
||||
|
||||
response = await client.get("preset")
|
||||
assert response.status_code == 200, response.json()
|
||||
assert len(response.json()["presets"]) > 0, response.json()
|
||||
|
||||
preset = response.json()["presets"][0]
|
||||
|
||||
payload = meal_from_preset(
|
||||
meal_order,
|
||||
diary_id,
|
||||
preset["id"],
|
||||
)
|
||||
|
||||
response = await client.post("meal/from_preset", json=payload)
|
||||
assert response.status_code == 200, response.json()
|
||||
meal = response.json()
|
||||
|
||||
for k, v in preset.items():
|
||||
if k in ("id", "name", "entries"):
|
||||
continue
|
||||
|
||||
assert meal[k] == v, f"{k} != {v}"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_preset(client):
|
||||
presets = (await client.get("preset")).json()["presets"]
|
||||
preset_id = presets[0]["id"]
|
||||
|
||||
response = await client.get(f"preset/{preset_id}")
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
response = await client.delete(f"preset/{preset_id}")
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
response = await client.get(f"preset/{preset_id}")
|
||||
assert response.status_code == 404, response.json()
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_product(client, product_payload):
|
||||
response = await client.post("product", json=product_payload)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_product(client):
|
||||
response = await client.get("product")
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
data = response.json()["products"]
|
||||
assert len(data) != 0
|
||||
|
||||
product_ids = set()
|
||||
for product in data:
|
||||
assert product["id"] not in product_ids
|
||||
product_ids.add(product["id"])
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_product_by_barcode(client):
|
||||
response = await client.get(
|
||||
"product/by_barcode", params={"barcode": "4056489666028"}
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
name = response.json()["name"]
|
||||
|
||||
response = await client.get("product", params={"q": name})
|
||||
assert response.status_code == 200, response.json()
|
||||
assert len(response.json()["products"]) == 1
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cache_product_usage(client, tasks_client):
|
||||
response = await client.get("product")
|
||||
assert response.status_code == 200, response.json()
|
||||
old_data = response.json()
|
||||
|
||||
response = await tasks_client.post("/cache_product_usage_data")
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
response = await client.get("product")
|
||||
assert response.status_code == 200, response.json()
|
||||
assert response.json() != old_data
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_user_creation(unauthorized_client, user_payload_factory):
|
||||
response = await unauthorized_client.post("user", json=user_payload_factory())
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_user_login(client, user_payload):
|
||||
response = await client.post("token", data=user_payload)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
data = response.json()
|
||||
assert data["access_token"] is not None
|
||||
assert data["refresh_token"] is not None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_user_refresh_token(client, user_payload):
|
||||
response = await client.post("token", data=user_payload)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
token = response.json()["refresh_token"]
|
||||
payload = {"refresh_token": token}
|
||||
|
||||
response = await client.post("token/refresh", json=payload)
|
||||
assert response.status_code == 200, response.json()
|
||||
2
pytest.ini
Normal file
2
pytest.ini
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
[pytest]
|
||||
asyncio_mode = auto
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
fastapi
|
||||
pydantic
|
||||
pydantic_settings
|
||||
pydantic-settings
|
||||
sqlalchemy[postgresql_asyncpg]
|
||||
uvicorn[standard]
|
||||
asyncpg
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
fastapi
|
||||
pydantic
|
||||
pydantic_settings
|
||||
pydantic-settings
|
||||
sqlalchemy[postgresql_asyncpg]
|
||||
uvicorn[standard]
|
||||
python-jose[cryptography]
|
||||
|
|
@ -8,6 +8,7 @@ bcrypt<5.0.0
|
|||
passlib[bcrypt]
|
||||
fastapi-users
|
||||
pytest
|
||||
pytest-asyncio
|
||||
requests
|
||||
black
|
||||
flake8
|
||||
32
test.sh
32
test.sh
|
|
@ -5,36 +5,10 @@
|
|||
|
||||
echo "Running fooder api tests"
|
||||
|
||||
# if exists, remove test.db
|
||||
[ -f test.db ] && rm test.db
|
||||
|
||||
# create test env values
|
||||
export DB_URI="sqlite+aiosqlite:///test.db"
|
||||
export ECHO_SQL=0
|
||||
export SECRET_KEY=$(openssl rand -hex 32)
|
||||
export REFRESH_SECRET_KEY=$(openssl rand -hex 32)
|
||||
export API_KEY=$(openssl rand -hex 32)
|
||||
|
||||
python3 -m fooder --create-tables
|
||||
|
||||
# finally run tests
|
||||
if [[ $# -eq 1 ]]; then
|
||||
python3 -m pytest fooder --disable-warnings -sv -k "${1}"
|
||||
python -m pytest fooder --disable-warnings -sv -k "${1}"
|
||||
else
|
||||
python3 -m pytest fooder --disable-warnings -sv
|
||||
python -m pytest fooder --disable-warnings -sv
|
||||
fi
|
||||
|
||||
status=$?
|
||||
|
||||
# unset test env values
|
||||
unset POSTGRES_USER
|
||||
unset POSTGRES_DATABASE
|
||||
unset POSTGRES_PASSWORD
|
||||
unset SECRET_KEY
|
||||
unset REFRESH_SECRET
|
||||
unset API_KEY
|
||||
|
||||
# if exists, remove test.db
|
||||
[ -f test.db ] && rm test.db
|
||||
|
||||
exit $status
|
||||
exit $?
|
||||
|
|
|
|||
Loading…
Reference in a new issue