# tests/test_db.py import pytest import asyncio from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from fooder.db import DatabaseSessionManager, get_db_session from fooder.settings import settings @pytest.fixture def fresh_manager(): return DatabaseSessionManager(settings) async def test_init_creates_engine_and_sessionmaker(db_manager: DatabaseSessionManager): assert db_manager._engine is not None assert db_manager._sessionmaker is not None async def test_close_disposes_engine_and_nullifies_attrs( fresh_manager: DatabaseSessionManager, ): await fresh_manager.close() assert fresh_manager._engine is None assert fresh_manager._sessionmaker is None async def test_connect_after_close_raises(fresh_manager: DatabaseSessionManager): await fresh_manager.close() with pytest.raises(Exception, match="not initialized"): async with fresh_manager.connect(): pass async def test_session_after_close_raises(fresh_manager: DatabaseSessionManager): await fresh_manager.close() with pytest.raises(Exception, match="not initialized"): async with fresh_manager.session(): pass async def test_close_when_already_closed_raises(fresh_manager: DatabaseSessionManager): await fresh_manager.close() with pytest.raises(Exception, match="not initialized"): await fresh_manager.close() async def test_session_commit_persists_data(db_manager: DatabaseSessionManager): async with db_manager.connect() as conn: await conn.execute(text("CREATE TABLE test_commit(x int)")) async with db_manager.session() as session: await session.execute(text("INSERT INTO test_commit VALUES (42)")) await session.commit() async with db_manager.session() as session: res = await session.execute(text("SELECT x FROM test_commit")) assert res.scalar() == 42 async def test_session_does_not_autocommit(db_manager: DatabaseSessionManager): async with db_manager.connect() as conn: await conn.execute(text("CREATE TABLE test_no_commit(x int)")) async with db_manager.session() as session: await session.execute(text("INSERT INTO test_no_commit VALUES (1)")) # no commit async with db_manager.session() as session: res = await session.execute(text("SELECT * FROM test_no_commit")) assert res.first() is None async def test_connect_context_yields_working_connection( db_manager: DatabaseSessionManager, ): async with db_manager.connect() as conn: assert isinstance(conn, AsyncConnection) # prove the connection is real res = await conn.execute(text("SELECT 1")) assert res.scalar() == 1 async def test_connect_rolls_back_on_exception(db_manager: DatabaseSessionManager): """Raising inside connect() must roll back the txn.""" class BoomError(Exception): pass with pytest.raises(BoomError): async with db_manager.connect() as conn: await conn.execute(text("CREATE TABLE t(x int)")) await conn.execute(text("INSERT INTO t VALUES (1)")) raise BoomError("deliberate") # Use a *fresh* connection so the failed one is really gone async with db_manager.connect() as conn: res = await conn.execute(text("SELECT * FROM t")) assert res.first() is None async def test_session_rolls_back_on_exception(db_manager: DatabaseSessionManager): """Raising inside session() must roll back the txn.""" class BoomError(Exception): pass with pytest.raises(BoomError): async with db_manager.session() as session: await session.execute(text("CREATE TABLE s(a int)")) await session.execute(text("INSERT INTO s VALUES (1)")) raise BoomError("deliberate") # Fresh session / connection async with db_manager.session() as session: res = await session.execute(text("SELECT * FROM s")) assert res.first() is None async def test_get_db_session_yields_active_session(): async for session in get_db_session(): assert isinstance(session, AsyncSession) res = await session.execute(text("SELECT 1337")) assert res.scalar() == 1337 break # single yield is enough async def test_concurrent_sessions(db_manager: DatabaseSessionManager): async with db_manager.connect() as conn: await conn.execute(text("CREATE TABLE test_concurrent(x int)")) async def worker(val): async with db_manager.session() as session: await session.execute( text("INSERT INTO test_concurrent VALUES (:v)"), {"v": val}, ) await session.commit() await asyncio.gather(*(worker(i) for i in range(5))) async with db_manager.session() as session: res = await session.execute(text("SELECT COUNT(*) FROM test_concurrent")) assert res.scalar() == 5