Source code for minos.common.database.mixins

from collections.abc import (
    AsyncIterator,
)
from contextlib import (
    suppress,
)
from typing import (
    Any,
    Generic,
    Optional,
    TypeVar,
    get_args,
    get_origin,
)

from ..exceptions import (
    NotProvidedException,
)
from ..injections import (
    Inject,
)
from ..pools import (
    PoolException,
    PoolFactory,
)
from ..setup import (
    SetupMixin,
)
from .operations import (
    DatabaseOperation,
    DatabaseOperationFactory,
)
from .pools import (
    DatabaseClient,
    DatabaseClientPool,
)

GenericDatabaseOperationFactory = TypeVar("GenericDatabaseOperationFactory", bound=DatabaseOperationFactory)


[docs]class DatabaseMixin(SetupMixin, Generic[GenericDatabaseOperationFactory]): """Database Mixin class."""
[docs] @Inject() def __init__( self, database_pool: Optional[DatabaseClientPool] = None, pool_factory: Optional[PoolFactory] = None, database_key: Optional[tuple[str]] = None, operation_factory: Optional[GenericDatabaseOperationFactory] = None, operation_factory_cls: Optional[type[GenericDatabaseOperationFactory]] = None, *args, **kwargs, ): super().__init__(*args, **kwargs, pool_factory=pool_factory) if database_pool is None and pool_factory is not None: database_pool = self._get_pool_from_factory(pool_factory, database_key) if not isinstance(database_pool, DatabaseClientPool): raise NotProvidedException(f"A {DatabaseClientPool!r} instance is required. Obtained: {database_pool}") self._pool = database_pool if operation_factory is None: if operation_factory_cls is None: operation_factory_cls = self._get_generic_operation_factory() if operation_factory_cls is not None: operation_factory = self.database_client_cls.get_factory(operation_factory_cls) self._operation_factory = operation_factory
@staticmethod def _get_pool_from_factory(pool_factory: PoolFactory, database_key: Optional[tuple[str]]): if database_key is None: database_key = tuple() for identifier in database_key: with suppress(PoolException): return pool_factory.get_pool(type_="database", identifier=identifier) return pool_factory.get_pool(type_="database") @property def database_operation_factory(self) -> Optional[GenericDatabaseOperationFactory]: """Get the operation factory if any. :return: A ``OperationFactory`` if it has been set or ``None`` otherwise. """ return self._operation_factory def _get_generic_operation_factory(self) -> Optional[type[GenericDatabaseOperationFactory]]: operation_factory_cls = None # noinspection PyUnresolvedReferences for base in self.__orig_bases__: origin = get_origin(base) if origin is None or not issubclass(origin, DatabaseMixin): continue args = get_args(base) operation_factory_cls = args[0] if not isinstance(operation_factory_cls, type) or not issubclass( operation_factory_cls, DatabaseOperationFactory ): raise TypeError(f"{type(self)!r} must contain a {DatabaseOperationFactory!r} as generic value.") return operation_factory_cls
[docs] async def execute_on_database_and_fetch_one(self, operation: DatabaseOperation) -> Any: """Submit an Operation and get the first response. :param operation: The operation to be executed. :return: This method does not return anything. """ async with self.database_pool.acquire() as client: await client.execute(operation) return await client.fetch_one()
# noinspection PyUnusedLocal
[docs] async def execute_on_database_and_fetch_all( self, operation: DatabaseOperation, streaming_mode: Optional[bool] = None ) -> AsyncIterator[tuple]: """Submit an Operation and return an asynchronous iterator. :param operation: The operation to be executed. :param streaming_mode: If ``True`` return the values in streaming directly from the database (keep an open database connection), otherwise preloads the full set of values on memory and then retrieves them. :return: This method does not return anything. """ if streaming_mode is None: streaming_mode = False async with self.database_pool.acquire() as client: await client.execute(operation) async_iterable = client.fetch_all() if streaming_mode: async for value in async_iterable: yield value return iterable = [value async for value in async_iterable] for value in iterable: yield value
# noinspection PyUnusedLocal
[docs] async def execute_on_database(self, operation: DatabaseOperation) -> None: """Submit an Operation. :param operation: The operation to be executed. :return: This method does not return anything. """ async with self.database_pool.acquire() as client: return await client.execute(operation)
@property def database_client_cls(self) -> type[DatabaseClient]: """Get the client's class. :return: A ``type`` instance that is subclass of ``DatabaseClient``. """ return self.database_pool.client_cls @property def database_pool(self) -> DatabaseClientPool: """Get the database pool. :return: A ``DatabaseClientPool`` object. """ return self._pool