Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions java/lance-jni/src/blocking_scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,12 @@ fn inner_create_scanner<'local>(
scanner.refine(refine_factor);
}

let distance_type_jstr: JString = env
.call_method(&java_obj, "getDistanceType", "()Ljava/lang/String;", &[])?
.l()?
.into();
let distance_type_str: String = env.get_string(&distance_type_jstr)?.into();
let distance_type = DistanceType::try_from(distance_type_str.as_str())?;
scanner.distance_metric(distance_type);
if let Some(distance_type_str) =
env.get_optional_string_from_method(&java_obj, "getDistanceTypeString")?
{
let distance_type = DistanceType::try_from(distance_type_str.as_str())?;
scanner.distance_metric(distance_type);
}

let use_index = env.get_boolean_from_method(&java_obj, "isUseIndex")?;
scanner.use_index(use_index);
Expand Down
13 changes: 7 additions & 6 deletions java/lance-jni/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,13 @@ pub fn get_query(env: &mut JNIEnv, query_obj: JObject) -> Result<Option<Query>>

let refine_factor = env.get_optional_u32_from_method(&java_obj, "getRefineFactor")?;

let distance_type_jstr: JString = env
.call_method(&java_obj, "getDistanceType", "()Ljava/lang/String;", &[])?
.l()?
.into();
let distance_type_str: String = env.get_string(&distance_type_jstr)?.into();
let distance_type = DistanceType::try_from(distance_type_str.as_str())?;
let distance_type = if let Some(distance_type_str) =
env.get_optional_string_from_method(&java_obj, "getDistanceTypeString")?
{
Some(DistanceType::try_from(distance_type_str.as_str())?)
} else {
None
};

let use_index = env.get_boolean_from_method(&java_obj, "isUseIndex")?;

Expand Down
21 changes: 14 additions & 7 deletions java/src/main/java/org/lance/ipc/Query.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class Query {
private final Optional<Integer> maximumNprobes;
private final Optional<Integer> ef;
private final Optional<Integer> refineFactor;
private final DistanceType distanceType;
private final Optional<DistanceType> distanceType;
private final boolean useIndex;

private Query(Builder builder) {
Expand All @@ -48,7 +48,7 @@ private Query(Builder builder) {
this.maximumNprobes = builder.maximumNprobes;
this.ef = builder.ef;
this.refineFactor = builder.refineFactor;
this.distanceType = Preconditions.checkNotNull(builder.distanceType, "Metric type must be set");
this.distanceType = builder.distanceType;
this.useIndex = builder.useIndex;
}

Expand Down Expand Up @@ -80,8 +80,12 @@ public Optional<Integer> getRefineFactor() {
return refineFactor;
}

public String getDistanceType() {
return distanceType.toString();
public Optional<DistanceType> getDistanceType() {
return distanceType;
}

public Optional<String> getDistanceTypeString() {
return distanceType.map(DistanceType::toString);
}

public boolean isUseIndex() {
Expand All @@ -98,7 +102,7 @@ public String toString() {
.add("maximumNprobes", maximumNprobes.orElse(null))
.add("ef", ef.orElse(null))
.add("refineFactor", refineFactor.orElse(null))
.add("distanceType", distanceType)
.add("distanceType", distanceType.orElse(null))
.add("useIndex", useIndex)
.toString();
}
Expand All @@ -111,7 +115,7 @@ public static class Builder {
private Optional<Integer> maximumNprobes = Optional.empty();
private Optional<Integer> ef = Optional.empty();
private Optional<Integer> refineFactor = Optional.empty();
private DistanceType distanceType = DistanceType.L2;
private Optional<DistanceType> distanceType = Optional.empty();
private boolean useIndex = true;

/**
Expand Down Expand Up @@ -219,11 +223,14 @@ public Builder setRefineFactor(int refineFactor) {
/**
* Sets the distance metric type.
*
* <p>If not set, the query will use the index's metric type (if an index is available), or the
* default metric for the data type (L2 for float vectors, Hamming for binary).
*
* @param distanceType The DistanceType to use for the query.
* @return The Builder instance for method chaining.
*/
public Builder setDistanceType(DistanceType distanceType) {
this.distanceType = distanceType;
this.distanceType = Optional.ofNullable(distanceType);
return this;
}

Expand Down
59 changes: 7 additions & 52 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def test_ivf_flat_over_binary_vector(tmp_path):


def test_ivf_flat_respects_index_metric_binary(tmp_path):
# Binary vectors indexed with Hamming should ignore a user-specified L2 metric.
# Searching with binary vectors should default to hamming distance
table = pa.Table.from_pydict(
{
"vector": pa.array([[0], [128], [255]], type=pa.list_(pa.uint8(), 1)),
Expand All @@ -697,67 +697,22 @@ def test_ivf_flat_respects_index_metric_binary(tmp_path):

query = np.array([128], dtype=np.uint8)

# Search should succeed and use the index's Hamming metric despite the L2 hint.
indexed = ds.to_table(
# Search should succeed and use the index's Hamming metric.
indexed = ds.scanner(
columns=["id"],
nearest={
"column": "vector",
"q": query,
"k": 3,
"metric": "l2",
},
)
plan = indexed.explain_plan()
indexed = indexed.to_table()

# Should succeed even though user asked for L2 (index metric is used).
assert indexed["id"].to_pylist() == [1, 0, 2]


def test_ivf_flat_respects_index_metric_float(tmp_path):
# Float vectors indexed with L2 should ignore a user-specified Hamming metric.
vectors = np.array(
[
[0.0, 0.0],
[1.0, 0.0],
[0.0, 2.0],
],
dtype=np.float32,
)
table = pa.Table.from_pydict(
{
"vector": pa.array(vectors.tolist(), type=pa.list_(pa.float32(), 2)),
"id": pa.array([0, 1, 2], type=pa.int32()),
}
)

ds = lance.write_dataset(table, tmp_path)
ds = ds.create_index(
"vector",
index_type="IVF_FLAT",
num_partitions=1,
metric="l2",
)

query = np.array([0.5, 0.0], dtype=np.float32)

indexed = ds.to_table(
columns=["id"],
nearest={
"column": "vector",
"q": query,
"k": 3,
"metric": "hamming",
},
)

expected = ds.to_table(
columns=["id"],
nearest={"column": "vector", "q": query, "k": 3},
)

assert indexed["id"].to_pylist() == expected["id"].to_pylist()
assert np.allclose(
indexed["_distance"].to_numpy(), expected["_distance"].to_numpy()
)
assert "metric=Hamming" in plan
assert "metric=L2" not in plan


def test_bruteforce_uses_user_metric(tmp_path):
Expand Down
5 changes: 3 additions & 2 deletions rust/lance-index/src/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ pub struct Query {
/// TODO: should we support fraction / float number here?
pub refine_factor: Option<u32>,

/// Distance metric type
pub metric_type: DistanceType,
/// Distance metric type. If None, uses the index's metric (if available)
/// or the default for the data type.
pub metric_type: Option<DistanceType>,

/// Whether to use an ANN index if available
pub use_index: bool,
Expand Down
Loading
Loading