From c0202a27fd12bebe47bceba7e77e3a020e842272 Mon Sep 17 00:00:00 2001 From: Abhishek Gaikwad Date: Mon, 24 Nov 2025 15:54:00 -0800 Subject: [PATCH] Add support for AIS batch loading for ASR audio processing * Introduced environment variable USE_AIS_GET_BATCH to toggle AIS batch loading. * Updated PromptedAudioToTextLhotseDataset to utilize AISBatchLoader when enabled. * Modified LazyNeMoTarredIterator to handle URL-based recordings when AIS batch loading is active. Implement URL-based audio loading using AIStore's Get-Batch API to improve data pipeline efficiency. This allows batch fetching of multiple audio files without local tar archive extraction, offloading processing to AIStore. Signed-off-by: Abhishek Gaikwad --- .../asr/data/audio_to_text_lhotse_prompted.py | 24 +- .../common/data/lhotse/nemo_adapters.py | 96 +++++ .../collections/speechlm2/data/s2s_dataset.py | 11 +- requirements/requirements_asr.txt | 6 +- ...test_lhotse_nemo_adapters_ais_get_batch.py | 337 ++++++++++++++++++ 5 files changed, 467 insertions(+), 7 deletions(-) create mode 100644 tests/collections/common/test_lhotse_nemo_adapters_ais_get_batch.py diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index d3b02d161348..f6046b6f075d 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from dataclasses import dataclass from typing import Optional, Union @@ -66,6 +67,11 @@ class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset): If `enable_chunking` is True, each audio sample is split into optimally sized chunks (see `find_optimal_chunk_size` and `chunk_waveform`). This is useful for long audio inputs, allowing the model to process them in manageable segments. + + NOTE: + If the environment variable `USE_AIS_GET_BATCH` is set to `true` (case-insensitive), + then batch audio loading from AIStore will be enabled for this dataset. This will use the + AISBatchLoader to load the audio from AIStore. This can improve data loading efficiency in some setups. """ def __init__( @@ -76,12 +82,28 @@ def __init__( ): super().__init__() self.tokenizer = tokenizer - self.load_audio = AudioSamples(fault_tolerant=True) + self.use_ais_get_batch = os.environ.get("USE_AIS_GET_BATCH", "False").lower() == "true" + + # Try to use use_batch_loader if available (Lhotse >= 1.32.0) + try: + self.load_audio = AudioSamples(fault_tolerant=True, use_batch_loader=self.use_ais_get_batch) + except TypeError: + # Lhotse < 1.32.0 doesn't support use_batch_loader + if self.use_ais_get_batch: + import logging + + logging.warning( + "AIS batch loading requested but not supported by this Lhotse version. " + "Please upgrade to Lhotse >= 1.32.0" + ) + self.load_audio = AudioSamples(fault_tolerant=True) + self.padding_value = self.tokenizer.pad_id self.prompt = prompt self.enable_chunking = enable_chunking def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch: + # Load the audio's from AIS and add them to the CutSet audio, audio_lens, cuts = self.load_audio(cuts) # Will work if batch_size is set to 1. diff --git a/nemo/collections/common/data/lhotse/nemo_adapters.py b/nemo/collections/common/data/lhotse/nemo_adapters.py index dfc00eba034d..802c02465f45 100644 --- a/nemo/collections/common/data/lhotse/nemo_adapters.py +++ b/nemo/collections/common/data/lhotse/nemo_adapters.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import random import re import tarfile @@ -331,6 +332,7 @@ def __init__( self.slice_length = slice_length self.epoch = 0 self._validate() + self.use_ais_get_batch = os.environ.get("USE_AIS_GET_BATCH", "False").lower() == "true" def to_shards(self) -> List["LazyNeMoTarredIterator"]: """Convert this iterator to a list of separate iterators for each shard.""" @@ -370,6 +372,93 @@ def _get_seed(self) -> int: def shard_ids(self) -> List[int]: return sorted(self.shard_id_to_manifest.keys()) + def _iter_batch_for_ais_get_batch( + self, tar_path, shard_manifest, manifest_path, rng, extra_fields + ) -> Generator[Cut, None, None]: + """ + Iterator for batch reading mode (AIS get batch). + Yields cuts with URL-based recordings without opening tar files. + """ + # Calculate slice offset for random skipping + total_entries = sum(len(entries) for entries in shard_manifest.values()) + slice_offset = ( + rng.randint(0, total_entries - self.slice_length) + if self.slice_length is not None and self.slice_length < total_entries + else -1 + ) + cntr = 0 + entries_processed = 0 + + for audio_filename, manifest_entries in shard_manifest.items(): + for data in manifest_entries: + # Skip entries if we haven't reached the slice offset yet + if entries_processed < slice_offset: + entries_processed += 1 + continue + # Stop if we've reached the slice length limit + elif cntr == self.slice_length: + break + + # filter out entries with valid "_skipme" values. + if data.get("_skipme", False): + entries_processed += 1 + continue + + # Construct URL: tar_path/audio_filename + audio_url = f"{tar_path.rstrip('/')}/{audio_filename.lstrip('/')}" + + # Get metadata from manifest + duration = data.get("duration") + if duration is None: + logging.warning(f"Skipping '{audio_filename}' - missing duration in manifest") + entries_processed += 1 + continue + + offset = data.get("offset", 0.0) + sampling_rate = data.get("sampling_rate", 16000) # default to 16kHz if not specified + + # Create URL-based recording + recording = Recording( + id=audio_filename, + sources=[AudioSource(type="url", channels=[0], source=audio_url)], + sampling_rate=sampling_rate, + num_samples=compute_num_samples(duration, sampling_rate), + duration=duration, + ) + + # Create cut from recording (audio will be loaded lazily from URL when needed) + cut = recording.to_cut() + if offset > 0: + cut = cut.truncate(offset=offset, duration=duration, preserve_id=True) + cut.id = f"{cut.id}-{round(offset * 1e2):06d}-{round(duration * 1e2):06d}" + + # Add supervision (transcript metadata) + cut.supervisions.append( + SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0, + duration=cut.duration, + text=data.get(self.text_field), + language=data.get(self.lang_field), + ) + ) + + # Attach custom fields and metadata + cut.custom = _to_custom_attr_dict(data) + cut.manifest_origin = manifest_path + cut.tar_origin = tar_path + for extra_field in extra_fields: + extra_field.attach_to(cut) + + cntr += 1 + entries_processed += 1 + yield cut + + # Break outer loop if we've reached the slice length limit + if cntr == self.slice_length: + break + def _iter_sequential( self, tar_path, shard_manifest, manifest_path, rng ) -> Generator[tuple[dict, bytes], None, None]: @@ -426,6 +515,13 @@ def basename(d: dict) -> str: shard_manifest: dict[str, list[dict]] = groupby(basename, self.shard_id_to_manifest[sid]) tar_path = self.shard_id_to_tar_path[sid] + + if self.use_ais_get_batch: + # Use batch reading mode - URL-based recordings without opening tar files + yield from self._iter_batch_for_ais_get_batch( + tar_path, shard_manifest, manifest_path, rng, extra_fields + ) + continue try: for data, raw_audio, tar_info in self._iter_sequential(tar_path, shard_manifest, manifest_path, rng): try: diff --git a/nemo/collections/speechlm2/data/s2s_dataset.py b/nemo/collections/speechlm2/data/s2s_dataset.py index ae32310e28e1..dd29c1d00d32 100644 --- a/nemo/collections/speechlm2/data/s2s_dataset.py +++ b/nemo/collections/speechlm2/data/s2s_dataset.py @@ -102,9 +102,14 @@ def __init__( def __getitem__(self, cuts: CutSet) -> dict: cuts = cuts.transform_text(_strip_timestamps) source_audio, source_audio_lens = collate_audio(cuts.resample(self.source_sample_rate)) - target_audio, target_audio_lens = collate_audio( - cuts.resample(self.target_sample_rate), recording_field="target_audio" - ) + # Manually resample target_audio attribute since cuts.resample() only affects the main recording + cuts_with_resampled_target = [] + for cut in cuts: + if hasattr(cut, "target_audio") and cut.target_audio is not None: + cut.target_audio = cut.target_audio.resample(self.target_sample_rate) + cuts_with_resampled_target.append(cut) + cuts_with_resampled_target = CutSet(cuts_with_resampled_target) + target_audio, target_audio_lens = collate_audio(cuts_with_resampled_target, recording_field="target_audio") target_tokens, target_token_lens = collate_token_channel( cuts, self.tokenizer, self.frame_length, roles=self.output_roles ) diff --git a/requirements/requirements_asr.txt b/requirements/requirements_asr.txt index db55c673653e..26e4af323001 100644 --- a/requirements/requirements_asr.txt +++ b/requirements/requirements_asr.txt @@ -1,10 +1,12 @@ braceexpand ctc_segmentation==1.7.4 +diskcache editdistance einops jiwer>=3.1.0,<4.0.0 kaldi-python-io -lhotse>=1.31.1 +kaldialign<=0.9.1 +lhotse>=1.32.0 # Align with upstream PyTorch requirements librosa>=0.10.1 marshmallow @@ -19,6 +21,4 @@ ruamel.yaml scipy>=0.14 soundfile sox<=1.5.0 -kaldialign<=0.9.1 whisper_normalizer -diskcache diff --git a/tests/collections/common/test_lhotse_nemo_adapters_ais_get_batch.py b/tests/collections/common/test_lhotse_nemo_adapters_ais_get_batch.py new file mode 100644 index 000000000000..cd699859abcc --- /dev/null +++ b/tests/collections/common/test_lhotse_nemo_adapters_ais_get_batch.py @@ -0,0 +1,337 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for GetBatch logic in LazyNeMoTarredIterator with AIS enabled.""" + +import tarfile +from pathlib import Path + +import pytest +from lhotse import CutSet +from lhotse.serialization import load_jsonl, save_to_jsonl +from lhotse.testing.dummies import DummyManifest + +from nemo.collections.common.data.lhotse.nemo_adapters import LazyNeMoTarredIterator + + +@pytest.fixture +def nemo_tarred_manifest_path_for_slicing(tmp_path_factory): + """Create a tarred NeMo manifest with 20 utterances (2 shards of 10 each) for slice testing.""" + tmpdir = tmp_path_factory.mktemp("nemo_tarred_slice_data") + + # Create dummy audio files + cuts = DummyManifest(CutSet, begin_id=0, end_id=20, with_data=True).save_audios(tmpdir, progress_bar=False) + + # Create two tar files (shard 0 and shard 1) + for shard_id in [0, 1]: + tar_path = tmpdir / f"audio_{shard_id}.tar" + manifest_path = tmpdir / f"manifest_{shard_id}.json" + + start_idx = shard_id * 10 + end_idx = start_idx + 10 + shard_cuts = list(cuts)[start_idx:end_idx] + + # Create tar file + with tarfile.open(tar_path, "w") as tar: + for idx, c in enumerate(shard_cuts): + audio_file = c.recording.sources[0].source + tar.add(audio_file, arcname=f"audio_{start_idx + idx}.wav") + + # Create manifest + manifest = [] + for idx, c in enumerate(shard_cuts): + manifest.append( + { + "audio_filepath": f"audio_{start_idx + idx}.wav", + "text": f"utterance {start_idx + idx}", + "duration": c.duration, + "lang": "en", + "shard_id": shard_id, + } + ) + + save_to_jsonl(manifest, manifest_path) + + # Return paths using NeMo's shard notation + manifest_paths = f"{tmpdir}/manifest__OP_0..1_CL_.json" + tar_paths = f"{tmpdir}/audio__OP_0..1_CL_.tar" + + return str(manifest_paths), str(tar_paths) + + +@pytest.mark.unit +def test_batch_reading_with_slice_offset(nemo_tarred_manifest_path_for_slicing, monkeypatch): + """Test that slice_length and slice_offset work correctly for batch reading mode.""" + monkeypatch.setenv("USE_AIS_GET_BATCH", "true") + + manifest_path, tar_path = nemo_tarred_manifest_path_for_slicing + + # Test with slice_length=5 (should get 5 entries per shard = 10 total) + iterator = LazyNeMoTarredIterator( + manifest_path=manifest_path, + tar_paths=tar_path, + shuffle_shards=False, + slice_length=5, + shard_seed=42, + ) + + cuts = list(iterator) + + # Should have exactly 5 cuts per shard = 10 cuts total + assert len(cuts) == 10 + + # Verify cuts have valid metadata + for cut in cuts: + assert cut.has_recording + assert cut.supervisions[0].text.startswith("utterance") + assert cut.duration == 1.0 + + +@pytest.mark.unit +def test_batch_reading_slice_offset_randomness(nemo_tarred_manifest_path_for_slicing, monkeypatch): + """Test that slice_offset varies across epochs for batch reading.""" + monkeypatch.setenv("USE_AIS_GET_BATCH", "true") + + manifest_path, tar_path = nemo_tarred_manifest_path_for_slicing + + iterator = LazyNeMoTarredIterator( + manifest_path=manifest_path, + tar_paths=tar_path, + shuffle_shards=False, + slice_length=5, + shard_seed=42, + ) + + # Collect IDs from multiple epochs + epoch_ids_list = [] + for _ in range(10): + epoch_ids_list.append(tuple([cut.id for cut in iterator])) + + # Verify at least some epochs differ (randomness is working) + unique_epochs = len(set(epoch_ids_list)) + assert ( + unique_epochs > 1 + ), f"Expected multiple unique epochs, got {unique_epochs}. All epochs identical - randomness not working." + + +@pytest.mark.unit +def test_batch_reading_without_slice_length(nemo_tarred_manifest_path_for_slicing, monkeypatch): + """Test that batch reading without slice_length returns all entries.""" + monkeypatch.setenv("USE_AIS_GET_BATCH", "true") + + manifest_path, tar_path = nemo_tarred_manifest_path_for_slicing + + iterator = LazyNeMoTarredIterator( + manifest_path=manifest_path, + tar_paths=tar_path, + shuffle_shards=False, + slice_length=None, # No slicing + shard_seed=42, + ) + + cuts = list(iterator) + + # Should have all 20 cuts (10 per shard) + assert len(cuts) == 20 + + +@pytest.mark.unit +def test_batch_reading_with_skipme(nemo_tarred_manifest_path_for_slicing, monkeypatch): + """Test that batch reading correctly skips entries with _skipme=True.""" + monkeypatch.setenv("USE_AIS_GET_BATCH", "true") + + tmpdir = Path(nemo_tarred_manifest_path_for_slicing[0]).parent + + # Modify manifest to add _skipme to some entries + for shard_id in [0, 1]: + manifest_path = tmpdir / f"manifest_{shard_id}.json" + + items = list(load_jsonl(manifest_path)) + # Mark every other item as skipme + for idx, item in enumerate(items): + if idx % 2 == 0: + item['_skipme'] = True + + save_to_jsonl(items, manifest_path) + + manifest_paths = f"{tmpdir}/manifest__OP_0..1_CL_.json" + tar_paths = f"{tmpdir}/audio__OP_0..1_CL_.tar" + + iterator = LazyNeMoTarredIterator( + manifest_path=manifest_paths, + tar_paths=tar_paths, + shuffle_shards=False, + slice_length=None, + shard_seed=42, + ) + + cuts = list(iterator) + + # Should have 10 cuts (half were skipped) + assert len(cuts) == 10 + + # Verify none of the cuts have _skipme in their custom fields + for cut in cuts: + assert not cut.custom.get('_skipme', False) + + +@pytest.mark.unit +def test_batch_reading_slice_length_larger_than_manifest(nemo_tarred_manifest_path_for_slicing, monkeypatch): + """Test that slice_length larger than manifest size returns all entries.""" + monkeypatch.setenv("USE_AIS_GET_BATCH", "true") + + manifest_path, tar_path = nemo_tarred_manifest_path_for_slicing + + iterator = LazyNeMoTarredIterator( + manifest_path=manifest_path, + tar_paths=tar_path, + shuffle_shards=False, + slice_length=100, # Much larger than 10 entries per shard + shard_seed=42, + ) + + cuts = list(iterator) + + # Should still return all 20 cuts + assert len(cuts) == 20 + + +@pytest.mark.unit +def test_batch_reading_slice_offset_respects_entries_processed(nemo_tarred_manifest_path_for_slicing, monkeypatch): + """Test that slice_offset correctly counts all entries including skipped ones.""" + monkeypatch.setenv("USE_AIS_GET_BATCH", "true") + + tmpdir = Path(nemo_tarred_manifest_path_for_slicing[0]).parent + + # Modify manifest to add _skipme to first 3 entries of each shard + for shard_id in [0, 1]: + manifest_path = tmpdir / f"manifest_{shard_id}.json" + + items = list(load_jsonl(manifest_path)) + for idx in range(3): + items[idx]['_skipme'] = True + + save_to_jsonl(items, manifest_path) + + manifest_paths = f"{tmpdir}/manifest__OP_0..1_CL_.json" + tar_paths = f"{tmpdir}/audio__OP_0..1_CL_.tar" + + iterator = LazyNeMoTarredIterator( + manifest_path=manifest_paths, + tar_paths=tar_paths, + shuffle_shards=False, + slice_length=5, + shard_seed=42, + ) + + cuts = list(iterator) + + # Should have at most 10 cuts (5 per shard), but could be less if slice_offset skips valid entries + assert len(cuts) <= 10 + + # All returned cuts should not have _skipme + for cut in cuts: + assert not cut.custom.get('_skipme', False) + + +@pytest.mark.unit +def test_batch_reading_creates_url_sources(nemo_tarred_manifest_path_for_slicing, monkeypatch): + """Test that batch mode creates URL-based recording sources.""" + monkeypatch.setenv("USE_AIS_GET_BATCH", "true") + + manifest_path, tar_path = nemo_tarred_manifest_path_for_slicing + + iterator = LazyNeMoTarredIterator( + manifest_path=manifest_path, + tar_paths=tar_path, + shuffle_shards=False, + slice_length=5, + shard_seed=42, + ) + + cuts = list(iterator) + + # Verify recording sources are URLs + for cut in cuts: + source = cut.recording.sources[0] + assert source.type == "url" + assert ".tar/" in source.source + assert source.source.endswith(".wav") + + +@pytest.mark.unit +def test_batch_reading_url_format(nemo_tarred_manifest_path_for_slicing, monkeypatch): + """Test that URL format is tar_path/audio_filename.""" + monkeypatch.setenv("USE_AIS_GET_BATCH", "true") + + manifest_path, tar_path = nemo_tarred_manifest_path_for_slicing + + iterator = LazyNeMoTarredIterator( + manifest_path=manifest_path, + tar_paths=tar_path, + shuffle_shards=False, + slice_length=None, + shard_seed=42, + ) + + cuts = list(iterator) + tmpdir = Path(manifest_path.split("__OP_")[0]).parent + + # Verify URL structure + for cut in cuts: + url = cut.recording.sources[0].source + assert str(tmpdir) in url + assert "audio_" in url + + # Verify exactly one .tar in URL + parts = url.split("/") + tar_parts = [p for p in parts if p.endswith(".tar")] + assert len(tar_parts) == 1 + + +@pytest.mark.unit +def test_batch_vs_sequential_mode(nemo_tarred_manifest_path_for_slicing, monkeypatch): + """Test that batch and sequential modes produce different source types.""" + manifest_path, tar_path = nemo_tarred_manifest_path_for_slicing + + # Batch mode + monkeypatch.setenv("USE_AIS_GET_BATCH", "true") + batch_iterator = LazyNeMoTarredIterator( + manifest_path=manifest_path, + tar_paths=tar_path, + shuffle_shards=False, + slice_length=None, + shard_seed=42, + ) + batch_cuts = list(batch_iterator) + + # Sequential mode + monkeypatch.setenv("USE_AIS_GET_BATCH", "false") + seq_iterator = LazyNeMoTarredIterator( + manifest_path=manifest_path, + tar_paths=tar_path, + shuffle_shards=False, + slice_length=None, + shard_seed=42, + ) + seq_cuts = list(seq_iterator) + + # Batch mode should use URL sources + assert batch_cuts[0].recording.sources[0].type == "url" + + # Sequential mode should use different approach + assert ( + seq_cuts[0].recording.sources[0].type != "url" + or seq_cuts[0].recording.sources[0].source != batch_cuts[0].recording.sources[0].source + )