Compare commits

...

3 Commits

Author SHA1 Message Date
Leonardo de Moura
85c14269b5 chore: missing include 2024-08-04 11:13:05 -07:00
Leonardo de Moura
c5a7037593 perf: add `lean_instantiate_level_mvars"
The new code is not active yet because of bootstrapping issues.
2024-08-04 11:00:25 -07:00
Leonardo de Moura
2f42232daa chore: export MetavarContext API 2024-08-04 09:35:05 -07:00
6 changed files with 132 additions and 3 deletions

View File

@@ -336,6 +336,8 @@ structure MetavarContext where
For more information about delayed abstraction, see the docstring for `DelayedMetavarAssignment`. -/
dAssignment : PersistentHashMap MVarId DelayedMetavarAssignment := {}
instance : Inhabited MetavarContext := {}
/-- A monad with a stateful metavariable context, defining `getMCtx` and `modifyMCtx`. -/
class MonadMCtx (m : Type Type) where
getMCtx : m MetavarContext
@@ -358,15 +360,27 @@ abbrev setMCtx [MonadMCtx m] (mctx : MetavarContext) : m Unit :=
abbrev getLevelMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : LMVarId) : m (Option Level) :=
return ( getMCtx).lAssignment.find? mvarId
@[export lean_get_lmvar_assignment]
def getLevelMVarAssignmentExp (m : MetavarContext) (mvarId : LMVarId) : Option Level :=
m.lAssignment.find? mvarId
def MetavarContext.getExprAssignmentCore? (m : MetavarContext) (mvarId : MVarId) : Option Expr :=
m.eAssignment.find? mvarId
@[export lean_get_mvar_assignment]
def MetavarContext.getExprAssignmentExp (m : MetavarContext) (mvarId : MVarId) : Option Expr :=
m.eAssignment.find? mvarId
def getExprMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option Expr) :=
return ( getMCtx).getExprAssignmentCore? mvarId
def MetavarContext.getDelayedMVarAssignmentCore? (mctx : MetavarContext) (mvarId : MVarId) : Option DelayedMetavarAssignment :=
mctx.dAssignment.find? mvarId
@[export lean_get_delayed_mvar_assignment]
def MetavarContext.getDelayedMVarAssignmentExp (mctx : MetavarContext) (mvarId : MVarId) : Option DelayedMetavarAssignment :=
mctx.dAssignment.find? mvarId
def getDelayedMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option DelayedMetavarAssignment) :=
return ( getMCtx).getDelayedMVarAssignmentCore? mvarId
@@ -478,6 +492,10 @@ def hasAssignableMVar [Monad m] [MonadMCtx m] : Expr → m Bool
def assignLevelMVar [MonadMCtx m] (mvarId : LMVarId) (val : Level) : m Unit :=
modifyMCtx fun m => { m with lAssignment := m.lAssignment.insert mvarId val }
@[export lean_assign_lmvar]
def assignLevelMVarExp (m : MetavarContext) (mvarId : LMVarId) (val : Level) : MetavarContext :=
{ m with lAssignment := m.lAssignment.insert mvarId val }
/--
Add `mvarId := x` to the metavariable assignment.
This method does not check whether `mvarId` is already assigned, nor it checks whether
@@ -487,6 +505,10 @@ This is a low-level API, and it is safer to use `isDefEq (mkMVar mvarId) x`.
def _root_.Lean.MVarId.assign [MonadMCtx m] (mvarId : MVarId) (val : Expr) : m Unit :=
modifyMCtx fun m => { m with eAssignment := m.eAssignment.insert mvarId val }
@[export lean_assign_mvar]
def assignExp (m : MetavarContext) (mvarId : MVarId) (val : Expr) : MetavarContext :=
{ m with eAssignment := m.eAssignment.insert mvarId val }
/--
Add a delayed assignment for the given metavariable. You must make sure that
the metavariable is not already assigned or delayed-assigned.
@@ -516,6 +538,9 @@ To avoid this term eta-expanded term, we apply beta-reduction when instantiating
This operation is performed at `instantiateExprMVars`, `elimMVarDeps`, and `levelMVarToParam`.
-/
@[extern "lean_instantiate_level_mvars"]
opaque instantiateLevelMVarsImp (mctx : MetavarContext) (l : Level) : MetavarContext × Level
partial def instantiateLevelMVars [Monad m] [MonadMCtx m] : Level m Level
| lvl@(Level.succ lvl₁) => return Level.updateSucc! lvl ( instantiateLevelMVars lvl₁)
| lvl@(Level.max lvl₁ lvl₂) => return Level.updateMax! lvl ( instantiateLevelMVars lvl₁) ( instantiateLevelMVars lvl₂)
@@ -531,6 +556,9 @@ partial def instantiateLevelMVars [Monad m] [MonadMCtx m] : Level → m Level
| none => pure lvl
| lvl => pure lvl
@[extern "lean_instantiate_expr_mvars"]
opaque instantiateExprMVarsImp (mctx : MetavarContext) (e : Expr) : MetavarContext × Expr
/-- instantiateExprMVars main function -/
partial def instantiateExprMVars [Monad m] [MonadMCtx m] [STWorld ω m] [MonadLiftT (ST ω) m] (e : Expr) : MonadCacheT ExprStructEq Expr m Expr :=
if !e.hasMVar then
@@ -792,8 +820,6 @@ def localDeclDependsOnPred [Monad m] [MonadMCtx m] (localDecl : LocalDecl) (pf :
namespace MetavarContext
instance : Inhabited MetavarContext := {}
@[export lean_mk_metavar_ctx]
def mkMetavarContext : Unit MetavarContext := fun _ => {}

View File

@@ -2,4 +2,4 @@ add_library(kernel OBJECT level.cpp expr.cpp expr_eq_fn.cpp
for_each_fn.cpp replace_fn.cpp abstract.cpp instantiate.cpp
local_ctx.cpp declaration.cpp environment.cpp type_checker.cpp
init_module.cpp expr_cache.cpp equiv_manager.cpp quot.cpp
inductive.cpp trace.cpp)
inductive.cpp trace.cpp instantiate_mvars.cpp)

View File

@@ -0,0 +1,95 @@
/*
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
*/
#include <unordered_map>
#include "runtime/option_ref.h"
#include "kernel/instantiate.h"
#include "kernel/abstract.h"
/*
This module is not used by the kernel. It just provides an efficient implementation of
`instantiateExprMVars`
*/
namespace lean {
extern "C" object * lean_get_lmvar_assignment(obj_arg mctx, obj_arg mid);
extern "C" object * lean_assign_lmvar(obj_arg mctx, obj_arg mid, obj_arg val);
typedef object_ref metavar_ctx;
void assign_lmvar(metavar_ctx & mctx, name const & mid, level const & l) {
object * r = lean_assign_lmvar(mctx.steal(), mid.to_obj_arg(), l.to_obj_arg());
mctx.set_box(r);
}
option_ref<level> get_lmvar_assignment(metavar_ctx & mctx, name const & mid) {
return option_ref<level>(lean_get_lmvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg()));
}
class instantiate_lmvar_fn {
metavar_ctx & m_mctx;
std::unordered_map<lean_object *, lean_object *> m_cache;
inline level cache(level const & l, level && r, bool shared) {
if (shared) {
m_cache.insert(mk_pair(l.raw(), r.raw()));
}
return r;
}
public:
instantiate_lmvar_fn(metavar_ctx & mctx):m_mctx(mctx) {}
level visit(level const & l) {
if (!has_mvar(l))
return l;
bool shared = false;
if (is_shared(l)) {
auto it = m_cache.find(l.raw());
if (it != m_cache.end()) {
return level(it->second, true);
}
shared = true;
}
switch (l.kind()) {
case level_kind::Succ:
return cache(l, update_succ(l, visit(succ_of(l))), shared);
case level_kind::Max: case level_kind::IMax:
return cache(l, update_max(l, visit(level_lhs(l)), visit(level_rhs(l))), shared);
case level_kind::Zero: case level_kind::Param:
lean_unreachable();
case level_kind::MVar: {
option_ref<level> r = get_lmvar_assignment(m_mctx, mvar_id(l));
if (!r) {
return l;
} else {
level a(r.get_val());
if (!has_mvar(a)) {
return a;
} else {
level a_new = visit(a);
if (!is_eqp(a, a_new)) {
assign_lmvar(m_mctx, mvar_id(l), a_new);
}
return a_new;
}
}
}}
}
level operator()(level const & l) { return visit(l); }
};
extern "C" LEAN_EXPORT object * lean_instantiate_level_mvars(object * m, object * l) {
metavar_ctx mctx(m);
level l_new = instantiate_lmvar_fn(mctx)(level(l));
object * r = alloc_cnstr(0, 2, 0);
cnstr_set(r, 0, mctx.steal());
cnstr_set(r, 1, l_new.steal());
return r;
}
extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars(object *, object *) {
lean_internal_panic("not implemented yet");
}
}

View File

@@ -82,6 +82,8 @@ inline bool operator!=(level const & l1, level const & l2) { return !operator==(
struct level_hash { unsigned operator()(level const & n) const { return n.hash(); } };
struct level_eq { bool operator()(level const & n1, level const & n2) const { return n1 == n2; } };
inline bool is_shared(level const & l) { return !is_exclusive(l.raw()); }
inline optional<level> none_level() { return optional<level>(); }
inline optional<level> some_level(level const & e) { return optional<level>(e); }
inline optional<level> some_level(level && e) { return optional<level>(std::forward<level>(e)); }

View File

@@ -35,6 +35,10 @@ public:
s.m_obj = box(0);
return *this;
}
void set_box(object * o) {
lean_assert(is_scalar(m_obj));
m_obj = o;
}
object * raw() const { return m_obj; }
object * steal() { object * r = m_obj; m_obj = box(0); return r; }
object * to_obj_arg() const { inc(m_obj); return m_obj; }

View File

@@ -28,6 +28,8 @@ public:
explicit operator bool() const { return !is_scalar(raw()); }
optional<T> get() const { return *this ? some(static_cast<T const &>(cnstr_get_ref(*this, 0))) : optional<T>(); }
T get_val() const { lean_assert(*this); return static_cast<T const &>(cnstr_get_ref(*this, 0)); }
/** \brief Structural equality. */
friend bool operator==(option_ref const & o1, option_ref const & o2) {
return o1.get() == o2.get();