From 652ca9f5b7d832244884d9452ec2f566faa8d680 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20B=C3=B6ving?= Date: Wed, 11 Mar 2026 15:19:54 +0100 Subject: [PATCH] refactor: port EmitC to LCNF (#12781) This PR ports the C emission pass from IR to LCNF, marking the last step of the IR/LCNF conversion and thus enabling end-to-end code generation through the new compilation infrastructure. --- src/Lean/Compiler/IR.lean | 2 - src/Lean/Compiler/IR/AddExtern.lean | 3 + src/Lean/Compiler/IR/EmitC.lean | 1013 ----------------- src/Lean/Compiler/IR/EmitLLVM.lean | 1 - src/Lean/Compiler/IR/SimpCase.lean | 23 - src/Lean/Compiler/LCNF/Basic.lean | 2 + src/Lean/Compiler/LCNF/EmitC.lean | 1163 ++++++++++++++++++++ src/Lean/Compiler/LCNF/EmitUtil.lean | 56 + src/Lean/Compiler/LCNF/Passes.lean | 6 +- src/Lean/Compiler/LCNF/PhaseExt.lean | 28 +- src/Lean/Compiler/LCNF/PublicDeclsExt.lean | 13 +- src/Lean/Compiler/LCNF/SimpCase.lean | 10 + src/Lean/Compiler/LCNF/Toposort.lean | 15 +- src/Lean/Shell.lean | 5 +- 14 files changed, 1284 insertions(+), 1056 deletions(-) delete mode 100644 src/Lean/Compiler/IR/EmitC.lean delete mode 100644 src/Lean/Compiler/IR/SimpCase.lean create mode 100644 src/Lean/Compiler/LCNF/EmitC.lean create mode 100644 src/Lean/Compiler/LCNF/EmitUtil.lean diff --git a/src/Lean/Compiler/IR.lean b/src/Lean/Compiler/IR.lean index ecd227f90d..6a686959b1 100644 --- a/src/Lean/Compiler/IR.lean +++ b/src/Lean/Compiler/IR.lean @@ -13,7 +13,6 @@ public import Lean.Compiler.IR.CompilerM public import Lean.Compiler.IR.NormIds public import Lean.Compiler.IR.Checker public import Lean.Compiler.IR.UnboxResult -public import Lean.Compiler.IR.EmitC public import Lean.Compiler.IR.Sorry public import Lean.Compiler.IR.ToIR public import Lean.Compiler.IR.ToIRType @@ -34,7 +33,6 @@ def compile (decls : Array Decl) : CompilerM (Array Decl) := do let mut decls := decls decls ← updateSorryDep decls logDecls `result decls - checkDecls decls addDecls decls inferMeta decls return decls diff --git a/src/Lean/Compiler/IR/AddExtern.lean b/src/Lean/Compiler/IR/AddExtern.lean index 2f21cfe10d..91de6c2f06 100644 --- a/src/Lean/Compiler/IR/AddExtern.lean +++ b/src/Lean/Compiler/IR/AddExtern.lean @@ -70,6 +70,9 @@ where decl.saveImpure let decls ← Compiler.LCNF.addBoxedVersions #[decl] let decls ← Compiler.LCNF.runExplicitRc decls + for decl in decls do + decl.saveImpure + modifyEnv fun env => Compiler.LCNF.recordFinalImpureDecl env decl.name return decls addIr (decls : Array (Compiler.LCNF.Decl .impure)) : CoreM Unit := do diff --git a/src/Lean/Compiler/IR/EmitC.lean b/src/Lean/Compiler/IR/EmitC.lean deleted file mode 100644 index 72a73d9502..0000000000 --- a/src/Lean/Compiler/IR/EmitC.lean +++ /dev/null @@ -1,1013 +0,0 @@ -/- -Copyright (c) 2019 Microsoft Corporation. All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. -Authors: Leonardo de Moura --/ -module - -prelude -public import Lean.Compiler.NameMangling -public import Lean.Compiler.IR.EmitUtil -public import Lean.Compiler.IR.NormIds -public import Lean.Compiler.IR.SimpCase -public import Lean.Compiler.ModPkgExt -import Lean.Compiler.LCNF.Types -import Lean.Compiler.ClosedTermCache -import Lean.Compiler.LCNF.SimpleGroundExpr -import Init.Omega -import Init.While -import Init.Data.Range.Polymorphic.Iterators -import Lean.Runtime - -public section - -namespace Lean.IR.EmitC -open Lean.Compiler.LCNF (isBoxedName isSimpleGroundDecl getSimpleGroundExpr - getSimpleGroundExprWithResolvedRefs uint64ToByteArrayLE SimpleGroundExpr SimpleGroundArg - addSimpleGroundDecl) - -def leanMainFn := "_lean_main" - -structure Context where - env : Environment - modName : Name - jpMap : JPParamsMap := {} - mainFn : FunId := default - mainParams : Array Param := #[] - -abbrev M := ReaderT Context (EStateM String String) - -@[inline] def getEnv : M Environment := Context.env <$> read - -@[inline] def getModName : M Name := Context.modName <$> read - -@[inline] def getModInitFn (phases : IRPhases) : M String := do - let pkg? := (← getEnv).getModulePackage? - return mkModuleInitializationFunctionName (phases := phases) (← getModName) pkg? - -def getDecl (n : Name) : M Decl := do - let env ← getEnv - match findEnvDecl env n with - | some d => pure d - | none => throw s!"unknown declaration '{n}'" - -@[inline] def emit {α : Type} [ToString α] (a : α) : M Unit := - modify fun out => out ++ toString a - -@[inline] def emitLn {α : Type} [ToString α] (a : α) : M Unit := do - emit a; emit "\n" - -def emitLns {α : Type} [ToString α] (as : List α) : M Unit := - as.forM fun a => emitLn a - -def argToCString (x : Arg) : String := - match x with - | .var x => toString x - | .erased => "lean_box(0)" - -def emitArg (x : Arg) : M Unit := - emit (argToCString x) - -def toCType : IRType → String - | IRType.float => "double" - | IRType.float32 => "float" - | IRType.uint8 => "uint8_t" - | IRType.uint16 => "uint16_t" - | IRType.uint32 => "uint32_t" - | IRType.uint64 => "uint64_t" - | IRType.usize => "size_t" - | IRType.object => "lean_object*" - | IRType.tagged => "lean_object*" - | IRType.tobject => "lean_object*" - | IRType.erased => "lean_object*" - | IRType.void => "lean_object*" - | IRType.struct _ _ => panic! "not implemented yet" - | IRType.union _ _ => panic! "not implemented yet" - -def toHexDigit (c : Nat) : String := - String.singleton c.digitChar - -def quoteString (s : String) : String := - let q := "\""; - let q := s.foldl - (fun q c => q ++ - if c == '\n' then "\\n" - else if c == '\r' then "\\r" - else if c == '\t' then "\\t" - else if c == '\\' then "\\\\" - else if c == '\"' then "\\\"" - else if c == '?' then "\\?" -- avoid trigraphs - else if c.toNat <= 31 then - "\\x" ++ toHexDigit (c.toNat / 16) ++ toHexDigit (c.toNat % 16) - -- TODO(Leo): we should use `\unnnn` for escaping unicode characters. - else String.singleton c) - q; - q ++ "\"" - -def throwInvalidExportName {α : Type} (n : Name) : M α := - throw s!"invalid export name '{n}'" - -def toCName (n : Name) : M String := do - let env ← getEnv; - -- TODO: we should support simple export names only - match getExportNameFor? env n with - | some (.str .anonymous s) => return s - | some _ => throwInvalidExportName n - | none => return if n == `main then leanMainFn else getSymbolStem env n - -def emitCName (n : Name) : M Unit := - toCName n >>= emit - -def toCInitName (n : Name) : M String := do - let env ← getEnv; - -- TODO: we should support simple export names only - match getExportNameFor? env n with - | some (.str .anonymous s) => return "_init_" ++ s - | some _ => throwInvalidExportName n - | none => return "_init_" ++ getSymbolStem env n - -def emitCInitName (n : Name) : M Unit := - toCInitName n >>= emit - -def ctorScalarSizeStr (usize : Nat) (ssize : Nat) : String := - if usize == 0 then toString ssize - else if ssize == 0 then s!"sizeof(size_t)*{usize}" - else s!"sizeof(size_t)*{usize} + {ssize}" - -structure GroundState where - auxCounter : Nat := 0 - -abbrev GroundM := StateT GroundState M - -partial def emitGroundDecl (decl : Decl) (cppBaseName : String) : M Unit := do - let some ground := getSimpleGroundExpr (← getEnv) decl.name | unreachable! - discard <| compileGround ground |>.run {} -where - compileGround (e : SimpleGroundExpr) : GroundM Unit := do - let valueName ← compileGroundToValue e - let declPrefix := if isClosedTermName (← getEnv) decl.name then "static" else "LEAN_EXPORT" - emitLn <| s!"{declPrefix} const lean_object* {cppBaseName} = (const lean_object*)&{valueName};" - - compileGroundToValue (e : SimpleGroundExpr) : GroundM String := do - match e with - | .ctor cidx objArgs usizeArgs scalarArgs => - let val ← compileCtor cidx objArgs usizeArgs scalarArgs - mkValueCLit "lean_ctor_object" val - | .string data => - let leanStringTag := 249 - let header := mkHeader 0 0 leanStringTag - let size := data.utf8ByteSize + 1 -- null byte - let length := data.length - let data : String := quoteString data - mkValueCLit - "lean_string_object" - s!"\{.m_header = {header}, .m_size = {size}, .m_capacity = {size}, .m_length = {length}, .m_data = {data}}" - | .pap func args => - let numFixed := args.size - let leanClosureTag := 245 - let header := mkHeader s!"sizeof(lean_closure_object) + sizeof(void*)*{numFixed}" 0 leanClosureTag - let funPtr := s!"(void*){← toCName func}" - let arity := (← getDecl func).params.size - let args ← args.mapM groundArgToCLit - let argArray := String.intercalate "," args.toList - mkValueCLit - "lean_closure_object" - s!"\{.m_header = {header}, .m_fun = {funPtr}, .m_arity = {arity}, .m_num_fixed = {numFixed}, .m_objs = \{{argArray}} }" - | .nameMkStr args => - let obj ← groundNameMkStrToCLit args - mkValueCLit "lean_ctor_object" obj - | .array elems => - let leanArrayTag := 246 - let header := mkHeader s!"sizeof(lean_array_object) + sizeof(void*)*{elems.size}" 0 leanArrayTag - let elemLits ← elems.mapM groundArgToCLit - let dataArray := String.intercalate "," elemLits.toList - mkValueCLit - "lean_array_object" - s!"\{.m_header = {header}, .m_size = {elems.size}, .m_capacity = {elems.size}, .m_data = \{{dataArray}}}" - | .byteArray data => - let leanScalarArrayTag := 248 - let elemSize : Nat := 1 - let header := mkHeader s!"sizeof(lean_sarray_object) + {data.size}" elemSize leanScalarArrayTag - let dataLits := data.map toString - let dataArray := String.intercalate "," dataLits.toList - mkValueCLit - "lean_sarray_object" - s!"\{.m_header = {header}, .m_size = {data.size}, .m_capacity = {data.size}, .m_data = \{{dataArray}}}" - | .reference refDecl => findValueDecl refDecl - - mkValueName (name : String) : String := - name ++ "_value" - - mkAuxValueName (name : String) (idx : Nat) : String := - mkValueName name ++ s!"_aux_{idx}" - - mkAuxDecl (type value : String) : GroundM String := do - let idx ← modifyGet fun s => (s.auxCounter, { s with auxCounter := s.auxCounter + 1 }) - let name := mkAuxValueName cppBaseName idx - emitLn <| s!"static const {type} {name} = {value};" - return name - - mkValueCLit (type value : String) : GroundM String := do - let valueName := mkValueName cppBaseName - emitLn <| s!"static const {type} {valueName} = {value};" - return valueName - - groundNameMkStrToCLit (args : Array (Name × UInt64)) : GroundM String := do - assert! args.size > 0 - if args.size == 1 then - let (ref, hash) := args[0]! - let hash := uint64ToByteArrayLE hash - compileCtor 1 #[.tagged 0, .reference ref] #[] hash - else - let (ref, hash) := args.back! - let args := args.pop - let lit ← groundNameMkStrToCLit args - let auxName ← mkAuxDecl "lean_ctor_object" lit - let hash := uint64ToByteArrayLE hash - compileCtor 1 #[.rawReference auxName, .reference ref] #[] hash - - groundArgToCLit (a : SimpleGroundArg) : GroundM String := do - match a with - | .tagged val => return s!"((lean_object*)(((size_t)({val}) << 1) | 1))" - | .reference decl => return s!"((lean_object*)&{← findValueDecl decl})" - | .rawReference decl => return s!"((lean_object*)&{decl})" - - findValueDecl (decl : Name) : GroundM String := do - let mut decl := decl - while true do - if let some (.reference ref) := getSimpleGroundExpr (← getEnv) decl then - decl := ref - else - break - return mkValueName (← toCName decl) - - compileCtor (cidx : Nat) (objArgs : Array SimpleGroundArg) (usizeArgs : Array UInt64) - (scalarArgs : Array UInt8) : GroundM String := do - let header := mkCtorHeader objArgs.size usizeArgs.size scalarArgs.size cidx - let objArgs ← objArgs.mapM groundArgToCLit - let usizeArgs : Array String := usizeArgs.map fun val => s!"(lean_object*)(size_t)({val}ULL)" - assert! scalarArgs.size % 8 == 0 - let scalarArgs : Array String := Id.run do - let chunks := scalarArgs.size / 8 - let mut packed := Array.emptyWithCapacity chunks - for idx in 0...chunks do - let b1 := scalarArgs[idx * 8]! - let b2 := scalarArgs[idx * 8 + 1]! - let b3 := scalarArgs[idx * 8 + 2]! - let b4 := scalarArgs[idx * 8 + 3]! - let b5 := scalarArgs[idx * 8 + 4]! - let b6 := scalarArgs[idx * 8 + 5]! - let b7 := scalarArgs[idx * 8 + 6]! - let b8 := scalarArgs[idx * 8 + 7]! - let lit := s!"LEAN_SCALAR_PTR_LITERAL({b1}, {b2}, {b3}, {b4}, {b5}, {b6}, {b7}, {b8})" - packed := packed.push lit - return packed - let argArray := String.intercalate "," (objArgs ++ usizeArgs ++ scalarArgs).toList - return s!"\{.m_header = {header}, .m_objs = \{{argArray}}}" - - mkCtorHeader (numObjs : Nat) (usize : Nat) (ssize : Nat) (tag : Nat) : String := - let size := s!"sizeof(lean_ctor_object) + sizeof(void*)*{numObjs} + {ctorScalarSizeStr usize ssize}" - mkHeader size numObjs tag - - mkHeader {α : Type} [ToString α] (csSz : α) (other : Nat) (tag : Nat) : String := - s!"\{.m_rc = 0, .m_cs_sz = {csSz}, .m_other = {other}, .m_tag = {tag}}" - -def toTokenName (cppBaseName : String) : String := - s!"{cppBaseName}_once" - -def emitFnClosedDecl (decl : Decl) (cppBaseName : String) : M Unit := do - emitLn s!"static lean_once_cell_t {toTokenName cppBaseName} = LEAN_ONCE_CELL_INITIALIZER;" - emitLn s!"static {toCType decl.resultType} {cppBaseName};" - -def emitFnDeclAux (decl : Decl) (cppBaseName : String) (isExternal : Bool) : M Unit := do - let ps := decl.params - let env ← getEnv - - if isSimpleGroundDecl env decl.name then - emitGroundDecl decl cppBaseName - else if isClosedTermName env decl.name then - emitFnClosedDecl decl cppBaseName - else - if ps.isEmpty then - if isExternal then - emit "extern " - else - emit "LEAN_EXPORT " - else - if !isExternal then emit "LEAN_EXPORT " - emit (toCType decl.resultType ++ " " ++ cppBaseName) - unless ps.isEmpty do - emit "(" - -- We omit void parameters, note that they are guaranteed not to occur in boxed functions - let ps := ps.filter (fun p => !p.ty.isVoid) - -- We omit erased parameters for extern constants - let ps := if isExternC env decl.name then ps.filter (fun p => !p.ty.isErased) else ps - if ps.size > closureMaxArgs && isBoxedName decl.name then - emit "lean_object**" - else - ps.size.forM fun i _ => do - if i > 0 then emit ", " - emit (toCType ps[i].ty) - emit ")" - emitLn ";" - -def emitFnDecl (decl : Decl) (isExternal : Bool) : M Unit := do - let cppBaseName ← toCName decl.name - emitFnDeclAux decl cppBaseName isExternal - -def emitExternDeclAux (decl : Decl) (cNameStr : String) : M Unit := do - let env ← getEnv - let extC := isExternC env decl.name - emitFnDeclAux decl cNameStr extC - -def emitFnDecls : M Unit := do - let env ← getEnv - let decls := getDecls env |>.reverse - let modDecls : NameSet := decls.foldl (fun s d => s.insert d.name) {} - let usedDecls := collectUsedDecls env decls - usedDecls.forM fun n => do - let decl ← getDecl n; - match getExternNameFor env `c decl.name with - | some cName => emitExternDeclAux decl cName - | none => emitFnDecl decl (!modDecls.contains n) - -def emitMainFn : M Unit := do - let d ← getDecl `main - match d with - | .fdecl (xs := xs) .. => do - unless xs.size == 2 || xs.size == 1 do throw "invalid main function, incorrect arity when generating code" - let env ← getEnv - let usesLeanAPI := usesModuleFrom env `Lean - emitLn "char ** lean_setup_args(int argc, char ** argv);"; - if usesLeanAPI then - emitLn "void lean_initialize();" - else - emitLn "void lean_initialize_runtime_module();"; - emitLn " - #if defined(WIN32) || defined(_WIN32) - #include - #endif - - int main(int argc, char ** argv) { - #if defined(WIN32) || defined(_WIN32) - SetErrorMode(SEM_FAILCRITICALERRORS); - SetConsoleOutputCP(CP_UTF8); - #endif - lean_object* in; lean_object* res;"; - emitLn "argv = lean_setup_args(argc, argv);"; - if usesLeanAPI then - emitLn "lean_initialize();" - else - emitLn "lean_initialize_runtime_module();" - /- We disable panic messages because they do not mesh well with extracted closed terms. - See issue #534. We can remove this workaround after we implement issue #467. -/ - emitLn "lean_set_panic_messages(false);" - emitLn s!"res = {← getModInitFn (phases := if env.header.isModule then .runtime else .all)}(1 /* builtin */);" - emitLn "lean_set_panic_messages(true);" - emitLns ["lean_io_mark_end_initialization();", - "if (lean_io_result_is_ok(res)) {", - "lean_dec_ref(res);", - "lean_init_task_manager();"]; - if xs.size == 2 then - emitLns ["in = lean_box(0);", - "int i = argc;", - "while (i > 1) {", - " lean_object* n;", - " i--;", - " n = lean_alloc_ctor(1,2,0); lean_ctor_set(n, 0, lean_mk_string(argv[i])); lean_ctor_set(n, 1, in);", - " in = n;", - "}"] - emitLn ("res = " ++ leanMainFn ++ "(in);") - else - emitLn ("res = " ++ leanMainFn ++ "();") - emitLn "}" - -- `IO _` - let retTy := env.find? `main |>.get! |>.type |>.getForallBody - -- either `UInt32` or `(P)Unit` - let retTy := retTy.appArg! - -- finalize at least the task manager to avoid leak sanitizer false positives from tasks outliving the main thread - emitLns ["lean_finalize_task_manager();", - "if (lean_io_result_is_ok(res)) {", - " int ret = " ++ if retTy.constName? == some ``UInt32 then "lean_unbox_uint32(lean_io_result_get_value(res));" else "0;", - " lean_dec_ref(res);", - " return ret;", - "} else {", - " lean_io_result_show_error(res);", - " lean_dec_ref(res);", - " return 1;", - "}"] - emitLn "}" - | _ => throw "function declaration expected" - -def hasMainFn : M Bool := do - let env ← getEnv - let decls := getDecls env - return decls.any (fun d => d.name == `main) - -def emitMainFnIfNeeded : M Unit := do - if (← hasMainFn) then emitMainFn - -def emitFileHeader : M Unit := do - let env ← getEnv - let modName ← getModName - emitLn "// Lean compiler output" - emitLn ("// Module: " ++ toString modName) - emit "// Imports:" - env.imports.forM fun m => emit (" " ++ toString m) - emitLn "" - emitLn "#include " - emitLns [ - "#if defined(__clang__)", - "#pragma clang diagnostic ignored \"-Wunused-parameter\"", - "#pragma clang diagnostic ignored \"-Wunused-label\"", - "#elif defined(__GNUC__) && !defined(__CLANG__)", - "#pragma GCC diagnostic ignored \"-Wunused-parameter\"", - "#pragma GCC diagnostic ignored \"-Wunused-label\"", - "#pragma GCC diagnostic ignored \"-Wunused-but-set-variable\"", - "#endif", - "#ifdef __cplusplus", - "extern \"C\" {", - "#endif" - ] - -def emitFileFooter : M Unit := - emitLns [ - "#ifdef __cplusplus", - "}", - "#endif" - ] - -def throwUnknownVar {α : Type} (x : VarId) : M α := - throw s!"unknown variable '{x}'" - -def getJPParams (j : JoinPointId) : M (Array Param) := do - let ctx ← read; - match ctx.jpMap[j]? with - | some ps => pure ps - | none => throw "unknown join point" - -def declareVar (x : VarId) (t : IRType) : M Unit := do - emit (toCType t); emit " "; emit x; emit "; " - -def declareParams (ps : Array Param) : M Unit := - ps.forM fun p => declareVar p.x p.ty - -partial def declareVars : FnBody → Bool → M Bool - | e@(FnBody.vdecl x t _ b), d => do - let ctx ← read - if isTailCallTo ctx.mainFn e then - pure d - else - declareVar x t; declareVars b true - | FnBody.jdecl _ xs _ b, d => do declareParams xs; declareVars b (d || xs.size > 0) - | e, d => if e.isTerminal then pure d else declareVars e.body d - -def emitTag (x : VarId) (xType : IRType) : M Unit := do - if xType.isObj then do - emit "lean_obj_tag("; emit x; emit ")" - else - emit x - -def isIf (alts : Array Alt) : Option (Nat × FnBody × FnBody) := - if h : alts.size ≠ 2 then none - else match alts[0] with - | Alt.ctor c b => some (c.cidx, b, alts[1].body) - | _ => none - -def emitInc (x : VarId) (n : Nat) (checkRef : Bool) : M Unit := do - emit $ - if checkRef then (if n == 1 then "lean_inc" else "lean_inc_n") - else (if n == 1 then "lean_inc_ref" else "lean_inc_ref_n") - emit "("; emit x - if n != 1 then emit ", "; emit n - emitLn ");" - -def emitDec (x : VarId) (n : Nat) (checkRef : Bool) : M Unit := do - emit (if checkRef then "lean_dec" else "lean_dec_ref"); - emit "("; emit x; - if n != 1 then emit ", "; emit n - emitLn ");" - -def emitDel (x : VarId) : M Unit := do - emit "lean_del_object("; emit x; emitLn ");" - -def emitSetTag (x : VarId) (i : Nat) : M Unit := do - emit "lean_ctor_set_tag("; emit x; emit ", "; emit i; emitLn ");" - -def emitSet (x : VarId) (i : Nat) (y : Arg) : M Unit := do - emit "lean_ctor_set("; emit x; emit ", "; emit i; emit ", "; emitArg y; emitLn ");" - -def emitOffset (n : Nat) (offset : Nat) : M Unit := do - if n > 0 then - emit "sizeof(void*)*"; emit n; - if offset > 0 then emit " + "; emit offset - else - emit offset - -def emitUSet (x : VarId) (n : Nat) (y : VarId) : M Unit := do - emit "lean_ctor_set_usize("; emit x; emit ", "; emit n; emit ", "; emit y; emitLn ");" - -def emitSSet (x : VarId) (n : Nat) (offset : Nat) (y : VarId) (t : IRType) : M Unit := do - match t with - | IRType.float => emit "lean_ctor_set_float" - | IRType.float32 => emit "lean_ctor_set_float32" - | IRType.uint8 => emit "lean_ctor_set_uint8" - | IRType.uint16 => emit "lean_ctor_set_uint16" - | IRType.uint32 => emit "lean_ctor_set_uint32" - | IRType.uint64 => emit "lean_ctor_set_uint64" - | _ => throw "invalid instruction"; - emit "("; emit x; emit ", "; emitOffset n offset; emit ", "; emit y; emitLn ");" - -def emitJmp (j : JoinPointId) (xs : Array Arg) : M Unit := do - let ps ← getJPParams j - if h : xs.size = ps.size then - xs.size.forM fun i _ => do - let p := ps[i] - let x := xs[i] - emit p.x; emit " = "; emitArg x; emitLn ";" - emit "goto "; emit j; emitLn ";" - else - do throw "invalid goto" - -def emitLhs (z : VarId) : M Unit := do - emit z; emit " = " - -def emitArgs (ys : Array Arg) : M Unit := - ys.size.forM fun i _ => do - if i > 0 then emit ", " - emitArg ys[i] - -def emitCtorScalarSize (usize : Nat) (ssize : Nat) : M Unit := - emit <| ctorScalarSizeStr usize ssize - -def emitAllocCtor (c : CtorInfo) : M Unit := do - emit "lean_alloc_ctor("; emit c.cidx; emit ", "; emit c.size; emit ", " - emitCtorScalarSize c.usize c.ssize; emitLn ");" - -def emitCtorSetArgs (z : VarId) (ys : Array Arg) : M Unit := - ys.size.forM fun i _ => do - emit "lean_ctor_set("; emit z; emit ", "; emit i; emit ", "; emitArg ys[i]; emitLn ");" - -def emitCtor (z : VarId) (c : CtorInfo) (ys : Array Arg) : M Unit := do - emitLhs z; - if c.size == 0 && c.usize == 0 && c.ssize == 0 then do - emit "lean_box("; emit c.cidx; emitLn ");" - else do - emitAllocCtor c; emitCtorSetArgs z ys - -def emitReset (z : VarId) (n : Nat) (x : VarId) : M Unit := do - emit "if (lean_is_exclusive("; emit x; emitLn ")) {"; - n.forM fun i _ => do - emit " lean_ctor_release("; emit x; emit ", "; emit i; emitLn ");" - emit " "; emitLhs z; emit x; emitLn ";"; - emitLn "} else {"; - emit " lean_dec_ref("; emit x; emitLn ");"; - emit " "; emitLhs z; emitLn "lean_box(0);"; - emitLn "}" - -def emitReuse (z : VarId) (x : VarId) (c : CtorInfo) (updtHeader : Bool) (ys : Array Arg) : M Unit := do - emit "if (lean_is_scalar("; emit x; emitLn ")) {"; - emit " "; emitLhs z; emitAllocCtor c; - emitLn "} else {"; - emit " "; emitLhs z; emit x; emitLn ";"; - if updtHeader then emit " lean_ctor_set_tag("; emit z; emit ", "; emit c.cidx; emitLn ");" - emitLn "}"; - emitCtorSetArgs z ys - -def emitProj (z : VarId) (i : Nat) (x : VarId) : M Unit := do - emitLhs z; emit "lean_ctor_get("; emit x; emit ", "; emit i; emitLn ");" - -def emitUProj (z : VarId) (i : Nat) (x : VarId) : M Unit := do - emitLhs z; emit "lean_ctor_get_usize("; emit x; emit ", "; emit i; emitLn ");" - -def emitSProj (z : VarId) (t : IRType) (n offset : Nat) (x : VarId) : M Unit := do - emitLhs z; - match t with - | IRType.float => emit "lean_ctor_get_float" - | IRType.float32 => emit "lean_ctor_get_float32" - | IRType.uint8 => emit "lean_ctor_get_uint8" - | IRType.uint16 => emit "lean_ctor_get_uint16" - | IRType.uint32 => emit "lean_ctor_get_uint32" - | IRType.uint64 => emit "lean_ctor_get_uint64" - | _ => throw "invalid instruction" - emit "("; emit x; emit ", "; emitOffset n offset; emitLn ");" - -def toStringArgs (ys : Array Arg) : List String := - ys.toList.map argToCString - -def emitSimpleExternalCall (f : String) (ps : Array Param) (ys : Array Arg) : M Unit := do - emit f; emit "(" - -- We must remove erased arguments to extern calls. - discard <| ys.size.foldM - (fun i _ (first : Bool) => - let ty := ps[i]!.ty - if ty.isErased || ty.isVoid then - pure first - else do - unless first do emit ", " - emitArg ys[i] - pure false) - true - emitLn ");" - pure () - -def emitExternCall (f : FunId) (ps : Array Param) (extData : ExternAttrData) (ys : Array Arg) : M Unit := - match getExternEntryFor extData `c with - | some (ExternEntry.standard _ extFn) => emitSimpleExternalCall extFn ps ys - | some (ExternEntry.inline _ pat) => do emit (expandExternPattern pat (toStringArgs ys)); emitLn ";" - | _ => throw s!"failed to emit extern application '{f}'" - -def emitLeanFunReference (t : IRType) (f : FunId) : M Unit := do - let env ← getEnv - if isSimpleGroundDecl env f then - emit s!"((lean_object*)({← toCName f}))" - else if isClosedTermName env f then - emitClosedTermRead t f - else - emitCName f -where - emitClosedTermRead (t : IRType) (f : FunId) : M Unit := do - let fn ← - match t with - | .float => pure "lean_float_once" - | .float32 => pure "lean_float32_once" - | .uint8 => pure "lean_uint8_once" - | .uint16 => pure "lean_uint16_once" - | .uint32 => pure "lean_uint32_once" - | .uint64 => pure "lean_uint64_once" - | .usize => pure "lean_usize_once" - | .object | .tobject | .tagged | .void => pure "lean_obj_once" - | _ => throw s!"failed to emit closed term read for '{f}'" - emit s!"{fn}(&{← toCName f}, &{toTokenName (← toCName f)}, {← toCInitName f})" - -def emitFullApp (z : VarId) (t : IRType) (f : FunId) (ys : Array Arg) : M Unit := do - emitLhs z - let decl ← getDecl f - match decl with - | .fdecl (xs := ps) .. | .extern (xs := ps) (ext := { entries := [.opaque], .. }) .. => - emitLeanFunReference t f - if ys.size > 0 then - let (ys, _) := ys.zip ps |>.filter (fun (_, p) => !p.ty.isVoid) |>.unzip - emit "("; emitArgs ys; emit ")" - emitLn ";" - | Decl.extern _ ps _ extData => emitExternCall f ps extData ys - -def emitPartialApp (z : VarId) (f : FunId) (ys : Array Arg) : M Unit := do - let decl ← getDecl f - let arity := decl.params.size; - emitLhs z; emit "lean_alloc_closure((void*)("; emitCName f; emit "), "; emit arity; emit ", "; emit ys.size; emitLn ");"; - ys.size.forM fun i _ => do - let y := ys[i] - emit "lean_closure_set("; emit z; emit ", "; emit i; emit ", "; emitArg y; emitLn ");" - -def emitApp (z : VarId) (f : VarId) (ys : Array Arg) : M Unit := - if ys.size > closureMaxArgs then do - emit "{ lean_object* _aargs[] = {"; emitArgs ys; emitLn "};"; - emitLhs z; emit "lean_apply_m("; emit f; emit ", "; emit ys.size; emitLn ", _aargs); }" - else do - emitLhs z; emit "lean_apply_"; emit ys.size; emit "("; emit f; emit ", "; emitArgs ys; emitLn ");" - -def emitBoxFn (xType : IRType) : M Unit := - match xType with - | IRType.usize => emit "lean_box_usize" - | IRType.uint32 => emit "lean_box_uint32" - | IRType.uint64 => emit "lean_box_uint64" - | IRType.float => emit "lean_box_float" - | IRType.float32 => emit "lean_box_float32" - | _ => emit "lean_box" - -def emitBox (z : VarId) (x : VarId) (xType : IRType) : M Unit := do - emitLhs z; emitBoxFn xType; emit "("; emit x; emitLn ");" - -def emitUnbox (z : VarId) (t : IRType) (x : VarId) : M Unit := do - emitLhs z - emit (getUnboxOpName t) - emit "("; emit x; emitLn ");" - -def emitIsShared (z : VarId) (x : VarId) : M Unit := do - emitLhs z; emit "!lean_is_exclusive("; emit x; emitLn ");" - -def emitNumLit (t : IRType) (v : Nat) : M Unit := do - if t.isObj then - if v < UInt32.size then - emit "lean_unsigned_to_nat("; emit v; emit "u)" - else - emit "lean_cstr_to_nat(\""; emit v; emit "\")" - else - if v < UInt32.size then - emit v - else if t == .usize then - emit "((size_t)" - emit v - emit "ULL)" - else - emit v - emit "ULL" - -def emitLit (z : VarId) (t : IRType) (v : LitVal) : M Unit := do - emitLhs z; - match v with - | LitVal.num v => emitNumLit t v; emitLn ";" - | LitVal.str v => - emit "lean_mk_string_unchecked("; - emit (quoteString v); emit ", "; - emit v.utf8ByteSize; emit ", "; - emit v.length; emitLn ");" - -def emitVDecl (z : VarId) (t : IRType) (v : Expr) : M Unit := - match v with - | Expr.ctor c ys => emitCtor z c ys - | Expr.reset n x => emitReset z n x - | Expr.reuse x c u ys => emitReuse z x c u ys - | Expr.proj i x => emitProj z i x - | Expr.uproj i x => emitUProj z i x - | Expr.sproj n o x => emitSProj z t n o x - | Expr.fap c ys => emitFullApp z t c ys - | Expr.pap c ys => emitPartialApp z c ys - | Expr.ap x ys => emitApp z x ys - | Expr.box t x => emitBox z x t - | Expr.unbox x => emitUnbox z t x - | Expr.isShared x => emitIsShared z x - | Expr.lit v => emitLit z t v - -def isTailCall (x : VarId) (v : Expr) (b : FnBody) : M Bool := do - let ctx ← read; - match v, b with - | Expr.fap f _, FnBody.ret (.var y) => return f == ctx.mainFn && x == y - | _, _ => pure false - -def paramEqArg (p : Param) (x : Arg) : Bool := - match x with - | .var x => p.x == x - | .erased => false - -/-- -Given `[p_0, ..., p_{n-1}]`, `[y_0, ..., y_{n-1}]`, representing the assignments -``` -p_0 := y_0, -... -p_{n-1} := y_{n-1} -``` -Return true iff we have `(i, j)` where `j > i`, and `y_j == p_i`. -That is, we have -``` - p_i := y_i, - ... - p_j := p_i, -- p_i was overwritten above -``` --/ -def overwriteParam (ps : Array Param) (ys : Array Arg) : Bool := - let n := ps.size; - n.any fun i _ => - let p := ps[i] - (i+1, n).anyI fun j _ _ => paramEqArg p ys[j]! - -def emitTailCall (v : Expr) : M Unit := - match v with - | Expr.fap _ ys => do - let ctx ← read - let ps := ctx.mainParams - if h : ps.size = ys.size then - let (ps, ys) := ps.zip ys |>.filter (fun (p, _) => !p.ty.isVoid) |>.unzip - if overwriteParam ps ys then - emitLn "{" - ps.size.forM fun i _ => do - let p := ps[i] - let y := ys[i]! - unless paramEqArg p y do - emit (toCType p.ty); emit " _tmp_"; emit i; emit " = "; emitArg y; emitLn ";" - ps.size.forM fun i _ => do - let p := ps[i] - let y := ys[i]! - unless paramEqArg p y do emit p.x; emit " = _tmp_"; emit i; emitLn ";" - emitLn "}" - else - ps.size.forM fun i _ => do - let p := ps[i] - let y := ys[i]! - unless paramEqArg p y do emit p.x; emit " = "; emitArg y; emitLn ";" - emitLn "goto _start;" - else - throw "invalid tail call" - | _ => throw "bug at emitTailCall" - -mutual - -partial def emitIf (x : VarId) (xType : IRType) (tag : Nat) (t : FnBody) (e : FnBody) : M Unit := do - emit "if ("; emitTag x xType; emit " == "; emit tag; emitLn ")"; - emitFnBody t; - emitLn "else"; - emitFnBody e - -partial def emitCase (x : VarId) (xType : IRType) (alts : Array Alt) : M Unit := - match isIf alts with - | some (tag, t, e) => emitIf x xType tag t e - | _ => do - emit "switch ("; emitTag x xType; emitLn ") {"; - let alts := ensureHasDefault alts; - alts.forM fun alt => do - match alt with - | Alt.ctor c b => emit "case "; emit c.cidx; emitLn ":"; emitFnBody b - | Alt.default b => emitLn "default: "; emitFnBody b - emitLn "}" - -partial def emitBlock (b : FnBody) : M Unit := do - match b with - | FnBody.jdecl _ _ _ b => emitBlock b - | d@(FnBody.vdecl x t v b) => - let ctx ← read - if isTailCallTo ctx.mainFn d then - emitTailCall v - else - emitVDecl x t v - emitBlock b - | FnBody.inc x n c p b => - unless p do emitInc x n c - emitBlock b - | FnBody.dec x n c p b => - unless p do emitDec x n c - emitBlock b - | FnBody.del x b => emitDel x; emitBlock b - | FnBody.setTag x i b => emitSetTag x i; emitBlock b - | FnBody.set x i y b => emitSet x i y; emitBlock b - | FnBody.uset x i y b => emitUSet x i y; emitBlock b - | FnBody.sset x i o y t b => emitSSet x i o y t; emitBlock b - | FnBody.ret x => emit "return "; emitArg x; emitLn ";" - | FnBody.case _ x xType alts => emitCase x xType alts - | FnBody.jmp j xs => emitJmp j xs - | FnBody.unreachable => emitLn "lean_internal_panic_unreachable();" - -partial def emitJPs : FnBody → M Unit - | FnBody.jdecl j _ v b => do emit j; emitLn ":"; emitFnBody v; emitJPs b - | e => do unless e.isTerminal do emitJPs e.body - -partial def emitFnBody (b : FnBody) : M Unit := do - emitLn "{" - let declared ← declareVars b false - if declared then emitLn "" - emitBlock b - emitJPs b - emitLn "}" - -end - -def emitDeclAux (d : Decl) : M Unit := do - let env ← getEnv - let (_, jpMap) := mkVarJPMaps d - withReader (fun ctx => { ctx with jpMap := jpMap }) do - unless hasInitAttr env d.name || isSimpleGroundDecl env d.name do - match d with - | .fdecl (f := f) (xs := xs) (type := t) (body := b) .. => - let baseName ← toCName f; - if xs.size == 0 then - emit "static " - else - emit "LEAN_EXPORT " -- make symbol visible to the interpreter - emit (toCType t); emit " "; - if xs.size > 0 then - let xs := xs.filter (fun p => !p.ty.isVoid) - emit baseName; - emit "("; - if xs.size > closureMaxArgs && isBoxedName d.name then - emit "lean_object** _args" - else - xs.size.forM fun i _ => do - if i > 0 then emit ", " - let x := xs[i] - emit (toCType x.ty); emit " "; emit x.x - emit ")" - else - emit ("_init_" ++ baseName ++ "(void)") - emitLn " {"; - if xs.size > closureMaxArgs && isBoxedName d.name then - xs.size.forM fun i _ => do - let x := xs[i]! - emit "lean_object* "; emit x.x; emit " = _args["; emit i; emitLn "];" - emitLn "_start:"; - withReader (fun ctx => { ctx with mainFn := f, mainParams := xs }) (emitFnBody b); - emitLn "}" - | _ => pure () - -def emitDecl (d : Decl) : M Unit := do - let d := d.normalizeIds; -- ensure we don't have gaps in the variable indices - try - emitDeclAux d - catch err => - throw s!"{err}\ncompiling:\n{d}" - -def emitFns : M Unit := do - let env ← getEnv; - let decls := getDecls env; - decls.reverse.forM emitDecl - -def emitMarkPersistent (d : Decl) (n : Name) : M Unit := do - if d.resultType.isObj then - emit "lean_mark_persistent(" - emitCName n - emitLn ");" - -def withErrRet (emitIORes : M Unit) : M Unit := do - emit "res = "; emitIORes; emitLn ";" - emitLn "if (lean_io_result_is_error(res)) return res;" - -def emitDeclInit (d : Decl) (isBuiltin : Bool) : M Unit := do - let env ← getEnv - let n := d.name - if (isBuiltin && isIOUnitBuiltinInitFn env n) || isIOUnitInitFn env n then - withErrRet do - emitCName n; emitLn "()" - emitLn "lean_dec_ref(res);" - else if d.params.size == 0 then - if let some initFn := (guard isBuiltin *> getBuiltinInitFnNameFor? env d.name) <|> getInitFnNameFor? env d.name then - withErrRet do - emitCName initFn; emitLn "()" - emitCName n - if d.resultType.isScalar then - emitLn (" = " ++ getUnboxOpName d.resultType ++ "(lean_io_result_get_value(res));") - else - emitLn " = lean_io_result_get_value(res);" - emitMarkPersistent d n - emitLn "lean_dec_ref(res);" - else if !isClosedTermName env d.name && !isSimpleGroundDecl env d.name then - emitCName n; emit " = "; emitCInitName n; emitLn "();"; emitMarkPersistent d n - -def emitInitFn (phases : IRPhases) : M Unit := do - let env ← getEnv - let impInitFns ← env.imports.filterMapM fun imp => do - if phases != .all && imp.isMeta != (phases == .comptime) then - return none - let some idx := env.getModuleIdx? imp.module - | throw "(internal) import without module index" -- should be unreachable - let pkg? := env.getModulePackageByIdx? idx - let fn := mkModuleInitializationFunctionName (phases := if phases == .all then .all else if imp.isMeta then .runtime else phases) imp.module pkg? - emitLn s!"lean_object* {fn}(uint8_t builtin);" - return some fn - let initialized := s!"_G_{mkModuleInitializationPrefix phases}initialized" - emitLns [ - s!"static bool {initialized} = false;", - s!"LEAN_EXPORT lean_object* {← getModInitFn (phases := phases)}(uint8_t builtin) \{", - "lean_object * res;", - s!"if ({initialized}) return lean_io_result_mk_ok(lean_box(0));", - s!"{initialized} = true;" - ] - impInitFns.forM fun fn => do - withErrRet do - emitLn s!"{fn}(builtin)" - emitLn "lean_dec_ref(res);" - let decls := getDecls env - for d in decls.reverse do - if phases == .all || (phases == .comptime) == isMarkedMeta env d.name then - emitDeclInit d (isBuiltin := phases != .comptime) - emitLns ["return lean_io_result_mk_ok(lean_box(0));", "}"] - -/-- Init function used before phase split under module system, keep for compatibility. -/ -def emitLegacyInitFn : M Unit := do - let env ← getEnv - let impInitFns ← env.imports.filterMapM fun imp => do - let some idx := env.getModuleIdx? imp.module - | throw "(internal) import without module index" -- should be unreachable - let pkg? := env.getModulePackageByIdx? idx - let fn := mkModuleInitializationFunctionName imp.module pkg? - emitLn s!"lean_object* {fn}(uint8_t builtin);" - return some fn - let initialized := s!"_G_initialized" - emitLns [ - s!"static bool {initialized} = false;", - s!"LEAN_EXPORT lean_object* {← getModInitFn (phases := .all)}(uint8_t builtin) \{", - "lean_object * res;", - s!"if ({initialized}) return lean_io_result_mk_ok(lean_box(0));", - s!"{initialized} = true;" - ] - impInitFns.forM fun fn => do - withErrRet do - emitLn s!"{fn}(builtin)" - emitLn "lean_dec_ref(res);" - withErrRet do - emitLn s!"{← getModInitFn (phases := .runtime)}(builtin)" - emitLn "lean_dec_ref(res);" - withErrRet do - emitLn s!"{← getModInitFn (phases := .comptime)}(builtin)" - emitLn "lean_dec_ref(res);" - emitLns [s!"return {← getModInitFn (phases := .all)}(builtin);", "}"] - -def main : M Unit := do - emitFileHeader - emitFnDecls - emitFns - if (← getEnv).header.isModule then - emitInitFn (phases := .runtime) - emitInitFn (phases := .comptime) - emitLegacyInitFn - else - emitInitFn (phases := .all) - emitMainFnIfNeeded - emitFileFooter - -end EmitC - -def emitC (env : Environment) (modName : Name) : IO String := - match EmitC.main { env, modName } |>.run "" with - | EStateM.Result.ok _ s => return s - | EStateM.Result.error err _ => throw <| .userError err - -end Lean.IR diff --git a/src/Lean/Compiler/IR/EmitLLVM.lean b/src/Lean/Compiler/IR/EmitLLVM.lean index f58849ab61..1586d712fb 100644 --- a/src/Lean/Compiler/IR/EmitLLVM.lean +++ b/src/Lean/Compiler/IR/EmitLLVM.lean @@ -9,7 +9,6 @@ prelude public import Lean.Compiler.NameMangling public import Lean.Compiler.IR.EmitUtil public import Lean.Compiler.IR.NormIds -public import Lean.Compiler.IR.SimpCase public import Lean.Compiler.IR.LLVMBindings import Lean.Compiler.LCNF.Types import Lean.Compiler.ModPkgExt diff --git a/src/Lean/Compiler/IR/SimpCase.lean b/src/Lean/Compiler/IR/SimpCase.lean deleted file mode 100644 index b10dcc84d4..0000000000 --- a/src/Lean/Compiler/IR/SimpCase.lean +++ /dev/null @@ -1,23 +0,0 @@ -/- -Copyright (c) 2019 Microsoft Corporation. All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. -Authors: Leonardo de Moura --/ -module - -prelude -public import Lean.Compiler.IR.Basic - -public section - -namespace Lean.IR - -def ensureHasDefault (alts : Array Alt) : Array Alt := - if alts.any Alt.isDefault then alts - else if alts.size < 2 then alts - else - let last := alts.back! - let alts := alts.pop - alts.push (Alt.default last.body) - -end Lean.IR diff --git a/src/Lean/Compiler/LCNF/Basic.lean b/src/Lean/Compiler/LCNF/Basic.lean index 6cff07bec6..16efc3776c 100644 --- a/src/Lean/Compiler/LCNF/Basic.lean +++ b/src/Lean/Compiler/LCNF/Basic.lean @@ -850,9 +850,11 @@ where | .jmp .. => inc | .return .. | unreach .. => return () +@[inline] partial def Code.forM [Monad m] (c : Code pu) (f : Code pu → m Unit) : m Unit := go c where + @[specialize] go (c : Code pu) : m Unit := do f c match c with diff --git a/src/Lean/Compiler/LCNF/EmitC.lean b/src/Lean/Compiler/LCNF/EmitC.lean new file mode 100644 index 0000000000..d88eb484b4 --- /dev/null +++ b/src/Lean/Compiler/LCNF/EmitC.lean @@ -0,0 +1,1163 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Henrik Böving +-/ +module + +prelude +public import Lean.Compiler.LCNF.CompilerM +import Lean.Compiler.LCNF.EmitUtil +import Lean.Compiler.NameMangling +import Lean.Compiler.LCNF.PhaseExt +import Lean.Compiler.ExportAttr +import Lean.Compiler.ModPkgExt +import Lean.Compiler.LCNF.SimpleGroundExpr +import Lean.Compiler.ClosedTermCache +import Lean.Runtime +import Lean.Compiler.LCNF.Internalize +import Lean.Compiler.InitAttr +import Init.Omega +import Init.While +import Lean.Compiler.LCNF.SimpCase +import Lean.Compiler.LCNF.PrettyPrinter + +namespace Lean.Compiler.LCNF + +def leanMainFn := "_lean_main" + +namespace ImpureType + +def Lean.Expr.toCType : Expr → String + | float => "double" + | float32 => "float" + | uint8 => "uint8_t" + | uint16 => "uint16_t" + | uint32 => "uint32_t" + | uint64 => "uint64_t" + | usize => "size_t" + | object => "lean_object*" + | tagged => "lean_object*" + | tobject => "lean_object*" + | erased => "lean_object*" + | void => "lean_object*" + | _ => unreachable! + +def Lean.Expr.unboxOpName (t : Expr) : String := + match t with + | usize => "lean_unbox_usize" + | uint32 => "lean_unbox_uint32" + | uint64 => "lean_unbox_uint64" + | float => "lean_unbox_float" + | float32 => "lean_unbox_float32" + | _ => "lean_unbox" + +def Lean.Expr.boxOpName (t : Expr) : String := + match t with + | usize => "lean_box_usize" + | uint32 => "lean_box_uint32" + | uint64 => "lean_box_uint64" + | float => "lean_box_float" + | float32 => "lean_box_float32" + | _ => "lean_box" + +def Lean.Expr.sprojOpName (t : Expr) : String := + match t with + | float => "lean_ctor_get_float" + | float32 => "lean_ctor_get_float32" + | uint8 => "lean_ctor_get_uint8" + | uint16 => "lean_ctor_get_uint16" + | uint32 => "lean_ctor_get_uint32" + | uint64 => "lean_ctor_get_uint64" + | _ => unreachable! + +def Lean.Expr.ssetOpName (t : Expr) : String := + match t with + | float => "lean_ctor_set_float" + | float32 => "lean_ctor_set_float32" + | uint8 => "lean_ctor_set_uint8" + | uint16 => "lean_ctor_set_uint16" + | uint32 => "lean_ctor_set_uint32" + | uint64 => "lean_ctor_set_uint64" + | _ => unreachable! + +def Lean.Expr.closedTermReadOpName (t : Expr) : String := + match t with + | float => "lean_float_once" + | float32 => "lean_float32_once" + | uint8 => "lean_uint8_once" + | uint16 => "lean_uint16_once" + | uint32 => "lean_uint32_once" + | uint64 => "lean_uint64_once" + | usize => "lean_usize_once" + | object | tobject | tagged | void => "lean_obj_once" + | _ => unreachable! + +end ImpureType + +open ImpureType + +structure Context where + /-- + Declarations from the current module in topologically sorted order. + -/ + localDecls : Array (Decl .impure) + /-- + Signatures of declarations from other modules. + -/ + otherModuleDecls : Array (Signature .impure) + /-- + The name of the current module + -/ + modName : Name + /-- + The function that is currently being emitted. + -/ + currFn : Name := default + /-- + The parameters of the function that is currently being emitted. + -/ + currParams : Array (Param .impure) := #[] + +structure State where + buf : String := "" + varMangleCache : Std.HashMap Name String := {} + funMangleCache : Std.HashMap Name String := {} + funInitMangleCache : Std.HashMap Name String := {} + +abbrev EmitM := ReaderT Context StateRefT State CompilerM + +@[inline] def getModName : EmitM Name := return (← read).modName + +@[inline] def getModInitFn (phases : IRPhases) : EmitM String := do + let pkg? := (← getEnv).getModulePackage? + return mkModuleInitializationFunctionName (phases := phases) (← getModName) pkg? + +@[inline] def getCurrFn : EmitM Name := return (← read).currFn + +@[inline] def getCurrParams : EmitM (Array (Param .impure)) := return (← read).currParams + +@[inline] def getLocalDecls : EmitM (Array (Decl .impure)) := return (← read).localDecls + +@[inline] def getOtherModuleDecls : EmitM (Array (Signature .impure)) := + return (← read).otherModuleDecls + +class EmitToString (α : Type) where + toEmitString : α → EmitM String + +instance (priority := low) [ToString α] : EmitToString α where + toEmitString x := return toString x + +instance : EmitToString Name where + toEmitString v := do + modifyGet fun s => + if let some mangled := s.varMangleCache[v]? then + (mangled, s) + else + let mangled := v.mangle (pre := "v_") + (mangled, { s with varMangleCache := s.varMangleCache.insert v mangled }) + +instance : EmitToString FVarId where + toEmitString fvarId := do EmitToString.toEmitString (← getBinderName fvarId) + +def Arg.toCString (a : Arg .impure) : EmitM String := do + match a with + | .fvar fvarId => EmitToString.toEmitString fvarId + | .erased => return "lean_box(0)" + +instance : EmitToString (Arg .impure) where + toEmitString a := a.toCString + + +@[inline] def emit [EmitToString α] (a : α) : EmitM Unit := do + let str ← EmitToString.toEmitString a + modify fun out => { out with buf := out.buf ++ str } + +@[inline] def emitLn [EmitToString α] (a : α) : EmitM Unit := do + emit a; emit "\n" + +@[inline] +def emitCApp1 {α : Type} [EmitToString α] (fn : String) (arg : α) : EmitM Unit := do + emit fn; emit "("; emit arg; emit ")" + +@[inline] +def emitCApp2 {α β : Type} [EmitToString α] [EmitToString β] (fn : String) (arg1 : α) (arg2 : β) : + EmitM Unit := do + emit fn; emit "("; emit arg1; emit ", "; emit arg2; emit ")" + +@[inline] +def emitCApp3 {α β γ : Type} [EmitToString α] [EmitToString β] [EmitToString γ] (fn : String) + (arg1 : α) (arg2 : β) (arg3 : γ) : EmitM Unit := do + emit fn; emit "("; emit arg1; emit ", "; emit arg2; emit ", "; emit arg3; emit ")" + +def toStringArgs (ys : Array (Arg .impure)) : EmitM (List String) := + ys.toList.mapM (·.toCString) + +def emitArgs (args : Array (Arg .impure)) : EmitM Unit := do + for h : i in 0...args.size do + if i > 0 then emit ", " + emit args[i] + +def emitLns [EmitToString α] (as : List α) : EmitM Unit := + as.forM fun a => emitLn a + +@[inline] def withEmitBlock (x : EmitM α) : EmitM α := do + emitLn "{" + let ret ← x + emitLn "}" + return ret + +def toHexDigit (c : Nat) : String := + String.singleton c.digitChar + +def quoteString (s : String) : String := + let q := "\""; + let q := s.foldl + (fun q c => q ++ + if c == '\n' then "\\n" + else if c == '\r' then "\\r" + else if c == '\t' then "\\t" + else if c == '\\' then "\\\\" + else if c == '\"' then "\\\"" + else if c == '?' then "\\?" -- avoid trigraphs + else if c.toNat <= 31 then + "\\x" ++ toHexDigit (c.toNat / 16) ++ toHexDigit (c.toNat % 16) + -- TODO(Leo): we should use `\unnnn` for escaping unicode characters. + else String.singleton c) + q; + q ++ "\"" + +def throwInvalidExportName (n : Name) : EmitM α := + throwError s!"invalid export name '{n}'" + +def toCName (n : Name) : EmitM String := do + if let some cached := (← get).funMangleCache[n]? then + return cached + let mangled ← go + modify fun s => { s with funMangleCache := s.funMangleCache.insert n mangled } + return mangled +where + go : EmitM String := do + let env ← getEnv + -- TODO: we should support simple export names only + match getExportNameFor? env n with + | some (.str .anonymous s) => return s + | some _ => throwInvalidExportName n + | none => return if n == `main then leanMainFn else getSymbolStem env n + +def emitCName (n : Name) : EmitM Unit := + toCName n >>= emit + +def toCInitName (n : Name) : EmitM String := do + if let some cached := (← get).funInitMangleCache[n]? then + return cached + let mangled ← go + modify fun s => { s with funInitMangleCache := s.funInitMangleCache.insert n mangled } + return mangled +where + go : EmitM String := do + let env ← getEnv; + -- TODO: we should support simple export names only + match getExportNameFor? env n with + | some (.str .anonymous s) => return "_init_" ++ s + | some _ => throwInvalidExportName n + | none => return "_init_" ++ getSymbolStem env n + +def emitCInitName (n : Name) : EmitM Unit := + toCInitName n >>= emit + +def emitFileHeader : EmitM Unit := do + let env ← getEnv + let modName ← getModName + emitLn "// Lean compiler output" + emitLn s!"// Module: {modName}" + emit "// Imports:" + env.imports.forM fun m => emit (" " ++ toString m) + emitLn "" + emitLn "#include " + emitLns [ + "#if defined(__clang__)", + "#pragma clang diagnostic ignored \"-Wunused-parameter\"", + "#pragma clang diagnostic ignored \"-Wunused-label\"", + "#elif defined(__GNUC__) && !defined(__CLANG__)", + "#pragma GCC diagnostic ignored \"-Wunused-parameter\"", + "#pragma GCC diagnostic ignored \"-Wunused-label\"", + "#pragma GCC diagnostic ignored \"-Wunused-but-set-variable\"", + "#endif", + "#ifdef __cplusplus", + "extern \"C\" {", + "#endif" + ] + +def ctorScalarSizeExpression (usize : Nat) (ssize : Nat) : String := + if usize == 0 then + s!"{ssize}" + else if ssize == 0 then + s!"sizeof(size_t)*{usize}" + else + s!"sizeof(size_t)*{usize} + {ssize}" + +structure GroundState where + auxCounter : Nat := 0 + +abbrev GroundM := StateRefT GroundState EmitM + +partial def emitGroundDecl (decl : Decl .impure) (cppBaseName : String) : EmitM Unit := do + let some ground := getSimpleGroundExpr (← getEnv) decl.name | unreachable! + discard <| compileGround ground |>.run {} +where + compileGround (e : SimpleGroundExpr) : GroundM Unit := do + let valueName ← compileGroundToValue e + let declPrefix := if isClosedTermName (← getEnv) decl.name then "static" else "LEAN_EXPORT" + emitLn <| s!"{declPrefix} const lean_object* {cppBaseName} = (const lean_object*)&{valueName};" + + compileGroundToValue (e : SimpleGroundExpr) : GroundM String := do + match e with + | .ctor cidx objArgs usizeArgs scalarArgs => + let val ← compileCtor cidx objArgs usizeArgs scalarArgs + mkValueCLit "lean_ctor_object" val + | .string data => + let leanStringTag := 249 + let header := mkHeader 0 0 leanStringTag + let size := data.utf8ByteSize + 1 -- null byte + let length := data.length + let data : String := quoteString data + mkValueCLit + "lean_string_object" + s!"\{.m_header = {header}, .m_size = {size}, .m_capacity = {size}, .m_length = {length}, .m_data = {data}}" + | .pap func args => + let numFixed := args.size + let leanClosureTag := 245 + let header := mkHeader s!"sizeof(lean_closure_object) + sizeof(void*)*{numFixed}" 0 leanClosureTag + let funPtr := s!"(void*){← toCName func}" + let arity := (← getImpureSignature? func).get!.params.size + let args ← args.mapM groundArgToCLit + let argArray := String.intercalate "," args.toList + mkValueCLit + "lean_closure_object" + s!"\{.m_header = {header}, .m_fun = {funPtr}, .m_arity = {arity}, .m_num_fixed = {numFixed}, .m_objs = \{{argArray}} }" + | .nameMkStr args => + let obj ← groundNameMkStrToCLit args + mkValueCLit "lean_ctor_object" obj + | .array elems => + let leanArrayTag := 246 + let header := mkHeader s!"sizeof(lean_array_object) + sizeof(void*)*{elems.size}" 0 leanArrayTag + let elemLits ← elems.mapM groundArgToCLit + let dataArray := String.intercalate "," elemLits.toList + mkValueCLit + "lean_array_object" + s!"\{.m_header = {header}, .m_size = {elems.size}, .m_capacity = {elems.size}, .m_data = \{{dataArray}}}" + | .byteArray data => + let leanScalarArrayTag := 248 + let elemSize : Nat := 1 + let header := mkHeader s!"sizeof(lean_sarray_object) + {data.size}" elemSize leanScalarArrayTag + let dataLits := data.map toString + let dataArray := String.intercalate "," dataLits.toList + mkValueCLit + "lean_sarray_object" + s!"\{.m_header = {header}, .m_size = {data.size}, .m_capacity = {data.size}, .m_data = \{{dataArray}}}" + | .reference refDecl => findValueDecl refDecl + + mkValueName (name : String) : String := + name ++ "_value" + + mkAuxValueName (name : String) (idx : Nat) : String := + mkValueName name ++ s!"_aux_{idx}" + + mkAuxDecl (type value : String) : GroundM String := do + let idx ← modifyGet fun s => (s.auxCounter, { s with auxCounter := s.auxCounter + 1 }) + let name := mkAuxValueName cppBaseName idx + emitLn <| s!"static const {type} {name} = {value};" + return name + + mkValueCLit (type value : String) : GroundM String := do + let valueName := mkValueName cppBaseName + emitLn <| s!"static const {type} {valueName} = {value};" + return valueName + + groundNameMkStrToCLit (args : Array (Name × UInt64)) : GroundM String := do + assert! args.size > 0 + if h : args.size = 1 then + let (ref, hash) := args[0] + let hash := uint64ToByteArrayLE hash + compileCtor 1 #[.tagged 0, .reference ref] #[] hash + else + let (ref, hash) := args.back! + let args := args.pop + let lit ← groundNameMkStrToCLit args + let auxName ← mkAuxDecl "lean_ctor_object" lit + let hash := uint64ToByteArrayLE hash + compileCtor 1 #[.rawReference auxName, .reference ref] #[] hash + + groundArgToCLit (a : SimpleGroundArg) : GroundM String := do + match a with + | .tagged val => return s!"((lean_object*)(((size_t)({val}) << 1) | 1))" + | .reference decl => return s!"((lean_object*)&{← findValueDecl decl})" + | .rawReference decl => return s!"((lean_object*)&{decl})" + + findValueDecl (decl : Name) : GroundM String := do + let mut decl := decl + while true do + if let some (.reference ref) := getSimpleGroundExpr (← getEnv) decl then + decl := ref + else + break + return mkValueName (← toCName decl) + + compileCtor (cidx : Nat) (objArgs : Array SimpleGroundArg) (usizeArgs : Array UInt64) + (scalarArgs : Array UInt8) : GroundM String := do + let header := mkCtorHeader objArgs.size usizeArgs.size scalarArgs.size cidx + let objArgs ← objArgs.mapM groundArgToCLit + let usizeArgs : Array String := usizeArgs.map fun val => s!"(lean_object*)(size_t)({val}ULL)" + assert! scalarArgs.size % 8 == 0 + let scalarArgs : Array String := Id.run do + let chunks := scalarArgs.size / 8 + let mut packed := Array.emptyWithCapacity chunks + for h : idx in 0...chunks do + have : idx * 8 + 7 < scalarArgs.size := by + have : idx < scalarArgs.size / 8 := Std.Rco.lt_upper_of_mem h + simp at this + omega + let b1 := scalarArgs[idx * 8] + let b2 := scalarArgs[idx * 8 + 1] + let b3 := scalarArgs[idx * 8 + 2] + let b4 := scalarArgs[idx * 8 + 3] + let b5 := scalarArgs[idx * 8 + 4] + let b6 := scalarArgs[idx * 8 + 5] + let b7 := scalarArgs[idx * 8 + 6] + let b8 := scalarArgs[idx * 8 + 7] + let lit := s!"LEAN_SCALAR_PTR_LITERAL({b1}, {b2}, {b3}, {b4}, {b5}, {b6}, {b7}, {b8})" + packed := packed.push lit + return packed + let argArray := String.intercalate "," (objArgs ++ usizeArgs ++ scalarArgs).toList + return s!"\{.m_header = {header}, .m_objs = \{{argArray}}}" + + mkCtorHeader (numObjs : Nat) (usize : Nat) (ssize : Nat) (tag : Nat) : String := + let size := s!"sizeof(lean_ctor_object) + sizeof(void*)*{numObjs} + {ctorScalarSizeExpression usize ssize}" + mkHeader size numObjs tag + + mkHeader {α : Type} [ToString α] (csSz : α) (other : Nat) (tag : Nat) : String := + s!"\{.m_rc = 0, .m_cs_sz = {csSz}, .m_other = {other}, .m_tag = {tag}}" + +def toOnceTokenName (cppBaseName : String) : String := + s!"{cppBaseName}_once" + +@[inline] +def paramsWithoutVoid (ps : Array (Param .impure)) := + ps.filter (!·.type.isVoid) + +@[inline] +def paramsWithoutErased (ps : Array (Param .impure)) := + ps.filter (!·.type.isErased) + +def emitFnDecls : EmitM Unit := do + (← getOtherModuleDecls).forM fun sig => do + match getExternNameFor (← getEnv) `c sig.name with + | some externName => emitExternDecl sig externName + | none => emitFnDeclStandard sig true + (← getLocalDecls).forM fun decl => do + match getExternNameFor (← getEnv) `c decl.name with + | some externName => emitExternDecl decl.toSignature externName + | none => emitFnDecl decl false +where + emitExternDecl (sig : Signature .impure) (externName : String) : EmitM Unit := do + let env ← getEnv + let extC := isExternC env sig.name + emitFnDeclAux sig externName extC + + emitFnDecl (decl : Decl .impure) (isExternal : Bool) : EmitM Unit := do + let env ← getEnv + let cppBaseName ← toCName decl.name + if isSimpleGroundDecl env decl.name then + emitGroundDecl decl cppBaseName + else if isClosedTermName env decl.name then + emitFnDeclClosed decl cppBaseName + else + emitFnDeclStandard decl.toSignature isExternal + + emitFnDeclClosed (decl : Decl .impure) (cppBaseName : String) : EmitM Unit := do + emitLn s!"static lean_once_cell_t {toOnceTokenName cppBaseName} = LEAN_ONCE_CELL_INITIALIZER;" + emitLn s!"static {decl.type.toCType} {cppBaseName};" + + emitFnDeclStandard (sig : Signature .impure) (isExternal : Bool) : EmitM Unit := do + let env ← getEnv + let cppBaseName ← toCName sig.name + emitFnDeclAux sig cppBaseName isExternal + + emitFnDeclAux (sig : Signature .impure) (cppBaseName : String) (isExternal : Bool) : + EmitM Unit := do + let ps := sig.params + let env ← getEnv + + if ps.isEmpty then + if isExternal then + emit "extern " + else + emit "LEAN_EXPORT " + else if !isExternal then + emit "LEAN_EXPORT " + emit <| sig.type.toCType ++ " " ++ cppBaseName + unless ps.isEmpty do + emit "(" + -- We omit void parameters, note that they are guaranteed not to occur in boxed functions + let ps := paramsWithoutVoid ps + -- We omit erased parameters for extern constants + let ps := if isExternC env sig.name then paramsWithoutErased ps else ps + if ps.size > closureMaxArgs && isBoxedName sig.name then + emit "lean_object**" + else + ps.size.forM fun i _ => do + if i > 0 then emit ", " + emit ps[i].type.toCType + emit ")" + emitLn ";" + +def offsetExpression (i : Nat) (offset : Nat) : String := + if i > 0 then + if offset > 0 then + s!"sizeof(void*)*{i} + {offset}" + else + s!"sizeof(void*)*{i}" + else + s!"{offset}" + +def isTailCall (code : Code .impure) : EmitM Bool := + match code with + | .let { fvarId := fvarId, value := .fap declName _, .. } (.return fvarId') => + return fvarId == fvarId' && (← getCurrFn) == declName + | _ => return false + +def declareVars (code : Code .impure) : EmitM Bool := + go code false +where + go (code : Code .impure) (didChange : Bool) : EmitM Bool := do + match code with + | .let decl k => + if ← isTailCall code then + return didChange + else + declareVar decl.binderName decl.type + go k true + | .jp decl k => + declareParams decl.params + go k (didChange || !decl.params.isEmpty) + | .del (k := k) .. | .dec (k := k) .. | .inc (k := k) .. | .setTag (k := k) .. + | .sset (k := k) .. | .uset (k := k) .. | .oset (k := k) .. => go k didChange + | .cases .. | .return .. | .jmp .. | .unreach .. => return didChange + + + declareVar (binderName : Name) (type : Expr) : EmitM Unit := do + emit type.toCType; emit " "; emit binderName; emit "; " + + declareParams (ps : Array (Param .impure)) : EmitM Unit := do + ps.forM fun p => declareVar p.binderName p.type + +def emitLetDecl (decl : LetDecl .impure) : EmitM Unit := do + match decl.value with + | .ctor info args => emitCtor info args + | .reset n fvarId => emitReset n fvarId + | .reuse fvarId info update args => emitReuse fvarId info update args + | .oproj i fvarId => emitOproj i fvarId + | .uproj i fvarId => emitUproj i fvarId + | .sproj n offset fvarId => emitSproj n offset fvarId + | .fap fn args => emitFap fn args + | .pap fn args => emitPap fn args + | .fvar fvarId args => emitAp fvarId args + | .box ty fvarId => emitBox ty fvarId + | .unbox fvarId => emitUnbox fvarId + | .isShared fvarId => emitIsShared fvarId + | .lit v => emitLit v + | .erased => emitErased +where + emitAllocCtor (info : CtorInfo) : EmitM Unit := + emitCApp3 "lean_alloc_ctor" info.cidx info.size (ctorScalarSizeExpression info.usize info.ssize) + + emitCtorSetArgs (targetId : FVarId) (args : Array (Arg .impure)) : EmitM Unit := do + for h : i in 0...args.size do + let arg := args[i] + emitCApp3 "lean_ctor_set" targetId i arg; emitLn ";" + + emitCtor (info : CtorInfo) (args : Array (Arg .impure)) : EmitM Unit := do + if info.size == 0 && info.usize == 0 && info.ssize == 0 then do + withEmitAssignment do emitCApp1 "lean_box" info.cidx + else do + withEmitAssignment do emitAllocCtor info + emitCtorSetArgs decl.fvarId args + + emitReset (n : Nat) (fvarId : FVarId) : EmitM Unit := do + emit "if("; emitCApp1 "lean_is_exclusive" fvarId; emit ")" + withEmitBlock do + for i in 0...n do + emitCApp2 "lean_ctor_release" fvarId i; emitLn ";" + withEmitAssignment do emit fvarId + emit "else" + withEmitBlock do + emitCApp1 "lean_dec_ref" fvarId; emitLn ";" + withEmitAssignment do emit "lean_box(0)" + + emitReuse (fvarId : FVarId) (info : CtorInfo) (update : Bool) (args : Array (Arg .impure)) : + EmitM Unit := do + emit "if("; emitCApp1 "lean_is_scalar" fvarId; emit ")" + withEmitBlock do + withEmitAssignment do emitAllocCtor info + emit "else" + withEmitBlock do + withEmitAssignment do emit fvarId + if update then + emitCApp2 "lean_ctor_set_tag" decl.fvarId info.cidx; emitLn ";" + emitCtorSetArgs decl.fvarId args + + emitOproj (i : Nat) (fvarId : FVarId) : EmitM Unit := do + withEmitAssignment do + emitCApp2 "lean_ctor_get" fvarId i + + emitUproj (i : Nat) (fvarId : FVarId) : EmitM Unit := do + withEmitAssignment do + emitCApp2 "lean_ctor_get_usize" fvarId i + + emitSproj (n : Nat) (offset : Nat) (fvarId : FVarId) : EmitM Unit := do + withEmitAssignment do + emitCApp2 decl.type.sprojOpName fvarId (offsetExpression n offset) + + emitLeanFunReference (ty : Expr) (f : Name) : EmitM Unit := do + let env ← getEnv + if isSimpleGroundDecl env f then + emit s!"((lean_object*)({← toCName f}))" + else if isClosedTermName env f then + let cname ← toCName f + let cnameRef := s!"&{cname}" + let tokenRef := s!"&{toOnceTokenName cname}" + let initName ← toCInitName f + emitCApp3 ty.closedTermReadOpName cnameRef tokenRef initName + else + emitCName f + + emitFap (fn : Name) (args : Array (Arg .impure)) : EmitM Unit := do + let some sig ← getImpureSignature? fn | unreachable! + let ps := sig.params + withEmitAssignment do + match getExternAttrData? (← getEnv) fn |>.bind (getExternEntryFor · `c) with + | some (.standard _ fn) => + let (_, args) := + ps.zip args + |>.filter (fun (p, _) => !(p.type.isVoid || p.type.isErased)) + |>.unzip + emit fn; emit "(" + for h : i in 0...args.size do + if i > 0 then emit ", " + emit args[i] + emit ")" + | some (.inline _ pat) => + emit (expandExternPattern pat (← toStringArgs args)) + | some .opaque | none => + emitLeanFunReference decl.type fn + if args.size > 0 then + let (_, args) := + ps.zip args + |>.filter (fun (p, _) => !p.type.isVoid) + |>.unzip + emit "("; emitArgs args; emit ")" + | _ => throwError s!"failed to emit extern application '{fn}'" + + emitPap (fn : Name) (args : Array (Arg .impure)) : EmitM Unit := do + let some sig ← getImpureSignature? fn | unreachable! + let arity := sig.params.size + withEmitAssignment do + emitCApp3 "lean_alloc_closure" s!"(void*)({← toCName fn})" arity args.size + for h : i in 0...args.size do + let arg := args[i] + emitCApp3 "lean_closure_set" decl.fvarId i arg; emitLn ";" + + emitAp (fvarId : FVarId) (args : Array (Arg .impure)) : EmitM Unit := do + assert! !args.isEmpty + if args.size > closureMaxArgs then + withEmitBlock do + emit "lean_object* _aargs[] = {"; emitArgs args; emitLn "};" + withEmitAssignment do + emitCApp3 "lean_apply_m" fvarId args.size "_aargs" + else + withEmitAssignment do + emit s!"lean_apply_{args.size}("; emit fvarId; emit ", "; emitArgs args; emit ")" + + emitBox (ty : Expr) (fvarId : FVarId) : EmitM Unit := do + withEmitAssignment do + emitCApp1 ty.boxOpName fvarId + + emitUnbox (fvarId : FVarId) : EmitM Unit := do + withEmitAssignment do + emitCApp1 decl.type.unboxOpName fvarId + + emitIsShared (fvarId : FVarId) : EmitM Unit := do + withEmitAssignment do + emit "!"; emitCApp1 "lean_is_exclusive" fvarId + + emitLit (v : LitValue) : EmitM Unit := do + withEmitAssignment do + match v with + | .uint8 v | .uint16 v | .uint32 v => emit v + | .uint64 v => emit v; emit "ULL" + | .usize v => emit "((size_t)"; emit v; emit "ULL)" + | .nat v => + if v < UInt32.size then + emit "lean_unsigned_to_nat("; emit v; emit "u)" + else + emit "lean_cstr_to_nat(\""; emit v; emit "\")" + | .str v => + emitCApp3 "lean_mk_string_unchecked" (quoteString v) v.utf8ByteSize v.length + + emitErased : EmitM Unit := do + withEmitAssignment do + emit "lean_box(0)" + + emitLhs (binderName : Name) : EmitM Unit := do + emit binderName; emit " = " + + @[inline] + withEmitAssignment {α : Type} (x : EmitM α) : EmitM α := do + emitLhs decl.binderName + let ret ← x + emitLn ";" + return ret + +def emitTailCall (decl : LetDecl .impure) : EmitM Unit := do + let .fap _ args := decl.value | unreachable! + let ps ← getCurrParams + assert! ps.size == args.size + let (ps, args) := ps.zip args |>.filter (fun (p, _) => !p.type.isVoid) |>.unzip + if overwriteParam ps args then + withEmitBlock do + for h : i in 0...ps.size do + let p := ps[i] + let arg := args[i]! + unless paramEqArg p arg do + emit p.type.toCType; emit " _tmp_"; emit i; emit " = "; emit arg; emitLn ";" + + for h : i in 0...ps.size do + let p := ps[i] + let arg := args[i]! + unless paramEqArg p arg do + emit p.binderName; emit " = _tmp_"; emit i; emitLn ";" + else + for p in ps, arg in args do + unless paramEqArg p arg do + emit p.binderName; emit " = "; emit arg; emitLn ";" + emitLn "goto _start;" +where + /-- + Given `[p_0, ..., p_{n-1}]`, `[arg_0, ..., arg_{n-1}]`, representing the assignments + ``` + p_0 := arg_0, + ... + p_{n-1} := arg_{n-1} + ``` + Return true iff we have `(i, j)` where `j > i`, and `arg_j == p_i`. + That is, we have + ``` + p_i := arg_i, + ... + p_j := p_i, -- p_i was overwritten above + ``` + -/ + overwriteParam (ps : Array (Param .impure)) (args : Array (Arg .impure)) : Bool := Id.run do + for h1 : i in 0...ps.size do + let p := ps[i] + for h2 : j in (i+1)...args.size do + if paramEqArg p args[j] then + return true + return false + + paramEqArg (p : Param .impure) (arg : Arg .impure) : Bool := + match arg with + | .fvar fvarId => p.fvarId == fvarId + | .erased => false + + +mutual + +private partial def emitBasicBlock (code : Code .impure) : EmitM Unit := do + match code with + | .jp (k := k) .. => emitBasicBlock k + | .let decl k => + if ← isTailCall code then + emitTailCall decl + else + emitLetDecl decl + emitBasicBlock k + | .inc fvarId n check persistent k => + unless persistent do emitInc fvarId n check + emitBasicBlock k + | .dec fvarId n check persistent k => + unless persistent do emitDec fvarId n check + emitBasicBlock k + | .del fvarId k => + emitDel fvarId + emitBasicBlock k + | .setTag fvarId cidx k => + emitSetTag fvarId cidx + emitBasicBlock k + | .oset fvarId i y k => + emitOset fvarId i y + emitBasicBlock k + | .uset fvarId i y k => + emitUset fvarId i y + emitBasicBlock k + | .sset fvarId i offset y ty k => + emitSset fvarId i offset y ty + emitBasicBlock k + | .cases cs => emitCases cs + | .return fvarId => emitReturn fvarId + | .jmp fvarId args => emitJmp fvarId args + | .unreach .. => emitUnreach +where + emitInc (fvarId : FVarId) (n : Nat) (check : Bool) : EmitM Unit := do + if n == 1 then + let incFn := if check then "lean_inc" else "lean_inc_ref" + emitCApp1 incFn fvarId + else + let incFn := if check then "lean_inc_n" else "lean_inc_ref_n" + emitCApp2 incFn fvarId n + emitLn ";" + + emitDec (fvarId : FVarId) (n : Nat) (check : Bool) : EmitM Unit := do + -- Anything else is unsupported at the moment + assert! n == 1 + let decFn := if check then "lean_dec" else "lean_dec_ref" + emitCApp1 decFn fvarId + emitLn ";" + + emitDel (fvarId : FVarId) : EmitM Unit := do + emitCApp1 "lean_del_object" fvarId + emitLn ";" + + emitSetTag (fvarId : FVarId) (cidx : Nat) : EmitM Unit := do + emitCApp2 "lean_ctor_set_tag" fvarId cidx + emitLn ";" + + emitOset (fvarId : FVarId) (i : Nat) (y : Arg .impure) : EmitM Unit := do + emitCApp3 "lean_ctor_set" fvarId i y + emitLn ";" + + emitUset (fvarId : FVarId) (i : Nat) (y : FVarId) : EmitM Unit := do + emitCApp3 "lean_ctor_set_usize" fvarId i y + emitLn ";" + + emitSset (fvarId : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) : EmitM Unit := do + emitCApp3 ty.ssetOpName fvarId (offsetExpression i offset) y + emitLn ";" + + isIf (cs : Cases .impure) : EmitM (Option (Nat × Code .impure × Code .impure)) := do + if h : cs.alts.size = 2 then + match cs.alts[0] with + | .ctorAlt info k => return some (info.cidx, k, cs.alts[1].getCode) + | _ => return none + else + return none + + emitTag (fvarId : FVarId) : EmitM Unit := do + let type ← getType fvarId + if type.isObj then do + emitCApp1 "lean_obj_tag" fvarId + else + emit fvarId + + emitCases (cs : Cases .impure) : EmitM Unit := do + match ← isIf cs with + | some (tag, t, e) => + emit "if ("; emitTag cs.discr; emit " == "; emit tag; emitLn ")"; + emitCode t + emitLn "else" + emitCode e + | none => + emit "switch("; emitTag cs.discr; emitLn ")" + withEmitBlock do + let alts := ensureHasDefault cs.alts + -- TODO: consider UB if no default? + alts.forM fun alt => do + match alt with + | .ctorAlt info k => + emit "case "; emit info.cidx; emitLn ":" + emitCode k + | .default k => + emitLn "default: "; + emitCode k + + emitReturn (fvarId : FVarId) : EmitM Unit := do + emit "return "; emit fvarId; emitLn ";" + + emitJmp (fvarId : FVarId) (args : Array (Arg .impure)) : EmitM Unit := do + let some jpDecl ← findFunDecl? (pu := .impure) fvarId | unreachable! + let ps := jpDecl.params + if args.size != ps.size then + throwError "invalid jump" + for arg in args, p in ps do + emit p.binderName; emit " = "; emit arg; emitLn ";" + emit "goto "; emit fvarId; emitLn ";" + + emitUnreach : EmitM Unit := do + emitLn "lean_internal_panic_unreachable();" + +private partial def emitJoinPoints (code : Code .impure) : EmitM Unit := do + match code with + | .jp decl k => + emit decl.binderName; emitLn ":" + emitCode decl.value + emitJoinPoints k + | .let (k := k) .. | .del (k := k) .. | .dec (k := k) .. | .inc (k := k) .. | .setTag (k := k) .. + | .sset (k := k) .. | .uset (k := k) .. | .oset (k := k) .. => emitJoinPoints k + | .cases .. | .return .. | .jmp .. | .unreach .. => return () + +private partial def emitCode (code : Code .impure) : EmitM Unit := do + withEmitBlock do + let declared ← declareVars code + if declared then emitLn "" + emitBasicBlock code + emitJoinPoints code + +end + +def emitDecl (decl : Decl .impure) : EmitM Unit := do + let env ← getEnv + if hasInitAttr env decl.name || isSimpleGroundDecl env decl.name then + return () + match decl.value with + | .extern .. => return () + | .code code => + let baseName ← toCName decl.name + let ps := decl.params + if ps.isEmpty then + emit "static " + else + -- make the symbol visible to the interpreter for native execution + emit "LEAN_EXPORT " + + emit decl.type.toCType; emit " " + + if ps.isEmpty then + emitCInitName decl.name + emit "(void)" + else + emit baseName + emit "(" + let ps := paramsWithoutVoid ps + if ps.size > closureMaxArgs && isBoxedName decl.name then + emit "lean_object** _args" + else + ps.size.forM fun i _ => do + if i > 0 then emit ", " + let p := ps[i] + emit p.type.toCType; emit " "; emit p.binderName + emit ")" + + withEmitBlock do + if ps.size > closureMaxArgs && isBoxedName decl.name then + ps.size.forM fun i _ => do + let p := ps[i] + emit "lean_object* "; emit p.binderName; emit " = _args["; emit i; emitLn "];" + -- goto marker for tail recursion + emitLn "_start:" + withReader (fun ctx => { ctx with currFn := decl.name, currParams := ps }) do + emitCode code + +def emitFns : EmitM Unit := do + (← getLocalDecls).forM go +where + go (decl : Decl .impure) : EmitM Unit := do + let decl ← decl.internalize (uniqueIdents := true) + try + emitDecl decl + catch err => + throwError m!"{err.toMessageData}\ncompiling:\n{decl.name}" + +def withErrRet (emitIORes : EmitM Unit) : EmitM Unit := do + emit "res = "; emitIORes; emitLn ";" + emitLn "if (lean_io_result_is_error(res)) return res;" + +def emitMarkPersistent (decl : Decl .impure) : EmitM Unit := do + if decl.type.isObj then + emitCApp1 "lean_mark_persistent" (← toCName decl.name); emitLn ";" + +def emitDeclInit (decl : Decl .impure) (isBuiltin : Bool) : EmitM Unit := do + let env ← getEnv + if (isBuiltin && isIOUnitBuiltinInitFn env decl.name) || isIOUnitInitFn env decl.name then + withErrRet do + emitCName decl.name; emit "()" + emitLn "lean_dec_ref(res);" + else if decl.params.isEmpty then + if let some initFn := (guard isBuiltin *> getBuiltinInitFnNameFor? env decl.name) <|> getInitFnNameFor? env decl.name then + withErrRet do + emitCName initFn; emit "()" + emitCName decl.name + if decl.type.isScalar then + emitLn <| " = " ++ decl.type.unboxOpName ++ "(lean_io_result_get_value(res));" + else + emitLn " = lean_io_result_get_value(res);" + emitMarkPersistent decl + emitLn "lean_dec_ref(res);" + else if !(isClosedTermName env decl.name || isSimpleGroundDecl env decl.name) then + emitCName decl.name; emit " = "; emitCInitName decl.name; emitLn "();" + emitMarkPersistent decl + +def emitInitFn (phases : IRPhases) : EmitM Unit := do + let env ← getEnv + let impInitFns ← env.imports.filterMapM fun imp => do + if phases != .all && imp.isMeta != (phases == .comptime) then + return none + let some idx := env.getModuleIdx? imp.module + | throwError "(internal) import without module index" -- should be unreachable + let pkg? := env.getModulePackageByIdx? idx + let fn := mkModuleInitializationFunctionName (phases := if phases == .all then .all else if imp.isMeta then .runtime else phases) imp.module pkg? + emitLn s!"lean_object* {fn}(uint8_t builtin);" + return some fn + let initialized := s!"_G_{mkModuleInitializationPrefix phases}initialized" + emitLns [ + s!"static bool {initialized} = false;", + s!"LEAN_EXPORT lean_object* {← getModInitFn (phases := phases)}(uint8_t builtin) \{", + "lean_object * res;", + s!"if ({initialized}) return lean_io_result_mk_ok(lean_box(0));", + s!"{initialized} = true;" + ] + impInitFns.forM fun fn => do + withErrRet do + emit s!"{fn}(builtin)" + emitLn "lean_dec_ref(res);" + for decl in (← getLocalDecls) do + if phases == .all || (phases == .comptime) == isMarkedMeta env decl.name then + emitDeclInit decl (isBuiltin := phases != .comptime) + emitLn "return lean_io_result_mk_ok(lean_box(0));" + emitLn "}" + +/-- Init function used before phase split under module system, keep for compatibility. -/ +def emitLegacyInitFn : EmitM Unit := do + let env ← getEnv + let impInitFns ← env.imports.filterMapM fun imp => do + let some idx := env.getModuleIdx? imp.module + | throwError "(internal) import without module index" -- should be unreachable + let pkg? := env.getModulePackageByIdx? idx + let fn := mkModuleInitializationFunctionName imp.module pkg? + emitLn s!"lean_object* {fn}(uint8_t builtin);" + return some fn + let initialized := s!"_G_initialized" + emitLns [ + s!"static bool {initialized} = false;", + s!"LEAN_EXPORT lean_object* {← getModInitFn (phases := .all)}(uint8_t builtin) \{", + "lean_object * res;", + s!"if ({initialized}) return lean_io_result_mk_ok(lean_box(0));", + s!"{initialized} = true;" + ] + impInitFns.forM fun fn => do + withErrRet do + emit s!"{fn}(builtin)" + emitLn "lean_dec_ref(res);" + withErrRet do + emit s!"{← getModInitFn (phases := .runtime)}(builtin)" + emitLn "lean_dec_ref(res);" + withErrRet do + emit s!"{← getModInitFn (phases := .comptime)}(builtin)" + emitLn "lean_dec_ref(res);" + emitLn s!"return {← getModInitFn (phases := .all)}(builtin);" + emitLn "}" + +def emitMainFnIfNeeded : EmitM Unit := do + if let some mainFn ← hasMainFn then + emitMainFn mainFn +where + hasMainFn : EmitM (Option (Decl .impure)) := do + return (← getLocalDecls).find? (·.name == `main) + + emitMainFn (decl : Decl .impure) : EmitM Unit := do + let .code .. := decl.value | throwError "Expected Lean function declaration as `main`" + let ps := decl.params + if ps.size != 1 && ps.size != 2 then + throwError "invalid main function, incorrect arity when generating code" + let env ← getEnv + let usesLeanAPI := usesModuleFrom env `Lean + emitLns [ + "char ** lean_setup_args(int argc, char ** argv);", + if usesLeanAPI then "void lean_initialize();" else "void lean_initialize_runtime_module();", + "#if defined(WIN32) || defined(_WIN32)", + "#include ", + "#endif", + "int main(int argc, char ** argv) {", + "#if defined(WIN32) || defined(_WIN32)", + " SetErrorMode(SEM_FAILCRITICALERRORS);", + " SetConsoleOutputCP(CP_UTF8);", + "#endif", + " lean_object* in; lean_object* res;", + " argv = lean_setup_args(argc, argv);", + if usesLeanAPI then " lean_initialize();" else " lean_initialize_runtime_module();", + s!" res = {← getModInitFn (phases := if env.header.isModule then .runtime else .all)}(1 /* builtin */);", + " lean_io_mark_end_initialization();", + " if (lean_io_result_is_ok(res)) {", + " lean_dec_ref(res);", + " lean_init_task_manager();", + ] + if ps.size == 2 then + emitLns [ + " in = lean_box(0);", + " int i = argc;", + " while (i > 1) {", + " lean_object* n;", + " i--;", + " n = lean_alloc_ctor(1,2,0); lean_ctor_set(n, 0, lean_mk_string(argv[i])); lean_ctor_set(n, 1, in);", + " in = n;", + " }" + ] + emitLn <| " res = " ++ leanMainFn ++ "(in);" + else + emitLn <| " res = " ++ leanMainFn ++ "();" + emitLn " }" + -- `IO _` + let retTy := env.find? `main |>.get! |>.type |>.getForallBody + -- either `UInt32` or `(P)Unit` + let retTy := retTy.appArg! + let hasExitCode := retTy.isConstOf ``UInt32 + -- finalize at least the task manager to avoid leak sanitizer false positives from tasks outliving the main thread + emitLns [ + " lean_finalize_task_manager();", + " if (lean_io_result_is_ok(res)) {", + " int ret = " ++ if hasExitCode then "lean_unbox_uint32(lean_io_result_get_value(res));" else "0;", + " lean_dec_ref(res);", + " return ret;", + " } else {", + " lean_io_result_show_error(res);", + " lean_dec_ref(res);", + " return 1;", + " }"] + emitLn "}" + +def emitFileFooter : EmitM Unit := + emitLns [ + "#ifdef __cplusplus", + "}", + "#endif" + ] + +def main : EmitM Unit := do + emitFileHeader + emitFnDecls + emitFns + if (← getEnv).header.isModule then + emitInitFn (phases := .runtime) + emitInitFn (phases := .comptime) + emitLegacyInitFn + else + emitInitFn (phases := .all) + emitMainFnIfNeeded + emitFileFooter + +public def emitCForDecls (modName : Name) (decls : Array Name) : CoreM String := do + let (localDecls, otherModuleDecls) ← collectUsedDecls decls + let env ← getEnv + let indexMap := getImpureDeclIndices env decls + let localDecls := localDecls.qsort fun l r => indexMap[l.name]! < indexMap[r.name]! + let (_, { buf, .. }) ← + main + |>.run { localDecls, otherModuleDecls, modName } + |>.run {} + |>.run (phase := .impure) + return buf + +public def emitC (modName : Name) : CoreM String := do + emitCForDecls modName (← getLocalImpureDecls) + +end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/EmitUtil.lean b/src/Lean/Compiler/LCNF/EmitUtil.lean new file mode 100644 index 0000000000..9363d52307 --- /dev/null +++ b/src/Lean/Compiler/LCNF/EmitUtil.lean @@ -0,0 +1,56 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Henrik Böving +-/ +module + +prelude +public import Lean.Compiler.LCNF.CompilerM +import Lean.Compiler.LCNF.PhaseExt +import Lean.Compiler.InitAttr + +namespace Lean.Compiler.LCNF + +private structure CollectUsedDeclsState where + visited : NameSet := {} + localDecls : Array (Decl .impure) := #[] + extSigs : Array (Signature .impure) := #[] + +/-- +Find all declarations that the declarations in `decls` transitively depend on. They are returned +partitioned into the declarations from the current module and declarations from other modules. +-/ +public partial def collectUsedDecls (decls : Array Name) : + CoreM (Array (Decl .impure) × Array (Signature .impure)) := do + let (_, state) ← go decls |>.run {} + return (state.localDecls, state.extSigs) +where + go (names : Array Name) : StateRefT CollectUsedDeclsState CoreM Unit := + names.forM fun name => do + if (← get).visited.contains name then return + modify fun s => { s with visited := s.visited.insert name } + if let some decl ← getLocalImpureDecl? name then + modify fun s => { s with localDecls := s.localDecls.push decl } + decl.value.forCodeM (·.forM visitCode) + let env ← getEnv + if let some initializer := getBuiltinInitFnNameFor? env decl.name <|> getInitFnNameFor? env decl.name then + go #[initializer] + else if let some sig ← getImpureSignature? name then + modify fun s => { s with extSigs := s.extSigs.push sig } + else + panic! s!"collectUsedDecls: could not find declaration or signature for '{name}'" + + visitCode (code : Code .impure) : StateRefT CollectUsedDeclsState CoreM Unit := do + match code with + | .let decl _ => + match decl.value with + | .const declName .. | .fap declName .. | .pap declName .. => + go #[declName] + | _ => return () + | _ => return () + +public def usesModuleFrom (env : Environment) (modulePrefix : Name) : Bool := + env.header.modules.any fun mod => mod.irPhases != .comptime && modulePrefix.isPrefixOf mod.module + +end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/Passes.lean b/src/Lean/Compiler/LCNF/Passes.lean index 12c1c50879..90f140fb7b 100644 --- a/src/Lean/Compiler/LCNF/Passes.lean +++ b/src/Lean/Compiler/LCNF/Passes.lean @@ -85,7 +85,9 @@ def saveImpure : Pass where phaseOut := .impure name := `saveImpure run decls := decls.mapM fun decl => do - (← normalizeFVarIds decl).saveImpure + let decl ← normalizeFVarIds decl + decl.saveImpure + modifyEnv fun env => recordFinalImpureDecl env decl.name return decl shouldAlwaysRunCheck := true @@ -160,8 +162,8 @@ def builtinPassManager : PassManager := { pushProj (occurrence := 1), detectSimpleGround, inferVisibility (phase := .impure), - saveImpure, -- End of impure phase toposortPass, + saveImpure, -- End of impure phase ] } diff --git a/src/Lean/Compiler/LCNF/PhaseExt.lean b/src/Lean/Compiler/LCNF/PhaseExt.lean index 35321a1cfc..627d694c2c 100644 --- a/src/Lean/Compiler/LCNF/PhaseExt.lean +++ b/src/Lean/Compiler/LCNF/PhaseExt.lean @@ -23,15 +23,15 @@ namespace Lean.Compiler.LCNF /-- Set of public declarations whose base bodies should be exported to other modules -/ -private builtin_initialize baseTransparentDeclsExt : EnvExtension (List Name × NameSet) ← mkDeclSetExt +private builtin_initialize baseTransparentDeclsExt : EnvExtension (List Name × NameSet) ← mkOrderedDeclSetExt /-- Set of public declarations whose mono bodies should be exported to other modules -/ -private builtin_initialize monoTransparentDeclsExt : EnvExtension (List Name × NameSet) ← mkDeclSetExt +private builtin_initialize monoTransparentDeclsExt : EnvExtension (List Name × NameSet) ← mkOrderedDeclSetExt /-- Set of public declarations whose impure bodies should be exported to other modules -/ -private builtin_initialize impureTransparentDeclsExt : EnvExtension (List Name × NameSet) ← mkDeclSetExt +private builtin_initialize impureTransparentDeclsExt : EnvExtension (List Name × NameSet) ← mkOrderedDeclSetExt private def getTransparencyExt : Phase → EnvExtension (List Name × NameSet) | .base => baseTransparentDeclsExt @@ -171,6 +171,9 @@ def getMonoDecl? (declName : Name) : CoreM (Option (Decl .pure)) := do def getLocalImpureDecl? (declName : Name) : CoreM (Option (Decl .impure)) := do return impureExt.getState (← getEnv) |>.find? declName +def getLocalImpureDecls : CoreM (Array Name) := do + return impureExt.getState (← getEnv) |>.toArray |>.map (·.fst) + def getImpureSignature? (declName : Name) : CoreM (Option (Signature .impure)) := do return getSigCore? (← getEnv) impureSigExt declName @@ -224,4 +227,23 @@ def getLocalDecl? (declName : Name) : CompilerM (Option ((pu : Purity) × Decl p let some decl ← getLocalDeclAt? declName (← getPhase) | return none return some ⟨_, decl⟩ +builtin_initialize declOrderExt : EnvExtension (List Name × NameSet) ← mkOrderedDeclSetExt + +def recordFinalImpureDecl (env : Environment) (name : Name) : Environment := + declOrderExt.modifyState env fun s => + (name :: s.1, s.2.insert name) + +def getImpureDeclIndices (env : Environment) (targets : Array Name) : Std.HashMap Name Nat := Id.run do + let (names, set) := declOrderExt.getState env + let mut map := Std.HashMap.emptyWithCapacity set.size + let targetSet := Std.HashSet.ofArray targets + let mut i := set.size + for name in names do + if targetSet.contains name then + map := map.insert name i + assert! i != 0 + i := i - 1 + assert! map.size == targets.size + return map + end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/PublicDeclsExt.lean b/src/Lean/Compiler/LCNF/PublicDeclsExt.lean index 95d1c82917..9c38ca4f17 100644 --- a/src/Lean/Compiler/LCNF/PublicDeclsExt.lean +++ b/src/Lean/Compiler/LCNF/PublicDeclsExt.lean @@ -10,15 +10,18 @@ public import Lean.Environment namespace Lean.Compiler.LCNF -/-- Creates a replayable local environment extension holding a name set. -/ -public def mkDeclSetExt : IO (EnvExtension (List Name × NameSet)) := +/-- +Creates a replayable local environment extension holding a name set and the list of names in the +order they were added to the set. +-/ +public def mkOrderedDeclSetExt : IO (EnvExtension (List Name × NameSet)) := registerEnvExtension (mkInitial := pure ([], {})) (asyncMode := .sync) (replay? := some <| fun oldState newState _ s => let newEntries := newState.1.take (newState.1.length - oldState.1.length) - newEntries.foldl (init := s) fun s n => - if s.1.contains n then + newEntries.reverse.foldl (init := s) fun s n => + if s.2.contains n then s else (n :: s.1, if newState.2.contains n then s.2.insert n else s.2)) @@ -26,7 +29,7 @@ public def mkDeclSetExt : IO (EnvExtension (List Name × NameSet)) := /-- Set of declarations to be exported to other modules; visibility shared by base/mono/IR phases. -/ -private builtin_initialize publicDeclsExt : EnvExtension (List Name × NameSet) ← mkDeclSetExt +private builtin_initialize publicDeclsExt : EnvExtension (List Name × NameSet) ← mkOrderedDeclSetExt public def isDeclPublic (env : Environment) (declName : Name) : Bool := Id.run do if !env.header.isModule then diff --git a/src/Lean/Compiler/LCNF/SimpCase.lean b/src/Lean/Compiler/LCNF/SimpCase.lean index abe521bcbd..adb444da26 100644 --- a/src/Lean/Compiler/LCNF/SimpCase.lean +++ b/src/Lean/Compiler/LCNF/SimpCase.lean @@ -115,6 +115,16 @@ def Decl.simpCase (decl : Decl .impure) : CompilerM (Decl .impure) := do let value ← decl.value.mapCodeM (·.simpCase) return { decl with value } +public def ensureHasDefault (alts : Array (Alt .impure)) : Array (Alt .impure) := + if alts.any (· matches .default ..) then + alts + else if alts.size < 2 then + alts + else + let last := alts.back! + let alts := alts.pop + alts.push (.default last.getCode) + public def simpCase : Pass := Pass.mkPerDeclaration `simpCase .impure Decl.simpCase 0 diff --git a/src/Lean/Compiler/LCNF/Toposort.lean b/src/Lean/Compiler/LCNF/Toposort.lean index ee710dc246..166308e5e5 100644 --- a/src/Lean/Compiler/LCNF/Toposort.lean +++ b/src/Lean/Compiler/LCNF/Toposort.lean @@ -8,6 +8,7 @@ module prelude public import Lean.Compiler.LCNF.CompilerM public import Lean.Compiler.LCNF.PassManager +import Lean.Compiler.InitAttr /-! This module "topologically sorts" an SCC of decls (an SCC of decls in the pipeline may in fact @@ -42,8 +43,11 @@ where if (← get).seen.contains decl.name then return () + let env ← getEnv modify fun s => { s with seen := s.seen.insert decl.name } decl.value.forCodeM (·.forM visitConsts) + if let some initializer := getBuiltinInitFnNameFor? env decl.name <|> getInitFnNameFor? env decl.name then + visitConst initializer modify fun s => { s with order := s.order.push decl } visitConsts (code : Code pu) : ToposortM pu Unit := do @@ -51,15 +55,16 @@ where | .let decl _ => match decl.value with | .const declName .. | .fap declName .. | .pap declName .. => - if let some d := (← read).declsMap[declName]? then - process d + visitConst declName | _ => return () | _ => return () + visitConst (declName : Name) : ToposortM pu Unit := do + if let some d := (← read).declsMap[declName]? then + process d + public def toposortDecls (decls : Array (Decl pu)) : CompilerM (Array (Decl pu)) := do - let (externDecls, otherDecls) := decls.partition (fun decl => decl.value matches .extern ..) - let otherDecls ← toposort otherDecls - return externDecls ++ otherDecls + toposort decls public def toposortPass : Pass where phase := .impure diff --git a/src/Lean/Shell.lean b/src/Lean/Shell.lean index 7a1246833f..dd5c809b19 100644 --- a/src/Lean/Shell.lean +++ b/src/Lean/Shell.lean @@ -10,7 +10,7 @@ import Lean.Elab.Frontend import Lean.Elab.ParseImportsFast import Lean.Server.Watchdog import Lean.Server.FileWorker -import Lean.Compiler.IR.EmitC +import Lean.Compiler.LCNF.EmitC import Init.System.Platform /- Lean companion to `shell.cpp` -/ @@ -545,7 +545,8 @@ def shellMain (args : List String) (opts : ShellOptions) : IO UInt32 := do | IO.eprintln s!"failed to create '{c}'" return 1 profileitIO "C code generation" opts.leanOpts do - let data ← IR.emitC env mainModuleName + let data ← Compiler.LCNF.emitC mainModuleName + |>.toIO' { fileName, fileMap := default } { env } out.write data.toUTF8 if let some bc := opts.bcFileName? then initLLVM