22# SPDX-License-Identifier: BSD-3-Clause
33
44import time
5- from collections .abc import Mapping , Sequence
5+ from collections .abc import Generator , Mapping , Sequence
66from dataclasses import dataclass
77from typing import Any
88
1717
1818@dataclass (frozen = True )
1919class _FitAndScoreJob :
20- metalearner : MetaLearner
20+ metalearner_factory : type [MetaLearner ]
21+ metalearner_params : dict [str , Any ]
2122 X_train : Matrix
2223 y_train : Vector
2324 w_train : Vector
@@ -32,7 +33,7 @@ class _FitAndScoreJob:
3233
3334
3435@dataclass (frozen = True )
35- class _GSResult :
36+ class GSResult :
3637 r"""Result from a single grid search evaluation."""
3738
3839 metalearner : MetaLearner
@@ -42,23 +43,22 @@ class _GSResult:
4243 score_time : float
4344
4445
45- def _fit_and_score (job : _FitAndScoreJob ) -> _GSResult :
46+ def _fit_and_score (job : _FitAndScoreJob ) -> GSResult :
4647 start_time = time .time ()
47- job .metalearner .fit (
48- job .X_train , job .y_train , job .w_train , ** job .metalerner_fit_params
49- )
48+ ml = job .metalearner_factory (** job .metalearner_params )
49+ ml .fit (job .X_train , job .y_train , job .w_train , ** job .metalerner_fit_params )
5050 fit_time = time .time () - start_time
5151 start_time = time .time ()
5252
53- train_scores = job . metalearner .evaluate (
53+ train_scores = ml .evaluate (
5454 X = job .X_train ,
5555 y = job .y_train ,
5656 w = job .w_train ,
5757 is_oos = False ,
5858 scoring = job .scoring ,
5959 )
6060 if job .X_test is not None and job .y_test is not None and job .w_test is not None :
61- test_scores = job . metalearner .evaluate (
61+ test_scores = ml .evaluate (
6262 X = job .X_test ,
6363 y = job .y_test ,
6464 w = job .w_test ,
@@ -69,16 +69,18 @@ def _fit_and_score(job: _FitAndScoreJob) -> _GSResult:
6969 else :
7070 test_scores = None
7171 score_time = time .time () - start_time
72- return _GSResult (
73- metalearner = job . metalearner ,
72+ return GSResult (
73+ metalearner = ml ,
7474 fit_time = fit_time ,
7575 score_time = score_time ,
7676 train_scores = train_scores ,
7777 test_scores = test_scores ,
7878 )
7979
8080
81- def _format_results (results : Sequence [_GSResult ]) -> pd .DataFrame :
81+ def _format_results (
82+ results : list [GSResult ] | Generator [GSResult , None , None ]
83+ ) -> pd .DataFrame :
8284 rows = []
8385 for result in results :
8486 row : dict [str , str | int | float ] = {}
@@ -180,11 +182,33 @@ class MetaLearnerGridSearch:
180182
181183 ``verbose`` will be passed to `joblib.Parallel <https://joblib.readthedocs.io/en/latest/parallel.html#parallel-reference-documentation>`_.
182184
183- After fitting a dataframe with the results will be available in `results_`.
185+ ``store_raw_results`` and ``store_results`` define which and how the results are saved
186+ after calling :meth:`~metalearners.grid_search.MetaLearnerGridSearch.fit` depending on
187+ their values:
188+
189+ * Both are ``True`` (default): ``raw_results_`` will be a list of
190+ :class:`~metalearners.grid_search.GSResult` with all the results and ``results_``
191+ will be a DataFrame with the processed results.
192+ * ``store_raw_results=True`` and ``store_results=False``: ``raw_results_`` will be a
193+ list of :class:`~metalearners.grid_search.GSResult` with all the results
194+ and ``results`` will be ``None``.
195+ * ``store_raw_results=False`` and ``store_results=True``: ``raw_results_`` will be
196+ ``None`` and ``results_`` will be a DataFrame with the processed results.
197+ * Both are ``False``: ``raw_results_`` will be a generator which yields a
198+ :class:`~metalearners.grid_search.GSResult` for each configuration and ``results``
199+ will be None. This configuration can be useful in the case the grid search is big
200+ and you do not want to store all MetaLearners objects rather evaluate them after
201+ fitting each one and just store one.
202+
203+ ``grid_size_`` will contain the number of hyperparameter combinations after fitting.
204+ This attribute may be useful in the case ``store_raw_results = False`` and ``store_results = False``.
205+ In that case, the generator object returned in ``raw_results_`` doesn't trigger the fitting
206+ of individual metalearners until explicitly requested, e.g. in a loop. This attribute
207+ can be use to track the progress, for instance, by creating a progress bar or a similar utility.
208+
209+ For an illustration see :ref:`our example on Tuning hyperparameters of a MetaLearner with MetaLearnerGridSearch <example-grid-search>`.
184210 """
185211
186- # TODO: Add a reference to a docs example once it is written.
187-
188212 def __init__ (
189213 self ,
190214 metalearner_factory : type [MetaLearner ],
@@ -195,16 +219,17 @@ def __init__(
195219 n_jobs : int | None = None ,
196220 random_state : int | None = None ,
197221 verbose : int = 0 ,
222+ store_raw_results : bool = True ,
223+ store_results : bool = True ,
198224 ):
199225 self .metalearner_factory = metalearner_factory
200226 self .metalearner_params = metalearner_params
201227 self .scoring = scoring
202228 self .n_jobs = n_jobs
203229 self .random_state = random_state
204230 self .verbose = verbose
205-
206- self .raw_results_ : Sequence [_GSResult ] | None = None
207- self .results_ : pd .DataFrame | None = None
231+ self .store_raw_results = store_raw_results
232+ self .store_results = store_results
208233
209234 all_base_models = set (
210235 metalearner_factory .nuisance_model_specifications ().keys ()
@@ -286,20 +311,33 @@ def fit(
286311 }
287312 propensity_model_params = params .get (PROPENSITY_MODEL , None )
288313
289- ml = self .metalearner_factory (
290- ** self .metalearner_params ,
291- nuisance_model_factory = nuisance_model_factory ,
292- treatment_model_factory = treatment_model_factory ,
293- propensity_model_factory = propensity_model_factory ,
294- nuisance_model_params = nuisance_model_params ,
295- treatment_model_params = treatment_model_params ,
296- propensity_model_params = propensity_model_params ,
297- random_state = self .random_state ,
298- )
314+ grid_metalearner_params = {
315+ "nuisance_model_factory" : nuisance_model_factory ,
316+ "treatment_model_factory" : treatment_model_factory ,
317+ "propensity_model_factory" : propensity_model_factory ,
318+ "nuisance_model_params" : nuisance_model_params ,
319+ "treatment_model_params" : treatment_model_params ,
320+ "propensity_model_params" : propensity_model_params ,
321+ "random_state" : self .random_state ,
322+ }
323+
324+ if (
325+ len (
326+ shared_keys := set (grid_metalearner_params .keys ())
327+ & set (self .metalearner_params .keys ())
328+ )
329+ > 0
330+ ):
331+ raise ValueError (
332+ f"{ shared_keys } should not be specified in metalearner_params as "
333+ "they are used internally. Please use the correct parameters."
334+ )
299335
300336 jobs .append (
301337 _FitAndScoreJob (
302- metalearner = ml ,
338+ metalearner_factory = self .metalearner_factory ,
339+ metalearner_params = dict (self .metalearner_params )
340+ | grid_metalearner_params ,
303341 X_train = X ,
304342 y_train = y ,
305343 w_train = w ,
@@ -312,7 +350,17 @@ def fit(
312350 )
313351 )
314352
315- parallel = Parallel (n_jobs = self .n_jobs , verbose = self .verbose )
316- raw_results = parallel (delayed (_fit_and_score )(job ) for job in jobs )
317- self .raw_results_ = raw_results
318- self .results_ = _format_results (results = raw_results )
353+ self .grid_size_ = len (jobs )
354+ self .raw_results_ : list [GSResult ] | Generator [GSResult , None , None ] | None
355+ self .results_ : pd .DataFrame | None = None
356+
357+ return_as = "list" if self .store_raw_results else "generator_unordered"
358+ parallel = Parallel (
359+ n_jobs = self .n_jobs , verbose = self .verbose , return_as = return_as
360+ )
361+ self .raw_results_ = parallel (delayed (_fit_and_score )(job ) for job in jobs )
362+ if self .store_results :
363+ self .results_ = _format_results (results = self .raw_results_ ) # type: ignore
364+ if not self .store_raw_results :
365+ # The generator will be empty so we replace it with None
366+ self .raw_results_ = None
0 commit comments