diff --git a/.github/workflows/docs-pages.yml b/.github/workflows/docs-pages.yml index f017201c7c..0da86fef34 100644 --- a/.github/workflows/docs-pages.yml +++ b/.github/workflows/docs-pages.yml @@ -2,6 +2,9 @@ name: "Docs / Publish" # For more information, # see https://sphinx-theme.scylladb.com/stable/deployment/production.html#available-workflows +permissions: + contents: write + on: push: branches: diff --git a/.github/workflows/publish-manually.yml b/.github/workflows/publish-manually.yml index d69475f394..09b9779117 100644 --- a/.github/workflows/publish-manually.yml +++ b/.github/workflows/publish-manually.yml @@ -1,5 +1,8 @@ name: Build and upload to PyPi manually +permissions: + contents: read + on: workflow_dispatch: inputs: diff --git a/.gitignore b/.gitignore index 28cf1ba218..1b60f54642 100644 --- a/.gitignore +++ b/.gitignore @@ -43,6 +43,11 @@ tests/unit/cython/bytesio_testhelper.c #iPython *.ipynb +uv.lock +.venv/ + + + # Files from upstream that we don't need Jenkinsfile Jenkinsfile.bak diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 66bf7c7049..099043eae0 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -838,8 +838,8 @@ def default_retry_policy(self, policy): Using ssl_options without ssl_context is deprecated and will be removed in the next major release. - An optional dict which will be used as kwargs for ``ssl.SSLContext.wrap_socket`` - when new sockets are created. This should be used when client encryption is enabled + An optional dict which will be used as kwargs for ``ssl.SSLContext.wrap_socket`` + when new sockets are created. This should be used when client encryption is enabled in Cassandra. The following documentation only applies when ssl_options is used without ssl_context. @@ -1086,10 +1086,10 @@ def default_retry_policy(self, policy): """ Specifies a server-side timeout (in seconds) for all internal driver queries, such as schema metadata lookups and cluster topology requests. - + The timeout is enforced by appending `USING TIMEOUT ` to queries executed by the driver. - + - A value of `0` disables explicit timeout enforcement. In this case, the driver does not add `USING TIMEOUT`, and the timeout is determined by the server's defaults. @@ -1683,14 +1683,7 @@ def protocol_downgrade(self, host_endpoint, previous_version): "http://datastax.github.io/python-driver/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", self.protocol_version, new_version, host_endpoint) self.protocol_version = new_version - def _add_resolved_hosts(self): - for endpoint in self.endpoints_resolved: - host, new = self.add_host(endpoint, signal=False) - if new: - host.set_up() - for listener in self.listeners: - listener.on_add(host) - + def _populate_hosts(self): self.profile_manager.populate( weakref.proxy(self), self.metadata.all_hosts()) self.load_balancing_policy.populate( @@ -1717,17 +1710,10 @@ def connect(self, keyspace=None, wait_for_all_pools=False): self.contact_points, self.protocol_version) self.connection_class.initialize_reactor() _register_cluster_shutdown(self) - - self._add_resolved_hosts() try: self.control_connection.connect() - - # we set all contact points up for connecting, but we won't infer state after this - for endpoint in self.endpoints_resolved: - h = self.metadata.get_host(endpoint) - if h and self.profile_manager.distance(h) == HostDistance.IGNORED: - h.is_up = None + self._populate_hosts() log.debug("Control connection created") except Exception: @@ -2016,14 +2002,14 @@ def on_add(self, host, refresh_nodes=True): log.debug("Handling new host %r and notifying listeners", host) + self.profile_manager.on_add(host) + self.control_connection.on_add(host, refresh_nodes) + distance = self.profile_manager.distance(host) if distance != HostDistance.IGNORED: self._prepare_all_queries(host) log.debug("Done preparing queries for new host %r", host) - self.profile_manager.on_add(host) - self.control_connection.on_add(host, refresh_nodes) - if distance == HostDistance.IGNORED: log.debug("Not adding connection pool for new host %r because the " "load balancing policy has marked it as IGNORED", host) @@ -3534,28 +3520,22 @@ def _set_new_connection(self, conn): if old: log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) old.close() - - def _connect_host_in_lbp(self): + + def _try_connect_to_hosts(self): errors = {} - lbp = ( - self._cluster.load_balancing_policy - if self._cluster._config_mode == _ConfigMode.LEGACY else - self._cluster._default_load_balancing_policy - ) - for host in lbp.make_query_plan(): + lbp = self._cluster.load_balancing_policy \ + if self._cluster._config_mode == _ConfigMode.LEGACY else self._cluster._default_load_balancing_policy + + for endpoint in chain((host.endpoint for host in lbp.make_query_plan()), self._cluster.endpoints_resolved): try: - return (self._try_connect(host), None) - except ConnectionException as exc: - errors[str(host.endpoint)] = exc - log.warning("[control connection] Error connecting to %s:", host, exc_info=True) - self._cluster.signal_connection_failure(host, exc, is_host_addition=False) + return (self._try_connect(endpoint), None) except Exception as exc: - errors[str(host.endpoint)] = exc - log.warning("[control connection] Error connecting to %s:", host, exc_info=True) + errors[str(endpoint)] = exc + log.warning("[control connection] Error connecting to %s:", endpoint, exc_info=True) if self._is_shutdown: raise DriverException("[control connection] Reconnection in progress during shutdown") - + return (None, errors) def _reconnect_internal(self): @@ -3567,43 +3547,43 @@ def _reconnect_internal(self): to the exception that was raised when an attempt was made to open a connection to that host. """ - (conn, _) = self._connect_host_in_lbp() + (conn, _) = self._try_connect_to_hosts() if conn is not None: return conn # Try to re-resolve hostnames as a fallback when all hosts are unreachable self._cluster._resolve_hostnames() - self._cluster._add_resolved_hosts() + self._cluster._populate_hosts() - (conn, errors) = self._connect_host_in_lbp() + (conn, errors) = self._try_connect_to_hosts() if conn is not None: return conn - + raise NoHostAvailable("Unable to connect to any servers", errors) - def _try_connect(self, host): + def _try_connect(self, endpoint): """ Creates a new Connection, registers for pushed events, and refreshes node/token and schema metadata. """ - log.debug("[control connection] Opening new connection to %s", host) + log.debug("[control connection] Opening new connection to %s", endpoint) while True: try: - connection = self._cluster.connection_factory(host.endpoint, is_control_connection=True) + connection = self._cluster.connection_factory(endpoint, is_control_connection=True) if self._is_shutdown: connection.close() raise DriverException("Reconnecting during shutdown") break except ProtocolVersionUnsupported as e: - self._cluster.protocol_downgrade(host.endpoint, e.startup_version) + self._cluster.protocol_downgrade(endpoint, e.startup_version) except ProtocolException as e: # protocol v5 is out of beta in C* >=4.0-beta5 and is now the default driver # protocol version. If the protocol version was not explicitly specified, # and that the server raises a beta protocol error, we should downgrade. if not self._cluster._protocol_version_explicit and e.is_beta_protocol_error: - self._cluster.protocol_downgrade(host.endpoint, self._cluster.protocol_version) + self._cluster.protocol_downgrade(endpoint, self._cluster.protocol_version) else: raise @@ -3818,67 +3798,10 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, self._cluster.metadata.cluster_name = cluster_name partitioner = local_row.get("partitioner") - tokens = local_row.get("tokens") - - host = self._cluster.metadata.get_host(connection.original_endpoint) - if host: - datacenter = local_row.get("data_center") - rack = local_row.get("rack") - self._update_location_info(host, datacenter, rack) - - # support the use case of connecting only with public address - if isinstance(self._cluster.endpoint_factory, SniEndPointFactory): - new_endpoint = self._cluster.endpoint_factory.create(local_row) - - if new_endpoint.address: - host.endpoint = new_endpoint - - host.host_id = local_row.get("host_id") - - found_host_ids.add(host.host_id) - found_endpoints.add(host.endpoint) - - host.listen_address = local_row.get("listen_address") - host.listen_port = local_row.get("listen_port") - host.broadcast_address = _NodeInfo.get_broadcast_address(local_row) - host.broadcast_port = _NodeInfo.get_broadcast_port(local_row) - - host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(local_row) - host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(local_row) - if host.broadcast_rpc_address is None: - if self._token_meta_enabled: - # local rpc_address is not available, use the connection endpoint - host.broadcast_rpc_address = connection.endpoint.address - host.broadcast_rpc_port = connection.endpoint.port - else: - # local rpc_address has not been queried yet, try to fetch it - # separately, which might fail because C* < 2.1.6 doesn't have rpc_address - # in system.local. See CASSANDRA-9436. - local_rpc_address_query = QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS, self._metadata_request_timeout), - consistency_level=ConsistencyLevel.ONE) - success, local_rpc_address_result = connection.wait_for_response( - local_rpc_address_query, timeout=self._timeout, fail_on_error=False) - if success: - row = dict_factory( - local_rpc_address_result.column_names, - local_rpc_address_result.parsed_rows) - host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(row[0]) - host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(row[0]) - else: - host.broadcast_rpc_address = connection.endpoint.address - host.broadcast_rpc_port = connection.endpoint.port - - host.release_version = local_row.get("release_version") - host.dse_version = local_row.get("dse_version") - host.dse_workload = local_row.get("workload") - host.dse_workloads = local_row.get("workloads") + tokens = local_row.get("tokens", None) - if partitioner and tokens: - token_map[host] = tokens + peers_result.insert(0, local_row) - self._cluster.metadata.update_host(host, old_endpoint=connection.endpoint) - connection.original_endpoint = connection.endpoint = host.endpoint # Check metadata.partitioner to see if we haven't built anything yet. If # every node in the cluster was in the contact points, we won't discover # any new nodes, so we need this additional check. (See PYTHON-90) @@ -3908,14 +3831,16 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, host = self._cluster.metadata.get_host_by_host_id(host_id) if host and host.endpoint != endpoint: log.debug("[control connection] Updating host ip from %s to %s for (%s)", host.endpoint, endpoint, host_id) - old_endpoint = host.endpoint - host.endpoint = endpoint - self._cluster.metadata.update_host(host, old_endpoint) reconnector = host.get_and_set_reconnection_handler(None) if reconnector: reconnector.cancel() self._cluster.on_down(host, is_host_addition=False, expect_host_to_be_down=True) + old_endpoint = host.endpoint + host.endpoint = endpoint + self._cluster.metadata.update_host(host, old_endpoint) + self._cluster.on_up(host) + if host is None: log.debug("[control connection] Found new host to connect to: %s", endpoint) host, _ = self._cluster.add_host(endpoint, datacenter=datacenter, rack=rack, signal=True, refresh_nodes=False, host_id=host_id) @@ -4177,8 +4102,9 @@ def _get_peers_query(self, peers_query_type, connection=None): query_template = (self._SELECT_SCHEMA_PEERS_TEMPLATE if peers_query_type == self.PeersQueryType.PEERS_SCHEMA else self._SELECT_PEERS_NO_TOKENS_TEMPLATE) - host_release_version = self._cluster.metadata.get_host(connection.original_endpoint).release_version - host_dse_version = self._cluster.metadata.get_host(connection.original_endpoint).dse_version + original_endpoint_host = self._cluster.metadata.get_host(connection.original_endpoint) + host_release_version = None if original_endpoint_host is None else original_endpoint_host.release_version + host_dse_version = None if original_endpoint_host is None else original_endpoint_host.dse_version uses_native_address_query = ( host_dse_version and Version(host_dse_version) >= self._MINIMUM_NATIVE_ADDRESS_DSE_VERSION) @@ -4547,7 +4473,7 @@ def _make_query_plan(self): # or to the explicit host target if set if self._host: # returning a single value effectively disables retries - self.query_plan = [self._host] + self.query_plan = iter([self._host]) else: # convert the list/generator/etc to an iterator so that subsequent # calls to send_request (which retries may do) will resume where diff --git a/cassandra/deserializers.pyx b/cassandra/deserializers.pyx index 7c256674b0..97d249d02f 100644 --- a/cassandra/deserializers.pyx +++ b/cassandra/deserializers.pyx @@ -208,15 +208,9 @@ cdef class _DesSingleParamType(_DesParameterizedType): cdef class DesListType(_DesSingleParamType): cdef deserialize(self, Buffer *buf, int protocol_version): - cdef uint16_t v2_and_below = 2 - cdef int32_t v3_and_above = 3 - if protocol_version >= 3: - result = _deserialize_list_or_set[int32_t]( - v3_and_above, buf, protocol_version, self.deserializer) - else: - result = _deserialize_list_or_set[uint16_t]( - v2_and_below, buf, protocol_version, self.deserializer) + result = _deserialize_list_or_set( + buf, protocol_version, self.deserializer) return result @@ -225,60 +219,49 @@ cdef class DesSetType(DesListType): return util.sortedset(DesListType.deserialize(self, buf, protocol_version)) -ctypedef fused itemlen_t: - uint16_t # protocol <= v2 - int32_t # protocol >= v3 - -cdef list _deserialize_list_or_set(itemlen_t dummy_version, - Buffer *buf, int protocol_version, +cdef list _deserialize_list_or_set(Buffer *buf, int protocol_version, Deserializer deserializer): """ Deserialize a list or set. - - The 'dummy' parameter is needed to make fused types work, so that - we can specialize on the protocol version. """ cdef Buffer itemlen_buf cdef Buffer elem_buf - cdef itemlen_t numelements + cdef int32_t numelements cdef int offset cdef list result = [] - _unpack_len[itemlen_t](buf, 0, &numelements) - offset = sizeof(itemlen_t) + _unpack_len(buf, 0, &numelements) + offset = sizeof(int32_t) protocol_version = max(3, protocol_version) for _ in range(numelements): - subelem[itemlen_t](buf, &elem_buf, &offset, dummy_version) + subelem(buf, &elem_buf, &offset) result.append(from_binary(deserializer, &elem_buf, protocol_version)) return result cdef inline int subelem( - Buffer *buf, Buffer *elem_buf, int* offset, itemlen_t dummy) except -1: + Buffer *buf, Buffer *elem_buf, int* offset) except -1: """ Read the next element from the buffer: first read the size (in bytes) of the element, then fill elem_buf with a newly sliced buffer of this size (and the right offset). """ - cdef itemlen_t elemlen + cdef int32_t elemlen - _unpack_len[itemlen_t](buf, offset[0], &elemlen) - offset[0] += sizeof(itemlen_t) + _unpack_len(buf, offset[0], &elemlen) + offset[0] += sizeof(int32_t) slice_buffer(buf, elem_buf, offset[0], elemlen) offset[0] += elemlen return 0 -cdef int _unpack_len(Buffer *buf, int offset, itemlen_t *output) except -1: +cdef int _unpack_len(Buffer *buf, int offset, int32_t *output) except -1: cdef Buffer itemlen_buf - slice_buffer(buf, &itemlen_buf, offset, sizeof(itemlen_t)) + slice_buffer(buf, &itemlen_buf, offset, sizeof(int32_t)) - if itemlen_t is uint16_t: - output[0] = unpack_num[uint16_t](&itemlen_buf) - else: - output[0] = unpack_num[int32_t](&itemlen_buf) + output[0] = unpack_num[int32_t](&itemlen_buf) return 0 @@ -295,42 +278,33 @@ cdef class DesMapType(_DesParameterizedType): self.val_deserializer = self.deserializers[1] cdef deserialize(self, Buffer *buf, int protocol_version): - cdef uint16_t v2_and_below = 0 - cdef int32_t v3_and_above = 0 key_type, val_type = self.cqltype.subtypes - if protocol_version >= 3: - result = _deserialize_map[int32_t]( - v3_and_above, buf, protocol_version, - self.key_deserializer, self.val_deserializer, - key_type, val_type) - else: - result = _deserialize_map[uint16_t]( - v2_and_below, buf, protocol_version, - self.key_deserializer, self.val_deserializer, - key_type, val_type) + result = _deserialize_map( + buf, protocol_version, + self.key_deserializer, self.val_deserializer, + key_type, val_type) return result -cdef _deserialize_map(itemlen_t dummy_version, - Buffer *buf, int protocol_version, +cdef _deserialize_map(Buffer *buf, int protocol_version, Deserializer key_deserializer, Deserializer val_deserializer, key_type, val_type): cdef Buffer key_buf, val_buf cdef Buffer itemlen_buf - cdef itemlen_t numelements + cdef int32_t numelements cdef int offset cdef list result = [] - _unpack_len[itemlen_t](buf, 0, &numelements) - offset = sizeof(itemlen_t) + _unpack_len(buf, 0, &numelements) + offset = sizeof(int32_t) themap = util.OrderedMapSerializedKey(key_type, protocol_version) protocol_version = max(3, protocol_version) for _ in range(numelements): - subelem[itemlen_t](buf, &key_buf, &offset, dummy_version) - subelem[itemlen_t](buf, &val_buf, &offset, numelements) + subelem(buf, &key_buf, &offset) + subelem(buf, &val_buf, &offset) key = from_binary(key_deserializer, &key_buf, protocol_version) val = from_binary(val_deserializer, &val_buf, protocol_version) themap._insert_unchecked(key, to_bytes(&key_buf), val) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 6379de069a..b85308449e 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -139,8 +139,9 @@ def export_schema_as_string(self): def refresh(self, connection, timeout, target_type=None, change_type=None, fetch_size=None, metadata_request_timeout=None, **kwargs): - server_version = self.get_host(connection.original_endpoint).release_version - dse_version = self.get_host(connection.original_endpoint).dse_version + host = self.get_host(connection.original_endpoint) + server_version = host.release_version if host else None + dse_version = host.dse_version if host else None parser = get_schema_parser(connection, server_version, dse_version, timeout, metadata_request_timeout, fetch_size) if not target_type: @@ -1892,7 +1893,7 @@ def hash_fn(cls, key): def __init__(self, token): """ `token` is an int or string representing the token. """ - self.value = int(token) + super().__init__(int(token)) class MD5Token(HashToken): @@ -3409,8 +3410,27 @@ def __init__( self.to_clustering_columns = to_clustering_columns +def get_column_from_system_local(connection, column_name: str, timeout, metadata_request_timeout) -> str: + success, local_result = connection.wait_for_response( + QueryMessage( + query=maybe_add_timeout_to_query( + "SELECT " + column_name + " FROM system.local WHERE key='local'", + metadata_request_timeout), + consistency_level=ConsistencyLevel.ONE) + , timeout=timeout, fail_on_error=False) + if not success or not local_result.parsed_rows: + return "" + local_rows = dict_factory(local_result.column_names, local_result.parsed_rows) + local_row = local_rows[0] + return local_row.get(column_name) + + def get_schema_parser(connection, server_version, dse_version, timeout, metadata_request_timeout, fetch_size=None): - version = Version(server_version) + if server_version is None and dse_version is None: + server_version = get_column_from_system_local(connection, "release_version", timeout, metadata_request_timeout) + dse_version = get_column_from_system_local(connection, "dse_version", timeout, metadata_request_timeout) + + version = Version(server_version or "0") if dse_version: v = Version(dse_version) if v >= Version('6.8.0'): @@ -3461,7 +3481,7 @@ def group_keys_by_replica(session, keyspace, table, keys): :class:`~.NO_VALID_REPLICA` Example usage:: - + >>> result = group_keys_by_replica( ... session, "system", "peers", ... (("127.0.0.1", ), ("127.0.0.2", ))) diff --git a/cassandra/policies.py b/cassandra/policies.py index bcfd797706..e742708019 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -245,7 +245,6 @@ def __init__(self, local_dc='', used_hosts_per_remote_dc=0): self.used_hosts_per_remote_dc = used_hosts_per_remote_dc self._dc_live_hosts = {} self._position = 0 - self._endpoints = [] LoadBalancingPolicy.__init__(self) def _dc(self, host): @@ -255,11 +254,6 @@ def populate(self, cluster, hosts): for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)): self._dc_live_hosts[dc] = tuple({*dc_hosts, *self._dc_live_hosts.get(dc, [])}) - if not self.local_dc: - self._endpoints = [ - endpoint - for endpoint in cluster.endpoints_resolved] - self._position = randint(0, len(hosts) - 1) if hosts else 0 def distance(self, host): @@ -301,13 +295,11 @@ def on_up(self, host): # not worrying about threads because this will happen during # control connection startup/refresh if not self.local_dc and host.datacenter: - if host.endpoint in self._endpoints: - self.local_dc = host.datacenter - log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); " - "if incorrect, please specify a local_dc to the constructor, " - "or limit contact points to local cluster nodes" % - (self.local_dc, host.endpoint)) - del self._endpoints + self.local_dc = host.datacenter + log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); " + "if incorrect, please specify a local_dc to the constructor, " + "or limit contact points to local cluster nodes" % + (self.local_dc, host.endpoint)) dc = self._dc(host) with self._hosts_lock: @@ -511,15 +503,14 @@ def make_query_plan(self, working_keyspace=None, query=None): return replicas = [] - if self._cluster_metadata._tablets.table_has_tablets(keyspace, query.table): - tablet = self._cluster_metadata._tablets.get_tablet_for_key( + tablet = self._cluster_metadata._tablets.get_tablet_for_key( keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key)) - if tablet is not None: - replicas_mapped = set(map(lambda r: r[0], tablet.replicas)) - child_plan = child.make_query_plan(keyspace, query) + if tablet is not None: + replicas_mapped = set(map(lambda r: r[0], tablet.replicas)) + child_plan = child.make_query_plan(keyspace, query) - replicas = [host for host in child_plan if host.host_id in replicas_mapped] + replicas = [host for host in child_plan if host.host_id in replicas_mapped] else: replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key) diff --git a/cassandra/pool.py b/cassandra/pool.py index b8a8ef7493..2da657256f 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -176,7 +176,7 @@ def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=No self.endpoint = endpoint if isinstance(endpoint, EndPoint) else DefaultEndPoint(endpoint) self.conviction_policy = conviction_policy_factory(self) if not host_id: - host_id = uuid.uuid4() + raise ValueError("host_id may not be None") self.host_id = host_id self.set_location_info(datacenter, rack) self.lock = RLock() diff --git a/cassandra/protocol.py b/cassandra/protocol.py index e574965de8..f37633a756 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -1085,20 +1085,10 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta :param compressor: optional compression function to be used on the body """ flags = 0 - body = io.BytesIO() if msg.custom_payload: if protocol_version < 4: raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher") flags |= CUSTOM_PAYLOAD_FLAG - write_bytesmap(body, msg.custom_payload) - msg.send_body(body, protocol_version) - body = body.getvalue() - - # With checksumming, the compression is done at the segment frame encoding - if (not ProtocolVersion.has_checksumming_support(protocol_version) - and compressor and len(body) > 0): - body = compressor(body) - flags |= COMPRESSED_FLAG if msg.tracing: flags |= TRACING_FLAG @@ -1107,9 +1097,31 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta flags |= USE_BETA_FLAG buff = io.BytesIO() - cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, len(body)) - buff.write(body) + buff.seek(9) + + # With checksumming, the compression is done at the segment frame encoding + if (compressor and not ProtocolVersion.has_checksumming_support(protocol_version)): + body = io.BytesIO() + if msg.custom_payload: + write_bytesmap(body, msg.custom_payload) + msg.send_body(body, protocol_version) + body = body.getvalue() + + if len(body) > 0: + body = compressor(body) + flags |= COMPRESSED_FLAG + + buff.write(body) + length = len(body) + else: + if msg.custom_payload: + write_bytesmap(buff, msg.custom_payload) + msg.send_body(buff, protocol_version) + + length = buff.tell() - 9 + buff.seek(0) + cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, length) return buff.getvalue() @staticmethod diff --git a/docs/api/cassandra/protocol.rst b/docs/api/cassandra/protocol.rst index 258c6baeb6..8b8f303574 100644 --- a/docs/api/cassandra/protocol.rst +++ b/docs/api/cassandra/protocol.rst @@ -16,7 +16,7 @@ holding custom key/value pairs. By default these are ignored by the server. They can be useful for servers implementing a custom QueryHandler. -See :meth:`.Session.execute`, ::meth:`.Session.execute_async`, :attr:`.ResponseFuture.custom_payload`. +See :meth:`.Session.execute`, :meth:`.Session.execute_async`, :attr:`.ResponseFuture.custom_payload`. .. autoclass:: _ProtocolHandler @@ -53,5 +53,5 @@ These protocol handlers comprise different parsers, and return results as descri - LazyProtocolHandler: near drop-in replacement for the above, except that it returns an iterator over rows, lazily decoded into the default row format (this is more efficient since all decoded results are not materialized at once) -- NumpyProtocolHander: deserializes results directly into NumPy arrays. This facilitates efficient integration with +- NumpyProtocolHandler: deserializes results directly into NumPy arrays. This facilitates efficient integration with analysis toolkits such as Pandas. diff --git a/docs/pyproject.toml b/docs/pyproject.toml index 49a4d000ae..f6ee417aee 100644 --- a/docs/pyproject.toml +++ b/docs/pyproject.toml @@ -25,14 +25,14 @@ dependencies = [ [dependency-groups] # Add any dev-only tools here; example shown -dev = ["hatchling==1.27.0"] +dev = ["hatchling==1.28.0"] [tool.uv.sources] # Keep the driver editable from the parent directory scylla-driver = { path = "../", editable = true } [build-system] -requires = ["hatchling==1.27.0"] +requires = ["hatchling==1.28.0"] build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] diff --git a/docs/uv.lock b/docs/uv.lock index 4d40698dcc..23468c7170 100644 --- a/docs/uv.lock +++ b/docs/uv.lock @@ -350,7 +350,7 @@ wheels = [ [[package]] name = "hatchling" -version = "1.27.0" +version = "1.28.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "packaging" }, @@ -358,9 +358,9 @@ dependencies = [ { name = "pluggy" }, { name = "trove-classifiers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8f/8a/cc1debe3514da292094f1c3a700e4ca25442489731ef7c0814358816bb03/hatchling-1.27.0.tar.gz", hash = "sha256:971c296d9819abb3811112fc52c7a9751c8d381898f36533bb16f9791e941fd6", size = 54983, upload-time = "2024-12-15T17:08:11.894Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/8e/e480359492affde4119a131da729dd26da742c2c9b604dff74836e47eef9/hatchling-1.28.0.tar.gz", hash = "sha256:4d50b02aece6892b8cd0b3ce6c82cb218594d3ec5836dbde75bf41a21ab004c8", size = 55365, upload-time = "2025-11-27T00:31:13.766Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/08/e7/ae38d7a6dfba0533684e0b2136817d667588ae3ec984c1a4e5df5eb88482/hatchling-1.27.0-py3-none-any.whl", hash = "sha256:d3a2f3567c4f926ea39849cdf924c7e99e6686c9c8e288ae1037c8fa2a5d937b", size = 75794, upload-time = "2024-12-15T17:08:10.364Z" }, + { url = "https://files.pythonhosted.org/packages/0d/a5/48cb7efb8b4718b1a4c0c331e3364a3a33f614ff0d6afd2b93ee883d3c47/hatchling-1.28.0-py3-none-any.whl", hash = "sha256:dc48722b68b3f4bbfa3ff618ca07cdea6750e7d03481289ffa8be1521d18a961", size = 76075, upload-time = "2025-11-27T00:31:12.544Z" }, ] [[package]] @@ -665,7 +665,7 @@ requires-dist = [ ] [package.metadata.requires-dev] -dev = [{ name = "hatchling", specifier = "==1.27.0" }] +dev = [{ name = "hatchling", specifier = "==1.28.0" }] [[package]] name = "pyyaml" diff --git a/pyproject.toml b/pyproject.toml index f0dffa23c9..7fa90dffa1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ requires-python = ">=3.9" [project.optional-dependencies] graph = ['gremlinpython==3.7.4'] -cle = ['cryptography>=35.0'] +cle = ['cryptography>=42.0.4'] compress-lz4 = ['lz4'] compress-snappy = ['python-snappy'] @@ -51,6 +51,7 @@ dev = [ "futurist", "asynctest", "pyyaml", + "cryptography>=42.0.4", "ccm @ git+https://git@github.com/scylladb/scylla-ccm.git@master", ] diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index d7f89ad598..1208edb9d2 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -900,8 +900,9 @@ def test_profile_lb_swap(self): """ Tests that profile load balancing policies are not shared - Creates two LBP, runs a few queries, and validates that each LBP is execised - seperately between EP's + Creates two LBP, runs a few queries, and validates that each LBP is exercised + separately between EP's. Each RoundRobinPolicy starts from its own random + position and maintains independent round-robin ordering. @since 3.5 @jira_ticket PYTHON-569 @@ -916,17 +917,28 @@ def test_profile_lb_swap(self): with TestCluster(execution_profiles=exec_profiles) as cluster: session = cluster.connect(wait_for_all_pools=True) - # default is DCA RR for all hosts expected_hosts = set(cluster.metadata.all_hosts()) - rr1_queried_hosts = set() - rr2_queried_hosts = set() - - rs = session.execute(query, execution_profile='rr1') - rr1_queried_hosts.add(rs.response_future._current_host) - rs = session.execute(query, execution_profile='rr2') - rr2_queried_hosts.add(rs.response_future._current_host) - - assert rr2_queried_hosts == rr1_queried_hosts + num_hosts = len(expected_hosts) + assert num_hosts > 1, "Need at least 2 hosts for this test" + + rr1_queried_hosts = [] + rr2_queried_hosts = [] + + for _ in range(num_hosts * 2): + rs = session.execute(query, execution_profile='rr1') + rr1_queried_hosts.append(rs.response_future._current_host) + rs = session.execute(query, execution_profile='rr2') + rr2_queried_hosts.append(rs.response_future._current_host) + + # Both policies should have queried all hosts + assert set(rr1_queried_hosts) == expected_hosts + assert set(rr2_queried_hosts) == expected_hosts + + # The order of hosts should demonstrate round-robin behavior + # After num_hosts queries, the pattern should repeat + for i in range(num_hosts): + assert rr1_queried_hosts[i] == rr1_queried_hosts[i + num_hosts] + assert rr2_queried_hosts[i] == rr2_queried_hosts[i + num_hosts] def test_ta_lbp(self): """ diff --git a/tests/integration/standard/test_control_connection.py b/tests/integration/standard/test_control_connection.py index 206945f0b3..cb7820f0a6 100644 --- a/tests/integration/standard/test_control_connection.py +++ b/tests/integration/standard/test_control_connection.py @@ -90,19 +90,23 @@ def test_get_control_connection_host(self): """ host = self.cluster.get_control_connection_host() - assert host == None + assert host is None self.session = self.cluster.connect() cc_host = self.cluster.control_connection._connection.host host = self.cluster.get_control_connection_host() assert host.address == cc_host - assert host.is_up == True + assert host.is_up # reconnect and make sure that the new host is reflected correctly self.cluster.control_connection._reconnect() - new_host = self.cluster.get_control_connection_host() - assert host != new_host + new_host1 = self.cluster.get_control_connection_host() + + self.cluster.control_connection._reconnect() + new_host2 = self.cluster.get_control_connection_host() + + assert new_host1 != new_host2 # TODO: enable after https://github.com/scylladb/python-driver/issues/121 is fixed @unittest.skip('Fails on scylla due to the broadcast_rpc_port is None') @@ -117,16 +121,16 @@ def test_control_connection_port_discovery(self): self.cluster = TestCluster() host = self.cluster.get_control_connection_host() - assert host == None + assert host is None self.session = self.cluster.connect() cc_endpoint = self.cluster.control_connection._connection.endpoint host = self.cluster.get_control_connection_host() assert host.endpoint == cc_endpoint - assert host.is_up == True + assert host.is_up hosts = self.cluster.metadata.all_hosts() - assert 3 == len(hosts) + assert len(hosts) == 3 for host in hosts: assert 9042 == host.broadcast_rpc_port diff --git a/tests/integration/standard/test_metrics.py b/tests/integration/standard/test_metrics.py index 8ccd278ee4..48c7b49b95 100644 --- a/tests/integration/standard/test_metrics.py +++ b/tests/integration/standard/test_metrics.py @@ -218,7 +218,7 @@ def test_metrics_per_cluster(self): try: # Test write query = SimpleStatement("INSERT INTO {0}.{0} (k, v) VALUES (2, 2)".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) - with pytest.raises(WriteTimeout): + with pytest.raises((WriteTimeout, Unavailable)): self.session.execute(query, timeout=None) finally: get_node(1).resume() @@ -230,7 +230,7 @@ def test_metrics_per_cluster(self): stats_cluster2 = cluster2.metrics.get_stats() # Test direct access to stats - assert 1 == self.cluster.metrics.stats.write_timeouts + assert (1 == self.cluster.metrics.stats.write_timeouts or 1 == self.cluster.metrics.stats.unavailables) assert 0 == cluster2.metrics.stats.write_timeouts # Test direct access to a child stats diff --git a/tests/integration/standard/test_policies.py b/tests/integration/standard/test_policies.py index 0c84fd06be..2de12f7b7f 100644 --- a/tests/integration/standard/test_policies.py +++ b/tests/integration/standard/test_policies.py @@ -45,9 +45,6 @@ def test_predicate_changes(self): external_event = True contact_point = DefaultEndPoint("127.0.0.1") - single_host = {Host(contact_point, SimpleConvictionPolicy)} - all_hosts = {Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in (1, 2, 3)} - predicate = lambda host: host.endpoint == contact_point if external_event else True hfp = ExecutionProfile( load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(), predicate=predicate) @@ -62,7 +59,8 @@ def test_predicate_changes(self): response = session.execute("SELECT * from system.local WHERE key='local'") queried_hosts.update(response.response_future.attempted_hosts) - assert queried_hosts == single_host + assert len(queried_hosts) == 1 + assert queried_hosts.pop().endpoint == contact_point external_event = False futures = session.update_created_pools() @@ -72,7 +70,8 @@ def test_predicate_changes(self): for _ in range(10): response = session.execute("SELECT * from system.local WHERE key='local'") queried_hosts.update(response.response_future.attempted_hosts) - assert queried_hosts == all_hosts + assert len(queried_hosts) == 3 + assert {host.endpoint for host in queried_hosts} == {DefaultEndPoint(f"127.0.0.{i}") for i in range(1, 4)} class WhiteListRoundRobinPolicyTests(unittest.TestCase): diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index a3bdf8a735..9cebc22b05 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -460,7 +460,8 @@ def make_query_plan(self, working_keyspace=None, query=None): live_hosts = sorted(list(self._live_hosts)) host = [] try: - host = [live_hosts[self.host_index_to_use]] + if len(live_hosts) > 0: + host = [live_hosts[self.host_index_to_use]] except IndexError as e: raise IndexError( 'You specified an index larger than the number of hosts. Total hosts: {}. Index specified: {}'.format( diff --git a/tests/unit/advanced/test_policies.py b/tests/unit/advanced/test_policies.py index 8e421a859d..75cfd3fbf9 100644 --- a/tests/unit/advanced/test_policies.py +++ b/tests/unit/advanced/test_policies.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest from unittest.mock import Mock +import uuid from cassandra.pool import Host from cassandra.policies import RoundRobinPolicy @@ -72,7 +73,7 @@ def test_target_no_host(self): def test_target_host_down(self): node_count = 4 - hosts = [Host(i, Mock()) for i in range(node_count)] + hosts = [Host(i, Mock(), host_id=uuid.uuid4()) for i in range(node_count)] target_host = hosts[1] policy = DSELoadBalancingPolicy(RoundRobinPolicy()) @@ -87,7 +88,7 @@ def test_target_host_down(self): def test_target_host_nominal(self): node_count = 4 - hosts = [Host(i, Mock()) for i in range(node_count)] + hosts = [Host(i, Mock(), host_id=uuid.uuid4()) for i in range(node_count)] target_host = hosts[1] target_host.is_up = True diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index f3efed9f54..49208ac53e 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -17,6 +17,7 @@ import socket from unittest.mock import patch, Mock +import uuid from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion @@ -200,7 +201,7 @@ def test_default_serial_consistency_level_ep(self, *_): PR #510 """ c = Cluster(protocol_version=4) - s = Session(c, [Host("127.0.0.1", SimpleConvictionPolicy)]) + s = Session(c, [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) c.connection_class.initialize_reactor() # default is None @@ -229,7 +230,7 @@ def test_default_serial_consistency_level_legacy(self, *_): PR #510 """ c = Cluster(protocol_version=4) - s = Session(c, [Host("127.0.0.1", SimpleConvictionPolicy)]) + s = Session(c, [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) c.connection_class.initialize_reactor() # default is None assert s.default_serial_consistency_level is None @@ -286,7 +287,7 @@ def test_default_exec_parameters(self): assert cluster.profile_manager.default.load_balancing_policy.__class__ == default_lbp_factory().__class__ assert cluster.default_retry_policy.__class__ == RetryPolicy assert cluster.profile_manager.default.retry_policy.__class__ == RetryPolicy - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) assert session.default_timeout == 10.0 assert cluster.profile_manager.default.request_timeout == 10.0 assert session.default_consistency_level == ConsistencyLevel.LOCAL_ONE @@ -300,7 +301,7 @@ def test_default_exec_parameters(self): def test_default_legacy(self): cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), default_retry_policy=DowngradingConsistencyRetryPolicy()) assert cluster._config_mode == _ConfigMode.LEGACY - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) session.default_timeout = 3.7 session.default_consistency_level = ConsistencyLevel.ALL session.default_serial_consistency_level = ConsistencyLevel.SERIAL @@ -314,7 +315,7 @@ def test_default_legacy(self): def test_default_profile(self): non_default_profile = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) cluster = Cluster(execution_profiles={'non-default': non_default_profile}) - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) assert cluster._config_mode == _ConfigMode.PROFILES @@ -347,7 +348,7 @@ def test_serial_consistency_level_validation(self): def test_statement_params_override_legacy(self): cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), default_retry_policy=DowngradingConsistencyRetryPolicy()) assert cluster._config_mode == _ConfigMode.LEGACY - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) ss = SimpleStatement("query", retry_policy=DowngradingConsistencyRetryPolicy(), consistency_level=ConsistencyLevel.ALL, serial_consistency_level=ConsistencyLevel.SERIAL) @@ -368,7 +369,7 @@ def test_statement_params_override_legacy(self): def test_statement_params_override_profile(self): non_default_profile = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) cluster = Cluster(execution_profiles={'non-default': non_default_profile}) - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) assert cluster._config_mode == _ConfigMode.PROFILES @@ -406,7 +407,7 @@ def test_no_profile_with_legacy(self): # session settings lock out profiles cluster = Cluster() - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) for attr, value in (('default_timeout', 1), ('default_consistency_level', ConsistencyLevel.ANY), ('default_serial_consistency_level', ConsistencyLevel.SERIAL), @@ -432,7 +433,7 @@ def test_no_legacy_with_profile(self): ('load_balancing_policy', default_lbp_factory())): with pytest.raises(ValueError): setattr(cluster, attr, value) - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) for attr, value in (('default_timeout', 1), ('default_consistency_level', ConsistencyLevel.ANY), ('default_serial_consistency_level', ConsistencyLevel.SERIAL), @@ -445,7 +446,7 @@ def test_profile_name_value(self): internalized_profile = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) cluster = Cluster(execution_profiles={'by-name': internalized_profile}) - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) assert cluster._config_mode == _ConfigMode.PROFILES rf = session.execute_async("query", execution_profile='by-name') @@ -459,7 +460,7 @@ def test_profile_name_value(self): def test_exec_profile_clone(self): cluster = Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(), 'one': ExecutionProfile()}) - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) profile_attrs = {'request_timeout': 1, 'consistency_level': ConsistencyLevel.ANY, diff --git a/tests/unit/test_concurrent.py b/tests/unit/test_concurrent.py index a3587a3e16..9c85b1ccac 100644 --- a/tests/unit/test_concurrent.py +++ b/tests/unit/test_concurrent.py @@ -22,6 +22,7 @@ from queue import PriorityQueue import sys import platform +import uuid from cassandra.cluster import Cluster, Session from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args @@ -248,7 +249,7 @@ def test_recursion_limited(self): PYTHON-585 """ max_recursion = sys.getrecursionlimit() - s = Session(Cluster(), [Host("127.0.0.1", SimpleConvictionPolicy)]) + s = Session(Cluster(), [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) with pytest.raises(TypeError): execute_concurrent_with_args(s, "doesn't matter", [('param',)] * max_recursion, raise_on_first_error=True) diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index e7b930a990..f92bb53785 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -14,6 +14,7 @@ from concurrent.futures import ThreadPoolExecutor import logging import time +import uuid from cassandra.protocol_features import ProtocolFeatures from cassandra.shard_info import _ShardingInfo @@ -205,20 +206,20 @@ def test_host_instantiations(self): """ with pytest.raises(ValueError): - Host(None, None) + Host(None, None, host_id=uuid.uuid4()) with pytest.raises(ValueError): - Host('127.0.0.1', None) + Host('127.0.0.1', None, host_id=uuid.uuid4()) with pytest.raises(ValueError): - Host(None, SimpleConvictionPolicy) + Host(None, SimpleConvictionPolicy, host_id=uuid.uuid4()) def test_host_equality(self): """ Test host equality has correct logic """ - a = Host('127.0.0.1', SimpleConvictionPolicy) - b = Host('127.0.0.1', SimpleConvictionPolicy) - c = Host('127.0.0.2', SimpleConvictionPolicy) + a = Host('127.0.0.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + b = Host('127.0.0.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + c = Host('127.0.0.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) assert a == b, 'Two Host instances should be equal when sharing.' assert a != c, 'Two Host instances should NOT be equal when using two different addresses.' @@ -238,7 +239,8 @@ class MockSession(MagicMock): def __init__(self, *args, **kwargs): super(MockSession, self).__init__(*args, **kwargs) self.cluster = MagicMock() - self.cluster.executor = ThreadPoolExecutor(max_workers=2, initializer=self.executor_init) + self.connection_created = Event() + self.cluster.executor = ThreadPoolExecutor(max_workers=2) self.cluster.signal_connection_failure = lambda *args, **kwargs: False self.cluster.connection_factory = self.mock_connection_factory self.connection_counter = 0 @@ -253,28 +255,35 @@ def mock_connection_factory(self, *args, **kwargs): connection.is_shutdown = False connection.is_defunct = False connection.is_closed = False - connection.features = ProtocolFeatures(shard_id=self.connection_counter, + connection.features = ProtocolFeatures(shard_id=self.connection_counter, sharding_info=_ShardingInfo(shard_id=1, shards_count=14, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port="", shard_aware_port_ssl="")) self.connection_counter += 1 + self.connection_created.set() return connection - def executor_init(self, *args): - time.sleep(0.5) - LOGGER.info("Future start: %s", args) - - for attempt_num in range(20): - LOGGER.info("Testing fast shutdown %d / 20 times", attempt_num + 1) + for attempt_num in range(3): + LOGGER.info("Testing fast shutdown %d / 3 times", attempt_num + 1) host = MagicMock() host.endpoint = "1.2.3.4" - session = self.make_session() + session = MockSession() pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session) LOGGER.info("Initialized pool %s", pool) + + # Wait for initial connection to be created (with timeout) + if not session.connection_created.wait(timeout=2.0): + pytest.fail("Initial connection failed to be created within 2 seconds") + LOGGER.info("Connections: %s", pool._connections) - time.sleep(0.5) + + # Shutdown the pool pool.shutdown() - time.sleep(3) - session.cluster.executor.shutdown() + + # Verify pool is shut down + assert pool.is_shutdown, "Pool should be marked as shutdown" + + # Cleanup executor with proper wait + session.cluster.executor.shutdown(wait=True) diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index 3069f6bced..dcbb840447 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -18,6 +18,7 @@ from unittest.mock import Mock import os import timeit +import uuid import cassandra from cassandra.cqltypes import strip_frozen @@ -30,9 +31,11 @@ UserType, KeyspaceMetadata, get_schema_parser, _UnknownStrategy, ColumnMetadata, TableMetadata, IndexMetadata, Function, Aggregate, - Metadata, TokenMap, ReplicationFactor) + Metadata, TokenMap, ReplicationFactor, + SchemaParserDSE68) from cassandra.policies import SimpleConvictionPolicy from cassandra.pool import Host +from cassandra.protocol import QueryMessage from tests.util import assertCountEqual import pytest @@ -121,7 +124,7 @@ def test_simple_replication_type_parsing(self): # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] token_to_host = dict(zip(ring, hosts)) assert simple_int.make_token_replica_map(token_to_host, ring) == simple_str.make_token_replica_map(token_to_host, ring) @@ -139,7 +142,7 @@ def test_transient_replication_parsing(self): # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] token_to_host = dict(zip(ring, hosts)) assert simple_transient.make_token_replica_map(token_to_host, ring) == simple_str.make_token_replica_map(token_to_host, ring) @@ -160,7 +163,7 @@ def test_nts_replication_parsing(self): # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] token_to_host = dict(zip(ring, hosts)) assert nts_int.make_token_replica_map(token_to_host, ring) == nts_str.make_token_replica_map(token_to_host, ring) @@ -180,30 +183,30 @@ def test_nts_transient_parsing(self): # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] token_to_host = dict(zip(ring, hosts)) assert nts_transient.make_token_replica_map(token_to_host, ring) == nts_str.make_token_replica_map(token_to_host, ring) def test_nts_make_token_replica_map(self): token_to_host_owner = {} - dc1_1 = Host('dc1.1', SimpleConvictionPolicy) - dc1_2 = Host('dc1.2', SimpleConvictionPolicy) - dc1_3 = Host('dc1.3', SimpleConvictionPolicy) + dc1_1 = Host('dc1.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_2 = Host('dc1.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_3 = Host('dc1.3', SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in (dc1_1, dc1_2, dc1_3): host.set_location_info('dc1', 'rack1') token_to_host_owner[MD5Token(0)] = dc1_1 token_to_host_owner[MD5Token(100)] = dc1_2 token_to_host_owner[MD5Token(200)] = dc1_3 - dc2_1 = Host('dc2.1', SimpleConvictionPolicy) - dc2_2 = Host('dc2.2', SimpleConvictionPolicy) + dc2_1 = Host('dc2.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc2_2 = Host('dc2.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc2_1.set_location_info('dc2', 'rack1') dc2_2.set_location_info('dc2', 'rack1') token_to_host_owner[MD5Token(1)] = dc2_1 token_to_host_owner[MD5Token(101)] = dc2_2 - dc3_1 = Host('dc3.1', SimpleConvictionPolicy) + dc3_1 = Host('dc3.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc3_1.set_location_info('dc3', 'rack3') token_to_host_owner[MD5Token(2)] = dc3_1 @@ -238,7 +241,7 @@ def test_nts_token_performance(self): vnodes_per_host = 500 for i in range(dc1hostnum): - host = Host('dc1.{0}'.format(i), SimpleConvictionPolicy) + host = Host('dc1.{0}'.format(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info('dc1', "rack1") for vnode_num in range(vnodes_per_host): md5_token = MD5Token(current_token+vnode_num) @@ -262,10 +265,10 @@ def test_nts_make_token_replica_map_multi_rack(self): token_to_host_owner = {} # (A) not enough distinct racks, first skipped is used - dc1_1 = Host('dc1.1', SimpleConvictionPolicy) - dc1_2 = Host('dc1.2', SimpleConvictionPolicy) - dc1_3 = Host('dc1.3', SimpleConvictionPolicy) - dc1_4 = Host('dc1.4', SimpleConvictionPolicy) + dc1_1 = Host('dc1.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_2 = Host('dc1.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_3 = Host('dc1.3', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_4 = Host('dc1.4', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc1_1.set_location_info('dc1', 'rack1') dc1_2.set_location_info('dc1', 'rack1') dc1_3.set_location_info('dc1', 'rack2') @@ -276,9 +279,9 @@ def test_nts_make_token_replica_map_multi_rack(self): token_to_host_owner[MD5Token(300)] = dc1_4 # (B) distinct racks, but not contiguous - dc2_1 = Host('dc2.1', SimpleConvictionPolicy) - dc2_2 = Host('dc2.2', SimpleConvictionPolicy) - dc2_3 = Host('dc2.3', SimpleConvictionPolicy) + dc2_1 = Host('dc2.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc2_2 = Host('dc2.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc2_3 = Host('dc2.3', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc2_1.set_location_info('dc2', 'rack1') dc2_2.set_location_info('dc2', 'rack1') dc2_3.set_location_info('dc2', 'rack2') @@ -301,7 +304,7 @@ def test_nts_make_token_replica_map_multi_rack(self): assertCountEqual(token_replicas, (dc1_1, dc1_2, dc1_3, dc2_1, dc2_3)) def test_nts_make_token_replica_map_empty_dc(self): - host = Host('1', SimpleConvictionPolicy) + host = Host('1', SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info('dc1', 'rack1') token_to_host_owner = {MD5Token(0): host} ring = [MD5Token(0)] @@ -315,9 +318,9 @@ def test_nts_export_for_schema(self): assert "{'class': 'NetworkTopologyStrategy', 'dc1': '1', 'dc2': '2'}" == strategy.export_for_schema() def test_simple_strategy_make_token_replica_map(self): - host1 = Host('1', SimpleConvictionPolicy) - host2 = Host('2', SimpleConvictionPolicy) - host3 = Host('3', SimpleConvictionPolicy) + host1 = Host('1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + host2 = Host('2', SimpleConvictionPolicy, host_id=uuid.uuid4()) + host3 = Host('3', SimpleConvictionPolicy, host_id=uuid.uuid4()) token_to_host_owner = { MD5Token(0): host1, MD5Token(100): host2, @@ -406,7 +409,7 @@ def test_is_valid_name(self): class GetReplicasTest(unittest.TestCase): def _get_replicas(self, token_klass): tokens = [token_klass(i) for i in range(0, (2 ** 127 - 1), 2 ** 125)] - hosts = [Host("ip%d" % i, SimpleConvictionPolicy) for i in range(len(tokens))] + hosts = [Host("ip%d" % i, SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(len(tokens))] token_to_primary_replica = dict(zip(tokens, hosts)) keyspace = KeyspaceMetadata("ks", True, "SimpleStrategy", {"replication_factor": "1"}) metadata = Mock(spec=Metadata, keyspaces={'ks': keyspace}) @@ -616,6 +619,37 @@ def test_build_index_as_cql(self): assert index_meta.as_cql_query() == "CREATE CUSTOM INDEX index_name_here ON keyspace_name_here.table_name_here (column_name_here) USING 'class_name_here'" +class SchemaParserLookupTests(unittest.TestCase): + + def test_reads_versions_from_system_local_when_missing(self): + connection = Mock() + + release_version_resp = Mock() + release_version_resp.column_names = ["release_version"] + release_version_resp.parsed_rows = [["4.0.0"]] + + dse_version_resp = Mock() + dse_version_resp.column_names = ["dse_version"] + dse_version_resp.parsed_rows = [["6.8.0"]] + + def mock_system_local(query, *args, **kwargs): + if not isinstance(query, QueryMessage): + raise RuntimeError("first argument should be a QueryMessage") + if "release_version" in query.query: + return (True, release_version_resp) + if "dse_version" in query.query: + return (True, dse_version_resp) + raise RuntimeError("unexpected query") + + connection.wait_for_response.side_effect = mock_system_local + + parser = get_schema_parser(connection, None, None, 0.1, None) + + assert isinstance(parser, SchemaParserDSE68) + message = connection.wait_for_response.call_args[0][0] + assert "system.local" in message.query + + class UnicodeIdentifiersTests(unittest.TestCase): """ Exercise cql generation with unicode characters. Keyspace, Table, and Index names @@ -784,8 +818,8 @@ def test_iterate_all_hosts_and_modify(self): PYTHON-572 """ metadata = Metadata() - metadata.add_or_return_host(Host('dc1.1', SimpleConvictionPolicy)) - metadata.add_or_return_host(Host('dc1.2', SimpleConvictionPolicy)) + metadata.add_or_return_host(Host('dc1.1', SimpleConvictionPolicy, host_id=uuid.uuid4())) + metadata.add_or_return_host(Host('dc1.2', SimpleConvictionPolicy, host_id=uuid.uuid4())) assert len(metadata.all_hosts()) == 2 diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index e15705c8f7..6142af1aa1 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -17,6 +17,7 @@ from itertools import islice, cycle from unittest.mock import Mock, patch, call from random import randint +import uuid import pytest from _thread import LockType import sys @@ -46,7 +47,7 @@ def test_non_implemented(self): """ policy = LoadBalancingPolicy() - host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info("dc1", "rack1") with pytest.raises(NotImplementedError): @@ -192,11 +193,11 @@ class TestRackOrDCAwareRoundRobinPolicy: def test_no_remote(self, policy_specialization, constructor_args): hosts = [] for i in range(2): - h = Host(DefaultEndPoint(i), SimpleConvictionPolicy) + h = Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) h.set_location_info("dc1", "rack2") hosts.append(h) for i in range(2): - h = Host(DefaultEndPoint(i + 2), SimpleConvictionPolicy) + h = Host(DefaultEndPoint(i + 2), SimpleConvictionPolicy, host_id=uuid.uuid4()) h.set_location_info("dc1", "rack1") hosts.append(h) @@ -208,7 +209,7 @@ def test_no_remote(self, policy_specialization, constructor_args): assert sorted(qplan) == sorted(hosts) def test_with_remotes(self, policy_specialization, constructor_args): - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(6)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(6)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:4]: @@ -263,7 +264,7 @@ def test_get_distance(self, policy_specialization, constructor_args): policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=0) # same dc, same rack - host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info("dc1", "rack1") policy.populate(Mock(), [host]) @@ -273,14 +274,14 @@ def test_get_distance(self, policy_specialization, constructor_args): assert policy.distance(host) == HostDistance.LOCAL_RACK # same dc different rack - host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info("dc1", "rack2") policy.populate(Mock(), [host]) assert policy.distance(host) == HostDistance.LOCAL # used_hosts_per_remote_dc is set to 0, so ignore it - remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy) + remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) remote_host.set_location_info("dc2", "rack1") assert policy.distance(remote_host) == HostDistance.IGNORED @@ -294,14 +295,14 @@ def test_get_distance(self, policy_specialization, constructor_args): # since used_hosts_per_remote_dc is set to 1, only the first # remote host in dc2 will be REMOTE, the rest are IGNORED - second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy) + second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy, host_id=uuid.uuid4()) second_remote_host.set_location_info("dc2", "rack1") policy.populate(Mock(), [host, remote_host, second_remote_host]) distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) assert distances == set([HostDistance.REMOTE, HostDistance.IGNORED]) def test_status_updates(self, policy_specialization, constructor_args): - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(5)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(5)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:4]: @@ -314,11 +315,11 @@ def test_status_updates(self, policy_specialization, constructor_args): policy.on_down(hosts[0]) policy.on_remove(hosts[2]) - new_local_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) + new_local_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy, host_id=uuid.uuid4()) new_local_host.set_location_info("dc1", "rack1") policy.on_up(new_local_host) - new_remote_host = Host(DefaultEndPoint(6), SimpleConvictionPolicy) + new_remote_host = Host(DefaultEndPoint(6), SimpleConvictionPolicy, host_id=uuid.uuid4()) new_remote_host.set_location_info("dc9000", "rack1") policy.on_add(new_remote_host) @@ -343,7 +344,7 @@ def test_status_updates(self, policy_specialization, constructor_args): assert qplan == [] def test_modification_during_generation(self, policy_specialization, constructor_args): - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:]: @@ -357,7 +358,7 @@ def test_modification_during_generation(self, policy_specialization, constructor # approach that changes specific things during known phases of the # generator. - new_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy) + new_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy, host_id=uuid.uuid4()) new_host.set_location_info("dc1", "rack1") # new local before iteration @@ -468,7 +469,7 @@ def test_modification_during_generation(self, policy_specialization, constructor policy.on_up(hosts[2]) policy.on_up(hosts[3]) - another_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) + another_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy, host_id=uuid.uuid4()) another_host.set_location_info("dc3", "rack1") new_host.set_location_info("dc3", "rack1") @@ -502,7 +503,7 @@ def test_no_live_nodes(self, policy_specialization, constructor_args): hosts = [] for i in range(4): - h = Host(DefaultEndPoint(i), SimpleConvictionPolicy) + h = Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) h.set_location_info("dc1", "rack1") hosts.append(h) @@ -527,7 +528,7 @@ def test_no_nodes(self, policy_specialization, constructor_args): assert qplan == [] def test_wrong_dc(self, policy_specialization, constructor_args): - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(3)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(3)] for h in hosts[:3]: h.set_location_info("dc2", "rack2") @@ -539,9 +540,9 @@ def test_wrong_dc(self, policy_specialization, constructor_args): class DCAwareRoundRobinPolicyTest(unittest.TestCase): def test_default_dc(self): - host_local = Host(DefaultEndPoint(1), SimpleConvictionPolicy, 'local') - host_remote = Host(DefaultEndPoint(2), SimpleConvictionPolicy, 'remote') - host_none = Host(DefaultEndPoint(1), SimpleConvictionPolicy) + host_local = Host(DefaultEndPoint(1), SimpleConvictionPolicy, 'local', host_id=uuid.uuid4()) + host_remote = Host(DefaultEndPoint(2), SimpleConvictionPolicy, 'remote', host_id=uuid.uuid4()) + host_none = Host(DefaultEndPoint(1), SimpleConvictionPolicy, host_id=uuid.uuid4()) # contact point is '1' cluster = Mock(endpoints_resolved=[DefaultEndPoint(1)]) @@ -555,15 +556,6 @@ def test_default_dc(self): assert policy.local_dc != host_remote.datacenter assert policy.local_dc == host_local.datacenter - # contact DC second - policy = DCAwareRoundRobinPolicy() - policy.populate(cluster, [host_none]) - assert not policy.local_dc - policy.on_add(host_remote) - policy.on_add(host_local) - assert policy.local_dc != host_remote.datacenter - assert policy.local_dc == host_local.datacenter - # no DC policy = DCAwareRoundRobinPolicy() policy.populate(cluster, [host_none]) @@ -576,7 +568,7 @@ def test_default_dc(self): policy.populate(cluster, [host_none]) assert not policy.local_dc policy.on_add(host_remote) - assert not policy.local_dc + assert policy.local_dc class TokenAwarePolicyTest(unittest.TestCase): @@ -584,8 +576,8 @@ def test_wrap_round_robin(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) - cluster.metadata._tablets.table_has_tablets.return_value = [] - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + cluster.metadata._tablets.get_tablet_for_key.return_value = None + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() @@ -617,8 +609,8 @@ def test_wrap_dc_aware(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) - cluster.metadata._tablets.table_has_tablets.return_value = [] - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + cluster.metadata._tablets.get_tablet_for_key.return_value = None + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() for h in hosts[:2]: @@ -666,8 +658,8 @@ def test_wrap_rack_aware(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) - cluster.metadata._tablets.table_has_tablets.return_value = [] - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(8)] + cluster.metadata._tablets.get_tablet_for_key.return_value = None + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(8)] for host in hosts: host.set_up() hosts[0].set_location_info("dc1", "rack1") @@ -731,7 +723,7 @@ def test_get_distance(self): """ policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=0)) - host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info("dc1", "rack1") policy.populate(self.FakeCluster(), [host]) @@ -739,7 +731,7 @@ def test_get_distance(self): assert policy.distance(host) == HostDistance.LOCAL # used_hosts_per_remote_dc is set to 0, so ignore it - remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy) + remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) remote_host.set_location_info("dc2", "rack1") assert policy.distance(remote_host) == HostDistance.IGNORED @@ -753,7 +745,7 @@ def test_get_distance(self): # since used_hosts_per_remote_dc is set to 1, only the first # remote host in dc2 will be REMOTE, the rest are IGNORED - second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy) + second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy, host_id=uuid.uuid4()) second_remote_host.set_location_info("dc2", "rack1") policy.populate(self.FakeCluster(), [host, remote_host, second_remote_host]) distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) @@ -764,7 +756,7 @@ def test_status_updates(self): Same test as DCAwareRoundRobinPolicyTest.test_status_updates() """ - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:]: @@ -775,11 +767,11 @@ def test_status_updates(self): policy.on_down(hosts[0]) policy.on_remove(hosts[2]) - new_local_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy) + new_local_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy, host_id=uuid.uuid4()) new_local_host.set_location_info("dc1", "rack1") policy.on_up(new_local_host) - new_remote_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) + new_remote_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy, host_id=uuid.uuid4()) new_remote_host.set_location_info("dc9000", "rack1") policy.on_add(new_remote_host) @@ -802,7 +794,7 @@ def test_status_updates(self): assert qplan == [] def test_statement_keyspace(self): - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() @@ -811,7 +803,7 @@ def test_statement_keyspace(self): cluster.metadata._tablets = Mock(spec=Tablets) replicas = hosts[2:] cluster.metadata.get_replicas.return_value = replicas - cluster.metadata._tablets.table_has_tablets.return_value = [] + cluster.metadata._tablets.get_tablet_for_key.return_value = None child_policy = Mock() child_policy.make_query_plan.return_value = hosts @@ -896,7 +888,7 @@ def test_no_shuffle_if_given_no_routing_key(self): self._assert_shuffle(cluster=self._prepare_cluster_with_tablets(), keyspace='keyspace', routing_key=None) def _prepare_cluster_with_vnodes(self): - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() cluster = Mock(spec=Cluster) @@ -904,11 +896,11 @@ def _prepare_cluster_with_vnodes(self): cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata.all_hosts.return_value = hosts cluster.metadata.get_replicas.return_value = hosts[2:] - cluster.metadata._tablets.table_has_tablets.return_value = False + cluster.metadata._tablets.get_tablet_for_key.return_value = None return cluster def _prepare_cluster_with_tablets(self): - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() cluster = Mock(spec=Cluster) @@ -916,7 +908,6 @@ def _prepare_cluster_with_tablets(self): cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata.all_hosts.return_value = hosts cluster.metadata.get_replicas.return_value = hosts[2:] - cluster.metadata._tablets.table_has_tablets.return_value = True cluster.metadata._tablets.get_tablet_for_key.return_value = Tablet(replicas=[(h.host_id, 0) for h in hosts[2:]]) return cluster @@ -931,7 +922,7 @@ def _assert_shuffle(self, patched_shuffle, cluster, keyspace, routing_key): policy = TokenAwarePolicy(child_policy, shuffle_replicas=True) policy.populate(cluster, hosts) - is_tablets = cluster.metadata._tablets.table_has_tablets() + is_tablets = cluster.metadata._tablets.get_tablet_for_key() is not None cluster.metadata.get_replicas.reset_mock() child_policy.make_query_plan.reset_mock() @@ -1422,7 +1413,7 @@ class WhiteListRoundRobinPolicyTest(unittest.TestCase): def test_hosts_with_hostname(self): hosts = ['localhost'] policy = WhiteListRoundRobinPolicy(hosts) - host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy) + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) policy.populate(None, [host]) qplan = list(policy.make_query_plan()) @@ -1433,7 +1424,7 @@ def test_hosts_with_hostname(self): def test_hosts_with_socket_hostname(self): hosts = [UnixSocketEndPoint('/tmp/scylla-workdir/cql.m')] policy = WhiteListRoundRobinPolicy(hosts) - host = Host(UnixSocketEndPoint('/tmp/scylla-workdir/cql.m'), SimpleConvictionPolicy) + host = Host(UnixSocketEndPoint('/tmp/scylla-workdir/cql.m'), SimpleConvictionPolicy, host_id=uuid.uuid4()) policy.populate(None, [host]) qplan = list(policy.make_query_plan()) @@ -1559,8 +1550,8 @@ def setUp(self): child_policy=Mock(name='child_policy', distance=Mock(name='distance')), predicate=lambda host: host.address == 'acceptme' ) - self.ignored_host = Host(DefaultEndPoint('ignoreme'), conviction_policy_factory=Mock()) - self.accepted_host = Host(DefaultEndPoint('acceptme'), conviction_policy_factory=Mock()) + self.ignored_host = Host(DefaultEndPoint('ignoreme'), conviction_policy_factory=Mock(), host_id=uuid.uuid4()) + self.accepted_host = Host(DefaultEndPoint('acceptme'), conviction_policy_factory=Mock(), host_id=uuid.uuid4()) def test_ignored_with_filter(self): assert self.hfp.distance(self.ignored_host) == HostDistance.IGNORED @@ -1629,7 +1620,7 @@ def test_query_plan_deferred_to_child(self): def test_wrap_token_aware(self): cluster = Mock(spec=Cluster) - hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in range(1, 6)] + hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(1, 6)] for host in hosts: host.set_up() @@ -1638,7 +1629,7 @@ def get_replicas(keyspace, packed_key): cluster.metadata.get_replicas.side_effect = get_replicas cluster.metadata._tablets = Mock(spec=Tablets) - cluster.metadata._tablets.table_has_tablets.return_value = [] + cluster.metadata._tablets.get_tablet_for_key.return_value = None child_policy = TokenAwarePolicy(RoundRobinPolicy()) @@ -1656,13 +1647,13 @@ def get_replicas(keyspace, packed_key): query_plan = hfp.make_query_plan("keyspace", mocked_query) # First the not filtered replica, and then the rest of the allowed hosts ordered query_plan = list(query_plan) - assert query_plan[0] == Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy) - assert set(query_plan[1:]) == {Host(DefaultEndPoint("127.0.0.3"), SimpleConvictionPolicy), - Host(DefaultEndPoint("127.0.0.5"), SimpleConvictionPolicy)} + assert query_plan[0] == Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + assert set(query_plan[1:]) == {Host(DefaultEndPoint("127.0.0.3"), SimpleConvictionPolicy, host_id=uuid.uuid4()), + Host(DefaultEndPoint("127.0.0.5"), SimpleConvictionPolicy, host_id=uuid.uuid4())} def test_create_whitelist(self): cluster = Mock(spec=Cluster) - hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in range(1, 6)] + hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(1, 6)] for host in hosts: host.set_up() @@ -1680,5 +1671,5 @@ def test_create_whitelist(self): mocked_query = Mock() query_plan = hfp.make_query_plan("keyspace", mocked_query) # Only the filtered replicas should be allowed - assert set(query_plan) == {Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy), - Host(DefaultEndPoint("127.0.0.4"), SimpleConvictionPolicy)} + assert set(query_plan) == {Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()), + Host(DefaultEndPoint("127.0.0.4"), SimpleConvictionPolicy, host_id=uuid.uuid4())} diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index bcca28ac73..7168ad2940 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -24,7 +24,7 @@ from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableErrorMessage, ResultMessage, QueryMessage, OverloadedErrorMessage, IsBootstrappingErrorMessage, - PreparedQueryNotFound, PrepareMessage, + PreparedQueryNotFound, PrepareMessage, ServerError, RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_SCHEMA_CHANGE, RESULT_KIND_PREPARED, ProtocolHandler) @@ -668,3 +668,58 @@ def test_timeout_does_not_release_stream_id(self): assert len(connection.request_ids) == 0, \ "Request IDs should be empty but it's not: {}".format(connection.request_ids) + + def test_single_host_query_plan_exhausted_after_one_retry(self): + """ + Test that when a specific host is provided, the query plan is properly + exhausted after one attempt and doesn't cause infinite retries. + + This test reproduces the issue where providing a single host in the query plan + (via the host parameter) would cause infinite retries on server errors because + the query_plan was a list instead of an iterator. + """ + session = self.make_basic_session() + pool = self.make_pool() + session._pools.get.return_value = pool + + # Create a specific host + specific_host = Mock() + + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + + query = SimpleStatement("INSERT INTO foo (a, b) VALUES (1, 2)") + message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + + # Create ResponseFuture with a specific host (this is the key to reproducing the bug) + rf = ResponseFuture(session, message, query, 1, host=specific_host) + rf.send_request() + + # Verify initial request was sent + rf.session._pools.get.assert_called_once_with(specific_host) + pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) + + # Simulate a ServerError response (which triggers RETRY_NEXT_HOST by default) + result = Mock(spec=ServerError, info={}) + result.to_exception.return_value = result + rf._set_result(specific_host, None, None, result) + + # The retry should be scheduled + rf.session.cluster.scheduler.schedule.assert_called_once_with(ANY, rf._retry_task, False, specific_host) + assert 1 == rf._query_retries + + # Reset mocks to track next calls + pool.borrow_connection.reset_mock() + connection.send_msg.reset_mock() + + # Now simulate the retry task executing + # The bug would cause this to succeed and retry again infinitely + # The fix ensures the iterator is exhausted after the first try + rf._retry_task(False, specific_host) + + # After the retry, send_request should be called but the query_plan iterator + # should be exhausted, so no new request should be sent + # Instead, it should set a NoHostAvailable exception + assert rf._final_exception is not None + assert isinstance(rf._final_exception, NoHostAvailable) diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 3390f6dbd6..a5bd028b26 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -1002,11 +1002,11 @@ def test_host_order(self): @test_category data_types """ - hosts = [Host(addr, SimpleConvictionPolicy) for addr in + hosts = [Host(addr, SimpleConvictionPolicy, host_id=uuid.uuid4()) for addr in ("127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4")] - hosts_equal = [Host(addr, SimpleConvictionPolicy) for addr in + hosts_equal = [Host(addr, SimpleConvictionPolicy, host_id=uuid.uuid4()) for addr in ("127.0.0.1", "127.0.0.1")] - hosts_equal_conviction = [Host("127.0.0.1", SimpleConvictionPolicy), Host("127.0.0.1", ConvictionPolicy)] + hosts_equal_conviction = [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()), Host("127.0.0.1", ConvictionPolicy, host_id=uuid.uuid4())] check_sequence_consistency(hosts) check_sequence_consistency(hosts_equal, equal=True) check_sequence_consistency(hosts_equal_conviction, equal=True)