From b010891ac7023190821c43a87448dfd2c1b29cfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Doma=C5=84ski?= Date: Thu, 2 Apr 2026 20:38:14 +0200 Subject: [PATCH] [repository] begin implementation --- fooder/db.py | 2 +- fooder/domain/base.py | 24 +++++++- fooder/domain/user.py | 69 +-------------------- fooder/repository/__init__.py | 1 + fooder/repository/base.py | 59 +++++++++++++++++- fooder/repository/repository.py | 8 +++ fooder/repository/user.py | 8 ++- fooder/test/conftest.py | 19 +----- fooder/test/fixtures/__init__.py | 3 +- fooder/test/fixtures/db.py | 72 ++++++++++++++++++++++ fooder/test/fixtures/dbssn.py | 24 -------- fooder/test/fixtures/faker.py | 9 +++ fooder/test/repository/__init__.py | 0 fooder/test/repository/test_base.py | 96 +++++++++++++++++++++++++++++ requirements/local.txt | 1 + 15 files changed, 278 insertions(+), 117 deletions(-) create mode 100644 fooder/repository/repository.py create mode 100644 fooder/test/fixtures/db.py delete mode 100644 fooder/test/fixtures/dbssn.py create mode 100644 fooder/test/fixtures/faker.py create mode 100644 fooder/test/repository/__init__.py create mode 100644 fooder/test/repository/test_base.py diff --git a/fooder/db.py b/fooder/db.py index ab53951..7fa2ccf 100644 --- a/fooder/db.py +++ b/fooder/db.py @@ -1,7 +1,7 @@ import contextlib from typing import AsyncIterator, AsyncGenerator -from fooder.settings import Settings, settings +from .settings import Settings, settings from sqlalchemy.ext.asyncio import ( AsyncConnection, AsyncSession, diff --git a/fooder/domain/base.py b/fooder/domain/base.py index d008cbc..68859a3 100644 --- a/fooder/domain/base.py +++ b/fooder/domain/base.py @@ -8,8 +8,9 @@ class Base(DeclarativeBase): class CommonMixin: - """define a series of common elements that may be applied to mapped - classes using this class as a mixin class.""" + """ + CommonMixin for all common fields in projetc + """ @declared_attr.directive def __tablename__(cls) -> str: @@ -20,3 +21,22 @@ class CommonMixin: return cls.__name__.lower() # type: ignore id: Mapped[int] = mapped_column(primary_key=True) + + +class PasswordMixin: + """ + PasswordMixin for entities with password + """ + + hashed_password: Mapped[str] + + def set_password(self, password) -> None: + """set_password. + + :param password: + :rtype: None + """ + + from ..auth import password_helper + + self.hashed_password = password_helper.hash(password) diff --git a/fooder/domain/user.py b/fooder/domain/user.py index de85d20..3042a72 100644 --- a/fooder/domain/user.py +++ b/fooder/domain/user.py @@ -1,74 +1,9 @@ -from typing import Optional - -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped -from .base import Base, CommonMixin +from .base import Base, CommonMixin, PasswordMixin -class User(Base, CommonMixin): +class User(Base, CommonMixin, PasswordMixin): """Product.""" username: Mapped[str] - hashed_password: Mapped[str] - - def set_password(self, password) -> None: - """set_password. - - :param password: - :rtype: None - """ - from ..auth import password_helper - - self.hashed_password = password_helper.hash(password) - - @classmethod - async def get_by_username( - cls, session: AsyncSession, username: str - ) -> Optional["User"]: - """get_by_username. - - :param session: - :type session: AsyncSession - :param username: - :type username: str - :rtype: Optional["User"] - """ - query = select(cls).filter(cls.username == username) - return await session.scalar(query.order_by(cls.id)) - - @classmethod - async def get(cls, session: AsyncSession, id: int) -> Optional["User"]: - """get_by_username. - - :param session: - :type session: AsyncSession - :param id: - :type id: int - :rtype: Optional["User"] - """ - query = select(cls).filter(cls.id == id) - return await session.scalar(query.order_by(cls.id)) - - @classmethod - async def create( - cls, session: AsyncSession, username: str, password: str - ) -> "User": - """create. - - :param session: - :type session: AsyncSession - :param username: - :type username: str - :param password: - :type password: str - :rtype: "User" - """ - exsisting_user = await User.get_by_username(session, username) - assert exsisting_user is None, "user already exists" - user = cls(username=username) - user.set_password(password) - session.add(user) - await session.flush() - return user diff --git a/fooder/repository/__init__.py b/fooder/repository/__init__.py index e69de29..9c22cb7 100644 --- a/fooder/repository/__init__.py +++ b/fooder/repository/__init__.py @@ -0,0 +1 @@ +from .repository import Repository diff --git a/fooder/repository/base.py b/fooder/repository/base.py index a89349f..a8a8111 100644 --- a/fooder/repository/base.py +++ b/fooder/repository/base.py @@ -1,6 +1,59 @@ +from typing import TypeVar, Generic, Type, Any, Sequence from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, update as sa_update, delete as sa_delete +from sqlalchemy.sql import Select + +T = TypeVar("T") -class RepositoryBase: - def __init__(self, dbssn: AsyncSession): - self.dbssn = dbssn +class RepositoryBase(Generic[T]): + def __init__(self, model: Type[T], session: AsyncSession): + self.model = model + self.session = session + + def _build_select(self, **filters: Any) -> Select[tuple[T]]: + stmt = select(self.model) + + for field, value in filters.items(): + column = getattr(self.model, field, None) + if column is None: + raise ValueError(f"{self.model.__name__} has no attribute '{field}'") + stmt = stmt.where(column == value) + + return stmt + + async def get(self, **filters: Any) -> T | None: + stmt = self._build_select(**filters) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + + async def list(self, **filters: Any) -> Sequence[T]: + stmt = self._build_select(**filters) + result = await self.session.execute(stmt) + return result.scalars().all() + + async def create(self, obj: T) -> T: + self.session.add(obj) + await self.session.flush() + await self.session.refresh(obj) + return obj + + async def delete(self, **filters: Any) -> int: + stmt = sa_delete(self.model) + + for field, value in filters.items(): + column = getattr(self.model, field) + stmt = stmt.where(column == value) + + result = await self.session.execute(stmt) + return result.rowcount if result.rowcount != -1 else 0 + + async def update(self, filters: dict[str, Any], values: dict[str, Any]) -> int: + stmt = sa_update(self.model) + + for field, value in filters.items(): + stmt = stmt.where(getattr(self.model, field) == value) + + stmt = stmt.values(**values) + result = await self.session.execute(stmt) + return result.rowcount if result.rowcount != -1 else 0 diff --git a/fooder/repository/repository.py b/fooder/repository/repository.py new file mode 100644 index 0000000..b4249ae --- /dev/null +++ b/fooder/repository/repository.py @@ -0,0 +1,8 @@ +from sqlalchemy.ext.asyncio import AsyncSession +from .user import UserRepository +from ..domain import User + + +class Repository: + def __init__(self, session: AsyncSession): + self.user = UserRepository(User, session) diff --git a/fooder/repository/user.py b/fooder/repository/user.py index fe63377..f921404 100644 --- a/fooder/repository/user.py +++ b/fooder/repository/user.py @@ -1,2 +1,8 @@ -class UserRepository: +from sqlalchemy import select, Select + +from .base import RepositoryBase +from ..domain import User + + +class UserRepository(RepositoryBase[User]): pass diff --git a/fooder/test/conftest.py b/fooder/test/conftest.py index 0c9b106..34336ed 100644 --- a/fooder/test/conftest.py +++ b/fooder/test/conftest.py @@ -1,6 +1,4 @@ import os -import pytest -import pytest_asyncio # --------------------------------------------------------------------------- # # Supply minimal dummy env-vars *before* any of our modules are imported. # @@ -16,19 +14,4 @@ os.environ.update( } ) -from fooder.db import DatabaseSessionManager -from fooder.domain import Base -from fooder.settings import settings - - -@pytest.fixture(scope="session") -def db_manager() -> DatabaseSessionManager: - return DatabaseSessionManager(settings) - - -@pytest_asyncio.fixture(scope="session", autouse=True) -async def setup_database(db_manager: DatabaseSessionManager): - async with db_manager.connect() as conn: - await conn.run_sync(Base.metadata.create_all) - - yield +from .fixtures import * diff --git a/fooder/test/fixtures/__init__.py b/fooder/test/fixtures/__init__.py index 3a25d46..65d34dc 100644 --- a/fooder/test/fixtures/__init__.py +++ b/fooder/test/fixtures/__init__.py @@ -1,5 +1,6 @@ import pytest -from .dbssn import * +from .db import * +from .faker import * @pytest.fixture diff --git a/fooder/test/fixtures/db.py b/fooder/test/fixtures/db.py new file mode 100644 index 0000000..5df90f7 --- /dev/null +++ b/fooder/test/fixtures/db.py @@ -0,0 +1,72 @@ +import pytest +import pytest_asyncio + +from sqlalchemy import event +from sqlalchemy.ext.asyncio import AsyncSession +from fooder.db import DatabaseSessionManager +from fooder.domain import Base +from fooder.settings import settings +from fooder.repository.base import RepositoryBase +from sqlalchemy.orm import Mapped, mapped_column + + +class TestModel(Base): + __tablename__ = "test" + id: Mapped[int] = mapped_column(primary_key=True) + property: Mapped[str] + + +class TestRepository(RepositoryBase[TestModel]): + pass + + +@pytest.fixture +def test_repo(db_session): + return TestRepository(TestModel, db_session) + + +@pytest.fixture +def test_model_factory(): + def factory(id: int, property: str = "value") -> TestModel: + return TestModel(id=id, property=property) + + return factory + + +@pytest.fixture +def test_model(test_model_factory): + return test_model_factory(1) + + +@pytest.fixture(scope="session") +def db_manager() -> DatabaseSessionManager: + return DatabaseSessionManager(settings) + + +@pytest_asyncio.fixture +async def db_session(db_manager): + async with db_manager._engine.connect() as conn: + trans = await conn.begin() + session = AsyncSession(bind=conn) + + nested = await conn.begin_nested() + + @event.listens_for(session.sync_session, "after_transaction_end") + def restart_savepoint(sess, transaction): + nonlocal nested + if not nested.is_active: + nested = conn.sync_connection.begin_nested() + + try: + yield session + finally: + await session.close() + await trans.rollback() + + +@pytest_asyncio.fixture(scope="session", autouse=True) +async def setup_database(db_manager: DatabaseSessionManager): + async with db_manager.connect() as conn: + await conn.run_sync(Base.metadata.create_all) + + yield diff --git a/fooder/test/fixtures/dbssn.py b/fooder/test/fixtures/dbssn.py deleted file mode 100644 index 63284d9..0000000 --- a/fooder/test/fixtures/dbssn.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest_asyncio -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import event - - -@pytest_asyncio.fixture -async def db_session(db_manager): - async with db_manager._engine.connect() as conn: - trans = await conn.begin() - session = AsyncSession(bind=conn) - - nested = await conn.begin_nested() - - @event.listens_for(session.sync_session, "after_transaction_end") - def restart_savepoint(sess, transaction): - nonlocal nested - if not nested.is_active: - nested = conn.sync_connection.begin_nested() - - try: - yield session - finally: - await session.close() - await trans.rollback() diff --git a/fooder/test/fixtures/faker.py b/fooder/test/fixtures/faker.py new file mode 100644 index 0000000..699eb4f --- /dev/null +++ b/fooder/test/fixtures/faker.py @@ -0,0 +1,9 @@ +import pytest +from faker import Faker + + +@pytest.fixture(scope="session", autouse=True) +def faker(): + f = Faker(["en_US", "pl_PL"]) + f.seed_instance(1234) + return f diff --git a/fooder/test/repository/__init__.py b/fooder/test/repository/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fooder/test/repository/test_base.py b/fooder/test/repository/test_base.py new file mode 100644 index 0000000..f1b8d97 --- /dev/null +++ b/fooder/test/repository/test_base.py @@ -0,0 +1,96 @@ +import pytest +import faker +# ------------------------------------------------------------------ create --- + + +async def test_create_returns_object_with_id(test_repo, test_model): + model = await test_repo.create(test_model) + assert model.id is not None + assert model.property is not None + + +# -------------------------------------------------------------------- get ---- + + +async def test_get_returns_existing_record(test_repo, test_model): + created = await test_repo.create(test_model) + found = await test_repo.get(id=created.id) + assert found is not None + assert found.id == created.id + + +async def test_get_returns_none_for_missing_record(test_repo): + result = await test_repo.get(id=1) + assert result is None + + +async def test_get_by_field(test_repo, test_model_factory): + await test_repo.create(test_model_factory(1, "value")) + found = await test_repo.get(property="value") + assert found is not None + assert found.id == 1 + + +# ------------------------------------------------------------------- list ---- + + +async def test_list_returns_all_matching(test_repo, test_model_factory): + await test_repo.create(test_model_factory(1)) + await test_repo.create(test_model_factory(2)) + results = await test_repo.list() + modelnames = {u.id for u in results} + assert {1, 2}.issubset(modelnames) + + +async def test_list_with_filter(test_repo, test_model_factory): + await test_repo.create(test_model_factory(1, "value")) + await test_repo.create(test_model_factory(2, "value2")) + results = await test_repo.list(property="value") + assert len(results) == 1 + assert results[0].id == 1 + + +async def test_list_returns_empty_when_no_match(test_repo, test_model_factory): + await test_repo.create(test_model_factory(1, "value")) + results = await test_repo.list(property="value2") + assert results == [] + + +# ------------------------------------------------------------------ delete --- + + +async def test_delete_removes_record_and_returns_count(test_repo, test_model): + model = await test_repo.create(test_model) + count = await test_repo.delete(id=model.id) + assert count == 1 + assert await test_repo.get(id=model.id) is None + + +async def test_delete_returns_zero_when_nothing_matched(test_repo, test_model_factory): + count = await test_repo.delete(id=999999) + assert count == 0 + + +# ------------------------------------------------------------------ update --- + + +async def test_update_modifies_record_and_returns_count(test_repo, test_model): + model = await test_repo.create(test_model) + count = await test_repo.update({"id": model.id}, {"property": "value2"}) + assert count == 1 + refreshed = await test_repo.get(id=model.id) + assert refreshed is not None + assert refreshed.property == "value2" + + +async def test_update_returns_zero_when_nothing_matched(test_repo, test_model_factory): + count = await test_repo.update({"id": 999999}, {"property": "value"}) + assert count == 0 + + +# ------------------------------------------------------- _build_select ------ + + +def test_build_select_raises_for_unknown_field(test_repo): + with pytest.raises(ValueError, match="has no attribute 'nonexistent'"): + test_repo._build_select(nonexistent="value") diff --git a/requirements/local.txt b/requirements/local.txt index 88e1f55..2e90aa7 100644 --- a/requirements/local.txt +++ b/requirements/local.txt @@ -9,6 +9,7 @@ passlib[bcrypt] fastapi-users pytest pytest-asyncio +faker requests black flake8