diff --git a/libs/@local/hashql/mir/src/pass/transform/cfg_simplify/mod.rs b/libs/@local/hashql/mir/src/pass/transform/cfg_simplify/mod.rs index 4adfaea5df4..9795613e5e4 100644 --- a/libs/@local/hashql/mir/src/pass/transform/cfg_simplify/mod.rs +++ b/libs/@local/hashql/mir/src/pass/transform/cfg_simplify/mod.rs @@ -140,21 +140,29 @@ impl CfgSimplify { return false; } - let [block, target] = body - .basic_blocks - .as_mut() - .get_disjoint_mut([id, goto.target.block]) - .unwrap_or_else(|_err| unreachable!("self-loops excluded by check above")); - // This is the only special case, if there are multiple predecessors, and the target itself // is a self-loop we cannot safely merge them. The reason is that in that case we wouldn't - // be able to make any progress, as expansion would be infinite. - if let TerminatorKind::Goto(target_goto) = target.terminator.kind + // be able to make any progress upon expansion, as we would replace our own terminator with + // the exact same one. We could broaden the search to also check params (which would still + // be correct), this case alone leads to more code generation as we're generating a + // superfluous assignment. + // The `target_predecessors_len` check isn't 100% necessary, as this case can only happen + // iff the target is a self-loop, hence has multiple predecessors, but allows us to be a bit + // more defensive about that fact. + if target_predecessors_len > 1 + && let TerminatorKind::Goto(target_goto) = + body.basic_blocks[goto.target.block].terminator.kind && target_goto.target.block == goto.target.block { return false; } + let [block, target] = body + .basic_blocks + .as_mut() + .get_disjoint_mut([id, goto.target.block]) + .unwrap_or_else(|_err| unreachable!("self-loops excluded by check above")); + // Step 1: Assign block parameters before moving statements to maintain def-before-use. debug_assert_eq!(target.params.len(), goto.target.args.len()); for (¶m, &arg) in target.params.iter().zip(goto.target.args) { diff --git a/libs/@local/hashql/mir/src/pass/transform/ssa_repair/mod.rs b/libs/@local/hashql/mir/src/pass/transform/ssa_repair/mod.rs index 3dc07242817..e26dcad7fa6 100644 --- a/libs/@local/hashql/mir/src/pass/transform/ssa_repair/mod.rs +++ b/libs/@local/hashql/mir/src/pass/transform/ssa_repair/mod.rs @@ -51,8 +51,8 @@ use crate::{ local::{Local, LocalDecl, LocalVec}, location::Location, operand::Operand, - place::{DefUse, Place, PlaceContext}, - statement::Statement, + place::{DefUse, Place, PlaceContext, PlaceWriteContext}, + statement::{Assign, Statement}, terminator::Target, }, context::MirContext, @@ -192,12 +192,14 @@ enum FindDefFromTop { Idom(Local), /// A new block parameter was inserted to merge reaching definitions from predecessors. Param(Local), + /// The reaching definition is already present in the block parameter. + Existing(Local), } impl FindDefFromTop { /// Extracts the local from either variant. const fn into_local(self) -> Local { - let (Self::Idom(local) | Self::Param(local)) = self; + let (Self::Idom(local) | Self::Param(local) | Self::Existing(local)) = self; local } @@ -355,13 +357,26 @@ impl<'ctx, 'mir, 'env, 'heap> SsaViolationRepair<'ctx, 'mir, 'env, 'heap> { // def to unify each calling site. // The calling sites are then wired up in the next step. if self.iterated.contains(id) { - // We create a new local declaration for the block param. - let local_decl = local_decls[self.local]; - let local = local_decls.push(local_decl); - - // If we're part of the DF+ then we will have a new local declaration for the block - // param. - self.block_top.insert(id, FindDefFromTop::Param(local)); + // Check if we already have a local declaration for the block param + let local = if let Some(index) = self.locations.iter().position( + |&Location { + block, + statement_index, + }| block == id && statement_index == 0, + ) { + let local = self.locals[index]; + self.block_top.insert(id, FindDefFromTop::Existing(local)); + local + } else { + // We must create a new local declaration for the block param. + let local_decl = local_decls[self.local]; + let local = local_decls.push(local_decl); + + // If we're part of the DF+ then we will have a new local declaration for the block + // param. + self.block_top.insert(id, FindDefFromTop::Param(local)); + local + }; // It's important that we set the block def *before* we recurse, otherwise a loop will // create an infinite recursion case. The live-out (aka the block def) is always the @@ -553,6 +568,25 @@ impl<'heap> VisitorMut<'heap> for RewireBody<'_, 'heap> { Ok(()) } + fn visit_statement_assign( + &mut self, + location: Location, + Assign { lhs, rhs }: &mut Assign<'heap>, + ) -> Self::Result<()> { + { + // We must visit the rvalue BEFORE the lvalue, to not pollute the namespace. + self.visit_rvalue(location, rhs)?; + + self.visit_place( + location, + PlaceContext::Write(PlaceWriteContext::Assign), + lhs, + )?; + + Ok(()) + } + } + fn visit_local( &mut self, location: Location, @@ -685,6 +719,24 @@ impl<'heap> Visitor<'heap> for UseBeforeDef { visit::r#ref::walk_statement(self, location, statement) } + fn visit_statement_assign( + &mut self, + location: Location, + Assign { lhs, rhs }: &Assign<'heap>, + ) -> Self::Result { + // We must visit the right-hand side first to ensure that all the values are defined before + // we use them. + self.visit_rvalue(location, rhs)?; + + self.visit_place( + location, + PlaceContext::Write(PlaceWriteContext::Assign), + lhs, + )?; + + ControlFlow::Continue(()) + } + fn visit_local(&mut self, _: Location, context: PlaceContext, local: Local) -> Self::Result { if local != self.local { return ControlFlow::Continue(()); diff --git a/libs/@local/hashql/mir/src/pass/transform/ssa_repair/tests.rs b/libs/@local/hashql/mir/src/pass/transform/ssa_repair/tests.rs index 3016c7b97a1..2a204a6545f 100644 --- a/libs/@local/hashql/mir/src/pass/transform/ssa_repair/tests.rs +++ b/libs/@local/hashql/mir/src/pass/transform/ssa_repair/tests.rs @@ -772,3 +772,38 @@ fn irreducible() { }, ); } + +#[test] +fn reassign_rodeo() { + scaffold!(heap, interner, builder); + let env = Environment::new(&heap); + + let x = builder.local("x", TypeBuilder::synthetic(&env).integer()); + let const_0 = builder.const_int(0); + + let bb0 = builder.reserve_block([]); + let bb1 = builder.reserve_block([x]); + + let x = builder.place_local(x); + + builder + .build_block(bb0) + .assign_place(x, |rv| rv.load(const_0)) + .assign_place(x, |rv| rv.load(x)) + .goto(bb1, [x.into()]); + + builder.build_block(bb1).goto(bb1, [x.into()]); + + let body = builder.finish(0, TypeBuilder::synthetic(&env).null()); + + assert_ssa_pass( + "reassign_rodeo", + body, + MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} diff --git a/libs/@local/hashql/mir/tests/ui/pass/ssa_repair/reassign_rodeo.snap b/libs/@local/hashql/mir/tests/ui/pass/ssa_repair/reassign_rodeo.snap new file mode 100644 index 00000000000..bdf05bb031e --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/ssa_repair/reassign_rodeo.snap @@ -0,0 +1,37 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/ssa_repair/tests.rs +expression: value +--- +fn {intrinsic#4294967040}() -> Null { + let %0: Integer + + bb0(): { + %0 = 0 + %0 = %0 + + goto -> bb1(%0) + } + + bb1(%0): { + goto -> bb1(%0) + } +} + +------------------------------------ + +fn {intrinsic#4294967040}() -> Null { + let %0: Integer + let %1: Integer + let %2: Integer + + bb0(): { + %1 = 0 + %2 = %1 + + goto -> bb1(%2) + } + + bb1(%0): { + goto -> bb1(%0) + } +}