Compare commits

...

2 Commits

Author SHA1 Message Date
Leonardo de Moura
90c6034cfc proto expr experiment 2022-08-22 08:55:27 -07:00
Leonardo de Moura
a5761e0166 checkpoint 2022-08-22 08:53:25 -07:00
4 changed files with 192 additions and 35 deletions

View File

@@ -5,6 +5,7 @@ Authors: Leonardo de Moura
-/
import Lean.Compiler.CompilerM
import Lean.Compiler.Decl
import Lean.Compiler.PExpr
namespace Lean.Compiler
@@ -19,30 +20,25 @@ abbrev M := StateRefT State CompilerM
mutual
partial def visitCases (casesInfo : CasesInfo) (cases : Expr) : M Expr := do
let mut args := cases.getAppArgs
for i in casesInfo.altsRange do
args args.modifyM i visitLambda
return mkAppN cases.getAppFn args
partial def visitLambda (xs : Array Expr) (e : Expr) : M PExpr :=
PExpr.visitLambda xs e visitLet
partial def visitLambda (e : Expr) : M Expr :=
withNewScope do
let (as, e) Compiler.visitLambdaCore e
let e mkLetUsingScope ( visitLet e as)
mkLambda as e
partial def visitLet (e : Expr) (xs : Array Expr): M Expr := do
partial def visitLet (xs : Array Expr) (e : Expr) : M (Array Expr × PExpr) := do
let saved get
try go e xs finally set saved
try go xs e finally set saved
where
go (e : Expr) (xs : Array Expr) : M Expr := do
go (xs : Array Expr) (e : Expr) : M (Array Expr × PExpr) := do
match e with
| .letE binderName type value body nonDep =>
let mut value := value.instantiateRev xs
if value.isLambda then
value visitLambda value
let value' if value.isLambda then
( visitLambda xs value).toExpr
else
pure <| value.instantiateRev xs
if value'.isLambda then
trace[Meta.debug] "nested lambda\n{value.instantiateRev xs}\n====>\n{value'}"
let value := value'
match ( get).map.find? value with
| some x => go body (xs.push x)
| some x => go (xs.push x) body
| none =>
let type := type.instantiateRev xs
let x mkLetDecl binderName type value nonDep
@@ -50,13 +46,12 @@ where
-- We currently don't eliminate common join points because we want to prevent
-- jumps to out-of-scope join points.
modify fun s => { s with map := s.map.insert value x }
go body (xs.push x)
go (xs.push x) body
| _ =>
let e := e.instantiateRev xs
if let some casesInfo isCasesApp? e then
visitCases casesInfo e
return (xs, PExpr.visitCases xs casesInfo e visitLambda)
else
return e
return (xs, e.instantiateRev xs)
end
@@ -66,6 +61,6 @@ end CSE
Common sub-expression elimination
-/
def Decl.cse (decl : Decl) : CoreM Decl :=
decl.mapValue fun value => CSE.visitLambda value |>.run' {}
decl.mapValue fun value => do ( CSE.visitLambda #[] value |>.run' {}).toExpr
end Lean.Compiler

View File

@@ -68,11 +68,16 @@ def compileStage1Impl (declNames : Array Name) : CoreM (Array Decl) := do
checkpoint `cse decls
let decls decls.mapM (·.simp)
checkpoint `simp decls
let mut decls := decls
profileitM Exception "compiler cse" ( getOptions) do
let mut decls := decls
for _ in [:1000] do
decls decls.mapM (·.cse)
-- let decls ← decls.mapM (·.cse)
-- checkpoint `cse decls
saveStage1Decls decls
decls.forM fun decl => do trace[Compiler.stat] "{decl.name}: {← getLCNFSize decl.value}"
return decls
-- checkpoint `cse decls
saveStage1Decls decls
decls.forM fun decl => do trace[Compiler.stat] "{decl.name}: {← getLCNFSize decl.value}"
return decls
/--
Run the code generation pipeline for all declarations in `declNames`

View File

@@ -0,0 +1,151 @@
/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Compiler.CompilerM
namespace Lean.Compiler
/--
Proto expressions are an artifact to minimize the performance overhead of the locally nameless approach.
-/
structure PExpr where
private expr : Expr
deriving Inhabited
/-- Every `Expr` is a valid `PExpr`. -/
instance : Coe Expr PExpr where
coe e := { expr := e }
namespace PExpr
def mkPBinding' (x : Expr) (body : PExpr) : PExpr :=
.lam `_pbinding x body.expr .default
def mkPBinding (xs : Array Expr) (body : PExpr) : PExpr :=
xs.foldr (init := body) mkPBinding'
def mkPBindingFrom (start : Nat) (xs : Array Expr) (body : PExpr) : PExpr :=
go xs.size body
where
go (i : Nat) (body : PExpr) : PExpr :=
if i > start then
go (i - 1) (mkPBinding' xs[i - 1]! body)
else
body
@[inline] private def isPBinding? (e : PExpr) : Option (Expr × PExpr) :=
if let .lam `_pbinding x body _ := e.expr then
some (x, body)
else
none
instance : ToMessageData PExpr where
toMessageData e := e.expr
def visitLambda' (xs : Array Expr) (e : Expr) : CompilerM (Array Expr × Expr) :=
go e xs
where
go (e : Expr) (xs : Array Expr) := do
if let .lam binderName type body binderInfo := e then
let type := type.instantiateRev xs
let x mkLocalDecl binderName type binderInfo
go body (xs.push x)
else
return (xs, e)
@[inline] def visitLambdaImp (xs : Array Expr) (e : Expr) (visitBody : Array Expr Expr CompilerM (Array Expr × PExpr)) : CompilerM PExpr := do
let start := xs.size
let (xs, e) visitLambda' xs e
let (xs, e) visitBody xs e
return mkPBindingFrom start xs e
class VisitLambda (m : Type Type) where
visitLambda : Array Expr Expr (Array Expr Expr m (Array Expr × PExpr)) m PExpr
export VisitLambda (visitLambda)
instance : VisitLambda CompilerM where
visitLambda := visitLambdaImp
instance [VisitLambda m] : VisitLambda (ReaderT ρ m) where
visitLambda := fun xs e f ctx => visitLambda xs e (f · · ctx)
instance [VisitLambda m] : VisitLambda (StateRefT' ω ρ m) :=
inferInstanceAs (VisitLambda (ReaderT _ _))
@[inline] def visitCases [Monad m] (xs : Array Expr) (casesInfo : CasesInfo) (cases : Expr) (visitAlt : Array Expr Expr m PExpr) : m PExpr := do
let mut args := cases.getAppArgs
for i in [:args.size] do
let arg := args[i]!
let arg if casesInfo.altsRange.start i && i < casesInfo.altsRange.stop then
pure ( visitAlt xs arg).expr
else
pure <| arg.instantiateRev xs
args := args.set! i arg
return mkAppN cases.getAppFn args
abbrev UsedSet := FVarIdHashSet
partial def toExpr' (env : Environment) (lctx : LocalContext) (e : PExpr) : Expr :=
let e := filter e {} |>.1
go #[] e
where
filter (e : PExpr) (s : UsedSet) : PExpr × UsedSet :=
if let some (x, e) := isPBinding? e then
let (e, s) := filter e s
match lctx.get! x.fvarId! with
| .cdecl .. => (mkPBinding' x e, s)
| .ldecl (value := value) .. =>
if s.contains x.fvarId! then
(mkPBinding' x e, collectLetFVars s value)
else
(e, s)
else if let some casesInfo := isCasesApp?' env e.expr then
Id.run do
let mut args := e.expr.getAppArgs
let mut s := s
for i in [:args.size] do
let arg := args[i]!
let arg if casesInfo.altsRange.start i && i < casesInfo.altsRange.stop then
let (arg, s') := filter arg s
s := s'
pure arg.expr
else
s := collectLetFVars s arg
pure arg
args := args.set! i arg
return (mkAppN e.expr.getAppFn args, s)
else
(e, collectLetFVars s e.expr)
go (xs : Array Expr) (e : PExpr) : Expr :=
if let some (x, e) := isPBinding? e then
match lctx.get! x.fvarId! with
| .cdecl (userName := userName) (type := type) (bi := bi) .. =>
let type := type.abstract xs
.lam userName type (go (xs.push x) e) bi
| .ldecl (userName := userName) (type := type) (value := value) .. =>
let type := type.abstract xs
let value := value.abstract xs
.letE userName type value (go (xs.push x) e) true
else if let some casesInfo := isCasesApp?' env e.expr then
Id.run do
let mut args := e.expr.getAppArgs
for i in [:args.size] do
let arg := args[i]!
let arg if casesInfo.altsRange.start i && i < casesInfo.altsRange.stop then
go xs arg
else
arg.abstract xs
args := args.set! i arg
return mkAppN e.expr.getAppFn args
else
e.expr.abstract xs
partial def toExpr [Monad m] [MonadEnv m] [MonadLCtx m] (e : PExpr) : m Expr :=
return toExpr' ( getEnv) ( getLCtx) e
end PExpr
end Lean.Compiler

View File

@@ -57,13 +57,13 @@ structure CasesInfo where
altNumParams : Array Nat
motivePos : Nat
private def getCasesOnInductiveVal? (declName : Name) : CoreM (Option InductiveVal) := do
unless isCasesOnRecursor ( getEnv) declName do return none
let .inductInfo val getConstInfo declName.getPrefix | return none
private def getCasesOnInductiveVal? (env : Environment) (declName : Name) : Option InductiveVal := Id.run do
unless isCasesOnRecursor env declName do return none
let some (.inductInfo val) := env.find? declName.getPrefix | return none
return some val
def getCasesInfo? (declName : Name) : CoreM (Option CasesInfo) := do
let some val getCasesOnInductiveVal? declName | return none
def getCasesInfo?' (env : Environment) (declName : Name) : Option CasesInfo := Id.run do
let some val getCasesOnInductiveVal? env declName | return none
let numParams := val.numParams
let motivePos := numParams
let arity := numParams + 1 /- motive -/ + val.numIndices + 1 /- major -/ + val.numCtors
@@ -72,10 +72,13 @@ def getCasesInfo? (declName : Name) : CoreM (Option CasesInfo) := do
let discrsRange := { start := numParams + 1, stop := majorPos + 1 }
let altsRange := { start := majorPos + 1, stop := arity }
let altNumParams val.ctors.toArray.mapM fun ctor => do
let .ctorInfo ctorVal getConstInfo ctor | unreachable!
let some (.ctorInfo ctorVal) := env.find? ctor | unreachable!
return ctorVal.numFields
return some { numParams, motivePos, arity, discrsRange, altsRange, altNumParams }
def getCasesInfo? [Monad m] [MonadEnv m] (declName : Name) : m (Option CasesInfo) :=
return getCasesInfo?' ( getEnv) declName
def CasesInfo.geNumDiscrs (casesInfo : CasesInfo) : Nat :=
casesInfo.discrsRange.stop - casesInfo.discrsRange.start
@@ -87,14 +90,17 @@ where
| .lam n b d bi => .lam n b (go d) bi
| _ => typeNew
def isCasesApp? (e : Expr) : CoreM (Option CasesInfo) := do
def isCasesApp?' (env : Environment) (e : Expr) : Option CasesInfo := Id.run do
let .const declName _ := e.getAppFn | return none
if let some info getCasesInfo? declName then
if let some info getCasesInfo?' env declName then
assert! info.arity == e.getAppNumArgs
return some info
else
return none
def isCasesApp? [Monad m] [MonadEnv m] (e : Expr) : m (Option CasesInfo) :=
return isCasesApp?' ( getEnv) e
def getCtorArity? (declName : Name) : CoreM (Option Nat) := do
let .ctorInfo val getConstInfo declName | return none
return val.numParams + val.numFields