mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-17 18:34:06 +00:00
Revert pass fusion, has perf issues
This commit is contained in:
@@ -14,20 +14,23 @@ Authors: Joachim Breitner
|
||||
#include "kernel/expr.h"
|
||||
|
||||
/*
|
||||
Fused single-pass instantiateMVars with improved sharing.
|
||||
This module provides a two-pass variant of `instantiateMVars` with improved sharing.
|
||||
|
||||
One class, two modes:
|
||||
- Outer mode (m_scope == 0, fvar_subst empty): resolves direct mvar assignments
|
||||
with write-back. When encountering a resolvable delayed mvar application,
|
||||
fires delayed resolution inline (switching to delayed mode).
|
||||
- Delayed mode (m_scope > 0, fvar_subst active): carries fvar substitution,
|
||||
resolves nested delayed mvars. No direct mvar resolution needed (outer mode
|
||||
already normalized all direct chains via write-back).
|
||||
Pass 1 (`instantiate_direct_fn`):
|
||||
Standard `instantiateMVars`-like traversal that resolves direct mvar assignments
|
||||
with write-back and a single persistent cache. For delayed assignments, it
|
||||
pre-normalizes the pending value (resolving its direct chain) but leaves the
|
||||
delayed mvar application in the expression. Unassigned mvars are left in place.
|
||||
|
||||
The shared m_cache allows nodes not behind delayed mvars to be traversed exactly
|
||||
once. In delayed mode, cache reads use the `use_global` condition (!has_fvar &&
|
||||
!has_expr_mvar) to safely read scope-independent results from m_cache. Cache
|
||||
writes go to m_cache when m_result_scope == 0, otherwise to the scoped cache.
|
||||
Pass 2 (`instantiate_delayed_fn`):
|
||||
Fused traversal that resolves delayed assignments by carrying a fvar substitution.
|
||||
Since pass 1 has pre-normalized all direct chains, each pending value is compact
|
||||
and visited once, avoiding the O(n³) sharing loss that occurs when the fused
|
||||
approach must also chase direct chains. Unassigned mvars are left as-is (matching
|
||||
the original `instantiateMVars` behavior).
|
||||
|
||||
The combination preserves sharing (O(n²) output for the pathological nested-delayed
|
||||
case) while avoiding the separate `replace_fvars` calls of the standard approach.
|
||||
*/
|
||||
|
||||
namespace lean {
|
||||
@@ -116,83 +119,54 @@ public:
|
||||
};
|
||||
|
||||
/* ============================================================================
|
||||
Fused instantiation: resolves direct and delayed mvar assignments in a
|
||||
single traversal. Outer mode handles direct assignments with write-back;
|
||||
delayed mode carries fvar substitution for delayed assignment resolution.
|
||||
Pass 1: Resolve direct mvar assignments with write-back.
|
||||
For delayed assignments, pre-normalize the pending value but leave the
|
||||
delayed mvar application in the expression.
|
||||
============================================================================ */
|
||||
|
||||
struct fvar_subst_entry {
|
||||
unsigned depth;
|
||||
unsigned scope;
|
||||
expr value;
|
||||
};
|
||||
|
||||
class instantiate_fused_fn {
|
||||
struct key_hasher {
|
||||
std::size_t operator()(std::pair<lean_object *, unsigned> const & p) const {
|
||||
return hash((size_t)p.first >> 3, p.second);
|
||||
}
|
||||
};
|
||||
|
||||
struct cache_entry { expr result; unsigned scope_level; unsigned scope_gen; };
|
||||
|
||||
typedef lean::unordered_map<std::pair<lean_object *, unsigned>, cache_entry, key_hasher> scoped_cache;
|
||||
|
||||
class instantiate_direct_fn {
|
||||
metavar_ctx & m_mctx;
|
||||
instantiate_lmvars_all_fn m_level_fn;
|
||||
|
||||
/* Direct mvar normalization tracking (outer mode). */
|
||||
name_set m_already_normalized;
|
||||
/* Set of delayed-assigned mvars whose pending value is assigned and
|
||||
mvar-free after normalization. Used by pass 2 as a guard: only resolve
|
||||
delayed assignments when the pending mvar is in this set, matching
|
||||
the original instantiateMVars behavior. */
|
||||
name_set m_resolvable_delayed;
|
||||
/* Set to true when any delayed assignment is encountered, even if not
|
||||
resolvable. Pass 2 is needed for write-back normalization in that case. */
|
||||
bool m_has_delayed;
|
||||
lean::unordered_map<lean_object *, expr> m_cache;
|
||||
std::vector<expr> m_saved;
|
||||
|
||||
/* Resolvable delayed mvars: pending value is assigned and mvar-free
|
||||
after full resolution. Built during outer-mode traversal. */
|
||||
name_set m_resolvable_delayed;
|
||||
|
||||
/* Shared cache: used in outer mode (ptr → expr) and readable in
|
||||
delayed mode for scope-independent expressions. */
|
||||
lean::unordered_map<lean_object *, expr> m_cache;
|
||||
|
||||
/* Scoped cache: (ptr, depth) → cache_entry with generation-based
|
||||
staleness detection. Used in delayed mode for scope-dependent results. */
|
||||
scoped_cache m_scoped_cache;
|
||||
|
||||
/* Fvar substitution state (delayed mode). */
|
||||
name_hash_map<fvar_subst_entry> m_fvar_subst;
|
||||
unsigned m_depth;
|
||||
|
||||
/* Scope tracking with generation-based staleness. */
|
||||
std::vector<unsigned> m_scope_gens;
|
||||
unsigned m_gen_counter;
|
||||
unsigned m_scope;
|
||||
|
||||
/* After visit() returns, holds the maximum fvar-substitution scope that
|
||||
contributed to the result. Used to decide cache placement. */
|
||||
unsigned m_result_scope;
|
||||
|
||||
bool in_delayed_mode() const { return m_scope > 0; }
|
||||
|
||||
level visit_level(level const & l) {
|
||||
if (in_delayed_mode()) return l; /* levels already resolved by outer mode */
|
||||
return m_level_fn(l);
|
||||
}
|
||||
|
||||
levels visit_levels(levels const & ls) {
|
||||
if (in_delayed_mode()) return ls;
|
||||
buffer<level> lsNew;
|
||||
for (auto const & l : ls)
|
||||
lsNew.push_back(visit_level(l));
|
||||
return levels(lsNew);
|
||||
}
|
||||
|
||||
/* ---- Direct mvar assignment (outer mode) ---- */
|
||||
inline expr cache(expr const & e, expr r, bool shared) {
|
||||
if (shared) {
|
||||
m_cache.insert(mk_pair(e.raw(), r));
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
optional<expr> get_direct_assignment(name const & mid) {
|
||||
/* Get and normalize a direct mvar assignment. Write back the normalized value. */
|
||||
optional<expr> get_assignment(name const & mid) {
|
||||
option_ref<expr> r = get_mvar_assignment(m_mctx, mid);
|
||||
if (!r) return optional<expr>();
|
||||
if (!r) {
|
||||
return optional<expr>();
|
||||
}
|
||||
expr a(r.get_val());
|
||||
if (!has_mvar(a) || m_already_normalized.contains(mid))
|
||||
if (!has_mvar(a) || m_already_normalized.contains(mid)) {
|
||||
return optional<expr>(a);
|
||||
}
|
||||
m_already_normalized.insert(mid);
|
||||
expr a_new = visit(a);
|
||||
if (!is_eqp(a, a_new)) {
|
||||
@@ -202,110 +176,6 @@ class instantiate_fused_fn {
|
||||
return optional<expr>(a_new);
|
||||
}
|
||||
|
||||
/* ---- Delayed-mode mvar assignment ---- */
|
||||
|
||||
/* In delayed mode, direct mvar chains are already resolved by outer-mode
|
||||
write-back. We just need to visit the value to apply fvar substitution
|
||||
and resolve nested delayed mvars. No write-back in delayed mode. */
|
||||
optional<expr> get_delayed_mode_assignment(name const & mid) {
|
||||
option_ref<expr> r = get_mvar_assignment(m_mctx, mid);
|
||||
if (!r) return optional<expr>();
|
||||
expr a(r.get_val());
|
||||
if (!has_mvar(a) && !has_fvar(a))
|
||||
return optional<expr>(a);
|
||||
return optional<expr>(visit(a));
|
||||
}
|
||||
|
||||
/* ---- Resolvability check ---- */
|
||||
|
||||
/* Check whether a normalized value would be mvar-free after full resolution.
|
||||
Uses m_resolvable_delayed to check inner delayed mvars. After outer-mode
|
||||
normalization, remaining mvars are either unassigned or delayed-assigned. */
|
||||
bool is_value_resolvable(expr const & e) {
|
||||
if (!has_expr_mvar(e)) return true;
|
||||
switch (e.kind()) {
|
||||
case expr_kind::BVar: case expr_kind::Lit: case expr_kind::FVar:
|
||||
case expr_kind::Sort: case expr_kind::Const:
|
||||
return true;
|
||||
case expr_kind::MVar:
|
||||
return false;
|
||||
case expr_kind::App: {
|
||||
expr const & f = get_app_fn(e);
|
||||
if (is_mvar(f)) {
|
||||
name const & mid = mvar_name(f);
|
||||
option_ref<delayed_assignment> d = get_delayed_mvar_assignment(m_mctx, mid);
|
||||
if (!d) return false;
|
||||
array_ref<expr> fvars(cnstr_get(d.get_val().raw(), 0), true);
|
||||
if (fvars.size() > get_app_num_args(e)) return false;
|
||||
name mid_pending(cnstr_get(d.get_val().raw(), 1), true);
|
||||
if (!m_resolvable_delayed.contains(mid_pending)) return false;
|
||||
expr const * curr = &e;
|
||||
while (is_app(*curr)) {
|
||||
if (!is_value_resolvable(app_arg(*curr))) return false;
|
||||
curr = &app_fn(*curr);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return is_value_resolvable(app_fn(e)) && is_value_resolvable(app_arg(e));
|
||||
}
|
||||
case expr_kind::Lambda: case expr_kind::Pi:
|
||||
return is_value_resolvable(binding_domain(e)) && is_value_resolvable(binding_body(e));
|
||||
case expr_kind::Let:
|
||||
return is_value_resolvable(let_type(e)) && is_value_resolvable(let_value(e))
|
||||
&& is_value_resolvable(let_body(e));
|
||||
case expr_kind::MData:
|
||||
return is_value_resolvable(mdata_expr(e));
|
||||
case expr_kind::Proj:
|
||||
return is_value_resolvable(proj_expr(e));
|
||||
}
|
||||
lean_unreachable();
|
||||
}
|
||||
|
||||
/* Pre-normalize a delayed assignment's pending value and record
|
||||
whether it is resolvable. Called during outer-mode traversal. */
|
||||
void normalize_delayed_pending(name const & mid_pending) {
|
||||
if (auto val = get_direct_assignment(mid_pending)) {
|
||||
if (is_value_resolvable(*val)) {
|
||||
m_resolvable_delayed.insert(mid_pending);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ---- Fvar substitution (delayed mode) ---- */
|
||||
|
||||
optional<expr> lookup_fvar(name const & fid) {
|
||||
auto it = m_fvar_subst.find(fid);
|
||||
if (it == m_fvar_subst.end())
|
||||
return optional<expr>();
|
||||
m_result_scope = std::max(m_result_scope, it->second.scope);
|
||||
unsigned d = m_depth - it->second.depth;
|
||||
if (d == 0)
|
||||
return optional<expr>(it->second.value);
|
||||
return optional<expr>(lift_loose_bvars(it->second.value, d));
|
||||
}
|
||||
|
||||
/* ---- Scoped cache (delayed mode) ---- */
|
||||
|
||||
optional<expr> scoped_cache_lookup(lean_object * ptr) {
|
||||
auto key = mk_pair(ptr, m_depth);
|
||||
auto it = m_scoped_cache.find(key);
|
||||
if (it == m_scoped_cache.end()) return {};
|
||||
auto & entry = it->second;
|
||||
if ((entry.scope_level == 0 || entry.scope_level == m_scope) &&
|
||||
m_scope_gens[entry.scope_level] == entry.scope_gen) {
|
||||
m_result_scope = std::max(m_result_scope, entry.scope_level);
|
||||
return optional<expr>(entry.result);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
void scoped_cache_insert(lean_object * ptr, expr const & result) {
|
||||
auto key = mk_pair(ptr, m_depth);
|
||||
m_scoped_cache[key] = { result, m_result_scope, m_scope_gens[m_result_scope] };
|
||||
}
|
||||
|
||||
/* ---- App visitors ---- */
|
||||
|
||||
expr visit_app_default(expr const & e) {
|
||||
buffer<expr> args;
|
||||
expr const * curr = &e;
|
||||
@@ -340,8 +210,300 @@ class instantiate_fused_fn {
|
||||
return apply_beta(f_new, args.size(), args.data(), preserve_data, zeta);
|
||||
}
|
||||
|
||||
/* Fire delayed resolution: push fvar substitution scope, visit pending
|
||||
value, pop scope. Uses generation-based staleness for the scoped cache. */
|
||||
/* Check whether a normalized value would be mvar-free after full resolution.
|
||||
Uses m_resolvable_delayed to check inner delayed mvars. After pass 1
|
||||
normalization, remaining mvars are either unassigned or delayed-assigned. */
|
||||
bool is_value_resolvable(expr const & e) {
|
||||
if (!has_expr_mvar(e)) return true;
|
||||
switch (e.kind()) {
|
||||
case expr_kind::BVar: case expr_kind::Lit: case expr_kind::FVar:
|
||||
case expr_kind::Sort: case expr_kind::Const:
|
||||
return true;
|
||||
case expr_kind::MVar:
|
||||
/* Bare mvar after pass 1 normalization: not directly assigned. */
|
||||
return false;
|
||||
case expr_kind::App: {
|
||||
expr const & f = get_app_fn(e);
|
||||
if (is_mvar(f)) {
|
||||
/* Mvar app after pass 1: must be delayed-assigned or unassigned. */
|
||||
name const & mid = mvar_name(f);
|
||||
option_ref<delayed_assignment> d = get_delayed_mvar_assignment(m_mctx, mid);
|
||||
if (!d) return false;
|
||||
array_ref<expr> fvars(cnstr_get(d.get_val().raw(), 0), true);
|
||||
if (fvars.size() > get_app_num_args(e)) return false;
|
||||
name mid_pending(cnstr_get(d.get_val().raw(), 1), true);
|
||||
if (!m_resolvable_delayed.contains(mid_pending)) return false;
|
||||
/* Also check args for unresolvable mvars. */
|
||||
expr const * curr = &e;
|
||||
while (is_app(*curr)) {
|
||||
if (!is_value_resolvable(app_arg(*curr))) return false;
|
||||
curr = &app_fn(*curr);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return is_value_resolvable(app_fn(e)) && is_value_resolvable(app_arg(e));
|
||||
}
|
||||
case expr_kind::Lambda: case expr_kind::Pi:
|
||||
return is_value_resolvable(binding_domain(e)) && is_value_resolvable(binding_body(e));
|
||||
case expr_kind::Let:
|
||||
return is_value_resolvable(let_type(e)) && is_value_resolvable(let_value(e))
|
||||
&& is_value_resolvable(let_body(e));
|
||||
case expr_kind::MData:
|
||||
return is_value_resolvable(mdata_expr(e));
|
||||
case expr_kind::Proj:
|
||||
return is_value_resolvable(proj_expr(e));
|
||||
}
|
||||
lean_unreachable();
|
||||
}
|
||||
|
||||
/* Pre-normalize the pending value of a delayed assignment and record
|
||||
whether it is resolvable (assigned and mvar-free after full resolution).
|
||||
Inner delayed assignments are processed first (via recursive normalization),
|
||||
so m_resolvable_delayed is already populated for them. */
|
||||
void normalize_delayed_pending(name const & mid_pending) {
|
||||
m_has_delayed = true;
|
||||
if (auto val = get_assignment(mid_pending)) {
|
||||
if (is_value_resolvable(*val)) {
|
||||
m_resolvable_delayed.insert(mid_pending);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
expr visit_app(expr const & e) {
|
||||
expr const & f = get_app_fn(e);
|
||||
if (!is_mvar(f)) {
|
||||
return visit_app_default(e);
|
||||
}
|
||||
name const & mid = mvar_name(f);
|
||||
/* Direct assignment takes precedence. */
|
||||
if (auto f_new = get_assignment(mid)) {
|
||||
buffer<expr> args;
|
||||
return visit_args_and_beta(*f_new, e, args);
|
||||
}
|
||||
/* Check delayed assignment and pre-normalize pending. */
|
||||
option_ref<delayed_assignment> d = get_delayed_mvar_assignment(m_mctx, mid);
|
||||
if (d) {
|
||||
name mid_pending(cnstr_get(d.get_val().raw(), 1), true);
|
||||
normalize_delayed_pending(mid_pending);
|
||||
}
|
||||
/* Leave the (possibly delayed) mvar in place, just visit args. */
|
||||
return visit_mvar_app_args(e);
|
||||
}
|
||||
|
||||
expr visit_mvar(expr const & e) {
|
||||
name const & mid = mvar_name(e);
|
||||
if (auto r = get_assignment(mid)) {
|
||||
return *r;
|
||||
}
|
||||
/* Not directly assigned. Check if delayed-assigned and pre-normalize. */
|
||||
option_ref<delayed_assignment> d = get_delayed_mvar_assignment(m_mctx, mid);
|
||||
if (d) {
|
||||
name mid_pending(cnstr_get(d.get_val().raw(), 1), true);
|
||||
normalize_delayed_pending(mid_pending);
|
||||
}
|
||||
return e; /* leave mvar in place */
|
||||
}
|
||||
|
||||
public:
|
||||
instantiate_direct_fn(metavar_ctx & mctx):m_mctx(mctx), m_level_fn(mctx), m_has_delayed(false) {}
|
||||
name_set const & resolvable_delayed() const { return m_resolvable_delayed; }
|
||||
bool has_delayed() const { return m_has_delayed; }
|
||||
|
||||
expr visit(expr const & e) {
|
||||
if (!has_mvar(e))
|
||||
return e;
|
||||
bool shared = false;
|
||||
if (is_shared(e)) {
|
||||
auto it = m_cache.find(e.raw());
|
||||
if (it != m_cache.end()) {
|
||||
return it->second;
|
||||
}
|
||||
shared = true;
|
||||
}
|
||||
|
||||
switch (e.kind()) {
|
||||
case expr_kind::BVar:
|
||||
case expr_kind::Lit: case expr_kind::FVar:
|
||||
lean_unreachable();
|
||||
case expr_kind::Sort:
|
||||
return cache(e, update_sort(e, visit_level(sort_level(e))), shared);
|
||||
case expr_kind::Const:
|
||||
return cache(e, update_const(e, visit_levels(const_levels(e))), shared);
|
||||
case expr_kind::MVar:
|
||||
return visit_mvar(e);
|
||||
case expr_kind::MData:
|
||||
return cache(e, update_mdata(e, visit(mdata_expr(e))), shared);
|
||||
case expr_kind::Proj:
|
||||
return cache(e, update_proj(e, visit(proj_expr(e))), shared);
|
||||
case expr_kind::App:
|
||||
return cache(e, visit_app(e), shared);
|
||||
case expr_kind::Pi: case expr_kind::Lambda:
|
||||
return cache(e, update_binding(e, visit(binding_domain(e)), visit(binding_body(e))), shared);
|
||||
case expr_kind::Let:
|
||||
return cache(e, update_let(e, visit(let_type(e)), visit(let_value(e)), visit(let_body(e))), shared);
|
||||
}
|
||||
}
|
||||
|
||||
expr operator()(expr const & e) { return visit(e); }
|
||||
};
|
||||
|
||||
/* ============================================================================
|
||||
Pass 2: Resolve delayed assignments with fused fvar substitution.
|
||||
Direct mvar chains have been pre-resolved by pass 1.
|
||||
|
||||
Uses a flat (ptr, depth)-keyed cache with generation-based staleness.
|
||||
Each visit_delayed scope gets a unique generation number; cache entries
|
||||
record the scope level and generation at insertion. Validity is O(1):
|
||||
entry valid iff level <= m_scope && m_scope_gens[level] == entry.scope_gen.
|
||||
============================================================================ */
|
||||
|
||||
struct fvar_subst_entry {
|
||||
unsigned depth;
|
||||
unsigned scope;
|
||||
expr value;
|
||||
};
|
||||
|
||||
class instantiate_delayed_fn {
|
||||
struct key_hasher {
|
||||
std::size_t operator()(std::pair<lean_object *, unsigned> const & p) const {
|
||||
return hash((size_t)p.first >> 3, p.second);
|
||||
}
|
||||
};
|
||||
|
||||
struct cache_entry { expr result; unsigned scope_level; unsigned scope_gen; };
|
||||
|
||||
typedef lean::unordered_map<std::pair<lean_object *, unsigned>, cache_entry, key_hasher> flat_cache;
|
||||
|
||||
metavar_ctx & m_mctx;
|
||||
name_set const & m_resolvable_delayed;
|
||||
name_hash_map<fvar_subst_entry> m_fvar_subst;
|
||||
unsigned m_depth;
|
||||
|
||||
/* Single flat cache with generation-based staleness detection. */
|
||||
flat_cache m_cache;
|
||||
std::vector<unsigned> m_scope_gens; /* m_scope_gens[level] = generation */
|
||||
unsigned m_gen_counter;
|
||||
unsigned m_scope;
|
||||
|
||||
/* After visit() returns, this holds the maximum fvar-substitution
|
||||
scope that contributed to the result — i.e., the outermost scope at which the
|
||||
result is valid and can be cached. Updated monotonically (via max) through
|
||||
the save/reset/restore pattern in visit(). */
|
||||
unsigned m_result_scope;
|
||||
|
||||
/* Global cache for fvar-free expressions — scope-independent. */
|
||||
lean::unordered_map<lean_object *, expr> m_global_cache;
|
||||
|
||||
/* Write-back support: when fvar_subst is empty, normalize and write back
|
||||
mvar assignments to match the original instantiateMVars mctx side effects.
|
||||
Downstream code (e.g. MutualDef.mkInitialUsedFVarsMap) reads stored
|
||||
assignments and expects them to be normalized. */
|
||||
name_set m_already_normalized;
|
||||
std::vector<expr> m_saved;
|
||||
|
||||
bool fvar_subst_empty() const {
|
||||
return m_fvar_subst.empty();
|
||||
}
|
||||
|
||||
optional<expr> lookup_fvar(name const & fid) {
|
||||
auto it = m_fvar_subst.find(fid);
|
||||
if (it == m_fvar_subst.end())
|
||||
return optional<expr>();
|
||||
m_result_scope = std::max(m_result_scope, it->second.scope);
|
||||
unsigned d = m_depth - it->second.depth;
|
||||
if (d == 0)
|
||||
return optional<expr>(it->second.value);
|
||||
return optional<expr>(lift_loose_bvars(it->second.value, d));
|
||||
}
|
||||
|
||||
/* Cache lookup — O(1) with generation-based staleness check.
|
||||
An entry at scope_level 0 (no fvar dependency) is valid at any scope.
|
||||
An entry at scope_level > 0 is only valid at exactly that scope level,
|
||||
because an inner scope may shadow the fvars it depends on. */
|
||||
optional<expr> cache_lookup(lean_object * ptr) {
|
||||
auto key = mk_pair(ptr, m_depth);
|
||||
auto it = m_cache.find(key);
|
||||
if (it == m_cache.end()) return {};
|
||||
auto & entry = it->second;
|
||||
if ((entry.scope_level == 0 || entry.scope_level == m_scope) &&
|
||||
m_scope_gens[entry.scope_level] == entry.scope_gen) {
|
||||
m_result_scope = std::max(m_result_scope, entry.scope_level);
|
||||
return optional<expr>(entry.result);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
void cache_insert(lean_object * ptr, expr const & result) {
|
||||
auto key = mk_pair(ptr, m_depth);
|
||||
m_cache[key] = { result, m_result_scope, m_scope_gens[m_result_scope] };
|
||||
}
|
||||
|
||||
/* Get a direct mvar assignment. Visit it to resolve delayed mvars
|
||||
and apply the fvar substitution.
|
||||
When fvar_subst is empty, normalize and write back the result to
|
||||
the mctx. This matches the original instantiateMVars behavior:
|
||||
downstream code (e.g. MutualDef.mkInitialUsedFVarsMap) reads stored
|
||||
assignments and expects inner delayed assignments to be resolved.
|
||||
When fvar_subst is non-empty, no write-back (values contain
|
||||
fvar-substituted terms not suitable for the mctx). */
|
||||
optional<expr> get_assignment(name const & mid) {
|
||||
option_ref<expr> r = get_mvar_assignment(m_mctx, mid);
|
||||
if (!r)
|
||||
return optional<expr>();
|
||||
expr a(r.get_val());
|
||||
if (fvar_subst_empty()) {
|
||||
if (!has_mvar(a))
|
||||
return optional<expr>(a);
|
||||
if (m_already_normalized.contains(mid))
|
||||
return optional<expr>(a);
|
||||
m_already_normalized.insert(mid);
|
||||
expr a_new = visit(a);
|
||||
if (!is_eqp(a, a_new)) {
|
||||
m_saved.push_back(a);
|
||||
assign_mvar(m_mctx, mid, a_new);
|
||||
}
|
||||
return optional<expr>(a_new);
|
||||
} else {
|
||||
if (!has_mvar(a) && !has_fvar(a))
|
||||
return optional<expr>(a);
|
||||
return optional<expr>(visit(a));
|
||||
}
|
||||
}
|
||||
|
||||
expr visit_app_default(expr const & e) {
|
||||
buffer<expr> args;
|
||||
expr const * curr = &e;
|
||||
while (is_app(*curr)) {
|
||||
args.push_back(visit(app_arg(*curr)));
|
||||
curr = &app_fn(*curr);
|
||||
}
|
||||
lean_assert(!is_mvar(*curr));
|
||||
expr f = visit(*curr);
|
||||
return mk_rev_app(f, args.size(), args.data());
|
||||
}
|
||||
|
||||
expr visit_mvar_app_args(expr const & e) {
|
||||
buffer<expr> args;
|
||||
expr const * curr = &e;
|
||||
while (is_app(*curr)) {
|
||||
args.push_back(visit(app_arg(*curr)));
|
||||
curr = &app_fn(*curr);
|
||||
}
|
||||
lean_assert(is_mvar(*curr));
|
||||
return mk_rev_app(*curr, args.size(), args.data());
|
||||
}
|
||||
|
||||
expr visit_args_and_beta(expr const & f_new, expr const & e, buffer<expr> & args) {
|
||||
expr const * curr = &e;
|
||||
while (is_app(*curr)) {
|
||||
args.push_back(visit(app_arg(*curr)));
|
||||
curr = &app_fn(*curr);
|
||||
}
|
||||
bool preserve_data = false;
|
||||
bool zeta = true;
|
||||
return apply_beta(f_new, args.size(), args.data(), preserve_data, zeta);
|
||||
}
|
||||
|
||||
expr visit_delayed(array_ref<expr> const & fvars, name const & mid_pending,
|
||||
expr const & e, buffer<expr> & args) {
|
||||
expr const * curr = &e;
|
||||
@@ -369,188 +531,125 @@ class instantiate_fused_fn {
|
||||
m_fvar_subst[fid] = {m_depth, m_scope, args[args.size() - 1 - i]};
|
||||
}
|
||||
|
||||
/* Bump generation so stale entries at this scope level are detected. */
|
||||
/* Push: bump generation so stale entries at this scope level are detected. */
|
||||
m_gen_counter++;
|
||||
if (m_scope >= m_scope_gens.size())
|
||||
m_scope_gens.push_back(m_gen_counter);
|
||||
else
|
||||
m_scope_gens[m_scope] = m_gen_counter;
|
||||
|
||||
/* Visit the pending value. In delayed mode, mk_mvar(mid_pending) will
|
||||
resolve via get_delayed_mode_assignment which visits the stored
|
||||
(already-normalized) value to apply fvar substitution. */
|
||||
expr val_new = visit(mk_mvar(mid_pending));
|
||||
|
||||
/* Pop scope. */
|
||||
/* Pop: just decrement scope — stale entries are detected by generation mismatch. */
|
||||
m_scope--;
|
||||
|
||||
/* Restore the fvar substitution. */
|
||||
for (auto & se : saved_entries) {
|
||||
if (!se.had_old)
|
||||
if (!se.had_old) {
|
||||
m_fvar_subst.erase(se.key);
|
||||
else
|
||||
} else {
|
||||
m_fvar_subst[se.key] = se.old;
|
||||
}
|
||||
}
|
||||
|
||||
/* Use apply_beta instead of mk_rev_app: pass 1's beta-reduction may have
|
||||
changed delayed mvar arguments (e.g., substituting a bvar with a concrete
|
||||
value), so the resolved pending value may be a lambda that needs beta-
|
||||
reduction with the extra args, matching the original's behavior. */
|
||||
bool preserve_data = false;
|
||||
bool zeta = true;
|
||||
return apply_beta(val_new, extra_count, args.data(), preserve_data, zeta);
|
||||
}
|
||||
|
||||
/* ---- Outer-mode app: resolve direct mvars, fire delayed inline ---- */
|
||||
|
||||
expr visit_app_outer(expr const & e) {
|
||||
expr visit_app(expr const & e) {
|
||||
expr const & f = get_app_fn(e);
|
||||
if (!is_mvar(f))
|
||||
if (!is_mvar(f)) {
|
||||
return visit_app_default(e);
|
||||
}
|
||||
name const & mid = mvar_name(f);
|
||||
/* Direct assignment takes precedence. */
|
||||
if (auto f_new = get_direct_assignment(mid)) {
|
||||
if (auto f_new = get_assignment(mid)) {
|
||||
buffer<expr> args;
|
||||
return visit_args_and_beta(*f_new, e, args);
|
||||
}
|
||||
/* Check delayed assignment. */
|
||||
option_ref<delayed_assignment> d = get_delayed_mvar_assignment(m_mctx, mid);
|
||||
if (!d)
|
||||
if (!d) {
|
||||
return visit_mvar_app_args(e);
|
||||
}
|
||||
array_ref<expr> fvars(cnstr_get(d.get_val().raw(), 0), true);
|
||||
name mid_pending(cnstr_get(d.get_val().raw(), 1), true);
|
||||
/* Pre-normalize the pending value and check resolvability. */
|
||||
normalize_delayed_pending(mid_pending);
|
||||
if (fvars.size() > get_app_num_args(e))
|
||||
if (fvars.size() > get_app_num_args(e)) {
|
||||
return visit_mvar_app_args(e);
|
||||
if (!m_resolvable_delayed.contains(mid_pending))
|
||||
}
|
||||
/* Match standard instantiateMVars: only resolve the delayed assignment
|
||||
when the pending value was determined to be resolvable by pass 1
|
||||
(assigned and mvar-free after normalization). */
|
||||
if (!m_resolvable_delayed.contains(mid_pending)) {
|
||||
/* Still normalize the pending value for mctx write-back side effects.
|
||||
The original instantiateMVars always normalizes the pending value
|
||||
(via get_assignment(mid_pending)) even when it can't resolve.
|
||||
Downstream code like MutualDef.mkInitialUsedFVarsMap reads stored
|
||||
assignments and relies on inner delayed assignments being resolved. */
|
||||
if (fvar_subst_empty()) {
|
||||
(void)get_assignment(mid_pending);
|
||||
}
|
||||
return visit_mvar_app_args(e);
|
||||
/* Fire delayed resolution inline. */
|
||||
}
|
||||
buffer<expr> args;
|
||||
return visit_delayed(fvars, mid_pending, e, args);
|
||||
}
|
||||
|
||||
/* ---- Delayed-mode app: apply fvar subst, resolve nested delayed ---- */
|
||||
|
||||
expr visit_app_delayed(expr const & e) {
|
||||
expr const & f = get_app_fn(e);
|
||||
if (!is_mvar(f))
|
||||
return visit_app_default(e);
|
||||
name const & mid = mvar_name(f);
|
||||
/* In delayed mode, direct chains are already resolved. Check assignment
|
||||
to apply fvar substitution to the stored value. */
|
||||
if (auto f_new = get_delayed_mode_assignment(mid)) {
|
||||
buffer<expr> args;
|
||||
return visit_args_and_beta(*f_new, e, args);
|
||||
}
|
||||
/* Check delayed assignment for nested resolution. */
|
||||
option_ref<delayed_assignment> d = get_delayed_mvar_assignment(m_mctx, mid);
|
||||
if (!d)
|
||||
return visit_mvar_app_args(e);
|
||||
array_ref<expr> fvars(cnstr_get(d.get_val().raw(), 0), true);
|
||||
name mid_pending(cnstr_get(d.get_val().raw(), 1), true);
|
||||
if (fvars.size() > get_app_num_args(e))
|
||||
return visit_mvar_app_args(e);
|
||||
if (!m_resolvable_delayed.contains(mid_pending))
|
||||
return visit_mvar_app_args(e);
|
||||
buffer<expr> args;
|
||||
return visit_delayed(fvars, mid_pending, e, args);
|
||||
}
|
||||
|
||||
/* ---- Mvar visitors ---- */
|
||||
|
||||
expr visit_mvar_outer(expr const & e) {
|
||||
expr visit_mvar(expr const & e) {
|
||||
name const & mid = mvar_name(e);
|
||||
if (auto r = get_direct_assignment(mid))
|
||||
if (auto r = get_assignment(mid)) {
|
||||
return *r;
|
||||
/* Not directly assigned. Check if delayed-assigned and pre-normalize. */
|
||||
option_ref<delayed_assignment> d = get_delayed_mvar_assignment(m_mctx, mid);
|
||||
if (d) {
|
||||
name mid_pending(cnstr_get(d.get_val().raw(), 1), true);
|
||||
normalize_delayed_pending(mid_pending);
|
||||
}
|
||||
return e;
|
||||
}
|
||||
|
||||
expr visit_mvar_delayed(expr const & e) {
|
||||
name const & mid = mvar_name(e);
|
||||
if (auto r = get_delayed_mode_assignment(mid))
|
||||
return *r;
|
||||
return e;
|
||||
}
|
||||
|
||||
expr visit_fvar(expr const & e) {
|
||||
name const & fid = fvar_name(e);
|
||||
if (auto r = lookup_fvar(fid))
|
||||
if (auto r = lookup_fvar(fid)) {
|
||||
return *r;
|
||||
}
|
||||
return e;
|
||||
}
|
||||
|
||||
/* ---- Outer-mode visit ---- */
|
||||
|
||||
expr visit_outer(expr const & e) {
|
||||
if (!has_mvar(e))
|
||||
return e;
|
||||
bool shared = false;
|
||||
if (is_shared(e)) {
|
||||
auto it = m_cache.find(e.raw());
|
||||
if (it != m_cache.end())
|
||||
return it->second;
|
||||
shared = true;
|
||||
}
|
||||
|
||||
expr r;
|
||||
switch (e.kind()) {
|
||||
case expr_kind::BVar:
|
||||
case expr_kind::Lit: case expr_kind::FVar:
|
||||
lean_unreachable();
|
||||
case expr_kind::Sort:
|
||||
r = update_sort(e, visit_level(sort_level(e)));
|
||||
break;
|
||||
case expr_kind::Const:
|
||||
r = update_const(e, visit_levels(const_levels(e)));
|
||||
break;
|
||||
case expr_kind::MVar:
|
||||
return visit_mvar_outer(e); /* don't cache mvar results */
|
||||
case expr_kind::MData:
|
||||
r = update_mdata(e, visit(mdata_expr(e)));
|
||||
break;
|
||||
case expr_kind::Proj:
|
||||
r = update_proj(e, visit(proj_expr(e)));
|
||||
break;
|
||||
case expr_kind::App:
|
||||
r = visit_app_outer(e);
|
||||
break;
|
||||
case expr_kind::Pi: case expr_kind::Lambda:
|
||||
r = update_binding(e, visit(binding_domain(e)), visit(binding_body(e)));
|
||||
break;
|
||||
case expr_kind::Let:
|
||||
r = update_let(e, visit(let_type(e)), visit(let_value(e)), visit(let_body(e)));
|
||||
break;
|
||||
}
|
||||
if (shared)
|
||||
m_cache.insert(mk_pair(e.raw(), r));
|
||||
return r;
|
||||
public:
|
||||
instantiate_delayed_fn(metavar_ctx & mctx, name_set const & resolvable_delayed)
|
||||
: m_mctx(mctx), m_resolvable_delayed(resolvable_delayed),
|
||||
m_depth(0), m_gen_counter(0), m_scope(0), m_result_scope(0) {
|
||||
m_scope_gens.push_back(0); /* scope 0 has generation 0 */
|
||||
}
|
||||
|
||||
/* ---- Delayed-mode visit ---- */
|
||||
expr visit(expr const & e) {
|
||||
if (fvar_subst_empty()) {
|
||||
if (!has_mvar(e))
|
||||
return e;
|
||||
} else {
|
||||
if (!has_mvar(e) && !has_fvar(e))
|
||||
return e;
|
||||
}
|
||||
|
||||
expr visit_delayed_mode(expr const & e) {
|
||||
if (!has_mvar(e) && !has_fvar(e))
|
||||
return e;
|
||||
|
||||
/* use_global: expression has no fvars and no expr mvars, so its result
|
||||
is scope-independent and can be read from/written to the shared cache. */
|
||||
bool use_global = !has_fvar(e) && !has_expr_mvar(e);
|
||||
bool shared = false;
|
||||
if (is_shared(e)) {
|
||||
if (use_global) {
|
||||
auto it = m_cache.find(e.raw());
|
||||
if (it != m_cache.end())
|
||||
auto it = m_global_cache.find(e.raw());
|
||||
if (it != m_global_cache.end())
|
||||
return it->second;
|
||||
} else {
|
||||
if (auto r = scoped_cache_lookup(e.raw()))
|
||||
if (auto r = cache_lookup(e.raw()))
|
||||
return *r;
|
||||
}
|
||||
shared = true;
|
||||
}
|
||||
|
||||
/* Save and reset the result scope for this subtree.
|
||||
After computing, cache_insert uses m_result_scope to place the entry
|
||||
at the outermost valid scope level. Then we restore the parent's
|
||||
watermark, taking the max with our contribution. */
|
||||
unsigned saved_result_scope = m_result_scope;
|
||||
m_result_scope = 0;
|
||||
|
||||
@@ -569,7 +668,7 @@ class instantiate_fused_fn {
|
||||
r = update_const(e, visit_levels(const_levels(e)));
|
||||
break;
|
||||
case expr_kind::MVar:
|
||||
r = visit_mvar_delayed(e);
|
||||
r = visit_mvar(e);
|
||||
goto done; /* mvar results are not (ptr, depth)-cacheable */
|
||||
case expr_kind::MData:
|
||||
r = update_mdata(e, visit(mdata_expr(e)));
|
||||
@@ -578,7 +677,7 @@ class instantiate_fused_fn {
|
||||
r = update_proj(e, visit(proj_expr(e)));
|
||||
break;
|
||||
case expr_kind::App:
|
||||
r = visit_app_delayed(e);
|
||||
r = visit_app(e);
|
||||
break;
|
||||
case expr_kind::Pi: case expr_kind::Lambda: {
|
||||
expr d = visit(binding_domain(e));
|
||||
@@ -599,10 +698,10 @@ class instantiate_fused_fn {
|
||||
}
|
||||
}
|
||||
if (shared) {
|
||||
if (m_result_scope == 0)
|
||||
m_cache.insert(mk_pair(e.raw(), r));
|
||||
if (use_global)
|
||||
m_global_cache.insert(mk_pair(e.raw(), r));
|
||||
else
|
||||
scoped_cache_insert(e.raw(), r);
|
||||
cache_insert(e.raw(), r);
|
||||
}
|
||||
|
||||
done:
|
||||
@@ -610,37 +709,54 @@ class instantiate_fused_fn {
|
||||
return r;
|
||||
}
|
||||
|
||||
public:
|
||||
instantiate_fused_fn(metavar_ctx & mctx)
|
||||
: m_mctx(mctx), m_level_fn(mctx),
|
||||
m_depth(0), m_gen_counter(0), m_scope(0), m_result_scope(0) {
|
||||
m_scope_gens.push_back(0);
|
||||
level visit_level(level const & l) {
|
||||
/* Pass 2 does not handle level mvars — pass 1 already resolved them.
|
||||
But we still need this for the visit_levels call in update_sort/update_const.
|
||||
Since levels have no fvars, we can just return them as-is. */
|
||||
return l;
|
||||
}
|
||||
|
||||
expr visit(expr const & e) {
|
||||
if (in_delayed_mode())
|
||||
return visit_delayed_mode(e);
|
||||
else
|
||||
return visit_outer(e);
|
||||
levels visit_levels(levels const & ls) {
|
||||
return ls;
|
||||
}
|
||||
|
||||
expr operator()(expr const & e) { return visit(e); }
|
||||
};
|
||||
|
||||
/* ============================================================================
|
||||
Entry points.
|
||||
Entry points: run pass 1 then pass 2.
|
||||
============================================================================ */
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars_all(object * m, object * e) {
|
||||
static object * run_instantiate_all(object * m, object * e) {
|
||||
metavar_ctx mctx(m);
|
||||
expr e_new = instantiate_fused_fn(mctx)(expr(e));
|
||||
|
||||
/* Pass 1: resolve direct mvar assignments, pre-normalize pending values. */
|
||||
instantiate_direct_fn pass1(mctx);
|
||||
expr e1 = pass1(expr(e));
|
||||
|
||||
/* Pass 2: resolve delayed assignments with fused fvar substitution.
|
||||
Skip if pass 1 found no delayed assignments at all — the expression
|
||||
has no delayed mvars that need resolution or write-back. */
|
||||
expr e2;
|
||||
if (!pass1.has_delayed()) {
|
||||
e2 = e1;
|
||||
} else {
|
||||
instantiate_delayed_fn pass2(mctx, pass1.resolvable_delayed());
|
||||
e2 = pass2(e1);
|
||||
}
|
||||
|
||||
/* (mctx, expr) */
|
||||
object * r = alloc_cnstr(0, 2, 0);
|
||||
cnstr_set(r, 0, mctx.steal());
|
||||
cnstr_set(r, 1, e_new.steal());
|
||||
cnstr_set(r, 1, e2.steal());
|
||||
return r;
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars_all(object * m, object * e) {
|
||||
return run_instantiate_all(m, e);
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars_all_sharing(object * m, object * e) {
|
||||
return lean_instantiate_expr_mvars_all(m, e);
|
||||
return run_instantiate_all(m, e);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user