3737logger : logging .Logger = logging .getLogger (__name__ )
3838
3939
40+ def _parse_forward_out (forward_output : object ) -> Tensor :
41+ """
42+ A temp wrapper for global _run_forward util to force forward output
43+ type assertion & conversion.
44+ Remove after the strict logic is supported by all attr classes
45+ """
46+ if isinstance (forward_output , Tensor ):
47+ return forward_output
48+
49+ output_type = type (forward_output )
50+ assert output_type is int or output_type is float , (
51+ "the return of forward_func must be a tensor, int, or float,"
52+ f" received: { forward_output } "
53+ )
54+
55+ # using python built-in type as torch dtype
56+ # int -> torch.int64, float -> torch.float64
57+ # ref: https://github.com/pytorch/pytorch/pull/21215
58+ return torch .tensor (forward_output , dtype = cast (dtype , output_type ))
59+
60+
61+ def process_initial_eval (
62+ initial_eval : Tensor ,
63+ inputs : TensorOrTupleOfTensorsGeneric ,
64+ use_weights : bool = False ,
65+ ) -> Tuple [List [Tensor ], List [Tensor ], Tensor , Tensor , int , dtype ]:
66+
67+ initial_eval = _parse_forward_out (initial_eval )
68+
69+ # number of elements in the output of forward_func
70+ n_outputs = initial_eval .numel () if isinstance (initial_eval , Tensor ) else 1
71+
72+ # flatten eval outputs into 1D (n_outputs)
73+ # add the leading dim for n_feature_perturbed
74+ flattened_initial_eval = initial_eval .reshape (1 , - 1 )
75+
76+ # Initialize attribution totals and counts
77+ attrib_type = flattened_initial_eval .dtype
78+
79+ total_attrib = [
80+ # attribute w.r.t each output element
81+ torch .zeros (
82+ (n_outputs ,) + input .shape [1 :],
83+ dtype = attrib_type ,
84+ device = input .device ,
85+ )
86+ for input in inputs
87+ ]
88+
89+ # Weights are used in cases where ablations may be overlapping.
90+ weights = []
91+ if use_weights :
92+ weights = [
93+ torch .zeros ((n_outputs ,) + input .shape [1 :], device = input .device ).float ()
94+ for input in inputs
95+ ]
96+
97+ return (
98+ total_attrib ,
99+ weights ,
100+ initial_eval ,
101+ flattened_initial_eval ,
102+ n_outputs ,
103+ attrib_type ,
104+ )
105+
106+
107+ def format_result (
108+ total_attrib : List [Tensor ],
109+ weights : List [Tensor ],
110+ is_inputs_tuple : bool ,
111+ use_weights : bool ,
112+ ) -> Union [Tensor , Tuple [Tensor , ...]]:
113+ """Normalizes attributions by weights if enabled and formats output as single tensor or tuple."""
114+ # Divide total attributions by counts and return formatted attributions
115+ if use_weights :
116+ attrib = tuple (
117+ single_attrib .float () / weight
118+ for single_attrib , weight in zip (total_attrib , weights )
119+ )
120+ else :
121+ attrib = tuple (total_attrib )
122+ return _format_output (is_inputs_tuple , attrib )
123+
124+
40125class FeatureAblation (PerturbationAttribution ):
41126 r"""
42127 A perturbation based approach to computing attribution, involving
@@ -331,9 +416,8 @@ def attribute(
331416 flattened_initial_eval ,
332417 n_outputs ,
333418 attrib_type ,
334- ) = self ._process_initial_eval (
335- initial_eval ,
336- formatted_inputs ,
419+ ) = process_initial_eval (
420+ initial_eval , formatted_inputs , use_weights = self .use_weights
337421 )
338422
339423 total_attrib , weights = self ._attribute_with_cross_tensor_feature_masks (
@@ -358,7 +442,9 @@ def attribute(
358442
359443 return cast (
360444 TensorOrTupleOfTensorsGeneric ,
361- self ._generate_result (total_attrib , weights , is_inputs_tuple ),
445+ format_result (
446+ total_attrib , weights , is_inputs_tuple , use_weights = self .use_weights
447+ ),
362448 )
363449
364450 def _attribute_with_cross_tensor_feature_masks (
@@ -586,8 +672,8 @@ def _initial_eval_to_processed_initial_eval_fut(
586672 "initial_eval_to_processed_initial_eval_fut: "
587673 "initial_eval should be a Tensor"
588674 )
589- result = self . _process_initial_eval (
590- initial_eval_processed , formatted_inputs
675+ result = process_initial_eval (
676+ initial_eval_processed , formatted_inputs , use_weights = self . use_weights
591677 )
592678
593679 except FeatureAblationFutureError as e :
@@ -886,10 +972,8 @@ def _generate_async_result_cross_tensor(
886972 )
887973
888974 result_fut = collect_all (accumulate_fut_list ).then (
889- lambda x : self ._generate_result (
890- total_attrib ,
891- weights ,
892- is_inputs_tuple ,
975+ lambda x : format_result (
976+ total_attrib , weights , is_inputs_tuple , use_weights = self .use_weights
893977 )
894978 )
895979
@@ -955,70 +1039,6 @@ def _eval_fut_to_ablated_out_fut_cross_tensor(
9551039 ) from e
9561040 return total_attrib , weights
9571041
958- def _parse_forward_out (self , forward_output : Tensor ) -> Tensor :
959- """
960- A temp wrapper for global _run_forward util to force forward output
961- type assertion & conversion.
962- Remove after the strict logic is supported by all attr classes
963- """
964- if isinstance (forward_output , Tensor ):
965- return forward_output
966-
967- output_type = type (forward_output )
968- assert output_type is int or output_type is float , (
969- "the return of forward_func must be a tensor, int, or float,"
970- f" received: { forward_output } "
971- )
972-
973- # using python built-in type as torch dtype
974- # int -> torch.int64, float -> torch.float64
975- # ref: https://github.com/pytorch/pytorch/pull/21215
976- return torch .tensor (forward_output , dtype = cast (dtype , output_type ))
977-
978- def _process_initial_eval (
979- self ,
980- initial_eval : Tensor ,
981- inputs : TensorOrTupleOfTensorsGeneric ,
982- ) -> Tuple [List [Tensor ], List [Tensor ], Tensor , Tensor , int , dtype ]:
983- initial_eval = self ._parse_forward_out (initial_eval )
984-
985- # number of elements in the output of forward_func
986- n_outputs = initial_eval .numel () if isinstance (initial_eval , Tensor ) else 1
987-
988- # flatten eval outputs into 1D (n_outputs)
989- # add the leading dim for n_feature_perturbed
990- flattened_initial_eval = initial_eval .reshape (1 , - 1 )
991-
992- # Initialize attribution totals and counts
993- attrib_type = flattened_initial_eval .dtype
994-
995- total_attrib = [
996- # attribute w.r.t each output element
997- torch .zeros (
998- (n_outputs ,) + input .shape [1 :],
999- dtype = attrib_type ,
1000- device = input .device ,
1001- )
1002- for input in inputs
1003- ]
1004-
1005- # Weights are used in cases where ablations may be overlapping.
1006- weights = []
1007- if self .use_weights :
1008- weights = [
1009- torch .zeros ((n_outputs ,) + input .shape [1 :], device = input .device ).float ()
1010- for input in inputs
1011- ]
1012-
1013- return (
1014- total_attrib ,
1015- weights ,
1016- initial_eval ,
1017- flattened_initial_eval ,
1018- n_outputs ,
1019- attrib_type ,
1020- )
1021-
10221042 def _process_ablated_out_full (
10231043 self ,
10241044 modified_eval : Tensor ,
@@ -1033,7 +1053,7 @@ def _process_ablated_out_full(
10331053 attrib_type : dtype ,
10341054 perturbations_per_eval : int ,
10351055 ) -> Tuple [List [Tensor ], List [Tensor ]]:
1036- modified_eval = self . _parse_forward_out (modified_eval )
1056+ modified_eval = _parse_forward_out (modified_eval )
10371057 # if perturbations_per_eval > 1, the output shape must grow with
10381058 # input and not be aggregated
10391059 current_batch_size = inputs [0 ].shape [0 ]
@@ -1086,19 +1106,3 @@ def _process_ablated_out_full(
10861106 total_attrib [i ] += (eval_diff * mask .to (attrib_type )).sum (dim = 0 )
10871107
10881108 return total_attrib , weights
1089-
1090- def _generate_result (
1091- self ,
1092- total_attrib : List [Tensor ],
1093- weights : List [Tensor ],
1094- is_inputs_tuple : bool ,
1095- ) -> Union [Tensor , Tuple [Tensor , ...]]:
1096- # Divide total attributions by counts and return formatted attributions
1097- if self .use_weights :
1098- attrib = tuple (
1099- single_attrib .float () / weight
1100- for single_attrib , weight in zip (total_attrib , weights )
1101- )
1102- else :
1103- attrib = tuple (total_attrib )
1104- return _format_output (is_inputs_tuple , attrib )
0 commit comments