[repository] begin implementation
This commit is contained in:
parent
bbbd124d78
commit
b010891ac7
15 changed files with 278 additions and 117 deletions
|
|
@ -1,7 +1,7 @@
|
|||
import contextlib
|
||||
from typing import AsyncIterator, AsyncGenerator
|
||||
|
||||
from fooder.settings import Settings, settings
|
||||
from .settings import Settings, settings
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncConnection,
|
||||
AsyncSession,
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ class Base(DeclarativeBase):
|
|||
|
||||
|
||||
class CommonMixin:
|
||||
"""define a series of common elements that may be applied to mapped
|
||||
classes using this class as a mixin class."""
|
||||
"""
|
||||
CommonMixin for all common fields in projetc
|
||||
"""
|
||||
|
||||
@declared_attr.directive
|
||||
def __tablename__(cls) -> str:
|
||||
|
|
@ -20,3 +21,22 @@ class CommonMixin:
|
|||
return cls.__name__.lower() # type: ignore
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
|
||||
class PasswordMixin:
|
||||
"""
|
||||
PasswordMixin for entities with password
|
||||
"""
|
||||
|
||||
hashed_password: Mapped[str]
|
||||
|
||||
def set_password(self, password) -> None:
|
||||
"""set_password.
|
||||
|
||||
:param password:
|
||||
:rtype: None
|
||||
"""
|
||||
|
||||
from ..auth import password_helper
|
||||
|
||||
self.hashed_password = password_helper.hash(password)
|
||||
|
|
|
|||
|
|
@ -1,74 +1,9 @@
|
|||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
from .base import Base, CommonMixin
|
||||
from .base import Base, CommonMixin, PasswordMixin
|
||||
|
||||
|
||||
class User(Base, CommonMixin):
|
||||
class User(Base, CommonMixin, PasswordMixin):
|
||||
"""Product."""
|
||||
|
||||
username: Mapped[str]
|
||||
hashed_password: Mapped[str]
|
||||
|
||||
def set_password(self, password) -> None:
|
||||
"""set_password.
|
||||
|
||||
:param password:
|
||||
:rtype: None
|
||||
"""
|
||||
from ..auth import password_helper
|
||||
|
||||
self.hashed_password = password_helper.hash(password)
|
||||
|
||||
@classmethod
|
||||
async def get_by_username(
|
||||
cls, session: AsyncSession, username: str
|
||||
) -> Optional["User"]:
|
||||
"""get_by_username.
|
||||
|
||||
:param session:
|
||||
:type session: AsyncSession
|
||||
:param username:
|
||||
:type username: str
|
||||
:rtype: Optional["User"]
|
||||
"""
|
||||
query = select(cls).filter(cls.username == username)
|
||||
return await session.scalar(query.order_by(cls.id))
|
||||
|
||||
@classmethod
|
||||
async def get(cls, session: AsyncSession, id: int) -> Optional["User"]:
|
||||
"""get_by_username.
|
||||
|
||||
:param session:
|
||||
:type session: AsyncSession
|
||||
:param id:
|
||||
:type id: int
|
||||
:rtype: Optional["User"]
|
||||
"""
|
||||
query = select(cls).filter(cls.id == id)
|
||||
return await session.scalar(query.order_by(cls.id))
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls, session: AsyncSession, username: str, password: str
|
||||
) -> "User":
|
||||
"""create.
|
||||
|
||||
:param session:
|
||||
:type session: AsyncSession
|
||||
:param username:
|
||||
:type username: str
|
||||
:param password:
|
||||
:type password: str
|
||||
:rtype: "User"
|
||||
"""
|
||||
exsisting_user = await User.get_by_username(session, username)
|
||||
assert exsisting_user is None, "user already exists"
|
||||
user = cls(username=username)
|
||||
user.set_password(password)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
return user
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
from .repository import Repository
|
||||
|
|
@ -1,6 +1,59 @@
|
|||
from typing import TypeVar, Generic, Type, Any, Sequence
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update as sa_update, delete as sa_delete
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RepositoryBase:
|
||||
def __init__(self, dbssn: AsyncSession):
|
||||
self.dbssn = dbssn
|
||||
class RepositoryBase(Generic[T]):
|
||||
def __init__(self, model: Type[T], session: AsyncSession):
|
||||
self.model = model
|
||||
self.session = session
|
||||
|
||||
def _build_select(self, **filters: Any) -> Select[tuple[T]]:
|
||||
stmt = select(self.model)
|
||||
|
||||
for field, value in filters.items():
|
||||
column = getattr(self.model, field, None)
|
||||
if column is None:
|
||||
raise ValueError(f"{self.model.__name__} has no attribute '{field}'")
|
||||
stmt = stmt.where(column == value)
|
||||
|
||||
return stmt
|
||||
|
||||
async def get(self, **filters: Any) -> T | None:
|
||||
stmt = self._build_select(**filters)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list(self, **filters: Any) -> Sequence[T]:
|
||||
stmt = self._build_select(**filters)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def create(self, obj: T) -> T:
|
||||
self.session.add(obj)
|
||||
await self.session.flush()
|
||||
await self.session.refresh(obj)
|
||||
return obj
|
||||
|
||||
async def delete(self, **filters: Any) -> int:
|
||||
stmt = sa_delete(self.model)
|
||||
|
||||
for field, value in filters.items():
|
||||
column = getattr(self.model, field)
|
||||
stmt = stmt.where(column == value)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return result.rowcount if result.rowcount != -1 else 0
|
||||
|
||||
async def update(self, filters: dict[str, Any], values: dict[str, Any]) -> int:
|
||||
stmt = sa_update(self.model)
|
||||
|
||||
for field, value in filters.items():
|
||||
stmt = stmt.where(getattr(self.model, field) == value)
|
||||
|
||||
stmt = stmt.values(**values)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.rowcount if result.rowcount != -1 else 0
|
||||
|
|
|
|||
8
fooder/repository/repository.py
Normal file
8
fooder/repository/repository.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from .user import UserRepository
|
||||
from ..domain import User
|
||||
|
||||
|
||||
class Repository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.user = UserRepository(User, session)
|
||||
|
|
@ -1,2 +1,8 @@
|
|||
class UserRepository:
|
||||
from sqlalchemy import select, Select
|
||||
|
||||
from .base import RepositoryBase
|
||||
from ..domain import User
|
||||
|
||||
|
||||
class UserRepository(RepositoryBase[User]):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
import os
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Supply minimal dummy env-vars *before* any of our modules are imported. #
|
||||
|
|
@ -16,19 +14,4 @@ os.environ.update(
|
|||
}
|
||||
)
|
||||
|
||||
from fooder.db import DatabaseSessionManager
|
||||
from fooder.domain import Base
|
||||
from fooder.settings import settings
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def db_manager() -> DatabaseSessionManager:
|
||||
return DatabaseSessionManager(settings)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def setup_database(db_manager: DatabaseSessionManager):
|
||||
async with db_manager.connect() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield
|
||||
from .fixtures import *
|
||||
|
|
|
|||
3
fooder/test/fixtures/__init__.py
vendored
3
fooder/test/fixtures/__init__.py
vendored
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
from .dbssn import *
|
||||
from .db import *
|
||||
from .faker import *
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
72
fooder/test/fixtures/db.py
vendored
Normal file
72
fooder/test/fixtures/db.py
vendored
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from fooder.db import DatabaseSessionManager
|
||||
from fooder.domain import Base
|
||||
from fooder.settings import settings
|
||||
from fooder.repository.base import RepositoryBase
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class TestModel(Base):
|
||||
__tablename__ = "test"
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
property: Mapped[str]
|
||||
|
||||
|
||||
class TestRepository(RepositoryBase[TestModel]):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_repo(db_session):
|
||||
return TestRepository(TestModel, db_session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_model_factory():
|
||||
def factory(id: int, property: str = "value") -> TestModel:
|
||||
return TestModel(id=id, property=property)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_model(test_model_factory):
|
||||
return test_model_factory(1)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def db_manager() -> DatabaseSessionManager:
|
||||
return DatabaseSessionManager(settings)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session(db_manager):
|
||||
async with db_manager._engine.connect() as conn:
|
||||
trans = await conn.begin()
|
||||
session = AsyncSession(bind=conn)
|
||||
|
||||
nested = await conn.begin_nested()
|
||||
|
||||
@event.listens_for(session.sync_session, "after_transaction_end")
|
||||
def restart_savepoint(sess, transaction):
|
||||
nonlocal nested
|
||||
if not nested.is_active:
|
||||
nested = conn.sync_connection.begin_nested()
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
await trans.rollback()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def setup_database(db_manager: DatabaseSessionManager):
|
||||
async with db_manager.connect() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield
|
||||
24
fooder/test/fixtures/dbssn.py
vendored
24
fooder/test/fixtures/dbssn.py
vendored
|
|
@ -1,24 +0,0 @@
|
|||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import event
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session(db_manager):
|
||||
async with db_manager._engine.connect() as conn:
|
||||
trans = await conn.begin()
|
||||
session = AsyncSession(bind=conn)
|
||||
|
||||
nested = await conn.begin_nested()
|
||||
|
||||
@event.listens_for(session.sync_session, "after_transaction_end")
|
||||
def restart_savepoint(sess, transaction):
|
||||
nonlocal nested
|
||||
if not nested.is_active:
|
||||
nested = conn.sync_connection.begin_nested()
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
await trans.rollback()
|
||||
9
fooder/test/fixtures/faker.py
vendored
Normal file
9
fooder/test/fixtures/faker.py
vendored
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def faker():
|
||||
f = Faker(["en_US", "pl_PL"])
|
||||
f.seed_instance(1234)
|
||||
return f
|
||||
0
fooder/test/repository/__init__.py
Normal file
0
fooder/test/repository/__init__.py
Normal file
96
fooder/test/repository/test_base.py
Normal file
96
fooder/test/repository/test_base.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
import pytest
|
||||
import faker
|
||||
# ------------------------------------------------------------------ create ---
|
||||
|
||||
|
||||
async def test_create_returns_object_with_id(test_repo, test_model):
|
||||
model = await test_repo.create(test_model)
|
||||
assert model.id is not None
|
||||
assert model.property is not None
|
||||
|
||||
|
||||
# -------------------------------------------------------------------- get ----
|
||||
|
||||
|
||||
async def test_get_returns_existing_record(test_repo, test_model):
|
||||
created = await test_repo.create(test_model)
|
||||
found = await test_repo.get(id=created.id)
|
||||
assert found is not None
|
||||
assert found.id == created.id
|
||||
|
||||
|
||||
async def test_get_returns_none_for_missing_record(test_repo):
|
||||
result = await test_repo.get(id=1)
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_get_by_field(test_repo, test_model_factory):
|
||||
await test_repo.create(test_model_factory(1, "value"))
|
||||
found = await test_repo.get(property="value")
|
||||
assert found is not None
|
||||
assert found.id == 1
|
||||
|
||||
|
||||
# ------------------------------------------------------------------- list ----
|
||||
|
||||
|
||||
async def test_list_returns_all_matching(test_repo, test_model_factory):
|
||||
await test_repo.create(test_model_factory(1))
|
||||
await test_repo.create(test_model_factory(2))
|
||||
results = await test_repo.list()
|
||||
modelnames = {u.id for u in results}
|
||||
assert {1, 2}.issubset(modelnames)
|
||||
|
||||
|
||||
async def test_list_with_filter(test_repo, test_model_factory):
|
||||
await test_repo.create(test_model_factory(1, "value"))
|
||||
await test_repo.create(test_model_factory(2, "value2"))
|
||||
results = await test_repo.list(property="value")
|
||||
assert len(results) == 1
|
||||
assert results[0].id == 1
|
||||
|
||||
|
||||
async def test_list_returns_empty_when_no_match(test_repo, test_model_factory):
|
||||
await test_repo.create(test_model_factory(1, "value"))
|
||||
results = await test_repo.list(property="value2")
|
||||
assert results == []
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ delete ---
|
||||
|
||||
|
||||
async def test_delete_removes_record_and_returns_count(test_repo, test_model):
|
||||
model = await test_repo.create(test_model)
|
||||
count = await test_repo.delete(id=model.id)
|
||||
assert count == 1
|
||||
assert await test_repo.get(id=model.id) is None
|
||||
|
||||
|
||||
async def test_delete_returns_zero_when_nothing_matched(test_repo, test_model_factory):
|
||||
count = await test_repo.delete(id=999999)
|
||||
assert count == 0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ update ---
|
||||
|
||||
|
||||
async def test_update_modifies_record_and_returns_count(test_repo, test_model):
|
||||
model = await test_repo.create(test_model)
|
||||
count = await test_repo.update({"id": model.id}, {"property": "value2"})
|
||||
assert count == 1
|
||||
refreshed = await test_repo.get(id=model.id)
|
||||
assert refreshed is not None
|
||||
assert refreshed.property == "value2"
|
||||
|
||||
|
||||
async def test_update_returns_zero_when_nothing_matched(test_repo, test_model_factory):
|
||||
count = await test_repo.update({"id": 999999}, {"property": "value"})
|
||||
assert count == 0
|
||||
|
||||
|
||||
# ------------------------------------------------------- _build_select ------
|
||||
|
||||
|
||||
def test_build_select_raises_for_unknown_field(test_repo):
|
||||
with pytest.raises(ValueError, match="has no attribute 'nonexistent'"):
|
||||
test_repo._build_select(nonexistent="value")
|
||||
|
|
@ -9,6 +9,7 @@ passlib[bcrypt]
|
|||
fastapi-users
|
||||
pytest
|
||||
pytest-asyncio
|
||||
faker
|
||||
requests
|
||||
black
|
||||
flake8
|
||||
|
|
|
|||
Loading…
Reference in a new issue