mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-17 18:34:06 +00:00
perf: compute delayed assignment resolvability via post-pass fixpoint
Move resolvability computation from pass 1's inline tracking (which stores a bool alongside each cache entry, increasing entry size from 16 to 24 bytes) to a separate fixpoint step between pass 1 and pass 2. Pass 1 now only collects a lightweight head→pending mapping for delayed assignments encountered during traversal. After pass 1 completes, the fixpoint analyzes each pending value's normalized form to determine which delayed assignments are fully resolvable. This keeps pass 1's cache entries at their minimal size, avoiding the memory pressure that caused bv_decide regressions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -120,38 +120,24 @@ public:
|
||||
|
||||
/* ============================================================================
|
||||
Pass 1: Resolve direct mvar assignments with write-back.
|
||||
For delayed assignments, pre-normalize the pending value but leave the
|
||||
delayed mvar application in the expression. Tracks resolvability of
|
||||
each subexpression (whether all remaining mvars are delayed-assigned
|
||||
with resolvable pending values) as a side product of the traversal,
|
||||
cached alongside the result to avoid redundant tree walks.
|
||||
For delayed assignments, pre-normalize the pending value (resolving its
|
||||
direct chains) but leave the delayed mvar application in the expression.
|
||||
Unassigned mvars are left in place.
|
||||
============================================================================ */
|
||||
|
||||
class instantiate_direct_fn {
|
||||
metavar_ctx & m_mctx;
|
||||
instantiate_lmvars_all_fn m_level_fn;
|
||||
name_set m_already_normalized;
|
||||
/* Set of delayed-assigned mvars whose pending value is assigned and
|
||||
fully resolvable after pass 1 normalization. Used by pass 2 as a
|
||||
guard: only resolve delayed assignments when the pending mvar is
|
||||
in this set, matching the original instantiateMVars behavior. */
|
||||
name_set m_resolvable_delayed;
|
||||
/* Tracks which normalized mvars have resolvable values, so that
|
||||
when get_assignment returns early via m_already_normalized, we
|
||||
can correctly set m_result_resolvable. */
|
||||
name_set m_normalized_resolvable;
|
||||
/* Set to true when any delayed assignment is encountered, even if not
|
||||
resolvable. Pass 2 is needed for write-back normalization in that case. */
|
||||
bool m_has_delayed;
|
||||
/* Mapping from delayed-assigned mvar head to its pending mvar name.
|
||||
Collected during traversal; used after pass 1 to compute the set of
|
||||
resolvable delayed assignments without adding per-entry overhead. */
|
||||
name_hash_map<name> m_delayed_head_to_pending;
|
||||
|
||||
/* After visit() returns, indicates whether the result would be mvar-free
|
||||
after full delayed-assignment resolution (pass 2). Updated via AND
|
||||
through the save/restore pattern: if any child is not resolvable,
|
||||
the parent is not resolvable. Cached alongside the result expression. */
|
||||
bool m_result_resolvable;
|
||||
|
||||
struct cache_entry { expr result; bool resolvable; };
|
||||
lean::unordered_map<lean_object *, cache_entry> m_cache;
|
||||
lean::unordered_map<lean_object *, expr> m_cache;
|
||||
std::vector<expr> m_saved;
|
||||
|
||||
level visit_level(level const & l) {
|
||||
@@ -167,35 +153,23 @@ class instantiate_direct_fn {
|
||||
|
||||
inline expr cache(expr const & e, expr r, bool shared) {
|
||||
if (shared) {
|
||||
m_cache.insert(mk_pair(e.raw(), cache_entry{r, m_result_resolvable}));
|
||||
m_cache.insert(mk_pair(e.raw(), r));
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
/* Get and normalize a direct mvar assignment. Write back the normalized value.
|
||||
Sets m_result_resolvable as a side effect: true if the normalized value
|
||||
would be mvar-free after full delayed-assignment resolution. */
|
||||
/* Get and normalize a direct mvar assignment. Write back the normalized value. */
|
||||
optional<expr> get_assignment(name const & mid) {
|
||||
option_ref<expr> r = get_mvar_assignment(m_mctx, mid);
|
||||
if (!r) {
|
||||
return optional<expr>();
|
||||
}
|
||||
expr a(r.get_val());
|
||||
if (!has_mvar(a)) {
|
||||
/* No mvars at all — trivially resolvable. m_result_resolvable stays true. */
|
||||
return optional<expr>(a);
|
||||
}
|
||||
if (m_already_normalized.contains(mid)) {
|
||||
/* Use cached resolvability from the first normalization visit. */
|
||||
if (!m_normalized_resolvable.contains(mid))
|
||||
m_result_resolvable = false;
|
||||
if (!has_mvar(a) || m_already_normalized.contains(mid)) {
|
||||
return optional<expr>(a);
|
||||
}
|
||||
m_already_normalized.insert(mid);
|
||||
expr a_new = visit(a);
|
||||
/* m_result_resolvable was set by visit(a). Cache it. */
|
||||
if (m_result_resolvable)
|
||||
m_normalized_resolvable.insert(mid);
|
||||
if (!is_eqp(a, a_new)) {
|
||||
m_saved.push_back(a);
|
||||
assign_mvar(m_mctx, mid, a_new);
|
||||
@@ -252,41 +226,13 @@ class instantiate_direct_fn {
|
||||
option_ref<delayed_assignment> d = get_delayed_mvar_assignment(m_mctx, mid);
|
||||
if (d) {
|
||||
m_has_delayed = true;
|
||||
array_ref<expr> fvars(cnstr_get(d.get_val().raw(), 0), true);
|
||||
name mid_pending(cnstr_get(d.get_val().raw(), 1), true);
|
||||
|
||||
/* Pre-normalize the pending value and check resolvability.
|
||||
Save/restore m_result_resolvable so the pending check
|
||||
doesn't pollute the outer resolvability tracking.
|
||||
m_result_resolvable (set by get_assignment's visit) correctly
|
||||
tracks whether all delayed mvars in the value are themselves
|
||||
resolvable — no has_expr_mvar check needed. */
|
||||
bool saved = m_result_resolvable;
|
||||
m_result_resolvable = true;
|
||||
bool pending_resolvable = false;
|
||||
if (auto val = get_assignment(mid_pending)) {
|
||||
pending_resolvable = m_result_resolvable;
|
||||
}
|
||||
m_result_resolvable = saved;
|
||||
|
||||
if (pending_resolvable) {
|
||||
m_resolvable_delayed.insert(mid_pending);
|
||||
}
|
||||
|
||||
/* Visit args — their resolvability contributes to the parent. */
|
||||
expr result = visit_mvar_app_args(e);
|
||||
|
||||
/* The delayed mvar app is resolvable only if: pending is resolvable,
|
||||
sufficient args for fvar substitution, and all args are resolvable. */
|
||||
if (!pending_resolvable || fvars.size() > get_app_num_args(e)) {
|
||||
m_result_resolvable = false;
|
||||
}
|
||||
/* Otherwise m_result_resolvable was set by visit_mvar_app_args
|
||||
(which visits and ANDs the args' resolvability). */
|
||||
return result;
|
||||
m_delayed_head_to_pending.insert(mk_pair(mid, mid_pending));
|
||||
/* Pre-normalize the pending value so pass 2 finds it ready. */
|
||||
(void)get_assignment(mid_pending);
|
||||
return visit_mvar_app_args(e);
|
||||
}
|
||||
/* Not delayed: unassigned mvar — not resolvable. */
|
||||
m_result_resolvable = false;
|
||||
/* Not delayed: unassigned mvar. */
|
||||
return visit_mvar_app_args(e);
|
||||
}
|
||||
|
||||
@@ -300,95 +246,227 @@ class instantiate_direct_fn {
|
||||
if (d) {
|
||||
m_has_delayed = true;
|
||||
name mid_pending(cnstr_get(d.get_val().raw(), 1), true);
|
||||
|
||||
/* Pre-normalize the pending value. Even though a bare mvar
|
||||
can't resolve (no args for fvar subst), the same mid_pending
|
||||
may appear in a delayed mvar app elsewhere that does have
|
||||
enough args, so we must record its resolvability. */
|
||||
bool saved = m_result_resolvable;
|
||||
m_result_resolvable = true;
|
||||
if (auto val = get_assignment(mid_pending)) {
|
||||
if (m_result_resolvable)
|
||||
m_resolvable_delayed.insert(mid_pending);
|
||||
}
|
||||
m_result_resolvable = saved;
|
||||
m_delayed_head_to_pending.insert(mk_pair(mid, mid_pending));
|
||||
/* Pre-normalize the pending value so pass 2 finds it ready. */
|
||||
(void)get_assignment(mid_pending);
|
||||
}
|
||||
/* Bare mvar after pass 1: not resolvable (no args for fvar subst). */
|
||||
m_result_resolvable = false;
|
||||
return e;
|
||||
}
|
||||
|
||||
public:
|
||||
instantiate_direct_fn(metavar_ctx & mctx)
|
||||
: m_mctx(mctx), m_level_fn(mctx), m_has_delayed(false), m_result_resolvable(true) {}
|
||||
name_set const & resolvable_delayed() const { return m_resolvable_delayed; }
|
||||
: m_mctx(mctx), m_level_fn(mctx), m_has_delayed(false) {}
|
||||
bool has_delayed() const { return m_has_delayed; }
|
||||
name_hash_map<name> const & head_to_pending() const { return m_delayed_head_to_pending; }
|
||||
|
||||
expr visit(expr const & e) {
|
||||
if (!has_mvar(e)) {
|
||||
/* No mvars: trivially resolvable. AND with true is identity,
|
||||
so don't touch m_result_resolvable — preserves sibling contributions. */
|
||||
if (!has_mvar(e))
|
||||
return e;
|
||||
}
|
||||
bool shared = false;
|
||||
if (is_shared(e)) {
|
||||
auto it = m_cache.find(e.raw());
|
||||
if (it != m_cache.end()) {
|
||||
/* AND cached resolvability with current state to preserve
|
||||
sibling contributions (e.g., domain's resolvability when
|
||||
visiting body of a Pi). */
|
||||
m_result_resolvable = m_result_resolvable && it->second.resolvable;
|
||||
return it->second.result;
|
||||
return it->second;
|
||||
}
|
||||
shared = true;
|
||||
}
|
||||
|
||||
/* Save and reset m_result_resolvable. Each subtree computes its own
|
||||
resolvability independently. After all children are visited,
|
||||
m_result_resolvable reflects the combined resolvability.
|
||||
We restore the parent's flag via AND at the end. */
|
||||
bool saved_resolvable = m_result_resolvable;
|
||||
m_result_resolvable = true;
|
||||
|
||||
expr r;
|
||||
switch (e.kind()) {
|
||||
case expr_kind::BVar:
|
||||
case expr_kind::Lit: case expr_kind::FVar:
|
||||
lean_unreachable();
|
||||
case expr_kind::Sort:
|
||||
r = cache(e, update_sort(e, visit_level(sort_level(e))), shared);
|
||||
break;
|
||||
return cache(e, update_sort(e, visit_level(sort_level(e))), shared);
|
||||
case expr_kind::Const:
|
||||
r = cache(e, update_const(e, visit_levels(const_levels(e))), shared);
|
||||
break;
|
||||
return cache(e, update_const(e, visit_levels(const_levels(e))), shared);
|
||||
case expr_kind::MVar:
|
||||
r = visit_mvar(e);
|
||||
goto done; /* mvar results depend on mctx state, skip caching */
|
||||
return visit_mvar(e);
|
||||
case expr_kind::MData:
|
||||
r = cache(e, update_mdata(e, visit(mdata_expr(e))), shared);
|
||||
break;
|
||||
return cache(e, update_mdata(e, visit(mdata_expr(e))), shared);
|
||||
case expr_kind::Proj:
|
||||
r = cache(e, update_proj(e, visit(proj_expr(e))), shared);
|
||||
break;
|
||||
return cache(e, update_proj(e, visit(proj_expr(e))), shared);
|
||||
case expr_kind::App:
|
||||
r = cache(e, visit_app(e), shared);
|
||||
break;
|
||||
return cache(e, visit_app(e), shared);
|
||||
case expr_kind::Pi: case expr_kind::Lambda:
|
||||
r = cache(e, update_binding(e, visit(binding_domain(e)), visit(binding_body(e))), shared);
|
||||
break;
|
||||
return cache(e, update_binding(e, visit(binding_domain(e)), visit(binding_body(e))), shared);
|
||||
case expr_kind::Let:
|
||||
r = cache(e, update_let(e, visit(let_type(e)), visit(let_value(e)), visit(let_body(e))), shared);
|
||||
break;
|
||||
return cache(e, update_let(e, visit(let_type(e)), visit(let_value(e)), visit(let_body(e))), shared);
|
||||
}
|
||||
done:
|
||||
/* Propagate: parent is resolvable only if this subtree is too. */
|
||||
m_result_resolvable = saved_resolvable && m_result_resolvable;
|
||||
return r;
|
||||
}
|
||||
|
||||
expr operator()(expr const & e) { return visit(e); }
|
||||
};
|
||||
|
||||
/* ============================================================================
|
||||
Resolvability computation (between pass 1 and pass 2).
|
||||
Determines which delayed assignments can be fully resolved by pass 2.
|
||||
A pending mvar is resolvable if:
|
||||
1. It is directly assigned, AND
|
||||
2. Its assigned value (normalized by pass 1) has no remaining mvars,
|
||||
OR all remaining mvars are delayed-assigned heads appearing in app
|
||||
position with enough arguments, whose own pending values are resolvable.
|
||||
Uses fixpoint iteration over the dependency graph.
|
||||
============================================================================ */
|
||||
|
||||
struct pending_info {
|
||||
bool assigned;
|
||||
bool trivially_resolvable; /* assigned and has_expr_mvar is false */
|
||||
bool has_unresolvable_mvar; /* has a mvar that can never be resolved by pass 2 */
|
||||
name_set delayed_deps; /* delayed head names appearing with enough args */
|
||||
};
|
||||
|
||||
/* Analyze an expression for mvar occurrences that affect resolvability.
|
||||
Sets info.has_unresolvable_mvar if any mvar is found that pass 2 cannot resolve.
|
||||
Adds delayed head names to info.delayed_deps for mvars in app position with
|
||||
enough arguments (these might be resolvable depending on the fixpoint). */
|
||||
static void analyze_pending_mvars(
|
||||
expr const & e,
|
||||
name_hash_map<name> const & head_to_pending,
|
||||
metavar_ctx & mctx,
|
||||
pending_info & info)
|
||||
{
|
||||
if (!has_expr_mvar(e) || info.has_unresolvable_mvar) return;
|
||||
|
||||
switch (e.kind()) {
|
||||
case expr_kind::MVar:
|
||||
/* Bare mvar — pass 2's visit_mvar only checks direct assignments,
|
||||
which pass 1 already resolved. So this mvar is stuck. */
|
||||
info.has_unresolvable_mvar = true;
|
||||
return;
|
||||
case expr_kind::App: {
|
||||
expr const & f = get_app_fn(e);
|
||||
if (is_mvar(f)) {
|
||||
name const & mid = mvar_name(f);
|
||||
/* Check if this is a known delayed head. */
|
||||
auto it = head_to_pending.find(mid);
|
||||
if (it == head_to_pending.end()) {
|
||||
/* Not a delayed head — unassigned mvar left by pass 1. */
|
||||
info.has_unresolvable_mvar = true;
|
||||
return;
|
||||
}
|
||||
/* Check arg count against fvar count. */
|
||||
option_ref<delayed_assignment> d = get_delayed_mvar_assignment(mctx, mid);
|
||||
if (!d) {
|
||||
info.has_unresolvable_mvar = true;
|
||||
return;
|
||||
}
|
||||
array_ref<expr> fvars(cnstr_get(d.get_val().raw(), 0), true);
|
||||
if (fvars.size() > get_app_num_args(e)) {
|
||||
/* Not enough args — pass 2 can't resolve this. */
|
||||
info.has_unresolvable_mvar = true;
|
||||
return;
|
||||
}
|
||||
/* Record this as a dependency on the delayed head. */
|
||||
info.delayed_deps.insert(mid);
|
||||
/* Also check the arguments for unresolvable mvars. */
|
||||
expr const * curr = &e;
|
||||
while (is_app(*curr)) {
|
||||
analyze_pending_mvars(app_arg(*curr), head_to_pending, mctx, info);
|
||||
if (info.has_unresolvable_mvar) return;
|
||||
curr = &app_fn(*curr);
|
||||
}
|
||||
return;
|
||||
}
|
||||
/* Non-mvar app head — recurse normally. */
|
||||
analyze_pending_mvars(app_fn(e), head_to_pending, mctx, info);
|
||||
if (info.has_unresolvable_mvar) return;
|
||||
analyze_pending_mvars(app_arg(e), head_to_pending, mctx, info);
|
||||
return;
|
||||
}
|
||||
case expr_kind::Lambda: case expr_kind::Pi:
|
||||
analyze_pending_mvars(binding_domain(e), head_to_pending, mctx, info);
|
||||
if (info.has_unresolvable_mvar) return;
|
||||
analyze_pending_mvars(binding_body(e), head_to_pending, mctx, info);
|
||||
return;
|
||||
case expr_kind::Let:
|
||||
analyze_pending_mvars(let_type(e), head_to_pending, mctx, info);
|
||||
if (info.has_unresolvable_mvar) return;
|
||||
analyze_pending_mvars(let_value(e), head_to_pending, mctx, info);
|
||||
if (info.has_unresolvable_mvar) return;
|
||||
analyze_pending_mvars(let_body(e), head_to_pending, mctx, info);
|
||||
return;
|
||||
case expr_kind::MData:
|
||||
analyze_pending_mvars(mdata_expr(e), head_to_pending, mctx, info);
|
||||
return;
|
||||
case expr_kind::Proj:
|
||||
analyze_pending_mvars(proj_expr(e), head_to_pending, mctx, info);
|
||||
return;
|
||||
default:
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
static name_set compute_resolvable_delayed(
|
||||
metavar_ctx & mctx,
|
||||
name_hash_map<name> const & head_to_pending)
|
||||
{
|
||||
if (head_to_pending.empty())
|
||||
return name_set();
|
||||
|
||||
/* Step 1: For each unique pending mvar, check its assignment and
|
||||
analyze the structure for mvar dependencies. */
|
||||
name_hash_map<pending_info> infos;
|
||||
for (auto & kv : head_to_pending) {
|
||||
name const & pending = kv.second;
|
||||
if (infos.find(pending) != infos.end())
|
||||
continue;
|
||||
option_ref<expr> r = get_mvar_assignment(mctx, pending);
|
||||
if (!r) {
|
||||
infos[pending] = {false, false, false, {}};
|
||||
continue;
|
||||
}
|
||||
expr val(r.get_val());
|
||||
if (!has_expr_mvar(val)) {
|
||||
infos[pending] = {true, true, false, {}};
|
||||
continue;
|
||||
}
|
||||
pending_info info = {true, false, false, {}};
|
||||
analyze_pending_mvars(val, head_to_pending, mctx, info);
|
||||
infos[pending] = std::move(info);
|
||||
}
|
||||
|
||||
/* Step 2: Fixpoint iteration.
|
||||
A pending is resolvable if assigned, has no unresolvable mvars,
|
||||
and all delayed_deps have resolvable pending values. */
|
||||
name_set resolvable;
|
||||
/* Seed with trivially resolvable (no mvars at all). */
|
||||
for (auto & kv : infos) {
|
||||
if (kv.second.trivially_resolvable)
|
||||
resolvable.insert(kv.first);
|
||||
}
|
||||
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
for (auto & kv : head_to_pending) {
|
||||
name const & pending = kv.second;
|
||||
if (resolvable.contains(pending)) continue;
|
||||
auto it = infos.find(pending);
|
||||
if (it == infos.end()) continue;
|
||||
auto & info = it->second;
|
||||
if (!info.assigned || info.has_unresolvable_mvar) continue;
|
||||
|
||||
bool all_ok = true;
|
||||
info.delayed_deps.for_each([&](name const & dep_head) {
|
||||
if (!all_ok) return;
|
||||
auto ht = head_to_pending.find(dep_head);
|
||||
if (ht == head_to_pending.end()) {
|
||||
all_ok = false;
|
||||
return;
|
||||
}
|
||||
if (!resolvable.contains(ht->second)) {
|
||||
all_ok = false;
|
||||
}
|
||||
});
|
||||
if (all_ok) {
|
||||
resolvable.insert(pending);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return resolvable;
|
||||
}
|
||||
|
||||
/* ============================================================================
|
||||
Pass 2: Resolve delayed assignments with fused fvar substitution.
|
||||
Direct mvar chains have been pre-resolved by pass 1.
|
||||
@@ -624,10 +702,10 @@ class instantiate_delayed_fn {
|
||||
if (fvars.size() > get_app_num_args(e)) {
|
||||
return visit_mvar_app_args(e);
|
||||
}
|
||||
/* Only resolve the delayed assignment when pass 1 determined
|
||||
the pending value is fully resolvable (assigned and all nested
|
||||
delayed mvars also resolvable). This matches the original
|
||||
instantiateMVars behavior. */
|
||||
/* Only resolve the delayed assignment when the resolvability
|
||||
computation determined the pending value is fully resolvable
|
||||
(assigned and all nested delayed mvars also resolvable). This
|
||||
matches the original instantiateMVars behavior. */
|
||||
if (!m_resolvable_delayed.contains(mid_pending)) {
|
||||
/* Still normalize the pending value for mctx write-back when
|
||||
fvar_subst is empty. Downstream code (MutualDef.mkInitialUsedFVarsMap)
|
||||
@@ -783,7 +861,8 @@ static object * run_instantiate_all(object * m, object * e) {
|
||||
if (!pass1.has_delayed()) {
|
||||
e2 = e1;
|
||||
} else {
|
||||
instantiate_delayed_fn pass2(mctx, pass1.resolvable_delayed());
|
||||
name_set resolvable = compute_resolvable_delayed(mctx, pass1.head_to_pending());
|
||||
instantiate_delayed_fn pass2(mctx, resolvable);
|
||||
e2 = pass2(e1);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user