44
55import logging
66import math
7- from typing import Any , Callable , cast , Dict , List , Optional , Tuple , TypeVar , Union
7+ from typing import (
8+ Any ,
9+ Callable ,
10+ cast ,
11+ Dict ,
12+ Iterable ,
13+ List ,
14+ Optional ,
15+ Tuple ,
16+ TypeVar ,
17+ Union ,
18+ )
819
920import torch
1021from captum ._utils .common import (
3748logger : logging .Logger = logging .getLogger (__name__ )
3849
3950
51+ def _parse_forward_out (forward_output : object ) -> Tensor :
52+ """
53+ A temp wrapper for global _run_forward util to force forward output
54+ type assertion & conversion.
55+ Remove after the strict logic is supported by all attr classes
56+ """
57+ if isinstance (forward_output , Tensor ):
58+ return forward_output
59+
60+ output_type = type (forward_output )
61+ assert output_type is int or output_type is float , (
62+ "the return of forward_func must be a tensor, int, or float,"
63+ f" received: { forward_output } "
64+ )
65+
66+ # using python built-in type as torch dtype
67+ # int -> torch.int64, float -> torch.float64
68+ # ref: https://github.com/pytorch/pytorch/pull/21215
69+ return torch .tensor (forward_output , dtype = cast (dtype , output_type ))
70+
71+
72+ def process_initial_eval (
73+ initial_eval : Tensor ,
74+ inputs : Iterable [Tensor ],
75+ use_weights : bool = False ,
76+ ) -> Tuple [List [Tensor ], List [Tensor ], Tensor , Tensor , int , dtype ]:
77+
78+ initial_eval = _parse_forward_out (initial_eval )
79+
80+ # number of elements in the output of forward_func
81+ n_outputs = initial_eval .numel () if isinstance (initial_eval , Tensor ) else 1
82+
83+ # flatten eval outputs into 1D (n_outputs)
84+ # add the leading dim for n_feature_perturbed
85+ flattened_initial_eval = initial_eval .reshape (1 , - 1 )
86+
87+ # Initialize attribution totals and counts
88+ attrib_type = flattened_initial_eval .dtype
89+
90+ total_attrib = [
91+ # attribute w.r.t each output element
92+ torch .zeros (
93+ (n_outputs ,) + input .shape [1 :],
94+ dtype = attrib_type ,
95+ device = input .device ,
96+ )
97+ for input in inputs
98+ ]
99+
100+ # Weights are used in cases where ablations may be overlapping.
101+ weights = []
102+ if use_weights :
103+ weights = [
104+ torch .zeros ((n_outputs ,) + input .shape [1 :], device = input .device ).float ()
105+ for input in inputs
106+ ]
107+
108+ return (
109+ total_attrib ,
110+ weights ,
111+ initial_eval ,
112+ flattened_initial_eval ,
113+ n_outputs ,
114+ attrib_type ,
115+ )
116+
117+
118+ def format_result (
119+ total_attrib : List [Tensor ],
120+ weights : List [Tensor ],
121+ is_inputs_tuple : bool ,
122+ use_weights : bool ,
123+ ) -> Union [Tensor , Tuple [Tensor , ...]]:
124+ """
125+ Normalizes attributions by weights if enabled and
126+ formats output as single tensor or tuple.
127+ """
128+ # Divide total attributions by counts and return formatted attributions
129+ if use_weights :
130+ attrib = tuple (
131+ single_attrib .float () / weight
132+ for single_attrib , weight in zip (total_attrib , weights )
133+ )
134+ else :
135+ attrib = tuple (total_attrib )
136+ return _format_output (is_inputs_tuple , attrib )
137+
138+
40139class FeatureAblation (PerturbationAttribution ):
41140 r"""
42141 A perturbation based approach to computing attribution, involving
@@ -331,9 +430,8 @@ def attribute(
331430 flattened_initial_eval ,
332431 n_outputs ,
333432 attrib_type ,
334- ) = self ._process_initial_eval (
335- initial_eval ,
336- formatted_inputs ,
433+ ) = process_initial_eval (
434+ initial_eval , formatted_inputs , use_weights = self .use_weights
337435 )
338436
339437 total_attrib , weights = self ._attribute_with_cross_tensor_feature_masks (
@@ -358,7 +456,9 @@ def attribute(
358456
359457 return cast (
360458 TensorOrTupleOfTensorsGeneric ,
361- self ._generate_result (total_attrib , weights , is_inputs_tuple ),
459+ format_result (
460+ total_attrib , weights , is_inputs_tuple , use_weights = self .use_weights
461+ ),
362462 )
363463
364464 def _attribute_with_cross_tensor_feature_masks (
@@ -586,8 +686,8 @@ def _initial_eval_to_processed_initial_eval_fut(
586686 "initial_eval_to_processed_initial_eval_fut: "
587687 "initial_eval should be a Tensor"
588688 )
589- result = self . _process_initial_eval (
590- initial_eval_processed , formatted_inputs
689+ result = process_initial_eval (
690+ initial_eval_processed , formatted_inputs , use_weights = self . use_weights
591691 )
592692
593693 except FeatureAblationFutureError as e :
@@ -886,10 +986,8 @@ def _generate_async_result_cross_tensor(
886986 )
887987
888988 result_fut = collect_all (accumulate_fut_list ).then (
889- lambda x : self ._generate_result (
890- total_attrib ,
891- weights ,
892- is_inputs_tuple ,
989+ lambda x : format_result (
990+ total_attrib , weights , is_inputs_tuple , use_weights = self .use_weights
893991 )
894992 )
895993
@@ -955,70 +1053,6 @@ def _eval_fut_to_ablated_out_fut_cross_tensor(
9551053 ) from e
9561054 return total_attrib , weights
9571055
958- def _parse_forward_out (self , forward_output : Tensor ) -> Tensor :
959- """
960- A temp wrapper for global _run_forward util to force forward output
961- type assertion & conversion.
962- Remove after the strict logic is supported by all attr classes
963- """
964- if isinstance (forward_output , Tensor ):
965- return forward_output
966-
967- output_type = type (forward_output )
968- assert output_type is int or output_type is float , (
969- "the return of forward_func must be a tensor, int, or float,"
970- f" received: { forward_output } "
971- )
972-
973- # using python built-in type as torch dtype
974- # int -> torch.int64, float -> torch.float64
975- # ref: https://github.com/pytorch/pytorch/pull/21215
976- return torch .tensor (forward_output , dtype = cast (dtype , output_type ))
977-
978- def _process_initial_eval (
979- self ,
980- initial_eval : Tensor ,
981- inputs : TensorOrTupleOfTensorsGeneric ,
982- ) -> Tuple [List [Tensor ], List [Tensor ], Tensor , Tensor , int , dtype ]:
983- initial_eval = self ._parse_forward_out (initial_eval )
984-
985- # number of elements in the output of forward_func
986- n_outputs = initial_eval .numel () if isinstance (initial_eval , Tensor ) else 1
987-
988- # flatten eval outputs into 1D (n_outputs)
989- # add the leading dim for n_feature_perturbed
990- flattened_initial_eval = initial_eval .reshape (1 , - 1 )
991-
992- # Initialize attribution totals and counts
993- attrib_type = flattened_initial_eval .dtype
994-
995- total_attrib = [
996- # attribute w.r.t each output element
997- torch .zeros (
998- (n_outputs ,) + input .shape [1 :],
999- dtype = attrib_type ,
1000- device = input .device ,
1001- )
1002- for input in inputs
1003- ]
1004-
1005- # Weights are used in cases where ablations may be overlapping.
1006- weights = []
1007- if self .use_weights :
1008- weights = [
1009- torch .zeros ((n_outputs ,) + input .shape [1 :], device = input .device ).float ()
1010- for input in inputs
1011- ]
1012-
1013- return (
1014- total_attrib ,
1015- weights ,
1016- initial_eval ,
1017- flattened_initial_eval ,
1018- n_outputs ,
1019- attrib_type ,
1020- )
1021-
10221056 def _process_ablated_out_full (
10231057 self ,
10241058 modified_eval : Tensor ,
@@ -1033,7 +1067,7 @@ def _process_ablated_out_full(
10331067 attrib_type : dtype ,
10341068 perturbations_per_eval : int ,
10351069 ) -> Tuple [List [Tensor ], List [Tensor ]]:
1036- modified_eval = self . _parse_forward_out (modified_eval )
1070+ modified_eval = _parse_forward_out (modified_eval )
10371071 # if perturbations_per_eval > 1, the output shape must grow with
10381072 # input and not be aggregated
10391073 current_batch_size = inputs [0 ].shape [0 ]
@@ -1086,19 +1120,3 @@ def _process_ablated_out_full(
10861120 total_attrib [i ] += (eval_diff * mask .to (attrib_type )).sum (dim = 0 )
10871121
10881122 return total_attrib , weights
1089-
1090- def _generate_result (
1091- self ,
1092- total_attrib : List [Tensor ],
1093- weights : List [Tensor ],
1094- is_inputs_tuple : bool ,
1095- ) -> Union [Tensor , Tuple [Tensor , ...]]:
1096- # Divide total attributions by counts and return formatted attributions
1097- if self .use_weights :
1098- attrib = tuple (
1099- single_attrib .float () / weight
1100- for single_attrib , weight in zip (total_attrib , weights )
1101- )
1102- else :
1103- attrib = tuple (total_attrib )
1104- return _format_output (is_inputs_tuple , attrib )
0 commit comments