Binary classification inconsistencies #1059
-
so i had to implement a meanwhile
https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html looking at the Accuracy docs, it seems like comparing equal length tensors is reserved for ordinal ints. seems like a symptom of the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
The inconsistencies you are experiencing likely stem from passing raw softmax outputs (probabilities) directly to the Accuracy metric. As per the TorchMetrics Accuracy docs, the metric expects integer class labels or binary labels for comparison, not probability tensors. import torch
from torchmetrics.classification import BinaryAccuracy
# Initialize metric
accuracy = BinaryAccuracy()
# Model outputs as probabilities (2D with softmax over dim=1)
probs = torch.tensor([[0.3, 0.7], [0.6, 0.4], [0.8, 0.2]])
# Convert to predicted class by argmax for multiclass or threshold for binary
# For binary, if you have single-channel output use threshold; here we use argmax for demonstration
preds = torch.argmax(probs, dim=1)
# True labels as integer class indices
target = torch.tensor([1, 0, 0])
# Update accuracy metric
acc_val = accuracy(preds, target)
print(f"Accuracy: {acc_val.item()}")If you’re using single-channel sigmoid outputs for binary classification, apply a threshold: sigmoid_out = torch.tensor([0.3, 0.7, 0.6])
preds = (sigmoid_out > 0.5).int()Avoid passing raw logits or softmax probabilities directly. This approach should resolve the issues you described. The current TorchMetrics Accuracy implementation does not automatically handle thresholding. |
Beta Was this translation helpful? Give feedback.
The inconsistencies you are experiencing likely stem from passing raw softmax outputs (probabilities) directly to the Accuracy metric. As per the TorchMetrics Accuracy docs, the metric expects integer class labels or binary labels for comparison, not probability tensors.
For binary classification, you should convert your model outputs from softmax or sigmoid probabilities to binary predictions by applying a threshold such as 0.5. Here’s a simple example of how to do this: