diff --git a/README.md b/README.md index 7e80ba79..5e60e69a 100644 --- a/README.md +++ b/README.md @@ -11,9 +11,6 @@ This is the simplest configuration for developers to start with. and follow the prompts to create your own user 3. Run `docker compose run --rm django ./manage.py loaddata species` to load species data into the database -4. Run `docker compose run --rm django ./manage.py makeclient \ - --username your.super.user@email.address \ - --uri http://localhost:3000/` ### Run Vue Frontend @@ -38,7 +35,7 @@ To non-destructively update your development stack at any time: ## Dev Tool Endpoints -1. Main Site Interface [http://localhost:3000/](http://localhost:3000/) +1. Main Site Interface [http://localhost:8080/](http://localhost:8080/) 2. Site Administration [http://localhost:8000/admin/](http://localhost:8000/admin/) 3. Swagger API (These are default swagger endpoints using Django-REST) [http://localhost:8000/api/docs/swagger/](http://localhost:8000/api/docs/swagger/) 4. Django Ninja API [http://localhost:8000/api/v1/docs#/](http://localhost:8000/api/v1/docs#/) diff --git a/bats_ai/core/management/commands/generateNABat.py b/bats_ai/core/management/commands/generateNABat.py index d8d6dc81..49aceec4 100644 --- a/bats_ai/core/management/commands/generateNABat.py +++ b/bats_ai/core/management/commands/generateNABat.py @@ -12,10 +12,10 @@ from bats_ai.core.models import Species from bats_ai.core.models.nabat import NABatRecording, NABatRecordingAnnotation +from bats_ai.core.utils.batbot_metadata import generate_spectrogram_assets from bats_ai.utils.spectrogram_utils import ( generate_nabat_compressed_spectrogram, generate_nabat_spectrogram, - generate_spectrogram_assets, ) fake = Faker() diff --git a/bats_ai/core/management/commands/importRecordings.py b/bats_ai/core/management/commands/importRecordings.py index 43211804..1b34e942 100644 --- a/bats_ai/core/management/commands/importRecordings.py +++ b/bats_ai/core/management/commands/importRecordings.py @@ -40,10 +40,6 @@ def add_arguments(self, parser): ) def handle(self, *args, **options): - import matplotlib - - matplotlib.use('Agg') - directory_path = Path(options['directory']) owner_username = options.get('owner') is_public = options.get('public', False) @@ -78,7 +74,11 @@ def handle(self, *args, **options): self.stdout.write(self.style.WARNING(f'Using default owner: {owner.username}')) # Find all WAV files - wav_files = list(directory_path.rglob('*.wav', case_sensitive=False)) + wav_files = list( + directory_path.rglob( + '*.wav', + ) + ) if not wav_files: self.stdout.write( diff --git a/bats_ai/core/management/commands/loadGRTS.py b/bats_ai/core/management/commands/loadGRTS.py index f6548eb2..667138dc 100644 --- a/bats_ai/core/management/commands/loadGRTS.py +++ b/bats_ai/core/management/commands/loadGRTS.py @@ -1,6 +1,7 @@ import logging import os import tempfile +import urllib from urllib.request import urlretrieve import zipfile @@ -18,27 +19,30 @@ 'https://www.sciencebase.gov/catalog/file/get/5b7753bde4b0f5d578820455?facet=conus_mastersample_10km_GRTS', # noqa: E501 14, 'CONUS', + # Backup URL + 'https://data.kitware.com/api/v1/item/697cc601e7dea9be44ec5aee/download', # noqa: E501 ), # CONUS - ( - 'https://www.sciencebase.gov/catalog/file/get/5b7753a8e4b0f5d578820452?facet=akcan_mastersample_10km_GRTS', # noqa: E501 - 20, - 'Alaska/Canada', - ), # Alaska/Canada - ( - 'https://www.sciencebase.gov/catalog/file/get/5b7753c2e4b0f5d578820457?facet=HI_mastersample_5km_GRTS', # noqa: E501 - 15, - 'Hawaii', - ), # Hawaii - ( - 'https://www.sciencebase.gov/catalog/file/get/5b7753d3e4b0f5d578820459?facet=mex_mastersample_10km_GRTS', # noqa: E501 - 12, - 'Mexico', - ), # Mexico - ( - 'https://www.sciencebase.gov/catalog/file/get/5b7753d8e4b0f5d57882045b?facet=PR_mastersample_5km_GRTS', # noqa: E501 - 21, - 'Puerto Rico', - ), # Puerto Rico + # Removed other regions for now because of sciencebase.gov being down + # ( + # 'https://www.sciencebase.gov/catalog/file/get/5b7753a8e4b0f5d578820452?facet=akcan_mastersample_10km_GRTS', # noqa: E501 + # 20, + # 'Alaska/Canada', + # ), # Alaska/Canada + # ( + # 'https://www.sciencebase.gov/catalog/file/get/5b7753c2e4b0f5d578820457?facet=HI_mastersample_5km_GRTS', # noqa: E501 + # 15, + # 'Hawaii', + # ), # Hawaii + # ( + # 'https://www.sciencebase.gov/catalog/file/get/5b7753d3e4b0f5d578820459?facet=mex_mastersample_10km_GRTS', # noqa: E501 + # 12, + # 'Mexico', + # ), # Mexico + # ( + # 'https://www.sciencebase.gov/catalog/file/get/5b7753d8e4b0f5d57882045b?facet=PR_mastersample_5km_GRTS', # noqa: E501 + # 21, + # 'Puerto Rico', + # ), # Puerto Rico ] @@ -56,11 +60,26 @@ def handle(self, *args, **options): # Track existing IDs to avoid duplicates existing_ids = set(GRTSCells.objects.values_list('id', flat=True)) - for url, sample_frame_id, name in SHAPEFILES: + for url, sample_frame_id, name, backup_url in SHAPEFILES: logger.info(f'Downloading shapefile for Location {name}...') with tempfile.TemporaryDirectory() as tmpdir: zip_path = os.path.join(tmpdir, 'file.zip') - urlretrieve(url, zip_path) + try: + urlretrieve(url, zip_path) + except urllib.error.URLError as e: + logger.warning( + f'Failed to download from primary URL: {e}. \ + Attempting backup URL...' + ) + if backup_url is None: + logger.warning('No backup URL provided, skipping this shapefile.') + continue + try: + urlretrieve(backup_url, zip_path) + except urllib.error.URLError as e2: + raise CommandError( + f'Failed to download from backup URL as well: {e2}' + ) from e2 logger.info(f'Downloaded to {zip_path}') logger.info('Extracting zip file...') diff --git a/bats_ai/core/migrations/0027_alter_annotations_end_time_and_more.py b/bats_ai/core/migrations/0027_alter_annotations_end_time_and_more.py new file mode 100644 index 00000000..6be78597 --- /dev/null +++ b/bats_ai/core/migrations/0027_alter_annotations_end_time_and_more.py @@ -0,0 +1,143 @@ +# Generated by Django 4.2.23 on 2026-02-03 13:05 + +import django.contrib.postgres.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ('core', '0026_merge'), + ] + + operations = [ + migrations.AlterField( + model_name='annotations', + name='end_time', + field=models.FloatField(blank=True, null=True), + ), + migrations.AlterField( + model_name='annotations', + name='high_freq', + field=models.FloatField(blank=True, null=True), + ), + migrations.AlterField( + model_name='annotations', + name='low_freq', + field=models.FloatField(blank=True, null=True), + ), + migrations.AlterField( + model_name='annotations', + name='start_time', + field=models.FloatField(blank=True, null=True), + ), + migrations.AlterField( + model_name='compressedspectrogram', + name='length', + field=models.FloatField(), + ), + migrations.AlterField( + model_name='compressedspectrogram', + name='starts', + field=django.contrib.postgres.fields.ArrayField( + base_field=django.contrib.postgres.fields.ArrayField( + base_field=models.FloatField(), size=None + ), + size=None, + ), + ), + migrations.AlterField( + model_name='compressedspectrogram', + name='stops', + field=django.contrib.postgres.fields.ArrayField( + base_field=django.contrib.postgres.fields.ArrayField( + base_field=models.FloatField(), size=None + ), + size=None, + ), + ), + migrations.AlterField( + model_name='compressedspectrogram', + name='widths', + field=django.contrib.postgres.fields.ArrayField( + base_field=django.contrib.postgres.fields.ArrayField( + base_field=models.FloatField(), size=None + ), + size=None, + ), + ), + migrations.AlterField( + model_name='nabatcompressedspectrogram', + name='length', + field=models.FloatField(), + ), + migrations.AlterField( + model_name='nabatcompressedspectrogram', + name='starts', + field=django.contrib.postgres.fields.ArrayField( + base_field=django.contrib.postgres.fields.ArrayField( + base_field=models.FloatField(), size=None + ), + size=None, + ), + ), + migrations.AlterField( + model_name='nabatcompressedspectrogram', + name='stops', + field=django.contrib.postgres.fields.ArrayField( + base_field=django.contrib.postgres.fields.ArrayField( + base_field=models.FloatField(), size=None + ), + size=None, + ), + ), + migrations.AlterField( + model_name='nabatcompressedspectrogram', + name='widths', + field=django.contrib.postgres.fields.ArrayField( + base_field=django.contrib.postgres.fields.ArrayField( + base_field=models.FloatField(), size=None + ), + size=None, + ), + ), + migrations.AlterField( + model_name='nabatspectrogram', + name='duration', + field=models.FloatField(), + ), + migrations.AlterField( + model_name='nabatspectrogram', + name='frequency_max', + field=models.FloatField(), + ), + migrations.AlterField( + model_name='nabatspectrogram', + name='frequency_min', + field=models.FloatField(), + ), + migrations.AlterField( + model_name='sequenceannotations', + name='end_time', + field=models.FloatField(blank=True, null=True), + ), + migrations.AlterField( + model_name='sequenceannotations', + name='start_time', + field=models.FloatField(blank=True, null=True), + ), + migrations.AlterField( + model_name='spectrogram', + name='duration', + field=models.FloatField(), + ), + migrations.AlterField( + model_name='spectrogram', + name='frequency_max', + field=models.FloatField(), + ), + migrations.AlterField( + model_name='spectrogram', + name='frequency_min', + field=models.FloatField(), + ), + ] diff --git a/bats_ai/core/models/annotations.py b/bats_ai/core/models/annotations.py index 6028ffb4..7feee8c7 100644 --- a/bats_ai/core/models/annotations.py +++ b/bats_ai/core/models/annotations.py @@ -10,10 +10,10 @@ class Annotations(TimeStampedModel, models.Model): recording = models.ForeignKey(Recording, on_delete=models.CASCADE) owner = models.ForeignKey(User, on_delete=models.CASCADE) - start_time = models.IntegerField(blank=True, null=True) - end_time = models.IntegerField(blank=True, null=True) - low_freq = models.IntegerField(blank=True, null=True) - high_freq = models.IntegerField(blank=True, null=True) + start_time = models.FloatField(blank=True, null=True) + end_time = models.FloatField(blank=True, null=True) + low_freq = models.FloatField(blank=True, null=True) + high_freq = models.FloatField(blank=True, null=True) type = models.TextField(blank=True, null=True) species = models.ManyToManyField(Species) comments = models.TextField(blank=True, null=True) diff --git a/bats_ai/core/models/compressed_spectrogram.py b/bats_ai/core/models/compressed_spectrogram.py index f4fa734a..089dd4e6 100644 --- a/bats_ai/core/models/compressed_spectrogram.py +++ b/bats_ai/core/models/compressed_spectrogram.py @@ -15,11 +15,11 @@ class CompressedSpectrogram(TimeStampedModel, models.Model): recording = models.ForeignKey(Recording, on_delete=models.CASCADE) spectrogram = models.ForeignKey(Spectrogram, on_delete=models.CASCADE) - length = models.IntegerField() + length = models.FloatField() images = GenericRelation(SpectrogramImage) - starts = ArrayField(ArrayField(models.IntegerField())) - stops = ArrayField(ArrayField(models.IntegerField())) - widths = ArrayField(ArrayField(models.IntegerField())) + starts = ArrayField(ArrayField(models.FloatField())) + stops = ArrayField(ArrayField(models.FloatField())) + widths = ArrayField(ArrayField(models.FloatField())) cache_invalidated = models.BooleanField(default=True) @property diff --git a/bats_ai/core/models/nabat/nabat_compressed_spectrogram.py b/bats_ai/core/models/nabat/nabat_compressed_spectrogram.py index 14bd9310..d6b584bd 100644 --- a/bats_ai/core/models/nabat/nabat_compressed_spectrogram.py +++ b/bats_ai/core/models/nabat/nabat_compressed_spectrogram.py @@ -17,10 +17,10 @@ class NABatCompressedSpectrogram(TimeStampedModel, models.Model): nabat_recording = models.ForeignKey(NABatRecording, on_delete=models.CASCADE) spectrogram = models.ForeignKey(NABatSpectrogram, on_delete=models.CASCADE) images = GenericRelation(SpectrogramImage) - length = models.IntegerField() - starts = ArrayField(ArrayField(models.IntegerField())) - stops = ArrayField(ArrayField(models.IntegerField())) - widths = ArrayField(ArrayField(models.IntegerField())) + length = models.FloatField() + starts = ArrayField(ArrayField(models.FloatField())) + stops = ArrayField(ArrayField(models.FloatField())) + widths = ArrayField(ArrayField(models.FloatField())) cache_invalidated = models.BooleanField(default=True) @property diff --git a/bats_ai/core/models/nabat/nabat_spectrogram.py b/bats_ai/core/models/nabat/nabat_spectrogram.py index 97c5e10c..ed9cf45c 100644 --- a/bats_ai/core/models/nabat/nabat_spectrogram.py +++ b/bats_ai/core/models/nabat/nabat_spectrogram.py @@ -20,9 +20,9 @@ class NABatSpectrogram(TimeStampedModel, models.Model): images = GenericRelation(SpectrogramImage) width = models.IntegerField() # pixels height = models.IntegerField() # pixels - duration = models.IntegerField() # milliseconds - frequency_min = models.IntegerField() # hz - frequency_max = models.IntegerField() # hz + duration = models.FloatField() # milliseconds + frequency_min = models.FloatField() # hz + frequency_max = models.FloatField() # hz @property def image_url_list(self): diff --git a/bats_ai/core/models/sequence_annotations.py b/bats_ai/core/models/sequence_annotations.py index 5326aeb0..1007c4fb 100644 --- a/bats_ai/core/models/sequence_annotations.py +++ b/bats_ai/core/models/sequence_annotations.py @@ -8,8 +8,8 @@ class SequenceAnnotations(models.Model): recording = models.ForeignKey(Recording, on_delete=models.CASCADE) owner = models.ForeignKey(User, on_delete=models.CASCADE) - start_time = models.IntegerField(blank=True, null=True) - end_time = models.IntegerField(blank=True, null=True) + start_time = models.FloatField(blank=True, null=True) + end_time = models.FloatField(blank=True, null=True) type = models.TextField(blank=True, null=True) comments = models.TextField(blank=True, null=True) species = models.ManyToManyField(Species) diff --git a/bats_ai/core/models/spectrogram.py b/bats_ai/core/models/spectrogram.py index 1d200878..d8491738 100644 --- a/bats_ai/core/models/spectrogram.py +++ b/bats_ai/core/models/spectrogram.py @@ -14,9 +14,9 @@ class Spectrogram(TimeStampedModel, models.Model): images = GenericRelation(SpectrogramImage) width = models.IntegerField() # pixels height = models.IntegerField() # pixels - duration = models.IntegerField() # milliseconds - frequency_min = models.IntegerField() # hz - frequency_max = models.IntegerField() # hz + duration = models.FloatField() # milliseconds + frequency_min = models.FloatField() # hz + frequency_max = models.FloatField() # hz @property def image_url_list(self): diff --git a/bats_ai/core/tasks/nabat/tasks.py b/bats_ai/core/tasks/nabat/tasks.py index 8c5fd6fd..932e8981 100644 --- a/bats_ai/core/tasks/nabat/tasks.py +++ b/bats_ai/core/tasks/nabat/tasks.py @@ -6,10 +6,10 @@ from bats_ai.core.models import Configuration, ProcessingTask, Species from bats_ai.core.models.nabat import NABatRecording, NABatRecordingAnnotation +from bats_ai.core.utils.batbot_metadata import generate_spectrogram_assets from bats_ai.utils.spectrogram_utils import ( generate_nabat_compressed_spectrogram, generate_nabat_spectrogram, - generate_spectrogram_assets, predict_from_compressed, ) @@ -60,7 +60,9 @@ def generate_spectrograms( try: config = Configuration.objects.first() - if config and config.run_inference_on_upload: + # TODO: Disabled until prediction is in batbot + # https://github.com/Kitware/batbot/issues/29 + if config and config.run_inference_on_upload and False: self.update_state( state='Progress', meta={'description': 'Running Prediction on Spectrogram'}, diff --git a/bats_ai/core/tasks/tasks.py b/bats_ai/core/tasks/tasks.py index ef0d5097..af7e7655 100644 --- a/bats_ai/core/tasks/tasks.py +++ b/bats_ai/core/tasks/tasks.py @@ -1,5 +1,6 @@ import logging import os +import shutil import tempfile from django.contrib.contenttypes.models import ContentType @@ -15,7 +16,8 @@ Spectrogram, SpectrogramImage, ) -from bats_ai.utils.spectrogram_utils import generate_spectrogram_assets, predict_from_compressed +from bats_ai.core.utils.batbot_metadata import generate_spectrogram_assets +from bats_ai.utils.spectrogram_utils import predict_from_compressed logging.basicConfig(level=logging.INFO) logger = logging.getLogger('NABatDataRetrieval') @@ -26,7 +28,15 @@ def recording_compute_spectrogram(recording_id: int): recording = Recording.objects.get(pk=recording_id) with tempfile.TemporaryDirectory() as tmpdir: - results = generate_spectrogram_assets(recording.audio_file, tmpdir) + # Copy the audio file from FileField to a temporary file + audio_filename = os.path.basename(recording.audio_file.name) + temp_audio_path = os.path.join(tmpdir, audio_filename) + + with recording.audio_file.open('rb') as source_file: + with open(temp_audio_path, 'wb') as dest_file: + shutil.copyfileobj(source_file, dest_file) + + results = generate_spectrogram_assets(temp_audio_path, output_folder=tmpdir) # Create or get Spectrogram spectrogram, _ = Spectrogram.objects.get_or_create( recording=recording, @@ -79,7 +89,9 @@ def recording_compute_spectrogram(recording_id: int): ) config = Configuration.objects.first() - if config and config.run_inference_on_upload: + # TODO: Disabled until prediction is in batbot + # https://github.com/Kitware/batbot/issues/29 + if config and config.run_inference_on_upload and False: predict_results = predict_from_compressed(compressed_obj) label = predict_results['label'] score = predict_results['score'] diff --git a/bats_ai/core/utils/batbot_metadata.py b/bats_ai/core/utils/batbot_metadata.py new file mode 100644 index 00000000..dab0c0bc --- /dev/null +++ b/bats_ai/core/utils/batbot_metadata.py @@ -0,0 +1,287 @@ +from contextlib import contextmanager +import json +import os +from pathlib import Path +from typing import Any, TypedDict + +import batbot +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +class SpectrogramMetadata(BaseModel): + """Metadata about the spectrogram.""" + + uncompressed_path: list[str] = Field(alias='uncompressed.path') + compressed_path: list[str] = Field(alias='compressed.path') + + +class UncompressedSize(BaseModel): + """Uncompressed spectrogram dimensions.""" + + width_px: int = Field(alias='width.px') + height_px: int = Field(alias='height.px') + + +class CompressedSize(BaseModel): + """Compressed spectrogram dimensions.""" + + width_px: int = Field(alias='width.px') + height_px: int = Field(alias='height.px') + + +class SizeMetadata(BaseModel): + """Size metadata for spectrograms.""" + + uncompressed: UncompressedSize + compressed: CompressedSize + + +class FrequencyMetadata(BaseModel): + """Frequency range metadata.""" + + min_hz: int = Field(alias='min.hz') + max_hz: int = Field(alias='max.hz') + pixels_hz: list[int] = Field(alias='pixels.hz') + + +class SegmentCurvePoint(BaseModel): + """A single point in a segment curve.""" + + frequency_hz: int + time_ms: float + + +class Segment(BaseModel): + """A detected segment in the spectrogram.""" + + curve_hz_ms: list[list[float]] = Field(alias='curve.(hz,ms)') + start_ms: float = Field(alias='segment start.ms') + end_ms: float = Field(alias='segment end.ms') + duration_ms: float = Field(alias='segment duration.ms') + contour_start_ms: float = Field(alias='contour start.ms') + contour_end_ms: float = Field(alias='contour end.ms') + contour_duration_ms: float = Field(alias='contour duration.ms') + threshold_amp: int = Field(alias='threshold.amp') + peak_f_ms: float | None = Field(None, alias='peak f.ms') + fc_ms: float | None = Field(None, alias='fc.ms') + hi_fc_knee_ms: float | None = Field(None, alias='hi fc:knee.ms') + lo_fc_heel_ms: float | None = Field(None, alias='lo fc:heel.ms') + bandwidth_hz: int | None = Field(None, alias='bandwidth.hz') + hi_f_hz: int | None = Field(None, alias='hi f.hz') + lo_f_hz: int | None = Field(None, alias='lo f.hz') + peak_f_hz: int | None = Field(None, alias='peak f.hz') + fc_hz: int | None = Field(None, alias='fc.hz') + hi_fc_knee_hz: int | None = Field(None, alias='hi fc:knee.hz') + lo_fc_heel_hz: int | None = Field(None, alias='lo fc:heel.hz') + harmonic_flag: bool = Field(False, alias='harmonic.flag') + harmonic_peak_f_ms: float | None = Field(None, alias='harmonic peak f.ms') + harmonic_peak_f_hz: int | None = Field(None, alias='harmonic peak f.hz') + echo_flag: bool = Field(False, alias='echo.flag') + echo_peak_f_ms: float | None = Field(None, alias='echo peak f.ms') + echo_peak_f_hz: int | None = Field(None, alias='echo peak f.hz') + # Slope fields (optional, many variations) + slope_at_hi_fc_knee_khz_per_ms: float | None = Field(None, alias='slope@hi fc:knee.khz/ms') + slope_at_fc_khz_per_ms: float | None = Field(None, alias='slope@fc.khz/ms') + slope_at_low_fc_heel_khz_per_ms: float | None = Field(None, alias='slope@low fc:heel.khz/ms') + slope_at_peak_khz_per_ms: float | None = Field(None, alias='slope@peak.khz/ms') + slope_avg_khz_per_ms: float | None = Field(None, alias='slope[avg].khz/ms') + slope_hi_avg_khz_per_ms: float | None = Field(None, alias='slope/hi[avg].khz/ms') + slope_mid_avg_khz_per_ms: float | None = Field(None, alias='slope/mid[avg].khz/ms') + slope_lo_avg_khz_per_ms: float | None = Field(None, alias='slope/lo[avg].khz/ms') + slope_box_khz_per_ms: float | None = Field(None, alias='slope[box].khz/ms') + slope_hi_box_khz_per_ms: float | None = Field(None, alias='slope/hi[box].khz/ms') + slope_mid_box_khz_per_ms: float | None = Field(None, alias='slope/mid[box].khz/ms') + slope_lo_box_khz_per_ms: float | None = Field(None, alias='slope/lo[box].khz/ms') + + @field_validator('curve_hz_ms', mode='before') + @classmethod + def validate_curve(cls, v: Any) -> list[list[float]]: + """Ensure curve is a list of [frequency, time] pairs.""" + if isinstance(v, list): + return v + return [] + + +class BatbotMetadata(BaseModel): + """Complete BatBot metadata structure.""" + + model_config = ConfigDict(validate_by_name=True, validate_by_alias=True) + wav_path: str = Field(alias='wav.path') + spectrogram: SpectrogramMetadata + global_threshold_amp: int = Field(alias='global_threshold.amp') + sr_hz: int = Field(alias='sr.hz') + duration_ms: float = Field(alias='duration.ms') + frequencies: FrequencyMetadata + size: SizeMetadata + segments: list[Segment] + + +class SpectrogramData(BaseModel): + """Data structure for creating a Spectrogram model.""" + + width: int + height: int + duration: int # milliseconds + frequency_min: int # hz + frequency_max: int # hz + + +class CompressedSpectrogramData(BaseModel): + """Data structure for creating a CompressedSpectrogram model.""" + + starts: list[float] + stops: list[float] + widths: list[float] + + +def parse_batbot_metadata(file_path: str | Path) -> BatbotMetadata: + """Parse a BatBot metadata JSON file. + + Args: + file_path: Path to the metadata JSON file + + Returns: + Parsed BatbotMetadata object + """ + file_path = Path(file_path) + with open(file_path) as f: + data = json.load(f) + return BatbotMetadata(**data) + + +def convert_to_spectrogram_data(metadata: BatbotMetadata) -> SpectrogramData: + """Convert BatBot metadata to Spectrogram model data. + + Args: + metadata: Parsed BatBot metadata + + Returns: + SpectrogramData with fields for Spectrogram model + """ + return SpectrogramData( + width=metadata.size.uncompressed.width_px, + height=metadata.size.uncompressed.height_px, + duration=int(round(metadata.duration_ms)), + frequency_min=metadata.frequencies.min_hz, + frequency_max=metadata.frequencies.max_hz, + ) + + +def convert_to_compressed_spectrogram_data(metadata: BatbotMetadata) -> CompressedSpectrogramData: + """Convert BatBot metadata to CompressedSpectrogram model data. + + This function calculates starts, stops, and widths for each compressed image + based on the segments and the relationship between uncompressed and compressed widths. + + The compressed image is a concatenation of segments from the uncompressed image. + - starts/stops: time values in milliseconds (matching the pattern in spectrogram_utils.py) + - widths: pixel widths in the compressed image (where segments are concatenated) + + Args: + metadata: Parsed BatBot metadata + + Returns: + CompressedSpectrogramData with fields for CompressedSpectrogram model + """ + duration_ms = metadata.duration_ms + + # Process each compressed image + starts_ms: list[int] = [] + stops_ms: list[int] = [] + widths_px_compressed: list[int] = [] + segment_times: list[int] = [] + compressed_width = metadata.size.compressed.width_px + total_time = 0.0 + + # If we have segments, use them to determine which parts are kept + if metadata.segments: + for segment in metadata.segments: + starts_ms.append(segment.start_ms) + stops_ms.append(segment.end_ms) + time = segment.end_ms - segment.start_ms + segment_times.append(time) + total_time += time + # Calculate width in compressed space + # The width in compressed space is proportional to the time duration + for time in segment_times: + width_px = (time / total_time) * compressed_width + widths_px_compressed.append(width_px) + else: + # No segments - the entire image is compressed + starts_ms = [0] + stops_ms = [duration_ms] + widths_px_compressed = [compressed_width] + + return CompressedSpectrogramData( + starts=starts_ms, + stops=stops_ms, + widths=widths_px_compressed, + ) + + +class SpectrogramAssetResult(TypedDict): + paths: list[str] + width: int + height: int + + +class SpectrogramCompressedAssetResult(TypedDict): + paths: list[str] + width: int + height: int + widths: list[float] + starts: list[float] + stops: list[float] + + +class SpectrogramAssets(TypedDict): + duration: float + freq_min: int + freq_max: int + normal: SpectrogramAssetResult + compressed: SpectrogramCompressedAssetResult + + +@contextmanager +def working_directory(path): + previous = os.getcwd() + os.chdir(path) + try: + yield + finally: + os.chdir(previous) + + +def generate_spectrogram_assets(recording_path: str, output_folder: str): + batbot.pipeline(recording_path, output_folder=output_folder) + # There should be a .metadata.json file in the output_base directory by replacing extentions + metadata_file = Path(recording_path).with_suffix('.metadata.json').name + metadata_file = Path(output_folder) / metadata_file + metadata = parse_batbot_metadata(metadata_file) + # from the metadata we should have the images that are used + uncompressed_paths = metadata.spectrogram.uncompressed_path + compressed_paths = metadata.spectrogram.compressed_path + + metadata.frequencies.min_hz + metadata.frequencies.max_hz + + compressed_metadata = convert_to_compressed_spectrogram_data(metadata) + result: SpectrogramAssets = { + 'duration': metadata.duration_ms, + 'freq_min': metadata.frequencies.min_hz, + 'freq_max': metadata.frequencies.max_hz, + 'normal': { + 'paths': uncompressed_paths, + 'width': metadata.size.uncompressed.width_px, + 'height': metadata.size.uncompressed.height_px, + }, + 'compressed': { + 'paths': compressed_paths, + 'width': metadata.size.compressed.width_px, + 'height': metadata.size.compressed.height_px, + 'widths': compressed_metadata.widths, + 'starts': compressed_metadata.starts, + 'stops': compressed_metadata.stops, + }, + } + return result diff --git a/bats_ai/utils/spectrogram_utils.py b/bats_ai/utils/spectrogram_utils.py index 718fc6f5..2f60a3bc 100644 --- a/bats_ai/utils/spectrogram_utils.py +++ b/bats_ai/utils/spectrogram_utils.py @@ -1,22 +1,15 @@ -import io import json import logging -import math import os from pathlib import Path from typing import TypedDict -from PIL import Image import cv2 from django.contrib.contenttypes.models import ContentType from django.core.files import File -import librosa -import librosa.display -import matplotlib.pyplot as plt import numpy as np import onnx import onnxruntime as ort -import scipy.signal import tqdm from bats_ai.core.models import CompressedSpectrogram, SpectrogramImage @@ -24,10 +17,6 @@ logger = logging.getLogger(__name__) -FREQ_MIN = 5e3 -FREQ_MAX = 120e3 -FREQ_PAD = 2e3 - class SpectrogramAssetResult(TypedDict): paths: list[str] @@ -141,251 +130,6 @@ def predict_from_compressed( return {'label': label, 'score': score, 'confs': confs} -def generate_spectrogram_assets( - recording_path: str, output_base: str, dpi: int = 520 -) -> SpectrogramAssets: - sig, sr = librosa.load(recording_path, sr=None) - duration = len(sig) / sr - - size_mod = 1 - size = int(0.001 * sr) - size = 2 ** (math.ceil(math.log(size, 2)) + size_mod) - hop_length = int(size / 4) - - window = librosa.stft(sig, n_fft=size, hop_length=hop_length, window='hamming') - window = np.abs(window) ** 2 - window = librosa.power_to_db(window) - window -= np.median(window, axis=1, keepdims=True) - window_ = window[window > 0] - thresh = np.median(window_) - window[window <= thresh] = 0 - - bands = librosa.fft_frequencies(sr=sr, n_fft=size) - for index in range(len(bands)): - band_min = bands[index] - band_max = bands[index + 1] if index < len(bands) - 1 else np.inf - if band_max <= FREQ_MIN or FREQ_MAX <= band_min: - window[index, :] = -1 - - window = np.clip(window, 0, None) - freq_low = int(FREQ_MIN - FREQ_PAD) - freq_high = int(FREQ_MAX + FREQ_PAD) - vmin = window.min() - vmax = window.max() - - chunksize = int(2e3) - arange = np.arange(chunksize, window.shape[1], chunksize) - chunks = np.array_split(window, arange, axis=1) - - imgs = [] - for chunk in chunks: - h, w = chunk.shape - alpha = 3 - figsize = (int(math.ceil(w / h)) * alpha + 1, alpha) - fig = plt.figure(figsize=figsize, facecolor='black', dpi=dpi) - ax = plt.axes() - plt.margins(0) - - kwargs = { - 'sr': sr, - 'n_fft': size, - 'hop_length': hop_length, - 'x_axis': 's', - 'y_axis': 'fft', - 'ax': ax, - 'vmin': vmin, - 'vmax': vmax, - } - - librosa.display.specshow(chunk, cmap='gray', **kwargs) - ax.set_ylim(freq_low, freq_high) - ax.axis('off') - - buf = io.BytesIO() - fig.savefig(buf, bbox_inches='tight', pad_inches=0) - plt.close(fig) - - buf.seek(0) - img = Image.open(buf) - img = np.array(img) - mask = img[:, :, -1] - flags = np.where(np.sum(mask != 0, axis=0) == 0)[0] - index = flags.min() if len(flags) > 0 else img.shape[1] - img = img[:, :index, :3] - - imgs.append(img) - - normal_img = np.hstack(imgs) - normal_width = int(8.0 * duration * 1e3) - normal_height = 1200 - normal_img_resized = cv2.resize( - normal_img, (normal_width, normal_height), interpolation=cv2.INTER_LANCZOS4 - ) - - normal_out_path_base = os.path.join( - os.path.dirname(output_base), - 'spectrogram', - os.path.splitext(os.path.basename(output_base))[0] + '_spectrogram', - ) - os.makedirs(os.path.dirname(normal_out_path_base), exist_ok=True) - normal_paths = save_img(normal_img_resized, normal_out_path_base) - real_duration = math.ceil(duration * 1e3) - compressed_img, compressed_paths, widths, starts, stops = generate_compressed( - normal_img_resized, real_duration, output_base - ) - - result = { - 'duration': real_duration, - 'freq_min': freq_low, - 'freq_max': freq_high, - 'normal': { - 'paths': normal_paths, - 'width': normal_img_resized.shape[1], - 'height': normal_img_resized.shape[0], - }, - 'compressed': { - 'paths': compressed_paths, - 'width': compressed_img.shape[1], - 'height': compressed_img.shape[0], - 'widths': widths, - 'starts': starts, - 'stops': stops, - }, - } - - return result - - -def generate_compressed(img: np.ndarray, duration: float, output_base: str): - threshold = 0.5 - compressed_img = img.copy() - starts, stops = [], [] - - while True: - canvas = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).astype(np.float32) - is_light = np.median(canvas) > 128.0 - if is_light: - canvas = 255.0 - canvas - - amplitude = canvas.max(axis=0) - amplitude -= amplitude.min() - if amplitude.max() != 0: - amplitude /= amplitude.max() - amplitude[amplitude < threshold] = 0.0 - amplitude[amplitude > 0] = 1.0 - amplitude = amplitude.reshape(1, -1) - - canvas -= canvas.min() - if canvas.max() != 0: - canvas /= canvas.max() - canvas *= 255.0 - canvas *= amplitude - canvas = np.around(canvas).astype(np.uint8) - - mask = canvas.max(axis=0) - mask = scipy.signal.medfilt(mask, 3) - mask[0] = 0 - mask[-1] = 0 - - starts, stops = [], [] - for index in range(1, len(mask) - 1): - if mask[index] != 0: - if mask[index - 1] == 0: - starts.append(index) - if mask[index + 1] == 0: - stops.append(index) - - starts = [val - 40 for val in starts] - stops = [val + 40 for val in stops] - ranges = list(zip(starts, stops)) - - while True: - found = False - merged = [] - index = 0 - while index < len(ranges) - 1: - start1, stop1 = ranges[index] - start2, stop2 = ranges[index + 1] - - # Clamp values within mask length - start1 = min(max(start1, 0), len(mask)) - start2 = min(max(start2, 0), len(mask)) - stop1 = min(max(stop1, 0), len(mask)) - stop2 = min(max(stop2, 0), len(mask)) - - if stop1 >= start2: - found = True - merged.append((start1, stop2)) - index += 2 - else: - merged.append((start1, stop1)) - index += 1 - if index == len(ranges) - 1: - merged.append((start2, stop2)) - ranges = merged - if not found: - break - - starts = [start for start, _ in ranges] - stops = [stop for _, stop in ranges] - - segments = [] - domain = img.shape[1] - widths = [] - for start, stop in ranges: - start_clamped = max(start, 0) - stop_clamped = min(stop, domain) - segment = img[:, start_clamped:stop_clamped] - segments.append(segment) - widths.append(stop_clamped - start_clamped) - - if segments: - compressed_img = np.hstack(segments) - break - - threshold -= 0.05 - if threshold < 0: - compressed_img = img.copy() - widths = [] - starts = [] - stops = [] - break - - # Convert starts and stops to time values relative to duration - starts_time = [int(round(duration * (max(s, 0) / domain))) for s in starts] - stops_time = [int(round(duration * (min(e, domain) / domain))) for e in stops] - - out_folder = os.path.join(os.path.dirname(output_base), 'compressed') - os.makedirs(out_folder, exist_ok=True) - base_name = os.path.splitext(os.path.basename(output_base))[0] - compressed_out_path = os.path.join(out_folder, f'{base_name}_compressed') - - # save_img should be your existing function to save images and return file paths - paths = save_img(compressed_img, compressed_out_path) - - return compressed_img, paths, widths, starts_time, stops_time - - -def save_img(img: np.ndarray, output_base: str): - chunksize = int(5e4) - length = img.shape[1] - chunks = ( - np.split(img, np.arange(chunksize, length, chunksize), axis=1) - if length > chunksize - else [img] - ) - total = len(chunks) - output_paths = [] - for index, chunk in enumerate(chunks): - out_path = f'{output_base}.{index + 1:02d}_of_{total:02d}.jpg' - out_img = Image.fromarray(chunk, 'RGB') - out_img.save(out_path, format='JPEG', optimize=True, quality=80) - output_paths.append(out_path) - logger.info(f'Saved image: {out_path}') - - return output_paths - - def generate_nabat_spectrogram( nabat_recording: NABatRecording, results: SpectrogramAssets ) -> NABatSpectrogram: diff --git a/client/src/components/ColorSchemeDialog.vue b/client/src/components/ColorSchemeDialog.vue index 874e8512..9a102e6c 100644 --- a/client/src/components/ColorSchemeDialog.vue +++ b/client/src/components/ColorSchemeDialog.vue @@ -41,7 +41,7 @@ watch(colorScheme, () => {