fix: stack-based scope cache for pass 2 fvar substitution

The flat cache in pass 2's `instantiate_delayed_fn` used a scope validity
check (`entry.scope_level == 0 || entry.scope_level == m_scope`) that
rejected cache entries at intermediate scope levels, causing 0 cache hits
and 236M+ redundant visits on bv_decide_rewriter.

Replace the single-entry cache with a stack of entries per key, ordered by
scope level (innermost at back). Lookup accepts only exact scope matches.
Insert stores the result at each scope in [result_scope, current_scope].
This handles fvar shadowing and late-binding correctly without special-case
logic, and allows cross-scope cache reuse when safe.

Add a test for the late-bind scenario (fvar first seen unsubstituted, then
substituted at a higher scope).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Joachim Breitner
2026-03-04 08:07:14 +00:00
parent f0b69dc841
commit bc0324b381
2 changed files with 133 additions and 16 deletions

View File

@@ -437,16 +437,26 @@ class instantiate_delayed_fn {
}
};
struct cache_entry { expr result; unsigned scope_level; unsigned scope_gen; };
struct cache_entry {
expr result;
unsigned scope_level; /* scope at which this entry was stored */
unsigned scope_gen; /* generation of scope_level at store time */
unsigned result_scope; /* original m_result_scope for propagation */
};
typedef lean::unordered_map<std::pair<lean_object *, unsigned>, cache_entry, key_hasher> flat_cache;
typedef lean::unordered_map<std::pair<lean_object *, unsigned>,
std::vector<cache_entry>, key_hasher> flat_cache;
metavar_ctx & m_mctx;
name_set const & m_resolvable_delayed;
name_hash_map<fvar_subst_entry> m_fvar_subst;
unsigned m_depth;
/* Single flat cache with generation-based staleness detection. */
/* Flat cache mapping (ptr, depth) to a stack of entries ordered by scope
(innermost/highest scope at the back). The stack-based approach handles
both fvar shadowing and late-binding of fvars without special-case logic:
- Lookup: drop stale entries from the back, accept only exact scope match.
- Insert: store entries for each scope in [result_scope, current_scope]. */
flat_cache m_cache;
std::vector<unsigned> m_scope_gens; /* m_scope_gens[level] = generation */
unsigned m_gen_counter;
@@ -483,26 +493,48 @@ class instantiate_delayed_fn {
return optional<expr>(lift_loose_bvars(it->second.value, d));
}
/* Cache lookup — O(1) with generation-based staleness check.
An entry at scope_level 0 (no fvar dependency) is valid at any scope.
An entry at scope_level > 0 is only valid at exactly that scope level,
because an inner scope may shadow the fvars it depends on. */
/* Cache lookup: scan the entry stack from the back (innermost scope first),
drop stale entries, and accept only an exact scope match.
This is always correct: an entry at scope S is only returned when
m_scope == S, so shadowed or late-bound fvars cannot produce stale hits. */
optional<expr> cache_lookup(lean_object * ptr) {
auto key = mk_pair(ptr, m_depth);
auto it = m_cache.find(key);
if (it == m_cache.end()) return {};
auto & entry = it->second;
if ((entry.scope_level == 0 || entry.scope_level == m_scope) &&
m_scope_gens[entry.scope_level] == entry.scope_gen) {
m_result_scope = std::max(m_result_scope, entry.scope_level);
return optional<expr>(entry.result);
auto & stack = it->second;
/* Drop stale entries from the back. In a LIFO scope structure,
if scope S is stale, all scopes > S were popped earlier. */
while (!stack.empty()) {
auto & top = stack.back();
if (top.scope_level <= m_scope &&
m_scope_gens[top.scope_level] == top.scope_gen) {
/* First valid entry — accept only if at current scope. */
if (top.scope_level == m_scope) {
m_result_scope = std::max(m_result_scope, top.result_scope);
return optional<expr>(top.result);
}
return {}; /* valid but at a lower scope — miss */
}
stack.pop_back();
}
return {};
}
/* Cache insert: store the result at each scope level in
[m_result_scope, m_scope], so that lookups at any of those scopes
will find an entry with an exact scope match. */
void cache_insert(lean_object * ptr, expr const & result) {
auto key = mk_pair(ptr, m_depth);
m_cache[key] = { result, m_result_scope, m_scope_gens[m_result_scope] };
auto & stack = m_cache[key];
/* Drop entries at scope_level >= m_result_scope — they are
superseded by the new result (or stale from a popped scope). */
while (!stack.empty() && stack.back().scope_level >= m_result_scope) {
stack.pop_back();
}
/* Push entries for each scope in [m_result_scope, m_scope]. */
for (unsigned s = m_result_scope; s <= m_scope; s++) {
stack.push_back({result, s, m_scope_gens[s], m_result_scope});
}
}
/* Get a direct mvar assignment. Visit it to resolve delayed mvars

View File

@@ -64,7 +64,6 @@ private def mkShadowTest : MetaM Expr := do
private def mkExpected : Expr :=
let nat := mkConst ``Nat
let succ := mkConst ``Nat.succ
let pairTy := mkApp2 (mkConst ``Prod [.succ .zero, .succ .zero]) nat nat
-- #0 refers to the lambda-bound `a`
let succ_a := mkApp succ (.bvar 0)
let succ_succ_a := mkApp succ succ_a
@@ -83,9 +82,95 @@ run_meta do
let saved saveState
let eOrig instantiateMVarsOriginal root
saved.restore
check "instantiateMVarsOriginal" eOrig
check "instantiateMVarsOriginal (shadow)" eOrig
let saved saveState
let eNew instantiateAllMVars root
saved.restore
check "instantiateAllMVars" eNew
check "instantiateAllMVars (shadow)" eNew
/-
Test: an fvar first seen unsubstituted, then substituted at a higher scope.
A shared subexpression `succ_y := Nat.succ y_fvar` is used both:
- directly in the body of d1 (where y is NOT bound), and
- inside d2's pending value (where y IS bound).
?root := fun (a : Nat) => ?d1 a
?d1 delayed [x] := ?body
?body := Prod.mk succ_y (?d2 succ_y) ← succ_y shared
?d2 delayed [y] := ?inner ← y is NOW bound
?inner := succ_y ← same shared object
Expected result:
fun (a : Nat) => (Nat.succ y_fvar, Nat.succ (Nat.succ y_fvar))
At scope 1 (d1), x → a. Visit body:
- succ_y: y is NOT in fvar_subst. Result is succ_y unchanged.
- ?d2 succ_y: arg succ_y visited → succ_y. Then d2 at scope 2 with y → succ_y.
- Visit ?inner = succ_y. y IS in fvar_subst → Nat.succ succ_y = Nat.succ (Nat.succ y_fvar).
A buggy cache would return the scope-1 result (succ_y unchanged) at scope 2,
producing (Nat.succ y_fvar, Nat.succ y_fvar) instead.
-/
private def mkLateBindTest : MetaM (Expr × Expr) := do
let nat := mkConst ``Nat
withLocalDeclD `x nat fun x_fvar =>
withLocalDeclD `y nat fun y_fvar => do
-- shared object referencing y_fvar (NOT x_fvar)
let succ_y := mkApp (mkConst ``Nat.succ) y_fvar
-- ?inner := succ_y
let inner mkFreshExprMVar nat
inner.mvarId!.assign succ_y
-- ?d2 delayed [y_fvar] := ?inner
let d2_ty mkArrow nat nat
let d2 mkFreshExprMVar d2_ty (kind := .syntheticOpaque)
assignDelayedMVar d2.mvarId! #[y_fvar] inner.mvarId!
-- ?body := ⟨succ_y, ?d2 succ_y⟩
let pairTy := mkApp2 (mkConst ``Prod [.succ .zero, .succ .zero]) nat nat
let body mkFreshExprMVar pairTy
body.mvarId!.assign
(mkApp4 (mkConst ``Prod.mk [.succ .zero, .succ .zero]) nat nat
succ_y (mkApp d2 succ_y))
-- ?d1 delayed [x_fvar] := ?body
let d1_ty mkArrow nat pairTy
let d1 mkFreshExprMVar d1_ty (kind := .syntheticOpaque)
assignDelayedMVar d1.mvarId! #[x_fvar] body.mvarId!
-- ?root := fun (a : Nat) => ?d1 a
let rootTy mkArrow nat pairTy
let root mkFreshExprMVar rootTy
root.mvarId!.assign (Lean.mkLambda `a .default nat (mkApp d1 (.bvar 0)))
return (root, y_fvar)
-- Expected: fun (a : Nat) => (Nat.succ y_fvar, Nat.succ (Nat.succ y_fvar))
private def mkExpectedLateBind (y_fvar : Expr) : Expr :=
let nat := mkConst ``Nat
let succ := mkConst ``Nat.succ
let succ_y := mkApp succ y_fvar
let succ_succ_y := mkApp succ succ_y
let body := mkApp4 (mkConst ``Prod.mk [.succ .zero, .succ .zero]) nat nat succ_y succ_succ_y
Lean.mkLambda `a .default nat body
private def checkLateBind (label : String) (result : Expr) (y_fvar : Expr) : MetaM Unit := do
let expected := mkExpectedLateBind y_fvar
unless result == expected do
throwError "{label}: expected {expected}, got {result}"
run_meta do
let (root, y_fvar) mkLateBindTest
let saved saveState
let eOrig instantiateMVarsOriginal root
saved.restore
checkLateBind "instantiateMVarsOriginal (late-bind)" eOrig y_fvar
let saved saveState
let eNew instantiateAllMVars root
saved.restore
checkLateBind "instantiateAllMVars (late-bind)" eNew y_fvar