From 37a992def32b6959bbfda82ceaa77105e657049a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 27 Dec 2025 10:36:37 +0000 Subject: [PATCH 01/29] Initial plan From 522d55fa08249c2e1e21349ffbfea6b6a74128e0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 29 Dec 2025 17:20:38 +0000 Subject: [PATCH 02/29] Add cryptography>=42.0.4 to dev dependencies for column encryption tests Co-authored-by: mykaul <4655593+mykaul@users.noreply.github.com> --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f0dffa23c9..7fa90dffa1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ requires-python = ">=3.9" [project.optional-dependencies] graph = ['gremlinpython==3.7.4'] -cle = ['cryptography>=35.0'] +cle = ['cryptography>=42.0.4'] compress-lz4 = ['lz4'] compress-snappy = ['python-snappy'] @@ -51,6 +51,7 @@ dev = [ "futurist", "asynctest", "pyyaml", + "cryptography>=42.0.4", "ccm @ git+https://git@github.com/scylladb/scylla-ccm.git@master", ] From 76733f2e22c7ff821ca312b7fd5fc6298dc81e7f Mon Sep 17 00:00:00 2001 From: Yaniv Kaul Date: Tue, 16 Dec 2025 18:15:21 +0200 Subject: [PATCH 03/29] Apply suggested fix to docs/api/cassandra/protocol.rst from Copilot Autofix Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> --- docs/api/cassandra/protocol.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api/cassandra/protocol.rst b/docs/api/cassandra/protocol.rst index 258c6baeb6..581c4f1204 100644 --- a/docs/api/cassandra/protocol.rst +++ b/docs/api/cassandra/protocol.rst @@ -53,5 +53,5 @@ These protocol handlers comprise different parsers, and return results as descri - LazyProtocolHandler: near drop-in replacement for the above, except that it returns an iterator over rows, lazily decoded into the default row format (this is more efficient since all decoded results are not materialized at once) -- NumpyProtocolHander: deserializes results directly into NumPy arrays. This facilitates efficient integration with +- NumpyProtocolHandler: deserializes results directly into NumPy arrays. This facilitates efficient integration with analysis toolkits such as Pandas. From f1deca6f81013cf003ea9d4f7e0607b1913f4dae Mon Sep 17 00:00:00 2001 From: Yaniv Kaul Date: Tue, 16 Dec 2025 18:15:22 +0200 Subject: [PATCH 04/29] Apply suggested fix to docs/api/cassandra/protocol.rst from Copilot Autofix Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> --- docs/api/cassandra/protocol.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api/cassandra/protocol.rst b/docs/api/cassandra/protocol.rst index 581c4f1204..8b8f303574 100644 --- a/docs/api/cassandra/protocol.rst +++ b/docs/api/cassandra/protocol.rst @@ -16,7 +16,7 @@ holding custom key/value pairs. By default these are ignored by the server. They can be useful for servers implementing a custom QueryHandler. -See :meth:`.Session.execute`, ::meth:`.Session.execute_async`, :attr:`.ResponseFuture.custom_payload`. +See :meth:`.Session.execute`, :meth:`.Session.execute_async`, :attr:`.ResponseFuture.custom_payload`. .. autoclass:: _ProtocolHandler From 09cc20108796fbed84e8a4e5827af6f69f75ffb2 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Mon, 29 Dec 2025 12:10:22 +0000 Subject: [PATCH 05/29] chore(deps): update dependency hatchling to v1.28.0 --- docs/pyproject.toml | 4 ++-- docs/uv.lock | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/pyproject.toml b/docs/pyproject.toml index 49a4d000ae..f6ee417aee 100644 --- a/docs/pyproject.toml +++ b/docs/pyproject.toml @@ -25,14 +25,14 @@ dependencies = [ [dependency-groups] # Add any dev-only tools here; example shown -dev = ["hatchling==1.27.0"] +dev = ["hatchling==1.28.0"] [tool.uv.sources] # Keep the driver editable from the parent directory scylla-driver = { path = "../", editable = true } [build-system] -requires = ["hatchling==1.27.0"] +requires = ["hatchling==1.28.0"] build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] diff --git a/docs/uv.lock b/docs/uv.lock index 4d40698dcc..23468c7170 100644 --- a/docs/uv.lock +++ b/docs/uv.lock @@ -350,7 +350,7 @@ wheels = [ [[package]] name = "hatchling" -version = "1.27.0" +version = "1.28.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "packaging" }, @@ -358,9 +358,9 @@ dependencies = [ { name = "pluggy" }, { name = "trove-classifiers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8f/8a/cc1debe3514da292094f1c3a700e4ca25442489731ef7c0814358816bb03/hatchling-1.27.0.tar.gz", hash = "sha256:971c296d9819abb3811112fc52c7a9751c8d381898f36533bb16f9791e941fd6", size = 54983, upload-time = "2024-12-15T17:08:11.894Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/8e/e480359492affde4119a131da729dd26da742c2c9b604dff74836e47eef9/hatchling-1.28.0.tar.gz", hash = "sha256:4d50b02aece6892b8cd0b3ce6c82cb218594d3ec5836dbde75bf41a21ab004c8", size = 55365, upload-time = "2025-11-27T00:31:13.766Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/08/e7/ae38d7a6dfba0533684e0b2136817d667588ae3ec984c1a4e5df5eb88482/hatchling-1.27.0-py3-none-any.whl", hash = "sha256:d3a2f3567c4f926ea39849cdf924c7e99e6686c9c8e288ae1037c8fa2a5d937b", size = 75794, upload-time = "2024-12-15T17:08:10.364Z" }, + { url = "https://files.pythonhosted.org/packages/0d/a5/48cb7efb8b4718b1a4c0c331e3364a3a33f614ff0d6afd2b93ee883d3c47/hatchling-1.28.0-py3-none-any.whl", hash = "sha256:dc48722b68b3f4bbfa3ff618ca07cdea6750e7d03481289ffa8be1521d18a961", size = 76075, upload-time = "2025-11-27T00:31:12.544Z" }, ] [[package]] @@ -665,7 +665,7 @@ requires-dist = [ ] [package.metadata.requires-dev] -dev = [{ name = "hatchling", specifier = "==1.27.0" }] +dev = [{ name = "hatchling", specifier = "==1.28.0" }] [[package]] name = "pyyaml" From c114180233177fcab73ad54d917c6870e4af8e23 Mon Sep 17 00:00:00 2001 From: Yaniv Kaul Date: Tue, 16 Dec 2025 18:18:01 +0200 Subject: [PATCH 06/29] Fix for Missing call to superclass `__init__` during object initialization Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- cassandra/metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 6379de069a..bbfaf2605b 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -1892,7 +1892,7 @@ def hash_fn(cls, key): def __init__(self, token): """ `token` is an int or string representing the token. """ - self.value = int(token) + super().__init__(int(token)) class MD5Token(HashToken): From 66b81c33e8fd77c894148635ff6c0a055ee88dd3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 23 Dec 2025 09:37:35 +0000 Subject: [PATCH 07/29] Apply Python style improvements to test assertions Co-authored-by: mykaul <4655593+mykaul@users.noreply.github.com> --- tests/integration/standard/test_control_connection.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/integration/standard/test_control_connection.py b/tests/integration/standard/test_control_connection.py index 206945f0b3..a2f4e051d3 100644 --- a/tests/integration/standard/test_control_connection.py +++ b/tests/integration/standard/test_control_connection.py @@ -90,14 +90,14 @@ def test_get_control_connection_host(self): """ host = self.cluster.get_control_connection_host() - assert host == None + assert host is None self.session = self.cluster.connect() cc_host = self.cluster.control_connection._connection.host host = self.cluster.get_control_connection_host() assert host.address == cc_host - assert host.is_up == True + assert host.is_up # reconnect and make sure that the new host is reflected correctly self.cluster.control_connection._reconnect() @@ -117,16 +117,16 @@ def test_control_connection_port_discovery(self): self.cluster = TestCluster() host = self.cluster.get_control_connection_host() - assert host == None + assert host is None self.session = self.cluster.connect() cc_endpoint = self.cluster.control_connection._connection.endpoint host = self.cluster.get_control_connection_host() assert host.endpoint == cc_endpoint - assert host.is_up == True + assert host.is_up hosts = self.cluster.metadata.all_hosts() - assert 3 == len(hosts) + assert len(hosts) == 3 for host in hosts: assert 9042 == host.broadcast_rpc_port From 78145550a78fec9b943618f8bdc9067fec3e1863 Mon Sep 17 00:00:00 2001 From: Yaniv Kaul Date: Mon, 22 Dec 2025 15:51:39 +0200 Subject: [PATCH 08/29] Potential fix for code scanning alert no. 3: Workflow does not contain permissions Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- .github/workflows/docs-pages.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/docs-pages.yml b/.github/workflows/docs-pages.yml index f017201c7c..0da86fef34 100644 --- a/.github/workflows/docs-pages.yml +++ b/.github/workflows/docs-pages.yml @@ -2,6 +2,9 @@ name: "Docs / Publish" # For more information, # see https://sphinx-theme.scylladb.com/stable/deployment/production.html#available-workflows +permissions: + contents: write + on: push: branches: From e79b4c6d14a93057ea0011c5d7392360865072cc Mon Sep 17 00:00:00 2001 From: Yaniv Kaul Date: Mon, 22 Dec 2025 13:51:44 +0200 Subject: [PATCH 09/29] .github/workflows/publish-manually.yml: Potential fix for code scanning alert no. 7: Workflow does not contain permissions Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- .github/workflows/publish-manually.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/publish-manually.yml b/.github/workflows/publish-manually.yml index d69475f394..09b9779117 100644 --- a/.github/workflows/publish-manually.yml +++ b/.github/workflows/publish-manually.yml @@ -1,5 +1,8 @@ name: Build and upload to PyPi manually +permissions: + contents: read + on: workflow_dispatch: inputs: From af4d83cc7134409c2ee9897cfc2ca68d9473ef14 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 30 Dec 2025 12:27:47 -0400 Subject: [PATCH 10/29] Don't mark node down when control connection fails to connect Node pools should be stable, if cc fails to connect it is not good enough reason to neither to kill it nor to mark node down. --- cassandra/cluster.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 66bf7c7049..1fff739e97 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3546,10 +3546,6 @@ def _connect_host_in_lbp(self): for host in lbp.make_query_plan(): try: return (self._try_connect(host), None) - except ConnectionException as exc: - errors[str(host.endpoint)] = exc - log.warning("[control connection] Error connecting to %s:", host, exc_info=True) - self._cluster.signal_connection_failure(host, exc, is_host_addition=False) except Exception as exc: errors[str(host.endpoint)] = exc log.warning("[control connection] Error connecting to %s:", host, exc_info=True) From f11f55f85efd56603a2d1882fa5aade67694ffdc Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sun, 4 Jan 2026 09:54:37 +0200 Subject: [PATCH 11/29] (improvement) remove supprot for protocols <3 from cython files Continued effort to remove protocol versions < 3 as was done in the native Python code. Signed-off-by: Yaniv Kaul --- cassandra/deserializers.pyx | 74 ++++++++++++------------------------- 1 file changed, 24 insertions(+), 50 deletions(-) diff --git a/cassandra/deserializers.pyx b/cassandra/deserializers.pyx index 7c256674b0..97d249d02f 100644 --- a/cassandra/deserializers.pyx +++ b/cassandra/deserializers.pyx @@ -208,15 +208,9 @@ cdef class _DesSingleParamType(_DesParameterizedType): cdef class DesListType(_DesSingleParamType): cdef deserialize(self, Buffer *buf, int protocol_version): - cdef uint16_t v2_and_below = 2 - cdef int32_t v3_and_above = 3 - if protocol_version >= 3: - result = _deserialize_list_or_set[int32_t]( - v3_and_above, buf, protocol_version, self.deserializer) - else: - result = _deserialize_list_or_set[uint16_t]( - v2_and_below, buf, protocol_version, self.deserializer) + result = _deserialize_list_or_set( + buf, protocol_version, self.deserializer) return result @@ -225,60 +219,49 @@ cdef class DesSetType(DesListType): return util.sortedset(DesListType.deserialize(self, buf, protocol_version)) -ctypedef fused itemlen_t: - uint16_t # protocol <= v2 - int32_t # protocol >= v3 - -cdef list _deserialize_list_or_set(itemlen_t dummy_version, - Buffer *buf, int protocol_version, +cdef list _deserialize_list_or_set(Buffer *buf, int protocol_version, Deserializer deserializer): """ Deserialize a list or set. - - The 'dummy' parameter is needed to make fused types work, so that - we can specialize on the protocol version. """ cdef Buffer itemlen_buf cdef Buffer elem_buf - cdef itemlen_t numelements + cdef int32_t numelements cdef int offset cdef list result = [] - _unpack_len[itemlen_t](buf, 0, &numelements) - offset = sizeof(itemlen_t) + _unpack_len(buf, 0, &numelements) + offset = sizeof(int32_t) protocol_version = max(3, protocol_version) for _ in range(numelements): - subelem[itemlen_t](buf, &elem_buf, &offset, dummy_version) + subelem(buf, &elem_buf, &offset) result.append(from_binary(deserializer, &elem_buf, protocol_version)) return result cdef inline int subelem( - Buffer *buf, Buffer *elem_buf, int* offset, itemlen_t dummy) except -1: + Buffer *buf, Buffer *elem_buf, int* offset) except -1: """ Read the next element from the buffer: first read the size (in bytes) of the element, then fill elem_buf with a newly sliced buffer of this size (and the right offset). """ - cdef itemlen_t elemlen + cdef int32_t elemlen - _unpack_len[itemlen_t](buf, offset[0], &elemlen) - offset[0] += sizeof(itemlen_t) + _unpack_len(buf, offset[0], &elemlen) + offset[0] += sizeof(int32_t) slice_buffer(buf, elem_buf, offset[0], elemlen) offset[0] += elemlen return 0 -cdef int _unpack_len(Buffer *buf, int offset, itemlen_t *output) except -1: +cdef int _unpack_len(Buffer *buf, int offset, int32_t *output) except -1: cdef Buffer itemlen_buf - slice_buffer(buf, &itemlen_buf, offset, sizeof(itemlen_t)) + slice_buffer(buf, &itemlen_buf, offset, sizeof(int32_t)) - if itemlen_t is uint16_t: - output[0] = unpack_num[uint16_t](&itemlen_buf) - else: - output[0] = unpack_num[int32_t](&itemlen_buf) + output[0] = unpack_num[int32_t](&itemlen_buf) return 0 @@ -295,42 +278,33 @@ cdef class DesMapType(_DesParameterizedType): self.val_deserializer = self.deserializers[1] cdef deserialize(self, Buffer *buf, int protocol_version): - cdef uint16_t v2_and_below = 0 - cdef int32_t v3_and_above = 0 key_type, val_type = self.cqltype.subtypes - if protocol_version >= 3: - result = _deserialize_map[int32_t]( - v3_and_above, buf, protocol_version, - self.key_deserializer, self.val_deserializer, - key_type, val_type) - else: - result = _deserialize_map[uint16_t]( - v2_and_below, buf, protocol_version, - self.key_deserializer, self.val_deserializer, - key_type, val_type) + result = _deserialize_map( + buf, protocol_version, + self.key_deserializer, self.val_deserializer, + key_type, val_type) return result -cdef _deserialize_map(itemlen_t dummy_version, - Buffer *buf, int protocol_version, +cdef _deserialize_map(Buffer *buf, int protocol_version, Deserializer key_deserializer, Deserializer val_deserializer, key_type, val_type): cdef Buffer key_buf, val_buf cdef Buffer itemlen_buf - cdef itemlen_t numelements + cdef int32_t numelements cdef int offset cdef list result = [] - _unpack_len[itemlen_t](buf, 0, &numelements) - offset = sizeof(itemlen_t) + _unpack_len(buf, 0, &numelements) + offset = sizeof(int32_t) themap = util.OrderedMapSerializedKey(key_type, protocol_version) protocol_version = max(3, protocol_version) for _ in range(numelements): - subelem[itemlen_t](buf, &key_buf, &offset, dummy_version) - subelem[itemlen_t](buf, &val_buf, &offset, numelements) + subelem(buf, &key_buf, &offset) + subelem(buf, &val_buf, &offset) key = from_binary(key_deserializer, &key_buf, protocol_version) val = from_binary(val_deserializer, &val_buf, protocol_version) themap._insert_unchecked(key, to_bytes(&key_buf), val) From 0ebd9f517a7a967fd14029f37ac81e573118e9a3 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Thu, 1 Jan 2026 00:59:34 -0400 Subject: [PATCH 12/29] Pull version information from systel.local, when version info is not present --- cassandra/metadata.py | 26 +++++++++++++++++++++++--- tests/unit/test_metadata.py | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index bbfaf2605b..85f6c45ac6 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -139,8 +139,9 @@ def export_schema_as_string(self): def refresh(self, connection, timeout, target_type=None, change_type=None, fetch_size=None, metadata_request_timeout=None, **kwargs): - server_version = self.get_host(connection.original_endpoint).release_version - dse_version = self.get_host(connection.original_endpoint).dse_version + host = self.get_host(connection.original_endpoint) + server_version = host.release_version if host else None + dse_version = host.dse_version if host else None parser = get_schema_parser(connection, server_version, dse_version, timeout, metadata_request_timeout, fetch_size) if not target_type: @@ -3409,8 +3410,27 @@ def __init__( self.to_clustering_columns = to_clustering_columns +def get_column_from_system_local(connection, column_name: str, timeout, metadata_request_timeout) -> str: + success, local_result = connection.wait_for_response( + QueryMessage( + query=maybe_add_timeout_to_query( + "SELECT " + column_name + " FROM system.local WHERE key='local'", + metadata_request_timeout), + consistency_level=ConsistencyLevel.ONE) + , timeout=timeout, fail_on_error=False) + if not success or not local_result.parsed_rows: + return "" + local_rows = dict_factory(local_result.column_names, local_result.parsed_rows) + local_row = local_rows[0] + return local_row.get(column_name) + + def get_schema_parser(connection, server_version, dse_version, timeout, metadata_request_timeout, fetch_size=None): - version = Version(server_version) + if server_version is None and dse_version is None: + server_version = get_column_from_system_local(connection, "release_version", timeout, metadata_request_timeout) + dse_version = get_column_from_system_local(connection, "dse_version", timeout, metadata_request_timeout) + + version = Version(server_version or "0") if dse_version: v = Version(dse_version) if v >= Version('6.8.0'): diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index 3069f6bced..c471fab827 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -30,9 +30,11 @@ UserType, KeyspaceMetadata, get_schema_parser, _UnknownStrategy, ColumnMetadata, TableMetadata, IndexMetadata, Function, Aggregate, - Metadata, TokenMap, ReplicationFactor) + Metadata, TokenMap, ReplicationFactor, + SchemaParserDSE68) from cassandra.policies import SimpleConvictionPolicy from cassandra.pool import Host +from cassandra.protocol import QueryMessage from tests.util import assertCountEqual import pytest @@ -616,6 +618,37 @@ def test_build_index_as_cql(self): assert index_meta.as_cql_query() == "CREATE CUSTOM INDEX index_name_here ON keyspace_name_here.table_name_here (column_name_here) USING 'class_name_here'" +class SchemaParserLookupTests(unittest.TestCase): + + def test_reads_versions_from_system_local_when_missing(self): + connection = Mock() + + release_version_resp = Mock() + release_version_resp.column_names = ["release_version"] + release_version_resp.parsed_rows = [["4.0.0"]] + + dse_version_resp = Mock() + dse_version_resp.column_names = ["dse_version"] + dse_version_resp.parsed_rows = [["6.8.0"]] + + def mock_system_local(query, *args, **kwargs): + if not isinstance(query, QueryMessage): + raise RuntimeError("first argument should be a QueryMessage") + if "release_version" in query.query: + return (True, release_version_resp) + if "dse_version" in query.query: + return (True, dse_version_resp) + raise RuntimeError("unexpected query") + + connection.wait_for_response.side_effect = mock_system_local + + parser = get_schema_parser(connection, None, None, 0.1, None) + + assert isinstance(parser, SchemaParserDSE68) + message = connection.wait_for_response.call_args[0][0] + assert "system.local" in message.query + + class UnicodeIdentifiersTests(unittest.TestCase): """ Exercise cql generation with unicode characters. Keyspace, Table, and Index names From d08d0e2a9408f8dc11c290a649b23430b27ee6fc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 27 Dec 2025 11:31:35 +0000 Subject: [PATCH 13/29] Fix infinite retry when single host fails with server error Co-authored-by: mykaul <4655593+mykaul@users.noreply.github.com> --- cassandra/cluster.py | 2 +- tests/unit/test_response_future.py | 57 +++++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 1fff739e97..fe1c0ea4c5 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4543,7 +4543,7 @@ def _make_query_plan(self): # or to the explicit host target if set if self._host: # returning a single value effectively disables retries - self.query_plan = [self._host] + self.query_plan = iter([self._host]) else: # convert the list/generator/etc to an iterator so that subsequent # calls to send_request (which retries may do) will resume where diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index bcca28ac73..7168ad2940 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -24,7 +24,7 @@ from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableErrorMessage, ResultMessage, QueryMessage, OverloadedErrorMessage, IsBootstrappingErrorMessage, - PreparedQueryNotFound, PrepareMessage, + PreparedQueryNotFound, PrepareMessage, ServerError, RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_SCHEMA_CHANGE, RESULT_KIND_PREPARED, ProtocolHandler) @@ -668,3 +668,58 @@ def test_timeout_does_not_release_stream_id(self): assert len(connection.request_ids) == 0, \ "Request IDs should be empty but it's not: {}".format(connection.request_ids) + + def test_single_host_query_plan_exhausted_after_one_retry(self): + """ + Test that when a specific host is provided, the query plan is properly + exhausted after one attempt and doesn't cause infinite retries. + + This test reproduces the issue where providing a single host in the query plan + (via the host parameter) would cause infinite retries on server errors because + the query_plan was a list instead of an iterator. + """ + session = self.make_basic_session() + pool = self.make_pool() + session._pools.get.return_value = pool + + # Create a specific host + specific_host = Mock() + + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + + query = SimpleStatement("INSERT INTO foo (a, b) VALUES (1, 2)") + message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + + # Create ResponseFuture with a specific host (this is the key to reproducing the bug) + rf = ResponseFuture(session, message, query, 1, host=specific_host) + rf.send_request() + + # Verify initial request was sent + rf.session._pools.get.assert_called_once_with(specific_host) + pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) + + # Simulate a ServerError response (which triggers RETRY_NEXT_HOST by default) + result = Mock(spec=ServerError, info={}) + result.to_exception.return_value = result + rf._set_result(specific_host, None, None, result) + + # The retry should be scheduled + rf.session.cluster.scheduler.schedule.assert_called_once_with(ANY, rf._retry_task, False, specific_host) + assert 1 == rf._query_retries + + # Reset mocks to track next calls + pool.borrow_connection.reset_mock() + connection.send_msg.reset_mock() + + # Now simulate the retry task executing + # The bug would cause this to succeed and retry again infinitely + # The fix ensures the iterator is exhausted after the first try + rf._retry_task(False, specific_host) + + # After the retry, send_request should be called but the query_plan iterator + # should be exhausted, so no new request should be sent + # Instead, it should set a NoHostAvailable exception + assert rf._final_exception is not None + assert isinstance(rf._final_exception, NoHostAvailable) From e61d2651da30e18d897f7c538c97caed7e899ae1 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Mon, 29 Dec 2025 14:46:58 +0100 Subject: [PATCH 14/29] Use endpoint instead od Host in _try_connect --- cassandra/cluster.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index fe1c0ea4c5..8db58c13cd 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -838,8 +838,8 @@ def default_retry_policy(self, policy): Using ssl_options without ssl_context is deprecated and will be removed in the next major release. - An optional dict which will be used as kwargs for ``ssl.SSLContext.wrap_socket`` - when new sockets are created. This should be used when client encryption is enabled + An optional dict which will be used as kwargs for ``ssl.SSLContext.wrap_socket`` + when new sockets are created. This should be used when client encryption is enabled in Cassandra. The following documentation only applies when ssl_options is used without ssl_context. @@ -1086,10 +1086,10 @@ def default_retry_policy(self, policy): """ Specifies a server-side timeout (in seconds) for all internal driver queries, such as schema metadata lookups and cluster topology requests. - + The timeout is enforced by appending `USING TIMEOUT ` to queries executed by the driver. - + - A value of `0` disables explicit timeout enforcement. In this case, the driver does not add `USING TIMEOUT`, and the timeout is determined by the server's defaults. @@ -1717,7 +1717,7 @@ def connect(self, keyspace=None, wait_for_all_pools=False): self.contact_points, self.protocol_version) self.connection_class.initialize_reactor() _register_cluster_shutdown(self) - + self._add_resolved_hosts() try: @@ -3534,7 +3534,7 @@ def _set_new_connection(self, conn): if old: log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) old.close() - + def _connect_host_in_lbp(self): errors = {} lbp = ( @@ -3545,13 +3545,13 @@ def _connect_host_in_lbp(self): for host in lbp.make_query_plan(): try: - return (self._try_connect(host), None) + return (self._try_connect(host.endpoint), None) except Exception as exc: errors[str(host.endpoint)] = exc log.warning("[control connection] Error connecting to %s:", host, exc_info=True) if self._is_shutdown: raise DriverException("[control connection] Reconnection in progress during shutdown") - + return (None, errors) def _reconnect_internal(self): @@ -3575,31 +3575,31 @@ def _reconnect_internal(self): (conn, errors) = self._connect_host_in_lbp() if conn is not None: return conn - + raise NoHostAvailable("Unable to connect to any servers", errors) - def _try_connect(self, host): + def _try_connect(self, endpoint): """ Creates a new Connection, registers for pushed events, and refreshes node/token and schema metadata. """ - log.debug("[control connection] Opening new connection to %s", host) + log.debug("[control connection] Opening new connection to %s", endpoint) while True: try: - connection = self._cluster.connection_factory(host.endpoint, is_control_connection=True) + connection = self._cluster.connection_factory(endpoint, is_control_connection=True) if self._is_shutdown: connection.close() raise DriverException("Reconnecting during shutdown") break except ProtocolVersionUnsupported as e: - self._cluster.protocol_downgrade(host.endpoint, e.startup_version) + self._cluster.protocol_downgrade(endpoint, e.startup_version) except ProtocolException as e: # protocol v5 is out of beta in C* >=4.0-beta5 and is now the default driver # protocol version. If the protocol version was not explicitly specified, # and that the server raises a beta protocol error, we should downgrade. if not self._cluster._protocol_version_explicit and e.is_beta_protocol_error: - self._cluster.protocol_downgrade(host.endpoint, self._cluster.protocol_version) + self._cluster.protocol_downgrade(endpoint, self._cluster.protocol_version) else: raise From 296a98162358d58710cce85104bc79bdcaa20a64 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Mon, 29 Dec 2025 15:28:56 +0100 Subject: [PATCH 15/29] tests/integration/standard: fix test to reflect RR policy randomizing starting point The `test_profile_lb_swap` test logic assumed that `populate` was called before control connection (cc) was created, meaning only the contact points from the cluster configuration were known (a single host). Due to that the starting point was not random. This commit updates the test to reflect the new behavior, where `populate` is called on the load-balancing policy after the control connection is created. This allows the policy to be updated with all known hosts and ensures the starting point is properly randomized. --- tests/integration/standard/test_cluster.py | 36 ++++++++++++++-------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index d7f89ad598..1208edb9d2 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -900,8 +900,9 @@ def test_profile_lb_swap(self): """ Tests that profile load balancing policies are not shared - Creates two LBP, runs a few queries, and validates that each LBP is execised - seperately between EP's + Creates two LBP, runs a few queries, and validates that each LBP is exercised + separately between EP's. Each RoundRobinPolicy starts from its own random + position and maintains independent round-robin ordering. @since 3.5 @jira_ticket PYTHON-569 @@ -916,17 +917,28 @@ def test_profile_lb_swap(self): with TestCluster(execution_profiles=exec_profiles) as cluster: session = cluster.connect(wait_for_all_pools=True) - # default is DCA RR for all hosts expected_hosts = set(cluster.metadata.all_hosts()) - rr1_queried_hosts = set() - rr2_queried_hosts = set() - - rs = session.execute(query, execution_profile='rr1') - rr1_queried_hosts.add(rs.response_future._current_host) - rs = session.execute(query, execution_profile='rr2') - rr2_queried_hosts.add(rs.response_future._current_host) - - assert rr2_queried_hosts == rr1_queried_hosts + num_hosts = len(expected_hosts) + assert num_hosts > 1, "Need at least 2 hosts for this test" + + rr1_queried_hosts = [] + rr2_queried_hosts = [] + + for _ in range(num_hosts * 2): + rs = session.execute(query, execution_profile='rr1') + rr1_queried_hosts.append(rs.response_future._current_host) + rs = session.execute(query, execution_profile='rr2') + rr2_queried_hosts.append(rs.response_future._current_host) + + # Both policies should have queried all hosts + assert set(rr1_queried_hosts) == expected_hosts + assert set(rr2_queried_hosts) == expected_hosts + + # The order of hosts should demonstrate round-robin behavior + # After num_hosts queries, the pattern should repeat + for i in range(num_hosts): + assert rr1_queried_hosts[i] == rr1_queried_hosts[i + num_hosts] + assert rr2_queried_hosts[i] == rr2_queried_hosts[i + num_hosts] def test_ta_lbp(self): """ From 2b7dd50a60ef8ce1c5d6abaac031e3ea45a362a7 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Mon, 29 Dec 2025 15:37:31 +0100 Subject: [PATCH 16/29] tests/integration/standard: update test to reflect new behavior Previously, the driver relied on the load-balancing policy (LBP) to determine the order of hosts to connect to. Since the default LBP is Round Robin, each reconnection would start from a different host. After removing fake hosts with random IDs at startup, this behavior changed. When the LBP is not yet initialized, the driver now uses the endpoints provided by the control connection (CC), so there is no guarantee that different hosts will be selected on reconnection. This change updates the test logic to first establish a connection and initialize the LBP, and only then verify that two subsequent reconnections land on different hosts in a healthy cluster. --- tests/integration/standard/test_control_connection.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/integration/standard/test_control_connection.py b/tests/integration/standard/test_control_connection.py index a2f4e051d3..cb7820f0a6 100644 --- a/tests/integration/standard/test_control_connection.py +++ b/tests/integration/standard/test_control_connection.py @@ -101,8 +101,12 @@ def test_get_control_connection_host(self): # reconnect and make sure that the new host is reflected correctly self.cluster.control_connection._reconnect() - new_host = self.cluster.get_control_connection_host() - assert host != new_host + new_host1 = self.cluster.get_control_connection_host() + + self.cluster.control_connection._reconnect() + new_host2 = self.cluster.get_control_connection_host() + + assert new_host1 != new_host2 # TODO: enable after https://github.com/scylladb/python-driver/issues/121 is fixed @unittest.skip('Fails on scylla due to the broadcast_rpc_port is None') From 796b0fcde4f521113c04286d15006723874b3c69 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Mon, 29 Dec 2025 16:03:38 +0100 Subject: [PATCH 17/29] tests/integration/standard: don't compare Host instances Only compare hosts endpoints not whole Host instances as we don't know hosts ids. --- tests/integration/standard/test_policies.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/integration/standard/test_policies.py b/tests/integration/standard/test_policies.py index 0c84fd06be..2de12f7b7f 100644 --- a/tests/integration/standard/test_policies.py +++ b/tests/integration/standard/test_policies.py @@ -45,9 +45,6 @@ def test_predicate_changes(self): external_event = True contact_point = DefaultEndPoint("127.0.0.1") - single_host = {Host(contact_point, SimpleConvictionPolicy)} - all_hosts = {Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in (1, 2, 3)} - predicate = lambda host: host.endpoint == contact_point if external_event else True hfp = ExecutionProfile( load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(), predicate=predicate) @@ -62,7 +59,8 @@ def test_predicate_changes(self): response = session.execute("SELECT * from system.local WHERE key='local'") queried_hosts.update(response.response_future.attempted_hosts) - assert queried_hosts == single_host + assert len(queried_hosts) == 1 + assert queried_hosts.pop().endpoint == contact_point external_event = False futures = session.update_created_pools() @@ -72,7 +70,8 @@ def test_predicate_changes(self): for _ in range(10): response = session.execute("SELECT * from system.local WHERE key='local'") queried_hosts.update(response.response_future.attempted_hosts) - assert queried_hosts == all_hosts + assert len(queried_hosts) == 3 + assert {host.endpoint for host in queried_hosts} == {DefaultEndPoint(f"127.0.0.{i}") for i in range(1, 4)} class WhiteListRoundRobinPolicyTests(unittest.TestCase): From 1b2488015ddacfbcd3b9546a9398ebf17f315c19 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Mon, 22 Dec 2025 14:50:01 +0100 Subject: [PATCH 18/29] tests/unit: Provide host_id when initializing Host --- tests/unit/advanced/test_policies.py | 5 +- tests/unit/test_cluster.py | 23 +++---- tests/unit/test_concurrent.py | 3 +- tests/unit/test_host_connection_pool.py | 15 +++-- tests/unit/test_metadata.py | 51 +++++++-------- tests/unit/test_policies.py | 85 +++++++++++++------------ tests/unit/test_types.py | 6 +- 7 files changed, 97 insertions(+), 91 deletions(-) diff --git a/tests/unit/advanced/test_policies.py b/tests/unit/advanced/test_policies.py index 8e421a859d..75cfd3fbf9 100644 --- a/tests/unit/advanced/test_policies.py +++ b/tests/unit/advanced/test_policies.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest from unittest.mock import Mock +import uuid from cassandra.pool import Host from cassandra.policies import RoundRobinPolicy @@ -72,7 +73,7 @@ def test_target_no_host(self): def test_target_host_down(self): node_count = 4 - hosts = [Host(i, Mock()) for i in range(node_count)] + hosts = [Host(i, Mock(), host_id=uuid.uuid4()) for i in range(node_count)] target_host = hosts[1] policy = DSELoadBalancingPolicy(RoundRobinPolicy()) @@ -87,7 +88,7 @@ def test_target_host_down(self): def test_target_host_nominal(self): node_count = 4 - hosts = [Host(i, Mock()) for i in range(node_count)] + hosts = [Host(i, Mock(), host_id=uuid.uuid4()) for i in range(node_count)] target_host = hosts[1] target_host.is_up = True diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index f3efed9f54..49208ac53e 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -17,6 +17,7 @@ import socket from unittest.mock import patch, Mock +import uuid from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion @@ -200,7 +201,7 @@ def test_default_serial_consistency_level_ep(self, *_): PR #510 """ c = Cluster(protocol_version=4) - s = Session(c, [Host("127.0.0.1", SimpleConvictionPolicy)]) + s = Session(c, [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) c.connection_class.initialize_reactor() # default is None @@ -229,7 +230,7 @@ def test_default_serial_consistency_level_legacy(self, *_): PR #510 """ c = Cluster(protocol_version=4) - s = Session(c, [Host("127.0.0.1", SimpleConvictionPolicy)]) + s = Session(c, [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) c.connection_class.initialize_reactor() # default is None assert s.default_serial_consistency_level is None @@ -286,7 +287,7 @@ def test_default_exec_parameters(self): assert cluster.profile_manager.default.load_balancing_policy.__class__ == default_lbp_factory().__class__ assert cluster.default_retry_policy.__class__ == RetryPolicy assert cluster.profile_manager.default.retry_policy.__class__ == RetryPolicy - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) assert session.default_timeout == 10.0 assert cluster.profile_manager.default.request_timeout == 10.0 assert session.default_consistency_level == ConsistencyLevel.LOCAL_ONE @@ -300,7 +301,7 @@ def test_default_exec_parameters(self): def test_default_legacy(self): cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), default_retry_policy=DowngradingConsistencyRetryPolicy()) assert cluster._config_mode == _ConfigMode.LEGACY - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) session.default_timeout = 3.7 session.default_consistency_level = ConsistencyLevel.ALL session.default_serial_consistency_level = ConsistencyLevel.SERIAL @@ -314,7 +315,7 @@ def test_default_legacy(self): def test_default_profile(self): non_default_profile = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) cluster = Cluster(execution_profiles={'non-default': non_default_profile}) - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) assert cluster._config_mode == _ConfigMode.PROFILES @@ -347,7 +348,7 @@ def test_serial_consistency_level_validation(self): def test_statement_params_override_legacy(self): cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), default_retry_policy=DowngradingConsistencyRetryPolicy()) assert cluster._config_mode == _ConfigMode.LEGACY - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) ss = SimpleStatement("query", retry_policy=DowngradingConsistencyRetryPolicy(), consistency_level=ConsistencyLevel.ALL, serial_consistency_level=ConsistencyLevel.SERIAL) @@ -368,7 +369,7 @@ def test_statement_params_override_legacy(self): def test_statement_params_override_profile(self): non_default_profile = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) cluster = Cluster(execution_profiles={'non-default': non_default_profile}) - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) assert cluster._config_mode == _ConfigMode.PROFILES @@ -406,7 +407,7 @@ def test_no_profile_with_legacy(self): # session settings lock out profiles cluster = Cluster() - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) for attr, value in (('default_timeout', 1), ('default_consistency_level', ConsistencyLevel.ANY), ('default_serial_consistency_level', ConsistencyLevel.SERIAL), @@ -432,7 +433,7 @@ def test_no_legacy_with_profile(self): ('load_balancing_policy', default_lbp_factory())): with pytest.raises(ValueError): setattr(cluster, attr, value) - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) for attr, value in (('default_timeout', 1), ('default_consistency_level', ConsistencyLevel.ANY), ('default_serial_consistency_level', ConsistencyLevel.SERIAL), @@ -445,7 +446,7 @@ def test_profile_name_value(self): internalized_profile = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) cluster = Cluster(execution_profiles={'by-name': internalized_profile}) - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) assert cluster._config_mode == _ConfigMode.PROFILES rf = session.execute_async("query", execution_profile='by-name') @@ -459,7 +460,7 @@ def test_profile_name_value(self): def test_exec_profile_clone(self): cluster = Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(), 'one': ExecutionProfile()}) - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) profile_attrs = {'request_timeout': 1, 'consistency_level': ConsistencyLevel.ANY, diff --git a/tests/unit/test_concurrent.py b/tests/unit/test_concurrent.py index a3587a3e16..9c85b1ccac 100644 --- a/tests/unit/test_concurrent.py +++ b/tests/unit/test_concurrent.py @@ -22,6 +22,7 @@ from queue import PriorityQueue import sys import platform +import uuid from cassandra.cluster import Cluster, Session from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args @@ -248,7 +249,7 @@ def test_recursion_limited(self): PYTHON-585 """ max_recursion = sys.getrecursionlimit() - s = Session(Cluster(), [Host("127.0.0.1", SimpleConvictionPolicy)]) + s = Session(Cluster(), [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) with pytest.raises(TypeError): execute_concurrent_with_args(s, "doesn't matter", [('param',)] * max_recursion, raise_on_first_error=True) diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index e7b930a990..580eb336b2 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -14,6 +14,7 @@ from concurrent.futures import ThreadPoolExecutor import logging import time +import uuid from cassandra.protocol_features import ProtocolFeatures from cassandra.shard_info import _ShardingInfo @@ -205,20 +206,20 @@ def test_host_instantiations(self): """ with pytest.raises(ValueError): - Host(None, None) + Host(None, None, host_id=uuid.uuid4()) with pytest.raises(ValueError): - Host('127.0.0.1', None) + Host('127.0.0.1', None, host_id=uuid.uuid4()) with pytest.raises(ValueError): - Host(None, SimpleConvictionPolicy) + Host(None, SimpleConvictionPolicy, host_id=uuid.uuid4()) def test_host_equality(self): """ Test host equality has correct logic """ - a = Host('127.0.0.1', SimpleConvictionPolicy) - b = Host('127.0.0.1', SimpleConvictionPolicy) - c = Host('127.0.0.2', SimpleConvictionPolicy) + a = Host('127.0.0.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + b = Host('127.0.0.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + c = Host('127.0.0.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) assert a == b, 'Two Host instances should be equal when sharing.' assert a != c, 'Two Host instances should NOT be equal when using two different addresses.' @@ -253,7 +254,7 @@ def mock_connection_factory(self, *args, **kwargs): connection.is_shutdown = False connection.is_defunct = False connection.is_closed = False - connection.features = ProtocolFeatures(shard_id=self.connection_counter, + connection.features = ProtocolFeatures(shard_id=self.connection_counter, sharding_info=_ShardingInfo(shard_id=1, shards_count=14, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port="", shard_aware_port_ssl="")) diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index c471fab827..dcbb840447 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -18,6 +18,7 @@ from unittest.mock import Mock import os import timeit +import uuid import cassandra from cassandra.cqltypes import strip_frozen @@ -123,7 +124,7 @@ def test_simple_replication_type_parsing(self): # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] token_to_host = dict(zip(ring, hosts)) assert simple_int.make_token_replica_map(token_to_host, ring) == simple_str.make_token_replica_map(token_to_host, ring) @@ -141,7 +142,7 @@ def test_transient_replication_parsing(self): # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] token_to_host = dict(zip(ring, hosts)) assert simple_transient.make_token_replica_map(token_to_host, ring) == simple_str.make_token_replica_map(token_to_host, ring) @@ -162,7 +163,7 @@ def test_nts_replication_parsing(self): # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] token_to_host = dict(zip(ring, hosts)) assert nts_int.make_token_replica_map(token_to_host, ring) == nts_str.make_token_replica_map(token_to_host, ring) @@ -182,30 +183,30 @@ def test_nts_transient_parsing(self): # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] token_to_host = dict(zip(ring, hosts)) assert nts_transient.make_token_replica_map(token_to_host, ring) == nts_str.make_token_replica_map(token_to_host, ring) def test_nts_make_token_replica_map(self): token_to_host_owner = {} - dc1_1 = Host('dc1.1', SimpleConvictionPolicy) - dc1_2 = Host('dc1.2', SimpleConvictionPolicy) - dc1_3 = Host('dc1.3', SimpleConvictionPolicy) + dc1_1 = Host('dc1.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_2 = Host('dc1.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_3 = Host('dc1.3', SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in (dc1_1, dc1_2, dc1_3): host.set_location_info('dc1', 'rack1') token_to_host_owner[MD5Token(0)] = dc1_1 token_to_host_owner[MD5Token(100)] = dc1_2 token_to_host_owner[MD5Token(200)] = dc1_3 - dc2_1 = Host('dc2.1', SimpleConvictionPolicy) - dc2_2 = Host('dc2.2', SimpleConvictionPolicy) + dc2_1 = Host('dc2.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc2_2 = Host('dc2.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc2_1.set_location_info('dc2', 'rack1') dc2_2.set_location_info('dc2', 'rack1') token_to_host_owner[MD5Token(1)] = dc2_1 token_to_host_owner[MD5Token(101)] = dc2_2 - dc3_1 = Host('dc3.1', SimpleConvictionPolicy) + dc3_1 = Host('dc3.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc3_1.set_location_info('dc3', 'rack3') token_to_host_owner[MD5Token(2)] = dc3_1 @@ -240,7 +241,7 @@ def test_nts_token_performance(self): vnodes_per_host = 500 for i in range(dc1hostnum): - host = Host('dc1.{0}'.format(i), SimpleConvictionPolicy) + host = Host('dc1.{0}'.format(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info('dc1', "rack1") for vnode_num in range(vnodes_per_host): md5_token = MD5Token(current_token+vnode_num) @@ -264,10 +265,10 @@ def test_nts_make_token_replica_map_multi_rack(self): token_to_host_owner = {} # (A) not enough distinct racks, first skipped is used - dc1_1 = Host('dc1.1', SimpleConvictionPolicy) - dc1_2 = Host('dc1.2', SimpleConvictionPolicy) - dc1_3 = Host('dc1.3', SimpleConvictionPolicy) - dc1_4 = Host('dc1.4', SimpleConvictionPolicy) + dc1_1 = Host('dc1.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_2 = Host('dc1.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_3 = Host('dc1.3', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_4 = Host('dc1.4', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc1_1.set_location_info('dc1', 'rack1') dc1_2.set_location_info('dc1', 'rack1') dc1_3.set_location_info('dc1', 'rack2') @@ -278,9 +279,9 @@ def test_nts_make_token_replica_map_multi_rack(self): token_to_host_owner[MD5Token(300)] = dc1_4 # (B) distinct racks, but not contiguous - dc2_1 = Host('dc2.1', SimpleConvictionPolicy) - dc2_2 = Host('dc2.2', SimpleConvictionPolicy) - dc2_3 = Host('dc2.3', SimpleConvictionPolicy) + dc2_1 = Host('dc2.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc2_2 = Host('dc2.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc2_3 = Host('dc2.3', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc2_1.set_location_info('dc2', 'rack1') dc2_2.set_location_info('dc2', 'rack1') dc2_3.set_location_info('dc2', 'rack2') @@ -303,7 +304,7 @@ def test_nts_make_token_replica_map_multi_rack(self): assertCountEqual(token_replicas, (dc1_1, dc1_2, dc1_3, dc2_1, dc2_3)) def test_nts_make_token_replica_map_empty_dc(self): - host = Host('1', SimpleConvictionPolicy) + host = Host('1', SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info('dc1', 'rack1') token_to_host_owner = {MD5Token(0): host} ring = [MD5Token(0)] @@ -317,9 +318,9 @@ def test_nts_export_for_schema(self): assert "{'class': 'NetworkTopologyStrategy', 'dc1': '1', 'dc2': '2'}" == strategy.export_for_schema() def test_simple_strategy_make_token_replica_map(self): - host1 = Host('1', SimpleConvictionPolicy) - host2 = Host('2', SimpleConvictionPolicy) - host3 = Host('3', SimpleConvictionPolicy) + host1 = Host('1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + host2 = Host('2', SimpleConvictionPolicy, host_id=uuid.uuid4()) + host3 = Host('3', SimpleConvictionPolicy, host_id=uuid.uuid4()) token_to_host_owner = { MD5Token(0): host1, MD5Token(100): host2, @@ -408,7 +409,7 @@ def test_is_valid_name(self): class GetReplicasTest(unittest.TestCase): def _get_replicas(self, token_klass): tokens = [token_klass(i) for i in range(0, (2 ** 127 - 1), 2 ** 125)] - hosts = [Host("ip%d" % i, SimpleConvictionPolicy) for i in range(len(tokens))] + hosts = [Host("ip%d" % i, SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(len(tokens))] token_to_primary_replica = dict(zip(tokens, hosts)) keyspace = KeyspaceMetadata("ks", True, "SimpleStrategy", {"replication_factor": "1"}) metadata = Mock(spec=Metadata, keyspaces={'ks': keyspace}) @@ -817,8 +818,8 @@ def test_iterate_all_hosts_and_modify(self): PYTHON-572 """ metadata = Metadata() - metadata.add_or_return_host(Host('dc1.1', SimpleConvictionPolicy)) - metadata.add_or_return_host(Host('dc1.2', SimpleConvictionPolicy)) + metadata.add_or_return_host(Host('dc1.1', SimpleConvictionPolicy, host_id=uuid.uuid4())) + metadata.add_or_return_host(Host('dc1.2', SimpleConvictionPolicy, host_id=uuid.uuid4())) assert len(metadata.all_hosts()) == 2 diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index e15705c8f7..65feaf72e5 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -17,6 +17,7 @@ from itertools import islice, cycle from unittest.mock import Mock, patch, call from random import randint +import uuid import pytest from _thread import LockType import sys @@ -46,7 +47,7 @@ def test_non_implemented(self): """ policy = LoadBalancingPolicy() - host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info("dc1", "rack1") with pytest.raises(NotImplementedError): @@ -192,11 +193,11 @@ class TestRackOrDCAwareRoundRobinPolicy: def test_no_remote(self, policy_specialization, constructor_args): hosts = [] for i in range(2): - h = Host(DefaultEndPoint(i), SimpleConvictionPolicy) + h = Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) h.set_location_info("dc1", "rack2") hosts.append(h) for i in range(2): - h = Host(DefaultEndPoint(i + 2), SimpleConvictionPolicy) + h = Host(DefaultEndPoint(i + 2), SimpleConvictionPolicy, host_id=uuid.uuid4()) h.set_location_info("dc1", "rack1") hosts.append(h) @@ -208,7 +209,7 @@ def test_no_remote(self, policy_specialization, constructor_args): assert sorted(qplan) == sorted(hosts) def test_with_remotes(self, policy_specialization, constructor_args): - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(6)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(6)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:4]: @@ -263,7 +264,7 @@ def test_get_distance(self, policy_specialization, constructor_args): policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=0) # same dc, same rack - host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info("dc1", "rack1") policy.populate(Mock(), [host]) @@ -273,14 +274,14 @@ def test_get_distance(self, policy_specialization, constructor_args): assert policy.distance(host) == HostDistance.LOCAL_RACK # same dc different rack - host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info("dc1", "rack2") policy.populate(Mock(), [host]) assert policy.distance(host) == HostDistance.LOCAL # used_hosts_per_remote_dc is set to 0, so ignore it - remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy) + remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) remote_host.set_location_info("dc2", "rack1") assert policy.distance(remote_host) == HostDistance.IGNORED @@ -294,14 +295,14 @@ def test_get_distance(self, policy_specialization, constructor_args): # since used_hosts_per_remote_dc is set to 1, only the first # remote host in dc2 will be REMOTE, the rest are IGNORED - second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy) + second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy, host_id=uuid.uuid4()) second_remote_host.set_location_info("dc2", "rack1") policy.populate(Mock(), [host, remote_host, second_remote_host]) distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) assert distances == set([HostDistance.REMOTE, HostDistance.IGNORED]) def test_status_updates(self, policy_specialization, constructor_args): - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(5)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(5)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:4]: @@ -314,11 +315,11 @@ def test_status_updates(self, policy_specialization, constructor_args): policy.on_down(hosts[0]) policy.on_remove(hosts[2]) - new_local_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) + new_local_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy, host_id=uuid.uuid4()) new_local_host.set_location_info("dc1", "rack1") policy.on_up(new_local_host) - new_remote_host = Host(DefaultEndPoint(6), SimpleConvictionPolicy) + new_remote_host = Host(DefaultEndPoint(6), SimpleConvictionPolicy, host_id=uuid.uuid4()) new_remote_host.set_location_info("dc9000", "rack1") policy.on_add(new_remote_host) @@ -343,7 +344,7 @@ def test_status_updates(self, policy_specialization, constructor_args): assert qplan == [] def test_modification_during_generation(self, policy_specialization, constructor_args): - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:]: @@ -357,7 +358,7 @@ def test_modification_during_generation(self, policy_specialization, constructor # approach that changes specific things during known phases of the # generator. - new_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy) + new_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy, host_id=uuid.uuid4()) new_host.set_location_info("dc1", "rack1") # new local before iteration @@ -468,7 +469,7 @@ def test_modification_during_generation(self, policy_specialization, constructor policy.on_up(hosts[2]) policy.on_up(hosts[3]) - another_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) + another_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy, host_id=uuid.uuid4()) another_host.set_location_info("dc3", "rack1") new_host.set_location_info("dc3", "rack1") @@ -502,7 +503,7 @@ def test_no_live_nodes(self, policy_specialization, constructor_args): hosts = [] for i in range(4): - h = Host(DefaultEndPoint(i), SimpleConvictionPolicy) + h = Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) h.set_location_info("dc1", "rack1") hosts.append(h) @@ -527,7 +528,7 @@ def test_no_nodes(self, policy_specialization, constructor_args): assert qplan == [] def test_wrong_dc(self, policy_specialization, constructor_args): - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(3)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(3)] for h in hosts[:3]: h.set_location_info("dc2", "rack2") @@ -539,9 +540,9 @@ def test_wrong_dc(self, policy_specialization, constructor_args): class DCAwareRoundRobinPolicyTest(unittest.TestCase): def test_default_dc(self): - host_local = Host(DefaultEndPoint(1), SimpleConvictionPolicy, 'local') - host_remote = Host(DefaultEndPoint(2), SimpleConvictionPolicy, 'remote') - host_none = Host(DefaultEndPoint(1), SimpleConvictionPolicy) + host_local = Host(DefaultEndPoint(1), SimpleConvictionPolicy, 'local', host_id=uuid.uuid4()) + host_remote = Host(DefaultEndPoint(2), SimpleConvictionPolicy, 'remote', host_id=uuid.uuid4()) + host_none = Host(DefaultEndPoint(1), SimpleConvictionPolicy, host_id=uuid.uuid4()) # contact point is '1' cluster = Mock(endpoints_resolved=[DefaultEndPoint(1)]) @@ -585,7 +586,7 @@ def test_wrap_round_robin(self): cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata._tablets.table_has_tablets.return_value = [] - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() @@ -618,7 +619,7 @@ def test_wrap_dc_aware(self): cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata._tablets.table_has_tablets.return_value = [] - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() for h in hosts[:2]: @@ -667,7 +668,7 @@ def test_wrap_rack_aware(self): cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata._tablets.table_has_tablets.return_value = [] - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(8)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(8)] for host in hosts: host.set_up() hosts[0].set_location_info("dc1", "rack1") @@ -731,7 +732,7 @@ def test_get_distance(self): """ policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=0)) - host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info("dc1", "rack1") policy.populate(self.FakeCluster(), [host]) @@ -739,7 +740,7 @@ def test_get_distance(self): assert policy.distance(host) == HostDistance.LOCAL # used_hosts_per_remote_dc is set to 0, so ignore it - remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy) + remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) remote_host.set_location_info("dc2", "rack1") assert policy.distance(remote_host) == HostDistance.IGNORED @@ -753,7 +754,7 @@ def test_get_distance(self): # since used_hosts_per_remote_dc is set to 1, only the first # remote host in dc2 will be REMOTE, the rest are IGNORED - second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy) + second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy, host_id=uuid.uuid4()) second_remote_host.set_location_info("dc2", "rack1") policy.populate(self.FakeCluster(), [host, remote_host, second_remote_host]) distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) @@ -764,7 +765,7 @@ def test_status_updates(self): Same test as DCAwareRoundRobinPolicyTest.test_status_updates() """ - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:]: @@ -775,11 +776,11 @@ def test_status_updates(self): policy.on_down(hosts[0]) policy.on_remove(hosts[2]) - new_local_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy) + new_local_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy, host_id=uuid.uuid4()) new_local_host.set_location_info("dc1", "rack1") policy.on_up(new_local_host) - new_remote_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) + new_remote_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy, host_id=uuid.uuid4()) new_remote_host.set_location_info("dc9000", "rack1") policy.on_add(new_remote_host) @@ -802,7 +803,7 @@ def test_status_updates(self): assert qplan == [] def test_statement_keyspace(self): - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() @@ -896,7 +897,7 @@ def test_no_shuffle_if_given_no_routing_key(self): self._assert_shuffle(cluster=self._prepare_cluster_with_tablets(), keyspace='keyspace', routing_key=None) def _prepare_cluster_with_vnodes(self): - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() cluster = Mock(spec=Cluster) @@ -908,7 +909,7 @@ def _prepare_cluster_with_vnodes(self): return cluster def _prepare_cluster_with_tablets(self): - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() cluster = Mock(spec=Cluster) @@ -1422,7 +1423,7 @@ class WhiteListRoundRobinPolicyTest(unittest.TestCase): def test_hosts_with_hostname(self): hosts = ['localhost'] policy = WhiteListRoundRobinPolicy(hosts) - host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy) + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) policy.populate(None, [host]) qplan = list(policy.make_query_plan()) @@ -1433,7 +1434,7 @@ def test_hosts_with_hostname(self): def test_hosts_with_socket_hostname(self): hosts = [UnixSocketEndPoint('/tmp/scylla-workdir/cql.m')] policy = WhiteListRoundRobinPolicy(hosts) - host = Host(UnixSocketEndPoint('/tmp/scylla-workdir/cql.m'), SimpleConvictionPolicy) + host = Host(UnixSocketEndPoint('/tmp/scylla-workdir/cql.m'), SimpleConvictionPolicy, host_id=uuid.uuid4()) policy.populate(None, [host]) qplan = list(policy.make_query_plan()) @@ -1559,8 +1560,8 @@ def setUp(self): child_policy=Mock(name='child_policy', distance=Mock(name='distance')), predicate=lambda host: host.address == 'acceptme' ) - self.ignored_host = Host(DefaultEndPoint('ignoreme'), conviction_policy_factory=Mock()) - self.accepted_host = Host(DefaultEndPoint('acceptme'), conviction_policy_factory=Mock()) + self.ignored_host = Host(DefaultEndPoint('ignoreme'), conviction_policy_factory=Mock(), host_id=uuid.uuid4()) + self.accepted_host = Host(DefaultEndPoint('acceptme'), conviction_policy_factory=Mock(), host_id=uuid.uuid4()) def test_ignored_with_filter(self): assert self.hfp.distance(self.ignored_host) == HostDistance.IGNORED @@ -1629,7 +1630,7 @@ def test_query_plan_deferred_to_child(self): def test_wrap_token_aware(self): cluster = Mock(spec=Cluster) - hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in range(1, 6)] + hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(1, 6)] for host in hosts: host.set_up() @@ -1656,13 +1657,13 @@ def get_replicas(keyspace, packed_key): query_plan = hfp.make_query_plan("keyspace", mocked_query) # First the not filtered replica, and then the rest of the allowed hosts ordered query_plan = list(query_plan) - assert query_plan[0] == Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy) - assert set(query_plan[1:]) == {Host(DefaultEndPoint("127.0.0.3"), SimpleConvictionPolicy), - Host(DefaultEndPoint("127.0.0.5"), SimpleConvictionPolicy)} + assert query_plan[0] == Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + assert set(query_plan[1:]) == {Host(DefaultEndPoint("127.0.0.3"), SimpleConvictionPolicy, host_id=uuid.uuid4()), + Host(DefaultEndPoint("127.0.0.5"), SimpleConvictionPolicy, host_id=uuid.uuid4())} def test_create_whitelist(self): cluster = Mock(spec=Cluster) - hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in range(1, 6)] + hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(1, 6)] for host in hosts: host.set_up() @@ -1680,5 +1681,5 @@ def test_create_whitelist(self): mocked_query = Mock() query_plan = hfp.make_query_plan("keyspace", mocked_query) # Only the filtered replicas should be allowed - assert set(query_plan) == {Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy), - Host(DefaultEndPoint("127.0.0.4"), SimpleConvictionPolicy)} + assert set(query_plan) == {Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()), + Host(DefaultEndPoint("127.0.0.4"), SimpleConvictionPolicy, host_id=uuid.uuid4())} diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 3390f6dbd6..a5bd028b26 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -1002,11 +1002,11 @@ def test_host_order(self): @test_category data_types """ - hosts = [Host(addr, SimpleConvictionPolicy) for addr in + hosts = [Host(addr, SimpleConvictionPolicy, host_id=uuid.uuid4()) for addr in ("127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4")] - hosts_equal = [Host(addr, SimpleConvictionPolicy) for addr in + hosts_equal = [Host(addr, SimpleConvictionPolicy, host_id=uuid.uuid4()) for addr in ("127.0.0.1", "127.0.0.1")] - hosts_equal_conviction = [Host("127.0.0.1", SimpleConvictionPolicy), Host("127.0.0.1", ConvictionPolicy)] + hosts_equal_conviction = [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()), Host("127.0.0.1", ConvictionPolicy, host_id=uuid.uuid4())] check_sequence_consistency(hosts) check_sequence_consistency(hosts_equal, equal=True) check_sequence_consistency(hosts_equal_conviction, equal=True) From d6459b91c80d610fe2e8896c70787fb9fc96a260 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Tue, 30 Dec 2025 09:23:32 +0100 Subject: [PATCH 19/29] tests/integration/standard: return empty query plan if there are no live hosts --- tests/integration/standard/test_query.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index a3bdf8a735..9cebc22b05 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -460,7 +460,8 @@ def make_query_plan(self, working_keyspace=None, query=None): live_hosts = sorted(list(self._live_hosts)) host = [] try: - host = [live_hosts[self.host_index_to_use]] + if len(live_hosts) > 0: + host = [live_hosts[self.host_index_to_use]] except IndexError as e: raise IndexError( 'You specified an index larger than the number of hosts. Total hosts: {}. Index specified: {}'.format( From 7e4bd1f851950247c2131e81ff585e04ea3275ae Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Mon, 12 Jan 2026 13:29:48 +0100 Subject: [PATCH 20/29] tests/integration/standard: allow execute to throw Unavailable exception --- tests/integration/standard/test_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/standard/test_metrics.py b/tests/integration/standard/test_metrics.py index 8ccd278ee4..48c7b49b95 100644 --- a/tests/integration/standard/test_metrics.py +++ b/tests/integration/standard/test_metrics.py @@ -218,7 +218,7 @@ def test_metrics_per_cluster(self): try: # Test write query = SimpleStatement("INSERT INTO {0}.{0} (k, v) VALUES (2, 2)".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) - with pytest.raises(WriteTimeout): + with pytest.raises((WriteTimeout, Unavailable)): self.session.execute(query, timeout=None) finally: get_node(1).resume() @@ -230,7 +230,7 @@ def test_metrics_per_cluster(self): stats_cluster2 = cluster2.metrics.get_stats() # Test direct access to stats - assert 1 == self.cluster.metrics.stats.write_timeouts + assert (1 == self.cluster.metrics.stats.write_timeouts or 1 == self.cluster.metrics.stats.unavailables) assert 0 == cluster2.metrics.stats.write_timeouts # Test direct access to a child stats From 2034f95470ae6861d0e931d97e561c75f28f01cd Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Thu, 8 Jan 2026 13:20:26 +0100 Subject: [PATCH 21/29] Don't check if host is in initial contact points when setting default local_dc --- cassandra/policies.py | 18 +++++------------- tests/unit/test_policies.py | 11 +---------- 2 files changed, 6 insertions(+), 23 deletions(-) diff --git a/cassandra/policies.py b/cassandra/policies.py index bcfd797706..7eea1e709a 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -245,7 +245,6 @@ def __init__(self, local_dc='', used_hosts_per_remote_dc=0): self.used_hosts_per_remote_dc = used_hosts_per_remote_dc self._dc_live_hosts = {} self._position = 0 - self._endpoints = [] LoadBalancingPolicy.__init__(self) def _dc(self, host): @@ -255,11 +254,6 @@ def populate(self, cluster, hosts): for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)): self._dc_live_hosts[dc] = tuple({*dc_hosts, *self._dc_live_hosts.get(dc, [])}) - if not self.local_dc: - self._endpoints = [ - endpoint - for endpoint in cluster.endpoints_resolved] - self._position = randint(0, len(hosts) - 1) if hosts else 0 def distance(self, host): @@ -301,13 +295,11 @@ def on_up(self, host): # not worrying about threads because this will happen during # control connection startup/refresh if not self.local_dc and host.datacenter: - if host.endpoint in self._endpoints: - self.local_dc = host.datacenter - log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); " - "if incorrect, please specify a local_dc to the constructor, " - "or limit contact points to local cluster nodes" % - (self.local_dc, host.endpoint)) - del self._endpoints + self.local_dc = host.datacenter + log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); " + "if incorrect, please specify a local_dc to the constructor, " + "or limit contact points to local cluster nodes" % + (self.local_dc, host.endpoint)) dc = self._dc(host) with self._hosts_lock: diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 65feaf72e5..ecaf6ca7e4 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -556,15 +556,6 @@ def test_default_dc(self): assert policy.local_dc != host_remote.datacenter assert policy.local_dc == host_local.datacenter - # contact DC second - policy = DCAwareRoundRobinPolicy() - policy.populate(cluster, [host_none]) - assert not policy.local_dc - policy.on_add(host_remote) - policy.on_add(host_local) - assert policy.local_dc != host_remote.datacenter - assert policy.local_dc == host_local.datacenter - # no DC policy = DCAwareRoundRobinPolicy() policy.populate(cluster, [host_none]) @@ -577,7 +568,7 @@ def test_default_dc(self): policy.populate(cluster, [host_none]) assert not policy.local_dc policy.on_add(host_remote) - assert not policy.local_dc + assert policy.local_dc class TokenAwarePolicyTest(unittest.TestCase): From 5f7f413bed414ee3fb8a7439cb96921d51baa49e Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Thu, 8 Jan 2026 15:30:03 +0100 Subject: [PATCH 22/29] Call on_add before distance to properly initialize lbp In DC aware lbp when local_dc is not provided we set it in on_add and it needs to be initialized for distance to give proper results. --- cassandra/cluster.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 8db58c13cd..8af4e19801 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2016,14 +2016,14 @@ def on_add(self, host, refresh_nodes=True): log.debug("Handling new host %r and notifying listeners", host) + self.profile_manager.on_add(host) + self.control_connection.on_add(host, refresh_nodes) + distance = self.profile_manager.distance(host) if distance != HostDistance.IGNORED: self._prepare_all_queries(host) log.debug("Done preparing queries for new host %r", host) - self.profile_manager.on_add(host) - self.control_connection.on_add(host, refresh_nodes) - if distance == HostDistance.IGNORED: log.debug("Not adding connection pool for new host %r because the " "load balancing policy has marked it as IGNORED", host) From 921f324311ee7410ce6984bdb41a6282797fb8a3 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Mon, 29 Dec 2025 14:57:42 +0100 Subject: [PATCH 23/29] Don't create Host instances with random host_id Previously, we used endpoints provided to the cluster to create Host instances with random host_ids in order to populate the LBP before the ControlConnection was established. This logic led to creating many connections that were opened and then quickly closed, because once we learned the correct host_ids from system.peers, we removed the old Hosts with random IDs and created new ones with the proper host_ids. This commit introduces a new approach. To establish the ControlConnection, we now use only the resolved contact points from the cluster configuration. Only after a successful connection do we populate Host information in the LBP. If the LBP is already initialized during ControlConnection reconnection, we reuse the existing values. --- cassandra/cluster.py | 108 +++++++----------------------------------- cassandra/metadata.py | 2 +- cassandra/pool.py | 2 +- 3 files changed, 20 insertions(+), 92 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 8af4e19801..a9c1d00e97 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1683,14 +1683,7 @@ def protocol_downgrade(self, host_endpoint, previous_version): "http://datastax.github.io/python-driver/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", self.protocol_version, new_version, host_endpoint) self.protocol_version = new_version - def _add_resolved_hosts(self): - for endpoint in self.endpoints_resolved: - host, new = self.add_host(endpoint, signal=False) - if new: - host.set_up() - for listener in self.listeners: - listener.on_add(host) - + def _populate_hosts(self): self.profile_manager.populate( weakref.proxy(self), self.metadata.all_hosts()) self.load_balancing_policy.populate( @@ -1718,16 +1711,9 @@ def connect(self, keyspace=None, wait_for_all_pools=False): self.connection_class.initialize_reactor() _register_cluster_shutdown(self) - self._add_resolved_hosts() - try: self.control_connection.connect() - - # we set all contact points up for connecting, but we won't infer state after this - for endpoint in self.endpoints_resolved: - h = self.metadata.get_host(endpoint) - if h and self.profile_manager.distance(h) == HostDistance.IGNORED: - h.is_up = None + self._populate_hosts() log.debug("Control connection created") except Exception: @@ -3535,20 +3521,18 @@ def _set_new_connection(self, conn): log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) old.close() - def _connect_host_in_lbp(self): + def _try_connect_to_hosts(self): errors = {} - lbp = ( - self._cluster.load_balancing_policy - if self._cluster._config_mode == _ConfigMode.LEGACY else - self._cluster._default_load_balancing_policy - ) - for host in lbp.make_query_plan(): + lbp = self._cluster.load_balancing_policy \ + if self._cluster._config_mode == _ConfigMode.LEGACY else self._cluster._default_load_balancing_policy + + for endpoint in chain((host.endpoint for host in lbp.make_query_plan()), self._cluster.endpoints_resolved): try: - return (self._try_connect(host.endpoint), None) + return (self._try_connect(endpoint), None) except Exception as exc: - errors[str(host.endpoint)] = exc - log.warning("[control connection] Error connecting to %s:", host, exc_info=True) + errors[str(endpoint)] = exc + log.warning("[control connection] Error connecting to %s:", endpoint, exc_info=True) if self._is_shutdown: raise DriverException("[control connection] Reconnection in progress during shutdown") @@ -3563,16 +3547,16 @@ def _reconnect_internal(self): to the exception that was raised when an attempt was made to open a connection to that host. """ - (conn, _) = self._connect_host_in_lbp() + (conn, _) = self._try_connect_to_hosts() if conn is not None: return conn # Try to re-resolve hostnames as a fallback when all hosts are unreachable self._cluster._resolve_hostnames() - self._cluster._add_resolved_hosts() + self._cluster._populate_hosts() - (conn, errors) = self._connect_host_in_lbp() + (conn, errors) = self._try_connect_to_hosts() if conn is not None: return conn @@ -3814,67 +3798,10 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, self._cluster.metadata.cluster_name = cluster_name partitioner = local_row.get("partitioner") - tokens = local_row.get("tokens") - - host = self._cluster.metadata.get_host(connection.original_endpoint) - if host: - datacenter = local_row.get("data_center") - rack = local_row.get("rack") - self._update_location_info(host, datacenter, rack) - - # support the use case of connecting only with public address - if isinstance(self._cluster.endpoint_factory, SniEndPointFactory): - new_endpoint = self._cluster.endpoint_factory.create(local_row) - - if new_endpoint.address: - host.endpoint = new_endpoint - - host.host_id = local_row.get("host_id") - - found_host_ids.add(host.host_id) - found_endpoints.add(host.endpoint) - - host.listen_address = local_row.get("listen_address") - host.listen_port = local_row.get("listen_port") - host.broadcast_address = _NodeInfo.get_broadcast_address(local_row) - host.broadcast_port = _NodeInfo.get_broadcast_port(local_row) - - host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(local_row) - host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(local_row) - if host.broadcast_rpc_address is None: - if self._token_meta_enabled: - # local rpc_address is not available, use the connection endpoint - host.broadcast_rpc_address = connection.endpoint.address - host.broadcast_rpc_port = connection.endpoint.port - else: - # local rpc_address has not been queried yet, try to fetch it - # separately, which might fail because C* < 2.1.6 doesn't have rpc_address - # in system.local. See CASSANDRA-9436. - local_rpc_address_query = QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS, self._metadata_request_timeout), - consistency_level=ConsistencyLevel.ONE) - success, local_rpc_address_result = connection.wait_for_response( - local_rpc_address_query, timeout=self._timeout, fail_on_error=False) - if success: - row = dict_factory( - local_rpc_address_result.column_names, - local_rpc_address_result.parsed_rows) - host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(row[0]) - host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(row[0]) - else: - host.broadcast_rpc_address = connection.endpoint.address - host.broadcast_rpc_port = connection.endpoint.port - - host.release_version = local_row.get("release_version") - host.dse_version = local_row.get("dse_version") - host.dse_workload = local_row.get("workload") - host.dse_workloads = local_row.get("workloads") + tokens = local_row.get("tokens", None) - if partitioner and tokens: - token_map[host] = tokens + peers_result.insert(0, local_row) - self._cluster.metadata.update_host(host, old_endpoint=connection.endpoint) - connection.original_endpoint = connection.endpoint = host.endpoint # Check metadata.partitioner to see if we haven't built anything yet. If # every node in the cluster was in the contact points, we won't discover # any new nodes, so we need this additional check. (See PYTHON-90) @@ -4173,8 +4100,9 @@ def _get_peers_query(self, peers_query_type, connection=None): query_template = (self._SELECT_SCHEMA_PEERS_TEMPLATE if peers_query_type == self.PeersQueryType.PEERS_SCHEMA else self._SELECT_PEERS_NO_TOKENS_TEMPLATE) - host_release_version = self._cluster.metadata.get_host(connection.original_endpoint).release_version - host_dse_version = self._cluster.metadata.get_host(connection.original_endpoint).dse_version + original_endpoint_host = self._cluster.metadata.get_host(connection.original_endpoint) + host_release_version = None if original_endpoint_host is None else original_endpoint_host.release_version + host_dse_version = None if original_endpoint_host is None else original_endpoint_host.dse_version uses_native_address_query = ( host_dse_version and Version(host_dse_version) >= self._MINIMUM_NATIVE_ADDRESS_DSE_VERSION) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 85f6c45ac6..b85308449e 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -3481,7 +3481,7 @@ def group_keys_by_replica(session, keyspace, table, keys): :class:`~.NO_VALID_REPLICA` Example usage:: - + >>> result = group_keys_by_replica( ... session, "system", "peers", ... (("127.0.0.1", ), ("127.0.0.2", ))) diff --git a/cassandra/pool.py b/cassandra/pool.py index b8a8ef7493..2da657256f 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -176,7 +176,7 @@ def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=No self.endpoint = endpoint if isinstance(endpoint, EndPoint) else DefaultEndPoint(endpoint) self.conviction_policy = conviction_policy_factory(self) if not host_id: - host_id = uuid.uuid4() + raise ValueError("host_id may not be None") self.host_id = host_id self.set_location_info(datacenter, rack) self.lock = RLock() From f2d9022cb1544cc44dd393db47333a2c07e95595 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 21 Jan 2026 09:19:32 +0200 Subject: [PATCH 24/29] (improvement)TokenAwarePolicy::make_query_plan(): remove redundant check if a table is using tablets table_has_tablets() performs the same self._tablets.get((keyspace, table) that get_tablet_for_key() does which is a called a line later does, so it's redundant. Removed the former. Note - perhaps table_has_tablets() is not needed and can be removed? Unsure, it's unclear if it's part of the API or not. It's now completely unused across the code. Adjusted unit tests as well. Signed-off-by: Yaniv Kaul --- cassandra/policies.py | 11 +++++------ tests/unit/test_policies.py | 15 +++++++-------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/cassandra/policies.py b/cassandra/policies.py index 7eea1e709a..e742708019 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -503,15 +503,14 @@ def make_query_plan(self, working_keyspace=None, query=None): return replicas = [] - if self._cluster_metadata._tablets.table_has_tablets(keyspace, query.table): - tablet = self._cluster_metadata._tablets.get_tablet_for_key( + tablet = self._cluster_metadata._tablets.get_tablet_for_key( keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key)) - if tablet is not None: - replicas_mapped = set(map(lambda r: r[0], tablet.replicas)) - child_plan = child.make_query_plan(keyspace, query) + if tablet is not None: + replicas_mapped = set(map(lambda r: r[0], tablet.replicas)) + child_plan = child.make_query_plan(keyspace, query) - replicas = [host for host in child_plan if host.host_id in replicas_mapped] + replicas = [host for host in child_plan if host.host_id in replicas_mapped] else: replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key) diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index ecaf6ca7e4..6142af1aa1 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -576,7 +576,7 @@ def test_wrap_round_robin(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) - cluster.metadata._tablets.table_has_tablets.return_value = [] + cluster.metadata._tablets.get_tablet_for_key.return_value = None hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() @@ -609,7 +609,7 @@ def test_wrap_dc_aware(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) - cluster.metadata._tablets.table_has_tablets.return_value = [] + cluster.metadata._tablets.get_tablet_for_key.return_value = None hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() @@ -658,7 +658,7 @@ def test_wrap_rack_aware(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) - cluster.metadata._tablets.table_has_tablets.return_value = [] + cluster.metadata._tablets.get_tablet_for_key.return_value = None hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(8)] for host in hosts: host.set_up() @@ -803,7 +803,7 @@ def test_statement_keyspace(self): cluster.metadata._tablets = Mock(spec=Tablets) replicas = hosts[2:] cluster.metadata.get_replicas.return_value = replicas - cluster.metadata._tablets.table_has_tablets.return_value = [] + cluster.metadata._tablets.get_tablet_for_key.return_value = None child_policy = Mock() child_policy.make_query_plan.return_value = hosts @@ -896,7 +896,7 @@ def _prepare_cluster_with_vnodes(self): cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata.all_hosts.return_value = hosts cluster.metadata.get_replicas.return_value = hosts[2:] - cluster.metadata._tablets.table_has_tablets.return_value = False + cluster.metadata._tablets.get_tablet_for_key.return_value = None return cluster def _prepare_cluster_with_tablets(self): @@ -908,7 +908,6 @@ def _prepare_cluster_with_tablets(self): cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata.all_hosts.return_value = hosts cluster.metadata.get_replicas.return_value = hosts[2:] - cluster.metadata._tablets.table_has_tablets.return_value = True cluster.metadata._tablets.get_tablet_for_key.return_value = Tablet(replicas=[(h.host_id, 0) for h in hosts[2:]]) return cluster @@ -923,7 +922,7 @@ def _assert_shuffle(self, patched_shuffle, cluster, keyspace, routing_key): policy = TokenAwarePolicy(child_policy, shuffle_replicas=True) policy.populate(cluster, hosts) - is_tablets = cluster.metadata._tablets.table_has_tablets() + is_tablets = cluster.metadata._tablets.get_tablet_for_key() is not None cluster.metadata.get_replicas.reset_mock() child_policy.make_query_plan.reset_mock() @@ -1630,7 +1629,7 @@ def get_replicas(keyspace, packed_key): cluster.metadata.get_replicas.side_effect = get_replicas cluster.metadata._tablets = Mock(spec=Tablets) - cluster.metadata._tablets.table_has_tablets.return_value = [] + cluster.metadata._tablets.get_tablet_for_key.return_value = None child_policy = TokenAwarePolicy(RoundRobinPolicy()) From a00ffa736cd4f0fb09db53cf19effb4568edf187 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sun, 18 Jan 2026 09:19:16 +0200 Subject: [PATCH 25/29] test: optimize test_fast_shutdown with event-based synchronization Replace sleep-based timing with proper Event synchronization: - Remove executor_init with 0.5s sleep - Replace time.sleep(0.5) with connection_created.wait() - Replace time.sleep(3) with executor.shutdown(wait=True) - Reduce iterations from 20 to 3 (deterministic with proper sync) - Add explicit assertions for shutdown state Performance improvement: - Before: 70.13s call tests/unit/test_host_connection_pool.py::HostConnectionTests::test_fast_shutdown - After: 0.01s call tests/unit/test_host_connection_pool.py::HostConnectionTests::test_fast_shutdown Signed-off-by: Yaniv Kaul --- tests/unit/test_host_connection_pool.py | 30 ++++++++++++++++--------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index 580eb336b2..f92bb53785 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -239,7 +239,8 @@ class MockSession(MagicMock): def __init__(self, *args, **kwargs): super(MockSession, self).__init__(*args, **kwargs) self.cluster = MagicMock() - self.cluster.executor = ThreadPoolExecutor(max_workers=2, initializer=self.executor_init) + self.connection_created = Event() + self.cluster.executor = ThreadPoolExecutor(max_workers=2) self.cluster.signal_connection_failure = lambda *args, **kwargs: False self.cluster.connection_factory = self.mock_connection_factory self.connection_counter = 0 @@ -259,23 +260,30 @@ def mock_connection_factory(self, *args, **kwargs): partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port="", shard_aware_port_ssl="")) self.connection_counter += 1 + self.connection_created.set() return connection - def executor_init(self, *args): - time.sleep(0.5) - LOGGER.info("Future start: %s", args) - - for attempt_num in range(20): - LOGGER.info("Testing fast shutdown %d / 20 times", attempt_num + 1) + for attempt_num in range(3): + LOGGER.info("Testing fast shutdown %d / 3 times", attempt_num + 1) host = MagicMock() host.endpoint = "1.2.3.4" - session = self.make_session() + session = MockSession() pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session) LOGGER.info("Initialized pool %s", pool) + + # Wait for initial connection to be created (with timeout) + if not session.connection_created.wait(timeout=2.0): + pytest.fail("Initial connection failed to be created within 2 seconds") + LOGGER.info("Connections: %s", pool._connections) - time.sleep(0.5) + + # Shutdown the pool pool.shutdown() - time.sleep(3) - session.cluster.executor.shutdown() + + # Verify pool is shut down + assert pool.is_shutdown, "Pool should be marked as shutdown" + + # Cleanup executor with proper wait + session.cluster.executor.shutdown(wait=True) From 9f27bcf7afdd36ba977c8a35b4782f0d2d639460 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Fri, 23 Jan 2026 19:44:49 +0200 Subject: [PATCH 26/29] (Fix)race condition during host IP address update When a host changes its IP address, the driver previously updated the host endpoint to the new IP before calling on_down. This caused on_down to mistakenly target the new IP for connection cleanup. This change reorders the operations to ensure on_down cleans up the old IP's resources before the host object is updated and on_up is called for the new IP. Signed-off-by: Yaniv Kaul --- cassandra/cluster.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index a9c1d00e97..099043eae0 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3831,14 +3831,16 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, host = self._cluster.metadata.get_host_by_host_id(host_id) if host and host.endpoint != endpoint: log.debug("[control connection] Updating host ip from %s to %s for (%s)", host.endpoint, endpoint, host_id) - old_endpoint = host.endpoint - host.endpoint = endpoint - self._cluster.metadata.update_host(host, old_endpoint) reconnector = host.get_and_set_reconnection_handler(None) if reconnector: reconnector.cancel() self._cluster.on_down(host, is_host_addition=False, expect_host_to_be_down=True) + old_endpoint = host.endpoint + host.endpoint = endpoint + self._cluster.metadata.update_host(host, old_endpoint) + self._cluster.on_up(host) + if host is None: log.debug("[control connection] Found new host to connect to: %s", endpoint) host, _ = self._cluster.add_host(endpoint, datacenter=datacenter, rack=rack, signal=True, refresh_nodes=False, host_id=host_id) From 82f99aaa8433627239070647f4b295389d6a63c9 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Thu, 29 Jan 2026 08:29:37 -0400 Subject: [PATCH 27/29] add uv files to .gitignore --- .gitignore | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.gitignore b/.gitignore index 28cf1ba218..1b60f54642 100644 --- a/.gitignore +++ b/.gitignore @@ -43,6 +43,11 @@ tests/unit/cython/bytesio_testhelper.c #iPython *.ipynb +uv.lock +.venv/ + + + # Files from upstream that we don't need Jenkinsfile Jenkinsfile.bak From 1886f8e763f65c3060918226891370a3e1f6d9f5 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sat, 10 Jan 2026 00:05:31 +0200 Subject: [PATCH 28/29] Optimize write path in protocol.py to reduce copies Refactored `_ProtocolHandler.encode_message` to reduce memory copies and allocations. - Implemented 'Reserve and Seek' strategy for the write path. - In uncompressed scenarios (including Protocol V5+), we now write directly to the final buffer instead of an intermediate one, avoiding `bytes` creation and buffer copying. - Reserved space for the frame header, wrote the body, and then back-filled the header with the correct length. - Unified buffer initialization and header writing logic for cleaner code. - Optimized conditional checks for compression support. Signed-off-by: Yaniv Kaul --- cassandra/protocol.py | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/cassandra/protocol.py b/cassandra/protocol.py index e574965de8..f37633a756 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -1085,20 +1085,10 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta :param compressor: optional compression function to be used on the body """ flags = 0 - body = io.BytesIO() if msg.custom_payload: if protocol_version < 4: raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher") flags |= CUSTOM_PAYLOAD_FLAG - write_bytesmap(body, msg.custom_payload) - msg.send_body(body, protocol_version) - body = body.getvalue() - - # With checksumming, the compression is done at the segment frame encoding - if (not ProtocolVersion.has_checksumming_support(protocol_version) - and compressor and len(body) > 0): - body = compressor(body) - flags |= COMPRESSED_FLAG if msg.tracing: flags |= TRACING_FLAG @@ -1107,9 +1097,31 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta flags |= USE_BETA_FLAG buff = io.BytesIO() - cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, len(body)) - buff.write(body) + buff.seek(9) + + # With checksumming, the compression is done at the segment frame encoding + if (compressor and not ProtocolVersion.has_checksumming_support(protocol_version)): + body = io.BytesIO() + if msg.custom_payload: + write_bytesmap(body, msg.custom_payload) + msg.send_body(body, protocol_version) + body = body.getvalue() + + if len(body) > 0: + body = compressor(body) + flags |= COMPRESSED_FLAG + + buff.write(body) + length = len(body) + else: + if msg.custom_payload: + write_bytesmap(buff, msg.custom_payload) + msg.send_body(buff, protocol_version) + + length = buff.tell() - 9 + buff.seek(0) + cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, length) return buff.getvalue() @staticmethod From a61336654bee0ca8581b5d7e8843ddc3aac06940 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 27 Dec 2025 10:36:37 +0000 Subject: [PATCH 29/29] Initial plan