Skip to content

Commit 80ce219

Browse files
Improve memory usage of MetaLearnerGridSearch (#62)
* Add options for storing * Tests * Finish TODO * Reduce memory usage by not creating metalearner object * Update CHANGELOG * Use generator_unordered * Add grid_size_ and move attributes initialization to fit * Fix * Fix * grid_size_ docstring * Add new options to tutorial * Remove check empty generator * Apply suggestions from code review Co-authored-by: Kevin Klein <[email protected]> * Add explanation grid_size_ --------- Co-authored-by: Kevin Klein <[email protected]>
1 parent 9406ef7 commit 80ce219

File tree

4 files changed

+173
-33
lines changed

4 files changed

+173
-33
lines changed

CHANGELOG.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ Changelog
1515
* Added :meth:`metalearners.metalearner.MetaLearner.fit_all_nuisance` and
1616
:meth:`metalearners.metalearner.MetaLearner.fit_all_treatment`.
1717

18+
* Add optional ``store_raw_results`` and ``store_results`` parameters to :class:`metalearners.grid_search.MetaLearnerGridSearch`.
19+
20+
* Renamed :class:`metalearners.grid_search._GSResult` to :class:`metalearners.grid_search.GSResult`.
21+
22+
* Added ``grid_size_`` attribute to :class:`metalearners.grid_search.MetaLearnerGridSearch`.
23+
1824
* Implement :meth:`metalearners.cross_fit_estimator.CrossFitEstimator.score`.
1925

2026
**Bug fixes**

docs/examples/example_gridsearch.ipynb

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,26 @@
327327
"gs.results_"
328328
]
329329
},
330+
{
331+
"cell_type": "markdown",
332+
"metadata": {},
333+
"source": [
334+
"What if I run out of memory?\n",
335+
"----------------------------\n",
336+
"\n",
337+
"If you're conducting an optimization task over a large grid with a substantial dataset,\n",
338+
"it is possible that memory usage issues may arise. To try to solve these, you can minimize\n",
339+
"memory usage by adjusting your settings.\n",
340+
"\n",
341+
"In that case you can set ``store_raw_results=False``, the grid search will then operate\n",
342+
"with a generator rather than a list, significantly reducing memory usage.\n",
343+
"\n",
344+
"If the ``results_ DataFrame`` is what you're after, you can simply set ``store_results=True``.\n",
345+
"However, if you aim to iterate over the {class}`~metalearners.metalearner.MetaLearner` objects,\n",
346+
"you can set ``store_results=False``. Consequently, ``raw_results_`` will become a generator\n",
347+
"object yielding {class}`~metalearners.grid_search.GSResult`."
348+
]
349+
},
330350
{
331351
"cell_type": "markdown",
332352
"metadata": {},

metalearners/grid_search.py

Lines changed: 81 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
import time
5-
from collections.abc import Mapping, Sequence
5+
from collections.abc import Generator, Mapping, Sequence
66
from dataclasses import dataclass
77
from typing import Any
88

@@ -17,7 +17,8 @@
1717

1818
@dataclass(frozen=True)
1919
class _FitAndScoreJob:
20-
metalearner: MetaLearner
20+
metalearner_factory: type[MetaLearner]
21+
metalearner_params: dict[str, Any]
2122
X_train: Matrix
2223
y_train: Vector
2324
w_train: Vector
@@ -32,7 +33,7 @@ class _FitAndScoreJob:
3233

3334

3435
@dataclass(frozen=True)
35-
class _GSResult:
36+
class GSResult:
3637
r"""Result from a single grid search evaluation."""
3738

3839
metalearner: MetaLearner
@@ -42,23 +43,22 @@ class _GSResult:
4243
score_time: float
4344

4445

45-
def _fit_and_score(job: _FitAndScoreJob) -> _GSResult:
46+
def _fit_and_score(job: _FitAndScoreJob) -> GSResult:
4647
start_time = time.time()
47-
job.metalearner.fit(
48-
job.X_train, job.y_train, job.w_train, **job.metalerner_fit_params
49-
)
48+
ml = job.metalearner_factory(**job.metalearner_params)
49+
ml.fit(job.X_train, job.y_train, job.w_train, **job.metalerner_fit_params)
5050
fit_time = time.time() - start_time
5151
start_time = time.time()
5252

53-
train_scores = job.metalearner.evaluate(
53+
train_scores = ml.evaluate(
5454
X=job.X_train,
5555
y=job.y_train,
5656
w=job.w_train,
5757
is_oos=False,
5858
scoring=job.scoring,
5959
)
6060
if job.X_test is not None and job.y_test is not None and job.w_test is not None:
61-
test_scores = job.metalearner.evaluate(
61+
test_scores = ml.evaluate(
6262
X=job.X_test,
6363
y=job.y_test,
6464
w=job.w_test,
@@ -69,16 +69,18 @@ def _fit_and_score(job: _FitAndScoreJob) -> _GSResult:
6969
else:
7070
test_scores = None
7171
score_time = time.time() - start_time
72-
return _GSResult(
73-
metalearner=job.metalearner,
72+
return GSResult(
73+
metalearner=ml,
7474
fit_time=fit_time,
7575
score_time=score_time,
7676
train_scores=train_scores,
7777
test_scores=test_scores,
7878
)
7979

8080

81-
def _format_results(results: Sequence[_GSResult]) -> pd.DataFrame:
81+
def _format_results(
82+
results: list[GSResult] | Generator[GSResult, None, None]
83+
) -> pd.DataFrame:
8284
rows = []
8385
for result in results:
8486
row: dict[str, str | int | float] = {}
@@ -180,11 +182,33 @@ class MetaLearnerGridSearch:
180182
181183
``verbose`` will be passed to `joblib.Parallel <https://joblib.readthedocs.io/en/latest/parallel.html#parallel-reference-documentation>`_.
182184
183-
After fitting a dataframe with the results will be available in `results_`.
185+
``store_raw_results`` and ``store_results`` define which and how the results are saved
186+
after calling :meth:`~metalearners.grid_search.MetaLearnerGridSearch.fit` depending on
187+
their values:
188+
189+
* Both are ``True`` (default): ``raw_results_`` will be a list of
190+
:class:`~metalearners.grid_search.GSResult` with all the results and ``results_``
191+
will be a DataFrame with the processed results.
192+
* ``store_raw_results=True`` and ``store_results=False``: ``raw_results_`` will be a
193+
list of :class:`~metalearners.grid_search.GSResult` with all the results
194+
and ``results`` will be ``None``.
195+
* ``store_raw_results=False`` and ``store_results=True``: ``raw_results_`` will be
196+
``None`` and ``results_`` will be a DataFrame with the processed results.
197+
* Both are ``False``: ``raw_results_`` will be a generator which yields a
198+
:class:`~metalearners.grid_search.GSResult` for each configuration and ``results``
199+
will be None. This configuration can be useful in the case the grid search is big
200+
and you do not want to store all MetaLearners objects rather evaluate them after
201+
fitting each one and just store one.
202+
203+
``grid_size_`` will contain the number of hyperparameter combinations after fitting.
204+
This attribute may be useful in the case ``store_raw_results = False`` and ``store_results = False``.
205+
In that case, the generator object returned in ``raw_results_`` doesn't trigger the fitting
206+
of individual metalearners until explicitly requested, e.g. in a loop. This attribute
207+
can be use to track the progress, for instance, by creating a progress bar or a similar utility.
208+
209+
For an illustration see :ref:`our example on Tuning hyperparameters of a MetaLearner with MetaLearnerGridSearch <example-grid-search>`.
184210
"""
185211

186-
# TODO: Add a reference to a docs example once it is written.
187-
188212
def __init__(
189213
self,
190214
metalearner_factory: type[MetaLearner],
@@ -195,16 +219,17 @@ def __init__(
195219
n_jobs: int | None = None,
196220
random_state: int | None = None,
197221
verbose: int = 0,
222+
store_raw_results: bool = True,
223+
store_results: bool = True,
198224
):
199225
self.metalearner_factory = metalearner_factory
200226
self.metalearner_params = metalearner_params
201227
self.scoring = scoring
202228
self.n_jobs = n_jobs
203229
self.random_state = random_state
204230
self.verbose = verbose
205-
206-
self.raw_results_: Sequence[_GSResult] | None = None
207-
self.results_: pd.DataFrame | None = None
231+
self.store_raw_results = store_raw_results
232+
self.store_results = store_results
208233

209234
all_base_models = set(
210235
metalearner_factory.nuisance_model_specifications().keys()
@@ -286,20 +311,33 @@ def fit(
286311
}
287312
propensity_model_params = params.get(PROPENSITY_MODEL, None)
288313

289-
ml = self.metalearner_factory(
290-
**self.metalearner_params,
291-
nuisance_model_factory=nuisance_model_factory,
292-
treatment_model_factory=treatment_model_factory,
293-
propensity_model_factory=propensity_model_factory,
294-
nuisance_model_params=nuisance_model_params,
295-
treatment_model_params=treatment_model_params,
296-
propensity_model_params=propensity_model_params,
297-
random_state=self.random_state,
298-
)
314+
grid_metalearner_params = {
315+
"nuisance_model_factory": nuisance_model_factory,
316+
"treatment_model_factory": treatment_model_factory,
317+
"propensity_model_factory": propensity_model_factory,
318+
"nuisance_model_params": nuisance_model_params,
319+
"treatment_model_params": treatment_model_params,
320+
"propensity_model_params": propensity_model_params,
321+
"random_state": self.random_state,
322+
}
323+
324+
if (
325+
len(
326+
shared_keys := set(grid_metalearner_params.keys())
327+
& set(self.metalearner_params.keys())
328+
)
329+
> 0
330+
):
331+
raise ValueError(
332+
f"{shared_keys} should not be specified in metalearner_params as "
333+
"they are used internally. Please use the correct parameters."
334+
)
299335

300336
jobs.append(
301337
_FitAndScoreJob(
302-
metalearner=ml,
338+
metalearner_factory=self.metalearner_factory,
339+
metalearner_params=dict(self.metalearner_params)
340+
| grid_metalearner_params,
303341
X_train=X,
304342
y_train=y,
305343
w_train=w,
@@ -312,7 +350,17 @@ def fit(
312350
)
313351
)
314352

315-
parallel = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)
316-
raw_results = parallel(delayed(_fit_and_score)(job) for job in jobs)
317-
self.raw_results_ = raw_results
318-
self.results_ = _format_results(results=raw_results)
353+
self.grid_size_ = len(jobs)
354+
self.raw_results_: list[GSResult] | Generator[GSResult, None, None] | None
355+
self.results_: pd.DataFrame | None = None
356+
357+
return_as = "list" if self.store_raw_results else "generator_unordered"
358+
parallel = Parallel(
359+
n_jobs=self.n_jobs, verbose=self.verbose, return_as=return_as
360+
)
361+
self.raw_results_ = parallel(delayed(_fit_and_score)(job) for job in jobs)
362+
if self.store_results:
363+
self.results_ = _format_results(results=self.raw_results_) # type: ignore
364+
if not self.store_raw_results:
365+
# The generator will be empty so we replace it with None
366+
self.raw_results_ = None

tests/test_grid_search.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
# SPDX-License-Identifier: BSD-3-Clause
33

44

5+
from types import GeneratorType
6+
57
import numpy as np
8+
import pandas as pd
69
import pytest
710
from lightgbm import LGBMClassifier, LGBMRegressor
811
from sklearn.linear_model import LinearRegression, LogisticRegression
@@ -153,6 +156,7 @@ def test_metalearnergridsearch_smoke(
153156
assert gs.results_ is not None
154157
assert gs.results_.shape[0] == expected_n_configs
155158
assert gs.results_.index.names == expected_index_cols
159+
assert gs.grid_size_ == expected_n_configs
156160

157161
train_scores_cols = set(
158162
c[6:] for c in list(gs.results_.columns) if c.startswith("train_")
@@ -259,3 +263,65 @@ def test_metalearnergridsearch_reuse_propensity_smoke(grid_search_data):
259263
assert gs.results_ is not None
260264
assert gs.results_.shape[0] == 2
261265
assert len(gs.results_.index.names) == 5
266+
267+
268+
@pytest.mark.parametrize(
269+
"store_raw_results, store_results, expected_type_raw_results, expected_type_results",
270+
[
271+
(True, True, list, pd.DataFrame),
272+
(True, False, list, type(None)),
273+
(False, True, type(None), pd.DataFrame),
274+
(False, False, GeneratorType, type(None)),
275+
],
276+
)
277+
def test_metalearnergridsearch_store(
278+
store_raw_results,
279+
store_results,
280+
expected_type_raw_results,
281+
expected_type_results,
282+
grid_search_data,
283+
):
284+
X, _, y, w, X_test, _, y_test, w_test = grid_search_data
285+
n_variants = len(np.unique(w))
286+
287+
metalearner_params = {
288+
"is_classification": False,
289+
"n_variants": n_variants,
290+
"n_folds": 2,
291+
}
292+
293+
gs = MetaLearnerGridSearch(
294+
metalearner_factory=SLearner,
295+
metalearner_params=metalearner_params,
296+
base_learner_grid={"base_model": [LinearRegression, LGBMRegressor]},
297+
param_grid={"base_model": {"LGBMRegressor": {"n_estimators": [1, 2]}}},
298+
store_raw_results=store_raw_results,
299+
store_results=store_results,
300+
)
301+
302+
gs.fit(X, y, w, X_test, y_test, w_test)
303+
assert isinstance(gs.raw_results_, expected_type_raw_results)
304+
assert isinstance(gs.results_, expected_type_results)
305+
306+
307+
def test_metalearnergridsearch_error(grid_search_data):
308+
X, _, y, w, X_test, _, y_test, w_test = grid_search_data
309+
n_variants = len(np.unique(w))
310+
311+
metalearner_params = {
312+
"is_classification": False,
313+
"n_variants": n_variants,
314+
"n_folds": 2,
315+
"random_state": 1,
316+
}
317+
318+
gs = MetaLearnerGridSearch(
319+
metalearner_factory=SLearner,
320+
metalearner_params=metalearner_params,
321+
base_learner_grid={"base_model": [LinearRegression, LGBMRegressor]},
322+
param_grid={"base_model": {"LGBMRegressor": {"n_estimators": [1, 2]}}},
323+
)
324+
with pytest.raises(
325+
ValueError, match="should not be specified in metalearner_params"
326+
):
327+
gs.fit(X, y, w, X_test, y_test, w_test)

0 commit comments

Comments
 (0)