Skip to content

Commit 8269619

Browse files
andyl7anThe Meridian Authors
authored andcommitted
[JAX] Implement JAX for optimizer
PiperOrigin-RevId: 823249034
1 parent c343225 commit 8269619

File tree

11 files changed

+919
-475
lines changed

11 files changed

+919
-475
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
* Refactor the TensorFlow RNG handler to use stateless seed generation.
2627
* Add `selected_geos` arg to the optimizer.
2728
* Add `selected_geos` arg to `get_aggregated_spend`.
2829
* Fix bug in `optimize()` when using `new_data` with `start_date` and `end_date`

meridian/analysis/analyzer.py

Lines changed: 83 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Methods to compute analysis metrics of the model and the data."""
1616

1717
from collections.abc import Mapping, Sequence
18+
import dataclasses
1819
import itertools
1920
import numbers
2021
from typing import Any, Optional
@@ -53,6 +54,7 @@ def _validate_non_media_baseline_values_numbers(
5354

5455

5556
# TODO: Refactor the related unit tests to be under DataTensors.
57+
@dataclasses.dataclass
5658
class DataTensors(backend.ExtensionType):
5759
"""Container for data variable arguments of Analyzer methods.
5860
@@ -175,12 +177,31 @@ def __init__(
175177
else None
176178
)
177179
self.time = (
178-
backend.to_tensor(time, dtype="string") if time is not None else None
180+
backend.to_tensor(time, dtype=backend.string)
181+
if time is not None
182+
else None
179183
)
180-
181-
def __validate__(self):
182184
self._validate_n_dims()
183185

186+
def __eq__(self, other: Any) -> bool:
187+
"""Provides safe equality comparison for mixed tensor/non-tensor fields."""
188+
if type(self) is not type(other):
189+
return NotImplemented
190+
for field in dataclasses.fields(self):
191+
a = getattr(self, field.name)
192+
b = getattr(other, field.name)
193+
if a is None and b is None:
194+
continue
195+
if a is None or b is None:
196+
return False
197+
try:
198+
if not bool(np.all(backend.to_tensor(backend.equal(a, b)))):
199+
return False
200+
except (ValueError, TypeError):
201+
if a != b:
202+
return False
203+
return True
204+
184205
def total_spend(self) -> backend.Tensor | None:
185206
"""Returns the total spend tensor.
186207
@@ -216,7 +237,7 @@ def get_modified_times(self, meridian: model.Meridian) -> int | None:
216237
of the corresponding tensor in the `meridian` object. If all time
217238
dimensions are the same, returns `None`.
218239
"""
219-
for field in self._tf_extension_type_fields():
240+
for field in dataclasses.fields(self):
220241
new_tensor = getattr(self, field.name)
221242
if field.name == constants.RF_IMPRESSIONS:
222243
old_tensor = getattr(meridian.rf_tensors, field.name)
@@ -282,7 +303,7 @@ def validate_and_fill_missing_data(
282303

283304
def _validate_n_dims(self):
284305
"""Raises an error if the tensors have the wrong number of dimensions."""
285-
for field in self._tf_extension_type_fields():
306+
for field in dataclasses.fields(self):
286307
tensor = getattr(self, field.name)
287308
if tensor is None:
288309
continue
@@ -315,7 +336,7 @@ def _validate_correct_variables_filled(
315336
Warning: If an attribute exists in the `DataTensors` object that is not in
316337
the `required_variables` list, it will be ignored.
317338
"""
318-
for field in self._tf_extension_type_fields():
339+
for field in dataclasses.fields(self):
319340
tensor = getattr(self, field.name)
320341
if tensor is None:
321342
continue
@@ -468,7 +489,7 @@ def _fill_default_values(
468489
) -> Self:
469490
"""Fills default values and returns a new DataTensors object."""
470491
output = {}
471-
for field in self._tf_extension_type_fields():
492+
for field in dataclasses.fields(self):
472493
var_name = field.name
473494
if var_name not in required_fields:
474495
continue
@@ -489,7 +510,7 @@ def _fill_default_values(
489510
old_tensor = meridian.revenue_per_kpi
490511
elif var_name == constants.TIME:
491512
old_tensor = backend.to_tensor(
492-
meridian.input_data.time.values.tolist(), dtype="string"
513+
meridian.input_data.time.values.tolist(), dtype=backend.string
493514
)
494515
else:
495516
continue
@@ -500,6 +521,7 @@ def _fill_default_values(
500521
return DataTensors(**output)
501522

502523

524+
@dataclasses.dataclass
503525
class DistributionTensors(backend.ExtensionType):
504526
"""Container for parameters distributions arguments of Analyzer methods."""
505527

@@ -583,17 +605,19 @@ def _transformed_new_or_scaled(
583605

584606
def _calc_rsquared(expected, actual):
585607
"""Calculates r-squared between actual and expected outcome."""
586-
return 1 - np.nanmean((expected - actual) ** 2) / np.nanvar(actual)
608+
return 1 - backend.nanmean((expected - actual) ** 2) / backend.nanvar(actual)
587609

588610

589611
def _calc_mape(expected, actual):
590612
"""Calculates MAPE between actual and expected outcome."""
591-
return np.nanmean(np.abs((actual - expected) / actual))
613+
return backend.nanmean(backend.absolute((actual - expected) / actual))
592614

593615

594616
def _calc_weighted_mape(expected, actual):
595617
"""Calculates wMAPE between actual and expected outcome (weighted by actual)."""
596-
return np.nansum(np.abs(actual - expected)) / np.nansum(actual)
618+
return backend.nansum(backend.absolute(actual - expected)) / backend.nansum(
619+
actual
620+
)
597621

598622

599623
def _warn_if_geo_arg_in_kwargs(**kwargs):
@@ -1399,8 +1423,14 @@ def filter_and_aggregate_geos_and_times(
13991423
"`selected_geos` must match the geo dimension names from "
14001424
"meridian.InputData."
14011425
)
1402-
geo_mask = [x in selected_geos for x in mmm.input_data.geo]
1403-
tensor = backend.boolean_mask(tensor, geo_mask, axis=geo_dim)
1426+
geo_indices = [
1427+
i for i, x in enumerate(mmm.input_data.geo) if x in selected_geos
1428+
]
1429+
tensor = backend.gather(
1430+
tensor,
1431+
backend.to_tensor(geo_indices, dtype=backend.int32),
1432+
axis=geo_dim,
1433+
)
14041434

14051435
if selected_times is not None:
14061436
_validate_selected_times(
@@ -1411,10 +1441,21 @@ def filter_and_aggregate_geos_and_times(
14111441
comparison_arg_name="`tensor`",
14121442
)
14131443
if _is_str_list(selected_times):
1414-
time_mask = [x in selected_times for x in mmm.input_data.time]
1415-
tensor = backend.boolean_mask(tensor, time_mask, axis=time_dim)
1444+
time_indices = [
1445+
i for i, x in enumerate(mmm.input_data.time) if x in selected_times
1446+
]
1447+
tensor = backend.gather(
1448+
tensor,
1449+
backend.to_tensor(time_indices, dtype=backend.int32),
1450+
axis=time_dim,
1451+
)
14161452
elif _is_bool_list(selected_times):
1417-
tensor = backend.boolean_mask(tensor, selected_times, axis=time_dim)
1453+
time_indices = [i for i, x in enumerate(selected_times) if x]
1454+
tensor = backend.gather(
1455+
tensor,
1456+
backend.to_tensor(time_indices, dtype=backend.int32),
1457+
axis=time_dim,
1458+
)
14181459

14191460
tensor_dims = "...gt" + "m" * has_media_dim
14201461
output_dims = (
@@ -1730,7 +1771,17 @@ def _inverse_outcome(
17301771
return kpi
17311772
return backend.einsum("gt,...gtm->...gtm", revenue_per_kpi, kpi)
17321773

1733-
@backend.function(jit_compile=True)
1774+
@backend.function(
1775+
jit_compile=True,
1776+
static_argnames=[
1777+
"inverse_transform_outcome",
1778+
"use_kpi",
1779+
"selected_geos",
1780+
"selected_times",
1781+
"aggregate_geos",
1782+
"aggregate_times",
1783+
],
1784+
)
17341785
def _incremental_outcome_impl(
17351786
self,
17361787
data_tensors: DataTensors,
@@ -2142,8 +2193,12 @@ def incremental_outcome(
21422193
)
21432194
incremental_outcome_temps = [None] * len(batch_starting_indices)
21442195
dim_kwargs = {
2145-
"selected_geos": selected_geos,
2146-
"selected_times": selected_times,
2196+
"selected_geos": (
2197+
tuple(selected_geos) if selected_geos is not None else None
2198+
),
2199+
"selected_times": (
2200+
tuple(selected_times) if selected_times is not None else None
2201+
),
21472202
"aggregate_geos": aggregate_geos,
21482203
"aggregate_times": aggregate_times,
21492204
}
@@ -3703,9 +3758,11 @@ def optimal_freq(
37033758
)
37043759

37053760
optimal_frequency = [freq_grid[i] for i in optimal_freq_idx]
3706-
optimal_frequency_tensor = backend.to_tensor(
3707-
backend.ones_like(filled_data.rf_impressions) * optimal_frequency,
3708-
backend.float32,
3761+
optimal_frequency_values = backend.to_tensor(
3762+
optimal_frequency, dtype=backend.float32
3763+
)
3764+
optimal_frequency_tensor = (
3765+
backend.ones_like(filled_data.rf_impressions) * optimal_frequency_values
37093766
)
37103767
optimal_reach = filled_data.rf_impressions / optimal_frequency_tensor
37113768

@@ -3997,10 +4054,11 @@ def get_rhat(self) -> Mapping[str, backend.Tensor]:
39974054
"sample_posterior() must be called prior to calling this method."
39984055
)
39994056

4000-
def _transpose_first_two_dims(x: backend.Tensor) -> backend.Tensor:
4001-
n_dim = len(x.shape)
4057+
def _transpose_first_two_dims(x: Any) -> backend.Tensor:
4058+
x_tensor = backend.to_tensor(x)
4059+
n_dim = len(x_tensor.shape)
40024060
perm = [1, 0] + list(range(2, n_dim))
4003-
return backend.transpose(x, perm)
4061+
return backend.transpose(x_tensor, perm)
40044062

40054063
rhat = backend.mcmc.potential_scale_reduction({
40064064
k: _transpose_first_two_dims(v)

0 commit comments

Comments
 (0)