24 lines
710 B
Python
24 lines
710 B
Python
import pytest_asyncio
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import event
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def db_session(db_manager):
|
|
async with db_manager._engine.connect() as conn:
|
|
trans = await conn.begin()
|
|
session = AsyncSession(bind=conn)
|
|
|
|
nested = await conn.begin_nested()
|
|
|
|
@event.listens_for(session.sync_session, "after_transaction_end")
|
|
def restart_savepoint(sess, transaction):
|
|
nonlocal nested
|
|
if not nested.is_active:
|
|
nested = conn.sync_connection.begin_nested()
|
|
|
|
try:
|
|
yield session
|
|
finally:
|
|
await session.close()
|
|
await trans.rollback()
|