mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-17 18:34:06 +00:00
fix(library/init/lean/compiler/ir/resetreuse): must use livevars instead of freevars
The file badreset contains two functions where the new `reset/reuse` insertion procedure implemented in Lean produces better results than the one implemented in C++. cc @kha
This commit is contained in:
@@ -81,7 +81,7 @@ end IsLive
|
||||
|
||||
Recall that we say that a join point `j` is free in `b` if `b` contains
|
||||
`FnBody.jmp j ys` and `j` is not local. -/
|
||||
def FnBody.isLive (b : FnBody) (ctx : Context) (x : VarId) : Bool :=
|
||||
def FnBody.hasLiveVar (b : FnBody) (ctx : Context) (x : VarId) : Bool :=
|
||||
(IsLive.visitFnBody x.idx b).run' ctx
|
||||
|
||||
end IR
|
||||
|
||||
@@ -5,8 +5,9 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import init.control.state
|
||||
import init.control.reader
|
||||
import init.lean.compiler.ir.basic
|
||||
import init.lean.compiler.ir.freevars
|
||||
import init.lean.compiler.ir.livevars
|
||||
|
||||
namespace Lean
|
||||
namespace IR
|
||||
@@ -20,9 +21,11 @@ namespace IR
|
||||
Here are the main differences:
|
||||
- We use the State monad to manage the generation of fresh variable names.
|
||||
- Support for join points, and `uset` and `sset` instructions for unboxed data.
|
||||
- `R` uses the `flatten` and `reshape` idiom.
|
||||
- `D` returns a pair `(b, found)` to avoid quadratic behavior when checking
|
||||
the last occurrence of the variable `x`
|
||||
- `D` uses the auxiliary function `Dmain`.
|
||||
- `Dmain` returns a pair `(b, found)` to avoid quadratic behavior when checking
|
||||
the last occurrence of the variable `x`.
|
||||
- Because we have join points in the actual implementation, a variable may be live even if it
|
||||
does not occur in a function body. See example at `livevars.lean`.
|
||||
-/
|
||||
|
||||
private def mayReuse (c₁ c₂ : CtorInfo) : Bool :=
|
||||
@@ -50,7 +53,8 @@ private partial def S (w : VarId) (c : CtorInfo) : FnBody → FnBody
|
||||
(instr, b) := b.split in
|
||||
instr <;> S b
|
||||
|
||||
abbrev M := State Index
|
||||
/- We use `Context` to track join points in scope. -/
|
||||
abbrev M := ReaderT Context (StateT Index Id)
|
||||
local attribute [instance] monadInhabited
|
||||
|
||||
private def mkFresh : M VarId :=
|
||||
@@ -68,56 +72,76 @@ private def Dfinalize (x : VarId) (c : CtorInfo) : FnBody × Bool → M FnBody
|
||||
| (b, false) := tryS x c b
|
||||
|
||||
private partial def Dmain (x : VarId) (c : CtorInfo) : FnBody → M (FnBody × Bool)
|
||||
| b@(FnBody.case tid y alts) :=
|
||||
if b.hasFreeVar x then do
|
||||
| e@(FnBody.case tid y alts) := do
|
||||
ctx ← read,
|
||||
if e.hasLiveVar ctx x then do
|
||||
/- If `x` is live in `e`, we recursively process each branch. -/
|
||||
alts ← alts.hmmap $ λ alt, alt.mmodifyBody (λ b, Dmain b >>= Dfinalize x c),
|
||||
pure (FnBody.case tid y alts, true)
|
||||
else
|
||||
pure (b, false)
|
||||
| e :=
|
||||
else pure (e, false)
|
||||
| e@(FnBody.jdecl j ys t v b) := do
|
||||
(b, _) ← adaptReader (λ ctx : Context, ctx.addDecl e) (Dmain b),
|
||||
(v, found) ← Dmain v,
|
||||
/- If `found == true`, then `Dmain b` must also have returned `(b, true)` since
|
||||
we assume the IR does not have dead join points. So, if `x` is live in `j`,
|
||||
then it must also live in `b` since `j` is reachable from `b` with a `jmp`. -/
|
||||
pure (FnBody.jdecl j ys t v b, found)
|
||||
| e := do
|
||||
ctx ← read,
|
||||
if e.isTerminal then
|
||||
pure (e, e.hasFreeVar x)
|
||||
pure (e, e.hasLiveVar ctx x)
|
||||
else do
|
||||
let (instr, b) := e.split,
|
||||
(b, found) ← Dmain b,
|
||||
/- Remark: it is fine to use `hasFreeVar` instead of `hasLiveVar`
|
||||
since `instr` is not a `FnBody.jmp` (it is not a terminal) nor it is a `FnBody.jdecl`. -/
|
||||
if found || !instr.hasFreeVar x then
|
||||
pure (instr <;> b, found)
|
||||
else do
|
||||
b ← tryS x c b,
|
||||
pure (instr <;> b, true)
|
||||
|
||||
/- Auxiliary function used to implement an additional heuristic at `D`. -/
|
||||
private partial def hasCtorUsing (x : VarId) : FnBody → Bool
|
||||
| (FnBody.vdecl x _ (Expr.ctor _ ys) b) :=
|
||||
ys.any (λ arg, Arg.hasFreeVar arg x) || hasCtorUsing b
|
||||
| b := !b.isTerminal && hasCtorUsing b.body
|
||||
ys.any (λ arg, match arg with
|
||||
| Arg.var y := x == y
|
||||
| _ := false)
|
||||
|| hasCtorUsing b
|
||||
| (FnBody.jdecl _ _ _ v b) := hasCtorUsing v || hasCtorUsing b
|
||||
| b := !b.isTerminal && hasCtorUsing b.body
|
||||
|
||||
private def D (x : VarId) (c : CtorInfo) (b : FnBody) : M FnBody :=
|
||||
/- If the scrutinee `x` (the one that is providing memory) is being
|
||||
stored in a constructor, then reuse will probably not work.
|
||||
stored in a constructor, then reuse will probably not be able to reuse memory at runtime.
|
||||
It may work only if the new cell is consumed, but we ignore this case. -/
|
||||
if hasCtorUsing x b then pure b
|
||||
else Dmain x c b >>= Dfinalize x c
|
||||
|
||||
private partial def R : FnBody → M FnBody
|
||||
| b := do
|
||||
let (bs, term) := b.flatten,
|
||||
bs ← mmodifyJPs bs R,
|
||||
match term with
|
||||
| FnBody.case tid x alts := do
|
||||
| (FnBody.case tid x alts) := do
|
||||
alts ← alts.hmmap $ λ alt, do {
|
||||
alt ← alt.mmodifyBody R,
|
||||
match alt with
|
||||
| Alt.ctor c b := Alt.ctor c <$> D x c b
|
||||
| _ := pure alt
|
||||
},
|
||||
let term := FnBody.case tid x alts,
|
||||
pure $ reshape bs term
|
||||
| other := pure $ reshape bs term
|
||||
pure $ FnBody.case tid x alts
|
||||
| e@(FnBody.jdecl j ys t v b) := do
|
||||
v ← R v,
|
||||
b ← adaptReader (λ ctx : Context, ctx.addDecl e) (R b),
|
||||
pure $ FnBody.jdecl j ys t v b
|
||||
| e := do
|
||||
if e.isTerminal then pure e
|
||||
else do
|
||||
let (instr, b) := e.split,
|
||||
b ← R b,
|
||||
pure (instr <;> b)
|
||||
|
||||
def Decl.insertResetReuse : Decl → Decl
|
||||
| d@(Decl.fdecl f xs t b) :=
|
||||
let nextIndex := d.maxIndex + 1 in
|
||||
let b := (R b).run' nextIndex in
|
||||
let b := (R b {}).run' nextIndex in
|
||||
Decl.fdecl f xs t b
|
||||
| other := other
|
||||
|
||||
|
||||
@@ -1,12 +1,23 @@
|
||||
@[noinline] def g (x : Nat × Nat) := x
|
||||
|
||||
set_option trace.compiler.boxed true
|
||||
set_option trace.compiler.lambda_pure true
|
||||
|
||||
@[noinline] def f (b : Bool) (x : Nat × Nat) : (Nat × Nat) × (Nat × Nat) :=
|
||||
let done (y : Nat × Nat) := (g (g (g y)), x) in
|
||||
let done (y : Nat × Nat) := (g (g (g x)), y) in
|
||||
match b with
|
||||
| true := match x with | (a, b) := done (a, 0)
|
||||
| false := match x with | (a, b) := done (0, b)
|
||||
|
||||
@[noinline] def h {α : Type} (x : Nat × α) := x.1
|
||||
|
||||
def tst2 (p : Nat × (Except Nat Nat)) : Nat × Nat :=
|
||||
match p with
|
||||
| (a, b) :=
|
||||
let done (x : Nat) := (h p + 1, x) in
|
||||
match b with
|
||||
| Except.ok v := done v
|
||||
| Except.error w := done w
|
||||
|
||||
def main (xs : List String) : IO Unit :=
|
||||
IO.println $ f true (xs.head.toNat, xs.tail.head.toNat)
|
||||
|
||||
Reference in New Issue
Block a user