1818 min_learning_rate = 1e-6 ,
1919 full_training = FinetuneFullTrainingLimits (
2020 max_batch_size = 96 ,
21+ max_batch_size_dpo = 48 ,
2122 min_batch_size = 8 ,
2223 ),
2324 lora_training = FinetuneLoraTrainingLimits (
2425 max_batch_size = 128 ,
26+ max_batch_size_dpo = 64 ,
2527 min_batch_size = 8 ,
2628 max_rank = 64 ,
2729 target_modules = ["q" , "k" , "v" , "o" , "mlp" ],
@@ -83,6 +85,36 @@ def test_lora_request():
8385 assert request .batch_size == _MODEL_LIMITS .lora_training .max_batch_size
8486
8587
88+ def test_dpo_request_lora ():
89+ request = create_finetune_request (
90+ model_limits = _MODEL_LIMITS ,
91+ model = _MODEL_NAME ,
92+ training_file = _TRAINING_FILE ,
93+ training_method = "dpo" ,
94+ lora = True ,
95+ )
96+
97+ assert request .training_type .type == "Lora"
98+ assert request .training_type .lora_r == _MODEL_LIMITS .lora_training .max_rank
99+ assert request .training_type .lora_alpha == _MODEL_LIMITS .lora_training .max_rank * 2
100+ assert request .training_type .lora_dropout == 0.0
101+ assert request .training_type .lora_trainable_modules == "all-linear"
102+ assert request .batch_size == _MODEL_LIMITS .lora_training .max_batch_size_dpo
103+
104+
105+ def test_dpo_request ():
106+ request = create_finetune_request (
107+ model_limits = _MODEL_LIMITS ,
108+ model = _MODEL_NAME ,
109+ training_file = _TRAINING_FILE ,
110+ training_method = "dpo" ,
111+ lora = False ,
112+ )
113+
114+ assert request .training_type .type == "Full"
115+ assert request .batch_size == _MODEL_LIMITS .full_training .max_batch_size_dpo
116+
117+
86118def test_from_checkpoint_request ():
87119 request = create_finetune_request (
88120 model_limits = _MODEL_LIMITS ,
@@ -160,6 +192,7 @@ def test_non_lora_model():
160192 min_learning_rate = 1e-6 ,
161193 full_training = FinetuneFullTrainingLimits (
162194 max_batch_size = 96 ,
195+ max_batch_size_dpo = 48 ,
163196 min_batch_size = 8 ,
164197 ),
165198 lora_training = None ,
@@ -181,6 +214,7 @@ def test_non_full_model():
181214 min_learning_rate = 1e-6 ,
182215 lora_training = FinetuneLoraTrainingLimits (
183216 max_batch_size = 96 ,
217+ max_batch_size_dpo = 48 ,
184218 min_batch_size = 8 ,
185219 max_rank = 64 ,
186220 target_modules = ["q" , "k" , "v" , "o" , "mlp" ],
0 commit comments