Skip to content

Commit b0b3799

Browse files
authored
perf: download products and assets in parallel (#1890)
1 parent 90fafbe commit b0b3799

File tree

15 files changed

+2471
-1699
lines changed

15 files changed

+2471
-1699
lines changed

docs/notebooks/api_user_guide/7_download.ipynb

Lines changed: 2033 additions & 1493 deletions
Large diffs are not rendered by default.

eodag/api/core.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
from eodag.utils.stac_reader import fetch_stac_items
8888

8989
if TYPE_CHECKING:
90+
from concurrent.futures import ThreadPoolExecutor
9091
from shapely.geometry.base import BaseGeometry
9192

9293
from eodag.api.product import EOProduct
@@ -1919,6 +1920,7 @@ def download_all(
19191920
search_result: SearchResult,
19201921
downloaded_callback: Optional[DownloadedCallback] = None,
19211922
progress_callback: Optional[ProgressCallback] = None,
1923+
executor: Optional[ThreadPoolExecutor] = None,
19221924
wait: float = DEFAULT_DOWNLOAD_WAIT,
19231925
timeout: float = DEFAULT_DOWNLOAD_TIMEOUT,
19241926
**kwargs: Unpack[DownloadConf],
@@ -1936,6 +1938,8 @@ def download_all(
19361938
size as inputs and handle progress bar
19371939
creation and update to give the user a
19381940
feedback on the download progress
1941+
:param executor: (optional) An executor to download EO products of ``search_result`` in parallel
1942+
which will also be reused to download assets of these products in parallel.
19391943
:param wait: (optional) If download fails, wait time in minutes between
19401944
two download tries of the same product
19411945
:param timeout: (optional) If download fails, maximum time in minutes
@@ -1956,15 +1960,15 @@ def download_all(
19561960
paths = []
19571961
if search_result:
19581962
logger.info("Downloading %s products", len(search_result))
1959-
# Get download plugin using first product assuming product from several provider
1960-
# aren't mixed into a search result
1963+
# Get download plugin using first product assuming all plugins use base.Download.download_all
19611964
download_plugin = self._plugins_manager.get_download_plugin(
19621965
search_result[0]
19631966
)
19641967
paths = download_plugin.download_all(
19651968
search_result,
19661969
downloaded_callback=downloaded_callback,
19671970
progress_callback=progress_callback,
1971+
executor=executor,
19681972
wait=wait,
19691973
timeout=timeout,
19701974
**kwargs,
@@ -2026,6 +2030,7 @@ def download(
20262030
self,
20272031
product: EOProduct,
20282032
progress_callback: Optional[ProgressCallback] = None,
2033+
executor: Optional[ThreadPoolExecutor] = None,
20292034
wait: float = DEFAULT_DOWNLOAD_WAIT,
20302035
timeout: float = DEFAULT_DOWNLOAD_TIMEOUT,
20312036
**kwargs: Unpack[DownloadConf],
@@ -2056,6 +2061,8 @@ def download(
20562061
size as inputs and handle progress bar
20572062
creation and update to give the user a
20582063
feedback on the download progress
2064+
:param executor: (optional) An executor to download assets of ``product`` in parallel if it has any. If ``None``
2065+
, a default executor will be created
20592066
:param wait: (optional) If download fails, wait time in minutes between
20602067
two download tries
20612068
:param timeout: (optional) If download fails, maximum time in minutes
@@ -2080,7 +2087,11 @@ def download(
20802087
return uri_to_path(product.location)
20812088
self._setup_downloader(product)
20822089
path = product.download(
2083-
progress_callback=progress_callback, wait=wait, timeout=timeout, **kwargs
2090+
progress_callback=progress_callback,
2091+
executor=executor,
2092+
wait=wait,
2093+
timeout=timeout,
2094+
**kwargs,
20842095
)
20852096

20862097
return path

eodag/api/product/_product.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from eodag.utils.repr import dict_to_html_table
6262

6363
if TYPE_CHECKING:
64+
from concurrent.futures import ThreadPoolExecutor
6465
from shapely.geometry.base import BaseGeometry
6566

6667
from eodag.api.product.drivers.base import DatasetDriver
@@ -122,6 +123,8 @@ class EOProduct:
122123
search_kwargs: Any
123124
#: Datetime for download next try
124125
next_try: datetime
126+
#: Stream for requests
127+
_stream: requests.Response
125128

126129
def __init__(
127130
self, provider: str, properties: dict[str, Any], **kwargs: Any
@@ -337,6 +340,7 @@ def register_downloader(
337340
def download(
338341
self,
339342
progress_callback: Optional[ProgressCallback] = None,
343+
executor: Optional[ThreadPoolExecutor] = None,
340344
wait: float = DEFAULT_DOWNLOAD_WAIT,
341345
timeout: float = DEFAULT_DOWNLOAD_TIMEOUT,
342346
**kwargs: Unpack[DownloadConf],
@@ -353,6 +357,8 @@ def download(
353357
size as inputs and handle progress bar
354358
creation and update to give the user a
355359
feedback on the download progress
360+
:param executor: (optional) An executor to download assets of the product in parallel if it has any. If ``None``
361+
, a default executor will be created
356362
:param wait: (optional) If download fails, wait time in minutes between
357363
two download tries
358364
:param timeout: (optional) If download fails, maximum time in minutes
@@ -377,17 +383,26 @@ def download(
377383
)
378384

379385
progress_callback, close_progress_callback = self._init_progress_bar(
380-
progress_callback
386+
progress_callback, executor
381387
)
388+
382389
fs_path = self.downloader.download(
383390
self,
384391
auth=auth,
385392
progress_callback=progress_callback,
393+
executor=executor,
386394
wait=wait,
387395
timeout=timeout,
388396
**kwargs,
389397
)
390398

399+
# shutdown executor if it was not created during parallel product downloads
400+
if (
401+
executor is not None
402+
and executor._thread_name_prefix != "eodag-download-all"
403+
):
404+
executor.shutdown(wait=True)
405+
391406
# close progress bar if needed
392407
if close_progress_callback:
393408
progress_callback.close()
@@ -408,15 +423,22 @@ def download(
408423
return fs_path
409424

410425
def _init_progress_bar(
411-
self, progress_callback: Optional[ProgressCallback]
426+
self,
427+
progress_callback: Optional[ProgressCallback],
428+
executor: Optional[ThreadPoolExecutor],
412429
) -> tuple[ProgressCallback, bool]:
430+
# determine position of the progress bar with a counter of executor passings
431+
# to avoid bar overwriting in case of parallel downloads
432+
count = executor._counter() if executor is not None else 1 # type: ignore
433+
413434
# progress bar init
414435
if progress_callback is None:
415-
progress_callback = ProgressCallback(position=1)
436+
progress_callback = ProgressCallback(position=count)
416437
# one shot progress callback to close after download
417438
close_progress_callback = True
418439
else:
419440
close_progress_callback = False
441+
progress_callback.pos = count
420442
# update units as bar may have been previously used for extraction
421443
progress_callback.unit = "B"
422444
progress_callback.unit_scale = True

eodag/cli.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from urllib.parse import parse_qs
4949

5050
import click
51+
from concurrent.futures import ThreadPoolExecutor
5152

5253
from eodag.api.collection import CollectionsList
5354
from eodag.api.core import EODataAccessGateway, SearchResult
@@ -556,6 +557,11 @@ def discover_col(ctx: Context, **kwargs: Any) -> None:
556557
type=click.Path(dir_okay=True, file_okay=False),
557558
help="Products or quicklooks download directory (Default: local temporary directory)",
558559
)
560+
@click.option(
561+
"--max-workers",
562+
type=int,
563+
help="The maximum number of workers to use for downloading products and assets in parallel",
564+
)
559565
@click.pass_context
560566
def download(ctx: Context, **kwargs: Any) -> None:
561567
"""Download a bunch of products from a serialized search result"""
@@ -601,7 +607,10 @@ def download(ctx: Context, **kwargs: Any) -> None:
601607

602608
else:
603609
# Download products
604-
downloaded_files = satim_api.download_all(search_results, output_dir=output_dir)
610+
executor = ThreadPoolExecutor(max_workers=kwargs.pop("max_workers"))
611+
downloaded_files = satim_api.download_all(
612+
search_results, output_dir=output_dir, executor=executor
613+
)
605614
if downloaded_files and len(downloaded_files) > 0:
606615
for downloaded_file in downloaded_files:
607616
if downloaded_file is None:

eodag/plugins/apis/ecmwf.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,15 @@
4848
if TYPE_CHECKING:
4949
from typing import Any, Optional, Union
5050

51+
from concurrent.futures import ThreadPoolExecutor
5152
from mypy_boto3_s3 import S3ServiceResource
5253
from requests.auth import AuthBase
5354

5455
from eodag.api.product import EOProduct
5556
from eodag.api.search_result import SearchResult
5657
from eodag.config import PluginConfig
5758
from eodag.types.download_args import DownloadConf
58-
from eodag.utils import DownloadedCallback, ProgressCallback, Unpack
59+
from eodag.utils import ProgressCallback, Unpack
5960

6061

6162
logger = logging.getLogger("eodag.apis.ecmwf")
@@ -185,6 +186,7 @@ def download(
185186
product: EOProduct,
186187
auth: Optional[Union[AuthBase, S3ServiceResource]] = None,
187188
progress_callback: Optional[ProgressCallback] = None,
189+
executor: Optional[ThreadPoolExecutor] = None,
188190
wait: float = DEFAULT_DOWNLOAD_WAIT,
189191
timeout: float = DEFAULT_DOWNLOAD_TIMEOUT,
190192
**kwargs: Unpack[DownloadConf],
@@ -269,29 +271,6 @@ def download(
269271
product.location = path_to_uri(product_path)
270272
return product_path
271273

272-
def download_all(
273-
self,
274-
products: SearchResult,
275-
auth: Optional[Union[AuthBase, S3ServiceResource]] = None,
276-
downloaded_callback: Optional[DownloadedCallback] = None,
277-
progress_callback: Optional[ProgressCallback] = None,
278-
wait: float = DEFAULT_DOWNLOAD_WAIT,
279-
timeout: float = DEFAULT_DOWNLOAD_TIMEOUT,
280-
**kwargs: Unpack[DownloadConf],
281-
) -> list[str]:
282-
"""
283-
Download all using parent (base plugin) method
284-
"""
285-
return super(EcmwfApi, self).download_all(
286-
products,
287-
auth=auth,
288-
downloaded_callback=downloaded_callback,
289-
progress_callback=progress_callback,
290-
wait=wait,
291-
timeout=timeout,
292-
**kwargs,
293-
)
294-
295274
def clear(self) -> None:
296275
"""Clear search context"""
297276
pass

eodag/plugins/apis/usgs.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,13 @@
5757
)
5858

5959
if TYPE_CHECKING:
60+
from concurrent.futures import ThreadPoolExecutor
6061
from mypy_boto3_s3 import S3ServiceResource
6162
from requests.auth import AuthBase
6263

6364
from eodag.config import PluginConfig
6465
from eodag.types.download_args import DownloadConf
65-
from eodag.utils import DownloadedCallback, Unpack
66+
from eodag.utils import Unpack
6667

6768
logger = logging.getLogger("eodag.apis.usgs")
6869

@@ -312,6 +313,7 @@ def download(
312313
product: EOProduct,
313314
auth: Optional[Union[AuthBase, S3ServiceResource]] = None,
314315
progress_callback: Optional[ProgressCallback] = None,
316+
executor: Optional[ThreadPoolExecutor] = None,
315317
wait: float = DEFAULT_DOWNLOAD_WAIT,
316318
timeout: float = DEFAULT_DOWNLOAD_TIMEOUT,
317319
**kwargs: Unpack[DownloadConf],
@@ -477,26 +479,3 @@ def download_request(
477479
shutil.move(fs_path, new_fs_path)
478480
product.location = path_to_uri(new_fs_path)
479481
return new_fs_path
480-
481-
def download_all(
482-
self,
483-
products: SearchResult,
484-
auth: Optional[Union[AuthBase, S3ServiceResource]] = None,
485-
downloaded_callback: Optional[DownloadedCallback] = None,
486-
progress_callback: Optional[ProgressCallback] = None,
487-
wait: float = DEFAULT_DOWNLOAD_WAIT,
488-
timeout: float = DEFAULT_DOWNLOAD_TIMEOUT,
489-
**kwargs: Unpack[DownloadConf],
490-
) -> list[str]:
491-
"""
492-
Download all using parent (base plugin) method
493-
"""
494-
return super(UsgsApi, self).download_all(
495-
products,
496-
auth=auth,
497-
downloaded_callback=downloaded_callback,
498-
progress_callback=progress_callback,
499-
wait=wait,
500-
timeout=timeout,
501-
**kwargs,
502-
)

0 commit comments

Comments
 (0)