diff --git a/CMakeLists.txt b/CMakeLists.txt index 65ebe03..da88eef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -125,13 +125,14 @@ set(RUST_FFI_DEPENDS ${CMAKE_CURRENT_LIST_DIR}/rust/ffi/mod.rs ${CMAKE_CURRENT_LIST_DIR}/rust/ffi/projection.rs ${CMAKE_CURRENT_LIST_DIR}/rust/ffi/scan.rs - ${CMAKE_CURRENT_LIST_DIR}/rust/ffi/search.rs - ${CMAKE_CURRENT_LIST_DIR}/rust/ffi/stream.rs - ${CMAKE_CURRENT_LIST_DIR}/rust/ffi/types.rs - ${CMAKE_CURRENT_LIST_DIR}/rust/ffi/update.rs - ${CMAKE_CURRENT_LIST_DIR}/rust/ffi/util.rs - ${CMAKE_CURRENT_LIST_DIR}/rust/ffi/write.rs - ${CMAKE_CURRENT_LIST_DIR}/rust/lib.rs + ${CMAKE_CURRENT_LIST_DIR}/rust/ffi/search.rs + ${CMAKE_CURRENT_LIST_DIR}/rust/ffi/stream.rs + ${CMAKE_CURRENT_LIST_DIR}/rust/ffi/take.rs + ${CMAKE_CURRENT_LIST_DIR}/rust/ffi/types.rs + ${CMAKE_CURRENT_LIST_DIR}/rust/ffi/update.rs + ${CMAKE_CURRENT_LIST_DIR}/rust/ffi/util.rs + ${CMAKE_CURRENT_LIST_DIR}/rust/ffi/write.rs + ${CMAKE_CURRENT_LIST_DIR}/rust/lib.rs ${CMAKE_CURRENT_LIST_DIR}/rust/runtime.rs ${CMAKE_CURRENT_LIST_DIR}/rust/scanner.rs) diff --git a/rust/error.rs b/rust/error.rs index c89c848..a965257 100644 --- a/rust/error.rs +++ b/rust/error.rs @@ -60,6 +60,7 @@ pub enum ErrorCode { DatasetListIndices = 45, DatasetCreateScalarIndex = 46, DatasetCalculateDataStats = 47, + DatasetTake = 48, } struct LastError { diff --git a/rust/ffi/dataset.rs b/rust/ffi/dataset.rs index 3e44590..c2dfbac 100644 --- a/rust/ffi/dataset.rs +++ b/rust/ffi/dataset.rs @@ -3,11 +3,13 @@ use std::ffi::{c_char, c_void, CStr}; use std::ptr; use std::sync::Arc; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_sql::unparser::expr_to_sql; use lance::dataset::statistics::DatasetStatisticsExt; use lance::dataset::builder::DatasetBuilder; use lance::Dataset; +use crate::constants::ROW_ID_COLUMN; use crate::error::{clear_last_error, set_last_error, ErrorCode}; use crate::runtime; @@ -201,6 +203,38 @@ fn get_schema_inner(dataset: *mut c_void) -> FfiResult *mut c_void { + match get_schema_for_scan_inner(dataset) { + Ok(schema) => { + clear_last_error(); + Box::into_raw(Box::new(schema)) as *mut c_void + } + Err(err) => { + set_last_error(err.code, err.message); + ptr::null_mut() + } + } +} + +fn get_schema_for_scan_inner(dataset: *mut c_void) -> FfiResult { + let handle = unsafe { super::util::dataset_handle(dataset)? }; + + let mut schema: Schema = (*handle.arrow_schema).clone(); + let has_row_id = schema.fields.iter().any(|f| f.name() == ROW_ID_COLUMN); + if !has_row_id { + let mut fields = schema.fields.iter().cloned().collect::>(); + fields.push(Arc::new(Field::new( + ROW_ID_COLUMN, + DataType::UInt64, + false, + ))); + schema.fields = fields.into(); + } + + Ok(Arc::new(schema)) +} + #[no_mangle] pub unsafe extern "C" fn lance_dataset_list_fragments( dataset: *mut c_void, diff --git a/rust/ffi/index.rs b/rust/ffi/index.rs index 9cf493a..c0d8edb 100644 --- a/rust/ffi/index.rs +++ b/rust/ffi/index.rs @@ -125,7 +125,12 @@ fn create_index_list_stream_inner(dataset: *mut c_void) -> FfiResult bool { ) } -fn build_vector_params(index_type: &str, params_json: Option<&str>) -> FfiResult { +fn build_vector_params( + index_type: &str, + params_json: Option<&str>, +) -> FfiResult { let mut params = serde_json::Map::::new(); if let Some(json) = params_json { let v: serde_json::Value = serde_json::from_str(json).map_err(|err| { @@ -459,19 +476,15 @@ fn build_vector_params(index_type: &str, params_json: Option<&str>) -> FfiResult .and_then(|v| v.as_str()) .unwrap_or("l2"); let metric_type = DistanceType::try_from(metric).map_err(|err| { - FfiError::new( - ErrorCode::DatasetCreateIndex, - format!("metric_type: {err}"), - ) + FfiError::new(ErrorCode::DatasetCreateIndex, format!("metric_type: {err}")) })?; let version = params .get("version") .and_then(|v| v.as_str()) .unwrap_or("v3"); - let version = IndexFileVersion::try_from(version).map_err(|err| { - FfiError::new(ErrorCode::DatasetCreateIndex, format!("version: {err}")) - })?; + let version = IndexFileVersion::try_from(version) + .map_err(|err| FfiError::new(ErrorCode::DatasetCreateIndex, format!("version: {err}")))?; let num_partitions = params .get("num_partitions") @@ -481,10 +494,7 @@ fn build_vector_params(index_type: &str, params_json: Option<&str>) -> FfiResult let out = match index_type { "IVF_FLAT" => VectorIndexParams::ivf_flat(num_partitions, metric_type), "IVF_PQ" => { - let num_bits = params - .get("num_bits") - .and_then(|v| v.as_u64()) - .unwrap_or(8) as u8; + let num_bits = params.get("num_bits").and_then(|v| v.as_u64()).unwrap_or(8) as u8; let num_sub_vectors = params .get("num_sub_vectors") .and_then(|v| v.as_u64()) @@ -502,17 +512,11 @@ fn build_vector_params(index_type: &str, params_json: Option<&str>) -> FfiResult ) } "IVF_RQ" => { - let num_bits = params - .get("num_bits") - .and_then(|v| v.as_u64()) - .unwrap_or(8) as u8; + let num_bits = params.get("num_bits").and_then(|v| v.as_u64()).unwrap_or(8) as u8; VectorIndexParams::ivf_rq(num_partitions, num_bits, metric_type) } "IVF_SQ" => { - let num_bits = params - .get("num_bits") - .and_then(|v| v.as_u64()) - .unwrap_or(8) as u8; + let num_bits = params.get("num_bits").and_then(|v| v.as_u64()).unwrap_or(8) as u8; let ivf = lance_index::vector::ivf::IvfBuildParams::new(num_partitions); let sq = lance_index::vector::sq::builder::SQBuildParams { num_bits: num_bits as u16, diff --git a/rust/ffi/mod.rs b/rust/ffi/mod.rs index 1cadd8e..d8fb835 100644 --- a/rust/ffi/mod.rs +++ b/rust/ffi/mod.rs @@ -6,9 +6,10 @@ mod knn; mod namespace; mod projection; mod scan; -mod search; mod schema_evolution; +mod search; mod stream; +mod take; mod types; mod update; mod util; diff --git a/rust/ffi/scan.rs b/rust/ffi/scan.rs index 8e65813..5ccdb68 100644 --- a/rust/ffi/scan.rs +++ b/rust/ffi/scan.rs @@ -4,6 +4,7 @@ use std::ptr; use crate::error::{clear_last_error, set_last_error, ErrorCode}; use crate::runtime; use crate::scanner::LanceStream; +use crate::constants::ROW_ID_COLUMN; use super::types::StreamHandle; use super::util::{ @@ -63,6 +64,9 @@ fn create_fragment_stream_ir_inner( let projection = unsafe { optional_cstr_array(columns, columns_len, "columns")? }; if !projection.is_empty() { + if projection.iter().any(|c| c == ROW_ID_COLUMN) { + scan.with_row_id(); + } scan.project(&projection).map_err(|err| { FfiError::new( ErrorCode::FragmentScan, @@ -147,6 +151,9 @@ fn create_dataset_stream_ir_inner( let projection = unsafe { optional_cstr_array(columns, columns_len, "columns")? }; if !projection.is_empty() { + if projection.iter().any(|c| c == ROW_ID_COLUMN) { + scan.with_row_id(); + } scan.project(&projection).map_err(|err| { FfiError::new( ErrorCode::DatasetScan, @@ -171,10 +178,7 @@ fn create_dataset_stream_ir_inner( let limit_opt = if limit == -1 { None } else { Some(limit) }; let offset_opt = if offset == 0 { None } else { Some(offset) }; scan.limit(limit_opt, offset_opt).map_err(|err| { - FfiError::new( - ErrorCode::DatasetScan, - format!("dataset scan limit: {err}"), - ) + FfiError::new(ErrorCode::DatasetScan, format!("dataset scan limit: {err}")) })?; } diff --git a/rust/ffi/schema_evolution.rs b/rust/ffi/schema_evolution.rs index 012f33c..a671725 100644 --- a/rust/ffi/schema_evolution.rs +++ b/rust/ffi/schema_evolution.rs @@ -56,7 +56,13 @@ pub unsafe extern "C" fn lance_dataset_add_columns( expressions_len: usize, batch_size: u32, ) -> i32 { - match dataset_add_columns_inner(dataset, new_columns_schema, expressions, expressions_len, batch_size) { + match dataset_add_columns_inner( + dataset, + new_columns_schema, + expressions, + expressions_len, + batch_size, + ) { Ok(()) => { clear_last_error(); 0 @@ -185,13 +191,12 @@ fn dataset_add_columns_inner( ) })?; let arr = if arr.data_type() != field.data_type() { - compute::cast(&arr, field.data_type()) - .map_err(|err| { - lance::Error::invalid_input( - format!("expression[{idx}] cast: {err}"), - location!(), - ) - })? + compute::cast(&arr, field.data_type()).map_err(|err| { + lance::Error::invalid_input( + format!("expression[{idx}] cast: {err}"), + location!(), + ) + })? } else { arr }; @@ -422,9 +427,11 @@ fn dataset_update_table_metadata_inner( let value = if value.is_null() { None } else { - Some(unsafe { CStr::from_ptr(value) }.to_str().map_err(|err| { - FfiError::new(ErrorCode::Utf8, format!("value utf8: {err}")) - })?) + Some( + unsafe { CStr::from_ptr(value) } + .to_str() + .map_err(|err| FfiError::new(ErrorCode::Utf8, format!("value utf8: {err}")))?, + ) }; let mut ds = (*handle.dataset).clone(); @@ -467,9 +474,11 @@ fn dataset_update_config_inner( let value = if value.is_null() { None } else { - Some(unsafe { CStr::from_ptr(value) }.to_str().map_err(|err| { - FfiError::new(ErrorCode::Utf8, format!("value utf8: {err}")) - })?) + Some( + unsafe { CStr::from_ptr(value) } + .to_str() + .map_err(|err| FfiError::new(ErrorCode::Utf8, format!("value utf8: {err}")))?, + ) }; let mut ds = (*handle.dataset).clone(); @@ -512,9 +521,11 @@ fn dataset_update_schema_metadata_inner( let value = if value.is_null() { None } else { - Some(unsafe { CStr::from_ptr(value) }.to_str().map_err(|err| { - FfiError::new(ErrorCode::Utf8, format!("value utf8: {err}")) - })?) + Some( + unsafe { CStr::from_ptr(value) } + .to_str() + .map_err(|err| FfiError::new(ErrorCode::Utf8, format!("value utf8: {err}")))?, + ) }; let mut ds = (*handle.dataset).clone(); @@ -560,9 +571,11 @@ fn dataset_update_field_metadata_inner( let value = if value.is_null() { None } else { - Some(unsafe { CStr::from_ptr(value) }.to_str().map_err(|err| { - FfiError::new(ErrorCode::Utf8, format!("value utf8: {err}")) - })?) + Some( + unsafe { CStr::from_ptr(value) } + .to_str() + .map_err(|err| FfiError::new(ErrorCode::Utf8, format!("value utf8: {err}")))?, + ) }; let mut ds = (*handle.dataset).clone(); @@ -770,12 +783,7 @@ fn dataset_list_kv_inner(dataset: *mut c_void, which: &'static str) -> FfiResult out.push('\n'); } } - _ => { - return Err(FfiError::new( - ErrorCode::InvalidArgument, - "unknown kv type", - )) - } + _ => return Err(FfiError::new(ErrorCode::InvalidArgument, "unknown kv type")), } Ok(out) diff --git a/rust/ffi/take.rs b/rust/ffi/take.rs new file mode 100644 index 0000000..c2eb081 --- /dev/null +++ b/rust/ffi/take.rs @@ -0,0 +1,95 @@ +use std::ffi::{c_char, c_void}; +use std::ptr; + +use lance::dataset::ProjectionRequest; + +use crate::error::{clear_last_error, set_last_error, ErrorCode}; +use crate::runtime; + +use super::types::StreamHandle; +use super::util::{optional_cstr_array, slice_from_ptr, FfiError, FfiResult}; + +#[no_mangle] +pub unsafe extern "C" fn lance_create_dataset_take_stream( + dataset: *mut c_void, + row_ids: *const u64, + row_ids_len: usize, + columns: *const *const c_char, + columns_len: usize, +) -> *mut c_void { + match create_dataset_take_stream_inner(dataset, row_ids, row_ids_len, columns, columns_len) { + Ok(stream) => { + clear_last_error(); + Box::into_raw(Box::new(stream)) as *mut c_void + } + Err(err) => { + set_last_error(err.code, err.message); + ptr::null_mut() + } + } +} + +fn create_dataset_take_stream_inner( + dataset: *mut c_void, + row_ids: *const u64, + row_ids_len: usize, + columns: *const *const c_char, + columns_len: usize, +) -> FfiResult { + let handle = unsafe { super::util::dataset_handle(dataset)? }; + + let row_ids = if row_ids_len == 0 { + &[][..] + } else { + unsafe { slice_from_ptr(row_ids, row_ids_len, "row_ids")? } + }; + let row_ids_filtered; + let row_ids = if row_ids.is_empty() { + row_ids + } else { + let max_row_id = if handle.dataset.manifest.uses_stable_row_ids() { + handle.dataset.manifest.next_row_id + } else { + handle + .dataset + .manifest + .fragments + .iter() + .map(|fragment| fragment.num_rows().unwrap_or_default() as u64) + .sum::() + }; + if row_ids.iter().all(|id| *id < max_row_id) { + row_ids + } else { + row_ids_filtered = row_ids + .iter() + .copied() + .filter(|id| *id < max_row_id) + .collect::>(); + row_ids_filtered.as_slice() + } + }; + + let projection_cols = unsafe { optional_cstr_array(columns, columns_len, "columns")? }; + let projection = if projection_cols.is_empty() { + ProjectionRequest::from_schema(handle.dataset.schema().clone()) + } else { + ProjectionRequest::from_columns( + projection_cols.iter().map(|s| s.as_str()), + handle.dataset.schema(), + ) + }; + + let batch = match runtime::block_on(handle.dataset.take_rows(row_ids, projection)) { + Ok(Ok(batch)) => batch, + Ok(Err(err)) => { + return Err(FfiError::new( + ErrorCode::DatasetTake, + format!("dataset take_rows: {err}"), + )) + } + Err(err) => return Err(FfiError::new(ErrorCode::Runtime, format!("runtime: {err}"))), + }; + + Ok(StreamHandle::Batches(vec![batch].into_iter())) +} diff --git a/src/include/lance_ffi.hpp b/src/include/lance_ffi.hpp index fec2afb..b1c6ecc 100644 --- a/src/include/lance_ffi.hpp +++ b/src/include/lance_ffi.hpp @@ -33,6 +33,7 @@ void *lance_open_dataset_in_namespace( void lance_close_dataset(void *dataset); void *lance_get_schema(void *dataset); +void *lance_get_schema_for_scan(void *dataset); void lance_free_schema(void *schema); int32_t lance_schema_to_arrow(void *schema, ArrowSchema *out_schema); @@ -114,6 +115,9 @@ void *lance_create_dataset_stream_ir(void *dataset, const char **columns, const uint8_t *filter_ir, size_t filter_ir_len, int64_t limit, int64_t offset); +void *lance_create_dataset_take_stream(void *dataset, const uint64_t *row_ids, + size_t row_ids_len, const char **columns, + size_t columns_len); void *lance_open_writer_with_storage_options( const char *path, const char *mode, const char **option_keys, diff --git a/src/include/lance_scan_bind_data.hpp b/src/include/lance_scan_bind_data.hpp index 4255462..82b989a 100644 --- a/src/include/lance_scan_bind_data.hpp +++ b/src/include/lance_scan_bind_data.hpp @@ -12,12 +12,16 @@ struct LanceScanBindData : public TableFunctionData { bool explain_verbose = false; void *dataset = nullptr; ArrowSchemaWrapper schema_root; + ArrowSchemaWrapper scan_schema_root; ArrowTableSchema arrow_table; + ArrowTableSchema scan_arrow_table; vector names; vector types; vector lance_pushed_filter_ir_parts; vector duckdb_pushed_filter_sql_parts; + vector take_row_ids; + bool limit_offset_pushed_down = false; optional_idx pushed_limit = optional_idx::Invalid(); idx_t pushed_offset = 0; diff --git a/src/lance_scan.cpp b/src/lance_scan.cpp index 0291155..04b1b48 100644 --- a/src/lance_scan.cpp +++ b/src/lance_scan.cpp @@ -1,4 +1,5 @@ #include "duckdb.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/common/arrow/arrow.hpp" #include "duckdb/common/arrow/arrow_converter.hpp" #include "duckdb/common/exception.hpp" @@ -11,9 +12,13 @@ #include "duckdb/parser/constraints/not_null_constraint.hpp" #include "duckdb/optimizer/optimizer_extension.hpp" #include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" #include "duckdb/planner/operator/logical_limit.hpp" #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/table_filter.hpp" +#include "duckdb/planner/filter/constant_filter.hpp" +#include "duckdb/planner/filter/in_filter.hpp" +#include "duckdb/planner/filter/optional_filter.hpp" #include "duckdb/parser/expression/cast_expression.hpp" #include "duckdb/parser/expression/columnref_expression.hpp" #include "duckdb/parser/parsed_data/alter_table_info.hpp" @@ -21,6 +26,11 @@ #include "duckdb/parser/parsed_data/create_table_info.hpp" #include "duckdb/parser/parsed_data/create_table_function_info.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression_iterator.hpp" #include "lance_common.hpp" @@ -86,6 +96,22 @@ LanceScanCardinality(ClientContext &context, const FunctionData *bind_data_p) { return make_uniq(count, count); } +static bool LancePushdownExpression(ClientContext &, const LogicalGet &, + Expression &expr) { + if (expr.GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { + return true; + } + auto &func = expr.Cast(); + auto &name = func.function.name; + // Keep LIKE/ILIKE as expressions so we can build and surface Lance Filter IR + // (instead of letting DuckDB rewrite them into TableFilters). + if (name == "~~" || name == "~~*" || name == "like_escape" || + name == "ilike_escape") { + return false; + } + return true; +} + static vector LanceScanGetPartitionStats(ClientContext &context, GetPartitionStatsInput &input) { @@ -116,6 +142,27 @@ LanceScanBindData::~LanceScanBindData() { } } +static constexpr column_t LANCE_COLUMN_IDENTIFIER_ROW_ID = + UINT64_C(9223372036854775900); + +static constexpr const char *LANCE_ROW_ID_COLUMN_NAME = "_rowid"; + +static bool IsLanceVirtualRowIdColumnId(column_t col_id) { + return col_id == LANCE_COLUMN_IDENTIFIER_ROW_ID; +} + +static virtual_column_map_t LanceGetVirtualColumns(ClientContext &, + optional_ptr) { + virtual_column_map_t result; + result.emplace(COLUMN_IDENTIFIER_ROW_ID, + TableColumn("rowid", LogicalType::ROW_TYPE)); + result.emplace(COLUMN_IDENTIFIER_EMPTY, + TableColumn("", LogicalType::BOOLEAN)); + result.emplace(LANCE_COLUMN_IDENTIFIER_ROW_ID, + TableColumn(LANCE_ROW_ID_COLUMN_NAME, LogicalType::UBIGINT)); + return result; +} + static bool TryLanceExplainDatasetScan(void *dataset, const vector *columns, const string *filter_ir, @@ -175,16 +222,20 @@ struct LanceScanGlobalState : public GlobalTableFunctionState { std::atomic filter_pushdown_fallbacks{0}; bool use_dataset_scanner = false; + bool use_dataset_take = false; + bool scan_includes_virtual_rowid = false; bool limit_offset_pushed_down = false; optional_idx pushed_limit = optional_idx::Invalid(); idx_t pushed_offset = 0; + vector take_row_ids; vector fragment_ids; idx_t max_threads = 1; vector projection_ids; vector scanned_types; + vector scan_column_ids; vector scan_column_names; string lance_filter_ir; bool filter_pushed_down = false; @@ -230,6 +281,116 @@ static bool LanceSupportsPushdownType(const FunctionData &bind_data, return LanceFilterIRSupportsLogicalType(scan_bind.types[col_idx]); } +static bool TryParseRowIdValue(const Value &value, uint64_t &out) { + if (value.IsNull()) { + return false; + } + + auto cast_signed = [&out](int64_t v) { + if (v < 0) { + return false; + } + out = NumericCast(v); + return true; + }; + + switch (value.type().id()) { + case LogicalTypeId::TINYINT: + return cast_signed(value.GetValue()); + case LogicalTypeId::SMALLINT: + return cast_signed(value.GetValue()); + case LogicalTypeId::INTEGER: + return cast_signed(value.GetValue()); + case LogicalTypeId::BIGINT: + return cast_signed(value.GetValue()); + case LogicalTypeId::UTINYINT: + out = value.GetValue(); + return true; + case LogicalTypeId::USMALLINT: + out = value.GetValue(); + return true; + case LogicalTypeId::UINTEGER: + out = value.GetValue(); + return true; + case LogicalTypeId::UBIGINT: + out = value.GetValue(); + return true; + default: + return false; + } +} + +static bool TryExtractTakeRowIdsFromFilter(const TableFilter &filter, + vector &out_row_ids) { + switch (filter.filter_type) { + case TableFilterType::IN_FILTER: { + auto &in = filter.Cast(); + out_row_ids.reserve(out_row_ids.size() + in.values.size()); + for (auto &v : in.values) { + uint64_t row_id = 0; + if (!TryParseRowIdValue(v, row_id)) { + throw InvalidInputException( + "Lance point lookup requires integer _rowid values"); + } + out_row_ids.push_back(row_id); + } + return true; + } + case TableFilterType::CONSTANT_COMPARISON: { + auto &cmp = filter.Cast(); + if (cmp.comparison_type != ExpressionType::COMPARE_EQUAL) { + return false; + } + uint64_t row_id = 0; + if (!TryParseRowIdValue(cmp.constant, row_id)) { + throw InvalidInputException( + "Lance point lookup requires integer _rowid values"); + } + out_row_ids.push_back(row_id); + return true; + } + case TableFilterType::OPTIONAL_FILTER: { + auto &opt = filter.Cast(); + if (!opt.child_filter) { + return false; + } + return TryExtractTakeRowIdsFromFilter(*opt.child_filter, out_row_ids); + } + default: + return false; + } +} + +static bool IsTakeRowIdFilter(const TableFilter &filter) { + switch (filter.filter_type) { + case TableFilterType::IN_FILTER: + return true; + case TableFilterType::CONSTANT_COMPARISON: { + auto &cmp = filter.Cast(); + return cmp.comparison_type == ExpressionType::COMPARE_EQUAL; + } + case TableFilterType::OPTIONAL_FILTER: { + auto &opt = filter.Cast(); + if (!opt.child_filter) { + return false; + } + return IsTakeRowIdFilter(*opt.child_filter); + } + default: + return false; + } +} + +static bool TryExtractTakeRowIdsFromFilters(const TableFilterSet &filters, + const idx_t row_id_col_idx, + vector &out_row_ids) { + auto it = filters.filters.find(row_id_col_idx); + if (it == filters.filters.end() || !it->second) { + return false; + } + return TryExtractTakeRowIdsFromFilter(*it->second, out_row_ids); +} + static void LancePushdownComplexFilter(ClientContext &context, LogicalGet &get, FunctionData *bind_data, @@ -270,11 +431,117 @@ LancePushdownComplexFilter(ClientContext &context, LogicalGet &get, get_column_names.push_back(scan_bind.names[col_id]); } + auto is_rowid_ref = [&](const BoundColumnRefExpression &colref) { + if (colref.binding.table_index != get.table_index || + colref.binding.column_index >= col_ids.size()) { + return false; + } + return IsLanceVirtualRowIdColumnId( + col_ids[colref.binding.column_index].GetPrimaryIndex()); + }; + + std::function &)> + try_extract_rowids; + try_extract_rowids = [&](const Expression &expr, vector &out) { + if (expr.GetExpressionType() == ExpressionType::COMPARE_IN && + expr.GetExpressionClass() == ExpressionClass::BOUND_OPERATOR) { + auto &op = expr.Cast(); + if (op.children.size() <= 1 || !op.children[0] || + op.children[0]->GetExpressionClass() != + ExpressionClass::BOUND_COLUMN_REF) { + return false; + } + auto &colref = op.children[0]->Cast(); + if (!is_rowid_ref(colref)) { + return false; + } + out.clear(); + out.reserve(op.children.size() - 1); + for (idx_t i = 1; i < op.children.size(); i++) { + auto &child = op.children[i]; + if (!child || + child->GetExpressionClass() != ExpressionClass::BOUND_CONSTANT) { + return false; + } + uint64_t row_id = 0; + auto &v = child->Cast().value; + if (!TryParseRowIdValue(v, row_id)) { + throw InvalidInputException( + "Lance point lookup requires integer _rowid values"); + } + out.push_back(row_id); + } + return true; + } + + if (expr.GetExpressionType() == ExpressionType::COMPARE_EQUAL && + expr.GetExpressionClass() == ExpressionClass::BOUND_COMPARISON) { + auto &cmp = expr.Cast(); + if (!cmp.left || !cmp.right) { + return false; + } + auto extract_one = [&](const Expression &lhs, + const Expression &rhs) -> bool { + if (lhs.GetExpressionClass() != ExpressionClass::BOUND_COLUMN_REF || + rhs.GetExpressionClass() != ExpressionClass::BOUND_CONSTANT) { + return false; + } + auto &colref = lhs.Cast(); + if (!is_rowid_ref(colref)) { + return false; + } + uint64_t row_id = 0; + auto &v = rhs.Cast().value; + if (!TryParseRowIdValue(v, row_id)) { + throw InvalidInputException( + "Lance point lookup requires integer _rowid values"); + } + out.clear(); + out.push_back(row_id); + return true; + }; + if (extract_one(*cmp.left, *cmp.right)) { + return true; + } + if (extract_one(*cmp.right, *cmp.left)) { + return true; + } + return false; + } + + if (expr.GetExpressionType() == ExpressionType::CONJUNCTION_OR && + expr.GetExpressionClass() == ExpressionClass::BOUND_CONJUNCTION) { + auto &conj = expr.Cast(); + vector child_out; + out.clear(); + out.reserve(conj.children.size()); + for (auto &child : conj.children) { + if (!child) { + return false; + } + child_out.clear(); + if (!try_extract_rowids(*child, child_out) || child_out.empty()) { + return false; + } + out.insert(out.end(), child_out.begin(), child_out.end()); + } + return !out.empty(); + } + + return false; + }; + for (auto &expr : filters) { if (!expr || expr->HasParameter() || expr->IsVolatile() || expr->CanThrow()) { continue; } + if (scan_bind.take_row_ids.empty()) { + vector take_row_ids; + if (try_extract_rowids(*expr, take_row_ids) && !take_row_ids.empty()) { + scan_bind.take_row_ids = std::move(take_row_ids); + } + } string filter_ir; if (!TryBuildLanceExprFilterIR(get, scan_bind.names, scan_bind.types, false, *expr, filter_ir)) { @@ -350,6 +617,24 @@ static unique_ptr LanceScanBind(ClientContext &context, config, result->arrow_table, result->schema_root.arrow_schema); result->names = result->arrow_table.GetNames(); result->types = result->arrow_table.GetTypes(); + + auto *scan_schema_handle = lance_get_schema_for_scan(result->dataset); + if (!scan_schema_handle) { + throw IOException("Failed to get scan schema from Lance dataset: " + + result->file_path + LanceFormatErrorSuffix()); + } + memset(&result->scan_schema_root.arrow_schema, 0, + sizeof(result->scan_schema_root.arrow_schema)); + if (lance_schema_to_arrow(scan_schema_handle, + &result->scan_schema_root.arrow_schema) != 0) { + lance_free_schema(scan_schema_handle); + throw IOException( + "Failed to export Lance scan schema to Arrow C Data Interface" + + LanceFormatErrorSuffix()); + } + lance_free_schema(scan_schema_handle); + ArrowTableFunction::PopulateArrowTableSchema( + config, result->scan_arrow_table, result->scan_schema_root.arrow_schema); names = result->names; return_types = result->types; return std::move(result); @@ -424,6 +709,25 @@ LanceNamespaceScanBind(ClientContext &context, TableFunctionBindInput &input, config, result->arrow_table, result->schema_root.arrow_schema); result->names = result->arrow_table.GetNames(); result->types = result->arrow_table.GetTypes(); + + auto *scan_schema_handle = lance_get_schema_for_scan(result->dataset); + if (!scan_schema_handle) { + throw IOException( + "Failed to get scan schema from Lance dataset via namespace: " + + result->file_path + LanceFormatErrorSuffix()); + } + memset(&result->scan_schema_root.arrow_schema, 0, + sizeof(result->scan_schema_root.arrow_schema)); + if (lance_schema_to_arrow(scan_schema_handle, + &result->scan_schema_root.arrow_schema) != 0) { + lance_free_schema(scan_schema_handle); + throw IOException( + "Failed to export Lance scan schema to Arrow C Data Interface" + + LanceFormatErrorSuffix()); + } + lance_free_schema(scan_schema_handle); + ArrowTableFunction::PopulateArrowTableSchema( + config, result->scan_arrow_table, result->scan_schema_root.arrow_schema); names = result->names; return_types = result->types; return std::move(result); @@ -440,6 +744,7 @@ LanceScanInitGlobal(ClientContext &context, TableFunctionInitInput &input) { scan_state.pushed_offset = bind_data.pushed_offset; scan_state.projection_ids = input.projection_ids; + auto rowid_internal_index = NumericCast(bind_data.types.size()); if (!input.projection_ids.empty()) { scan_state.scanned_types.reserve(input.column_ids.size()); for (auto col_id : input.column_ids) { @@ -447,6 +752,11 @@ LanceScanInitGlobal(ClientContext &context, TableFunctionInitInput &input) { col_id == COLUMN_IDENTIFIER_EMPTY) { continue; } + if (IsLanceVirtualRowIdColumnId(col_id)) { + scan_state.scan_includes_virtual_rowid = true; + scan_state.scanned_types.push_back(LogicalType::UBIGINT); + continue; + } if (col_id >= bind_data.types.size()) { throw IOException("Invalid column id in projection"); } @@ -460,9 +770,16 @@ LanceScanInitGlobal(ClientContext &context, TableFunctionInitInput &input) { col_id == COLUMN_IDENTIFIER_EMPTY) { continue; } + if (IsLanceVirtualRowIdColumnId(col_id)) { + scan_state.scan_includes_virtual_rowid = true; + scan_state.scan_column_ids.push_back(rowid_internal_index); + scan_state.scan_column_names.push_back(LANCE_ROW_ID_COLUMN_NAME); + continue; + } if (col_id >= bind_data.names.size()) { throw IOException("Invalid column id in projection"); } + scan_state.scan_column_ids.push_back(col_id); scan_state.scan_column_names.push_back(bind_data.names[col_id]); } @@ -488,6 +805,42 @@ LanceScanInitGlobal(ClientContext &context, TableFunctionInitInput &input) { scan_state.filter_pushed_down = table_filters.all_filters_pushed && !scan_state.lance_filter_ir.empty(); + if (!bind_data.take_row_ids.empty()) { + if (bind_data.limit_offset_pushed_down) { + throw IOException( + "Lance point lookup does not support limit/offset pushdown"); + } + scan_state.lance_filter_ir.clear(); + scan_state.filter_pushed_down = false; + scan_state.use_dataset_scanner = true; + scan_state.use_dataset_take = true; + scan_state.max_threads = 1; + scan_state.take_row_ids = bind_data.take_row_ids; + return state; + } + + idx_t row_id_col_idx = DConstants::INVALID_INDEX; + for (idx_t i = 0; i < bind_data.names.size(); i++) { + if (bind_data.names[i] == "_rowid") { + row_id_col_idx = i; + break; + } + } + if (row_id_col_idx != DConstants::INVALID_INDEX && input.filters && + TryExtractTakeRowIdsFromFilters(*input.filters, row_id_col_idx, + scan_state.take_row_ids)) { + if (bind_data.limit_offset_pushed_down) { + throw IOException( + "Lance point lookup does not support limit/offset pushdown"); + } + scan_state.lance_filter_ir.clear(); + scan_state.filter_pushed_down = false; + scan_state.use_dataset_scanner = true; + scan_state.use_dataset_take = true; + scan_state.max_threads = 1; + return state; + } + if (scan_state.scan_column_names.empty() && scan_state.lance_filter_ir.empty()) { auto rows = lance_dataset_count_rows(bind_data.dataset); @@ -496,7 +849,22 @@ LanceScanInitGlobal(ClientContext &context, TableFunctionInitInput &input) { LanceFormatErrorSuffix()); } scan_state.count_only = true; - scan_state.count_only_total_rows = NumericCast(rows); + auto total_rows = NumericCast(rows); + if (scan_state.limit_offset_pushed_down) { + if (scan_state.pushed_offset >= total_rows) { + scan_state.count_only_total_rows = 0; + } else { + auto remaining = total_rows - scan_state.pushed_offset; + if (scan_state.pushed_limit.IsValid()) { + scan_state.count_only_total_rows = + MinValue(remaining, scan_state.pushed_limit.GetIndex()); + } else { + scan_state.count_only_total_rows = remaining; + } + } + } else { + scan_state.count_only_total_rows = total_rows; + } scan_state.max_threads = 1; return state; } @@ -537,7 +905,7 @@ LanceScanLocalInit(ExecutionContext &context, TableFunctionInitInput &input, auto chunk = make_uniq(); auto result = make_uniq(std::move(chunk), context.client); - result->column_ids = input.column_ids; + result->column_ids = scan_global.scan_column_ids; result->filters = input.filters.get(); result->global_state = &scan_global; result->filter_pushed_down = scan_global.filter_pushed_down; @@ -583,7 +951,14 @@ static bool LanceScanOpenStream(ClientContext &context, global_state.filter_pushed_down && filter_ir && filter_ir_len > 0; void *stream = nullptr; - if (global_state.use_dataset_scanner) { + if (global_state.use_dataset_take) { + auto row_ids_ptr = global_state.take_row_ids.empty() + ? nullptr + : global_state.take_row_ids.data(); + stream = lance_create_dataset_take_stream(bind_data.dataset, row_ids_ptr, + global_state.take_row_ids.size(), + columns.data(), columns.size()); + } else if (global_state.use_dataset_scanner) { auto limit_i64 = global_state.pushed_limit.IsValid() ? NumericCast(global_state.pushed_limit.GetIndex()) @@ -664,6 +1039,39 @@ static bool LanceScanLoadNextBatch(LanceScanLocalState &local_state) { lance_free_batch(batch); + if (local_state.global_state && tmp_schema.n_children > 0 && + new_chunk->arrow_array.n_children == tmp_schema.n_children && + !local_state.global_state->scan_column_names.empty()) { + unordered_map idx_by_name; + idx_t child_count = NumericCast(tmp_schema.n_children); + idx_by_name.reserve(child_count); + for (idx_t i = 0; i < child_count; i++) { + auto *child_schema = tmp_schema.children[i]; + if (!child_schema || !child_schema->name) { + continue; + } + idx_by_name.emplace(child_schema->name, i); + } + + auto expected_count = + NumericCast(local_state.global_state->scan_column_names.size()); + vector old_children; + old_children.reserve(child_count); + for (idx_t i = 0; i < child_count; i++) { + old_children.push_back(new_chunk->arrow_array.children[i]); + } + + for (idx_t i = 0; i < expected_count; i++) { + auto &expected = local_state.global_state->scan_column_names[i]; + auto it = idx_by_name.find(expected); + if (it == idx_by_name.end()) { + throw IOException("Missing expected column in Arrow batch: " + + expected); + } + new_chunk->arrow_array.children[i] = old_children[it->second]; + } + } + if (local_state.global_state) { local_state.global_state->record_batches.fetch_add(1); auto rows = NumericCast(new_chunk->arrow_array.length); @@ -688,6 +1096,9 @@ static void LanceScanFunc(ClientContext &context, TableFunctionInput &data, auto &bind_data = data.bind_data->Cast(); auto &global_state = data.global_state->Cast(); auto &local_state = data.local_state->Cast(); + auto &arrow_columns = global_state.scan_includes_virtual_rowid + ? bind_data.scan_arrow_table.GetColumns() + : bind_data.arrow_table.GetColumns(); if (global_state.count_only) { auto start = global_state.count_only_offset.fetch_add(STANDARD_VECTOR_SIZE); @@ -728,8 +1139,7 @@ static void LanceScanFunc(ClientContext &context, TableFunctionInput &data, if (global_state.CanRemoveFilterColumns()) { local_state.all_columns.Reset(); local_state.all_columns.SetCardinality(output_size); - ArrowTableFunction::ArrowToDuckDB(local_state, - bind_data.arrow_table.GetColumns(), + ArrowTableFunction::ArrowToDuckDB(local_state, arrow_columns, local_state.all_columns, start); local_state.chunk_offset += output_size; if (local_state.filters && !local_state.filter_pushed_down) { @@ -741,8 +1151,8 @@ static void LanceScanFunc(ClientContext &context, TableFunctionInput &data, output.SetCardinality(local_state.all_columns); } else { output.SetCardinality(output_size); - ArrowTableFunction::ArrowToDuckDB( - local_state, bind_data.arrow_table.GetColumns(), output, start); + ArrowTableFunction::ArrowToDuckDB(local_state, arrow_columns, output, + start); local_state.chunk_offset += output_size; if (local_state.filters && !local_state.filter_pushed_down) { ApplyDuckDBFilters(context, *local_state.filters, output, @@ -904,6 +1314,243 @@ static bool IsLanceScanTableFunction(const TableFunction &fn) { fn.name == "__lance_namespace_scan"; } +static bool IsLanceRowIdColumn(const LogicalGet &get, + const LanceScanBindData &scan_bind, + const BoundColumnRefExpression &colref) { + if (colref.binding.table_index != get.table_index) { + return false; + } + auto &col_ids = get.GetColumnIds(); + if (colref.binding.column_index >= col_ids.size()) { + return false; + } + + auto col_id = col_ids[colref.binding.column_index].GetPrimaryIndex(); + (void)scan_bind; + return IsLanceVirtualRowIdColumnId(col_id); +} + +static bool IsLanceRowIdColumn(const LogicalGet &get, + const LanceScanBindData &scan_bind, + const BoundReferenceExpression &ref) { + auto &col_ids = get.GetColumnIds(); + auto idx = NumericCast(ref.index); + if (idx >= col_ids.size()) { + return false; + } + + auto col_id = col_ids[idx].GetPrimaryIndex(); + (void)scan_bind; + return IsLanceVirtualRowIdColumnId(col_id); +} + +static bool TryExtractTakeRowIdsFromExpression( + const LogicalGet &get, const LanceScanBindData &scan_bind, + const Expression &expr, vector &out_row_ids) { + auto is_row_id = [&](const Expression &candidate) { + switch (candidate.GetExpressionClass()) { + case ExpressionClass::BOUND_COLUMN_REF: + return IsLanceRowIdColumn(get, scan_bind, + candidate.Cast()); + case ExpressionClass::BOUND_REF: + return IsLanceRowIdColumn(get, scan_bind, + candidate.Cast()); + default: + return false; + } + }; + + auto parse_row_id = [&](const Value &v, uint64_t &out) { + if (!TryParseRowIdValue(v, out)) { + throw InvalidInputException( + "Lance point lookup requires integer _rowid values"); + } + return true; + }; + + if (expr.GetExpressionType() == ExpressionType::COMPARE_IN && + expr.GetExpressionClass() == ExpressionClass::BOUND_OPERATOR) { + auto &op = expr.Cast(); + if (op.children.size() <= 1 || !op.children[0] || + !is_row_id(*op.children[0])) { + return false; + } + + out_row_ids.clear(); + out_row_ids.reserve(op.children.size() - 1); + for (idx_t i = 1; i < op.children.size(); i++) { + auto &child = op.children[i]; + if (!child || + child->GetExpressionClass() != ExpressionClass::BOUND_CONSTANT) { + return false; + } + uint64_t row_id = 0; + parse_row_id(child->Cast().value, row_id); + out_row_ids.push_back(row_id); + } + return true; + } + + auto try_extract_equal = [&](const Expression &candidate, uint64_t &out) { + if (candidate.GetExpressionType() != ExpressionType::COMPARE_EQUAL || + candidate.GetExpressionClass() != ExpressionClass::BOUND_COMPARISON) { + return false; + } + auto &cmp = candidate.Cast(); + if (!cmp.left || !cmp.right) { + return false; + } + + if (is_row_id(*cmp.left) && + cmp.right->GetExpressionClass() == ExpressionClass::BOUND_CONSTANT) { + parse_row_id(cmp.right->Cast().value, out); + return true; + } + if (is_row_id(*cmp.right) && + cmp.left->GetExpressionClass() == ExpressionClass::BOUND_CONSTANT) { + parse_row_id(cmp.left->Cast().value, out); + return true; + } + return false; + }; + + if (expr.GetExpressionType() == ExpressionType::CONJUNCTION_OR && + expr.GetExpressionClass() == ExpressionClass::BOUND_CONJUNCTION) { + out_row_ids.clear(); + vector stack; + stack.push_back(&expr); + while (!stack.empty()) { + auto *current = stack.back(); + stack.pop_back(); + if (!current) { + continue; + } + if (current->GetExpressionType() == ExpressionType::CONJUNCTION_OR && + current->GetExpressionClass() == ExpressionClass::BOUND_CONJUNCTION) { + auto &conj = current->Cast(); + // Preserve left-to-right evaluation order. + for (auto it = conj.children.rbegin(); it != conj.children.rend(); + ++it) { + stack.push_back(it->get()); + } + continue; + } + uint64_t row_id = 0; + if (!try_extract_equal(*current, row_id)) { + return false; + } + out_row_ids.push_back(row_id); + } + return !out_row_ids.empty(); + } + + return false; +} + +static unique_ptr +LanceRowIdInRewrite(unique_ptr op) { + for (auto &child : op->children) { + child = LanceRowIdInRewrite(std::move(child)); + } + + if (op->type == LogicalOperatorType::LOGICAL_GET) { + auto &get = op->Cast(); + if (!IsLanceScanTableFunction(get.function) || !get.bind_data) { + return op; + } + auto &scan_bind = get.bind_data->Cast(); + + auto it = get.table_filters.filters.find(LANCE_COLUMN_IDENTIFIER_ROW_ID); + if (it == get.table_filters.filters.end() || !it->second) { + return op; + } + + vector row_ids; + bool can_take = scan_bind.take_row_ids.empty() && + TryExtractTakeRowIdsFromFilter(*it->second, row_ids) && + !row_ids.empty(); + if (can_take) { + scan_bind.take_row_ids = std::move(row_ids); + get.table_filters.filters.erase(it); + return op; + } + + auto &col_ids = get.GetColumnIds(); + optional_idx col_pos = optional_idx::Invalid(); + for (idx_t i = 0; i < col_ids.size(); i++) { + if (col_ids[i].GetPrimaryIndex() == LANCE_COLUMN_IDENTIFIER_ROW_ID) { + col_pos = optional_idx(i); + break; + } + } + if (!col_pos.IsValid()) { + throw InternalException( + "Lance scan found a _rowid table filter without a _rowid column"); + } + + auto colref = make_uniq( + LogicalType::UBIGINT, + ColumnBinding(get.table_index, col_pos.GetIndex())); + auto filter_expr = it->second->ToExpression(*colref); + get.table_filters.filters.erase(it); + + auto estimated = op->estimated_cardinality; + auto filter = make_uniq(); + filter->expressions.push_back(std::move(filter_expr)); + filter->children.push_back(std::move(op)); + filter->estimated_cardinality = estimated; + return std::move(filter); + } + + if (op->type != LogicalOperatorType::LOGICAL_FILTER || + op->children.size() != 1 || !op->children[0]) { + return op; + } + + auto &filter_op = op->Cast(); + auto *node = op->children[0].get(); + while (node && node->type == LogicalOperatorType::LOGICAL_PROJECTION) { + if (node->children.empty() || !node->children[0]) { + return op; + } + node = node->children[0].get(); + } + if (!node || node->type != LogicalOperatorType::LOGICAL_GET) { + return op; + } + + auto &get = node->Cast(); + if (!IsLanceScanTableFunction(get.function) || !get.bind_data) { + return op; + } + auto &scan_bind = get.bind_data->Cast(); + + vector row_ids; + idx_t idx = 0; + bool found = false; + for (idx_t i = 0; i < filter_op.expressions.size(); i++) { + if (TryExtractTakeRowIdsFromExpression( + get, scan_bind, *filter_op.expressions[i], row_ids)) { + found = true; + idx = i; + break; + } + } + if (!found) { + return op; + } + + scan_bind.take_row_ids = row_ids; + filter_op.expressions.erase(filter_op.expressions.begin() + + NumericCast(idx)); + if (filter_op.expressions.empty()) { + auto child = std::move(op->children[0]); + child->estimated_cardinality = op->estimated_cardinality; + return child; + } + return op; +} + static unique_ptr LanceLimitOffsetPushdown(unique_ptr op) { for (auto &child : op->children) { @@ -941,6 +1588,22 @@ LanceLimitOffsetPushdown(unique_ptr op) { } auto &scan_bind = get.bind_data->Cast(); + if (!scan_bind.take_row_ids.empty()) { + return op; + } + auto &col_ids = get.GetColumnIds(); + for (auto &entry : get.table_filters.filters) { + if (!entry.second || entry.first >= col_ids.size()) { + continue; + } + auto col_id = col_ids[entry.first].GetPrimaryIndex(); + if (!IsLanceVirtualRowIdColumnId(col_id)) { + continue; + } + if (IsTakeRowIdFilter(*entry.second)) { + return op; + } + } scan_bind.limit_offset_pushed_down = true; scan_bind.pushed_limit = pushed_limit; scan_bind.pushed_offset = pushed_offset; @@ -956,10 +1619,142 @@ LanceLimitOffsetPushdownOptimizer(OptimizerExtensionInput &, plan = LanceLimitOffsetPushdown(std::move(plan)); } +static void LanceRowIdInRewriteOptimizer(OptimizerExtensionInput &, + unique_ptr &plan) { + plan = LanceRowIdInRewrite(std::move(plan)); +} + +static unique_ptr +LanceLikePushdown(unique_ptr op) { + for (auto &child : op->children) { + child = LanceLikePushdown(std::move(child)); + } + + if (op->type != LogicalOperatorType::LOGICAL_FILTER || + op->children.size() != 1 || !op->children[0]) { + return op; + } + + auto &filter_op = op->Cast(); + auto *node = op->children[0].get(); + while (node && node->type == LogicalOperatorType::LOGICAL_PROJECTION) { + if (node->children.empty() || !node->children[0]) { + return op; + } + node = node->children[0].get(); + } + if (!node || node->type != LogicalOperatorType::LOGICAL_GET) { + return op; + } + + auto &get = node->Cast(); + if (!IsLanceScanTableFunction(get.function) || !get.bind_data) { + return op; + } + auto &scan_bind = get.bind_data->Cast(); + + for (auto &expr : filter_op.expressions) { + if (!expr || expr->HasParameter() || expr->IsVolatile() || + expr->CanThrow()) { + continue; + } + if (expr->GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { + continue; + } + auto &func = expr->Cast(); + auto &name = func.function.name; + if (name != "~~" && name != "~~*" && name != "like_escape" && + name != "ilike_escape") { + continue; + } + + string filter_ir; + if (!TryBuildLanceExprFilterIR(get, scan_bind.names, scan_bind.types, false, + *expr, filter_ir)) { + continue; + } + scan_bind.lance_pushed_filter_ir_parts.push_back(std::move(filter_ir)); + } + + return op; +} + +static void LanceLikePushdownOptimizer(OptimizerExtensionInput &, + unique_ptr &plan) { + plan = LanceLikePushdown(std::move(plan)); +} + +static unique_ptr +LanceCardinalityFixup(ClientContext &context, unique_ptr op) { + for (auto &child : op->children) { + child = LanceCardinalityFixup(context, std::move(child)); + } + + if (op->type != LogicalOperatorType::LOGICAL_GET) { + return op; + } + + auto &get = op->Cast(); + if (!IsLanceScanTableFunction(get.function) || !get.bind_data) { + if (get.function.name != "lance_scan" || get.parameters.size() != 1 || + get.parameters[0].IsNull() || + get.parameters[0].type() != LogicalType::VARCHAR) { + return op; + } + + auto path = get.parameters[0].GetValue(); + auto *dataset = LanceOpenDataset(context, path); + if (!dataset) { + return op; + } + auto rows = lance_dataset_count_rows(dataset); + lance_close_dataset(dataset); + if (rows < 0) { + return op; + } + get.SetEstimatedCardinality(NumericCast(rows)); + return op; + } + + auto &scan_bind = get.bind_data->Cast(); + if (!scan_bind.take_row_ids.empty()) { + get.SetEstimatedCardinality(scan_bind.take_row_ids.size()); + return op; + } + + if (!scan_bind.dataset) { + return op; + } + + auto rows = lance_dataset_count_rows(scan_bind.dataset); + if (rows < 0) { + return op; + } + get.SetEstimatedCardinality(NumericCast(rows)); + return op; +} + +static void LanceCardinalityFixupOptimizer(OptimizerExtensionInput &input, + unique_ptr &plan) { + plan = LanceCardinalityFixup(input.context, std::move(plan)); +} + void RegisterLanceScanOptimizer(DBConfig &config) { - OptimizerExtension ext; - ext.optimize_function = LanceLimitOffsetPushdownOptimizer; - config.optimizer_extensions.push_back(std::move(ext)); + OptimizerExtension rowid_take_ext; + rowid_take_ext.optimize_function = LanceRowIdInRewriteOptimizer; + config.optimizer_extensions.push_back(std::move(rowid_take_ext)); + + OptimizerExtension like_ext; + like_ext.optimize_function = LanceLikePushdownOptimizer; + config.optimizer_extensions.push_back(std::move(like_ext)); + + OptimizerExtension limit_ext; + limit_ext.optimize_function = LanceLimitOffsetPushdownOptimizer; + config.optimizer_extensions.push_back(std::move(limit_ext)); + + OptimizerExtension cardinality_ext; + cardinality_ext.optimize_function = LanceCardinalityFixupOptimizer; + config.optimizer_extensions.push_back(std::move(cardinality_ext)); } static TableFunction LanceTableScanFunction() { @@ -972,6 +1767,8 @@ static TableFunction LanceTableScanFunction() { function.get_partition_stats = LanceScanGetPartitionStats; function.supports_pushdown_type = LanceSupportsPushdownType; function.pushdown_complex_filter = LancePushdownComplexFilter; + function.pushdown_expression = LancePushdownExpression; + function.get_virtual_columns = LanceGetVirtualColumns; function.to_string = LanceScanToString; function.dynamic_to_string = LanceScanDynamicToString; function.init_global = LanceScanInitGlobal; @@ -1392,6 +2189,24 @@ LanceTableEntry::GetScanFunction(ClientContext &context, result->names = result->arrow_table.GetNames(); result->types = result->arrow_table.GetTypes(); + auto *scan_schema_handle = lance_get_schema_for_scan(result->dataset); + if (!scan_schema_handle) { + throw IOException("Failed to get scan schema from Lance dataset: " + + result->file_path + LanceFormatErrorSuffix()); + } + memset(&result->scan_schema_root.arrow_schema, 0, + sizeof(result->scan_schema_root.arrow_schema)); + if (lance_schema_to_arrow(scan_schema_handle, + &result->scan_schema_root.arrow_schema) != 0) { + lance_free_schema(scan_schema_handle); + throw IOException( + "Failed to export Lance scan schema to Arrow C Data Interface" + + LanceFormatErrorSuffix()); + } + lance_free_schema(scan_schema_handle); + ArrowTableFunction::PopulateArrowTableSchema( + config, result->scan_arrow_table, result->scan_schema_root.arrow_schema); + bind_data = std::move(result); return LanceTableScanFunction(); } @@ -1409,6 +2224,8 @@ void RegisterLanceScan(ExtensionLoader &loader) { lance_scan.get_partition_stats = LanceScanGetPartitionStats; lance_scan.supports_pushdown_type = LanceSupportsPushdownType; lance_scan.pushdown_complex_filter = LancePushdownComplexFilter; + lance_scan.pushdown_expression = LancePushdownExpression; + lance_scan.get_virtual_columns = LanceGetVirtualColumns; lance_scan.to_string = LanceScanToString; lance_scan.dynamic_to_string = LanceScanDynamicToString; loader.RegisterFunction(lance_scan); @@ -1432,6 +2249,8 @@ void RegisterLanceScan(ExtensionLoader &loader) { internal_namespace_scan.get_partition_stats = LanceScanGetPartitionStats; internal_namespace_scan.supports_pushdown_type = LanceSupportsPushdownType; internal_namespace_scan.pushdown_complex_filter = LancePushdownComplexFilter; + internal_namespace_scan.pushdown_expression = LancePushdownExpression; + internal_namespace_scan.get_virtual_columns = LanceGetVirtualColumns; internal_namespace_scan.to_string = LanceScanToString; internal_namespace_scan.dynamic_to_string = LanceScanDynamicToString; diff --git a/test/sql/lance_rowid_in.test b/test/sql/lance_rowid_in.test new file mode 100644 index 0000000..f2281b6 --- /dev/null +++ b/test/sql/lance_rowid_in.test @@ -0,0 +1,29 @@ +# name: test/sql/lance_rowid_in.test +# description: Point lookup via `WHERE _rowid IN (...)` rewrite +# group: [sql] + +require lance + +# Basic ordering: output matches the IN-list order +query T +SELECT name FROM 'test/data/test_data.lance' WHERE _rowid IN (2, 0, 4) +---- +Charlie +Alice +Eve + +# Duplicate row ids: output preserves duplicates +query T +SELECT name FROM 'test/data/test_data.lance' WHERE _rowid IN (1, 1, 3) +---- +Bob +Bob +David + +# Non-existent row ids: missing ids are ignored +query T +SELECT name FROM 'test/data/test_data.lance' WHERE _rowid IN (0, 999, 4) +---- +Alice +Eve +