Compare commits

...

2 Commits

Author SHA1 Message Date
Leonardo de Moura
d198a7a0b4 chore: typo 2024-07-17 19:09:09 -07:00
Leonardo de Moura
0070b06e10 perf: ensure Expr.replaceExpr preserve DAG structure in Exprs 2024-07-17 18:54:18 -07:00
2 changed files with 40 additions and 50 deletions

View File

@@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Init.Data.Hashable
import Lean.Data.HashSet
import Lean.Data.HashMap
namespace Lean
@@ -33,4 +34,22 @@ unsafe abbrev PtrSet.insert (s : PtrSet α) (a : α) : PtrSet α :=
unsafe abbrev PtrSet.contains (s : PtrSet α) (a : α) : Bool :=
HashSet.contains s { value := a }
/--
Map of pointers. It is a low-level auxiliary datastructure used for traversing DAGs.
-/
unsafe def PtrMap (α : Type) (β : Type) :=
HashMap (Ptr α) β
unsafe def mkPtrMap {α β : Type} (capacity : Nat := 64) : PtrMap α β :=
mkHashMap capacity
unsafe abbrev PtrMap.insert (s : PtrMap α β) (a : α) (b : β) : PtrMap α β :=
HashMap.insert s { value := a } b
unsafe abbrev PtrMap.contains (s : PtrMap α β) (a : α) : Bool :=
HashMap.contains s { value := a }
unsafe abbrev PtrMap.find? (s : PtrMap α β) (a : α) : Option β :=
HashMap.find? s { value := a }
end Lean

View File

@@ -5,74 +5,45 @@ Authors: Leonardo de Moura, Gabriel Ebner, Sebastian Ullrich
-/
prelude
import Lean.Expr
import Lean.Util.PtrSet
namespace Lean
namespace Expr
namespace ReplaceImpl
structure Cache where
size : USize
-- First `size` elements are the keys.
-- Second `size` elements are the results.
keysResults : Array NonScalar -- Either Expr or Unit (disjoint memory representation)
unsafe def Cache.new (e : Expr) : Cache :=
-- scale size with approximate number of subterms up to 8k
-- make sure size is coprime with power of two for collision avoidance
let size := (1 <<< min (max e.approxDepth.toUSize 1) 13) - 1
{ size, keysResults := mkArray (2 * size).toNat (unsafeCast ()) }
unsafe abbrev ReplaceM := StateM (PtrMap Expr Expr)
@[inline]
unsafe def Cache.keyIdx (c : Cache) (key : Expr) : USize :=
ptrAddrUnsafe key % c.size
@[inline]
unsafe def Cache.resultIdx (c : Cache) (key : Expr) : USize :=
c.keyIdx key + c.size
@[inline]
unsafe def Cache.hasResultFor (c : Cache) (key : Expr) : Bool :=
have : (c.keyIdx key).toNat < c.keysResults.size := lcProof
ptrEq (unsafeCast key) c.keysResults[c.keyIdx key]
@[inline]
unsafe def Cache.getResultFor (c : Cache) (key : Expr) : Expr :=
have : (c.resultIdx key).toNat < c.keysResults.size := lcProof
unsafeCast c.keysResults[c.resultIdx key]
unsafe def Cache.store (c : Cache) (key result : Expr) : Cache :=
{ c with keysResults := c.keysResults
|>.uset (c.keyIdx key) (unsafeCast key) lcProof
|>.uset (c.resultIdx key) (unsafeCast result) lcProof }
abbrev ReplaceM := StateM Cache
@[inline]
unsafe def cache (key : Expr) (result : Expr) : ReplaceM Expr := do
modify (·.store key result)
unsafe def cache (key : Expr) (exclusive : Bool) (result : Expr) : ReplaceM Expr := do
unless exclusive do
modify (·.insert key result)
pure result
@[specialize]
unsafe def replaceUnsafeM (f? : Expr Option Expr) (e : Expr) : ReplaceM Expr := do
let rec @[specialize] visit (e : Expr) := do
if ( get).hasResultFor e then
return ( get).getResultFor e
else match f? e with
| some eNew => cache e eNew
-- TODO: We need better control over RC operations to ensure
-- the following (unsafe) optimization is correctly applied.
let excl := isExclusiveUnsafe e
unless excl do
if let some result := ( get).find? e then
return result
match f? e with
| some eNew => cache e excl eNew
| none => match e with
| Expr.forallE _ d b _ => cache e <| e.updateForallE! ( visit d) ( visit b)
| Expr.lam _ d b _ => cache e <| e.updateLambdaE! ( visit d) ( visit b)
| Expr.mdata _ b => cache e <| e.updateMData! ( visit b)
| Expr.letE _ t v b _ => cache e <| e.updateLet! ( visit t) ( visit v) ( visit b)
| Expr.app f a => cache e <| e.updateApp! ( visit f) ( visit a)
| Expr.proj _ _ b => cache e <| e.updateProj! ( visit b)
| e => pure e
| .forallE _ d b _ => cache e excl <| e.updateForallE! ( visit d) ( visit b)
| .lam _ d b _ => cache e excl <| e.updateLambdaE! ( visit d) ( visit b)
| .mdata _ b => cache e excl <| e.updateMData! ( visit b)
| .letE _ t v b _ => cache e excl <| e.updateLet! ( visit t) ( visit v) ( visit b)
| .app f a => cache e excl <| e.updateApp! ( visit f) ( visit a)
| .proj _ _ b => cache e excl <| e.updateProj! ( visit b)
| e => return e
visit e
@[inline]
unsafe def replaceUnsafe (f? : Expr Option Expr) (e : Expr) : Expr :=
(replaceUnsafeM f? e).run' (Cache.new e)
(replaceUnsafeM f? e).run' mkPtrMap
end ReplaceImpl