From 8f5bb82913e6b3eccf9f172184e3973ba58a9272 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Tue, 21 Oct 2025 18:54:42 -0500 Subject: [PATCH 1/5] CNDB-15703: Fix ScoreTracker impl init and reset methods --- .../jbellis/jvector/graph/ScoreTracker.java | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java index 21be9e3b0..11eed369e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java @@ -41,7 +41,7 @@ public ScoreTracker getScoreTracker(boolean pruneSearch, int rerankK, float thre if (threshold > 0) { if (twoPhaseTracker == null) { - twoPhaseTracker = new ScoreTracker.TwoPhaseTracker(); + twoPhaseTracker = new ScoreTracker.TwoPhaseTracker(threshold); } else { twoPhaseTracker.reset(threshold); } @@ -49,7 +49,7 @@ public ScoreTracker getScoreTracker(boolean pruneSearch, int rerankK, float thre } else { if (pruneSearch) { if (relaxedMonotonicityTracker == null) { - relaxedMonotonicityTracker = new ScoreTracker.RelaxedMonotonicityTracker(); + relaxedMonotonicityTracker = new ScoreTracker.RelaxedMonotonicityTracker(rerankK); } else { relaxedMonotonicityTracker.reset(rerankK); } @@ -109,10 +109,6 @@ class TwoPhaseTracker implements ScoreTracker { this.threshold = threshold; } - TwoPhaseTracker() { - this(0); - } - void reset(double threshold) { this.bestScores.clear(); this.observationCount = 0; @@ -195,10 +191,6 @@ class RelaxedMonotonicityTracker implements ScoreTracker { this.dSquared = 0; } - RelaxedMonotonicityTracker() { - this(100); - } - private static int getRecentScoresSize(int bestScoresTracked) { // A quick empirical study yields that the number of recent scores // that we need to consider grows by a factor of ~sqrt(bestScoresTracked / 2) @@ -211,7 +203,11 @@ void reset(int bestScoresTracked) { if (this.recentScoresSize > recentScores.length) { recentScores = ArrayUtil.grow(recentScores, this.recentScoresSize); } - this.bestScores.clear(); + if (this.recentScoresSize > bestScores.size()) { + bestScores = new BoundedLongHeap(this.recentScoresSize); + } else { + bestScores.clear(); + } this.observationCount = 0; this.mean = 0; this.dSquared = 0; @@ -219,6 +215,10 @@ void reset(int bestScoresTracked) { @Override public void track(float score) { + // TODO do we want to implement a resizeable bestScores heap? + if (bestScores.size() == this.recentScoresSize) { + bestScores.pop(); + } bestScores.push(floatToSortableInt(score)); observationCount++; From b59246db4a9b7c951e9ca103a62bfc494ded8a60 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Tue, 21 Oct 2025 20:35:28 -0500 Subject: [PATCH 2/5] Use setMaxSize on BoundedLongHeap --- .../java/io/github/jbellis/jvector/graph/ScoreTracker.java | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java index 11eed369e..919976641 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java @@ -207,6 +207,7 @@ void reset(int bestScoresTracked) { bestScores = new BoundedLongHeap(this.recentScoresSize); } else { bestScores.clear(); + bestScores.setMaxSize(this.recentScoresSize); } this.observationCount = 0; this.mean = 0; @@ -215,10 +216,6 @@ void reset(int bestScoresTracked) { @Override public void track(float score) { - // TODO do we want to implement a resizeable bestScores heap? - if (bestScores.size() == this.recentScoresSize) { - bestScores.pop(); - } bestScores.push(floatToSortableInt(score)); observationCount++; From d9e9e84e661294699814a983dc2ec23aa996ec72 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Wed, 22 Oct 2025 10:04:55 -0500 Subject: [PATCH 3/5] Fix reset, make it use the right new maxSize --- .../io/github/jbellis/jvector/graph/ScoreTracker.java | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java index 919976641..815476db4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java @@ -203,12 +203,8 @@ void reset(int bestScoresTracked) { if (this.recentScoresSize > recentScores.length) { recentScores = ArrayUtil.grow(recentScores, this.recentScoresSize); } - if (this.recentScoresSize > bestScores.size()) { - bestScores = new BoundedLongHeap(this.recentScoresSize); - } else { - bestScores.clear(); - bestScores.setMaxSize(this.recentScoresSize); - } + bestScores.clear(); + bestScores.setMaxSize(bestScoresTracked); this.observationCount = 0; this.mean = 0; this.dSquared = 0; From 56c6f9e3630b020359fbdcc0a1c540ddc392d42c Mon Sep 17 00:00:00 2001 From: Mariano Tepper Date: Wed, 22 Oct 2025 09:35:38 -0700 Subject: [PATCH 4/5] Add test to mimic https://github.com/jbellis/pwtest natively --- .../graph/TestLowCardinalityFiltering.java | 115 ++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestLowCardinalityFiltering.java diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestLowCardinalityFiltering.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestLowCardinalityFiltering.java new file mode 100644 index 000000000..8a9197e8f --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestLowCardinalityFiltering.java @@ -0,0 +1,115 @@ +package io.github.jbellis.jvector.graph; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import io.github.jbellis.jvector.LuceneTestCase; +import io.github.jbellis.jvector.TestUtil; +import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider; +import io.github.jbellis.jvector.util.BitSet; +import io.github.jbellis.jvector.util.BoundedLongHeap; +import io.github.jbellis.jvector.util.FixedBitSet; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import org.junit.Test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class TestLowCardinalityFiltering extends LuceneTestCase { + @Test + public void testLowCardinalityFiltering() throws IOException { + testLowCardinalityFiltering(32, 0.01f, 0.87f, false); + testLowCardinalityFiltering(32, 0.01f, 0.87f, true); + } + public void testLowCardinalityFiltering(int maxDegree, float visitedRatioThreshold, float recallThreshold, boolean addHierarchy) throws IOException { + var R = getRandom(); + + int nVectors = 100_000; + int nQueries = 100; + int dimensions = 16; + int topK = 10; + + VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.COSINE; + + // build index + VectorFloat[] vectors = TestVectorGraph.createRandomFloatVectors(nVectors, dimensions, R); + var ravv = new ListRandomAccessVectorValues(List.of(vectors), dimensions); + var builder = new GraphIndexBuilder(ravv, similarityFunction, maxDegree, 2 * maxDegree, 1.2f, 1.2f, addHierarchy); + var onHeapGraph = builder.build(ravv); + + // Build the set of accepted ordinals. There are two classes evenly split. + Map bitSets = new HashMap<>(); + bitSets.put(true, new FixedBitSet(nVectors)); + bitSets.put(false, new FixedBitSet(nVectors)); + for (int j = 0; j < nVectors; j++) { + bitSets.get(R.nextBoolean()).set(j); + } + + // test raw vectors + var searcher = new GraphSearcher(onHeapGraph); + + float meanVisitedRatio = 0; + float meanRecall = 0; + + for (int i = 0; i < nQueries; i++) { + VectorFloat query = TestUtil.randomVector(R, dimensions); + boolean queryClass = R.nextBoolean(); + + var sf = ravv.rerankerFor(query, similarityFunction); + var result = searcher.search(new DefaultSearchScoreProvider(sf), topK, 0, bitSets.get(queryClass)); + + float recall = getRecall(ravv, bitSets, similarityFunction, query, queryClass, topK, result); + + meanVisitedRatio += ((float) result.getVisitedCount()) / (vectors.length * nQueries); + meanRecall += recall / (nQueries * topK); + } + + System.out.println("meanVisitedRatio " + meanVisitedRatio); + System.out.println("meanRecall " + meanRecall); + + assert meanVisitedRatio < visitedRatioThreshold : "visited " + meanVisitedRatio * 100 + "% of the vectors, which is more than " + visitedRatioThreshold * 100 + "%"; + assert meanRecall > recallThreshold : "the recall is too low: " + meanRecall + " < " + recallThreshold; + } + + /** + * Create "interesting" test parameters -- shouldn't match too many (we want to validate + * that threshold code doesn't just crawl the entire graph) or too few (we might not find them) + */ + private float getRecall(RandomAccessVectorValues ravv, Map bitSets, VectorSimilarityFunction similarityFunction, VectorFloat query, boolean queryClass, int topK, SearchResult result) { + var resultNodes = result.getNodes(); + assertEquals(topK, resultNodes.length); + + NodeQueue expected = new NodeQueue(new BoundedLongHeap(topK), NodeQueue.Order.MIN_HEAP); + for (int j = 0; j < ravv.size(); j++) { + if (bitSets.get(queryClass).get(j)) { + expected.push(j, similarityFunction.compare(query, ravv.getVector(j))); + } + } + var actualNodeIds = Arrays.stream(resultNodes, 0, topK).mapToInt(nodeScore -> nodeScore.node).toArray(); + + return computeOverlap(actualNodeIds, expected.nodesCopy()); + } + + private int computeOverlap(int[] a, int[] b) { + Arrays.sort(a); + Arrays.sort(b); + int overlap = 0; + for (int i = 0, j = 0; i < a.length && j < b.length; ) { + if (a[i] == b[j]) { + ++overlap; + ++i; + ++j; + } else if (a[i] > b[j]) { + ++j; + } else { + ++i; + } + } + return overlap; + } +} From f2630e69959d113bd3f5987df7b121a00cfbd2ba Mon Sep 17 00:00:00 2001 From: Mariano Tepper Date: Wed, 22 Oct 2025 09:44:34 -0700 Subject: [PATCH 5/5] Add missing license --- .../graph/TestLowCardinalityFiltering.java | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestLowCardinalityFiltering.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestLowCardinalityFiltering.java index 8a9197e8f..da9b3e559 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestLowCardinalityFiltering.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestLowCardinalityFiltering.java @@ -1,3 +1,27 @@ +/* + * All changes to the original code are Copyright DataStax, Inc. + * + * Please see the included license file for details. + */ + +/* + * Original license: + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.github.jbellis.jvector.graph; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;