Skip to content

Commit 03994e7

Browse files
Add simplified Training Options for TrainJob labels, annotations and podSpecOverride
Signed-off-by: Abhijeet Dhumal <[email protected]>
1 parent 4faad04 commit 03994e7

File tree

15 files changed

+2244
-265
lines changed

15 files changed

+2244
-265
lines changed

kubeflow/trainer/api/trainer_client.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
LocalProcessBackend,
2323
LocalProcessBackendConfig,
2424
)
25+
from kubeflow.trainer.backends.options import Option
2526
from kubeflow.trainer.constants import constants
2627
from kubeflow.trainer.types import types
2728

@@ -96,6 +97,7 @@ def train(
9697
runtime: Optional[types.Runtime] = None,
9798
initializer: Optional[types.Initializer] = None,
9899
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
100+
options: Optional[list[Option]] = None,
99101
) -> str:
100102
"""Create a TrainJob. You can configure the TrainJob using one of these trainers:
101103
@@ -110,6 +112,8 @@ def train(
110112
initializer: Optional configuration for the dataset and model initializers.
111113
trainer: Optional configuration for a CustomTrainer or BuiltinTrainer. If not specified,
112114
the TrainJob will use the runtime's default values.
115+
options: Optional list of configuration options to apply to the TrainJob. Use
116+
WithLabels and WithAnnotations for basic metadata configuration.
113117
114118
Returns:
115119
The unique name of the TrainJob that has been generated.
@@ -119,7 +123,16 @@ def train(
119123
TimeoutError: Timeout to create TrainJobs.
120124
RuntimeError: Failed to create TrainJobs.
121125
"""
122-
return self.backend.train(runtime=runtime, initializer=initializer, trainer=trainer)
126+
# Validate options compatibility with backend
127+
if options:
128+
self.backend.validate_options(options)
129+
130+
return self.backend.train(
131+
runtime=runtime,
132+
initializer=initializer,
133+
trainer=trainer,
134+
options=options,
135+
)
123136

124137
def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.TrainJob]:
125138
"""List of the created TrainJobs. If a runtime is specified, only TrainJobs associated with
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
# Copyright 2025 The Kubeflow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Unit tests for TrainerClient option handling and error messages.
17+
"""
18+
19+
from unittest.mock import Mock, patch
20+
21+
import pytest
22+
23+
from kubeflow.trainer.api.trainer_client import TrainerClient
24+
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
25+
from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig
26+
from kubeflow.trainer.options import WithAnnotations, WithLabels
27+
from kubeflow.trainer.types import types
28+
29+
30+
class TestTrainerClientOptionValidation:
31+
"""Test TrainerClient option validation integration."""
32+
33+
def test_trainer_client_passes_options_to_backend(self):
34+
"""Test that TrainerClient passes options to backend correctly."""
35+
config = LocalProcessBackendConfig()
36+
client = TrainerClient(backend_config=config)
37+
38+
def simple_func():
39+
return "test"
40+
41+
trainer = types.CustomTrainer(func=simple_func)
42+
options = [WithLabels({"app": "test"})]
43+
44+
with pytest.raises(ValueError) as exc_info:
45+
client.train(trainer=trainer, options=options)
46+
47+
error_msg = str(exc_info.value)
48+
assert "The following options are not compatible with this backend" in error_msg
49+
assert "WithLabels (labels)" in error_msg
50+
51+
@patch("kubernetes.config.load_kube_config")
52+
@patch("kubernetes.client.CustomObjectsApi")
53+
@patch("kubernetes.client.CoreV1Api")
54+
def test_trainer_client_with_kubernetes_backend(
55+
self, mock_core_api, mock_custom_api, mock_load_config
56+
):
57+
"""Test TrainerClient with KubernetesBackend and compatible options."""
58+
mock_custom_api.return_value = Mock()
59+
mock_core_api.return_value = Mock()
60+
61+
config = KubernetesBackendConfig()
62+
client = TrainerClient(backend_config=config)
63+
64+
def simple_func():
65+
return "test"
66+
67+
trainer = types.CustomTrainer(func=simple_func)
68+
options = [WithLabels({"app": "test"}), WithAnnotations({"desc": "test"})]
69+
70+
with pytest.raises((ValueError, RuntimeError)) as exc_info:
71+
client.train(trainer=trainer, options=options)
72+
73+
error_msg = str(exc_info.value)
74+
# Should either fail with runtime requirement or K8s connection error
75+
assert (
76+
"Runtime is required" in error_msg
77+
or "Failed to get clustertrainingruntimes" in error_msg
78+
)
79+
80+
def test_trainer_client_empty_options(self):
81+
"""Test TrainerClient with empty options."""
82+
config = LocalProcessBackendConfig()
83+
client = TrainerClient(backend_config=config)
84+
85+
def simple_func():
86+
return "test"
87+
88+
trainer = types.CustomTrainer(func=simple_func)
89+
90+
with pytest.raises(ValueError) as exc_info:
91+
client.train(trainer=trainer, options=[])
92+
93+
error_msg = str(exc_info.value)
94+
assert "Runtime must be provided for LocalProcessBackend" in error_msg
95+
96+
97+
class TestTrainerClientErrorHandling:
98+
"""Test TrainerClient error handling improvements."""
99+
100+
def test_missing_runtime_error_message(self):
101+
"""Test improved error message for missing runtime."""
102+
config = LocalProcessBackendConfig()
103+
client = TrainerClient(backend_config=config)
104+
105+
def simple_func():
106+
return "test"
107+
108+
trainer = types.CustomTrainer(func=simple_func)
109+
110+
with pytest.raises(ValueError) as exc_info:
111+
client.train(trainer=trainer)
112+
113+
error_msg = str(exc_info.value)
114+
# The error message should contain the runtime requirement
115+
assert "Runtime must be provided for LocalProcessBackend" in error_msg
116+
117+
def test_option_validation_error_propagation(self):
118+
"""Test that option validation errors are properly propagated."""
119+
config = LocalProcessBackendConfig()
120+
client = TrainerClient(backend_config=config)
121+
122+
def simple_func():
123+
return "test"
124+
125+
trainer = types.CustomTrainer(func=simple_func)
126+
options = [WithLabels({"app": "test"}), WithAnnotations({"desc": "test"})]
127+
128+
with pytest.raises(ValueError) as exc_info:
129+
client.train(trainer=trainer, options=options)
130+
131+
error_msg = str(exc_info.value)
132+
assert "The following options are not compatible with this backend" in error_msg
133+
assert "WithLabels (labels)" in error_msg
134+
assert "WithAnnotations (annotations)" in error_msg
135+
assert "The following options are not compatible with this backend" in error_msg
136+
137+
def test_error_message_does_not_contain_runtime_help_for_option_errors(self):
138+
"""Test that option validation errors don't get runtime help text."""
139+
config = LocalProcessBackendConfig()
140+
client = TrainerClient(backend_config=config)
141+
142+
def simple_func():
143+
return "test"
144+
145+
trainer = types.CustomTrainer(func=simple_func)
146+
options = [WithLabels({"app": "test"})]
147+
148+
with pytest.raises(ValueError) as exc_info:
149+
client.train(trainer=trainer, options=options)
150+
151+
error_msg = str(exc_info.value)
152+
assert "The following options are not compatible with this backend" in error_msg
153+
assert "Example usage:" not in error_msg
154+
155+
@patch("kubernetes.config.load_kube_config")
156+
@patch("kubernetes.client.CustomObjectsApi")
157+
@patch("kubernetes.client.CoreV1Api")
158+
def test_kubernetes_backend_error_handling(
159+
self, mock_core_api, mock_custom_api, mock_load_config
160+
):
161+
"""Test error handling with KubernetesBackend."""
162+
mock_custom_api.return_value = Mock()
163+
mock_core_api.return_value = Mock()
164+
165+
config = KubernetesBackendConfig()
166+
client = TrainerClient(backend_config=config)
167+
168+
def simple_func():
169+
return "test"
170+
171+
trainer = types.CustomTrainer(func=simple_func)
172+
173+
with pytest.raises((ValueError, RuntimeError)) as exc_info:
174+
client.train(trainer=trainer)
175+
176+
error_msg = str(exc_info.value)
177+
# Should either fail with runtime requirement or K8s connection error
178+
assert (
179+
"Runtime is required" in error_msg
180+
or "Failed to get clustertrainingruntimes" in error_msg
181+
)
182+
183+
184+
class TestTrainerClientBackendSelection:
185+
"""Test TrainerClient backend selection and configuration."""
186+
187+
@patch("kubernetes.config.load_kube_config")
188+
@patch("kubernetes.client.CustomObjectsApi")
189+
@patch("kubernetes.client.CoreV1Api")
190+
def test_default_backend_is_kubernetes(self, mock_core_api, mock_custom_api, mock_load_config):
191+
"""Test that default backend is Kubernetes."""
192+
mock_custom_api.return_value = Mock()
193+
mock_core_api.return_value = Mock()
194+
195+
client = TrainerClient()
196+
197+
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
198+
199+
assert isinstance(client.backend, KubernetesBackend)
200+
201+
def test_local_process_backend_selection(self):
202+
"""Test LocalProcess backend selection."""
203+
config = LocalProcessBackendConfig()
204+
client = TrainerClient(backend_config=config)
205+
206+
from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackend
207+
208+
assert isinstance(client.backend, LocalProcessBackend)
209+
210+
@patch("kubernetes.config.load_kube_config")
211+
@patch("kubernetes.client.CustomObjectsApi")
212+
@patch("kubernetes.client.CoreV1Api")
213+
def test_kubernetes_backend_selection(self, mock_core_api, mock_custom_api, mock_load_config):
214+
"""Test Kubernetes backend selection."""
215+
mock_custom_api.return_value = Mock()
216+
mock_core_api.return_value = Mock()
217+
218+
config = KubernetesBackendConfig()
219+
client = TrainerClient(backend_config=config)
220+
221+
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
222+
223+
assert isinstance(client.backend, KubernetesBackend)
224+
225+
226+
class TestTrainerClientOptionFlow:
227+
"""Test the complete option flow through TrainerClient."""
228+
229+
def test_option_validation_happens_early(self):
230+
"""Test that option validation happens before other validations."""
231+
config = LocalProcessBackendConfig()
232+
client = TrainerClient(backend_config=config)
233+
234+
def simple_func():
235+
return "test"
236+
237+
trainer = types.CustomTrainer(func=simple_func)
238+
options = [WithLabels({"app": "test"})]
239+
240+
with pytest.raises(ValueError) as exc_info:
241+
client.train(trainer=trainer, options=options)
242+
243+
error_msg = str(exc_info.value)
244+
assert "The following options are not compatible with this backend" in error_msg
245+
246+
def test_multiple_option_types_validation(self):
247+
"""Test validation with multiple different option types."""
248+
config = LocalProcessBackendConfig()
249+
client = TrainerClient(backend_config=config)
250+
251+
def simple_func():
252+
return "test"
253+
254+
trainer = types.CustomTrainer(func=simple_func)
255+
options = [
256+
WithLabels({"app": "test"}),
257+
WithAnnotations({"desc": "test"}),
258+
]
259+
260+
with pytest.raises(ValueError) as exc_info:
261+
client.train(trainer=trainer, options=options)
262+
263+
error_msg = str(exc_info.value)
264+
assert "The following options are not compatible with this backend" in error_msg
265+
assert "WithLabels (labels)" in error_msg
266+
assert "WithAnnotations (annotations)" in error_msg
267+
268+
def test_none_options_handling(self):
269+
"""Test that None options are handled correctly."""
270+
config = LocalProcessBackendConfig()
271+
client = TrainerClient(backend_config=config)
272+
273+
def simple_func():
274+
return "test"
275+
276+
trainer = types.CustomTrainer(func=simple_func)
277+
278+
with pytest.raises(ValueError) as exc_info:
279+
client.train(trainer=trainer, options=None)
280+
281+
error_msg = str(exc_info.value)
282+
assert "Runtime must be provided for LocalProcessBackend" in error_msg

0 commit comments

Comments
 (0)