Skip to content

Commit 18557ae

Browse files
Add support for AIS batch loading for ASR audio processing
* Introduced environment variable USE_AIS_GET_BATCH to toggle AIS batch loading. * Updated PromptedAudioToTextLhotseDataset to utilize AISBatchLoader when enabled. * Modified LazyNeMoTarredIterator to handle URL-based recordings when AIS batch loading is active. Implement URL-based audio loading using AIStore's Get-Batch API to improve data pipeline efficiency. This allows batch fetching of multiple audio files without local tar archive extraction, offloading processing to AIStore. Signed-off-by: Abhishek Gaikwad <[email protected]>
1 parent 6389a71 commit 18557ae

File tree

4 files changed

+459
-4
lines changed

4 files changed

+459
-4
lines changed

nemo/collections/asr/data/audio_to_text_lhotse_prompted.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
1415
from dataclasses import dataclass
1516
from typing import Optional, Union
1617

@@ -66,6 +67,11 @@ class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):
6667
If `enable_chunking` is True, each audio sample is split into optimally sized chunks
6768
(see `find_optimal_chunk_size` and `chunk_waveform`). This is useful for long audio inputs,
6869
allowing the model to process them in manageable segments.
70+
71+
NOTE:
72+
If the environment variable `USE_AIS_GET_BATCH` is set to `true` (case-insensitive),
73+
then batch audio loading from AIStore will be enabled for this dataset. This will use the
74+
AISBatchLoader to load the audio from AIStore. This can improve data loading efficiency in some setups.
6975
"""
7076

7177
def __init__(
@@ -76,12 +82,28 @@ def __init__(
7682
):
7783
super().__init__()
7884
self.tokenizer = tokenizer
79-
self.load_audio = AudioSamples(fault_tolerant=True)
85+
self.use_ais_get_batch = os.environ.get("USE_AIS_GET_BATCH", "False").lower() == "true"
86+
87+
# Try to use use_batch_loader if available (Lhotse >= 1.32.0)
88+
try:
89+
self.load_audio = AudioSamples(fault_tolerant=True, use_batch_loader=self.use_ais_get_batch)
90+
except TypeError:
91+
# Lhotse < 1.32.0 doesn't support use_batch_loader
92+
if self.use_ais_get_batch:
93+
import logging
94+
95+
logging.warning(
96+
"AIS batch loading requested but not supported by this Lhotse version. "
97+
"Please upgrade to Lhotse >= 1.32.0"
98+
)
99+
self.load_audio = AudioSamples(fault_tolerant=True)
100+
80101
self.padding_value = self.tokenizer.pad_id
81102
self.prompt = prompt
82103
self.enable_chunking = enable_chunking
83104

84105
def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch:
106+
# Load the audio's from AIS and add them to the CutSet
85107
audio, audio_lens, cuts = self.load_audio(cuts)
86108

87109
# Will work if batch_size is set to 1.

nemo/collections/common/data/lhotse/nemo_adapters.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
import random
1617
import re
1718
import tarfile
@@ -331,6 +332,7 @@ def __init__(
331332
self.slice_length = slice_length
332333
self.epoch = 0
333334
self._validate()
335+
self.use_ais_get_batch = os.environ.get("USE_AIS_GET_BATCH", "False").lower() == "true"
334336

335337
def to_shards(self) -> List["LazyNeMoTarredIterator"]:
336338
"""Convert this iterator to a list of separate iterators for each shard."""
@@ -370,6 +372,93 @@ def _get_seed(self) -> int:
370372
def shard_ids(self) -> List[int]:
371373
return sorted(self.shard_id_to_manifest.keys())
372374

375+
def _iter_batch_for_ais_get_batch(
376+
self, tar_path, shard_manifest, manifest_path, rng, extra_fields
377+
) -> Generator[Cut, None, None]:
378+
"""
379+
Iterator for batch reading mode (AIS get batch).
380+
Yields cuts with URL-based recordings without opening tar files.
381+
"""
382+
# Calculate slice offset for random skipping
383+
total_entries = sum(len(entries) for entries in shard_manifest.values())
384+
slice_offset = (
385+
rng.randint(0, total_entries - self.slice_length)
386+
if self.slice_length is not None and self.slice_length < total_entries
387+
else -1
388+
)
389+
cntr = 0
390+
entries_processed = 0
391+
392+
for audio_filename, manifest_entries in shard_manifest.items():
393+
for data in manifest_entries:
394+
# Skip entries if we haven't reached the slice offset yet
395+
if entries_processed < slice_offset:
396+
entries_processed += 1
397+
continue
398+
# Stop if we've reached the slice length limit
399+
elif cntr == self.slice_length:
400+
break
401+
402+
# filter out entries with valid "_skipme" values.
403+
if data.get("_skipme", False):
404+
entries_processed += 1
405+
continue
406+
407+
# Construct URL: tar_path/audio_filename
408+
audio_url = f"{tar_path.rstrip('/')}/{audio_filename.lstrip('/')}"
409+
410+
# Get metadata from manifest
411+
duration = data.get("duration")
412+
if duration is None:
413+
logging.warning(f"Skipping '{audio_filename}' - missing duration in manifest")
414+
entries_processed += 1
415+
continue
416+
417+
offset = data.get("offset", 0.0)
418+
sampling_rate = data.get("sampling_rate", 16000) # default to 16kHz if not specified
419+
420+
# Create URL-based recording
421+
recording = Recording(
422+
id=audio_filename,
423+
sources=[AudioSource(type="url", channels=[0], source=audio_url)],
424+
sampling_rate=sampling_rate,
425+
num_samples=compute_num_samples(duration, sampling_rate),
426+
duration=duration,
427+
)
428+
429+
# Create cut from recording (audio will be loaded lazily from URL when needed)
430+
cut = recording.to_cut()
431+
if offset > 0:
432+
cut = cut.truncate(offset=offset, duration=duration, preserve_id=True)
433+
cut.id = f"{cut.id}-{round(offset * 1e2):06d}-{round(duration * 1e2):06d}"
434+
435+
# Add supervision (transcript metadata)
436+
cut.supervisions.append(
437+
SupervisionSegment(
438+
id=cut.id,
439+
recording_id=cut.recording_id,
440+
start=0,
441+
duration=cut.duration,
442+
text=data.get(self.text_field),
443+
language=data.get(self.lang_field),
444+
)
445+
)
446+
447+
# Attach custom fields and metadata
448+
cut.custom = _to_custom_attr_dict(data)
449+
cut.manifest_origin = manifest_path
450+
cut.tar_origin = tar_path
451+
for extra_field in extra_fields:
452+
extra_field.attach_to(cut)
453+
454+
cntr += 1
455+
entries_processed += 1
456+
yield cut
457+
458+
# Break outer loop if we've reached the slice length limit
459+
if cntr == self.slice_length:
460+
break
461+
373462
def _iter_sequential(
374463
self, tar_path, shard_manifest, manifest_path, rng
375464
) -> Generator[tuple[dict, bytes], None, None]:
@@ -426,6 +515,13 @@ def basename(d: dict) -> str:
426515

427516
shard_manifest: dict[str, list[dict]] = groupby(basename, self.shard_id_to_manifest[sid])
428517
tar_path = self.shard_id_to_tar_path[sid]
518+
519+
if self.use_ais_get_batch:
520+
# Use batch reading mode - URL-based recordings without opening tar files
521+
yield from self._iter_batch_for_ais_get_batch(
522+
tar_path, shard_manifest, manifest_path, rng, extra_fields
523+
)
524+
continue
429525
try:
430526
for data, raw_audio, tar_info in self._iter_sequential(tar_path, shard_manifest, manifest_path, rng):
431527
try:

requirements/requirements_asr.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
braceexpand
22
ctc_segmentation==1.7.4
3+
diskcache
34
editdistance
45
einops
56
jiwer>=3.1.0,<4.0.0
67
kaldi-python-io
7-
lhotse>=1.31.1
8+
kaldialign<=0.9.1
9+
lhotse>=1.32.0
810
# Align with upstream PyTorch requirements
911
librosa>=0.10.1
1012
marshmallow
@@ -19,6 +21,4 @@ ruamel.yaml
1921
scipy>=0.14
2022
soundfile
2123
sox<=1.5.0
22-
kaldialign<=0.9.1
2324
whisper_normalizer
24-
diskcache

0 commit comments

Comments
 (0)