Compare commits

...

2 Commits

Author SHA1 Message Date
Leonardo de Moura
b7300b30df fix : double reset bug at ResetReuse 2024-04-29 10:43:14 -07:00
Leonardo de Moura
2bf9168eb4 chore: cleanup and coding convention 2024-04-29 10:22:08 -07:00
4 changed files with 152 additions and 67 deletions

View File

@@ -9,9 +9,10 @@ import Lean.Compiler.IR.CompilerM
import Lean.Compiler.IR.LiveVars
namespace Lean.IR.ExplicitRC
/-! Insert explicit RC instructions. So, it assumes the input code does not contain `inc` nor `dec` instructions.
This transformation is applied before lower level optimizations
that introduce the instructions `release` and `set`
/-!
Insert explicit RC instructions. So, it assumes the input code does not contain `inc` nor `dec` instructions.
This transformation is applied before lower level optimizations
that introduce the instructions `release` and `set`
-/
structure VarInfo where

View File

@@ -9,21 +9,24 @@ import Lean.Compiler.IR.LiveVars
import Lean.Compiler.IR.Format
namespace Lean.IR.ResetReuse
/-! Remark: the insertResetReuse transformation is applied before we have
inserted `inc/dec` instructions, and performed lower level optimizations
that introduce the instructions `release` and `set`. -/
/-!
Remark: the insertResetReuse transformation is applied before we have
inserted `inc/dec` instructions, and performed lower level optimizations
that introduce the instructions `release` and `set`.
-/
/-! Remark: the functions `S`, `D` and `R` defined here implement the
corresponding functions in the paper "Counting Immutable Beans"
/-!
Remark: the functions `S`, `D` and `R` defined here implement the
corresponding functions in the paper "Counting Immutable Beans"
Here are the main differences:
- We use the State monad to manage the generation of fresh variable names.
- Support for join points, and `uset` and `sset` instructions for unboxed data.
- `D` uses the auxiliary function `Dmain`.
- `Dmain` returns a pair `(b, found)` to avoid quadratic behavior when checking
the last occurrence of the variable `x`.
- Because we have join points in the actual implementation, a variable may be live even if it
does not occur in a function body. See example at `livevars.lean`.
Here are the main differences:
- We use the State monad to manage the generation of fresh variable names.
- Support for join points, and `uset` and `sset` instructions for unboxed data.
- `D` uses the auxiliary function `Dmain`.
- `Dmain` returns a pair `(b, found)` to avoid quadratic behavior when checking
the last occurrence of the variable `x`.
- Because we have join points in the actual implementation, a variable may be live even if it
does not occur in a function body. See example at `livevars.lean`.
-/
private def mayReuse (c₁ c₂ : CtorInfo) : Bool :=
@@ -33,39 +36,68 @@ private def mayReuse (c₁ c₂ : CtorInfo) : Bool :=
because it produces counterintuitive behavior. -/
c₁.name.getPrefix == c₂.name.getPrefix
/--
Replace `ctor` applications with `reuse` applications if compatible.
`w` contains the "memory cell" being reused.
-/
private partial def S (w : VarId) (c : CtorInfo) : FnBody FnBody
| FnBody.vdecl x t v@(Expr.ctor c' ys) b =>
| .vdecl x t v@(.ctor c' ys) b =>
if mayReuse c c' then
let updtCidx := c.cidx != c'.cidx
FnBody.vdecl x t (Expr.reuse w c' updtCidx ys) b
.vdecl x t (.reuse w c' updtCidx ys) b
else
FnBody.vdecl x t v (S w c b)
| FnBody.jdecl j ys v b =>
.vdecl x t v (S w c b)
| .jdecl j ys v b =>
let v' := S w c v
if v == v' then FnBody.jdecl j ys v (S w c b)
else FnBody.jdecl j ys v' b
| FnBody.case tid x xType alts => FnBody.case tid x xType <| alts.map fun alt => alt.modifyBody (S w c)
if v == v' then
.jdecl j ys v (S w c b)
else
.jdecl j ys v' b
| .case tid x xType alts =>
.case tid x xType <| alts.map fun alt => alt.modifyBody (S w c)
| b =>
if b.isTerminal then b
if b.isTerminal then
b
else let
(instr, b) := b.split
instr.setBody (S w c b)
structure Context where
lctx : LocalContext := {}
/--
Contains all variables in `cases` statements in the current path.
We use this information to prevent double-reset in code such as
```
case x_i : obj of
Prod.mk →
case x_i : obj of
Prod.mk →
...
```
-/
casesVars : PHashSet VarId := {}
/-- We use `Context` to track join points in scope. -/
abbrev M := ReaderT LocalContext (StateT Index Id)
abbrev M := ReaderT Context (StateT Index Id)
private def mkFresh : M VarId := do
let idx getModify (fun n => n + 1)
pure { idx := idx }
let idx getModify fun n => n + 1
return { idx := idx }
/--
Helper function for applying `S`. We only introduce a `reset` if we managed
to replace a `ctor` withe `reuse` in `b`.
-/
private def tryS (x : VarId) (c : CtorInfo) (b : FnBody) : M FnBody := do
let w mkFresh
let b' := S w c b
if b == b' then pure b
else pure $ FnBody.vdecl w IRType.object (Expr.reset c.size x) b'
if b == b' then
return b
else
return .vdecl w IRType.object (.reset c.size x) b'
private def Dfinalize (x : VarId) (c : CtorInfo) : FnBody × Bool M FnBody
| (b, true) => pure b
| (b, true) => return b
| (b, false) => tryS x c b
private def argsContainsVar (ys : Array Arg) (x : VarId) : Bool :=
@@ -75,75 +107,85 @@ private def argsContainsVar (ys : Array Arg) (x : VarId) : Bool :=
private def isCtorUsing (b : FnBody) (x : VarId) : Bool :=
match b with
| (FnBody.vdecl _ _ (Expr.ctor _ ys) _) => argsContainsVar ys x
| .vdecl _ _ (.ctor _ ys) _ => argsContainsVar ys x
| _ => false
/-- Given `Dmain b`, the resulting pair `(new_b, flag)` contains the new body `new_b`,
and `flag == true` if `x` is live in `b`.
/--
Given `Dmain b`, the resulting pair `(new_b, flag)` contains the new body `new_b`,
and `flag == true` if `x` is live in `b`.
Note that, in the function `D` defined in the paper, for each `let x := e; F`,
`D` checks whether `x` is live in `F` or not. This is great for clarity but it
is expensive: `O(n^2)` where `n` is the size of the function body. -/
private partial def Dmain (x : VarId) (c : CtorInfo) : FnBody M (FnBody × Bool)
| e@(FnBody.case tid y yType alts) => do
let ctx read
if e.hasLiveVar ctx x then do
Note that, in the function `D` defined in the paper, for each `let x := e; F`,
`D` checks whether `x` is live in `F` or not. This is great for clarity but it
is expensive: `O(n^2)` where `n` is the size of the function body. -/
private partial def Dmain (x : VarId) (c : CtorInfo) (e : FnBody) : M (FnBody × Bool) := do
match e with
| .case tid y yType alts =>
if e.hasLiveVar ( read).lctx x then
/- If `x` is live in `e`, we recursively process each branch. -/
let alts alts.mapM fun alt => alt.mmodifyBody fun b => Dmain x c b >>= Dfinalize x c
pure (FnBody.case tid y yType alts, true)
else pure (e, false)
| FnBody.jdecl j ys v b => do
let (b, found) withReader (fun ctx => ctx.addJP j ys v) (Dmain x c b)
return (.case tid y yType alts, true)
else
return (e, false)
| .jdecl j ys v b =>
let (b, found) withReader (fun ctx => { ctx with lctx := ctx.lctx.addJP j ys v }) (Dmain x c b)
let (v, _ /- found' -/) Dmain x c v
/- If `found' == true`, then `Dmain b` must also have returned `(b, true)` since
we assume the IR does not have dead join points. So, if `x` is live in `j` (i.e., `v`),
then it must also live in `b` since `j` is reachable from `b` with a `jmp`.
On the other hand, `x` may be live in `b` but dead in `j` (i.e., `v`). -/
pure (FnBody.jdecl j ys v b, found)
| e => do
let ctx read
return (.jdecl j ys v b, found)
| e =>
if e.isTerminal then
pure (e, e.hasLiveVar ctx x)
return (e, e.hasLiveVar ( read).lctx x)
else do
let (instr, b) := e.split
if isCtorUsing instr x then
/- If the scrutinee `x` (the one that is providing memory) is being
stored in a constructor, then reuse will probably not be able to reuse memory at runtime.
It may work only if the new cell is consumed, but we ignore this case. -/
pure (e, true)
return (e, true)
else
let (b, found) Dmain x c b
/- Remark: it is fine to use `hasFreeVar` instead of `hasLiveVar`
since `instr` is not a `FnBody.jmp` (it is not a terminal) nor it is a `FnBody.jdecl`. -/
since `instr` is not a `FnBody.jmp` (it is not a terminal) nor
it is a `FnBody.jdecl`. -/
if found || !instr.hasFreeVar x then
pure (instr.setBody b, found)
return (instr.setBody b, found)
else
let b tryS x c b
pure (instr.setBody b, true)
return (instr.setBody b, true)
private def D (x : VarId) (c : CtorInfo) (b : FnBody) : M FnBody :=
Dmain x c b >>= Dfinalize x c
partial def R : FnBody M FnBody
| FnBody.case tid x xType alts => do
partial def R (e : FnBody) : M FnBody := do
match e with
| .case tid x xType alts =>
let alreadyFound := ( read).casesVars.contains x
withReader (fun ctx => { ctx with casesVars := ctx.casesVars.insert x }) do
let alts alts.mapM fun alt => do
let alt alt.mmodifyBody R
match alt with
| Alt.ctor c b =>
if c.isScalar then pure alt
else Alt.ctor c <$> D x c b
| _ => pure alt
pure $ FnBody.case tid x xType alts
| FnBody.jdecl j ys v b => do
| .ctor c b =>
if c.isScalar || alreadyFound then
-- If `alreadyFound`, then we don't try to reuse memory cell to avoid
-- double reset.
return alt
else
.ctor c <$> D x c b
| _ => return alt
return .case tid x xType alts
| .jdecl j ys v b =>
let v R v
let b withReader (fun ctx => ctx.addJP j ys v) (R b)
pure $ FnBody.jdecl j ys v b
| e => do
if e.isTerminal then pure e
else do
let b withReader (fun ctx => { ctx with lctx := ctx.lctx.addJP j ys v }) (R b)
return .jdecl j ys v b
| e =>
if e.isTerminal then
return e
else
let (instr, b) := e.split
let b R b
pure (instr.setBody b)
return instr.setBody b
end ResetReuse
@@ -151,7 +193,7 @@ open ResetReuse
def Decl.insertResetReuse (d : Decl) : Decl :=
match d with
| .fdecl (body := b) ..=>
| .fdecl (body := b) .. =>
let nextIndex := d.maxIndex + 1
let bNew := (R b {}).run' nextIndex
d.updateBody! bNew

View File

@@ -0,0 +1,4 @@
set_option trace.compiler.ir.reset_reuse true in
def applyProjectionRules (projs : Array ((α × β) × γ)) (newName : γ) :
Array ((α × β) × γ) :=
projs.map fun proj => { proj with 2 := newName, 1.2 := proj.1.2 }

View File

@@ -0,0 +1,38 @@
[reset_reuse]
def Array.mapMUnsafe.map._at.applyProjectionRules._spec_1._rarg (x_1 : obj) (x_2 : usize) (x_3 : usize) (x_4 : obj) : obj :=
let x_5 : u8 := USize.decLt x_3 x_2;
case x_5 : obj of
Bool.false →
ret x_4
Bool.true →
let x_6 : obj := Array.uget ◾ x_4 x_3 ◾;
let x_7 : obj := 0;
let x_8 : obj := Array.uset ◾ x_4 x_3 x_7 ◾;
let x_9 : obj := proj[0] x_6;
case x_9 : obj of
Prod.mk →
case x_9 : obj of
Prod.mk →
let x_10 : obj := proj[0] x_9;
let x_11 : obj := proj[1] x_9;
let x_18 : obj := reset[2] x_9;
let x_12 : obj := reuse x_18 in ctor_0[Prod.mk] x_10 x_11;
let x_13 : obj := ctor_0[Prod.mk] x_12 x_1;
let x_14 : usize := 1;
let x_15 : usize := USize.add x_3 x_14;
let x_16 : obj := Array.uset ◾ x_8 x_3 x_13 ◾;
let x_17 : obj := Array.mapMUnsafe.map._at.applyProjectionRules._spec_1._rarg x_1 x_2 x_15 x_16;
ret x_17
def Array.mapMUnsafe.map._at.applyProjectionRules._spec_1 (x_1 : ◾) (x_2 : ◾) (x_3 : ◾) : obj :=
let x_4 : obj := pap Array.mapMUnsafe.map._at.applyProjectionRules._spec_1._rarg;
ret x_4
def applyProjectionRules._rarg (x_1 : obj) (x_2 : obj) : obj :=
let x_3 : obj := Array.size ◾ x_1;
let x_4 : usize := USize.ofNat x_3;
let x_5 : usize := 0;
let x_6 : obj := Array.mapMUnsafe.map._at.applyProjectionRules._spec_1._rarg x_2 x_4 x_5 x_1;
ret x_6
def applyProjectionRules (x_1 : ◾) (x_2 : ◾) (x_3 : ◾) : obj :=
let x_4 : obj := pap applyProjectionRules._rarg;
ret x_4