Skip to content

Commit 458270f

Browse files
committed
migrate to sft_on_inputs, and change defaults to match
1 parent 367d2bd commit 458270f

File tree

3 files changed

+30
-21
lines changed

3 files changed

+30
-21
lines changed

src/together/cli/api/finetune.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

33
import json
4+
import re
45
from datetime import datetime, timezone
56
from textwrap import wrap
67
from typing import Any, Literal
7-
import re
88

99
import click
1010
from click.core import ParameterSource # type: ignore[attr-defined]
@@ -13,17 +13,17 @@
1313

1414
from together import Together
1515
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX
16+
from together.types.finetune import (
17+
DownloadCheckpointType,
18+
FinetuneEventType,
19+
FinetuneTrainingLimits,
20+
)
1621
from together.utils import (
1722
finetune_price_to_dollars,
23+
format_timestamp,
1824
log_warn,
1925
log_warn_once,
2026
parse_timestamp,
21-
format_timestamp,
22-
)
23-
from together.types.finetune import (
24-
DownloadCheckpointType,
25-
FinetuneTrainingLimits,
26-
FinetuneEventType,
2727
)
2828

2929

@@ -340,9 +340,9 @@ def list(ctx: click.Context) -> None:
340340
"Model Output Name": "\n".join(wrap(i.output_name or "", width=30)),
341341
"Status": i.status,
342342
"Created At": i.created_at,
343-
"Price": f"""${finetune_price_to_dollars(
344-
float(str(i.total_price))
345-
)}""", # convert to string for mypy typing
343+
"Price": f"""${
344+
finetune_price_to_dollars(float(str(i.total_price)))
345+
}""", # convert to string for mypy typing
346346
}
347347
)
348348
table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True)

src/together/resources/finetune.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def create_finetune_request(
7777
wandb_base_url: str | None = None,
7878
wandb_project_name: str | None = None,
7979
wandb_name: str | None = None,
80-
train_on_inputs: bool | Literal["auto"] = "auto",
80+
train_on_inputs: bool | Literal["auto"] | None = None,
8181
training_method: str = "sft",
8282
dpo_beta: float | None = None,
8383
from_checkpoint: str | None = None,
@@ -162,6 +162,15 @@ def create_finetune_request(
162162
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
163163
)
164164

165+
if train_on_inputs is not None and training_method != "sft":
166+
raise ValueError("train_on_inputs is only supported for SFT training")
167+
168+
if train_on_inputs is None and training_method == "sft":
169+
log_warn_once(
170+
"train_on_inputs is not set for SFT training, it will be set to 'auto' automatically"
171+
)
172+
train_on_inputs = "auto"
173+
165174
lr_scheduler: FinetuneLRScheduler
166175
if lr_scheduler_type == "cosine":
167176
if scheduler_num_cycles <= 0.0:
@@ -179,7 +188,9 @@ def create_finetune_request(
179188
lr_scheduler_args=LinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
180189
)
181190

182-
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
191+
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT(
192+
train_on_inputs=train_on_inputs
193+
)
183194
if training_method == "dpo":
184195
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)
185196

@@ -202,7 +213,6 @@ def create_finetune_request(
202213
wandb_base_url=wandb_base_url,
203214
wandb_project_name=wandb_project_name,
204215
wandb_name=wandb_name,
205-
train_on_inputs=train_on_inputs,
206216
training_method=training_method_cls,
207217
from_checkpoint=from_checkpoint,
208218
)
@@ -307,7 +317,7 @@ def create(
307317
wandb_name: str | None = None,
308318
verbose: bool = False,
309319
model_limits: FinetuneTrainingLimits | None = None,
310-
train_on_inputs: bool | Literal["auto"] = "auto",
320+
train_on_inputs: bool | Literal["auto"] | None = None,
311321
training_method: str = "sft",
312322
dpo_beta: float | None = None,
313323
from_checkpoint: str | None = None,
@@ -352,12 +362,12 @@ def create(
352362
Defaults to False.
353363
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
354364
Defaults to None.
355-
train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
365+
train_on_inputs (bool or "auto", optional): Whether to mask the user messages in conversational data or prompts in instruction data.
356366
"auto" will automatically determine whether to mask the inputs based on the data format.
357367
For datasets with the "text" field (general format), inputs will not be masked.
358368
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
359369
(Instruction format), inputs will be masked.
360-
Defaults to "auto".
370+
Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request).
361371
training_method (str, optional): Training method. Defaults to "sft".
362372
Supported methods: "sft", "dpo".
363373
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
@@ -695,7 +705,7 @@ async def create(
695705
wandb_name: str | None = None,
696706
verbose: bool = False,
697707
model_limits: FinetuneTrainingLimits | None = None,
698-
train_on_inputs: bool | Literal["auto"] = "auto",
708+
train_on_inputs: bool | Literal["auto"] | None = None,
699709
training_method: str = "sft",
700710
dpo_beta: float | None = None,
701711
from_checkpoint: str | None = None,
@@ -745,7 +755,7 @@ async def create(
745755
For datasets with the "text" field (general format), inputs will not be masked.
746756
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
747757
(Instruction format), inputs will be masked.
748-
Defaults to "auto".
758+
Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request).
749759
training_method (str, optional): Training method. Defaults to "sft".
750760
Supported methods: "sft", "dpo".
751761
dpo_beta (float, optional): DPO beta parameter. Defaults to None.

src/together/types/finetune.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from enum import Enum
44
from typing import List, Literal
55

6-
from pydantic import StrictBool, Field, field_validator
6+
from pydantic import Field, StrictBool, field_validator
77

88
from together.types.abstract import BaseModel
99
from together.types.common import (
@@ -149,6 +149,7 @@ class TrainingMethodSFT(TrainingMethod):
149149
"""
150150

151151
method: Literal["sft"] = "sft"
152+
train_on_inputs: StrictBool | Literal["auto"] = "auto"
152153

153154

154155
class TrainingMethodDPO(TrainingMethod):
@@ -201,8 +202,6 @@ class FinetuneRequest(BaseModel):
201202
wandb_name: str | None = None
202203
# training type
203204
training_type: FullTrainingType | LoRATrainingType | None = None
204-
# train on inputs
205-
train_on_inputs: StrictBool | Literal["auto"] = "auto"
206205
# training method
207206
training_method: TrainingMethodSFT | TrainingMethodDPO = Field(
208207
default_factory=TrainingMethodSFT

0 commit comments

Comments
 (0)