diff --git a/.gitignore b/.gitignore index 3d18b83..d53c77a 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ docs/.ub_cache/ **/*.so **/*.dylib **/*.wasm +**/*.wasm.rs # Test artifacts *.profraw diff --git a/Cargo.lock b/Cargo.lock index 9963e89..823120e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -359,6 +359,7 @@ dependencies = [ "anyhow", "clap", "heck", + "herkos-runtime", "wasmparser 0.227.1", "wat", ] diff --git a/crates/herkos-runtime/src/lib.rs b/crates/herkos-runtime/src/lib.rs index 02e722b..12dc322 100644 --- a/crates/herkos-runtime/src/lib.rs +++ b/crates/herkos-runtime/src/lib.rs @@ -26,7 +26,8 @@ mod ops; pub use ops::{ i32_div_s, i32_div_u, i32_rem_s, i32_rem_u, i32_trunc_f32_s, i32_trunc_f32_u, i32_trunc_f64_s, i32_trunc_f64_u, i64_div_s, i64_div_u, i64_rem_s, i64_rem_u, i64_trunc_f32_s, i64_trunc_f32_u, - i64_trunc_f64_s, i64_trunc_f64_u, + i64_trunc_f64_s, i64_trunc_f64_u, wasm_max_f32, wasm_max_f64, wasm_min_f32, wasm_min_f64, + wasm_nearest_f32, wasm_nearest_f64, }; /// Wasm execution errors — no panics, no unwinding. diff --git a/crates/herkos-runtime/src/ops.rs b/crates/herkos-runtime/src/ops.rs index bef9388..3312abc 100644 --- a/crates/herkos-runtime/src/ops.rs +++ b/crates/herkos-runtime/src/ops.rs @@ -194,6 +194,147 @@ pub fn i64_rem_u(lhs: i64, rhs: i64) -> WasmResult { .ok_or(WasmTrap::DivisionByZero) } +// ── Wasm float min/max/nearest ──────────────────────────────────────────────── + +/// Wasm `f32.min`: propagates NaN (unlike Rust's `f32::min` which ignores it). +/// Also preserves the Wasm rule `min(-0.0, +0.0) = -0.0`. +pub fn wasm_min_f32(a: f32, b: f32) -> f32 { + if a.is_nan() || b.is_nan() { + return f32::NAN; + } + if a == 0.0 && b == 0.0 { + return if a.is_sign_negative() { a } else { b }; + } + if a <= b { + a + } else { + b + } +} + +/// Wasm `f32.max`: propagates NaN. `max(-0.0, +0.0) = +0.0`. +pub fn wasm_max_f32(a: f32, b: f32) -> f32 { + if a.is_nan() || b.is_nan() { + return f32::NAN; + } + if a == 0.0 && b == 0.0 { + return if a.is_sign_positive() { a } else { b }; + } + if a >= b { + a + } else { + b + } +} + +/// Wasm `f64.min`: propagates NaN. `min(-0.0, +0.0) = -0.0`. +pub fn wasm_min_f64(a: f64, b: f64) -> f64 { + if a.is_nan() || b.is_nan() { + return f64::NAN; + } + if a == 0.0 && b == 0.0 { + return if a.is_sign_negative() { a } else { b }; + } + if a <= b { + a + } else { + b + } +} + +/// Wasm `f64.max`: propagates NaN. `max(-0.0, +0.0) = +0.0`. +pub fn wasm_max_f64(a: f64, b: f64) -> f64 { + if a.is_nan() || b.is_nan() { + return f64::NAN; + } + if a == 0.0 && b == 0.0 { + return if a.is_sign_positive() { a } else { b }; + } + if a >= b { + a + } else { + b + } +} + +/// Wasm `f32.nearest` — round to nearest even (banker's rounding). +/// +/// Uses `as i32` for truncation-toward-zero (safe since we guard against values >= 2^23, +/// which have no fractional bits). Avoids `f32::round`/`f32::trunc` +/// which are not available in `no_std` without `libm`. +pub fn wasm_nearest_f32(v: f32) -> f32 { + if v.is_nan() || v.is_infinite() || v == 0.0 { + return v; + } + // Floats >= 2^23 have no fractional bits — already an integer. + const NO_FRAC: f32 = 8_388_608.0; // 2^23 + if v >= NO_FRAC || v <= -NO_FRAC { + return v; + } + let trunc_i = v as i32; // truncates toward zero; safe since |v| < 2^23 + let trunc_f = trunc_i as f32; + let frac = v - trunc_f; // in (-1.0, 1.0), same sign as v + if frac > 0.5 { + (trunc_i + 1) as f32 + } else if frac < -0.5 { + (trunc_i - 1) as f32 + } else if frac == 0.5 { + // Tie: round to even (trunc_i is the floor for positive v). + if trunc_i % 2 == 0 { + trunc_f + } else { + (trunc_i + 1) as f32 + } + } else if frac == -0.5 { + // Tie: round to even. copysign preserves -0.0 when trunc_i == 0. + if trunc_i % 2 == 0 { + f32::copysign(trunc_f, v) + } else { + (trunc_i - 1) as f32 + } + } else { + trunc_f + } +} + +/// Wasm `f64.nearest` — round to nearest even (banker's rounding). +/// +/// Uses `as i64` for truncation-toward-zero (safe since we guard against values >= 2^52, +/// which have no fractional bits). Avoids `f64::round`/`f64::trunc` +/// which are not available in `no_std` without `libm`. +pub fn wasm_nearest_f64(v: f64) -> f64 { + if v.is_nan() || v.is_infinite() || v == 0.0 { + return v; + } + // Floats >= 2^52 have no fractional bits — already an integer. + const NO_FRAC: f64 = 4_503_599_627_370_496.0; // 2^52 + if v >= NO_FRAC || v <= -NO_FRAC { + return v; + } + let trunc_i = v as i64; // truncates toward zero; safe since |v| < 2^52 + let trunc_f = trunc_i as f64; + let frac = v - trunc_f; + if frac > 0.5 { + (trunc_i + 1) as f64 + } else if frac < -0.5 { + (trunc_i - 1) as f64 + } else if frac == 0.5 { + if trunc_i % 2 == 0 { + trunc_f + } else { + (trunc_i + 1) as f64 + } + } else if frac == -0.5 { + if trunc_i % 2 == 0 { + f64::copysign(trunc_f, v) + } else { + (trunc_i - 1) as f64 + } + } else { + trunc_f + } +} + // ── Tests ───────────────────────────────────────────────────────────────────── #[cfg(test)] diff --git a/crates/herkos/Cargo.toml b/crates/herkos/Cargo.toml index e3a4c30..c32d69d 100644 --- a/crates/herkos/Cargo.toml +++ b/crates/herkos/Cargo.toml @@ -16,6 +16,7 @@ wasmparser = { workspace = true } anyhow = { workspace = true } clap = { workspace = true } heck = { workspace = true } +herkos-runtime = { path = "../herkos-runtime", version = "0.1.1" } [dev-dependencies] wat = { workspace = true } diff --git a/crates/herkos/src/backend/mod.rs b/crates/herkos/src/backend/mod.rs index 267f660..eb8f2c5 100644 --- a/crates/herkos/src/backend/mod.rs +++ b/crates/herkos/src/backend/mod.rs @@ -106,6 +106,19 @@ pub trait Backend { if_false_idx: usize, ) -> String; + /// Emit Rust code for a conditional branch with an inlined comparison. + /// + /// Instead of `if condition != 0`, emits the comparison directly: + /// `if lhs >= rhs { ... } else { ... }`. + fn emit_branch_cmp_to_index( + &self, + op: BinOp, + lhs: VarId, + rhs: VarId, + if_true_idx: usize, + if_false_idx: usize, + ) -> String; + /// Emit Rust code for multi-way branch (br_table) using block indices. fn emit_branch_table_to_index( &self, diff --git a/crates/herkos/src/backend/safe.rs b/crates/herkos/src/backend/safe.rs index 1951d3c..217e08a 100644 --- a/crates/herkos/src/backend/safe.rs +++ b/crates/herkos/src/backend/safe.rs @@ -564,6 +564,21 @@ impl Backend for SafeBackend { ) } + fn emit_branch_cmp_to_index( + &self, + op: BinOp, + lhs: VarId, + rhs: VarId, + if_true_idx: usize, + if_false_idx: usize, + ) -> String { + let cmp_expr = format_cmp_expr(op, lhs, rhs); + format!( + " if {cmp_expr} {{\n __current_block = Block::B{};\n }} else {{\n __current_block = Block::B{};\n }}\n continue;", + if_true_idx, if_false_idx + ) + } + fn emit_branch_table_to_index( &self, index: VarId, @@ -597,3 +612,42 @@ impl Backend for SafeBackend { code } } + +/// Format a comparison expression for use in branch conditions. +fn format_cmp_expr(op: BinOp, lhs: VarId, rhs: VarId) -> String { + match op { + // Signed i32 comparisons + BinOp::I32Eq => format!("{lhs} == {rhs}"), + BinOp::I32Ne => format!("{lhs} != {rhs}"), + BinOp::I32LtS => format!("{lhs} < {rhs}"), + BinOp::I32GtS => format!("{lhs} > {rhs}"), + BinOp::I32LeS => format!("{lhs} <= {rhs}"), + BinOp::I32GeS => format!("{lhs} >= {rhs}"), + // Unsigned i32 comparisons + BinOp::I32LtU => format!("({lhs} as u32) < ({rhs} as u32)"), + BinOp::I32GtU => format!("({lhs} as u32) > ({rhs} as u32)"), + BinOp::I32LeU => format!("({lhs} as u32) <= ({rhs} as u32)"), + BinOp::I32GeU => format!("({lhs} as u32) >= ({rhs} as u32)"), + // Signed i64 comparisons + BinOp::I64Eq => format!("{lhs} == {rhs}"), + BinOp::I64Ne => format!("{lhs} != {rhs}"), + BinOp::I64LtS => format!("{lhs} < {rhs}"), + BinOp::I64GtS => format!("{lhs} > {rhs}"), + BinOp::I64LeS => format!("{lhs} <= {rhs}"), + BinOp::I64GeS => format!("{lhs} >= {rhs}"), + // Unsigned i64 comparisons + BinOp::I64LtU => format!("({lhs} as u64) < ({rhs} as u64)"), + BinOp::I64GtU => format!("({lhs} as u64) > ({rhs} as u64)"), + BinOp::I64LeU => format!("({lhs} as u64) <= ({rhs} as u64)"), + BinOp::I64GeU => format!("({lhs} as u64) >= ({rhs} as u64)"), + // Float comparisons + BinOp::F32Eq | BinOp::F64Eq => format!("{lhs} == {rhs}"), + BinOp::F32Ne | BinOp::F64Ne => format!("{lhs} != {rhs}"), + BinOp::F32Lt | BinOp::F64Lt => format!("{lhs} < {rhs}"), + BinOp::F32Gt | BinOp::F64Gt => format!("{lhs} > {rhs}"), + BinOp::F32Le | BinOp::F64Le => format!("{lhs} <= {rhs}"), + BinOp::F32Ge | BinOp::F64Ge => format!("{lhs} >= {rhs}"), + // Non-comparison ops should never reach here + _ => format!("{lhs} != 0 /* unexpected non-comparison op */"), + } +} diff --git a/crates/herkos/src/codegen/function.rs b/crates/herkos/src/codegen/function.rs index 409a2ba..859f00a 100644 --- a/crates/herkos/src/codegen/function.rs +++ b/crates/herkos/src/codegen/function.rs @@ -172,10 +172,27 @@ pub fn generate_function_with_info( output.push_str(" loop {\n"); output.push_str(" match __current_block {\n"); + // Build global use counts for branch condition inlining. + let global_uses = crate::optimizer::utils::build_global_use_count(ir_func); + for (idx, block) in ir_func.blocks.iter().enumerate() { output.push_str(&format!(" Block::B{} => {{\n", idx)); + // Detect if the BranchIf condition is a single-use comparison BinOp + // defined in this block — if so, skip emitting it and inline into branch. + let inlined_cmp = detect_inlined_cmp(block, &global_uses); + let skip_var = inlined_cmp.as_ref().map(|_| match &block.terminator { + IrTerminator::BranchIf { condition, .. } => *condition, + _ => unreachable!(), + }); + for instr in &block.instructions { + // Skip the inlined comparison instruction + if let Some(skip) = skip_var { + if crate::optimizer::utils::instr_dest(instr) == Some(skip) { + continue; + } + } let code = crate::codegen::instruction::generate_instruction_with_info(backend, instr, info)?; output.push_str(&code); @@ -187,6 +204,7 @@ pub fn generate_function_with_info( &block.terminator, &block_id_to_index, ir_func.return_type, + inlined_cmp.as_ref(), ); output.push_str(&term_code); output.push('\n'); @@ -290,6 +308,48 @@ fn generate_signature_with_info( sig } +/// Detect whether a block's `BranchIf` condition is defined by a single-use +/// comparison `BinOp` instruction within the same block. If so, return the +/// comparison info for inlining into the branch. +fn detect_inlined_cmp( + block: &IrBlock, + global_uses: &std::collections::HashMap, +) -> Option { + let condition = match &block.terminator { + IrTerminator::BranchIf { condition, .. } => *condition, + _ => return None, + }; + + // Condition must have exactly one use (the BranchIf itself). + if global_uses.get(&condition).copied().unwrap_or(0) != 1 { + return None; + } + + // Find the defining instruction in this block. + for (i, instr) in block.instructions.iter().enumerate() { + if let IrInstr::BinOp { dest, op, lhs, rhs } = instr { + if *dest == condition && op.is_comparison() { + // Safety: verify no instruction after this one redefines + // lhs or rhs. If they do, inlining at branch-time would see + // the wrong operand values. + let operands_stable = !block.instructions[i + 1..].iter().any(|later| { + let d = crate::optimizer::utils::instr_dest(later); + d == Some(*lhs) || d == Some(*rhs) + }); + if operands_stable { + return Some(crate::codegen::instruction::InlinedCmp { + op: *op, + lhs: *lhs, + rhs: *rhs, + }); + } + } + } + } + + None +} + /// Check if an IR function has any import calls. fn has_import_calls(ir_func: &IrFunction) -> bool { ir_func.blocks.iter().any(|block| { diff --git a/crates/herkos/src/codegen/instruction.rs b/crates/herkos/src/codegen/instruction.rs index 5ae955f..9430b34 100644 --- a/crates/herkos/src/codegen/instruction.rs +++ b/crates/herkos/src/codegen/instruction.rs @@ -103,16 +103,36 @@ pub fn generate_instruction_with_info( val2, condition, } => backend.emit_select(*dest, *val1, *val2, *condition), + + // Phi nodes must be lowered to Assign instructions by the lower_phis pass + // before codegen runs. Reaching this arm is a compiler bug. + IrInstr::Phi { .. } => { + unreachable!( + "IrInstr::Phi must be lowered before codegen (lower_phis pass missed this block)" + ) + } }; Ok(code) } +/// An inlined comparison for a `BranchIf` terminator. +/// +/// When the `BranchIf` condition is defined by a single-use comparison BinOp, +/// the codegen skips emitting the comparison instruction and passes this info +/// to the terminator so the backend can emit `if lhs >= rhs { ... }` directly. +pub struct InlinedCmp { + pub op: BinOp, + pub lhs: VarId, + pub rhs: VarId, +} + /// Generate code for a terminator with BlockId to index mapping. pub fn generate_terminator_with_mapping( backend: &B, term: &IrTerminator, block_id_to_index: &HashMap, func_return_type: Option, + inlined_cmp: Option<&InlinedCmp>, ) -> String { match term { IrTerminator::Return { value } => { @@ -137,7 +157,11 @@ pub fn generate_terminator_with_mapping( } => { let true_idx = block_id_to_index[if_true]; let false_idx = block_id_to_index[if_false]; - backend.emit_branch_if_to_index(*condition, true_idx, false_idx) + if let Some(cmp) = inlined_cmp { + backend.emit_branch_cmp_to_index(cmp.op, cmp.lhs, cmp.rhs, true_idx, false_idx) + } else { + backend.emit_branch_if_to_index(*condition, true_idx, false_idx) + } } IrTerminator::BranchTable { diff --git a/crates/herkos/src/codegen/mod.rs b/crates/herkos/src/codegen/mod.rs index c618d96..21feebd 100644 --- a/crates/herkos/src/codegen/mod.rs +++ b/crates/herkos/src/codegen/mod.rs @@ -165,7 +165,7 @@ impl<'a, B: Backend> CodeGenerator<'a, B> { /// Generate a complete Rust module from IR with full module info. /// /// This is the main entry point. It generates a module wrapper structure. - pub fn generate_module_with_info(&self, info: &ModuleInfo) -> Result { + pub fn generate_module_with_info(&self, info: &LoweredModuleInfo) -> Result { module::generate_module_with_info(self.backend, info) } } @@ -500,7 +500,8 @@ mod tests { let backend = SafeBackend::new(); let codegen = CodeGenerator::new(&backend); - let code = codegen.generate_module_with_info(&info).unwrap(); + let lowered = crate::ir::lower_phis::lower(info); + let code = codegen.generate_module_with_info(&lowered).unwrap(); println!("Generated wrapper code:\n{}", code); @@ -566,7 +567,8 @@ mod tests { let backend = SafeBackend::new(); let codegen = CodeGenerator::new(&backend); - let code = codegen.generate_module_with_info(&info).unwrap(); + let lowered = crate::ir::lower_phis::lower(info); + let code = codegen.generate_module_with_info(&lowered).unwrap(); println!("Generated wrapper code:\n{}", code); @@ -633,7 +635,8 @@ mod tests { let backend = SafeBackend::new(); let codegen = CodeGenerator::new(&backend); - let code = codegen.generate_module_with_info(&info).unwrap(); + let lowered = crate::ir::lower_phis::lower(info); + let code = codegen.generate_module_with_info(&lowered).unwrap(); println!("Generated code with immutable global:\n{}", code); diff --git a/crates/herkos/src/codegen/module.rs b/crates/herkos/src/codegen/module.rs index c6d763c..6e1b25f 100644 --- a/crates/herkos/src/codegen/module.rs +++ b/crates/herkos/src/codegen/module.rs @@ -15,7 +15,10 @@ use anyhow::{Context, Result}; /// Generate a complete Rust module from IR functions with full module info. /// /// This is the main entry point. It generates a module wrapper structure. -pub fn generate_module_with_info(backend: &B, info: &ModuleInfo) -> Result { +pub fn generate_module_with_info( + backend: &B, + info: &LoweredModuleInfo, +) -> Result { generate_wrapper_module(backend, info) } diff --git a/crates/herkos/src/ir/builder/core.rs b/crates/herkos/src/ir/builder/core.rs index a74f51f..29b4183 100644 --- a/crates/herkos/src/ir/builder/core.rs +++ b/crates/herkos/src/ir/builder/core.rs @@ -221,7 +221,43 @@ pub(super) struct ControlFrame { /// ``` /// /// If result_var is None, the structure produces no value. - pub(super) result_var: Option, + pub(super) result_var: Option, + + /// Snapshot of `local_vars` at the time this frame was pushed. + /// + /// Uses: + /// - **Else frames**: `Operator::Else` restores `local_vars` to this snapshot so + /// the else branch starts with the same local state as the then branch. + /// - **If frames (no else)**: The implicit else path uses these as phi sources. + /// - **Loop frames**: Pre-loop local values; used as the entry predecessor for loop phis. + pub(super) locals_at_entry: Vec, + + /// Forward branches that target this frame's `end_block`. + /// Each entry is `(predecessor_block, local_vars_snapshot_at_branch_time)`. + /// + /// Populated by `Br`/`BrIf`/`BrTable` instructions that resolve to this frame's end. + /// Not used for Loop frames (backward branches go to `IrBuilder::phi_patches` instead). + pub(super) branch_incoming: Vec<(BlockId, Vec)>, + + /// For Else frames only: info about the then-branch fall-through. + /// + /// Set when `Operator::Else` is processed and the If frame is converted to Else. + /// Contains `(then_end_block, local_vars_at_then_end)`. + /// Used to compute phis at the `end_block` join point. + pub(super) then_pred_info: Option<(BlockId, Vec)>, + + /// For Loop frames only: pre-allocated phi VarIds (one per Wasm local). + /// + /// At push time, `local_vars[i]` is updated to `loop_phi_vars[i]` for all i. + /// This ensures all code inside the loop already references the phi vars. + /// Phi sources are filled in at `End` of the loop from `IrBuilder::phi_patches`. + pub(super) loop_phi_vars: Vec, + + /// For Loop frames only: the block immediately before the loop header. + /// + /// This block terminates with a `Jump` to `start_block` (the loop header). + /// Used as the entry predecessor when emitting loop phi nodes. + pub(super) pre_loop_block: Option, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -269,15 +305,15 @@ pub struct IrBuilder { pub(super) next_block_id: u32, /// Wasm value stack (now SSA variables instead of actual values) - pub(super) value_stack: Vec, + pub(super) value_stack: Vec, /// Control flow stack for nested blocks/loops/if pub(super) control_stack: Vec, - /// Mapping from Wasm local index → VarId. + /// Mapping from Wasm local index → UseVar. /// Populated at the start of each `translate_function` call. /// Indices 0..param_count-1 are parameters; param_count.. are declared locals. - pub(super) local_vars: Vec, + pub(super) local_vars: Vec, /// Callee function signatures: (param_count, return_type) per function index. /// Set at the start of each `translate_function` call. @@ -294,6 +330,23 @@ pub struct IrBuilder { /// Function import details: (module_name, func_name) for each imported function. /// Indexed by import_idx (0..num_imported_functions-1). pub(super) func_imports: Vec<(String, String)>, + + /// True when the current insertion point is unreachable code. + /// + /// Set to `true` by `Br`, `BrTable`, `Return`, and `Unreachable` instructions. + /// Cleared by `start_real_block()`. + /// + /// When `dead_code` is true, emitted instructions are discarded and branches + /// are NOT recorded as phi predecessors. + pub(super) dead_code: bool, + + /// Deferred loop phi sources from backward branches. + /// + /// Each entry is `(phi_dest, pred_block, src_var)`. When a `Br` targets a loop + /// header (backward edge), we record the current values of all loop phi vars here + /// instead of directly mutating the phi instructions. Consumed at `End` of each + /// Loop frame by `emit_loop_phis()`. + pub(super) phi_patches: Vec<(UseVar, BlockId, UseVar)>, } impl IrBuilder { @@ -315,14 +368,47 @@ impl IrBuilder { type_signatures: Vec::new(), num_imported_functions: 0, func_imports: Vec::new(), + dead_code: false, + phi_patches: Vec::new(), } } - /// Allocate a new SSA variable. - pub(super) fn new_var(&mut self) -> VarId { + /// Allocate a new SSA variable definition token. + /// + /// Returns a [`DefVar`] that must be consumed by exactly one call to + /// [`emit_def`] or [`emit_phi_def`]. The compiler will reject any attempt + /// to emit the same variable twice. + pub(super) fn new_var(&mut self) -> DefVar { let id = VarId(self.next_var_id); self.next_var_id += 1; - id + DefVar(id) + } + + /// Emit an instruction that produces a value. + /// + /// Consumes `dest` (enforcing single-definition) and returns a [`UseVar`] + /// that can be read any number of times. The closure receives the raw + /// [`VarId`] to embed in the [`IrInstr`]. + pub(super) fn emit_def(&mut self, dest: DefVar, f: impl FnOnce(VarId) -> IrInstr) -> UseVar { + let id = dest.into_var_id(); + self.emit_void(f(id)); + UseVar(id) + } + + /// Allocate a variable for phi pre-allocation or function parameters. + /// + /// Returns both the raw [`VarId`] (for use in [`IrInstr`] dest fields that are + /// assembled and prepended to non-current blocks) and a [`UseVar`] for later + /// reading. Unlike [`new_var`]+[`emit_def`], this does **not** enforce + /// single-definition at compile time — use it only for: + /// - Function entry parameters/locals (implicitly defined by the call) + /// - `push_control` result_var (phi convergence slot assigned by each branch) + /// - Loop phi pre-allocation (emitted later by `emit_loop_phis`) + /// - `insert_phis_at_join` (phi dests inserted into non-current blocks) + pub(super) fn new_pre_alloc_var(&mut self) -> (VarId, UseVar) { + let id = VarId(self.next_var_id); + self.next_var_id += 1; + (id, UseVar(id)) } /// Allocate a new basic block. @@ -332,8 +418,8 @@ impl IrBuilder { id } - /// Emit an instruction to the current block. - pub(super) fn emit(&mut self, instr: IrInstr) { + /// Emit an instruction (with no result, or whose result is already embedded) to the current block. + pub(super) fn emit_void(&mut self, instr: IrInstr) { if let Some(block) = self.blocks.iter_mut().find(|b| b.id == self.current_block) { block.instructions.push(instr); } else { @@ -370,22 +456,26 @@ impl IrBuilder { self.next_block_id = 0; self.current_block = BlockId(0); self.local_vars.clear(); + self.dead_code = false; + self.phi_patches.clear(); self.func_signatures = module_ctx.func_signatures.clone(); self.type_signatures = module_ctx.type_signatures.clone(); self.num_imported_functions = module_ctx.num_imported_functions; self.func_imports = module_ctx.func_imports.clone(); // Allocate VarIds for all locals (params first, then declared locals). - // This ensures local_index maps directly to the correct VarId. - let mut local_index_to_var: Vec = Vec::new(); + // This ensures local_index maps directly to the correct UseVar. + // Parameters and zero-initialized locals are implicitly defined at function entry + // (not via emit_def), so we use new_pre_alloc_var to get both VarId and UseVar. + let mut local_index_to_var: Vec = Vec::new(); // Allocate variables for parameters let param_vars: Vec<(VarId, WasmType)> = params .iter() .map(|(_, ty)| { - let var = self.new_var(); - local_index_to_var.push(var); - (var, *ty) + let (var_id, use_var) = self.new_pre_alloc_var(); + local_index_to_var.push(use_var); + (var_id, *ty) }) .collect(); @@ -393,9 +483,9 @@ impl IrBuilder { let mut func_locals: Vec<(VarId, WasmType)> = Vec::new(); for vt in locals { let ty = WasmType::from_wasmparser(*vt); - let var = self.new_var(); - local_index_to_var.push(var); - func_locals.push((var, ty)); + let (var_id, use_var) = self.new_pre_alloc_var(); + local_index_to_var.push(use_var); + func_locals.push((var_id, ty)); } self.local_vars = local_index_to_var; @@ -434,6 +524,10 @@ impl IrBuilder { } /// Push a control frame onto the control stack. + /// + /// For Loop frames, pre-allocates phi VarIds for all locals and immediately updates + /// `self.local_vars[i]` to the phi vars. This ensures that all code inside the loop + /// body reads/writes through the phi vars, making backward-branch phi sources correct. pub(super) fn push_control( &mut self, kind: ControlKind, @@ -442,13 +536,32 @@ impl IrBuilder { else_block: Option, result_type: Option, ) { - // Allocate a result variable if block has result type + // Allocate a result variable if block has result type. + // result_var is a phi convergence slot assigned by multiple branches — + // use new_pre_alloc_var to get UseVar directly. let result_var = if result_type.is_some() { - Some(self.new_var()) + let (_, use_var) = self.new_pre_alloc_var(); + Some(use_var) } else { None }; + // Snapshot local state at frame entry (before any phi-var substitution). + let locals_at_entry = self.local_vars.clone(); + + // For Loop frames: pre-allocate phi vars for all locals and immediately substitute. + let (loop_phi_vars, pre_loop_block) = if kind == ControlKind::Loop { + let phi_vars: Vec = (0..self.local_vars.len()) + .map(|_| self.new_pre_alloc_var().1) + .collect(); + // Update local_vars so code inside the loop uses phi vars. + self.local_vars.clone_from(&phi_vars); + let pre_loop = self.current_block; + (phi_vars, Some(pre_loop)) + } else { + (Vec::new(), None) + }; + self.control_stack.push(ControlFrame { kind, start_block, @@ -456,6 +569,11 @@ impl IrBuilder { else_block, result_type, result_var, + locals_at_entry, + branch_incoming: Vec::new(), + then_pred_info: None, + loop_phi_vars, + pre_loop_block, }); } @@ -504,6 +622,164 @@ impl IrBuilder { terminator: IrTerminator::Unreachable, }); } + + /// Start a new reachable block: creates the block and clears `dead_code`. + /// + /// Use this instead of `start_block` whenever the new block is a real join point + /// reachable from live code (e.g., after If/Else/End, or after BrIf fallthrough). + pub(super) fn start_real_block(&mut self, block_id: BlockId) { + self.dead_code = false; + self.start_block(block_id); + } + + /// Record a forward branch to a non-loop frame. + /// + /// Saves `(current_block, local_vars_snapshot)` in the target frame's `branch_incoming`. + /// No-op if `dead_code` is set (unreachable branches are not phi predecessors). + /// + /// `frame_idx` is the index into `self.control_stack`. + pub(super) fn record_forward_branch(&mut self, frame_idx: usize) { + if self.dead_code { + return; + } + let pred_block = self.current_block; + let locals_snap = self.local_vars.clone(); + self.control_stack[frame_idx] + .branch_incoming + .push((pred_block, locals_snap)); + } + + /// Record a backward branch to a loop frame (adds to `phi_patches`). + /// + /// For each loop phi var, records `(phi_var, current_block, current_local_value)`. + /// No-op if `dead_code` is set. + /// + /// `frame_idx` is the index into `self.control_stack` for the Loop frame. + pub(super) fn record_loop_back_branch(&mut self, frame_idx: usize) { + if self.dead_code { + return; + } + let pred_block = self.current_block; + // Clone to avoid borrow conflict (local_vars is also in self) + let phi_vars = self.control_stack[frame_idx].loop_phi_vars.clone(); + for (local_idx, &phi_var) in phi_vars.iter().enumerate() { + let src_var = self.local_vars[local_idx]; + self.phi_patches.push((phi_var, pred_block, src_var)); + } + } + + /// Insert SSA phi nodes at a join block for locals with differing predecessor values. + /// + /// For each local index, if any predecessor provides a different VarId for that local, + /// a `IrInstr::Phi` node is inserted at the beginning of `join_block` and + /// `self.local_vars[i]` is updated to the phi dest. + /// + /// Phis are inserted in local-index order at the very start of the block's instruction + /// list, before any instructions already in the block. + pub(super) fn insert_phis_at_join( + &mut self, + join_block: BlockId, + predecessors: &[(BlockId, Vec)], + ) { + let num_locals = self.local_vars.len(); + if predecessors.is_empty() || num_locals == 0 { + return; + } + + // Collect phis to insert: allocate dest vars before touching self.blocks. + let mut phi_instrs: Vec = Vec::new(); + let mut new_locals = self.local_vars.clone(); + + for local_idx in 0..num_locals { + let first_var = predecessors[0].1[local_idx]; + let all_same = predecessors + .iter() + .all(|(_, locals)| locals[local_idx] == first_var); + if !all_same { + // Use new_pre_alloc_var because the phi dest is inserted into a + // non-current block — we can't go through emit_def here. + let (dest_id, dest_use) = self.new_pre_alloc_var(); + let srcs: Vec<(BlockId, VarId)> = predecessors + .iter() + .map(|(bid, locals)| (*bid, locals[local_idx].var_id())) + .collect(); + new_locals[local_idx] = dest_use; + phi_instrs.push(IrInstr::Phi { + dest: dest_id, + srcs, + }); + } else { + // All predecessors agree on this local — no phi needed, but we still + // update local_vars to the canonical value. This ensures correctness + // when arriving from dead code (where local_vars may be stale). + new_locals[local_idx] = first_var; + } + } + + self.local_vars = new_locals; + + if !phi_instrs.is_empty() { + if let Some(block) = self.blocks.iter_mut().find(|b| b.id == join_block) { + let old = std::mem::take(&mut block.instructions); + block.instructions = phi_instrs; + block.instructions.extend(old); + } + } + } + + /// Emit phi instructions for a loop frame into its header block. + /// + /// Called at `End` of a Loop frame (after `pop_control`). Inserts `IrInstr::Phi` + /// at the start of `frame.start_block` for each local. Sources come from: + /// 1. The pre-loop predecessor (`frame.pre_loop_block`, `frame.locals_at_entry`). + /// 2. All backward branches recorded in `self.phi_patches` for this loop's phi vars. + /// + /// Consumes the relevant entries from `self.phi_patches`. + /// Trivial phis (all sources are the same var, or the only non-self source) are left + /// for the `lower_phis` pass to eliminate. + pub(super) fn emit_loop_phis(&mut self, frame: &ControlFrame) { + debug_assert_eq!(frame.kind, ControlKind::Loop); + let num_locals = frame.loop_phi_vars.len(); + if num_locals == 0 { + return; + } + + let mut phi_srcs: Vec> = vec![Vec::new(); num_locals]; + + // Entry from before the loop + if let Some(pre_block) = frame.pre_loop_block { + for (local_idx, phi_src) in phi_srcs.iter_mut().enumerate() { + phi_src.push((pre_block, frame.locals_at_entry[local_idx].var_id())); + } + } + + // Backward branch sources from phi_patches + for &(phi_dest, pred_block, src_var) in &self.phi_patches { + if let Some(local_idx) = frame.loop_phi_vars.iter().position(|v| *v == phi_dest) { + phi_srcs[local_idx].push((pred_block, src_var.var_id())); + } + } + + // Consume the processed patches + self.phi_patches + .retain(|&(phi_dest, _, _)| !frame.loop_phi_vars.contains(&phi_dest)); + + // Build phi instructions and prepend to loop header + let mut phi_instrs: Vec = Vec::new(); + for (local_idx, &phi_var) in frame.loop_phi_vars.iter().enumerate() { + let srcs = std::mem::take(&mut phi_srcs[local_idx]); + phi_instrs.push(IrInstr::Phi { + dest: phi_var.var_id(), + srcs, + }); + } + + if let Some(block) = self.blocks.iter_mut().find(|b| b.id == frame.start_block) { + let old = std::mem::take(&mut block.instructions); + block.instructions = phi_instrs; + block.instructions.extend(old); + } + } } impl Default for IrBuilder { diff --git a/crates/herkos/src/ir/builder/translate.rs b/crates/herkos/src/ir/builder/translate.rs index de96acf..bc968df 100644 --- a/crates/herkos/src/ir/builder/translate.rs +++ b/crates/herkos/src/ir/builder/translate.rs @@ -14,39 +14,31 @@ impl IrBuilder { match op { // Constants Operator::I32Const { value } => { - let dest = self.new_var(); - self.emit(IrInstr::Const { - dest, - value: IrValue::I32(*value), - }); - self.value_stack.push(dest); + let v = IrValue::I32(*value); + let _def = self.new_var(); + let use_v = self.emit_def(_def, |d| IrInstr::Const { dest: d, value: v }); + self.value_stack.push(use_v); } Operator::I64Const { value } => { - let dest = self.new_var(); - self.emit(IrInstr::Const { - dest, - value: IrValue::I64(*value), - }); - self.value_stack.push(dest); + let v = IrValue::I64(*value); + let _def = self.new_var(); + let use_v = self.emit_def(_def, |d| IrInstr::Const { dest: d, value: v }); + self.value_stack.push(use_v); } Operator::F32Const { value } => { - let dest = self.new_var(); - self.emit(IrInstr::Const { - dest, - value: IrValue::F32(f32::from_bits(value.bits())), - }); - self.value_stack.push(dest); + let v = IrValue::F32(f32::from_bits(value.bits())); + let _def = self.new_var(); + let use_v = self.emit_def(_def, |d| IrInstr::Const { dest: d, value: v }); + self.value_stack.push(use_v); } Operator::F64Const { value } => { - let dest = self.new_var(); - self.emit(IrInstr::Const { - dest, - value: IrValue::F64(f64::from_bits(value.bits())), - }); - self.value_stack.push(dest); + let v = IrValue::F64(f64::from_bits(value.bits())); + let _def = self.new_var(); + let use_v = self.emit_def(_def, |d| IrInstr::Const { dest: d, value: v }); + self.value_stack.push(use_v); } // Local variable access @@ -58,59 +50,74 @@ impl IrBuilder { .ok_or_else(|| { anyhow::anyhow!("local.get: local index {} out of range", local_index) })?; - // Emit a copy rather than pushing the local's VarId directly. - // If we push the local's VarId, a later local.tee/local.set that + // Emit a copy rather than pushing the local's UseVar directly. + // If we push the local's UseVar, a later local.tee/local.set that // overwrites the same local will corrupt any already-pushed reference // to it, because the backend emits sequential mutable assignments. // A fresh variable captures the value at this point in time. - let dest = self.new_var(); - self.emit(IrInstr::Assign { dest, src }); - self.value_stack.push(dest); + let _def = self.new_var(); + let use_v = self.emit_def(_def, |d| IrInstr::Assign { + dest: d, + src: src.var_id(), + }); + self.value_stack.push(use_v); } Operator::LocalSet { local_index } => { + let idx = *local_index as usize; // Pop value and assign to local let value = self .value_stack .pop() .ok_or_else(|| anyhow::anyhow!("Stack underflow for local.set"))?; - let dest = self - .local_vars - .get(*local_index as usize) - .copied() - .ok_or_else(|| { - anyhow::anyhow!("local.set: local index {} out of range", local_index) - })?; - self.emit(IrInstr::Assign { dest, src: value }); + if idx >= self.local_vars.len() { + bail!("local.set: local index {} out of range", local_index); + } + + // Allocate a fresh dest to satisfy SSA single-definition rule. + let _def = self.new_var(); + let use_v = self.emit_def(_def, |d| IrInstr::Assign { + dest: d, + src: value.var_id(), + }); + // Update the local mapping so subsequent reads see the new value. + self.local_vars[idx] = use_v; } Operator::LocalTee { local_index } => { + let idx = *local_index as usize; // Like LocalSet but keeps value on stack let value = self .value_stack .last() + .copied() .ok_or_else(|| anyhow::anyhow!("Stack underflow for local.tee"))?; - let dest = self - .local_vars - .get(*local_index as usize) - .copied() - .ok_or_else(|| { - anyhow::anyhow!("local.tee: local index {} out of range", local_index) - })?; - self.emit(IrInstr::Assign { dest, src: *value }); - // Value stays on stack (we already have it via .last()) + if idx >= self.local_vars.len() { + bail!("local.tee: local index {} out of range", local_index); + } + + // Allocate a fresh dest to satisfy SSA single-definition rule. + let _def = self.new_var(); + let use_v = self.emit_def(_def, |d| IrInstr::Assign { + dest: d, + src: value.var_id(), + }); + // Update the local mapping so subsequent reads see the new value. + self.local_vars[idx] = use_v; + // Value stays on stack (already there via .last()) } // Global variable access Operator::GlobalGet { global_index } => { - let dest = self.new_var(); - self.emit(IrInstr::GlobalGet { - dest, - index: GlobalIdx::new(*global_index as usize), + let idx = GlobalIdx::new(*global_index as usize); + let _def = self.new_var(); + let use_v = self.emit_def(_def, |d| IrInstr::GlobalGet { + dest: d, + index: idx, }); - self.value_stack.push(dest); + self.value_stack.push(use_v); } Operator::GlobalSet { global_index } => { @@ -118,9 +125,9 @@ impl IrBuilder { .value_stack .pop() .ok_or_else(|| anyhow::anyhow!("Stack underflow for global.set"))?; - self.emit(IrInstr::GlobalSet { + self.emit_void(IrInstr::GlobalSet { index: GlobalIdx::new(*global_index as usize), - value, + value: value.var_id(), }); } @@ -281,6 +288,7 @@ impl IrBuilder { // the already-terminated block's control flow. let dead_block = self.new_block(); self.start_block(dead_block); + self.dead_code = true; } // End (end of function or block) @@ -289,110 +297,200 @@ impl IrBuilder { // End of function - treat as implicit return self.emit_return()?; } else { - // End of block/loop/if/else let frame = self.pop_control()?; - // Check if this is an If frame - if frame.kind == super::core::ControlKind::If { - if let Some(else_block) = frame.else_block { - // === STEP 1: Finalize the THEN branch === - // At this point, we've finished executing all instructions in the if's then block. - // If the if has a result type (e.g., "if i32 ... end"), any result value - // is now on top of the value_stack, and we need to assign it to result_var - // so it can be passed to the join point. + match frame.kind { + super::core::ControlKind::If => { + // === IF without ELSE === + // The then-branch just finished. We create an implicit empty else + // block and insert phi nodes at the join point. + + let else_block = + frame.else_block.expect("If frame must have else_block"); + + // Collect predecessors of end_block for phi computation. + // 1. Fall-through from then-body (if reachable) + // 2. Any `br` inside then-body that targeted end_block + // 3. Implicit else block (with pre-if locals = locals_at_entry) + let mut preds: Vec<(BlockId, Vec)> = Vec::new(); + if !self.dead_code { + preds.push((self.current_block, self.local_vars.clone())); + } + preds.extend(frame.branch_incoming.iter().cloned()); - // Then branch: assign result if needed (only if value is on stack) + // Assign result if needed (then-branch fall-through) if let Some(result_var) = frame.result_var { if let Some(stack_value) = self.value_stack.pop() { - self.emit(IrInstr::Assign { - dest: result_var, - src: stack_value, + self.emit_void(IrInstr::Assign { + dest: result_var.var_id(), + src: stack_value.var_id(), }); } - // If stack is empty, then branch ended with br/return (unreachable after) } - // === STEP 2: Terminate the THEN branch with a forward jump === - // Jump to the end block (the join point after the if-else). - // This merges both the then and else branches back together. - self.terminate(IrTerminator::Jump { - target: frame.end_block, - }); + // Terminate then-branch (if reachable) + if !self.dead_code { + self.terminate(IrTerminator::Jump { + target: frame.end_block, + }); + } - // === STEP 3: Create the ELSE block === - // Even if the source WebAssembly had NO explicit "else" clause, - // we MUST create one in the IR because: - // - WebAssembly's `if` always has two branches (true/false) - // - The IR needs an explicit control flow graph with both paths - // - An implicit else (no code written) becomes an empty else block - // - // This is a fundamental design choice: the IR makes ALL control flow explicit. + // Create implicit else block (empty; just jumps to end). + // The else_block is always reachable (false-branch of BranchIf), + // but we don't need to clear dead_code for it — it has no user + // instructions; we just terminate it directly. self.start_block(else_block); - - // === STEP 4: Else block body (empty in this case) === - // Since the source Wasm had no explicit "else" clause, the else block - // has no instructions. It just falls through to the join point. - // We represent this as a single jump to the end block. self.terminate(IrTerminator::Jump { target: frame.end_block, }); + // The implicit else carries the pre-if local state. + preds.push((else_block, frame.locals_at_entry.clone())); + + // Restore local_vars to pre-if state before computing phis + // (preds already captured the necessary snapshots above). + self.local_vars = frame.locals_at_entry.clone(); + + // Start the join block (always reachable — else_block always jumps here) + self.start_real_block(frame.end_block); + + // Insert phi nodes for locals with differing predecessor values. + // If no live predecessors, mark as dead code. + if preds.is_empty() { + self.dead_code = true; + } else { + self.insert_phis_at_join(frame.end_block, &preds); + } + } - // === STEP 5: Continue in the END block (join point) === - // After both then and else branches have jumped here, - // future instructions execute in this end block. - self.start_block(frame.end_block); - } else { - // Should not happen - If always has else_block - bail!("If frame missing else_block"); + super::core::ControlKind::Else => { + // === IF-ELSE END === + // The else-branch just finished. Insert phis at the join point + // using then-pred and else-pred info saved during Operator::Else. + + // Collect predecessors of end_block: + // 1. then-branch fall-through (saved as then_pred_info in Else frame) + // 2. else-branch fall-through (current block, if reachable) + // 3. Any `br` from either branch targeting end_block (branch_incoming) + let mut preds: Vec<(BlockId, Vec)> = Vec::new(); + if let Some((then_block, then_locals)) = frame.then_pred_info.clone() { + preds.push((then_block, then_locals)); + } + if !self.dead_code { + preds.push((self.current_block, self.local_vars.clone())); + } + preds.extend(frame.branch_incoming.iter().cloned()); + + // Assign result if needed (else-branch fall-through) + if let Some(result_var) = frame.result_var { + if let Some(stack_value) = self.value_stack.pop() { + self.emit_void(IrInstr::Assign { + dest: result_var.var_id(), + src: stack_value.var_id(), + }); + } + } + + // Terminate else-branch (if reachable) + if !self.dead_code { + self.terminate(IrTerminator::Jump { + target: frame.end_block, + }); + } + + // Start join block + self.start_real_block(frame.end_block); + + if preds.is_empty() { + self.dead_code = true; + } else { + self.insert_phis_at_join(frame.end_block, &preds); + } } - } else { - // === This handles Block, Loop, and Else constructs (NOT If) === - // These are simpler than If: they have no branching, just linear control flow. - - // === STEP 1: Capture the block's result value (if any) === - // If this block/loop/else has a result type (e.g., "block i32 ... end"), - // the result value should be on top of the value_stack when we exit. - // We assign it to result_var so it can be used at the join point. - if let Some(result_var) = frame.result_var { - if let Some(stack_value) = self.value_stack.pop() { - // Normal case: block fell through with a result value - self.emit(IrInstr::Assign { - dest: result_var, - src: stack_value, + + super::core::ControlKind::Loop => { + // === LOOP END === + // Emit the phi instructions into the loop header (start_block). + // Then fall through to end_block. + + // Collect fall-through predecessor BEFORE switching blocks. + let mut preds: Vec<(BlockId, Vec)> = + frame.branch_incoming.clone(); + if !self.dead_code { + preds.push((self.current_block, self.local_vars.clone())); + } + + // Assign result if needed (loop fall-through) + if let Some(result_var) = frame.result_var { + if let Some(stack_value) = self.value_stack.pop() { + self.emit_void(IrInstr::Assign { + dest: result_var.var_id(), + src: stack_value.var_id(), + }); + } + } + + // Terminate loop body fall-through (if reachable) + if !self.dead_code { + self.terminate(IrTerminator::Jump { + target: frame.end_block, }); } - // WHY IS EMPTY STACK NOT AN ERROR? - // ──────────────────────────────── - // If stack is empty here, it means this block ended with a branch - // (br/br_if/br_table) or return instruction. These terminators: - // 1. Consume the value from the stack before jumping/returning - // 2. Jump away, making all subsequent code unreachable - // - // So even though result_var exists, it won't be used (the code - // after this block is unreachable via the normal path). - // This is NOT an error—it's valid control flow to have dead code - // after a terminating instruction. - // - // Example: - // block i32 - // i32.const 5 - // br 0 ◄─── Consumes the 5, jumps, stack becomes empty - // end ◄─── Stack is empty here, but that's OK + + // Emit Phi instructions into the loop header block. + // This consumes the relevant phi_patches. + self.emit_loop_phis(&frame); + + // Start the loop's exit block + self.start_real_block(frame.end_block); + + // Insert phis at end_block for locals with differing exit values + if preds.is_empty() { + self.dead_code = true; + } else { + self.insert_phis_at_join(frame.end_block, &preds); + } } - // === STEP 2: Terminate this block with a forward jump === - // Jump to the end block (the join point after this control structure). - // This is the normal exit path (only reached if no br/return interrupted us). - self.terminate(IrTerminator::Jump { - target: frame.end_block, - }); + super::core::ControlKind::Block => { + // === BLOCK END === + // Collect fall-through predecessor BEFORE switching blocks. + let mut preds: Vec<(BlockId, Vec)> = + frame.branch_incoming.clone(); + if !self.dead_code { + preds.push((self.current_block, self.local_vars.clone())); + } + + // Assign result if needed (block fall-through) + if let Some(result_var) = frame.result_var { + if let Some(stack_value) = self.value_stack.pop() { + // Normal case: block fell through with a result value + self.emit_void(IrInstr::Assign { + dest: result_var.var_id(), + src: stack_value.var_id(), + }); + } + // Empty stack: block ended with br/return (dead code after) + } + + // Terminate block fall-through (if reachable) + if !self.dead_code { + self.terminate(IrTerminator::Jump { + target: frame.end_block, + }); + } - // === STEP 3: Continue in the END block (join point) === - // All paths (normal fall-through, branches into this block) meet here. - self.start_block(frame.end_block); + // Start join block + self.start_real_block(frame.end_block); + + if preds.is_empty() { + self.dead_code = true; + } else { + self.insert_phis_at_join(frame.end_block, &preds); + } + } } - // If block has result type, push result var onto stack + // If the control structure produced a result, push it onto the value stack. if let Some(result_var) = frame.result_var { self.value_stack.push(result_var); } @@ -545,9 +643,9 @@ impl IrBuilder { // === Memory size and grow === Operator::MemorySize { mem: 0, .. } => { - let dest = self.new_var(); - self.emit(IrInstr::MemorySize { dest }); - self.value_stack.push(dest); + let _def = self.new_var(); + let use_v = self.emit_def(_def, |d| IrInstr::MemorySize { dest: d }); + self.value_stack.push(use_v); } Operator::MemoryGrow { mem: 0, .. } => { @@ -555,9 +653,12 @@ impl IrBuilder { .value_stack .pop() .ok_or_else(|| anyhow::anyhow!("Stack underflow for memory.grow"))?; - let dest = self.new_var(); - self.emit(IrInstr::MemoryGrow { dest, delta }); - self.value_stack.push(dest); + let _def = self.new_var(); + let use_v = self.emit_def(_def, |d| IrInstr::MemoryGrow { + dest: d, + delta: delta.var_id(), + }); + self.value_stack.push(use_v); } // Control flow @@ -632,18 +733,14 @@ impl IrBuilder { target: loop_header, }); - // === STEP 2: Begin codegen in the loop header block === - // This block is the entry point to the loop and the target of backward - // branches (via "br" inside the loop body). - self.start_block(loop_header); - - // === STEP 3: Push control frame === - // Record this loop's control structure: - // - kind=Loop: marks this as a loop (for br/br_if dispatch) - // - start_block=loop_header: where backward br's jump - // - end_block=end_block: where forward br's jump (exit) - // - else_block=None: loops don't have else branches (that's just If) - // - result_type: the loop's result type (if any) + // === STEP 2: Push control frame BEFORE switching blocks === + // push_control captures self.current_block as pre_loop_block (the block + // that jumps into the loop). It must be called while current_block still + // points to the pre-loop block, before start_block changes it. + // + // push_control also pre-allocates phi vars for all locals and updates + // self.local_vars to point to them, so all code inside the loop body + // reads/writes through the phi vars from the start. self.push_control( super::core::ControlKind::Loop, loop_header, @@ -651,6 +748,11 @@ impl IrBuilder { None, result_type, ); + + // === STEP 3: Begin codegen in the loop header block === + // This block is the entry point to the loop and the target of backward + // branches (via "br" inside the loop body). + self.start_block(loop_header); } Operator::If { blockty } => { @@ -669,7 +771,8 @@ impl IrBuilder { let condition = self .value_stack .pop() - .ok_or_else(|| anyhow::anyhow!("Stack underflow for if condition"))?; + .ok_or_else(|| anyhow::anyhow!("Stack underflow for if condition"))? + .var_id(); // === STEP 2: Pre-allocate all three blocks === // We create all blocks upfront so we can reference them in the BranchIf. @@ -692,8 +795,9 @@ impl IrBuilder { }); // === STEP 4: Start building the THEN branch === - // Activate the then_block so subsequent instructions are emitted there. - self.start_block(then_block); + // Activate the then_block. This is always reachable (the BranchIf true path), + // so use start_real_block to clear dead_code. + self.start_real_block(then_block); // === STEP 5: Push If control frame === // Record this if's control structure for later resolution: @@ -725,31 +829,44 @@ impl IrBuilder { bail!("else without matching if"); } + // Save then-branch fall-through info (predecessor for end_block phis). + // Only recorded if the then-branch body is reachable (not dead code). + let then_pred_info = if !self.dead_code { + Some((self.current_block, self.local_vars.clone())) + } else { + None + }; + // Then branch: assign result if needed let result_var = if_frame.result_var; if let Some(result_var) = result_var { - let stack_value = self.value_stack.pop().ok_or_else(|| { - anyhow::anyhow!("Stack underflow for then result in else") - })?; - self.emit(IrInstr::Assign { - dest: result_var, - src: stack_value, + if let Some(stack_value) = self.value_stack.pop() { + self.emit_void(IrInstr::Assign { + dest: result_var.var_id(), + src: stack_value.var_id(), + }); + } + } + + // Current block (then branch) jumps to end (if reachable) + if !self.dead_code { + self.terminate(IrTerminator::Jump { + target: if_frame.end_block, }); } - // Current block (then branch) jumps to end - self.terminate(IrTerminator::Jump { - target: if_frame.end_block, - }); + // Restore local_vars to pre-if state so the else branch starts + // with the same local variable bindings as the then branch. + self.local_vars = if_frame.locals_at_entry.clone(); - // Use the pre-allocated else block + // Use the pre-allocated else block (always reachable: false branch of BranchIf) let else_block = if_frame .else_block .expect("If frame should have else_block"); - self.start_block(else_block); + self.start_real_block(else_block); - // Push else frame (same end block, same result_var, no else_block needed) - // We manually create the frame to preserve result_var + // Push Else frame, preserving result_var and transferring branch_incoming + // (any `br 0` inside the then-branch already recorded in if_frame). self.control_stack.push(super::core::ControlFrame { kind: super::core::ControlKind::Else, start_block: else_block, @@ -757,27 +874,79 @@ impl IrBuilder { else_block: None, result_type: if_frame.result_type, result_var, + locals_at_entry: if_frame.locals_at_entry, + branch_incoming: if_frame.branch_incoming, // transfer then-body br's + then_pred_info, + loop_phi_vars: Vec::new(), + pre_loop_block: None, }); } Operator::Br { relative_depth } => { - let target = self.get_branch_target(*relative_depth)?; + let depth = *relative_depth as usize; + let frame_idx = + self.control_stack + .len() + .checked_sub(depth + 1) + .ok_or_else(|| { + anyhow::anyhow!("br: depth {} exceeds control stack", relative_depth) + })?; + + let (target, is_loop) = { + let frame = &self.control_stack[frame_idx]; + match frame.kind { + super::core::ControlKind::Loop => (frame.start_block, true), + _ => (frame.end_block, false), + } + }; + + // Record this branch as a phi predecessor (before terminate) + if is_loop { + self.record_loop_back_branch(frame_idx); + } else { + self.record_forward_branch(frame_idx); + } + self.terminate(IrTerminator::Jump { target }); - // Create unreachable continuation block - let unreachable_block = self.new_block(); - self.start_block(unreachable_block); + // Everything after an unconditional branch is unreachable + let dead_block = self.new_block(); + self.start_block(dead_block); + self.dead_code = true; } Operator::BrIf { relative_depth } => { let condition = self .value_stack .pop() - .ok_or_else(|| anyhow::anyhow!("Stack underflow for br_if"))?; + .ok_or_else(|| anyhow::anyhow!("Stack underflow for br_if"))? + .var_id(); + + let depth = *relative_depth as usize; + let frame_idx = + self.control_stack + .len() + .checked_sub(depth + 1) + .ok_or_else(|| { + anyhow::anyhow!("br_if: depth {} exceeds control stack", relative_depth) + })?; - let target = self.get_branch_target(*relative_depth)?; + let (target, is_loop) = { + let frame = &self.control_stack[frame_idx]; + match frame.kind { + super::core::ControlKind::Loop => (frame.start_block, true), + _ => (frame.end_block, false), + } + }; + + // Record the taken branch as a phi predecessor + if is_loop { + self.record_loop_back_branch(frame_idx); + } else { + self.record_forward_branch(frame_idx); + } - // Create continuation block (fallthrough) + // Create continuation block (fallthrough — always reachable) let continue_block = self.new_block(); self.terminate(IrTerminator::BranchIf { @@ -786,24 +955,49 @@ impl IrBuilder { if_false: continue_block, }); - // Continue building in continuation block - self.start_block(continue_block); + // Fallthrough is always reachable; clear dead_code + self.start_real_block(continue_block); } Operator::BrTable { targets } => { let index = self .value_stack .pop() - .ok_or_else(|| anyhow::anyhow!("Stack underflow for br_table"))?; + .ok_or_else(|| anyhow::anyhow!("Stack underflow for br_table"))? + .var_id(); - // Convert targets to BlockIds + // Collect all depths (table entries + default), deduplicate by frame_idx + // to avoid recording the same predecessor block twice for the same target. let target_depths: Vec = targets.targets().collect::, _>>()?; + let default_depth = targets.default(); + + let stack_len = self.control_stack.len(); + let mut recorded: std::collections::HashSet = + std::collections::HashSet::new(); + for depth in target_depths + .iter() + .copied() + .chain(std::iter::once(default_depth)) + { + let depth = depth as usize; + let frame_idx = stack_len.saturating_sub(depth + 1); + if recorded.insert(frame_idx) { + let is_loop = + self.control_stack[frame_idx].kind == super::core::ControlKind::Loop; + if is_loop { + self.record_loop_back_branch(frame_idx); + } else { + self.record_forward_branch(frame_idx); + } + } + } + let target_blocks: Vec = target_depths .iter() .map(|depth| self.get_branch_target(*depth)) .collect::>>()?; - let default = self.get_branch_target(targets.default())?; + let default = self.get_branch_target(default_depth)?; self.terminate(IrTerminator::BranchTable { index, @@ -811,9 +1005,10 @@ impl IrBuilder { default, }); - // Create unreachable continuation block - let unreachable_block = self.new_block(); - self.start_block(unreachable_block); + // Everything after br_table is unreachable + let dead_block = self.new_block(); + self.start_block(dead_block); + self.dead_code = true; } Operator::Call { function_index } => { @@ -826,7 +1021,14 @@ impl IrBuilder { let args = self.pop_call_args(param_count, &format!("call to func_{}", func_idx))?; - let dest = callee_return_type.map(|_| self.new_var()); + // For optional-result calls we use new_pre_alloc_var: the dest is + // defined by the call instruction itself, not via emit_def. + let (dest_id, dest_use) = if callee_return_type.is_some() { + let (id, u) = self.new_pre_alloc_var(); + (Some(id), Some(u)) + } else { + (None, None) + }; // Check if this is a call to an imported function or a local function if func_idx < self.num_imported_functions { @@ -837,8 +1039,8 @@ impl IrBuilder { anyhow::anyhow!("Call: import index {} out of range", import_idx) })?; - self.emit(IrInstr::CallImport { - dest, + self.emit_void(IrInstr::CallImport { + dest: dest_id, import_idx: ImportIdx::new(import_idx), module_name, func_name, @@ -847,15 +1049,15 @@ impl IrBuilder { } else { // Call to local function - convert to local index let local_func_idx = func_idx - self.num_imported_functions; - self.emit(IrInstr::Call { - dest, + self.emit_void(IrInstr::Call { + dest: dest_id, func_idx: LocalFuncIdx::new(local_func_idx), args, }); } - if let Some(d) = dest { - self.value_stack.push(d); + if let Some(u) = dest_use { + self.value_stack.push(u); } } @@ -873,9 +1075,13 @@ impl IrBuilder { })?; // Pop table element index (on top of stack) - let table_idx_var = self.value_stack.pop().ok_or_else(|| { - anyhow::anyhow!("Stack underflow for call_indirect table index") - })?; + let table_idx_var = self + .value_stack + .pop() + .ok_or_else(|| { + anyhow::anyhow!("Stack underflow for call_indirect table index") + })? + .var_id(); // Pop arguments let args = self.pop_call_args( @@ -883,16 +1089,21 @@ impl IrBuilder { &format!("call_indirect type {}", type_idx_usize), )?; - let dest = callee_return_type.map(|_| self.new_var()); - self.emit(IrInstr::CallIndirect { - dest, + let (dest_id, dest_use) = if callee_return_type.is_some() { + let (id, u) = self.new_pre_alloc_var(); + (Some(id), Some(u)) + } else { + (None, None) + }; + self.emit_void(IrInstr::CallIndirect { + dest: dest_id, type_idx: TypeIdx::new(*type_index as usize), table_idx: table_idx_var, args, }); - if let Some(d) = dest { - self.value_stack.push(d); + if let Some(u) = dest_use { + self.value_stack.push(u); } } @@ -901,6 +1112,7 @@ impl IrBuilder { // Create unreachable continuation block (dead code follows) let unreachable_block = self.new_block(); self.start_block(unreachable_block); + self.dead_code = true; } Operator::Select => { @@ -919,14 +1131,14 @@ impl IrBuilder { .value_stack .pop() .ok_or_else(|| anyhow::anyhow!("stack underflow in Select (val1)"))?; - let dest = self.new_var(); - self.emit(IrInstr::Select { - dest, - val1, - val2, - condition, + let _def = self.new_var(); + let use_v = self.emit_def(_def, |d| IrInstr::Select { + dest: d, + val1: val1.var_id(), + val2: val2.var_id(), + condition: condition.var_id(), }); - self.value_stack.push(dest); + self.value_stack.push(use_v); } // === Bulk memory operations === @@ -947,7 +1159,11 @@ impl IrBuilder { .value_stack .pop() .ok_or_else(|| anyhow::anyhow!("Stack underflow for memory.copy (dst)"))?; - self.emit(IrInstr::MemoryCopy { dst, src, len }); + self.emit_void(IrInstr::MemoryCopy { + dst: dst.var_id(), + src: src.var_id(), + len: len.var_id(), + }); } _ => bail!("Unsupported operator: {:?}", op), @@ -964,7 +1180,8 @@ impl IrBuilder { Some( self.value_stack .pop() - .ok_or_else(|| anyhow::anyhow!("stack underflow in return"))?, + .ok_or_else(|| anyhow::anyhow!("stack underflow in return"))? + .var_id(), ) }; self.terminate(IrTerminator::Return { value }); @@ -986,9 +1203,13 @@ impl IrBuilder { .pop() .ok_or_else(|| anyhow::anyhow!("stack underflow in binop (lhs)"))?; let dest = self.new_var(); - - self.emit(IrInstr::BinOp { dest, op, lhs, rhs }); - self.value_stack.push(dest); + let use_v = self.emit_def(dest, |v| IrInstr::BinOp { + dest: v, + op, + lhs: lhs.var_id(), + rhs: rhs.var_id(), + }); + self.value_stack.push(use_v); Ok(()) } @@ -1004,9 +1225,12 @@ impl IrBuilder { .pop() .ok_or_else(|| anyhow::anyhow!("stack underflow in unop (operand)"))?; let dest = self.new_var(); - - self.emit(IrInstr::UnOp { dest, op, operand }); - self.value_stack.push(dest); + let use_v = self.emit_def(dest, |v| IrInstr::UnOp { + dest: v, + op, + operand: operand.var_id(), + }); + self.value_stack.push(use_v); Ok(()) } @@ -1035,17 +1259,15 @@ impl IrBuilder { .pop() .ok_or_else(|| anyhow::anyhow!("stack underflow in load (addr)"))?; let dest = self.new_var(); - - self.emit(IrInstr::Load { - dest, + let use_v = self.emit_def(dest, |v| IrInstr::Load { + dest: v, ty, - addr, + addr: addr.var_id(), offset: offset as u32, width, sign, }); - - self.value_stack.push(dest); + self.value_stack.push(use_v); Ok(()) } @@ -1063,9 +1285,14 @@ impl IrBuilder { } let mut args = Vec::with_capacity(param_count); for _ in 0..param_count { - args.push(self.value_stack.pop().ok_or_else(|| { - anyhow::anyhow!("stack underflow collecting {} arguments", context) - })?); + args.push( + self.value_stack + .pop() + .ok_or_else(|| { + anyhow::anyhow!("stack underflow collecting {} arguments", context) + })? + .var_id(), + ); } args.reverse(); Ok(args) @@ -1092,10 +1319,10 @@ impl IrBuilder { .pop() .ok_or_else(|| anyhow::anyhow!("stack underflow in store (addr)"))?; - self.emit(IrInstr::Store { + self.emit_void(IrInstr::Store { ty, - addr, - value, + addr: addr.var_id(), + value: value.var_id(), offset: offset as u32, width, }); diff --git a/crates/herkos/src/ir/lower_phis.rs b/crates/herkos/src/ir/lower_phis.rs new file mode 100644 index 0000000..ee81849 --- /dev/null +++ b/crates/herkos/src/ir/lower_phis.rs @@ -0,0 +1,425 @@ +//! SSA phi-node lowering. +//! +//! ## Overview +//! +//! This pass converts `IrInstr::Phi` nodes — which are SSA-form join-point +//! selectors — into ordinary `IrInstr::Assign` instructions placed at the end +//! of predecessor blocks (just before their terminators). After this pass, no +//! `IrInstr::Phi` instructions remain in the IR. +//! +//! This pass **must run before all optimizer passes** and **before codegen**. +//! It is a phase transition (SSA destruction), not an optimization. +//! +//! ## Algorithm +//! +//! For each function: +//! +//! 1. **Prune stale phi sources**: dead_blocks may have removed predecessor +//! blocks. Remove any `(pred_block, _)` entry in a Phi's `srcs` whose +//! `pred_block` no longer exists in `func.blocks`. +//! +//! 2. **Simplify trivial phis**: A phi is trivial if: +//! - It has no sources → dead (replace with self-assign, will be removed by +//! dead_instrs if unused). +//! - Ignoring self-references (`src == dest`), all remaining sources resolve +//! to the same single variable. +//! +//! In both cases, replace `Phi { dest, srcs }` with `Assign { dest, src }`. +//! Repeat until no more trivial phis can be simplified. +//! +//! 3. **Lower non-trivial phis**: For each remaining `Phi { dest, srcs }`: +//! - In each predecessor block, insert `Assign { dest, src }` just before +//! the block's terminator. +//! - Remove the `Phi` instruction from the join block. +//! +//! ## Example +//! +//! Given an `if/else` that produces a value, the SSA IR contains a phi at the +//! merge block: +//! +//! ```text +//! block0 (entry): +//! v0 = i32.const 1 +//! v1 = i32.const 2 +//! br_if v_cond → block1, block2 +//! +//! block1 (then): +//! br → block3 +//! +//! block2 (else): +//! br → block3 +//! +//! block3 (merge): +//! v2 = phi [(block1, v0), (block2, v1)] ← SSA phi +//! ...use v2... +//! ``` +//! +//! After lowering, the phi is removed and each predecessor gets an assignment: +//! +//! ```text +//! block1 (then): +//! v2 = v0 ← inserted before terminator +//! br → block3 +//! +//! block2 (else): +//! v2 = v1 ← inserted before terminator +//! br → block3 +//! +//! block3 (merge): +//! ...use v2... ← phi gone; v2 already set +//! ``` +//! +//! ## Why predecessor assignments? +//! +//! A phi at a join point conceptually says "take the value from whichever +//! predecessor we came from". In the generated Rust state machine, the predecessor +//! block's code runs immediately before transitioning to the join block, so +//! assigning there is equivalent to selecting based on the taken path. + +use crate::ir::{BlockId, IrFunction, IrInstr, ModuleInfo, VarId}; +use std::collections::HashSet; + +/// Lower all `IrInstr::Phi` nodes in `module_info`, returning a [`super::LoweredModuleInfo`]. +/// +/// After this call, no `IrInstr::Phi` instructions remain in any function. +/// The returned [`super::LoweredModuleInfo`] can be passed to the optimizer and codegen. +pub fn lower(module_info: ModuleInfo) -> super::LoweredModuleInfo { + let mut module_info = module_info; + for func in &mut module_info.ir_functions { + lower_func(func); + debug_assert!( + func.blocks + .iter() + .flat_map(|b| &b.instructions) + .all(|i| !matches!(i, IrInstr::Phi { .. })), + "lower_phis: phi instructions remain after lowering — bug in lower_func" + ); + } + super::LoweredModuleInfo(module_info) +} + +/// Lower all `IrInstr::Phi` nodes in a single function. +fn lower_func(func: &mut IrFunction) { + // Collect the set of live block IDs present in the IR right now. + // Note: lower_phis runs before the optimizer, so no optimizer dead-block + // elimination has happened yet. This pruning handles blocks that were never + // emitted (e.g. code after `unreachable` in the IR builder) — their block + // IDs may still appear as phi sources even though the blocks were dropped. + let live_blocks: HashSet = func.blocks.iter().map(|b| b.id).collect(); + + // Step 1: Prune phi sources that refer to removed (dead) predecessor blocks. + for block in &mut func.blocks { + for instr in &mut block.instructions { + if let IrInstr::Phi { srcs, .. } = instr { + srcs.retain(|(pred_id, _)| live_blocks.contains(pred_id)); + } + } + } + + // Step 2: Simplify trivial phis to Assign, iterating to fixpoint. + // + // A phi is trivial when — after removing self-references — all remaining + // sources point to the same VarId. We iterate because simplifying one phi + // may allow another that referenced it to become trivial. + // + // Example — loop-invariant variable (self-reference only): + // + // v1 = phi [(block0, v0), (block2, v1)] + // ^^^^^^^^^^^^^^^^^^ pre-loop value + // ^^^^^^^^^^^ back-edge (self-ref, ignored) + // + // → all non-self sources resolve to v0 → trivial + // → simplified to: v1 = v0 + // + // Example — chain: simplifying v2 unlocks v3: + // + // v2 = phi [(block0, v0), (block1, v0)] ← both sources same → trivial + // v3 = phi [(block0, v2), (block1, v2)] ← after v2→v0: both v0 → trivial + // + // pass 1: v2 → Assign { dest: v2, src: v0 } + // pass 2: v3 → Assign { dest: v3, src: v0 } + loop { + let mut changed = false; + for block in &mut func.blocks { + for instr in &mut block.instructions { + if let IrInstr::Phi { dest, srcs } = instr { + let phi_dest = *dest; + + // Collect unique non-self sources + let non_self: Vec = srcs + .iter() + .map(|(_, v)| *v) + .filter(|&v| v != phi_dest) + .collect::>() + .into_iter() + .collect(); + + let trivial_src = if non_self.is_empty() { + // All sources are self-references — phi is dead. + // Replace with an assign from itself (a no-op) so the + // dead_instrs pass can remove it if unused. + Some(phi_dest) + } else if non_self.len() == 1 { + // All non-self sources agree on one value. + Some(non_self[0]) + } else { + None + }; + + if let Some(src) = trivial_src { + *instr = IrInstr::Assign { + dest: phi_dest, + src, + }; + changed = true; + } + } + } + } + + if !changed { + break; + } + + // After simplifying a phi, propagate the change: replace uses of the + // former phi's dest with its now-known single source in the same block. + // (Full global propagation is handled by copy_prop; we just do the local + // fixpoint simplification of other phis in the same block here.) + for block in &mut func.blocks { + // Collect Assign(dest, src) replacements from this iteration's simplifications. + let replacements: Vec<(VarId, VarId)> = block + .instructions + .iter() + .filter_map(|i| { + if let IrInstr::Assign { dest, src } = i { + // Only propagate trivial-phi-turned-assigns where src != dest + if *dest != *src { + Some((*dest, *src)) + } else { + None + } + } else { + None + } + }) + .collect(); + + for (old, new) in replacements { + for instr in &mut block.instructions { + replace_phi_src(instr, old, new); + } + } + } + } + + // Step 3: Lower non-trivial phis to predecessor-block assignments. + // + // We collect all phi nodes first, then mutate blocks. + // `phi_assignments` maps `(dest, pred_block) -> src`. + // `phi_locations` maps `block_id -> [phi_dest]` so we know which phis to remove. + let mut phi_assignments: Vec<(BlockId, IrInstr)> = Vec::new(); + let mut phi_block_dests: Vec<(BlockId, VarId)> = Vec::new(); + + for block in &func.blocks { + for instr in &block.instructions { + if let IrInstr::Phi { dest, srcs } = instr { + phi_block_dests.push((block.id, *dest)); + for (pred_block, src) in srcs { + phi_assignments.push(( + *pred_block, + IrInstr::Assign { + dest: *dest, + src: *src, + }, + )); + } + } + } + } + + // Insert assignments into predecessor blocks (before the terminator). + for (pred_id, assign) in phi_assignments { + if let Some(block) = func.blocks.iter_mut().find(|b| b.id == pred_id) { + // Insert just before the terminator (i.e., at the end of the instruction list). + block.instructions.push(assign); + } + } + + // Remove phi instructions from their join blocks. + let phi_dests: HashSet = phi_block_dests.iter().map(|(_, d)| *d).collect(); + for block in &mut func.blocks { + block + .instructions + .retain(|i| !matches!(i, IrInstr::Phi { dest, .. } if phi_dests.contains(dest))); + } +} + +/// Replace every read-occurrence of `old` with `new` in `instr`. +/// Used during trivial-phi simplification to propagate the resolved source. +fn replace_phi_src(instr: &mut IrInstr, old: VarId, new: VarId) { + let sub = |v: &mut VarId| { + if *v == old { + *v = new; + } + }; + match instr { + IrInstr::Phi { srcs, .. } => { + for (_, src) in srcs { + sub(src); + } + } + IrInstr::Assign { src, .. } => sub(src), + // Other instruction kinds don't appear before phi lowering completes + // within the trivial-phi loop, so only phi/assign need handling here. + _ => {} + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{BlockId, IrBlock, IrFunction, IrTerminator, IrValue, TypeIdx, VarId}; + + fn make_module(blocks: Vec) -> ModuleInfo { + ModuleInfo { + has_memory: false, + has_memory_import: false, + max_pages: 0, + initial_pages: 0, + table_initial: 0, + table_max: 0, + element_segments: Vec::new(), + globals: Vec::new(), + data_segments: Vec::new(), + func_exports: Vec::new(), + type_signatures: Vec::new(), + canonical_type: Vec::new(), + func_imports: Vec::new(), + imported_globals: Vec::new(), + ir_functions: vec![IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + }], + wasm_version: 1, + } + } + + /// A trivial phi where all non-self sources agree: phi(v0, v0) → Assign(dest, v0). + #[test] + fn test_trivial_phi_same_source() { + // block_0: v0 = const 1; jump block_1 + // block_1: v1 = phi((block_0, v0), (block_0, v0)); return v1 + let block0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }; + let block1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Phi { + dest: VarId(1), + srcs: vec![(BlockId(0), VarId(0)), (BlockId(0), VarId(0))], + }], + terminator: IrTerminator::Return { + value: Some(VarId(1)), + }, + }; + + let module = make_module(vec![block0, block1]); + let lowered = lower(module); + + // phi should be simplified to Assign + let instrs = &lowered.ir_functions[0].blocks[1].instructions; + assert_eq!(instrs.len(), 1); + assert!(matches!( + instrs[0], + IrInstr::Assign { + dest: VarId(1), + src: VarId(0) + } + )); + } + + /// A non-trivial phi with two different sources gets lowered to predecessor assignments. + #[test] + fn test_non_trivial_phi_lowering() { + // block_0: v0=1; br_if block_1 else block_2 + // block_1: v1=2; jump block_3 + // block_2: v2=3; jump block_3 + // block_3: v3=phi((block_1,v1),(block_2,v2)); return v3 + let block0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(0), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }; + let block1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(2), + }], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }; + let block2 = IrBlock { + id: BlockId(2), + instructions: vec![IrInstr::Const { + dest: VarId(2), + value: IrValue::I32(3), + }], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }; + let block3 = IrBlock { + id: BlockId(3), + instructions: vec![IrInstr::Phi { + dest: VarId(3), + srcs: vec![(BlockId(1), VarId(1)), (BlockId(2), VarId(2))], + }], + terminator: IrTerminator::Return { + value: Some(VarId(3)), + }, + }; + + let module = make_module(vec![block0, block1, block2, block3]); + let lowered = lower(module); + let func = &lowered.ir_functions[0]; + + // No phi in block3 after lowering + assert!(!func.blocks[3] + .instructions + .iter() + .any(|i| matches!(i, IrInstr::Phi { .. }))); + + // block1 should have an Assign v3 = v1 appended + assert!(func.blocks[1].instructions.iter().any(|i| matches!( + i, + IrInstr::Assign { + dest: VarId(3), + src: VarId(1) + } + ))); + + // block2 should have an Assign v3 = v2 appended + assert!(func.blocks[2].instructions.iter().any(|i| matches!( + i, + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } + ))); + } +} diff --git a/crates/herkos/src/ir/mod.rs b/crates/herkos/src/ir/mod.rs index e70694f..694c185 100644 --- a/crates/herkos/src/ir/mod.rs +++ b/crates/herkos/src/ir/mod.rs @@ -7,9 +7,33 @@ //! It includes: //! - **Per-function IR** ([`IrFunction`], [`IrBlock`], [`IrInstr`]): SSA-form IR for function bodies //! - **Module-level IR** ([`ModuleInfo`] and related types): Module structure and metadata +//! - **[`LoweredModuleInfo`]**: Post-SSA-destruction wrapper; no `IrInstr::Phi` nodes remain mod types; pub use types::*; pub mod builder; pub use builder::{build_module_info, ModuleContext}; + +pub mod lower_phis; + +/// [`ModuleInfo`] with all `IrInstr::Phi` nodes lowered to `IrInstr::Assign`. +/// +/// Constructed exclusively by [`lower_phis::lower`]. Signals the phase +/// boundary between SSA IR (with phi nodes) and post-SSA IR (without). +/// After this point, no optimizer pass or codegen module will encounter +/// `IrInstr::Phi` in any function body. +pub struct LoweredModuleInfo(ModuleInfo); + +impl std::ops::Deref for LoweredModuleInfo { + type Target = ModuleInfo; + fn deref(&self) -> &ModuleInfo { + &self.0 + } +} + +impl std::ops::DerefMut for LoweredModuleInfo { + fn deref_mut(&mut self) -> &mut ModuleInfo { + &mut self.0 + } +} diff --git a/crates/herkos/src/ir/types.rs b/crates/herkos/src/ir/types.rs index 02bf3ab..efc7443 100644 --- a/crates/herkos/src/ir/types.rs +++ b/crates/herkos/src/ir/types.rs @@ -12,6 +12,44 @@ use std::fmt; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct VarId(pub u32); +/// One-time-use definition token for an SSA variable. +/// +/// Returned by [`IrBuilder::new_var()`] and consumed by [`IrBuilder::emit_def()`]. +/// Non-`Copy`/non-`Clone` so that the borrow checker enforces single-definition: +/// attempting to emit the same variable twice is a compile-time error. +/// +/// ```compile_fail +/// let v = builder.new_var(); +/// let _ = builder.emit_def(v, |id| IrInstr::Const { dest: id, value: IrValue::I32(1) }); +/// let _ = builder.emit_def(v, |id| IrInstr::Const { dest: id, value: IrValue::I32(2) }); // error: use of moved value +/// ``` +#[must_use = "DefVar must be emitted exactly once via emit_def or emit_phi_def"] +#[derive(Debug)] +pub struct DefVar(pub(super) VarId); + +impl DefVar { + /// Consume this token and return the underlying [`VarId`]. + /// Used internally by emit methods to build [`IrInstr`] with the correct dest. + pub(super) fn into_var_id(self) -> VarId { + self.0 + } +} + +/// Multi-use read token for an already-defined SSA variable. +/// +/// Obtained from [`IrBuilder::emit_def()`] or [`IrBuilder::emit_phi_def()`]. +/// `Copy + Clone` because a variable may be read any number of times. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct UseVar(pub(super) VarId); + +impl UseVar { + /// Return the underlying [`VarId`]. + /// Used internally when building [`IrInstr`] source fields. + pub(super) fn var_id(self) -> VarId { + self.0 + } +} + /// Generic index type with a phantom tag to distinguish different index spaces. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Idx { @@ -315,6 +353,17 @@ pub enum IrInstr { val2: VarId, condition: VarId, }, + + /// SSA phi node: at a join point, select the reaching definition based on which + /// predecessor block was taken. `srcs` maps predecessor BlockId to the VarId that + /// holds the value of the local at the end of that predecessor. + /// + /// Phi nodes are inserted during IR construction and lowered (converted to Assign + /// instructions in predecessor blocks) before codegen. + Phi { + dest: VarId, + srcs: Vec<(BlockId, VarId)>, + }, } /// Block terminator — how control flow exits a basic block. @@ -345,7 +394,7 @@ pub enum IrTerminator { } /// Constant value in the IR. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum IrValue { I32(i32), I64(i64), @@ -383,7 +432,7 @@ impl fmt::Display for IrValue { } /// Binary operations. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum BinOp { // i32 operations I32Add, @@ -479,7 +528,7 @@ pub enum BinOp { } /// Unary operations. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum UnOp { // i32 unary I32Clz, // Count leading zeros @@ -547,6 +596,45 @@ pub enum UnOp { } impl BinOp { + /// Returns `true` if this operation is a comparison (produces 0 or 1). + pub fn is_comparison(&self) -> bool { + matches!( + self, + BinOp::I32Eq + | BinOp::I32Ne + | BinOp::I32LtS + | BinOp::I32LtU + | BinOp::I32GtS + | BinOp::I32GtU + | BinOp::I32LeS + | BinOp::I32LeU + | BinOp::I32GeS + | BinOp::I32GeU + | BinOp::I64Eq + | BinOp::I64Ne + | BinOp::I64LtS + | BinOp::I64LtU + | BinOp::I64GtS + | BinOp::I64GtU + | BinOp::I64LeS + | BinOp::I64LeU + | BinOp::I64GeS + | BinOp::I64GeU + | BinOp::F32Eq + | BinOp::F32Ne + | BinOp::F32Lt + | BinOp::F32Gt + | BinOp::F32Le + | BinOp::F32Ge + | BinOp::F64Eq + | BinOp::F64Ne + | BinOp::F64Lt + | BinOp::F64Gt + | BinOp::F64Le + | BinOp::F64Ge + ) + } + /// Returns the WasmType of the result produced by this operation. /// /// Note: all comparison operations return i32 (0 or 1), even for i64/f32/f64 operands. diff --git a/crates/herkos/src/lib.rs b/crates/herkos/src/lib.rs index efa298a..d540971 100644 --- a/crates/herkos/src/lib.rs +++ b/crates/herkos/src/lib.rs @@ -14,7 +14,7 @@ pub use anyhow::{Context, Result}; use backend::SafeBackend; use codegen::CodeGenerator; use ir::builder::build_module_info; -use ir::ModuleInfo; +use ir::{lower_phis, LoweredModuleInfo}; use optimizer::optimize_ir; use parser::parse_wasm; @@ -65,6 +65,11 @@ pub fn transpile(wasm_bytes: &[u8], options: &TranspileOptions) -> Result Result Result { +fn generate_rust_code(module_info: &LoweredModuleInfo) -> Result { let backend = SafeBackend::new(); let codegen = CodeGenerator::new(&backend); diff --git a/crates/herkos/src/optimizer/algebraic.rs b/crates/herkos/src/optimizer/algebraic.rs new file mode 100644 index 0000000..c0aa32f --- /dev/null +++ b/crates/herkos/src/optimizer/algebraic.rs @@ -0,0 +1,782 @@ +//! Algebraic simplifications. +//! +//! Rewrites `BinOp` instructions when one operand is a known constant and an +//! identity or annihilator rule applies. Runs after `const_prop` so that +//! constant operands are already resolved. +//! +//! ## Rules +//! +//! | Pattern | Result | +//! |---------------------|--------------| +//! | `x + 0`, `0 + x` | `x` | +//! | `x - 0` | `x` | +//! | `x * 1`, `1 * x` | `x` | +//! | `x * 0`, `0 * x` | `0` | +//! | `x & 0` | `0` | +//! | `x & -1` | `x` | +//! | `x | 0` | `x` | +//! | `x | -1` | `-1` | +//! | `x ^ 0` | `x` | +//! | `x ^ x` | `0` | +//! | `x << 0`, `x >> 0` | `x` | +//! | `x == x` | `1` | +//! | `x != x` | `0` | + +use super::utils::instr_dest; +use crate::ir::{BinOp, IrFunction, IrInstr, IrValue, VarId}; +use std::collections::HashMap; + +// ── Public entry point ──────────────────────────────────────────────────────── + +pub fn eliminate(func: &mut IrFunction) { + let global_consts = build_global_const_map(func); + + for block in &mut func.blocks { + let mut local_consts = global_consts.clone(); + + for instr in &mut block.instructions { + // Track constants defined in this block. + if let IrInstr::Const { dest, value } = instr { + local_consts.insert(*dest, *value); + continue; + } + + let (dest, op, lhs, rhs) = match instr { + IrInstr::BinOp { dest, op, lhs, rhs } => (*dest, *op, *lhs, *rhs), + _ => continue, + }; + + // Same-operand rules (no constant needed). + if lhs == rhs { + if let Some(replacement) = same_operand_rule(op, dest, lhs) { + *instr = replacement; + if let IrInstr::Const { dest, value } = instr { + local_consts.insert(*dest, *value); + } + continue; + } + } + + let lhs_val = local_consts.get(&lhs).copied(); + let rhs_val = local_consts.get(&rhs).copied(); + + if let Some(replacement) = constant_operand_rule(op, dest, lhs, rhs, lhs_val, rhs_val) { + *instr = replacement; + if let IrInstr::Const { dest, value } = instr { + local_consts.insert(*dest, *value); + } + } + } + } +} + +// ── Global constant map ─────────────────────────────────────────────────────── + +fn build_global_const_map(func: &IrFunction) -> HashMap { + let mut total_defs: HashMap = HashMap::new(); + let mut const_defs: HashMap = HashMap::new(); + + for block in &func.blocks { + for instr in &block.instructions { + if let Some(dest) = instr_dest(instr) { + *total_defs.entry(dest).or_insert(0) += 1; + if let IrInstr::Const { dest, value } = instr { + const_defs.insert(*dest, *value); + } + } + } + } + + const_defs + .into_iter() + .filter(|(v, _)| total_defs.get(v).copied().unwrap_or(0) == 1) + .collect() +} + +// ── Same-operand rules ──────────────────────────────────────────────────────── + +fn same_operand_rule(op: BinOp, dest: VarId, operand: VarId) -> Option { + match op { + // x ^ x → 0 + BinOp::I32Xor => Some(IrInstr::Const { + dest, + value: IrValue::I32(0), + }), + BinOp::I64Xor => Some(IrInstr::Const { + dest, + value: IrValue::I64(0), + }), + + // x == x → 1 (integers only; floats have NaN) + BinOp::I32Eq | BinOp::I32LeS | BinOp::I32LeU | BinOp::I32GeS | BinOp::I32GeU => { + Some(IrInstr::Const { + dest, + value: IrValue::I32(1), + }) + } + BinOp::I64Eq | BinOp::I64LeS | BinOp::I64LeU | BinOp::I64GeS | BinOp::I64GeU => { + Some(IrInstr::Const { + dest, + value: IrValue::I32(1), + }) + } + + // x != x → 0 (integers only) + BinOp::I32Ne | BinOp::I32LtS | BinOp::I32LtU | BinOp::I32GtS | BinOp::I32GtU => { + Some(IrInstr::Const { + dest, + value: IrValue::I32(0), + }) + } + BinOp::I64Ne | BinOp::I64LtS | BinOp::I64LtU | BinOp::I64GtS | BinOp::I64GtU => { + Some(IrInstr::Const { + dest, + value: IrValue::I32(0), + }) + } + + // x - x → 0 (integers only; floats have Inf - Inf = NaN) + BinOp::I32Sub => Some(IrInstr::Const { + dest, + value: IrValue::I32(0), + }), + BinOp::I64Sub => Some(IrInstr::Const { + dest, + value: IrValue::I64(0), + }), + + // x & x → x, x | x → x + BinOp::I32And | BinOp::I32Or | BinOp::I64And | BinOp::I64Or => { + Some(IrInstr::Assign { dest, src: operand }) + } + + _ => None, + } +} + +// ── Constant-operand rules ──────────────────────────────────────────────────── + +fn constant_operand_rule( + op: BinOp, + dest: VarId, + lhs: VarId, + rhs: VarId, + lhs_val: Option, + rhs_val: Option, +) -> Option { + match op { + // ── Add ────────────────────────────────────────────────────────── + BinOp::I32Add => match (lhs_val, rhs_val) { + (_, Some(IrValue::I32(0))) => Some(IrInstr::Assign { dest, src: lhs }), + (Some(IrValue::I32(0)), _) => Some(IrInstr::Assign { dest, src: rhs }), + _ => None, + }, + BinOp::I64Add => match (lhs_val, rhs_val) { + (_, Some(IrValue::I64(0))) => Some(IrInstr::Assign { dest, src: lhs }), + (Some(IrValue::I64(0)), _) => Some(IrInstr::Assign { dest, src: rhs }), + _ => None, + }, + + // ── Sub ────────────────────────────────────────────────────────── + BinOp::I32Sub => match rhs_val { + Some(IrValue::I32(0)) => Some(IrInstr::Assign { dest, src: lhs }), + _ => None, + }, + BinOp::I64Sub => match rhs_val { + Some(IrValue::I64(0)) => Some(IrInstr::Assign { dest, src: lhs }), + _ => None, + }, + + // ── Mul ────────────────────────────────────────────────────────── + BinOp::I32Mul => match (lhs_val, rhs_val) { + (_, Some(IrValue::I32(1))) => Some(IrInstr::Assign { dest, src: lhs }), + (Some(IrValue::I32(1)), _) => Some(IrInstr::Assign { dest, src: rhs }), + (_, Some(IrValue::I32(0))) | (Some(IrValue::I32(0)), _) => Some(IrInstr::Const { + dest, + value: IrValue::I32(0), + }), + _ => None, + }, + BinOp::I64Mul => match (lhs_val, rhs_val) { + (_, Some(IrValue::I64(1))) => Some(IrInstr::Assign { dest, src: lhs }), + (Some(IrValue::I64(1)), _) => Some(IrInstr::Assign { dest, src: rhs }), + (_, Some(IrValue::I64(0))) => Some(IrInstr::Const { + dest, + value: IrValue::I64(0), + }), + (Some(IrValue::I64(0)), _) => Some(IrInstr::Const { + dest, + value: IrValue::I64(0), + }), + _ => None, + }, + + // ── And ────────────────────────────────────────────────────────── + BinOp::I32And => match (lhs_val, rhs_val) { + (_, Some(IrValue::I32(0))) | (Some(IrValue::I32(0)), _) => Some(IrInstr::Const { + dest, + value: IrValue::I32(0), + }), + (_, Some(IrValue::I32(-1))) => Some(IrInstr::Assign { dest, src: lhs }), + (Some(IrValue::I32(-1)), _) => Some(IrInstr::Assign { dest, src: rhs }), + _ => None, + }, + BinOp::I64And => match (lhs_val, rhs_val) { + (_, Some(IrValue::I64(0))) | (Some(IrValue::I64(0)), _) => Some(IrInstr::Const { + dest, + value: IrValue::I64(0), + }), + (_, Some(IrValue::I64(-1))) => Some(IrInstr::Assign { dest, src: lhs }), + (Some(IrValue::I64(-1)), _) => Some(IrInstr::Assign { dest, src: rhs }), + _ => None, + }, + + // ── Or ─────────────────────────────────────────────────────────── + BinOp::I32Or => match (lhs_val, rhs_val) { + (_, Some(IrValue::I32(0))) => Some(IrInstr::Assign { dest, src: lhs }), + (Some(IrValue::I32(0)), _) => Some(IrInstr::Assign { dest, src: rhs }), + (_, Some(IrValue::I32(-1))) | (Some(IrValue::I32(-1)), _) => Some(IrInstr::Const { + dest, + value: IrValue::I32(-1), + }), + _ => None, + }, + BinOp::I64Or => match (lhs_val, rhs_val) { + (_, Some(IrValue::I64(0))) => Some(IrInstr::Assign { dest, src: lhs }), + (Some(IrValue::I64(0)), _) => Some(IrInstr::Assign { dest, src: rhs }), + (_, Some(IrValue::I64(-1))) | (Some(IrValue::I64(-1)), _) => Some(IrInstr::Const { + dest, + value: IrValue::I64(-1), + }), + _ => None, + }, + + // ── Xor ────────────────────────────────────────────────────────── + BinOp::I32Xor => match (lhs_val, rhs_val) { + (_, Some(IrValue::I32(0))) => Some(IrInstr::Assign { dest, src: lhs }), + (Some(IrValue::I32(0)), _) => Some(IrInstr::Assign { dest, src: rhs }), + _ => None, + }, + BinOp::I64Xor => match (lhs_val, rhs_val) { + (_, Some(IrValue::I64(0))) => Some(IrInstr::Assign { dest, src: lhs }), + (Some(IrValue::I64(0)), _) => Some(IrInstr::Assign { dest, src: rhs }), + _ => None, + }, + + // ── Shifts / Rotates ───────────────────────────────────────────── + BinOp::I32Shl | BinOp::I32ShrS | BinOp::I32ShrU | BinOp::I32Rotl | BinOp::I32Rotr => { + match rhs_val { + Some(IrValue::I32(0)) => Some(IrInstr::Assign { dest, src: lhs }), + _ => None, + } + } + BinOp::I64Shl | BinOp::I64ShrS | BinOp::I64ShrU | BinOp::I64Rotl | BinOp::I64Rotr => { + match rhs_val { + Some(IrValue::I64(0)) => Some(IrInstr::Assign { dest, src: lhs }), + _ => None, + } + } + + _ => None, + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{BlockId, IrBlock, IrTerminator, TypeIdx}; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + } + } + + fn single_block(instrs: Vec) -> Vec { + vec![IrBlock { + id: BlockId(0), + instructions: instrs, + terminator: IrTerminator::Return { value: None }, + }] + } + + // ── Additive identity ──────────────────────────────────────────────── + + #[test] + fn add_zero_rhs() { + // v1 = 0; v2 = v0 + v1 → v2 = Assign(v0) + let mut func = make_func(single_block(vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(2), + src: VarId(0) + } + )); + } + + #[test] + fn add_zero_lhs() { + // v1 = 0; v2 = v1 + v0 → v2 = Assign(v0) + let mut func = make_func(single_block(vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(1), + rhs: VarId(0), + }, + ])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(2), + src: VarId(0) + } + )); + } + + // ── Multiplicative identity ────────────────────────────────────────── + + #[test] + fn mul_one() { + let mut func = make_func(single_block(vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(1), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Mul, + lhs: VarId(0), + rhs: VarId(1), + }, + ])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(2), + src: VarId(0) + } + )); + } + + #[test] + fn mul_zero() { + let mut func = make_func(single_block(vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Mul, + lhs: VarId(0), + rhs: VarId(1), + }, + ])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Const { + dest: VarId(2), + value: IrValue::I32(0) + } + )); + } + + // ── XOR same operand ───────────────────────────────────────────────── + + #[test] + fn xor_self() { + let mut func = make_func(single_block(vec![IrInstr::BinOp { + dest: VarId(1), + op: BinOp::I32Xor, + lhs: VarId(0), + rhs: VarId(0), + }])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0) + } + )); + } + + // ── Equality same operand ──────────────────────────────────────────── + + #[test] + fn eq_self() { + let mut func = make_func(single_block(vec![IrInstr::BinOp { + dest: VarId(1), + op: BinOp::I32Eq, + lhs: VarId(0), + rhs: VarId(0), + }])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(1) + } + )); + } + + #[test] + fn ne_self() { + let mut func = make_func(single_block(vec![IrInstr::BinOp { + dest: VarId(1), + op: BinOp::I32Ne, + lhs: VarId(0), + rhs: VarId(0), + }])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0) + } + )); + } + + // ── AND / OR with constants ────────────────────────────────────────── + + #[test] + fn and_zero() { + let mut func = make_func(single_block(vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32And, + lhs: VarId(0), + rhs: VarId(1), + }, + ])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Const { + dest: VarId(2), + value: IrValue::I32(0) + } + )); + } + + #[test] + fn and_all_ones() { + let mut func = make_func(single_block(vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(-1), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32And, + lhs: VarId(0), + rhs: VarId(1), + }, + ])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(2), + src: VarId(0) + } + )); + } + + #[test] + fn or_all_ones() { + let mut func = make_func(single_block(vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(-1), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Or, + lhs: VarId(0), + rhs: VarId(1), + }, + ])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Const { + dest: VarId(2), + value: IrValue::I32(-1) + } + )); + } + + // ── Shift by zero ──────────────────────────────────────────────────── + + #[test] + fn shl_zero() { + let mut func = make_func(single_block(vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Shl, + lhs: VarId(0), + rhs: VarId(1), + }, + ])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(2), + src: VarId(0) + } + )); + } + + // ── Sub self ───────────────────────────────────────────────────────── + + #[test] + fn sub_self() { + let mut func = make_func(single_block(vec![IrInstr::BinOp { + dest: VarId(1), + op: BinOp::I32Sub, + lhs: VarId(0), + rhs: VarId(0), + }])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0) + } + )); + } + + // ── Sub zero ───────────────────────────────────────────────────────── + + #[test] + fn sub_zero() { + let mut func = make_func(single_block(vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Sub, + lhs: VarId(0), + rhs: VarId(1), + }, + ])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(2), + src: VarId(0) + } + )); + } + + // ── i64 variants ───────────────────────────────────────────────────── + + #[test] + fn i64_add_zero() { + let mut func = make_func(single_block(vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I64(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I64Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(2), + src: VarId(0) + } + )); + } + + #[test] + fn i64_xor_self() { + let mut func = make_func(single_block(vec![IrInstr::BinOp { + dest: VarId(1), + op: BinOp::I64Xor, + lhs: VarId(0), + rhs: VarId(0), + }])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(1), + value: IrValue::I64(0) + } + )); + } + + // ── Cross-block constant ───────────────────────────────────────────── + + #[test] + fn cross_block_const_simplification() { + // B0: v1 = 0 → B1: v2 = v0 + v1 → v2 = Assign(v0) + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::Return { + value: Some(VarId(2)), + }, + }, + ]); + eliminate(&mut func); + assert!(matches!( + func.blocks[1].instructions[0], + IrInstr::Assign { + dest: VarId(2), + src: VarId(0) + } + )); + } + + // ── No-op: non-identity constant unchanged ─────────────────────────── + + #[test] + fn add_nonzero_unchanged() { + let mut func = make_func(single_block(vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(5), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::BinOp { .. } + )); + } + + // ── Float ops are NOT simplified (NaN concerns) ────────────────────── + + #[test] + fn float_add_zero_unchanged() { + let mut func = make_func(single_block(vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::F32(0.0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::F32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ])); + eliminate(&mut func); + // Float add with 0 is NOT simplified because -0.0 + 0.0 = 0.0 ≠ -0.0 + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::BinOp { .. } + )); + } + + // ── AND/OR self ────────────────────────────────────────────────────── + + #[test] + fn and_self() { + let mut func = make_func(single_block(vec![IrInstr::BinOp { + dest: VarId(1), + op: BinOp::I32And, + lhs: VarId(0), + rhs: VarId(0), + }])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Assign { + dest: VarId(1), + src: VarId(0) + } + )); + } + + #[test] + fn or_self() { + let mut func = make_func(single_block(vec![IrInstr::BinOp { + dest: VarId(1), + op: BinOp::I32Or, + lhs: VarId(0), + rhs: VarId(0), + }])); + eliminate(&mut func); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Assign { + dest: VarId(1), + src: VarId(0) + } + )); + } +} diff --git a/crates/herkos/src/optimizer/branch_fold.rs b/crates/herkos/src/optimizer/branch_fold.rs new file mode 100644 index 0000000..eecfbe8 --- /dev/null +++ b/crates/herkos/src/optimizer/branch_fold.rs @@ -0,0 +1,366 @@ +//! Branch condition folding. +//! +//! Simplifies `BranchIf` terminators by looking at the instruction that +//! defines the condition variable: +//! +//! - `Eqz(x)` as condition → swap branch targets, use `x` directly +//! - `Ne(x, 0)` as condition → use `x` directly +//! - `Eq(x, 0)` as condition → swap branch targets, use `x` directly +//! +//! After substitution, the defining instruction becomes dead (single use was +//! the branch) and is cleaned up by `dead_instrs`. + +use super::utils::{build_global_use_count, instr_dest}; +use crate::ir::{BinOp, IrFunction, IrInstr, IrTerminator, IrValue, UnOp, VarId}; +use std::collections::HashMap; + +pub fn eliminate(func: &mut IrFunction) { + loop { + let global_uses = build_global_use_count(func); + if !fold_one(func, &global_uses) { + break; + } + } +} + +/// Attempt a single branch fold across the function. Returns `true` if a +/// change was made. +fn fold_one(func: &mut IrFunction, global_uses: &HashMap) -> bool { + // Build a map of VarId → defining instruction info. + // We only care about single-use vars defined by Eqz, Ne(x,0), or Eq(x,0). + let mut var_defs: HashMap = HashMap::new(); + + // Also build a global constant map for checking if an operand is zero. + let global_consts = build_global_const_map(func); + + for block in &func.blocks { + let mut local_consts = global_consts.clone(); + for instr in &block.instructions { + if let IrInstr::Const { dest, value } = instr { + local_consts.insert(*dest, *value); + } + + if let Some(dest) = instr_dest(instr) { + match instr { + IrInstr::UnOp { + op: UnOp::I32Eqz | UnOp::I64Eqz, + operand, + .. + } => { + var_defs.insert(dest, VarDef::Eqz(*operand)); + } + IrInstr::BinOp { + op: BinOp::I32Ne | BinOp::I64Ne, + lhs, + rhs, + .. + } => { + if is_zero(rhs, &local_consts) { + var_defs.insert(dest, VarDef::NeZero(*lhs)); + } else if is_zero(lhs, &local_consts) { + var_defs.insert(dest, VarDef::NeZero(*rhs)); + } + } + IrInstr::BinOp { + op: BinOp::I32Eq | BinOp::I64Eq, + lhs, + rhs, + .. + } => { + if is_zero(rhs, &local_consts) { + var_defs.insert(dest, VarDef::EqZero(*lhs)); + } else if is_zero(lhs, &local_consts) { + var_defs.insert(dest, VarDef::EqZero(*rhs)); + } + } + _ => {} + } + } + } + } + + // Now scan terminators for BranchIf with a foldable condition. + for block in &mut func.blocks { + let condition = match &block.terminator { + IrTerminator::BranchIf { condition, .. } => *condition, + _ => continue, + }; + + // Only fold if the condition has exactly one use (the BranchIf). + if global_uses.get(&condition).copied().unwrap_or(0) != 1 { + continue; + } + + let def = match var_defs.get(&condition) { + Some(d) => d, + None => continue, + }; + + match def { + VarDef::Eqz(inner) | VarDef::EqZero(inner) => { + // eqz(x) != 0 ≡ x == 0, so swap targets and use x + if let IrTerminator::BranchIf { + condition: cond, + if_true, + if_false, + } = &mut block.terminator + { + *cond = *inner; + std::mem::swap(if_true, if_false); + } + return true; + } + VarDef::NeZero(inner) => { + // ne(x, 0) != 0 ≡ x != 0, so just use x + if let IrTerminator::BranchIf { + condition: cond, .. + } = &mut block.terminator + { + *cond = *inner; + } + return true; + } + } + } + + false +} + +#[derive(Clone, Copy)] +enum VarDef { + Eqz(VarId), + NeZero(VarId), + EqZero(VarId), +} + +fn is_zero(var: &VarId, consts: &HashMap) -> bool { + matches!( + consts.get(var), + Some(IrValue::I32(0)) | Some(IrValue::I64(0)) + ) +} + +fn build_global_const_map(func: &IrFunction) -> HashMap { + let mut total_defs: HashMap = HashMap::new(); + let mut const_defs: HashMap = HashMap::new(); + + for block in &func.blocks { + for instr in &block.instructions { + if let Some(dest) = instr_dest(instr) { + *total_defs.entry(dest).or_insert(0) += 1; + if let IrInstr::Const { dest, value } = instr { + const_defs.insert(*dest, *value); + } + } + } + } + + const_defs + .into_iter() + .filter(|(v, _)| total_defs.get(v).copied().unwrap_or(0) == 1) + .collect() +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{BlockId, IrBlock, TypeIdx}; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + } + } + + #[test] + fn eqz_swaps_targets() { + // v1 = Eqz(v0); BranchIf(v1, B1, B2) → BranchIf(v0, B2, B1) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32Eqz, + operand: VarId(0), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + match &func.blocks[0].terminator { + IrTerminator::BranchIf { + condition, + if_true, + if_false, + } => { + assert_eq!(*condition, VarId(0)); + assert_eq!(*if_true, BlockId(2), "targets should be swapped"); + assert_eq!(*if_false, BlockId(1)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn ne_zero_simplifies() { + // v1 = 0; v2 = Ne(v0, v1); BranchIf(v2, B1, B2) → BranchIf(v0, B1, B2) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Ne, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + match &func.blocks[0].terminator { + IrTerminator::BranchIf { + condition, + if_true, + if_false, + } => { + assert_eq!(*condition, VarId(0)); + assert_eq!(*if_true, BlockId(1), "targets should NOT be swapped"); + assert_eq!(*if_false, BlockId(2)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn eq_zero_swaps() { + // v1 = 0; v2 = Eq(v0, v1); BranchIf(v2, B1, B2) → BranchIf(v0, B2, B1) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Eq, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + match &func.blocks[0].terminator { + IrTerminator::BranchIf { + condition, + if_true, + if_false, + } => { + assert_eq!(*condition, VarId(0)); + assert_eq!(*if_true, BlockId(2), "targets should be swapped"); + assert_eq!(*if_false, BlockId(1)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn multi_use_not_folded() { + // v1 = Eqz(v0); use(v1) elsewhere → don't fold + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32Eqz, + operand: VarId(0), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Return { + value: Some(VarId(1)), // second use of v1 + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + eliminate(&mut func); + // Should NOT fold — v1 has 2 uses + match &func.blocks[0].terminator { + IrTerminator::BranchIf { condition, .. } => { + assert_eq!(*condition, VarId(1), "should not have been folded"); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn cross_block_zero_const() { + // B0: v1 = 0; Jump(B1) + // B1: v2 = Ne(v0, v1); BranchIf(v2, B2, B3) → BranchIf(v0, B2, B3) + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Ne, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + ]); + eliminate(&mut func); + match &func.blocks[1].terminator { + IrTerminator::BranchIf { condition, .. } => { + assert_eq!(*condition, VarId(0)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } +} diff --git a/crates/herkos/src/optimizer/const_prop.rs b/crates/herkos/src/optimizer/const_prop.rs new file mode 100644 index 0000000..d8bcf49 --- /dev/null +++ b/crates/herkos/src/optimizer/const_prop.rs @@ -0,0 +1,1309 @@ +//! Constant folding and propagation. +//! +//! ## What it does +//! +//! Tracks which `VarId`s hold known constant values, then: +//! +//! - **Propagates** constants through `Assign` chains. +//! - **Folds** `BinOp(Const, Const)` and `UnOp(Const)` into `Const` when the +//! result is statically computable. +//! +//! ## Algorithm +//! +//! Each iteration of the outer fixpoint loop: +//! +//! 1. Build a **global constant map**: variables that have exactly one definition +//! across the entire function and that definition is a `Const` instruction. +//! These are safe to treat as constants in any block that uses them (SSA +//! single-definition guarantee ensures the value is the same everywhere). +//! +//! 2. For each block, maintain a `HashMap` of known constants, +//! **seeded** with the global constant map. Walk instructions in order: +//! - `Const { dest, value }` → record `dest → value` +//! - `Assign { dest, src }` where src is known → replace with `Const`, record +//! - `BinOp { dest, op, lhs, rhs }` where both known → fold via `try_eval_binop` +//! - `UnOp { dest, op, operand }` where operand known → fold via `try_eval_unop` +//! +//! 3. Run to fixpoint. +//! +//! The global constant seed is what enables cross-block propagation: after +//! const_prop converts a local `Assign` to a `Const` in block B0, the next +//! outer-loop iteration sees that new `Const` definition in the global map and +//! can fold `Assign(that_var)` instructions in dominated blocks B1, B2, etc. +//! +//! ## Safety +//! +//! Operations that may trap at runtime are **not** folded: +//! - `DivS`/`DivU`/`RemS`/`RemU` with divisor 0 +//! - `I32DivS(I32_MIN, -1)` / `I64DivS(I64_MIN, -1)` — signed overflow +//! - `TruncF*` with NaN or out-of-range float + +use super::utils::instr_dest; +use crate::ir::{BinOp, IrFunction, IrInstr, IrValue, UnOp, VarId}; +use herkos_runtime::{ + i32_trunc_f32_s, i32_trunc_f32_u, i32_trunc_f64_s, i32_trunc_f64_u, i64_trunc_f32_s, + i64_trunc_f32_u, i64_trunc_f64_s, i64_trunc_f64_u, wasm_max_f32, wasm_max_f64, wasm_min_f32, + wasm_min_f64, wasm_nearest_f32, wasm_nearest_f64, +}; +use std::collections::HashMap; + +// ── Public entry point ──────────────────────────────────────────────────────── + +/// Variables with exactly one definition across the function that is a `Const` +/// instruction. These can be treated as constants in any block that uses them. +fn build_global_const_map(func: &IrFunction) -> HashMap { + // Count total definitions per variable (any instruction with a dest). + let mut total_defs: HashMap = HashMap::new(); + let mut const_defs: HashMap = HashMap::new(); + + for block in &func.blocks { + for instr in &block.instructions { + if let Some(dest) = super::utils::instr_dest(instr) { + *total_defs.entry(dest).or_insert(0) += 1; + if let IrInstr::Const { dest, value } = instr { + const_defs.insert(*dest, *value); + } + } + } + } + + // Only include variables whose sole definition is a Const instruction. + const_defs + .into_iter() + .filter(|(v, _)| total_defs.get(v).copied().unwrap_or(0) == 1) + .collect() +} + +/// Run constant propagation and folding to fixpoint. +pub fn eliminate(func: &mut IrFunction) { + loop { + // Seed each block's known-constants map with variables that are defined + // exactly once as a Const across the whole function. This enables + // cross-block propagation: after a block-local fold turns an Assign into + // a Const in one block, the next outer iteration seeds that constant into + // all other blocks that use the variable. + let global_consts = build_global_const_map(func); + + let mut changed = false; + for block in &mut func.blocks { + let mut known: HashMap = global_consts.clone(); + + for instr in &mut block.instructions { + // Track whether we recorded a known constant for this instruction's dest. + let mut folded = false; + + match instr { + IrInstr::Const { dest, value } => { + known.insert(*dest, *value); + folded = true; + } + IrInstr::Assign { dest, src } => { + let d = *dest; + if let Some(val) = known.get(src).copied() { + *instr = IrInstr::Const { + dest: d, + value: val, + }; + known.insert(d, val); + changed = true; + folded = true; + } + } + IrInstr::BinOp { + dest, op, lhs, rhs, .. + } => { + let (d, o) = (*dest, *op); + if let (Some(lv), Some(rv)) = + (known.get(lhs).copied(), known.get(rhs).copied()) + { + if let Some(result) = try_eval_binop(o, lv, rv) { + *instr = IrInstr::Const { + dest: d, + value: result, + }; + known.insert(d, result); + changed = true; + folded = true; + } + } + } + IrInstr::UnOp { + dest, op, operand, .. + } => { + let (d, o) = (*dest, *op); + if let Some(val) = known.get(operand).copied() { + if let Some(result) = try_eval_unop(o, val) { + *instr = IrInstr::Const { + dest: d, + value: result, + }; + known.insert(d, result); + changed = true; + folded = true; + } + } + } + _ => {} + } + + // If we didn't fold this instruction to a constant, ensure its dest + // is not in `known` (purely defensive: in strict SSA form each variable + // is defined exactly once, so this is always a no-op). + if !folded { + if let Some(dest) = instr_dest(instr) { + known.remove(&dest); + } + } + } + } + if !changed { + break; + } + } +} + +// ── Binary operation folding ────────────────────────────────────────────────── + +/// Attempt to evaluate a binary operation on two constant values. +/// +/// Returns `None` when: +/// - The value types don't match the expected operand types for the op. +/// - The operation would trap at runtime (div/rem by zero, signed overflow, +/// etc.) — we must preserve the runtime trap rather than folding it away. +fn try_eval_binop(op: BinOp, lhs: IrValue, rhs: IrValue) -> Option { + match (op, lhs, rhs) { + // ── i32 arithmetic ────────────────────────────────────────────── + (BinOp::I32Add, IrValue::I32(a), IrValue::I32(b)) => Some(IrValue::I32(a.wrapping_add(b))), + (BinOp::I32Sub, IrValue::I32(a), IrValue::I32(b)) => Some(IrValue::I32(a.wrapping_sub(b))), + (BinOp::I32Mul, IrValue::I32(a), IrValue::I32(b)) => Some(IrValue::I32(a.wrapping_mul(b))), + + // Division/remainder: do NOT fold if it would trap. + (BinOp::I32DivS, IrValue::I32(a), IrValue::I32(b)) => a.checked_div(b).map(IrValue::I32), + (BinOp::I32DivU, IrValue::I32(a), IrValue::I32(b)) => (a as u32) + .checked_div(b as u32) + .map(|v| IrValue::I32(v as i32)), + (BinOp::I32RemS, IrValue::I32(a), IrValue::I32(b)) => { + if b == 0 { + None + } else if a == i32::MIN && b == -1 { + Some(IrValue::I32(0)) + } else { + Some(IrValue::I32(a % b)) + } + } + (BinOp::I32RemU, IrValue::I32(a), IrValue::I32(b)) => (a as u32) + .checked_rem(b as u32) + .map(|v| IrValue::I32(v as i32)), + + // Bitwise + (BinOp::I32And, IrValue::I32(a), IrValue::I32(b)) => Some(IrValue::I32(a & b)), + (BinOp::I32Or, IrValue::I32(a), IrValue::I32(b)) => Some(IrValue::I32(a | b)), + (BinOp::I32Xor, IrValue::I32(a), IrValue::I32(b)) => Some(IrValue::I32(a ^ b)), + + // Shifts/rotates (Wasm masks shift amount by type width) + (BinOp::I32Shl, IrValue::I32(a), IrValue::I32(b)) => { + Some(IrValue::I32(a.wrapping_shl(b as u32 & 31))) + } + (BinOp::I32ShrS, IrValue::I32(a), IrValue::I32(b)) => { + Some(IrValue::I32(a.wrapping_shr(b as u32 & 31))) + } + (BinOp::I32ShrU, IrValue::I32(a), IrValue::I32(b)) => { + Some(IrValue::I32((a as u32).wrapping_shr(b as u32 & 31) as i32)) + } + (BinOp::I32Rotl, IrValue::I32(a), IrValue::I32(b)) => { + Some(IrValue::I32((a as u32).rotate_left(b as u32 & 31) as i32)) + } + (BinOp::I32Rotr, IrValue::I32(a), IrValue::I32(b)) => { + Some(IrValue::I32((a as u32).rotate_right(b as u32 & 31) as i32)) + } + + // i32 comparisons + (BinOp::I32Eq, IrValue::I32(a), IrValue::I32(b)) => { + Some(IrValue::I32(if a == b { 1 } else { 0 })) + } + (BinOp::I32Ne, IrValue::I32(a), IrValue::I32(b)) => { + Some(IrValue::I32(if a != b { 1 } else { 0 })) + } + (BinOp::I32LtS, IrValue::I32(a), IrValue::I32(b)) => { + Some(IrValue::I32(if a < b { 1 } else { 0 })) + } + (BinOp::I32LtU, IrValue::I32(a), IrValue::I32(b)) => { + Some(IrValue::I32(if (a as u32) < (b as u32) { 1 } else { 0 })) + } + (BinOp::I32GtS, IrValue::I32(a), IrValue::I32(b)) => { + Some(IrValue::I32(if a > b { 1 } else { 0 })) + } + (BinOp::I32GtU, IrValue::I32(a), IrValue::I32(b)) => { + Some(IrValue::I32(if (a as u32) > (b as u32) { 1 } else { 0 })) + } + (BinOp::I32LeS, IrValue::I32(a), IrValue::I32(b)) => { + Some(IrValue::I32(if a <= b { 1 } else { 0 })) + } + (BinOp::I32LeU, IrValue::I32(a), IrValue::I32(b)) => { + Some(IrValue::I32(if (a as u32) <= (b as u32) { 1 } else { 0 })) + } + (BinOp::I32GeS, IrValue::I32(a), IrValue::I32(b)) => { + Some(IrValue::I32(if a >= b { 1 } else { 0 })) + } + (BinOp::I32GeU, IrValue::I32(a), IrValue::I32(b)) => { + Some(IrValue::I32(if (a as u32) >= (b as u32) { 1 } else { 0 })) + } + + // ── i64 arithmetic ────────────────────────────────────────────── + (BinOp::I64Add, IrValue::I64(a), IrValue::I64(b)) => Some(IrValue::I64(a.wrapping_add(b))), + (BinOp::I64Sub, IrValue::I64(a), IrValue::I64(b)) => Some(IrValue::I64(a.wrapping_sub(b))), + (BinOp::I64Mul, IrValue::I64(a), IrValue::I64(b)) => Some(IrValue::I64(a.wrapping_mul(b))), + + (BinOp::I64DivS, IrValue::I64(a), IrValue::I64(b)) => a.checked_div(b).map(IrValue::I64), + (BinOp::I64DivU, IrValue::I64(a), IrValue::I64(b)) => (a as u64) + .checked_div(b as u64) + .map(|v| IrValue::I64(v as i64)), + (BinOp::I64RemS, IrValue::I64(a), IrValue::I64(b)) => { + if b == 0 { + None + } else if a == i64::MIN && b == -1 { + Some(IrValue::I64(0)) + } else { + Some(IrValue::I64(a % b)) + } + } + (BinOp::I64RemU, IrValue::I64(a), IrValue::I64(b)) => (a as u64) + .checked_rem(b as u64) + .map(|v| IrValue::I64(v as i64)), + + // Bitwise + (BinOp::I64And, IrValue::I64(a), IrValue::I64(b)) => Some(IrValue::I64(a & b)), + (BinOp::I64Or, IrValue::I64(a), IrValue::I64(b)) => Some(IrValue::I64(a | b)), + (BinOp::I64Xor, IrValue::I64(a), IrValue::I64(b)) => Some(IrValue::I64(a ^ b)), + + // Shifts/rotates (Wasm masks by 63 for i64) + (BinOp::I64Shl, IrValue::I64(a), IrValue::I64(b)) => { + Some(IrValue::I64(a.wrapping_shl(b as u32 & 63))) + } + (BinOp::I64ShrS, IrValue::I64(a), IrValue::I64(b)) => { + Some(IrValue::I64(a.wrapping_shr(b as u32 & 63))) + } + (BinOp::I64ShrU, IrValue::I64(a), IrValue::I64(b)) => { + Some(IrValue::I64((a as u64).wrapping_shr(b as u32 & 63) as i64)) + } + (BinOp::I64Rotl, IrValue::I64(a), IrValue::I64(b)) => { + Some(IrValue::I64((a as u64).rotate_left(b as u32 & 63) as i64)) + } + (BinOp::I64Rotr, IrValue::I64(a), IrValue::I64(b)) => { + Some(IrValue::I64((a as u64).rotate_right(b as u32 & 63) as i64)) + } + + // i64 comparisons (result is i32) + (BinOp::I64Eq, IrValue::I64(a), IrValue::I64(b)) => { + Some(IrValue::I32(if a == b { 1 } else { 0 })) + } + (BinOp::I64Ne, IrValue::I64(a), IrValue::I64(b)) => { + Some(IrValue::I32(if a != b { 1 } else { 0 })) + } + (BinOp::I64LtS, IrValue::I64(a), IrValue::I64(b)) => { + Some(IrValue::I32(if a < b { 1 } else { 0 })) + } + (BinOp::I64LtU, IrValue::I64(a), IrValue::I64(b)) => { + Some(IrValue::I32(if (a as u64) < (b as u64) { 1 } else { 0 })) + } + (BinOp::I64GtS, IrValue::I64(a), IrValue::I64(b)) => { + Some(IrValue::I32(if a > b { 1 } else { 0 })) + } + (BinOp::I64GtU, IrValue::I64(a), IrValue::I64(b)) => { + Some(IrValue::I32(if (a as u64) > (b as u64) { 1 } else { 0 })) + } + (BinOp::I64LeS, IrValue::I64(a), IrValue::I64(b)) => { + Some(IrValue::I32(if a <= b { 1 } else { 0 })) + } + (BinOp::I64LeU, IrValue::I64(a), IrValue::I64(b)) => { + Some(IrValue::I32(if (a as u64) <= (b as u64) { 1 } else { 0 })) + } + (BinOp::I64GeS, IrValue::I64(a), IrValue::I64(b)) => { + Some(IrValue::I32(if a >= b { 1 } else { 0 })) + } + (BinOp::I64GeU, IrValue::I64(a), IrValue::I64(b)) => { + Some(IrValue::I32(if (a as u64) >= (b as u64) { 1 } else { 0 })) + } + + // ── f32 arithmetic ────────────────────────────────────────────── + (BinOp::F32Add, IrValue::F32(a), IrValue::F32(b)) => Some(IrValue::F32(a + b)), + (BinOp::F32Sub, IrValue::F32(a), IrValue::F32(b)) => Some(IrValue::F32(a - b)), + (BinOp::F32Mul, IrValue::F32(a), IrValue::F32(b)) => Some(IrValue::F32(a * b)), + (BinOp::F32Div, IrValue::F32(a), IrValue::F32(b)) => Some(IrValue::F32(a / b)), + // Wasm f32.min/max: if either operand is NaN the result is NaN. + // Rust's f32::min/max ignores NaN (returns the other operand), which + // violates the Wasm spec. + (BinOp::F32Min, IrValue::F32(a), IrValue::F32(b)) => Some(IrValue::F32(wasm_min_f32(a, b))), + (BinOp::F32Max, IrValue::F32(a), IrValue::F32(b)) => Some(IrValue::F32(wasm_max_f32(a, b))), + (BinOp::F32Copysign, IrValue::F32(a), IrValue::F32(b)) => Some(IrValue::F32(a.copysign(b))), + + // f32 comparisons (result is i32) + (BinOp::F32Eq, IrValue::F32(a), IrValue::F32(b)) => { + Some(IrValue::I32(if a == b { 1 } else { 0 })) + } + (BinOp::F32Ne, IrValue::F32(a), IrValue::F32(b)) => { + Some(IrValue::I32(if a != b { 1 } else { 0 })) + } + (BinOp::F32Lt, IrValue::F32(a), IrValue::F32(b)) => { + Some(IrValue::I32(if a < b { 1 } else { 0 })) + } + (BinOp::F32Gt, IrValue::F32(a), IrValue::F32(b)) => { + Some(IrValue::I32(if a > b { 1 } else { 0 })) + } + (BinOp::F32Le, IrValue::F32(a), IrValue::F32(b)) => { + Some(IrValue::I32(if a <= b { 1 } else { 0 })) + } + (BinOp::F32Ge, IrValue::F32(a), IrValue::F32(b)) => { + Some(IrValue::I32(if a >= b { 1 } else { 0 })) + } + + // ── f64 arithmetic ────────────────────────────────────────────── + (BinOp::F64Add, IrValue::F64(a), IrValue::F64(b)) => Some(IrValue::F64(a + b)), + (BinOp::F64Sub, IrValue::F64(a), IrValue::F64(b)) => Some(IrValue::F64(a - b)), + (BinOp::F64Mul, IrValue::F64(a), IrValue::F64(b)) => Some(IrValue::F64(a * b)), + (BinOp::F64Div, IrValue::F64(a), IrValue::F64(b)) => Some(IrValue::F64(a / b)), + // Wasm f64.min/max: if either operand is NaN the result is NaN. + (BinOp::F64Min, IrValue::F64(a), IrValue::F64(b)) => Some(IrValue::F64(wasm_min_f64(a, b))), + (BinOp::F64Max, IrValue::F64(a), IrValue::F64(b)) => Some(IrValue::F64(wasm_max_f64(a, b))), + (BinOp::F64Copysign, IrValue::F64(a), IrValue::F64(b)) => Some(IrValue::F64(a.copysign(b))), + + // f64 comparisons (result is i32) + (BinOp::F64Eq, IrValue::F64(a), IrValue::F64(b)) => { + Some(IrValue::I32(if a == b { 1 } else { 0 })) + } + (BinOp::F64Ne, IrValue::F64(a), IrValue::F64(b)) => { + Some(IrValue::I32(if a != b { 1 } else { 0 })) + } + (BinOp::F64Lt, IrValue::F64(a), IrValue::F64(b)) => { + Some(IrValue::I32(if a < b { 1 } else { 0 })) + } + (BinOp::F64Gt, IrValue::F64(a), IrValue::F64(b)) => { + Some(IrValue::I32(if a > b { 1 } else { 0 })) + } + (BinOp::F64Le, IrValue::F64(a), IrValue::F64(b)) => { + Some(IrValue::I32(if a <= b { 1 } else { 0 })) + } + (BinOp::F64Ge, IrValue::F64(a), IrValue::F64(b)) => { + Some(IrValue::I32(if a >= b { 1 } else { 0 })) + } + + // Type mismatch — don't fold. + _ => None, + } +} + +// ── Unary operation folding ─────────────────────────────────────────────────── + +/// Attempt to evaluate a unary operation on a constant value. +/// +/// Returns `None` for trapping conversions (`TruncF*` with NaN/out-of-range). +fn try_eval_unop(op: UnOp, val: IrValue) -> Option { + match (op, val) { + // ── i32 unary ─────────────────────────────────────────────────── + (UnOp::I32Clz, IrValue::I32(v)) => Some(IrValue::I32((v as u32).leading_zeros() as i32)), + (UnOp::I32Ctz, IrValue::I32(v)) => Some(IrValue::I32((v as u32).trailing_zeros() as i32)), + (UnOp::I32Popcnt, IrValue::I32(v)) => Some(IrValue::I32((v as u32).count_ones() as i32)), + (UnOp::I32Eqz, IrValue::I32(v)) => Some(IrValue::I32(if v == 0 { 1 } else { 0 })), + + // ── i64 unary ─────────────────────────────────────────────────── + (UnOp::I64Clz, IrValue::I64(v)) => Some(IrValue::I64((v as u64).leading_zeros() as i64)), + (UnOp::I64Ctz, IrValue::I64(v)) => Some(IrValue::I64((v as u64).trailing_zeros() as i64)), + (UnOp::I64Popcnt, IrValue::I64(v)) => Some(IrValue::I64((v as u64).count_ones() as i64)), + (UnOp::I64Eqz, IrValue::I64(v)) => Some(IrValue::I32(if v == 0 { 1 } else { 0 })), + + // ── f32 unary ─────────────────────────────────────────────────── + (UnOp::F32Abs, IrValue::F32(v)) => Some(IrValue::F32(v.abs())), + (UnOp::F32Neg, IrValue::F32(v)) => Some(IrValue::F32(-v)), + (UnOp::F32Ceil, IrValue::F32(v)) => Some(IrValue::F32(v.ceil())), + (UnOp::F32Floor, IrValue::F32(v)) => Some(IrValue::F32(v.floor())), + (UnOp::F32Trunc, IrValue::F32(v)) => Some(IrValue::F32(v.trunc())), + (UnOp::F32Nearest, IrValue::F32(v)) => Some(IrValue::F32(wasm_nearest_f32(v))), + (UnOp::F32Sqrt, IrValue::F32(v)) => Some(IrValue::F32(v.sqrt())), + + // ── f64 unary ─────────────────────────────────────────────────── + (UnOp::F64Abs, IrValue::F64(v)) => Some(IrValue::F64(v.abs())), + (UnOp::F64Neg, IrValue::F64(v)) => Some(IrValue::F64(-v)), + (UnOp::F64Ceil, IrValue::F64(v)) => Some(IrValue::F64(v.ceil())), + (UnOp::F64Floor, IrValue::F64(v)) => Some(IrValue::F64(v.floor())), + (UnOp::F64Trunc, IrValue::F64(v)) => Some(IrValue::F64(v.trunc())), + (UnOp::F64Nearest, IrValue::F64(v)) => Some(IrValue::F64(wasm_nearest_f64(v))), + (UnOp::F64Sqrt, IrValue::F64(v)) => Some(IrValue::F64(v.sqrt())), + + // ── Integer conversions ───────────────────────────────────────── + (UnOp::I32WrapI64, IrValue::I64(v)) => Some(IrValue::I32(v as i32)), + (UnOp::I64ExtendI32S, IrValue::I32(v)) => Some(IrValue::I64(v as i64)), + (UnOp::I64ExtendI32U, IrValue::I32(v)) => Some(IrValue::I64((v as u32) as i64)), + + // ── Float → integer (trapping) — do NOT fold on NaN/overflow ── + (UnOp::I32TruncF32S, IrValue::F32(v)) => i32_trunc_f32_s(v).ok().map(IrValue::I32), + (UnOp::I32TruncF32U, IrValue::F32(v)) => i32_trunc_f32_u(v).ok().map(IrValue::I32), + (UnOp::I32TruncF64S, IrValue::F64(v)) => i32_trunc_f64_s(v).ok().map(IrValue::I32), + (UnOp::I32TruncF64U, IrValue::F64(v)) => i32_trunc_f64_u(v).ok().map(IrValue::I32), + (UnOp::I64TruncF32S, IrValue::F32(v)) => i64_trunc_f32_s(v).ok().map(IrValue::I64), + (UnOp::I64TruncF32U, IrValue::F32(v)) => i64_trunc_f32_u(v).ok().map(IrValue::I64), + (UnOp::I64TruncF64S, IrValue::F64(v)) => i64_trunc_f64_s(v).ok().map(IrValue::I64), + (UnOp::I64TruncF64U, IrValue::F64(v)) => i64_trunc_f64_u(v).ok().map(IrValue::I64), + + // ── Integer → float conversions ───────────────────────────────── + (UnOp::F32ConvertI32S, IrValue::I32(v)) => Some(IrValue::F32(v as f32)), + (UnOp::F32ConvertI32U, IrValue::I32(v)) => Some(IrValue::F32((v as u32) as f32)), + (UnOp::F32ConvertI64S, IrValue::I64(v)) => Some(IrValue::F32(v as f32)), + (UnOp::F32ConvertI64U, IrValue::I64(v)) => Some(IrValue::F32((v as u64) as f32)), + (UnOp::F64ConvertI32S, IrValue::I32(v)) => Some(IrValue::F64(v as f64)), + (UnOp::F64ConvertI32U, IrValue::I32(v)) => Some(IrValue::F64((v as u32) as f64)), + (UnOp::F64ConvertI64S, IrValue::I64(v)) => Some(IrValue::F64(v as f64)), + (UnOp::F64ConvertI64U, IrValue::I64(v)) => Some(IrValue::F64((v as u64) as f64)), + + // ── Float precision conversions ───────────────────────────────── + (UnOp::F32DemoteF64, IrValue::F64(v)) => Some(IrValue::F32(v as f32)), + (UnOp::F64PromoteF32, IrValue::F32(v)) => Some(IrValue::F64(v as f64)), + + // ── Reinterpretations (bitcast) ───────────────────────────────── + (UnOp::I32ReinterpretF32, IrValue::F32(v)) => Some(IrValue::I32(v.to_bits() as i32)), + (UnOp::I64ReinterpretF64, IrValue::F64(v)) => Some(IrValue::I64(v.to_bits() as i64)), + (UnOp::F32ReinterpretI32, IrValue::I32(v)) => Some(IrValue::F32(f32::from_bits(v as u32))), + (UnOp::F64ReinterpretI64, IrValue::I64(v)) => Some(IrValue::F64(f64::from_bits(v as u64))), + + // Type mismatch — don't fold. + _ => None, + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{BlockId, IrBlock, IrFunction, IrTerminator, TypeIdx, WasmType}; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + } + } + + fn single_block(instrs: Vec, term: IrTerminator) -> Vec { + vec![IrBlock { + id: BlockId(0), + instructions: instrs, + terminator: term, + }] + } + + fn ret_none() -> IrTerminator { + IrTerminator::Return { value: None } + } + + // ── Basic constant propagation through Assign ──────────────────────── + + #[test] + fn assign_propagation() { + // v0 = Const(42); v1 = Assign(v0) → v1 = Const(42) + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }, + IrInstr::Assign { + dest: VarId(1), + src: VarId(0), + }, + ], + ret_none(), + )); + eliminate(&mut func); + match &func.blocks[0].instructions[1] { + IrInstr::Const { + dest, + value: IrValue::I32(42), + } => assert_eq!(*dest, VarId(1)), + other => panic!("expected Const(v1, 42), got {other:?}"), + } + } + + // ── BinOp folding ─────────────────────────────────────────────────── + + #[test] + fn fold_i32_add() { + // v0 = 10; v1 = 20; v2 = v0 + v1 → v2 = 30 + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(10), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(20), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + ret_none(), + )); + eliminate(&mut func); + match &func.blocks[0].instructions[2] { + IrInstr::Const { + dest, + value: IrValue::I32(30), + } => assert_eq!(*dest, VarId(2)), + other => panic!("expected Const(v2, 30), got {other:?}"), + } + } + + #[test] + fn fold_i32_comparison() { + // v0 = 5; v1 = 10; v2 = v0 < v1 → v2 = 1 + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(5), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(10), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32LtS, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + ret_none(), + )); + eliminate(&mut func); + match &func.blocks[0].instructions[2] { + IrInstr::Const { + value: IrValue::I32(1), + .. + } => {} + other => panic!("expected Const(1), got {other:?}"), + } + } + + // ── Div by zero: must NOT fold ────────────────────────────────────── + + #[test] + fn div_by_zero_not_folded() { + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(10), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32DivS, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + ret_none(), + )); + eliminate(&mut func); + assert!( + matches!(&func.blocks[0].instructions[2], IrInstr::BinOp { .. }), + "div-by-zero must not be folded" + ); + } + + #[test] + fn i32_div_s_overflow_not_folded() { + // i32::MIN / -1 → trap, do NOT fold + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(i32::MIN), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(-1), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32DivS, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + ret_none(), + )); + eliminate(&mut func); + assert!(matches!( + &func.blocks[0].instructions[2], + IrInstr::BinOp { .. } + )); + } + + #[test] + fn i32_rem_s_min_neg_one_folds_to_zero() { + // i32::MIN % -1 = 0 (does NOT trap in Wasm) + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(i32::MIN), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(-1), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32RemS, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + ret_none(), + )); + eliminate(&mut func); + match &func.blocks[0].instructions[2] { + IrInstr::Const { + value: IrValue::I32(0), + .. + } => {} + other => panic!("expected Const(0), got {other:?}"), + } + } + + // ── UnOp folding ──────────────────────────────────────────────────── + + #[test] + fn fold_i32_eqz() { + // v0 = 0; v1 = eqz(v0) → v1 = 1 + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(0), + }, + IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32Eqz, + operand: VarId(0), + }, + ], + ret_none(), + )); + eliminate(&mut func); + match &func.blocks[0].instructions[1] { + IrInstr::Const { + value: IrValue::I32(1), + .. + } => {} + other => panic!("expected Const(1), got {other:?}"), + } + } + + #[test] + fn fold_i32_clz() { + // v0 = 1; v1 = clz(v0) → v1 = 31 + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }, + IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32Clz, + operand: VarId(0), + }, + ], + ret_none(), + )); + eliminate(&mut func); + match &func.blocks[0].instructions[1] { + IrInstr::Const { + value: IrValue::I32(31), + .. + } => {} + other => panic!("expected Const(31), got {other:?}"), + } + } + + // ── Trapping TruncF must NOT fold ─────────────────────────────────── + + #[test] + fn trunc_f32_nan_not_folded() { + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::F32(f32::NAN), + }, + IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32TruncF32S, + operand: VarId(0), + }, + ], + ret_none(), + )); + eliminate(&mut func); + assert!( + matches!(&func.blocks[0].instructions[1], IrInstr::UnOp { .. }), + "trunc(NaN) must not be folded" + ); + } + + #[test] + fn trunc_f32_valid_folds() { + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::F32(3.7), + }, + IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32TruncF32S, + operand: VarId(0), + }, + ], + ret_none(), + )); + eliminate(&mut func); + match &func.blocks[0].instructions[1] { + IrInstr::Const { + value: IrValue::I32(3), + .. + } => {} + other => panic!("expected Const(3), got {other:?}"), + } + } + + // ── Chain folding (fixpoint) ──────────────────────────────────────── + + #[test] + fn chain_through_assign_then_binop() { + // v0 = 5; v1 = Assign(v0); v2 = 3; v3 = v1 + v2 → v3 = 8 + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(5), + }, + IrInstr::Assign { + dest: VarId(1), + src: VarId(0), + }, + IrInstr::Const { + dest: VarId(2), + value: IrValue::I32(3), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Add, + lhs: VarId(1), + rhs: VarId(2), + }, + ], + ret_none(), + )); + eliminate(&mut func); + match &func.blocks[0].instructions[3] { + IrInstr::Const { + value: IrValue::I32(8), + .. + } => {} + other => panic!("expected Const(8), got {other:?}"), + } + } + + // ── Non-constant operand: must NOT fold ───────────────────────────── + + #[test] + fn non_constant_operand_not_folded() { + // v0 = param (not const); v1 = 5; v2 = v0 + v1 — v0 unknown + let mut func = IrFunction { + params: vec![(VarId(0), WasmType::I32)], + locals: vec![], + blocks: single_block( + vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(5), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + ret_none(), + ), + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + }; + eliminate(&mut func); + assert!( + matches!(&func.blocks[0].instructions[1], IrInstr::BinOp { .. }), + "BinOp with non-const operand must not be folded" + ); + } + + // ── Conversion folding ────────────────────────────────────────────── + + #[test] + fn fold_i32_wrap_i64() { + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I64(0x1_0000_0005), + }, + IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32WrapI64, + operand: VarId(0), + }, + ], + ret_none(), + )); + eliminate(&mut func); + match &func.blocks[0].instructions[1] { + IrInstr::Const { + value: IrValue::I32(5), + .. + } => {} + other => panic!("expected Const(5), got {other:?}"), + } + } + + #[test] + fn fold_reinterpret_roundtrip() { + // i32 → f32 → i32 via reinterpret should preserve bits + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(0x4048_0000), // f32 bits for 3.125 + }, + IrInstr::UnOp { + dest: VarId(1), + op: UnOp::F32ReinterpretI32, + operand: VarId(0), + }, + IrInstr::UnOp { + dest: VarId(2), + op: UnOp::I32ReinterpretF32, + operand: VarId(1), + }, + ], + ret_none(), + )); + eliminate(&mut func); + match &func.blocks[0].instructions[2] { + IrInstr::Const { + value: IrValue::I32(v), + .. + } => assert_eq!(*v, 0x4048_0000), + other => panic!("expected Const(0x40480000), got {other:?}"), + } + } + + // ── i32 shift masking ─────────────────────────────────────────────── + + #[test] + fn fold_i32_shl_masks_shift_amount() { + // Wasm: shift amount masked by 31. shl(1, 33) == shl(1, 1) == 2 + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(33), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Shl, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + ret_none(), + )); + eliminate(&mut func); + match &func.blocks[0].instructions[2] { + IrInstr::Const { + value: IrValue::I32(2), + .. + } => {} + other => panic!("expected Const(2), got {other:?}"), + } + } + + // ── Wrapping arithmetic ───────────────────────────────────────────── + + #[test] + fn fold_i32_wrapping_add() { + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(i32::MAX), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(1), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + ret_none(), + )); + eliminate(&mut func); + match &func.blocks[0].instructions[2] { + IrInstr::Const { + value: IrValue::I32(v), + .. + } => assert_eq!(*v, i32::MIN), + other => panic!("expected Const(i32::MIN), got {other:?}"), + } + } + + // ── f64 folding ───────────────────────────────────────────────────── + + #[test] + fn fold_f64_mul() { + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::F64(2.5), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::F64(4.0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::F64Mul, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + ret_none(), + )); + eliminate(&mut func); + match &func.blocks[0].instructions[2] { + IrInstr::Const { + value: IrValue::F64(v), + .. + } => assert!((*v - 10.0).abs() < f64::EPSILON), + other => panic!("expected Const(10.0), got {other:?}"), + } + } + + // ── Multi-block: single-definition constants propagate cross-block ── + + #[test] + fn single_def_const_propagates_cross_block() { + // B0: v0 = Const(5); Jump(B1) + // B1: v1 = Assign(v0) + // + // v0 has exactly one definition (a Const) across the function. + // The global-constant seed therefore includes v0 → 5, so the Assign + // in B1 is folded to Const(5). + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(5), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Assign { + dest: VarId(1), + src: VarId(0), + }], + terminator: IrTerminator::Return { value: None }, + }, + ]); + eliminate(&mut func); + // The Assign in B1 should be folded: v0 is a single-definition global constant. + assert!( + matches!( + &func.blocks[1].instructions[0], + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(5) + } + ), + "Assign should be folded to Const(5) via global constant propagation; got: {:?}", + &func.blocks[1].instructions[0] + ); + } + + #[test] + fn multi_def_var_not_treated_as_global_const() { + // v0 is defined twice (Const in B0, Assign in B1) — must NOT be in + // global_const_map, so v1 = Assign(v0) in B2 is NOT folded. + // + // B0: v0 = Const(5); Jump(B1) + // B1: v0 = Assign(v2); Jump(B2) ← second def of v0 + // B2: v1 = Assign(v0) ← must remain + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(5), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Assign { + dest: VarId(0), // second definition of v0 + src: VarId(2), + }], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![IrInstr::Assign { + dest: VarId(1), + src: VarId(0), + }], + terminator: IrTerminator::Return { value: None }, + }, + ]); + eliminate(&mut func); + // v0 has 2 definitions → NOT a global constant → Assign in B2 must remain. + assert!( + matches!(&func.blocks[2].instructions[0], IrInstr::Assign { .. }), + "Assign must remain — v0 has multiple definitions" + ); + } + + // ── try_eval_binop unit tests ─────────────────────────────────────── + + #[test] + fn binop_i32_unsigned_comparison() { + // -1 as u32 > 0 as u32 + assert_eq!( + try_eval_binop(BinOp::I32GtU, IrValue::I32(-1), IrValue::I32(0)), + Some(IrValue::I32(1)) + ); + } + + #[test] + fn binop_i32_rotl() { + assert_eq!( + try_eval_binop( + BinOp::I32Rotl, + IrValue::I32(0x8000_0001_u32 as i32), + IrValue::I32(1) + ), + Some(IrValue::I32(3)) + ); + } + + #[test] + fn binop_i64_div_zero() { + assert_eq!( + try_eval_binop(BinOp::I64DivS, IrValue::I64(10), IrValue::I64(0)), + None + ); + } + + #[test] + fn binop_i64_div_signed_overflow() { + assert_eq!( + try_eval_binop(BinOp::I64DivS, IrValue::I64(i64::MIN), IrValue::I64(-1)), + None + ); + } + + #[test] + fn binop_i64_rem_s_min_neg_one() { + assert_eq!( + try_eval_binop(BinOp::I64RemS, IrValue::I64(i64::MIN), IrValue::I64(-1)), + Some(IrValue::I64(0)) + ); + } + + #[test] + fn binop_type_mismatch_returns_none() { + assert_eq!( + try_eval_binop(BinOp::I32Add, IrValue::I32(1), IrValue::I64(2)), + None + ); + } + + // ── try_eval_unop unit tests ──────────────────────────────────────── + + #[test] + fn unop_i64_extend_s() { + assert_eq!( + try_eval_unop(UnOp::I64ExtendI32S, IrValue::I32(-1)), + Some(IrValue::I64(-1)) + ); + } + + #[test] + fn unop_i64_extend_u() { + assert_eq!( + try_eval_unop(UnOp::I64ExtendI32U, IrValue::I32(-1)), + Some(IrValue::I64(0xFFFF_FFFF)) + ); + } + + #[test] + fn unop_f32_neg() { + match try_eval_unop(UnOp::F32Neg, IrValue::F32(1.5)) { + Some(IrValue::F32(v)) => assert!((v - (-1.5)).abs() < f32::EPSILON), + other => panic!("expected F32(-1.5), got {other:?}"), + } + } + + #[test] + fn unop_type_mismatch() { + assert_eq!(try_eval_unop(UnOp::I32Clz, IrValue::I64(1)), None); + } + + // ── Wasm nearest ──────────────────────────────────────────────────── + + #[test] + fn nearest_f32_bankers_rounding() { + assert_eq!(wasm_nearest_f32(0.5), 0.0); // round to even (0) + assert_eq!(wasm_nearest_f32(1.5), 2.0); // round to even (2) + assert_eq!(wasm_nearest_f32(2.5), 2.0); // round to even (2) + assert_eq!(wasm_nearest_f32(3.5), 4.0); // round to even (4) + } + + #[test] + fn nearest_f64_bankers_rounding() { + assert_eq!(wasm_nearest_f64(0.5), 0.0); + assert_eq!(wasm_nearest_f64(1.5), 2.0); + assert_eq!(wasm_nearest_f64(2.5), 2.0); + } + + // ── Valid i32 div/rem that SHOULD fold ─────────────────────────────── + + #[test] + fn i32_div_s_valid_folds() { + assert_eq!( + try_eval_binop(BinOp::I32DivS, IrValue::I32(10), IrValue::I32(3)), + Some(IrValue::I32(3)) + ); + } + + #[test] + fn i32_div_u_valid_folds() { + // -1 as u32 = u32::MAX; u32::MAX / 2 = 2147483647 + assert_eq!( + try_eval_binop(BinOp::I32DivU, IrValue::I32(-1), IrValue::I32(2)), + Some(IrValue::I32(2147483647)) + ); + } + + #[test] + fn i32_rem_u_valid_folds() { + assert_eq!( + try_eval_binop(BinOp::I32RemU, IrValue::I32(10), IrValue::I32(3)), + Some(IrValue::I32(1)) + ); + } + + // ── Wasm f32/f64 min/max NaN semantics ────────────────────────────── + + #[test] + fn f32_min_nan_propagates() { + // Wasm: f32.min(NaN, 1.0) = NaN (Rust's f32::min returns 1.0 — wrong) + let result = wasm_min_f32(f32::NAN, 1.0); + assert!(result.is_nan(), "f32.min(NaN, 1.0) must return NaN"); + + let result = wasm_min_f32(1.0, f32::NAN); + assert!(result.is_nan(), "f32.min(1.0, NaN) must return NaN"); + } + + #[test] + fn f32_max_nan_propagates() { + let result = wasm_max_f32(f32::NAN, 1.0); + assert!(result.is_nan(), "f32.max(NaN, 1.0) must return NaN"); + + let result = wasm_max_f32(1.0, f32::NAN); + assert!(result.is_nan(), "f32.max(1.0, NaN) must return NaN"); + } + + #[test] + fn f32_min_negative_zero() { + // Wasm: f32.min(-0.0, +0.0) = -0.0 + let result = wasm_min_f32(-0.0f32, 0.0f32); + assert!( + result.is_sign_negative(), + "f32.min(-0.0, +0.0) must return -0.0" + ); + + let result = wasm_min_f32(0.0f32, -0.0f32); + assert!( + result.is_sign_negative(), + "f32.min(+0.0, -0.0) must return -0.0" + ); + } + + #[test] + fn f32_max_positive_zero() { + // Wasm: f32.max(-0.0, +0.0) = +0.0 + let result = wasm_max_f32(-0.0f32, 0.0f32); + assert!( + result.is_sign_positive(), + "f32.max(-0.0, +0.0) must return +0.0" + ); + } + + #[test] + fn f64_min_nan_propagates() { + let result = wasm_min_f64(f64::NAN, 1.0); + assert!(result.is_nan(), "f64.min(NaN, 1.0) must return NaN"); + } + + #[test] + fn f64_max_nan_propagates() { + let result = wasm_max_f64(f64::NAN, 1.0); + assert!(result.is_nan(), "f64.max(NaN, 1.0) must return NaN"); + } + + #[test] + fn f32_min_max_normal_values() { + assert_eq!(wasm_min_f32(1.0, 2.0), 1.0); + assert_eq!(wasm_max_f32(1.0, 2.0), 2.0); + assert_eq!(wasm_min_f64(3.0, 4.0), 3.0); + assert_eq!(wasm_max_f64(3.0, 4.0), 4.0); + } +} diff --git a/crates/herkos/src/optimizer/copy_prop.rs b/crates/herkos/src/optimizer/copy_prop.rs new file mode 100644 index 0000000..d175275 --- /dev/null +++ b/crates/herkos/src/optimizer/copy_prop.rs @@ -0,0 +1,1347 @@ +//! Copy propagation: backward coalescing and forward substitution. +//! +//! ## Backward pass — single-use Assign coalescing +//! +//! When an instruction I_def defines variable `v_src`, and `v_src` is used +//! exactly once in the same block — by `Assign { dest: v_dst, src: v_src }` — +//! and `v_dst` is neither read nor written in instructions between I_def and the +//! Assign, we can: +//! +//! 1. Change I_def to write directly to `v_dst` (eliminating the copy-through `v_src`). +//! 2. Remove the `Assign` instruction. +//! +//! This eliminates the single-use temporaries that arise from Wasm's stack-based +//! evaluation model. For example: +//! +//! ```text +//! v7 = Const(2) → v1 = Const(2) +//! v1 = Assign(v7) → (removed) +//! +//! v16 = v4.add(v3) → v5 = v4.add(v3) +//! v5 = Assign(v16) → (removed) +//! ``` +//! +//! ## Forward pass — Assign substitution +//! +//! When `v_dst = Assign(v_src)` and all uses of `v_dst` are within the same +//! block after the Assign, and `v_src` is not redefined before those uses, +//! replace every read of `v_dst` with `v_src` and remove the Assign. +//! +//! This eliminates the `local.get` temporaries that Wasm emits when reading a +//! parameter or loop variable: +//! +//! ```text +//! v20 = Assign(v1) → (removed) +//! v24 = v20.wrapping_add(v23) → v24 = v1.wrapping_add(v23) +//! ``` +//! +//! ## Fixpoint and ordering +//! +//! Both passes run to fixpoint. The backward pass runs first (it creates no +//! new forward opportunities), then the forward pass runs. After both passes +//! settle, dead variables are pruned from `IrFunction::locals`. + +use super::utils::{ + build_global_use_count, count_uses_of, count_uses_of_terminator, for_each_use, instr_dest, + prune_dead_locals, replace_uses_of, replace_uses_of_terminator, set_instr_dest, +}; +use crate::ir::{IrBlock, IrFunction, IrInstr, VarId}; +use std::collections::HashMap; + +// ── Public entry point ──────────────────────────────────────────────────────── + +/// Eliminate single-use Assign copies and prune now-dead locals. +pub fn eliminate(func: &mut IrFunction) { + // ── Global (cross-block) copy propagation ──────────────────────────────── + // + // In SSA form every `Assign { dest, src }` is a global fact: `dest` is + // defined exactly once and always equals `src`. We collect all Assigns, + // build a substitution map chasing chains to their root, and rewrite every + // use across the entire function. The now-dead Assign instructions are + // left for dead_instrs to clean up (or the backward pass below). + global_copy_prop(func); + + // ── Backward pass: redirect I_def dest through single-use Assigns ──────── + // + // We rebuild the global use-count map before each round because a successful + // coalescing removes one Assign (changing use counts), so the previous map is + // stale. We break out of the per-block scan as soon as any block changes and + // restart the outer loop so we always work from a fresh global count. + loop { + let global_uses = build_global_use_count(func); + let mut any_changed = false; + for block in &mut func.blocks { + if coalesce_one(block, &global_uses) { + any_changed = true; + break; // global_uses is now stale; rebuild before continuing + } + } + if !any_changed { + break; + } + } + + // ── Forward pass: substitute v_src for v_dst at each use site ──────────── + // + // Runs after the backward pass because backward creates no new forward + // opportunities (it only removes Assigns, never adds them). The forward + // pass eliminates the `local.get` snapshots that Wasm emits when reading + // parameters or loop-carried variables, e.g.: + // + // v20 = Assign(v1) → (removed) + // v24 = v20.wrapping_add(v23) → v24 = v1.wrapping_add(v23) + loop { + let global_uses = build_global_use_count(func); + let mut any_changed = false; + for block in &mut func.blocks { + if forward_propagate_one(block, &global_uses) { + any_changed = true; + break; + } + } + if !any_changed { + break; + } + } + + // Prune locals that are no longer referenced anywhere. + prune_dead_locals(func); +} + +// ── Global (cross-block) copy propagation ───────────────────────────────────── + +/// Replaces every use of an Assign's `dest` with its `src`, chasing chains. +/// +/// In SSA form, `Assign { dest, src }` means `dest == src` globally. We +/// collect all such pairs, resolve transitive chains (e.g. `v25 → v24 → v19` +/// stops at `v19` if `v19` is not itself an Assign dest), and rewrite all +/// variable reads across every block. +fn global_copy_prop(func: &mut IrFunction) { + // Step 0: count definitions per variable. Only variables defined exactly + // once (true SSA) are safe for global substitution. Function parameters + // count as one definition each. + let mut def_count: HashMap = HashMap::new(); + for (param_var, _) in &func.params { + *def_count.entry(*param_var).or_insert(0) += 1; + } + for block in &func.blocks { + for instr in &block.instructions { + if let Some(dest) = instr_dest(instr) { + *def_count.entry(dest).or_insert(0) += 1; + } + } + } + + // Step 1: collect Assign { dest, src } pairs where dest has exactly one def. + let mut copy_map: HashMap = HashMap::new(); + for block in &func.blocks { + for instr in &block.instructions { + if let IrInstr::Assign { dest, src } = instr { + // Both dest and src must have at most one definition. + // dest must have exactly 1 (this Assign). src must have 0 + // (function parameter) or 1 (another instruction). If src has + // multiple definitions, replacing uses of dest with src could + // pick up a wrong definition. + let dest_ok = def_count.get(dest).copied() == Some(1); + let src_ok = def_count.get(src).copied().unwrap_or(0) <= 1; + if dest != src && dest_ok && src_ok { + copy_map.insert(*dest, *src); + } + } + } + } + + if copy_map.is_empty() { + return; + } + + // Step 2: chase chains to find the root for each key. + // E.g. if v25 → v24 and v24 → v19, then v25's root is v19. + let resolved: HashMap = copy_map + .keys() + .map(|&var| { + let mut root = var; + // Follow the chain with a depth limit to avoid infinite loops + // (shouldn't happen in well-formed SSA, but defensive). + let mut steps = 0; + while let Some(&next) = copy_map.get(&root) { + root = next; + steps += 1; + if steps > copy_map.len() { + break; // cycle guard + } + } + (var, root) + }) + .filter(|(var, root)| var != root) + .collect(); + + if resolved.is_empty() { + return; + } + + // Step 3: rewrite all uses across the entire function. + for block in &mut func.blocks { + for instr in &mut block.instructions { + for (&old, &new) in &resolved { + replace_uses_of(instr, old, new); + } + } + for (&old, &new) in &resolved { + replace_uses_of_terminator(&mut block.terminator, old, new); + } + } + + // Remove Assign instructions whose dest was resolved (they're now dead: + // all uses of their dest have been rewritten to the root). + for block in &mut func.blocks { + block.instructions.retain(|instr| { + if let IrInstr::Assign { dest, .. } = instr { + !resolved.contains_key(dest) + } else { + true + } + }); + } +} + +// ── Core coalescing logic ───────────────────────────────────────────────────── + +/// Tries to perform a single Assign coalescing in `block`. +/// Returns `true` if a coalescing was performed. +/// +/// `global_uses` is the function-wide read-count for every variable (built by +/// `build_global_use_count`). We use it — rather than a per-block count — for +/// the single-use check so that variables read in *other* blocks are not +/// incorrectly considered single-use. +fn coalesce_one(block: &mut IrBlock, global_uses: &HashMap) -> bool { + // ── Step 1: build per-block def-site map ───────────────────────────── + let mut def_site: HashMap = HashMap::new(); // var → instruction index + + for (i, instr) in block.instructions.iter().enumerate() { + if let Some(dest) = instr_dest(instr) { + def_site.insert(dest, i); + } + } + + // ── Step 2: find a coalesceable Assign ─────────────────────────────── + for assign_idx in 0..block.instructions.len() { + let (v_dst, v_src) = match &block.instructions[assign_idx] { + IrInstr::Assign { dest, src } => (*dest, *src), + _ => continue, + }; + + // Skip self-assignments (v_dst = v_src where they are the same). + if v_dst == v_src { + // Self-assignment: just remove it. + block.instructions.remove(assign_idx); + return true; + } + + // v_src must have exactly one use *globally* (this Assign). + // + // Using the global count (not a per-block count) is the key safety + // invariant: if v_src is also read in another block, coalescing would + // redirect v_src's definition to write v_dst instead, leaving v_src + // undefined for those other-block reads. + if global_uses.get(&v_src).copied().unwrap_or(0) != 1 { + continue; + } + + // v_src must be defined by an instruction in this block. + let def_idx = match def_site.get(&v_src) { + Some(&i) => i, + None => continue, // v_src is not defined in this block, so can't coalesce + }; + + // The def must precede the Assign. + // In strict SSA form each variable is defined exactly once, so this check + // is always satisfied — but kept as a safety guard. + if def_idx >= assign_idx { + continue; + } + + // Safety check: v_dst must not be read or written in the instructions + // strictly between def_idx and assign_idx. + // + // Rationale: after redirect, I_def writes to v_dst at def_idx. Any + // intervening read would see the new value instead of the old one (wrong). + // Any intervening write would clobber the value before it can be used (wrong). + // + // In strict SSA form v_dst has exactly one definition (this Assign), so it + // cannot be written between the two indices. It also cannot be read before its + // definition (the Assign), so the check is effectively a no-op. Kept as a + // guard against any future relaxation of the invariant. + let conflict = block.instructions[def_idx + 1..assign_idx].iter().any(|i| { + let mut found = false; + for_each_use(i, |v| { + if v == v_dst { + found = true; + } + }); + if instr_dest(i) == Some(v_dst) { + found = true; + } + found + }); + if conflict { + continue; + } + + // ── Safe: perform the redirect ──────────────────────────────────── + set_instr_dest(&mut block.instructions[def_idx], v_dst); + block.instructions.remove(assign_idx); + return true; + } + + false +} + +// ── Forward propagation ─────────────────────────────────────────────────────── + +/// Tries to perform a single forward substitution in `block`. +/// Returns `true` if any substitution was performed. +/// +/// For each `Assign { dest: v_dst, src: v_src }` at position `assign_idx`: +/// +/// 1. All global reads of `v_dst` must occur within this block, strictly after +/// `assign_idx` (ensures no cross-block uses, no pre-Assign uses in this block). +/// 2. `v_dst` must not be redefined after `assign_idx` within this block (avoids +/// incorrectly replacing uses that read a later definition). +/// 3. `v_src` must not be redefined between `assign_idx` (exclusive) and the +/// last use of `v_dst` (exclusive), preserving the value the Assign captured. +/// +/// When all conditions hold, every use of `v_dst` after `assign_idx` is replaced +/// by `v_src`, and the Assign is removed. +fn forward_propagate_one(block: &mut IrBlock, global_uses: &HashMap) -> bool { + for assign_idx in 0..block.instructions.len() { + let (v_dst, v_src) = match &block.instructions[assign_idx] { + IrInstr::Assign { dest, src } => (*dest, *src), + _ => continue, + }; + + // Self-assignments are handled by the backward pass. + if v_dst == v_src { + continue; + } + + // Count uses of v_dst in this block strictly after assign_idx. + let uses_after_instrs: usize = block.instructions[assign_idx + 1..] + .iter() + .map(|i| count_uses_of(i, v_dst)) + .sum(); + let uses_in_term = count_uses_of_terminator(&block.terminator, v_dst); + let local_uses_after = uses_after_instrs + uses_in_term; + + // All global reads of v_dst must be accounted for by local_uses_after. + // Any excess means v_dst is used in another block, or before assign_idx + // in this block — both unsafe to substitute. + let global_count = global_uses.get(&v_dst).copied().unwrap_or(0); + if global_count != local_uses_after { + continue; + } + + // Nothing to do if v_dst is never read after the Assign. + if local_uses_after == 0 { + continue; + } + + // v_dst must not be redefined in instructions after assign_idx. + // If it were, uses past the redefinition would read a different value. + if block.instructions[assign_idx + 1..] + .iter() + .any(|i| instr_dest(i) == Some(v_dst)) + { + continue; + } + + // Determine the range of instructions in which v_src must remain stable. + // + // If the terminator reads v_dst, v_src must survive all instructions + // after assign_idx. Otherwise, only up to (but not including) the last + // instruction that reads v_dst: reads happen before the dest-write in + // the same instruction, so a same-position redefinition of v_src is safe. + let check_end = if uses_in_term > 0 { + block.instructions.len() + } else { + // last instruction index (0-based into the full block) that reads v_dst + block.instructions[assign_idx + 1..] + .iter() + .enumerate() + .filter(|(_, i)| count_uses_of(i, v_dst) > 0) + .map(|(rel, _)| assign_idx + 1 + rel) + .next_back() + .unwrap_or(assign_idx) // unreachable: local_uses_after > 0 + }; + + // Check v_src is not written in [assign_idx+1, check_end). + if block.instructions[assign_idx + 1..check_end] + .iter() + .any(|i| instr_dest(i) == Some(v_src)) + { + continue; + } + + // Safe: substitute v_src for every read of v_dst after assign_idx. + for instr in &mut block.instructions[assign_idx + 1..] { + replace_uses_of(instr, v_dst, v_src); + } + replace_uses_of_terminator(&mut block.terminator, v_dst, v_src); + block.instructions.remove(assign_idx); + return true; + } + false +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + BinOp, BlockId, IrBlock, IrFunction, IrTerminator, IrValue, TypeIdx, WasmType, + }; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + } + } + + fn make_func_with_locals(blocks: Vec, locals: Vec<(VarId, WasmType)>) -> IrFunction { + IrFunction { + params: vec![], + locals, + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + } + } + + fn single_block(instrs: Vec, term: IrTerminator) -> Vec { + vec![IrBlock { + id: BlockId(0), + instructions: instrs, + terminator: term, + }] + } + + fn ret_none() -> IrTerminator { + IrTerminator::Return { value: None } + } + + // ── Basic: Const → Assign ───────────────────────────────────────────── + + #[test] + fn const_assign_coalesced() { + // v7 = Const(2); v1 = Assign(v7) + // Global copy prop: v1→v7, removes Assign. Result: v7 = Const(2). + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(7), + value: IrValue::I32(2), + }, + IrInstr::Assign { + dest: VarId(1), + src: VarId(7), + }, + ], + ret_none(), + )); + eliminate(&mut func); + let block = &func.blocks[0]; + assert_eq!(block.instructions.len(), 1, "Assign should be removed"); + match &block.instructions[0] { + IrInstr::Const { + dest, + value: IrValue::I32(2), + } => assert_eq!(*dest, VarId(7), "producer v7 survives"), + other => panic!("expected Const, got {other:?}"), + } + } + + // ── Basic: BinOp → Assign ───────────────────────────────────────────── + + #[test] + fn binop_assign_coalesced() { + // v16 = v4 + v3; v5 = Assign(v16) + // Global copy prop: v5→v16, removes Assign. Result: v16 = v4 + v3. + let mut func = make_func(single_block( + vec![ + IrInstr::BinOp { + dest: VarId(16), + op: BinOp::I32Add, + lhs: VarId(4), + rhs: VarId(3), + }, + IrInstr::Assign { + dest: VarId(5), + src: VarId(16), + }, + ], + ret_none(), + )); + eliminate(&mut func); + let block = &func.blocks[0]; + assert_eq!(block.instructions.len(), 1); + match &block.instructions[0] { + IrInstr::BinOp { dest, .. } => assert_eq!(*dest, VarId(16)), + other => panic!("expected BinOp, got {other:?}"), + } + } + + // ── Multi-use src: must NOT coalesce ───────────────────────────────── + + #[test] + fn multi_use_src_global_prop_removes_both() { + // v7 = Const(2); v1 = Assign(v7); v2 = Assign(v7) + // Global copy prop: v1→v7, v2→v7, removes both Assigns. + // (Backward pass can't coalesce because v7 has 2 uses, but global + // copy prop works by rewriting uses of v1/v2 to v7.) + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(7), + value: IrValue::I32(2), + }, + IrInstr::Assign { + dest: VarId(1), + src: VarId(7), + }, + IrInstr::Assign { + dest: VarId(2), + src: VarId(7), + }, + ], + ret_none(), + )); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 1); + match &func.blocks[0].instructions[0] { + IrInstr::Const { dest, .. } => assert_eq!(*dest, VarId(7)), + other => panic!("expected Const, got {other:?}"), + } + } + + // ── Intervening read of v_dst: must NOT coalesce ────────────────────── + + #[test] + fn intervening_read_of_dst_blocks_coalesce() { + // Non-SSA pattern: v5 is a parameter AND redefined by Assign. + // v16 = v4+v3; v8 = v5+v1 (reads param v5); v5 = Assign(v16) + // Backward coalescing v16→v5 blocked because v5 is read in between. + // Global copy prop skips v5 because it has 2 defs (param + Assign). + let mut func = IrFunction { + params: vec![(VarId(5), WasmType::I32)], + locals: vec![], + blocks: single_block( + vec![ + IrInstr::BinOp { + dest: VarId(16), + op: BinOp::I32Add, + lhs: VarId(4), + rhs: VarId(3), + }, + IrInstr::BinOp { + dest: VarId(8), + op: BinOp::I32Add, + lhs: VarId(5), + rhs: VarId(1), + }, + IrInstr::Assign { + dest: VarId(5), + src: VarId(16), + }, + ], + ret_none(), + ), + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + }; + eliminate(&mut func); + // Coalescing v16→v5 is blocked because v5 is read between def(v16) and Assign. + assert_eq!(func.blocks[0].instructions.len(), 3); + } + + // ── Intervening write of v_dst: must NOT coalesce ───────────────────── + + #[test] + fn intervening_write_of_dst_blocks_coalesce() { + // v5 = v4+v3; v5 = Assign(v0) [write to v5 in between]; v4 = Assign(v5) + // v5 is written between def(v5_tmp) and Assign → conflict. + let mut func = make_func(single_block( + vec![ + IrInstr::BinOp { + dest: VarId(99), // temp + op: BinOp::I32Add, + lhs: VarId(4), + rhs: VarId(3), + }, + IrInstr::Assign { + // writes v4 (= v_dst of next Assign) + dest: VarId(4), + src: VarId(0), + }, + IrInstr::Assign { + dest: VarId(4), + src: VarId(99), + }, + ], + ret_none(), + )); + eliminate(&mut func); + // Coalescing v99→v4 is blocked because v4 is written between def(v99) and Assign. + assert_eq!(func.blocks[0].instructions.len(), 3); + } + + // ── Self-assignment removal ─────────────────────────────────────────── + + #[test] + fn self_assign_removed() { + // v1 = Assign(v1) is a no-op and should be removed. + let mut func = make_func(single_block( + vec![IrInstr::Assign { + dest: VarId(1), + src: VarId(1), + }], + ret_none(), + )); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 0); + } + + // ── Chain coalescing ────────────────────────────────────────────────── + + #[test] + fn chain_coalesced() { + // v7 = Const(2); v10 = Assign(v7); v1 = Assign(v10) + // Global copy prop: v10→v7, v1→v10→v7. Both Assigns removed. + // Result: v7 = Const(2). + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(7), + value: IrValue::I32(2), + }, + IrInstr::Assign { + dest: VarId(10), + src: VarId(7), + }, + IrInstr::Assign { + dest: VarId(1), + src: VarId(10), + }, + ], + ret_none(), + )); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 1); + match &func.blocks[0].instructions[0] { + IrInstr::Const { + dest, + value: IrValue::I32(2), + } => assert_eq!(*dest, VarId(7)), + other => panic!("expected Const(v7,2), got {other:?}"), + } + } + + // ── Dead local pruning ──────────────────────────────────────────────── + + #[test] + fn dead_local_pruned_after_coalesce() { + // v7 = Const(2); v1 = Assign(v7) + // Global copy prop: v1→v7, Assign removed. v7 survives, v1 is dead. + let mut func = make_func_with_locals( + single_block( + vec![ + IrInstr::Const { + dest: VarId(7), + value: IrValue::I32(2), + }, + IrInstr::Assign { + dest: VarId(1), + src: VarId(7), + }, + ], + ret_none(), + ), + vec![(VarId(7), WasmType::I32), (VarId(1), WasmType::I32)], + ); + eliminate(&mut func); + // v1 should be pruned (its uses rewritten to v7); v7 survives. + assert!( + func.locals.iter().any(|(v, _)| *v == VarId(7)), + "v7 should remain in locals" + ); + assert!( + !func.locals.iter().any(|(v, _)| *v == VarId(1)), + "v1 should be pruned from locals" + ); + } + + // ── No-op: no Assigns → nothing changes ────────────────────────────── + + #[test] + fn no_assigns_unchanged() { + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }, + IrInstr::BinOp { + dest: VarId(1), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(0), + }, + ], + IrTerminator::Return { + value: Some(VarId(1)), + }, + )); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 2); + } + + // ── Realistic fibo B0 pattern ───────────────────────────────────────── + + #[test] + fn fibo_b0_pattern() { + // v7 = Const(2); v1 = Assign(v7); v8 = Const(2); v9 = BinOp(v0, v8) + // Global copy prop: v1→v7, Assign removed. 3 instrs remain. + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(7), + value: IrValue::I32(2), + }, + IrInstr::Assign { + dest: VarId(1), + src: VarId(7), + }, + IrInstr::Const { + dest: VarId(8), + value: IrValue::I32(2), + }, + IrInstr::BinOp { + dest: VarId(9), + op: BinOp::I32LtS, + lhs: VarId(0), + rhs: VarId(8), + }, + ], + IrTerminator::BranchIf { + condition: VarId(9), + if_true: BlockId(1), + if_false: BlockId(2), + }, + )); + eliminate(&mut func); + let instrs = &func.blocks[0].instructions; + assert_eq!( + instrs.len(), + 3, + "only v7+Assign pair removed, leaving 3 instrs" + ); + // First instruction is v7 = Const(2) (producer survives). + match &instrs[0] { + IrInstr::Const { + dest, + value: IrValue::I32(2), + } => assert_eq!(*dest, VarId(7)), + other => panic!("expected Const(v7,2), got {other:?}"), + } + } + + // ── Forward pass tests ──────────────────────────────────────────────── + + #[test] + fn forward_basic_param_snapshot() { + // v_dst = Assign(v_src); v_out = BinOp(v_dst, v1) → v_out = BinOp(v_src, v1) + let mut func = make_func(single_block( + vec![ + IrInstr::Assign { + dest: VarId(10), + src: VarId(0), // parameter — not defined by any in-block instruction + }, + IrInstr::BinOp { + dest: VarId(11), + op: BinOp::I32Add, + lhs: VarId(10), + rhs: VarId(1), + }, + ], + IrTerminator::Return { + value: Some(VarId(11)), + }, + )); + eliminate(&mut func); + let instrs = &func.blocks[0].instructions; + assert_eq!(instrs.len(), 1, "Assign should be removed"); + match &instrs[0] { + IrInstr::BinOp { lhs, .. } => assert_eq!(*lhs, VarId(0), "lhs should be v0"), + other => panic!("expected BinOp, got {other:?}"), + } + } + + #[test] + fn forward_multi_use_all_in_block() { + // v_dst = Assign(v_src); use(v_dst) twice in same block → both replaced + let mut func = make_func(single_block( + vec![ + IrInstr::Assign { + dest: VarId(10), + src: VarId(0), + }, + IrInstr::BinOp { + dest: VarId(11), + op: BinOp::I32Add, + lhs: VarId(10), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(12), + op: BinOp::I32Add, + lhs: VarId(10), + rhs: VarId(2), + }, + ], + ret_none(), + )); + eliminate(&mut func); + let instrs = &func.blocks[0].instructions; + assert_eq!(instrs.len(), 2, "Assign should be removed"); + // Both BinOps should now reference v0 directly + for instr in instrs { + match instr { + IrInstr::BinOp { lhs, .. } => assert_eq!(*lhs, VarId(0)), + other => panic!("expected BinOp, got {other:?}"), + } + } + } + + #[test] + fn forward_use_in_terminator() { + // v_dst = Assign(v_src); Return(v_dst) → Return(v_src) + let mut func = make_func(single_block( + vec![IrInstr::Assign { + dest: VarId(10), + src: VarId(0), + }], + IrTerminator::Return { + value: Some(VarId(10)), + }, + )); + eliminate(&mut func); + assert_eq!( + func.blocks[0].instructions.len(), + 0, + "Assign should be removed" + ); + match &func.blocks[0].terminator { + IrTerminator::Return { value: Some(v) } => assert_eq!(*v, VarId(0)), + other => panic!("expected Return(v0), got {other:?}"), + } + } + + #[test] + fn forward_cross_block_use_propagated() { + // v_dst used in another block → global copy prop replaces v10 with v0 + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Assign { + dest: VarId(10), + src: VarId(0), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(11), + op: BinOp::I32Add, + lhs: VarId(10), // reads v10 from B0 — cross-block use + rhs: VarId(1), + }], + terminator: IrTerminator::Return { + value: Some(VarId(11)), + }, + }, + ]); + eliminate(&mut func); + // Global copy prop rewrites v10 → v0 in B1; then the Assign has no + // remaining uses and is removed by the forward pass. + match &func.blocks[1].instructions[0] { + IrInstr::BinOp { lhs, .. } => assert_eq!(*lhs, VarId(0), "lhs should be v0"), + other => panic!("expected BinOp, got {other:?}"), + } + } + + #[test] + fn forward_blocked_by_src_redef_before_last_use() { + // Non-SSA: v0 is a param AND redefined by BinOp (2 defs). + // v10 = Assign(v0); v0 = v0+v1; v11 = v10+v2 + // Global copy prop skips v10 because v0 (src) has >1 def. + // Forward pass also blocked: v0 redefined before last use of v10. + let mut func = IrFunction { + params: vec![(VarId(0), WasmType::I32)], + locals: vec![], + blocks: single_block( + vec![ + IrInstr::Assign { + dest: VarId(10), + src: VarId(0), + }, + IrInstr::BinOp { + dest: VarId(0), // redefines v_src = v0 + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(11), + op: BinOp::I32Add, + lhs: VarId(10), // last use of v_dst + rhs: VarId(2), + }, + ], + ret_none(), + ), + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + }; + eliminate(&mut func); + // v10 = Assign(v0) must NOT be eliminated: v0 is redefined before v10's last use + assert_eq!(func.blocks[0].instructions.len(), 3); + } + + #[test] + fn forward_safe_when_src_redef_at_last_use() { + // v_dst = Assign(v_src) + // v_src = v_dst + 5 ← uses v_dst AND redefines v_src at the same position + // + // v_src is redefined at check_end (exclusive), not before it, so the + // substitution is safe: v_src = (old v_src) + 5. + let mut func = make_func(single_block( + vec![ + IrInstr::Assign { + dest: VarId(10), + src: VarId(0), + }, + IrInstr::BinOp { + dest: VarId(0), // redefines v0 (v_src) — but this is also the last use of v10 + op: BinOp::I32Add, + lhs: VarId(10), // reads v10 (v_dst) + rhs: VarId(5), + }, + ], + ret_none(), + )); + eliminate(&mut func); + // Assign should be removed; v0 = BinOp(v0, v5) is the result + let instrs = &func.blocks[0].instructions; + assert_eq!(instrs.len(), 1, "Assign should be removed"); + match &instrs[0] { + IrInstr::BinOp { dest, lhs, .. } => { + assert_eq!(*dest, VarId(0)); + assert_eq!(*lhs, VarId(0), "lhs should be v0 (substituted from v10)"); + } + other => panic!("expected BinOp, got {other:?}"), + } + } + + #[test] + fn forward_blocked_by_dst_redef() { + // v_dst = Assign(v_src) + // v_dst = BinOp(...) ← redefines v_dst: later uses read the new value + // use(v_dst) + let mut func = make_func(single_block( + vec![ + IrInstr::Assign { + dest: VarId(10), + src: VarId(0), + }, + IrInstr::BinOp { + dest: VarId(10), // redefines v_dst + op: BinOp::I32Add, + lhs: VarId(1), + rhs: VarId(2), + }, + IrInstr::BinOp { + dest: VarId(11), + op: BinOp::I32Add, + lhs: VarId(10), + rhs: VarId(3), + }, + ], + ret_none(), + )); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 3); + } + + #[test] + fn forward_fibo_b3_local_get_chain() { + // Mirrors the v16/v17 pattern from func_7 B3: + // v16 = Assign(v1) + // v17 = Assign(v0) + // v18 = BinOp(I32GeS, v16, v17) + // BranchIf(v18, B5, B4) + // + // After forward pass: v18 = BinOp(v1, v0), no Assigns. + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Assign { + dest: VarId(16), + src: VarId(1), + }, + IrInstr::Assign { + dest: VarId(17), + src: VarId(0), + }, + IrInstr::BinOp { + dest: VarId(18), + op: BinOp::I32GeS, + lhs: VarId(16), + rhs: VarId(17), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(18), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + let instrs = &func.blocks[0].instructions; + assert_eq!(instrs.len(), 1, "both Assigns should be removed"); + match &instrs[0] { + IrInstr::BinOp { + lhs, + rhs, + op: BinOp::I32GeS, + .. + } => { + assert_eq!(*lhs, VarId(1), "lhs should be v1"); + assert_eq!(*rhs, VarId(0), "rhs should be v0"); + } + other => panic!("expected BinOp(I32GeS), got {other:?}"), + } + } + + #[test] + fn forward_fibo_b4_multi_snapshot() { + // Mirrors the v20/v21/v22/v25 pattern from func_7 B4. + // v20 = Assign(v1); v21 = Assign(v1); v22 = Assign(v0) + // v23 = BinOp(I32LtS, v21, v22) + // v24 = BinOp(I32Add, v20, v23) + // v25 = Assign(v0) + // v26 = BinOp(I32LeS, v24, v25) + // BranchIf(v26, ...) + // + // After forward pass: + // v23 = BinOp(I32LtS, v1, v0) + // v24 = BinOp(I32Add, v1, v23) + // v26 = BinOp(I32LeS, v24, v0) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Assign { + dest: VarId(20), + src: VarId(1), + }, + IrInstr::Assign { + dest: VarId(21), + src: VarId(1), + }, + IrInstr::Assign { + dest: VarId(22), + src: VarId(0), + }, + IrInstr::BinOp { + dest: VarId(23), + op: BinOp::I32LtS, + lhs: VarId(21), + rhs: VarId(22), + }, + IrInstr::BinOp { + dest: VarId(24), + op: BinOp::I32Add, + lhs: VarId(20), + rhs: VarId(23), + }, + IrInstr::Assign { + dest: VarId(25), + src: VarId(0), + }, + IrInstr::BinOp { + dest: VarId(26), + op: BinOp::I32LeS, + lhs: VarId(24), + rhs: VarId(25), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(26), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + + let instrs = &func.blocks[0].instructions; + // Only the 3 BinOps remain; all 4 Assigns are gone. + assert_eq!( + instrs.len(), + 3, + "all 4 Assigns should be removed, leaving 3 BinOps" + ); + + // v23 = BinOp(I32LtS, v1, v0) + match &instrs[0] { + IrInstr::BinOp { + dest, + op: BinOp::I32LtS, + lhs, + rhs, + } => { + assert_eq!(*dest, VarId(23)); + assert_eq!(*lhs, VarId(1)); + assert_eq!(*rhs, VarId(0)); + } + other => panic!("instrs[0]: expected BinOp(I32LtS, v1, v0), got {other:?}"), + } + + // v24 = BinOp(I32Add, v1, v23) + match &instrs[1] { + IrInstr::BinOp { + dest, + op: BinOp::I32Add, + lhs, + rhs, + } => { + assert_eq!(*dest, VarId(24)); + assert_eq!(*lhs, VarId(1)); + assert_eq!(*rhs, VarId(23)); + } + other => panic!("instrs[1]: expected BinOp(I32Add, v1, v23), got {other:?}"), + } + + // v26 = BinOp(I32LeS, v24, v0) + match &instrs[2] { + IrInstr::BinOp { + dest, + op: BinOp::I32LeS, + lhs, + rhs, + } => { + assert_eq!(*dest, VarId(26)); + assert_eq!(*lhs, VarId(24)); + assert_eq!(*rhs, VarId(0)); + } + other => panic!("instrs[2]: expected BinOp(I32LeS, v24, v0), got {other:?}"), + } + } + + // ── Cross-block copy propagation ──────────────────────────────────── + + #[test] + fn global_copy_prop_chain_across_blocks() { + // B0: v10 = Assign(v0) + // B1: v20 = Assign(v10) + // B2: Return(v20) + // + // Chain: v20 → v10 → v0. After global copy prop, B2 returns v0. + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Assign { + dest: VarId(10), + src: VarId(0), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Assign { + dest: VarId(20), + src: VarId(10), + }], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { + value: Some(VarId(20)), + }, + }, + ]); + eliminate(&mut func); + // v20 → v10 → v0; Return should use v0 + match &func.blocks[2].terminator { + IrTerminator::Return { value: Some(v) } => assert_eq!(*v, VarId(0)), + other => panic!("expected Return(v0), got {other:?}"), + } + } + + #[test] + fn global_copy_prop_multiple_uses_across_blocks() { + // B0: v10 = Assign(v0) + // B1: v11 = BinOp(v10, v10) — two uses of v10 in another block + // + // Both uses should be rewritten to v0. + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Assign { + dest: VarId(10), + src: VarId(0), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(11), + op: BinOp::I32Add, + lhs: VarId(10), + rhs: VarId(10), + }], + terminator: IrTerminator::Return { + value: Some(VarId(11)), + }, + }, + ]); + eliminate(&mut func); + match &func.blocks[1].instructions[0] { + IrInstr::BinOp { lhs, rhs, .. } => { + assert_eq!(*lhs, VarId(0)); + assert_eq!(*rhs, VarId(0)); + } + other => panic!("expected BinOp, got {other:?}"), + } + } + + // ── Regression: cross-block v_src must NOT be coalesced ────────────── + // + // Bug: the old per-block use_count counted v_src uses only within the + // current block. If v_src had exactly 1 use in the current block (the + // Assign) but was also read in another block, copy_prop would incorrectly + // coalesce, redirecting v_src's definition to v_dst and leaving v_src + // undefined in the other block. Functions like `lcm` and `isqrt` then + // returned 0 (the default initial value) instead of the correct result. + + #[test] + fn cross_block_all_copies_resolved() { + // Block 0: v3 = Const(1); v0 = Assign(v3); v10 = Assign(v0) + // Block 1: v20 = Assign(v0); Return(v20) + // + // Global copy prop chains: v0→v3, v10→v0→v3, v20→v0→v3. + // All Assigns are removed, Return uses v3. v3 is still defined. + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { + dest: VarId(3), + value: IrValue::I32(1), + }, + IrInstr::Assign { + dest: VarId(0), + src: VarId(3), + }, + IrInstr::Assign { + dest: VarId(10), + src: VarId(0), + }, + ], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Assign { + dest: VarId(20), + src: VarId(0), + }], + terminator: IrTerminator::Return { + value: Some(VarId(20)), + }, + }, + ]); + eliminate(&mut func); + + // v3 must still be defined in Block 0 (it's the root). + let b0_dests: Vec = func.blocks[0] + .instructions + .iter() + .filter_map(instr_dest) + .collect(); + assert!( + b0_dests.contains(&VarId(3)), + "v3 must still be defined in Block 0; got: {b0_dests:?}" + ); + // Return should use v3 directly. + match &func.blocks[1].terminator { + IrTerminator::Return { value: Some(v) } => assert_eq!(*v, VarId(3)), + other => panic!("expected Return(v3), got {other:?}"), + } + } + + // ── Multi-block: each block is independent ──────────────────────────── + + #[test] + fn multi_block_each_coalesced_independently() { + // Block 0: v7 = Const(1); v1 = Assign(v7) → v1 = Const(1) + // Block 1: v8 = Const(2); v2 = Assign(v8) → v2 = Const(2) + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { + dest: VarId(7), + value: IrValue::I32(1), + }, + IrInstr::Assign { + dest: VarId(1), + src: VarId(7), + }, + ], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![ + IrInstr::Const { + dest: VarId(8), + value: IrValue::I32(2), + }, + IrInstr::Assign { + dest: VarId(2), + src: VarId(8), + }, + ], + terminator: IrTerminator::Return { value: None }, + }, + ]); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 1); + assert_eq!(func.blocks[1].instructions.len(), 1); + } +} diff --git a/crates/herkos/src/optimizer/dead_blocks.rs b/crates/herkos/src/optimizer/dead_blocks.rs index 8db587d..1db178f 100644 --- a/crates/herkos/src/optimizer/dead_blocks.rs +++ b/crates/herkos/src/optimizer/dead_blocks.rs @@ -4,28 +4,11 @@ //! arise naturally during IR translation when code follows an `unreachable` or //! `return` instruction inside a Wasm structured control flow construct. -use crate::ir::{BlockId, IrBlock, IrFunction, IrTerminator}; +use super::utils::terminator_successors; +use crate::ir::{BlockId, IrBlock, IrFunction}; use anyhow::{bail, Result}; use std::collections::{HashMap, HashSet}; -/// Returns the successor block IDs for a terminator. -fn terminator_successors(term: &IrTerminator) -> Vec { - match term { - IrTerminator::Return { .. } | IrTerminator::Unreachable => vec![], - IrTerminator::Jump { target } => vec![*target], - IrTerminator::BranchIf { - if_true, if_false, .. - } => vec![*if_true, *if_false], - IrTerminator::BranchTable { - targets, default, .. - } => targets - .iter() - .chain(std::iter::once(default)) - .copied() - .collect(), - } -} - /// Computes the set of block IDs reachable from the entry block via BFS. fn reachable_blocks(func: &IrFunction) -> Result> { // Index blocks by ID for O(1) lookup during traversal. @@ -61,7 +44,7 @@ pub fn eliminate(func: &mut IrFunction) -> Result<()> { #[cfg(test)] mod tests { use super::*; - use crate::ir::{IrInstr, IrValue, TypeIdx, VarId, WasmType}; + use crate::ir::{IrInstr, IrTerminator, IrValue, TypeIdx, VarId, WasmType}; /// Build a minimal `IrFunction` with the given blocks. /// Entry block is always `BlockId(0)`. diff --git a/crates/herkos/src/optimizer/dead_instrs.rs b/crates/herkos/src/optimizer/dead_instrs.rs new file mode 100644 index 0000000..fd6dc5c --- /dev/null +++ b/crates/herkos/src/optimizer/dead_instrs.rs @@ -0,0 +1,363 @@ +//! Dead instruction elimination. +//! +//! Removes instructions whose destination `VarId` has zero uses across the +//! entire function and whose operation is side-effect-free. +//! +//! ## Algorithm +//! +//! 1. Build the global use-count map (`VarId → number of reads`). +//! 2. For each instruction that produces a value (`instr_dest` returns `Some`): +//! if the use count is zero **and** the instruction is side-effect-free, +//! mark it for removal. +//! 3. Remove all marked instructions. +//! 4. Repeat to fixpoint — removing an instruction may make its operands' +//! definitions unused. +//! 5. Prune dead locals from `IrFunction::locals`. + +use super::utils::{build_global_use_count, instr_dest, is_side_effect_free, prune_dead_locals}; +use crate::ir::IrFunction; + +/// Run dead instruction elimination to fixpoint, then prune dead locals. +pub fn eliminate(func: &mut IrFunction) { + loop { + let uses = build_global_use_count(func); + let mut changed = false; + + for block in &mut func.blocks { + block.instructions.retain(|instr| { + if let Some(dest) = instr_dest(instr) { + if uses.get(&dest).copied().unwrap_or(0) == 0 && is_side_effect_free(instr) { + changed = true; + return false; // remove + } + } + true // keep + }); + } + + if !changed { + break; + } + } + + prune_dead_locals(func); +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + BinOp, BlockId, GlobalIdx, IrBlock, IrFunction, IrInstr, IrTerminator, IrValue, + MemoryAccessWidth, TypeIdx, VarId, WasmType, + }; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + } + } + + fn make_func_with_locals(blocks: Vec, locals: Vec<(VarId, WasmType)>) -> IrFunction { + IrFunction { + params: vec![], + locals, + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + } + } + + fn single_block(instrs: Vec, term: IrTerminator) -> Vec { + vec![IrBlock { + id: BlockId(0), + instructions: instrs, + terminator: term, + }] + } + + fn ret_none() -> IrTerminator { + IrTerminator::Return { value: None } + } + + // ── Basic: unused side-effect-free instruction is removed ───────────── + + #[test] + fn unused_const_removed() { + let mut func = make_func(single_block( + vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + ret_none(), + )); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 0); + } + + #[test] + fn unused_binop_removed() { + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(2), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + ret_none(), + )); + eliminate(&mut func); + // v2 unused → removed; then v0, v1 become unused → removed + assert_eq!(func.blocks[0].instructions.len(), 0); + } + + // ── Used instruction is kept ───────────────────────────────────────── + + #[test] + fn used_const_kept() { + let mut func = make_func(single_block( + vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + IrTerminator::Return { + value: Some(VarId(0)), + }, + )); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 1); + } + + // ── Side-effectful instructions are kept even when unused ───────────── + + #[test] + fn unused_load_kept() { + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(0), + }, + IrInstr::Load { + dest: VarId(1), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }, + ], + ret_none(), + )); + eliminate(&mut func); + // Load may trap → kept; v0 is used by Load → kept + assert_eq!(func.blocks[0].instructions.len(), 2); + } + + #[test] + fn store_kept() { + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(0), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(99), + }, + IrInstr::Store { + ty: WasmType::I32, + addr: VarId(0), + value: VarId(1), + offset: 0, + width: MemoryAccessWidth::Full, + }, + ], + ret_none(), + )); + eliminate(&mut func); + // Store has side effects → kept; v0, v1 used by Store → kept + assert_eq!(func.blocks[0].instructions.len(), 3); + } + + // ── Fixpoint: cascading removal ────────────────────────────────────── + + #[test] + fn fixpoint_cascading_removal() { + // v0 = Const(1) + // v1 = Const(2) + // v2 = BinOp(v0, v1) ← only use of v0, v1 + // v3 = BinOp(v2, v2) ← only use of v2 + // Return(None) ← v3 unused + // + // Round 1: v3 unused → remove v3's BinOp + // Round 2: v2 unused → remove v2's BinOp + // Round 3: v0, v1 unused → remove both Consts + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(2), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Mul, + lhs: VarId(2), + rhs: VarId(2), + }, + ], + ret_none(), + )); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 0); + } + + // ── Mixed: some dead, some live ────────────────────────────────────── + + #[test] + fn mixed_dead_and_live() { + // v0 = Const(1) ← used by Return + // v1 = Const(2) ← unused → dead + // v2 = BinOp(v1, v1) ← unused → dead + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(2), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(1), + rhs: VarId(1), + }, + ], + IrTerminator::Return { + value: Some(VarId(0)), + }, + )); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 1); + match &func.blocks[0].instructions[0] { + IrInstr::Const { + dest, + value: IrValue::I32(1), + } => assert_eq!(*dest, VarId(0)), + other => panic!("expected Const(v0, 1), got {other:?}"), + } + } + + // ── Dead locals are pruned ─────────────────────────────────────────── + + #[test] + fn dead_locals_pruned() { + let mut func = make_func_with_locals( + single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(2), + }, + ], + IrTerminator::Return { + value: Some(VarId(0)), + }, + ), + vec![(VarId(0), WasmType::I32), (VarId(1), WasmType::I32)], + ); + eliminate(&mut func); + // v1 is dead → removed from instructions and locals + assert!(!func.locals.iter().any(|(v, _)| *v == VarId(1))); + assert!(func.locals.iter().any(|(v, _)| *v == VarId(0))); + } + + // ── Multi-block: dead in one, live in another ──────────────────────── + + #[test] + fn multi_block_cross_reference_kept() { + // Block 0: v0 = Const(1); Jump(B1) + // Block 1: Return(v0) + // v0 is used in B1 → must NOT be removed from B0 + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Return { + value: Some(VarId(0)), + }, + }, + ]); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 1); + } + + // ── GlobalGet (side-effect-free) is removed when unused ────────────── + + #[test] + fn unused_global_get_removed() { + let mut func = make_func(single_block( + vec![IrInstr::GlobalGet { + dest: VarId(0), + index: GlobalIdx::new(0), + }], + ret_none(), + )); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 0); + } + + // ── No-op: empty function ──────────────────────────────────────────── + + #[test] + fn empty_function_unchanged() { + let mut func = make_func(single_block(vec![], ret_none())); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 0); + } +} diff --git a/crates/herkos/src/optimizer/empty_blocks.rs b/crates/herkos/src/optimizer/empty_blocks.rs new file mode 100644 index 0000000..ad39494 --- /dev/null +++ b/crates/herkos/src/optimizer/empty_blocks.rs @@ -0,0 +1,333 @@ +//! Empty block / passthrough elimination. +//! +//! A passthrough block contains no instructions and ends with an unconditional +//! `Jump`. All references to such a block can be replaced by references to its +//! ultimate target, eliminating the block entirely. +//! +//! Example (from the fibo transpilation): +//! B5: {} → Jump(B6) +//! B6: {} → Jump(B7) +//! +//! After this pass B4's branch-false target is rewritten from B5 to B7, and +//! B5/B6 become unreferenced dead blocks, removed by `dead_blocks::eliminate` +//! in the next pass. + +use crate::ir::{BlockId, IrFunction, IrTerminator}; +use std::collections::HashMap; + +/// Replace every reference to a passthrough block with its ultimate target. +/// +/// After this call, all passthrough blocks are unreferenced and will be +/// removed by the subsequent `dead_blocks::eliminate` pass. +pub fn eliminate(func: &mut IrFunction) { + // ── Step 1: Build the raw forwarding map ──────────────────────────── + // A block is a passthrough if it has no instructions and its terminator + // is an unconditional Jump. + let mut forward: HashMap = HashMap::new(); + for block in &func.blocks { + if block.instructions.is_empty() { + if let IrTerminator::Jump { target } = block.terminator { + forward.insert(block.id, target); + } + } + } + + if forward.is_empty() { + return; + } + + // ── Step 2: Resolve chains, cycle-safe ────────────────────────────── + // Collapse A → B → C chains into A → C. + // Bound hop count to func.blocks.len() to handle cycles (e.g. A→B→A). + let max_hops = func.blocks.len(); + let resolved: HashMap = forward + .keys() + .copied() + .map(|start| { + let mut cur = start; + for _ in 0..max_hops { + match forward.get(&cur) { + Some(&next) => cur = next, + None => break, + } + } + (start, cur) + }) + .collect(); + + // ── Step 3: Rewrite all terminator targets ─────────────────────────── + let fwd = |id: BlockId| resolved.get(&id).copied().unwrap_or(id); + + for block in &mut func.blocks { + match &mut block.terminator { + IrTerminator::Jump { target } => { + *target = fwd(*target); + } + IrTerminator::BranchIf { + if_true, if_false, .. + } => { + *if_true = fwd(*if_true); + *if_false = fwd(*if_false); + } + IrTerminator::BranchTable { + targets, default, .. + } => { + for t in targets.iter_mut() { + *t = fwd(*t); + } + *default = fwd(*default); + } + IrTerminator::Return { .. } | IrTerminator::Unreachable => {} + } + } + // Passthrough blocks are now unreferenced; dead_blocks::eliminate will + // remove them in the next pass. +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{IrBlock, IrFunction, IrInstr, IrTerminator, IrValue, TypeIdx, VarId}; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + } + } + + fn jump(id: u32, target: u32) -> IrBlock { + IrBlock { + id: BlockId(id), + instructions: vec![], + terminator: IrTerminator::Jump { + target: BlockId(target), + }, + } + } + + fn ret(id: u32) -> IrBlock { + IrBlock { + id: BlockId(id), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + } + } + + fn branch(id: u32, cond: u32, if_true: u32, if_false: u32) -> IrBlock { + IrBlock { + id: BlockId(id), + instructions: vec![], + terminator: IrTerminator::BranchIf { + condition: VarId(cond), + if_true: BlockId(if_true), + if_false: BlockId(if_false), + }, + } + } + + fn target_of(func: &IrFunction, id: u32) -> Option { + func.blocks + .iter() + .find(|b| b.id == BlockId(id)) + .and_then(|b| match b.terminator { + IrTerminator::Jump { target } => Some(target), + _ => None, + }) + } + + fn branch_targets(func: &IrFunction, id: u32) -> Option<(BlockId, BlockId)> { + func.blocks + .iter() + .find(|b| b.id == BlockId(id)) + .and_then(|b| match b.terminator { + IrTerminator::BranchIf { + if_true, if_false, .. + } => Some((if_true, if_false)), + _ => None, + }) + } + + // ── Basic cases ────────────────────────────────────────────────────── + + #[test] + fn no_passthrough_unchanged() { + // B0: instr → Jump(B1), B1: Return — no passthrough, nothing changes + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + ret(1), + ]); + eliminate(&mut func); + assert_eq!(target_of(&func, 0), Some(BlockId(1))); + assert_eq!(func.blocks.len(), 2); + } + + #[test] + fn single_passthrough_redirected() { + // B0 → B1(pass) → B2: B0's target becomes B2 + let mut func = make_func(vec![jump(0, 1), jump(1, 2), ret(2)]); + eliminate(&mut func); + assert_eq!(target_of(&func, 0), Some(BlockId(2))); + } + + #[test] + fn chain_collapsed() { + // B0 → B1(pass) → B2(pass) → B3: B0's target becomes B3 + let mut func = make_func(vec![jump(0, 1), jump(1, 2), jump(2, 3), ret(3)]); + eliminate(&mut func); + assert_eq!(target_of(&func, 0), Some(BlockId(3))); + // B1 should also forward to B3 + assert_eq!(target_of(&func, 1), Some(BlockId(3))); + } + + // ── BranchIf ──────────────────────────────────────────────────────── + + #[test] + fn branch_if_both_arms_redirected() { + // B0: BranchIf(true→B1(pass)→B3, false→B2(pass)→B4) + let mut func = make_func(vec![ + branch(0, 0, 1, 2), + jump(1, 3), + jump(2, 4), + ret(3), + ret(4), + ]); + eliminate(&mut func); + let (t, f) = branch_targets(&func, 0).unwrap(); + assert_eq!(t, BlockId(3)); + assert_eq!(f, BlockId(4)); + } + + #[test] + fn branch_if_one_arm_redirected() { + // B0: BranchIf(true→B1(non-pass), false→B2(pass)→B3) + let mut func = make_func(vec![ + branch(0, 0, 1, 2), + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + }], + terminator: IrTerminator::Return { value: None }, + }, + jump(2, 3), + ret(3), + ]); + eliminate(&mut func); + let (t, f) = branch_targets(&func, 0).unwrap(); + assert_eq!(t, BlockId(1)); // unchanged + assert_eq!(f, BlockId(3)); // forwarded + } + + // ── BranchTable ────────────────────────────────────────────────────── + + #[test] + fn branch_table_redirected() { + // B0: BranchTable(targets:[B1(pass)→B3, B2(pass)→B4], default:B5) + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::BranchTable { + index: VarId(0), + targets: vec![BlockId(1), BlockId(2)], + default: BlockId(5), + }, + }, + jump(1, 3), + jump(2, 4), + ret(3), + ret(4), + ret(5), + ]); + eliminate(&mut func); + let b = func.blocks.iter().find(|b| b.id == BlockId(0)).unwrap(); + match &b.terminator { + IrTerminator::BranchTable { + targets, default, .. + } => { + assert_eq!(targets[0], BlockId(3)); + assert_eq!(targets[1], BlockId(4)); + assert_eq!(*default, BlockId(5)); // non-passthrough, unchanged + } + _ => panic!("expected BranchTable"), + } + } + + // ── Edge cases ─────────────────────────────────────────────────────── + + #[test] + fn cycle_safe() { + // B0 → B1(pass) → B2(pass) → B1 (cycle) + // Should not infinite loop; B0 ends up pointing somewhere in the cycle + let mut func = make_func(vec![jump(0, 1), jump(1, 2), jump(2, 1)]); + // Must complete without hanging; exact target is unspecified for cycles + eliminate(&mut func); + } + + #[test] + fn entry_passthrough_not_removed() { + // Entry block B0 is itself a passthrough: B0(pass) → B1 → Return + // After pass B0's jump stays (it's a passthrough of a passthrough pointing at B1), + // dead_blocks won't remove B0 (it starts BFS from entry). + let mut func = make_func(vec![jump(0, 1), ret(1)]); + eliminate(&mut func); + // B0 is a passthrough pointing to B1; resolve(B0)=B1 but nobody *jumps to* B0, + // so B0's own terminator remains Jump(B1) (forwarded to itself, i.e. B1). + assert_eq!(target_of(&func, 0), Some(BlockId(1))); + assert_eq!(func.blocks.len(), 2); // dead_blocks not called here, both still present + } + + // ── Realistic fibo pattern ──────────────────────────────────────────── + + #[test] + fn fibo_pattern() { + // Mirrors the B3/B4/B5/B6/B7 structure from func_7 (release build): + // B3: BranchIf(cond→B7, else→B4) + // B4: BranchIf(cond→B3, else→B5) + // B5: {} → Jump(B6) ← passthrough + // B6: {} → Jump(B7) ← passthrough + // B7: Return + // + // After eliminate(): + // B4's false-arm should be B7 (not B5) + // B3 and B7 are unchanged + let mut func = make_func(vec![ + branch(3, 0, 7, 4), + branch(4, 1, 3, 5), + jump(5, 6), // passthrough + jump(6, 7), // passthrough + ret(7), + ]); + func.entry_block = BlockId(3); + + eliminate(&mut func); + + // B4's false-arm: was B5, must now be B7 + let (true_arm, false_arm) = branch_targets(&func, 4).unwrap(); + assert_eq!(true_arm, BlockId(3)); // back-edge unchanged + assert_eq!(false_arm, BlockId(7)); // forwarded through B5→B6→B7 + + // B3's true-arm was already B7 — still B7 + let (t3, f3) = branch_targets(&func, 3).unwrap(); + assert_eq!(t3, BlockId(7)); + assert_eq!(f3, BlockId(4)); + + // B5 and B6 themselves now point to B7 (resolved) + assert_eq!(target_of(&func, 5), Some(BlockId(7))); + assert_eq!(target_of(&func, 6), Some(BlockId(7))); + } +} diff --git a/crates/herkos/src/optimizer/gvn.rs b/crates/herkos/src/optimizer/gvn.rs new file mode 100644 index 0000000..4d7605d --- /dev/null +++ b/crates/herkos/src/optimizer/gvn.rs @@ -0,0 +1,619 @@ +//! Global value numbering (GVN) — cross-block CSE using the dominator tree. +//! +//! Extends block-local CSE ([`super::local_cse`]) to work across basic blocks. +//! If block A dominates block B (every path to B passes through A), then any +//! pure computation defined in A with the same value key as one in B can be +//! reused in B instead of recomputing. +//! +//! ## Algorithm +//! +//! 1. Compute the immediate dominator of each block (Cooper/Harvey/Kennedy +//! iterative algorithm) to build the dominator tree. +//! 2. Walk the dominator tree in preorder using a scoped value-number table. +//! On entry to a block, push a new scope; on exit, pop it. +//! 3. For each pure instruction (`Const`, `BinOp`, `UnOp`) in the current +//! block, compute a value key. If the key already exists in any enclosing +//! scope (meaning it was computed in a dominating block), record a +//! replacement: `dest → first_var`. Otherwise insert the key into the +//! current scope. +//! 4. After the walk, rewrite all recorded destinations to +//! `Assign { dest, src: first_var }` and let copy-propagation clean up. +//! +//! **Only pure instructions are eligible.** Loads, calls, and memory ops are +//! never deduplicated (they may trap or have observable side effects). + +use super::utils::{build_predecessors, instr_dest, prune_dead_locals, terminator_successors}; +use crate::ir::{BinOp, BlockId, IrFunction, IrInstr, IrValue, UnOp, VarId}; +use std::collections::{HashMap, HashSet}; + +// ── Value key ──────────────────────────────────────────────────────────────── + +/// Hashable representation of a pure computation for deduplication. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum ValueKey { + Const(ConstKey), + BinOp { op: BinOp, lhs: VarId, rhs: VarId }, + UnOp { op: UnOp, operand: VarId }, +} + +/// Bit-level constant key that implements `Eq`/`Hash` correctly for floats. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum ConstKey { + I32(i32), + I64(i64), + F32(u32), + F64(u64), +} + +impl From for ConstKey { + fn from(v: IrValue) -> Self { + match v { + IrValue::I32(x) => ConstKey::I32(x), + IrValue::I64(x) => ConstKey::I64(x), + IrValue::F32(x) => ConstKey::F32(x.to_bits()), + IrValue::F64(x) => ConstKey::F64(x.to_bits()), + } + } +} + +fn is_commutative(op: &BinOp) -> bool { + matches!( + op, + BinOp::I32Add + | BinOp::I32Mul + | BinOp::I32And + | BinOp::I32Or + | BinOp::I32Xor + | BinOp::I32Eq + | BinOp::I32Ne + | BinOp::I64Add + | BinOp::I64Mul + | BinOp::I64And + | BinOp::I64Or + | BinOp::I64Xor + | BinOp::I64Eq + | BinOp::I64Ne + | BinOp::F32Add + | BinOp::F32Mul + | BinOp::F32Eq + | BinOp::F32Ne + | BinOp::F64Add + | BinOp::F64Mul + | BinOp::F64Eq + | BinOp::F64Ne + ) +} + +fn binop_key(op: BinOp, lhs: VarId, rhs: VarId) -> ValueKey { + let (lhs, rhs) = if is_commutative(&op) && lhs.0 > rhs.0 { + (rhs, lhs) + } else { + (lhs, rhs) + }; + ValueKey::BinOp { op, lhs, rhs } +} + +// ── Multi-definition detection ─────────────────────────────────────────────── + +/// Build the set of variables defined more than once across the function. +/// +/// After phi lowering the code is no longer in strict SSA form: loop phi +/// variables receive an initial assignment in the pre-loop block and a +/// back-edge update at the end of each iteration. These variables carry +/// different values at different program points, so any BinOp/UnOp that uses +/// them cannot be safely hoisted or deduplicated across blocks. +/// +/// `Const` instructions are always safe (they have no operands). +fn build_multi_def_vars(func: &IrFunction) -> HashSet { + let mut def_count: HashMap = HashMap::new(); + for block in &func.blocks { + for instr in &block.instructions { + if let Some(dest) = instr_dest(instr) { + *def_count.entry(dest).or_insert(0) += 1; + } + } + } + def_count + .into_iter() + .filter(|&(_, count)| count > 1) + .map(|(v, _)| v) + .collect() +} + +// ── Dominator tree ─────────────────────────────────────────────────────────── + +/// Compute the reverse-postorder traversal of the CFG starting from `entry`. +fn compute_rpo(func: &IrFunction) -> Vec { + let block_idx: HashMap = + func.blocks.iter().enumerate().map(|(i, b)| (b.id, i)).collect(); + + let mut visited = vec![false; func.blocks.len()]; + let mut postorder = Vec::with_capacity(func.blocks.len()); + + dfs_postorder(func, func.entry_block, &block_idx, &mut visited, &mut postorder); + + postorder.reverse(); + postorder +} + +fn dfs_postorder( + func: &IrFunction, + block_id: BlockId, + block_idx: &HashMap, + visited: &mut Vec, + postorder: &mut Vec, +) { + let idx = match block_idx.get(&block_id) { + Some(&i) => i, + None => return, + }; + if visited[idx] { + return; + } + visited[idx] = true; + + for succ in terminator_successors(&func.blocks[idx].terminator) { + dfs_postorder(func, succ, block_idx, visited, postorder); + } + postorder.push(block_id); +} + +/// Compute the immediate dominator of each block using Cooper/Harvey/Kennedy. +/// +/// Returns `idom[b] = immediate dominator of b`, with `idom[entry] = entry`. +fn compute_idoms(func: &IrFunction) -> HashMap { + let rpo = compute_rpo(func); + // rpo_num[b] = index in RPO order (entry = 0, smallest index = processed first) + let rpo_num: HashMap = + rpo.iter().enumerate().map(|(i, &b)| (b, i)).collect(); + + let preds = build_predecessors(func); + let entry = func.entry_block; + + let mut idom: HashMap = HashMap::new(); + idom.insert(entry, entry); + + let mut changed = true; + while changed { + changed = false; + // Process blocks in RPO order, skipping the entry. + for &b in rpo.iter().skip(1) { + let block_preds = &preds[&b]; + + // Start with the first predecessor that already has an idom assigned. + let mut new_idom = match block_preds + .iter() + .filter(|&&p| idom.contains_key(&p)) + .min_by_key(|&&p| rpo_num[&p]) + { + Some(&p) => p, + None => continue, // unreachable block — skip + }; + + // Intersect (walk up dom tree) with all other processed predecessors. + for &p in block_preds { + if p != new_idom && idom.contains_key(&p) { + new_idom = intersect(p, new_idom, &idom, &rpo_num); + } + } + + if idom.get(&b) != Some(&new_idom) { + idom.insert(b, new_idom); + changed = true; + } + } + } + + idom +} + +/// Walk up both fingers until they meet — the standard Cooper intersect. +fn intersect( + mut a: BlockId, + mut b: BlockId, + idom: &HashMap, + rpo_num: &HashMap, +) -> BlockId { + while a != b { + while rpo_num[&a] > rpo_num[&b] { + a = idom[&a]; + } + while rpo_num[&b] > rpo_num[&a] { + b = idom[&b]; + } + } + a +} + +/// Build dominator-tree children from the `idom` map. +fn build_dom_children( + idom: &HashMap, + entry: BlockId, +) -> HashMap> { + let mut children: HashMap> = HashMap::new(); + for (&b, &d) in idom { + if b != entry { + children.entry(d).or_default().push(b); + } + } + // Sort children for deterministic output. + for v in children.values_mut() { + v.sort_unstable_by_key(|id| id.0); + } + children +} + +// ── GVN walk ───────────────────────────────────────────────────────────────── + +/// Recursively walk the dominator tree in preorder. +/// +/// `value_map` is a flat map that acts as a scoped table: on entry we insert +/// new keys (recording them in `frame_keys`), on exit we remove them, restoring +/// the parent scope. Any key already present in `value_map` when we visit a +/// block was computed in a dominating block — safe to reuse. +fn collect_replacements( + func: &IrFunction, + block_id: BlockId, + dom_children: &HashMap>, + block_idx: &HashMap, + multi_def_vars: &HashSet, + value_map: &mut HashMap, + replacements: &mut HashMap, +) { + let idx = match block_idx.get(&block_id) { + Some(&i) => i, + None => return, + }; + + let mut frame_keys: Vec = Vec::new(); + + for instr in &func.blocks[idx].instructions { + match instr { + IrInstr::Const { dest, value } => { + // A multiply-defined dest (loop phi var) must be skipped + // entirely: adding it to replacements would replace ALL of + // its definitions with Assign(first), clobbering back-edge + // updates; inserting it into value_map would let dominated + // blocks wrongly reuse a value that changes each iteration. + if multi_def_vars.contains(dest) { + continue; + } + let key = ValueKey::Const(ConstKey::from(*value)); + if let Some(&first) = value_map.get(&key) { + replacements.insert(*dest, first); + } else { + value_map.insert(key.clone(), *dest); + frame_keys.push(key); + } + } + + IrInstr::BinOp { dest, op, lhs, rhs, .. } => { + // Skip if dest is multiply-defined (same reason as Const). + // Also skip if any operand is multiply-defined: a loop phi + // var carries different values per iteration, so the same + // BinOp in two dominated blocks can produce different results. + if multi_def_vars.contains(dest) + || multi_def_vars.contains(lhs) + || multi_def_vars.contains(rhs) + { + continue; + } + let key = binop_key(*op, *lhs, *rhs); + if let Some(&first) = value_map.get(&key) { + replacements.insert(*dest, first); + } else { + value_map.insert(key.clone(), *dest); + frame_keys.push(key); + } + } + + IrInstr::UnOp { dest, op, operand } => { + if multi_def_vars.contains(dest) || multi_def_vars.contains(operand) { + continue; + } + let key = ValueKey::UnOp { op: *op, operand: *operand }; + if let Some(&first) = value_map.get(&key) { + replacements.insert(*dest, first); + } else { + value_map.insert(key.clone(), *dest); + frame_keys.push(key); + } + } + + _ => {} + } + } + + // Recurse into dominated children. + if let Some(children) = dom_children.get(&block_id) { + for &child in children { + collect_replacements( + func, + child, + dom_children, + block_idx, + multi_def_vars, + value_map, + replacements, + ); + } + } + + // Pop this block's scope. + for key in frame_keys { + value_map.remove(&key); + } +} + +// ── Pass entry point ───────────────────────────────────────────────────────── + +/// Eliminates common subexpressions across basic blocks using the dominator tree. +pub fn eliminate(func: &mut IrFunction) { + if func.blocks.len() < 2 { + return; // nothing to do for single-block functions (local_cse covers those) + } + + let idom = compute_idoms(func); + let dom_children = build_dom_children(&idom, func.entry_block); + let block_idx: HashMap = + func.blocks.iter().enumerate().map(|(i, b)| (b.id, i)).collect(); + + let multi_def_vars = build_multi_def_vars(func); + let mut value_map: HashMap = HashMap::new(); + let mut replacements: HashMap = HashMap::new(); + + collect_replacements( + func, + func.entry_block, + &dom_children, + &block_idx, + &multi_def_vars, + &mut value_map, + &mut replacements, + ); + + if replacements.is_empty() { + return; + } + + for block in &mut func.blocks { + for instr in &mut block.instructions { + if let Some(dest) = instr_dest(instr) { + if let Some(&src) = replacements.get(&dest) { + *instr = IrInstr::Assign { dest, src }; + } + } + } + } + + prune_dead_locals(func); +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{IrBlock, IrTerminator, IrValue, TypeIdx, WasmType}; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + } + } + + /// Entry (B0) → B1: const duplicated across the edge. + /// B0 dominates B1, so the duplicate in B1 should be replaced with Assign. + #[test] + fn cross_block_const_deduplication() { + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { dest: VarId(0), value: IrValue::I32(42) }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }; + let b1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { dest: VarId(1), value: IrValue::I32(42) }], + terminator: IrTerminator::Return { value: Some(VarId(1)) }, + }; + let mut func = make_func(vec![b0, b1]); + func.locals = vec![(VarId(0), WasmType::I32), (VarId(1), WasmType::I32)]; + + eliminate(&mut func); + + assert!( + matches!(func.blocks[0].instructions[0], IrInstr::Const { dest: VarId(0), .. }), + "first definition should stay as Const" + ); + assert!( + matches!( + func.blocks[1].instructions[0], + IrInstr::Assign { dest: VarId(1), src: VarId(0) } + ), + "dominated duplicate should become Assign" + ); + } + + /// Entry (B0) → B1: BinOp duplicated across the edge. + #[test] + fn cross_block_binop_deduplication() { + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }; + let b1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::Return { value: Some(VarId(3)) }, + }; + let mut func = make_func(vec![b0, b1]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + + eliminate(&mut func); + + assert!(matches!(func.blocks[0].instructions[0], IrInstr::BinOp { .. })); + assert!( + matches!( + func.blocks[1].instructions[0], + IrInstr::Assign { dest: VarId(3), src: VarId(2) } + ), + "dominated duplicate BinOp should become Assign" + ); + } + + /// B0 branches to B1 and B2 (diamond). B1 and B2 don't dominate each other, + /// so a const in B1 should NOT eliminate the same const in B2. + #[test] + fn sibling_blocks_not_deduplicated() { + // B0 → B1, B0 → B2, both converge to B3 + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }; + let b1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { dest: VarId(1), value: IrValue::I32(7) }], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }; + let b2 = IrBlock { + id: BlockId(2), + instructions: vec![IrInstr::Const { dest: VarId(2), value: IrValue::I32(7) }], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }; + let b3 = IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }; + let mut func = make_func(vec![b0, b1, b2, b3]); + func.locals = vec![(VarId(1), WasmType::I32), (VarId(2), WasmType::I32)]; + + eliminate(&mut func); + + // Both consts should remain — neither block dominates the other. + assert!( + matches!(func.blocks[1].instructions[0], IrInstr::Const { dest: VarId(1), .. }), + "const in B1 must not be eliminated" + ); + assert!( + matches!(func.blocks[2].instructions[0], IrInstr::Const { dest: VarId(2), .. }), + "const in B2 must not be eliminated" + ); + } + + /// A const defined in B0 (entry) should be reused in a deeply dominated block. + #[test] + fn deep_domination_chain() { + // B0 → B1 → B2: const defined in B0, duplicated in B2 + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { dest: VarId(0), value: IrValue::I32(99) }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }; + let b1 = IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }; + let b2 = IrBlock { + id: BlockId(2), + instructions: vec![IrInstr::Const { dest: VarId(1), value: IrValue::I32(99) }], + terminator: IrTerminator::Return { value: Some(VarId(1)) }, + }; + let mut func = make_func(vec![b0, b1, b2]); + func.locals = vec![(VarId(0), WasmType::I32), (VarId(1), WasmType::I32)]; + + eliminate(&mut func); + + assert!( + matches!(func.blocks[0].instructions[0], IrInstr::Const { dest: VarId(0), .. }) + ); + assert!( + matches!( + func.blocks[2].instructions[0], + IrInstr::Assign { dest: VarId(1), src: VarId(0) } + ), + "deeply dominated duplicate should be eliminated" + ); + } + + /// Commutative BinOps with swapped operands in a dominated block should be deduped. + #[test] + fn cross_block_commutative_deduplication() { + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Mul, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }; + let b1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Mul, + lhs: VarId(1), // swapped + rhs: VarId(0), + }], + terminator: IrTerminator::Return { value: Some(VarId(3)) }, + }; + let mut func = make_func(vec![b0, b1]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + + eliminate(&mut func); + + assert!( + matches!( + func.blocks[1].instructions[0], + IrInstr::Assign { dest: VarId(3), src: VarId(2) } + ), + "commutative cross-block BinOp should be deduplicated" + ); + } + + /// Single-block functions are skipped entirely (handled by local_cse). + #[test] + fn single_block_function_unchanged() { + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { dest: VarId(0), value: IrValue::I32(1) }, + IrInstr::Const { dest: VarId(1), value: IrValue::I32(1) }, + ], + terminator: IrTerminator::Return { value: None }, + }; + let mut func = make_func(vec![b0]); + func.locals = vec![(VarId(0), WasmType::I32), (VarId(1), WasmType::I32)]; + + eliminate(&mut func); + + // GVN skips single-block functions; duplicates remain (local_cse's job). + assert!(matches!(func.blocks[0].instructions[0], IrInstr::Const { .. })); + assert!(matches!(func.blocks[0].instructions[1], IrInstr::Const { .. })); + } +} diff --git a/crates/herkos/src/optimizer/licm.rs b/crates/herkos/src/optimizer/licm.rs new file mode 100644 index 0000000..172c948 --- /dev/null +++ b/crates/herkos/src/optimizer/licm.rs @@ -0,0 +1,1307 @@ +//! Loop-invariant code motion (LICM). +//! +//! Identifies instructions in loop headers whose operands don't change across +//! iterations, and moves them to a preheader block. +//! +//! ## Algorithm +//! +//! 1. Compute dominators (iterative algorithm) +//! 2. Find back edges: edge (src → tgt) where tgt dominates src +//! 3. Find natural loops: for each back edge, collect all blocks that reach +//! the source without going through the header +//! 4. For each loop, identify invariant instructions in the header (fixpoint): +//! - `Const` — trivially invariant +//! - `BinOp`, `UnOp`, `Assign`, `Select` — invariant if all operands are +//! defined outside the loop or by other invariant instructions +//! - Skip: `Load`, `Store`, `Call*`, `Global*`, `Memory*` +//! 5. Create or reuse a preheader block and move invariant instructions there +//! +//! **V1 simplification:** only hoists from the loop header block (which +//! dominates all loop blocks by definition). + +use super::utils::{ + build_predecessors, for_each_use, instr_dest, rewrite_terminator_target, terminator_successors, +}; +use crate::ir::{BlockId, IrBlock, IrFunction, IrInstr, IrTerminator, VarId}; +use std::collections::{HashMap, HashSet}; + +/// Run loop-invariant code motion on `func`. +pub fn eliminate(func: &mut IrFunction) { + if func.blocks.len() < 2 { + return; + } + + let preds = build_predecessors(func); + let dominators = compute_dominators(func, &preds); + let back_edges = find_back_edges(func, &dominators); + + if back_edges.is_empty() { + return; + } + + let loops = find_natural_loops(&back_edges, &preds); + + for (header, loop_blocks) in &loops { + hoist_invariants(func, *header, loop_blocks); + } +} + +// ── Dominator computation ──────────────────────────────────────────────────── + +/// Compute the dominator set for each block using the iterative algorithm. +/// +/// Returns a map from each block to the set of blocks that dominate it. +fn compute_dominators( + func: &IrFunction, + preds: &HashMap>, +) -> HashMap> { + let entry = func.entry_block; + let all_block_ids: HashSet = func.blocks.iter().map(|b| b.id).collect(); + + let mut dom: HashMap> = HashMap::new(); + dom.insert(entry, HashSet::from([entry])); + + for block in &func.blocks { + if block.id != entry { + dom.insert(block.id, all_block_ids.clone()); + } + } + + loop { + let mut changed = false; + for block in &func.blocks { + if block.id == entry { + continue; + } + let pred_set = &preds[&block.id]; + if pred_set.is_empty() { + continue; + } + + // new_dom = {self} ∪ ∩(dom[p] for p in preds) + let mut new_dom: Option> = None; + for p in pred_set { + if let Some(p_dom) = dom.get(p) { + new_dom = Some(match new_dom { + None => p_dom.clone(), + Some(current) => current.intersection(p_dom).copied().collect(), + }); + } + } + + let mut new_dom = new_dom.unwrap_or_default(); + new_dom.insert(block.id); + + if new_dom != dom[&block.id] { + dom.insert(block.id, new_dom); + changed = true; + } + } + if !changed { + break; + } + } + + dom +} + +// ── Back edge detection ────────────────────────────────────────────────────── + +/// Find all back edges in the CFG. +/// +/// A back edge is (src, tgt) where tgt dominates src. +fn find_back_edges( + func: &IrFunction, + dominators: &HashMap>, +) -> Vec<(BlockId, BlockId)> { + let mut back_edges = Vec::new(); + for block in &func.blocks { + for succ in terminator_successors(&block.terminator) { + if dominators + .get(&block.id) + .is_some_and(|dom_set| dom_set.contains(&succ)) + { + back_edges.push((block.id, succ)); + } + } + } + back_edges +} + +// ── Natural loop detection ─────────────────────────────────────────────────── + +/// Find natural loops from back edges. +/// +/// For each back edge (src → header), collects all blocks that can reach `src` +/// without going through `header`. Multiple back edges with the same header +/// are merged into one loop. +fn find_natural_loops( + back_edges: &[(BlockId, BlockId)], + preds: &HashMap>, +) -> Vec<(BlockId, HashSet)> { + let mut loops: HashMap> = HashMap::new(); + + for &(src, header) in back_edges { + let loop_blocks = loops.entry(header).or_insert_with(|| { + let mut set = HashSet::new(); + set.insert(header); + set + }); + + let mut worklist = vec![src]; + while let Some(n) = worklist.pop() { + if loop_blocks.insert(n) { + if let Some(n_preds) = preds.get(&n) { + for &p in n_preds { + worklist.push(p); + } + } + } + } + } + + loops.into_iter().collect() +} + +// ── Invariant identification & hoisting ────────────────────────────────────── + +/// Returns `true` if the instruction type is eligible for LICM hoisting. +/// +/// Only pure, side-effect-free computations are hoistable. Instructions that +/// depend on mutable state (`Global*`, `Memory*`) or have side effects +/// (`Load`, `Store`, `Call*`) are excluded. +fn is_licm_hoistable(instr: &IrInstr) -> bool { + matches!( + instr, + IrInstr::Const { .. } + | IrInstr::BinOp { .. } + | IrInstr::UnOp { .. } + | IrInstr::Assign { .. } + | IrInstr::Select { .. } + ) +} + +/// Identify loop-invariant instructions in the header and hoist them to a preheader. +fn hoist_invariants(func: &mut IrFunction, header: BlockId, loop_blocks: &HashSet) { + let header_idx = match func.blocks.iter().position(|b| b.id == header) { + Some(idx) => idx, + None => return, + }; + + // Collect all VarIds defined in any loop block. + let mut loop_defs: HashSet = HashSet::new(); + for block in &func.blocks { + if loop_blocks.contains(&block.id) { + for instr in &block.instructions { + if let Some(dest) = instr_dest(instr) { + loop_defs.insert(dest); + } + } + } + } + + // Fixpoint: identify invariant instructions in the header. + let mut invariant_dests: HashSet = HashSet::new(); + loop { + let mut changed = false; + for instr in &func.blocks[header_idx].instructions { + if !is_licm_hoistable(instr) { + continue; + } + let dest = match instr_dest(instr) { + Some(d) => d, + None => continue, + }; + if invariant_dests.contains(&dest) { + continue; + } + + let mut all_ops_invariant = true; + for_each_use(instr, |v| { + if loop_defs.contains(&v) && !invariant_dests.contains(&v) { + all_ops_invariant = false; + } + }); + + if all_ops_invariant { + invariant_dests.insert(dest); + changed = true; + } + } + if !changed { + break; + } + } + + if invariant_dests.is_empty() { + return; + } + + // Find or create preheader. + let preheader_id = find_or_create_preheader(func, header, loop_blocks); + + // Re-lookup indices after possible block insertion. + let header_idx = func.blocks.iter().position(|b| b.id == header).unwrap(); + let preheader_idx = func + .blocks + .iter() + .position(|b| b.id == preheader_id) + .unwrap(); + + // Move invariant instructions from header to preheader (in order). + let mut hoisted = Vec::new(); + let mut remaining = Vec::new(); + + for instr in func.blocks[header_idx].instructions.drain(..) { + if let Some(dest) = instr_dest(&instr) { + if invariant_dests.contains(&dest) { + hoisted.push(instr); + continue; + } + } + remaining.push(instr); + } + + func.blocks[header_idx].instructions = remaining; + func.blocks[preheader_idx].instructions.extend(hoisted); +} + +/// Allocate a fresh `BlockId` that doesn't conflict with existing blocks. +fn fresh_block_id(func: &IrFunction) -> BlockId { + let max_id = func.blocks.iter().map(|b| b.id.0).max().unwrap_or(0); + BlockId(max_id + 1) +} + +/// Find an existing preheader or create a new one. +/// +/// A preheader is reused if it is the sole non-loop predecessor and ends +/// with an unconditional jump to the header. Otherwise a new preheader +/// block is created and non-loop predecessors are redirected to it. +fn find_or_create_preheader( + func: &mut IrFunction, + header: BlockId, + loop_blocks: &HashSet, +) -> BlockId { + let preds = build_predecessors(func); + let header_preds = &preds[&header]; + let non_loop_preds: Vec = header_preds + .iter() + .filter(|p| !loop_blocks.contains(p)) + .copied() + .collect(); + + if non_loop_preds.is_empty() { + // Header has no non-loop predecessors (entry block or unreachable from outside). + let preheader_id = fresh_block_id(func); + func.blocks.push(IrBlock { + id: preheader_id, + instructions: vec![], + terminator: IrTerminator::Jump { target: header }, + }); + if header == func.entry_block { + func.entry_block = preheader_id; + } + return preheader_id; + } + + // Reuse if single non-loop predecessor with unconditional jump to header. + if non_loop_preds.len() == 1 { + let pred = non_loop_preds[0]; + let pred_idx = func.blocks.iter().position(|b| b.id == pred).unwrap(); + if matches!(func.blocks[pred_idx].terminator, IrTerminator::Jump { target } if target == header) + { + return pred; + } + } + + // Create a new preheader and redirect non-loop predecessors. + let preheader_id = fresh_block_id(func); + func.blocks.push(IrBlock { + id: preheader_id, + instructions: vec![], + terminator: IrTerminator::Jump { target: header }, + }); + + for block in &mut func.blocks { + if non_loop_preds.contains(&block.id) { + rewrite_terminator_target(&mut block.terminator, header, preheader_id); + } + } + + preheader_id +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + BinOp, IrBlock, IrFunction, IrInstr, IrTerminator, IrValue, TypeIdx, VarId, WasmType, + }; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + } + } + + // ── No loops → no changes ──────────────────────────────────────────── + + #[test] + fn no_loop_no_change() { + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Return { + value: Some(VarId(0)), + }, + }, + ]); + + eliminate(&mut func); + + // No loops, so the const stays in block 0. + assert_eq!(func.blocks[0].instructions.len(), 1); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { dest: VarId(0), .. } + )); + } + + // ── Simple loop: const in header → hoisted to preheader ────────────── + + #[test] + fn simple_loop_const_hoisted() { + // B0 (entry): Jump(B1) + // B1 (header): v0 = Const(42), BranchIf(v1, B2, B3) + // B2 (body): Jump(B1) ← back edge + // B3 (exit): Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // B0 is the sole non-loop predecessor with Jump → reused as preheader. + // v0 = Const(42) should be hoisted to B0. + assert_eq!(func.blocks[0].instructions.len(), 1); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + } + )); + + // B1 (header) should have no instructions. + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── BinOp with operands from outside loop → hoisted ────────────────── + + #[test] + fn invariant_binop_hoisted() { + // B0 (entry): v0 = Const(10), v1 = Const(20), Jump(B1) + // B1 (header): v2 = BinOp::Add(v0, v1), BranchIf(v3, B2, B3) + // B2 (body): Jump(B1) + // B3 (exit): Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(10), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(20), + }, + ], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(3), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // v2 = BinOp should be hoisted to B0 (preheader). + assert_eq!(func.blocks[0].instructions.len(), 3); + assert!(matches!( + func.blocks[0].instructions[2], + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + .. + } + )); + + // Header should be empty. + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── Chained invariants: const → binop using that const ─────────────── + + #[test] + fn chained_invariants_hoisted() { + // B0 (entry): v0 = Const(10), Jump(B1) + // B1 (header): v1 = Const(65536), v2 = BinOp::Add(v0, v1), BranchIf(v3, B2, B3) + // B2 (body): Jump(B1) + // B3 (exit): Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(10), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(65536), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(3), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // Both v1 = Const and v2 = BinOp should be hoisted to B0. + // B0 now has: v0 = Const(10), v1 = Const(65536), v2 = Add(v0, v1). + assert_eq!(func.blocks[0].instructions.len(), 3); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(65536), + } + )); + assert!(matches!( + func.blocks[0].instructions[2], + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + .. + } + )); + + // Header should be empty. + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── Non-hoistable instructions stay in the header ──────────────────── + + #[test] + fn side_effectful_not_hoisted() { + use crate::ir::MemoryAccessWidth; + + // B0: Jump(B1) + // B1 (header): v0 = Const(0), v1 = Load(v0), BranchIf(v2, B2, B3) + // B2: Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(0), + }, + IrInstr::Load { + dest: VarId(1), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // v0 = Const is hoisted (invariant), but Load stays (not hoistable). + assert_eq!(func.blocks[0].instructions.len(), 1); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { dest: VarId(0), .. } + )); + + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 1); + assert!(matches!(header.instructions[0], IrInstr::Load { .. })); + } + + // ── BinOp with operand from loop body → NOT hoisted ────────────────── + + #[test] + fn loop_dependent_not_hoisted() { + // B0: v0 = Const(1), Jump(B1) + // B1 (header): v2 = BinOp::Add(v0, v1), BranchIf(v3, B2, B3) + // v1 is defined in B2 (loop body) → v2 is NOT invariant + // B2: v1 = Const(5), Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(3), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(5), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // v2 = BinOp should NOT be hoisted because v1 is defined in B2 (loop body). + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 1); + assert!(matches!(header.instructions[0], IrInstr::BinOp { .. })); + } + + // ── Preheader reuse: single non-loop predecessor with Jump ─────────── + + #[test] + fn preheader_reused_when_possible() { + // B0 (entry): v0 = Const(99), Jump(B1) + // B1 (header): v1 = Const(42), BranchIf(v2, B2, B3) + // B2: Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(99), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // B0 should be reused as preheader (sole non-loop pred with Jump). + // No new blocks should be created. + assert_eq!(func.blocks.len(), 4); + assert_eq!(func.blocks[0].instructions.len(), 2); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(99), + } + )); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + } + )); + } + + // ── Preheader creation: multiple non-loop predecessors ─────────────── + + #[test] + fn preheader_created_when_needed() { + // B0 (entry): BranchIf(v0, B1, B2) + // B1: Jump(B3) + // B2: Jump(B3) + // B3 (header): v1 = Const(42), BranchIf(v2, B4, B5) + // B4 (body): Jump(B3) ← back edge + // B5 (exit): Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(4), + if_false: BlockId(5), + }, + }, + IrBlock { + id: BlockId(4), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(5), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // A new preheader (B6) should be created. + assert_eq!(func.blocks.len(), 7); + + let preheader = func.blocks.iter().find(|b| b.id == BlockId(6)).unwrap(); + assert_eq!(preheader.instructions.len(), 1); + assert!(matches!( + preheader.instructions[0], + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + } + )); + assert!(matches!( + preheader.terminator, + IrTerminator::Jump { target: BlockId(3) } + )); + + // B1 and B2 should now jump to the preheader (B6). + let b1 = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert!(matches!( + b1.terminator, + IrTerminator::Jump { target: BlockId(6) } + )); + let b2 = func.blocks.iter().find(|b| b.id == BlockId(2)).unwrap(); + assert!(matches!( + b2.terminator, + IrTerminator::Jump { target: BlockId(6) } + )); + + // Header (B3) should be empty. + let header = func.blocks.iter().find(|b| b.id == BlockId(3)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── GlobalGet not hoisted (depends on mutable state) ───────────────── + + #[test] + fn global_get_not_hoisted() { + use crate::ir::GlobalIdx; + + // B0: Jump(B1) + // B1 (header): v0 = GlobalGet(0), BranchIf(v1, B2, B3) + // B2: Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::GlobalGet { + dest: VarId(0), + index: GlobalIdx::new(0), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // GlobalGet should NOT be hoisted (mutable global may change each iteration). + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 1); + assert!(matches!(header.instructions[0], IrInstr::GlobalGet { .. })); + } + + // ── Self-loop: header is also the back-edge source ─────────────────── + + #[test] + fn self_loop_const_hoisted() { + // B0: Jump(B1) + // B1: v0 = Const(42), BranchIf(v1, B1, B2) ← self-loop + // B2: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // Const should be hoisted to B0 (preheader). + assert_eq!(func.blocks[0].instructions.len(), 1); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + } + )); + + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── No invariant instructions → no changes ─────────────────────────── + + #[test] + fn no_invariants_no_change() { + use crate::ir::MemoryAccessWidth; + + // B0: v0 = Const(0), Jump(B1) + // B1 (header): v1 = Load(v0), BranchIf(v2, B2, B3) + // B2: Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(0), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Load { + dest: VarId(1), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // No invariants to hoist — no new blocks, header unchanged. + assert_eq!(func.blocks.len(), 4); + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 1); + assert!(matches!(header.instructions[0], IrInstr::Load { .. })); + } + + // ── Entry block as loop header ─────────────────────────────────────── + + #[test] + fn entry_block_loop_header() { + // B0 (entry/header): v0 = Const(42), BranchIf(v1, B1, B2) + // B1 (body): Jump(B0) ← back edge + // B2 (exit): Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(0) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // A preheader should be created, and entry_block updated. + assert_eq!(func.blocks.len(), 4); + let preheader_id = func.entry_block; + assert_ne!(preheader_id, BlockId(0)); + + let preheader = func.blocks.iter().find(|b| b.id == preheader_id).unwrap(); + assert_eq!(preheader.instructions.len(), 1); + assert!(matches!( + preheader.instructions[0], + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + } + )); + assert!(matches!( + preheader.terminator, + IrTerminator::Jump { target: BlockId(0) } + )); + + // Original header (B0) should be empty. + let header = func.blocks.iter().find(|b| b.id == BlockId(0)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── Mixed: some hoistable, some not ────────────────────────────────── + + #[test] + fn mixed_hoistable_and_non_hoistable() { + use crate::ir::MemoryAccessWidth; + + // B0: Jump(B1) + // B1 (header): v0 = Const(100), v1 = Load(v0), v2 = Const(200) + // BranchIf(v3, B2, B3) + // B2: Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(100), + }, + IrInstr::Load { + dest: VarId(1), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }, + IrInstr::Const { + dest: VarId(2), + value: IrValue::I32(200), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(3), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // v0 and v2 (Consts) should be hoisted; Load stays. + assert_eq!(func.blocks[0].instructions.len(), 2); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(100), + } + )); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Const { + dest: VarId(2), + value: IrValue::I32(200), + } + )); + + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 1); + assert!(matches!(header.instructions[0], IrInstr::Load { .. })); + } + + // ── Single-block function → no change ──────────────────────────────── + + #[test] + fn single_block_function_no_change() { + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + terminator: IrTerminator::Return { + value: Some(VarId(0)), + }, + }]); + + eliminate(&mut func); + + assert_eq!(func.blocks.len(), 1); + assert_eq!(func.blocks[0].instructions.len(), 1); + } + + // ── Dominator computation tests ────────────────────────────────────── + + #[test] + fn dominators_linear_chain() { + // B0 → B1 → B2 + let func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + let preds = build_predecessors(&func); + let dom = compute_dominators(&func, &preds); + + assert_eq!(dom[&BlockId(0)], HashSet::from([BlockId(0)])); + assert_eq!(dom[&BlockId(1)], HashSet::from([BlockId(0), BlockId(1)])); + assert_eq!( + dom[&BlockId(2)], + HashSet::from([BlockId(0), BlockId(1), BlockId(2)]) + ); + } + + #[test] + fn dominators_diamond() { + // B0 → B1, B0 → B2, B1 → B3, B2 → B3 + let func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + let preds = build_predecessors(&func); + let dom = compute_dominators(&func, &preds); + + // B3 is dominated by B0 (only common dominator of B1 and B2). + assert_eq!(dom[&BlockId(3)], HashSet::from([BlockId(0), BlockId(3)])); + } + + #[test] + fn back_edges_detected() { + // B0 → B1 → B2 → B1 (back edge) + let func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + ]); + + let preds = build_predecessors(&func); + let dom = compute_dominators(&func, &preds); + let back_edges = find_back_edges(&func, &dom); + + assert_eq!(back_edges.len(), 1); + assert_eq!(back_edges[0], (BlockId(2), BlockId(1))); + } + + #[test] + fn natural_loop_blocks() { + // B0 → B1 → B2 → B3 → B1 (back edge) + // Loop = {B1, B2, B3} + let func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + ]); + + let preds = build_predecessors(&func); + let dom = compute_dominators(&func, &preds); + let back_edges = find_back_edges(&func, &dom); + let loops = find_natural_loops(&back_edges, &preds); + + assert_eq!(loops.len(), 1); + let (header, loop_blocks) = &loops[0]; + assert_eq!(*header, BlockId(1)); + assert_eq!( + *loop_blocks, + HashSet::from([BlockId(1), BlockId(2), BlockId(3)]) + ); + } +} diff --git a/crates/herkos/src/optimizer/local_cse.rs b/crates/herkos/src/optimizer/local_cse.rs new file mode 100644 index 0000000..b920ab8 --- /dev/null +++ b/crates/herkos/src/optimizer/local_cse.rs @@ -0,0 +1,575 @@ +//! Local common subexpression elimination (CSE) via value numbering. +//! +//! Within each block, identifies identical computations and replaces duplicates +//! with references to the first result. Only side-effect-free instructions are +//! considered (`BinOp`, `UnOp`, `Const`). Duplicates are replaced with +//! `Assign { dest, src: previous_result }`, which copy propagation cleans up. + +use crate::ir::{BinOp, IrFunction, IrInstr, IrValue, UnOp, VarId}; +use std::collections::HashMap; + +use super::utils::prune_dead_locals; + +// ── Value key ──────────────────────────────────────────────────────────────── + +/// Hashable representation of a pure computation for deduplication. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum ValueKey { + /// Constant value (using bit-level equality for floats). + Const(ConstKey), + + /// Binary operation with operand variable IDs. + BinOp { op: BinOp, lhs: VarId, rhs: VarId }, + + /// Unary operation with operand variable ID. + UnOp { op: UnOp, operand: VarId }, +} + +/// Bit-level constant key that implements Eq/Hash correctly for floats. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum ConstKey { + I32(i32), + I64(i64), + F32(u32), + F64(u64), +} + +impl From for ConstKey { + fn from(v: IrValue) -> Self { + match v { + IrValue::I32(x) => ConstKey::I32(x), + IrValue::I64(x) => ConstKey::I64(x), + IrValue::F32(x) => ConstKey::F32(x.to_bits()), + IrValue::F64(x) => ConstKey::F64(x.to_bits()), + } + } +} + +// ── Commutative op detection ───────────────────────────────────────────────── + +/// Returns true for operations where `op(a, b) == op(b, a)`. +fn is_commutative(op: &BinOp) -> bool { + matches!( + op, + BinOp::I32Add + | BinOp::I32Mul + | BinOp::I32And + | BinOp::I32Or + | BinOp::I32Xor + | BinOp::I32Eq + | BinOp::I32Ne + | BinOp::I64Add + | BinOp::I64Mul + | BinOp::I64And + | BinOp::I64Or + | BinOp::I64Xor + | BinOp::I64Eq + | BinOp::I64Ne + | BinOp::F32Add + | BinOp::F32Mul + | BinOp::F32Eq + | BinOp::F32Ne + | BinOp::F64Add + | BinOp::F64Mul + | BinOp::F64Eq + | BinOp::F64Ne + ) +} + +/// Build a `ValueKey` for a `BinOp`, normalizing operand order for commutative ops. +fn binop_key(op: BinOp, lhs: VarId, rhs: VarId) -> ValueKey { + let (lhs, rhs) = if is_commutative(&op) && lhs.0 > rhs.0 { + (rhs, lhs) + } else { + (lhs, rhs) + }; + ValueKey::BinOp { op, lhs, rhs } +} + +// ── Pass entry point ───────────────────────────────────────────────────────── + +/// Eliminates common subexpressions within each block of `func`. +pub fn eliminate(func: &mut IrFunction) { + let mut changed = false; + + for block in &mut func.blocks { + // Maps a pure computation to the first VarId that computed it. + let mut value_map: HashMap = HashMap::new(); + + for instr in &mut block.instructions { + // In strict SSA form each variable is defined exactly once, so there + // is no need to invalidate cached CSE entries on redefinition. + match instr { + IrInstr::Const { dest, value } => { + let key = ValueKey::Const(ConstKey::from(*value)); + if let Some(&first) = value_map.get(&key) { + *instr = IrInstr::Assign { + dest: *dest, + src: first, + }; + changed = true; + } else { + value_map.insert(key, *dest); + } + } + + IrInstr::BinOp { + dest, op, lhs, rhs, .. + } => { + let key = binop_key(*op, *lhs, *rhs); + if let Some(&first) = value_map.get(&key) { + *instr = IrInstr::Assign { + dest: *dest, + src: first, + }; + changed = true; + } else { + value_map.insert(key, *dest); + } + } + + IrInstr::UnOp { dest, op, operand } => { + let key = ValueKey::UnOp { + op: *op, + operand: *operand, + }; + if let Some(&first) = value_map.get(&key) { + *instr = IrInstr::Assign { + dest: *dest, + src: first, + }; + changed = true; + } else { + value_map.insert(key, *dest); + } + } + + // All other instructions are not eligible for CSE. + _ => {} + } + } + } + + if changed { + prune_dead_locals(func); + } +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{BlockId, IrBlock, IrTerminator, TypeIdx, WasmType}; + + /// Helper: create a minimal IrFunction with the given blocks. + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + } + } + + /// Helper: create a block with given instructions and a simple return terminator. + fn make_block(id: u32, instructions: Vec) -> IrBlock { + IrBlock { + id: BlockId(id), + instructions, + terminator: IrTerminator::Return { value: None }, + } + } + + #[test] + fn duplicate_binop_is_eliminated() { + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + let block = &func.blocks[0]; + assert!(matches!(block.instructions[0], IrInstr::BinOp { .. })); + assert!( + matches!( + block.instructions[1], + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } + ), + "Duplicate BinOp should be replaced with Assign" + ); + } + + #[test] + fn commutative_binop_is_deduplicated() { + // v2 = v0 + v1, v3 = v1 + v0 → v3 should become Assign from v2 + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Add, + lhs: VarId(1), + rhs: VarId(0), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + assert!( + matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } + ), + "Commutative BinOp with swapped operands should be deduplicated" + ); + } + + #[test] + fn non_commutative_binop_not_deduplicated() { + // v2 = v0 - v1, v3 = v1 - v0 → different computations, keep both + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Sub, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Sub, + lhs: VarId(1), + rhs: VarId(0), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::BinOp { .. } + )); + assert!( + matches!(func.blocks[0].instructions[1], IrInstr::BinOp { .. }), + "Non-commutative BinOp with swapped operands should NOT be deduplicated" + ); + } + + #[test] + fn duplicate_unop_is_eliminated() { + let instrs = vec![ + IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32Clz, + operand: VarId(0), + }, + IrInstr::UnOp { + dest: VarId(2), + op: UnOp::I32Clz, + operand: VarId(0), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(1), WasmType::I32), (VarId(2), WasmType::I32)]; + eliminate(&mut func); + + assert!( + matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(2), + src: VarId(1) + } + ), + "Duplicate UnOp should be replaced with Assign" + ); + } + + #[test] + fn duplicate_const_is_eliminated() { + let instrs = vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(0), WasmType::I32), (VarId(1), WasmType::I32)]; + eliminate(&mut func); + + assert!( + matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(1), + src: VarId(0) + } + ), + "Duplicate Const should be replaced with Assign" + ); + } + + #[test] + fn float_const_nan_bits_handled() { + // Two NaN constants with the same bit pattern should be deduplicated. + let instrs = vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::F32(f32::NAN), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::F32(f32::NAN), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(0), WasmType::F32), (VarId(1), WasmType::F32)]; + eliminate(&mut func); + + assert!( + matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(1), + src: VarId(0) + } + ), + "NaN constants with same bit pattern should be deduplicated" + ); + } + + #[test] + fn different_ops_not_deduplicated() { + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Sub, + lhs: VarId(0), + rhs: VarId(1), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::BinOp { .. } + )); + assert!( + matches!(func.blocks[0].instructions[1], IrInstr::BinOp { .. }), + "Different operations should not be deduplicated" + ); + } + + #[test] + fn cross_block_not_deduplicated() { + // Each block should have its own value map — no cross-block CSE. + let block0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }; + let block1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::Return { value: None }, + }; + + let mut func = make_func(vec![block0, block1]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + // Both should remain as BinOp (no cross-block elimination). + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::BinOp { .. } + )); + assert!( + matches!(func.blocks[1].instructions[0], IrInstr::BinOp { .. }), + "Cross-block duplicate should NOT be eliminated" + ); + } + + /// In strict SSA form every variable is defined exactly once within a block, + /// so (v0 + v1) always refers to the same computation and can be CSE'd. + #[test] + fn ssa_unique_defs_allow_cse() { + // v2 = v0 + v1 ← first occurrence + // v3 = v0 + v1 ← identical keys with same VarIds → should be eliminated + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + // v3 should be eliminated to Assign(v3, v2). + assert!( + matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } + ), + "duplicate (v0 + v1) should be CSE'd to Assign in strict SSA" + ); + } + + #[test] + fn side_effect_instructions_not_eliminated() { + // Load, Store, Call, etc. should never be CSE'd. + use crate::ir::MemoryAccessWidth; + + let instrs = vec![ + IrInstr::Load { + dest: VarId(1), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }, + IrInstr::Load { + dest: VarId(2), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(1), WasmType::I32), (VarId(2), WasmType::I32)]; + eliminate(&mut func); + + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Load { .. } + )); + assert!( + matches!(func.blocks[0].instructions[1], IrInstr::Load { .. }), + "Load instructions should not be CSE'd" + ); + } + + #[test] + fn triple_duplicate_eliminates_both() { + // Three identical BinOps: second and third should become Assigns to first. + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Mul, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Mul, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(4), + op: BinOp::I32Mul, + lhs: VarId(0), + rhs: VarId(1), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![ + (VarId(2), WasmType::I32), + (VarId(3), WasmType::I32), + (VarId(4), WasmType::I32), + ]; + eliminate(&mut func); + + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::BinOp { .. } + )); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } + )); + assert!(matches!( + func.blocks[0].instructions[2], + IrInstr::Assign { + dest: VarId(4), + src: VarId(2) + } + )); + } +} diff --git a/crates/herkos/src/optimizer/merge_blocks.rs b/crates/herkos/src/optimizer/merge_blocks.rs new file mode 100644 index 0000000..583fd8b --- /dev/null +++ b/crates/herkos/src/optimizer/merge_blocks.rs @@ -0,0 +1,384 @@ +//! Single-predecessor block merging. +//! +//! When a block `B` has exactly one predecessor `P`, and `P` reaches `B` via an +//! unconditional `Jump`, then `B` can be appended to `P` — its instructions are +//! concatenated and `P` inherits `B`'s terminator. +//! +//! The pass iterates to a fixed point so that chains like +//! B0 → Jump → B1 → Jump → B2 → Return +//! collapse into a single block B0 → Return. +//! +//! After merging, absorbed blocks are removed from `func.blocks`. + +use super::utils::build_predecessors; +use crate::ir::{BlockId, IrFunction, IrTerminator}; +use std::collections::{HashMap, HashSet}; + +/// Merge single-predecessor blocks reached via unconditional `Jump`. +/// +/// Iterates to a fixed point, then removes absorbed blocks. +pub fn eliminate(func: &mut IrFunction) { + loop { + let preds = build_predecessors(func); + + // Index blocks by ID for lookup during merging. + let block_map: HashMap = func + .blocks + .iter() + .enumerate() + .map(|(i, b)| (b.id, i)) + .collect(); + + // Collect merge pairs: (predecessor_idx, target_idx) where target has + // exactly one predecessor and that predecessor reaches it via Jump. + let mut merges: Vec<(usize, usize)> = Vec::new(); + // Track which blocks are already involved in a merge this round to avoid + // conflicting operations (a block can't be both a merge source and target + // in the same round). + let mut involved: HashSet = HashSet::new(); + + for block in &func.blocks { + if let IrTerminator::Jump { target } = block.terminator { + // Skip self-loops. + if target == block.id { + continue; + } + // Never merge away the entry block. + if target == func.entry_block { + continue; + } + if let Some(pred_set) = preds.get(&target) { + if pred_set.len() == 1 { + let pred_idx = block_map[&block.id]; + let target_idx = block_map[&target]; + // Avoid conflicts: each block participates in at most one + // merge per round. + if !involved.contains(&pred_idx) && !involved.contains(&target_idx) { + merges.push((pred_idx, target_idx)); + involved.insert(pred_idx); + involved.insert(target_idx); + } + } + } + } + } + + if merges.is_empty() { + break; + } + + // Perform merges. We collect the target block data first to avoid borrow + // conflicts on func.blocks. + let absorbed_sorted = { + let mut absorbed: Vec = merges.iter().map(|(_, t)| *t).collect(); + absorbed.sort_unstable_by(|a, b| b.cmp(a)); + absorbed + }; + + for (pred_idx, target_idx) in &merges { + // Take target block's data out. + let target_instrs = std::mem::take(&mut func.blocks[*target_idx].instructions); + let target_term = std::mem::replace( + &mut func.blocks[*target_idx].terminator, + IrTerminator::Unreachable, + ); + // Append to predecessor. + func.blocks[*pred_idx].instructions.extend(target_instrs); + func.blocks[*pred_idx].terminator = target_term; + } + + // Remove absorbed blocks (iterate in reverse to preserve indices). + for idx in absorbed_sorted { + func.blocks.remove(idx); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{IrBlock, IrFunction, IrInstr, IrTerminator, IrValue, TypeIdx, VarId}; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + } + } + + fn block_ids(func: &IrFunction) -> Vec { + func.blocks.iter().map(|b| b.id.0).collect() + } + + fn instr_block(id: u32, dest: u32, val: i32, term: IrTerminator) -> IrBlock { + IrBlock { + id: BlockId(id), + instructions: vec![IrInstr::Const { + dest: VarId(dest), + value: IrValue::I32(val), + }], + terminator: term, + } + } + + // ── Basic cases ────────────────────────────────────────────────────── + + #[test] + fn single_block_unchanged() { + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }]); + eliminate(&mut func); + assert_eq!(block_ids(&func), vec![0]); + } + + #[test] + fn linear_chain_collapses() { + // B0 → Jump → B1 → Jump → B2 → Return + let mut func = make_func(vec![ + instr_block(0, 0, 1, IrTerminator::Jump { target: BlockId(1) }), + instr_block(1, 1, 2, IrTerminator::Jump { target: BlockId(2) }), + instr_block( + 2, + 2, + 3, + IrTerminator::Return { + value: Some(VarId(2)), + }, + ), + ]); + eliminate(&mut func); + // All merged into B0. + assert_eq!(block_ids(&func), vec![0]); + assert_eq!(func.blocks[0].instructions.len(), 3); + assert!(matches!( + func.blocks[0].terminator, + IrTerminator::Return { + value: Some(VarId(2)) + } + )); + } + + #[test] + fn conditional_predecessor_not_merged() { + // B0: BranchIf → B1 / B2 — both have 1 predecessor but via conditional + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + eliminate(&mut func); + // Nothing merged — conditional edges are not Jump. + assert_eq!(block_ids(&func), vec![0, 1, 2]); + } + + #[test] + fn multiple_predecessors_not_merged() { + // B0 → Jump → B2, B1 → Jump → B2 — B2 has 2 predecessors + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + // B1 is dead (no predecessor), but merge_blocks doesn't remove dead blocks. + // B2 has 2 predecessors (B0 and B1) → not merged. + eliminate(&mut func); + assert_eq!(block_ids(&func), vec![0, 1, 2]); + } + + #[test] + fn self_loop_not_merged() { + // B0 → Jump → B0 (self-loop) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(0) }, + }]); + eliminate(&mut func); + assert_eq!(block_ids(&func), vec![0]); + } + + #[test] + fn entry_block_not_absorbed() { + // B1 → Jump → B0 (entry) — B0 has 1 predecessor but is entry + let mut func = IrFunction { + params: vec![], + locals: vec![], + blocks: vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(0) }, + }, + ], + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + needs_host: false, + }; + eliminate(&mut func); + // B0 is entry, must not be absorbed into B1. + assert!(func.blocks.iter().any(|b| b.id == BlockId(0))); + } + + // ── Fixed-point iteration ────────────────────────────────────────── + + #[test] + fn fixed_point_three_block_chain() { + // B0 → B1 → B2 → B3 → Return + // Round 1: B1→B0, B3→B2. Round 2: B2→B0. + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + eliminate(&mut func); + assert_eq!(block_ids(&func), vec![0]); + assert!(matches!( + func.blocks[0].terminator, + IrTerminator::Return { value: None } + )); + } + + // ── Realistic pattern ────────────────────────────────────────────── + + #[test] + fn jump_then_branch_merges_prologue() { + // B0 → Jump → B1 → BranchIf(B2, B3) + // B2: Return, B3: Return + // B1 has 1 predecessor (B0) via Jump → merge. + let mut func = make_func(vec![ + instr_block(0, 0, 10, IrTerminator::Jump { target: BlockId(1) }), + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(20), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + eliminate(&mut func); + // B1 merged into B0. + assert_eq!(block_ids(&func), vec![0, 2, 3]); + assert_eq!(func.blocks[0].instructions.len(), 2); + assert!(matches!( + func.blocks[0].terminator, + IrTerminator::BranchIf { .. } + )); + } + + #[test] + fn loop_back_edge_prevents_merge() { + // B0 → Jump → B1 → BranchIf(B2, B3) + // B2 → Jump → B1 (back-edge, B1 now has 2 predecessors) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + eliminate(&mut func); + // B1 has 2 predecessors (B0 and B2) → not merged. + // No blocks are mergeable. + assert_eq!(block_ids(&func), vec![0, 1, 2, 3]); + } +} diff --git a/crates/herkos/src/optimizer/mod.rs b/crates/herkos/src/optimizer/mod.rs index fb666b1..3cc1792 100644 --- a/crates/herkos/src/optimizer/mod.rs +++ b/crates/herkos/src/optimizer/mod.rs @@ -6,17 +6,65 @@ //! Each optimization is a self-contained sub-module. The top-level //! [`optimize_ir`] function runs all passes in order. -use crate::ir::ModuleInfo; +use crate::ir::LoweredModuleInfo; use anyhow::Result; +// ── Shared utilities ───────────────────────────────────────────────────────── +pub(crate) mod utils; + // ── Passes ─────────────────────────────────────────────────────────────────── +mod algebraic; +mod branch_fold; +mod const_prop; +mod copy_prop; mod dead_blocks; +mod dead_instrs; +mod empty_blocks; +mod gvn; +mod licm; +mod local_cse; +mod merge_blocks; /// Optimizes the IR representation by running all passes in order. -pub fn optimize_ir(module_info: ModuleInfo) -> Result { +/// +/// Expects a [`LoweredModuleInfo`] — i.e. phi nodes have already been lowered +/// by [`crate::ir::lower_phis::lower`] before calling this function. +pub fn optimize_ir(module_info: LoweredModuleInfo) -> Result { let mut module_info = module_info; for func in &mut module_info.ir_functions { - dead_blocks::eliminate(func)?; + // Two structural passes: dead_instrs may have emptied blocks that + // lower_phis had populated with now-dead assignments (e.g. loop-exit + // locals that were never used after the join block). Re-run structural + // cleanup to remove those passthrough blocks. + for _ in 0..2 { + // Empty block optimizations + empty_blocks::eliminate(func); + dead_blocks::eliminate(func)?; + + // Control flow optimizations + merge_blocks::eliminate(func); + dead_blocks::eliminate(func)?; + + // Value optimizations + const_prop::eliminate(func); + algebraic::eliminate(func); + copy_prop::eliminate(func); + + // Redundancy elimination + local_cse::eliminate(func); + gvn::eliminate(func); + copy_prop::eliminate(func); + dead_instrs::eliminate(func); + + // Branch simplification + branch_fold::eliminate(func); + dead_instrs::eliminate(func); + + // Loop optimization + licm::eliminate(func); + dead_instrs::eliminate(func); + copy_prop::eliminate(func); + } } Ok(module_info) } @@ -25,62 +73,5 @@ pub fn optimize_ir(module_info: ModuleInfo) -> Result { #[cfg(test)] mod tests { - use crate::ir::{BlockId, IrBlock, IrFunction, IrTerminator, ModuleInfo, TypeIdx}; - - #[test] - fn optimize_ir_eliminates_dead_blocks_across_functions() { - let make_ir_func = |blocks: Vec| IrFunction { - params: vec![], - locals: vec![], - blocks, - entry_block: BlockId(0), - return_type: None, - type_idx: TypeIdx::new(0), - needs_host: false, - }; - - let module = ModuleInfo { - ir_functions: vec![ - // func 0: block_0 → Return, block_1 dead - make_ir_func(vec![ - IrBlock { - id: BlockId(0), - instructions: vec![], - terminator: IrTerminator::Return { value: None }, - }, - IrBlock { - id: BlockId(1), - instructions: vec![], - terminator: IrTerminator::Return { value: None }, - }, - ]), - // func 1: block_0 → Jump → block_1 → Return (all live) - make_ir_func(vec![ - IrBlock { - id: BlockId(0), - instructions: vec![], - terminator: IrTerminator::Jump { target: BlockId(1) }, - }, - IrBlock { - id: BlockId(1), - instructions: vec![], - terminator: IrTerminator::Return { value: None }, - }, - ]), - ], - ..Default::default() - }; - - let result = super::optimize_ir(module).unwrap(); - assert_eq!( - result.ir_functions[0].blocks.len(), - 1, - "dead block in func 0 should be removed" - ); - assert_eq!( - result.ir_functions[1].blocks.len(), - 2, - "both blocks in func 1 should be kept" - ); - } + // TODO: Add tests that verify the correctness of the optimized IR and the generated code. } diff --git a/crates/herkos/src/optimizer/utils.rs b/crates/herkos/src/optimizer/utils.rs new file mode 100644 index 0000000..fd4fdd4 --- /dev/null +++ b/crates/herkos/src/optimizer/utils.rs @@ -0,0 +1,644 @@ +//! Shared utility functions for IR optimization passes. +//! +//! Provides common operations on IR instructions, terminators, and control flow +//! that are needed by multiple optimization passes. +//! +//! Some functions are not yet used by existing passes but are provided for +//! upcoming optimization passes (const_prop, dead_instrs, local_cse, licm). +#![allow(dead_code)] + +use crate::ir::{BinOp, BlockId, IrFunction, IrInstr, IrTerminator, UnOp, VarId}; +use std::collections::{HashMap, HashSet}; + +// ── Terminator successors ──────────────────────────────────────────────────── + +/// Returns the successor block IDs for a terminator. +pub fn terminator_successors(term: &IrTerminator) -> Vec { + match term { + IrTerminator::Return { .. } | IrTerminator::Unreachable => vec![], + IrTerminator::Jump { target } => vec![*target], + IrTerminator::BranchIf { + if_true, if_false, .. + } => vec![*if_true, *if_false], + IrTerminator::BranchTable { + targets, default, .. + } => targets + .iter() + .chain(std::iter::once(default)) + .copied() + .collect(), + } +} + +// ── Predecessor map ────────────────────────────────────────────────────────── + +/// Build a map from each block ID to the set of *distinct* predecessor block IDs. +pub fn build_predecessors(func: &IrFunction) -> HashMap> { + let mut preds: HashMap> = HashMap::new(); + // Ensure every block has an entry (even if no predecessors). + for block in &func.blocks { + preds.entry(block.id).or_default(); + } + for block in &func.blocks { + for succ in terminator_successors(&block.terminator) { + preds.entry(succ).or_default().insert(block.id); + } + } + preds +} + +// ── Instruction variable traversal ─────────────────────────────────────────── + +/// Calls `f` with every variable read by `instr`. +pub fn for_each_use(instr: &IrInstr, mut f: F) { + match instr { + IrInstr::Const { .. } => {} + IrInstr::BinOp { lhs, rhs, .. } => { + f(*lhs); + f(*rhs); + } + IrInstr::UnOp { operand, .. } => { + f(*operand); + } + IrInstr::Load { addr, .. } => { + f(*addr); + } + IrInstr::Store { addr, value, .. } => { + f(*addr); + f(*value); + } + IrInstr::Call { args, .. } | IrInstr::CallImport { args, .. } => { + for a in args { + f(*a); + } + } + IrInstr::CallIndirect { + table_idx, args, .. + } => { + f(*table_idx); + for a in args { + f(*a); + } + } + IrInstr::Assign { src, .. } => { + f(*src); + } + IrInstr::GlobalGet { .. } => {} + IrInstr::GlobalSet { value, .. } => { + f(*value); + } + IrInstr::MemorySize { .. } => {} + IrInstr::MemoryGrow { delta, .. } => { + f(*delta); + } + IrInstr::MemoryCopy { dst, src, len } => { + f(*dst); + f(*src); + f(*len); + } + IrInstr::Select { + val1, + val2, + condition, + .. + } => { + f(*val1); + f(*val2); + f(*condition); + } + IrInstr::Phi { .. } => unreachable!("Phi nodes must be lowered before optimization"), + } +} + +/// Calls `f` with every variable read by a block terminator. +pub fn for_each_use_terminator(term: &IrTerminator, mut f: F) { + match term { + IrTerminator::Return { value: Some(v) } => { + f(*v); + } + IrTerminator::Return { value: None } + | IrTerminator::Jump { .. } + | IrTerminator::Unreachable => {} + IrTerminator::BranchIf { condition, .. } => { + f(*condition); + } + IrTerminator::BranchTable { index, .. } => { + f(*index); + } + } +} + +// ── Instruction destination ────────────────────────────────────────────────── + +/// Returns the variable written by `instr`, or `None` for side-effect-only instructions. +pub fn instr_dest(instr: &IrInstr) -> Option { + match instr { + IrInstr::Const { dest, .. } + | IrInstr::BinOp { dest, .. } + | IrInstr::UnOp { dest, .. } + | IrInstr::Load { dest, .. } + | IrInstr::Assign { dest, .. } + | IrInstr::GlobalGet { dest, .. } + | IrInstr::MemorySize { dest } + | IrInstr::MemoryGrow { dest, .. } + | IrInstr::Select { dest, .. } => Some(*dest), + + IrInstr::Call { dest, .. } + | IrInstr::CallImport { dest, .. } + | IrInstr::CallIndirect { dest, .. } => *dest, + + IrInstr::Store { .. } | IrInstr::GlobalSet { .. } | IrInstr::MemoryCopy { .. } => None, + + IrInstr::Phi { .. } => unreachable!("Phi nodes must be lowered before optimization"), + } +} + +/// Redirects the destination variable of `instr` to `new_dest`. +/// +/// Only called when `instr_dest(instr)` is `Some(_)`, i.e. the instruction +/// produces a value. Instructions without a dest are left unchanged. +pub fn set_instr_dest(instr: &mut IrInstr, new_dest: VarId) { + match instr { + IrInstr::Const { dest, .. } + | IrInstr::BinOp { dest, .. } + | IrInstr::UnOp { dest, .. } + | IrInstr::Load { dest, .. } + | IrInstr::Assign { dest, .. } + | IrInstr::GlobalGet { dest, .. } + | IrInstr::MemorySize { dest } + | IrInstr::MemoryGrow { dest, .. } + | IrInstr::Select { dest, .. } => { + *dest = new_dest; + } + IrInstr::Call { dest, .. } + | IrInstr::CallImport { dest, .. } + | IrInstr::CallIndirect { dest, .. } => { + *dest = Some(new_dest); + } + // No dest — unreachable given precondition, but harmless to ignore. + IrInstr::Store { .. } | IrInstr::GlobalSet { .. } | IrInstr::MemoryCopy { .. } => {} + + IrInstr::Phi { .. } => unreachable!("Phi nodes must be lowered before optimization"), + } +} + +// ── Use-count helpers ──────────────────────────────────────────────────────── + +/// Count how many times `var` appears as an operand (read) in `instr`. +pub fn count_uses_of(instr: &IrInstr, var: VarId) -> usize { + let mut count = 0usize; + for_each_use(instr, |v| { + if v == var { + count += 1; + } + }); + count +} + +/// Count how many times `var` appears as an operand in `term`. +pub fn count_uses_of_terminator(term: &IrTerminator, var: VarId) -> usize { + let mut count = 0usize; + for_each_use_terminator(term, |v| { + if v == var { + count += 1; + } + }); + count +} + +// ── Use-replacement helpers ────────────────────────────────────────────────── + +/// Replace every read-occurrence of `old` with `new` in `instr`. +/// Only touches operand (source) slots; the destination slot is never modified. +pub fn replace_uses_of(instr: &mut IrInstr, old: VarId, new: VarId) { + let sub = |v: &mut VarId| { + if *v == old { + *v = new; + } + }; + match instr { + IrInstr::Const { .. } => {} + IrInstr::BinOp { lhs, rhs, .. } => { + sub(lhs); + sub(rhs); + } + IrInstr::UnOp { operand, .. } => { + sub(operand); + } + IrInstr::Load { addr, .. } => { + sub(addr); + } + IrInstr::Store { addr, value, .. } => { + sub(addr); + sub(value); + } + IrInstr::Call { args, .. } | IrInstr::CallImport { args, .. } => { + for a in args { + sub(a); + } + } + IrInstr::CallIndirect { + table_idx, args, .. + } => { + sub(table_idx); + for a in args { + sub(a); + } + } + IrInstr::Assign { src, .. } => { + sub(src); + } + IrInstr::GlobalGet { .. } => {} + IrInstr::GlobalSet { value, .. } => { + sub(value); + } + IrInstr::MemorySize { .. } => {} + IrInstr::MemoryGrow { delta, .. } => { + sub(delta); + } + IrInstr::MemoryCopy { dst, src, len } => { + sub(dst); + sub(src); + sub(len); + } + IrInstr::Select { + val1, + val2, + condition, + .. + } => { + sub(val1); + sub(val2); + sub(condition); + } + IrInstr::Phi { .. } => unreachable!("Phi nodes must be lowered before optimization"), + } +} + +/// Replace every read-occurrence of `old` with `new` in `term`. +pub fn replace_uses_of_terminator(term: &mut IrTerminator, old: VarId, new: VarId) { + let sub = |v: &mut VarId| { + if *v == old { + *v = new; + } + }; + match term { + IrTerminator::Return { value: Some(v) } => { + sub(v); + } + IrTerminator::Return { value: None } + | IrTerminator::Jump { .. } + | IrTerminator::Unreachable => {} + IrTerminator::BranchIf { condition, .. } => { + sub(condition); + } + IrTerminator::BranchTable { index, .. } => { + sub(index); + } + } +} + +// ── Global use-count ───────────────────────────────────────────────────────── + +/// Counts how many times each variable is *read* across the entire function +/// (all blocks, all instructions, all terminators). +pub fn build_global_use_count(func: &IrFunction) -> HashMap { + let mut counts: HashMap = HashMap::new(); + for block in &func.blocks { + for instr in &block.instructions { + for_each_use(instr, |v| { + *counts.entry(v).or_insert(0) += 1; + }); + } + for_each_use_terminator(&block.terminator, |v| { + *counts.entry(v).or_insert(0) += 1; + }); + } + counts +} + +// ── Dead-local pruning ─────────────────────────────────────────────────────── + +/// Remove from `func.locals` any variable that no longer appears in any +/// instruction or terminator of any block. +pub fn prune_dead_locals(func: &mut IrFunction) { + // Collect all variables still referenced anywhere in the function. + let mut live: HashSet = HashSet::new(); + + for block in &func.blocks { + for instr in &block.instructions { + for_each_use(instr, |v| { + live.insert(v); + }); + if let Some(dest) = instr_dest(instr) { + live.insert(dest); + } + } + for_each_use_terminator(&block.terminator, |v| { + live.insert(v); + }); + } + + // Keep params unconditionally; prune locals that are not in `live`. + func.locals.retain(|(var, _)| live.contains(var)); +} + +// ── Side-effect classification ─────────────────────────────────────────────── + +/// Returns `true` if the instruction is side-effect-free and can be safely +/// removed when its result is unused. +/// +/// Instructions that may trap (Load, MemoryGrow, integer div/rem, float-to-int +/// truncation), modify external state (Store, GlobalSet, MemoryCopy), or have +/// unknown effects (Call*) are considered side-effectful and must be retained +/// even if their result is unused — removing them would suppress a Wasm trap. +pub fn is_side_effect_free(instr: &IrInstr) -> bool { + match instr { + // Integer division and remainder trap on divisor == 0 (and i*::MIN / -1 + // for signed division). Must be preserved even when the result is dead. + IrInstr::BinOp { op, .. } => !matches!( + op, + BinOp::I32DivS + | BinOp::I32DivU + | BinOp::I32RemS + | BinOp::I32RemU + | BinOp::I64DivS + | BinOp::I64DivU + | BinOp::I64RemS + | BinOp::I64RemU + ), + // Float-to-integer truncations trap on NaN or out-of-range inputs. + IrInstr::UnOp { op, .. } => !matches!( + op, + UnOp::I32TruncF32S + | UnOp::I32TruncF32U + | UnOp::I32TruncF64S + | UnOp::I32TruncF64U + | UnOp::I64TruncF32S + | UnOp::I64TruncF32U + | UnOp::I64TruncF64S + | UnOp::I64TruncF64U + ), + IrInstr::Const { .. } + | IrInstr::Assign { .. } + | IrInstr::Select { .. } + | IrInstr::GlobalGet { .. } + | IrInstr::MemorySize { .. } => true, + IrInstr::Phi { .. } => unreachable!("Phi nodes must be lowered before optimization"), + _ => false, + } +} + +// ── Rewrite terminator block targets ───────────────────────────────────────── + +/// Rewrite all block-ID references in a terminator from `old` to `new`. +pub fn rewrite_terminator_target(term: &mut IrTerminator, old: BlockId, new: BlockId) { + let replace = |b: &mut BlockId| { + if *b == old { + *b = new; + } + }; + match term { + IrTerminator::Jump { target } => replace(target), + IrTerminator::BranchIf { + if_true, if_false, .. + } => { + replace(if_true); + replace(if_false); + } + IrTerminator::BranchTable { + targets, default, .. + } => { + for t in targets.iter_mut() { + replace(t); + } + replace(default); + } + IrTerminator::Return { .. } | IrTerminator::Unreachable => {} + } +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{BinOp, IrBlock, IrValue, WasmType}; + + #[test] + fn for_each_use_covers_binop() { + let instr = IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }; + let mut uses = vec![]; + for_each_use(&instr, |v| uses.push(v)); + assert_eq!(uses, vec![VarId(0), VarId(1)]); + } + + #[test] + fn instr_dest_returns_none_for_store() { + let instr = IrInstr::Store { + ty: WasmType::I32, + addr: VarId(0), + value: VarId(1), + offset: 0, + width: crate::ir::MemoryAccessWidth::Full, + }; + assert_eq!(instr_dest(&instr), None); + } + + #[test] + fn instr_dest_returns_some_for_const() { + let instr = IrInstr::Const { + dest: VarId(5), + value: IrValue::I32(42), + }; + assert_eq!(instr_dest(&instr), Some(VarId(5))); + } + + #[test] + fn is_side_effect_free_classification() { + assert!(is_side_effect_free(&IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + })); + assert!(is_side_effect_free(&IrInstr::BinOp { + dest: VarId(0), + op: BinOp::I32Add, + lhs: VarId(1), + rhs: VarId(2), + })); + assert!(is_side_effect_free(&IrInstr::Assign { + dest: VarId(0), + src: VarId(1), + })); + assert!(!is_side_effect_free(&IrInstr::Store { + ty: WasmType::I32, + addr: VarId(0), + value: VarId(1), + offset: 0, + width: crate::ir::MemoryAccessWidth::Full, + })); + assert!(!is_side_effect_free(&IrInstr::Load { + dest: VarId(0), + ty: WasmType::I32, + addr: VarId(1), + offset: 0, + width: crate::ir::MemoryAccessWidth::Full, + sign: None, + })); + } + + #[test] + fn trapping_binops_not_side_effect_free() { + // Integer div/rem must NOT be classified as side-effect-free because they + // can trap at runtime (division by zero, i*::MIN / -1 for signed div). + for op in [ + BinOp::I32DivS, + BinOp::I32DivU, + BinOp::I32RemS, + BinOp::I32RemU, + BinOp::I64DivS, + BinOp::I64DivU, + BinOp::I64RemS, + BinOp::I64RemU, + ] { + let instr = IrInstr::BinOp { + dest: VarId(0), + op, + lhs: VarId(1), + rhs: VarId(2), + }; + assert!( + !is_side_effect_free(&instr), + "{op:?} should NOT be side-effect-free" + ); + } + // Non-trapping BinOps remain side-effect-free. + assert!(is_side_effect_free(&IrInstr::BinOp { + dest: VarId(0), + op: BinOp::I32Mul, + lhs: VarId(1), + rhs: VarId(2), + })); + } + + #[test] + fn trapping_unops_not_side_effect_free() { + use crate::ir::UnOp; + // Float-to-integer truncations trap on NaN or out-of-range values. + for op in [ + UnOp::I32TruncF32S, + UnOp::I32TruncF32U, + UnOp::I32TruncF64S, + UnOp::I32TruncF64U, + UnOp::I64TruncF32S, + UnOp::I64TruncF32U, + UnOp::I64TruncF64S, + UnOp::I64TruncF64U, + ] { + let instr = IrInstr::UnOp { + dest: VarId(0), + op, + operand: VarId(1), + }; + assert!( + !is_side_effect_free(&instr), + "{op:?} should NOT be side-effect-free" + ); + } + // Non-trapping UnOp remains side-effect-free. + assert!(is_side_effect_free(&IrInstr::UnOp { + dest: VarId(0), + op: UnOp::I32Clz, + operand: VarId(1), + })); + } + + #[test] + fn terminator_successors_coverage() { + assert_eq!( + terminator_successors(&IrTerminator::Return { value: None }), + vec![] + ); + assert_eq!( + terminator_successors(&IrTerminator::Jump { target: BlockId(3) }), + vec![BlockId(3)] + ); + assert_eq!( + terminator_successors(&IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(1), + if_false: BlockId(2), + }), + vec![BlockId(1), BlockId(2)] + ); + } + + #[test] + fn build_predecessors_simple() { + let func = IrFunction { + params: vec![], + locals: vec![], + blocks: vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ], + entry_block: BlockId(0), + return_type: None, + type_idx: crate::ir::TypeIdx::new(0), + needs_host: false, + }; + let preds = build_predecessors(&func); + assert!(preds[&BlockId(0)].is_empty()); + assert_eq!(preds[&BlockId(1)], HashSet::from([BlockId(0)])); + } + + #[test] + fn replace_uses_of_substitutes_correctly() { + let mut instr = IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(0), + }; + replace_uses_of(&mut instr, VarId(0), VarId(5)); + match &instr { + IrInstr::BinOp { lhs, rhs, .. } => { + assert_eq!(*lhs, VarId(5)); + assert_eq!(*rhs, VarId(5)); + } + _ => panic!("expected BinOp"), + } + } + + #[test] + fn rewrite_terminator_target_works() { + let mut term = IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(1), + if_false: BlockId(2), + }; + rewrite_terminator_target(&mut term, BlockId(1), BlockId(5)); + match &term { + IrTerminator::BranchIf { + if_true, if_false, .. + } => { + assert_eq!(*if_true, BlockId(5)); + assert_eq!(*if_false, BlockId(2)); + } + _ => panic!("expected BranchIf"), + } + } +} diff --git a/crates/herkos/tests/e2e.rs b/crates/herkos/tests/e2e.rs index 95efdf1..93b9cfc 100644 --- a/crates/herkos/tests/e2e.rs +++ b/crates/herkos/tests/e2e.rs @@ -102,9 +102,6 @@ fn test_constant_arithmetic() -> Result<()> { println!("Generated Rust code:\n{}", rust_code); assert!(rust_code.contains("pub fn func_0")); - assert!(rust_code.contains("10i32")); - assert!(rust_code.contains("20i32")); - assert!(rust_code.contains("wrapping_add")); assert!(rust_code.contains("return Ok(")); Ok(())