1515"""Methods to compute analysis metrics of the model and the data.""" 
1616
1717from  collections .abc  import  Mapping , Sequence 
18+ import  dataclasses 
1819import  itertools 
1920import  numbers 
2021from  typing  import  Any , Optional 
@@ -53,6 +54,7 @@ def _validate_non_media_baseline_values_numbers(
5354
5455
5556# TODO: Refactor the related unit tests to be under DataTensors. 
57+ @dataclasses .dataclass  
5658class  DataTensors (backend .ExtensionType ):
5759  """Container for data variable arguments of Analyzer methods. 
5860
@@ -175,12 +177,31 @@ def __init__(
175177        else  None 
176178    )
177179    self .time  =  (
178-         backend .to_tensor (time , dtype = "string" ) if  time  is  not   None  else  None 
180+         backend .to_tensor (time , dtype = backend .string )
181+         if  time  is  not   None 
182+         else  None 
179183    )
180- 
181-   def  __validate__ (self ):
182184    self ._validate_n_dims ()
183185
186+   def  __eq__ (self , other : Any ) ->  bool :
187+     """Provides safe equality comparison for mixed tensor/non-tensor fields.""" 
188+     if  type (self ) is  not   type (other ):
189+       return  NotImplemented 
190+     for  field  in  dataclasses .fields (self ):
191+       a  =  getattr (self , field .name )
192+       b  =  getattr (other , field .name )
193+       if  a  is  None  and  b  is  None :
194+         continue 
195+       if  a  is  None  or  b  is  None :
196+         return  False 
197+       try :
198+         if  not  bool (np .all (backend .to_tensor (backend .equal (a , b )))):
199+           return  False 
200+       except  (ValueError , TypeError ):
201+         if  a  !=  b :
202+           return  False 
203+     return  True 
204+ 
184205  def  total_spend (self ) ->  backend .Tensor  |  None :
185206    """Returns the total spend tensor. 
186207
@@ -216,7 +237,7 @@ def get_modified_times(self, meridian: model.Meridian) -> int | None:
216237      of the corresponding tensor in the `meridian` object. If all time 
217238      dimensions are the same, returns `None`. 
218239    """ 
219-     for  field  in  self . _tf_extension_type_fields ( ):
240+     for  field  in  dataclasses . fields ( self ):
220241      new_tensor  =  getattr (self , field .name )
221242      if  field .name  ==  constants .RF_IMPRESSIONS :
222243        old_tensor  =  getattr (meridian .rf_tensors , field .name )
@@ -282,7 +303,7 @@ def validate_and_fill_missing_data(
282303
283304  def  _validate_n_dims (self ):
284305    """Raises an error if the tensors have the wrong number of dimensions.""" 
285-     for  field  in  self . _tf_extension_type_fields ( ):
306+     for  field  in  dataclasses . fields ( self ):
286307      tensor  =  getattr (self , field .name )
287308      if  tensor  is  None :
288309        continue 
@@ -315,7 +336,7 @@ def _validate_correct_variables_filled(
315336      Warning: If an attribute exists in the `DataTensors` object that is not in 
316337        the `required_variables` list, it will be ignored. 
317338    """ 
318-     for  field  in  self . _tf_extension_type_fields ( ):
339+     for  field  in  dataclasses . fields ( self ):
319340      tensor  =  getattr (self , field .name )
320341      if  tensor  is  None :
321342        continue 
@@ -468,7 +489,7 @@ def _fill_default_values(
468489  ) ->  Self :
469490    """Fills default values and returns a new DataTensors object.""" 
470491    output  =  {}
471-     for  field  in  self . _tf_extension_type_fields ( ):
492+     for  field  in  dataclasses . fields ( self ):
472493      var_name  =  field .name 
473494      if  var_name  not  in   required_fields :
474495        continue 
@@ -489,7 +510,7 @@ def _fill_default_values(
489510        old_tensor  =  meridian .revenue_per_kpi 
490511      elif  var_name  ==  constants .TIME :
491512        old_tensor  =  backend .to_tensor (
492-             meridian .input_data .time .values .tolist (), dtype = " string" 
513+             meridian .input_data .time .values .tolist (), dtype = backend . string 
493514        )
494515      else :
495516        continue 
@@ -500,6 +521,7 @@ def _fill_default_values(
500521    return  DataTensors (** output )
501522
502523
524+ @dataclasses .dataclass  
503525class  DistributionTensors (backend .ExtensionType ):
504526  """Container for parameters distributions arguments of Analyzer methods.""" 
505527
@@ -583,17 +605,19 @@ def _transformed_new_or_scaled(
583605
584606def  _calc_rsquared (expected , actual ):
585607  """Calculates r-squared between actual and expected outcome.""" 
586-   return  1  -  np .nanmean ((expected  -  actual ) **  2 ) /  np .nanvar (actual )
608+   return  1  -  backend .nanmean ((expected  -  actual ) **  2 ) /  backend .nanvar (actual )
587609
588610
589611def  _calc_mape (expected , actual ):
590612  """Calculates MAPE between actual and expected outcome.""" 
591-   return  np .nanmean (np . abs ((actual  -  expected ) /  actual ))
613+   return  backend .nanmean (backend . absolute ((actual  -  expected ) /  actual ))
592614
593615
594616def  _calc_weighted_mape (expected , actual ):
595617  """Calculates wMAPE between actual and expected outcome (weighted by actual).""" 
596-   return  np .nansum (np .abs (actual  -  expected )) /  np .nansum (actual )
618+   return  backend .nansum (backend .absolute (actual  -  expected )) /  backend .nansum (
619+       actual 
620+   )
597621
598622
599623def  _warn_if_geo_arg_in_kwargs (** kwargs ):
@@ -1399,8 +1423,14 @@ def filter_and_aggregate_geos_and_times(
13991423            "`selected_geos` must match the geo dimension names from " 
14001424            "meridian.InputData." 
14011425        )
1402-       geo_mask  =  [x  in  selected_geos  for  x  in  mmm .input_data .geo ]
1403-       tensor  =  backend .boolean_mask (tensor , geo_mask , axis = geo_dim )
1426+       geo_indices  =  [
1427+           i  for  i , x  in  enumerate (mmm .input_data .geo ) if  x  in  selected_geos 
1428+       ]
1429+       tensor  =  backend .gather (
1430+           tensor ,
1431+           backend .to_tensor (geo_indices , dtype = backend .int32 ),
1432+           axis = geo_dim ,
1433+       )
14041434
14051435    if  selected_times  is  not   None :
14061436      _validate_selected_times (
@@ -1411,10 +1441,21 @@ def filter_and_aggregate_geos_and_times(
14111441          comparison_arg_name = "`tensor`" ,
14121442      )
14131443      if  _is_str_list (selected_times ):
1414-         time_mask  =  [x  in  selected_times  for  x  in  mmm .input_data .time ]
1415-         tensor  =  backend .boolean_mask (tensor , time_mask , axis = time_dim )
1444+         time_indices  =  [
1445+             i  for  i , x  in  enumerate (mmm .input_data .time ) if  x  in  selected_times 
1446+         ]
1447+         tensor  =  backend .gather (
1448+             tensor ,
1449+             backend .to_tensor (time_indices , dtype = backend .int32 ),
1450+             axis = time_dim ,
1451+         )
14161452      elif  _is_bool_list (selected_times ):
1417-         tensor  =  backend .boolean_mask (tensor , selected_times , axis = time_dim )
1453+         time_indices  =  [i  for  i , x  in  enumerate (selected_times ) if  x ]
1454+         tensor  =  backend .gather (
1455+             tensor ,
1456+             backend .to_tensor (time_indices , dtype = backend .int32 ),
1457+             axis = time_dim ,
1458+         )
14181459
14191460    tensor_dims  =  "...gt"  +  "m"  *  has_media_dim 
14201461    output_dims  =  (
@@ -1730,7 +1771,17 @@ def _inverse_outcome(
17301771      return  kpi 
17311772    return  backend .einsum ("gt,...gtm->...gtm" , revenue_per_kpi , kpi )
17321773
1733-   @backend .function (jit_compile = True ) 
1774+   @backend .function ( 
1775+       jit_compile = True , 
1776+       static_argnames = [ 
1777+           "inverse_transform_outcome" , 
1778+           "use_kpi" , 
1779+           "selected_geos" , 
1780+           "selected_times" , 
1781+           "aggregate_geos" , 
1782+           "aggregate_times" , 
1783+       ], 
1784+   ) 
17341785  def  _incremental_outcome_impl (
17351786      self ,
17361787      data_tensors : DataTensors ,
@@ -2142,8 +2193,12 @@ def incremental_outcome(
21422193    )
21432194    incremental_outcome_temps  =  [None ] *  len (batch_starting_indices )
21442195    dim_kwargs  =  {
2145-         "selected_geos" : selected_geos ,
2146-         "selected_times" : selected_times ,
2196+         "selected_geos" : (
2197+             tuple (selected_geos ) if  selected_geos  is  not   None  else  None 
2198+         ),
2199+         "selected_times" : (
2200+             tuple (selected_times ) if  selected_times  is  not   None  else  None 
2201+         ),
21472202        "aggregate_geos" : aggregate_geos ,
21482203        "aggregate_times" : aggregate_times ,
21492204    }
@@ -3703,9 +3758,11 @@ def optimal_freq(
37033758    )
37043759
37053760    optimal_frequency  =  [freq_grid [i ] for  i  in  optimal_freq_idx ]
3706-     optimal_frequency_tensor  =  backend .to_tensor (
3707-         backend .ones_like (filled_data .rf_impressions ) *  optimal_frequency ,
3708-         backend .float32 ,
3761+     optimal_frequency_values  =  backend .to_tensor (
3762+         optimal_frequency , dtype = backend .float32 
3763+     )
3764+     optimal_frequency_tensor  =  (
3765+         backend .ones_like (filled_data .rf_impressions ) *  optimal_frequency_values 
37093766    )
37103767    optimal_reach  =  filled_data .rf_impressions  /  optimal_frequency_tensor 
37113768
@@ -3997,10 +4054,11 @@ def get_rhat(self) -> Mapping[str, backend.Tensor]:
39974054          "sample_posterior() must be called prior to calling this method." 
39984055      )
39994056
4000-     def  _transpose_first_two_dims (x : backend .Tensor ) ->  backend .Tensor :
4001-       n_dim  =  len (x .shape )
4057+     def  _transpose_first_two_dims (x : Any ) ->  backend .Tensor :
4058+       x_tensor  =  backend .to_tensor (x )
4059+       n_dim  =  len (x_tensor .shape )
40024060      perm  =  [1 , 0 ] +  list (range (2 , n_dim ))
4003-       return  backend .transpose (x , perm )
4061+       return  backend .transpose (x_tensor , perm )
40044062
40054063    rhat  =  backend .mcmc .potential_scale_reduction ({
40064064        k : _transpose_first_two_dims (v )
0 commit comments