88from datasets import load_dataset
99import fire
1010
11+
1112def load_model (
1213 rank : int = 128 ,
1314 train_embeddings : bool = True ,
@@ -56,7 +57,7 @@ def load_model(
5657
5758
5859def train (
59- model , tokenizer , dataset , run_name : str , batch_size : int = 64 , max_seq_length = 2048
60+ model , tokenizer , dataset , run_name : str , batch_size : int = 64 , max_seq_length = 2048 , eval_dataset = None
6061):
6162 wandb .init (project = "chemnlp-ablations" , name = run_name )
6263 trainer = UnslothTrainer (
@@ -66,6 +67,7 @@ def train(
6667 dataset_text_field = "text" ,
6768 max_seq_length = max_seq_length ,
6869 dataset_num_proc = 2 ,
70+ eval_dataset = eval_dataset ,
6971 args = UnslothTrainingArguments (
7072 per_device_train_batch_size = batch_size ,
7173 gradient_accumulation_steps = 1 ,
@@ -81,6 +83,8 @@ def train(
8183 lr_scheduler_type = "linear" ,
8284 seed = 3407 ,
8385 output_dir = f"outputs_{ run_name } " ,
86+ eval_strategy = 'steps' if eval_dataset is not None else 'no' ,
87+ eval_steps = 10_000 if eval_dataset is not None else None
8488 ),
8589 )
8690
@@ -116,19 +120,27 @@ def formatting_prompts_func(examples):
116120 return dataset
117121
118122
119- def run (data_files : List [str ], run_name : str , batch_size : int = 64 , add_special_tokens : Optional [List [str ]]= None , train_embeddings : bool = True ):
123+ def run (
124+ data_files : List [str ],
125+ run_name : str ,
126+ batch_size : int = 64 ,
127+ add_special_tokens : Optional [List [str ]] = None ,
128+ train_embeddings : bool = True ,
129+ eval_data_files : Optional [List [str ]] = None ,
130+ ):
120131 print (f"Data files { data_files } " )
121132 print (f"Run name { run_name } " )
122133 print (f"Batch size { batch_size } " )
123134 print (f"Add special tokens { add_special_tokens } " )
124135 print (f"Train embeddings { train_embeddings } " )
125- model , tokenizer = load_model (train_embeddings = train_embeddings , add_special_tokens = add_special_tokens )
126-
127- dataset = create_dataset (
128- tokenizer , data_files
136+ model , tokenizer = load_model (
137+ train_embeddings = train_embeddings , add_special_tokens = add_special_tokens
129138 )
130139
131- train (model , tokenizer , dataset , run_name , batch_size = batch_size )
140+ dataset = create_dataset (tokenizer , data_files )
141+ eval_dataset = create_dataset (tokenizer , eval_data_files ) if eval_data_files else None
142+
143+ train (model , tokenizer , dataset , run_name , batch_size = batch_size , eval_dataset = eval_dataset )
132144
133145
134146if __name__ == "__main__" :
0 commit comments