first attempt

This commit is contained in:
Henrik Böving
2026-03-06 16:09:33 +00:00
parent 0f2532f683
commit 5ab45e12ca

View File

@@ -56,7 +56,21 @@ inductive Key where
end ParamMap
abbrev ParamMap := Std.HashMap ParamMap.Key (Array (Param .impure))
structure ParamMap where
map : Std.HashMap ParamMap.Key (Array (Param .impure)) := {}
userDefinedBorrows : Std.HashSet FVarId := {}
namespace ParamMap
@[inline]
def insert (pm : ParamMap) (k : Key) (ps : Array (Param .impure)) : ParamMap :=
{ pm with map := pm.map.insert k ps }
@[inline]
def erase (pm : ParamMap) (k : Key) : ParamMap :=
{ pm with map := pm.map.erase k }
end ParamMap
/-- Mark parameters that take a reference as borrow -/
def initBorrow (ps : Array (Param .impure)) : Array (Param .impure) :=
@@ -73,7 +87,12 @@ where
match decl.value with
| .code code =>
let exported := isExport ( getEnv) decl.name
modify fun m => m.insert (.decl decl.name) (initParamsIfNotExported exported decl.params)
modify fun m =>
{ m with
map := m.map.insert (.decl decl.name) (initParamsIfNotExported exported decl.params),
userDefinedBorrows := decl.params.foldl (init := m.userDefinedBorrows) fun acc p =>
if p.borrow then acc.insert p.fvarId else acc
}
goCode decl.name code
| .extern .. => return ()
@@ -86,7 +105,12 @@ where
goCode (declName : Name) (code : Code .impure) : InitM Unit := do
match code with
| .jp decl k =>
modify fun m => m.insert (.jp declName decl.fvarId) (initParams decl.params)
modify fun m =>
{ m with
map := m.map.insert (.jp declName decl.fvarId) (initParams decl.params),
userDefinedBorrows := decl.params.foldl (init := m.userDefinedBorrows) fun acc p =>
if p.borrow then acc.insert p.fvarId else acc
}
goCode declName decl.value
goCode declName k
| .cases cs => cs.alts.forM (·.forCodeM (goCode declName))
@@ -102,7 +126,7 @@ partial def apply (decls : Array (Decl .impure)) (map : ParamMap) : CompilerM (A
match decl.value with
| .code code =>
let code go decl.name code
let newParams updateParams decl.params map[ParamMap.Key.decl decl.name]!
let newParams updateParams decl.params map.map[ParamMap.Key.decl decl.name]!
return { decl with value := .code code, params := newParams }
| _ => return decl
where
@@ -116,7 +140,7 @@ where
go (declName : Name) (code : Code .impure) : CompilerM (Code .impure) := do
match code with
| .jp decl k =>
let ps updateParams decl.params map[ParamMap.Key.jp declName decl.fvarId]!
let ps updateParams decl.params map.map[ParamMap.Key.jp declName decl.fvarId]!
let decl decl.update decl.type ps ( go declName decl.value)
return code.updateFun! decl ( go declName k)
| .cases cs => return code.updateAlts! <| cs.alts.mapMonoM (·.mapCodeM (go declName))
@@ -195,6 +219,20 @@ def OwnReason.toString (reason : OwnReason) : CompilerM String := do
| .jpArgPropagation jpFVar => return s!"backward propagation from JP {← PP.ppFVar jpFVar}"
| .jpTailCallPreservation jpFVar => return s!"JP tail call preservation {← PP.ppFVar jpFVar}"
def OwnReason.isForced (reason : OwnReason) : Bool :=
match reason with
| .constructorArg .. => false
| .functionCallArg .. => false
| .fvarCall .. => false
| .partialApplication .. => false
| .resetReuse .. => false -- TODO: think deeper
| .jpArgPropagation .. => false -- TODO: think deeper
| .constructorResult .. => true
| .projectionPropagation .. => true
| .functionCallResult .. => true
| .tailCallPreservation .. => true
| .jpTailCallPreservation .. => true
/--
Infer the borrowing annotations in a SCC through dataflow analysis.
-/
@@ -215,8 +253,11 @@ where
ownFVar (fvarId : FVarId) (reason : OwnReason) : InferM Unit := do
unless ( get).owned.contains fvarId do
trace[Compiler.inferBorrow] "own {← PP.run <| PP.ppFVar fvarId}: {← reason.toString}"
modify fun s => { s with owned := s.owned.insert fvarId, modified := true }
if !reason.isForced && ( get).paramMap.userDefinedBorrows.contains fvarId then
trace[Compiler.inferBorrow] "user annotation blocked owning {← PP.run <| PP.ppFVar fvarId}: {← reason.toString}"
else
trace[Compiler.inferBorrow] "own {← PP.run <| PP.ppFVar fvarId}: {← reason.toString}"
modify fun s => { s with owned := s.owned.insert fvarId, modified := true }
ownArg (reason : OwnReason) (a : Arg .impure) : InferM Unit := do
a.forFVarM (ownFVar · reason)
@@ -237,7 +278,7 @@ where
/-- Updates `map[k]` using the current set of `owned` variables. -/
updateParamMap (k : ParamMap.Key) : InferM Unit := do
if let some ps := ( get).paramMap[k]? then
if let some ps := ( get).paramMap.map[k]? then
-- This is to ensure linearity over ps in the following code, if you know how to make this
-- linear in a nice fashion please make a PR
modify fun s => { s with paramMap := s.paramMap.erase k }
@@ -252,7 +293,7 @@ where
modify fun s => { s with paramMap := s.paramMap.insert k ps }
getParamInfo (k : ParamMap.Key) : InferM (Array (Param .impure)) := do
match ( get).paramMap[k]? with
match ( get).paramMap.map[k]? with
| some ps => return ps
| none =>
let .decl fn := k | unreachable!