preserve in early pipeline

This commit is contained in:
Henrik Böving
2026-03-06 14:04:42 +00:00
parent 57df23f27e
commit 0f2532f683
4 changed files with 181 additions and 28 deletions

View File

@@ -93,7 +93,7 @@ where
match type with
| .forallE _ d b _ =>
let d := d.instantiateRev xs
let p mkAuxParam d
let p mkAuxParam d (isMarkedBorrowed d)
go b (xs.push (.fvar p.fvarId)) (ps.push p)
| _ =>
let type := type.instantiateRev xs

View File

@@ -159,7 +159,7 @@ def toDecl (declName : Name) : CompilerM (Decl .pure) := do
/- Recall that `inlineMatchers` may have exposed `ite`s and `dite`s which are tagged as `[macro_inline]`. -/
let value macroInline value
return (type, value)
let code toLCNF value
let code toLCNF value type
let mut decl if let .fun decl (.return _) := code then
eraseFunDecl decl (recursive := false)
pure { name := declName, params := decl.params, type, value := .code decl.value, levelParams := info.levelParams, safe, inlineAttr? : Decl .pure }

View File

@@ -206,6 +206,7 @@ structure Context where
eventually.
-/
ignoreNoncomputable : Bool := false
expectedType : Option Expr
structure State where
/-- Local context containing the original Lean types (not LCNF ones). -/
@@ -265,8 +266,18 @@ def toCode (result : Arg .pure) : M (Code .pure) := do
let fvarId mkAuxLetDecl .erased
seqToCode ( get).seq (.return fvarId)
def run (x : M α) : CompilerM α :=
x.run {} |>.run' {}
def run (expectedType : Expr) (x : M α) : CompilerM α :=
x.run { expectedType } |>.run' {}
@[inline]
def withExpectedType (e : Option Expr) (x : M α) : M α :=
withReader (fun ctx => { ctx with expectedType := e }) do
x
@[inline]
def withoutExpectedType (x : M α) : M α :=
withExpectedType none do
x
/--
Return true iff `type` is `Sort _` or `As → Sort _`.
@@ -340,9 +351,9 @@ def cleanupBinderName (binderName : Name) : CompilerM Name :=
return binderName
/-- Create a new local declaration using a Lean regular type. -/
def mkParam (binderName : Name) (type : Expr) : M (Param .pure) := do
def mkParam (binderName : Name) (type : Expr) (borrow : Bool := isMarkedBorrowed type) :
M (Param .pure) := do
let binderName cleanupBinderName binderName
let borrow := isMarkedBorrowed type
let type' toLCNFType type
let param LCNF.mkParam binderName type' borrow
modify fun s => { s with lctx := s.lctx.mkLocalDecl param.fvarId binderName type .default }
@@ -361,16 +372,22 @@ def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (a
}
return letDecl
def visitLambda (e : Expr) : M (Array (Param .pure) × Expr) :=
go e #[] #[]
def visitLambda (e : Expr) : M (Array (Param .pure) × Expr × Option Expr) := do
go e #[] #[] ( read).expectedType
where
go (e : Expr) (xs : Array Expr) (ps : Array (Param .pure)) := do
go (e : Expr) (xs : Array Expr) (ps : Array (Param .pure)) (eType? : Option Expr) := do
if let .lam binderName type body _ := e then
let type := type.instantiateRev xs
let p mkParam binderName type
go body (xs.push p.toExpr) (ps.push p)
if let some (.forallE _ type' eType _) := eType? then
let borrow := isMarkedBorrowed type || isMarkedBorrowed type'
let p mkParam binderName type borrow
-- no need to instantiate eType, we only ever check if for `isMarkedBorrowed`
go body (xs.push p.toExpr) (ps.push p) (some eType)
else
let p mkParam binderName type
go body (xs.push p.toExpr) (ps.push p) none
else
return (ps, e.instantiateRev xs)
return (ps, e.instantiateRev xs, eType?.map (·.instantiateRev xs))
def visitBoundedLambda (e : Expr) (n : Nat) : M (Array (Param .pure) × Expr) :=
go e n #[] #[]
@@ -446,8 +463,8 @@ Put the given expression in `LCNF`.
- Eta-expand applications of declarations that satisfy `shouldEtaExpand`.
- Put computationally relevant expressions in A-normal form.
-/
partial def toLCNF (e : Expr) : CompilerM (Code .pure) := do
run do toCode ( visit e)
partial def toLCNF (e : Expr) (eType : Expr) : CompilerM (Code .pure) := do
run eType do toCode ( visit e)
where
visitCore (e : Expr) : M (Arg .pure) := withIncRecDepth do
if let some arg := ( get).cache.find? e then
@@ -505,7 +522,7 @@ where
visitAppDefaultConst (f : Expr) (args : Array Expr) : M (Arg .pure) := do
let env getEnv
let .const declName us CSimp.replaceConstant env f | unreachable!
let args args.mapM visitAppArg
let args args.mapM (withoutExpectedType do visitAppArg ·)
if hasNeverExtractAttribute env declName then
modify fun s => {s with shouldCache := false }
letValueToArg <| .const declName us args
@@ -549,10 +566,12 @@ where
let altType c.inferType
return (altType, .default c)
| .ctor ctorName numParams =>
let mut (ps, e) visitBoundedLambda e numParams
let mut (ps, e) withoutExpectedType do
visitBoundedLambda e numParams
if ps.size < numParams then
e etaExpandN e (numParams - ps.size)
let (ps', e') ToLCNF.visitLambda e
let (ps', e', _) withoutExpectedType do
ToLCNF.visitLambda e
ps := ps ++ ps'
e := e'
/-
@@ -609,11 +628,17 @@ where
fieldArgs := fieldArgs.push fieldArg
return fieldArgs
let f := args[casesInfo.altsRange.lower]!
let result visit (mkAppN f fieldArgs)
mkOverApplication result args casesInfo.arity
let arity := casesInfo.arity
if args.size == arity then
visit (mkAppN f fieldArgs)
else
withoutExpectedType do
let result visit (mkAppN f fieldArgs)
mkOverApplication result args casesInfo.arity
else
let mut alts := #[]
let discr visitAppArg args[casesInfo.discrPos]!
let discr withoutExpectedType do
visitAppArg args[casesInfo.discrPos]!
let discrFVarId match discr with
| .fvar discrFVarId => pure discrFVarId
| .erased | .type .. => mkAuxLetDecl .erased
@@ -625,9 +650,11 @@ where
let auxDecl mkAuxParam resultType
pushElement (.cases auxDecl cases)
let result := .fvar auxDecl.fvarId
mkOverApplication result args casesInfo.arity
withoutExpectedType do
mkOverApplication result args casesInfo.arity
visitCtor (arity : Nat) (e : Expr) : M (Arg .pure) :=
withoutExpectedType do
etaIfUnderApplied e arity do
let f := e.getAppFn
let args := e.getAppArgs
@@ -638,7 +665,7 @@ where
-- We can rely on `toMono` erasing ctor params eventually; we do not do so here so that type
-- inference on the value is preserved.
withReader (fun ctx =>
{ ignoreNoncomputable := ctx.ignoreNoncomputable || ctorInfo?.any (idx < ·.numParams) }) do
{ ctx with ignoreNoncomputable := ctx.ignoreNoncomputable || ctorInfo?.any (idx < ·.numParams) }) do
visitAppArg arg
if hasNeverExtractAttribute env declName then
modify fun s => {s with shouldCache := false }
@@ -646,6 +673,7 @@ where
visitQuotLift (e : Expr) : M (Arg .pure) := do
let arity := 6
withoutExpectedType do
etaIfUnderApplied e arity do
let mut args := e.getAppArgs
let α visitAppArg args[0]!
@@ -661,6 +689,7 @@ where
visitEqRec (e : Expr) : M (Arg .pure) :=
let arity := 6
withoutExpectedType do
etaIfUnderApplied e arity do
let args := e.getAppArgs
let minor := if e.isAppOf ``Eq.rec || e.isAppOf ``Eq.ndrec then args[3]! else args[5]!
@@ -669,6 +698,7 @@ where
visitHEqRec (e : Expr) : M (Arg .pure) :=
let arity := 7
withoutExpectedType do
etaIfUnderApplied e arity do
let args := e.getAppArgs
let minor := if e.isAppOf ``HEq.rec || e.isAppOf ``HEq.ndrec then args[3]! else args[6]!
@@ -677,18 +707,21 @@ where
visitFalseRec (e : Expr) : M (Arg .pure) :=
let arity := 2
withoutExpectedType do
etaIfUnderApplied e arity do
let type toLCNFType ( liftMetaM do Meta.inferType e)
mkUnreachable type
visitLcUnreachable (e : Expr) : M (Arg .pure) :=
let arity := 1
withoutExpectedType do
etaIfUnderApplied e arity do
let type toLCNFType ( liftMetaM do Meta.inferType e)
mkUnreachable type
visitAndIffRecCore (e : Expr) (minorPos : Nat) : M (Arg .pure) :=
let arity := 5
withoutExpectedType do
etaIfUnderApplied e arity do
let args := e.getAppArgs
let ha := mkLcProof args[0]! -- We should not use `lcErased` here since we use it to create a pre-LCNF Expr.
@@ -701,6 +734,7 @@ where
let .const declName _ := e.getAppFn | unreachable!
let info := getNoConfusionInfo ( getEnv) declName
let typeName := declName.getPrefix
withoutExpectedType do
etaIfUnderApplied e info.arity do
let args := e.getAppArgs
let visitMajor (numNonPropFields : Nat) := do
@@ -786,10 +820,13 @@ where
e.withApp visitAppDefaultConst
else
e.withApp fun f args => do
match ( visit f) with
-- TODO: we can try to make the type more precise here. However, it probably won't matter
-- much as user defined borrow annotations will likely never occur in a position where they
-- are relevant for specifically the function of an application.
match ( withoutExpectedType do visit f) with
| .erased | .type .. => return .erased
| .fvar fvarId =>
let args args.mapM visitAppArg
let args args.mapM (withoutExpectedType do visitAppArg ·)
letValueToArg <| .fvar fvarId args
visitLambda (e : Expr) : M (Arg .pure) := do
@@ -821,8 +858,9 @@ where
visit b
else
let funDecl withNewScope do
let (ps, e) ToLCNF.visitLambda e
let e visit e
let (ps, e, eType?) ToLCNF.visitLambda e
let e withExpectedType eType? do
visit e
let c toCode e
mkAuxFunDecl ps c
pushElement (.fun funDecl)
@@ -837,7 +875,7 @@ where
let projExpr liftMetaM <| Meta.mkProjection e structInfo.fieldNames[i]!
visitApp projExpr
else
match ( visit e) with
match ( withoutExpectedType do visit e) with
| .erased | .type .. => return .erased
| .fvar fvarId => letValueToArg <| .proj s i fvarId
@@ -850,7 +888,9 @@ where
visitLet body (xs.push value)
else
let type' toLCNFType type
let letDecl mkLetDecl binderName type value type' ( visit value)
let value' withExpectedType type' do
visit value
let letDecl mkLetDecl binderName type value type' value'
visitLet body (xs.push (.fvar letDecl.fvarId))
| _ =>
let e := e.instantiateRev xs

View File

@@ -0,0 +1,113 @@
module
public section
/-!
Tests that borrow annotations from declaration/let-binding types survive LCNF conversion.
The `@&` annotations live in the forall type, not in the lambda binders, and are based on the
(rather brittle) mdata so LCNF must infer them to a degree.
-/
/--
trace: [Compiler.saveBase] size: 1
def borrowTop @&xs : Nat :=
let _x.1 := @List.lengthTR _ xs;
return _x.1
-/
#guard_msgs in
set_option trace.Compiler.saveBase true in
def borrowTop (xs : @& List Nat) : Nat := xs.length
/--
trace: [Compiler.saveBase] size: 3
def borrowMixed n @&xs m : Nat :=
let _x.1 := @List.lengthTR _ xs;
let _x.2 := Nat.add n _x.1;
let _x.3 := Nat.add _x.2 m;
return _x.3
-/
#guard_msgs in
set_option trace.Compiler.saveBase true in
def borrowMixed (n : Nat) (xs : @& List Nat) (m : Nat) : Nat :=
n + xs.length + m
/--
trace: [Compiler.saveBase] size: 5
def borrowLet n xs ys : Nat :=
fun f @&ys : Nat :=
let _x.1 := @List.lengthTR _ ys;
let _x.2 := Nat.add _x.1 n;
return _x.2;
let _x.3 := f xs;
let _x.4 := f ys;
let _x.5 := Nat.add _x.3 _x.4;
return _x.5
-/
#guard_msgs in
set_option trace.Compiler.saveBase true in
def borrowLet (n : Nat) (xs ys : List Nat) : Nat :=
let f : (@& List Nat) Nat := fun ys => ys.length + n
f xs + f ys
/--
trace: [Compiler.saveBase] size: 2
def applyTwice f @&a.1 : Nat :=
let _x.2 := f a.1;
let _x.3 := f _x.2;
return _x.3
-/
#guard_msgs in
set_option trace.Compiler.saveBase true in
def applyTwice (f : Nat Nat) : (@& Nat) Nat :=
let g := f f
g
structure Ctx where
values : List Nat
abbrev MyReaderM (α : Type) := (@& Ctx) α
@[inline]
def MyReaderM.bind (f : MyReaderM α) (g : α MyReaderM β) : MyReaderM β :=
fun ctx => g (f ctx) ctx
instance : Monad MyReaderM where
pure a := fun _ => a
bind := MyReaderM.bind
@[inline] def MyReaderM.read : MyReaderM Ctx := fun ctx => ctx
/--
trace: [Compiler.saveBase] size: 2
def withMyReader α f x @&ctx : α :=
let _x.1 := f ctx;
let _x.2 := x _x.1;
return _x.2
-/
#guard_msgs in
set_option trace.Compiler.saveBase true in
@[noinline]
def withMyReader (f : Ctx Ctx) (x : MyReaderM α) : MyReaderM α :=
fun ctx => x (f ctx)
/--
trace: [Compiler.saveBase] size: 6
def getLength other @&a.1 : Nat :=
fun _f.2 ctx : Ctx :=
let _x.3 := ctx # 0;
let _x.4 := @List.appendTR _ _x.3 other;
let _x.5 := Ctx.mk _x.4;
return _x.5;
fun _f.6 _y.7 : Nat :=
let _x.8 := _y.7 # 0;
let _x.9 := @List.lengthTR _ _x.8;
return _x.9;
let _x.10 := @withMyReader _ _f.2 _f.6 a.1;
return _x.10
-/
#guard_msgs in
set_option trace.Compiler.saveBase true in
def getLength (other : List Nat) : MyReaderM Nat := do
withMyReader (fun ctx => { ctx with values := ctx.values ++ other }) do
let ctx MyReaderM.read
return ctx.values.length