perf: revert to fused pass 2, keep cached resolvability_checker

This reverts commit 93f75deaf3's pass 2 changes (replace_fvars approach)
while keeping its cached resolvability_checker. The fused fvar substitution
in pass 2 is subquadratic on the delayed assignment benchmark (74ms vs
2116ms at n=500), but causes OOM on the larger elab_bench/bv_decide_rewriter
test due to sharing loss when the same subexpression appears in different
fvar_subst contexts.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Joachim Breitner
2026-03-04 07:14:25 +00:00
parent d72b0d5d02
commit f0b69dc841

View File

@@ -26,11 +26,11 @@ Between passes, a `resolvability_checker` determines which delayed assignments c
be fully resolved (assigned, mvar-free after resolution, sufficient arguments).
Pass 2 (`instantiate_delayed_fn`):
Resolves delayed assignments using `replace_fvars`, matching the original
`instantiateMVars` approach. Pending values are visited once (cached/written
back), then fvar substitution is applied mechanically per-site. This preserves
sharing of the visited pending value across multiple occurrences of the same
delayed mvar.
Fused traversal that resolves delayed assignments by carrying a fvar substitution.
Since pass 1 has pre-normalized all direct chains, each pending value is compact
and visited once, avoiding the O(n³) sharing loss that occurs when the fused
approach must also chase direct chains. Unassigned mvars are left as-is (matching
the original `instantiateMVars` behavior).
*/
namespace lean {
@@ -42,9 +42,6 @@ extern "C" object * lean_assign_mvar(obj_arg mctx, obj_arg mid, obj_arg val);
typedef object_ref metavar_ctx;
typedef object_ref delayed_assignment;
/* Forward declaration — defined in instantiate_mvars.cpp */
expr replace_fvars(expr const & e, array_ref<expr> const & fvars, expr const * rev_args);
static void assign_lmvar(metavar_ctx & mctx, name const & mid, level const & l) {
object * r = lean_assign_lmvar(mctx.steal(), mid.to_obj_arg(), l.to_obj_arg());
mctx.set_box(r);
@@ -418,45 +415,126 @@ public:
};
/* ============================================================================
Pass 2: Resolve delayed assignments using replace_fvars.
Direct mvar chains have been pre-resolved by pass 1, and the resolvability
checker has determined which delayed assignments can be fully resolved.
Pass 2: Resolve delayed assignments with fused fvar substitution.
Direct mvar chains have been pre-resolved by pass 1.
Uses the same replace_fvars approach as the original instantiateMVars:
pending values are visited once (cached/written back), then fvar substitution
is applied mechanically per-site via replace_fvars. This preserves sharing
of the visited pending value across multiple occurrences of the same delayed
mvar.
Uses a flat (ptr, depth)-keyed cache with generation-based staleness.
Each visit_delayed scope gets a unique generation number; cache entries
record the scope level and generation at insertion. Validity is O(1):
entry valid iff level <= m_scope && m_scope_gens[level] == entry.scope_gen.
============================================================================ */
struct fvar_subst_entry {
unsigned depth;
unsigned scope;
expr value;
};
class instantiate_delayed_fn {
struct key_hasher {
std::size_t operator()(std::pair<lean_object *, unsigned> const & p) const {
return hash((size_t)p.first >> 3, p.second);
}
};
struct cache_entry { expr result; unsigned scope_level; unsigned scope_gen; };
typedef lean::unordered_map<std::pair<lean_object *, unsigned>, cache_entry, key_hasher> flat_cache;
metavar_ctx & m_mctx;
name_set const & m_resolvable_delayed;
lean::unordered_map<lean_object *, expr> m_cache;
name_hash_map<fvar_subst_entry> m_fvar_subst;
unsigned m_depth;
/* Single flat cache with generation-based staleness detection. */
flat_cache m_cache;
std::vector<unsigned> m_scope_gens; /* m_scope_gens[level] = generation */
unsigned m_gen_counter;
unsigned m_scope;
/* After visit() returns, this holds the maximum fvar-substitution
scope that contributed to the result — i.e., the outermost scope at which the
result is valid and can be cached. Updated monotonically (via max) through
the save/reset/restore pattern in visit(). */
unsigned m_result_scope;
/* Global cache for fvar-free expressions — scope-independent. */
lean::unordered_map<lean_object *, expr> m_global_cache;
/* Write-back support: when fvar_subst is empty, normalize and write back
mvar assignments to match the original instantiateMVars mctx side effects.
Downstream code (e.g. MutualDef.mkInitialUsedFVarsMap) reads stored
assignments and expects them to be normalized. */
name_set m_already_normalized;
std::vector<expr> m_saved;
/* Get a direct mvar assignment. Visit it to resolve inner delayed mvars.
Normalize and write back the result to the mctx. This matches the
original instantiateMVars behavior: downstream code (e.g.
MutualDef.mkInitialUsedFVarsMap) reads stored assignments and expects
inner delayed assignments to be resolved. */
bool fvar_subst_empty() const {
return m_fvar_subst.empty();
}
optional<expr> lookup_fvar(name const & fid) {
auto it = m_fvar_subst.find(fid);
if (it == m_fvar_subst.end())
return optional<expr>();
m_result_scope = std::max(m_result_scope, it->second.scope);
unsigned d = m_depth - it->second.depth;
if (d == 0)
return optional<expr>(it->second.value);
return optional<expr>(lift_loose_bvars(it->second.value, d));
}
/* Cache lookup — O(1) with generation-based staleness check.
An entry at scope_level 0 (no fvar dependency) is valid at any scope.
An entry at scope_level > 0 is only valid at exactly that scope level,
because an inner scope may shadow the fvars it depends on. */
optional<expr> cache_lookup(lean_object * ptr) {
auto key = mk_pair(ptr, m_depth);
auto it = m_cache.find(key);
if (it == m_cache.end()) return {};
auto & entry = it->second;
if ((entry.scope_level == 0 || entry.scope_level == m_scope) &&
m_scope_gens[entry.scope_level] == entry.scope_gen) {
m_result_scope = std::max(m_result_scope, entry.scope_level);
return optional<expr>(entry.result);
}
return {};
}
void cache_insert(lean_object * ptr, expr const & result) {
auto key = mk_pair(ptr, m_depth);
m_cache[key] = { result, m_result_scope, m_scope_gens[m_result_scope] };
}
/* Get a direct mvar assignment. Visit it to resolve delayed mvars
and apply the fvar substitution.
When fvar_subst is empty, normalize and write back the result to
the mctx. This matches the original instantiateMVars behavior:
downstream code (e.g. MutualDef.mkInitialUsedFVarsMap) reads stored
assignments and expects inner delayed assignments to be resolved.
When fvar_subst is non-empty, no write-back (values contain
fvar-substituted terms not suitable for the mctx). */
optional<expr> get_assignment(name const & mid) {
option_ref<expr> r = get_mvar_assignment(m_mctx, mid);
if (!r)
return optional<expr>();
expr a(r.get_val());
if (!has_mvar(a))
return optional<expr>(a);
if (m_already_normalized.contains(mid))
return optional<expr>(a);
m_already_normalized.insert(mid);
expr a_new = visit(a);
if (!is_eqp(a, a_new)) {
m_saved.push_back(a);
assign_mvar(m_mctx, mid, a_new);
if (fvar_subst_empty()) {
if (!has_mvar(a))
return optional<expr>(a);
if (m_already_normalized.contains(mid))
return optional<expr>(a);
m_already_normalized.insert(mid);
expr a_new = visit(a);
if (!is_eqp(a, a_new)) {
m_saved.push_back(a);
assign_mvar(m_mctx, mid, a_new);
}
return optional<expr>(a_new);
} else {
if (!has_mvar(a) && !has_fvar(a))
return optional<expr>(a);
return optional<expr>(visit(a));
}
return optional<expr>(a_new);
}
expr visit_app_default(expr const & e) {
@@ -500,20 +578,54 @@ class instantiate_delayed_fn {
args.push_back(visit(app_arg(*curr)));
curr = &app_fn(*curr);
}
/* Get and visit the pending value (resolving inner delayed assignments).
Uses get_assignment for write-back normalization. */
optional<expr> val = get_assignment(mid_pending);
lean_assert(val); /* resolvability checker verified this is assigned */
/* Replace the delayed assignment's fvars with the visited arguments,
matching the original instantiateMVars approach. */
size_t fvar_count = fvars.size();
expr val_new = replace_fvars(*val, fvars, args.data() + (args.size() - fvar_count));
/* Use apply_beta for extra args: the fvar-substituted pending value may
be a lambda (e.g. from assertAfter), and extra args should be beta-
reduced into it rather than left as a redex. */
size_t extra_count = args.size() - fvar_count;
/* Save and extend the fvar substitution. */
struct saved_entry { name key; bool had_old; fvar_subst_entry old; };
std::vector<saved_entry> saved_entries;
saved_entries.reserve(fvar_count);
m_scope++;
for (size_t i = 0; i < fvar_count; i++) {
name const & fid = fvar_name(fvars[i]);
auto old_it = m_fvar_subst.find(fid);
if (old_it != m_fvar_subst.end()) {
saved_entries.push_back({fid, true, old_it->second});
} else {
saved_entries.push_back({fid, false, {0, 0, expr()}});
}
m_fvar_subst[fid] = {m_depth, m_scope, args[args.size() - 1 - i]};
}
/* Push: bump generation so stale entries at this scope level are detected. */
m_gen_counter++;
if (m_scope >= m_scope_gens.size())
m_scope_gens.push_back(m_gen_counter);
else
m_scope_gens[m_scope] = m_gen_counter;
expr val_new = visit(mk_mvar(mid_pending));
/* Pop: just decrement scope — stale entries are detected by generation mismatch. */
m_scope--;
/* Restore the fvar substitution. */
for (auto & se : saved_entries) {
if (!se.had_old) {
m_fvar_subst.erase(se.key);
} else {
m_fvar_subst[se.key] = se.old;
}
}
/* Use apply_beta instead of mk_rev_app: pass 1's beta-reduction may have
changed delayed mvar arguments (e.g., substituting a bvar with a concrete
value), so the resolved pending value may be a lambda that needs beta-
reduction with the extra args, matching the original's behavior. */
bool preserve_data = false;
bool zeta = true;
return apply_beta(val_new, args.size() - fvar_count, args.data(), preserve_data, zeta);
return apply_beta(val_new, extra_count, args.data(), preserve_data, zeta);
}
expr visit_app(expr const & e) {
@@ -542,11 +654,13 @@ class instantiate_delayed_fn {
(assigned and all nested delayed mvars also resolvable). This
matches the original instantiateMVars behavior. */
if (!m_resolvable_delayed.contains(mid_pending)) {
/* Normalize the pending value for mctx write-back.
Downstream code (MutualDef.mkInitialUsedFVarsMap) reads stored
assignments and relies on inner delayed assignments being
resolved even when the outer one cannot be. */
(void)get_assignment(mid_pending);
/* Still normalize the pending value for mctx write-back when
fvar_subst is empty. Downstream code (MutualDef.mkInitialUsedFVarsMap)
reads stored assignments and relies on inner delayed assignments
being resolved even when the outer one cannot be. */
if (fvar_subst_empty()) {
(void)get_assignment(mid_pending);
}
return visit_mvar_app_args(e);
}
buffer<expr> args;
@@ -561,49 +675,116 @@ class instantiate_delayed_fn {
return e;
}
inline expr cache(expr const & e, expr r, bool shared) {
if (shared) {
m_cache.insert(mk_pair(e.raw(), r));
expr visit_fvar(expr const & e) {
name const & fid = fvar_name(e);
if (auto r = lookup_fvar(fid)) {
return *r;
}
return r;
return e;
}
public:
instantiate_delayed_fn(metavar_ctx & mctx, name_set const & resolvable_delayed)
: m_mctx(mctx), m_resolvable_delayed(resolvable_delayed) {}
: m_mctx(mctx), m_resolvable_delayed(resolvable_delayed),
m_depth(0), m_gen_counter(0), m_scope(0), m_result_scope(0) {
m_scope_gens.push_back(0); /* scope 0 has generation 0 */
}
expr visit(expr const & e) {
if (!has_mvar(e))
return e;
if (fvar_subst_empty()) {
if (!has_mvar(e))
return e;
} else {
if (!has_mvar(e) && !has_fvar(e))
return e;
}
bool use_global = !has_fvar(e) && !has_expr_mvar(e);
bool shared = false;
if (is_shared(e)) {
auto it = m_cache.find(e.raw());
if (it != m_cache.end()) {
return it->second;
if (use_global) {
auto it = m_global_cache.find(e.raw());
if (it != m_global_cache.end())
return it->second;
} else {
if (auto r = cache_lookup(e.raw()))
return *r;
}
shared = true;
}
/* Save and reset the result scope for this subtree.
After computing, cache_insert uses m_result_scope to place the entry
at the outermost valid scope level. Then we restore the parent's
watermark, taking the max with our contribution. */
unsigned saved_result_scope = m_result_scope;
m_result_scope = 0;
expr r;
switch (e.kind()) {
case expr_kind::BVar:
case expr_kind::Lit: case expr_kind::FVar:
case expr_kind::Lit:
lean_unreachable();
case expr_kind::Sort: case expr_kind::Const:
/* Levels already resolved by pass 1. */
return e;
case expr_kind::FVar:
r = visit_fvar(e);
goto done; /* skip caching for fvars */
case expr_kind::Sort:
r = update_sort(e, visit_level(sort_level(e)));
break;
case expr_kind::Const:
r = update_const(e, visit_levels(const_levels(e)));
break;
case expr_kind::MVar:
return visit_mvar(e);
r = visit_mvar(e);
goto done; /* mvar results are not (ptr, depth)-cacheable */
case expr_kind::MData:
return cache(e, update_mdata(e, visit(mdata_expr(e))), shared);
r = update_mdata(e, visit(mdata_expr(e)));
break;
case expr_kind::Proj:
return cache(e, update_proj(e, visit(proj_expr(e))), shared);
r = update_proj(e, visit(proj_expr(e)));
break;
case expr_kind::App:
return cache(e, visit_app(e), shared);
case expr_kind::Pi: case expr_kind::Lambda:
return cache(e, update_binding(e, visit(binding_domain(e)), visit(binding_body(e))), shared);
case expr_kind::Let:
return cache(e, update_let(e, visit(let_type(e)), visit(let_value(e)), visit(let_body(e))), shared);
r = visit_app(e);
break;
case expr_kind::Pi: case expr_kind::Lambda: {
expr d = visit(binding_domain(e));
m_depth++;
expr b = visit(binding_body(e));
m_depth--;
r = update_binding(e, d, b);
break;
}
case expr_kind::Let: {
expr t = visit(let_type(e));
expr v = visit(let_value(e));
m_depth++;
expr b = visit(let_body(e));
m_depth--;
r = update_let(e, t, v, b);
break;
}
}
if (shared) {
if (use_global)
m_global_cache.insert(mk_pair(e.raw(), r));
else
cache_insert(e.raw(), r);
}
done:
m_result_scope = std::max(saved_result_scope, m_result_scope);
return r;
}
level visit_level(level const & l) {
/* Pass 2 does not handle level mvars — pass 1 already resolved them.
But we still need this for the visit_levels call in update_sort/update_const.
Since levels have no fvars, we can just return them as-is. */
return l;
}
levels visit_levels(levels const & ls) {
return ls;
}
expr operator()(expr const & e) { return visit(e); }
@@ -620,7 +801,7 @@ static object * run_instantiate_all(object * m, object * e) {
instantiate_direct_fn pass1(mctx);
expr e1 = pass1(expr(e));
/* Pass 2: resolve delayed assignments using replace_fvars.
/* Pass 2: resolve delayed assignments with fused fvar substitution.
Skip if pass 1 found no delayed assignments at all — the expression
has no delayed mvars that need resolution or write-back. */
expr e2;