Skip to content

Commit d95bc61

Browse files
Fix ScoreTracker initialization and reset methods (#551)
PR #501 introduced a subtle bug in the ScoreTracker that was delaying the early termination (the search takes longer but recall is higher). This PR returns to the old behavior with earlier terminations.
1 parent d20733b commit d95bc61

File tree

2 files changed

+143
-11
lines changed

2 files changed

+143
-11
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ public ScoreTracker getScoreTracker(boolean pruneSearch, int rerankK, float thre
4141

4242
if (threshold > 0) {
4343
if (twoPhaseTracker == null) {
44-
twoPhaseTracker = new ScoreTracker.TwoPhaseTracker();
44+
twoPhaseTracker = new ScoreTracker.TwoPhaseTracker(threshold);
4545
} else {
4646
twoPhaseTracker.reset(threshold);
4747
}
4848
scoreTracker = twoPhaseTracker;
4949
} else {
5050
if (pruneSearch) {
5151
if (relaxedMonotonicityTracker == null) {
52-
relaxedMonotonicityTracker = new ScoreTracker.RelaxedMonotonicityTracker();
52+
relaxedMonotonicityTracker = new ScoreTracker.RelaxedMonotonicityTracker(rerankK);
5353
} else {
5454
relaxedMonotonicityTracker.reset(rerankK);
5555
}
@@ -109,10 +109,6 @@ class TwoPhaseTracker implements ScoreTracker {
109109
this.threshold = threshold;
110110
}
111111

112-
TwoPhaseTracker() {
113-
this(0);
114-
}
115-
116112
void reset(double threshold) {
117113
this.bestScores.clear();
118114
this.observationCount = 0;
@@ -195,10 +191,6 @@ class RelaxedMonotonicityTracker implements ScoreTracker {
195191
this.dSquared = 0;
196192
}
197193

198-
RelaxedMonotonicityTracker() {
199-
this(100);
200-
}
201-
202194
private static int getRecentScoresSize(int bestScoresTracked) {
203195
// A quick empirical study yields that the number of recent scores
204196
// that we need to consider grows by a factor of ~sqrt(bestScoresTracked / 2)
@@ -211,7 +203,8 @@ void reset(int bestScoresTracked) {
211203
if (this.recentScoresSize > recentScores.length) {
212204
recentScores = ArrayUtil.grow(recentScores, this.recentScoresSize);
213205
}
214-
this.bestScores.clear();
206+
bestScores.clear();
207+
bestScores.setMaxSize(bestScoresTracked);
215208
this.observationCount = 0;
216209
this.mean = 0;
217210
this.dSquared = 0;
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* All changes to the original code are Copyright DataStax, Inc.
3+
*
4+
* Please see the included license file for details.
5+
*/
6+
7+
/*
8+
* Original license:
9+
* Licensed to the Apache Software Foundation (ASF) under one or more
10+
* contributor license agreements. See the NOTICE file distributed with
11+
* this work for additional information regarding copyright ownership.
12+
* The ASF licenses this file to You under the Apache License, Version 2.0
13+
* (the "License"); you may not use this file except in compliance with
14+
* the License. You may obtain a copy of the License at
15+
*
16+
* http://www.apache.org/licenses/LICENSE-2.0
17+
*
18+
* Unless required by applicable law or agreed to in writing, software
19+
* distributed under the License is distributed on an "AS IS" BASIS,
20+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21+
* See the License for the specific language governing permissions and
22+
* limitations under the License.
23+
*/
24+
25+
package io.github.jbellis.jvector.graph;
26+
27+
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
28+
import io.github.jbellis.jvector.LuceneTestCase;
29+
import io.github.jbellis.jvector.TestUtil;
30+
import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider;
31+
import io.github.jbellis.jvector.util.BitSet;
32+
import io.github.jbellis.jvector.util.BoundedLongHeap;
33+
import io.github.jbellis.jvector.util.FixedBitSet;
34+
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
35+
import io.github.jbellis.jvector.vector.types.VectorFloat;
36+
import org.junit.Test;
37+
38+
import java.io.IOException;
39+
import java.util.Arrays;
40+
import java.util.HashMap;
41+
import java.util.List;
42+
import java.util.Map;
43+
44+
import static org.junit.Assert.assertEquals;
45+
46+
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
47+
public class TestLowCardinalityFiltering extends LuceneTestCase {
48+
@Test
49+
public void testLowCardinalityFiltering() throws IOException {
50+
testLowCardinalityFiltering(32, 0.01f, 0.87f, false);
51+
testLowCardinalityFiltering(32, 0.01f, 0.87f, true);
52+
}
53+
public void testLowCardinalityFiltering(int maxDegree, float visitedRatioThreshold, float recallThreshold, boolean addHierarchy) throws IOException {
54+
var R = getRandom();
55+
56+
int nVectors = 100_000;
57+
int nQueries = 100;
58+
int dimensions = 16;
59+
int topK = 10;
60+
61+
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.COSINE;
62+
63+
// build index
64+
VectorFloat<?>[] vectors = TestVectorGraph.createRandomFloatVectors(nVectors, dimensions, R);
65+
var ravv = new ListRandomAccessVectorValues(List.of(vectors), dimensions);
66+
var builder = new GraphIndexBuilder(ravv, similarityFunction, maxDegree, 2 * maxDegree, 1.2f, 1.2f, addHierarchy);
67+
var onHeapGraph = builder.build(ravv);
68+
69+
// Build the set of accepted ordinals. There are two classes evenly split.
70+
Map<Boolean, BitSet> bitSets = new HashMap<>();
71+
bitSets.put(true, new FixedBitSet(nVectors));
72+
bitSets.put(false, new FixedBitSet(nVectors));
73+
for (int j = 0; j < nVectors; j++) {
74+
bitSets.get(R.nextBoolean()).set(j);
75+
}
76+
77+
// test raw vectors
78+
var searcher = new GraphSearcher(onHeapGraph);
79+
80+
float meanVisitedRatio = 0;
81+
float meanRecall = 0;
82+
83+
for (int i = 0; i < nQueries; i++) {
84+
VectorFloat<?> query = TestUtil.randomVector(R, dimensions);
85+
boolean queryClass = R.nextBoolean();
86+
87+
var sf = ravv.rerankerFor(query, similarityFunction);
88+
var result = searcher.search(new DefaultSearchScoreProvider(sf), topK, 0, bitSets.get(queryClass));
89+
90+
float recall = getRecall(ravv, bitSets, similarityFunction, query, queryClass, topK, result);
91+
92+
meanVisitedRatio += ((float) result.getVisitedCount()) / (vectors.length * nQueries);
93+
meanRecall += recall / (nQueries * topK);
94+
}
95+
96+
System.out.println("meanVisitedRatio " + meanVisitedRatio);
97+
System.out.println("meanRecall " + meanRecall);
98+
99+
assert meanVisitedRatio < visitedRatioThreshold : "visited " + meanVisitedRatio * 100 + "% of the vectors, which is more than " + visitedRatioThreshold * 100 + "%";
100+
assert meanRecall > recallThreshold : "the recall is too low: " + meanRecall + " < " + recallThreshold;
101+
}
102+
103+
/**
104+
* Create "interesting" test parameters -- shouldn't match too many (we want to validate
105+
* that threshold code doesn't just crawl the entire graph) or too few (we might not find them)
106+
*/
107+
private float getRecall(RandomAccessVectorValues ravv, Map<Boolean, BitSet> bitSets, VectorSimilarityFunction similarityFunction, VectorFloat<?> query, boolean queryClass, int topK, SearchResult result) {
108+
var resultNodes = result.getNodes();
109+
assertEquals(topK, resultNodes.length);
110+
111+
NodeQueue expected = new NodeQueue(new BoundedLongHeap(topK), NodeQueue.Order.MIN_HEAP);
112+
for (int j = 0; j < ravv.size(); j++) {
113+
if (bitSets.get(queryClass).get(j)) {
114+
expected.push(j, similarityFunction.compare(query, ravv.getVector(j)));
115+
}
116+
}
117+
var actualNodeIds = Arrays.stream(resultNodes, 0, topK).mapToInt(nodeScore -> nodeScore.node).toArray();
118+
119+
return computeOverlap(actualNodeIds, expected.nodesCopy());
120+
}
121+
122+
private int computeOverlap(int[] a, int[] b) {
123+
Arrays.sort(a);
124+
Arrays.sort(b);
125+
int overlap = 0;
126+
for (int i = 0, j = 0; i < a.length && j < b.length; ) {
127+
if (a[i] == b[j]) {
128+
++overlap;
129+
++i;
130+
++j;
131+
} else if (a[i] > b[j]) {
132+
++j;
133+
} else {
134+
++i;
135+
}
136+
}
137+
return overlap;
138+
}
139+
}

0 commit comments

Comments
 (0)