Skip to content

Commit 153a726

Browse files
authored
refactor: add fragment_group_size to reduce lance scan task (#5261)
## Changes Made When the number of fragments is large, the current implementation method assigns one task to each fragment, which results in a long planning time. Therefore, some fragment filtering and fragment grouping implementations have been added here to reduce the number of tasks. <!-- Describe what changes were made and why. Include implementation details if necessary. --> ## Related Issues <!-- Link to related GitHub issues, e.g., "Closes #123" --> ## Checklist - [ ] Documented in API Docs (if applicable) - [ ] Documented in User Guide (if applicable) - [ ] If adding a new documentation page, doc is added to `docs/mkdocs.yml` navigation - [ ] Documentation builds and is formatted properly (tag @/ccmao1130 for docs review)
1 parent 93a0dcd commit 153a726

File tree

3 files changed

+79
-27
lines changed

3 files changed

+79
-27
lines changed

daft/io/lance/_lance.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def read_lance(
3636
index_cache_size: Optional[int] = None,
3737
default_scan_options: Optional[dict[str, str]] = None,
3838
metadata_cache_size_bytes: Optional[int] = None,
39+
fragment_group_size: Optional[int] = None,
3940
) -> DataFrame:
4041
"""Create a DataFrame from a LanceDB table.
4142
@@ -60,7 +61,6 @@ def read_lance(
6061
6162
Roughly, for an ``IVF_PQ`` partition with ``n`` rows, the size of each index
6263
page equals the combination of the pq code (``np.array([n,pq], dtype=uint8))``
63-
and the row ids (``np.array([n], dtype=uint64)``).
6464
Approximately, ``n = Total Rows / number of IVF partitions``.
6565
``pq = number of PQ sub-vectors``.
6666
storage_options : optional, dict
@@ -82,11 +82,13 @@ def read_lance(
8282
Size of the metadata cache in bytes. This cache is used to store metadata
8383
information about the dataset, such as schema and statistics. If not specified,
8484
a default size will be used.
85+
fragment_group_size : optional, int
86+
Number of fragments to group together in a single scan task. If None or <= 1,
87+
each fragment will be processed individually (default behavior).
8588
8689
Returns:
8790
DataFrame: a DataFrame with the schema converted from the specified LanceDB table
8891
89-
Note:
9092
This function requires the use of [LanceDB](https://lancedb.github.io/lancedb/), which is the Python library for the LanceDB project.
9193
To ensure that this is installed with Daft, you may install: `pip install daft[lance]`
9294
@@ -104,6 +106,10 @@ def read_lance(
104106
Read a local LanceDB table and specify a version:
105107
>>> df = daft.read_lance("s3://my-lancedb-bucket/data/", version=1)
106108
>>> df.show()
109+
110+
Read a local LanceDB table with fragment grouping:
111+
>>> df = daft.read_lance("s3://my-lancedb-bucket/data/", fragment_group_size=5)
112+
>>> df.show()
107113
"""
108114
try:
109115
import lance
@@ -126,7 +132,7 @@ def read_lance(
126132
default_scan_options=default_scan_options,
127133
metadata_cache_size_bytes=metadata_cache_size_bytes,
128134
)
129-
lance_operator = LanceDBScanOperator(ds)
135+
lance_operator = LanceDBScanOperator(ds, fragment_group_size=fragment_group_size)
130136

131137
handle = ScanOperatorHandle.from_python_scan_operator(lance_operator)
132138
builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle)

daft/io/lance/lance_scan.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ def _lancedb_count_result_function(
5252

5353

5454
class LanceDBScanOperator(ScanOperator, SupportsPushdownFilters):
55-
def __init__(self, ds: "lance.LanceDataset"):
55+
def __init__(self, ds: "lance.LanceDataset", fragment_group_size: Optional[int] = None):
5656
self._ds = ds
5757
self._pushed_filters: Union[list[PyExpr], None] = None
58+
self._fragment_group_size = fragment_group_size
5859

5960
def name(self) -> str:
6061
return "LanceDBScanOperator"
@@ -203,30 +204,61 @@ def _create_regular_scan_tasks(
203204
) -> Iterator[ScanTask]:
204205
"""Create regular scan tasks without count pushdown."""
205206
fragments = self._ds.get_fragments()
206-
for fragment in fragments:
207-
# TODO: figure out how if we can get this metadata from LanceDB fragments cheaply
208-
size_bytes = None
209-
stats = None
210-
211-
# NOTE: `fragment.count_rows()` should result in 1 IO call for the data file
212-
# (1 fragment = 1 data file) and 1 more IO call for the deletion file (if present).
213-
# This could potentially be expensive to perform serially if there are thousands of files.
214-
# Given that num_rows isn't leveraged for much at the moment, and without statistics
215-
# we will probably end up materializing the data anyways for any operations, we leave this
216-
# as None.
217-
num_rows = None
218-
pushed_expr = self._combine_filters_to_arrow()
207+
pushed_expr = self._combine_filters_to_arrow()
219208

220-
yield ScanTask.python_factory_func_scan_task(
221-
module=_lancedb_table_factory_function.__module__,
222-
func_name=_lancedb_table_factory_function.__name__,
223-
func_args=(self._ds, [fragment.fragment_id], required_columns, pushed_expr, pushdowns.limit),
224-
schema=self.schema()._schema,
225-
num_rows=num_rows,
226-
size_bytes=size_bytes,
227-
pushdowns=pushdowns,
228-
stats=stats,
229-
)
209+
if self._fragment_group_size is None or self._fragment_group_size <= 1:
210+
# Default behavior: one fragment per task
211+
for fragment in fragments:
212+
size_bytes = None
213+
stats = None
214+
num_rows = None
215+
if fragment.count_rows(pushed_expr) == 0:
216+
continue
217+
218+
yield ScanTask.python_factory_func_scan_task(
219+
module=_lancedb_table_factory_function.__module__,
220+
func_name=_lancedb_table_factory_function.__name__,
221+
func_args=(self._ds, [fragment.fragment_id], required_columns, pushed_expr, pushdowns.limit),
222+
schema=self.schema()._schema,
223+
num_rows=num_rows,
224+
size_bytes=size_bytes,
225+
pushdowns=pushdowns,
226+
stats=stats,
227+
)
228+
else:
229+
# Group fragments
230+
fragment_groups = []
231+
current_group = []
232+
233+
for fragment in fragments:
234+
if fragment.count_rows(pushed_expr) == 0:
235+
continue
236+
current_group.append(fragment)
237+
if len(current_group) >= self._fragment_group_size:
238+
fragment_groups.append(current_group)
239+
current_group = []
240+
241+
# Add the last group if it has any fragments
242+
if current_group:
243+
fragment_groups.append(current_group)
244+
245+
# Create scan tasks for each fragment group
246+
for fragment_group in fragment_groups:
247+
fragment_ids = [fragment.fragment_id for fragment in fragment_group]
248+
size_bytes = None
249+
stats = None
250+
num_rows = None
251+
252+
yield ScanTask.python_factory_func_scan_task(
253+
module=_lancedb_table_factory_function.__module__,
254+
func_name=_lancedb_table_factory_function.__name__,
255+
func_args=(self._ds, fragment_ids, required_columns, pushed_expr, pushdowns.limit),
256+
schema=self.schema()._schema,
257+
num_rows=num_rows,
258+
size_bytes=size_bytes,
259+
pushdowns=pushdowns,
260+
stats=stats,
261+
)
230262

231263
def _combine_filters_to_arrow(self) -> Optional["pa.compute.Expression"]:
232264
if self._pushed_filters is not None:

tests/io/lancedb/test_lancedb_reads.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,20 @@ def test_lancedb_read_pushdown(lance_dataset_path, capsys):
142142
), f"Physical plan contains {filter_count} Filter nodes and {scan_source_count} ScanTaskSource nodes, which is not expected"
143143

144144

145+
def test_lancedb_read_parallelism_fragment_merging(large_lance_dataset_path):
146+
"""Test parallelism parameter reduces scan tasks by merging fragments."""
147+
df_no_fragment_group = daft.read_lance(large_lance_dataset_path)
148+
assert len(lance.dataset(large_lance_dataset_path).get_fragments()) == df_no_fragment_group.num_partitions()
149+
150+
df = daft.read_lance(large_lance_dataset_path, fragment_group_size=3)
151+
df.explain(show_all=True)
152+
assert df.num_partitions() == 4 # 10 fragments, group size 3 -> 4 scan tasks
153+
154+
result = df.to_pydict()
155+
assert len(result["vector"]) == 10000
156+
assert len(result["big_int"]) == 10000
157+
158+
145159
class TestLanceDBCountPushdown:
146160
tmp_data = {
147161
"a": ["a", "b", "c", "d", "e", None],

0 commit comments

Comments
 (0)