diff --git a/examples/cpp-dummy/micro_cpp_dummy.cpp b/examples/cpp-dummy/micro_cpp_dummy.cpp index 3127cc84..368ed895 100644 --- a/examples/cpp-dummy/micro_cpp_dummy.cpp +++ b/examples/cpp-dummy/micro_cpp_dummy.cpp @@ -59,6 +59,11 @@ py::list MicroSimulation::get_state() const return state_python; } +int MicroSimulation::get_global_id() const +{ + return _sim_id; +} + PYBIND11_MODULE(micro_dummy, m) { // optional docstring m.doc() = "pybind11 micro dummy plugin"; @@ -68,6 +73,7 @@ PYBIND11_MODULE(micro_dummy, m) { .def("solve", &MicroSimulation::solve) .def("get_state", &MicroSimulation::get_state) .def("set_state", &MicroSimulation::set_state) + .def("get_global_id", &MicroSimulation::get_global_id) // Pickling support does not work currently, as there is no way to pass the simulation ID to the new instance ms. .def(py::pickle( // https://pybind11.readthedocs.io/en/latest/advanced/classes.html#pickling-support [](const MicroSimulation &ms) { // __getstate__ diff --git a/examples/cpp-dummy/micro_cpp_dummy.hpp b/examples/cpp-dummy/micro_cpp_dummy.hpp index fb230ea1..29c306e0 100644 --- a/examples/cpp-dummy/micro_cpp_dummy.hpp +++ b/examples/cpp-dummy/micro_cpp_dummy.hpp @@ -20,6 +20,7 @@ class MicroSimulation void set_state(py::list state); py::list get_state() const; + int get_global_id() const; private: int _sim_id; diff --git a/examples/python-dummy/micro_dummy.py b/examples/python-dummy/micro_dummy.py index 638e2051..3bc074fb 100644 --- a/examples/python-dummy/micro_dummy.py +++ b/examples/python-dummy/micro_dummy.py @@ -32,3 +32,6 @@ def set_state(self, state): def get_state(self): return self._state + + def get_global_id(self): + return self._sim_id diff --git a/micro_manager/adaptivity/adaptivity.py b/micro_manager/adaptivity/adaptivity.py index 293edb7f..70894bf5 100644 --- a/micro_manager/adaptivity/adaptivity.py +++ b/micro_manager/adaptivity/adaptivity.py @@ -11,7 +11,7 @@ class AdaptivityCalculator: def __init__( - self, configurator, nsims, micro_problem_cls, base_logger, rank + self, configurator, nsims, micro_problem_cls, model_manager, base_logger, rank ) -> None: """ Class constructor. @@ -24,6 +24,8 @@ def __init__( Number of micro simulations. micro_problem_cls : callable Class of micro problem. + model_manager : object + Handles instantiation of micro simulation. base_logger : object of class Logger Logger object to log messages. rank : int @@ -37,6 +39,7 @@ def __init__( self._adaptivity_output_type = configurator.get_adaptivity_output_type() self._micro_problem_cls = micro_problem_cls + self._model_manager = model_manager self._coarse_tol = 0.0 self._ref_tol = 0.0 diff --git a/micro_manager/adaptivity/adaptivity_selection.py b/micro_manager/adaptivity/adaptivity_selection.py new file mode 100644 index 00000000..3f005c5f --- /dev/null +++ b/micro_manager/adaptivity/adaptivity_selection.py @@ -0,0 +1,50 @@ +from .global_adaptivity import GlobalAdaptivityCalculator +from .global_adaptivity_lb import GlobalAdaptivityLBCalculator +from .local_adaptivity import LocalAdaptivityCalculator +from .adaptivity import AdaptivityCalculator + + +def create_adaptivity_calculator( + config, + local_number_of_sims, + global_number_of_sims, + global_ids_of_local_sims, + participant, + logger, + rank, + comm, + micro_problem_cls, + model_manager, + use_lb, +) -> AdaptivityCalculator: + adaptivity_type = config.get_adaptivity_type() + + if adaptivity_type == "local": + return LocalAdaptivityCalculator( + config, + local_number_of_sims, + logger, + rank, + comm, + micro_problem_cls, + model_manager, + ) + + if adaptivity_type == "global": + cls = GlobalAdaptivityCalculator + if use_lb: + cls = GlobalAdaptivityLBCalculator + + return cls( + config, + global_number_of_sims, + global_ids_of_local_sims, + participant, + logger, + rank, + comm, + micro_problem_cls, + model_manager, + ) + + raise ValueError("Unknown adaptivity type") diff --git a/micro_manager/adaptivity/global_adaptivity.py b/micro_manager/adaptivity/global_adaptivity.py index 81b29a8e..d8b7afd1 100644 --- a/micro_manager/adaptivity/global_adaptivity.py +++ b/micro_manager/adaptivity/global_adaptivity.py @@ -26,6 +26,7 @@ def __init__( rank: int, comm, micro_problem_cls, + model_manager, ) -> None: """ Class constructor. @@ -48,9 +49,16 @@ def __init__( Communicator for MPI. micro_problem_cls : callable Class of micro problem. + model_manager : object of class ModelManager + Handles instantiation of the micro simulation. """ super().__init__( - configurator, global_number_of_sims, micro_problem_cls, base_logger, rank + configurator, + global_number_of_sims, + micro_problem_cls, + model_manager, + base_logger, + rank, ) self._global_number_of_sims = global_number_of_sims self._global_ids = global_ids @@ -459,7 +467,9 @@ def _update_inactive_sims(self, micro_sims: list) -> None: # Only handle activation of simulations on this rank for gid in to_be_activated_gids: to_be_activated_lid = self._global_ids.index(gid) - micro_sims[to_be_activated_lid] = self._micro_problem_cls(gid) + micro_sims[to_be_activated_lid] = self._model_manager.get_instance( + gid, self._micro_problem_cls + ) assoc_active_gid = self._sim_is_associated_to[gid] if self._is_sim_on_this_rank[ @@ -496,7 +506,9 @@ def _update_inactive_sims(self, micro_sims: list) -> None: local_ids = to_be_activated_map[gid] for lid in local_ids: # Create the micro simulation object and set its state - micro_sims[lid] = self._micro_problem_cls(self._global_ids[lid]) + micro_sims[lid] = self._model_manager.get_instance( + self._global_ids[lid], self._micro_problem_cls + ) micro_sims[lid].set_state(state) # Delete the micro simulation object if it is inactive diff --git a/micro_manager/adaptivity/global_adaptivity_lb.py b/micro_manager/adaptivity/global_adaptivity_lb.py index e6ded12f..4f1a3d0a 100644 --- a/micro_manager/adaptivity/global_adaptivity_lb.py +++ b/micro_manager/adaptivity/global_adaptivity_lb.py @@ -23,6 +23,7 @@ def __init__( rank: int, comm, micro_problem_cls: callable, + model_manager, ) -> None: """ Class constructor. @@ -45,6 +46,8 @@ def __init__( Communicator for MPI. micro_problem_cls : callable Class of micro problem. + model_manager : object of class ModelManager + Handles instantiation of the micro simulation. """ super().__init__( configurator, @@ -55,6 +58,7 @@ def __init__( rank, comm, micro_problem_cls, + model_manager, ) self._base_logger = base_logger @@ -367,7 +371,9 @@ def _move_active_sims( # Create simulations and set them to the received states for req in recv_reqs: output, gid = req.wait() - micro_sims.append(self._micro_problem_cls(gid)) + micro_sims.append( + self._model_manager.get_instance(gid, self._micro_problem_cls) + ) micro_sims[-1].set_state(output) self._global_ids.append(gid) self._is_sim_on_this_rank[gid] = True diff --git a/micro_manager/adaptivity/local_adaptivity.py b/micro_manager/adaptivity/local_adaptivity.py index 6ea9a717..6ba525d1 100644 --- a/micro_manager/adaptivity/local_adaptivity.py +++ b/micro_manager/adaptivity/local_adaptivity.py @@ -20,6 +20,7 @@ def __init__( rank, comm, micro_problem_cls, + model_manager, ) -> None: """ Class constructor. @@ -38,8 +39,12 @@ def __init__( Communicator for MPI. micro_problem_cls : callable Class of micro problem. + model_manager : object of class ModelManager + Handles instantiation of micro simulation. """ - super().__init__(configurator, num_sims, micro_problem_cls, base_logger, rank) + super().__init__( + configurator, num_sims, micro_problem_cls, model_manager, base_logger, rank + ) self._comm = comm # similarity_dists: 2D array having similarity distances between each micro simulation pair @@ -292,7 +297,7 @@ def _update_inactive_sims(self, micro_sims: list) -> None: # Update the set of inactive micro sims for i in to_be_activated_ids: associated_active_id = self._sim_is_associated_to[i] - micro_sims[i] = self._micro_problem_cls(i) + micro_sims[i] = self._model_manager.get_instance(i, self._micro_problem_cls) micro_sims[i].set_state(micro_sims[associated_active_id].get_state()) self._sim_is_associated_to[ i diff --git a/micro_manager/adaptivity/model_adaptivity.py b/micro_manager/adaptivity/model_adaptivity.py index 7d4e049f..366151ce 100644 --- a/micro_manager/adaptivity/model_adaptivity.py +++ b/micro_manager/adaptivity/model_adaptivity.py @@ -4,16 +4,26 @@ from typing import Union, Optional from ..config import Config -from ..micro_simulation import create_simulation_class +from ..micro_simulation import create_simulation_class, load_backend_class from micro_manager.tools.logging_wrapper import Logger from micro_manager.tools.misc import clamp_in_range +from micro_manager.model_manager import ModelManager +from micro_manager.tasking.connection import Connection import numpy as np import importlib class ModelAdaptivity: - def __init__(self, configurator: Config, rank: int, log_file: str) -> None: + def __init__( + self, + model_manager: ModelManager, + configurator: Config, + rank: int, + log_file: str, + conn: Connection, + num_ranks: int, + ) -> None: """ Class constructor. @@ -28,20 +38,27 @@ def __init__(self, configurator: Config, rank: int, log_file: str) -> None: """ self._logger = Logger(__name__, log_file, rank) + self._model_manager = model_manager self._model_files = configurator.get_model_adaptivity_file_names() self._switching_func_name = ( configurator.get_model_adaptivity_switching_function() ) + stateless_flags = configurator.get_model_adaptivity_micro_stateless() self._model_classes = [] - CLASS_NAME = "MicroSimulation" + pos = 0 for model_file in self._model_files: try: - model = getattr( - importlib.import_module(model_file, CLASS_NAME), - CLASS_NAME, + model = load_backend_class(model_file) + self._model_classes.append( + create_simulation_class( + self._logger, model, model_file, num_ranks, conn + ) + ) + self._model_manager.register( + self._model_classes[pos], stateless_flags[pos] ) - self._model_classes.append(create_simulation_class(model)) + pos += 1 except Exception as e: self._logger.log_info_rank_zero( f"Failed to load model class with error: {e}" @@ -114,10 +131,10 @@ def switch_models( self, locations: np.ndarray, t: float, - inputs: list[dict], + inputs: list, prev_output: dict, sims: list, - active_sim_ids: Optional[list[int]] = None, + active_sim_ids: Optional = None, ) -> None: """ Switches models within sims list. If active_sim_ids is None, all sims are considered as active. @@ -150,19 +167,88 @@ def switch_models( if cur_res[idx] == tgt_res[idx]: continue - sim_state = sims[idx].get_state() - sim_id = sims[idx].get_global_id() - sims[idx] = self.get_resolution_sim_class(tgt_res[idx])(sim_id) - sims[idx].set_state(sim_state) + sim = sims[idx] + gid = sim.get_global_id() + tgt_cls = self.get_resolution_sim_class(tgt_res[idx]) + + key = f"{sim.name}-state" + key_new = f"{tgt_cls.name}-state" + + new_state_exists = key_new in sim.attachments + sim.attachments[key] = sim.get_state() + + sim_new = self._model_manager.get_instance( + gid, tgt_cls, late_init=new_state_exists + ) + sim_new.attachments = sim.attachments + sim_new.attachments[key_new] = sim_new.get_state() + + if new_state_exists: + sim_new_state = sim.attachments[key_new] + sim_new.set_state(sim_new_state) + + sims[idx] = sim_new + + def update_states( + self, + sims: list, + active_sim_ids: Optional = None, + ): + """ + Stores the current state of the current model into local buffers. + + Parameters + ---------- + sims : list + List of all simulation objects. + active_sim_ids : [list, None] + List of all active simulation ids. + """ + size = len(sims) + active_sims = self._create_active_mask(active_sim_ids, size) + + for idx in range(size): + if not active_sims[idx]: + continue + + sim = sims[idx] + key = f"{sim.name}-state" + sim.attachments[key] = sim.get_state() + + def write_back_states( + self, + sims: list, + active_sim_ids: Optional = None, + ): + """ + Loads the current state of the current model into local buffers. + + Parameters + ---------- + sims : list + List of all simulation objects. + active_sim_ids : [list, None] + List of all active simulation ids. + """ + size = len(sims) + active_sims = self._create_active_mask(active_sim_ids, size) + + for idx in range(size): + if not active_sims[idx]: + continue + + sim = sims[idx] + key = f"{sim.name}-state" + sim.set_state(sim.attachments[key]) def check_convergence( self, locations: np.ndarray, t: float, - inputs: list[dict], - prev_output: Optional[dict], + inputs: list, + prev_output: Optional, sims: list, - active_sim_ids: Optional[list[int]] = None, + active_sim_ids: Optional = None, ) -> None: """ Similarly to switch_models, checks whether models would be switched in next step. @@ -207,9 +293,7 @@ def get_num_resolutions(self) -> int: """ return len(self._model_classes) - def get_resolution_sim_class( - self, resolution: Union[int, np.ndarray] - ) -> Union[object, np.ndarray]: + def get_resolution_sim_class(self, resolution: Union) -> Union: """ Looks up the class associated with the provided resolution. @@ -227,9 +311,7 @@ def get_resolution_sim_class( clamp_in_range(resolution, 0, len(self._model_classes) - 1) ] - def get_sim_class_resolution( - self, sim: Union[object, np.ndarray] - ) -> Union[int, np.ndarray]: + def get_sim_class_resolution(self, sim: Union) -> Union: """ Looks up the resolution associated with the provided simulation object. @@ -244,11 +326,11 @@ def get_sim_class_resolution( target resolution """ return next( - (idx for idx, cls in enumerate(self._model_classes) if cls == type(sim)) + (idx for idx, cls in enumerate(self._model_classes) if cls.name == sim.name) ) def _gather_current_resolutions( - self, sims: list[object], active_sims: np.ndarray + self, sims: list, active_sims: np.ndarray ) -> np.ndarray: """ Gathers current resolutions. Inactive sims have resolution -1. @@ -277,7 +359,7 @@ def _gather_target_resolutions( cur_res: np.ndarray, locations: np.ndarray, t: float, - inputs: list[dict], + inputs: list, prev_output: dict, active_sims: np.ndarray, ) -> np.ndarray: @@ -320,7 +402,7 @@ def _gather_target_resolutions( ) return res_tgt - def _create_active_mask(self, active_sim_ids: list[int], size: int) -> np.ndarray: + def _create_active_mask(self, active_sim_ids: list, size: int) -> np.ndarray: """ Converts list of active simulation ids to np boolean mask. diff --git a/micro_manager/config.py b/micro_manager/config.py index d7d453e5..b3826fe1 100644 --- a/micro_manager/config.py +++ b/micro_manager/config.py @@ -25,6 +25,7 @@ def __init__(self, config_file_name): self._config_file_name = config_file_name self._logger = None self._micro_file_name = None + self._micro_stateless = False self._precice_config_file_name = None self._macro_mesh_name = None @@ -72,8 +73,14 @@ def __init__(self, config_file_name): # Model Adaptivity information self._m_adap = False self._m_adap_micro_file_names = None + self._m_adap_micro_stateless = None self._m_adap_switching_function = None + # Tasking + self._task_is_slurm = False + self._task_backend = "socket" + self._task_num_workers = 1 + def set_logger(self, logger): """ Set the logger for the Config class. @@ -114,6 +121,17 @@ def _read_json(self, config_file_name): .replace(".py", "") ) + try: + self._micro_stateless = self._data["micro_stateless"] + self._logger.log_info_rank_zero( + "Only creating one full instance of Micro Model." + ) + except: + self._micro_stateless = False + self._logger.log_info_rank_zero( + "Creating full instance of Micro Model per mesh vertex." + ) + self._logger.log_info_rank_zero( "Micro simulation file name: " + self._data["micro_file_name"] ) @@ -182,6 +200,23 @@ def _read_json(self, config_file_name): self._micro_dt = self._data["simulation_params"]["micro_dt"] + try: + if self._data["tasking"]: + backend = self._data["tasking"]["backend"] + if backend not in ["mpi", "socket"]: + raise Exception("Backend must be either 'mpi' or 'socket'.") + self._task_backend = backend + if "is_slurm" in self._data["tasking"]: + self._task_is_slurm = self._data["tasking"]["is_slurm"] + if "num_workers" in self._data["tasking"]: + self._task_num_workers = self._data["tasking"]["num_workers"] + if self._task_is_slurm and backend == "mpi": + raise Exception("MPI backend not supported on SLURM systems.") + except BaseException: + self._logger.log_info_rank_zero( + "No or incorrect tasking information provided. Micro manager will compute locally." + ) + def read_json_micro_manager(self): """ Reads Micro Manager relevant information from JSON configuration file @@ -481,6 +516,28 @@ def read_json_micro_manager(self): "model_adaptivity_settings" ]["switching_function"] + if ( + "micro_stateless" + in self._data["simulation_params"]["model_adaptivity_settings"] + ): + self._m_adap_micro_stateless = self._data["simulation_params"][ + "model_adaptivity_settings" + ]["micro_stateless"] + else: + self._m_adap_micro_stateless = [False] * len( + self._m_adap_micro_file_names + ) + + for i in range(len(self._m_adap_micro_file_names)): + if self._m_adap_micro_stateless[i]: + self._logger.log_info_rank_zero( + f"Only creating one full instance of Micro Model {i}." + ) + else: + self._logger.log_info_rank_zero( + f"Creating full instance of Micro Model {i} per mesh vertex." + ) + if "interpolate_crash" in self._data["simulation_params"]: if self._data["simulation_params"]["interpolate_crash"]: self._interpolate_crash = True @@ -657,6 +714,17 @@ def get_micro_file_name(self): """ return self._micro_file_name + def turn_on_micro_stateless(self): + """ + Boolean stating whether micro model is stateless or not. + + Returns + ------- + stateless : bool + True if micro model is stateless, False otherwise. + """ + return self._micro_stateless + def get_micro_output_n(self): """ Get the micro output frequency @@ -974,6 +1042,17 @@ def get_model_adaptivity_file_names(self): """ return self._m_adap_micro_file_names + def get_model_adaptivity_micro_stateless(self): + """ + List of boolean stating whether the respective micro model is stateless or not. + + Returns + ------- + stateless : list + True if micro model is stateless, False otherwise. + """ + return self._m_adap_micro_stateless + def get_model_adaptivity_switching_function(self): """ Get path to switching function file @@ -984,3 +1063,36 @@ def get_model_adaptivity_switching_function(self): String containing the path to the switching function file """ return self._m_adap_switching_function + + def get_tasking_num_workers(self): + """ + Get number of workers + + Returns + ------- + num_workers : int + Number of workers + """ + return self._task_num_workers + + def get_tasking_backend(self): + """ + Get backend type + + Returns + ------- + backend : str + either socket or mpi + """ + return self._task_backend + + def get_tasking_use_slurm(self): + """ + Get flag whether slurm is used + + Returns + ------- + use_slurm : bool + use slurm or not + """ + return self._task_is_slurm diff --git a/micro_manager/domain_decomposition.py b/micro_manager/domain_decomposition.py index bbb97b16..f0dd97ff 100644 --- a/micro_manager/domain_decomposition.py +++ b/micro_manager/domain_decomposition.py @@ -91,7 +91,7 @@ def get_local_mesh_bounds(self, macro_bounds: list, ranks_per_axis: list) -> lis def get_local_sims_and_macro_coords( self, macro_bounds: list, ranks_per_axis: list, macro_coords: np.ndarray - ) -> tuple[int, list]: + ) -> tuple: """ Decompose the micro simulations among all ranks based on their positions in the macro domain. diff --git a/micro_manager/micro_manager.py b/micro_manager/micro_manager.py index 12bab6ed..7af548de 100644 --- a/micro_manager/micro_manager.py +++ b/micro_manager/micro_manager.py @@ -25,16 +25,19 @@ import precice +from .model_manager import ModelManager + from .micro_manager_base import MicroManager from .adaptivity.global_adaptivity import GlobalAdaptivityCalculator from .adaptivity.local_adaptivity import LocalAdaptivityCalculator from .adaptivity.global_adaptivity_lb import GlobalAdaptivityLBCalculator from .adaptivity.model_adaptivity import ModelAdaptivity +from .adaptivity.adaptivity_selection import create_adaptivity_calculator from .domain_decomposition import DomainDecomposer - -from .micro_simulation import create_simulation_class +from .tasking.connection import spawn_local_workers +from .micro_simulation import create_simulation_class, load_backend_class from .tools.logging_wrapper import Logger @@ -156,6 +159,9 @@ def __init__(self, config_file: str, log_file: str = "") -> None: self._t = 0 # global time self._n = 0 # sim-step + self._model_manager = ModelManager() + self._conn = None + # ************** # Public methods # ************** @@ -204,20 +210,26 @@ def solve(self) -> None: self._micro_sims, self._data_for_adaptivity, ) - - # Write a checkpoint if a simulation is just activated. - # This checkpoint will be asynchronous to the checkpoints written at the start of the time window. - for i in range(self._local_number_of_sims): - if sim_states_cp[i] is None and self._micro_sims[i]: - sim_states_cp[i] = self._micro_sims[i].get_state() - active_sim_gids = ( self._adaptivity_controller.get_active_sim_global_ids() ) - for gid in active_sim_gids: self._micro_sims_active_steps[gid] += 1 + # Write a checkpoint if a simulation is just activated. + # This checkpoint will be asynchronous to the checkpoints written at the start of the time window. + if self._is_model_adaptivity_on: + self._model_adaptivity_controller.update_states( + self._micro_sims, active_sim_gids + ) + for i in range(self._local_number_of_sims): + if sim_states_cp[i] is None and self._micro_sims[i]: + sim_states_cp[i] = self._micro_sims[i].attachments + else: + for i in range(self._local_number_of_sims): + if sim_states_cp[i] is None and self._micro_sims[i]: + sim_states_cp[i] = self._micro_sims[i].get_state() + self._participant.stop_last_profiling_section() if self._is_adaptivity_with_load_balancing: @@ -245,10 +257,29 @@ def solve(self) -> None: # Write a checkpoint if self._participant.requires_writing_checkpoint(): - for i in range(self._local_number_of_sims): - sim_states_cp[i] = ( - self._micro_sims[i].get_state() if self._micro_sims[i] else None + active_sim_gids = None + if self._is_adaptivity_on: + active_sim_gids = ( + self._adaptivity_controller.get_active_sim_local_ids() + ) + + if self._is_model_adaptivity_on: + self._model_adaptivity_controller.update_states( + self._micro_sims, active_sim_gids ) + for i in range(self._local_number_of_sims): + sim_states_cp[i] = ( + self._micro_sims[i].attachments + if self._micro_sims[i] + else None + ) + else: + for i in range(self._local_number_of_sims): + sim_states_cp[i] = ( + self._micro_sims[i].get_state() + if self._micro_sims[i] + else None + ) micro_sims_input = self._read_data_from_precice(dt) @@ -293,14 +324,29 @@ def solve(self) -> None: # Revert micro simulations to their last checkpoints if required if self._participant.requires_reading_checkpoint(): - for i in range(self._local_number_of_sims): - if self._micro_sims[i]: - self._micro_sims[i].set_state(sim_states_cp[i]) + if self._is_model_adaptivity_on: + active_sim_gids = None + if self._is_adaptivity_on: + active_sim_gids = ( + self._adaptivity_controller.get_active_sim_local_ids() + ) + + for i in range(self._local_number_of_sims): + if self._micro_sims[i]: + self._micro_sims[i].attachments.update(sim_states_cp[i]) + self._model_adaptivity_controller.write_back_states( + self._micro_sims, active_sim_gids + ) + + else: + for i in range(self._local_number_of_sims): + if self._micro_sims[i]: + self._micro_sims[i].set_state(sim_states_cp[i]) + first_iteration = False - if ( - self._participant.is_time_window_complete() - ): # Time window has converged, now micro output can be generated + # Time window has converged, now micro output can be generated + if self._participant.is_time_window_complete(): self._t += dt # Update time to the end of the time window self._n += 1 # Update time step to the end of the time window @@ -421,10 +467,7 @@ def initialize(self) -> None: " and does not match the dimensions of the macro mesh." ) - domain_decomposer = DomainDecomposer( - self._rank, - self._size, - ) + domain_decomposer = DomainDecomposer(self._rank, self._size) if self._is_parallel and not self._is_adaptivity_with_load_balancing: coupling_mesh_bounds = domain_decomposer.get_local_mesh_bounds( @@ -539,72 +582,67 @@ def initialize(self) -> None: if self._interpolate_crashed_sims: self._interpolant = Interpolation(self._logger) + # Setup remote workers + base_dir = os.path.dirname(os.path.abspath(__file__)) + worker_exec = os.path.join(base_dir, "tasking", "worker_main.py") + num_ranks = self._config.get_tasking_num_workers() + self._conn = spawn_local_workers( + worker_exec, + num_ranks, + self._config.get_tasking_backend(), + self._config.get_tasking_use_slurm(), + ) + + # load micro sim micro_problem_cls = None if self._is_model_adaptivity_on: self._model_adaptivity_controller: ModelAdaptivity = ModelAdaptivity( - self._config, self._rank, self._log_file + self._model_manager, + self._config, + self._rank, + self._log_file, + self._conn, + num_ranks, ) micro_problem_cls = ( self._model_adaptivity_controller.get_resolution_sim_class(0) ) else: - micro_problem_base = getattr( - importlib.import_module( - self._config.get_micro_file_name(), "MicroSimulation" - ), - "MicroSimulation", - ) + micro_problem_base = load_backend_class(self._config.get_micro_file_name()) micro_problem_cls = create_simulation_class( - micro_problem_base, "MicroSimulationDefault" + self._logger, + micro_problem_base, + self._config.get_micro_file_name(), + self._config.get_tasking_num_workers(), + self._conn, + "MicroSimulationDefault", + ) + self._model_manager.register( + micro_problem_cls, self._config.turn_on_micro_stateless() ) # Create micro simulation objects self._micro_sims = [0] * self._local_number_of_sims if not self._lazy_init: for i in range(self._local_number_of_sims): - self._micro_sims[i] = micro_problem_cls( - self._global_ids_of_local_sims[i] + self._micro_sims[i] = self._model_manager.get_instance( + self._global_ids_of_local_sims[i], micro_problem_cls ) if self._is_adaptivity_on: - if self._config.get_adaptivity_type() == "local": - self._adaptivity_controller: LocalAdaptivityCalculator = ( - LocalAdaptivityCalculator( - self._config, - self._local_number_of_sims, - self._logger, - self._rank, - self._comm, - micro_problem_cls, - ) - ) - elif self._config.get_adaptivity_type() == "global": - if self._is_adaptivity_with_load_balancing: - self._adaptivity_controller: GlobalAdaptivityLBCalculator = ( - GlobalAdaptivityLBCalculator( - self._config, - self._global_number_of_sims, - self._global_ids_of_local_sims, - self._participant, - self._logger, - self._rank, - self._comm, - micro_problem_cls, - ) - ) - else: - self._adaptivity_controller: GlobalAdaptivityCalculator = ( - GlobalAdaptivityCalculator( - self._config, - self._global_number_of_sims, - self._global_ids_of_local_sims, - self._participant, - self._logger, - self._rank, - self._comm, - micro_problem_cls, - ) - ) + self._adaptivity_controller = create_adaptivity_calculator( + self._config, + self._local_number_of_sims, + self._global_number_of_sims, + self._global_ids_of_local_sims, + self._participant, + self._logger, + self._rank, + self._comm, + micro_problem_cls, + self._model_manager, + self._is_adaptivity_with_load_balancing, + ) self._micro_sims_active_steps = np.zeros( self._global_number_of_sims @@ -624,9 +662,8 @@ def initialize(self) -> None: is_initial_data_available = False else: is_initial_data_available = True - if ( - self._lazy_init - ): # For lazy initialization, compute adaptivity with the initial macro data + # For lazy initialization, compute adaptivity with the initial macro data + if self._lazy_init: for i in range(self._local_number_of_sims): for name in self._adaptivity_macro_data_names: self._data_for_adaptivity[name][i] = initial_data[i][name] @@ -644,8 +681,8 @@ def initialize(self) -> None: return for i in active_sim_lids: - self._micro_sims[i] = micro_problem_cls( - self._global_ids_of_local_sims[i] + self._micro_sims[i] = self._model_manager.get_instance( + self._global_ids_of_local_sims[i], micro_problem_cls ) first_id = active_sim_lids[0] # First active simulation ID @@ -654,45 +691,13 @@ def initialize(self) -> None: ) # Boolean which states if the initialize() method of the micro simulation requires initial data - sim_requires_init_data = False - - # Check if provided micro simulation has an initialize() method - if hasattr(micro_problem_cls, "initialize") and callable( - getattr(micro_problem_cls, "initialize") - ): - self._micro_sims_init = True # Starting value before setting - - try: # Try to get the signature of the initialize() method, if it is written in Python - argspec = inspect.getfullargspec(micro_problem_cls.initialize) - if ( - len(argspec.args) == 1 - ): # The first argument in the signature is self - sim_requires_init_data = False - elif len(argspec.args) == 2: - sim_requires_init_data = True - else: - raise Exception( - "The initialize() method of the Micro simulation has an incorrect number of arguments." - ) - except TypeError: - self._logger.log_info_rank_zero( - "The signature of initialize() method of the micro simulation cannot be determined. Trying to determine the signature by calling the method." - ) - # Try to get the signature of the initialize() method, if it is not written in Python - try: # Try to call the initialize() method without initial data - self._micro_sims[first_id].initialize() - sim_requires_init_data = False - except TypeError: - self._logger.log_info_rank_zero( - "The initialize() method of the micro simulation has arguments. Attempting to call it again with initial data." - ) - try: # Try to call the initialize() method with initial data - self._micro_sims[first_id].initialize(initial_data[first_id]) - sim_requires_init_data = True - except TypeError: - raise Exception( - "The initialize() method of the Micro simulation has an incorrect number of arguments." - ) + ( + self._micro_sims_init, + sim_requires_init_data, + ) = micro_problem_cls.check_initialize( + self._micro_sims[first_id], + initial_data[first_id] if is_initial_data_available else None, + ) if sim_requires_init_data and not is_initial_data_available: raise Exception( @@ -701,7 +706,6 @@ def initialize(self) -> None: # Get initial data from micro simulations if initialize() method exists if self._micro_sims_init: - # Call initialize() method of the micro simulation to check if it returns any initial data if sim_requires_init_data: initial_micro_output = self._micro_sims[first_id].initialize( @@ -710,9 +714,8 @@ def initialize(self) -> None: else: initial_micro_output = self._micro_sims[first_id].initialize() - if ( - initial_micro_output is None - ): # Check if the detected initialize() method returns any data + # Check if the detected initialize() method returns any data + if initial_micro_output is None: self._logger.log_warning_rank_zero( "The initialize() call of the Micro simulation has not returned any initial data." " This means that the initialize() call has no effect on the adaptivity. The initialize method will nevertheless still be called." @@ -765,9 +768,8 @@ def initialize(self) -> None: ] = initial_micro_output[name] initial_micro_data[name][i] = initial_micro_output[name] - if ( - self._lazy_init - ): # If lazy initialization is on, initial states of inactive simulations need to be determined + # If lazy initialization is on, initial states of inactive simulations need to be determined + if self._lazy_init: self._adaptivity_controller.get_full_field_micro_output( initial_micro_data ) @@ -789,11 +791,7 @@ def initialize(self) -> None: for i in range(1, self._local_number_of_sims): self._micro_sims[i].initialize() - self._micro_sims_have_output = False - if hasattr(micro_problem_cls, "output") and callable( - getattr(micro_problem_cls, "output") - ): - self._micro_sims_have_output = True + self._micro_sims_have_output = micro_problem_cls.check_output() self._participant.stop_last_profiling_section() diff --git a/micro_manager/micro_simulation.py b/micro_manager/micro_simulation.py index 829c6eb5..e0df10e6 100644 --- a/micro_manager/micro_simulation.py +++ b/micro_manager/micro_simulation.py @@ -4,8 +4,305 @@ class MicroSimulation. A global ID member variable is defined for the class Simu created object is uniquely identifiable in a global setting. """ +from abc import ABC, abstractmethod +import inspect +import importlib as ipl -def create_simulation_class(micro_simulation_class, sim_class_name=None): +from .tasking.task import * + + +class MicroSimulationInterface(ABC): + @abstractmethod + def solve(self, micro_sim_input, dt): + pass + + @abstractmethod + def get_state(self): + pass + + @abstractmethod + def set_state(self, state): + pass + + @abstractmethod + def get_global_id(self): + pass + + @abstractmethod + def set_global_id(self, global_id): + pass + + @abstractmethod + def initialize(self, *args, **kwargs): + pass + + @abstractmethod + def output(self): + pass + + +class MicroSimulationLocal(MicroSimulationInterface): + def __init__(self, gid, late_init, sim_cls): + self._gid = gid + self._instance = sim_cls(-1 if late_init else gid) + + def solve(self, micro_sim_input, dt): + return self._instance.solve(micro_sim_input, dt) + + def get_state(self): + return self._instance.get_state() + + def set_state(self, state): + return self._instance.set_state(state) + + def get_global_id(self): + return self._gid + + def set_global_id(self, global_id): + self._gid = global_id + + def __getattr__(self, name): + return getattr(self._instance, name) + + def initialize(self, *args, **kwargs): + return self._instance.initialize(*args, **kwargs) + + def output(self): + return self._instance.output() + + +class MicroSimulationRemote(MicroSimulationInterface): + def __init__(self, gid, late_init, num_ranks, conn, cls_path): + self._cls_path = cls_path + self._gid = gid + self._num_ranks = num_ranks + self._conn = conn + + construct_cls = ConstructLateTask if late_init else ConstructTask + for worker_id in range(self._num_ranks): + task = construct_cls.send_args(self._gid, self._cls_path) + self._conn.send(worker_id, task) + + for worker_id in range(self._num_ranks): + self._conn.recv(worker_id) + + def solve(self, micro_sim_input, dt): + for worker_id in range(self._num_ranks): + task = SolveTask.send_args(self._gid, micro_sim_input, dt) + self._conn.send(worker_id, task) + + result = None + for worker_id in range(self._num_ranks): + output = self._conn.recv(worker_id) + if worker_id == 0: + result = output + + return result + + def get_state(self): + for worker_id in range(self._num_ranks): + task = GetStateTask.send_args(self._gid) + self._conn.send(worker_id, task) + + result = {} + for worker_id in range(self._num_ranks): + result[worker_id] = self._conn.recv(worker_id) + + return result + + def set_state(self, state): + for worker_id in range(self._num_ranks): + task = SetStateTask.send_args(self._gid, state[worker_id]) + self._conn.send(worker_id, task) + + for worker_id in range(self._num_ranks): + self._conn.recv(worker_id) + + def get_global_id(self): + return self._gid + + def set_global_id(self, global_id): + self._gid = global_id + + def initialize(self, *args, **kwargs): + for worker_id in range(self._num_ranks): + task = InitializeTask.send_args(self._gid, *args, **kwargs) + self._conn.send(worker_id, task) + + result = None + for worker_id in range(self._num_ranks): + output = self._conn.recv(worker_id) + if worker_id == 0: + result = output + + return result + + def output(self): + for worker_id in range(self._num_ranks): + task = OutputTask.send_args(self._gid) + self._conn.send(worker_id, task) + + result = None + for worker_id in range(self._num_ranks): + output = self._conn.recv(worker_id) + if worker_id == 0: + result = output + + return result + + +class MicroSimulationWrapper(MicroSimulationInterface): + """ + If only a single rank is in use: will contain the micro sim instance. + Otherwise, it will delegate method calls to workers and not contain state. + """ + + def __init__(self, name, sim_cls, cls_path, global_id, late_init, num_ranks, conn): + self._impl = None + + if num_ranks > 1 and conn is not None: + self._impl = MicroSimulationRemote( + global_id, late_init, num_ranks, conn, cls_path + ) + else: + self._impl = MicroSimulationLocal(global_id, late_init, sim_cls) + + self._external_data = dict() + self._name = name + + def solve(self, micro_sim_input, dt): + return self._impl.solve(micro_sim_input, dt) + + def get_state(self): + return self._impl.get_state() + + def set_state(self, state): + return self._impl.set_state(state) + + def get_global_id(self): + return self._impl.get_global_id() + + def set_global_id(self, global_id): + return self._impl.set_global_id(global_id) + + def initialize(self, *args, **kwargs): + return self._impl.initialize(*args, **kwargs) + + def output(self): + return self._impl.output() + + def __getattr__(self, name): + return getattr(self._impl, name) + + @property + def attachments(self): + return self._external_data + + @attachments.setter + def attachments(self, value): + self._external_data = value + + @property + def name(self): + return self._name + + +class MicroSimulationClassAdapter: + def __init__(self, sim_cls, cls_path, name, num_ranks, conn, log): + self._sim_cls = sim_cls + self._cls_path = cls_path + self._name = name + self._num_ranks = num_ranks + self._conn = conn + self._log = log + + @property + def name(self): + return self._name + + def __call__(self, gid, *, late_init=False): + return MicroSimulationWrapper( + self._name, + self._sim_cls, + self._cls_path, + gid, + late_init, + self._num_ranks, + self._conn, + ) + + @property + def backend_cls(self): + return self._sim_cls + + def check_initialize(self, test_instance, test_input): + has_init = hasattr(self._sim_cls, "initialize") + if not has_init: + return False, False + callable_init = callable(getattr(self._sim_cls, "initialize")) + if not callable_init: + return False, False + + has_args = False + + # Try to get the signature of the initialize() method, if it is written in Python + try: + argspec = inspect.getfullargspec(self._sim_cls.initialize) + # The first argument in the signature is self + if len(argspec.args) == 1: + has_args = False + elif len(argspec.args) == 2: + has_args = True + else: + raise Exception( + "The initialize() method of the Micro simulation has an incorrect number of arguments." + ) + except TypeError: + self._log.log_info_rank_zero( + "The signature of initialize() method of the micro simulation cannot be determined. " + + "Trying to determine the signature by calling the method." + ) + # Try to call the initialize() method without initial data + try: + test_instance.initialize() + has_args = False + except TypeError: + self._log.log_info_rank_zero( + "The initialize() method of the micro simulation has arguments. " + + "Attempting to call it again with initial data." + ) + try: + test_instance.initialize(test_input) + has_args = True + except TypeError: + raise Exception( + "The initialize() method of the Micro simulation has an incorrect number of arguments." + ) + + return has_init and callable_init, has_args + + def check_output(self): + has_init = hasattr(self._sim_cls, "output") + if not has_init: + return False + callable_init = callable(getattr(self._sim_cls, "output")) + + return has_init and callable_init + + +def load_backend_class(path_to_micro_file): + CLS_NAME = "MicroSimulation" + return getattr(ipl.import_module(path_to_micro_file, CLS_NAME), CLS_NAME) + + +def create_simulation_class( + log, + micro_simulation_class, + path_to_micro_file, + num_ranks, + conn=None, + sim_class_name=None, +): """ Creates a class Simulation which inherits from the class of the micro simulation. @@ -22,6 +319,15 @@ def create_simulation_class(micro_simulation_class, sim_class_name=None): Simulation : class Definition of class Simulation defined in this function. """ + if not hasattr(micro_simulation_class, "get_global_id"): + raise ValueError("Invalid micro simulation class") + if not hasattr(micro_simulation_class, "get_state"): + raise ValueError("Invalid micro simulation class") + if not hasattr(micro_simulation_class, "set_state"): + raise ValueError("Invalid micro simulation class") + if not hasattr(micro_simulation_class, "solve"): + raise ValueError("Invalid micro simulation class") + if sim_class_name is None: if not hasattr(create_simulation_class, "sim_id"): create_simulation_class.sim_id = 0 @@ -29,21 +335,7 @@ def create_simulation_class(micro_simulation_class, sim_class_name=None): create_simulation_class.sim_id += 1 sim_class_name = f"MicroSimulation{create_simulation_class.sim_id}" - sim_class_body = """ -def __init__(self, global_id): - micro_simulation_class.__init__(self, global_id) - self._global_id = global_id - -def get_global_id(self) -> int: - return self._global_id - """ - sim_class_dict = {} - local_globals = { - "__builtins__": __builtins__, - "micro_simulation_class": micro_simulation_class, - } - exec(sim_class_body, local_globals, sim_class_dict) - # print(sim_class_dict) - sim_class = type(sim_class_name, (micro_simulation_class,), sim_class_dict) - - return sim_class + result_cls = MicroSimulationClassAdapter( + micro_simulation_class, path_to_micro_file, sim_class_name, num_ranks, conn, log + ) + return result_cls diff --git a/micro_manager/model_manager.py b/micro_manager/model_manager.py new file mode 100644 index 00000000..e7c1ccea --- /dev/null +++ b/micro_manager/model_manager.py @@ -0,0 +1,88 @@ +class ModelWrapper: + """ + Stateless Model Wrapper + """ + + def __init__(self, global_id, backend, attach_init, attach_output): + self._global_id = global_id + self._backend = backend + + if attach_init: + self.initialize = backend.initialize + if attach_output: + self.output = backend.output + + def get_global_id(self) -> int: + return self._global_id + + def solve(self, macro_data, dt): + return self._backend.solve(macro_data, dt) + + def get_state(self): + return self._backend.get_state() + + def set_state(self, state): + self._backend.set_state(state) + + @property + def __class__(self): + return self._backend.__class__ + + @property + def attachments(self): + return self._backend.attachments + + @attachments.setter + def attachments(self, value): + self._backend.attachments = value + + @property + def name(self): + return self._backend.name + + +class ModelManager: + def __init__(self): + self._registered_classes = [] + self._stateless_map = dict() + self._backend_map = dict() + self._has_init_map = dict() + self._has_output_map = dict() + + def register(self, micro_sim_cls, stateless): + if micro_sim_cls in self._registered_classes: + return + + self._registered_classes.append(micro_sim_cls) + self._stateless_map[micro_sim_cls] = stateless + + if stateless: + self._backend_map[micro_sim_cls] = micro_sim_cls( + len(self._registered_classes) - 1 + ) + + self._has_init_map[micro_sim_cls] = False + if hasattr(micro_sim_cls, "initialize") and callable( + getattr(micro_sim_cls, "initialize") + ): + self._has_init_map[micro_sim_cls] = True + + self._has_output_map[micro_sim_cls] = False + if hasattr(micro_sim_cls, "output") and callable( + getattr(micro_sim_cls, "output") + ): + self._has_output_map[micro_sim_cls] = True + + def get_instance(self, gid, micro_sim_cls, *, late_init=False): + if micro_sim_cls not in self._registered_classes: + raise RuntimeError("Trying to create instance of unknown class!") + + if self._stateless_map[micro_sim_cls]: + return ModelWrapper( + gid, + self._backend_map[micro_sim_cls], + self._has_init_map[micro_sim_cls], + self._has_output_map[micro_sim_cls], + ) + else: + return micro_sim_cls(gid, late_init=late_init) diff --git a/micro_manager/snapshot/snapshot.py b/micro_manager/snapshot/snapshot.py index 4b6ae71f..618e5987 100644 --- a/micro_manager/snapshot/snapshot.py +++ b/micro_manager/snapshot/snapshot.py @@ -17,7 +17,7 @@ from micro_manager.micro_manager import MicroManager from .dataset import ReadWriteHDF -from micro_manager.micro_simulation import create_simulation_class +from micro_manager.micro_simulation import create_simulation_class, load_backend_class from micro_manager.tools.logging_wrapper import Logger @@ -84,7 +84,13 @@ def solve(self) -> None: - Merge output in parallel run. """ - micro_problem_cls = create_simulation_class(self._micro_problem) + micro_problem_cls = create_simulation_class( + self._logger, + self._micro_problem, + self._config.get_micro_file_name(), + 1, + None, + ) # Loop over all macro parameters for elems in range(self._local_number_of_sims): @@ -256,12 +262,7 @@ def initialize(self) -> None: for i in range(self._local_number_of_sims): self._global_ids_of_local_sims.append(sim_id) sim_id += 1 - self._micro_problem = getattr( - importlib.import_module( - self._config.get_micro_file_name(), "MicroSimulation" - ), - "MicroSimulation", - ) + self._micro_problem = load_backend_class(self._config.get_micro_file_name()) self._micro_sims_have_output = False if hasattr(self._micro_problem, "output") and callable( diff --git a/micro_manager/tasking/__init__.py b/micro_manager/tasking/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/micro_manager/tasking/connection.py b/micro_manager/tasking/connection.py new file mode 100644 index 00000000..bcfd4b7a --- /dev/null +++ b/micro_manager/tasking/connection.py @@ -0,0 +1,258 @@ +import pickle +import psutil +import socket +import struct +import subprocess +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional +from mpi4py import MPI + + +class Connection(ABC): + @abstractmethod + def send(self, dst_id: int, obj: Any) -> None: + pass + + @abstractmethod + def recv(self, src_id: int) -> Any: + pass + + @abstractmethod + def close(self) -> None: + pass + + +class MPIConnection(Connection): + def __init__(self): + self.inter_comm = None + + @classmethod + def create_workers( + cls, worker_exec: str, mpi_args: Optional, n_workers: int + ) -> "MPIConnection": + comm = MPI.COMM_SELF + conn = cls() + conn.inter_comm = comm.Spawn( + f"python {worker_exec}", + args=mpi_args or [], + maxprocs=n_workers, + ) + return conn + + @classmethod + def connect_to_micromanager(cls, parent_comm) -> "MPIConnection": + conn = cls() + conn.inter_comm = parent_comm + return conn + + def send(self, dst_id: int, obj: Any) -> None: + data = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + self.inter_comm.send(data, dest=dst_id, tag=0) + + def recv(self, src_id: int) -> Any: + data = self.inter_comm.recv(source=src_id, tag=1) + return pickle.loads(data) + + def close(self) -> None: + self.inter_comm.Disconnect() + + +class SocketConnection(Connection): + def __init__(self): + self.sockets: Dict[int, socket.socket] = {} + + @classmethod + def create_workers( + cls, worker_exec: str, launcher: list, host: str, n_workers: int + ) -> "SocketConnection": + # create listening socket with ephemeral port + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.bind((host, 0)) # kernel picks free port + server.listen() + port = server.getsockname()[1] + + executable = [ + "python", + worker_exec, + "--backend", + "socket", + "--host", + host, + "--port", + str(port), + ] + cmd = [] + cmd.extend(launcher) + cmd.extend(executable) + subprocess.Popen(cmd, env=os.environ.copy()) + + conn = cls() + for wid in range(n_workers): + sock, _ = server.accept() + conn.sockets[wid] = sock + + server.close() + return conn + + @classmethod + def connect_to_micromanager( + cls, worker_id: int, host: str, port: int + ) -> "SocketConnection": + conn = cls() + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect((host, port)) + conn.sockets[worker_id] = sock + return conn + + def send(self, dst_id: int, obj: Any) -> None: + sock = self.sockets[dst_id] + data = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + header = struct.pack("!Q", len(data)) + sock.sendall(header + data) + + def recv(self, src_id: int) -> Any: + sock = self.sockets[src_id] + header = sock.recv(8) + if not header: + raise EOFError + (size,) = struct.unpack("!Q", header) + payload = b"" + while len(payload) < size: + chunk = sock.recv(size - len(payload)) + if not chunk: + raise EOFError + payload += chunk + return pickle.loads(payload) + + def close(self) -> None: + for sock in self.sockets.values(): + sock.close() + self.sockets.clear() + + +def get_local_ip(preferred_ifaces=None) -> str: + """ + Returns a non-loopback IPv4 address without accessing external networks. + + Parameters + ---------- + preferred_ifaces : list[str], optional + If provided, try interfaces in this order first (e.g., ["ib0", "eno1"]) + + Returns + ------- + str + The selected IPv4 address + """ + addrs = psutil.net_if_addrs() + + candidates = [] + + # Iterate over preferred interfaces first + if preferred_ifaces: + for name in preferred_ifaces: + if name not in addrs: + continue + for a in addrs[name]: + if a.family == socket.AF_INET and not a.address.startswith("127."): + return a.address + + # Fallback: iterate all interfaces + for name, iface_addrs in addrs.items(): + for a in iface_addrs: + if a.family == socket.AF_INET: + ip = a.address + if not ip.startswith("127.") and not ip.startswith("169.254."): + candidates.append(ip) + + if candidates: + return candidates[0] + + raise RuntimeError("No non-loopback IPv4 address found") + + +def spawn_local_workers( + worker_exec: str, + n_workers: int, + backend: str, + is_slurm: bool, +): + """ + Spawn worker processes. On Slurm systems: MPI spawn now supported, socket backend enforced. + Ephemeral port auto-selected. + + Parameters + ---------- + worker_exec : str + path to worker executable + n_workers : int + number of worker processes, must be > 1 otherwise returns None + backend : str + mpi or socket + is_slurm : bool + is our system slurm based? + + Returns + ------- + conn : Connection + Established connection on generator side + """ + from .task import RegisterAllTask + + if n_workers <= 1: + return None + conn = None + + # MPI BACKEND (non-Slurm only) + if backend == "mpi": + if is_slurm: + raise RuntimeError( + "MPI backend is not supported under Slurm. " + "Use socket backend instead." + ) + comm = MPI.COMM_WORLD + local_rank = comm.Get_rank() + conn = MPIConnection.create_workers( + worker_exec=worker_exec, + mpi_args=[ + "--backend", + "mpi", + "--parentrank", + str(local_rank), + ], + n_workers=n_workers, + ) + + # SOCKET BACKEND + if backend == "socket": + host = get_local_ip() + + # launch workers + launcher = None + if is_slurm: + launcher = [ + "srun", + # "--exclusive", + "--ntasks", + str(n_workers), + "--kill-on-bad-exit=1", + ] + else: + launcher = [ + "mpiexec", + "-n", + str(n_workers), + ] + + conn = SocketConnection.create_workers( + worker_exec=worker_exec, launcher=launcher, host=host, n_workers=n_workers + ) + + from ..micro_simulation import load_backend_class + + for worker_id in range(n_workers): + conn.send(worker_id, RegisterAllTask(load_backend_class)) + conn.recv(worker_id) + + return conn diff --git a/micro_manager/tasking/task.py b/micro_manager/tasking/task.py new file mode 100644 index 00000000..55b5cb46 --- /dev/null +++ b/micro_manager/tasking/task.py @@ -0,0 +1,119 @@ +class Task: + def __init__(self, fn, *args, **kwargs): + self.fn = fn + self.args = args + self.kwargs = kwargs + + def __call__(self, state_data: dict): + return self.fn(*self.args, state_data=state_data, **self.kwargs) + + @classmethod + def send_args(cls, *args, **kwargs): + return cls.__name__, args, kwargs + + +class ConstructTask(Task): + def __init__(self, gid, cls_path): + super().__init__(ConstructTask.initializer, gid=gid, cls_path=cls_path) + + @staticmethod + def initializer(gid, cls_path, state_data): + if cls_path not in state_data["sim_classes"]: + state_data["sim_classes"][cls_path] = state_data["load_function"](cls_path) + cls = state_data["sim_classes"][cls_path] + + if gid in state_data["sim_classes"]: + del state_data["sim_classes"][gid] + state_data["sim_instances"][gid] = cls(gid) + return None + + +class ConstructLateTask(Task): + def __init__(self, gid, cls_path): + super().__init__(ConstructLateTask.initializer, gid=gid, cls_path=cls_path) + + @staticmethod + def initializer(gid, cls_path, state_data): + if cls_path not in state_data["sim_classes"]: + state_data["sim_classes"][cls_path] = state_data["load_function"](cls_path) + cls = state_data["sim_classes"][cls_path] + + if gid in state_data["sim_classes"]: + del state_data["sim_classes"][gid] + state_data["sim_instances"][gid] = cls(-1) + return None + + +class SolveTask(Task): + def __init__(self, gid, sim_input, dt): + super().__init__(SolveTask.solve, gid=gid, sim_input=sim_input, dt=dt) + + @staticmethod + def solve(gid, sim_input, dt, state_data): + sim_output = state_data["sim_instances"][gid].solve(sim_input, dt) + return sim_output + + +class GetStateTask(Task): + def __init__(self, gid): + super().__init__(GetStateTask.get, gid=gid) + + @staticmethod + def get(gid, state_data): + return state_data["sim_instances"][gid].get_state() + + +class SetStateTask(Task): + def __init__(self, gid, state): + super().__init__(SetStateTask.set, gid=gid, state=state) + + @staticmethod + def set(gid, state, state_data): + state_data["sim_instances"][gid].set_state(state) + return None + + +class InitializeTask(Task): + def __init__(self, gid, *args, **kwargs): + super().__init__(InitializeTask.initialize, *args, gid=gid, **kwargs) + + @staticmethod + def initialize(gid, state_data, *args, **kwargs): + return state_data["sim_instances"][gid].initialize(*args, **kwargs) + + +class OutputTask(Task): + def __init__(self, gid): + super().__init__(OutputTask.output, gid=gid) + + @staticmethod + def output(gid, state_data): + return state_data["sim_instances"][gid].output() + + +class RegisterAllTask(Task): + def __init__(self, load_function): + super().__init__(RegisterAllTask.register, load_function=load_function) + + @staticmethod + def register(state_data, load_function): + task_dict = dict() + task_dict[ConstructTask.__name__] = ConstructTask + task_dict[ConstructLateTask.__name__] = ConstructLateTask + task_dict[SolveTask.__name__] = SolveTask + task_dict[GetStateTask.__name__] = GetStateTask + task_dict[SetStateTask.__name__] = SetStateTask + task_dict[InitializeTask.__name__] = InitializeTask + task_dict[OutputTask.__name__] = OutputTask + state_data["tasks"] = task_dict + state_data["sim_classes"] = dict() + state_data["sim_instances"] = dict() + state_data["load_function"] = load_function + return None + + +def handle_task(state_data, task_descriptor): + name, args, kwargs = task_descriptor + task = state_data["tasks"][name](*args, **kwargs) + # print(f"handling task: {name} args={args} kwargs={kwargs}") + return task(state_data) diff --git a/micro_manager/tasking/worker_main.py b/micro_manager/tasking/worker_main.py new file mode 100644 index 00000000..d0db4103 --- /dev/null +++ b/micro_manager/tasking/worker_main.py @@ -0,0 +1,64 @@ +import argparse +import os +from mpi4py import MPI +from task import handle_task + +from connection import Connection, MPIConnection, SocketConnection + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--backend", required=True, choices=["mpi", "socket"]) + parser.add_argument("--host", help="IP or localhost") + parser.add_argument("--port", type=int, help="Port to open port in micro manager") + parser.add_argument( + "--parentrank", + type=int, + help="Parent rank of spawning micro manager mpi instance", + ) + args = parser.parse_args() + + rank = MPI.COMM_WORLD.Get_rank() + size = MPI.COMM_WORLD.Get_size() + worker_id = rank + + conn, dst_id, src_id = None, 0, 0 + if args.backend == "mpi": + conn = MPIConnection.connect_to_micromanager(MPI.Comm.Get_parent()) + dst_id = src_id = args.parentrank + else: + conn = SocketConnection.connect_to_micromanager(worker_id, args.host, args.port) + dst_id = src_id = worker_id + + state_data = {} + + # register possible tasks + register_task = None + try: + register_task = conn.recv(src_id) + except Exception: + raise RuntimeError("Failed to recv register tasks") + output = register_task(state_data) + try: + conn.send(dst_id, output) + except Exception: + raise RuntimeError("Failed to send register tasks output") + + while True: + task_descriptor = None + try: + task_descriptor = conn.recv(src_id) + except Exception: + break + + output = None + try: + output = handle_task(state_data, task_descriptor) + except Exception: + break + + try: + conn.send(dst_id, output) + except Exception: + break + + conn.close() diff --git a/pyproject.toml b/pyproject.toml index dfc9c62f..9288f42d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ Repository = "https://github.com/precice/micro-manager" micro-manager-precice = "micro_manager:main" [tool.setuptools] -packages=["micro_manager", "micro_manager.adaptivity", "micro_manager.snapshot", "micro_manager.tools"] +packages=["micro_manager", "micro_manager.adaptivity", "micro_manager.snapshot", "micro_manager.tools", "micro_manager.tasking"] [tool.setuptools-git-versioning] enabled = true diff --git a/tests/integration/test_unit_cube/micro_dummy.py b/tests/integration/test_unit_cube/micro_dummy.py index b2ec8b4b..6e35bc4f 100644 --- a/tests/integration/test_unit_cube/micro_dummy.py +++ b/tests/integration/test_unit_cube/micro_dummy.py @@ -53,3 +53,6 @@ def get_state(self): def set_state(self, state): self._state = copy.deepcopy(state) + + def get_global_id(self): + return self._sim_id diff --git a/tests/unit/test_adaptivity_parallel.py b/tests/unit/test_adaptivity_parallel.py index ba5e78b6..68391e0c 100644 --- a/tests/unit/test_adaptivity_parallel.py +++ b/tests/unit/test_adaptivity_parallel.py @@ -21,6 +21,14 @@ def set_state(self, state): def get_state(self): return self._state.copy() + def solve(self, micro_input, dt): + pass + + +class ModelManager: + def get_instance(self, gid, micro_problem_cls): + return micro_problem_cls(gid) + class TestGlobalAdaptivity(TestCase): def setUp(self): @@ -60,6 +68,7 @@ def test_update_inactive_sims_global_adaptivity(self): rank=self._rank, comm=self._comm, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._is_sim_active = np.array( @@ -134,6 +143,7 @@ def test_update_all_active_sims_global_adaptivity(self): rank=self._rank, comm=self._comm, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._adaptivity_data_names = ["data1", "data2"] @@ -189,6 +199,7 @@ def test_communicate_micro_output(self): rank=self._rank, comm=self._comm, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._is_sim_active = np.array( @@ -228,6 +239,7 @@ def test_get_ranks_of_sims(self): rank=self._rank, comm=self._comm, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) actual_ranks_of_sims = adaptivity_controller._get_ranks_of_sims() diff --git a/tests/unit/test_adaptivity_serial.py b/tests/unit/test_adaptivity_serial.py index fe6412d0..fb4aa04a 100644 --- a/tests/unit/test_adaptivity_serial.py +++ b/tests/unit/test_adaptivity_serial.py @@ -25,6 +25,14 @@ def set_state(self, state): def get_state(self): pass + def solve(self, micro_input, dt): + pass + + +class ModelManager: + def get_instance(self, gid, micro_problem_cls): + return micro_problem_cls(gid) + class TestLocalAdaptivity(TestCase): def setUp(self): @@ -94,6 +102,7 @@ def test_update_similarity_dists(self): configurator, nsims=self._number_of_sims, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), base_logger=MagicMock(), rank=0, ) @@ -146,6 +155,7 @@ def test_update_active_sims(self): rank=0, comm=MagicMock(), micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._refine_const = self._refine_const adaptivity_controller._coarse_const = self._coarse_const @@ -182,6 +192,7 @@ def test_adaptivity_norms(self): base_logger=MagicMock(), rank=0, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) fake_data = np.array([[1], [2], [3]]) @@ -280,6 +291,7 @@ def test_associate_active_to_inactive(self): base_logger=MagicMock(), rank=0, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._refine_const = self._refine_const adaptivity_controller._coarse_const = self._coarse_const @@ -323,6 +335,7 @@ def test_update_inactive_sims_local_adaptivity(self): rank=0, comm=MPI.COMM_WORLD, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._refine_const = self._refine_const adaptivity_controller._coarse_const = self._coarse_const diff --git a/tests/unit/test_global_adaptivity_lb.py b/tests/unit/test_global_adaptivity_lb.py index d60ae12b..4e6af391 100644 --- a/tests/unit/test_global_adaptivity_lb.py +++ b/tests/unit/test_global_adaptivity_lb.py @@ -22,6 +22,14 @@ def set_state(self, state): def get_state(self): return self._state.copy() + def solve(self, micro_input, dt): + pass + + +class ModelManager: + def get_instance(self, gid, micro_problem_cls): + return micro_problem_cls(gid) + class TestGlobalAdaptivityLB(TestCase): def setUp(self): @@ -68,6 +76,7 @@ def test_redistribute_active_sims_two_ranks(self): rank=self._rank, comm=self._comm, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._is_sim_active = np.array( @@ -118,6 +127,7 @@ def test_redistribute_inactive_sims_two_ranks(self): rank=self._rank, comm=self._comm, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._is_sim_active = np.array( @@ -171,6 +181,7 @@ def test_redistribute_active_sims_four_ranks(self): rank=self._rank, comm=self._comm, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._is_sim_active = np.array( @@ -243,6 +254,7 @@ def test_redistribute_inactive_sims_four_ranks(self): rank=self._rank, comm=self._comm, micro_problem_cls=MicroSimulation, + model_manager=ModelManager(), ) adaptivity_controller._is_sim_active = np.array( diff --git a/tests/unit/test_micro_manager.py b/tests/unit/test_micro_manager.py index 9f656b2f..87438ad7 100644 --- a/tests/unit/test_micro_manager.py +++ b/tests/unit/test_micro_manager.py @@ -21,6 +21,15 @@ def solve(self, macro_data, dt): "micro-vector-data": macro_data["macro-vector-data"] + 1, } + def get_global_id(self): + pass + + def get_state(self): + return None + + def set_state(self, state): + pass + class TestFunctioncalls(TestCase): def setUp(self): diff --git a/tests/unit/test_micro_simulation_crash_handling.py b/tests/unit/test_micro_simulation_crash_handling.py index c7861aa1..7da636c6 100644 --- a/tests/unit/test_micro_simulation_crash_handling.py +++ b/tests/unit/test_micro_simulation_crash_handling.py @@ -21,6 +21,15 @@ def solve(self, macro_data, dt): "micro-scalar-data": macro_data["macro-scalar-data"], } + def get_state(self): + return None + + def set_state(self, state): + pass + + def get_global_id(self): + return self.sim_id + class TestSimulationCrashHandling(TestCase): def test_crash_handling(self): diff --git a/tests/unit/test_snapshot_computation.py b/tests/unit/test_snapshot_computation.py index c4cc807d..4553312f 100644 --- a/tests/unit/test_snapshot_computation.py +++ b/tests/unit/test_snapshot_computation.py @@ -19,6 +19,15 @@ def solve(self, macro_data, dt): "micro-vector-data": macro_data["macro-vector-data"] + 1, } + def get_state(self): + return None + + def set_state(self, state): + pass + + def get_global_id(self): + pass + class TestFunctionCalls(TestCase): def setUp(self): @@ -85,7 +94,11 @@ def test_solve_micro_sims(self): snapshot_object._micro_problem = MicroSimulation snapshot_object._micro_sims = create_simulation_class( - snapshot_object._micro_problem + MagicMock(), + snapshot_object._micro_problem, + None, + 1, + None, )(0) micro_sim_output = snapshot_object._solve_micro_simulation(self.fake_read_data)