-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Changes required for enabling prompt based models in Nemo Inference #15036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
3fa05c5 to
921122c
Compare
Signed-off-by: arushid <[email protected]>
Signed-off-by: arushidNV <[email protected]>
Signed-off-by: arushid <[email protected]>
Signed-off-by: arushid <[email protected]>
Signed-off-by: arushidNV <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for RNNT multilingual models with prompt-based language selection in NeMo Inference. The implementation enables language-specific prompts to be passed through the inference pipeline, allowing a single model to handle multiple languages.
Key Changes:
- Added
language_codefield toASRRequestOptionsfor specifying the target language per request - Introduced
prompt_idxtracking in streaming state to maintain language selection across stream lifecycle - Implemented prompt vector generation and caching infrastructure in the buffered RNNT pipeline with validation and efficient batch processing
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 9 comments.
| File | Description |
|---|---|
nemo/collections/asr/inference/streaming/state/state.py |
Adds prompt_idx field and set_prompt_index() method to track prompt index per stream |
nemo/collections/asr/inference/streaming/framing/request_options.py |
Adds language_code field to request options for language selection |
nemo/collections/asr/inference/model_wrappers/rnnt_inference_wrapper.py |
Extends encode methods to support optional prompt vectors with dimension expansion |
nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py |
Implements prompt support infrastructure including configuration loading, validation, and prompt vector batch generation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| default_enable_itn=self.text_processor.is_itn_enabled(), | ||
| default_enable_pnc=self.text_processor.is_pnc_enabled(), | ||
| default_stop_history_eou=self.stop_history_eou_in_milliseconds, | ||
| default_asr_output_granularity=self.asr_output_granularity, |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The augment_with_defaults call is missing the default_language_code parameter. According to the updated signature in request_options.py, this parameter should be passed here. Without it, the default language code will always be None even if a default was intended to be set.
Consider adding:
new_options = options.augment_with_defaults(
default_enable_itn=self.text_processor.is_itn_enabled(),
default_enable_pnc=self.text_processor.is_pnc_enabled(),
default_stop_history_eou=self.stop_history_eou_in_milliseconds,
default_asr_output_granularity=self.asr_output_granularity,
default_language_code=None, # or an appropriate default
)| default_asr_output_granularity=self.asr_output_granularity, | |
| default_asr_output_granularity=self.asr_output_granularity, | |
| default_language_code=None, # or "en-US" if a default is desired |
| # Build prompt vectors if prompts are enabled | ||
| if self.prompt_enabled: | ||
| requests_states = [self.get_state(f.stream_id) for f in frames] | ||
| indices = torch.tensor([s.prompt_idx for s in requests_states], device=self.device, dtype=torch.long) |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code assumes s.prompt_idx is set for all states when self.prompt_enabled is True. However, if a state was created before the prompt feature was added or if set_prompt_index() wasn't called for some reason, this will cause an AttributeError when accessing s.prompt_idx.
Consider adding a check or a default value:
indices = torch.tensor([getattr(s, 'prompt_idx', 0) for s in requests_states], device=self.device, dtype=torch.long)Or better, ensure prompt_idx is always initialized in the state's _reset_streaming_state method.
| indices = torch.tensor([s.prompt_idx for s in requests_states], device=self.device, dtype=torch.long) | |
| indices = torch.tensor([getattr(s, 'prompt_idx', 0) for s in requests_states], device=self.device, dtype=torch.long) |
| # Build prompt vectors if prompts are enabled | ||
| if self.prompt_enabled: | ||
| requests_states = [self.get_state(f.stream_id) for f in fbuffers] | ||
| indices = torch.tensor([s.prompt_idx for s in requests_states], device=self.device, dtype=torch.long) |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue as line 402: the code assumes s.prompt_idx is always set for all states. If a state was created before the prompt feature was added or if set_prompt_index() wasn't called, this will cause an AttributeError.
Consider adding a check or default value as in line 402.
| indices = torch.tensor([s.prompt_idx for s in requests_states], device=self.device, dtype=torch.long) | |
| indices = torch.tensor([getattr(s, "prompt_idx", 0) for s in requests_states], device=self.device, dtype=torch.long) |
| Args: | ||
| processed_signal: (Tensor) processed signal. Shape is torch.Size([B, C, T]). | ||
| processed_signal_length: (Tensor) processed signal length. Shape is torch.Size([B]). | ||
| prompt_vectors: (Tensor | None) Optional prompt vectors for multilingual models. Shape is torch.Size([B, num_prompts]). |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring states that prompt_vectors should have shape [B, num_prompts], but the method encode_with_prompts() expands it to [B, T_enc, num_prompts] before calling this method. This creates an inconsistency in the expected shape documentation.
Either:
- Update the docstring to clarify that both shapes are accepted, or
- Document that this method accepts time-expanded prompts when called directly
Consider updating the docstring to:
prompt_vectors: (Tensor | None) Optional prompt vectors for multilingual models.
Shape can be torch.Size([B, num_prompts]) or torch.Size([B, T_enc, num_prompts]) if already expanded.| prompt_vectors: (Tensor | None) Optional prompt vectors for multilingual models. Shape is torch.Size([B, num_prompts]). | |
| prompt_vectors: (Tensor | None) Optional prompt vectors for multilingual models. Shape can be torch.Size([B, num_prompts]) or torch.Size([B, T_enc, num_prompts]) if already expanded. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please fix prompt_vectors shape in the docstring
| """ | ||
| return input_time_steps // self.get_subsampling_factor() |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The time step estimation uses integer division which may not accurately predict the actual encoder output time steps. Depending on the encoder architecture (e.g., with padding, different convolution parameters), the actual output time steps might differ from input_time_steps // subsampling_factor.
This could cause issues when the expanded prompt tensor has a different time dimension than the actual encoder expects. Consider:
- Using the actual encoder's time estimation logic if available
- Adding a safety buffer or validation
- Documenting this as an approximation that may need adjustment
Example:
# Add some buffer or use encoder's actual logic
return (input_time_steps + subsampling_factor - 1) // subsampling_factor| """ | |
| return input_time_steps // self.get_subsampling_factor() | |
| Note: This is an approximation. For most encoders, the output time steps are calculated as | |
| (input_time_steps + subsampling_factor - 1) // subsampling_factor to avoid underestimation. | |
| """ | |
| subsampling_factor = self.get_subsampling_factor() | |
| return (input_time_steps + subsampling_factor - 1) // subsampling_factor |
| # Use "en-US" as the default prompt for zero encoding | ||
| # This region is sliced out before decoding, so language choice doesn't matter | ||
| default_prompt_idx = self._resolve_prompt_index("en-US") |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hardcoded "en-US" language code assumes this language will always be present in the prompt dictionary. If a prompt-enabled model doesn't include "en-US" in its prompt dictionary, this will cause a ValueError during initialization in init_zero_enc().
Consider either:
- Using the first available language from the prompt dictionary
- Making this configurable
- Adding validation at initialization to ensure "en-US" exists
Example fix:
# Get the first available language or use a configurable default
available_languages = list(self._prompt_config['prompt_dict'].keys())
default_lang = available_languages[0] if available_languages else "en-US"
default_prompt_idx = self._resolve_prompt_index(default_lang)| # Use "en-US" as the default prompt for zero encoding | |
| # This region is sliced out before decoding, so language choice doesn't matter | |
| default_prompt_idx = self._resolve_prompt_index("en-US") | |
| # Use the first available language as the default prompt for zero encoding | |
| # This region is sliced out before decoding, so language choice doesn't matter | |
| available_languages = list(self._prompt_config['prompt_dict'].keys()) | |
| default_lang = available_languages[0] if available_languages else "en-US" | |
| default_prompt_idx = self._resolve_prompt_index(default_lang) |
| indices = torch.tensor([s.prompt_idx for s in requests_states], device=self.device, dtype=torch.long) | ||
| # Validate indices | ||
| num_prompts = self._prompt_config['num_prompts'] | ||
| if torch.any((indices < 0) | (indices >= num_prompts)): | ||
| raise ValueError("Found out-of-range prompt index in batch.") | ||
| prompt_matrix = self._get_prompt_matrix() | ||
| prompt_vectors = prompt_matrix.index_select(0, indices) # [B, num_prompts] |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for building prompt vectors (lines 400-408) is duplicated in encode_processed_signals() (lines 454-461). Consider extracting this into a helper method to reduce code duplication and improve maintainability.
Example:
def _build_prompt_vectors(self, states: list) -> Tensor:
"""Build prompt vectors for a batch of states."""
indices = torch.tensor([s.prompt_idx for s in states], device=self.device, dtype=torch.long)
num_prompts = self._prompt_config['num_prompts']
if torch.any((indices < 0) | (indices >= num_prompts)):
raise ValueError("Found out-of-range prompt index in batch.")
prompt_matrix = self._get_prompt_matrix()
return prompt_matrix.index_select(0, indices)| indices = torch.tensor([s.prompt_idx for s in requests_states], device=self.device, dtype=torch.long) | |
| # Validate indices | |
| num_prompts = self._prompt_config['num_prompts'] | |
| if torch.any((indices < 0) | (indices >= num_prompts)): | |
| raise ValueError("Found out-of-range prompt index in batch.") | |
| prompt_matrix = self._get_prompt_matrix() | |
| prompt_vectors = prompt_matrix.index_select(0, indices) # [B, num_prompts] | |
| prompt_vectors = self._build_prompt_vectors(requests_states) |
| self._prompt_matrix_cache = {} | ||
|
|
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When prompt_enabled is True (line 180) but _load_prompt_config() returns an empty dict (line 212), the _prompt_config will be empty. However, the code at lines 271-274 will attempt to use it, causing a RuntimeError at line 226.
This indicates a configuration issue, but it happens at runtime rather than initialization. Consider adding validation in init_prompt_support():
if self.prompt_enabled:
self._prompt_config = self._load_prompt_config()
if not self._prompt_config:
raise RuntimeError(
"Model has concat=True but prompt configuration (num_prompts, prompt_dictionary) "
"is missing or invalid in model_defaults."
)
self._prompt_matrix_cache = {}| self._prompt_matrix_cache = {} | |
| if not self._prompt_config: | |
| raise RuntimeError( | |
| "Model has concat=True but prompt configuration (num_prompts, prompt_dictionary) " | |
| "is missing or invalid in model_defaults." | |
| ) | |
| self._prompt_matrix_cache = {} |
| Args: | ||
| processed_signal: (Tensor) processed signal. Shape is torch.Size([B, C, T]). | ||
| processed_signal_length: (Tensor) processed signal length. Shape is torch.Size([B]). | ||
| prompt_vectors: (Tensor | None) Optional prompt vectors for multilingual models. Shape is torch.Size([B, num_prompts]). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please fix prompt_vectors shape in the docstring
| Returns: | ||
| (tuple[Tensor, Tensor]) encoder output and encoder output length. | ||
| """ | ||
| encoder_time_steps = self._estimate_encoder_time_steps(processed_signal.shape[2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like _estimate_encoder_time_steps is a single-line method. There’s no need to keep it separate, since it contains only one line and isn’t reused elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is duplicated code for building prompts in both encode_raw_signals and encode_processed_signals. It would be better to move this logic into a separate helper method.
Important
The
Update branchbutton must only be pressed in very rare occassions.An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.
What does this PR do ?
Adds support for RNNT multilingual model with prompt input in Nemo Inference
Collection: ASR
Changelog
language_codefield toASRRequestOptionsfor specifying target languageprompt_idxfield andset_prompt_index()method for creating prompt vector for each stream,prompt_vectorsparameter toencode()methodencode_with_prompts()for prompt modelsencode_raw_signals()andencode_processed_signals()to apply promptsUsage
# Add a code snippet demonstrating how to use thisGitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information