diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index 873d33730d11..83127d04f2ca 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -32,6 +32,7 @@ ) from ibis.backends.sql import SQLBackend from ibis.backends.sql.compilers.base import C +from ibis.backends.sql.rewrites import convert_pandas_udf_to_pyarrow from ibis.common.dispatch import lazy_singledispatch from ibis.expr.operations.udf import InputType from ibis.formats.pyarrow import PyArrowSchema, PyArrowType @@ -268,14 +269,17 @@ def _register_udfs(self, expr: ir.Expr) -> None: if udf_node.__input_type__ == InputType.PYARROW: udf = self._compile_pyarrow_udf(udf_node) self.con.register_udf(udf) + if udf_node.__input_type__ == InputType.PANDAS: + udf = self._compile_pandas_udf(udf_node) + self.con.register_udf(udf) for udf_node in expr.op().find(ops.ElementWiseVectorizedUDF): udf = self._compile_elementwise_udf(udf_node) self.con.register_udf(udf) - def _compile_pyarrow_udf(self, udf_node): + def _compile_udf(self, udf_node, func): return df.udf( - udf_node.__func__, + func, input_types=[PyArrowType.from_ibis(arg.dtype) for arg in udf_node.args], return_type=PyArrowType.from_ibis(udf_node.dtype), volatility=getattr(udf_node, "__config__", {}).get( @@ -284,6 +288,13 @@ def _compile_pyarrow_udf(self, udf_node): name=udf_node.__func_name__, ) + def _compile_pyarrow_udf(self, udf_node): + return self._compile_udf(udf_node, func=udf_node.__func__) + + def _compile_pandas_udf(self, udf_node): + pyarrow_udf = convert_pandas_udf_to_pyarrow(udf_node.__func__) + return self._compile_udf(udf_node, func=pyarrow_udf) + def _compile_elementwise_udf(self, udf_node): return df.udf( udf_node.func, diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 739fdf259cd0..27f6177a7108 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -33,6 +33,7 @@ ) from ibis.backends.sql import SQLBackend from ibis.backends.sql.compilers.base import STAR, AlterTable, C, RenameTable +from ibis.backends.sql.rewrites import convert_pandas_udf_to_pyarrow from ibis.common.dispatch import lazy_singledispatch from ibis.expr.operations.udf import InputType @@ -1739,20 +1740,22 @@ def _register_udfs(self, expr: ir.Expr) -> None: if registration_func is not None: registration_func(con) - def _register_udf(self, udf_node: ops.ScalarUDF): + def _register_udf( + self, + udf_node: ops.ScalarUDF, + *, + func: callable | None = None, + input_type: InputType | None = None, + ): type_mapper = self.compiler.type_mapper - input_types = [ - type_mapper.to_string(param.annotation.pattern.dtype) - for param in udf_node.__signature__.parameters.values() - ] def register_udf(con): return con.create_function( name=type(udf_node).__name__, - function=udf_node.__func__, - parameters=input_types, + function=func or udf_node.__func__, + parameters=[type_mapper.to_string(arg.dtype) for arg in udf_node.args], return_type=type_mapper.to_string(udf_node.dtype), - type=_UDF_INPUT_TYPE_MAPPING[udf_node.__input_type__], + type=_UDF_INPUT_TYPE_MAPPING[input_type or udf_node.__input_type__], **udf_node.__config__, ) @@ -1761,6 +1764,12 @@ def register_udf(con): _register_python_udf = _register_udf _register_pyarrow_udf = _register_udf + def _register_pandas_udf(self, pandas_udf_node: ops.ScalarUDF) -> str: + pyarrow_function = convert_pandas_udf_to_pyarrow(pandas_udf_node.__func__) + return self._register_udf( + pandas_udf_node, func=pyarrow_function, input_type=InputType.PYARROW + ) + def _get_temp_view_definition(self, name: str, definition: str) -> str: return sge.Create( this=sg.to_identifier(name, quoted=self.compiler.quoted), diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py index bb4e56042e20..5c26795cb161 100644 --- a/ibis/backends/sql/compilers/datafusion.py +++ b/ibis/backends/sql/compilers/datafusion.py @@ -172,7 +172,7 @@ def visit_StandardDev(self, op, *, arg, how, where): def visit_ScalarUDF(self, op, **kw): input_type = op.__input_type__ - if input_type in (InputType.PYARROW, InputType.BUILTIN): + if input_type in (InputType.PYARROW, InputType.BUILTIN, InputType.PANDAS): return self.f.anon[op.__func_name__](*kw.values()) else: raise NotImplementedError( diff --git a/ibis/backends/sql/rewrites.py b/ibis/backends/sql/rewrites.py index addc87ef7fea..79558fe40e6b 100644 --- a/ibis/backends/sql/rewrites.py +++ b/ibis/backends/sql/rewrites.py @@ -2,9 +2,10 @@ from __future__ import annotations +import functools import operator import sys -from collections.abc import Mapping +from collections.abc import Callable, Mapping from functools import reduce from typing import TYPE_CHECKING, Any @@ -670,3 +671,31 @@ def argument_replacer(_, y, **kwargs): return ops.Subtract(y, 1) return _.copy(body=_.body.replace(argument_replacer)) + + +def convert_pandas_udf_to_pyarrow(pandas_udf: Callable) -> Callable: + """Convert a pandas UDF to a PyArrow UDF. + + This is useful for backends that support PyArrow UDFs but not pandas UDFs. + + Parameters + ---------- + pandas_udf + The pandas UDF to convert. + + Returns + ------- + A PyArrow UDF that wraps the original pandas UDF. + """ + + @functools.wraps(pandas_udf) + def pyarrow_udf(*pa_args, **pa_kwargs): + import pyarrow as pa + + pandas_args = [arg.to_pandas() for arg in pa_args] + pandas_kwargs = {k: v.to_pandas() for k, v in pa_kwargs.items()} + pandas_result = pandas_udf(*pandas_args, **pandas_kwargs) + pa_result = pa.Array.from_pandas(pandas_result) + return pa_result + + return pyarrow_udf diff --git a/ibis/backends/tests/test_udf.py b/ibis/backends/tests/test_udf.py index 04fc98f659fb..8744b809924e 100644 --- a/ibis/backends/tests/test_udf.py +++ b/ibis/backends/tests/test_udf.py @@ -171,7 +171,7 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type add_one_pandas, marks=[ mark.notyet( - ["duckdb", "datafusion", "polars", "sqlite"], + ["polars", "sqlite"], raises=NotImplementedError, reason="backend doesn't support pandas UDFs", ), diff --git a/ibis/expr/operations/udf.py b/ibis/expr/operations/udf.py index 1a73aec93f4c..190630fea517 100644 --- a/ibis/expr/operations/udf.py +++ b/ibis/expr/operations/udf.py @@ -50,8 +50,16 @@ class InputType(enum.Enum): PYTHON = enum.auto() +class _UDFMixin: + __input_type__: InputType + __func__: Callable + __func_name__: str + __config__: FrozenDict + __udf_namespace__: ops.Namespace + + @public -class ScalarUDF(ops.Impure): +class ScalarUDF(ops.Impure, _UDFMixin): @attribute def shape(self): if not (args := getattr(self, "args")): # noqa: B009 @@ -65,7 +73,7 @@ def shape(self): @public -class AggUDF(ops.Reduction, ops.Impure): +class AggUDF(ops.Reduction, ops.Impure, _UDFMixin): where: Optional[ops.Value[dt.Boolean]] = None @@ -479,7 +487,7 @@ def pandas( ... def str_cap(x: str) -> str: ... # note usage of pandas `str` method ... return x.str.capitalize() - >>> str_cap(t.str_col) # doctest: +SKIP + >>> str_cap(t.str_col) ┏━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ string_cap_0(str_col) ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━┩