Compare commits

...

4 Commits

Author SHA1 Message Date
Leonardo de Moura
9c25a94e7c feat: add PreInstanceSet 2024-12-31 11:48:24 -08:00
Leonardo de Moura
f441a55495 feat: cleanup enode and parents map 2024-12-31 11:48:24 -08:00
Leonardo de Moura
1e66ead216 chore: move configuration getters to Types.lean 2024-12-31 11:48:24 -08:00
Leonardo de Moura
67fb33e05a refactor: avoid internalizing e-match new instances in the EMatch.lean module.
Reason: It does not use the correct monad. We want to preprocess the
new instances before internalizing them, and preprocessing may trigger
the creation of new goals.
2024-12-31 11:48:24 -08:00
4 changed files with 124 additions and 61 deletions

View File

@@ -8,14 +8,14 @@ import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Internalize
namespace Lean.Meta.Grind
/-- Returns maximum term generation that is considered during ematching -/
private def getMaxGeneration : GoalM Nat := do
return 10000 -- TODO
/-- Returns `true` if the maximum number of instances has been reached. -/
private def checkMaxInstancesExceeded : GoalM Bool := do
return false -- TODO
/--
Theorem instance found using E-matching.
Recall that we only internalize new instances after we complete a full round of E-matching. -/
structure EMatchTheoremInstance where
proof : Expr
prop : Expr
generation : Nat
deriving Inhabited
namespace EMatch
/-! This module implements a simple E-matching procedure as a backtracking search. -/
@@ -51,13 +51,6 @@ structure Choice where
assignment : Array Expr
deriving Inhabited
/-- Theorem instances found so far. We only internalize them after we complete a full round of E-matching. -/
structure TheoremInstance where
proof : Expr
prop : Expr
generation : Nat
deriving Inhabited
/-- Context for the E-matching monad. -/
structure Context where
/-- `useMT` is `true` if we are using the mod-time optimization. It is always set to false for new `EMatchTheorem`s. -/
@@ -70,7 +63,7 @@ structure Context where
structure State where
/-- Choices that still have to be processed. -/
choiceStack : List Choice := []
newInstances : PArray TheoremInstance := {}
newInstances : Array EMatchTheoremInstance := #[]
deriving Inhabited
abbrev M := ReaderT Context $ StateRefT State GoalM
@@ -181,6 +174,8 @@ Missing parameters are synthesized using type inference and type class synthesis
-/
private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do
let thm := ( read).thm
unless ( markTheorenInstance thm.proof c.assignment) do
return ()
trace[grind.ematch.instance.assignment] "{← thm.origin.pp}: {assignmentToMessageData c.assignment}"
let proof thm.getProofWithFreshMVarLevels
let numParams := thm.numParams
@@ -285,22 +280,26 @@ where
def ematchTheorems (thms : PArray EMatchTheorem) : M Unit := do
thms.forM ematchTheorem
def internalizeNewInstances : M Unit := do
-- TODO
return ()
end EMatch
open EMatch
/-- Performs one round of E-matching, and internalizes new instances. -/
def ematch : GoalM Unit := do
let go (thms newThms : PArray EMatchTheorem) : EMatch.M Unit := do
/-- Performs one round of E-matching, and returns new instances. -/
def ematch : GoalM (Array EMatchTheoremInstance) := do
let go (thms newThms : PArray EMatchTheorem) : EMatch.M (Array EMatchTheoremInstance) := do
withReader (fun ctx => { ctx with useMT := true }) <| ematchTheorems thms
withReader (fun ctx => { ctx with useMT := false }) <| ematchTheorems newThms
internalizeNewInstances
unless ( checkMaxInstancesExceeded) do
go ( get).thms ( get).newThms |>.run'
modify fun s => { s with thms := s.thms ++ s.newThms, newThms := {}, gmt := s.gmt + 1 }
return ( get).newInstances
if ( checkMaxInstancesExceeded) then
return #[]
else
let insts go ( get).thms ( get).newThms |>.run'
modify fun s => { s with
thms := s.thms ++ s.newThms
newThms := {}
gmt := s.gmt + 1
numInstances := s.numInstances + insts.size
}
return insts
end Lean.Meta.Grind

View File

@@ -59,7 +59,10 @@ private partial def activateTheoremPatterns (fName : Name) (generation : Nat) :
let thm := { thm with symbols }
match symbols with
| [] =>
let thm := { thm with patterns := ( thm.patterns.mapM (internalizePattern · generation)) }
-- Recall that we use the proof as part of the key for a set of instances found so far.
-- We don't want to use structural equality when comparing keys.
let proof shareCommon thm.proof
let thm := { thm with proof, patterns := ( thm.patterns.mapM (internalizePattern · generation)) }
trace[grind.ematch] "activated `{thm.origin.key}`, {thm.patterns.map ppPattern}"
modify fun s => { s with newThms := s.newThms.push thm }
| _ =>

View File

@@ -144,7 +144,7 @@ def preprocess (mvarId : MVarId) : PreM State := do
loop ( mkGoal mvarId)
let goals := ( get).goals
-- Testing `ematch` module here. We will rewrite this part later.
let goals goals.mapM fun goal => GoalM.run' goal ematch
let goals goals.mapM fun goal => GoalM.run' goal (discard <| ematch)
if ( isTracingEnabledFor `grind.pre) then
trace[grind.debug.pre] ( ppGoals goals)
for goal in goals do

View File

@@ -101,6 +101,11 @@ def getMainDeclName : GrindM Name :=
@[inline] def getMethodsRef : GrindM MethodsRef :=
read
/--
Returns maximum term generation that is considered during ematching. -/
def getMaxGeneration : GrindM Nat := do
return 10000 -- TODO
/--
Abtracts nested proofs in `e`. This is a preprocessing step performed before internalization.
-/
@@ -193,31 +198,44 @@ structure NewEq where
proof : Expr
isHEq : Bool
abbrev ENodes := PHashMap USize ENode
/--
Key for the `ENodeMap` and `ParentMap` map.
We use pointer addresses and rely on the fact all internalized expressions
have been hash-consed, i.e., we have applied `shareCommon`.
-/
private structure ENodeKey where
expr : Expr
structure CongrKey (enodes : ENodes) where
instance : Hashable ENodeKey where
hash k := unsafe (ptrAddrUnsafe k.expr).toUInt64
instance : BEq ENodeKey where
beq k₁ k₂ := isSameExpr k₁.expr k₂.expr
abbrev ENodeMap := PHashMap ENodeKey ENode
/--
Key for the congruence table.
We need access to the `enodes` to be able to retrieve the equivalence class roots.
-/
structure CongrKey (enodes : ENodeMap) where
e : Expr
private abbrev toENodeKey (e : Expr) : USize :=
unsafe ptrAddrUnsafe e
private def hashRoot (enodes : ENodes) (e : Expr) : UInt64 :=
if let some node := enodes.find? (toENodeKey e) then
toENodeKey node.root |>.toUInt64
private def hashRoot (enodes : ENodeMap) (e : Expr) : UInt64 :=
if let some node := enodes.find? { expr := e } then
unsafe (ptrAddrUnsafe node.root).toUInt64
else
13
private def hasSameRoot (enodes : ENodes) (a b : Expr) : Bool := Id.run do
let ka := toENodeKey a
let kb := toENodeKey b
if ka == kb then
private def hasSameRoot (enodes : ENodeMap) (a b : Expr) : Bool := Id.run do
if isSameExpr a b then
return true
else
let some n1 := enodes.find? ka | return false
let some n2 := enodes.find? kb | return false
toENodeKey n1.root == toENodeKey n2.root
let some n1 := enodes.find? { expr := a } | return false
let some n2 := enodes.find? { expr := b } | return false
isSameExpr n1.root n2.root
def congrHash (enodes : ENodes) (e : Expr) : UInt64 :=
def congrHash (enodes : ENodeMap) (e : Expr) : UInt64 :=
if e.isAppOfArity ``Lean.Grind.nestedProof 2 then
-- We only hash the proposition
hashRoot enodes (e.getArg! 0)
@@ -229,7 +247,7 @@ where
| .app f a => go f (mixHash r (hashRoot enodes a))
| _ => mixHash r (hashRoot enodes e)
partial def isCongruent (enodes : ENodes) (a b : Expr) : Bool :=
partial def isCongruent (enodes : ENodeMap) (a b : Expr) : Bool :=
if a.isAppOfArity ``Lean.Grind.nestedProof 2 && b.isAppOfArity ``Lean.Grind.nestedProof 2 then
hasSameRoot enodes (a.getArg! 0) (b.getArg! 0)
else
@@ -249,15 +267,43 @@ instance : Hashable (CongrKey enodes) where
instance : BEq (CongrKey enodes) where
beq k1 k2 := isCongruent enodes k1.e k2.e
abbrev CongrTable (enodes : ENodes) := PHashSet (CongrKey enodes)
abbrev CongrTable (enodes : ENodeMap) := PHashSet (CongrKey enodes)
-- Remark: we cannot use pointer addresses here because we have to traverse the tree.
abbrev ParentSet := RBTree Expr Expr.quickComp
abbrev ParentMap := PHashMap USize ParentSet
abbrev ParentMap := PHashMap ENodeKey ParentSet
/--
The E-matching module instantiates theorems using the `EMatchTheorem proof` and a (partial) assignment.
We want to avoid instantiating the same theorem with the same assignment more than once.
Therefore, we store the (pre-)instance information in set.
Recall that the proofs of activated theorems have been hash-consed.
The assignment contains internalized expressions, which have also been hash-consed.
-/
structure PreInstance where
proof : Expr
assignment : Array Expr
instance : Hashable PreInstance where
hash i := Id.run do
let mut r := unsafe (ptrAddrUnsafe i.proof >>> 3).toUInt64
for v in i.assignment do
r := mixHash r (unsafe (ptrAddrUnsafe v >>> 3).toUInt64)
return r
instance : BEq PreInstance where
beq i₁ i₂ := Id.run do
unless isSameExpr i₁.proof i₂.proof do return false
unless i₁.assignment.size == i₂.assignment.size do return false
for v₁ in i₁.assignment, v₂ in i₂.assignment do
unless isSameExpr v₁ v₂ do return false
return true
abbrev PreInstanceSet := PHashSet PreInstance
structure Goal where
mvarId : MVarId
enodes : ENodes := {}
enodes : ENodeMap := {}
parents : ParentMap := {}
congrTable : CongrTable enodes := {}
/--
@@ -285,6 +331,8 @@ structure Goal where
thmMap : EMatchTheorems
/-- Number of theorem instances generated so far -/
numInstances : Nat := 0
/-- (pre-)instances found so far -/
instances : PreInstanceSet := {}
deriving Inhabited
def Goal.admit (goal : Goal) : MetaM Unit :=
@@ -294,6 +342,21 @@ abbrev GoalM := StateRefT Goal GrindM
abbrev Propagator := Expr GoalM Unit
/--
A helper function used to mark a theorem instance found by the E-matching module.
It returns `true` if it is a new instance and `false` otherwise.
-/
def markTheorenInstance (proof : Expr) (assignment : Array Expr) : GoalM Bool := do
let k := { proof, assignment }
if ( get).instances.contains k then
return false
modify fun s => { s with instances := s.instances.insert k }
return true
/-- Returns `true` if the maximum number of instances has been reached. -/
def checkMaxInstancesExceeded : GoalM Bool := do
return false -- TODO
/-- Returns `true` if `e` is the internalized `True` expression. -/
def isTrueExpr (e : Expr) : GrindM Bool :=
return isSameExpr e ( getTrueExpr)
@@ -307,11 +370,11 @@ Returns `some n` if `e` has already been "internalized" into the
Otherwise, returns `none`s.
-/
def getENode? (e : Expr) : GoalM (Option ENode) :=
return ( get).enodes.find? (unsafe ptrAddrUnsafe e)
return ( get).enodes.find? { expr := e }
/-- Returns node associated with `e`. It assumes `e` has already been internalized. -/
def getENode (e : Expr) : GoalM ENode := do
let some n := ( get).enodes.find? (unsafe ptrAddrUnsafe e)
let some n := ( get).enodes.find? { expr := e }
| throwError "internal `grind` error, term has not been internalized{indentExpr e}"
return n
@@ -362,7 +425,7 @@ def getNext (e : Expr) : GoalM Expr :=
/-- Returns `true` if `e` has already been internalized. -/
def alreadyInternalized (e : Expr) : GoalM Bool :=
return ( get).enodes.contains (unsafe ptrAddrUnsafe e)
return ( get).enodes.contains { expr := e }
def getTarget? (e : Expr) : GoalM (Option Expr) := do
let some n getENode? e | return none
@@ -407,9 +470,8 @@ information in the root (aka canonical representative) of `child`.
-/
def registerParent (parent : Expr) (child : Expr) : GoalM Unit := do
let some childRoot getRoot? child | return ()
let key := toENodeKey childRoot
let parents := if let some parents := ( get).parents.find? key then parents else {}
modify fun s => { s with parents := s.parents.insert key (parents.insert parent) }
let parents := if let some parents := ( get).parents.find? { expr := childRoot } then parents else {}
modify fun s => { s with parents := s.parents.insert { expr := childRoot } (parents.insert parent) }
/--
Returns the set of expressions `e` is a child of, or an expression in
@@ -417,7 +479,7 @@ Returns the set of expressions `e` is a child of, or an expression in
The information is only up to date if `e` is the root (aka canonical representative) of the equivalence class.
-/
def getParents (e : Expr) : GoalM ParentSet := do
let some parents := ( get).parents.find? (toENodeKey e) | return {}
let some parents := ( get).parents.find? { expr := e } | return {}
return parents
/--
@@ -425,7 +487,7 @@ Similar to `getParents`, but also removes the entry `e ↦ parents` from the par
-/
def getParentsAndReset (e : Expr) : GoalM ParentSet := do
let parents getParents e
modify fun s => { s with parents := s.parents.erase (toENodeKey e) }
modify fun s => { s with parents := s.parents.erase { expr := e } }
return parents
/--
@@ -433,15 +495,14 @@ Copy `parents` to the parents of `root`.
`root` must be the root of its equivalence class.
-/
def copyParentsTo (parents : ParentSet) (root : Expr) : GoalM Unit := do
let key := toENodeKey root
let mut curr := if let some parents := ( get).parents.find? key then parents else {}
let mut curr := if let some parents := ( get).parents.find? { expr := root } then parents else {}
for parent in parents do
curr := curr.insert parent
modify fun s => { s with parents := s.parents.insert key curr }
modify fun s => { s with parents := s.parents.insert { expr := root } curr }
def setENode (e : Expr) (n : ENode) : GoalM Unit :=
modify fun s => { s with
enodes := s.enodes.insert (unsafe ptrAddrUnsafe e) n
enodes := s.enodes.insert { expr := e } n
congrTable := unsafe unsafeCast s.congrTable
}