|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +import os.path as osp |
| 3 | +from copy import deepcopy |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from mmengine import fileio |
| 7 | +from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting |
| 8 | + |
| 9 | +from mmrazor.models.task_modules import (GeneticOptimizer, |
| 10 | + NSGA2Optimizer, |
| 11 | + AuxiliarySingleLevelProblem, |
| 12 | + SubsetProblem) |
| 13 | +from mmrazor.registry import LOOPS |
| 14 | +from mmrazor.structures import Candidates, export_fix_subnet |
| 15 | +from .attentive_search_loop import AttentiveSearchLoop |
| 16 | +from .utils.high_tradeoff_points import HighTradeoffPoints |
| 17 | + |
| 18 | +# from pymoo.algorithms.moo.nsga2 import NSGA2 as NSGA2Optimizer |
| 19 | +# from pymoo.algorithms.soo.nonconvex.ga import GA as GeneticOptimizer |
| 20 | +# from pymoo.optimize import minimize |
| 21 | + |
| 22 | + |
| 23 | +@LOOPS.register_module() |
| 24 | +class NSGA2SearchLoop(AttentiveSearchLoop): |
| 25 | + """Evolution search loop with NSGA2 optimizer.""" |
| 26 | + |
| 27 | + def run_epoch(self) -> None: |
| 28 | + """Iterate one epoch. |
| 29 | +
|
| 30 | + Steps: |
| 31 | + 0. Collect archives and predictor. |
| 32 | + 1. Sample some new candidates from the supernet.Then Append them |
| 33 | + to the candidates, Thus make its number equal to the specified |
| 34 | + number. |
| 35 | + 2. Validate these candidates(step 1) and update their scores. |
| 36 | + 3. Pick the top k candidates based on the scores(step 2), which |
| 37 | + will be used in mutation and crossover. |
| 38 | + 4. Implement Mutation and crossover, generate better candidates. |
| 39 | + """ |
| 40 | + archive = Candidates() |
| 41 | + for subnet, score, flops in zip(self.candidates.subnets, |
| 42 | + self.candidates.scores, |
| 43 | + self.candidates.resources('flops')): |
| 44 | + if self.trade_off['max_score_key'] != 0: |
| 45 | + score = self.trade_off['max_score_key'] - score |
| 46 | + archive.append(subnet) |
| 47 | + archive.set_score(-1, score) |
| 48 | + archive.set_resource(-1, flops, 'flops') |
| 49 | + |
| 50 | + self.sample_candidates(random=(self._epoch == 0), archive=archive) |
| 51 | + self.update_candidates_scores() |
| 52 | + |
| 53 | + scores_before = self.top_k_candidates.scores |
| 54 | + self.runner.logger.info(f'top k scores before update: ' |
| 55 | + f'{scores_before}') |
| 56 | + |
| 57 | + self.candidates.extend(self.top_k_candidates) |
| 58 | + self.sort_candidates() |
| 59 | + self.top_k_candidates = Candidates(self.candidates[:self.top_k]) |
| 60 | + |
| 61 | + scores_after = self.top_k_candidates.scores |
| 62 | + self.runner.logger.info(f'top k scores after update: ' |
| 63 | + f'{scores_after}') |
| 64 | + |
| 65 | + mutation_candidates = self.gen_mutation_candidates() |
| 66 | + self.candidates_mutator_crossover = Candidates(mutation_candidates) |
| 67 | + crossover_candidates = self.gen_crossover_candidates() |
| 68 | + self.candidates_mutator_crossover.extend(crossover_candidates) |
| 69 | + |
| 70 | + assert len(self.candidates_mutator_crossover |
| 71 | + ) <= self.num_candidates, 'Total of mutation and \ |
| 72 | + crossover should be less than the number of candidates.' |
| 73 | + |
| 74 | + self.candidates = self.candidates_mutator_crossover |
| 75 | + self._epoch += 1 |
| 76 | + |
| 77 | + def sample_candidates(self, random: bool = True, archive=None) -> None: |
| 78 | + if random: |
| 79 | + super().sample_candidates() |
| 80 | + else: |
| 81 | + candidates = self.sample_candidates_with_nsga2( |
| 82 | + archive, self.num_candidates) |
| 83 | + new_candidates = [] |
| 84 | + candidates_resources = [] |
| 85 | + for candidate in candidates: |
| 86 | + is_pass, result = self._check_constraints(candidate) |
| 87 | + if is_pass: |
| 88 | + new_candidates.append(candidate) |
| 89 | + candidates_resources.append(result) |
| 90 | + self.candidates = Candidates(new_candidates) |
| 91 | + |
| 92 | + if len(candidates_resources) > 0: |
| 93 | + self.candidates.update_resources( |
| 94 | + candidates_resources, |
| 95 | + start=len(self.candidates.data)-len(candidates_resources)) |
| 96 | + |
| 97 | + def sample_candidates_with_nsga2(self, archive: Candidates, num_candidates): |
| 98 | + """Searching for candidates with high-fidelity evaluation.""" |
| 99 | + F = np.column_stack((archive.scores, archive.resources('flops'))) |
| 100 | + front_index = NonDominatedSorting().do(F, only_non_dominated_front=True) |
| 101 | + |
| 102 | + fronts = np.array(archive.subnets)[front_index] |
| 103 | + fronts = np.array([self.predictor.model2vector(cand) for cand in fronts]) |
| 104 | + fronts = self.predictor.preprocess(fronts) |
| 105 | + |
| 106 | + # initialize the candidate finding optimization problem |
| 107 | + problem = AuxiliarySingleLevelProblem(self, len(fronts[0])) |
| 108 | + |
| 109 | + # initiate a multi-objective solver to optimize the problem |
| 110 | + method = NSGA2Optimizer( |
| 111 | + pop_size=4, |
| 112 | + sampling=fronts, # initialize with current nd archs |
| 113 | + eliminate_duplicates=True, |
| 114 | + logger=self.runner.logger) |
| 115 | + |
| 116 | + # # kick-off the search |
| 117 | + method.initialize(problem, n_gen=2, verbose=True) |
| 118 | + result = method.solve() |
| 119 | + |
| 120 | + # check for duplicates |
| 121 | + check_list = [] |
| 122 | + for x in result['pop'].get('X'): |
| 123 | + assert x is not None |
| 124 | + check_list.append(self.predictor.vector2model(x)) |
| 125 | + |
| 126 | + not_duplicate = np.logical_not( |
| 127 | + [x in archive.subnets for x in check_list]) |
| 128 | + |
| 129 | + # extra process after nsga2 search |
| 130 | + sub_problem = SubsetProblem(result['pop'][not_duplicate].get('F')[:, 1], |
| 131 | + F[front_index, 1], |
| 132 | + num_candidates) |
| 133 | + sub_method = GeneticOptimizer(pop_size=num_candidates, |
| 134 | + eliminate_duplicates=True) |
| 135 | + sub_method.initialize(sub_problem, n_gen=4, verbose=False) |
| 136 | + indices = sub_method.solve()['X'] |
| 137 | + |
| 138 | + candidates = Candidates() |
| 139 | + pop = result['pop'][not_duplicate][indices] |
| 140 | + for x in pop.get('X'): |
| 141 | + candidates.append(self.predictor.vector2model(x)) |
| 142 | + |
| 143 | + return candidates |
| 144 | + |
| 145 | + def sort_candidates(self) -> None: |
| 146 | + """Support sort candidates in single and multiple-obj optimization.""" |
| 147 | + assert self.trade_off is not None, ( |
| 148 | + '`self.trade_off` is required when sorting candidates in ' |
| 149 | + 'NSGA2SearchLoop. Got self.trade_off is None.') |
| 150 | + ratio = self.trade_off.get('ratio', 1) |
| 151 | + multiple_obj_score = [] |
| 152 | + for score, flops in zip(self.candidates.scores, |
| 153 | + self.candidates.resources('flops')): |
| 154 | + multiple_obj_score.append((score, flops)) |
| 155 | + multiple_obj_score = np.array(multiple_obj_score) |
| 156 | + max_score_key = self.trade_off.get('max_score_key', 100) |
| 157 | + if max_score_key != 0: |
| 158 | + multiple_obj_score[:, 0] = \ |
| 159 | + max_score_key - multiple_obj_score[:, 0] |
| 160 | + sort_idx = np.argsort(multiple_obj_score[:, 0]) |
| 161 | + F = multiple_obj_score[sort_idx] |
| 162 | + dm = HighTradeoffPoints(ratio, n_survive=len(multiple_obj_score)) |
| 163 | + candidate_index = dm.do(F) |
| 164 | + candidate_index = sort_idx[candidate_index] |
| 165 | + self.candidates = [self.candidates[idx] for idx in candidate_index] |
| 166 | + |
| 167 | + def _save_searcher_ckpt(self, archive=[]): |
| 168 | + """Save searcher ckpt, which is different from common ckpt. |
| 169 | +
|
| 170 | + It mainly contains the candicate pool, the top-k candicates with scores |
| 171 | + and the current epoch. |
| 172 | + """ |
| 173 | + if self.runner.rank == 0: |
| 174 | + rmse, rho, tau = 0, 0, 0 |
| 175 | + if len(archive) > 0: |
| 176 | + top1_err_pred = self.fit_predictor(archive) |
| 177 | + rmse, rho, tau = self.predictor.get_correlation( |
| 178 | + top1_err_pred, np.array([x[1] for x in archive])) |
| 179 | + |
| 180 | + save_for_resume = dict() |
| 181 | + save_for_resume['_epoch'] = self._epoch |
| 182 | + for k in ['candidates', 'top_k_candidates']: |
| 183 | + save_for_resume[k] = getattr(self, k) |
| 184 | + fileio.dump( |
| 185 | + save_for_resume, |
| 186 | + osp.join(self.runner.work_dir, |
| 187 | + f'search_epoch_{self._epoch}.pkl')) |
| 188 | + |
| 189 | + correlation_str = 'fitting ' |
| 190 | + # correlation_str += f'{self.predictor.type}: ' |
| 191 | + correlation_str += f'RMSE = {rmse:.4f}, ' |
| 192 | + correlation_str += f'Spearmans Rho = {rho:.4f}, ' |
| 193 | + correlation_str += f'num_candidatesendalls Tau = {tau:.4f}' |
| 194 | + |
| 195 | + self.pareto_mode = False |
| 196 | + if self.pareto_mode: |
| 197 | + step_str = '\n' |
| 198 | + for step, candidates in self.pareto_candidates.items(): |
| 199 | + if len(candidates) > 0: |
| 200 | + step_str += f'step: {step}: ' |
| 201 | + step_str += f'{candidates[0][self.score_key]}\n' |
| 202 | + self.runner.logger.info( |
| 203 | + f'Epoch:[{self._epoch + 1}/{self._max_epochs}], ' |
| 204 | + f'top1_score: {step_str} ' |
| 205 | + f'{correlation_str}') |
| 206 | + else: |
| 207 | + self.runner.logger.info( |
| 208 | + f'Epoch:[{self._epoch + 1}/{self._max_epochs}], ' |
| 209 | + f'top1_score: {self.top_k_candidates.scores[0]} ' |
| 210 | + f'{correlation_str}') |
| 211 | + |
| 212 | + def fit_predictor(self, candidates): |
| 213 | + """anticipate testfn training(err rate).""" |
| 214 | + inputs = [export_fix_subnet(x) for x in candidates.subnets] |
| 215 | + inputs = np.array([self.predictor.model2vector(x) for x in inputs]) |
| 216 | + |
| 217 | + targets = np.array([x[1] for x in candidates]) |
| 218 | + |
| 219 | + if not self.predictor.pretrained: |
| 220 | + self.predictor.fit(inputs, targets) |
| 221 | + |
| 222 | + metrics = self.predictor.predict(inputs) |
| 223 | + if self.max_score_key != 0: |
| 224 | + for i in range(len(metrics)): |
| 225 | + metrics[i] = self.max_score_key - metrics[i] |
| 226 | + return metrics |
| 227 | + |
| 228 | + def finetune_step(self, model): |
| 229 | + """fintune before candidates evaluation.""" |
| 230 | + # TODO (gaoyang): update with 2.0 version. |
| 231 | + self.runner.logger.info('start finetuning...') |
| 232 | + model.train() |
| 233 | + while self._fintune_epoch < self._max_finetune_epochs: |
| 234 | + self.runner.call_hook('before_train_epoch') |
| 235 | + for idx, data_batch in enumerate(self.dataloader): |
| 236 | + self.runner.call_hook( |
| 237 | + 'before_train_iter', |
| 238 | + batch_idx=idx, |
| 239 | + data_batch=data_batch) |
| 240 | + |
| 241 | + outputs = model.train_step( |
| 242 | + data_batch, optim_wrapper=self.optim_wrapper) |
| 243 | + |
| 244 | + self.runner.call_hook( |
| 245 | + 'after_train_iter', |
| 246 | + batch_idx=idx, |
| 247 | + data_batch=data_batch, |
| 248 | + outputs=outputs) |
| 249 | + |
| 250 | + self.runner.call_hook('after_train_epoch') |
| 251 | + self._finetune_epoch += 1 |
| 252 | + |
| 253 | + model.eval() |
0 commit comments