148 lines
4.9 KiB
Python
148 lines
4.9 KiB
Python
# tests/test_db.py
|
|
import pytest
|
|
import asyncio
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
|
|
|
|
from fooder.db import DatabaseSessionManager, get_db_session
|
|
from fooder.settings import settings
|
|
|
|
|
|
@pytest.fixture
|
|
def fresh_manager():
|
|
return DatabaseSessionManager(settings)
|
|
|
|
|
|
async def test_init_creates_engine_and_sessionmaker(db_manager: DatabaseSessionManager):
|
|
assert db_manager._engine is not None
|
|
assert db_manager._sessionmaker is not None
|
|
|
|
|
|
async def test_close_disposes_engine_and_nullifies_attrs(
|
|
fresh_manager: DatabaseSessionManager,
|
|
):
|
|
await fresh_manager.close()
|
|
assert fresh_manager._engine is None
|
|
assert fresh_manager._sessionmaker is None
|
|
|
|
|
|
async def test_connect_after_close_raises(fresh_manager: DatabaseSessionManager):
|
|
await fresh_manager.close()
|
|
|
|
with pytest.raises(Exception, match="not initialized"):
|
|
async with fresh_manager.connect():
|
|
pass
|
|
|
|
|
|
async def test_session_after_close_raises(fresh_manager: DatabaseSessionManager):
|
|
await fresh_manager.close()
|
|
|
|
with pytest.raises(Exception, match="not initialized"):
|
|
async with fresh_manager.session():
|
|
pass
|
|
|
|
|
|
async def test_close_when_already_closed_raises(fresh_manager: DatabaseSessionManager):
|
|
await fresh_manager.close()
|
|
|
|
with pytest.raises(Exception, match="not initialized"):
|
|
await fresh_manager.close()
|
|
|
|
|
|
async def test_session_commit_persists_data(db_manager: DatabaseSessionManager):
|
|
async with db_manager.connect() as conn:
|
|
await conn.execute(text("CREATE TABLE test_commit(x int)"))
|
|
|
|
async with db_manager.session() as session:
|
|
await session.execute(text("INSERT INTO test_commit VALUES (42)"))
|
|
await session.commit()
|
|
|
|
async with db_manager.session() as session:
|
|
res = await session.execute(text("SELECT x FROM test_commit"))
|
|
assert res.scalar() == 42
|
|
|
|
|
|
async def test_session_does_not_autocommit(db_manager: DatabaseSessionManager):
|
|
async with db_manager.connect() as conn:
|
|
await conn.execute(text("CREATE TABLE test_no_commit(x int)"))
|
|
|
|
async with db_manager.session() as session:
|
|
await session.execute(text("INSERT INTO test_no_commit VALUES (1)"))
|
|
# no commit
|
|
|
|
async with db_manager.session() as session:
|
|
res = await session.execute(text("SELECT * FROM test_no_commit"))
|
|
assert res.first() is None
|
|
|
|
|
|
async def test_connect_context_yields_working_connection(
|
|
db_manager: DatabaseSessionManager,
|
|
):
|
|
async with db_manager.connect() as conn:
|
|
assert isinstance(conn, AsyncConnection)
|
|
# prove the connection is real
|
|
res = await conn.execute(text("SELECT 1"))
|
|
assert res.scalar() == 1
|
|
|
|
|
|
async def test_connect_rolls_back_on_exception(db_manager: DatabaseSessionManager):
|
|
"""Raising inside connect() must roll back the txn."""
|
|
|
|
class BoomError(Exception):
|
|
pass
|
|
|
|
with pytest.raises(BoomError):
|
|
async with db_manager.connect() as conn:
|
|
await conn.execute(text("CREATE TABLE t(x int)"))
|
|
await conn.execute(text("INSERT INTO t VALUES (1)"))
|
|
raise BoomError("deliberate")
|
|
|
|
# Use a *fresh* connection so the failed one is really gone
|
|
async with db_manager.connect() as conn:
|
|
res = await conn.execute(text("SELECT * FROM t"))
|
|
assert res.first() is None
|
|
|
|
|
|
async def test_session_rolls_back_on_exception(db_manager: DatabaseSessionManager):
|
|
"""Raising inside session() must roll back the txn."""
|
|
|
|
class BoomError(Exception):
|
|
pass
|
|
|
|
with pytest.raises(BoomError):
|
|
async with db_manager.session() as session:
|
|
await session.execute(text("CREATE TABLE s(a int)"))
|
|
await session.execute(text("INSERT INTO s VALUES (1)"))
|
|
raise BoomError("deliberate")
|
|
|
|
# Fresh session / connection
|
|
async with db_manager.session() as session:
|
|
res = await session.execute(text("SELECT * FROM s"))
|
|
assert res.first() is None
|
|
|
|
|
|
async def test_get_db_session_yields_active_session():
|
|
async for session in get_db_session():
|
|
assert isinstance(session, AsyncSession)
|
|
res = await session.execute(text("SELECT 1337"))
|
|
assert res.scalar() == 1337
|
|
break # single yield is enough
|
|
|
|
|
|
async def test_concurrent_sessions(db_manager: DatabaseSessionManager):
|
|
async with db_manager.connect() as conn:
|
|
await conn.execute(text("CREATE TABLE test_concurrent(x int)"))
|
|
|
|
async def worker(val):
|
|
async with db_manager.session() as session:
|
|
await session.execute(
|
|
text("INSERT INTO test_concurrent VALUES (:v)"),
|
|
{"v": val},
|
|
)
|
|
await session.commit()
|
|
|
|
await asyncio.gather(*(worker(i) for i in range(5)))
|
|
|
|
async with db_manager.session() as session:
|
|
res = await session.execute(text("SELECT COUNT(*) FROM test_concurrent"))
|
|
assert res.scalar() == 5
|