1818
1919torchrun --nproc_per_node=1 gemma3vl_finetune.py --data_type=energon \
2020 --data_dir=<YOUR DATA DIR>
21-
2221"""
2322from scripts .vlm import gemma3vl_utils as train_utils
2423# Need to run these filters before importing nemo.
@@ -80,6 +79,8 @@ def main(args):
8079 image_processor = Gemma3ImageProcessor ()
8180 # Data setup
8281 if args .data_type == "energon" :
82+ if args .data_dir is None :
83+ raise ValueError ("data_dir is required for energon data type." )
8384 # Initialize the data module
8485 use_packed_sequence = False
8586 hf_processor = Gemma3Processor .from_pretrained (args .hf_model_id )
@@ -200,11 +201,11 @@ def main(args):
200201 parser = argparse .ArgumentParser (description = "Gemma3VL Model Training Script" )
201202
202203 parser .add_argument ("--data_type" , type = str , required = False , default = "energon" , help = "mock | energon" )
203- parser .add_argument ("--data_dir" , type = str , required = True , default = None , help = "Path to the dataset folder" )
204+ parser .add_argument ("--data_dir" , type = str , required = False , default = None , help = "Path to the dataset folder" )
204205 parser .add_argument (
205206 "--restore_path" , type = str , required = False , default = None , help = "Path to restore model from checkpoint"
206207 )
207- parser .add_argument ("--log_dir" , type = str , required = True , default = None , help = "Path to the log folder" )
208+ parser .add_argument ("--log_dir" , type = str , required = False , default = "/logs" , help = "Path to the log folder" )
208209 parser .add_argument ("--tp_size" , type = int , required = False , default = 1 )
209210 parser .add_argument ("--pp_size" , type = int , required = False , default = 1 )
210211 parser .add_argument ("--num_nodes" , type = int , required = False , default = 1 )
@@ -214,10 +215,6 @@ def main(args):
214215 parser .add_argument ("--max_steps" , type = int , required = False , default = 10 )
215216 parser .add_argument ("--val_check_interval" , type = int , required = False , default = 10 )
216217 parser .add_argument ("--limit_val_batches" , type = float , required = False , default = 1.0 )
217- parser .add_argument ("--every_n_train_steps" , type = int , required = False , default = 100 )
218- parser .add_argument (
219- "--monitor_metric" , type = str , required = False , default = "val_loss"
220- )
221218 parser .add_argument ("--lr" , type = float , required = False , default = 2.0e-06 , help = "Learning rate" )
222219 parser .add_argument ("--hf_model_id" , type = str , required = False , default = "google/gemma-3-4b-it" , help = "HuggingFace Gemma3VL model ids" )
223220 parser .add_argument ("--gbs" , type = int , required = False , default = 32 , help = "Global batch size" )
0 commit comments