Skip to content

Commit cbae6c3

Browse files
Add support for AIS batch loading for ASR audio processing
* Introduced environment variable USE_AIS_GET_BATCH to toggle AIS batch loading. * Updated PromptedAudioToTextLhotseDataset to utilize AISBatchLoader when enabled. * Modified LazyNeMoTarredIterator to handle URL-based recordings when AIS batch loading is active. Implement URL-based audio loading using AIStore's Get-Batch API to improve data pipeline efficiency. This allows batch fetching of multiple audio files without local tar archive extraction, offloading processing to AIStore. Signed-off-by: Abhishek Gaikwad <[email protected]>
1 parent e0f7ca9 commit cbae6c3

File tree

4 files changed

+478
-31
lines changed

4 files changed

+478
-31
lines changed

nemo/collections/asr/data/audio_to_text_lhotse_prompted.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
1415
from dataclasses import dataclass
1516
from typing import Optional, Union
1617

@@ -66,6 +67,11 @@ class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):
6667
If `enable_chunking` is True, each audio sample is split into optimally sized chunks
6768
(see `find_optimal_chunk_size` and `chunk_waveform`). This is useful for long audio inputs,
6869
allowing the model to process them in manageable segments.
70+
71+
NOTE:
72+
If the environment variable `USE_AIS_GET_BATCH` is set to `true` (case-insensitive),
73+
then batch audio loading from AIStore will be enabled for this dataset. This will use the
74+
AISBatchLoader to load the audio from AIStore. This can improve data loading efficiency in some setups.
6975
"""
7076

7177
def __init__(
@@ -76,12 +82,28 @@ def __init__(
7682
):
7783
super().__init__()
7884
self.tokenizer = tokenizer
79-
self.load_audio = AudioSamples(fault_tolerant=True)
85+
self.use_ais_get_batch = os.environ.get("USE_AIS_GET_BATCH", "False").lower() == "true"
86+
87+
# Try to use use_batch_loader if available (Lhotse >= 1.32.0)
88+
try:
89+
self.load_audio = AudioSamples(fault_tolerant=True, use_batch_loader=self.use_ais_get_batch)
90+
except TypeError:
91+
# Lhotse < 1.32.0 doesn't support use_batch_loader
92+
if self.use_ais_get_batch:
93+
import logging
94+
95+
logging.warning(
96+
"AIS batch loading requested but not supported by this Lhotse version. "
97+
"Please upgrade to Lhotse >= 1.32.0"
98+
)
99+
self.load_audio = AudioSamples(fault_tolerant=True)
100+
80101
self.padding_value = self.tokenizer.pad_id
81102
self.prompt = prompt
82103
self.enable_chunking = enable_chunking
83104

84105
def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch:
106+
# Load the audio's from AIS and add them to the CutSet
85107
audio, audio_lens, cuts = self.load_audio(cuts)
86108

87109
# Will work if batch_size is set to 1.

nemo/collections/common/data/lhotse/nemo_adapters.py

Lines changed: 115 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
import random
1617
import re
1718
import tarfile
@@ -331,6 +332,7 @@ def __init__(
331332
self.slice_length = slice_length
332333
self.epoch = 0
333334
self._validate()
335+
self.use_ais_get_batch = os.environ.get("USE_AIS_GET_BATCH", "False").lower() == "true"
334336

335337
def to_shards(self) -> List["LazyNeMoTarredIterator"]:
336338
"""Convert this iterator to a list of separate iterators for each shard."""
@@ -426,29 +428,62 @@ def basename(d: dict) -> str:
426428

427429
shard_manifest: dict[str, list[dict]] = groupby(basename, self.shard_id_to_manifest[sid])
428430
tar_path = self.shard_id_to_tar_path[sid]
429-
try:
430-
for data, raw_audio, tar_info in self._iter_sequential(tar_path, shard_manifest, manifest_path, rng):
431-
try:
432-
meta = soundfile.info(BytesIO(raw_audio))
433-
except Exception:
434-
logging.warning(f"Skipped corrupted file '{tar_info.path}' in {tar_path=}.")
435-
continue
436-
recording = Recording(
437-
id=tar_info.path,
438-
sources=[AudioSource(type="memory", channels=list(range(meta.channels)), source=raw_audio)],
439-
sampling_rate=int(meta.samplerate),
440-
num_samples=meta.frames,
441-
duration=meta.duration,
442-
)
443-
cuts_for_recording = []
444-
for data in sorted(shard_manifest[tar_info.name], key=lambda d: d["audio_filepath"]):
431+
if self.use_ais_get_batch:
432+
# Don't open tar file, just prepare URL-based recordings
433+
# Calculate slice offset for random skipping
434+
total_entries = sum(len(entries) for entries in shard_manifest.values())
435+
slice_offset = (
436+
rng.randint(0, total_entries - self.slice_length)
437+
if self.slice_length is not None and self.slice_length < total_entries
438+
else -1
439+
)
440+
cntr = 0
441+
entries_processed = 0
442+
443+
for audio_filename, manifest_entries in shard_manifest.items():
444+
for data in manifest_entries:
445+
# Skip entries if we haven't reached the slice offset yet
446+
if entries_processed < slice_offset:
447+
entries_processed += 1
448+
continue
449+
# Stop if we've reached the slice length limit
450+
elif cntr == self.slice_length:
451+
break
452+
445453
# filter out entries with valid "_skipme" values.
446454
if data.get("_skipme", False):
455+
entries_processed += 1
447456
continue
448-
# Cut the recording into corresponding segment and discard audio data outside the segment.
449-
cut = make_cut_with_subset_inmemory_recording(
450-
recording, offset=data.get("offset", 0.0), duration=data.get("duration")
457+
458+
# Construct URL: tar_path/audio_filename
459+
audio_url = f"{tar_path.rstrip('/')}/{audio_filename.lstrip('/')}"
460+
461+
# Get metadata from manifest
462+
duration = data.get("duration")
463+
if duration is None:
464+
logging.warning(f"Skipping '{audio_filename}' - missing duration in manifest")
465+
entries_processed += 1
466+
continue
467+
468+
offset = data.get("offset", 0.0)
469+
sampling_rate = data.get("sampling_rate", 16000) # default to 16kHz if not specified
470+
471+
# Create URL-based recording
472+
recording = Recording(
473+
id=audio_filename,
474+
sources=[AudioSource(type="url", channels=[0], source=audio_url)],
475+
sampling_rate=sampling_rate,
476+
num_samples=compute_num_samples(duration, sampling_rate),
477+
duration=duration,
451478
)
479+
480+
# Create cut from recording (audio will be loaded lazily from URL when needed)
481+
cut = recording.to_cut()
482+
if offset > 0:
483+
cut = cut.truncate(offset=offset, duration=duration, preserve_id=True)
484+
cut.id = f"{cut.id}-{round(offset * 1e2):06d}-{round(duration * 1e2):06d}"
485+
486+
# Add supervision (transcript metadata)
452487
cut.supervisions.append(
453488
SupervisionSegment(
454489
id=cut.id,
@@ -459,19 +494,72 @@ def basename(d: dict) -> str:
459494
language=data.get(self.lang_field),
460495
)
461496
)
497+
498+
# Attach custom fields and metadata
462499
cut.custom = _to_custom_attr_dict(data)
463500
cut.manifest_origin = manifest_path
464501
cut.tar_origin = tar_path
465502
for extra_field in extra_fields:
466503
extra_field.attach_to(cut)
467-
cuts_for_recording.append(cut)
468-
del recording # free the memory - helps with very large audio files
469-
del raw_audio
470-
yield from cuts_for_recording
471-
except tarfile.ReadError:
472-
logging.warning(
473-
f"Skipping tar file due to read errors (unstable storage or bad file?): {tar_path=}",
474-
)
504+
505+
cntr += 1
506+
entries_processed += 1
507+
yield cut
508+
509+
# Break outer loop if we've reached the slice length limit
510+
if cntr == self.slice_length:
511+
break
512+
else:
513+
try:
514+
for data, raw_audio, tar_info in self._iter_sequential(
515+
tar_path, shard_manifest, manifest_path, rng
516+
):
517+
try:
518+
meta = soundfile.info(BytesIO(raw_audio))
519+
except Exception:
520+
logging.warning(f"Skipped corrupted file '{tar_info.path}' in {tar_path=}.")
521+
continue
522+
recording = Recording(
523+
id=tar_info.path,
524+
sources=[
525+
AudioSource(type="memory", channels=list(range(meta.channels)), source=raw_audio)
526+
],
527+
sampling_rate=int(meta.samplerate),
528+
num_samples=meta.frames,
529+
duration=meta.duration,
530+
)
531+
cuts_for_recording = []
532+
for data in shard_manifest[tar_info.name]:
533+
# filter out entries with valid "_skipme" values.
534+
if data.get("_skipme", False):
535+
continue
536+
# Cut the recording into corresponding segment and discard audio data outside the segment.
537+
cut = make_cut_with_subset_inmemory_recording(
538+
recording, offset=data.get("offset", 0.0), duration=data.get("duration")
539+
)
540+
cut.supervisions.append(
541+
SupervisionSegment(
542+
id=cut.id,
543+
recording_id=cut.recording_id,
544+
start=0,
545+
duration=cut.duration,
546+
text=data.get(self.text_field),
547+
language=data.get(self.lang_field),
548+
)
549+
)
550+
cut.custom = _to_custom_attr_dict(data)
551+
cut.manifest_origin = manifest_path
552+
cut.tar_origin = tar_path
553+
for extra_field in extra_fields:
554+
extra_field.attach_to(cut)
555+
cuts_for_recording.append(cut)
556+
del recording # free the memory - helps with very large audio files
557+
del raw_audio
558+
yield from cuts_for_recording
559+
except tarfile.ReadError:
560+
logging.warning(
561+
f"Skipping tar file due to read errors (unstable storage or bad file?): {tar_path=}",
562+
)
475563

476564
self.epoch += 1
477565

requirements/requirements_asr.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
braceexpand
22
ctc_segmentation==1.7.4
3+
diskcache
34
editdistance
45
einops
56
jiwer>=3.1.0,<4.0.0
67
kaldi-python-io
7-
lhotse>=1.31.1
8+
kaldialign<=0.9.1
9+
lhotse>=1.32.0
810
# Align with upstream PyTorch requirements
911
librosa>=0.10.1
1012
marshmallow
@@ -19,6 +21,4 @@ ruamel.yaml
1921
scipy>=0.14
2022
soundfile
2123
sox<=1.5.0
22-
kaldialign<=0.9.1
2324
whisper_normalizer
24-
diskcache

0 commit comments

Comments
 (0)