11#!/usr/bin/env python3
22
33# pyre-strict
4- import typing
54from enum import Enum
65from functools import reduce
76from inspect import signature
@@ -65,7 +64,7 @@ def safe_div(
6564 denom : Union [Tensor , int , float ],
6665 default_denom : Union [Tensor , int , float ] = 1.0 ,
6766) -> Tensor :
68- r """
67+ """
6968 A simple utility function to perform `numerator / denom`
7069 if the statement is undefined => result will be `numerator / default_denorm`
7170 """
@@ -81,15 +80,15 @@ def safe_div(
8180 return numerator / torch .where (denom != 0 , denom , default_denom )
8281
8382
84- @typing . overload
83+ @overload
8584def _is_tuple (inputs : Tuple [Tensor , ...]) -> Literal [True ]: ...
8685
8786
88- @typing . overload
87+ @overload
8988def _is_tuple (inputs : Tensor ) -> Literal [False ]: ...
9089
9190
92- @typing . overload
91+ @overload
9392def _is_tuple (
9493 inputs : Union [Tensor , Tuple [Tensor , ...]],
9594) -> bool : ...
@@ -150,7 +149,7 @@ def _validate_input(
150149
151150
152151def _zeros (inputs : Tuple [Tensor , ...]) -> Tuple [int , ...]:
153- r """
152+ """
154153 Takes a tuple of tensors as input and returns a tuple that has the same
155154 length as `inputs` with each element as the integer 0.
156155 """
@@ -160,6 +159,10 @@ def _zeros(inputs: Tuple[Tensor, ...]) -> Tuple[int, ...]:
160159def _format_baseline (
161160 baselines : BaselineType , inputs : Tuple [Tensor , ...]
162161) -> Tuple [Union [Tensor , int , float ], ...]:
162+ """
163+ Converts baselines to tuple format, returning zeros if None,
164+ or wrapping single values in a tuple.
165+ """
163166 if baselines is None :
164167 return _zeros (inputs )
165168
@@ -197,11 +200,8 @@ def _format_feature_mask(
197200 start_idx : int = 0 ,
198201) -> Tuple [Tensor , ...]:
199202 """
200- Format a feature mask into a tuple of tensors.
201- The `inputs` should be correctly formatted first
202- If `feature_mask` is None, assign each non-batch dimension with a consecutive
203- integer from `start_idx`.
204- If `feature_mask` is a tensor, wrap it in a tuple.
203+ Converts feature mask to tuple format, auto-generating default mask
204+ from start_idx if None.
205205 """
206206 if feature_mask is None :
207207 formatted_mask = []
@@ -240,6 +240,9 @@ def _format_tensor_into_tuples(
240240def _format_tensor_into_tuples (
241241 inputs : Union [None , Tensor , Tuple [Tensor , ...]],
242242) -> Union [None , Tuple [Tensor , ...]]:
243+ """
244+ Converts tensor inputs to tuple format, returning None unchanged if None.
245+ """
243246 if inputs is None :
244247 return None
245248 if not isinstance (inputs , tuple ):
@@ -252,6 +255,10 @@ def _format_tensor_into_tuples(
252255
253256
254257def _format_inputs (inputs : Any , unpack_inputs : bool = True ) -> Any :
258+ """
259+ Returns inputs unchanged if already tuple/list
260+ and unpack_inputs=True, otherwise wraps in tuple.
261+ """
255262 return (
256263 inputs
257264 if (isinstance (inputs , tuple ) or isinstance (inputs , list )) and unpack_inputs
@@ -262,6 +269,9 @@ def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any:
262269def _format_float_or_tensor_into_tuples (
263270 inputs : Union [float , Tensor , Tuple [Union [float , Tensor ], ...]],
264271) -> Tuple [Union [float , Tensor ], ...]:
272+ """
273+ Converts float or tensor inputs to tuple format, wrapping single values in a tuple.
274+ """
265275 if not isinstance (inputs , tuple ):
266276 assert isinstance (
267277 inputs , (torch .Tensor , float )
@@ -274,23 +284,28 @@ def _format_float_or_tensor_into_tuples(
274284
275285@overload
276286def _format_additional_forward_args (
277- # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
278- additional_forward_args : Union [Tensor , Tuple ],
279- # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
280- ) -> Tuple : ...
287+ additional_forward_args : Union [Tensor , Tuple [Any , ...]],
288+ ) -> Tuple [Any , ...]: ...
281289
282290
283291@overload
284- def _format_additional_forward_args ( # type: ignore
292+ def _format_additional_forward_args (
293+ additional_forward_args : None ,
294+ ) -> None : ...
295+
296+
297+ @overload
298+ def _format_additional_forward_args (
285299 additional_forward_args : Optional [object ],
286- # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
287- ) -> Union [None , Tuple ]: ...
300+ ) -> Optional [Tuple [Any , ...]]: ...
288301
289302
290303def _format_additional_forward_args (
291304 additional_forward_args : Optional [object ],
292- # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
293- ) -> Union [None , Tuple ]:
305+ ) -> Optional [Tuple [Any , ...]]:
306+ """
307+ Converts additional forward args to tuple format, returning None unchanged if None.
308+ """
294309 if additional_forward_args is not None and not isinstance (
295310 additional_forward_args , tuple
296311 ):
@@ -478,21 +493,21 @@ def _expand_and_update_feature_mask(n_samples: int, kwargs: dict) -> None:
478493 kwargs ["feature_mask" ] = feature_mask
479494
480495
481- @typing . overload
496+ @overload
482497def _format_output (
483498 is_inputs_tuple : Literal [True ],
484499 output : Tuple [Tensor , ...],
485500) -> Tuple [Tensor , ...]: ...
486501
487502
488- @typing . overload
503+ @overload
489504def _format_output (
490505 is_inputs_tuple : Literal [False ],
491506 output : Tuple [Tensor , ...],
492507) -> Tensor : ...
493508
494509
495- @typing . overload
510+ @overload
496511def _format_output (
497512 is_inputs_tuple : bool , output : Tuple [Tensor , ...]
498513) -> Union [Tensor , Tuple [Tensor , ...]]: ...
@@ -501,7 +516,7 @@ def _format_output(
501516def _format_output (
502517 is_inputs_tuple : bool , output : Tuple [Tensor , ...]
503518) -> Union [Tensor , Tuple [Tensor , ...]]:
504- r """
519+ """
505520 In case input is a tensor and the output is returned in form of a
506521 tuple we take the first element of the output's tuple to match the
507522 same shape signatues of the inputs
@@ -516,21 +531,21 @@ def _format_output(
516531 return output if is_inputs_tuple else output [0 ]
517532
518533
519- @typing . overload
534+ @overload
520535def _format_outputs (
521536 is_multiple_inputs : Literal [False ],
522537 outputs : List [Tuple [Tensor , ...]],
523538) -> Union [Tensor , Tuple [Tensor , ...]]: ...
524539
525540
526- @typing . overload
541+ @overload
527542def _format_outputs (
528543 is_multiple_inputs : Literal [True ],
529544 outputs : List [Tuple [Tensor , ...]],
530545) -> List [Union [Tensor , Tuple [Tensor , ...]]]: ...
531546
532547
533- @typing . overload
548+ @overload
534549def _format_outputs (
535550 is_multiple_inputs : bool , outputs : List [Tuple [Tensor , ...]]
536551) -> Union [Tensor , Tuple [Tensor , ...], List [Union [Tensor , Tuple [Tensor , ...]]]]: ...
@@ -539,6 +554,10 @@ def _format_outputs(
539554def _format_outputs (
540555 is_multiple_inputs : bool , outputs : List [Tuple [Tensor , ...]]
541556) -> Union [Tensor , Tuple [Tensor , ...], List [Union [Tensor , Tuple [Tensor , ...]]]]:
557+ """
558+ Formats list of output tuples: returns list if is_multiple_inputs is True,
559+ otherwise single formatted output.
560+ """
542561 assert isinstance (outputs , list ), "Outputs must be a list"
543562 assert is_multiple_inputs or len (outputs ) == 1 , (
544563 "outputs should contain multiple inputs or have a single output"
@@ -554,9 +573,7 @@ def _format_outputs(
554573
555574# pyre-fixme[24] Callable requires 2 arguments
556575def _construct_future_forward (original_forward : Callable ) -> Callable :
557- # pyre-fixme[3] return type not specified
558- def future_forward (* args : Any , ** kwargs : Any ):
559- # pyre-fixme[29]: `typing.Type[torch.futures.Future]` is not a function.
576+ def future_forward (* args : Any , ** kwargs : Any ) -> torch .futures .Future [Tensor ]:
560577 fut : torch .futures .Future [Tensor ] = torch .futures .Future ()
561578 fut .set_result (original_forward (* args , ** kwargs ))
562579 return fut
@@ -829,7 +846,7 @@ def _flatten_tensor_or_tuple(inp: TensorOrTupleOfTensorsGeneric) -> Tensor:
829846
830847
831848def _get_module_from_name (model : Module , layer_name : str ) -> Any :
832- r """
849+ """
833850 Returns the module (layer) object, given its (string) name
834851 in the model.
835852
0 commit comments