diff --git a/pyproject.toml b/pyproject.toml index df3a1dd..e641068 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,8 +9,8 @@ license = "MIT" requires-python = ">=3.10.15" dynamic = ["version"] dependencies = [ - "ezmsg-baseproc>=1.0.2", - "ezmsg-sigproc>=2.14.0", + "ezmsg-baseproc>=1.3.0", + "ezmsg-sigproc>=2.15.0", "river>=0.22.0", "scikit-learn>=1.6.0", "torch>=2.6.0", diff --git a/src/ezmsg/learn/process/adaptive_linear_regressor.py b/src/ezmsg/learn/process/adaptive_linear_regressor.py index dbfbb00..3e00909 100644 --- a/src/ezmsg/learn/process/adaptive_linear_regressor.py +++ b/src/ezmsg/learn/process/adaptive_linear_regressor.py @@ -11,7 +11,6 @@ BaseAdaptiveTransformerUnit, processor_state, ) -from ezmsg.sigproc.sampler import SampleMessage from ezmsg.util.messages.axisarray import AxisArray, replace from ..util import AdaptiveLinearRegressor, RegressorType, get_regressor @@ -78,30 +77,30 @@ def _reset_state(self, message: AxisArray) -> None: # .template is updated in partial_fit pass - def partial_fit(self, message: SampleMessage) -> None: - if np.any(np.isnan(message.sample.data)): + def partial_fit(self, message: AxisArray) -> None: + if np.any(np.isnan(message.data)): return if self.settings.model_type in [ AdaptiveLinearRegressor.LINEAR, AdaptiveLinearRegressor.LOGISTIC, ]: - x = pd.DataFrame.from_dict({k: v for k, v in zip(message.sample.axes["ch"].data, message.sample.data.T)}) + x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, message.data.T)}) y = pd.Series( - data=message.trigger.value.data[:, 0], - name=message.trigger.value.axes["ch"].data[0], + data=message.attrs["trigger"].value.data[:, 0], + name=message.attrs["trigger"].value.axes["ch"].data[0], ) self.state.model.learn_many(x, y) else: - X = message.sample.data - if message.sample.get_axis_idx("time") != 0: - X = np.moveaxis(X, message.sample.get_axis_idx("time"), 0) - self.state.model.partial_fit(X, message.trigger.value.data) + X = message.data + if message.get_axis_idx("time") != 0: + X = np.moveaxis(X, message.get_axis_idx("time"), 0) + self.state.model.partial_fit(X, message.attrs["trigger"].value.data) self.state.template = replace( - message.trigger.value, - data=np.empty_like(message.trigger.value.data), - key=message.trigger.value.key + "_pred", + message.attrs["trigger"].value, + data=np.empty_like(message.attrs["trigger"].value.data), + key=message.attrs["trigger"].value.key + "_pred", ) def _process(self, message: AxisArray) -> AxisArray | None: diff --git a/src/ezmsg/learn/process/linear_regressor.py b/src/ezmsg/learn/process/linear_regressor.py index 49567f5..873406a 100644 --- a/src/ezmsg/learn/process/linear_regressor.py +++ b/src/ezmsg/learn/process/linear_regressor.py @@ -7,7 +7,6 @@ BaseAdaptiveTransformerUnit, processor_state, ) -from ezmsg.sigproc.sampler import SampleMessage from ezmsg.util.messages.axisarray import AxisArray, replace from sklearn.linear_model._base import LinearModel @@ -53,18 +52,18 @@ def _reset_state(self, message: AxisArray) -> None: # .model and .template are initialized in __init__ pass - def partial_fit(self, message: SampleMessage) -> None: - if np.any(np.isnan(message.sample.data)): + def partial_fit(self, message: AxisArray) -> None: + if np.any(np.isnan(message.data)): return - X = message.sample.data - y = message.trigger.value.data + X = message.data + y = message.attrs["trigger"].value.data # TODO: Resample should provide identical durations. self.state.model = self.state.model.fit(X[: y.shape[0]], y[: X.shape[0]]) self.state.template = replace( - message.trigger.value, + message.attrs["trigger"].value, data=np.array([[]]), - key=message.trigger.value.key + "_pred", + key=message.attrs["trigger"].value.key + "_pred", ) def _process(self, message: AxisArray) -> AxisArray: diff --git a/src/ezmsg/learn/process/mlp_old.py b/src/ezmsg/learn/process/mlp_old.py index b059cbf..b8c87d3 100644 --- a/src/ezmsg/learn/process/mlp_old.py +++ b/src/ezmsg/learn/process/mlp_old.py @@ -9,7 +9,6 @@ BaseAdaptiveTransformerUnit, processor_state, ) -from ezmsg.sigproc.sampler import SampleMessage from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.messages.util import replace @@ -134,14 +133,14 @@ def _to_tensor(self, data: np.ndarray) -> torch.Tensor: dtype = torch.float32 if self.settings.single_precision else torch.float64 return torch.tensor(data, dtype=dtype, device=self._state.device) - def partial_fit(self, message: SampleMessage) -> None: + def partial_fit(self, message: AxisArray) -> None: self._state.model.train() # TODO: loss_fn should be determined by setting loss_fn = torch.nn.functional.mse_loss - X = self._to_tensor(message.sample.data) - y_targ = self._to_tensor(message.trigger.value) + X = self._to_tensor(message.data) + y_targ = self._to_tensor(message.attrs["trigger"].value) with torch.set_grad_enabled(True): self._state.model.train() diff --git a/src/ezmsg/learn/process/refit_kalman.py b/src/ezmsg/learn/process/refit_kalman.py index bc04d12..6684bc9 100644 --- a/src/ezmsg/learn/process/refit_kalman.py +++ b/src/ezmsg/learn/process/refit_kalman.py @@ -8,7 +8,6 @@ BaseAdaptiveTransformerUnit, processor_state, ) -from ezmsg.sigproc.sampler import SampleMessage from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.messages.util import replace @@ -284,22 +283,22 @@ def _process(self, message: AxisArray) -> AxisArray: key=f"{message.key}_filtered" if hasattr(message, "key") else "filtered", ) - def partial_fit(self, message: SampleMessage) -> None: + def partial_fit(self, message: AxisArray) -> None: """ Perform refitting using externally provided data. - Expects message.sample.data (neural input) and message.trigger.value as a dict with: + Expects message.data (neural input) and message.attrs["trigger"].value as a dict with: - Y_state: (n_samples, n_states) array - intention_velocity_indices: Optional[int] - target_positions: Optional[np.ndarray] - cursor_positions: Optional[np.ndarray] - hold_flags: Optional[list[bool]] """ - if not hasattr(message, "sample") or not hasattr(message, "trigger"): + if "trigger" not in message.attrs: raise ValueError("Invalid message format for partial_fit.") - X = np.array(message.sample.data) - values = message.trigger.value + X = np.array(message.data) + values = message.attrs["trigger"].value if not isinstance(values, dict) or "Y_state" not in values: raise ValueError("partial_fit expects trigger.value to include at least 'Y_state'.") diff --git a/src/ezmsg/learn/process/rnn.py b/src/ezmsg/learn/process/rnn.py index ec1f3da..b3cfd62 100644 --- a/src/ezmsg/learn/process/rnn.py +++ b/src/ezmsg/learn/process/rnn.py @@ -5,7 +5,6 @@ import torch from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit from ezmsg.baseproc.util.profile import profile_subpub -from ezmsg.sigproc.sampler import SampleMessage from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.messages.util import replace @@ -184,18 +183,18 @@ def _train_step( if self._state.scheduler is not None: self._state.scheduler.step() - def partial_fit(self, message: SampleMessage) -> None: + def partial_fit(self, message: AxisArray) -> None: self._state.model.train() - X = self._to_tensor(message.sample.data) + X = self._to_tensor(message.data) # Add batch dimension if missing X, batched = self._ensure_batched(X) batch_size = X.shape[0] - preserve_state = self._maybe_reset_state(message.sample, batch_size) + preserve_state = self._maybe_reset_state(message, batch_size) - y_targ = message.trigger.value + y_targ = message.attrs["trigger"].value if not isinstance(y_targ, dict): y_targ = {"output": y_targ} y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()} diff --git a/src/ezmsg/learn/process/sgd.py b/src/ezmsg/learn/process/sgd.py index d00b39c..baedab5 100644 --- a/src/ezmsg/learn/process/sgd.py +++ b/src/ezmsg/learn/process/sgd.py @@ -5,7 +5,6 @@ from ezmsg.baseproc import ( BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit, - SampleMessage, processor_state, ) from ezmsg.util.messages.axisarray import AxisArray @@ -87,23 +86,23 @@ def _process(self, message: AxisArray) -> ClassifierMessage | None: key=message.key, ) - def partial_fit(self, message: SampleMessage) -> None: + def partial_fit(self, message: AxisArray) -> None: if self._hash != 0: - self._reset_state(message.sample) + self._reset_state(message) self._hash = 0 - if np.any(np.isnan(message.sample.data)): + if np.any(np.isnan(message.data)): return - train_sample = message.sample.data.reshape(1, -1) + train_sample = message.data.reshape(1, -1) if self._state.b_first_train: self._state.model.partial_fit( train_sample, - [message.trigger.value], + [message.attrs["trigger"].value], classes=list(self.settings.label_weights.keys()), ) self._state.b_first_train = False else: - self._state.model.partial_fit(train_sample, [message.trigger.value]) + self._state.model.partial_fit(train_sample, [message.attrs["trigger"].value]) class SGDDecoder( diff --git a/src/ezmsg/learn/process/sklearn.py b/src/ezmsg/learn/process/sklearn.py index 37d0b23..15a4596 100644 --- a/src/ezmsg/learn/process/sklearn.py +++ b/src/ezmsg/learn/process/sklearn.py @@ -10,7 +10,6 @@ BaseAdaptiveTransformerUnit, processor_state, ) -from ezmsg.sigproc.sampler import SampleMessage from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.messages.util import replace @@ -116,25 +115,25 @@ def _reset_state(self, message: AxisArray) -> None: # No checkpoint, initialize from scratch self._init_model() - def partial_fit(self, message: SampleMessage) -> None: - X = message.sample.data - y = message.trigger.value + def partial_fit(self, message: AxisArray) -> None: + X = message.data + y = message.attrs["trigger"].value if self._state.model is None: - self._reset_state(message.sample) + self._reset_state(message) if hasattr(self._state.model, "partial_fit"): kwargs = {} if self.settings.partial_fit_classes is not None: kwargs["classes"] = self.settings.partial_fit_classes self._state.model.partial_fit(X, y, **kwargs) elif hasattr(self._state.model, "learn_many"): - df_X = pd.DataFrame({k: v for k, v in zip(message.sample.axes["ch"].data, message.sample.data.T)}) + df_X = pd.DataFrame({k: v for k, v in zip(message.axes["ch"].data, message.data.T)}) name = ( - message.trigger.value.axes["ch"].data[0] - if hasattr(message.trigger.value, "axes") and "ch" in message.trigger.value.axes + message.attrs["trigger"].value.axes["ch"].data[0] + if hasattr(message.attrs["trigger"].value, "axes") and "ch" in message.attrs["trigger"].value.axes else "target" ) ser_y = pd.Series( - data=np.asarray(message.trigger.value.data).flatten(), + data=np.asarray(message.attrs["trigger"].value.data).flatten(), name=name, ) self._state.model.learn_many(df_X, ser_y) diff --git a/src/ezmsg/learn/process/torch.py b/src/ezmsg/learn/process/torch.py index db62022..ed3ef14 100644 --- a/src/ezmsg/learn/process/torch.py +++ b/src/ezmsg/learn/process/torch.py @@ -12,7 +12,6 @@ processor_state, ) from ezmsg.baseproc.util.profile import profile_subpub -from ezmsg.sigproc.sampler import SampleMessage from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.messages.util import replace @@ -294,13 +293,13 @@ def _reset_state(self, message: AxisArray) -> None: def _process(self, message: AxisArray) -> list[AxisArray]: return self._common_process(message) - def partial_fit(self, message: SampleMessage) -> None: + def partial_fit(self, message: AxisArray) -> None: self._state.model.train() - X = self._to_tensor(message.sample.data) + X = self._to_tensor(message.data) X, batched = self._ensure_batched(X) - y_targ = message.trigger.value + y_targ = message.attrs["trigger"].value if not isinstance(y_targ, dict): y_targ = {"output": y_targ} y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()} diff --git a/src/ezmsg/learn/process/transformer.py b/src/ezmsg/learn/process/transformer.py index 9784192..b3abeb8 100644 --- a/src/ezmsg/learn/process/transformer.py +++ b/src/ezmsg/learn/process/transformer.py @@ -4,7 +4,6 @@ import torch from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit from ezmsg.baseproc.util.profile import profile_subpub -from ezmsg.sigproc.sampler import SampleMessage from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.messages.util import replace @@ -125,13 +124,13 @@ def _process(self, message: AxisArray) -> list[AxisArray]: ) ] - def partial_fit(self, message: SampleMessage) -> None: + def partial_fit(self, message: AxisArray) -> None: self._state.model.train() - X = self._to_tensor(message.sample.data) + X = self._to_tensor(message.data) X, batched = self._ensure_batched(X) - y_targ = message.trigger.value + y_targ = message.attrs["trigger"].value if not isinstance(y_targ, dict): y_targ = {"output": y_targ} y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()} diff --git a/tests/unit/test_adaptive_linear_regressor.py b/tests/unit/test_adaptive_linear_regressor.py index d699f3a..6098aaa 100644 --- a/tests/unit/test_adaptive_linear_regressor.py +++ b/tests/unit/test_adaptive_linear_regressor.py @@ -1,6 +1,6 @@ import numpy as np import pytest -from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage +from ezmsg.baseproc import SampleTriggerMessage from ezmsg.util.messages.axisarray import AxisArray, replace from ezmsg.learn.process.adaptive_linear_regressor import ( @@ -42,7 +42,7 @@ def test_adaptive_linear_regressor(model_type: str): period=(0.0, dur), value=value_axarr, ) - samp = SampleMessage(trigger=samp_trig, sample=sig_axarr) + samp = replace(sig_axarr, attrs={"trigger": samp_trig}) proc = AdaptiveLinearRegressorTransformer(model_type=model_type) _ = proc.send(samp) diff --git a/tests/unit/test_linear_regressor.py b/tests/unit/test_linear_regressor.py index 3b7ba8b..0bfbeff 100644 --- a/tests/unit/test_linear_regressor.py +++ b/tests/unit/test_linear_regressor.py @@ -1,6 +1,6 @@ import numpy as np import pytest -from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage +from ezmsg.baseproc import SampleTriggerMessage from ezmsg.util.messages.axisarray import AxisArray, replace from ezmsg.learn.process.linear_regressor import LinearRegressorTransformer @@ -40,7 +40,7 @@ def test_linear_regressor(model_type: str): period=(0.0, dur), value=value_axarr, ) - samp = SampleMessage(trigger=samp_trig, sample=sig_axarr) + samp = replace(sig_axarr, attrs={"trigger": samp_trig}) gen = LinearRegressorTransformer(model_type=model_type) _ = gen.send(samp) diff --git a/tests/unit/test_mlp.py b/tests/unit/test_mlp.py index 1479068..e40e957 100644 --- a/tests/unit/test_mlp.py +++ b/tests/unit/test_mlp.py @@ -77,7 +77,8 @@ def test_mlp_checkpoint_io(tmp_path, sample_input, mlp_settings): def test_mlp_partial_fit_learns(sample_input, mlp_settings): - from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage + from ezmsg.baseproc import SampleTriggerMessage + from ezmsg.util.messages.util import replace proc = TorchModelProcessor( model_class="ezmsg.learn.model.mlp.MLP", @@ -88,13 +89,12 @@ def test_mlp_partial_fit_learns(sample_input, mlp_settings): ) proc(sample_input) - sample = AxisArray( - data=sample_input.data[:1], dims=["time", "ch"], axes=sample_input.axes - ) + sample = AxisArray(data=sample_input.data[:1], dims=["time", "ch"], axes=sample_input.axes) target = np.random.randn(1, 5) - msg = SampleMessage( - sample=sample, trigger=SampleTriggerMessage(timestamp=0.0, value=target) + msg = replace( + sample, + attrs={**sample.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=target)}, ) before = [p.detach().clone() for p in proc.state.model.parameters()] @@ -135,9 +135,9 @@ def test_mlp_hidden_size_integer(sample_input): device="cpu", ) proc(sample_input) - hidden_layers = [ - m for m in proc._state.model.modules() if isinstance(m, torch.nn.Linear) - ][:-1] # Exclude the output head + hidden_layers = [m for m in proc._state.model.modules() if isinstance(m, torch.nn.Linear)][ + :-1 + ] # Exclude the output head assert len(hidden_layers) == 3 # num_layers = 3 assert hidden_layers[0].in_features == 8 assert all(layer.out_features == 32 for layer in hidden_layers[:-1]) diff --git a/tests/unit/test_mlp_old.py b/tests/unit/test_mlp_old.py index 9b6172e..169ec9a 100644 --- a/tests/unit/test_mlp_old.py +++ b/tests/unit/test_mlp_old.py @@ -4,8 +4,9 @@ import pytest import torch import torch.nn -from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage +from ezmsg.baseproc import SampleTriggerMessage from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace from sklearn.model_selection import train_test_split from ezmsg.learn.process.mlp_old import MLPProcessor @@ -146,7 +147,10 @@ def xy_gen(set: int = 0): template.data[:] = X # This would fail if n_samps / batch_size had a remainder. template.axes["time"].offset = ts if set == 0: - yield SampleMessage(trigger=SampleTriggerMessage(timestamp=ts, value=y), sample=template) + yield replace( + template, + attrs={**template.attrs, "trigger": SampleTriggerMessage(timestamp=ts, value=y)}, + ) else: yield template, y @@ -167,14 +171,15 @@ def xy_gen(set: int = 0): result = [] train_loss = [] for sample_msg in xy_gen(set=0): - # Naive closed-loop inference - result.append(proc(sample_msg.sample)) + # Naive closed-loop inference — strip trigger attrs before inference + plain_msg = replace(sample_msg, attrs={}) + result.append(proc(plain_msg)) # Collect the loss to see if it decreases with training. train_loss.append( torch.nn.MSELoss()( torch.tensor(result[-1].data), - torch.tensor(sample_msg.trigger.value.reshape(-1, 1), dtype=torch.float32), + torch.tensor(sample_msg.attrs["trigger"].value.reshape(-1, 1), dtype=torch.float32), ).item() ) diff --git a/tests/unit/test_refit_kalman.py b/tests/unit/test_refit_kalman.py index bb3d83a..a4ea4ee 100644 --- a/tests/unit/test_refit_kalman.py +++ b/tests/unit/test_refit_kalman.py @@ -4,6 +4,7 @@ import numpy as np import pytest +from ezmsg.baseproc import SampleTriggerMessage from ezmsg.util.messages.axisarray import AxisArray from ezmsg.learn.process.refit_kalman import ( @@ -299,12 +300,6 @@ def test_partial_fit_functionality(create_test_message, checkpoint_file): H_initial = checkpoint_data["H_observation_matrix"] Q_initial = checkpoint_data["Q_measurement_noise_covariance"] - # Create a mock SampleMessage with the expected structure - class MockSampleMessage: - def __init__(self, neural_data, trigger_value): - self.sample = type("obj", (object,), {"data": neural_data})() - self.trigger = type("obj", (object,), {"value": trigger_value})() - # Create test data neural_data = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) # 3 samples, 2 channels trigger_value = { @@ -315,8 +310,12 @@ def __init__(self, neural_data, trigger_value): "hold_flags": [False, False, False], } - mock_message = MockSampleMessage(neural_data, trigger_value) - processor.partial_fit(mock_message) + sample_msg = AxisArray( + data=neural_data, + dims=["time", "ch"], + attrs={"trigger": SampleTriggerMessage(timestamp=0.0, value=trigger_value)}, + ) + processor.partial_fit(sample_msg) assert not np.allclose(H_initial, processor._state.model.H_observation_matrix) assert not np.allclose(Q_initial, processor._state.model.Q_measurement_noise_covariance) diff --git a/tests/unit/test_rnn.py b/tests/unit/test_rnn.py index f490d20..66cf42c 100644 --- a/tests/unit/test_rnn.py +++ b/tests/unit/test_rnn.py @@ -5,8 +5,9 @@ import pytest import torch import torch.nn -from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage +from ezmsg.baseproc import SampleTriggerMessage from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace from ezmsg.learn.process.rnn import RNNProcessor @@ -107,9 +108,7 @@ def test_rnn_process(rnn_type, simple_message): # We don't pass in the hx state so it should be initialized to zeros, same as in the first call to proc. in_tensor = torch.tensor(simple_message.data[None, ...], dtype=torch.float32) with torch.no_grad(): - expected_result = ( - proc.state.model(in_tensor)[0]["output"].cpu().numpy().squeeze(0) - ) + expected_result = proc.state.model(in_tensor)[0]["output"].cpu().numpy().squeeze(0) assert np.allclose(output.data, expected_result) @@ -139,9 +138,9 @@ def test_rnn_partial_fit(simple_message): target_shape = (simple_message.data.shape[0], output_size) target_value = np.ones(target_shape, dtype=np.float32) - sample_message = SampleMessage( - trigger=SampleTriggerMessage(timestamp=0.0, value=target_value), - sample=simple_message, + sample_message = replace( + simple_message, + attrs={**simple_message.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=target_value)}, ) proc(sample_message) @@ -149,9 +148,7 @@ def test_rnn_partial_fit(simple_message): assert not proc.state.model.training updated_weights = [p.detach() for p in proc.state.model.parameters()] - assert any( - not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights) - ) + assert any(not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights)) def test_rnn_checkpoint_save_load(simple_message): @@ -201,9 +198,7 @@ def test_rnn_checkpoint_save_load(simple_message): for key in state_dict1: assert key in state_dict2, f"Missing key {key} in loaded state_dict" - assert torch.equal(state_dict1[key], state_dict2[key]), ( - f"Mismatch in parameter {key}" - ) + assert torch.equal(state_dict1[key], state_dict2[key]), f"Mismatch in parameter {key}" finally: # Ensure the temporary file is deleted @@ -244,20 +239,21 @@ def test_rnn_partial_fit_multiloss(simple_message): dtype=torch.long, ) - sample_message = SampleMessage( - trigger=SampleTriggerMessage( - timestamp=0.0, - value={"traj": traj_target, "state": state_target}, - ), - sample=simple_message, + sample_message = replace( + simple_message, + attrs={ + **simple_message.attrs, + "trigger": SampleTriggerMessage( + timestamp=0.0, + value={"traj": traj_target, "state": state_target}, + ), + }, ) proc.partial_fit(sample_message) updated_weights = [p.detach() for p in proc.state.model.parameters()] - assert any( - not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights) - ) + assert any(not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights)) @pytest.mark.parametrize( @@ -269,9 +265,7 @@ def test_rnn_partial_fit_multiloss(simple_message): ("auto", 0.05, 0.1, False), # overlapping → reset ], ) -def test_rnn_preserve_state( - preserve_state_across_windows, win_stride, win_len, should_preserve -): +def test_rnn_preserve_state(preserve_state_across_windows, win_stride, win_len, should_preserve): hidden_size = 16 num_layers = 1 output_size = 2 diff --git a/tests/unit/test_sgd.py b/tests/unit/test_sgd.py index 0785224..b497e28 100644 --- a/tests/unit/test_sgd.py +++ b/tests/unit/test_sgd.py @@ -1,5 +1,5 @@ import numpy as np -from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage +from ezmsg.baseproc import SampleTriggerMessage from ezmsg.util.messages.axisarray import AxisArray from ezmsg.learn.process.sgd import SGDDecoderSettings, SGDDecoderTransformer @@ -13,9 +13,10 @@ def test_sgd(): data = np.random.normal(scale=0.05, size=(3, 2, 1)) data[time_idx[label] : time_idx[label] + 1, 0, 0] += 1.0 samples.append( - SampleMessage( - trigger=SampleTriggerMessage(timestamp=len(samples), period=None, value=label), - sample=AxisArray(data=data, dims=["time", "ch", "freq"]), + AxisArray( + data=data, + dims=["time", "ch", "freq"], + attrs={"trigger": SampleTriggerMessage(timestamp=len(samples), period=None, value=label)}, ) ) diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index 398d335..ca32c72 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -1,7 +1,8 @@ import numpy as np import pytest -from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage +from ezmsg.baseproc import SampleTriggerMessage from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace from ezmsg.learn.process.sklearn import SklearnModelProcessor @@ -83,9 +84,9 @@ def test_partial_fit_supported_models( proc = SklearnModelProcessor(**settings_kwargs) proc._reset_state(input_axisarray) - sample_msg = SampleMessage( - sample=input_axisarray, - trigger=SampleTriggerMessage(timestamp=0.0, value=labels), + sample_msg = replace( + input_axisarray, + attrs={**input_axisarray.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=labels)}, ) proc.partial_fit(sample_msg) @@ -96,9 +97,9 @@ def test_partial_fit_supported_models( def test_partial_fit_unsupported_model(input_axisarray, labels_regression): proc = SklearnModelProcessor(model_class="sklearn.linear_model.Ridge") proc._reset_state(input_axisarray) - sample_msg = SampleMessage( - sample=input_axisarray, - trigger=SampleTriggerMessage(timestamp=0.0, value=labels_regression), + sample_msg = replace( + input_axisarray, + attrs={**input_axisarray.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=labels_regression)}, ) with pytest.raises(NotImplementedError, match="partial_fit"): proc.partial_fit(sample_msg) @@ -108,9 +109,9 @@ def test_partial_fit_changes_model(input_axisarray, labels_regression): proc = SklearnModelProcessor(model_class="sklearn.linear_model.SGDRegressor") proc._reset_state(input_axisarray) - sample_msg = SampleMessage( - sample=input_axisarray, - trigger=SampleTriggerMessage(timestamp=0.0, value=labels_regression), + sample_msg = replace( + input_axisarray, + attrs={**input_axisarray.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=labels_regression)}, ) proc.partial_fit(sample_msg) @@ -127,9 +128,7 @@ def test_model_save_and_load(tmp_path, input_axisarray): checkpoint_path = tmp_path / "model_checkpoint.pkl" proc.save_checkpoint(str(checkpoint_path)) - new_proc = SklearnModelProcessor( - model_class="sklearn.linear_model.Ridge", checkpoint_path=str(checkpoint_path) - ) + new_proc = SklearnModelProcessor(model_class="sklearn.linear_model.Ridge", checkpoint_path=str(checkpoint_path)) new_proc._reset_state(input_axisarray) assert new_proc._state.model is not None diff --git a/tests/unit/test_torch.py b/tests/unit/test_torch.py index 6d36228..6fb0789 100644 --- a/tests/unit/test_torch.py +++ b/tests/unit/test_torch.py @@ -5,8 +5,9 @@ import numpy as np import pytest import torch -from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage +from ezmsg.baseproc import SampleTriggerMessage from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace from ezmsg.learn.process.torch import TorchModelProcessor @@ -185,9 +186,9 @@ def test_partial_fit_changes_weights(batch_message, device): }, ) - msg = SampleMessage( - sample=sample, - trigger=SampleTriggerMessage(timestamp=0.0, value=y), + msg = replace( + sample, + attrs={"trigger": SampleTriggerMessage(timestamp=0.0, value=y)}, ) proc(sample) # run forward pass once to init model @@ -318,14 +319,11 @@ def test_multihead_partial_fit_with_loss_dict(batch_message, device): "head_a": np.random.randn(1, 2), "head_b": np.random.randn(1, 3), } - sample = AxisArray( + msg = AxisArray( data=batch_message.data[:1], dims=["time", "ch"], axes=batch_message.axes, - ) - msg = SampleMessage( - sample=sample, - trigger=SampleTriggerMessage(timestamp=0.0, value=y_targ), + attrs={"trigger": SampleTriggerMessage(timestamp=0.0, value=y_targ)}, ) before_a = proc._state.model.head_a.weight.clone() @@ -360,14 +358,11 @@ def test_partial_fit_with_loss_weights(batch_message, device): "head_a": np.random.randn(1, 2), "head_b": np.random.randn(1, 3), } - sample = AxisArray( + msg = AxisArray( data=batch_message.data[:1], dims=["time", "ch"], axes=batch_message.axes, - ) - msg = SampleMessage( - sample=sample, - trigger=SampleTriggerMessage(timestamp=0.0, value=y_targ), + attrs={"trigger": SampleTriggerMessage(timestamp=0.0, value=y_targ)}, ) # Expect no error, and just run once diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index 71952b0..73efd6c 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -5,8 +5,9 @@ import pytest import torch import torch.nn -from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage +from ezmsg.baseproc import SampleTriggerMessage from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace from ezmsg.learn.process.transformer import TransformerProcessor @@ -138,9 +139,9 @@ def test_transformer_partial_fit(simple_message, decoder_layers): target_shape = (simple_message.data.shape[0], output_size) target_value = np.ones(target_shape, dtype=np.float32) - sample_message = SampleMessage( - trigger=SampleTriggerMessage(timestamp=0.0, value=target_value), - sample=simple_message, + sample_message = replace( + simple_message, + attrs={**simple_message.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=target_value)}, ) proc.partial_fit(sample_message) @@ -149,9 +150,7 @@ def test_transformer_partial_fit(simple_message, decoder_layers): assert proc.state.tgt_cache is None updated_weights = [p.detach() for p in proc.state.model.parameters()] - assert any( - not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights) - ) + assert any(not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights)) def test_transformer_checkpoint_save_load(simple_message): @@ -201,9 +200,7 @@ def test_transformer_checkpoint_save_load(simple_message): for key in state_dict1: assert key in state_dict2, f"Missing key {key} in loaded state_dict" - assert torch.equal(state_dict1[key], state_dict2[key]), ( - f"Mismatch in parameter {key}" - ) + assert torch.equal(state_dict1[key], state_dict2[key]), f"Mismatch in parameter {key}" finally: # Ensure the temporary file is deleted @@ -244,20 +241,21 @@ def test_transformer_partial_fit_multiloss(simple_message): dtype=torch.long, ) - sample_message = SampleMessage( - trigger=SampleTriggerMessage( - timestamp=0.0, - value={"traj": traj_target, "state": state_target}, - ), - sample=simple_message, + sample_message = replace( + simple_message, + attrs={ + **simple_message.attrs, + "trigger": SampleTriggerMessage( + timestamp=0.0, + value={"traj": traj_target, "state": state_target}, + ), + }, ) proc.partial_fit(sample_message) updated_weights = [p.detach() for p in proc.state.model.parameters()] - assert any( - not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights) - ) + assert any(not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights)) def test_autoregressive_cache_behavior(simple_message):