Compare commits

...

1 Commits

Author SHA1 Message Date
Sebastian Graf
f81340c3ca perf: persist grind preprocessing caches across calls
This PR persists per-call traversal caches in `markNestedSubsingletons`,
`canonImpl`, and `unfoldReducible` across grind preprocessing calls. Previously
these steps created fresh HashMaps each invocation, re-traversing all shared
subexpressions. For workloads that generate many facts with growing shared
structure (e.g. nested Nat subtraction chains), total work was O(n²). Persisting
the caches and adding strategic `shareCommon` calls to restore pointer identity
reduces `grind canon` time by ~94% and `grind mark subsingleton` time by ~74%
on the included benchmark at n=100.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 18:47:44 +00:00
5 changed files with 73 additions and 6 deletions

View File

@@ -242,7 +242,9 @@ private def normOfNatArgs? (args : Array Expr) : MetaM (Option (Array Expr)) :=
@[export lean_grind_canon]
partial def canonImpl (e : Expr) : GoalM Expr := do profileitM Exception "grind canon" ( getOptions) do
trace_goal[grind.debug.canon] "{e}"
visit e |>.run' {}
let (r, cache') (visit e).run ( get').visitCache
modify' fun s => { s with visitCache := cache' }
return r
where
visit (e : Expr) : StateRefT (Std.HashMap ExprPtr Expr) GoalM Expr := do
unless e.isApp || e.isForall do return e

View File

@@ -12,6 +12,14 @@ import Lean.Meta.Tactic.Grind.Util
public section
namespace Lean.Meta.Grind
/-- Cached variant of `Sym.unfoldReducible` that persists the `transformWithCache` cache across calls,
ensuring pointer stability for shared sub-expressions. -/
def unfoldReducibleCached (e : Expr) : GrindM Expr := do
let cache := ( get).unfoldReducibleCache
let (e', cache') Meta.transformWithCache e cache (pre := fun e => Sym.unfoldReducibleStep e)
modify fun s => { s with unfoldReducibleCache := cache' }
return e'
private abbrev M := StateRefT (Std.HashMap ExprPtr Expr) GrindM
def isMarkedSubsingletonConst (e : Expr) : Bool := Id.run do
@@ -41,7 +49,9 @@ Recall that the congruence closure module has special support for them.
-- TODO: consider other subsingletons in the future? We decided to not support them to avoid the overhead of
-- synthesizing `Subsingleton` instances.
partial def markNestedSubsingletons (e : Expr) : GrindM Expr := do profileitM Exception "grind mark subsingleton" ( getOptions) do
visit e |>.run' {}
let (r, cache') (visit e).run ( get).markSubsingletonCache
modify fun s => { s with markSubsingletonCache := cache' }
return r
where
visit (e : Expr) : M Expr := do
if isMarkedSubsingletonApp e then
@@ -103,7 +113,7 @@ where
-/
/- We must also apply beta-reduction to improve the effectiveness of the congruence closure procedure. -/
let e Core.betaReduce e
let e Sym.unfoldReducible e
let e unfoldReducibleCached e
/- We must mask proofs occurring in `prop` too. -/
let e visit e
let e eraseIrrelevantMData e
@@ -123,6 +133,8 @@ def markProof (e : Expr) : GrindM Expr := do
if e.isAppOf ``Grind.nestedProof then
return e -- `e` is already marked
else
markNestedProof e |>.run' {}
let (r, cache') (markNestedProof e).run ( get).markSubsingletonCache
modify fun s => { s with markSubsingletonCache := cache' }
return r
end Lean.Meta.Grind

View File

@@ -58,8 +58,9 @@ def preprocessImpl (e : Expr) : GoalM Simp.Result := do
let e' instantiateMVars r.expr
-- Remark: `simpCore` unfolds reducible constants, but it does not consistently visit all possible subterms.
-- So, we must use the following `unfoldReducible` step. It is non-op in most cases
let e' Sym.unfoldReducible e'
let e' unfoldReducibleCached e'
let e' abstractNestedProofs e'
let e' shareCommon e'
let e' markNestedSubsingletons e'
let e' eraseIrrelevantMData e'
let e' foldProjs e'
@@ -70,6 +71,7 @@ def preprocessImpl (e : Expr) : GoalM Simp.Result := do
let r' replacePreMatchCond e'
let r r.mkEqTrans r'
let e' := r'.expr
let e' shareCommon e'
let e' canon e'
let e' shareCommon e'
trace_goal[grind.simp] "{e}\n===>\n{e'}"
@@ -98,6 +100,6 @@ but ensures assumptions made by `grind` are satisfied.
-/
def preprocessLight (e : Expr) : GoalM Expr := do
let e instantiateMVars e
shareCommon ( canon ( normalizeLevels ( foldProjs ( eraseIrrelevantMData ( markNestedSubsingletons ( Sym.unfoldReducible e))))))
shareCommon ( canon ( normalizeLevels ( foldProjs ( eraseIrrelevantMData ( markNestedSubsingletons ( unfoldReducibleCached e))))))
end Lean.Meta.Grind

View File

@@ -232,6 +232,10 @@ structure State where
Cached anchors (aka stable hash codes) for terms in the `grind` state.
-/
anchors : PHashMap ExprPtr UInt64 := {}
/-- Persistent cache for `markNestedSubsingletons` and `markProof` traversals. -/
markSubsingletonCache : Std.HashMap ExprPtr Expr := {}
/-- Persistent cache for `unfoldReducible` via `transformWithCache`, ensuring pointer stability. -/
unfoldReducibleCache : Std.HashMap ExprStructEq Expr := {}
instance : Nonempty State :=
.intro {}
@@ -712,6 +716,8 @@ structure Canon.State where
canon : PHashMap Expr Expr := {}
proofCanon : PHashMap Expr Expr := {}
canonArg : PHashMap CanonArgKey Expr := {}
/-- Persistent cache for `canonImpl` visit traversals. -/
visitCache : Std.HashMap ExprPtr Expr := {}
deriving Inhabited
/-- Trace information for a case split. -/

View File

@@ -0,0 +1,45 @@
import Lean
/-!
Regression test: `grind` on nested Nat subtraction chains.
After `simp only [Goal, loop]`, the goal becomes:
```
post (s₁ + s₂ - s₂ + s₂ - s₂ + ⋯ + s₂ - s₂) s₂
```
with `n` nested `(+ s₂ - s₂)` operations. Grind handles each Nat subtraction by
generating a `natCast_sub` fact, processed inside-out. The i-th fact has expression
size O(i). Before persistent caches were added, the preprocessing steps
`markNestedSubsingletons` and `canonImpl` used fresh per-call caches (traversing
all O(i) subexpressions each time), giving O(n²) total work.
This test checks that the scaling from n=25 to n=100 (4x) is at most 10x in wall-clock
time, which passes for ≤ O(n^1.7) but fails for O(n²) (which would give ~16x).
-/
def loop : Nat (Nat Nat Prop) (Nat Nat Prop)
| 0, post => post
| n+1, post => fun s₁ s₂ => loop n post (s₁ + s₂ - s₂) s₂
def Goal (n : Nat) : Prop := post s₁ s₂, post s₁ s₂ loop n post s₁ s₂
set_option maxRecDepth 10000
set_option maxHeartbeats 10000000
open Lean Elab Command in
elab "#test_grind_scaling" : command => do
let solveAt (n : Nat) : CommandElabM Nat := do
let start IO.monoNanosNow
elabCommand ( `(command|
example : Goal $(Syntax.mkNumLit (toString n)) := by intros; simp only [Goal, loop]; grind
))
let stop IO.monoNanosNow
return stop - start
let t_small solveAt 25
let t_large solveAt 100
let ratio := t_large.toFloat / t_small.toFloat
-- Linear: expect ~4x for 4x problem size. Quadratic: would be ~16x.
-- Use 10x as threshold (generous for noise, catches quadratic).
if ratio > 10.0 then
throwError "grind preprocessing scaling regression: 100/25 time ratio is {ratio}x (expected < 10x for linear scaling)"
#test_grind_scaling