From 5ab45e12ca2ab0712132e1009a33fcb7084c40de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20B=C3=B6ving?= Date: Fri, 6 Mar 2026 16:09:33 +0000 Subject: [PATCH] first attempt --- src/Lean/Compiler/LCNF/InferBorrow.lean | 59 +++++++++++++++++++++---- 1 file changed, 50 insertions(+), 9 deletions(-) diff --git a/src/Lean/Compiler/LCNF/InferBorrow.lean b/src/Lean/Compiler/LCNF/InferBorrow.lean index 6c6d7fa423..63d8e3b538 100644 --- a/src/Lean/Compiler/LCNF/InferBorrow.lean +++ b/src/Lean/Compiler/LCNF/InferBorrow.lean @@ -56,7 +56,21 @@ inductive Key where end ParamMap -abbrev ParamMap := Std.HashMap ParamMap.Key (Array (Param .impure)) +structure ParamMap where + map : Std.HashMap ParamMap.Key (Array (Param .impure)) := {} + userDefinedBorrows : Std.HashSet FVarId := {} + +namespace ParamMap + +@[inline] +def insert (pm : ParamMap) (k : Key) (ps : Array (Param .impure)) : ParamMap := + { pm with map := pm.map.insert k ps } + +@[inline] +def erase (pm : ParamMap) (k : Key) : ParamMap := + { pm with map := pm.map.erase k } + +end ParamMap /-- Mark parameters that take a reference as borrow -/ def initBorrow (ps : Array (Param .impure)) : Array (Param .impure) := @@ -73,7 +87,12 @@ where match decl.value with | .code code => let exported := isExport (← getEnv) decl.name - modify fun m => m.insert (.decl decl.name) (initParamsIfNotExported exported decl.params) + modify fun m => + { m with + map := m.map.insert (.decl decl.name) (initParamsIfNotExported exported decl.params), + userDefinedBorrows := decl.params.foldl (init := m.userDefinedBorrows) fun acc p => + if p.borrow then acc.insert p.fvarId else acc + } goCode decl.name code | .extern .. => return () @@ -86,7 +105,12 @@ where goCode (declName : Name) (code : Code .impure) : InitM Unit := do match code with | .jp decl k => - modify fun m => m.insert (.jp declName decl.fvarId) (initParams decl.params) + modify fun m => + { m with + map := m.map.insert (.jp declName decl.fvarId) (initParams decl.params), + userDefinedBorrows := decl.params.foldl (init := m.userDefinedBorrows) fun acc p => + if p.borrow then acc.insert p.fvarId else acc + } goCode declName decl.value goCode declName k | .cases cs => cs.alts.forM (·.forCodeM (goCode declName)) @@ -102,7 +126,7 @@ partial def apply (decls : Array (Decl .impure)) (map : ParamMap) : CompilerM (A match decl.value with | .code code => let code ← go decl.name code - let newParams ← updateParams decl.params map[ParamMap.Key.decl decl.name]! + let newParams ← updateParams decl.params map.map[ParamMap.Key.decl decl.name]! return { decl with value := .code code, params := newParams } | _ => return decl where @@ -116,7 +140,7 @@ where go (declName : Name) (code : Code .impure) : CompilerM (Code .impure) := do match code with | .jp decl k => - let ps ← updateParams decl.params map[ParamMap.Key.jp declName decl.fvarId]! + let ps ← updateParams decl.params map.map[ParamMap.Key.jp declName decl.fvarId]! let decl ← decl.update decl.type ps (← go declName decl.value) return code.updateFun! decl (← go declName k) | .cases cs => return code.updateAlts! <| ← cs.alts.mapMonoM (·.mapCodeM (go declName)) @@ -195,6 +219,20 @@ def OwnReason.toString (reason : OwnReason) : CompilerM String := do | .jpArgPropagation jpFVar => return s!"backward propagation from JP {← PP.ppFVar jpFVar}" | .jpTailCallPreservation jpFVar => return s!"JP tail call preservation {← PP.ppFVar jpFVar}" +def OwnReason.isForced (reason : OwnReason) : Bool := + match reason with + | .constructorArg .. => false + | .functionCallArg .. => false + | .fvarCall .. => false + | .partialApplication .. => false + | .resetReuse .. => false -- TODO: think deeper + | .jpArgPropagation .. => false -- TODO: think deeper + | .constructorResult .. => true + | .projectionPropagation .. => true + | .functionCallResult .. => true + | .tailCallPreservation .. => true + | .jpTailCallPreservation .. => true + /-- Infer the borrowing annotations in a SCC through dataflow analysis. -/ @@ -215,8 +253,11 @@ where ownFVar (fvarId : FVarId) (reason : OwnReason) : InferM Unit := do unless (← get).owned.contains fvarId do - trace[Compiler.inferBorrow] "own {← PP.run <| PP.ppFVar fvarId}: {← reason.toString}" - modify fun s => { s with owned := s.owned.insert fvarId, modified := true } + if !reason.isForced && (← get).paramMap.userDefinedBorrows.contains fvarId then + trace[Compiler.inferBorrow] "user annotation blocked owning {← PP.run <| PP.ppFVar fvarId}: {← reason.toString}" + else + trace[Compiler.inferBorrow] "own {← PP.run <| PP.ppFVar fvarId}: {← reason.toString}" + modify fun s => { s with owned := s.owned.insert fvarId, modified := true } ownArg (reason : OwnReason) (a : Arg .impure) : InferM Unit := do a.forFVarM (ownFVar · reason) @@ -237,7 +278,7 @@ where /-- Updates `map[k]` using the current set of `owned` variables. -/ updateParamMap (k : ParamMap.Key) : InferM Unit := do - if let some ps := (← get).paramMap[k]? then + if let some ps := (← get).paramMap.map[k]? then -- This is to ensure linearity over ps in the following code, if you know how to make this -- linear in a nice fashion please make a PR modify fun s => { s with paramMap := s.paramMap.erase k } @@ -252,7 +293,7 @@ where modify fun s => { s with paramMap := s.paramMap.insert k ps } getParamInfo (k : ParamMap.Key) : InferM (Array (Param .impure)) := do - match (← get).paramMap[k]? with + match (← get).paramMap.map[k]? with | some ps => return ps | none => let .decl fn := k | unreachable!