fooder-api/fooder/test/test_db.py

148 lines
4.9 KiB
Python

# tests/test_db.py
import pytest
import asyncio
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
from fooder.db import DatabaseSessionManager, get_db_session
from fooder.settings import settings
@pytest.fixture
def fresh_manager():
return DatabaseSessionManager(settings)
async def test_init_creates_engine_and_sessionmaker(db_manager: DatabaseSessionManager):
assert db_manager._engine is not None
assert db_manager._sessionmaker is not None
async def test_close_disposes_engine_and_nullifies_attrs(
fresh_manager: DatabaseSessionManager,
):
await fresh_manager.close()
assert fresh_manager._engine is None
assert fresh_manager._sessionmaker is None
async def test_connect_after_close_raises(fresh_manager: DatabaseSessionManager):
await fresh_manager.close()
with pytest.raises(Exception, match="not initialized"):
async with fresh_manager.connect():
pass
async def test_session_after_close_raises(fresh_manager: DatabaseSessionManager):
await fresh_manager.close()
with pytest.raises(Exception, match="not initialized"):
async with fresh_manager.session():
pass
async def test_close_when_already_closed_raises(fresh_manager: DatabaseSessionManager):
await fresh_manager.close()
with pytest.raises(Exception, match="not initialized"):
await fresh_manager.close()
async def test_session_commit_persists_data(db_manager: DatabaseSessionManager):
async with db_manager.connect() as conn:
await conn.execute(text("CREATE TABLE test_commit(x int)"))
async with db_manager.session() as session:
await session.execute(text("INSERT INTO test_commit VALUES (42)"))
await session.commit()
async with db_manager.session() as session:
res = await session.execute(text("SELECT x FROM test_commit"))
assert res.scalar() == 42
async def test_session_does_not_autocommit(db_manager: DatabaseSessionManager):
async with db_manager.connect() as conn:
await conn.execute(text("CREATE TABLE test_no_commit(x int)"))
async with db_manager.session() as session:
await session.execute(text("INSERT INTO test_no_commit VALUES (1)"))
# no commit
async with db_manager.session() as session:
res = await session.execute(text("SELECT * FROM test_no_commit"))
assert res.first() is None
async def test_connect_context_yields_working_connection(
db_manager: DatabaseSessionManager,
):
async with db_manager.connect() as conn:
assert isinstance(conn, AsyncConnection)
# prove the connection is real
res = await conn.execute(text("SELECT 1"))
assert res.scalar() == 1
async def test_connect_rolls_back_on_exception(db_manager: DatabaseSessionManager):
"""Raising inside connect() must roll back the txn."""
class BoomError(Exception):
pass
with pytest.raises(BoomError):
async with db_manager.connect() as conn:
await conn.execute(text("CREATE TABLE t(x int)"))
await conn.execute(text("INSERT INTO t VALUES (1)"))
raise BoomError("deliberate")
# Use a *fresh* connection so the failed one is really gone
async with db_manager.connect() as conn:
res = await conn.execute(text("SELECT * FROM t"))
assert res.first() is None
async def test_session_rolls_back_on_exception(db_manager: DatabaseSessionManager):
"""Raising inside session() must roll back the txn."""
class BoomError(Exception):
pass
with pytest.raises(BoomError):
async with db_manager.session() as session:
await session.execute(text("CREATE TABLE s(a int)"))
await session.execute(text("INSERT INTO s VALUES (1)"))
raise BoomError("deliberate")
# Fresh session / connection
async with db_manager.session() as session:
res = await session.execute(text("SELECT * FROM s"))
assert res.first() is None
async def test_get_db_session_yields_active_session():
async for session in get_db_session():
assert isinstance(session, AsyncSession)
res = await session.execute(text("SELECT 1337"))
assert res.scalar() == 1337
break # single yield is enough
async def test_concurrent_sessions(db_manager: DatabaseSessionManager):
async with db_manager.connect() as conn:
await conn.execute(text("CREATE TABLE test_concurrent(x int)"))
async def worker(val):
async with db_manager.session() as session:
await session.execute(
text("INSERT INTO test_concurrent VALUES (:v)"),
{"v": val},
)
await session.commit()
await asyncio.gather(*(worker(i) for i in range(5)))
async with db_manager.session() as session:
res = await session.execute(text("SELECT COUNT(*) FROM test_concurrent"))
assert res.scalar() == 5