chore(library/init,runtime,library/compiler): add fix primitive back

The new `partial def`s allow us to define `fix` in Lean, but the Lean
implementation is not as efficient as the native one. The native one
in C++ use weak pointers to prevent a closure allocation at every
recursive invocation.

This commit also fixes the `fixCore` helper functions that were broken
after we switched to camelCase.

We have updated the test `fix1.lean` to demonstrate the native
implementation is faster. Here are the numbers on my desktop.

```
./run.sh fix1.lean 24
721420279
Time for 'native fix': 816ms
721420279
Time for 'fix in lean': 1.34s
```
This commit is contained in:
Leonardo de Moura
2019-03-27 17:13:53 -07:00
parent cd21793b53
commit 42fbe3c18c
8 changed files with 264 additions and 10 deletions

View File

@@ -6,3 +6,4 @@ Authors: Leonardo de Moura
prelude
import init.core init.control init.data.basic
import init.coe init.wf init.data init.io init.util
import init.fix

80
library/init/fix.lean Normal file
View File

@@ -0,0 +1,80 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import init.data.uint
universe u
def bfix1 {α β : Type u} (base : α β) (rec : (α β) α β) : Nat α β
| 0 a := base a
| (n+1) a := rec (bfix1 n) a
@[extern cpp inline "lean::fixpoint(#4, #5)"]
def fixCore1 {α β : Type u} (base : @& (α β)) (rec : (α β) α β) : α β :=
bfix1 base rec usizeSz
@[inline] def fixCore {α β : Type u} (base : @& (α β)) (rec : (α β) α β) : α β :=
fixCore1 base rec
@[inline] def fix1 {α β : Type u} [Inhabited β] (rec : (α β) α β) : α β :=
fixCore1 (λ _, default β) rec
@[inline] def fix {α β : Type u} [Inhabited β] (rec : (α β) α β) : α β :=
fixCore1 (λ _, default β) rec
def bfix2 {α₁ α₂ β : Type u} (base : α₁ α₂ β) (rec : (α₁ α₂ β) α₁ α₂ β) : Nat α₁ α₂ β
| 0 a₁ a₂ := base a₁ a₂
| (n+1) a₁ a₂ := rec (bfix2 n) a₁ a₂
@[extern cpp inline "lean::fixpoint2(#5, #6, #7)"]
def fixCore2 {α₁ α₂ β : Type u} (base : α₁ α₂ β) (rec : (α₁ α₂ β) α₁ α₂ β) : α₁ α₂ β :=
bfix2 base rec usizeSz
@[inline] def fix2 {α₁ α₂ β : Type u} [Inhabited β] (rec : (α₁ α₂ β) α₁ α₂ β) : α₁ α₂ β :=
fixCore2 (λ _ _, default β) rec
def bfix3 {α₁ α₂ α₃ β : Type u} (base : α₁ α₂ α₃ β) (rec : (α₁ α₂ α₃ β) α₁ α₂ α₃ β) : Nat α₁ α₂ α₃ β
| 0 a₁ a₂ a₃ := base a₁ a₂ a₃
| (n+1) a₁ a₂ a₃ := rec (bfix3 n) a₁ a₂ a₃
@[extern cpp inline "lean::fixpoint3(#6, #7, #8, #9)"]
def fixCore3 {α₁ α₂ α₃ β : Type u} (base : α₁ α₂ α₃ β) (rec : (α₁ α₂ α₃ β) α₁ α₂ α₃ β) : α₁ α₂ α₃ β :=
bfix3 base rec usizeSz
@[inline] def fix3 {α₁ α₂ α₃ β : Type u} [Inhabited β] (rec : (α₁ α₂ α₃ β) α₁ α₂ α₃ β) : α₁ α₂ α₃ β :=
fixCore3 (λ _ _ _, default β) rec
def bfix4 {α₁ α₂ α₃ α₄ β : Type u} (base : α₁ α₂ α₃ α₄ β) (rec : (α₁ α₂ α₃ α₄ β) α₁ α₂ α₃ α₄ β) : Nat α₁ α₂ α₃ α₄ β
| 0 a₁ a₂ a₃ a₄ := base a₁ a₂ a₃ a₄
| (n+1) a₁ a₂ a₃ a₄ := rec (bfix4 n) a₁ a₂ a₃ a₄
@[extern cpp inline "lean::fixpoint4(#7, #8, #9, #10, #11)"]
def fixCore4 {α₁ α₂ α₃ α₄ β : Type u} (base : α₁ α₂ α₃ α₄ β) (rec : (α₁ α₂ α₃ α₄ β) α₁ α₂ α₃ α₄ β) : α₁ α₂ α₃ α₄ β :=
bfix4 base rec usizeSz
@[inline] def fix4 {α₁ α₂ α₃ α₄ β : Type u} [Inhabited β] (rec : (α₁ α₂ α₃ α₄ β) α₁ α₂ α₃ α₄ β) : α₁ α₂ α₃ α₄ β :=
fixCore4 (λ _ _ _ _, default β) rec
def bfix5 {α₁ α₂ α₃ α₄ α₅ β : Type u} (base : α₁ α₂ α₃ α₄ α₅ β) (rec : (α₁ α₂ α₃ α₄ α₅ β) α₁ α₂ α₃ α₄ α₅ β) : Nat α₁ α₂ α₃ α₄ α₅ β
| 0 a₁ a₂ a₃ a₄ a₅ := base a₁ a₂ a₃ a₄ a₅
| (n+1) a₁ a₂ a₃ a₄ a₅ := rec (bfix5 n) a₁ a₂ a₃ a₄ a₅
@[extern cpp inline "lean::fixpoint5(#8, #9, #10, #11, #12, #13)"]
def fixCore5 {α₁ α₂ α₃ α₄ α₅ β : Type u} (base : α₁ α₂ α₃ α₄ α₅ β) (rec : (α₁ α₂ α₃ α₄ α₅ β) α₁ α₂ α₃ α₄ α₅ β) : α₁ α₂ α₃ α₄ α₅ β :=
bfix5 base rec usizeSz
@[inline] def fix5 {α₁ α₂ α₃ α₄ α₅ β : Type u} [Inhabited β] (rec : (α₁ α₂ α₃ α₄ α₅ β) α₁ α₂ α₃ α₄ α₅ β) : α₁ α₂ α₃ α₄ α₅ β :=
fixCore5 (λ _ _ _ _ _, default β) rec
def bfix6 {α₁ α₂ α₃ α₄ α₅ α₆ β : Type u} (base : α₁ α₂ α₃ α₄ α₅ α₆ β) (rec : (α₁ α₂ α₃ α₄ α₅ α₆ β) α₁ α₂ α₃ α₄ α₅ α₆ β) : Nat α₁ α₂ α₃ α₄ α₅ α₆ β
| 0 a₁ a₂ a₃ a₄ a₅ a₆ := base a₁ a₂ a₃ a₄ a₅ a₆
| (n+1) a₁ a₂ a₃ a₄ a₅ a₆ := rec (bfix6 n) a₁ a₂ a₃ a₄ a₅ a₆
@[extern cpp inline "lean::fixpoint6(#9, #10, #11, #12, #13, #14, #15)"]
def fixCore6 {α₁ α₂ α₃ α₄ α₅ α₆ β : Type u} (base : α₁ α₂ α₃ α₄ α₅ α₆ β) (rec : (α₁ α₂ α₃ α₄ α₅ α₆ β) α₁ α₂ α₃ α₄ α₅ α₆ β) : α₁ α₂ α₃ α₄ α₅ α₆ β :=
bfix6 base rec usizeSz
@[inline] def fix6 {α₁ α₂ α₃ α₄ α₅ α₆ β : Type u} [Inhabited β] (rec : (α₁ α₂ α₃ α₄ α₅ α₆ β) α₁ α₂ α₃ α₄ α₅ α₆ β) : α₁ α₂ α₃ α₄ α₅ α₆ β :=
fixCore6 (λ _ _ _ _ _ _, default β) rec

View File

@@ -6,7 +6,7 @@ Author: Sebastian Ullrich
Recursion monad transformer
-/
prelude
import init.control.reader init.lean.parser.parsec
import init.control.reader init.lean.parser.parsec init.fix
namespace Lean.Parser
@@ -23,13 +23,10 @@ local attribute [reducible] RecT
@[inline] def recurse (a : α) : RecT α δ m δ :=
λ f, f a
@[specialize] private partial def runAux : m δ (α RecT α δ m δ) α m δ
| b rec a := rec a (runAux b rec)
/-- Execute `x`, executing `rec a` whenever `recurse a` is called.
After `maxRec` recursion steps, `base` is executed instead. -/
@[inline] protected def run (x : RecT α δ m β) (base : Unit m δ) (rec : α RecT α δ m δ) : m β :=
x.run (runAux (base ()) rec)
@[inline] protected def run (x : RecT α δ m β) (base : α m δ) (rec : α RecT α δ m δ) : m β :=
x (fixCore base (λ a f, rec f a))
@[inline] protected def runParsec {γ : Type} [MonadParsec γ m] (x : RecT α δ m β) (rec : α RecT α δ m δ) : m β :=
RecT.run x (λ _, MonadParsec.error "RecT.runParsec: no progress") rec

View File

@@ -1377,6 +1377,46 @@ class csimp_fn {
return mk_app(mk_constant(get_nat_add_name()), arg, mk_lit(literal(nat(1))));
}
/*
Replace `fixCore<n> f a_1 ... a_m`
with `fixCore<m> f a_1 ... a_m` whenever `n < m`.
This optimization is for writing reusable/generic code. For
example, we cannot write an efficient `rec_t` monad transformer
without it because we don't know the arity of `m A` when we write `rec_t`.
Remark: the runtime provides a small set of `fixCore<i>` implementations (`i in [1, 6]`).
This methods does nothing if `m > 6`. */
expr visit_fix_core(expr const & e, unsigned n) {
if (m_before_erasure) return visit_app_default(e);
buffer<expr> args;
expr fn = get_app_args(e, args);
lean_assert(is_constant(fn) && is_fix_core(const_name(fn)));
unsigned arity =
n + /* α_1 ... α_n Type arguments */
1 + /* β : Type */
1 + /* (base : α_1 → ... → α_n → β) */
1 + /* (rec : (α_1 → ... → α_n → β) → α_1 → ... → α_n → β) */
n; /* α_1 → ... → α_n */
if (args.size() <= arity) return visit_app_default(e);
/* This `fixCore<n>` application is an overapplication.
The `fixCore<n>` is implemented by the runtime, and the result
is a closure. This is bad for performance. We should
replace it with `fixCore<m>` (if the runtime contains one) */
unsigned num_extra = args.size() - arity;
unsigned m = n + num_extra;
optional<expr> fix_core_m = mk_enf_fix_core(m);
if (!fix_core_m) return visit_app_default(e);
buffer<expr> new_args;
/* Add α_1 ... α_n and β */
for (unsigned i = 0; i < m+1; i++) {
new_args.push_back(mk_enf_neutral());
}
/* `(base : α_1 → ... → α_n → β)` is not used in the runtime primitive.
So, we replace it with a neutral value :) */
new_args.push_back(mk_enf_neutral());
new_args.append(args.size() - n - 2, args.data() + n + 2);
return mk_app(*fix_core_m, new_args);
}
expr visit_app(expr const & e, bool is_let_val) {
if (is_cases_on_app(env(), e)) {
return visit_cases(e, is_let_val);
@@ -1417,6 +1457,8 @@ class csimp_fn {
return mk_lit(literal(nat(0)));
} else if (optional<expr> r = try_inline(fn, e, is_let_val)) {
return *r;
} else if (optional<unsigned> i = is_fix_core(n)) {
return visit_fix_core(e, *i);
} else {
return visit_app_default(e);
}

View File

@@ -516,15 +516,17 @@ optional<nat> get_num_lit_ext(expr const & e) {
optional<unsigned> is_fix_core(name const & n) {
if (!n.is_atomic() || !n.is_string()) return optional<unsigned>();
string_ref const & r = n.get_string();
if (r.length() != 10) return optional<unsigned>();
if (r.length() != 8) return optional<unsigned>();
char const * s = r.data();
if (std::strncmp(s, "fix_core_", 9) != 0 || !std::isdigit(s[9])) return optional<unsigned>();
return optional<unsigned>(s[9] - '0');
if (std::strncmp(s, "fixCore", 7) != 0 || !std::isdigit(s[7])) return optional<unsigned>();
return optional<unsigned>(s[7] - '0');
}
optional<expr> mk_enf_fix_core(unsigned n) {
if (n == 0 || n > 6) return none_expr();
return some_expr(mk_constant(name("fix_core").append_after(n)));
std::ostringstream s;
s << "fixCore" << n;
return some_expr(mk_constant(name(s.str())));
}
void initialize_compiler_util() {

View File

@@ -1756,6 +1756,107 @@ object * array_push(obj_arg a, obj_arg v) {
return r;
}
// =======================================
// fixpoint
static inline object * ptr_to_weak_ptr(object * p) {
return reinterpret_cast<object*>(reinterpret_cast<uintptr_t>(p) | 1);
}
static inline object * weak_ptr_to_ptr(object * w) {
return reinterpret_cast<object*>((reinterpret_cast<uintptr_t>(w) >> 1) << 1);
}
obj_res fixpoint_aux(obj_arg rec, obj_arg weak_k, obj_arg a) {
object * k = weak_ptr_to_ptr(weak_k);
inc(k);
return apply_2(rec, k, a);
}
obj_res fixpoint(obj_arg rec, obj_arg a) {
object * k = alloc_closure(fixpoint_aux, 2);
inc(rec);
closure_set(k, 0, rec);
closure_set(k, 1, ptr_to_weak_ptr(k));
object * r = apply_2(rec, k, a);
return r;
}
obj_res fixpoint_aux2(obj_arg rec, obj_arg weak_k, obj_arg a1, obj_arg a2) {
object * k = weak_ptr_to_ptr(weak_k);
inc(k);
return apply_3(rec, k, a1, a2);
}
obj_res fixpoint2(obj_arg rec, obj_arg a1, obj_arg a2) {
object * k = alloc_closure(fixpoint_aux2, 2);
inc(rec);
closure_set(k, 0, rec);
closure_set(k, 1, ptr_to_weak_ptr(k));
object * r = apply_3(rec, k, a1, a2);
return r;
}
obj_res fixpoint_aux3(obj_arg rec, obj_arg weak_k, obj_arg a1, obj_arg a2, obj_arg a3) {
object * k = weak_ptr_to_ptr(weak_k);
inc(k);
return apply_4(rec, k, a1, a2, a3);
}
obj_res fixpoint3(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3) {
object * k = alloc_closure(fixpoint_aux3, 2);
inc(rec);
closure_set(k, 0, rec);
closure_set(k, 1, ptr_to_weak_ptr(k));
object * r = apply_4(rec, k, a1, a2, a3);
return r;
}
obj_res fixpoint_aux4(obj_arg rec, obj_arg weak_k, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4) {
object * k = weak_ptr_to_ptr(weak_k);
inc(k);
return apply_5(rec, k, a1, a2, a3, a4);
}
obj_res fixpoint4(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4) {
object * k = alloc_closure(fixpoint_aux4, 2);
inc(rec);
closure_set(k, 0, rec);
closure_set(k, 1, ptr_to_weak_ptr(k));
object * r = apply_5(rec, k, a1, a2, a3, a4);
return r;
}
obj_res fixpoint_aux5(obj_arg rec, obj_arg weak_k, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5) {
object * k = weak_ptr_to_ptr(weak_k);
inc(k);
return apply_6(rec, k, a1, a2, a3, a4, a5);
}
obj_res fixpoint5(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5) {
object * k = alloc_closure(fixpoint_aux5, 2);
inc(rec);
closure_set(k, 0, rec);
closure_set(k, 1, ptr_to_weak_ptr(k));
object * r = apply_6(rec, k, a1, a2, a3, a4, a5);
return r;
}
obj_res fixpoint_aux6(obj_arg rec, obj_arg weak_k, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5, obj_arg a6) {
object * k = weak_ptr_to_ptr(weak_k);
inc(k);
return apply_7(rec, k, a1, a2, a3, a4, a5, a6);
}
obj_res fixpoint6(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5, obj_arg a6) {
object * k = alloc_closure(fixpoint_aux6, 2);
inc(rec);
closure_set(k, 0, rec);
closure_set(k, 1, ptr_to_weak_ptr(k));
object * r = apply_7(rec, k, a1, a2, a3, a4, a5, a6);
return r;
}
// =======================================
// Debugging helper functions

View File

@@ -652,6 +652,16 @@ inline obj_res alloc_closure(object*(*fun)(object *, object *, object *, object
return alloc_closure(reinterpret_cast<void*>(fun), 8, num_fixed);
}
// =======================================
// Fixpoint
obj_res fixpoint(obj_arg rec, obj_arg a);
obj_res fixpoint2(obj_arg rec, obj_arg a1, obj_arg a2);
obj_res fixpoint3(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3);
obj_res fixpoint4(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4);
obj_res fixpoint5(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5);
obj_res fixpoint6(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5, obj_arg a6);
// =======================================
// Array of objects

View File

@@ -0,0 +1,21 @@
def foo (rec : Nat Nat Nat) : Nat Nat Nat
| 0 a := a
| (n+1) a := rec n a + a + rec n (a+1)
partial def fix' (f: (Nat Nat Nat) (Nat Nat Nat)) : Nat Nat Nat
| a b := f fix' a b
def prof {α : Type} (msg : String) (p : IO α) : IO α :=
let msg := "Time for '" ++ msg ++ "':" in
timeit msg p
def fix_test (n : Nat) : IO Unit :=
IO.println (fix foo n 10)
def fix'_test (n : Nat) : IO Unit :=
IO.println (fix' foo n 10)
def main (xs : List String) : IO Unit :=
prof "native fix" (fix_test xs.head.toNat) *>
prof "fix in lean" (fix'_test xs.head.toNat) *>
pure ()