Skip to content

Commit b3a0069

Browse files
author
gaoyang07
committed
update nsga2 search with pymoo_v0.50
1 parent bbb58f1 commit b3a0069

22 files changed

+2675
-4
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
_base_ = ['./spos_shufflenet_supernet_8xb128_in1k.py']
2+
3+
model = dict(norm_training=True)
4+
5+
train_cfg = dict(
6+
_delete_=True,
7+
type='mmrazor.NSGA2SearchLoop',
8+
dataloader=_base_.val_dataloader,
9+
evaluator=_base_.val_evaluator,
10+
max_epochs=20,
11+
num_candidates=50,
12+
top_k=10,
13+
num_mutation=25,
14+
num_crossover=25,
15+
mutate_prob=0.1,
16+
constraints_range=dict(flops=(0., 360.)),
17+
predictor_cfg=dict(
18+
type='mmrazor.MetricPredictor',
19+
encoding_type='normal',
20+
train_samples=2,
21+
handler_cfg=dict(type='mmrazor.GaussProcessHandler')),
22+
)

mmrazor/engine/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
from .runner import (AutoSlimValLoop, DartsEpochBasedTrainLoop,
55
DartsIterBasedTrainLoop, EvolutionSearchLoop,
66
GreedySamplerTrainLoop, SelfDistillValLoop,
7-
SingleTeacherDistillValLoop, SlimmableValLoop)
7+
SingleTeacherDistillValLoop, SlimmableValLoop,
8+
NSGA2SearchLoop)
89

910
__all__ = [
1011
'SeparateOptimWrapperConstructor', 'DumpSubnetHook',
1112
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
1213
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
1314
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'EstimateResourcesHook',
14-
'SelfDistillValLoop'
15+
'SelfDistillValLoop', 'NSGA2SearchLoop'
1516
]

mmrazor/engine/runner/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop
44
from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop
55
from .evolution_search_loop import EvolutionSearchLoop
6+
from .nsganetv2_search_loop import NSGA2SearchLoop
67
from .slimmable_val_loop import SlimmableValLoop
78
from .subnet_sampler_loop import GreedySamplerTrainLoop
89

910
__all__ = [
1011
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
1112
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
12-
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop'
13+
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop',
14+
'NSGA2SearchLoop'
1315
]
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from mmrazor.registry import LOOPS
3+
from .evolution_search_loop import EvolutionSearchLoop
4+
5+
6+
@LOOPS.register_module()
7+
class AttentiveSearchLoop(EvolutionSearchLoop):
8+
"""Loop for evolution searching with attentive tricks from AttentiveNAS.
9+
10+
Args:
11+
runner (Runner): A reference of runner.
12+
dataloader (Dataloader or dict): A dataloader object or a dict to
13+
build a dataloader.
14+
evaluator (Evaluator or dict or list): Used for computing metrics.
15+
max_epochs (int): Total searching epochs. Defaults to 20.
16+
max_keep_ckpts (int): The maximum checkpoints of searcher to keep.
17+
Defaults to 3.
18+
resume_from (str, optional): Specify the path of saved .pkl file for
19+
resuming searching.
20+
num_candidates (int): The length of candidate pool. Defaults to 50.
21+
top_k (int): Specify top k candidates based on scores. Defaults to 10.
22+
num_mutation (int): The number of candidates got by mutation.
23+
Defaults to 25.
24+
num_crossover (int): The number of candidates got by crossover.
25+
Defaults to 25.
26+
mutate_prob (float): The probability of mutation. Defaults to 0.1.
27+
flops_range (tuple, optional): It is used for screening candidates.
28+
resource_estimator_cfg (dict): The config for building estimator, which
29+
is be used to estimate the flops of sampled subnet. Defaults to
30+
None, which means default config is used.
31+
score_key (str): Specify one metric in evaluation results to score
32+
candidates. Defaults to 'accuracy_top-1'.
33+
init_candidates (str, optional): The candidates file path, which is
34+
used to init `self.candidates`. Its format is usually in .yaml
35+
format. Defaults to None.
36+
"""
37+
38+
def _init_pareto(self):
39+
# TODO (gaoyang): Fix apis with mmrazor2.0
40+
for k, v in self.constraints.items():
41+
if not isinstance(v, (list, tuple)):
42+
self.constraints[k] = (0, v)
43+
44+
assert len(self.constraints) == 1, 'Only accept one kind constrain.'
45+
self.pareto_candidates = dict()
46+
constraints = list(self.constraints.items())[0]
47+
discretize_step = self.pareto_mode['discretize_step']
48+
ds = discretize_step
49+
# find the left bound
50+
while ds + 0.5 * discretize_step < constraints[1][0]:
51+
ds += discretize_step
52+
self.pareto_candidates[ds] = []
53+
# find the right bound
54+
while ds - 0.5 * discretize_step < constraints[1][1]:
55+
self.pareto_candidates[ds] = []
56+
ds += discretize_step

mmrazor/engine/runner/evolution_search_loop.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(self,
9696
self.crossover_prob = crossover_prob
9797
self.max_keep_ckpts = max_keep_ckpts
9898
self.resume_from = resume_from
99+
self.trade_off = dict(max_score_key=40)
99100

100101
if init_candidates is None:
101102
self.candidates = Candidates()
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os.path as osp
3+
from copy import deepcopy
4+
5+
import numpy as np
6+
from mmengine import fileio
7+
from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting
8+
9+
from mmrazor.models.task_modules import (GeneticOptimizer,
10+
NSGA2Optimizer,
11+
AuxiliarySingleLevelProblem,
12+
SubsetProblem)
13+
from mmrazor.registry import LOOPS
14+
from mmrazor.structures import Candidates, export_fix_subnet
15+
from .attentive_search_loop import AttentiveSearchLoop
16+
from .utils.high_tradeoff_points import HighTradeoffPoints
17+
18+
# from pymoo.algorithms.moo.nsga2 import NSGA2 as NSGA2Optimizer
19+
# from pymoo.algorithms.soo.nonconvex.ga import GA as GeneticOptimizer
20+
# from pymoo.optimize import minimize
21+
22+
23+
@LOOPS.register_module()
24+
class NSGA2SearchLoop(AttentiveSearchLoop):
25+
"""Evolution search loop with NSGA2 optimizer."""
26+
27+
def run_epoch(self) -> None:
28+
"""Iterate one epoch.
29+
30+
Steps:
31+
0. Collect archives and predictor.
32+
1. Sample some new candidates from the supernet.Then Append them
33+
to the candidates, Thus make its number equal to the specified
34+
number.
35+
2. Validate these candidates(step 1) and update their scores.
36+
3. Pick the top k candidates based on the scores(step 2), which
37+
will be used in mutation and crossover.
38+
4. Implement Mutation and crossover, generate better candidates.
39+
"""
40+
archive = Candidates()
41+
for subnet, score, flops in zip(self.candidates.subnets,
42+
self.candidates.scores,
43+
self.candidates.resources('flops')):
44+
if self.trade_off['max_score_key'] != 0:
45+
score = self.trade_off['max_score_key'] - score
46+
archive.append(subnet)
47+
archive.set_score(-1, score)
48+
archive.set_resource(-1, flops, 'flops')
49+
50+
self.sample_candidates(random=(self._epoch == 0), archive=archive)
51+
self.update_candidates_scores()
52+
53+
scores_before = self.top_k_candidates.scores
54+
self.runner.logger.info(f'top k scores before update: '
55+
f'{scores_before}')
56+
57+
self.candidates.extend(self.top_k_candidates)
58+
self.sort_candidates()
59+
self.top_k_candidates = Candidates(self.candidates[:self.top_k])
60+
61+
scores_after = self.top_k_candidates.scores
62+
self.runner.logger.info(f'top k scores after update: '
63+
f'{scores_after}')
64+
65+
mutation_candidates = self.gen_mutation_candidates()
66+
self.candidates_mutator_crossover = Candidates(mutation_candidates)
67+
crossover_candidates = self.gen_crossover_candidates()
68+
self.candidates_mutator_crossover.extend(crossover_candidates)
69+
70+
assert len(self.candidates_mutator_crossover
71+
) <= self.num_candidates, 'Total of mutation and \
72+
crossover should be less than the number of candidates.'
73+
74+
self.candidates = self.candidates_mutator_crossover
75+
self._epoch += 1
76+
77+
def sample_candidates(self, random: bool = True, archive=None) -> None:
78+
if random:
79+
super().sample_candidates()
80+
else:
81+
candidates = self.sample_candidates_with_nsga2(
82+
archive, self.num_candidates)
83+
new_candidates = []
84+
candidates_resources = []
85+
for candidate in candidates:
86+
is_pass, result = self._check_constraints(candidate)
87+
if is_pass:
88+
new_candidates.append(candidate)
89+
candidates_resources.append(result)
90+
self.candidates = Candidates(new_candidates)
91+
92+
if len(candidates_resources) > 0:
93+
self.candidates.update_resources(
94+
candidates_resources,
95+
start=len(self.candidates.data)-len(candidates_resources))
96+
97+
def sample_candidates_with_nsga2(self, archive: Candidates, num_candidates):
98+
"""Searching for candidates with high-fidelity evaluation."""
99+
F = np.column_stack((archive.scores, archive.resources('flops')))
100+
front_index = NonDominatedSorting().do(F, only_non_dominated_front=True)
101+
102+
fronts = np.array(archive.subnets)[front_index]
103+
fronts = np.array([self.predictor.model2vector(cand) for cand in fronts])
104+
fronts = self.predictor.preprocess(fronts)
105+
106+
# initialize the candidate finding optimization problem
107+
problem = AuxiliarySingleLevelProblem(self, len(fronts[0]))
108+
109+
# initiate a multi-objective solver to optimize the problem
110+
method = NSGA2Optimizer(
111+
pop_size=4,
112+
sampling=fronts, # initialize with current nd archs
113+
eliminate_duplicates=True,
114+
logger=self.runner.logger)
115+
116+
# # kick-off the search
117+
method.initialize(problem, n_gen=2, verbose=True)
118+
result = method.solve()
119+
120+
# check for duplicates
121+
check_list = []
122+
for x in result['pop'].get('X'):
123+
assert x is not None
124+
check_list.append(self.predictor.vector2model(x))
125+
126+
not_duplicate = np.logical_not(
127+
[x in archive.subnets for x in check_list])
128+
129+
# extra process after nsga2 search
130+
sub_problem = SubsetProblem(result['pop'][not_duplicate].get('F')[:, 1],
131+
F[front_index, 1],
132+
num_candidates)
133+
sub_method = GeneticOptimizer(pop_size=num_candidates,
134+
eliminate_duplicates=True)
135+
sub_method.initialize(sub_problem, n_gen=4, verbose=False)
136+
indices = sub_method.solve()['X']
137+
138+
candidates = Candidates()
139+
pop = result['pop'][not_duplicate][indices]
140+
for x in pop.get('X'):
141+
candidates.append(self.predictor.vector2model(x))
142+
143+
return candidates
144+
145+
def sort_candidates(self) -> None:
146+
"""Support sort candidates in single and multiple-obj optimization."""
147+
assert self.trade_off is not None, (
148+
'`self.trade_off` is required when sorting candidates in '
149+
'NSGA2SearchLoop. Got self.trade_off is None.')
150+
ratio = self.trade_off.get('ratio', 1)
151+
multiple_obj_score = []
152+
for score, flops in zip(self.candidates.scores,
153+
self.candidates.resources('flops')):
154+
multiple_obj_score.append((score, flops))
155+
multiple_obj_score = np.array(multiple_obj_score)
156+
max_score_key = self.trade_off.get('max_score_key', 100)
157+
if max_score_key != 0:
158+
multiple_obj_score[:, 0] = \
159+
max_score_key - multiple_obj_score[:, 0]
160+
sort_idx = np.argsort(multiple_obj_score[:, 0])
161+
F = multiple_obj_score[sort_idx]
162+
dm = HighTradeoffPoints(ratio, n_survive=len(multiple_obj_score))
163+
candidate_index = dm.do(F)
164+
candidate_index = sort_idx[candidate_index]
165+
self.candidates = [self.candidates[idx] for idx in candidate_index]
166+
167+
def _save_searcher_ckpt(self, archive=[]):
168+
"""Save searcher ckpt, which is different from common ckpt.
169+
170+
It mainly contains the candicate pool, the top-k candicates with scores
171+
and the current epoch.
172+
"""
173+
if self.runner.rank == 0:
174+
rmse, rho, tau = 0, 0, 0
175+
if len(archive) > 0:
176+
top1_err_pred = self.fit_predictor(archive)
177+
rmse, rho, tau = self.predictor.get_correlation(
178+
top1_err_pred, np.array([x[1] for x in archive]))
179+
180+
save_for_resume = dict()
181+
save_for_resume['_epoch'] = self._epoch
182+
for k in ['candidates', 'top_k_candidates']:
183+
save_for_resume[k] = getattr(self, k)
184+
fileio.dump(
185+
save_for_resume,
186+
osp.join(self.runner.work_dir,
187+
f'search_epoch_{self._epoch}.pkl'))
188+
189+
correlation_str = 'fitting '
190+
# correlation_str += f'{self.predictor.type}: '
191+
correlation_str += f'RMSE = {rmse:.4f}, '
192+
correlation_str += f'Spearmans Rho = {rho:.4f}, '
193+
correlation_str += f'num_candidatesendalls Tau = {tau:.4f}'
194+
195+
self.pareto_mode = False
196+
if self.pareto_mode:
197+
step_str = '\n'
198+
for step, candidates in self.pareto_candidates.items():
199+
if len(candidates) > 0:
200+
step_str += f'step: {step}: '
201+
step_str += f'{candidates[0][self.score_key]}\n'
202+
self.runner.logger.info(
203+
f'Epoch:[{self._epoch + 1}/{self._max_epochs}], '
204+
f'top1_score: {step_str} '
205+
f'{correlation_str}')
206+
else:
207+
self.runner.logger.info(
208+
f'Epoch:[{self._epoch + 1}/{self._max_epochs}], '
209+
f'top1_score: {self.top_k_candidates.scores[0]} '
210+
f'{correlation_str}')
211+
212+
def fit_predictor(self, candidates):
213+
"""anticipate testfn training(err rate)."""
214+
inputs = [export_fix_subnet(x) for x in candidates.subnets]
215+
inputs = np.array([self.predictor.model2vector(x) for x in inputs])
216+
217+
targets = np.array([x[1] for x in candidates])
218+
219+
if not self.predictor.pretrained:
220+
self.predictor.fit(inputs, targets)
221+
222+
metrics = self.predictor.predict(inputs)
223+
if self.max_score_key != 0:
224+
for i in range(len(metrics)):
225+
metrics[i] = self.max_score_key - metrics[i]
226+
return metrics
227+
228+
def finetune_step(self, model):
229+
"""fintune before candidates evaluation."""
230+
# TODO (gaoyang): update with 2.0 version.
231+
self.runner.logger.info('start finetuning...')
232+
model.train()
233+
while self._fintune_epoch < self._max_finetune_epochs:
234+
self.runner.call_hook('before_train_epoch')
235+
for idx, data_batch in enumerate(self.dataloader):
236+
self.runner.call_hook(
237+
'before_train_iter',
238+
batch_idx=idx,
239+
data_batch=data_batch)
240+
241+
outputs = model.train_step(
242+
data_batch, optim_wrapper=self.optim_wrapper)
243+
244+
self.runner.call_hook(
245+
'after_train_iter',
246+
batch_idx=idx,
247+
data_batch=data_batch,
248+
outputs=outputs)
249+
250+
self.runner.call_hook('after_train_epoch')
251+
self._finetune_epoch += 1
252+
253+
model.eval()

0 commit comments

Comments
 (0)