Compare commits

...

2 Commits

Author SHA1 Message Date
Leonardo de Moura
d881b43ca7 test: update sym_simp_cd test for perm theorem support
Replace `Nat.add_comm_of_pos` (a permutation theorem) with `f_idem`
(a non-perm conditional rewrite: `a > 0 → f (f a) = f a`).
Update all examples and expected trace output accordingly.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 17:14:37 -07:00
Leonardo de Moura
be9d2644d6 feat: add permutation theorem support to Sym.simp
This PR prevents `Sym.simp` from looping on permutation theorems like
`∀ x y, x + y = y + x`.

- Add `perm : Bool` field to `Theorem`
- Add `isPerm` that checks if LHS and RHS are the same structure with
  pattern variables (de Bruijn indices) rearranged via a consistent
  bijection. Uses `ReaderT` (offset), `StateT` (fwd/bwd maps),
  `ExceptT` (failure).
- Compute `perm` in `mkTheoremFromDecl` / `mkTheoremFromExpr`
- In `Theorem.rewrite`, when `perm` is true, only apply the rewrite if
  the result is strictly less than the input (using `acLt`)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 16:33:47 -07:00
4 changed files with 156 additions and 95 deletions

View File

@@ -9,6 +9,7 @@ public import Lean.Meta.Sym.Simp.Simproc
public import Lean.Meta.Sym.Simp.Theorems
public import Lean.Meta.Sym.Simp.App
public import Lean.Meta.Sym.Simp.Discharger
import Lean.Meta.ACLt
import Lean.Meta.Sym.InstantiateS
import Lean.Meta.Sym.InstantiateMVarsS
import Init.Data.Range.Polymorphic.Iterators
@@ -71,10 +72,16 @@ public def Theorem.rewrite (thm : Theorem) (e : Expr) (d : Discharger := dischar
let expr instantiateRevBetaS rhs args.toArray
if isSameExpr e expr then
return mkRflResultCD isCD
else if !( checkPerm thm.perm e expr) then
return mkRflResultCD isCD
else
return .step expr proof (contextDependent := isCD)
else
return .rfl
where
checkPerm (perm : Bool) (e result : Expr) : MetaM Bool := do
if !perm then return true
acLt result e
public def Theorems.rewrite (thms : Theorems) (d : Discharger := dischargeNone) : Simproc := fun e => do
-- Track `cd` across all attempted theorems. If theorem A fails with cd=true

View File

@@ -10,6 +10,7 @@ public import Lean.Meta.DiscrTree
import Lean.Meta.Sym.Simp.DiscrTree
import Lean.Meta.AppBuilder
import Lean.ExtraModUses
import Init.Omega
public section
namespace Lean.Meta.Sym.Simp
@@ -26,6 +27,10 @@ structure Theorem where
pattern : Pattern
/-- Right-hand side of the equation. -/
rhs : Expr
/-- If `true`, the theorem is a permutation rule (e.g., `x + y = y + x`).
Rewriting is only applied when the result is strictly less than the input
(using `acLt`), preventing infinite loops. -/
perm : Bool := false
deriving Inhabited
instance : BEq Theorem where
@@ -45,6 +50,49 @@ def Theorems.getMatch (thms : Theorems) (e : Expr) : Array Theorem :=
def Theorems.getMatchWithExtra (thms : Theorems) (e : Expr) : Array (Theorem × Nat) :=
Sym.getMatchWithExtra thms.thms e
/--
Check whether `lhs` and `rhs` (with `numVars` pattern variables represented as `.bvar` indices
`≥ 0` before any binder entry) are permutations of each other — same structure with only
pattern variable indices rearranged via a consistent bijection.
Bvars with index `< offset` are "local" (introduced by binders inside the pattern) and must
match exactly. Bvars with index `≥ offset` are pattern variables and may be permuted,
but the mapping must be a bijection.
Simplified compared to `Meta.simp`'s `isPerm`:
- Uses de Bruijn indices instead of metavariables
- No `.proj` (folded into applications) or `.letE` (zeta-expanded) cases
-/
private abbrev IsPermM := ReaderT Nat $ StateT (Array (Option Nat)) $ Except Unit
private partial def isPermAux (a b : Expr) : IsPermM Unit := do
match a, b with
| .bvar i, .bvar j =>
let offset read
if i < offset && j < offset then
unless i == j do throw ()
else if i >= offset && j >= offset then
let pi := i - offset
let pj := j - offset
let fwd get
if h : pi >= fwd.size then throw () else
match fwd[pi] with
| none =>
-- Check injectivity: pj must not already be a target of another mapping
if fwd.contains (some pj) then throw ()
set (fwd.set pi (some pj))
| some pj' => unless pj == pj' do throw ()
else throw ()
| .app f₁ a₁, .app f₂ a₂ => isPermAux f₁ f₂; isPermAux a₁ a₂
| .mdata _ s, t => isPermAux s t
| s, .mdata _ t => isPermAux s t
| .forallE _ d₁ b₁ _, .forallE _ d₂ b₂ _ => isPermAux d₁ d₂; withReader (· + 1) (isPermAux b₁ b₂)
| .lam _ d₁ b₁ _, .lam _ d₂ b₂ _ => isPermAux d₁ d₂; withReader (· + 1) (isPermAux b₁ b₂)
| s, t => unless s == t do throw ()
def isPerm (numVars : Nat) (lhs rhs : Expr) : Bool :=
((isPermAux lhs rhs).run 0 |>.run (Array.replicate numVars none)) matches .ok _
/-- Describes how a theorem's conclusion was adapted to an equality for use in `Sym.simp`. -/
private inductive EqAdaptation where
/-- Already an equality `lhs = rhs`. Proof is used as-is. -/
@@ -99,13 +147,15 @@ where
def mkTheoremFromDecl (declName : Name) : MetaM Theorem := do
let (pattern, (rhs, adaptation)) mkPatternFromDeclWithKey declName selectEqKey
let expr wrapProof pattern.varTypes.size (mkConst declName) adaptation
return { expr, pattern, rhs }
let perm := isPerm pattern.varTypes.size pattern.pattern rhs
return { expr, pattern, rhs, perm }
/-- Create a `Theorem` from a proof expression. Handles equalities, `¬`, `↔`, and propositions. -/
def mkTheoremFromExpr (e : Expr) : MetaM Theorem := do
let (pattern, (rhs, adaptation)) mkPatternFromExprWithKey e [] selectEqKey
let expr wrapProof pattern.varTypes.size e adaptation
return { expr, pattern, rhs }
let perm := isPerm pattern.varTypes.size pattern.pattern rhs
return { expr, pattern, rhs, perm }
/--
Environment extension storing a set of `Sym.Simp` theorems.

View File

@@ -42,153 +42,120 @@ example : 2 + 3 = 5 := by
-- and lands in the transient cache. On the second invocation, the transient cache is
-- cleared, so there should be NO persistent cache hit for the overall expression.
-- Only context-independent sub-expressions (literals, fvars) get persistent cache hits.
theorem Nat.add_comm_of_pos (a b : Nat) (_h : 0 < a) : a + b = b + a := Nat.add_comm a b
opaque f : Nat Nat
axiom f_idem (a : Nat) (_h : 0 < a) : f (f a) = f a
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
trace: [sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] transient cache hit: f (f n)
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] transient cache hit: f (f n)
[sym.simp.debug.cache] persistent cache hit: f n
-/
#guard_msgs in
example (n : Nat) (h : 0 < n) : n + 2 = 2 + n := by
sym_simp_twice [Nat.add_comm_of_pos]
example (n : Nat) (h : 0 < n) : f (f n) = f (f (f n)) := by
sym_simp_twice [f_idem]
-- Test 3: Congruence — cd propagates through function application.
-- `n + 2` rewrites context-dependently (cd=true), `3 + 4` evaluates ground (cd=false).
-- The congruence combines both, so the overall result is cd=true.
-- On second traversal: ground sub-expressions (`3 + 4`, `7`) hit persistent cache,
-- but cd-tainted expressions (`2 + n`, `2 + n + 7`) are only in transient.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] transient cache hit: (2 + n) * 7
trace: [sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: f n + 7
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: 3 + 4
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] persistent cache hit: 7
[sym.simp.debug.cache] transient cache hit: (2 + n) * 7
[sym.simp.debug.cache] persistent cache hit: f n + 7
[sym.simp.debug.cache] persistent cache hit: f n + 7
-/
#guard_msgs in
example (n : Nat) (h : 0 < n) : (n + 2) * (3 + 4) = (2 + n) * 7 := by
sym_simp_twice [Nat.add_comm_of_pos]
example (n : Nat) (h : 0 < n) : f (f n) + (3 + 4) = f n + 7 := by
sym_simp_twice [f_idem]
-- Similar to previous test, but `Nat.add_comm_of_pos` is not applicable, but discharger must return `cd := true`.
-- Similar to previous test, but `f_idem` is not applicable (no hypothesis), but discharger must return `cd := true`.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] transient cache hit: n + 2
[sym.simp.debug.cache] transient cache hit: (n + 2) * 7
trace: [sym.simp.debug.cache] transient cache hit: f (f n)
[sym.simp.debug.cache] transient cache hit: f (f n) + 7
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: 3 + 4
[sym.simp.debug.cache] transient cache hit: n + 2
[sym.simp.debug.cache] transient cache hit: f (f n)
[sym.simp.debug.cache] persistent cache hit: 7
[sym.simp.debug.cache] transient cache hit: (n + 2) * 7
[sym.simp.debug.cache] transient cache hit: f (f n) + 7
-/
#guard_msgs in
example (n : Nat) : (n + 2) * (3 + 4) = (n + 2) * 7 := by
sym_simp_twice [Nat.add_comm_of_pos]
example (n : Nat) : f (f n) + (3 + 4) = f (f n) + 7 := by
sym_simp_twice [f_idem]
-- Test 4: Arrow — cd propagates through implication.
-- The hypothesis `n + 2 = 2 + n` is simplified context-dependently to `True`.
-- `True → True` simplifies to `True`. The whole result is cd=true.
-- `True` hits persistent cache; `2 + n` is only in transient.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
trace: [sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: True
-/
#guard_msgs in
set_option linter.unusedVariables false in
example (n : Nat) (h : 0 < n) : (n + 2 = 2 + n) True := by
sym_simp_twice [Nat.add_comm_of_pos]
example (n : Nat) (h : 0 < n) : (f (f n) = f n) True := by
sym_simp_twice [f_idem]
-- Test 5: Lambda — cd propagates through funext.
-- Body `n + 2` is simplified context-dependently inside the binder.
-- `withFreshTransientCache` clears the transient cache on binder entry.
-- The lambda result `fun x => 2 + n` is only in transient.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: fun x => 2 + n
trace: [sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: fun x => f n
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: fun x => 2 + n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: fun x => f n
[sym.simp.debug.cache] persistent cache hit: fun x => f n
-/
#guard_msgs in
example (n : Nat) (_h : 0 < n) : (fun _ : Nat => n + 2) = (fun _ : Nat => 2 + n) := by
sym_simp_twice [Nat.add_comm_of_pos]
example (n : Nat) (_h : 0 < n) : (fun _ : Nat => f (f n)) = (fun _ : Nat => f n) := by
sym_simp_twice [f_idem]
-- Test 6: Control flow — cd propagates through `ite` condition.
-- The condition `n + 2 = 2 + n` is simplified context-dependently.
-- The `ite` result inherits cd, and `1` (ground) is in persistent cache.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
trace: [sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: 1
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: 1
[sym.simp.debug.cache] persistent cache hit: 1
-/
#guard_msgs in
example (n : Nat) (h : 0 < n) : (if n + 2 = 2 + n then 1 else 0) = 1 := by
sym_simp_twice [Nat.add_comm_of_pos]
example (n : Nat) (h : 0 < n) : (if f (f n) = f n then 1 else 0) = 1 := by
sym_simp_twice [f_idem]
-- Test 7: Dependent forall — body cd under binder with `withFreshTransientCache`.
-- Simplifying `∀ (m : Nat), n + 2 = 2 + n` enters a binder (for `m`).
-- The transient cache is cleared on binder entry (`withFreshTransientCache`).
-- The body uses a cd rewrite, so the overall result is cd=true.
-- After "second traversal": `Nat` (the binder type) hits persistent cache.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
trace: [sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] transient cache hit: f (f n)
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: Nat
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] persistent cache hit: f n
[sym.simp.debug.cache] transient cache hit: f (f n)
[sym.simp.debug.cache] persistent cache hit: f n
-/
#guard_msgs in
set_option linter.unusedVariables false in
example (n : Nat) (h : 0 < n) : (_ : Nat), n + 2 = 2 + n := by
sym_simp_twice [Nat.add_comm_of_pos]
example (n : Nat) (h : 0 < n) : (_ : Nat), f (f n) = f (f (f n)) := by
sym_simp_twice [f_idem]

View File

@@ -0,0 +1,37 @@
import Lean
/-! Tests for permutation theorem support in `Sym.simp` -/
-- Nat.add_comm is a permutation theorem: x + y = y + x
-- Without perm support, `simp` with this theorem would loop.
register_sym_simp commSimp where
post := ground >> rewrite [Nat.add_comm]
-- This should terminate: Nat.add_comm is detected as perm,
-- and only applied when result < input.
example (x y : Nat) : x + y = y + x := by
sym =>
simp commSimp
-- Combining perm with non-perm theorems
register_sym_simp commZeroSimp where
post := ground >> rewrite [Nat.add_comm, Nat.zero_add, Nat.add_zero]
example (x y : Nat) : 0 + (x + y) = y + x := by
sym =>
simp commZeroSimp
-- Verify perm doesn't interfere with non-perm theorems
register_sym_simp nonPermSimp where
post := ground >> rewrite [Nat.zero_add]
example (x : Nat) : 0 + x = x := by
sym =>
simp nonPermSimp
register_sym_simp simple where
post := ground
example (x y z w : Nat) : x + y + z + w = w + (z + y) + x := by
sym => simp simple [Nat.add_comm, Nat.add_assoc, Nat.add_left_comm]