Skip to content

Commit 0a36e57

Browse files
cyrjanofacebook-github-bot
authored andcommitted
Add documentation and clean some typing information for common file. (meta-pytorch#1662)
Summary: Improve documentation of the _format method and cleanup a few typing errors. Differential Revision: D86432880
1 parent 7092433 commit 0a36e57

File tree

1 file changed

+48
-31
lines changed

1 file changed

+48
-31
lines changed

captum/_utils/common.py

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
import typing
54
from enum import Enum
65
from functools import reduce
76
from inspect import signature
@@ -65,7 +64,7 @@ def safe_div(
6564
denom: Union[Tensor, int, float],
6665
default_denom: Union[Tensor, int, float] = 1.0,
6766
) -> Tensor:
68-
r"""
67+
"""
6968
A simple utility function to perform `numerator / denom`
7069
if the statement is undefined => result will be `numerator / default_denorm`
7170
"""
@@ -81,15 +80,15 @@ def safe_div(
8180
return numerator / torch.where(denom != 0, denom, default_denom)
8281

8382

84-
@typing.overload
83+
@overload
8584
def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
8685

8786

88-
@typing.overload
87+
@overload
8988
def _is_tuple(inputs: Tensor) -> Literal[False]: ...
9089

9190

92-
@typing.overload
91+
@overload
9392
def _is_tuple(
9493
inputs: Union[Tensor, Tuple[Tensor, ...]],
9594
) -> bool: ...
@@ -150,7 +149,7 @@ def _validate_input(
150149

151150

152151
def _zeros(inputs: Tuple[Tensor, ...]) -> Tuple[int, ...]:
153-
r"""
152+
"""
154153
Takes a tuple of tensors as input and returns a tuple that has the same
155154
length as `inputs` with each element as the integer 0.
156155
"""
@@ -160,6 +159,10 @@ def _zeros(inputs: Tuple[Tensor, ...]) -> Tuple[int, ...]:
160159
def _format_baseline(
161160
baselines: BaselineType, inputs: Tuple[Tensor, ...]
162161
) -> Tuple[Union[Tensor, int, float], ...]:
162+
"""
163+
Converts baselines to tuple format, returning zeros if None,
164+
or wrapping single values in a tuple.
165+
"""
163166
if baselines is None:
164167
return _zeros(inputs)
165168

@@ -197,11 +200,8 @@ def _format_feature_mask(
197200
start_idx: int = 0,
198201
) -> Tuple[Tensor, ...]:
199202
"""
200-
Format a feature mask into a tuple of tensors.
201-
The `inputs` should be correctly formatted first
202-
If `feature_mask` is None, assign each non-batch dimension with a consecutive
203-
integer from `start_idx`.
204-
If `feature_mask` is a tensor, wrap it in a tuple.
203+
Converts feature mask to tuple format, auto-generating default mask
204+
from start_idx if None.
205205
"""
206206
if feature_mask is None:
207207
formatted_mask = []
@@ -240,6 +240,9 @@ def _format_tensor_into_tuples(
240240
def _format_tensor_into_tuples(
241241
inputs: Union[None, Tensor, Tuple[Tensor, ...]],
242242
) -> Union[None, Tuple[Tensor, ...]]:
243+
"""
244+
Converts tensor inputs to tuple format, returning None unchanged if None.
245+
"""
243246
if inputs is None:
244247
return None
245248
if not isinstance(inputs, tuple):
@@ -252,6 +255,10 @@ def _format_tensor_into_tuples(
252255

253256

254257
def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any:
258+
"""
259+
Returns inputs unchanged if already tuple/list
260+
and unpack_inputs=True, otherwise wraps in tuple.
261+
"""
255262
return (
256263
inputs
257264
if (isinstance(inputs, tuple) or isinstance(inputs, list)) and unpack_inputs
@@ -262,6 +269,9 @@ def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any:
262269
def _format_float_or_tensor_into_tuples(
263270
inputs: Union[float, Tensor, Tuple[Union[float, Tensor], ...]],
264271
) -> Tuple[Union[float, Tensor], ...]:
272+
"""
273+
Converts float or tensor inputs to tuple format, wrapping single values in a tuple.
274+
"""
265275
if not isinstance(inputs, tuple):
266276
assert isinstance(
267277
inputs, (torch.Tensor, float)
@@ -274,23 +284,28 @@ def _format_float_or_tensor_into_tuples(
274284

275285
@overload
276286
def _format_additional_forward_args(
277-
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
278-
additional_forward_args: Union[Tensor, Tuple],
279-
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
280-
) -> Tuple: ...
287+
additional_forward_args: Union[Tensor, Tuple[Any, ...]],
288+
) -> Tuple[Any, ...]: ...
281289

282290

283291
@overload
284-
def _format_additional_forward_args( # type: ignore
292+
def _format_additional_forward_args(
293+
additional_forward_args: None,
294+
) -> None: ...
295+
296+
297+
@overload
298+
def _format_additional_forward_args(
285299
additional_forward_args: Optional[object],
286-
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
287-
) -> Union[None, Tuple]: ...
300+
) -> Optional[Tuple[Any, ...]]: ...
288301

289302

290303
def _format_additional_forward_args(
291304
additional_forward_args: Optional[object],
292-
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
293-
) -> Union[None, Tuple]:
305+
) -> Optional[Tuple[Any, ...]]:
306+
"""
307+
Converts additional forward args to tuple format, returning None unchanged if None.
308+
"""
294309
if additional_forward_args is not None and not isinstance(
295310
additional_forward_args, tuple
296311
):
@@ -478,21 +493,21 @@ def _expand_and_update_feature_mask(n_samples: int, kwargs: dict) -> None:
478493
kwargs["feature_mask"] = feature_mask
479494

480495

481-
@typing.overload
496+
@overload
482497
def _format_output(
483498
is_inputs_tuple: Literal[True],
484499
output: Tuple[Tensor, ...],
485500
) -> Tuple[Tensor, ...]: ...
486501

487502

488-
@typing.overload
503+
@overload
489504
def _format_output(
490505
is_inputs_tuple: Literal[False],
491506
output: Tuple[Tensor, ...],
492507
) -> Tensor: ...
493508

494509

495-
@typing.overload
510+
@overload
496511
def _format_output(
497512
is_inputs_tuple: bool, output: Tuple[Tensor, ...]
498513
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
@@ -501,7 +516,7 @@ def _format_output(
501516
def _format_output(
502517
is_inputs_tuple: bool, output: Tuple[Tensor, ...]
503518
) -> Union[Tensor, Tuple[Tensor, ...]]:
504-
r"""
519+
"""
505520
In case input is a tensor and the output is returned in form of a
506521
tuple we take the first element of the output's tuple to match the
507522
same shape signatues of the inputs
@@ -516,21 +531,21 @@ def _format_output(
516531
return output if is_inputs_tuple else output[0]
517532

518533

519-
@typing.overload
534+
@overload
520535
def _format_outputs(
521536
is_multiple_inputs: Literal[False],
522537
outputs: List[Tuple[Tensor, ...]],
523538
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
524539

525540

526-
@typing.overload
541+
@overload
527542
def _format_outputs(
528543
is_multiple_inputs: Literal[True],
529544
outputs: List[Tuple[Tensor, ...]],
530545
) -> List[Union[Tensor, Tuple[Tensor, ...]]]: ...
531546

532547

533-
@typing.overload
548+
@overload
534549
def _format_outputs(
535550
is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]]
536551
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: ...
@@ -539,6 +554,10 @@ def _format_outputs(
539554
def _format_outputs(
540555
is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]]
541556
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
557+
"""
558+
Formats list of output tuples: returns list if is_multiple_inputs is True,
559+
otherwise single formatted output.
560+
"""
542561
assert isinstance(outputs, list), "Outputs must be a list"
543562
assert is_multiple_inputs or len(outputs) == 1, (
544563
"outputs should contain multiple inputs or have a single output"
@@ -554,9 +573,7 @@ def _format_outputs(
554573

555574
# pyre-fixme[24] Callable requires 2 arguments
556575
def _construct_future_forward(original_forward: Callable) -> Callable:
557-
# pyre-fixme[3] return type not specified
558-
def future_forward(*args: Any, **kwargs: Any):
559-
# pyre-fixme[29]: `typing.Type[torch.futures.Future]` is not a function.
576+
def future_forward(*args: Any, **kwargs: Any) -> torch.futures.Future[Tensor]:
560577
fut: torch.futures.Future[Tensor] = torch.futures.Future()
561578
fut.set_result(original_forward(*args, **kwargs))
562579
return fut
@@ -829,7 +846,7 @@ def _flatten_tensor_or_tuple(inp: TensorOrTupleOfTensorsGeneric) -> Tensor:
829846

830847

831848
def _get_module_from_name(model: Module, layer_name: str) -> Any:
832-
r"""
849+
"""
833850
Returns the module (layer) object, given its (string) name
834851
in the model.
835852

0 commit comments

Comments
 (0)