Skip to content

Commit 684875d

Browse files
committed
Merge branch 'dev'
2 parents 1b09896 + eae0b84 commit 684875d

28 files changed

+1246
-63
lines changed

README.md

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@ English | [简体中文](README_zh-CN.md)
4343

4444
MMSelfSup is an open source self-supervised representation learning toolbox based on PyTorch. It is a part of the [OpenMMLab](https://openmmlab.com/) project.
4545

46-
The master branch works with **PyTorch 1.5** or higher.
46+
The master branch works with **PyTorch 1.5+**.
4747

48-
### Major features
48+
<details open>
49+
<summary>Major features</summary>
4950

5051
- **Methods All in One**
5152

@@ -63,37 +64,39 @@ The master branch works with **PyTorch 1.5** or higher.
6364

6465
Since MMSelfSup adopts similar design of modulars and interfaces as those in other OpenMMLab projects, it supports smooth evaluation on downstream tasks with other OpenMMLab projects like object detection and segmentation.
6566

67+
</details>
68+
6669
## What's New
6770

68-
### Preview of 1.x version
71+
### 💎 Stable version
6972

70-
A brand new version of **MMSelfSup v1.0.0rc1** was released in 01/09/2022:
73+
MMSelfSup **v0.11.0** was released in 30/12/2022.
7174

7275
Highlights of the new version:
7376

74-
- Based on [MMEngine](https://github.com/open-mmlab/mmengine) and [MMCV](https://github.com/open-mmlab/mmcv/tree/2.x).
75-
- Released with refactor.
76-
- Refine all [documents](https://mmselfsup.readthedocs.io/en/1.x/).
77-
- Support `MAE`, `SimMIM`, `MoCoV3` with different pre-training epochs and backbones of different scales.
78-
- More concise API.
79-
- More powerful data pipeline.
80-
- Higher accurcy for some algorithms.
77+
- Support `InterCLR`
78+
- Fix some bugs
8179

82-
Find more new features in [1.x branch](https://github.com/open-mmlab/mmselfsup/tree/1.x). Issues and PRs are welcome!
80+
Please refer to [changelog.md](docs/en/changelog.md) for details and release history.
8381

84-
### Stable version
82+
Differences between MMSelfSup and OpenSelfSup codebases can be found in [compatibility.md](docs/en/compatibility.md).
8583

86-
MMSelfSup **v0.10.1** was released in 01/11/2022.
84+
### 🌟 Preview of 1.x version
8785

88-
Highlights of the new version:
86+
A brand new version of **MMSelfSup v1.0.0rc4** was released in 07/12/2022.
8987

90-
- Support MaskFeat
91-
- Update issue form
92-
- Fix some typo in documents
88+
Highlights of the new version:
9389

94-
Please refer to [changelog.md](docs/en/changelog.md) for details and release history.
90+
- Based on [MMEngine](https://github.com/open-mmlab/mmengine) and [MMCV](https://github.com/open-mmlab/mmcv/tree/2.x).
91+
- Refine all [documents](https://mmselfsup.readthedocs.io/en/1.x/).
92+
- Support `BEiT v1`, `BEiT v2`, `MILAN`, `MixMIM`, `EVA`.
93+
- Support `MAE`, `SimMIM`, `MoCoV3` with different pre-training epochs and backbones of different scales.
94+
- More concise APIs.
95+
- More visualization tools.
96+
- More powerful data pipeline.
97+
- Higher accurcy for some algorithms.
9598

96-
Differences between MMSelfSup and OpenSelfSup codebases can be found in [compatibility.md](docs/en/compatibility.md).
99+
Find more new features in [1.x branch](https://github.com/open-mmlab/mmselfsup/tree/1.x). Issues and PRs are welcome!
97100

98101
## Installation
99102

@@ -138,6 +141,7 @@ Supported algorithms:
138141
- [x] [SimSiam (CVPR'2021)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/simsiam)
139142
- [x] [Barlow Twins (ICML'2021)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/barlowtwins)
140143
- [x] [MoCo v3 (ICCV'2021)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/mocov3)
144+
- [x] [InterCLR (IJCV'2022)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/interclr)
141145
- [x] [MAE (CVPR'2022)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/mae)
142146
- [x] [SimMIM (CVPR'2022)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/simmim)
143147
- [x] [MaskFeat (CVPR'2022)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/maskfeat)

README_zh-CN.md

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,14 @@
3939

4040
</div>
4141

42-
## 介绍
42+
## 简介
4343

4444
MMSelfSup 是一个基于 PyTorch 实现的开源自监督表征学习工具箱,是 [OpenMMLab](https://openmmlab.com/) 项目成员之一。
4545

4646
主分支代码支持 **PyTorch 1.5** 及以上的版本。
4747

48-
### 主要特性
48+
<details open>
49+
<summary>主要特性</summary>
4950

5051
- **多方法集成**
5152

@@ -63,37 +64,39 @@ MMSelfSup 是一个基于 PyTorch 实现的开源自监督表征学习工具箱
6364

6465
兼容 OpenMMLab 各大算法库,拥有丰富的下游评测任务和预训练模型的应用。
6566

66-
## 更新
67+
</details>
6768

68-
### 1.x 预览版本
69+
## 最新进展
6970

70-
全新的 **MMSelfSup v1.0.0rc1** 版本已在 2022.09.01 发布。
71+
### 💎 稳定版本
72+
73+
最新的 **v0.11.0** 版本已经在 2022.12.30 发布。
7174

7275
新版本亮点:
7376

74-
- 基于全新的 [MMEngine](https://github.com/open-mmlab/mmengine)[MMCV](https://github.com/open-mmlab/mmcv/tree/2.x)
75-
- 代码库重构,统一接口。
76-
- 完善了新版本 [文档](https://mmselfsup.readthedocs.io/en/1.x/)
77-
- 支持了不同训练时间、不同尺寸的 `MAE`, `SimMIM`, `MoCoV3` 的预训练模型。
78-
- 更加简洁的 API。
79-
- 更加强大的数据管道。
80-
- 部分模型具有更高的准确率。
77+
- 支持 `InterCLR`
78+
- 修复部分 bugs
8179

82-
[1.x 分支](https://github.com/open-mmlab/mmselfsup/tree/1.x) 查看更多新特性。 欢迎大家提 Issues 和 PRs!
80+
请参考 [更新日志](docs/zh_cn/changelog.md) 获取更多细节和历史版本信息。
8381

84-
### 稳定版本
82+
MMSelfSup 和 OpenSelfSup 的不同点写在 [对比文档](docs/en/compatibility.md) 中。
8583

86-
最新的 **v0.10.1** 版本已经在 2022.11.1 发布。
84+
### 🌟 1.x 预览版本
8785

88-
新版本亮点
86+
全新的 **v1.0.0rc4** 版本已经在 2022.12.07 发布
8987

90-
- 支持 MaskFeat
91-
- 更新 issue 模板
92-
- 修复部分文档的错误
88+
新版本亮点:
9389

94-
请参考 [更新日志](docs/zh_cn/changelog.md) 获取更多细节和历史版本信息。
90+
- 基于全新的 [MMEngine](https://github.com/open-mmlab/mmengine)[MMCV](https://github.com/open-mmlab/mmcv/tree/2.x)
91+
- 完善了新版本 [文档](https://mmselfsup.readthedocs.io/en/1.x/)
92+
- 支持了 `BEiT v1`, `BEiT v2`, `MILAN`, `MixMIM`, `EVA`
93+
- 支持了不同训练时间、不同尺寸的 `MAE`, `SimMIM`, `MoCoV3` 的预训练模型。
94+
- 更加简洁的 APIs。
95+
- 更丰富的可视化工具。
96+
- 更加强大的数据管道。
97+
- 部分模型具有更高的准确率。
9598

96-
MMSelfSup 和 OpenSelfSup 的不同点写在 [对比文档](docs/en/compatibility.md) 中。
99+
[1.x 分支](https://github.com/open-mmlab/mmselfsup/tree/1.x) 查看更多新特性。 欢迎大家提 Issues 和 PRs!
97100

98101
## 安装
99102

@@ -138,6 +141,7 @@ MMSelfSup 依赖 [PyTorch](https://pytorch.org/), [MMCV](https://github.com/open
138141
- [x] [SimSiam (CVPR'2021)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/simsiam)
139142
- [x] [Barlow Twins (ICML'2021)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/barlowtwins)
140143
- [x] [MoCo v3 (ICCV'2021)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/mocov3)
144+
- [x] [InterCLR (IJCV'2022)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/interclr)
141145
- [x] [MAE (CVPR'2022)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/mae)
142146
- [x] [SimMIM (CVPR'2022)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/simmim)
143147
- [x] [MaskFeat (CVPR'2022)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/maskfeat)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# dataset settings
2+
data_source = 'ImageNet'
3+
train_dataset_type = 'MultiViewDataset'
4+
extract_dataset_type = 'SingleViewDataset'
5+
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
6+
train_pipeline = [
7+
dict(type='RandomResizedCrop', size=224, scale=(0.2, 1.)),
8+
dict(
9+
type='RandomAppliedTrans',
10+
transforms=[
11+
dict(
12+
type='ColorJitter',
13+
brightness=0.4,
14+
contrast=0.4,
15+
saturation=0.4,
16+
hue=0.1)
17+
],
18+
p=0.8),
19+
dict(type='RandomGrayscale', p=0.2),
20+
dict(type='GaussianBlur', sigma_min=0.1, sigma_max=2.0, p=0.5),
21+
dict(type='RandomHorizontalFlip'),
22+
]
23+
extract_pipeline = [
24+
dict(type='Resize', size=256),
25+
dict(type='CenterCrop', size=224),
26+
]
27+
28+
# prefetch
29+
prefetch = False
30+
if not prefetch:
31+
train_pipeline.extend(
32+
[dict(type='ToTensor'),
33+
dict(type='Normalize', **img_norm_cfg)])
34+
extract_pipeline.extend(
35+
[dict(type='ToTensor'),
36+
dict(type='Normalize', **img_norm_cfg)])
37+
38+
# dataset summary
39+
data = dict(
40+
samples_per_gpu=32, # total 32*8=256
41+
replace=True,
42+
workers_per_gpu=4,
43+
drop_last=True,
44+
train=dict(
45+
type=train_dataset_type,
46+
data_source=dict(
47+
type=data_source,
48+
data_prefix='data/imagenet/train',
49+
ann_file='data/imagenet/meta/train.txt',
50+
),
51+
num_views=[2],
52+
pipelines=[train_pipeline],
53+
prefetch=prefetch))
54+
55+
# additional hooks
56+
num_classes = 10000
57+
custom_hooks = [
58+
dict(
59+
type='InterCLRHook',
60+
extractor=dict(
61+
samples_per_gpu=256,
62+
workers_per_gpu=8,
63+
dataset=dict(
64+
type=extract_dataset_type,
65+
data_source=dict(
66+
type=data_source,
67+
data_prefix='data/imagenet/train',
68+
ann_file='data/imagenet/meta/train.txt',
69+
),
70+
pipeline=extract_pipeline,
71+
prefetch=prefetch),
72+
prefetch=prefetch,
73+
img_norm_cfg=img_norm_cfg),
74+
clustering=dict(type='Kmeans', k=num_classes, pca_dim=-1), # no pca
75+
centroids_update_interval=10, # iter
76+
deal_with_small_clusters_interval=1,
77+
evaluate_interval=50,
78+
warmup_epochs=0,
79+
init_memory=True,
80+
initial=True, # call initially
81+
online_labels=True,
82+
interval=10) # same as the checkpoint interval
83+
]

configs/selfsup/_base_/datasets/imagenet_odc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
# dataset summary
3333
data = dict(
3434
samples_per_gpu=64, # 64*8
35-
sampling_replace=True,
35+
replace=True,
3636
workers_per_gpu=4,
3737
train=dict(
3838
type=dataset_type,
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# model settings
2+
model = dict(
3+
type='InterCLRMoCo',
4+
queue_len=65536,
5+
feat_dim=128,
6+
momentum=0.999,
7+
backbone=dict(
8+
type='ResNet',
9+
depth=50,
10+
in_channels=3,
11+
out_indices=[4], # 0: conv-1, x: stage-x
12+
norm_cfg=dict(type='BN')),
13+
neck=dict(
14+
type='MoCoV2Neck',
15+
in_channels=2048,
16+
hid_channels=2048,
17+
out_channels=128,
18+
with_avg_pool=True),
19+
head=dict(type='ContrastiveHead', temperature=0.2),
20+
memory_bank=dict(
21+
type='InterCLRMemory',
22+
length=1281167,
23+
feat_dim=128,
24+
momentum=1.,
25+
num_classes=10000,
26+
min_cluster=20,
27+
debug=False),
28+
online_labels=True,
29+
neg_num=16384,
30+
neg_sampling='semihard', # 'hard', 'semihard', 'random', 'semieasy'
31+
semihard_neg_pool_num=128000,
32+
semieasy_neg_pool_num=128000,
33+
intra_cos_marign_loss=False,
34+
intra_cos_margin=0,
35+
inter_cos_marign_loss=True,
36+
inter_cos_margin=-0.5,
37+
intra_loss_weight=0.75,
38+
inter_loss_weight=0.25,
39+
share_neck=True,
40+
num_classes=10000)

configs/selfsup/interclr/README.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# InterCLR
2+
3+
> [Delving into Inter-Image Invariance for Unsupervised Visual Representations](https://arxiv.org/abs/2008.11702)
4+
5+
<!-- [ALGORITHM] -->
6+
7+
## Abstract
8+
9+
Contrastive learning has recently shown immense
10+
potential in unsupervised visual representation learning. Existing studies in this track mainly focus on intra-image invariance learning. The learning typically uses rich intraimage transformations to construct positive pairs and then
11+
maximizes agreement using a contrastive loss. The merits
12+
of inter-image invariance, conversely, remain much less explored. One major obstacle to exploit inter-image invariance
13+
is that it is unclear how to reliably construct inter-image
14+
positive pairs, and further derive effective supervision from
15+
them since no pair annotations are available. In this work,
16+
we present a comprehensive empirical study to better understand the role of inter-image invariance learning from three main constituting components: pseudo-label maintenance,
17+
sampling strategy, and decision boundary design. To facilitate the study, we introduce a unified and generic framework that supports the integration of unsupervised intra- and
18+
inter-image invariance learning. Through carefully-designed
19+
comparisons and analysis, multiple valuable observations
20+
are revealed: 1) online labels converge faster and perform
21+
better than offline labels; 2) semi-hard negative samples are more reliable and unbiased than hard negative samples; 3) a
22+
less stringent decision boundary is more favorable for interimage invariance learning. With all the obtained recipes, our final model, namely InterCLR, shows consistent improvements over state-of-the-art intra-image invariance learning methods on multiple standard benchmarks. We hope this
23+
work will provide useful experience for devising effective unsupervised inter-image invariance learning.
24+
25+
<div align="center">
26+
<img src="https://user-images.githubusercontent.com/52497952/205854109-2385b765-e12b-4e22-b7b8-45db6292895b.png" width="800" />
27+
</div>
28+
29+
## Results and Models
30+
31+
In this page, we provide benchmarks as much as possible to evaluate our pre-trained models. If not mentioned, all models are pre-trained on ImageNet-1k dataset. Here, we use MoCov2-InterCLR as an example. More models and results are coming soon.
32+
33+
### Classification
34+
35+
#### VOC SVM / Low-shot SVM
36+
37+
The **Best Layer** indicates that the best results are obtained from which layers feature map. For example, if the **Best Layer** is **feature3**, its best result is obtained from the second stage of ResNet (1 for stem layer, 2-5 for 4 stage layers).
38+
39+
Besides, k=1 to 96 indicates the hyper-parameter of Low-shot SVM.
40+
41+
| Self-Supervised Config | Best Layer | SVM | k=1 | k=2 | k=4 | k=8 | k=16 | k=32 | k=64 | k=96 |
42+
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ---- | ----- |
43+
| [interclr-moco_resnet50_8xb32-coslr-200e](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k.py) | feature5 | 85.24 | 45.08 | 59.25 | 65.99 | 74.31 | 77.95 | 80.68 | 82.7 | 83.49 |
44+
45+
#### ImageNet Linear Evaluation
46+
47+
The **Feature1 - Feature5** don't have the GlobalAveragePooling, the feature map is pooled to the specific dimensions and then follows a Linear layer to do the classification. Please refer to [resnet50_mhead_linear-8xb32-steplr-90e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/resnet50_mhead_linear-8xb32-steplr-90e_in1k.py) for details of config.
48+
49+
| Self-Supervised Config | Feature1 | Feature2 | Feature3 | Feature4 | Feature5 |
50+
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------- | -------- | -------- | -------- | -------- |
51+
| [interclr-moco_resnet50_8xb32-coslr-200e](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k.py) | 15.59 | 35.10 | 47.36 | 62.86 | 68.04 |
52+
53+
## Citation
54+
55+
```bibtex
56+
@article{xie2022delving,
57+
title={Delving into inter-image invariance for unsupervised visual representations},
58+
author={Xie, Jiahao and Zhan, Xiaohang and Liu, Ziwei and Ong, Yew-Soon and Loy, Chen Change},
59+
journal={International Journal of Computer Vision},
60+
year={2022}
61+
}
62+
```
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
_base_ = [
2+
'../_base_/models/interclr-moco.py',
3+
'../_base_/datasets/imagenet_interclr-moco.py',
4+
'../_base_/schedules/sgd_coslr-200e_in1k.py',
5+
'../_base_/default_runtime.py',
6+
]
7+
8+
# model settings
9+
model = dict(
10+
memory_bank=dict(num_classes={{_base_.num_classes}}),
11+
num_classes={{_base_.num_classes}},
12+
)
13+
14+
# runtime settings
15+
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
16+
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
17+
# it will remove the oldest one to keep the number of total ckpts as 3
18+
checkpoint_config = dict(interval=10, max_keep_ckpts=3)

0 commit comments

Comments
 (0)