diff --git a/hazelcast/asyncio/client.py b/hazelcast/asyncio/client.py index 6920eef361..1e08bab92f 100644 --- a/hazelcast/asyncio/client.py +++ b/hazelcast/asyncio/client.py @@ -8,7 +8,6 @@ from hazelcast.config import Config, IndexConfig from hazelcast.internal.asyncio_connection import ConnectionManager, DefaultAddressProvider from hazelcast.core import DistributedObjectEvent, DistributedObjectInfo -from hazelcast.cp import CPSubsystem, ProxySessionManager from hazelcast.discovery import HazelcastCloudAddressProvider from hazelcast.errors import IllegalStateError, InvalidConfigurationError from hazelcast.internal.asyncio_invocation import InvocationService, Invocation @@ -16,7 +15,7 @@ from hazelcast.lifecycle import LifecycleService, LifecycleState, _InternalLifecycleService from hazelcast.internal.asyncio_listener import ClusterViewListenerService, ListenerService from hazelcast.near_cache import NearCacheManager -from hazelcast.partition import PartitionService, _InternalPartitionService +from hazelcast.internal.asyncio_partition import PartitionService, InternalPartitionService from hazelcast.protocol.codec import ( client_add_distributed_object_listener_codec, client_get_distributed_objects_codec, @@ -34,7 +33,7 @@ from hazelcast.serialization import SerializationServiceV1 from hazelcast.sql import SqlService, _InternalSqlService from hazelcast.internal.asyncio_statistics import Statistics -from hazelcast.types import KeyType, ValueType, ItemType, MessageType +from hazelcast.types import KeyType, ValueType from hazelcast.util import AtomicInteger, RoundRobinLB __all__ = ("HazelcastClient",) @@ -84,7 +83,7 @@ def __init__(self, config: Config | None = None, **kwargs): self._config, ) self._address_provider = self._create_address_provider() - self._internal_partition_service = _InternalPartitionService(self) + self._internal_partition_service = InternalPartitionService(self) self._partition_service = PartitionService( self._internal_partition_service, self._serialization_service, @@ -111,8 +110,6 @@ def __init__(self, config: Config | None = None, **kwargs): self._compact_schema_service, ) self._proxy_manager = ProxyManager(self._context) - self._cp_subsystem = CPSubsystem(self._context) - self._proxy_session_manager = ProxySessionManager(self._context) self._lock_reference_id_generator = AtomicInteger(1) self._statistics = Statistics( self, @@ -159,7 +156,6 @@ def _init_context(self): self._near_cache_manager, self._lock_reference_id_generator, self._name, - self._proxy_session_manager, self._reactor, self._compact_schema_service, ) @@ -167,7 +163,7 @@ def _init_context(self): async def _start(self): try: self._internal_lifecycle_service.start() - self._invocation_service.start() + await self._invocation_service.start() membership_listeners = self._config.membership_listeners self._internal_cluster_service.start(self._connection_manager, membership_listeners) self._cluster_view_listener.start() @@ -278,7 +274,6 @@ async def shutdown(self) -> None: if self._internal_lifecycle_service.running: self._internal_lifecycle_service.fire_lifecycle_event(LifecycleState.SHUTTING_DOWN) self._internal_lifecycle_service.shutdown() - self._proxy_session_manager.shutdown().result() self._near_cache_manager.destroy_near_caches() await self._connection_manager.shutdown() self._invocation_service.shutdown() @@ -301,10 +296,6 @@ def partition_service(self) -> PartitionService: def cluster_service(self) -> ClusterService: return self._cluster_service - @property - def cp_subsystem(self) -> CPSubsystem: - return self._cp_subsystem - def _create_address_provider(self): config = self._config cluster_members = config.cluster_members @@ -360,7 +351,6 @@ def __init__(self): self.near_cache_manager = None self.lock_reference_id_generator = None self.name = None - self.proxy_session_manager = None self.reactor = None self.compact_schema_service = None @@ -378,7 +368,6 @@ def init_context( near_cache_manager, lock_reference_id_generator, name, - proxy_session_manager, reactor, compact_schema_service, ): @@ -394,6 +383,5 @@ def init_context( self.near_cache_manager = near_cache_manager self.lock_reference_id_generator = lock_reference_id_generator self.name = name - self.proxy_session_manager = proxy_session_manager self.reactor = reactor self.compact_schema_service = compact_schema_service diff --git a/hazelcast/internal/asyncio_compact.py b/hazelcast/internal/asyncio_compact.py index 06f53ab97c..1533befdc3 100644 --- a/hazelcast/internal/asyncio_compact.py +++ b/hazelcast/internal/asyncio_compact.py @@ -57,35 +57,38 @@ def fetch_schema(self, schema_id: int) -> asyncio.Future: self._invocation_service.invoke(fetch_schema_invocation) return fetch_schema_invocation.future - def send_schema_and_retry( + async def send_schema_and_retry( self, error: "SchemaNotReplicatedError", func: typing.Callable[..., asyncio.Future], *args: typing.Any, **kwargs: typing.Any, - ) -> asyncio.Future: + ) -> None: schema = error.schema clazz = error.clazz request = client_send_schema_codec.encode_request(schema) - def callback(): + async def callback(): self._has_replicated_schemas = True self._compact_serializer.register_schema_to_type(schema, clazz) - return func(*args, **kwargs) + maybe_coro = func(*args, **kwargs) + # maybe_coro maybe a coroutine or None + if maybe_coro: + return await maybe_coro - return self._replicate_schema( - schema, request, CompactSchemaService._SEND_SCHEMA_RETRY_COUNT, callback + return await self._replicate_schema( + schema, request, CompactSchemaService._SEND_SCHEMA_RETRY_COUNT, callback() ) - def _replicate_schema( + async def _replicate_schema( self, schema: "Schema", request: "OutboundMessage", remaining_retries: int, - callback: typing.Callable[..., asyncio.Future], - ) -> asyncio.Future: - def continuation(future: asyncio.Future): - replicated_members = future.result() + callback: typing.Coroutine[typing.Any, typing.Any, typing.Any], + ) -> None: + while remaining_retries >= 2: + replicated_members = await self._send_schema_replication_request(request) members = self._cluster_service.get_members() for member in members: if member.uuid not in replicated_members: @@ -93,41 +96,25 @@ def continuation(future: asyncio.Future): else: # Loop completed normally. # All members in our member list all known to have the schema - return callback() + return await callback # There is a member in our member list that the schema # is not known to be replicated yet. We should retry # sending it in a random member. - if remaining_retries <= 1: - # We tried to send it a couple of times, but the member list - # in our local and the member list returned by the initiator - # nodes did not match. - raise IllegalStateError( - f"The schema {schema} cannot be replicated in the cluster, " - f"after {CompactSchemaService._SEND_SCHEMA_RETRY_COUNT} retries. " - f"It might be the case that the client is connected to the two " - f"halves of the cluster that is experiencing a split-brain, " - f"and continue putting the data associated with that schema " - f"might result in data loss. It might be possible to replicate " - f"the schema after some time, when the cluster is healed." - ) - - delayed_future: asyncio.Future = asyncio.get_running_loop().create_future() - self._reactor.add_timer( - self._invocation_retry_pause, - lambda: delayed_future.set_result(None), - ) - - def retry(_): - return self._replicate_schema( - schema, request.copy(), remaining_retries - 1, callback - ) - - return delayed_future.add_done_callback(retry) - - fut = self._send_schema_replication_request(request) - fut.add_done_callback(continuation) - return fut + await asyncio.sleep(self._invocation_retry_pause) + + # We tried to send it a couple of times, but the member list + # in our local and the member list returned by the initiator + # nodes did not match. + raise IllegalStateError( + f"The schema {schema} cannot be replicated in the cluster, " + f"after {CompactSchemaService._SEND_SCHEMA_RETRY_COUNT} retries. " + f"It might be the case that the client is connected to the two " + f"halves of the cluster that is experiencing a split-brain, " + f"and continue putting the data associated with that schema " + f"might result in data loss. It might be possible to replicate " + f"the schema after some time, when the cluster is healed." + ) def _send_schema_replication_request(self, request: "OutboundMessage") -> asyncio.Future: invocation = Invocation(request, response_handler=client_send_schema_codec.decode_response) diff --git a/hazelcast/internal/asyncio_invocation.py b/hazelcast/internal/asyncio_invocation.py index 591740faa0..f7effc6e5b 100644 --- a/hazelcast/internal/asyncio_invocation.py +++ b/hazelcast/internal/asyncio_invocation.py @@ -96,7 +96,7 @@ def __init__(self, client, config, reactor): self._backup_ack_to_client_enabled = smart_routing and config.backup_ack_to_client_enabled self._fail_on_indeterminate_state = config.fail_on_indeterminate_operation_state self._backup_timeout = config.operation_backup_timeout - self._clean_resources_timer = None + self._clean_resources_task = None self._shutdown = False self._compact_schema_service = None @@ -107,8 +107,8 @@ def init(self, partition_service, connection_manager, listener_service, compact_ self._check_invocation_allowed_fn = connection_manager.check_invocation_allowed self._compact_schema_service = compact_schema_service - def start(self): - self._start_clean_resources_timer() + async def start(self): + await self._start_clean_resources_timer() async def add_backup_listener(self): if self._backup_ack_to_client_enabled: @@ -152,8 +152,8 @@ def shutdown(self): return self._shutdown = True - if self._clean_resources_timer: - self._clean_resources_timer.cancel() + if self._clean_resources_task: + self._clean_resources_task.cancel() for invocation in list(self._pending.values()): self._notify_error(invocation, HazelcastClientNotActiveError()) @@ -400,8 +400,9 @@ def _notify_backup_complete(self, invocation): self._complete(invocation, invocation.pending_response) - def _start_clean_resources_timer(self): - def run(): + async def _start_clean_resources_timer(self): + async def run(): + await asyncio.sleep(self._CLEAN_RESOURCES_PERIOD) if self._shutdown: return @@ -419,9 +420,9 @@ def run(): if self._backup_ack_to_client_enabled: self._detect_and_handle_backup_timeout(invocation, now) - self._clean_resources_timer = self._reactor.add_timer(self._CLEAN_RESOURCES_PERIOD, run) + self._clean_resources_task = asyncio.create_task(run()) - self._clean_resources_timer = self._reactor.add_timer(self._CLEAN_RESOURCES_PERIOD, run) + self._clean_resources_task = asyncio.create_task(run()) def _detect_and_handle_backup_timeout(self, invocation, now): if not invocation.pending_response: diff --git a/hazelcast/internal/asyncio_partition.py b/hazelcast/internal/asyncio_partition.py new file mode 100644 index 0000000000..a560971d69 --- /dev/null +++ b/hazelcast/internal/asyncio_partition.py @@ -0,0 +1,165 @@ +import logging +import uuid + +import typing + +from hazelcast.errors import ClientOfflineError +from hazelcast.hash import hash_to_index +from hazelcast.serialization.compact import SchemaNotReplicatedError + +_logger = logging.getLogger(__name__) + + +class _PartitionTable: + __slots__ = ("connection", "version", "partitions") + + def __init__(self, connection, version, partitions): + self.connection = connection + self.version = version + self.partitions = partitions + + def __repr__(self): + return "PartitionTable(connection=%s, version=%s)" % (self.connection, self.version) + + +class PartitionService: + """ + Allows retrieving information about the partition count, the partition + owner or the partition id of a key. + """ + + __slots__ = ("_service", "_serialization_service", "_send_schema_and_retry_fn") + + def __init__(self, internal_partition_service, serialization_service, send_schema_and_retry_fn): + self._service = internal_partition_service + self._serialization_service = serialization_service + self._send_schema_and_retry_fn = send_schema_and_retry_fn + + def get_partition_owner(self, partition_id: int) -> typing.Optional[uuid.UUID]: + """ + Returns the owner of the partition if it's set, ``None`` otherwise. + + Args: + partition_id: The partition id. + + Returns: + Owner of the partition + """ + return self._service.get_partition_owner(partition_id) + + async def get_partition_id(self, key: typing.Any) -> int: + """ + Returns the partition id for a key data. + + Args: + key: The given key. + + Returns: + The partition id. + """ + try: + key_data = self._serialization_service.to_data(key) + except SchemaNotReplicatedError as e: + await self._send_schema_and_retry_fn(e, lambda: None) + return await self.get_partition_id(key) + + return self._service.get_partition_id(key_data) + + def get_partition_count(self) -> int: + """ + Returns partition count of the connected cluster. + + If partition table is not fetched yet, this method returns ``0``. + + Returns: + The partition count + """ + return self._service.partition_count + + +class InternalPartitionService: + __slots__ = ("partition_count", "_client", "_partition_table") + + def __init__(self, client): + self.partition_count = 0 + self._client = client + self._partition_table = _PartitionTable(None, -1, {}) + + def handle_partitions_view_event(self, connection, partitions, version): + _logger.debug("Handling new partition table with version: %s", version) + + table = self._partition_table + if not self._should_be_applied(connection, partitions, version, table): + return + + new_partitions = self._prepare_partitions(partitions) + new_table = _PartitionTable(connection, version, new_partitions) + self._partition_table = new_table + + def get_partition_owner(self, partition_id): + return self._partition_table.partitions.get(partition_id, None) + + def get_partition_id(self, key): + if self.partition_count == 0: + # Partition count can not be zero for the SYNC mode. + # On the SYNC mode, we are waiting for the first connection to be established. + # We are initializing the partition count with the value coming from the server with authentication. + # This error is used only for ASYNC mode client. + raise ClientOfflineError() + return hash_to_index(key.get_partition_hash(), self.partition_count) + + def check_and_set_partition_count(self, partition_count): + if self.partition_count == 0: + self.partition_count = partition_count + return True + return self.partition_count == partition_count + + @classmethod + def _should_be_applied(cls, connection, partitions, version, current): + if not partitions: + _logger.debug( + "Partition view will not be applied since response is empty. " + "Sending connection: %s, version: %s, current table: %s", + connection, + version, + current, + ) + return False + + if connection != current.connection: + _logger.debug( + "Partition view event coming from a new connection. Old: %s, new: %s", + current.connection, + connection, + ) + return True + + if version <= current.version: + _logger.debug( + "Partition view will not be applied since response state version is older. " + "Sending connection: %s, version: %s, current table: %s", + connection, + version, + current, + ) + return False + + return True + + @classmethod + def _prepare_partitions(cls, partitions): + new_partitions = {} + for uuid, partition_list in partitions: + for partition in partition_list: + new_partitions[partition] = uuid + return new_partitions + + +def string_partition_strategy(key): + if key is None: + return None + try: + index_of = key.index("@") + return key[index_of + 1 :] + except ValueError: + return key diff --git a/hazelcast/internal/asyncio_proxy/base.py b/hazelcast/internal/asyncio_proxy/base.py index 62d4ccb44d..8edd64a53b 100644 --- a/hazelcast/internal/asyncio_proxy/base.py +++ b/hazelcast/internal/asyncio_proxy/base.py @@ -6,7 +6,6 @@ from hazelcast.core import MemberInfo from hazelcast.types import KeyType, ValueType, ItemType, MessageType, BlockingProxyType from hazelcast.internal.asyncio_invocation import Invocation -from hazelcast.partition import string_partition_strategy from hazelcast.util import get_attr_name MAX_SIZE = float("inf") @@ -67,15 +66,14 @@ def _invoke_on_target( self._invocation_service.invoke(invocation) return invocation.future - def _invoke_on_key( + async def _invoke_on_key( self, request, key_data, response_handler=_no_op_response_handler - ) -> asyncio.Future: + ) -> typing.Any: partition_id = self._partition_service.get_partition_id(key_data) invocation = Invocation( request, partition_id=partition_id, response_handler=response_handler ) - self._invocation_service.invoke(invocation) - return invocation.future + return await self._invocation_service.ainvoke(invocation) def _invoke_on_partition( self, request, partition_id, response_handler=_no_op_response_handler @@ -93,22 +91,6 @@ async def _ainvoke_on_partition( return await fut -class PartitionSpecificProxy(Proxy[BlockingProxyType], abc.ABC): - """Provides basic functionality for Partition Specific Proxies.""" - - def __init__(self, service_name, name, context): - super(PartitionSpecificProxy, self).__init__(service_name, name, context) - partition_key = context.serialization_service.to_data(string_partition_strategy(self.name)) - self._partition_id = context.partition_service.get_partition_id(partition_key) - - def _invoke(self, request, response_handler=_no_op_response_handler): - invocation = Invocation( - request, partition_id=self._partition_id, response_handler=response_handler - ) - self._invocation_service.invoke(invocation) - return invocation.future - - class ItemEventType: """Type of item events.""" diff --git a/hazelcast/internal/asyncio_proxy/map.py b/hazelcast/internal/asyncio_proxy/map.py index 913bd4ef34..84c8ecaa14 100644 --- a/hazelcast/internal/asyncio_proxy/map.py +++ b/hazelcast/internal/asyncio_proxy/map.py @@ -272,7 +272,7 @@ async def add_interceptor(self, interceptor: typing.Any) -> str: try: interceptor_data = self._to_data(interceptor) except SchemaNotReplicatedError as e: - return self._send_schema_and_retry(e, self.add_interceptor, interceptor) + return await self._send_schema_and_retry(e, self.add_interceptor, interceptor) request = map_add_interceptor_codec.encode_request(self.name, interceptor_data) return await self._invoke(request, map_add_interceptor_codec.decode_response) @@ -872,7 +872,7 @@ def _delete_internal(self, key_data): request = map_delete_codec.encode_request(self.name, key_data, thread_id()) return self._invoke_on_key(request, key_data) - def _put_internal(self, key_data, value_data, ttl, max_idle): + async def _put_internal(self, key_data, value_data, ttl, max_idle): def handler(message): return self._to_object(map_put_codec.decode_response(message)) @@ -884,7 +884,7 @@ def handler(message): request = map_put_codec.encode_request( self.name, key_data, value_data, thread_id(), to_millis(ttl) ) - return self._invoke_on_key(request, key_data, handler) + return await self._invoke_on_key(request, key_data, handler) def _set_internal(self, key_data, value_data, ttl, max_idle): if max_idle is not None: @@ -1106,9 +1106,11 @@ def _put_transient_internal(self, key_data, value_data, ttl, max_idle): key_data, value_data, ttl, max_idle ) - def _put_internal(self, key_data, value_data, ttl, max_idle): + async def _put_internal(self, key_data, value_data, ttl, max_idle): self._invalidate_cache(key_data) - return super(MapFeatNearCache, self)._put_internal(key_data, value_data, ttl, max_idle) + return await super(MapFeatNearCache, self)._put_internal( + key_data, value_data, ttl, max_idle + ) def _put_if_absent_internal(self, key_data, value_data, ttl, max_idle): self._invalidate_cache(key_data) diff --git a/hazelcast/internal/asyncio_proxy/vector_collection.py b/hazelcast/internal/asyncio_proxy/vector_collection.py index ef12e032af..4abe7fb3b2 100644 --- a/hazelcast/internal/asyncio_proxy/vector_collection.py +++ b/hazelcast/internal/asyncio_proxy/vector_collection.py @@ -125,12 +125,12 @@ async def size(self) -> int: request = vector_collection_size_codec.encode_request(self.name) return await self._invoke(request, vector_collection_size_codec.decode_response) - def _set_internal(self, key: Any, document: Document) -> asyncio.Future[None]: + async def _set_internal(self, key: Any, document: Document) -> None: try: key_data = self._to_data(key) value_data = self._to_data(document.value) except SchemaNotReplicatedError as e: - return self._send_schema_and_retry(e, self.set, key, document) + return await self._send_schema_and_retry(e, self.set, key, document) document = copy.copy(document) document.value = value_data request = vector_collection_set_codec.encode_request( @@ -138,9 +138,9 @@ def _set_internal(self, key: Any, document: Document) -> asyncio.Future[None]: key_data, document, ) - return self._invoke_on_key(request, key_data) + return await self._invoke_on_key(request, key_data) - def _get_internal(self, key: Any) -> asyncio.Future[Any]: + async def _get_internal(self, key: Any) -> Any: def handler(message): doc = vector_collection_get_codec.decode_response(message) return self._transform_document(doc) @@ -148,12 +148,12 @@ def handler(message): try: key_data = self._to_data(key) except SchemaNotReplicatedError as e: - return self._send_schema_and_retry(e, self.get, key) + return await self._send_schema_and_retry(e, self.get, key) request = vector_collection_get_codec.encode_request( self.name, key_data, ) - return self._invoke_on_key(request, key_data, response_handler=handler) + return await self._invoke_on_key(request, key_data, response_handler=handler) def _search_near_vector_internal( self, @@ -191,21 +191,21 @@ def handler(message): ) return self._invoke(request, response_handler=handler) - def _delete_internal(self, key: Any) -> asyncio.Future[None]: + async def _delete_internal(self, key: Any) -> None: key_data = self._to_data(key) request = vector_collection_delete_codec.encode_request(self.name, key_data) - return self._invoke_on_key(request, key_data) + return await self._invoke_on_key(request, key_data) - def _remove_internal(self, key: Any) -> asyncio.Future[Document | None]: + async def _remove_internal(self, key: Any) -> Document | None: def handler(message): doc = vector_collection_remove_codec.decode_response(message) return self._transform_document(doc) key_data = self._to_data(key) request = vector_collection_remove_codec.encode_request(self.name, key_data) - return self._invoke_on_key(request, key_data, response_handler=handler) + return await self._invoke_on_key(request, key_data, response_handler=handler) - def _put_internal(self, key: Any, document: Document) -> asyncio.Future[Document | None]: + async def _put_internal(self, key: Any, document: Document) -> Document | None: def handler(message): doc = vector_collection_put_codec.decode_response(message) return self._transform_document(doc) @@ -214,7 +214,7 @@ def handler(message): key_data = self._to_data(key) value_data = self._to_data(document.value) except SchemaNotReplicatedError as e: - return self._send_schema_and_retry(e, self.set, key, document) + return await self._send_schema_and_retry(e, self.set, key, document) document = copy.copy(document) document.value = value_data request = vector_collection_put_codec.encode_request( @@ -222,11 +222,9 @@ def handler(message): key_data, document, ) - return self._invoke_on_key(request, key_data, response_handler=handler) + return await self._invoke_on_key(request, key_data, response_handler=handler) - def _put_if_absent_internal( - self, key: Any, document: Document - ) -> asyncio.Future[Document | None]: + async def _put_if_absent_internal(self, key: Any, document: Document) -> Document | None: def handler(message): doc = vector_collection_put_if_absent_codec.decode_response(message) return self._transform_document(doc) @@ -235,14 +233,14 @@ def handler(message): key_data = self._to_data(key) value_data = self._to_data(document.value) except SchemaNotReplicatedError as e: - return self._send_schema_and_retry(e, self.set, key, document) + return await self._send_schema_and_retry(e, self.set, key, document) document.value = value_data request = vector_collection_put_if_absent_codec.encode_request( self.name, key_data, document, ) - return self._invoke_on_key(request, key_data, response_handler=handler) + return await self._invoke_on_key(request, key_data, response_handler=handler) def _transform_document(self, doc: Optional[Document]) -> Optional[Document]: if doc is not None: diff --git a/hazelcast/internal/asyncio_reactor.py b/hazelcast/internal/asyncio_reactor.py index e1026c78c3..7311f10bfc 100644 --- a/hazelcast/internal/asyncio_reactor.py +++ b/hazelcast/internal/asyncio_reactor.py @@ -172,7 +172,6 @@ def __init__(self, conn: AsyncioConnection): self._write_buf = io.BytesIO() self._write_buf_size = 0 self._recv_buf = None - self._alive = True # asyncio tasks are weakly referenced # storing tasks here in order not to lose them midway # see: https: // docs.python.org / 3 / library / asyncio - task.html # creating-tasks @@ -186,8 +185,10 @@ def connection_made(self, transport: transports.BaseTransport): self._conn._loop.call_soon(self._write_loop) def connection_lost(self, exc): - self._alive = False - task = self._conn._loop.create_task(self._conn.close_connection(str(exc), None)) + _logger.warning("Connection closed by server") + task = self._conn._loop.create_task( + self._conn.close_connection(None, IOError("Connection closed by server")) + ) self._tasks.add(task) task.add_done_callback(self._tasks.discard) return False @@ -213,9 +214,6 @@ def buffer_updated(self, nbytes): if self._conn._reader.length: self._conn._reader.process() - def eof_received(self): - self._alive = False - def _do_write(self): if not self._write_buf_size: return diff --git a/tests/integration/asyncio/listener_test.py b/tests/integration/asyncio/listener_test.py index da13a25438..9d132a8ad4 100644 --- a/tests/integration/asyncio/listener_test.py +++ b/tests/integration/asyncio/listener_test.py @@ -52,7 +52,7 @@ async def _remove_member_test(self, is_smart): self.client_config["smart_routing"] = is_smart client = await self.create_client(self.client_config) await wait_for_partition_table(client) - key_m1 = generate_key_owned_by_instance(client, self.m1.uuid) + key_m1 = await generate_key_owned_by_instance(client, self.m1.uuid) random_map = await client.get_map(random_string()) await random_map.add_entry_listener(added_func=self.collector) await asyncio.to_thread(self.m1.shutdown) @@ -92,7 +92,7 @@ async def _add_member_test(self, is_smart): await random_map.add_entry_listener(added_func=self.collector, updated_func=self.collector) m2 = await asyncio.to_thread(self.cluster.start_member) await wait_for_partition_table(client) - key_m2 = generate_key_owned_by_instance(client, m2.uuid) + key_m2 = await generate_key_owned_by_instance(client, m2.uuid) assertion_succeeded = False async def run(): diff --git a/tests/integration/asyncio/serialization/__init__.py b/tests/integration/asyncio/serialization/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/asyncio/serialization/compact_compatibility/__init__.py b/tests/integration/asyncio/serialization/compact_compatibility/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/asyncio/serialization/compact_compatibility/compact_compatibility_test.py b/tests/integration/asyncio/serialization/compact_compatibility/compact_compatibility_test.py new file mode 100644 index 0000000000..b1a55abab3 --- /dev/null +++ b/tests/integration/asyncio/serialization/compact_compatibility/compact_compatibility_test.py @@ -0,0 +1,650 @@ +import copy +import enum +import typing +import unittest + +from hazelcast.errors import NullPointerError +from hazelcast.predicate import Predicate +from tests.integration.asyncio.base import HazelcastTestCase +from tests.util import ( + random_string, + compare_client_version, + compare_server_version_with_rc, + skip_if_client_version_older_than, +) + +try: + from hazelcast.serialization.api import ( + CompactSerializer, + CompactWriter, + CompactReader, + ) +except ImportError: + # For backward compatibility tests + + T = typing.TypeVar("T") + + class CompactSerializer(typing.Generic[T]): + pass + + class CompactReader: + pass + + class CompactWriter: + pass + + class FieldKind(enum.Enum): + pass + + +class InnerCompact: + def __init__(self, string_field: str): + self.string_field = string_field + + def __eq__(self, o: object) -> bool: + return isinstance(o, InnerCompact) and self.string_field == o.string_field + + def __hash__(self) -> int: + return hash(self.string_field) + + def __repr__(self): + return f"InnerCompact(string_field={self.string_field})" + + +class OuterCompact: + def __init__(self, int_field: int, inner_field: InnerCompact): + self.int_field = int_field + self.inner_field = inner_field + + def __eq__(self, o: object) -> bool: + return ( + isinstance(o, OuterCompact) + and self.int_field == o.int_field + and self.inner_field == o.inner_field + ) + + def __hash__(self) -> int: + return hash((self.int_field, self.inner_field)) + + def __repr__(self): + return f"OuterCompact(int_field={self.int_field}, inner_field={self.inner_field})" + + +class InnerSerializer(CompactSerializer[InnerCompact]): + def read(self, reader: CompactReader) -> InnerCompact: + return InnerCompact(reader.read_string("stringField")) + + def write(self, writer: CompactWriter, obj: InnerCompact) -> None: + writer.write_string("stringField", obj.string_field) + + def get_type_name(self) -> str: + return "com.hazelcast.serialization.compact.InnerCompact" + + def get_class(self): + return InnerCompact + + +class OuterSerializer(CompactSerializer[OuterCompact]): + def read(self, reader: CompactReader) -> OuterCompact: + return OuterCompact( + reader.read_int32("intField"), + reader.read_compact("innerField"), + ) + + def write(self, writer: CompactWriter, obj: OuterCompact) -> None: + writer.write_int32("intField", obj.int_field) + writer.write_compact("innerField", obj.inner_field) + + def get_type_name(self) -> str: + return "com.hazelcast.serialization.compact.OuterCompact" + + def get_class(self): + return OuterCompact + + +class CompactIncrementFunction: + pass + + +class CompactIncrementFunctionSerializer(CompactSerializer[CompactIncrementFunction]): + def read(self, reader: CompactReader) -> CompactIncrementFunction: + return CompactIncrementFunction() + + def write(self, writer: CompactWriter, obj: CompactIncrementFunction) -> None: + pass + + def get_type_name(self) -> str: + return "com.hazelcast.serialization.compact.CompactIncrementFunction" + + def get_class(self): + return CompactIncrementFunction + + +class CompactReturningFunction: + pass + + +class CompactReturningFunctionSerializer(CompactSerializer[CompactReturningFunction]): + def read(self, reader: CompactReader) -> CompactReturningFunction: + return CompactReturningFunction() + + def write(self, writer: CompactWriter, obj: CompactReturningFunction) -> None: + pass + + def get_type_name(self) -> str: + return "com.hazelcast.serialization.compact.CompactReturningFunction" + + def get_class(self): + return CompactReturningFunction + + +class CompactReturningCallable: + pass + + +class CompactReturningCallableSerializer(CompactSerializer[CompactReturningCallable]): + def read(self, reader: CompactReader) -> CompactReturningCallable: + return CompactReturningCallable() + + def write(self, writer: CompactWriter, obj: CompactReturningCallable) -> None: + pass + + def get_type_name(self) -> str: + return "com.hazelcast.serialization.compact.CompactReturningCallable" + + def get_class(self): + return CompactReturningCallable + + +class CompactPredicate(Predicate): + pass + + +class CompactPredicateSerializer(CompactSerializer[CompactPredicate]): + def read(self, reader: CompactReader) -> CompactPredicate: + return CompactPredicate() + + def write(self, writer: CompactWriter, obj: CompactPredicate) -> None: + pass + + def get_type_name(self) -> str: + return "com.hazelcast.serialization.compact.CompactPredicate" + + def get_class(self): + return CompactPredicate + + +class CompactReturningMapInterceptor: + pass + + +class CompactReturningMapInterceptorSerializer(CompactSerializer[CompactReturningMapInterceptor]): + def read(self, reader: CompactReader) -> CompactReturningMapInterceptor: + return CompactReturningMapInterceptor() + + def write(self, writer: CompactWriter, obj: CompactReturningMapInterceptor) -> None: + pass + + def get_type_name(self) -> str: + return "com.hazelcast.serialization.compact.CompactReturningMapInterceptor" + + def get_class(self): + return CompactReturningMapInterceptor + + +try: + from hazelcast.aggregator import Aggregator + + class CompactReturningAggregator(Aggregator): + pass + +except ImportError: + + class CompactReturningAggregator: + pass + + +class CompactReturningAggregatorSerializer(CompactSerializer[CompactReturningAggregator]): + def read(self, reader: CompactReader) -> CompactReturningAggregator: + return CompactReturningAggregator() + + def write(self, writer: CompactWriter, obj: CompactReturningAggregator) -> None: + pass + + def get_type_name(self) -> str: + return "com.hazelcast.serialization.compact.CompactReturningAggregator" + + def get_class(self): + return CompactReturningAggregator + + +class CompactReturningEntryProcessor: + pass + + +class CompactReturningEntryProcessorSerializer(CompactSerializer[CompactReturningEntryProcessor]): + def read(self, reader: CompactReader) -> CompactReturningEntryProcessor: + return CompactReturningEntryProcessor() + + def write(self, writer: CompactWriter, obj: CompactReturningEntryProcessor) -> None: + pass + + def get_type_name(self) -> str: + return "com.hazelcast.serialization.compact.CompactReturningEntryProcessor" + + def get_class(self): + return CompactReturningEntryProcessor + + +class CompactReturningProjection: + pass + + +class CompactReturningProjectionSerializer(CompactSerializer[CompactReturningProjection]): + def read(self, reader: CompactReader) -> CompactReturningProjection: + return CompactReturningProjection() + + def write(self, writer: CompactWriter, obj: CompactReturningProjection) -> None: + pass + + def get_type_name(self) -> str: + return "com.hazelcast.serialization.compact.CompactReturningProjection" + + def get_class(self): + return CompactReturningProjection + + +class CompactFilter: + pass + + +class CompactFilterSerializer(CompactSerializer[CompactFilter]): + def read(self, reader: CompactReader) -> CompactFilter: + return CompactFilter() + + def write(self, writer: CompactWriter, obj: CompactFilter) -> None: + pass + + def get_type_name(self) -> str: + return "com.hazelcast.serialization.compact.CompactFilter" + + def get_class(self): + return CompactFilter + + +INNER_COMPACT_INSTANCE = InnerCompact("42") +OUTER_COMPACT_INSTANCE = OuterCompact(42, INNER_COMPACT_INSTANCE) + + +@unittest.skipIf( + compare_client_version("5.2") < 0, "Tests the features added in 5.2 version of the client" +) +class CompactCompatibilityBase(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + rc = None + cluster = None + client_config = None + + @classmethod + def setUpClass(cls) -> None: + cls.rc = cls.create_rc() + if compare_server_version_with_rc(cls.rc, "5.2") < 0: + cls.rc.exit() + raise unittest.SkipTest("Compact serialization requires 5.2 server") + + config = f""" + + + + """ + + cls.cluster = cls.create_cluster(cls.rc, config) + cls.cluster.start_member() + cls.client_config = { + "cluster_name": cls.cluster.id, + "compact_serializers": [ + InnerSerializer(), + OuterSerializer(), + CompactIncrementFunctionSerializer(), + CompactReturningFunctionSerializer(), + CompactReturningCallableSerializer(), + CompactPredicateSerializer(), + CompactReturningMapInterceptorSerializer(), + CompactReturningAggregatorSerializer(), + CompactReturningEntryProcessorSerializer(), + CompactReturningProjectionSerializer(), + CompactFilterSerializer(), + ], + } + + @classmethod + def tearDownClass(cls) -> None: + cls.rc.terminateCluster(cls.cluster.id) + cls.rc.exit() + + async def asyncSetUp(self) -> None: + self.client = await self.create_client(self.client_config) + + async def asyncTearDown(self) -> None: + await self.shutdown_all_clients() + + +class MapCompatibilityTest(CompactCompatibilityBase): + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.map = await self.client.get_map(random_string()) + + async def asyncTearDown(self) -> None: + await self.map.destroy() + await super().asyncTearDown() + + async def test_add_entry_listener_with_key_and_predicate(self): + events = [] + + def listener(event): + events.append(event) + + await self.map.add_entry_listener( + include_value=True, + key=INNER_COMPACT_INSTANCE, + predicate=CompactPredicate(), + added_func=listener, + ) + + await self._assert_entry_event(events) + + async def test_add_entry_listener_with_predicate(self): + events = [] + + def listener(event): + events.append(event) + + await self.map.add_entry_listener( + include_value=True, + predicate=CompactPredicate(), + added_func=listener, + ) + + await self._assert_entry_event(events) + + async def test_add_entry_listener_with_key(self): + events = [] + + def listener(event): + events.append(event) + + await self.map.add_entry_listener( + include_value=True, + key=INNER_COMPACT_INSTANCE, + added_func=listener, + ) + + await self._assert_entry_event(events) + + async def test_add_interceptor(self): + await self.map.add_interceptor(CompactReturningMapInterceptor()) + self.assertEqual(OUTER_COMPACT_INSTANCE, await self.map.get("non-existent-key")) + + async def test_aggregate(self): + self.assertEqual( + OUTER_COMPACT_INSTANCE, + await self.map.aggregate(CompactReturningAggregator()), + ) + + async def test_aggregate_with_predicate(self): + self.assertEqual( + OUTER_COMPACT_INSTANCE, + await self.map.aggregate(CompactReturningAggregator(), predicate=CompactPredicate()), + ) + + async def test_contains_key(self): + self.assertFalse(await self.map.contains_key(OUTER_COMPACT_INSTANCE)) + await self.map.put(OUTER_COMPACT_INSTANCE, INNER_COMPACT_INSTANCE) + self.assertTrue(await self.map.contains_key(OUTER_COMPACT_INSTANCE)) + + async def test_contains_value(self): + self.assertFalse(await self.map.contains_value(OUTER_COMPACT_INSTANCE)) + await self.map.put(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertTrue(await self.map.contains_value(OUTER_COMPACT_INSTANCE)) + + async def test_delete(self): + await self.map.delete(OUTER_COMPACT_INSTANCE) + await self.map.put(OUTER_COMPACT_INSTANCE, INNER_COMPACT_INSTANCE) + await self.map.delete(OUTER_COMPACT_INSTANCE) + self.assertIsNone(await self.map.get(OUTER_COMPACT_INSTANCE)) + + async def test_entry_set(self): + await self._put_from_another_client(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertEqual( + [(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE)], await self.map.entry_set() + ) + + async def test_entry_set_with_predicate(self): + await self._put_from_another_client(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertEqual( + [(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE)], + await self.map.entry_set(CompactPredicate()), + ) + + async def test_evict(self): + self.assertFalse(await self.map.evict(OUTER_COMPACT_INSTANCE)) + await self.map.put(OUTER_COMPACT_INSTANCE, INNER_COMPACT_INSTANCE) + self.assertTrue(await self.map.evict(OUTER_COMPACT_INSTANCE)) + + async def test_execute_on_entries(self): + await self._put_from_another_client(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertEqual( + [(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE)], + await self.map.execute_on_entries(CompactReturningEntryProcessor()), + ) + + async def test_execute_on_entries_predicate(self): + await self._put_from_another_client(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertEqual( + [(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE)], + await self.map.execute_on_entries(CompactReturningEntryProcessor(), CompactPredicate()), + ) + + async def test_execute_on_key(self): + await self.map.put(OUTER_COMPACT_INSTANCE, INNER_COMPACT_INSTANCE) + self.assertEqual( + OUTER_COMPACT_INSTANCE, + await self.map.execute_on_key(OUTER_COMPACT_INSTANCE, CompactReturningEntryProcessor()), + ) + + async def test_execute_on_keys(self): + await self._put_from_another_client(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertEqual( + [(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE)], + await self.map.execute_on_keys( + [INNER_COMPACT_INSTANCE], CompactReturningEntryProcessor() + ), + ) + + async def test_get(self): + await self._put_from_another_client(OUTER_COMPACT_INSTANCE, INNER_COMPACT_INSTANCE) + self.assertEqual(INNER_COMPACT_INSTANCE, await self.map.get(OUTER_COMPACT_INSTANCE)) + + async def test_get_all(self): + await self._put_from_another_client(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertEqual( + {INNER_COMPACT_INSTANCE: OUTER_COMPACT_INSTANCE}, + await self.map.get_all([INNER_COMPACT_INSTANCE]), + ) + + async def test_get_entry_view(self): + await self._put_from_another_client(OUTER_COMPACT_INSTANCE, INNER_COMPACT_INSTANCE) + entry_view = await self.map.get_entry_view(OUTER_COMPACT_INSTANCE) + self.assertEqual(OUTER_COMPACT_INSTANCE, entry_view.key) + self.assertEqual(INNER_COMPACT_INSTANCE, entry_view.value) + + async def test_key_set(self): + await self._put_from_another_client(OUTER_COMPACT_INSTANCE, INNER_COMPACT_INSTANCE) + self.assertEqual([OUTER_COMPACT_INSTANCE], await self.map.key_set()) + + async def test_key_set_with_predicate(self): + await self._put_from_another_client(OUTER_COMPACT_INSTANCE, INNER_COMPACT_INSTANCE) + self.assertEqual( + [OUTER_COMPACT_INSTANCE], + await self.map.key_set(CompactPredicate()), + ) + + async def test_load_all_with_keys(self): + try: + await self.map.load_all(keys=[INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE]) + except NullPointerError: + # Since there is no loader configured + # the server throws this error. In this test, + # we only care about sending the serialized + # for of the keys to the server. So, we don't + # care about what server does with these keys. + # It should probably handle this gracefully, + # but it is OK. + pass + + async def test_project(self): + await self._put_from_another_client(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertEqual( + [OUTER_COMPACT_INSTANCE], + await self.map.project(CompactReturningProjection()), + ) + + async def test_project_with_predicate(self): + await self._put_from_another_client(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertEqual( + [OUTER_COMPACT_INSTANCE], + await self.map.project(CompactReturningProjection(), CompactPredicate()), + ) + + async def test_put(self): + await self.map.put(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertEqual(OUTER_COMPACT_INSTANCE, await self.map.get(INNER_COMPACT_INSTANCE)) + + async def test_put_all(self): + await self.map.put_all({OUTER_COMPACT_INSTANCE: INNER_COMPACT_INSTANCE}) + self.assertEqual(INNER_COMPACT_INSTANCE, await self.map.get(OUTER_COMPACT_INSTANCE)) + + async def test_put_if_absent(self): + self.assertIsNone( + await self.map.put_if_absent(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + ) + self.assertEqual( + OUTER_COMPACT_INSTANCE, + await self.map.put_if_absent(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE), + ) + + async def test_put_transient(self): + await self.map.put_transient(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertEqual(OUTER_COMPACT_INSTANCE, await self.map.get(INNER_COMPACT_INSTANCE)) + + async def test_remove(self): + self.assertIsNone(await self.map.remove(OUTER_COMPACT_INSTANCE)) + await self._put_from_another_client(OUTER_COMPACT_INSTANCE, INNER_COMPACT_INSTANCE) + self.assertEqual(INNER_COMPACT_INSTANCE, await self.map.remove(OUTER_COMPACT_INSTANCE)) + + async def test_remove_all(self): + skip_if_client_version_older_than(self, "5.2") + + await self._put_from_another_client(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertIsNone(await self.map.remove_all(CompactPredicate())) + self.assertEqual(0, await self.map.size()) + + async def test_remove_if_same(self): + self.assertFalse( + await self.map.remove_if_same(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + ) + await self._put_from_another_client(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertTrue( + await self.map.remove_if_same(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + ) + + async def test_replace(self): + self.assertIsNone(await self.map.replace(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE)) + await self.map.put(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertEqual( + OUTER_COMPACT_INSTANCE, + await self.map.replace(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE), + ) + + async def test_replace_if_same(self): + self.assertFalse( + await self.map.replace_if_same( + INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE + ) + ) + await self.map.put(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertTrue( + await self.map.replace_if_same( + INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE + ) + ) + + async def test_set(self): + await self.map.set(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertEqual(OUTER_COMPACT_INSTANCE, await self.map.get(INNER_COMPACT_INSTANCE)) + + async def test_set_ttl(self): + await self._put_from_another_client(OUTER_COMPACT_INSTANCE, INNER_COMPACT_INSTANCE) + await self.map.set_ttl(OUTER_COMPACT_INSTANCE, 999) + + async def test_try_put(self): + self.assertTrue(await self.map.try_put(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE)) + + async def test_try_remove(self): + self.assertFalse(await self.map.try_remove(OUTER_COMPACT_INSTANCE)) + await self.map.put(OUTER_COMPACT_INSTANCE, INNER_COMPACT_INSTANCE) + self.assertTrue(await self.map.try_remove(OUTER_COMPACT_INSTANCE)) + + async def test_values(self): + await self._put_from_another_client(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertEqual([OUTER_COMPACT_INSTANCE], await self.map.values()) + + async def test_values_with_predicate(self): + await self._put_from_another_client(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertEqual([OUTER_COMPACT_INSTANCE], await self.map.values(CompactPredicate())) + + async def _put_from_another_client(self, key, value): + other_client = await self.create_client(self.client_config) + other_client_map = await other_client.get_map(self.map.name) + await other_client_map.put(key, value) + + async def _assert_entry_event(self, events): + await self._put_from_another_client(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + + def assertion(): + self.assertEqual(1, len(events)) + event = events[0] + self.assertEqual(INNER_COMPACT_INSTANCE, event.key) + self.assertEqual(OUTER_COMPACT_INSTANCE, event.value) + self.assertIsNone(event.old_value) + self.assertIsNone(event.merging_value) + + await self.assertTrueEventually(assertion) + + +class NearCachedMapCompactCompatibilityTest(MapCompatibilityTest): + async def asyncSetUp(self) -> None: + map_name = random_string() + self.client_config = copy.deepcopy(NearCachedMapCompactCompatibilityTest.client_config) + self.client_config["near_caches"] = {map_name: {}} + await super().asyncSetUp() + self.map = await self.client.get_map(map_name) + + async def test_get_for_near_cache(self): + # Another variant of the test in the superclass, where we lookup a key + # which has a value whose schema is not fully sent to the + # cluster from this client. The near cache will try to serialize it, + # but it should not attempt to send this schema to the cluster, as it + # is just fetched from there. + await self._put_from_another_client(INNER_COMPACT_INSTANCE, OUTER_COMPACT_INSTANCE) + self.assertEqual(OUTER_COMPACT_INSTANCE, await self.map.get(INNER_COMPACT_INSTANCE)) + + +class PartitionServiceCompactCompatibilityTest(CompactCompatibilityBase): + async def test_partition_service(self): + self.assertEqual( + 267, + await self.client.partition_service.get_partition_id(OUTER_COMPACT_INSTANCE), + ) diff --git a/tests/integration/asyncio/serialization/compact_test.py b/tests/integration/asyncio/serialization/compact_test.py new file mode 100644 index 0000000000..c49bf4e1fe --- /dev/null +++ b/tests/integration/asyncio/serialization/compact_test.py @@ -0,0 +1,929 @@ +import asyncio +import copy +import datetime +import decimal +import enum +import itertools +import random +import typing +import unittest + +from hazelcast.errors import HazelcastSerializationError +from hazelcast.predicate import sql +from hazelcast.util import AtomicInteger +from tests.integration.asyncio.base import HazelcastTestCase +from tests.util import ( + is_equal, + random_string, + compare_client_version, + compare_server_version_with_rc, +) + +try: + from hazelcast.serialization.api import ( + CompactSerializer, + CompactReader, + CompactWriter, + FieldKind, + ) + from hazelcast.serialization.compact import FIELD_OPERATIONS + + _COMPACT_AVAILABLE = True +except ImportError: + # For backward compatibility tests + + T = typing.TypeVar("T") + + class CompactSerializer(typing.Generic[T]): + pass + + class CompactReader: + pass + + class CompactWriter: + pass + + class FieldKind(enum.Enum): + pass + + _COMPACT_AVAILABLE = False + + +if _COMPACT_AVAILABLE: + FIELD_KINDS = [kind for kind in FieldKind if FIELD_OPERATIONS[kind] is not None] + FIX_SIZED_FIELD_KINDS = [ + kind for kind in FIELD_KINDS if not FIELD_OPERATIONS[kind].is_var_sized() + ] + VAR_SIZED_FIELD_KINDS = [kind for kind in FIELD_KINDS if FIELD_OPERATIONS[kind].is_var_sized()] + + FIX_SIZED_TO_NULLABLE = { + FieldKind.BOOLEAN: FieldKind.NULLABLE_BOOLEAN, + FieldKind.INT8: FieldKind.NULLABLE_INT8, + FieldKind.INT16: FieldKind.NULLABLE_INT16, + FieldKind.INT32: FieldKind.NULLABLE_INT32, + FieldKind.INT64: FieldKind.NULLABLE_INT64, + FieldKind.FLOAT32: FieldKind.NULLABLE_FLOAT32, + FieldKind.FLOAT64: FieldKind.NULLABLE_FLOAT64, + } + + FIX_SIZED_ARRAY_TO_NULLABLE_FIX_SIZED_ARRAY = { + FieldKind.ARRAY_OF_BOOLEAN: FieldKind.ARRAY_OF_NULLABLE_BOOLEAN, + FieldKind.ARRAY_OF_INT8: FieldKind.ARRAY_OF_NULLABLE_INT8, + FieldKind.ARRAY_OF_INT16: FieldKind.ARRAY_OF_NULLABLE_INT16, + FieldKind.ARRAY_OF_INT32: FieldKind.ARRAY_OF_NULLABLE_INT32, + FieldKind.ARRAY_OF_INT64: FieldKind.ARRAY_OF_NULLABLE_INT64, + FieldKind.ARRAY_OF_FLOAT32: FieldKind.ARRAY_OF_NULLABLE_FLOAT32, + FieldKind.ARRAY_OF_FLOAT64: FieldKind.ARRAY_OF_NULLABLE_FLOAT64, + } + + ARRAY_FIELD_KINDS_WITH_NULLABLE_ITEMS = [ + kind + for kind in VAR_SIZED_FIELD_KINDS + if ("ARRAY" in kind.name) and (kind not in FIX_SIZED_ARRAY_TO_NULLABLE_FIX_SIZED_ARRAY) + ] +else: + FIELD_KINDS = [] + FIX_SIZED_FIELD_KINDS = [] + VAR_SIZED_FIELD_KINDS = [] + FIX_SIZED_TO_NULLABLE = {} + FIX_SIZED_ARRAY_TO_NULLABLE_FIX_SIZED_ARRAY = {} + ARRAY_FIELD_KINDS_WITH_NULLABLE_ITEMS = [] + + +@unittest.skipIf( + compare_client_version("5.2") < 0, "Tests the features added in 5.2 version of the client" +) +class CompactTestBase(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + rc = None + cluster = None + member = None + + @classmethod + def setUpClass(cls) -> None: + cls.rc = cls.create_rc() + if compare_server_version_with_rc(cls.rc, "5.2") < 0: + cls.rc.exit() + raise unittest.SkipTest("Compact serialization requires 5.2 server") + + cls.cluster = cls.create_cluster(cls.rc, None) + cls.member = cls.cluster.start_member() + + @classmethod + def tearDownClass(cls): + cls.rc.terminateCluster(cls.cluster.id) + cls.rc.exit() + + async def asyncTearDown(self) -> None: + await self.shutdown_all_clients() + + +class CompactTest(CompactTestBase): + async def test_write_then_read_with_all_fields(self): + serializer = SomeFieldsSerializer.from_kinds(FIELD_KINDS) + await self._write_then_read(FIELD_KINDS, REFERENCE_OBJECTS, serializer) + + async def test_write_then_read_with_no_fields(self): + serializer = SomeFieldsSerializer.from_kinds([]) + await self._write_then_read([], {}, serializer) + + async def test_write_then_read_with_just_var_sized_fields(self): + serializer = SomeFieldsSerializer.from_kinds(VAR_SIZED_FIELD_KINDS) + await self._write_then_read(VAR_SIZED_FIELD_KINDS, REFERENCE_OBJECTS, serializer) + + async def test_write_then_read_with_just_fix_sized_fields(self): + serializer = SomeFieldsSerializer.from_kinds(FIX_SIZED_FIELD_KINDS) + await self._write_then_read(FIX_SIZED_FIELD_KINDS, REFERENCE_OBJECTS, serializer) + + async def test_write_then_read_object_with_different_position_readers(self): + params = [ + ("uint8_reader", 1), + ("uint16_reader", 20), + ("int32_reader", 42), + ] + for name, array_item_count in params: + with self.subTest(name, array_item_count=array_item_count): + reference_objects = { + FieldKind.ARRAY_OF_STRING: [ + "x" * (i * 100) for i in range(1, array_item_count) + ], + FieldKind.INT32: 32, + FieldKind.STRING: "hey", + } + reference_objects[FieldKind.ARRAY_OF_STRING].append(None) + serializer = SomeFieldsSerializer.from_kinds(list(reference_objects.keys())) + await self._write_then_read( + list(reference_objects.keys()), reference_objects, serializer + ) + + async def test_write_then_read_boolean_array(self): + params = [ + ("0", 0), + ("1", 1), + ("8", 8), + ("10", 10), + ("100", 100), + ("1000", 1000), + ] + for name, item_count in params: + with self.subTest(name, item_count=item_count): + reference_objects = { + FieldKind.ARRAY_OF_BOOLEAN: [ + random.randrange(0, 10) % 2 == 0 for _ in range(item_count) + ] + } + serializer = SomeFieldsSerializer.from_kinds(list(reference_objects.keys())) + await self._write_then_read( + list(reference_objects.keys()), reference_objects, serializer + ) + + async def test_write_and_read_with_multiple_boolean_fields(self): + params = [ + ("0", 0), + ("1", 1), + ("8", 8), + ("10", 10), + ("100", 100), + ("1000", 1000), + ] + for name, field_count in params: + with self.subTest(name, field_count=field_count): + all_fields = {str(i): random.randrange(0, 2) % 2 == 0 for i in range(field_count)} + + class Serializer(CompactSerializer[SomeFields]): + def __init__(self, field_names: typing.List[str]): + self._field_names = field_names + + def read(self, reader: CompactReader) -> SomeFields: + fields = {} + for field_name in self._field_names: + fields[field_name] = reader.read_boolean(field_name) + + return SomeFields(**fields) + + def write(self, writer: CompactWriter, obj: SomeFields) -> None: + for field_name in self._field_names: + writer.write_boolean(field_name, getattr(obj, field_name)) + + def get_type_name(self) -> str: + return SomeFields.__name__ + + def get_class(self) -> typing.Type[SomeFields]: + return SomeFields + + await self._write_then_read0(all_fields, Serializer(list(all_fields.keys()))) + + async def test_write_then_read(self): + params = [(field_kind.name, field_kind) for field_kind in FIELD_KINDS] + if not params: + self.skipTest("empty") + for name, field_kind in params: + with self.subTest(name, field_kind=field_kind): + field_name = field_kind.name.lower() + m = await self._put_entry( + map_name=random_string(), + value_to_put=REFERENCE_OBJECTS[field_kind], + field_name=field_name, + ) + obj = await m.get("key") + self.assertTrue(is_equal(REFERENCE_OBJECTS[field_kind], getattr(obj, field_name))) + + async def test_write_none_then_read(self): + params = [(field_kind.name, field_kind) for field_kind in VAR_SIZED_FIELD_KINDS] + if not params: + self.skipTest("empty") + for name, field_kind in params: + with self.subTest(name, field_kind=field_kind): + field_name = field_kind.name.lower() + m = await self._put_entry( + map_name=random_string(), + value_to_put=None, + field_name=field_name, + ) + obj = await m.get("key") + self.assertIsNone(getattr(obj, field_name)) + + async def test_write_array_with_none_items_then_read(self): + params = [ + (field_kind.name, field_kind) for field_kind in ARRAY_FIELD_KINDS_WITH_NULLABLE_ITEMS + ] + if not params: + self.skipTest("empty") + for name, field_kind in params: + with self.subTest(name, field_kind=field_kind): + field_name = field_kind.name.lower() + value = [None] + REFERENCE_OBJECTS[field_kind] + [None] + value.insert(2, None) + m = await self._put_entry( + map_name=random_string(), + value_to_put=value, + field_name=field_name, + ) + obj = await m.get("key") + self.assertTrue(is_equal(value, getattr(obj, field_name))) + + async def test_read_when_field_does_not_exist(self): + params = [(field_kind.name, field_kind) for field_kind in FIELD_KINDS] + if not params: + self.skipTest("empty") + for name, field_kind in params: + with self.subTest(name, field_kind=field_kind): + map_name = random_string() + field_name = field_kind.name.lower() + await self._put_entry( + map_name=map_name, + value_to_put=REFERENCE_OBJECTS[field_kind], + field_name=field_name, + ) + + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [ + SomeFieldsSerializer( + [ + FieldDefinition( + name=field_name, + name_to_read="not-a-field", + reader_method_name=f"read_{field_name}", + ) + ] + ), + NestedSerializer(), + ], + } + ) + + evolved_m = await client.get_map(map_name) + with self.assertRaisesRegex(HazelcastSerializationError, "No field with the name"): + await evolved_m.get("key") + + async def test_read_with_type_mismatch(self): + params = [(field_kind.name, field_kind) for field_kind in FIELD_KINDS] + if not params: + self.skipTest("empty") + for name, field_kind in params: + with self.subTest(name, field_kind=field_kind): + map_name = random_string() + mismatched_field_kind = FIELD_KINDS[(field_kind.value + 1) % len(FIELD_KINDS)] + field_name = field_kind.name.lower() + await self._put_entry( + map_name=map_name, + value_to_put=REFERENCE_OBJECTS[mismatched_field_kind], + field_name=field_name, + writer_method_name=f"write_{mismatched_field_kind.name.lower()}", + ) + + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [ + SomeFieldsSerializer([FieldDefinition(name=field_name)]), + NestedSerializer(), + ], + } + ) + + m = await client.get_map(map_name) + with self.assertRaisesRegex(HazelcastSerializationError, "Mismatched field types"): + await m.get("key") + + async def test_write_then_read_as_nullable(self): + params = [ + (field_kind.name, field_kind, nullable_field_kind) + for field_kind, nullable_field_kind in itertools.chain( + FIX_SIZED_TO_NULLABLE.items(), + FIX_SIZED_ARRAY_TO_NULLABLE_FIX_SIZED_ARRAY.items(), + ) + ] + if not params: + self.skipTest("empty") + for name, field_kind, nullable_field_kind in params: + with self.subTest(name, field_kind=field_kind, nullable_field_kind=nullable_field_kind): + map_name = random_string() + field_name = field_kind.name.lower() + await self._put_entry( + map_name=map_name, + value_to_put=REFERENCE_OBJECTS[field_kind], + field_name=field_name, + ) + nullable_method_suffix = nullable_field_kind.name.lower() + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [ + SomeFieldsSerializer( + [ + FieldDefinition( + name=field_name, + reader_method_name=f"read_{nullable_method_suffix}", + ) + ] + ), + ], + } + ) + + m = await client.get_map(map_name) + obj = await m.get("key") + self.assertTrue(is_equal(REFERENCE_OBJECTS[field_kind], getattr(obj, field_name))) + + async def test_write_as_nullable_then_read(self): + params = [ + (field_kind.name, field_kind, nullable_field_kind) + for field_kind, nullable_field_kind in itertools.chain( + FIX_SIZED_TO_NULLABLE.items(), + FIX_SIZED_ARRAY_TO_NULLABLE_FIX_SIZED_ARRAY.items(), + ) + ] + if not params: + self.skipTest("empty") + for name, field_kind, nullable_field_kind in params: + with self.subTest(name, field_kind=field_kind, nullable_field_kind=nullable_field_kind): + map_name = random_string() + nullable_method_suffix = nullable_field_kind.name.lower() + field_name = field_kind.name.lower() + await self._put_entry( + map_name=map_name, + value_to_put=REFERENCE_OBJECTS[field_kind], + field_name=field_name, + writer_method_name=f"write_{nullable_method_suffix}", + ) + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [ + SomeFieldsSerializer([FieldDefinition(name=field_name)]), + ], + } + ) + + m = await client.get_map(map_name) + obj = await m.get("key") + self.assertTrue(is_equal(REFERENCE_OBJECTS[field_kind], getattr(obj, field_name))) + + async def test_write_nullable_fix_sized_as_none_then_read_as_fix_sized(self): + params = [ + (field_kind.name, field_kind, nullable_field_kind) + for field_kind, nullable_field_kind in FIX_SIZED_TO_NULLABLE.items() + ] + if not params: + self.skipTest("empty") + for name, field_kind, nullable_field_kind in params: + with self.subTest(name, field_kind=field_kind, nullable_field_kind=nullable_field_kind): + map_name = random_string() + nullable_method_suffix = nullable_field_kind.name.lower() + field_name = field_kind.name.lower() + await self._put_entry( + map_name=map_name, + value_to_put=None, + field_name=field_name, + writer_method_name=f"write_{nullable_method_suffix}", + ) + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [ + SomeFieldsSerializer([FieldDefinition(name=field_name)]), + ], + } + ) + + m = await client.get_map(map_name) + with self.assertRaisesRegex( + HazelcastSerializationError, "A 'None' value cannot be read" + ): + await m.get("key") + + async def test_write_nullable_fix_sized_array_with_none_item_then_read_as_fix_sized_array(self): + params = [ + (field_kind.name, field_kind, nullable_field_kind) + for field_kind, nullable_field_kind in FIX_SIZED_ARRAY_TO_NULLABLE_FIX_SIZED_ARRAY.items() + ] + if not params: + self.skipTest("empty") + for name, field_kind, nullable_field_kind in params: + with self.subTest(name, field_kind=field_kind, nullable_field_kind=nullable_field_kind): + map_name = random_string() + nullable_method_suffix = nullable_field_kind.name.lower() + field_name = field_kind.name.lower() + await self._put_entry( + map_name=map_name, + value_to_put=[None], + field_name=field_name, + writer_method_name=f"write_{nullable_method_suffix}", + ) + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [ + SomeFieldsSerializer([FieldDefinition(name=field_name)]), + ], + } + ) + m = await client.get_map(map_name) + with self.assertRaisesRegex( + HazelcastSerializationError, "A `None` item cannot be read" + ): + await m.get("key") + + async def test_write_then_read_with_default_value(self): + params = [(field_kind.name, field_kind) for field_kind in FIELD_KINDS] + if not params: + self.skipTest("empty") + for name, field_kind in params: + with self.subTest(name, field_kind=field_kind): + field_name = field_kind.name.lower() + m = await self._put_entry( + map_name=random_string(), + value_to_put=REFERENCE_OBJECTS[field_kind], + field_name=field_name, + reader_method_name=f"read_{field_name}_or_default", + default_value_to_read=object(), + ) + obj = await m.get("key") + self.assertTrue(is_equal(REFERENCE_OBJECTS[field_kind], getattr(obj, field_name))) + + async def test_write_then_read_with_default_value_when_field_name_does_not_match(self): + params = [(field_kind.name, field_kind) for field_kind in FIELD_KINDS] + if not params: + self.skipTest("empty") + for name, field_kind in params: + with self.subTest(name, field_kind=field_kind): + field_name = field_kind.name.lower() + default_value = object() + m = await self._put_entry( + map_name=random_string(), + value_to_put=REFERENCE_OBJECTS[field_kind], + field_name=field_name, + field_name_to_read="not-a-field", + reader_method_name=f"read_{field_name}_or_default", + default_value_to_read=default_value, + ) + obj = await m.get("key") + self.assertTrue(getattr(obj, field_name) is default_value) + + async def test_write_then_read_with_default_value_when_field_type_does_not_match(self): + params = [(field_kind.name, field_kind) for field_kind in FIELD_KINDS] + if not params: + self.skipTest("empty") + for name, field_kind in params: + with self.subTest(name, field_kind=field_kind): + field_name = field_kind.name.lower() + mismatched_field_kind = FIELD_KINDS[(field_kind.value + 1) % len(FIELD_KINDS)] + default_value = object() + m = await self._put_entry( + map_name=random_string(), + value_to_put=REFERENCE_OBJECTS[mismatched_field_kind], + field_name=field_name, + field_name_to_read=field_name, + writer_method_name=f"write_{mismatched_field_kind.name.lower()}", + reader_method_name=f"read_{field_name}_or_default", + default_value_to_read=default_value, + ) + obj = await m.get("key") + self.assertTrue(getattr(obj, field_name) is default_value) + + async def _put_entry( + self, + *, + map_name: str, + value_to_put: typing.Any, + field_name: str, + field_name_to_read=None, + writer_method_name=None, + reader_method_name=None, + default_value_to_read=None, + ): + field_definition = FieldDefinition( + name=field_name, + name_to_read=field_name_to_read or field_name, + writer_method_name=writer_method_name or f"write_{field_name}", + reader_method_name=reader_method_name or f"read_{field_name}", + default_value_to_read=default_value_to_read, + ) + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [ + SomeFieldsSerializer([field_definition]), + NestedSerializer(), + ], + } + ) + + m = await client.get_map(map_name) + await m.put("key", SomeFields(**{field_name: value_to_put})) + return m + + async def _write_then_read( + self, + kinds: typing.List[FieldKind], + reference_objects: typing.Dict[FieldKind, typing.Any], + serializer: CompactSerializer, + ): + fields = {kind.name.lower(): reference_objects[kind] for kind in kinds} + await self._write_then_read0(fields, serializer) + + async def _write_then_read0( + self, fields: typing.Dict[str, typing.Any], serializer: CompactSerializer + ): + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [serializer, NestedSerializer()], + } + ) + + m = await client.get_map(random_string()) + await m.put("key", SomeFields(**fields)) + obj = await m.get("key") + for name, value in fields.items(): + self.assertTrue(is_equal(value, getattr(obj, name))) + + +class CompactSchemaEvolutionTest(CompactTestBase): + async def test_adding_a_fix_sized_field(self): + await self._verify_adding_a_field( + ("int32", 42), + ("string", "42"), + new_field_name="int64", + new_field_value=24, + new_field_default_value=12, + ) + + async def test_removing_a_fix_sized_field(self): + await self._verify_removing_a_field( + ("int64", 1234), + ("string", "hey"), + removed_field_name="int64", + removed_field_default_value=43321, + ) + + async def test_adding_a_var_sized_field(self): + await self._verify_adding_a_field( + ("int32", 42), + ("string", "42"), + new_field_name="array_of_boolean", + new_field_value=[True, False, True], + new_field_default_value=[False, False, False, True], + ) + + async def test_removing_a_var_sized_field(self): + await self._verify_removing_a_field( + ("int64", 1234), + ("string", "hey"), + removed_field_name="string", + removed_field_default_value="43321", + ) + + async def _create_client(self, serializer: CompactSerializer): + return await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [serializer], + } + ) + + async def _verify_adding_a_field( + self, + *existing_fields: typing.Tuple[str, typing.Any], + new_field_name: str, + new_field_value: typing.Any, + new_field_default_value: typing.Any, + ): + map_name = random_string() + v1_field_definitions = [FieldDefinition(name=name) for name, _ in existing_fields] + v1_serializer = SomeFieldsSerializer(v1_field_definitions) + v1_client = await self._create_client(v1_serializer) + v1_map = await v1_client.get_map(map_name) + v1_fields = {name: value for name, value in existing_fields} + await v1_map.put("key1", SomeFields(**v1_fields)) + + v2_field_definitions = v1_field_definitions + [FieldDefinition(name=new_field_name)] + v2_serializer = SomeFieldsSerializer(v2_field_definitions) + v2_client = await self._create_client(v2_serializer) + v2_map = await v2_client.get_map(map_name) + v2_fields = copy.deepcopy(v1_fields) + v2_fields[new_field_name] = new_field_value + await v2_map.put("key2", SomeFields(**v2_fields)) + + careful_v2_field_definitions = v1_field_definitions + [ + FieldDefinition( + name=new_field_name, + reader_method_name=f"read_{new_field_name}_or_default", + default_value_to_read=new_field_default_value, + ) + ] + careful_v2_serializer = SomeFieldsSerializer(careful_v2_field_definitions) + careful_client_v2 = await self._create_client(careful_v2_serializer) + careful_v2_map = await careful_client_v2.get_map(map_name) + + # Old client can read data written by the new client + v1_obj = await v1_map.get("key2") + for name in v1_fields: + self.assertEqual(v2_fields[name], getattr(v1_obj, name)) + + # New client cannot read data written by the old client, since + # there is no such field on the old data. + + with self.assertRaisesRegex(HazelcastSerializationError, "No field with the name"): + await v2_map.get("key1") + + # However, if it has default value, everything should work + + careful_v2_obj = await careful_v2_map.get("key1") + for name in v2_fields: + self.assertEqual( + v1_fields.get(name) or new_field_default_value, + getattr(careful_v2_obj, name), + ) + + async def _verify_removing_a_field( + self, + *existing_fields: typing.Tuple[str, typing.Any], + removed_field_name: str, + removed_field_default_value: typing.Any, + ): + map_name = random_string() + v1_field_definitions = [FieldDefinition(name=name) for name, _ in existing_fields] + v1_serializer = SomeFieldsSerializer(v1_field_definitions) + v1_client = await self._create_client(v1_serializer) + v1_map = await v1_client.get_map(map_name) + v1_fields = {name: value for name, value in existing_fields} + await v1_map.put("key1", SomeFields(**v1_fields)) + + v2_field_definitions = [ + FieldDefinition(name=name) for name, _ in existing_fields if name != removed_field_name + ] + v2_serializer = SomeFieldsSerializer(v2_field_definitions) + v2_client = await self._create_client(v2_serializer) + v2_map = await v2_client.get_map(map_name) + v2_fields = copy.deepcopy(v1_fields) + del v2_fields[removed_field_name] + await v2_map.put("key2", SomeFields(**v2_fields)) + + careful_v1_field_definitions = v2_field_definitions + [ + FieldDefinition( + name=removed_field_name, + reader_method_name=f"read_{removed_field_name}_or_default", + default_value_to_read=removed_field_default_value, + ) + ] + careful_v1_serializer = SomeFieldsSerializer(careful_v1_field_definitions) + careful_client_v1 = await self._create_client(careful_v1_serializer) + careful_v1_map = await careful_client_v1.get_map(map_name) + + # Old client cannot read data written by the new client, since + # there is no such field on the new data + + with self.assertRaisesRegex(HazelcastSerializationError, "No field with the name"): + await v1_map.get("key2") + + # However, if it has default value, everything should work + v1_obj = await careful_v1_map.get("key2") + for name in v1_fields: + self.assertEqual( + v2_fields.get(name) or removed_field_default_value, + getattr(v1_obj, name), + ) + + # New client can read data written by the old client + v2_obj = await v2_map.get("key1") + for name in v2_fields: + self.assertEqual(v1_fields[name], getattr(v2_obj, name)) + + with self.assertRaises(AttributeError): + getattr(v2_obj, removed_field_name) # no such field for the new schema + + +class CompactOnClusterRestartTest(CompactTestBase): + async def test_cluster_restart(self): + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [SomeFieldsSerializer([FieldDefinition(name="int32")])], + } + ) + m = await client.get_map(random_string()) + await m.put(1, SomeFields(int32=42)) + # self.rc.terminateMember(self.cluster.id, self.member.uuid) + # CompactOnClusterRestartTest.member = self.cluster.start_member() + await asyncio.to_thread(self._restart) + await m.put(1, SomeFields(int32=42)) + obj = await m.get(1) + self.assertEqual(42, obj.int32) + # Perform a query to make sure that the schema is available on the cluster + self.assertEqual(1, len(await m.values(sql("int32 == 42")))) + + def _restart(self): + self.rc.terminateMember(self.cluster.id, self.member.uuid) + CompactOnClusterRestartTest.member = self.cluster.start_member() + + +class CompactWithListenerTest(CompactTestBase): + async def test_map_listener(self): + config = { + "cluster_name": self.cluster.id, + "compact_serializers": [SomeFieldsSerializer([FieldDefinition(name="int32")])], + } + client = await self.create_client(config) + map_name = random_string() + m = await client.get_map(map_name) + counter = AtomicInteger() + + def listener(_): + counter.add(1) + + await m.add_entry_listener(include_value=True, added_func=listener) + # Put the entry from other client to not create a local + # registry in the actual client. This will force it to + # go the cluster to fetch the schema. + other_client = await self.create_client(config) + other_client_map = await other_client.get_map(map_name) + await other_client_map.put(1, SomeFields(int32=42)) + await self.assertTrueEventually(lambda: self.assertEqual(1, counter.get())) + + +class SomeFields: + def __init__(self, **fields): + self._fields = fields + + def __getattr__(self, item): + if item not in self._fields: + raise AttributeError() + + return self._fields[item] + + +class Nested: + def __init__(self, i32_field, string_field): + self.i32_field = i32_field + self.string_field = string_field + + def __eq__(self, other): + return ( + isinstance(other, Nested) + and self.i32_field == other.i32_field + and self.string_field == other.string_field + ) + + +class NestedSerializer(CompactSerializer[Nested]): + def read(self, reader: CompactReader) -> Nested: + return Nested(reader.read_int32("i32_field"), reader.read_string("string_field")) + + def write(self, writer: CompactWriter, obj: Nested) -> None: + writer.write_int32("i32_field", obj.i32_field) + writer.write_string("string_field", obj.string_field) + + def get_type_name(self) -> str: + return Nested.__name__ + + def get_class(self) -> typing.Type[Nested]: + return Nested + + +class FieldDefinition: + def __init__( + self, + *, + name: str, + name_to_read: str = None, + writer_method_name: str = None, + reader_method_name: str = None, + default_value_to_read: typing.Any = None, + ): + self.name = name + self.name_to_read = name_to_read or name + self.writer_method_name = writer_method_name or f"write_{name}" + self.reader_method_name = reader_method_name or f"read_{name}" + self.default_value_to_read = default_value_to_read + + +class SomeFieldsSerializer(CompactSerializer[SomeFields]): + def __init__(self, field_definitions: typing.List[FieldDefinition]): + self._field_definitions = field_definitions + + def read(self, reader: CompactReader) -> SomeFields: + fields = {} + for field_definition in self._field_definitions: + reader_parameters = [field_definition.name_to_read] + default_value_to_read = field_definition.default_value_to_read + if default_value_to_read is not None: + reader_parameters.append(default_value_to_read) + + value = getattr(reader, field_definition.reader_method_name)(*reader_parameters) + fields[field_definition.name] = value + + return SomeFields(**fields) + + def write(self, writer: CompactWriter, obj: SomeFields) -> None: + for field_definition in self._field_definitions: + getattr(writer, field_definition.writer_method_name)( + field_definition.name, + getattr(obj, field_definition.name), + ) + + def get_type_name(self) -> str: + return SomeFields.__name__ + + def get_class(self) -> typing.Type[SomeFields]: + return SomeFields + + @staticmethod + def from_kinds(kinds: typing.List[FieldKind]) -> "SomeFieldsSerializer": + field_definitions = [FieldDefinition(name=kind.name.lower()) for kind in kinds] + return SomeFieldsSerializer(field_definitions) + + +if _COMPACT_AVAILABLE: + REFERENCE_OBJECTS = { + FieldKind.BOOLEAN: True, + FieldKind.ARRAY_OF_BOOLEAN: [True, False, True, True, True, False, True, True, False], + FieldKind.INT8: 42, + FieldKind.ARRAY_OF_INT8: [42, -128, -1, 127], + FieldKind.INT16: -456, + FieldKind.ARRAY_OF_INT16: [-4231, 12343, 0], + FieldKind.INT32: 21212121, + FieldKind.ARRAY_OF_INT32: [-1, 1, 0, 9999999], + FieldKind.INT64: 123456789, + FieldKind.ARRAY_OF_INT64: [11, -123456789], + FieldKind.FLOAT32: 12.5, + FieldKind.ARRAY_OF_FLOAT32: [-13.13, 12345.67, 0.1, 9876543.2, -99999.99], + FieldKind.FLOAT64: 12345678.90123, + FieldKind.ARRAY_OF_FLOAT64: [-12345.67], + FieldKind.STRING: "üğişçöa", + FieldKind.ARRAY_OF_STRING: ["17", "😊 😇 🙂", "abc"], + FieldKind.DECIMAL: decimal.Decimal("123.456"), + FieldKind.ARRAY_OF_DECIMAL: [decimal.Decimal("0"), decimal.Decimal("-123456.789")], + FieldKind.TIME: datetime.time(2, 3, 4, 5), + FieldKind.ARRAY_OF_TIME: [datetime.time(8, 7, 6, 5)], + FieldKind.DATE: datetime.date(2022, 1, 1), + FieldKind.ARRAY_OF_DATE: [datetime.date(2021, 11, 11), datetime.date(2020, 3, 3)], + FieldKind.TIMESTAMP: datetime.datetime(2022, 2, 2, 3, 3, 3, 4), + FieldKind.ARRAY_OF_TIMESTAMP: [datetime.datetime(1990, 2, 12, 13, 14, 54, 98765)], + FieldKind.TIMESTAMP_WITH_TIMEZONE: datetime.datetime( + 200, 10, 10, 16, 44, 42, 12345, datetime.timezone(datetime.timedelta(hours=2)) + ), + FieldKind.ARRAY_OF_TIMESTAMP_WITH_TIMEZONE: [ + datetime.datetime( + 2001, 1, 10, 12, 24, 2, 45, datetime.timezone(datetime.timedelta(hours=-2)) + ) + ], + FieldKind.COMPACT: Nested(42, "42"), + FieldKind.ARRAY_OF_COMPACT: [Nested(-42, "-42"), Nested(123, "123")], + FieldKind.NULLABLE_BOOLEAN: False, + FieldKind.ARRAY_OF_NULLABLE_BOOLEAN: [False, False, True], + FieldKind.NULLABLE_INT8: 34, + FieldKind.ARRAY_OF_NULLABLE_INT8: [-32, 32], + FieldKind.NULLABLE_INT16: 36, + FieldKind.ARRAY_OF_NULLABLE_INT16: [37, -37, 0, 12345], + FieldKind.NULLABLE_INT32: -38, + FieldKind.ARRAY_OF_NULLABLE_INT32: [-39, 2134567, -8765432, 39], + FieldKind.NULLABLE_INT64: -4040, + FieldKind.ARRAY_OF_NULLABLE_INT64: [1, 41, -1, 12312312312, -9312912391], + FieldKind.NULLABLE_FLOAT32: 42.4, + FieldKind.ARRAY_OF_NULLABLE_FLOAT32: [-43.4, 434.43], + FieldKind.NULLABLE_FLOAT64: 44.12, + FieldKind.ARRAY_OF_NULLABLE_FLOAT64: [45.678, -4567.8, 0.12345], + } diff --git a/tests/integration/asyncio/util.py b/tests/integration/asyncio/util.py index 6a15d9c8ec..f96297d2b1 100644 --- a/tests/integration/asyncio/util.py +++ b/tests/integration/asyncio/util.py @@ -12,16 +12,16 @@ async def fill_map(map, size=10, key_prefix="key", value_prefix="val"): async def open_connection_to_address(client, uuid): - key = generate_key_owned_by_instance(client, uuid) + key = await generate_key_owned_by_instance(client, uuid) m = await client.get_map(str(uuid4())) await m.put(key, 0) await m.destroy() -def generate_key_owned_by_instance(client, uuid): +async def generate_key_owned_by_instance(client, uuid) -> str: while True: key = str(uuid4()) - partition_id = client.partition_service.get_partition_id(key) + partition_id = await client.partition_service.get_partition_id(key) owner = str(client.partition_service.get_partition_owner(partition_id)) if owner == uuid: return key