Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs,\
reparameterize_model

has_apex = False
try:
from apex import amp
has_apex = True
except ImportError:
pass

try:
from deepspeed.profiling.flops_profiler import get_model_profile
has_deepspeed_profiling = True
Expand Down
8 changes: 1 addition & 7 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@
from timm.models import create_model
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs

try:
from apex import amp
has_apex = True
except ImportError:
has_apex = False

try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
Expand Down Expand Up @@ -170,7 +164,7 @@ def main():
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
model_dtype = getattr(torch, args.model_dtype)

# resolve AMP arguments based on PyTorch / Apex availability
# resolve AMP arguments based on PyTorch availability
amp_autocast = suppress
if args.amp:
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
Expand Down
17 changes: 17 additions & 0 deletions timm/task/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Training task abstractions for timm.

This module provides task-based abstractions for training loops where each task
encapsulates both the forward pass and loss computation, returning a dictionary
with loss components and outputs for logging.
"""
from .task import TrainingTask
from .classification import ClassificationTask
from .distillation import DistillationTeacher, LogitDistillationTask, FeatureDistillationTask

__all__ = [
'TrainingTask',
'ClassificationTask',
'DistillationTeacher',
'LogitDistillationTask',
'FeatureDistillationTask',
]
90 changes: 90 additions & 0 deletions timm/task/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Classification training task."""
import logging
from typing import Callable, Dict, Optional, Union

import torch
import torch.nn as nn

from .task import TrainingTask

_logger = logging.getLogger(__name__)


class ClassificationTask(TrainingTask):
"""Standard supervised classification task.

Simple task that performs a forward pass through the model and computes
the classification loss.

Args:
model: The model to train
criterion: Loss function (e.g., CrossEntropyLoss)
device: Device for task tensors/buffers
dtype: Dtype for task tensors/buffers
verbose: Enable info logging

Example:
>>> task = ClassificationTask(model, nn.CrossEntropyLoss(), device=torch.device('cuda'))
>>> result = task(input, target)
>>> result['loss'].backward()
"""

def __init__(
self,
model: nn.Module,
criterion: Union[nn.Module, Callable],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
verbose: bool = True,
):
super().__init__(device=device, dtype=dtype, verbose=verbose)
self.model = model
self.criterion = criterion

if self.verbose:
loss_name = getattr(criterion, '__name__', None) or type(criterion).__name__
_logger.info(f"ClassificationTask: criterion={loss_name}")

def prepare_distributed(
self,
device_ids: Optional[list] = None,
**ddp_kwargs
) -> 'ClassificationTask':
"""Prepare task for distributed training.

Wraps the model in DistributedDataParallel (DDP).

Args:
device_ids: List of device IDs for DDP (e.g., [local_rank])
**ddp_kwargs: Additional arguments passed to DistributedDataParallel

Returns:
self (for method chaining)
"""
from torch.nn.parallel import DistributedDataParallel as DDP
self.model = DDP(self.model, device_ids=device_ids, **ddp_kwargs)
return self

def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""Forward pass through model and compute classification loss.

Args:
input: Input tensor [B, C, H, W]
target: Target labels [B]

Returns:
Dictionary containing:
- 'loss': Classification loss
- 'output': Model logits
"""
output = self.model(input)
loss = self.criterion(output, target)

return {
'loss': loss,
'output': output,
}
Loading