Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,16 @@ def get_subsampling_factor(self) -> int:
"""
return self.asr_model.encoder.subsampling_factor

def encode(self, processed_signal: Tensor, processed_signal_length: Tensor) -> tuple[Tensor, Tensor]:
def encode(
self, processed_signal: Tensor, processed_signal_length: Tensor, prompt_vectors: Tensor | None = None
) -> tuple[Tensor, Tensor]:
"""
Get encoder output from the model. It is used for streaming inference.
Args:
processed_signal: (Tensor) processed signal. Shape is torch.Size([B, C, T]).
processed_signal_length: (Tensor) processed signal length. Shape is torch.Size([B]).
prompt_vectors: (Tensor | None) Optional prompt vectors for multilingual models.
Shape can be torch.Size([B, num_prompts]) or torch.Size([B, T_enc, num_prompts]) if already expanded.
Returns:
(tuple[Tensor, Tensor]) encoder output and encoder output length of shape torch.Size([B, T, D]), torch.Size([B]).
"""
Expand All @@ -92,9 +96,15 @@ def encode(self, processed_signal: Tensor, processed_signal_length: Tensor) -> t
torch.no_grad(),
):

forward_outs = self.asr_model(
processed_signal=processed_signal.to(self.cast_dtype), processed_signal_length=processed_signal_length
)
# Prepare model arguments
model_args = {
'processed_signal': processed_signal.to(self.cast_dtype),
'processed_signal_length': processed_signal_length,
}
if prompt_vectors is not None:
model_args['prompt'] = prompt_vectors

forward_outs = self.asr_model(**model_args)

encoded, encoded_len = forward_outs
return encoded, encoded_len
Expand All @@ -113,3 +123,25 @@ def decode(self, encoded: Tensor, encoded_len: Tensor, partial_hypotheses: list)
encoded.to(self.cast_dtype), encoded_len, return_hypotheses=True, partial_hypotheses=partial_hypotheses
)
return best_hyp

def encode_with_prompts(
self, processed_signal: Tensor, processed_signal_length: Tensor, prompt_vectors: Tensor
) -> tuple[Tensor, Tensor]:
"""
Convenience wrapper for prompt-enabled encoding.
Expands prompt vectors across the time dimension before calling encode.
Args:
processed_signal: (Tensor) processed signal. Shape is torch.Size([B, C, T]).
processed_signal_length: (Tensor) processed signal length. Shape is torch.Size([B]).
prompt_vectors: (Tensor) prompt vectors. Shape is torch.Size([B, num_prompts]).
Returns:
(tuple[Tensor, Tensor]) encoder output and encoder output length.
"""
encoder_time_steps = processed_signal.shape[2] // self.get_subsampling_factor()
# Expand prompts: [B, num_prompts] -> [B, T_enc, num_prompts]
prompt_vectors = prompt_vectors.unsqueeze(1).expand(-1, encoder_time_steps, -1)
return self.encode(
processed_signal=processed_signal,
processed_signal_length=processed_signal_length,
prompt_vectors=prompt_vectors,
)
161 changes: 152 additions & 9 deletions nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is duplicated code for building prompts in both encode_raw_signals and encode_processed_signals. It would be better to move this logic into a separate helper method.

Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
self.init_bpe_decoder()
self.init_decoding_computer()
self.init_text_processor(cfg, itn_model)
self.init_prompt_support()
super().__init__()

def init_parameters(self, cfg: DictConfig) -> None:
Expand Down Expand Up @@ -174,6 +175,99 @@ def init_decoding_computer(self) -> None:
if self.stateful:
self.decoding_computer = self.asr_model.asr_model.decoding.decoding.decoding_computer

def init_prompt_support(self) -> None:
"""Initialize prompt support for multilingual models."""
self.prompt_enabled = hasattr(self.asr_model.asr_model, 'concat') and self.asr_model.asr_model.concat

if self.prompt_enabled:
self._prompt_config = self._load_prompt_config()
self._prompt_matrix_cache = {}

Comment on lines +184 to +185
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When prompt_enabled is True (line 180) but _load_prompt_config() returns an empty dict (line 212), the _prompt_config will be empty. However, the code at lines 271-274 will attempt to use it, causing a RuntimeError at line 226.

This indicates a configuration issue, but it happens at runtime rather than initialization. Consider adding validation in init_prompt_support():

if self.prompt_enabled:
    self._prompt_config = self._load_prompt_config()
    if not self._prompt_config:
        raise RuntimeError(
            "Model has concat=True but prompt configuration (num_prompts, prompt_dictionary) "
            "is missing or invalid in model_defaults."
        )
    self._prompt_matrix_cache = {}
Suggested change
self._prompt_matrix_cache = {}
if not self._prompt_config:
raise RuntimeError(
"Model has concat=True but prompt configuration (num_prompts, prompt_dictionary) "
"is missing or invalid in model_defaults."
)
self._prompt_matrix_cache = {}

Copilot uses AI. Check for mistakes.
def _load_prompt_config(self) -> dict:
"""
Load and cache prompt configuration once at initialization.
Returns:
(dict) Prompt configuration containing num_prompts, prompt_dict, and compute_dtype.
"""
cfg = self.asr_model.asr_model.cfg
if cfg and hasattr(cfg, 'model_defaults'):
model_defaults = cfg.model_defaults
num_prompts = model_defaults.get('num_prompts', None)
prompt_dict = model_defaults.get('prompt_dictionary', None)

# Validate and convert types once
num_prompts_int = int(num_prompts) if num_prompts is not None else 0

is_dict_like = isinstance(prompt_dict, dict) or (
hasattr(prompt_dict, 'get') and hasattr(prompt_dict, '__contains__')
)

if num_prompts_int > 0 and is_dict_like:
return {
'num_prompts': num_prompts_int,
'prompt_dict': prompt_dict,
'compute_dtype': getattr(self.asr_model.asr_model, 'dtype', torch.float32),
}

return {}

def _resolve_prompt_index(self, language_code: str) -> int:
"""
Resolve language_code to a strict prompt index; raise if invalid.
Args:
language_code: (str) Language code to resolve (e.g., "en-US", "es-ES").
Returns:
(int) Prompt index corresponding to the language code.
Raises:
RuntimeError: If prompt configuration is missing.
ValueError: If language_code is not found in prompt dictionary.
"""
if not hasattr(self, '_prompt_config') or not self._prompt_config:
raise RuntimeError("Prompt configuration is missing for a prompt-enabled model.")
prompt_dict = self._prompt_config['prompt_dict']
lang_index = prompt_dict.get(language_code, None)
if lang_index is None:
raise ValueError(
f"Language code '{language_code}' not found in prompt dictionary. "
f"Available languages: {list(prompt_dict.keys())}"
)
return lang_index

def _get_prompt_matrix(self) -> Tensor:
"""
Return cached identity matrix [num_prompts, num_prompts] on device/dtype.
Returns:
(Tensor) Identity matrix for prompt selection.
"""
if not hasattr(self, '_prompt_config') or not self._prompt_config:
raise RuntimeError("Prompt configuration is missing for a prompt-enabled model.")
key = (self.device, self._prompt_config['compute_dtype'])
cached = self._prompt_matrix_cache.get(key)
if cached is not None:
return cached
num_prompts = self._prompt_config['num_prompts']
compute_dtype = self._prompt_config['compute_dtype']
eye = torch.eye(num_prompts, device=self.device, dtype=compute_dtype)
self._prompt_matrix_cache[key] = eye
return eye

def _build_prompt_vectors(self, states: list) -> Tensor:
"""
Build prompt vectors for a batch of states.
Args:
states: (list) List of streaming states.
Returns:
(Tensor) Prompt vectors of shape [B, num_prompts].
Raises:
ValueError: If any prompt index is out of range.
"""
indices = torch.tensor([getattr(s, 'prompt_idx', 0) for s in states], device=self.device, dtype=torch.long)
num_prompts = self._prompt_config['num_prompts']
if torch.any((indices < 0) | (indices >= num_prompts)):
raise ValueError("Found out-of-range prompt index in batch.")
prompt_matrix = self._get_prompt_matrix()
return prompt_matrix.index_select(0, indices) # [B, num_prompts]

def init_zero_enc(self) -> Tensor:
"""
Initialize the encoder output for the zero buffer.
Expand All @@ -190,9 +284,24 @@ def init_zero_enc(self) -> Tensor:
buffer_lens=torch.tensor([zero_buffer.shape[1]], device=self.device),
expected_feature_buffer_len=self.expected_feature_buffer_len,
)
zero_encoded, _ = self.asr_model.encode(
processed_signal=zero_features, processed_signal_length=zero_features_len
)

if self.prompt_enabled:
# Use "en-US" as the default prompt for zero encoding
# This region is sliced out before decoding, so language choice doesn't matter
default_prompt_idx = self._resolve_prompt_index("en-US")
Comment on lines +289 to +291
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded "en-US" language code assumes this language will always be present in the prompt dictionary. If a prompt-enabled model doesn't include "en-US" in its prompt dictionary, this will cause a ValueError during initialization in init_zero_enc().

Consider either:

  1. Using the first available language from the prompt dictionary
  2. Making this configurable
  3. Adding validation at initialization to ensure "en-US" exists

Example fix:

# Get the first available language or use a configurable default
available_languages = list(self._prompt_config['prompt_dict'].keys())
default_lang = available_languages[0] if available_languages else "en-US"
default_prompt_idx = self._resolve_prompt_index(default_lang)
Suggested change
# Use "en-US" as the default prompt for zero encoding
# This region is sliced out before decoding, so language choice doesn't matter
default_prompt_idx = self._resolve_prompt_index("en-US")
# Use the first available language as the default prompt for zero encoding
# This region is sliced out before decoding, so language choice doesn't matter
available_languages = list(self._prompt_config['prompt_dict'].keys())
default_lang = available_languages[0] if available_languages else "en-US"
default_prompt_idx = self._resolve_prompt_index(default_lang)

Copilot uses AI. Check for mistakes.
prompt_matrix = self._get_prompt_matrix()
prompt_vector = prompt_matrix[default_prompt_idx].unsqueeze(0) # [1, num_prompts]

zero_encoded, _ = self.asr_model.encode_with_prompts(
processed_signal=zero_features,
processed_signal_length=zero_features_len,
prompt_vectors=prompt_vector,
)
else:
zero_encoded, _ = self.asr_model.encode(
processed_signal=zero_features, processed_signal_length=zero_features_len
)

return zero_encoded[0]

def create_state(self, options: ASRRequestOptions) -> RNNTStreamingState:
Expand All @@ -210,8 +319,18 @@ def create_state(self, options: ASRRequestOptions) -> RNNTStreamingState:
default_enable_pnc=self.text_processor.is_pnc_enabled(),
default_stop_history_eou=self.stop_history_eou_in_milliseconds,
default_asr_output_granularity=self.asr_output_granularity,
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The augment_with_defaults call is missing the default_language_code parameter. According to the updated signature in request_options.py, this parameter should be passed here. Without it, the default language code will always be None even if a default was intended to be set.

Consider adding:

new_options = options.augment_with_defaults(
    default_enable_itn=self.text_processor.is_itn_enabled(),
    default_enable_pnc=self.text_processor.is_pnc_enabled(),
    default_stop_history_eou=self.stop_history_eou_in_milliseconds,
    default_asr_output_granularity=self.asr_output_granularity,
    default_language_code=None,  # or an appropriate default
)
Suggested change
default_asr_output_granularity=self.asr_output_granularity,
default_asr_output_granularity=self.asr_output_granularity,
default_language_code=None, # or "en-US" if a default is desired

Copilot uses AI. Check for mistakes.
default_language_code="en-US" if self.prompt_enabled else None,
)
state.set_options(new_options)

# Create per-stream prompt index for prompt-enabled models
if self.prompt_enabled:
lang_code = getattr(new_options, "language_code", None)
if not isinstance(lang_code, str) or len(lang_code) == 0:
raise ValueError("Prompt-enabled model requires a valid language_code in request options.")
prompt_idx = self._resolve_prompt_index(lang_code)
state.set_prompt_index(prompt_idx)

return state

def get_sep(self) -> str:
Expand Down Expand Up @@ -295,9 +414,21 @@ def encode_raw_signals(
expected_feature_buffer_len=self.expected_feature_buffer_len,
)

encoded, encoded_len = self.asr_model.encode(
processed_signal=feature_buffers, processed_signal_length=feature_buffer_lens
)
# Build prompt vectors if prompts are enabled
if self.prompt_enabled:
requests_states = [self.get_state(f.stream_id) for f in frames]
prompt_vectors = self._build_prompt_vectors(requests_states)

# Use encode_with_prompts which handles dimension expansion
encoded, encoded_len = self.asr_model.encode_with_prompts(
processed_signal=feature_buffers,
processed_signal_length=feature_buffer_lens,
prompt_vectors=prompt_vectors,
)
else:
encoded, encoded_len = self.asr_model.encode(
processed_signal=feature_buffers, processed_signal_length=feature_buffer_lens
)
encoded = encoded.clone()
encoded_len = encoded_len.clone()

Expand Down Expand Up @@ -331,9 +462,21 @@ def encode_processed_signals(
processed_signals = normalize_features(processed_signals, processed_signal_lengths)
processed_signal_lengths = processed_signal_lengths.clamp(max=processed_signals.shape[2])

encoded, encoded_len = self.asr_model.encode(
processed_signal=processed_signals, processed_signal_length=processed_signal_lengths
)
# Build prompt vectors if prompts are enabled
if self.prompt_enabled:
requests_states = [self.get_state(f.stream_id) for f in fbuffers]
prompt_vectors = self._build_prompt_vectors(requests_states)

# Use encode_with_prompts which handles dimension expansion
encoded, encoded_len = self.asr_model.encode_with_prompts(
processed_signal=processed_signals,
processed_signal_length=processed_signal_lengths,
prompt_vectors=prompt_vectors,
)
else:
encoded, encoded_len = self.asr_model.encode(
processed_signal=processed_signals, processed_signal_length=processed_signal_lengths
)
encoded = encoded.clone()
encoded_len = encoded_len.clone()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class ASRRequestOptions:
enable_pnc: bool = None
stop_history_eou: int = None
asr_output_granularity: ASROutputGranularity | str = None
language_code: str | None = None

def __post_init__(self) -> None:
"""
Expand Down Expand Up @@ -56,6 +57,7 @@ def augment_with_defaults(
default_enable_pnc: bool,
default_stop_history_eou: int,
default_asr_output_granularity: ASROutputGranularity | str,
default_language_code: str | None = None,
) -> "ASRRequestOptions":
"""
Augment the options with the default values.
Expand All @@ -64,6 +66,7 @@ def augment_with_defaults(
default_enable_pnc (bool): Default enable PNC.
default_stop_history_eou (int): Default stop history EOU.
default_asr_output_granularity (ASROutputGranularity | str): Default output granularity.
default_language_code (str | None): Default language code for prompt-enabled models.
Returns:
ASRRequestOptions: Augmented options.
"""
Expand All @@ -76,6 +79,7 @@ def augment_with_defaults(
asr_output_granularity=(
default_asr_output_granularity if self.asr_output_granularity is None else self.asr_output_granularity
),
language_code=default_language_code if self.language_code is None else self.language_code,
)


Expand Down
11 changes: 11 additions & 0 deletions nemo/collections/asr/inference/streaming/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def _reset_streaming_state(self) -> None:
# Request options
self.options = None

# Prompt-related index (set by pipelines that use prompts)
self.prompt_idx = None

def set_options(self, options: RequestOptions) -> None:
"""
Set the options
Expand All @@ -114,6 +117,14 @@ def set_options(self, options: RequestOptions) -> None:
"""
self.options = options

def set_prompt_index(self, prompt_idx: int) -> None:
"""
Store the resolved prompt index for prompt-enabled models.
Args:
prompt_idx: (int) The prompt index to store in the state
"""
self.prompt_idx = prompt_idx

def set_incomplete_segment_tokens(self, incomplete_segment_tokens: list) -> None:
"""
Set the partial tokens
Expand Down