1313# limitations under the License.
1414
1515import logging
16+ import json
1617from dataclasses import dataclass , field
1718from typing import Optional
1819
2425from nemo .collections .vlm .data .task_encoder import TaskEncoder as BaseTaskEncoder
2526from nemo .collections .vlm .data .task_encoder import TaskEncoderConfig as BaseTaskEncoderConfig
2627from nemo .collections .vlm .data .utils import _find_pattern_indices
28+ from nemo .collections .vlm .qwen2vl .data .multimodal_tokens import IGNORE_INDEX , IMAGE_TOKEN_INDEX
29+ from nemo .utils import logging
2730
2831
2932@dataclass
@@ -101,58 +104,54 @@ def encode_batch(self, batch_data: DataBatch) -> dict:
101104 batch_data ["media" ] = batch_data ["media" ].reshape (- 1 , * batch_data ["media" ].shape [2 :])
102105 return batch_data
103106
104- def encode_vqa_sample (self , input_sample : VQASample ) -> DataSample :
105- """Encode a VQA sample into a DataSample format.
106107
107- Args :
108- input_sample (VQASample): Input VQA sample containing image, context and answers
108+ def encode_vqa_sample_multi_turns ( self , input_sample : VQASample ) :
109+ images = input_sample . image if isinstance ( input_sample . image , list ) else [ input_sample . image ]
109110
110- Returns:
111- DataSample: Encoded sample with processed image, tokens, labels and loss mask
112- """
111+ contexts = json .loads (input_sample .context .decode ('utf-8' ))
113112 messages = []
114113 if self .config .system_prompt :
115114 messages .append ({'role' : 'system' , 'content' : self .config .system_prompt })
116-
117- # Ensure context and answers are lists for consistent processing
118- contexts = input_sample .context if isinstance (input_sample .context , list ) else [input_sample .context ]
119- answers = input_sample .answers if isinstance (input_sample .answers , list ) else [input_sample .answers ]
120-
121- # Build the conversation messages, replacing image placeholder
122- min_length = min (len (contexts ), len (answers ))
123- for i in range (min_length ):
124- context_with_placeholder = contexts [i ].replace ("<image>" , self .config .image_token )
125- messages .append ({'role' : self .config .roles [0 ], 'content' : context_with_placeholder })
126- messages .append ({'role' : self .config .roles [1 ], 'content' : answers [i ]})
115+ for context in contexts :
116+ messages .append (context )
127117
128118 # Apply chat template and process with HF processor
129- converted_messages = self .hf_processor .apply_chat_template (messages , tokenize = False )
119+ #`add_generation_prompt=False` because we're providing the full ground truth sequence
120+ # We remove the <bos> token using removeprefix('<bos>') since we're finetuning.
121+ # The Processor will add this token before training and the model expects only one.
122+ converted_messages = self .hf_processor .apply_chat_template (messages , add_generation_prompt = False , tokenize = False ).removeprefix ('<bos>' )
130123 outputs = self .hf_processor (
131- images = input_sample . image ,
124+ images = images ,
132125 text = converted_messages ,
133126 return_tensors = "pt" ,
134127 images_kwargs = {"do_rescale" : False },
135128 )
136-
137129 # Get tokens and images from processor output
138130 # Squeeze the batch dimension as we process one sample at a time
139131 tokens = outputs ["input_ids" ].squeeze (0 )
140132 images = outputs .get ("pixel_values" ) # Use .get() for optional images
141133
142134 # --- Label Generation ---
135+ # Same as: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/vlm/qwen2vl/data/task_encoder.py#L263-L270.
143136 # Initialize labels with ignore placeholder
144137 labels = torch .full_like (tokens , self .config .ignore_place_holder )
145-
146138 search_start_index = 0
147- for answer in answers :
139+ for context in contexts :
140+ if context ['role' ] != 'assistant' :
141+ continue
148142 # Tokenize the answer, including the stop string if provided
149- answer_with_stop = answer + (self .config .stop_string or "" )
143+ answer_with_stop = context ['content' ][0 ]['text' ].rstrip ().lstrip () + "<end_of_turn>" + (self .config .stop_string or "" )
144+ answer_with_stop = answer_with_stop .rstrip ().lstrip ()
150145 answer_tokens = self .tokenizer .tokenizer (answer_with_stop , add_special_tokens = False )["input_ids" ]
151146 answer_tokens_tensor = torch .tensor (answer_tokens , device = tokens .device ) # Ensure same device
152147
148+ # sometimes the tokenizer can add additional space. See:
149+ # https://github.com/huggingface/transformers/issues/25073#issuecomment-1655271420
150+ if self .tokenizer .tokenizer .decode (answer_tokens [0 ]) == "" :
151+ answer_tokens_tensor = answer_tokens_tensor [1 :]
152+
153153 # Find answer pattern in tokens
154154 answer_start , answer_end = _find_pattern_indices (tokens , answer_tokens_tensor , search_start_index )
155-
156155 if answer_start >= 0 :
157156 labels [answer_start :answer_end ] = tokens [answer_start :answer_end ]
158157 search_start_index = answer_end
@@ -170,11 +169,25 @@ def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
170169 search_start_index ,
171170 )
172171 break
172+ return tokens , labels , images
173+
174+
175+ def encode_vqa_sample (self , input_sample : VQASample ) -> DataSample :
176+ """Encode a VQA sample into a DataSample format.
177+
178+ Args:
179+ input_sample (VQASample): Input VQA sample containing image, context and answers
180+
181+ Returns:
182+ DataSample: Encoded sample with processed image, tokens, labels and loss mask
183+ """
184+ tokens , labels , images = self .encode_vqa_sample_multi_turns (input_sample )
173185
174186 # Prepare final tensors
175187 tokens = tokens [:- 1 ].contiguous ()
176188 labels = labels [1 :].contiguous ()
177189 seqlen = len (tokens ) # Original sequence length before padding
190+ position_ids = torch .arange (seqlen , dtype = torch .int64 )
178191
179192 # Pad tokens and labels to a multiple of `pad_to_multiple_of` if specified
180193 if self .config .pad_to_multiple_of :
@@ -191,7 +204,7 @@ def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
191204
192205 # Compute loss mask
193206 loss_mask = torch .ones_like (labels , dtype = torch .float )
194- loss_mask [labels == self . config . ignore_place_holder ] = 0.0
207+ loss_mask [labels < 0 ] = 0.0
195208
196209 # Convert images to bfloat16 and stack, or create an empty tensor if no images
197210 if images is not None and images .numel () > 0 :
@@ -202,13 +215,17 @@ def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
202215 # Create an empty tensor with appropriate dimensions and dtype if no images
203216 processed_image = None
204217
205- return Gemma3DataSample (
218+ sample = Gemma3DataSample (
206219 __key__ = input_sample .__key__ ,
207220 __restore_key__ = input_sample .__restore_key__ ,
208221 __subflavor__ = input_sample .__subflavor__ ,
209222 __subflavors__ = input_sample .__subflavors__ ,
210223 pixel_values = processed_image ,
211224 input_ids = tokens ,
225+ position_ids = position_ids ,
212226 labels = labels ,
213227 loss_mask = loss_mask ,
214228 )
229+
230+ return sample
231+
0 commit comments