Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
dc17b86
Fixing lines for multispeaker pipeline (#15030)
tango4j Nov 6, 2025
9e11c3e
Support gemma3vl tuning with verified performances
genquan9 Nov 7, 2025
2bf98a6
minor update gemma3vl parameters for easier usages
genquan9 Nov 7, 2025
7277489
Apply isort and black reformatting
genquan9 Nov 7, 2025
aa3d1cf
Inference optimization for cache-aware pipelines (#15035)
naymaraq Nov 9, 2025
6fa91ab
fix loading of hyb ctc rnnt bpe models when using from pretrained (#1…
nithinraok Nov 9, 2025
652b7d2
revert ckpt scripts removal from #14617 (#15048)
dimapihtar Nov 10, 2025
b025ad5
chore: remove ExportDeploy (#15033)
pablo-garay Nov 10, 2025
3e7510c
fix after ED remove (#15051)
pablo-garay Nov 10, 2025
778e322
Update changelog for `v2.5.3` (#15055)
github-actions[bot] Nov 10, 2025
dedff29
[voice agent] Fix RTVI missing bot message (#15068)
stevehuang52 Nov 13, 2025
ce18f64
[voice agent] make parakeet-eou model default stt (#15069)
stevehuang52 Nov 14, 2025
c141f69
minor fixes to remove unused headers/lines and add exception
genquan9 Nov 14, 2025
e7c2c3c
resolve merge conflicts from github
genquan9 Nov 17, 2025
7987349
removed old buffered CTC script (#15061)
naymaraq Nov 17, 2025
4853439
remove unused imports
genquan9 Nov 17, 2025
a502afa
remove nlp related notebooks (#15070)
nithinraok Nov 17, 2025
d62ef4c
chore: Remove Automodel module (#15044)
thomasdhc Nov 18, 2025
8d73e0d
add support for parallel ckpt removal (#15073)
dimapihtar Nov 18, 2025
632c362
Fix vlm engine changes in mcore (#15076)
meatybobby Nov 18, 2025
f622545
Add docstring for encode_vqa_sample_multi_turns, and fix long comments
genquan9 Nov 19, 2025
8d68799
Update MagpieTTS model with latest changes (#15031)
blisc Nov 19, 2025
4a6f319
Revert "Fix vlm engine changes in mcore (#15076)" (#15090)
pablo-garay Nov 19, 2025
a3fc9a6
ASR inference: expose RNN-T decoding params for context biasing (#15091)
artbataev Nov 19, 2025
cffc47e
Fix vlm engine changes in mcore (#15076)
meatybobby Nov 18, 2025
bd4362d
Revert "Fix vlm engine changes in mcore (#15076)" (#15090)
pablo-garay Nov 19, 2025
6969a25
update notebook (#15093)
nithinraok Nov 20, 2025
4dfb343
fix lines with malformed anchor tags (#15095)
pablo-garay Nov 20, 2025
104d821
add copyright header for missing files
genquan9 Nov 20, 2025
f3a2462
Merge branch 'NVIDIA-NeMo:main' into main
genquan9 Nov 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/voice_agent/server/server_configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ tts:
type: kokoro # choices in ['nemo', 'kokoro']
model: "hexgrad/Kokoro-82M"
model_config: "./server_configs/tts_configs/kokoro_82M.yaml"
device: "cuda"
device: "cuda"
16 changes: 12 additions & 4 deletions nemo/collections/llm/gpt/model/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Callable, Optional, Tuple, Union

import torch
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
Expand All @@ -32,7 +33,6 @@
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from torch import Tensor, nn

from nemo.collections.llm.fn.activation import openai_gelu
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.gpt.model.gemma2 import TERowParallelLinearLayerNorm
from nemo.collections.llm.utils import Config
Expand Down Expand Up @@ -146,16 +146,19 @@ class Gemma3Config(GPTConfig):
attention_backend: AttnBackend = AttnBackend.flash

# mlp
bias_activation_fusion: bool = True
gated_linear_unit: bool = True
add_bias_linear: bool = False
activation_func: Callable = openai_gelu
activation_func: Callable = torch.nn.functional.gelu

# Do not change
is_vision_language: bool = False
flash_decode: bool = False
gradient_accumulation_fusion: bool = False
transformer_layer_spec: Union[ModuleSpec, Callable[["GPTConfig"], ModuleSpec]] = gemma3_layer_spec
scatter_embedding_sequence_parallel: bool = True
apply_rope_fusion: bool = True
cross_entropy_fusion_impl: str = 'te'

def configure_model(
self,
Expand Down Expand Up @@ -338,7 +341,12 @@ def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -
"""Get global and local rope embedding"""
rope_global = super().forward(max_seq_len, offset, packed_seq)
rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq)
return rope_local, rope_global
# when using recompute_granularity is full, save_for_backward is called
# to save all variables in a layer. It can only save variables but not
# tuples.
# Stack rope_local and rope_global into a single tensor to avoid the
# error.
return torch.stack((rope_local, rope_global), dim=0)


def _is_local_attn_layer(
Expand Down Expand Up @@ -372,7 +380,6 @@ def forward(
inference_params: Optional[BaseInferenceContext] = None,
) -> Tuple[Tensor, Tensor]:
"""Switch to either local or global rope embedding before forward"""
assert isinstance(rotary_pos_emb, tuple)
assert rotary_pos_cos is None and rotary_pos_sin is None

if _is_local_attn_layer(self.layer_number, self.config.interleaved_attn_pattern):
Expand Down Expand Up @@ -614,6 +621,7 @@ def config(self):
architectures=["Gemma3ForCausalLM"],
num_hidden_layers=source.num_layers,
hidden_size=source.hidden_size,
sliding_window=source.window_size,
intermediate_size=source.ffn_hidden_size,
num_attention_heads=source.num_attention_heads,
head_dim=source.kv_channels,
Expand Down
8 changes: 7 additions & 1 deletion nemo/collections/vlm/gemma3vl/data/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,19 @@ def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
generation_prompt_size = 5 # "ASSISTANT:" like
prompt_end_idx = img_start_idx + IMAGE_TOKENS + generation_prompt_size
labels[:prompt_end_idx] = -100
# Add the labels clipping to the mock data loader.
labels = labels[1:]

# 5) prepare loss masks
# Calculate loss mask from labels, to be consistent with real data and reduce confusions.
loss_mask = torch.ones_like(labels, dtype=torch.float)
loss_mask[labels < 0] = 0.0

return {
"input_ids": input_ids,
"position_ids": position_ids,
"pixel_values": pixel_values,
"loss_mask": self.loss_mask,
"loss_mask": loss_mask,
"labels": labels,
}

Expand Down
71 changes: 48 additions & 23 deletions nemo/collections/vlm/gemma3vl/data/task_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
from dataclasses import dataclass, field
from typing import Optional
Expand All @@ -24,6 +25,7 @@
from nemo.collections.vlm.data.task_encoder import TaskEncoder as BaseTaskEncoder
from nemo.collections.vlm.data.task_encoder import TaskEncoderConfig as BaseTaskEncoderConfig
from nemo.collections.vlm.data.utils import _find_pattern_indices
from nemo.utils import logging


@dataclass
Expand Down Expand Up @@ -101,58 +103,65 @@ def encode_batch(self, batch_data: DataBatch) -> dict:
batch_data["media"] = batch_data["media"].reshape(-1, *batch_data["media"].shape[2:])
return batch_data

def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
"""Encode a VQA sample into a DataSample format.
def encode_vqa_sample_multi_turns(self, input_sample: VQASample):
"""Encode a VQA sample multi turns into a DataSample format.

Args:
input_sample (VQASample): Input VQA sample containing image, context and answers

Returns:
DataSample: Encoded sample with processed image, tokens, labels and loss mask
Encoded tokens, labels and images.
"""
images = input_sample.image if isinstance(input_sample.image, list) else [input_sample.image]

contexts = json.loads(input_sample.context.decode('utf-8'))
messages = []
if self.config.system_prompt:
messages.append({'role': 'system', 'content': self.config.system_prompt})

# Ensure context and answers are lists for consistent processing
contexts = input_sample.context if isinstance(input_sample.context, list) else [input_sample.context]
answers = input_sample.answers if isinstance(input_sample.answers, list) else [input_sample.answers]

# Build the conversation messages, replacing image placeholder
min_length = min(len(contexts), len(answers))
for i in range(min_length):
context_with_placeholder = contexts[i].replace("<image>", self.config.image_token)
messages.append({'role': self.config.roles[0], 'content': context_with_placeholder})
messages.append({'role': self.config.roles[1], 'content': answers[i]})
for context in contexts:
messages.append(context)

# Apply chat template and process with HF processor
converted_messages = self.hf_processor.apply_chat_template(messages, tokenize=False)
# `add_generation_prompt=False` because we're providing the full ground truth sequence
# We remove the <bos> token using removeprefix('<bos>') since we're finetuning.
# The Processor will add this token before training and the model expects only one.
converted_messages = self.hf_processor.apply_chat_template(
messages, add_generation_prompt=False, tokenize=False
).removeprefix('<bos>')
outputs = self.hf_processor(
images=input_sample.image,
images=images,
text=converted_messages,
return_tensors="pt",
images_kwargs={"do_rescale": False},
)

# Get tokens and images from processor output
# Squeeze the batch dimension as we process one sample at a time
tokens = outputs["input_ids"].squeeze(0)
images = outputs.get("pixel_values") # Use .get() for optional images

# --- Label Generation ---
# Same as: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/vlm/qwen2vl/data/task_encoder.py#L263-L270
# Initialize labels with ignore placeholder
labels = torch.full_like(tokens, self.config.ignore_place_holder)

search_start_index = 0
for answer in answers:
for context in contexts:
if context['role'] != 'assistant':
continue
# Tokenize the answer, including the stop string if provided
answer_with_stop = answer + (self.config.stop_string or "")
answer_with_stop = (
context['content'][0]['text'].rstrip().lstrip() + "<end_of_turn>" + (self.config.stop_string or "")
)
answer_with_stop = answer_with_stop.rstrip().lstrip()
answer_tokens = self.tokenizer.tokenizer(answer_with_stop, add_special_tokens=False)["input_ids"]
answer_tokens_tensor = torch.tensor(answer_tokens, device=tokens.device) # Ensure same device

# sometimes the tokenizer can add additional space. See:
# https://github.com/huggingface/transformers/issues/25073#issuecomment-1655271420
if self.tokenizer.tokenizer.decode(answer_tokens[0]) == "":
answer_tokens_tensor = answer_tokens_tensor[1:]

# Find answer pattern in tokens
answer_start, answer_end = _find_pattern_indices(tokens, answer_tokens_tensor, search_start_index)

if answer_start >= 0:
labels[answer_start:answer_end] = tokens[answer_start:answer_end]
search_start_index = answer_end
Expand All @@ -170,11 +179,24 @@ def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
search_start_index,
)
break
return tokens, labels, images

def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
"""Encode a VQA sample into a DataSample format.

Args:
input_sample (VQASample): Input VQA sample containing image, context and answers

Returns:
DataSample: Encoded sample with processed image, tokens, labels and loss mask
"""
tokens, labels, images = self.encode_vqa_sample_multi_turns(input_sample)

# Prepare final tensors
tokens = tokens[:-1].contiguous()
labels = labels[1:].contiguous()
seqlen = len(tokens) # Original sequence length before padding
position_ids = torch.arange(seqlen, dtype=torch.int64)

# Pad tokens and labels to a multiple of `pad_to_multiple_of` if specified
if self.config.pad_to_multiple_of:
Expand All @@ -191,7 +213,7 @@ def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:

# Compute loss mask
loss_mask = torch.ones_like(labels, dtype=torch.float)
loss_mask[labels == self.config.ignore_place_holder] = 0.0
loss_mask[labels < 0] = 0.0

# Convert images to bfloat16 and stack, or create an empty tensor if no images
if images is not None and images.numel() > 0:
Expand All @@ -202,13 +224,16 @@ def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
# Create an empty tensor with appropriate dimensions and dtype if no images
processed_image = None

return Gemma3DataSample(
sample = Gemma3DataSample(
__key__=input_sample.__key__,
__restore_key__=input_sample.__restore_key__,
__subflavor__=input_sample.__subflavor__,
__subflavors__=input_sample.__subflavors__,
pixel_values=processed_image,
input_ids=tokens,
position_ids=position_ids,
labels=labels,
loss_mask=loss_mask,
)

return sample
18 changes: 12 additions & 6 deletions nemo/collections/vlm/gemma3vl/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from nemo.collections.llm.gpt.model.gemma3 import Gemma3Config
from nemo.collections.vlm.gemma3vl.model.vision import Gemma3VLMultimodalProjectorConfig, Gemma3VLVisionConfig
from nemo.collections.vlm.neva.model.base import MODEL_CONFIG_ATTR, NevaModel, restore_model_weights
from nemo.collections.vlm.qwen2vl.data.multimodal_tokens import IGNORE_INDEX
from nemo.lightning import io
from nemo.lightning.pytorch.optim import OptimizerModule
from nemo.utils.import_utils import safe_import_from
Expand Down Expand Up @@ -78,6 +79,7 @@ def gemma3vl_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
key: val.cuda(non_blocking=True) if key in required_keys and val is not None else None
for key, val in _batch.items()
}

return _batch


Expand Down Expand Up @@ -392,18 +394,18 @@ def forward(
input_ids = F.pad(input_ids, (0, padded_seq_len))
position_ids = F.pad(position_ids, (0, padded_seq_len))
if self.post_process:
labels = F.pad(labels, (0, padded_seq_len))
loss_mask = F.pad(loss_mask, (0, padded_seq_len))
labels = F.pad(labels, (0, padded_seq_len), value=IGNORE_INDEX)
loss_mask = F.pad(loss_mask, (0, padded_seq_len), value=0.0)

# Compute language embedding
if self.pre_process:
safe_input_ids = input_ids
# Replace image_token_id with 0 to avoid embedding index error
if self.image_token_id >= self.vocab_size:
image_token_mask = input_ids == self.image_token_id
safe_input_ids = input_ids.clone()
safe_input_ids[image_token_mask] = 0
image_token_mask = input_ids == self.image_token_id
safe_input_ids = input_ids.clone()
safe_input_ids[image_token_mask] = 0
# (T, B, D)
# The position_ids is None for qwen2 models, but set to position_ids for gemma3vl models.
language_embedding = self.language_model.embedding(input_ids=safe_input_ids, position_ids=position_ids)
# (B, T, D)
language_embedding = language_embedding.transpose(1, 0).contiguous()
Expand All @@ -428,6 +430,7 @@ def forward(
combined_embedding = combined_embedding.transpose(1, 0).contiguous()

# Run decoder model
# position_ids is None for gemma3vl models, but set to position_ids to qwen2 models.
output = self.language_model(
input_ids=None,
position_ids=None,
Expand All @@ -441,6 +444,8 @@ def forward(

if labels is None or loss_mask is None:
return output

output = output.masked_fill(labels < 0, 0.0)
return output, loss_mask

def _preprocess_data(
Expand Down Expand Up @@ -536,6 +541,7 @@ def _process_sequence_parallel(
combined_embedding = scatter_to_sequence_parallel_region(combined_embedding)
return combined_embedding, labels, loss_mask, packed_seq_params

@torch.compile
def _compute_attention_mask(
self,
input_ids: torch.Tensor,
Expand Down
83 changes: 83 additions & 0 deletions scripts/vlm/gemma3vl_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Export Gemma3VL NeMo checkpoints to Hugging Face format."""

import argparse
from pathlib import Path

from huggingface_hub import hf_hub_download

from nemo.collections import llm


def main():
parser = argparse.ArgumentParser(
description=("Export NeMo vision language model checkpoint to Hugging Face format.")
)
parser.add_argument(
"--nemo_ckpt_path",
type=str,
required=True,
default=None,
help="Path to the NeMo checkpoint directory.",
)
parser.add_argument(
"--output_hf_path",
type=str,
required=True,
default=None,
help="Path to save the converted Hugging Face checkpoint.",
)
parser.add_argument(
"--model_name",
type=str,
required=False,
default=None,
help="Name of the model on Hugging Face.",
)

args = parser.parse_args()

llm.export_ckpt(
path=Path(args.nemo_ckpt_path),
target="hf",
output_path=Path(args.output_hf_path),
overwrite=True,
)
if args.model_name:
# Copy necessary files if exist from HuggingFace for Gemma3VL model export.
copy_file_list = [
"preprocessor_config.json",
"chat_template.json",
"config.json",
"generation_config.json",
"merges.txt",
"tokenizer.json",
"tokenizer_config.json",
"vocab.json",
]
for file_name in copy_file_list:
try:
downloaded_path = hf_hub_download(
repo_id=args.model_name,
filename=file_name,
local_dir=args.output_hf_path,
)
print(f"Downloaded {downloaded_path} during export gamma3vl models.")
except Exception as e:
print(f"Ignore {file_name} during export gamma3vl models.")


if __name__ == "__main__":
main()
Loading
Loading