Compare commits

...

2 Commits

Author SHA1 Message Date
Joachim Breitner
7053950e33 Save before commit 2025-06-19 11:32:58 +02:00
Joachim Breitner
11ee5bf3b1 perf: check simp cache in simpLoop (experiment)
This PR changes where `simp` is consulting its own cache, to avoid
repeated simplification.

This is an experiment.
2025-06-19 11:11:01 +02:00
6 changed files with 97 additions and 29 deletions

View File

@@ -748,6 +748,10 @@ def cacheResult (e : Expr) (cfg : Config) (r : Result) : SimpM Result := do
partial def simpLoop (e : Expr) : SimpM Result := withIncRecDepth do
let cfg getConfig
if cfg.memoize then
let cache := ( get).cache
if let some result := cache.find? e then
return result
if ( get).numSteps > cfg.maxSteps then
throwError "simp failed, maximum number of steps exceeded"
else
@@ -784,16 +788,8 @@ def simpImpl (e : Expr) : SimpM Result := withIncRecDepth do
checkSystem "simp"
if ( isProof e) then
return { expr := e }
go
where
go : SimpM Result := do
let cfg getConfig
if cfg.memoize then
let cache := ( get).cache
if let some result := cache.find? e then
return result
trace[Meta.Tactic.simp.heads] "{repr e.toHeadIndex}"
simpLoop e
trace[Meta.Tactic.simp.heads] "{repr e.toHeadIndex}"
simpLoop e
@[inline] private def withSimpContext (ctx : Context) (x : MetaM α) : MetaM α := do
withConfig (fun c => { c with etaStruct := ctx.config.etaStruct }) <|

View File

@@ -15,17 +15,17 @@ Please use `termination_by` to specify a decreasing measure.
decreasing_by.lean:75:13-77:3: error: unexpected token 'end'; expected '{' or tactic
decreasing_by.lean:75:0-75:13: error: unsolved goals
n m : Nat
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (n, dec2 m) (n, m)
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (n, dec2 m) (n, m)
n m : Nat
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (dec1 n, 100) (n, m)
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (dec1 n, 100) (n, m)
decreasing_by.lean:85:0-85:22: error: unsolved goals
case a
n m : Nat
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (n, dec2 m) (n, m)
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (n, dec2 m) (n, m)
n m : Nat
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (dec1 n, 100) (n, m)
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (dec1 n, 100) (n, m)
decreasing_by.lean:93:0-94:22: error: Could not find a decreasing measure.
The basic measures relate at each recursive call as follows:
(<, ≤, =: relation proved, ? all proofs failed, _: no proof attempted)
@@ -35,7 +35,7 @@ The basic measures relate at each recursive call as follows:
Please use `termination_by` to specify a decreasing measure.
decreasing_by.lean:104:0-106:17: error: unsolved goals
n m : Nat
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (dec1 n, 100) (n, m)
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (dec1 n, 100) (n, m)
decreasing_by.lean:114:0-117:17: error: Could not find a decreasing measure.
The basic measures relate at each recursive call as follows:
(<, ≤, =: relation proved, ? all proofs failed, _: no proof attempted)

View File

@@ -13,10 +13,10 @@ trace: [diag] Diagnostics
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
---
trace: [simp] Diagnostics
[simp] used theorems (max: 59, num: 1):
[simp] ack.eq_3 ↦ 59
[simp] tried theorems (max: 59, num: 1):
[simp] ack.eq_3 ↦ 59, succeeded: 59
[simp] used theorems (max: 57, num: 1):
[simp] ack.eq_3 ↦ 57
[simp] tried theorems (max: 57, num: 1):
[simp] ack.eq_3 ↦ 57, succeeded: 57
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
---
trace: [diag] Diagnostics

View File

@@ -0,0 +1,72 @@
/-!
Checks that the simp cache is consulted within `simpLoop`, not just in `simpMain`
-/
axiom testSorry : α
opaque a : Nat
opaque b : Nat
opaque c : Nat
opaque f : Nat Nat
opaque P : Nat Prop
theorem ab : a = b := testSorry
theorem bc : b = c := testSorry
theorem ba : b = a := testSorry
theorem fafb : f a = f b := testSorry
set_option trace.Meta.Tactic.simp.rewrite true
-- This trace should only mention one `bc` rewrite, not two.
/--
trace: [Meta.Tactic.simp.rewrite] bc:1000:
b
==>
c
[Meta.Tactic.simp.rewrite] h:1000:
P c
==>
True
[Meta.Tactic.simp.rewrite] ab:1000:
a
==>
b
[Meta.Tactic.simp.rewrite] h:1000:
P c
==>
True
[Meta.Tactic.simp.rewrite] and_self:1000:
True ∧ True
==>
True
-/
#guard_msgs in
example (h : P c) : P b P a := by simp [ab, bc, h]
-- Almost the same goal, but ordered differently.
/--
trace: [Meta.Tactic.simp.rewrite] ab:1000:
a
==>
b
[Meta.Tactic.simp.rewrite] bc:1000:
b
==>
c
[Meta.Tactic.simp.rewrite] h:1000:
P c
==>
True
[Meta.Tactic.simp.rewrite] h:1000:
P c
==>
True
[Meta.Tactic.simp.rewrite] and_self:1000:
True ∧ True
==>
True
-/
#guard_msgs in
example (h : P c) : P a P b := by simp [ab, bc, h]

View File

@@ -12,8 +12,8 @@ trace: [simp] Diagnostics
[simp] used theorems (max: 50, num: 2):
[simp] f_eq ↦ 50
[simp] q_eq ↦ 50
[simp] tried theorems (max: 101, num: 2):
[simp] f_eq ↦ 101, succeeded: 50
[simp] tried theorems (max: 51, num: 2):
[simp] f_eq ↦ 51, succeeded: 50
[simp] q_eq ↦ 50, succeeded: 50
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
-/
@@ -33,13 +33,13 @@ def ack : Nat → Nat → Nat
/--
trace: [simp] Diagnostics
[simp] used theorems (max: 1201, num: 3):
[simp] ack.eq_3 ↦ 1201
[simp] Nat.reduceAdd (builtin simproc) ↦ 771
[simp] ack.eq_1 ↦ 768
[simp] tried theorems (max: 1973, num: 2):
[simp] ack.eq_3 ↦ 1973, succeeded: 1201
[simp] ack.eq_1 ↦ 768, succeeded: 768
[simp] used theorems (max: 1193, num: 3):
[simp] ack.eq_3 ↦ 1193
[simp] Nat.reduceAdd (builtin simproc) ↦ 508
[simp] ack.eq_1 ↦ 508
[simp] tried theorems (max: 1705, num: 2):
[simp] ack.eq_3 ↦ 1705, succeeded: 1193
[simp] ack.eq_1 ↦ 508, succeeded: 508
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
---
error: tactic 'simp' failed, nested error:

View File

@@ -16,7 +16,7 @@ x✝ :
(y : (_ : Nat) ×' Tree α) →
(invImage (fun x => PSigma.casesOn x fun n t => (n, t)) Prod.instWellFoundedRelation).1 y ⟨n.succ, { cs := cs }⟩ →
Tree α
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => sizeOf a₁ < sizeOf a₂)
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun a₁ a₂ => sizeOf a₁ < sizeOf a₂)
(n, { cs := List.map (fun x => x✝ ⟨n + 1, x.val⟩ ⋯) cs.attach }) (n.succ, { cs := cs })
-/
#guard_msgs(trace) in