mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-17 18:34:06 +00:00
Compare commits
14 Commits
57df23f27e
...
grind_ext_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
481cb95a59 | ||
|
|
d208dba93b | ||
|
|
eb3e6047b7 | ||
|
|
a75de82685 | ||
|
|
61127e672e | ||
|
|
795fc2ca63 | ||
|
|
0a9b6d538f | ||
|
|
34258fd00d | ||
|
|
bb031ff6d2 | ||
|
|
cd85eb6620 | ||
|
|
0feaf40fc5 | ||
|
|
32cd5390cd | ||
|
|
61d4b00b42 | ||
|
|
b8460f58bc |
@@ -255,6 +255,7 @@ theorem fg_eq (h : x > 0) : f (g x) = x
|
||||
-- With minimal subexpression:
|
||||
@[grind! <-] theorem fg_eq (h : x > 0) : f (g x) = x
|
||||
-- Pattern selected: `g x`
|
||||
```
|
||||
-/
|
||||
syntax (name := grind!) "grind!" (ppSpace grindMod)? : attr
|
||||
/--
|
||||
|
||||
@@ -88,15 +88,14 @@ def mkConfig (items : Array (TSyntax `Lean.Parser.Tactic.configItem)) : TermElab
|
||||
elabConfigItems defaultConfig items
|
||||
|
||||
def mkParams (config : Grind.Config) : MetaM Params := do
|
||||
let params ← Meta.Grind.mkParams config
|
||||
let casesTypes ← Grind.getCasesTypes
|
||||
let mut ematch ← getEMatchTheorems
|
||||
let params ← Meta.Grind.mkDefaultParams config
|
||||
let mut ematch := params.extensions[0]!.ematch
|
||||
for declName in muteExt.getState (← getEnv) do
|
||||
try
|
||||
ematch ← ematch.eraseDecl declName
|
||||
catch _ =>
|
||||
pure () -- Ignore failures here.
|
||||
return { params with ematch, casesTypes }
|
||||
return { params with extensions[0].ematch := ematch }
|
||||
|
||||
/-- Returns the total number of generated instances. -/
|
||||
def sum (cs : PHashMap Grind.Origin Nat) : Nat := Id.run do
|
||||
|
||||
@@ -210,22 +210,13 @@ def elabGrindParamsAndSuggestions
|
||||
def mkGrindParams
|
||||
(config : Grind.Config) (only : Bool) (ps : TSyntaxArray ``Parser.Tactic.grindParam) (mvarId : MVarId) :
|
||||
TermElabM Grind.Params := do
|
||||
let params ← Grind.mkParams config
|
||||
let ematch ← if only then pure default else Grind.getEMatchTheorems
|
||||
let inj ← if only then pure default else Grind.getInjectiveTheorems
|
||||
/-
|
||||
**Note**: We used to skip the global cases attribute when `only = true`, but
|
||||
this is not very effective. We now use anchors to restrict the set of case-splits.
|
||||
-/
|
||||
let casesTypes ← Grind.getCasesTypes
|
||||
let funCCs ← Grind.getFunCCSet
|
||||
let params := { params with ematch, casesTypes, inj, funCCs }
|
||||
let params ← if only then Grind.mkOnlyParams config else Grind.mkDefaultParams config
|
||||
let suggestions ← if config.suggestions then
|
||||
LibrarySuggestions.select mvarId { caller := some "grind" }
|
||||
else
|
||||
pure #[]
|
||||
let mut params ← elabGrindParamsAndSuggestions params ps suggestions (only := only) (lax := config.lax)
|
||||
trace[grind.debug.inj] "{params.inj.getOrigins.map (·.pp)}"
|
||||
trace[grind.debug.inj] "{params.extensions[0]!.inj.getOrigins.map (·.pp)}"
|
||||
if params.anchorRefs?.isSome then
|
||||
/-
|
||||
**Note**: anchors are automatically computed in interactive mode where
|
||||
|
||||
@@ -20,7 +20,49 @@ open Meta
|
||||
`grind` parameter elaboration
|
||||
-/
|
||||
|
||||
def warnRedundantEMatchArg (s : Grind.EMatchTheorems) (declName : Name) : MetaM Unit := do
|
||||
def _root_.Lean.Meta.Grind.Params.insertCasesTypes (params : Grind.Params) (declName : Name) (eager : Bool) : Grind.Params :=
|
||||
{ params with extensions := params.extensions.modify 0 fun ext => { ext with casesTypes := ext.casesTypes.insert declName eager } }
|
||||
|
||||
def _root_.Lean.Meta.Grind.Params.eraseCasesTypes (params : Grind.Params) (declName : Name) : CoreM Grind.Params := do
|
||||
unless params.extensions.any fun ext => ext.casesTypes.contains declName do
|
||||
Grind.throwNotMarkedWithGrindAttribute declName
|
||||
return { params with extensions := params.extensions.modify 0 fun ext => { ext with casesTypes := ext.casesTypes.erase declName } }
|
||||
|
||||
def _root_.Lean.Meta.Grind.Params.insertFunCC (params : Grind.Params) (declName : Name) : Grind.Params :=
|
||||
{ params with extensions := params.extensions.modify 0 fun ext => { ext with funCC := ext.funCC.insert declName } }
|
||||
|
||||
def _root_.Lean.Meta.Grind.Params.containsEMatch (params : Grind.Params) (declName : Name) : Bool :=
|
||||
params.extensions.any fun ext => ext.ematch.contains (.decl declName)
|
||||
|
||||
def _root_.Lean.Meta.Grind.Params.eraseEMatchCore (params : Grind.Params) (declName : Name) : Grind.Params :=
|
||||
{ params with extensions := params.extensions.modify 0 fun ext => { ext with ematch := ext.ematch.erase (.decl declName) } }
|
||||
|
||||
def _root_.Lean.Meta.Grind.Params.eraseEMatch (params : Grind.Params) (declName : Name) : MetaM Grind.Params := do
|
||||
if !wasOriginallyTheorem (← getEnv) declName then
|
||||
if let some eqns ← getEqnsFor? declName then
|
||||
unless eqns.all fun eqn => params.containsEMatch eqn do
|
||||
Grind.throwNotMarkedWithGrindAttribute declName
|
||||
return eqns.foldl (init := params) fun params eqn => params.eraseEMatchCore eqn
|
||||
else
|
||||
Grind.throwNotMarkedWithGrindAttribute declName
|
||||
else
|
||||
unless params.containsEMatch declName do
|
||||
Grind.throwNotMarkedWithGrindAttribute declName
|
||||
return params.eraseEMatchCore declName
|
||||
|
||||
def _root_.Lean.Meta.Grind.Params.eraseInj (params : Grind.Params) (declName : Name) : Grind.Params :=
|
||||
{ params with extensions := params.extensions.modify 0 fun ext => { ext with inj := ext.inj.erase (.decl declName) } }
|
||||
|
||||
def _root_.Lean.Meta.Grind.ExtensionStateArray.getKindsFor (s : Grind.ExtensionStateArray) (origin : Grind.Origin) : List Grind.EMatchTheoremKind := Id.run do
|
||||
let mut result := []
|
||||
for ext in s do
|
||||
let s : Grind.EMatchTheorems := ext.ematch
|
||||
let ks := s.getKindsFor origin
|
||||
unless ks.isEmpty do
|
||||
result := result ++ ks
|
||||
return result
|
||||
|
||||
def warnRedundantEMatchArg (s : Grind.ExtensionStateArray) (declName : Name) : MetaM Unit := do
|
||||
let minIndexable := false -- TODO: infer it
|
||||
let kinds ← match s.getKindsFor (.decl declName) with
|
||||
| [] => return ()
|
||||
@@ -54,9 +96,9 @@ public def addEMatchTheorem (params : Grind.Params) (id : Ident) (declName : Nam
|
||||
let thm₁ ← Grind.mkEMatchTheoremForDecl declName (.eqLhs gen) params.symPrios
|
||||
let thm₂ ← Grind.mkEMatchTheoremForDecl declName (.eqRhs gen) params.symPrios
|
||||
if warn &&
|
||||
params.ematch.containsWithSamePatterns thm₁.origin thm₁.patterns thm₁.cnstrs &&
|
||||
params.ematch.containsWithSamePatterns thm₂.origin thm₂.patterns thm₂.cnstrs then
|
||||
warnRedundantEMatchArg params.ematch declName
|
||||
params.extensions.containsWithSamePatterns thm₁.origin thm₁.patterns thm₁.cnstrs &&
|
||||
params.extensions.containsWithSamePatterns thm₂.origin thm₂.patterns thm₂.cnstrs then
|
||||
warnRedundantEMatchArg params.extensions declName
|
||||
return { params with extra := params.extra.push thm₁ |>.push thm₂ }
|
||||
| _ =>
|
||||
if kind matches .eqLhs _ | .eqRhs _ then
|
||||
@@ -65,8 +107,8 @@ public def addEMatchTheorem (params : Grind.Params) (id : Ident) (declName : Nam
|
||||
Grind.mkEMatchTheoremAndSuggest id declName params.symPrios minIndexable (isParam := true)
|
||||
else
|
||||
Grind.mkEMatchTheoremForDecl declName kind params.symPrios (minIndexable := minIndexable)
|
||||
if warn && params.ematch.containsWithSamePatterns thm.origin thm.patterns thm.cnstrs then
|
||||
warnRedundantEMatchArg params.ematch declName
|
||||
if warn && params.extensions.containsWithSamePatterns thm.origin thm.patterns thm.cnstrs then
|
||||
warnRedundantEMatchArg params.extensions declName
|
||||
return { params with extra := params.extra.push thm }
|
||||
| .defn =>
|
||||
if (← isReducible declName) then
|
||||
@@ -154,9 +196,13 @@ def processParam (params : Grind.Params)
|
||||
catch err =>
|
||||
if (← resolveLocalName id.getId).isSome then
|
||||
throwErrorAt id "redundant parameter `{id}`, `grind` uses local hypotheses automatically"
|
||||
else if let some ext ← Grind.getExtension? id.getId then
|
||||
if let some mod := mod? then
|
||||
throwErrorAt mod "invalid use of modifier in `grind` attribute `{id.getId}`"
|
||||
return { params with extensions := params.extensions.push (ext.getState (← getEnv)) }
|
||||
else if !id.getId.getPrefix.isAnonymous then
|
||||
-- Fall back to term elaboration for compound identifiers like `foo.le` (dot notation on declarations)
|
||||
return ← processTermParam params p mod? id minIndexable
|
||||
return (← processTermParam params p mod? id minIndexable)
|
||||
else
|
||||
throw err
|
||||
Linter.checkDeprecated declName
|
||||
@@ -179,7 +225,7 @@ def processParam (params : Grind.Params)
|
||||
if incremental then throwError "`cases` parameter are not supported here"
|
||||
ensureNoMinIndexable minIndexable
|
||||
withRef p <| Grind.validateCasesAttr declName eager
|
||||
params := { params with casesTypes := params.casesTypes.insert declName eager }
|
||||
params := params.insertCasesTypes declName eager
|
||||
| .intro =>
|
||||
if let some info ← Grind.isCasesAttrPredicateCandidate? declName false then
|
||||
if incremental then throwError "`cases` parameter are not supported here"
|
||||
@@ -194,7 +240,7 @@ def processParam (params : Grind.Params)
|
||||
throwError "`[grind ext]` cannot be set using parameters"
|
||||
| .infer =>
|
||||
if let some declName ← Grind.isCasesAttrCandidate? declName false then
|
||||
params := { params with casesTypes := params.casesTypes.insert declName false }
|
||||
params := params.insertCasesTypes declName false
|
||||
if let some info ← isInductivePredicate? declName then
|
||||
-- If it is an inductive predicate,
|
||||
-- we also add the constructors (intro rules) as E-matching rules
|
||||
@@ -207,7 +253,7 @@ def processParam (params : Grind.Params)
|
||||
ensureNoMinIndexable minIndexable
|
||||
params := { params with symPrios := params.symPrios.insert declName prio }
|
||||
| .funCC =>
|
||||
params := { params with funCCs := params.funCCs.insert declName }
|
||||
params := params.insertFunCC declName
|
||||
return params
|
||||
|
||||
/--
|
||||
@@ -228,11 +274,11 @@ public def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.T
|
||||
Linter.checkDeprecated declName
|
||||
if let some declName ← Grind.isCasesAttrCandidate? declName false then
|
||||
Grind.ensureNotBuiltinCases declName
|
||||
params := { params with casesTypes := (← params.casesTypes.eraseDecl declName) }
|
||||
params ← params.eraseCasesTypes declName
|
||||
else if (← Grind.isInjectiveTheorem declName) then
|
||||
params := { params with inj := params.inj.erase (.decl declName) }
|
||||
params := params.eraseInj declName
|
||||
else
|
||||
params := { params with ematch := (← params.ematch.eraseDecl declName) }
|
||||
params ← params.eraseEMatch declName
|
||||
| `(Parser.Tactic.grindParam| $[$mod?:grindMod]? $id:ident) =>
|
||||
-- Check if this is dot notation on a local variable (e.g., `n.triv` for `Nat.triv n`).
|
||||
-- If so, process as term to let elaboration resolve the dot notation properly.
|
||||
@@ -294,7 +340,7 @@ public def withParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic
|
||||
let mut params := params
|
||||
if only then
|
||||
params := { params with
|
||||
ematch := {}
|
||||
extensions := params.extensions.map fun ext => { ext with ematch := {} }
|
||||
anchorRefs? := none
|
||||
}
|
||||
params ← elabGrindParams params ps (only := only) (incremental := true)
|
||||
|
||||
@@ -48,6 +48,7 @@ public import Lean.Meta.Tactic.Grind.Filter
|
||||
public import Lean.Meta.Tactic.Grind.CollectParams
|
||||
public import Lean.Meta.Tactic.Grind.Finish
|
||||
public import Lean.Meta.Tactic.Grind.FunCC
|
||||
public import Lean.Meta.Tactic.Grind.RegisterCommand
|
||||
public section
|
||||
namespace Lean
|
||||
|
||||
|
||||
@@ -63,40 +63,89 @@ def getAttrKindFromOpt (stx : Syntax) : CoreM AttrKind := do
|
||||
def throwInvalidUsrModifier : CoreM α :=
|
||||
throwError "the modifier `usr` is only relevant in parameters for `grind only`"
|
||||
|
||||
private def Extension.addCasesAttr (ext : Extension) (declName : Name) (eager : Bool) (attrKind : AttributeKind) : CoreM Unit := do
|
||||
validateCasesAttr declName eager
|
||||
ext.add (.cases declName eager) attrKind
|
||||
|
||||
private def Extension.addExtAttr (ext : Extension) (declName : Name) (attrKind : AttributeKind) : CoreM Unit := do
|
||||
validateExtAttr declName
|
||||
ext.add (.ext declName) attrKind
|
||||
|
||||
private def Extension.addFunCCAttr (ext : Extension) (declName : Name) (attrKind : AttributeKind) : CoreM Unit := do
|
||||
ext.add (.funCC declName) attrKind
|
||||
|
||||
private def Extension.eraseExtAttr (ext : Extension) (declName : Name) : CoreM Unit := do
|
||||
let s := ext.getState (← getEnv)
|
||||
let extThms ← s.extThms.eraseDecl declName
|
||||
modifyEnv fun env => ext.modifyState env fun s => { s with extThms }
|
||||
|
||||
private def Extension.eraseCasesAttr (ext : Extension) (declName : Name) : CoreM Unit := do
|
||||
ensureNotBuiltinCases declName
|
||||
let s := ext.getState (← getEnv)
|
||||
let casesTypes ← s.casesTypes.eraseDecl declName
|
||||
modifyEnv fun env => ext.modifyState env fun s => { s with casesTypes }
|
||||
|
||||
private def Extension.eraseFunCCAttr (ext : Extension) (declName : Name) : CoreM Unit := do
|
||||
let s := ext.getState (← getEnv)
|
||||
unless s.funCC.contains declName do
|
||||
throwNotMarkedWithGrindAttribute declName
|
||||
let funCC := s.funCC.erase declName
|
||||
modifyEnv fun env => ext.modifyState env fun s => { s with funCC }
|
||||
|
||||
private def Extension.eraseEMatchAttr (ext : Extension) (declName : Name) : MetaM Unit := do
|
||||
let s := ext.getState (← getEnv)
|
||||
let ematch ← s.ematch.eraseDecl declName
|
||||
modifyEnv fun env => ext.modifyState env fun s => { s with ematch }
|
||||
|
||||
private def Extension.eraseInjectiveAttr (ext : Extension) (declName : Name) : MetaM Unit := do
|
||||
let s := ext.getState (← getEnv)
|
||||
let inj ← s.inj.eraseDecl declName
|
||||
modifyEnv fun env => ext.modifyState env fun s => { s with inj }
|
||||
|
||||
private def Extension.isExtTheorem (ext : Extension) (declName : Name) : CoreM Bool := do
|
||||
return ext.getState (← getEnv) |>.extThms.contains declName
|
||||
|
||||
private def Extension.isInjectiveTheorem (ext : Extension) (declName : Name) : CoreM Bool := do
|
||||
return ext.getState (← getEnv) |>.inj.contains (.decl declName)
|
||||
|
||||
private def Extension.hasFunCCAttr (ext : Extension) (declName : Name) : CoreM Bool := do
|
||||
return ext.getState (← getEnv) |>.funCC.contains declName
|
||||
|
||||
/--
|
||||
Auxiliary function for registering `grind`, `grind!`, `grind?`, and `grind!?` attributes.
|
||||
`grind!` is like `grind` but selects minimal indexable subterms.
|
||||
The `grind?` and `grind!?` are aliases for `grind` and `grind!` which displays patterns using `logInfo`.
|
||||
It is just a convenience for users.
|
||||
-/
|
||||
private def registerGrindAttr (minIndexable : Bool) (showInfo : Bool) : IO Unit :=
|
||||
private def mkGrindAttr (attrName : Name) (minIndexable : Bool) (showInfo : Bool) (ext? : Option Extension := none) (ref : Name := by exact decl_name%) : IO Unit :=
|
||||
registerBuiltinAttribute {
|
||||
ref := ref
|
||||
name := match minIndexable, showInfo with
|
||||
| false, false => `grind
|
||||
| false, true => `grind?
|
||||
| true, false => `grind!
|
||||
| true, true => `grind!?
|
||||
| false, false => attrName
|
||||
| false, true => attrName.appendAfter "?"
|
||||
| true, false => attrName.appendAfter "!"
|
||||
| true, true => attrName.appendAfter "!?"
|
||||
descr :=
|
||||
let header := match minIndexable, showInfo with
|
||||
| false, false => "The `[grind]` attribute is used to annotate declarations."
|
||||
| false, true => "The `[grind?]` attribute is identical to the `[grind]` attribute, but displays inferred pattern information."
|
||||
| true, false => "The `[grind!]` attribute is used to annotate declarations, but selecting minimal indexable subterms."
|
||||
| true, true => "The `[grind!?]` attribute is identical to the `[grind!]` attribute, but displays inferred pattern information."
|
||||
header ++ "\
|
||||
| false, false => s!"The `[{attrName}]` attribute is used to annotate declarations."
|
||||
| false, true => s!"The `[{attrName}?]` attribute is identical to the `[{attrName}]` attribute, but displays inferred pattern information."
|
||||
| true, false => s!"The `[{attrName}!]` attribute is used to annotate declarations, but selecting minimal indexable subterms."
|
||||
| true, true => s!"The `[{attrName}!?]` attribute is identical to the `[{attrName}!]` attribute, but displays inferred pattern information."
|
||||
header ++ s!"\
|
||||
\
|
||||
When applied to an equational theorem, `[grind =]`, `[grind =_]`, or `[grind _=_]`\
|
||||
will mark the theorem for use in heuristic instantiations by the `grind` tactic,
|
||||
When applied to an equational theorem, `[{attrName} =]`, `[{attrName} =_]`, or `[{attrName} _=_]`\
|
||||
will mark the theorem for use in heuristic instantiations by the `{attrName}` tactic,
|
||||
using respectively the left-hand side, the right-hand side, or both sides of the theorem.\
|
||||
When applied to a function, `[grind =]` automatically annotates the equational theorems associated with that function.\
|
||||
When applied to a theorem `[grind ←]` will instantiate the theorem whenever it encounters the conclusion of the theorem
|
||||
When applied to a function, `[{attrName} =]` automatically annotates the equational theorems associated with that function.\
|
||||
When applied to a theorem `[{attrName} ←]` will instantiate the theorem whenever it encounters the conclusion of the theorem
|
||||
(that is, it will use the theorem for backwards reasoning).\
|
||||
When applied to a theorem `[grind →]` will instantiate the theorem whenever it encounters sufficiently many of the propositional hypotheses
|
||||
When applied to a theorem `[{attrName} →]` will instantiate the theorem whenever it encounters sufficiently many of the propositional hypotheses
|
||||
(that is, it will use the theorem for forwards reasoning).\
|
||||
\
|
||||
The attribute `[grind]` by itself will effectively try `[grind ←]` (if the conclusion is sufficient for instantiation) and then `[grind →]`.\
|
||||
The attribute `[{attrName}]` by itself will effectively try `[{attrName} ←]` (if the conclusion is sufficient for instantiation) and then `[{attrName} →]`.\
|
||||
\
|
||||
The `grind` tactic utilizes annotated theorems to add instances of matching patterns into the local context during proof search.\
|
||||
For example, if a theorem `@[grind =] theorem foo_idempotent : foo (foo x) = foo x` is annotated,\
|
||||
For example, if a theorem `@[{attrName} =] theorem foo_idempotent : foo (foo x) = foo x` is annotated,\
|
||||
`grind` will add an instance of this theorem to the local context whenever it encounters the pattern `foo (foo x)`."
|
||||
applicationTime := .afterCompilation
|
||||
add := fun declName stx attrKind => MetaM.run' do
|
||||
@@ -105,49 +154,111 @@ private def registerGrindAttr (minIndexable : Bool) (showInfo : Bool) : IO Unit
|
||||
-- When the body is not available (i.e. the def equations are private), the attribute will not
|
||||
-- be exported; see `ematchTheoremsExt.exportEntry?`.
|
||||
withoutExporting do
|
||||
match (← getAttrKindFromOpt stx) with
|
||||
| .ematch .user => throwInvalidUsrModifier
|
||||
| .ematch k => addEMatchAttr declName attrKind k (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
|
||||
| .cases eager => addCasesAttr declName eager attrKind
|
||||
| .intro =>
|
||||
if let some info ← isCasesAttrPredicateCandidate? declName false then
|
||||
for ctor in info.ctors do
|
||||
addEMatchAttr ctor attrKind (.default false) (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
|
||||
else
|
||||
throwError "invalid `[grind intro]`, `{.ofConstName declName}` is not an inductive predicate"
|
||||
| .ext => addExtAttr declName attrKind
|
||||
| .infer =>
|
||||
if let some declName ← isCasesAttrCandidate? declName false then
|
||||
addCasesAttr declName false attrKind
|
||||
if let some info ← isInductivePredicate? declName then
|
||||
-- If it is an inductive predicate,
|
||||
-- we also add the constructors (intro rules) as E-matching rules
|
||||
if let some ext := ext? then
|
||||
match (← getAttrKindFromOpt stx) with
|
||||
| .symbol prio =>
|
||||
unless attrName == `grind do
|
||||
throwError "symbol priorities must be set using the default `[grind]` attribute"
|
||||
addSymbolPriorityAttr declName attrKind prio
|
||||
| .cases eager => ext.addCasesAttr declName eager attrKind
|
||||
| .funCC => ext.addFunCCAttr declName attrKind
|
||||
| .ext => ext.addExtAttr declName attrKind
|
||||
| .ematch .user => throwInvalidUsrModifier
|
||||
| .ematch k => ext.addEMatchAttr declName attrKind k (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
|
||||
| .intro =>
|
||||
if let some info ← isCasesAttrPredicateCandidate? declName false then
|
||||
for ctor in info.ctors do
|
||||
ext.addEMatchAttr ctor attrKind (.default false) (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
|
||||
else
|
||||
throwError "invalid `[{attrName} intro]`, `{.ofConstName declName}` is not an inductive predicate"
|
||||
| .infer =>
|
||||
if let some declName ← isCasesAttrCandidate? declName false then
|
||||
ext.addCasesAttr declName false attrKind
|
||||
if let some info ← isInductivePredicate? declName then
|
||||
-- If it is an inductive predicate,
|
||||
-- we also add the constructors (intro rules) as E-matching rules
|
||||
for ctor in info.ctors do
|
||||
ext.addEMatchAttr ctor attrKind (.default false) (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
|
||||
else
|
||||
ext.addEMatchAttrAndSuggest stx declName attrKind (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
|
||||
| .inj => ext.addInjectiveAttr declName attrKind
|
||||
else
|
||||
-- **TODO**: delete after update stage0 and new extension for default `grind` attribute
|
||||
match (← getAttrKindFromOpt stx) with
|
||||
| .ematch .user => throwInvalidUsrModifier
|
||||
| .ematch k => addEMatchAttr declName attrKind k (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
|
||||
| .cases eager => addCasesAttr declName eager attrKind
|
||||
| .intro =>
|
||||
if let some info ← isCasesAttrPredicateCandidate? declName false then
|
||||
for ctor in info.ctors do
|
||||
addEMatchAttr ctor attrKind (.default false) (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
|
||||
else
|
||||
addEMatchAttrAndSuggest stx declName attrKind (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
|
||||
| .symbol prio => addSymbolPriorityAttr declName attrKind prio
|
||||
| .inj => addInjectiveAttr declName attrKind
|
||||
| .funCC => addFunCCAttr declName attrKind
|
||||
else
|
||||
throwError "invalid `[{attrName} intro]`, `{.ofConstName declName}` is not an inductive predicate"
|
||||
| .ext => addExtAttr declName attrKind
|
||||
| .infer =>
|
||||
if let some declName ← isCasesAttrCandidate? declName false then
|
||||
addCasesAttr declName false attrKind
|
||||
if let some info ← isInductivePredicate? declName then
|
||||
-- If it is an inductive predicate,
|
||||
-- we also add the constructors (intro rules) as E-matching rules
|
||||
for ctor in info.ctors do
|
||||
addEMatchAttr ctor attrKind (.default false) (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
|
||||
else
|
||||
addEMatchAttrAndSuggest stx declName attrKind (← getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
|
||||
| .symbol prio => addSymbolPriorityAttr declName attrKind prio
|
||||
| .inj => addInjectiveAttr declName attrKind
|
||||
| .funCC => addFunCCAttr declName attrKind
|
||||
erase := fun declName => MetaM.run' do
|
||||
if showInfo then
|
||||
throwError "`[grind?]` is a helper attribute for displaying inferred patterns, if you want to remove the attribute, consider using `[grind]` instead"
|
||||
if (← isCasesAttrCandidate declName false) then
|
||||
eraseCasesAttr declName
|
||||
else if (← isExtTheorem declName) then
|
||||
eraseExtAttr declName
|
||||
else if (← isInjectiveTheorem declName) then
|
||||
eraseInjectiveAttr declName
|
||||
else if (← hasFunCCAttr declName) then
|
||||
eraseFunCCAttr declName
|
||||
throwError "`[{attrName}?]` is a helper attribute for displaying inferred patterns, if you want to remove the attribute, consider using `[{attrName}]` instead"
|
||||
if let some ext := ext? then
|
||||
if (← isCasesAttrCandidate declName false) then
|
||||
ext.eraseCasesAttr declName
|
||||
else if (← ext.isExtTheorem declName) then
|
||||
ext.eraseExtAttr declName
|
||||
else if (← ext.isInjectiveTheorem declName) then
|
||||
ext.eraseInjectiveAttr declName
|
||||
else if (← ext.hasFunCCAttr declName) then
|
||||
ext.eraseFunCCAttr declName
|
||||
else
|
||||
ext.eraseEMatchAttr declName
|
||||
else
|
||||
eraseEMatchAttr declName
|
||||
-- **TODO**: delete after update stage0 and new extension for default `grind` attribute
|
||||
if (← isCasesAttrCandidate declName false) then
|
||||
eraseCasesAttr declName
|
||||
else if (← isExtTheorem declName) then
|
||||
eraseExtAttr declName
|
||||
else if (← isInjectiveTheorem declName) then
|
||||
eraseInjectiveAttr declName
|
||||
else if (← hasFunCCAttr declName) then
|
||||
eraseFunCCAttr declName
|
||||
else
|
||||
eraseEMatchAttr declName
|
||||
}
|
||||
|
||||
private def registerDefaultGrindAttr (minIndexable : Bool) (showInfo : Bool) : IO Unit :=
|
||||
mkGrindAttr `grind minIndexable showInfo
|
||||
|
||||
builtin_initialize
|
||||
registerGrindAttr (minIndexable := false) (showInfo := true)
|
||||
registerGrindAttr (minIndexable := false) (showInfo := false)
|
||||
registerGrindAttr (minIndexable := true) (showInfo := true)
|
||||
registerGrindAttr (minIndexable := true) (showInfo := false)
|
||||
registerDefaultGrindAttr (minIndexable := false) (showInfo := true)
|
||||
registerDefaultGrindAttr (minIndexable := false) (showInfo := false)
|
||||
registerDefaultGrindAttr (minIndexable := true) (showInfo := true)
|
||||
registerDefaultGrindAttr (minIndexable := true) (showInfo := false)
|
||||
|
||||
abbrev ExtensionMap := Std.HashMap Name Extension
|
||||
|
||||
builtin_initialize extensionMapRef : IO.Ref ExtensionMap ← IO.mkRef {}
|
||||
|
||||
def getExtension? (attrName : Name) : IO (Option Extension) :=
|
||||
return (← extensionMapRef.get)[attrName]?
|
||||
|
||||
def registerAttr (attrName : Name) (ref : Name := by exact decl_name%) : IO Extension := do
|
||||
let ext ← mkExtension ref
|
||||
mkGrindAttr attrName (minIndexable := false) (showInfo := true) (ext? := some ext) (ref := ref)
|
||||
mkGrindAttr attrName (minIndexable := false) (showInfo := false) (ext? := some ext) (ref := ref)
|
||||
mkGrindAttr attrName (minIndexable := true) (showInfo := true) (ext? := some ext) (ref := ref)
|
||||
mkGrindAttr attrName (minIndexable := true) (showInfo := false) (ext? := some ext) (ref := ref)
|
||||
extensionMapRef.modify fun map => map.insert attrName ext
|
||||
return ext
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
||||
@@ -6,19 +6,18 @@ Authors: Leonardo de Moura
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Cases
|
||||
public import Lean.Meta.Tactic.Grind.Extension
|
||||
public section
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
/-- Types that `grind` will case-split on. -/
|
||||
structure CasesTypes where
|
||||
casesMap : PHashMap Name Bool := {}
|
||||
deriving Inhabited
|
||||
|
||||
-- TODO: delete
|
||||
structure CasesEntry where
|
||||
declName : Name
|
||||
eager : Bool
|
||||
deriving Inhabited
|
||||
|
||||
/-- A collection of `CasesTypes`. -/
|
||||
abbrev CasesTypesArray := Array CasesTypes
|
||||
|
||||
/--
|
||||
`grind` always case-splits on the following types. Even when using `grind only`.
|
||||
The goal is to reduce noise in the tactic generated by `grind?`
|
||||
@@ -43,9 +42,6 @@ def CasesTypes.contains (s : CasesTypes) (declName : Name) : Bool :=
|
||||
def CasesTypes.erase (s : CasesTypes) (declName : Name) : CasesTypes :=
|
||||
{ s with casesMap := s.casesMap.erase declName }
|
||||
|
||||
def CasesTypes.insert (s : CasesTypes) (declName : Name) (eager : Bool) : CasesTypes :=
|
||||
{ s with casesMap := s.casesMap.insert declName eager }
|
||||
|
||||
def CasesTypes.find? (s : CasesTypes) (declName : Name) : Option Bool :=
|
||||
s.casesMap.find? declName
|
||||
|
||||
@@ -55,6 +51,9 @@ def CasesTypes.isEagerSplit (s : CasesTypes) (declName : Name) : Bool :=
|
||||
def CasesTypes.isSplit (s : CasesTypes) (declName : Name) : Bool :=
|
||||
(s.casesMap.find? declName |>.isSome) || isBuiltinEagerCases declName
|
||||
|
||||
/-
|
||||
TODO: group into a `grind` extension object
|
||||
-/
|
||||
builtin_initialize casesExt : SimpleScopedEnvExtension CasesEntry CasesTypes ←
|
||||
registerSimpleScopedEnvExtension {
|
||||
initial := {}
|
||||
@@ -68,7 +67,7 @@ def getCasesTypes : CoreM CasesTypes :=
|
||||
return casesExt.getState (← getEnv)
|
||||
|
||||
/-- Returns `true` is `declName` is a builtin split or has been tagged with `[grind]` attribute. -/
|
||||
def isSplit (declName : Name) : CoreM Bool := do
|
||||
def isGlobalSplit (declName : Name) : CoreM Bool := do
|
||||
return (← getCasesTypes).isSplit declName
|
||||
|
||||
partial def isCasesAttrCandidate? (declName : Name) (eager : Bool) : CoreM (Option Name) := do
|
||||
@@ -98,7 +97,7 @@ def CasesTypes.eraseDecl (s : CasesTypes) (declName : Name) : CoreM CasesTypes :
|
||||
if s.contains declName then
|
||||
return s.erase declName
|
||||
else
|
||||
throwError "`{.ofConstName declName}` is not marked with the `[grind]` attribute"
|
||||
throwNotMarkedWithGrindAttribute declName
|
||||
|
||||
def ensureNotBuiltinCases (declName : Name) : CoreM Unit := do
|
||||
if isBuiltinEagerCases declName then
|
||||
|
||||
@@ -5,7 +5,7 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Theorems
|
||||
public import Lean.Meta.Tactic.Grind.Extension
|
||||
import Init.Grind.Util
|
||||
import Lean.Util.ForEachExpr
|
||||
import Lean.Meta.Tactic.Grind.Util
|
||||
@@ -13,14 +13,7 @@ import Lean.Meta.Match.Basic
|
||||
import Lean.Meta.Tactic.TryThis
|
||||
public section
|
||||
namespace Lean.Meta.Grind
|
||||
/--
|
||||
`grind` uses symbol priorities when inferring patterns for E-matching.
|
||||
Symbols not in `map` are assumed to have default priority (i.e., `eval_prio default`).
|
||||
-/
|
||||
structure SymbolPriorities where
|
||||
map : PHashMap Name Nat := {}
|
||||
deriving Inhabited
|
||||
|
||||
-- TODO: delete
|
||||
structure SymbolPriorityEntry where
|
||||
declName : Name
|
||||
prio : Nat
|
||||
@@ -30,10 +23,6 @@ structure SymbolPriorityEntry where
|
||||
def SymbolPriorities.erase (s : SymbolPriorities) (declName : Name) : SymbolPriorities :=
|
||||
{ s with map := s.map.erase declName }
|
||||
|
||||
/-- Inserts `declName ↦ prio` into `s`. -/
|
||||
def SymbolPriorities.insert (s : SymbolPriorities) (declName : Name) (prio : Nat) : SymbolPriorities :=
|
||||
{ s with map := s.map.insert declName prio }
|
||||
|
||||
/-- Returns `declName` priority for E-matching pattern inference in `s`. -/
|
||||
def SymbolPriorities.getPrio (s : SymbolPriorities) (declName : Name) : Nat :=
|
||||
if let some prio := s.map.find? declName then
|
||||
@@ -48,6 +37,9 @@ Recall that symbols not in `s` are assumed to have default priority.
|
||||
def SymbolPriorities.contains (s : SymbolPriorities) (declName : Name) : Bool :=
|
||||
s.map.contains declName
|
||||
|
||||
/-
|
||||
TODO: group into a `grind` extension object
|
||||
-/
|
||||
private builtin_initialize symbolPrioExt : SimpleScopedEnvExtension SymbolPriorityEntry SymbolPriorities ←
|
||||
registerSimpleScopedEnvExtension {
|
||||
initial := {}
|
||||
@@ -286,19 +278,6 @@ def preprocessPattern (pat : Expr) (normalizePattern := true) : MetaM Expr := do
|
||||
let pat ← foldProjs pat
|
||||
return pat
|
||||
|
||||
inductive EMatchTheoremKind where
|
||||
| eqLhs (gen : Bool)
|
||||
| eqRhs (gen : Bool)
|
||||
| eqBoth (gen : Bool)
|
||||
| eqBwd
|
||||
| fwd
|
||||
| bwd (gen : Bool)
|
||||
| leftRight
|
||||
| rightLeft
|
||||
| default (gen : Bool)
|
||||
| user /- pattern specified using `grind_pattern` command -/
|
||||
deriving Inhabited, BEq, Repr, Hashable
|
||||
|
||||
def EMatchTheoremKind.isEqLhs : EMatchTheoremKind → Bool
|
||||
| .eqLhs _ => true
|
||||
| _ => false
|
||||
@@ -345,106 +324,13 @@ private def EMatchTheoremKind.explainFailure : EMatchTheoremKind → String
|
||||
| .default _ => "failed to find patterns"
|
||||
| .user => unreachable!
|
||||
|
||||
structure CnstrRHS where
|
||||
/-- Abstracted universe level param names in the `rhs` -/
|
||||
levelNames : Array Name
|
||||
/-- Number of abstracted metavariable in the `rhs` -/
|
||||
numMVars : Nat
|
||||
/-- The actual `rhs`. -/
|
||||
expr : Expr
|
||||
deriving Inhabited, BEq, Repr
|
||||
|
||||
/--
|
||||
Grind patterns may have constraints associated with them.
|
||||
-/
|
||||
inductive EMatchTheoremConstraint where
|
||||
| /--
|
||||
A constraint of the form `lhs =/= rhs`.
|
||||
The `lhs` is one of the bound variables, and the `rhs` an abstract term that must not be definitionally
|
||||
equal to a term `t` assigned to `lhs`. -/
|
||||
notDefEq (lhs : Nat) (rhs : CnstrRHS)
|
||||
| /--
|
||||
A constraint of the form `lhs =?= rhs`.
|
||||
The `lhs` is one of the bound variables, and the `rhs` an abstract term that must be definitionally
|
||||
equal to a term `t` assigned to `lhs`. -/
|
||||
defEq (lhs : Nat) (rhs : CnstrRHS)
|
||||
| /--
|
||||
A constraint of the form `size lhs < n`. The `lhs` is one of the bound variables.
|
||||
The size is computed ignoring implicit terms, but sharing is not taken into account.
|
||||
-/
|
||||
sizeLt (lhs : Nat) (n : Nat)
|
||||
| /--
|
||||
A constraint of the form `depth lhs < n`. The `lhs` is one of the bound variables.
|
||||
The depth is computed in constant time using the `approxDepth` field attached to expressions.
|
||||
-/
|
||||
depthLt (lhs : Nat) (n : Nat)
|
||||
| /--
|
||||
Instantiates the theorem only if its generation is less than `n`
|
||||
-/
|
||||
genLt (n : Nat)
|
||||
| /--
|
||||
Constraints of the form `is_ground x`. Instantiates the theorem only if
|
||||
`x` is ground term.
|
||||
-/
|
||||
isGround (bvarIdx : Nat)
|
||||
| /--
|
||||
Constraints of the form `is_value x` and `is_strict_value x`.
|
||||
A value is defined as
|
||||
- A constructor fully applied to value arguments.
|
||||
- A literal: numerals, strings, etc.
|
||||
- A lambda. In the strict case, lambdas are not considered.
|
||||
-/
|
||||
isValue (bvarIdx : Nat) (strict : Bool)
|
||||
| /--
|
||||
Instantiates the theorem only if less than `n` instances have been generated for this theorem.
|
||||
-/
|
||||
maxInsts (n : Nat)
|
||||
| /--
|
||||
It instructs `grind` to postpone the instantiation of the theorem until `e` is known to be `true`.
|
||||
-/
|
||||
guard (e : Expr)
|
||||
| /--
|
||||
Similar to `guard`, but checks whether `e` is implied by asserting `¬e`.
|
||||
-/
|
||||
check (e : Expr)
|
||||
| /--
|
||||
Constraints of the form `not_value x` and `not_strict_value x`.
|
||||
They are the negations of `is_value x` and `is_strict_value x`.
|
||||
-/
|
||||
notValue (bvarIdx : Nat) (strict : Bool)
|
||||
deriving Inhabited, Repr, BEq
|
||||
|
||||
/-- A theorem for heuristic instantiation based on E-matching. -/
|
||||
structure EMatchTheorem where
|
||||
/--
|
||||
It stores universe parameter names for universe polymorphic proofs.
|
||||
Recall that it is non-empty only when we elaborate an expression provided by the user.
|
||||
When `proof` is just a constant, we can use the universe parameter names stored in the declaration.
|
||||
-/
|
||||
levelParams : Array Name
|
||||
proof : Expr
|
||||
numParams : Nat
|
||||
patterns : List Expr
|
||||
/-- Contains all symbols used in `patterns`. -/
|
||||
symbols : List HeadIndex
|
||||
origin : Origin
|
||||
/-- The `kind` is used for generating the `patterns`. We save it here to implement `grind?`. -/
|
||||
kind : EMatchTheoremKind
|
||||
/-- Stores whether patterns were inferred using the minimal indexable subexpression condition. -/
|
||||
minIndexable : Bool
|
||||
cnstrs : List EMatchTheoremConstraint := []
|
||||
deriving Inhabited
|
||||
|
||||
instance : TheoremLike EMatchTheorem where
|
||||
getSymbols thm := thm.symbols
|
||||
setSymbols thm symbols := { thm with symbols }
|
||||
getOrigin thm := thm.origin
|
||||
getProof thm := thm.proof
|
||||
getLevelParams thm := thm.levelParams
|
||||
|
||||
/-- Set of E-matching theorems. -/
|
||||
abbrev EMatchTheorems := Theorems EMatchTheorem
|
||||
|
||||
/-- A collection of sets of E-matching theorems. -/
|
||||
abbrev EMatchTheoremsArray := TheoremsArray EMatchTheorem
|
||||
|
||||
/--
|
||||
Returns `true` if there is a theorem with exactly the same pattern and constraints is already in `s`
|
||||
-/
|
||||
@@ -453,6 +339,10 @@ def EMatchTheorems.containsWithSamePatterns (s : EMatchTheorems) (origin : Origi
|
||||
let thms := s.find origin
|
||||
thms.any fun thm => thm.patterns == patterns && thm.cnstrs == cnstrs
|
||||
|
||||
def ExtensionStateArray.containsWithSamePatterns (s : ExtensionStateArray) (origin : Origin)
|
||||
(patterns : List Expr) (cnstrs : List EMatchTheoremConstraint) : Bool :=
|
||||
s.any (EMatchTheorems.containsWithSamePatterns ·.ematch origin patterns cnstrs)
|
||||
|
||||
def EMatchTheorems.getKindsFor (s : EMatchTheorems) (origin : Origin) : List EMatchTheoremKind :=
|
||||
let thms := s.find origin
|
||||
thms.map (·.kind)
|
||||
@@ -460,6 +350,9 @@ def EMatchTheorems.getKindsFor (s : EMatchTheorems) (origin : Origin) : List EMa
|
||||
def EMatchTheorem.getProofWithFreshMVarLevels (thm : EMatchTheorem) : MetaM Expr := do
|
||||
Grind.getProofWithFreshMVarLevels thm
|
||||
|
||||
/-
|
||||
TODO: group into a `grind` extension object
|
||||
-/
|
||||
private builtin_initialize ematchTheoremsExt : SimpleScopedEnvExtension EMatchTheorem (Theorems EMatchTheorem) ←
|
||||
registerSimpleScopedEnvExtension {
|
||||
addEntry := Theorems.insert
|
||||
@@ -1482,6 +1375,7 @@ def mkEMatchEqTheoremsForDef? (declName : Name) (showInfo := false) : MetaM (Opt
|
||||
eqns.mapM fun eqn => do
|
||||
mkEMatchEqTheorem eqn (normalizePattern := true) (showInfo := showInfo)
|
||||
|
||||
-- TODO: delete
|
||||
private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (useLhs := true) (showInfo := false) : MetaM Unit := do
|
||||
if wasOriginallyTheorem (← getEnv) declName then
|
||||
ematchTheoremsExt.add (← mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs) (gen := thmKind.gen) (showInfo := showInfo)) attrKind
|
||||
@@ -1492,26 +1386,34 @@ private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind
|
||||
else
|
||||
throwError s!"`{thmKind.toAttribute false}` attribute can only be applied to equational theorems or function definitions"
|
||||
|
||||
private def Extension.addGrindEqAttr (ext : Extension) (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (useLhs := true) (showInfo := false) : MetaM Unit := do
|
||||
if wasOriginallyTheorem (← getEnv) declName then
|
||||
ext.add (.ematch (← mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs) (gen := thmKind.gen) (showInfo := showInfo))) attrKind
|
||||
else if let some thms ← mkEMatchEqTheoremsForDef? declName (showInfo := showInfo) then
|
||||
unless useLhs do
|
||||
throwError "`{.ofConstName declName}` is a definition, you must only use the left-hand side for extracting patterns"
|
||||
thms.forM fun thm => ext.add (.ematch thm) attrKind
|
||||
else
|
||||
throwError s!"`{thmKind.toAttribute false}` attribute can only be applied to equational theorems or function definitions"
|
||||
|
||||
def EMatchTheorems.eraseDecl (s : EMatchTheorems) (declName : Name) : MetaM EMatchTheorems := do
|
||||
let throwErr {α} : MetaM α :=
|
||||
throwError "`{.ofConstName declName}` is not marked with the `[grind]` attribute"
|
||||
if !wasOriginallyTheorem (← getEnv) declName then
|
||||
if let some eqns ← getEqnsFor? declName then
|
||||
let s := ematchTheoremsExt.getState (← getEnv)
|
||||
unless eqns.all fun eqn => s.contains (.decl eqn) do
|
||||
throwErr
|
||||
throwNotMarkedWithGrindAttribute declName
|
||||
return eqns.foldl (init := s) fun s eqn => s.erase (.decl eqn)
|
||||
else
|
||||
throwErr
|
||||
throwNotMarkedWithGrindAttribute declName
|
||||
else
|
||||
unless ematchTheoremsExt.getState (← getEnv) |>.contains (.decl declName) do
|
||||
throwErr
|
||||
unless s.contains (.decl declName) do
|
||||
throwNotMarkedWithGrindAttribute declName
|
||||
return s.erase <| .decl declName
|
||||
|
||||
private def ensureNoMinIndexable (minIndexable : Bool) : MetaM Unit := do
|
||||
if minIndexable then
|
||||
throwError "redundant modifier `!` in `grind` attribute"
|
||||
|
||||
-- TODO: delete
|
||||
def addEMatchAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (prios : SymbolPriorities)
|
||||
(showInfo := false) (minIndexable := false) : MetaM Unit := do
|
||||
match thmKind with
|
||||
@@ -1534,6 +1436,28 @@ def addEMatchAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatch
|
||||
let thm ← mkEMatchTheoremForDecl declName thmKind prios (showInfo := showInfo) (minIndexable := minIndexable)
|
||||
ematchTheoremsExt.add thm attrKind
|
||||
|
||||
def Extension.addEMatchAttr (ext : Extension) (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (prios : SymbolPriorities)
|
||||
(showInfo := false) (minIndexable := false) : MetaM Unit := do
|
||||
match thmKind with
|
||||
| .eqLhs _ =>
|
||||
ensureNoMinIndexable minIndexable
|
||||
ext.addGrindEqAttr declName attrKind thmKind (useLhs := true) (showInfo := showInfo)
|
||||
| .eqRhs _ =>
|
||||
ensureNoMinIndexable minIndexable
|
||||
ext.addGrindEqAttr declName attrKind thmKind (useLhs := false) (showInfo := showInfo)
|
||||
| .eqBoth _ =>
|
||||
ensureNoMinIndexable minIndexable
|
||||
ext.addGrindEqAttr declName attrKind thmKind (useLhs := true) (showInfo := showInfo)
|
||||
ext.addGrindEqAttr declName attrKind thmKind (useLhs := false) (showInfo := showInfo)
|
||||
| _ =>
|
||||
let info ← getConstInfo declName
|
||||
if !wasOriginallyTheorem (← getEnv) declName && !info.isCtor && !info.isAxiom then
|
||||
ensureNoMinIndexable minIndexable
|
||||
ext.addGrindEqAttr declName attrKind thmKind (showInfo := showInfo)
|
||||
else
|
||||
let thm ← mkEMatchTheoremForDecl declName thmKind prios (showInfo := showInfo) (minIndexable := minIndexable)
|
||||
ext.add (.ematch thm) attrKind
|
||||
|
||||
private structure SelectM.State where
|
||||
-- **Note**: hack, an attribute is not a tactic.
|
||||
suggestions : Array Tactic.TryThis.Suggestion := #[]
|
||||
@@ -1665,6 +1589,7 @@ Tries different modifiers, logs info messages with modifiers that worked, but st
|
||||
Remark: if `backward.grind.inferPattern` is `true`, then `.default false` is used.
|
||||
The parameter `showInfo` is only taken into account when `backward.grind.inferPattern` is `true`.
|
||||
-/
|
||||
-- TODO: delete
|
||||
def addEMatchAttrAndSuggest (ref : Syntax) (declName : Name) (attrKind : AttributeKind) (prios : SymbolPriorities)
|
||||
(minIndexable : Bool) (showInfo : Bool) (isParam : Bool := false) : MetaM Unit := do
|
||||
let info ← getConstInfo declName
|
||||
@@ -1677,6 +1602,25 @@ def addEMatchAttrAndSuggest (ref : Syntax) (declName : Name) (attrKind : Attribu
|
||||
let thm ← mkEMatchTheoremAndSuggest ref declName prios minIndexable isParam
|
||||
ematchTheoremsExt.add thm attrKind
|
||||
|
||||
/--
|
||||
Tries different modifiers, logs info messages with modifiers that worked, but stores just the first one that worked.
|
||||
|
||||
Remark: if `backward.grind.inferPattern` is `true`, then `.default false` is used.
|
||||
The parameter `showInfo` is only taken into account when `backward.grind.inferPattern` is `true`.
|
||||
-/
|
||||
-- TODO: delete
|
||||
def Extension.addEMatchAttrAndSuggest (ext : Extension) (ref : Syntax) (declName : Name) (attrKind : AttributeKind) (prios : SymbolPriorities)
|
||||
(minIndexable : Bool) (showInfo : Bool) (isParam : Bool := false) : MetaM Unit := do
|
||||
let info ← getConstInfo declName
|
||||
if !wasOriginallyTheorem (← getEnv) declName && !info.isCtor && !info.isAxiom then
|
||||
ensureNoMinIndexable minIndexable
|
||||
ext.addGrindEqAttr declName attrKind (.default false) (showInfo := showInfo)
|
||||
else if backward.grind.inferPattern.get (← getOptions) then
|
||||
ext.addEMatchAttr declName attrKind (.default false) prios (minIndexable := minIndexable) (showInfo := showInfo)
|
||||
else
|
||||
let thm ← mkEMatchTheoremAndSuggest ref declName prios minIndexable isParam
|
||||
ext.add (.ematch thm) attrKind
|
||||
|
||||
def eraseEMatchAttr (declName : Name) : MetaM Unit := do
|
||||
/-
|
||||
Remark: consider the following example
|
||||
|
||||
@@ -6,13 +6,14 @@ Authors: Leonardo de Moura
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Ext
|
||||
public import Lean.Meta.Tactic.Grind.Extension
|
||||
public section
|
||||
namespace Lean.Meta.Grind
|
||||
/-! Grind extensionality attribute to mark which `[ext]` theorems should be used. -/
|
||||
|
||||
/-- Extensionality theorems that can be used by `grind` -/
|
||||
abbrev ExtTheorems := PHashSet Name
|
||||
|
||||
/-
|
||||
TODO: group into a `grind` extension object
|
||||
-/
|
||||
builtin_initialize extTheoremsExt : SimpleScopedEnvExtension Name ExtTheorems ←
|
||||
registerSimpleScopedEnvExtension {
|
||||
initial := {}
|
||||
@@ -28,7 +29,7 @@ def addExtAttr (declName : Name) (attrKind : AttributeKind) : CoreM Unit := do
|
||||
validateExtAttr declName
|
||||
extTheoremsExt.add declName attrKind
|
||||
|
||||
private def eraseDecl (s : ExtTheorems) (declName : Name) : CoreM ExtTheorems := do
|
||||
def ExtTheorems.eraseDecl (s : ExtTheorems) (declName : Name) : CoreM ExtTheorems := do
|
||||
if s.contains declName then
|
||||
return s.erase declName
|
||||
else
|
||||
@@ -36,10 +37,13 @@ private def eraseDecl (s : ExtTheorems) (declName : Name) : CoreM ExtTheorems :=
|
||||
|
||||
def eraseExtAttr (declName : Name) : CoreM Unit := do
|
||||
let s := extTheoremsExt.getState (← getEnv)
|
||||
let s ← eraseDecl s declName
|
||||
let s ← s.eraseDecl declName
|
||||
modifyEnv fun env => extTheoremsExt.modifyState env fun _ => s
|
||||
|
||||
def isExtTheorem (declName : Name) : CoreM Bool := do
|
||||
return extTheoremsExt.getState (← getEnv) |>.contains declName
|
||||
|
||||
def getGlobalExtTheorems : CoreM ExtTheorems := do
|
||||
return extTheoremsExt.getState (← getEnv)
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
||||
221
src/Lean/Meta/Tactic/Grind/Extension.lean
Normal file
221
src/Lean/Meta/Tactic/Grind/Extension.lean
Normal file
@@ -0,0 +1,221 @@
|
||||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Expr
|
||||
public import Lean.Data.PersistentHashMap
|
||||
public import Lean.Meta.Tactic.Grind.Theorems
|
||||
public section
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
/-- Types that `grind` will case-split on. -/
|
||||
structure CasesTypes where
|
||||
casesMap : PHashMap Name Bool := {}
|
||||
deriving Inhabited
|
||||
|
||||
def CasesTypes.insert (s : CasesTypes) (declName : Name) (eager : Bool) : CasesTypes :=
|
||||
{ s with casesMap := s.casesMap.insert declName eager }
|
||||
|
||||
abbrev ExtTheorems := PHashSet Name
|
||||
|
||||
structure SymbolPriorities where
|
||||
map : PHashMap Name Nat := {}
|
||||
deriving Inhabited
|
||||
|
||||
/-- Inserts `declName ↦ prio` into `s`. -/
|
||||
def SymbolPriorities.insert (s : SymbolPriorities) (declName : Name) (prio : Nat) : SymbolPriorities :=
|
||||
{ s with map := s.map.insert declName prio }
|
||||
|
||||
inductive EMatchTheoremKind where
|
||||
| eqLhs (gen : Bool)
|
||||
| eqRhs (gen : Bool)
|
||||
| eqBoth (gen : Bool)
|
||||
| eqBwd
|
||||
| fwd
|
||||
| bwd (gen : Bool)
|
||||
| leftRight
|
||||
| rightLeft
|
||||
| default (gen : Bool)
|
||||
| user /- pattern specified using `grind_pattern` command -/
|
||||
deriving Inhabited, BEq, Repr, Hashable
|
||||
|
||||
structure CnstrRHS where
|
||||
/-- Abstracted universe level param names in the `rhs` -/
|
||||
levelNames : Array Name
|
||||
/-- Number of abstracted metavariable in the `rhs` -/
|
||||
numMVars : Nat
|
||||
/-- The actual `rhs`. -/
|
||||
expr : Expr
|
||||
deriving Inhabited, BEq, Repr
|
||||
|
||||
/--
|
||||
Grind patterns may have constraints associated with them.
|
||||
-/
|
||||
inductive EMatchTheoremConstraint where
|
||||
| /--
|
||||
A constraint of the form `lhs =/= rhs`.
|
||||
The `lhs` is one of the bound variables, and the `rhs` an abstract term that must not be definitionally
|
||||
equal to a term `t` assigned to `lhs`. -/
|
||||
notDefEq (lhs : Nat) (rhs : CnstrRHS)
|
||||
| /--
|
||||
A constraint of the form `lhs =?= rhs`.
|
||||
The `lhs` is one of the bound variables, and the `rhs` an abstract term that must be definitionally
|
||||
equal to a term `t` assigned to `lhs`. -/
|
||||
defEq (lhs : Nat) (rhs : CnstrRHS)
|
||||
| /--
|
||||
A constraint of the form `size lhs < n`. The `lhs` is one of the bound variables.
|
||||
The size is computed ignoring implicit terms, but sharing is not taken into account.
|
||||
-/
|
||||
sizeLt (lhs : Nat) (n : Nat)
|
||||
| /--
|
||||
A constraint of the form `depth lhs < n`. The `lhs` is one of the bound variables.
|
||||
The depth is computed in constant time using the `approxDepth` field attached to expressions.
|
||||
-/
|
||||
depthLt (lhs : Nat) (n : Nat)
|
||||
| /--
|
||||
Instantiates the theorem only if its generation is less than `n`
|
||||
-/
|
||||
genLt (n : Nat)
|
||||
| /--
|
||||
Constraints of the form `is_ground x`. Instantiates the theorem only if
|
||||
`x` is ground term.
|
||||
-/
|
||||
isGround (bvarIdx : Nat)
|
||||
| /--
|
||||
Constraints of the form `is_value x` and `is_strict_value x`.
|
||||
A value is defined as
|
||||
- A constructor fully applied to value arguments.
|
||||
- A literal: numerals, strings, etc.
|
||||
- A lambda. In the strict case, lambdas are not considered.
|
||||
-/
|
||||
isValue (bvarIdx : Nat) (strict : Bool)
|
||||
| /--
|
||||
Instantiates the theorem only if less than `n` instances have been generated for this theorem.
|
||||
-/
|
||||
maxInsts (n : Nat)
|
||||
| /--
|
||||
It instructs `grind` to postpone the instantiation of the theorem until `e` is known to be `true`.
|
||||
-/
|
||||
guard (e : Expr)
|
||||
| /--
|
||||
Similar to `guard`, but checks whether `e` is implied by asserting `¬e`.
|
||||
-/
|
||||
check (e : Expr)
|
||||
| /--
|
||||
Constraints of the form `not_value x` and `not_strict_value x`.
|
||||
They are the negations of `is_value x` and `is_strict_value x`.
|
||||
-/
|
||||
notValue (bvarIdx : Nat) (strict : Bool)
|
||||
deriving Inhabited, Repr, BEq
|
||||
|
||||
/-- A theorem for heuristic instantiation based on E-matching. -/
|
||||
structure EMatchTheorem where
|
||||
/--
|
||||
It stores universe parameter names for universe polymorphic proofs.
|
||||
Recall that it is non-empty only when we elaborate an expression provided by the user.
|
||||
When `proof` is just a constant, we can use the universe parameter names stored in the declaration.
|
||||
-/
|
||||
levelParams : Array Name
|
||||
proof : Expr
|
||||
numParams : Nat
|
||||
patterns : List Expr
|
||||
/-- Contains all symbols used in `patterns`. -/
|
||||
symbols : List HeadIndex
|
||||
origin : Origin
|
||||
/-- The `kind` is used for generating the `patterns`. We save it here to implement `grind?`. -/
|
||||
kind : EMatchTheoremKind
|
||||
/-- Stores whether patterns were inferred using the minimal indexable subexpression condition. -/
|
||||
minIndexable : Bool
|
||||
cnstrs : List EMatchTheoremConstraint := []
|
||||
deriving Inhabited
|
||||
|
||||
instance : TheoremLike EMatchTheorem where
|
||||
getSymbols thm := thm.symbols
|
||||
setSymbols thm symbols := { thm with symbols }
|
||||
getOrigin thm := thm.origin
|
||||
getProof thm := thm.proof
|
||||
getLevelParams thm := thm.levelParams
|
||||
|
||||
/-- A theorem marked with `@[grind inj]` -/
|
||||
structure InjectiveTheorem where
|
||||
levelParams : Array Name
|
||||
proof : Expr
|
||||
/-- Contains all symbols used in the term `f` at the theorem's conclusion: `Function.Injective f`. -/
|
||||
symbols : List HeadIndex
|
||||
origin : Origin
|
||||
deriving Inhabited
|
||||
|
||||
instance : TheoremLike InjectiveTheorem where
|
||||
getSymbols thm := thm.symbols
|
||||
setSymbols thm symbols := { thm with symbols }
|
||||
getOrigin thm := thm.origin
|
||||
getProof thm := thm.proof
|
||||
getLevelParams thm := thm.levelParams
|
||||
|
||||
inductive Entry where
|
||||
| ext (declName : Name)
|
||||
| funCC (declName : Name)
|
||||
| cases (declName : Name) (eager : Bool)
|
||||
| ematch (thm : EMatchTheorem)
|
||||
| inj (thm : InjectiveTheorem)
|
||||
deriving Inhabited
|
||||
|
||||
/-
|
||||
**Note**: We currently have a single normalization and symbol priority sets for all `grind` attributes.
|
||||
Reason: the E-match patterns must be normalized with respect to them. If we are using multiple
|
||||
`grind` attributes, they patterns would have to be re-normalized using the union of all normalizers.
|
||||
|
||||
Alternative design: allow a single `grind` attribute per `grind` call. Cons: when creating a new
|
||||
`grind` attribute users would have to carefully setup the normalizer to ensure all `grind` assumptions
|
||||
are met. Cons: users would not be able to write `grind only [attr_1, attr_2]`.
|
||||
-/
|
||||
|
||||
structure ExtensionState where
|
||||
casesTypes : CasesTypes := {}
|
||||
extThms : ExtTheorems := {}
|
||||
funCC : NameSet := {}
|
||||
ematch : Theorems EMatchTheorem := {}
|
||||
inj : Theorems InjectiveTheorem := {}
|
||||
deriving Inhabited
|
||||
|
||||
abbrev Extension := SimpleScopedEnvExtension Entry ExtensionState
|
||||
|
||||
def ExtensionState.addEntry (s : ExtensionState) (e : Entry) : ExtensionState :=
|
||||
match e with
|
||||
| .cases declName eager => { s with casesTypes := s.casesTypes.insert declName eager }
|
||||
| .ext declName => { s with extThms := s.extThms.insert declName }
|
||||
| .funCC declName => { s with funCC := s.funCC.insert declName }
|
||||
| .ematch thm => { s with ematch := s.ematch.insert thm }
|
||||
| .inj thm => { s with inj := s.inj.insert thm }
|
||||
|
||||
def mkExtension (name : Name := by exact decl_name%) : IO Extension :=
|
||||
registerSimpleScopedEnvExtension {
|
||||
name := name
|
||||
initial := {}
|
||||
addEntry := ExtensionState.addEntry
|
||||
exportEntry? := fun lvl e => do
|
||||
-- export only annotations on public decls
|
||||
let declName := match e with
|
||||
| .inj thm | .ematch thm =>
|
||||
match thm.origin with
|
||||
| .decl declName => declName
|
||||
| _ => unreachable!
|
||||
| .ext declName | .cases declName _ | .funCC declName => declName
|
||||
guard (lvl == .private || !isPrivateName declName)
|
||||
return e
|
||||
}
|
||||
|
||||
/--
|
||||
`grind` is parametrized by a collection of `ExtensionState`. The motivation is to allow
|
||||
users to use multiple extensions simultaneously without merging them into a single structure.
|
||||
The collection is scanned linearly. In practice, we expect the array to be very small.
|
||||
-/
|
||||
abbrev ExtensionStateArray := Array ExtensionState
|
||||
|
||||
def throwNotMarkedWithGrindAttribute (declName : Name) : CoreM α :=
|
||||
throwError "`{.ofConstName declName}` is not marked with the `[grind]` attribute"
|
||||
|
||||
end Lean.Meta.Grind
|
||||
@@ -9,6 +9,9 @@ public import Lean.ScopedEnvExtension
|
||||
public section
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
/-
|
||||
TODO: group into a `grind` extension object
|
||||
-/
|
||||
private builtin_initialize funCCExt : SimpleScopedEnvExtension Name NameSet ←
|
||||
registerSimpleScopedEnvExtension {
|
||||
initial := {}
|
||||
|
||||
@@ -14,25 +14,15 @@ builtin_initialize registerTraceClass `grind.inj
|
||||
builtin_initialize registerTraceClass `grind.inj.assert
|
||||
builtin_initialize registerTraceClass `grind.debug.inj
|
||||
|
||||
/-- A theorem marked with `@[grind inj]` -/
|
||||
structure InjectiveTheorem where
|
||||
levelParams : Array Name
|
||||
proof : Expr
|
||||
/-- Contains all symbols used in the term `f` at the theorem's conclusion: `Function.Injective f`. -/
|
||||
symbols : List HeadIndex
|
||||
origin : Origin
|
||||
deriving Inhabited
|
||||
|
||||
instance : TheoremLike InjectiveTheorem where
|
||||
getSymbols thm := thm.symbols
|
||||
setSymbols thm symbols := { thm with symbols }
|
||||
getOrigin thm := thm.origin
|
||||
getProof thm := thm.proof
|
||||
getLevelParams thm := thm.levelParams
|
||||
|
||||
/-- Set of Injective theorems. -/
|
||||
abbrev InjectiveTheorems := Theorems InjectiveTheorem
|
||||
|
||||
/-- A collections of sets of Injective theorems. -/
|
||||
abbrev InjectiveTheoremsArray := TheoremsArray InjectiveTheorem
|
||||
|
||||
/-
|
||||
TODO: group into a `grind` extension object
|
||||
-/
|
||||
private builtin_initialize injectiveTheoremsExt : SimpleScopedEnvExtension InjectiveTheorem (Theorems InjectiveTheorem) ←
|
||||
registerSimpleScopedEnvExtension {
|
||||
addEntry := Theorems.insert
|
||||
@@ -85,9 +75,13 @@ def mkInjectiveTheorem (declName : Name) : MetaM InjectiveTheorem := do
|
||||
proof, symbols
|
||||
}
|
||||
|
||||
-- TODO: delete
|
||||
def addInjectiveAttr (declName : Name) (attrKind : AttributeKind) : MetaM Unit := do
|
||||
injectiveTheoremsExt.add (← mkInjectiveTheorem declName) attrKind
|
||||
|
||||
def Extension.addInjectiveAttr (ext : Extension) (declName : Name) (attrKind : AttributeKind) : MetaM Unit := do
|
||||
ext.add (.inj (← mkInjectiveTheorem declName)) attrKind
|
||||
|
||||
def eraseInjectiveAttr (declName : Name) : MetaM Unit := do
|
||||
let s := injectiveTheoremsExt.getState (← getEnv)
|
||||
let s ← s.eraseDecl declName
|
||||
|
||||
@@ -146,13 +146,13 @@ private def checkAndAddSplitCandidate (e : Expr) : GoalM Unit := do
|
||||
return ()
|
||||
unless (← isInductivePredicate declName) do
|
||||
return ()
|
||||
if (← get).split.casesTypes.isSplit declName then
|
||||
if (← isSplit declName) then
|
||||
addDefaultSplitCandidate e
|
||||
else if (← getConfig).splitIndPred then
|
||||
addDefaultSplitCandidate e
|
||||
| .fvar .. =>
|
||||
let .const declName _ := (← whnf (← inferType e)).getAppFn | return ()
|
||||
if (← get).split.casesTypes.isSplit declName then
|
||||
if (← isSplit declName) then
|
||||
addDefaultSplitCandidate e
|
||||
| .forallE _ d _ _ =>
|
||||
let currSplitSource := (← readThe Context).splitSource
|
||||
@@ -275,8 +275,8 @@ private def addMatchEqns (f : Expr) (generation : Nat) : GoalM Unit := do
|
||||
|
||||
@[specialize]
|
||||
private def activateTheoremsCore [TheoremLike α] (declName : Name)
|
||||
(getThms : GoalM (Theorems α))
|
||||
(setThms : Theorems α → GoalM Unit)
|
||||
(getThms : GoalM (TheoremsArray α))
|
||||
(setThms : TheoremsArray α → GoalM Unit)
|
||||
(reinsertThm : α → GoalM Unit)
|
||||
(activateThm : α → GoalM Unit) : GoalM Unit := do
|
||||
if let some (thms, s) := (← getThms).retrieve? declName then
|
||||
@@ -444,7 +444,7 @@ private def tryEta (e : Expr) (generation : Nat) : GoalM Unit := do
|
||||
Returns `true` if we should use `funCC` for applications of the given constant symbol.
|
||||
-/
|
||||
private def useFunCongrAtDecl (declName : Name) : GrindM Bool := do
|
||||
if (← readThe Grind.Context).funCCs.contains declName then
|
||||
if (← hasFunCCModifier declName) then
|
||||
return true
|
||||
if (← isInstance declName) then
|
||||
/- **Note**: Instances are support elements. No `funCC` -/
|
||||
|
||||
@@ -182,9 +182,9 @@ private partial def introNext (goal : Goal) (generation : Nat) : GrindM IntroRes
|
||||
else
|
||||
return .done goal
|
||||
|
||||
private def isEagerCasesCandidate (goal : Goal) (type : Expr) : Bool := Id.run do
|
||||
private def isEagerCasesCandidate (type : Expr) : GrindM Bool := do
|
||||
let .const declName _ := type.getAppFn | return false
|
||||
return goal.split.casesTypes.isEagerSplit declName
|
||||
isEagerSplit declName
|
||||
|
||||
/-- Returns `true` if `type` is an inductive type with at most one constructor. -/
|
||||
private def isCheapInductive (type : Expr) : CoreM Bool := do
|
||||
@@ -215,7 +215,7 @@ private def applyCases? (goal : Goal) (fvarId : FVarId) (kp : ActionCont) : Grin
|
||||
Example: `a ∣ b` is defined as `∃ x, b = a * x`
|
||||
-/
|
||||
let type ← whnf (← fvarId.getType)
|
||||
unless isEagerCasesCandidate goal type do return none
|
||||
unless (← isEagerCasesCandidate type) do return none
|
||||
if (← cheapCasesOnly) then
|
||||
unless (← isCheapInductive type) do return none
|
||||
if let .const declName _ := type.getAppFn then
|
||||
@@ -268,7 +268,7 @@ def intros (generation : Nat) : Action :=
|
||||
|
||||
/-- Asserts a new fact `prop` with proof `proof` to the given `goal`. -/
|
||||
private def assertAt (proof : Expr) (prop : Expr) (generation : Nat) : Action := fun goal kna kp => do
|
||||
if isEagerCasesCandidate goal prop then
|
||||
if (← isEagerCasesCandidate prop) then
|
||||
let mvarId ← goal.mvarId.assert (← mkFreshUserName `h) prop proof
|
||||
intros generation { goal with mvarId } kna kp
|
||||
else goal.withContext do
|
||||
|
||||
@@ -32,26 +32,51 @@ import Lean.Meta.Tactic.Grind.Core
|
||||
public section
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
/--
|
||||
Returns the `ExtensionState` for the default `grind` attribute.
|
||||
-/
|
||||
def getDefaultExtensionState : MetaM ExtensionState := do
|
||||
-- **TODO**: update after update stage0
|
||||
let casesTypes ← getCasesTypes
|
||||
let funCC ← getFunCCSet
|
||||
let extThms ← getGlobalExtTheorems
|
||||
let ematch ← getEMatchTheorems
|
||||
let inj ← getInjectiveTheorems
|
||||
return {
|
||||
casesTypes, funCC, extThms, ematch, inj
|
||||
}
|
||||
|
||||
def getOnlyExtensionState : MetaM ExtensionState := do
|
||||
let casesTypes ← getCasesTypes
|
||||
let funCC ← getFunCCSet
|
||||
let extThms ← getGlobalExtTheorems
|
||||
return {
|
||||
casesTypes, funCC, extThms
|
||||
}
|
||||
|
||||
structure Params where
|
||||
config : Grind.Config
|
||||
ematch : EMatchTheorems := default
|
||||
inj : InjectiveTheorems := default
|
||||
symPrios : SymbolPriorities := {}
|
||||
casesTypes : CasesTypes := {}
|
||||
extensions : ExtensionStateArray := #[]
|
||||
extra : PArray EMatchTheorem := {}
|
||||
extraInj : PArray InjectiveTheorem := {}
|
||||
extraFacts : PArray Expr := {}
|
||||
funCCs : NameSet := {}
|
||||
symPrios : SymbolPriorities := {}
|
||||
norm : Simp.Context
|
||||
normProcs : Array Simprocs
|
||||
anchorRefs? : Option (Array AnchorRef) := none
|
||||
-- TODO: inductives to split
|
||||
|
||||
def mkParams (config : Grind.Config) : MetaM Params := do
|
||||
def mkParams (config : Grind.Config) (extensions : ExtensionStateArray) : MetaM Params := do
|
||||
let norm ← Grind.getSimpContext config
|
||||
let normProcs ← Grind.getSimprocs
|
||||
let symPrios ← getGlobalSymbolPriorities
|
||||
return { config, norm, normProcs, symPrios }
|
||||
return { config, norm, normProcs, symPrios, extensions }
|
||||
|
||||
def mkDefaultParams (config : Grind.Config) : MetaM Params := do
|
||||
mkParams config #[← getDefaultExtensionState]
|
||||
|
||||
def mkOnlyParams (config : Grind.Config) : MetaM Params := do
|
||||
mkParams config #[← getOnlyExtensionState]
|
||||
|
||||
def mkMethods (evalTactic? : Option EvalTactic := none) : CoreM Methods := do
|
||||
let builtinPropagators ← builtinPropagatorsRef.get
|
||||
@@ -99,9 +124,11 @@ def GrindM.run (x : GrindM α) (params : Params) (evalTactic? : Option EvalTacti
|
||||
let simp := params.norm
|
||||
let config := params.config
|
||||
let symPrios := params.symPrios
|
||||
let extensions := params.extensions
|
||||
let anchorRefs? := params.anchorRefs?
|
||||
let funCCs := params.funCCs
|
||||
x (← mkMethods evalTactic?).toMethodsRef { config, anchorRefs?, simpMethods, simp, trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr, ordEqExpr, intExpr, symPrios, funCCs }
|
||||
x (← mkMethods evalTactic?).toMethodsRef
|
||||
{ config, anchorRefs?, simpMethods, simp, extensions, symPrios
|
||||
trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr, ordEqExpr, intExpr }
|
||||
|>.run' { scState }
|
||||
|
||||
private def mkCleanState (mvarId : MVarId) (params : Params) : MetaM Clean.State := mvarId.withContext do
|
||||
@@ -136,11 +163,11 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do
|
||||
let bfalseExpr ← getBoolFalseExpr
|
||||
let natZeroExpr ← getNatZeroExpr
|
||||
let ordEqExpr ← getOrderingEqExpr
|
||||
let thmMap := params.ematch
|
||||
let casesTypes := params.casesTypes
|
||||
let thmEMatch := params.extensions.map fun ext => ext.ematch
|
||||
let thmInj := params.extensions.map fun ext => ext.inj
|
||||
let clean ← mkCleanState mvarId params
|
||||
let sstates ← Solvers.mkInitialStates
|
||||
GoalM.run' { mvarId, ematch.thmMap := thmMap, inj.thms := params.inj, split.casesTypes := casesTypes, clean, sstates } do
|
||||
GoalM.run' { mvarId, ematch.thmMap := thmEMatch, inj.thms := thmInj, clean, sstates } do
|
||||
initENodeCore falseExpr (interpreted := true) (ctor := false)
|
||||
initENodeCore trueExpr (interpreted := true) (ctor := false)
|
||||
initENodeCore btrueExpr (interpreted := false) (ctor := true)
|
||||
|
||||
@@ -52,6 +52,9 @@ private def addBuiltin (propagatorName : Name) (stx : Syntax) : AttrM Unit := do
|
||||
declareBuiltin initDeclName val
|
||||
go.run' {}
|
||||
|
||||
/-
|
||||
**Note**: We currently use the same propagators for all `grind` attributes.
|
||||
-/
|
||||
builtin_initialize
|
||||
registerBuiltinAttribute {
|
||||
ref := by exact decl_name%
|
||||
|
||||
30
src/Lean/Meta/Tactic/Grind/RegisterCommand.lean
Normal file
30
src/Lean/Meta/Tactic/Grind/RegisterCommand.lean
Normal file
@@ -0,0 +1,30 @@
|
||||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Types
|
||||
meta import Lean.Meta.Tactic.Grind.Attr
|
||||
public section
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
macro (name := _root_.Lean.Parser.Command.registerGrindAttr) doc:(docComment)?
|
||||
"register_grind_attr" id:ident : command => do
|
||||
let str1 := id.getId.toString
|
||||
let idParser1 := mkIdentFrom id (`Lean.Parser.Attr ++ id.getId)
|
||||
let str2 := id.getId.toString ++ "!"
|
||||
let idParser2 := mkIdentFrom id (`Lean.Parser.Attr ++ (id.getId.appendAfter "!"))
|
||||
let str3 := id.getId.toString ++ "?"
|
||||
let idParser3 := mkIdentFrom id (`Lean.Parser.Attr ++ (id.getId.appendAfter "?"))
|
||||
let str4 := id.getId.toString ++ "!?"
|
||||
let idParser4 := mkIdentFrom id (`Lean.Parser.Attr ++ (id.getId.appendAfter "!?"))
|
||||
`($[$doc:docComment]? initialize ext : Extension ← registerAttr $(quote id.getId) (ref := $(quote id.getId))
|
||||
$[$doc:docComment]? syntax (name := $idParser1:ident) $(quote str1):str (ppSpace Lean.Parser.Attr.grindMod)? : attr
|
||||
$[$doc:docComment]? syntax (name := $idParser2:ident) $(quote str2):str (ppSpace Lean.Parser.Attr.grindMod)? : attr
|
||||
$[$doc:docComment]? syntax (name := $idParser3:ident) $(quote str3):str (ppSpace Lean.Parser.Attr.grindMod)? : attr
|
||||
$[$doc:docComment]? syntax (name := $idParser4:ident) $(quote str4):str (ppSpace Lean.Parser.Attr.grindMod)? : attr
|
||||
)
|
||||
|
||||
end Lean.Meta.Grind
|
||||
@@ -18,6 +18,9 @@ import Init.Grind.Norm
|
||||
public section
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
/-
|
||||
TODO: group into a `grind` extension object
|
||||
-/
|
||||
builtin_initialize normExt : SimpExtension ← mkSimpExt
|
||||
|
||||
def registerNormTheorems (preDeclNames : Array Name) (postDeclNames : Array Name) : MetaM Unit := do
|
||||
@@ -176,14 +179,18 @@ private def addDeclToUnfold (s : SimpTheorems) (declName : Name) : MetaM SimpThe
|
||||
else
|
||||
return s
|
||||
|
||||
/-- Returns the simplification context used by `grind`. -/
|
||||
protected def getSimpContext (config : Grind.Config) : MetaM Simp.Context := do
|
||||
def getNormTheorems : MetaM SimpTheorems := do
|
||||
let mut thms ← normExt.getTheorems
|
||||
thms ← addDeclToUnfold thms ``GE.ge
|
||||
thms ← addDeclToUnfold thms ``GT.gt
|
||||
thms ← addDeclToUnfold thms ``Nat.cast
|
||||
thms ← addDeclToUnfold thms ``Bool.xor
|
||||
thms ← addDeclToUnfold thms ``Ne
|
||||
return thms
|
||||
|
||||
/-- Returns the simplification context used by `grind`. -/
|
||||
protected def getSimpContext (config : Grind.Config) : MetaM Simp.Context := do
|
||||
let thms ← getNormTheorems
|
||||
Simp.mkContext
|
||||
(config :=
|
||||
{ arith := true
|
||||
|
||||
@@ -174,4 +174,36 @@ def getProofForDecl (declName : Name) : MetaM Expr := do
|
||||
let us := info.levelParams.map mkLevelParam
|
||||
return mkConst declName us
|
||||
|
||||
/--
|
||||
A `TheoremsArray α` is a collection of `Theorems α`.
|
||||
The array is scanned linear during theorem activation.
|
||||
This avoids the need for efficiently merging the `Theorems α` data structure.
|
||||
-/
|
||||
abbrev TheoremsArray (α : Type) := Array (Theorems α)
|
||||
|
||||
@[specialize]
|
||||
def TheoremsArray.retrieve? (s : TheoremsArray α) (sym : Name) : Option (List α × TheoremsArray α) := Id.run do
|
||||
for h : i in *...s.size do
|
||||
if let some (thms, a) ← s[i].retrieve? sym then
|
||||
return some (thms, s.set i a)
|
||||
return none
|
||||
|
||||
def TheoremsArray.insert [TheoremLike α] (s : TheoremsArray α) (thm : α) : TheoremsArray α := Id.run do
|
||||
if s.isEmpty then
|
||||
let thms := { : Theorems α}
|
||||
#[thms.insert thm]
|
||||
else
|
||||
s.modify 0 (·.insert thm)
|
||||
|
||||
def TheoremsArray.isErased (s : TheoremsArray α) (origin : Origin) : Bool :=
|
||||
s.any fun thms => thms.erased.contains origin
|
||||
|
||||
def TheoremsArray.find (s : TheoremsArray α) (origin : Origin) : List α := Id.run do
|
||||
let mut r := []
|
||||
for h : i in *...s.size do
|
||||
let thms := s[i].find origin
|
||||
unless thms.isEmpty do
|
||||
r := r ++ thms
|
||||
return r
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
||||
@@ -9,6 +9,7 @@ public import Lean.Meta.Tactic.Simp.Types
|
||||
public import Lean.Meta.Tactic.Grind.AlphaShareCommon
|
||||
public import Lean.Meta.Tactic.Grind.Attr
|
||||
public import Lean.Meta.Tactic.Grind.CheckResult
|
||||
public import Lean.Meta.Tactic.Grind.Extension
|
||||
public import Init.Data.Queue
|
||||
import Lean.Meta.Tactic.Grind.ExprPtr
|
||||
import Lean.HeadIndex
|
||||
@@ -158,8 +159,7 @@ structure Context where
|
||||
splitSource : SplitSource := .input
|
||||
/-- Symbol priorities for inferring E-matching patterns -/
|
||||
symPrios : SymbolPriorities
|
||||
/-- Global declarations marked with `@[grind funCC]` -/
|
||||
funCCs : NameSet
|
||||
extensions : ExtensionStateArray := #[]
|
||||
trueExpr : Expr
|
||||
falseExpr : Expr
|
||||
natZExpr : Expr
|
||||
@@ -346,6 +346,18 @@ def reportMVarInternalization : GrindM Bool :=
|
||||
def getSymbolPriorities : GrindM SymbolPriorities := do
|
||||
return (← readThe Context).symPrios
|
||||
|
||||
/--
|
||||
Returns `true` if we `declName` is tagged with `funCC` modifier.
|
||||
-/
|
||||
def hasFunCCModifier (declName : Name) : GrindM Bool :=
|
||||
return (← readThe Context).extensions.any fun ext => ext.funCC.contains declName
|
||||
|
||||
def isSplit (declName : Name) : GrindM Bool :=
|
||||
return (← readThe Context).extensions.any fun ext => ext.casesTypes.isSplit declName
|
||||
|
||||
def isEagerSplit (declName : Name) : GrindM Bool :=
|
||||
return (← readThe Context).extensions.any fun ext => ext.casesTypes.isEagerSplit declName
|
||||
|
||||
/--
|
||||
Returns `true` if `declName` is the name of a `match` equation or a `match` congruence equation.
|
||||
-/
|
||||
@@ -758,7 +770,7 @@ structure EMatch.State where
|
||||
Inactive global theorems. As we internalize terms, we activate theorems as we find their symbols.
|
||||
Local theorem provided by users are added directly into `newThms`.
|
||||
-/
|
||||
thmMap : EMatchTheorems
|
||||
thmMap : EMatchTheoremsArray
|
||||
/-- Goal modification time. -/
|
||||
gmt : Nat := 0
|
||||
/-- Active theorems that we have performed ematching at least once. -/
|
||||
@@ -840,8 +852,6 @@ structure SplitArg where
|
||||
structure Split.State where
|
||||
/-- Number of splits performed to get to this goal. -/
|
||||
num : Nat := 0
|
||||
/-- Inductive datatypes marked for case-splitting -/
|
||||
casesTypes : CasesTypes := {}
|
||||
/-- Case-split candidates. -/
|
||||
candidates : List SplitInfo := []
|
||||
/-- Case-splits that have been inserted at `candidates` at some point. -/
|
||||
@@ -901,7 +911,7 @@ structure InjectiveInfo where
|
||||
|
||||
/-- State for injective theorem support. -/
|
||||
structure Injective.State where
|
||||
thms : InjectiveTheorems
|
||||
thms : InjectiveTheoremsArray
|
||||
fns : PHashMap ExprPtr InjectiveInfo := {}
|
||||
deriving Inhabited
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ def grindDischarger (mvarId : MVarId) : MetaM (Option (List MVarId)) := do
|
||||
let [subgoal] ← mvarId.apply markerExpr
|
||||
| return none
|
||||
-- Solve the subgoal with grind
|
||||
let params ← Grind.mkParams {}
|
||||
let params ← Grind.mkDefaultParams {}
|
||||
let result ← Grind.main subgoal params
|
||||
if result.hasFailed then
|
||||
return none
|
||||
|
||||
@@ -126,7 +126,7 @@ def checkInductive (localDecl : LocalDecl) : M Unit := do
|
||||
let .const declName _ := type.getAppFn | return ()
|
||||
let .inductInfo val ← getConstInfo declName | return ()
|
||||
if (← isEligible declName) then
|
||||
unless (← Grind.isSplit declName) do
|
||||
unless (← Grind.isGlobalSplit declName) do
|
||||
modify fun s => { s with indCandidates := s.indCandidates.push { fvarId := localDecl.fvarId, val } }
|
||||
|
||||
unsafe abbrev Cache := PtrSet Expr
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// update me!
|
||||
#include "util/options.h"
|
||||
|
||||
namespace lean {
|
||||
|
||||
@@ -3,7 +3,7 @@ import Lean
|
||||
open Lean Meta Tactic Grind
|
||||
|
||||
def runGrind (x : GrindM α) : MetaM α := do
|
||||
GrindM.run x (← mkParams {})
|
||||
GrindM.run x (← mkDefaultParams {})
|
||||
|
||||
@[noinline] def mkA (x : Nat) := x + 1
|
||||
|
||||
|
||||
@@ -31,3 +31,5 @@ initialize registerBuiltinAttribute {
|
||||
logInfo m!"trace_add attribute added to {decl}"
|
||||
-- applicationTime := .afterCompilation
|
||||
}
|
||||
|
||||
register_grind_attr my_grind
|
||||
|
||||
@@ -155,3 +155,22 @@ termination_by n => n
|
||||
end
|
||||
|
||||
end TraceAdd
|
||||
|
||||
namespace GrindAttr
|
||||
|
||||
opaque f : Nat → Nat
|
||||
opaque g : Nat → Nat
|
||||
|
||||
@[my_grind] theorem fax : f (f x) = f x := sorry
|
||||
|
||||
@[my_grind =] theorem fax2 : f (f (f x)) = f x := by
|
||||
fail_if_success grind
|
||||
grind [my_grind]
|
||||
|
||||
@[my_grind? .] theorem fg : g (f x) = x := sorry
|
||||
|
||||
@[my_grind? =] theorem fax3 : f (f (f x)) = f x := sorry
|
||||
|
||||
@[my_grind!? .] theorem fax4 : f (f (f x)) = f x := sorry
|
||||
|
||||
end GrindAttr
|
||||
|
||||
Reference in New Issue
Block a user