Compare commits

...

4 Commits

Author SHA1 Message Date
Leonardo de Moura
747f0efa97 fix: instantiate_lmvars_fn 2024-08-05 09:59:56 -07:00
Leonardo de Moura
2a0310dfbd feat: add options preserve_data and zeta to apply_beta 2024-08-05 09:59:56 -07:00
Leonardo de Moura
f76f10d0c6 perf: add lean_instantiate_expr_mvars 2024-08-05 09:59:40 -07:00
Leonardo de Moura
bf010e153a perf: use lean_instantiate_level_mvars
implemented in C/C++.
TODO: same for `instantiateExprMVars`
2024-08-04 15:59:41 -07:00
4 changed files with 313 additions and 38 deletions

View File

@@ -541,20 +541,10 @@ This operation is performed at `instantiateExprMVars`, `elimMVarDeps`, and `leve
@[extern "lean_instantiate_level_mvars"]
opaque instantiateLevelMVarsImp (mctx : MetavarContext) (l : Level) : MetavarContext × Level
partial def instantiateLevelMVars [Monad m] [MonadMCtx m] : Level m Level
| lvl@(Level.succ lvl₁) => return Level.updateSucc! lvl ( instantiateLevelMVars lvl₁)
| lvl@(Level.max lvl₁ lvl₂) => return Level.updateMax! lvl ( instantiateLevelMVars lvl₁) ( instantiateLevelMVars lvl₂)
| lvl@(Level.imax lvl₁ lvl₂) => return Level.updateIMax! lvl ( instantiateLevelMVars lvl₁) ( instantiateLevelMVars lvl₂)
| lvl@(Level.mvar mvarId) => do
match ( getLevelMVarAssignment? mvarId) with
| some newLvl =>
if !newLvl.hasMVar then pure newLvl
else do
let newLvl' instantiateLevelMVars newLvl
assignLevelMVar mvarId newLvl'
pure newLvl'
| none => pure lvl
| lvl => pure lvl
partial def instantiateLevelMVars [Monad m] [MonadMCtx m] (l : Level) : m Level := do
let (mctx, lNew) := instantiateLevelMVarsImp ( getMCtx) l
setMCtx mctx
return lNew
@[extern "lean_instantiate_expr_mvars"]
opaque instantiateExprMVarsImp (mctx : MetavarContext) (e : Expr) : MetavarContext × Expr

View File

@@ -165,22 +165,38 @@ bool is_head_beta(expr const & t) {
return is_app(t) && is_lambda(get_app_fn(t));
}
expr apply_beta(expr f, unsigned num_args, expr const * args) {
if (num_args == 0) {
return f;
} else if (!is_lambda(f)) {
return mk_rev_app(f, num_args, args);
} else {
unsigned m = 1;
while (is_lambda(binding_body(f)) && m < num_args) {
f = binding_body(f);
m++;
static expr apply_beta_rec(expr e, unsigned i, unsigned num_rev_args, expr const * rev_args, bool preserve_data, bool zeta) {
if (is_lambda(e)) {
if (i + 1 < num_rev_args) {
return apply_beta_rec(binding_body(e), i+1, num_rev_args, rev_args, preserve_data, zeta);
} else {
return instantiate(binding_body(e), num_rev_args, rev_args);
}
lean_assert(m <= num_args);
return mk_rev_app(instantiate(binding_body(f), m, args + (num_args - m)), num_args - m, args);
} else if (is_let(e)) {
if (zeta && i < num_rev_args) {
return apply_beta_rec(instantiate(let_body(e), let_value(e)), i, num_rev_args, rev_args, preserve_data, zeta);
} else {
unsigned n = num_rev_args - i;
return mk_rev_app(instantiate(e, i, rev_args + n), n, rev_args);
}
} else if (is_mdata(e)) {
if (preserve_data) {
unsigned n = num_rev_args - i;
return mk_rev_app(instantiate(e, i, rev_args + n), n, rev_args);
} else {
return apply_beta_rec(mdata_expr(e), i, num_rev_args, rev_args, preserve_data, zeta);
}
} else {
unsigned n = num_rev_args - i;
return mk_rev_app(instantiate(e, i, rev_args + n), n, rev_args);
}
}
expr apply_beta(expr f, unsigned num_rev_args, expr const * rev_args, bool preserve_data, bool zeta) {
if (num_rev_args == 0) return f;
return apply_beta_rec(f, 0, num_rev_args, rev_args, preserve_data, zeta);
}
expr head_beta_reduce(expr const & t) {
if (!is_head_beta(t)) {
return t;

View File

@@ -24,7 +24,7 @@ inline expr instantiate_rev(expr const & e, buffer<expr> const & s) {
return instantiate_rev(e, s.size(), s.data());
}
expr apply_beta(expr f, unsigned num_rev_args, expr const * rev_args);
expr apply_beta(expr f, unsigned num_rev_args, expr const * rev_args, bool preserve_data = true, bool zeta = false);
bool is_head_beta(expr const & t);
expr head_beta_reduce(expr const & t);
/* If `e` is of the form `(fun x, t) a` return `head_beta_const_fn(t)` if `t` does not depend on `x`,

View File

@@ -4,10 +4,13 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
*/
#include <vector>
#include <unordered_map>
#include "util/name_set.h"
#include "runtime/option_ref.h"
#include "runtime/array_ref.h"
#include "kernel/instantiate.h"
#include "kernel/abstract.h"
#include "kernel/replace_fn.h"
/*
This module is not used by the kernel. It just provides an efficient implementation of
@@ -15,7 +18,6 @@ This module is not used by the kernel. It just provides an efficient implementat
*/
namespace lean {
extern "C" object * lean_get_lmvar_assignment(obj_arg mctx, obj_arg mid);
extern "C" object * lean_assign_lmvar(obj_arg mctx, obj_arg mid, obj_arg val);
@@ -29,18 +31,19 @@ option_ref<level> get_lmvar_assignment(metavar_ctx & mctx, name const & mid) {
return option_ref<level>(lean_get_lmvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg()));
}
class instantiate_lmvar_fn {
class instantiate_lmvars_fn {
metavar_ctx & m_mctx;
std::unordered_map<lean_object *, lean_object *> m_cache;
std::unordered_map<lean_object *, level> m_cache;
std::vector<level> m_saved; // Helper vector to prevent values from being garbagge collected
inline level cache(level const & l, level && r, bool shared) {
inline level cache(level const & l, level r, bool shared) {
if (shared) {
m_cache.insert(mk_pair(l.raw(), r.raw()));
m_cache.insert(mk_pair(l.raw(), r));
}
return r;
}
public:
instantiate_lmvar_fn(metavar_ctx & mctx):m_mctx(mctx) {}
instantiate_lmvars_fn(metavar_ctx & mctx):m_mctx(mctx) {}
level visit(level const & l) {
if (!has_mvar(l))
return l;
@@ -48,7 +51,7 @@ public:
if (is_shared(l)) {
auto it = m_cache.find(l.raw());
if (it != m_cache.end()) {
return level(it->second, true);
return it->second;
}
shared = true;
}
@@ -70,6 +73,12 @@ public:
} else {
level a_new = visit(a);
if (!is_eqp(a, a_new)) {
/*
We save `a` to ensure it will not be garbage collected
after we update `mctx`. This is necessary because `m_cache`
may contain references to its subterms.
*/
m_saved.push_back(a);
assign_lmvar(m_mctx, mvar_id(l), a_new);
}
return a_new;
@@ -82,14 +91,274 @@ public:
extern "C" LEAN_EXPORT object * lean_instantiate_level_mvars(object * m, object * l) {
metavar_ctx mctx(m);
level l_new = instantiate_lmvar_fn(mctx)(level(l));
level l_new = instantiate_lmvars_fn(mctx)(level(l));
object * r = alloc_cnstr(0, 2, 0);
cnstr_set(r, 0, mctx.steal());
cnstr_set(r, 1, l_new.steal());
return r;
}
extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars(object *, object *) {
lean_internal_panic("not implemented yet");
extern "C" object * lean_get_mvar_assignment(obj_arg mctx, obj_arg mid);
extern "C" object * lean_get_delayed_mvar_assignment(obj_arg mctx, obj_arg mid);
extern "C" object * lean_assign_mvar(obj_arg mctx, obj_arg mid, obj_arg val);
typedef object_ref delayed_assignment;
void assign_mvar(metavar_ctx & mctx, name const & mid, expr const & e) {
object * r = lean_assign_mvar(mctx.steal(), mid.to_obj_arg(), e.to_obj_arg());
mctx.set_box(r);
}
option_ref<expr> get_mvar_assignment(metavar_ctx & mctx, name const & mid) {
return option_ref<expr>(lean_get_mvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg()));
}
option_ref<delayed_assignment> get_delayed_mvar_assignment(metavar_ctx & mctx, name const & mid) {
return option_ref<delayed_assignment>(lean_get_delayed_mvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg()));
}
expr replace_fvars(expr const & e, array_ref<expr> const & fvars, expr const * rev_args) {
size_t sz = fvars.size();
if (sz == 0)
return e;
return replace(e, [=](expr const & m, unsigned offset) -> optional<expr> {
if (!has_fvar(m))
return some_expr(m); // expression m does not contain free variables
if (is_fvar(m)) {
size_t i = sz;
name const & fid = fvar_name(m);
while (i > 0) {
--i;
if (fvar_name(fvars[i]) == fid) {
return some_expr(lift_loose_bvars(rev_args[sz - i - 1], offset));
}
}
}
return none_expr();
});
}
class instantiate_mvars_fn {
metavar_ctx & m_mctx;
instantiate_lmvars_fn m_level_fn;
name_set m_already_normalized; // Store metavariables whose assignment has already been normalized.
std::unordered_map<lean_object *, expr> m_cache;
std::vector<expr> m_saved; // Helper vector to prevent values from being garbagge collected
level visit_level(level const & l) {
return m_level_fn(l);
}
levels visit_levels(levels const & ls) {
buffer<level> lsNew;
for (auto const & l : ls)
lsNew.push_back(visit_level(l));
return levels(lsNew);
}
inline expr cache(expr const & e, expr r, bool shared) {
if (shared) {
m_cache.insert(mk_pair(e.raw(), r));
}
return r;
}
optional<expr> get_assignment(name const & mid) {
option_ref<expr> r = get_mvar_assignment(m_mctx, mid);
if (!r) {
return optional<expr>();
} else {
expr a(r.get_val());
if (!has_mvar(a) || m_already_normalized.contains(mid)) {
return optional<expr>(a);
} else {
m_already_normalized.insert(mid);
expr a_new = visit(a);
if (!is_eqp(a, a_new)) {
/*
We save `a` to ensure it will not be garbage collected
after we update `mctx`. This is necessary because `m_cache`
may contain references to its subterms.
*/
m_saved.push_back(a);
assign_mvar(m_mctx, mid, a_new);
}
return optional<expr>(a_new);
}
}
}
/*
Given `e` of the form `f a_1 ... a_n` where `f` is not a metavariable,
instantiate metavariables.
*/
expr visit_app_default(expr const & e) {
if (is_app(e)) {
return update_app(e, visit_app_default(app_fn(e)), visit(app_arg(e)));
} else {
lean_assert(!is_mvar(e));
return visit(e);
}
}
/*
Given `e` of the form `?m a_1 ... a_n`, return new application where
the metavariables in the arguments `a_i` have been instantiated.
*/
expr visit_mvar_app_args(expr const & e) {
if (is_app(e)) {
return update_app(e, visit_app_default(app_fn(e)), visit(app_arg(e)));
} else {
lean_assert(is_mvar(e));
return e;
}
}
/*
Given `e` of the form `f a_1 ... a_n`, return new application `f_new a_1' ... a_n'`
where `a_i'` is `visit(a_i)`. `args` is an accumulator for the new arguments.
*/
expr visit_args_and_beta(expr const & f_new, expr const & e, buffer<expr> & args) {
if (is_app(e)) {
args.push_back(visit(app_arg(e)));
return visit_args_and_beta(f_new, app_fn(e), args);
} else {
/*
Some of the arguments in `args` are irrelevant after we beta
reduce. Also, it may be a bug to not instantiate them, since they
may depend on free variables that are not in the context (see
issue #4375). So we pass `useZeta := true` to ensure that they are
instantiated.
*/
bool preserve_data = false;
bool zeta = true;
return apply_beta(f_new, args.size(), args.data(), preserve_data, zeta);
}
}
/*
Helper function for delayed assignment case at `visit_app`.
`e` is a term of the form `?m t1 t2 t3`
Moreover, `?m` is delayed assigned
`?m #[x, y] := g x y`
where, `fvars := #[x, y]` and `val := g x y`.
`args` is an accumulator for `e`'s arguments.
We want to return `g t1' t2' t3'` where
`ti'`s are `visit(ti)`.
*/
expr visit_delayed(array_ref<expr> const & fvars, expr const & val, expr const & e, buffer<expr> & args) {
if (is_app(e)) {
args.push_back(visit(app_arg(e)));
return visit_delayed(fvars, val, app_fn(e), args);
} else {
expr val_new = replace_fvars(val, fvars, args.data() + (args.size() - fvars.size()));
return mk_rev_app(val_new, args.size() - fvars.size(), args.data());
}
}
expr visit_app(expr const & e) {
expr const & f = get_app_fn(e);
if (!is_mvar(f)) {
return visit_app_default(e);
} else {
name const & mid = mvar_name(f);
option_ref<delayed_assignment> d = get_delayed_mvar_assignment(m_mctx, mid);
if (!d) {
// mvar is not delayed assigned
expr f_new = visit(f);
if (is_eqp(f, f_new)) {
return visit_mvar_app_args(e);
} else {
buffer<expr> args;
return visit_args_and_beta(f_new, e, args);
}
} else {
/*
Apply "delayed substitution" (i.e., delayed assignment + application).
That is, `f` is some metavariable `?m`, that is delayed assigned to `val`.
If after instantiating `val`, we obtain `newVal`, and `newVal` does not contain
metavariables, we replace the free variables `fvars` in `newVal` with the first
`fvars.size` elements of `args`.
*/
array_ref<expr> fvars(cnstr_get(d.get_val().raw(), 0), true);
name mid_pending(cnstr_get(d.get_val().raw(), 1), true);
if (fvars.size() > get_app_num_args(e)) {
/*
We don't have sufficient arguments for instantiating the free variables `fvars`.
This can only happen if a tactic or elaboration function is not implemented correctly.
We decided to not use `panic!` here and report it as an error in the frontend
when we are checking for unassigned metavariables in an elaborated term. */
return visit_mvar_app_args(e);
}
optional<expr> val = get_assignment(mid_pending);
if (!val)
// mid_pending has not been assigned yet.
return visit_mvar_app_args(e);
if (has_expr_mvar(*val))
// mid_pending has been assigned, but assignment contains mvars.
return visit_mvar_app_args(e);
buffer<expr> args;
return visit_delayed(fvars, *val, e, args);
}
}
}
expr visit_mvar(expr const & e) {
name const & mid = mvar_name(e);
if (auto r = get_assignment(mid)) {
return *r;
} else {
return e;
}
}
public:
instantiate_mvars_fn(metavar_ctx & mctx):m_mctx(mctx), m_level_fn(mctx) {}
expr visit(expr const & e) {
if (!has_mvar(e))
return e;
bool shared = false;
if (is_shared(e)) {
auto it = m_cache.find(e.raw());
if (it != m_cache.end()) {
return it->second;
}
shared = true;
}
switch (e.kind()) {
case expr_kind::BVar:
case expr_kind::Lit: case expr_kind::FVar:
lean_unreachable();
case expr_kind::Sort:
return cache(e, update_sort(e, visit_level(sort_level(e))), shared);
case expr_kind::Const:
return cache(e, update_const(e, visit_levels(const_levels(e))), shared);
case expr_kind::MVar:
return visit_mvar(e);
case expr_kind::MData:
return cache(e, update_mdata(e, visit(mdata_expr(e))), shared);
case expr_kind::Proj:
return cache(e, update_proj(e, visit(proj_expr(e))), shared);
case expr_kind::App:
return cache(e, visit_app(e), shared);
case expr_kind::Pi: case expr_kind::Lambda:
return cache(e, update_binding(e, visit(binding_domain(e)), visit(binding_body(e))), shared);
case expr_kind::Let:
return cache(e, update_let(e, visit(let_type(e)), visit(let_value(e)), visit(let_body(e))), shared);
}
}
expr operator()(expr const & e) { return visit(e); }
};
extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars(object * m, object * e) {
metavar_ctx mctx(m);
expr e_new = instantiate_mvars_fn(mctx)(expr(e));
object * r = alloc_cnstr(0, 2, 0);
cnstr_set(r, 0, mctx.steal());
cnstr_set(r, 1, e_new.steal());
return r;
}
}