Skip to content

Commit 3e8c052

Browse files
committed
feat(trainer): add s3 initializers, add ignore_patterns to hf initializers
Signed-off-by: rudeigerc <[email protected]>
1 parent e152b71 commit 3e8c052

File tree

1 file changed

+75
-8
lines changed

1 file changed

+75
-8
lines changed

kubeflow/trainer/types/types.py

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -277,12 +277,42 @@ class TrainJob:
277277
# TODO (andreyvelich): Discuss how to keep these configurations is sync with pkg.initializers.types
278278
@dataclass
279279
class HuggingFaceDatasetInitializer:
280-
"""Configuration for downloading datasets from HuggingFace Hub."""
280+
"""Configuration for downloading datasets from HuggingFace Hub.
281+
282+
Args:
283+
storage_uri (`str`): The HuggingFace Hub model identifier in the format 'hf://username/repo_name'.
284+
ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download.
285+
access_token (`Optional[str]`): HuggingFace Hub access token for private datasets.
286+
"""
281287

282288
storage_uri: str
289+
ignore_patterns: Optional[list[str]] = None
283290
access_token: Optional[str] = None
284291

285292

293+
@dataclass
294+
class S3DatasetInitializer:
295+
"""Configuration for downloading datasets from S3-compatible storage.
296+
297+
Args:
298+
storage_uri (`str`): The S3 URI for the model in the format 's3://bucket-name/path/to/model'.
299+
ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download.
300+
endpoint (`Optional[str]`): Custom S3 endpoint URL.
301+
access_key_id (`Optional[str]`): Access key for authentication.
302+
secret_access_key (`Optional[str]`): Secret key for authentication.
303+
region (`Optional[str]`): Region used in instantiating the client.
304+
role_arn (`Optional[str]`): The ARN of the role you want to assume.
305+
"""
306+
307+
storage_uri: str
308+
ignore_patterns: Optional[list[str]] = None
309+
endpoint: Optional[str] = None
310+
access_key_id: Optional[str] = None
311+
secret_access_key: Optional[str] = None
312+
region: Optional[str] = None
313+
role_arn: Optional[str] = None
314+
315+
286316
@dataclass
287317
class DataCacheInitializer:
288318
"""Configuration for distributed data caching system for training workloads.
@@ -332,23 +362,60 @@ def __post_init__(self):
332362
# Configuration for the HuggingFace model initializer.
333363
@dataclass
334364
class HuggingFaceModelInitializer:
365+
"""Configuration for downloading models from HuggingFace Hub.
366+
367+
Args:
368+
storage_uri (`str`): The HuggingFace Hub model identifier in the format 'hf://username/repo_name'.
369+
ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download.
370+
access_token (`Optional[str]`): HuggingFace Hub access token.
371+
"""
372+
335373
storage_uri: str
374+
ignore_patterns: Optional[list[str]] = None
336375
access_token: Optional[str] = None
337376

338377

378+
@dataclass
379+
class S3ModelInitializer:
380+
"""Configuration for downloading models from S3-compatible storage.
381+
382+
Args:
383+
storage_uri (`str`): The S3 URI for the model in the format 's3://bucket-name/path/to/model'.
384+
ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download.
385+
Defaults to `['*.msgpack', '*.h5', '*.bin', '.pt', '.pth']`.
386+
endpoint (`Optional[str]`): Custom S3 endpoint URL.
387+
access_key_id (`Optional[str]`): Access key for authentication.
388+
secret_access_key (`Optional[str]`): Secret key for authentication.
389+
region (`Optional[str]`): Region used in instantiating the client.
390+
role_arn (`Optional[str]`): The ARN of the role you want to assume.
391+
"""
392+
393+
storage_uri: str
394+
ignore_patterns: Optional[list[str]] = field(
395+
default_factory=lambda: ["*.msgpack", "*.h5", "*.bin", ".pt", ".pth"]
396+
)
397+
endpoint: Optional[str] = None
398+
access_key_id: Optional[str] = None
399+
secret_access_key: Optional[str] = None
400+
region: Optional[str] = None
401+
role_arn: Optional[str] = None
402+
403+
339404
@dataclass
340405
class Initializer:
341406
"""Initializer defines configurations for dataset and pre-trained model initialization
342407
343408
Args:
344-
dataset (`Optional[Union[HuggingFaceDatasetInitializer, DataCacheInitializer]]`):
409+
dataset (`Optional[Union[HuggingFaceDatasetInitializer, S3DatasetInitializer, DataCacheInitializer]]`):
345410
The configuration for one of the supported dataset initializers.
346-
model (`Optional[HuggingFaceModelInitializer]`): The configuration for one of the
347-
supported model initializers.
348-
"""
349-
350-
dataset: Optional[Union[HuggingFaceDatasetInitializer, DataCacheInitializer]] = None
351-
model: Optional[HuggingFaceModelInitializer] = None
411+
model (`Optional[Union[HuggingFaceModelInitializer, S3ModelInitializer]]`):
412+
The configuration for one of the supported model initializers.
413+
""" # noqa: E501
414+
415+
dataset: Optional[
416+
Union[HuggingFaceDatasetInitializer, S3DatasetInitializer, DataCacheInitializer]
417+
] = None
418+
model: Optional[Union[HuggingFaceModelInitializer, S3ModelInitializer]] = None
352419

353420

354421
# TODO (andreyvelich): Add train() and optimize() methods to this class.

0 commit comments

Comments
 (0)