[repository] begin implementation
This commit is contained in:
parent
bbbd124d78
commit
b010891ac7
15 changed files with 278 additions and 117 deletions
|
|
@ -1,7 +1,7 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
from typing import AsyncIterator, AsyncGenerator
|
from typing import AsyncIterator, AsyncGenerator
|
||||||
|
|
||||||
from fooder.settings import Settings, settings
|
from .settings import Settings, settings
|
||||||
from sqlalchemy.ext.asyncio import (
|
from sqlalchemy.ext.asyncio import (
|
||||||
AsyncConnection,
|
AsyncConnection,
|
||||||
AsyncSession,
|
AsyncSession,
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,9 @@ class Base(DeclarativeBase):
|
||||||
|
|
||||||
|
|
||||||
class CommonMixin:
|
class CommonMixin:
|
||||||
"""define a series of common elements that may be applied to mapped
|
"""
|
||||||
classes using this class as a mixin class."""
|
CommonMixin for all common fields in projetc
|
||||||
|
"""
|
||||||
|
|
||||||
@declared_attr.directive
|
@declared_attr.directive
|
||||||
def __tablename__(cls) -> str:
|
def __tablename__(cls) -> str:
|
||||||
|
|
@ -20,3 +21,22 @@ class CommonMixin:
|
||||||
return cls.__name__.lower() # type: ignore
|
return cls.__name__.lower() # type: ignore
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordMixin:
|
||||||
|
"""
|
||||||
|
PasswordMixin for entities with password
|
||||||
|
"""
|
||||||
|
|
||||||
|
hashed_password: Mapped[str]
|
||||||
|
|
||||||
|
def set_password(self, password) -> None:
|
||||||
|
"""set_password.
|
||||||
|
|
||||||
|
:param password:
|
||||||
|
:rtype: None
|
||||||
|
"""
|
||||||
|
|
||||||
|
from ..auth import password_helper
|
||||||
|
|
||||||
|
self.hashed_password = password_helper.hash(password)
|
||||||
|
|
|
||||||
|
|
@ -1,74 +1,9 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.orm import Mapped
|
from sqlalchemy.orm import Mapped
|
||||||
|
|
||||||
from .base import Base, CommonMixin
|
from .base import Base, CommonMixin, PasswordMixin
|
||||||
|
|
||||||
|
|
||||||
class User(Base, CommonMixin):
|
class User(Base, CommonMixin, PasswordMixin):
|
||||||
"""Product."""
|
"""Product."""
|
||||||
|
|
||||||
username: Mapped[str]
|
username: Mapped[str]
|
||||||
hashed_password: Mapped[str]
|
|
||||||
|
|
||||||
def set_password(self, password) -> None:
|
|
||||||
"""set_password.
|
|
||||||
|
|
||||||
:param password:
|
|
||||||
:rtype: None
|
|
||||||
"""
|
|
||||||
from ..auth import password_helper
|
|
||||||
|
|
||||||
self.hashed_password = password_helper.hash(password)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_by_username(
|
|
||||||
cls, session: AsyncSession, username: str
|
|
||||||
) -> Optional["User"]:
|
|
||||||
"""get_by_username.
|
|
||||||
|
|
||||||
:param session:
|
|
||||||
:type session: AsyncSession
|
|
||||||
:param username:
|
|
||||||
:type username: str
|
|
||||||
:rtype: Optional["User"]
|
|
||||||
"""
|
|
||||||
query = select(cls).filter(cls.username == username)
|
|
||||||
return await session.scalar(query.order_by(cls.id))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get(cls, session: AsyncSession, id: int) -> Optional["User"]:
|
|
||||||
"""get_by_username.
|
|
||||||
|
|
||||||
:param session:
|
|
||||||
:type session: AsyncSession
|
|
||||||
:param id:
|
|
||||||
:type id: int
|
|
||||||
:rtype: Optional["User"]
|
|
||||||
"""
|
|
||||||
query = select(cls).filter(cls.id == id)
|
|
||||||
return await session.scalar(query.order_by(cls.id))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def create(
|
|
||||||
cls, session: AsyncSession, username: str, password: str
|
|
||||||
) -> "User":
|
|
||||||
"""create.
|
|
||||||
|
|
||||||
:param session:
|
|
||||||
:type session: AsyncSession
|
|
||||||
:param username:
|
|
||||||
:type username: str
|
|
||||||
:param password:
|
|
||||||
:type password: str
|
|
||||||
:rtype: "User"
|
|
||||||
"""
|
|
||||||
exsisting_user = await User.get_by_username(session, username)
|
|
||||||
assert exsisting_user is None, "user already exists"
|
|
||||||
user = cls(username=username)
|
|
||||||
user.set_password(password)
|
|
||||||
session.add(user)
|
|
||||||
await session.flush()
|
|
||||||
return user
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from .repository import Repository
|
||||||
|
|
@ -1,6 +1,59 @@
|
||||||
|
from typing import TypeVar, Generic, Type, Any, Sequence
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select, update as sa_update, delete as sa_delete
|
||||||
|
from sqlalchemy.sql import Select
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class RepositoryBase:
|
class RepositoryBase(Generic[T]):
|
||||||
def __init__(self, dbssn: AsyncSession):
|
def __init__(self, model: Type[T], session: AsyncSession):
|
||||||
self.dbssn = dbssn
|
self.model = model
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
def _build_select(self, **filters: Any) -> Select[tuple[T]]:
|
||||||
|
stmt = select(self.model)
|
||||||
|
|
||||||
|
for field, value in filters.items():
|
||||||
|
column = getattr(self.model, field, None)
|
||||||
|
if column is None:
|
||||||
|
raise ValueError(f"{self.model.__name__} has no attribute '{field}'")
|
||||||
|
stmt = stmt.where(column == value)
|
||||||
|
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
async def get(self, **filters: Any) -> T | None:
|
||||||
|
stmt = self._build_select(**filters)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def list(self, **filters: Any) -> Sequence[T]:
|
||||||
|
stmt = self._build_select(**filters)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def create(self, obj: T) -> T:
|
||||||
|
self.session.add(obj)
|
||||||
|
await self.session.flush()
|
||||||
|
await self.session.refresh(obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
async def delete(self, **filters: Any) -> int:
|
||||||
|
stmt = sa_delete(self.model)
|
||||||
|
|
||||||
|
for field, value in filters.items():
|
||||||
|
column = getattr(self.model, field)
|
||||||
|
stmt = stmt.where(column == value)
|
||||||
|
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.rowcount if result.rowcount != -1 else 0
|
||||||
|
|
||||||
|
async def update(self, filters: dict[str, Any], values: dict[str, Any]) -> int:
|
||||||
|
stmt = sa_update(self.model)
|
||||||
|
|
||||||
|
for field, value in filters.items():
|
||||||
|
stmt = stmt.where(getattr(self.model, field) == value)
|
||||||
|
|
||||||
|
stmt = stmt.values(**values)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.rowcount if result.rowcount != -1 else 0
|
||||||
|
|
|
||||||
8
fooder/repository/repository.py
Normal file
8
fooder/repository/repository.py
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from .user import UserRepository
|
||||||
|
from ..domain import User
|
||||||
|
|
||||||
|
|
||||||
|
class Repository:
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.user = UserRepository(User, session)
|
||||||
|
|
@ -1,2 +1,8 @@
|
||||||
class UserRepository:
|
from sqlalchemy import select, Select
|
||||||
|
|
||||||
|
from .base import RepositoryBase
|
||||||
|
from ..domain import User
|
||||||
|
|
||||||
|
|
||||||
|
class UserRepository(RepositoryBase[User]):
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,4 @@
|
||||||
import os
|
import os
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------------- #
|
# --------------------------------------------------------------------------- #
|
||||||
# Supply minimal dummy env-vars *before* any of our modules are imported. #
|
# Supply minimal dummy env-vars *before* any of our modules are imported. #
|
||||||
|
|
@ -16,19 +14,4 @@ os.environ.update(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
from fooder.db import DatabaseSessionManager
|
from .fixtures import *
|
||||||
from fooder.domain import Base
|
|
||||||
from fooder.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def db_manager() -> DatabaseSessionManager:
|
|
||||||
return DatabaseSessionManager(settings)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
|
||||||
async def setup_database(db_manager: DatabaseSessionManager):
|
|
||||||
async with db_manager.connect() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
|
||||||
3
fooder/test/fixtures/__init__.py
vendored
3
fooder/test/fixtures/__init__.py
vendored
|
|
@ -1,5 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
from .dbssn import *
|
from .db import *
|
||||||
|
from .faker import *
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
72
fooder/test/fixtures/db.py
vendored
Normal file
72
fooder/test/fixtures/db.py
vendored
Normal file
|
|
@ -0,0 +1,72 @@
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from sqlalchemy import event
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from fooder.db import DatabaseSessionManager
|
||||||
|
from fooder.domain import Base
|
||||||
|
from fooder.settings import settings
|
||||||
|
from fooder.repository.base import RepositoryBase
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
|
||||||
|
class TestModel(Base):
|
||||||
|
__tablename__ = "test"
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
property: Mapped[str]
|
||||||
|
|
||||||
|
|
||||||
|
class TestRepository(RepositoryBase[TestModel]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_repo(db_session):
|
||||||
|
return TestRepository(TestModel, db_session)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_model_factory():
|
||||||
|
def factory(id: int, property: str = "value") -> TestModel:
|
||||||
|
return TestModel(id=id, property=property)
|
||||||
|
|
||||||
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_model(test_model_factory):
|
||||||
|
return test_model_factory(1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def db_manager() -> DatabaseSessionManager:
|
||||||
|
return DatabaseSessionManager(settings)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db_session(db_manager):
|
||||||
|
async with db_manager._engine.connect() as conn:
|
||||||
|
trans = await conn.begin()
|
||||||
|
session = AsyncSession(bind=conn)
|
||||||
|
|
||||||
|
nested = await conn.begin_nested()
|
||||||
|
|
||||||
|
@event.listens_for(session.sync_session, "after_transaction_end")
|
||||||
|
def restart_savepoint(sess, transaction):
|
||||||
|
nonlocal nested
|
||||||
|
if not nested.is_active:
|
||||||
|
nested = conn.sync_connection.begin_nested()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
await trans.rollback()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||||
|
async def setup_database(db_manager: DatabaseSessionManager):
|
||||||
|
async with db_manager.connect() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
yield
|
||||||
24
fooder/test/fixtures/dbssn.py
vendored
24
fooder/test/fixtures/dbssn.py
vendored
|
|
@ -1,24 +0,0 @@
|
||||||
import pytest_asyncio
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy import event
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
|
||||||
async def db_session(db_manager):
|
|
||||||
async with db_manager._engine.connect() as conn:
|
|
||||||
trans = await conn.begin()
|
|
||||||
session = AsyncSession(bind=conn)
|
|
||||||
|
|
||||||
nested = await conn.begin_nested()
|
|
||||||
|
|
||||||
@event.listens_for(session.sync_session, "after_transaction_end")
|
|
||||||
def restart_savepoint(sess, transaction):
|
|
||||||
nonlocal nested
|
|
||||||
if not nested.is_active:
|
|
||||||
nested = conn.sync_connection.begin_nested()
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
await trans.rollback()
|
|
||||||
9
fooder/test/fixtures/faker.py
vendored
Normal file
9
fooder/test/fixtures/faker.py
vendored
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
import pytest
|
||||||
|
from faker import Faker
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def faker():
|
||||||
|
f = Faker(["en_US", "pl_PL"])
|
||||||
|
f.seed_instance(1234)
|
||||||
|
return f
|
||||||
0
fooder/test/repository/__init__.py
Normal file
0
fooder/test/repository/__init__.py
Normal file
96
fooder/test/repository/test_base.py
Normal file
96
fooder/test/repository/test_base.py
Normal file
|
|
@ -0,0 +1,96 @@
|
||||||
|
import pytest
|
||||||
|
import faker
|
||||||
|
# ------------------------------------------------------------------ create ---
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_returns_object_with_id(test_repo, test_model):
|
||||||
|
model = await test_repo.create(test_model)
|
||||||
|
assert model.id is not None
|
||||||
|
assert model.property is not None
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------- get ----
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_returns_existing_record(test_repo, test_model):
|
||||||
|
created = await test_repo.create(test_model)
|
||||||
|
found = await test_repo.get(id=created.id)
|
||||||
|
assert found is not None
|
||||||
|
assert found.id == created.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_returns_none_for_missing_record(test_repo):
|
||||||
|
result = await test_repo.get(id=1)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_by_field(test_repo, test_model_factory):
|
||||||
|
await test_repo.create(test_model_factory(1, "value"))
|
||||||
|
found = await test_repo.get(property="value")
|
||||||
|
assert found is not None
|
||||||
|
assert found.id == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------- list ----
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_returns_all_matching(test_repo, test_model_factory):
|
||||||
|
await test_repo.create(test_model_factory(1))
|
||||||
|
await test_repo.create(test_model_factory(2))
|
||||||
|
results = await test_repo.list()
|
||||||
|
modelnames = {u.id for u in results}
|
||||||
|
assert {1, 2}.issubset(modelnames)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_with_filter(test_repo, test_model_factory):
|
||||||
|
await test_repo.create(test_model_factory(1, "value"))
|
||||||
|
await test_repo.create(test_model_factory(2, "value2"))
|
||||||
|
results = await test_repo.list(property="value")
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].id == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_returns_empty_when_no_match(test_repo, test_model_factory):
|
||||||
|
await test_repo.create(test_model_factory(1, "value"))
|
||||||
|
results = await test_repo.list(property="value2")
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ delete ---
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_removes_record_and_returns_count(test_repo, test_model):
|
||||||
|
model = await test_repo.create(test_model)
|
||||||
|
count = await test_repo.delete(id=model.id)
|
||||||
|
assert count == 1
|
||||||
|
assert await test_repo.get(id=model.id) is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_returns_zero_when_nothing_matched(test_repo, test_model_factory):
|
||||||
|
count = await test_repo.delete(id=999999)
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ update ---
|
||||||
|
|
||||||
|
|
||||||
|
async def test_update_modifies_record_and_returns_count(test_repo, test_model):
|
||||||
|
model = await test_repo.create(test_model)
|
||||||
|
count = await test_repo.update({"id": model.id}, {"property": "value2"})
|
||||||
|
assert count == 1
|
||||||
|
refreshed = await test_repo.get(id=model.id)
|
||||||
|
assert refreshed is not None
|
||||||
|
assert refreshed.property == "value2"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_update_returns_zero_when_nothing_matched(test_repo, test_model_factory):
|
||||||
|
count = await test_repo.update({"id": 999999}, {"property": "value"})
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------- _build_select ------
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_select_raises_for_unknown_field(test_repo):
|
||||||
|
with pytest.raises(ValueError, match="has no attribute 'nonexistent'"):
|
||||||
|
test_repo._build_select(nonexistent="value")
|
||||||
|
|
@ -9,6 +9,7 @@ passlib[bcrypt]
|
||||||
fastapi-users
|
fastapi-users
|
||||||
pytest
|
pytest
|
||||||
pytest-asyncio
|
pytest-asyncio
|
||||||
|
faker
|
||||||
requests
|
requests
|
||||||
black
|
black
|
||||||
flake8
|
flake8
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue