Skip to content

Commit ccb8276

Browse files
committed
Workaround VECTOR type: use List with metadata
1 parent 665ead5 commit ccb8276

File tree

3 files changed

+53
-22
lines changed

3 files changed

+53
-22
lines changed

python/databend_udf/udf.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
MAX_DECIMAL256_PRECISION = 76
5050
EXTENSION_KEY = b"Extension"
5151
ARROW_EXT_TYPE_VARIANT = b"Variant"
52+
ARROW_EXT_TYPE_VECTOR = b"Vector"
5253

5354
TIMESTAMP_UINT = "us"
5455
_SCHEMA_METADATA_INPUT_COUNT_KEY = b"x-databend-udf-input-count"
@@ -1405,8 +1406,16 @@ def _type_str_to_arrow_field_inner(type_str: str) -> pa.Field:
14051406
elif type_str.startswith("VECTOR"):
14061407
# VECTOR(1024)
14071408
dim = int(type_str[6:].strip("()").strip())
1409+
# Use List(float) with metadata to represent VECTOR(N)
1410+
# This is a workaround because Databend UDF client might not support FixedSizeList yet.
14081411
return pa.field(
1409-
"", pa.list_(pa.field("item", pa.float32(), nullable=False), dim), False
1412+
"",
1413+
pa.list_(pa.field("item", pa.float32(), nullable=False)),
1414+
nullable=False,
1415+
metadata={
1416+
EXTENSION_KEY: ARROW_EXT_TYPE_VECTOR,
1417+
b"vector_size": str(dim).encode(),
1418+
},
14101419
)
14111420
else:
14121421
raise ValueError(f"Unsupported type: {type_str}")
@@ -1431,6 +1440,10 @@ def _field_type_to_string(field: pa.Field) -> str:
14311440
Convert a `pyarrow.DataType` to a SQL data type string.
14321441
"""
14331442
t = field.type
1443+
if field.metadata and field.metadata.get(EXTENSION_KEY) == ARROW_EXT_TYPE_VECTOR:
1444+
dim = int(field.metadata.get(b"vector_size", b"0"))
1445+
return f"VECTOR({dim})"
1446+
14341447
if pa.types.is_boolean(t):
14351448
return "BOOLEAN"
14361449
elif pa.types.is_int8(t):
@@ -1466,8 +1479,6 @@ def _field_type_to_string(field: pa.Field) -> str:
14661479
return "VARIANT"
14671480
else:
14681481
return "BINARY"
1469-
elif pa.types.is_fixed_size_list(t):
1470-
return f"VECTOR({t.list_size})"
14711482
elif pa.types.is_list(t):
14721483
return f"ARRAY({_inner_field_to_string(t.value_field)})"
14731484
elif pa.types.is_map(t):

python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ classifiers = [
77
description = "Databend UDF Server"
88
license = { text = "Apache-2.0" }
99
name = "databend-udf"
10-
version = "0.2.13"
10+
version = "0.2.14"
1111
readme = "README.md"
1212
requires-python = ">=3.7"
1313
dependencies = [

python/tests/test_vector_type.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,47 +23,67 @@ def test_vector_sql_generation():
2323

2424
def test_vector_type_parsing():
2525
field = _type_str_to_arrow_field("VECTOR(1024)")
26-
assert pa.types.is_fixed_size_list(field.type)
27-
assert field.type.list_size == 1024
26+
# Should be List type with metadata, not FixedSizeList
27+
assert pa.types.is_list(field.type)
28+
assert field.metadata[b"Extension"] == b"Vector"
29+
assert field.metadata[b"vector_size"] == b"1024"
2830
assert pa.types.is_float32(field.type.value_type)
31+
# Default is nullable
2932
assert field.nullable is True
3033

34+
# Test NOT NULL
35+
field_not_null = _type_str_to_arrow_field("VECTOR(1024) NOT NULL")
36+
assert field_not_null.nullable is False
37+
3138

3239
def test_vector_type_formatting():
40+
# Test that a List with VECTOR metadata is formatted as VECTOR(N)
3341
field = pa.field(
3442
"",
35-
pa.list_(pa.field("item", pa.float32(), nullable=False), 1024),
36-
nullable=True,
43+
pa.list_(pa.field("item", pa.float32(), nullable=False)),
44+
nullable=False,
45+
metadata={
46+
b"Extension": b"Vector",
47+
b"vector_size": b"1024",
48+
},
3749
)
3850
type_str = _field_type_to_string(field)
3951
assert type_str == "VECTOR(1024)"
4052

4153

4254
def test_vector_input_processing():
55+
# Input processing should handle List (which is what VECTOR is physically)
4356
field = pa.field(
44-
"", pa.list_(pa.field("item", pa.float32(), nullable=False), 3), nullable=True
57+
"",
58+
pa.list_(pa.field("item", pa.float32(), nullable=False)),
59+
nullable=False,
60+
metadata={
61+
b"Extension": b"Vector",
62+
b"vector_size": b"3",
63+
},
4564
)
4665
func = _input_process_func(field)
4766

4867
# Input is a list of floats
49-
input_data = [1.0, 2.0, 3.0]
50-
result = func(input_data)
51-
assert result == [1.0, 2.0, 3.0]
52-
53-
# Input is None
54-
assert func(None) is None
68+
data = [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]]
69+
processed = func(data)
70+
assert processed == data
5571

5672

5773
def test_vector_output_processing():
74+
# Output processing should handle List
5875
field = pa.field(
59-
"", pa.list_(pa.field("item", pa.float32(), nullable=False), 3), nullable=True
76+
"",
77+
pa.list_(pa.field("item", pa.float32(), nullable=False)),
78+
nullable=False,
79+
metadata={
80+
b"Extension": b"Vector",
81+
b"vector_size": b"3",
82+
},
6083
)
6184
func = _output_process_func(field)
6285

6386
# Output is a list of floats
64-
output_data = [1.0, 2.0, 3.0]
65-
result = func(output_data)
66-
assert result == [1.0, 2.0, 3.0]
67-
68-
# Output is None
69-
assert func(None) is None
87+
data = [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]]
88+
processed = func(data)
89+
assert processed == data

0 commit comments

Comments
 (0)