@@ -277,12 +277,42 @@ class TrainJob:
277277# TODO (andreyvelich): Discuss how to keep these configurations is sync with pkg.initializers.types
278278@dataclass
279279class HuggingFaceDatasetInitializer :
280- """Configuration for downloading datasets from HuggingFace Hub."""
280+ """Configuration for downloading datasets from HuggingFace Hub.
281+
282+ Args:
283+ storage_uri (`str`): The HuggingFace Hub model identifier in the format 'hf://username/repo_name'.
284+ ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download.
285+ access_token (`Optional[str]`): HuggingFace Hub access token for private datasets.
286+ """
281287
282288 storage_uri : str
289+ ignore_patterns : Optional [list [str ]] = None
283290 access_token : Optional [str ] = None
284291
285292
293+ @dataclass
294+ class S3DatasetInitializer :
295+ """Configuration for downloading datasets from S3-compatible storage.
296+
297+ Args:
298+ storage_uri (`str`): The S3 URI for the model in the format 's3://bucket-name/path/to/model'.
299+ ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download.
300+ endpoint (`Optional[str]`): Custom S3 endpoint URL.
301+ access_key_id (`Optional[str]`): Access key for authentication.
302+ secret_access_key (`Optional[str]`): Secret key for authentication.
303+ region (`Optional[str]`): Region used in instantiating the client.
304+ role_arn (`Optional[str]`): The ARN of the role you want to assume.
305+ """
306+
307+ storage_uri : str
308+ ignore_patterns : Optional [list [str ]] = None
309+ endpoint : Optional [str ] = None
310+ access_key_id : Optional [str ] = None
311+ secret_access_key : Optional [str ] = None
312+ region : Optional [str ] = None
313+ role_arn : Optional [str ] = None
314+
315+
286316@dataclass
287317class DataCacheInitializer :
288318 """Configuration for distributed data caching system for training workloads.
@@ -332,23 +362,60 @@ def __post_init__(self):
332362# Configuration for the HuggingFace model initializer.
333363@dataclass
334364class HuggingFaceModelInitializer :
365+ """Configuration for downloading models from HuggingFace Hub.
366+
367+ Args:
368+ storage_uri (`str`): The HuggingFace Hub model identifier in the format 'hf://username/repo_name'.
369+ ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download.
370+ access_token (`Optional[str]`): HuggingFace Hub access token.
371+ """
372+
335373 storage_uri : str
374+ ignore_patterns : Optional [list [str ]] = None
336375 access_token : Optional [str ] = None
337376
338377
378+ @dataclass
379+ class S3ModelInitializer :
380+ """Configuration for downloading models from S3-compatible storage.
381+
382+ Args:
383+ storage_uri (`str`): The S3 URI for the model in the format 's3://bucket-name/path/to/model'.
384+ ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download.
385+ Defaults to `['*.msgpack', '*.h5', '*.bin', '.pt', '.pth']`.
386+ endpoint (`Optional[str]`): Custom S3 endpoint URL.
387+ access_key_id (`Optional[str]`): Access key for authentication.
388+ secret_access_key (`Optional[str]`): Secret key for authentication.
389+ region (`Optional[str]`): Region used in instantiating the client.
390+ role_arn (`Optional[str]`): The ARN of the role you want to assume.
391+ """
392+
393+ storage_uri : str
394+ ignore_patterns : Optional [list [str ]] = field (
395+ default_factory = lambda : ["*.msgpack" , "*.h5" , "*.bin" , ".pt" , ".pth" ]
396+ )
397+ endpoint : Optional [str ] = None
398+ access_key_id : Optional [str ] = None
399+ secret_access_key : Optional [str ] = None
400+ region : Optional [str ] = None
401+ role_arn : Optional [str ] = None
402+
403+
339404@dataclass
340405class Initializer :
341406 """Initializer defines configurations for dataset and pre-trained model initialization
342407
343408 Args:
344- dataset (`Optional[Union[HuggingFaceDatasetInitializer, DataCacheInitializer]]`):
409+ dataset (`Optional[Union[HuggingFaceDatasetInitializer, S3DatasetInitializer, DataCacheInitializer]]`):
345410 The configuration for one of the supported dataset initializers.
346- model (`Optional[HuggingFaceModelInitializer]`): The configuration for one of the
347- supported model initializers.
348- """
349-
350- dataset : Optional [Union [HuggingFaceDatasetInitializer , DataCacheInitializer ]] = None
351- model : Optional [HuggingFaceModelInitializer ] = None
411+ model (`Optional[Union[HuggingFaceModelInitializer, S3ModelInitializer]]`):
412+ The configuration for one of the supported model initializers.
413+ """ # noqa: E501
414+
415+ dataset : Optional [
416+ Union [HuggingFaceDatasetInitializer , S3DatasetInitializer , DataCacheInitializer ]
417+ ] = None
418+ model : Optional [Union [HuggingFaceModelInitializer , S3ModelInitializer ]] = None
352419
353420
354421# TODO (andreyvelich): Add train() and optimize() methods to this class.
0 commit comments