Source code for minos.aggregate.entities.collections

from __future__ import (
    annotations,
)

from operator import (
    attrgetter,
)
from typing import (
    Any,
    Iterable,
    Iterator,
    Optional,
    TypeVar,
    Union,
    get_args,
)
from uuid import (
    UUID,
)

from minos.common import (
    DataDecoder,
    DataEncoder,
    DeclarativeModel,
    Model,
    ModelType,
    SchemaEncoder,
)

from ..collections import (
    IncrementalSet,
    IncrementalSetDiff,
)

T = TypeVar("T", bound=Model)


[docs]class EntitySet(IncrementalSet[T]): """Entity set class.""" data: dict[str, T]
[docs] def __init__(self, data: Optional[Iterable[T]] = None, *args, **kwargs): if data is None: data = dict() elif not isinstance(data, dict): data = {str(entity.uuid): entity for entity in data} DeclarativeModel.__init__(self, data, *args, **kwargs)
[docs] def add(self, entity: T) -> None: """Add an entity. :param entity: The entity to be added. :return: This method does not return anything. """ self.data[str(entity.uuid)] = entity
[docs] def discard(self, entity: T) -> None: """Discard an entity. :param entity: The entity to be discarded. :return: This method does not return anything. """ if not isinstance(entity, UUID): entity = entity.uuid self.data.pop(str(entity), None)
[docs] def get(self, uuid: UUID) -> T: """Get an entity by identifier. :param uuid: The identifier of the entity. :return: A entity instance. """ return self.data[str(uuid)]
def __contains__(self, entity: Union[T, UUID]) -> bool: if not isinstance(entity, UUID): if not hasattr(entity, "uuid"): return False entity = entity.uuid return str(entity) in self.data def __iter__(self) -> Iterator[T]: yield from self.data.values() def __eq__(self, other): if isinstance(other, EntitySet): return super().__eq__(other) if isinstance(other, dict): return self.data == other return set(self) == other
[docs] def diff(self, another: EntitySet[T]) -> IncrementalSetDiff: """Compute the difference between self and another entity set. :param another: Another entity set instance. :return: The difference between both entity sets. """ return IncrementalSetDiff.from_difference(self, another, get_fn=attrgetter("uuid"))
@property def data_cls(self) -> Optional[type]: """Get data class if available. :return: A model type. """ args = get_args(self.type_hints["data"]) return args[1] # noinspection PyMethodParameters
[docs] @classmethod def encode_schema(cls, encoder: SchemaEncoder, target: Any, **kwargs) -> Any: """Encode schema with the given encoder. :param encoder: The encoder instance. :param target: An optional pre-encoded schema. :return: The encoded schema of the instance. """ type_ = get_args(target.type_hints["data"])[-1] schema = encoder.build(list[type_], **kwargs) return schema | {"logicalType": cls.classname}
[docs] @staticmethod def encode_data(encoder: DataEncoder, target: Any, **kwargs) -> Any: """Encode data with the given encoder. :param encoder: The encoder instance. :param target: An optional pre-encoded data. :return: The encoded data of the instance. """ target = list(target["data"].values()) return encoder.build(target, **kwargs)
[docs] @classmethod def decode_data(cls, decoder: DataDecoder, target: Any, type_: ModelType, **kwargs) -> IncrementalSet: """Decode data with the given decoder. :param decoder: The decoder instance. :param target: The data to be decoded. :param type_: The data type. :return: A decoded instance. """ data_cls = get_args(type_.type_hints["data"])[1] target = (decoder.build(v, data_cls, **kwargs) for v in target) target = {str(v["uuid"]): v for v in target} decoded = decoder.build(target, type_.type_hints["data"], **kwargs) return cls(decoded, additional_type_hints=type_.type_hints)