Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ license = "MIT"
requires-python = ">=3.10.15"
dynamic = ["version"]
dependencies = [
"ezmsg-baseproc>=1.0.2",
"ezmsg-sigproc>=2.14.0",
"ezmsg-baseproc>=1.3.0",
"ezmsg-sigproc>=2.15.0",
"river>=0.22.0",
"scikit-learn>=1.6.0",
"torch>=2.6.0",
Expand Down
25 changes: 12 additions & 13 deletions src/ezmsg/learn/process/adaptive_linear_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
BaseAdaptiveTransformerUnit,
processor_state,
)
from ezmsg.sigproc.sampler import SampleMessage
from ezmsg.util.messages.axisarray import AxisArray, replace

from ..util import AdaptiveLinearRegressor, RegressorType, get_regressor
Expand Down Expand Up @@ -78,30 +77,30 @@ def _reset_state(self, message: AxisArray) -> None:
# .template is updated in partial_fit
pass

def partial_fit(self, message: SampleMessage) -> None:
if np.any(np.isnan(message.sample.data)):
def partial_fit(self, message: AxisArray) -> None:
if np.any(np.isnan(message.data)):
return

if self.settings.model_type in [
AdaptiveLinearRegressor.LINEAR,
AdaptiveLinearRegressor.LOGISTIC,
]:
x = pd.DataFrame.from_dict({k: v for k, v in zip(message.sample.axes["ch"].data, message.sample.data.T)})
x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
y = pd.Series(
data=message.trigger.value.data[:, 0],
name=message.trigger.value.axes["ch"].data[0],
data=message.attrs["trigger"].value.data[:, 0],
name=message.attrs["trigger"].value.axes["ch"].data[0],
)
self.state.model.learn_many(x, y)
else:
X = message.sample.data
if message.sample.get_axis_idx("time") != 0:
X = np.moveaxis(X, message.sample.get_axis_idx("time"), 0)
self.state.model.partial_fit(X, message.trigger.value.data)
X = message.data
if message.get_axis_idx("time") != 0:
X = np.moveaxis(X, message.get_axis_idx("time"), 0)
self.state.model.partial_fit(X, message.attrs["trigger"].value.data)

self.state.template = replace(
message.trigger.value,
data=np.empty_like(message.trigger.value.data),
key=message.trigger.value.key + "_pred",
message.attrs["trigger"].value,
data=np.empty_like(message.attrs["trigger"].value.data),
key=message.attrs["trigger"].value.key + "_pred",
)

def _process(self, message: AxisArray) -> AxisArray | None:
Expand Down
13 changes: 6 additions & 7 deletions src/ezmsg/learn/process/linear_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
BaseAdaptiveTransformerUnit,
processor_state,
)
from ezmsg.sigproc.sampler import SampleMessage
from ezmsg.util.messages.axisarray import AxisArray, replace
from sklearn.linear_model._base import LinearModel

Expand Down Expand Up @@ -53,18 +52,18 @@ def _reset_state(self, message: AxisArray) -> None:
# .model and .template are initialized in __init__
pass

def partial_fit(self, message: SampleMessage) -> None:
if np.any(np.isnan(message.sample.data)):
def partial_fit(self, message: AxisArray) -> None:
if np.any(np.isnan(message.data)):
return

X = message.sample.data
y = message.trigger.value.data
X = message.data
y = message.attrs["trigger"].value.data
# TODO: Resample should provide identical durations.
self.state.model = self.state.model.fit(X[: y.shape[0]], y[: X.shape[0]])
self.state.template = replace(
message.trigger.value,
message.attrs["trigger"].value,
data=np.array([[]]),
key=message.trigger.value.key + "_pred",
key=message.attrs["trigger"].value.key + "_pred",
)

def _process(self, message: AxisArray) -> AxisArray:
Expand Down
7 changes: 3 additions & 4 deletions src/ezmsg/learn/process/mlp_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
BaseAdaptiveTransformerUnit,
processor_state,
)
from ezmsg.sigproc.sampler import SampleMessage
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

Expand Down Expand Up @@ -134,14 +133,14 @@ def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
dtype = torch.float32 if self.settings.single_precision else torch.float64
return torch.tensor(data, dtype=dtype, device=self._state.device)

def partial_fit(self, message: SampleMessage) -> None:
def partial_fit(self, message: AxisArray) -> None:
self._state.model.train()

# TODO: loss_fn should be determined by setting
loss_fn = torch.nn.functional.mse_loss

X = self._to_tensor(message.sample.data)
y_targ = self._to_tensor(message.trigger.value)
X = self._to_tensor(message.data)
y_targ = self._to_tensor(message.attrs["trigger"].value)

with torch.set_grad_enabled(True):
self._state.model.train()
Expand Down
11 changes: 5 additions & 6 deletions src/ezmsg/learn/process/refit_kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
BaseAdaptiveTransformerUnit,
processor_state,
)
from ezmsg.sigproc.sampler import SampleMessage
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

Expand Down Expand Up @@ -284,22 +283,22 @@ def _process(self, message: AxisArray) -> AxisArray:
key=f"{message.key}_filtered" if hasattr(message, "key") else "filtered",
)

def partial_fit(self, message: SampleMessage) -> None:
def partial_fit(self, message: AxisArray) -> None:
"""
Perform refitting using externally provided data.

Expects message.sample.data (neural input) and message.trigger.value as a dict with:
Expects message.data (neural input) and message.attrs["trigger"].value as a dict with:
- Y_state: (n_samples, n_states) array
- intention_velocity_indices: Optional[int]
- target_positions: Optional[np.ndarray]
- cursor_positions: Optional[np.ndarray]
- hold_flags: Optional[list[bool]]
"""
if not hasattr(message, "sample") or not hasattr(message, "trigger"):
if "trigger" not in message.attrs:
raise ValueError("Invalid message format for partial_fit.")

X = np.array(message.sample.data)
values = message.trigger.value
X = np.array(message.data)
values = message.attrs["trigger"].value

if not isinstance(values, dict) or "Y_state" not in values:
raise ValueError("partial_fit expects trigger.value to include at least 'Y_state'.")
Expand Down
9 changes: 4 additions & 5 deletions src/ezmsg/learn/process/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
from ezmsg.baseproc.util.profile import profile_subpub
from ezmsg.sigproc.sampler import SampleMessage
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

Expand Down Expand Up @@ -184,18 +183,18 @@ def _train_step(
if self._state.scheduler is not None:
self._state.scheduler.step()

def partial_fit(self, message: SampleMessage) -> None:
def partial_fit(self, message: AxisArray) -> None:
self._state.model.train()

X = self._to_tensor(message.sample.data)
X = self._to_tensor(message.data)

# Add batch dimension if missing
X, batched = self._ensure_batched(X)

batch_size = X.shape[0]
preserve_state = self._maybe_reset_state(message.sample, batch_size)
preserve_state = self._maybe_reset_state(message, batch_size)

y_targ = message.trigger.value
y_targ = message.attrs["trigger"].value
if not isinstance(y_targ, dict):
y_targ = {"output": y_targ}
y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
Expand Down
13 changes: 6 additions & 7 deletions src/ezmsg/learn/process/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ezmsg.baseproc import (
BaseAdaptiveTransformer,
BaseAdaptiveTransformerUnit,
SampleMessage,
processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray
Expand Down Expand Up @@ -87,23 +86,23 @@ def _process(self, message: AxisArray) -> ClassifierMessage | None:
key=message.key,
)

def partial_fit(self, message: SampleMessage) -> None:
def partial_fit(self, message: AxisArray) -> None:
if self._hash != 0:
self._reset_state(message.sample)
self._reset_state(message)
self._hash = 0

if np.any(np.isnan(message.sample.data)):
if np.any(np.isnan(message.data)):
return
train_sample = message.sample.data.reshape(1, -1)
train_sample = message.data.reshape(1, -1)
if self._state.b_first_train:
self._state.model.partial_fit(
train_sample,
[message.trigger.value],
[message.attrs["trigger"].value],
classes=list(self.settings.label_weights.keys()),
)
self._state.b_first_train = False
else:
self._state.model.partial_fit(train_sample, [message.trigger.value])
self._state.model.partial_fit(train_sample, [message.attrs["trigger"].value])


class SGDDecoder(
Expand Down
17 changes: 8 additions & 9 deletions src/ezmsg/learn/process/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
BaseAdaptiveTransformerUnit,
processor_state,
)
from ezmsg.sigproc.sampler import SampleMessage
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

Expand Down Expand Up @@ -116,25 +115,25 @@ def _reset_state(self, message: AxisArray) -> None:
# No checkpoint, initialize from scratch
self._init_model()

def partial_fit(self, message: SampleMessage) -> None:
X = message.sample.data
y = message.trigger.value
def partial_fit(self, message: AxisArray) -> None:
X = message.data
y = message.attrs["trigger"].value
if self._state.model is None:
self._reset_state(message.sample)
self._reset_state(message)
if hasattr(self._state.model, "partial_fit"):
kwargs = {}
if self.settings.partial_fit_classes is not None:
kwargs["classes"] = self.settings.partial_fit_classes
self._state.model.partial_fit(X, y, **kwargs)
elif hasattr(self._state.model, "learn_many"):
df_X = pd.DataFrame({k: v for k, v in zip(message.sample.axes["ch"].data, message.sample.data.T)})
df_X = pd.DataFrame({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
name = (
message.trigger.value.axes["ch"].data[0]
if hasattr(message.trigger.value, "axes") and "ch" in message.trigger.value.axes
message.attrs["trigger"].value.axes["ch"].data[0]
if hasattr(message.attrs["trigger"].value, "axes") and "ch" in message.attrs["trigger"].value.axes
else "target"
)
ser_y = pd.Series(
data=np.asarray(message.trigger.value.data).flatten(),
data=np.asarray(message.attrs["trigger"].value.data).flatten(),
name=name,
)
self._state.model.learn_many(df_X, ser_y)
Expand Down
7 changes: 3 additions & 4 deletions src/ezmsg/learn/process/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
processor_state,
)
from ezmsg.baseproc.util.profile import profile_subpub
from ezmsg.sigproc.sampler import SampleMessage
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

Expand Down Expand Up @@ -294,13 +293,13 @@ def _reset_state(self, message: AxisArray) -> None:
def _process(self, message: AxisArray) -> list[AxisArray]:
return self._common_process(message)

def partial_fit(self, message: SampleMessage) -> None:
def partial_fit(self, message: AxisArray) -> None:
self._state.model.train()

X = self._to_tensor(message.sample.data)
X = self._to_tensor(message.data)
X, batched = self._ensure_batched(X)

y_targ = message.trigger.value
y_targ = message.attrs["trigger"].value
if not isinstance(y_targ, dict):
y_targ = {"output": y_targ}
y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
Expand Down
7 changes: 3 additions & 4 deletions src/ezmsg/learn/process/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
from ezmsg.baseproc.util.profile import profile_subpub
from ezmsg.sigproc.sampler import SampleMessage
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

Expand Down Expand Up @@ -125,13 +124,13 @@ def _process(self, message: AxisArray) -> list[AxisArray]:
)
]

def partial_fit(self, message: SampleMessage) -> None:
def partial_fit(self, message: AxisArray) -> None:
self._state.model.train()

X = self._to_tensor(message.sample.data)
X = self._to_tensor(message.data)
X, batched = self._ensure_batched(X)

y_targ = message.trigger.value
y_targ = message.attrs["trigger"].value
if not isinstance(y_targ, dict):
y_targ = {"output": y_targ}
y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_adaptive_linear_regressor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import pytest
from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage
from ezmsg.baseproc import SampleTriggerMessage
from ezmsg.util.messages.axisarray import AxisArray, replace

from ezmsg.learn.process.adaptive_linear_regressor import (
Expand Down Expand Up @@ -42,7 +42,7 @@ def test_adaptive_linear_regressor(model_type: str):
period=(0.0, dur),
value=value_axarr,
)
samp = SampleMessage(trigger=samp_trig, sample=sig_axarr)
samp = replace(sig_axarr, attrs={"trigger": samp_trig})

proc = AdaptiveLinearRegressorTransformer(model_type=model_type)
_ = proc.send(samp)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_linear_regressor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import pytest
from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage
from ezmsg.baseproc import SampleTriggerMessage
from ezmsg.util.messages.axisarray import AxisArray, replace

from ezmsg.learn.process.linear_regressor import LinearRegressorTransformer
Expand Down Expand Up @@ -40,7 +40,7 @@ def test_linear_regressor(model_type: str):
period=(0.0, dur),
value=value_axarr,
)
samp = SampleMessage(trigger=samp_trig, sample=sig_axarr)
samp = replace(sig_axarr, attrs={"trigger": samp_trig})

gen = LinearRegressorTransformer(model_type=model_type)
_ = gen.send(samp)
Expand Down
18 changes: 9 additions & 9 deletions tests/unit/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def test_mlp_checkpoint_io(tmp_path, sample_input, mlp_settings):


def test_mlp_partial_fit_learns(sample_input, mlp_settings):
from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage
from ezmsg.baseproc import SampleTriggerMessage
from ezmsg.util.messages.util import replace

proc = TorchModelProcessor(
model_class="ezmsg.learn.model.mlp.MLP",
Expand All @@ -88,13 +89,12 @@ def test_mlp_partial_fit_learns(sample_input, mlp_settings):
)
proc(sample_input)

sample = AxisArray(
data=sample_input.data[:1], dims=["time", "ch"], axes=sample_input.axes
)
sample = AxisArray(data=sample_input.data[:1], dims=["time", "ch"], axes=sample_input.axes)
target = np.random.randn(1, 5)

msg = SampleMessage(
sample=sample, trigger=SampleTriggerMessage(timestamp=0.0, value=target)
msg = replace(
sample,
attrs={**sample.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=target)},
)

before = [p.detach().clone() for p in proc.state.model.parameters()]
Expand Down Expand Up @@ -135,9 +135,9 @@ def test_mlp_hidden_size_integer(sample_input):
device="cpu",
)
proc(sample_input)
hidden_layers = [
m for m in proc._state.model.modules() if isinstance(m, torch.nn.Linear)
][:-1] # Exclude the output head
hidden_layers = [m for m in proc._state.model.modules() if isinstance(m, torch.nn.Linear)][
:-1
] # Exclude the output head
assert len(hidden_layers) == 3 # num_layers = 3
assert hidden_layers[0].in_features == 8
assert all(layer.out_features == 32 for layer in hidden_layers[:-1])
Expand Down
Loading