72 lines
1.8 KiB
Python
72 lines
1.8 KiB
Python
import pytest
|
|
import pytest_asyncio
|
|
|
|
from sqlalchemy import event
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from fooder.db import DatabaseSessionManager
|
|
from fooder.domain import Base
|
|
from fooder.settings import settings
|
|
from fooder.repository.base import RepositoryBase
|
|
from sqlalchemy.orm import Mapped, mapped_column
|
|
|
|
|
|
class TestModel(Base):
|
|
__tablename__ = "test"
|
|
id: Mapped[int] = mapped_column(primary_key=True)
|
|
property: Mapped[str]
|
|
|
|
|
|
class TestRepository(RepositoryBase[TestModel]):
|
|
pass
|
|
|
|
|
|
@pytest.fixture
|
|
def test_repo(db_session):
|
|
return TestRepository(TestModel, db_session)
|
|
|
|
|
|
@pytest.fixture
|
|
def test_model_factory():
|
|
def factory(id: int, property: str = "value") -> TestModel:
|
|
return TestModel(id=id, property=property)
|
|
|
|
return factory
|
|
|
|
|
|
@pytest.fixture
|
|
def test_model(test_model_factory):
|
|
return test_model_factory(1)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def db_manager() -> DatabaseSessionManager:
|
|
return DatabaseSessionManager(settings)
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def db_session(db_manager):
|
|
async with db_manager._engine.connect() as conn:
|
|
trans = await conn.begin()
|
|
session = AsyncSession(bind=conn, expire_on_commit=False)
|
|
|
|
nested = await conn.begin_nested()
|
|
|
|
@event.listens_for(session.sync_session, "after_transaction_end")
|
|
def restart_savepoint(sess, transaction):
|
|
nonlocal nested
|
|
if not nested.is_active:
|
|
nested = conn.sync_connection.begin_nested()
|
|
|
|
try:
|
|
yield session
|
|
finally:
|
|
await session.close()
|
|
await trans.rollback()
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="session", autouse=True)
|
|
async def setup_database(db_manager: DatabaseSessionManager):
|
|
async with db_manager.connect() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
yield
|