Skip to content

Commit 9f48afb

Browse files
cyrjanometa-codesync[bot]
authored andcommitted
Move methods out of FeatureAblation and make them free methods. (#1663)
Summary: Pull Request resolved: #1663 This diff moves methods out of the `FeatureAblation` class and makes them free methods. The changes include creating a new method `_parse_forward_out` to force forward output type assertion and conversion, and modifying the `add_one_back` module to use the new methods. The `attr/fb/add_one_back.py` file has been modified to use the new methods. The `attr/fb/within_group_utils.py` file has also been modified to use the new methods. The `attr/fb/test_within_group_utils.py` file has been Reviewed By: jjuncho Differential Revision: D86785624 fbshipit-source-id: 5b2709d469cfffdff3b4cbeafe52fff5bff5b5e4
1 parent 4f74df8 commit 9f48afb

File tree

3 files changed

+225
-95
lines changed

3 files changed

+225
-95
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 110 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,18 @@
44

55
import logging
66
import math
7-
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, TypeVar, Union
7+
from typing import (
8+
Any,
9+
Callable,
10+
cast,
11+
Dict,
12+
Iterable,
13+
List,
14+
Optional,
15+
Tuple,
16+
TypeVar,
17+
Union,
18+
)
819

920
import torch
1021
from captum._utils.common import (
@@ -37,6 +48,94 @@
3748
logger: logging.Logger = logging.getLogger(__name__)
3849

3950

51+
def _parse_forward_out(forward_output: object) -> Tensor:
52+
"""
53+
A temp wrapper for global _run_forward util to force forward output
54+
type assertion & conversion.
55+
Remove after the strict logic is supported by all attr classes
56+
"""
57+
if isinstance(forward_output, Tensor):
58+
return forward_output
59+
60+
output_type = type(forward_output)
61+
assert output_type is int or output_type is float, (
62+
"the return of forward_func must be a tensor, int, or float,"
63+
f" received: {forward_output}"
64+
)
65+
66+
# using python built-in type as torch dtype
67+
# int -> torch.int64, float -> torch.float64
68+
# ref: https://github.com/pytorch/pytorch/pull/21215
69+
return torch.tensor(forward_output, dtype=cast(dtype, output_type))
70+
71+
72+
def process_initial_eval(
73+
initial_eval: Tensor,
74+
inputs: Iterable[Tensor],
75+
use_weights: bool = False,
76+
) -> Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]:
77+
78+
initial_eval = _parse_forward_out(initial_eval)
79+
80+
# number of elements in the output of forward_func
81+
n_outputs = initial_eval.numel() if isinstance(initial_eval, Tensor) else 1
82+
83+
# flatten eval outputs into 1D (n_outputs)
84+
# add the leading dim for n_feature_perturbed
85+
flattened_initial_eval = initial_eval.reshape(1, -1)
86+
87+
# Initialize attribution totals and counts
88+
attrib_type = flattened_initial_eval.dtype
89+
90+
total_attrib = [
91+
# attribute w.r.t each output element
92+
torch.zeros(
93+
(n_outputs,) + input.shape[1:],
94+
dtype=attrib_type,
95+
device=input.device,
96+
)
97+
for input in inputs
98+
]
99+
100+
# Weights are used in cases where ablations may be overlapping.
101+
weights = []
102+
if use_weights:
103+
weights = [
104+
torch.zeros((n_outputs,) + input.shape[1:], device=input.device).float()
105+
for input in inputs
106+
]
107+
108+
return (
109+
total_attrib,
110+
weights,
111+
initial_eval,
112+
flattened_initial_eval,
113+
n_outputs,
114+
attrib_type,
115+
)
116+
117+
118+
def format_result(
119+
total_attrib: List[Tensor],
120+
weights: List[Tensor],
121+
is_inputs_tuple: bool,
122+
use_weights: bool,
123+
) -> Union[Tensor, Tuple[Tensor, ...]]:
124+
"""
125+
Normalizes attributions by weights if enabled and
126+
formats output as single tensor or tuple.
127+
"""
128+
# Divide total attributions by counts and return formatted attributions
129+
if use_weights:
130+
attrib = tuple(
131+
single_attrib.float() / weight
132+
for single_attrib, weight in zip(total_attrib, weights)
133+
)
134+
else:
135+
attrib = tuple(total_attrib)
136+
return _format_output(is_inputs_tuple, attrib)
137+
138+
40139
class FeatureAblation(PerturbationAttribution):
41140
r"""
42141
A perturbation based approach to computing attribution, involving
@@ -331,9 +430,8 @@ def attribute(
331430
flattened_initial_eval,
332431
n_outputs,
333432
attrib_type,
334-
) = self._process_initial_eval(
335-
initial_eval,
336-
formatted_inputs,
433+
) = process_initial_eval(
434+
initial_eval, formatted_inputs, use_weights=self.use_weights
337435
)
338436

339437
total_attrib, weights = self._attribute_with_cross_tensor_feature_masks(
@@ -358,7 +456,9 @@ def attribute(
358456

359457
return cast(
360458
TensorOrTupleOfTensorsGeneric,
361-
self._generate_result(total_attrib, weights, is_inputs_tuple),
459+
format_result(
460+
total_attrib, weights, is_inputs_tuple, use_weights=self.use_weights
461+
),
362462
)
363463

364464
def _attribute_with_cross_tensor_feature_masks(
@@ -586,8 +686,8 @@ def _initial_eval_to_processed_initial_eval_fut(
586686
"initial_eval_to_processed_initial_eval_fut: "
587687
"initial_eval should be a Tensor"
588688
)
589-
result = self._process_initial_eval(
590-
initial_eval_processed, formatted_inputs
689+
result = process_initial_eval(
690+
initial_eval_processed, formatted_inputs, use_weights=self.use_weights
591691
)
592692

593693
except FeatureAblationFutureError as e:
@@ -886,10 +986,8 @@ def _generate_async_result_cross_tensor(
886986
)
887987

888988
result_fut = collect_all(accumulate_fut_list).then(
889-
lambda x: self._generate_result(
890-
total_attrib,
891-
weights,
892-
is_inputs_tuple,
989+
lambda x: format_result(
990+
total_attrib, weights, is_inputs_tuple, use_weights=self.use_weights
893991
)
894992
)
895993

@@ -955,70 +1053,6 @@ def _eval_fut_to_ablated_out_fut_cross_tensor(
9551053
) from e
9561054
return total_attrib, weights
9571055

958-
def _parse_forward_out(self, forward_output: Tensor) -> Tensor:
959-
"""
960-
A temp wrapper for global _run_forward util to force forward output
961-
type assertion & conversion.
962-
Remove after the strict logic is supported by all attr classes
963-
"""
964-
if isinstance(forward_output, Tensor):
965-
return forward_output
966-
967-
output_type = type(forward_output)
968-
assert output_type is int or output_type is float, (
969-
"the return of forward_func must be a tensor, int, or float,"
970-
f" received: {forward_output}"
971-
)
972-
973-
# using python built-in type as torch dtype
974-
# int -> torch.int64, float -> torch.float64
975-
# ref: https://github.com/pytorch/pytorch/pull/21215
976-
return torch.tensor(forward_output, dtype=cast(dtype, output_type))
977-
978-
def _process_initial_eval(
979-
self,
980-
initial_eval: Tensor,
981-
inputs: TensorOrTupleOfTensorsGeneric,
982-
) -> Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]:
983-
initial_eval = self._parse_forward_out(initial_eval)
984-
985-
# number of elements in the output of forward_func
986-
n_outputs = initial_eval.numel() if isinstance(initial_eval, Tensor) else 1
987-
988-
# flatten eval outputs into 1D (n_outputs)
989-
# add the leading dim for n_feature_perturbed
990-
flattened_initial_eval = initial_eval.reshape(1, -1)
991-
992-
# Initialize attribution totals and counts
993-
attrib_type = flattened_initial_eval.dtype
994-
995-
total_attrib = [
996-
# attribute w.r.t each output element
997-
torch.zeros(
998-
(n_outputs,) + input.shape[1:],
999-
dtype=attrib_type,
1000-
device=input.device,
1001-
)
1002-
for input in inputs
1003-
]
1004-
1005-
# Weights are used in cases where ablations may be overlapping.
1006-
weights = []
1007-
if self.use_weights:
1008-
weights = [
1009-
torch.zeros((n_outputs,) + input.shape[1:], device=input.device).float()
1010-
for input in inputs
1011-
]
1012-
1013-
return (
1014-
total_attrib,
1015-
weights,
1016-
initial_eval,
1017-
flattened_initial_eval,
1018-
n_outputs,
1019-
attrib_type,
1020-
)
1021-
10221056
def _process_ablated_out_full(
10231057
self,
10241058
modified_eval: Tensor,
@@ -1033,7 +1067,7 @@ def _process_ablated_out_full(
10331067
attrib_type: dtype,
10341068
perturbations_per_eval: int,
10351069
) -> Tuple[List[Tensor], List[Tensor]]:
1036-
modified_eval = self._parse_forward_out(modified_eval)
1070+
modified_eval = _parse_forward_out(modified_eval)
10371071
# if perturbations_per_eval > 1, the output shape must grow with
10381072
# input and not be aggregated
10391073
current_batch_size = inputs[0].shape[0]
@@ -1086,19 +1120,3 @@ def _process_ablated_out_full(
10861120
total_attrib[i] += (eval_diff * mask.to(attrib_type)).sum(dim=0)
10871121

10881122
return total_attrib, weights
1089-
1090-
def _generate_result(
1091-
self,
1092-
total_attrib: List[Tensor],
1093-
weights: List[Tensor],
1094-
is_inputs_tuple: bool,
1095-
) -> Union[Tensor, Tuple[Tensor, ...]]:
1096-
# Divide total attributions by counts and return formatted attributions
1097-
if self.use_weights:
1098-
attrib = tuple(
1099-
single_attrib.float() / weight
1100-
for single_attrib, weight in zip(total_attrib, weights)
1101-
)
1102-
else:
1103-
attrib = tuple(total_attrib)
1104-
return _format_output(is_inputs_tuple, attrib)

captum/testing/helpers/basic_models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,6 @@ def forward(
581581
self.relu(lin1_out)
582582
else:
583583
relu_out = self.relu(lin1_out)
584-
# pyre-fixme [29]: `typing.Type[Future]` is not a function
585584
result = Future()
586585
lin2_out = self.linear2(relu_out)
587586
if multidim_output:

0 commit comments

Comments
 (0)