Skip to content

Commit 5aa06aa

Browse files
committed
Merge branch 'dev'
2 parents b6547c3 + cbbf4b0 commit 5aa06aa

File tree

8 files changed

+391
-121
lines changed

8 files changed

+391
-121
lines changed

nemo/collections/llm/gpt/model/gemma3.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
3131
from megatron.core.transformer.enums import AttnBackend, AttnMaskType
3232
from megatron.core.transformer.mlp import MLP, MLPSubmodules
33+
34+
import torch
3335
from torch import Tensor, nn
3436

3537
from nemo.collections.llm.fn.activation import openai_gelu
@@ -146,16 +148,19 @@ class Gemma3Config(GPTConfig):
146148
attention_backend: AttnBackend = AttnBackend.flash
147149

148150
# mlp
151+
bias_activation_fusion: bool = True
149152
gated_linear_unit: bool = True
150153
add_bias_linear: bool = False
151-
activation_func: Callable = openai_gelu
154+
activation_func: Callable = torch.nn.functional.gelu
152155

153156
# Do not change
154157
is_vision_language: bool = False
155158
flash_decode: bool = False
156159
gradient_accumulation_fusion: bool = False
157160
transformer_layer_spec: Union[ModuleSpec, Callable[["GPTConfig"], ModuleSpec]] = gemma3_layer_spec
158161
scatter_embedding_sequence_parallel: bool = True
162+
apply_rope_fusion: bool = True
163+
cross_entropy_fusion_impl: str = 'te'
159164

160165
def configure_model(
161166
self,
@@ -338,7 +343,12 @@ def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -
338343
"""Get global and local rope embedding"""
339344
rope_global = super().forward(max_seq_len, offset, packed_seq)
340345
rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq)
341-
return rope_local, rope_global
346+
# when using recompute_granularity is full, save_for_backward is called
347+
# to save all variables in a layer. It can only save variables but not
348+
# tuples.
349+
# Stack rope_local and rope_global into a single tensor to avoid the
350+
# error.
351+
return torch.stack((rope_local, rope_global), dim=0)
342352

343353

344354
def _is_local_attn_layer(
@@ -372,7 +382,6 @@ def forward(
372382
inference_params: Optional[BaseInferenceContext] = None,
373383
) -> Tuple[Tensor, Tensor]:
374384
"""Switch to either local or global rope embedding before forward"""
375-
assert isinstance(rotary_pos_emb, tuple)
376385
assert rotary_pos_cos is None and rotary_pos_sin is None
377386

378387
if _is_local_attn_layer(self.layer_number, self.config.interleaved_attn_pattern):
@@ -614,6 +623,7 @@ def config(self):
614623
architectures=["Gemma3ForCausalLM"],
615624
num_hidden_layers=source.num_layers,
616625
hidden_size=source.hidden_size,
626+
sliding_window=source.window_size,
617627
intermediate_size=source.ffn_hidden_size,
618628
num_attention_heads=source.num_attention_heads,
619629
head_dim=source.kv_channels,

nemo/collections/vlm/gemma3vl/data/mock.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,19 @@ def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
186186
generation_prompt_size = 5 # "ASSISTANT:" like
187187
prompt_end_idx = img_start_idx + IMAGE_TOKENS + generation_prompt_size
188188
labels[:prompt_end_idx] = -100
189+
# Add the labels clipping to the mock data loader.
189190
labels = labels[1:]
190191

192+
# 5) prepare loss masks
193+
# Calculate loss mask from labels, to be consistent with real data and reduce confusions.
194+
loss_mask = torch.ones_like(labels, dtype=torch.float)
195+
loss_mask[labels < 0] = 0.0
196+
191197
return {
192198
"input_ids": input_ids,
193199
"position_ids": position_ids,
194200
"pixel_values": pixel_values,
195-
"loss_mask": self.loss_mask,
201+
"loss_mask": loss_mask,
196202
"labels": labels,
197203
}
198204

nemo/collections/vlm/gemma3vl/data/task_encoder.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import json
1617
from dataclasses import dataclass, field
1718
from typing import Optional
1819

@@ -24,6 +25,8 @@
2425
from nemo.collections.vlm.data.task_encoder import TaskEncoder as BaseTaskEncoder
2526
from nemo.collections.vlm.data.task_encoder import TaskEncoderConfig as BaseTaskEncoderConfig
2627
from nemo.collections.vlm.data.utils import _find_pattern_indices
28+
from nemo.collections.vlm.qwen2vl.data.multimodal_tokens import IGNORE_INDEX, IMAGE_TOKEN_INDEX
29+
from nemo.utils import logging
2730

2831

2932
@dataclass
@@ -101,58 +104,54 @@ def encode_batch(self, batch_data: DataBatch) -> dict:
101104
batch_data["media"] = batch_data["media"].reshape(-1, *batch_data["media"].shape[2:])
102105
return batch_data
103106

104-
def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
105-
"""Encode a VQA sample into a DataSample format.
106107

107-
Args:
108-
input_sample (VQASample): Input VQA sample containing image, context and answers
108+
def encode_vqa_sample_multi_turns(self, input_sample: VQASample):
109+
images = input_sample.image if isinstance(input_sample.image, list) else [input_sample.image]
109110

110-
Returns:
111-
DataSample: Encoded sample with processed image, tokens, labels and loss mask
112-
"""
111+
contexts = json.loads(input_sample.context.decode('utf-8'))
113112
messages = []
114113
if self.config.system_prompt:
115114
messages.append({'role': 'system', 'content': self.config.system_prompt})
116-
117-
# Ensure context and answers are lists for consistent processing
118-
contexts = input_sample.context if isinstance(input_sample.context, list) else [input_sample.context]
119-
answers = input_sample.answers if isinstance(input_sample.answers, list) else [input_sample.answers]
120-
121-
# Build the conversation messages, replacing image placeholder
122-
min_length = min(len(contexts), len(answers))
123-
for i in range(min_length):
124-
context_with_placeholder = contexts[i].replace("<image>", self.config.image_token)
125-
messages.append({'role': self.config.roles[0], 'content': context_with_placeholder})
126-
messages.append({'role': self.config.roles[1], 'content': answers[i]})
115+
for context in contexts:
116+
messages.append(context)
127117

128118
# Apply chat template and process with HF processor
129-
converted_messages = self.hf_processor.apply_chat_template(messages, tokenize=False)
119+
#`add_generation_prompt=False` because we're providing the full ground truth sequence
120+
# We remove the <bos> token using removeprefix('<bos>') since we're finetuning.
121+
# The Processor will add this token before training and the model expects only one.
122+
converted_messages = self.hf_processor.apply_chat_template(messages, add_generation_prompt=False, tokenize=False).removeprefix('<bos>')
130123
outputs = self.hf_processor(
131-
images=input_sample.image,
124+
images=images,
132125
text=converted_messages,
133126
return_tensors="pt",
134127
images_kwargs={"do_rescale": False},
135128
)
136-
137129
# Get tokens and images from processor output
138130
# Squeeze the batch dimension as we process one sample at a time
139131
tokens = outputs["input_ids"].squeeze(0)
140132
images = outputs.get("pixel_values") # Use .get() for optional images
141133

142134
# --- Label Generation ---
135+
# Same as: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/vlm/qwen2vl/data/task_encoder.py#L263-L270.
143136
# Initialize labels with ignore placeholder
144137
labels = torch.full_like(tokens, self.config.ignore_place_holder)
145-
146138
search_start_index = 0
147-
for answer in answers:
139+
for context in contexts:
140+
if context['role'] != 'assistant':
141+
continue
148142
# Tokenize the answer, including the stop string if provided
149-
answer_with_stop = answer + (self.config.stop_string or "")
143+
answer_with_stop = context['content'][0]['text'].rstrip().lstrip() + "<end_of_turn>" + (self.config.stop_string or "")
144+
answer_with_stop = answer_with_stop.rstrip().lstrip()
150145
answer_tokens = self.tokenizer.tokenizer(answer_with_stop, add_special_tokens=False)["input_ids"]
151146
answer_tokens_tensor = torch.tensor(answer_tokens, device=tokens.device) # Ensure same device
152147

148+
# sometimes the tokenizer can add additional space. See:
149+
# https://github.com/huggingface/transformers/issues/25073#issuecomment-1655271420
150+
if self.tokenizer.tokenizer.decode(answer_tokens[0]) == "":
151+
answer_tokens_tensor = answer_tokens_tensor[1:]
152+
153153
# Find answer pattern in tokens
154154
answer_start, answer_end = _find_pattern_indices(tokens, answer_tokens_tensor, search_start_index)
155-
156155
if answer_start >= 0:
157156
labels[answer_start:answer_end] = tokens[answer_start:answer_end]
158157
search_start_index = answer_end
@@ -170,11 +169,25 @@ def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
170169
search_start_index,
171170
)
172171
break
172+
return tokens, labels, images
173+
174+
175+
def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
176+
"""Encode a VQA sample into a DataSample format.
177+
178+
Args:
179+
input_sample (VQASample): Input VQA sample containing image, context and answers
180+
181+
Returns:
182+
DataSample: Encoded sample with processed image, tokens, labels and loss mask
183+
"""
184+
tokens, labels, images = self.encode_vqa_sample_multi_turns(input_sample)
173185

174186
# Prepare final tensors
175187
tokens = tokens[:-1].contiguous()
176188
labels = labels[1:].contiguous()
177189
seqlen = len(tokens) # Original sequence length before padding
190+
position_ids = torch.arange(seqlen, dtype=torch.int64)
178191

179192
# Pad tokens and labels to a multiple of `pad_to_multiple_of` if specified
180193
if self.config.pad_to_multiple_of:
@@ -191,7 +204,7 @@ def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
191204

192205
# Compute loss mask
193206
loss_mask = torch.ones_like(labels, dtype=torch.float)
194-
loss_mask[labels == self.config.ignore_place_holder] = 0.0
207+
loss_mask[labels < 0] = 0.0
195208

196209
# Convert images to bfloat16 and stack, or create an empty tensor if no images
197210
if images is not None and images.numel() > 0:
@@ -202,13 +215,17 @@ def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
202215
# Create an empty tensor with appropriate dimensions and dtype if no images
203216
processed_image = None
204217

205-
return Gemma3DataSample(
218+
sample = Gemma3DataSample(
206219
__key__=input_sample.__key__,
207220
__restore_key__=input_sample.__restore_key__,
208221
__subflavor__=input_sample.__subflavor__,
209222
__subflavors__=input_sample.__subflavors__,
210223
pixel_values=processed_image,
211224
input_ids=tokens,
225+
position_ids=position_ids,
212226
labels=labels,
213227
loss_mask=loss_mask,
214228
)
229+
230+
return sample
231+

nemo/collections/vlm/gemma3vl/model/base.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from megatron.core.inference_params import InferenceParams
2727
from megatron.core.packed_seq_params import PackedSeqParams
2828
from megatron.core.parallel_state import get_context_parallel_group
29+
from nemo.collections.vlm.qwen2vl.data.multimodal_tokens import IGNORE_INDEX
2930
from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region
3031
from megatron.core.transformer import MegatronModule
3132
from megatron.core.transformer.transformer_config import TransformerConfig
@@ -41,6 +42,7 @@
4142
from nemo.lightning.pytorch.optim import OptimizerModule
4243
from nemo.utils.import_utils import safe_import_from
4344

45+
4446
TENorm, _ = safe_import_from("megatron.core.extensions.transformer_engine", "TENorm")
4547

4648
HAVE_TEX = True
@@ -78,6 +80,7 @@ def gemma3vl_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
7880
key: val.cuda(non_blocking=True) if key in required_keys and val is not None else None
7981
for key, val in _batch.items()
8082
}
83+
8184
return _batch
8285

8386

@@ -392,18 +395,18 @@ def forward(
392395
input_ids = F.pad(input_ids, (0, padded_seq_len))
393396
position_ids = F.pad(position_ids, (0, padded_seq_len))
394397
if self.post_process:
395-
labels = F.pad(labels, (0, padded_seq_len))
396-
loss_mask = F.pad(loss_mask, (0, padded_seq_len))
398+
labels = F.pad(labels, (0, padded_seq_len), value=IGNORE_INDEX)
399+
loss_mask = F.pad(loss_mask, (0, padded_seq_len), value=0.0)
397400

398401
# Compute language embedding
399402
if self.pre_process:
400403
safe_input_ids = input_ids
401404
# Replace image_token_id with 0 to avoid embedding index error
402-
if self.image_token_id >= self.vocab_size:
403-
image_token_mask = input_ids == self.image_token_id
404-
safe_input_ids = input_ids.clone()
405-
safe_input_ids[image_token_mask] = 0
405+
image_token_mask = input_ids == self.image_token_id
406+
safe_input_ids = input_ids.clone()
407+
safe_input_ids[image_token_mask] = 0
406408
# (T, B, D)
409+
# The position_ids is None for qwen2 models, but set to position_ids for gemma3vl models.
407410
language_embedding = self.language_model.embedding(input_ids=safe_input_ids, position_ids=position_ids)
408411
# (B, T, D)
409412
language_embedding = language_embedding.transpose(1, 0).contiguous()
@@ -428,6 +431,7 @@ def forward(
428431
combined_embedding = combined_embedding.transpose(1, 0).contiguous()
429432

430433
# Run decoder model
434+
# position_ids is None for gemma3vl models, but set to position_ids to qwen2 models.
431435
output = self.language_model(
432436
input_ids=None,
433437
position_ids=None,
@@ -441,6 +445,8 @@ def forward(
441445

442446
if labels is None or loss_mask is None:
443447
return output
448+
449+
output = output.masked_fill(labels < 0, 0.0)
444450
return output, loss_mask
445451

446452
def _preprocess_data(
@@ -536,6 +542,7 @@ def _process_sequence_parallel(
536542
combined_embedding = scatter_to_sequence_parallel_region(combined_embedding)
537543
return combined_embedding, labels, loss_mask, packed_seq_params
538544

545+
@torch.compile
539546
def _compute_attention_mask(
540547
self,
541548
input_ids: torch.Tensor,

scripts/vlm/gemma3vl_export.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Export Gemma3VL NeMo checkpoints to Hugging Face format."""
2+
3+
import argparse
4+
from huggingface_hub import hf_hub_download
5+
import importlib
6+
import os
7+
from pathlib import Path
8+
import sys
9+
from nemo.collections import llm
10+
11+
12+
def main():
13+
parser = argparse.ArgumentParser(
14+
description=(
15+
"Export NeMo vision language model checkpoint to Hugging Face format."
16+
)
17+
)
18+
parser.add_argument(
19+
"--nemo_ckpt_path",
20+
type=str,
21+
required=True,
22+
default=None,
23+
help="Path to the NeMo checkpoint directory.",
24+
)
25+
parser.add_argument(
26+
"--output_hf_path",
27+
type=str,
28+
required=True,
29+
default=None,
30+
help="Path to save the converted Hugging Face checkpoint.",
31+
)
32+
parser.add_argument(
33+
"--model_name",
34+
type=str,
35+
required=False,
36+
default=None,
37+
help="Name of the model on Hugging Face.",
38+
)
39+
40+
args = parser.parse_args()
41+
42+
llm.export_ckpt(
43+
path=Path(args.nemo_ckpt_path),
44+
target="hf",
45+
output_path=Path(args.output_hf_path),
46+
overwrite=True,
47+
)
48+
if args.model_name:
49+
# Copy necessary files if exist from HuggingFace for Gemma3VL model export.
50+
copy_file_list = [
51+
"preprocessor_config.json",
52+
"chat_template.json",
53+
"config.json",
54+
"generation_config.json",
55+
"merges.txt",
56+
"tokenizer.json",
57+
"tokenizer_config.json",
58+
"vocab.json",
59+
]
60+
for file_name in copy_file_list:
61+
try:
62+
downloaded_path = hf_hub_download(
63+
repo_id=args.model_name,
64+
filename=file_name,
65+
local_dir=args.output_hf_path,
66+
)
67+
print(f"Downloaded {downloaded_path} during export gamma3vl models.")
68+
except:
69+
print(f"Ignore {file_name} during export gamma3vl models.")
70+
71+
72+
if __name__ == "__main__":
73+
main()

0 commit comments

Comments
 (0)