Compare commits

...

5 Commits

Author SHA1 Message Date
Leonardo de Moura
404a5b7351 fix: isRfl must ignore contextDependent; add unit tests
isRfl was using `matches .rfl` which only matched when ALL fields
have default values. With the new contextDependent field, .rfl false true
(cd=true) no longer matched, causing Theorems.rewrite to incorrectly
return cd .rfl results instead of continuing to the next theorem.

Fix: match `.rfl false _` (done=false, ignore cd).

Add unit tests for mkEqTransResult, andThen, orElse cd propagation.
2026-03-19 15:54:29 -04:00
Leonardo de Moura
96d15f5d48 test: add dependent forall test for contextDependent cache
Exercises `simpForall'` with `withFreshTransientCache` — the body
`n + 2 = 2 + n` is simplified context-dependently inside the binder.
The binder type `Nat` hits persistent cache on second traversal.
2026-03-19 15:41:33 -04:00
Leonardo de Moura
0a8121e609 test: add test for failed cd discharger propagation
When  fails (no matching hypothesis), it returns
`.failed true`. This cd propagates through the rewrite result, so
`n + 2` lands in the transient cache even though no rewrite occurred.
2026-03-19 15:35:05 -04:00
Leonardo de Moura
399f9a6717 test: add systematic contextDependent cache tests
Tests cover:
1. Ground evaluation → persistent cache hit
2. Conditional rewrite with dischargeAssumption → transient only
3. Congruence with mixed cd/non-cd sub-results → cd propagates
4. Arrow/implication → cd propagates through domain simplification
5. Lambda/funext → cd propagates through body under binder
6. Control flow (ite) → cd propagates through condition
2026-03-19 15:29:39 -04:00
Leonardo de Moura
81180ba129 feat: add contextDependent to Sym.simp Result with two-tier cache
Replace the coarse `wellBehavedMethods` flag with per-result
`contextDependent : Bool` tracking. Split the simp cache into
`persistentCache` (context-independent results, survives binder entry)
and `transientCache` (context-dependent results, cleared on binder entry).

Propagate `contextDependent` through all combinators (congruence,
transitivity, control flow, arrows, rewriting). The invariant:
when combining sub-results, `cd` is the disjunction of all sub-results'
flags — including `.rfl` results, since `simp` might take a completely
different code path in another local context.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-19 15:13:38 -04:00
27 changed files with 752 additions and 282 deletions

View File

@@ -23,7 +23,7 @@ open Lean.Meta.Tactic.Cbv
let evalResult cbvEntry lhs
match evalResult with
| .rfl .. => return ()
| .step e' proof _ =>
| .step e' proof _ _ =>
updateLhs e' proof
end Lean.Elab.Tactic.Conv

View File

@@ -48,20 +48,23 @@ public def mkCongr (e : Expr) (f a : Expr) (fr : Result) (ar : Result) (_ : e =
let β inferType e
let v getLevel β
return mkApp2 (mkConst declName [u, v]) α β
-- Propagate `contextDependent` from both sub-results. If either `simp f` or `simp a`
-- was context-dependent, the result for `f a` is too: in another local context, either
-- sub-expression might simplify differently, changing the overall result.
match fr, ar with
| .rfl _, .rfl _ => return .rfl
| .step f' hf _, .rfl _ =>
| .rfl _ cd₁, .rfl _ cd₂ => return mkRflResultCD (cd₁ || cd₂)
| .step f' hf _ cd₁, .rfl _ cd₂ =>
let e' mkAppS f' a
let h := mkApp4 ( mkCongrPrefix ``congrFun') f f' hf a
return .step e' h
| .rfl _, .step a' ha _ =>
return .step e' h (contextDependent := cd₁ || cd₂)
| .rfl _ cd₁, .step a' ha _ cd₂ =>
let e' mkAppS f a'
let h := mkApp4 ( mkCongrPrefix ``congrArg) a a' f ha
return .step e' h
| .step f' hf _, .step a' ha _ =>
return .step e' h (contextDependent := cd₁ || cd₂)
| .step f' hf _ cd₁, .step a' ha _ cd₂ =>
let e' mkAppS f' a'
let h := mkApp6 ( mkCongrPrefix ``congr) f f' a a' hf ha
return .step e' h
return .step e' h (contextDependent := cd₁ || cd₂)
/--
Returns a proof using `congrFun`
@@ -69,7 +72,8 @@ Returns a proof using `congrFun`
congrFun.{u, v} {α : Sort u} {β : α → Sort v} {f g : (x : α) → β x} (h : f = g) (a : α) : f a = g a
```
-/
def mkCongrFun (e : Expr) (f a : Expr) (f' : Expr) (hf : Expr) (_ : e = .app f a) (done := false) : SymM Result := do
def mkCongrFun (e : Expr) (f a : Expr) (f' : Expr) (hf : Expr) (_ : e = .app f a)
(done := false) (contextDependent := false) : SymM Result := do
let .forallE x _ βx _ whnfD ( inferType f)
| throwError "failed to build congruence proof, function expected{indentExpr f}"
let α inferType a
@@ -78,7 +82,7 @@ def mkCongrFun (e : Expr) (f a : Expr) (f' : Expr) (hf : Expr) (_ : e = .app f a
let β := Lean.mkLambda x .default α βx
let e' mkAppS f' a
let h := mkApp6 (mkConst ``congrFun [u, v]) α β f f' hf a
return .step e' h done
return .step e' h done contextDependent
/--
Handles simplification of over-applied function terms.
@@ -122,9 +126,12 @@ public def simpOverApplied (e : Expr) (numArgs : Nat) (simpFn : Expr → SimpM R
mkCongr e f a fr .rfl h
else
mkCongr e f a fr ( simp a) h
-- Dependent argument: can't simplify `a` independently.
-- Propagate `cd` from the function result: if `fr` was context-dependent,
-- the whole expression is, since `fr` might differ in another context.
else match fr with
| .rfl _ => return .rfl
| .step f' hf _ => mkCongrFun e f a f' hf h
| .rfl _ cd => return mkRflResultCD cd
| .step f' hf _ cd => mkCongrFun e f a f' hf h (contextDependent := cd)
| _ => unreachable!
visit e numArgs
@@ -160,8 +167,8 @@ public def propagateOverApplied (e : Expr) (numArgs : Nat) (simpFn : Expr → Si
| .app f a =>
let r visit f i
match r with
| .rfl _ => return r
| .step f' hf done => mkCongrFun e f a f' hf h done
| .rfl _ _ => return r
| .step f' hf done cd => mkCongrFun e f a f' hf h done cd
| _ => unreachable!
visit e numArgs
@@ -242,30 +249,32 @@ where
else
let .app f a := e | unreachable!
let (hf, fType) go (i-1) f
-- Propagate `cd` from both sub-results. Even when both return `.rfl`,
-- either might succeed in a different context, changing the result.
match hf, ( simp a) with
| .rfl _, .rfl _ => return (.rfl, default)
| .step f' hf _, .rfl _ =>
| .rfl _ cd₁, .rfl _ cd₂ => return (mkRflResultCD (cd₁ || cd₂), default)
| .step f' hf _ cd₁, .rfl _ cd₂ =>
let .forallE _ α β _ whnfToForall fType | unreachable!
let e' mkAppS f' a
let u getLevel α
let v getLevel β
let h := mkApp6 (mkConst ``congrFun' [u, v]) α β f f' hf a
return (.step e' h, β)
| .rfl _, .step a' ha _ =>
return (.step e' h (contextDependent := cd₁ || cd₂), β)
| .rfl _ cd₁, .step a' ha _ cd₂ =>
let fType getFnType f (i-1)
let .forallE _ α β _ whnfToForall fType | unreachable!
let e' mkAppS f a'
let u getLevel α
let v getLevel β
let h := mkApp6 (mkConst ``congrArg [u, v]) α β a a' f ha
return (.step e' h, β)
| .step f' hf _, .step a' ha _ =>
return (.step e' h (contextDependent := cd₁ || cd₂), β)
| .step f' hf _ cd₁, .step a' ha _ cd₂ =>
let .forallE _ α β _ whnfToForall fType | unreachable!
let e' mkAppS f' a'
let u getLevel α
let v getLevel β
let h := mkApp8 (mkConst ``congr [u, v]) α β f f' a a' hf ha
return (.step e' h, β)
return (.step e' h (contextDependent := cd₁ || cd₂), β)
/--
Simplifies arguments of a function application with interlaced rewritable/fixed arguments.
@@ -292,9 +301,10 @@ where
let fr go (i - 1) f (by omega)
if rewritable[i - 1] then
mkCongr e f a fr ( simp a) h
-- Fixed (non-rewritable) argument: propagate `cd` from the function result.
else match fr with
| .rfl _ => return .rfl
| .step f' hf _ => mkCongrFun e f a f' hf h
| .rfl _ cd => return mkRflResultCD cd
| .step f' hf _ cd => mkCongrFun e f a f' hf h (contextDependent := cd)
| _ => unreachable!
/--
@@ -379,11 +389,11 @@ def simpUsingCongrThm (e : Expr) (thm : CongrTheorem) : SimpM Result := do
| .eq =>
subst := subst.push arg
match argResults[j]! with
| .rfl _ =>
| .rfl _ _ =>
let h mkEqRefl arg
proof := mkApp2 proof arg h
subst := subst.push arg |>.push h
| .step arg' h _ =>
| .step arg' h _ _ =>
proof := mkApp2 proof arg' h
subst := subst.push arg' |>.push h
type := type.bindingBody!.bindingBody!
@@ -394,29 +404,38 @@ def simpUsingCongrThm (e : Expr) (thm : CongrTheorem) : SimpM Result := do
let hasCast := argKinds.any (· matches .cast)
let rhs if hasCast then Simp.removeUnnecessaryCasts rhs else pure rhs
let rhs share rhs
return .step rhs proof
-- The result is context-dependent if any simplified argument was.
let cd := argResults.any (·.isContextDependent)
return .step rhs proof (contextDependent := cd)
/-
Recursively simplifies arguments of kind `.eq`. The array `argResults` is initialized lazily
as soon as the simplifier returns a non-`rfl` result for some arguments.
`numEqs` is the number of `.eq` arguments found so far.
-/
let rec simpEqArgs (e : Expr) (i : Nat) (numEqs : Nat) (argResults : Array Result) : SimpM Result := do
-- `anyCD` tracks cd from `.rfl` results that `pushResult` may not store
-- (due to lazy initialization of `argResults`). Without this, cd from `.rfl`
-- sub-results would be silently dropped.
let rec simpEqArgs (e : Expr) (i : Nat) (numEqs : Nat) (argResults : Array Result)
(anyCD : Bool) : SimpM Result := do
match e with
| .app f a =>
match argKinds[i]! with
| .subsingletonInst
| .fixed => simpEqArgs f (i-1) numEqs argResults
| .cast => simpEqArgs f (i-1) numEqs argResults
| .eq => simpEqArgs f (i-1) (numEqs+1) (pushResult argResults numEqs ( simp a))
| .fixed => simpEqArgs f (i-1) numEqs argResults anyCD
| .cast => simpEqArgs f (i-1) numEqs argResults anyCD
| .eq =>
let r simp a
simpEqArgs f (i-1) (numEqs+1) (pushResult argResults numEqs r) (anyCD || r.isContextDependent)
| _ => unreachable!
| _ =>
if argResults.isEmpty then
return .rfl
return mkRflResultCD anyCD
else
mkNonRflResult argResults.reverse
let r mkNonRflResult argResults.reverse
return if anyCD && !r.isContextDependent then r.withContextDependent else r
let numArgs := e.getAppNumArgs
if numArgs > argKinds.size then
simpOverApplied e (numArgs - argKinds.size) (simpEqArgs · (argKinds.size - 1) 0 #[])
simpOverApplied e (numArgs - argKinds.size) (simpEqArgs · (argKinds.size - 1) 0 #[] false)
else if numArgs < argKinds.size then
/-
**Note**: under-applied case. This can be optimized, but this case is so
@@ -424,7 +443,7 @@ def simpUsingCongrThm (e : Expr) (thm : CongrTheorem) : SimpM Result := do
-/
simpOverApplied e e.getAppNumArgs (fun _ => return .rfl)
else
simpEqArgs e (argKinds.size - 1) 0 #[]
simpEqArgs e (argKinds.size - 1) 0 #[] false
/--
Main entry point for simplifying function application arguments.
@@ -461,10 +480,11 @@ public def simpAppArgRange (e : Expr) (start stop : Nat) : SimpM Result := do
match h : e with
| .app f a =>
let fr visit f i
-- Argument outside the [start, stop) range: propagate `cd` from function result.
let skip : SimpM Result := do
match fr with
| .rfl _ => return .rfl
| .step f' hf _ => mkCongrFun e f a f' hf h
| .rfl _ cd => return mkRflResultCD cd
| .step f' hf _ cd => mkCongrFun e f a f' hf h (contextDependent := cd)
if i < stop then
let .forallE _ α β _ whnfD ( inferType f) | unreachable!
if !β.hasLooseBVars then

View File

@@ -23,26 +23,31 @@ def simpIte : Simproc := fun e => do
if numArgs < 5 then return .rfl (done := true)
propagateOverApplied e (numArgs - 5) fun e => do
let_expr f@ite α c _ a b := e | return .rfl
-- **cd propagation**: `cd` from `simp c` is propagated to ALL branches.
-- When `cd = true`, `simp c` might produce a different result in another context
-- (e.g., a conditional rewrite could change `c`). This means the entire `ite`
-- result might differ — even branches like `isTrueExpr c` that seem ground-determined,
-- because in another context `simp c` might not return `.rfl` at all.
match ( simp c) with
| .rfl _ =>
| .rfl _ cd =>
if ( isTrueExpr c) then
return .step a <| mkApp3 (mkConst ``ite_true f.constLevels!) α a b
return .step a (mkApp3 (mkConst ``ite_true f.constLevels!) α a b) (contextDependent := cd)
else if ( isFalseExpr c) then
return .step b <| mkApp3 (mkConst ``ite_false f.constLevels!) α a b
return .step b (mkApp3 (mkConst ``ite_false f.constLevels!) α a b) (contextDependent := cd)
else
return .rfl (done := true)
| .step c' h _ =>
return mkRflResult (done := true) (contextDependent := cd)
| .step c' h _ cd =>
if ( isTrueExpr c') then
return .step a <| mkApp (e.replaceFn ``ite_cond_eq_true) h
return .step a (mkApp (e.replaceFn ``ite_cond_eq_true) h) (contextDependent := cd)
else if ( isFalseExpr c') then
return .step b <| mkApp (e.replaceFn ``ite_cond_eq_false) h
return .step b (mkApp (e.replaceFn ``ite_cond_eq_false) h) (contextDependent := cd)
else
let .some inst' trySynthInstance (mkApp (mkConst ``Decidable) c') | return .rfl
let inst' shareCommon inst'
let e' := e.getBoundedAppFn 4
let e' mkAppS₄ e' c' inst' a b
let h' := mkApp3 (e.replaceFn ``Sym.ite_cond_congr) c' inst' h
return .step e' h' (done := true)
return .step e' h' (done := true) (contextDependent := cd)
/--
Simplifies a dependent `if-then-else` expression.
@@ -52,25 +57,26 @@ def simpDIte : Simproc := fun e => do
if numArgs < 5 then return .rfl (done := true)
propagateOverApplied e (numArgs - 5) fun e => do
let_expr f@dite α c _ a b := e | return .rfl
-- See `simpIte` for why `cd` is propagated to all branches.
match ( simp c) with
| .rfl _ =>
| .rfl _ cd =>
if ( isTrueExpr c) then
let a' share <| a.betaRev #[mkConst ``True.intro]
return .step a' <| mkApp3 (mkConst ``dite_true f.constLevels!) α a b
return .step a' (mkApp3 (mkConst ``dite_true f.constLevels!) α a b) (contextDependent := cd)
else if ( isFalseExpr c) then
let b' share <| b.betaRev #[mkConst ``not_false]
return .step b' <| mkApp3 (mkConst ``dite_false f.constLevels!) α a b
return .step b' (mkApp3 (mkConst ``dite_false f.constLevels!) α a b) (contextDependent := cd)
else
return .rfl (done := true)
| .step c' h _ =>
return mkRflResult (done := true) (contextDependent := cd)
| .step c' h _ cd =>
if ( isTrueExpr c') then
let h' shareCommon <| mkOfEqTrueCore c h
let a share <| a.betaRev #[h']
return .step a <| mkApp (e.replaceFn ``dite_cond_eq_true) h
return .step a (mkApp (e.replaceFn ``dite_cond_eq_true) h) (contextDependent := cd)
else if ( isFalseExpr c') then
let h' shareCommon <| mkOfEqFalseCore c h
let b share <| b.betaRev #[h']
return .step b <| mkApp (e.replaceFn ``dite_cond_eq_false) h
return .step b (mkApp (e.replaceFn ``dite_cond_eq_false) h) (contextDependent := cd)
else
let .some inst' trySynthInstance (mkApp (mkConst ``Decidable) c') | return .rfl
let inst' shareCommon inst'
@@ -80,7 +86,7 @@ def simpDIte : Simproc := fun e => do
let b share <| mkLambda `h .default (mkNot c') (b.betaRev #[mkApp4 (mkConst ``Eq.mpr_not) c c' h (mkBVar 0)])
let e' mkAppS₄ e' c' inst' a b
let h' := mkApp3 (e.replaceFn ``Sym.dite_cond_congr) c' inst' h
return .step e' h' (done := true)
return .step e' h' (done := true) (contextDependent := cd)
/--
Simplifies a `cond` expression (aka Boolean `if-then-else`).
@@ -90,24 +96,25 @@ public def simpCond : Simproc := fun e => do
if numArgs < 4 then return .rfl (done := true)
propagateOverApplied e (numArgs - 4) fun e => do
let_expr f@cond α c a b := e | return .rfl
-- See `simpIte` for why `cd` is propagated to all branches.
match ( simp c) with
| .rfl _ =>
| .rfl _ cd =>
if isSameExpr c ( getBoolTrueExpr) then
return .step a <| mkApp3 (mkConst ``cond_true f.constLevels!) α a b
return .step a (mkApp3 (mkConst ``cond_true f.constLevels!) α a b) (contextDependent := cd)
else if isSameExpr c ( getBoolFalseExpr) then
return .step b <| mkApp3 (mkConst ``cond_false f.constLevels!) α a b
return .step b (mkApp3 (mkConst ``cond_false f.constLevels!) α a b) (contextDependent := cd)
else
return .rfl (done := true)
| .step c' h _ =>
return mkRflResult (done := true) (contextDependent := cd)
| .step c' h _ cd =>
if isSameExpr c' ( getBoolTrueExpr) then
return .step a <| mkApp (e.replaceFn ``Sym.cond_cond_eq_true) h
return .step a (mkApp (e.replaceFn ``Sym.cond_cond_eq_true) h) (contextDependent := cd)
else if isSameExpr c' ( getBoolFalseExpr) then
return .step b <| mkApp (e.replaceFn ``Sym.cond_cond_eq_false) h
return .step b (mkApp (e.replaceFn ``Sym.cond_cond_eq_false) h) (contextDependent := cd)
else
let e' := e.getBoundedAppFn 3
let e' mkAppS₃ e' c' a b
let h' := mkApp2 (e.replaceFn ``Sym.cond_cond_congr) c' h
return .step e' h' (done := true)
return .step e' h' (done := true) (contextDependent := cd)
/--
Simplifies a `match`-expression.
@@ -123,7 +130,7 @@ def simpMatch (declName : Name) : Simproc := fun e => do
let r simpAppArgRange e start stop
match r with
| .step .. => return r
| _ => return .rfl (done := true)
| .rfl _ cd => return mkRflResult (done := true) (contextDependent := cd)
/--
Simplifies control-flow expressions such as `if-then-else` and `match` expressions.

View File

@@ -32,36 +32,40 @@ When simplifying `x / x`, the discharger must prove `x ≠ 0` to apply this rule
Dischargers work by:
1. Attempting to simplify the side condition to `True`
2. If successful, extracting a proof from the simplification result
3. Returning `none` if the condition cannot be discharged
3. Returning `.failed` if the condition cannot be discharged
This integrates naturally with `Simproc`-based simplification.
## Important
When using dischargers that access new local declarations introduced when
visiting binders, it is the user's responsibility to set `wellBehavedMethods := false`.
This setting will instruct `simp` to discard the cache after visiting the binder's body.
Each discharge result also carries a `contextDependent` flag indicating whether
the discharge used context-dependent information (e.g., local hypotheses).
This enables the simplifier's two-tier cache to correctly handle context-dependent results.
-/
/-- Result of a discharge attempt. -/
public inductive DischargeResult where
/-- Discharge failed. If `contextDependent = true`, it might succeed in another local context. -/
| failed (contextDependent : Bool := false)
/-- Discharge succeeded with the given proof. -/
| solved (proof : Expr) (contextDependent : Bool := false)
/--
A discharger attempts to prove propositions that arise as side conditions during rewriting.
Given a proposition `e : Prop`, returns:
- `some proof` if `e` can be proven
- `none` if `e` cannot be discharged
- `.solved proof` if `e` can be proven
- `.failed` otherwise
**Usage**: Dischargers are used by the simplifier when applying conditional rewrite rules.
Both carry a `contextDependent` flag indicating whether context-dependent
information was used during the attempt.
-/
public abbrev Discharger := Expr SimpM (Option Expr)
public abbrev Discharger := Expr SimpM DischargeResult
def resultToOptionProof (e : Expr) (result : Result) : Option Expr :=
def resultToDischargeResult (e : Expr) (result : Result) : DischargeResult :=
match result with
| .rfl _ => none
| .step e' h _ =>
| .rfl _ cd => .failed cd
| .step e' h _ cd =>
if e'.isTrue then
some <| mkOfEqTrueCore e h
.solved (mkOfEqTrueCore e h) cd
else
none
.failed cd
/--
Converts a simplification procedure into a discharger.
@@ -73,7 +77,7 @@ a proof of the original proposition.
**Algorithm**:
1. Apply the simproc to the side condition `e`
2. If `e` simplifies to `True` (via proof `h : e = True`), return `ofEqTrue h : e`
3. Otherwise, return `none` (cannot discharge)
3. Otherwise, return `.failed` (cannot discharge)
**Parameters**:
- `p`: A simplification procedure to use for discharging conditions
@@ -82,7 +86,7 @@ a proof of the original proposition.
then `mkDischargerFromSimproc p` returns `ofEqTrue h : 5 < 10`.
-/
public def mkDischargerFromSimproc (p : Simproc) : Discharger := fun e => do
return resultToOptionProof e ( p e)
return resultToDischargeResult e ( p e)
/--
The default discharger uses the simplifier itself to discharge side conditions.
@@ -99,16 +103,16 @@ infinite recursion.
-/
public def dischargeSimpSelf : Discharger := fun e => do
if ( readThe Context).dischargeDepth > ( getConfig).maxDischargeDepth then
return none
return .failed
withoutModifyingCache do
withTheReader Context (fun ctx => { ctx with dischargeDepth := ctx.dischargeDepth + 1 }) do
return resultToOptionProof e ( simp e)
return resultToDischargeResult e ( simp e)
/--
A discharger that fails to prove any side condition.
This is used when conditional rewrite rules should not be applied. It immediately
returns `none` for all propositions, effectively disabling conditional rewriting.
returns `.failed` for all propositions, effectively disabling conditional rewriting.
**Use cases**:
- Testing: Isolating unconditional rewriting behavior
@@ -116,6 +120,17 @@ returns `none` for all propositions, effectively disabling conditional rewriting
- Controlled rewriting: Explicitly disabling conditional rules in specific contexts
-/
public def dischargeNone : Discharger := fun _ =>
return none
return .failed
/--
A discharger for testing that proves side conditions using hypotheses from the
local context. Always context-dependent: available hypotheses change when entering binders.
-/
public def dischargeAssumption : Discharger := fun e => do
for localDecl in ( getLCtx) do
if localDecl.isAuxDecl then continue
if ( isDefEq localDecl.type e) then
return .solved localDecl.toExpr true
return .failed true
end Lean.Meta.Sym.Simp

View File

@@ -65,7 +65,7 @@ Operations dispatch on the type expression directly. It assumes non-standard ins
def skipIfUnchanged (e : Expr) (result : Result) : Result :=
match result with
| .step e' _ _ => if isSameExpr e e' then .rfl else result
| .step e' _ _ cd => if isSameExpr e e' then mkRflResultCD cd else result
| _ => result
abbrev evalUnary [ToExpr α] (toValue? : Expr Option α) (op : α α) (a : Expr) : SimpM Result := do

View File

@@ -98,40 +98,52 @@ partial def simpArrows (e : Expr) (infos : List ArrowInfo) (simpBody : Simproc)
| [] => return (( simpBody e), [])
| info :: infos' =>
let_expr f@Arrow p q := e | return (( simpBody e), infos)
-- **cd propagation**: `cd` from `simp p` and `simp q` (via recursive `simpArrows`)
-- must be propagated through ALL branches. Even when `p` is already False/True
-- and hasn't changed, `cd = true` means `simp p` might take a different code path
-- in another local context (e.g., a conditional rewrite could succeed), which would
-- lead to a completely different result for the arrow expression.
let p_r simp p
if ( isFalseExpr (p_r.getResultExpr p)) && info.v.isZero then
match p_r with
| .rfl _ => return (.step ( getTrueExpr) (mkApp (mkConst ``false_arrow) q), [])
| .step _ h _ => return (.step ( getTrueExpr) (mkApp3 (mkConst ``false_arrow_congr) p q h), [])
| .rfl _ cd => return (.step ( getTrueExpr) (mkApp (mkConst ``false_arrow) q) (contextDependent := cd), [])
| .step _ h _ cd => return (.step ( getTrueExpr) (mkApp3 (mkConst ``false_arrow_congr) p q h) (contextDependent := cd), [])
let (q_r, infos') simpArrows q infos' simpBody
if ( isTrueExpr (q_r.getResultExpr q)) then
let cd := p_r.isContextDependent || q_r.isContextDependent
match q_r with
| .rfl _ => return (.step ( getTrueExpr) (mkApp (mkConst ``arrow_true [info.u]) p), [])
| .step _ h _ => return (.step ( getTrueExpr) (mkApp3 (mkConst ``arrow_true_congr [info.u]) p q h), [])
| .rfl _ _ => return (.step ( getTrueExpr) (mkApp (mkConst ``arrow_true [info.u]) p) (contextDependent := cd), [])
| .step _ h _ _ => return (.step ( getTrueExpr) (mkApp3 (mkConst ``arrow_true_congr [info.u]) p q h) (contextDependent := cd), [])
match p_r, q_r with
| .rfl _, .rfl _ =>
| .rfl _ cd₁, .rfl _ cd₂ =>
let cd := cd₁ || cd₂
if ( isTrueExpr p) && info.v.isZero then
return (.step q (mkApp (mkConst ``true_arrow) q), infos')
return (.step q (mkApp (mkConst ``true_arrow) q) (contextDependent := cd), infos')
else
return (.rfl, infos)
| .step p' h _, .rfl _ =>
return (mkRflResultCD cd, infos)
| .step p' h _ cd₁, .rfl _ cd₂ =>
let cd := cd₁ || cd₂
if ( isTrueExpr p') && info.v.isZero then
return (.step q (mkApp3 (mkConst ``true_arrow_congr_left) p q h), infos')
return (.step q (mkApp3 (mkConst ``true_arrow_congr_left) p q h) (contextDependent := cd), infos')
else
let e' mkAppS₂ f p' q
return (.step e' <| mkApp4 (mkConst ``arrow_congr_left f.constLevels!) p p' q h, info :: infos')
| .rfl _, .step q' h _ =>
let proof := mkApp4 (mkConst ``arrow_congr_left f.constLevels!) p p' q h
return (.step e' proof (contextDependent := cd), info :: infos')
| .rfl _ cd₁, .step q' h _ cd₂ =>
let cd := cd₁ || cd₂
if ( isTrueExpr p) && info.v.isZero then
return (.step q' (mkApp3 (mkConst ``true_arrow_congr_right) q q' h), infos')
return (.step q' (mkApp3 (mkConst ``true_arrow_congr_right) q q' h) (contextDependent := cd), infos')
else
let e' mkAppS₂ f p q'
return (.step e' <| mkApp4 (mkConst ``arrow_congr_right f.constLevels!) p q q' h, info :: infos')
| .step p' h₁ _, .step q' h₂ _ =>
let proof := mkApp4 (mkConst ``arrow_congr_right f.constLevels!) p q q' h
return (.step e' proof (contextDependent := cd), info :: infos')
| .step p' h₁ _ cd₁, .step q' h₂ _ cd₂ =>
if ( isTrueExpr p') && info.v.isZero then
return (.step q' (mkApp5 (mkConst ``true_arrow_congr) p q q' h₁ h₂), infos')
return (.step q' (mkApp5 (mkConst ``true_arrow_congr) p q q' h₁ h₂) (contextDependent := cd₁ || cd₂), infos')
else
let e' mkAppS₂ f p' q'
return (.step e' <| mkApp6 (mkConst ``arrow_congr f.constLevels!) p p' q q' h₁ h₂, info :: infos')
let proof := mkApp6 (mkConst ``arrow_congr f.constLevels!) p p' q q' h₁ h₂
return (.step e' proof (contextDependent := cd₁ || cd₂), info :: infos')
/--
Simplifies a telescope of non-dependent arrows `p₁ → p₂ → ... → pₙ → q` by:
@@ -159,55 +171,60 @@ result as fully simplified to prevent `simpArrow` from being applied.
public def simpArrowTelescope (simpBody : Simproc := simp) : Simproc := fun e => do
unless e.isArrow do return .rfl -- not applicable
let { arrow, infos, v } toArrow e
let (.step arrow' h _, infos) simpArrows arrow infos simpBody | return .rfl (done := true)
match ( simpArrows arrow infos simpBody) with
| (.rfl _ cd, _) => return mkRflResult (done := true) (contextDependent := cd)
| (.step arrow' h _ cd, infos) =>
let e' toForall arrow' infos
let α := mkSort v
let v1 := v.succ
let h := mkApp6 (mkConst ``Eq.trans [v1]) α e arrow arrow' (mkApp2 (mkConst ``Eq.refl [v1]) α arrow) h
let h := mkApp6 (mkConst ``Eq.trans [v1]) α e arrow' e' h (mkApp2 (mkConst ``Eq.refl [v1]) α e')
return .step e' h (done := true)
return .step e' h (done := true) (contextDependent := cd)
public def simpArrow (e : Expr) : SimpM Result := do
let p := e.bindingDomain!
let q := e.bindingBody!
-- Propagate `cd` from both domain and codomain: if either sub-simplification
-- was context-dependent, `simp` might take a different path in another context.
match ( simp p), ( simp q) with
| .rfl _, .rfl _ =>
return .rfl
| .step p' h _, .rfl _ =>
| .rfl _ cd₁, .rfl _ cd₂ =>
return mkRflResultCD (cd₁ || cd₂)
| .step p' h _ cd₁, .rfl _ cd₂ =>
let u getLevel p
let v getLevel q
let e' e.updateForallS! p' q
return .step e' <| mkApp4 (mkConst ``implies_congr_left [u, v]) p p' q h
| .rfl _, .step q' h _ =>
return .step e' (mkApp4 (mkConst ``implies_congr_left [u, v]) p p' q h) (contextDependent := cd₁ || cd₂)
| .rfl _ cd₁, .step q' h _ cd₂ =>
let u getLevel p
let v getLevel q
let e' e.updateForallS! p q'
return .step e' <| mkApp4 (mkConst ``implies_congr_right [u, v]) p q q' h
| .step p' h₁ _, .step q' h₂ _ =>
return .step e' (mkApp4 (mkConst ``implies_congr_right [u, v]) p q q' h) (contextDependent := cd₁ || cd₂)
| .step p' h₁ _ cd₁, .step q' h₂ _ cd₂ =>
let u getLevel p
let v getLevel q
let e' e.updateForallS! p' q'
return .step e' <| mkApp6 (mkConst ``implies_congr [u, v]) p p' q q' h₁ h₂
return .step e' (mkApp6 (mkConst ``implies_congr [u, v]) p p' q q' h₁ h₂) (contextDependent := cd₁ || cd₂)
public def simpForall' (simpArrow : Simproc) (simpBody : Simproc) (e : Expr) : SimpM Result := do
if e.isArrow then
simpArrow e
else if ( isProp e) then
let n := getForallTelescopeSize e.bindingBody! 1
forallBoundedTelescope e n fun xs b => withoutModifyingCacheIfNotWellBehaved do
forallBoundedTelescope e n fun xs b => withFreshTransientCache do
main xs ( shareCommon b)
else
return .rfl
where
main (xs : Array Expr) (b : Expr) : SimpM Result := do
-- Propagate `cd` from the body: in another context the body might simplify differently.
match ( simpBody b) with
| .rfl _ => return .rfl
| .step b' h _ =>
| .rfl _ cd => return mkRflResultCD cd
| .step b' h _ cd =>
let h mkLambdaFVars xs h
let e' shareCommon ( mkForallFVars xs b')
-- **Note**: consider caching the forall-congr theorems
let hcongr mkForallCongrFor xs
return .step e' (mkApp3 hcongr ( mkLambdaFVars xs b) ( mkLambdaFVars xs b') h)
return .step e' (mkApp3 hcongr ( mkLambdaFVars xs b) ( mkLambdaFVars xs b') h) (contextDependent := cd)
-- **Note**: Optimize if this is quadratic in practice
getForallTelescopeSize (e : Expr) (n : Nat) : Nat :=

View File

@@ -44,8 +44,8 @@ Converts a `Simp.Result` value into `SimpGoalResult`.
public def Simp.Result.toSimpGoalResult (result : Simp.Result) (mvarId : MVarId) : SymM SimpGoalResult := do
let decl mvarId.getDecl
match result with
| .rfl _ => return .noProgress
| .step target' h _ =>
| .rfl _ _ => return .noProgress
| .step target' h _ _ =>
let mvarNew mkFreshExprSyntheticOpaqueMVar target' decl.userName
let u getLevel decl.type
let h := mkApp4 (mkConst ``Eq.mpr [u]) decl.type target' h mvarNew

View File

@@ -325,21 +325,22 @@ where
match e with
| .app f a =>
let (rf, fType) go f (i-1)
-- Propagate `cd` from both function and argument sub-results.
let r match rf, ( simp a) with
| .rfl _, .rfl _ =>
pure .rfl
| .step f' hf _, .rfl _ =>
| .rfl _ cd₁, .rfl _ cd₂ =>
pure (mkRflResultCD (cd₁ || cd₂))
| .step f' hf _ cd₁, .rfl _ cd₂ =>
let e' mkAppS f' a
let h := mkApp4 ( mkCongrPrefix ``congrFun' fType i) f f' hf a
pure <| .step e' h
| .rfl _, .step a' ha _ =>
pure <| .step e' h (contextDependent := cd₁ || cd₂)
| .rfl _ cd₁, .step a' ha _ cd₂ =>
let e' mkAppS f a'
let h := mkApp4 ( mkCongrPrefix ``congrArg fType i) a a' f ha
pure <| .step e' h
| .step f' hf _, .step a' ha _ =>
pure <| .step e' h (contextDependent := cd₁ || cd₂)
| .step f' hf _ cd₁, .step a' ha _ cd₂ =>
let e' mkAppS f' a'
let h := mkApp6 ( mkCongrPrefix ``congr fType i) f f' a a' hf ha
pure <| .step e' h
pure <| .step e' h (contextDependent := cd₁ || cd₂)
return (r, fType.bindingBody!)
| .lam .. => return ( simpBody e, fType)
| _ => unreachable!
@@ -383,15 +384,15 @@ def simpHaveCore (e : Expr) (simpBody : Simproc) : SimpM SimpHaveResult := do
let e₂ := r.e
let { fnUnivs, argUnivs } getUnivs r.fType
match ( simpBetaApp e₂ r.fType fnUnivs argUnivs simpBody) with
| .rfl _ => return { result := .rfl, α := r.α, u := r.u }
| .step e₃ h _ =>
| .rfl _ cd => return { result := mkRflResultCD cd, α := r.α, u := r.u }
| .step e₃ h _ cd =>
let h₁ := mkApp6 (mkConst ``Eq.trans [r.u]) r.α e₁ e₂ e₃ r.h h
let e₄ toHave e₃ r.varDeps
let eq := mkApp3 (mkConst ``Eq [r.u]) r.α e₃ e₄
let h₂ := mkApp2 (mkConst ``Eq.refl [r.u]) r.α e₃
let h₂ := mkExpectedPropHint h₂ eq
let h := mkApp6 (mkConst ``Eq.trans [r.u]) r.α e₁ e₃ e₄ h₁ h₂
return { result := .step e₄ h, α := r.α, u := r.u }
return { result := .step e₄ h (contextDependent := cd), α := r.α, u := r.u }
/--
Simplify a `have`-telescope.
@@ -411,21 +412,21 @@ avoiding quadratic behavior from multiple passes.
public def simpHaveAndZetaUnused (e₁ : Expr) (simpBody : Simproc) : SimpM Result := do
let r simpHaveCore e₁ simpBody
match r.result with
| .rfl _ =>
| .rfl _ cd =>
let e₂ zetaUnused e₁
if isSameExpr e₁ e₂ then
return .rfl
return mkRflResultCD cd
else
let h := mkApp2 (mkConst ``Eq.refl [r.u]) r.α e₂
return .step e₂ h
| .step e₂ h _ =>
return .step e₂ h (contextDependent := cd)
| .step e₂ h _ cd =>
let e₃ zetaUnused e₂
if isSameExpr e₂ e₃ then
return r.result
else
let h := mkApp6 (mkConst ``Eq.trans [r.u]) r.α e₁ e₂ e₃ h
(mkApp2 (mkConst ``Eq.refl [r.u]) r.α e₃)
return .step e₃ h
return .step e₃ h (contextDependent := cd)
public def simpLet' (simpBody : Simproc) (e : Expr) : SimpM Result := do
if !e.letNondep! then

View File

@@ -46,17 +46,18 @@ def mkFunextFor (xs : Array Expr) (β : Expr) : MetaM Expr := do
return result
public def simpLambda' (simpBody : Simproc) (e : Expr) : SimpM Result := do
lambdaTelescope e fun xs b => withoutModifyingCacheIfNotWellBehaved do
lambdaTelescope e fun xs b => withFreshTransientCache do
main xs ( shareCommon b)
where
main (xs : Array Expr) (b : Expr) : SimpM Result := do
-- Propagate `cd` from the body: in another context the body might simplify differently.
match ( simpBody b) with
| .rfl _ => return .rfl
| .step b' h _ =>
| .rfl _ cd => return mkRflResultCD cd
| .step b' h _ cd =>
let h mkLambdaFVars xs h
let e' shareCommon ( mkLambdaFVars xs b')
let funext getFunext xs b
return .step e' (mkApp3 funext e e' h)
return .step e' (mkApp3 funext e e' h) (contextDependent := cd)
getFunext (xs : Array Expr) (b : Expr) : SimpM Expr := do
let key inferType e

View File

@@ -11,7 +11,10 @@ import Lean.Meta.Sym.Simp.Simproc
import Lean.Meta.Sym.Simp.App
import Lean.Meta.Sym.Simp.Have
import Lean.Meta.Sym.Simp.Forall
namespace Lean.Meta.Sym.Simp
builtin_initialize registerTraceClass `sym.simp.debug.cache
open Internal
def simpStep : Simproc := fun e => do
@@ -20,17 +23,21 @@ def simpStep : Simproc := fun e => do
| .proj .. =>
throwError "unexpected kernel projection term during simplification{indentExpr e}\npre-process and fold them as projection applications"
| .mdata m b =>
-- Propagate `cd` from inner term through the mdata wrapper.
let r simp b
match r with
| .rfl _ => return .rfl
| .step b' h _ => return .step ( mkMDataS m b') h
| .rfl _ cd => return mkRflResultCD cd
| .step b' h _ cd => return .step ( mkMDataS m b') h (contextDependent := cd)
| .lam .. => simpLambda e
| .forallE .. => simpForall e
| .letE .. => simpLet e
| .app .. => simpAppArgs e
abbrev cacheResult (e : Expr) (r : Result) : SimpM Result := do
modify fun s => { s with cache := s.cache.insert { expr := e } r }
if r.isContextDependent then
modify fun s => { s with transientCache := s.transientCache.insert { expr := e } r }
else
modify fun s => { s with persistentCache := s.persistentCache.insert { expr := e } r }
return r
@[export lean_sym_simp]
@@ -38,7 +45,12 @@ def simpImpl (e₁ : Expr) : SimpM Result := withIncRecDepth do
let numSteps := ( get).numSteps
if numSteps >= ( getConfig).maxSteps then
throwError "`simp` failed: maximum number of steps exceeded"
if let some result := ( getCache).find? { expr := e₁ } then
let key : ExprPtr := { expr := e₁ }
if let some result := ( get).persistentCache.find? key then
trace[sym.simp.debug.cache] "persistent cache hit: {e₁}"
return result
if let some result := ( get).transientCache.find? key then
trace[sym.simp.debug.cache] "transient cache hit: {e₁}"
return result
let numSteps := numSteps + 1
if numSteps % 1000 == 0 then
@@ -46,12 +58,16 @@ def simpImpl (e₁ : Expr) : SimpM Result := withIncRecDepth do
modify fun s => { s with numSteps }
let r₁ pre e₁
match r₁ with
| .rfl true | .step _ _ true => cacheResult e₁ r₁
| .step e₂ h₁ false => cacheResult e₁ ( mkEqTransResult e₁ e₂ h₁ ( simp e₂))
| .rfl false =>
| .rfl true _ | .step _ _ true _ => cacheResult e₁ r₁
| .step e₂ h₁ false cd₁ => cacheResult e₁ ( mkEqTransResult e₁ e₂ h₁ ( simp e₂) cd₁)
| .rfl false cd₁ =>
let r₂ (simpStep >> post) e₁
-- If `pre` was context-dependent (cd₁ = true) but returned `.rfl`, it might
-- succeed in another context. Propagate cd₁ so the cached result for `e₁`
-- lands in the transient cache and gets re-evaluated after binder entry.
let r₂ := if cd₁ && !r₂.isContextDependent then r₂.withContextDependent else r₂
match r₂ with
| .rfl _ | .step _ _ true => cacheResult e₁ r₂
| .step e₂ h₁ false => cacheResult e₁ ( mkEqTransResult e₁ e₂ h₁ ( simp e₂))
| .rfl _ _ | .step _ _ true _ => cacheResult e₁ r₂
| .step e₂ h₁ false cd₁ => cacheResult e₁ ( mkEqTransResult e₁ e₂ h₁ ( simp e₂) cd₁)
end Lean.Meta.Sym.Simp

View File

@@ -9,25 +9,31 @@ public import Lean.Meta.Sym.Simp.SimpM
public import Lean.Meta.Sym.InferType
namespace Lean.Meta.Sym.Simp
public abbrev Result.isRfl (result : Result) : Bool :=
result matches .rfl
public abbrev Result.isRfl : Result Bool
| .rfl false _ => true
| _ => false
public def mkEqTrans (e₁ : Expr) (e₂ : Expr) (h₁ : Expr) (e₃ : Expr) (h₂ : Expr) : SymM Expr := do
let α Sym.inferType e₁
let u Sym.getLevel α
return mkApp6 (mkConst ``Eq.trans [u]) α e₁ e₂ e₃ h₁ h₂
public abbrev mkEqTransResult (e₁ : Expr) (e₂ : Expr) (h₁ : Expr) (r₂ : Result) : SymM Result :=
/-- Chains two simplification steps via `Eq.trans`.
`cd₁` is the `contextDependent` flag from the first step (whose proof is `h₁`).
The output is context-dependent if either step was: in another local context,
either step might produce a different result, changing the whole chain. -/
public abbrev mkEqTransResult (e₁ : Expr) (e₂ : Expr) (h₁ : Expr) (r₂ : Result)
(cd₁ : Bool := false) : SymM Result :=
match r₂ with
| .rfl done => return .step e₂ h₁ done
| .step e₃ h₂ done => return .step e₃ ( mkEqTrans e₁ e₂ h₁ e₃ h₂) done
| .rfl done cd₂ => return .step e₂ h₁ done (cd₁ || cd₂)
| .step e₃ h₂ done cd₂ => return .step e₃ ( mkEqTrans e₁ e₂ h₁ e₃ h₂) done (cd₁ || cd₂)
public def Result.markAsDone : Result Result
| .rfl _ => .rfl true
| .step e h _ => .step e h true
| .rfl _ cd => .rfl true cd
| .step e h _ cd => .step e h true cd
public def Result.getResultExpr : Expr Result Expr
| e, .rfl _ => e
| _, .step e _ _ => e
| e, .rfl _ _ => e
| _, .step e _ _ _ => e
end Lean.Meta.Sym.Simp

View File

@@ -40,6 +40,10 @@ public def Theorem.rewrite (thm : Theorem) (e : Expr) (d : Discharger := dischar
-- **Note**: Potential optimization: check whether pattern covers all variables.
let mut args := result.args.toVector
let us result.us.mapM instantiateLevelMVars
-- Track whether any discharger used context-dependent information.
-- If so, the result is context-dependent: in another context, the discharger
-- might succeed/fail differently, changing whether the rewrite applies.
let mut isCD := false
for h : i in *...args.size do
let arg := args[i]
if let .mvar mvarId := arg then
@@ -48,13 +52,16 @@ public def Theorem.rewrite (thm : Theorem) (e : Expr) (d : Discharger := dischar
args := args.set i arg
else
let decl mvarId.getDecl
if let some val d decl.type then
match ( d decl.type) with
| .failed cd =>
isCD := isCD || cd
-- Failed to discharge hypothesis.
return mkRflResultCD isCD
| .solved val cd =>
isCD := isCD || cd
let val instantiateMVarsS val
mvarId.assign val
args := args.set i val
else
-- **Note**: Failed to discharge hypothesis.
return .rfl
else if arg.hasMVar then
let arg instantiateMVarsS arg
args := args.set i arg
@@ -63,20 +70,25 @@ public def Theorem.rewrite (thm : Theorem) (e : Expr) (d : Discharger := dischar
let rhs share rhs
let expr instantiateRevBetaS rhs args.toArray
if isSameExpr e expr then
return .rfl
return mkRflResultCD isCD
else
return .step expr proof
return .step expr proof (contextDependent := isCD)
else
return .rfl
public def Theorems.rewrite (thms : Theorems) (d : Discharger := dischargeNone) : Simproc := fun e => do
-- Track `cd` across all attempted theorems. If theorem A fails with cd=true
-- and theorem B succeeds with cd=false, the result is still cd=true: in another
-- context A might succeed (with higher priority) and produce a different result.
let mut anyCD := false
for (thm, numExtra) in thms.getMatchWithExtra e do
let result if numExtra == 0 then
thm.rewrite e d
else
simpOverApplied e numExtra (thm.rewrite · d)
anyCD := anyCD || result.isContextDependent
if !result.isRfl then
return result
return .rfl
return if anyCD && !result.isContextDependent then result.withContextDependent else result
return mkRflResultCD anyCD
end Lean.Meta.Sym.Simp

View File

@@ -144,16 +144,53 @@ The `done` flag affects:
The flag is orthogonal to caching: both `.rfl` and `.step` results are cached
regardless of the `done` flag, and cached results are always treated as final.
## Context-dependent results
The `contextDependent` flag tracks whether the result depends on the local context
(e.g., hypotheses introduced when entering binders). Context-dependent results are
stored in a transient cache that is cleared when entering binders, while
context-independent results persist across binder entry and across `simp` invocations.
**Propagation rule**: when combining sub-results (congruence, transitivity, etc.),
`contextDependent` is the disjunction of all sub-results' flags. This includes
`.rfl` results: if `simp` returned `.rfl (contextDependent := true)`, it means
`simp` might produce a *different* result in another local context (e.g., a conditional
rewrite could succeed), so all downstream results must be marked context-dependent.
-/
inductive Result where
/-- No change. If `done = true`, skip remaining simplification steps for this term. -/
| rfl (done : Bool := false)
| rfl (done : Bool := false) (contextDependent : Bool := false)
/--
Simplified to `e'` with proof `proof : e = e'`.
If `done = true`, skip recursive simplification of `e'`. -/
| step (e' : Expr) (proof : Expr) (done : Bool := false)
| step (e' : Expr) (proof : Expr) (done : Bool := false) (contextDependent : Bool := false)
deriving Inhabited
/--
Pre-computed `.rfl` results to avoid dynamic memory allocation.
Each combination of `done` and `contextDependent` maps to a compile-time constant.
-/
public def mkRflResult (done : Bool := false) (contextDependent : Bool := false) : Result :=
match done, contextDependent with
| false, false => .rfl
| false, true => .rfl false true
| true, false => .rfl true
| true, true => .rfl true true
/-- Like `mkRflResult` with `done := false`. -/
public def mkRflResultCD (contextDependent : Bool) : Result :=
if contextDependent then .rfl false true else .rfl
/-- Returns `true` if this result depends on the local context (e.g., hypotheses). -/
public abbrev Result.isContextDependent : Result Bool
| .rfl _ cd | .step _ _ _ cd => cd
/-- Marks a result as context-dependent. -/
public def Result.withContextDependent : Result Result
| .rfl done _ => .rfl done true
| .step e h done _ => .step e h done true
private opaque MethodsRefPointed : NonemptyType.{0}
def MethodsRef : Type := MethodsRefPointed.type
instance : Nonempty MethodsRef := by exact MethodsRefPointed.property
@@ -175,10 +212,15 @@ structure State where
/-- Number of steps performed so far. -/
numSteps := 0
/--
Cache of previously simplified expressions to avoid redundant work.
**Note**: Consider moving to `SymM.State`
Cache for context-independent results. Survives across binder entry and
across `simp` invocations within a `sym =>` block.
-/
cache : Cache := {}
persistentCache : Cache := {}
/--
Cache for context-dependent results. Cleared when entering binders
(where new hypotheses may change the result).
-/
transientCache : Cache := {}
/-- Cache for generated funext theorems -/
funext : PHashMap ExprPtr Expr := {}
@@ -193,13 +235,6 @@ abbrev Simproc := Expr → SimpM Result
structure Methods where
pre : Simproc := fun _ => return .rfl
post : Simproc := fun _ => return .rfl
/--
`wellBehavedMethods` must **not** be set to `true` IF their behavior
depends on new hypotheses in the local context. For example, for applying
conditional rewrite rules.
Reason: it would prevent us from aggressively caching `simp` results.
-/
wellBehavedMethods : Bool := true
deriving Inhabited
unsafe def Methods.toMethodsRefImpl (m : Methods) : MethodsRef :=
@@ -217,12 +252,15 @@ opaque MethodsRef.toMethods (m : MethodsRef) : Methods
def getMethods : SimpM Methods :=
return MethodsRef.toMethods ( read)
/-- Runs a `SimpM` computation with the given theorems, configuration, and initial state -/
def SimpM.run (x : SimpM α) (methods : Methods := {}) (config : Config := {}) (s : State := {}) : SymM (α × State) := do
/-- Runs a `SimpM` computation with the given methods, configuration, and state.
The `transientCache` is always reset (context-dependent results don't survive across
invocations). The `persistentCache` and `funext` cache are preserved from `s`. -/
def SimpM.run (x : SimpM α) (methods : Methods := {}) (config : Config := {})
(s : State := {}) : SymM (α × State) := do
let initialLCtxSize := ( getLCtx).decls.size
x methods.toMethodsRef { initialLCtxSize, config } |>.run s
x methods.toMethodsRef { initialLCtxSize, config } |>.run { s with transientCache := {}, numSteps := 0 }
/-- Runs a `SimpM` computation with the given theorems and configuration. -/
/-- Runs a `SimpM` computation with the given methods and configuration. -/
def SimpM.run' (x : SimpM α) (methods : Methods := {}) (config : Config := {}) : SymM α := do
let initialLCtxSize := ( getLCtx).decls.size
x methods.toMethodsRef { initialLCtxSize, config } |>.run' {}
@@ -234,22 +272,25 @@ opaque simp : Simproc
def getConfig : SimpM Config :=
return ( readThe Context).config
abbrev getCache : SimpM Cache :=
return ( get).cache
abbrev pre : Simproc := fun e => do
( getMethods).pre e
abbrev post : Simproc := fun e => do
( getMethods).post e
/-- Saves and restores both caches and funext. Used by dischargers. -/
abbrev withoutModifyingCache (k : SimpM α) : SimpM α := do
let cache getCache
let persistentCache := ( get).persistentCache
let transientCache := ( get).transientCache
let funext := ( get).funext
try k finally modify fun s => { s with cache, funext }
try k finally modify fun s => { s with persistentCache, transientCache, funext }
abbrev withoutModifyingCacheIfNotWellBehaved (k : SimpM α) : SimpM α := do
if ( getMethods).wellBehavedMethods then k else withoutModifyingCache k
/-- Saves and restores the transient cache and funext, leaving
the persistent cache untouched. Used when entering binders. -/
abbrev withFreshTransientCache (k : SimpM α) : SimpM α := do
let transientCache := ( get).transientCache
let funext := ( get).funext
try k finally modify fun s => { s with transientCache, funext }
end Simp

View File

@@ -11,9 +11,15 @@ namespace Lean.Meta.Sym.Simp
public abbrev Simproc.andThen (f g : Simproc) : Simproc := fun e₁ => do
let r f e₁
match r with
| .step _ _ true | .rfl true => return r
| .rfl false => g e₁
| .step e₂ h₁ false => mkEqTransResult e₁ e₂ h₁ ( g e₂)
| .step _ _ true _ | .rfl true _ => return r
-- Propagate cd₁: if `f` was context-dependent but returned `.rfl`, in another
-- context `f` might succeed and the whole `andThen` would take a different path.
| .rfl false cd₁ =>
let r₂ g e₁
return if cd₁ && !r₂.isContextDependent then r₂.withContextDependent else r₂
-- `cd₁` from `f` is threaded into `mkEqTransResult` so the combined result
-- is context-dependent if either `f` or `g` was.
| .step e₂ h₁ false cd₁ => mkEqTransResult e₁ e₂ h₁ ( g e₂) cd₁
public instance : AndThen Simproc where
andThen f g := Simproc.andThen f (g ())
@@ -21,8 +27,12 @@ public instance : AndThen Simproc where
public abbrev Simproc.orElse (f g : Simproc) : Simproc := fun e₁ => do
let r f e₁
match r with
| .step _ _ _ | .rfl true => return r
| .rfl false => g e₁
| .step _ _ _ _ | .rfl true _ => return r
-- Propagate cd₁: if `f` was context-dependent but returned `.rfl`, in another
-- context `f` might succeed and the `orElse` would return `f`'s result instead.
| .rfl false cd₁ =>
let r₂ g e₁
return if cd₁ && !r₂.isContextDependent then r₂.withContextDependent else r₂
public instance : OrElse Simproc where
orElse f g := Simproc.orElse f (g ())

View File

@@ -17,14 +17,14 @@ if it reduces to `True`, returns `True` immediately without evaluating the right
builtin_cbv_simproc simpOr (@Or _ _) := fun e => do
let_expr Or a b := e | return .rfl
match ( simp a) with
| .rfl _ =>
| .rfl _ _ =>
if ( isTrueExpr a) then
return .step ( getTrueExpr) (mkApp (mkConst ``true_or) b) (done := true)
else if ( isFalseExpr a) then
return .step b (mkApp (mkConst ``false_or) b)
else
return .rfl
| .step a' ha _ =>
| .step a' ha _ _ =>
if ( isTrueExpr a') then
return .step ( getTrueExpr) (mkApp (e.replaceFn ``Sym.or_eq_true_left) ha) (done := true)
else if ( isFalseExpr a') then
@@ -37,14 +37,14 @@ if it reduces to `False`, returns `False` immediately without evaluating the rig
builtin_cbv_simproc simpAnd (@And _ _) := fun e => do
let_expr And a b := e | return .rfl
match ( simp a) with
| .rfl _ =>
| .rfl _ _ =>
if ( isFalseExpr a) then
return .step ( getFalseExpr) (mkApp (mkConst ``false_and) b) (done := true)
else if ( isTrueExpr a) then
return .step b (mkApp (mkConst ``true_and) b)
else
return .rfl
| .step a' ha _ =>
| .step a' ha _ _ =>
if ( isFalseExpr a') then
return .step ( getFalseExpr) (mkApp (e.replaceFn ``Sym.and_eq_false_left) ha) (done := true)
else if ( isTrueExpr a') then

View File

@@ -262,14 +262,14 @@ def cbvSimprocDispatch (tree : DiscrTree CbvSimprocEntry)
let simprocName := (privateToUserName entry.declName).replacePrefix `Lean.Meta.Sym.Simp .anonymous |>.replacePrefix `Lean.Meta.Tactic.Cbv .anonymous
let result withTraceNode `Meta.Tactic.cbv.simprocs (fun
| .ok (Result.step e' ..) => return m!"simproc {simprocName}:{indentExpr e}\n==>{indentExpr e'}"
| .ok (Result.rfl true) => return m!"simproc {simprocName}: done{indentExpr e}"
| .ok (Result.rfl true _) => return m!"simproc {simprocName}: done{indentExpr e}"
| .ok _ => return m!"simproc {simprocName}: no change"
| .error err => return m!"simproc {simprocName}: {err.toMessageData}") do
if numExtra == 0 then
entry.proc e
else
simpOverApplied e numExtra entry.proc
if result matches .step _ _ _ then
if result matches .step _ _ _ _ then
return result
if result matches .rfl (done := true) then
return result

View File

@@ -71,15 +71,24 @@ def matchIteDecidableCongr (f α c inst a b c' h inst' : Expr) (fallback : SimpM
/-- Simplify the `Decidable` instance, then try `simpIteDecidable`. -/
def simpAndMatchIteDecidable (f α c inst a b : Expr) (fallback : SimpM Result) : SimpM Result := do
-- Propagate cd from `simp inst`: in another context the instance might simplify differently.
match ( simp inst) with
| .rfl _ => matchIteDecidable f α c inst a b inst fallback
| .step inst' _ _ => matchIteDecidable f α c inst a b inst' fallback
| .rfl _ cd =>
let r matchIteDecidable f α c inst a b inst fallback
return if cd && !r.isContextDependent then r.withContextDependent else r
| .step inst' _ _ cd =>
let r matchIteDecidable f α c inst a b inst' fallback
return if cd && !r.isContextDependent then r.withContextDependent else r
/-- Like `simpAndMatchIteDecidable`, but for the congruence case where `c` was simplified to `c'`. -/
def simpAndMatchIteDecidableCongr (f α c inst a b c' h inst' : Expr) (fallback : SimpM Result) : SimpM Result := do
match ( simp inst') with
| .rfl _ => matchIteDecidableCongr f α c inst a b c' h inst' fallback
| .step inst'' _ _ => matchIteDecidableCongr f α c inst a b c' h inst'' fallback
| .rfl _ cd =>
let r matchIteDecidableCongr f α c inst a b c' h inst' fallback
return if cd && !r.isContextDependent then r.withContextDependent else r
| .step inst'' _ _ cd =>
let r matchIteDecidableCongr f α c inst a b c' h inst'' fallback
return if cd && !r.isContextDependent then r.withContextDependent else r
/-- Like `simpIte` but also evaluates `Decidable.decide` when the condition does not
reduce to `True`/`False` directly. -/
@@ -88,19 +97,20 @@ builtin_cbv_simproc ↓ simpIteCbv (@ite _ _ _ _ _) := fun e => do
if numArgs < 5 then return .rfl (done := true)
propagateOverApplied e (numArgs - 5) fun e => do
let_expr f@ite α c inst a b := e | return .rfl
-- See Sym.Simp.ControlFlow.simpIte for why cd is propagated to all branches.
match ( simp c) with
| .rfl _ =>
| .rfl _ cd =>
if ( isTrueExpr c) then
return .step a <| mkApp3 (mkConst ``ite_true f.constLevels!) α a b
return .step a (mkApp3 (mkConst ``ite_true f.constLevels!) α a b) (contextDependent := cd)
else if ( isFalseExpr c) then
return .step b <| mkApp3 (mkConst ``ite_false f.constLevels!) α a b
return .step b (mkApp3 (mkConst ``ite_false f.constLevels!) α a b) (contextDependent := cd)
else
simpAndMatchIteDecidable f α c inst a b do return .rfl (done := true)
| .step c' h _ =>
simpAndMatchIteDecidable f α c inst a b do return mkRflResult (done := true) (contextDependent := cd)
| .step c' h _ cd =>
if ( isTrueExpr c') then
return .step a <| mkApp (e.replaceFn ``ite_cond_eq_true) h
return .step a (mkApp (e.replaceFn ``ite_cond_eq_true) h) (contextDependent := cd)
else if ( isFalseExpr c') then
return .step b <| mkApp (e.replaceFn ``ite_cond_eq_false) h
return .step b (mkApp (e.replaceFn ``ite_cond_eq_false) h) (contextDependent := cd)
else
-- If we got stuck with simplifying `p` then let's try evaluating the original isntance
simpAndMatchIteDecidable f α c inst a b do
@@ -111,7 +121,7 @@ builtin_cbv_simproc ↓ simpIteCbv (@ite _ _ _ _ _) := fun e => do
let e' := e.getBoundedAppFn 4
let e' mkAppS₄ e' c' inst' a b
let h' := mkApp3 (e.replaceFn ``Sym.ite_cond_congr) c' inst' h
return .step e' h' (done := true)
return .step e' h' (done := true) (contextDependent := cd)
/-- Reduce `dite` by matching the `Decidable` instance for `isTrue`/`isFalse`. -/
def matchDIteDecidable (f α c inst a b instToMatch : Expr) (fallback : SimpM Result) : SimpM Result := do
@@ -140,14 +150,22 @@ def matchDIteDecidableCongr (f α c inst a b c' h inst' : Expr) (fallback : Simp
/-- Simplify the `Decidable` instance, then try `simpDIteDecidable`. -/
def simpAndMatchDIteDecidable (f α c inst a b : Expr) (fallback : SimpM Result) : SimpM Result := do
match ( simp inst) with
| .rfl _ => matchDIteDecidable f α c inst a b inst fallback
| .step inst' _ _ => matchDIteDecidable f α c inst a b inst' fallback
| .rfl _ cd =>
let r matchDIteDecidable f α c inst a b inst fallback
return if cd && !r.isContextDependent then r.withContextDependent else r
| .step inst' _ _ cd =>
let r matchDIteDecidable f α c inst a b inst' fallback
return if cd && !r.isContextDependent then r.withContextDependent else r
/-- Like `simpAndMatchDIteDecidable`, but for the congruence case where `c` was simplified to `c'`. -/
def simpAndMatchDIteDecidableCongr (f α c inst a b c' h inst' : Expr) (fallback : SimpM Result) : SimpM Result := do
match ( simp inst') with
| .rfl _ => matchDIteDecidableCongr f α c inst a b c' h inst' fallback
| .step inst'' _ _ => matchDIteDecidableCongr f α c inst a b c' h inst'' fallback
| .rfl _ cd =>
let r matchDIteDecidableCongr f α c inst a b c' h inst' fallback
return if cd && !r.isContextDependent then r.withContextDependent else r
| .step inst'' _ _ cd =>
let r matchDIteDecidableCongr f α c inst a b c' h inst'' fallback
return if cd && !r.isContextDependent then r.withContextDependent else r
/-- Like `simpDIte` but also evaluates `Decidable.decide` when the condition does not
reduce to `True`/`False` directly. -/
@@ -157,24 +175,24 @@ builtin_cbv_simproc ↓ simpDIteCbv (@dite _ _ _ _ _) := fun e => do
propagateOverApplied e (numArgs - 5) fun e => do
let_expr f@dite α c inst a b := e | return .rfl
match ( simp c) with
| .rfl _ =>
| .rfl _ cd =>
if ( isTrueExpr c) then
let a' share <| a.betaRev #[mkConst ``True.intro]
return .step a' <| mkApp3 (mkConst ``dite_true f.constLevels!) α a b
return .step a' (mkApp3 (mkConst ``dite_true f.constLevels!) α a b) (contextDependent := cd)
else if ( isFalseExpr c) then
let b' share <| b.betaRev #[mkConst ``not_false]
return .step b' <| mkApp3 (mkConst ``dite_false f.constLevels!) α a b
return .step b' (mkApp3 (mkConst ``dite_false f.constLevels!) α a b) (contextDependent := cd)
else
simpAndMatchDIteDecidable f α c inst a b do return .rfl (done := true)
| .step c' h _ =>
simpAndMatchDIteDecidable f α c inst a b do return mkRflResult (done := true) (contextDependent := cd)
| .step c' h _ cd =>
if ( isTrueExpr c') then
let h' shareCommon <| mkOfEqTrueCore c h
let a share <| a.betaRev #[h']
return .step a <| mkApp (e.replaceFn ``dite_cond_eq_true) h
return .step a (mkApp (e.replaceFn ``dite_cond_eq_true) h) (contextDependent := cd)
else if ( isFalseExpr c') then
let h' shareCommon <| mkOfEqFalseCore c h
let b share <| b.betaRev #[h']
return .step b <| mkApp (e.replaceFn ``dite_cond_eq_false) h
return .step b (mkApp (e.replaceFn ``dite_cond_eq_false) h) (contextDependent := cd)
else
-- If we get stuck after simplifying `p` to `p'`, then we try to evaluate the original instance
simpAndMatchDIteDecidable f α c inst a b do
@@ -187,7 +205,7 @@ builtin_cbv_simproc ↓ simpDIteCbv (@dite _ _ _ _ _) := fun e => do
let b share <| mkLambda `h .default (mkNot c') (b.betaRev #[mkApp4 (mkConst ``Eq.mpr_not) c c' h (mkBVar 0)])
let e' mkAppS₄ e' c' inst' a b
let h' := mkApp3 (e.replaceFn ``Sym.dite_cond_congr) c' inst' h
return .step e' h' (done := true)
return .step e' h' (done := true) (contextDependent := cd)
/-- Reduce `decide` by matching the `Decidable` instance for `isTrue`/`isFalse`. -/
def matchDecideDecidable (p inst instToMatch : Expr) (fallback : SimpM Result) : SimpM Result := do
@@ -210,14 +228,22 @@ def matchDecideDecidableCongr (p p' h inst inst' : Expr) (fallback : SimpM Resul
/-- Simplify the `Decidable` instance, then try `simpDecideByInst`. -/
def simpAndMatchDecideDecidable (p inst : Expr) (fallback : SimpM Result) : SimpM Result := do
match ( simp inst) with
| .rfl _ => matchDecideDecidable p inst inst fallback
| .step inst' _ _ => matchDecideDecidable p inst inst' fallback
| .rfl _ cd =>
let r matchDecideDecidable p inst inst fallback
return if cd && !r.isContextDependent then r.withContextDependent else r
| .step inst' _ _ cd =>
let r matchDecideDecidable p inst inst' fallback
return if cd && !r.isContextDependent then r.withContextDependent else r
/-- Like `simpDecideByInstWithFallback`, but for the case where `p` was simplified to `p'`. -/
def simpAndMatchDecideDecidableCongr (p p' h inst inst' : Expr) (fallback : SimpM Result) : SimpM Result := do
match ( simp inst') with
| .rfl _ => matchDecideDecidableCongr p p' h inst inst' fallback
| .step inst'' _ _ => matchDecideDecidableCongr p p' h inst inst'' fallback
| .rfl _ cd =>
let r matchDecideDecidableCongr p p' h inst inst' fallback
return if cd && !r.isContextDependent then r.withContextDependent else r
| .step inst'' _ _ cd =>
let r matchDecideDecidableCongr p p' h inst inst'' fallback
return if cd && !r.isContextDependent then r.withContextDependent else r
/-- Simplify `Decidable.decide` by simplifying the proposition and reducing the instance.
@@ -232,18 +258,18 @@ builtin_cbv_simproc ↓ simpDecideCbv (@Decidable.decide _ _) := fun e => do
propagateOverApplied e (numArgs - 2) fun e => do
let_expr Decidable.decide p inst := e | return .rfl
match ( simp p) with
| .rfl _ =>
| .rfl _ cd =>
if ( isTrueExpr p) then
return .step ( getBoolTrueExpr) (mkApp (mkConst ``decide_true) inst)
return .step ( getBoolTrueExpr) (mkApp (mkConst ``decide_true) inst) (contextDependent := cd)
else if ( isFalseExpr p) then
return .step ( getBoolFalseExpr) (mkApp (mkConst ``decide_false) inst)
return .step ( getBoolFalseExpr) (mkApp (mkConst ``decide_false) inst) (contextDependent := cd)
else
simpAndMatchDecideDecidable p inst do return .rfl (done := true)
| .step p' hp _ =>
simpAndMatchDecideDecidable p inst do return mkRflResult (done := true) (contextDependent := cd)
| .step p' hp _ cd =>
if ( isTrueExpr p') then
return .step ( getBoolTrueExpr) <| mkApp3 (mkConst ``Sym.decide_prop_eq_true) p inst hp
return .step ( getBoolTrueExpr) (mkApp3 (mkConst ``Sym.decide_prop_eq_true) p inst hp) (contextDependent := cd)
else if ( isFalseExpr p') then
return .step ( getBoolFalseExpr) <| mkApp3 (mkConst ``Sym.decide_prop_eq_false) p inst hp
return .step ( getBoolFalseExpr) (mkApp3 (mkConst ``Sym.decide_prop_eq_false) p inst hp) (contextDependent := cd)
else
let inst' trySynthComputableInstance p'
let inst' := inst'.getD <| mkApp4 (mkConst ``decidable_of_decidable_of_eq) p p' inst hp
@@ -251,7 +277,7 @@ builtin_cbv_simproc ↓ simpDecideCbv (@Decidable.decide _ _) := fun e => do
let res := (mkConst ``Decidable.decide)
let res shareCommon res
let res mkAppS₂ res p' inst'
return .step res (mkApp5 (mkConst ``Decidable.decide.congr_simp) p p' hp inst inst') (done := true)
return .step res (mkApp5 (mkConst ``Decidable.decide.congr_simp) p p' hp inst inst') (done := true) (contextDependent := cd)
end Lean.Meta.Sym.Simp

View File

@@ -198,20 +198,20 @@ def handleProj : Simproc := fun e => do
let Expr.proj typeName idx struct := e | return .rfl
withTraceNode `Debug.Meta.Tactic.cbv.reduce (fun
| .ok (Result.step e' ..) => return m!"proj `{typeName}`.{idx}:{indentExpr e}\n==>{indentExpr e'}"
| .ok (Result.rfl true) => return m!"proj `{typeName}`.{idx}: stuck{indentExpr e}"
| .ok (Result.rfl true _) => return m!"proj `{typeName}`.{idx}: stuck{indentExpr e}"
| .ok _ => return m!"proj `{typeName}`.{idx}: no change"
| .error err => return m!"proj `{typeName}`.{idx}: {err.toMessageData}") do
-- We recursively simplify the projection
let res simp struct
match res with
| .rfl _ =>
| .rfl _ _ =>
let some reduced withCbvOpaqueGuard <| reduceProj? <| .proj typeName idx struct | do
return .rfl (done := true)
-- TODO: Figure if we can share this term incrementally
let reduced Sym.share reduced
return .step reduced ( Sym.mkEqRefl reduced)
| .step e' proof _ =>
| .step e' proof _ _ =>
let type Sym.inferType e'
let congrArgFun := Lean.mkLambda `x .default type <| .proj typeName idx <| .bvar 0
let congrArgFunType inferType congrArgFun
@@ -250,8 +250,8 @@ def simplifyAppFn : Simproc := fun e => do
else
let res simp fn
match res with
| .rfl _ => return res
| .step e' proof _ =>
| .rfl _ _ => return res
| .step e' proof _ _ =>
let newType Sym.inferType e'
let congrArgFun := Lean.mkLambda `x .default newType (mkAppN (.bvar 0) e.getAppArgs)
let newValue mkAppNS e' e.getAppArgs
@@ -356,8 +356,8 @@ public def cbvGoal (mvarId : MVarId) (simplifyTarget : Bool := true) (fvarIdsToS
| .error err => return m!"hypothesis `{localDecl.userName}`: {err.toMessageData}") do
cbvCore type config
match result with
| .rfl _ => pure ()
| .step type' proof _ =>
| .rfl _ _ => pure ()
| .step type' proof _ _ =>
if type'.isFalse then
let u getLevel type
mvarIdNew.assign ( mkFalseElim ( mvarIdNew.getType) (mkApp4 (mkConst ``Eq.mp [u]) type type' proof (mkFVar fvarId)))
@@ -374,8 +374,8 @@ public def cbvGoal (mvarId : MVarId) (simplifyTarget : Bool := true) (fvarIdsToS
| .error err => return m!"target: {err.toMessageData}") do
cbvCore target config
match result with
| .rfl _ => pure ()
| .step target' proof _ =>
| .rfl _ _ => pure ()
| .step target' proof _ _ =>
if target'.isTrue then
mvarIdNew.assign ( mkOfEqTrue proof)
return none
@@ -417,8 +417,8 @@ public def cbvDecideGoal (m : MVarId) : MetaM Unit := do
else
throwError "`decide_cbv` failed: could not reduce the expression to a boolean value; got stuck at: {indentExpr e}"
match result with
| .rfl _ => checkResult lhs (m.refl)
| .step e' proof _ => checkResult e' (m.assign proof)
| .rfl _ _ => checkResult lhs (m.refl)
| .step e' proof _ _ => checkResult e' (m.assign proof)
end Lean.Meta.Tactic.Cbv

View File

@@ -9,13 +9,13 @@ namespace SimpBench
def getProofSize (r : Sym.Simp.Result) : MetaM Nat :=
match r with
| .rfl _ => return 0
| .step _ p _ => p.numObjs
| .rfl _ _ => return 0
| .step _ p _ _ => p.numObjs
def checkWithKernel (r : Sym.Simp.Result) : MetaM Float := do
match r with
| .rfl _ => return 0.0
| .step _ p _ =>
| .rfl _ _ => return 0.0
| .step _ p _ _ =>
let p := ShareCommon.shareCommon' p
let startTime IO.monoNanosNow
Meta.checkWithKernel p
@@ -36,8 +36,8 @@ def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do
let endTime IO.monoNanosNow
-- logInfo e
-- match r with
-- | .rfl => logInfo "rfl"
-- | .step e' h => logInfo e'; logInfo h; check h
-- | .rfl _ _ => logInfo "rfl"
-- | .step e' h _ _ => logInfo e'; logInfo h; check h
let timeMs := (endTime - startTime).toFloat / 1000000.0
return (r, timeMs)
@@ -57,8 +57,8 @@ def ppExample (e : Expr) (info := false) : MetaM Unit := do
IO.println ( ppExpr e)
IO.println "====>"
match ( simp e).1 with
| .rfl _ => IO.println "<no change>"
| .step e' h _ =>
| .rfl _ _ => IO.println "<no change>"
| .step e' h _ _ =>
IO.println ( ppExpr e')
IO.println "Proof:"
if info then

View File

@@ -9,13 +9,13 @@ namespace SimpBench
def getProofSize (r : Sym.Simp.Result) : MetaM Nat := do
match r with
| .rfl _ => return 0
| .step _ p _ => (ShareCommon.shareCommon' p).numObjs
| .rfl _ _ => return 0
| .step _ p _ _ => (ShareCommon.shareCommon' p).numObjs
def checkWithKernel (r : Sym.Simp.Result) : MetaM Float := do
match r with
| .rfl _ => return 0.0
| .step _ p _ =>
| .rfl _ _ => return 0.0
| .step _ p _ _ =>
let p := ShareCommon.shareCommon' p
let startTime IO.monoNanosNow
Meta.checkWithKernel p
@@ -37,8 +37,8 @@ def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do
let timeMs := (endTime - startTime).toFloat / 1000000.0
-- logInfo e
-- match r with
-- | .rfl _ => logInfo "rfl"
-- | .step e' h _ =>
-- | .rfl _ _ => logInfo "rfl"
-- | .step e' h _ _ =>
-- logInfo e'; logInfo h
return (r, timeMs)
@@ -48,8 +48,8 @@ def ppExample (e : Expr) (info := false) : MetaM Unit := do
IO.println ( ppExpr e)
IO.println "====>"
match ( simp e).1 with
| .rfl _ => IO.println "<no change>"
| .step e' h _ =>
| .rfl _ _ => IO.println "<no change>"
| .step e' h _ _ =>
IO.println ( ppExpr e')
IO.println "Proof:"
if info then

View File

@@ -9,13 +9,13 @@ namespace SimpBench
def getProofSize (r : Sym.Simp.Result) : MetaM Nat := do
match r with
| .rfl _ => return 0
| .step _ p _ => (ShareCommon.shareCommon' p).numObjs
| .rfl _ _ => return 0
| .step _ p _ _ => (ShareCommon.shareCommon' p).numObjs
def checkWithKernel (r : Sym.Simp.Result) : MetaM Float := do
match r with
| .rfl _ => return 0.0
| .step _ p _ =>
| .rfl _ _ => return 0.0
| .step _ p _ _ =>
let p := ShareCommon.shareCommon' p
let startTime IO.monoNanosNow
Meta.checkWithKernel p
@@ -37,8 +37,8 @@ def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do
let timeMs := (endTime - startTime).toFloat / 1000000.0
-- logInfo e
-- match r with
-- | .rfl _ => logInfo "rfl"
-- | .step e' h _ =>
-- | .rfl _ _ => logInfo "rfl"
-- | .step e' h _ _ =>
-- logInfo e'; logInfo h
return (r, timeMs)
@@ -48,8 +48,8 @@ def ppExample (e : Expr) : MetaM Unit := do
IO.println ( ppExpr e)
IO.println "====>"
match ( simp e).1 with
| .rfl _ => IO.println "<no change>"
| .step e' h _ =>
| .rfl _ _ => IO.println "<no change>"
| .step e' h _ _ =>
IO.println ( ppExpr e')
IO.println ( ppExpr h)
IO.println ""

View File

@@ -9,13 +9,13 @@ def implies (p q : Prop) := p → q
def getProofSize (r : Sym.Simp.Result) : MetaM Nat := do
match r with
| .rfl _ => return 0
| .step _ p _ => (ShareCommon.shareCommon' p).numObjs
| .rfl _ _ => return 0
| .step _ p _ _ => (ShareCommon.shareCommon' p).numObjs
def checkWithKernel (r : Sym.Simp.Result) : MetaM Float := do
match r with
| .rfl _ => return 0.0
| .step _ p _ =>
| .rfl _ _ => return 0.0
| .step _ p _ _ =>
let p := ShareCommon.shareCommon' p
let startTime IO.monoNanosNow
Meta.checkWithKernel p
@@ -40,8 +40,8 @@ def simp (e : Expr) (arrowTelescope : Bool) : MetaM (Sym.Simp.Result × Float) :
let timeMs := (endTime - startTime).toFloat / 1000000.0
-- logInfo e
-- match r with
-- | .rfl _ => logInfo "rfl"
-- | .step e' h _ =>
-- | .rfl _ _ => logInfo "rfl"
-- | .step e' h _ _ =>
-- logInfo e'; logInfo h
return (r, timeMs)
@@ -50,8 +50,8 @@ def ppExample (e : Expr) (arrowTelescope : Bool) (info := false) : MetaM Unit :=
IO.println ( ppExpr e)
IO.println "====>"
match ( simp e arrowTelescope).1 with
| .rfl _ => IO.println "<no change>"
| .step e' h _ =>
| .rfl _ _ => IO.println "<no change>"
| .step e' h _ _ =>
IO.println ( ppExpr e')
IO.println "Proof:"
if info then

194
tests/elab/sym_simp_cd.lean Normal file
View File

@@ -0,0 +1,194 @@
/-
Test for `contextDependent` two-tier cache in Sym.simp.
Uses `dischargeAssumption` (context-dependent) to verify:
- Context-independent results land in persistent cache and survive across invocations.
- Context-dependent results land in transient cache and are re-computed on second invocation.
-/
import Lean
open Lean Elab Tactic Meta
/-- Invoke simp twice on the same goal, threading the persistent cache. -/
elab "sym_simp_twice" "[" declNames:ident,* "]" : tactic => do
let rewrite Sym.mkSimprocFor ( declNames.getElems.mapM fun s => realizeGlobalConstNoOverload s.raw) Sym.Simp.dischargeAssumption
let methods : Sym.Simp.Methods := {
pre := Sym.Simp.simpControl.andThen Sym.Simp.simpArrowTelescope
post := Sym.Simp.evalGround.andThen rewrite
}
liftMetaTactic1 fun mvarId => Sym.SymM.run do
let mvarId Sym.preprocessMVar mvarId
let target := ( mvarId.getDecl).type
-- First invocation: builds the cache from scratch
let (_, state) Sym.Simp.SimpM.run (Sym.Simp.simp target) methods
trace[sym.simp.debug.cache] "second traversal"
-- Second invocation: persistent cache carries over, transient cache is cleared
let (result, _) Sym.Simp.SimpM.run (Sym.Simp.simp target) methods (s := state)
( result.toSimpGoalResult mvarId).toOption
-- Test 1: Ground evaluation is context-independent.
-- The second invocation should hit the persistent cache for the whole expression.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: 2 + 3 = 5
-/
#guard_msgs in
example : 2 + 3 = 5 := by
sym_simp_twice []
-- Test 2: Conditional rewrite using a hypothesis is context-dependent.
-- `dischargeAssumption` uses local hypothesis `h : 0 < n`, so the result is context-dependent
-- and lands in the transient cache. On the second invocation, the transient cache is
-- cleared, so there should be NO persistent cache hit for the overall expression.
-- Only context-independent sub-expressions (literals, fvars) get persistent cache hits.
theorem Nat.add_comm_of_pos (a b : Nat) (_h : 0 < a) : a + b = b + a := Nat.add_comm a b
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
-/
#guard_msgs in
example (n : Nat) (h : 0 < n) : n + 2 = 2 + n := by
sym_simp_twice [Nat.add_comm_of_pos]
-- Test 3: Congruence — cd propagates through function application.
-- `n + 2` rewrites context-dependently (cd=true), `3 + 4` evaluates ground (cd=false).
-- The congruence combines both, so the overall result is cd=true.
-- On second traversal: ground sub-expressions (`3 + 4`, `7`) hit persistent cache,
-- but cd-tainted expressions (`2 + n`, `2 + n + 7`) are only in transient.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] transient cache hit: (2 + n) * 7
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 3 + 4
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] persistent cache hit: 7
[sym.simp.debug.cache] transient cache hit: (2 + n) * 7
-/
#guard_msgs in
example (n : Nat) (h : 0 < n) : (n + 2) * (3 + 4) = (2 + n) * 7 := by
sym_simp_twice [Nat.add_comm_of_pos]
-- Similar to previous test, but `Nat.add_comm_of_pos` is not applicable, but discharger must return `cd := true`.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] transient cache hit: n + 2
[sym.simp.debug.cache] transient cache hit: (n + 2) * 7
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 3 + 4
[sym.simp.debug.cache] transient cache hit: n + 2
[sym.simp.debug.cache] persistent cache hit: 7
[sym.simp.debug.cache] transient cache hit: (n + 2) * 7
-/
#guard_msgs in
example (n : Nat) : (n + 2) * (3 + 4) = (n + 2) * 7 := by
sym_simp_twice [Nat.add_comm_of_pos]
-- Test 4: Arrow — cd propagates through implication.
-- The hypothesis `n + 2 = 2 + n` is simplified context-dependently to `True`.
-- `True → True` simplifies to `True`. The whole result is cd=true.
-- `True` hits persistent cache; `2 + n` is only in transient.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] persistent cache hit: True
-/
#guard_msgs in
set_option linter.unusedVariables false in
example (n : Nat) (h : 0 < n) : (n + 2 = 2 + n) True := by
sym_simp_twice [Nat.add_comm_of_pos]
-- Test 5: Lambda — cd propagates through funext.
-- Body `n + 2` is simplified context-dependently inside the binder.
-- `withFreshTransientCache` clears the transient cache on binder entry.
-- The lambda result `fun x => 2 + n` is only in transient.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: fun x => 2 + n
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: fun x => 2 + n
-/
#guard_msgs in
example (n : Nat) (_h : 0 < n) : (fun _ : Nat => n + 2) = (fun _ : Nat => 2 + n) := by
sym_simp_twice [Nat.add_comm_of_pos]
-- Test 6: Control flow — cd propagates through `ite` condition.
-- The condition `n + 2 = 2 + n` is simplified context-dependently.
-- The `ite` result inherits cd, and `1` (ground) is in persistent cache.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] persistent cache hit: 1
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] persistent cache hit: 1
[sym.simp.debug.cache] persistent cache hit: 1
-/
#guard_msgs in
example (n : Nat) (h : 0 < n) : (if n + 2 = 2 + n then 1 else 0) = 1 := by
sym_simp_twice [Nat.add_comm_of_pos]
-- Test 7: Dependent forall — body cd under binder with `withFreshTransientCache`.
-- Simplifying `∀ (m : Nat), n + 2 = 2 + n` enters a binder (for `m`).
-- The transient cache is cleared on binder entry (`withFreshTransientCache`).
-- The body uses a cd rewrite, so the overall result is cd=true.
-- After "second traversal": `Nat` (the binder type) hits persistent cache.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: Nat
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
-/
#guard_msgs in
set_option linter.unusedVariables false in
example (n : Nat) (h : 0 < n) : (_ : Nat), n + 2 = 2 + n := by
sym_simp_twice [Nat.add_comm_of_pos]

View File

@@ -0,0 +1,104 @@
/-
Unit tests for `contextDependent` propagation in Sym.simp combinators.
-/
import Lean
open Lean Meta
private def e := mkConst ``True
private def rflProof := mkApp2 (mkConst ``Eq.refl [1]) (mkSort 0) e
-- mkRflResultCD
#eval do assert! !(Sym.Simp.mkRflResultCD false).isContextDependent
#eval do assert! (Sym.Simp.mkRflResultCD true).isContextDependent
-- isRfl checks done=false, ignores contextDependent
#eval do assert! (Sym.Simp.mkRflResultCD false).isRfl
#eval do assert! (Sym.Simp.mkRflResultCD true).isRfl
#eval do assert! !(Sym.Simp.mkRflResult (done := true)).isRfl
-- mkRflResult
#eval do assert! !(Sym.Simp.mkRflResult).isContextDependent
#eval do assert! (Sym.Simp.mkRflResult (contextDependent := true)).isContextDependent
#eval do assert! (Sym.Simp.mkRflResult (done := true) (contextDependent := true)).isContextDependent
-- Result.withContextDependent
#eval do assert! (Sym.Simp.Result.rfl).withContextDependent.isContextDependent
#eval do assert! (Sym.Simp.Result.step e rflProof).withContextDependent.isContextDependent
-- mkEqTransResult: cd₁=true, r₂ cd=false → combined cd=true
#eval show MetaM Unit from Sym.SymM.run do
let r Sym.Simp.mkEqTransResult e e rflProof (.step e rflProof) (cd₁ := true)
assert! r.isContextDependent
-- mkEqTransResult: cd₁=false, r₂ cd=true → combined cd=true
#eval show MetaM Unit from Sym.SymM.run do
let r Sym.Simp.mkEqTransResult e e rflProof (.step e rflProof false true)
assert! r.isContextDependent
-- mkEqTransResult: cd₁=false, r₂ cd=false → combined cd=false
#eval show MetaM Unit from Sym.SymM.run do
let r Sym.Simp.mkEqTransResult e e rflProof (.step e rflProof)
assert! !r.isContextDependent
-- mkEqTransResult: cd₁=true, r₂=rfl cd=false → combined cd=true
#eval show MetaM Unit from Sym.SymM.run do
let r Sym.Simp.mkEqTransResult e e rflProof (.rfl) (cd₁ := true)
assert! r.isContextDependent
-- andThen: f returns .rfl (cd=true), g returns .step (cd=false)
-- cd from f propagates: in another context f might succeed, changing the path.
#eval show MetaM Unit from Sym.SymM.run do
let f : Sym.Simp.Simproc := fun _ => return .rfl false true
let g : Sym.Simp.Simproc := fun _ => return .step e rflProof
let r Sym.Simp.SimpM.run' ((f.andThen g) e)
assert! r.isContextDependent
-- andThen: f returns .rfl (cd=false), g returns .step (cd=true) → cd from g
#eval show MetaM Unit from Sym.SymM.run do
let f : Sym.Simp.Simproc := fun _ => return .rfl
let g : Sym.Simp.Simproc := fun _ => return .step e rflProof false true
let r Sym.Simp.SimpM.run' ((f.andThen g) e)
assert! r.isContextDependent
-- andThen: f .rfl (cd=false), g .step (cd=false) → no cd
#eval show MetaM Unit from Sym.SymM.run do
let f : Sym.Simp.Simproc := fun _ => return .rfl
let g : Sym.Simp.Simproc := fun _ => return .step e rflProof
let r Sym.Simp.SimpM.run' ((f.andThen g) e)
assert! !r.isContextDependent
-- andThen transitivity: f step (cd=true), g step (cd=false)
-- Exercises mkEqTransResult through andThen.
#eval show MetaM Unit from Sym.SymM.run do
let f : Sym.Simp.Simproc := fun _ => return .step e rflProof false true
let g : Sym.Simp.Simproc := fun _ => return .step e rflProof
let r Sym.Simp.SimpM.run' ((f.andThen g) e)
assert! r.isContextDependent
-- andThen transitivity: f step (cd=false), g step (cd=true)
#eval show MetaM Unit from Sym.SymM.run do
let f : Sym.Simp.Simproc := fun _ => return .step e rflProof
let g : Sym.Simp.Simproc := fun _ => return .step e rflProof false true
let r Sym.Simp.SimpM.run' ((f.andThen g) e)
assert! r.isContextDependent
-- orElse: f returns .rfl (cd=true), g returns .step (cd=false)
-- cd from f propagates: in another context f might succeed.
#eval show MetaM Unit from Sym.SymM.run do
let f : Sym.Simp.Simproc := fun _ => return .rfl false true
let g : Sym.Simp.Simproc := fun _ => return .step e rflProof
let r Sym.Simp.SimpM.run' ((f.orElse g) e)
assert! r.isContextDependent
-- orElse: f returns .rfl (cd=false), g returns .step (cd=true) → cd from g
#eval show MetaM Unit from Sym.SymM.run do
let f : Sym.Simp.Simproc := fun _ => return .rfl
let g : Sym.Simp.Simproc := fun _ => return .step e rflProof false true
let r Sym.Simp.SimpM.run' ((f.orElse g) e)
assert! r.isContextDependent
-- orElse: f returns .step → returned directly, g not called
#eval show MetaM Unit from Sym.SymM.run do
let f : Sym.Simp.Simproc := fun _ => return .step e rflProof false true
let g : Sym.Simp.Simproc := fun _ => unreachable!
let r Sym.Simp.SimpM.run' ((f.orElse g) e)
assert! r.isContextDependent

View File

@@ -13,8 +13,8 @@ def runProblem (n : Nat) : MetaM Unit := do
let endTime IO.monoNanosNow
let ms := (endTime - startTime).toFloat / 1000000.0
match executed with
| .rfl _ => IO.println s!"goal_{n}: {ms} ms"
| .step _ proof _ =>
| .rfl _ _ => IO.println s!"goal_{n}: {ms} ms"
| .step _ proof _ _ =>
let startTime IO.monoNanosNow
Meta.checkWithKernel proof
let endTime IO.monoNanosNow

View File

@@ -154,8 +154,8 @@ def runSingleTest (n : Nat) : MetaM Unit := do
let endTime IO.monoNanosNow
let ms := (endTime - startTime).toFloat / 1000000.0
match executed with
| .rfl _ => IO.println s!"goal_{n}: {ms} ms"
| .step e' proof _ =>
| .rfl _ _ => IO.println s!"goal_{n}: {ms} ms"
| .step e' proof _ _ =>
let startTime IO.monoNanosNow
Meta.checkWithKernel proof
let endTime IO.monoNanosNow

View File

@@ -35,8 +35,8 @@ def runProblem (n : Nat) : MetaM Unit := do
let endTime IO.monoNanosNow
let ms := (endTime - startTime).toFloat / 1000000.0
match executed with
| .rfl _ => IO.println s!"mergeSort_{n}: {ms} ms (rfl)"
| .step _ proof _ =>
| .rfl _ _ => IO.println s!"mergeSort_{n}: {ms} ms (rfl)"
| .step _ proof _ _ =>
let startTime IO.monoNanosNow
Meta.checkWithKernel proof
let endTime IO.monoNanosNow