diff --git a/transformers_interpret/explainers/text/multilabel_classification.py b/transformers_interpret/explainers/text/multilabel_classification.py index d2fa657..b03ce79 100644 --- a/transformers_interpret/explainers/text/multilabel_classification.py +++ b/transformers_interpret/explainers/text/multilabel_classification.py @@ -157,7 +157,7 @@ def __call__( ) self.selected_index = i explainer._forward = self._forward - explainer(text, i, embedding_type) + explainer(text, i, None, embedding_type, internal_batch_size, n_steps) self.attributions.append(explainer.attributions) self.input_ids = explainer.input_ids