Official implementation of “On the Adversarial Vulnerability of Label‑Free Test‑Time Adaptation” (ICLR 2025). We expose how partial-batch adversarial perturbations degrade adaptation and provide reference scripts to reproduce the paper’s results.
- Python 3.8+ with CUDA-capable PyTorch.
- Install dependencies:
pip install torch torchvision timm robustbench yacs tqdm pandas numpy matplotlib pillow
Datasets:
- CIFAR-10/100-C: place under corrupted_data (RobustBench format) or set cfg.DATA.PATH.
- ImageNet-C: expected under corrupted_data/ImageNet-C///. Pretrained checkpoints: CIFAR-10-C /100-C: Download relevant checkpoints from this Link ImageNet-C: Model defination and checkpoints from torchvision.
Single run:
python3 main.py \
--attack dia \
--tta tent \
--dataset cifar10c \
--severity 3 \
--batch_size 200 \
--gpu_id 0- BASE.SEED – random seed
- BASE.GPU_ID – GPU device ID
- BASE.ATTACK – attack type (dia, u_dia)
- TTA.NAME – test-time adaptation method
- DATA.BATCH_SIZE – test-time batch size
- DATA.SEVERITY – corruption severity (1–5)
- CORRUPTION.DATASET – dataset (cifar10c, cifar100c, imagenetc)
- DATA.PATH – path to dataset root
- DIA.EPS – max perturbation
- DIA.ALPHA – step size
- DIA.STEPS – number of attack iterations
- DIA.MAL_PORTION – fraction of batch to attack
- DIA.PSEUDO – enable pseudo-labeling
- DIA.SRC_ONLY – use source model only
- DIA.CONTINUAL – attack carries across batches
- DIA.ADV_MODEL – alternate model for adversarial gradients
- MODEL.ARCH – backbone architecture
- MODEL.EPISODIC – reset model each episode
- OPTIM.* – optimizer settings (LR, steps, method)
- TTA. – method-specific configuration blocks
- e.g., TTA.TENT, TTA.EATA, TTA.NORM