Skip to content

Commit 07da129

Browse files
committed
minor update gemma3vl parameters for easier usages
Signed-off-by: genquan9 <[email protected]>
1 parent 30dfd71 commit 07da129

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

scripts/vlm/gemma3vl_finetune.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
1919
torchrun --nproc_per_node=1 gemma3vl_finetune.py --data_type=energon \
2020
--data_dir=<YOUR DATA DIR>
21-
2221
"""
2322
from scripts.vlm import gemma3vl_utils as train_utils
2423
# Need to run these filters before importing nemo.
@@ -80,6 +79,8 @@ def main(args):
8079
image_processor = Gemma3ImageProcessor()
8180
# Data setup
8281
if args.data_type == "energon":
82+
if args.data_dir is None:
83+
raise ValueError("data_dir is required for energon data type.")
8384
# Initialize the data module
8485
use_packed_sequence = False
8586
hf_processor = Gemma3Processor.from_pretrained(args.hf_model_id)
@@ -200,11 +201,11 @@ def main(args):
200201
parser = argparse.ArgumentParser(description="Gemma3VL Model Training Script")
201202

202203
parser.add_argument("--data_type", type=str, required=False, default="energon", help="mock | energon")
203-
parser.add_argument("--data_dir", type=str, required=True, default=None, help="Path to the dataset folder")
204+
parser.add_argument("--data_dir", type=str, required=False, default=None, help="Path to the dataset folder")
204205
parser.add_argument(
205206
"--restore_path", type=str, required=False, default=None, help="Path to restore model from checkpoint"
206207
)
207-
parser.add_argument("--log_dir", type=str, required=True, default=None, help="Path to the log folder")
208+
parser.add_argument("--log_dir", type=str, required=False, default="/logs", help="Path to the log folder")
208209
parser.add_argument("--tp_size", type=int, required=False, default=1)
209210
parser.add_argument("--pp_size", type=int, required=False, default=1)
210211
parser.add_argument("--num_nodes", type=int, required=False, default=1)
@@ -214,10 +215,6 @@ def main(args):
214215
parser.add_argument("--max_steps", type=int, required=False, default=10)
215216
parser.add_argument("--val_check_interval", type=int, required=False, default=10)
216217
parser.add_argument("--limit_val_batches", type=float, required=False, default=1.0)
217-
parser.add_argument("--every_n_train_steps", type=int, required=False, default=100)
218-
parser.add_argument(
219-
"--monitor_metric", type=str, required=False, default="val_loss"
220-
)
221218
parser.add_argument("--lr", type=float, required=False, default=2.0e-06, help="Learning rate")
222219
parser.add_argument("--hf_model_id", type=str, required=False, default="google/gemma-3-4b-it", help="HuggingFace Gemma3VL model ids")
223220
parser.add_argument("--gbs", type=int, required=False, default=32, help="Global batch size")

0 commit comments

Comments
 (0)