Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bats_ai/core/admin/pulse_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@

@admin.register(PulseMetadata)
class PulseMetadataAdmin(admin.ModelAdmin):
list_display = ('recording', 'index', 'bounding_box')
list_display = ('recording', 'index', 'bounding_box', 'curve', 'char_freq', 'knee', 'heel')
list_select_related = True
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Generated by Django 4.2.23 on 2026-02-03 19:43

import django.contrib.gis.db.models.fields
from django.db import migrations


class Migration(migrations.Migration):
dependencies = [
('core', '0028_alter_spectrogramimage_type_pulsemetadata'),
]

operations = [
migrations.AddField(
model_name='pulsemetadata',
name='char_freq',
field=django.contrib.gis.db.models.fields.PointField(blank=True, null=True, srid=4326),
),
migrations.AddField(
model_name='pulsemetadata',
name='curve',
field=django.contrib.gis.db.models.fields.LineStringField(
blank=True, null=True, srid=4326
),
),
migrations.AddField(
model_name='pulsemetadata',
name='heel',
field=django.contrib.gis.db.models.fields.PointField(blank=True, null=True, srid=4326),
),
migrations.AddField(
model_name='pulsemetadata',
name='knee',
field=django.contrib.gis.db.models.fields.PointField(blank=True, null=True, srid=4326),
),
]
5 changes: 4 additions & 1 deletion bats_ai/core/models/pulse_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ class PulseMetadata(models.Model):
index = models.IntegerField(null=False, blank=False)
bounding_box = models.PolygonField(null=False, blank=False)
contours = models.JSONField(null=True, blank=True)
# TODO: Add in metadata from batbot
curve = models.LineStringField(null=True, blank=True)
char_freq = models.PointField(null=True, blank=True)
knee = models.PointField(null=True, blank=True)
heel = models.PointField(null=True, blank=True)
45 changes: 44 additions & 1 deletion bats_ai/core/tasks/nabat/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from pathlib import Path
import tempfile

from django.contrib.gis.geos import LineString, Point, Polygon
import requests

from bats_ai.core.models import Configuration, ProcessingTask, Species
from bats_ai.core.models import Configuration, ProcessingTask, PulseMetadata, Species
from bats_ai.core.models.nabat import NABatRecording, NABatRecordingAnnotation
from bats_ai.core.utils.batbot_metadata import generate_spectrogram_assets
from bats_ai.utils.spectrogram_utils import (
Expand Down Expand Up @@ -57,6 +58,48 @@ def generate_spectrograms(
compressed_obj = generate_nabat_compressed_spectrogram(
nabat_recording, spectrogram, compressed
)
segment_index_map = {}
for segment in compressed['contours']['segments']:
pulse_metadata_obj, _ = PulseMetadata.objects.get_or_create(
recording=compressed_obj.recording,
index=segment['segment_index'],
defaults={
'contours': segment['contours'],
'bounding_box': Polygon(
(
(segment['start_ms'], segment['freq_max']),
(segment['stop_ms'], segment['freq_max']),
(segment['stop_ms'], segment['freq_min']),
(segment['start_ms'], segment['freq_min']),
(segment['start_ms'], segment['freq_max']),
)
),
},
)
segment_index_map[segment['segment_index']] = pulse_metadata_obj
for segment in compressed['segments']:
if segment['segment_index'] not in segment_index_map:
PulseMetadata.objects.update_or_create(
recording=compressed_obj.recording,
index=segment['segment_index'],
defaults={
'curve': LineString([Point(x[1], x[0]) for x in segment['curve_hz_ms']]),
'char_freq': Point(segment['char_freq_ms'], segment['char_freq_hz']),
'knee': Point(segment['knee_ms'], segment['knee_hz']),
'heel': Point(segment['heel_ms'], segment['heel_hz']),
},
)
else:
pulse_metadata_obj = segment_index_map[segment['segment_index']]
pulse_metadata_obj.curve = LineString(
[Point(x[1], x[0]) for x in segment['curve_hz_ms']]
)
pulse_metadata_obj.char_freq = Point(
segment['char_freq_ms'], segment['char_freq_hz']
)
pulse_metadata_obj.knee = Point(segment['knee_ms'], segment['knee_hz'])
pulse_metadata_obj.heel = Point(segment['heel_ms'], segment['heel_hz'])
pulse_metadata_obj.save()

try:
config = Configuration.objects.first()
Expand Down
31 changes: 28 additions & 3 deletions bats_ai/core/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tempfile

from django.contrib.contenttypes.models import ContentType
from django.contrib.gis.geos import Polygon
from django.contrib.gis.geos import LineString, Point, Polygon
from django.core.files import File

from bats_ai.celery import app
Expand Down Expand Up @@ -104,8 +104,9 @@ def recording_compute_spectrogram(recording_id: int):
)

# Create SpectrogramContour objects for each segment
for segment in results['segments']['segments']:
PulseMetadata.objects.update_or_create(
segment_index_map = {}
for segment in compressed['contours']['segments']:
pulse_metadata_obj, _ = PulseMetadata.objects.update_or_create(
recording=compressed_obj.recording,
index=segment['segment_index'],
defaults={
Expand All @@ -121,6 +122,30 @@ def recording_compute_spectrogram(recording_id: int):
),
},
)
segment_index_map[segment['segment_index']] = pulse_metadata_obj
for segment in compressed['segments']:
if segment['segment_index'] not in segment_index_map:
PulseMetadata.objects.update_or_create(
recording=compressed_obj.recording,
index=segment['segment_index'],
defaults={
'curve': LineString([Point(x[1], x[0]) for x in segment['curve_hz_ms']]),
'char_freq': Point(segment['char_freq_ms'], segment['char_freq_hz']),
'knee': Point(segment['knee_ms'], segment['knee_hz']),
'heel': Point(segment['heel_ms'], segment['heel_hz']),
},
)
else:
pulse_metadata_obj = segment_index_map[segment['segment_index']]
pulse_metadata_obj.curve = LineString(
[Point(x[1], x[0]) for x in segment['curve_hz_ms']]
)
pulse_metadata_obj.char_freq = Point(
segment['char_freq_ms'], segment['char_freq_hz']
)
pulse_metadata_obj.knee = Point(segment['knee_ms'], segment['knee_hz'])
pulse_metadata_obj.heel = Point(segment['heel_ms'], segment['heel_hz'])
pulse_metadata_obj.save()

config = Configuration.objects.first()
# TODO: Disabled until prediction is in batbot
Expand Down
41 changes: 38 additions & 3 deletions bats_ai/core/utils/batbot_metadata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager
import json
import logging
import os
from pathlib import Path
from typing import Any, TypedDict
Expand All @@ -9,6 +10,8 @@

from .contour_utils import process_spectrogram_assets_for_contours

logger = logging.getLogger(__name__)


class SpectrogramMetadata(BaseModel):
"""Metadata about the spectrogram."""
Expand Down Expand Up @@ -255,6 +258,17 @@ class SpectrogramContourSegment(TypedDict):
stop_ms: float


class BatBotMetadataCurve(TypedDict):
segment_index: int
curve_hz_ms: list[float]
char_freq_ms: float
char_freq_hz: float
knee_ms: float
knee_hz: float
heel_ms: float
heel_hz: float


class SpectrogramContours(TypedDict):
segments: list[SpectrogramContourSegment]
total_segments: int
Expand All @@ -266,7 +280,7 @@ class SpectrogramAssets(TypedDict):
freq_max: int
normal: SpectrogramAssetResult
compressed: SpectrogramCompressedAssetResult
segments: SpectrogramContours | None
contours: SpectrogramContours | None


@contextmanager
Expand All @@ -279,6 +293,25 @@ def working_directory(path):
os.chdir(previous)


def convert_to_segment_data(
metadata: BatbotMetadata,
) -> list[BatBotMetadataCurve]:
segment_data: list[BatBotMetadataCurve] = []
for index, segment in enumerate(metadata.segments):
segment_data_item: BatBotMetadataCurve = {
'segment_index': index,
'curve_hz_ms': segment.curve_hz_ms,
'char_freq_ms': segment.fc_ms,
'char_freq_hz': segment.fc_hz,
'knee_ms': segment.hi_fc_knee_ms,
'knee_hz': segment.hi_fc_knee_hz,
'heel_ms': segment.lo_fc_heel_ms,
'heel_hz': segment.lo_fc_heel_hz,
}
segment_data.append(segment_data_item)
return segment_data


def generate_spectrogram_assets(recording_path: str, output_folder: str):
batbot.pipeline(recording_path, output_folder=output_folder)
# There should be a .metadata.json file in the output_base directory by replacing extentions
Expand All @@ -294,6 +327,7 @@ def generate_spectrogram_assets(recording_path: str, output_folder: str):
metadata.frequencies.max_hz

compressed_metadata = convert_to_compressed_spectrogram_data(metadata)
segment_curve_data = convert_to_segment_data(metadata)
result: SpectrogramAssets = {
'duration': metadata.duration_ms,
'freq_min': metadata.frequencies.min_hz,
Expand All @@ -311,10 +345,11 @@ def generate_spectrogram_assets(recording_path: str, output_folder: str):
'widths': compressed_metadata.widths,
'starts': compressed_metadata.starts,
'stops': compressed_metadata.stops,
'segments': segment_curve_data,
},
}

segments_data = process_spectrogram_assets_for_contours(result)
result['segments'] = segments_data
contour_segments_data = process_spectrogram_assets_for_contours(result)
result['compressed']['contours'] = contour_segments_data

return result
51 changes: 50 additions & 1 deletion bats_ai/core/views/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class UpdateAnnotationsSchema(Schema):
id: int | None


class PulseMetadataSchema(Schema):
class PulseContourSchema(Schema):
id: int | None
index: int
bounding_box: Any
Expand All @@ -146,6 +146,36 @@ def from_orm(cls, obj: PulseMetadata):
)


class PulseMetadataSchema(Schema):
id: int | None
index: int
curve: list[list[float]] | None = None # list of [time, frequency]
char_freq: list[float] | None = None # point [time, frequency]
knee: list[float] | None = None # point [time, frequency]
heel: list[float] | None = None # point [time, frequency]

@classmethod
def from_orm(cls, obj: PulseMetadata):
def point_to_list(pt):
if pt is None:
return None
return [pt.x, pt.y]

def linestring_to_list(ls):
if ls is None:
return None
return [[c[0], c[1]] for c in ls.coords]

return cls(
id=obj.id,
index=obj.index,
curve=linestring_to_list(obj.curve),
char_freq=point_to_list(obj.char_freq),
knee=point_to_list(obj.knee),
heel=point_to_list(obj.heel),
)


@router.post('/')
def create_recording(
request: HttpRequest,
Expand Down Expand Up @@ -560,6 +590,25 @@ def get_annotations(request: HttpRequest, id: int):
return {'error': 'Recording not found'}


@router.get('/{id}/pulse_contours')
def get_pulse_contours(request: HttpRequest, id: int):
try:
recording = Recording.objects.get(pk=id)
if recording.owner == request.user or recording.public:
computed_pulse_annotation_qs = PulseMetadata.objects.filter(
recording=recording
).order_by('index')
return [
PulseContourSchema.from_orm(pulse) for pulse in computed_pulse_annotation_qs.all()
]
else:
return {
'error': 'Permission denied. You do not own this recording, and it is not public.'
}
except Recording.DoesNotExist:
return {'error': 'Recording not found'}


@router.get('/{id}/pulse_data')
def get_pulse_data(request: HttpRequest, id: int):
try:
Expand Down
23 changes: 19 additions & 4 deletions client/src/api/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -574,14 +574,28 @@ export interface Contour {
index: number;
}

export interface ComputedPulseAnnotation {
export interface ComputedPulseContour {
id: number;
index: number;
contours: Contour[];
}

async function getComputedPulseAnnotations(recordingId: number) {
const result = await axiosInstance.get<ComputedPulseAnnotation[]>(`/recording/${recordingId}/pulse_data`);
async function getComputedPulseContour(recordingId: number) {
const result = await axiosInstance.get<ComputedPulseContour[]>(`/recording/${recordingId}/pulse_contours`);
return result.data;
}

export interface PulseMetadata {
id: number;
index: number;
curve: number[][] | null; // list of [time, frequency]
char_freq: number[] | null; // point [time, frequency]
knee: number[] | null; // point [time, frequency]
heel: number[] | null; // point [time, frequency]
}

async function getPulseMetadata(recordingId: number) {
const result = await axiosInstance.get<PulseMetadata[]>(`/recording/${recordingId}/pulse_data`);
return result.data;
}

Expand Down Expand Up @@ -622,7 +636,8 @@ export {
getFileAnnotationDetails,
getExportStatus,
getRecordingTags,
getComputedPulseAnnotations,
getComputedPulseContour,
getPulseMetadata,
getCurrentUser,
getVettingDetailsForUser,
createOrUpdateVettingDetailsForUser,
Expand Down
Loading