Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
b8faf416af perf: add PersistentHashMap.findKeyD and PersistentHashSet.findD
This PR implements `PersistentHashMap.findKeyD` and
`PersistentHashSet.findD`. The motivation is avoid two memory
allocations (`Prod.mk` and `Option.some`) when the collections
contains the key.
2026-01-05 11:53:18 -08:00
3 changed files with 52 additions and 17 deletions

View File

@@ -193,6 +193,28 @@ partial def findEntryAux [BEq α] : Node α β → USize → α → Option (α
def findEntry? {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β α Option (α × β)
| { root }, k => findEntryAux root (hash k |>.toUSize) k
partial def findKeyDAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) (k₀ : α) : α :=
if h : i < keys.size then
let k' := keys[i]
if k == k' then k'
else findKeyDAtAux keys vals heq (i+1) k k₀
else k₀
partial def findKeyDAux [BEq α] : Node α β USize α α α
| .entries entries, h, k, k₀ =>
let j := (mod2Shift h shift).toNat
match entries[j]! with
| .null => k₀
| .ref node => findKeyDAux node (div2Shift h shift) k k₀
| .entry k' _ => if k == k' then k' else k₀
| .collision keys vals heq, _, k, k₀ => findKeyDAtAux keys vals heq 0 k k₀
/--
A more efficient `m.findEntry? a |>.map (·.1) |>.getD a₀`
-/
@[inline] def findKeyD {_ : BEq α} {_ : Hashable α} (m : PersistentHashMap α β) (a : α) (a₀ : α) : α :=
findKeyDAux m.root (hash a |>.toUSize) a a₀
partial def containsAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Bool :=
if h : i < keys.size then
let k' := keys[i]

View File

@@ -45,6 +45,9 @@ variable {_ : BEq α} {_ : Hashable α}
| some (a, _) => some a
| none => none
@[inline] def findD (s : PersistentHashSet α) (a : α) (a₀ : α) : α :=
s.set.findKeyD a a₀
@[inline] def contains (s : PersistentHashSet α) (a : α) : Bool :=
s.set.contains a

View File

@@ -69,20 +69,24 @@ private structure State where
private abbrev M := StateM State
private def dummy : AlphaKey := { expr := mkConst `__dummy__}
private def save (e : Expr) (r : Expr) : M Expr := do
if let some r := ( get).set.find? { expr := r } then
let r := r.expr
modify fun { set, map } => {
set
map := map.insert { expr := e } r
}
return r
else
let prev := ( get).set.findD { expr := r } dummy
if isSameExpr prev.expr dummy.expr then
-- `r` is new
modify fun { set, map } => {
set := set.insert { expr := r }
map := map.insert { expr := e } r |>.insert { expr := r } r
}
return r
else
let r := prev.expr
modify fun { set, map } => {
set
map := map.insert { expr := e } r
}
return r
private abbrev visit (e : Expr) (k : M Expr) : M Expr := do
/-
@@ -96,10 +100,12 @@ private abbrev visit (e : Expr) (k : M Expr) : M Expr := do
**Note**: The input may contain sub-expressions that have already been processed and are
already maximally shared.
-/
if let some r := ( get).set.find? { expr := e } then
return r.expr
else
let prev := ( get).set.findD { expr := e } dummy
if isSameExpr prev.expr dummy.expr then
-- `e` has not been hash-consed before
save e ( k)
else
return prev.expr
private def go (e : Expr) : M Expr := do
match e with
@@ -125,17 +131,21 @@ private def go (e : Expr) : M Expr := do
(e, set)
private def saveInc (e : Expr) : AlphaShareCommonM Expr := do
if let some r := ( get).set.find? { expr := e } then
return r.expr
else
let prev := ( get).set.findD { expr := e } dummy
if isSameExpr prev.expr dummy.expr then
-- `e` is new
modify fun { set := set } => { set := set.insert { expr := e } }
return e
else
return prev.expr
@[inline] private def visitInc (e : Expr) (k : AlphaShareCommonM Expr) : AlphaShareCommonM Expr := do
if let some r := ( get).set.find? { expr := e } then
return r.expr
else
let prev := ( get).set.findD { expr := e } dummy
if isSameExpr prev.expr dummy.expr then
-- `e` has now been cached before
saveInc ( k)
else
return prev.expr
/--
Incremental variant of `shareCommonAlpha` for expressions constructed from already-shared subterms.