mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-17 18:34:06 +00:00
preserve in early pipeline
This commit is contained in:
@@ -93,7 +93,7 @@ where
|
||||
match type with
|
||||
| .forallE _ d b _ =>
|
||||
let d := d.instantiateRev xs
|
||||
let p ← mkAuxParam d
|
||||
let p ← mkAuxParam d (isMarkedBorrowed d)
|
||||
go b (xs.push (.fvar p.fvarId)) (ps.push p)
|
||||
| _ =>
|
||||
let type := type.instantiateRev xs
|
||||
|
||||
@@ -159,7 +159,7 @@ def toDecl (declName : Name) : CompilerM (Decl .pure) := do
|
||||
/- Recall that `inlineMatchers` may have exposed `ite`s and `dite`s which are tagged as `[macro_inline]`. -/
|
||||
let value ← macroInline value
|
||||
return (type, value)
|
||||
let code ← toLCNF value
|
||||
let code ← toLCNF value type
|
||||
let mut decl ← if let .fun decl (.return _) := code then
|
||||
eraseFunDecl decl (recursive := false)
|
||||
pure { name := declName, params := decl.params, type, value := .code decl.value, levelParams := info.levelParams, safe, inlineAttr? : Decl .pure }
|
||||
|
||||
@@ -206,6 +206,7 @@ structure Context where
|
||||
eventually.
|
||||
-/
|
||||
ignoreNoncomputable : Bool := false
|
||||
expectedType : Option Expr
|
||||
|
||||
structure State where
|
||||
/-- Local context containing the original Lean types (not LCNF ones). -/
|
||||
@@ -265,8 +266,18 @@ def toCode (result : Arg .pure) : M (Code .pure) := do
|
||||
let fvarId ← mkAuxLetDecl .erased
|
||||
seqToCode (← get).seq (.return fvarId)
|
||||
|
||||
def run (x : M α) : CompilerM α :=
|
||||
x.run {} |>.run' {}
|
||||
def run (expectedType : Expr) (x : M α) : CompilerM α :=
|
||||
x.run { expectedType } |>.run' {}
|
||||
|
||||
@[inline]
|
||||
def withExpectedType (e : Option Expr) (x : M α) : M α :=
|
||||
withReader (fun ctx => { ctx with expectedType := e }) do
|
||||
x
|
||||
|
||||
@[inline]
|
||||
def withoutExpectedType (x : M α) : M α :=
|
||||
withExpectedType none do
|
||||
x
|
||||
|
||||
/--
|
||||
Return true iff `type` is `Sort _` or `As → Sort _`.
|
||||
@@ -340,9 +351,9 @@ def cleanupBinderName (binderName : Name) : CompilerM Name :=
|
||||
return binderName
|
||||
|
||||
/-- Create a new local declaration using a Lean regular type. -/
|
||||
def mkParam (binderName : Name) (type : Expr) : M (Param .pure) := do
|
||||
def mkParam (binderName : Name) (type : Expr) (borrow : Bool := isMarkedBorrowed type) :
|
||||
M (Param .pure) := do
|
||||
let binderName ← cleanupBinderName binderName
|
||||
let borrow := isMarkedBorrowed type
|
||||
let type' ← toLCNFType type
|
||||
let param ← LCNF.mkParam binderName type' borrow
|
||||
modify fun s => { s with lctx := s.lctx.mkLocalDecl param.fvarId binderName type .default }
|
||||
@@ -361,16 +372,22 @@ def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (a
|
||||
}
|
||||
return letDecl
|
||||
|
||||
def visitLambda (e : Expr) : M (Array (Param .pure) × Expr) :=
|
||||
go e #[] #[]
|
||||
def visitLambda (e : Expr) : M (Array (Param .pure) × Expr × Option Expr) := do
|
||||
go e #[] #[] (← read).expectedType
|
||||
where
|
||||
go (e : Expr) (xs : Array Expr) (ps : Array (Param .pure)) := do
|
||||
go (e : Expr) (xs : Array Expr) (ps : Array (Param .pure)) (eType? : Option Expr) := do
|
||||
if let .lam binderName type body _ := e then
|
||||
let type := type.instantiateRev xs
|
||||
let p ← mkParam binderName type
|
||||
go body (xs.push p.toExpr) (ps.push p)
|
||||
if let some (.forallE _ type' eType _) := eType? then
|
||||
let borrow := isMarkedBorrowed type || isMarkedBorrowed type'
|
||||
let p ← mkParam binderName type borrow
|
||||
-- no need to instantiate eType, we only ever check if for `isMarkedBorrowed`
|
||||
go body (xs.push p.toExpr) (ps.push p) (some eType)
|
||||
else
|
||||
let p ← mkParam binderName type
|
||||
go body (xs.push p.toExpr) (ps.push p) none
|
||||
else
|
||||
return (ps, e.instantiateRev xs)
|
||||
return (ps, e.instantiateRev xs, eType?.map (·.instantiateRev xs))
|
||||
|
||||
def visitBoundedLambda (e : Expr) (n : Nat) : M (Array (Param .pure) × Expr) :=
|
||||
go e n #[] #[]
|
||||
@@ -446,8 +463,8 @@ Put the given expression in `LCNF`.
|
||||
- Eta-expand applications of declarations that satisfy `shouldEtaExpand`.
|
||||
- Put computationally relevant expressions in A-normal form.
|
||||
-/
|
||||
partial def toLCNF (e : Expr) : CompilerM (Code .pure) := do
|
||||
run do toCode (← visit e)
|
||||
partial def toLCNF (e : Expr) (eType : Expr) : CompilerM (Code .pure) := do
|
||||
run eType do toCode (← visit e)
|
||||
where
|
||||
visitCore (e : Expr) : M (Arg .pure) := withIncRecDepth do
|
||||
if let some arg := (← get).cache.find? e then
|
||||
@@ -505,7 +522,7 @@ where
|
||||
visitAppDefaultConst (f : Expr) (args : Array Expr) : M (Arg .pure) := do
|
||||
let env ← getEnv
|
||||
let .const declName us ← CSimp.replaceConstant env f | unreachable!
|
||||
let args ← args.mapM visitAppArg
|
||||
let args ← args.mapM (withoutExpectedType do visitAppArg ·)
|
||||
if hasNeverExtractAttribute env declName then
|
||||
modify fun s => {s with shouldCache := false }
|
||||
letValueToArg <| .const declName us args
|
||||
@@ -549,10 +566,12 @@ where
|
||||
let altType ← c.inferType
|
||||
return (altType, .default c)
|
||||
| .ctor ctorName numParams =>
|
||||
let mut (ps, e) ← visitBoundedLambda e numParams
|
||||
let mut (ps, e) ← withoutExpectedType do
|
||||
visitBoundedLambda e numParams
|
||||
if ps.size < numParams then
|
||||
e ← etaExpandN e (numParams - ps.size)
|
||||
let (ps', e') ← ToLCNF.visitLambda e
|
||||
let (ps', e', _) ← withoutExpectedType do
|
||||
ToLCNF.visitLambda e
|
||||
ps := ps ++ ps'
|
||||
e := e'
|
||||
/-
|
||||
@@ -609,11 +628,17 @@ where
|
||||
fieldArgs := fieldArgs.push fieldArg
|
||||
return fieldArgs
|
||||
let f := args[casesInfo.altsRange.lower]!
|
||||
let result ← visit (mkAppN f fieldArgs)
|
||||
mkOverApplication result args casesInfo.arity
|
||||
let arity := casesInfo.arity
|
||||
if args.size == arity then
|
||||
visit (mkAppN f fieldArgs)
|
||||
else
|
||||
withoutExpectedType do
|
||||
let result ← visit (mkAppN f fieldArgs)
|
||||
mkOverApplication result args casesInfo.arity
|
||||
else
|
||||
let mut alts := #[]
|
||||
let discr ← visitAppArg args[casesInfo.discrPos]!
|
||||
let discr ← withoutExpectedType do
|
||||
visitAppArg args[casesInfo.discrPos]!
|
||||
let discrFVarId ← match discr with
|
||||
| .fvar discrFVarId => pure discrFVarId
|
||||
| .erased | .type .. => mkAuxLetDecl .erased
|
||||
@@ -625,9 +650,11 @@ where
|
||||
let auxDecl ← mkAuxParam resultType
|
||||
pushElement (.cases auxDecl cases)
|
||||
let result := .fvar auxDecl.fvarId
|
||||
mkOverApplication result args casesInfo.arity
|
||||
withoutExpectedType do
|
||||
mkOverApplication result args casesInfo.arity
|
||||
|
||||
visitCtor (arity : Nat) (e : Expr) : M (Arg .pure) :=
|
||||
withoutExpectedType do
|
||||
etaIfUnderApplied e arity do
|
||||
let f := e.getAppFn
|
||||
let args := e.getAppArgs
|
||||
@@ -638,7 +665,7 @@ where
|
||||
-- We can rely on `toMono` erasing ctor params eventually; we do not do so here so that type
|
||||
-- inference on the value is preserved.
|
||||
withReader (fun ctx =>
|
||||
{ ignoreNoncomputable := ctx.ignoreNoncomputable || ctorInfo?.any (idx < ·.numParams) }) do
|
||||
{ ctx with ignoreNoncomputable := ctx.ignoreNoncomputable || ctorInfo?.any (idx < ·.numParams) }) do
|
||||
visitAppArg arg
|
||||
if hasNeverExtractAttribute env declName then
|
||||
modify fun s => {s with shouldCache := false }
|
||||
@@ -646,6 +673,7 @@ where
|
||||
|
||||
visitQuotLift (e : Expr) : M (Arg .pure) := do
|
||||
let arity := 6
|
||||
withoutExpectedType do
|
||||
etaIfUnderApplied e arity do
|
||||
let mut args := e.getAppArgs
|
||||
let α ← visitAppArg args[0]!
|
||||
@@ -661,6 +689,7 @@ where
|
||||
|
||||
visitEqRec (e : Expr) : M (Arg .pure) :=
|
||||
let arity := 6
|
||||
withoutExpectedType do
|
||||
etaIfUnderApplied e arity do
|
||||
let args := e.getAppArgs
|
||||
let minor := if e.isAppOf ``Eq.rec || e.isAppOf ``Eq.ndrec then args[3]! else args[5]!
|
||||
@@ -669,6 +698,7 @@ where
|
||||
|
||||
visitHEqRec (e : Expr) : M (Arg .pure) :=
|
||||
let arity := 7
|
||||
withoutExpectedType do
|
||||
etaIfUnderApplied e arity do
|
||||
let args := e.getAppArgs
|
||||
let minor := if e.isAppOf ``HEq.rec || e.isAppOf ``HEq.ndrec then args[3]! else args[6]!
|
||||
@@ -677,18 +707,21 @@ where
|
||||
|
||||
visitFalseRec (e : Expr) : M (Arg .pure) :=
|
||||
let arity := 2
|
||||
withoutExpectedType do
|
||||
etaIfUnderApplied e arity do
|
||||
let type ← toLCNFType (← liftMetaM do Meta.inferType e)
|
||||
mkUnreachable type
|
||||
|
||||
visitLcUnreachable (e : Expr) : M (Arg .pure) :=
|
||||
let arity := 1
|
||||
withoutExpectedType do
|
||||
etaIfUnderApplied e arity do
|
||||
let type ← toLCNFType (← liftMetaM do Meta.inferType e)
|
||||
mkUnreachable type
|
||||
|
||||
visitAndIffRecCore (e : Expr) (minorPos : Nat) : M (Arg .pure) :=
|
||||
let arity := 5
|
||||
withoutExpectedType do
|
||||
etaIfUnderApplied e arity do
|
||||
let args := e.getAppArgs
|
||||
let ha := mkLcProof args[0]! -- We should not use `lcErased` here since we use it to create a pre-LCNF Expr.
|
||||
@@ -701,6 +734,7 @@ where
|
||||
let .const declName _ := e.getAppFn | unreachable!
|
||||
let info := getNoConfusionInfo (← getEnv) declName
|
||||
let typeName := declName.getPrefix
|
||||
withoutExpectedType do
|
||||
etaIfUnderApplied e info.arity do
|
||||
let args := e.getAppArgs
|
||||
let visitMajor (numNonPropFields : Nat) := do
|
||||
@@ -786,10 +820,13 @@ where
|
||||
e.withApp visitAppDefaultConst
|
||||
else
|
||||
e.withApp fun f args => do
|
||||
match (← visit f) with
|
||||
-- TODO: we can try to make the type more precise here. However, it probably won't matter
|
||||
-- much as user defined borrow annotations will likely never occur in a position where they
|
||||
-- are relevant for specifically the function of an application.
|
||||
match (← withoutExpectedType do visit f) with
|
||||
| .erased | .type .. => return .erased
|
||||
| .fvar fvarId =>
|
||||
let args ← args.mapM visitAppArg
|
||||
let args ← args.mapM (withoutExpectedType do visitAppArg ·)
|
||||
letValueToArg <| .fvar fvarId args
|
||||
|
||||
visitLambda (e : Expr) : M (Arg .pure) := do
|
||||
@@ -821,8 +858,9 @@ where
|
||||
visit b
|
||||
else
|
||||
let funDecl ← withNewScope do
|
||||
let (ps, e) ← ToLCNF.visitLambda e
|
||||
let e ← visit e
|
||||
let (ps, e, eType?) ← ToLCNF.visitLambda e
|
||||
let e ← withExpectedType eType? do
|
||||
visit e
|
||||
let c ← toCode e
|
||||
mkAuxFunDecl ps c
|
||||
pushElement (.fun funDecl)
|
||||
@@ -837,7 +875,7 @@ where
|
||||
let projExpr ← liftMetaM <| Meta.mkProjection e structInfo.fieldNames[i]!
|
||||
visitApp projExpr
|
||||
else
|
||||
match (← visit e) with
|
||||
match (← withoutExpectedType do visit e) with
|
||||
| .erased | .type .. => return .erased
|
||||
| .fvar fvarId => letValueToArg <| .proj s i fvarId
|
||||
|
||||
@@ -850,7 +888,9 @@ where
|
||||
visitLet body (xs.push value)
|
||||
else
|
||||
let type' ← toLCNFType type
|
||||
let letDecl ← mkLetDecl binderName type value type' (← visit value)
|
||||
let value' ← withExpectedType type' do
|
||||
visit value
|
||||
let letDecl ← mkLetDecl binderName type value type' value'
|
||||
visitLet body (xs.push (.fvar letDecl.fvarId))
|
||||
| _ =>
|
||||
let e := e.instantiateRev xs
|
||||
|
||||
113
tests/elab/lcnf_borrow_expected_type.lean
Normal file
113
tests/elab/lcnf_borrow_expected_type.lean
Normal file
@@ -0,0 +1,113 @@
|
||||
module
|
||||
|
||||
public section
|
||||
|
||||
/-!
|
||||
Tests that borrow annotations from declaration/let-binding types survive LCNF conversion.
|
||||
The `@&` annotations live in the forall type, not in the lambda binders, and are based on the
|
||||
(rather brittle) mdata so LCNF must infer them to a degree.
|
||||
-/
|
||||
|
||||
/--
|
||||
trace: [Compiler.saveBase] size: 1
|
||||
def borrowTop @&xs : Nat :=
|
||||
let _x.1 := @List.lengthTR _ xs;
|
||||
return _x.1
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.saveBase true in
|
||||
def borrowTop (xs : @& List Nat) : Nat := xs.length
|
||||
|
||||
/--
|
||||
trace: [Compiler.saveBase] size: 3
|
||||
def borrowMixed n @&xs m : Nat :=
|
||||
let _x.1 := @List.lengthTR _ xs;
|
||||
let _x.2 := Nat.add n _x.1;
|
||||
let _x.3 := Nat.add _x.2 m;
|
||||
return _x.3
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.saveBase true in
|
||||
def borrowMixed (n : Nat) (xs : @& List Nat) (m : Nat) : Nat :=
|
||||
n + xs.length + m
|
||||
|
||||
/--
|
||||
trace: [Compiler.saveBase] size: 5
|
||||
def borrowLet n xs ys : Nat :=
|
||||
fun f @&ys : Nat :=
|
||||
let _x.1 := @List.lengthTR _ ys;
|
||||
let _x.2 := Nat.add _x.1 n;
|
||||
return _x.2;
|
||||
let _x.3 := f xs;
|
||||
let _x.4 := f ys;
|
||||
let _x.5 := Nat.add _x.3 _x.4;
|
||||
return _x.5
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.saveBase true in
|
||||
def borrowLet (n : Nat) (xs ys : List Nat) : Nat :=
|
||||
let f : (@& List Nat) → Nat := fun ys => ys.length + n
|
||||
f xs + f ys
|
||||
|
||||
/--
|
||||
trace: [Compiler.saveBase] size: 2
|
||||
def applyTwice f @&a.1 : Nat :=
|
||||
let _x.2 := f a.1;
|
||||
let _x.3 := f _x.2;
|
||||
return _x.3
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.saveBase true in
|
||||
def applyTwice (f : Nat → Nat) : (@& Nat) → Nat :=
|
||||
let g := f ∘ f
|
||||
g
|
||||
|
||||
structure Ctx where
|
||||
values : List Nat
|
||||
|
||||
abbrev MyReaderM (α : Type) := (@& Ctx) → α
|
||||
|
||||
@[inline]
|
||||
def MyReaderM.bind (f : MyReaderM α) (g : α → MyReaderM β) : MyReaderM β :=
|
||||
fun ctx => g (f ctx) ctx
|
||||
|
||||
instance : Monad MyReaderM where
|
||||
pure a := fun _ => a
|
||||
bind := MyReaderM.bind
|
||||
|
||||
@[inline] def MyReaderM.read : MyReaderM Ctx := fun ctx => ctx
|
||||
|
||||
/--
|
||||
trace: [Compiler.saveBase] size: 2
|
||||
def withMyReader α f x @&ctx : α :=
|
||||
let _x.1 := f ctx;
|
||||
let _x.2 := x _x.1;
|
||||
return _x.2
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.saveBase true in
|
||||
@[noinline]
|
||||
def withMyReader (f : Ctx → Ctx) (x : MyReaderM α) : MyReaderM α :=
|
||||
fun ctx => x (f ctx)
|
||||
|
||||
/--
|
||||
trace: [Compiler.saveBase] size: 6
|
||||
def getLength other @&a.1 : Nat :=
|
||||
fun _f.2 ctx : Ctx :=
|
||||
let _x.3 := ctx # 0;
|
||||
let _x.4 := @List.appendTR _ _x.3 other;
|
||||
let _x.5 := Ctx.mk _x.4;
|
||||
return _x.5;
|
||||
fun _f.6 _y.7 : Nat :=
|
||||
let _x.8 := _y.7 # 0;
|
||||
let _x.9 := @List.lengthTR _ _x.8;
|
||||
return _x.9;
|
||||
let _x.10 := @withMyReader _ _f.2 _f.6 a.1;
|
||||
return _x.10
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.saveBase true in
|
||||
def getLength (other : List Nat) : MyReaderM Nat := do
|
||||
withMyReader (fun ctx => { ctx with values := ctx.values ++ other }) do
|
||||
let ctx ← MyReaderM.read
|
||||
return ctx.values.length
|
||||
Reference in New Issue
Block a user