Skip to content

Commit 92bcb51

Browse files
cyrjanofacebook-github-bot
authored andcommitted
Move methods out of FeatureAblation and make them free methods. (#1663)
Summary: This diff moves methods out of the `FeatureAblation` class and makes them free methods. The changes include creating a new method `_parse_forward_out` to force forward output type assertion and conversion, and modifying the `add_one_back` module to use the new methods. The `attr/fb/add_one_back.py` file has been modified to use the new methods. The `attr/fb/within_group_utils.py` file has also been modified to use the new methods. The `attr/fb/test_within_group_utils.py` file has been Differential Revision: D86785624
1 parent a1a64d0 commit 92bcb51

File tree

3 files changed

+210
-94
lines changed

3 files changed

+210
-94
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 95 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,91 @@
3737
logger: logging.Logger = logging.getLogger(__name__)
3838

3939

40+
def _parse_forward_out(forward_output: object) -> Tensor:
41+
"""
42+
A temp wrapper for global _run_forward util to force forward output
43+
type assertion & conversion.
44+
Remove after the strict logic is supported by all attr classes
45+
"""
46+
if isinstance(forward_output, Tensor):
47+
return forward_output
48+
49+
output_type = type(forward_output)
50+
assert output_type is int or output_type is float, (
51+
"the return of forward_func must be a tensor, int, or float,"
52+
f" received: {forward_output}"
53+
)
54+
55+
# using python built-in type as torch dtype
56+
# int -> torch.int64, float -> torch.float64
57+
# ref: https://github.com/pytorch/pytorch/pull/21215
58+
return torch.tensor(forward_output, dtype=cast(dtype, output_type))
59+
60+
61+
def process_initial_eval(
62+
initial_eval: Tensor,
63+
inputs: TensorOrTupleOfTensorsGeneric,
64+
use_weights: bool = False,
65+
) -> Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]:
66+
67+
initial_eval = _parse_forward_out(initial_eval)
68+
69+
# number of elements in the output of forward_func
70+
n_outputs = initial_eval.numel() if isinstance(initial_eval, Tensor) else 1
71+
72+
# flatten eval outputs into 1D (n_outputs)
73+
# add the leading dim for n_feature_perturbed
74+
flattened_initial_eval = initial_eval.reshape(1, -1)
75+
76+
# Initialize attribution totals and counts
77+
attrib_type = flattened_initial_eval.dtype
78+
79+
total_attrib = [
80+
# attribute w.r.t each output element
81+
torch.zeros(
82+
(n_outputs,) + input.shape[1:],
83+
dtype=attrib_type,
84+
device=input.device,
85+
)
86+
for input in inputs
87+
]
88+
89+
# Weights are used in cases where ablations may be overlapping.
90+
weights = []
91+
if use_weights:
92+
weights = [
93+
torch.zeros((n_outputs,) + input.shape[1:], device=input.device).float()
94+
for input in inputs
95+
]
96+
97+
return (
98+
total_attrib,
99+
weights,
100+
initial_eval,
101+
flattened_initial_eval,
102+
n_outputs,
103+
attrib_type,
104+
)
105+
106+
107+
def format_result(
108+
total_attrib: List[Tensor],
109+
weights: List[Tensor],
110+
is_inputs_tuple: bool,
111+
use_weights: bool,
112+
) -> Union[Tensor, Tuple[Tensor, ...]]:
113+
"""Normalizes attributions by weights if enabled and formats output as single tensor or tuple."""
114+
# Divide total attributions by counts and return formatted attributions
115+
if use_weights:
116+
attrib = tuple(
117+
single_attrib.float() / weight
118+
for single_attrib, weight in zip(total_attrib, weights)
119+
)
120+
else:
121+
attrib = tuple(total_attrib)
122+
return _format_output(is_inputs_tuple, attrib)
123+
124+
40125
class FeatureAblation(PerturbationAttribution):
41126
r"""
42127
A perturbation based approach to computing attribution, involving
@@ -331,9 +416,8 @@ def attribute(
331416
flattened_initial_eval,
332417
n_outputs,
333418
attrib_type,
334-
) = self._process_initial_eval(
335-
initial_eval,
336-
formatted_inputs,
419+
) = process_initial_eval(
420+
initial_eval, formatted_inputs, use_weights=self.use_weights
337421
)
338422

339423
total_attrib, weights = self._attribute_with_cross_tensor_feature_masks(
@@ -358,7 +442,9 @@ def attribute(
358442

359443
return cast(
360444
TensorOrTupleOfTensorsGeneric,
361-
self._generate_result(total_attrib, weights, is_inputs_tuple),
445+
format_result(
446+
total_attrib, weights, is_inputs_tuple, use_weights=self.use_weights
447+
),
362448
)
363449

364450
def _attribute_with_cross_tensor_feature_masks(
@@ -586,8 +672,8 @@ def _initial_eval_to_processed_initial_eval_fut(
586672
"initial_eval_to_processed_initial_eval_fut: "
587673
"initial_eval should be a Tensor"
588674
)
589-
result = self._process_initial_eval(
590-
initial_eval_processed, formatted_inputs
675+
result = process_initial_eval(
676+
initial_eval_processed, formatted_inputs, use_weights=self.use_weights
591677
)
592678

593679
except FeatureAblationFutureError as e:
@@ -886,10 +972,8 @@ def _generate_async_result_cross_tensor(
886972
)
887973

888974
result_fut = collect_all(accumulate_fut_list).then(
889-
lambda x: self._generate_result(
890-
total_attrib,
891-
weights,
892-
is_inputs_tuple,
975+
lambda x: format_result(
976+
total_attrib, weights, is_inputs_tuple, use_weights=self.use_weights
893977
)
894978
)
895979

@@ -955,70 +1039,6 @@ def _eval_fut_to_ablated_out_fut_cross_tensor(
9551039
) from e
9561040
return total_attrib, weights
9571041

958-
def _parse_forward_out(self, forward_output: Tensor) -> Tensor:
959-
"""
960-
A temp wrapper for global _run_forward util to force forward output
961-
type assertion & conversion.
962-
Remove after the strict logic is supported by all attr classes
963-
"""
964-
if isinstance(forward_output, Tensor):
965-
return forward_output
966-
967-
output_type = type(forward_output)
968-
assert output_type is int or output_type is float, (
969-
"the return of forward_func must be a tensor, int, or float,"
970-
f" received: {forward_output}"
971-
)
972-
973-
# using python built-in type as torch dtype
974-
# int -> torch.int64, float -> torch.float64
975-
# ref: https://github.com/pytorch/pytorch/pull/21215
976-
return torch.tensor(forward_output, dtype=cast(dtype, output_type))
977-
978-
def _process_initial_eval(
979-
self,
980-
initial_eval: Tensor,
981-
inputs: TensorOrTupleOfTensorsGeneric,
982-
) -> Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]:
983-
initial_eval = self._parse_forward_out(initial_eval)
984-
985-
# number of elements in the output of forward_func
986-
n_outputs = initial_eval.numel() if isinstance(initial_eval, Tensor) else 1
987-
988-
# flatten eval outputs into 1D (n_outputs)
989-
# add the leading dim for n_feature_perturbed
990-
flattened_initial_eval = initial_eval.reshape(1, -1)
991-
992-
# Initialize attribution totals and counts
993-
attrib_type = flattened_initial_eval.dtype
994-
995-
total_attrib = [
996-
# attribute w.r.t each output element
997-
torch.zeros(
998-
(n_outputs,) + input.shape[1:],
999-
dtype=attrib_type,
1000-
device=input.device,
1001-
)
1002-
for input in inputs
1003-
]
1004-
1005-
# Weights are used in cases where ablations may be overlapping.
1006-
weights = []
1007-
if self.use_weights:
1008-
weights = [
1009-
torch.zeros((n_outputs,) + input.shape[1:], device=input.device).float()
1010-
for input in inputs
1011-
]
1012-
1013-
return (
1014-
total_attrib,
1015-
weights,
1016-
initial_eval,
1017-
flattened_initial_eval,
1018-
n_outputs,
1019-
attrib_type,
1020-
)
1021-
10221042
def _process_ablated_out_full(
10231043
self,
10241044
modified_eval: Tensor,
@@ -1033,7 +1053,7 @@ def _process_ablated_out_full(
10331053
attrib_type: dtype,
10341054
perturbations_per_eval: int,
10351055
) -> Tuple[List[Tensor], List[Tensor]]:
1036-
modified_eval = self._parse_forward_out(modified_eval)
1056+
modified_eval = _parse_forward_out(modified_eval)
10371057
# if perturbations_per_eval > 1, the output shape must grow with
10381058
# input and not be aggregated
10391059
current_batch_size = inputs[0].shape[0]
@@ -1086,19 +1106,3 @@ def _process_ablated_out_full(
10861106
total_attrib[i] += (eval_diff * mask.to(attrib_type)).sum(dim=0)
10871107

10881108
return total_attrib, weights
1089-
1090-
def _generate_result(
1091-
self,
1092-
total_attrib: List[Tensor],
1093-
weights: List[Tensor],
1094-
is_inputs_tuple: bool,
1095-
) -> Union[Tensor, Tuple[Tensor, ...]]:
1096-
# Divide total attributions by counts and return formatted attributions
1097-
if self.use_weights:
1098-
attrib = tuple(
1099-
single_attrib.float() / weight
1100-
for single_attrib, weight in zip(total_attrib, weights)
1101-
)
1102-
else:
1103-
attrib = tuple(total_attrib)
1104-
return _format_output(is_inputs_tuple, attrib)

captum/testing/helpers/basic_models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,6 @@ def forward(
581581
self.relu(lin1_out)
582582
else:
583583
relu_out = self.relu(lin1_out)
584-
# pyre-fixme [29]: `typing.Type[Future]` is not a function
585584
result = Future()
586585
lin2_out = self.linear2(relu_out)
587586
if multidim_output:

tests/attr/test_feature_ablation.py

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
import torch
1313
from captum._utils.common import _construct_future_forward
1414
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
15-
from captum.attr._core.feature_ablation import FeatureAblation
15+
from captum.attr._core.feature_ablation import (
16+
_parse_forward_out,
17+
FeatureAblation,
18+
format_result,
19+
)
1620
from captum.attr._core.noise_tunnel import NoiseTunnel
1721
from captum.attr._utils.attribution import Attribution
1822
from captum.testing.helpers import BaseTest
@@ -595,7 +599,6 @@ def slow_set_future(fut: torch.futures.Future[Tensor], value: Tensor) -> None:
595599
fut.set_result(out)
596600

597601
def forward_func(inp: Tensor) -> torch.futures.Future[Tensor]:
598-
# pyre-fixme[29]: `typing.Type[torch.futures.Future]` is not a function.
599602
fut: torch.futures.Future[Tensor] = torch.futures.Future()
600603
t = threading.Thread(target=slow_set_future, args=(fut, inp))
601604
t.start()
@@ -900,5 +903,115 @@ def _ablation_test_assert(
900903
assertTensorAlmostEqual(self, attributions, expected_ablation)
901904

902905

906+
class TestParseForwardOutput(BaseTest):
907+
908+
def test_parse_forward_out_tensor_passthrough(self) -> None:
909+
input_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
910+
result = _parse_forward_out(input_tensor)
911+
912+
self.assertIs(result, input_tensor)
913+
assertTensorAlmostEqual(self, result, input_tensor)
914+
915+
def test_parse_forward_out_python_int(self) -> None:
916+
input_value = 42
917+
result = _parse_forward_out(input_value)
918+
919+
self.assertIsInstance(result, Tensor)
920+
self.assertEqual(result.dtype, torch.int64)
921+
assertTensorAlmostEqual(self, result, torch.tensor(42))
922+
923+
def test_parse_forward_out_python_float(self) -> None:
924+
input_value = 3.14
925+
result = _parse_forward_out(input_value)
926+
927+
self.assertIsInstance(result, Tensor)
928+
self.assertEqual(result.dtype, torch.float64)
929+
assertTensorAlmostEqual(self, result, torch.tensor(3.14))
930+
931+
def test_parse_forward_out_invalid_none(self) -> None:
932+
with self.assertRaises(AssertionError):
933+
_parse_forward_out(None)
934+
935+
936+
class TestFormatResult(BaseTest):
937+
938+
def test_format_result_single_tensor_no_weights(self) -> None:
939+
total_attrib = [torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])]
940+
weights = []
941+
is_inputs_tuple = False
942+
use_weights = False
943+
944+
result = format_result(total_attrib, weights, is_inputs_tuple, use_weights)
945+
946+
self.assertIsInstance(result, Tensor)
947+
assert isinstance(result, Tensor) # Type narrowing for pyre
948+
self.assertEqual(result.shape, (2, 3))
949+
assertTensorAlmostEqual(
950+
self, result, torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
951+
)
952+
953+
def test_format_result_tuple_output_no_weights(self) -> None:
954+
total_attrib = [
955+
torch.tensor([[1.0, 2.0], [3.0, 4.0]]),
956+
torch.tensor([[5.0, 6.0], [7.0, 8.0]]),
957+
]
958+
weights = []
959+
is_inputs_tuple = True
960+
use_weights = False
961+
962+
result = format_result(total_attrib, weights, is_inputs_tuple, use_weights)
963+
964+
self.assertIsInstance(result, tuple)
965+
self.assertEqual(len(result), 2)
966+
assertTensorAlmostEqual(self, result[0], torch.tensor([[1.0, 2.0], [3.0, 4.0]]))
967+
assertTensorAlmostEqual(self, result[1], torch.tensor([[5.0, 6.0], [7.0, 8.0]]))
968+
969+
def test_format_result_single_tensor_with_weights(self) -> None:
970+
total_attrib = [torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]])]
971+
weights = [torch.tensor([[2.0, 4.0, 5.0], [8.0, 10.0, 12.0]])]
972+
is_inputs_tuple = False
973+
use_weights = True
974+
975+
result = format_result(total_attrib, weights, is_inputs_tuple, use_weights)
976+
977+
self.assertIsInstance(result, Tensor)
978+
expected = torch.tensor([[5.0, 5.0, 6.0], [5.0, 5.0, 5.0]])
979+
assertTensorAlmostEqual(self, result, expected)
980+
981+
def test_format_result_tuple_output_with_weights(self) -> None:
982+
total_attrib = [
983+
torch.tensor([[10.0, 20.0], [30.0, 40.0]]),
984+
torch.tensor([[50.0, 60.0], [70.0, 80.0]]),
985+
]
986+
weights = [
987+
torch.tensor([[2.0, 4.0], [5.0, 8.0]]),
988+
torch.tensor([[10.0, 12.0], [14.0, 16.0]]),
989+
]
990+
is_inputs_tuple = True
991+
use_weights = True
992+
993+
result = format_result(total_attrib, weights, is_inputs_tuple, use_weights)
994+
995+
self.assertIsInstance(result, tuple)
996+
self.assertEqual(len(result), 2)
997+
assertTensorAlmostEqual(self, result[0], torch.tensor([[5.0, 5.0], [6.0, 5.0]]))
998+
assertTensorAlmostEqual(self, result[1], torch.tensor([[5.0, 5.0], [5.0, 5.0]]))
999+
1000+
def test_format_result_integer_dtype_no_weights(self) -> None:
1001+
total_attrib = [torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32)]
1002+
weights = []
1003+
is_inputs_tuple = False
1004+
use_weights = False
1005+
1006+
result = format_result(total_attrib, weights, is_inputs_tuple, use_weights)
1007+
1008+
self.assertIsInstance(result, Tensor)
1009+
assert isinstance(result, Tensor) # Type narrowing for pyre
1010+
self.assertEqual(result.dtype, torch.int32)
1011+
assertTensorAlmostEqual(
1012+
self, result, torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32)
1013+
)
1014+
1015+
9031016
if __name__ == "__main__":
9041017
unittest.main()

0 commit comments

Comments
 (0)