Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
011ddbab0b feat: pattern normalization in the grind tactic
This PR ensures patterns provided by users are normalized. See new
test to understand why this is needed.
2025-01-05 11:18:57 -08:00
6 changed files with 78 additions and 17 deletions

View File

@@ -37,9 +37,10 @@ def isOffsetPattern? (pat : Expr) : Option (Expr × Nat) := Id.run do
let .lit (.natVal k) := k | none
return some (pat, k)
def preprocessPattern (pat : Expr) : MetaM Expr := do
def preprocessPattern (pat : Expr) (normalizePattern := true) : MetaM Expr := do
let pat instantiateMVars pat
let pat unfoldReducible pat
let pat if normalizePattern then normalize pat else pure pat
let pat detectOffsets pat
let pat foldProjs pat
return pat
@@ -424,12 +425,15 @@ def mkEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) :
/--
Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`,
creates an E-matching pattern for it using `addEMatchTheorem n [lhs]`
If `normalizePattern` is true, it applies the `grind` simplification theorems and simprocs to the
pattern.
-/
def mkEMatchEqTheorem (declName : Name) : MetaM EMatchTheorem := do
def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) : MetaM EMatchTheorem := do
let info getConstInfo declName
let (numParams, patterns) forallTelescopeReducing info.type fun xs type => do
let_expr Eq _ lhs _ := type | throwError "invalid E-matching equality theorem, conclusion must be an equality{indentExpr type}"
let lhs preprocessPattern lhs
let lhs preprocessPattern lhs normalizePattern
return (xs.size, [lhs.abstract xs])
mkEMatchTheorem declName numParams patterns

View File

@@ -71,7 +71,8 @@ private partial def addMatchEqns (f : Expr) (generation : Nat) : GoalM Unit := d
if ( get).matchEqNames.contains declName then return ()
modify fun s => { s with matchEqNames := s.matchEqNames.insert declName }
for eqn in ( Match.getEquationsFor declName).eqnNames do
activateTheorem ( mkEMatchEqTheorem eqn) generation
-- We disable pattern normalization to prevent the `match`-expression to be reduced.
activateTheorem ( mkEMatchEqTheorem eqn (normalizePattern := false)) generation
private partial def activateTheoremPatterns (fName : Name) (generation : Nat) : GoalM Unit := do
if let some (thms, thmMap) := ( get).thmMap.retrieve? fName then

View File

@@ -6,7 +6,6 @@ Authors: Leonardo de Moura
prelude
import Init.Grind.Lemmas
import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Simp.Simproc
import Lean.Meta.Tactic.Grind.RevertAll
import Lean.Meta.Tactic.Grind.PropagatorAttr
import Lean.Meta.Tactic.Grind.Proj
@@ -15,7 +14,7 @@ import Lean.Meta.Tactic.Grind.Util
import Lean.Meta.Tactic.Grind.Inv
import Lean.Meta.Tactic.Grind.Intro
import Lean.Meta.Tactic.Grind.EMatch
import Lean.Meta.Tactic.Grind.DoNotSimp
import Lean.Meta.Tactic.Grind.SimpUtil
namespace Lean.Meta.Grind
@@ -35,21 +34,12 @@ def mkMethods (fallback : Fallback) : CoreM Methods := do
prop e
}
private def getGrindSimprocs : MetaM Simprocs := do
let s grindNormSimprocExt.getSimprocs
let s addDoNotSimp s
return s
def GrindM.run (x : GrindM α) (mainDeclName : Name) (config : Grind.Config) (fallback : Fallback) : MetaM α := do
let scState := ShareCommon.State.mk _
let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False)
let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True)
let thms grindNormExt.getTheorems
let simprocs := #[( getGrindSimprocs), ( Simp.getSEvalSimprocs)]
let simp Simp.mkContext
(config := { arith := true })
(simpTheorems := #[thms])
(congrTheorems := ( getSimpCongrTheorems))
let simprocs Grind.getSimprocs
let simp Grind.getSimpContext
x ( mkMethods fallback).toMethodsRef { mainDeclName, config, simprocs, simp } |>.run' { scState, trueExpr, falseExpr }
private def mkGoal (mvarId : MVarId) : GrindM Goal := do

View File

@@ -0,0 +1,32 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Simp.Simproc
import Lean.Meta.Tactic.Grind.Simp
import Lean.Meta.Tactic.Grind.DoNotSimp
namespace Lean.Meta.Grind
/-- Returns the array of simprocs used by `grind`. -/
protected def getSimprocs : MetaM (Array Simprocs) := do
let s grindNormSimprocExt.getSimprocs
let s addDoNotSimp s
return #[s, ( Simp.getSEvalSimprocs)]
/-- Returns the simplification context used by `grind`. -/
protected def getSimpContext : MetaM Simp.Context := do
let thms grindNormExt.getTheorems
Simp.mkContext
(config := { arith := true })
(simpTheorems := #[thms])
(congrTheorems := ( getSimpCongrTheorems))
@[export lean_grind_normalize]
def normalizeImp (e : Expr) : MetaM Expr := do
let (r, _) Meta.simp e ( Grind.getSimpContext) ( Grind.getSimprocs)
return r.expr
end Lean.Meta.Grind

View File

@@ -140,4 +140,11 @@ def normalizeLevels (e : Expr) : CoreM Expr := do
| _ => return .continue
Core.transform e (pre := pre)
/--
Normalizes the given expression using the `grind` simplification theorems and simprocs.
This function is used for normalzing E-matching patterns. Note that it does not return a proof.
-/
@[extern "lean_grind_normalize"] -- forward definition
opaque normalize (e : Expr) : MetaM Expr
end Lean.Meta.Grind

View File

@@ -66,3 +66,30 @@ info: [grind.ematch.instance] fx: f a (f a a) = a
#guard_msgs (info) in
example : a = b₁ c = f b₁ b₂ f a c a a = b₂ False := by
grind
namespace pattern_normalization
opaque g : Nat Nat
@[grind_norm] theorem gthm : g x = x := sorry
opaque f : Nat Nat Nat
theorem fthm : f (g x) x = x := sorry
-- The following pattern should be normalized by `grind`. Otherwise, we will not find any instance during E-matching.
/--
info: [grind.ematch.pattern] fthm: [f #0 #0]
-/
#guard_msgs (info) in
grind_pattern fthm => f (g x) x
/--
info: [grind.assert] f x y = b
[grind.assert] y = x
[grind.assert] ¬b = x
[grind.ematch.instance] fthm: f (g y) y = y
[grind.assert] f y y = y
-/
#guard_msgs (info) in
set_option trace.grind.assert true in
example : f (g x) y = b y = x b = x := by
grind
end pattern_normalization