Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
707ce7ac12 feat: add [grind norm] and [grind unfold] attributes
This PR adds the attributes `[grind norm]` and `[grind unfold]` for
controlling the `grind` normalizer/preprocessor.

The `norm` modifier instructs `grind` to use a theorem as a normalization rule. That is,
the theorem is applied during the preprocessing step.
This feature is meant for advanced users who understand how the preprocessor and `grind`'s search
procedure interact with each other.
New users can still benefit from this feature by restricting its use to theorems that completely
eliminate a symbol from the goal. Example:
```lean
theorem max_def : max n m = if n ≤ m then m else n
```
For a negative example, consider:
```lean
opaque f : Int → Int → Int → Int
theorem fax1 : f x 0 1 = 1 := sorry
theorem fax2 : f 1 x 1 = 1 := sorry
attribute [grind norm] fax1
attribute [grind =] fax2

example (h : c = 1) : f c 0 c = 1 := by
  grind -- fails
```
In this example, `fax1` is a normalization rule, but it is not applicable to the input goal since
`f c 0 c` is not an instance of `f x 0 1`. However, `f c 0 c` matches the pattern `f 1 x 1` modulo
the equality `c = 1`. Thus, `grind` instantiates `fax2` with `x := 0`, producing the equality
`f 1 0 1 = 1`, which the normalizer simplifies to `True`. As a result, nothing useful is learned.
In the future, we plan to include linters to automatically detect issues like these.
Example:
```lean
opaque f : Nat → Nat
opaque g : Nat → Nat

@[grind norm] axiom fax : f x = x + 2
@[grind norm ←] axiom fg : f x = g x

example : f x ≥ 2 := by grind
example : f x ≥ g x := by grind
example : f x + g x ≥ 4 := by grind
```

The `unfold` modifier instructs `grind` to unfold the given definition during the preprocessing step.
Example:
```lean
@[grind unfold] def h (x : Nat) := 2 * x
example : 6 ∣ 3*h x := by grind
```
2025-12-22 19:36:35 -08:00
7 changed files with 112 additions and 27 deletions

View File

@@ -198,6 +198,55 @@ Given an application `f a₁ a₂ … aₙ`, when `funCC := true`,
-/
syntax grindFunCC := &"funCC"
/--
The `norm` modifier instructs `grind` to use a theorem as a normalization rule. That is,
the theorem is applied during the preprocessing step.
This feature is meant for advanced users who understand how the preprocessor and `grind`'s search
procedure interact with each other.
New users can still benefit from this feature by restricting its use to theorems that completely
eliminate a symbol from the goal. Example:
```
theorem max_def : max n m = if n ≤ m then m else n
```
For a negative example, consider:
```
opaque f : Int → Int → Int → Int
theorem fax1 : f x 0 1 = 1 := sorry
theorem fax2 : f 1 x 1 = 1 := sorry
attribute [grind norm] fax1
attribute [grind =] fax2
example (h : c = 1) : f c 0 c = 1 := by
grind -- fails
```
In this example, `fax1` is a normalization rule, but it is not applicable to the input goal since
`f c 0 c` is not an instance of `f x 0 1`. However, `f c 0 c` matches the pattern `f 1 x 1` modulo
the equality `c = 1`. Thus, `grind` instantiates `fax2` with `x := 0`, producing the equality
`f 1 0 1 = 1`, which the normalizer simplifies to `True`. As a result, nothing useful is learned.
In the future, we plan to include linters to automatically detect issues like these.
Example:
```
opaque f : Nat → Nat
opaque g : Nat → Nat
@[grind norm] axiom fax : f x = x + 2
@[grind norm ←] axiom fg : f x = g x
example : f x ≥ 2 := by grind
example : f x ≥ g x := by grind
example : f x + g x ≥ 4 := by grind
```
-/
syntax grindNorm := &"norm" (Tactic.simpPre <|> Tactic.simpPost)? patternIgnore("" <|> "<- ")?
/--
The `unfold` modifier instructs `grind` to unfold the given definition during the preprocessing step.
Example:
```
@[grind unfold] def h (x : Nat) := 2 * x
example : 6 3*h x := by grind
```
-/
syntax grindUnfold := &"unfold"
/--
`symbol <prio>` sets the priority of a constant for `grind`s pattern-selection
procedure. `grind` prefers patterns that contain higher-priority symbols.
Example:
@@ -224,7 +273,7 @@ syntax grindMod :=
grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd
<|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager
<|> grindCases <|> grindIntro <|> grindExt <|> grindGen <|> grindSym <|> grindInj
<|> grindFunCC <|> grindDef
<|> grindFunCC <|> grindNorm <|> grindUnfold <|> grindDef
/--
Marks a theorem or definition for use by the `grind` tactic.

View File

@@ -245,7 +245,7 @@ where
elabEMatchTheorem declName (.default false) minIndexable
else
return thms.toArray
| .cases _ | .intro | .inj | .ext | .symbol _ | .funCC =>
| .cases _ | .intro | .inj | .ext | .symbol _ | .funCC | .norm .. | .unfold =>
throwError "invalid modifier"
def logAnchor (c : SplitInfo) : TermElabM Unit := do

View File

@@ -151,7 +151,7 @@ def processTermParam (params : Grind.Params)
checkNoRevert params
let kind if let some mod := mod? then Grind.getAttrKindCore mod else pure .infer
let kind match kind with
| .ematch .user | .cases _ | .intro | .inj | .ext | .symbol _ | .funCC =>
| .ematch .user | .cases _ | .intro | .inj | .ext | .symbol _ | .funCC | .norm .. | .unfold =>
throwError "invalid `grind` parameter, only global declarations are allowed with this kind of modifier"
| .ematch kind => pure kind
| .infer => pure <| .default false
@@ -266,6 +266,8 @@ def processParam (params : Grind.Params)
params := { params with symPrios := params.symPrios.insert declName prio }
| .funCC =>
params := params.insertFunCC declName
| .norm .. => throwError "normalization theorems should be registered using the `@[grind norm]` attribute"
| .unfold => throwError "declarations to be unfolded during normalization should be registered using the `@[grind unfold]` attribute"
return params
/--

View File

@@ -8,10 +8,13 @@ prelude
public import Lean.Meta.Tactic.Grind.Injective
public import Lean.Meta.Tactic.Grind.Cases
public import Lean.Meta.Tactic.Grind.ExtAttr
public import Lean.Meta.Tactic.Simp.Attr
import Lean.ExtraModUses
public section
namespace Lean.Meta.Grind
builtin_initialize normExt : SimpExtension mkSimpExt
inductive AttrKind where
| ematch (k : EMatchTheoremKind)
| cases (eager : Bool)
@@ -21,6 +24,8 @@ inductive AttrKind where
| symbol (prio : Nat)
| inj
| funCC
| norm (post : Bool) (inv : Bool)
| unfold
/-- Return theorem kind for `stx` of the form `Attr.grindThmMod` -/
def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do
@@ -47,6 +52,13 @@ def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do
| `(Parser.Attr.grindMod|ext) => return .ext
| `(Parser.Attr.grindMod|inj) => return .inj
| `(Parser.Attr.grindMod|funCC) => return .funCC
| `(Parser.Attr.grindMod|norm) => return .norm true false
| `(Parser.Attr.grindMod|norm ) => return .norm true false
| `(Parser.Attr.grindMod|norm ) => return .norm (post := false) false
| `(Parser.Attr.grindMod|norm ) => return .norm true true
| `(Parser.Attr.grindMod|norm ) => return .norm true true
| `(Parser.Attr.grindMod|norm ) => return .norm (post := false) true
| `(Parser.Attr.grindMod|unfold) => return .unfold
| `(Parser.Attr.grindMod|symbol $prio:prio) =>
let some prio := prio.raw.isNatLit? | throwErrorAt prio "priority expected"
return .symbol prio
@@ -158,6 +170,15 @@ private def mkGrindAttr (attrName : Name) (minIndexable : Bool) (showInfo : Bool
unless attrName == `grind do
throwError "symbol priorities must be set using the default `[grind]` attribute"
addSymbolPriorityAttr declName attrKind prio
| .norm post inv =>
unless attrName == `grind do
throwError "normalizer must be set using the default `[grind]` attribute"
addSimpTheorem normExt declName (post := post) (inv := inv) attrKind (eval_prio default)
| .unfold =>
unless attrName == `grind do
throwError "declaration to unfold must be set using the default `[grind]` attribute"
unless ( addDeclToUnfold normExt declName (post := false) (inv := false) (prio := eval_prio default) (attrKind := attrKind)) do
throwError "cannot mark declaration to be unfolded by `grind`"
| .cases eager => ext.addCasesAttr declName eager attrKind
| .funCC => ext.addFunCCAttr declName attrKind
| .ext => ext.addExtAttr declName attrKind

View File

@@ -18,11 +18,6 @@ import Init.Grind.Norm
public section
namespace Lean.Meta.Grind
/-
TODO: group into a `grind` extension object
-/
builtin_initialize normExt : SimpExtension mkSimpExt
def registerNormTheorems (preDeclNames : Array Name) (postDeclNames : Array Name) : MetaM Unit := do
let thms normExt.getTheorems
unless thms.lemmaNames.isEmpty do

View File

@@ -4,15 +4,35 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Tactic.Simp.Simproc
public section
namespace Lean.Meta
open Simp
/--
Marks `declName` to be unfolded in the given `SimpExtension`.
-/
def addDeclToUnfold (ext : SimpExtension) (declName : Name) (post inv : Bool) (prio : Nat) (attrKind : AttributeKind) : MetaM Bool := do
if getOriginalConstKind? ( getEnv) declName == some .defn then
if inv then
throwError m!"Invalid `←` modifier: `{.ofConstName declName}` is a declaration name to be unfolded"
++ .note m!"The simplifier will automatically unfold definitions marked with the `[simp]` \
attribute, but it will not \"refold\" them"
if ( Simp.ignoreEquations declName) then
ext.add (SimpEntry.toUnfold declName) attrKind
else if let some eqns getEqnsFor? declName then
for eqn in eqns do
addSimpTheorem ext eqn post (inv := false) attrKind prio
ext.add (SimpEntry.toUnfoldThms declName eqns) attrKind
if ( Simp.unfoldEvenWithEqns declName) then
ext.add (SimpEntry.toUnfold declName) attrKind
else
ext.add (SimpEntry.toUnfold declName) attrKind
return true
else
return false
def mkSimpAttr (attrName : Name) (attrDescr : String) (ext : SimpExtension)
(ref : Name := by exact decl_name%) : IO Unit :=
registerBuiltinAttribute {
@@ -32,22 +52,7 @@ def mkSimpAttr (attrName : Name) (attrDescr : String) (ext : SimpExtension)
let prio getAttrParamOptPrio stx[3]
if ( isProp info.sig.get.type) then
addSimpTheorem ext declName post (inv := inv) attrKind prio
else if getOriginalConstKind? ( getEnv) declName == some .defn then
if inv then
throwError m!"Invalid `←` modifier: `{.ofConstName declName}` is a declaration name to be unfolded"
++ .note m!"The simplifier will automatically unfold definitions marked with the `[simp]` \
attribute, but it will not \"refold\" them"
if ( Simp.ignoreEquations declName) then
ext.add (SimpEntry.toUnfold declName) attrKind
else if let some eqns getEqnsFor? declName then
for eqn in eqns do
addSimpTheorem ext eqn post (inv := false) attrKind prio
ext.add (SimpEntry.toUnfoldThms declName eqns) attrKind
if ( Simp.unfoldEvenWithEqns declName) then
ext.add (SimpEntry.toUnfold declName) attrKind
else
ext.add (SimpEntry.toUnfold declName) attrKind
else
else unless ( addDeclToUnfold ext declName post inv prio attrKind) do
throwError m!"Cannot add `simp` attribute to `{.ofConstName declName}`: It is not a proposition nor a definition (to unfold)"
++ .note m!"The `[simp]` attribute can be added to lemmas that should be automatically used by the simplifier \
and to definitions that the simplifier should automatically unfold"

View File

@@ -0,0 +1,13 @@
opaque f : Nat Nat
opaque g : Nat Nat
@[grind norm] axiom fax : f x = x + 2
@[grind norm ] axiom fg : f x = g x
example : f x 2 := by grind
example : f x g x := by grind
example : f x + g x 4 := by grind
@[grind unfold] def h (x : Nat) := 2 * x
example : 2 h x := by grind