diff --git a/nemo/collections/asr/inference/model_wrappers/rnnt_inference_wrapper.py b/nemo/collections/asr/inference/model_wrappers/rnnt_inference_wrapper.py index f15b5ff07ef7..a554b960f071 100644 --- a/nemo/collections/asr/inference/model_wrappers/rnnt_inference_wrapper.py +++ b/nemo/collections/asr/inference/model_wrappers/rnnt_inference_wrapper.py @@ -71,12 +71,16 @@ def get_subsampling_factor(self) -> int: """ return self.asr_model.encoder.subsampling_factor - def encode(self, processed_signal: Tensor, processed_signal_length: Tensor) -> tuple[Tensor, Tensor]: + def encode( + self, processed_signal: Tensor, processed_signal_length: Tensor, prompt_vectors: Tensor | None = None + ) -> tuple[Tensor, Tensor]: """ Get encoder output from the model. It is used for streaming inference. Args: processed_signal: (Tensor) processed signal. Shape is torch.Size([B, C, T]). processed_signal_length: (Tensor) processed signal length. Shape is torch.Size([B]). + prompt_vectors: (Tensor | None) Optional prompt vectors for multilingual models. + Shape can be torch.Size([B, num_prompts]) or torch.Size([B, T_enc, num_prompts]) if already expanded. Returns: (tuple[Tensor, Tensor]) encoder output and encoder output length of shape torch.Size([B, T, D]), torch.Size([B]). """ @@ -92,9 +96,15 @@ def encode(self, processed_signal: Tensor, processed_signal_length: Tensor) -> t torch.no_grad(), ): - forward_outs = self.asr_model( - processed_signal=processed_signal.to(self.cast_dtype), processed_signal_length=processed_signal_length - ) + # Prepare model arguments + model_args = { + 'processed_signal': processed_signal.to(self.cast_dtype), + 'processed_signal_length': processed_signal_length, + } + if prompt_vectors is not None: + model_args['prompt'] = prompt_vectors + + forward_outs = self.asr_model(**model_args) encoded, encoded_len = forward_outs return encoded, encoded_len @@ -113,3 +123,25 @@ def decode(self, encoded: Tensor, encoded_len: Tensor, partial_hypotheses: list) encoded.to(self.cast_dtype), encoded_len, return_hypotheses=True, partial_hypotheses=partial_hypotheses ) return best_hyp + + def encode_with_prompts( + self, processed_signal: Tensor, processed_signal_length: Tensor, prompt_vectors: Tensor + ) -> tuple[Tensor, Tensor]: + """ + Convenience wrapper for prompt-enabled encoding. + Expands prompt vectors across the time dimension before calling encode. + Args: + processed_signal: (Tensor) processed signal. Shape is torch.Size([B, C, T]). + processed_signal_length: (Tensor) processed signal length. Shape is torch.Size([B]). + prompt_vectors: (Tensor) prompt vectors. Shape is torch.Size([B, num_prompts]). + Returns: + (tuple[Tensor, Tensor]) encoder output and encoder output length. + """ + encoder_time_steps = processed_signal.shape[2] // self.get_subsampling_factor() + # Expand prompts: [B, num_prompts] -> [B, T_enc, num_prompts] + prompt_vectors = prompt_vectors.unsqueeze(1).expand(-1, encoder_time_steps, -1) + return self.encode( + processed_signal=processed_signal, + processed_signal_length=processed_signal_length, + prompt_vectors=prompt_vectors, + ) diff --git a/nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py index 915063d4d22b..7fe59f993f9a 100644 --- a/nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py @@ -70,6 +70,7 @@ def __init__( self.init_bpe_decoder() self.init_decoding_computer() self.init_text_processor(cfg, itn_model) + self.init_prompt_support() super().__init__() def init_parameters(self, cfg: DictConfig) -> None: @@ -174,6 +175,99 @@ def init_decoding_computer(self) -> None: if self.stateful: self.decoding_computer = self.asr_model.asr_model.decoding.decoding.decoding_computer + def init_prompt_support(self) -> None: + """Initialize prompt support for multilingual models.""" + self.prompt_enabled = hasattr(self.asr_model.asr_model, 'concat') and self.asr_model.asr_model.concat + + if self.prompt_enabled: + self._prompt_config = self._load_prompt_config() + self._prompt_matrix_cache = {} + + def _load_prompt_config(self) -> dict: + """ + Load and cache prompt configuration once at initialization. + Returns: + (dict) Prompt configuration containing num_prompts, prompt_dict, and compute_dtype. + """ + cfg = self.asr_model.asr_model.cfg + if cfg and hasattr(cfg, 'model_defaults'): + model_defaults = cfg.model_defaults + num_prompts = model_defaults.get('num_prompts', None) + prompt_dict = model_defaults.get('prompt_dictionary', None) + + # Validate and convert types once + num_prompts_int = int(num_prompts) if num_prompts is not None else 0 + + is_dict_like = isinstance(prompt_dict, dict) or ( + hasattr(prompt_dict, 'get') and hasattr(prompt_dict, '__contains__') + ) + + if num_prompts_int > 0 and is_dict_like: + return { + 'num_prompts': num_prompts_int, + 'prompt_dict': prompt_dict, + 'compute_dtype': getattr(self.asr_model.asr_model, 'dtype', torch.float32), + } + + return {} + + def _resolve_prompt_index(self, language_code: str) -> int: + """ + Resolve language_code to a strict prompt index; raise if invalid. + Args: + language_code: (str) Language code to resolve (e.g., "en-US", "es-ES"). + Returns: + (int) Prompt index corresponding to the language code. + Raises: + RuntimeError: If prompt configuration is missing. + ValueError: If language_code is not found in prompt dictionary. + """ + if not hasattr(self, '_prompt_config') or not self._prompt_config: + raise RuntimeError("Prompt configuration is missing for a prompt-enabled model.") + prompt_dict = self._prompt_config['prompt_dict'] + lang_index = prompt_dict.get(language_code, None) + if lang_index is None: + raise ValueError( + f"Language code '{language_code}' not found in prompt dictionary. " + f"Available languages: {list(prompt_dict.keys())}" + ) + return lang_index + + def _get_prompt_matrix(self) -> Tensor: + """ + Return cached identity matrix [num_prompts, num_prompts] on device/dtype. + Returns: + (Tensor) Identity matrix for prompt selection. + """ + if not hasattr(self, '_prompt_config') or not self._prompt_config: + raise RuntimeError("Prompt configuration is missing for a prompt-enabled model.") + key = (self.device, self._prompt_config['compute_dtype']) + cached = self._prompt_matrix_cache.get(key) + if cached is not None: + return cached + num_prompts = self._prompt_config['num_prompts'] + compute_dtype = self._prompt_config['compute_dtype'] + eye = torch.eye(num_prompts, device=self.device, dtype=compute_dtype) + self._prompt_matrix_cache[key] = eye + return eye + + def _build_prompt_vectors(self, states: list) -> Tensor: + """ + Build prompt vectors for a batch of states. + Args: + states: (list) List of streaming states. + Returns: + (Tensor) Prompt vectors of shape [B, num_prompts]. + Raises: + ValueError: If any prompt index is out of range. + """ + indices = torch.tensor([getattr(s, 'prompt_idx', 0) for s in states], device=self.device, dtype=torch.long) + num_prompts = self._prompt_config['num_prompts'] + if torch.any((indices < 0) | (indices >= num_prompts)): + raise ValueError("Found out-of-range prompt index in batch.") + prompt_matrix = self._get_prompt_matrix() + return prompt_matrix.index_select(0, indices) # [B, num_prompts] + def init_zero_enc(self) -> Tensor: """ Initialize the encoder output for the zero buffer. @@ -190,9 +284,24 @@ def init_zero_enc(self) -> Tensor: buffer_lens=torch.tensor([zero_buffer.shape[1]], device=self.device), expected_feature_buffer_len=self.expected_feature_buffer_len, ) - zero_encoded, _ = self.asr_model.encode( - processed_signal=zero_features, processed_signal_length=zero_features_len - ) + + if self.prompt_enabled: + # Use "en-US" as the default prompt for zero encoding + # This region is sliced out before decoding, so language choice doesn't matter + default_prompt_idx = self._resolve_prompt_index("en-US") + prompt_matrix = self._get_prompt_matrix() + prompt_vector = prompt_matrix[default_prompt_idx].unsqueeze(0) # [1, num_prompts] + + zero_encoded, _ = self.asr_model.encode_with_prompts( + processed_signal=zero_features, + processed_signal_length=zero_features_len, + prompt_vectors=prompt_vector, + ) + else: + zero_encoded, _ = self.asr_model.encode( + processed_signal=zero_features, processed_signal_length=zero_features_len + ) + return zero_encoded[0] def create_state(self, options: ASRRequestOptions) -> RNNTStreamingState: @@ -210,8 +319,18 @@ def create_state(self, options: ASRRequestOptions) -> RNNTStreamingState: default_enable_pnc=self.text_processor.is_pnc_enabled(), default_stop_history_eou=self.stop_history_eou_in_milliseconds, default_asr_output_granularity=self.asr_output_granularity, + default_language_code="en-US" if self.prompt_enabled else None, ) state.set_options(new_options) + + # Create per-stream prompt index for prompt-enabled models + if self.prompt_enabled: + lang_code = getattr(new_options, "language_code", None) + if not isinstance(lang_code, str) or len(lang_code) == 0: + raise ValueError("Prompt-enabled model requires a valid language_code in request options.") + prompt_idx = self._resolve_prompt_index(lang_code) + state.set_prompt_index(prompt_idx) + return state def get_sep(self) -> str: @@ -295,9 +414,21 @@ def encode_raw_signals( expected_feature_buffer_len=self.expected_feature_buffer_len, ) - encoded, encoded_len = self.asr_model.encode( - processed_signal=feature_buffers, processed_signal_length=feature_buffer_lens - ) + # Build prompt vectors if prompts are enabled + if self.prompt_enabled: + requests_states = [self.get_state(f.stream_id) for f in frames] + prompt_vectors = self._build_prompt_vectors(requests_states) + + # Use encode_with_prompts which handles dimension expansion + encoded, encoded_len = self.asr_model.encode_with_prompts( + processed_signal=feature_buffers, + processed_signal_length=feature_buffer_lens, + prompt_vectors=prompt_vectors, + ) + else: + encoded, encoded_len = self.asr_model.encode( + processed_signal=feature_buffers, processed_signal_length=feature_buffer_lens + ) encoded = encoded.clone() encoded_len = encoded_len.clone() @@ -331,9 +462,21 @@ def encode_processed_signals( processed_signals = normalize_features(processed_signals, processed_signal_lengths) processed_signal_lengths = processed_signal_lengths.clamp(max=processed_signals.shape[2]) - encoded, encoded_len = self.asr_model.encode( - processed_signal=processed_signals, processed_signal_length=processed_signal_lengths - ) + # Build prompt vectors if prompts are enabled + if self.prompt_enabled: + requests_states = [self.get_state(f.stream_id) for f in fbuffers] + prompt_vectors = self._build_prompt_vectors(requests_states) + + # Use encode_with_prompts which handles dimension expansion + encoded, encoded_len = self.asr_model.encode_with_prompts( + processed_signal=processed_signals, + processed_signal_length=processed_signal_lengths, + prompt_vectors=prompt_vectors, + ) + else: + encoded, encoded_len = self.asr_model.encode( + processed_signal=processed_signals, processed_signal_length=processed_signal_lengths + ) encoded = encoded.clone() encoded_len = encoded_len.clone() diff --git a/nemo/collections/asr/inference/streaming/framing/request_options.py b/nemo/collections/asr/inference/streaming/framing/request_options.py index fff6f7677c2a..e8e6d796ab57 100644 --- a/nemo/collections/asr/inference/streaming/framing/request_options.py +++ b/nemo/collections/asr/inference/streaming/framing/request_options.py @@ -29,6 +29,7 @@ class ASRRequestOptions: enable_pnc: bool = None stop_history_eou: int = None asr_output_granularity: ASROutputGranularity | str = None + language_code: str | None = None def __post_init__(self) -> None: """ @@ -56,6 +57,7 @@ def augment_with_defaults( default_enable_pnc: bool, default_stop_history_eou: int, default_asr_output_granularity: ASROutputGranularity | str, + default_language_code: str | None = None, ) -> "ASRRequestOptions": """ Augment the options with the default values. @@ -64,6 +66,7 @@ def augment_with_defaults( default_enable_pnc (bool): Default enable PNC. default_stop_history_eou (int): Default stop history EOU. default_asr_output_granularity (ASROutputGranularity | str): Default output granularity. + default_language_code (str | None): Default language code for prompt-enabled models. Returns: ASRRequestOptions: Augmented options. """ @@ -76,6 +79,7 @@ def augment_with_defaults( asr_output_granularity=( default_asr_output_granularity if self.asr_output_granularity is None else self.asr_output_granularity ), + language_code=default_language_code if self.language_code is None else self.language_code, ) diff --git a/nemo/collections/asr/inference/streaming/state/state.py b/nemo/collections/asr/inference/streaming/state/state.py index 59f5031110f4..29474be45a5c 100644 --- a/nemo/collections/asr/inference/streaming/state/state.py +++ b/nemo/collections/asr/inference/streaming/state/state.py @@ -106,6 +106,9 @@ def _reset_streaming_state(self) -> None: # Request options self.options = None + # Prompt-related index (set by pipelines that use prompts) + self.prompt_idx = None + def set_options(self, options: RequestOptions) -> None: """ Set the options @@ -114,6 +117,14 @@ def set_options(self, options: RequestOptions) -> None: """ self.options = options + def set_prompt_index(self, prompt_idx: int) -> None: + """ + Store the resolved prompt index for prompt-enabled models. + Args: + prompt_idx: (int) The prompt index to store in the state + """ + self.prompt_idx = prompt_idx + def set_incomplete_segment_tokens(self, incomplete_segment_tokens: list) -> None: """ Set the partial tokens