Skip to content

Commit f6b294a

Browse files
committed
spirv-opt: Handle id overflow in MergeReturnPass
This CL adds error handling to the MergeReturnPass to gracefully handle cases where the pass runs out of IDs. The following functions were modified to return a boolean indicating success or failure: - AddNewPhiNodes - AddReturnFlag - AddReturnValue - BranchToBlock - CreatePhiNodesForInst - ProcessStructuredBlock - RecordReturned - UpdatePhiNodes The callers of these functions were updated to check the return value and propagate the failure. This prevents the pass from crashing when it runs out of IDs.
1 parent 7f2d9ee commit f6b294a

File tree

2 files changed

+149
-68
lines changed

2 files changed

+149
-68
lines changed

source/opt/merge_return_pass.cpp

Lines changed: 119 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ bool MergeReturnPass::ProcessStructured(
134134
state_.pop_back();
135135
}
136136

137-
ProcessStructuredBlock(block);
137+
if (!ProcessStructuredBlock(block)) {
138+
return false;
139+
}
138140

139141
// Generate state for next block if warranted
140142
GenerateState(block);
@@ -169,7 +171,9 @@ bool MergeReturnPass::ProcessStructured(
169171
// We have not kept the dominator tree up-to-date.
170172
// Invalidate it at this point to make sure it will be rebuilt.
171173
context()->RemoveDominatorAnalysis(function);
172-
AddNewPhiNodes();
174+
if (!AddNewPhiNodes()) {
175+
return false;
176+
}
173177
return true;
174178
}
175179

@@ -196,7 +200,9 @@ bool MergeReturnPass::CreateReturnBlock() {
196200
}
197201

198202
bool MergeReturnPass::CreateReturn(BasicBlock* block) {
199-
AddReturnValue();
203+
if (!AddReturnValue()) {
204+
return false;
205+
}
200206

201207
if (return_value_) {
202208
// Load and return the final return value
@@ -229,12 +235,18 @@ bool MergeReturnPass::CreateReturn(BasicBlock* block) {
229235
return true;
230236
}
231237

232-
void MergeReturnPass::ProcessStructuredBlock(BasicBlock* block) {
238+
bool MergeReturnPass::ProcessStructuredBlock(BasicBlock* block) {
239+
if (block->tail() == block->end()) {
240+
return true;
241+
}
242+
233243
spv::Op tail_opcode = block->tail()->opcode();
234244
if (tail_opcode == spv::Op::OpReturn ||
235245
tail_opcode == spv::Op::OpReturnValue) {
236246
if (!return_flag_) {
237-
AddReturnFlag();
247+
if (!AddReturnFlag()) {
248+
return false;
249+
}
238250
}
239251
}
240252

@@ -243,43 +255,57 @@ void MergeReturnPass::ProcessStructuredBlock(BasicBlock* block) {
243255
tail_opcode == spv::Op::OpUnreachable) {
244256
assert(CurrentState().InBreakable() &&
245257
"Should be in the placeholder construct.");
246-
BranchToBlock(block, CurrentState().BreakMergeId());
258+
if (!BranchToBlock(block, CurrentState().BreakMergeId())) {
259+
return false;
260+
}
247261
return_blocks_.insert(block->id());
248262
}
263+
return true;
249264
}
250265

251-
void MergeReturnPass::BranchToBlock(BasicBlock* block, uint32_t target) {
266+
bool MergeReturnPass::BranchToBlock(BasicBlock* block, uint32_t target) {
252267
if (block->tail()->opcode() == spv::Op::OpReturn ||
253268
block->tail()->opcode() == spv::Op::OpReturnValue) {
254-
RecordReturned(block);
269+
if (!RecordReturned(block)) {
270+
return false;
271+
}
255272
RecordReturnValue(block);
256273
}
257274

258275
BasicBlock* target_block = context()->get_instr_block(target);
259276
if (target_block->GetLoopMergeInst()) {
260277
cfg()->SplitLoopHeader(target_block);
261278
}
262-
UpdatePhiNodes(block, target_block);
279+
if (!UpdatePhiNodes(block, target_block)) {
280+
return false;
281+
}
263282

264283
Instruction* return_inst = block->terminator();
265284
return_inst->SetOpcode(spv::Op::OpBranch);
266285
return_inst->ReplaceOperands({{SPV_OPERAND_TYPE_ID, {target}}});
267286
context()->get_def_use_mgr()->AnalyzeInstDefUse(return_inst);
268287
new_edges_[target_block].insert(block->id());
269288
cfg()->AddEdge(block->id(), target);
289+
return true;
270290
}
271291

272-
void MergeReturnPass::UpdatePhiNodes(BasicBlock* new_source,
292+
bool MergeReturnPass::UpdatePhiNodes(BasicBlock* new_source,
273293
BasicBlock* target) {
274-
target->ForEachPhiInst([this, new_source](Instruction* inst) {
294+
bool succeeded = true;
295+
target->ForEachPhiInst([this, new_source, &succeeded](Instruction* inst) {
275296
uint32_t undefId = Type2Undef(inst->type_id());
297+
if (undefId == 0) {
298+
succeeded = false;
299+
return;
300+
}
276301
inst->AddOperand({SPV_OPERAND_TYPE_ID, {undefId}});
277302
inst->AddOperand({SPV_OPERAND_TYPE_ID, {new_source->id()}});
278303
context()->UpdateDefUse(inst);
279304
});
305+
return succeeded;
280306
}
281307

282-
void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
308+
bool MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
283309
Instruction& inst) {
284310
DominatorAnalysis* dom_tree =
285311
context()->GetDominatorAnalysis(merge_block->GetParent());
@@ -313,7 +339,7 @@ void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
313339
});
314340

315341
if (users_to_update.empty()) {
316-
return;
342+
return true;
317343
}
318344

319345
// There is at least one values that needs to be replaced.
@@ -357,6 +383,9 @@ void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
357383
if (regenerateInstruction) {
358384
std::unique_ptr<Instruction> regen_inst(inst.Clone(context()));
359385
uint32_t new_id = TakeNextId();
386+
if (new_id == 0) {
387+
return false;
388+
}
360389
regen_inst->SetResultId(new_id);
361390
Instruction* insert_pos = &*merge_block->begin();
362391
while (insert_pos->opcode() == spv::Op::OpPhi) {
@@ -366,19 +395,31 @@ void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
366395
get_def_use_mgr()->AnalyzeInstDefUse(new_phi);
367396
context()->set_instr_block(new_phi, merge_block);
368397

369-
new_phi->ForEachInId([dom_tree, merge_block, this](uint32_t* use_id) {
398+
bool succeeded = true;
399+
new_phi->ForEachInId([dom_tree, merge_block, this,
400+
&succeeded](uint32_t* use_id) {
401+
if (!succeeded) {
402+
return;
403+
}
370404
Instruction* use = get_def_use_mgr()->GetDef(*use_id);
371405
BasicBlock* use_bb = context()->get_instr_block(use);
372406
if (use_bb != nullptr && !dom_tree->Dominates(use_bb, merge_block)) {
373-
CreatePhiNodesForInst(merge_block, *use);
407+
if (!CreatePhiNodesForInst(merge_block, *use)) {
408+
succeeded = false;
409+
}
374410
}
375411
});
412+
if (!succeeded) {
413+
return false;
414+
}
376415
} else {
377416
InstructionBuilder builder(
378417
context(), &*merge_block->begin(),
379418
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
380-
// TODO(1841): Handle id overflow.
381419
new_phi = builder.AddPhi(inst.type_id(), phi_operands);
420+
if (new_phi == nullptr) {
421+
return false;
422+
}
382423
}
383424
uint32_t result_of_phi = new_phi->result_id();
384425

@@ -392,6 +433,7 @@ void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
392433
context()->AnalyzeUses(user);
393434
}
394435
}
436+
return true;
395437
}
396438

397439
bool MergeReturnPass::PredicateBlocks(
@@ -484,6 +526,9 @@ bool MergeReturnPass::BreakFromConstruct(
484526
cfg()->RemoveSuccessorEdges(block);
485527

486528
auto old_body_id = TakeNextId();
529+
if (old_body_id == 0) {
530+
return false;
531+
}
487532
BasicBlock* old_body = block->SplitBasicBlock(context(), old_body_id, iter);
488533
predicated->insert(old_body);
489534

@@ -520,9 +565,11 @@ bool MergeReturnPass::BreakFromConstruct(
520565
analysis::Bool bool_type;
521566
uint32_t bool_id = context()->get_type_mgr()->GetId(&bool_type);
522567
assert(bool_id != 0);
523-
// TODO(1841): Handle id overflow.
524-
uint32_t load_id =
525-
builder.AddLoad(bool_id, return_flag_->result_id())->result_id();
568+
Instruction* load_inst = builder.AddLoad(bool_id, return_flag_->result_id());
569+
if (load_inst == nullptr) {
570+
return false;
571+
}
572+
uint32_t load_id = load_inst->result_id();
526573

527574
// 2. Branch to |merge_block| (true) or |old_body| (false)
528575
builder.AddConditionalBranch(load_id, merge_block->id(), old_body->id(),
@@ -535,7 +582,9 @@ bool MergeReturnPass::BreakFromConstruct(
535582
}
536583

537584
// 3. Update OpPhi instructions in |merge_block|.
538-
UpdatePhiNodes(block, merge_block);
585+
if (!UpdatePhiNodes(block, merge_block)) {
586+
return false;
587+
}
539588

540589
// 4. Update the CFG. We do this after updating the OpPhi instructions
541590
// because |UpdatePhiNodes| assumes the edge from |block| has not been added
@@ -548,10 +597,10 @@ bool MergeReturnPass::BreakFromConstruct(
548597
return true;
549598
}
550599

551-
void MergeReturnPass::RecordReturned(BasicBlock* block) {
600+
bool MergeReturnPass::RecordReturned(BasicBlock* block) {
552601
if (block->tail()->opcode() != spv::Op::OpReturn &&
553602
block->tail()->opcode() != spv::Op::OpReturnValue)
554-
return;
603+
return true;
555604

556605
assert(return_flag_ && "Did not generate the return flag variable.");
557606

@@ -564,6 +613,9 @@ void MergeReturnPass::RecordReturned(BasicBlock* block) {
564613
const analysis::Constant* true_const =
565614
const_mgr->GetConstant(bool_type, {true});
566615
constant_true_ = const_mgr->GetDefiningInstruction(true_const);
616+
if (!constant_true_) {
617+
return false;
618+
}
567619
context()->UpdateDefUse(constant_true_);
568620
}
569621

@@ -577,6 +629,7 @@ void MergeReturnPass::RecordReturned(BasicBlock* block) {
577629
&*block->tail().InsertBefore(std::move(return_store));
578630
context()->set_instr_block(store_inst, block);
579631
context()->AnalyzeDefUse(store_inst);
632+
return true;
580633
}
581634

582635
void MergeReturnPass::RecordReturnValue(BasicBlock* block) {
@@ -600,18 +653,21 @@ void MergeReturnPass::RecordReturnValue(BasicBlock* block) {
600653
context()->AnalyzeDefUse(store_inst);
601654
}
602655

603-
void MergeReturnPass::AddReturnValue() {
604-
if (return_value_) return;
656+
bool MergeReturnPass::AddReturnValue() {
657+
if (return_value_) return true;
605658

606659
uint32_t return_type_id = function_->type_id();
607660
if (get_def_use_mgr()->GetDef(return_type_id)->opcode() ==
608661
spv::Op::OpTypeVoid)
609-
return;
662+
return true;
610663

611664
uint32_t return_ptr_type = context()->get_type_mgr()->FindPointerToType(
612665
return_type_id, spv::StorageClass::Function);
613666

614667
uint32_t var_id = TakeNextId();
668+
if (var_id == 0) {
669+
return false;
670+
}
615671
std::unique_ptr<Instruction> returnValue(
616672
new Instruction(context(), spv::Op::OpVariable, return_ptr_type, var_id,
617673
std::initializer_list<Operand>{
@@ -627,27 +683,44 @@ void MergeReturnPass::AddReturnValue() {
627683

628684
context()->get_decoration_mgr()->CloneDecorations(
629685
function_->result_id(), var_id, {spv::Decoration::RelaxedPrecision});
686+
return true;
630687
}
631688

632-
void MergeReturnPass::AddReturnFlag() {
633-
if (return_flag_) return;
689+
bool MergeReturnPass::AddReturnFlag() {
690+
if (return_flag_) return true;
634691

635692
analysis::TypeManager* type_mgr = context()->get_type_mgr();
636693
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
637694

638695
analysis::Bool temp;
639696
uint32_t bool_id = type_mgr->GetTypeInstruction(&temp);
697+
if (bool_id == 0) {
698+
return false;
699+
}
640700
analysis::Bool* bool_type = type_mgr->GetType(bool_id)->AsBool();
641701

642702
const analysis::Constant* false_const =
643703
const_mgr->GetConstant(bool_type, {false});
644-
uint32_t const_false_id =
645-
const_mgr->GetDefiningInstruction(false_const)->result_id();
704+
Instruction* false_inst = const_mgr->GetDefiningInstruction(false_const);
705+
if (false_inst == nullptr) {
706+
return false;
707+
}
708+
uint32_t const_false_id = false_inst->result_id();
646709

647710
uint32_t bool_ptr_id =
648711
type_mgr->FindPointerToType(bool_id, spv::StorageClass::Function);
649712

713+
if (bool_ptr_id == 0) {
714+
return false;
715+
;
716+
}
717+
650718
uint32_t var_id = TakeNextId();
719+
720+
if (var_id == 0) {
721+
return false;
722+
}
723+
651724
std::unique_ptr<Instruction> returnFlag(new Instruction(
652725
context(), spv::Op::OpVariable, bool_ptr_id, var_id,
653726
std::initializer_list<Operand>{{SPV_OPERAND_TYPE_STORAGE_CLASS,
@@ -661,6 +734,7 @@ void MergeReturnPass::AddReturnFlag() {
661734
return_flag_ = &*entry_block->begin();
662735
context()->AnalyzeDefUse(return_flag_);
663736
context()->set_instr_block(return_flag_, entry_block);
737+
return true;
664738
}
665739

666740
std::vector<BasicBlock*> MergeReturnPass::CollectReturnBlocks(
@@ -739,16 +813,19 @@ bool MergeReturnPass::MergeReturnBlocks(
739813
return true;
740814
}
741815

742-
void MergeReturnPass::AddNewPhiNodes() {
816+
bool MergeReturnPass::AddNewPhiNodes() {
743817
std::list<BasicBlock*> order;
744818
cfg()->ComputeStructuredOrder(function_, &*function_->begin(), &order);
745819

746820
for (BasicBlock* bb : order) {
747-
AddNewPhiNodes(bb);
821+
if (!AddNewPhiNodes(bb)) {
822+
return false;
823+
}
748824
}
825+
return true;
749826
}
750827

751-
void MergeReturnPass::AddNewPhiNodes(BasicBlock* bb) {
828+
bool MergeReturnPass::AddNewPhiNodes(BasicBlock* bb) {
752829
// New phi nodes are needed for any id whose definition used to dominate |bb|,
753830
// but no longer dominates |bb|. These are found by walking the dominator
754831
// tree starting at the original immediate dominator of |bb| and ending at its
@@ -766,16 +843,19 @@ void MergeReturnPass::AddNewPhiNodes(BasicBlock* bb) {
766843

767844
BasicBlock* dominator = dom_tree->ImmediateDominator(bb);
768845
if (dominator == nullptr) {
769-
return;
846+
return true;
770847
}
771848

772849
BasicBlock* current_bb = context()->get_instr_block(original_dominator_[bb]);
773850
while (current_bb != nullptr && current_bb != dominator) {
774851
for (Instruction& inst : *current_bb) {
775-
CreatePhiNodesForInst(bb, inst);
852+
if (!CreatePhiNodesForInst(bb, inst)) {
853+
return false;
854+
}
776855
}
777856
current_bb = dom_tree->ImmediateDominator(current_bb);
778857
}
858+
return true;
779859
}
780860

781861
void MergeReturnPass::RecordImmediateDominators(Function* function) {
@@ -859,8 +939,12 @@ bool MergeReturnPass::CreateSingleCaseSwitch(BasicBlock* merge_target) {
859939
++split_pos;
860940
}
861941

942+
uint32_t new_block_id = TakeNextId();
943+
if (new_block_id == 0) {
944+
return false;
945+
}
862946
BasicBlock* old_block =
863-
start_block->SplitBasicBlock(context(), TakeNextId(), split_pos);
947+
start_block->SplitBasicBlock(context(), new_block_id, split_pos);
864948

865949
// Find DebugFunctionDefinition inst in the old block, and if we can find it,
866950
// move it to the entry block. Since DebugFunctionDefinition is not necessary

0 commit comments

Comments
 (0)