Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
f2e438e44d fix: simp caching
closes #3943

TODO: remove `cache` field from `Simp.Result`.
2024-04-22 13:52:46 -07:00
3 changed files with 78 additions and 32 deletions

View File

@@ -241,7 +241,17 @@ def getSimpLetCase (n : Name) (t : Expr) (b : Expr) : MetaM SimpLetCase := do
else
return SimpLetCase.dep
def withNewLemmas {α} (xs : Array Expr) (f : SimpM α) : SimpM α := do
/--
We use `withNewlemmas` whenever updating the local context.
We use `withFreshCache` because the local context affects `simp` rewrites
even when `contextual := false`.
For example, the `discharger` may inspect the current local context. The default
discharger does that when applying equational theorems, and the user may
use `(discharger := assumption)` or `(discharger := omega)`.
If the `wishFreshCache` introduces performance issues, we can design a better solution
for the default discharger which is used most of the time.
-/
def withNewLemmas {α} (xs : Array Expr) (f : SimpM α) : SimpM α := withFreshCache do
if ( getConfig).contextual then
let mut s getSimpTheorems
let mut updated := false
@@ -250,7 +260,7 @@ def withNewLemmas {α} (xs : Array Expr) (f : SimpM α) : SimpM α := do
s s.addTheorem (.fvar x.fvarId!) x
updated := true
if updated then
withSimpTheorems s f
withTheReader Context (fun ctx => { ctx with simpTheorems := s }) f
else
f
else
@@ -304,30 +314,27 @@ def simpArrow (e : Expr) : SimpM Result := do
trace[Debug.Meta.Tactic.simp] "arrow [{(← getConfig).contextual}] {p} [{← isProp p}] -> {q} [{← isProp q}]"
if ( pure ( getConfig).contextual <&&> isProp p <&&> isProp q) then
trace[Debug.Meta.Tactic.simp] "ctx arrow {rp.expr} -> {q}"
withLocalDeclD e.bindingName! rp.expr fun h => do
let s getSimpTheorems
let s s.addTheorem (.fvar h.fvarId!) h
withSimpTheorems s do
let rq simp q
match rq.proof? with
| none => mkImpCongr e rp rq
| some hq =>
let hq mkLambdaFVars #[h] hq
/-
We use the default reducibility setting at `mkImpDepCongrCtx` and `mkImpCongrCtx` because they use the theorems
```lean
@implies_dep_congr_ctx : ∀ {p₁ p₂ q₁ : Prop}, p₁ = p₂ → ∀ {q₂ : p₂ → Prop}, (∀ (h : p₂), q₁ = q₂ h) → (p₁ → q₁) = ∀ (h : p₂), q₂ h
@implies_congr_ctx : ∀ {p₁ p₂ q₁ q₂ : Prop}, p₁ = p₂ → (p₂ → q₁ = q₂) → (p₁ → q₁) = (p₂ → q₂)
```
And the proofs may be from `rfl` theorems which are now omitted. Moreover, we cannot establish that the two
terms are definitionally equal using `withReducible`.
TODO (better solution): provide the problematic implicit arguments explicitly. It is more efficient and avoids this
problem.
-/
if rq.expr.containsFVar h.fvarId! then
return { expr := ( mkForallFVars #[h] rq.expr), proof? := ( withDefault <| mkImpDepCongrCtx ( rp.getProof) hq) }
else
return { expr := e.updateForallE! rp.expr rq.expr, proof? := ( withDefault <| mkImpCongrCtx ( rp.getProof) hq) }
withLocalDeclD e.bindingName! rp.expr fun h => withNewLemmas #[h] do
let rq simp q
match rq.proof? with
| none => mkImpCongr e rp rq
| some hq =>
let hq mkLambdaFVars #[h] hq
/-
We use the default reducibility setting at `mkImpDepCongrCtx` and `mkImpCongrCtx` because they use the theorems
```lean
@implies_dep_congr_ctx : ∀ {p₁ p₂ q₁ : Prop}, p₁ = p₂ → ∀ {q₂ : p₂ → Prop}, (∀ (h : p₂), q₁ = q₂ h) → (p₁ → q₁) = ∀ (h : p₂), q₂ h
@implies_congr_ctx : ∀ {p₁ p₂ q₁ q₂ : Prop}, p₁ = p₂ → (p₂ → q₁ = q₂) → (p₁ → q₁) = (p₂ → q₂)
```
And the proofs may be from `rfl` theorems which are now omitted. Moreover, we cannot establish that the two
terms are definitionally equal using `withReducible`.
TODO (better solution): provide the problematic implicit arguments explicitly. It is more efficient and avoids this
problem.
-/
if rq.expr.containsFVar h.fvarId! then
return { expr := ( mkForallFVars #[h] rq.expr), proof? := ( withDefault <| mkImpDepCongrCtx ( rp.getProof) hq) }
else
return { expr := e.updateForallE! rp.expr rq.expr, proof? := ( withDefault <| mkImpCongrCtx ( rp.getProof) hq) }
else
mkImpCongr e rp ( simp q)
@@ -389,7 +396,7 @@ def simpLet (e : Expr) : SimpM Result := do
| SimpLetCase.dep => return { expr := ( dsimp e) }
| SimpLetCase.nondep =>
let rv simp v
withLocalDeclD n t fun x => do
withLocalDeclD n t fun x => withNewLemmas #[x] do
let bx := b.instantiate1 x
let rbx simp bx
let hb? match rbx.proof? with
@@ -402,7 +409,7 @@ def simpLet (e : Expr) : SimpM Result := do
| _, some h => return { expr := e', proof? := some ( mkLetCongr ( rv.getProof) h) }
| SimpLetCase.nondepDepVar =>
let v' dsimp v
withLocalDeclD n t fun x => do
withLocalDeclD n t fun x => withNewLemmas #[x] do
let bx := b.instantiate1 x
let rbx simp bx
let e' := mkLet n t v' ( rbx.expr.abstractM #[x])

View File

@@ -21,7 +21,11 @@ structure Result where
/-- A proof that `$e = $expr`, where the simplified expression is on the RHS.
If `none`, the proof is assumed to be `refl`. -/
proof? : Option Expr := none
/-- If `cache := true` the result is cached. -/
/--
If `cache := true` the result is cached.
Warning: we will remove this field in the future. It is currently used by
`arith := true`, but we can now refactor the code to avoid the hack.
-/
cache : Bool := true
deriving Inhabited
@@ -284,9 +288,6 @@ Save current cache, reset it, execute `x`, and then restore original cache.
modify fun s => { s with cache := {} }
try x finally modify fun s => { s with cache := cacheSaved }
@[inline] def withSimpTheorems (s : SimpTheoremsArray) (x : SimpM α) : SimpM α := do
withFreshCache <| withTheReader Context (fun ctx => { ctx with simpTheorems := s }) x
@[inline] def withDischarger (discharge? : Expr SimpM (Option Expr)) (x : SimpM α) : SimpM α :=
withFreshCache <| withReader (fun r => { MethodsRef.toMethods r with discharge? }.toMethodsRef) x

38
tests/lean/run/3943.lean Normal file
View File

@@ -0,0 +1,38 @@
example (f : Nat Nat) : (if f x = 0 then f x else f x + 1) + f x = y := by
simp (config := { contextual := true })
guard_target = (if f x = 0 then 0 else f x + 1) + f x = y
sorry
example (f : Nat Nat) : f x = 0 f x + 1 = y := by
simp (config := { contextual := true })
guard_target = f x = 0 1 = y
sorry
example (f : Nat Nat) : let _ : f x = 0 := sorry; f x + 1 = y := by
simp (config := { contextual := true, zeta := false })
guard_target = let _ : f x = 0 := sorry; 1 = y
sorry
def overlap : Nat Nat
| 0 => 0
| 1 => 1
| n+1 => overlap n
example : (if (n = 0 False) then overlap (n+1) else overlap (n+1)) = overlap n := by
simp only [overlap]
guard_target = (if (n = 0 False) then overlap n else overlap (n+1)) = overlap n
sorry
opaque p : Nat Bool
opaque g : Nat Nat
@[simp] theorem g_eq (h : p x) : g x = x := sorry
example : (if p x then g x else g x + 1) + g x = y := by
simp (discharger := assumption)
guard_target = (if p x then x else g x + 1) + g x = y
sorry
example : (let _ : p x := sorry; g x + 1 = y) g x = y := by
simp (config := { zeta := false }) (discharger := assumption)
guard_target = (let _ : p x := sorry; x + 1 = y) g x = y
sorry