Skip to content

Commit 2d7a8e6

Browse files
Add simplified Training Options for TrainJob labels and annotations
Signed-off-by: Abhijeet Dhumal <[email protected]>
1 parent ffc3d62 commit 2d7a8e6

File tree

7 files changed

+703
-417
lines changed

7 files changed

+703
-417
lines changed

kubeflow/trainer/api/trainer_client.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Optional, Union, Iterator
17-
16+
from typing import List, Optional, Union, Iterator
1817
from kubeflow.trainer.constants import constants
1918
from kubeflow.trainer.types import types
19+
from kubeflow.trainer.options.options import Option
2020
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
2121
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
2222
from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackend
@@ -93,6 +93,7 @@ def train(
9393
runtime: Optional[types.Runtime] = None,
9494
initializer: Optional[types.Initializer] = None,
9595
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
96+
options: Optional[List[Option]] = None,
9697
) -> str:
9798
"""Create a TrainJob. You can configure the TrainJob using one of these trainers:
9899
@@ -107,6 +108,8 @@ def train(
107108
initializer: Optional configuration for the dataset and model initializers.
108109
trainer: Optional configuration for a CustomTrainer or BuiltinTrainer. If not specified,
109110
the TrainJob will use the runtime's default values.
111+
options: Optional list of configuration options to apply to the TrainJob. Use
112+
WithLabels and WithAnnotations for basic metadata configuration.
110113
111114
Returns:
112115
The unique name of the TrainJob that has been generated.
@@ -116,7 +119,24 @@ def train(
116119
TimeoutError: Timeout to create TrainJobs.
117120
RuntimeError: Failed to create TrainJobs.
118121
"""
119-
return self.backend.train(runtime=runtime, initializer=initializer, trainer=trainer)
122+
job_spec = {}
123+
124+
if options:
125+
for option in options:
126+
option.apply(job_spec)
127+
128+
metadata_section = job_spec.get("metadata", {})
129+
130+
labels = metadata_section.get("labels") or None
131+
annotations = metadata_section.get("annotations") or None
132+
133+
return self.backend.train(
134+
runtime=runtime,
135+
initializer=initializer,
136+
trainer=trainer,
137+
labels=labels,
138+
annotations=annotations,
139+
)
120140

121141
def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.TrainJob]:
122142
"""List of the created TrainJobs. If a runtime is specified, only TrainJobs associated with

kubeflow/trainer/backends/kubernetes/backend.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import string
2020
import time
2121
import uuid
22-
from typing import Optional, Union, Iterator
22+
from typing import Dict, Optional, Union, Iterator
2323
import re
2424

2525
from kubeflow.trainer.constants import constants
@@ -181,6 +181,8 @@ def train(
181181
runtime: Optional[types.Runtime] = None,
182182
initializer: Optional[types.Initializer] = None,
183183
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
184+
labels: Optional[Dict[str, str]] = None,
185+
annotations: Optional[Dict[str, str]] = None,
184186
) -> str:
185187
if runtime is None:
186188
runtime = self.get_runtime(constants.TORCH_RUNTIME)
@@ -216,7 +218,11 @@ def train(
216218
train_job = models.TrainerV1alpha1TrainJob(
217219
apiVersion=constants.API_VERSION,
218220
kind=constants.TRAINJOB_KIND,
219-
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(name=train_job_name),
221+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
222+
name=train_job_name,
223+
labels=labels,
224+
annotations=annotations
225+
),
220226
spec=models.TrainerV1alpha1TrainJobSpec(
221227
runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime.name),
222228
trainer=(trainer_crd if trainer_crd != models.TrainerV1alpha1Trainer() else None),

kubeflow/trainer/backends/kubernetes/backend_test.py

Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@
2525
import string
2626
import uuid
2727
from dataclasses import asdict
28-
from typing import Optional
28+
from typing import Dict, Optional
2929
from unittest.mock import Mock, patch
3030

3131
import pytest
3232
from kubeflow_trainer_api import models
3333

3434
from kubeflow.trainer.constants import constants
3535
from kubeflow.trainer.types import types
36+
from kubeflow.trainer.options import WithLabels, WithAnnotations
3637
from kubeflow.trainer.utils import utils
3738
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
3839
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
@@ -253,14 +254,20 @@ def get_train_job(
253254
runtime_name: str,
254255
train_job_name: str = BASIC_TRAIN_JOB_NAME,
255256
train_job_trainer: Optional[models.TrainerV1alpha1Trainer] = None,
257+
labels: Optional[Dict[str, str]] = None,
258+
annotations: Optional[Dict[str, str]] = None,
256259
) -> models.TrainerV1alpha1TrainJob:
257260
"""
258261
Create a mock TrainJob object with optional trainer configurations.
259262
"""
260263
train_job = models.TrainerV1alpha1TrainJob(
261264
apiVersion=constants.API_VERSION,
262265
kind=constants.TRAINJOB_KIND,
263-
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(name=train_job_name),
266+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
267+
name=train_job_name,
268+
labels=labels,
269+
annotations=annotations
270+
),
264271
spec=models.TrainerV1alpha1TrainJobSpec(
265272
runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime_name),
266273
trainer=train_job_trainer,
@@ -788,7 +795,73 @@ def test_get_runtime_packages(kubernetes_backend, test_case):
788795
},
789796
expected_error=ValueError,
790797
),
791-
798+
TestCase(
799+
name="valid flow with labels and annotations",
800+
expected_status=SUCCESS,
801+
config={
802+
"labels": {"kueue.x-k8s.io/queue-name": "ml-queue", "team": "ml-engineering"},
803+
"annotations": {"experiment.id": "exp-001", "description": "Test training job"},
804+
},
805+
expected_output=get_train_job(
806+
runtime_name=TORCH_RUNTIME,
807+
train_job_name=BASIC_TRAIN_JOB_NAME,
808+
labels={"kueue.x-k8s.io/queue-name": "ml-queue", "team": "ml-engineering"},
809+
annotations={"experiment.id": "exp-001", "description": "Test training job"},
810+
),
811+
),
812+
TestCase(
813+
name="valid flow with only labels",
814+
expected_status=SUCCESS,
815+
config={
816+
"labels": {"priority": "high"},
817+
},
818+
expected_output=get_train_job(
819+
runtime_name=TORCH_RUNTIME,
820+
train_job_name=BASIC_TRAIN_JOB_NAME,
821+
labels={"priority": "high"},
822+
),
823+
),
824+
TestCase(
825+
name="valid flow with only annotations",
826+
expected_status=SUCCESS,
827+
config={
828+
"annotations": {"created-by": "training-pipeline"},
829+
},
830+
expected_output=get_train_job(
831+
runtime_name=TORCH_RUNTIME,
832+
train_job_name=BASIC_TRAIN_JOB_NAME,
833+
annotations={"created-by": "training-pipeline"},
834+
),
835+
),
836+
# Test cases using the new Options pattern
837+
TestCase(
838+
name="valid flow with WithLabels option",
839+
expected_status=SUCCESS,
840+
config={
841+
"options": [WithLabels({"team": "ml-platform", "project": "training"})],
842+
},
843+
expected_output=get_train_job(
844+
runtime_name=TORCH_RUNTIME,
845+
train_job_name=BASIC_TRAIN_JOB_NAME,
846+
labels={"team": "ml-platform", "project": "training"},
847+
),
848+
),
849+
TestCase(
850+
name="valid flow with multiple options",
851+
expected_status=SUCCESS,
852+
config={
853+
"options": [
854+
WithLabels({"team": "ml-platform"}),
855+
WithAnnotations({"created-by": "sdk"}),
856+
],
857+
},
858+
expected_output=get_train_job(
859+
runtime_name=TORCH_RUNTIME,
860+
train_job_name=BASIC_TRAIN_JOB_NAME,
861+
labels={"team": "ml-platform"},
862+
annotations={"created-by": "sdk"},
863+
),
864+
),
792865
],
793866
)
794867
def test_train(kubernetes_backend, test_case):
@@ -798,8 +871,38 @@ def test_train(kubernetes_backend, test_case):
798871
kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)
799872
runtime = kubernetes_backend.get_runtime(test_case.config.get("runtime", TORCH_RUNTIME))
800873

874+
job_spec = {}
875+
876+
options = test_case.config.get("options", None)
877+
if options:
878+
for option in options:
879+
option.apply(job_spec)
880+
881+
metadata_section = job_spec.get("metadata", {})
882+
883+
labels = metadata_section.get("labels") or None
884+
annotations = metadata_section.get("annotations") or None
885+
886+
# Merge individual parameters with options
887+
individual_labels = test_case.config.get("labels", None)
888+
individual_annotations = test_case.config.get("annotations", None)
889+
890+
if individual_labels:
891+
if labels:
892+
labels.update(individual_labels)
893+
else:
894+
labels = individual_labels
895+
if individual_annotations:
896+
if annotations:
897+
annotations.update(individual_annotations)
898+
else:
899+
annotations = individual_annotations
900+
801901
train_job_name = kubernetes_backend.train(
802-
runtime=runtime, trainer=test_case.config.get("trainer", None)
902+
runtime=runtime,
903+
trainer=test_case.config.get("trainer", None),
904+
labels=labels,
905+
annotations=annotations,
803906
)
804907

805908
assert test_case.expected_status == SUCCESS
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2025 The Kubeflow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from kubeflow.trainer.options.options import (
16+
Option,
17+
WithAnnotations,
18+
WithLabels,
19+
)
20+
21+
__all__ = [
22+
"Option",
23+
"WithAnnotations",
24+
"WithLabels",
25+
]
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2025 The Kubeflow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import ABC, abstractmethod
16+
from dataclasses import dataclass
17+
from typing import Dict
18+
19+
20+
class Option(ABC):
21+
"""Base class for TrainJob configuration options.
22+
23+
Options provide a composable way to configure different aspects of a TrainJob.
24+
Each option implements the apply() method to modify the TrainJob specification.
25+
"""
26+
27+
@abstractmethod
28+
def apply(self, job_spec: dict) -> None:
29+
"""Apply this option to the TrainJob specification.
30+
31+
Args:
32+
job_spec: The TrainJob specification dictionary to modify.
33+
"""
34+
pass
35+
36+
37+
@dataclass
38+
class WithLabels(Option):
39+
"""Add labels to the TrainJob resource metadata (.metadata.labels).
40+
41+
These labels are applied to the TrainJob resource itself and are used
42+
for resource organization, filtering, and selection.
43+
44+
Args:
45+
labels: Dictionary of labels to apply to the TrainJob metadata.
46+
"""
47+
48+
labels: Dict[str, str]
49+
50+
def apply(self, job_spec: dict) -> None:
51+
"""Apply labels to TrainJob metadata."""
52+
metadata = job_spec.setdefault("metadata", {})
53+
existing_labels = metadata.setdefault("labels", {})
54+
existing_labels.update(self.labels)
55+
56+
57+
@dataclass
58+
class WithAnnotations(Option):
59+
"""Add annotations to the TrainJob resource metadata (.metadata.annotations).
60+
61+
These annotations are applied to the TrainJob resource itself and are used
62+
for storing additional metadata about the training job resource.
63+
64+
Args:
65+
annotations: Dictionary of annotations to apply to the TrainJob metadata.
66+
"""
67+
68+
annotations: Dict[str, str]
69+
70+
def apply(self, job_spec: dict) -> None:
71+
"""Apply annotations to TrainJob metadata."""
72+
metadata = job_spec.setdefault("metadata", {})
73+
existing_annotations = metadata.setdefault("annotations", {})
74+
existing_annotations.update(self.annotations)
75+
76+

kubeflow/trainer/types/__init__.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2024 The Kubeflow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from kubeflow.trainer.types.types import (
16+
BuiltinTrainer,
17+
CustomTrainer,
18+
DataFormat,
19+
DataType,
20+
HuggingFaceDatasetInitializer,
21+
HuggingFaceModelInitializer,
22+
Initializer,
23+
Loss,
24+
Runtime,
25+
RuntimeTrainer,
26+
Step,
27+
TorchTuneConfig,
28+
TorchTuneInstructDataset,
29+
TrainJob,
30+
TrainerType,
31+
)
32+
33+
__all__ = [
34+
"BuiltinTrainer",
35+
"CustomTrainer",
36+
"DataFormat",
37+
"DataType",
38+
"HuggingFaceDatasetInitializer",
39+
"HuggingFaceModelInitializer",
40+
"Initializer",
41+
"Loss",
42+
"Runtime",
43+
"RuntimeTrainer",
44+
"Step",
45+
"TorchTuneConfig",
46+
"TorchTuneInstructDataset",
47+
"TrainJob",
48+
"TrainerType",
49+
]

0 commit comments

Comments
 (0)