diff --git a/Cargo.toml b/Cargo.toml index 15977fc..57725ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ missing_const_for_fn = "deny" trivially_copy_pass_by_ref = "deny" cast_possible_truncation = "deny" explicit_iter_loop = "deny" -wildcard_enum_match_arm = "deny" +wildcard_enum_match_arm = "allow" indexing_slicing = "deny" self_named_module_files = "deny" precedence_bits = "deny" diff --git a/cli/src/main.rs b/cli/src/main.rs index acf4dfb..1f33bb1 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -49,7 +49,7 @@ fn stage(file: &PathBuf) -> Result<()> { checker::elaborate_program(&core_arena, &program).context("failed to elaborate program")?; drop(src_arena); - // Unstage into out_arena; src_arena and core_arena are no longer needed. + // Unstage into out_arena; core_arena is no longer needed after this. let out_arena = bumpalo::Bump::new(); let staged = eval::unstage_program(&out_arena, &core_program).context("failed to stage program")?; diff --git a/compiler/src/checker/mod.rs b/compiler/src/checker/mod.rs index dd36320..c5b0aa3 100644 --- a/compiler/src/checker/mod.rs +++ b/compiler/src/checker/mod.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use anyhow::{Context as _, Result, anyhow, bail, ensure}; -use crate::core::{self, IntType, IntWidth, Lvl, Prim}; +use crate::core::{self, IntType, IntWidth, Lam, Lvl, Pi, Prim, alpha_eq, subst}; use crate::parser::ast::{self, Phase}; /// Elaboration context. @@ -19,15 +19,16 @@ pub struct Ctx<'core, 'globals> { /// Local variables: (source name, core type) /// Indexed by De Bruijn level (0 = outermost in current scope, len-1 = most recent) locals: Vec<(&'core str, &'core core::Term<'core>)>, - /// Global function signatures: name -> signature. + /// Global function types: name -> Pi term. + /// Storing `&Term` (always a Pi) unifies type lookup for globals and locals. /// Borrowed independently of the arena so the map can live on the stack. - globals: &'globals HashMap, core::FunSig<'core>>, + globals: &'globals HashMap, &'core core::Term<'core>>, } impl<'core, 'globals> Ctx<'core, 'globals> { pub const fn new( arena: &'core bumpalo::Bump, - globals: &'globals HashMap, core::FunSig<'core>>, + globals: &'globals HashMap, &'core core::Term<'core>>, ) -> Self { Ctx { arena, @@ -110,7 +111,9 @@ impl<'core, 'globals> Ctx<'core, 'globals> { // Primitive types inhabit the relevant universe. core::Term::Prim(Prim::IntTy(it)) => core::Term::universe(it.phase), // Type, VmType, and [[T]] all inhabit Type (meta universe). - core::Term::Prim(Prim::U(_)) | core::Term::Lift(_) => &core::Term::TYPE, + core::Term::Prim(Prim::U(_)) | core::Term::Lift(_) | core::Term::Pi(_) => { + &core::Term::TYPE + } // Comparison ops return u1 at the operand phase. core::Term::Prim( @@ -127,15 +130,18 @@ impl<'core, 'globals> Ctx<'core, 'globals> { self.alloc(core::Term::Lift(core::Term::int_ty(*w, Phase::Object))) } - // Application: return type comes from the head. - core::Term::App(app) => match &app.head { - core::Head::Global(name) => { - self.globals - .get(name) - .expect("App/Global with unknown name (typechecker invariant)") - .ret_ty - } - core::Head::Prim(p) => match *p { + // Global reference: look up its Pi type directly from the globals table. + core::Term::Global(name) => self + .globals + .get(name) + .copied() + .expect("Global with unknown name (typechecker invariant)"), + + // App: dispatch on func. + // - Prim callee: return type is determined by the primitive. + // - Other callee: peel Pi types, substituting each arg. + core::Term::App(app) => match app.func { + core::Term::Prim(prim) => match prim { Prim::Add(it) | Prim::Sub(it) | Prim::Mul(it) @@ -150,14 +156,47 @@ impl<'core, 'globals> Ctx<'core, 'globals> { | Prim::Le(it) | Prim::Ge(it) => core::Term::u1_ty(it.phase), Prim::Embed(w) => { - self.alloc(core::Term::Lift(core::Term::int_ty(w, Phase::Object))) + self.alloc(core::Term::Lift(core::Term::int_ty(*w, Phase::Object))) } Prim::IntTy(_) | Prim::U(_) => { - unreachable!("type-level prim in App head (typechecker invariant)") + unreachable!("type-level prim in App (typechecker invariant)") } }, + _ => { + // Global function signatures are elaborated in an empty context, + // so the i-th Pi binder is at De Bruijn level (base_depth + i) + // where base_depth counts args already applied by outer Apps. + let base_depth = app_base_depth(app.func); + let func_ty = self.type_of(app.func); + match func_ty { + core::Term::Pi(pi) => { + // Substitute each arg for its corresponding Pi param. + let mut result = pi.body_ty; + for (i, arg) in app.args.iter().enumerate() { + result = subst(self.arena, result, Lvl(base_depth + i), arg); + } + result + } + _ => unreachable!( + "App func must have Pi type (typechecker invariant)" + ), + } + } }, + // Lam: synthesise Pi from params and body type. + core::Term::Lam(lam) => { + for &(name, ty) in lam.params { + self.push_local(name, ty); + } + let body_ty = self.type_of(lam.body); + for _ in lam.params { + self.pop_local(); + } + let params = self.alloc_slice(lam.params.iter().copied()); + self.alloc(core::Term::Pi(Pi { params, body_ty, phase: Phase::Meta })) + } + // #(t) : [[type_of(t)]] core::Term::Quote(inner) => { let inner_ty = self.type_of(inner); @@ -172,7 +211,10 @@ impl<'core, 'globals> Ctx<'core, 'globals> { core::Term::Var(_) | core::Term::Prim(_) | core::Term::Lit(..) + | core::Term::Global(_) | core::Term::App(_) + | core::Term::Pi(_) + | core::Term::Lam(_) | core::Term::Quote(_) | core::Term::Splice(_) | core::Term::Let(_) @@ -226,10 +268,11 @@ fn builtin_prim_ty(name: &str, phase: Phase) -> Option<&'static core::Term<'stat }) } +/// Elaborate one function's signature into a `Term::Pi` (the globals table entry). fn elaborate_sig<'src, 'core>( arena: &'core bumpalo::Bump, func: &ast::Function<'src>, -) -> Result> { +) -> Result<&'core core::Term<'core>> { let empty_globals = HashMap::new(); let mut ctx = Ctx::new(arena, &empty_globals); @@ -237,28 +280,24 @@ fn elaborate_sig<'src, 'core>( arena.alloc_slice_try_fill_iter(func.params.iter().map(|p| -> Result<_> { let param_name: &'core str = arena.alloc_str(p.name.as_str()); let param_ty = infer(&mut ctx, func.phase, p.ty)?; + ctx.push_local(param_name, param_ty); Ok((param_name, param_ty)) }))?; - let ret_ty = infer(&mut ctx, func.phase, func.ret_ty)?; + let body_ty = infer(&mut ctx, func.phase, func.ret_ty)?; - Ok(core::FunSig { - params, - ret_ty, - phase: func.phase, - }) + Ok(arena.alloc(core::Term::Pi(Pi { params, body_ty, phase: func.phase }))) } /// Pass 1: collect all top-level function signatures into a globals table. /// -/// Type annotations on parameters and return types are elaborated here so that -/// pass 2 (body elaboration) has fully-typed signatures available for all -/// functions, including forward references. +/// Each entry is a `Term::Pi` carrying the function's phase, param types, and return type. +/// This allows pass 2 to look up a global's type the same way it looks up a local's type. pub(crate) fn collect_signatures<'src, 'core>( arena: &'core bumpalo::Bump, program: &ast::Program<'src>, -) -> Result, core::FunSig<'core>>> { - let mut globals: HashMap, core::FunSig<'core>> = HashMap::new(); +) -> Result, &'core core::Term<'core>>> { + let mut globals: HashMap, &'core core::Term<'core>> = HashMap::new(); for func in program.functions { let name = core::Name::new(arena.alloc_str(func.name.as_str())); @@ -268,9 +307,9 @@ pub(crate) fn collect_signatures<'src, 'core>( "duplicate function name `{name}`" ); - let sig = elaborate_sig(arena, func).with_context(|| format!("in function `{name}`"))?; + let ty = elaborate_sig(arena, func).with_context(|| format!("in function `{name}`"))?; - globals.insert(name, sig); + globals.insert(name, ty); } Ok(globals) @@ -280,34 +319,30 @@ pub(crate) fn collect_signatures<'src, 'core>( fn elaborate_bodies<'src, 'core>( arena: &'core bumpalo::Bump, program: &ast::Program<'src>, - globals: &HashMap, core::FunSig<'core>>, + globals: &HashMap, &'core core::Term<'core>>, ) -> Result> { let functions: &'core [core::Function<'core>] = arena.alloc_slice_try_fill_iter(program.functions.iter().map(|func| -> Result<_> { let name = core::Name::new(arena.alloc_str(func.name.as_str())); - let ast_sig = globals.get(&name).expect("signature missing from pass 1"); + let ty = *globals.get(&name).expect("signature missing from pass 1"); + let pi = match ty { + core::Term::Pi(pi) => pi, + _ => unreachable!("globals table must contain Pi types"), + }; // Build a fresh context borrowing the stack-owned globals map. let mut ctx = Ctx::new(arena, globals); // Push parameters as locals so the body can reference them. - for (pname, pty) in ast_sig.params { + for (pname, pty) in pi.params { ctx.push_local(pname, pty); } // Elaborate the body, checking it against the declared return type. - let body = check(&mut ctx, ast_sig.phase, func.body, ast_sig.ret_ty) + let body = check(&mut ctx, pi.phase, func.body, pi.body_ty) .with_context(|| format!("in function `{name}`"))?; - // Re-borrow sig from globals (ctx was consumed in the check above). - // We need the sig fields for the Function; collect them before moving ctx. - let sig = core::FunSig { - params: ast_sig.params, - ret_ty: ast_sig.ret_ty, - phase: ast_sig.phase, - }; - - Ok(core::Function { name, sig, body }) + Ok(core::Function { name, ty, body }) }))?; Ok(core::Program { functions }) @@ -329,26 +364,59 @@ pub fn elaborate_program<'core>( /// - `U(Meta)` (Type) inhabits `U(Meta)` (type-in-type for the meta universe) /// - `U(Object)` (`VmType`) inhabits `U(Meta)` (the meta universe classifies object types) /// - `Lift(_)` inhabits `U(Meta)` -const fn type_universe(ty: &core::Term<'_>) -> Option { +/// - `Pi` inhabits `U(Meta)` (function types are meta-level) +/// - `Var(lvl)` — look up the variable's type in `locals`; if it is `U(p)`, it is a type in `p` +fn type_universe<'core>( + ty: &core::Term<'_>, + locals: &[(&'core str, &'core core::Term<'core>)], +) -> Option { match ty { core::Term::Prim(Prim::IntTy(IntType { phase, .. })) => Some(*phase), - core::Term::Prim(Prim::U(_)) | core::Term::Lift(_) => Some(Phase::Meta), - core::Term::Var(_) - | core::Term::Prim(_) + core::Term::Prim(Prim::U(_)) | core::Term::Lift(_) | core::Term::Pi(_) => Some(Phase::Meta), + // A type variable: its universe is determined by what universe its type inhabits. + // E.g. if `A : Type` (= U(Meta)), then A is a meta-level type. + core::Term::Var(lvl) => match locals.get(lvl.0)?.1 { + core::Term::Prim(Prim::U(phase)) => Some(*phase), + core::Term::Var(_) + | core::Term::Prim(_) + | core::Term::Lit(..) + | core::Term::Global(_) + | core::Term::App(_) + | core::Term::Pi(_) + | core::Term::Lam(_) + | core::Term::Lift(_) + | core::Term::Quote(_) + | core::Term::Splice(_) + | core::Term::Let(_) + | core::Term::Match(_) => None, + }, + core::Term::Prim(_) | core::Term::Lit(..) - | core::Term::App { .. } + | core::Term::Global(_) + | core::Term::App(_) + | core::Term::Lam(_) | core::Term::Quote(_) | core::Term::Splice(_) - | core::Term::Let { .. } - | core::Term::Match { .. } => None, + | core::Term::Let(_) + | core::Term::Match(_) => None, } } -/// Structural equality of core types (no normalisation needed for this prototype). +/// Type equality: alpha-equality (ignores param names in Pi/Lam). fn types_equal(a: &core::Term<'_>, b: &core::Term<'_>) -> bool { - // Uses pointer equality as a fast path — terms allocated from the same arena - // slot are guaranteed identical without recursion. - std::ptr::eq(a, b) || a == b + alpha_eq(a, b) +} + +/// Count the total number of arguments already applied by nested `App` nodes. +/// +/// Used to determine which Pi binder level to target during dependent-return-type +/// substitution: global function signatures are elaborated in an empty context, so the +/// binder introduced by the i-th Pi in the chain sits at De Bruijn level i. +fn app_base_depth(term: &core::Term<'_>) -> usize { + match term { + core::Term::App(app) => app_base_depth(app.func) + app.args.len(), + _ => 0, + } } /// Synthesise and return the elaborated core term; recover its type via `ctx.type_of`. @@ -374,11 +442,16 @@ pub fn infer<'src, 'core>( } return Ok(term); } - // Otherwise look in locals. - let (lvl, _) = ctx - .lookup_local(name_str) - .ok_or_else(|| anyhow!("unbound variable `{name_str}`"))?; - Ok(ctx.alloc(core::Term::Var(lvl))) + // Check locals. + if let Some((lvl, _)) = ctx.lookup_local(name_str) { + return Ok(ctx.alloc(core::Term::Var(lvl))); + } + // Check globals — bare reference without call, produces Global term. + let core_name = core::Name::new(ctx.arena.alloc_str(name_str)); + if ctx.globals.contains_key(&core_name) { + return Ok(ctx.alloc(core::Term::Global(core_name))); + } + Err(anyhow!("unbound variable `{name_str}`")) } // ------------------------------------------------------------------ Lit @@ -387,53 +460,59 @@ pub fn infer<'src, 'core>( "cannot infer type of a literal; add a type annotation" )), - // ------------------------------------------------------------------ App { Global } - // Look up the callee in globals, check each argument, return the return type. + // ------------------------------------------------------------------ App { Global or local } + // Function calls: look up callee, elaborate as curried FunApp chain. ast::Term::App { - func: ast::FunName::Name(name), + func: ast::FunName::Term(func_term), args, } => { - let sig = ctx - .globals - .get(name) - .ok_or_else(|| anyhow!("unknown function `{name}`"))?; - - // The call phase must match the current elaboration phase. - let call_phase = sig.phase; - ensure!( - call_phase == phase, - "function `{name}` is a {call_phase}-phase function, but called in {phase}-phase context" - ); - let params = sig.params; + // Elaborate the callee + let callee = infer(ctx, phase, func_term)?; + let callee_ty = ctx.type_of(callee); + + // Callee type must be Pi; arity must match. + let pi = match callee_ty { + core::Term::Pi(pi) => pi, + _ => bail!("callee is not a function type"), + }; + // For globals, verify phase matches (phase is now carried on the Pi itself). + if let core::Term::Global(gname) = callee { + ensure!( + pi.phase == phase, + "function `{gname}` is a {}-phase function, but called in {phase}-phase context", + pi.phase + ); + } ensure!( - args.len() == params.len(), - "function `{name}` expects {} argument(s), got {}", - params.len(), + args.len() == pi.params.len(), + "wrong number of arguments: callee expects {}, got {}", + pi.params.len(), args.len() ); - // Check each argument against its declared parameter type. - let core_args: &'core [&'core core::Term<'core>] = ctx - .arena - .alloc_slice_try_fill_iter(args.iter().zip(params.iter()).map( - |(arg, (pname, pty))| -> Result<_> { - let core_arg = check(ctx, call_phase, arg, pty) - .with_context(|| format!("in call to '{name}' argument '{pname}'"))?; - Ok(core_arg) - }, - ))?; + // Check each arg against its Pi param type. + // Global sigs are elaborated in an empty context, so param i is at De Bruijn level i. + // For dependent types, substitute earlier args into later param types. + let base = app_base_depth(callee); + let mut core_args: Vec<&'core core::Term<'core>> = Vec::with_capacity(args.len()); + for (i, (arg, &(_, mut param_ty))) in + args.iter().zip(pi.params.iter()).enumerate() + { + for (j, &earlier_arg) in core_args.iter().enumerate() { + param_ty = subst(ctx.arena, param_ty, Lvl(base + j), earlier_arg); + } + let core_arg = check(ctx, phase, arg, param_ty) + .with_context(|| format!("in argument {i} of function call"))?; + core_args.push(core_arg); + } - Ok(ctx.alloc(core::Term::new_app( - core::Head::Global(core::Name::new(ctx.arena.alloc_str(name.as_str()))), - core_args, - ))) + let args_slice = ctx.alloc_slice(core_args); + Ok(ctx.alloc(core::Term::new_app(callee, args_slice))) } // ------------------------------------------------------------------ App { Prim (BinOp/UnOp) } - // Arithmetic/bitwise ops are check-only (width comes from expected type). - // Comparison ops are inferable: they always return u1, and the operand type - // is inferred from the first argument (the second is checked to match). + // Comparison ops are inferable: they always return u1. ast::Term::App { func: ast::FunName::BinOp(op), args, @@ -452,25 +531,24 @@ pub fn infer<'src, 'core>( bail!("binary operation expects exactly 2 arguments") }; - // Infer the operand type from the first argument. let core_arg0 = infer(ctx, phase, lhs)?; let operand_ty = ctx.type_of(core_arg0); - // Check the second argument against the same operand type. let core_arg1 = check(ctx, phase, rhs, operand_ty)?; - // Verify both operands are integers and build the prim carrying the operand type. let op_int_ty = match operand_ty { core::Term::Prim(Prim::IntTy(it)) => *it, core::Term::Var(_) | core::Term::Prim(_) | core::Term::Lit(..) + | core::Term::Global(_) | core::Term::App(_) + | core::Term::Pi(_) + | core::Term::Lam(_) | core::Term::Lift(_) | core::Term::Quote(_) | core::Term::Splice(_) | core::Term::Let(_) | core::Term::Match(_) => { - ensure!(false, "comparison operands must be integers"); - unreachable!() + bail!("comparison operands must be integers"); } }; let prim = match op { @@ -488,7 +566,10 @@ pub fn infer<'src, 'core>( | BinOp::BitOr => unreachable!(), }; let core_args = ctx.alloc_slice([core_arg0, core_arg1]); - Ok(ctx.alloc(core::Term::new_app(core::Head::Prim(prim), core_args))) + Ok(ctx.alloc(core::Term::new_app( + ctx.alloc(core::Term::Prim(prim)), + core_args, + ))) } ast::Term::App { func: ast::FunName::BinOp(_) | ast::FunName::UnOp(_), @@ -497,17 +578,76 @@ pub fn infer<'src, 'core>( "cannot infer type of a primitive operation; add a type annotation" )), + // ------------------------------------------------------------------ Pi + // Function type expression: elaborate each param type, push locals, elaborate body type. + ast::Term::Pi { params, ret_ty } => { + ensure!( + phase == Phase::Meta, + "function types are only valid in meta-phase context" + ); + let depth_before = ctx.depth(); + + let mut elaborated_params: Vec<(&'core str, &'core core::Term<'core>)> = Vec::new(); + for p in *params { + let param_name: &'core str = ctx.arena.alloc_str(p.name.as_str()); + let param_ty = infer(ctx, Phase::Meta, p.ty)?; + ensure!( + type_universe(param_ty, &ctx.locals).is_some(), + "parameter type must be a type" + ); + elaborated_params.push((param_name, param_ty)); + ctx.push_local(param_name, param_ty); + } + + let core_ret_ty = infer(ctx, Phase::Meta, ret_ty)?; + ensure!( + type_universe(core_ret_ty, &ctx.locals).is_some(), + "return type must be a type" + ); + + for _ in &elaborated_params { + ctx.pop_local(); + } + assert_eq!(ctx.depth(), depth_before, "Pi elaboration leaked locals"); + let params_slice = ctx.alloc_slice(elaborated_params); + Ok(ctx.alloc(core::Term::Pi(Pi { params: params_slice, body_ty: core_ret_ty, phase: Phase::Meta }))) + } + + // ------------------------------------------------------------------ Lam (infer mode) + // Lambda with mandatory type annotations — inferable. + ast::Term::Lam { params, body } => { + ensure!( + phase == Phase::Meta, + "lambdas are only valid in meta-phase context" + ); + + let depth_before = ctx.depth(); + let mut elaborated_params: Vec<(&'core str, &'core core::Term<'core>)> = Vec::new(); + + for p in *params { + let param_name: &'core str = ctx.arena.alloc_str(p.name.as_str()); + let param_ty = infer(ctx, Phase::Meta, p.ty)?; + elaborated_params.push((param_name, param_ty)); + ctx.push_local(param_name, param_ty); + } + + let core_body = infer(ctx, phase, body)?; + + for _ in &elaborated_params { + ctx.pop_local(); + } + assert_eq!(ctx.depth(), depth_before, "Lam elaboration leaked locals"); + let params_slice = ctx.alloc_slice(elaborated_params); + Ok(ctx.alloc(core::Term::Lam(Lam { params: params_slice, body: core_body }))) + } + // ------------------------------------------------------------------ Lift - // `[[T]]` — elaborate T at the object phase, type is Type (meta universe). ast::Term::Lift(inner) => { - // Lift is only legal in meta phase. ensure!( phase == Phase::Meta, "`[[...]]` is only valid in a meta-phase context" ); - // The inner expression must be an object type. let core_inner = infer(ctx, Phase::Object, inner)?; - // Verify the inner term is indeed a type (inhabits VmType). ensure!( types_equal(ctx.type_of(core_inner), &core::Term::VM_TYPE), "argument of `[[...]]` must be an object type" @@ -516,9 +656,7 @@ pub fn infer<'src, 'core>( } // ------------------------------------------------------------------ Quote - // `#(t)` — infer iff the inner term is inferable (phase shifts meta→object). ast::Term::Quote(inner) => { - // Quote is only legal in meta phase. ensure!( phase == Phase::Meta, "`#(...)` is only valid in a meta-phase context" @@ -528,11 +666,7 @@ pub fn infer<'src, 'core>( } // ------------------------------------------------------------------ Splice - // `$(t)` — infer iff `t` infers as `[[T]]`; result type is `T` (phase shifts object→meta). - // If `t` infers as a meta integer `IntTy(w, Meta)`, insert an implicit `Embed(w)` - // to produce `[[IntTy(w, Object)]]` before splicing. ast::Term::Splice(inner) => { - // Splice is only legal in object phase. ensure!( phase == Phase::Object, "`$(...)` is only valid in an object-phase context" @@ -541,14 +675,12 @@ pub fn infer<'src, 'core>( let inner_ty = ctx.type_of(core_inner); match inner_ty { core::Term::Lift(_) => Ok(ctx.alloc(core::Term::Splice(core_inner))), - // A meta-level integer is implicitly embedded: insert Embed(w) so that - // the splice argument has type `[[IntTy(w, Object)]]`. core::Term::Prim(Prim::IntTy(IntType { width, phase: Phase::Meta, })) => { let embedded = ctx.alloc(core::Term::new_app( - core::Head::Prim(Prim::Embed(*width)), + ctx.alloc(core::Term::Prim(Prim::Embed(*width))), ctx.alloc_slice([core_inner]), )); Ok(ctx.alloc(core::Term::Splice(embedded))) @@ -556,7 +688,10 @@ pub fn infer<'src, 'core>( core::Term::Var(_) | core::Term::Prim(_) | core::Term::Lit(..) + | core::Term::Global(_) | core::Term::App(_) + | core::Term::Pi(_) + | core::Term::Lam(_) | core::Term::Quote(_) | core::Term::Splice(_) | core::Term::Let(_) @@ -567,18 +702,14 @@ pub fn infer<'src, 'core>( } // ------------------------------------------------------------------ Block (Let*) - // Elaborate each `let` binding in sequence, then the trailing expression. ast::Term::Block { stmts, expr } => { let depth_before = ctx.depth(); let result = infer_block(ctx, phase, stmts, expr); - // Each let-binding is responsible for pushing and popping its own local - // (via `elaborate_let`), so the depth must be restored exactly. assert_eq!(ctx.depth(), depth_before, "infer_block leaked locals"); result } // ------------------------------------------------------------------ Match - // Without an expected type, match is not inferable — require an annotation. ast::Term::Match { .. } => Err(anyhow!( "cannot infer type of match expression; add a type annotation or use in a \ checked position" @@ -587,14 +718,7 @@ pub fn infer<'src, 'core>( } /// Check exhaustiveness of `arms` given the scrutinee type `scrut_ty`. -/// -/// Returns `Err` if coverage cannot be established. fn check_exhaustiveness(scrut_ty: &core::Term<'_>, arms: &[ast::MatchArm<'_>]) -> Result<()> { - // For u0/u1/u8 scrutinees we track which literal values have been covered - // using a Vec of length 1/2/256 respectively. If all entries become - // true the match is exhaustive even without a wildcard. For any other type - // (u16/u32/u64) we only accept a wildcard or bind-all arm as evidence of - // exhaustiveness, since enumerating every value is impractical. let mut covered_lits: Option> = match scrut_ty { core::Term::Prim(Prim::IntTy(ty)) => match ty.width { IntWidth::U0 => Some(vec![false; 1]), @@ -605,12 +729,15 @@ fn check_exhaustiveness(scrut_ty: &core::Term<'_>, arms: &[ast::MatchArm<'_>]) - core::Term::Var(_) | core::Term::Prim(_) | core::Term::Lit(..) - | core::Term::App { .. } + | core::Term::Global(_) + | core::Term::App(_) + | core::Term::Pi(_) + | core::Term::Lam(_) | core::Term::Lift(_) | core::Term::Quote(_) | core::Term::Splice(_) - | core::Term::Let { .. } - | core::Term::Match { .. } => None, + | core::Term::Let(_) + | core::Term::Match(_) => None, }; let mut has_catch_all = false; @@ -639,7 +766,6 @@ fn check_exhaustiveness(scrut_ty: &core::Term<'_>, arms: &[ast::MatchArm<'_>]) - } /// Elaborate a match pattern into a core pattern. -/// Any bound name can be recovered via `core::Pat::bound_name()`. fn elaborate_pat<'core>(ctx: &Ctx<'core, '_>, pat: &ast::Pat<'_>) -> core::Pat<'core> { match pat { ast::Pat::Lit(n) => core::Pat::Lit(*n), @@ -655,15 +781,7 @@ fn elaborate_pat<'core>(ctx: &Ctx<'core, '_>, pat: &ast::Pat<'_>) -> core::Pat<' } } -/// Elaborate a single `let` binding: resolve the binding type, elaborate the -/// initialiser, push the local into the context, call `cont`, then pop and -/// assemble `core::Term::Let`. -/// -/// `cont` receives the extended context and returns any result `T`. A -/// `body_of` accessor is used to extract the body term (needed to build the -/// `Let` node) from `T`, and a `wrap` function replaces the body in `T` with -/// the finished `Let` node — letting the caller thread arbitrary extra data -/// (e.g. the inferred type) through without any dummy pairs. +/// Elaborate a single `let` binding. fn elaborate_let<'src, 'core, T, F, G, W>( ctx: &mut Ctx<'core, '_>, phase: Phase, @@ -677,7 +795,6 @@ where G: FnOnce(&T) -> &'core core::Term<'core>, W: FnOnce(&'core core::Term<'core>, T) -> T, { - // Determine the binding type: use annotation if present, otherwise infer. let (core_expr, bind_ty) = if let Some(ann) = stmt.ty { let ty = infer(ctx, phase, ann)?; let core_e = check(ctx, phase, stmt.expr, ty) @@ -752,9 +869,7 @@ pub fn check<'src, 'core>( expected: &'core core::Term<'core>, ) -> Result<&'core core::Term<'core>> { // Verify `expected` inhabits the correct universe for the current phase. - // Every `expected` originates from `elaborate_ty` or from `infer`, both of which - // only produce `IntTy`, `U`, or `Lift` — so `None` here is an internal compiler bug. - let ty_phase = type_universe(expected) + let ty_phase = type_universe(expected, &ctx.locals) .expect("expected type passed to `check` is not a well-formed type expression"); ensure!( ty_phase == phase, @@ -763,7 +878,6 @@ pub fn check<'src, 'core>( ); match term { // ------------------------------------------------------------------ Lit - // Literals check against any integer type. ast::Term::Lit(n) => match expected { core::Term::Prim(Prim::IntTy(it)) => { let width = it.width; @@ -776,20 +890,19 @@ pub fn check<'src, 'core>( core::Term::Var(_) | core::Term::Prim(_) | core::Term::Lit(..) - | core::Term::App { .. } + | core::Term::Global(_) + | core::Term::App(_) + | core::Term::Pi(_) + | core::Term::Lam(_) | core::Term::Lift(_) | core::Term::Quote(_) | core::Term::Splice(_) - | core::Term::Let { .. } - | core::Term::Match { .. } => { - Err(anyhow!("literal `{n}` cannot have a non-integer type")) - } + | core::Term::Let(_) + | core::Term::Match(_) => Err(anyhow!("literal `{n}` cannot have a non-integer type")), }, // ------------------------------------------------------------------ App { Prim (BinOp) } // Width is resolved from the expected type. - // Comparison ops (Eq/Ne/Lt/Gt/Le/Ge) are handled in infer mode and fall through - // to infer+unify below, since they always return u1 (inferable). ast::Term::App { func: ast::FunName::BinOp(op), args, @@ -808,12 +921,15 @@ pub fn check<'src, 'core>( core::Term::Var(_) | core::Term::Prim(_) | core::Term::Lit(..) - | core::Term::App { .. } + | core::Term::Global(_) + | core::Term::App(_) + | core::Term::Pi(_) + | core::Term::Lam(_) | core::Term::Lift(_) | core::Term::Quote(_) | core::Term::Splice(_) - | core::Term::Let { .. } - | core::Term::Match { .. } => { + | core::Term::Let(_) + | core::Term::Match(_) => { bail!("primitive operation requires an integer type") } }; @@ -839,7 +955,10 @@ pub fn check<'src, 'core>( let core_arg1 = check(ctx, phase, rhs, expected)?; let core_args = ctx.alloc_slice([core_arg0, core_arg1]); - Ok(ctx.alloc(core::Term::new_app(core::Head::Prim(prim), core_args))) + Ok(ctx.alloc(core::Term::new_app( + ctx.alloc(core::Term::Prim(prim)), + core_args, + ))) } // ------------------------------------------------------------------ App { UnOp } @@ -852,7 +971,10 @@ pub fn check<'src, 'core>( core::Term::Var(_) | core::Term::Prim(_) | core::Term::Lit(..) + | core::Term::Global(_) | core::Term::App(_) + | core::Term::Pi(_) + | core::Term::Lam(_) | core::Term::Lift(_) | core::Term::Quote(_) | core::Term::Splice(_) @@ -871,11 +993,13 @@ pub fn check<'src, 'core>( }; let core_arg = check(ctx, phase, arg, expected)?; let core_args = std::slice::from_ref(ctx.arena.alloc(core_arg)); - Ok(ctx.alloc(core::Term::new_app(core::Head::Prim(prim), core_args))) + Ok(ctx.alloc(core::Term::new_app( + ctx.alloc(core::Term::Prim(prim)), + core_args, + ))) } // ------------------------------------------------------------------ Quote (check mode) - // `#(t)` checked against `[[T]]` — check `t` against `T` at object phase. ast::Term::Quote(inner) => match expected { core::Term::Lift(obj_ty) => { let core_inner = check(ctx, Phase::Object, inner, obj_ty)?; @@ -884,7 +1008,10 @@ pub fn check<'src, 'core>( core::Term::Var(_) | core::Term::Prim(_) | core::Term::Lit(..) + | core::Term::Global(_) | core::Term::App(_) + | core::Term::Pi(_) + | core::Term::Lam(_) | core::Term::Quote(_) | core::Term::Splice(_) | core::Term::Let(_) @@ -894,21 +1021,11 @@ pub fn check<'src, 'core>( }, // ------------------------------------------------------------------ Splice (check mode) - // `$(e)` checked against `T` (object) — check `e` against `[[T]]` at meta phase. - // Mirror image of Quote: Quote unwraps `[[T]]` to check inner at object phase; - // Splice wraps `T` in `[[...]]` to check inner at meta phase. - // - // For object integer types `T = IntTy(w, Object)`, also accept `e : IntTy(w, Meta)` - // with an implicit `Embed(w)` insertion — the same coercion as the infer path. ast::Term::Splice(inner) => { ensure!( phase == Phase::Object, "`$(...)` is only valid in an object-phase context" ); - // For object integer expected types, first try the standard [[T]] path; if - // that fails, try the meta-integer embed path (inner has type IntTy(w, Meta)). - // Trying [[T]] first means a variable `x : [[u64]]` is always handled - // correctly and the embed path only activates when [[T]] genuinely fails. if let core::Term::Prim(Prim::IntTy(IntType { width, phase: Phase::Object, @@ -921,7 +1038,7 @@ pub fn check<'src, 'core>( let meta_int_ty = ctx.alloc(core::Term::Prim(Prim::IntTy(IntType::meta(*width)))); let core_inner = check(ctx, Phase::Meta, inner, meta_int_ty)?; let embedded = ctx.alloc(core::Term::new_app( - core::Head::Prim(Prim::Embed(*width)), + ctx.alloc(core::Term::Prim(Prim::Embed(*width))), ctx.arena.alloc_slice_fill_iter([core_inner]), )); return Ok(ctx.alloc(core::Term::Splice(embedded))); @@ -931,8 +1048,52 @@ pub fn check<'src, 'core>( Ok(ctx.alloc(core::Term::Splice(core_inner))) } + // ------------------------------------------------------------------ Lam (check mode) + // Check lambda against an expected Pi type. + ast::Term::Lam { params, body } => { + ensure!( + phase == Phase::Meta, + "lambdas are only valid in meta-phase context" + ); + + let depth_before = ctx.depth(); + + // Expected type must be a Pi with matching arity. + let pi = match expected { + core::Term::Pi(pi) => pi, + _ => bail!("expected a function type for this lambda"), + }; + ensure!( + params.len() == pi.params.len(), + "lambda has {} parameter(s) but expected type has {}", + params.len(), + pi.params.len() + ); + + let mut elaborated_params: Vec<(&'core str, &'core core::Term<'core>)> = Vec::new(); + for (p, &(_, pi_param_ty)) in params.iter().zip(pi.params.iter()) { + let param_name: &'core str = ctx.arena.alloc_str(p.name.as_str()); + let annotated_ty = infer(ctx, Phase::Meta, p.ty)?; + ensure!( + types_equal(annotated_ty, pi_param_ty), + "lambda parameter type mismatch: annotation gives a different type \ + than the expected function type" + ); + elaborated_params.push((param_name, pi_param_ty)); + ctx.push_local(param_name, pi_param_ty); + } + + let core_body = check(ctx, phase, body, pi.body_ty)?; + + for _ in &elaborated_params { + ctx.pop_local(); + } + assert_eq!(ctx.depth(), depth_before, "Lam check leaked locals"); + let params_slice = ctx.alloc_slice(elaborated_params); + Ok(ctx.alloc(core::Term::Lam(Lam { params: params_slice, body: core_body }))) + } + // ------------------------------------------------------------------ Match (check mode) - // Check each arm body against the expected type; the scrutinee is always inferred. ast::Term::Match { scrutinee, arms } => { let core_scrutinee = infer(ctx, phase, scrutinee)?; let scrut_ty = ctx.type_of(core_scrutinee); @@ -943,8 +1104,6 @@ pub fn check<'src, 'core>( ctx.arena .alloc_slice_try_fill_iter(arms.iter().map(|arm| -> Result<_> { let core_pat = elaborate_pat(ctx, &arm.pat); - // If the pattern binds a name, push it into locals for the arm body. - // We use a placeholder type (scrutinee type) — sufficient for the prototype. if let Some(bname) = core_pat.bound_name() { ctx.push_local(bname, scrut_ty); } @@ -966,19 +1125,15 @@ pub fn check<'src, 'core>( } // ------------------------------------------------------------------ Block (check mode) - // Thread the expected type down through let-bindings to the final expression. ast::Term::Block { stmts, expr } => { let depth_before = ctx.depth(); let result = check_block(ctx, phase, stmts, expr, expected); - // Each let-binding is responsible for pushing and popping its own local - // (via `elaborate_let`), so the depth must be restored exactly. assert_eq!(ctx.depth(), depth_before, "check_block leaked locals"); result } // ------------------------------------------------------------------ fallthrough: infer then unify - // For all other forms, infer the type and check it matches expected. - ast::Term::Var(_) | ast::Term::App { .. } | ast::Term::Lift(_) => { + ast::Term::Var(_) | ast::Term::App { .. } | ast::Term::Lift(_) | ast::Term::Pi { .. } => { let core_term = infer(ctx, phase, term)?; ensure!( types_equal(ctx.type_of(core_term), expected), diff --git a/compiler/src/checker/test/apply.rs b/compiler/src/checker/test/apply.rs index 80f7da2..50dcadb 100644 --- a/compiler/src/checker/test/apply.rs +++ b/compiler/src/checker/test/apply.rs @@ -8,11 +8,11 @@ fn infer_global_call_no_args_returns_ret_ty() { let src_arena = bumpalo::Bump::new(); let core_arena = bumpalo::Bump::new(); let mut globals = HashMap::new(); - globals.insert(Name::new("f"), sig_no_params_returns_u64()); + globals.insert(Name::new("f"), sig_no_params_returns_u64(&core_arena)); let mut ctx = test_ctx_with_globals(&core_arena, &globals); let term = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("f")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("f")))), args: &[], }); let result = infer(&mut ctx, Phase::Meta, term).expect("should infer"); @@ -34,7 +34,7 @@ fn infer_global_call_unknown_name_fails() { let mut ctx = test_ctx(&core_arena); let term = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("unknown")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("unknown")))), args: &[], }); assert!(infer(&mut ctx, Phase::Meta, term).is_err()); @@ -48,11 +48,11 @@ fn infer_global_call_wrong_arity_fails() { let extra_arg = src_arena.alloc(ast::Term::Lit(99)); let args = src_arena.alloc_slice_fill_iter([extra_arg as &ast::Term]); let mut globals = HashMap::new(); - globals.insert(Name::new("f"), sig_no_params_returns_u64()); + globals.insert(Name::new("f"), sig_no_params_returns_u64(&core_arena)); let mut ctx = test_ctx_with_globals(&core_arena, &globals); let term = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("f")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("f")))), args, }); assert!(infer(&mut ctx, Phase::Meta, term).is_err()); @@ -67,19 +67,13 @@ fn infer_global_call_phase_mismatch_fails() { // `code fn f() -> u64` — object-phase function let u64_obj = core_arena.alloc(core::Term::Prim(Prim::IntTy(IntType::U64_OBJ))); let mut globals = HashMap::new(); - globals.insert( - Name::new("f"), - FunSig { - params: &[], - ret_ty: u64_obj, - phase: Phase::Object, - }, - ); + let f_ty: &core::Term = core_arena.alloc(core::Term::Pi(Pi { params: &[], body_ty: u64_obj, phase: Phase::Object })); + globals.insert(Name::new("f"), f_ty); let mut ctx = test_ctx_with_globals(&core_arena, &globals); // Call `f()` from meta phase — should be rejected. let term = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("f")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("f")))), args: &[], }); assert!(infer(&mut ctx, Phase::Meta, term).is_err()); @@ -98,7 +92,7 @@ fn infer_global_call_with_arg_checks_arg_type() { let arg = src_arena.alloc(ast::Term::Lit(42)); let args = src_arena.alloc_slice_fill_iter([arg as &ast::Term]); let term = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("f")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("f")))), args, }); let result = infer(&mut ctx, Phase::Meta, term).expect("should infer"); @@ -140,7 +134,7 @@ fn check_binop_add_against_u32_succeeds() { assert!(matches!( result, core::Term::App(core::App { - head: Head::Prim(Prim::Add(IntType { + func: core::Term::Prim(Prim::Add(IntType { width: IntWidth::U32, .. })), @@ -173,7 +167,7 @@ fn infer_comparison_op_returns_u1() { assert!(matches!( core_term, core::Term::App(core::App { - head: Head::Prim(Prim::Eq(IntType { + func: core::Term::Prim(Prim::Eq(IntType { width: IntWidth::U64, .. })), @@ -283,7 +277,7 @@ fn check_eq_op_produces_u1() { assert!(matches!( result, core::Term::App(core::App { - head: Head::Prim(Prim::Eq(IntType { + func: core::Term::Prim(Prim::Eq(IntType { width: IntWidth::U64, .. })), diff --git a/compiler/src/checker/test/context.rs b/compiler/src/checker/test/context.rs index bb2954c..49acf6a 100644 --- a/compiler/src/checker/test/context.rs +++ b/compiler/src/checker/test/context.rs @@ -168,15 +168,10 @@ fn arithmetic_requires_expected_type() { fn global_call_is_inferable() { let arena = bumpalo::Bump::new(); let arg = arena.alloc(core::Term::Lit(1, IntType::U64_META)); - let args = &*arena.alloc_slice_fill_iter([&*arg]); - let app = arena.alloc(core::Term::new_app(Head::Global(Name::new("foo")), args)); - assert!(matches!( - app, - core::Term::App(core::App { - head: Head::Global(Name("foo")), - .. - }) - )); + let global = arena.alloc(core::Term::Global(Name::new("foo"))); + let args = &*arena.alloc_slice_fill_iter([arg as &core::Term]); + let app = arena.alloc(core::Term::new_app(global, args)); + assert!(matches!(app, core::Term::App(_))); } #[test] @@ -203,10 +198,7 @@ fn lift_type_structure() { #[test] fn quote_inference_mirrors_inner() { let arena = bumpalo::Bump::new(); - let inner = arena.alloc(core::Term::new_app( - Head::Global(Name::new("foo")), - arena.alloc_slice_fill_iter([] as [&core::Term; 0]), - )); + let inner = arena.alloc(core::Term::Global(Name::new("foo"))); let quoted = arena.alloc(core::Term::Quote(inner)); assert!(matches!(quoted, core::Term::Quote(_))); } @@ -276,16 +268,11 @@ fn match_with_binding_pattern() { fn function_call_to_global() { let arena = bumpalo::Bump::new(); let arg = arena.alloc(core::Term::Lit(42, IntType::U64_META)); - let args = &*arena.alloc_slice_fill_iter([&*arg]); - let app = arena.alloc(core::Term::new_app(Head::Global(Name::new("foo")), args)); + let global = arena.alloc(core::Term::Global(Name::new("foo"))); + let args = &*arena.alloc_slice_fill_iter([arg as &core::Term]); + let app = arena.alloc(core::Term::new_app(global, args)); - assert!(matches!( - app, - core::Term::App(core::App { - head: Head::Global(Name("foo")), - .. - }) - )); + assert!(matches!(app, core::Term::App(_))); } #[test] @@ -294,15 +281,13 @@ fn builtin_operation_call() { let arg1 = arena.alloc(core::Term::Lit(1, IntType::U64_OBJ)); let arg2 = arena.alloc(core::Term::Lit(2, IntType::U64_OBJ)); let args = &*arena.alloc_slice_fill_iter([&*arg1, &*arg2]); - let app = arena.alloc(core::Term::new_app( - Head::Prim(Prim::Add(IntType::U64_OBJ)), - args, - )); + let prim = arena.alloc(core::Term::Prim(Prim::Add(IntType::U64_OBJ))); + let app = arena.alloc(core::Term::new_app(prim, args)); assert!(matches!( app, core::Term::App(core::App { - head: Head::Prim(Prim::Add(IntType { + func: core::Term::Prim(Prim::Add(IntType { width: IntWidth::U64, .. })), diff --git a/compiler/src/checker/test/helpers.rs b/compiler/src/checker/test/helpers.rs index 4cd9ac0..0bff9e8 100644 --- a/compiler/src/checker/test/helpers.rs +++ b/compiler/src/checker/test/helpers.rs @@ -4,7 +4,7 @@ use super::*; /// Helper to create a test context with empty globals pub fn test_ctx(arena: &bumpalo::Bump) -> Ctx<'_, '_> { - static EMPTY: std::sync::OnceLock, core::FunSig<'static>>> = + static EMPTY: std::sync::OnceLock, &'static core::Term<'static>>> = std::sync::OnceLock::new(); let globals = EMPTY.get_or_init(HashMap::new); Ctx::new(arena, globals) @@ -15,29 +15,26 @@ pub fn test_ctx(arena: &bumpalo::Bump) -> Ctx<'_, '_> { /// The caller must ensure `globals` outlives the returned `Ctx`. pub fn test_ctx_with_globals<'core, 'globals>( arena: &'core bumpalo::Bump, - globals: &'globals HashMap, core::FunSig<'core>>, + globals: &'globals HashMap, &'core core::Term<'core>>, ) -> Ctx<'core, 'globals> { Ctx::new(arena, globals) } -/// Helper: build a simple `FunSig` for a function `fn f() -> u64` (no params, meta phase). -pub fn sig_no_params_returns_u64() -> FunSig<'static> { - let ret_ty = &core::Term::U64_META; - FunSig { +/// Helper: build a Pi term for a function `fn f() -> u64` (no params, meta phase). +pub fn sig_no_params_returns_u64(arena: &bumpalo::Bump) -> &core::Term<'_> { + arena.alloc(core::Term::Pi(Pi { params: &[], - ret_ty, + body_ty: &core::Term::U64_META, phase: Phase::Meta, - } + })) } -/// Helper: build a `FunSig` for `fn f(x: u32) -> u64`. -pub fn sig_one_param_returns_u64(core_arena: &bumpalo::Bump) -> FunSig<'_> { - let u32_ty = &core::Term::U32_META; - let u64_ty = &core::Term::U64_META; - let param = core_arena.alloc(("x", u32_ty as &core::Term)); - FunSig { - params: std::slice::from_ref(param), - ret_ty: u64_ty, +/// Helper: build a Pi term for `fn f(x: u32) -> u64`. +pub fn sig_one_param_returns_u64<'a>(arena: &'a bumpalo::Bump) -> &'a core::Term<'a> { + let params = arena.alloc_slice_fill_iter([("x", &core::Term::U32_META as &core::Term)]); + arena.alloc(core::Term::Pi(Pi { + params, + body_ty: &core::Term::U64_META, phase: Phase::Meta, - } + })) } diff --git a/compiler/src/checker/test/matching.rs b/compiler/src/checker/test/matching.rs index fafa5db..879b913 100644 --- a/compiler/src/checker/test/matching.rs +++ b/compiler/src/checker/test/matching.rs @@ -10,25 +10,20 @@ fn check_match_all_arms_same_type_succeeds() { let u32_ty_core = &core::Term::U32_META; let mut globals = HashMap::new(); - globals.insert( - Name::new("k32"), - FunSig { - params: &[], - ret_ty: u32_ty_core, - phase: Phase::Meta, - }, - ); + globals.insert(Name::new("k32"), core_arena.alloc(core::Term::Pi(Pi { + params: &[], body_ty: u32_ty_core, phase: Phase::Meta, + })) as &_); let mut ctx = test_ctx_with_globals(&core_arena, &globals); let u32_ty = &core::Term::U32_META; ctx.push_local("x", u32_ty); let scrutinee = src_arena.alloc(ast::Term::Var(ast::Name::new("x"))); let arm0_body = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("k32")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("k32")))), args: &[], }); let arm1_body = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("k32")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("k32")))), args: &[], }); let arms = src_arena.alloc_slice_fill_iter([ @@ -54,24 +49,19 @@ fn check_match_u1_fully_covered_succeeds() { let u1_ty_core = &core::Term::U1_META; let mut globals = HashMap::new(); - globals.insert( - Name::new("k1"), - FunSig { - params: &[], - ret_ty: u1_ty_core, - phase: Phase::Meta, - }, - ); + globals.insert(Name::new("k1"), core_arena.alloc(core::Term::Pi(Pi { + params: &[], body_ty: u1_ty_core, phase: Phase::Meta, + })) as &_); let mut ctx = test_ctx_with_globals(&core_arena, &globals); ctx.push_local("x", u1_ty_core); let scrutinee = src_arena.alloc(ast::Term::Var(ast::Name::new("x"))); let arm0_body = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("k1")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("k1")))), args: &[], }); let arm1_body = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("k1")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("k1")))), args: &[], }); // Both values of u1 are covered — exhaustive without a wildcard. @@ -98,20 +88,15 @@ fn infer_match_u1_partially_covered_fails() { let u1_ty_core = &core::Term::U1_META; let mut globals = HashMap::new(); - globals.insert( - Name::new("k1"), - FunSig { - params: &[], - ret_ty: u1_ty_core, - phase: Phase::Meta, - }, - ); + globals.insert(Name::new("k1"), core_arena.alloc(core::Term::Pi(Pi { + params: &[], body_ty: u1_ty_core, phase: Phase::Meta, + })) as &_); let mut ctx = test_ctx_with_globals(&core_arena, &globals); ctx.push_local("x", u1_ty_core); let scrutinee = src_arena.alloc(ast::Term::Var(ast::Name::new("x"))); let arm0_body = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("k1")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("k1")))), args: &[], }); // Only 0 covered, 1 is missing — not exhaustive. @@ -132,25 +117,20 @@ fn infer_match_no_catch_all_fails() { let u32_ty_core = &core::Term::U32_META; let mut globals = HashMap::new(); - globals.insert( - Name::new("k32"), - FunSig { - params: &[], - ret_ty: u32_ty_core, - phase: Phase::Meta, - }, - ); + globals.insert(Name::new("k32"), core_arena.alloc(core::Term::Pi(Pi { + params: &[], body_ty: u32_ty_core, phase: Phase::Meta, + })) as &_); let mut ctx = test_ctx_with_globals(&core_arena, &globals); let u32_ty = &core::Term::U32_META; ctx.push_local("x", u32_ty); let scrutinee = src_arena.alloc(ast::Term::Var(ast::Name::new("x"))); let arm0_body = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("k32")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("k32")))), args: &[], }); let arm1_body = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("k32")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("k32")))), args: &[], }); // Only literal arms, no wildcard/bind — not exhaustive. @@ -178,33 +158,23 @@ fn infer_match_arms_type_mismatch_fails() { let u32_ty_core = &core::Term::U32_META; let u64_ty_core = &core::Term::U64_META; let mut globals = HashMap::new(); - globals.insert( - Name::new("k32"), - FunSig { - params: &[], - ret_ty: u32_ty_core, - phase: Phase::Meta, - }, - ); - globals.insert( - Name::new("k64"), - FunSig { - params: &[], - ret_ty: u64_ty_core, - phase: Phase::Meta, - }, - ); + globals.insert(Name::new("k32"), core_arena.alloc(core::Term::Pi(Pi { + params: &[], body_ty: u32_ty_core, phase: Phase::Meta, + })) as &_); + globals.insert(Name::new("k64"), core_arena.alloc(core::Term::Pi(Pi { + params: &[], body_ty: u64_ty_core, phase: Phase::Meta, + })) as &_); let mut ctx = test_ctx_with_globals(&core_arena, &globals); let u32_ty = &core::Term::U32_META; ctx.push_local("x", u32_ty); let scrutinee = src_arena.alloc(ast::Term::Var(ast::Name::new("x"))); let arm0_body = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("k32")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("k32")))), args: &[], }); let arm1_body = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("k64")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("k64")))), args: &[], }); let arms = src_arena.alloc_slice_fill_iter([ diff --git a/compiler/src/checker/test/meta.rs b/compiler/src/checker/test/meta.rs index 8347af2..4018ab7 100644 --- a/compiler/src/checker/test/meta.rs +++ b/compiler/src/checker/test/meta.rs @@ -65,17 +65,13 @@ fn infer_quote_of_global_call_returns_lifted_type() { let mut globals = HashMap::new(); globals.insert( Name::new("f"), - FunSig { - params: &[], - ret_ty: u64_ty_core, - phase: Phase::Object, - }, + core_arena.alloc(core::Term::Pi(Pi { params: &[], body_ty: u64_ty_core, phase: Phase::Object })) as &_, ); let mut ctx = test_ctx_with_globals(&core_arena, &globals); // Surface: `#(f())` let inner = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("f")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("f")))), args: &[], }); let term = src_arena.alloc(ast::Term::Quote(inner)); @@ -98,16 +94,12 @@ fn infer_quote_at_object_phase_fails() { let mut globals = HashMap::new(); globals.insert( Name::new("f"), - FunSig { - params: &[], - ret_ty: u64_ty_core, - phase: Phase::Object, - }, + core_arena.alloc(core::Term::Pi(Pi { params: &[], body_ty: u64_ty_core, phase: Phase::Object })) as &_, ); let mut ctx = test_ctx_with_globals(&core_arena, &globals); let inner = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("f")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("f")))), args: &[], }); let term = src_arena.alloc(ast::Term::Quote(inner)); diff --git a/compiler/src/checker/test/mod.rs b/compiler/src/checker/test/mod.rs index 92d2afc..21dee3e 100644 --- a/compiler/src/checker/test/mod.rs +++ b/compiler/src/checker/test/mod.rs @@ -8,7 +8,7 @@ use std::collections::HashMap; use super::*; -use crate::core::{self, FunSig, Head, IntType, IntWidth, Name, Pat, Prim}; +use crate::core::{self, IntType, IntWidth, Name, Pat, Pi, Prim}; use crate::parser::ast::{self, BinOp, FunName, MatchArm, Phase}; mod helpers; diff --git a/compiler/src/checker/test/signatures.rs b/compiler/src/checker/test/signatures.rs index 015d651..ae5169a 100644 --- a/compiler/src/checker/test/signatures.rs +++ b/compiler/src/checker/test/signatures.rs @@ -50,42 +50,40 @@ fn collect_signatures_two_functions() { assert_eq!(globals.len(), 2); - let id_sig = globals - .get(&Name::new("id")) - .expect("id should be in globals"); - assert_eq!(id_sig.phase, Phase::Meta); - assert_eq!(id_sig.params.len(), 1); - assert_eq!(id_sig.params[0].0, "x"); + let id_ty = globals.get(&Name::new("id")).expect("id should be in globals"); + let core::Term::Pi(id_pi) = id_ty else { panic!("expected Pi") }; + assert_eq!(id_pi.phase, Phase::Meta); + assert_eq!(id_pi.params.len(), 1); + assert_eq!(id_pi.params[0].0, "x"); assert!(matches!( - id_sig.params[0].1, + id_pi.params[0].1, core::Term::Prim(Prim::IntTy(IntType { width: IntWidth::U32, .. })) )); assert!(matches!( - id_sig.ret_ty, + id_pi.body_ty, core::Term::Prim(Prim::IntTy(IntType { width: IntWidth::U32, .. })) )); - let add_sig = globals - .get(&Name::new("add_one")) - .expect("add_one should be in globals"); - assert_eq!(add_sig.phase, Phase::Object); - assert_eq!(add_sig.params.len(), 1); - assert_eq!(add_sig.params[0].0, "y"); + let add_ty = globals.get(&Name::new("add_one")).expect("add_one should be in globals"); + let core::Term::Pi(add_pi) = add_ty else { panic!("expected Pi") }; + assert_eq!(add_pi.phase, Phase::Object); + assert_eq!(add_pi.params.len(), 1); + assert_eq!(add_pi.params[0].0, "y"); assert!(matches!( - add_sig.params[0].1, + add_pi.params[0].1, core::Term::Prim(Prim::IntTy(IntType { width: IntWidth::U64, .. })) )); assert!(matches!( - add_sig.ret_ty, + add_pi.body_ty, core::Term::Prim(Prim::IntTy(IntType { width: IntWidth::U64, .. @@ -256,7 +254,7 @@ fn elaborate_program_code_fn_with_splice() { // pow0's body: $(k()) let k_call = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("k")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("k")))), args: &[], }); let pow0_body = src_arena.alloc(ast::Term::Splice(k_call)); @@ -296,7 +294,7 @@ fn elaborate_program_forward_reference_succeeds() { // fn a() -> u32 { b() } let a_body = src_arena.alloc(ast::Term::App { - func: FunName::Name(ast::Name::new("b")), + func: FunName::Term(src_arena.alloc(ast::Term::Var(ast::Name::new("b")))), args: &[], }); // fn b() -> u32 { 42 } diff --git a/compiler/src/core/alpha_eq.rs b/compiler/src/core/alpha_eq.rs new file mode 100644 index 0000000..7cff37a --- /dev/null +++ b/compiler/src/core/alpha_eq.rs @@ -0,0 +1,51 @@ +use super::Term; + +/// Alpha-equality: structural equality ignoring `param_name` fields in Pi/Lam. +pub fn alpha_eq(a: &Term<'_>, b: &Term<'_>) -> bool { + // Fast path: pointer equality + if std::ptr::eq(a, b) { + return true; + } + match (a, b) { + (Term::Var(l1), Term::Var(l2)) => l1 == l2, + (Term::Prim(p1), Term::Prim(p2)) => p1 == p2, + (Term::Lit(n1, t1), Term::Lit(n2, t2)) => n1 == n2 && t1 == t2, + (Term::Global(n1), Term::Global(n2)) => n1 == n2, + (Term::App(a1), Term::App(a2)) => { + alpha_eq(a1.func, a2.func) + && a1.args.len() == a2.args.len() + && a1 + .args + .iter() + .zip(a2.args.iter()) + .all(|(x, y)| alpha_eq(x, y)) + } + (Term::Pi(p1), Term::Pi(p2)) => { + p1.phase == p2.phase + && p1.params.len() == p2.params.len() + && p1.params.iter().zip(p2.params.iter()).all(|((_, t1), (_, t2))| alpha_eq(t1, t2)) + && alpha_eq(p1.body_ty, p2.body_ty) + } + (Term::Lam(l1), Term::Lam(l2)) => { + l1.params.len() == l2.params.len() + && l1.params.iter().zip(l2.params.iter()).all(|((_, t1), (_, t2))| alpha_eq(t1, t2)) + && alpha_eq(l1.body, l2.body) + } + (Term::Lift(i1), Term::Lift(i2)) + | (Term::Quote(i1), Term::Quote(i2)) + | (Term::Splice(i1), Term::Splice(i2)) => alpha_eq(i1, i2), + (Term::Let(l1), Term::Let(l2)) => { + alpha_eq(l1.ty, l2.ty) && alpha_eq(l1.expr, l2.expr) && alpha_eq(l1.body, l2.body) + } + (Term::Match(m1), Term::Match(m2)) => { + alpha_eq(m1.scrutinee, m2.scrutinee) + && m1.arms.len() == m2.arms.len() + && m1 + .arms + .iter() + .zip(m2.arms.iter()) + .all(|(a, b)| a.pat == b.pat && alpha_eq(a.body, b.body)) + } + _ => false, + } +} diff --git a/compiler/src/core/mod.rs b/compiler/src/core/mod.rs index e846c83..8450440 100644 --- a/compiler/src/core/mod.rs +++ b/compiler/src/core/mod.rs @@ -1,8 +1,12 @@ pub mod pretty; mod prim; +mod subst; +pub mod alpha_eq; pub use crate::common::{Name, Phase}; +pub use alpha_eq::alpha_eq; pub use prim::{IntType, IntWidth, Prim}; +pub use subst::subst; /// De Bruijn level (counts from the outermost binder) #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -19,20 +23,6 @@ impl Lvl { } } -/// Head of an application: either a top-level function or a primitive op -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum Head<'a> { - Global(Name<'a>), // resolved top-level function name - Prim(Prim), // built-in operation with resolved width -} - -impl Head<'_> { - /// Returns `true` if this head is a binary infix primitive operator. - pub const fn is_binop(&self) -> bool { - matches!(self, Self::Prim(p) if p.is_binop()) - } -} - /// Match pattern in the core IR #[derive(Debug, Clone, PartialEq, Eq)] pub enum Pat<'a> { @@ -58,52 +48,65 @@ pub struct Arm<'a> { pub body: &'a Term<'a>, } -/// Top-level function signature (stored in the globals table during elaboration) -#[derive(Debug)] -pub struct FunSig<'a> { - pub params: &'a [(&'a str, &'a Term<'a>)], // (name, type) pairs - pub ret_ty: &'a Term<'a>, - pub phase: Phase, -} - -/// Elaborated top-level function definition +/// Elaborated top-level function definition. +/// +/// `ty` is always a `Term::Pi`; use `Function::pi()` for convenient access. #[derive(Debug)] pub struct Function<'a> { pub name: Name<'a>, - pub sig: FunSig<'a>, + /// Function type (always `Term::Pi`). The Pi carries the phase, params, and return type. + pub ty: &'a Term<'a>, pub body: &'a Term<'a>, } +impl<'a> Function<'a> { + /// Unwrap `self.ty` as a `Pi`. Panics if `ty` is not a `Pi` (typechecker invariant). + pub fn pi(&self) -> &Pi<'a> { + match self.ty { + Term::Pi(pi) => pi, + _ => unreachable!("Function::ty must be a Pi (typechecker invariant)"), + } + } +} + /// Elaborated program: a sequence of top-level function definitions #[derive(Debug)] pub struct Program<'a> { pub functions: &'a [Function<'a>], } -/// Application of a global function or primitive operation to arguments. +/// Function or primitive application: `func(args...)` +/// +/// `func` may be any term yielding a function type — most commonly: +/// - `Term::Global(name)` for top-level function calls +/// - `Term::Prim(p)` for built-in primitive operations +/// - any expression for higher-order calls +/// +/// An empty `args` slice represents a zero-argument call and is distinct from +/// a bare reference to `func`. #[derive(Debug, PartialEq, Eq)] pub struct App<'a> { - pub head: Head<'a>, + pub func: &'a Term<'a>, pub args: &'a [&'a Term<'a>], } -impl App<'_> { - /// Returns the number of arguments. - pub const fn arity(&self) -> usize { - self.args.len() - } +/// Dependent function type: `fn(params...) -> body_ty` +/// +/// `phase` distinguishes meta-level (`fn`) from object-level (`code fn`) functions. +/// This allows the globals table to store `&Term` directly, unifying type lookup +/// for globals and locals. +#[derive(Debug, PartialEq, Eq)] +pub struct Pi<'a> { + pub params: &'a [(&'a str, &'a Term<'a>)], // (name, type) pairs + pub body_ty: &'a Term<'a>, + pub phase: Phase, +} - /// Returns `true` if this application is a binary infix primitive operator. - /// - /// Asserts that the argument count is exactly 2, which is an invariant - /// enforced by the elaborator for all binop applications. - pub fn is_binop(&self) -> bool { - let result = self.head.is_binop(); - if result { - assert_eq!(self.arity(), 2, "binop App must have exactly 2 arguments"); - } - result - } +/// Lambda abstraction: |params...| body +#[derive(Debug, PartialEq, Eq)] +pub struct Lam<'a> { + pub params: &'a [(&'a str, &'a Term<'a>)], // (name, type) pairs + pub body: &'a Term<'a>, } /// Let binding with explicit type annotation and a body. @@ -127,12 +130,18 @@ pub struct Match<'a> { pub enum Term<'a> { /// Local variable, identified by De Bruijn level Var(Lvl), - /// Built-in type or operation + /// Built-in type or operation (not applied) Prim(Prim), /// Numeric literal with its integer type Lit(u64, IntType), - /// Application of a global function or primitive operation to arguments + /// Global function reference + Global(Name<'a>), + /// Function or primitive application: func(args...) App(App<'a>), + /// Dependent function type: fn(x: A) -> B + Pi(Pi<'a>), + /// Lambda abstraction: |x: A| body + Lam(Lam<'a>), /// Lift: [[T]] — meta type representing object-level code of type T Lift(&'a Self), /// Quotation: #(t) — produce object-level code from a meta expression @@ -202,8 +211,8 @@ impl Term<'static> { } impl<'a> Term<'a> { - pub const fn new_app(head: Head<'a>, args: &'a [&'a Self]) -> Self { - Self::App(App { head, args }) + pub const fn new_app(func: &'a Self, args: &'a [&'a Self]) -> Self { + Self::App(App { func, args }) } pub const fn new_let(name: &'a str, ty: &'a Self, expr: &'a Self, body: &'a Self) -> Self { @@ -220,18 +229,18 @@ impl<'a> Term<'a> { } } -impl<'a> From> for Term<'a> { - fn from(app: App<'a>) -> Self { - Self::App(app) - } -} - impl<'a> From> for Term<'a> { fn from(let_: Let<'a>) -> Self { Self::Let(let_) } } +impl<'a> From> for Term<'a> { + fn from(app: App<'a>) -> Self { + Self::App(app) + } +} + impl<'a> From> for Term<'a> { fn from(match_: Match<'a>) -> Self { Self::Match(match_) diff --git a/compiler/src/core/pretty.rs b/compiler/src/core/pretty.rs index 4007591..8d5a57d 100644 --- a/compiler/src/core/pretty.rs +++ b/compiler/src/core/pretty.rs @@ -2,7 +2,7 @@ use std::fmt; use crate::parser::ast::Phase; -use super::{App, Arm, Function, Head, Pat, Program, Term}; +use super::{Arm, Function, Pat, Program, Term}; // ── Helpers ─────────────────────────────────────────────────────────────────── @@ -29,7 +29,10 @@ impl<'a> Term<'a> { Term::Var(_) | Term::Prim(_) | Term::Lit(..) + | Term::Global(_) | Term::App(_) + | Term::Pi(_) + | Term::Lam(_) | Term::Lift(_) | Term::Quote(_) | Term::Splice(_) => { @@ -63,8 +66,57 @@ impl<'a> Term<'a> { // ── Primitive type / universe ───────────────────────────────────────── Term::Prim(p) => write!(f, "{p}"), - // ── Application ────────────────────────────────────────────────────── - Term::App(app) => app.fmt_app(env, indent, f), + // ── Global reference ────────────────────────────────────────────────── + Term::Global(name) => write!(f, "{name}"), + + // ── Application ─────────────────────────────────────────────────────── + Term::App(app) => { + app.func.fmt_expr(env, indent, f)?; + write!(f, "(")?; + for (i, arg) in app.args.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + arg.fmt_expr(env, indent, f)?; + } + write!(f, ")") + } + + // ── Pi type ─────────────────────────────────────────────────────────── + Term::Pi(pi) => { + let env_before = env.len(); + write!(f, "fn(")?; + for (i, &(name, ty)) in pi.params.iter().enumerate() { + if i > 0 { write!(f, ", ")?; } + if name == "_" { + write!(f, "_: ")?; + } else { + write!(f, "{}@{}: ", name, env.len())?; + } + ty.fmt_expr(env, indent, f)?; + env.push(name); + } + write!(f, ") -> ")?; + pi.body_ty.fmt_expr(env, indent, f)?; + env.truncate(env_before); + Ok(()) + } + + // ── Lambda ──────────────────────────────────────────────────────────── + Term::Lam(lam) => { + let env_before = env.len(); + write!(f, "|")?; + for (i, &(name, ty)) in lam.params.iter().enumerate() { + if i > 0 { write!(f, ", ")?; } + write!(f, "{}@{}: ", name, env.len())?; + ty.fmt_expr(env, indent, f)?; + env.push(name); + } + write!(f, "| ")?; + lam.body.fmt_expr(env, indent, f)?; + env.truncate(env_before); + Ok(()) + } // ── Lift / Quote / Splice ───────────────────────────────────────────── Term::Lift(inner) => { @@ -134,7 +186,10 @@ impl<'a> Term<'a> { Term::Var(_) | Term::Prim(_) | Term::Lit(..) + | Term::Global(_) | Term::App(_) + | Term::Pi(_) + | Term::Lam(_) | Term::Lift(_) | Term::Quote(_) | Term::Splice(_) => self.fmt_term_inline(env, indent, f), @@ -142,46 +197,6 @@ impl<'a> Term<'a> { } } -impl<'a> App<'a> { - /// Print an application. - /// - /// All primitives use `@name(arg, arg, ...)` function-call syntax. No infix - /// operators are emitted in the core pretty-printer. - fn fmt_app( - &self, - env: &mut Vec<&'a str>, - indent: usize, - f: &mut fmt::Formatter<'_>, - ) -> fmt::Result { - match &self.head { - // ── Global function call ────────────────────────────────────────────── - Head::Global(name) => { - write!(f, "{name}(")?; - for (i, arg) in self.args.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - arg.fmt_expr(env, indent, f)?; - } - write!(f, ")") - } - - // ── Primitive operation ─────────────────────────────────────────────── - // All builtins use `@name(args...)` function-call syntax. - Head::Prim(prim) => { - write!(f, "{prim}(")?; - for (i, arg) in self.args.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - arg.fmt_expr(env, indent, f)?; - } - write!(f, ")") - } - } - } -} - impl<'a> Arm<'a> { /// Print a single match arm. fn fmt_arm( @@ -224,11 +239,13 @@ impl fmt::Display for Program<'_> { impl fmt::Display for Function<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let pi = self.pi(); + // Build the name environment for the body: one entry per parameter. - let mut env: Vec<&str> = Vec::with_capacity(self.sig.params.len()); + let mut env: Vec<&str> = Vec::with_capacity(pi.params.len()); // Phase prefix. - match self.sig.phase { + match pi.phase { Phase::Object => write!(f, "code ")?, Phase::Meta => {} } @@ -236,7 +253,7 @@ impl fmt::Display for Function<'_> { // Parameters: types are printed with the env as built so far (dependent // function types: earlier params are in scope for later param types). - for (i, (name, ty)) in self.sig.params.iter().enumerate() { + for (i, (name, ty)) in pi.params.iter().enumerate() { if i > 0 { write!(f, ", ")?; } @@ -246,7 +263,7 @@ impl fmt::Display for Function<'_> { } write!(f, ") -> ")?; - self.sig.ret_ty.fmt_expr(&mut env, 1, f)?; + pi.body_ty.fmt_expr(&mut env, 1, f)?; writeln!(f, " {{")?; // Body in statement position at indent depth 1. diff --git a/compiler/src/core/subst.rs b/compiler/src/core/subst.rs new file mode 100644 index 0000000..38ca520 --- /dev/null +++ b/compiler/src/core/subst.rs @@ -0,0 +1,69 @@ +use super::{Arm, Lam, Lvl, Pi, Term}; + +/// Substitute `replacement` for `Var(target)` in `term`. +pub fn subst<'a>( + arena: &'a bumpalo::Bump, + term: &'a Term<'a>, + target: Lvl, + replacement: &'a Term<'a>, +) -> &'a Term<'a> { + match term { + Term::Var(lvl) if *lvl == target => replacement, + Term::Var(_) | Term::Prim(_) | Term::Lit(..) | Term::Global(_) => term, + + Term::App(app) => { + let new_func = subst(arena, app.func, target, replacement); + let new_args = arena.alloc_slice_fill_iter( + app.args + .iter() + .map(|arg| subst(arena, arg, target, replacement)), + ); + arena.alloc(Term::new_app(new_func, new_args)) + } + + Term::Pi(pi) => { + let new_params = arena.alloc_slice_fill_iter( + pi.params.iter().map(|&(name, ty)| (name, subst(arena, ty, target, replacement))), + ); + let new_body_ty = subst(arena, pi.body_ty, target, replacement); + arena.alloc(Term::Pi(Pi { params: new_params, body_ty: new_body_ty, phase: pi.phase })) + } + + Term::Lam(lam) => { + let new_params = arena.alloc_slice_fill_iter( + lam.params.iter().map(|&(name, ty)| (name, subst(arena, ty, target, replacement))), + ); + let new_body = subst(arena, lam.body, target, replacement); + arena.alloc(Term::Lam(Lam { params: new_params, body: new_body })) + } + + Term::Lift(inner) => { + let new_inner = subst(arena, inner, target, replacement); + arena.alloc(Term::Lift(new_inner)) + } + Term::Quote(inner) => { + let new_inner = subst(arena, inner, target, replacement); + arena.alloc(Term::Quote(new_inner)) + } + Term::Splice(inner) => { + let new_inner = subst(arena, inner, target, replacement); + arena.alloc(Term::Splice(new_inner)) + } + + Term::Let(let_) => { + let new_ty = subst(arena, let_.ty, target, replacement); + let new_expr = subst(arena, let_.expr, target, replacement); + let new_body = subst(arena, let_.body, target, replacement); + arena.alloc(Term::new_let(let_.name, new_ty, new_expr, new_body)) + } + + Term::Match(match_) => { + let new_scrutinee = subst(arena, match_.scrutinee, target, replacement); + let new_arms = arena.alloc_slice_fill_iter(match_.arms.iter().map(|arm| Arm { + pat: arm.pat.clone(), + body: subst(arena, arm.body, target, replacement), + })); + arena.alloc(Term::new_match(new_scrutinee, new_arms)) + } + } +} diff --git a/compiler/src/eval/mod.rs b/compiler/src/eval/mod.rs index 19e2f34..1603dc5 100644 --- a/compiler/src/eval/mod.rs +++ b/compiler/src/eval/mod.rs @@ -4,7 +4,7 @@ use anyhow::{Result, anyhow, ensure}; use bumpalo::Bump; use crate::core::{ - App, Arm, FunSig, Function, Head, IntType, IntWidth, Lvl, Name, Pat, Prim, Program, Term, + Arm, Function, IntType, IntWidth, Lam, Lvl, Name, Pat, Pi, Prim, Program, Term, }; use crate::parser::ast::Phase; @@ -12,46 +12,47 @@ use crate::parser::ast::Phase; /// A value produced by meta-level evaluation. /// -/// In this substitution-based prototype there are only two kinds of meta -/// values: integer literals and quoted object-level code. Meta-level -/// lambdas / closures are not yet supported (the current surface language -/// has no meta-level lambda syntax; only top-level `fn` definitions). +/// Two lifetime parameters: +/// - `'out`: lifetime of the output arena (for `Code` values that appear in the result). +/// - `'eval`: lifetime of the evaluation phase — covers both the input program data (`'core`) +/// and any temporary terms allocated in the local eval arena. Since `Term` is covariant +/// in its lifetime, `'core` data can be coerced to `'eval` at call sites. #[derive(Clone, Debug)] -enum MetaVal<'out> { +enum MetaVal<'out, 'eval> { /// A concrete integer value computed at meta (compile) time. - VLit(u64), - /// Quoted object-level code: the result of evaluating `#(t)` or of - /// wrapping a literal via `Embed`. The inner term is a splice-free - /// object term produced by `unstage_obj`. - VCode(&'out Term<'out>), + Lit(u64), + /// Quoted object-level code. + Code(&'out Term<'out>), + /// A type term passed as a type argument (dependent types: types are values). + /// The type term itself is not inspected during evaluation. + Ty, + /// A closure: a lambda body captured with its environment. + Closure { + body: &'eval Term<'eval>, + env: Vec>, + obj_next: Lvl, + }, } // ── Environment ─────────────────────────────────────────────────────────────── /// A binding stored in the evaluation environment, indexed by De Bruijn level. #[derive(Clone, Debug)] -enum Binding<'out> { +enum Binding<'out, 'eval> { /// A meta-level variable bound to a concrete `MetaVal`. - Meta(MetaVal<'out>), - /// An object-level variable. Object variables are opaque during - /// meta-level evaluation and remain as `Var(lvl)` in the output. - /// The Lvl inside signifies the level in the generated output - /// rather than in the original program where bindings for the - /// object level and meta level may be interwoven. + Meta(MetaVal<'out, 'eval>), + /// An object-level variable. Obj(Lvl), } /// Evaluation environment: a stack of bindings indexed by De Bruijn level. -/// -/// Level 0 is the outermost binding (first function parameter); new bindings -/// are pushed onto the end and accessed by their index. #[derive(Debug)] -struct Env<'out> { - bindings: Vec>, +struct Env<'out, 'eval> { + bindings: Vec>, obj_next: Lvl, } -impl<'out> Env<'out> { +impl<'out, 'eval> Env<'out, 'eval> { const fn new(obj_next: Lvl) -> Self { Env { bindings: Vec::new(), @@ -60,14 +61,13 @@ impl<'out> Env<'out> { } /// Look up the binding at level `lvl`. - fn get(&self, lvl: Lvl) -> &Binding<'out> { + fn get(&self, lvl: Lvl) -> &Binding<'out, 'eval> { self.bindings .get(lvl.0) .expect("De Bruijn level in env bounds") } - /// Push an object-level binding. Assigns the next consecutive object-level - /// De Bruijn level and advances `obj_next`. + /// Push an object-level binding. fn push_obj(&mut self) { let lvl = self.obj_next; self.obj_next = lvl.succ(); @@ -75,11 +75,11 @@ impl<'out> Env<'out> { } /// Push a meta-level binding bound to the given value. - fn push_meta(&mut self, val: MetaVal<'out>) { + fn push_meta(&mut self, val: MetaVal<'out, 'eval>) { self.bindings.push(Binding::Meta(val)); } - /// Pop the last binding (used to restore the environment after a let / arm). + /// Pop the last binding. fn pop(&mut self) { match self.bindings.pop().expect("pop on empty environment") { Binding::Obj(_) => { @@ -98,31 +98,23 @@ impl<'out> Env<'out> { // ── Globals table ───────────────────────────────────────────────────────────── /// Everything the evaluator needs to know about a top-level function. -struct GlobalDef<'core> { - sig: &'core FunSig<'core>, - body: &'core Term<'core>, +struct GlobalDef<'a> { + ty: &'a Term<'a>, // always Term::Pi + body: &'a Term<'a>, } -type Globals<'core> = HashMap, GlobalDef<'core>>; +type Globals<'a> = HashMap, GlobalDef<'a>>; // ── Meta-level evaluator ────────────────────────────────────────────────────── /// Evaluate a meta-level `term` to a `MetaVal`. -/// -/// `env` maps De Bruijn levels to their current values. `globals` provides -/// the definitions of all top-level functions. `arena` is used when -/// allocating object terms inside `VCode` values (via `unstage_obj`). -/// -/// Invariants enforced by the typechecker (violations panic via `unreachable!`): -/// - `Splice` nodes never appear at meta level. -/// - `Lift` and type-level `Prim` nodes never appear in term positions. -/// - Object variables (`Binding::Obj`) are never referenced at meta level. -fn eval_meta<'out, 'core>( +fn eval_meta<'out, 'eval>( arena: &'out Bump, - globals: &Globals<'core>, - env: &mut Env<'out>, - term: &'core Term<'core>, -) -> Result> { + eval_arena: &'eval Bump, + globals: &Globals<'eval>, + env: &mut Env<'out, 'eval>, + term: &'eval Term<'eval>, +) -> Result> { match term { // ── Variable ───────────────────────────────────────────────────────── Term::Var(lvl) => match env.get(*lvl) { @@ -134,106 +126,187 @@ fn eval_meta<'out, 'core>( }, // ── Literal ────────────────────────────────────────────────────────── - Term::Lit(n, _) => Ok(MetaVal::VLit(*n)), + Term::Lit(n, _) => Ok(MetaVal::Lit(*n)), + + // ── Global reference ───────────────────────────────────────────────── + Term::Global(name) => { + let def = globals + .get(name) + .unwrap_or_else(|| panic!("unknown global `{name}` during staging")); + let Term::Pi(pi) = def.ty else { + unreachable!("global `{name}` must have a Pi type (typechecker invariant)") + }; + if pi.params.is_empty() { + // Zero-param global: evaluate the body immediately in a fresh env. + let mut callee_env = Env::new(env.obj_next); + eval_meta(arena, eval_arena, globals, &mut callee_env, def.body) + } else { + // Multi-param global: produce a closure. + Ok(global_to_closure(eval_arena, def, env.obj_next)) + } + } + + // ── Lambda ─────────────────────────────────────────────────────────── + Term::Lam(lam) => { + // For a zero-param lambda (thunk), wrap in a Closure whose body IS the + // lambda body; force_thunk evaluates it when applied to zero args. + // For a multi-param lambda, wrap params[1..] in a synthetic Lam so that + // apply_closure can peel one param at a time. + let body = match lam.params { + [] | [_] => lam.body, + [_, rest @ ..] => { + eval_arena.alloc(Term::Lam(Lam { params: rest, body: lam.body })) + } + }; + Ok(MetaVal::Closure { + body, + env: env.bindings.clone(), + obj_next: env.obj_next, + }) + } // ── Application ────────────────────────────────────────────────────── - Term::App(app) => eval_meta_app(arena, globals, env, app), + Term::App(app) => match app.func { + Term::Prim(prim) => eval_meta_prim(arena, eval_arena, globals, env, *prim, app.args), + _ => { + let mut val = eval_meta(arena, eval_arena, globals, env, app.func)?; + if app.args.is_empty() { + // Zero-arg call: force the thunk closure. + val = force_thunk(arena, eval_arena, globals, val)?; + } else { + for arg in app.args { + let arg_val = eval_meta(arena, eval_arena, globals, env, arg)?; + val = apply_closure(arena, eval_arena, globals, val, arg_val)?; + } + } + Ok(val) + } + }, - // ── Quote: #(t) ─────────────────────────────────────────────────────── - // Unstage the enclosed object term (eliminating any splices inside it) - // and wrap the result as object code. + // ── Quote: #(t) ────────────────────────────────────────────────────── Term::Quote(inner) => { - let obj_term = unstage_obj(arena, globals, env, inner)?; - Ok(MetaVal::VCode(obj_term)) + let obj_term = unstage_obj(arena, eval_arena, globals, env, inner)?; + Ok(MetaVal::Code(obj_term)) } - // ── Let binding ─────────────────────────────────────────────────────── + // ── Let binding ────────────────────────────────────────────────────── Term::Let(let_) => { - let val = eval_meta(arena, globals, env, let_.expr)?; + let val = eval_meta(arena, eval_arena, globals, env, let_.expr)?; env.push_meta(val); - let result = eval_meta(arena, globals, env, let_.body); + let result = eval_meta(arena, eval_arena, globals, env, let_.body); env.pop(); result } - // ── Match ───────────────────────────────────────────────────────────── + // ── Match ──────────────────────────────────────────────────────────── Term::Match(match_) => { - let scrut_val = eval_meta(arena, globals, env, match_.scrutinee)?; + let scrut_val = eval_meta(arena, eval_arena, globals, env, match_.scrutinee)?; let n = match scrut_val { - MetaVal::VLit(n) => n, - MetaVal::VCode(_) => unreachable!( - "cannot match on object code at meta level (typechecker invariant)" + MetaVal::Lit(n) => n, + MetaVal::Code(_) | MetaVal::Ty | MetaVal::Closure { .. } => unreachable!( + "cannot match on non-integer at meta level (typechecker invariant)" ), }; - eval_meta_match(arena, globals, env, n, match_.arms) + eval_meta_match(arena, eval_arena, globals, env, n, match_.arms) } - // ── Unreachable in well-typed meta terms ────────────────────────────── + // ── Unreachable in well-typed meta terms ───────────────────────────── Term::Splice(_) => unreachable!("Splice in meta context (typechecker invariant)"), - Term::Lift(_) | Term::Prim(_) => { - unreachable!("type-level term in evaluation position (typechecker invariant)") - } + // Type-level terms evaluate to themselves when passed as type arguments + // in a dependently-typed function call (e.g. `id(u64, x)` passes `u64 : Type`). + Term::Lift(_) | Term::Prim(_) | Term::Pi(_) => Ok(MetaVal::Ty), } } -/// Evaluate a function application at meta level. -fn eval_meta_app<'out, 'core>( - arena: &'out Bump, - globals: &Globals<'core>, - env: &mut Env<'out>, - app: &'core App<'core>, -) -> Result> { - match &app.head { - // ── Global function call ────────────────────────────────────────────── - Head::Global(name) => { - let def = globals - .get(name) - .unwrap_or_else(|| panic!("unknown global `{name}` during staging")); - - assert_eq!( - def.sig.phase, - Phase::Meta, - "object-phase function `{name}` called in meta context during staging" - ); - - // Evaluate each argument in the *caller's* environment. - let mut arg_vals: Vec> = Vec::with_capacity(app.args.len()); - for arg in app.args { - arg_vals.push(eval_meta(arena, globals, env, arg)?); - } +/// Convert a global function definition into a closure value. +/// +/// For a multi-parameter function, we build nested closures. E.g., `fn f(x, y) = body` +/// becomes a closure whose body is a lambda `|y| body`. The synthetic `Lam` wrapper nodes +/// are allocated in `eval_arena`, which is local to `unstage_program` and lives for the +/// duration of staging — long enough to outlive any closure values. +fn global_to_closure<'out, 'eval>( + eval_arena: &'eval Bump, + def: &GlobalDef<'eval>, + obj_next: Lvl, +) -> MetaVal<'out, 'eval> { + // Called only when params is non-empty (zero-param globals are evaluated immediately). + let Term::Pi(pi) = def.ty else { + unreachable!("global must have a Pi type (typechecker invariant)") + }; + let body = match pi.params { + [_] | [] => def.body, + [_, rest @ ..] => { + eval_arena.alloc(Term::Lam(Lam { params: rest, body: def.body })) + } + }; + MetaVal::Closure { body, env: Vec::new(), obj_next } +} - // Build a fresh environment for the callee: one binding per parameter. - let mut callee_env = Env::new(env.obj_next); - for val in arg_vals { - callee_env.push_meta(val); - } +/// Apply a closure value to an argument value. +fn apply_closure<'out, 'eval>( + arena: &'out Bump, + eval_arena: &'eval Bump, + globals: &Globals<'eval>, + func_val: MetaVal<'out, 'eval>, + arg_val: MetaVal<'out, 'eval>, +) -> Result> { + match func_val { + MetaVal::Closure { + body, + env, + obj_next, + .. + } => { + let mut callee_env = Env { + bindings: env, + obj_next, + }; + callee_env.push_meta(arg_val); - eval_meta(arena, globals, &mut callee_env, def.body) + eval_meta(arena, eval_arena, globals, &mut callee_env, body) } + MetaVal::Lit(_) | MetaVal::Code(_) | MetaVal::Ty => { + unreachable!("applying a non-function value (typechecker invariant)") + } + } +} - // ── Primitive operations ────────────────────────────────────────────── - Head::Prim(prim) => eval_meta_prim(arena, globals, env, *prim, app.args), +/// Force a thunk closure: evaluate its body in the captured environment without pushing any arg. +fn force_thunk<'out, 'eval>( + arena: &'out Bump, + eval_arena: &'eval Bump, + globals: &Globals<'eval>, + val: MetaVal<'out, 'eval>, +) -> Result> { + match val { + MetaVal::Closure { body, env, obj_next, .. } => { + let mut callee_env = Env { bindings: env, obj_next }; + eval_meta(arena, eval_arena, globals, &mut callee_env, body) + } + // Already-evaluated value (e.g. a zero-param global reduced to Lit/Code). + // A zero-arg call is a no-op in this case. + other => Ok(other), } } /// Evaluate a primitive operation at meta level. -fn eval_meta_prim<'out, 'core>( +fn eval_meta_prim<'out, 'eval>( arena: &'out Bump, - globals: &Globals<'core>, - env: &mut Env<'out>, + eval_arena: &'eval Bump, + globals: &Globals<'eval>, + env: &mut Env<'out, 'eval>, prim: Prim, - args: &'core [&'core Term<'core>], -) -> Result> { - // Evaluate args[i] and extract its integer value. - // Panics if the value is `VCode` — the typechecker guarantees integer operands. + args: &'eval [&'eval Term<'eval>], +) -> Result> { let eval_lit = |arena: &'out Bump, - globals: &Globals<'core>, - env: &mut Env<'out>, - arg: &'core Term<'core>| { - eval_meta(arena, globals, env, arg).map(|v| match v { - MetaVal::VLit(n) => n, - MetaVal::VCode(_) => unreachable!( - "expected integer meta value for primitive operand, got code (typechecker invariant)" + eval_arena: &'eval Bump, + globals: &Globals<'eval>, + env: &mut Env<'out, 'eval>, + arg: &'eval Term<'eval>| { + eval_meta(arena, eval_arena, globals, env, arg).map(|v| match v { + MetaVal::Lit(n) => n, + MetaVal::Code(_) | MetaVal::Ty | MetaVal::Closure { .. } => unreachable!( + "expected integer meta value for primitive operand (typechecker invariant)" ), }) }; @@ -242,8 +315,8 @@ fn eval_meta_prim<'out, 'core>( match prim { // ── Arithmetic ──────────────────────────────────────────────────────── Prim::Add(IntType { width, .. }) => { - let a = eval_lit(arena, globals, env, args[0])?; - let b = eval_lit(arena, globals, env, args[1])?; + let a = eval_lit(arena, eval_arena, globals, env, args[0])?; + let b = eval_lit(arena, eval_arena, globals, env, args[1])?; let result = a .checked_add(b) .filter(|&r| r <= width.max_value()) @@ -255,22 +328,22 @@ fn eval_meta_prim<'out, 'core>( width.max_value() ) })?; - Ok(MetaVal::VLit(result)) + Ok(MetaVal::Lit(result)) } Prim::Sub(IntType { width, .. }) => { - let a = eval_lit(arena, globals, env, args[0])?; - let b = eval_lit(arena, globals, env, args[1])?; + let a = eval_lit(arena, eval_arena, globals, env, args[0])?; + let b = eval_lit(arena, eval_arena, globals, env, args[1])?; let result = a.checked_sub(b).ok_or_else(|| { anyhow!( "arithmetic overflow during staging: \ {a} - {b} underflows {width}" ) })?; - Ok(MetaVal::VLit(result)) + Ok(MetaVal::Lit(result)) } Prim::Mul(IntType { width, .. }) => { - let a = eval_lit(arena, globals, env, args[0])?; - let b = eval_lit(arena, globals, env, args[1])?; + let a = eval_lit(arena, eval_arena, globals, env, args[0])?; + let b = eval_lit(arena, eval_arena, globals, env, args[1])?; let result = a .checked_mul(b) .filter(|&r| r <= width.max_value()) @@ -282,72 +355,69 @@ fn eval_meta_prim<'out, 'core>( width.max_value() ) })?; - Ok(MetaVal::VLit(result)) + Ok(MetaVal::Lit(result)) } Prim::Div(_) => { - let a = eval_lit(arena, globals, env, args[0])?; - let b = eval_lit(arena, globals, env, args[1])?; + let a = eval_lit(arena, eval_arena, globals, env, args[0])?; + let b = eval_lit(arena, eval_arena, globals, env, args[1])?; ensure!(b != 0, "division by zero during staging"); - Ok(MetaVal::VLit(a / b)) + Ok(MetaVal::Lit(a / b)) } // ── Bitwise ─────────────────────────────────────────────────────────── Prim::BitAnd(_) => { - let a = eval_lit(arena, globals, env, args[0])?; - let b = eval_lit(arena, globals, env, args[1])?; - Ok(MetaVal::VLit(a & b)) + let a = eval_lit(arena, eval_arena, globals, env, args[0])?; + let b = eval_lit(arena, eval_arena, globals, env, args[1])?; + Ok(MetaVal::Lit(a & b)) } Prim::BitOr(_) => { - let a = eval_lit(arena, globals, env, args[0])?; - let b = eval_lit(arena, globals, env, args[1])?; - Ok(MetaVal::VLit(a | b)) + let a = eval_lit(arena, eval_arena, globals, env, args[0])?; + let b = eval_lit(arena, eval_arena, globals, env, args[1])?; + Ok(MetaVal::Lit(a | b)) } Prim::BitNot(IntType { width, .. }) => { - let a = eval_lit(arena, globals, env, args[0])?; - Ok(MetaVal::VLit(mask_to_width(width, !a))) + let a = eval_lit(arena, eval_arena, globals, env, args[0])?; + Ok(MetaVal::Lit(mask_to_width(width, !a))) } // ── Comparison ──────────────────────────────────────────────────────── Prim::Eq(_) => { - let a = eval_lit(arena, globals, env, args[0])?; - let b = eval_lit(arena, globals, env, args[1])?; - Ok(MetaVal::VLit(u64::from(a == b))) + let a = eval_lit(arena, eval_arena, globals, env, args[0])?; + let b = eval_lit(arena, eval_arena, globals, env, args[1])?; + Ok(MetaVal::Lit(u64::from(a == b))) } Prim::Ne(_) => { - let a = eval_lit(arena, globals, env, args[0])?; - let b = eval_lit(arena, globals, env, args[1])?; - Ok(MetaVal::VLit(u64::from(a != b))) + let a = eval_lit(arena, eval_arena, globals, env, args[0])?; + let b = eval_lit(arena, eval_arena, globals, env, args[1])?; + Ok(MetaVal::Lit(u64::from(a != b))) } Prim::Lt(_) => { - let a = eval_lit(arena, globals, env, args[0])?; - let b = eval_lit(arena, globals, env, args[1])?; - Ok(MetaVal::VLit(u64::from(a < b))) + let a = eval_lit(arena, eval_arena, globals, env, args[0])?; + let b = eval_lit(arena, eval_arena, globals, env, args[1])?; + Ok(MetaVal::Lit(u64::from(a < b))) } Prim::Gt(_) => { - let a = eval_lit(arena, globals, env, args[0])?; - let b = eval_lit(arena, globals, env, args[1])?; - Ok(MetaVal::VLit(u64::from(a > b))) + let a = eval_lit(arena, eval_arena, globals, env, args[0])?; + let b = eval_lit(arena, eval_arena, globals, env, args[1])?; + Ok(MetaVal::Lit(u64::from(a > b))) } Prim::Le(_) => { - let a = eval_lit(arena, globals, env, args[0])?; - let b = eval_lit(arena, globals, env, args[1])?; - Ok(MetaVal::VLit(u64::from(a <= b))) + let a = eval_lit(arena, eval_arena, globals, env, args[0])?; + let b = eval_lit(arena, eval_arena, globals, env, args[1])?; + Ok(MetaVal::Lit(u64::from(a <= b))) } Prim::Ge(_) => { - let a = eval_lit(arena, globals, env, args[0])?; - let b = eval_lit(arena, globals, env, args[1])?; - Ok(MetaVal::VLit(u64::from(a >= b))) + let a = eval_lit(arena, eval_arena, globals, env, args[0])?; + let b = eval_lit(arena, eval_arena, globals, env, args[1])?; + Ok(MetaVal::Lit(u64::from(a >= b))) } // ── Embed: meta integer → object code ───────────────────────────────── - // `Embed(w)` applied to a meta integer `n` produces object-level code - // consisting of the literal `n`. This is how a compile-time integer - // constant is embedded into the generated object program. Prim::Embed(width) => { - let n = eval_lit(arena, globals, env, args[0])?; + let n = eval_lit(arena, eval_arena, globals, env, args[0])?; let phase = Phase::Object; let lit_term = arena.alloc(Term::Lit(n, IntType { width, phase })); - Ok(MetaVal::VCode(lit_term)) + Ok(MetaVal::Code(lit_term)) } // ── Type-level prims are unreachable ────────────────────────────────── @@ -370,36 +440,29 @@ const fn mask_to_width(width: IntWidth, val: u64) -> u64 { } /// Evaluate a meta-level `match` expression. -/// -/// `n` is the already-evaluated scrutinee value. -/// Arms are checked in order; the first matching arm wins. -fn eval_meta_match<'out, 'core>( +fn eval_meta_match<'out, 'eval>( arena: &'out Bump, - globals: &Globals<'core>, - env: &mut Env<'out>, + eval_arena: &'eval Bump, + globals: &Globals<'eval>, + env: &mut Env<'out, 'eval>, n: u64, - arms: &'core [Arm<'core>], -) -> Result> { + arms: &'eval [Arm<'eval>], +) -> Result> { for arm in arms { match &arm.pat { Pat::Lit(m) => { if n == *m { - return eval_meta(arena, globals, env, arm.body); + return eval_meta(arena, eval_arena, globals, env, arm.body); } } Pat::Bind(_) | Pat::Wildcard => { - // Catch-all: bind the scrutinee value and evaluate the body. - env.push_meta(MetaVal::VLit(n)); - let result = eval_meta(arena, globals, env, arm.body); + env.push_meta(MetaVal::Lit(n)); + let result = eval_meta(arena, eval_arena, globals, env, arm.body); env.pop(); return result; } } } - // The typechecker enforces exhaustiveness, so this should not be reachable - // for well-typed programs. It can happen if the meta computation produces - // a value outside the covered range (e.g. a u64 with no wildcard arm), which - // is a user-visible runtime staging error rather than an internal bug. Err(anyhow!( "non-exhaustive match during staging (scrutinee = {n})" )) @@ -408,36 +471,31 @@ fn eval_meta_match<'out, 'core>( // ── Object-level unstager ───────────────────────────────────────────────────── /// Unstage an object-level `term`, eliminating all `Splice` nodes. -/// -/// Object variables (`Var`), operations (`App`), `Let`, and `Match` are left -/// structurally intact; only `Splice` nodes are reduced by running the -/// enclosed meta computation. -/// -/// `env` is shared with the meta evaluator so that meta variables that are in -/// scope at a splice point (e.g. from an enclosing `Quote`) remain accessible. -fn unstage_obj<'out, 'core>( +fn unstage_obj<'out, 'eval>( arena: &'out Bump, - globals: &Globals<'core>, - env: &mut Env<'out>, - term: &'core Term<'core>, + eval_arena: &'eval Bump, + globals: &Globals<'eval>, + env: &mut Env<'out, 'eval>, + term: &'eval Term<'eval>, ) -> Result<&'out Term<'out>> { match term { // ── Variable ───────────────────────────────────────────────────────── Term::Var(lvl) => match env.get(*lvl) { - // A plain object variable (e.g. a `code fn` parameter) passes - // through as-is — it will be a free variable in the output. Binding::Obj(out_lvl) => Ok(arena.alloc(Term::Var(*out_lvl))), - // A meta variable of type `[[T]]` is referenced inside a quoted - // object term. Its value is object code. `VCode` is always - // fully staged (produced by `unstage_obj` at quote time), so we - // return it directly without recursing — that would be unsound - // because the levels inside the VCode term are relative to the - // environment at the *quoting site*, not the current env. - // This implements the ∼⟨t⟩ ≡ t definitional equality. - Binding::Meta(MetaVal::VCode(obj)) => Ok(obj), - Binding::Meta(MetaVal::VLit(_)) => unreachable!( + Binding::Meta(MetaVal::Code(obj)) => Ok(obj), + Binding::Meta(MetaVal::Lit(_)) => unreachable!( "integer meta variable at level {} referenced in object context \ - (typechecker invariant: only [[T]]-typed meta vars can appear in object terms)", + (typechecker invariant)", + lvl.0 + ), + Binding::Meta(MetaVal::Closure { .. }) => unreachable!( + "closure meta variable at level {} referenced in object context \ + (typechecker invariant)", + lvl.0 + ), + Binding::Meta(MetaVal::Ty) => unreachable!( + "type meta variable at level {} referenced in object context \ + (typechecker invariant)", lvl.0 ), }, @@ -448,43 +506,39 @@ fn unstage_obj<'out, 'core>( // ── Primitive ──────────────────────────────────────────────────────── Term::Prim(p) => Ok(arena.alloc(Term::Prim(*p))), - // ── Application ────────────────────────────────────────────────────── + // ── Global reference (in object terms, e.g. object-level function call) ── + Term::Global(name) => { + Ok(arena.alloc(Term::Global(Name::new(arena.alloc_str(name.as_str()))))) + } + + // ── App ─────────────────────────────────────────────────────────────── Term::App(app) => { - let staged_head = match &app.head { - Head::Global(name) => Head::Global(Name::new(arena.alloc_str(name.as_str()))), - Head::Prim(p) => Head::Prim(*p), - }; + let staged_func = unstage_obj(arena, eval_arena, globals, env, app.func)?; let staged_args: &'out [&'out Term<'out>] = arena.alloc_slice_try_fill_iter( app.args .iter() - .map(|arg| unstage_obj(arena, globals, env, arg)), + .map(|arg| unstage_obj(arena, eval_arena, globals, env, arg)), )?; - Ok(arena.alloc(Term::new_app(staged_head, staged_args))) + Ok(arena.alloc(Term::new_app(staged_func, staged_args))) } - // ── Splice: $(t) — the key staging step ─────────────────────────────── - // Evaluate the meta term `t` to a `VCode(obj)`. `VCode` values are - // always fully staged (produced by `unstage_obj` at quote time), so - // we return the inner term directly. + // ── Splice: $(t) — the key staging step ────────────────────────────── Term::Splice(inner) => { - let meta_val = eval_meta(arena, globals, env, inner)?; + let meta_val = eval_meta(arena, eval_arena, globals, env, inner)?; match meta_val { - MetaVal::VCode(obj_term) => Ok(obj_term), - MetaVal::VLit(_) => unreachable!( - "splice evaluated to an integer literal (typechecker invariant: \ - splice argument must have type [[T]])" - ), + MetaVal::Code(obj_term) => Ok(obj_term), + MetaVal::Lit(_) | MetaVal::Ty | MetaVal::Closure { .. } => { + unreachable!("splice evaluated to non-code value (typechecker invariant)") + } } } - // ── Let binding ─────────────────────────────────────────────────────── + // ── Let binding ────────────────────────────────────────────────────── Term::Let(let_) => { - let staged_ty = unstage_obj(arena, globals, env, let_.ty)?; - let staged_expr = unstage_obj(arena, globals, env, let_.expr)?; - // Push an object binding so that subsequent Var references by - // De Bruijn level resolve to the correct slot. + let staged_ty = unstage_obj(arena, eval_arena, globals, env, let_.ty)?; + let staged_expr = unstage_obj(arena, eval_arena, globals, env, let_.expr)?; env.push_obj(); - let staged_body = unstage_obj(arena, globals, env, let_.body); + let staged_body = unstage_obj(arena, eval_arena, globals, env, let_.body); env.pop(); Ok(arena.alloc(Term::new_let( arena.alloc_str(let_.name), @@ -494,9 +548,9 @@ fn unstage_obj<'out, 'core>( ))) } - // ── Match ───────────────────────────────────────────────────────────── + // ── Match ──────────────────────────────────────────────────────────── Term::Match(match_) => { - let staged_scrutinee = unstage_obj(arena, globals, env, match_.scrutinee)?; + let staged_scrutinee = unstage_obj(arena, eval_arena, globals, env, match_.scrutinee)?; let staged_arms: &'out [Arm<'out>] = arena.alloc_slice_try_fill_iter(match_.arms.iter().map(|arm| -> Result<_> { let staged_pat = match &arm.pat { @@ -508,7 +562,7 @@ fn unstage_obj<'out, 'core>( if has_binding { env.push_obj(); } - let staged_body = unstage_obj(arena, globals, env, arm.body); + let staged_body = unstage_obj(arena, eval_arena, globals, env, arm.body); if has_binding { env.pop(); } @@ -520,75 +574,64 @@ fn unstage_obj<'out, 'core>( Ok(arena.alloc(Term::new_match(staged_scrutinee, staged_arms))) } - // ── Unreachable in well-typed object terms ──────────────────────────── + // ── Unreachable in well-typed object terms ─────────────────────────── Term::Quote(_) => unreachable!("Quote in object context (typechecker invariant)"), - Term::Lift(_) => unreachable!("Lift in object context (typechecker invariant)"), + Term::Lift(_) | Term::Pi(_) | Term::Lam(_) => { + unreachable!("meta-only term in object context (typechecker invariant)") + } } } // ── Public entry point ──────────────────────────────────────────────────────── -/// Unstage an elaborated program, eliminating all meta-level functions and -/// splices to produce a splice-free object-level program. -/// -/// The output `Program` contains only `Phase::Object` functions. All -/// `Phase::Meta` functions are erased (they served only as compile-time -/// helpers). Every `Splice` node in object-function bodies is replaced by -/// the object code it produces when the enclosing meta computation runs. +/// Unstage an elaborated program. /// -/// # Errors -/// -/// Returns an error for genuine runtime staging errors: division by zero, -/// or a non-exhaustive match on a value not covered by any arm. +/// - `arena`: output arena; the returned `Program<'out>` is allocated here. +/// - `program`: input core program; may be dropped once this function returns. pub fn unstage_program<'out, 'core>( arena: &'out Bump, program: &'core Program<'core>, ) -> Result> { - // Build the globals table from all functions in the program. - let globals: Globals<'core> = program + // A temporary arena for intermediate values (synthetic Lam wrappers for closures, etc.) + // that exist only during evaluation and must not appear in the output. Its lifetime + // `'eval` is shorter than `'core`, so `'core` data is coercible to `'eval` via the + // covariance of `Term`. + let eval_bump = Bump::new(); + + let globals: Globals<'_> = program .functions .iter() - .map(|f| { - ( - f.name, - GlobalDef { - sig: &f.sig, - body: f.body, - }, - ) - }) + .map(|f| (f.name, GlobalDef { ty: f.ty, body: f.body })) .collect(); - // Unstage each object-level function; discard meta-level functions. let staged_fns: Vec> = program .functions .iter() - .filter(|f| f.sig.phase == Phase::Object) + .filter(|f| f.pi().phase == Phase::Object) .map(|f| -> Result<_> { - // Build an initial environment: one Obj binding per parameter, - // processing each parameter's type term before pushing the binding - // so that the env is correct for dependent types. + let pi = f.pi(); let mut env = Env::new(Lvl::new(0)); - let staged_params = arena.alloc_slice_try_fill_iter(f.sig.params.iter().map( + let staged_params = arena.alloc_slice_try_fill_iter(pi.params.iter().map( |(n, ty)| -> Result<(&'out str, &'out Term<'out>)> { - let staged_ty = unstage_obj(arena, &globals, &mut env, ty)?; + let staged_ty = unstage_obj(arena, &eval_bump, &globals, &mut env, ty)?; env.push_obj(); Ok((arena.alloc_str(n), staged_ty)) }, ))?; - let staged_ret_ty = unstage_obj(arena, &globals, &mut env, f.sig.ret_ty)?; + let staged_ret_ty = unstage_obj(arena, &eval_bump, &globals, &mut env, pi.body_ty)?; + let staged_body = unstage_obj(arena, &eval_bump, &globals, &mut env, f.body)?; - let staged_body = unstage_obj(arena, &globals, &mut env, f.body)?; + let staged_ty = arena.alloc(Term::Pi(Pi { + params: staged_params, + body_ty: staged_ret_ty, + phase: Phase::Object, + })); Ok(Function { name: Name::new(arena.alloc_str(f.name.as_str())), - sig: FunSig { - params: staged_params, - ret_ty: staged_ret_ty, - phase: f.sig.phase, - }, + ty: staged_ty, body: staged_body, }) }) diff --git a/compiler/src/parser/ast.rs b/compiler/src/parser/ast.rs index a261d72..8a11b56 100644 --- a/compiler/src/parser/ast.rs +++ b/compiler/src/parser/ast.rs @@ -1,9 +1,9 @@ pub use crate::common::{Assoc, BinOp, Name, Phase, UnOp}; /// Function or operator reference -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy)] pub enum FunName<'a> { - Name(Name<'a>), + Term(&'a Term<'a>), BinOp(BinOp), UnOp(UnOp), } @@ -11,7 +11,7 @@ pub enum FunName<'a> { impl std::fmt::Debug for FunName<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Name(n) => n.fmt(f), + Self::Term(t) => t.fmt(f), Self::BinOp(o) => o.fmt(f), Self::UnOp(o) => o.fmt(f), } @@ -65,6 +65,16 @@ pub enum Term<'a> { func: FunName<'a>, args: &'a [&'a Self], }, + /// Function type: `fn(name: ty, ...) -> ret_ty` + Pi { + params: &'a [Param<'a>], + ret_ty: &'a Self, + }, + /// Lambda: `|params| body` + Lam { + params: &'a [Param<'a>], + body: &'a Self, + }, Quote(&'a Self), Splice(&'a Self), Lift(&'a Self), diff --git a/compiler/src/parser/mod.rs b/compiler/src/parser/mod.rs index c3b9873..42329c1 100644 --- a/compiler/src/parser/mod.rs +++ b/compiler/src/parser/mod.rs @@ -256,6 +256,21 @@ where }; loop { + if self.peek() == Some(Token::LParen) { + self.next(); + let args = self.parse_separated_list(Token::RParen, |parser| { + parser.parse_expr().context("parsing function argument") + })?; + self.take(Token::RParen) + .context("expected ')' after function arguments")?; + let args = self.arena.alloc_slice_fill_iter(args); + lhs = Term::App { + func: FunName::Term(self.arena.alloc(lhs)), + args, + }; + continue; + } + let Some(op) = self.match_binop() else { break; }; @@ -294,6 +309,7 @@ where #[expect(clippy::wildcard_enum_match_arm)] fn match_binop(&mut self) -> Option { match self.peek()? { + // `|` after an expression is bitwise OR (never lambda — lambdas are atoms) Token::Bar => Some(BinOp::BitOr), Token::Ampersand => Some(BinOp::BitAnd), Token::EqEq => Some(BinOp::Eq), @@ -319,7 +335,7 @@ where .context("expected ')' after function arguments")?; let args = self.arena.alloc_slice_fill_iter(args); Ok(Term::App { - func: FunName::Name(name), + func: FunName::Term(self.arena.alloc(Term::Var(name))), args, }) } @@ -346,6 +362,48 @@ where Ok(Term::Match { scrutinee, arms }) } + /// Parse a function type: `fn(params) -> ret_ty` + /// + /// Called after consuming the `fn` token. Each param is `name: type`. + fn parse_fn_type(&mut self) -> Result> { + self.take(Token::LParen) + .context("expected '(' in function type")?; + let params = self.parse_params()?; + self.take(Token::RParen) + .context("expected ')' in function type")?; + self.take(Token::Arrow) + .context("expected '->' in function type")?; + let ret_ty = self + .parse_expr() + .context("expected return type in function type")?; + Ok(Term::Pi { params, ret_ty }) + } + + /// Parse a lambda expression: `|params| body` + /// + /// Called after consuming the `|` token. Each param is `name: type`. + fn parse_lambda(&mut self) -> Result> { + let params_vec = self.parse_separated_list(Token::Bar, |parser| { + let name = parser + .take_ident() + .context("expected parameter name in lambda")?; + parser + .take(Token::Colon) + .context("expected ':' in lambda parameter (type annotations are required)")?; + let ty = parser + .parse_atom_owned() + .context("expected parameter type")?; + let ty = parser.arena.alloc(ty); + Ok(Param { name, ty }) + })?; + self.take(Token::Bar) + .context("expected '|' after lambda parameters")?; + + let body = self.parse_expr().context("expected lambda body")?; + let params = self.arena.alloc_slice_fill_iter(params_vec); + Ok(Term::Lam { params, body }) + } + #[expect(clippy::wildcard_enum_match_arm)] fn parse_atom_owned(&mut self) -> Result> { let token = self.next().context("expected expression")??; @@ -358,6 +416,10 @@ where Ok(Term::Var(name)) } } + // `fn` not followed by ident → function type expression + Token::Fn => self.parse_fn_type(), + // `|` in atom position → lambda (not bitwise OR, which is infix) + Token::Bar => self.parse_lambda(), Token::LParen => self.parse_paren_expr(), Token::HashLParen => self.parse_quoted_expr(), Token::HashLBrace => self.parse_quoted_block(), diff --git a/compiler/src/parser/test/expr/app.snap.txt b/compiler/src/parser/test/expr/app.snap.txt index 7e23324..bb9cdd9 100644 --- a/compiler/src/parser/test/expr/app.snap.txt +++ b/compiler/src/parser/test/expr/app.snap.txt @@ -1,5 +1,7 @@ App { - func: "f", + func: Var( + "f", + ), args: [ Var( "x", diff --git a/compiler/src/parser/test/expr/complex.snap.txt b/compiler/src/parser/test/expr/complex.snap.txt index 3246487..ee65230 100644 --- a/compiler/src/parser/test/expr/complex.snap.txt +++ b/compiler/src/parser/test/expr/complex.snap.txt @@ -21,7 +21,9 @@ App { func: Not, args: [ App { - func: "foo", + func: Var( + "foo", + ), args: [ Var( "z", diff --git a/compiler/src/parser/test/mod.rs b/compiler/src/parser/test/mod.rs index aef7ce6..ff90fad 100644 --- a/compiler/src/parser/test/mod.rs +++ b/compiler/src/parser/test/mod.rs @@ -72,7 +72,7 @@ fn parse_expr_prec() { match expr { Term::App { func, args } => { assert_eq!(args.len(), 2); - assert_eq!(func, &FunName::BinOp(BinOp::Add)); + assert!(matches!(func, FunName::BinOp(BinOp::Add))); } _ => panic!("expected App"), } @@ -87,7 +87,7 @@ fn parse_expr_prec2() { match expr { Term::App { func, args } => { assert_eq!(args.len(), 2); - assert_eq!(func, &FunName::BinOp(BinOp::Add)); + assert!(matches!(func, FunName::BinOp(BinOp::Add))); } _ => panic!("expected App"), } @@ -102,7 +102,7 @@ fn parse_expr_paren() { match expr { Term::App { func, args } => { assert_eq!(args.len(), 2); - assert_eq!(func, &FunName::BinOp(BinOp::Mul)); + assert!(matches!(func, FunName::BinOp(BinOp::Mul))); } _ => panic!("expected App"), } diff --git a/compiler/tests/snap/full/let_type/0_input.splic b/compiler/tests/snap/full/let_type/0_input.splic new file mode 100644 index 0000000..d424ea0 --- /dev/null +++ b/compiler/tests/snap/full/let_type/0_input.splic @@ -0,0 +1,7 @@ +fn let_type() -> u32 { + let ty: Type = u32; + let x: ty = 1337; + x +} + +code fn test() -> u32 { $(let_type()) } diff --git a/compiler/tests/snap/full/let_type/1_lex.txt b/compiler/tests/snap/full/let_type/1_lex.txt new file mode 100644 index 0000000..173385d --- /dev/null +++ b/compiler/tests/snap/full/let_type/1_lex.txt @@ -0,0 +1,37 @@ +Fn +Ident("let_type") +LParen +RParen +Arrow +Ident("u32") +LBrace +Let +Ident("ty") +Colon +Ident("Type") +Eq +Ident("u32") +Semi +Let +Ident("x") +Colon +Ident("ty") +Eq +Num(1337) +Semi +Ident("x") +RBrace +Code +Fn +Ident("test") +LParen +RParen +Arrow +Ident("u32") +LBrace +DollarLParen +Ident("let_type") +LParen +RParen +RParen +RBrace diff --git a/compiler/tests/snap/full/let_type/2_parse.txt b/compiler/tests/snap/full/let_type/2_parse.txt new file mode 100644 index 0000000..bfdefb7 --- /dev/null +++ b/compiler/tests/snap/full/let_type/2_parse.txt @@ -0,0 +1,60 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "let_type", + params: [], + ret_ty: Var( + "u32", + ), + body: Block { + stmts: [ + Let { + name: "ty", + ty: Some( + Var( + "Type", + ), + ), + expr: Var( + "u32", + ), + }, + Let { + name: "x", + ty: Some( + Var( + "ty", + ), + ), + expr: Lit( + 1337, + ), + }, + ], + expr: Var( + "x", + ), + }, + }, + Function { + phase: Object, + name: "test", + params: [], + ret_ty: Var( + "u32", + ), + body: Block { + stmts: [], + expr: Splice( + App { + func: Var( + "let_type", + ), + args: [], + }, + ), + }, + }, + ], +} diff --git a/compiler/tests/snap/full/let_type/3_check.txt b/compiler/tests/snap/full/let_type/3_check.txt new file mode 100644 index 0000000..0ea29b3 --- /dev/null +++ b/compiler/tests/snap/full/let_type/3_check.txt @@ -0,0 +1,2 @@ +ERROR +in function `let_type`: in let binding `x`: literal `1337` cannot have a non-integer type diff --git a/compiler/tests/snap/full/pi_apply_non_fn/0_input.splic b/compiler/tests/snap/full/pi_apply_non_fn/0_input.splic new file mode 100644 index 0000000..a45916a --- /dev/null +++ b/compiler/tests/snap/full/pi_apply_non_fn/0_input.splic @@ -0,0 +1,5 @@ +// ERROR: applying a non-function value +fn test() -> u64 { + let x: u64 = 42; + x(1) +} diff --git a/compiler/tests/snap/full/pi_apply_non_fn/1_lex.txt b/compiler/tests/snap/full/pi_apply_non_fn/1_lex.txt new file mode 100644 index 0000000..c417ba6 --- /dev/null +++ b/compiler/tests/snap/full/pi_apply_non_fn/1_lex.txt @@ -0,0 +1,19 @@ +Fn +Ident("test") +LParen +RParen +Arrow +Ident("u64") +LBrace +Let +Ident("x") +Colon +Ident("u64") +Eq +Num(42) +Semi +Ident("x") +LParen +Num(1) +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_apply_non_fn/2_parse.txt b/compiler/tests/snap/full/pi_apply_non_fn/2_parse.txt new file mode 100644 index 0000000..fcdc77a --- /dev/null +++ b/compiler/tests/snap/full/pi_apply_non_fn/2_parse.txt @@ -0,0 +1,37 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "test", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [ + Let { + name: "x", + ty: Some( + Var( + "u64", + ), + ), + expr: Lit( + 42, + ), + }, + ], + expr: App { + func: Var( + "x", + ), + args: [ + Lit( + 1, + ), + ], + }, + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_apply_non_fn/3_check.txt b/compiler/tests/snap/full/pi_apply_non_fn/3_check.txt new file mode 100644 index 0000000..ad5e030 --- /dev/null +++ b/compiler/tests/snap/full/pi_apply_non_fn/3_check.txt @@ -0,0 +1,2 @@ +ERROR +in function `test`: callee is not a function type diff --git a/compiler/tests/snap/full/pi_arity_mismatch/0_input.splic b/compiler/tests/snap/full/pi_arity_mismatch/0_input.splic new file mode 100644 index 0000000..63db83c --- /dev/null +++ b/compiler/tests/snap/full/pi_arity_mismatch/0_input.splic @@ -0,0 +1,4 @@ +// ERROR: too many arguments to a function-typed variable +fn apply(f: fn(_: u64) -> u64, x: u64) -> u64 { + f(x, x) +} diff --git a/compiler/tests/snap/full/pi_arity_mismatch/1_lex.txt b/compiler/tests/snap/full/pi_arity_mismatch/1_lex.txt new file mode 100644 index 0000000..d009de6 --- /dev/null +++ b/compiler/tests/snap/full/pi_arity_mismatch/1_lex.txt @@ -0,0 +1,28 @@ +Fn +Ident("apply") +LParen +Ident("f") +Colon +Fn +LParen +Ident("_") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +Comma +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +Ident("f") +LParen +Ident("x") +Comma +Ident("x") +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_arity_mismatch/2_parse.txt b/compiler/tests/snap/full/pi_arity_mismatch/2_parse.txt new file mode 100644 index 0000000..878f1a8 --- /dev/null +++ b/compiler/tests/snap/full/pi_arity_mismatch/2_parse.txt @@ -0,0 +1,51 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "apply", + params: [ + Param { + name: "f", + ty: Pi { + params: [ + Param { + name: "_", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + }, + }, + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "f", + ), + args: [ + Var( + "x", + ), + Var( + "x", + ), + ], + }, + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_arity_mismatch/3_check.txt b/compiler/tests/snap/full/pi_arity_mismatch/3_check.txt new file mode 100644 index 0000000..4051012 --- /dev/null +++ b/compiler/tests/snap/full/pi_arity_mismatch/3_check.txt @@ -0,0 +1,2 @@ +ERROR +in function `apply`: wrong number of arguments: callee expects 1, got 2 diff --git a/compiler/tests/snap/full/pi_arity_mismatch_too_few/0_input.splic b/compiler/tests/snap/full/pi_arity_mismatch_too_few/0_input.splic new file mode 100644 index 0000000..3b84764 --- /dev/null +++ b/compiler/tests/snap/full/pi_arity_mismatch_too_few/0_input.splic @@ -0,0 +1,8 @@ +// ERROR: too few arguments to a two-arg function +fn add(x: u64, y: u64) -> u64 { + x + y +} + +fn test() -> u64 { + add(1) +} diff --git a/compiler/tests/snap/full/pi_arity_mismatch_too_few/1_lex.txt b/compiler/tests/snap/full/pi_arity_mismatch_too_few/1_lex.txt new file mode 100644 index 0000000..7cf4c46 --- /dev/null +++ b/compiler/tests/snap/full/pi_arity_mismatch_too_few/1_lex.txt @@ -0,0 +1,30 @@ +Fn +Ident("add") +LParen +Ident("x") +Colon +Ident("u64") +Comma +Ident("y") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +Ident("x") +Plus +Ident("y") +RBrace +Fn +Ident("test") +LParen +RParen +Arrow +Ident("u64") +LBrace +Ident("add") +LParen +Num(1) +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_arity_mismatch_too_few/2_parse.txt b/compiler/tests/snap/full/pi_arity_mismatch_too_few/2_parse.txt new file mode 100644 index 0000000..8fac109 --- /dev/null +++ b/compiler/tests/snap/full/pi_arity_mismatch_too_few/2_parse.txt @@ -0,0 +1,60 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "add", + params: [ + Param { + name: "x", + ty: Var( + "u64", + ), + }, + Param { + name: "y", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Add, + args: [ + Var( + "x", + ), + Var( + "y", + ), + ], + }, + }, + }, + Function { + phase: Meta, + name: "test", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "add", + ), + args: [ + Lit( + 1, + ), + ], + }, + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_arity_mismatch_too_few/3_check.txt b/compiler/tests/snap/full/pi_arity_mismatch_too_few/3_check.txt new file mode 100644 index 0000000..c24ac68 --- /dev/null +++ b/compiler/tests/snap/full/pi_arity_mismatch_too_few/3_check.txt @@ -0,0 +1,2 @@ +ERROR +in function `test`: wrong number of arguments: callee expects 2, got 1 diff --git a/compiler/tests/snap/full/pi_basic/0_input.splic b/compiler/tests/snap/full/pi_basic/0_input.splic new file mode 100644 index 0000000..19cb53f --- /dev/null +++ b/compiler/tests/snap/full/pi_basic/0_input.splic @@ -0,0 +1,12 @@ +// Higher-order function: pass a function as an argument +fn apply(f: fn(_: u64) -> u64, x: u64) -> u64 { + f(x) +} + +fn inc(x: u64) -> u64 { x + 1 } + +fn test() -> u64 { + apply(inc, 42) +} + +code fn result() -> u64 { $(test()) } diff --git a/compiler/tests/snap/full/pi_basic/1_lex.txt b/compiler/tests/snap/full/pi_basic/1_lex.txt new file mode 100644 index 0000000..07abe97 --- /dev/null +++ b/compiler/tests/snap/full/pi_basic/1_lex.txt @@ -0,0 +1,68 @@ +Fn +Ident("apply") +LParen +Ident("f") +Colon +Fn +LParen +Ident("_") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +Comma +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +Ident("f") +LParen +Ident("x") +RParen +RBrace +Fn +Ident("inc") +LParen +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +Ident("x") +Plus +Num(1) +RBrace +Fn +Ident("test") +LParen +RParen +Arrow +Ident("u64") +LBrace +Ident("apply") +LParen +Ident("inc") +Comma +Num(42) +RParen +RBrace +Code +Fn +Ident("result") +LParen +RParen +Arrow +Ident("u64") +LBrace +DollarLParen +Ident("test") +LParen +RParen +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_basic/2_parse.txt b/compiler/tests/snap/full/pi_basic/2_parse.txt new file mode 100644 index 0000000..de3e8f5 --- /dev/null +++ b/compiler/tests/snap/full/pi_basic/2_parse.txt @@ -0,0 +1,120 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "apply", + params: [ + Param { + name: "f", + ty: Pi { + params: [ + Param { + name: "_", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + }, + }, + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "f", + ), + args: [ + Var( + "x", + ), + ], + }, + }, + }, + Function { + phase: Meta, + name: "inc", + params: [ + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Add, + args: [ + Var( + "x", + ), + Lit( + 1, + ), + ], + }, + }, + }, + Function { + phase: Meta, + name: "test", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "apply", + ), + args: [ + Var( + "inc", + ), + Lit( + 42, + ), + ], + }, + }, + }, + Function { + phase: Object, + name: "result", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: Splice( + App { + func: Var( + "test", + ), + args: [], + }, + ), + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_basic/3_check.txt b/compiler/tests/snap/full/pi_basic/3_check.txt new file mode 100644 index 0000000..aaab64c --- /dev/null +++ b/compiler/tests/snap/full/pi_basic/3_check.txt @@ -0,0 +1,16 @@ +fn apply(f@0: fn(_: u64) -> u64, x@1: u64) -> u64 { + f@0(x@1) +} + +fn inc(x@0: u64) -> u64 { + @add_u64(x@0, 1_u64) +} + +fn test() -> u64 { + apply(inc, 42_u64) +} + +code fn result() -> u64 { + $(@embed_u64(test())) +} + diff --git a/compiler/tests/snap/full/pi_basic/6_stage.txt b/compiler/tests/snap/full/pi_basic/6_stage.txt new file mode 100644 index 0000000..0ae8d94 --- /dev/null +++ b/compiler/tests/snap/full/pi_basic/6_stage.txt @@ -0,0 +1,4 @@ +code fn result() -> u64 { + 43_u64 +} + diff --git a/compiler/tests/snap/full/pi_compose/0_input.splic b/compiler/tests/snap/full/pi_compose/0_input.splic new file mode 100644 index 0000000..55d74a9 --- /dev/null +++ b/compiler/tests/snap/full/pi_compose/0_input.splic @@ -0,0 +1,13 @@ +// Function composition (monomorphic) +fn compose(f: fn(_: u64) -> u64, g: fn(_: u64) -> u64) -> fn(_: u64) -> u64 { + |x: u64| f(g(x)) +} + +fn double(x: u64) -> u64 { x + x } +fn inc(x: u64) -> u64 { x + 1 } + +fn test() -> u64 { + compose(double, inc)(5) +} + +code fn result() -> u64 { $(test()) } diff --git a/compiler/tests/snap/full/pi_compose/1_lex.txt b/compiler/tests/snap/full/pi_compose/1_lex.txt new file mode 100644 index 0000000..72121d6 --- /dev/null +++ b/compiler/tests/snap/full/pi_compose/1_lex.txt @@ -0,0 +1,107 @@ +Fn +Ident("compose") +LParen +Ident("f") +Colon +Fn +LParen +Ident("_") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +Comma +Ident("g") +Colon +Fn +LParen +Ident("_") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +RParen +Arrow +Fn +LParen +Ident("_") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +Bar +Ident("x") +Colon +Ident("u64") +Bar +Ident("f") +LParen +Ident("g") +LParen +Ident("x") +RParen +RParen +RBrace +Fn +Ident("double") +LParen +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +Ident("x") +Plus +Ident("x") +RBrace +Fn +Ident("inc") +LParen +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +Ident("x") +Plus +Num(1) +RBrace +Fn +Ident("test") +LParen +RParen +Arrow +Ident("u64") +LBrace +Ident("compose") +LParen +Ident("double") +Comma +Ident("inc") +RParen +LParen +Num(5) +RParen +RBrace +Code +Fn +Ident("result") +LParen +RParen +Arrow +Ident("u64") +LBrace +DollarLParen +Ident("test") +LParen +RParen +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_compose/2_parse.txt b/compiler/tests/snap/full/pi_compose/2_parse.txt new file mode 100644 index 0000000..596643d --- /dev/null +++ b/compiler/tests/snap/full/pi_compose/2_parse.txt @@ -0,0 +1,193 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "compose", + params: [ + Param { + name: "f", + ty: Pi { + params: [ + Param { + name: "_", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + }, + }, + Param { + name: "g", + ty: Pi { + params: [ + Param { + name: "_", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + }, + }, + ], + ret_ty: Pi { + params: [ + Param { + name: "_", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + }, + body: Block { + stmts: [], + expr: Lam { + params: [ + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + body: App { + func: Var( + "f", + ), + args: [ + App { + func: Var( + "g", + ), + args: [ + Var( + "x", + ), + ], + }, + ], + }, + }, + }, + }, + Function { + phase: Meta, + name: "double", + params: [ + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Add, + args: [ + Var( + "x", + ), + Var( + "x", + ), + ], + }, + }, + }, + Function { + phase: Meta, + name: "inc", + params: [ + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Add, + args: [ + Var( + "x", + ), + Lit( + 1, + ), + ], + }, + }, + }, + Function { + phase: Meta, + name: "test", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: App { + func: Var( + "compose", + ), + args: [ + Var( + "double", + ), + Var( + "inc", + ), + ], + }, + args: [ + Lit( + 5, + ), + ], + }, + }, + }, + Function { + phase: Object, + name: "result", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: Splice( + App { + func: Var( + "test", + ), + args: [], + }, + ), + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_compose/3_check.txt b/compiler/tests/snap/full/pi_compose/3_check.txt new file mode 100644 index 0000000..8935a86 --- /dev/null +++ b/compiler/tests/snap/full/pi_compose/3_check.txt @@ -0,0 +1,20 @@ +fn compose(f@0: fn(_: u64) -> u64, g@1: fn(_: u64) -> u64) -> fn(_: u64) -> u64 { + |x@2: u64| f@0(g@1(x@2)) +} + +fn double(x@0: u64) -> u64 { + @add_u64(x@0, x@0) +} + +fn inc(x@0: u64) -> u64 { + @add_u64(x@0, 1_u64) +} + +fn test() -> u64 { + compose(double, inc)(5_u64) +} + +code fn result() -> u64 { + $(@embed_u64(test())) +} + diff --git a/compiler/tests/snap/full/pi_compose/6_stage.txt b/compiler/tests/snap/full/pi_compose/6_stage.txt new file mode 100644 index 0000000..5b9ed78 --- /dev/null +++ b/compiler/tests/snap/full/pi_compose/6_stage.txt @@ -0,0 +1,4 @@ +code fn result() -> u64 { + 12_u64 +} + diff --git a/compiler/tests/snap/full/pi_const/0_input.splic b/compiler/tests/snap/full/pi_const/0_input.splic new file mode 100644 index 0000000..837b0d4 --- /dev/null +++ b/compiler/tests/snap/full/pi_const/0_input.splic @@ -0,0 +1,10 @@ +// Const combinator: returns a function that ignores its argument +fn const_(A: Type, B: Type) -> fn(_: A) -> fn(_: B) -> A { + |a: A| |b: B| a +} + +fn test() -> u64 { + const_(u64, u8)(42)(7) +} + +code fn result() -> u64 { $(test()) } diff --git a/compiler/tests/snap/full/pi_const/1_lex.txt b/compiler/tests/snap/full/pi_const/1_lex.txt new file mode 100644 index 0000000..26dee3f --- /dev/null +++ b/compiler/tests/snap/full/pi_const/1_lex.txt @@ -0,0 +1,74 @@ +Fn +Ident("const_") +LParen +Ident("A") +Colon +Ident("Type") +Comma +Ident("B") +Colon +Ident("Type") +RParen +Arrow +Fn +LParen +Ident("_") +Colon +Ident("A") +RParen +Arrow +Fn +LParen +Ident("_") +Colon +Ident("B") +RParen +Arrow +Ident("A") +LBrace +Bar +Ident("a") +Colon +Ident("A") +Bar +Bar +Ident("b") +Colon +Ident("B") +Bar +Ident("a") +RBrace +Fn +Ident("test") +LParen +RParen +Arrow +Ident("u64") +LBrace +Ident("const_") +LParen +Ident("u64") +Comma +Ident("u8") +RParen +LParen +Num(42) +RParen +LParen +Num(7) +RParen +RBrace +Code +Fn +Ident("result") +LParen +RParen +Arrow +Ident("u64") +LBrace +DollarLParen +Ident("test") +LParen +RParen +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_const/2_parse.txt b/compiler/tests/snap/full/pi_const/2_parse.txt new file mode 100644 index 0000000..82b89ee --- /dev/null +++ b/compiler/tests/snap/full/pi_const/2_parse.txt @@ -0,0 +1,128 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "const_", + params: [ + Param { + name: "A", + ty: Var( + "Type", + ), + }, + Param { + name: "B", + ty: Var( + "Type", + ), + }, + ], + ret_ty: Pi { + params: [ + Param { + name: "_", + ty: Var( + "A", + ), + }, + ], + ret_ty: Pi { + params: [ + Param { + name: "_", + ty: Var( + "B", + ), + }, + ], + ret_ty: Var( + "A", + ), + }, + }, + body: Block { + stmts: [], + expr: Lam { + params: [ + Param { + name: "a", + ty: Var( + "A", + ), + }, + ], + body: Lam { + params: [ + Param { + name: "b", + ty: Var( + "B", + ), + }, + ], + body: Var( + "a", + ), + }, + }, + }, + }, + Function { + phase: Meta, + name: "test", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: App { + func: App { + func: Var( + "const_", + ), + args: [ + Var( + "u64", + ), + Var( + "u8", + ), + ], + }, + args: [ + Lit( + 42, + ), + ], + }, + args: [ + Lit( + 7, + ), + ], + }, + }, + }, + Function { + phase: Object, + name: "result", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: Splice( + App { + func: Var( + "test", + ), + args: [], + }, + ), + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_const/3_check.txt b/compiler/tests/snap/full/pi_const/3_check.txt new file mode 100644 index 0000000..5c4d60c --- /dev/null +++ b/compiler/tests/snap/full/pi_const/3_check.txt @@ -0,0 +1,12 @@ +fn const_(A@0: Type, B@1: Type) -> fn(_: A@0) -> fn(_: B@1) -> A@0 { + |a@2: A@0| |b@3: B@1| a@2 +} + +fn test() -> u64 { + const_(u64, u8)(42_u64)(7_u8) +} + +code fn result() -> u64 { + $(@embed_u64(test())) +} + diff --git a/compiler/tests/snap/full/pi_const/6_stage.txt b/compiler/tests/snap/full/pi_const/6_stage.txt new file mode 100644 index 0000000..8e18575 --- /dev/null +++ b/compiler/tests/snap/full/pi_const/6_stage.txt @@ -0,0 +1,4 @@ +code fn result() -> u64 { + 42_u64 +} + diff --git a/compiler/tests/snap/full/pi_dependent_ret/0_input.splic b/compiler/tests/snap/full/pi_dependent_ret/0_input.splic new file mode 100644 index 0000000..bde60a5 --- /dev/null +++ b/compiler/tests/snap/full/pi_dependent_ret/0_input.splic @@ -0,0 +1,9 @@ +// Return type depends on the first type argument. +// const_(A, B, a, b) ignores b and returns a : A. +fn const_(A: Type, B: Type, a: A, b: B) -> A { a } + +fn test_u64() -> u64 { const_(u64, u8, 10, 7) } +fn test_u8() -> u8 { const_(u8, u64, 7, 10) } + +code fn result_u64() -> u64 { $(test_u64()) } +code fn result_u8() -> u8 { $(test_u8()) } diff --git a/compiler/tests/snap/full/pi_dependent_ret/1_lex.txt b/compiler/tests/snap/full/pi_dependent_ret/1_lex.txt new file mode 100644 index 0000000..1dcecd2 --- /dev/null +++ b/compiler/tests/snap/full/pi_dependent_ret/1_lex.txt @@ -0,0 +1,88 @@ +Fn +Ident("const_") +LParen +Ident("A") +Colon +Ident("Type") +Comma +Ident("B") +Colon +Ident("Type") +Comma +Ident("a") +Colon +Ident("A") +Comma +Ident("b") +Colon +Ident("B") +RParen +Arrow +Ident("A") +LBrace +Ident("a") +RBrace +Fn +Ident("test_u64") +LParen +RParen +Arrow +Ident("u64") +LBrace +Ident("const_") +LParen +Ident("u64") +Comma +Ident("u8") +Comma +Num(10) +Comma +Num(7) +RParen +RBrace +Fn +Ident("test_u8") +LParen +RParen +Arrow +Ident("u8") +LBrace +Ident("const_") +LParen +Ident("u8") +Comma +Ident("u64") +Comma +Num(7) +Comma +Num(10) +RParen +RBrace +Code +Fn +Ident("result_u64") +LParen +RParen +Arrow +Ident("u64") +LBrace +DollarLParen +Ident("test_u64") +LParen +RParen +RParen +RBrace +Code +Fn +Ident("result_u8") +LParen +RParen +Arrow +Ident("u8") +LBrace +DollarLParen +Ident("test_u8") +LParen +RParen +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_dependent_ret/2_parse.txt b/compiler/tests/snap/full/pi_dependent_ret/2_parse.txt new file mode 100644 index 0000000..bb3e28d --- /dev/null +++ b/compiler/tests/snap/full/pi_dependent_ret/2_parse.txt @@ -0,0 +1,141 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "const_", + params: [ + Param { + name: "A", + ty: Var( + "Type", + ), + }, + Param { + name: "B", + ty: Var( + "Type", + ), + }, + Param { + name: "a", + ty: Var( + "A", + ), + }, + Param { + name: "b", + ty: Var( + "B", + ), + }, + ], + ret_ty: Var( + "A", + ), + body: Block { + stmts: [], + expr: Var( + "a", + ), + }, + }, + Function { + phase: Meta, + name: "test_u64", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "const_", + ), + args: [ + Var( + "u64", + ), + Var( + "u8", + ), + Lit( + 10, + ), + Lit( + 7, + ), + ], + }, + }, + }, + Function { + phase: Meta, + name: "test_u8", + params: [], + ret_ty: Var( + "u8", + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "const_", + ), + args: [ + Var( + "u8", + ), + Var( + "u64", + ), + Lit( + 7, + ), + Lit( + 10, + ), + ], + }, + }, + }, + Function { + phase: Object, + name: "result_u64", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: Splice( + App { + func: Var( + "test_u64", + ), + args: [], + }, + ), + }, + }, + Function { + phase: Object, + name: "result_u8", + params: [], + ret_ty: Var( + "u8", + ), + body: Block { + stmts: [], + expr: Splice( + App { + func: Var( + "test_u8", + ), + args: [], + }, + ), + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_dependent_ret/3_check.txt b/compiler/tests/snap/full/pi_dependent_ret/3_check.txt new file mode 100644 index 0000000..e9417a9 --- /dev/null +++ b/compiler/tests/snap/full/pi_dependent_ret/3_check.txt @@ -0,0 +1,20 @@ +fn const_(A@0: Type, B@1: Type, a@2: A@0, b@3: B@1) -> A@0 { + a@2 +} + +fn test_u64() -> u64 { + const_(u64, u8, 10_u64, 7_u8) +} + +fn test_u8() -> u8 { + const_(u8, u64, 7_u8, 10_u64) +} + +code fn result_u64() -> u64 { + $(@embed_u64(test_u64())) +} + +code fn result_u8() -> u8 { + $(@embed_u8(test_u8())) +} + diff --git a/compiler/tests/snap/full/pi_dependent_ret/6_stage.txt b/compiler/tests/snap/full/pi_dependent_ret/6_stage.txt new file mode 100644 index 0000000..2d777e3 --- /dev/null +++ b/compiler/tests/snap/full/pi_dependent_ret/6_stage.txt @@ -0,0 +1,8 @@ +code fn result_u64() -> u64 { + 10_u64 +} + +code fn result_u8() -> u8 { + 7_u8 +} + diff --git a/compiler/tests/snap/full/pi_lambda_arg/0_input.splic b/compiler/tests/snap/full/pi_lambda_arg/0_input.splic new file mode 100644 index 0000000..8b92c4c --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_arg/0_input.splic @@ -0,0 +1,10 @@ +// Pass a lambda as an argument to a higher-order function +fn apply(f: fn(_: u64) -> u64, x: u64) -> u64 { + f(x) +} + +fn test() -> u64 { + apply(|x: u64| x + 1, 42) +} + +code fn result() -> u64 { $(test()) } diff --git a/compiler/tests/snap/full/pi_lambda_arg/1_lex.txt b/compiler/tests/snap/full/pi_lambda_arg/1_lex.txt new file mode 100644 index 0000000..175ebe7 --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_arg/1_lex.txt @@ -0,0 +1,61 @@ +Fn +Ident("apply") +LParen +Ident("f") +Colon +Fn +LParen +Ident("_") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +Comma +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +Ident("f") +LParen +Ident("x") +RParen +RBrace +Fn +Ident("test") +LParen +RParen +Arrow +Ident("u64") +LBrace +Ident("apply") +LParen +Bar +Ident("x") +Colon +Ident("u64") +Bar +Ident("x") +Plus +Num(1) +Comma +Num(42) +RParen +RBrace +Code +Fn +Ident("result") +LParen +RParen +Arrow +Ident("u64") +LBrace +DollarLParen +Ident("test") +LParen +RParen +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_lambda_arg/2_parse.txt b/compiler/tests/snap/full/pi_lambda_arg/2_parse.txt new file mode 100644 index 0000000..06128c6 --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_arg/2_parse.txt @@ -0,0 +1,109 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "apply", + params: [ + Param { + name: "f", + ty: Pi { + params: [ + Param { + name: "_", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + }, + }, + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "f", + ), + args: [ + Var( + "x", + ), + ], + }, + }, + }, + Function { + phase: Meta, + name: "test", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "apply", + ), + args: [ + Lam { + params: [ + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + body: App { + func: Add, + args: [ + Var( + "x", + ), + Lit( + 1, + ), + ], + }, + }, + Lit( + 42, + ), + ], + }, + }, + }, + Function { + phase: Object, + name: "result", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: Splice( + App { + func: Var( + "test", + ), + args: [], + }, + ), + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_lambda_arg/3_check.txt b/compiler/tests/snap/full/pi_lambda_arg/3_check.txt new file mode 100644 index 0000000..f9a981a --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_arg/3_check.txt @@ -0,0 +1,12 @@ +fn apply(f@0: fn(_: u64) -> u64, x@1: u64) -> u64 { + f@0(x@1) +} + +fn test() -> u64 { + apply(|x@0: u64| @add_u64(x@0, 1_u64), 42_u64) +} + +code fn result() -> u64 { + $(@embed_u64(test())) +} + diff --git a/compiler/tests/snap/full/pi_lambda_arg/6_stage.txt b/compiler/tests/snap/full/pi_lambda_arg/6_stage.txt new file mode 100644 index 0000000..0ae8d94 --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_arg/6_stage.txt @@ -0,0 +1,4 @@ +code fn result() -> u64 { + 43_u64 +} + diff --git a/compiler/tests/snap/full/pi_lambda_empty_params/0_input.splic b/compiler/tests/snap/full/pi_lambda_empty_params/0_input.splic new file mode 100644 index 0000000..3c9092e --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_empty_params/0_input.splic @@ -0,0 +1,10 @@ +// A lambda with an empty parameter list creates a thunk +fn make_thunk(x: u64) -> fn() -> u64 { + || x +} + +fn test() -> u64 { + make_thunk(42)() +} + +code fn result() -> u64 { $(test()) } diff --git a/compiler/tests/snap/full/pi_lambda_empty_params/1_lex.txt b/compiler/tests/snap/full/pi_lambda_empty_params/1_lex.txt new file mode 100644 index 0000000..2a1f00a --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_empty_params/1_lex.txt @@ -0,0 +1,46 @@ +Fn +Ident("make_thunk") +LParen +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Fn +LParen +RParen +Arrow +Ident("u64") +LBrace +Bar +Bar +Ident("x") +RBrace +Fn +Ident("test") +LParen +RParen +Arrow +Ident("u64") +LBrace +Ident("make_thunk") +LParen +Num(42) +RParen +LParen +RParen +RBrace +Code +Fn +Ident("result") +LParen +RParen +Arrow +Ident("u64") +LBrace +DollarLParen +Ident("test") +LParen +RParen +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_lambda_empty_params/2_parse.txt b/compiler/tests/snap/full/pi_lambda_empty_params/2_parse.txt new file mode 100644 index 0000000..01110b7 --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_empty_params/2_parse.txt @@ -0,0 +1,74 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "make_thunk", + params: [ + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Pi { + params: [], + ret_ty: Var( + "u64", + ), + }, + body: Block { + stmts: [], + expr: Lam { + params: [], + body: Var( + "x", + ), + }, + }, + }, + Function { + phase: Meta, + name: "test", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: App { + func: Var( + "make_thunk", + ), + args: [ + Lit( + 42, + ), + ], + }, + args: [], + }, + }, + }, + Function { + phase: Object, + name: "result", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: Splice( + App { + func: Var( + "test", + ), + args: [], + }, + ), + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_lambda_empty_params/3_check.txt b/compiler/tests/snap/full/pi_lambda_empty_params/3_check.txt new file mode 100644 index 0000000..c9b7ada --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_empty_params/3_check.txt @@ -0,0 +1,12 @@ +fn make_thunk(x@0: u64) -> fn() -> u64 { + || x@0 +} + +fn test() -> u64 { + make_thunk(42_u64)() +} + +code fn result() -> u64 { + $(@embed_u64(test())) +} + diff --git a/compiler/tests/snap/full/pi_lambda_empty_params/6_stage.txt b/compiler/tests/snap/full/pi_lambda_empty_params/6_stage.txt new file mode 100644 index 0000000..8e18575 --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_empty_params/6_stage.txt @@ -0,0 +1,4 @@ +code fn result() -> u64 { + 42_u64 +} + diff --git a/compiler/tests/snap/full/pi_lambda_in_object/0_input.splic b/compiler/tests/snap/full/pi_lambda_in_object/0_input.splic new file mode 100644 index 0000000..69981b4 --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_in_object/0_input.splic @@ -0,0 +1,4 @@ +// ERROR: lambda in object-level function body (meta-level only) +code fn test(x: u64) -> u64 { + |y: u64| y +} diff --git a/compiler/tests/snap/full/pi_lambda_in_object/1_lex.txt b/compiler/tests/snap/full/pi_lambda_in_object/1_lex.txt new file mode 100644 index 0000000..6f1a5b5 --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_in_object/1_lex.txt @@ -0,0 +1,18 @@ +Code +Fn +Ident("test") +LParen +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +Bar +Ident("y") +Colon +Ident("u64") +Bar +Ident("y") +RBrace diff --git a/compiler/tests/snap/full/pi_lambda_in_object/2_parse.txt b/compiler/tests/snap/full/pi_lambda_in_object/2_parse.txt new file mode 100644 index 0000000..9932c1f --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_in_object/2_parse.txt @@ -0,0 +1,35 @@ +Program { + functions: [ + Function { + phase: Object, + name: "test", + params: [ + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: Lam { + params: [ + Param { + name: "y", + ty: Var( + "u64", + ), + }, + ], + body: Var( + "y", + ), + }, + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_lambda_in_object/3_check.txt b/compiler/tests/snap/full/pi_lambda_in_object/3_check.txt new file mode 100644 index 0000000..7e5df17 --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_in_object/3_check.txt @@ -0,0 +1,2 @@ +ERROR +in function `test`: lambdas are only valid in meta-phase context diff --git a/compiler/tests/snap/full/pi_lambda_missing_annotation/0_input.splic b/compiler/tests/snap/full/pi_lambda_missing_annotation/0_input.splic new file mode 100644 index 0000000..70b6935 --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_missing_annotation/0_input.splic @@ -0,0 +1,8 @@ +// ERROR: lambda without type annotation should fail +fn apply(f: fn(_: u64) -> u64, x: u64) -> u64 { + f(x) +} + +fn test() -> u64 { + apply(|x| x + 1, 42) +} diff --git a/compiler/tests/snap/full/pi_lambda_missing_annotation/1_lex.txt b/compiler/tests/snap/full/pi_lambda_missing_annotation/1_lex.txt new file mode 100644 index 0000000..ced7ad1 --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_missing_annotation/1_lex.txt @@ -0,0 +1,45 @@ +Fn +Ident("apply") +LParen +Ident("f") +Colon +Fn +LParen +Ident("_") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +Comma +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +Ident("f") +LParen +Ident("x") +RParen +RBrace +Fn +Ident("test") +LParen +RParen +Arrow +Ident("u64") +LBrace +Ident("apply") +LParen +Bar +Ident("x") +Bar +Ident("x") +Plus +Num(1) +Comma +Num(42) +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_lambda_missing_annotation/2_parse.txt b/compiler/tests/snap/full/pi_lambda_missing_annotation/2_parse.txt new file mode 100644 index 0000000..5dad2f0 --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_missing_annotation/2_parse.txt @@ -0,0 +1,2 @@ +ERROR +in function `test`: expected function body: parsing expression in block: parsing function argument: expected ':' in lambda parameter (type annotations are required): expected Colon, got Bar diff --git a/compiler/tests/snap/full/pi_lambda_type_mismatch/0_input.splic b/compiler/tests/snap/full/pi_lambda_type_mismatch/0_input.splic new file mode 100644 index 0000000..59927da --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_type_mismatch/0_input.splic @@ -0,0 +1,8 @@ +// ERROR: lambda parameter type doesn't match expected function type +fn apply(f: fn(_: u64) -> u64, x: u64) -> u64 { + f(x) +} + +fn test() -> u64 { + apply(|x: u32| x, 42) +} diff --git a/compiler/tests/snap/full/pi_lambda_type_mismatch/1_lex.txt b/compiler/tests/snap/full/pi_lambda_type_mismatch/1_lex.txt new file mode 100644 index 0000000..b5fda56 --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_type_mismatch/1_lex.txt @@ -0,0 +1,45 @@ +Fn +Ident("apply") +LParen +Ident("f") +Colon +Fn +LParen +Ident("_") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +Comma +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +Ident("f") +LParen +Ident("x") +RParen +RBrace +Fn +Ident("test") +LParen +RParen +Arrow +Ident("u64") +LBrace +Ident("apply") +LParen +Bar +Ident("x") +Colon +Ident("u32") +Bar +Ident("x") +Comma +Num(42) +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_lambda_type_mismatch/2_parse.txt b/compiler/tests/snap/full/pi_lambda_type_mismatch/2_parse.txt new file mode 100644 index 0000000..9650eec --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_type_mismatch/2_parse.txt @@ -0,0 +1,82 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "apply", + params: [ + Param { + name: "f", + ty: Pi { + params: [ + Param { + name: "_", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + }, + }, + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "f", + ), + args: [ + Var( + "x", + ), + ], + }, + }, + }, + Function { + phase: Meta, + name: "test", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "apply", + ), + args: [ + Lam { + params: [ + Param { + name: "x", + ty: Var( + "u32", + ), + }, + ], + body: Var( + "x", + ), + }, + Lit( + 42, + ), + ], + }, + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_lambda_type_mismatch/3_check.txt b/compiler/tests/snap/full/pi_lambda_type_mismatch/3_check.txt new file mode 100644 index 0000000..2109895 --- /dev/null +++ b/compiler/tests/snap/full/pi_lambda_type_mismatch/3_check.txt @@ -0,0 +1,2 @@ +ERROR +in function `test`: in argument 0 of function call: lambda parameter type mismatch: annotation gives a different type than the expected function type diff --git a/compiler/tests/snap/full/pi_nested/0_input.splic b/compiler/tests/snap/full/pi_nested/0_input.splic new file mode 100644 index 0000000..f1b2a04 --- /dev/null +++ b/compiler/tests/snap/full/pi_nested/0_input.splic @@ -0,0 +1,12 @@ +// Nested function types: function that takes a function and applies it twice +fn apply_twice(f: fn(_: u64) -> u64, x: u64) -> u64 { + f(f(x)) +} + +fn inc(x: u64) -> u64 { x + 1 } + +fn test() -> u64 { + apply_twice(inc, 0) +} + +code fn result() -> u64 { $(test()) } diff --git a/compiler/tests/snap/full/pi_nested/1_lex.txt b/compiler/tests/snap/full/pi_nested/1_lex.txt new file mode 100644 index 0000000..5d36ff6 --- /dev/null +++ b/compiler/tests/snap/full/pi_nested/1_lex.txt @@ -0,0 +1,71 @@ +Fn +Ident("apply_twice") +LParen +Ident("f") +Colon +Fn +LParen +Ident("_") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +Comma +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +Ident("f") +LParen +Ident("f") +LParen +Ident("x") +RParen +RParen +RBrace +Fn +Ident("inc") +LParen +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +Ident("x") +Plus +Num(1) +RBrace +Fn +Ident("test") +LParen +RParen +Arrow +Ident("u64") +LBrace +Ident("apply_twice") +LParen +Ident("inc") +Comma +Num(0) +RParen +RBrace +Code +Fn +Ident("result") +LParen +RParen +Arrow +Ident("u64") +LBrace +DollarLParen +Ident("test") +LParen +RParen +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_nested/2_parse.txt b/compiler/tests/snap/full/pi_nested/2_parse.txt new file mode 100644 index 0000000..89cf92c --- /dev/null +++ b/compiler/tests/snap/full/pi_nested/2_parse.txt @@ -0,0 +1,127 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "apply_twice", + params: [ + Param { + name: "f", + ty: Pi { + params: [ + Param { + name: "_", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + }, + }, + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "f", + ), + args: [ + App { + func: Var( + "f", + ), + args: [ + Var( + "x", + ), + ], + }, + ], + }, + }, + }, + Function { + phase: Meta, + name: "inc", + params: [ + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Add, + args: [ + Var( + "x", + ), + Lit( + 1, + ), + ], + }, + }, + }, + Function { + phase: Meta, + name: "test", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "apply_twice", + ), + args: [ + Var( + "inc", + ), + Lit( + 0, + ), + ], + }, + }, + }, + Function { + phase: Object, + name: "result", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: Splice( + App { + func: Var( + "test", + ), + args: [], + }, + ), + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_nested/3_check.txt b/compiler/tests/snap/full/pi_nested/3_check.txt new file mode 100644 index 0000000..2f5001f --- /dev/null +++ b/compiler/tests/snap/full/pi_nested/3_check.txt @@ -0,0 +1,16 @@ +fn apply_twice(f@0: fn(_: u64) -> u64, x@1: u64) -> u64 { + f@0(f@0(x@1)) +} + +fn inc(x@0: u64) -> u64 { + @add_u64(x@0, 1_u64) +} + +fn test() -> u64 { + apply_twice(inc, 0_u64) +} + +code fn result() -> u64 { + $(@embed_u64(test())) +} + diff --git a/compiler/tests/snap/full/pi_nested/6_stage.txt b/compiler/tests/snap/full/pi_nested/6_stage.txt new file mode 100644 index 0000000..f4f672a --- /dev/null +++ b/compiler/tests/snap/full/pi_nested/6_stage.txt @@ -0,0 +1,4 @@ +code fn result() -> u64 { + 2_u64 +} + diff --git a/compiler/tests/snap/full/pi_nested_polymorphic/0_input.splic b/compiler/tests/snap/full/pi_nested_polymorphic/0_input.splic new file mode 100644 index 0000000..3ba6fff --- /dev/null +++ b/compiler/tests/snap/full/pi_nested_polymorphic/0_input.splic @@ -0,0 +1,12 @@ +// Polymorphic apply_twice: polymorphic in both the value type and the function +fn apply_twice(A: Type, f: fn(_: A) -> A, x: A) -> A { + f(f(x)) +} + +fn inc(x: u64) -> u64 { x + 1 } + +fn test() -> u64 { + apply_twice(u64, inc, 0) +} + +code fn result() -> u64 { $(test()) } diff --git a/compiler/tests/snap/full/pi_nested_polymorphic/1_lex.txt b/compiler/tests/snap/full/pi_nested_polymorphic/1_lex.txt new file mode 100644 index 0000000..0236883 --- /dev/null +++ b/compiler/tests/snap/full/pi_nested_polymorphic/1_lex.txt @@ -0,0 +1,77 @@ +Fn +Ident("apply_twice") +LParen +Ident("A") +Colon +Ident("Type") +Comma +Ident("f") +Colon +Fn +LParen +Ident("_") +Colon +Ident("A") +RParen +Arrow +Ident("A") +Comma +Ident("x") +Colon +Ident("A") +RParen +Arrow +Ident("A") +LBrace +Ident("f") +LParen +Ident("f") +LParen +Ident("x") +RParen +RParen +RBrace +Fn +Ident("inc") +LParen +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +Ident("x") +Plus +Num(1) +RBrace +Fn +Ident("test") +LParen +RParen +Arrow +Ident("u64") +LBrace +Ident("apply_twice") +LParen +Ident("u64") +Comma +Ident("inc") +Comma +Num(0) +RParen +RBrace +Code +Fn +Ident("result") +LParen +RParen +Arrow +Ident("u64") +LBrace +DollarLParen +Ident("test") +LParen +RParen +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_nested_polymorphic/2_parse.txt b/compiler/tests/snap/full/pi_nested_polymorphic/2_parse.txt new file mode 100644 index 0000000..15eccd2 --- /dev/null +++ b/compiler/tests/snap/full/pi_nested_polymorphic/2_parse.txt @@ -0,0 +1,136 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "apply_twice", + params: [ + Param { + name: "A", + ty: Var( + "Type", + ), + }, + Param { + name: "f", + ty: Pi { + params: [ + Param { + name: "_", + ty: Var( + "A", + ), + }, + ], + ret_ty: Var( + "A", + ), + }, + }, + Param { + name: "x", + ty: Var( + "A", + ), + }, + ], + ret_ty: Var( + "A", + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "f", + ), + args: [ + App { + func: Var( + "f", + ), + args: [ + Var( + "x", + ), + ], + }, + ], + }, + }, + }, + Function { + phase: Meta, + name: "inc", + params: [ + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Add, + args: [ + Var( + "x", + ), + Lit( + 1, + ), + ], + }, + }, + }, + Function { + phase: Meta, + name: "test", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "apply_twice", + ), + args: [ + Var( + "u64", + ), + Var( + "inc", + ), + Lit( + 0, + ), + ], + }, + }, + }, + Function { + phase: Object, + name: "result", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: Splice( + App { + func: Var( + "test", + ), + args: [], + }, + ), + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_nested_polymorphic/3_check.txt b/compiler/tests/snap/full/pi_nested_polymorphic/3_check.txt new file mode 100644 index 0000000..6d94ea5 --- /dev/null +++ b/compiler/tests/snap/full/pi_nested_polymorphic/3_check.txt @@ -0,0 +1,2 @@ +ERROR +in function `apply_twice`: in argument 0 of function call: type mismatch diff --git a/compiler/tests/snap/full/pi_polycompose/0_input.splic b/compiler/tests/snap/full/pi_polycompose/0_input.splic new file mode 100644 index 0000000..0964a51 --- /dev/null +++ b/compiler/tests/snap/full/pi_polycompose/0_input.splic @@ -0,0 +1,13 @@ +// Polymorphic function composition +fn compose(A: Type, B: Type, C: Type, f: fn(_: B) -> C, g: fn(_: A) -> B) -> fn(_: A) -> C { + |x: A| f(g(x)) +} + +fn double(x: u64) -> u64 { x + x } +fn to_u8(x: u64) -> u8 { match x { 0 => 0, 1 => 1, 2 => 2, _ => 3 } } + +fn test() -> u8 { + compose(u64, u64, u8, to_u8, double)(5) +} + +code fn result() -> u8 { $(test()) } diff --git a/compiler/tests/snap/full/pi_polycompose/1_lex.txt b/compiler/tests/snap/full/pi_polycompose/1_lex.txt new file mode 100644 index 0000000..87e85d9 --- /dev/null +++ b/compiler/tests/snap/full/pi_polycompose/1_lex.txt @@ -0,0 +1,141 @@ +Fn +Ident("compose") +LParen +Ident("A") +Colon +Ident("Type") +Comma +Ident("B") +Colon +Ident("Type") +Comma +Ident("C") +Colon +Ident("Type") +Comma +Ident("f") +Colon +Fn +LParen +Ident("_") +Colon +Ident("B") +RParen +Arrow +Ident("C") +Comma +Ident("g") +Colon +Fn +LParen +Ident("_") +Colon +Ident("A") +RParen +Arrow +Ident("B") +RParen +Arrow +Fn +LParen +Ident("_") +Colon +Ident("A") +RParen +Arrow +Ident("C") +LBrace +Bar +Ident("x") +Colon +Ident("A") +Bar +Ident("f") +LParen +Ident("g") +LParen +Ident("x") +RParen +RParen +RBrace +Fn +Ident("double") +LParen +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +Ident("x") +Plus +Ident("x") +RBrace +Fn +Ident("to_u8") +LParen +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Ident("u8") +LBrace +Match +Ident("x") +LBrace +Num(0) +DArrow +Num(0) +Comma +Num(1) +DArrow +Num(1) +Comma +Num(2) +DArrow +Num(2) +Comma +Ident("_") +DArrow +Num(3) +RBrace +RBrace +Fn +Ident("test") +LParen +RParen +Arrow +Ident("u8") +LBrace +Ident("compose") +LParen +Ident("u64") +Comma +Ident("u64") +Comma +Ident("u8") +Comma +Ident("to_u8") +Comma +Ident("double") +RParen +LParen +Num(5) +RParen +RBrace +Code +Fn +Ident("result") +LParen +RParen +Arrow +Ident("u8") +LBrace +DollarLParen +Ident("test") +LParen +RParen +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_polycompose/2_parse.txt b/compiler/tests/snap/full/pi_polycompose/2_parse.txt new file mode 100644 index 0000000..1dcc577 --- /dev/null +++ b/compiler/tests/snap/full/pi_polycompose/2_parse.txt @@ -0,0 +1,248 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "compose", + params: [ + Param { + name: "A", + ty: Var( + "Type", + ), + }, + Param { + name: "B", + ty: Var( + "Type", + ), + }, + Param { + name: "C", + ty: Var( + "Type", + ), + }, + Param { + name: "f", + ty: Pi { + params: [ + Param { + name: "_", + ty: Var( + "B", + ), + }, + ], + ret_ty: Var( + "C", + ), + }, + }, + Param { + name: "g", + ty: Pi { + params: [ + Param { + name: "_", + ty: Var( + "A", + ), + }, + ], + ret_ty: Var( + "B", + ), + }, + }, + ], + ret_ty: Pi { + params: [ + Param { + name: "_", + ty: Var( + "A", + ), + }, + ], + ret_ty: Var( + "C", + ), + }, + body: Block { + stmts: [], + expr: Lam { + params: [ + Param { + name: "x", + ty: Var( + "A", + ), + }, + ], + body: App { + func: Var( + "f", + ), + args: [ + App { + func: Var( + "g", + ), + args: [ + Var( + "x", + ), + ], + }, + ], + }, + }, + }, + }, + Function { + phase: Meta, + name: "double", + params: [ + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Add, + args: [ + Var( + "x", + ), + Var( + "x", + ), + ], + }, + }, + }, + Function { + phase: Meta, + name: "to_u8", + params: [ + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u8", + ), + body: Block { + stmts: [], + expr: Match { + scrutinee: Var( + "x", + ), + arms: [ + MatchArm { + pat: Lit( + 0, + ), + body: Lit( + 0, + ), + }, + MatchArm { + pat: Lit( + 1, + ), + body: Lit( + 1, + ), + }, + MatchArm { + pat: Lit( + 2, + ), + body: Lit( + 2, + ), + }, + MatchArm { + pat: Name( + "_", + ), + body: Lit( + 3, + ), + }, + ], + }, + }, + }, + Function { + phase: Meta, + name: "test", + params: [], + ret_ty: Var( + "u8", + ), + body: Block { + stmts: [], + expr: App { + func: App { + func: Var( + "compose", + ), + args: [ + Var( + "u64", + ), + Var( + "u64", + ), + Var( + "u8", + ), + Var( + "to_u8", + ), + Var( + "double", + ), + ], + }, + args: [ + Lit( + 5, + ), + ], + }, + }, + }, + Function { + phase: Object, + name: "result", + params: [], + ret_ty: Var( + "u8", + ), + body: Block { + stmts: [], + expr: Splice( + App { + func: Var( + "test", + ), + args: [], + }, + ), + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_polycompose/3_check.txt b/compiler/tests/snap/full/pi_polycompose/3_check.txt new file mode 100644 index 0000000..0ee33d6 --- /dev/null +++ b/compiler/tests/snap/full/pi_polycompose/3_check.txt @@ -0,0 +1,25 @@ +fn compose(A@0: Type, B@1: Type, C@2: Type, f@3: fn(_: B@1) -> C@2, g@4: fn(_: A@0) -> B@1) -> fn(_: A@0) -> C@2 { + |x@5: A@0| f@3(g@4(x@5)) +} + +fn double(x@0: u64) -> u64 { + @add_u64(x@0, x@0) +} + +fn to_u8(x@0: u64) -> u8 { + match x@0 { + 0 => 0_u8, + 1 => 1_u8, + 2 => 2_u8, + _ => 3_u8, + } +} + +fn test() -> u8 { + compose(u64, u64, u8, to_u8, double)(5_u64) +} + +code fn result() -> u8 { + $(@embed_u8(test())) +} + diff --git a/compiler/tests/snap/full/pi_polycompose/6_stage.txt b/compiler/tests/snap/full/pi_polycompose/6_stage.txt new file mode 100644 index 0000000..84580a6 --- /dev/null +++ b/compiler/tests/snap/full/pi_polycompose/6_stage.txt @@ -0,0 +1,4 @@ +code fn result() -> u8 { + 3_u8 +} + diff --git a/compiler/tests/snap/full/pi_polymorphic_id/0_input.splic b/compiler/tests/snap/full/pi_polymorphic_id/0_input.splic new file mode 100644 index 0000000..352f9e9 --- /dev/null +++ b/compiler/tests/snap/full/pi_polymorphic_id/0_input.splic @@ -0,0 +1,8 @@ +// Polymorphic identity function using dependent types +fn id(A: Type, x: A) -> A { x } + +fn test_u64() -> u64 { id(u64, 42) } +fn test_u8() -> u8 { id(u8, 7) } + +code fn result_u64() -> u64 { $(test_u64()) } +code fn result_u8() -> u8 { $(test_u8()) } diff --git a/compiler/tests/snap/full/pi_polymorphic_id/1_lex.txt b/compiler/tests/snap/full/pi_polymorphic_id/1_lex.txt new file mode 100644 index 0000000..f19a724 --- /dev/null +++ b/compiler/tests/snap/full/pi_polymorphic_id/1_lex.txt @@ -0,0 +1,72 @@ +Fn +Ident("id") +LParen +Ident("A") +Colon +Ident("Type") +Comma +Ident("x") +Colon +Ident("A") +RParen +Arrow +Ident("A") +LBrace +Ident("x") +RBrace +Fn +Ident("test_u64") +LParen +RParen +Arrow +Ident("u64") +LBrace +Ident("id") +LParen +Ident("u64") +Comma +Num(42) +RParen +RBrace +Fn +Ident("test_u8") +LParen +RParen +Arrow +Ident("u8") +LBrace +Ident("id") +LParen +Ident("u8") +Comma +Num(7) +RParen +RBrace +Code +Fn +Ident("result_u64") +LParen +RParen +Arrow +Ident("u64") +LBrace +DollarLParen +Ident("test_u64") +LParen +RParen +RParen +RBrace +Code +Fn +Ident("result_u8") +LParen +RParen +Arrow +Ident("u8") +LBrace +DollarLParen +Ident("test_u8") +LParen +RParen +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_polymorphic_id/2_parse.txt b/compiler/tests/snap/full/pi_polymorphic_id/2_parse.txt new file mode 100644 index 0000000..08afa73 --- /dev/null +++ b/compiler/tests/snap/full/pi_polymorphic_id/2_parse.txt @@ -0,0 +1,117 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "id", + params: [ + Param { + name: "A", + ty: Var( + "Type", + ), + }, + Param { + name: "x", + ty: Var( + "A", + ), + }, + ], + ret_ty: Var( + "A", + ), + body: Block { + stmts: [], + expr: Var( + "x", + ), + }, + }, + Function { + phase: Meta, + name: "test_u64", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "id", + ), + args: [ + Var( + "u64", + ), + Lit( + 42, + ), + ], + }, + }, + }, + Function { + phase: Meta, + name: "test_u8", + params: [], + ret_ty: Var( + "u8", + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "id", + ), + args: [ + Var( + "u8", + ), + Lit( + 7, + ), + ], + }, + }, + }, + Function { + phase: Object, + name: "result_u64", + params: [], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: Splice( + App { + func: Var( + "test_u64", + ), + args: [], + }, + ), + }, + }, + Function { + phase: Object, + name: "result_u8", + params: [], + ret_ty: Var( + "u8", + ), + body: Block { + stmts: [], + expr: Splice( + App { + func: Var( + "test_u8", + ), + args: [], + }, + ), + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_polymorphic_id/3_check.txt b/compiler/tests/snap/full/pi_polymorphic_id/3_check.txt new file mode 100644 index 0000000..957a747 --- /dev/null +++ b/compiler/tests/snap/full/pi_polymorphic_id/3_check.txt @@ -0,0 +1,20 @@ +fn id(A@0: Type, x@1: A@0) -> A@0 { + x@1 +} + +fn test_u64() -> u64 { + id(u64, 42_u64) +} + +fn test_u8() -> u8 { + id(u8, 7_u8) +} + +code fn result_u64() -> u64 { + $(@embed_u64(test_u64())) +} + +code fn result_u8() -> u8 { + $(@embed_u8(test_u8())) +} + diff --git a/compiler/tests/snap/full/pi_polymorphic_id/6_stage.txt b/compiler/tests/snap/full/pi_polymorphic_id/6_stage.txt new file mode 100644 index 0000000..83bfdca --- /dev/null +++ b/compiler/tests/snap/full/pi_polymorphic_id/6_stage.txt @@ -0,0 +1,8 @@ +code fn result_u64() -> u64 { + 42_u64 +} + +code fn result_u8() -> u8 { + 7_u8 +} + diff --git a/compiler/tests/snap/full/pi_repeat/0_input.splic b/compiler/tests/snap/full/pi_repeat/0_input.splic new file mode 100644 index 0000000..6b5d2d6 --- /dev/null +++ b/compiler/tests/snap/full/pi_repeat/0_input.splic @@ -0,0 +1,11 @@ +// The motivating example: pass a code-generating lambda, unroll at compile time +fn repeat(f: fn(_: [[u64]]) -> [[u64]], n: u64, x: [[u64]]) -> [[u64]] { + match n { + 0 => x, + n => repeat(f, n - 1, f(x)), + } +} + +code fn square_twice(x: u64) -> u64 { + $(repeat(|y: [[u64]]| #($(y) * $(y)), 2, #(x))) +} diff --git a/compiler/tests/snap/full/pi_repeat/1_lex.txt b/compiler/tests/snap/full/pi_repeat/1_lex.txt new file mode 100644 index 0000000..af4e85a --- /dev/null +++ b/compiler/tests/snap/full/pi_repeat/1_lex.txt @@ -0,0 +1,97 @@ +Fn +Ident("repeat") +LParen +Ident("f") +Colon +Fn +LParen +Ident("_") +Colon +DoubleLBracket +Ident("u64") +DoubleRBracket +RParen +Arrow +DoubleLBracket +Ident("u64") +DoubleRBracket +Comma +Ident("n") +Colon +Ident("u64") +Comma +Ident("x") +Colon +DoubleLBracket +Ident("u64") +DoubleRBracket +RParen +Arrow +DoubleLBracket +Ident("u64") +DoubleRBracket +LBrace +Match +Ident("n") +LBrace +Num(0) +DArrow +Ident("x") +Comma +Ident("n") +DArrow +Ident("repeat") +LParen +Ident("f") +Comma +Ident("n") +Minus +Num(1) +Comma +Ident("f") +LParen +Ident("x") +RParen +RParen +Comma +RBrace +RBrace +Code +Fn +Ident("square_twice") +LParen +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +DollarLParen +Ident("repeat") +LParen +Bar +Ident("y") +Colon +DoubleLBracket +Ident("u64") +DoubleRBracket +Bar +HashLParen +DollarLParen +Ident("y") +RParen +Star +DollarLParen +Ident("y") +RParen +RParen +Comma +Num(2) +Comma +HashLParen +Ident("x") +RParen +RParen +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_repeat/2_parse.txt b/compiler/tests/snap/full/pi_repeat/2_parse.txt new file mode 100644 index 0000000..5d47607 --- /dev/null +++ b/compiler/tests/snap/full/pi_repeat/2_parse.txt @@ -0,0 +1,167 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "repeat", + params: [ + Param { + name: "f", + ty: Pi { + params: [ + Param { + name: "_", + ty: Lift( + Var( + "u64", + ), + ), + }, + ], + ret_ty: Lift( + Var( + "u64", + ), + ), + }, + }, + Param { + name: "n", + ty: Var( + "u64", + ), + }, + Param { + name: "x", + ty: Lift( + Var( + "u64", + ), + ), + }, + ], + ret_ty: Lift( + Var( + "u64", + ), + ), + body: Block { + stmts: [], + expr: Match { + scrutinee: Var( + "n", + ), + arms: [ + MatchArm { + pat: Lit( + 0, + ), + body: Var( + "x", + ), + }, + MatchArm { + pat: Name( + "n", + ), + body: App { + func: Var( + "repeat", + ), + args: [ + Var( + "f", + ), + App { + func: Sub, + args: [ + Var( + "n", + ), + Lit( + 1, + ), + ], + }, + App { + func: Var( + "f", + ), + args: [ + Var( + "x", + ), + ], + }, + ], + }, + }, + ], + }, + }, + }, + Function { + phase: Object, + name: "square_twice", + params: [ + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: Splice( + App { + func: Var( + "repeat", + ), + args: [ + Lam { + params: [ + Param { + name: "y", + ty: Lift( + Var( + "u64", + ), + ), + }, + ], + body: Quote( + App { + func: Mul, + args: [ + Splice( + Var( + "y", + ), + ), + Splice( + Var( + "y", + ), + ), + ], + }, + ), + }, + Lit( + 2, + ), + Quote( + Var( + "x", + ), + ), + ], + }, + ), + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_repeat/3_check.txt b/compiler/tests/snap/full/pi_repeat/3_check.txt new file mode 100644 index 0000000..6e926ae --- /dev/null +++ b/compiler/tests/snap/full/pi_repeat/3_check.txt @@ -0,0 +1,11 @@ +fn repeat(f@0: fn(_: [[u64]]) -> [[u64]], n@1: u64, x@2: [[u64]]) -> [[u64]] { + match n@1 { + 0 => x@2, + n@3 => repeat(f@0, @sub_u64(n@3, 1_u64), f@0(x@2)), + } +} + +code fn square_twice(x@0: u64) -> u64 { + $(repeat(|y@1: [[u64]]| #(@mul_u64($(y@1), $(y@1))), 2_u64, #(x@0))) +} + diff --git a/compiler/tests/snap/full/pi_repeat/6_stage.txt b/compiler/tests/snap/full/pi_repeat/6_stage.txt new file mode 100644 index 0000000..93dcff6 --- /dev/null +++ b/compiler/tests/snap/full/pi_repeat/6_stage.txt @@ -0,0 +1,4 @@ +code fn square_twice(x@0: u64) -> u64 { + @mul_u64(@mul_u64(x@0, x@0), @mul_u64(x@0, x@0)) +} + diff --git a/compiler/tests/snap/full/pi_staging_hof/0_input.splic b/compiler/tests/snap/full/pi_staging_hof/0_input.splic new file mode 100644 index 0000000..001f69a --- /dev/null +++ b/compiler/tests/snap/full/pi_staging_hof/0_input.splic @@ -0,0 +1,8 @@ +// Higher-order staging: meta function that transforms code via a lambda +fn map_code(f: fn(_: [[u64]]) -> [[u64]], x: [[u64]]) -> [[u64]] { + f(x) +} + +code fn double(x: u64) -> u64 { + $(map_code(|y: [[u64]]| #($(y) + $(y)), #(x))) +} diff --git a/compiler/tests/snap/full/pi_staging_hof/1_lex.txt b/compiler/tests/snap/full/pi_staging_hof/1_lex.txt new file mode 100644 index 0000000..4257536 --- /dev/null +++ b/compiler/tests/snap/full/pi_staging_hof/1_lex.txt @@ -0,0 +1,71 @@ +Fn +Ident("map_code") +LParen +Ident("f") +Colon +Fn +LParen +Ident("_") +Colon +DoubleLBracket +Ident("u64") +DoubleRBracket +RParen +Arrow +DoubleLBracket +Ident("u64") +DoubleRBracket +Comma +Ident("x") +Colon +DoubleLBracket +Ident("u64") +DoubleRBracket +RParen +Arrow +DoubleLBracket +Ident("u64") +DoubleRBracket +LBrace +Ident("f") +LParen +Ident("x") +RParen +RBrace +Code +Fn +Ident("double") +LParen +Ident("x") +Colon +Ident("u64") +RParen +Arrow +Ident("u64") +LBrace +DollarLParen +Ident("map_code") +LParen +Bar +Ident("y") +Colon +DoubleLBracket +Ident("u64") +DoubleRBracket +Bar +HashLParen +DollarLParen +Ident("y") +RParen +Plus +DollarLParen +Ident("y") +RParen +RParen +Comma +HashLParen +Ident("x") +RParen +RParen +RParen +RBrace diff --git a/compiler/tests/snap/full/pi_staging_hof/2_parse.txt b/compiler/tests/snap/full/pi_staging_hof/2_parse.txt new file mode 100644 index 0000000..301f150 --- /dev/null +++ b/compiler/tests/snap/full/pi_staging_hof/2_parse.txt @@ -0,0 +1,117 @@ +Program { + functions: [ + Function { + phase: Meta, + name: "map_code", + params: [ + Param { + name: "f", + ty: Pi { + params: [ + Param { + name: "_", + ty: Lift( + Var( + "u64", + ), + ), + }, + ], + ret_ty: Lift( + Var( + "u64", + ), + ), + }, + }, + Param { + name: "x", + ty: Lift( + Var( + "u64", + ), + ), + }, + ], + ret_ty: Lift( + Var( + "u64", + ), + ), + body: Block { + stmts: [], + expr: App { + func: Var( + "f", + ), + args: [ + Var( + "x", + ), + ], + }, + }, + }, + Function { + phase: Object, + name: "double", + params: [ + Param { + name: "x", + ty: Var( + "u64", + ), + }, + ], + ret_ty: Var( + "u64", + ), + body: Block { + stmts: [], + expr: Splice( + App { + func: Var( + "map_code", + ), + args: [ + Lam { + params: [ + Param { + name: "y", + ty: Lift( + Var( + "u64", + ), + ), + }, + ], + body: Quote( + App { + func: Add, + args: [ + Splice( + Var( + "y", + ), + ), + Splice( + Var( + "y", + ), + ), + ], + }, + ), + }, + Quote( + Var( + "x", + ), + ), + ], + }, + ), + }, + }, + ], +} diff --git a/compiler/tests/snap/full/pi_staging_hof/3_check.txt b/compiler/tests/snap/full/pi_staging_hof/3_check.txt new file mode 100644 index 0000000..a389070 --- /dev/null +++ b/compiler/tests/snap/full/pi_staging_hof/3_check.txt @@ -0,0 +1,8 @@ +fn map_code(f@0: fn(_: [[u64]]) -> [[u64]], x@1: [[u64]]) -> [[u64]] { + f@0(x@1) +} + +code fn double(x@0: u64) -> u64 { + $(map_code(|y@1: [[u64]]| #(@add_u64($(y@1), $(y@1))), #(x@0))) +} + diff --git a/compiler/tests/snap/full/pi_staging_hof/6_stage.txt b/compiler/tests/snap/full/pi_staging_hof/6_stage.txt new file mode 100644 index 0000000..3a65226 --- /dev/null +++ b/compiler/tests/snap/full/pi_staging_hof/6_stage.txt @@ -0,0 +1,4 @@ +code fn double(x@0: u64) -> u64 { + @add_u64(x@0, x@0) +} + diff --git a/compiler/tests/snap/full/power/2_parse.txt b/compiler/tests/snap/full/power/2_parse.txt index 489213b..f1c5410 100644 --- a/compiler/tests/snap/full/power/2_parse.txt +++ b/compiler/tests/snap/full/power/2_parse.txt @@ -72,7 +72,9 @@ Program { ty: None, expr: Splice( App { - func: "power", + func: Var( + "power", + ), args: [ Var( "x", @@ -196,7 +198,9 @@ Program { stmts: [], expr: Splice( App { - func: "power", + func: Var( + "power", + ), args: [ Quote( Var( @@ -229,7 +233,9 @@ Program { stmts: [], expr: Splice( App { - func: "power", + func: Var( + "power", + ), args: [ Quote( Var( diff --git a/compiler/tests/snap/full/power_acc/2_parse.txt b/compiler/tests/snap/full/power_acc/2_parse.txt index 0b48d10..1e1178c 100644 --- a/compiler/tests/snap/full/power_acc/2_parse.txt +++ b/compiler/tests/snap/full/power_acc/2_parse.txt @@ -109,7 +109,9 @@ Program { 0, ), body: App { - func: "power_acc", + func: Var( + "power_acc", + ), args: [ Quote( App { @@ -150,7 +152,9 @@ Program { stmts: [], expr: Splice( App { - func: "power_acc", + func: Var( + "power_acc", + ), args: [ Quote( App { @@ -311,7 +315,9 @@ Program { 0, ), body: App { - func: "power_acc_1", + func: Var( + "power_acc_1", + ), args: [ Quote( App { @@ -345,7 +351,9 @@ Program { 1, ), body: App { - func: "power_acc", + func: Var( + "power_acc", + ), args: [ Quote( App { @@ -417,7 +425,9 @@ Program { body: Block { stmts: [], expr: App { - func: "power_acc_1", + func: Var( + "power_acc_1", + ), args: [ Var( "x", @@ -447,7 +457,9 @@ Program { stmts: [], expr: Splice( App { - func: "power", + func: Var( + "power", + ), args: [ Quote( Var( @@ -480,7 +492,9 @@ Program { stmts: [], expr: Splice( App { - func: "power", + func: Var( + "power", + ), args: [ Quote( Var( @@ -513,7 +527,9 @@ Program { stmts: [], expr: Splice( App { - func: "power", + func: Var( + "power", + ), args: [ Quote( Block { @@ -569,7 +585,9 @@ Program { stmts: [], expr: Splice( App { - func: "power", + func: Var( + "power", + ), args: [ Quote( Var( @@ -602,7 +620,9 @@ Program { stmts: [], expr: Splice( App { - func: "power", + func: Var( + "power", + ), args: [ Quote( Var( @@ -635,7 +655,9 @@ Program { stmts: [], expr: Splice( App { - func: "power", + func: Var( + "power", + ), args: [ Quote( Var( diff --git a/compiler/tests/snap/full/power_simple/2_parse.txt b/compiler/tests/snap/full/power_simple/2_parse.txt index 5704138..ae2e728 100644 --- a/compiler/tests/snap/full/power_simple/2_parse.txt +++ b/compiler/tests/snap/full/power_simple/2_parse.txt @@ -61,7 +61,9 @@ Program { args: [ Splice( App { - func: "power", + func: Var( + "power", + ), args: [ Var( "x", @@ -112,7 +114,9 @@ Program { stmts: [], expr: Splice( App { - func: "power", + func: Var( + "power", + ), args: [ Quote( Var( @@ -145,7 +149,9 @@ Program { stmts: [], expr: Splice( App { - func: "power", + func: Var( + "power", + ), args: [ Quote( Var( diff --git a/compiler/tests/snap/full/splice_meta_int/2_parse.txt b/compiler/tests/snap/full/splice_meta_int/2_parse.txt index c9af4d8..8f5f797 100644 --- a/compiler/tests/snap/full/splice_meta_int/2_parse.txt +++ b/compiler/tests/snap/full/splice_meta_int/2_parse.txt @@ -25,7 +25,9 @@ Program { stmts: [], expr: Splice( App { - func: "val", + func: Var( + "val", + ), args: [], }, ), diff --git a/compiler/tests/snap/full/staging/2_parse.txt b/compiler/tests/snap/full/staging/2_parse.txt index cdb2845..70fe89e 100644 --- a/compiler/tests/snap/full/staging/2_parse.txt +++ b/compiler/tests/snap/full/staging/2_parse.txt @@ -29,7 +29,9 @@ Program { stmts: [], expr: Splice( App { - func: "k", + func: Var( + "k", + ), args: [], }, ), diff --git a/compiler/tests/snap/full/sum_n/2_parse.txt b/compiler/tests/snap/full/sum_n/2_parse.txt index 0fa4992..5d10b31 100644 --- a/compiler/tests/snap/full/sum_n/2_parse.txt +++ b/compiler/tests/snap/full/sum_n/2_parse.txt @@ -37,7 +37,9 @@ Program { func: Add, args: [ App { - func: "sum_n", + func: Var( + "sum_n", + ), args: [ App { func: Sub, @@ -73,7 +75,9 @@ Program { stmts: [], expr: Splice( App { - func: "sum_n", + func: Var( + "sum_n", + ), args: [ Lit( 5, diff --git a/compiler/tests/snap/stage_error/add_overflow_u32/2_parse.txt b/compiler/tests/snap/stage_error/add_overflow_u32/2_parse.txt index eac0cb9..d3df84e 100644 --- a/compiler/tests/snap/stage_error/add_overflow_u32/2_parse.txt +++ b/compiler/tests/snap/stage_error/add_overflow_u32/2_parse.txt @@ -33,7 +33,9 @@ Program { stmts: [], expr: Splice( App { - func: "f", + func: Var( + "f", + ), args: [], }, ), diff --git a/compiler/tests/snap/stage_error/add_overflow_u8/2_parse.txt b/compiler/tests/snap/stage_error/add_overflow_u8/2_parse.txt index 5b4df65..803f16a 100644 --- a/compiler/tests/snap/stage_error/add_overflow_u8/2_parse.txt +++ b/compiler/tests/snap/stage_error/add_overflow_u8/2_parse.txt @@ -33,7 +33,9 @@ Program { stmts: [], expr: Splice( App { - func: "f", + func: Var( + "f", + ), args: [], }, ), diff --git a/compiler/tests/snap/stage_error/mul_overflow_u8/2_parse.txt b/compiler/tests/snap/stage_error/mul_overflow_u8/2_parse.txt index 3e48c02..4d4f5d0 100644 --- a/compiler/tests/snap/stage_error/mul_overflow_u8/2_parse.txt +++ b/compiler/tests/snap/stage_error/mul_overflow_u8/2_parse.txt @@ -33,7 +33,9 @@ Program { stmts: [], expr: Splice( App { - func: "f", + func: Var( + "f", + ), args: [], }, ), diff --git a/compiler/tests/snap/stage_error/sub_underflow_u8/2_parse.txt b/compiler/tests/snap/stage_error/sub_underflow_u8/2_parse.txt index cd548f1..e62be59 100644 --- a/compiler/tests/snap/stage_error/sub_underflow_u8/2_parse.txt +++ b/compiler/tests/snap/stage_error/sub_underflow_u8/2_parse.txt @@ -33,7 +33,9 @@ Program { stmts: [], expr: Splice( App { - func: "f", + func: Var( + "f", + ), args: [], }, ), diff --git a/docs/SYNTAX.md b/docs/SYNTAX.md index 6094085..e85dc51 100644 --- a/docs/SYNTAX.md +++ b/docs/SYNTAX.md @@ -22,7 +22,7 @@ x + y // This is also a comment | Keyword | Description | |-----------|-------------| -| `fn` | Function definition | +| `fn` | Function definition or function type | | `code` | Object-level marker | | `let` | Variable binding | | `match` | Pattern matching | @@ -45,6 +45,41 @@ x + y // This is also a comment Identifiers matching `u[0-9]+` are reserved for primitive types. +## Function Types + +Function types use the `fn` keyword with parenthesized parameters: + +``` +fn(_: u64) -> u64 // non-dependent function type (wildcard name required) +fn(x: u64) -> u64 // dependent: return type may mention x +fn(A: Type, x: A) -> A // polymorphic: type parameter used in value positions +fn(_: fn(_: u64) -> u64) -> u64 // higher-order: function taking a function +``` + +Function types are right-associative: `fn(_: A) -> fn(_: B) -> C` means `fn(_: A) -> (fn(_: B) -> C)`. + +Multi-parameter function types desugar to nested single-parameter types: + +``` +fn(A: Type, x: A) -> A ≡ fn(A: Type) -> fn(x: A) -> A +``` + +Function types are meta-level only — they inhabit `Type`, not `VmType`. + +## Lambda Expressions + +Lambdas use Rust's closure syntax with mandatory type annotations: + +``` +|x: u64| x + 1 // single parameter +|x: u64, y: u64| x + y // multi-parameter (desugars to nested lambdas) +|f: fn(_: u64) -> u64, x: u64| f(x) // higher-order +``` + +Type annotations on lambda parameters are required. This makes lambdas inferable — the typechecker can synthesise the full function type from the annotations and the body. + +Lambdas are meta-level only — they cannot appear in object-level (`code fn`) bodies. + ## Operators Lowest to highest, left-associative unless noted: @@ -58,6 +93,8 @@ Lowest to highest, left-associative unless noted: | 5 | `*` `/` | | 6 | `!` (unary) | +Note: `|` as bitwise OR is distinguished from `|` as lambda delimiter by position: a leading `|` in atom position starts a lambda; `|` after an expression is bitwise OR. + Note: The comparison operators are provisional. See [bs/comparison_operators.md](bs/comparison_operators.md) for discussion. ## Grammar (EBNF-like) @@ -87,7 +124,9 @@ expr ::= literal | expr "(" expr ("," expr)* ")" -- application | expr binary_op expr | unary_op expr - | "#(" expr ")" -- quotation + | fn_type -- function type + | lambda -- lambda expression + | "#(" expr ")" -- quotation | "#{" stmt* expr "}" -- block quotation | "$(" expr ")" -- splice | "${" stmt* expr "}" -- block splice @@ -95,6 +134,12 @@ expr ::= literal | "match" expr "{" match_arm* "}" | block +fn_type ::= "fn" "(" fn_params ")" "->" expr +fn_params ::= (fn_param ("," fn_param)*)? +fn_param ::= identifier ":" expr -- name required; use "_" for non-dependent + +lambda ::= "|" param ("," param)* "|" expr + binary_op ::= "+" | "-" | "*" | "/" | "==" | "!=" | "<" | ">" | "<=" | ">=" | "&" | "|" unary_op ::= "!" diff --git a/docs/bs/README.md b/docs/bs/README.md index 4a27f13..2a67da1 100644 --- a/docs/bs/README.md +++ b/docs/bs/README.md @@ -12,6 +12,7 @@ The folder name, `bs`, stands for brainstorming. Obviously. - [functional_goto.md](functional_goto.md) — Control flow via SSA-style basic blocks with goto - [comparison_operators.md](comparison_operators.md) — Boolean vs propositional comparisons - [tuples_and_inference.md](tuples_and_inference.md) — Tuple syntax and type inference +- [pi_types.md](pi_types.md) — Dependent function types (Pi) and lambdas at the meta level ## Compiler Internals diff --git a/docs/bs/pi_types.md b/docs/bs/pi_types.md new file mode 100644 index 0000000..96fd2d1 --- /dev/null +++ b/docs/bs/pi_types.md @@ -0,0 +1,264 @@ +# Pi Types: Dependent Function Types at the Meta Level + +This document records the design decisions for adding dependent function types (Pi types) and lambda abstractions to Splic at the meta level. + +## Motivation + +The current prototype has no first-class functions. All functions are top-level named definitions; there is no way to pass a function as an argument or return one from another function. This blocks key use cases: + +**Higher-order meta functions.** The `repeat` combinator from `prototype_next.md` requires passing a code-generating function: + +```splic +fn repeat(f: fn([[u64]]) -> [[u64]], n: u64, x: [[u64]]) -> [[u64]] { + match n { + 0 => x, + n => repeat(f, n - 1, f(x)), + } +} + +code fn square_twice(x: u64) -> u64 { + $(repeat(|y: [[u64]]| #($(y) * $(y)), 2, #(x))) +} +``` + +**Polymorphic functions.** Dependent function types let parameters appear in subsequent types: + +```splic +fn id(A: Type, x: A) -> A { x } +``` + +Here `A` is a type passed at compile time, and the return type depends on it. + +**Type-level computation.** With Pi types, function types are first-class terms in the meta universe, enabling functions that compute types. + +## Syntax + +### Function types + +Dependent function types use the `fn` keyword — the same keyword used for definitions: + +``` +fn(x: A) -> B // dependent: B may mention x +fn(_: A) -> B // non-dependent: wildcard name required +``` + +Right-associative: `fn(_: A) -> fn(_: B) -> C` means `fn(_: A) -> (fn(_: B) -> C)`. + +Multi-parameter function types are **not** desugared to nested Pi — the arity is preserved to enable proper arity checking at call sites: + +``` +fn(x: A, y: B) -> C -- two-argument function, not sugar for nested Pi +``` + +**Rationale.** Using `fn` for types mirrors its use for definitions — in Splic, `fn` introduces anything function-shaped. The parenthesized parameter syntax `fn(x: A)` is visually distinct from a definition `fn name(x: A)` (the presence of a name between `fn` and `(` distinguishes them). The `(x: A) -> B` Agda/Lean convention was considered but `fn(x: A) -> B` is more Rust-flavored. + +### Lambdas + +Lambda expressions use Rust's closure syntax: + +``` +|x: A| body // type annotation required +|x: A, y: B| body // multi-parameter (desugars to nested lambdas) +``` + +Type annotations on lambda parameters are **mandatory**. This makes lambdas inferable — the typechecker can construct the full Pi type from the annotation and the inferred body type, without needing an expected type pushed down from context. This is a deliberate simplification for the prototype; unannotated `|x| body` syntax may be added later when check-mode lambdas are needed. + +**Rationale.** The `|...|` syntax is familiar to Rust users. It reuses the existing `|` token. Disambiguation with bitwise OR is positional: `|` at the start of an atom is a lambda; `|` after an expression is bitwise OR. + +### Scope + +Pi types and lambdas are **meta-level only**. Object-level functions remain top-level `code fn` definitions. A lambda cannot appear in object-level code, and `fn(_: A) -> B` cannot appear as an object-level type. This matches the 2LTT philosophy: the meta level is a rich functional language; the object level is a simple low-level language. + +## Typing Rules + +Pi types inhabit the meta universe (`Type`). The formation, introduction, and elimination rules: + +### Formation (Pi) + +``` +Γ ⊢ A : Type Γ, x : A ⊢ B : Type +────────────────────────────────────── + Γ ⊢ fn(x: A) -> B : Type +``` + +Both `A` and `B` must be types. The parameter `x` is in scope in `B` (dependent case). For non-dependent arrows, `x` does not appear free in `B`. + +### Introduction (Lambda) + +Lambdas are **inferable** because type annotations on parameters are mandatory: + +``` +Γ ⊢ A : Type Γ, x : A ⊢ body ⇒ B +───────────────────────────────────────── + Γ ⊢ |x: A| body ⇒ fn(x: A) -> B +``` + +The parameter type `A` comes from the annotation; the body type `B` is inferred in the extended context. The synthesised type is the Pi type `fn(x: A) -> B`. + +### Elimination (Application) + +Application is inferable when the function is inferable: + +``` +Γ ⊢ f ⇒ fn(x: A) -> B Γ ⊢ arg ⇐ A +───────────────────────────────────────── + Γ ⊢ f(arg) ⇒ B[arg/x] +``` + +The return type `B[arg/x]` is the body type with the argument substituted for the parameter. For non-dependent functions this is just `B`. + +Multi-argument calls desugar to curried application: `f(a, b)` = `f(a)(b)`. + +## Core IR Design + +### New Term variants + +```rust +Pi { param_name: &'a str, param_ty: &'a Term<'a>, body_ty: &'a Term<'a> } +Lam { param_name: &'a str, param_ty: &'a Term<'a>, body: &'a Term<'a> } +FunApp { func: &'a Term<'a>, arg: &'a Term<'a> } +Global(Name<'a>) +PrimApp { prim: Prim, args: &'a [&'a Term<'a>] } +``` + +### Refactoring App/Head + +The current `App { head: Head, args }` where `Head` is `Global(Name) | Prim(Prim)` is replaced by: + +- **`Global(Name)`** — a term representing a reference to a top-level function. Now a first-class term rather than just an application head. +- **`FunApp { func, arg }`** — single-argument curried application. Used for both global and local function calls. Multi-arg calls `foo(a, b)` elaborate to `FunApp(FunApp(Global("foo"), a), b)`. +- **`PrimApp { prim, args }`** — primitive operation application. Kept separate because prims carry resolved `IntType` and are always fully applied. Eventually prims will become regular typed symbols, but the typechecker isn't ready for that yet. + +**`FunSig` is preserved** as a convenience structure in the globals table. It stores the flat parameter list and return type for efficient lookup. A `FunSig::to_pi_type(arena)` method constructs the corresponding nested Pi type when needed (e.g., for `type_of(Global(name))`). + +### Substitution + +Dependent return types require substitution: `B[arg/x]`. Since the core IR uses De Bruijn levels, substitution replaces `Var(lvl)` with the argument term. Levels do not shift, making the implementation straightforward: + +```rust +fn subst<'a>(arena: &'a Bump, term: &'a Term<'a>, lvl: Lvl, replacement: &'a Term<'a>) -> &'a Term<'a> +``` + +### Alpha-equivalence + +The current `PartialEq` on `Term` compares structurally, including `param_name` fields. Two Pi types that differ only in parameter names (`fn(x: A) -> B` vs `fn(y: A) -> B`) should be equal. A dedicated `alpha_eq` function ignores names and compares only structure (De Bruijn levels handle binding correctly). + +## Evaluator Design + +### Closures + +A new `MetaVal` variant captures the environment at lambda creation: + +```rust +VClosure { + param_name: &str, + body: &Term, + env: Vec, + obj_next: Lvl, +} +``` + +This follows the substitution-based approach already in use. Application extends the captured env with the argument value and evaluates the body. + +### Global function references + +When `eval_meta` encounters `Global(name)`, it constructs a closure from the global's body and parameters. When applied via `FunApp`, this closure behaves identically to a lambda — the argument extends the env and the body is evaluated. + +For multi-parameter globals, partial application produces a closure that awaits the remaining arguments. This falls out naturally from curried `FunApp` chains. + +### Pi types in evaluation + +`Pi` terms are type-level and never appear in evaluation position (the typechecker ensures this). They are unreachable in `eval_meta`. + +## Staging Interaction + +### Closures cannot be quoted + +Meta-level closures (`VClosure`) cannot appear in object-level code. The type system prevents this: + +- Pi types inhabit `Type` (meta universe), not `VmType` (object universe) +- Therefore `[[fn(A) -> B]]` is ill-formed — you cannot lift a function type +- Lambdas have Pi types, so they cannot have lifted types, so they cannot be quoted + +This is the correct behavior: closures are compile-time values that are fully evaluated during staging. + +### Code-generating lambdas + +The main staging use case is lambdas that *produce* code: + +```splic +fn repeat(f: fn([[u64]]) -> [[u64]], n: u64, x: [[u64]]) -> [[u64]] { + match n { + 0 => x, + n => repeat(f, n - 1, f(x)), + } +} + +code fn square_twice(x: u64) -> u64 { + $(repeat(|y: [[u64]]| #($(y) * $(y)), 2, #(x))) +} +``` + +Here `f` has type `fn([[u64]]) -> [[u64]]` — it takes object code and returns object code. The lambda `|y| #($(y) * $(y))` is a meta-level function that generates object-level multiplication code. After staging, all meta computation (including the lambda and the `repeat` recursion) is erased: + +```splic +code fn square_twice(x: u64) -> u64 { + (x * x) * (x * x) +} +``` + +### Object-level FunApp/Global + +`FunApp` and `Global` can appear in object-level terms for object-level function calls. The unstager passes them through structurally (copying to the output arena), just as it does for the current `App { head: Global }`. + +## Examples + +### Polymorphic identity + +```splic +fn id(A: Type, x: A) -> A { x } +fn use_id() -> u64 { id(u64, 42) } +``` + +### Const combinator + +```splic +fn const_(A: Type, B: Type) -> fn(A) -> fn(B) -> A { + |a: A| |b: B| a +} +``` + +### Function composition + +```splic +fn compose(A: Type, B: Type, C: Type, f: fn(B) -> C, g: fn(A) -> B) -> fn(A) -> C { + |x: A| f(g(x)) +} +``` + +### Higher-order staging + +```splic +fn map_code(f: fn([[u64]]) -> [[u64]], x: [[u64]]) -> [[u64]] { + f(x) +} + +code fn double(x: u64) -> u64 { + $(map_code(|y: [[u64]]| #($(y) + $(y)), #(x))) +} +// Stages to: code fn double(x: u64) -> u64 { x + x } +``` + +## Future Work + +- **Prims as typed symbols**: Currently prims are special-cased with `PrimApp`. Eventually they should have types (polymorphic in width/phase) and be typechecked uniformly. +- **Object-level closures**: The closure-free approach from Kovács 2024 avoids runtime closures while still supporting higher-order object code. +- **Implicit arguments**: `fn {A: Type}(x: A) -> A` with unification to infer `A` at call sites. +- **Spine-based evaluation**: Replace substitution-based closures with lazy spines before adding full dependent elimination. + +## References + +- Kovács 2022: Staged Compilation with Two-Level Type Theory (ICFP) +- Kovács 2024: Closure-Free Functional Programming in a Two-Level Type Theory (ICFP) +- [prototype_eval.md](prototype_eval.md): Evaluator design and progression plan +- [prototype_next.md](prototype_next.md): Roadmap (Phase 2: Meta-level Functions, Phase 3: Dependent Types)