1414
1515"""Kubernetes-specific training options for the Kubeflow Trainer SDK."""
1616
17+ from __future__ import annotations
18+
1719from dataclasses import dataclass
18- from typing import TYPE_CHECKING , Any , Optional , Union
20+ from typing import TYPE_CHECKING , Any
1921
2022from kubeflow .trainer .backends .base import KubernetesCompatible
2123from kubeflow .trainer .options .common import PodTemplateOverride
@@ -33,7 +35,7 @@ class Labels(KubernetesCompatible):
3335 def __call__ (
3436 self ,
3537 job_spec : dict [str , Any ],
36- trainer : Optional [ Union [ " CustomTrainer" , " BuiltinTrainer" ]] = None ,
38+ trainer : CustomTrainer | BuiltinTrainer | None = None ,
3739 ) -> None :
3840 """Apply labels to the job specification."""
3941 metadata = job_spec .setdefault ("metadata" , {})
@@ -49,7 +51,7 @@ class Annotations(KubernetesCompatible):
4951 def __call__ (
5052 self ,
5153 job_spec : dict [str , Any ],
52- trainer : Optional [ Union [ " CustomTrainer" , " BuiltinTrainer" ]] = None ,
54+ trainer : CustomTrainer | BuiltinTrainer | None = None ,
5355 ) -> None :
5456 """Apply annotations to the job specification."""
5557 metadata = job_spec .setdefault ("metadata" , {})
@@ -69,7 +71,7 @@ class SpecLabels(KubernetesCompatible):
6971 def __call__ (
7072 self ,
7173 job_spec : dict [str , Any ],
72- trainer : Optional [ Union [ " CustomTrainer" , " BuiltinTrainer" ]] = None ,
74+ trainer : CustomTrainer | BuiltinTrainer | None = None ,
7375 ) -> None :
7476 """Apply spec-level labels to the job specification."""
7577 spec = job_spec .setdefault ("spec" , {})
@@ -89,29 +91,13 @@ class SpecAnnotations(KubernetesCompatible):
8991 def __call__ (
9092 self ,
9193 job_spec : dict [str , Any ],
92- trainer : Optional [ Union [ " CustomTrainer" , " BuiltinTrainer" ]] = None ,
94+ trainer : CustomTrainer | BuiltinTrainer | None = None ,
9395 ) -> None :
9496 """Apply spec-level annotations to the job specification."""
9597 spec = job_spec .setdefault ("spec" , {})
9698 spec ["annotations" ] = self .annotations
9799
98100
99- @dataclass
100- class Name (KubernetesCompatible ):
101- """Set a custom name for the TrainJob resource (.metadata.name)."""
102-
103- name : str
104-
105- def __call__ (
106- self ,
107- job_spec : dict [str , Any ],
108- trainer : Optional [Union ["CustomTrainer" , "BuiltinTrainer" ]] = None ,
109- ) -> None :
110- """Apply custom name to the job specification."""
111- metadata = job_spec .setdefault ("metadata" , {})
112- metadata ["name" ] = self .name
113-
114-
115101@dataclass
116102class PodTemplateOverrides (KubernetesCompatible ):
117103 """Add pod template overrides to the TrainJob (.spec.podTemplateOverrides)."""
@@ -121,7 +107,7 @@ class PodTemplateOverrides(KubernetesCompatible):
121107 def __call__ (
122108 self ,
123109 job_spec : dict [str , Any ],
124- trainer : Optional [ Union [ " CustomTrainer" , " BuiltinTrainer" ]] = None ,
110+ trainer : CustomTrainer | BuiltinTrainer | None = None ,
125111 ) -> None :
126112 """Apply pod template overrides to the job specification."""
127113 spec = job_spec .setdefault ("spec" , {})
@@ -187,7 +173,7 @@ class TrainerImage(KubernetesCompatible):
187173 def __call__ (
188174 self ,
189175 job_spec : dict [str , Any ],
190- trainer : Optional [ Union [ " CustomTrainer" , " BuiltinTrainer" ]] = None ,
176+ trainer : CustomTrainer | BuiltinTrainer | None = None ,
191177 ) -> None :
192178 """Apply trainer image override to the job specification.
193179
@@ -209,7 +195,7 @@ class TrainerCommand(KubernetesCompatible):
209195 def __call__ (
210196 self ,
211197 job_spec : dict [str , Any ],
212- trainer : Optional [ Union [ " CustomTrainer" , " BuiltinTrainer" ]] = None ,
198+ trainer : CustomTrainer | BuiltinTrainer | None = None ,
213199 ) -> None :
214200 """Apply trainer command override to the job specification.
215201
@@ -243,7 +229,7 @@ class TrainerArgs(KubernetesCompatible):
243229 def __call__ (
244230 self ,
245231 job_spec : dict [str , Any ],
246- trainer : Optional [ Union [ " CustomTrainer" , " BuiltinTrainer" ]] = None ,
232+ trainer : CustomTrainer | BuiltinTrainer | None = None ,
247233 ) -> None :
248234 """Apply trainer args override to the job specification.
249235
0 commit comments