1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import os
1516import random
1617import re
1718import tarfile
@@ -331,6 +332,7 @@ def __init__(
331332 self .slice_length = slice_length
332333 self .epoch = 0
333334 self ._validate ()
335+ self .use_ais_get_batch = os .environ .get ("USE_AIS_GET_BATCH" , "False" ).lower () == "true"
334336
335337 def to_shards (self ) -> List ["LazyNeMoTarredIterator" ]:
336338 """Convert this iterator to a list of separate iterators for each shard."""
@@ -370,6 +372,93 @@ def _get_seed(self) -> int:
370372 def shard_ids (self ) -> List [int ]:
371373 return sorted (self .shard_id_to_manifest .keys ())
372374
375+ def _iter_batch_for_ais_get_batch (
376+ self , tar_path , shard_manifest , manifest_path , rng , extra_fields
377+ ) -> Generator [Cut , None , None ]:
378+ """
379+ Iterator for batch reading mode (AIS get batch).
380+ Yields cuts with URL-based recordings without opening tar files.
381+ """
382+ # Calculate slice offset for random skipping
383+ total_entries = sum (len (entries ) for entries in shard_manifest .values ())
384+ slice_offset = (
385+ rng .randint (0 , total_entries - self .slice_length )
386+ if self .slice_length is not None and self .slice_length < total_entries
387+ else - 1
388+ )
389+ cntr = 0
390+ entries_processed = 0
391+
392+ for audio_filename , manifest_entries in shard_manifest .items ():
393+ for data in manifest_entries :
394+ # Skip entries if we haven't reached the slice offset yet
395+ if entries_processed < slice_offset :
396+ entries_processed += 1
397+ continue
398+ # Stop if we've reached the slice length limit
399+ elif cntr == self .slice_length :
400+ break
401+
402+ # filter out entries with valid "_skipme" values.
403+ if data .get ("_skipme" , False ):
404+ entries_processed += 1
405+ continue
406+
407+ # Construct URL: tar_path/audio_filename
408+ audio_url = f"{ tar_path .rstrip ('/' )} /{ audio_filename .lstrip ('/' )} "
409+
410+ # Get metadata from manifest
411+ duration = data .get ("duration" )
412+ if duration is None :
413+ logging .warning (f"Skipping '{ audio_filename } ' - missing duration in manifest" )
414+ entries_processed += 1
415+ continue
416+
417+ offset = data .get ("offset" , 0.0 )
418+ sampling_rate = data .get ("sampling_rate" , 16000 ) # default to 16kHz if not specified
419+
420+ # Create URL-based recording
421+ recording = Recording (
422+ id = audio_filename ,
423+ sources = [AudioSource (type = "url" , channels = [0 ], source = audio_url )],
424+ sampling_rate = sampling_rate ,
425+ num_samples = compute_num_samples (duration , sampling_rate ),
426+ duration = duration ,
427+ )
428+
429+ # Create cut from recording (audio will be loaded lazily from URL when needed)
430+ cut = recording .to_cut ()
431+ if offset > 0 :
432+ cut = cut .truncate (offset = offset , duration = duration , preserve_id = True )
433+ cut .id = f"{ cut .id } -{ round (offset * 1e2 ):06d} -{ round (duration * 1e2 ):06d} "
434+
435+ # Add supervision (transcript metadata)
436+ cut .supervisions .append (
437+ SupervisionSegment (
438+ id = cut .id ,
439+ recording_id = cut .recording_id ,
440+ start = 0 ,
441+ duration = cut .duration ,
442+ text = data .get (self .text_field ),
443+ language = data .get (self .lang_field ),
444+ )
445+ )
446+
447+ # Attach custom fields and metadata
448+ cut .custom = _to_custom_attr_dict (data )
449+ cut .manifest_origin = manifest_path
450+ cut .tar_origin = tar_path
451+ for extra_field in extra_fields :
452+ extra_field .attach_to (cut )
453+
454+ cntr += 1
455+ entries_processed += 1
456+ yield cut
457+
458+ # Break outer loop if we've reached the slice length limit
459+ if cntr == self .slice_length :
460+ break
461+
373462 def _iter_sequential (
374463 self , tar_path , shard_manifest , manifest_path , rng
375464 ) -> Generator [tuple [dict , bytes ], None , None ]:
@@ -426,6 +515,13 @@ def basename(d: dict) -> str:
426515
427516 shard_manifest : dict [str , list [dict ]] = groupby (basename , self .shard_id_to_manifest [sid ])
428517 tar_path = self .shard_id_to_tar_path [sid ]
518+
519+ if self .use_ais_get_batch :
520+ # Use batch reading mode - URL-based recordings without opening tar files
521+ yield from self ._iter_batch_for_ais_get_batch (
522+ tar_path , shard_manifest , manifest_path , rng , extra_fields
523+ )
524+ continue
429525 try :
430526 for data , raw_audio , tar_info in self ._iter_sequential (tar_path , shard_manifest , manifest_path , rng ):
431527 try :
@@ -441,7 +537,7 @@ def basename(d: dict) -> str:
441537 duration = meta .duration ,
442538 )
443539 cuts_for_recording = []
444- for data in sorted ( shard_manifest [tar_info .name ], key = lambda d : d [ "audio_filepath" ]) :
540+ for data in shard_manifest [tar_info .name ]:
445541 # filter out entries with valid "_skipme" values.
446542 if data .get ("_skipme" , False ):
447543 continue
0 commit comments