Skip to content

Commit 61ca373

Browse files
committed
implement eval
1 parent 9d2972e commit 61ca373

File tree

3 files changed

+25
-13
lines changed

3 files changed

+25
-13
lines changed

experiments/ablations/continued_pretrain.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from datasets import load_dataset
99
import fire
1010

11+
1112
def load_model(
1213
rank: int = 128,
1314
train_embeddings: bool = True,
@@ -56,7 +57,7 @@ def load_model(
5657

5758

5859
def train(
59-
model, tokenizer, dataset, run_name: str, batch_size: int = 64, max_seq_length=2048
60+
model, tokenizer, dataset, run_name: str, batch_size: int = 64, max_seq_length=2048, eval_dataset=None
6061
):
6162
wandb.init(project="chemnlp-ablations", name=run_name)
6263
trainer = UnslothTrainer(
@@ -66,6 +67,7 @@ def train(
6667
dataset_text_field="text",
6768
max_seq_length=max_seq_length,
6869
dataset_num_proc=2,
70+
eval_dataset=eval_dataset,
6971
args=UnslothTrainingArguments(
7072
per_device_train_batch_size=batch_size,
7173
gradient_accumulation_steps=1,
@@ -81,6 +83,8 @@ def train(
8183
lr_scheduler_type="linear",
8284
seed=3407,
8385
output_dir=f"outputs_{run_name}",
86+
eval_strategy = 'steps' if eval_dataset is not None else 'no',
87+
eval_steps = 10_000 if eval_dataset is not None else None
8488
),
8589
)
8690

@@ -116,19 +120,27 @@ def formatting_prompts_func(examples):
116120
return dataset
117121

118122

119-
def run(data_files: List[str], run_name: str, batch_size: int=64, add_special_tokens: Optional[List[str]]=None, train_embeddings: bool=True):
123+
def run(
124+
data_files: List[str],
125+
run_name: str,
126+
batch_size: int = 64,
127+
add_special_tokens: Optional[List[str]] = None,
128+
train_embeddings: bool = True,
129+
eval_data_files: Optional[List[str]] = None,
130+
):
120131
print(f"Data files {data_files}")
121132
print(f"Run name {run_name}")
122133
print(f"Batch size {batch_size}")
123134
print(f"Add special tokens {add_special_tokens}")
124135
print(f"Train embeddings {train_embeddings}")
125-
model, tokenizer = load_model(train_embeddings=train_embeddings, add_special_tokens=add_special_tokens )
126-
127-
dataset = create_dataset(
128-
tokenizer, data_files
136+
model, tokenizer = load_model(
137+
train_embeddings=train_embeddings, add_special_tokens=add_special_tokens
129138
)
130139

131-
train(model, tokenizer, dataset, run_name, batch_size=batch_size)
140+
dataset = create_dataset(tokenizer, data_files)
141+
eval_dataset = create_dataset(tokenizer, eval_data_files) if eval_data_files else None
142+
143+
train(model, tokenizer, dataset, run_name, batch_size=batch_size, eval_dataset=eval_dataset)
132144

133145

134146
if __name__ == "__main__":

src/chemnlp/data/sampler.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ def _wrap_identifier(self, identifier: str, value: str) -> str:
172172
"""Wrap the identifier value with tags if wrap_identifiers is enabled."""
173173

174174
if not self.wrap_identifiers:
175-
logger.debug("Not wrapping identifiers.")
176175
return value
177176

178177
identifier_type = next(
@@ -189,11 +188,9 @@ def _wrap_identifier(self, identifier: str, value: str) -> str:
189188
except ValueError:
190189
identifier_type = None
191190

192-
logger.debug(f'Identifier type: {identifier_type}, value: {value}')
193191
if identifier_type and identifier_type not in self.config.get(
194192
"excluded_from_wrapping", []
195193
):
196-
logger.debug(f"Wrapping {identifier_type} with tags.")
197194
return f"[BEGIN_{identifier_type}]{value}[END_{identifier_type}]"
198195
return value
199196

src/chemnlp/data/sampler_cli.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def process_dataset(
102102
"excluded_from_wrapping": ["Other"],
103103
}
104104

105-
106105
templates = meta["templates"]
107106
if benchmarking:
108107
templates = [t for t in templates if "<EOI>" in t]
@@ -117,9 +116,13 @@ def process_dataset(
117116
logger.debug(f"Processing chunk {chunk_idx} to {chunk_output_dir}")
118117
os.makedirs(chunk_output_dir, exist_ok=True)
119118

120-
sampler = TemplateSampler(df_chunk, meta=meta, config=config, path_data_dir=data_dir)
119+
sampler = TemplateSampler(
120+
df_chunk, meta=meta, config=config, path_data_dir=data_dir
121+
)
121122
if wrap_identifiers:
122-
assert sampler.wrap_identifiers, "Wrap identifiers must be enabled in the sampler"
123+
assert (
124+
sampler.wrap_identifiers
125+
), "Wrap identifiers must be enabled in the sampler"
123126

124127
for template_idx, template in enumerate(templates):
125128
print(

0 commit comments

Comments
 (0)