mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-17 10:24:07 +00:00
first attempt
This commit is contained in:
@@ -56,7 +56,21 @@ inductive Key where
|
||||
|
||||
end ParamMap
|
||||
|
||||
abbrev ParamMap := Std.HashMap ParamMap.Key (Array (Param .impure))
|
||||
structure ParamMap where
|
||||
map : Std.HashMap ParamMap.Key (Array (Param .impure)) := {}
|
||||
userDefinedBorrows : Std.HashSet FVarId := {}
|
||||
|
||||
namespace ParamMap
|
||||
|
||||
@[inline]
|
||||
def insert (pm : ParamMap) (k : Key) (ps : Array (Param .impure)) : ParamMap :=
|
||||
{ pm with map := pm.map.insert k ps }
|
||||
|
||||
@[inline]
|
||||
def erase (pm : ParamMap) (k : Key) : ParamMap :=
|
||||
{ pm with map := pm.map.erase k }
|
||||
|
||||
end ParamMap
|
||||
|
||||
/-- Mark parameters that take a reference as borrow -/
|
||||
def initBorrow (ps : Array (Param .impure)) : Array (Param .impure) :=
|
||||
@@ -73,7 +87,12 @@ where
|
||||
match decl.value with
|
||||
| .code code =>
|
||||
let exported := isExport (← getEnv) decl.name
|
||||
modify fun m => m.insert (.decl decl.name) (initParamsIfNotExported exported decl.params)
|
||||
modify fun m =>
|
||||
{ m with
|
||||
map := m.map.insert (.decl decl.name) (initParamsIfNotExported exported decl.params),
|
||||
userDefinedBorrows := decl.params.foldl (init := m.userDefinedBorrows) fun acc p =>
|
||||
if p.borrow then acc.insert p.fvarId else acc
|
||||
}
|
||||
goCode decl.name code
|
||||
| .extern .. => return ()
|
||||
|
||||
@@ -86,7 +105,12 @@ where
|
||||
goCode (declName : Name) (code : Code .impure) : InitM Unit := do
|
||||
match code with
|
||||
| .jp decl k =>
|
||||
modify fun m => m.insert (.jp declName decl.fvarId) (initParams decl.params)
|
||||
modify fun m =>
|
||||
{ m with
|
||||
map := m.map.insert (.jp declName decl.fvarId) (initParams decl.params),
|
||||
userDefinedBorrows := decl.params.foldl (init := m.userDefinedBorrows) fun acc p =>
|
||||
if p.borrow then acc.insert p.fvarId else acc
|
||||
}
|
||||
goCode declName decl.value
|
||||
goCode declName k
|
||||
| .cases cs => cs.alts.forM (·.forCodeM (goCode declName))
|
||||
@@ -102,7 +126,7 @@ partial def apply (decls : Array (Decl .impure)) (map : ParamMap) : CompilerM (A
|
||||
match decl.value with
|
||||
| .code code =>
|
||||
let code ← go decl.name code
|
||||
let newParams ← updateParams decl.params map[ParamMap.Key.decl decl.name]!
|
||||
let newParams ← updateParams decl.params map.map[ParamMap.Key.decl decl.name]!
|
||||
return { decl with value := .code code, params := newParams }
|
||||
| _ => return decl
|
||||
where
|
||||
@@ -116,7 +140,7 @@ where
|
||||
go (declName : Name) (code : Code .impure) : CompilerM (Code .impure) := do
|
||||
match code with
|
||||
| .jp decl k =>
|
||||
let ps ← updateParams decl.params map[ParamMap.Key.jp declName decl.fvarId]!
|
||||
let ps ← updateParams decl.params map.map[ParamMap.Key.jp declName decl.fvarId]!
|
||||
let decl ← decl.update decl.type ps (← go declName decl.value)
|
||||
return code.updateFun! decl (← go declName k)
|
||||
| .cases cs => return code.updateAlts! <| ← cs.alts.mapMonoM (·.mapCodeM (go declName))
|
||||
@@ -195,6 +219,20 @@ def OwnReason.toString (reason : OwnReason) : CompilerM String := do
|
||||
| .jpArgPropagation jpFVar => return s!"backward propagation from JP {← PP.ppFVar jpFVar}"
|
||||
| .jpTailCallPreservation jpFVar => return s!"JP tail call preservation {← PP.ppFVar jpFVar}"
|
||||
|
||||
def OwnReason.isForced (reason : OwnReason) : Bool :=
|
||||
match reason with
|
||||
| .constructorArg .. => false
|
||||
| .functionCallArg .. => false
|
||||
| .fvarCall .. => false
|
||||
| .partialApplication .. => false
|
||||
| .resetReuse .. => false -- TODO: think deeper
|
||||
| .jpArgPropagation .. => false -- TODO: think deeper
|
||||
| .constructorResult .. => true
|
||||
| .projectionPropagation .. => true
|
||||
| .functionCallResult .. => true
|
||||
| .tailCallPreservation .. => true
|
||||
| .jpTailCallPreservation .. => true
|
||||
|
||||
/--
|
||||
Infer the borrowing annotations in a SCC through dataflow analysis.
|
||||
-/
|
||||
@@ -215,8 +253,11 @@ where
|
||||
|
||||
ownFVar (fvarId : FVarId) (reason : OwnReason) : InferM Unit := do
|
||||
unless (← get).owned.contains fvarId do
|
||||
trace[Compiler.inferBorrow] "own {← PP.run <| PP.ppFVar fvarId}: {← reason.toString}"
|
||||
modify fun s => { s with owned := s.owned.insert fvarId, modified := true }
|
||||
if !reason.isForced && (← get).paramMap.userDefinedBorrows.contains fvarId then
|
||||
trace[Compiler.inferBorrow] "user annotation blocked owning {← PP.run <| PP.ppFVar fvarId}: {← reason.toString}"
|
||||
else
|
||||
trace[Compiler.inferBorrow] "own {← PP.run <| PP.ppFVar fvarId}: {← reason.toString}"
|
||||
modify fun s => { s with owned := s.owned.insert fvarId, modified := true }
|
||||
|
||||
ownArg (reason : OwnReason) (a : Arg .impure) : InferM Unit := do
|
||||
a.forFVarM (ownFVar · reason)
|
||||
@@ -237,7 +278,7 @@ where
|
||||
|
||||
/-- Updates `map[k]` using the current set of `owned` variables. -/
|
||||
updateParamMap (k : ParamMap.Key) : InferM Unit := do
|
||||
if let some ps := (← get).paramMap[k]? then
|
||||
if let some ps := (← get).paramMap.map[k]? then
|
||||
-- This is to ensure linearity over ps in the following code, if you know how to make this
|
||||
-- linear in a nice fashion please make a PR
|
||||
modify fun s => { s with paramMap := s.paramMap.erase k }
|
||||
@@ -252,7 +293,7 @@ where
|
||||
modify fun s => { s with paramMap := s.paramMap.insert k ps }
|
||||
|
||||
getParamInfo (k : ParamMap.Key) : InferM (Array (Param .impure)) := do
|
||||
match (← get).paramMap[k]? with
|
||||
match (← get).paramMap.map[k]? with
|
||||
| some ps => return ps
|
||||
| none =>
|
||||
let .decl fn := k | unreachable!
|
||||
|
||||
Reference in New Issue
Block a user