[repository] begin implementation

This commit is contained in:
Piotr Domański 2026-04-02 20:38:14 +02:00
parent bbbd124d78
commit b010891ac7
15 changed files with 278 additions and 117 deletions

View file

@ -1,7 +1,7 @@
import contextlib
from typing import AsyncIterator, AsyncGenerator
from fooder.settings import Settings, settings
from .settings import Settings, settings
from sqlalchemy.ext.asyncio import (
AsyncConnection,
AsyncSession,

View file

@ -8,8 +8,9 @@ class Base(DeclarativeBase):
class CommonMixin:
"""define a series of common elements that may be applied to mapped
classes using this class as a mixin class."""
"""
CommonMixin for all common fields in projetc
"""
@declared_attr.directive
def __tablename__(cls) -> str:
@ -20,3 +21,22 @@ class CommonMixin:
return cls.__name__.lower() # type: ignore
id: Mapped[int] = mapped_column(primary_key=True)
class PasswordMixin:
"""
PasswordMixin for entities with password
"""
hashed_password: Mapped[str]
def set_password(self, password) -> None:
"""set_password.
:param password:
:rtype: None
"""
from ..auth import password_helper
self.hashed_password = password_helper.hash(password)

View file

@ -1,74 +1,9 @@
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped
from .base import Base, CommonMixin
from .base import Base, CommonMixin, PasswordMixin
class User(Base, CommonMixin):
class User(Base, CommonMixin, PasswordMixin):
"""Product."""
username: Mapped[str]
hashed_password: Mapped[str]
def set_password(self, password) -> None:
"""set_password.
:param password:
:rtype: None
"""
from ..auth import password_helper
self.hashed_password = password_helper.hash(password)
@classmethod
async def get_by_username(
cls, session: AsyncSession, username: str
) -> Optional["User"]:
"""get_by_username.
:param session:
:type session: AsyncSession
:param username:
:type username: str
:rtype: Optional["User"]
"""
query = select(cls).filter(cls.username == username)
return await session.scalar(query.order_by(cls.id))
@classmethod
async def get(cls, session: AsyncSession, id: int) -> Optional["User"]:
"""get_by_username.
:param session:
:type session: AsyncSession
:param id:
:type id: int
:rtype: Optional["User"]
"""
query = select(cls).filter(cls.id == id)
return await session.scalar(query.order_by(cls.id))
@classmethod
async def create(
cls, session: AsyncSession, username: str, password: str
) -> "User":
"""create.
:param session:
:type session: AsyncSession
:param username:
:type username: str
:param password:
:type password: str
:rtype: "User"
"""
exsisting_user = await User.get_by_username(session, username)
assert exsisting_user is None, "user already exists"
user = cls(username=username)
user.set_password(password)
session.add(user)
await session.flush()
return user

View file

@ -0,0 +1 @@
from .repository import Repository

View file

@ -1,6 +1,59 @@
from typing import TypeVar, Generic, Type, Any, Sequence
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update as sa_update, delete as sa_delete
from sqlalchemy.sql import Select
T = TypeVar("T")
class RepositoryBase:
def __init__(self, dbssn: AsyncSession):
self.dbssn = dbssn
class RepositoryBase(Generic[T]):
def __init__(self, model: Type[T], session: AsyncSession):
self.model = model
self.session = session
def _build_select(self, **filters: Any) -> Select[tuple[T]]:
stmt = select(self.model)
for field, value in filters.items():
column = getattr(self.model, field, None)
if column is None:
raise ValueError(f"{self.model.__name__} has no attribute '{field}'")
stmt = stmt.where(column == value)
return stmt
async def get(self, **filters: Any) -> T | None:
stmt = self._build_select(**filters)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def list(self, **filters: Any) -> Sequence[T]:
stmt = self._build_select(**filters)
result = await self.session.execute(stmt)
return result.scalars().all()
async def create(self, obj: T) -> T:
self.session.add(obj)
await self.session.flush()
await self.session.refresh(obj)
return obj
async def delete(self, **filters: Any) -> int:
stmt = sa_delete(self.model)
for field, value in filters.items():
column = getattr(self.model, field)
stmt = stmt.where(column == value)
result = await self.session.execute(stmt)
return result.rowcount if result.rowcount != -1 else 0
async def update(self, filters: dict[str, Any], values: dict[str, Any]) -> int:
stmt = sa_update(self.model)
for field, value in filters.items():
stmt = stmt.where(getattr(self.model, field) == value)
stmt = stmt.values(**values)
result = await self.session.execute(stmt)
return result.rowcount if result.rowcount != -1 else 0

View file

@ -0,0 +1,8 @@
from sqlalchemy.ext.asyncio import AsyncSession
from .user import UserRepository
from ..domain import User
class Repository:
def __init__(self, session: AsyncSession):
self.user = UserRepository(User, session)

View file

@ -1,2 +1,8 @@
class UserRepository:
from sqlalchemy import select, Select
from .base import RepositoryBase
from ..domain import User
class UserRepository(RepositoryBase[User]):
pass

View file

@ -1,6 +1,4 @@
import os
import pytest
import pytest_asyncio
# --------------------------------------------------------------------------- #
# Supply minimal dummy env-vars *before* any of our modules are imported. #
@ -16,19 +14,4 @@ os.environ.update(
}
)
from fooder.db import DatabaseSessionManager
from fooder.domain import Base
from fooder.settings import settings
@pytest.fixture(scope="session")
def db_manager() -> DatabaseSessionManager:
return DatabaseSessionManager(settings)
@pytest_asyncio.fixture(scope="session", autouse=True)
async def setup_database(db_manager: DatabaseSessionManager):
async with db_manager.connect() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
from .fixtures import *

View file

@ -1,5 +1,6 @@
import pytest
from .dbssn import *
from .db import *
from .faker import *
@pytest.fixture

72
fooder/test/fixtures/db.py vendored Normal file
View file

@ -0,0 +1,72 @@
import pytest
import pytest_asyncio
from sqlalchemy import event
from sqlalchemy.ext.asyncio import AsyncSession
from fooder.db import DatabaseSessionManager
from fooder.domain import Base
from fooder.settings import settings
from fooder.repository.base import RepositoryBase
from sqlalchemy.orm import Mapped, mapped_column
class TestModel(Base):
__tablename__ = "test"
id: Mapped[int] = mapped_column(primary_key=True)
property: Mapped[str]
class TestRepository(RepositoryBase[TestModel]):
pass
@pytest.fixture
def test_repo(db_session):
return TestRepository(TestModel, db_session)
@pytest.fixture
def test_model_factory():
def factory(id: int, property: str = "value") -> TestModel:
return TestModel(id=id, property=property)
return factory
@pytest.fixture
def test_model(test_model_factory):
return test_model_factory(1)
@pytest.fixture(scope="session")
def db_manager() -> DatabaseSessionManager:
return DatabaseSessionManager(settings)
@pytest_asyncio.fixture
async def db_session(db_manager):
async with db_manager._engine.connect() as conn:
trans = await conn.begin()
session = AsyncSession(bind=conn)
nested = await conn.begin_nested()
@event.listens_for(session.sync_session, "after_transaction_end")
def restart_savepoint(sess, transaction):
nonlocal nested
if not nested.is_active:
nested = conn.sync_connection.begin_nested()
try:
yield session
finally:
await session.close()
await trans.rollback()
@pytest_asyncio.fixture(scope="session", autouse=True)
async def setup_database(db_manager: DatabaseSessionManager):
async with db_manager.connect() as conn:
await conn.run_sync(Base.metadata.create_all)
yield

View file

@ -1,24 +0,0 @@
import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import event
@pytest_asyncio.fixture
async def db_session(db_manager):
async with db_manager._engine.connect() as conn:
trans = await conn.begin()
session = AsyncSession(bind=conn)
nested = await conn.begin_nested()
@event.listens_for(session.sync_session, "after_transaction_end")
def restart_savepoint(sess, transaction):
nonlocal nested
if not nested.is_active:
nested = conn.sync_connection.begin_nested()
try:
yield session
finally:
await session.close()
await trans.rollback()

9
fooder/test/fixtures/faker.py vendored Normal file
View file

@ -0,0 +1,9 @@
import pytest
from faker import Faker
@pytest.fixture(scope="session", autouse=True)
def faker():
f = Faker(["en_US", "pl_PL"])
f.seed_instance(1234)
return f

View file

View file

@ -0,0 +1,96 @@
import pytest
import faker
# ------------------------------------------------------------------ create ---
async def test_create_returns_object_with_id(test_repo, test_model):
model = await test_repo.create(test_model)
assert model.id is not None
assert model.property is not None
# -------------------------------------------------------------------- get ----
async def test_get_returns_existing_record(test_repo, test_model):
created = await test_repo.create(test_model)
found = await test_repo.get(id=created.id)
assert found is not None
assert found.id == created.id
async def test_get_returns_none_for_missing_record(test_repo):
result = await test_repo.get(id=1)
assert result is None
async def test_get_by_field(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1, "value"))
found = await test_repo.get(property="value")
assert found is not None
assert found.id == 1
# ------------------------------------------------------------------- list ----
async def test_list_returns_all_matching(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1))
await test_repo.create(test_model_factory(2))
results = await test_repo.list()
modelnames = {u.id for u in results}
assert {1, 2}.issubset(modelnames)
async def test_list_with_filter(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1, "value"))
await test_repo.create(test_model_factory(2, "value2"))
results = await test_repo.list(property="value")
assert len(results) == 1
assert results[0].id == 1
async def test_list_returns_empty_when_no_match(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1, "value"))
results = await test_repo.list(property="value2")
assert results == []
# ------------------------------------------------------------------ delete ---
async def test_delete_removes_record_and_returns_count(test_repo, test_model):
model = await test_repo.create(test_model)
count = await test_repo.delete(id=model.id)
assert count == 1
assert await test_repo.get(id=model.id) is None
async def test_delete_returns_zero_when_nothing_matched(test_repo, test_model_factory):
count = await test_repo.delete(id=999999)
assert count == 0
# ------------------------------------------------------------------ update ---
async def test_update_modifies_record_and_returns_count(test_repo, test_model):
model = await test_repo.create(test_model)
count = await test_repo.update({"id": model.id}, {"property": "value2"})
assert count == 1
refreshed = await test_repo.get(id=model.id)
assert refreshed is not None
assert refreshed.property == "value2"
async def test_update_returns_zero_when_nothing_matched(test_repo, test_model_factory):
count = await test_repo.update({"id": 999999}, {"property": "value"})
assert count == 0
# ------------------------------------------------------- _build_select ------
def test_build_select_raises_for_unknown_field(test_repo):
with pytest.raises(ValueError, match="has no attribute 'nonexistent'"):
test_repo._build_select(nonexistent="value")

View file

@ -9,6 +9,7 @@ passlib[bcrypt]
fastapi-users
pytest
pytest-asyncio
faker
requests
black
flake8