From 9583cc81c47a269996bb6d2f311cdcc815382dc0 Mon Sep 17 00:00:00 2001 From: wcrzlh Date: Mon, 23 Oct 2023 14:43:18 +0800 Subject: [PATCH] feat: add model script, training configs and trained weights of cait --- configs/cait/README.md | 88 ++++++++++++++++++++++++++++++++ configs/cait/cait_s24_224.yaml | 61 ++++++++++++++++++++++ configs/cait/cait_s36_224.yaml | 61 ++++++++++++++++++++++ configs/cait/cait_xs24_224.yaml | 61 ++++++++++++++++++++++ configs/cait/cait_xxs24_224.yaml | 61 ++++++++++++++++++++++ mindcv/models/cait.py | 44 ++++++++++------ 6 files changed, 361 insertions(+), 15 deletions(-) create mode 100644 configs/cait/README.md create mode 100644 configs/cait/cait_s24_224.yaml create mode 100644 configs/cait/cait_s36_224.yaml create mode 100644 configs/cait/cait_xs24_224.yaml create mode 100644 configs/cait/cait_xxs24_224.yaml diff --git a/configs/cait/README.md b/configs/cait/README.md new file mode 100644 index 000000000..1dd124c4d --- /dev/null +++ b/configs/cait/README.md @@ -0,0 +1,88 @@ +# Going deeper with Image Transformers + +> [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) + +## Introduction + +CaiT is built based on ViT but made two contributions to improve model performance. +Firstly, Layerscale is introduced to facilitate the convergence. +Secondly, class-attention offers a more effective processing of the class embedding. +By combing these parts, Cait could get a SOTA performance on ImageNet-1K dataset. + + +## Results + +Our reproduced model performance on ImageNet-1K is reported as follows. + +
+ +| Model | Context | Top-1 (%) | Top-5 (%) | Params(M) | Recipe | Download | +|----------------| -------- |----------|-----------|-----------|--------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------| +| cait_xxs24_224 | D910x8-G | 77.71 | 94.10 | 11.94 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/cait/cait_xxs24_224.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/cait/cait_xxs24-31b307a8.ckpt) | +| cait_xs24_224 | D910x8-G | 81.29 | 95.60 | 26.53 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/cait/cait_xs24_224.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/cait/cait_xs24-ba0c2053.ckpt) | +| cait_s24_224 | D910x8-G | 82.25 | 95.95 | 46.88 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/cait/cait_s24_224.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/cait/cait_s24-0a06be71.ckpt) | +| cait_s36_224 | D910x8-G | 82.11 | 95.84 | 68.16 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/cait/cait_s36_224.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/cait/cait_s36-2e42bfc8.ckpt) | + + +
+ +#### Notes + +- Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G - graph mode or F - pynative mode with ms function. For example, D910x8-G is for training on 8 pieces of Ascend 910 NPU using graph mode. +- Top-1 and Top-5: Accuracy reported on the validation set of ImageNet-1K. + +## Quick Start + +### Preparation + +#### Installation + +Please refer to the [installation instruction](https://github.com/mindspore-lab/mindcv#installation) in MindCV. + +#### Dataset Preparation + +Please download the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/2012/index.php) dataset for model training and validation. + +### Training + +* Distributed Training + +It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run + +```shell +# distributed training on multiple GPU/Ascend devices +mpirun -n 8 python train.py --config configs/cait/cait_xxs24_224.yaml --data_dir /path/to/imagenet +``` +> If the script is executed by the root user, the `--allow-run-as-root` parameter must be added to `mpirun`. + +Similarly, you can train the model on multiple GPU devices with the above `mpirun` command. + +For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindcv/blob/main/config.py). + +**Note:** As the global batch size (batch_size x num_devices) is an important hyper-parameter, it is recommended to keep the global batch size unchanged for reproduction or adjust the learning rate linearly to a new global batch size. + +* Standalone Training + +If you want to train or finetune the model on a smaller dataset without distributed training, please run: + +```shell +# standalone training on a CPU/GPU/Ascend device +python train.py --config configs/cait/cait_xxs24_224.yaml --data_dir /path/to/dataset --distribute False +``` + +### Validation + +To validate the accuracy of the trained model, you can use `validate.py` and parse the checkpoint path with `--ckpt_path`. + +``` +python validate.py -c configs/cait/cait_xxs24_224.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt +``` + +### Deployment + +Please refer to the [deployment tutorial](https://mindspore-lab.github.io/mindcv/tutorials/deployment/). + +## References + + +[1] Touvron H, Cord M, Sablayrolles A, et al. Going deeper with image transformers[C]//Proceedings of the IEEE/CVF international conference on computer vision. 2021: 32-42. diff --git a/configs/cait/cait_s24_224.yaml b/configs/cait/cait_s24_224.yaml new file mode 100644 index 000000000..a786c7946 --- /dev/null +++ b/configs/cait/cait_s24_224.yaml @@ -0,0 +1,61 @@ +# system config +mode: 0 +distribute: True +num_parallel_workers: 8 +val_while_train: True + +# dataset config +dataset: 'imagenet' +data_dir: '/path/to/imagenet' +shuffle: True +dataset_download: False +batch_size: 64 +drop_remainder: True + +# augmentation config +image_resize: 224 +scale: [0.08, 1.0] +ratio: [0.75, 1.333] +hflip: 0.5 +vflip: 0. +interpolation: 'bicubic' +auto_augment: 'randaug-m9-mstd0.5-inc1' +re_prob: 0.25 +mixup: 0.8 +cutmix: 1 +color_jitter: 0.3 +crop_pct: 1.0 +ema: True +ema_decay: 0.99996 + +# model config +model: 'cait_s24_224' +num_classes: 1000 +pretrained: False +ckpt_path: '' +keep_checkpoint_max: 10 +ckpt_save_dir: '/cache/output/' +epoch_size: 400 +dataset_sink_mode: True +amp_level: 'O2' +drop_path_rate: 0.1 + +# loss config +loss: 'CE' +label_smoothing: 0.1 + +# lr scheduler config +scheduler: 'warmup_cosine_decay' +lr: 0.001 +min_lr: 0.000001 +warmup_epochs: 30 +decay_epochs: 370 +num_cycles: 2 + + +# optimizer config +opt: 'adamw' +weight_decay: 0.05 +filter_bias_and_bn: True +loss_scale: 1024 +use_nesterov: False diff --git a/configs/cait/cait_s36_224.yaml b/configs/cait/cait_s36_224.yaml new file mode 100644 index 000000000..1acc253db --- /dev/null +++ b/configs/cait/cait_s36_224.yaml @@ -0,0 +1,61 @@ +# system config +mode: 0 +distribute: True +num_parallel_workers: 8 +val_while_train: True + +# dataset config +dataset: 'imagenet' +data_dir: '/path/to/iamgenet' +shuffle: True +dataset_download: False +batch_size: 64 +drop_remainder: True + +# augmentation config +image_resize: 224 +scale: [0.08, 1.0] +ratio: [0.75, 1.333] +hflip: 0.5 +vflip: 0. +interpolation: 'bicubic' +auto_augment: 'randaug-m9-mstd0.5-inc1' +re_prob: 0.25 +mixup: 0.8 +cutmix: 1 +color_jitter: 0.3 +crop_pct: 1.0 +ema: True +ema_decay: 0.99996 + +# model config +model: 'cait_s36_224' +num_classes: 1000 +pretrained: False +ckpt_path: '' +keep_checkpoint_max: 10 +ckpt_save_dir: './ckpt' +epoch_size: 400 +dataset_sink_mode: True +amp_level: 'O2' +drop_path_rate: 0.1 + +# loss config +loss: 'CE' +label_smoothing: 0.1 + +# lr scheduler config +scheduler: 'warmup_cosine_decay' +lr: 0.002 +min_lr: 0.000001 +warmup_epochs: 30 +decay_epochs: 370 +num_cycles: 2 + + +# optimizer config +opt: 'adamw' +weight_decay: 0.05 +filter_bias_and_bn: True +loss_scale: 1024 +use_nesterov: False diff --git a/configs/cait/cait_xs24_224.yaml b/configs/cait/cait_xs24_224.yaml new file mode 100644 index 000000000..9f675aca4 --- /dev/null +++ b/configs/cait/cait_xs24_224.yaml @@ -0,0 +1,61 @@ +# system config +mode: 0 +distribute: True +num_parallel_workers: 8 +val_while_train: True + +# dataset config +dataset: 'imagenet' +data_dir: '/path/to/imagenet' +shuffle: True +dataset_download: False +batch_size: 64 +drop_remainder: True + +# augmentation config +image_resize: 224 +scale: [0.08, 1.0] +ratio: [0.75, 1.333] +hflip: 0.5 +vflip: 0. +interpolation: 'bicubic' +auto_augment: 'randaug-m9-mstd0.5-inc1' +re_prob: 0.25 +mixup: 0.8 +cutmix: 1 +color_jitter: 0.3 +crop_pct: 1.0 +ema: True +ema_decay: 0.99996 + +# model config +model: 'cait_xs24_224' +num_classes: 1000 # +pretrained: False # +ckpt_path: '' +keep_checkpoint_max: 10 +ckpt_save_dir: './ckpt' +epoch_size: 400 +dataset_sink_mode: True +amp_level: 'O2' +drop_path_rate: 0.1 + +# loss config +loss: 'CE' +label_smoothing: 0.1 + +# lr scheduler config +scheduler: 'warmup_cosine_decay' +lr: 0.001 +min_lr: 0.000001 +warmup_epochs: 40 +decay_epochs: 360 +num_cycles: 2 + + +# optimizer config +opt: 'adamw' +weight_decay: 0.05 +filter_bias_and_bn: True +loss_scale: 1024 +use_nesterov: False diff --git a/configs/cait/cait_xxs24_224.yaml b/configs/cait/cait_xxs24_224.yaml new file mode 100644 index 000000000..371792b2b --- /dev/null +++ b/configs/cait/cait_xxs24_224.yaml @@ -0,0 +1,61 @@ +# system config +mode: 0 +distribute: True +num_parallel_workers: 8 +val_while_train: True + +# dataset config +dataset: 'imagenet' +data_dir: '/path/to/dataset' +shuffle: True +dataset_download: False +batch_size: 128 +drop_remainder: True + +# augmentation config +image_resize: 224 +scale: [0.08, 1.0] +ratio: [0.75, 1.333] +hflip: 0.5 +vflip: 0. +interpolation: 'bicubic' +auto_augment: 'randaug-m9-mstd0.5-inc1' +re_prob: 0.25 +mixup: 0.8 +cutmix: 1 +color_jitter: 0.3 +crop_pct: 1.0 +ema: True +ema_decay: 0.99996 + +# model config +model: 'cait_xxs24_224' +num_classes: 1000 +pretrained: False +ckpt_path: '' +keep_checkpoint_max: 10 +ckpt_save_dir: './ckpt' +epoch_size: 500 +dataset_sink_mode: True +amp_level: 'O2' +drop_path_rate: 0.1 + +# loss config +loss: 'CE' +label_smoothing: 0.1 + +# lr scheduler config +scheduler: 'warmup_cosine_decay' +lr: 0.001 +min_lr: 0.000001 +warmup_epochs: 40 +decay_epochs: 460 +num_cycles: 2 + + +# optimizer config +opt: 'adamw' +weight_decay: 0.025 +filter_bias_and_bn: True +loss_scale: 1024 +use_nesterov: False diff --git a/mindcv/models/cait.py b/mindcv/models/cait.py index 40ca6ec13..7defe4e68 100644 --- a/mindcv/models/cait.py +++ b/mindcv/models/cait.py @@ -22,10 +22,10 @@ __all__ = [ "CaiT", "cait_xxs24_224", - "cait_xs24_384", + "cait_xs24_224", "cait_s24_224", "cait_s24_384", - "cait_s36_384", + "cait_s36_224", "cait_m36_384", "cait_m48_448", ] @@ -42,10 +42,10 @@ def _cfg(url='', **kwargs): default_cfgs = { "cait_xxs24_224": _cfg(url=''), - "cait_xs24_384": _cfg(url='', input_size=(3, 384, 384)), + "cait_xs24_224": _cfg(url=''), "cait_s24_224": _cfg(url=''), "cait_s24_384": _cfg(url='', input_size=(3, 384, 384)), - "cait_s36_384": _cfg(url='', input_size=(3, 384, 384)), + "cait_s36_224": _cfg(url=''), "cait_m36_384": _cfg(url='', input_size=(3, 384, 384)), "cait_m48_448": _cfg(url='', input_size=(3, 448, 448)), } @@ -156,7 +156,12 @@ def __init__(self, head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 - self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias) + # self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias) + + self.q = nn.Dense(dim, dim, has_bias=qkv_bias) + self.k = nn.Dense(dim, dim, has_bias=qkv_bias) + self.v = nn.Dense(dim, dim, has_bias=qkv_bias) + self.attn_drop = Dropout(p=attn_drop_rate) self.proj = nn.Dense(dim, dim, has_bias=False) @@ -173,12 +178,21 @@ def __init__(self, def construct(self, x) -> Tensor: B, N, C = x.shape - qkv = ops.reshape(self.qkv(x), (B, N, 3, self.num_heads, C // self.num_heads)) - qkv = ops.transpose(qkv, (2, 0, 3, 1, 4)) - q, k, v = ops.unstack(qkv, axis=0) - q = ops.mul(q, self.scale) - - attn = self.q_matmul_k(q, k) + # qkv = ops.reshape(self.qkv(x), (B, N, 3, self.num_heads, C // self.num_heads)) + # qkv = ops.transpose(qkv, (2, 0, 3, 1, 4)) + # q, k, v = ops.unstack(qkv, axis=0) + # q = ops.mul(q, self.scale) + # + # attn = self.q_matmul_k(q, k) + + q = ops.reshape(self.q(x), (B, N, self.num_heads, C // self.num_heads)) + q = ops.transpose(q, (0, 2, 1, 3)) * self.scale + k = ops.reshape(self.k(x), (B, N, self.num_heads, C // self.num_heads)) + k = ops.transpose(k, (0, 2, 3, 1)) + v = ops.reshape(self.v(x), (B, N, self.num_heads, C // self.num_heads)) + v = ops.transpose(v, (0, 2, 1, 3)) + + attn = ops.BatchMatMul()(q, k) attn = ops.transpose(attn, (0, 2, 3, 1)) attn = self.proj_l(attn) @@ -369,8 +383,8 @@ def cait_xxs24_224(pretrained: bool = False, num_classes: int = 1000, in_channel @register_model -def cait_xs24_384(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> CaiT: - model = CaiT(img_size=384, patch_size=16, in_channels=in_channels, num_classes=num_classes, +def cait_xs24_224(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> CaiT: + model = CaiT(img_size=224, patch_size=16, in_channels=in_channels, num_classes=num_classes, embed_dim=288, depth=24, num_heads=6, mlp_ratio=4, qkv_bias=False, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), init_values=1e-5, depth_token_only=2, **kwargs) @@ -405,8 +419,8 @@ def cait_s24_384(pretrained: bool = False, num_classes: int = 1000, in_channels= @register_model -def cait_s36_384(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> CaiT: - model = CaiT(img_size=384, patch_size=16, in_channels=in_channels, num_classes=num_classes, +def cait_s36_224(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> CaiT: + model = CaiT(img_size=224, patch_size=16, in_channels=in_channels, num_classes=num_classes, embed_dim=384, depth=36, num_heads=8, mlp_ratio=4, qkv_bias=False, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), init_values=1e-6, depth_token_only=2, **kwargs)