@@ -77,7 +77,7 @@ def create_finetune_request(
7777 wandb_base_url : str | None = None ,
7878 wandb_project_name : str | None = None ,
7979 wandb_name : str | None = None ,
80- train_on_inputs : bool | Literal ["auto" ] = "auto" ,
80+ train_on_inputs : bool | Literal ["auto" ] | None = None ,
8181 training_method : str = "sft" ,
8282 dpo_beta : float | None = None ,
8383 from_checkpoint : str | None = None ,
@@ -162,6 +162,15 @@ def create_finetune_request(
162162 f"training_method must be one of { ', ' .join (AVAILABLE_TRAINING_METHODS )} "
163163 )
164164
165+ if train_on_inputs is not None and training_method != "sft" :
166+ raise ValueError ("train_on_inputs is only supported for SFT training" )
167+
168+ if train_on_inputs is None and training_method == "sft" :
169+ log_warn_once (
170+ "train_on_inputs is not set for SFT training, it will be set to 'auto' automatically"
171+ )
172+ train_on_inputs = "auto"
173+
165174 lr_scheduler : FinetuneLRScheduler
166175 if lr_scheduler_type == "cosine" :
167176 if scheduler_num_cycles <= 0.0 :
@@ -179,7 +188,9 @@ def create_finetune_request(
179188 lr_scheduler_args = LinearLRSchedulerArgs (min_lr_ratio = min_lr_ratio ),
180189 )
181190
182- training_method_cls : TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT ()
191+ training_method_cls : TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT (
192+ train_on_inputs = train_on_inputs
193+ )
183194 if training_method == "dpo" :
184195 training_method_cls = TrainingMethodDPO (dpo_beta = dpo_beta )
185196
@@ -202,7 +213,6 @@ def create_finetune_request(
202213 wandb_base_url = wandb_base_url ,
203214 wandb_project_name = wandb_project_name ,
204215 wandb_name = wandb_name ,
205- train_on_inputs = train_on_inputs ,
206216 training_method = training_method_cls ,
207217 from_checkpoint = from_checkpoint ,
208218 )
@@ -307,7 +317,7 @@ def create(
307317 wandb_name : str | None = None ,
308318 verbose : bool = False ,
309319 model_limits : FinetuneTrainingLimits | None = None ,
310- train_on_inputs : bool | Literal ["auto" ] = "auto" ,
320+ train_on_inputs : bool | Literal ["auto" ] | None = None ,
311321 training_method : str = "sft" ,
312322 dpo_beta : float | None = None ,
313323 from_checkpoint : str | None = None ,
@@ -352,12 +362,12 @@ def create(
352362 Defaults to False.
353363 model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
354364 Defaults to None.
355- train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
365+ train_on_inputs (bool or "auto", optional ): Whether to mask the user messages in conversational data or prompts in instruction data.
356366 "auto" will automatically determine whether to mask the inputs based on the data format.
357367 For datasets with the "text" field (general format), inputs will not be masked.
358368 For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
359369 (Instruction format), inputs will be masked.
360- Defaults to "auto".
370+ Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request) .
361371 training_method (str, optional): Training method. Defaults to "sft".
362372 Supported methods: "sft", "dpo".
363373 dpo_beta (float, optional): DPO beta parameter. Defaults to None.
@@ -695,7 +705,7 @@ async def create(
695705 wandb_name : str | None = None ,
696706 verbose : bool = False ,
697707 model_limits : FinetuneTrainingLimits | None = None ,
698- train_on_inputs : bool | Literal ["auto" ] = "auto" ,
708+ train_on_inputs : bool | Literal ["auto" ] | None = None ,
699709 training_method : str = "sft" ,
700710 dpo_beta : float | None = None ,
701711 from_checkpoint : str | None = None ,
@@ -745,7 +755,7 @@ async def create(
745755 For datasets with the "text" field (general format), inputs will not be masked.
746756 For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
747757 (Instruction format), inputs will be masked.
748- Defaults to "auto".
758+ Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request) .
749759 training_method (str, optional): Training method. Defaults to "sft".
750760 Supported methods: "sft", "dpo".
751761 dpo_beta (float, optional): DPO beta parameter. Defaults to None.
0 commit comments