Compare commits

...

14 Commits

Author SHA1 Message Date
Leonardo de Moura
481cb95a59 chore: fix test 2025-12-21 18:42:49 -08:00
Leonardo de Moura
d208dba93b chore: request update stage0 2025-12-21 18:16:12 -08:00
Leonardo de Moura
eb3e6047b7 test: 2025-12-21 18:14:55 -08:00
Leonardo de Moura
a75de82685 feat: resolve grind attributes 2025-12-21 18:14:42 -08:00
Leonardo de Moura
61127e672e test: 2025-12-21 18:03:25 -08:00
Leonardo de Moura
795fc2ca63 feat: register_grind_attr 2025-12-21 18:03:00 -08:00
Leonardo de Moura
0a9b6d538f chore: descr 2025-12-21 17:26:07 -08:00
Leonardo de Moura
34258fd00d chore: missing instance 2025-12-21 17:17:08 -08:00
Leonardo de Moura
bb031ff6d2 feat: add Grind.registerAttr 2025-12-21 17:08:29 -08:00
Leonardo de Moura
cd85eb6620 feat: multiple grind attributes 2025-12-21 17:03:38 -08:00
Leonardo de Moura
0feaf40fc5 feat: support for multiple extensions 2025-12-21 16:09:36 -08:00
Leonardo de Moura
32cd5390cd feat: normToUnfold, ExtensionStateArray, and TheoremsArray 2025-12-21 13:49:29 -08:00
Leonardo de Moura
61d4b00b42 refactor: add Extension.lean
This PR groups all `grind` extensions in a single structure
2025-12-21 11:32:26 -08:00
Leonardo de Moura
b8460f58bc chore: doc 2025-12-21 11:32:26 -08:00
26 changed files with 720 additions and 275 deletions

View File

@@ -255,6 +255,7 @@ theorem fg_eq (h : x > 0) : f (g x) = x
-- With minimal subexpression:
@[grind! <-] theorem fg_eq (h : x > 0) : f (g x) = x
-- Pattern selected: `g x`
```
-/
syntax (name := grind!) "grind!" (ppSpace grindMod)? : attr
/--

View File

@@ -88,15 +88,14 @@ def mkConfig (items : Array (TSyntax `Lean.Parser.Tactic.configItem)) : TermElab
elabConfigItems defaultConfig items
def mkParams (config : Grind.Config) : MetaM Params := do
let params Meta.Grind.mkParams config
let casesTypes Grind.getCasesTypes
let mut ematch getEMatchTheorems
let params Meta.Grind.mkDefaultParams config
let mut ematch := params.extensions[0]!.ematch
for declName in muteExt.getState ( getEnv) do
try
ematch ematch.eraseDecl declName
catch _ =>
pure () -- Ignore failures here.
return { params with ematch, casesTypes }
return { params with extensions[0].ematch := ematch }
/-- Returns the total number of generated instances. -/
def sum (cs : PHashMap Grind.Origin Nat) : Nat := Id.run do

View File

@@ -210,22 +210,13 @@ def elabGrindParamsAndSuggestions
def mkGrindParams
(config : Grind.Config) (only : Bool) (ps : TSyntaxArray ``Parser.Tactic.grindParam) (mvarId : MVarId) :
TermElabM Grind.Params := do
let params Grind.mkParams config
let ematch if only then pure default else Grind.getEMatchTheorems
let inj if only then pure default else Grind.getInjectiveTheorems
/-
**Note**: We used to skip the global cases attribute when `only = true`, but
this is not very effective. We now use anchors to restrict the set of case-splits.
-/
let casesTypes Grind.getCasesTypes
let funCCs Grind.getFunCCSet
let params := { params with ematch, casesTypes, inj, funCCs }
let params if only then Grind.mkOnlyParams config else Grind.mkDefaultParams config
let suggestions if config.suggestions then
LibrarySuggestions.select mvarId { caller := some "grind" }
else
pure #[]
let mut params elabGrindParamsAndSuggestions params ps suggestions (only := only) (lax := config.lax)
trace[grind.debug.inj] "{params.inj.getOrigins.map (·.pp)}"
trace[grind.debug.inj] "{params.extensions[0]!.inj.getOrigins.map (·.pp)}"
if params.anchorRefs?.isSome then
/-
**Note**: anchors are automatically computed in interactive mode where

View File

@@ -20,7 +20,49 @@ open Meta
`grind` parameter elaboration
-/
def warnRedundantEMatchArg (s : Grind.EMatchTheorems) (declName : Name) : MetaM Unit := do
def _root_.Lean.Meta.Grind.Params.insertCasesTypes (params : Grind.Params) (declName : Name) (eager : Bool) : Grind.Params :=
{ params with extensions := params.extensions.modify 0 fun ext => { ext with casesTypes := ext.casesTypes.insert declName eager } }
def _root_.Lean.Meta.Grind.Params.eraseCasesTypes (params : Grind.Params) (declName : Name) : CoreM Grind.Params := do
unless params.extensions.any fun ext => ext.casesTypes.contains declName do
Grind.throwNotMarkedWithGrindAttribute declName
return { params with extensions := params.extensions.modify 0 fun ext => { ext with casesTypes := ext.casesTypes.erase declName } }
def _root_.Lean.Meta.Grind.Params.insertFunCC (params : Grind.Params) (declName : Name) : Grind.Params :=
{ params with extensions := params.extensions.modify 0 fun ext => { ext with funCC := ext.funCC.insert declName } }
def _root_.Lean.Meta.Grind.Params.containsEMatch (params : Grind.Params) (declName : Name) : Bool :=
params.extensions.any fun ext => ext.ematch.contains (.decl declName)
def _root_.Lean.Meta.Grind.Params.eraseEMatchCore (params : Grind.Params) (declName : Name) : Grind.Params :=
{ params with extensions := params.extensions.modify 0 fun ext => { ext with ematch := ext.ematch.erase (.decl declName) } }
def _root_.Lean.Meta.Grind.Params.eraseEMatch (params : Grind.Params) (declName : Name) : MetaM Grind.Params := do
if !wasOriginallyTheorem ( getEnv) declName then
if let some eqns getEqnsFor? declName then
unless eqns.all fun eqn => params.containsEMatch eqn do
Grind.throwNotMarkedWithGrindAttribute declName
return eqns.foldl (init := params) fun params eqn => params.eraseEMatchCore eqn
else
Grind.throwNotMarkedWithGrindAttribute declName
else
unless params.containsEMatch declName do
Grind.throwNotMarkedWithGrindAttribute declName
return params.eraseEMatchCore declName
def _root_.Lean.Meta.Grind.Params.eraseInj (params : Grind.Params) (declName : Name) : Grind.Params :=
{ params with extensions := params.extensions.modify 0 fun ext => { ext with inj := ext.inj.erase (.decl declName) } }
def _root_.Lean.Meta.Grind.ExtensionStateArray.getKindsFor (s : Grind.ExtensionStateArray) (origin : Grind.Origin) : List Grind.EMatchTheoremKind := Id.run do
let mut result := []
for ext in s do
let s : Grind.EMatchTheorems := ext.ematch
let ks := s.getKindsFor origin
unless ks.isEmpty do
result := result ++ ks
return result
def warnRedundantEMatchArg (s : Grind.ExtensionStateArray) (declName : Name) : MetaM Unit := do
let minIndexable := false -- TODO: infer it
let kinds match s.getKindsFor (.decl declName) with
| [] => return ()
@@ -54,9 +96,9 @@ public def addEMatchTheorem (params : Grind.Params) (id : Ident) (declName : Nam
let thm₁ Grind.mkEMatchTheoremForDecl declName (.eqLhs gen) params.symPrios
let thm₂ Grind.mkEMatchTheoremForDecl declName (.eqRhs gen) params.symPrios
if warn &&
params.ematch.containsWithSamePatterns thm₁.origin thm₁.patterns thm₁.cnstrs &&
params.ematch.containsWithSamePatterns thm₂.origin thm₂.patterns thm₂.cnstrs then
warnRedundantEMatchArg params.ematch declName
params.extensions.containsWithSamePatterns thm₁.origin thm₁.patterns thm₁.cnstrs &&
params.extensions.containsWithSamePatterns thm₂.origin thm₂.patterns thm₂.cnstrs then
warnRedundantEMatchArg params.extensions declName
return { params with extra := params.extra.push thm₁ |>.push thm₂ }
| _ =>
if kind matches .eqLhs _ | .eqRhs _ then
@@ -65,8 +107,8 @@ public def addEMatchTheorem (params : Grind.Params) (id : Ident) (declName : Nam
Grind.mkEMatchTheoremAndSuggest id declName params.symPrios minIndexable (isParam := true)
else
Grind.mkEMatchTheoremForDecl declName kind params.symPrios (minIndexable := minIndexable)
if warn && params.ematch.containsWithSamePatterns thm.origin thm.patterns thm.cnstrs then
warnRedundantEMatchArg params.ematch declName
if warn && params.extensions.containsWithSamePatterns thm.origin thm.patterns thm.cnstrs then
warnRedundantEMatchArg params.extensions declName
return { params with extra := params.extra.push thm }
| .defn =>
if ( isReducible declName) then
@@ -154,9 +196,13 @@ def processParam (params : Grind.Params)
catch err =>
if ( resolveLocalName id.getId).isSome then
throwErrorAt id "redundant parameter `{id}`, `grind` uses local hypotheses automatically"
else if let some ext Grind.getExtension? id.getId then
if let some mod := mod? then
throwErrorAt mod "invalid use of modifier in `grind` attribute `{id.getId}`"
return { params with extensions := params.extensions.push (ext.getState ( getEnv)) }
else if !id.getId.getPrefix.isAnonymous then
-- Fall back to term elaboration for compound identifiers like `foo.le` (dot notation on declarations)
return processTermParam params p mod? id minIndexable
return ( processTermParam params p mod? id minIndexable)
else
throw err
Linter.checkDeprecated declName
@@ -179,7 +225,7 @@ def processParam (params : Grind.Params)
if incremental then throwError "`cases` parameter are not supported here"
ensureNoMinIndexable minIndexable
withRef p <| Grind.validateCasesAttr declName eager
params := { params with casesTypes := params.casesTypes.insert declName eager }
params := params.insertCasesTypes declName eager
| .intro =>
if let some info Grind.isCasesAttrPredicateCandidate? declName false then
if incremental then throwError "`cases` parameter are not supported here"
@@ -194,7 +240,7 @@ def processParam (params : Grind.Params)
throwError "`[grind ext]` cannot be set using parameters"
| .infer =>
if let some declName Grind.isCasesAttrCandidate? declName false then
params := { params with casesTypes := params.casesTypes.insert declName false }
params := params.insertCasesTypes declName false
if let some info isInductivePredicate? declName then
-- If it is an inductive predicate,
-- we also add the constructors (intro rules) as E-matching rules
@@ -207,7 +253,7 @@ def processParam (params : Grind.Params)
ensureNoMinIndexable minIndexable
params := { params with symPrios := params.symPrios.insert declName prio }
| .funCC =>
params := { params with funCCs := params.funCCs.insert declName }
params := params.insertFunCC declName
return params
/--
@@ -228,11 +274,11 @@ public def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.T
Linter.checkDeprecated declName
if let some declName Grind.isCasesAttrCandidate? declName false then
Grind.ensureNotBuiltinCases declName
params := { params with casesTypes := ( params.casesTypes.eraseDecl declName) }
params params.eraseCasesTypes declName
else if ( Grind.isInjectiveTheorem declName) then
params := { params with inj := params.inj.erase (.decl declName) }
params := params.eraseInj declName
else
params := { params with ematch := ( params.ematch.eraseDecl declName) }
params params.eraseEMatch declName
| `(Parser.Tactic.grindParam| $[$mod?:grindMod]? $id:ident) =>
-- Check if this is dot notation on a local variable (e.g., `n.triv` for `Nat.triv n`).
-- If so, process as term to let elaboration resolve the dot notation properly.
@@ -294,7 +340,7 @@ public def withParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic
let mut params := params
if only then
params := { params with
ematch := {}
extensions := params.extensions.map fun ext => { ext with ematch := {} }
anchorRefs? := none
}
params elabGrindParams params ps (only := only) (incremental := true)

View File

@@ -48,6 +48,7 @@ public import Lean.Meta.Tactic.Grind.Filter
public import Lean.Meta.Tactic.Grind.CollectParams
public import Lean.Meta.Tactic.Grind.Finish
public import Lean.Meta.Tactic.Grind.FunCC
public import Lean.Meta.Tactic.Grind.RegisterCommand
public section
namespace Lean

View File

@@ -63,40 +63,89 @@ def getAttrKindFromOpt (stx : Syntax) : CoreM AttrKind := do
def throwInvalidUsrModifier : CoreM α :=
throwError "the modifier `usr` is only relevant in parameters for `grind only`"
private def Extension.addCasesAttr (ext : Extension) (declName : Name) (eager : Bool) (attrKind : AttributeKind) : CoreM Unit := do
validateCasesAttr declName eager
ext.add (.cases declName eager) attrKind
private def Extension.addExtAttr (ext : Extension) (declName : Name) (attrKind : AttributeKind) : CoreM Unit := do
validateExtAttr declName
ext.add (.ext declName) attrKind
private def Extension.addFunCCAttr (ext : Extension) (declName : Name) (attrKind : AttributeKind) : CoreM Unit := do
ext.add (.funCC declName) attrKind
private def Extension.eraseExtAttr (ext : Extension) (declName : Name) : CoreM Unit := do
let s := ext.getState ( getEnv)
let extThms s.extThms.eraseDecl declName
modifyEnv fun env => ext.modifyState env fun s => { s with extThms }
private def Extension.eraseCasesAttr (ext : Extension) (declName : Name) : CoreM Unit := do
ensureNotBuiltinCases declName
let s := ext.getState ( getEnv)
let casesTypes s.casesTypes.eraseDecl declName
modifyEnv fun env => ext.modifyState env fun s => { s with casesTypes }
private def Extension.eraseFunCCAttr (ext : Extension) (declName : Name) : CoreM Unit := do
let s := ext.getState ( getEnv)
unless s.funCC.contains declName do
throwNotMarkedWithGrindAttribute declName
let funCC := s.funCC.erase declName
modifyEnv fun env => ext.modifyState env fun s => { s with funCC }
private def Extension.eraseEMatchAttr (ext : Extension) (declName : Name) : MetaM Unit := do
let s := ext.getState ( getEnv)
let ematch s.ematch.eraseDecl declName
modifyEnv fun env => ext.modifyState env fun s => { s with ematch }
private def Extension.eraseInjectiveAttr (ext : Extension) (declName : Name) : MetaM Unit := do
let s := ext.getState ( getEnv)
let inj s.inj.eraseDecl declName
modifyEnv fun env => ext.modifyState env fun s => { s with inj }
private def Extension.isExtTheorem (ext : Extension) (declName : Name) : CoreM Bool := do
return ext.getState ( getEnv) |>.extThms.contains declName
private def Extension.isInjectiveTheorem (ext : Extension) (declName : Name) : CoreM Bool := do
return ext.getState ( getEnv) |>.inj.contains (.decl declName)
private def Extension.hasFunCCAttr (ext : Extension) (declName : Name) : CoreM Bool := do
return ext.getState ( getEnv) |>.funCC.contains declName
/--
Auxiliary function for registering `grind`, `grind!`, `grind?`, and `grind!?` attributes.
`grind!` is like `grind` but selects minimal indexable subterms.
The `grind?` and `grind!?` are aliases for `grind` and `grind!` which displays patterns using `logInfo`.
It is just a convenience for users.
-/
private def registerGrindAttr (minIndexable : Bool) (showInfo : Bool) : IO Unit :=
private def mkGrindAttr (attrName : Name) (minIndexable : Bool) (showInfo : Bool) (ext? : Option Extension := none) (ref : Name := by exact decl_name%) : IO Unit :=
registerBuiltinAttribute {
ref := ref
name := match minIndexable, showInfo with
| false, false => `grind
| false, true => `grind?
| true, false => `grind!
| true, true => `grind!?
| false, false => attrName
| false, true => attrName.appendAfter "?"
| true, false => attrName.appendAfter "!"
| true, true => attrName.appendAfter "!?"
descr :=
let header := match minIndexable, showInfo with
| false, false => "The `[grind]` attribute is used to annotate declarations."
| false, true => "The `[grind?]` attribute is identical to the `[grind]` attribute, but displays inferred pattern information."
| true, false => "The `[grind!]` attribute is used to annotate declarations, but selecting minimal indexable subterms."
| true, true => "The `[grind!?]` attribute is identical to the `[grind!]` attribute, but displays inferred pattern information."
header ++ "\
| false, false => s!"The `[{attrName}]` attribute is used to annotate declarations."
| false, true => s!"The `[{attrName}?]` attribute is identical to the `[{attrName}]` attribute, but displays inferred pattern information."
| true, false => s!"The `[{attrName}!]` attribute is used to annotate declarations, but selecting minimal indexable subterms."
| true, true => s!"The `[{attrName}!?]` attribute is identical to the `[{attrName}!]` attribute, but displays inferred pattern information."
header ++ s!"\
\
When applied to an equational theorem, `[grind =]`, `[grind =_]`, or `[grind _=_]`\
will mark the theorem for use in heuristic instantiations by the `grind` tactic,
When applied to an equational theorem, `[{attrName} =]`, `[{attrName} =_]`, or `[{attrName} _=_]`\
will mark the theorem for use in heuristic instantiations by the `{attrName}` tactic,
using respectively the left-hand side, the right-hand side, or both sides of the theorem.\
When applied to a function, `[grind =]` automatically annotates the equational theorems associated with that function.\
When applied to a theorem `[grind ←]` will instantiate the theorem whenever it encounters the conclusion of the theorem
When applied to a function, `[{attrName} =]` automatically annotates the equational theorems associated with that function.\
When applied to a theorem `[{attrName} ←]` will instantiate the theorem whenever it encounters the conclusion of the theorem
(that is, it will use the theorem for backwards reasoning).\
When applied to a theorem `[grind →]` will instantiate the theorem whenever it encounters sufficiently many of the propositional hypotheses
When applied to a theorem `[{attrName} →]` will instantiate the theorem whenever it encounters sufficiently many of the propositional hypotheses
(that is, it will use the theorem for forwards reasoning).\
\
The attribute `[grind]` by itself will effectively try `[grind ←]` (if the conclusion is sufficient for instantiation) and then `[grind →]`.\
The attribute `[{attrName}]` by itself will effectively try `[{attrName} ←]` (if the conclusion is sufficient for instantiation) and then `[{attrName} →]`.\
\
The `grind` tactic utilizes annotated theorems to add instances of matching patterns into the local context during proof search.\
For example, if a theorem `@[grind =] theorem foo_idempotent : foo (foo x) = foo x` is annotated,\
For example, if a theorem `@[{attrName} =] theorem foo_idempotent : foo (foo x) = foo x` is annotated,\
`grind` will add an instance of this theorem to the local context whenever it encounters the pattern `foo (foo x)`."
applicationTime := .afterCompilation
add := fun declName stx attrKind => MetaM.run' do
@@ -105,49 +154,111 @@ private def registerGrindAttr (minIndexable : Bool) (showInfo : Bool) : IO Unit
-- When the body is not available (i.e. the def equations are private), the attribute will not
-- be exported; see `ematchTheoremsExt.exportEntry?`.
withoutExporting do
match ( getAttrKindFromOpt stx) with
| .ematch .user => throwInvalidUsrModifier
| .ematch k => addEMatchAttr declName attrKind k ( getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
| .cases eager => addCasesAttr declName eager attrKind
| .intro =>
if let some info isCasesAttrPredicateCandidate? declName false then
for ctor in info.ctors do
addEMatchAttr ctor attrKind (.default false) ( getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
else
throwError "invalid `[grind intro]`, `{.ofConstName declName}` is not an inductive predicate"
| .ext => addExtAttr declName attrKind
| .infer =>
if let some declName isCasesAttrCandidate? declName false then
addCasesAttr declName false attrKind
if let some info isInductivePredicate? declName then
-- If it is an inductive predicate,
-- we also add the constructors (intro rules) as E-matching rules
if let some ext := ext? then
match ( getAttrKindFromOpt stx) with
| .symbol prio =>
unless attrName == `grind do
throwError "symbol priorities must be set using the default `[grind]` attribute"
addSymbolPriorityAttr declName attrKind prio
| .cases eager => ext.addCasesAttr declName eager attrKind
| .funCC => ext.addFunCCAttr declName attrKind
| .ext => ext.addExtAttr declName attrKind
| .ematch .user => throwInvalidUsrModifier
| .ematch k => ext.addEMatchAttr declName attrKind k ( getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
| .intro =>
if let some info isCasesAttrPredicateCandidate? declName false then
for ctor in info.ctors do
ext.addEMatchAttr ctor attrKind (.default false) ( getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
else
throwError "invalid `[{attrName} intro]`, `{.ofConstName declName}` is not an inductive predicate"
| .infer =>
if let some declName isCasesAttrCandidate? declName false then
ext.addCasesAttr declName false attrKind
if let some info isInductivePredicate? declName then
-- If it is an inductive predicate,
-- we also add the constructors (intro rules) as E-matching rules
for ctor in info.ctors do
ext.addEMatchAttr ctor attrKind (.default false) ( getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
else
ext.addEMatchAttrAndSuggest stx declName attrKind ( getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
| .inj => ext.addInjectiveAttr declName attrKind
else
-- **TODO**: delete after update stage0 and new extension for default `grind` attribute
match ( getAttrKindFromOpt stx) with
| .ematch .user => throwInvalidUsrModifier
| .ematch k => addEMatchAttr declName attrKind k ( getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
| .cases eager => addCasesAttr declName eager attrKind
| .intro =>
if let some info isCasesAttrPredicateCandidate? declName false then
for ctor in info.ctors do
addEMatchAttr ctor attrKind (.default false) ( getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
else
addEMatchAttrAndSuggest stx declName attrKind ( getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
| .symbol prio => addSymbolPriorityAttr declName attrKind prio
| .inj => addInjectiveAttr declName attrKind
| .funCC => addFunCCAttr declName attrKind
else
throwError "invalid `[{attrName} intro]`, `{.ofConstName declName}` is not an inductive predicate"
| .ext => addExtAttr declName attrKind
| .infer =>
if let some declName isCasesAttrCandidate? declName false then
addCasesAttr declName false attrKind
if let some info isInductivePredicate? declName then
-- If it is an inductive predicate,
-- we also add the constructors (intro rules) as E-matching rules
for ctor in info.ctors do
addEMatchAttr ctor attrKind (.default false) ( getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
else
addEMatchAttrAndSuggest stx declName attrKind ( getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
| .symbol prio => addSymbolPriorityAttr declName attrKind prio
| .inj => addInjectiveAttr declName attrKind
| .funCC => addFunCCAttr declName attrKind
erase := fun declName => MetaM.run' do
if showInfo then
throwError "`[grind?]` is a helper attribute for displaying inferred patterns, if you want to remove the attribute, consider using `[grind]` instead"
if ( isCasesAttrCandidate declName false) then
eraseCasesAttr declName
else if ( isExtTheorem declName) then
eraseExtAttr declName
else if ( isInjectiveTheorem declName) then
eraseInjectiveAttr declName
else if ( hasFunCCAttr declName) then
eraseFunCCAttr declName
throwError "`[{attrName}?]` is a helper attribute for displaying inferred patterns, if you want to remove the attribute, consider using `[{attrName}]` instead"
if let some ext := ext? then
if ( isCasesAttrCandidate declName false) then
ext.eraseCasesAttr declName
else if ( ext.isExtTheorem declName) then
ext.eraseExtAttr declName
else if ( ext.isInjectiveTheorem declName) then
ext.eraseInjectiveAttr declName
else if ( ext.hasFunCCAttr declName) then
ext.eraseFunCCAttr declName
else
ext.eraseEMatchAttr declName
else
eraseEMatchAttr declName
-- **TODO**: delete after update stage0 and new extension for default `grind` attribute
if ( isCasesAttrCandidate declName false) then
eraseCasesAttr declName
else if ( isExtTheorem declName) then
eraseExtAttr declName
else if ( isInjectiveTheorem declName) then
eraseInjectiveAttr declName
else if ( hasFunCCAttr declName) then
eraseFunCCAttr declName
else
eraseEMatchAttr declName
}
private def registerDefaultGrindAttr (minIndexable : Bool) (showInfo : Bool) : IO Unit :=
mkGrindAttr `grind minIndexable showInfo
builtin_initialize
registerGrindAttr (minIndexable := false) (showInfo := true)
registerGrindAttr (minIndexable := false) (showInfo := false)
registerGrindAttr (minIndexable := true) (showInfo := true)
registerGrindAttr (minIndexable := true) (showInfo := false)
registerDefaultGrindAttr (minIndexable := false) (showInfo := true)
registerDefaultGrindAttr (minIndexable := false) (showInfo := false)
registerDefaultGrindAttr (minIndexable := true) (showInfo := true)
registerDefaultGrindAttr (minIndexable := true) (showInfo := false)
abbrev ExtensionMap := Std.HashMap Name Extension
builtin_initialize extensionMapRef : IO.Ref ExtensionMap IO.mkRef {}
def getExtension? (attrName : Name) : IO (Option Extension) :=
return ( extensionMapRef.get)[attrName]?
def registerAttr (attrName : Name) (ref : Name := by exact decl_name%) : IO Extension := do
let ext mkExtension ref
mkGrindAttr attrName (minIndexable := false) (showInfo := true) (ext? := some ext) (ref := ref)
mkGrindAttr attrName (minIndexable := false) (showInfo := false) (ext? := some ext) (ref := ref)
mkGrindAttr attrName (minIndexable := true) (showInfo := true) (ext? := some ext) (ref := ref)
mkGrindAttr attrName (minIndexable := true) (showInfo := false) (ext? := some ext) (ref := ref)
extensionMapRef.modify fun map => map.insert attrName ext
return ext
end Lean.Meta.Grind

View File

@@ -6,19 +6,18 @@ Authors: Leonardo de Moura
module
prelude
public import Lean.Meta.Tactic.Cases
public import Lean.Meta.Tactic.Grind.Extension
public section
namespace Lean.Meta.Grind
/-- Types that `grind` will case-split on. -/
structure CasesTypes where
casesMap : PHashMap Name Bool := {}
deriving Inhabited
-- TODO: delete
structure CasesEntry where
declName : Name
eager : Bool
deriving Inhabited
/-- A collection of `CasesTypes`. -/
abbrev CasesTypesArray := Array CasesTypes
/--
`grind` always case-splits on the following types. Even when using `grind only`.
The goal is to reduce noise in the tactic generated by `grind?`
@@ -43,9 +42,6 @@ def CasesTypes.contains (s : CasesTypes) (declName : Name) : Bool :=
def CasesTypes.erase (s : CasesTypes) (declName : Name) : CasesTypes :=
{ s with casesMap := s.casesMap.erase declName }
def CasesTypes.insert (s : CasesTypes) (declName : Name) (eager : Bool) : CasesTypes :=
{ s with casesMap := s.casesMap.insert declName eager }
def CasesTypes.find? (s : CasesTypes) (declName : Name) : Option Bool :=
s.casesMap.find? declName
@@ -55,6 +51,9 @@ def CasesTypes.isEagerSplit (s : CasesTypes) (declName : Name) : Bool :=
def CasesTypes.isSplit (s : CasesTypes) (declName : Name) : Bool :=
(s.casesMap.find? declName |>.isSome) || isBuiltinEagerCases declName
/-
TODO: group into a `grind` extension object
-/
builtin_initialize casesExt : SimpleScopedEnvExtension CasesEntry CasesTypes
registerSimpleScopedEnvExtension {
initial := {}
@@ -68,7 +67,7 @@ def getCasesTypes : CoreM CasesTypes :=
return casesExt.getState ( getEnv)
/-- Returns `true` is `declName` is a builtin split or has been tagged with `[grind]` attribute. -/
def isSplit (declName : Name) : CoreM Bool := do
def isGlobalSplit (declName : Name) : CoreM Bool := do
return ( getCasesTypes).isSplit declName
partial def isCasesAttrCandidate? (declName : Name) (eager : Bool) : CoreM (Option Name) := do
@@ -98,7 +97,7 @@ def CasesTypes.eraseDecl (s : CasesTypes) (declName : Name) : CoreM CasesTypes :
if s.contains declName then
return s.erase declName
else
throwError "`{.ofConstName declName}` is not marked with the `[grind]` attribute"
throwNotMarkedWithGrindAttribute declName
def ensureNotBuiltinCases (declName : Name) : CoreM Unit := do
if isBuiltinEagerCases declName then

View File

@@ -5,7 +5,7 @@ Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Tactic.Grind.Theorems
public import Lean.Meta.Tactic.Grind.Extension
import Init.Grind.Util
import Lean.Util.ForEachExpr
import Lean.Meta.Tactic.Grind.Util
@@ -13,14 +13,7 @@ import Lean.Meta.Match.Basic
import Lean.Meta.Tactic.TryThis
public section
namespace Lean.Meta.Grind
/--
`grind` uses symbol priorities when inferring patterns for E-matching.
Symbols not in `map` are assumed to have default priority (i.e., `eval_prio default`).
-/
structure SymbolPriorities where
map : PHashMap Name Nat := {}
deriving Inhabited
-- TODO: delete
structure SymbolPriorityEntry where
declName : Name
prio : Nat
@@ -30,10 +23,6 @@ structure SymbolPriorityEntry where
def SymbolPriorities.erase (s : SymbolPriorities) (declName : Name) : SymbolPriorities :=
{ s with map := s.map.erase declName }
/-- Inserts `declName ↦ prio` into `s`. -/
def SymbolPriorities.insert (s : SymbolPriorities) (declName : Name) (prio : Nat) : SymbolPriorities :=
{ s with map := s.map.insert declName prio }
/-- Returns `declName` priority for E-matching pattern inference in `s`. -/
def SymbolPriorities.getPrio (s : SymbolPriorities) (declName : Name) : Nat :=
if let some prio := s.map.find? declName then
@@ -48,6 +37,9 @@ Recall that symbols not in `s` are assumed to have default priority.
def SymbolPriorities.contains (s : SymbolPriorities) (declName : Name) : Bool :=
s.map.contains declName
/-
TODO: group into a `grind` extension object
-/
private builtin_initialize symbolPrioExt : SimpleScopedEnvExtension SymbolPriorityEntry SymbolPriorities
registerSimpleScopedEnvExtension {
initial := {}
@@ -286,19 +278,6 @@ def preprocessPattern (pat : Expr) (normalizePattern := true) : MetaM Expr := do
let pat foldProjs pat
return pat
inductive EMatchTheoremKind where
| eqLhs (gen : Bool)
| eqRhs (gen : Bool)
| eqBoth (gen : Bool)
| eqBwd
| fwd
| bwd (gen : Bool)
| leftRight
| rightLeft
| default (gen : Bool)
| user /- pattern specified using `grind_pattern` command -/
deriving Inhabited, BEq, Repr, Hashable
def EMatchTheoremKind.isEqLhs : EMatchTheoremKind Bool
| .eqLhs _ => true
| _ => false
@@ -345,106 +324,13 @@ private def EMatchTheoremKind.explainFailure : EMatchTheoremKind → String
| .default _ => "failed to find patterns"
| .user => unreachable!
structure CnstrRHS where
/-- Abstracted universe level param names in the `rhs` -/
levelNames : Array Name
/-- Number of abstracted metavariable in the `rhs` -/
numMVars : Nat
/-- The actual `rhs`. -/
expr : Expr
deriving Inhabited, BEq, Repr
/--
Grind patterns may have constraints associated with them.
-/
inductive EMatchTheoremConstraint where
| /--
A constraint of the form `lhs =/= rhs`.
The `lhs` is one of the bound variables, and the `rhs` an abstract term that must not be definitionally
equal to a term `t` assigned to `lhs`. -/
notDefEq (lhs : Nat) (rhs : CnstrRHS)
| /--
A constraint of the form `lhs =?= rhs`.
The `lhs` is one of the bound variables, and the `rhs` an abstract term that must be definitionally
equal to a term `t` assigned to `lhs`. -/
defEq (lhs : Nat) (rhs : CnstrRHS)
| /--
A constraint of the form `size lhs < n`. The `lhs` is one of the bound variables.
The size is computed ignoring implicit terms, but sharing is not taken into account.
-/
sizeLt (lhs : Nat) (n : Nat)
| /--
A constraint of the form `depth lhs < n`. The `lhs` is one of the bound variables.
The depth is computed in constant time using the `approxDepth` field attached to expressions.
-/
depthLt (lhs : Nat) (n : Nat)
| /--
Instantiates the theorem only if its generation is less than `n`
-/
genLt (n : Nat)
| /--
Constraints of the form `is_ground x`. Instantiates the theorem only if
`x` is ground term.
-/
isGround (bvarIdx : Nat)
| /--
Constraints of the form `is_value x` and `is_strict_value x`.
A value is defined as
- A constructor fully applied to value arguments.
- A literal: numerals, strings, etc.
- A lambda. In the strict case, lambdas are not considered.
-/
isValue (bvarIdx : Nat) (strict : Bool)
| /--
Instantiates the theorem only if less than `n` instances have been generated for this theorem.
-/
maxInsts (n : Nat)
| /--
It instructs `grind` to postpone the instantiation of the theorem until `e` is known to be `true`.
-/
guard (e : Expr)
| /--
Similar to `guard`, but checks whether `e` is implied by asserting `¬e`.
-/
check (e : Expr)
| /--
Constraints of the form `not_value x` and `not_strict_value x`.
They are the negations of `is_value x` and `is_strict_value x`.
-/
notValue (bvarIdx : Nat) (strict : Bool)
deriving Inhabited, Repr, BEq
/-- A theorem for heuristic instantiation based on E-matching. -/
structure EMatchTheorem where
/--
It stores universe parameter names for universe polymorphic proofs.
Recall that it is non-empty only when we elaborate an expression provided by the user.
When `proof` is just a constant, we can use the universe parameter names stored in the declaration.
-/
levelParams : Array Name
proof : Expr
numParams : Nat
patterns : List Expr
/-- Contains all symbols used in `patterns`. -/
symbols : List HeadIndex
origin : Origin
/-- The `kind` is used for generating the `patterns`. We save it here to implement `grind?`. -/
kind : EMatchTheoremKind
/-- Stores whether patterns were inferred using the minimal indexable subexpression condition. -/
minIndexable : Bool
cnstrs : List EMatchTheoremConstraint := []
deriving Inhabited
instance : TheoremLike EMatchTheorem where
getSymbols thm := thm.symbols
setSymbols thm symbols := { thm with symbols }
getOrigin thm := thm.origin
getProof thm := thm.proof
getLevelParams thm := thm.levelParams
/-- Set of E-matching theorems. -/
abbrev EMatchTheorems := Theorems EMatchTheorem
/-- A collection of sets of E-matching theorems. -/
abbrev EMatchTheoremsArray := TheoremsArray EMatchTheorem
/--
Returns `true` if there is a theorem with exactly the same pattern and constraints is already in `s`
-/
@@ -453,6 +339,10 @@ def EMatchTheorems.containsWithSamePatterns (s : EMatchTheorems) (origin : Origi
let thms := s.find origin
thms.any fun thm => thm.patterns == patterns && thm.cnstrs == cnstrs
def ExtensionStateArray.containsWithSamePatterns (s : ExtensionStateArray) (origin : Origin)
(patterns : List Expr) (cnstrs : List EMatchTheoremConstraint) : Bool :=
s.any (EMatchTheorems.containsWithSamePatterns ·.ematch origin patterns cnstrs)
def EMatchTheorems.getKindsFor (s : EMatchTheorems) (origin : Origin) : List EMatchTheoremKind :=
let thms := s.find origin
thms.map (·.kind)
@@ -460,6 +350,9 @@ def EMatchTheorems.getKindsFor (s : EMatchTheorems) (origin : Origin) : List EMa
def EMatchTheorem.getProofWithFreshMVarLevels (thm : EMatchTheorem) : MetaM Expr := do
Grind.getProofWithFreshMVarLevels thm
/-
TODO: group into a `grind` extension object
-/
private builtin_initialize ematchTheoremsExt : SimpleScopedEnvExtension EMatchTheorem (Theorems EMatchTheorem)
registerSimpleScopedEnvExtension {
addEntry := Theorems.insert
@@ -1482,6 +1375,7 @@ def mkEMatchEqTheoremsForDef? (declName : Name) (showInfo := false) : MetaM (Opt
eqns.mapM fun eqn => do
mkEMatchEqTheorem eqn (normalizePattern := true) (showInfo := showInfo)
-- TODO: delete
private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (useLhs := true) (showInfo := false) : MetaM Unit := do
if wasOriginallyTheorem ( getEnv) declName then
ematchTheoremsExt.add ( mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs) (gen := thmKind.gen) (showInfo := showInfo)) attrKind
@@ -1492,26 +1386,34 @@ private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind
else
throwError s!"`{thmKind.toAttribute false}` attribute can only be applied to equational theorems or function definitions"
private def Extension.addGrindEqAttr (ext : Extension) (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (useLhs := true) (showInfo := false) : MetaM Unit := do
if wasOriginallyTheorem ( getEnv) declName then
ext.add (.ematch ( mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs) (gen := thmKind.gen) (showInfo := showInfo))) attrKind
else if let some thms mkEMatchEqTheoremsForDef? declName (showInfo := showInfo) then
unless useLhs do
throwError "`{.ofConstName declName}` is a definition, you must only use the left-hand side for extracting patterns"
thms.forM fun thm => ext.add (.ematch thm) attrKind
else
throwError s!"`{thmKind.toAttribute false}` attribute can only be applied to equational theorems or function definitions"
def EMatchTheorems.eraseDecl (s : EMatchTheorems) (declName : Name) : MetaM EMatchTheorems := do
let throwErr {α} : MetaM α :=
throwError "`{.ofConstName declName}` is not marked with the `[grind]` attribute"
if !wasOriginallyTheorem ( getEnv) declName then
if let some eqns getEqnsFor? declName then
let s := ematchTheoremsExt.getState ( getEnv)
unless eqns.all fun eqn => s.contains (.decl eqn) do
throwErr
throwNotMarkedWithGrindAttribute declName
return eqns.foldl (init := s) fun s eqn => s.erase (.decl eqn)
else
throwErr
throwNotMarkedWithGrindAttribute declName
else
unless ematchTheoremsExt.getState ( getEnv) |>.contains (.decl declName) do
throwErr
unless s.contains (.decl declName) do
throwNotMarkedWithGrindAttribute declName
return s.erase <| .decl declName
private def ensureNoMinIndexable (minIndexable : Bool) : MetaM Unit := do
if minIndexable then
throwError "redundant modifier `!` in `grind` attribute"
-- TODO: delete
def addEMatchAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (prios : SymbolPriorities)
(showInfo := false) (minIndexable := false) : MetaM Unit := do
match thmKind with
@@ -1534,6 +1436,28 @@ def addEMatchAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatch
let thm mkEMatchTheoremForDecl declName thmKind prios (showInfo := showInfo) (minIndexable := minIndexable)
ematchTheoremsExt.add thm attrKind
def Extension.addEMatchAttr (ext : Extension) (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (prios : SymbolPriorities)
(showInfo := false) (minIndexable := false) : MetaM Unit := do
match thmKind with
| .eqLhs _ =>
ensureNoMinIndexable minIndexable
ext.addGrindEqAttr declName attrKind thmKind (useLhs := true) (showInfo := showInfo)
| .eqRhs _ =>
ensureNoMinIndexable minIndexable
ext.addGrindEqAttr declName attrKind thmKind (useLhs := false) (showInfo := showInfo)
| .eqBoth _ =>
ensureNoMinIndexable minIndexable
ext.addGrindEqAttr declName attrKind thmKind (useLhs := true) (showInfo := showInfo)
ext.addGrindEqAttr declName attrKind thmKind (useLhs := false) (showInfo := showInfo)
| _ =>
let info getConstInfo declName
if !wasOriginallyTheorem ( getEnv) declName && !info.isCtor && !info.isAxiom then
ensureNoMinIndexable minIndexable
ext.addGrindEqAttr declName attrKind thmKind (showInfo := showInfo)
else
let thm mkEMatchTheoremForDecl declName thmKind prios (showInfo := showInfo) (minIndexable := minIndexable)
ext.add (.ematch thm) attrKind
private structure SelectM.State where
-- **Note**: hack, an attribute is not a tactic.
suggestions : Array Tactic.TryThis.Suggestion := #[]
@@ -1665,6 +1589,7 @@ Tries different modifiers, logs info messages with modifiers that worked, but st
Remark: if `backward.grind.inferPattern` is `true`, then `.default false` is used.
The parameter `showInfo` is only taken into account when `backward.grind.inferPattern` is `true`.
-/
-- TODO: delete
def addEMatchAttrAndSuggest (ref : Syntax) (declName : Name) (attrKind : AttributeKind) (prios : SymbolPriorities)
(minIndexable : Bool) (showInfo : Bool) (isParam : Bool := false) : MetaM Unit := do
let info getConstInfo declName
@@ -1677,6 +1602,25 @@ def addEMatchAttrAndSuggest (ref : Syntax) (declName : Name) (attrKind : Attribu
let thm mkEMatchTheoremAndSuggest ref declName prios minIndexable isParam
ematchTheoremsExt.add thm attrKind
/--
Tries different modifiers, logs info messages with modifiers that worked, but stores just the first one that worked.
Remark: if `backward.grind.inferPattern` is `true`, then `.default false` is used.
The parameter `showInfo` is only taken into account when `backward.grind.inferPattern` is `true`.
-/
-- TODO: delete
def Extension.addEMatchAttrAndSuggest (ext : Extension) (ref : Syntax) (declName : Name) (attrKind : AttributeKind) (prios : SymbolPriorities)
(minIndexable : Bool) (showInfo : Bool) (isParam : Bool := false) : MetaM Unit := do
let info getConstInfo declName
if !wasOriginallyTheorem ( getEnv) declName && !info.isCtor && !info.isAxiom then
ensureNoMinIndexable minIndexable
ext.addGrindEqAttr declName attrKind (.default false) (showInfo := showInfo)
else if backward.grind.inferPattern.get ( getOptions) then
ext.addEMatchAttr declName attrKind (.default false) prios (minIndexable := minIndexable) (showInfo := showInfo)
else
let thm mkEMatchTheoremAndSuggest ref declName prios minIndexable isParam
ext.add (.ematch thm) attrKind
def eraseEMatchAttr (declName : Name) : MetaM Unit := do
/-
Remark: consider the following example

View File

@@ -6,13 +6,14 @@ Authors: Leonardo de Moura
module
prelude
public import Lean.Meta.Tactic.Ext
public import Lean.Meta.Tactic.Grind.Extension
public section
namespace Lean.Meta.Grind
/-! Grind extensionality attribute to mark which `[ext]` theorems should be used. -/
/-- Extensionality theorems that can be used by `grind` -/
abbrev ExtTheorems := PHashSet Name
/-
TODO: group into a `grind` extension object
-/
builtin_initialize extTheoremsExt : SimpleScopedEnvExtension Name ExtTheorems
registerSimpleScopedEnvExtension {
initial := {}
@@ -28,7 +29,7 @@ def addExtAttr (declName : Name) (attrKind : AttributeKind) : CoreM Unit := do
validateExtAttr declName
extTheoremsExt.add declName attrKind
private def eraseDecl (s : ExtTheorems) (declName : Name) : CoreM ExtTheorems := do
def ExtTheorems.eraseDecl (s : ExtTheorems) (declName : Name) : CoreM ExtTheorems := do
if s.contains declName then
return s.erase declName
else
@@ -36,10 +37,13 @@ private def eraseDecl (s : ExtTheorems) (declName : Name) : CoreM ExtTheorems :=
def eraseExtAttr (declName : Name) : CoreM Unit := do
let s := extTheoremsExt.getState ( getEnv)
let s eraseDecl s declName
let s s.eraseDecl declName
modifyEnv fun env => extTheoremsExt.modifyState env fun _ => s
def isExtTheorem (declName : Name) : CoreM Bool := do
return extTheoremsExt.getState ( getEnv) |>.contains declName
def getGlobalExtTheorems : CoreM ExtTheorems := do
return extTheoremsExt.getState ( getEnv)
end Lean.Meta.Grind

View File

@@ -0,0 +1,221 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Expr
public import Lean.Data.PersistentHashMap
public import Lean.Meta.Tactic.Grind.Theorems
public section
namespace Lean.Meta.Grind
/-- Types that `grind` will case-split on. -/
structure CasesTypes where
casesMap : PHashMap Name Bool := {}
deriving Inhabited
def CasesTypes.insert (s : CasesTypes) (declName : Name) (eager : Bool) : CasesTypes :=
{ s with casesMap := s.casesMap.insert declName eager }
abbrev ExtTheorems := PHashSet Name
structure SymbolPriorities where
map : PHashMap Name Nat := {}
deriving Inhabited
/-- Inserts `declName ↦ prio` into `s`. -/
def SymbolPriorities.insert (s : SymbolPriorities) (declName : Name) (prio : Nat) : SymbolPriorities :=
{ s with map := s.map.insert declName prio }
inductive EMatchTheoremKind where
| eqLhs (gen : Bool)
| eqRhs (gen : Bool)
| eqBoth (gen : Bool)
| eqBwd
| fwd
| bwd (gen : Bool)
| leftRight
| rightLeft
| default (gen : Bool)
| user /- pattern specified using `grind_pattern` command -/
deriving Inhabited, BEq, Repr, Hashable
structure CnstrRHS where
/-- Abstracted universe level param names in the `rhs` -/
levelNames : Array Name
/-- Number of abstracted metavariable in the `rhs` -/
numMVars : Nat
/-- The actual `rhs`. -/
expr : Expr
deriving Inhabited, BEq, Repr
/--
Grind patterns may have constraints associated with them.
-/
inductive EMatchTheoremConstraint where
| /--
A constraint of the form `lhs =/= rhs`.
The `lhs` is one of the bound variables, and the `rhs` an abstract term that must not be definitionally
equal to a term `t` assigned to `lhs`. -/
notDefEq (lhs : Nat) (rhs : CnstrRHS)
| /--
A constraint of the form `lhs =?= rhs`.
The `lhs` is one of the bound variables, and the `rhs` an abstract term that must be definitionally
equal to a term `t` assigned to `lhs`. -/
defEq (lhs : Nat) (rhs : CnstrRHS)
| /--
A constraint of the form `size lhs < n`. The `lhs` is one of the bound variables.
The size is computed ignoring implicit terms, but sharing is not taken into account.
-/
sizeLt (lhs : Nat) (n : Nat)
| /--
A constraint of the form `depth lhs < n`. The `lhs` is one of the bound variables.
The depth is computed in constant time using the `approxDepth` field attached to expressions.
-/
depthLt (lhs : Nat) (n : Nat)
| /--
Instantiates the theorem only if its generation is less than `n`
-/
genLt (n : Nat)
| /--
Constraints of the form `is_ground x`. Instantiates the theorem only if
`x` is ground term.
-/
isGround (bvarIdx : Nat)
| /--
Constraints of the form `is_value x` and `is_strict_value x`.
A value is defined as
- A constructor fully applied to value arguments.
- A literal: numerals, strings, etc.
- A lambda. In the strict case, lambdas are not considered.
-/
isValue (bvarIdx : Nat) (strict : Bool)
| /--
Instantiates the theorem only if less than `n` instances have been generated for this theorem.
-/
maxInsts (n : Nat)
| /--
It instructs `grind` to postpone the instantiation of the theorem until `e` is known to be `true`.
-/
guard (e : Expr)
| /--
Similar to `guard`, but checks whether `e` is implied by asserting `¬e`.
-/
check (e : Expr)
| /--
Constraints of the form `not_value x` and `not_strict_value x`.
They are the negations of `is_value x` and `is_strict_value x`.
-/
notValue (bvarIdx : Nat) (strict : Bool)
deriving Inhabited, Repr, BEq
/-- A theorem for heuristic instantiation based on E-matching. -/
structure EMatchTheorem where
/--
It stores universe parameter names for universe polymorphic proofs.
Recall that it is non-empty only when we elaborate an expression provided by the user.
When `proof` is just a constant, we can use the universe parameter names stored in the declaration.
-/
levelParams : Array Name
proof : Expr
numParams : Nat
patterns : List Expr
/-- Contains all symbols used in `patterns`. -/
symbols : List HeadIndex
origin : Origin
/-- The `kind` is used for generating the `patterns`. We save it here to implement `grind?`. -/
kind : EMatchTheoremKind
/-- Stores whether patterns were inferred using the minimal indexable subexpression condition. -/
minIndexable : Bool
cnstrs : List EMatchTheoremConstraint := []
deriving Inhabited
instance : TheoremLike EMatchTheorem where
getSymbols thm := thm.symbols
setSymbols thm symbols := { thm with symbols }
getOrigin thm := thm.origin
getProof thm := thm.proof
getLevelParams thm := thm.levelParams
/-- A theorem marked with `@[grind inj]` -/
structure InjectiveTheorem where
levelParams : Array Name
proof : Expr
/-- Contains all symbols used in the term `f` at the theorem's conclusion: `Function.Injective f`. -/
symbols : List HeadIndex
origin : Origin
deriving Inhabited
instance : TheoremLike InjectiveTheorem where
getSymbols thm := thm.symbols
setSymbols thm symbols := { thm with symbols }
getOrigin thm := thm.origin
getProof thm := thm.proof
getLevelParams thm := thm.levelParams
inductive Entry where
| ext (declName : Name)
| funCC (declName : Name)
| cases (declName : Name) (eager : Bool)
| ematch (thm : EMatchTheorem)
| inj (thm : InjectiveTheorem)
deriving Inhabited
/-
**Note**: We currently have a single normalization and symbol priority sets for all `grind` attributes.
Reason: the E-match patterns must be normalized with respect to them. If we are using multiple
`grind` attributes, they patterns would have to be re-normalized using the union of all normalizers.
Alternative design: allow a single `grind` attribute per `grind` call. Cons: when creating a new
`grind` attribute users would have to carefully setup the normalizer to ensure all `grind` assumptions
are met. Cons: users would not be able to write `grind only [attr_1, attr_2]`.
-/
structure ExtensionState where
casesTypes : CasesTypes := {}
extThms : ExtTheorems := {}
funCC : NameSet := {}
ematch : Theorems EMatchTheorem := {}
inj : Theorems InjectiveTheorem := {}
deriving Inhabited
abbrev Extension := SimpleScopedEnvExtension Entry ExtensionState
def ExtensionState.addEntry (s : ExtensionState) (e : Entry) : ExtensionState :=
match e with
| .cases declName eager => { s with casesTypes := s.casesTypes.insert declName eager }
| .ext declName => { s with extThms := s.extThms.insert declName }
| .funCC declName => { s with funCC := s.funCC.insert declName }
| .ematch thm => { s with ematch := s.ematch.insert thm }
| .inj thm => { s with inj := s.inj.insert thm }
def mkExtension (name : Name := by exact decl_name%) : IO Extension :=
registerSimpleScopedEnvExtension {
name := name
initial := {}
addEntry := ExtensionState.addEntry
exportEntry? := fun lvl e => do
-- export only annotations on public decls
let declName := match e with
| .inj thm | .ematch thm =>
match thm.origin with
| .decl declName => declName
| _ => unreachable!
| .ext declName | .cases declName _ | .funCC declName => declName
guard (lvl == .private || !isPrivateName declName)
return e
}
/--
`grind` is parametrized by a collection of `ExtensionState`. The motivation is to allow
users to use multiple extensions simultaneously without merging them into a single structure.
The collection is scanned linearly. In practice, we expect the array to be very small.
-/
abbrev ExtensionStateArray := Array ExtensionState
def throwNotMarkedWithGrindAttribute (declName : Name) : CoreM α :=
throwError "`{.ofConstName declName}` is not marked with the `[grind]` attribute"
end Lean.Meta.Grind

View File

@@ -9,6 +9,9 @@ public import Lean.ScopedEnvExtension
public section
namespace Lean.Meta.Grind
/-
TODO: group into a `grind` extension object
-/
private builtin_initialize funCCExt : SimpleScopedEnvExtension Name NameSet
registerSimpleScopedEnvExtension {
initial := {}

View File

@@ -14,25 +14,15 @@ builtin_initialize registerTraceClass `grind.inj
builtin_initialize registerTraceClass `grind.inj.assert
builtin_initialize registerTraceClass `grind.debug.inj
/-- A theorem marked with `@[grind inj]` -/
structure InjectiveTheorem where
levelParams : Array Name
proof : Expr
/-- Contains all symbols used in the term `f` at the theorem's conclusion: `Function.Injective f`. -/
symbols : List HeadIndex
origin : Origin
deriving Inhabited
instance : TheoremLike InjectiveTheorem where
getSymbols thm := thm.symbols
setSymbols thm symbols := { thm with symbols }
getOrigin thm := thm.origin
getProof thm := thm.proof
getLevelParams thm := thm.levelParams
/-- Set of Injective theorems. -/
abbrev InjectiveTheorems := Theorems InjectiveTheorem
/-- A collections of sets of Injective theorems. -/
abbrev InjectiveTheoremsArray := TheoremsArray InjectiveTheorem
/-
TODO: group into a `grind` extension object
-/
private builtin_initialize injectiveTheoremsExt : SimpleScopedEnvExtension InjectiveTheorem (Theorems InjectiveTheorem)
registerSimpleScopedEnvExtension {
addEntry := Theorems.insert
@@ -85,9 +75,13 @@ def mkInjectiveTheorem (declName : Name) : MetaM InjectiveTheorem := do
proof, symbols
}
-- TODO: delete
def addInjectiveAttr (declName : Name) (attrKind : AttributeKind) : MetaM Unit := do
injectiveTheoremsExt.add ( mkInjectiveTheorem declName) attrKind
def Extension.addInjectiveAttr (ext : Extension) (declName : Name) (attrKind : AttributeKind) : MetaM Unit := do
ext.add (.inj ( mkInjectiveTheorem declName)) attrKind
def eraseInjectiveAttr (declName : Name) : MetaM Unit := do
let s := injectiveTheoremsExt.getState ( getEnv)
let s s.eraseDecl declName

View File

@@ -146,13 +146,13 @@ private def checkAndAddSplitCandidate (e : Expr) : GoalM Unit := do
return ()
unless ( isInductivePredicate declName) do
return ()
if ( get).split.casesTypes.isSplit declName then
if ( isSplit declName) then
addDefaultSplitCandidate e
else if ( getConfig).splitIndPred then
addDefaultSplitCandidate e
| .fvar .. =>
let .const declName _ := ( whnf ( inferType e)).getAppFn | return ()
if ( get).split.casesTypes.isSplit declName then
if ( isSplit declName) then
addDefaultSplitCandidate e
| .forallE _ d _ _ =>
let currSplitSource := ( readThe Context).splitSource
@@ -275,8 +275,8 @@ private def addMatchEqns (f : Expr) (generation : Nat) : GoalM Unit := do
@[specialize]
private def activateTheoremsCore [TheoremLike α] (declName : Name)
(getThms : GoalM (Theorems α))
(setThms : Theorems α GoalM Unit)
(getThms : GoalM (TheoremsArray α))
(setThms : TheoremsArray α GoalM Unit)
(reinsertThm : α GoalM Unit)
(activateThm : α GoalM Unit) : GoalM Unit := do
if let some (thms, s) := ( getThms).retrieve? declName then
@@ -444,7 +444,7 @@ private def tryEta (e : Expr) (generation : Nat) : GoalM Unit := do
Returns `true` if we should use `funCC` for applications of the given constant symbol.
-/
private def useFunCongrAtDecl (declName : Name) : GrindM Bool := do
if ( readThe Grind.Context).funCCs.contains declName then
if ( hasFunCCModifier declName) then
return true
if ( isInstance declName) then
/- **Note**: Instances are support elements. No `funCC` -/

View File

@@ -182,9 +182,9 @@ private partial def introNext (goal : Goal) (generation : Nat) : GrindM IntroRes
else
return .done goal
private def isEagerCasesCandidate (goal : Goal) (type : Expr) : Bool := Id.run do
private def isEagerCasesCandidate (type : Expr) : GrindM Bool := do
let .const declName _ := type.getAppFn | return false
return goal.split.casesTypes.isEagerSplit declName
isEagerSplit declName
/-- Returns `true` if `type` is an inductive type with at most one constructor. -/
private def isCheapInductive (type : Expr) : CoreM Bool := do
@@ -215,7 +215,7 @@ private def applyCases? (goal : Goal) (fvarId : FVarId) (kp : ActionCont) : Grin
Example: `a b` is defined as `∃ x, b = a * x`
-/
let type whnf ( fvarId.getType)
unless isEagerCasesCandidate goal type do return none
unless ( isEagerCasesCandidate type) do return none
if ( cheapCasesOnly) then
unless ( isCheapInductive type) do return none
if let .const declName _ := type.getAppFn then
@@ -268,7 +268,7 @@ def intros (generation : Nat) : Action :=
/-- Asserts a new fact `prop` with proof `proof` to the given `goal`. -/
private def assertAt (proof : Expr) (prop : Expr) (generation : Nat) : Action := fun goal kna kp => do
if isEagerCasesCandidate goal prop then
if ( isEagerCasesCandidate prop) then
let mvarId goal.mvarId.assert ( mkFreshUserName `h) prop proof
intros generation { goal with mvarId } kna kp
else goal.withContext do

View File

@@ -32,26 +32,51 @@ import Lean.Meta.Tactic.Grind.Core
public section
namespace Lean.Meta.Grind
/--
Returns the `ExtensionState` for the default `grind` attribute.
-/
def getDefaultExtensionState : MetaM ExtensionState := do
-- **TODO**: update after update stage0
let casesTypes getCasesTypes
let funCC getFunCCSet
let extThms getGlobalExtTheorems
let ematch getEMatchTheorems
let inj getInjectiveTheorems
return {
casesTypes, funCC, extThms, ematch, inj
}
def getOnlyExtensionState : MetaM ExtensionState := do
let casesTypes getCasesTypes
let funCC getFunCCSet
let extThms getGlobalExtTheorems
return {
casesTypes, funCC, extThms
}
structure Params where
config : Grind.Config
ematch : EMatchTheorems := default
inj : InjectiveTheorems := default
symPrios : SymbolPriorities := {}
casesTypes : CasesTypes := {}
extensions : ExtensionStateArray := #[]
extra : PArray EMatchTheorem := {}
extraInj : PArray InjectiveTheorem := {}
extraFacts : PArray Expr := {}
funCCs : NameSet := {}
symPrios : SymbolPriorities := {}
norm : Simp.Context
normProcs : Array Simprocs
anchorRefs? : Option (Array AnchorRef) := none
-- TODO: inductives to split
def mkParams (config : Grind.Config) : MetaM Params := do
def mkParams (config : Grind.Config) (extensions : ExtensionStateArray) : MetaM Params := do
let norm Grind.getSimpContext config
let normProcs Grind.getSimprocs
let symPrios getGlobalSymbolPriorities
return { config, norm, normProcs, symPrios }
return { config, norm, normProcs, symPrios, extensions }
def mkDefaultParams (config : Grind.Config) : MetaM Params := do
mkParams config #[ getDefaultExtensionState]
def mkOnlyParams (config : Grind.Config) : MetaM Params := do
mkParams config #[ getOnlyExtensionState]
def mkMethods (evalTactic? : Option EvalTactic := none) : CoreM Methods := do
let builtinPropagators builtinPropagatorsRef.get
@@ -99,9 +124,11 @@ def GrindM.run (x : GrindM α) (params : Params) (evalTactic? : Option EvalTacti
let simp := params.norm
let config := params.config
let symPrios := params.symPrios
let extensions := params.extensions
let anchorRefs? := params.anchorRefs?
let funCCs := params.funCCs
x ( mkMethods evalTactic?).toMethodsRef { config, anchorRefs?, simpMethods, simp, trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr, ordEqExpr, intExpr, symPrios, funCCs }
x ( mkMethods evalTactic?).toMethodsRef
{ config, anchorRefs?, simpMethods, simp, extensions, symPrios
trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr, ordEqExpr, intExpr }
|>.run' { scState }
private def mkCleanState (mvarId : MVarId) (params : Params) : MetaM Clean.State := mvarId.withContext do
@@ -136,11 +163,11 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do
let bfalseExpr getBoolFalseExpr
let natZeroExpr getNatZeroExpr
let ordEqExpr getOrderingEqExpr
let thmMap := params.ematch
let casesTypes := params.casesTypes
let thmEMatch := params.extensions.map fun ext => ext.ematch
let thmInj := params.extensions.map fun ext => ext.inj
let clean mkCleanState mvarId params
let sstates Solvers.mkInitialStates
GoalM.run' { mvarId, ematch.thmMap := thmMap, inj.thms := params.inj, split.casesTypes := casesTypes, clean, sstates } do
GoalM.run' { mvarId, ematch.thmMap := thmEMatch, inj.thms := thmInj, clean, sstates } do
initENodeCore falseExpr (interpreted := true) (ctor := false)
initENodeCore trueExpr (interpreted := true) (ctor := false)
initENodeCore btrueExpr (interpreted := false) (ctor := true)

View File

@@ -52,6 +52,9 @@ private def addBuiltin (propagatorName : Name) (stx : Syntax) : AttrM Unit := do
declareBuiltin initDeclName val
go.run' {}
/-
**Note**: We currently use the same propagators for all `grind` attributes.
-/
builtin_initialize
registerBuiltinAttribute {
ref := by exact decl_name%

View File

@@ -0,0 +1,30 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Tactic.Grind.Types
meta import Lean.Meta.Tactic.Grind.Attr
public section
namespace Lean.Meta.Grind
macro (name := _root_.Lean.Parser.Command.registerGrindAttr) doc:(docComment)?
"register_grind_attr" id:ident : command => do
let str1 := id.getId.toString
let idParser1 := mkIdentFrom id (`Lean.Parser.Attr ++ id.getId)
let str2 := id.getId.toString ++ "!"
let idParser2 := mkIdentFrom id (`Lean.Parser.Attr ++ (id.getId.appendAfter "!"))
let str3 := id.getId.toString ++ "?"
let idParser3 := mkIdentFrom id (`Lean.Parser.Attr ++ (id.getId.appendAfter "?"))
let str4 := id.getId.toString ++ "!?"
let idParser4 := mkIdentFrom id (`Lean.Parser.Attr ++ (id.getId.appendAfter "!?"))
`($[$doc:docComment]? initialize ext : Extension registerAttr $(quote id.getId) (ref := $(quote id.getId))
$[$doc:docComment]? syntax (name := $idParser1:ident) $(quote str1):str (ppSpace Lean.Parser.Attr.grindMod)? : attr
$[$doc:docComment]? syntax (name := $idParser2:ident) $(quote str2):str (ppSpace Lean.Parser.Attr.grindMod)? : attr
$[$doc:docComment]? syntax (name := $idParser3:ident) $(quote str3):str (ppSpace Lean.Parser.Attr.grindMod)? : attr
$[$doc:docComment]? syntax (name := $idParser4:ident) $(quote str4):str (ppSpace Lean.Parser.Attr.grindMod)? : attr
)
end Lean.Meta.Grind

View File

@@ -18,6 +18,9 @@ import Init.Grind.Norm
public section
namespace Lean.Meta.Grind
/-
TODO: group into a `grind` extension object
-/
builtin_initialize normExt : SimpExtension mkSimpExt
def registerNormTheorems (preDeclNames : Array Name) (postDeclNames : Array Name) : MetaM Unit := do
@@ -176,14 +179,18 @@ private def addDeclToUnfold (s : SimpTheorems) (declName : Name) : MetaM SimpThe
else
return s
/-- Returns the simplification context used by `grind`. -/
protected def getSimpContext (config : Grind.Config) : MetaM Simp.Context := do
def getNormTheorems : MetaM SimpTheorems := do
let mut thms normExt.getTheorems
thms addDeclToUnfold thms ``GE.ge
thms addDeclToUnfold thms ``GT.gt
thms addDeclToUnfold thms ``Nat.cast
thms addDeclToUnfold thms ``Bool.xor
thms addDeclToUnfold thms ``Ne
return thms
/-- Returns the simplification context used by `grind`. -/
protected def getSimpContext (config : Grind.Config) : MetaM Simp.Context := do
let thms getNormTheorems
Simp.mkContext
(config :=
{ arith := true

View File

@@ -174,4 +174,36 @@ def getProofForDecl (declName : Name) : MetaM Expr := do
let us := info.levelParams.map mkLevelParam
return mkConst declName us
/--
A `TheoremsArray α` is a collection of `Theorems α`.
The array is scanned linear during theorem activation.
This avoids the need for efficiently merging the `Theorems α` data structure.
-/
abbrev TheoremsArray (α : Type) := Array (Theorems α)
@[specialize]
def TheoremsArray.retrieve? (s : TheoremsArray α) (sym : Name) : Option (List α × TheoremsArray α) := Id.run do
for h : i in *...s.size do
if let some (thms, a) s[i].retrieve? sym then
return some (thms, s.set i a)
return none
def TheoremsArray.insert [TheoremLike α] (s : TheoremsArray α) (thm : α) : TheoremsArray α := Id.run do
if s.isEmpty then
let thms := { : Theorems α}
#[thms.insert thm]
else
s.modify 0 (·.insert thm)
def TheoremsArray.isErased (s : TheoremsArray α) (origin : Origin) : Bool :=
s.any fun thms => thms.erased.contains origin
def TheoremsArray.find (s : TheoremsArray α) (origin : Origin) : List α := Id.run do
let mut r := []
for h : i in *...s.size do
let thms := s[i].find origin
unless thms.isEmpty do
r := r ++ thms
return r
end Lean.Meta.Grind

View File

@@ -9,6 +9,7 @@ public import Lean.Meta.Tactic.Simp.Types
public import Lean.Meta.Tactic.Grind.AlphaShareCommon
public import Lean.Meta.Tactic.Grind.Attr
public import Lean.Meta.Tactic.Grind.CheckResult
public import Lean.Meta.Tactic.Grind.Extension
public import Init.Data.Queue
import Lean.Meta.Tactic.Grind.ExprPtr
import Lean.HeadIndex
@@ -158,8 +159,7 @@ structure Context where
splitSource : SplitSource := .input
/-- Symbol priorities for inferring E-matching patterns -/
symPrios : SymbolPriorities
/-- Global declarations marked with `@[grind funCC]` -/
funCCs : NameSet
extensions : ExtensionStateArray := #[]
trueExpr : Expr
falseExpr : Expr
natZExpr : Expr
@@ -346,6 +346,18 @@ def reportMVarInternalization : GrindM Bool :=
def getSymbolPriorities : GrindM SymbolPriorities := do
return ( readThe Context).symPrios
/--
Returns `true` if we `declName` is tagged with `funCC` modifier.
-/
def hasFunCCModifier (declName : Name) : GrindM Bool :=
return ( readThe Context).extensions.any fun ext => ext.funCC.contains declName
def isSplit (declName : Name) : GrindM Bool :=
return ( readThe Context).extensions.any fun ext => ext.casesTypes.isSplit declName
def isEagerSplit (declName : Name) : GrindM Bool :=
return ( readThe Context).extensions.any fun ext => ext.casesTypes.isEagerSplit declName
/--
Returns `true` if `declName` is the name of a `match` equation or a `match` congruence equation.
-/
@@ -758,7 +770,7 @@ structure EMatch.State where
Inactive global theorems. As we internalize terms, we activate theorems as we find their symbols.
Local theorem provided by users are added directly into `newThms`.
-/
thmMap : EMatchTheorems
thmMap : EMatchTheoremsArray
/-- Goal modification time. -/
gmt : Nat := 0
/-- Active theorems that we have performed ematching at least once. -/
@@ -840,8 +852,6 @@ structure SplitArg where
structure Split.State where
/-- Number of splits performed to get to this goal. -/
num : Nat := 0
/-- Inductive datatypes marked for case-splitting -/
casesTypes : CasesTypes := {}
/-- Case-split candidates. -/
candidates : List SplitInfo := []
/-- Case-splits that have been inserted at `candidates` at some point. -/
@@ -901,7 +911,7 @@ structure InjectiveInfo where
/-- State for injective theorem support. -/
structure Injective.State where
thms : InjectiveTheorems
thms : InjectiveTheoremsArray
fns : PHashMap ExprPtr InjectiveInfo := {}
deriving Inhabited

View File

@@ -55,7 +55,7 @@ def grindDischarger (mvarId : MVarId) : MetaM (Option (List MVarId)) := do
let [subgoal] mvarId.apply markerExpr
| return none
-- Solve the subgoal with grind
let params Grind.mkParams {}
let params Grind.mkDefaultParams {}
let result Grind.main subgoal params
if result.hasFailed then
return none

View File

@@ -126,7 +126,7 @@ def checkInductive (localDecl : LocalDecl) : M Unit := do
let .const declName _ := type.getAppFn | return ()
let .inductInfo val getConstInfo declName | return ()
if ( isEligible declName) then
unless ( Grind.isSplit declName) do
unless ( Grind.isGlobalSplit declName) do
modify fun s => { s with indCandidates := s.indCandidates.push { fvarId := localDecl.fvarId, val } }
unsafe abbrev Cache := PtrSet Expr

View File

@@ -1,3 +1,4 @@
// update me!
#include "util/options.h"
namespace lean {

View File

@@ -3,7 +3,7 @@ import Lean
open Lean Meta Tactic Grind
def runGrind (x : GrindM α) : MetaM α := do
GrindM.run x ( mkParams {})
GrindM.run x ( mkDefaultParams {})
@[noinline] def mkA (x : Nat) := x + 1

View File

@@ -31,3 +31,5 @@ initialize registerBuiltinAttribute {
logInfo m!"trace_add attribute added to {decl}"
-- applicationTime := .afterCompilation
}
register_grind_attr my_grind

View File

@@ -155,3 +155,22 @@ termination_by n => n
end
end TraceAdd
namespace GrindAttr
opaque f : Nat Nat
opaque g : Nat Nat
@[my_grind] theorem fax : f (f x) = f x := sorry
@[my_grind =] theorem fax2 : f (f (f x)) = f x := by
fail_if_success grind
grind [my_grind]
@[my_grind? .] theorem fg : g (f x) = x := sorry
@[my_grind? =] theorem fax3 : f (f (f x)) = f x := sorry
@[my_grind!? .] theorem fax4 : f (f (f x)) = f x := sorry
end GrindAttr