Source code for minos.aggregate.snapshots.repositories.memory

from __future__ import (
    annotations,
)

from contextlib import (
    suppress,
)
from functools import (
    cmp_to_key,
)
from operator import (
    attrgetter,
)
from typing import (
    AsyncIterator,
    Optional,
)
from uuid import (
    UUID,
)

from minos.common import (
    NULL_UUID,
    Inject,
    NotProvidedException,
)

from ...events import (
    EventEntry,
    EventRepository,
)
from ...exceptions import (
    AlreadyDeletedException,
)
from ...queries import (
    _Condition,
    _Ordering,
)
from ...transactions import (
    TransactionEntry,
    TransactionRepository,
    TransactionStatus,
)
from ..entries import (
    SnapshotEntry,
)
from .abc import (
    SnapshotRepository,
)


[docs]class InMemorySnapshotRepository(SnapshotRepository): """InMemory Snapshot class. The snapshot provides a direct accessor to the ``RootEntity`` instances stored as events by the event repository class. """
[docs] @Inject() def __init__( self, *args, event_repository: EventRepository, transaction_repository: TransactionRepository, **kwargs, ): super().__init__(*args, **kwargs) if event_repository is None: raise NotProvidedException("An event repository instance is required.") if transaction_repository is None: raise NotProvidedException("A transaction repository instance is required.") self._event_repository = event_repository self._transaction_repository = transaction_repository
async def _find_entries( self, name: str, condition: _Condition, ordering: Optional[_Ordering], limit: Optional[int], exclude_deleted: bool, **kwargs, ) -> AsyncIterator[SnapshotEntry]: uuids = {v.uuid async for v in self._event_repository.select(name=name)} entries = list() for uuid in uuids: entry = await self._get(name, uuid, **kwargs) try: instance = entry.build() if condition.evaluate(instance): entries.append(entry) except AlreadyDeletedException: # noinspection PyTypeChecker if not exclude_deleted and condition.evaluate(entry): entries.append(entry) if ordering is not None: def _cmp(a: SnapshotEntry, b: SnapshotEntry) -> int: with suppress(AlreadyDeletedException): with suppress(AlreadyDeletedException): try: aa = attrgetter(ordering.by)(a.build()) except AlreadyDeletedException: aa = attrgetter(ordering.by)(a) with suppress(AlreadyDeletedException): try: bb = attrgetter(ordering.by)(b.build()) except AlreadyDeletedException: bb = attrgetter(ordering.by)(b) if aa > bb: return 1 elif aa < bb: return -1 return 0 entries.sort(key=cmp_to_key(_cmp), reverse=ordering.reverse) if limit is not None: entries = entries[:limit] for entry in entries: yield entry # noinspection PyMethodOverriding async def _get( self, name: str, uuid: UUID, transaction: Optional[TransactionEntry] = None, **kwargs ) -> SnapshotEntry: transaction_uuids = await self._get_transaction_uuids(transaction) entries = await self._get_event_entries(name, uuid, transaction_uuids) return self._build_instance(entries, **kwargs) async def _get_transaction_uuids(self, transaction: Optional[TransactionEntry]) -> tuple[UUID, ...]: if transaction is None: transaction_uuids = (NULL_UUID,) else: transaction_uuids = await transaction.uuids while len(transaction_uuids) > 1: transaction = await self._transaction_repository.get(uuid=transaction_uuids[-1]) if transaction.status != TransactionStatus.REJECTED: break transaction_uuids = tuple(transaction_uuids[:-1]) return transaction_uuids async def _get_event_entries(self, name: str, uuid: UUID, transaction_uuids: tuple[UUID, ...]) -> list[EventEntry]: entries = [ v async for v in self._event_repository.select(name=name, uuid=uuid) if v.transaction_uuid in transaction_uuids ] entries.sort(key=lambda e: (e.version, transaction_uuids.index(e.transaction_uuid))) if len({e.transaction_uuid for e in entries}) > 1: new = [entries.pop()] for e in reversed(entries): if e.version < new[-1].version: new.append(e) entries = list(reversed(new)) return entries @staticmethod def _build_instance(entries: list[EventEntry], **kwargs) -> SnapshotEntry: if entries[-1].action.is_delete: return SnapshotEntry.from_event_entry(entries[-1]) cls = entries[0].type_ instance = cls.from_diff(entries[0].event, **kwargs) for entry in entries[1:]: instance.apply_diff(entry.event) snapshot = SnapshotEntry.from_root_entity(instance) return snapshot async def _synchronize(self, **kwargs) -> None: pass