1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import os
1516import random
1617import re
1718import tarfile
@@ -331,6 +332,7 @@ def __init__(
331332 self .slice_length = slice_length
332333 self .epoch = 0
333334 self ._validate ()
335+ self .use_ais_get_batch = os .environ .get ("USE_AIS_GET_BATCH" , "False" ).lower () == "true"
334336
335337 def to_shards (self ) -> List ["LazyNeMoTarredIterator" ]:
336338 """Convert this iterator to a list of separate iterators for each shard."""
@@ -426,29 +428,62 @@ def basename(d: dict) -> str:
426428
427429 shard_manifest : dict [str , list [dict ]] = groupby (basename , self .shard_id_to_manifest [sid ])
428430 tar_path = self .shard_id_to_tar_path [sid ]
429- try :
430- for data , raw_audio , tar_info in self ._iter_sequential (tar_path , shard_manifest , manifest_path , rng ):
431- try :
432- meta = soundfile .info (BytesIO (raw_audio ))
433- except Exception :
434- logging .warning (f"Skipped corrupted file '{ tar_info .path } ' in { tar_path = } ." )
435- continue
436- recording = Recording (
437- id = tar_info .path ,
438- sources = [AudioSource (type = "memory" , channels = list (range (meta .channels )), source = raw_audio )],
439- sampling_rate = int (meta .samplerate ),
440- num_samples = meta .frames ,
441- duration = meta .duration ,
442- )
443- cuts_for_recording = []
444- for data in sorted (shard_manifest [tar_info .name ], key = lambda d : d ["audio_filepath" ]):
431+ if self .use_ais_get_batch :
432+ # Don't open tar file, just prepare URL-based recordings
433+ # Calculate slice offset for random skipping
434+ total_entries = sum (len (entries ) for entries in shard_manifest .values ())
435+ slice_offset = (
436+ rng .randint (0 , total_entries - self .slice_length )
437+ if self .slice_length is not None and self .slice_length < total_entries
438+ else - 1
439+ )
440+ cntr = 0
441+ entries_processed = 0
442+
443+ for audio_filename , manifest_entries in shard_manifest .items ():
444+ for data in manifest_entries :
445+ # Skip entries if we haven't reached the slice offset yet
446+ if entries_processed < slice_offset :
447+ entries_processed += 1
448+ continue
449+ # Stop if we've reached the slice length limit
450+ elif cntr == self .slice_length :
451+ break
452+
445453 # filter out entries with valid "_skipme" values.
446454 if data .get ("_skipme" , False ):
455+ entries_processed += 1
447456 continue
448- # Cut the recording into corresponding segment and discard audio data outside the segment.
449- cut = make_cut_with_subset_inmemory_recording (
450- recording , offset = data .get ("offset" , 0.0 ), duration = data .get ("duration" )
457+
458+ # Construct URL: tar_path/audio_filename
459+ audio_url = f"{ tar_path .rstrip ('/' )} /{ audio_filename .lstrip ('/' )} "
460+
461+ # Get metadata from manifest
462+ duration = data .get ("duration" )
463+ if duration is None :
464+ logging .warning (f"Skipping '{ audio_filename } ' - missing duration in manifest" )
465+ entries_processed += 1
466+ continue
467+
468+ offset = data .get ("offset" , 0.0 )
469+ sampling_rate = data .get ("sampling_rate" , 16000 ) # default to 16kHz if not specified
470+
471+ # Create URL-based recording
472+ recording = Recording (
473+ id = audio_filename ,
474+ sources = [AudioSource (type = "url" , channels = [0 ], source = audio_url )],
475+ sampling_rate = sampling_rate ,
476+ num_samples = compute_num_samples (duration , sampling_rate ),
477+ duration = duration ,
451478 )
479+
480+ # Create cut from recording (audio will be loaded lazily from URL when needed)
481+ cut = recording .to_cut ()
482+ if offset > 0 :
483+ cut = cut .truncate (offset = offset , duration = duration , preserve_id = True )
484+ cut .id = f"{ cut .id } -{ round (offset * 1e2 ):06d} -{ round (duration * 1e2 ):06d} "
485+
486+ # Add supervision (transcript metadata)
452487 cut .supervisions .append (
453488 SupervisionSegment (
454489 id = cut .id ,
@@ -459,19 +494,72 @@ def basename(d: dict) -> str:
459494 language = data .get (self .lang_field ),
460495 )
461496 )
497+
498+ # Attach custom fields and metadata
462499 cut .custom = _to_custom_attr_dict (data )
463500 cut .manifest_origin = manifest_path
464501 cut .tar_origin = tar_path
465502 for extra_field in extra_fields :
466503 extra_field .attach_to (cut )
467- cuts_for_recording .append (cut )
468- del recording # free the memory - helps with very large audio files
469- del raw_audio
470- yield from cuts_for_recording
471- except tarfile .ReadError :
472- logging .warning (
473- f"Skipping tar file due to read errors (unstable storage or bad file?): { tar_path = } " ,
474- )
504+
505+ cntr += 1
506+ entries_processed += 1
507+ yield cut
508+
509+ # Break outer loop if we've reached the slice length limit
510+ if cntr == self .slice_length :
511+ break
512+ else :
513+ try :
514+ for data , raw_audio , tar_info in self ._iter_sequential (
515+ tar_path , shard_manifest , manifest_path , rng
516+ ):
517+ try :
518+ meta = soundfile .info (BytesIO (raw_audio ))
519+ except Exception :
520+ logging .warning (f"Skipped corrupted file '{ tar_info .path } ' in { tar_path = } ." )
521+ continue
522+ recording = Recording (
523+ id = tar_info .path ,
524+ sources = [
525+ AudioSource (type = "memory" , channels = list (range (meta .channels )), source = raw_audio )
526+ ],
527+ sampling_rate = int (meta .samplerate ),
528+ num_samples = meta .frames ,
529+ duration = meta .duration ,
530+ )
531+ cuts_for_recording = []
532+ for data in shard_manifest [tar_info .name ]:
533+ # filter out entries with valid "_skipme" values.
534+ if data .get ("_skipme" , False ):
535+ continue
536+ # Cut the recording into corresponding segment and discard audio data outside the segment.
537+ cut = make_cut_with_subset_inmemory_recording (
538+ recording , offset = data .get ("offset" , 0.0 ), duration = data .get ("duration" )
539+ )
540+ cut .supervisions .append (
541+ SupervisionSegment (
542+ id = cut .id ,
543+ recording_id = cut .recording_id ,
544+ start = 0 ,
545+ duration = cut .duration ,
546+ text = data .get (self .text_field ),
547+ language = data .get (self .lang_field ),
548+ )
549+ )
550+ cut .custom = _to_custom_attr_dict (data )
551+ cut .manifest_origin = manifest_path
552+ cut .tar_origin = tar_path
553+ for extra_field in extra_fields :
554+ extra_field .attach_to (cut )
555+ cuts_for_recording .append (cut )
556+ del recording # free the memory - helps with very large audio files
557+ del raw_audio
558+ yield from cuts_for_recording
559+ except tarfile .ReadError :
560+ logging .warning (
561+ f"Skipping tar file due to read errors (unstable storage or bad file?): { tar_path = } " ,
562+ )
475563
476564 self .epoch += 1
477565
0 commit comments