diff --git a/pyproject.toml b/pyproject.toml index b43d3db1..8a9a8f06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "numpy>=2", # We use internal pd._libs.missing and experimental ArrowExtensionArray "pandas>=2.2.3,<2.4", - "pyarrow>=16", # remove struct_field_names when upgraded to 18+ + "pyarrow>=16", # remove struct_field_names() and struct_fields() when upgraded to 18+ "universal_pathlib>=0.2", ] diff --git a/src/nested_pandas/series/_storage/struct_list_storage.py b/src/nested_pandas/series/_storage/struct_list_storage.py index 9e9018a6..032c7210 100644 --- a/src/nested_pandas/series/_storage/struct_list_storage.py +++ b/src/nested_pandas/series/_storage/struct_list_storage.py @@ -6,9 +6,9 @@ import pyarrow as pa from nested_pandas.series.utils import ( + align_chunked_struct_list_offsets, table_to_struct_array, transpose_list_struct_chunked, - validate_struct_list_array_for_equal_lengths, ) if TYPE_CHECKING: @@ -25,7 +25,9 @@ class StructListStorage: Pyarrow struct-array with all fields to be list-arrays. All list-values must be "aligned", e.g., have the same length. validate : bool (default True) - Check that all the lists have the same lengths for each struct-value. + Check that all the lists have the same lengths for each struct-value, + and if all list offset arrays are the same. Fails for the first check, + and reallocates the data for the second check. """ _data: pa.ChunkedArray @@ -37,8 +39,7 @@ def __init__(self, array: pa.StructArray | pa.ChunkedArray, *, validate: bool = raise ValueError("array must be a StructArray or ChunkedArray") if validate: - for chunk in array.chunks: - validate_struct_list_array_for_equal_lengths(chunk) + array = align_chunked_struct_list_offsets(array) self._data = array diff --git a/src/nested_pandas/series/_storage/table_storage.py b/src/nested_pandas/series/_storage/table_storage.py index 1c9516c4..9b5656e0 100644 --- a/src/nested_pandas/series/_storage/table_storage.py +++ b/src/nested_pandas/series/_storage/table_storage.py @@ -5,9 +5,9 @@ import pyarrow as pa from nested_pandas.series.utils import ( + align_chunked_struct_list_offsets, table_from_struct_array, table_to_struct_array, - validate_struct_list_array_for_equal_lengths, ) if TYPE_CHECKING: @@ -30,8 +30,8 @@ class TableStorage: def __init__(self, table: pa.Table, validate: bool = True) -> None: if validate: struct_array = table_to_struct_array(table) - for chunk in struct_array.iterchunks(): - validate_struct_list_array_for_equal_lengths(chunk) + aligned_struct_array = align_chunked_struct_list_offsets(struct_array) + table = table_from_struct_array(aligned_struct_array) self._data = table diff --git a/src/nested_pandas/series/utils.py b/src/nested_pandas/series/utils.py index 394d5d40..cb0cbbf5 100644 --- a/src/nested_pandas/series/utils.py +++ b/src/nested_pandas/series/utils.py @@ -20,6 +20,15 @@ def struct_field_names(struct_type: pa.StructType) -> list[str]: return [f.name for f in struct_type] +def struct_fields(struct_type: pa.StructType) -> list[pa.Field]: + """Return fields of a pyarrow.StructType in a pyarrow<18-compatible way. + + Note: Once we bump our pyarrow requirement to ">=18", this helper can be + replaced with direct usage of ``struct_type.fields`` throughout the codebase. + """ + return [struct_type.field(i) for i in range(struct_type.num_fields)] + + def is_pa_type_a_list(pa_type: pa.DataType) -> bool: """Check if the given pyarrow type is a list type. @@ -58,36 +67,97 @@ def is_pa_type_is_list_struct(pa_type: pa.DataType) -> bool: return is_pa_type_a_list(pa_type) and pa.types.is_struct(pa_type.value_type) -def validate_struct_list_array_for_equal_lengths(array: pa.StructArray) -> None: - """Check if the given struct array has lists of equal length. +def align_struct_list_offsets(array: pa.StructArray) -> pa.StructArray: + """Checks if all struct-list offsets are the same, and reallocates if needed Parameters ---------- array : pa.StructArray Input struct array. + Returns + ------- + pa.StructArray + Array with all struct-list offsets aligned. May be the input, + if it was valid. + Raises ------ ValueError - If the struct array has lists of unequal length or type of the input - array is not a StructArray or fields are not ListArrays. + If the input is not a valid "nested" StructArray. """ if not pa.types.is_struct(array.type): raise ValueError(f"Expected a StructArray, got {array.type}") - first_list_array: pa.ListArray | None = None + first_offsets: pa.ListArray | None = None for field in array.type: inner_array = array.field(field.name) if not is_pa_type_a_list(inner_array.type): raise ValueError(f"Expected a ListArray, got {inner_array.type}") list_array = cast(pa.ListArray, inner_array) - if first_list_array is None: - first_list_array = list_array + if first_offsets is None: + first_offsets = list_array.offsets continue # compare offsets from the first list array with the current one - if not first_list_array.offsets.equals(list_array.offsets): - raise ValueError("Offsets of all ListArrays must be the same") + if not first_offsets.equals(list_array.offsets): + break + else: + # Return the original array if all offsets match + return array + + new_offsets = pa.compute.subtract(first_offsets, first_offsets[0]) + value_lengths = None + list_arrays = [] + for field in array.type: + inner_array = array.field(field.name) + list_array = cast(pa.ListArray, inner_array) + + if value_lengths is None: + value_lengths = list_array.value_lengths() + elif not value_lengths.equals(list_array.value_lengths()): + raise ValueError( + f"List lengths do not match for struct fields {array.type.field(0).name} and {field.name}", + ) + + list_arrays.append( + pa.ListArray.from_arrays( + values=list_array.values[list_array.offsets[0].as_py() : list_array.offsets[-1].as_py()], + offsets=new_offsets, + ) + ) + new_array = pa.StructArray.from_arrays( + arrays=list_arrays, + fields=struct_fields(array.type), + ) + return new_array + + +def align_chunked_struct_list_offsets(array: pa.Array | pa.ChunkedArray) -> pa.ChunkedArray: + """Checks if all struct-list offsets are the same, and reallocates if needed + + Parameters + ---------- + array : pa.ChunkedArray or pa.Array + Input chunked array, it must be a valid "nested" struct-list array, + e.g. all list lengths must match. Non-chunked arrays are allowed, + but the return array will always be chunked. + + Returns + ------- + pa.ChunkedArray + Chunked array with all struct-list offsets aligned. + + Raises + ------ + ValueError + If the input is not a valid "nested" struct-list-array. + """ + if isinstance(array, pa.Array): + array = pa.chunked_array([array]) + chunks = [align_struct_list_offsets(chunk) for chunk in array.iterchunks()] + # Provide type for the case of zero-chunks array + return pa.chunked_array(chunks, type=array.type) def transpose_struct_list_type(t: pa.StructType) -> pa.ListType: @@ -139,7 +209,7 @@ def transpose_struct_list_array(array: pa.StructArray, validate: bool = True) -> List array of structs. """ if validate: - validate_struct_list_array_for_equal_lengths(array) + array = align_struct_list_offsets(array) mask = array.is_null() if not pa.compute.any(mask).as_py(): @@ -220,6 +290,16 @@ def validate_list_struct_type(t: pa.ListType) -> None: raise ValueError(f"Expected a StructType as a list value type, got {t.value_type}") +def validate_struct_list_type(t: pa.StructType) -> None: + """Raise a ValueError if not a struct-list-type.""" + if not pa.types.is_struct(t): + raise ValueError(f"Expected a StructType, got {t}") + + for field in struct_fields(t): + if not is_pa_type_a_list(field.type): + raise ValueError(f"Expected a ListType for field {field.name}, got {field.type}") + + def transpose_list_struct_type(t: pa.ListType) -> pa.StructType: """Converts a type of list-struct array into a type of struct-list array. diff --git a/tests/nested_pandas/series/test_series_utils.py b/tests/nested_pandas/series/test_series_utils.py index a966c017..e1512349 100644 --- a/tests/nested_pandas/series/test_series_utils.py +++ b/tests/nested_pandas/series/test_series_utils.py @@ -3,6 +3,8 @@ import pytest from nested_pandas import NestedDtype from nested_pandas.series.utils import ( + align_chunked_struct_list_offsets, + align_struct_list_offsets, nested_types_mapper, struct_field_names, transpose_list_struct_array, @@ -10,27 +12,27 @@ transpose_list_struct_type, transpose_struct_list_array, transpose_struct_list_type, - validate_struct_list_array_for_equal_lengths, + validate_struct_list_type, ) -def test_validate_struct_list_array_for_equal_lengths(): - """Test validate_struct_list_array_for_equal_lengths function.""" +def test_align_struct_list_offsets(): + """Test align_struct_list_offsets function.""" # Raises for wrong types with pytest.raises(ValueError): - validate_struct_list_array_for_equal_lengths(pa.array([], type=pa.int64())) + align_struct_list_offsets(pa.array([], type=pa.int64())) with pytest.raises(ValueError): - validate_struct_list_array_for_equal_lengths(pa.array([], type=pa.list_(pa.int64()))) + align_struct_list_offsets(pa.array([], type=pa.list_(pa.int64()))) # Raises if one of the fields is not a ListArray with pytest.raises(ValueError): - validate_struct_list_array_for_equal_lengths( + align_struct_list_offsets( pa.StructArray.from_arrays([pa.array([[1, 2], [3, 4, 5]]), pa.array([1, 2])], ["a", "b"]) ) # Raises for mismatched lengths with pytest.raises(ValueError): - validate_struct_list_array_for_equal_lengths( + align_struct_list_offsets( pa.StructArray.from_arrays( [pa.array([[1, 2], [3, 4, 5]]), pa.array([[1, 2, 3], [4, 5]])], ["a", "b"] ) @@ -43,7 +45,96 @@ def test_validate_struct_list_array_for_equal_lengths(): ], names=["a", "b"], ) - assert validate_struct_list_array_for_equal_lengths(input_array) is None + assert align_struct_list_offsets(input_array) is input_array + + a = pa.array([[0, 0, 0], [1, 2], [3, 4], [], [5, 6, 7]])[1:] + assert a.offsets[0].as_py() == 3 + b = pa.array([["x", "y"], ["y", "x"], [], ["d", "e", "f"]]) + assert b.offsets[0].as_py() == 0 + input_array = pa.StructArray.from_arrays( + arrays=[a, b], + names=["a", "b"], + ) + aligned_array = align_struct_list_offsets(input_array) + assert aligned_array is not input_array + assert aligned_array.equals(input_array) + + +def test_align_chunked_struct_list_offsets(): + """Test align_chunked_struct_list_offsets function.""" + # Input is an array, output is chunked array + a = pa.array([[1, 2], [3, 4], [], [5, 6, 7]]) + b = pa.array([["x", "y"], ["y", "x"], [], ["d", "e", "f"]]) + input_array = pa.StructArray.from_arrays( + arrays=[a, b], + names=["a", "b"], + ) + output_array = align_chunked_struct_list_offsets(input_array) + assert isinstance(output_array, pa.ChunkedArray) + assert output_array.equals(pa.chunked_array([input_array])) + + # Input is an "aligned" chunked array + input_array = pa.chunked_array( + [ + pa.StructArray.from_arrays( + arrays=[a, b], + names=["a", "b"], + ) + ] + * 2 + ) + output_array = align_chunked_struct_list_offsets(input_array) + assert output_array.equals(input_array) + + # Input is an "aligned" chunked array, but offsets do not start with zero + a = pa.array([[0, 0, 0], [1, 2], [3, 4], [], [5, 6, 7]])[1:] + b = pa.array([["a", "a", "a", "a"], ["x", "y"], ["y", "x"], [], ["d", "e", "f"]])[1:] + input_array = pa.chunked_array( + [ + pa.StructArray.from_arrays( + arrays=[a, b], + names=["a", "b"], + ) + ] + * 3 + ) + output_array = align_chunked_struct_list_offsets(input_array) + assert output_array.equals(input_array) + + # Input is a "non-aligned" chunked array + a = pa.array([[0, 0, 0], [1, 2], [3, 4], [], [5, 6, 7]])[1:] + b = pa.array([["x", "y"], ["y", "x"], [], ["d", "e", "f"]]) + input_array = pa.chunked_array( + [ + pa.StructArray.from_arrays( + arrays=[a, b], + names=["a", "b"], + ) + ] + * 4 + ) + output_array = align_chunked_struct_list_offsets(input_array) + assert output_array.equals(input_array) + + +def test_validate_struct_list_type(): + """Test validate_struct_list_type function.""" + with pytest.raises(ValueError): + validate_struct_list_type(pa.float64()) + + with pytest.raises(ValueError): + validate_struct_list_type(pa.list_(pa.struct({"a": pa.int64()}))) + + with pytest.raises(ValueError): + validate_struct_list_type(pa.struct({"a": pa.float64()})) + + with pytest.raises(ValueError): + validate_struct_list_type(pa.struct({"a": pa.list_(pa.float64()), "b": pa.float64()})) + + assert ( + validate_struct_list_type(pa.struct({"a": pa.list_(pa.float64()), "b": pa.list_(pa.float64())})) + is None + ) def test_transpose_struct_list_type():