diff --git a/benchmark.py b/benchmark.py index db8fe4308d..beaf257a1d 100755 --- a/benchmark.py +++ b/benchmark.py @@ -25,13 +25,6 @@ from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs,\ reparameterize_model -has_apex = False -try: - from apex import amp - has_apex = True -except ImportError: - pass - try: from deepspeed.profiling.flops_profiler import get_model_profile has_deepspeed_profiling = True diff --git a/inference.py b/inference.py index 7ccaa334ce..21db6194fa 100755 --- a/inference.py +++ b/inference.py @@ -23,12 +23,6 @@ from timm.models import create_model from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs -try: - from apex import amp - has_apex = True -except ImportError: - has_apex = False - try: from functorch.compile import memory_efficient_fusion has_functorch = True @@ -170,7 +164,7 @@ def main(): assert args.model_dtype in ('float32', 'float16', 'bfloat16') model_dtype = getattr(torch, args.model_dtype) - # resolve AMP arguments based on PyTorch / Apex availability + # resolve AMP arguments based on PyTorch availability amp_autocast = suppress if args.amp: assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP' diff --git a/timm/task/__init__.py b/timm/task/__init__.py new file mode 100644 index 0000000000..625488fd25 --- /dev/null +++ b/timm/task/__init__.py @@ -0,0 +1,17 @@ +"""Training task abstractions for timm. + +This module provides task-based abstractions for training loops where each task +encapsulates both the forward pass and loss computation, returning a dictionary +with loss components and outputs for logging. +""" +from .task import TrainingTask +from .classification import ClassificationTask +from .distillation import DistillationTeacher, LogitDistillationTask, FeatureDistillationTask + +__all__ = [ + 'TrainingTask', + 'ClassificationTask', + 'DistillationTeacher', + 'LogitDistillationTask', + 'FeatureDistillationTask', +] diff --git a/timm/task/classification.py b/timm/task/classification.py new file mode 100644 index 0000000000..2f81871b3a --- /dev/null +++ b/timm/task/classification.py @@ -0,0 +1,90 @@ +"""Classification training task.""" +import logging +from typing import Callable, Dict, Optional, Union + +import torch +import torch.nn as nn + +from .task import TrainingTask + +_logger = logging.getLogger(__name__) + + +class ClassificationTask(TrainingTask): + """Standard supervised classification task. + + Simple task that performs a forward pass through the model and computes + the classification loss. + + Args: + model: The model to train + criterion: Loss function (e.g., CrossEntropyLoss) + device: Device for task tensors/buffers + dtype: Dtype for task tensors/buffers + verbose: Enable info logging + + Example: + >>> task = ClassificationTask(model, nn.CrossEntropyLoss(), device=torch.device('cuda')) + >>> result = task(input, target) + >>> result['loss'].backward() + """ + + def __init__( + self, + model: nn.Module, + criterion: Union[nn.Module, Callable], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + verbose: bool = True, + ): + super().__init__(device=device, dtype=dtype, verbose=verbose) + self.model = model + self.criterion = criterion + + if self.verbose: + loss_name = getattr(criterion, '__name__', None) or type(criterion).__name__ + _logger.info(f"ClassificationTask: criterion={loss_name}") + + def prepare_distributed( + self, + device_ids: Optional[list] = None, + **ddp_kwargs + ) -> 'ClassificationTask': + """Prepare task for distributed training. + + Wraps the model in DistributedDataParallel (DDP). + + Args: + device_ids: List of device IDs for DDP (e.g., [local_rank]) + **ddp_kwargs: Additional arguments passed to DistributedDataParallel + + Returns: + self (for method chaining) + """ + from torch.nn.parallel import DistributedDataParallel as DDP + self.model = DDP(self.model, device_ids=device_ids, **ddp_kwargs) + return self + + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Forward pass through model and compute classification loss. + + Args: + input: Input tensor [B, C, H, W] + target: Target labels [B] + + Returns: + Dictionary containing: + - 'loss': Classification loss + - 'output': Model logits + """ + output = self.model(input) + loss = self.criterion(output, target) + + return { + 'loss': loss, + 'output': output, + } diff --git a/timm/task/distillation.py b/timm/task/distillation.py new file mode 100644 index 0000000000..ff92b44d93 --- /dev/null +++ b/timm/task/distillation.py @@ -0,0 +1,574 @@ +"""Knowledge distillation training tasks and components.""" +import logging +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.models import create_model +from timm.utils import unwrap_model + +from .task import TrainingTask + + +_logger = logging.getLogger(__name__) + + +class DistillationTeacher(nn.Module): + """Wrapper for a teacher model used in knowledge distillation. + + Creates and manages a pre-trained teacher model for knowledge distillation, + handling model compilation and normalization differences between teacher and student. + + Args: + model_name: Name of the teacher model to create + num_classes: Number of output classes + in_chans: Number of input channels + pretrained_path: Optional path to pretrained weights + device: Device to place the model on + dtype: Model dtype (uses float32 if None) + """ + + def __init__( + self, + model_name: str, + num_classes: int, + in_chans: int = 3, + pretrained_path: Optional[str] = None, + device: torch.device = torch.device('cuda'), + dtype: torch.dtype = None, + ): + super().__init__() + + _logger.info(f"Creating KD teacher model: '{model_name}'") + + pretrained_kwargs = {'pretrained': True} + if pretrained_path: + # specify a local checkpoint path to load pretrained weights from + pretrained_kwargs['pretrained_cfg_overlay'] = dict( + file=pretrained_path, + num_classes=num_classes, + ) + + model_kd = create_model( + model_name=model_name, + num_classes=num_classes, + in_chans=in_chans, + device=device, + dtype=dtype, + **pretrained_kwargs, + ) + + model_kd.eval() + self.model = model_kd + + # Register normalization values as non-persistent buffers + # Shape: [1, 3, 1, 1] for proper broadcasting over BCHW images + mean_kd = torch.tensor(model_kd.pretrained_cfg['mean'], device=device, dtype=dtype).view(1, -1, 1, 1) + std_kd = torch.tensor(model_kd.pretrained_cfg['std'], device=device, dtype=dtype).view(1, -1, 1, 1) + self.register_buffer('mean_kd', mean_kd, persistent=False) + self.register_buffer('std_kd', std_kd, persistent=False) + + def forward( + self, + input: torch.Tensor, + return_features: bool = False, + ) -> torch.Tensor: + """Forward pass through teacher model. + + Args: + input: Input tensor (should already be normalized for teacher) + return_features: Whether to return pooled pre-logits features instead of logits + + Returns: + Logits or pooled pre-logits features depending on return_features flag + """ + if return_features: + if not hasattr(self.model, 'forward_features') or not hasattr(self.model, 'forward_head'): + raise ValueError( + f"Model {self.model.__class__.__name__} does not support feature extraction. " + "Ensure the model has 'forward_features' and 'forward_head' methods." + ) + # Extract spatial features and pool to pre-logits + feature_map = self.model.forward_features(input) + return self.model.forward_head(feature_map, pre_logits=True) + else: + return self.model(input) + + def normalize_input( + self, + input: torch.Tensor, + student_mean: Optional[torch.Tensor] = None, + student_std: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Normalize input to match teacher's expected normalization. + + Handles different normalization between teacher and student models by + converting the student's normalized input to the teacher's expected format. + + Args: + input: Input tensor (already normalized for student) + student_mean: Student normalization mean buffer [1, 3, 1, 1] (None if same as teacher) + student_std: Student normalization std buffer [1, 3, 1, 1] (None if same as teacher) + + Returns: + Input tensor normalized for the teacher model + """ + # If no student normalization provided, assume it matches teacher (no conversion needed) + if student_mean is None or student_std is None: + return input + + # Check if renormalization is actually needed + if torch.equal(student_mean, self.mean_kd) and torch.equal(student_std, self.std_kd): + return input + + # De-normalize (Student) -> Re-normalize (Teacher) + # Combined for efficiency: (input * std_s + mean_s - mean_t) / std_t + return (input * student_std + student_mean - self.mean_kd) / self.std_kd + + +class LogitDistillationTask(TrainingTask): + """Logit-based knowledge distillation task. + + Performs distillation by matching student and teacher output logits using + KL divergence with temperature scaling. + + Loss weighting supports two modes: + 1. Independent weights: loss = task_loss_weight * task_loss + distill_loss_weight * distill_loss + 2. Complementary mode: loss = task_loss_weight * task_loss + (1 - task_loss_weight) * distill_loss + (used when only task_loss_weight is specified) + + Args: + student_model: Student model to train + teacher: Pre-configured teacher model wrapper + criterion: Task loss function (e.g., CrossEntropyLoss) + loss_type: Type of distillation loss (currently only 'kl' supported, reserved for future extensions) + distill_loss_weight: Weight for distillation loss + task_loss_weight: Weight for task loss + temperature: Softmax temperature for distillation (typical values: 1-4) + device: Device for task tensors/buffers + dtype: Dtype for task tensors/buffers + verbose: Enable info logging + + Example: + >>> # Independent weights + >>> task = LogitDistillationTask( + ... student_model=model, teacher=teacher, criterion=nn.CrossEntropyLoss(), + ... distill_loss_weight=1.0, task_loss_weight=1.0, temperature=4.0, + ... device=torch.device('cuda'), + ... ) + >>> # Complementary mode (task_weight=0.3 means distill gets 0.7) + >>> task = LogitDistillationTask( + ... student_model=model, teacher=teacher, criterion=nn.CrossEntropyLoss(), + ... task_loss_weight=0.3, temperature=4.0, + ... device=torch.device('cuda'), + ... ) + """ + + def __init__( + self, + student_model: nn.Module, + teacher: DistillationTeacher, + criterion: nn.Module, + loss_type: str = 'kl', + distill_loss_weight: Optional[float] = None, + task_loss_weight: Optional[float] = None, + temperature: float = 1.0, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + verbose: bool = True, + ): + super().__init__(device=device, dtype=dtype, verbose=verbose) + self.student = student_model + self.teacher = teacher + self.criterion = criterion + self.loss_type = loss_type + self.temperature = temperature + + if loss_type != 'kl': + raise ValueError(f"Unsupported loss_type '{loss_type}'. Currently only 'kl' is supported.") + + # Register student normalization values as non-persistent buffers + # Shape: [1, 3, 1, 1] for proper broadcasting over BCHW images + student_unwrapped = unwrap_model(student_model) + student_mean = torch.tensor( + student_unwrapped.pretrained_cfg['mean'], + device=self.device, + dtype=self.dtype, + ).view(1, -1, 1, 1) + student_std = torch.tensor( + student_unwrapped.pretrained_cfg['std'], + device=self.device, + dtype=self.dtype, + ).view(1, -1, 1, 1) + self.register_buffer('student_mean', student_mean, persistent=False) + self.register_buffer('student_std', student_std, persistent=False) + + # Determine weighting mode + if distill_loss_weight is not None: + # Mode 1: distill_weight specified - independent weights (task defaults to 1.0 if not set) + self.distill_loss_weight = distill_loss_weight + self.task_loss_weight = task_loss_weight if task_loss_weight is not None else 1.0 + if self.verbose: + _logger.info( + f"LogitDistillationTask: Independent weights - " + f"task_weight={self.task_loss_weight}, distill_weight={distill_loss_weight}" + ) + elif task_loss_weight is not None: + # Mode 2: Only task_weight specified - complementary mode + self.task_loss_weight = task_loss_weight + self.distill_loss_weight = 1.0 - task_loss_weight + if self.verbose: + _logger.info( + f"LogitDistillationTask: Complementary mode - " + f"task_weight={task_loss_weight}, distill_weight={self.distill_loss_weight}" + ) + else: + # Neither specified - use defaults (equal weighting) + self.distill_loss_weight = 1.0 + self.task_loss_weight = 1.0 + if self.verbose: + _logger.info( + f"LogitDistillationTask: Default equal weights - " + f"task_weight={self.task_loss_weight}, distill_weight={self.distill_loss_weight}" + ) + + if self.verbose: + _logger.info( + f"LogitDistillationTask: loss_type={loss_type}, temperature={temperature}" + ) + + def prepare_distributed( + self, + device_ids: Optional[list] = None, + **ddp_kwargs + ) -> 'LogitDistillationTask': + """Prepare task for distributed training. + + Wraps the student model in DistributedDataParallel (DDP) while leaving + the frozen teacher model unwrapped. + + Args: + device_ids: List of device IDs for DDP (e.g., [local_rank]) + **ddp_kwargs: Additional arguments passed to DistributedDataParallel + + Returns: + self (for method chaining) + """ + from torch.nn.parallel import DistributedDataParallel as DDP + + # Ensure teacher parameters are frozen + for param in self.teacher.parameters(): + param.requires_grad = False + + # Wrap only student in DDP + self.student = DDP(self.student, device_ids=device_ids, **ddp_kwargs) + return self + + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Forward pass with logit distillation. + + Args: + input: Input tensor [B, C, H, W] + target: Target labels [B] + + Returns: + Dictionary containing: + - 'loss': Combined training loss (task + distillation) + - 'output': Student logits + - 'task_loss': Classification loss component + - 'kd_loss': Distillation loss component + """ + # Student forward pass + student_logits = self.student(input) + + # Compute task loss + task_loss = self.criterion(student_logits, target) + + # Teacher forward pass (no gradient) + with torch.no_grad(): + input_kd = self.teacher.normalize_input(input, self.student_mean, self.student_std) + teacher_logits = self.teacher(input_kd.detach(), return_features=False) + + # Compute distillation loss (KL divergence with temperature scaling) + prob_s = F.log_softmax(student_logits / self.temperature, dim=-1) + prob_t = F.log_softmax(teacher_logits / self.temperature, dim=-1) + kd_loss = F.kl_div(prob_s, prob_t, reduction='batchmean', log_target=True) * (self.temperature ** 2) + + # Combine losses with weights + total_loss = self.task_loss_weight * task_loss + self.distill_loss_weight * kd_loss + + return { + 'loss': total_loss, + 'output': student_logits, + 'task_loss': task_loss, + 'kd_loss': kd_loss, + } + + +class FeatureDistillationTrainableModule(nn.Module): + """Trainable module for feature distillation. + + Wraps student model and projection layer into a single module where all + trainable forward operations happen inside forward(). This ensures proper + DDP wrapping when the module is used with DistributedDataParallel. + + Args: + student_model: Student model to train + projection: Optional projection layer (Linear layer or None) + + Returns: + Tuple of (logits, projected_features) + """ + + def __init__( + self, + student_model: nn.Module, + projection: Optional[nn.Module] = None, + ): + super().__init__() + self.student = student_model + self.projection = projection + + def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass through student and projection. + + Args: + input: Input tensor [B, C, H, W] + + Returns: + Tuple of (student_logits, projected_features) + """ + # Extract features and compute logits + feature_map = self.student.forward_features(input) + student_logits = self.student.forward_head(feature_map) + student_features = self.student.forward_head(feature_map, pre_logits=True) + + # Apply projection if needed + if self.projection is not None: + student_features = self.projection(student_features) + + return student_logits, student_features + + +class FeatureDistillationTask(TrainingTask): + """Feature-based knowledge distillation task. + + Performs distillation by matching student and teacher intermediate features + (pooled pre-logits) using MSE loss. Automatically creates a projection layer + if student and teacher feature dimensions differ. + + Loss weighting supports two modes: + 1. Independent weights: loss = task_loss_weight * task_loss + distill_loss_weight * distill_loss + 2. Complementary mode: loss = task_loss_weight * task_loss + (1 - task_loss_weight) * distill_loss + (used when only task_loss_weight is specified) + + Args: + student_model: Student model to train + teacher: Pre-configured teacher model wrapper + criterion: Task loss function (e.g., CrossEntropyLoss) + distill_loss_weight: Weight for distillation loss + task_loss_weight: Weight for task loss + student_feature_dim: Student pre-logits dimension (auto-detected if None) + teacher_feature_dim: Teacher pre-logits dimension (auto-detected if None) + device: Device for task tensors/buffers + dtype: Dtype for task tensors/buffers + verbose: Enable info logging + + Example: + >>> # Independent weights + >>> task = FeatureDistillationTask( + ... student_model=model, teacher=teacher, criterion=nn.CrossEntropyLoss(), + ... distill_loss_weight=5.0, task_loss_weight=1.0, + ... device=torch.device('cuda'), + ... ) + >>> # Complementary mode + >>> task = FeatureDistillationTask( + ... student_model=model, teacher=teacher, criterion=nn.CrossEntropyLoss(), + ... task_loss_weight=0.3, + ... device=torch.device('cuda'), + ... ) + """ + + def __init__( + self, + student_model: nn.Module, + teacher: DistillationTeacher, + criterion: nn.Module, + distill_loss_weight: Optional[float] = None, + task_loss_weight: Optional[float] = None, + student_feature_dim: Optional[int] = None, + teacher_feature_dim: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + verbose: bool = True, + ): + super().__init__(device=device, dtype=dtype, verbose=verbose) + self.teacher = teacher + self.criterion = criterion + + # Determine weighting mode + if distill_loss_weight is not None: + # Mode 1: distill_weight specified - independent weights (task defaults to 1.0 if not set) + self.distill_loss_weight = distill_loss_weight + self.task_loss_weight = task_loss_weight if task_loss_weight is not None else 1.0 + if self.verbose: + _logger.info( + f"FeatureDistillationTask: Independent weights - " + f"task_weight={self.task_loss_weight}, distill_weight={distill_loss_weight}" + ) + elif task_loss_weight is not None: + # Mode 2: Only task_weight specified - complementary mode + self.task_loss_weight = task_loss_weight + self.distill_loss_weight = 1.0 - task_loss_weight + if self.verbose: + _logger.info( + f"FeatureDistillationTask: Complementary mode - " + f"task_weight={task_loss_weight}, distill_weight={self.distill_loss_weight}" + ) + else: + # Neither specified - use defaults (equal weighting) + self.distill_loss_weight = 1.0 + self.task_loss_weight = 1.0 + if self.verbose: + _logger.info( + f"FeatureDistillationTask: Default equal weights - " + f"task_weight={self.task_loss_weight}, distill_weight={self.distill_loss_weight}" + ) + + # Auto-detect feature dimensions if not provided + if student_feature_dim is None: + student_feature_dim = self._detect_feature_dim(student_model) + if teacher_feature_dim is None: + teacher_feature_dim = self._detect_feature_dim(teacher.model) + + # Create projection layer if dimensions differ + projection = None + if student_feature_dim != teacher_feature_dim: + if self.verbose: + _logger.info( + f"Creating projection layer: {student_feature_dim} -> {teacher_feature_dim}" + ) + projection = nn.Linear(student_feature_dim, teacher_feature_dim, device=self.device, dtype=self.dtype) + else: + if self.verbose: + _logger.info("Feature dimensions match, no projection needed") + + # Create trainable module wrapping student and projection + self.trainable_module = FeatureDistillationTrainableModule(student_model, projection) + + # Register student normalization values as non-persistent buffers + # Shape: [1, 3, 1, 1] for proper broadcasting over BCHW images + student_unwrapped = unwrap_model(student_model) + student_mean = torch.tensor( + student_unwrapped.pretrained_cfg['mean'], + device=self.device, + dtype=self.dtype, + ).view(1, -1, 1, 1) + student_std = torch.tensor( + student_unwrapped.pretrained_cfg['std'], + device=self.device, + dtype=self.dtype, + ).view(1, -1, 1, 1) + self.register_buffer('student_mean', student_mean, persistent=False) + self.register_buffer('student_std', student_std, persistent=False) + + if self.verbose: + _logger.info( + f"FeatureDistillationTask: " + f"student_dim={student_feature_dim}, teacher_dim={teacher_feature_dim}" + ) + + @staticmethod + def _detect_feature_dim(model: nn.Module) -> int: + """Auto-detect feature dimension from model. + + Tries head_hidden_size first (pre-logits dimension), then num_features. + """ + # Unwrap DDP/EMA wrapper if present + model = unwrap_model(model) + + if hasattr(model, 'head_hidden_size'): + return model.head_hidden_size + elif hasattr(model, 'num_features'): + return model.num_features + else: + raise ValueError( + "Cannot auto-detect feature dimension. Model must have " + "'head_hidden_size' or 'num_features' attribute, or you must " + "specify student_feature_dim and teacher_feature_dim explicitly." + ) + + def prepare_distributed( + self, + device_ids: Optional[list] = None, + **ddp_kwargs + ) -> 'FeatureDistillationTask': + """Prepare task for distributed training. + + Wraps the trainable module (student + projection) in DistributedDataParallel (DDP) + while leaving the frozen teacher model unwrapped. + + Args: + device_ids: List of device IDs for DDP (e.g., [local_rank]) + **ddp_kwargs: Additional arguments passed to DistributedDataParallel + + Returns: + self (for method chaining) + """ + from torch.nn.parallel import DistributedDataParallel as DDP + + # Ensure teacher parameters are frozen + for param in self.teacher.parameters(): + param.requires_grad = False + + # Wrap trainable module (student + projection) in DDP + self.trainable_module = DDP(self.trainable_module, device_ids=device_ids, **ddp_kwargs) + return self + + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Forward pass with feature distillation. + + Args: + input: Input tensor [B, C, H, W] + target: Target labels [B] + + Returns: + Dictionary containing: + - 'loss': Combined training loss (task + distillation) + - 'output': Student logits + - 'task_loss': Classification loss component + - 'kd_loss': Feature distillation loss component + """ + # Student forward pass through trainable module (student + projection) + student_logits, student_features = self.trainable_module(input) + + # Compute task loss + task_loss = self.criterion(student_logits, target) + + # Teacher forward pass (no gradient) + with torch.no_grad(): + input_kd = self.teacher.normalize_input(input, self.student_mean, self.student_std) + teacher_features = self.teacher(input_kd.detach(), return_features=True) + + # Compute feature distillation loss (MSE) + kd_loss = F.mse_loss(student_features, teacher_features) + + # Combine losses with weights + total_loss = self.task_loss_weight * task_loss + self.distill_loss_weight * kd_loss + + return { + 'loss': total_loss, + 'output': student_logits, + 'task_loss': task_loss, + 'kd_loss': kd_loss, + } diff --git a/timm/task/task.py b/timm/task/task.py new file mode 100644 index 0000000000..719c58a600 --- /dev/null +++ b/timm/task/task.py @@ -0,0 +1,100 @@ +"""Base training task abstraction. + +This module provides the base TrainingTask class that encapsulates a complete +forward pass including loss computation. Tasks return a dictionary with loss +components and outputs for logging. +""" +from typing import Dict, Optional + +import torch +import torch.nn as nn + + +class TrainingTask(nn.Module): + """Base class for training tasks. + + A training task encapsulates a complete forward pass including loss computation. + Tasks return a dictionary containing the training loss and other components for logging. + + The returned dictionary must contain: + - 'loss': The training loss for backward pass (required) + - 'output': Model output/logits for metric computation (recommended) + - Other task-specific loss components for logging (optional) + + Args: + device: Device for task tensors/buffers (defaults to cpu) + dtype: Dtype for task tensors/buffers (defaults to torch default) + verbose: Enable info logging + + Example: + >>> task = SomeTask(model, criterion, device=torch.device('cuda')) + >>> + >>> # Prepare for distributed training (if needed) + >>> if distributed: + >>> task.prepare_distributed(device_ids=[local_rank]) + >>> + >>> # Training loop + >>> result = task(input, target) + >>> result['loss'].backward() + """ + + def __init__( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + verbose: bool = True, + ): + super().__init__() + self.device = device if device is not None else torch.device('cpu') + self.dtype = dtype if dtype is not None else torch.get_default_dtype() + self.verbose = verbose + + def to(self, *args, **kwargs): + """Move task to device/dtype, keeping self.device and self.dtype in sync.""" + dummy = torch.empty(0).to(*args, **kwargs) + self.device = dummy.device + self.dtype = dummy.dtype + return super().to(*args, **kwargs) + + def prepare_distributed( + self, + device_ids: Optional[list] = None, + **ddp_kwargs + ) -> 'TrainingTask': + """Prepare task for distributed training. + + This method wraps trainable components in DistributedDataParallel (DDP) + while leaving non-trainable components (like frozen teacher models) unwrapped. + + Should be called after task initialization but before training loop. + + Args: + device_ids: List of device IDs for DDP (e.g., [local_rank]) + **ddp_kwargs: Additional arguments passed to DistributedDataParallel + + Returns: + self (for method chaining) + + Example: + >>> task = LogitDistillationTask(student, teacher, criterion) + >>> task.prepare_distributed(device_ids=[args.local_rank]) + >>> task = torch.compile(task) # Compile after DDP + """ + # Default implementation - subclasses override if they need DDP + return self + + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Perform forward pass and compute loss. + + Args: + input: Input tensor [B, C, H, W] + target: Target labels [B] + + Returns: + Dictionary with at least 'loss' key containing the training loss + """ + raise NotImplementedError diff --git a/train.py b/train.py index 0dacdc1dc2..efa7725506 100755 --- a/train.py +++ b/train.py @@ -30,7 +30,6 @@ import torch.nn as nn import torchvision.utils import yaml -from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm import utils from timm.data import create_dataset, create_loader, create_naflex_loader, resolve_data_config, \ @@ -40,15 +39,8 @@ from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler_v2, scheduler_kwargs -from timm.utils import ApexScaler, NativeScaler - -try: - from apex import amp - from apex.parallel import DistributedDataParallel as ApexDDP - from apex.parallel import convert_syncbn_model - has_apex = True -except ImportError: - has_apex = False +from timm.utils import NativeScaler +from timm.task import DistillationTeacher, ClassificationTask, LogitDistillationTask, FeatureDistillationTask try: @@ -173,11 +165,9 @@ group.add_argument('--device', default='cuda', type=str, help="Device (accelerator) to use.") group.add_argument('--amp', action='store_true', default=False, - help='use NVIDIA Apex AMP or Native AMP for mixed precision training') + help='use AMP for mixed precision training') group.add_argument('--amp-dtype', default='float16', type=str, help='lower precision AMP dtype (default: float16)') -group.add_argument('--amp-impl', default='native', type=str, - help='AMP impl to use, "native" or "apex" (default: native)') group.add_argument('--model-dtype', default=None, type=str, help='Model dtype override (non-AMP) (default: float32)') group.add_argument('--no-ddp-bb', action='store_true', default=False, @@ -345,7 +335,7 @@ group.add_argument('--bn-eps', type=float, default=None, help='BatchNorm epsilon override (if not None)') group.add_argument('--sync-bn', action='store_true', - help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') + help='Enable synchronized BatchNorm.') group.add_argument('--dist-bn', type=str, default='reduce', help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') group.add_argument('--split-bn', action='store_true', @@ -417,6 +407,25 @@ group.add_argument('--naflex-loss-scale', default='linear', type=str, help='Scale loss (gradient) by batch_size ("none", "sqrt", or "linear")') +# Knowledge Distillation parameters +parser.add_argument('--kd-model-name', default=None, type=str, + help='Name of teacher model for knowledge distillation') +parser.add_argument('--kd-distill-type', default='logit', type=str, choices=['logit', 'feature'], + help='Type of distillation: "logit" for output distillation, "feature" for intermediate features (default: logit)') +parser.add_argument('--kd-loss-type', default='kl', type=str, + help='Loss function for logit distillation (default: kl). Currently only "kl" supported, reserved for future extensions.') +parser.add_argument('--distill-loss-weight', default=None, type=float, + help='Weight for distillation loss. If both weights specified: loss = task_weight * task + distill_weight * distill. ' + 'If only task_weight: loss = task_weight * task + (1-task_weight) * distill. Default: 1.0 if only this specified.') +parser.add_argument('--task-loss-weight', default=None, type=float, + help='Weight for task (classification) loss. See --distill-loss-weight for weighting modes. Default: 1.0 if unspecified.') +parser.add_argument('--kd-temperature', default=4.0, type=float, + help='Temperature for softmax in distillation (default: 4.0, typical range: 1-4)') +parser.add_argument('--kd-student-feature-dim', default=None, type=int, + help='Student model feature dimension (auto-detected from model.head_hidden_size or model.num_features if not specified)') +parser.add_argument('--kd-teacher-feature-dim', default=None, type=int, + help='Teacher model feature dimension (auto-detected from model.head_hidden_size or model.num_features if not specified)') + def _parse_args(): # Do we have a config file to parse? @@ -465,18 +474,11 @@ def main(): if model_dtype == torch.float16: _logger.warning('float16 is not recommended for training, for half precision bfloat16 is recommended.') - # resolve AMP arguments based on PyTorch / Apex availability - use_amp = None + # resolve AMP arguments based on PyTorch availability amp_dtype = torch.float16 if args.amp: assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP' - if args.amp_impl == 'apex': - assert has_apex, 'AMP impl specified as APEX but APEX is not installed.' - use_amp = 'apex' - assert args.amp_dtype == 'float16' - else: - use_amp = 'native' - assert args.amp_dtype in ('float16', 'bfloat16') + assert args.amp_dtype in ('float16', 'bfloat16') if args.amp_dtype == 'bfloat16': amp_dtype = torch.bfloat16 @@ -531,6 +533,9 @@ def main(): if args.grad_checkpointing: model.set_grad_checkpointing(enable=True) + # Create training task (classification or distillation) + task = None + if utils.is_primary(args): _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') @@ -557,12 +562,7 @@ def main(): if args.distributed and args.sync_bn: args.dist_bn = '' # disable dist_bn when sync BN active assert not args.split_bn - if has_apex and use_amp == 'apex': - # Apex SyncBN used with Apex AMP - # WARNING this won't currently work with models using BatchNormAct2d - model = convert_syncbn_model(model) - else: - model = convert_sync_batchnorm(model) + model = convert_sync_batchnorm(model) if utils.is_primary(args): _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' @@ -575,7 +575,6 @@ def main(): if args.torchscript: assert not args.torchcompile - assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) @@ -609,13 +608,7 @@ def main(): # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None - if use_amp == 'apex': - assert device.type == 'cuda' - model, optimizer = amp.initialize(model, optimizer, opt_level='O1') - loss_scaler = ApexScaler() - if utils.is_primary(args): - _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') - elif use_amp == 'native': + if args.amp: amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) if device.type in ('cuda',) and amp_dtype == torch.float16: # loss scaler only used for float16 (half) dtype, bfloat16 does not need it @@ -656,24 +649,6 @@ def main(): mode=args.torchcompile_mode, ) - # setup distributed training - if args.distributed: - if has_apex and use_amp == 'apex': - # Apex DDP preferred unless native amp is activated - if utils.is_primary(args): - _logger.info("Using NVIDIA APEX DistributedDataParallel.") - model = ApexDDP(model, delay_allreduce=True) - else: - if utils.is_primary(args): - _logger.info("Using native Torch DistributedDataParallel.") - model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb) - # NOTE: EMA model does not need to be wrapped by DDP - - if args.torchcompile: - # torch compile should be done after DDP - assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' - model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode) - # create the train and eval datasets if args.data and not args.data_dir: args.data_dir = args.data @@ -907,6 +882,69 @@ def main(): train_loss_fn = train_loss_fn.to(device=device) validate_loss_fn = nn.CrossEntropyLoss().to(device=device) + # Setup training task (classification or distillation) + if args.kd_model_name is not None: + # Create teacher model + teacher = DistillationTeacher( + model_name=args.kd_model_name, + num_classes=args.num_classes, + in_chans=in_chans, + device=device, + dtype=model_dtype, + ) + + # Create distillation task + if args.kd_distill_type == 'logit': + task = LogitDistillationTask( + student_model=model, + teacher=teacher, + criterion=train_loss_fn, + loss_type=args.kd_loss_type, + distill_loss_weight=args.distill_loss_weight, + task_loss_weight=args.task_loss_weight, + temperature=args.kd_temperature, + device=device, + dtype=model_dtype, + verbose=utils.is_primary(args), + ) + elif args.kd_distill_type == 'feature': + task = FeatureDistillationTask( + student_model=model, + teacher=teacher, + criterion=train_loss_fn, + distill_loss_weight=args.distill_loss_weight, + task_loss_weight=args.task_loss_weight, + student_feature_dim=args.kd_student_feature_dim, + teacher_feature_dim=args.kd_teacher_feature_dim, + device=device, + dtype=model_dtype, + verbose=utils.is_primary(args), + ) + else: + raise ValueError(f"Unknown distillation type: {args.kd_distill_type}") + else: + # Standard classification task + task = ClassificationTask( + model=model, + criterion=train_loss_fn, + device=device, + dtype=model_dtype, + verbose=utils.is_primary(args), + ) + + # Prepare task for distributed training + if args.distributed: + if utils.is_primary(args): + _logger.info("Preparing task for distributed training") + task.prepare_distributed(device_ids=[device]) + + # Compile task if requested (should be done after DDP) + if args.torchcompile: + assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' + if utils.is_primary(args): + _logger.info(f"Compiling task with backend={args.torchcompile}, mode={args.torchcompile_mode}") + task = torch.compile(task, backend=args.torchcompile, mode=args.torchcompile_mode) + # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric if loader_eval is not None else 'loss' decreasing_metric = eval_metric == 'loss' @@ -995,8 +1033,8 @@ def main(): model, loader_train, optimizer, - train_loss_fn, args, + task=task, device=device, lr_scheduler=lr_scheduler, saver=saver, @@ -1091,6 +1129,9 @@ def main(): except KeyboardInterrupt: pass + if args.distributed: + torch.distributed.destroy_process_group() + if best_metric is not None: # log best metric as tracked by checkpoint saver _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) @@ -1110,8 +1151,8 @@ def train_one_epoch( model, loader, optimizer, - loss_fn, args, + task=None, device=torch.device('cuda'), lr_scheduler=None, saver=None, @@ -1167,11 +1208,13 @@ def train_one_epoch( def _forward(): with amp_autocast(): - output = model(input) - _loss = loss_fn(output, target) + # Task handles the complete forward pass and loss computation + result = task(input, target) + _loss = result['loss'] + if accum_steps > 1: _loss /= accum_steps - return _loss + return _loss, result def _backward(_loss): if loss_scaler is not None: @@ -1220,13 +1263,13 @@ def _backward(_loss): if has_no_sync and not need_update: with model.no_sync(): - loss = _forward() + loss, result = _forward() scaled_loss = local_scale * loss if dist_scale is not None: scaled_loss *= dist_scale _backward(scaled_loss) else: - loss = _forward() + loss, result = _forward() scaled_loss = local_scale * loss if dist_scale is not None: scaled_loss *= dist_scale @@ -1238,10 +1281,10 @@ def _backward(_loss): if has_no_sync and not need_update: with model.no_sync(): - loss = _forward() + loss, result = _forward() _backward(loss) else: - loss = _forward() + loss, result = _forward() _backward(loss) losses_m.update(loss.item() * accum_steps, batch_size) diff --git a/validate.py b/validate.py index 75657a764d..03c572929d 100755 --- a/validate.py +++ b/validate.py @@ -28,11 +28,6 @@ from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \ decay_batch_step, check_batch_size_retry, ParseKwargs, reparameterize_model -try: - from apex import amp - has_apex = True -except ImportError: - has_apex = False try: from functorch.compile import memory_efficient_fusion @@ -124,11 +119,9 @@ parser.add_argument('--device', default='cuda', type=str, help="Device (accelerator) to use.") parser.add_argument('--amp', action='store_true', default=False, - help='use NVIDIA Apex AMP or Native AMP for mixed precision training') + help='use Native AMP for mixed precision inference') parser.add_argument('--amp-dtype', default='float16', type=str, help='lower precision AMP dtype (default: float16)') -parser.add_argument('--amp-impl', default='native', type=str, - help='AMP impl to use, "native" or "apex" (default: native)') parser.add_argument('--model-dtype', default=None, type=str, help='Model dtype override (non-AMP) (default: float32)') parser.add_argument('--tf-preprocessing', action='store_true', default=False, @@ -197,22 +190,14 @@ def validate(args): assert args.model_dtype in ('float32', 'float16', 'bfloat16') model_dtype = getattr(torch, args.model_dtype) - # resolve AMP arguments based on PyTorch / Apex availability - use_amp = None + # resolve AMP arguments based on PyTorch availability amp_autocast = suppress if args.amp: assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP' - if args.amp_impl == 'apex': - assert has_apex, 'AMP impl specified as APEX but APEX is not installed.' - assert args.amp_dtype == 'float16' - use_amp = 'apex' - _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') - else: - assert args.amp_dtype in ('float16', 'bfloat16') - use_amp = 'native' - amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 - amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) - _logger.info('Validating in mixed precision with native PyTorch AMP.') + assert args.amp_dtype in ('float16', 'bfloat16') + amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 + amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) + _logger.info('Validating in mixed precision with native PyTorch AMP.') else: _logger.info(f'Validating in {model_dtype or torch.float32}. AMP not enabled.') @@ -266,7 +251,6 @@ def validate(args): model = model.to(memory_format=torch.channels_last) if args.torchscript: - assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' model = torch.jit.script(model) elif args.torchcompile: assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' @@ -276,9 +260,6 @@ def validate(args): assert has_functorch, "functorch is needed for --aot-autograd" model = memory_efficient_fusion(model) - if use_amp == 'apex': - model = amp.initialize(model, opt_level='O1') - if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))