Skip to content
Closed
85 changes: 83 additions & 2 deletions src/Baballonia/Models/SliderBindableSetting.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
using CommunityToolkit.Mvvm.ComponentModel;
using Baballonia.Contracts;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.ComponentModel;
using System.Threading.Tasks;
using Avalonia.Threading;

namespace Baballonia.Models;

Expand All @@ -18,7 +24,82 @@ public SliderBindableSetting(string name, float lower = 0f, float upper = 1f, fl
Name = name;
Lower = lower;
Upper = upper;
Min = max;
Max = min;
Min = min;
Max = max;
}
}

public partial class ParameterGroupCollection : ObservableCollection<SliderBindableSetting>
{
public string GroupName { get; }
public IFilterSettings FilterSettings { get; }

public ParameterGroupCollection(string groupName, IFilterSettings filterSettings, IEnumerable<SliderBindableSetting> items)
: base(items)
{
GroupName = groupName;
FilterSettings = filterSettings;
}
}

public interface IFilterSettings : INotifyPropertyChanged
{
bool Enabled { get; set; }
float MinFreqCutoff { get; set; }
float SpeedCutoff { get; set; }
}

public partial class GroupFilterSettings : ObservableObject, IFilterSettings
{
private readonly ILocalSettingsService _localSettingsService;
private readonly string _prefix;

[ObservableProperty]
private bool _enabled;

[ObservableProperty]
private float _minFreqCutoff;

[ObservableProperty]
private float _speedCutoff;

public GroupFilterSettings(ILocalSettingsService localSettingsService, string settingPrefix,
bool defaultEnabled, float defaultMinFreqCutoff, float defaultSpeedCutoff)
{
_localSettingsService = localSettingsService;
_prefix = settingPrefix;

Enabled = defaultEnabled;
MinFreqCutoff = defaultMinFreqCutoff;
SpeedCutoff = defaultSpeedCutoff;

Task.Run(async () =>
{
var enabled = await _localSettingsService.ReadSettingAsync($"{_prefix}_Enabled", defaultEnabled);
var min = await _localSettingsService.ReadSettingAsync($"{_prefix}_MinFreq", defaultMinFreqCutoff);
var speed = await _localSettingsService.ReadSettingAsync($"{_prefix}_Speed", defaultSpeedCutoff);
Dispatcher.UIThread.Post(() =>
{
Enabled = enabled;
MinFreqCutoff = min;
SpeedCutoff = speed;
});
});

PropertyChanged += async (_, e) =>
{
switch (e.PropertyName)
{
case nameof(Enabled):
await _localSettingsService.SaveSettingAsync($"{_prefix}_Enabled", Enabled);
break;
case nameof(MinFreqCutoff):
await _localSettingsService.SaveSettingAsync($"{_prefix}_MinFreq", MinFreqCutoff);
break;
case nameof(SpeedCutoff):
await _localSettingsService.SaveSettingAsync($"{_prefix}_Speed", SpeedCutoff);
break;
}
};
}
}
160 changes: 84 additions & 76 deletions src/Baballonia/Services/Inference/Filters/OneEuroFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,101 +3,109 @@

namespace Baballonia.Services.Inference.Filters;

public class OneEuroFilter : IFilter
public class GroupedOneEuroFilter : IFilter
{
private float[] minCutoff;
private float[] beta;
private float[] dCutoff;
private float[] xPrev;
private float[] dxPrev;
private DateTime tPrev;
public OneEuroFilter(float[] x0, float minCutoff = 1.0f, float beta = 0.0f)
private sealed class GroupState
{
float dx0 = 0.0f;
float dCutoff = 1.0f;
int length = x0.Length;
this.minCutoff = CreateFilledArray(length, minCutoff);
this.beta = CreateFilledArray(length, beta);
this.dCutoff = CreateFilledArray(length, dCutoff);
// Previous values.
this.xPrev = (float[])x0.Clone();
this.dxPrev = CreateFilledArray(length, dx0);
this.tPrev = DateTime.UtcNow;

public int[] Indices = Array.Empty<int>();
public float[] XPrev = Array.Empty<float>();
public float[] DxPrev = Array.Empty<float>();
public float MinCutoff;
public float Beta;
public float DCutoff = 1.0f;
public DateTime TPrev;
public bool Initialized;
}

public float[] Filter(float[] x)
{
if (x.Length != xPrev.Length)
throw new ArgumentException($"Input shape does not match initial shape. Expected: {xPrev.Length}, got: {x.Length}");

DateTime now = DateTime.UtcNow;
float elapsedTime = (float)(now - tPrev).TotalSeconds;

if (elapsedTime == 0.0f)
{
xPrev = (float[])x.Clone();
return x;
}

float[] t_e = CreateFilledArray(x.Length, elapsedTime);

// Derivative
float[] dx = new float[x.Length];
for (int i = 0; i < x.Length; i++)
{
dx[i] = (x[i] - xPrev[i]) / t_e[i];
}
private readonly Dictionary<string, GroupState> _groups = new();

float[] a_d = SmoothingFactor(t_e, dCutoff);
float[] dxHat = ExponentialSmoothing(a_d, dx, dxPrev);
public void ConfigureGroup(string groupName, int[] parameterIndices, float minCutoff, float beta)
{
if (parameterIndices.Length == 0)
return;

// Adjusted cutoff
float[] cutoff = new float[x.Length];
for (int i = 0; i < x.Length; i++)
var state = new GroupState
{
cutoff[i] = minCutoff[i] + beta[i] * Math.Abs(dxHat[i]);
}

float[] a = SmoothingFactor(t_e, cutoff);
float[] xHat = ExponentialSmoothing(a, x, xPrev);

// Store previous values
xPrev = xHat;
dxPrev = dxHat;
tPrev = now;

return xHat;
Indices = (int[])parameterIndices.Clone(),
XPrev = new float[parameterIndices.Length],
DxPrev = new float[parameterIndices.Length],
MinCutoff = Math.Max(0.001f, minCutoff),
Beta = Math.Max(0f, beta),
TPrev = DateTime.UtcNow,
Initialized = false
};

_groups[groupName] = state;
}

private float[] CreateFilledArray(int length, float value)
public void DisableGroup(string groupName)
{
float[] arr = new float[length];
for (int i = 0; i < length; i++) arr[i] = value;
return arr;
_groups.Remove(groupName);
}

private float[] SmoothingFactor(float[] t_e, float[] cutoff)
public float[] Filter(float[] input)
{
int length = t_e.Length;
float[] result = new float[length];
for (int i = 0; i < length; i++)
if (_groups.Count == 0)
return input;

var now = DateTime.UtcNow;
float[] result = (float[])input.Clone();

foreach (var kvp in _groups)
{
float r = 2 * (float)Math.PI * cutoff[i] * t_e[i];
result[i] = r / (r + 1);
var state = kvp.Value;
if (state.Indices.Length == 0)
continue;

int n = state.Indices.Length;
float[] x = new float[n];
var indices = state.Indices;
for (int i = 0; i < n; i++)
{
x[i] = input[indices[i]];
}

float dt = (float)(now - state.TPrev).TotalSeconds;
if (!state.Initialized || dt <= 0f)
{
for (int i = 0; i < n; i++)
state.XPrev[i] = x[i];
state.TPrev = now;
state.Initialized = true;
continue;
}

// dx = (x - xPrev) / dt
for (int i = 0; i < n; i++)
{
state.DxPrev[i] = OneEuroSmooth(state.DCutoff, dt, (x[i] - state.XPrev[i]) / dt, state.DxPrev[i]);
}

// cutoff = minCutoff + beta * |dxHat|
for (int i = 0; i < n; i++)
{
float cutoff = state.MinCutoff + state.Beta * MathF.Abs(state.DxPrev[i]);
float a = SmoothingFactor(cutoff, dt);
float xHat = a * x[i] + (1f - a) * state.XPrev[i];
state.XPrev[i] = xHat;
result[indices[i]] = xHat;
}

state.TPrev = now;
}

return result;
}

private float[] ExponentialSmoothing(float[] a, float[] x, float[] xPrev)
private static float OneEuroSmooth(float cutoff, float dt, float value, float prev)
{
int length = a.Length;
float[] result = new float[length];
for (int i = 0; i < length; i++)
{
result[i] = a[i] * x[i] + (1 - a[i]) * xPrev[i];
}
return result;
float a = SmoothingFactor(cutoff, dt);
return a * value + (1f - a) * prev;
}

private static float SmoothingFactor(float cutoff, float dt)
{
float r = 2f * MathF.PI * cutoff * dt;
return r / (r + 1f);
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using Baballonia.Services.Inference;
using Baballonia.Services.Inference.Filters;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
Expand All @@ -10,7 +11,7 @@ public class PlatformSettings(
Size inputSize,
InferenceSession session,
DenseTensor<float> tensor,
OneEuroFilter oneEuroFilter,
IFilter oneEuroFilter,
float lastTime,
string inputName,
string modelName)
Expand All @@ -19,7 +20,7 @@ public class PlatformSettings(
public InferenceSession Session { get; } = session;
public DenseTensor<float> Tensor { get; } = tensor;

public OneEuroFilter Filter { get; } = oneEuroFilter;
public IFilter Filter { get; } = oneEuroFilter;
public float LastTime { get; set; } = lastTime;
public string InputName { get; } = inputName;
public string ModelName { get; } = modelName;
Expand Down
58 changes: 3 additions & 55 deletions src/Baballonia/Services/ProcessingLoopService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,66 +52,14 @@ public ProcessingLoopService(
dualTransformer.RightTransformer.TargetSize = new Size(128, 128);
EyesProcessingPipeline.ImageTransformer = dualTransformer;

var face = LoadFaceInference();
var eyes = LoadEyeInference();

FaceProcessingPipeline.InferenceService = face;
EyesProcessingPipeline.InferenceService = eyes;

var faceFilter = LoadFaceFilter();
var eyeFilter = LoadEyeFilter();
FaceProcessingPipeline.Filter = faceFilter;
EyesProcessingPipeline.Filter = eyeFilter;
LoadEyeStabilizationSetting();
_ = SetupFaceInference();
_ = SetupEyeInference();

_drawTimer.Tick += TimerEvent;
_drawTimer.Start();
}

private IFilter? LoadFaceFilter()
{
var enabled = _localSettingsService.ReadSetting<bool>("AppSettings_OneEuroEnabled");
var cutoff = _localSettingsService.ReadSetting<float>("AppSettings_OneEuroMinFreqCutoff");
var speedCutoff = _localSettingsService.ReadSetting<float>("AppSettings_OneEuroSpeedCutoff");

if (!enabled)
return null;

float[] faceArray = new float[Utils.FaceRawExpressions];
var faceFilter = new OneEuroFilter(
faceArray,
minCutoff: cutoff,
beta: speedCutoff
);

return faceFilter;
}

private IFilter? LoadEyeFilter()
{
var enabled = _localSettingsService.ReadSetting<bool>("AppSettings_OneEuroEnabled");
var cutoff = _localSettingsService.ReadSetting<float>("AppSettings_OneEuroMinFreqCutoff");
var speedCutoff = _localSettingsService.ReadSetting<float>("AppSettings_OneEuroSpeedCutoff");

if (!enabled)
return null;

float[] eyeArray = new float[Utils.EyeRawExpressions];
var eyeFilter = new OneEuroFilter(
eyeArray,
minCutoff: cutoff,
beta: speedCutoff
);
return eyeFilter;
}


public Task<DefaultInferenceRunner> LoadEyeInferenceAsync()
{
return Task.Run(LoadEyeInference);
}

public DefaultInferenceRunner LoadEyeInference()
public async Task SetupEyeInference()
{
const string defaultEyeModel = "eyeModel.onnx";
var eyeModel = _localSettingsService.ReadSetting<string>("EyeHome_EyeModel", defaultEyeModel);
Expand Down
Loading