fooder-api/fooder/test/fixtures/db.py

72 lines
1.8 KiB
Python

import pytest
import pytest_asyncio
from sqlalchemy import event
from sqlalchemy.ext.asyncio import AsyncSession
from fooder.db import DatabaseSessionManager
from fooder.domain import Base
from fooder.settings import settings
from fooder.repository.base import RepositoryBase
from sqlalchemy.orm import Mapped, mapped_column
class TestModel(Base):
__tablename__ = "test"
id: Mapped[int] = mapped_column(primary_key=True)
property: Mapped[str]
class TestRepository(RepositoryBase[TestModel]):
pass
@pytest.fixture
def test_repo(db_session):
return TestRepository(TestModel, db_session)
@pytest.fixture
def test_model_factory():
def factory(id: int, property: str = "value") -> TestModel:
return TestModel(id=id, property=property)
return factory
@pytest.fixture
def test_model(test_model_factory):
return test_model_factory(1)
@pytest.fixture(scope="session")
def db_manager() -> DatabaseSessionManager:
return DatabaseSessionManager(settings)
@pytest_asyncio.fixture
async def db_session(db_manager):
async with db_manager._engine.connect() as conn:
trans = await conn.begin()
session = AsyncSession(bind=conn, expire_on_commit=False)
nested = await conn.begin_nested()
@event.listens_for(session.sync_session, "after_transaction_end")
def restart_savepoint(sess, transaction):
nonlocal nested
if not nested.is_active:
nested = conn.sync_connection.begin_nested()
try:
yield session
finally:
await session.close()
await trans.rollback()
@pytest_asyncio.fixture(scope="session", autouse=True)
async def setup_database(db_manager: DatabaseSessionManager):
async with db_manager.connect() as conn:
await conn.run_sync(Base.metadata.create_all)
yield