Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 119 additions & 35 deletions source/opt/merge_return_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ bool MergeReturnPass::ProcessStructured(
state_.pop_back();
}

ProcessStructuredBlock(block);
if (!ProcessStructuredBlock(block)) {
return false;
}

// Generate state for next block if warranted
GenerateState(block);
Expand Down Expand Up @@ -169,7 +171,9 @@ bool MergeReturnPass::ProcessStructured(
// We have not kept the dominator tree up-to-date.
// Invalidate it at this point to make sure it will be rebuilt.
context()->RemoveDominatorAnalysis(function);
AddNewPhiNodes();
if (!AddNewPhiNodes()) {
return false;
}
return true;
}

Expand All @@ -196,7 +200,9 @@ bool MergeReturnPass::CreateReturnBlock() {
}

bool MergeReturnPass::CreateReturn(BasicBlock* block) {
AddReturnValue();
if (!AddReturnValue()) {
return false;
}

if (return_value_) {
// Load and return the final return value
Expand Down Expand Up @@ -229,12 +235,18 @@ bool MergeReturnPass::CreateReturn(BasicBlock* block) {
return true;
}

void MergeReturnPass::ProcessStructuredBlock(BasicBlock* block) {
bool MergeReturnPass::ProcessStructuredBlock(BasicBlock* block) {
if (block->tail() == block->end()) {
return true;
}

spv::Op tail_opcode = block->tail()->opcode();
if (tail_opcode == spv::Op::OpReturn ||
tail_opcode == spv::Op::OpReturnValue) {
if (!return_flag_) {
AddReturnFlag();
if (!AddReturnFlag()) {
return false;
}
}
}

Expand All @@ -243,43 +255,57 @@ void MergeReturnPass::ProcessStructuredBlock(BasicBlock* block) {
tail_opcode == spv::Op::OpUnreachable) {
assert(CurrentState().InBreakable() &&
"Should be in the placeholder construct.");
BranchToBlock(block, CurrentState().BreakMergeId());
if (!BranchToBlock(block, CurrentState().BreakMergeId())) {
return false;
}
return_blocks_.insert(block->id());
}
return true;
}

void MergeReturnPass::BranchToBlock(BasicBlock* block, uint32_t target) {
bool MergeReturnPass::BranchToBlock(BasicBlock* block, uint32_t target) {
if (block->tail()->opcode() == spv::Op::OpReturn ||
block->tail()->opcode() == spv::Op::OpReturnValue) {
RecordReturned(block);
if (!RecordReturned(block)) {
return false;
}
RecordReturnValue(block);
}

BasicBlock* target_block = context()->get_instr_block(target);
if (target_block->GetLoopMergeInst()) {
cfg()->SplitLoopHeader(target_block);
}
UpdatePhiNodes(block, target_block);
if (!UpdatePhiNodes(block, target_block)) {
return false;
}

Instruction* return_inst = block->terminator();
return_inst->SetOpcode(spv::Op::OpBranch);
return_inst->ReplaceOperands({{SPV_OPERAND_TYPE_ID, {target}}});
context()->get_def_use_mgr()->AnalyzeInstDefUse(return_inst);
new_edges_[target_block].insert(block->id());
cfg()->AddEdge(block->id(), target);
return true;
}

void MergeReturnPass::UpdatePhiNodes(BasicBlock* new_source,
bool MergeReturnPass::UpdatePhiNodes(BasicBlock* new_source,
BasicBlock* target) {
target->ForEachPhiInst([this, new_source](Instruction* inst) {
bool succeeded = true;
target->ForEachPhiInst([this, new_source, &succeeded](Instruction* inst) {
uint32_t undefId = Type2Undef(inst->type_id());
if (undefId == 0) {
succeeded = false;
return;
}
inst->AddOperand({SPV_OPERAND_TYPE_ID, {undefId}});
inst->AddOperand({SPV_OPERAND_TYPE_ID, {new_source->id()}});
context()->UpdateDefUse(inst);
});
return succeeded;
}

void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
bool MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
Instruction& inst) {
DominatorAnalysis* dom_tree =
context()->GetDominatorAnalysis(merge_block->GetParent());
Expand Down Expand Up @@ -313,7 +339,7 @@ void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
});

if (users_to_update.empty()) {
return;
return true;
}

// There is at least one values that needs to be replaced.
Expand Down Expand Up @@ -357,6 +383,9 @@ void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
if (regenerateInstruction) {
std::unique_ptr<Instruction> regen_inst(inst.Clone(context()));
uint32_t new_id = TakeNextId();
if (new_id == 0) {
return false;
}
regen_inst->SetResultId(new_id);
Instruction* insert_pos = &*merge_block->begin();
while (insert_pos->opcode() == spv::Op::OpPhi) {
Expand All @@ -366,19 +395,31 @@ void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
get_def_use_mgr()->AnalyzeInstDefUse(new_phi);
context()->set_instr_block(new_phi, merge_block);

new_phi->ForEachInId([dom_tree, merge_block, this](uint32_t* use_id) {
bool succeeded = true;
new_phi->ForEachInId([dom_tree, merge_block, this,
&succeeded](uint32_t* use_id) {
if (!succeeded) {
return;
}
Instruction* use = get_def_use_mgr()->GetDef(*use_id);
BasicBlock* use_bb = context()->get_instr_block(use);
if (use_bb != nullptr && !dom_tree->Dominates(use_bb, merge_block)) {
CreatePhiNodesForInst(merge_block, *use);
if (!CreatePhiNodesForInst(merge_block, *use)) {
succeeded = false;
}
}
});
if (!succeeded) {
return false;
}
} else {
InstructionBuilder builder(
context(), &*merge_block->begin(),
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
// TODO(1841): Handle id overflow.
new_phi = builder.AddPhi(inst.type_id(), phi_operands);
if (new_phi == nullptr) {
return false;
}
}
uint32_t result_of_phi = new_phi->result_id();

Expand All @@ -392,6 +433,7 @@ void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
context()->AnalyzeUses(user);
}
}
return true;
}

bool MergeReturnPass::PredicateBlocks(
Expand Down Expand Up @@ -484,6 +526,9 @@ bool MergeReturnPass::BreakFromConstruct(
cfg()->RemoveSuccessorEdges(block);

auto old_body_id = TakeNextId();
if (old_body_id == 0) {
return false;
}
BasicBlock* old_body = block->SplitBasicBlock(context(), old_body_id, iter);
predicated->insert(old_body);

Expand Down Expand Up @@ -520,9 +565,11 @@ bool MergeReturnPass::BreakFromConstruct(
analysis::Bool bool_type;
uint32_t bool_id = context()->get_type_mgr()->GetId(&bool_type);
assert(bool_id != 0);
// TODO(1841): Handle id overflow.
uint32_t load_id =
builder.AddLoad(bool_id, return_flag_->result_id())->result_id();
Instruction* load_inst = builder.AddLoad(bool_id, return_flag_->result_id());
if (load_inst == nullptr) {
return false;
}
uint32_t load_id = load_inst->result_id();

// 2. Branch to |merge_block| (true) or |old_body| (false)
builder.AddConditionalBranch(load_id, merge_block->id(), old_body->id(),
Expand All @@ -535,7 +582,9 @@ bool MergeReturnPass::BreakFromConstruct(
}

// 3. Update OpPhi instructions in |merge_block|.
UpdatePhiNodes(block, merge_block);
if (!UpdatePhiNodes(block, merge_block)) {
return false;
}

// 4. Update the CFG. We do this after updating the OpPhi instructions
// because |UpdatePhiNodes| assumes the edge from |block| has not been added
Expand All @@ -548,10 +597,10 @@ bool MergeReturnPass::BreakFromConstruct(
return true;
}

void MergeReturnPass::RecordReturned(BasicBlock* block) {
bool MergeReturnPass::RecordReturned(BasicBlock* block) {
if (block->tail()->opcode() != spv::Op::OpReturn &&
block->tail()->opcode() != spv::Op::OpReturnValue)
return;
return true;

assert(return_flag_ && "Did not generate the return flag variable.");

Expand All @@ -564,6 +613,9 @@ void MergeReturnPass::RecordReturned(BasicBlock* block) {
const analysis::Constant* true_const =
const_mgr->GetConstant(bool_type, {true});
constant_true_ = const_mgr->GetDefiningInstruction(true_const);
if (!constant_true_) {
return false;
}
context()->UpdateDefUse(constant_true_);
}

Expand All @@ -577,6 +629,7 @@ void MergeReturnPass::RecordReturned(BasicBlock* block) {
&*block->tail().InsertBefore(std::move(return_store));
context()->set_instr_block(store_inst, block);
context()->AnalyzeDefUse(store_inst);
return true;
}

void MergeReturnPass::RecordReturnValue(BasicBlock* block) {
Expand All @@ -600,18 +653,21 @@ void MergeReturnPass::RecordReturnValue(BasicBlock* block) {
context()->AnalyzeDefUse(store_inst);
}

void MergeReturnPass::AddReturnValue() {
if (return_value_) return;
bool MergeReturnPass::AddReturnValue() {
if (return_value_) return true;

uint32_t return_type_id = function_->type_id();
if (get_def_use_mgr()->GetDef(return_type_id)->opcode() ==
spv::Op::OpTypeVoid)
return;
return true;

uint32_t return_ptr_type = context()->get_type_mgr()->FindPointerToType(
return_type_id, spv::StorageClass::Function);

uint32_t var_id = TakeNextId();
if (var_id == 0) {
return false;
}
std::unique_ptr<Instruction> returnValue(
new Instruction(context(), spv::Op::OpVariable, return_ptr_type, var_id,
std::initializer_list<Operand>{
Expand All @@ -627,27 +683,44 @@ void MergeReturnPass::AddReturnValue() {

context()->get_decoration_mgr()->CloneDecorations(
function_->result_id(), var_id, {spv::Decoration::RelaxedPrecision});
return true;
}

void MergeReturnPass::AddReturnFlag() {
if (return_flag_) return;
bool MergeReturnPass::AddReturnFlag() {
if (return_flag_) return true;

analysis::TypeManager* type_mgr = context()->get_type_mgr();
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();

analysis::Bool temp;
uint32_t bool_id = type_mgr->GetTypeInstruction(&temp);
if (bool_id == 0) {
return false;
}
analysis::Bool* bool_type = type_mgr->GetType(bool_id)->AsBool();

const analysis::Constant* false_const =
const_mgr->GetConstant(bool_type, {false});
uint32_t const_false_id =
const_mgr->GetDefiningInstruction(false_const)->result_id();
Instruction* false_inst = const_mgr->GetDefiningInstruction(false_const);
if (false_inst == nullptr) {
return false;
}
uint32_t const_false_id = false_inst->result_id();

uint32_t bool_ptr_id =
type_mgr->FindPointerToType(bool_id, spv::StorageClass::Function);

if (bool_ptr_id == 0) {
return false;
;
}

uint32_t var_id = TakeNextId();

if (var_id == 0) {
return false;
}

std::unique_ptr<Instruction> returnFlag(new Instruction(
context(), spv::Op::OpVariable, bool_ptr_id, var_id,
std::initializer_list<Operand>{{SPV_OPERAND_TYPE_STORAGE_CLASS,
Expand All @@ -661,6 +734,7 @@ void MergeReturnPass::AddReturnFlag() {
return_flag_ = &*entry_block->begin();
context()->AnalyzeDefUse(return_flag_);
context()->set_instr_block(return_flag_, entry_block);
return true;
}

std::vector<BasicBlock*> MergeReturnPass::CollectReturnBlocks(
Expand Down Expand Up @@ -739,16 +813,19 @@ bool MergeReturnPass::MergeReturnBlocks(
return true;
}

void MergeReturnPass::AddNewPhiNodes() {
bool MergeReturnPass::AddNewPhiNodes() {
std::list<BasicBlock*> order;
cfg()->ComputeStructuredOrder(function_, &*function_->begin(), &order);

for (BasicBlock* bb : order) {
AddNewPhiNodes(bb);
if (!AddNewPhiNodes(bb)) {
return false;
}
}
return true;
}

void MergeReturnPass::AddNewPhiNodes(BasicBlock* bb) {
bool MergeReturnPass::AddNewPhiNodes(BasicBlock* bb) {
// New phi nodes are needed for any id whose definition used to dominate |bb|,
// but no longer dominates |bb|. These are found by walking the dominator
// tree starting at the original immediate dominator of |bb| and ending at its
Expand All @@ -766,16 +843,19 @@ void MergeReturnPass::AddNewPhiNodes(BasicBlock* bb) {

BasicBlock* dominator = dom_tree->ImmediateDominator(bb);
if (dominator == nullptr) {
return;
return true;
}

BasicBlock* current_bb = context()->get_instr_block(original_dominator_[bb]);
while (current_bb != nullptr && current_bb != dominator) {
for (Instruction& inst : *current_bb) {
CreatePhiNodesForInst(bb, inst);
if (!CreatePhiNodesForInst(bb, inst)) {
return false;
}
}
current_bb = dom_tree->ImmediateDominator(current_bb);
}
return true;
}

void MergeReturnPass::RecordImmediateDominators(Function* function) {
Expand Down Expand Up @@ -859,8 +939,12 @@ bool MergeReturnPass::CreateSingleCaseSwitch(BasicBlock* merge_target) {
++split_pos;
}

uint32_t new_block_id = TakeNextId();
if (new_block_id == 0) {
return false;
}
BasicBlock* old_block =
start_block->SplitBasicBlock(context(), TakeNextId(), split_pos);
start_block->SplitBasicBlock(context(), new_block_id, split_pos);

// Find DebugFunctionDefinition inst in the old block, and if we can find it,
// move it to the entry block. Since DebugFunctionDefinition is not necessary
Expand Down
Loading