Compare commits

...

9 Commits

Author SHA1 Message Date
Leonardo de Moura
0953d3f983 fix: match-equation theorem activation in grind 2025-01-03 17:49:12 -08:00
Leonardo de Moura
b72eadfbc1 feat: activate match-equations automatically in the grind tactic 2025-01-03 17:28:14 -08:00
Leonardo de Moura
c029ced7a9 chore: cleanup 2025-01-03 16:38:01 -08:00
Leonardo de Moura
337a261871 chore: mark additional Exists theorems with [grind_norm] 2025-01-03 16:21:19 -08:00
Leonardo de Moura
c3231b5f74 chore: use MessageData.ofConst at trace[grind.ematch.pattern] 2025-01-03 16:02:34 -08:00
Leonardo de Moura
1068f45274 feat: activate match-equations during internalization 2025-01-03 16:02:19 -08:00
Leonardo de Moura
5417d86e5c feat: add mkEMatchTheorem and mkEMatchEqTheorem 2025-01-03 15:28:06 -08:00
Leonardo de Moura
7958f685fb refactor: EMatchTheorems 2025-01-03 15:19:57 -08:00
Leonardo de Moura
c9a8c86295 feat: add addEMatchEqTheorem 2025-01-03 14:39:03 -08:00
11 changed files with 224 additions and 28 deletions

View File

@@ -5,6 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Init.SimpLemmas
import Init.PropLemmas
import Init.Classical
import Init.ByCases
@@ -64,7 +65,7 @@ attribute [grind_norm] forall_and
-- Exists
@[grind_norm] theorem not_exists (p : α Prop) : (¬ x, p x) = x, ¬p x := by simp
attribute [grind_norm] exists_const exists_or
attribute [grind_norm] exists_const exists_or exists_prop exists_and_left exists_and_right
-- Bool cond
@[grind_norm] theorem cond_eq_ite (c : Bool) (a b : α) : cond c a b = ite c a b := by

View File

@@ -28,6 +28,8 @@ structure Config where
gen : Nat := 5
/-- Maximum number of theorem instances generated using E-matching in a proof search tree branch. -/
instances : Nat := 1000
/-- If `matchEqs` is `true`, `grind` uses `match`-equations as E-matching theorems. -/
matchEqs : Bool := true
deriving Inhabited, BEq
end Lean.Grind

View File

@@ -11,6 +11,13 @@ namespace Lean.Grind
/-- A helper gadget for annotating nested proofs in goals. -/
def nestedProof (p : Prop) (h : p) : p := h
/--
Gadget for marking terms that should not be normalized by `grind`s simplifier.
`grind` uses a simproc to implement this feature.
We use it when adding instances of `match`-equations to prevent them from being simplified to true.
-/
def doNotSimp {α : Sort u} (a : α) : α := a
set_option pp.proofs true
theorem nestedProof_congr (p q : Prop) (h : p = q) (hp : p) (hq : q) : HEq (nestedProof p hp) (nestedProof q hq) := by

View File

@@ -0,0 +1,35 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Grind.Util
import Init.Simproc
import Lean.Meta.Tactic.Simp.Simproc
namespace Lean.Meta.Grind
/--
Returns `Grind.doNotSimp e`.
Recall that `Grind.doNotSimp` is an identity function, but the following simproc is used to prevent the term `e` from being simplified.
-/
def markAsDoNotSimp (e : Expr) : MetaM Expr :=
mkAppM ``Grind.doNotSimp #[e]
builtin_dsimproc_decl reduceDoNotSimp (Grind.doNotSimp _) := fun e => do
let_expr Grind.doNotSimp _ _ e | return .continue
return .done e
/-- Adds `reduceDoNotSimp` to `s` -/
def addDoNotSimp (s : Simprocs) : CoreM Simprocs := do
s.add ``reduceDoNotSimp (post := false)
/-- Erases `Grind.doNotSimp` annotations. -/
def eraseDoNotSimp (e : Expr) : CoreM Expr := do
let pre (e : Expr) := do
let_expr Grind.doNotSimp _ a := e | return .continue e
return .continue a
Core.transform e (pre := pre)
end Lean.Meta.Grind

View File

@@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Intro
import Lean.Meta.Tactic.Grind.DoNotSimp
namespace Lean.Meta.Grind
namespace EMatch
@@ -146,6 +147,15 @@ private def processContinue (c : Choice) (p : Expr) : M Unit := do
let c := { c with gen := Nat.max gen c.gen }
modify fun s => { s with choiceStack := c :: s.choiceStack }
/-- Helper function for marking parts of `match`-equation theorem as "do-not-simplify" -/
private partial def annotateMatchEqnType (prop : Expr) : M Expr := do
if let .forallE n d b bi := prop then
withLocalDecl n bi ( markAsDoNotSimp d) fun x => do
mkForallFVars #[x] ( annotateMatchEqnType (b.instantiate1 x))
else
let_expr f@Eq α lhs rhs := prop | return prop
return mkApp3 f α ( markAsDoNotSimp lhs) rhs
/--
Stores new theorem instance in the state.
Recall that new instances are internalized later, after a full round of ematching.
@@ -154,7 +164,9 @@ private def addNewInstance (origin : Origin) (proof : Expr) (generation : Nat) :
let proof instantiateMVars proof
if grind.debug.proofs.get ( getOptions) then
check proof
let prop inferType proof
let mut prop inferType proof
if Match.isMatchEqnTheorem ( getEnv) origin.key then
prop annotateMatchEqnType prop
trace[grind.ematch.instance] "{← origin.pp}: {prop}"
addTheoremInstance proof prop generation
@@ -189,10 +201,10 @@ private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do w
unless ( synthesizeInstance mvar type) do
trace[grind.issues] "failed to synthesize instance when instantiating {← thm.origin.pp}{indentExpr type}"
return ()
let proof := mkAppN proof mvars
if ( mvars.allM (·.mvarId!.isAssigned)) then
addNewInstance thm.origin (mkAppN proof mvars) c.gen
addNewInstance thm.origin proof c.gen
else
let proof := mkAppN proof mvars
let mvars mvars.filterM fun mvar => return !( mvar.mvarId!.isAssigned)
if let some mvarBad mvars.findM? fun mvar => return !( isProof mvar) then
trace[grind.issues] "failed to instantiate {← thm.origin.pp}, failed to instantiate non propositional argument with type{indentExpr (← inferType mvarBad)}"

View File

@@ -56,17 +56,47 @@ structure EMatchTheorem where
origin : Origin
deriving Inhabited
/-- The key is a symbol from `EMatchTheorem.symbols`. -/
abbrev EMatchTheorems := PHashMap Name (List EMatchTheorem)
/-- Set of E-matching theorems. -/
structure EMatchTheorems where
/-- The key is a symbol from `EMatchTheorem.symbols`. -/
private map : PHashMap Name (List EMatchTheorem) := {}
/-- Set of theorem names that have been inserted using `insert`. -/
private thmNames : PHashSet Name := {}
deriving Inhabited
/--
Inserts a `thm` with symbols `[s_1, ..., s_n]` to `s`.
We add `s_1 -> { thm with symbols := [s_2, ..., s_n] }`.
When `grind` internalizes a term containing symbol `s`, we
process all theorems `thm` associated with key `s`.
If their `thm.symbols` is empty, we say they are activated.
Otherwise, we reinsert into `map`.
-/
def EMatchTheorems.insert (s : EMatchTheorems) (thm : EMatchTheorem) : EMatchTheorems := Id.run do
let .const declName :: syms := thm.symbols
| unreachable!
let thm := { thm with symbols := syms }
if let some thms := s.find? declName then
return PersistentHashMap.insert s declName (thm::thms)
let { map, thmNames } := s
let thmNames := thmNames.insert thm.origin.key
if let some thms := map.find? declName then
return { map := map.insert declName (thm::thms), thmNames }
else
return PersistentHashMap.insert s declName [thm]
return { map := map.insert declName [thm], thmNames }
/--
Retrieves theorems from `s` associated with the given symbol. See `EMatchTheorem.insert`.
The theorems are removed from `s`.
-/
@[inline]
def EMatchTheorems.retrieve? (s : EMatchTheorems) (sym : Name) : Option (List EMatchTheorem × EMatchTheorems) :=
if let some thms := s.map.find? sym then
some (thms, { s with map := s.map.erase sym })
else
none
/-- Returns `true` if `declName` is the name of a theorem that was inserted using `insert`. -/
def EMatchTheorems.containsTheoremName (s : EMatchTheorems) (declName : Name) : Bool :=
s.thmNames.contains declName
def EMatchTheorem.getProofWithFreshMVarLevels (thm : EMatchTheorem) : MetaM Expr := do
if thm.proof.isConst && thm.levelParams.isEmpty then
@@ -85,7 +115,7 @@ def EMatchTheorem.getProofWithFreshMVarLevels (thm : EMatchTheorem) : MetaM Expr
private builtin_initialize ematchTheoremsExt : SimpleScopedEnvExtension EMatchTheorem EMatchTheorems
registerSimpleScopedEnvExtension {
addEntry := EMatchTheorems.insert
initial := .empty
initial := {}
}
-- TODO: create attribute?
@@ -320,8 +350,8 @@ private def checkCoverage (thmProof : Expr) (numParams : Nat) (bvarsFound : Std.
Given a theorem with proof `proof` and `numParams` parameters, returns a message
containing the parameters at positions `paramPos`.
-/
private def ppParamsAt (proof : Expr) (numParms : Nat) (paramPos : List Nat) : MetaM MessageData := do
forallBoundedTelescope ( inferType proof) numParms fun xs _ => do
private def ppParamsAt (proof : Expr) (numParams : Nat) (paramPos : List Nat) : MetaM MessageData := do
forallBoundedTelescope ( inferType proof) numParams fun xs _ => do
let mut msg := m!""
let mut first := true
for h : i in [:xs.size] do
@@ -331,23 +361,53 @@ private def ppParamsAt (proof : Expr) (numParms : Nat) (paramPos : List Nat) : M
msg := msg ++ m!"{x} : {← inferType x}"
addMessageContextFull msg
def addEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM Unit := do
/--
Creates an E-matching theorem for `declName` with `numParams` parameters, and the given set of patterns.
Pattern variables are represented using de Bruijn indices.
-/
def mkEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM EMatchTheorem := do
let .thmInfo info getConstInfo declName
| throwError "`{declName}` is not a theorem, you cannot assign patterns to non-theorems for the `grind` tactic"
let us := info.levelParams.map mkLevelParam
let proof := mkConst declName us
let (patterns, symbols, bvarFound) NormalizePattern.main patterns
assert! symbols.all fun s => s matches .const _
trace[grind.ematch.pattern] "{declName}: {patterns.map ppPattern}"
trace[grind.ematch.pattern] "{MessageData.ofConst proof}: {patterns.map ppPattern}"
if let .missing pos checkCoverage proof numParams bvarFound then
let pats : MessageData := m!"{patterns.map ppPattern}"
throwError "invalid pattern(s) for `{declName}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
ematchTheoremsExt.add {
proof, patterns, numParams, symbols
levelParams := #[]
origin := .decl declName
return {
proof, patterns, numParams, symbols
levelParams := #[]
origin := .decl declName
}
/--
Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`,
creates an E-matching pattern for it using `addEMatchTheorem n [lhs]`
-/
def mkEMatchEqTheorem (declName : Name) : MetaM EMatchTheorem := do
let info getConstInfo declName
let (numParams, patterns) forallTelescopeReducing info.type fun xs type => do
let_expr Eq _ lhs _ := type | throwError "invalid E-matching equality theorem, conclusion must be an equality{indentExpr type}"
return (xs.size, [lhs.abstract xs])
mkEMatchTheorem declName numParams patterns
/--
Adds an E-matching theorem to the environment.
See `mkEMatchTheorem`.
-/
def addEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM Unit := do
ematchTheoremsExt.add ( mkEMatchTheorem declName numParams patterns)
/--
Adds an E-matching equality theorem to the environment.
See `mkEMatchEqTheorem`.
-/
def addEMatchEqTheorem (declName : Name) : MetaM Unit := do
ematchTheoremsExt.add ( mkEMatchEqTheorem declName)
/-- Returns the E-matching theorems registered in the environment. -/
def getEMatchTheorems : CoreM EMatchTheorems :=
return ematchTheoremsExt.getState ( getEnv)

View File

@@ -6,6 +6,8 @@ Authors: Leonardo de Moura
prelude
import Init.Grind.Util
import Lean.Meta.LitValues
import Lean.Meta.Match.MatcherInfo
import Lean.Meta.Match.MatchEqsExt
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Util
@@ -50,21 +52,36 @@ private partial def internalizePattern (pattern : Expr) (generation : Nat) : Goa
else pattern.withApp fun f args => do
return mkAppN f ( args.mapM (internalizePattern · generation))
private partial def activateTheorem (thm : EMatchTheorem) (generation : Nat) : GoalM Unit := do
-- Recall that we use the proof as part of the key for a set of instances found so far.
-- We don't want to use structural equality when comparing keys.
let proof shareCommon thm.proof
let thm := { thm with proof, patterns := ( thm.patterns.mapM (internalizePattern · generation)) }
trace[grind.ematch] "activated `{thm.origin.key}`, {thm.patterns.map ppPattern}"
modify fun s => { s with newThms := s.newThms.push thm }
/--
If `Config.matchEqs` is set to `true`, and `f` is `match`-auxiliary function,
adds its equations to `newThms`.
-/
private partial def addMatchEqns (f : Expr) (generation : Nat) : GoalM Unit := do
if !( getConfig).matchEqs then return ()
let .const declName _ := f | return ()
if !( isMatcher declName) then return ()
if ( get).matchEqNames.contains declName then return ()
modify fun s => { s with matchEqNames := s.matchEqNames.insert declName }
for eqn in ( Match.getEquationsFor declName).eqnNames do
activateTheorem ( mkEMatchEqTheorem eqn) generation
private partial def activateTheoremPatterns (fName : Name) (generation : Nat) : GoalM Unit := do
if let some thms := ( get).thmMap.find? fName then
modify fun s => { s with thmMap := s.thmMap.erase fName }
if let some (thms, thmMap) := ( get).thmMap.retrieve? fName then
modify fun s => { s with thmMap }
let appMap := ( get).appMap
for thm in thms do
let symbols := thm.symbols.filter fun sym => !appMap.contains sym
let thm := { thm with symbols }
match symbols with
| [] =>
-- Recall that we use the proof as part of the key for a set of instances found so far.
-- We don't want to use structural equality when comparing keys.
let proof shareCommon thm.proof
let thm := { thm with proof, patterns := ( thm.patterns.mapM (internalizePattern · generation)) }
trace[grind.ematch] "activated `{thm.origin.key}`, {thm.patterns.map ppPattern}"
modify fun s => { s with newThms := s.newThms.push thm }
| [] => activateTheorem thm generation
| _ =>
trace[grind.ematch] "reinsert `{thm.origin.key}`"
modify fun s => { s with thmMap := s.thmMap.insert thm }
@@ -95,6 +112,7 @@ partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
-- We do not want to internalize the components of a literal value.
mkENode e generation
else e.withApp fun f args => do
addMatchEqns f generation
if f.isConstOf ``Lean.Grind.nestedProof && args.size == 2 then
-- We only internalize the proposition. We can skip the proof because of
-- proof irrelevance

View File

@@ -14,6 +14,7 @@ import Lean.Meta.Tactic.Grind.Util
import Lean.Meta.Tactic.Grind.Inv
import Lean.Meta.Tactic.Grind.Intro
import Lean.Meta.Tactic.Grind.EMatch
import Lean.Meta.Tactic.Grind.DoNotSimp
namespace Lean.Meta.Grind
@@ -38,7 +39,7 @@ def GrindM.run (x : GrindM α) (mainDeclName : Name) (config : Grind.Config) (fa
let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False)
let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True)
let thms grindNormExt.getTheorems
let simprocs := #[( grindNormSimprocExt.getSimprocs)]
let simprocs := #[( addDoNotSimp ( grindNormSimprocExt.getSimprocs))]
let simp Simp.mkContext
(config := { arith := true })
(simpTheorems := #[thms])

View File

@@ -9,6 +9,7 @@ import Lean.Meta.Tactic.Assert
import Lean.Meta.Tactic.Simp.Main
import Lean.Meta.Tactic.Grind.Util
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.DoNotSimp
import Lean.Meta.Tactic.Grind.MarkNestedProofs
namespace Lean.Meta.Grind
@@ -33,6 +34,7 @@ def simp (e : Expr) : GrindM Simp.Result := do
let e' eraseIrrelevantMData e'
let e' foldProjs e'
let e' normalizeLevels e'
let e' eraseDoNotSimp e'
let e' canon e'
let e' shareCommon e'
trace[grind.simp] "{e}\n===>\n{e'}"

View File

@@ -379,6 +379,8 @@ structure Goal where
preInstances : PreInstanceSet := {}
/-- new facts to be processed. -/
newFacts : Std.Queue NewFact :=
/-- `match` auxiliary functions whose equations have already been created and activated. -/
matchEqNames : PHashSet Name := {}
deriving Inhabited
def Goal.admit (goal : Goal) : MetaM Unit :=

View File

@@ -0,0 +1,56 @@
def g (xs : List α) (ys : List α) :=
match xs, ys with
| [], _ => ys
| _::_::_, [ ] => []
| x::xs, ys => x :: g xs ys
attribute [simp] g
set_option trace.grind.assert true
/--
info: [grind.assert] (match as, bs with
| [], x => bs
| head :: head_1 :: tail, [] => []
| x :: xs, ys => x :: g xs ys) =
d
[grind.assert] bs = []
[grind.assert] a₁ :: f 0 = as
[grind.assert] f 0 = a₂ :: f 1
[grind.assert] ¬d = []
[grind.assert] (match a₁ :: a₂ :: f 1, [] with
| [], x => bs
| head :: head_1 :: tail, [] => []
| x :: xs, ys => x :: g xs ys) =
[]
-/
#guard_msgs (info) in
example (f : Nat List Nat) : g as bs = d bs = [] a₁ :: f 0 = as f 0 = a₂ :: f 1 d = [] := by
unfold g
grind
example : g as bs = d as = [] d = bs := by
unfold g
grind
def f (x : List α) : Bool :=
match x with
| [] => true
| _::_ => false
example : f a = b a = [] b = true := by
unfold f
grind
def f' (x : List Nat) : Bool :=
match x with
| [] => true
| _::_ => false
attribute [simp] f'
#check f'.match_1.eq_1
example : f' a = b a = [] b = true := by
unfold f'
grind