Skip to content

Commit b1ce0f5

Browse files
fix: Minor improvemnts in imports and trainer types
Signed-off-by: Abhijeet Dhumal <[email protected]>
1 parent 5a7c563 commit b1ce0f5

File tree

8 files changed

+226
-201
lines changed

8 files changed

+226
-201
lines changed

kubeflow/trainer/api/trainer_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def train(
112112
trainer: Optional configuration for a CustomTrainer or BuiltinTrainer. If not specified,
113113
the TrainJob will use the runtime's default values.
114114
options: Optional list of configuration options to apply to the TrainJob. Use
115-
WithLabels and WithAnnotations for basic metadata configuration.
115+
Labels and Annotations for basic metadata configuration.
116116
117117
Returns:
118118
The unique name of the TrainJob that has been generated.

kubeflow/trainer/options/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020

2121
from kubeflow.trainer.options.common import (
2222
ContainerOverride,
23+
Name,
2324
PodTemplateOverride,
2425
PodTemplateSpecOverride,
2526
)
2627
from kubeflow.trainer.options.kubernetes import (
2728
Annotations,
2829
Labels,
29-
Name,
3030
PodTemplateOverrides,
3131
SpecAnnotations,
3232
SpecLabels,
@@ -39,12 +39,12 @@
3939
__all__ = [
4040
# Common options
4141
"ContainerOverride",
42+
"Name",
4243
"PodTemplateOverride",
4344
"PodTemplateSpecOverride",
4445
# Kubernetes options
4546
"Annotations",
4647
"Labels",
47-
"Name",
4848
"PodTemplateOverrides",
4949
"SpecAnnotations",
5050
"SpecLabels",

kubeflow/trainer/options/common.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,40 @@
1414

1515
"""Common options and helper classes used across multiple backends."""
1616

17+
from __future__ import annotations
18+
1719
from dataclasses import dataclass
18-
from typing import Optional
20+
from typing import TYPE_CHECKING, Any
21+
22+
from kubeflow.trainer.backends.base import UniversalCompatible
23+
24+
if TYPE_CHECKING:
25+
from kubeflow.trainer.types.types import BuiltinTrainer, CustomTrainer
1926

2027
__all__ = [
2128
"ContainerOverride",
29+
"Name",
2230
"PodTemplateOverride",
2331
"PodTemplateSpecOverride",
2432
]
2533

2634

35+
@dataclass
36+
class Name(UniversalCompatible):
37+
"""Set a custom name for the TrainJob resource."""
38+
39+
name: str
40+
41+
def __call__(
42+
self,
43+
job_spec: dict[str, Any],
44+
trainer: BuiltinTrainer | CustomTrainer | None = None,
45+
) -> None:
46+
"""Apply custom name to the job specification."""
47+
metadata = job_spec.setdefault("metadata", {})
48+
metadata["name"] = self.name
49+
50+
2751
@dataclass
2852
class ContainerOverride:
2953
"""Configuration for overriding a specific container in a pod.
@@ -37,8 +61,8 @@ class ContainerOverride:
3761
"""
3862

3963
name: str
40-
env: Optional[list[dict]] = None
41-
volume_mounts: Optional[list[dict]] = None
64+
env: list[dict] | None = None
65+
volume_mounts: list[dict] | None = None
4266

4367
def __post_init__(self):
4468
"""Validate the container override configuration."""
@@ -108,15 +132,15 @@ class PodTemplateSpecOverride:
108132
image_pull_secrets: Image pull secrets for the pods.
109133
"""
110134

111-
service_account_name: Optional[str] = None
112-
node_selector: Optional[dict[str, str]] = None
113-
affinity: Optional[dict] = None
114-
tolerations: Optional[list[dict]] = None
115-
volumes: Optional[list[dict]] = None
116-
init_containers: Optional[list[ContainerOverride]] = None
117-
containers: Optional[list[ContainerOverride]] = None
118-
scheduling_gates: Optional[list[dict]] = None
119-
image_pull_secrets: Optional[list[dict]] = None
135+
service_account_name: str | None = None
136+
node_selector: dict[str, str] | None = None
137+
affinity: dict | None = None
138+
tolerations: list[dict] | None = None
139+
volumes: list[dict] | None = None
140+
init_containers: list[ContainerOverride] | None = None
141+
containers: list[ContainerOverride] | None = None
142+
scheduling_gates: list[dict] | None = None
143+
image_pull_secrets: list[dict] | None = None
120144

121145

122146
@dataclass
@@ -130,5 +154,5 @@ class PodTemplateOverride:
130154
"""
131155

132156
target_jobs: list[str]
133-
metadata: Optional[dict] = None
134-
spec: Optional[PodTemplateSpecOverride] = None
157+
metadata: dict | None = None
158+
spec: PodTemplateSpecOverride | None = None

kubeflow/trainer/options/kubernetes.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
"""Kubernetes-specific training options for the Kubeflow Trainer SDK."""
1616

17+
from __future__ import annotations
18+
1719
from dataclasses import dataclass
18-
from typing import TYPE_CHECKING, Any, Optional, Union
20+
from typing import TYPE_CHECKING, Any
1921

2022
from kubeflow.trainer.backends.base import KubernetesCompatible
2123
from kubeflow.trainer.options.common import PodTemplateOverride
@@ -33,7 +35,7 @@ class Labels(KubernetesCompatible):
3335
def __call__(
3436
self,
3537
job_spec: dict[str, Any],
36-
trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None,
38+
trainer: CustomTrainer | BuiltinTrainer | None = None,
3739
) -> None:
3840
"""Apply labels to the job specification."""
3941
metadata = job_spec.setdefault("metadata", {})
@@ -49,7 +51,7 @@ class Annotations(KubernetesCompatible):
4951
def __call__(
5052
self,
5153
job_spec: dict[str, Any],
52-
trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None,
54+
trainer: CustomTrainer | BuiltinTrainer | None = None,
5355
) -> None:
5456
"""Apply annotations to the job specification."""
5557
metadata = job_spec.setdefault("metadata", {})
@@ -69,7 +71,7 @@ class SpecLabels(KubernetesCompatible):
6971
def __call__(
7072
self,
7173
job_spec: dict[str, Any],
72-
trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None,
74+
trainer: CustomTrainer | BuiltinTrainer | None = None,
7375
) -> None:
7476
"""Apply spec-level labels to the job specification."""
7577
spec = job_spec.setdefault("spec", {})
@@ -89,29 +91,13 @@ class SpecAnnotations(KubernetesCompatible):
8991
def __call__(
9092
self,
9193
job_spec: dict[str, Any],
92-
trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None,
94+
trainer: CustomTrainer | BuiltinTrainer | None = None,
9395
) -> None:
9496
"""Apply spec-level annotations to the job specification."""
9597
spec = job_spec.setdefault("spec", {})
9698
spec["annotations"] = self.annotations
9799

98100

99-
@dataclass
100-
class Name(KubernetesCompatible):
101-
"""Set a custom name for the TrainJob resource (.metadata.name)."""
102-
103-
name: str
104-
105-
def __call__(
106-
self,
107-
job_spec: dict[str, Any],
108-
trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None,
109-
) -> None:
110-
"""Apply custom name to the job specification."""
111-
metadata = job_spec.setdefault("metadata", {})
112-
metadata["name"] = self.name
113-
114-
115101
@dataclass
116102
class PodTemplateOverrides(KubernetesCompatible):
117103
"""Add pod template overrides to the TrainJob (.spec.podTemplateOverrides)."""
@@ -121,7 +107,7 @@ class PodTemplateOverrides(KubernetesCompatible):
121107
def __call__(
122108
self,
123109
job_spec: dict[str, Any],
124-
trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None,
110+
trainer: CustomTrainer | BuiltinTrainer | None = None,
125111
) -> None:
126112
"""Apply pod template overrides to the job specification."""
127113
spec = job_spec.setdefault("spec", {})
@@ -187,7 +173,7 @@ class TrainerImage(KubernetesCompatible):
187173
def __call__(
188174
self,
189175
job_spec: dict[str, Any],
190-
trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None,
176+
trainer: CustomTrainer | BuiltinTrainer | None = None,
191177
) -> None:
192178
"""Apply trainer image override to the job specification.
193179
@@ -209,7 +195,7 @@ class TrainerCommand(KubernetesCompatible):
209195
def __call__(
210196
self,
211197
job_spec: dict[str, Any],
212-
trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None,
198+
trainer: CustomTrainer | BuiltinTrainer | None = None,
213199
) -> None:
214200
"""Apply trainer command override to the job specification.
215201
@@ -243,7 +229,7 @@ class TrainerArgs(KubernetesCompatible):
243229
def __call__(
244230
self,
245231
job_spec: dict[str, Any],
246-
trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None,
232+
trainer: CustomTrainer | BuiltinTrainer | None = None,
247233
) -> None:
248234
"""Apply trainer args override to the job specification.
249235

kubeflow/trainer/options/localprocess.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
"""LocalProcess-specific training options for the Kubeflow Trainer SDK."""
1616

17+
from __future__ import annotations
18+
1719
from dataclasses import dataclass
18-
from typing import TYPE_CHECKING, Optional, Union
20+
from typing import TYPE_CHECKING
1921

2022
from kubeflow.trainer.backends.base import LocalProcessCompatible
2123

@@ -32,7 +34,7 @@ class ProcessTimeout(LocalProcessCompatible):
3234
def __call__(
3335
self,
3436
job_spec: dict,
35-
trainer: Optional[Union["BuiltinTrainer", "CustomTrainer"]] = None,
37+
trainer: BuiltinTrainer | CustomTrainer | None = None,
3638
) -> None:
3739
"""Apply timeout to local process configuration.
3840
@@ -52,7 +54,7 @@ class WorkingDirectory(LocalProcessCompatible):
5254
def __call__(
5355
self,
5456
job_spec: dict,
55-
trainer: Optional[Union["BuiltinTrainer", "CustomTrainer"]] = None,
57+
trainer: BuiltinTrainer | CustomTrainer | None = None,
5658
) -> None:
5759
"""Apply working directory to local process configuration.
5860

0 commit comments

Comments
 (0)