Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions python/databend_udf/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,14 @@ def require_stage_locations(self, names: List[str]) -> Dict[str, StageLocation]:
mapping = self._stage_mapping()
missing = [name for name in names if name not in mapping]
if missing:
raise ValueError(
"Missing stage mapping for parameter(s): " + ", ".join(sorted(missing))
msg = (
"Missing stage mapping for parameter(s): "
+ ", ".join(sorted(missing))
+ ".\n"
"Please check your CREATE FUNCTION statement to ensure that the stage location is correctly specified.\n"
"For example: CREATE FUNCTION ... (stage_param STAGE_LOCATION) ...\n"
)
raise ValueError(msg)
return {name: mapping[name] for name in names}


Expand Down Expand Up @@ -493,7 +498,7 @@ def __init__(
for kind, identifier in self._call_arg_layout:
if kind == "stage":
stage_ref_name = self._stage_param_to_ref.get(identifier, identifier)
self._sql_parameter_defs.append(f"STAGE_LOCATION {stage_ref_name}")
self._sql_parameter_defs.append(f"{stage_ref_name} STAGE_LOCATION")
elif kind == "data":
field = data_field_map[identifier]
self._sql_parameter_defs.append(
Expand Down Expand Up @@ -1144,11 +1149,11 @@ def add_function(self, udf: UserDefinedFunction):
f"{field.name} {_inner_field_to_string(field)}"
for field in udf._result_schema
)
output_type = f"({column_defs})"
output_type = f"TABLE ({column_defs})"
else:
output_type = _arrow_field_to_string(udf._result_schema[0])
sql = (
f"CREATE FUNCTION {name} ({input_types}) "
f"CREATE OR REPLACE FUNCTION {name} ({input_types}) "
f"RETURNS {output_type} LANGUAGE python "
f"HANDLER = '{name}' ADDRESS = 'http://{self._location}';"
)
Expand Down Expand Up @@ -1412,7 +1417,7 @@ def _arrow_field_to_string(field: pa.Field) -> str:
def _inner_field_to_string(field: pa.Field) -> str:
# inner field default is NOT NULL in databend
type_str = _field_type_to_string(field)
return f"{type_str} NULL" if field.nullable else type_str
return f"{type_str} NOT NULL" if not field.nullable else type_str
Comment on lines 1417 to +1420

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve NULLability in table output columns

Table-valued UDFs with nullable result columns are now emitted without an explicit NULL marker: _inner_field_to_string returns only the type when field.nullable is true even though the comment notes inner fields default to NOT NULL in Databend. That causes RETURNS TABLE (...) definitions to register nullable columns as NOT NULL, so UDFs that actually return nulls will have a mismatched signature and can fail at runtime. Please keep emitting NULL when a result field is nullable.

Useful? React with 👍 / 👎.



def _field_type_to_string(field: pa.Field) -> str:
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ classifiers = [
description = "Databend UDF Server"
license = { text = "Apache-2.0" }
name = "databend-udf"
version = "0.2.8"
version = "0.2.9"
readme = "README.md"
requires-python = ">=3.7"
dependencies = [
Expand Down
54 changes: 54 additions & 0 deletions python/tests/test_sql_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import logging
import pyarrow as pa
from prometheus_client import REGISTRY
from databend_udf import udf, StageLocation, UDFServer


@udf(input_types=["INT"], result_type="INT")
def scalar_func(x: int) -> int:
return x


@udf(stage_refs=["stage_loc"], input_types=["INT"], result_type="INT")
def stage_func(stage_loc: StageLocation, x: int) -> int:
return x


@udf(input_types=["INT"], result_type=["INT"], batch_mode=True)
def table_func(x: int):
yield pa.RecordBatch.from_arrays([pa.array([x])], names=["res"])


def setup_function():
for collector in list(REGISTRY._collector_to_names):
REGISTRY.unregister(collector)


def test_scalar_sql(caplog):
with caplog.at_level(logging.INFO):
server = UDFServer("0.0.0.0:0")
server.add_function(scalar_func)

assert "CREATE OR REPLACE FUNCTION scalar_func (x INT)" in caplog.text
assert "RETURNS INT LANGUAGE python" in caplog.text


def test_stage_sql(caplog):
with caplog.at_level(logging.INFO):
server = UDFServer("0.0.0.0:0")
server.add_function(stage_func)

assert (
"CREATE OR REPLACE FUNCTION stage_func (stage_loc STAGE_LOCATION, x INT)"
in caplog.text
)
assert "RETURNS INT LANGUAGE python" in caplog.text


def test_table_sql(caplog):
with caplog.at_level(logging.INFO):
server = UDFServer("0.0.0.0:0")
server.add_function(table_func)

assert "CREATE OR REPLACE FUNCTION table_func (x INT)" in caplog.text
assert "RETURNS TABLE (col0 INT) LANGUAGE python" in caplog.text
4 changes: 3 additions & 1 deletion python/tests/test_stage_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ def test_multiple_stage_entries():


def test_missing_stage_mapping():
with pytest.raises(ValueError, match="Missing stage mapping"):
with pytest.raises(
ValueError, match=r"Missing stage mapping(.|\n)*CREATE FUNCTION"
):
_collect(describe_stage, _make_batch([1]), Headers())


Expand Down