Skip to content
2 changes: 1 addition & 1 deletion bats_ai/core/admin/pulse_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@

@admin.register(PulseMetadata)
class PulseMetadataAdmin(admin.ModelAdmin):
list_display = ('recording', 'index', 'bounding_box')
list_display = ('recording', 'index', 'bounding_box', 'curve', 'char_freq', 'knee', 'heel')
list_select_related = True
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Generated by Django 4.2.23 on 2026-02-03 19:43

import django.contrib.gis.db.models.fields
from django.db import migrations


class Migration(migrations.Migration):
dependencies = [
('core', '0028_alter_spectrogramimage_type_pulsemetadata'),
]

operations = [
migrations.AddField(
model_name='pulsemetadata',
name='char_freq',
field=django.contrib.gis.db.models.fields.PointField(blank=True, null=True, srid=4326),
),
migrations.AddField(
model_name='pulsemetadata',
name='curve',
field=django.contrib.gis.db.models.fields.LineStringField(
blank=True, null=True, srid=4326
),
),
migrations.AddField(
model_name='pulsemetadata',
name='heel',
field=django.contrib.gis.db.models.fields.PointField(blank=True, null=True, srid=4326),
),
migrations.AddField(
model_name='pulsemetadata',
name='knee',
field=django.contrib.gis.db.models.fields.PointField(blank=True, null=True, srid=4326),
),
]
5 changes: 4 additions & 1 deletion bats_ai/core/models/pulse_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ class PulseMetadata(models.Model):
index = models.IntegerField(null=False, blank=False)
bounding_box = models.PolygonField(null=False, blank=False)
contours = models.JSONField(null=True, blank=True)
# TODO: Add in metadata from batbot
curve = models.LineStringField(null=True, blank=True)
char_freq = models.PointField(null=True, blank=True)
knee = models.PointField(null=True, blank=True)
heel = models.PointField(null=True, blank=True)
45 changes: 44 additions & 1 deletion bats_ai/core/tasks/nabat/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from pathlib import Path
import tempfile

from django.contrib.gis.geos import LineString, Point, Polygon
import requests

from bats_ai.core.models import Configuration, ProcessingTask, Species
from bats_ai.core.models import Configuration, ProcessingTask, PulseMetadata, Species
from bats_ai.core.models.nabat import NABatRecording, NABatRecordingAnnotation
from bats_ai.core.utils.batbot_metadata import generate_spectrogram_assets
from bats_ai.utils.spectrogram_utils import (
Expand Down Expand Up @@ -57,6 +58,48 @@ def generate_spectrograms(
compressed_obj = generate_nabat_compressed_spectrogram(
nabat_recording, spectrogram, compressed
)
segment_index_map = {}
for segment in compressed['contours']['segments']:
pulse_metadata_obj, _ = PulseMetadata.objects.get_or_create(
recording=compressed_obj.recording,
index=segment['segment_index'],
defaults={
'contours': segment['contours'],
'bounding_box': Polygon(
(
(segment['start_ms'], segment['freq_max']),
(segment['stop_ms'], segment['freq_max']),
(segment['stop_ms'], segment['freq_min']),
(segment['start_ms'], segment['freq_min']),
(segment['start_ms'], segment['freq_max']),
)
),
},
)
segment_index_map[segment['segment_index']] = pulse_metadata_obj
for segment in compressed['segments']:
if segment['segment_index'] not in segment_index_map:
PulseMetadata.objects.update_or_create(
recording=compressed_obj.recording,
index=segment['segment_index'],
defaults={
'curve': LineString([Point(x[1], x[0]) for x in segment['curve_hz_ms']]),
'char_freq': Point(segment['char_freq_ms'], segment['char_freq_hz']),
'knee': Point(segment['knee_ms'], segment['knee_hz']),
'heel': Point(segment['heel_ms'], segment['heel_hz']),
},
)
else:
pulse_metadata_obj = segment_index_map[segment['segment_index']]
pulse_metadata_obj.curve = LineString(
[Point(x[1], x[0]) for x in segment['curve_hz_ms']]
)
pulse_metadata_obj.char_freq = Point(
segment['char_freq_ms'], segment['char_freq_hz']
)
pulse_metadata_obj.knee = Point(segment['knee_ms'], segment['knee_hz'])
pulse_metadata_obj.heel = Point(segment['heel_ms'], segment['heel_hz'])
pulse_metadata_obj.save()

try:
config = Configuration.objects.first()
Expand Down
31 changes: 28 additions & 3 deletions bats_ai/core/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tempfile

from django.contrib.contenttypes.models import ContentType
from django.contrib.gis.geos import Polygon
from django.contrib.gis.geos import LineString, Point, Polygon
from django.core.files import File

from bats_ai.celery import app
Expand Down Expand Up @@ -104,8 +104,9 @@ def recording_compute_spectrogram(recording_id: int):
)

# Create SpectrogramContour objects for each segment
for segment in results['segments']['segments']:
PulseMetadata.objects.update_or_create(
segment_index_map = {}
for segment in compressed['contours']['segments']:
pulse_metadata_obj, _ = PulseMetadata.objects.update_or_create(
recording=compressed_obj.recording,
index=segment['segment_index'],
defaults={
Expand All @@ -121,6 +122,30 @@ def recording_compute_spectrogram(recording_id: int):
),
},
)
segment_index_map[segment['segment_index']] = pulse_metadata_obj
for segment in compressed['segments']:
if segment['segment_index'] not in segment_index_map:
PulseMetadata.objects.update_or_create(
recording=compressed_obj.recording,
index=segment['segment_index'],
defaults={
'curve': LineString([Point(x[1], x[0]) for x in segment['curve_hz_ms']]),
'char_freq': Point(segment['char_freq_ms'], segment['char_freq_hz']),
'knee': Point(segment['knee_ms'], segment['knee_hz']),
'heel': Point(segment['heel_ms'], segment['heel_hz']),
},
)
else:
pulse_metadata_obj = segment_index_map[segment['segment_index']]
pulse_metadata_obj.curve = LineString(
[Point(x[1], x[0]) for x in segment['curve_hz_ms']]
)
pulse_metadata_obj.char_freq = Point(
segment['char_freq_ms'], segment['char_freq_hz']
)
pulse_metadata_obj.knee = Point(segment['knee_ms'], segment['knee_hz'])
pulse_metadata_obj.heel = Point(segment['heel_ms'], segment['heel_hz'])
pulse_metadata_obj.save()

config = Configuration.objects.first()
# TODO: Disabled until prediction is in batbot
Expand Down
41 changes: 38 additions & 3 deletions bats_ai/core/utils/batbot_metadata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager
import json
import logging
import os
from pathlib import Path
from typing import Any, TypedDict
Expand All @@ -9,6 +10,8 @@

from .contour_utils import process_spectrogram_assets_for_contours

logger = logging.getLogger(__name__)


class SpectrogramMetadata(BaseModel):
"""Metadata about the spectrogram."""
Expand Down Expand Up @@ -255,6 +258,17 @@ class SpectrogramContourSegment(TypedDict):
stop_ms: float


class BatBotMetadataCurve(TypedDict):
segment_index: int
curve_hz_ms: list[float]
char_freq_ms: float
char_freq_hz: float
knee_ms: float
knee_hz: float
heel_ms: float
heel_hz: float


class SpectrogramContours(TypedDict):
segments: list[SpectrogramContourSegment]
total_segments: int
Expand All @@ -266,7 +280,7 @@ class SpectrogramAssets(TypedDict):
freq_max: int
normal: SpectrogramAssetResult
compressed: SpectrogramCompressedAssetResult
segments: SpectrogramContours | None
contours: SpectrogramContours | None


@contextmanager
Expand All @@ -279,6 +293,25 @@ def working_directory(path):
os.chdir(previous)


def convert_to_segment_data(
metadata: BatbotMetadata,
) -> list[BatBotMetadataCurve]:
segment_data: list[BatBotMetadataCurve] = []
for index, segment in enumerate(metadata.segments):
segment_data_item: BatBotMetadataCurve = {
'segment_index': index,
'curve_hz_ms': segment.curve_hz_ms,
'char_freq_ms': segment.fc_ms,
'char_freq_hz': segment.fc_hz,
'knee_ms': segment.hi_fc_knee_ms,
'knee_hz': segment.hi_fc_knee_hz,
'heel_ms': segment.lo_fc_heel_ms,
'heel_hz': segment.lo_fc_heel_hz,
}
segment_data.append(segment_data_item)
return segment_data


def generate_spectrogram_assets(recording_path: str, output_folder: str):
batbot.pipeline(recording_path, output_folder=output_folder)
# There should be a .metadata.json file in the output_base directory by replacing extentions
Expand All @@ -294,6 +327,7 @@ def generate_spectrogram_assets(recording_path: str, output_folder: str):
metadata.frequencies.max_hz

compressed_metadata = convert_to_compressed_spectrogram_data(metadata)
segment_curve_data = convert_to_segment_data(metadata)
result: SpectrogramAssets = {
'duration': metadata.duration_ms,
'freq_min': metadata.frequencies.min_hz,
Expand All @@ -311,10 +345,11 @@ def generate_spectrogram_assets(recording_path: str, output_folder: str):
'widths': compressed_metadata.widths,
'starts': compressed_metadata.starts,
'stops': compressed_metadata.stops,
'segments': segment_curve_data,
},
}

segments_data = process_spectrogram_assets_for_contours(result)
result['segments'] = segments_data
contour_segments_data = process_spectrogram_assets_for_contours(result)
result['compressed']['contours'] = contour_segments_data

return result
51 changes: 50 additions & 1 deletion bats_ai/core/views/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class UpdateAnnotationsSchema(Schema):
id: int | None


class PulseMetadataSchema(Schema):
class PulseContourSchema(Schema):
id: int | None
index: int
bounding_box: Any
Expand All @@ -146,6 +146,36 @@ def from_orm(cls, obj: PulseMetadata):
)


class PulseMetadataSchema(Schema):
id: int | None
index: int
curve: list[list[float]] | None = None # list of [time, frequency]
char_freq: list[float] | None = None # point [time, frequency]
knee: list[float] | None = None # point [time, frequency]
heel: list[float] | None = None # point [time, frequency]

@classmethod
def from_orm(cls, obj: PulseMetadata):
def point_to_list(pt):
if pt is None:
return None
return [pt.x, pt.y]

def linestring_to_list(ls):
if ls is None:
return None
return [[c[0], c[1]] for c in ls.coords]

return cls(
id=obj.id,
index=obj.index,
curve=linestring_to_list(obj.curve),
char_freq=point_to_list(obj.char_freq),
knee=point_to_list(obj.knee),
heel=point_to_list(obj.heel),
)


@router.post('/')
def create_recording(
request: HttpRequest,
Expand Down Expand Up @@ -560,6 +590,25 @@ def get_annotations(request: HttpRequest, id: int):
return {'error': 'Recording not found'}


@router.get('/{id}/pulse_contours')
def get_pulse_contours(request: HttpRequest, id: int):
try:
recording = Recording.objects.get(pk=id)
if recording.owner == request.user or recording.public:
computed_pulse_annotation_qs = PulseMetadata.objects.filter(
recording=recording
).order_by('index')
return [
PulseContourSchema.from_orm(pulse) for pulse in computed_pulse_annotation_qs.all()
]
else:
return {
'error': 'Permission denied. You do not own this recording, and it is not public.'
}
except Recording.DoesNotExist:
return {'error': 'Recording not found'}


@router.get('/{id}/pulse_data')
def get_pulse_data(request: HttpRequest, id: int):
try:
Expand Down
23 changes: 19 additions & 4 deletions client/src/api/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -574,14 +574,28 @@ export interface Contour {
index: number;
}

export interface ComputedPulseAnnotation {
export interface ComputedPulseContour {
id: number;
index: number;
contours: Contour[];
}

async function getComputedPulseAnnotations(recordingId: number) {
const result = await axiosInstance.get<ComputedPulseAnnotation[]>(`/recording/${recordingId}/pulse_data`);
async function getComputedPulseContour(recordingId: number) {
const result = await axiosInstance.get<ComputedPulseContour[]>(`/recording/${recordingId}/pulse_contours`);
return result.data;
}

export interface PulseMetadata {
id: number;
index: number;
curve: number[][] | null; // list of [time, frequency]
char_freq: number[] | null; // point [time, frequency]
knee: number[] | null; // point [time, frequency]
heel: number[] | null; // point [time, frequency]
}

async function getPulseMetadata(recordingId: number) {
const result = await axiosInstance.get<PulseMetadata[]>(`/recording/${recordingId}/pulse_data`);
return result.data;
}

Expand Down Expand Up @@ -622,7 +636,8 @@ export {
getFileAnnotationDetails,
getExportStatus,
getRecordingTags,
getComputedPulseAnnotations,
getComputedPulseContour,
getPulseMetadata,
getCurrentUser,
getVettingDetailsForUser,
createOrUpdateVettingDetailsForUser,
Expand Down
Loading
Loading