Compare commits

...

3 Commits

Author SHA1 Message Date
Leonardo de Moura
c6885094d9 test: 2025-12-23 13:59:49 -08:00
Leonardo de Moura
7b49793b0a test: PersistentArray.forM with start position 2025-12-23 13:59:49 -08:00
Leonardo de Moura
f422c53884 feat: PersistentArray.forM with initial position 2025-12-23 13:59:49 -08:00
2 changed files with 46 additions and 1 deletions

View File

@@ -277,9 +277,25 @@ instance [Monad m] : ForIn m (PersistentArray α) α where
| node cs => cs.forM (fun c => forMAux f c)
| leaf vs => vs.forM f
@[specialize] def forM (t : PersistentArray α) (f : α m PUnit) : m PUnit :=
@[specialize] def forMFrom0 (t : PersistentArray α) (f : α m PUnit) : m PUnit :=
forMAux f t.root *> t.tail.forM f
@[specialize] private partial def forFromMAux (f : α m PUnit) : PersistentArrayNode α USize USize m PUnit
| node cs, i, shift => do
let j := (div2Shift i shift).toNat
forFromMAux f cs[j]! (mod2Shift i shift) (shift - initShift)
cs.forM (start := j+1) (forMAux f)
| leaf vs, i, _ => vs.forM (start := i.toNat) f
@[specialize] def forM (t : PersistentArray α) (f : α m PUnit) (start : Nat := 0) : m PUnit := do
if start == 0 then
forMFrom0 t f
else if start >= t.tailOff then
t.tail.forM (start := start - t.tailOff) f
else do
forFromMAux f t.root (USize.ofNat start) t.shift
t.tail.forM f
end
@[inline] def foldl (t : PersistentArray α) (f : β α β) (init : β) (start : Nat := 0) : β :=

View File

@@ -0,0 +1,29 @@
import Lean.Data.PersistentArray
/-!
Test `PersistentArray.forM` with nonzero start position.
-/
def mk (n : Nat) : Lean.PersistentArray Nat :=
List.range n |>.toPArray'
def sum1 (start : Nat) (s : List Nat) : Nat :=
let (_, s) := StateT.run (m := Id) (s.drop start |>.forM fun val => modify (· + val)) 0
s
def sum2 (start : Nat) (s : Lean.PArray Nat) : Nat :=
let (_, s) := StateT.run (m := Id) (s.forM (start := start) (fun val => modify (· + val))) 0
s
def check (s₁ : List Nat) : IO Unit := do
let s₂ := s₁.toPArray'
let n := s₂.size
for i in *...n do
unless sum1 i s₁ == sum2 i s₂ do
throw <| .userError "failed"
IO.println "ok"
#eval check (List.range 10)
#eval check (List.range 0)
#eval check (List.range 2000)
#eval check (List.replicate 1000 1)
#eval check (List.replicate 10 2)