import contextlib from typing import AsyncIterator, AsyncGenerator from fooder.settings import Settings, settings from sqlalchemy.ext.asyncio import ( AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine, ) class DatabaseSessionManager: def __init__(self, settings: Settings) -> None: self._engine: AsyncEngine | None = create_async_engine( settings.DB_URI, pool_pre_ping=True, echo=settings.ECHO_SQL, connect_args=( {"check_same_thread": False} if settings.DB_URI.startswith("sqlite") else {} ), ) self._sessionmaker: async_sessionmaker[AsyncSession] | None = ( async_sessionmaker( autocommit=False, autoflush=False, bind=self._engine, expire_on_commit=False, ) ) async def close(self) -> None: if self._engine is None: raise Exception("DatabaseSessionManager is not initialized") await self._engine.dispose() self._engine = None self._sessionmaker = None @contextlib.asynccontextmanager async def connect(self) -> AsyncIterator[AsyncConnection]: if self._engine is None: raise Exception("DatabaseSessionManager is not initialized") async with self._engine.begin() as connection: try: yield connection except Exception: await connection.rollback() raise @contextlib.asynccontextmanager async def session(self) -> AsyncIterator[AsyncSession]: if self._sessionmaker is None: raise Exception("DatabaseSessionManager is not initialized") session = self._sessionmaker() try: yield session except Exception: await session.rollback() raise finally: await session.close() session_manager = DatabaseSessionManager(settings) async def get_db_session() -> AsyncGenerator[AsyncSession, None]: async with session_manager.session() as session: yield session