2525import string
2626import uuid
2727from dataclasses import asdict
28- from typing import Optional
28+ from typing import Dict , Optional
2929from unittest .mock import Mock , patch
3030
3131import pytest
3232from kubeflow_trainer_api import models
3333
3434from kubeflow .trainer .constants import constants
3535from kubeflow .trainer .types import types
36+ from kubeflow .trainer .options import WithLabels , WithAnnotations
3637from kubeflow .trainer .utils import utils
3738from kubeflow .trainer .backends .kubernetes .backend import KubernetesBackend
3839from kubeflow .trainer .backends .kubernetes .types import KubernetesBackendConfig
@@ -253,14 +254,20 @@ def get_train_job(
253254 runtime_name : str ,
254255 train_job_name : str = BASIC_TRAIN_JOB_NAME ,
255256 train_job_trainer : Optional [models .TrainerV1alpha1Trainer ] = None ,
257+ labels : Optional [Dict [str , str ]] = None ,
258+ annotations : Optional [Dict [str , str ]] = None ,
256259) -> models .TrainerV1alpha1TrainJob :
257260 """
258261 Create a mock TrainJob object with optional trainer configurations.
259262 """
260263 train_job = models .TrainerV1alpha1TrainJob (
261264 apiVersion = constants .API_VERSION ,
262265 kind = constants .TRAINJOB_KIND ,
263- metadata = models .IoK8sApimachineryPkgApisMetaV1ObjectMeta (name = train_job_name ),
266+ metadata = models .IoK8sApimachineryPkgApisMetaV1ObjectMeta (
267+ name = train_job_name ,
268+ labels = labels ,
269+ annotations = annotations
270+ ),
264271 spec = models .TrainerV1alpha1TrainJobSpec (
265272 runtimeRef = models .TrainerV1alpha1RuntimeRef (name = runtime_name ),
266273 trainer = train_job_trainer ,
@@ -788,7 +795,73 @@ def test_get_runtime_packages(kubernetes_backend, test_case):
788795 },
789796 expected_error = ValueError ,
790797 ),
791-
798+ TestCase (
799+ name = "valid flow with labels and annotations" ,
800+ expected_status = SUCCESS ,
801+ config = {
802+ "labels" : {"kueue.x-k8s.io/queue-name" : "ml-queue" , "team" : "ml-engineering" },
803+ "annotations" : {"experiment.id" : "exp-001" , "description" : "Test training job" },
804+ },
805+ expected_output = get_train_job (
806+ runtime_name = TORCH_RUNTIME ,
807+ train_job_name = BASIC_TRAIN_JOB_NAME ,
808+ labels = {"kueue.x-k8s.io/queue-name" : "ml-queue" , "team" : "ml-engineering" },
809+ annotations = {"experiment.id" : "exp-001" , "description" : "Test training job" },
810+ ),
811+ ),
812+ TestCase (
813+ name = "valid flow with only labels" ,
814+ expected_status = SUCCESS ,
815+ config = {
816+ "labels" : {"priority" : "high" },
817+ },
818+ expected_output = get_train_job (
819+ runtime_name = TORCH_RUNTIME ,
820+ train_job_name = BASIC_TRAIN_JOB_NAME ,
821+ labels = {"priority" : "high" },
822+ ),
823+ ),
824+ TestCase (
825+ name = "valid flow with only annotations" ,
826+ expected_status = SUCCESS ,
827+ config = {
828+ "annotations" : {"created-by" : "training-pipeline" },
829+ },
830+ expected_output = get_train_job (
831+ runtime_name = TORCH_RUNTIME ,
832+ train_job_name = BASIC_TRAIN_JOB_NAME ,
833+ annotations = {"created-by" : "training-pipeline" },
834+ ),
835+ ),
836+ # Test cases using the new Options pattern
837+ TestCase (
838+ name = "valid flow with WithLabels option" ,
839+ expected_status = SUCCESS ,
840+ config = {
841+ "options" : [WithLabels ({"team" : "ml-platform" , "project" : "training" })],
842+ },
843+ expected_output = get_train_job (
844+ runtime_name = TORCH_RUNTIME ,
845+ train_job_name = BASIC_TRAIN_JOB_NAME ,
846+ labels = {"team" : "ml-platform" , "project" : "training" },
847+ ),
848+ ),
849+ TestCase (
850+ name = "valid flow with multiple options" ,
851+ expected_status = SUCCESS ,
852+ config = {
853+ "options" : [
854+ WithLabels ({"team" : "ml-platform" }),
855+ WithAnnotations ({"created-by" : "sdk" }),
856+ ],
857+ },
858+ expected_output = get_train_job (
859+ runtime_name = TORCH_RUNTIME ,
860+ train_job_name = BASIC_TRAIN_JOB_NAME ,
861+ labels = {"team" : "ml-platform" },
862+ annotations = {"created-by" : "sdk" },
863+ ),
864+ ),
792865 ],
793866)
794867def test_train (kubernetes_backend , test_case ):
@@ -798,8 +871,38 @@ def test_train(kubernetes_backend, test_case):
798871 kubernetes_backend .namespace = test_case .config .get ("namespace" , DEFAULT_NAMESPACE )
799872 runtime = kubernetes_backend .get_runtime (test_case .config .get ("runtime" , TORCH_RUNTIME ))
800873
874+ job_spec = {}
875+
876+ options = test_case .config .get ("options" , None )
877+ if options :
878+ for option in options :
879+ option .apply (job_spec )
880+
881+ metadata_section = job_spec .get ("metadata" , {})
882+
883+ labels = metadata_section .get ("labels" ) or None
884+ annotations = metadata_section .get ("annotations" ) or None
885+
886+ # Merge individual parameters with options
887+ individual_labels = test_case .config .get ("labels" , None )
888+ individual_annotations = test_case .config .get ("annotations" , None )
889+
890+ if individual_labels :
891+ if labels :
892+ labels .update (individual_labels )
893+ else :
894+ labels = individual_labels
895+ if individual_annotations :
896+ if annotations :
897+ annotations .update (individual_annotations )
898+ else :
899+ annotations = individual_annotations
900+
801901 train_job_name = kubernetes_backend .train (
802- runtime = runtime , trainer = test_case .config .get ("trainer" , None )
902+ runtime = runtime ,
903+ trainer = test_case .config .get ("trainer" , None ),
904+ labels = labels ,
905+ annotations = annotations ,
803906 )
804907
805908 assert test_case .expected_status == SUCCESS
0 commit comments