Skip to content

Commit aeeef20

Browse files
lukmazThe Meridian Authors
authored andcommitted
Total treatment contribution prior.
PiperOrigin-RevId: 767234149
1 parent 4624447 commit aeeef20

11 files changed

+501
-366
lines changed

demo/RF_Data_Simulation_for_Meridian.ipynb

Lines changed: 116 additions & 116 deletions
Large diffs are not rendered by default.

meridian/constants.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,6 @@
228228
CHAIN = 'chain'
229229
DRAW = 'draw'
230230
KNOTS = 'knots'
231-
SIGMA_DIM = 'sigma_dim'
232231

233232

234233
# Model parameters.
@@ -244,6 +243,16 @@
244243
CONTRIBUTION_OM = 'contribution_om'
245244
CONTRIBUTION_ORF = 'contribution_orf'
246245
CONTRIBUTION_N = 'contribution_n'
246+
TOTAL_TREATMENT_CONTRIBUTION = 'total_treatment_contribution'
247+
TOTAL_TREATMENT_ALLOCATION_CONCENTRATION = (
248+
'total_treatment_allocation_concentration'
249+
)
250+
TOTAL_TREATMENT_ALLOCATION = 'total_treatment_allocation'
251+
TOTAL_TREATMENT_ALLOCATION_M = 'total_treatment_allocation_m'
252+
TOTAL_TREATMENT_ALLOCATION_RF = 'total_treatment_allocation_rf'
253+
TOTAL_TREATMENT_ALLOCATION_OM = 'total_treatment_allocation_om'
254+
TOTAL_TREATMENT_ALLOCATION_ORF = 'total_treatment_allocation_orf'
255+
TOTAL_TREATMENT_ALLOCATION_N = 'total_treatment_allocation_n'
247256
GAMMA_C = 'gamma_c'
248257
GAMMA_N = 'gamma_n'
249258
XI_C = 'xi_c'
@@ -317,6 +326,7 @@
317326
ROI_M,
318327
MROI_M,
319328
CONTRIBUTION_M,
329+
TOTAL_TREATMENT_ALLOCATION_M,
320330
BETA_M,
321331
ETA_M,
322332
ALPHA_M,
@@ -327,6 +337,7 @@
327337
ROI_RF,
328338
MROI_RF,
329339
CONTRIBUTION_RF,
340+
TOTAL_TREATMENT_ALLOCATION_RF,
330341
BETA_RF,
331342
ETA_RF,
332343
ALPHA_RF,
@@ -335,6 +346,7 @@
335346
)
336347
ORGANIC_MEDIA_PARAMETERS = (
337348
CONTRIBUTION_OM,
349+
TOTAL_TREATMENT_ALLOCATION_OM,
338350
BETA_OM,
339351
ETA_OM,
340352
ALPHA_OM,
@@ -343,6 +355,7 @@
343355
)
344356
ORGANIC_RF_PARAMETERS = (
345357
CONTRIBUTION_ORF,
358+
TOTAL_TREATMENT_ALLOCATION_ORF,
346359
BETA_ORF,
347360
ETA_ORF,
348361
ALPHA_ORF,
@@ -351,6 +364,7 @@
351364
)
352365
NON_MEDIA_PARAMETERS = (
353366
CONTRIBUTION_N,
367+
TOTAL_TREATMENT_ALLOCATION_N,
354368
GAMMA_N,
355369
XI_N,
356370
)
@@ -370,6 +384,10 @@
370384
GEO_RF_PARAMETERS = (BETA_GRF,)
371385
GEO_CONTROL_PARAMETERS = (GAMMA_GC,)
372386
GEO_NON_MEDIA_PARAMETERS = (GAMMA_GN,)
387+
TOTAL_TREATMENT_CONTRIBUTION_PARAMETERS = (
388+
TOTAL_TREATMENT_CONTRIBUTION,
389+
TOTAL_TREATMENT_ALLOCATION_CONCENTRATION,
390+
)
373391

374392
ALL_PRIOR_DISTRIBUTION_PARAMETERS = (
375393
*KNOTS_PARAMETERS,
@@ -382,6 +400,7 @@
382400
*SIGMA_PARAMETERS,
383401
TAU_G_EXCL_BASELINE,
384402
*TIME_PARAMETERS,
403+
*TOTAL_TREATMENT_CONTRIBUTION_PARAMETERS,
385404
)
386405

387406
UNSAVED_PARAMETERS = (
@@ -392,6 +411,7 @@
392411
GAMMA_GC_DEV,
393412
GAMMA_GN_DEV,
394413
TAU_G_EXCL_BASELINE, # Used to derive TAU_G.
414+
TOTAL_TREATMENT_ALLOCATION,
395415
)
396416
IGNORED_PRIORS_MEDIA = immutabledict.immutabledict({
397417
TREATMENT_PRIOR_TYPE_ROI: (
@@ -462,6 +482,8 @@
462482
BETA_GORF: (GEO, ORGANIC_RF_CHANNEL),
463483
GAMMA_GC: (GEO, CONTROL_VARIABLE),
464484
GAMMA_GN: (GEO, NON_MEDIA_CHANNEL),
485+
TOTAL_TREATMENT_CONTRIBUTION: (),
486+
TOTAL_TREATMENT_ALLOCATION_CONCENTRATION: (),
465487
}
466488
| {param: (CONTROL_VARIABLE,) for param in CONTROL_PARAMETERS}
467489
| {param: (NON_MEDIA_CHANNEL,) for param in NON_MEDIA_PARAMETERS}
@@ -508,6 +530,8 @@
508530
P_MEAN = 0.4
509531
# Prior standard deviation proportion of KPI incremental to all media.
510532
P_SD = 0.2
533+
# Default prior mean proportion of contribution allocated to paid channels.
534+
TOTAL_TREATMENT_PAID_MEAN_ALLOCATION = 0.75
511535

512536

513537
# Model metrics.

meridian/model/model.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -295,10 +295,6 @@ def n_media_times(self) -> int:
295295
def is_national(self) -> bool:
296296
return self.n_geos == 1
297297

298-
@property
299-
def _sigma_shape(self) -> int:
300-
return len(self.input_data.geo) if self.unique_sigma_for_each_geo else 1
301-
302298
@functools.cached_property
303299
def knot_info(self) -> knots.KnotInfo:
304300
return knots.get_knot_info(
@@ -449,7 +445,7 @@ def prior_broadcast(self) -> prior_distribution.PriorDistribution:
449445
n_organic_rf_channels=self.n_organic_rf_channels,
450446
n_controls=self.n_controls,
451447
n_non_media_channels=self.n_non_media_channels,
452-
sigma_shape=self._sigma_shape,
448+
unique_sigma_for_each_geo=self.unique_sigma_for_each_geo,
453449
n_knots=self.knot_info.n_knots,
454450
is_national=self.is_national,
455451
set_total_media_contribution_prior=self._set_total_media_contribution_prior,
@@ -663,10 +659,6 @@ def _validate_injected_inference_data_group(
663659
self._validate_injected_inference_data_group_coord(
664660
inference_data, group, constants.TIME, self.n_times
665661
)
666-
if not self.model_spec.unique_sigma_for_each_geo:
667-
self._validate_injected_inference_data_group_coord(
668-
inference_data, group, constants.SIGMA_DIM, self._sigma_shape
669-
)
670662
self._validate_injected_inference_data_group_coord(
671663
inference_data,
672664
group,
@@ -1429,7 +1421,7 @@ def create_inference_data_dims(self) -> Mapping[str, Sequence[str]]:
14291421
if self.unique_sigma_for_each_geo:
14301422
inference_dims[constants.SIGMA] = [constants.GEO]
14311423
else:
1432-
inference_dims[constants.SIGMA] = [constants.SIGMA_DIM]
1424+
inference_dims[constants.SIGMA] = []
14331425

14341426
return {
14351427
param: [constants.CHAIN, constants.DRAW] + list(dims)

meridian/model/model_test.py

Lines changed: 4 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,7 +1071,7 @@ def test_broadcast_prior_distribution_is_called_in_meridian_init(self):
10711071
self.assertEqual(broad.batch_shape, (meridian.n_controls,))
10721072

10731073
# Validate sigma.
1074-
self.assertEqual(meridian.prior_broadcast.sigma.batch_shape, (1,))
1074+
self.assertEqual(meridian.prior_broadcast.sigma.batch_shape, ())
10751075

10761076
@parameterized.named_parameters(
10771077
dict(
@@ -1725,7 +1725,7 @@ def test_broadcast_prior_distribution_is_called_in_meridian_init(self):
17251725
self.assertEqual(broad.batch_shape, (meridian.n_non_media_channels,))
17261726

17271727
# Validate sigma.
1728-
self.assertEqual(meridian.prior_broadcast.sigma.batch_shape, (1,))
1728+
self.assertEqual(meridian.prior_broadcast.sigma.batch_shape, ())
17291729

17301730
def test_scaled_data_shape(self):
17311731
meridian = model.Meridian(input_data=self.input_data_non_media_and_organic)
@@ -2182,9 +2182,8 @@ def test_inference_data_non_paid_correct_dims(self):
21822182
dims = prior_dims[param]
21832183
self.assertEqual(len(dims), len(tensor.shape))
21842184
for dim, shape_dim in zip(dims, tensor.shape):
2185-
if dim != constants.SIGMA_DIM:
2186-
self.assertIn(dim, prior_coords)
2187-
self.assertLen(prior_coords[dim], shape_dim)
2185+
self.assertIn(dim, prior_coords)
2186+
self.assertLen(prior_coords[dim], shape_dim)
21882187

21892188
def test_validate_injected_inference_data_correct_shapes(self):
21902189
"""Checks validation passes with correct shapes."""
@@ -2330,155 +2329,6 @@ def test_validate_injected_inference_data_prior_incorrect_coordinates(
23302329
inference_data=inference_data,
23312330
)
23322331

2333-
@parameterized.named_parameters(
2334-
dict(
2335-
testcase_name="sigma_dims_unique_sigma",
2336-
coord=constants.GEO,
2337-
mismatched_priors={
2338-
constants.BETA_GOM: (
2339-
1,
2340-
input_data_samples._N_DRAWS,
2341-
input_data_samples._N_GEOS + 1,
2342-
input_data_samples._N_ORGANIC_MEDIA_CHANNELS,
2343-
),
2344-
constants.BETA_GORF: (
2345-
1,
2346-
input_data_samples._N_DRAWS,
2347-
input_data_samples._N_GEOS + 1,
2348-
input_data_samples._N_ORGANIC_RF_CHANNELS,
2349-
),
2350-
constants.GAMMA_GN: (
2351-
1,
2352-
input_data_samples._N_DRAWS,
2353-
input_data_samples._N_GEOS + 1,
2354-
input_data_samples._N_NON_MEDIA_CHANNELS,
2355-
),
2356-
constants.GAMMA_GC: (
2357-
1,
2358-
input_data_samples._N_DRAWS,
2359-
input_data_samples._N_GEOS + 1,
2360-
input_data_samples._N_CONTROLS,
2361-
),
2362-
constants.TAU_G: (
2363-
1,
2364-
input_data_samples._N_DRAWS,
2365-
input_data_samples._N_GEOS + 1,
2366-
input_data_samples._N_CONTROLS,
2367-
),
2368-
constants.TAU_G_EXCL_BASELINE: (
2369-
1,
2370-
input_data_samples._N_DRAWS,
2371-
input_data_samples._N_GEOS + 1,
2372-
),
2373-
constants.BETA_GM: (
2374-
1,
2375-
input_data_samples._N_DRAWS,
2376-
input_data_samples._N_GEOS + 1,
2377-
input_data_samples._N_MEDIA_CHANNELS,
2378-
),
2379-
constants.BETA_GRF: (
2380-
1,
2381-
input_data_samples._N_DRAWS,
2382-
input_data_samples._N_GEOS + 1,
2383-
input_data_samples._N_RF_CHANNELS,
2384-
),
2385-
constants.BETA_GOM_DEV: (
2386-
1,
2387-
input_data_samples._N_DRAWS,
2388-
input_data_samples._N_GEOS + 1,
2389-
input_data_samples._N_ORGANIC_MEDIA_CHANNELS,
2390-
),
2391-
constants.BETA_GORF_DEV: (
2392-
1,
2393-
input_data_samples._N_DRAWS,
2394-
input_data_samples._N_GEOS + 1,
2395-
input_data_samples._N_ORGANIC_RF_CHANNELS,
2396-
),
2397-
constants.GAMMA_GN_DEV: (
2398-
1,
2399-
input_data_samples._N_DRAWS,
2400-
input_data_samples._N_GEOS + 1,
2401-
input_data_samples._N_NON_MEDIA_CHANNELS,
2402-
),
2403-
constants.GAMMA_GC_DEV: (
2404-
1,
2405-
input_data_samples._N_DRAWS,
2406-
input_data_samples._N_GEOS + 1,
2407-
input_data_samples._N_CONTROLS,
2408-
),
2409-
constants.SIGMA: (
2410-
1,
2411-
input_data_samples._N_DRAWS,
2412-
input_data_samples._N_GEOS + 1,
2413-
),
2414-
},
2415-
mismatched_coord_size=input_data_samples._N_GEOS + 1,
2416-
expected_coord_size=input_data_samples._N_GEOS,
2417-
unique_sigma=True,
2418-
),
2419-
dict(
2420-
testcase_name="sigma_dims_not_unique_sigma",
2421-
coord=constants.SIGMA_DIM,
2422-
mismatched_priors={
2423-
constants.SIGMA: (
2424-
1,
2425-
input_data_samples._N_DRAWS,
2426-
2,
2427-
),
2428-
},
2429-
mismatched_coord_size=2,
2430-
expected_coord_size=1,
2431-
unique_sigma=False,
2432-
),
2433-
)
2434-
def test_validate_injected_inference_data_prior_incorrect_sigma_coordinates(
2435-
self,
2436-
coord,
2437-
mismatched_priors,
2438-
mismatched_coord_size,
2439-
expected_coord_size,
2440-
unique_sigma,
2441-
):
2442-
"""Checks validation fails with incorrect coordinates for sigma."""
2443-
model_spec = spec.ModelSpec(unique_sigma_for_each_geo=unique_sigma)
2444-
meridian = model.Meridian(
2445-
input_data=self.input_data_non_media_and_organic,
2446-
model_spec=model_spec,
2447-
)
2448-
prior_samples = meridian.prior_sampler_callable._sample_prior(self._N_DRAWS)
2449-
prior_coords = meridian.create_inference_data_coords(1, self._N_DRAWS)
2450-
prior_dims = meridian.create_inference_data_dims()
2451-
2452-
prior_samples = dict(prior_samples)
2453-
for param in mismatched_priors:
2454-
prior_samples[param] = tf.zeros(mismatched_priors[param])
2455-
prior_coords = dict(prior_coords)
2456-
prior_coords[coord] = np.arange(mismatched_coord_size)
2457-
if unique_sigma:
2458-
prior_coords[constants.GEO] = np.arange(mismatched_coord_size)
2459-
else:
2460-
prior_coords[constants.SIGMA_DIM] = np.arange(mismatched_coord_size)
2461-
2462-
inference_data = az.convert_to_inference_data(
2463-
prior_samples,
2464-
coords=prior_coords,
2465-
dims=prior_dims,
2466-
group=constants.PRIOR,
2467-
)
2468-
2469-
with self.assertRaisesRegex(
2470-
ValueError,
2471-
"Injected inference data prior has incorrect coordinate"
2472-
f" '{coord}': expected"
2473-
f" {expected_coord_size}, got"
2474-
f" {mismatched_coord_size}",
2475-
):
2476-
_ = model.Meridian(
2477-
input_data=self.input_data_non_media_and_organic,
2478-
model_spec=model_spec,
2479-
inference_data=inference_data,
2480-
)
2481-
24822332
def test_compute_non_media_treatments_baseline_wrong_baseline_values_shape_raises_exception(
24832333
self,
24842334
):

0 commit comments

Comments
 (0)