mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-17 18:34:06 +00:00
perf: revert to fused pass 2, keep cached resolvability_checker
This reverts commit 93f75deaf3's pass 2 changes (replace_fvars approach) while keeping its cached resolvability_checker. The fused fvar substitution in pass 2 is subquadratic on the delayed assignment benchmark (74ms vs 2116ms at n=500), but causes OOM on the larger elab_bench/bv_decide_rewriter test due to sharing loss when the same subexpression appears in different fvar_subst contexts. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -26,11 +26,11 @@ Between passes, a `resolvability_checker` determines which delayed assignments c
|
||||
be fully resolved (assigned, mvar-free after resolution, sufficient arguments).
|
||||
|
||||
Pass 2 (`instantiate_delayed_fn`):
|
||||
Resolves delayed assignments using `replace_fvars`, matching the original
|
||||
`instantiateMVars` approach. Pending values are visited once (cached/written
|
||||
back), then fvar substitution is applied mechanically per-site. This preserves
|
||||
sharing of the visited pending value across multiple occurrences of the same
|
||||
delayed mvar.
|
||||
Fused traversal that resolves delayed assignments by carrying a fvar substitution.
|
||||
Since pass 1 has pre-normalized all direct chains, each pending value is compact
|
||||
and visited once, avoiding the O(n³) sharing loss that occurs when the fused
|
||||
approach must also chase direct chains. Unassigned mvars are left as-is (matching
|
||||
the original `instantiateMVars` behavior).
|
||||
*/
|
||||
|
||||
namespace lean {
|
||||
@@ -42,9 +42,6 @@ extern "C" object * lean_assign_mvar(obj_arg mctx, obj_arg mid, obj_arg val);
|
||||
typedef object_ref metavar_ctx;
|
||||
typedef object_ref delayed_assignment;
|
||||
|
||||
/* Forward declaration — defined in instantiate_mvars.cpp */
|
||||
expr replace_fvars(expr const & e, array_ref<expr> const & fvars, expr const * rev_args);
|
||||
|
||||
static void assign_lmvar(metavar_ctx & mctx, name const & mid, level const & l) {
|
||||
object * r = lean_assign_lmvar(mctx.steal(), mid.to_obj_arg(), l.to_obj_arg());
|
||||
mctx.set_box(r);
|
||||
@@ -418,45 +415,126 @@ public:
|
||||
};
|
||||
|
||||
/* ============================================================================
|
||||
Pass 2: Resolve delayed assignments using replace_fvars.
|
||||
Direct mvar chains have been pre-resolved by pass 1, and the resolvability
|
||||
checker has determined which delayed assignments can be fully resolved.
|
||||
Pass 2: Resolve delayed assignments with fused fvar substitution.
|
||||
Direct mvar chains have been pre-resolved by pass 1.
|
||||
|
||||
Uses the same replace_fvars approach as the original instantiateMVars:
|
||||
pending values are visited once (cached/written back), then fvar substitution
|
||||
is applied mechanically per-site via replace_fvars. This preserves sharing
|
||||
of the visited pending value across multiple occurrences of the same delayed
|
||||
mvar.
|
||||
Uses a flat (ptr, depth)-keyed cache with generation-based staleness.
|
||||
Each visit_delayed scope gets a unique generation number; cache entries
|
||||
record the scope level and generation at insertion. Validity is O(1):
|
||||
entry valid iff level <= m_scope && m_scope_gens[level] == entry.scope_gen.
|
||||
============================================================================ */
|
||||
|
||||
struct fvar_subst_entry {
|
||||
unsigned depth;
|
||||
unsigned scope;
|
||||
expr value;
|
||||
};
|
||||
|
||||
class instantiate_delayed_fn {
|
||||
struct key_hasher {
|
||||
std::size_t operator()(std::pair<lean_object *, unsigned> const & p) const {
|
||||
return hash((size_t)p.first >> 3, p.second);
|
||||
}
|
||||
};
|
||||
|
||||
struct cache_entry { expr result; unsigned scope_level; unsigned scope_gen; };
|
||||
|
||||
typedef lean::unordered_map<std::pair<lean_object *, unsigned>, cache_entry, key_hasher> flat_cache;
|
||||
|
||||
metavar_ctx & m_mctx;
|
||||
name_set const & m_resolvable_delayed;
|
||||
lean::unordered_map<lean_object *, expr> m_cache;
|
||||
name_hash_map<fvar_subst_entry> m_fvar_subst;
|
||||
unsigned m_depth;
|
||||
|
||||
/* Single flat cache with generation-based staleness detection. */
|
||||
flat_cache m_cache;
|
||||
std::vector<unsigned> m_scope_gens; /* m_scope_gens[level] = generation */
|
||||
unsigned m_gen_counter;
|
||||
unsigned m_scope;
|
||||
|
||||
/* After visit() returns, this holds the maximum fvar-substitution
|
||||
scope that contributed to the result — i.e., the outermost scope at which the
|
||||
result is valid and can be cached. Updated monotonically (via max) through
|
||||
the save/reset/restore pattern in visit(). */
|
||||
unsigned m_result_scope;
|
||||
|
||||
/* Global cache for fvar-free expressions — scope-independent. */
|
||||
lean::unordered_map<lean_object *, expr> m_global_cache;
|
||||
|
||||
/* Write-back support: when fvar_subst is empty, normalize and write back
|
||||
mvar assignments to match the original instantiateMVars mctx side effects.
|
||||
Downstream code (e.g. MutualDef.mkInitialUsedFVarsMap) reads stored
|
||||
assignments and expects them to be normalized. */
|
||||
name_set m_already_normalized;
|
||||
std::vector<expr> m_saved;
|
||||
|
||||
/* Get a direct mvar assignment. Visit it to resolve inner delayed mvars.
|
||||
Normalize and write back the result to the mctx. This matches the
|
||||
original instantiateMVars behavior: downstream code (e.g.
|
||||
MutualDef.mkInitialUsedFVarsMap) reads stored assignments and expects
|
||||
inner delayed assignments to be resolved. */
|
||||
bool fvar_subst_empty() const {
|
||||
return m_fvar_subst.empty();
|
||||
}
|
||||
|
||||
optional<expr> lookup_fvar(name const & fid) {
|
||||
auto it = m_fvar_subst.find(fid);
|
||||
if (it == m_fvar_subst.end())
|
||||
return optional<expr>();
|
||||
m_result_scope = std::max(m_result_scope, it->second.scope);
|
||||
unsigned d = m_depth - it->second.depth;
|
||||
if (d == 0)
|
||||
return optional<expr>(it->second.value);
|
||||
return optional<expr>(lift_loose_bvars(it->second.value, d));
|
||||
}
|
||||
|
||||
/* Cache lookup — O(1) with generation-based staleness check.
|
||||
An entry at scope_level 0 (no fvar dependency) is valid at any scope.
|
||||
An entry at scope_level > 0 is only valid at exactly that scope level,
|
||||
because an inner scope may shadow the fvars it depends on. */
|
||||
optional<expr> cache_lookup(lean_object * ptr) {
|
||||
auto key = mk_pair(ptr, m_depth);
|
||||
auto it = m_cache.find(key);
|
||||
if (it == m_cache.end()) return {};
|
||||
auto & entry = it->second;
|
||||
if ((entry.scope_level == 0 || entry.scope_level == m_scope) &&
|
||||
m_scope_gens[entry.scope_level] == entry.scope_gen) {
|
||||
m_result_scope = std::max(m_result_scope, entry.scope_level);
|
||||
return optional<expr>(entry.result);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
void cache_insert(lean_object * ptr, expr const & result) {
|
||||
auto key = mk_pair(ptr, m_depth);
|
||||
m_cache[key] = { result, m_result_scope, m_scope_gens[m_result_scope] };
|
||||
}
|
||||
|
||||
/* Get a direct mvar assignment. Visit it to resolve delayed mvars
|
||||
and apply the fvar substitution.
|
||||
When fvar_subst is empty, normalize and write back the result to
|
||||
the mctx. This matches the original instantiateMVars behavior:
|
||||
downstream code (e.g. MutualDef.mkInitialUsedFVarsMap) reads stored
|
||||
assignments and expects inner delayed assignments to be resolved.
|
||||
When fvar_subst is non-empty, no write-back (values contain
|
||||
fvar-substituted terms not suitable for the mctx). */
|
||||
optional<expr> get_assignment(name const & mid) {
|
||||
option_ref<expr> r = get_mvar_assignment(m_mctx, mid);
|
||||
if (!r)
|
||||
return optional<expr>();
|
||||
expr a(r.get_val());
|
||||
if (!has_mvar(a))
|
||||
return optional<expr>(a);
|
||||
if (m_already_normalized.contains(mid))
|
||||
return optional<expr>(a);
|
||||
m_already_normalized.insert(mid);
|
||||
expr a_new = visit(a);
|
||||
if (!is_eqp(a, a_new)) {
|
||||
m_saved.push_back(a);
|
||||
assign_mvar(m_mctx, mid, a_new);
|
||||
if (fvar_subst_empty()) {
|
||||
if (!has_mvar(a))
|
||||
return optional<expr>(a);
|
||||
if (m_already_normalized.contains(mid))
|
||||
return optional<expr>(a);
|
||||
m_already_normalized.insert(mid);
|
||||
expr a_new = visit(a);
|
||||
if (!is_eqp(a, a_new)) {
|
||||
m_saved.push_back(a);
|
||||
assign_mvar(m_mctx, mid, a_new);
|
||||
}
|
||||
return optional<expr>(a_new);
|
||||
} else {
|
||||
if (!has_mvar(a) && !has_fvar(a))
|
||||
return optional<expr>(a);
|
||||
return optional<expr>(visit(a));
|
||||
}
|
||||
return optional<expr>(a_new);
|
||||
}
|
||||
|
||||
expr visit_app_default(expr const & e) {
|
||||
@@ -500,20 +578,54 @@ class instantiate_delayed_fn {
|
||||
args.push_back(visit(app_arg(*curr)));
|
||||
curr = &app_fn(*curr);
|
||||
}
|
||||
/* Get and visit the pending value (resolving inner delayed assignments).
|
||||
Uses get_assignment for write-back normalization. */
|
||||
optional<expr> val = get_assignment(mid_pending);
|
||||
lean_assert(val); /* resolvability checker verified this is assigned */
|
||||
/* Replace the delayed assignment's fvars with the visited arguments,
|
||||
matching the original instantiateMVars approach. */
|
||||
|
||||
size_t fvar_count = fvars.size();
|
||||
expr val_new = replace_fvars(*val, fvars, args.data() + (args.size() - fvar_count));
|
||||
/* Use apply_beta for extra args: the fvar-substituted pending value may
|
||||
be a lambda (e.g. from assertAfter), and extra args should be beta-
|
||||
reduced into it rather than left as a redex. */
|
||||
size_t extra_count = args.size() - fvar_count;
|
||||
|
||||
/* Save and extend the fvar substitution. */
|
||||
struct saved_entry { name key; bool had_old; fvar_subst_entry old; };
|
||||
std::vector<saved_entry> saved_entries;
|
||||
saved_entries.reserve(fvar_count);
|
||||
m_scope++;
|
||||
for (size_t i = 0; i < fvar_count; i++) {
|
||||
name const & fid = fvar_name(fvars[i]);
|
||||
auto old_it = m_fvar_subst.find(fid);
|
||||
if (old_it != m_fvar_subst.end()) {
|
||||
saved_entries.push_back({fid, true, old_it->second});
|
||||
} else {
|
||||
saved_entries.push_back({fid, false, {0, 0, expr()}});
|
||||
}
|
||||
m_fvar_subst[fid] = {m_depth, m_scope, args[args.size() - 1 - i]};
|
||||
}
|
||||
|
||||
/* Push: bump generation so stale entries at this scope level are detected. */
|
||||
m_gen_counter++;
|
||||
if (m_scope >= m_scope_gens.size())
|
||||
m_scope_gens.push_back(m_gen_counter);
|
||||
else
|
||||
m_scope_gens[m_scope] = m_gen_counter;
|
||||
|
||||
expr val_new = visit(mk_mvar(mid_pending));
|
||||
|
||||
/* Pop: just decrement scope — stale entries are detected by generation mismatch. */
|
||||
m_scope--;
|
||||
|
||||
/* Restore the fvar substitution. */
|
||||
for (auto & se : saved_entries) {
|
||||
if (!se.had_old) {
|
||||
m_fvar_subst.erase(se.key);
|
||||
} else {
|
||||
m_fvar_subst[se.key] = se.old;
|
||||
}
|
||||
}
|
||||
|
||||
/* Use apply_beta instead of mk_rev_app: pass 1's beta-reduction may have
|
||||
changed delayed mvar arguments (e.g., substituting a bvar with a concrete
|
||||
value), so the resolved pending value may be a lambda that needs beta-
|
||||
reduction with the extra args, matching the original's behavior. */
|
||||
bool preserve_data = false;
|
||||
bool zeta = true;
|
||||
return apply_beta(val_new, args.size() - fvar_count, args.data(), preserve_data, zeta);
|
||||
return apply_beta(val_new, extra_count, args.data(), preserve_data, zeta);
|
||||
}
|
||||
|
||||
expr visit_app(expr const & e) {
|
||||
@@ -542,11 +654,13 @@ class instantiate_delayed_fn {
|
||||
(assigned and all nested delayed mvars also resolvable). This
|
||||
matches the original instantiateMVars behavior. */
|
||||
if (!m_resolvable_delayed.contains(mid_pending)) {
|
||||
/* Normalize the pending value for mctx write-back.
|
||||
Downstream code (MutualDef.mkInitialUsedFVarsMap) reads stored
|
||||
assignments and relies on inner delayed assignments being
|
||||
resolved even when the outer one cannot be. */
|
||||
(void)get_assignment(mid_pending);
|
||||
/* Still normalize the pending value for mctx write-back when
|
||||
fvar_subst is empty. Downstream code (MutualDef.mkInitialUsedFVarsMap)
|
||||
reads stored assignments and relies on inner delayed assignments
|
||||
being resolved even when the outer one cannot be. */
|
||||
if (fvar_subst_empty()) {
|
||||
(void)get_assignment(mid_pending);
|
||||
}
|
||||
return visit_mvar_app_args(e);
|
||||
}
|
||||
buffer<expr> args;
|
||||
@@ -561,49 +675,116 @@ class instantiate_delayed_fn {
|
||||
return e;
|
||||
}
|
||||
|
||||
inline expr cache(expr const & e, expr r, bool shared) {
|
||||
if (shared) {
|
||||
m_cache.insert(mk_pair(e.raw(), r));
|
||||
expr visit_fvar(expr const & e) {
|
||||
name const & fid = fvar_name(e);
|
||||
if (auto r = lookup_fvar(fid)) {
|
||||
return *r;
|
||||
}
|
||||
return r;
|
||||
return e;
|
||||
}
|
||||
|
||||
public:
|
||||
instantiate_delayed_fn(metavar_ctx & mctx, name_set const & resolvable_delayed)
|
||||
: m_mctx(mctx), m_resolvable_delayed(resolvable_delayed) {}
|
||||
: m_mctx(mctx), m_resolvable_delayed(resolvable_delayed),
|
||||
m_depth(0), m_gen_counter(0), m_scope(0), m_result_scope(0) {
|
||||
m_scope_gens.push_back(0); /* scope 0 has generation 0 */
|
||||
}
|
||||
|
||||
expr visit(expr const & e) {
|
||||
if (!has_mvar(e))
|
||||
return e;
|
||||
if (fvar_subst_empty()) {
|
||||
if (!has_mvar(e))
|
||||
return e;
|
||||
} else {
|
||||
if (!has_mvar(e) && !has_fvar(e))
|
||||
return e;
|
||||
}
|
||||
|
||||
bool use_global = !has_fvar(e) && !has_expr_mvar(e);
|
||||
bool shared = false;
|
||||
if (is_shared(e)) {
|
||||
auto it = m_cache.find(e.raw());
|
||||
if (it != m_cache.end()) {
|
||||
return it->second;
|
||||
if (use_global) {
|
||||
auto it = m_global_cache.find(e.raw());
|
||||
if (it != m_global_cache.end())
|
||||
return it->second;
|
||||
} else {
|
||||
if (auto r = cache_lookup(e.raw()))
|
||||
return *r;
|
||||
}
|
||||
shared = true;
|
||||
}
|
||||
|
||||
/* Save and reset the result scope for this subtree.
|
||||
After computing, cache_insert uses m_result_scope to place the entry
|
||||
at the outermost valid scope level. Then we restore the parent's
|
||||
watermark, taking the max with our contribution. */
|
||||
unsigned saved_result_scope = m_result_scope;
|
||||
m_result_scope = 0;
|
||||
|
||||
expr r;
|
||||
switch (e.kind()) {
|
||||
case expr_kind::BVar:
|
||||
case expr_kind::Lit: case expr_kind::FVar:
|
||||
case expr_kind::Lit:
|
||||
lean_unreachable();
|
||||
case expr_kind::Sort: case expr_kind::Const:
|
||||
/* Levels already resolved by pass 1. */
|
||||
return e;
|
||||
case expr_kind::FVar:
|
||||
r = visit_fvar(e);
|
||||
goto done; /* skip caching for fvars */
|
||||
case expr_kind::Sort:
|
||||
r = update_sort(e, visit_level(sort_level(e)));
|
||||
break;
|
||||
case expr_kind::Const:
|
||||
r = update_const(e, visit_levels(const_levels(e)));
|
||||
break;
|
||||
case expr_kind::MVar:
|
||||
return visit_mvar(e);
|
||||
r = visit_mvar(e);
|
||||
goto done; /* mvar results are not (ptr, depth)-cacheable */
|
||||
case expr_kind::MData:
|
||||
return cache(e, update_mdata(e, visit(mdata_expr(e))), shared);
|
||||
r = update_mdata(e, visit(mdata_expr(e)));
|
||||
break;
|
||||
case expr_kind::Proj:
|
||||
return cache(e, update_proj(e, visit(proj_expr(e))), shared);
|
||||
r = update_proj(e, visit(proj_expr(e)));
|
||||
break;
|
||||
case expr_kind::App:
|
||||
return cache(e, visit_app(e), shared);
|
||||
case expr_kind::Pi: case expr_kind::Lambda:
|
||||
return cache(e, update_binding(e, visit(binding_domain(e)), visit(binding_body(e))), shared);
|
||||
case expr_kind::Let:
|
||||
return cache(e, update_let(e, visit(let_type(e)), visit(let_value(e)), visit(let_body(e))), shared);
|
||||
r = visit_app(e);
|
||||
break;
|
||||
case expr_kind::Pi: case expr_kind::Lambda: {
|
||||
expr d = visit(binding_domain(e));
|
||||
m_depth++;
|
||||
expr b = visit(binding_body(e));
|
||||
m_depth--;
|
||||
r = update_binding(e, d, b);
|
||||
break;
|
||||
}
|
||||
case expr_kind::Let: {
|
||||
expr t = visit(let_type(e));
|
||||
expr v = visit(let_value(e));
|
||||
m_depth++;
|
||||
expr b = visit(let_body(e));
|
||||
m_depth--;
|
||||
r = update_let(e, t, v, b);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (shared) {
|
||||
if (use_global)
|
||||
m_global_cache.insert(mk_pair(e.raw(), r));
|
||||
else
|
||||
cache_insert(e.raw(), r);
|
||||
}
|
||||
|
||||
done:
|
||||
m_result_scope = std::max(saved_result_scope, m_result_scope);
|
||||
return r;
|
||||
}
|
||||
|
||||
level visit_level(level const & l) {
|
||||
/* Pass 2 does not handle level mvars — pass 1 already resolved them.
|
||||
But we still need this for the visit_levels call in update_sort/update_const.
|
||||
Since levels have no fvars, we can just return them as-is. */
|
||||
return l;
|
||||
}
|
||||
|
||||
levels visit_levels(levels const & ls) {
|
||||
return ls;
|
||||
}
|
||||
|
||||
expr operator()(expr const & e) { return visit(e); }
|
||||
@@ -620,7 +801,7 @@ static object * run_instantiate_all(object * m, object * e) {
|
||||
instantiate_direct_fn pass1(mctx);
|
||||
expr e1 = pass1(expr(e));
|
||||
|
||||
/* Pass 2: resolve delayed assignments using replace_fvars.
|
||||
/* Pass 2: resolve delayed assignments with fused fvar substitution.
|
||||
Skip if pass 1 found no delayed assignments at all — the expression
|
||||
has no delayed mvars that need resolution or write-back. */
|
||||
expr e2;
|
||||
|
||||
Reference in New Issue
Block a user