Compare commits

...

7 Commits

Author SHA1 Message Date
Leonardo de Moura
0504f50cdb feat: Expr internalization, add facts, congruence theorem cache 2024-12-18 15:46:27 -08:00
Leonardo de Moura
c1dade0664 feat: fold Expr.proj back into projection applications 2024-12-18 11:38:42 -08:00
Leonardo de Moura
8ce20910df feat: add eraseIrrelevantMData 2024-12-18 11:38:42 -08:00
Leonardo de Moura
9c1435df0e chore: update to new code conventions
Use new features that were not available when this code was originally
written.
2024-12-18 11:38:42 -08:00
Leonardo de Moura
a3f803860f chore: update doc strings 2024-12-18 11:38:42 -08:00
Leonardo de Moura
609ce14b02 chore: add new tracing classes for grind 2024-12-18 11:38:41 -08:00
Leonardo de Moura
5af20c1093 chore: add TODO 2024-12-18 11:38:41 -08:00
7 changed files with 319 additions and 82 deletions

View File

@@ -12,3 +12,12 @@ import Lean.Meta.Tactic.Grind.Util
import Lean.Meta.Tactic.Grind.Cases
import Lean.Meta.Tactic.Grind.Injection
import Lean.Meta.Tactic.Grind.Core
namespace Lean
builtin_initialize registerTraceClass `grind
builtin_initialize registerTraceClass `grind.eq
builtin_initialize registerTraceClass `grind.issues
builtin_initialize registerTraceClass `grind.add
end Lean

View File

@@ -20,8 +20,8 @@ def isInterpreted (e : Expr) : MetaM Bool := do
Creates an `ENode` for `e` if one does not already exist.
This method assumes `e` has been hashconsed.
-/
def mkENode (e : Expr) (generation : Nat := 0) : GoalM Unit := do
if ( getENode? e).isSome then return ()
def mkENode (e : Expr) (generation : Nat) : GoalM Unit := do
if ( alreadyInternalized e) then return ()
let ctor := ( isConstructorAppCore? e).isSome
let interpreted isInterpreted e
mkENodeCore e interpreted ctor generation
@@ -40,11 +40,6 @@ def getNext (e : Expr) : GoalM Expr := do
let some n getENode? e | return e
return n.next
@[inline] def isSameExpr (a b : Expr) : Bool :=
-- It is safe to use pointer equality because we hashcons all expressions
-- inserted into the E-graph
unsafe ptrEq a b
private def pushNewEqCore (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit :=
modify fun s => { s with newEqs := s.newEqs.push { lhs, rhs, proof, isHEq } }
@@ -54,6 +49,44 @@ private def pushNewEqCore (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit :=
@[inline] private def pushNewHEq (lhs rhs proof : Expr) : GoalM Unit :=
pushNewEqCore lhs rhs proof (isHEq := true)
/--
Adds `e` to congruence table.
-/
def addCongrTable (_e : Expr) : GoalM Unit := do
-- TODO
return ()
partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
if ( alreadyInternalized e) then return ()
match e with
| .bvar .. => unreachable!
| .sort .. => return ()
| .fvar .. | .letE .. | .lam .. | .forallE .. =>
mkENodeCore e (ctor := false) (interpreted := false) (generation := generation)
| .lit .. | .const .. =>
mkENode e generation
| .mvar ..
| .mdata ..
| .proj .. =>
trace[grind.issues] "unexpected term during internalization{indentExpr e}"
mkENodeCore e (ctor := false) (interpreted := false) (generation := generation)
| .app .. => e.withApp fun f args => do
let congrThm mkHCongrWithArity f args.size
let info getFunInfo f
let shouldInternalize (i : Nat) : GoalM Bool := do
if h : i < info.paramInfo.size then
let pinfo := info.paramInfo[i]
if pinfo.binderInfo.isInstImplicit || pinfo.isProp then
return false
return true
for h : i in [: args.size] do
let arg := args[i]
if ( shouldInternalize i) then
unless ( isTypeFormerType arg) do
internalize arg generation
mkENode e generation
addCongrTable e
/--
The fields `target?` and `proof?` in `e`'s `ENode` are encoding a transitivity proof
from `e` to the root of the equivalence class
@@ -77,7 +110,11 @@ where
private def markAsInconsistent : GoalM Unit :=
modify fun s => { s with inconsistent := true }
def isInconsistent : GoalM Bool :=
return ( get).inconsistent
private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
trace[grind.eq] "{lhs} {if isHEq then "=" else ""} {rhs}"
let some lhsNode getENode? lhs | return () -- `lhs` has not been internalized yet
let some rhsNode getENode? rhs | return () -- `rhs` has not been internalized yet
if isSameExpr lhsNode.root rhsNode.root then return () -- `lhs` and `rhs` are already in the same equivalence class.
@@ -136,13 +173,17 @@ where
loop n.next
loop lhs
/-- Ensures collection of equations to be processed is empty. -/
def resetNewEqs : GoalM Unit :=
modify fun s => { s with newEqs := #[] }
partial def addEqCore (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
addEqStep lhs rhs proof isHEq
processTodo
where
processTodo : GoalM Unit := do
if ( get).inconsistent then
modify fun s => { s with newEqs := #[] }
if ( isInconsistent) then
resetNewEqs
return ()
let some { lhs, rhs, proof, isHEq } := ( get).newEqs.back? | return ()
addEqStep lhs rhs proof isHEq
@@ -154,4 +195,66 @@ def addEq (lhs rhs proof : Expr) : GoalM Unit := do
def addHEq (lhs rhs proof : Expr) : GoalM Unit := do
addEqCore lhs rhs proof true
/--
Adds a new `fact` justified by the given proof and using the given generation.
-/
def add (fact : Expr) (proof : Expr) (generation := 0) : GoalM Unit := do
trace[grind.add] "{proof} : {fact}"
if ( isInconsistent) then return ()
resetNewEqs
let_expr Not p := fact
| go fact false
go p true
where
go (p : Expr) (isNeg : Bool) : GoalM Unit := do
trace[grind.add] "isNeg: {isNeg}, {p}"
match_expr p with
| Eq _ lhs rhs => goEq p lhs rhs isNeg false
| HEq _ _ lhs rhs => goEq p lhs rhs isNeg true
| _ =>
internalize p generation
if isNeg then
addEq p ( getFalseExpr) ( mkEqFalse proof)
else
addEq p ( getFalseExpr) ( mkEqTrue proof)
goEq (p : Expr) (lhs rhs : Expr) (isNeg : Bool) (isHEq : Bool) : GoalM Unit := do
if isNeg then
internalize p generation
addEq p ( getFalseExpr) ( mkEqFalse proof)
else
internalize lhs generation
internalize rhs generation
addEqCore lhs rhs proof isHEq
/--
Adds a new hypothesis.
-/
def addHyp (fvarId : FVarId) (generation := 0) : GoalM Unit := do
add ( fvarId.getType) (mkFVar fvarId) generation
/--
Returns expressions in the given expression equivalence class.
-/
partial def getEqc (e : Expr) : GoalM (List Expr) :=
go e e []
where
go (first : Expr) (e : Expr) (acc : List Expr) : GoalM (List Expr) := do
let next getNext e
let acc := e :: acc
if isSameExpr e next then
return acc
else
go first next acc
/--
Returns all equivalence classes in the current goal.
-/
partial def getEqcs : GoalM (List (List Expr)) := do
let mut r := []
for (_, node) in ( get).enodes do
if isSameExpr node.root node.self then
r := ( getEqc node.self) :: r
return r
end Lean.Meta.Grind

View File

@@ -68,8 +68,11 @@ def introNext (goal : Goal) : PreM IntroResult := do
else
let tag goal.mvarId.getTag
let q := target.bindingBody!
-- TODO: keep applying simp/eraseIrrelevantMData/canon/shareCommon until no progress
let r simp goal p
let p' := r.expr
let p' eraseIrrelevantMData p'
let p' foldProjs p'
let p' canon p'
let p' shareCommon p'
let fvarId mkFreshFVarId
@@ -133,8 +136,7 @@ partial def loop (goal : Goal) : PreM Unit := do
else if let some goal applyInjection? goal fvarId then
loop goal
else
let clause goal.mvarId.withContext do mkInputClause fvarId
loop { goal with clauses := goal.clauses.push clause }
loop ( GoalM.run' goal <| addHyp fvarId)
| .newDepHyp goal =>
loop goal
| .newLocal fvarId goal =>
@@ -153,6 +155,7 @@ open Preprocessor
partial def main (mvarId : MVarId) (mainDeclName : Name) : MetaM (List MVarId) := do
mvarId.ensureProp
-- TODO: abstract metavars
mvarId.ensureNoMVar
let mvarId mvarId.clearAuxDecls
let mvarId mvarId.revertAll

View File

@@ -6,38 +6,90 @@ Authors: Leonardo de Moura
prelude
import Lean.Util.ShareCommon
import Lean.Meta.Basic
import Lean.Meta.CongrTheorems
import Lean.Meta.AbstractNestedProofs
import Lean.Meta.Canonicalizer
import Lean.Meta.Tactic.Util
namespace Lean.Meta.Grind
@[inline] def isSameExpr (a b : Expr) : Bool :=
-- It is safe to use pointer equality because we hashcons all expressions
-- inserted into the E-graph
unsafe ptrEq a b
structure Context where
mainDeclName : Name
/--
Key for the congruence theorem cache.
-/
structure CongrTheoremCacheKey where
f : Expr
numArgs : Nat
-- We manually define `BEq` because we wannt to use pointer equality.
instance : BEq CongrTheoremCacheKey where
beq a b := isSameExpr a.f b.f && a.numArgs == b.numArgs
-- We manually define `Hashable` because we wannt to use pointer equality.
instance : Hashable CongrTheoremCacheKey where
hash a := mixHash (unsafe ptrAddrUnsafe a.f).toUInt64 (hash a.numArgs)
structure State where
canon : Canonicalizer.State := {}
/-- `ShareCommon` (aka `Hashconsing`) state. -/
scState : ShareCommon.State.{0} ShareCommon.objectFactory := ShareCommon.State.mk _
/-- Next index for creating auxiliary theorems. -/
nextThmIdx : Nat := 1
/--
Congruence theorems generated so far. Recall that for constant symbols
we rely on the reserved name feature (i.e., `mkHCongrWithArityForConst?`).
Remark: we currently do not reuse congruence theorems
-/
congrThms : PHashMap CongrTheoremCacheKey CongrTheorem := {}
trueExpr : Expr
falseExpr : Expr
abbrev GrindM := ReaderT Context $ StateRefT State MetaM
@[inline] def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α :=
x { mainDeclName } |>.run' {}
def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α := do
let scState := ShareCommon.State.mk _
let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False)
let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True)
x { mainDeclName } |>.run' { scState, trueExpr, falseExpr }
def getTrueExpr : GrindM Expr := do
return ( get).trueExpr
def getFalseExpr : GrindM Expr := do
return ( get).falseExpr
/--
Abtracts nested proofs in `e`. This is a preprocessing step performed before internalization.
-/
def abstractNestedProofs (e : Expr) : GrindM Expr := do
let nextIdx := ( get).nextThmIdx
let (e, s') AbstractNestedProofs.visit e |>.run { baseName := ( read).mainDeclName } |>.run |>.run { nextIdx }
modify fun s => { s with nextThmIdx := s'.nextIdx }
return e
/--
Applies hash-consing to `e`. Recall that all expressions in a `grind` goal have
been hash-consing. We perform this step before we internalize expressions.
-/
def shareCommon (e : Expr) : GrindM Expr := do
modifyGet fun { canon, scState, nextThmIdx } =>
modifyGet fun { canon, scState, nextThmIdx, congrThms, trueExpr, falseExpr } =>
let (e, scState) := ShareCommon.State.shareCommon scState e
(e, { canon, scState, nextThmIdx })
(e, { canon, scState, nextThmIdx, congrThms, trueExpr, falseExpr })
/--
Applies the canonicalizer to all subterms of `e`.
-/
-- TODO: the current canonicalizer is not a good solution for `grind`.
-- The problem is that two different applications `@f inst_1 a` and `@f inst_2 b`
-- may still have syntaticaally different instances. Thus, if we learn that `a = b`,
-- congruence closure will fail to see that the two applications are congruent.
def canon (e : Expr) : GrindM Expr := do
let canonS modifyGet fun s => (s.canon, { s with canon := {} })
let (e, canonS) Canonicalizer.CanonM.run (canonRec e) (s := canonS)
@@ -52,11 +104,28 @@ where
return .done e
transform e post
/--
Creates a congruence theorem for a `f`-applications with `numArgs` arguments.
-/
def mkHCongrWithArity (f : Expr) (numArgs : Nat) : GrindM CongrTheorem := do
let key := { f, numArgs }
if let some result := ( get).congrThms.find? key then
return result
if let .const declName us := f then
if let some result mkHCongrWithArityForConst? declName us numArgs then
modify fun s => { s with congrThms := s.congrThms.insert key result }
return result
let result Meta.mkHCongrWithArity f numArgs
modify fun s => { s with congrThms := s.congrThms.insert key result }
return result
/--
Stores information for a node in the egraph.
Each internalized expression `e` has an `ENode` associated with it.
-/
structure ENode where
/-- Node represented by this ENode. -/
self : Expr
/-- Next element in the equivalence class. -/
next : Expr
/-- Root (aka canonical representative) of the equivalence class -/
@@ -84,19 +153,15 @@ structure ENode where
on heterogeneous equality.
-/
heqProofs : Bool := false
/--
Unique index used for pretty printing and debugging purposes.
-/
idx : Nat := 0
generation : Nat := 0
/-- Modification time -/
mt : Nat := 0
-- TODO: see Lean 3 implementation
structure Clause where
expr : Expr
proof : Expr
deriving Inhabited
def mkInputClause (fvarId : FVarId) : MetaM Clause :=
return { expr := ( fvarId.getType), proof := mkFVar fvarId }
structure NewEq where
lhs : Expr
rhs : Expr
@@ -105,13 +170,15 @@ structure NewEq where
structure Goal where
mvarId : MVarId
clauses : PArray Clause := {}
enodes : PHashMap USize ENode := {}
/-- Equations to be processed. -/
newEqs : Array NewEq := #[]
/-- `inconsistent := true` if `ENode`s for `True` and `False` are in the same equivalence class. -/
inconsistent : Bool := false
/-- Goal modification time. -/
gmt : Nat := 0
/-- Next unique index for creating ENodes -/
nextIdx : Nat := 0
deriving Inhabited
def Goal.admit (goal : Goal) : MetaM Unit :=
@@ -120,10 +187,10 @@ def Goal.admit (goal : Goal) : MetaM Unit :=
abbrev GoalM := StateRefT Goal GrindM
@[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindM (α × Goal) :=
StateRefT'.run x goal
goal.mvarId.withContext do StateRefT'.run x goal
@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal :=
StateRefT'.run' (x *> get) goal
goal.mvarId.withContext do StateRefT'.run' (x *> get) goal
/--
Returns `some n` if `e` has already been "internalized" into the
@@ -132,22 +199,30 @@ Otherwise, returns `none`s.
def getENode? (e : Expr) : GoalM (Option ENode) :=
return ( get).enodes.find? (unsafe ptrAddrUnsafe e)
/-- Returns `true` if `e` has already been internalized. -/
def alreadyInternalized (e : Expr) : GoalM Bool :=
return ( get).enodes.contains (unsafe ptrAddrUnsafe e)
def setENode (e : Expr) (n : ENode) : GoalM Unit :=
modify fun s => { s with enodes := s.enodes.insert (unsafe ptrAddrUnsafe e) n }
def mkENodeCore (e : Expr) (interpreted ctor : Bool) (generation : Nat) : GoalM Unit := do
setENode e {
next := e, root := e, cgRoot := e, size := 1
self := e, next := e, root := e, cgRoot := e, size := 1
flipped := false
heqProofs := false
hasLambdas := e.isLambda
mt := ( get).gmt
idx := ( get).nextIdx
interpreted, ctor, generation
}
modify fun s => { s with nextIdx := s.nextIdx + 1 }
def mkGoal (mvarId : MVarId) : GrindM Goal := do
let trueExpr getTrueExpr
let falseExpr getFalseExpr
GoalM.run' { mvarId } do
mkENodeCore ( shareCommon (mkConst ``True)) (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore ( shareCommon (mkConst ``False)) (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0)
end Lean.Meta.Grind

View File

@@ -5,6 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.AbstractNestedProofs
import Lean.Meta.Transform
import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Clear
@@ -35,7 +36,7 @@ def _root_.Lean.MVarId.transformTarget (mvarId : MVarId) (f : Expr → MetaM Exp
return mvarNew.mvarId!
/--
Unfold all `reducible` declarations occurring in `e`.
Unfolds all `reducible` declarations occurring in `e`.
-/
def unfoldReducible (e : Expr) : MetaM Expr :=
let pre (e : Expr) : MetaM TransformStep := do
@@ -46,25 +47,25 @@ def unfoldReducible (e : Expr) : MetaM Expr :=
Core.transform e (pre := pre)
/--
Unfold all `reducible` declarations occurring in the goal's target.
Unfolds all `reducible` declarations occurring in the goal's target.
-/
def _root_.Lean.MVarId.unfoldReducible (mvarId : MVarId) : MetaM MVarId :=
mvarId.transformTarget Grind.unfoldReducible
/--
Abstract nested proofs occurring in the goal's target.
Abstracts nested proofs occurring in the goal's target.
-/
def _root_.Lean.MVarId.abstractNestedProofs (mvarId : MVarId) (mainDeclName : Name) : MetaM MVarId :=
mvarId.transformTarget (Lean.Meta.abstractNestedProofs mainDeclName)
/--
Beta-reduce the goal's target.
Beta-reduces the goal's target.
-/
def _root_.Lean.MVarId.betaReduce (mvarId : MVarId) : MetaM MVarId :=
mvarId.transformTarget (Core.betaReduce ·)
/--
If the target is not `False`, apply `byContradiction`.
If the target is not `False`, applies `byContradiction`.
-/
def _root_.Lean.MVarId.byContra? (mvarId : MVarId) : MetaM (Option MVarId) := mvarId.withContext do
mvarId.checkNotAssigned `grind.by_contra
@@ -77,7 +78,7 @@ def _root_.Lean.MVarId.byContra? (mvarId : MVarId) : MetaM (Option MVarId) := mv
return mvarNew.mvarId!
/--
Clear auxiliary decls used to encode recursive declarations.
Clears auxiliary decls used to encode recursive declarations.
`grind` eliminates them to ensure they are not accidentally used by its proof automation.
-/
def _root_.Lean.MVarId.clearAuxDecls (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
@@ -96,4 +97,31 @@ def _root_.Lean.MVarId.clearAuxDecls (mvarId : MVarId) : MetaM MVarId := mvarId.
throwTacticEx `grind.clear_aux_decls mvarId "failed to clear local auxiliary declaration"
return mvarId
/--
In the `grind` tactic, during `Expr` internalization, we don't expect to find `Expr.mdata`.
This function ensures `Expr.mdata` is not found during internalization.
Recall that we do not internalize `Expr.forallE` and `Expr.lam` components.
-/
def eraseIrrelevantMData (e : Expr) : CoreM Expr := do
let pre (e : Expr) := do
match e with
| .letE .. | .lam .. | .forallE .. => return .done e
| .mdata _ e => return .continue e
| _ => return .continue e
Core.transform e (pre := pre)
/--
Converts nested `Expr.proj`s into projection applications if possible.
-/
def foldProjs (e : Expr) : MetaM Expr := do
let post (e : Expr) := do
let .proj structName idx s := e | return .done e
let some info := getStructureInfo? ( getEnv) structName | return .done e
if h : idx < info.fieldNames.size then
let fieldName := info.fieldNames[idx]
return .done ( mkProjection s fieldName)
else
return .done e
Meta.transform e (post := post)
end Lean.Meta.Grind

View File

@@ -57,13 +57,13 @@ partial def transform {m} [Monad m] [MonadLiftT CoreM m] [MonadControlT CoreM m]
| .continue e? =>
let e := e?.getD e
match e with
| Expr.forallE _ d b _ => visitPost (e.updateForallE! ( visit d) ( visit b))
| Expr.lam _ d b _ => visitPost (e.updateLambdaE! ( visit d) ( visit b))
| Expr.letE _ t v b _ => visitPost (e.updateLet! ( visit t) ( visit v) ( visit b))
| Expr.app .. => e.withApp fun f args => do visitPost (mkAppN ( visit f) ( args.mapM visit))
| Expr.mdata _ b => visitPost (e.updateMData! ( visit b))
| Expr.proj _ _ b => visitPost (e.updateProj! ( visit b))
| _ => visitPost e
| .forallE _ d b _ => visitPost (e.updateForallE! ( visit d) ( visit b))
| .lam _ d b _ => visitPost (e.updateLambdaE! ( visit d) ( visit b))
| .letE _ t v b _ => visitPost (e.updateLet! ( visit t) ( visit v) ( visit b))
| .app .. => e.withApp fun f args => do visitPost (mkAppN ( visit f) ( args.mapM visit))
| .mdata _ b => visitPost (e.updateMData! ( visit b))
| .proj _ _ b => visitPost (e.updateProj! ( visit b))
| _ => visitPost e
visit input |>.run
def betaReduce (e : Expr) : CoreM Expr :=
@@ -99,19 +99,19 @@ partial def transform {m} [Monad m] [MonadLiftT MetaM m] [MonadControlT MetaM m]
| .continue e? => pure (e?.getD e)
let rec visitLambda (fvars : Array Expr) (e : Expr) : MonadCacheT ExprStructEq Expr m Expr := do
match e with
| Expr.lam n d b c =>
| .lam n d b c =>
withLocalDecl n c ( visit (d.instantiateRev fvars)) fun x =>
visitLambda (fvars.push x) b
| e => visitPost ( mkLambdaFVars (usedLetOnly := usedLetOnly) fvars ( visit (e.instantiateRev fvars)))
let rec visitForall (fvars : Array Expr) (e : Expr) : MonadCacheT ExprStructEq Expr m Expr := do
match e with
| Expr.forallE n d b c =>
| .forallE n d b c =>
withLocalDecl n c ( visit (d.instantiateRev fvars)) fun x =>
visitForall (fvars.push x) b
| e => visitPost ( mkForallFVars (usedLetOnly := usedLetOnly) fvars ( visit (e.instantiateRev fvars)))
let rec visitLet (fvars : Array Expr) (e : Expr) : MonadCacheT ExprStructEq Expr m Expr := do
match e with
| Expr.letE n t v b _ =>
| .letE n t v b _ =>
withLetDecl n ( visit (t.instantiateRev fvars)) ( visit (v.instantiateRev fvars)) fun x =>
visitLet (fvars.push x) b
| e => visitPost ( mkLetFVars (usedLetOnly := usedLetOnly) fvars ( visit (e.instantiateRev fvars)))
@@ -127,28 +127,22 @@ partial def transform {m} [Monad m] [MonadLiftT MetaM m] [MonadControlT MetaM m]
| .continue e? =>
let e := e?.getD e
match e with
| Expr.forallE .. => visitForall #[] e
| Expr.lam .. => visitLambda #[] e
| Expr.letE .. => visitLet #[] e
| Expr.app .. => visitApp e
| Expr.mdata _ b => visitPost (e.updateMData! ( visit b))
| Expr.proj _ _ b => visitPost (e.updateProj! ( visit b))
| _ => visitPost e
| .forallE .. => visitForall #[] e
| .lam .. => visitLambda #[] e
| .letE .. => visitLet #[] e
| .app .. => visitApp e
| .mdata _ b => visitPost (e.updateMData! ( visit b))
| .proj _ _ b => visitPost (e.updateProj! ( visit b))
| _ => visitPost e
visit input |>.run
-- TODO: add options to distinguish zeta and zetaDelta reduction
def zetaReduce (e : Expr) : MetaM Expr := do
let pre (e : Expr) : MetaM TransformStep := do
match e with
| Expr.fvar fvarId =>
match ( getLCtx).find? fvarId with
| none => return TransformStep.done e
| some localDecl =>
if let some value := localDecl.value? then
return TransformStep.visit ( instantiateMVars value)
else
return TransformStep.done e
| _ => return .continue
let .fvar fvarId := e | return .continue
let some localDecl := ( getLCtx).find? fvarId | return .done e
let some value := localDecl.value? | return .done e
return .visit ( instantiateMVars value)
transform e (pre := pre) (usedLetOnly := true)
/--
@@ -161,10 +155,9 @@ def zetaDeltaFVars (e : Expr) (fvars : Array FVarId) : MetaM Expr :=
else
return none
let pre (e : Expr) : MetaM TransformStep := do
if let .fvar fvarId := e.getAppFn then
if let some val unfold? fvarId then
return .visit <| ( instantiateMVars val).beta e.getAppArgs
return .continue
let .fvar fvarId := e.getAppFn | return .continue
let some val unfold? fvarId | return .continue
return .visit <| ( instantiateMVars val).beta e.getAppArgs
transform e (pre := pre)
/-- Unfold definitions and theorems in `e` that are not in the current environment, but are in `biggerEnv`. -/
@@ -173,25 +166,20 @@ def unfoldDeclsFrom (biggerEnv : Environment) (e : Expr) : CoreM Expr := do
let env getEnv
setEnv biggerEnv -- `e` has declarations from `biggerEnv` that are not in `env`
let pre (e : Expr) : CoreM TransformStep := do
match e with
| Expr.const declName us .. =>
if env.contains declName then
return TransformStep.done e
else if let some info := biggerEnv.find? declName then
if info.hasValue then
return TransformStep.visit ( instantiateValueLevelParams info us)
else
return TransformStep.done e
else
return TransformStep.done e
| _ => return .continue
let .const declName us := e | return .continue
if env.contains declName then
return .done e
let some info := biggerEnv.find? declName | return .done e
if info.hasValue then
return .visit ( instantiateValueLevelParams info us)
else
return .done e
Core.transform e (pre := pre)
def eraseInaccessibleAnnotations (e : Expr) : CoreM Expr :=
Core.transform e (post := fun e => return TransformStep.done <| if let some e := inaccessible? e then e else e)
Core.transform e (post := fun e => return .done <| if let some e := inaccessible? e then e else e)
def erasePatternRefAnnotations (e : Expr) : CoreM Expr :=
Core.transform e (post := fun e => return TransformStep.done <| if let some (_, e) := patternWithRef? e then e else e)
Core.transform e (post := fun e => return .done <| if let some (_, e) := patternWithRef? e then e else e)
end Meta
end Lean
end Lean.Meta

View File

@@ -0,0 +1,31 @@
import Lean.Meta.Tactic.Grind
import Lean.Elab.Tactic
structure A (α : Type u) where
x : α
y : α
f : α α
structure B (α : Type u) extends A α where
z : α
b : Bool
open Lean Meta Elab Tactic
elab "fold_projs" : tactic => liftMetaTactic1 fun mvarId => do
mvarId.replaceTargetDefEq ( Grind.foldProjs ( mvarId.getType))
example (a : Nat × Bool) : a.fst = 10 := by
unfold Prod.fst
fail_if_success guard_target = a.fst = 10
fold_projs
guard_target = a.fst = 10
sorry
example (b : B (List Nat)) : b.y = [] := by
unfold B.toA
unfold A.y
fail_if_success unfold A.y
fail_if_success guard_target = b.y = []
fold_projs
guard_target = b.y = []
sorry