Compare commits

...

4 Commits

Author SHA1 Message Date
Leonardo de Moura
488c2c409e chore: grind? skeleton 2025-01-26 17:08:19 -08:00
Leonardo de Moura
8a6b9c5e4f feat elaborate grind modifier usr 2025-01-26 17:01:54 -08:00
Leonardo de Moura
346efc1cf2 feat: store EMatchTheorem kind
We also added a new attribute to specify that E-matching theorem was
created using `grind_pattern` command
2025-01-26 15:59:12 -08:00
Leonardo de Moura
5012a24f58 chore: add grind? parser 2025-01-26 15:30:22 -08:00
5 changed files with 186 additions and 52 deletions

View File

@@ -14,10 +14,11 @@ syntax grindEqRhs := atomic("=" "_")
syntax grindEqBwd := atomic("" "=")
syntax grindBwd := ""
syntax grindFwd := ""
syntax grindUsr := &"usr"
syntax grindCases := &"cases"
syntax grindCasesEager := atomic(&"cases" &"eager")
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindCasesEager <|> grindCases
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindUsr <|> grindCasesEager <|> grindCases
syntax (name := grind) "grind" (grindMod)? : attr
@@ -75,4 +76,10 @@ syntax (name := grind)
(" [" withoutPosition(grindParam,*) "]")?
("on_failure " term)? : tactic
syntax (name := grindTrace)
"grind?" optConfig (&" only")?
(" [" withoutPosition(grindParam,*) "]")?
("on_failure " term)? : tactic
end Lean.Parser.Tactic

View File

@@ -31,7 +31,7 @@ def elabGrindPattern : CommandElab := fun stx => do
let pattern instantiateMVars pattern
let pattern Grind.preprocessPattern pattern
return pattern.abstract xs
Grind.addEMatchTheorem declName xs.size patterns.toList
Grind.addEMatchTheorem declName xs.size patterns.toList .user
| _ => throwUnsupportedSyntax
open Command Term in
@@ -45,7 +45,7 @@ def elabInitGrindNorm : CommandElab := fun stx =>
Grind.registerNormTheorems pre post
| _ => throwUnsupportedSyntax
def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic.grindParam) : MetaM Grind.Params := do
def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic.grindParam) (only : Bool) : MetaM Grind.Params := do
let mut params := params
for p in ps do
match p with
@@ -59,6 +59,16 @@ def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic.
let declName realizeGlobalConstNoOverloadWithInfo id
let kind if let some mod := mod? then Grind.getAttrKindCore mod else pure .infer
match kind with
| .ematch .user =>
unless only do
withRef p <| Grind.throwInvalidUsrModifier
let s Grind.getEMatchTheorems
let thms := s.find (.decl declName)
let thms := thms.filter fun thm => thm.kind == .user
if thms.isEmpty then
throwErrorAt p "invalid use of `usr` modifier, `{declName}` does not have patterns specified with the command `grind_pattern`"
for thm in thms do
params := { params with extra := params.extra.push thm }
| .ematch kind =>
params withRef p <| addEMatchTheorem params declName kind
| .cases eager =>
@@ -97,7 +107,7 @@ def mkGrindParams (config : Grind.Config) (only : Bool) (ps : TSyntaxArray ``Pa
let ematch if only then pure {} else Grind.getEMatchTheorems
let casesTypes if only then pure {} else Grind.getCasesTypes
let params := { params with ematch, casesTypes }
elabGrindParams params ps
elabGrindParams params ps only
def grind
(mvarId : MVarId) (config : Grind.Config)
@@ -126,16 +136,32 @@ private def elabFallback (fallback? : Option Term) : TermElabM (Grind.GoalM Unit
pure auxDeclName
unsafe evalConst (Grind.GoalM Unit) auxDeclName
private def evalGrindCore
(ref : Syntax)
(config : TSyntax `Lean.Parser.Tactic.optConfig)
(only : Option Syntax)
(params : Option (Syntax.TSepArray `Lean.Parser.Tactic.grindParam ","))
(fallback? : Option Term)
(_trace : Bool) -- TODO
: TacticM Unit := do
let fallback elabFallback fallback?
let only := only.isSome
let params := if let some params := params then params.getElems else #[]
logWarningAt ref "The `grind` tactic is experimental and still under development. Avoid using it in production projects"
let declName := ( Term.getDeclName?).getD `_grind
let config elabGrindConfig config
withMainContext do liftMetaFinishingTactic (grind · config only params declName fallback)
@[builtin_tactic Lean.Parser.Tactic.grind] def evalGrind : Tactic := fun stx => do
match stx with
| `(tactic| grind $config:optConfig $[only%$only]? $[ [$params:grindParam,*] ]? $[on_failure $fallback?]?) =>
let fallback elabFallback fallback?
let only := only.isSome
let params := if let some params := params then params.getElems else #[]
logWarningAt stx "The `grind` tactic is experimental and still under development. Avoid using it in production projects"
let declName := ( Term.getDeclName?).getD `_grind
let config elabGrindConfig config
withMainContext do liftMetaFinishingTactic (grind · config only params declName fallback)
evalGrindCore stx config only params fallback? false
| _ => throwUnsupportedSyntax
@[builtin_tactic Lean.Parser.Tactic.grindTrace] def evalGrindTrace : Tactic := fun stx => do
match stx with
| `(tactic| grind? $config:optConfig $[only%$only]? $[ [$params:grindParam,*] ]? $[on_failure $fallback?]?) =>
evalGrindCore stx config only params fallback? true
| _ => throwUnsupportedSyntax
end Lean.Elab.Tactic

View File

@@ -23,6 +23,7 @@ def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do
| `(Parser.Attr.grindMod| =_) => return .ematch .eqRhs
| `(Parser.Attr.grindMod| _=_) => return .ematch .eqBoth
| `(Parser.Attr.grindMod| =) => return .ematch .eqBwd
| `(Parser.Attr.grindMod| usr) => return .ematch .user
| `(Parser.Attr.grindMod| cases) => return .cases false
| `(Parser.Attr.grindMod| cases eager) => return .cases true
| _ => throwError "unexpected `grind` theorem kind: `{stx}`"
@@ -34,6 +35,9 @@ def getAttrKindFromOpt (stx : Syntax) : CoreM AttrKind := do
else
getAttrKindCore stx[1][0]
def throwInvalidUsrModifier : CoreM α :=
throwError "the modifier `usr` is only relevant in parameters for `grind only`"
builtin_initialize
registerBuiltinAttribute {
name := `grind
@@ -57,6 +61,7 @@ builtin_initialize
applicationTime := .afterCompilation
add := fun declName stx attrKind => MetaM.run' do
match ( getAttrKindFromOpt stx) with
| .ematch .user => throwInvalidUsrModifier
| .ematch k => addEMatchAttr declName attrKind k
| .cases eager => addCasesAttr declName eager attrKind
| .infer =>

View File

@@ -92,6 +92,30 @@ instance : BEq Origin where
instance : Hashable Origin where
hash a := hash a.key
inductive TheoremKind where
| eqLhs | eqRhs | eqBoth | eqBwd | fwd | bwd | default | user /- pattern specified using `grind_pattern` command -/
deriving Inhabited, BEq, Repr
private def TheoremKind.toAttribute : TheoremKind String
| .eqLhs => "[grind =]"
| .eqRhs => "[grind =_]"
| .eqBoth => "[grind _=_]"
| .eqBwd => "[grind ←=]"
| .fwd => "[grind →]"
| .bwd => "[grind ←]"
| .default => "[grind]"
| .user => "[grind]"
private def TheoremKind.explainFailure : TheoremKind String
| .eqLhs => "failed to find pattern in the left-hand side of the theorem's conclusion"
| .eqRhs => "failed to find pattern in the right-hand side of the theorem's conclusion"
| .eqBoth => unreachable! -- eqBoth is a macro
| .eqBwd => "failed to use theorem's conclusion as a pattern"
| .fwd => "failed to find patterns in the antecedents of the theorem"
| .bwd => "failed to find patterns in the theorem's conclusion"
| .default => "failed to find patterns"
| .user => unreachable!
/-- A theorem for heuristic instantiation based on E-matching. -/
structure EMatchTheorem where
/--
@@ -106,16 +130,20 @@ structure EMatchTheorem where
/-- Contains all symbols used in `pattterns`. -/
symbols : List HeadIndex
origin : Origin
/-- The `kind` is used for generating the `patterns`. We save it here to implement `grind?`. -/
kind : TheoremKind
deriving Inhabited
/-- Set of E-matching theorems. -/
structure EMatchTheorems where
/-- The key is a symbol from `EMatchTheorem.symbols`. -/
private map : PHashMap Name (List EMatchTheorem) := {}
private smap : PHashMap Name (List EMatchTheorem) := {}
/-- Set of theorem ids that have been inserted using `insert`. -/
private origins : PHashSet Origin := {}
/-- Theorems that have been marked as erased -/
private erased : PHashSet Origin := {}
/-- Mapping from origin to E-matching theorems associated with this origin. -/
private omap : PHashMap Origin (List EMatchTheorem) := {}
deriving Inhabited
/--
@@ -130,13 +158,19 @@ def EMatchTheorems.insert (s : EMatchTheorems) (thm : EMatchTheorem) : EMatchThe
let .const declName :: syms := thm.symbols
| unreachable!
let thm := { thm with symbols := syms }
let { map, origins, erased } := s
let origins := origins.insert thm.origin
let erased := erased.erase thm.origin
if let some thms := map.find? declName then
return { map := map.insert declName (thm::thms), origins, erased }
let { smap, origins, erased, omap } := s
let origin := thm.origin
let origins := origins.insert origin
let erased := erased.erase origin
let smap := if let some thms := smap.find? declName then
smap.insert declName (thm::thms)
else
return { map := map.insert declName [thm], origins, erased }
smap.insert declName [thm]
let omap := if let some thms := omap.find? origin then
omap.insert origin (thm::thms)
else
omap.insert origin [thm]
return { smap, origins, erased, omap }
/-- Returns `true` if `s` contains a theorem with the given origin. -/
def EMatchTheorems.contains (s : EMatchTheorems) (origin : Origin) : Bool :=
@@ -156,11 +190,20 @@ The theorems are removed from `s`.
-/
@[inline]
def EMatchTheorems.retrieve? (s : EMatchTheorems) (sym : Name) : Option (List EMatchTheorem × EMatchTheorems) :=
if let some thms := s.map.find? sym then
some (thms, { s with map := s.map.erase sym })
if let some thms := s.smap.find? sym then
some (thms, { s with smap := s.smap.erase sym })
else
none
/--
Returns theorems associated with the given origin.
-/
def EMatchTheorems.find (s : EMatchTheorems) (origin : Origin) : List EMatchTheorem :=
if let some thms := s.omap.find? origin then
thms
else
[]
def EMatchTheorem.getProofWithFreshMVarLevels (thm : EMatchTheorem) : MetaM Expr := do
if thm.proof.isConst && thm.levelParams.isEmpty then
let declName := thm.proof.constName!
@@ -491,7 +534,7 @@ private def ppParamsAt (proof : Expr) (numParams : Nat) (paramPos : List Nat) :
Creates an E-matching theorem for a theorem with proof `proof`, `numParams` parameters, and the given set of patterns.
Pattern variables are represented using de Bruijn indices.
-/
def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) : MetaM EMatchTheorem := do
def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) (kind : TheoremKind): MetaM EMatchTheorem := do
let (patterns, symbols, bvarFound) NormalizePattern.main patterns
if symbols.isEmpty then
throwError "invalid pattern for `{← origin.pp}`{indentD (patterns.map ppPattern)}\nthe pattern does not contain constant symbols for indexing"
@@ -501,7 +544,7 @@ def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams
throwError "invalid pattern(s) for `{← origin.pp}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
return {
proof, patterns, numParams, symbols
levelParams, origin
levelParams, origin, kind
}
private def getProofFor (declName : Name) : CoreM Expr := do
@@ -514,8 +557,8 @@ private def getProofFor (declName : Name) : CoreM Expr := do
Creates an E-matching theorem for `declName` with `numParams` parameters, and the given set of patterns.
Pattern variables are represented using de Bruijn indices.
-/
def mkEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM EMatchTheorem := do
mkEMatchTheoremCore (.decl declName) #[] numParams ( getProofFor declName) patterns
def mkEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) (kind : TheoremKind) : MetaM EMatchTheorem := do
mkEMatchTheoremCore (.decl declName) #[] numParams ( getProofFor declName) patterns kind
/--
Given a theorem with proof `proof` and type of the form `∀ (a_1 ... a_n), lhs = rhs`,
@@ -535,15 +578,15 @@ def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof :
trace[grind.debug.ematch.pattern] "mkEMatchEqTheoremCore: after preprocessing: {pat}, {← normalize pat}"
let pats := splitWhileForbidden (pat.abstract xs)
return (xs.size, pats)
mkEMatchTheoremCore origin levelParams numParams proof patterns
mkEMatchTheoremCore origin levelParams numParams proof patterns (if useLhs then .eqLhs else .eqRhs)
def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) : MetaM EMatchTheorem := do
let (numParams, patterns) forallTelescopeReducing ( inferType proof) fun xs type => do
let_expr f@Eq α lhs rhs := type
| throwError "invalid E-matching `` theorem, conclusion must be an equality{indentExpr type}"
| throwError "invalid E-matching `←=` theorem, conclusion must be an equality{indentExpr type}"
let pat preprocessPattern (mkEqBwdPattern f.constLevels! α lhs rhs)
return (xs.size, [pat.abstract xs])
mkEMatchTheoremCore origin levelParams numParams proof patterns
mkEMatchTheoremCore origin levelParams numParams proof patterns .eqBwd
/--
Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`,
@@ -559,8 +602,8 @@ def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) (useLhs : Boo
Adds an E-matching theorem to the environment.
See `mkEMatchTheorem`.
-/
def addEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM Unit := do
ematchTheoremsExt.add ( mkEMatchTheorem declName numParams patterns)
def addEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) (kind : TheoremKind) : MetaM Unit := do
ematchTheoremsExt.add ( mkEMatchTheorem declName numParams patterns kind)
/--
Adds an E-matching equality theorem to the environment.
@@ -573,28 +616,6 @@ def addEMatchEqTheorem (declName : Name) : MetaM Unit := do
def getEMatchTheorems : CoreM EMatchTheorems :=
return ematchTheoremsExt.getState ( getEnv)
inductive TheoremKind where
| eqLhs | eqRhs | eqBoth | eqBwd | fwd | bwd | default
deriving Inhabited, BEq, Repr
private def TheoremKind.toAttribute : TheoremKind String
| .eqLhs => "[grind =]"
| .eqRhs => "[grind =_]"
| .eqBoth => "[grind _=_]"
| .eqBwd => "[grind ←=]"
| .fwd => "[grind →]"
| .bwd => "[grind ←]"
| .default => "[grind]"
private def TheoremKind.explainFailure : TheoremKind String
| .eqLhs => "failed to find pattern in the left-hand side of the theorem's conclusion"
| .eqRhs => "failed to find pattern in the right-hand side of the theorem's conclusion"
| .eqBoth => unreachable! -- eqBoth is a macro
| .eqBwd => "failed to use theorem's conclusion as a pattern"
| .fwd => "failed to find patterns in the antecedents of the theorem"
| .bwd => "failed to find patterns in the theorem's conclusion"
| .default => "failed to find patterns"
/-- Returns the types of `xs` that are propositions. -/
private def getPropTypes (xs : Array Expr) : MetaM (Array Expr) :=
xs.filterMapM fun x => do
@@ -702,7 +723,7 @@ where
trace[grind.ematch.pattern] "{← origin.pp}: {patterns.map ppPattern}"
return some {
proof, patterns, numParams, symbols
levelParams, origin
levelParams, origin, kind
}
def mkEMatchTheoremForDecl (declName : Name) (thmKind : TheoremKind) : MetaM EMatchTheorem := do

View File

@@ -0,0 +1,75 @@
opaque f : Nat Nat
/--
error: the modifier `usr` is only relevant in parameters for `grind only`
-/
#guard_msgs (error) in
@[grind usr]
theorem fthm : f (f x) = f x := sorry
/--
info: [grind.ematch.pattern] fthm: [f (f #0)]
-/
#guard_msgs (info) in
set_option trace.grind.ematch.pattern true in
example : f (f (f x)) = f x := by
grind only [fthm]
/--
info: [grind.ematch.instance] fthm: f (f (f x)) = f (f x)
[grind.ematch.instance] fthm: f (f x) = f x
-/
#guard_msgs (info) in
set_option trace.grind.ematch.instance true in
example : f (f (f x)) = f x := by
grind only [fthm]
#guard_msgs (info) in
-- should not instantiate anything using pattern `f (f #0)`
set_option trace.grind.ematch.instance true in
example : f x = x := by
fail_if_success grind only [fthm]
sorry
/--
error: the modifier `usr` is only relevant in parameters for `grind only`
-/
#guard_msgs (error) in
example : f (f (f x)) = f x := by
grind [usr fthm]
/--
error: invalid use of `usr` modifier, `fthm` does not have patterns specified with the command `grind_pattern`
-/
#guard_msgs (error) in
example : f (f (f x)) = f x := by
grind only [usr fthm]
grind_pattern fthm => f x
example : f (f (f x)) = f x := by
grind only [usr fthm]
#guard_msgs (info) in
-- should not instantiate anything using pattern `f (f #0)`
set_option trace.grind.ematch.instance true in
example : f x = x := by
fail_if_success grind only [fthm]
sorry
/--
info: [grind.ematch.instance] fthm: f (f x) = f x
[grind.ematch.instance] fthm: f (f (f x)) = f (f x)
-/
#guard_msgs (info) in
set_option trace.grind.ematch.instance true in
example : f x = x := by
fail_if_success grind only [usr fthm]
sorry
/--
error: the modifier `usr` is only relevant in parameters for `grind only`
-/
#guard_msgs (error) in
example : f (f (f x)) = f x := by
grind [usr fthm]