Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion nemo/collections/asr/data/audio_to_text_lhotse_prompted.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from typing import Optional, Union

Expand Down Expand Up @@ -66,6 +67,11 @@ class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):
If `enable_chunking` is True, each audio sample is split into optimally sized chunks
(see `find_optimal_chunk_size` and `chunk_waveform`). This is useful for long audio inputs,
allowing the model to process them in manageable segments.

NOTE:
If the environment variable `USE_AIS_GET_BATCH` is set to `true` (case-insensitive),
then batch audio loading from AIStore will be enabled for this dataset. This will use the
AISBatchLoader to load the audio from AIStore. This can improve data loading efficiency in some setups.
"""

def __init__(
Expand All @@ -76,12 +82,28 @@ def __init__(
):
super().__init__()
self.tokenizer = tokenizer
self.load_audio = AudioSamples(fault_tolerant=True)
self.use_ais_get_batch = os.environ.get("USE_AIS_GET_BATCH", "False").lower() == "true"

# Try to use use_batch_loader if available (Lhotse >= 1.32.0)
try:
self.load_audio = AudioSamples(fault_tolerant=True, use_batch_loader=self.use_ais_get_batch)
except TypeError:
# Lhotse < 1.32.0 doesn't support use_batch_loader
if self.use_ais_get_batch:
import logging

logging.warning(
"AIS batch loading requested but not supported by this Lhotse version. "
"Please upgrade to Lhotse >= 1.32.0"
)
self.load_audio = AudioSamples(fault_tolerant=True)

self.padding_value = self.tokenizer.pad_id
self.prompt = prompt
self.enable_chunking = enable_chunking

def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch:
# Load the audio's from AIS and add them to the CutSet
audio, audio_lens, cuts = self.load_audio(cuts)

# Will work if batch_size is set to 1.
Expand Down
96 changes: 96 additions & 0 deletions nemo/collections/common/data/lhotse/nemo_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import random
import re
import tarfile
Expand Down Expand Up @@ -331,6 +332,7 @@ def __init__(
self.slice_length = slice_length
self.epoch = 0
self._validate()
self.use_ais_get_batch = os.environ.get("USE_AIS_GET_BATCH", "False").lower() == "true"

def to_shards(self) -> List["LazyNeMoTarredIterator"]:
"""Convert this iterator to a list of separate iterators for each shard."""
Expand Down Expand Up @@ -370,6 +372,93 @@ def _get_seed(self) -> int:
def shard_ids(self) -> List[int]:
return sorted(self.shard_id_to_manifest.keys())

def _iter_batch_for_ais_get_batch(
self, tar_path, shard_manifest, manifest_path, rng, extra_fields
) -> Generator[Cut, None, None]:
"""
Iterator for batch reading mode (AIS get batch).
Yields cuts with URL-based recordings without opening tar files.
"""
# Calculate slice offset for random skipping
total_entries = sum(len(entries) for entries in shard_manifest.values())
slice_offset = (
rng.randint(0, total_entries - self.slice_length)
if self.slice_length is not None and self.slice_length < total_entries
else -1
)
cntr = 0
entries_processed = 0

for audio_filename, manifest_entries in shard_manifest.items():
for data in manifest_entries:
# Skip entries if we haven't reached the slice offset yet
if entries_processed < slice_offset:
entries_processed += 1
continue
# Stop if we've reached the slice length limit
elif cntr == self.slice_length:
break

# filter out entries with valid "_skipme" values.
if data.get("_skipme", False):
entries_processed += 1
continue

# Construct URL: tar_path/audio_filename
audio_url = f"{tar_path.rstrip('/')}/{audio_filename.lstrip('/')}"

# Get metadata from manifest
duration = data.get("duration")
if duration is None:
logging.warning(f"Skipping '{audio_filename}' - missing duration in manifest")
entries_processed += 1
continue

offset = data.get("offset", 0.0)
sampling_rate = data.get("sampling_rate", 16000) # default to 16kHz if not specified

# Create URL-based recording
recording = Recording(
id=audio_filename,
sources=[AudioSource(type="url", channels=[0], source=audio_url)],
sampling_rate=sampling_rate,
num_samples=compute_num_samples(duration, sampling_rate),
duration=duration,
)

# Create cut from recording (audio will be loaded lazily from URL when needed)
cut = recording.to_cut()
if offset > 0:
cut = cut.truncate(offset=offset, duration=duration, preserve_id=True)
cut.id = f"{cut.id}-{round(offset * 1e2):06d}-{round(duration * 1e2):06d}"

# Add supervision (transcript metadata)
cut.supervisions.append(
SupervisionSegment(
id=cut.id,
recording_id=cut.recording_id,
start=0,
duration=cut.duration,
text=data.get(self.text_field),
language=data.get(self.lang_field),
)
)

# Attach custom fields and metadata
cut.custom = _to_custom_attr_dict(data)
cut.manifest_origin = manifest_path
cut.tar_origin = tar_path
for extra_field in extra_fields:
extra_field.attach_to(cut)

cntr += 1
entries_processed += 1
yield cut

# Break outer loop if we've reached the slice length limit
if cntr == self.slice_length:
break

def _iter_sequential(
self, tar_path, shard_manifest, manifest_path, rng
) -> Generator[tuple[dict, bytes], None, None]:
Expand Down Expand Up @@ -426,6 +515,13 @@ def basename(d: dict) -> str:

shard_manifest: dict[str, list[dict]] = groupby(basename, self.shard_id_to_manifest[sid])
tar_path = self.shard_id_to_tar_path[sid]

if self.use_ais_get_batch:
# Use batch reading mode - URL-based recordings without opening tar files
yield from self._iter_batch_for_ais_get_batch(
tar_path, shard_manifest, manifest_path, rng, extra_fields
)
continue
try:
for data, raw_audio, tar_info in self._iter_sequential(tar_path, shard_manifest, manifest_path, rng):
try:
Expand Down
11 changes: 8 additions & 3 deletions nemo/collections/speechlm2/data/s2s_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,14 @@ def __init__(
def __getitem__(self, cuts: CutSet) -> dict:
cuts = cuts.transform_text(_strip_timestamps)
source_audio, source_audio_lens = collate_audio(cuts.resample(self.source_sample_rate))
target_audio, target_audio_lens = collate_audio(
cuts.resample(self.target_sample_rate), recording_field="target_audio"
)
# Manually resample target_audio attribute since cuts.resample() only affects the main recording
cuts_with_resampled_target = []
for cut in cuts:
if hasattr(cut, "target_audio") and cut.target_audio is not None:
cut.target_audio = cut.target_audio.resample(self.target_sample_rate)
cuts_with_resampled_target.append(cut)
cuts_with_resampled_target = CutSet(cuts_with_resampled_target)
target_audio, target_audio_lens = collate_audio(cuts_with_resampled_target, recording_field="target_audio")
target_tokens, target_token_lens = collate_token_channel(
cuts, self.tokenizer, self.frame_length, roles=self.output_roles
)
Expand Down
6 changes: 3 additions & 3 deletions requirements/requirements_asr.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
braceexpand
ctc_segmentation==1.7.4
diskcache
editdistance
einops
jiwer>=3.1.0,<4.0.0
kaldi-python-io
lhotse>=1.31.1
kaldialign<=0.9.1
lhotse>=1.32.0
# Align with upstream PyTorch requirements
librosa>=0.10.1
marshmallow
Expand All @@ -19,6 +21,4 @@ ruamel.yaml
scipy>=0.14
soundfile
sox<=1.5.0
kaldialign<=0.9.1
whisper_normalizer
diskcache
Loading
Loading