diff --git a/configs/cait/README.md b/configs/cait/README.md
new file mode 100644
index 000000000..1dd124c4d
--- /dev/null
+++ b/configs/cait/README.md
@@ -0,0 +1,88 @@
+# Going deeper with Image Transformers
+
+> [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239)
+
+## Introduction
+
+CaiT is built based on ViT but made two contributions to improve model performance.
+Firstly, Layerscale is introduced to facilitate the convergence.
+Secondly, class-attention offers a more effective processing of the class embedding.
+By combing these parts, Cait could get a SOTA performance on ImageNet-1K dataset.
+
+
+## Results
+
+Our reproduced model performance on ImageNet-1K is reported as follows.
+
+
+
+| Model | Context | Top-1 (%) | Top-5 (%) | Params(M) | Recipe | Download |
+|----------------| -------- |----------|-----------|-----------|--------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------|
+| cait_xxs24_224 | D910x8-G | 77.71 | 94.10 | 11.94 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/cait/cait_xxs24_224.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/cait/cait_xxs24-31b307a8.ckpt) |
+| cait_xs24_224 | D910x8-G | 81.29 | 95.60 | 26.53 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/cait/cait_xs24_224.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/cait/cait_xs24-ba0c2053.ckpt) |
+| cait_s24_224 | D910x8-G | 82.25 | 95.95 | 46.88 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/cait/cait_s24_224.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/cait/cait_s24-0a06be71.ckpt) |
+| cait_s36_224 | D910x8-G | 82.11 | 95.84 | 68.16 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/cait/cait_s36_224.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/cait/cait_s36-2e42bfc8.ckpt) |
+
+
+
+
+#### Notes
+
+- Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G - graph mode or F - pynative mode with ms function. For example, D910x8-G is for training on 8 pieces of Ascend 910 NPU using graph mode.
+- Top-1 and Top-5: Accuracy reported on the validation set of ImageNet-1K.
+
+## Quick Start
+
+### Preparation
+
+#### Installation
+
+Please refer to the [installation instruction](https://github.com/mindspore-lab/mindcv#installation) in MindCV.
+
+#### Dataset Preparation
+
+Please download the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/2012/index.php) dataset for model training and validation.
+
+### Training
+
+* Distributed Training
+
+It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run
+
+```shell
+# distributed training on multiple GPU/Ascend devices
+mpirun -n 8 python train.py --config configs/cait/cait_xxs24_224.yaml --data_dir /path/to/imagenet
+```
+> If the script is executed by the root user, the `--allow-run-as-root` parameter must be added to `mpirun`.
+
+Similarly, you can train the model on multiple GPU devices with the above `mpirun` command.
+
+For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindcv/blob/main/config.py).
+
+**Note:** As the global batch size (batch_size x num_devices) is an important hyper-parameter, it is recommended to keep the global batch size unchanged for reproduction or adjust the learning rate linearly to a new global batch size.
+
+* Standalone Training
+
+If you want to train or finetune the model on a smaller dataset without distributed training, please run:
+
+```shell
+# standalone training on a CPU/GPU/Ascend device
+python train.py --config configs/cait/cait_xxs24_224.yaml --data_dir /path/to/dataset --distribute False
+```
+
+### Validation
+
+To validate the accuracy of the trained model, you can use `validate.py` and parse the checkpoint path with `--ckpt_path`.
+
+```
+python validate.py -c configs/cait/cait_xxs24_224.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt
+```
+
+### Deployment
+
+Please refer to the [deployment tutorial](https://mindspore-lab.github.io/mindcv/tutorials/deployment/).
+
+## References
+
+
+[1] Touvron H, Cord M, Sablayrolles A, et al. Going deeper with image transformers[C]//Proceedings of the IEEE/CVF international conference on computer vision. 2021: 32-42.
diff --git a/configs/cait/cait_s24_224.yaml b/configs/cait/cait_s24_224.yaml
new file mode 100644
index 000000000..a786c7946
--- /dev/null
+++ b/configs/cait/cait_s24_224.yaml
@@ -0,0 +1,61 @@
+# system config
+mode: 0
+distribute: True
+num_parallel_workers: 8
+val_while_train: True
+
+# dataset config
+dataset: 'imagenet'
+data_dir: '/path/to/imagenet'
+shuffle: True
+dataset_download: False
+batch_size: 64
+drop_remainder: True
+
+# augmentation config
+image_resize: 224
+scale: [0.08, 1.0]
+ratio: [0.75, 1.333]
+hflip: 0.5
+vflip: 0.
+interpolation: 'bicubic'
+auto_augment: 'randaug-m9-mstd0.5-inc1'
+re_prob: 0.25
+mixup: 0.8
+cutmix: 1
+color_jitter: 0.3
+crop_pct: 1.0
+ema: True
+ema_decay: 0.99996
+
+# model config
+model: 'cait_s24_224'
+num_classes: 1000
+pretrained: False
+ckpt_path: ''
+keep_checkpoint_max: 10
+ckpt_save_dir: '/cache/output/'
+epoch_size: 400
+dataset_sink_mode: True
+amp_level: 'O2'
+drop_path_rate: 0.1
+
+# loss config
+loss: 'CE'
+label_smoothing: 0.1
+
+# lr scheduler config
+scheduler: 'warmup_cosine_decay'
+lr: 0.001
+min_lr: 0.000001
+warmup_epochs: 30
+decay_epochs: 370
+num_cycles: 2
+
+
+# optimizer config
+opt: 'adamw'
+weight_decay: 0.05
+filter_bias_and_bn: True
+loss_scale: 1024
+use_nesterov: False
diff --git a/configs/cait/cait_s36_224.yaml b/configs/cait/cait_s36_224.yaml
new file mode 100644
index 000000000..1acc253db
--- /dev/null
+++ b/configs/cait/cait_s36_224.yaml
@@ -0,0 +1,61 @@
+# system config
+mode: 0
+distribute: True
+num_parallel_workers: 8
+val_while_train: True
+
+# dataset config
+dataset: 'imagenet'
+data_dir: '/path/to/iamgenet'
+shuffle: True
+dataset_download: False
+batch_size: 64
+drop_remainder: True
+
+# augmentation config
+image_resize: 224
+scale: [0.08, 1.0]
+ratio: [0.75, 1.333]
+hflip: 0.5
+vflip: 0.
+interpolation: 'bicubic'
+auto_augment: 'randaug-m9-mstd0.5-inc1'
+re_prob: 0.25
+mixup: 0.8
+cutmix: 1
+color_jitter: 0.3
+crop_pct: 1.0
+ema: True
+ema_decay: 0.99996
+
+# model config
+model: 'cait_s36_224'
+num_classes: 1000
+pretrained: False
+ckpt_path: ''
+keep_checkpoint_max: 10
+ckpt_save_dir: './ckpt'
+epoch_size: 400
+dataset_sink_mode: True
+amp_level: 'O2'
+drop_path_rate: 0.1
+
+# loss config
+loss: 'CE'
+label_smoothing: 0.1
+
+# lr scheduler config
+scheduler: 'warmup_cosine_decay'
+lr: 0.002
+min_lr: 0.000001
+warmup_epochs: 30
+decay_epochs: 370
+num_cycles: 2
+
+
+# optimizer config
+opt: 'adamw'
+weight_decay: 0.05
+filter_bias_and_bn: True
+loss_scale: 1024
+use_nesterov: False
diff --git a/configs/cait/cait_xs24_224.yaml b/configs/cait/cait_xs24_224.yaml
new file mode 100644
index 000000000..9f675aca4
--- /dev/null
+++ b/configs/cait/cait_xs24_224.yaml
@@ -0,0 +1,61 @@
+# system config
+mode: 0
+distribute: True
+num_parallel_workers: 8
+val_while_train: True
+
+# dataset config
+dataset: 'imagenet'
+data_dir: '/path/to/imagenet'
+shuffle: True
+dataset_download: False
+batch_size: 64
+drop_remainder: True
+
+# augmentation config
+image_resize: 224
+scale: [0.08, 1.0]
+ratio: [0.75, 1.333]
+hflip: 0.5
+vflip: 0.
+interpolation: 'bicubic'
+auto_augment: 'randaug-m9-mstd0.5-inc1'
+re_prob: 0.25
+mixup: 0.8
+cutmix: 1
+color_jitter: 0.3
+crop_pct: 1.0
+ema: True
+ema_decay: 0.99996
+
+# model config
+model: 'cait_xs24_224'
+num_classes: 1000 #
+pretrained: False #
+ckpt_path: ''
+keep_checkpoint_max: 10
+ckpt_save_dir: './ckpt'
+epoch_size: 400
+dataset_sink_mode: True
+amp_level: 'O2'
+drop_path_rate: 0.1
+
+# loss config
+loss: 'CE'
+label_smoothing: 0.1
+
+# lr scheduler config
+scheduler: 'warmup_cosine_decay'
+lr: 0.001
+min_lr: 0.000001
+warmup_epochs: 40
+decay_epochs: 360
+num_cycles: 2
+
+
+# optimizer config
+opt: 'adamw'
+weight_decay: 0.05
+filter_bias_and_bn: True
+loss_scale: 1024
+use_nesterov: False
diff --git a/configs/cait/cait_xxs24_224.yaml b/configs/cait/cait_xxs24_224.yaml
new file mode 100644
index 000000000..371792b2b
--- /dev/null
+++ b/configs/cait/cait_xxs24_224.yaml
@@ -0,0 +1,61 @@
+# system config
+mode: 0
+distribute: True
+num_parallel_workers: 8
+val_while_train: True
+
+# dataset config
+dataset: 'imagenet'
+data_dir: '/path/to/dataset'
+shuffle: True
+dataset_download: False
+batch_size: 128
+drop_remainder: True
+
+# augmentation config
+image_resize: 224
+scale: [0.08, 1.0]
+ratio: [0.75, 1.333]
+hflip: 0.5
+vflip: 0.
+interpolation: 'bicubic'
+auto_augment: 'randaug-m9-mstd0.5-inc1'
+re_prob: 0.25
+mixup: 0.8
+cutmix: 1
+color_jitter: 0.3
+crop_pct: 1.0
+ema: True
+ema_decay: 0.99996
+
+# model config
+model: 'cait_xxs24_224'
+num_classes: 1000
+pretrained: False
+ckpt_path: ''
+keep_checkpoint_max: 10
+ckpt_save_dir: './ckpt'
+epoch_size: 500
+dataset_sink_mode: True
+amp_level: 'O2'
+drop_path_rate: 0.1
+
+# loss config
+loss: 'CE'
+label_smoothing: 0.1
+
+# lr scheduler config
+scheduler: 'warmup_cosine_decay'
+lr: 0.001
+min_lr: 0.000001
+warmup_epochs: 40
+decay_epochs: 460
+num_cycles: 2
+
+
+# optimizer config
+opt: 'adamw'
+weight_decay: 0.025
+filter_bias_and_bn: True
+loss_scale: 1024
+use_nesterov: False
diff --git a/mindcv/models/cait.py b/mindcv/models/cait.py
index 40ca6ec13..7defe4e68 100644
--- a/mindcv/models/cait.py
+++ b/mindcv/models/cait.py
@@ -22,10 +22,10 @@
__all__ = [
"CaiT",
"cait_xxs24_224",
- "cait_xs24_384",
+ "cait_xs24_224",
"cait_s24_224",
"cait_s24_384",
- "cait_s36_384",
+ "cait_s36_224",
"cait_m36_384",
"cait_m48_448",
]
@@ -42,10 +42,10 @@ def _cfg(url='', **kwargs):
default_cfgs = {
"cait_xxs24_224": _cfg(url=''),
- "cait_xs24_384": _cfg(url='', input_size=(3, 384, 384)),
+ "cait_xs24_224": _cfg(url=''),
"cait_s24_224": _cfg(url=''),
"cait_s24_384": _cfg(url='', input_size=(3, 384, 384)),
- "cait_s36_384": _cfg(url='', input_size=(3, 384, 384)),
+ "cait_s36_224": _cfg(url=''),
"cait_m36_384": _cfg(url='', input_size=(3, 384, 384)),
"cait_m48_448": _cfg(url='', input_size=(3, 448, 448)),
}
@@ -156,7 +156,12 @@ def __init__(self,
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
- self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias)
+ # self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias)
+
+ self.q = nn.Dense(dim, dim, has_bias=qkv_bias)
+ self.k = nn.Dense(dim, dim, has_bias=qkv_bias)
+ self.v = nn.Dense(dim, dim, has_bias=qkv_bias)
+
self.attn_drop = Dropout(p=attn_drop_rate)
self.proj = nn.Dense(dim, dim, has_bias=False)
@@ -173,12 +178,21 @@ def __init__(self,
def construct(self, x) -> Tensor:
B, N, C = x.shape
- qkv = ops.reshape(self.qkv(x), (B, N, 3, self.num_heads, C // self.num_heads))
- qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
- q, k, v = ops.unstack(qkv, axis=0)
- q = ops.mul(q, self.scale)
-
- attn = self.q_matmul_k(q, k)
+ # qkv = ops.reshape(self.qkv(x), (B, N, 3, self.num_heads, C // self.num_heads))
+ # qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
+ # q, k, v = ops.unstack(qkv, axis=0)
+ # q = ops.mul(q, self.scale)
+ #
+ # attn = self.q_matmul_k(q, k)
+
+ q = ops.reshape(self.q(x), (B, N, self.num_heads, C // self.num_heads))
+ q = ops.transpose(q, (0, 2, 1, 3)) * self.scale
+ k = ops.reshape(self.k(x), (B, N, self.num_heads, C // self.num_heads))
+ k = ops.transpose(k, (0, 2, 3, 1))
+ v = ops.reshape(self.v(x), (B, N, self.num_heads, C // self.num_heads))
+ v = ops.transpose(v, (0, 2, 1, 3))
+
+ attn = ops.BatchMatMul()(q, k)
attn = ops.transpose(attn, (0, 2, 3, 1))
attn = self.proj_l(attn)
@@ -369,8 +383,8 @@ def cait_xxs24_224(pretrained: bool = False, num_classes: int = 1000, in_channel
@register_model
-def cait_xs24_384(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> CaiT:
- model = CaiT(img_size=384, patch_size=16, in_channels=in_channels, num_classes=num_classes,
+def cait_xs24_224(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> CaiT:
+ model = CaiT(img_size=224, patch_size=16, in_channels=in_channels, num_classes=num_classes,
embed_dim=288, depth=24, num_heads=6, mlp_ratio=4, qkv_bias=False,
norm_layer=partial(nn.LayerNorm, epsilon=1e-6), init_values=1e-5, depth_token_only=2,
**kwargs)
@@ -405,8 +419,8 @@ def cait_s24_384(pretrained: bool = False, num_classes: int = 1000, in_channels=
@register_model
-def cait_s36_384(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> CaiT:
- model = CaiT(img_size=384, patch_size=16, in_channels=in_channels, num_classes=num_classes,
+def cait_s36_224(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> CaiT:
+ model = CaiT(img_size=224, patch_size=16, in_channels=in_channels, num_classes=num_classes,
embed_dim=384, depth=36, num_heads=8, mlp_ratio=4, qkv_bias=False,
norm_layer=partial(nn.LayerNorm, epsilon=1e-6), init_values=1e-6, depth_token_only=2,
**kwargs)