@@ -1071,7 +1071,7 @@ def test_broadcast_prior_distribution_is_called_in_meridian_init(self):
10711071 self .assertEqual (broad .batch_shape , (meridian .n_controls ,))
10721072
10731073 # Validate sigma.
1074- self .assertEqual (meridian .prior_broadcast .sigma .batch_shape , (1 , ))
1074+ self .assertEqual (meridian .prior_broadcast .sigma .batch_shape , ())
10751075
10761076 @parameterized .named_parameters (
10771077 dict (
@@ -1725,7 +1725,7 @@ def test_broadcast_prior_distribution_is_called_in_meridian_init(self):
17251725 self .assertEqual (broad .batch_shape , (meridian .n_non_media_channels ,))
17261726
17271727 # Validate sigma.
1728- self .assertEqual (meridian .prior_broadcast .sigma .batch_shape , (1 , ))
1728+ self .assertEqual (meridian .prior_broadcast .sigma .batch_shape , ())
17291729
17301730 def test_scaled_data_shape (self ):
17311731 meridian = model .Meridian (input_data = self .input_data_non_media_and_organic )
@@ -2182,9 +2182,8 @@ def test_inference_data_non_paid_correct_dims(self):
21822182 dims = prior_dims [param ]
21832183 self .assertEqual (len (dims ), len (tensor .shape ))
21842184 for dim , shape_dim in zip (dims , tensor .shape ):
2185- if dim != constants .SIGMA_DIM :
2186- self .assertIn (dim , prior_coords )
2187- self .assertLen (prior_coords [dim ], shape_dim )
2185+ self .assertIn (dim , prior_coords )
2186+ self .assertLen (prior_coords [dim ], shape_dim )
21882187
21892188 def test_validate_injected_inference_data_correct_shapes (self ):
21902189 """Checks validation passes with correct shapes."""
@@ -2330,155 +2329,6 @@ def test_validate_injected_inference_data_prior_incorrect_coordinates(
23302329 inference_data = inference_data ,
23312330 )
23322331
2333- @parameterized .named_parameters (
2334- dict (
2335- testcase_name = "sigma_dims_unique_sigma" ,
2336- coord = constants .GEO ,
2337- mismatched_priors = {
2338- constants .BETA_GOM : (
2339- 1 ,
2340- input_data_samples ._N_DRAWS ,
2341- input_data_samples ._N_GEOS + 1 ,
2342- input_data_samples ._N_ORGANIC_MEDIA_CHANNELS ,
2343- ),
2344- constants .BETA_GORF : (
2345- 1 ,
2346- input_data_samples ._N_DRAWS ,
2347- input_data_samples ._N_GEOS + 1 ,
2348- input_data_samples ._N_ORGANIC_RF_CHANNELS ,
2349- ),
2350- constants .GAMMA_GN : (
2351- 1 ,
2352- input_data_samples ._N_DRAWS ,
2353- input_data_samples ._N_GEOS + 1 ,
2354- input_data_samples ._N_NON_MEDIA_CHANNELS ,
2355- ),
2356- constants .GAMMA_GC : (
2357- 1 ,
2358- input_data_samples ._N_DRAWS ,
2359- input_data_samples ._N_GEOS + 1 ,
2360- input_data_samples ._N_CONTROLS ,
2361- ),
2362- constants .TAU_G : (
2363- 1 ,
2364- input_data_samples ._N_DRAWS ,
2365- input_data_samples ._N_GEOS + 1 ,
2366- input_data_samples ._N_CONTROLS ,
2367- ),
2368- constants .TAU_G_EXCL_BASELINE : (
2369- 1 ,
2370- input_data_samples ._N_DRAWS ,
2371- input_data_samples ._N_GEOS + 1 ,
2372- ),
2373- constants .BETA_GM : (
2374- 1 ,
2375- input_data_samples ._N_DRAWS ,
2376- input_data_samples ._N_GEOS + 1 ,
2377- input_data_samples ._N_MEDIA_CHANNELS ,
2378- ),
2379- constants .BETA_GRF : (
2380- 1 ,
2381- input_data_samples ._N_DRAWS ,
2382- input_data_samples ._N_GEOS + 1 ,
2383- input_data_samples ._N_RF_CHANNELS ,
2384- ),
2385- constants .BETA_GOM_DEV : (
2386- 1 ,
2387- input_data_samples ._N_DRAWS ,
2388- input_data_samples ._N_GEOS + 1 ,
2389- input_data_samples ._N_ORGANIC_MEDIA_CHANNELS ,
2390- ),
2391- constants .BETA_GORF_DEV : (
2392- 1 ,
2393- input_data_samples ._N_DRAWS ,
2394- input_data_samples ._N_GEOS + 1 ,
2395- input_data_samples ._N_ORGANIC_RF_CHANNELS ,
2396- ),
2397- constants .GAMMA_GN_DEV : (
2398- 1 ,
2399- input_data_samples ._N_DRAWS ,
2400- input_data_samples ._N_GEOS + 1 ,
2401- input_data_samples ._N_NON_MEDIA_CHANNELS ,
2402- ),
2403- constants .GAMMA_GC_DEV : (
2404- 1 ,
2405- input_data_samples ._N_DRAWS ,
2406- input_data_samples ._N_GEOS + 1 ,
2407- input_data_samples ._N_CONTROLS ,
2408- ),
2409- constants .SIGMA : (
2410- 1 ,
2411- input_data_samples ._N_DRAWS ,
2412- input_data_samples ._N_GEOS + 1 ,
2413- ),
2414- },
2415- mismatched_coord_size = input_data_samples ._N_GEOS + 1 ,
2416- expected_coord_size = input_data_samples ._N_GEOS ,
2417- unique_sigma = True ,
2418- ),
2419- dict (
2420- testcase_name = "sigma_dims_not_unique_sigma" ,
2421- coord = constants .SIGMA_DIM ,
2422- mismatched_priors = {
2423- constants .SIGMA : (
2424- 1 ,
2425- input_data_samples ._N_DRAWS ,
2426- 2 ,
2427- ),
2428- },
2429- mismatched_coord_size = 2 ,
2430- expected_coord_size = 1 ,
2431- unique_sigma = False ,
2432- ),
2433- )
2434- def test_validate_injected_inference_data_prior_incorrect_sigma_coordinates (
2435- self ,
2436- coord ,
2437- mismatched_priors ,
2438- mismatched_coord_size ,
2439- expected_coord_size ,
2440- unique_sigma ,
2441- ):
2442- """Checks validation fails with incorrect coordinates for sigma."""
2443- model_spec = spec .ModelSpec (unique_sigma_for_each_geo = unique_sigma )
2444- meridian = model .Meridian (
2445- input_data = self .input_data_non_media_and_organic ,
2446- model_spec = model_spec ,
2447- )
2448- prior_samples = meridian .prior_sampler_callable ._sample_prior (self ._N_DRAWS )
2449- prior_coords = meridian .create_inference_data_coords (1 , self ._N_DRAWS )
2450- prior_dims = meridian .create_inference_data_dims ()
2451-
2452- prior_samples = dict (prior_samples )
2453- for param in mismatched_priors :
2454- prior_samples [param ] = tf .zeros (mismatched_priors [param ])
2455- prior_coords = dict (prior_coords )
2456- prior_coords [coord ] = np .arange (mismatched_coord_size )
2457- if unique_sigma :
2458- prior_coords [constants .GEO ] = np .arange (mismatched_coord_size )
2459- else :
2460- prior_coords [constants .SIGMA_DIM ] = np .arange (mismatched_coord_size )
2461-
2462- inference_data = az .convert_to_inference_data (
2463- prior_samples ,
2464- coords = prior_coords ,
2465- dims = prior_dims ,
2466- group = constants .PRIOR ,
2467- )
2468-
2469- with self .assertRaisesRegex (
2470- ValueError ,
2471- "Injected inference data prior has incorrect coordinate"
2472- f" '{ coord } ': expected"
2473- f" { expected_coord_size } , got"
2474- f" { mismatched_coord_size } " ,
2475- ):
2476- _ = model .Meridian (
2477- input_data = self .input_data_non_media_and_organic ,
2478- model_spec = model_spec ,
2479- inference_data = inference_data ,
2480- )
2481-
24822332 def test_compute_non_media_treatments_baseline_wrong_baseline_values_shape_raises_exception (
24832333 self ,
24842334 ):
0 commit comments