Skip to content

Commit be7daf5

Browse files
committed
Add ArrowArrayExportable class and use it to create pyarrow arrays for python UDFs
1 parent e8d7dbf commit be7daf5

File tree

6 files changed

+115
-77
lines changed

6 files changed

+115
-77
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ dev = [
141141
"maturin>=1.8.1",
142142
"numpy>1.25.0;python_version<'3.14'",
143143
"numpy>=2.3.2;python_version>='3.14'",
144+
"pyarrow>=19.0.0",
144145
"pre-commit>=4.3.0",
145146
"pyyaml>=6.0.3",
146147
"pytest>=7.4.4",

python/tests/test_udf.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pyarrow as pa
1919
import pytest
2020
from datafusion import column, udf
21+
from datafusion import functions as f
2122

2223

2324
@pytest.fixture
@@ -124,3 +125,26 @@ def udf_with_param(values: pa.Array) -> pa.Array:
124125
result = df2.collect()[0].column(0)
125126

126127
assert result == pa.array([False, True, True])
128+
129+
130+
def test_udf_with_metadata(ctx) -> None:
131+
from uuid import UUID
132+
133+
@udf([pa.string()], pa.uuid(), "stable")
134+
def uuid_from_string(uuid_string):
135+
return pa.array((UUID(s).bytes for s in uuid_string.to_pylist()), pa.uuid())
136+
137+
@udf([pa.uuid()], pa.int64(), "stable")
138+
def uuid_version(uuid):
139+
return pa.array(s.version for s in uuid.to_pylist())
140+
141+
batch = pa.record_batch({"idx": pa.array(range(5))})
142+
results = (
143+
ctx.create_dataframe([[batch]])
144+
.with_column("uuid_string", f.uuid())
145+
.with_column("uuid", uuid_from_string(column("uuid_string")))
146+
.select(uuid_version(column("uuid").alias("uuid_version")))
147+
.collect()
148+
)
149+
150+
assert results[0][0].to_pylist() == [4, 4, 4, 4, 4]

src/array.rs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::errors::PyDataFusionResult;
19+
use crate::utils::validate_pycapsule;
20+
use arrow::array::{Array, ArrayRef};
21+
use arrow::datatypes::{Field, FieldRef};
22+
use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
23+
use arrow::pyarrow::ToPyArrow;
24+
use pyo3::prelude::{PyAnyMethods, PyCapsuleMethods};
25+
use pyo3::types::PyCapsule;
26+
use pyo3::{pyclass, pymethods, Bound, PyObject, PyResult, Python};
27+
use std::sync::Arc;
28+
29+
/// A Python object which implements the Arrow PyCapsule for importing
30+
/// into other libraries.
31+
#[pyclass(name = "ArrowArrayExportable", module = "datafusion", frozen)]
32+
#[derive(Clone)]
33+
pub struct PyArrowArrayExportable {
34+
array: ArrayRef,
35+
field: FieldRef,
36+
}
37+
38+
#[pymethods]
39+
impl PyArrowArrayExportable {
40+
#[pyo3(signature = (requested_schema=None))]
41+
fn __arrow_c_array__<'py>(
42+
&'py self,
43+
py: Python<'py>,
44+
requested_schema: Option<Bound<'py, PyCapsule>>,
45+
) -> PyDataFusionResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> {
46+
let field = if let Some(schema_capsule) = requested_schema {
47+
validate_pycapsule(&schema_capsule, "arrow_schema")?;
48+
49+
let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
50+
let desired_field = Field::try_from(schema_ptr)?;
51+
52+
Arc::new(desired_field)
53+
} else {
54+
Arc::clone(&self.field)
55+
};
56+
57+
let ffi_schema = FFI_ArrowSchema::try_from(&field)?;
58+
let schema_capsule = PyCapsule::new(py, ffi_schema, Some(cr"arrow_schema".into()))?;
59+
60+
let ffi_array = FFI_ArrowArray::new(&self.array.to_data());
61+
let array_capsule = PyCapsule::new(py, ffi_array, Some(cr"arrow_array".into()))?;
62+
63+
Ok((schema_capsule, array_capsule))
64+
}
65+
}
66+
67+
impl ToPyArrow for PyArrowArrayExportable {
68+
fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
69+
let module = py.import("pyarrow")?;
70+
let method = module.getattr("array")?;
71+
let array = method.call((self.clone(),), None)?;
72+
Ok(array.unbind())
73+
}
74+
}
75+
76+
impl PyArrowArrayExportable {
77+
pub fn new(array: ArrayRef, field: FieldRef) -> Self {
78+
Self { array, field }
79+
}
80+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ pub mod store;
5555
pub mod table;
5656
pub mod unparser;
5757

58+
mod array;
5859
#[cfg(feature = "substrait")]
5960
pub mod substrait;
6061
#[allow(clippy::borrow_deref_ref)]

src/udf.rs

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::array::PyArrowArrayExportable;
1819
use crate::errors::to_datafusion_err;
1920
use crate::errors::{py_datafusion_err, PyDataFusionResult};
2021
use crate::expr::PyExpr;
2122
use crate::utils::{parse_volatility, validate_pycapsule};
2223
use arrow::datatypes::{Field, FieldRef};
23-
use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
24-
use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef};
24+
use arrow::pyarrow::ToPyArrow;
25+
use datafusion::arrow::array::{make_array, ArrayData};
2526
use datafusion::arrow::datatypes::DataType;
2627
use datafusion::arrow::pyarrow::FromPyArrow;
2728
use datafusion::arrow::pyarrow::PyArrowType;
@@ -31,12 +32,10 @@ use datafusion::logical_expr::{
3132
ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
3233
};
3334
use datafusion_ffi::udf::{FFI_ScalarUDF, ForeignScalarUDF};
34-
use pyo3::ffi::Py_uintptr_t;
3535
use pyo3::types::PyCapsule;
3636
use pyo3::{prelude::*, types::PyTuple};
3737
use std::any::Any;
3838
use std::hash::{Hash, Hasher};
39-
use std::ptr::addr_of;
4039
use std::sync::Arc;
4140

4241
/// This struct holds the Python written function that is a
@@ -92,26 +91,6 @@ impl Hash for PythonFunctionScalarUDF {
9291
}
9392
}
9493

95-
fn array_to_pyarrow_with_field(
96-
py: Python,
97-
array: ArrayRef,
98-
field: &FieldRef,
99-
) -> PyResult<PyObject> {
100-
let array = FFI_ArrowArray::new(&array.to_data());
101-
let schema = FFI_ArrowSchema::try_from(field).map_err(py_datafusion_err)?;
102-
103-
let module = py.import("pyarrow")?;
104-
let class = module.getattr("Array")?;
105-
let array = class.call_method1(
106-
"_import_from_c",
107-
(
108-
addr_of!(array) as Py_uintptr_t,
109-
addr_of!(schema) as Py_uintptr_t,
110-
),
111-
)?;
112-
Ok(array.unbind())
113-
}
114-
11594
impl ScalarUDFImpl for PythonFunctionScalarUDF {
11695
fn as_any(&self) -> &dyn Any {
11796
self
@@ -150,7 +129,9 @@ impl ScalarUDFImpl for PythonFunctionScalarUDF {
150129
.zip(args.arg_fields)
151130
.map(|(arg, field)| {
152131
let array = arg.to_array(num_rows)?;
153-
array_to_pyarrow_with_field(py, array, &field).map_err(to_datafusion_err)
132+
PyArrowArrayExportable::new(array, field)
133+
.to_pyarrow(py)
134+
.map_err(to_datafusion_err)
154135
})
155136
.collect::<Result<Vec<_>, _>>()?;
156137
let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?;

uv.lock

Lines changed: 3 additions & 52 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)