perf: integrate resolvability tracking into pass 1 traversal cache

Replace the uncached `is_value_resolvable` tree walk with integrated
resolvability tracking in pass 1's traversal. Each cache entry now
stores both the result expression and a resolvability flag, eliminating
redundant O(n) walks that lost sharing on large expressions.

Key changes:
- `m_result_resolvable` flag propagated via AND through save/restore
  pattern, cached alongside expression results
- `m_normalized_resolvable` name_set tracks per-mvar resolvability
  for the `m_already_normalized` early-return path in `get_assignment`
- Cache hits and mvar-free early returns AND into (not overwrite)
  the current resolvability state to preserve sibling contributions
- Pass 2 write-back for non-resolvable delayed assignments restored
  (needed by MutualDef.mkInitialUsedFVarsMap)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Joachim Breitner
2026-03-03 13:55:41 +00:00
parent f0562fac73
commit ff6cd50b8d

View File

@@ -121,7 +121,10 @@ public:
/* ============================================================================
Pass 1: Resolve direct mvar assignments with write-back.
For delayed assignments, pre-normalize the pending value but leave the
delayed mvar application in the expression.
delayed mvar application in the expression. Tracks resolvability of
each subexpression (whether all remaining mvars are delayed-assigned
with resolvable pending values) as a side product of the traversal,
cached alongside the result to avoid redundant tree walks.
============================================================================ */
class instantiate_direct_fn {
@@ -129,14 +132,26 @@ class instantiate_direct_fn {
instantiate_lmvars_all_fn m_level_fn;
name_set m_already_normalized;
/* Set of delayed-assigned mvars whose pending value is assigned and
mvar-free after normalization. Used by pass 2 as a guard: only resolve
delayed assignments when the pending mvar is in this set, matching
the original instantiateMVars behavior. */
fully resolvable after pass 1 normalization. Used by pass 2 as a
guard: only resolve delayed assignments when the pending mvar is
in this set, matching the original instantiateMVars behavior. */
name_set m_resolvable_delayed;
/* Tracks which normalized mvars have resolvable values, so that
when get_assignment returns early via m_already_normalized, we
can correctly set m_result_resolvable. */
name_set m_normalized_resolvable;
/* Set to true when any delayed assignment is encountered, even if not
resolvable. Pass 2 is needed for write-back normalization in that case. */
bool m_has_delayed;
lean::unordered_map<lean_object *, expr> m_cache;
/* After visit() returns, indicates whether the result would be mvar-free
after full delayed-assignment resolution (pass 2). Updated via AND
through the save/restore pattern: if any child is not resolvable,
the parent is not resolvable. Cached alongside the result expression. */
bool m_result_resolvable;
struct cache_entry { expr result; bool resolvable; };
lean::unordered_map<lean_object *, cache_entry> m_cache;
std::vector<expr> m_saved;
level visit_level(level const & l) {
@@ -152,23 +167,35 @@ class instantiate_direct_fn {
inline expr cache(expr const & e, expr r, bool shared) {
if (shared) {
m_cache.insert(mk_pair(e.raw(), r));
m_cache.insert(mk_pair(e.raw(), cache_entry{r, m_result_resolvable}));
}
return r;
}
/* Get and normalize a direct mvar assignment. Write back the normalized value. */
/* Get and normalize a direct mvar assignment. Write back the normalized value.
Sets m_result_resolvable as a side effect: true if the normalized value
would be mvar-free after full delayed-assignment resolution. */
optional<expr> get_assignment(name const & mid) {
option_ref<expr> r = get_mvar_assignment(m_mctx, mid);
if (!r) {
return optional<expr>();
}
expr a(r.get_val());
if (!has_mvar(a) || m_already_normalized.contains(mid)) {
if (!has_mvar(a)) {
/* No mvars at all — trivially resolvable. m_result_resolvable stays true. */
return optional<expr>(a);
}
if (m_already_normalized.contains(mid)) {
/* Use cached resolvability from the first normalization visit. */
if (!m_normalized_resolvable.contains(mid))
m_result_resolvable = false;
return optional<expr>(a);
}
m_already_normalized.insert(mid);
expr a_new = visit(a);
/* m_result_resolvable was set by visit(a). Cache it. */
if (m_result_resolvable)
m_normalized_resolvable.insert(mid);
if (!is_eqp(a, a_new)) {
m_saved.push_back(a);
assign_mvar(m_mctx, mid, a_new);
@@ -210,65 +237,6 @@ class instantiate_direct_fn {
return apply_beta(f_new, args.size(), args.data(), preserve_data, zeta);
}
/* Check whether a normalized value would be mvar-free after full resolution.
Uses m_resolvable_delayed to check inner delayed mvars. After pass 1
normalization, remaining mvars are either unassigned or delayed-assigned. */
bool is_value_resolvable(expr const & e) {
if (!has_expr_mvar(e)) return true;
switch (e.kind()) {
case expr_kind::BVar: case expr_kind::Lit: case expr_kind::FVar:
case expr_kind::Sort: case expr_kind::Const:
return true;
case expr_kind::MVar:
/* Bare mvar after pass 1 normalization: not directly assigned. */
return false;
case expr_kind::App: {
expr const & f = get_app_fn(e);
if (is_mvar(f)) {
/* Mvar app after pass 1: must be delayed-assigned or unassigned. */
name const & mid = mvar_name(f);
option_ref<delayed_assignment> d = get_delayed_mvar_assignment(m_mctx, mid);
if (!d) return false;
array_ref<expr> fvars(cnstr_get(d.get_val().raw(), 0), true);
if (fvars.size() > get_app_num_args(e)) return false;
name mid_pending(cnstr_get(d.get_val().raw(), 1), true);
if (!m_resolvable_delayed.contains(mid_pending)) return false;
/* Also check args for unresolvable mvars. */
expr const * curr = &e;
while (is_app(*curr)) {
if (!is_value_resolvable(app_arg(*curr))) return false;
curr = &app_fn(*curr);
}
return true;
}
return is_value_resolvable(app_fn(e)) && is_value_resolvable(app_arg(e));
}
case expr_kind::Lambda: case expr_kind::Pi:
return is_value_resolvable(binding_domain(e)) && is_value_resolvable(binding_body(e));
case expr_kind::Let:
return is_value_resolvable(let_type(e)) && is_value_resolvable(let_value(e))
&& is_value_resolvable(let_body(e));
case expr_kind::MData:
return is_value_resolvable(mdata_expr(e));
case expr_kind::Proj:
return is_value_resolvable(proj_expr(e));
}
lean_unreachable();
}
/* Pre-normalize the pending value of a delayed assignment and record
whether it is resolvable (assigned and mvar-free after full resolution).
Inner delayed assignments are processed first (via recursive normalization),
so m_resolvable_delayed is already populated for them. */
void normalize_delayed_pending(name const & mid_pending) {
m_has_delayed = true;
if (auto val = get_assignment(mid_pending)) {
if (is_value_resolvable(*val)) {
m_resolvable_delayed.insert(mid_pending);
}
}
}
expr visit_app(expr const & e) {
expr const & f = get_app_fn(e);
if (!is_mvar(f)) {
@@ -280,13 +248,45 @@ class instantiate_direct_fn {
buffer<expr> args;
return visit_args_and_beta(*f_new, e, args);
}
/* Check delayed assignment and pre-normalize pending. */
/* Check for delayed assignment. */
option_ref<delayed_assignment> d = get_delayed_mvar_assignment(m_mctx, mid);
if (d) {
m_has_delayed = true;
array_ref<expr> fvars(cnstr_get(d.get_val().raw(), 0), true);
name mid_pending(cnstr_get(d.get_val().raw(), 1), true);
normalize_delayed_pending(mid_pending);
/* Pre-normalize the pending value and check resolvability.
Save/restore m_result_resolvable so the pending check
doesn't pollute the outer resolvability tracking.
m_result_resolvable (set by get_assignment's visit) correctly
tracks whether all delayed mvars in the value are themselves
resolvable — no has_expr_mvar check needed. */
bool saved = m_result_resolvable;
m_result_resolvable = true;
bool pending_resolvable = false;
if (auto val = get_assignment(mid_pending)) {
pending_resolvable = m_result_resolvable;
}
m_result_resolvable = saved;
if (pending_resolvable) {
m_resolvable_delayed.insert(mid_pending);
}
/* Visit args — their resolvability contributes to the parent. */
expr result = visit_mvar_app_args(e);
/* The delayed mvar app is resolvable only if: pending is resolvable,
sufficient args for fvar substitution, and all args are resolvable. */
if (!pending_resolvable || fvars.size() > get_app_num_args(e)) {
m_result_resolvable = false;
}
/* Otherwise m_result_resolvable was set by visit_mvar_app_args
(which visits and ANDs the args' resolvability). */
return result;
}
/* Leave the (possibly delayed) mvar in place, just visit args. */
/* Not delayed: unassigned mvar — not resolvable. */
m_result_resolvable = false;
return visit_mvar_app_args(e);
}
@@ -295,53 +295,95 @@ class instantiate_direct_fn {
if (auto r = get_assignment(mid)) {
return *r;
}
/* Not directly assigned. Check if delayed-assigned and pre-normalize. */
/* Not directly assigned. Check if delayed-assigned. */
option_ref<delayed_assignment> d = get_delayed_mvar_assignment(m_mctx, mid);
if (d) {
m_has_delayed = true;
name mid_pending(cnstr_get(d.get_val().raw(), 1), true);
normalize_delayed_pending(mid_pending);
/* Pre-normalize the pending value. Even though a bare mvar
can't resolve (no args for fvar subst), the same mid_pending
may appear in a delayed mvar app elsewhere that does have
enough args, so we must record its resolvability. */
bool saved = m_result_resolvable;
m_result_resolvable = true;
if (auto val = get_assignment(mid_pending)) {
if (m_result_resolvable)
m_resolvable_delayed.insert(mid_pending);
}
m_result_resolvable = saved;
}
return e; /* leave mvar in place */
/* Bare mvar after pass 1: not resolvable (no args for fvar subst). */
m_result_resolvable = false;
return e;
}
public:
instantiate_direct_fn(metavar_ctx & mctx):m_mctx(mctx), m_level_fn(mctx), m_has_delayed(false) {}
instantiate_direct_fn(metavar_ctx & mctx)
: m_mctx(mctx), m_level_fn(mctx), m_has_delayed(false), m_result_resolvable(true) {}
name_set const & resolvable_delayed() const { return m_resolvable_delayed; }
bool has_delayed() const { return m_has_delayed; }
expr visit(expr const & e) {
if (!has_mvar(e))
if (!has_mvar(e)) {
/* No mvars: trivially resolvable. AND with true is identity,
so don't touch m_result_resolvable — preserves sibling contributions. */
return e;
}
bool shared = false;
if (is_shared(e)) {
auto it = m_cache.find(e.raw());
if (it != m_cache.end()) {
return it->second;
/* AND cached resolvability with current state to preserve
sibling contributions (e.g., domain's resolvability when
visiting body of a Pi). */
m_result_resolvable = m_result_resolvable && it->second.resolvable;
return it->second.result;
}
shared = true;
}
/* Save and reset m_result_resolvable. Each subtree computes its own
resolvability independently. After all children are visited,
m_result_resolvable reflects the combined resolvability.
We restore the parent's flag via AND at the end. */
bool saved_resolvable = m_result_resolvable;
m_result_resolvable = true;
expr r;
switch (e.kind()) {
case expr_kind::BVar:
case expr_kind::Lit: case expr_kind::FVar:
lean_unreachable();
case expr_kind::Sort:
return cache(e, update_sort(e, visit_level(sort_level(e))), shared);
r = cache(e, update_sort(e, visit_level(sort_level(e))), shared);
break;
case expr_kind::Const:
return cache(e, update_const(e, visit_levels(const_levels(e))), shared);
r = cache(e, update_const(e, visit_levels(const_levels(e))), shared);
break;
case expr_kind::MVar:
return visit_mvar(e);
r = visit_mvar(e);
goto done; /* mvar results depend on mctx state, skip caching */
case expr_kind::MData:
return cache(e, update_mdata(e, visit(mdata_expr(e))), shared);
r = cache(e, update_mdata(e, visit(mdata_expr(e))), shared);
break;
case expr_kind::Proj:
return cache(e, update_proj(e, visit(proj_expr(e))), shared);
r = cache(e, update_proj(e, visit(proj_expr(e))), shared);
break;
case expr_kind::App:
return cache(e, visit_app(e), shared);
r = cache(e, visit_app(e), shared);
break;
case expr_kind::Pi: case expr_kind::Lambda:
return cache(e, update_binding(e, visit(binding_domain(e)), visit(binding_body(e))), shared);
r = cache(e, update_binding(e, visit(binding_domain(e)), visit(binding_body(e))), shared);
break;
case expr_kind::Let:
return cache(e, update_let(e, visit(let_type(e)), visit(let_value(e)), visit(let_body(e))), shared);
r = cache(e, update_let(e, visit(let_type(e)), visit(let_value(e)), visit(let_body(e))), shared);
break;
}
done:
/* Propagate: parent is resolvable only if this subtree is too. */
m_result_resolvable = saved_resolvable && m_result_resolvable;
return r;
}
expr operator()(expr const & e) { return visit(e); }
@@ -582,15 +624,15 @@ class instantiate_delayed_fn {
if (fvars.size() > get_app_num_args(e)) {
return visit_mvar_app_args(e);
}
/* Match standard instantiateMVars: only resolve the delayed assignment
when the pending value was determined to be resolvable by pass 1
(assigned and mvar-free after normalization). */
/* Only resolve the delayed assignment when pass 1 determined
the pending value is fully resolvable (assigned and all nested
delayed mvars also resolvable). This matches the original
instantiateMVars behavior. */
if (!m_resolvable_delayed.contains(mid_pending)) {
/* Still normalize the pending value for mctx write-back side effects.
The original instantiateMVars always normalizes the pending value
(via get_assignment(mid_pending)) even when it can't resolve.
Downstream code like MutualDef.mkInitialUsedFVarsMap reads stored
assignments and relies on inner delayed assignments being resolved. */
/* Still normalize the pending value for mctx write-back when
fvar_subst is empty. Downstream code (MutualDef.mkInitialUsedFVarsMap)
reads stored assignments and relies on inner delayed assignments
being resolved even when the outer one cannot be. */
if (fvar_subst_empty()) {
(void)get_assignment(mid_pending);
}