Commit 6b7847d
Update FeatureAblation to handle precision loss when baseline is more granular than input when cross tensor attribution is enabled
Summary:
Noticed when flipping the flag, this test case failed:
https://www.internalfb.com/code/fbsource/[faf71541b1ec0fae639f82d487b81fb18ea3e523]/fbcode/pytorch/captum/tests/attr/test_dataloader_attr.py?lines=138%2C134
The ablated tensor was `tensor([0])` instead of `tensor([[0.1])` since the baseline was a float-type and the input tensors were int tensors.
https://www.internalfb.com/code/fbsource/[f2fcc926a6f3669602bac4d28c2d92e4197c96b9]/fbcode/pytorch/captum/captum/attr/_core/feature_ablation.py?lines=707-709
`ablated_input` is just a copy of the `input_tensor`, so during assignment, the ablated feature tensor incorrectly gets cast to an int tensor for this case.
Differential Revision: D819802191 parent f815abc commit 6b7847d
File tree
2 files changed
+41
-3
lines changed- captum/attr/_core
- tests/attr
2 files changed
+41
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
704 | 704 | | |
705 | 705 | | |
706 | 706 | | |
707 | | - | |
708 | | - | |
709 | | - | |
| 707 | + | |
| 708 | + | |
| 709 | + | |
| 710 | + | |
| 711 | + | |
710 | 712 | | |
711 | 713 | | |
712 | 714 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
220 | 220 | | |
221 | 221 | | |
222 | 222 | | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
223 | 259 | | |
224 | 260 | | |
225 | 261 | | |
| |||
0 commit comments