From e3c49cf4695ae9d3328b52487c7dcf31087ca52c Mon Sep 17 00:00:00 2001 From: Ishaan Desai Date: Thu, 11 Dec 2025 19:17:45 +0100 Subject: [PATCH 01/21] Make type extensions compatible with Python 3.8.11 because that is what SuperMUC has --- micro_manager/adaptivity/model_adaptivity.py | 24 ++++++++------------ 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/micro_manager/adaptivity/model_adaptivity.py b/micro_manager/adaptivity/model_adaptivity.py index 7d4e049f..e83108fd 100644 --- a/micro_manager/adaptivity/model_adaptivity.py +++ b/micro_manager/adaptivity/model_adaptivity.py @@ -114,10 +114,10 @@ def switch_models( self, locations: np.ndarray, t: float, - inputs: list[dict], + inputs: list, prev_output: dict, sims: list, - active_sim_ids: Optional[list[int]] = None, + active_sim_ids: Optional = None, ) -> None: """ Switches models within sims list. If active_sim_ids is None, all sims are considered as active. @@ -159,10 +159,10 @@ def check_convergence( self, locations: np.ndarray, t: float, - inputs: list[dict], - prev_output: Optional[dict], + inputs: list, + prev_output: Optional, sims: list, - active_sim_ids: Optional[list[int]] = None, + active_sim_ids: Optional = None, ) -> None: """ Similarly to switch_models, checks whether models would be switched in next step. @@ -207,9 +207,7 @@ def get_num_resolutions(self) -> int: """ return len(self._model_classes) - def get_resolution_sim_class( - self, resolution: Union[int, np.ndarray] - ) -> Union[object, np.ndarray]: + def get_resolution_sim_class(self, resolution: Union) -> Union: """ Looks up the class associated with the provided resolution. @@ -227,9 +225,7 @@ def get_resolution_sim_class( clamp_in_range(resolution, 0, len(self._model_classes) - 1) ] - def get_sim_class_resolution( - self, sim: Union[object, np.ndarray] - ) -> Union[int, np.ndarray]: + def get_sim_class_resolution(self, sim: Union) -> Union: """ Looks up the resolution associated with the provided simulation object. @@ -248,7 +244,7 @@ def get_sim_class_resolution( ) def _gather_current_resolutions( - self, sims: list[object], active_sims: np.ndarray + self, sims: list, active_sims: np.ndarray ) -> np.ndarray: """ Gathers current resolutions. Inactive sims have resolution -1. @@ -277,7 +273,7 @@ def _gather_target_resolutions( cur_res: np.ndarray, locations: np.ndarray, t: float, - inputs: list[dict], + inputs: list, prev_output: dict, active_sims: np.ndarray, ) -> np.ndarray: @@ -320,7 +316,7 @@ def _gather_target_resolutions( ) return res_tgt - def _create_active_mask(self, active_sim_ids: list[int], size: int) -> np.ndarray: + def _create_active_mask(self, active_sim_ids: list, size: int) -> np.ndarray: """ Converts list of active simulation ids to np boolean mask. From 6531d575065372c5e4a4718cc765e6d17653302c Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Fri, 12 Dec 2025 18:18:35 +0100 Subject: [PATCH 02/21] implement model instancing to reduce memory footprint --- micro_manager/adaptivity/adaptivity.py | 5 +- micro_manager/adaptivity/global_adaptivity.py | 9 ++- .../adaptivity/global_adaptivity_lb.py | 6 +- micro_manager/adaptivity/local_adaptivity.py | 7 ++- micro_manager/adaptivity/model_adaptivity.py | 10 ++- micro_manager/config.py | 44 +++++++++++++ micro_manager/micro_manager.py | 18 ++++-- micro_manager/model_manager.py | 61 +++++++++++++++++++ 8 files changed, 146 insertions(+), 14 deletions(-) create mode 100644 micro_manager/model_manager.py diff --git a/micro_manager/adaptivity/adaptivity.py b/micro_manager/adaptivity/adaptivity.py index 293edb7f..70894bf5 100644 --- a/micro_manager/adaptivity/adaptivity.py +++ b/micro_manager/adaptivity/adaptivity.py @@ -11,7 +11,7 @@ class AdaptivityCalculator: def __init__( - self, configurator, nsims, micro_problem_cls, base_logger, rank + self, configurator, nsims, micro_problem_cls, model_manager, base_logger, rank ) -> None: """ Class constructor. @@ -24,6 +24,8 @@ def __init__( Number of micro simulations. micro_problem_cls : callable Class of micro problem. + model_manager : object + Handles instantiation of micro simulation. base_logger : object of class Logger Logger object to log messages. rank : int @@ -37,6 +39,7 @@ def __init__( self._adaptivity_output_type = configurator.get_adaptivity_output_type() self._micro_problem_cls = micro_problem_cls + self._model_manager = model_manager self._coarse_tol = 0.0 self._ref_tol = 0.0 diff --git a/micro_manager/adaptivity/global_adaptivity.py b/micro_manager/adaptivity/global_adaptivity.py index 3de4957f..eecf13cb 100644 --- a/micro_manager/adaptivity/global_adaptivity.py +++ b/micro_manager/adaptivity/global_adaptivity.py @@ -27,6 +27,7 @@ def __init__( rank: int, comm, micro_problem_cls, + model_manager, ) -> None: """ Class constructor. @@ -49,9 +50,11 @@ def __init__( Communicator for MPI. micro_problem_cls : callable Class of micro problem. + model_manager : object of class ModelManager + Handles instantiation of the micro simulation. """ super().__init__( - configurator, global_number_of_sims, micro_problem_cls, base_logger, rank + configurator, global_number_of_sims, micro_problem_cls, model_manager, base_logger, rank ) self._global_number_of_sims = global_number_of_sims self._global_ids = global_ids @@ -460,7 +463,7 @@ def _update_inactive_sims(self, micro_sims: list) -> None: # Only handle activation of simulations on this rank for gid in to_be_activated_gids: to_be_activated_lid = self._global_ids.index(gid) - micro_sims[to_be_activated_lid] = self._micro_problem_cls(gid) + micro_sims[to_be_activated_lid] = self._model_manager.get_instance(gid, self._micro_problem_cls) assoc_active_gid = self._sim_is_associated_to[gid] if self._is_sim_on_this_rank[ @@ -497,7 +500,7 @@ def _update_inactive_sims(self, micro_sims: list) -> None: local_ids = to_be_activated_map[gid] for lid in local_ids: # Create the micro simulation object and set its state - micro_sims[lid] = self._micro_problem_cls(self._global_ids[lid]) + micro_sims[lid] = self._model_manager.get_instance(self._global_ids[lid], self._micro_problem_cls) micro_sims[lid].set_state(state) # Delete the micro simulation object if it is inactive diff --git a/micro_manager/adaptivity/global_adaptivity_lb.py b/micro_manager/adaptivity/global_adaptivity_lb.py index da3cc195..67c1ba91 100644 --- a/micro_manager/adaptivity/global_adaptivity_lb.py +++ b/micro_manager/adaptivity/global_adaptivity_lb.py @@ -26,6 +26,7 @@ def __init__( rank: int, comm, micro_problem_cls: callable, + model_manager, ) -> None: """ Class constructor. @@ -48,6 +49,8 @@ def __init__( Communicator for MPI. micro_problem_cls : callable Class of micro problem. + model_manager : object of class ModelManager + Handles instantiation of the micro simulation. """ super().__init__( configurator, @@ -58,6 +61,7 @@ def __init__( rank, comm, micro_problem_cls, + model_manager, ) self._base_logger = base_logger @@ -366,7 +370,7 @@ def _move_active_sims( # Create simulations and set them to the received states for req in recv_reqs: output, gid = req.wait() - micro_sims.append(self._micro_problem_cls(gid)) + micro_sims.append(self._model_manager.get_instance(gid, self._micro_problem_cls)) micro_sims[-1].set_state(output) self._global_ids.append(gid) self._is_sim_on_this_rank[gid] = True diff --git a/micro_manager/adaptivity/local_adaptivity.py b/micro_manager/adaptivity/local_adaptivity.py index f68f49b5..c449c375 100644 --- a/micro_manager/adaptivity/local_adaptivity.py +++ b/micro_manager/adaptivity/local_adaptivity.py @@ -21,6 +21,7 @@ def __init__( rank, comm, micro_problem_cls, + model_manager, ) -> None: """ Class constructor. @@ -39,8 +40,10 @@ def __init__( Communicator for MPI. micro_problem_cls : callable Class of micro problem. + model_manager : object of class ModelManager + Handles instantiation of micro simulation. """ - super().__init__(configurator, num_sims, micro_problem_cls, base_logger, rank) + super().__init__(configurator, num_sims, micro_problem_cls, model_manager, base_logger, rank) self._comm = comm # similarity_dists: 2D array having similarity distances between each micro simulation pair @@ -293,7 +296,7 @@ def _update_inactive_sims(self, micro_sims: list) -> None: # Update the set of inactive micro sims for i in to_be_activated_ids: associated_active_id = self._sim_is_associated_to[i] - micro_sims[i] = self._micro_problem_cls(i) + micro_sims[i] = self._model_manager.get_instance(i, self._micro_problem_cls) micro_sims[i].set_state(micro_sims[associated_active_id].get_state()) self._sim_is_associated_to[ i diff --git a/micro_manager/adaptivity/model_adaptivity.py b/micro_manager/adaptivity/model_adaptivity.py index 7d4e049f..2022ead5 100644 --- a/micro_manager/adaptivity/model_adaptivity.py +++ b/micro_manager/adaptivity/model_adaptivity.py @@ -7,13 +7,14 @@ from ..micro_simulation import create_simulation_class from micro_manager.tools.logging_wrapper import Logger from micro_manager.tools.misc import clamp_in_range +from micro_manager.model_manager import ModelManager import numpy as np import importlib class ModelAdaptivity: - def __init__(self, configurator: Config, rank: int, log_file: str) -> None: + def __init__(self, model_manager: ModelManager, configurator: Config, rank: int, log_file: str) -> None: """ Class constructor. @@ -28,12 +29,15 @@ def __init__(self, configurator: Config, rank: int, log_file: str) -> None: """ self._logger = Logger(__name__, log_file, rank) + self._model_manager = model_manager self._model_files = configurator.get_model_adaptivity_file_names() self._switching_func_name = ( configurator.get_model_adaptivity_switching_function() ) + stateless_flags = configurator.get_model_adaptivity_micro_stateless() self._model_classes = [] + pos = 0 CLASS_NAME = "MicroSimulation" for model_file in self._model_files: try: @@ -42,6 +46,8 @@ def __init__(self, configurator: Config, rank: int, log_file: str) -> None: CLASS_NAME, ) self._model_classes.append(create_simulation_class(model)) + self._model_manager.register(self._model_classes[pos], stateless_flags[pos]) + pos += 1 except Exception as e: self._logger.log_info_rank_zero( f"Failed to load model class with error: {e}" @@ -152,7 +158,7 @@ def switch_models( sim_state = sims[idx].get_state() sim_id = sims[idx].get_global_id() - sims[idx] = self.get_resolution_sim_class(tgt_res[idx])(sim_id) + sims[idx] = self._model_manager.get_instance(sim_id, self.get_resolution_sim_class(tgt_res[idx])) sims[idx].set_state(sim_state) def check_convergence( diff --git a/micro_manager/config.py b/micro_manager/config.py index 4033d746..b740d900 100644 --- a/micro_manager/config.py +++ b/micro_manager/config.py @@ -25,6 +25,7 @@ def __init__(self, config_file_name): self._config_file_name = config_file_name self._logger = None self._micro_file_name = None + self._micro_stateless = False self._precice_config_file_name = None self._macro_mesh_name = None @@ -72,6 +73,7 @@ def __init__(self, config_file_name): # Model Adaptivity information self._m_adap = False self._m_adap_micro_file_names = None + self._m_adap_micro_stateless = None self._m_adap_switching_function = None def set_logger(self, logger): @@ -114,6 +116,13 @@ def _read_json(self, config_file_name): .replace(".py", "") ) + try: + self._micro_stateless = self._data["micro_stateless"] + self._logger.log_info_rank_zero("Only creating one full instance of Micro Model.") + except: + self._micro_stateless = False + self._logger.log_info_rank_zero("Creating full instance of Micro Model per mesh vertex.") + self._logger.log_info_rank_zero( "Micro simulation file name: " + self._data["micro_file_name"] ) @@ -482,6 +491,19 @@ def read_json_micro_manager(self): "model_adaptivity_settings" ]["switching_function"] + if self._data["simulation_params"]["model_adaptivity_settings"]["micro_stateless"]: + self._m_adap_micro_stateless = self._data["simulation_params"][ + "model_adaptivity_settings" + ]["micro_stateless"] + else: + self._m_adap_micro_stateless = [False] * len(self._m_adap_micro_file_names) + + for i in range(len(self._m_adap_micro_file_names)): + if self._m_adap_micro_stateless[i]: + self._logger.log_info_rank_zero(f"Only creating one full instance of Micro Model {i}.") + else: + self._logger.log_info_rank_zero(f"Creating full instance of Micro Model {i} per mesh vertex.") + if "interpolate_crash" in self._data["simulation_params"]: if self._data["simulation_params"]["interpolate_crash"]: self._interpolate_crash = True @@ -658,6 +680,17 @@ def get_micro_file_name(self): """ return self._micro_file_name + def turn_on_micro_stateless(self): + """ + Boolean stating whether micro model is stateless or not. + + Returns + ------- + stateless : bool + True if micro model is stateless, False otherwise. + """ + return self._micro_stateless + def get_micro_output_n(self): """ Get the micro output frequency @@ -975,6 +1008,17 @@ def get_model_adaptivity_file_names(self): """ return self._m_adap_micro_file_names + def get_model_adaptivity_micro_stateless(self): + """ + List of boolean stating whether the respective micro model is stateless or not. + + Returns + ------- + stateless : list + True if micro model is stateless, False otherwise. + """ + return self._m_adap_micro_stateless + def get_model_adaptivity_switching_function(self): """ Get path to switching function file diff --git a/micro_manager/micro_manager.py b/micro_manager/micro_manager.py index d36e61d3..b2eb9cd4 100644 --- a/micro_manager/micro_manager.py +++ b/micro_manager/micro_manager.py @@ -25,6 +25,8 @@ import precice +from .model_manager import ModelManager + from .micro_manager_base import MicroManager from .adaptivity.global_adaptivity import GlobalAdaptivityCalculator @@ -153,6 +155,8 @@ def __init__(self, config_file: str, log_file: str = "") -> None: self._t = 0 # global time self._n = 0 # sim-step + self._model_manager = ModelManager() + # ************** # Public methods # ************** @@ -531,7 +535,7 @@ def initialize(self) -> None: micro_problem_cls = None if self._is_model_adaptivity_on: self._model_adaptivity_controller: ModelAdaptivity = ModelAdaptivity( - self._config, self._rank, self._log_file + self._model_manager, self._config, self._rank, self._log_file ) micro_problem_cls = ( self._model_adaptivity_controller.get_resolution_sim_class(0) @@ -546,13 +550,14 @@ def initialize(self) -> None: micro_problem_cls = create_simulation_class( micro_problem_base, "MicroSimulationDefault" ) + self._model_manager.register(micro_problem_cls, self._config.turn_on_micro_stateless()) # Create micro simulation objects self._micro_sims = [0] * self._local_number_of_sims if not self._lazy_init: for i in range(self._local_number_of_sims): - self._micro_sims[i] = micro_problem_cls( - self._global_ids_of_local_sims[i] + self._micro_sims[i] = self._model_manager.get_instance( + self._global_ids_of_local_sims[i], micro_problem_cls ) if self._is_adaptivity_on: @@ -565,6 +570,7 @@ def initialize(self) -> None: self._rank, self._comm, micro_problem_cls, + self._model_manager, ) ) elif self._config.get_adaptivity_type() == "global": @@ -579,6 +585,7 @@ def initialize(self) -> None: self._rank, self._comm, micro_problem_cls, + self._model_manager, ) ) else: @@ -592,6 +599,7 @@ def initialize(self) -> None: self._rank, self._comm, micro_problem_cls, + self._model_manager, ) ) @@ -633,8 +641,8 @@ def initialize(self) -> None: return for i in active_sim_lids: - self._micro_sims[i] = micro_problem_cls( - self._global_ids_of_local_sims[i] + self._micro_sims[i] = self._model_manager.get_instance( + self._global_ids_of_local_sims[i], micro_problem_cls ) first_id = active_sim_lids[0] # First active simulation ID diff --git a/micro_manager/model_manager.py b/micro_manager/model_manager.py new file mode 100644 index 00000000..67984134 --- /dev/null +++ b/micro_manager/model_manager.py @@ -0,0 +1,61 @@ + +class ModelWrapper: + """ + Stateless Model Wrapper + """ + def __init__(self, global_id, backend, attach_init, attach_output): + self._global_id = global_id + self._backend = backend + + if attach_init: self.initialize = backend.initialize + if attach_output: self.output = backend.output + + def get_global_id(self) -> int: + return self._global_id + + def solve(self, macro_data, dt): + return self._backend.solve(macro_data, dt) + + def get_state(self): + return self._backend.get_state() + + def set_state(self, state): + self._backend.set_state(state) + +class ModelManager: + def __init__(self): + self._registered_classes = [] + self._stateless_map = dict() + self._backend_map = dict() + self._has_init_map = dict() + self._has_output_map = dict() + + def register(self, micro_sim_cls, stateless): + if micro_sim_cls in self._registered_classes: return + + self._registered_classes.append(micro_sim_cls) + self._stateless_map[micro_sim_cls] = stateless + + if stateless: self._backend_map[micro_sim_cls] = micro_sim_cls(-1) + + self._has_init_map[micro_sim_cls] = False + if hasattr(micro_sim_cls, "initialize") and callable(getattr(micro_sim_cls, "initialize")): + self._has_init_map[micro_sim_cls] = True + + self._has_output_map[micro_sim_cls] = False + if hasattr(micro_sim_cls, "output") and callable(getattr(micro_sim_cls, "output")): + self._has_output_map[micro_sim_cls] = True + + def get_instance(self, gid, micro_sim_cls): + if micro_sim_cls not in self._registered_classes: + raise RuntimeError("Trying to create instance of unknown class!") + + if self._stateless_map[micro_sim_cls]: + return ModelWrapper( + gid, + self._backend_map[micro_sim_cls], + self._has_init_map[micro_sim_cls], + self._has_output_map[micro_sim_cls] + ) + else: + return micro_sim_cls(gid) \ No newline at end of file From 12416b226afe986991fa09a6b887206f1cd31487 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Fri, 12 Dec 2025 18:31:39 +0100 Subject: [PATCH 03/21] small fix, to make mada work again --- micro_manager/model_manager.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/micro_manager/model_manager.py b/micro_manager/model_manager.py index 67984134..2b119be1 100644 --- a/micro_manager/model_manager.py +++ b/micro_manager/model_manager.py @@ -22,6 +22,10 @@ def get_state(self): def set_state(self, state): self._backend.set_state(state) + @property + def __class__(self): + return self._backend.__class__ + class ModelManager: def __init__(self): self._registered_classes = [] From 062b5559d42d93269f7b20af0540cfb9a59b40e4 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Fri, 12 Dec 2025 18:40:18 +0100 Subject: [PATCH 04/21] fix formatting --- micro_manager/adaptivity/global_adaptivity.py | 15 ++++++++--- .../adaptivity/global_adaptivity_lb.py | 4 ++- micro_manager/adaptivity/local_adaptivity.py | 6 +++-- micro_manager/adaptivity/model_adaptivity.py | 16 ++++++++--- micro_manager/config.py | 24 ++++++++++++----- micro_manager/micro_manager.py | 6 +++-- micro_manager/model_manager.py | 27 ++++++++++++------- 7 files changed, 72 insertions(+), 26 deletions(-) diff --git a/micro_manager/adaptivity/global_adaptivity.py b/micro_manager/adaptivity/global_adaptivity.py index eecf13cb..93dc754f 100644 --- a/micro_manager/adaptivity/global_adaptivity.py +++ b/micro_manager/adaptivity/global_adaptivity.py @@ -54,7 +54,12 @@ def __init__( Handles instantiation of the micro simulation. """ super().__init__( - configurator, global_number_of_sims, micro_problem_cls, model_manager, base_logger, rank + configurator, + global_number_of_sims, + micro_problem_cls, + model_manager, + base_logger, + rank, ) self._global_number_of_sims = global_number_of_sims self._global_ids = global_ids @@ -463,7 +468,9 @@ def _update_inactive_sims(self, micro_sims: list) -> None: # Only handle activation of simulations on this rank for gid in to_be_activated_gids: to_be_activated_lid = self._global_ids.index(gid) - micro_sims[to_be_activated_lid] = self._model_manager.get_instance(gid, self._micro_problem_cls) + micro_sims[to_be_activated_lid] = self._model_manager.get_instance( + gid, self._micro_problem_cls + ) assoc_active_gid = self._sim_is_associated_to[gid] if self._is_sim_on_this_rank[ @@ -500,7 +507,9 @@ def _update_inactive_sims(self, micro_sims: list) -> None: local_ids = to_be_activated_map[gid] for lid in local_ids: # Create the micro simulation object and set its state - micro_sims[lid] = self._model_manager.get_instance(self._global_ids[lid], self._micro_problem_cls) + micro_sims[lid] = self._model_manager.get_instance( + self._global_ids[lid], self._micro_problem_cls + ) micro_sims[lid].set_state(state) # Delete the micro simulation object if it is inactive diff --git a/micro_manager/adaptivity/global_adaptivity_lb.py b/micro_manager/adaptivity/global_adaptivity_lb.py index 67c1ba91..42213a07 100644 --- a/micro_manager/adaptivity/global_adaptivity_lb.py +++ b/micro_manager/adaptivity/global_adaptivity_lb.py @@ -370,7 +370,9 @@ def _move_active_sims( # Create simulations and set them to the received states for req in recv_reqs: output, gid = req.wait() - micro_sims.append(self._model_manager.get_instance(gid, self._micro_problem_cls)) + micro_sims.append( + self._model_manager.get_instance(gid, self._micro_problem_cls) + ) micro_sims[-1].set_state(output) self._global_ids.append(gid) self._is_sim_on_this_rank[gid] = True diff --git a/micro_manager/adaptivity/local_adaptivity.py b/micro_manager/adaptivity/local_adaptivity.py index c449c375..a5bff8d8 100644 --- a/micro_manager/adaptivity/local_adaptivity.py +++ b/micro_manager/adaptivity/local_adaptivity.py @@ -43,7 +43,9 @@ def __init__( model_manager : object of class ModelManager Handles instantiation of micro simulation. """ - super().__init__(configurator, num_sims, micro_problem_cls, model_manager, base_logger, rank) + super().__init__( + configurator, num_sims, micro_problem_cls, model_manager, base_logger, rank + ) self._comm = comm # similarity_dists: 2D array having similarity distances between each micro simulation pair @@ -296,7 +298,7 @@ def _update_inactive_sims(self, micro_sims: list) -> None: # Update the set of inactive micro sims for i in to_be_activated_ids: associated_active_id = self._sim_is_associated_to[i] - micro_sims[i] = self._model_manager.get_instance(i, self._micro_problem_cls) + micro_sims[i] = self._model_manager.get_instance(i, self._micro_problem_cls) micro_sims[i].set_state(micro_sims[associated_active_id].get_state()) self._sim_is_associated_to[ i diff --git a/micro_manager/adaptivity/model_adaptivity.py b/micro_manager/adaptivity/model_adaptivity.py index 2022ead5..ba92a5b5 100644 --- a/micro_manager/adaptivity/model_adaptivity.py +++ b/micro_manager/adaptivity/model_adaptivity.py @@ -14,7 +14,13 @@ class ModelAdaptivity: - def __init__(self, model_manager: ModelManager, configurator: Config, rank: int, log_file: str) -> None: + def __init__( + self, + model_manager: ModelManager, + configurator: Config, + rank: int, + log_file: str, + ) -> None: """ Class constructor. @@ -46,7 +52,9 @@ def __init__(self, model_manager: ModelManager, configurator: Config, rank: int, CLASS_NAME, ) self._model_classes.append(create_simulation_class(model)) - self._model_manager.register(self._model_classes[pos], stateless_flags[pos]) + self._model_manager.register( + self._model_classes[pos], stateless_flags[pos] + ) pos += 1 except Exception as e: self._logger.log_info_rank_zero( @@ -158,7 +166,9 @@ def switch_models( sim_state = sims[idx].get_state() sim_id = sims[idx].get_global_id() - sims[idx] = self._model_manager.get_instance(sim_id, self.get_resolution_sim_class(tgt_res[idx])) + sims[idx] = self._model_manager.get_instance( + sim_id, self.get_resolution_sim_class(tgt_res[idx]) + ) sims[idx].set_state(sim_state) def check_convergence( diff --git a/micro_manager/config.py b/micro_manager/config.py index b740d900..2af96a04 100644 --- a/micro_manager/config.py +++ b/micro_manager/config.py @@ -118,10 +118,14 @@ def _read_json(self, config_file_name): try: self._micro_stateless = self._data["micro_stateless"] - self._logger.log_info_rank_zero("Only creating one full instance of Micro Model.") + self._logger.log_info_rank_zero( + "Only creating one full instance of Micro Model." + ) except: self._micro_stateless = False - self._logger.log_info_rank_zero("Creating full instance of Micro Model per mesh vertex.") + self._logger.log_info_rank_zero( + "Creating full instance of Micro Model per mesh vertex." + ) self._logger.log_info_rank_zero( "Micro simulation file name: " + self._data["micro_file_name"] @@ -491,18 +495,26 @@ def read_json_micro_manager(self): "model_adaptivity_settings" ]["switching_function"] - if self._data["simulation_params"]["model_adaptivity_settings"]["micro_stateless"]: + if self._data["simulation_params"]["model_adaptivity_settings"][ + "micro_stateless" + ]: self._m_adap_micro_stateless = self._data["simulation_params"][ "model_adaptivity_settings" ]["micro_stateless"] else: - self._m_adap_micro_stateless = [False] * len(self._m_adap_micro_file_names) + self._m_adap_micro_stateless = [False] * len( + self._m_adap_micro_file_names + ) for i in range(len(self._m_adap_micro_file_names)): if self._m_adap_micro_stateless[i]: - self._logger.log_info_rank_zero(f"Only creating one full instance of Micro Model {i}.") + self._logger.log_info_rank_zero( + f"Only creating one full instance of Micro Model {i}." + ) else: - self._logger.log_info_rank_zero(f"Creating full instance of Micro Model {i} per mesh vertex.") + self._logger.log_info_rank_zero( + f"Creating full instance of Micro Model {i} per mesh vertex." + ) if "interpolate_crash" in self._data["simulation_params"]: if self._data["simulation_params"]["interpolate_crash"]: diff --git a/micro_manager/micro_manager.py b/micro_manager/micro_manager.py index b2eb9cd4..46533a12 100644 --- a/micro_manager/micro_manager.py +++ b/micro_manager/micro_manager.py @@ -550,7 +550,9 @@ def initialize(self) -> None: micro_problem_cls = create_simulation_class( micro_problem_base, "MicroSimulationDefault" ) - self._model_manager.register(micro_problem_cls, self._config.turn_on_micro_stateless()) + self._model_manager.register( + micro_problem_cls, self._config.turn_on_micro_stateless() + ) # Create micro simulation objects self._micro_sims = [0] * self._local_number_of_sims @@ -641,7 +643,7 @@ def initialize(self) -> None: return for i in active_sim_lids: - self._micro_sims[i] = self._model_manager.get_instance( + self._micro_sims[i] = self._model_manager.get_instance( self._global_ids_of_local_sims[i], micro_problem_cls ) diff --git a/micro_manager/model_manager.py b/micro_manager/model_manager.py index 2b119be1..6ff3b4f0 100644 --- a/micro_manager/model_manager.py +++ b/micro_manager/model_manager.py @@ -1,14 +1,16 @@ - class ModelWrapper: """ Stateless Model Wrapper """ + def __init__(self, global_id, backend, attach_init, attach_output): self._global_id = global_id self._backend = backend - if attach_init: self.initialize = backend.initialize - if attach_output: self.output = backend.output + if attach_init: + self.initialize = backend.initialize + if attach_output: + self.output = backend.output def get_global_id(self) -> int: return self._global_id @@ -26,6 +28,7 @@ def set_state(self, state): def __class__(self): return self._backend.__class__ + class ModelManager: def __init__(self): self._registered_classes = [] @@ -35,19 +38,25 @@ def __init__(self): self._has_output_map = dict() def register(self, micro_sim_cls, stateless): - if micro_sim_cls in self._registered_classes: return + if micro_sim_cls in self._registered_classes: + return self._registered_classes.append(micro_sim_cls) self._stateless_map[micro_sim_cls] = stateless - if stateless: self._backend_map[micro_sim_cls] = micro_sim_cls(-1) + if stateless: + self._backend_map[micro_sim_cls] = micro_sim_cls(-1) self._has_init_map[micro_sim_cls] = False - if hasattr(micro_sim_cls, "initialize") and callable(getattr(micro_sim_cls, "initialize")): + if hasattr(micro_sim_cls, "initialize") and callable( + getattr(micro_sim_cls, "initialize") + ): self._has_init_map[micro_sim_cls] = True self._has_output_map[micro_sim_cls] = False - if hasattr(micro_sim_cls, "output") and callable(getattr(micro_sim_cls, "output")): + if hasattr(micro_sim_cls, "output") and callable( + getattr(micro_sim_cls, "output") + ): self._has_output_map[micro_sim_cls] = True def get_instance(self, gid, micro_sim_cls): @@ -59,7 +68,7 @@ def get_instance(self, gid, micro_sim_cls): gid, self._backend_map[micro_sim_cls], self._has_init_map[micro_sim_cls], - self._has_output_map[micro_sim_cls] + self._has_output_map[micro_sim_cls], ) else: - return micro_sim_cls(gid) \ No newline at end of file + return micro_sim_cls(gid) From f660372fd5dfaa6d28dd420a13d77d810d0ee70b Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Fri, 12 Dec 2025 18:52:30 +0100 Subject: [PATCH 05/21] fix tests --- tests/unit/test_adaptivity_parallel.py | 9 +++++++++ tests/unit/test_adaptivity_serial.py | 10 ++++++++++ tests/unit/test_global_adaptivity_lb.py | 8 ++++++++ 3 files changed, 27 insertions(+) diff --git a/tests/unit/test_adaptivity_parallel.py b/tests/unit/test_adaptivity_parallel.py index ba5e78b6..12379b40 100644 --- a/tests/unit/test_adaptivity_parallel.py +++ b/tests/unit/test_adaptivity_parallel.py @@ -22,6 +22,11 @@ def get_state(self): return self._state.copy() +class ModelManager: + def get_instance(self, gid, micro_problem_cls): + return micro_problem_cls(gid) + + class TestGlobalAdaptivity(TestCase): def setUp(self): self._comm = MPI.COMM_WORLD @@ -60,6 +65,7 @@ def test_update_inactive_sims_global_adaptivity(self): rank=self._rank, comm=self._comm, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._is_sim_active = np.array( @@ -134,6 +140,7 @@ def test_update_all_active_sims_global_adaptivity(self): rank=self._rank, comm=self._comm, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._adaptivity_data_names = ["data1", "data2"] @@ -189,6 +196,7 @@ def test_communicate_micro_output(self): rank=self._rank, comm=self._comm, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._is_sim_active = np.array( @@ -228,6 +236,7 @@ def test_get_ranks_of_sims(self): rank=self._rank, comm=self._comm, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) actual_ranks_of_sims = adaptivity_controller._get_ranks_of_sims() diff --git a/tests/unit/test_adaptivity_serial.py b/tests/unit/test_adaptivity_serial.py index fe6412d0..f92c5b00 100644 --- a/tests/unit/test_adaptivity_serial.py +++ b/tests/unit/test_adaptivity_serial.py @@ -26,6 +26,11 @@ def get_state(self): pass +class ModelManager: + def get_instance(self, gid, micro_problem_cls): + return micro_problem_cls(gid) + + class TestLocalAdaptivity(TestCase): def setUp(self): self._number_of_sims = 5 @@ -94,6 +99,7 @@ def test_update_similarity_dists(self): configurator, nsims=self._number_of_sims, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), base_logger=MagicMock(), rank=0, ) @@ -146,6 +152,7 @@ def test_update_active_sims(self): rank=0, comm=MagicMock(), micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._refine_const = self._refine_const adaptivity_controller._coarse_const = self._coarse_const @@ -182,6 +189,7 @@ def test_adaptivity_norms(self): base_logger=MagicMock(), rank=0, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) fake_data = np.array([[1], [2], [3]]) @@ -280,6 +288,7 @@ def test_associate_active_to_inactive(self): base_logger=MagicMock(), rank=0, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._refine_const = self._refine_const adaptivity_controller._coarse_const = self._coarse_const @@ -323,6 +332,7 @@ def test_update_inactive_sims_local_adaptivity(self): rank=0, comm=MPI.COMM_WORLD, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._refine_const = self._refine_const adaptivity_controller._coarse_const = self._coarse_const diff --git a/tests/unit/test_global_adaptivity_lb.py b/tests/unit/test_global_adaptivity_lb.py index d60ae12b..fb63cc80 100644 --- a/tests/unit/test_global_adaptivity_lb.py +++ b/tests/unit/test_global_adaptivity_lb.py @@ -23,6 +23,11 @@ def get_state(self): return self._state.copy() +class ModelManager: + def get_instance(self, gid, micro_problem_cls): + return micro_problem_cls(gid) + + class TestGlobalAdaptivityLB(TestCase): def setUp(self): self._comm = MPI.COMM_WORLD @@ -68,6 +73,7 @@ def test_redistribute_active_sims_two_ranks(self): rank=self._rank, comm=self._comm, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._is_sim_active = np.array( @@ -118,6 +124,7 @@ def test_redistribute_inactive_sims_two_ranks(self): rank=self._rank, comm=self._comm, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._is_sim_active = np.array( @@ -171,6 +178,7 @@ def test_redistribute_active_sims_four_ranks(self): rank=self._rank, comm=self._comm, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._is_sim_active = np.array( From 530d31ae133b2bc8e00825d2da5b44c9a610093c Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Fri, 12 Dec 2025 18:57:19 +0100 Subject: [PATCH 06/21] missed one... --- tests/unit/test_global_adaptivity_lb.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/test_global_adaptivity_lb.py b/tests/unit/test_global_adaptivity_lb.py index fb63cc80..1e1e211b 100644 --- a/tests/unit/test_global_adaptivity_lb.py +++ b/tests/unit/test_global_adaptivity_lb.py @@ -251,6 +251,7 @@ def test_redistribute_inactive_sims_four_ranks(self): rank=self._rank, comm=self._comm, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._is_sim_active = np.array( From b0b9487cbf77d7b0e35935f7f004e86f2716e116 Mon Sep 17 00:00:00 2001 From: Ishaan Desai Date: Wed, 31 Dec 2025 13:58:39 +0100 Subject: [PATCH 07/21] Remove tuple type hint --- micro_manager/domain_decomposition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_manager/domain_decomposition.py b/micro_manager/domain_decomposition.py index bbb97b16..f0dd97ff 100644 --- a/micro_manager/domain_decomposition.py +++ b/micro_manager/domain_decomposition.py @@ -91,7 +91,7 @@ def get_local_mesh_bounds(self, macro_bounds: list, ranks_per_axis: list) -> lis def get_local_sims_and_macro_coords( self, macro_bounds: list, ranks_per_axis: list, macro_coords: np.ndarray - ) -> tuple[int, list]: + ) -> tuple: """ Decompose the micro simulations among all ranks based on their positions in the macro domain. From 619254247ddbad5a8f64495fd4c39840caea6fcc Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Sat, 3 Jan 2026 12:09:48 +0100 Subject: [PATCH 08/21] add dummy state property for micro sim class --- micro_manager/micro_simulation.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/micro_manager/micro_simulation.py b/micro_manager/micro_simulation.py index 829c6eb5..ed9cf468 100644 --- a/micro_manager/micro_simulation.py +++ b/micro_manager/micro_simulation.py @@ -36,6 +36,18 @@ def __init__(self, global_id): def get_global_id(self) -> int: return self._global_id + +def get_state(self): + if hasattr(micro_simulation_class, "get_state"): + return super().get_state() + else: + return None + +def set_state(self, state): + if hasattr(micro_simulation_class, "set_state"): + super().set_state(state) + else: + return """ sim_class_dict = {} local_globals = { From e0468c6dd0b8643295851d69e189cd7da2e2f961 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Mon, 12 Jan 2026 23:16:21 +0100 Subject: [PATCH 09/21] add task model skeleton needs integration and more features --- micro_manager/micro_simulation.py | 143 +++++++++++++++++++++------ micro_manager/tasking/connection.py | 100 +++++++++++++++++++ micro_manager/tasking/task.py | 44 +++++++++ micro_manager/tasking/worker_main.py | 41 ++++++++ 4 files changed, 298 insertions(+), 30 deletions(-) create mode 100644 micro_manager/tasking/connection.py create mode 100644 micro_manager/tasking/task.py create mode 100644 micro_manager/tasking/worker_main.py diff --git a/micro_manager/micro_simulation.py b/micro_manager/micro_simulation.py index ed9cf468..a92bd417 100644 --- a/micro_manager/micro_simulation.py +++ b/micro_manager/micro_simulation.py @@ -3,9 +3,101 @@ class MicroSimulation. A global ID member variable is defined for the class Simulation, which ensures that each created object is uniquely identifiable in a global setting. """ +class MicroSimulationWrapper: + # TODO support optional initialize + # TODO support optional output + def __init__(self, sim_cls, name, global_id, num_ranks, executor, late_init): + self._sim_cls = sim_cls # backend impl class + self._name = name # needs to be unique + self._gid = global_id + self._num_ranks = num_ranks + self._executor = executor -def create_simulation_class(micro_simulation_class, sim_class_name=None): + self._states = [None] * num_ranks # list of sims + self._instance = None + + if late_init: return + + if self._num_ranks <= 1: + self._instance = sim_cls(self._gid) + else: + f_gen_instances = self._executor.submit( + MicroSimulationWrapper.gen_instances, + # args + gid=self._gid, + num_ranks=self._num_ranks, + sim_cls=self._sim_cls, + # execution params + resource_dict={"cores": self._num_ranks}, + ) + + for rank, sim_state in f_gen_instances.result(): + self._states[rank] = sim_state + + def solve(self, micro_sim_input, dt): + if self._num_ranks <= 1: + return self._instance.solve(micro_sim_input, dt) + else: + f_solve = self._executor.submit( + MicroSimulationWrapper.solve_local, + # args + sim_cls=self._sim_cls, + states=self._states, + input=micro_sim_input, + dt=dt, + # execution params + resource_dict={"cores": self._num_ranks}, + ) + + results = f_solve.result() + result = None + for rank, output, state in results: + if rank == 0: result = output + self._states[rank] = state + + return result + + def get_state(self): + if self._num_ranks <= 1: return self._instance.get_state() + else: return self._states + + def set_state(self, states): + if self._num_ranks <= 1: self._instance.set_state(states) + else: self._states = states + + def get_global_id(self): return self._gid + def get_name(self): return self._name + + @staticmethod + def gen_instances(gid, num_ranks, sim_cls): + from mpi4py import MPI + + size = MPI.COMM_WORLD.Get_size() + rank = MPI.COMM_WORLD.Get_rank() + assert size == num_ranks + + return rank, sim_cls(gid).get_state() + + @staticmethod + def solve_local(sim_cls, states, input, dt): + from mpi4py import MPI + + size = MPI.COMM_WORLD.Get_size() + rank = MPI.COMM_WORLD.Get_rank() + assert size == len(states) + sim = sim_cls(-1) + sim.set_state(states[rank]) + + if rank == 0: + output = sim.solve(input, dt) + return rank, output, sim.get_state() + else: + sim.solve(input, dt) + return rank, None, sim.get_state() + + +def create_simulation_class(micro_simulation_class, num_ranks, executor, sim_class_name=None): """ Creates a class Simulation which inherits from the class of the micro simulation. @@ -22,40 +114,31 @@ def create_simulation_class(micro_simulation_class, sim_class_name=None): Simulation : class Definition of class Simulation defined in this function. """ + if not hasattr(micro_simulation_class, "get_global_id"): raise ValueError("Invalid micro simulation class") + if not hasattr(micro_simulation_class, "get_state"): raise ValueError("Invalid micro simulation class") + if not hasattr(micro_simulation_class, "set_state"): raise ValueError("Invalid micro simulation class") + if not hasattr(micro_simulation_class, "solve"): raise ValueError("Invalid micro simulation class") + if sim_class_name is None: - if not hasattr(create_simulation_class, "sim_id"): - create_simulation_class.sim_id = 0 - else: - create_simulation_class.sim_id += 1 + if not hasattr(create_simulation_class, "sim_id"): create_simulation_class.sim_id = 0 + else: create_simulation_class.sim_id += 1 sim_class_name = f"MicroSimulation{create_simulation_class.sim_id}" - sim_class_body = """ -def __init__(self, global_id): - micro_simulation_class.__init__(self, global_id) - self._global_id = global_id - -def get_global_id(self) -> int: - return self._global_id - -def get_state(self): - if hasattr(micro_simulation_class, "get_state"): - return super().get_state() - else: - return None - -def set_state(self, state): - if hasattr(micro_simulation_class, "set_state"): - super().set_state(state) - else: - return + cls_body = """ +backend_cls = sim_cls +def __init__(self, global_id, late_init=False): + wrapper_cls.__init__(self, sim_cls, name, global_id, num_ranks, executor, late_init) """ - sim_class_dict = {} + cls_dict = {} local_globals = { "__builtins__": __builtins__, - "micro_simulation_class": micro_simulation_class, + "wrapper_cls": MicroSimulationWrapper, + "sim_cls": micro_simulation_class, + "name": sim_class_name, + "num_ranks": num_ranks, + "executor": executor, } - exec(sim_class_body, local_globals, sim_class_dict) - # print(sim_class_dict) - sim_class = type(sim_class_name, (micro_simulation_class,), sim_class_dict) - return sim_class + exec(cls_body, local_globals, cls_dict) + result_cls = type(sim_class_name, (MicroSimulationWrapper,), cls_dict) + return result_cls diff --git a/micro_manager/tasking/connection.py b/micro_manager/tasking/connection.py new file mode 100644 index 00000000..8b0a1bc6 --- /dev/null +++ b/micro_manager/tasking/connection.py @@ -0,0 +1,100 @@ +import pickle +import socket +import struct +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional +from mpi4py import MPI + +class Connection(ABC): + @abstractmethod + def send(self, dst_id: int, obj: Any) -> None: pass + @abstractmethod + def recv(self, src_id: int) -> Any: pass + @abstractmethod + def close(self) -> None: pass + + +class MPIConnection(Connection): + def __init__(self): + self.inter_comm = None + + @classmethod + def create_workers(cls, worker_exec: str, mpi_args: Optional, n_workers: int) -> "MPIConnection": + comm = MPI.COMM_SELF + conn = cls() + conn.inter_comm = comm.Spawn( + worker_exec, + args=mpi_args or [], + maxprocs=n_workers, + ) + return conn + + @classmethod + def connect_to_micromanager(cls, parent_comm) -> "MPIConnection": + conn = cls() + conn.inter_comm = parent_comm + return conn + + def send(self, dst_id: int, obj: Any) -> None: + data = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + self.inter_comm.send(data, dest=dst_id, tag=0) + + def recv(self, src_id: int) -> Any: + data = self.inter_comm.recv(source=src_id, tag=1) + return pickle.loads(data) + + def close(self) -> None: + self.inter_comm.Disconnect() + + +class SocketConnection(Connection): + def __init__(self): + self.sockets: Dict[int, socket.socket] = {} + + @classmethod + def accept_workers(cls, host: str, port: int, n_workers: int) -> "SocketConnection": + conn = cls() + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.bind((host, port)) + server.listen() + + for wid in range(n_workers): + sock, _ = server.accept() + conn.sockets[wid] = sock + + server.close() + return conn + + @classmethod + def connect_to_micromanager( + cls, worker_id: int, host: str, port: int + ) -> "SocketConnection": + conn = cls() + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect((host, port)) + conn.sockets[worker_id] = sock + return conn + + def send(self, dst_id: int, obj: Any) -> None: + sock = self.sockets[dst_id] + data = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + header = struct.pack("!Q", len(data)) + sock.sendall(header + data) + + def recv(self, src_id: int) -> Any: + sock = self.sockets[src_id] + header = sock.recv(8) + if not header: + raise EOFError + (size,) = struct.unpack("!Q", header) + payload = b"" + while len(payload) < size: + chunk = sock.recv(size - len(payload)) + if not chunk: + raise EOFError + payload += chunk + return pickle.loads(payload) + + def close(self) -> None: + for sock in self.sockets.values(): sock.close() + self.sockets.clear() \ No newline at end of file diff --git a/micro_manager/tasking/task.py b/micro_manager/tasking/task.py new file mode 100644 index 00000000..39f5eb1a --- /dev/null +++ b/micro_manager/tasking/task.py @@ -0,0 +1,44 @@ + +class Task: + def __init__(self, fn, *args, **kwargs): + self.fn = fn + self.args = args + self.kwargs = kwargs + + def __call__(self, state_data: dict): + return self.fn(*self.args, state_data=state_data, **self.kwargs) + +class ConstructTask(Task): + def __init__(self, gid, sim_cls): + super().__init__(ConstructTask.initializer, gid=gid, sim_cls=sim_cls) + + @staticmethod + def initializer(gid, sim_cls, state_data): + state_data[gid] = sim_cls(gid) + return None + +class SolveTask(Task): + def __init__(self, gid, sim_input, dt): + super().__init__(SolveTask.solve, gid=gid, sim_input=sim_input, dt=dt) + + @staticmethod + def solve(gid, sim_input, dt, state_data): + sim_output = state_data[gid].solve(sim_input, dt) + return sim_output + +class GetStateTask(Task): + def __init__(self, gid): + super().__init__(GetStateTask.get, gid=gid) + + @staticmethod + def get(gid, state_data): + return state_data[gid].get_state() + +class SetStateTask(Task): + def __init__(self, gid, state): + super().__init__(SetStateTask.set, gid=gid, state=state) + + @staticmethod + def set(gid, state, state_data): + state_data[gid].set_state(state) + return None \ No newline at end of file diff --git a/micro_manager/tasking/worker_main.py b/micro_manager/tasking/worker_main.py new file mode 100644 index 00000000..6d66573a --- /dev/null +++ b/micro_manager/tasking/worker_main.py @@ -0,0 +1,41 @@ +import argparse +import os +from mpi4py import MPI + +from .connection import Connection, MPIConnection, SocketConnection + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--backend", required=True, choices=["mpi", "socket"]) + parser.add_argument("--host", help="IP or localhost") + parser.add_argument("--port", type=int, help="Port to open port in micro manager") + parser.add_argument("--parentrank", type=int, help="Parent rank of spawning micro manager mpi instance") + args = parser.parse_args() + + rank = MPI.COMM_WORLD.Get_rank() + size = MPI.COMM_WORLD.Get_size() + worker_id = rank + + conn, dst_id, src_id = None, 0, 0 + if args.backend == "mpi": + conn = MPIConnection.connect_to_micromanager(MPI.Comm.Get_parent()) + dst_id = src_id = args.parentrank + else: + conn = SocketConnection.connect_to_micromanager(worker_id, args.host, args.port) + dst_id = src_id = worker_id + + state_data = {} + + while True: + data = None + try: data = conn.recv(src_id) + except Exception: break + + # TODO unpickle data into task and handle it + # TODO retain sim_obj... + # TODO should always be smth like this: output = task(state_data) + send_data = None # TODO needs to be set + try: conn.send(dst_id, send_data) + except Exception: break + + conn.close() From d61daf9715cb35a3104856285960c42a1a5e5617 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Tue, 13 Jan 2026 16:14:23 +0100 Subject: [PATCH 10/21] full impl and integrated needs testing --- .../adaptivity/adaptivity_selection.py | 42 +++ micro_manager/adaptivity/model_adaptivity.py | 5 +- micro_manager/config.py | 55 ++++ micro_manager/micro_manager.py | 160 ++++------- micro_manager/micro_simulation.py | 272 ++++++++++++------ micro_manager/snapshot/snapshot.py | 11 +- micro_manager/tasking/connection.py | 143 ++++++++- micro_manager/tasking/task.py | 18 +- 8 files changed, 489 insertions(+), 217 deletions(-) create mode 100644 micro_manager/adaptivity/adaptivity_selection.py diff --git a/micro_manager/adaptivity/adaptivity_selection.py b/micro_manager/adaptivity/adaptivity_selection.py new file mode 100644 index 00000000..75b5d404 --- /dev/null +++ b/micro_manager/adaptivity/adaptivity_selection.py @@ -0,0 +1,42 @@ +from .global_adaptivity import GlobalAdaptivityCalculator +from .global_adaptivity_lb import GlobalAdaptivityLBCalculator +from .local_adaptivity import LocalAdaptivityCalculator +from adaptivity import AdaptivityCalculator + +def create_adaptivity_calculator( + config, + local_number_of_sims, + global_number_of_sims, + global_ids_of_local_sims, + participant, + logger, + rank, + comm, + micro_problem_cls, + model_manager, + use_lb +) -> AdaptivityCalculator: + adaptivity_type = config.get_adaptivity_type() + + if adaptivity_type == 'local': + return LocalAdaptivityCalculator( + config, local_number_of_sims, logger, rank, comm, micro_problem_cls, model_manager + ) + + if adaptivity_type == 'global': + cls = GlobalAdaptivityCalculator + if use_lb: cls = GlobalAdaptivityLBCalculator + + return cls( + config, + global_number_of_sims, + global_ids_of_local_sims, + participant, + logger, + rank, + comm, + micro_problem_cls, + model_manager + ) + + raise ValueError("Unknown adaptivity type") \ No newline at end of file diff --git a/micro_manager/adaptivity/model_adaptivity.py b/micro_manager/adaptivity/model_adaptivity.py index 951c8669..c8911f5c 100644 --- a/micro_manager/adaptivity/model_adaptivity.py +++ b/micro_manager/adaptivity/model_adaptivity.py @@ -8,6 +8,7 @@ from micro_manager.tools.logging_wrapper import Logger from micro_manager.tools.misc import clamp_in_range from micro_manager.model_manager import ModelManager +from micro_manager.tasking.connection import Connection import numpy as np import importlib @@ -20,6 +21,8 @@ def __init__( configurator: Config, rank: int, log_file: str, + conn: Connection, + num_ranks: int, ) -> None: """ Class constructor. @@ -51,7 +54,7 @@ def __init__( importlib.import_module(model_file, CLASS_NAME), CLASS_NAME, ) - self._model_classes.append(create_simulation_class(model)) + self._model_classes.append(create_simulation_class(self._logger, model, num_ranks, conn)) self._model_manager.register( self._model_classes[pos], stateless_flags[pos] ) diff --git a/micro_manager/config.py b/micro_manager/config.py index f72dba2d..ce41ae98 100644 --- a/micro_manager/config.py +++ b/micro_manager/config.py @@ -76,6 +76,11 @@ def __init__(self, config_file_name): self._m_adap_micro_stateless = None self._m_adap_switching_function = None + # Tasking + self._task_is_slurm = False + self._task_backend = "socket" + self._task_num_workers = 1 + def set_logger(self, logger): """ Set the logger for the Config class. @@ -195,6 +200,23 @@ def _read_json(self, config_file_name): self._micro_dt = self._data["simulation_params"]["micro_dt"] + try: + if self._data["tasking"]: + backend = self._data["tasking"]["backend"] + if backend not in ["mpi", "socket"]: + raise Exception("Backend must be either 'mpi' or 'socket'.") + self._task_backend = backend + if "is_slurm" in self._data["tasking"]: + self._task_is_slurm = self._data["tasking"]["is_slurm"] + if "num_workers" in self._data["tasking"]: + self._task_num_workers = self._data["tasking"]["num_workers"] + if self._task_is_slurm and backend == "mpi": + raise Exception("MPI backend not supported on SLURM systems.") + except BaseException: + self._logger.log_info_rank_zero( + "No or incorrect tasking information provided. Micro manager will compute locally." + ) + def read_json_micro_manager(self): """ Reads Micro Manager relevant information from JSON configuration file @@ -1040,3 +1062,36 @@ def get_model_adaptivity_switching_function(self): String containing the path to the switching function file """ return self._m_adap_switching_function + + def get_tasking_num_workers(self): + """ + Get number of workers + + Returns + ------- + num_workers : int + Number of workers + """ + return self._task_num_workers + + def get_tasking_backend(self): + """ + Get backend type + + Returns + ------- + backend : str + either socket or mpi + """ + return self._task_backend + + def get_tasking_use_slurm(self): + """ + Get flag whether slurm is used + + Returns + ------- + use_slurm : bool + use slurm or not + """ + return self._task_is_slurm diff --git a/micro_manager/micro_manager.py b/micro_manager/micro_manager.py index fb4a0f99..24225773 100644 --- a/micro_manager/micro_manager.py +++ b/micro_manager/micro_manager.py @@ -33,10 +33,11 @@ from .adaptivity.local_adaptivity import LocalAdaptivityCalculator from .adaptivity.global_adaptivity_lb import GlobalAdaptivityLBCalculator from .adaptivity.model_adaptivity import ModelAdaptivity +from .adaptivity.adaptivity_selection import create_adaptivity_calculator from .domain_decomposition import DomainDecomposer - -from .micro_simulation import create_simulation_class +from .tasking.connection import spawn_local_workers +from .micro_simulation import create_simulation_class, load_backend_class from .tools.logging_wrapper import Logger @@ -159,6 +160,7 @@ def __init__(self, config_file: str, log_file: str = "") -> None: self._n = 0 # sim-step self._model_manager = ModelManager() + self._conn = None # ************** # Public methods @@ -425,10 +427,7 @@ def initialize(self) -> None: " and does not match the dimensions of the macro mesh." ) - domain_decomposer = DomainDecomposer( - self._rank, - self._size, - ) + domain_decomposer = DomainDecomposer(self._rank, self._size) if self._is_parallel and not self._is_adaptivity_with_load_balancing: coupling_mesh_bounds = domain_decomposer.get_local_mesh_bounds( @@ -543,23 +542,34 @@ def initialize(self) -> None: if self._interpolate_crashed_sims: self._interpolant = Interpolation(self._logger) + # Setup remote workers + base_dir = os.path.dirname(os.path.abspath(__file__)) + worker_exec = os.path.join(base_dir, "tasking", "worker_main.py") + num_ranks = self._config.get_tasking_num_workers() + self._conn = spawn_local_workers( + worker_exec, + num_ranks, + self._config.get_tasking_backend(), + self._config.get_tasking_use_slurm() + ) + + # load micro sim micro_problem_cls = None if self._is_model_adaptivity_on: self._model_adaptivity_controller: ModelAdaptivity = ModelAdaptivity( - self._model_manager, self._config, self._rank, self._log_file + self._model_manager, self._config, self._rank, self._log_file, self._conn, num_ranks, ) micro_problem_cls = ( self._model_adaptivity_controller.get_resolution_sim_class(0) ) else: - micro_problem_base = getattr( - importlib.import_module( - self._config.get_micro_file_name(), "MicroSimulation" - ), - "MicroSimulation", - ) + micro_problem_base = load_backend_class(self._config.get_micro_file_name()) micro_problem_cls = create_simulation_class( - micro_problem_base, "MicroSimulationDefault" + self._logger, + micro_problem_base, + self._config.get_tasking_num_workers(), + self._conn, + "MicroSimulationDefault" ) self._model_manager.register( micro_problem_cls, self._config.turn_on_micro_stateless() @@ -574,47 +584,19 @@ def initialize(self) -> None: ) if self._is_adaptivity_on: - if self._config.get_adaptivity_type() == "local": - self._adaptivity_controller: LocalAdaptivityCalculator = ( - LocalAdaptivityCalculator( - self._config, - self._local_number_of_sims, - self._logger, - self._rank, - self._comm, - micro_problem_cls, - self._model_manager, - ) - ) - elif self._config.get_adaptivity_type() == "global": - if self._is_adaptivity_with_load_balancing: - self._adaptivity_controller: GlobalAdaptivityLBCalculator = ( - GlobalAdaptivityLBCalculator( - self._config, - self._global_number_of_sims, - self._global_ids_of_local_sims, - self._participant, - self._logger, - self._rank, - self._comm, - micro_problem_cls, - self._model_manager, - ) - ) - else: - self._adaptivity_controller: GlobalAdaptivityCalculator = ( - GlobalAdaptivityCalculator( - self._config, - self._global_number_of_sims, - self._global_ids_of_local_sims, - self._participant, - self._logger, - self._rank, - self._comm, - micro_problem_cls, - self._model_manager, - ) - ) + self._adaptivity_controller = create_adaptivity_calculator( + self._config, + self._local_number_of_sims, + self._global_number_of_sims, + self._global_ids_of_local_sims, + self._participant, + self._logger, + self._rank, + self._comm, + micro_problem_cls, + self._model_manager, + self._is_adaptivity_with_load_balancing, + ) self._micro_sims_active_steps = np.zeros( self._global_number_of_sims @@ -634,9 +616,8 @@ def initialize(self) -> None: is_initial_data_available = False else: is_initial_data_available = True - if ( - self._lazy_init - ): # For lazy initialization, compute adaptivity with the initial macro data + # For lazy initialization, compute adaptivity with the initial macro data + if self._lazy_init: for i in range(self._local_number_of_sims): for name in self._adaptivity_macro_data_names: self._data_for_adaptivity[name][i] = initial_data[i][name] @@ -664,45 +645,13 @@ def initialize(self) -> None: ) # Boolean which states if the initialize() method of the micro simulation requires initial data - sim_requires_init_data = False - - # Check if provided micro simulation has an initialize() method - if hasattr(micro_problem_cls, "initialize") and callable( - getattr(micro_problem_cls, "initialize") - ): - self._micro_sims_init = True # Starting value before setting - - try: # Try to get the signature of the initialize() method, if it is written in Python - argspec = inspect.getfullargspec(micro_problem_cls.initialize) - if ( - len(argspec.args) == 1 - ): # The first argument in the signature is self - sim_requires_init_data = False - elif len(argspec.args) == 2: - sim_requires_init_data = True - else: - raise Exception( - "The initialize() method of the Micro simulation has an incorrect number of arguments." - ) - except TypeError: - self._logger.log_info_rank_zero( - "The signature of initialize() method of the micro simulation cannot be determined. Trying to determine the signature by calling the method." - ) - # Try to get the signature of the initialize() method, if it is not written in Python - try: # Try to call the initialize() method without initial data - self._micro_sims[first_id].initialize() - sim_requires_init_data = False - except TypeError: - self._logger.log_info_rank_zero( - "The initialize() method of the micro simulation has arguments. Attempting to call it again with initial data." - ) - try: # Try to call the initialize() method with initial data - self._micro_sims[first_id].initialize(initial_data[first_id]) - sim_requires_init_data = True - except TypeError: - raise Exception( - "The initialize() method of the Micro simulation has an incorrect number of arguments." - ) + ( + self._micro_sims_init, + sim_requires_init_data + ) = micro_problem_cls.check_initialize( + self._micro_sims[first_id], + initial_data[first_id] if is_initial_data_available else None, + ) if sim_requires_init_data and not is_initial_data_available: raise Exception( @@ -711,7 +660,6 @@ def initialize(self) -> None: # Get initial data from micro simulations if initialize() method exists if self._micro_sims_init: - # Call initialize() method of the micro simulation to check if it returns any initial data if sim_requires_init_data: initial_micro_output = self._micro_sims[first_id].initialize( @@ -720,9 +668,8 @@ def initialize(self) -> None: else: initial_micro_output = self._micro_sims[first_id].initialize() - if ( - initial_micro_output is None - ): # Check if the detected initialize() method returns any data + # Check if the detected initialize() method returns any data + if initial_micro_output is None: self._logger.log_warning_rank_zero( "The initialize() call of the Micro simulation has not returned any initial data." " This means that the initialize() call has no effect on the adaptivity. The initialize method will nevertheless still be called." @@ -775,9 +722,8 @@ def initialize(self) -> None: ] = initial_micro_output[name] initial_micro_data[name][i] = initial_micro_output[name] - if ( - self._lazy_init - ): # If lazy initialization is on, initial states of inactive simulations need to be determined + # If lazy initialization is on, initial states of inactive simulations need to be determined + if self._lazy_init: self._adaptivity_controller.get_full_field_micro_output( initial_micro_data ) @@ -799,11 +745,7 @@ def initialize(self) -> None: for i in range(1, self._local_number_of_sims): self._micro_sims[i].initialize() - self._micro_sims_have_output = False - if hasattr(micro_problem_cls, "output") and callable( - getattr(micro_problem_cls, "output") - ): - self._micro_sims_have_output = True + self._micro_sims_have_output = micro_problem_cls.check_output() self._participant.stop_last_profiling_section() diff --git a/micro_manager/micro_simulation.py b/micro_manager/micro_simulation.py index a92bd417..39c07268 100644 --- a/micro_manager/micro_simulation.py +++ b/micro_manager/micro_simulation.py @@ -3,101 +3,202 @@ class MicroSimulation. A global ID member variable is defined for the class Simulation, which ensures that each created object is uniquely identifiable in a global setting. """ -class MicroSimulationWrapper: - # TODO support optional initialize - # TODO support optional output - - def __init__(self, sim_cls, name, global_id, num_ranks, executor, late_init): - self._sim_cls = sim_cls # backend impl class - self._name = name # needs to be unique - self._gid = global_id - self._num_ranks = num_ranks - self._executor = executor - - self._states = [None] * num_ranks # list of sims - self._instance = None - if late_init: return +from abc import ABC, abstractmethod +import inspect +import importlib as ipl + +from .tasking.task import * + +class MicroSimulationInterface(ABC): + @abstractmethod + def solve(self, micro_sim_input, dt): pass + @abstractmethod + def get_state(self): pass + @abstractmethod + def set_state(self, state): pass + @abstractmethod + def get_global_id(self): pass + @abstractmethod + def initialize(self, *args, **kwargs): pass + @abstractmethod + def output(self): pass + + +class MicroSimulationLocal(MicroSimulationInterface): + def __init__(self, gid, sim_cls): + self._instance = sim_cls(gid) + self._gid = gid + + def solve(self, micro_sim_input, dt): return self._instance.solve(micro_sim_input, dt) + def get_state(self): return self._instance.get_state() + def set_state(self, state): return self._instance.set_state(state) + def get_global_id(self): return self._gid + def initialize(self, *args, **kwargs): return self._instance.initialize(*args, **kwargs) + def output(self): return self._instance.output() + + +class MicroSimulationRemote(MicroSimulationInterface): + def __init__(self, gid, num_ranks, conn, sim_cls): + self._sim_cls = sim_cls # backend impl class + self._gid = gid + self._num_ranks = num_ranks + self._conn = conn - if self._num_ranks <= 1: - self._instance = sim_cls(self._gid) - else: - f_gen_instances = self._executor.submit( - MicroSimulationWrapper.gen_instances, - # args - gid=self._gid, - num_ranks=self._num_ranks, - sim_cls=self._sim_cls, - # execution params - resource_dict={"cores": self._num_ranks}, - ) + for worker_id in range(self._num_ranks): + task = ConstructTask(self._gid, self._sim_cls) + self._conn.send(worker_id, task) - for rank, sim_state in f_gen_instances.result(): - self._states[rank] = sim_state + for worker_id in range(self._num_ranks): + self._conn.recv(worker_id) def solve(self, micro_sim_input, dt): - if self._num_ranks <= 1: - return self._instance.solve(micro_sim_input, dt) - else: - f_solve = self._executor.submit( - MicroSimulationWrapper.solve_local, - # args - sim_cls=self._sim_cls, - states=self._states, - input=micro_sim_input, - dt=dt, - # execution params - resource_dict={"cores": self._num_ranks}, - ) + for worker_id in range(self._num_ranks): + task = SolveTask(self._gid, micro_sim_input, dt) + self._conn.send(worker_id, task) - results = f_solve.result() - result = None - for rank, output, state in results: - if rank == 0: result = output - self._states[rank] = state + result = None + for worker_id in range(self._num_ranks): + output = self._conn.recv(worker_id) + if worker_id == 0: result = output - return result + return result def get_state(self): - if self._num_ranks <= 1: return self._instance.get_state() - else: return self._states + for worker_id in range(self._num_ranks): + task = GetStateTask(self._gid) + self._conn.send(worker_id, task) + + result = {} + for worker_id in range(self._num_ranks): + result[worker_id] = self._conn.recv(worker_id) + + return result + + def set_state(self, state): + for worker_id in range(self._num_ranks): + task = SetStateTask(self._gid, state[worker_id]) + self._conn.send(worker_id, task) - def set_state(self, states): - if self._num_ranks <= 1: self._instance.set_state(states) - else: self._states = states + for worker_id in range(self._num_ranks): + self._conn.recv(worker_id) - def get_global_id(self): return self._gid - def get_name(self): return self._name + def get_global_id(self): + return self._gid - @staticmethod - def gen_instances(gid, num_ranks, sim_cls): - from mpi4py import MPI + def initialize(self, *args, **kwargs): + for worker_id in range(self._num_ranks): + task = InitializeTask(self._gid, *args, **kwargs) + self._conn.send(worker_id, task) - size = MPI.COMM_WORLD.Get_size() - rank = MPI.COMM_WORLD.Get_rank() - assert size == num_ranks + result = None + for worker_id in range(self._num_ranks): + output = self._conn.recv(worker_id) + if worker_id == 0: result = output - return rank, sim_cls(gid).get_state() + return result - @staticmethod - def solve_local(sim_cls, states, input, dt): - from mpi4py import MPI + def output(self): + for worker_id in range(self._num_ranks): + task = OutputTask(self._gid) + self._conn.send(worker_id, task) - size = MPI.COMM_WORLD.Get_size() - rank = MPI.COMM_WORLD.Get_rank() - assert size == len(states) - sim = sim_cls(-1) - sim.set_state(states[rank]) + result = None + for worker_id in range(self._num_ranks): + output = self._conn.recv(worker_id) + if worker_id == 0: result = output - if rank == 0: - output = sim.solve(input, dt) - return rank, output, sim.get_state() + return result + + +class MicroSimulationWrapper(MicroSimulationInterface): + """ + If only a single rank is in use: will contain the micro sim instance. + Otherwise, it will delegate method calls to workers and not contain state. + """ + def __init__(self, sim_cls, global_id, num_ranks, conn): + self._impl = None + + if num_ranks > 1 and conn is not None: + self._impl = MicroSimulationRemote(global_id, num_ranks, conn, sim_cls) else: - sim.solve(input, dt) - return rank, None, sim.get_state() + self._impl = MicroSimulationLocal(global_id, sim_cls) + def solve(self, micro_sim_input, dt): return self._impl.solve(micro_sim_input, dt) + def get_state(self): return self._impl.get_state() + def set_state(self, state): return self._impl.set_state(state) + def get_global_id(self): return self._impl.get_global_id() + def initialize(self, *args, **kwargs): return self._impl.initialize(*args, **kwargs) + def output(self): return self._impl.output() -def create_simulation_class(micro_simulation_class, num_ranks, executor, sim_class_name=None): + +class MicroSimulationClassAdapter: + def __init__(self, sim_cls, name, num_ranks, conn, log): + self._sim_cls = sim_cls + self._name = name + self._num_ranks = num_ranks + self._conn = conn + self._log = log + + def __class__(self): return self._name + def __call__(self, gid): return MicroSimulationWrapper(self._sim_cls, gid, self._num_ranks, self._conn) + @property + def backend_cls(self): return self._sim_cls + + def check_initialize(self, test_instance, test_input): + has_init = hasattr(self._sim_cls, 'initialize') + callable_init = callable(getattr(self._sim_cls, 'initialize')) + if not has_init or not callable_init: return False, False + + has_args = False + + # Try to get the signature of the initialize() method, if it is written in Python + try: + argspec = inspect.getfullargspec(self._sim_cls.initialize) + # The first argument in the signature is self + if len(argspec.args) == 1: has_args = False + elif len(argspec.args) == 2: has_args = True + else: + raise Exception( + "The initialize() method of the Micro simulation has an incorrect number of arguments." + ) + except TypeError: + self._log.log_info_rank_zero( + "The signature of initialize() method of the micro simulation cannot be determined. " + + "Trying to determine the signature by calling the method." + ) + # Try to call the initialize() method without initial data + try: + test_instance.initialize() + has_args = False + except TypeError: + self._log.log_info_rank_zero( + "The initialize() method of the micro simulation has arguments. " + + "Attempting to call it again with initial data." + ) + try: + test_instance.initialize(test_input) + has_args = True + except TypeError: + raise Exception( + "The initialize() method of the Micro simulation has an incorrect number of arguments." + ) + + return has_init and callable_init, has_args + + def check_output(self): + has_init = hasattr(self._sim_cls, 'output') + callable_init = callable(getattr(self._sim_cls, 'output')) + + return has_init and callable_init + + +def load_backend_class(path_to_micro_file): + CLS_NAME = 'MicroSimulation' + return getattr(ipl.import_module(path_to_micro_file, CLS_NAME), CLS_NAME) + + +def create_simulation_class(log, micro_simulation_class, num_ranks, conn=None, sim_class_name=None): """ Creates a class Simulation which inherits from the class of the micro simulation. @@ -117,28 +218,13 @@ def create_simulation_class(micro_simulation_class, num_ranks, executor, sim_cla if not hasattr(micro_simulation_class, "get_global_id"): raise ValueError("Invalid micro simulation class") if not hasattr(micro_simulation_class, "get_state"): raise ValueError("Invalid micro simulation class") if not hasattr(micro_simulation_class, "set_state"): raise ValueError("Invalid micro simulation class") - if not hasattr(micro_simulation_class, "solve"): raise ValueError("Invalid micro simulation class") + if not hasattr(micro_simulation_class, "solve"): raise ValueError("Invalid micro simulation class") if sim_class_name is None: if not hasattr(create_simulation_class, "sim_id"): create_simulation_class.sim_id = 0 else: create_simulation_class.sim_id += 1 sim_class_name = f"MicroSimulation{create_simulation_class.sim_id}" - cls_body = """ -backend_cls = sim_cls -def __init__(self, global_id, late_init=False): - wrapper_cls.__init__(self, sim_cls, name, global_id, num_ranks, executor, late_init) - """ - cls_dict = {} - local_globals = { - "__builtins__": __builtins__, - "wrapper_cls": MicroSimulationWrapper, - "sim_cls": micro_simulation_class, - "name": sim_class_name, - "num_ranks": num_ranks, - "executor": executor, - } - - exec(cls_body, local_globals, cls_dict) - result_cls = type(sim_class_name, (MicroSimulationWrapper,), cls_dict) + + result_cls = MicroSimulationClassAdapter(micro_simulation_class, sim_class_name, num_ranks, conn, log) return result_cls diff --git a/micro_manager/snapshot/snapshot.py b/micro_manager/snapshot/snapshot.py index 4b6ae71f..879ac7ff 100644 --- a/micro_manager/snapshot/snapshot.py +++ b/micro_manager/snapshot/snapshot.py @@ -17,7 +17,7 @@ from micro_manager.micro_manager import MicroManager from .dataset import ReadWriteHDF -from micro_manager.micro_simulation import create_simulation_class +from micro_manager.micro_simulation import create_simulation_class, load_backend_class from micro_manager.tools.logging_wrapper import Logger @@ -84,7 +84,7 @@ def solve(self) -> None: - Merge output in parallel run. """ - micro_problem_cls = create_simulation_class(self._micro_problem) + micro_problem_cls = create_simulation_class(self._logger, self._micro_problem, 1, None) # Loop over all macro parameters for elems in range(self._local_number_of_sims): @@ -256,12 +256,7 @@ def initialize(self) -> None: for i in range(self._local_number_of_sims): self._global_ids_of_local_sims.append(sim_id) sim_id += 1 - self._micro_problem = getattr( - importlib.import_module( - self._config.get_micro_file_name(), "MicroSimulation" - ), - "MicroSimulation", - ) + self._micro_problem = load_backend_class(self._config.get_micro_file_name()) self._micro_sims_have_output = False if hasattr(self._micro_problem, "output") and callable( diff --git a/micro_manager/tasking/connection.py b/micro_manager/tasking/connection.py index 8b0a1bc6..49cbde17 100644 --- a/micro_manager/tasking/connection.py +++ b/micro_manager/tasking/connection.py @@ -1,6 +1,9 @@ import pickle +import psutil import socket import struct +import subprocess +import os from abc import ABC, abstractmethod from typing import Any, Dict, Optional from mpi4py import MPI @@ -23,7 +26,7 @@ def create_workers(cls, worker_exec: str, mpi_args: Optional, n_workers: int) -> comm = MPI.COMM_SELF conn = cls() conn.inter_comm = comm.Spawn( - worker_exec, + f"python {worker_exec}", args=mpi_args or [], maxprocs=n_workers, ) @@ -52,12 +55,26 @@ def __init__(self): self.sockets: Dict[int, socket.socket] = {} @classmethod - def accept_workers(cls, host: str, port: int, n_workers: int) -> "SocketConnection": - conn = cls() + def create_workers(cls, worker_exec: str, launcher: list, host: str, n_workers: int) -> "SocketConnection": + # create listening socket with ephemeral port server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server.bind((host, port)) + server.bind((host, 0)) # kernel picks free port server.listen() + port = server.getsockname()[1] + executable = [ + "python", + worker_exec, + "--backend", "socket", + "--host", host, + "--port", str(port), + ] + cmd = [] + cmd.extend(launcher) + cmd.extend(executable) + subprocess.Popen(cmd, env=os.environ.copy()) + + conn = cls() for wid in range(n_workers): sock, _ = server.accept() conn.sockets[wid] = sock @@ -97,4 +114,120 @@ def recv(self, src_id: int) -> Any: def close(self) -> None: for sock in self.sockets.values(): sock.close() - self.sockets.clear() \ No newline at end of file + self.sockets.clear() + + +def get_local_ip(preferred_ifaces=None) -> str: + """ + Returns a non-loopback IPv4 address without accessing external networks. + + Parameters + ---------- + preferred_ifaces : list[str], optional + If provided, try interfaces in this order first (e.g., ["ib0", "eno1"]) + + Returns + ------- + str + The selected IPv4 address + """ + addrs = psutil.net_if_addrs() + + candidates = [] + + # Iterate over preferred interfaces first + if preferred_ifaces: + for name in preferred_ifaces: + if name not in addrs: + continue + for a in addrs[name]: + if a.family == socket.AF_INET and not a.address.startswith("127."): + return a.address + + # Fallback: iterate all interfaces + for name, iface_addrs in addrs.items(): + for a in iface_addrs: + if a.family == socket.AF_INET: + ip = a.address + if not ip.startswith("127.") and not ip.startswith("169.254."): + candidates.append(ip) + + if candidates: + return candidates[0] + + raise RuntimeError("No non-loopback IPv4 address found") + + +def spawn_local_workers( + worker_exec: str, + n_workers: int, + backend: str, + is_slurm: bool, +): + """ + Spawn worker processes. On Slurm systems: MPI spawn now supported, socket backend enforced. + Ephemeral port auto-selected. + + Parameters + ---------- + worker_exec : str + path to worker executable + n_workers : int + number of worker processes, must be > 1 otherwise returns None + backend : str + mpi or socket + is_slurm : bool + is our system slurm based? + + Returns + ------- + conn : Connection + Established connection on generator side + """ + if n_workers <= 1: return None + conn = None + + # MPI BACKEND (non-Slurm only) + if backend == "mpi": + if is_slurm: raise RuntimeError( + "MPI backend is not supported under Slurm. " + "Use socket backend instead." + ) + comm = MPI.COMM_WORLD + local_rank = comm.Get_rank() + conn = MPIConnection.create_workers( + worker_exec=worker_exec, + mpi_args=[ + "--backend", "mpi", + "--parentrank", str(local_rank), + ], + n_workers=n_workers, + ) + + # SOCKET BACKEND + if backend == "socket": + host = get_local_ip() + + # launch workers + launcher = None + if is_slurm: + launcher = [ + "srun", + #"--exclusive", + "--ntasks", str(n_workers), + "--kill-on-bad-exit=1", + ] + else: + launcher = [ + "mpiexec", + "-n", str(n_workers), + ] + + conn = SocketConnection.create_workers( + worker_exec=worker_exec, + launcher=launcher, + host=host, + n_workers=n_workers + ) + + return conn diff --git a/micro_manager/tasking/task.py b/micro_manager/tasking/task.py index 39f5eb1a..92fe25ee 100644 --- a/micro_manager/tasking/task.py +++ b/micro_manager/tasking/task.py @@ -41,4 +41,20 @@ def __init__(self, gid, state): @staticmethod def set(gid, state, state_data): state_data[gid].set_state(state) - return None \ No newline at end of file + return None + +class InitializeTask(Task): + def __init__(self, gid, *args, **kwargs): + super().__init__(InitializeTask.initialize, *args, gid=gid, **kwargs) + + @staticmethod + def initialize(gid, state_data, *args, **kwargs): + return state_data[gid].initialize(*args, **kwargs) + +class OutputTask(Task): + def __init__(self, gid): + super().__init__(OutputTask.output, gid=gid) + + @staticmethod + def output(gid, state_data): + return state_data[gid].output() \ No newline at end of file From b251cc4794bd5d40de2fc4b1ae30a2fe81ff24c0 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Tue, 13 Jan 2026 18:50:29 +0100 Subject: [PATCH 11/21] impl late sim init for model adaptivity --- micro_manager/adaptivity/model_adaptivity.py | 77 ++++++++++++++++++-- micro_manager/micro_manager.py | 67 +++++++++++------ micro_manager/micro_simulation.py | 30 ++++++-- micro_manager/model_manager.py | 4 +- micro_manager/tasking/task.py | 9 +++ 5 files changed, 147 insertions(+), 40 deletions(-) diff --git a/micro_manager/adaptivity/model_adaptivity.py b/micro_manager/adaptivity/model_adaptivity.py index c8911f5c..922e28e1 100644 --- a/micro_manager/adaptivity/model_adaptivity.py +++ b/micro_manager/adaptivity/model_adaptivity.py @@ -167,12 +167,77 @@ def switch_models( if cur_res[idx] == tgt_res[idx]: continue - sim_state = sims[idx].get_state() - sim_id = sims[idx].get_global_id() - sims[idx] = self._model_manager.get_instance( - sim_id, self.get_resolution_sim_class(tgt_res[idx]) - ) - sims[idx].set_state(sim_state) + sim = sims[idx] + gid = sim.get_global_id() + tgt_cls = self.get_resolution_sim_class(tgt_res[idx]) + + key = f"{sim.__name__}-state" + key_new = f"{tgt_cls.__name__}-state" + + new_state_exists = key_new in sim.attachments + sim.attachments[key] = sim.get_state() + + sim_new = self._model_manager.get_instance(gid, tgt_cls, late_init=new_state_exists) + sim_new.attachments = sim.attachments + + if new_state_exists: + sim_new_state = sim.attachments[key_new] + sim_new.set_state(sim_new_state) + + sims[idx] = sim_new + + + def update_states( + self, + sims: list, + active_sim_ids: Optional = None, + ): + """ + Stores the current state of the current model into local buffers. + + Parameters + ---------- + sims : list + List of all simulation objects. + active_sim_ids : [list, None] + List of all active simulation ids. + """ + size = len(sims) + active_sims = self._create_active_mask(active_sim_ids, size) + + for idx in range(size): + if not active_sims[idx]: continue + + sim = sims[idx] + key = f"{sim.__name__}-state" + sim.attachments[key] = sim.get_state() + + + def write_back_states( + self, + sims: list, + active_sim_ids: Optional = None, + ): + """ + Loads the current state of the current model into local buffers. + + Parameters + ---------- + sims : list + List of all simulation objects. + active_sim_ids : [list, None] + List of all active simulation ids. + """ + size = len(sims) + active_sims = self._create_active_mask(active_sim_ids, size) + + for idx in range(size): + if not active_sims[idx]: continue + + sim = sims[idx] + key = f"{sim.__name__}-state" + sim.set_state(sim.attachments[key]) + def check_convergence( self, diff --git a/micro_manager/micro_manager.py b/micro_manager/micro_manager.py index 24225773..4fdf31f7 100644 --- a/micro_manager/micro_manager.py +++ b/micro_manager/micro_manager.py @@ -201,28 +201,27 @@ def solve(self) -> None: if (self._adaptivity_in_every_implicit_step or first_iteration) and ( self._n % self._adaptivity_n == 0 ): - self._participant.start_profiling_section( - "micro_manager.solve.adaptivity_computation" - ) + self._participant.start_profiling_section("micro_manager.solve.adaptivity_computation") self._adaptivity_controller.compute_adaptivity( dt, self._micro_sims, self._data_for_adaptivity, ) + active_sim_gids = self._adaptivity_controller.get_active_sim_global_ids() + for gid in active_sim_gids: self._micro_sims_active_steps[gid] += 1 # Write a checkpoint if a simulation is just activated. # This checkpoint will be asynchronous to the checkpoints written at the start of the time window. - for i in range(self._local_number_of_sims): - if sim_states_cp[i] is None and self._micro_sims[i]: - sim_states_cp[i] = self._micro_sims[i].get_state() - - active_sim_gids = ( - self._adaptivity_controller.get_active_sim_global_ids() - ) - - for gid in active_sim_gids: - self._micro_sims_active_steps[gid] += 1 + if self._is_model_adaptivity_on: + self._model_adaptivity_controller.update_states(self._micro_sims, active_sim_gids) + for i in range(self._local_number_of_sims): + if sim_states_cp[i] is None and self._micro_sims[i]: + sim_states_cp[i] = self._micro_sims[i].attachments + else: + for i in range(self._local_number_of_sims): + if sim_states_cp[i] is None and self._micro_sims[i]: + sim_states_cp[i] = self._micro_sims[i].get_state() self._participant.stop_last_profiling_section() @@ -251,10 +250,19 @@ def solve(self) -> None: # Write a checkpoint if self._participant.requires_writing_checkpoint(): - for i in range(self._local_number_of_sims): - sim_states_cp[i] = ( - self._micro_sims[i].get_state() if self._micro_sims[i] else None - ) + active_sim_gids = None + if self._is_adaptivity_on: + active_sim_gids = self._adaptivity_controller.get_active_sim_local_ids() + + if self._is_model_adaptivity_on: + self._model_adaptivity_controller.update_states(self._micro_sims, active_sim_gids) + for i in range(self._local_number_of_sims): + sim_states_cp[i] = self._micro_sims[i].attachments if self._micro_sims[i] else None + else: + for i in range(self._local_number_of_sims): + sim_states_cp[i] = ( + self._micro_sims[i].get_state() if self._micro_sims[i] else None + ) micro_sims_input = self._read_data_from_precice(dt) @@ -299,16 +307,27 @@ def solve(self) -> None: # Revert micro simulations to their last checkpoints if required if self._participant.requires_reading_checkpoint(): - for i in range(self._local_number_of_sims): - if self._micro_sims[i]: - self._micro_sims[i].set_state(sim_states_cp[i]) + if self._is_model_adaptivity_on: + active_sim_gids = None + if self._is_adaptivity_on: + active_sim_gids = self._adaptivity_controller.get_active_sim_local_ids() + + for i in range(self._local_number_of_sims): + if self._micro_sims[i]: + self._micro_sims[i].attachments = sim_states_cp[i] + self._model_adaptivity_controller.write_back_states(self._micro_sims, active_sim_gids) + + else: + for i in range(self._local_number_of_sims): + if self._micro_sims[i]: + self._micro_sims[i].set_state(sim_states_cp[i]) + first_iteration = False - if ( - self._participant.is_time_window_complete() - ): # Time window has converged, now micro output can be generated + # Time window has converged, now micro output can be generated + if self._participant.is_time_window_complete(): self._t += dt # Update time to the end of the time window - self._n += 1 # Update time step to the end of the time window + self._n += 1 # Update time step to the end of the time window if self._micro_sims_have_output: if self._n % self._micro_n_out == 0: diff --git a/micro_manager/micro_simulation.py b/micro_manager/micro_simulation.py index 39c07268..9d787eb2 100644 --- a/micro_manager/micro_simulation.py +++ b/micro_manager/micro_simulation.py @@ -20,33 +20,37 @@ def set_state(self, state): pass @abstractmethod def get_global_id(self): pass @abstractmethod + def set_global_id(self, global_id): pass + @abstractmethod def initialize(self, *args, **kwargs): pass @abstractmethod def output(self): pass class MicroSimulationLocal(MicroSimulationInterface): - def __init__(self, gid, sim_cls): - self._instance = sim_cls(gid) + def __init__(self, gid, late_init, sim_cls): self._gid = gid + self._instance = sim_cls(-1 if late_init else gid) def solve(self, micro_sim_input, dt): return self._instance.solve(micro_sim_input, dt) def get_state(self): return self._instance.get_state() def set_state(self, state): return self._instance.set_state(state) def get_global_id(self): return self._gid + def set_global_id(self, global_id): self._gid = global_id def initialize(self, *args, **kwargs): return self._instance.initialize(*args, **kwargs) def output(self): return self._instance.output() class MicroSimulationRemote(MicroSimulationInterface): - def __init__(self, gid, num_ranks, conn, sim_cls): + def __init__(self, gid, late_init, num_ranks, conn, sim_cls): self._sim_cls = sim_cls # backend impl class self._gid = gid self._num_ranks = num_ranks self._conn = conn + construct_cls = ConstructLateTask if late_init else ConstructTask for worker_id in range(self._num_ranks): - task = ConstructTask(self._gid, self._sim_cls) + task = construct_cls(self._gid, self._sim_cls) self._conn.send(worker_id, task) for worker_id in range(self._num_ranks): @@ -86,6 +90,9 @@ def set_state(self, state): def get_global_id(self): return self._gid + def set_global_id(self, global_id): + self._gid = global_id + def initialize(self, *args, **kwargs): for worker_id in range(self._num_ranks): task = InitializeTask(self._gid, *args, **kwargs) @@ -116,20 +123,27 @@ class MicroSimulationWrapper(MicroSimulationInterface): If only a single rank is in use: will contain the micro sim instance. Otherwise, it will delegate method calls to workers and not contain state. """ - def __init__(self, sim_cls, global_id, num_ranks, conn): + def __init__(self, name, sim_cls, global_id, late_init, num_ranks, conn): self._impl = None if num_ranks > 1 and conn is not None: - self._impl = MicroSimulationRemote(global_id, num_ranks, conn, sim_cls) + self._impl = MicroSimulationRemote(global_id, late_init, num_ranks, conn, sim_cls) else: - self._impl = MicroSimulationLocal(global_id, sim_cls) + self._impl = MicroSimulationLocal(global_id, late_init, sim_cls) + + self._external_data = dict() + self._name = name def solve(self, micro_sim_input, dt): return self._impl.solve(micro_sim_input, dt) def get_state(self): return self._impl.get_state() def set_state(self, state): return self._impl.set_state(state) def get_global_id(self): return self._impl.get_global_id() + def set_global_id(self, global_id): return self._impl.set_global_id(global_id) def initialize(self, *args, **kwargs): return self._impl.initialize(*args, **kwargs) def output(self): return self._impl.output() + @property + def attachments(self): return self._external_data + def __class__(self): return self._name class MicroSimulationClassAdapter: @@ -141,7 +155,7 @@ def __init__(self, sim_cls, name, num_ranks, conn, log): self._log = log def __class__(self): return self._name - def __call__(self, gid): return MicroSimulationWrapper(self._sim_cls, gid, self._num_ranks, self._conn) + def __call__(self, gid, *, late_init=False): return MicroSimulationWrapper(self._name, self._sim_cls, gid, late_init, self._num_ranks, self._conn) @property def backend_cls(self): return self._sim_cls diff --git a/micro_manager/model_manager.py b/micro_manager/model_manager.py index 6ff3b4f0..b81bcc25 100644 --- a/micro_manager/model_manager.py +++ b/micro_manager/model_manager.py @@ -59,7 +59,7 @@ def register(self, micro_sim_cls, stateless): ): self._has_output_map[micro_sim_cls] = True - def get_instance(self, gid, micro_sim_cls): + def get_instance(self, gid, micro_sim_cls, *, late_init=False): if micro_sim_cls not in self._registered_classes: raise RuntimeError("Trying to create instance of unknown class!") @@ -71,4 +71,4 @@ def get_instance(self, gid, micro_sim_cls): self._has_output_map[micro_sim_cls], ) else: - return micro_sim_cls(gid) + return micro_sim_cls(gid, late_init=late_init) diff --git a/micro_manager/tasking/task.py b/micro_manager/tasking/task.py index 92fe25ee..d2f3a01d 100644 --- a/micro_manager/tasking/task.py +++ b/micro_manager/tasking/task.py @@ -17,6 +17,15 @@ def initializer(gid, sim_cls, state_data): state_data[gid] = sim_cls(gid) return None +class ConstructLateTask(Task): + def __init__(self, gid, sim_cls): + super().__init__(ConstructLateTask.initializer, gid=gid, sim_cls=sim_cls) + + @staticmethod + def initializer(gid, sim_cls, state_data): + state_data[gid] = sim_cls(-1) + return None + class SolveTask(Task): def __init__(self, gid, sim_input, dt): super().__init__(SolveTask.solve, gid=gid, sim_input=sim_input, dt=dt) From 5098bdddbee78b604586cda1543c6e0af1aa9d75 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Tue, 13 Jan 2026 19:38:10 +0100 Subject: [PATCH 12/21] small fixes seems to run, needs more testing --- micro_manager/adaptivity/adaptivity_selection.py | 2 +- micro_manager/micro_simulation.py | 4 +++- micro_manager/tasking/__init__.py | 0 micro_manager/tasking/worker_main.py | 16 ++++++++-------- pyproject.toml | 2 +- 5 files changed, 13 insertions(+), 11 deletions(-) create mode 100644 micro_manager/tasking/__init__.py diff --git a/micro_manager/adaptivity/adaptivity_selection.py b/micro_manager/adaptivity/adaptivity_selection.py index 75b5d404..bca40db1 100644 --- a/micro_manager/adaptivity/adaptivity_selection.py +++ b/micro_manager/adaptivity/adaptivity_selection.py @@ -1,7 +1,7 @@ from .global_adaptivity import GlobalAdaptivityCalculator from .global_adaptivity_lb import GlobalAdaptivityLBCalculator from .local_adaptivity import LocalAdaptivityCalculator -from adaptivity import AdaptivityCalculator +from .adaptivity import AdaptivityCalculator def create_adaptivity_calculator( config, diff --git a/micro_manager/micro_simulation.py b/micro_manager/micro_simulation.py index 9d787eb2..4648ce1d 100644 --- a/micro_manager/micro_simulation.py +++ b/micro_manager/micro_simulation.py @@ -161,8 +161,9 @@ def backend_cls(self): return self._sim_cls def check_initialize(self, test_instance, test_input): has_init = hasattr(self._sim_cls, 'initialize') + if not has_init: return False, False callable_init = callable(getattr(self._sim_cls, 'initialize')) - if not has_init or not callable_init: return False, False + if not callable_init: return False, False has_args = False @@ -202,6 +203,7 @@ def check_initialize(self, test_instance, test_input): def check_output(self): has_init = hasattr(self._sim_cls, 'output') + if not has_init: return False callable_init = callable(getattr(self._sim_cls, 'output')) return has_init and callable_init diff --git a/micro_manager/tasking/__init__.py b/micro_manager/tasking/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/micro_manager/tasking/worker_main.py b/micro_manager/tasking/worker_main.py index 6d66573a..a75c1064 100644 --- a/micro_manager/tasking/worker_main.py +++ b/micro_manager/tasking/worker_main.py @@ -2,7 +2,7 @@ import os from mpi4py import MPI -from .connection import Connection, MPIConnection, SocketConnection +from connection import Connection, MPIConnection, SocketConnection if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -27,15 +27,15 @@ state_data = {} while True: - data = None - try: data = conn.recv(src_id) + task = None + try: task = conn.recv(src_id) except Exception: break - # TODO unpickle data into task and handle it - # TODO retain sim_obj... - # TODO should always be smth like this: output = task(state_data) - send_data = None # TODO needs to be set - try: conn.send(dst_id, send_data) + output = None + try: output = task(state_data) + except Exception: break + + try: conn.send(dst_id, output) except Exception: break conn.close() diff --git a/pyproject.toml b/pyproject.toml index dfc9c62f..9288f42d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ Repository = "https://github.com/precice/micro-manager" micro-manager-precice = "micro_manager:main" [tool.setuptools] -packages=["micro_manager", "micro_manager.adaptivity", "micro_manager.snapshot", "micro_manager.tools"] +packages=["micro_manager", "micro_manager.adaptivity", "micro_manager.snapshot", "micro_manager.tools", "micro_manager.tasking"] [tool.setuptools-git-versioning] enabled = true From 4164781b126d68254a8cceb76cf646d1eb77c7cf Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Thu, 15 Jan 2026 11:45:03 +0100 Subject: [PATCH 13/21] improved task model, less pickling --- micro_manager/adaptivity/model_adaptivity.py | 10 +--- micro_manager/micro_manager.py | 1 + micro_manager/micro_simulation.py | 29 ++++----- micro_manager/model_manager.py | 4 ++ micro_manager/snapshot/snapshot.py | 2 +- micro_manager/tasking/connection.py | 8 +++ micro_manager/tasking/task.py | 63 ++++++++++++++++---- micro_manager/tasking/worker_main.py | 15 ++++- 8 files changed, 94 insertions(+), 38 deletions(-) diff --git a/micro_manager/adaptivity/model_adaptivity.py b/micro_manager/adaptivity/model_adaptivity.py index 922e28e1..6b198c16 100644 --- a/micro_manager/adaptivity/model_adaptivity.py +++ b/micro_manager/adaptivity/model_adaptivity.py @@ -4,7 +4,7 @@ from typing import Union, Optional from ..config import Config -from ..micro_simulation import create_simulation_class +from ..micro_simulation import create_simulation_class, load_backend_class from micro_manager.tools.logging_wrapper import Logger from micro_manager.tools.misc import clamp_in_range from micro_manager.model_manager import ModelManager @@ -47,14 +47,10 @@ def __init__( stateless_flags = configurator.get_model_adaptivity_micro_stateless() self._model_classes = [] pos = 0 - CLASS_NAME = "MicroSimulation" for model_file in self._model_files: try: - model = getattr( - importlib.import_module(model_file, CLASS_NAME), - CLASS_NAME, - ) - self._model_classes.append(create_simulation_class(self._logger, model, num_ranks, conn)) + model = load_backend_class(model_file) + self._model_classes.append(create_simulation_class(self._logger, model, model_file, num_ranks, conn)) self._model_manager.register( self._model_classes[pos], stateless_flags[pos] ) diff --git a/micro_manager/micro_manager.py b/micro_manager/micro_manager.py index 4fdf31f7..df15cfea 100644 --- a/micro_manager/micro_manager.py +++ b/micro_manager/micro_manager.py @@ -586,6 +586,7 @@ def initialize(self) -> None: micro_problem_cls = create_simulation_class( self._logger, micro_problem_base, + self._config.get_micro_file_name(), self._config.get_tasking_num_workers(), self._conn, "MicroSimulationDefault" diff --git a/micro_manager/micro_simulation.py b/micro_manager/micro_simulation.py index 4648ce1d..d8ded0b0 100644 --- a/micro_manager/micro_simulation.py +++ b/micro_manager/micro_simulation.py @@ -42,15 +42,15 @@ def output(self): return self._instance.output() class MicroSimulationRemote(MicroSimulationInterface): - def __init__(self, gid, late_init, num_ranks, conn, sim_cls): - self._sim_cls = sim_cls # backend impl class + def __init__(self, gid, late_init, num_ranks, conn, cls_path): + self._cls_path = cls_path self._gid = gid self._num_ranks = num_ranks self._conn = conn construct_cls = ConstructLateTask if late_init else ConstructTask for worker_id in range(self._num_ranks): - task = construct_cls(self._gid, self._sim_cls) + task = construct_cls.send_args(self._gid, self._cls_path) self._conn.send(worker_id, task) for worker_id in range(self._num_ranks): @@ -58,7 +58,7 @@ def __init__(self, gid, late_init, num_ranks, conn, sim_cls): def solve(self, micro_sim_input, dt): for worker_id in range(self._num_ranks): - task = SolveTask(self._gid, micro_sim_input, dt) + task = SolveTask.send_args(self._gid, micro_sim_input, dt) self._conn.send(worker_id, task) result = None @@ -70,7 +70,7 @@ def solve(self, micro_sim_input, dt): def get_state(self): for worker_id in range(self._num_ranks): - task = GetStateTask(self._gid) + task = GetStateTask.send_args(self._gid) self._conn.send(worker_id, task) result = {} @@ -81,7 +81,7 @@ def get_state(self): def set_state(self, state): for worker_id in range(self._num_ranks): - task = SetStateTask(self._gid, state[worker_id]) + task = SetStateTask.send_args(self._gid, state[worker_id]) self._conn.send(worker_id, task) for worker_id in range(self._num_ranks): @@ -95,7 +95,7 @@ def set_global_id(self, global_id): def initialize(self, *args, **kwargs): for worker_id in range(self._num_ranks): - task = InitializeTask(self._gid, *args, **kwargs) + task = InitializeTask.send_args(self._gid, *args, **kwargs) self._conn.send(worker_id, task) result = None @@ -107,7 +107,7 @@ def initialize(self, *args, **kwargs): def output(self): for worker_id in range(self._num_ranks): - task = OutputTask(self._gid) + task = OutputTask.send_args(self._gid) self._conn.send(worker_id, task) result = None @@ -123,11 +123,11 @@ class MicroSimulationWrapper(MicroSimulationInterface): If only a single rank is in use: will contain the micro sim instance. Otherwise, it will delegate method calls to workers and not contain state. """ - def __init__(self, name, sim_cls, global_id, late_init, num_ranks, conn): + def __init__(self, name, sim_cls, cls_path, global_id, late_init, num_ranks, conn): self._impl = None if num_ranks > 1 and conn is not None: - self._impl = MicroSimulationRemote(global_id, late_init, num_ranks, conn, sim_cls) + self._impl = MicroSimulationRemote(global_id, late_init, num_ranks, conn, cls_path) else: self._impl = MicroSimulationLocal(global_id, late_init, sim_cls) @@ -147,15 +147,16 @@ def __class__(self): return self._name class MicroSimulationClassAdapter: - def __init__(self, sim_cls, name, num_ranks, conn, log): + def __init__(self, sim_cls, cls_path, name, num_ranks, conn, log): self._sim_cls = sim_cls + self._cls_path = cls_path self._name = name self._num_ranks = num_ranks self._conn = conn self._log = log def __class__(self): return self._name - def __call__(self, gid, *, late_init=False): return MicroSimulationWrapper(self._name, self._sim_cls, gid, late_init, self._num_ranks, self._conn) + def __call__(self, gid, *, late_init=False): return MicroSimulationWrapper(self._name, self._sim_cls, self._cls_path, gid, late_init, self._num_ranks, self._conn) @property def backend_cls(self): return self._sim_cls @@ -214,7 +215,7 @@ def load_backend_class(path_to_micro_file): return getattr(ipl.import_module(path_to_micro_file, CLS_NAME), CLS_NAME) -def create_simulation_class(log, micro_simulation_class, num_ranks, conn=None, sim_class_name=None): +def create_simulation_class(log, micro_simulation_class, path_to_micro_file, num_ranks, conn=None, sim_class_name=None): """ Creates a class Simulation which inherits from the class of the micro simulation. @@ -242,5 +243,5 @@ def create_simulation_class(log, micro_simulation_class, num_ranks, conn=None, s sim_class_name = f"MicroSimulation{create_simulation_class.sim_id}" - result_cls = MicroSimulationClassAdapter(micro_simulation_class, sim_class_name, num_ranks, conn, log) + result_cls = MicroSimulationClassAdapter(micro_simulation_class, path_to_micro_file, sim_class_name, num_ranks, conn, log) return result_cls diff --git a/micro_manager/model_manager.py b/micro_manager/model_manager.py index b81bcc25..8f442459 100644 --- a/micro_manager/model_manager.py +++ b/micro_manager/model_manager.py @@ -28,6 +28,10 @@ def set_state(self, state): def __class__(self): return self._backend.__class__ + @property + def attachments(self): + return self._backend.attachments + class ModelManager: def __init__(self): diff --git a/micro_manager/snapshot/snapshot.py b/micro_manager/snapshot/snapshot.py index 879ac7ff..a4411813 100644 --- a/micro_manager/snapshot/snapshot.py +++ b/micro_manager/snapshot/snapshot.py @@ -84,7 +84,7 @@ def solve(self) -> None: - Merge output in parallel run. """ - micro_problem_cls = create_simulation_class(self._logger, self._micro_problem, 1, None) + micro_problem_cls = create_simulation_class(self._logger, self._micro_problem, self._config.get_micro_file_name(), 1, None) # Loop over all macro parameters for elems in range(self._local_number_of_sims): diff --git a/micro_manager/tasking/connection.py b/micro_manager/tasking/connection.py index 49cbde17..3ff91366 100644 --- a/micro_manager/tasking/connection.py +++ b/micro_manager/tasking/connection.py @@ -184,6 +184,8 @@ def spawn_local_workers( conn : Connection Established connection on generator side """ + from .task import RegisterAllTask + if n_workers <= 1: return None conn = None @@ -230,4 +232,10 @@ def spawn_local_workers( n_workers=n_workers ) + from ..micro_simulation import load_backend_class + + for worker_id in range(n_workers): + conn.send(worker_id, RegisterAllTask(load_backend_class)) + conn.recv(worker_id) + return conn diff --git a/micro_manager/tasking/task.py b/micro_manager/tasking/task.py index d2f3a01d..36076b41 100644 --- a/micro_manager/tasking/task.py +++ b/micro_manager/tasking/task.py @@ -8,22 +8,34 @@ def __init__(self, fn, *args, **kwargs): def __call__(self, state_data: dict): return self.fn(*self.args, state_data=state_data, **self.kwargs) + @classmethod + def send_args(cls, *args, **kwargs): + return cls.__name__, args, kwargs + class ConstructTask(Task): - def __init__(self, gid, sim_cls): - super().__init__(ConstructTask.initializer, gid=gid, sim_cls=sim_cls) + def __init__(self, gid, cls_path): + super().__init__(ConstructTask.initializer, gid=gid, cls_path=cls_path) @staticmethod - def initializer(gid, sim_cls, state_data): - state_data[gid] = sim_cls(gid) + def initializer(gid, cls_path, state_data): + if cls_path not in state_data['sim_classes']: + state_data['sim_classes'][cls_path] = state_data['load_function'](cls_path) + cls = state_data['sim_classes'][cls_path] + + state_data['sim_instances'][gid] = cls(gid) return None class ConstructLateTask(Task): - def __init__(self, gid, sim_cls): - super().__init__(ConstructLateTask.initializer, gid=gid, sim_cls=sim_cls) + def __init__(self, gid, cls_path): + super().__init__(ConstructLateTask.initializer, gid=gid, cls_path=cls_path) @staticmethod - def initializer(gid, sim_cls, state_data): - state_data[gid] = sim_cls(-1) + def initializer(gid, cls_path, state_data): + if cls_path not in state_data['sim_classes']: + state_data['sim_classes'][cls_path] = state_data['load_function'](cls_path) + cls = state_data['sim_classes'][cls_path] + + state_data['sim_instances'][gid] = cls(-1) return None class SolveTask(Task): @@ -32,7 +44,7 @@ def __init__(self, gid, sim_input, dt): @staticmethod def solve(gid, sim_input, dt, state_data): - sim_output = state_data[gid].solve(sim_input, dt) + sim_output = state_data['sim_instances'][gid].solve(sim_input, dt) return sim_output class GetStateTask(Task): @@ -41,7 +53,7 @@ def __init__(self, gid): @staticmethod def get(gid, state_data): - return state_data[gid].get_state() + return state_data['sim_instances'][gid].get_state() class SetStateTask(Task): def __init__(self, gid, state): @@ -49,7 +61,7 @@ def __init__(self, gid, state): @staticmethod def set(gid, state, state_data): - state_data[gid].set_state(state) + state_data['sim_instances'][gid].set_state(state) return None class InitializeTask(Task): @@ -58,7 +70,7 @@ def __init__(self, gid, *args, **kwargs): @staticmethod def initialize(gid, state_data, *args, **kwargs): - return state_data[gid].initialize(*args, **kwargs) + return state_data['sim_instances'][gid].initialize(*args, **kwargs) class OutputTask(Task): def __init__(self, gid): @@ -66,4 +78,29 @@ def __init__(self, gid): @staticmethod def output(gid, state_data): - return state_data[gid].output() \ No newline at end of file + return state_data['sim_instances'][gid].output() + +class RegisterAllTask(Task): + def __init__(self, load_function): + super().__init__(RegisterAllTask.register, load_function=load_function) + + @staticmethod + def register(state_data, load_function): + task_dict = dict() + task_dict[ConstructTask.__name__] = ConstructTask + task_dict[ConstructLateTask.__name__] = ConstructLateTask + task_dict[SolveTask.__name__] = SolveTask + task_dict[GetStateTask.__name__] = GetStateTask + task_dict[SetStateTask.__name__] = SetStateTask + task_dict[InitializeTask.__name__] = InitializeTask + task_dict[OutputTask.__name__] = OutputTask + state_data['tasks'] = task_dict + state_data['sim_classes'] = dict() + state_data['sim_instances'] = dict() + state_data['load_function'] = load_function + return None + +def handle_task(state_data, task_descriptor): + name, args, kwargs = task_descriptor + task = state_data['tasks'][name](*args, **kwargs) + return task(state_data) \ No newline at end of file diff --git a/micro_manager/tasking/worker_main.py b/micro_manager/tasking/worker_main.py index a75c1064..2de3b915 100644 --- a/micro_manager/tasking/worker_main.py +++ b/micro_manager/tasking/worker_main.py @@ -1,6 +1,7 @@ import argparse import os from mpi4py import MPI +from task import handle_task from connection import Connection, MPIConnection, SocketConnection @@ -26,13 +27,21 @@ state_data = {} + # register possible tasks + register_task = None + try: register_task = conn.recv(src_id) + except Exception: raise RuntimeError("Failed to recv register tasks") + output = register_task(state_data) + try: conn.send(dst_id, output) + except Exception: raise RuntimeError("Failed to send register tasks output") + while True: - task = None - try: task = conn.recv(src_id) + task_descriptor = None + try: task_descriptor = conn.recv(src_id) except Exception: break output = None - try: output = task(state_data) + try: output = handle_task(state_data, task_descriptor) except Exception: break try: conn.send(dst_id, output) From fa26e43f656df215b533a5d68d7c5f43cca437e9 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Fri, 16 Jan 2026 11:36:51 +0100 Subject: [PATCH 14/21] integrate tasking with other modules --- .../adaptivity/adaptivity_selection.py | 22 +- micro_manager/adaptivity/model_adaptivity.py | 42 ++-- micro_manager/config.py | 5 +- micro_manager/micro_manager.py | 58 ++++-- micro_manager/micro_simulation.py | 194 +++++++++++++----- micro_manager/model_manager.py | 12 +- micro_manager/snapshot/snapshot.py | 8 +- micro_manager/tasking/connection.py | 63 +++--- micro_manager/tasking/task.py | 53 +++-- micro_manager/tasking/worker_main.py | 38 ++-- 10 files changed, 342 insertions(+), 153 deletions(-) diff --git a/micro_manager/adaptivity/adaptivity_selection.py b/micro_manager/adaptivity/adaptivity_selection.py index bca40db1..3f005c5f 100644 --- a/micro_manager/adaptivity/adaptivity_selection.py +++ b/micro_manager/adaptivity/adaptivity_selection.py @@ -3,6 +3,7 @@ from .local_adaptivity import LocalAdaptivityCalculator from .adaptivity import AdaptivityCalculator + def create_adaptivity_calculator( config, local_number_of_sims, @@ -14,18 +15,25 @@ def create_adaptivity_calculator( comm, micro_problem_cls, model_manager, - use_lb + use_lb, ) -> AdaptivityCalculator: adaptivity_type = config.get_adaptivity_type() - if adaptivity_type == 'local': + if adaptivity_type == "local": return LocalAdaptivityCalculator( - config, local_number_of_sims, logger, rank, comm, micro_problem_cls, model_manager + config, + local_number_of_sims, + logger, + rank, + comm, + micro_problem_cls, + model_manager, ) - if adaptivity_type == 'global': + if adaptivity_type == "global": cls = GlobalAdaptivityCalculator - if use_lb: cls = GlobalAdaptivityLBCalculator + if use_lb: + cls = GlobalAdaptivityLBCalculator return cls( config, @@ -36,7 +44,7 @@ def create_adaptivity_calculator( rank, comm, micro_problem_cls, - model_manager + model_manager, ) - raise ValueError("Unknown adaptivity type") \ No newline at end of file + raise ValueError("Unknown adaptivity type") diff --git a/micro_manager/adaptivity/model_adaptivity.py b/micro_manager/adaptivity/model_adaptivity.py index 6b198c16..366151ce 100644 --- a/micro_manager/adaptivity/model_adaptivity.py +++ b/micro_manager/adaptivity/model_adaptivity.py @@ -50,7 +50,11 @@ def __init__( for model_file in self._model_files: try: model = load_backend_class(model_file) - self._model_classes.append(create_simulation_class(self._logger, model, model_file, num_ranks, conn)) + self._model_classes.append( + create_simulation_class( + self._logger, model, model_file, num_ranks, conn + ) + ) self._model_manager.register( self._model_classes[pos], stateless_flags[pos] ) @@ -167,14 +171,17 @@ def switch_models( gid = sim.get_global_id() tgt_cls = self.get_resolution_sim_class(tgt_res[idx]) - key = f"{sim.__name__}-state" - key_new = f"{tgt_cls.__name__}-state" + key = f"{sim.name}-state" + key_new = f"{tgt_cls.name}-state" new_state_exists = key_new in sim.attachments sim.attachments[key] = sim.get_state() - sim_new = self._model_manager.get_instance(gid, tgt_cls, late_init=new_state_exists) + sim_new = self._model_manager.get_instance( + gid, tgt_cls, late_init=new_state_exists + ) sim_new.attachments = sim.attachments + sim_new.attachments[key_new] = sim_new.get_state() if new_state_exists: sim_new_state = sim.attachments[key_new] @@ -182,11 +189,10 @@ def switch_models( sims[idx] = sim_new - def update_states( - self, - sims: list, - active_sim_ids: Optional = None, + self, + sims: list, + active_sim_ids: Optional = None, ): """ Stores the current state of the current model into local buffers. @@ -202,17 +208,17 @@ def update_states( active_sims = self._create_active_mask(active_sim_ids, size) for idx in range(size): - if not active_sims[idx]: continue + if not active_sims[idx]: + continue sim = sims[idx] - key = f"{sim.__name__}-state" + key = f"{sim.name}-state" sim.attachments[key] = sim.get_state() - def write_back_states( - self, - sims: list, - active_sim_ids: Optional = None, + self, + sims: list, + active_sim_ids: Optional = None, ): """ Loads the current state of the current model into local buffers. @@ -228,13 +234,13 @@ def write_back_states( active_sims = self._create_active_mask(active_sim_ids, size) for idx in range(size): - if not active_sims[idx]: continue + if not active_sims[idx]: + continue sim = sims[idx] - key = f"{sim.__name__}-state" + key = f"{sim.name}-state" sim.set_state(sim.attachments[key]) - def check_convergence( self, locations: np.ndarray, @@ -320,7 +326,7 @@ def get_sim_class_resolution(self, sim: Union) -> Union: target resolution """ return next( - (idx for idx, cls in enumerate(self._model_classes) if cls == type(sim)) + (idx for idx, cls in enumerate(self._model_classes) if cls.name == sim.name) ) def _gather_current_resolutions( diff --git a/micro_manager/config.py b/micro_manager/config.py index ce41ae98..b3826fe1 100644 --- a/micro_manager/config.py +++ b/micro_manager/config.py @@ -516,9 +516,10 @@ def read_json_micro_manager(self): "model_adaptivity_settings" ]["switching_function"] - if self._data["simulation_params"]["model_adaptivity_settings"][ + if ( "micro_stateless" - ]: + in self._data["simulation_params"]["model_adaptivity_settings"] + ): self._m_adap_micro_stateless = self._data["simulation_params"][ "model_adaptivity_settings" ]["micro_stateless"] diff --git a/micro_manager/micro_manager.py b/micro_manager/micro_manager.py index df15cfea..7af548de 100644 --- a/micro_manager/micro_manager.py +++ b/micro_manager/micro_manager.py @@ -201,20 +201,27 @@ def solve(self) -> None: if (self._adaptivity_in_every_implicit_step or first_iteration) and ( self._n % self._adaptivity_n == 0 ): - self._participant.start_profiling_section("micro_manager.solve.adaptivity_computation") + self._participant.start_profiling_section( + "micro_manager.solve.adaptivity_computation" + ) self._adaptivity_controller.compute_adaptivity( dt, self._micro_sims, self._data_for_adaptivity, ) - active_sim_gids = self._adaptivity_controller.get_active_sim_global_ids() - for gid in active_sim_gids: self._micro_sims_active_steps[gid] += 1 + active_sim_gids = ( + self._adaptivity_controller.get_active_sim_global_ids() + ) + for gid in active_sim_gids: + self._micro_sims_active_steps[gid] += 1 # Write a checkpoint if a simulation is just activated. # This checkpoint will be asynchronous to the checkpoints written at the start of the time window. if self._is_model_adaptivity_on: - self._model_adaptivity_controller.update_states(self._micro_sims, active_sim_gids) + self._model_adaptivity_controller.update_states( + self._micro_sims, active_sim_gids + ) for i in range(self._local_number_of_sims): if sim_states_cp[i] is None and self._micro_sims[i]: sim_states_cp[i] = self._micro_sims[i].attachments @@ -252,16 +259,26 @@ def solve(self) -> None: if self._participant.requires_writing_checkpoint(): active_sim_gids = None if self._is_adaptivity_on: - active_sim_gids = self._adaptivity_controller.get_active_sim_local_ids() + active_sim_gids = ( + self._adaptivity_controller.get_active_sim_local_ids() + ) if self._is_model_adaptivity_on: - self._model_adaptivity_controller.update_states(self._micro_sims, active_sim_gids) + self._model_adaptivity_controller.update_states( + self._micro_sims, active_sim_gids + ) for i in range(self._local_number_of_sims): - sim_states_cp[i] = self._micro_sims[i].attachments if self._micro_sims[i] else None + sim_states_cp[i] = ( + self._micro_sims[i].attachments + if self._micro_sims[i] + else None + ) else: for i in range(self._local_number_of_sims): sim_states_cp[i] = ( - self._micro_sims[i].get_state() if self._micro_sims[i] else None + self._micro_sims[i].get_state() + if self._micro_sims[i] + else None ) micro_sims_input = self._read_data_from_precice(dt) @@ -310,12 +327,16 @@ def solve(self) -> None: if self._is_model_adaptivity_on: active_sim_gids = None if self._is_adaptivity_on: - active_sim_gids = self._adaptivity_controller.get_active_sim_local_ids() + active_sim_gids = ( + self._adaptivity_controller.get_active_sim_local_ids() + ) for i in range(self._local_number_of_sims): if self._micro_sims[i]: - self._micro_sims[i].attachments = sim_states_cp[i] - self._model_adaptivity_controller.write_back_states(self._micro_sims, active_sim_gids) + self._micro_sims[i].attachments.update(sim_states_cp[i]) + self._model_adaptivity_controller.write_back_states( + self._micro_sims, active_sim_gids + ) else: for i in range(self._local_number_of_sims): @@ -327,7 +348,7 @@ def solve(self) -> None: # Time window has converged, now micro output can be generated if self._participant.is_time_window_complete(): self._t += dt # Update time to the end of the time window - self._n += 1 # Update time step to the end of the time window + self._n += 1 # Update time step to the end of the time window if self._micro_sims_have_output: if self._n % self._micro_n_out == 0: @@ -569,14 +590,19 @@ def initialize(self) -> None: worker_exec, num_ranks, self._config.get_tasking_backend(), - self._config.get_tasking_use_slurm() + self._config.get_tasking_use_slurm(), ) # load micro sim micro_problem_cls = None if self._is_model_adaptivity_on: self._model_adaptivity_controller: ModelAdaptivity = ModelAdaptivity( - self._model_manager, self._config, self._rank, self._log_file, self._conn, num_ranks, + self._model_manager, + self._config, + self._rank, + self._log_file, + self._conn, + num_ranks, ) micro_problem_cls = ( self._model_adaptivity_controller.get_resolution_sim_class(0) @@ -589,7 +615,7 @@ def initialize(self) -> None: self._config.get_micro_file_name(), self._config.get_tasking_num_workers(), self._conn, - "MicroSimulationDefault" + "MicroSimulationDefault", ) self._model_manager.register( micro_problem_cls, self._config.turn_on_micro_stateless() @@ -667,7 +693,7 @@ def initialize(self) -> None: # Boolean which states if the initialize() method of the micro simulation requires initial data ( self._micro_sims_init, - sim_requires_init_data + sim_requires_init_data, ) = micro_problem_cls.check_initialize( self._micro_sims[first_id], initial_data[first_id] if is_initial_data_available else None, diff --git a/micro_manager/micro_simulation.py b/micro_manager/micro_simulation.py index d8ded0b0..0014246f 100644 --- a/micro_manager/micro_simulation.py +++ b/micro_manager/micro_simulation.py @@ -10,21 +10,35 @@ class MicroSimulation. A global ID member variable is defined for the class Simu from .tasking.task import * + class MicroSimulationInterface(ABC): @abstractmethod - def solve(self, micro_sim_input, dt): pass + def solve(self, micro_sim_input, dt): + pass + @abstractmethod - def get_state(self): pass + def get_state(self): + pass + @abstractmethod - def set_state(self, state): pass + def set_state(self, state): + pass + @abstractmethod - def get_global_id(self): pass + def get_global_id(self): + pass + @abstractmethod - def set_global_id(self, global_id): pass + def set_global_id(self, global_id): + pass + @abstractmethod - def initialize(self, *args, **kwargs): pass + def initialize(self, *args, **kwargs): + pass + @abstractmethod - def output(self): pass + def output(self): + pass class MicroSimulationLocal(MicroSimulationInterface): @@ -32,13 +46,26 @@ def __init__(self, gid, late_init, sim_cls): self._gid = gid self._instance = sim_cls(-1 if late_init else gid) - def solve(self, micro_sim_input, dt): return self._instance.solve(micro_sim_input, dt) - def get_state(self): return self._instance.get_state() - def set_state(self, state): return self._instance.set_state(state) - def get_global_id(self): return self._gid - def set_global_id(self, global_id): self._gid = global_id - def initialize(self, *args, **kwargs): return self._instance.initialize(*args, **kwargs) - def output(self): return self._instance.output() + def solve(self, micro_sim_input, dt): + return self._instance.solve(micro_sim_input, dt) + + def get_state(self): + return self._instance.get_state() + + def set_state(self, state): + return self._instance.set_state(state) + + def get_global_id(self): + return self._gid + + def set_global_id(self, global_id): + self._gid = global_id + + def initialize(self, *args, **kwargs): + return self._instance.initialize(*args, **kwargs) + + def output(self): + return self._instance.output() class MicroSimulationRemote(MicroSimulationInterface): @@ -64,7 +91,8 @@ def solve(self, micro_sim_input, dt): result = None for worker_id in range(self._num_ranks): output = self._conn.recv(worker_id) - if worker_id == 0: result = output + if worker_id == 0: + result = output return result @@ -101,7 +129,8 @@ def initialize(self, *args, **kwargs): result = None for worker_id in range(self._num_ranks): output = self._conn.recv(worker_id) - if worker_id == 0: result = output + if worker_id == 0: + result = output return result @@ -113,7 +142,8 @@ def output(self): result = None for worker_id in range(self._num_ranks): output = self._conn.recv(worker_id) - if worker_id == 0: result = output + if worker_id == 0: + result = output return result @@ -123,27 +153,52 @@ class MicroSimulationWrapper(MicroSimulationInterface): If only a single rank is in use: will contain the micro sim instance. Otherwise, it will delegate method calls to workers and not contain state. """ + def __init__(self, name, sim_cls, cls_path, global_id, late_init, num_ranks, conn): self._impl = None if num_ranks > 1 and conn is not None: - self._impl = MicroSimulationRemote(global_id, late_init, num_ranks, conn, cls_path) + self._impl = MicroSimulationRemote( + global_id, late_init, num_ranks, conn, cls_path + ) else: self._impl = MicroSimulationLocal(global_id, late_init, sim_cls) self._external_data = dict() self._name = name - def solve(self, micro_sim_input, dt): return self._impl.solve(micro_sim_input, dt) - def get_state(self): return self._impl.get_state() - def set_state(self, state): return self._impl.set_state(state) - def get_global_id(self): return self._impl.get_global_id() - def set_global_id(self, global_id): return self._impl.set_global_id(global_id) - def initialize(self, *args, **kwargs): return self._impl.initialize(*args, **kwargs) - def output(self): return self._impl.output() + def solve(self, micro_sim_input, dt): + return self._impl.solve(micro_sim_input, dt) + + def get_state(self): + return self._impl.get_state() + + def set_state(self, state): + return self._impl.set_state(state) + + def get_global_id(self): + return self._impl.get_global_id() + + def set_global_id(self, global_id): + return self._impl.set_global_id(global_id) + + def initialize(self, *args, **kwargs): + return self._impl.initialize(*args, **kwargs) + + def output(self): + return self._impl.output() + @property - def attachments(self): return self._external_data - def __class__(self): return self._name + def attachments(self): + return self._external_data + + @attachments.setter + def attachments(self, value): + self._external_data = value + + @property + def name(self): + return self._name class MicroSimulationClassAdapter: @@ -155,16 +210,32 @@ def __init__(self, sim_cls, cls_path, name, num_ranks, conn, log): self._conn = conn self._log = log - def __class__(self): return self._name - def __call__(self, gid, *, late_init=False): return MicroSimulationWrapper(self._name, self._sim_cls, self._cls_path, gid, late_init, self._num_ranks, self._conn) @property - def backend_cls(self): return self._sim_cls + def name(self): + return self._name + + def __call__(self, gid, *, late_init=False): + return MicroSimulationWrapper( + self._name, + self._sim_cls, + self._cls_path, + gid, + late_init, + self._num_ranks, + self._conn, + ) + + @property + def backend_cls(self): + return self._sim_cls def check_initialize(self, test_instance, test_input): - has_init = hasattr(self._sim_cls, 'initialize') - if not has_init: return False, False - callable_init = callable(getattr(self._sim_cls, 'initialize')) - if not callable_init: return False, False + has_init = hasattr(self._sim_cls, "initialize") + if not has_init: + return False, False + callable_init = callable(getattr(self._sim_cls, "initialize")) + if not callable_init: + return False, False has_args = False @@ -172,16 +243,18 @@ def check_initialize(self, test_instance, test_input): try: argspec = inspect.getfullargspec(self._sim_cls.initialize) # The first argument in the signature is self - if len(argspec.args) == 1: has_args = False - elif len(argspec.args) == 2: has_args = True + if len(argspec.args) == 1: + has_args = False + elif len(argspec.args) == 2: + has_args = True else: raise Exception( "The initialize() method of the Micro simulation has an incorrect number of arguments." ) except TypeError: self._log.log_info_rank_zero( - "The signature of initialize() method of the micro simulation cannot be determined. " + - "Trying to determine the signature by calling the method." + "The signature of initialize() method of the micro simulation cannot be determined. " + + "Trying to determine the signature by calling the method." ) # Try to call the initialize() method without initial data try: @@ -189,8 +262,8 @@ def check_initialize(self, test_instance, test_input): has_args = False except TypeError: self._log.log_info_rank_zero( - "The initialize() method of the micro simulation has arguments. " + - "Attempting to call it again with initial data." + "The initialize() method of the micro simulation has arguments. " + + "Attempting to call it again with initial data." ) try: test_instance.initialize(test_input) @@ -203,19 +276,27 @@ def check_initialize(self, test_instance, test_input): return has_init and callable_init, has_args def check_output(self): - has_init = hasattr(self._sim_cls, 'output') - if not has_init: return False - callable_init = callable(getattr(self._sim_cls, 'output')) + has_init = hasattr(self._sim_cls, "output") + if not has_init: + return False + callable_init = callable(getattr(self._sim_cls, "output")) return has_init and callable_init def load_backend_class(path_to_micro_file): - CLS_NAME = 'MicroSimulation' + CLS_NAME = "MicroSimulation" return getattr(ipl.import_module(path_to_micro_file, CLS_NAME), CLS_NAME) -def create_simulation_class(log, micro_simulation_class, path_to_micro_file, num_ranks, conn=None, sim_class_name=None): +def create_simulation_class( + log, + micro_simulation_class, + path_to_micro_file, + num_ranks, + conn=None, + sim_class_name=None, +): """ Creates a class Simulation which inherits from the class of the micro simulation. @@ -232,16 +313,23 @@ def create_simulation_class(log, micro_simulation_class, path_to_micro_file, num Simulation : class Definition of class Simulation defined in this function. """ - if not hasattr(micro_simulation_class, "get_global_id"): raise ValueError("Invalid micro simulation class") - if not hasattr(micro_simulation_class, "get_state"): raise ValueError("Invalid micro simulation class") - if not hasattr(micro_simulation_class, "set_state"): raise ValueError("Invalid micro simulation class") - if not hasattr(micro_simulation_class, "solve"): raise ValueError("Invalid micro simulation class") + if not hasattr(micro_simulation_class, "get_global_id"): + raise ValueError("Invalid micro simulation class") + if not hasattr(micro_simulation_class, "get_state"): + raise ValueError("Invalid micro simulation class") + if not hasattr(micro_simulation_class, "set_state"): + raise ValueError("Invalid micro simulation class") + if not hasattr(micro_simulation_class, "solve"): + raise ValueError("Invalid micro simulation class") if sim_class_name is None: - if not hasattr(create_simulation_class, "sim_id"): create_simulation_class.sim_id = 0 - else: create_simulation_class.sim_id += 1 + if not hasattr(create_simulation_class, "sim_id"): + create_simulation_class.sim_id = 0 + else: + create_simulation_class.sim_id += 1 sim_class_name = f"MicroSimulation{create_simulation_class.sim_id}" - - result_cls = MicroSimulationClassAdapter(micro_simulation_class, path_to_micro_file, sim_class_name, num_ranks, conn, log) + result_cls = MicroSimulationClassAdapter( + micro_simulation_class, path_to_micro_file, sim_class_name, num_ranks, conn, log + ) return result_cls diff --git a/micro_manager/model_manager.py b/micro_manager/model_manager.py index 8f442459..e7c1ccea 100644 --- a/micro_manager/model_manager.py +++ b/micro_manager/model_manager.py @@ -32,6 +32,14 @@ def __class__(self): def attachments(self): return self._backend.attachments + @attachments.setter + def attachments(self, value): + self._backend.attachments = value + + @property + def name(self): + return self._backend.name + class ModelManager: def __init__(self): @@ -49,7 +57,9 @@ def register(self, micro_sim_cls, stateless): self._stateless_map[micro_sim_cls] = stateless if stateless: - self._backend_map[micro_sim_cls] = micro_sim_cls(-1) + self._backend_map[micro_sim_cls] = micro_sim_cls( + len(self._registered_classes) - 1 + ) self._has_init_map[micro_sim_cls] = False if hasattr(micro_sim_cls, "initialize") and callable( diff --git a/micro_manager/snapshot/snapshot.py b/micro_manager/snapshot/snapshot.py index a4411813..618e5987 100644 --- a/micro_manager/snapshot/snapshot.py +++ b/micro_manager/snapshot/snapshot.py @@ -84,7 +84,13 @@ def solve(self) -> None: - Merge output in parallel run. """ - micro_problem_cls = create_simulation_class(self._logger, self._micro_problem, self._config.get_micro_file_name(), 1, None) + micro_problem_cls = create_simulation_class( + self._logger, + self._micro_problem, + self._config.get_micro_file_name(), + 1, + None, + ) # Loop over all macro parameters for elems in range(self._local_number_of_sims): diff --git a/micro_manager/tasking/connection.py b/micro_manager/tasking/connection.py index 3ff91366..bcfd4b7a 100644 --- a/micro_manager/tasking/connection.py +++ b/micro_manager/tasking/connection.py @@ -8,13 +8,19 @@ from typing import Any, Dict, Optional from mpi4py import MPI + class Connection(ABC): @abstractmethod - def send(self, dst_id: int, obj: Any) -> None: pass + def send(self, dst_id: int, obj: Any) -> None: + pass + @abstractmethod - def recv(self, src_id: int) -> Any: pass + def recv(self, src_id: int) -> Any: + pass + @abstractmethod - def close(self) -> None: pass + def close(self) -> None: + pass class MPIConnection(Connection): @@ -22,7 +28,9 @@ def __init__(self): self.inter_comm = None @classmethod - def create_workers(cls, worker_exec: str, mpi_args: Optional, n_workers: int) -> "MPIConnection": + def create_workers( + cls, worker_exec: str, mpi_args: Optional, n_workers: int + ) -> "MPIConnection": comm = MPI.COMM_SELF conn = cls() conn.inter_comm = comm.Spawn( @@ -55,7 +63,9 @@ def __init__(self): self.sockets: Dict[int, socket.socket] = {} @classmethod - def create_workers(cls, worker_exec: str, launcher: list, host: str, n_workers: int) -> "SocketConnection": + def create_workers( + cls, worker_exec: str, launcher: list, host: str, n_workers: int + ) -> "SocketConnection": # create listening socket with ephemeral port server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server.bind((host, 0)) # kernel picks free port @@ -65,9 +75,12 @@ def create_workers(cls, worker_exec: str, launcher: list, host: str, n_workers: executable = [ "python", worker_exec, - "--backend", "socket", - "--host", host, - "--port", str(port), + "--backend", + "socket", + "--host", + host, + "--port", + str(port), ] cmd = [] cmd.extend(launcher) @@ -113,7 +126,8 @@ def recv(self, src_id: int) -> Any: return pickle.loads(payload) def close(self) -> None: - for sock in self.sockets.values(): sock.close() + for sock in self.sockets.values(): + sock.close() self.sockets.clear() @@ -186,22 +200,26 @@ def spawn_local_workers( """ from .task import RegisterAllTask - if n_workers <= 1: return None + if n_workers <= 1: + return None conn = None # MPI BACKEND (non-Slurm only) if backend == "mpi": - if is_slurm: raise RuntimeError( - "MPI backend is not supported under Slurm. " - "Use socket backend instead." - ) + if is_slurm: + raise RuntimeError( + "MPI backend is not supported under Slurm. " + "Use socket backend instead." + ) comm = MPI.COMM_WORLD local_rank = comm.Get_rank() conn = MPIConnection.create_workers( worker_exec=worker_exec, mpi_args=[ - "--backend", "mpi", - "--parentrank", str(local_rank), + "--backend", + "mpi", + "--parentrank", + str(local_rank), ], n_workers=n_workers, ) @@ -215,21 +233,20 @@ def spawn_local_workers( if is_slurm: launcher = [ "srun", - #"--exclusive", - "--ntasks", str(n_workers), + # "--exclusive", + "--ntasks", + str(n_workers), "--kill-on-bad-exit=1", ] else: launcher = [ "mpiexec", - "-n", str(n_workers), + "-n", + str(n_workers), ] conn = SocketConnection.create_workers( - worker_exec=worker_exec, - launcher=launcher, - host=host, - n_workers=n_workers + worker_exec=worker_exec, launcher=launcher, host=host, n_workers=n_workers ) from ..micro_simulation import load_backend_class diff --git a/micro_manager/tasking/task.py b/micro_manager/tasking/task.py index 36076b41..1b5b0f67 100644 --- a/micro_manager/tasking/task.py +++ b/micro_manager/tasking/task.py @@ -1,4 +1,3 @@ - class Task: def __init__(self, fn, *args, **kwargs): self.fn = fn @@ -12,48 +11,57 @@ def __call__(self, state_data: dict): def send_args(cls, *args, **kwargs): return cls.__name__, args, kwargs + class ConstructTask(Task): def __init__(self, gid, cls_path): super().__init__(ConstructTask.initializer, gid=gid, cls_path=cls_path) @staticmethod def initializer(gid, cls_path, state_data): - if cls_path not in state_data['sim_classes']: - state_data['sim_classes'][cls_path] = state_data['load_function'](cls_path) - cls = state_data['sim_classes'][cls_path] + if cls_path not in state_data["sim_classes"]: + state_data["sim_classes"][cls_path] = state_data["load_function"](cls_path) + cls = state_data["sim_classes"][cls_path] - state_data['sim_instances'][gid] = cls(gid) + if gid in state_data["sim_classes"]: + del state_data["sim_classes"][gid] + state_data["sim_instances"][gid] = cls(gid) return None + class ConstructLateTask(Task): def __init__(self, gid, cls_path): super().__init__(ConstructLateTask.initializer, gid=gid, cls_path=cls_path) @staticmethod def initializer(gid, cls_path, state_data): - if cls_path not in state_data['sim_classes']: - state_data['sim_classes'][cls_path] = state_data['load_function'](cls_path) - cls = state_data['sim_classes'][cls_path] + if cls_path not in state_data["sim_classes"]: + state_data["sim_classes"][cls_path] = state_data["load_function"](cls_path) + cls = state_data["sim_classes"][cls_path] - state_data['sim_instances'][gid] = cls(-1) + if gid in state_data["sim_classes"]: + del state_data["sim_classes"][gid] + state_data["sim_instances"][gid] = cls(-1) return None + class SolveTask(Task): def __init__(self, gid, sim_input, dt): super().__init__(SolveTask.solve, gid=gid, sim_input=sim_input, dt=dt) @staticmethod def solve(gid, sim_input, dt, state_data): - sim_output = state_data['sim_instances'][gid].solve(sim_input, dt) + sim_output = state_data["sim_instances"][gid].solve(sim_input, dt) return sim_output + class GetStateTask(Task): def __init__(self, gid): super().__init__(GetStateTask.get, gid=gid) @staticmethod def get(gid, state_data): - return state_data['sim_instances'][gid].get_state() + return state_data["sim_instances"][gid].get_state() + class SetStateTask(Task): def __init__(self, gid, state): @@ -61,16 +69,18 @@ def __init__(self, gid, state): @staticmethod def set(gid, state, state_data): - state_data['sim_instances'][gid].set_state(state) + state_data["sim_instances"][gid].set_state(state) return None + class InitializeTask(Task): def __init__(self, gid, *args, **kwargs): super().__init__(InitializeTask.initialize, *args, gid=gid, **kwargs) @staticmethod def initialize(gid, state_data, *args, **kwargs): - return state_data['sim_instances'][gid].initialize(*args, **kwargs) + return state_data["sim_instances"][gid].initialize(*args, **kwargs) + class OutputTask(Task): def __init__(self, gid): @@ -78,7 +88,8 @@ def __init__(self, gid): @staticmethod def output(gid, state_data): - return state_data['sim_instances'][gid].output() + return state_data["sim_instances"][gid].output() + class RegisterAllTask(Task): def __init__(self, load_function): @@ -94,13 +105,15 @@ def register(state_data, load_function): task_dict[SetStateTask.__name__] = SetStateTask task_dict[InitializeTask.__name__] = InitializeTask task_dict[OutputTask.__name__] = OutputTask - state_data['tasks'] = task_dict - state_data['sim_classes'] = dict() - state_data['sim_instances'] = dict() - state_data['load_function'] = load_function + state_data["tasks"] = task_dict + state_data["sim_classes"] = dict() + state_data["sim_instances"] = dict() + state_data["load_function"] = load_function return None + def handle_task(state_data, task_descriptor): name, args, kwargs = task_descriptor - task = state_data['tasks'][name](*args, **kwargs) - return task(state_data) \ No newline at end of file + task = state_data["tasks"][name](*args, **kwargs) + print(f"handling task: {name} args={args} kwargs={kwargs}") + return task(state_data) diff --git a/micro_manager/tasking/worker_main.py b/micro_manager/tasking/worker_main.py index 2de3b915..d0db4103 100644 --- a/micro_manager/tasking/worker_main.py +++ b/micro_manager/tasking/worker_main.py @@ -5,12 +5,16 @@ from connection import Connection, MPIConnection, SocketConnection -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--backend", required=True, choices=["mpi", "socket"]) parser.add_argument("--host", help="IP or localhost") parser.add_argument("--port", type=int, help="Port to open port in micro manager") - parser.add_argument("--parentrank", type=int, help="Parent rank of spawning micro manager mpi instance") + parser.add_argument( + "--parentrank", + type=int, + help="Parent rank of spawning micro manager mpi instance", + ) args = parser.parse_args() rank = MPI.COMM_WORLD.Get_rank() @@ -29,22 +33,32 @@ # register possible tasks register_task = None - try: register_task = conn.recv(src_id) - except Exception: raise RuntimeError("Failed to recv register tasks") + try: + register_task = conn.recv(src_id) + except Exception: + raise RuntimeError("Failed to recv register tasks") output = register_task(state_data) - try: conn.send(dst_id, output) - except Exception: raise RuntimeError("Failed to send register tasks output") + try: + conn.send(dst_id, output) + except Exception: + raise RuntimeError("Failed to send register tasks output") while True: task_descriptor = None - try: task_descriptor = conn.recv(src_id) - except Exception: break + try: + task_descriptor = conn.recv(src_id) + except Exception: + break output = None - try: output = handle_task(state_data, task_descriptor) - except Exception: break + try: + output = handle_task(state_data, task_descriptor) + except Exception: + break - try: conn.send(dst_id, output) - except Exception: break + try: + conn.send(dst_id, output) + except Exception: + break conn.close() From 7b366ea33d27852a4752c87e433845bd27d84c16 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Fri, 16 Jan 2026 16:26:07 +0100 Subject: [PATCH 15/21] remove debug output --- micro_manager/tasking/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_manager/tasking/task.py b/micro_manager/tasking/task.py index 1b5b0f67..d04e6359 100644 --- a/micro_manager/tasking/task.py +++ b/micro_manager/tasking/task.py @@ -115,5 +115,5 @@ def register(state_data, load_function): def handle_task(state_data, task_descriptor): name, args, kwargs = task_descriptor task = state_data["tasks"][name](*args, **kwargs) - print(f"handling task: {name} args={args} kwargs={kwargs}") + #print(f"handling task: {name} args={args} kwargs={kwargs}") return task(state_data) From 37d497cfb086e58838cf60ccbf02ab8cb3baddca Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Tue, 20 Jan 2026 09:16:54 +0100 Subject: [PATCH 16/21] fix tests --- micro_manager/tasking/task.py | 2 +- tests/integration/test_unit_cube/micro_dummy.py | 3 +++ tests/unit/test_micro_manager.py | 9 +++++++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/micro_manager/tasking/task.py b/micro_manager/tasking/task.py index d04e6359..55b5cb46 100644 --- a/micro_manager/tasking/task.py +++ b/micro_manager/tasking/task.py @@ -115,5 +115,5 @@ def register(state_data, load_function): def handle_task(state_data, task_descriptor): name, args, kwargs = task_descriptor task = state_data["tasks"][name](*args, **kwargs) - #print(f"handling task: {name} args={args} kwargs={kwargs}") + # print(f"handling task: {name} args={args} kwargs={kwargs}") return task(state_data) diff --git a/tests/integration/test_unit_cube/micro_dummy.py b/tests/integration/test_unit_cube/micro_dummy.py index b2ec8b4b..6e35bc4f 100644 --- a/tests/integration/test_unit_cube/micro_dummy.py +++ b/tests/integration/test_unit_cube/micro_dummy.py @@ -53,3 +53,6 @@ def get_state(self): def set_state(self, state): self._state = copy.deepcopy(state) + + def get_global_id(self): + return self._sim_id diff --git a/tests/unit/test_micro_manager.py b/tests/unit/test_micro_manager.py index 9f656b2f..87438ad7 100644 --- a/tests/unit/test_micro_manager.py +++ b/tests/unit/test_micro_manager.py @@ -21,6 +21,15 @@ def solve(self, macro_data, dt): "micro-vector-data": macro_data["macro-vector-data"] + 1, } + def get_global_id(self): + pass + + def get_state(self): + return None + + def set_state(self, state): + pass + class TestFunctioncalls(TestCase): def setUp(self): From e6d67031d8545d066ada172388171e447f50d725 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Tue, 20 Jan 2026 09:42:59 +0100 Subject: [PATCH 17/21] fix example micro files --- examples/cpp-dummy/micro_cpp_dummy.cpp | 6 ++++++ examples/cpp-dummy/micro_cpp_dummy.hpp | 1 + examples/python-dummy/micro_dummy.py | 3 +++ 3 files changed, 10 insertions(+) diff --git a/examples/cpp-dummy/micro_cpp_dummy.cpp b/examples/cpp-dummy/micro_cpp_dummy.cpp index 3127cc84..368ed895 100644 --- a/examples/cpp-dummy/micro_cpp_dummy.cpp +++ b/examples/cpp-dummy/micro_cpp_dummy.cpp @@ -59,6 +59,11 @@ py::list MicroSimulation::get_state() const return state_python; } +int MicroSimulation::get_global_id() const +{ + return _sim_id; +} + PYBIND11_MODULE(micro_dummy, m) { // optional docstring m.doc() = "pybind11 micro dummy plugin"; @@ -68,6 +73,7 @@ PYBIND11_MODULE(micro_dummy, m) { .def("solve", &MicroSimulation::solve) .def("get_state", &MicroSimulation::get_state) .def("set_state", &MicroSimulation::set_state) + .def("get_global_id", &MicroSimulation::get_global_id) // Pickling support does not work currently, as there is no way to pass the simulation ID to the new instance ms. .def(py::pickle( // https://pybind11.readthedocs.io/en/latest/advanced/classes.html#pickling-support [](const MicroSimulation &ms) { // __getstate__ diff --git a/examples/cpp-dummy/micro_cpp_dummy.hpp b/examples/cpp-dummy/micro_cpp_dummy.hpp index fb230ea1..29c306e0 100644 --- a/examples/cpp-dummy/micro_cpp_dummy.hpp +++ b/examples/cpp-dummy/micro_cpp_dummy.hpp @@ -20,6 +20,7 @@ class MicroSimulation void set_state(py::list state); py::list get_state() const; + int get_global_id() const; private: int _sim_id; diff --git a/examples/python-dummy/micro_dummy.py b/examples/python-dummy/micro_dummy.py index 638e2051..3bc074fb 100644 --- a/examples/python-dummy/micro_dummy.py +++ b/examples/python-dummy/micro_dummy.py @@ -32,3 +32,6 @@ def set_state(self, state): def get_state(self): return self._state + + def get_global_id(self): + return self._sim_id From f0bf90aad9cb844df69604b263d7148f20ad794c Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Tue, 20 Jan 2026 09:53:14 +0100 Subject: [PATCH 18/21] fix more test add member field delegation for sim wrapper for testing purposes --- micro_manager/micro_simulation.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/micro_manager/micro_simulation.py b/micro_manager/micro_simulation.py index 0014246f..1355ef27 100644 --- a/micro_manager/micro_simulation.py +++ b/micro_manager/micro_simulation.py @@ -61,6 +61,9 @@ def get_global_id(self): def set_global_id(self, global_id): self._gid = global_id + def __getattr__(self, name): + return getattr(self._impl, name) + def initialize(self, *args, **kwargs): return self._instance.initialize(*args, **kwargs) @@ -188,6 +191,9 @@ def initialize(self, *args, **kwargs): def output(self): return self._impl.output() + def __getattr__(self, name): + return getattr(self._impl, name) + @property def attachments(self): return self._external_data From 28d424e96074adf559582e59c919064e8c6b55de Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Tue, 20 Jan 2026 09:57:17 +0100 Subject: [PATCH 19/21] fix last change --- micro_manager/micro_simulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_manager/micro_simulation.py b/micro_manager/micro_simulation.py index 1355ef27..e0df10e6 100644 --- a/micro_manager/micro_simulation.py +++ b/micro_manager/micro_simulation.py @@ -62,7 +62,7 @@ def set_global_id(self, global_id): self._gid = global_id def __getattr__(self, name): - return getattr(self._impl, name) + return getattr(self._instance, name) def initialize(self, *args, **kwargs): return self._instance.initialize(*args, **kwargs) From 5d025a4026216a6788b97f30903192912e98cde2 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Tue, 20 Jan 2026 10:08:38 +0100 Subject: [PATCH 20/21] fix test micro sim classes --- tests/unit/test_adaptivity_parallel.py | 3 +++ tests/unit/test_adaptivity_serial.py | 3 +++ tests/unit/test_global_adaptivity_lb.py | 3 +++ tests/unit/test_micro_simulation_crash_handling.py | 9 +++++++++ tests/unit/test_snapshot_computation.py | 9 +++++++++ 5 files changed, 27 insertions(+) diff --git a/tests/unit/test_adaptivity_parallel.py b/tests/unit/test_adaptivity_parallel.py index 12379b40..68391e0c 100644 --- a/tests/unit/test_adaptivity_parallel.py +++ b/tests/unit/test_adaptivity_parallel.py @@ -21,6 +21,9 @@ def set_state(self, state): def get_state(self): return self._state.copy() + def solve(self, micro_input, dt): + pass + class ModelManager: def get_instance(self, gid, micro_problem_cls): diff --git a/tests/unit/test_adaptivity_serial.py b/tests/unit/test_adaptivity_serial.py index f92c5b00..fb4aa04a 100644 --- a/tests/unit/test_adaptivity_serial.py +++ b/tests/unit/test_adaptivity_serial.py @@ -25,6 +25,9 @@ def set_state(self, state): def get_state(self): pass + def solve(self, micro_input, dt): + pass + class ModelManager: def get_instance(self, gid, micro_problem_cls): diff --git a/tests/unit/test_global_adaptivity_lb.py b/tests/unit/test_global_adaptivity_lb.py index 1e1e211b..4e6af391 100644 --- a/tests/unit/test_global_adaptivity_lb.py +++ b/tests/unit/test_global_adaptivity_lb.py @@ -22,6 +22,9 @@ def set_state(self, state): def get_state(self): return self._state.copy() + def solve(self, micro_input, dt): + pass + class ModelManager: def get_instance(self, gid, micro_problem_cls): diff --git a/tests/unit/test_micro_simulation_crash_handling.py b/tests/unit/test_micro_simulation_crash_handling.py index c7861aa1..7da636c6 100644 --- a/tests/unit/test_micro_simulation_crash_handling.py +++ b/tests/unit/test_micro_simulation_crash_handling.py @@ -21,6 +21,15 @@ def solve(self, macro_data, dt): "micro-scalar-data": macro_data["macro-scalar-data"], } + def get_state(self): + return None + + def set_state(self, state): + pass + + def get_global_id(self): + return self.sim_id + class TestSimulationCrashHandling(TestCase): def test_crash_handling(self): diff --git a/tests/unit/test_snapshot_computation.py b/tests/unit/test_snapshot_computation.py index c4cc807d..614fcab8 100644 --- a/tests/unit/test_snapshot_computation.py +++ b/tests/unit/test_snapshot_computation.py @@ -19,6 +19,15 @@ def solve(self, macro_data, dt): "micro-vector-data": macro_data["macro-vector-data"] + 1, } + def get_state(self): + return None + + def set_state(self, state): + pass + + def get_global_id(self): + pass + class TestFunctionCalls(TestCase): def setUp(self): From 70e1d9812879ab092858615b16d9a790baf23a93 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Tue, 20 Jan 2026 10:25:09 +0100 Subject: [PATCH 21/21] fix snapshot test --- tests/unit/test_snapshot_computation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_snapshot_computation.py b/tests/unit/test_snapshot_computation.py index 614fcab8..4553312f 100644 --- a/tests/unit/test_snapshot_computation.py +++ b/tests/unit/test_snapshot_computation.py @@ -94,7 +94,11 @@ def test_solve_micro_sims(self): snapshot_object._micro_problem = MicroSimulation snapshot_object._micro_sims = create_simulation_class( - snapshot_object._micro_problem + MagicMock(), + snapshot_object._micro_problem, + None, + 1, + None, )(0) micro_sim_output = snapshot_object._solve_micro_simulation(self.fake_read_data)