diff --git a/rust/lance-core/src/error.rs b/rust/lance-core/src/error.rs index f80dbca4a7..7a20392618 100644 --- a/rust/lance-core/src/error.rs +++ b/rust/lance-core/src/error.rs @@ -118,6 +118,13 @@ pub enum Error { source: BoxedError, location: Location, }, + /// External error passed through from user code. + /// + /// This variant preserves errors that users pass into Lance APIs (e.g., via streams + /// with custom error types). The original error can be recovered using [`Error::into_external`] + /// or inspected using [`Error::external_source`]. + #[snafu(transparent)] + External { source: BoxedError }, } impl Error { @@ -164,6 +171,31 @@ impl Error { location, } } + + /// Create an External error from a boxed error source. + pub fn external(source: BoxedError) -> Self { + Self::External { source } + } + + /// Returns a reference to the external error source if this is an `External` variant. + /// + /// This allows downcasting to recover the original error type. + pub fn external_source(&self) -> Option<&BoxedError> { + match self { + Self::External { source } => Some(source), + _ => None, + } + } + + /// Consumes the error and returns the external source if this is an `External` variant. + /// + /// Returns `Err(self)` if this is not an `External` variant, allowing for chained handling. + pub fn into_external(self) -> std::result::Result { + match self { + Self::External { source } => Ok(source), + other => Err(other), + } + } } pub trait LanceOptionExt { @@ -202,9 +234,18 @@ pub type DataFusionResult = std::result::Result for Error { #[track_caller] fn from(e: ArrowError) -> Self { - Self::Arrow { - message: e.to_string(), - location: std::panic::Location::caller().to_snafu_location(), + match e { + ArrowError::ExternalError(source) => { + // Try to downcast to lance_core::Error first to recover the original + match source.downcast::() { + Ok(lance_err) => *lance_err, + Err(source) => Self::External { source }, + } + } + other => Self::Arrow { + message: other.to_string(), + location: std::panic::Location::caller().to_snafu_location(), + }, } } } @@ -309,20 +350,15 @@ impl From for Error { } } -#[track_caller] -fn arrow_io_error_from_msg(message: String) -> ArrowError { - ArrowError::IoError(message.clone(), std::io::Error::other(message)) -} - impl From for ArrowError { fn from(value: Error) -> Self { match value { - Error::Arrow { message, .. } => arrow_io_error_from_msg(message), // we lose the error type converting to LanceError - Error::IO { source, .. } => arrow_io_error_from_msg(source.to_string()), + // Pass through external errors directly + Error::External { source } => Self::ExternalError(source), + // Preserve schema errors with their specific type Error::Schema { message, .. } => Self::SchemaError(message), - Error::Index { message, .. } => arrow_io_error_from_msg(message), - Error::Stop => arrow_io_error_from_msg("early stop".to_string()), - e => arrow_io_error_from_msg(e.to_string()), // Find a more scalable way of doing this + // Wrap all other lance errors so they can be recovered + e => Self::ExternalError(Box::new(e)), } } } @@ -353,7 +389,7 @@ impl From for Error { impl From for datafusion_common::DataFusionError { #[track_caller] fn from(e: Error) -> Self { - Self::Execution(e.to_string()) + Self::External(Box::new(e)) } } @@ -373,10 +409,22 @@ impl From for Error { message: e.to_string(), location, }, - datafusion_common::DataFusionError::ArrowError(..) => Self::Arrow { - message: e.to_string(), - location, - }, + datafusion_common::DataFusionError::ArrowError(arrow_err, _) => { + // Check if the ArrowError wraps an external error and extract it + match *arrow_err { + ArrowError::ExternalError(source) => { + // Try to downcast to lance_core::Error first + match source.downcast::() { + Ok(lance_err) => *lance_err, + Err(source) => Self::External { source }, + } + } + other => Self::Arrow { + message: other.to_string(), + location, + }, + } + } datafusion_common::DataFusionError::NotImplemented(..) => Self::NotSupported { source: box_error(e), location, @@ -385,6 +433,13 @@ impl From for Error { message: e.to_string(), location, }, + datafusion_common::DataFusionError::External(source) => { + // Try to downcast to lance_core::Error first + match source.downcast::() { + Ok(lance_err) => *lance_err, + Err(source) => Self::External { source }, + } + } _ => Self::IO { source: box_error(e), location, @@ -439,6 +494,7 @@ impl From> for CloneableResult { #[cfg(test)] mod test { use super::*; + use std::fmt; #[test] fn test_caller_location_capture() { @@ -461,4 +517,208 @@ mod test { _ => panic!("expected ObjectStore error"), } } + + #[derive(Debug)] + struct MyCustomError { + code: i32, + message: String, + } + + impl fmt::Display for MyCustomError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "MyCustomError({}): {}", self.code, self.message) + } + } + + impl std::error::Error for MyCustomError {} + + #[test] + fn test_external_error_creation() { + let custom_err = MyCustomError { + code: 42, + message: "test error".to_string(), + }; + let err = Error::external(Box::new(custom_err)); + + match &err { + Error::External { source } => { + let recovered = source.downcast_ref::().unwrap(); + assert_eq!(recovered.code, 42); + assert_eq!(recovered.message, "test error"); + } + _ => panic!("Expected External variant"), + } + } + + #[test] + fn test_external_source_method() { + let custom_err = MyCustomError { + code: 123, + message: "source test".to_string(), + }; + let err = Error::external(Box::new(custom_err)); + + let source = err.external_source().expect("should have external source"); + let recovered = source.downcast_ref::().unwrap(); + assert_eq!(recovered.code, 123); + + // Test that non-External variants return None + let io_err = Error::io("test", snafu::Location::new("test", 1, 1)); + assert!(io_err.external_source().is_none()); + } + + #[test] + fn test_into_external_method() { + let custom_err = MyCustomError { + code: 456, + message: "into test".to_string(), + }; + let err = Error::external(Box::new(custom_err)); + + match err.into_external() { + Ok(source) => { + let recovered = source.downcast::().unwrap(); + assert_eq!(recovered.code, 456); + } + Err(_) => panic!("Expected Ok"), + } + + // Test that non-External variants return Err(self) + let io_err = Error::io("test", snafu::Location::new("test", 1, 1)); + match io_err.into_external() { + Err(Error::IO { .. }) => {} + _ => panic!("Expected Err with IO variant"), + } + } + + #[test] + fn test_arrow_external_error_conversion() { + let custom_err = MyCustomError { + code: 789, + message: "arrow test".to_string(), + }; + let arrow_err = ArrowError::ExternalError(Box::new(custom_err)); + let lance_err: Error = arrow_err.into(); + + match lance_err { + Error::External { source } => { + let recovered = source.downcast_ref::().unwrap(); + assert_eq!(recovered.code, 789); + } + _ => panic!("Expected External variant, got {:?}", lance_err), + } + } + + #[test] + fn test_external_to_arrow_roundtrip() { + let custom_err = MyCustomError { + code: 999, + message: "roundtrip".to_string(), + }; + let lance_err = Error::external(Box::new(custom_err)); + let arrow_err: ArrowError = lance_err.into(); + + match arrow_err { + ArrowError::ExternalError(source) => { + let recovered = source.downcast_ref::().unwrap(); + assert_eq!(recovered.code, 999); + } + _ => panic!("Expected ExternalError variant"), + } + } + + #[cfg(feature = "datafusion")] + #[test] + fn test_datafusion_external_error_conversion() { + let custom_err = MyCustomError { + code: 111, + message: "datafusion test".to_string(), + }; + let df_err = datafusion_common::DataFusionError::External(Box::new(custom_err)); + let lance_err: Error = df_err.into(); + + match lance_err { + Error::External { source } => { + let recovered = source.downcast_ref::().unwrap(); + assert_eq!(recovered.code, 111); + } + _ => panic!("Expected External variant"), + } + } + + #[cfg(feature = "datafusion")] + #[test] + fn test_datafusion_arrow_external_error_conversion() { + // Test the nested case: ArrowError::ExternalError inside DataFusionError::ArrowError + let custom_err = MyCustomError { + code: 222, + message: "nested test".to_string(), + }; + let arrow_err = ArrowError::ExternalError(Box::new(custom_err)); + let df_err = datafusion_common::DataFusionError::ArrowError(Box::new(arrow_err), None); + let lance_err: Error = df_err.into(); + + match lance_err { + Error::External { source } => { + let recovered = source.downcast_ref::().unwrap(); + assert_eq!(recovered.code, 222); + } + _ => panic!("Expected External variant, got {:?}", lance_err), + } + } + + /// Test that lance_core::Error round-trips through ArrowError. + /// + /// This simulates the case where a user defines an iterator in terms of + /// lance_core::Error, and the error goes through Arrow's error type + /// (e.g., via RecordBatchIterator) before being converted back. + #[test] + fn test_lance_error_roundtrip_through_arrow() { + let original = Error::invalid_input( + "test validation error", + snafu::Location::new("test.rs", 10, 1), + ); + + // Simulate what happens when using ? in an Arrow context + let arrow_err: ArrowError = original.into(); + + // Convert back to lance error (as happens when Lance consumes the stream) + let recovered: Error = arrow_err.into(); + + // Should get back the original lance error directly (not wrapped in External) + match recovered { + Error::InvalidInput { .. } => { + assert!(recovered.to_string().contains("test validation error")); + } + _ => panic!("Expected InvalidInput variant, got {:?}", recovered), + } + } + + /// Test that lance_core::Error round-trips through DataFusionError. + /// + /// This simulates the case where a user defines a stream in terms of + /// lance_core::Error, and the error goes through DataFusion's error type + /// (e.g., via SendableRecordBatchStream) before being converted back. + #[cfg(feature = "datafusion")] + #[test] + fn test_lance_error_roundtrip_through_datafusion() { + let original = Error::invalid_input( + "test validation error", + snafu::Location::new("test.rs", 10, 1), + ); + + // Simulate what happens when using ? in a DataFusion context + let df_err: datafusion_common::DataFusionError = original.into(); + + // Convert back to lance error (as happens when Lance consumes the stream) + let recovered: Error = df_err.into(); + + // Should get back the original lance error directly (not wrapped in External) + match recovered { + Error::InvalidInput { .. } => { + assert!(recovered.to_string().contains("test validation error")); + } + _ => panic!("Expected InvalidInput variant, got {:?}", recovered), + } + } } diff --git a/rust/lance/src/dataset/write/insert.rs b/rust/lance/src/dataset/write/insert.rs index d1ee6db414..7763c5e8f7 100644 --- a/rust/lance/src/dataset/write/insert.rs +++ b/rust/lance/src/dataset/write/insert.rs @@ -434,8 +434,8 @@ struct WriteContext<'a> { #[cfg(test)] mod test { - use arrow_array::{Int32Array, StructArray}; - use arrow_schema::{DataType, Field, Schema}; + use arrow_array::{Int32Array, RecordBatchReader, StructArray}; + use arrow_schema::{ArrowError, DataType, Field, Schema}; use crate::session::Session; @@ -515,4 +515,100 @@ mod test { assert!(matches!(result, Err(Error::InvalidInput { .. }))); } + + mod external_error { + use super::*; + use std::fmt; + + #[derive(Debug)] + struct MyTestError { + code: i32, + details: String, + } + + impl fmt::Display for MyTestError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "MyTestError({}): {}", self.code, self.details) + } + } + + impl std::error::Error for MyTestError {} + + fn create_failing_iterator( + schema: Arc, + fail_at_batch: usize, + error_code: i32, + ) -> impl Iterator> { + let mut batch_count = 0; + std::iter::from_fn(move || { + if batch_count >= 5 { + return None; + } + batch_count += 1; + if batch_count == fail_at_batch { + Some(Err(ArrowError::ExternalError(Box::new(MyTestError { + code: error_code, + details: format!("Failed at batch {}", batch_count), + })))) + } else { + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![batch_count as i32; 10]))], + ) + .unwrap(); + Some(Ok(batch)) + } + }) + } + + #[tokio::test] + async fn test_insert_builder_preserves_external_error() { + let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let error_code = 42; + let iter = create_failing_iterator(schema.clone(), 3, error_code); + let reader = RecordBatchIterator::new(iter, schema); + + let result = InsertBuilder::new("memory://test_external_error") + .execute_stream(Box::new(reader) as Box) + .await; + + match result { + Err(Error::External { source }) => { + let original = source + .downcast_ref::() + .expect("Should be able to downcast to MyTestError"); + assert_eq!(original.code, error_code); + assert!(original.details.contains("batch 3")); + } + Err(other) => panic!("Expected Error::External variant, got: {:?}", other), + Ok(_) => panic!("Expected error, got success"), + } + } + + #[tokio::test] + async fn test_insert_builder_first_batch_error() { + let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let error_code = 999; + let iter = std::iter::once(Err(ArrowError::ExternalError(Box::new(MyTestError { + code: error_code, + details: "immediate failure".to_string(), + })))); + let reader = RecordBatchIterator::new(iter, schema); + + let result = InsertBuilder::new("memory://test_first_batch_error") + .execute_stream(Box::new(reader) as Box) + .await; + + match result { + Err(Error::External { source }) => { + let original = source.downcast_ref::().unwrap(); + assert_eq!(original.code, error_code); + } + Err(other) => panic!("Expected External, got: {:?}", other), + Ok(_) => panic!("Expected error"), + } + } + } } diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index c4709624bc..abf36994e4 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -5314,4 +5314,72 @@ MergeInsert: on=[id], when_matched=UpdateAll, when_not_matched=InsertAll, when_n assert_eq!(result, expected); } + + mod external_error { + use super::*; + use arrow_schema::{ArrowError, Field as ArrowField, Schema as ArrowSchema}; + use std::fmt; + + #[derive(Debug)] + struct MyTestError { + code: i32, + details: String, + } + + impl fmt::Display for MyTestError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "MyTestError({}): {}", self.code, self.details) + } + } + + impl std::error::Error for MyTestError {} + + #[tokio::test] + async fn test_merge_insert_execute_reader_preserves_external_error() { + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("key", DataType::Int32, false), + ArrowField::new("value", DataType::Int32, false), + ])); + + // Create initial dataset + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![10, 20, 30])), + ], + ) + .unwrap(); + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + let dataset = Arc::new( + Dataset::write(reader, "memory://test_merge_external", None) + .await + .unwrap(), + ); + + // Try merge insert with failing source + let error_code = 789; + let iter = std::iter::once(Err(ArrowError::ExternalError(Box::new(MyTestError { + code: error_code, + details: "merge insert failure".to_string(), + })))); + let reader = RecordBatchIterator::new(iter, schema); + + let result = MergeInsertBuilder::try_new(dataset, vec!["key".to_string()]) + .unwrap() + .try_build() + .unwrap() + .execute_reader(Box::new(reader) as Box) + .await; + + match result { + Err(Error::External { source }) => { + let original = source.downcast_ref::().unwrap(); + assert_eq!(original.code, error_code); + } + Err(other) => panic!("Expected External, got: {:?}", other), + Ok(_) => panic!("Expected error"), + } + } + } }