Compare commits

...

5 Commits

Author SHA1 Message Date
Mario Carneiro
c12204144c double lookup 2024-04-23 23:33:50 -04:00
Mario Carneiro
3fa93531e4 unboxed tuples 2024-04-23 23:03:32 -04:00
Mario Carneiro
91441a4d60 defunctionalization 2024-04-23 21:13:53 -04:00
Mario Carneiro
9a7a98d4b6 see if CPS helps 2024-04-23 20:05:36 -04:00
Scott Morrison
5f3404ae53 chore: fix (PersistentHashMap.insert k v).size
correctier

more thorough test

correction
2024-04-23 20:31:23 +10:00
2 changed files with 65 additions and 23 deletions

View File

@@ -62,6 +62,8 @@ inductive IsCollisionNode : Node α β → Prop where
abbrev CollisionNode (α β) := { n : Node α β // IsCollisionNode n }
instance : Inhabited (CollisionNode α β) := Node.collision #[] #[] rfl, .mk _ _ _
inductive IsEntriesNode : Node α β Prop where
| mk (entries : Array (Entry α β (Node α β))) : IsEntriesNode (Node.entries entries)
@@ -87,6 +89,7 @@ partial def insertAtCollisionNodeAux [BEq α] : CollisionNode α β → Nat →
Node.collision (keys.push k) (vals.push v) (size_push heq k v), IsCollisionNode.mk _ _ _
| Node.entries _, h, _, _, _ => nomatch h
/-- Inserts a key-value pair into a CollisionNode, also returning whether an existing value was replaced. -/
def insertAtCollisionNode [BEq α] : CollisionNode α β α β CollisionNode α β :=
fun n k v => insertAtCollisionNodeAux n 0 k v
@@ -101,36 +104,41 @@ def mkCollisionNode (k₁ : α) (v₁ : β) (k₂ : α) (v₂ : β) : Node α β
let vs := (vs.push v₁).push v₂
Node.collision ks vs rfl
/--
Inserts a key-value pair into a node, returning the new node,
along with a `Bool` indicating whether an existing value was replaced.
-/
partial def insertAux [BEq α] [Hashable α] : Node α β USize USize α β Node α β
| Node.collision keys vals heq, _, depth, k, v =>
let newNode := insertAtCollisionNode Node.collision keys vals heq, IsCollisionNode.mk _ _ _ k v
if depth >= maxDepth || getCollisionNodeSize newNode < maxCollisions then newNode.val
else match newNode with
| Node.entries _, h => nomatch h
| Node.collision keys vals heq, _ =>
let rec traverse (i : Nat) (entries : Node α β) : Node α β :=
if h : i < keys.size then
let k := keys[i]
have : i < vals.size := heq h
let v := vals[i]
let h := hash k |>.toUSize
let h := div2Shift h (shift * (depth - 1))
traverse (i+1) (insertAux entries h depth k v)
else
entries
traverse 0 mkEmptyEntries
else
let Node.collision keys vals heq, _ := newNode
let rec traverse (i : Nat) (entries : Node α β) : Node α β :=
if h : i < keys.size then
let k := keys[i]
have : i < vals.size := heq h
let v := vals[i]
let h := hash k |>.toUSize
let h := div2Shift h (shift * (depth - 1))
traverse (i+1) (insertAux entries h depth k v)
else
entries
traverse 0 mkEmptyEntries
| Node.entries entries, h, depth, k, v =>
let j := (mod2Shift h shift).toNat
Node.entries $ entries.modify j fun entry =>
match entry with
| Entry.null => Entry.entry k v
| Entry.ref node => Entry.ref $ insertAux node (div2Shift h shift) (depth+1) k v
-- We can't use `entries.modify` here, as we need to return `replaced`.
-- To ensure linearity, we use `swapAt!`.
let (entry, entries') := entries.swapAt! j .null
match entry with
| Entry.null =>
Node.entries $ entries'.set! j (.entry k v)
| Entry.ref node =>
let newNode := insertAux node (div2Shift h shift) (depth+1) k v
Node.entries $ entries'.set! j (.ref newNode)
| Entry.entry k' v' =>
if k == k' then Entry.entry k v
else Entry.ref $ mkCollisionNode k' v' k v
def insert {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β α β PersistentHashMap α β
| { root := n, size := sz }, k, v => { root := insertAux n (hash k |>.toUSize) 1 k v, size := sz + 1 }
if k == k' then Node.entries $ entries'.set! j (.entry k v)
else Node.entries $ entries'.set! j (.ref <| mkCollisionNode k' v' k v)
partial def findAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Option β :=
if h : i < keys.size then
@@ -204,6 +212,13 @@ partial def containsAux [BEq α] : Node α β → USize → α → Bool
def contains [BEq α] [Hashable α] : PersistentHashMap α β α Bool
| { root := n, .. }, k => containsAux n (hash k |>.toUSize) k
def insert {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β α β PersistentHashMap α β
| { root := n, size := sz }, k, v =>
let hash := hash k |>.toUSize
let replaced := containsAux n hash k
let node := insertAux n hash 1 k v
{ root := node, size := if replaced then sz else sz + 1 }
partial def isUnaryEntries (a : Array (Entry α β (Node α β))) (i : Nat) (acc : Option (α × β)) : Option (α × β) :=
if h : i < a.size then
match a[i] with

27
tests/lean/run/3029.lean Normal file
View File

@@ -0,0 +1,27 @@
import Lean.Data.PersistentHashMap
open Lean
example : ((PersistentHashMap.empty : PersistentHashMap Nat Nat)
|> (·.insert 1 1)
|> (·.insert 1 1)
|> (·.size)) = 1 := by native_decide
example : ((PersistentHashMap.empty : PersistentHashMap Nat Nat)
|> (·.insert 1 1)
|> (·.insert 2 1)
|> (·.size)) = 2 := by native_decide
/-- Inserts the pairs (i * n, i * n) for all `i < k`. -/
def insertPairs (k n : Nat) (m : PersistentHashMap Nat Nat) : PersistentHashMap Nat Nat :=
(List.range k).foldl (init := m) fun m i => m.insert (n * i) (n * i)
/-- Inserts `0, 1, 2, 3, ..., 2^(j-1)`, and then `0, 2, 4, ..., 2^(j-1)`, and so on. -/
def insertPows (j : Nat) (m : PersistentHashMap Nat Nat) : PersistentHashMap Nat Nat :=
(List.range j).foldl (init := m) fun m i => insertPairs (2^(j-i)) (2^i) m
/-- As for `insertPows`, but backwards. -/
def insertPows' (j : Nat) (m : PersistentHashMap Nat Nat) : PersistentHashMap Nat Nat :=
(List.range j).reverse.foldl (init := m) fun m i => insertPairs (2^(j-i)) (2^i) m
example : (insertPows 12 (PersistentHashMap.empty : PersistentHashMap Nat Nat)).size = 4096 := by native_decide
example : (insertPows' 12 (PersistentHashMap.empty : PersistentHashMap Nat Nat)).size = 4096 := by native_decide