Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.vector.types.VectorFloat;

import java.util.Arrays;

public class RemappedRandomAccessVectorValues implements RandomAccessVectorValues {
private final RandomAccessVectorValues ravv;
private final int[] graphToRavvOrdMap;

/**
* Remaps a RAVV to a different set of ordinals. This is useful when the ordinals used by the graph
* do not match the ordinals used by the RAVV.
*
* @param ravv the RAVV to remap
* @param graphToRavvOrdMap a mapping from the graph's ordinals to the RAVV's ordinals where
* graphToRavvOrdMap[i] is the RAVV ordinal corresponding to graph ordinal i.
*/
public RemappedRandomAccessVectorValues(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap) {
this.ravv = ravv;
this.graphToRavvOrdMap = graphToRavvOrdMap;
}

@Override
public int size() {
return graphToRavvOrdMap.length;
}

@Override
public int dimension() {
return ravv.dimension();
}

@Override
public VectorFloat<?> getVector(int node) {
return ravv.getVector(graphToRavvOrdMap[node]);
}

@Override
public boolean isValueShared() {
return ravv.isValueShared();
}

@Override
public RandomAccessVectorValues copy() {
return new RemappedRandomAccessVectorValues(ravv.copy(), Arrays.copyOf(graphToRavvOrdMap, graphToRavvOrdMap.length));
}

@Override
public void getVectorInto(int node, VectorFloat<?> result, int offset) {
ravv.getVectorInto(graphToRavvOrdMap[node], result, offset);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.github.jbellis.jvector.graph.similarity;

import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.RemappedRandomAccessVectorValues;
import io.github.jbellis.jvector.quantization.BQVectors;
import io.github.jbellis.jvector.quantization.PQVectors;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
Expand All @@ -25,8 +26,6 @@
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;

import java.util.stream.IntStream;

/**
* Encapsulates comparing node distances for GraphIndexBuilder.
*/
Expand Down Expand Up @@ -88,15 +87,15 @@ public interface BuildScoreProvider {
*
* Helper method for the special case that mapping between graph node IDs and ravv ordinals is the identity function.
*/
static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, VectorSimilarityFunction similarityFunction) {
return randomAccessScoreProvider(ravv, IntStream.range(0, ravv.size()).toArray(), similarityFunction);
static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap, VectorSimilarityFunction similarityFunction) {
return randomAccessScoreProvider(new RemappedRandomAccessVectorValues(ravv, graphToRavvOrdMap), similarityFunction);
}

/**
* Returns a BSP that performs exact score comparisons using the given RandomAccessVectorValues and VectorSimilarityFunction.
* graphToRavvOrdMap maps graph node IDs to ravv ordinals.
*/
static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap, VectorSimilarityFunction similarityFunction) {
static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, VectorSimilarityFunction similarityFunction) {
// We need two sources of vectors in order to perform diversity check comparisons without
// colliding. ThreadLocalSupplier makes this a no-op if the RAVV is actually un-shared.
var vectors = ravv.threadLocalSupplier();
Expand Down Expand Up @@ -125,22 +124,22 @@ public VectorFloat<?> approximateCentroid() {
@Override
public SearchScoreProvider searchProviderFor(VectorFloat<?> vector) {
var vc = vectorsCopy.get();
return DefaultSearchScoreProvider.exact(vector, graphToRavvOrdMap, similarityFunction, vc);
return DefaultSearchScoreProvider.exact(vector, similarityFunction, vc);
}

@Override
public SearchScoreProvider searchProviderFor(int node1) {
RandomAccessVectorValues randomAccessVectorValues = vectors.get();
var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]);
var v = randomAccessVectorValues.getVector(node1);
return searchProviderFor(v);
}

@Override
public SearchScoreProvider diversityProviderFor(int node1) {
RandomAccessVectorValues randomAccessVectorValues = vectors.get();
var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]);
var v = randomAccessVectorValues.getVector(node1);
var vc = vectorsCopy.get();
return DefaultSearchScoreProvider.exact(v, graphToRavvOrdMap, similarityFunction, vc);
return DefaultSearchScoreProvider.exact(v, similarityFunction, vc);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,19 @@ public void testSaveAndLoad() throws IOException {
}
assertGraphEquals(graph, builder.graph);
}

// Because RandomAccessVectorValues is exposed in such a way that it allows for subsequent additions to the
// vector source, we need to ensure that GraphIndexBuilder can handle this.
@Test
public void testAddNodesToVectorValuesIteratively() throws IOException {
int dimension = randomIntBetween(2, 32);
var mutableVectors = new ArrayList<VectorFloat<?>>();
RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(mutableVectors, dimension);
try (var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f, true)) {
for (int i = 0; i < 10; i++) {
mutableVectors.add(TestUtil.randomVector(random(), dimension));
builder.addGraphNode(i, ravv.getVector(i));
}
}
}
}
Loading