diff --git a/cassandra/policies.py b/cassandra/policies.py index bcfd797706..c4b2dad0fc 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -505,21 +505,27 @@ def make_query_plan(self, working_keyspace=None, query=None): keyspace = query.keyspace if query and query.keyspace else working_keyspace child = self._child_policy + + # Early return case: pass through the generator to preserve lazy evaluation + # This avoids materializing the full host list in memory when we don't need token-aware routing if query is None or query.routing_key is None or keyspace is None: for host in child.make_query_plan(keyspace, query): yield host return + # Call child.make_query_plan only once and convert to list for reuse + # List conversion is necessary because we iterate over it twice: + # 1. To identify replicas (either from tablets or token ring) + # 2. To yield remaining hosts not in the replica set + child_plan = list(child.make_query_plan(keyspace, query)) + replicas = [] - if self._cluster_metadata._tablets.table_has_tablets(keyspace, query.table): - tablet = self._cluster_metadata._tablets.get_tablet_for_key( + tablet = self._cluster_metadata._tablets.get_tablet_for_key( keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key)) - if tablet is not None: - replicas_mapped = set(map(lambda r: r[0], tablet.replicas)) - child_plan = child.make_query_plan(keyspace, query) - - replicas = [host for host in child_plan if host.host_id in replicas_mapped] + if tablet is not None: + replicas_mapped = set(map(lambda r: r[0], tablet.replicas)) + replicas = [host for host in child_plan if host.host_id in replicas_mapped] else: replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key) @@ -535,7 +541,7 @@ def yield_in_order(hosts): # yield replicas: local_rack, local, remote yield from yield_in_order(replicas) # yield rest of the cluster: local_rack, local, remote - yield from yield_in_order([host for host in child.make_query_plan(keyspace, query) if host not in replicas]) + yield from yield_in_order([host for host in child_plan if host not in replicas]) def on_up(self, *args, **kwargs): return self._child_policy.on_up(*args, **kwargs) diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index e15705c8f7..f80f21dd51 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -584,7 +584,7 @@ def test_wrap_round_robin(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) - cluster.metadata._tablets.table_has_tablets.return_value = [] + cluster.metadata._tablets.get_tablet_for_key.return_value = None hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] for host in hosts: host.set_up() @@ -617,7 +617,7 @@ def test_wrap_dc_aware(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) - cluster.metadata._tablets.table_has_tablets.return_value = [] + cluster.metadata._tablets.get_tablet_for_key.return_value = None hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] for host in hosts: host.set_up() @@ -666,7 +666,7 @@ def test_wrap_rack_aware(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) - cluster.metadata._tablets.table_has_tablets.return_value = [] + cluster.metadata._tablets.get_tablet_for_key.return_value = None hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(8)] for host in hosts: host.set_up() @@ -811,7 +811,7 @@ def test_statement_keyspace(self): cluster.metadata._tablets = Mock(spec=Tablets) replicas = hosts[2:] cluster.metadata.get_replicas.return_value = replicas - cluster.metadata._tablets.table_has_tablets.return_value = [] + cluster.metadata._tablets.get_tablet_for_key.return_value = None child_policy = Mock() child_policy.make_query_plan.return_value = hosts @@ -904,7 +904,7 @@ def _prepare_cluster_with_vnodes(self): cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata.all_hosts.return_value = hosts cluster.metadata.get_replicas.return_value = hosts[2:] - cluster.metadata._tablets.table_has_tablets.return_value = False + cluster.metadata._tablets.get_tablet_for_key.return_value = None return cluster def _prepare_cluster_with_tablets(self): @@ -916,7 +916,6 @@ def _prepare_cluster_with_tablets(self): cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata.all_hosts.return_value = hosts cluster.metadata.get_replicas.return_value = hosts[2:] - cluster.metadata._tablets.table_has_tablets.return_value = True cluster.metadata._tablets.get_tablet_for_key.return_value = Tablet(replicas=[(h.host_id, 0) for h in hosts[2:]]) return cluster @@ -931,8 +930,6 @@ def _assert_shuffle(self, patched_shuffle, cluster, keyspace, routing_key): policy = TokenAwarePolicy(child_policy, shuffle_replicas=True) policy.populate(cluster, hosts) - is_tablets = cluster.metadata._tablets.table_has_tablets() - cluster.metadata.get_replicas.reset_mock() child_policy.make_query_plan.reset_mock() query = Statement(routing_key=routing_key) @@ -945,13 +942,69 @@ def _assert_shuffle(self, patched_shuffle, cluster, keyspace, routing_key): else: assert set(replicas) == set(qplan[:2]) assert hosts[:2] == qplan[2:] - if is_tablets: - child_policy.make_query_plan.assert_called_with(keyspace, query) - assert child_policy.make_query_plan.call_count == 2 - else: - child_policy.make_query_plan.assert_called_once_with(keyspace, query) + # After optimization, child.make_query_plan should be called once for both tablets and vnodes + child_policy.make_query_plan.assert_called_once_with(keyspace, query) assert patched_shuffle.call_count == 1 + def test_child_make_query_plan_called_once(self): + """ + Test to validate that child.make_query_plan is called only once + in all scenarios (with/without tablets, with/without routing key) + + @test_category policy + """ + # Test with vnodes (no tablets) + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + for host in hosts: + host.set_up() + + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata._tablets.get_tablet_for_key.return_value = None # No tablets for this table + replicas = hosts[2:] + cluster.metadata.get_replicas.return_value = replicas + + child_policy = Mock() + child_policy.make_query_plan.return_value = hosts + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy) + policy.populate(cluster, hosts) + + # Test case 1: With routing key and keyspace (should call once) + child_policy.reset_mock() + keyspace = 'keyspace' + routing_key = 'routing_key' + query = Statement(routing_key=routing_key, keyspace=keyspace) + qplan = list(policy.make_query_plan(keyspace, query)) + child_policy.make_query_plan.assert_called_once_with(keyspace, query) + + # Test case 2: Without routing key (should call once) + child_policy.reset_mock() + query = Statement(routing_key=None, keyspace=keyspace) + qplan = list(policy.make_query_plan(keyspace, query)) + child_policy.make_query_plan.assert_called_once_with(keyspace, query) + + # Test case 3: Without keyspace (should call once) + child_policy.reset_mock() + query = Statement(routing_key=routing_key, keyspace=None) + qplan = list(policy.make_query_plan(None, query)) + child_policy.make_query_plan.assert_called_once_with(None, query) + + # Test case 4: With tablets (should call once) + tablet = Mock(spec=Tablet) + tablet.replicas = [(hosts[0].host_id, None), (hosts[1].host_id, None)] + cluster.metadata._tablets.get_tablet_for_key.return_value = tablet + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class = Mock() + cluster.metadata.token_map.token_class.from_key.return_value = 'token' + + child_policy.reset_mock() + query = Statement(routing_key=routing_key, keyspace=keyspace, table='test_table') + qplan = list(policy.make_query_plan(keyspace, query)) + child_policy.make_query_plan.assert_called_once_with(keyspace, query) + class ConvictionPolicyTest(unittest.TestCase): def test_not_implemented(self): @@ -1638,7 +1691,7 @@ def get_replicas(keyspace, packed_key): cluster.metadata.get_replicas.side_effect = get_replicas cluster.metadata._tablets = Mock(spec=Tablets) - cluster.metadata._tablets.table_has_tablets.return_value = [] + cluster.metadata._tablets.get_tablet_for_key.return_value = None child_policy = TokenAwarePolicy(RoundRobinPolicy())