[repository] begin implementation

This commit is contained in:
Piotr Domański 2026-04-02 20:38:14 +02:00
parent bbbd124d78
commit b010891ac7
15 changed files with 278 additions and 117 deletions

View file

@ -1,7 +1,7 @@
import contextlib import contextlib
from typing import AsyncIterator, AsyncGenerator from typing import AsyncIterator, AsyncGenerator
from fooder.settings import Settings, settings from .settings import Settings, settings
from sqlalchemy.ext.asyncio import ( from sqlalchemy.ext.asyncio import (
AsyncConnection, AsyncConnection,
AsyncSession, AsyncSession,

View file

@ -8,8 +8,9 @@ class Base(DeclarativeBase):
class CommonMixin: class CommonMixin:
"""define a series of common elements that may be applied to mapped """
classes using this class as a mixin class.""" CommonMixin for all common fields in projetc
"""
@declared_attr.directive @declared_attr.directive
def __tablename__(cls) -> str: def __tablename__(cls) -> str:
@ -20,3 +21,22 @@ class CommonMixin:
return cls.__name__.lower() # type: ignore return cls.__name__.lower() # type: ignore
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
class PasswordMixin:
"""
PasswordMixin for entities with password
"""
hashed_password: Mapped[str]
def set_password(self, password) -> None:
"""set_password.
:param password:
:rtype: None
"""
from ..auth import password_helper
self.hashed_password = password_helper.hash(password)

View file

@ -1,74 +1,9 @@
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped from sqlalchemy.orm import Mapped
from .base import Base, CommonMixin from .base import Base, CommonMixin, PasswordMixin
class User(Base, CommonMixin): class User(Base, CommonMixin, PasswordMixin):
"""Product.""" """Product."""
username: Mapped[str] username: Mapped[str]
hashed_password: Mapped[str]
def set_password(self, password) -> None:
"""set_password.
:param password:
:rtype: None
"""
from ..auth import password_helper
self.hashed_password = password_helper.hash(password)
@classmethod
async def get_by_username(
cls, session: AsyncSession, username: str
) -> Optional["User"]:
"""get_by_username.
:param session:
:type session: AsyncSession
:param username:
:type username: str
:rtype: Optional["User"]
"""
query = select(cls).filter(cls.username == username)
return await session.scalar(query.order_by(cls.id))
@classmethod
async def get(cls, session: AsyncSession, id: int) -> Optional["User"]:
"""get_by_username.
:param session:
:type session: AsyncSession
:param id:
:type id: int
:rtype: Optional["User"]
"""
query = select(cls).filter(cls.id == id)
return await session.scalar(query.order_by(cls.id))
@classmethod
async def create(
cls, session: AsyncSession, username: str, password: str
) -> "User":
"""create.
:param session:
:type session: AsyncSession
:param username:
:type username: str
:param password:
:type password: str
:rtype: "User"
"""
exsisting_user = await User.get_by_username(session, username)
assert exsisting_user is None, "user already exists"
user = cls(username=username)
user.set_password(password)
session.add(user)
await session.flush()
return user

View file

@ -0,0 +1 @@
from .repository import Repository

View file

@ -1,6 +1,59 @@
from typing import TypeVar, Generic, Type, Any, Sequence
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update as sa_update, delete as sa_delete
from sqlalchemy.sql import Select
T = TypeVar("T")
class RepositoryBase: class RepositoryBase(Generic[T]):
def __init__(self, dbssn: AsyncSession): def __init__(self, model: Type[T], session: AsyncSession):
self.dbssn = dbssn self.model = model
self.session = session
def _build_select(self, **filters: Any) -> Select[tuple[T]]:
stmt = select(self.model)
for field, value in filters.items():
column = getattr(self.model, field, None)
if column is None:
raise ValueError(f"{self.model.__name__} has no attribute '{field}'")
stmt = stmt.where(column == value)
return stmt
async def get(self, **filters: Any) -> T | None:
stmt = self._build_select(**filters)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def list(self, **filters: Any) -> Sequence[T]:
stmt = self._build_select(**filters)
result = await self.session.execute(stmt)
return result.scalars().all()
async def create(self, obj: T) -> T:
self.session.add(obj)
await self.session.flush()
await self.session.refresh(obj)
return obj
async def delete(self, **filters: Any) -> int:
stmt = sa_delete(self.model)
for field, value in filters.items():
column = getattr(self.model, field)
stmt = stmt.where(column == value)
result = await self.session.execute(stmt)
return result.rowcount if result.rowcount != -1 else 0
async def update(self, filters: dict[str, Any], values: dict[str, Any]) -> int:
stmt = sa_update(self.model)
for field, value in filters.items():
stmt = stmt.where(getattr(self.model, field) == value)
stmt = stmt.values(**values)
result = await self.session.execute(stmt)
return result.rowcount if result.rowcount != -1 else 0

View file

@ -0,0 +1,8 @@
from sqlalchemy.ext.asyncio import AsyncSession
from .user import UserRepository
from ..domain import User
class Repository:
def __init__(self, session: AsyncSession):
self.user = UserRepository(User, session)

View file

@ -1,2 +1,8 @@
class UserRepository: from sqlalchemy import select, Select
from .base import RepositoryBase
from ..domain import User
class UserRepository(RepositoryBase[User]):
pass pass

View file

@ -1,6 +1,4 @@
import os import os
import pytest
import pytest_asyncio
# --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- #
# Supply minimal dummy env-vars *before* any of our modules are imported. # # Supply minimal dummy env-vars *before* any of our modules are imported. #
@ -16,19 +14,4 @@ os.environ.update(
} }
) )
from fooder.db import DatabaseSessionManager from .fixtures import *
from fooder.domain import Base
from fooder.settings import settings
@pytest.fixture(scope="session")
def db_manager() -> DatabaseSessionManager:
return DatabaseSessionManager(settings)
@pytest_asyncio.fixture(scope="session", autouse=True)
async def setup_database(db_manager: DatabaseSessionManager):
async with db_manager.connect() as conn:
await conn.run_sync(Base.metadata.create_all)
yield

View file

@ -1,5 +1,6 @@
import pytest import pytest
from .dbssn import * from .db import *
from .faker import *
@pytest.fixture @pytest.fixture

72
fooder/test/fixtures/db.py vendored Normal file
View file

@ -0,0 +1,72 @@
import pytest
import pytest_asyncio
from sqlalchemy import event
from sqlalchemy.ext.asyncio import AsyncSession
from fooder.db import DatabaseSessionManager
from fooder.domain import Base
from fooder.settings import settings
from fooder.repository.base import RepositoryBase
from sqlalchemy.orm import Mapped, mapped_column
class TestModel(Base):
__tablename__ = "test"
id: Mapped[int] = mapped_column(primary_key=True)
property: Mapped[str]
class TestRepository(RepositoryBase[TestModel]):
pass
@pytest.fixture
def test_repo(db_session):
return TestRepository(TestModel, db_session)
@pytest.fixture
def test_model_factory():
def factory(id: int, property: str = "value") -> TestModel:
return TestModel(id=id, property=property)
return factory
@pytest.fixture
def test_model(test_model_factory):
return test_model_factory(1)
@pytest.fixture(scope="session")
def db_manager() -> DatabaseSessionManager:
return DatabaseSessionManager(settings)
@pytest_asyncio.fixture
async def db_session(db_manager):
async with db_manager._engine.connect() as conn:
trans = await conn.begin()
session = AsyncSession(bind=conn)
nested = await conn.begin_nested()
@event.listens_for(session.sync_session, "after_transaction_end")
def restart_savepoint(sess, transaction):
nonlocal nested
if not nested.is_active:
nested = conn.sync_connection.begin_nested()
try:
yield session
finally:
await session.close()
await trans.rollback()
@pytest_asyncio.fixture(scope="session", autouse=True)
async def setup_database(db_manager: DatabaseSessionManager):
async with db_manager.connect() as conn:
await conn.run_sync(Base.metadata.create_all)
yield

View file

@ -1,24 +0,0 @@
import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import event
@pytest_asyncio.fixture
async def db_session(db_manager):
async with db_manager._engine.connect() as conn:
trans = await conn.begin()
session = AsyncSession(bind=conn)
nested = await conn.begin_nested()
@event.listens_for(session.sync_session, "after_transaction_end")
def restart_savepoint(sess, transaction):
nonlocal nested
if not nested.is_active:
nested = conn.sync_connection.begin_nested()
try:
yield session
finally:
await session.close()
await trans.rollback()

9
fooder/test/fixtures/faker.py vendored Normal file
View file

@ -0,0 +1,9 @@
import pytest
from faker import Faker
@pytest.fixture(scope="session", autouse=True)
def faker():
f = Faker(["en_US", "pl_PL"])
f.seed_instance(1234)
return f

View file

View file

@ -0,0 +1,96 @@
import pytest
import faker
# ------------------------------------------------------------------ create ---
async def test_create_returns_object_with_id(test_repo, test_model):
model = await test_repo.create(test_model)
assert model.id is not None
assert model.property is not None
# -------------------------------------------------------------------- get ----
async def test_get_returns_existing_record(test_repo, test_model):
created = await test_repo.create(test_model)
found = await test_repo.get(id=created.id)
assert found is not None
assert found.id == created.id
async def test_get_returns_none_for_missing_record(test_repo):
result = await test_repo.get(id=1)
assert result is None
async def test_get_by_field(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1, "value"))
found = await test_repo.get(property="value")
assert found is not None
assert found.id == 1
# ------------------------------------------------------------------- list ----
async def test_list_returns_all_matching(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1))
await test_repo.create(test_model_factory(2))
results = await test_repo.list()
modelnames = {u.id for u in results}
assert {1, 2}.issubset(modelnames)
async def test_list_with_filter(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1, "value"))
await test_repo.create(test_model_factory(2, "value2"))
results = await test_repo.list(property="value")
assert len(results) == 1
assert results[0].id == 1
async def test_list_returns_empty_when_no_match(test_repo, test_model_factory):
await test_repo.create(test_model_factory(1, "value"))
results = await test_repo.list(property="value2")
assert results == []
# ------------------------------------------------------------------ delete ---
async def test_delete_removes_record_and_returns_count(test_repo, test_model):
model = await test_repo.create(test_model)
count = await test_repo.delete(id=model.id)
assert count == 1
assert await test_repo.get(id=model.id) is None
async def test_delete_returns_zero_when_nothing_matched(test_repo, test_model_factory):
count = await test_repo.delete(id=999999)
assert count == 0
# ------------------------------------------------------------------ update ---
async def test_update_modifies_record_and_returns_count(test_repo, test_model):
model = await test_repo.create(test_model)
count = await test_repo.update({"id": model.id}, {"property": "value2"})
assert count == 1
refreshed = await test_repo.get(id=model.id)
assert refreshed is not None
assert refreshed.property == "value2"
async def test_update_returns_zero_when_nothing_matched(test_repo, test_model_factory):
count = await test_repo.update({"id": 999999}, {"property": "value"})
assert count == 0
# ------------------------------------------------------- _build_select ------
def test_build_select_raises_for_unknown_field(test_repo):
with pytest.raises(ValueError, match="has no attribute 'nonexistent'"):
test_repo._build_select(nonexistent="value")

View file

@ -9,6 +9,7 @@ passlib[bcrypt]
fastapi-users fastapi-users
pytest pytest
pytest-asyncio pytest-asyncio
faker
requests requests
black black
flake8 flake8