diff --git a/src/Lean/Elab/PreDefinition/Basic.lean b/src/Lean/Elab/PreDefinition/Basic.lean index 812e56198b..2df9ad63ff 100644 --- a/src/Lean/Elab/PreDefinition/Basic.lean +++ b/src/Lean/Elab/PreDefinition/Basic.lean @@ -11,7 +11,6 @@ public import Lean.Util.NumApps public import Lean.Meta.Eqns public import Lean.Elab.RecAppSyntax public import Lean.Elab.DefView - public section namespace Lean.Elab diff --git a/src/Lean/MetavarContext.lean b/src/Lean/MetavarContext.lean index 63ebeb43ed..137bfbb085 100644 --- a/src/Lean/MetavarContext.lean +++ b/src/Lean/MetavarContext.lean @@ -400,6 +400,12 @@ def MetavarContext.getDelayedMVarAssignmentCore? (mctx : MetavarContext) (mvarId def MetavarContext.getDelayedMVarAssignmentExp (mctx : MetavarContext) (mvarId : MVarId) : Option DelayedMetavarAssignment := mctx.dAssignment.find? mvarId +@[export lean_delayed_mvar_assignment_fvars] +def DelayedMetavarAssignment.fvarsExp (d : DelayedMetavarAssignment) : Array Expr := d.fvars + +@[export lean_delayed_mvar_assignment_mvar_id_pending] +def DelayedMetavarAssignment.mvarIdPendingExp (d : DelayedMetavarAssignment) : MVarId := d.mvarIdPending + def getDelayedMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option DelayedMetavarAssignment) := return (← getMCtx).getDelayedMVarAssignmentCore? mvarId diff --git a/src/kernel/CMakeLists.txt b/src/kernel/CMakeLists.txt index 913f167235..ffeb321412 100644 --- a/src/kernel/CMakeLists.txt +++ b/src/kernel/CMakeLists.txt @@ -18,5 +18,4 @@ add_library( quot.cpp inductive.cpp trace.cpp - instantiate_mvars.cpp ) diff --git a/src/kernel/instantiate_mvars.cpp b/src/kernel/instantiate_mvars.cpp deleted file mode 100644 index b49d74aa5c..0000000000 --- a/src/kernel/instantiate_mvars.cpp +++ /dev/null @@ -1,373 +0,0 @@ -/* -Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. -Released under Apache 2.0 license as described in the file LICENSE. - -Authors: Leonardo de Moura -*/ -#include -#include -#include "util/name_set.h" -#include "runtime/option_ref.h" -#include "runtime/array_ref.h" -#include "kernel/instantiate.h" -#include "kernel/replace_fn.h" - -/* -This module is not used by the kernel. It just provides an efficient implementation of -`instantiateExprMVars` -*/ - -namespace lean { -extern "C" object * lean_get_lmvar_assignment(obj_arg mctx, obj_arg mid); -extern "C" object * lean_assign_lmvar(obj_arg mctx, obj_arg mid, obj_arg val); - -typedef object_ref metavar_ctx; -void assign_lmvar(metavar_ctx & mctx, name const & mid, level const & l) { - object * r = lean_assign_lmvar(mctx.steal(), mid.to_obj_arg(), l.to_obj_arg()); - mctx.set_box(r); -} - -option_ref get_lmvar_assignment(metavar_ctx & mctx, name const & mid) { - return option_ref(lean_get_lmvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg())); -} - -class instantiate_lmvars_fn { - metavar_ctx & m_mctx; - lean::unordered_map m_cache; - std::vector m_saved; // Helper vector to prevent values from being garbage collected - - inline level cache(level const & l, level r, bool shared) { - if (shared) { - m_cache.insert(mk_pair(l.raw(), r)); - } - return r; - } -public: - instantiate_lmvars_fn(metavar_ctx & mctx):m_mctx(mctx) {} - level visit(level const & l) { - if (!has_mvar(l)) - return l; - bool shared = false; - if (is_shared(l)) { - auto it = m_cache.find(l.raw()); - if (it != m_cache.end()) { - return it->second; - } - shared = true; - } - switch (l.kind()) { - case level_kind::Succ: - return cache(l, update_succ(l, visit(succ_of(l))), shared); - case level_kind::Max: case level_kind::IMax: - return cache(l, update_max(l, visit(level_lhs(l)), visit(level_rhs(l))), shared); - case level_kind::Zero: case level_kind::Param: - lean_unreachable(); - case level_kind::MVar: { - option_ref r = get_lmvar_assignment(m_mctx, mvar_id(l)); - if (!r) { - return l; - } else { - level a(r.get_val()); - if (!has_mvar(a)) { - return a; - } else { - level a_new = visit(a); - if (!is_eqp(a, a_new)) { - /* - We save `a` to ensure it will not be garbage collected - after we update `mctx`. This is necessary because `m_cache` - may contain references to its subterms. - */ - m_saved.push_back(a); - assign_lmvar(m_mctx, mvar_id(l), a_new); - } - return a_new; - } - } - }} - } - level operator()(level const & l) { return visit(l); } -}; - -extern "C" LEAN_EXPORT object * lean_instantiate_level_mvars(object * m, object * l) { - metavar_ctx mctx(m); - level l_new = instantiate_lmvars_fn(mctx)(level(l)); - object * r = alloc_cnstr(0, 2, 0); - cnstr_set(r, 0, mctx.steal()); - cnstr_set(r, 1, l_new.steal()); - return r; -} - -extern "C" object * lean_get_mvar_assignment(obj_arg mctx, obj_arg mid); -extern "C" object * lean_get_delayed_mvar_assignment(obj_arg mctx, obj_arg mid); -extern "C" object * lean_assign_mvar(obj_arg mctx, obj_arg mid, obj_arg val); -typedef object_ref delayed_assignment; - -void assign_mvar(metavar_ctx & mctx, name const & mid, expr const & e) { - object * r = lean_assign_mvar(mctx.steal(), mid.to_obj_arg(), e.to_obj_arg()); - mctx.set_box(r); -} - -option_ref get_mvar_assignment(metavar_ctx & mctx, name const & mid) { - return option_ref(lean_get_mvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg())); -} - -option_ref get_delayed_mvar_assignment(metavar_ctx & mctx, name const & mid) { - return option_ref(lean_get_delayed_mvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg())); -} - -expr replace_fvars(expr const & e, array_ref const & fvars, expr const * rev_args) { - size_t sz = fvars.size(); - if (sz == 0) - return e; - return replace(e, [=](expr const & m, unsigned offset) -> optional { - if (!has_fvar(m)) - return some_expr(m); // expression m does not contain free variables - if (is_fvar(m)) { - size_t i = sz; - name const & fid = fvar_name(m); - while (i > 0) { - --i; - if (fvar_name(fvars[i]) == fid) { - return some_expr(lift_loose_bvars(rev_args[sz - i - 1], offset)); - } - } - } - return none_expr(); - }); -} - -class instantiate_mvars_fn { - metavar_ctx & m_mctx; - instantiate_lmvars_fn m_level_fn; - name_set m_already_normalized; // Store metavariables whose assignment has already been normalized. - lean::unordered_map m_cache; - std::vector m_saved; // Helper vector to prevent values from being garbage collected - - level visit_level(level const & l) { - return m_level_fn(l); - } - - levels visit_levels(levels const & ls) { - buffer lsNew; - for (auto const & l : ls) - lsNew.push_back(visit_level(l)); - return levels(lsNew); - } - - inline expr cache(expr const & e, expr r, bool shared) { - if (shared) { - m_cache.insert(mk_pair(e.raw(), r)); - } - return r; - } - - optional get_assignment(name const & mid) { - option_ref r = get_mvar_assignment(m_mctx, mid); - if (!r) { - return optional(); - } else { - expr a(r.get_val()); - if (!has_mvar(a) || m_already_normalized.contains(mid)) { - return optional(a); - } else { - m_already_normalized.insert(mid); - expr a_new = visit(a); - if (!is_eqp(a, a_new)) { - /* - We save `a` to ensure it will not be garbage collected - after we update `mctx`. This is necessary because `m_cache` - may contain references to its subterms. - */ - m_saved.push_back(a); - assign_mvar(m_mctx, mid, a_new); - } - return optional(a_new); - } - } - } - - /* - Given `e` of the form `f a_1 ... a_n` where `f` is not a metavariable, - instantiate metavariables. - */ - expr visit_app_default(expr const & e) { - buffer args; - expr const * curr = &e; - while (is_app(*curr)) { - args.push_back(visit(app_arg(*curr))); - curr = &app_fn(*curr); - } - lean_assert(!is_mvar(*curr)); - expr f = visit(*curr); - return mk_rev_app(f, args.size(), args.data()); - } - - /* - Given `e` of the form `?m a_1 ... a_n`, return new application where - the metavariables in the arguments `a_i` have been instantiated. - */ - expr visit_mvar_app_args(expr const & e) { - buffer args; - expr const * curr = &e; - while (is_app(*curr)) { - args.push_back(visit(app_arg(*curr))); - curr = &app_fn(*curr); - } - lean_assert(is_mvar(*curr)); - return mk_rev_app(*curr, args.size(), args.data()); - } - - /* - Given `e` of the form `f a_1 ... a_n`, return new application `f_new a_1' ... a_n'` - where `a_i'` is `visit(a_i)`. `args` is an accumulator for the new arguments. - */ - expr visit_args_and_beta(expr const & f_new, expr const & e, buffer & args) { - expr const * curr = &e; - while (is_app(*curr)) { - args.push_back(visit(app_arg(*curr))); - curr = &app_fn(*curr); - } - /* - Some of the arguments in `args` are irrelevant after we beta - reduce. Also, it may be a bug to not instantiate them, since they - may depend on free variables that are not in the context (see - issue #4375). So we pass `useZeta := true` to ensure that they are - instantiated. - */ - bool preserve_data = false; - bool zeta = true; - return apply_beta(f_new, args.size(), args.data(), preserve_data, zeta); - } - - /* - Helper function for delayed assignment case at `visit_app`. - `e` is a term of the form `?m t1 t2 t3` - Moreover, `?m` is delayed assigned - `?m #[x, y] := g x y` - where, `fvars := #[x, y]` and `val := g x y`. - `args` is an accumulator for `e`'s arguments. - - We want to return `g t1' t2' t3'` where - `ti'`s are `visit(ti)`. - */ - expr visit_delayed(array_ref const & fvars, expr const & val, expr const & e, buffer & args) { - expr const * curr = &e; - while (is_app(*curr)) { - args.push_back(visit(app_arg(*curr))); - curr = &app_fn(*curr); - } - expr val_new = replace_fvars(val, fvars, args.data() + (args.size() - fvars.size())); - return mk_rev_app(val_new, args.size() - fvars.size(), args.data()); - } - - expr visit_app(expr const & e) { - expr const & f = get_app_fn(e); - if (!is_mvar(f)) { - return visit_app_default(e); - } else { - name const & mid = mvar_name(f); - /* - Regular assignments take precedence over delayed ones. - When an error occurs, Lean assigns `sorry` to unassigned metavariables. - The idea is to ensure we can submit the declaration to the kernel and proceed. - Some of the metavariables may have been delayed assigned. - */ - if (auto f_new = get_assignment(mid)) { - // `f` is an assigned metavariable. - buffer args; - return visit_args_and_beta(*f_new, e, args); - } - option_ref d = get_delayed_mvar_assignment(m_mctx, mid); - if (!d) { - // mvar is not delayed assigned - return visit_mvar_app_args(e); - } - /* - Apply "delayed substitution" (i.e., delayed assignment + application). - That is, `f` is some metavariable `?m`, that is delayed assigned to `val`. - If after instantiating `val`, we obtain `newVal`, and `newVal` does not contain - metavariables, we replace the free variables `fvars` in `newVal` with the first - `fvars.size` elements of `args`. - */ - array_ref fvars(cnstr_get(d.get_val().raw(), 0), true); - name mid_pending(cnstr_get(d.get_val().raw(), 1), true); - if (fvars.size() > get_app_num_args(e)) { - /* - We don't have sufficient arguments for instantiating the free variables `fvars`. - This can only happen if a tactic or elaboration function is not implemented correctly. - We decided to not use `panic!` here and report it as an error in the frontend - when we are checking for unassigned metavariables in an elaborated term. */ - return visit_mvar_app_args(e); - } - optional val = get_assignment(mid_pending); - if (!val) - // mid_pending has not been assigned yet. - return visit_mvar_app_args(e); - if (has_expr_mvar(*val)) - // mid_pending has been assigned, but assignment contains mvars. - return visit_mvar_app_args(e); - buffer args; - return visit_delayed(fvars, *val, e, args); - } - } - - expr visit_mvar(expr const & e) { - name const & mid = mvar_name(e); - if (auto r = get_assignment(mid)) { - return *r; - } else { - return e; - } - } - -public: - instantiate_mvars_fn(metavar_ctx & mctx):m_mctx(mctx), m_level_fn(mctx) {} - - expr visit(expr const & e) { - if (!has_mvar(e)) - return e; - bool shared = false; - if (is_shared(e)) { - auto it = m_cache.find(e.raw()); - if (it != m_cache.end()) { - return it->second; - } - shared = true; - } - - switch (e.kind()) { - case expr_kind::BVar: - case expr_kind::Lit: case expr_kind::FVar: - lean_unreachable(); - case expr_kind::Sort: - return cache(e, update_sort(e, visit_level(sort_level(e))), shared); - case expr_kind::Const: - return cache(e, update_const(e, visit_levels(const_levels(e))), shared); - case expr_kind::MVar: - return visit_mvar(e); - case expr_kind::MData: - return cache(e, update_mdata(e, visit(mdata_expr(e))), shared); - case expr_kind::Proj: - return cache(e, update_proj(e, visit(proj_expr(e))), shared); - case expr_kind::App: - return cache(e, visit_app(e), shared); - case expr_kind::Pi: case expr_kind::Lambda: - return cache(e, update_binding(e, visit(binding_domain(e)), visit(binding_body(e))), shared); - case expr_kind::Let: - return cache(e, update_let(e, visit(let_type(e)), visit(let_value(e)), visit(let_body(e))), shared); - } - } - - expr operator()(expr const & e) { return visit(e); } -}; - -extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars(object * m, object * e) { - metavar_ctx mctx(m); - expr e_new = instantiate_mvars_fn(mctx)(expr(e)); - object * r = alloc_cnstr(0, 2, 0); - cnstr_set(r, 0, mctx.steal()); - cnstr_set(r, 1, e_new.steal()); - return r; -} -} diff --git a/src/library/CMakeLists.txt b/src/library/CMakeLists.txt index 222202b1de..d778c326f3 100644 --- a/src/library/CMakeLists.txt +++ b/src/library/CMakeLists.txt @@ -20,4 +20,5 @@ add_library( init_attribute.cpp llvm.cpp ir_interpreter.cpp + instantiate_mvars.cpp ) diff --git a/src/library/instantiate_mvars.cpp b/src/library/instantiate_mvars.cpp new file mode 100644 index 0000000000..e7b73909e4 --- /dev/null +++ b/src/library/instantiate_mvars.cpp @@ -0,0 +1,750 @@ +/* +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Authors: Joachim Breitner +*/ +#include +#include +#include "util/name_set.h" +#include "util/name_hash_map.h" +#include "runtime/option_ref.h" +#include "runtime/array_ref.h" +#include "kernel/instantiate.h" +#include "kernel/expr.h" +#include "library/scope_cache.h" + +/* +This module provides an implementation of `instantiateMVars` with linear +complexity in the presence of nested delayed-assigned metavariables and +improved sharing. It proceeds in two passes. + +Terminology (for this file): + +* Direct MVar: an MVar that is not delayed-assigned. +* Pending MVar: the direct MVar stored in a `DelayedMetavarAssignment`. +* Assigned MVar: a direct MVar with an assignment, or a delayed-assigned MVar + with an assigned pending MVar. +* MVar DAG: the directed acyclic graph of MVars reachable from the expression. +* Resolvable MVar: an MVar where all MVars reachable from it (including itself) + are assigned. +* Updateable MVar: an assigned direct MVar, or a delayed-assigned MVar that is + resolvable but not reachable from any other resolvable delayed-assigned MVar. + +In the MVar DAG, the updateable delayed-assigned MVars form a cut with only +assigned MVars behind it and no resolvable delayed-assigned MVars before it. + +Pass 1 (`instantiate_direct_fn`): + Traverses all MVars and expressions reachable from the initial expression and + * instantiates all updateable direct MVars, updating their assignment with + its instantiation, + * instantiates all level MVars, + * determines if there are any updateable delayed-assigned MVars. + +Pass 2 (`instantiate_delayed_fn`): + Only run if there are updateable delayed-assigned MVars. Has an "outer" and + an "inner" mode, depending on whether it has crossed the updateable-MVar cut. + + In outer mode (empty substitution), all MVars are either unassigned direct + MVars (left alone), non-updateable delayed-assigned MVars (pending MVar + traversed in outer mode and updated with the result), or updateable + delayed-assigned MVars. + + When a delayed-assigned MVar is encountered, its MVar DAG is explored (via + `is_resolvable_pending`) to determine if it is resolvable (and thus + updateable). Results are cached across invocations. + + If it is updateable, the substitution is initialized from its arguments and + traversal continues with the value of its pending MVar in inner mode. + + In inner mode (non-empty substitution), all encountered delayed-assigned + MVars are, by construction, resolvable but not updateable. The substitution + is carried along and extended as we cross such MVars. Pending MVars of these + delayed-assigned MVars are NOT updated with the result (as the result is + valid only for this substitution, not in general). + + Applying the substitution in one go, rather than instantiating each + delayed-assigned MVar on its own from inside out, avoids the quadratic + overhead of that approach when there are long chains of delayed-assigned + MVars. + + A special-crafted caching data structure, the `scope_cache`, ensures that + sharing is preserved even across different delayed-assigned MVars (and hence + with different substitutions), when possible. +*/ + +namespace lean { +extern "C" object * lean_get_lmvar_assignment(obj_arg mctx, obj_arg mid); +extern "C" object * lean_assign_lmvar(obj_arg mctx, obj_arg mid, obj_arg val); +extern "C" object * lean_get_mvar_assignment(obj_arg mctx, obj_arg mid); +extern "C" object * lean_get_delayed_mvar_assignment(obj_arg mctx, obj_arg mid); +extern "C" object * lean_delayed_mvar_assignment_fvars(obj_arg d); +extern "C" object * lean_delayed_mvar_assignment_mvar_id_pending(obj_arg d); +extern "C" object * lean_assign_mvar(obj_arg mctx, obj_arg mid, obj_arg val); +typedef object_ref metavar_ctx; +typedef object_ref delayed_assignment; + +static void assign_lmvar(metavar_ctx & mctx, name const & mid, level const & l) { + object * r = lean_assign_lmvar(mctx.steal(), mid.to_obj_arg(), l.to_obj_arg()); + mctx.set_box(r); +} + +static option_ref get_lmvar_assignment(metavar_ctx & mctx, name const & mid) { + return option_ref(lean_get_lmvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg())); +} + +static void assign_mvar(metavar_ctx & mctx, name const & mid, expr const & e) { + object * r = lean_assign_mvar(mctx.steal(), mid.to_obj_arg(), e.to_obj_arg()); + mctx.set_box(r); +} + +static option_ref get_mvar_assignment(metavar_ctx & mctx, name const & mid) { + return option_ref(lean_get_mvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg())); +} + +static option_ref get_delayed_mvar_assignment(metavar_ctx & mctx, name const & mid) { + return option_ref(lean_get_delayed_mvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg())); +} + +static array_ref delayed_assignment_fvars(delayed_assignment const & d) { + return array_ref(lean_delayed_mvar_assignment_fvars(d.to_obj_arg())); +} + +static name delayed_assignment_mvar_id_pending(delayed_assignment const & d) { + return name(lean_delayed_mvar_assignment_mvar_id_pending(d.to_obj_arg())); +} + +/* Level metavariable instantiation. */ +class instantiate_lmvars_all_fn { + metavar_ctx & m_mctx; + lean::unordered_map m_cache; + std::vector m_saved; + + inline level cache(level const & l, level r, bool shared) { + if (shared) { + m_cache.insert(mk_pair(l.raw(), r)); + } + return r; + } +public: + instantiate_lmvars_all_fn(metavar_ctx & mctx):m_mctx(mctx) {} + level visit(level const & l) { + if (!has_mvar(l)) + return l; + bool shared = false; + if (is_shared(l)) { + auto it = m_cache.find(l.raw()); + if (it != m_cache.end()) { + return it->second; + } + shared = true; + } + switch (l.kind()) { + case level_kind::Succ: + return cache(l, update_succ(l, visit(succ_of(l))), shared); + case level_kind::Max: case level_kind::IMax: + return cache(l, update_max(l, visit(level_lhs(l)), visit(level_rhs(l))), shared); + case level_kind::Zero: case level_kind::Param: + lean_unreachable(); + case level_kind::MVar: { + option_ref r = get_lmvar_assignment(m_mctx, mvar_id(l)); + if (!r) { + return l; + } else { + level a(r.get_val()); + if (!has_mvar(a)) { + return a; + } else { + level a_new = visit(a); + if (!is_eqp(a, a_new)) { + m_saved.push_back(a); + assign_lmvar(m_mctx, mvar_id(l), a_new); + } + return a_new; + } + } + }} + } + level operator()(level const & l) { return visit(l); } +}; + +/* ============================================================================ + Pass 1: Instantiate updateable direct MVars with write-back. + For delayed-assigned MVars, pre-normalize the pending MVar's value + (resolving its direct MVar chains) but leave the delayed-assigned MVar + application in the expression. Also instantiates level MVars. + Unassigned MVars are left in place. + ============================================================================ */ + +class instantiate_direct_fn { + metavar_ctx & m_mctx; + instantiate_lmvars_all_fn m_level_fn; + name_set m_already_normalized; + /* Set to true when a delayed-assigned MVar with an assigned pending MVar + is encountered. Pass 2 is needed to resolve or write back such MVars. */ + bool m_has_updateable_delayed; + + lean::unordered_map m_cache; + std::vector m_saved; + + level visit_level(level const & l) { + return m_level_fn(l); + } + + levels visit_levels(levels const & ls) { + return map_reuse(ls, [&](level const & l) { return visit_level(l); }); + } + + inline expr cache(expr const & e, expr r, bool shared) { + if (shared) { + m_cache.insert(mk_pair(e.raw(), r)); + } + return r; + } + + /* Get and normalize an updateable direct MVar's assignment. Write back the + normalized value. */ + optional get_assignment(name const & mid) { + option_ref r = get_mvar_assignment(m_mctx, mid); + if (!r) { + return optional(); + } + expr a(r.get_val()); + if (!has_mvar(a) || m_already_normalized.contains(mid)) { + return optional(a); + } + m_already_normalized.insert(mid); + expr a_new = visit(a); + if (!is_eqp(a, a_new)) { + m_saved.push_back(a); + assign_mvar(m_mctx, mid, a_new); + } + return optional(a_new); + } + + /* Visit an application whose head is not an MVar, preserving sharing of + application prefixes (e.g. for `f a b c`, if only `c` changes, + the nodes `f a` and `f a b` are pointer-shared with the original). */ + expr visit_nonmvar_app(expr const & e) { + expr new_a = visit(app_arg(e)); + expr const & fn = app_fn(e); + expr new_f = is_app(fn) ? visit_nonmvar_app(fn) : visit(fn); + return update_app(e, new_f, new_a); + } + + /* Collect the application spine into a buffer and apply beta reduction. */ + expr visit_app_beta(expr const & f_new, expr const & e) { + buffer args; + expr const * curr = &e; + while (is_app(*curr)) { + args.push_back(visit(app_arg(*curr))); + curr = &app_fn(*curr); + } + bool preserve_data = false; + bool zeta = true; + return apply_beta(f_new, args.size(), args.data(), preserve_data, zeta); + } + + expr visit_app(expr const & e) { + expr const & f = get_app_fn(e); + if (!is_mvar(f)) { + return visit_nonmvar_app(e); + } + name const & mid = mvar_name(f); + /* Direct MVar assignment takes precedence. */ + if (auto f_new = get_assignment(mid)) { + return visit_app_beta(*f_new, e); + } + /* Check for delayed-assigned MVar. */ + option_ref d = get_delayed_mvar_assignment(m_mctx, mid); + if (d) { + /* Pre-normalize the pending MVar's value so pass 2 finds it ready. + Only trigger pass 2 if the pending MVar is actually assigned; + unassigned pending MVars will clearly fail the resolvability check. */ + name mid_pending = delayed_assignment_mvar_id_pending(d.get_val()); + if (get_assignment(mid_pending)) + m_has_updateable_delayed = true; + } + /* Unresolved MVar head: visit structurally, preserving app prefix sharing. */ + return visit_nonmvar_app(e); + } + + expr visit_mvar(expr const & e) { + name const & mid = mvar_name(e); + if (auto r = get_assignment(mid)) { + return *r; + } + /* Not a direct MVar with assignment. Check if delayed-assigned. */ + option_ref d = get_delayed_mvar_assignment(m_mctx, mid); + if (d) { + name mid_pending = delayed_assignment_mvar_id_pending(d.get_val()); + if (get_assignment(mid_pending)) + m_has_updateable_delayed = true; + } + return e; + } + +public: + instantiate_direct_fn(metavar_ctx & mctx) + : m_mctx(mctx), m_level_fn(mctx), m_has_updateable_delayed(false) {} + bool has_updateable_delayed() const { return m_has_updateable_delayed; } + + expr visit(expr const & e) { + if (!has_mvar(e)) + return e; + bool shared = false; + if (is_shared(e)) { + auto it = m_cache.find(e.raw()); + if (it != m_cache.end()) { + return it->second; + } + shared = true; + } + + switch (e.kind()) { + case expr_kind::BVar: + case expr_kind::Lit: case expr_kind::FVar: + lean_unreachable(); + case expr_kind::Sort: + return cache(e, update_sort(e, visit_level(sort_level(e))), shared); + case expr_kind::Const: + return cache(e, update_const(e, visit_levels(const_levels(e))), shared); + case expr_kind::MVar: + return visit_mvar(e); + case expr_kind::MData: + return cache(e, update_mdata(e, visit(mdata_expr(e))), shared); + case expr_kind::Proj: + return cache(e, update_proj(e, visit(proj_expr(e))), shared); + case expr_kind::App: + return cache(e, visit_app(e), shared); + case expr_kind::Pi: case expr_kind::Lambda: + return cache(e, update_binding(e, visit(binding_domain(e)), visit(binding_body(e))), shared); + case expr_kind::Let: + return cache(e, update_let(e, visit(let_type(e)), visit(let_value(e)), visit(let_body(e))), shared); + } + } + + expr operator()(expr const & e) { return visit(e); } +}; + +/* ============================================================================ + Pass 2: Resolve delayed-assigned MVars with fused fvar substitution. + Direct MVar chains have been pre-resolved by pass 1. + + Write-back behavior: + + Delayed-assigned MVars form a dependency tree: each delayed-assigned MVar's + pending MVar value may reference other delayed-assigned MVars. Some subtrees + of this tree are fully resolvable (all delayed-assigned MVars within are + resolvable), while others are not. + + Pass 2 fully resolves every maximal resolvable subtree. The roots of these + subtrees — updateable delayed-assigned MVars that are resolvable but whose + parent in the tree is not — form the updateable-MVar cut through the + dependency tree. Above the cut sit non-resolvable delayed-assigned MVars; + below the cut, everything is resolved. + + Pass 2 writes back the normalized pending MVar values of delayed-assigned + MVars above the cut (the non-resolvable ones whose children may have been + resolved). This is exactly the right set: these MVars are visited in outer + mode (empty fvar substitution), so their normalized values are suitable for + storing in the mctx. MVars below the cut are visited in inner mode + (non-empty substitution, fvars replaced by arguments), so their intermediate + values cannot be written back. + ============================================================================ */ + +struct fvar_subst_entry { + unsigned depth; + unsigned scope; + expr value; +}; + +class instantiate_delayed_fn { + metavar_ctx & m_mctx; + name_hash_map m_fvar_subst; + unsigned m_depth; + + /* Scope-aware cache for (ptr, depth) → expr with lazy staleness detection. */ + struct key_hasher { + std::size_t operator()(std::pair const & p) const { + return hash((size_t)p.first >> 3, p.second); + } + }; + typedef std::pair cache_key; + scope_cache m_cache; + + /* After visit() returns, this holds the maximum fvar-substitution + scope that contributed to the result — i.e., the outermost scope at which the + result is valid and can be cached. Updated monotonically (via max) through + the save/reset/restore pattern in visit(). */ + unsigned m_result_scope; + + /* Write-back support: in outer mode, normalize and write back direct MVar + assignments. Downstream code (e.g. MutualDef.mkInitialUsedFVarsMap) reads + stored assignments and expects inner delayed-assigned MVars to be resolved. */ + name_set m_already_normalized; + std::vector m_saved; + + /* Resolvability caches — persistent across all delayed-assigned MVar + resolutions. A pending MVar is resolvable if its assigned value + (normalized by pass 1) would become MVar-free after resolution: all + remaining MVars must be delayed-assigned MVars in app position with + enough arguments, whose own pending MVars are also resolvable. */ + lean::unordered_map m_resolvable_expr_cache; + name_hash_map m_resolvable_pending_cache; /* 0 = in-progress, 1 = yes, 2 = no */ + + bool is_resolvable_pending(name const & pending) { + auto it = m_resolvable_pending_cache.find(pending); + if (it != m_resolvable_pending_cache.end()) + return it->second == 1; + /* Mark in-progress (cycle guard — shouldn't happen). */ + m_resolvable_pending_cache[pending] = 0; + option_ref r = get_mvar_assignment(m_mctx, pending); + if (!r) { + m_resolvable_pending_cache[pending] = 2; + return false; + } + bool ok = is_resolvable_expr(expr(r.get_val())); + m_resolvable_pending_cache[pending] = ok ? 1 : 2; + return ok; + } + + bool is_resolvable_expr(expr const & e) { + if (!has_expr_mvar(e)) return true; + if (is_shared(e)) { + auto it = m_resolvable_expr_cache.find(e.raw()); + if (it != m_resolvable_expr_cache.end()) + return it->second; + } + bool r = is_resolvable_expr_core(e); + if (is_shared(e)) + m_resolvable_expr_cache[e.raw()] = r; + return r; + } + + bool is_resolvable_expr_core(expr const & e) { + switch (e.kind()) { + case expr_kind::MVar: + /* Bare MVar — direct MVar assignments were resolved by pass 1. Stuck. */ + return false; + case expr_kind::App: { + expr const & f = get_app_fn(e); + if (is_mvar(f)) { + name const & mid = mvar_name(f); + option_ref d = + get_delayed_mvar_assignment(m_mctx, mid); + if (!d) return false; + array_ref fvars = delayed_assignment_fvars(d.get_val()); + if (fvars.size() > get_app_num_args(e)) + return false; /* not enough args */ + name mid_pending = delayed_assignment_mvar_id_pending(d.get_val()); + if (!is_resolvable_pending(mid_pending)) + return false; + /* Check args too. */ + expr const * curr = &e; + while (is_app(*curr)) { + if (!is_resolvable_expr(app_arg(*curr))) + return false; + curr = &app_fn(*curr); + } + return true; + } + return is_resolvable_expr(app_fn(e)) && is_resolvable_expr(app_arg(e)); + } + case expr_kind::Lambda: case expr_kind::Pi: + return is_resolvable_expr(binding_domain(e)) && + is_resolvable_expr(binding_body(e)); + case expr_kind::Let: + return is_resolvable_expr(let_type(e)) && + is_resolvable_expr(let_value(e)) && + is_resolvable_expr(let_body(e)); + case expr_kind::MData: + return is_resolvable_expr(mdata_expr(e)); + case expr_kind::Proj: + return is_resolvable_expr(proj_expr(e)); + default: + return true; + } + } + + /* Outer mode: no fvar substitution active; inner mode: inside a + resolvable delayed-assigned MVar with fvars mapped to arguments. */ + bool in_outer_mode() const { + return m_fvar_subst.empty(); + } + + optional lookup_fvar(name const & fid) { + auto it = m_fvar_subst.find(fid); + if (it == m_fvar_subst.end()) + return optional(); + m_result_scope = std::max(m_result_scope, it->second.scope); + unsigned d = m_depth - it->second.depth; + if (d == 0) + return optional(it->second.value); + return optional(lift_loose_bvars(it->second.value, d)); + } + + /* Get a direct MVar assignment. Visit it to resolve delayed-assigned + MVars and apply the fvar substitution. + In outer mode, normalize and write back the result to the mctx. + Downstream code (e.g. MutualDef.mkInitialUsedFVarsMap) reads stored + assignments and expects inner delayed-assigned MVars to be resolved. + In inner mode, no write-back: the result contains fvar-substituted + terms not suitable for the mctx. */ + optional get_assignment(name const & mid) { + option_ref r = get_mvar_assignment(m_mctx, mid); + if (!r) + return optional(); + expr a(r.get_val()); + if (in_outer_mode()) { + if (m_already_normalized.contains(mid)) + return optional(a); + m_already_normalized.insert(mid); + expr a_new = visit(a); + if (!is_eqp(a, a_new)) { + m_saved.push_back(a); + assign_mvar(m_mctx, mid, a_new); + } + return optional(a_new); + } else { + return optional(visit(a)); + } + } + + expr visit_delayed(array_ref const & fvars, name const & mid_pending, + expr const & e) { + buffer args; + expr const * curr = &e; + while (is_app(*curr)) { + args.push_back(visit(app_arg(*curr))); + curr = &app_fn(*curr); + } + + size_t fvar_count = fvars.size(); + size_t extra_count = args.size() - fvar_count; + + /* Push a new scope and extend the fvar substitution. */ + m_cache.push(); + struct saved_entry { name key; bool had_old; fvar_subst_entry old; }; + std::vector saved_entries; + saved_entries.reserve(fvar_count); + for (size_t i = 0; i < fvar_count; i++) { + name const & fid = fvar_name(fvars[i]); + auto old_it = m_fvar_subst.find(fid); + if (old_it != m_fvar_subst.end()) { + saved_entries.push_back({fid, true, old_it->second}); + } else { + saved_entries.push_back({fid, false, {0, 0, expr()}}); + } + m_fvar_subst[fid] = {m_depth, m_cache.scope(), args[args.size() - 1 - i]}; + } + + /* Get the pending MVar's value directly — it must be assigned (pass 1 + pre-normalized it). No write-back: we are in inner mode. */ + option_ref pending_val = get_mvar_assignment(m_mctx, mid_pending); + lean_assert(!!pending_val); + expr val_new = visit(expr(pending_val.get_val())); + + /* Pop scope; stale entries are detected by generation mismatch on lookup. */ + m_cache.pop(); + + /* The result no longer depends on the popped scope — all fvars from + that scope have been substituted. Clamp result_scope to the current + scope so that cache entries for this result (and ancestors) are not + spuriously invalidated. Dependencies on outer scopes (from fvars of + enclosing delayed MVars) are preserved since they are ≤ m_scope. */ + m_result_scope = std::min(m_result_scope, m_cache.scope()); + + /* Restore the fvar substitution. */ + for (auto & se : saved_entries) { + if (!se.had_old) { + m_fvar_subst.erase(se.key); + } else { + m_fvar_subst[se.key] = se.old; + } + } + + /* Use apply_beta instead of mk_rev_app: pass 1's beta-reduction may have + changed delayed-assigned MVar arguments (e.g., substituting a bvar with a + concrete value), so the resolved pending MVar value may be a lambda that + needs beta-reduction with the extra args. */ + bool preserve_data = false; + bool zeta = true; + return apply_beta(val_new, extra_count, args.data(), preserve_data, zeta); + } + + /* Visit an application whose head is not an MVar, preserving sharing of + application prefixes. */ + expr visit_nonmvar_app(expr const & e) { + expr new_a = visit(app_arg(e)); + expr const & fn = app_fn(e); + expr new_f = is_app(fn) ? visit_nonmvar_app(fn) : visit(fn); + return update_app(e, new_f, new_a); + } + + expr visit_app(expr const & e) { + expr const & f = get_app_fn(e); + if (!is_mvar(f)) { + return visit_nonmvar_app(e); + } + name const & mid = mvar_name(f); + /* Direct MVar assignments were resolved by pass 1. */ + lean_assert(!get_mvar_assignment(m_mctx, mid)); + /* Check for delayed-assigned MVar. */ + option_ref d = get_delayed_mvar_assignment(m_mctx, mid); + if (!d) { + return visit_nonmvar_app(e); + } + array_ref fvars = delayed_assignment_fvars(d.get_val()); + name mid_pending = delayed_assignment_mvar_id_pending(d.get_val()); + if (fvars.size() > get_app_num_args(e)) { + return visit_nonmvar_app(e); + } + if (is_resolvable_pending(mid_pending)) { + /* Updateable delayed-assigned MVar: cross the cut into inner mode. */ + return visit_delayed(fvars, mid_pending, e); + } else { + /* Non-resolvable delayed-assigned MVars only appear in outer mode: + inside a resolvable subtree all nested delayed-assigned MVars are + resolvable too. */ + lean_assert(in_outer_mode()); + /* Normalize the pending MVar's value for mctx write-back + (see write-back comment above). */ + (void)get_assignment(mid_pending); + return visit_nonmvar_app(e); + } + } + + expr visit_fvar(expr const & e) { + name const & fid = fvar_name(e); + if (auto r = lookup_fvar(fid)) { + return *r; + } + return e; + } + +public: + instantiate_delayed_fn(metavar_ctx & mctx) + : m_mctx(mctx), m_depth(0), m_result_scope(0) {} + + expr visit(expr const & e) { + if ((!has_fvar(e) || in_outer_mode()) && !has_expr_mvar(e)) + return e; + + bool shared = false; + if (is_shared(e)) { + if (auto r = m_cache.lookup(cache_key(e.raw(), m_depth), m_result_scope)) + return *r; + shared = true; + } + + /* Save and reset the result scope for this subtree. + After computing, cache_insert uses m_result_scope to place the entry + at the outermost valid scope level. Then we restore the parent's + watermark, taking the max with our contribution. */ + unsigned saved_result_scope = m_result_scope; + m_result_scope = 0; + + expr r; + switch (e.kind()) { + case expr_kind::BVar: + case expr_kind::Lit: + lean_unreachable(); + case expr_kind::FVar: + r = visit_fvar(e); + goto done; /* skip caching for fvars */ + case expr_kind::Sort: + case expr_kind::Const: + /* Sorts and Consts have no fvars and no expr MVars (level MVars + were resolved by pass 1), so the early exit above catches them. */ + lean_unreachable(); + case expr_kind::MVar: + /* Bare MVars in pass 2 are unassigned direct MVars: direct MVar + assignments were resolved by pass 1, and resolvable pending MVar + values contain no bare unassigned MVars. They only appear in + outer mode (at the top level or during write-back normalization). */ + lean_assert(in_outer_mode()); + lean_assert(!get_mvar_assignment(m_mctx, mvar_name(e))); + r = e; + goto done; + case expr_kind::MData: + r = update_mdata(e, visit(mdata_expr(e))); + break; + case expr_kind::Proj: + r = update_proj(e, visit(proj_expr(e))); + break; + case expr_kind::App: + r = visit_app(e); + break; + case expr_kind::Pi: case expr_kind::Lambda: { + expr d = visit(binding_domain(e)); + m_depth++; + expr b = visit(binding_body(e)); + m_depth--; + r = update_binding(e, d, b); + break; + } + case expr_kind::Let: { + expr t = visit(let_type(e)); + expr v = visit(let_value(e)); + m_depth++; + expr b = visit(let_body(e)); + m_depth--; + r = update_let(e, t, v, b); + break; + } + } + if (shared) { + r = m_cache.insert(cache_key(e.raw(), m_depth), r, m_result_scope); + } + + done: + m_result_scope = std::max(saved_result_scope, m_result_scope); + return r; + } + + expr operator()(expr const & e) { return visit(e); } +}; + +/* ============================================================================ + Entry points: run pass 1 then pass 2. + ============================================================================ */ + +static object * run_instantiate_all(object * m, object * e) { + metavar_ctx mctx(m); + + /* Pass 1: instantiate updateable direct MVars, pre-normalize pending MVar values. */ + instantiate_direct_fn pass1(mctx); + expr e1 = pass1(expr(e)); + + /* Pass 2: resolve delayed-assigned MVars with fused fvar substitution. + Skip if pass 1 found no delayed-assigned MVars with assigned pending + MVars — none need resolution or write-back. */ + expr e2; + if (!pass1.has_updateable_delayed()) { + e2 = e1; + } else { + instantiate_delayed_fn pass2(mctx); + e2 = pass2(e1); + } + + /* (mctx, expr) */ + object * r = alloc_cnstr(0, 2, 0); + cnstr_set(r, 0, mctx.steal()); + cnstr_set(r, 1, e2.steal()); + return r; +} + +extern "C" LEAN_EXPORT object * lean_instantiate_level_mvars(object * m, object * l) { + metavar_ctx mctx(m); + level l_new = instantiate_lmvars_all_fn(mctx)(level(l)); + object * r = alloc_cnstr(0, 2, 0); + cnstr_set(r, 0, mctx.steal()); + cnstr_set(r, 1, l_new.steal()); + return r; +} + +extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars(object * m, object * e) { + return run_instantiate_all(m, e); +} +} diff --git a/src/library/scope_cache.h b/src/library/scope_cache.h new file mode 100644 index 0000000000..3b0d6114bd --- /dev/null +++ b/src/library/scope_cache.h @@ -0,0 +1,176 @@ +/* +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Authors: Joachim Breitner +*/ +#pragma once +#include +#include +#include "util/alloc.h" +#include "runtime/optional.h" + +namespace lean { + +/* +Conceptually, the scope cache is a stack of `Key → (Value × Scope)` hashmaps. +The `Scope` is a counter indicating the lowest position in the stack for which +the entry is valid. + +Its purpose is to provide caching for an operation that: + * maintains scopes (e.g. local contexts, substitutions). Higher stack + positions correspond to inner, more local scopes. + * For a given key, the result may depend on all or part of that scope. + * At lookup time, it is not known whether the value for a key will depend on + all or part of the scope, so only entries for the current innermost scope + are considered. + * At insert time, it is known which outermost scope the result depends on + (the "result scope"), and the result is valid for all scopes between that + and the innermost scope. + +The operations are: + * push(): push a new (empty) hashmap onto the stack. + * pop(): pop the top hashmap from the stack. + * scope(): current size of the stack (i.e. the index of the innermost scope). + * lookup(key, result_scope): look up in the top hashmap, returning the value + and propagating its result scope into `result_scope` via max. + * insert(key, value, result_scope): insert a key-value pair into the hashmaps + in the stack at depths in the range from `result_scope` to `scope()`. If it + encounters an existing value along the way, uses and returns that value for + improved sharing. + +The implementation inverts the data structure to a hashmap of stacks for +efficiency. It uses a generation counter to assign a unique identifier to each +scope, and maintains a persistent linked list of these to represent the current +scope stack. Cache entries are not touched upon pop(); instead they are lazily +cleaned up when accessed (the `rewind` operation). Upon insert, instead of +duplicating the entry for all valid scopes, it stores one entry with the range +of scopes it is valid for. +*/ +template> +class scope_cache { + struct scope_gen_node { + unsigned gen; + scope_gen_node * tail; /* parent scope, or nullptr for scope -1 */ + }; + + struct cache_entry { + Value result; + unsigned scope_level; /* scope at which this entry is (currently) valid */ + scope_gen_node * scope_gens; /* snapshot of scope_gens list at store time */ + unsigned result_scope; /* maximum scope that contributed to the result */ + }; + + typedef lean::unordered_map, Hash> cache_map; + + cache_map m_cache; + std::deque m_gen_arena; + scope_gen_node * m_scope_gens_list; + unsigned m_gen_counter; + unsigned m_scope; + +public: + scope_cache() : m_scope_gens_list(nullptr), m_gen_counter(0), m_scope(0) { + m_gen_arena.push_back({0, nullptr}); + m_scope_gens_list = &m_gen_arena.back(); + } + + unsigned scope() const { return m_scope; } + + /* Enter a new scope. Bumps the generation counter so that stale + entries at the new scope level are detected on lookup. */ + void push() { + m_scope++; + m_gen_counter++; + m_gen_arena.push_back({m_gen_counter, m_scope_gens_list}); + m_scope_gens_list = &m_gen_arena.back(); + } + + /* Leave the current scope. Follows the tail of the persistent + generation list back to the parent scope. */ + void pop() { + m_scope--; + m_scope_gens_list = m_scope_gens_list->tail; + } + +private: + /* Lazily clean up the top of a per-key entry stack: degrade entries + whose scopes were popped and evict entries that are stale due to + popped result scopes or scope re-entry. After rewind, either the + stack is empty or its top entry satisfies scope_level <= m_scope + with a matching scope generation. */ + void rewind(std::vector & stack) { + while (!stack.empty()) { + auto & top = stack.back(); + /* Discard entries whose result depends on popped scopes. */ + if (top.result_scope > m_scope) { + stack.pop_back(); + continue; + } + /* Degrade: follow tail pointers for scopes that were popped. */ + while (top.scope_level > m_scope) { + top.scope_gens = top.scope_gens->tail; + top.scope_level--; + } + /* Check generation at scope_level. When scope_level < m_scope, + walk the current list down to scope_level first. */ + scope_gen_node * current_node = m_scope_gens_list; + for (unsigned i = m_scope; i > top.scope_level; i--) + current_node = current_node->tail; + if (top.scope_gens->gen == current_node->gen) return; + /* Generation mismatch: scope was re-entered. Walk both lists + in lockstep to find a valid level or exhaust to result_scope. */ + scope_gen_node * entry_node = top.scope_gens; + unsigned level = top.scope_level; + while (level > top.result_scope) { + entry_node = entry_node->tail; + current_node = current_node->tail; + level--; + if (entry_node->gen == current_node->gen) { + top.scope_level = level; + top.scope_gens = entry_node; + return; /* now scope_level < m_scope */ + } + } + /* No valid level found → discard. */ + stack.pop_back(); + } + } + +public: + /* Look up a cached result for the given key at the current scope. + On hit, updates `result_scope = max(result_scope, entry.result_scope)` + and returns the cached result. On miss, returns none. */ + optional lookup(Key const & key, unsigned & result_scope) { + auto it = m_cache.find(key); + if (it == m_cache.end()) return {}; + auto & stack = it->second; + rewind(stack); + if (stack.empty()) return {}; + auto & top = stack.back(); + if (top.scope_level != m_scope) return {}; + result_scope = std::max(result_scope, top.result_scope); + return optional(top.result); + } + + /* Insert a result for the given key at the current scope. + `result_scope` is the maximum scope that contributed to the result; + the entry is only valid when all scopes up to result_scope are unchanged. + If a valid entry with the same `result_scope` already exists, its value + is reused for sharing; the returned reference is the stored value. */ + Value const & insert(Key const & key, Value const & result, unsigned result_scope) { + auto & stack = m_cache[key]; + rewind(stack); + Value shared = result; + if (!stack.empty() && stack.back().result_scope == result_scope) { + shared = stack.back().result; + } + while (!stack.empty() && stack.back().scope_level >= result_scope) { + stack.pop_back(); + } + stack.push_back({std::move(shared), m_scope, m_scope_gens_list, result_scope}); + return stack.back().result; + } +}; + +} diff --git a/tests/elab/1179b.lean.out.expected b/tests/elab/1179b.lean.out.expected index d9985ac3a8..ba4c34d35d 100644 --- a/tests/elab/1179b.lean.out.expected +++ b/tests/elab/1179b.lean.out.expected @@ -2,15 +2,13 @@ def Foo.bar.match_1.{u_1} : {l₂ : Nat} → (motive : Foo l₂ → Sort u_1) → (t₂ : Foo l₂) → ((s₁ : Foo l₂) → motive s₁.cons) → ((x : Foo l₂) → motive x) → motive t₂ := fun {l₂} motive t₂ h_1 h_2 => - (fun t₂_1 => - Foo.bar._sparseCasesOn_1 (motive := fun a x => l₂ = a → t₂ ≍ x → motive t₂) t₂_1 - (fun {l} t h => - Eq.ndrec (motive := fun {l} => (t : Foo l) → t₂ ≍ t.cons → motive t₂) - (fun t h => Eq.symm (eq_of_heq h) ▸ h_1 t) h t) - fun h h_3 => - Eq.ndrec (motive := fun a => (t₂_2 : Foo a) → Nat.hasNotBit 2 t₂_2.ctorIdx → t₂ ≍ t₂_2 → motive t₂) - (fun t₂_2 h h_4 => - Eq.ndrec (motive := fun t₂_3 => Nat.hasNotBit 2 t₂_3.ctorIdx → motive t₂) (fun h => h_2 t₂) (eq_of_heq h_4) - h) - h_3 t₂_1 h) - t₂ (Eq.refl l₂) (HEq.refl t₂) + Foo.bar._sparseCasesOn_1 (motive := fun a x => l₂ = a → t₂ ≍ x → motive t₂) t₂ + (fun {l} t h => + Eq.ndrec (motive := fun {l} => (t : Foo l) → t₂ ≍ t.cons → motive t₂) (fun t h => Eq.symm (eq_of_heq h) ▸ h_1 t) h + t) + (fun h h_3 => + Eq.ndrec (motive := fun a => (t₂_1 : Foo a) → Nat.hasNotBit 2 t₂_1.ctorIdx → t₂ ≍ t₂_1 → motive t₂) + (fun t₂_1 h h_4 => + Eq.ndrec (motive := fun t₂_2 => Nat.hasNotBit 2 t₂_2.ctorIdx → motive t₂) (fun h => h_2 t₂) (eq_of_heq h_4) h) + h_3 t₂ h) + (Eq.refl l₂) (HEq.refl t₂) diff --git a/tests/elab/depElim1.lean.out.expected b/tests/elab/depElim1.lean.out.expected index 9a84846e5b..6104b68b92 100644 --- a/tests/elab/depElim1.lean.out.expected +++ b/tests/elab/depElim1.lean.out.expected @@ -24,29 +24,28 @@ def elimTest2.{u_1, u_2} : (α : Type u_1) → (x : α) → (xs : Vec α n) → (y : α) → (ys : Vec α n) → motive (n + 1) (Vec.cons x xs) (Vec.cons y ys)) → motive n xs ys := fun α motive n xs ys h_1 h_2 => - (fun xs_1 => - Vec.casesOn (motive := fun a x => n = a → xs ≍ x → motive n xs ys) xs_1 - (fun h => - Eq.ndrec (motive := fun n => (xs ys : Vec α n) → xs ≍ Vec.nil → motive n xs ys) - (fun xs ys h => - ⋯ ▸ - Vec.casesOn (motive := fun a x => 0 = a → ys ≍ x → motive 0 Vec.nil ys) ys (fun h h_3 => ⋯ ▸ h_1 ()) - (fun {n} a a_1 h => False.elim ⋯) ⋯ ⋯) - ⋯ xs ys) - fun {n_1} a a_1 h => - Eq.ndrec (motive := fun n => (xs ys : Vec α n) → xs ≍ Vec.cons a a_1 → motive n xs ys) - (fun xs ys h => - ⋯ ▸ - Vec.casesOn (motive := fun a_2 x => n_1 + 1 = a_2 → ys ≍ x → motive (n_1 + 1) (Vec.cons a a_1) ys) ys - (fun h => False.elim ⋯) - (fun {n} a_2 a_3 h => - n_1.elimOffset n 1 h fun x => - Eq.ndrec (motive := fun {n} => - (a_4 : Vec α n) → ys ≍ Vec.cons a_2 a_4 → motive (n_1 + 1) (Vec.cons a a_1) ys) - (fun a_4 h => ⋯ ▸ h_2 n_1 a a_1 a_2 a_4) x a_3) - ⋯ ⋯) - ⋯ xs ys) - xs ⋯ ⋯ + Vec.casesOn (motive := fun a x => n = a → xs ≍ x → motive n xs ys) xs + (fun h => + Eq.ndrec (motive := fun n => (xs ys : Vec α n) → xs ≍ Vec.nil → motive n xs ys) + (fun xs ys h => + ⋯ ▸ + Vec.casesOn (motive := fun a x => 0 = a → ys ≍ x → motive 0 Vec.nil ys) ys (fun h h_3 => ⋯ ▸ h_1 ()) + (fun {n} a a_1 h => False.elim ⋯) ⋯ ⋯) + ⋯ xs ys) + (fun {n_1} a a_1 h => + Eq.ndrec (motive := fun n => (xs ys : Vec α n) → xs ≍ Vec.cons a a_1 → motive n xs ys) + (fun xs ys h => + ⋯ ▸ + Vec.casesOn (motive := fun a_2 x => n_1 + 1 = a_2 → ys ≍ x → motive (n_1 + 1) (Vec.cons a a_1) ys) ys + (fun h => False.elim ⋯) + (fun {n} a_2 a_3 h => + n_1.elimOffset n 1 h fun x => + Eq.ndrec (motive := fun {n} => + (a_4 : Vec α n) → ys ≍ Vec.cons a_2 a_4 → motive (n_1 + 1) (Vec.cons a a_1) ys) + (fun a_4 h => ⋯ ▸ h_2 n_1 a a_1 a_2 a_4) x a_3) + ⋯ ⋯) + ⋯ xs ys) + ⋯ ⋯ elimTest3 : forall (α : Type.{u_1}) (β : Type.{u_2}) (motive : (List.{u_1} α) -> (List.{u_2} β) -> Sort.{u_3}) (x : List.{u_1} α) (y : List.{u_2} β), (Unit -> (motive (List.nil.{u_1} α) (List.nil.{u_2} β))) -> (forall (a : α) (b : β), motive (List.cons.{u_1} α a (List.nil.{u_1} α)) (List.cons.{u_2} β b (List.nil.{u_2} β))) -> (forall (a₁ : α) (a₂ : α) (as : List.{u_1} α) (b₁ : β) (b₂ : β) (bs : List.{u_2} β), motive (List.cons.{u_1} α a₁ (List.cons.{u_1} α a₂ as)) (List.cons.{u_2} β b₁ (List.cons.{u_2} β b₂ bs))) -> (forall (as : List.{u_1} α) (bs : List.{u_2} β), motive as bs) -> (motive x y) def elimTest3.{u_1, u_2, u_3} : (α : Type u_1) → (β : Type u_2) → diff --git a/tests/elab/scopeCacheProofs.lean b/tests/elab/scopeCacheProofs.lean new file mode 100644 index 0000000000..9390dd5213 --- /dev/null +++ b/tests/elab/scopeCacheProofs.lean @@ -0,0 +1,1690 @@ +module + +/-! +# Scope Cache: Specification and Verification + +This file was mostly generated by Claude and verified by Lean. If the proofs +here break due to upstream changes and fixing them is a hassle, it is fine to +just delete this test file. + +This file models the `scope_cache` data structure from `src/library/scope_cache.h` +and verifies that the optimized implementation refines the simple specification. + +## Structure + +- `Spec`: The specification — an eager entry-list model where stale entries are + cleaned immediately on pop. No generation counters needed. +- `Imp`: The implementation — a lazy entry-list model with generation-based + staleness detection (matching the C++ `scope_cache`). +- `Tests`: Exhaustive testing of both models on all bounded operation sequences. +- `Proofs`: Formal proof that Imp refines Spec via a simulation invariant. +-/ + +namespace ScopeCache + +/-! ## Specification: Eager entry-list model + +Each key independently maintains a list of entries. Each entry has: +- `result`: the cached value +- `scopeLevel`: the scope at which it was inserted +- `resultScope`: the outermost scope the result depends on + +The entry covers scope levels `resultScope` through `scopeLevel`. +Entries are non-overlapping and sorted by `scopeLevel`. + +On pop, entries whose `resultScope` exceeds the new scope are discarded. +This is the "eager" cleanup — no lazy staleness detection needed. +-/ +namespace Spec + +structure Entry where + result : Nat + scopeLevel : Nat + resultScope : Nat + deriving Repr, DecidableEq, BEq + +structure State where + entries : List Entry + scopeN : Nat + deriving Repr, DecidableEq, BEq + +def empty : State := { entries := [], scopeN := 0 } + +def push (s : State) : State := + { s with scopeN := s.scopeN + 1 } + +def pop (s : State) : State := + let newScope := s.scopeN - 1 + { entries := (s.entries.filter fun e => e.resultScope ≤ newScope).map fun e => + if e.scopeLevel > newScope then { e with scopeLevel := newScope } else e + scopeN := newScope } + +def lookup (s : State) : Option (Nat × Nat) := + match s.entries.getLast? with + | some top => if top.scopeLevel == s.scopeN then some (top.result, top.resultScope) else none + | none => none + +def insert (s : State) (value : Nat) (resultScope : Nat) : State × Nat := + let shared := match s.entries.getLast? with + | some top => if top.resultScope == resultScope then top.result else value + | none => value + let entries' := s.entries.filter fun e => e.scopeLevel < resultScope + let newEntry : Entry := { result := shared, scopeLevel := s.scopeN, resultScope } + ({ s with entries := entries' ++ [newEntry] }, shared) + +end Spec + +/-! ## Implementation: Lazy entry-list model with generation counters + +Models the C++ `scope_cache` for a single key. Entries are NOT cleaned on pop; +instead, they are lazily validated via generation counters when accessed (rewind). +-/ +namespace Imp + +abbrev GenList := List Nat + +structure Entry where + result : Nat + scopeLevel : Nat + scopeGens : GenList -- snapshot of gen list at insert time + resultScope : Nat + deriving Repr, DecidableEq, BEq, Inhabited + +structure State where + entries : List Entry + scopeN : Nat + genCounter : Nat + currentGens : GenList -- head = gen for scopeN + deriving Repr + +def empty : State := + { entries := [], scopeN := 0, genCounter := 0, currentGens := [0] } + +def push (s : State) : State := + { s with + scopeN := s.scopeN + 1 + genCounter := s.genCounter + 1 + currentGens := (s.genCounter + 1) :: s.currentGens } + +def pop (s : State) : State := + { s with + scopeN := s.scopeN - 1 + currentGens := s.currentGens.tail } + +/-- Search from `level` down to `resultScope` for a level where entry and current gens match. + Returns the matching level and the entry's remaining gen list at that level. -/ +def findValidLevel (eGens cGens : GenList) (level resultScope : Nat) : + Option (Nat × GenList) := + match eGens.head?, cGens.head? with + | some eg, some cg => + if eg == cg then some (level, eGens) + else if h : level ≤ resultScope then none + else findValidLevel eGens.tail cGens.tail (level - 1) resultScope + | _, _ => none +termination_by level - resultScope + +/-- Rewind: clean up the entry stack from the back. -/ +def rewind (entries : List Entry) (scopeN : Nat) (currentGens : GenList) : List Entry := + go entries.reverse [] +where + go : List Entry → List Entry → List Entry + | [], acc => acc + | top :: rest, acc => + if top.resultScope > scopeN then + go rest acc + else + let degradedLevel := min top.scopeLevel scopeN + let degradedGens := top.scopeGens.drop (top.scopeLevel - degradedLevel) + let currentGensAligned := currentGens.drop (scopeN - degradedLevel) + match findValidLevel degradedGens currentGensAligned degradedLevel top.resultScope with + | some (lvl, eGens) => + rest.reverse ++ [{ top with scopeLevel := lvl, scopeGens := eGens }] ++ acc + | none => go rest acc + +def lookup (s : State) : Option (Nat × Nat) := + let entries' := rewind s.entries s.scopeN s.currentGens + match entries'.getLast? with + | some top => if top.scopeLevel == s.scopeN then some (top.result, top.resultScope) else none + | none => none + +def insert (s : State) (value : Nat) (resultScope : Nat) : State × Nat := + let entries' := rewind s.entries s.scopeN s.currentGens + let shared := match entries'.getLast? with + | some top => if top.resultScope == resultScope then top.result else value + | none => value + let entries'' := entries'.filter fun e => e.scopeLevel < resultScope + let newEntry : Entry := { + result := shared, scopeLevel := s.scopeN, scopeGens := s.currentGens, resultScope + } + ({ s with entries := entries'' ++ [newEntry] }, shared) + +/-- Erase generation info from an Imp entry to get a Spec entry. -/ +def Entry.toSpec (e : Entry) : Spec.Entry := + { result := e.result, scopeLevel := e.scopeLevel, resultScope := e.resultScope } + +end Imp + +/-! ## Tests -/ +namespace Tests + +-- Basic Spec tests +#eval do + let s := Spec.empty + assert! Spec.lookup s == none + let (s, _) := Spec.insert s 100 0 + assert! Spec.lookup s == some (100, 0) + let s := Spec.push s + assert! Spec.lookup s == none -- entry at scope 0, not at scope 1 + let (s, v) := Spec.insert s 200 0 + assert! v == 100 -- sharing + assert! Spec.lookup s == some (100, 0) + let s := Spec.pop s + assert! Spec.lookup s == some (100, 0) + return () + +-- Basic Imp tests +#eval do + let s := Imp.empty + assert! Imp.lookup s == none + let (s, _) := Imp.insert s 100 0 + assert! Imp.lookup s == some (100, 0) + let s := Imp.push s + assert! Imp.lookup s == none + let (s, v) := Imp.insert s 200 0 + assert! v == 100 -- sharing + assert! Imp.lookup s == some (100, 0) + let s := Imp.pop s + assert! Imp.lookup s == some (100, 0) + return () + +-- Re-entry test +#eval do + let s := Imp.empty + let s := Imp.push s + let (s, _) := Imp.insert s 100 1 + assert! Imp.lookup s == some (100, 1) + let s := Imp.pop s + let s := Imp.push s -- re-enter scope 1 + assert! Imp.lookup s == none -- stale! + return () + +-- Degradation test +#eval do + let s := Imp.empty + let s := Imp.push s + let s := Imp.push s + let s := Imp.push s -- scope 3 + let (s, _) := Imp.insert s 100 1 + assert! Imp.lookup s == some (100, 1) + let s := Imp.pop s -- scope 2 + assert! Imp.lookup s == some (100, 1) + let s := Imp.pop s -- scope 1 + assert! Imp.lookup s == some (100, 1) + return () + +-- WalkDown test: all scopes re-entered, walkDown must iterate to scope 0 +#eval do + let s := Imp.empty + let s := Imp.push s -- scope 1 + let s := Imp.push s -- scope 2 + let s := Imp.push s -- scope 3 + let (s, _) := Imp.insert s 100 0 + let s := Imp.pop s -- scope 2 + let s := Imp.pop s -- scope 1 + let s := Imp.pop s -- scope 0 + let s := Imp.push s -- scope 1 (re-enter) + let s := Imp.push s -- scope 2 (re-enter) + let s := Imp.push s -- scope 3 (re-enter) + -- Entry should survive at level 0 (only scope 0 unchanged) + let (s, v) := Imp.insert s 200 0 + assert! v == 100 -- sharing: rewind found entry at level 0 + assert! Imp.lookup s == some (100, 0) + return () + +/-! ### Exhaustive verification + +Verify that Spec and Imp produce identical observable results for ALL +operation sequences up to a bounded depth and length. -/ + +partial def verifySpecImp (maxOps maxScope : Nat) : IO Unit := + go maxOps Spec.empty Imp.empty 0 [] +where + go (fuel : Nat) (spec : Spec.State) (imp : Imp.State) (depth : Nat) + (trace : List String) : IO Unit := do + if fuel == 0 then return + -- push + if depth < maxScope then + let spec' := Spec.push spec + let imp' := Imp.push imp + if Spec.lookup spec' != Imp.lookup imp' then + throw <| IO.userError s!"push lookup mismatch after {trace}" + go (fuel - 1) spec' imp' (depth + 1) (trace ++ ["push"]) + -- pop + if depth > 0 then + let spec' := Spec.pop spec + let imp' := Imp.pop imp + if Spec.lookup spec' != Imp.lookup imp' then + throw <| IO.userError s!"pop lookup mismatch after {trace}" + go (fuel - 1) spec' imp' (depth - 1) (trace ++ ["pop"]) + -- inserts + for v in [0, 1, 2] do + for rs in List.range (depth + 1) do + let (spec', sv) := Spec.insert spec v rs + let (imp', iv) := Imp.insert imp v rs + if sv != iv then + throw <| IO.userError s!"insert sharing mismatch after {trace}: insert {v} {rs}" + if Spec.lookup spec' != Imp.lookup imp' then + throw <| IO.userError s!"insert lookup mismatch after {trace}: insert {v} {rs}" + go (fuel - 1) spec' imp' depth (trace ++ [s!"ins({v},{rs})"]) + +#eval do + verifySpecImp 4 3 + return () + +end Tests + +/-! ## Formal Proofs + +We prove that for any sequence of operations starting from empty, +Spec and Imp produce identical lookup results and insert-sharing values. + +### Approach + +We define a simulation invariant `SimInv` relating Spec and Imp states: +the Imp entries, after rewinding, match the Spec entries (modulo generation info). +We prove: +1. `SimInv` holds for empty states +2. Each operation preserves `SimInv` +3. `SimInv` implies lookup equivalence and insert-sharing equivalence + +Key rewind properties are stated as separate lemmas. +-/ +namespace Proofs + +/-! ### Rewind helper lemmas + +These lemmas capture the essential properties of the `rewind` function needed +for the simulation proof. They express how rewind behaves under push, pop, +and when applied to freshly-inserted entries. -/ + +/-- After Imp.insert, the entries are `filtered ++ [newEntry]` where newEntry + has current scope and gens. Rewind at the same scope is identity because + rewind processes from the back and immediately finds the valid last entry. -/ +theorem rewind_after_insert (filtered : List Imp.Entry) (newEntry : Imp.Entry) + (scopeN : Nat) (gens : Imp.GenList) + (hSL : newEntry.scopeLevel = scopeN) + (hGens : newEntry.scopeGens = gens) + (hRS : newEntry.resultScope ≤ scopeN) + (hGensPos : gens.length > 0) : + Imp.rewind (filtered ++ [newEntry]) scopeN gens = filtered ++ [newEntry] := by + subst hSL; subst hGens + simp only [Imp.rewind] + rw [List.reverse_append, List.reverse_cons, List.reverse_nil, List.nil_append, + List.singleton_append] + rw [Imp.rewind.go] + -- resultScope ≤ scopeLevel, so not discarded + simp only [show (newEntry.resultScope > newEntry.scopeLevel) = False from + eq_false (Nat.not_lt.mpr hRS), ↓reduceIte] + -- degradedLevel = scopeLevel, drops are zero + simp only [Nat.min_self, Nat.sub_self, List.drop_zero] + -- findValidLevel with same eGens and cGens: head? values are equal + obtain ⟨g, gl, hgl⟩ := List.exists_cons_of_length_pos hGensPos + conv => lhs; rw [hgl, Imp.findValidLevel]; simp + rw [← hgl] + +/-- When the first gen check fails due to a fresh gen g, findValidLevel at (n+1) + with g::cGens reduces to findValidLevel at n with eGens.tail and cGens. -/ +theorem findValidLevel_skip_fresh (eGens cGens : Imp.GenList) (n g resultScope : Nat) + (hFreshHead : ∀ k < eGens.length, eGens[k]! < g) (hRS : resultScope ≤ n) : + Imp.findValidLevel eGens (g :: cGens) (n + 1) resultScope = + Imp.findValidLevel eGens.tail cGens n resultScope := by + cases eGens with + | nil => + -- eGens empty → both sides none + simp [Imp.findValidLevel] + | cons v vs => + -- eGens = v :: vs, head? = some v, tail = vs + have hvg : v < g := by have := hFreshHead 0 (by simp); simpa using this + rw [Imp.findValidLevel] + simp [show (v == g) = false from by simp [Nat.ne_of_lt hvg], + show ¬(n + 1 ≤ resultScope) from Nat.not_le.mpr (Nat.lt_succ_of_le hRS)] + +/-- Push rewind equivalence: rewind at (n+1, g::gens) equals rewind at (n, gens) + when g is fresh (strictly greater than all gens in all entries). -/ +theorem rewind_push_equiv (entries : List Imp.Entry) (n g : Nat) (gens : Imp.GenList) + (hFresh : ∀ e ∈ entries, ∀ k < e.scopeGens.length, e.scopeGens[k]! < g) + (hRSSL : ∀ e ∈ entries, e.resultScope ≤ e.scopeLevel) : + Imp.rewind entries (n + 1) (g :: gens) = Imp.rewind entries n gens := by + simp only [Imp.rewind] + suffices ∀ revEntries acc, + (∀ e ∈ revEntries, ∀ k < e.scopeGens.length, e.scopeGens[k]! < g) → + (∀ e ∈ revEntries, e.resultScope ≤ e.scopeLevel) → + Imp.rewind.go (n + 1) (g :: gens) revEntries acc = + Imp.rewind.go n gens revEntries acc by + exact this entries.reverse [] (fun e he => hFresh e (List.mem_reverse.mp he)) + (fun e he => hRSSL e (List.mem_reverse.mp he)) + intro revEntries acc hFreshRev hRSRev + induction revEntries generalizing acc with + | nil => simp [Imp.rewind.go] + | cons top rest ih => + have hTF := hFreshRev top (List.mem_cons_self ..) + have hTR := hRSRev top (List.mem_cons_self ..) + have hRF := fun e he => hFreshRev e (List.mem_cons_of_mem _ he) + have hRR := fun e he => hRSRev e (List.mem_cons_of_mem _ he) + simp only [Imp.rewind.go] + -- Helper: degraded gens freshness (all gens in dropped list are still < g) + have hDGFresh : ∀ d, ∀ k < (top.scopeGens.drop d).length, + (top.scopeGens.drop d)[k]! < g := by + intro d k hk + simp only [List.length_drop] at hk + have hkd : k + d < top.scopeGens.length := Nat.add_lt_of_lt_sub hk + have h1 := hTF (k + d) hkd + rw [show top.scopeGens[k + d]! = (top.scopeGens.drop d)[k]! from by + simp [List.getElem!_eq_getElem?_getD, List.getElem?_drop, Nat.add_comm]] at h1 + exact h1 + -- Main case split: first handle LHS (scope n+1), then RHS (scope n) + by_cases hGN1 : top.resultScope > n + 1 + · -- resultScope > n+1: both discard (also > n) + have hGN : top.resultScope > n := Nat.lt_of_le_of_lt (Nat.le_succ n) hGN1 + simp only [hGN1, hGN, ↓reduceIte]; exact ih acc hRF hRR + · simp only [show ¬(top.resultScope > n + 1) from hGN1, ↓reduceIte] + by_cases hGN : top.resultScope > n + · -- resultScope = n+1: discarded at (n), findValidLevel returns none at (n+1) + simp only [hGN, ↓reduceIte] + have hEq : top.resultScope = n + 1 := Nat.le_antisymm (Nat.not_lt.mp hGN1) hGN + have hSL : top.scopeLevel ≥ n + 1 := hEq ▸ hTR + simp only [show min top.scopeLevel (n + 1) = n + 1 from Nat.min_eq_right hSL, + Nat.sub_self, List.drop_zero] + rw [hEq]; cases hDG : top.scopeGens.drop (top.scopeLevel - (n + 1)) with + | nil => simp [Imp.findValidLevel]; exact ih acc hRF hRR + | cons v vs => + rw [Imp.findValidLevel]; simp only [List.head?_cons] + have hvg : v < g := by + have := hDGFresh (top.scopeLevel - (n + 1)) 0 (by simp [hDG]) + simpa [hDG] using this + simp [show (v == g) = false from by simp [Nat.ne_of_lt hvg]] + exact ih acc hRF hRR + · -- resultScope ≤ n: both not discarded + simp only [show ¬(top.resultScope > n) from hGN, ↓reduceIte] + have hRSN : top.resultScope ≤ n := Nat.not_lt.mp hGN + by_cases hSL : top.scopeLevel ≤ n + · -- scopeLevel ≤ n: identical processing + simp only [show min top.scopeLevel (n + 1) = top.scopeLevel from + Nat.min_eq_left (Nat.le_succ_of_le hSL), + show min top.scopeLevel n = top.scopeLevel from Nat.min_eq_left hSL, + Nat.sub_self, List.drop_zero] + have hDrop : List.drop (n + 1 - top.scopeLevel) (g :: gens) = + List.drop (n - top.scopeLevel) gens := by + rw [show n + 1 - top.scopeLevel = (n - top.scopeLevel) + 1 from by grind]; rfl + rw [hDrop] + -- Both sides now match on the same findValidLevel call + split + · rfl -- some case: identical + · exact ih acc hRF hRR -- none case: induction hypothesis + · -- scopeLevel > n: use findValidLevel_skip_fresh + have hSLgt : top.scopeLevel > n := Nat.not_le.mp hSL + simp only [show min top.scopeLevel (n + 1) = n + 1 from + Nat.min_eq_right (Nat.succ_le_of_lt hSLgt), + show min top.scopeLevel n = n from Nat.min_eq_right (Nat.le_of_lt hSLgt), + Nat.sub_self, List.drop_zero] + have hTailEq : (List.drop (top.scopeLevel - (n + 1)) top.scopeGens).tail = + List.drop (top.scopeLevel - n) top.scopeGens := by + rw [List.tail_drop]; congr 1; grind + rw [findValidLevel_skip_fresh _ _ _ _ _ (hDGFresh _) hRSN, hTailEq] + split + · rfl + · exact ih acc hRF hRR + +/-- findValidLevel returns gens that are a suffix (via drop) of the input eGens. -/ +theorem findValidLevel_gens_suffix {eGens cGens : Imp.GenList} {level resultScope lvl : Nat} + {gs : Imp.GenList} + (h : Imp.findValidLevel eGens cGens level resultScope = some (lvl, gs)) : + ∃ d, gs = eGens.drop d := by + induction eGens generalizing level cGens with + | nil => simp [Imp.findValidLevel] at h + | cons v vs ih => + cases cGens with + | nil => simp [Imp.findValidLevel] at h + | cons w ws => + rw [Imp.findValidLevel] at h + simp only [List.head?_cons] at h + split at h + · simp at h; exact ⟨0, by simp [h.2]⟩ + · split at h + · simp at h + · obtain ⟨d, hd⟩ := ih h + exact ⟨d + 1, by rw [hd]; rfl⟩ + +/-- findValidLevel returns a level ≤ the input level. -/ +theorem findValidLevel_lvl_le {eGens cGens : Imp.GenList} {level resultScope lvl : Nat} + {gs : Imp.GenList} + (h : Imp.findValidLevel eGens cGens level resultScope = some (lvl, gs)) : + lvl ≤ level := by + induction eGens generalizing level cGens with + | nil => simp [Imp.findValidLevel] at h + | cons v vs ih => + cases cGens with + | nil => simp [Imp.findValidLevel] at h + | cons w ws => + rw [Imp.findValidLevel] at h + simp only [List.head?_cons] at h + split at h + · simp at h; exact h.1 ▸ Nat.le_refl _ + · split at h + · simp at h + · exact Nat.le_trans (ih h) (Nat.sub_le level 1) + +/-- findValidLevel returns a level ≥ resultScope. -/ +theorem findValidLevel_rs_le {eGens cGens : Imp.GenList} {level resultScope lvl : Nat} + {gs : Imp.GenList} + (hLR : resultScope ≤ level) + (h : Imp.findValidLevel eGens cGens level resultScope = some (lvl, gs)) : + resultScope ≤ lvl := by + induction eGens generalizing level cGens with + | nil => simp [Imp.findValidLevel] at h + | cons v vs ih => + cases cGens with + | nil => simp [Imp.findValidLevel] at h + | cons w ws => + rw [Imp.findValidLevel] at h + simp only [List.head?_cons] at h + split at h + · simp at h; exact h.1 ▸ hLR + · split at h + · simp at h + · exact ih (by grind) h + +/-- When findValidLevel matches at the starting level, it returns the input gens unchanged. -/ +theorem findValidLevel_match_at_level {eGens cGens : Imp.GenList} {level rs : Nat} + {gs : Imp.GenList} + (h : Imp.findValidLevel eGens cGens level rs = some (level, gs)) : + gs = eGens := by + cases eGens with + | nil => simp [Imp.findValidLevel] at h + | cons v vs => + cases cGens with + | nil => simp [Imp.findValidLevel] at h + | cons w ws => + rw [Imp.findValidLevel] at h + simp only [List.head?_cons] at h + split at h + · simp at h; exact h.symm + · split at h + · simp at h + · -- Recursive case: fvl(vs, ws, level-1, rs) returns (level, gs). + -- But findValidLevel_lvl_le gives level ≤ level - 1, contradicting level ≥ 1. + rename_i _ hlrs + have := findValidLevel_lvl_le h + omega + +/-- The output gens of findValidLevel is a drop-suffix of the input eGens. -/ +theorem findValidLevel_output_eq_input_drop + {eGens cGens : Imp.GenList} {level rs lvl : Nat} {gs : Imp.GenList} + (h : Imp.findValidLevel eGens cGens level rs = some (lvl, gs)) : + gs = eGens.drop (level - lvl) := by + induction eGens generalizing cGens level with + | nil => simp [Imp.findValidLevel] at h + | cons v vs ih => + cases cGens with + | nil => simp [Imp.findValidLevel] at h + | cons w ws => + rw [Imp.findValidLevel] at h + simp only [List.head?_cons] at h + split at h + · simp at h; obtain ⟨hlvl, hgs⟩ := h + subst hlvl; subst hgs; simp + · split at h + · simp at h + · have hrec := ih h + have hlvl_lt : lvl < level := by + have := findValidLevel_lvl_le h; omega + rw [hrec, show level - lvl = (level - 1 - lvl) + 1 from by omega, + List.drop_succ_cons] + +/-- When findValidLevel succeeds and gensSuffix holds between the original entry gens + and the current gens list, the output gens are aligned at the correct offset. -/ +theorem findValidLevel_aligned + {eGens cGens : Imp.GenList} {level rs lvl : Nat} {gs : Imp.GenList} + (hfvl : Imp.findValidLevel eGens cGens level rs = some (lvl, gs)) + {origGens gens : Imp.GenList} {d offset : Nat} + (hE : eGens = origGens.drop d) (hC : cGens = gens.drop offset) + (hGensSuffix : ∀ i < origGens.length, ∀ j < gens.length, + origGens[i]! = gens[j]! → origGens.drop i = gens.drop j) : + gs = gens.drop (offset + (level - lvl)) := by + induction eGens generalizing cGens level d offset with + | nil => simp [Imp.findValidLevel] at hfvl + | cons v vs ih => + cases cGens with + | nil => simp [Imp.findValidLevel] at hfvl + | cons w ws => + rw [Imp.findValidLevel] at hfvl + simp only [List.head?_cons] at hfvl + split at hfvl + · -- Match: v == w + rename_i hvw + simp at hfvl + obtain ⟨hlvl, hgs⟩ := hfvl + subst hlvl; subst hgs; simp + -- From hE: v :: vs = origGens.drop d, so origGens[d]! = v + have hd : d < origGens.length := by + have : (origGens.drop d).length > 0 := by rw [← hE]; simp + simp [List.length_drop] at this; omega + have ho : offset < gens.length := by + have : (gens.drop offset).length > 0 := by rw [← hC]; simp + simp [List.length_drop] at this; omega + have hv : origGens[d]! = v := by + have : (origGens.drop d)[0]! = v := by rw [← hE]; rfl + rwa [List.getElem!_eq_getElem?_getD, List.getElem?_drop, Nat.add_zero, + ← List.getElem!_eq_getElem?_getD] at this + have hw : gens[offset]! = w := by + have : (gens.drop offset)[0]! = w := by rw [← hC]; rfl + rwa [List.getElem!_eq_getElem?_getD, List.getElem?_drop, Nat.add_zero, + ← List.getElem!_eq_getElem?_getD] at this + have heq : origGens[d]! = gens[offset]! := by + rw [hv, hw]; exact beq_iff_eq.mp hvw + rw [← hGensSuffix d hd offset ho heq, ← hE] + · split at hfvl + · simp at hfvl + · -- Recurse: fvl vs ws (level - 1) rs + rename_i hne hlrs + have hE' : vs = origGens.drop (d + 1) := by + have h := congrArg List.tail hE + simp only [List.tail_cons, List.tail_drop] at h + exact h + have hC' : ws = gens.drop (offset + 1) := by + have h := congrArg List.tail hC + simp only [List.tail_cons, List.tail_drop] at h + exact h + have hlvlle := findValidLevel_lvl_le hfvl + have hih := ih hfvl hE' hC' + rw [hih]; congr 1 + omega + +/-- Every entry in rewind output either comes from the input unchanged or has its + scopeGens as a suffix (via drop) of some input entry's scopeGens. -/ +theorem rewind_mem {entries : List Imp.Entry} {n : Nat} {gens : Imp.GenList} + {e : Imp.Entry} (he : e ∈ Imp.rewind entries n gens) : + ∃ e' ∈ entries, e.result = e'.result ∧ e.resultScope = e'.resultScope ∧ + (∃ d, e.scopeGens = e'.scopeGens.drop d) := by + simp only [Imp.rewind] at he + suffices ∀ revEntries acc, + (∀ e ∈ revEntries, e ∈ entries) → + (∀ e ∈ acc, ∃ e' ∈ entries, e.result = e'.result ∧ e.resultScope = e'.resultScope ∧ + ∃ d, e.scopeGens = e'.scopeGens.drop d) → + e ∈ Imp.rewind.go n gens revEntries acc → + ∃ e' ∈ entries, e.result = e'.result ∧ e.resultScope = e'.resultScope ∧ + ∃ d, e.scopeGens = e'.scopeGens.drop d by + exact this entries.reverse [] + (fun e he => List.mem_reverse.mp he) (fun _ h => nomatch h) he + intro revEntries acc hRev hAcc + induction revEntries generalizing acc with + | nil => + simp [Imp.rewind.go]; intro hmem; exact hAcc e hmem + | cons top rest ih => + intro he' + have hTopMem := hRev top (List.mem_cons_self ..) + have hRestMem := fun e he => hRev e (List.mem_cons_of_mem _ he) + simp only [Imp.rewind.go] at he' + split at he' + · exact ih acc hRestMem hAcc he' + · split at he' + · -- Found valid: rest.reverse ++ [modified_top] ++ acc + rename_i _ lvl eGens hfvl + rcases List.mem_append.mp he' with h1 | h2 + · rcases List.mem_append.mp h1 with hr | hm + · exact ⟨e, hRev e (List.mem_cons_of_mem _ (List.mem_reverse.mp hr)), + rfl, rfl, ⟨0, by simp⟩⟩ + · simp only [List.mem_cons, List.mem_nil_iff, or_false] at hm + subst hm + obtain ⟨d, hd⟩ := findValidLevel_gens_suffix hfvl + exact ⟨top, hTopMem, rfl, rfl, + ⟨top.scopeLevel - min top.scopeLevel n + d, by rw [hd, List.drop_drop]⟩⟩ + · exact hAcc e h2 + · exact ih acc hRestMem hAcc he' + +/-- Entries in rewind output have their gen values bounded by the input entries' gen values. -/ +theorem rewind_genBound {entries : List Imp.Entry} {n : Nat} {gens : Imp.GenList} + {bound : Nat} + (hBound : ∀ e ∈ entries, ∀ k < e.scopeGens.length, e.scopeGens[k]! ≤ bound) + {e : Imp.Entry} (he : e ∈ Imp.rewind entries n gens) : + ∀ k < e.scopeGens.length, e.scopeGens[k]! ≤ bound := by + obtain ⟨e', he', _, _, ⟨d, hd⟩⟩ := rewind_mem he + intro k hk; rw [hd] at hk ⊢ + have hkd : k + d < e'.scopeGens.length := by + simp [List.length_drop] at hk; grind + have h1 := hBound e' he' (k + d) hkd + rwa [show e'.scopeGens[k + d]! = (e'.scopeGens.drop d)[k]! from by + simp [List.getElem!_eq_getElem?_getD, List.getElem?_drop, Nat.add_comm]] at h1 + +/-- Entries in rewind output have resultScope ≤ scopeLevel. -/ +theorem rewind_rsLeSl {entries : List Imp.Entry} {n : Nat} {gens : Imp.GenList} + (hRSSL : ∀ e ∈ entries, e.resultScope ≤ e.scopeLevel) + {e : Imp.Entry} (he : e ∈ Imp.rewind entries n gens) : + e.resultScope ≤ e.scopeLevel := by + simp only [Imp.rewind] at he + suffices ∀ revEntries acc, + (∀ e ∈ revEntries, e.resultScope ≤ e.scopeLevel) → + (∀ e ∈ acc, e.resultScope ≤ e.scopeLevel) → + e ∈ Imp.rewind.go n gens revEntries acc → e.resultScope ≤ e.scopeLevel by + exact this entries.reverse [] (fun e he => hRSSL e (List.mem_reverse.mp he)) + (fun _ h => nomatch h) he + intro revEntries acc hRev hAcc + induction revEntries generalizing acc with + | nil => simp [Imp.rewind.go]; intro hmem; exact hAcc e hmem + | cons top rest ih => + intro he' + have hRestRev := fun e he => hRev e (List.mem_cons_of_mem _ he) + simp only [Imp.rewind.go] at he' + split at he' + · exact ih acc hRestRev hAcc he' + · split at he' + · rename_i _ lvl eGens hfvl + rcases List.mem_append.mp he' with h1 | h2 + · rcases List.mem_append.mp h1 with hr | hm + · exact hRev _ (List.mem_cons_of_mem _ (List.mem_reverse.mp hr)) + · simp only [List.mem_cons, List.mem_nil_iff, or_false] at hm + subst hm; simp only + exact findValidLevel_rs_le (by simp [Nat.min_def]; split <;> grind) hfvl + · exact hAcc e h2 + · exact ih acc hRestRev hAcc he' + +/-- Entries are well-ordered: each entry's scopeLevel < every later entry's resultScope. + This holds for entries produced by Imp.insert (which filters scopeLevel < rs). -/ +def Imp.EntryOrdered (entries : List Imp.Entry) : Prop := + entries.Pairwise (fun a b => a.scopeLevel < b.resultScope) + +/-- Rewind preserves entry ordering. -/ +theorem rewind_entryOrdered {entries : List Imp.Entry} {n : Nat} {gens : Imp.GenList} + (hOrd : Imp.EntryOrdered entries) : + Imp.EntryOrdered (Imp.rewind entries n gens) := by + simp only [Imp.rewind, Imp.EntryOrdered] + suffices ∀ revEntries acc, + (revEntries.reverse ++ acc).Pairwise (fun a b => a.scopeLevel < b.resultScope) → + acc.Pairwise (fun a b => a.scopeLevel < b.resultScope) → + (Imp.rewind.go n gens revEntries acc).Pairwise + (fun a b => a.scopeLevel < b.resultScope) by + exact this entries.reverse [] + (by simpa using hOrd) List.Pairwise.nil + intro revEntries acc hAll hAcc + induction revEntries generalizing acc with + | nil => simpa [Imp.rewind.go] + | cons top rest ih => + simp only [Imp.rewind.go] + -- Extract ordering facts from hAll + rw [List.reverse_cons, List.append_assoc, List.singleton_append] at hAll + have hP := List.pairwise_append.mp hAll + have hOrdPre := hP.1 + have hOrdTA := hP.2.1 + have hCross := hP.2.2 + have hTopAcc := List.pairwise_cons.mp hOrdTA + split + · -- top.resultScope > n: skip top + apply ih acc + · rw [List.pairwise_append] + exact ⟨hOrdPre, hTopAcc.2, + fun a ha b hb => hCross a ha b (List.mem_cons_of_mem _ hb)⟩ + · exact hAcc + · -- top.resultScope ≤ n + split + · -- findValidLevel returns some (lvl, eGens) + rename_i _ lvl eGens hfvl + -- Result is rest.reverse ++ [modified_top] ++ acc + -- modified_top has same resultScope, scopeLevel ≤ top.scopeLevel + have hLvlLe : lvl ≤ min top.scopeLevel n := + findValidLevel_lvl_le hfvl + have hLvlLeSL : lvl ≤ top.scopeLevel := + Nat.le_trans hLvlLe (Nat.min_le_left ..) + rw [List.append_assoc, List.pairwise_append] + refine ⟨hOrdPre, ?_, ?_⟩ + · -- Pairwise for [modified_top] ++ acc + rw [List.singleton_append, List.pairwise_cons] + exact ⟨fun b hb => Nat.lt_of_le_of_lt hLvlLeSL (hTopAcc.1 b hb), hTopAcc.2⟩ + · -- ∀ a ∈ rest.reverse, ∀ b ∈ [modified_top] ++ acc, a.sl < b.rs + intro a ha b hb + rw [List.singleton_append, List.mem_cons] at hb + cases hb with + | inl heq => + subst heq; simp only + exact hCross a ha top (List.mem_cons_self ..) + | inr hb' => exact hCross a ha b (List.mem_cons_of_mem _ hb') + · -- findValidLevel returns none: skip top + apply ih acc + · rw [List.pairwise_append] + exact ⟨hOrdPre, hTopAcc.2, + fun a ha b hb => hCross a ha b (List.mem_cons_of_mem _ hb)⟩ + · exact hAcc + + +/-- When all entries in revEntries are trivially valid at the given scope + (findValidLevel matches immediately at their scopeLevel), rewind.go is identity. -/ +theorem rewind_go_all_valid (revEntries acc : List Imp.Entry) + (scope : Nat) (gens : Imp.GenList) + (hRS : ∀ e ∈ revEntries, e.resultScope ≤ scope) + (hSL : ∀ e ∈ revEntries, e.scopeLevel ≤ scope) + (hRSSL : ∀ e ∈ revEntries, e.resultScope ≤ e.scopeLevel) + (hValid : ∀ e ∈ revEntries, + Imp.findValidLevel e.scopeGens (gens.drop (scope - e.scopeLevel)) + e.scopeLevel e.resultScope = some (e.scopeLevel, e.scopeGens)) : + Imp.rewind.go scope gens revEntries acc = revEntries.reverse ++ acc := by + induction revEntries generalizing acc with + | nil => simp [Imp.rewind.go] + | cons top rest ih => + simp only [Imp.rewind.go] + have hTopRS := hRS top (List.mem_cons_self ..) + have hTopSL := hSL top (List.mem_cons_self ..) + have hTopRSSL := hRSSL top (List.mem_cons_self ..) + have hTopValid := hValid top (List.mem_cons_self ..) + simp only [show ¬(top.resultScope > scope) from Nat.not_lt.mpr hTopRS, ↓reduceIte] + simp only [show min top.scopeLevel scope = top.scopeLevel from Nat.min_eq_left hTopSL, + Nat.sub_self, List.drop_zero] at hTopValid ⊢ + rw [hTopValid] + simp [List.reverse_cons] + +/-! ### Simulation invariant -/ + +/-- The simulation invariant relates Spec and Imp states. -/ +structure SimInv (spec : Spec.State) (imp : Imp.State) : Prop where + scopeEq : spec.scopeN = imp.scopeN + entriesEq : spec.entries = + (Imp.rewind imp.entries imp.scopeN imp.currentGens).map Imp.Entry.toSpec + gensLen : imp.currentGens.length = imp.scopeN + 1 + -- All generation values in entries and currentGens are ≤ genCounter + genBound : ∀ e ∈ imp.entries, ∀ k < e.scopeGens.length, e.scopeGens[k]! ≤ imp.genCounter + curGenBound : ∀ k < imp.currentGens.length, imp.currentGens[k]! ≤ imp.genCounter + -- Entries have resultScope ≤ scopeLevel (from insert semantics) + rsLeSl : ∀ e ∈ imp.entries, e.resultScope ≤ e.scopeLevel + -- Entries are well-ordered: each entry's scopeLevel < next entry's resultScope + entryOrdered : Imp.EntryOrdered imp.entries + -- Current gens are all distinct (each push creates a unique gen value) + gensNodup : imp.currentGens.Nodup + -- Gen values act as unique scope identifiers: matching gens imply matching suffixes + gensSuffix : ∀ e ∈ imp.entries, ∀ i < e.scopeGens.length, + ∀ j < imp.currentGens.length, e.scopeGens[i]! = imp.currentGens[j]! → + e.scopeGens.drop i = imp.currentGens.drop j + +/-- SimInv holds for empty states. -/ +theorem simInv_empty : SimInv Spec.empty Imp.empty where + scopeEq := by simp [Spec.empty, Imp.empty] + entriesEq := by simp [Spec.empty, Imp.empty, Imp.rewind, Imp.rewind.go] + gensLen := by simp [Imp.empty] + genBound := by simp [Imp.empty] + curGenBound := by simp [Imp.empty] + rsLeSl := by simp [Imp.empty] + entryOrdered := List.Pairwise.nil + gensNodup := by simp [Imp.empty, List.Nodup] + gensSuffix := by simp [Imp.empty] + +/-! ### Observational equivalence from SimInv -/ + +/-- Lookup equivalence given SimInv. -/ +theorem lookup_equiv (h : SimInv spec imp) : + Spec.lookup spec = Imp.lookup imp := by + simp only [Spec.lookup, Imp.lookup] + rw [h.entriesEq, h.scopeEq] + cases hrw : (Imp.rewind imp.entries imp.scopeN imp.currentGens).getLast? with + | none => simp [List.getLast?_map, hrw] + | some e => simp [List.getLast?_map, hrw, Imp.Entry.toSpec] + +/-- Insert-sharing equivalence given SimInv. -/ +theorem insert_sharing_equiv (h : SimInv spec imp) : + (Spec.insert spec v rs).2 = (Imp.insert imp v rs).2 := by + simp only [Spec.insert, Imp.insert] + rw [h.entriesEq] + cases hrw : (Imp.rewind imp.entries imp.scopeN imp.currentGens).getLast? with + | none => simp [List.getLast?_map, hrw] + | some e => simp [List.getLast?_map, hrw, Imp.Entry.toSpec] + +/-! ### SimInv preservation -/ + +/-- Push preserves SimInv. -/ +theorem push_simInv (h : SimInv spec imp) : + SimInv (Spec.push spec) (Imp.push imp) where + scopeEq := by simp [Spec.push, Imp.push, h.scopeEq] + entriesEq := by + simp only [Spec.push, Imp.push] + have hFresh : ∀ e ∈ imp.entries, + ∀ k < e.scopeGens.length, e.scopeGens[k]! < imp.genCounter + 1 := by + intro e he k hk + exact Nat.lt_succ_of_le (h.genBound e he k hk) + rw [rewind_push_equiv imp.entries imp.scopeN + (imp.genCounter + 1) imp.currentGens hFresh h.rsLeSl] + exact h.entriesEq + gensLen := by simp [Imp.push, h.gensLen] + genBound := by + intro e he k hk + simp only [Imp.push] at he + exact Nat.le_succ_of_le (h.genBound e he k hk) + curGenBound := by + intro k hk + simp only [Imp.push] at hk ⊢ + cases k with + | zero => simp + | succ k => + simp only [List.getElem!_cons_succ] + exact Nat.le_succ_of_le (h.curGenBound k (by + simp only [List.length_cons] at hk; exact Nat.lt_of_succ_lt_succ hk)) + rsLeSl := by + intro e he; simp only [Imp.push] at he; exact h.rsLeSl e he + entryOrdered := by simp only [Imp.push]; exact h.entryOrdered + gensNodup := by + simp only [Imp.push, List.nodup_cons] + refine ⟨?_, h.gensNodup⟩ + intro hmem + have ⟨k, hk, heq⟩ := List.mem_iff_getElem.mp hmem + have hle := h.curGenBound k hk + simp only [List.getElem!_eq_getElem?_getD, List.getElem?_eq_getElem hk, + Option.getD_some] at hle + -- hle : currentGens[k] ≤ genCounter, heq : currentGens[k] = genCounter + 1 + grind + gensSuffix := by + intro e he i hi j hj heq + simp only [Imp.push] at he hj heq ⊢ + cases j with + | zero => + -- gens[0] is the fresh gen (genCounter + 1), can't match any entry gen + simp only [List.getElem!_cons_zero] at heq + exact absurd heq.symm (Nat.ne_of_gt (Nat.lt_succ_of_le (h.genBound e he i hi))) + | succ j => + simp only [List.getElem!_cons_succ] at heq + exact h.gensSuffix e he i hi j (by + simp only [List.length_cons] at hj; exact Nat.lt_of_succ_lt_succ hj) heq + +/-- Key lemma: an entry with scopeGens aligned to gens (via suffix property) is + trivially valid at its scopeLevel. -/ +theorem fvl_self_match (e : Imp.Entry) (gens : Imp.GenList) (scope : Nat) + (_hSL : e.scopeLevel ≤ scope) + (hSuffix : e.scopeGens = gens.drop (scope - e.scopeLevel)) + (hLen : e.scopeGens.length > 0) : + Imp.findValidLevel e.scopeGens (gens.drop (scope - e.scopeLevel)) + e.scopeLevel e.resultScope = some (e.scopeLevel, e.scopeGens) := by + obtain ⟨g, gs, hgs⟩ := List.exists_cons_of_length_pos hLen + have hgs' : gens.drop (scope - e.scopeLevel) = g :: gs := hSuffix ▸ hgs + rw [hgs, hgs', Imp.findValidLevel] + simp + +/-- For entries with sl ≤ n-1, the fvl calls at scope n and scope n-1 are identical + because gens.drop(n - sl) = gens.tail.drop(n-1 - sl). -/ +theorem fvl_drop_shift (gens : Imp.GenList) (n sl : Nat) (hSL : sl ≤ n - 1) (hn : n > 0) : + gens.drop (n - sl) = gens.tail.drop (n - 1 - sl) := by + have h1 : n - sl = (n - 1 - sl) + 1 := by grind + rw [h1, List.drop_tail] + +/-- When fvl succeeds at scope n with level = n (first gens match), + and dg = gens (from gensSuffix), fvl at scope n-1 matches at n-1. -/ +theorem fvl_self_tail (gens : Imp.GenList) (n rs : Nat) (hn : n > 0) + (_hRS : rs ≤ n - 1) (hGensLen : gens.length = n + 1) : + Imp.findValidLevel gens.tail gens.tail (n - 1) rs = some (n - 1, gens.tail) := by + have hTailLen : gens.tail.length = n := by simp [List.length_tail, hGensLen] + obtain ⟨g, gs, hgs⟩ := List.exists_cons_of_length_pos (by rw [hTailLen]; exact hn : gens.tail.length > 0) + rw [hgs, Imp.findValidLevel]; simp + +/-- When fvl at scope n succeeds for an entry, fvl at scope n-1 produces + the degraded version. Takes fvl success as a premise. -/ +theorem fvl_pop_entry (e : Imp.Entry) (n : Nat) (gens : Imp.GenList) + (hn : n > 0) (hRS : e.resultScope ≤ n - 1) + (hGensSuffix : ∀ i < e.scopeGens.length, ∀ j < gens.length, + e.scopeGens[i]! = gens[j]! → e.scopeGens.drop i = gens.drop j) + (hGensLen : gens.length = n + 1) + (lvl : Nat) (eGens : Imp.GenList) + (hfvl : Imp.findValidLevel (e.scopeGens.drop (e.scopeLevel - min e.scopeLevel n)) + (gens.drop (n - min e.scopeLevel n)) (min e.scopeLevel n) e.resultScope = + some (lvl, eGens)) : + Imp.findValidLevel (e.scopeGens.drop (e.scopeLevel - min e.scopeLevel (n - 1))) + (gens.tail.drop ((n - 1) - min e.scopeLevel (n - 1))) + (min e.scopeLevel (n - 1)) e.resultScope = + some (min lvl (n - 1), + if lvl ≤ n - 1 then eGens + else eGens.tail) := by + by_cases hSL : e.scopeLevel ≤ n - 1 + · -- Case: sl ≤ n-1. Both fvl calls are identical. + have hSLn : e.scopeLevel ≤ n := Nat.le_trans hSL (Nat.pred_le n) + rw [Nat.min_eq_left hSLn, Nat.sub_self, List.drop_zero] at hfvl + rw [Nat.min_eq_left hSL, Nat.sub_self, List.drop_zero, + ← fvl_drop_shift gens n e.scopeLevel hSL hn] + have hlvl_le : lvl ≤ n - 1 := Nat.le_trans (findValidLevel_lvl_le hfvl) hSL + simp [Nat.min_eq_left hlvl_le, hlvl_le, hfvl] + · -- Case: sl ≥ n. Unfold one step of fvl at scope n. + have hSLge : e.scopeLevel ≥ n := by omega + rw [Nat.min_eq_right hSLge, show n - n = 0 from Nat.sub_self n, + List.drop_zero] at hfvl + rw [Nat.min_eq_right (by omega : e.scopeLevel ≥ n - 1), + show (n - 1) - (n - 1) = 0 from Nat.sub_self (n - 1), List.drop_zero] + -- Relate the two degraded gens lists + have hdrop_shift : e.scopeGens.drop (e.scopeLevel - (n - 1)) = + (e.scopeGens.drop (e.scopeLevel - n)).tail := by + rw [List.tail_drop]; congr 1; omega + rw [hdrop_shift] + -- dg := e.scopeGens.drop(sl - n) must be non-empty (fvl succeeded on it) + obtain ⟨v, vs, hvvs⟩ : ∃ v vs, e.scopeGens.drop (e.scopeLevel - n) = v :: vs := by + cases hdg : e.scopeGens.drop (e.scopeLevel - n) with + | nil => simp [hdg, Imp.findValidLevel] at hfvl + | cons v vs => exact ⟨v, vs, rfl⟩ + -- gens must be non-empty + obtain ⟨w, ws, hwws⟩ : ∃ w ws, gens = w :: ws := by + cases gens with + | nil => simp at hGensLen + | cons w ws => exact ⟨w, ws, rfl⟩ + rw [hvvs, hwws] at hfvl ⊢ + simp only [List.tail_cons] + -- Unfold fvl at scope n: fvl(v :: vs, w :: ws, n, rs) + rw [Imp.findValidLevel] at hfvl + simp only [List.head?_cons] at hfvl + split at hfvl + · -- Sub-case: v == w (match at level n) + rename_i heq + simp at hfvl + obtain ⟨hlvl_eq, hout_eq⟩ := hfvl + subst hlvl_eq; subst hout_eq + -- lvl = n, eGens = v :: vs + simp only [show ¬(n ≤ n - 1) from by omega, ↓reduceIte, + show min n (n - 1) = n - 1 from Nat.min_eq_right (Nat.pred_le n)] + -- By gensSuffix: e.scopeGens.drop(sl-n) = gens, so v :: vs = w :: ws + have hveq : v = w := by simpa using heq + have hi : e.scopeLevel - n < e.scopeGens.length := by + have : (e.scopeGens.drop (e.scopeLevel - n)).length > 0 := by rw [hvvs]; simp + simp [List.length_drop] at this; omega + have hval : e.scopeGens[e.scopeLevel - n]! = gens[0]! := by + have h1 : (e.scopeGens.drop (e.scopeLevel - n))[0]! = v := by + rw [hvvs]; simp + rw [show (e.scopeGens.drop (e.scopeLevel - n))[0]! = e.scopeGens[e.scopeLevel - n]! from by + simp [List.getElem!_eq_getElem?_getD, List.getElem?_drop]] at h1 + have h2 : gens[0]! = w := by rw [hwws]; simp + rw [h1, h2, hveq] + have hdg_gens : e.scopeGens.drop (e.scopeLevel - n) = gens := + hGensSuffix (e.scopeLevel - n) hi 0 (by rw [hGensLen]; omega) hval + -- Goal: fvl vs ws (n-1) rs = some(n-1, (v :: vs).tail) + -- Since drop(sl-n) = gens, and drop(sl-n) = v :: vs, and gens = w :: ws: + -- v :: vs = w :: ws, so vs = ws and v = w + -- Also (v :: vs).tail = vs = ws = gens.tail + have hvs_ws : vs = ws := by + have h := hdg_gens; rw [hvvs, hwws] at h; exact List.cons.inj h |>.2 + simp only [List.tail_cons, hvs_ws] + -- Goal: fvl ws ws (n-1) rs = some(n-1, ws) + have hws_eq : ws = gens.tail := by rw [hwws]; rfl + rw [hws_eq] + exact fvl_self_tail gens n e.resultScope hn hRS hGensLen + · -- Sub-case: v ≠ w + rename_i hne + split at hfvl + · -- n ≤ rs: impossible since rs ≤ n-1 + omega + · -- n > rs: fvl recurses to fvl(vs, ws, n-1, rs) = some(lvl, eGens) + -- This IS the scope n-1 call, and lvl ≤ n-1 + simp only [List.tail_cons] at hfvl + have hlvl_le : lvl ≤ n - 1 := findValidLevel_lvl_le hfvl + simp [Nat.min_eq_left hlvl_le, hlvl_le, hfvl] + +/-- If fvl at scope n returns none, fvl at scope n-1 also returns none. -/ +theorem fvl_pop_none (e : Imp.Entry) (n : Nat) (gens : Imp.GenList) + (hn : n > 0) (hRS : e.resultScope ≤ n - 1) + (hfvl : Imp.findValidLevel (e.scopeGens.drop (e.scopeLevel - min e.scopeLevel n)) + (gens.drop (n - min e.scopeLevel n)) (min e.scopeLevel n) e.resultScope = none) : + Imp.findValidLevel (e.scopeGens.drop (e.scopeLevel - min e.scopeLevel (n - 1))) + (gens.tail.drop ((n - 1) - min e.scopeLevel (n - 1))) + (min e.scopeLevel (n - 1)) e.resultScope = none := by + by_cases hSL : e.scopeLevel ≤ n - 1 + · -- sl ≤ n-1: both calls are identical + have hSLn : e.scopeLevel ≤ n := Nat.le_trans hSL (Nat.pred_le n) + rw [Nat.min_eq_left hSLn, Nat.sub_self, List.drop_zero] at hfvl + rw [Nat.min_eq_left hSL, Nat.sub_self, List.drop_zero, + ← fvl_drop_shift gens n e.scopeLevel hSL hn] + exact hfvl + · -- sl ≥ n: unfold one step + have hSLge : e.scopeLevel ≥ n := by omega + rw [Nat.min_eq_right hSLge, show n - n = 0 from Nat.sub_self n, + List.drop_zero] at hfvl + rw [Nat.min_eq_right (by omega : e.scopeLevel ≥ n - 1), + show (n - 1) - (n - 1) = 0 from Nat.sub_self (n - 1), List.drop_zero] + have hdrop_shift : e.scopeGens.drop (e.scopeLevel - (n - 1)) = + (e.scopeGens.drop (e.scopeLevel - n)).tail := by + rw [List.tail_drop]; congr 1; omega + rw [hdrop_shift] + cases hdg : e.scopeGens.drop (e.scopeLevel - n) with + | nil => simp [Imp.findValidLevel] + | cons v vs => + simp only [List.tail_cons] + -- Goal: fvl(vs, gens.tail, n-1, rs) = none + -- hfvl: fvl(dg, gens, n, rs) = none where dg = v :: vs + rw [hdg] at hfvl + cases hg : gens with + | nil => + -- fvl with empty cGens always returns none + simp only [List.tail_nil] + cases vs with + | nil => rw [Imp.findValidLevel]; rfl + | cons a as => rw [Imp.findValidLevel]; simp + | cons w ws => + simp only [List.tail_cons] + rw [hg] at hfvl; rw [Imp.findValidLevel] at hfvl + simp only [List.head?_cons] at hfvl + split at hfvl + · -- v == w: fvl returns some, contradicts hfvl = none + exact absurd hfvl (by simp) + · split at hfvl + · -- n ≤ rs: impossible since rs ≤ n-1 and n > 0 + omega + · -- Recursive: fvl(vs, ws, n-1, rs) = none + simp only [List.tail_cons] at hfvl + exact hfvl + +/-- When all entries have rs ≤ k and sl ≤ k, filter+degrade is the identity. -/ +private theorem filter_degrade_spec_id (l : List Spec.Entry) (k : Nat) + (hRS : ∀ e ∈ l, e.resultScope ≤ k) + (hSL : ∀ e ∈ l, e.scopeLevel ≤ k) : + List.map (fun e => if e.scopeLevel > k then + { result := e.result, scopeLevel := k, resultScope := e.resultScope } else e) + (l.filter (fun e => decide (e.resultScope ≤ k))) = l := by + induction l with + | nil => simp + | cons hd tl ih => + simp only [List.filter_cons, + show decide (hd.resultScope ≤ k) = true from decide_eq_true_eq.mpr (hRS hd (List.mem_cons_self ..)), + ↓reduceIte, List.map_cons, + show ¬(hd.scopeLevel > k) from Nat.not_lt.mpr (hSL hd (List.mem_cons_self ..)), + ↓reduceIte] + exact congrArg (hd :: ·) (ih (fun e he => hRS e (List.mem_cons_of_mem _ he)) + (fun e he => hSL e (List.mem_cons_of_mem _ he))) + +/-- Entries' gen lists are mutually consistent: earlier entries' gens are drop-suffixes + of later entries' gens. This is preserved trivially by push/pop (which don't touch entries) + and established by insert (where all surviving entries are aligned with currentGens). -/ +private def GensConsistent (entries : List Imp.Entry) : Prop := + entries.Pairwise (fun a b => a.scopeGens = b.scopeGens.drop (b.scopeLevel - a.scopeLevel)) + +/-- Pop rewind equivalence: rewind at (n-1, gens.tail) relates to + filter+degrade of rewind at (n, gens). -/ +theorem rewind_pop_equiv (entries : List Imp.Entry) (n : Nat) (gens : Imp.GenList) + (hn : n > 0) + (hGensLen : gens.length = n + 1) + (hRSSL : ∀ e ∈ entries, e.resultScope ≤ e.scopeLevel) + (hOrd : Imp.EntryOrdered entries) + (hGensSuffix : ∀ e ∈ entries, ∀ i < e.scopeGens.length, + ∀ j < gens.length, e.scopeGens[i]! = gens[j]! → + e.scopeGens.drop i = gens.drop j) + (hGC : GensConsistent entries) : + (Imp.rewind entries (n - 1) gens.tail).map Imp.Entry.toSpec = + List.map + (fun e => if e.scopeLevel > n - 1 then + { e with scopeLevel := n - 1 : Spec.Entry } else e) + ((Imp.rewind entries n gens).map Imp.Entry.toSpec |>.filter + fun e => decide (e.resultScope ≤ n - 1)) := by + simp only [Imp.rewind] + -- GensConsistent is a Pairwise property on entries. In the suffices, we track + -- GensConsistent for revEntries.reverse (sub-list of entries) plus EntryOrdered. + -- In the early-stop case, we derive alignment for rest entries from GensConsistent + -- (inter-entry gens relationship) + gensSuffix (entry-to-currentGens relationship). + suffices ∀ revEntries acc1 acc2, + (∀ e ∈ revEntries, e ∈ entries) → + (∀ e ∈ revEntries, e.resultScope ≤ e.scopeLevel) → + (revEntries.reverse ++ acc1).Pairwise (fun a b => a.scopeLevel < b.resultScope) → + (∀ e ∈ revEntries, ∀ i < e.scopeGens.length, + ∀ j < gens.length, e.scopeGens[i]! = gens[j]! → + e.scopeGens.drop i = gens.drop j) → + revEntries.reverse.Pairwise + (fun a b => a.scopeGens = b.scopeGens.drop (b.scopeLevel - a.scopeLevel)) → + acc2.map Imp.Entry.toSpec = + List.map (fun e => if e.scopeLevel > n - 1 then + { e with scopeLevel := n - 1 : Spec.Entry } else e) + (acc1.map Imp.Entry.toSpec |>.filter fun e => decide (e.resultScope ≤ n - 1)) → + (Imp.rewind.go (n - 1) gens.tail revEntries acc2).map Imp.Entry.toSpec = + List.map (fun e => if e.scopeLevel > n - 1 then + { e with scopeLevel := n - 1 : Spec.Entry } else e) + ((Imp.rewind.go n gens revEntries acc1).map Imp.Entry.toSpec |>.filter + fun e => decide (e.resultScope ≤ n - 1)) by + exact this entries.reverse [] [] + (fun e he => List.mem_reverse.mp he) + (fun e he => hRSSL e (List.mem_reverse.mp he)) + (by simpa using hOrd) + (fun e he => hGensSuffix e (List.mem_reverse.mp he)) + (by simpa using hGC) + (by simp) + intro revEntries acc1 acc2 hMem hRSSLrev hOrdRev hGSrev hGCrev hAccEq + induction revEntries generalizing acc1 acc2 with + | nil => simp [Imp.rewind.go]; exact hAccEq + | cons top rest ih => + have hTopMem := hMem top (List.mem_cons_self ..) + have hTopRSSL := hRSSLrev top (List.mem_cons_self ..) + have hRestMem := fun e he => hMem e (List.mem_cons_of_mem _ he) + have hRestRSSL := fun e he => hRSSLrev e (List.mem_cons_of_mem _ he) + have hRestGS := fun e he => hGSrev e (List.mem_cons_of_mem _ he) + have hTopGS := hGSrev top (List.mem_cons_self ..) + -- GensConsistent for rest.reverse: sub-list of (top :: rest).reverse = rest.reverse ++ [top] + rw [List.reverse_cons] at hGCrev + have hRestGC := (List.pairwise_append.mp hGCrev).1 + -- Cross from GensConsistent: ∀ e ∈ rest, e.gens = top.gens.drop(top.sl - e.sl) + have hGCcross := (List.pairwise_append.mp hGCrev).2.2 + -- Extract ordering + rw [List.reverse_cons, List.append_assoc, List.singleton_append] at hOrdRev + have hP := List.pairwise_append.mp hOrdRev + have hOrdPre := hP.1 + have hOrdTA := hP.2.1 + have hCross := hP.2.2 + have hTopAcc := List.pairwise_cons.mp hOrdTA + simp only [Imp.rewind.go] + -- Case 1: top.resultScope > n + by_cases hGN : top.resultScope > n + · -- Both scopes skip (rs > n implies rs > n-1) + have hGN1 : top.resultScope > n - 1 := by omega + simp only [hGN, hGN1, ↓reduceIte] + exact ih acc1 acc2 hRestMem hRestRSSL + (by rw [List.pairwise_append]; exact ⟨hOrdPre, hTopAcc.2, + fun a ha b hb => hCross a ha b (List.mem_cons_of_mem _ hb)⟩) + hRestGS hRestGC hAccEq + · simp only [show ¬(top.resultScope > n) from hGN, ↓reduceIte] + -- Case 2: top.resultScope = n (rs > n-1 but rs ≤ n) + by_cases hGN1 : top.resultScope > n - 1 + · -- Scope n processes, scope n-1 skips + simp only [hGN1, ↓reduceIte] + -- At scope n, fvl is called + split + · -- fvl at scope n succeeds: result has rs = n which gets filtered + rename_i lvl eGens hfvl_n + -- The entry has rs > n-1, so it gets filtered out + have hRSn : top.resultScope = n := by omega + -- rest.reverse ++ [modified_top] ++ acc1 + -- After map toSpec and filter (rs ≤ n-1): modified_top has rs = n, filtered out + -- Prefix entries (rest.reverse) have scopeLevel < top.resultScope = n + -- so scopeLevel ≤ n-1, so their toSpec entries don't get capped + -- modified_top has rs = n, so it gets filtered out on RHS. + -- All entries in rest have sl < top.rs = n (by EntryOrdered/hCross), + -- and rs ≤ sl < n (by hRSSL). Use IH. + -- Key: need acc equivalence for the new accumulators. + -- New acc1' = modified_top :: acc1. + -- modified_top.toSpec.resultScope = n > n-1, so filtered out. + -- modified_top entry + have hNewAccEq : acc2.map Imp.Entry.toSpec = + List.map (fun e => if e.scopeLevel > n - 1 then + { e with scopeLevel := n - 1 : Spec.Entry } else e) + (({ result := top.result, scopeLevel := lvl, + scopeGens := eGens, resultScope := top.resultScope } :: + acc1).map Imp.Entry.toSpec |>.filter + fun e => decide (e.resultScope ≤ n - 1)) := by + simp only [List.map_cons, List.filter_cons, Imp.Entry.toSpec, hRSn] + have : ¬(n ≤ n - 1) := by omega + simp only [this, decide_false] + exact hAccEq + -- Ordering: entries in rest have sl < top.rs = n. + -- mt.resultScope = n, and lvl ≤ n (from fvl_lvl_le). So sl_rest < n = mt.rs. + -- Ordering: entries in rest have sl < top.rs = n = mt.rs + have hOrdNew : (rest.reverse ++ + { result := top.result, scopeLevel := lvl, + scopeGens := eGens, resultScope := top.resultScope } :: acc1).Pairwise + (fun a b => a.scopeLevel < b.resultScope) := by + rw [List.pairwise_append]; refine ⟨hOrdPre, ?_, ?_⟩ + · have hlvl : lvl ≤ top.scopeLevel := + Nat.le_trans (findValidLevel_lvl_le hfvl_n) (Nat.min_le_left ..) + exact List.pairwise_cons.mpr + ⟨fun a ha => by + exact Nat.lt_of_le_of_lt hlvl (hTopAcc.1 a ha), + hTopAcc.2⟩ + · intro a ha b hb + rcases List.mem_cons.mp hb with rfl | hb' + · simp only; exact hRSn ▸ hCross a ha top (List.mem_cons_self ..) + · exact hCross a ha b (List.mem_cons_of_mem _ hb') + -- The IH gives rewind.go(n, gens, rest, mt::acc1) on the RHS. + -- The goal has rest.reverse ++ [mt] ++ acc1 (early stop result). + -- Show they're equal via rewind_go_all_valid. + -- Derive alignment for rest entries from GensConsistent + top's fvl + gensSuffix + -- Step 1: top.gens.drop(top.sl - n) = gens + have hTopSLgeN : n ≤ top.scopeLevel := by + have := hTopRSSL; rw [hRSn] at this; exact this + have hMinEq : min top.scopeLevel n = n := Nat.min_eq_right hTopSLgeN + -- fvl at (n, n) can only match at level n + have hLvlN : lvl = n := by + have h1 := findValidLevel_lvl_le hfvl_n + have h2 := findValidLevel_rs_le + (show top.resultScope ≤ min top.scopeLevel n by + simp [hRSn, hMinEq]) hfvl_n + rw [hMinEq] at h1; rw [hRSn] at h2; omega + -- top.gens.drop(top.sl - n) = gens + -- eGens = top.gens.drop(top.sl - n) since fvl matched at level n + have hfvl_n' : Imp.findValidLevel + (top.scopeGens.drop (top.scopeLevel - n)) + (gens.drop (n - n)) n top.resultScope = some (n, eGens) := by + rwa [hMinEq, hLvlN] at hfvl_n + have hEGens : eGens = top.scopeGens.drop (top.scopeLevel - n) := + findValidLevel_match_at_level hfvl_n' + -- eGens = gens (from findValidLevel_aligned) + have hEGensGens : eGens = gens := by + have := findValidLevel_aligned hfvl_n rfl rfl hTopGS + rw [hMinEq, Nat.sub_self, hLvlN, Nat.sub_self, Nat.zero_add, List.drop_zero] at this + exact this + -- So top.gens.drop(top.sl - n) = gens + have hTopAligned : top.scopeGens.drop (top.scopeLevel - n) = gens := by + rw [← hEGensGens, hEGens] + -- Step 2: derive alignment for each rest entry + have hRestValid : ∀ e ∈ rest, + Imp.findValidLevel e.scopeGens (gens.drop (n - e.scopeLevel)) + e.scopeLevel e.resultScope = some (e.scopeLevel, e.scopeGens) := by + intro e he + have heSL : e.scopeLevel ≤ n := by + have := hCross e (List.mem_reverse.mpr he) top (List.mem_cons_self ..) + have := hRSSLrev e (List.mem_cons_of_mem _ he) + omega + -- From GensConsistent: e.gens = top.gens.drop(top.sl - e.sl) + have heGC := hGCcross e (List.mem_reverse.mpr he) top (List.mem_cons_self ..) + -- Derive e.gens = gens.drop(n - e.sl) using List.drop_drop + have heAligned : e.scopeGens = gens.drop (n - e.scopeLevel) := by + rw [heGC, show top.scopeLevel - e.scopeLevel = + (top.scopeLevel - n) + (n - e.scopeLevel) from by omega, + ← List.drop_drop, hTopAligned] + exact fvl_self_match e gens n heSL heAligned + (by rw [heAligned, List.length_drop, hGensLen]; omega) + -- Use IH + rewind_go_all_valid to close the goal + have hRestRS : ∀ e ∈ rest, e.resultScope ≤ n := + fun e he => by + have := hCross e (List.mem_reverse.mpr he) top (List.mem_cons_self ..) + have := hRestRSSL e he; omega + have hRestSL : ∀ e ∈ rest, e.scopeLevel ≤ n := + fun e he => by + have := hCross e (List.mem_reverse.mpr he) top (List.mem_cons_self ..) + omega + have hRGAV := rewind_go_all_valid rest + ({ result := top.result, scopeLevel := lvl, + scopeGens := eGens, resultScope := top.resultScope } :: acc1) + n gens hRestRS hRestSL hRestRSSL hRestValid + have hIH := ih + ({ result := top.result, scopeLevel := lvl, + scopeGens := eGens, resultScope := top.resultScope } :: acc1) + acc2 hRestMem hRestRSSL hOrdNew hRestGS hRestGC hNewAccEq + rw [hRGAV] at hIH + simp only [List.append_assoc, List.singleton_append] at hIH ⊢ + exact hIH + · -- fvl at scope n fails: both skip + exact ih acc1 acc2 hRestMem hRestRSSL + (by rw [List.pairwise_append]; exact ⟨hOrdPre, hTopAcc.2, + fun a ha b hb => hCross a ha b (List.mem_cons_of_mem _ hb)⟩) + hRestGS hRestGC hAccEq + · -- Case 3: top.resultScope ≤ n-1 (both scopes process) + have hRS : top.resultScope ≤ n - 1 := by omega + simp only [show ¬(top.resultScope > n - 1) from hGN1, ↓reduceIte] + -- `split` splits the first match (LHS = scope n-1 fvl) + split + · -- scope n-1 fvl succeeds + rename_i lvl1 eGens1 hfvl_n1 + -- Now split on scope n fvl (RHS) + split + · -- scope n fvl also succeeds + rename_i lvl eGens hfvl_n + -- Relate (lvl1, eGens1) to (lvl, eGens) via fvl_pop_entry + have hRel := fvl_pop_entry top n gens hn hRS hTopGS hGensLen lvl eGens hfvl_n + rw [hRel] at hfvl_n1 + simp only [Option.some.injEq, Prod.mk.injEq] at hfvl_n1 + obtain ⟨hLvl1, hEGens1⟩ := hfvl_n1 + subst hLvl1; subst hEGens1 + -- Simplify filter/map on empty list and singleton + simp only [List.map_append, List.map_cons, List.map_nil, + List.filter_append, List.filter_cons, + Imp.Entry.toSpec, List.filter_nil] + -- rest entries: filter+degrade is identity (sl < top.rs ≤ n-1) + have hRestRS : ∀ e ∈ rest.reverse.map Imp.Entry.toSpec, + e.resultScope ≤ n - 1 := by + intro e he + obtain ⟨e', he', rfl⟩ := List.mem_map.mp he + have := hCross e' he' top (List.mem_cons_self ..) + have := hRSSLrev e' (List.mem_cons_of_mem _ (List.mem_reverse.mp he')) + simp [Imp.Entry.toSpec]; omega + have hRestSL : ∀ e ∈ rest.reverse.map Imp.Entry.toSpec, + e.scopeLevel ≤ n - 1 := by + intro e he + obtain ⟨e', he', rfl⟩ := List.mem_map.mp he + have := hCross e' he' top (List.mem_cons_self ..) + simp [Imp.Entry.toSpec]; omega + rw [filter_degrade_spec_id _ (n - 1) hRestRS hRestSL, + show decide (top.resultScope ≤ n - 1) = true from + decide_eq_true_eq.mpr hRS] + simp only [↓reduceIte, List.map_cons, List.map_nil] + congr 1 + · congr 1 + -- [{sl = min lvl (n-1)}] = [degrade {sl = lvl}] + by_cases hlvl : lvl > n - 1 + · simp [hlvl, show min lvl (n - 1) = n - 1 from Nat.min_eq_right (by omega)] + · simp [hlvl, show min lvl (n - 1) = lvl from Nat.min_eq_left (by omega)] + · -- scope n fails but scope n-1 succeeds: impossible + rename_i hfvl_n_none + exact absurd hfvl_n1 (by rw [fvl_pop_none top n gens hn hRS hfvl_n_none]; simp) + · -- scope n-1 fvl fails + rename_i hfvl_n1_none + -- Now split on scope n fvl (RHS) + split + · -- scope n succeeds but scope n-1 fails: impossible + rename_i lvl eGens hfvl_n + exfalso + rw [fvl_pop_entry top n gens hn hRS hTopGS hGensLen lvl eGens hfvl_n] + at hfvl_n1_none + exact absurd hfvl_n1_none (by simp) + · -- Both fail: both skip, use IH + exact ih acc1 acc2 hRestMem hRestRSSL + (by rw [List.pairwise_append]; exact ⟨hOrdPre, hTopAcc.2, + fun a ha b hb => hCross a ha b (List.mem_cons_of_mem _ hb)⟩) + hRestGS hRestGC hAccEq + +/-- Pop preserves SimInv. -/ +theorem pop_simInv (h : SimInv spec imp) (hpos : spec.scopeN > 0) + (hgc : GensConsistent imp.entries) : + SimInv (Spec.pop spec) (Imp.pop imp) where + scopeEq := by simp [Spec.pop, Imp.pop, h.scopeEq] + entriesEq := by + simp only [Spec.pop, Imp.pop] + rw [h.entriesEq, h.scopeEq] + exact (rewind_pop_equiv imp.entries imp.scopeN imp.currentGens + (h.scopeEq ▸ hpos) h.gensLen h.rsLeSl h.entryOrdered h.gensSuffix hgc).symm + gensLen := by + simp only [Imp.pop, List.length_tail, h.gensLen] + have : imp.scopeN ≥ 1 := h.scopeEq ▸ hpos + rw [Nat.succ_sub_one, Nat.sub_add_cancel this] + genBound := by + intro e he k hk + simp only [Imp.pop] at he + exact h.genBound e he k hk + curGenBound := by + intro k hk + simp only [Imp.pop] at hk ⊢ + have : k + 1 < imp.currentGens.length := by + rw [List.length_tail] at hk; grind + rw [show imp.currentGens.tail[k]! = imp.currentGens[k + 1]! from by + simp [List.getElem!_eq_getElem?_getD, List.getElem?_tail]] + exact h.curGenBound (k + 1) this + rsLeSl := by + intro e he; simp only [Imp.pop] at he; exact h.rsLeSl e he + entryOrdered := by simp only [Imp.pop]; exact h.entryOrdered + gensNodup := by + simp only [Imp.pop] + exact h.gensNodup.sublist (List.tail_sublist _) + gensSuffix := by + intro e he i hi j hj heq + simp only [Imp.pop] at he hj heq ⊢ + have hj1 : j + 1 < imp.currentGens.length := by + rw [List.length_tail] at hj; grind + rw [show imp.currentGens.tail[j]! = imp.currentGens[j + 1]! from by + simp [List.getElem!_eq_getElem?_getD, List.getElem?_tail]] at heq + rw [show imp.currentGens.tail.drop j = imp.currentGens.drop (j + 1) from by + rw [List.drop_tail]] + exact h.gensSuffix e he i hi (j + 1) hj1 heq + +private theorem filter_map_toSpec (p : Nat → Bool) (entries : List Imp.Entry) : + (entries.map Imp.Entry.toSpec).filter (fun e => p e.scopeLevel) = + (entries.filter (fun e => p e.scopeLevel)).map Imp.Entry.toSpec := by + induction entries with + | nil => simp + | cons hd tl ih => + simp only [List.map_cons, List.filter_cons, Imp.Entry.toSpec] + split <;> simp_all [Imp.Entry.toSpec] + +/-- Insert preserves SimInv. -/ +theorem insert_simInv (h : SimInv spec imp) (hrs : rs ≤ spec.scopeN) : + SimInv (Spec.insert spec v rs).1 (Imp.insert imp v rs).1 where + scopeEq := by simp [Spec.insert, Imp.insert, h.scopeEq] + entriesEq := by + simp only [Spec.insert, Imp.insert] + rw [h.entriesEq, h.scopeEq] + have hGensPos : imp.currentGens.length > 0 := by + rw [h.gensLen]; exact Nat.zero_lt_succ _ + rw [rewind_after_insert _ _ imp.scopeN imp.currentGens rfl rfl + (h.scopeEq ▸ hrs) hGensPos] + rw [List.map_append, List.map_cons, List.map_nil] + congr 1 + · exact filter_map_toSpec (· < rs) _ + · simp only [List.cons_eq_cons, and_true, Imp.Entry.toSpec] + congr 1 + cases hL : (Imp.rewind imp.entries imp.scopeN imp.currentGens).getLast? with + | none => simp [List.getLast?_map, hL] + | some e => simp [List.getLast?_map, hL, Imp.Entry.toSpec] + gensLen := by simp [Imp.insert, h.gensLen] + genBound := by + intro e he k hk + simp only [Imp.insert, List.mem_append, List.mem_cons, List.mem_nil_iff, or_false] at he + rcases he with hmem | heq + · -- From filtered rewind entries — these come from imp.entries via rewind + exact rewind_genBound h.genBound (List.mem_filter.mp hmem).1 k hk + · -- The new entry has scopeGens = currentGens + subst heq + exact h.curGenBound k hk + curGenBound := by + intro k hk + simp only [Imp.insert] at hk ⊢ + exact h.curGenBound k hk + rsLeSl := by + intro e he + simp only [Imp.insert, List.mem_append, List.mem_cons, List.mem_nil_iff, or_false] at he + rcases he with hmem | heq + · exact rewind_rsLeSl h.rsLeSl (List.mem_filter.mp hmem).1 + · subst heq; exact h.scopeEq ▸ hrs + entryOrdered := by + simp only [Imp.insert, Imp.EntryOrdered] + apply List.pairwise_append.mpr + refine ⟨(rewind_entryOrdered h.entryOrdered).filter _, ?_, ?_⟩ + · -- [newEntry].Pairwise R — trivially true for singleton + exact List.pairwise_cons.mpr ⟨fun b hb => (List.not_mem_nil hb).elim, List.Pairwise.nil⟩ + · -- ∀ a ∈ filtered, ∀ b ∈ [newEntry], R a b + intro a ha b hb + simp only [List.mem_singleton] at hb + subst hb; simp only + exact of_decide_eq_true (List.mem_filter.mp ha).2 + gensNodup := by simp only [Imp.insert]; exact h.gensNodup + gensSuffix := by + intro e he i hi j hj heq + simp only [Imp.insert, List.mem_append, List.mem_cons, List.mem_nil_iff, or_false] at he + simp only [Imp.insert] at hj heq ⊢ + rcases he with hmem | heq_e + · -- Entry from filtered rewind: its gens come from some original entry + have he_rw := (List.mem_filter.mp hmem).1 + obtain ⟨e', he', _, _, ⟨d, hd⟩⟩ := rewind_mem he_rw + -- Convert to facts about e'.scopeGens + have hi' : i + d < e'.scopeGens.length := by + rw [hd] at hi; simp only [List.length_drop] at hi; grind + have heq' : e'.scopeGens[i + d]! = imp.currentGens[j]! := by + rw [hd] at heq + simp only [List.getElem!_eq_getElem?_getD, List.getElem?_drop] at heq + rw [Nat.add_comm] at heq + simp only [List.getElem!_eq_getElem?_getD] + exact heq + have hsuf := h.gensSuffix e' he' (i + d) hi' j hj heq' + -- hsuf : e'.scopeGens.drop (i + d) = imp.currentGens.drop j + -- Goal: e.scopeGens.drop i = imp.currentGens.drop j + rw [hd, List.drop_drop, Nat.add_comm] + exact hsuf + · -- The new entry: scopeGens = imp.currentGens + subst heq_e; simp only at hi heq ⊢ + -- heq : currentGens[i]! = currentGens[j]! + -- Use Nodup to conclude i = j, then drop i = drop j + have heqi : imp.currentGens[i]? = imp.currentGens[j]? := by + simp only [List.getElem!_eq_getElem?_getD, + List.getElem?_eq_getElem hi, List.getElem?_eq_getElem hj, + Option.getD_some] at heq + simp only [List.getElem?_eq_getElem hi, List.getElem?_eq_getElem hj] + exact congrArg some heq + have hij := (List.getElem?_inj hi h.gensNodup).mp heqi + subst hij; rfl + +/-! ### Reachable states and main theorems -/ + +/-- States reachable from empty via matching operations. -/ +inductive Reachable : Spec.State → Imp.State → Prop where + | empty : Reachable Spec.empty Imp.empty + | push : Reachable s i → Reachable (Spec.push s) (Imp.push i) + | pop : Reachable s i → s.scopeN > 0 → + Reachable (Spec.pop s) (Imp.pop i) + | insert : Reachable s i → rs ≤ s.scopeN → + Reachable (Spec.insert s v rs).1 (Imp.insert i v rs).1 + +private theorem push_gensConsistent (hgc : GensConsistent imp.entries) : + GensConsistent (Imp.push imp).entries := by + simp only [Imp.push]; exact hgc + +private theorem pop_gensConsistent (hgc : GensConsistent imp.entries) : + GensConsistent (Imp.pop imp).entries := by + simp only [Imp.pop]; exact hgc + +/-- When all entries are aligned with a common gens list and ordered, GensConsistent holds. -/ +private theorem aligned_to_gensConsistent + {entries : List Imp.Entry} {g : Imp.GenList} {N : Nat} + (hAlign : ∀ e ∈ entries, e.scopeGens = g.drop (N - e.scopeLevel)) + (hOrd : entries.Pairwise (fun a b => a.scopeLevel < b.resultScope)) + (hRSSL : ∀ e ∈ entries, e.resultScope ≤ e.scopeLevel) + (hSLle : ∀ e ∈ entries, e.scopeLevel ≤ N) : + GensConsistent entries := by + unfold GensConsistent + induction entries with + | nil => exact List.Pairwise.nil + | cons hd tl ih => + apply List.pairwise_cons.mpr + constructor + · intro b hb + rw [hAlign hd (List.mem_cons_self ..), hAlign b (List.mem_cons_of_mem _ hb), + List.drop_drop]; congr 1 + have := (List.pairwise_cons.mp hOrd).1 b hb + have := hRSSL b (List.mem_cons_of_mem _ hb) + have := hSLle b (List.mem_cons_of_mem _ hb) + omega + · exact ih (fun e he => hAlign e (List.mem_cons_of_mem _ he)) + (List.pairwise_cons.mp hOrd).2 + (fun e he => hRSSL e (List.mem_cons_of_mem _ he)) + (fun e he => hSLle e (List.mem_cons_of_mem _ he)) + +/-- After rewind, all surviving entries have gens aligned with currentGens. -/ +private theorem rewind_gens_aligned + (hGC : GensConsistent entries) + (hGS : ∀ e ∈ entries, ∀ i < e.scopeGens.length, ∀ j < gens.length, + e.scopeGens[i]! = gens[j]! → e.scopeGens.drop i = gens.drop j) + (hRSSL : ∀ e ∈ entries, e.resultScope ≤ e.scopeLevel) + (hOrd : Imp.EntryOrdered entries) : + ∀ e ∈ Imp.rewind entries n gens, e.scopeGens = gens.drop (n - e.scopeLevel) := by + simp only [Imp.rewind] + suffices ∀ revEntries acc, + (∀ e ∈ revEntries, e ∈ entries) → + revEntries.reverse.Pairwise + (fun a b => a.scopeGens = b.scopeGens.drop (b.scopeLevel - a.scopeLevel)) → + revEntries.reverse.Pairwise (fun a b => a.scopeLevel < b.resultScope) → + (∀ e ∈ acc, e.scopeGens = gens.drop (n - e.scopeLevel)) → + ∀ e ∈ Imp.rewind.go n gens revEntries acc, + e.scopeGens = gens.drop (n - e.scopeLevel) by + exact this entries.reverse [] + (fun e he => List.mem_reverse.mp he) + (by simpa using hGC) + (by simpa using hOrd) + (by simp) + intro revEntries acc hMem hGCrev hOrdRev hAccAligned + induction revEntries generalizing acc with + | nil => simp [Imp.rewind.go]; exact hAccAligned + | cons top rest ih => + have hTopMem := hMem top (List.mem_cons_self ..) + have hRestMem := fun e he => hMem e (List.mem_cons_of_mem _ he) + -- Decompose Pairwise on rest.reverse ++ [top] + rw [List.reverse_cons] at hGCrev hOrdRev + have hGCparts := List.pairwise_append.mp hGCrev + have hRestGC := hGCparts.1 + have hGCcross := hGCparts.2.2 + have hOrdParts := List.pairwise_append.mp hOrdRev + have hRestOrd := hOrdParts.1 + have hOrdCross := hOrdParts.2.2 + simp only [Imp.rewind.go] + by_cases hGN : top.resultScope > n + · -- rs > n: discard + simp only [hGN, ↓reduceIte] + exact ih acc hRestMem hRestGC hRestOrd hAccAligned + · simp only [show ¬(top.resultScope > n) from hGN, ↓reduceIte] + split + · -- fvl succeeds: early stop, result = rest.reverse ++ [modified_top] ++ acc + rename_i lvl eGens hfvl + -- Establish key facts about the fvl result before intro + have hTopRSSL := hRSSL top hTopMem + have hRSleMin : top.resultScope ≤ min top.scopeLevel n := + Nat.le_min.mpr ⟨hTopRSSL, Nat.not_lt.mp hGN⟩ + have hRSle := findValidLevel_rs_le hRSleMin hfvl + have hEGensDrop := findValidLevel_output_eq_input_drop hfvl + have hFVLle := findValidLevel_lvl_le hfvl + have hEGensTop : eGens = top.scopeGens.drop (top.scopeLevel - lvl) := by + rw [hEGensDrop, List.drop_drop]; congr 1; omega + have hEGensGens := findValidLevel_aligned hfvl rfl rfl (hGS top hTopMem) + have hEGensGens' : eGens = gens.drop (n - lvl) := by + rw [hEGensGens]; congr 1; omega + have hTopDrop : top.scopeGens.drop (top.scopeLevel - lvl) = + gens.drop (n - lvl) := by + rw [← hEGensGens', hEGensTop] + intro e he + -- result = (rest.reverse ++ [modified_top]) ++ acc + rcases List.mem_append.mp he with hLeft | hAcc + · rcases List.mem_append.mp hLeft with hRest | hSingle + · -- e ∈ rest.reverse: derive alignment from GensConsistent + fvl result + have heGC := hGCcross e hRest top (List.mem_singleton.mpr rfl) + have heOrd := hOrdCross e hRest top (List.mem_singleton.mpr rfl) + -- e.gens = gens.drop(n - e.sl) via List.drop_drop + rw [heGC, show top.scopeLevel - e.scopeLevel = + (top.scopeLevel - lvl) + (lvl - e.scopeLevel) from by omega, + ← List.drop_drop, hTopDrop, List.drop_drop] + congr 1; omega + · -- e ∈ [modified_top]: from findValidLevel_aligned + rw [List.mem_singleton.mp hSingle] + simp only [show (n - min top.scopeLevel n) + (min top.scopeLevel n - lvl) = + n - lvl from by omega] at hEGensGens + exact hEGensGens + · -- e ∈ acc: from hypothesis + exact hAccAligned e hAcc + · -- fvl fails: continue with rest + exact ih acc hRestMem hRestGC hRestOrd hAccAligned + +private theorem insert_gensConsistent (h : SimInv spec imp) + (hgc : GensConsistent imp.entries) (hrs : rs ≤ spec.scopeN) : + GensConsistent (Imp.insert imp v rs).1.entries := by + have hRA := rewind_gens_aligned (n := imp.scopeN) (gens := imp.currentGens) + hgc h.gensSuffix h.rsLeSl h.entryOrdered + -- After insert, entries = filter(sl < rs, rewind) ++ [newEntry] + -- All are aligned with currentGens and GensConsistent follows from aligned_to_gensConsistent + change GensConsistent (Imp.insert imp v rs).1.entries + unfold Imp.insert + simp only + apply aligned_to_gensConsistent (N := imp.scopeN) (g := imp.currentGens) + · -- alignment: all entries have gens = currentGens.drop(scopeN - sl) + intro e he + rcases List.mem_append.mp he with hmem | hSingle + · exact hRA e (List.mem_filter.mp hmem).1 + · rw [List.mem_singleton.mp hSingle]; simp [Nat.sub_self] + · -- ordering: Pairwise (fun a b => a.sl < b.rs) + apply List.pairwise_append.mpr + refine ⟨(rewind_entryOrdered h.entryOrdered).filter _, ?_, ?_⟩ + · exact List.pairwise_cons.mpr ⟨fun b hb => by simp at hb, List.Pairwise.nil⟩ + · intro a ha b hb + rw [List.mem_singleton.mp hb] + exact of_decide_eq_true (List.mem_filter.mp ha).2 + · -- rsLeSl + intro e he + rcases List.mem_append.mp he with hmem | hSingle + · exact rewind_rsLeSl h.rsLeSl (List.mem_filter.mp hmem).1 + · rw [List.mem_singleton.mp hSingle]; exact h.scopeEq ▸ hrs + · -- sl ≤ scopeN + intro e he + rcases List.mem_append.mp he with hmem | hSingle + · have := of_decide_eq_true (List.mem_filter.mp hmem).2 + have := h.scopeEq ▸ hrs; omega + · rw [List.mem_singleton.mp hSingle]; simp + +/-- All reachable states satisfy the simulation invariant and gens consistency. -/ +private theorem reachable_simInv_gc : Reachable s i → SimInv s i ∧ GensConsistent i.entries := by + intro h + induction h with + | empty => exact ⟨simInv_empty, List.Pairwise.nil⟩ + | push _ ih => + exact ⟨push_simInv ih.1, push_gensConsistent ih.2⟩ + | pop _ hpos ih => + exact ⟨pop_simInv ih.1 hpos ih.2, pop_gensConsistent ih.2⟩ + | insert _ hrs ih => + exact ⟨insert_simInv ih.1 hrs, insert_gensConsistent ih.1 ih.2 hrs⟩ + +/-- All reachable states satisfy the simulation invariant. -/ +theorem reachable_simInv : Reachable s i → SimInv s i := + fun h => (reachable_simInv_gc h).1 + +/-- Main theorem: lookup is equivalent for all reachable states. -/ +theorem reachable_lookup_equiv (h : Reachable s i) : + Spec.lookup s = Imp.lookup i := + lookup_equiv (reachable_simInv h) + +/-- Main theorem: insert sharing is equivalent for all reachable states. -/ +theorem reachable_insert_equiv (h : Reachable s i) : + (Spec.insert s v rs).2 = (Imp.insert i v rs).2 := + insert_sharing_equiv (reachable_simInv h) + +end Proofs + +end ScopeCache