|  | 
|  | 1 | +/* | 
|  | 2 | + * All changes to the original code are Copyright DataStax, Inc. | 
|  | 3 | + * | 
|  | 4 | + * Please see the included license file for details. | 
|  | 5 | + */ | 
|  | 6 | + | 
|  | 7 | +/* | 
|  | 8 | + * Original license: | 
|  | 9 | + * Licensed to the Apache Software Foundation (ASF) under one or more | 
|  | 10 | + * contributor license agreements.  See the NOTICE file distributed with | 
|  | 11 | + * this work for additional information regarding copyright ownership. | 
|  | 12 | + * The ASF licenses this file to You under the Apache License, Version 2.0 | 
|  | 13 | + * (the "License"); you may not use this file except in compliance with | 
|  | 14 | + * the License.  You may obtain a copy of the License at | 
|  | 15 | + * | 
|  | 16 | + *     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 17 | + * | 
|  | 18 | + * Unless required by applicable law or agreed to in writing, software | 
|  | 19 | + * distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 20 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 21 | + * See the License for the specific language governing permissions and | 
|  | 22 | + * limitations under the License. | 
|  | 23 | + */ | 
|  | 24 | + | 
|  | 25 | +package io.github.jbellis.jvector.graph; | 
|  | 26 | + | 
|  | 27 | +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; | 
|  | 28 | +import io.github.jbellis.jvector.LuceneTestCase; | 
|  | 29 | +import io.github.jbellis.jvector.TestUtil; | 
|  | 30 | +import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider; | 
|  | 31 | +import io.github.jbellis.jvector.util.BitSet; | 
|  | 32 | +import io.github.jbellis.jvector.util.BoundedLongHeap; | 
|  | 33 | +import io.github.jbellis.jvector.util.FixedBitSet; | 
|  | 34 | +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; | 
|  | 35 | +import io.github.jbellis.jvector.vector.types.VectorFloat; | 
|  | 36 | +import org.junit.Test; | 
|  | 37 | + | 
|  | 38 | +import java.io.IOException; | 
|  | 39 | +import java.util.Arrays; | 
|  | 40 | +import java.util.HashMap; | 
|  | 41 | +import java.util.List; | 
|  | 42 | +import java.util.Map; | 
|  | 43 | + | 
|  | 44 | +import static org.junit.Assert.assertEquals; | 
|  | 45 | + | 
|  | 46 | +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) | 
|  | 47 | +public class TestLowCardinalityFiltering extends LuceneTestCase { | 
|  | 48 | +    @Test | 
|  | 49 | +    public void testLowCardinalityFiltering() throws IOException { | 
|  | 50 | +        testLowCardinalityFiltering(32, 0.01f, 0.87f, false); | 
|  | 51 | +        testLowCardinalityFiltering(32, 0.01f, 0.87f, true); | 
|  | 52 | +    } | 
|  | 53 | +    public void testLowCardinalityFiltering(int maxDegree, float visitedRatioThreshold, float recallThreshold, boolean addHierarchy) throws IOException { | 
|  | 54 | +        var R = getRandom(); | 
|  | 55 | + | 
|  | 56 | +        int nVectors = 100_000; | 
|  | 57 | +        int nQueries = 100; | 
|  | 58 | +        int dimensions = 16; | 
|  | 59 | +        int topK = 10; | 
|  | 60 | + | 
|  | 61 | +        VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.COSINE; | 
|  | 62 | + | 
|  | 63 | +        // build index | 
|  | 64 | +        VectorFloat<?>[] vectors = TestVectorGraph.createRandomFloatVectors(nVectors, dimensions, R); | 
|  | 65 | +        var ravv = new ListRandomAccessVectorValues(List.of(vectors), dimensions); | 
|  | 66 | +        var builder = new GraphIndexBuilder(ravv, similarityFunction, maxDegree, 2 * maxDegree, 1.2f, 1.2f, addHierarchy); | 
|  | 67 | +        var onHeapGraph = builder.build(ravv); | 
|  | 68 | + | 
|  | 69 | +        // Build the set of accepted ordinals. There are two classes evenly split. | 
|  | 70 | +        Map<Boolean, BitSet> bitSets = new HashMap<>(); | 
|  | 71 | +        bitSets.put(true, new FixedBitSet(nVectors)); | 
|  | 72 | +        bitSets.put(false, new FixedBitSet(nVectors)); | 
|  | 73 | +        for (int j = 0; j < nVectors; j++) { | 
|  | 74 | +            bitSets.get(R.nextBoolean()).set(j); | 
|  | 75 | +        } | 
|  | 76 | + | 
|  | 77 | +        // test raw vectors | 
|  | 78 | +        var searcher = new GraphSearcher(onHeapGraph); | 
|  | 79 | + | 
|  | 80 | +        float meanVisitedRatio = 0; | 
|  | 81 | +        float meanRecall = 0; | 
|  | 82 | + | 
|  | 83 | +        for (int i = 0; i < nQueries; i++) { | 
|  | 84 | +            VectorFloat<?> query = TestUtil.randomVector(R, dimensions); | 
|  | 85 | +            boolean queryClass = R.nextBoolean(); | 
|  | 86 | + | 
|  | 87 | +            var sf = ravv.rerankerFor(query, similarityFunction); | 
|  | 88 | +            var result = searcher.search(new DefaultSearchScoreProvider(sf), topK, 0, bitSets.get(queryClass)); | 
|  | 89 | + | 
|  | 90 | +            float recall = getRecall(ravv, bitSets, similarityFunction, query, queryClass, topK, result); | 
|  | 91 | + | 
|  | 92 | +            meanVisitedRatio += ((float) result.getVisitedCount()) / (vectors.length * nQueries); | 
|  | 93 | +            meanRecall += recall / (nQueries * topK); | 
|  | 94 | +        } | 
|  | 95 | + | 
|  | 96 | +        System.out.println("meanVisitedRatio " +  meanVisitedRatio); | 
|  | 97 | +        System.out.println("meanRecall " +  meanRecall); | 
|  | 98 | + | 
|  | 99 | +        assert meanVisitedRatio < visitedRatioThreshold : "visited " + meanVisitedRatio * 100 + "% of the vectors, which is more than " + visitedRatioThreshold * 100 + "%"; | 
|  | 100 | +        assert meanRecall > recallThreshold : "the recall is too low: " + meanRecall + " < " + recallThreshold; | 
|  | 101 | +    } | 
|  | 102 | + | 
|  | 103 | +    /** | 
|  | 104 | +     * Create "interesting" test parameters -- shouldn't match too many (we want to validate | 
|  | 105 | +     * that threshold code doesn't just crawl the entire graph) or too few (we might not find them) | 
|  | 106 | +     */ | 
|  | 107 | +    private float getRecall(RandomAccessVectorValues ravv, Map<Boolean, BitSet> bitSets, VectorSimilarityFunction similarityFunction, VectorFloat<?> query, boolean queryClass, int topK, SearchResult result) { | 
|  | 108 | +        var resultNodes = result.getNodes(); | 
|  | 109 | +        assertEquals(topK, resultNodes.length); | 
|  | 110 | + | 
|  | 111 | +        NodeQueue expected = new NodeQueue(new BoundedLongHeap(topK), NodeQueue.Order.MIN_HEAP); | 
|  | 112 | +        for (int j = 0; j < ravv.size(); j++) { | 
|  | 113 | +            if (bitSets.get(queryClass).get(j)) { | 
|  | 114 | +                expected.push(j, similarityFunction.compare(query, ravv.getVector(j))); | 
|  | 115 | +            } | 
|  | 116 | +        } | 
|  | 117 | +        var actualNodeIds = Arrays.stream(resultNodes, 0, topK).mapToInt(nodeScore -> nodeScore.node).toArray(); | 
|  | 118 | + | 
|  | 119 | +        return computeOverlap(actualNodeIds, expected.nodesCopy()); | 
|  | 120 | +    } | 
|  | 121 | + | 
|  | 122 | +    private int computeOverlap(int[] a, int[] b) { | 
|  | 123 | +        Arrays.sort(a); | 
|  | 124 | +        Arrays.sort(b); | 
|  | 125 | +        int overlap = 0; | 
|  | 126 | +        for (int i = 0, j = 0; i < a.length && j < b.length; ) { | 
|  | 127 | +            if (a[i] == b[j]) { | 
|  | 128 | +                ++overlap; | 
|  | 129 | +                ++i; | 
|  | 130 | +                ++j; | 
|  | 131 | +            } else if (a[i] > b[j]) { | 
|  | 132 | +                ++j; | 
|  | 133 | +            } else { | 
|  | 134 | +                ++i; | 
|  | 135 | +            } | 
|  | 136 | +        } | 
|  | 137 | +        return overlap; | 
|  | 138 | +    } | 
|  | 139 | +} | 
0 commit comments