Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]

* Refactor the TensorFlow RNG handler to use stateless seed generation.
* Add `selected_geos` arg to the optimizer.
* Add `selected_geos` arg to `get_aggregated_spend`.
* Fix bug in `optimize()` when using `new_data` with `start_date` and `end_date`
Expand Down
108 changes: 83 additions & 25 deletions meridian/analysis/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Methods to compute analysis metrics of the model and the data."""

from collections.abc import Mapping, Sequence
import dataclasses
import itertools
import numbers
from typing import Any, Optional
Expand Down Expand Up @@ -53,6 +54,7 @@ def _validate_non_media_baseline_values_numbers(


# TODO: Refactor the related unit tests to be under DataTensors.
@dataclasses.dataclass
class DataTensors(backend.ExtensionType):
"""Container for data variable arguments of Analyzer methods.

Expand Down Expand Up @@ -175,12 +177,31 @@ def __init__(
else None
)
self.time = (
backend.to_tensor(time, dtype="string") if time is not None else None
backend.to_tensor(time, dtype=backend.string)
if time is not None
else None
)

def __validate__(self):
self._validate_n_dims()

def __eq__(self, other: Any) -> bool:
"""Provides safe equality comparison for mixed tensor/non-tensor fields."""
if type(self) is not type(other):
return NotImplemented
for field in dataclasses.fields(self):
a = getattr(self, field.name)
b = getattr(other, field.name)
if a is None and b is None:
continue
if a is None or b is None:
return False
try:
if not bool(np.all(backend.to_tensor(backend.equal(a, b)))):
return False
except (ValueError, TypeError):
if a != b:
return False
return True

def total_spend(self) -> backend.Tensor | None:
"""Returns the total spend tensor.

Expand Down Expand Up @@ -216,7 +237,7 @@ def get_modified_times(self, meridian: model.Meridian) -> int | None:
of the corresponding tensor in the `meridian` object. If all time
dimensions are the same, returns `None`.
"""
for field in self._tf_extension_type_fields():
for field in dataclasses.fields(self):
new_tensor = getattr(self, field.name)
if field.name == constants.RF_IMPRESSIONS:
old_tensor = getattr(meridian.rf_tensors, field.name)
Expand Down Expand Up @@ -282,7 +303,7 @@ def validate_and_fill_missing_data(

def _validate_n_dims(self):
"""Raises an error if the tensors have the wrong number of dimensions."""
for field in self._tf_extension_type_fields():
for field in dataclasses.fields(self):
tensor = getattr(self, field.name)
if tensor is None:
continue
Expand Down Expand Up @@ -315,7 +336,7 @@ def _validate_correct_variables_filled(
Warning: If an attribute exists in the `DataTensors` object that is not in
the `required_variables` list, it will be ignored.
"""
for field in self._tf_extension_type_fields():
for field in dataclasses.fields(self):
tensor = getattr(self, field.name)
if tensor is None:
continue
Expand Down Expand Up @@ -468,7 +489,7 @@ def _fill_default_values(
) -> Self:
"""Fills default values and returns a new DataTensors object."""
output = {}
for field in self._tf_extension_type_fields():
for field in dataclasses.fields(self):
var_name = field.name
if var_name not in required_fields:
continue
Expand All @@ -489,7 +510,7 @@ def _fill_default_values(
old_tensor = meridian.revenue_per_kpi
elif var_name == constants.TIME:
old_tensor = backend.to_tensor(
meridian.input_data.time.values.tolist(), dtype="string"
meridian.input_data.time.values.tolist(), dtype=backend.string
)
else:
continue
Expand All @@ -500,6 +521,7 @@ def _fill_default_values(
return DataTensors(**output)


@dataclasses.dataclass
class DistributionTensors(backend.ExtensionType):
"""Container for parameters distributions arguments of Analyzer methods."""

Expand Down Expand Up @@ -583,17 +605,19 @@ def _transformed_new_or_scaled(

def _calc_rsquared(expected, actual):
"""Calculates r-squared between actual and expected outcome."""
return 1 - np.nanmean((expected - actual) ** 2) / np.nanvar(actual)
return 1 - backend.nanmean((expected - actual) ** 2) / backend.nanvar(actual)


def _calc_mape(expected, actual):
"""Calculates MAPE between actual and expected outcome."""
return np.nanmean(np.abs((actual - expected) / actual))
return backend.nanmean(backend.absolute((actual - expected) / actual))


def _calc_weighted_mape(expected, actual):
"""Calculates wMAPE between actual and expected outcome (weighted by actual)."""
return np.nansum(np.abs(actual - expected)) / np.nansum(actual)
return backend.nansum(backend.absolute(actual - expected)) / backend.nansum(
actual
)


def _warn_if_geo_arg_in_kwargs(**kwargs):
Expand Down Expand Up @@ -1399,8 +1423,14 @@ def filter_and_aggregate_geos_and_times(
"`selected_geos` must match the geo dimension names from "
"meridian.InputData."
)
geo_mask = [x in selected_geos for x in mmm.input_data.geo]
tensor = backend.boolean_mask(tensor, geo_mask, axis=geo_dim)
geo_indices = [
i for i, x in enumerate(mmm.input_data.geo) if x in selected_geos
]
tensor = backend.gather(
tensor,
backend.to_tensor(geo_indices, dtype=backend.int32),
axis=geo_dim,
)

if selected_times is not None:
_validate_selected_times(
Expand All @@ -1411,10 +1441,21 @@ def filter_and_aggregate_geos_and_times(
comparison_arg_name="`tensor`",
)
if _is_str_list(selected_times):
time_mask = [x in selected_times for x in mmm.input_data.time]
tensor = backend.boolean_mask(tensor, time_mask, axis=time_dim)
time_indices = [
i for i, x in enumerate(mmm.input_data.time) if x in selected_times
]
tensor = backend.gather(
tensor,
backend.to_tensor(time_indices, dtype=backend.int32),
axis=time_dim,
)
elif _is_bool_list(selected_times):
tensor = backend.boolean_mask(tensor, selected_times, axis=time_dim)
time_indices = [i for i, x in enumerate(selected_times) if x]
tensor = backend.gather(
tensor,
backend.to_tensor(time_indices, dtype=backend.int32),
axis=time_dim,
)

tensor_dims = "...gt" + "m" * has_media_dim
output_dims = (
Expand Down Expand Up @@ -1730,7 +1771,17 @@ def _inverse_outcome(
return kpi
return backend.einsum("gt,...gtm->...gtm", revenue_per_kpi, kpi)

@backend.function(jit_compile=True)
@backend.function(
jit_compile=True,
static_argnames=[
"inverse_transform_outcome",
"use_kpi",
"selected_geos",
"selected_times",
"aggregate_geos",
"aggregate_times",
],
)
def _incremental_outcome_impl(
self,
data_tensors: DataTensors,
Expand Down Expand Up @@ -2142,8 +2193,12 @@ def incremental_outcome(
)
incremental_outcome_temps = [None] * len(batch_starting_indices)
dim_kwargs = {
"selected_geos": selected_geos,
"selected_times": selected_times,
"selected_geos": (
tuple(selected_geos) if selected_geos is not None else None
),
"selected_times": (
tuple(selected_times) if selected_times is not None else None
),
"aggregate_geos": aggregate_geos,
"aggregate_times": aggregate_times,
}
Expand Down Expand Up @@ -3703,9 +3758,11 @@ def optimal_freq(
)

optimal_frequency = [freq_grid[i] for i in optimal_freq_idx]
optimal_frequency_tensor = backend.to_tensor(
backend.ones_like(filled_data.rf_impressions) * optimal_frequency,
backend.float32,
optimal_frequency_values = backend.to_tensor(
optimal_frequency, dtype=backend.float32
)
optimal_frequency_tensor = (
backend.ones_like(filled_data.rf_impressions) * optimal_frequency_values
)
optimal_reach = filled_data.rf_impressions / optimal_frequency_tensor

Expand Down Expand Up @@ -3997,10 +4054,11 @@ def get_rhat(self) -> Mapping[str, backend.Tensor]:
"sample_posterior() must be called prior to calling this method."
)

def _transpose_first_two_dims(x: backend.Tensor) -> backend.Tensor:
n_dim = len(x.shape)
def _transpose_first_two_dims(x: Any) -> backend.Tensor:
x_tensor = backend.to_tensor(x)
n_dim = len(x_tensor.shape)
perm = [1, 0] + list(range(2, n_dim))
return backend.transpose(x, perm)
return backend.transpose(x_tensor, perm)

rhat = backend.mcmc.potential_scale_reduction({
k: _transpose_first_two_dims(v)
Expand Down
Loading