mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-17 18:34:06 +00:00
chore(library/init,runtime,library/compiler): add fix primitive back
The new `partial def`s allow us to define `fix` in Lean, but the Lean implementation is not as efficient as the native one. The native one in C++ use weak pointers to prevent a closure allocation at every recursive invocation. This commit also fixes the `fixCore` helper functions that were broken after we switched to camelCase. We have updated the test `fix1.lean` to demonstrate the native implementation is faster. Here are the numbers on my desktop. ``` ./run.sh fix1.lean 24 721420279 Time for 'native fix': 816ms 721420279 Time for 'fix in lean': 1.34s ```
This commit is contained in:
@@ -6,3 +6,4 @@ Authors: Leonardo de Moura
|
||||
prelude
|
||||
import init.core init.control init.data.basic
|
||||
import init.coe init.wf init.data init.io init.util
|
||||
import init.fix
|
||||
|
||||
80
library/init/fix.lean
Normal file
80
library/init/fix.lean
Normal file
@@ -0,0 +1,80 @@
|
||||
/-
|
||||
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import init.data.uint
|
||||
universe u
|
||||
|
||||
def bfix1 {α β : Type u} (base : α → β) (rec : (α → β) → α → β) : Nat → α → β
|
||||
| 0 a := base a
|
||||
| (n+1) a := rec (bfix1 n) a
|
||||
|
||||
@[extern cpp inline "lean::fixpoint(#4, #5)"]
|
||||
def fixCore1 {α β : Type u} (base : @& (α → β)) (rec : (α → β) → α → β) : α → β :=
|
||||
bfix1 base rec usizeSz
|
||||
|
||||
@[inline] def fixCore {α β : Type u} (base : @& (α → β)) (rec : (α → β) → α → β) : α → β :=
|
||||
fixCore1 base rec
|
||||
|
||||
@[inline] def fix1 {α β : Type u} [Inhabited β] (rec : (α → β) → α → β) : α → β :=
|
||||
fixCore1 (λ _, default β) rec
|
||||
|
||||
@[inline] def fix {α β : Type u} [Inhabited β] (rec : (α → β) → α → β) : α → β :=
|
||||
fixCore1 (λ _, default β) rec
|
||||
|
||||
def bfix2 {α₁ α₂ β : Type u} (base : α₁ → α₂ → β) (rec : (α₁ → α₂ → β) → α₁ → α₂ → β) : Nat → α₁ → α₂ → β
|
||||
| 0 a₁ a₂ := base a₁ a₂
|
||||
| (n+1) a₁ a₂ := rec (bfix2 n) a₁ a₂
|
||||
|
||||
@[extern cpp inline "lean::fixpoint2(#5, #6, #7)"]
|
||||
def fixCore2 {α₁ α₂ β : Type u} (base : α₁ → α₂ → β) (rec : (α₁ → α₂ → β) → α₁ → α₂ → β) : α₁ → α₂ → β :=
|
||||
bfix2 base rec usizeSz
|
||||
|
||||
@[inline] def fix2 {α₁ α₂ β : Type u} [Inhabited β] (rec : (α₁ → α₂ → β) → α₁ → α₂ → β) : α₁ → α₂ → β :=
|
||||
fixCore2 (λ _ _, default β) rec
|
||||
|
||||
def bfix3 {α₁ α₂ α₃ β : Type u} (base : α₁ → α₂ → α₃ → β) (rec : (α₁ → α₂ → α₃ → β) → α₁ → α₂ → α₃ → β) : Nat → α₁ → α₂ → α₃ → β
|
||||
| 0 a₁ a₂ a₃ := base a₁ a₂ a₃
|
||||
| (n+1) a₁ a₂ a₃ := rec (bfix3 n) a₁ a₂ a₃
|
||||
|
||||
@[extern cpp inline "lean::fixpoint3(#6, #7, #8, #9)"]
|
||||
def fixCore3 {α₁ α₂ α₃ β : Type u} (base : α₁ → α₂ → α₃ → β) (rec : (α₁ → α₂ → α₃ → β) → α₁ → α₂ → α₃ → β) : α₁ → α₂ → α₃ → β :=
|
||||
bfix3 base rec usizeSz
|
||||
|
||||
@[inline] def fix3 {α₁ α₂ α₃ β : Type u} [Inhabited β] (rec : (α₁ → α₂ → α₃ → β) → α₁ → α₂ → α₃ → β) : α₁ → α₂ → α₃ → β :=
|
||||
fixCore3 (λ _ _ _, default β) rec
|
||||
|
||||
def bfix4 {α₁ α₂ α₃ α₄ β : Type u} (base : α₁ → α₂ → α₃ → α₄ → β) (rec : (α₁ → α₂ → α₃ → α₄ → β) → α₁ → α₂ → α₃ → α₄ → β) : Nat → α₁ → α₂ → α₃ → α₄ → β
|
||||
| 0 a₁ a₂ a₃ a₄ := base a₁ a₂ a₃ a₄
|
||||
| (n+1) a₁ a₂ a₃ a₄ := rec (bfix4 n) a₁ a₂ a₃ a₄
|
||||
|
||||
@[extern cpp inline "lean::fixpoint4(#7, #8, #9, #10, #11)"]
|
||||
def fixCore4 {α₁ α₂ α₃ α₄ β : Type u} (base : α₁ → α₂ → α₃ → α₄ → β) (rec : (α₁ → α₂ → α₃ → α₄ → β) → α₁ → α₂ → α₃ → α₄ → β) : α₁ → α₂ → α₃ → α₄ → β :=
|
||||
bfix4 base rec usizeSz
|
||||
|
||||
@[inline] def fix4 {α₁ α₂ α₃ α₄ β : Type u} [Inhabited β] (rec : (α₁ → α₂ → α₃ → α₄ → β) → α₁ → α₂ → α₃ → α₄ → β) : α₁ → α₂ → α₃ → α₄ → β :=
|
||||
fixCore4 (λ _ _ _ _, default β) rec
|
||||
|
||||
def bfix5 {α₁ α₂ α₃ α₄ α₅ β : Type u} (base : α₁ → α₂ → α₃ → α₄ → α₅ → β) (rec : (α₁ → α₂ → α₃ → α₄ → α₅ → β) → α₁ → α₂ → α₃ → α₄ → α₅ → β) : Nat → α₁ → α₂ → α₃ → α₄ → α₅ → β
|
||||
| 0 a₁ a₂ a₃ a₄ a₅ := base a₁ a₂ a₃ a₄ a₅
|
||||
| (n+1) a₁ a₂ a₃ a₄ a₅ := rec (bfix5 n) a₁ a₂ a₃ a₄ a₅
|
||||
|
||||
@[extern cpp inline "lean::fixpoint5(#8, #9, #10, #11, #12, #13)"]
|
||||
def fixCore5 {α₁ α₂ α₃ α₄ α₅ β : Type u} (base : α₁ → α₂ → α₃ → α₄ → α₅ → β) (rec : (α₁ → α₂ → α₃ → α₄ → α₅ → β) → α₁ → α₂ → α₃ → α₄ → α₅ → β) : α₁ → α₂ → α₃ → α₄ → α₅ → β :=
|
||||
bfix5 base rec usizeSz
|
||||
|
||||
@[inline] def fix5 {α₁ α₂ α₃ α₄ α₅ β : Type u} [Inhabited β] (rec : (α₁ → α₂ → α₃ → α₄ → α₅ → β) → α₁ → α₂ → α₃ → α₄ → α₅ → β) : α₁ → α₂ → α₃ → α₄ → α₅ → β :=
|
||||
fixCore5 (λ _ _ _ _ _, default β) rec
|
||||
|
||||
def bfix6 {α₁ α₂ α₃ α₄ α₅ α₆ β : Type u} (base : α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β) (rec : (α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β) → α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β) : Nat → α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β
|
||||
| 0 a₁ a₂ a₃ a₄ a₅ a₆ := base a₁ a₂ a₃ a₄ a₅ a₆
|
||||
| (n+1) a₁ a₂ a₃ a₄ a₅ a₆ := rec (bfix6 n) a₁ a₂ a₃ a₄ a₅ a₆
|
||||
|
||||
@[extern cpp inline "lean::fixpoint6(#9, #10, #11, #12, #13, #14, #15)"]
|
||||
def fixCore6 {α₁ α₂ α₃ α₄ α₅ α₆ β : Type u} (base : α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β) (rec : (α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β) → α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β) : α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β :=
|
||||
bfix6 base rec usizeSz
|
||||
|
||||
@[inline] def fix6 {α₁ α₂ α₃ α₄ α₅ α₆ β : Type u} [Inhabited β] (rec : (α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β) → α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β) : α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β :=
|
||||
fixCore6 (λ _ _ _ _ _ _, default β) rec
|
||||
@@ -6,7 +6,7 @@ Author: Sebastian Ullrich
|
||||
Recursion monad transformer
|
||||
-/
|
||||
prelude
|
||||
import init.control.reader init.lean.parser.parsec
|
||||
import init.control.reader init.lean.parser.parsec init.fix
|
||||
|
||||
namespace Lean.Parser
|
||||
|
||||
@@ -23,13 +23,10 @@ local attribute [reducible] RecT
|
||||
@[inline] def recurse (a : α) : RecT α δ m δ :=
|
||||
λ f, f a
|
||||
|
||||
@[specialize] private partial def runAux : m δ → (α → RecT α δ m δ) → α → m δ
|
||||
| b rec a := rec a (runAux b rec)
|
||||
|
||||
/-- Execute `x`, executing `rec a` whenever `recurse a` is called.
|
||||
After `maxRec` recursion steps, `base` is executed instead. -/
|
||||
@[inline] protected def run (x : RecT α δ m β) (base : Unit → m δ) (rec : α → RecT α δ m δ) : m β :=
|
||||
x.run (runAux (base ()) rec)
|
||||
@[inline] protected def run (x : RecT α δ m β) (base : α → m δ) (rec : α → RecT α δ m δ) : m β :=
|
||||
x (fixCore base (λ a f, rec f a))
|
||||
|
||||
@[inline] protected def runParsec {γ : Type} [MonadParsec γ m] (x : RecT α δ m β) (rec : α → RecT α δ m δ) : m β :=
|
||||
RecT.run x (λ _, MonadParsec.error "RecT.runParsec: no progress") rec
|
||||
|
||||
@@ -1377,6 +1377,46 @@ class csimp_fn {
|
||||
return mk_app(mk_constant(get_nat_add_name()), arg, mk_lit(literal(nat(1))));
|
||||
}
|
||||
|
||||
/*
|
||||
Replace `fixCore<n> f a_1 ... a_m`
|
||||
with `fixCore<m> f a_1 ... a_m` whenever `n < m`.
|
||||
This optimization is for writing reusable/generic code. For
|
||||
example, we cannot write an efficient `rec_t` monad transformer
|
||||
without it because we don't know the arity of `m A` when we write `rec_t`.
|
||||
Remark: the runtime provides a small set of `fixCore<i>` implementations (`i in [1, 6]`).
|
||||
This methods does nothing if `m > 6`. */
|
||||
expr visit_fix_core(expr const & e, unsigned n) {
|
||||
if (m_before_erasure) return visit_app_default(e);
|
||||
buffer<expr> args;
|
||||
expr fn = get_app_args(e, args);
|
||||
lean_assert(is_constant(fn) && is_fix_core(const_name(fn)));
|
||||
unsigned arity =
|
||||
n + /* α_1 ... α_n Type arguments */
|
||||
1 + /* β : Type */
|
||||
1 + /* (base : α_1 → ... → α_n → β) */
|
||||
1 + /* (rec : (α_1 → ... → α_n → β) → α_1 → ... → α_n → β) */
|
||||
n; /* α_1 → ... → α_n */
|
||||
if (args.size() <= arity) return visit_app_default(e);
|
||||
/* This `fixCore<n>` application is an overapplication.
|
||||
The `fixCore<n>` is implemented by the runtime, and the result
|
||||
is a closure. This is bad for performance. We should
|
||||
replace it with `fixCore<m>` (if the runtime contains one) */
|
||||
unsigned num_extra = args.size() - arity;
|
||||
unsigned m = n + num_extra;
|
||||
optional<expr> fix_core_m = mk_enf_fix_core(m);
|
||||
if (!fix_core_m) return visit_app_default(e);
|
||||
buffer<expr> new_args;
|
||||
/* Add α_1 ... α_n and β */
|
||||
for (unsigned i = 0; i < m+1; i++) {
|
||||
new_args.push_back(mk_enf_neutral());
|
||||
}
|
||||
/* `(base : α_1 → ... → α_n → β)` is not used in the runtime primitive.
|
||||
So, we replace it with a neutral value :) */
|
||||
new_args.push_back(mk_enf_neutral());
|
||||
new_args.append(args.size() - n - 2, args.data() + n + 2);
|
||||
return mk_app(*fix_core_m, new_args);
|
||||
}
|
||||
|
||||
expr visit_app(expr const & e, bool is_let_val) {
|
||||
if (is_cases_on_app(env(), e)) {
|
||||
return visit_cases(e, is_let_val);
|
||||
@@ -1417,6 +1457,8 @@ class csimp_fn {
|
||||
return mk_lit(literal(nat(0)));
|
||||
} else if (optional<expr> r = try_inline(fn, e, is_let_val)) {
|
||||
return *r;
|
||||
} else if (optional<unsigned> i = is_fix_core(n)) {
|
||||
return visit_fix_core(e, *i);
|
||||
} else {
|
||||
return visit_app_default(e);
|
||||
}
|
||||
|
||||
@@ -516,15 +516,17 @@ optional<nat> get_num_lit_ext(expr const & e) {
|
||||
optional<unsigned> is_fix_core(name const & n) {
|
||||
if (!n.is_atomic() || !n.is_string()) return optional<unsigned>();
|
||||
string_ref const & r = n.get_string();
|
||||
if (r.length() != 10) return optional<unsigned>();
|
||||
if (r.length() != 8) return optional<unsigned>();
|
||||
char const * s = r.data();
|
||||
if (std::strncmp(s, "fix_core_", 9) != 0 || !std::isdigit(s[9])) return optional<unsigned>();
|
||||
return optional<unsigned>(s[9] - '0');
|
||||
if (std::strncmp(s, "fixCore", 7) != 0 || !std::isdigit(s[7])) return optional<unsigned>();
|
||||
return optional<unsigned>(s[7] - '0');
|
||||
}
|
||||
|
||||
optional<expr> mk_enf_fix_core(unsigned n) {
|
||||
if (n == 0 || n > 6) return none_expr();
|
||||
return some_expr(mk_constant(name("fix_core").append_after(n)));
|
||||
std::ostringstream s;
|
||||
s << "fixCore" << n;
|
||||
return some_expr(mk_constant(name(s.str())));
|
||||
}
|
||||
|
||||
void initialize_compiler_util() {
|
||||
|
||||
@@ -1756,6 +1756,107 @@ object * array_push(obj_arg a, obj_arg v) {
|
||||
return r;
|
||||
}
|
||||
|
||||
// =======================================
|
||||
// fixpoint
|
||||
|
||||
static inline object * ptr_to_weak_ptr(object * p) {
|
||||
return reinterpret_cast<object*>(reinterpret_cast<uintptr_t>(p) | 1);
|
||||
}
|
||||
|
||||
static inline object * weak_ptr_to_ptr(object * w) {
|
||||
return reinterpret_cast<object*>((reinterpret_cast<uintptr_t>(w) >> 1) << 1);
|
||||
}
|
||||
|
||||
obj_res fixpoint_aux(obj_arg rec, obj_arg weak_k, obj_arg a) {
|
||||
object * k = weak_ptr_to_ptr(weak_k);
|
||||
inc(k);
|
||||
return apply_2(rec, k, a);
|
||||
}
|
||||
|
||||
obj_res fixpoint(obj_arg rec, obj_arg a) {
|
||||
object * k = alloc_closure(fixpoint_aux, 2);
|
||||
inc(rec);
|
||||
closure_set(k, 0, rec);
|
||||
closure_set(k, 1, ptr_to_weak_ptr(k));
|
||||
object * r = apply_2(rec, k, a);
|
||||
return r;
|
||||
}
|
||||
|
||||
obj_res fixpoint_aux2(obj_arg rec, obj_arg weak_k, obj_arg a1, obj_arg a2) {
|
||||
object * k = weak_ptr_to_ptr(weak_k);
|
||||
inc(k);
|
||||
return apply_3(rec, k, a1, a2);
|
||||
}
|
||||
|
||||
obj_res fixpoint2(obj_arg rec, obj_arg a1, obj_arg a2) {
|
||||
object * k = alloc_closure(fixpoint_aux2, 2);
|
||||
inc(rec);
|
||||
closure_set(k, 0, rec);
|
||||
closure_set(k, 1, ptr_to_weak_ptr(k));
|
||||
object * r = apply_3(rec, k, a1, a2);
|
||||
return r;
|
||||
}
|
||||
|
||||
obj_res fixpoint_aux3(obj_arg rec, obj_arg weak_k, obj_arg a1, obj_arg a2, obj_arg a3) {
|
||||
object * k = weak_ptr_to_ptr(weak_k);
|
||||
inc(k);
|
||||
return apply_4(rec, k, a1, a2, a3);
|
||||
}
|
||||
|
||||
obj_res fixpoint3(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3) {
|
||||
object * k = alloc_closure(fixpoint_aux3, 2);
|
||||
inc(rec);
|
||||
closure_set(k, 0, rec);
|
||||
closure_set(k, 1, ptr_to_weak_ptr(k));
|
||||
object * r = apply_4(rec, k, a1, a2, a3);
|
||||
return r;
|
||||
}
|
||||
|
||||
obj_res fixpoint_aux4(obj_arg rec, obj_arg weak_k, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4) {
|
||||
object * k = weak_ptr_to_ptr(weak_k);
|
||||
inc(k);
|
||||
return apply_5(rec, k, a1, a2, a3, a4);
|
||||
}
|
||||
|
||||
obj_res fixpoint4(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4) {
|
||||
object * k = alloc_closure(fixpoint_aux4, 2);
|
||||
inc(rec);
|
||||
closure_set(k, 0, rec);
|
||||
closure_set(k, 1, ptr_to_weak_ptr(k));
|
||||
object * r = apply_5(rec, k, a1, a2, a3, a4);
|
||||
return r;
|
||||
}
|
||||
|
||||
obj_res fixpoint_aux5(obj_arg rec, obj_arg weak_k, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5) {
|
||||
object * k = weak_ptr_to_ptr(weak_k);
|
||||
inc(k);
|
||||
return apply_6(rec, k, a1, a2, a3, a4, a5);
|
||||
}
|
||||
|
||||
obj_res fixpoint5(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5) {
|
||||
object * k = alloc_closure(fixpoint_aux5, 2);
|
||||
inc(rec);
|
||||
closure_set(k, 0, rec);
|
||||
closure_set(k, 1, ptr_to_weak_ptr(k));
|
||||
object * r = apply_6(rec, k, a1, a2, a3, a4, a5);
|
||||
return r;
|
||||
}
|
||||
|
||||
obj_res fixpoint_aux6(obj_arg rec, obj_arg weak_k, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5, obj_arg a6) {
|
||||
object * k = weak_ptr_to_ptr(weak_k);
|
||||
inc(k);
|
||||
return apply_7(rec, k, a1, a2, a3, a4, a5, a6);
|
||||
}
|
||||
|
||||
obj_res fixpoint6(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5, obj_arg a6) {
|
||||
object * k = alloc_closure(fixpoint_aux6, 2);
|
||||
inc(rec);
|
||||
closure_set(k, 0, rec);
|
||||
closure_set(k, 1, ptr_to_weak_ptr(k));
|
||||
object * r = apply_7(rec, k, a1, a2, a3, a4, a5, a6);
|
||||
return r;
|
||||
}
|
||||
|
||||
// =======================================
|
||||
// Debugging helper functions
|
||||
|
||||
|
||||
@@ -652,6 +652,16 @@ inline obj_res alloc_closure(object*(*fun)(object *, object *, object *, object
|
||||
return alloc_closure(reinterpret_cast<void*>(fun), 8, num_fixed);
|
||||
}
|
||||
|
||||
// =======================================
|
||||
// Fixpoint
|
||||
|
||||
obj_res fixpoint(obj_arg rec, obj_arg a);
|
||||
obj_res fixpoint2(obj_arg rec, obj_arg a1, obj_arg a2);
|
||||
obj_res fixpoint3(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3);
|
||||
obj_res fixpoint4(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4);
|
||||
obj_res fixpoint5(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5);
|
||||
obj_res fixpoint6(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5, obj_arg a6);
|
||||
|
||||
// =======================================
|
||||
// Array of objects
|
||||
|
||||
|
||||
21
tests/playground/fix1.lean
Normal file
21
tests/playground/fix1.lean
Normal file
@@ -0,0 +1,21 @@
|
||||
def foo (rec : Nat → Nat → Nat) : Nat → Nat → Nat
|
||||
| 0 a := a
|
||||
| (n+1) a := rec n a + a + rec n (a+1)
|
||||
|
||||
partial def fix' (f: (Nat → Nat → Nat) → (Nat → Nat → Nat)) : Nat → Nat → Nat
|
||||
| a b := f fix' a b
|
||||
|
||||
def prof {α : Type} (msg : String) (p : IO α) : IO α :=
|
||||
let msg := "Time for '" ++ msg ++ "':" in
|
||||
timeit msg p
|
||||
|
||||
def fix_test (n : Nat) : IO Unit :=
|
||||
IO.println (fix foo n 10)
|
||||
|
||||
def fix'_test (n : Nat) : IO Unit :=
|
||||
IO.println (fix' foo n 10)
|
||||
|
||||
def main (xs : List String) : IO Unit :=
|
||||
prof "native fix" (fix_test xs.head.toNat) *>
|
||||
prof "fix in lean" (fix'_test xs.head.toNat) *>
|
||||
pure ()
|
||||
Reference in New Issue
Block a user