Compare commits

...

4 Commits

Author SHA1 Message Date
Leonardo de Moura
9b68268dbf feat: add simple union-find to grind 2024-05-21 20:34:02 -07:00
Leonardo de Moura
e4b43e0c41 feat: add isLitValue 2024-05-21 19:52:47 -07:00
Leonardo de Moura
73ba9a4968 feat: add MVarId.clearAuxDecls 2024-05-21 15:58:42 -07:00
Leonardo de Moura
0c8a727e8a feat: add Grind.Config and [grind_cases] annotations 2024-05-21 14:00:44 -07:00
10 changed files with 370 additions and 81 deletions

View File

@@ -7,3 +7,4 @@ prelude
import Init.Grind.Norm
import Init.Grind.Tactics
import Init.Grind.Lemmas
import Init.Grind.Cases

15
src/Init/Grind/Cases.lean Normal file
View File

@@ -0,0 +1,15 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Core
attribute [grind_cases] And Prod False Empty True Unit Exists
namespace Lean.Grind.Eager
attribute [scoped grind_cases] Or
end Lean.Grind.Eager

View File

@@ -7,8 +7,19 @@ prelude
import Init.Tactics
namespace Lean.Grind
/--
The configuration for `grind`.
Passed to `grind` using, for example, the `grind (config := { eager := true })` syntax.
-/
structure Config where
/--
When `eager` is true (default: `false`), `grind` eagerly splits `if-then-else` and `match`
expressions.
-/
eager : Bool := false
deriving Inhabited, BEq
/-!
`grind` tactic and related tactics.
-/
end Lean.Grind

View File

@@ -99,6 +99,8 @@ def getUInt64Value? (e : Expr) : MetaM (Option UInt64) := OptionT.run do
let (n, _) getOfNatValue? e ``UInt64
return UInt64.ofNat n
-- TODO: extensibility
/--
If `e` is a literal value, ensure it is encoded using the standard representation.
Otherwise, just return `e`.
@@ -117,6 +119,23 @@ def normLitValue (e : Expr) : MetaM Expr := do
if let some n getUInt64Value? e then return toExpr n
return e
/--
Returns `true` if `e` is a literal value.
-/
def isLitValue (e : Expr) : MetaM Bool := do
let e instantiateMVars e
if ( getNatValue? e).isSome then return true
if ( getIntValue? e).isSome then return true
if ( getFinValue? e).isSome then return true
if ( getBitVecValue? e).isSome then return true
if (getStringValue? e).isSome then return true
if ( getCharValue? e).isSome then return true
if ( getUInt8Value? e).isSome then return true
if ( getUInt16Value? e).isSome then return true
if ( getUInt32Value? e).isSome then return true
if ( getUInt64Value? e).isSome then return true
return false
/--
If `e` is a `Nat`, `Int`, or `Fin` literal value, converts it into a constructor application.
Otherwise, just return `e`.

View File

@@ -11,3 +11,4 @@ import Lean.Meta.Tactic.Grind.Preprocessor
import Lean.Meta.Tactic.Grind.Util
import Lean.Meta.Tactic.Grind.Cases
import Lean.Meta.Tactic.Grind.Injection
import Lean.Meta.Tactic.Grind.Core

View File

@@ -0,0 +1,157 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.LitValues
namespace Lean.Meta.Grind
/--
Returns `true` if `e` is `True`, `False`, or a literal value.
See `LitValues` for supported literals.
-/
def isInterpreted (e : Expr) : MetaM Bool := do
if e.isTrue || e.isFalse then return true
isLitValue e
/--
Creates an `ENode` for `e` if one does not already exist.
This method assumes `e` has been hashconsed.
-/
def mkENode (e : Expr) (generation : Nat := 0) : GoalM Unit := do
if ( getENode? e).isSome then return ()
let ctor := ( isConstructorAppCore? e).isSome
let interpreted isInterpreted e
mkENodeCore e interpreted ctor generation
/--
Returns the root element in the equivalence class of `e`.
-/
def getRoot (e : Expr) : GoalM Expr := do
let some n getENode? e | return e
return n.root
/--
Returns the next element in the equivalence class of `e`.
-/
def getNext (e : Expr) : GoalM Expr := do
let some n getENode? e | return e
return n.next
@[inline] def isSameExpr (a b : Expr) : Bool :=
-- It is safe to use pointer equality because we hashcons all expressions
-- inserted into the E-graph
unsafe ptrEq a b
private def pushNewEqCore (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit :=
modify fun s => { s with newEqs := s.newEqs.push { lhs, rhs, proof, isHEq } }
@[inline] private def pushNewEq (lhs rhs proof : Expr) : GoalM Unit :=
pushNewEqCore lhs rhs proof (isHEq := false)
@[inline] private def pushNewHEq (lhs rhs proof : Expr) : GoalM Unit :=
pushNewEqCore lhs rhs proof (isHEq := true)
/--
The fields `target?` and `proof?` in `e`'s `ENode` are encoding a transitivity proof
from `e` to the root of the equivalence class
This method "inverts" the proof, and makes it to go from the root of the equivalence class to `e`.
We use this method when merging two equivalence classes.
-/
private partial def invertTrans (e : Expr) : GoalM Unit := do
go e false none none
where
go (e : Expr) (flippedNew : Bool) (targetNew? : Option Expr) (proofNew? : Option Expr) : GoalM Unit := do
let some node getENode? e | unreachable!
if let some target := node.target? then
go target (!node.flipped) (some e) node.proof?
setENode e { node with
target? := targetNew?
flipped := flippedNew
proof? := proofNew?
}
private def markAsInconsistent : GoalM Unit :=
modify fun s => { s with inconsistent := true }
private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
let some lhsNode getENode? lhs | return () -- `lhs` has not been internalized yet
let some rhsNode getENode? rhs | return () -- `rhs` has not been internalized yet
if isSameExpr lhsNode.root rhsNode.root then return () -- `lhs` and `rhs` are already in the same equivalence class.
let some lhsRoot getENode? lhsNode.root | unreachable!
let some rhsRoot getENode? rhsNode.root | unreachable!
if (lhsRoot.interpreted && !rhsRoot.interpreted)
|| (lhsRoot.ctor && !rhsRoot.ctor)
|| (lhsRoot.size > rhsRoot.size && !rhsRoot.interpreted && !rhsRoot.ctor) then
go rhs lhs rhsNode lhsNode rhsRoot lhsRoot true
else
go lhs rhs lhsNode rhsNode lhsRoot rhsRoot false
where
go (lhs rhs : Expr) (lhsNode rhsNode lhsRoot rhsRoot : ENode) (flipped : Bool) : GoalM Unit := do
let mut valueInconsistency := false
if lhsRoot.interpreted && rhsRoot.interpreted then
if lhsNode.root.isTrue || rhsNode.root.isTrue then
markAsInconsistent
else
valueInconsistency := true
-- TODO: process valueInconsistency := true
/-
We have the following `target?/proof?`
`lhs -> ... -> lhsNode.root`
`rhs -> ... -> rhsNode.root`
We want to convert it to
`lhsNode.root -> ... -> lhs -*-> rhs -> ... -> rhsNode.root`
where step `-*->` is justified by `proof` (or `proof.symm` if `flipped := true`)
-/
invertTrans lhs
setENode lhs { lhsNode with
target? := rhs
proof? := proof
flipped
}
-- TODO: Remove parents from congruence table
-- TODO: set propagateBool
updateRoots lhs rhsNode.root true -- TODO
-- TODO: Reinsert parents into congruence table
setENode lhsNode.root { lhsRoot with
next := rhsRoot.next
}
setENode rhsNode.root { rhsRoot with
next := lhsRoot.next
size := rhsRoot.size + lhsRoot.size
hasLambdas := rhsRoot.hasLambdas || lhsRoot.hasLambdas
heqProofs := isHEq || rhsRoot.heqProofs || lhsRoot.heqProofs
}
-- TODO: copy parentst from lhsRoot parents to rhsRoot parents
updateRoots (lhs : Expr) (rootNew : Expr) (_propagateBool : Bool) : GoalM Unit := do
let rec loop (e : Expr) : GoalM Unit := do
-- TODO: propagateBool
let some n getENode? e | unreachable!
setENode e { n with root := rootNew }
if isSameExpr lhs n.next then return ()
loop n.next
loop lhs
partial def addEqCore (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
addEqStep lhs rhs proof isHEq
processTodo
where
processTodo : GoalM Unit := do
if ( get).inconsistent then
modify fun s => { s with newEqs := #[] }
return ()
let some { lhs, rhs, proof, isHEq } := ( get).newEqs.back? | return ()
addEqStep lhs rhs proof isHEq
processTodo
def addEq (lhs rhs proof : Expr) : GoalM Unit := do
addEqCore lhs rhs proof false
def addHEq (lhs rhs proof : Expr) : GoalM Unit := do
addEqCore lhs rhs proof true
end Lean.Meta.Grind

View File

@@ -15,6 +15,7 @@ import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Util
import Lean.Meta.Tactic.Grind.Cases
import Lean.Meta.Tactic.Grind.Injection
import Lean.Meta.Tactic.Grind.Core
namespace Lean.Meta.Grind
namespace Preprocessor
@@ -29,6 +30,7 @@ structure Context where
structure State where
simpStats : Simp.Stats := {}
goals : PArray Goal := {}
deriving Inhabited
abbrev PreM := ReaderT Context $ StateRefT State GrindM
@@ -43,18 +45,13 @@ def PreM.run (x : PreM α) : GrindM α := do
}
x { simp, simprocs } |>.run' {}
def simp (e : Expr) : PreM Simp.Result := do
def simp (_goal : Goal) (e : Expr) : PreM Simp.Result := do
-- TODO: use `goal` state in the simplifier
let simpStats := ( get).simpStats
let (r, simpStats) Meta.simp e ( read).simp ( read).simprocs (stats := simpStats)
modify fun s => { s with simpStats }
return r
def simpHyp? (mvarId : MVarId) (fvarId : FVarId) : PreM (Option (FVarId × MVarId)) := do
let simpStats := ( get).simpStats
let (result, simpStats) simpLocalDecl mvarId fvarId ( read).simp ( read).simprocs (stats := simpStats)
modify fun s => { s with simpStats }
return result
inductive IntroResult where
| done
| newHyp (fvarId : FVarId) (goal : Goal)
@@ -72,7 +69,7 @@ def introNext (goal : Goal) : PreM IntroResult := do
else
let tag goal.mvarId.getTag
let q := target.bindingBody!
let r simp p
let r simp goal p
let p' := r.expr
let p' canon p'
let p' shareCommon p'
@@ -105,7 +102,7 @@ def introNext (goal : Goal) : PreM IntroResult := do
return .done
def pushResult (goal : Goal) : PreM Unit :=
modifyThe Grind.State fun s => { s with goals := s.goals.push goal }
modify fun s => { s with goals := s.goals.push goal }
def isCasesCandidate (fvarId : FVarId) : MetaM Bool := do
let .const declName _ := ( fvarId.getType).getAppFn | return false
@@ -124,42 +121,47 @@ def applyInjection? (goal : Goal) (fvarId : FVarId) : MetaM (Option Goal) := do
else
return none
partial def preprocess (goal : Goal) : PreM Unit := do
partial def loop (goal : Goal) : PreM Unit := do
match ( introNext goal) with
| .done =>
if let some mvarId goal.mvarId.byContra? then
preprocess { goal with mvarId }
loop { goal with mvarId }
else
pushResult goal
| .newHyp fvarId goal =>
if let some goals applyCases? goal fvarId then
goals.forM preprocess
goals.forM loop
else if let some goal applyInjection? goal fvarId then
preprocess goal
loop goal
else
let clause goal.mvarId.withContext do mkInputClause fvarId
preprocess { goal with clauses := goal.clauses.push clause }
loop { goal with clauses := goal.clauses.push clause }
| .newDepHyp goal =>
preprocess goal
loop goal
| .newLocal fvarId goal =>
if let some goals applyCases? goal fvarId then
goals.forM preprocess
goals.forM loop
else
preprocess goal
loop goal
def preprocess (mvarId : MVarId) : PreM State := do
loop ( mkGoal mvarId)
get
end Preprocessor
open Preprocessor
partial def main (mvarId : MVarId) (mainDeclName : Name) : MetaM Grind.State := do
partial def main (mvarId : MVarId) (mainDeclName : Name) : MetaM (List MVarId) := do
mvarId.ensureProp
mvarId.ensureNoMVar
let mvarId mvarId.clearAuxDecls
let mvarId mvarId.revertAll
mvarId.ensureNoMVar
let mvarId mvarId.abstractNestedProofs mainDeclName
let mvarId mvarId.unfoldReducible
let mvarId mvarId.betaReduce
let s (preprocess { mvarId } *> getThe Grind.State) |>.run |>.run mainDeclName
return s
let s preprocess mvarId |>.run |>.run mainDeclName
return s.goals.toList.map (·.mvarId)
end Lean.Meta.Grind

View File

@@ -11,6 +11,47 @@ import Lean.Meta.Canonicalizer
import Lean.Meta.Tactic.Util
namespace Lean.Meta.Grind
structure Context where
mainDeclName : Name
structure State where
canon : Canonicalizer.State := {}
/-- `ShareCommon` (aka `Hashconsing`) state. -/
scState : ShareCommon.State.{0} ShareCommon.objectFactory := ShareCommon.State.mk _
/-- Next index for creating auxiliary theorems. -/
nextThmIdx : Nat := 1
abbrev GrindM := ReaderT Context $ StateRefT State MetaM
@[inline] def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α :=
x { mainDeclName } |>.run' {}
def abstractNestedProofs (e : Expr) : GrindM Expr := do
let nextIdx := ( get).nextThmIdx
let (e, s') AbstractNestedProofs.visit e |>.run { baseName := ( read).mainDeclName } |>.run |>.run { nextIdx }
modify fun s => { s with nextThmIdx := s'.nextIdx }
return e
def shareCommon (e : Expr) : GrindM Expr := do
modifyGet fun { canon, scState, nextThmIdx } =>
let (e, scState) := ShareCommon.State.shareCommon scState e
(e, { canon, scState, nextThmIdx })
def canon (e : Expr) : GrindM Expr := do
let canonS modifyGet fun s => (s.canon, { s with canon := {} })
let (e, canonS) Canonicalizer.CanonM.run (canonRec e) (s := canonS)
modify fun s => { s with canon := canonS }
return e
where
canonRec (e : Expr) : CanonM Expr := do
let post (e : Expr) : CanonM TransformStep := do
if e.isApp then
return .done ( Meta.canon e)
else
return .done e
transform e post
/--
Stores information for a node in the egraph.
Each internalized expression `e` has an `ENode` associated with it.
@@ -43,6 +84,9 @@ structure ENode where
on heterogeneous equality.
-/
heqProofs : Bool := false
generation : Nat := 0
/-- Modification time -/
mt : Nat := 0
-- TODO: see Lean 3 implementation
structure Clause where
@@ -53,58 +97,57 @@ structure Clause where
def mkInputClause (fvarId : FVarId) : MetaM Clause :=
return { expr := ( fvarId.getType), proof := mkFVar fvarId }
structure Goal where
mvarId : MVarId
clauses : PArray Clause := {}
enodes : PHashMap UInt64 ENode := {}
-- TODO: occurrences for propagation, etc
deriving Inhabited
structure NewEq where
lhs : Expr
rhs : Expr
proof : Expr
isHEq : Bool
def mkGoal (mvarId : MVarId) : Goal :=
{ mvarId }
structure Goal where
mvarId : MVarId
clauses : PArray Clause := {}
enodes : PHashMap USize ENode := {}
newEqs : Array NewEq := #[]
/-- `inconsistent := true` if `ENode`s for `True` and `False` are in the same equivalence class. -/
inconsistent : Bool := false
/-- Goal modification time. -/
gmt : Nat := 0
deriving Inhabited
def Goal.admit (goal : Goal) : MetaM Unit :=
goal.mvarId.admit
structure Context where
mainDeclName : Name
abbrev GoalM := StateRefT Goal GrindM
structure State where
canon : Canonicalizer.State := {}
/-- `ShareCommon` (aka `Hashconsing`) state. -/
scState : ShareCommon.State.{0} ShareCommon.objectFactory := ShareCommon.State.mk _
/-- Next index for creating auxiliary theorems. -/
nextThmIdx : Nat := 1
goals : PArray Goal := {}
@[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindM (α × Goal) :=
StateRefT'.run x goal
abbrev GrindM := ReaderT Context $ StateRefT State MetaM
@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal :=
StateRefT'.run' (x *> get) goal
def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α :=
x { mainDeclName } |>.run' {}
/--
Returns `some n` if `e` has already been "internalized" into the
Otherwise, returns `none`s.
-/
def getENode? (e : Expr) : GoalM (Option ENode) :=
return ( get).enodes.find? (unsafe ptrAddrUnsafe e)
def abstractNestedProofs (e : Expr) : GrindM Expr := do
let nextIdx := ( get).nextThmIdx
let (e, s') AbstractNestedProofs.visit e |>.run { baseName := ( read).mainDeclName } |>.run |>.run { nextIdx }
modify fun s => { s with nextThmIdx := s'.nextIdx }
return e
def setENode (e : Expr) (n : ENode) : GoalM Unit :=
modify fun s => { s with enodes := s.enodes.insert (unsafe ptrAddrUnsafe e) n }
def shareCommon (e : Expr) : GrindM Expr := do
modifyGet fun { canon, scState, nextThmIdx, goals } =>
let (e, scState) := ShareCommon.State.shareCommon scState e
(e, { canon, scState, nextThmIdx, goals })
def mkENodeCore (e : Expr) (interpreted ctor : Bool) (generation : Nat) : GoalM Unit := do
setENode e {
next := e, root := e, cgRoot := e, size := 1
flipped := false
heqProofs := false
hasLambdas := e.isLambda
mt := ( get).gmt
interpreted, ctor, generation
}
def canon (e : Expr) : GrindM Expr := do
let canonS modifyGet fun s => (s.canon, { s with canon := {} })
let (e, canonS) Canonicalizer.CanonM.run (canonRec e) (s := canonS)
modify fun s => { s with canon := canonS }
return e
where
canonRec (e : Expr) : CanonM Expr := do
let post (e : Expr) : CanonM TransformStep := do
if e.isApp then
return .done ( Meta.canon e)
else
return .done e
transform e post
def mkGoal (mvarId : MVarId) : GrindM Goal := do
GoalM.run' { mvarId } do
mkENodeCore ( shareCommon (mkConst ``True)) (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore ( shareCommon (mkConst ``False)) (interpreted := true) (ctor := false) (generation := 0)
end Lean.Meta.Grind

View File

@@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Lean.Meta.AbstractNestedProofs
import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Clear
namespace Lean.Meta.Grind
/--
@@ -66,7 +67,7 @@ def _root_.Lean.MVarId.betaReduce (mvarId : MVarId) : MetaM MVarId :=
If the target is not `False`, apply `byContradiction`.
-/
def _root_.Lean.MVarId.byContra? (mvarId : MVarId) : MetaM (Option MVarId) := mvarId.withContext do
mvarId.checkNotAssigned `grind
mvarId.checkNotAssigned `grind.by_contra
let target mvarId.getType
if target.isFalse then return none
let targetNew mkArrow (mkNot target) (mkConst ``False)
@@ -75,4 +76,24 @@ def _root_.Lean.MVarId.byContra? (mvarId : MVarId) : MetaM (Option MVarId) := mv
mvarId.assign <| mkApp2 (mkConst ``Classical.byContradiction) target mvarNew
return mvarNew.mvarId!
/--
Clear auxiliary decls used to encode recursive declarations.
`grind` eliminates them to ensure they are not accidentaly used by its proof automation.
-/
def _root_.Lean.MVarId.clearAuxDecls (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
mvarId.checkNotAssigned `grind.clear_aux_decls
let mut toClear := []
for localDecl in ( getLCtx) do
if localDecl.isAuxDecl then
toClear := localDecl.fvarId :: toClear
if toClear.isEmpty then
return mvarId
let mut mvarId := mvarId
for fvarId in toClear do
try
mvarId mvarId.clear fvarId
catch _ =>
throwTacticEx `grind.clear_aux_decls mvarId "failed to clear local auxiliary declaration"
return mvarId
end Lean.Meta.Grind

View File

@@ -4,13 +4,29 @@ open Lean Meta Elab Tactic Grind in
elab "grind_pre" : tactic => do
let declName := ( Term.getDeclName?).getD `_main
liftMetaTactic fun mvarId => do
let result Meta.Grind.main mvarId declName
return result.goals.map (·.mvarId) |>.toList
Meta.Grind.main mvarId declName
abbrev f (a : α) := a
attribute [grind_cases] And Or
/--
warning: declaration uses 'sorry'
---
info: a b c : Bool
p q : Prop
left✝ : a = true
right✝ : b = true c = true
left : p
right : q
x✝ : b = false a = false
⊢ False
-/
#guard_msgs in
theorem ex (h : (f a && (b || f (f c))) = true) (h' : p q) : b && a := by
grind_pre
trace_state
all_goals sorry
open Lean.Grind.Eager in
/--
warning: declaration uses 'sorry'
---
@@ -51,11 +67,12 @@ h : a = false
⊢ False
-/
#guard_msgs in
theorem ex (h : (f a && (b || f (f c))) = true) (h' : p q) : b && a := by
theorem ex2 (h : (f a && (b || f (f c))) = true) (h' : p q) : b && a := by
grind_pre
trace_state
all_goals sorry
def g (i : Nat) (j : Nat) (_ : i > j := by omega) := i + j
example (i j : Nat) (h : i + 1 > j + 1) : g (i+1) j = f ((fun x => x) i) + f j + 1 := by
@@ -65,27 +82,29 @@ example (i j : Nat) (h : i + 1 > j + 1) : g (i+1) j = f ((fun x => x) i) + f j +
guard_hyp hn : ¬g (i + 1) j _ = i + j + 1
simp_arith [g] at hn
structure Point where
x : Nat
y : Int
/--
warning: declaration uses 'sorry'
---
info: α✝ : Type u_1
β✝ : Type u_2
a : α✝ × β✝
a : α✝
a₃ : β✝
as : List (α✝ × β✝)
b : α✝ × β✝
b : α✝
b₃ : β✝
bs : List (α✝ × β✝)
info: a₁ : Point
a₂ : Nat
a : Int
as : List Point
b₁ : Point
bs : List Point
b : Nat
b : Int
head_eq : a₁ = b₁
fst_eq : a₂ = b₂
snd_eq : a₃ = b₃
x_eq : a₂ = b₂
y_eq : a₃ = b₃
tail_eq : as = bs
⊢ False
-/
#guard_msgs in
theorem ex2 (h : a₁ :: (a₂, a₃) :: as = b₁ :: (b₂, b₃) :: bs) : False := by
theorem ex3 (h : a₁ :: { x := a₂, y := a₃ : Point } :: as = b₁ :: { x := b₂, y := b₃} :: bs) : False := by
grind_pre
trace_state
sorry