mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-17 18:34:06 +00:00
Compare commits
20 Commits
57df23f27e
...
sofia/sync
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
73e799f773 | ||
|
|
06352ca77e | ||
|
|
18e0fd073f | ||
|
|
6304cf6393 | ||
|
|
5068ade9a2 | ||
|
|
c3353e98c7 | ||
|
|
21e189b408 | ||
|
|
bfa3d30b20 | ||
|
|
00eda8a6d8 | ||
|
|
31bb12b529 | ||
|
|
898150d4f8 | ||
|
|
dda0c18b92 | ||
|
|
a4c87e0e90 | ||
|
|
d2698ca637 | ||
|
|
9ea7ebd3db | ||
|
|
aed8293a0a | ||
|
|
380698aa1a | ||
|
|
67e7690eaa | ||
|
|
b8449007db | ||
|
|
4d477e7784 |
@@ -13,3 +13,6 @@ public import Std.Sync.RecursiveMutex
|
||||
public import Std.Sync.Barrier
|
||||
public import Std.Sync.SharedMutex
|
||||
public import Std.Sync.Notify
|
||||
public import Std.Sync.Broadcast
|
||||
|
||||
@[expose] public section
|
||||
|
||||
644
src/Std/Sync/Broadcast.lean
Normal file
644
src/Std/Sync/Broadcast.lean
Normal file
@@ -0,0 +1,644 @@
|
||||
/-
|
||||
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sofia Rodrigues
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Std.Data
|
||||
public import Init.System.Promise
|
||||
public import Init.Data.Queue
|
||||
public import Init.Data.Vector
|
||||
public import Std.Sync.Mutex
|
||||
public import Std.Internal.Async.Select
|
||||
public import Std.Internal.Async.IO
|
||||
|
||||
public section
|
||||
|
||||
namespace Std
|
||||
|
||||
open Std.Internal.Async.IO
|
||||
open Std.Internal.IO.Async
|
||||
|
||||
/-!
|
||||
The `Std.Sync.Broadcast` module implements a broadcasting primitive for sending values
|
||||
to multiple consumers. It maintains a queue of values and supports both synchronous
|
||||
and asynchronous waiting.
|
||||
|
||||
This module is heavily inspired by `Std.Sync.Channel` as well as
|
||||
[tokio’s broadcast implementation](https://github.com/tokio-rs/tokio/blob/master/tokio/src/sync/broadcast.rs).
|
||||
-/
|
||||
|
||||
/--
|
||||
Errors that may be thrown while interacting with the broadcast channel API.
|
||||
-/
|
||||
inductive Broadcast.Error where
|
||||
/--
|
||||
Tried to send to a closed broadcast channel.
|
||||
-/
|
||||
| closed
|
||||
|
||||
/--
|
||||
Tried to close an already closed broadcast channel.
|
||||
-/
|
||||
| alreadyClosed
|
||||
|
||||
/--
|
||||
Tried to unsubscribe a channel that already is not part of it.
|
||||
-/
|
||||
| notSubscribed
|
||||
|
||||
deriving Repr, DecidableEq, Hashable
|
||||
|
||||
instance instToStringBroadcastError : ToString Broadcast.Error where
|
||||
toString
|
||||
| .closed => "attempted to send on an already closed channel"
|
||||
| .alreadyClosed => "attempted to close an already closed broadcast channel"
|
||||
| .notSubscribed => "receiver not subscribed in a broadcast channel"
|
||||
|
||||
instance instMonadLiftBroadcastIO : MonadLift (EIO Broadcast.Error) IO where
|
||||
monadLift x := EIO.toIO (.userError <| toString ·) x
|
||||
|
||||
private structure Broadcast.Consumer (α : Type) where
|
||||
promise : IO.Promise Bool
|
||||
waiter : Option (Internal.IO.Async.Waiter (Option α))
|
||||
|
||||
private def Broadcast.Consumer.resolve (c : Broadcast.Consumer α) (b : Bool) : BaseIO Unit :=
|
||||
c.promise.resolve b
|
||||
|
||||
private structure Slot (α : Type) where
|
||||
value : Option α
|
||||
pos : Nat
|
||||
remaining : Nat
|
||||
deriving Inhabited, Repr
|
||||
|
||||
private structure Bounded.State (α : Type) where
|
||||
/--
|
||||
Queue of producers blocked waiting for buffer space to become available.
|
||||
-/
|
||||
producers : Std.Queue (IO.Promise Bool)
|
||||
|
||||
/--
|
||||
Queue of consumers blocked waiting for new messages to be broadcast.
|
||||
-/
|
||||
waiters : Std.Queue (Broadcast.Consumer α)
|
||||
|
||||
/--
|
||||
Maximum number of messages that can be buffered before producers block.
|
||||
-/
|
||||
capacity : { x : Nat // 0 < x }
|
||||
|
||||
/--
|
||||
Current number of messages stored in the circular buffer.
|
||||
-/
|
||||
size : Nat
|
||||
|
||||
/--
|
||||
Circular buffer storing broadcast messages accessible to all receivers.
|
||||
-/
|
||||
buffer : Vector (IO.Ref (Slot α)) capacity
|
||||
|
||||
/--
|
||||
Index where the next message will be written in the circular buffer.
|
||||
-/
|
||||
write : Fin capacity
|
||||
|
||||
/--
|
||||
Index of the oldest message still available for lagging receivers.
|
||||
-/
|
||||
read : Fin capacity
|
||||
|
||||
/--
|
||||
Maps receiver IDs to their current position in the message sequence.
|
||||
-/
|
||||
receivers : Std.TreeMap Nat Nat
|
||||
|
||||
/--
|
||||
Counter for assigning unique IDs to new receivers.
|
||||
-/
|
||||
nextId : Nat
|
||||
|
||||
/--
|
||||
Whether the channel has been closed, preventing new messages.
|
||||
-/
|
||||
closed : Bool
|
||||
|
||||
/--
|
||||
Global message sequence number for the next message to be sent.
|
||||
-/
|
||||
pos : Nat
|
||||
|
||||
/--
|
||||
A channel that can create `Bounded.Receiver` and send messages.
|
||||
-/
|
||||
private structure Bounded (α : Type) where
|
||||
state : Mutex (Bounded.State α)
|
||||
|
||||
/--
|
||||
A channel that can receive messages from `Bounded`.
|
||||
-/
|
||||
private structure Bounded.Receiver (α : Type) where
|
||||
state : Mutex (Bounded.State α)
|
||||
id : Nat
|
||||
|
||||
namespace Bounded
|
||||
|
||||
/--
|
||||
Creates a new `Bounded` channel.
|
||||
-/
|
||||
private def new {α} (capacity : Nat := 16) (h : capacity > 0 := by decide) : BaseIO (Bounded α) := do
|
||||
return { state := ← Mutex.new {
|
||||
producers := .empty
|
||||
waiters := .empty
|
||||
buffer := ← Vector.mapM (fun _ => IO.mkRef { pos := 0, value := none, remaining := 0 }) (Vector.replicate capacity ())
|
||||
receivers := .empty
|
||||
nextId := 0
|
||||
closed := false
|
||||
pos := 0
|
||||
size := 0
|
||||
read := ⟨0, h⟩
|
||||
write := ⟨0, h⟩
|
||||
capacity := ⟨capacity, h⟩
|
||||
}}
|
||||
|
||||
/--
|
||||
Subscribes a new `Receiver` in the `Bounded` channel.
|
||||
-/
|
||||
private def subscribe (bd : Bounded α) : IO (Receiver α) := do
|
||||
let id ← bd.state.atomically do
|
||||
modifyGet fun state =>
|
||||
let id := state.nextId
|
||||
(id, { state with nextId := id + 1, receivers := state.receivers.insert id state.pos })
|
||||
return { state := bd.state, id }
|
||||
|
||||
/--
|
||||
Returns true if the buffer contains no elements.
|
||||
-/
|
||||
private def isEmpty [Monad m] [MonadLiftT (ST IO.RealWorld) m] : AtomicT (Bounded.State α) m Bool := do
|
||||
let mut st ← get
|
||||
return st.size = 0
|
||||
|
||||
/--
|
||||
Returns true if the buffer is at full capacity.
|
||||
-/
|
||||
private def isFull : AtomicT (Bounded.State α) BaseIO Bool := do
|
||||
let mut st ← get
|
||||
return st.size ≥ st.capacity
|
||||
|
||||
/--
|
||||
Enqueues an element to the back of the circular buffer.
|
||||
If the buffer is full, the oldest element (at front) is overwritten.
|
||||
-/
|
||||
private def enqueue (value : α) (st : Bounded.State α) : BaseIO (Bounded.State α) := do
|
||||
let tailRef := st.buffer.get st.write
|
||||
|
||||
tailRef.set { pos := st.pos, remaining := st.receivers.size, value := some value }
|
||||
let write : Fin st.capacity := @Fin.ofNat _ ⟨Nat.ne_zero_iff_zero_lt.mpr st.capacity.property⟩ (st.write + 1)
|
||||
let size := st.size + 1
|
||||
let pos := st.pos + 1
|
||||
|
||||
return { st with write, size, pos }
|
||||
|
||||
/--
|
||||
Dequeues an element from the front of the circular buffer.
|
||||
Returns none if the buffer is empty.
|
||||
-/
|
||||
private def dequeue (st: State α) : State α :=
|
||||
let size := st.size - 1
|
||||
let read : Fin st.capacity := @Fin.ofNat _ ⟨Nat.ne_zero_iff_zero_lt.mpr st.capacity.property⟩ (st.read + 1)
|
||||
|
||||
{ st with read, size }
|
||||
|
||||
/--
|
||||
Peeks at the element at the front of the buffer without removing it.
|
||||
Returns none if the buffer is empty.
|
||||
-/
|
||||
private def getSlot
|
||||
[Monad m] [MonadLiftT (ST IO.RealWorld) m] (place : Nat) :
|
||||
AtomicT (Bounded.State α) m (IO.Ref (Slot α)) := do
|
||||
let st ← get
|
||||
let idx := (@Fin.ofNat st.capacity ⟨Nat.ne_zero_of_lt st.capacity.property⟩ place)
|
||||
return st.buffer.get idx
|
||||
|
||||
/--
|
||||
Subscribes a new `Receiver` in the `Bounded` channel.
|
||||
-/
|
||||
private def trySend' (v : α) : AtomicT (Bounded.State α) BaseIO (Option Nat) := do
|
||||
if ← isFull then
|
||||
return none
|
||||
else
|
||||
let st ← enqueue v (← get)
|
||||
let waiters := st.waiters
|
||||
set ({ st with waiters := ∅ })
|
||||
|
||||
for consumer in waiters.toArray do
|
||||
discard <| consumer.resolve true
|
||||
|
||||
return some st.receivers.size
|
||||
|
||||
private def trySend (ch : Bounded α) (v : α) : BaseIO (Option Nat) := do
|
||||
ch.state.atomically do
|
||||
if (← get).closed then
|
||||
return none
|
||||
else if (← get).receivers.isEmpty then
|
||||
return (some 0)
|
||||
else
|
||||
trySend' v
|
||||
|
||||
private partial def send (ch : Bounded α) (v : α) : BaseIO (Task (Except Broadcast.Error Nat)) := do
|
||||
ch.state.atomically do
|
||||
if (← get).closed then
|
||||
return .pure <| .error .closed
|
||||
else if (← get).receivers.isEmpty then
|
||||
return .pure <| .ok 0
|
||||
else if let some receivers ← trySend' v then
|
||||
return .pure <| .ok receivers
|
||||
else
|
||||
let promise ← IO.Promise.new
|
||||
modify fun st => { st with producers := st.producers.enqueue promise }
|
||||
|
||||
BaseIO.bindTask promise.result? fun res => do
|
||||
if res.getD false then
|
||||
Bounded.send ch v
|
||||
else
|
||||
return .pure <| .error .closed
|
||||
|
||||
private def close (ch : Bounded α) : EIO Broadcast.Error Unit := do
|
||||
ch.state.atomically do
|
||||
let st ← get
|
||||
|
||||
if st.closed then
|
||||
throw .alreadyClosed
|
||||
|
||||
for consumer in st.waiters.toArray do
|
||||
consumer.resolve false
|
||||
|
||||
set { st with waiters := ∅, closed := true }
|
||||
return ()
|
||||
|
||||
private def isClosed (ch : Bounded α) : BaseIO Bool :=
|
||||
ch.state.atomically do
|
||||
return (← get).closed
|
||||
|
||||
namespace Receiver
|
||||
|
||||
private def getSlotValue [Monad m] [MonadLiftT (ST IO.RealWorld) m]
|
||||
(slot : IO.Ref (Slot α)) (next : Nat) : AtomicT (Bounded.State α) m (Option α × Bool) :=
|
||||
slot.modifyGet fun slot =>
|
||||
if next != slot.pos then
|
||||
((none, false), slot)
|
||||
else if slot.remaining == 1 then
|
||||
((slot.value, true), { slot with value := none, remaining := 0 })
|
||||
else
|
||||
((slot.value, false), { slot with remaining := slot.remaining - 1 })
|
||||
|
||||
private def getValueByPosition [Monad m] [MonadLiftT (ST IO.RealWorld) m]
|
||||
[MonadLiftT BaseIO m] (next : Nat) : AtomicT (Bounded.State α) m (Option α) := do
|
||||
let mut st ← get
|
||||
|
||||
if ← isEmpty then
|
||||
return none
|
||||
|
||||
let id := next % st.capacity
|
||||
let slot ← getSlot id
|
||||
|
||||
let (some val, shouldDequeue) ← getSlotValue slot next
|
||||
| return none
|
||||
|
||||
if shouldDequeue then
|
||||
st := dequeue st
|
||||
|
||||
if let some (producer, producers) := st.producers.dequeue? then
|
||||
producer.resolve true
|
||||
st := { st with producers }
|
||||
|
||||
set st
|
||||
return some val
|
||||
|
||||
/--
|
||||
Unsubscribes a `Receiver` from the `Bounded` channel.
|
||||
-/
|
||||
private def unsubscribe (bd : Bounded.Receiver α) : IO Unit := do
|
||||
let id ← bd.state.atomically do
|
||||
let st ← get
|
||||
|
||||
let some next := st.receivers.get? bd.id
|
||||
| return Except.error Broadcast.Error.notSubscribed
|
||||
|
||||
let mut currentSt := st
|
||||
let mut currentNext := next
|
||||
|
||||
while currentNext < currentSt.pos ∧ currentSt.size > 0 do
|
||||
let some _val ← getValueByPosition currentNext | break
|
||||
|
||||
currentSt ← get
|
||||
currentNext := currentNext + 1
|
||||
|
||||
set { currentSt with receivers := currentSt.receivers.erase bd.id }
|
||||
|
||||
pure <| .ok ()
|
||||
|
||||
match id with
|
||||
| .error res => throw (.userError (toString res))
|
||||
| .ok _ => pure ()
|
||||
|
||||
private def tryRecv'
|
||||
[Monad m] [MonadLiftT (ST IO.RealWorld) m] [MonadLiftT BaseIO m]
|
||||
(receiverId : Nat) : AtomicT (Bounded.State α) m (Option α) := do
|
||||
let st ← get
|
||||
|
||||
let some next := st.receivers[receiverId]?
|
||||
| return none
|
||||
|
||||
if let some val ← getValueByPosition next then
|
||||
modify ({ · with receivers := st.receivers.modify receiverId (· + 1) })
|
||||
return some val
|
||||
else
|
||||
return none
|
||||
|
||||
private def tryRecv (ch : Bounded.Receiver α) : BaseIO (Option α) :=
|
||||
ch.state.atomically (tryRecv' ch.id)
|
||||
|
||||
private partial def recv (ch : Bounded.Receiver α) : BaseIO (Task (Option α)) := do
|
||||
ch.state.atomically do
|
||||
if ¬ (← get).receivers.contains ch.id then
|
||||
return .pure none
|
||||
else if let some val ← tryRecv' ch.id then
|
||||
return .pure <| some val
|
||||
else if (← get).closed then
|
||||
return .pure none
|
||||
else
|
||||
let promise ← IO.Promise.new
|
||||
modify fun st => { st with waiters := st.waiters.enqueue ⟨promise, none⟩ }
|
||||
BaseIO.bindTask promise.result? fun res => do
|
||||
if res.getD false then
|
||||
Bounded.Receiver.recv ch
|
||||
else
|
||||
return .pure none
|
||||
|
||||
private partial def forAsync
|
||||
(f : α → BaseIO Unit) (ch : Bounded.Receiver α)
|
||||
(prio : Task.Priority := .default) :
|
||||
BaseIO (Task Unit) := do
|
||||
BaseIO.bindTask (prio := prio) (← ch.recv) fun
|
||||
| none => return .pure ()
|
||||
| some v => do f v; forAsync f ch prio
|
||||
|
||||
@[inline]
|
||||
private def recvReady'
|
||||
[Monad m] [MonadLiftT (ST IO.RealWorld) m] [MonadLiftT IO m] [MonadLiftT BaseIO m]
|
||||
(receiverId : Nat) : AtomicT (State α) m Bool := do
|
||||
let st ← get
|
||||
|
||||
if st.closed then
|
||||
return true
|
||||
|
||||
let some next := st.receivers.get? receiverId
|
||||
| return false
|
||||
|
||||
if st.size = 0 then
|
||||
return false
|
||||
else
|
||||
let id := next % st.capacity
|
||||
let slot ← getSlot id
|
||||
let slotVal ← slot.get
|
||||
return slotVal.pos = next
|
||||
|
||||
open Internal.IO.Async in
|
||||
private partial def recvSelector (ch : Bounded.Receiver α) : Selector (Option α) where
|
||||
tryFn := do
|
||||
ch.state.atomically do
|
||||
if ← recvReady' ch.id then
|
||||
let val ← tryRecv' ch.id
|
||||
return some val
|
||||
else
|
||||
return none
|
||||
|
||||
registerFn waiter := registerAux ch waiter
|
||||
|
||||
unregisterFn := do
|
||||
ch.state.atomically do
|
||||
let st ← get
|
||||
let waiters ← st.waiters.filterM fun c => do
|
||||
match c.waiter with
|
||||
| some waiter => return !(← waiter.checkFinished)
|
||||
| none => return true
|
||||
|
||||
set { st with waiters }
|
||||
where
|
||||
registerAux (ch : Bounded.Receiver α) (waiter : Waiter (Option α)) : IO Unit := do
|
||||
ch.state.atomically do
|
||||
if ← recvReady' ch.id then
|
||||
let lose := do
|
||||
let st ← get
|
||||
if let some (waiter, waiters) := st.waiters.dequeue? then
|
||||
waiter.resolve true
|
||||
set { st with waiters }
|
||||
let win promise := do
|
||||
promise.resolve (.ok (← tryRecv' ch.id))
|
||||
|
||||
waiter.race lose win
|
||||
else
|
||||
let promise ← IO.Promise.new
|
||||
modify fun st => { st with waiters := st.waiters.enqueue ⟨promise, some waiter⟩ }
|
||||
|
||||
IO.chainTask promise.result? fun res? => do
|
||||
match res? with
|
||||
| none => return ()
|
||||
| some res =>
|
||||
if res then
|
||||
registerAux ch waiter
|
||||
else
|
||||
let lose := return ()
|
||||
let win promise := promise.resolve (.ok none)
|
||||
waiter.race lose win
|
||||
|
||||
end Receiver
|
||||
end Bounded
|
||||
|
||||
/--
|
||||
A multi-subscriber broadcast that delivers each message to all current subscribers.
|
||||
Supports only bounded buffering and an asynchronous API; to switch into
|
||||
synchronous mode use `Broadcast.sync`.
|
||||
|
||||
Unlike `Std.Channel`, each message is received by **every** subscriber instead of just one.
|
||||
Subscribers only receive messages sent after they have subscribed (unless otherwise specified).
|
||||
-/
|
||||
structure Broadcast (α : Type) where
|
||||
private mk ::
|
||||
private inner : Bounded α
|
||||
|
||||
/--
|
||||
A receiver for a `Broadcast` channel that can asynchronously receive messages.
|
||||
Each receiver gets a copy of every message sent to the broadcast channel after
|
||||
the receiver was created. Multiple receivers can exist for the same broadcast,
|
||||
and each will receive all messages independently.
|
||||
-/
|
||||
structure Broadcast.Receiver (α : Type) where
|
||||
private mk ::
|
||||
private inner : Bounded.Receiver α
|
||||
|
||||
namespace Broadcast
|
||||
|
||||
/--
|
||||
Creates a new broadcast channel.
|
||||
-/
|
||||
@[inline]
|
||||
def new {α} (capacity : Nat := 16) (h : capacity > 0 := by decide) : BaseIO (Broadcast α) := do
|
||||
return ⟨← Bounded.new capacity h⟩
|
||||
|
||||
/--
|
||||
Try to send a value to the broadcast channel, if this can be completed right away without blocking return
|
||||
`true`, otherwise don't send the value and return `false`.
|
||||
-/
|
||||
@[inline]
|
||||
def trySend (ch : Broadcast α) (v : α) : BaseIO (Option Nat) :=
|
||||
ch.inner.trySend v
|
||||
|
||||
/--
|
||||
Subscribes a new `Receiver` from the `Broadcast` channel.
|
||||
-/
|
||||
@[inline]
|
||||
def subscribe (ch : Broadcast α) : IO (Broadcast.Receiver α) := do
|
||||
Broadcast.Receiver.mk <$> ch.inner.subscribe
|
||||
|
||||
/--
|
||||
Closes a `Broadcast` channel.
|
||||
-/
|
||||
@[inline]
|
||||
def close (ch : Broadcast α) : IO Unit := do
|
||||
ch.inner.close
|
||||
|
||||
/--
|
||||
Send a value through the broadcast channel, returning a task that will resolve once the transmission
|
||||
could be completed.
|
||||
-/
|
||||
@[inline]
|
||||
def send (ch : Broadcast α) (v : α) : BaseIO (Task (Except IO.Error Nat)) := do
|
||||
BaseIO.bindTask (sync := true) (← ch.inner.send v)
|
||||
fun
|
||||
| .ok res => return .pure <| .ok res
|
||||
| .error err => return .pure <| .error (toString err)
|
||||
|
||||
namespace Receiver
|
||||
|
||||
/--
|
||||
Try to receive a value from the broadcast receiver, if a message is available right away
|
||||
return `some value`, otherwise return `none` without blocking.
|
||||
-/
|
||||
@[inline]
|
||||
def tryRecv (ch : Broadcast.Receiver α) : BaseIO (Option α) :=
|
||||
Std.Bounded.Receiver.tryRecv ch.inner
|
||||
|
||||
/--
|
||||
Receive a value from the broadcast receiver, returning a task that will resolve with
|
||||
the next available message. This will block until a message is available.
|
||||
-/
|
||||
@[inline]
|
||||
def recv [Inhabited α] (ch : Broadcast.Receiver α) : BaseIO (Task (Option α)) := do
|
||||
Std.Bounded.Receiver.recv ch.inner
|
||||
|
||||
open Internal.IO.Async in
|
||||
|
||||
/--
|
||||
Creates a `Selector` that resolves once the broadcast channel `ch` has data available and provides that that data.
|
||||
-/
|
||||
@[inline]
|
||||
def recvSelector [Inhabited α] (ch : Broadcast.Receiver α) : Selector (Option α) :=
|
||||
Bounded.Receiver.recvSelector ch.inner
|
||||
|
||||
/--
|
||||
Unsubscribes a `Receiver` from the `Broadcast` channel.
|
||||
-/
|
||||
@[inline]
|
||||
def unsubscribe (ch : Broadcast.Receiver α) : IO Unit := do
|
||||
ch.inner.unsubscribe
|
||||
|
||||
/--
|
||||
`ch.forAsync f` calls `f` for every message received on `ch`.
|
||||
|
||||
Note that if this function is called twice, each message will only arrive at exactly one invocation.
|
||||
-/
|
||||
partial def forAsync (f : α → BaseIO Unit) (ch : Broadcast.Receiver α)
|
||||
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
|
||||
ch.inner.forAsync f prio
|
||||
|
||||
instance [Inhabited α] : AsyncStream (Broadcast.Receiver α) (Option α) where
|
||||
next channel := channel.recvSelector
|
||||
stop channel := channel.unsubscribe
|
||||
|
||||
instance [Inhabited α] : AsyncRead (Broadcast.Receiver α) (Option α) where
|
||||
read receiver := Internal.IO.Async.Async.ofIOTask receiver.recv
|
||||
|
||||
instance [Inhabited α] : AsyncWrite (Broadcast α) α where
|
||||
write receiver x := do
|
||||
let task ← receiver.send x
|
||||
discard <| Async.ofTask <| task
|
||||
|
||||
end Receiver
|
||||
|
||||
/--
|
||||
A multi-subscriber broadcast that delivers each message to all current subscribers.
|
||||
Supports only bounded buffering and an asynchronous API.
|
||||
|
||||
It's the sync version of `Broadcast`.
|
||||
-/
|
||||
@[expose] def Sync (α : Type) : Type := Broadcast α
|
||||
|
||||
/--
|
||||
A receiver for a `Broadcast` channel that can asynchronously receive messages.
|
||||
Each receiver gets a copy of every message sent to the broadcast channel after
|
||||
the receiver was created. Multiple receivers can exist for the same broadcast,
|
||||
and each will receive all messages independently.
|
||||
|
||||
It's the sync version of `Broadcast.Receiver`.
|
||||
-/
|
||||
@[expose] def Sync.Receiver (α : Type) : Type := Broadcast.Receiver α
|
||||
|
||||
namespace Sync
|
||||
|
||||
@[inherit_doc Broadcast.new, inline]
|
||||
def new (capacity : Nat := 16) (h : capacity > 0 := by decide) : BaseIO (Sync α) :=
|
||||
Broadcast.new capacity h
|
||||
|
||||
@[inherit_doc Broadcast.trySend, inline]
|
||||
def trySend (ch : Sync α) (v : α) : BaseIO (Option Nat) :=
|
||||
Broadcast.trySend ch v
|
||||
|
||||
/--
|
||||
Send a value through the channel, blocking until the transmission could be completed.
|
||||
-/
|
||||
@[inline]
|
||||
def send (ch : Sync α) (v : α) : IO Nat := do
|
||||
IO.ofExcept =<< IO.wait (← Broadcast.send ch v)
|
||||
|
||||
namespace Receiver
|
||||
|
||||
@[inherit_doc Broadcast.Receiver.tryRecv, inline]
|
||||
def tryRecv (ch : Sync.Receiver α) : BaseIO (Option α) := Broadcast.Receiver.tryRecv ch
|
||||
|
||||
/--
|
||||
Receive a value from the channel, blocking until the transmission could be completed.
|
||||
-/
|
||||
def recv [Inhabited α] (ch : Sync.Receiver α) : BaseIO (Option α) := do
|
||||
IO.wait (← Broadcast.Receiver.recv ch)
|
||||
|
||||
partial def forIn [Inhabited α] [Monad m] [MonadLiftT BaseIO m]
|
||||
(ch : Sync.Receiver α) (f : α → β → m (ForInStep β)) : β → m β := fun b => do
|
||||
let a ← ch.recv
|
||||
match a with
|
||||
| none => pure b
|
||||
| some a =>
|
||||
match ← f a b with
|
||||
| .done b => pure b
|
||||
| .yield b => ch.forIn f b
|
||||
|
||||
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
|
||||
instance [Inhabited α] [MonadLiftT BaseIO m] : ForIn m (Sync.Receiver α) α where
|
||||
forIn ch b f := Receiver.forIn ch f b
|
||||
|
||||
end Receiver
|
||||
end Sync
|
||||
end Broadcast
|
||||
end Std
|
||||
391
tests/lean/run/broadcast.lean
Normal file
391
tests/lean/run/broadcast.lean
Normal file
@@ -0,0 +1,391 @@
|
||||
import Std.Internal.Async
|
||||
import Std.Sync
|
||||
|
||||
open Std.Internal.IO Async
|
||||
|
||||
-- Test tryRecv with empty channel
|
||||
def tryRecvEmpty : Async Unit := do
|
||||
let channel ← Std.Broadcast.new (capacity := 4) (α := Nat)
|
||||
let subs ← channel.subscribe
|
||||
|
||||
let result ← subs.tryRecv
|
||||
assert! result.isNone
|
||||
|
||||
#eval tryRecvEmpty.block
|
||||
|
||||
-- Test tryRecv with messages available
|
||||
def tryRecvWithMessages : Async Unit := do
|
||||
let channel ← Std.Broadcast.new (capacity := 4)
|
||||
let subs ← channel.subscribe
|
||||
|
||||
discard <| await (← channel.send 42)
|
||||
discard <| await (← channel.send 100)
|
||||
|
||||
let msg1 ← subs.tryRecv
|
||||
let msg2 ← subs.tryRecv
|
||||
let msg3 ← subs.tryRecv
|
||||
|
||||
assert! msg1 == some 42
|
||||
assert! msg2 == some 100
|
||||
assert! msg3.isNone
|
||||
|
||||
#eval tryRecvWithMessages.block
|
||||
|
||||
-- Test unsubscribe functionality
|
||||
def testUnsubscribe : Async Unit := do
|
||||
let channel ← Std.Broadcast.new (capacity := 4)
|
||||
let subs1 ← channel.subscribe
|
||||
let subs2 ← channel.subscribe
|
||||
|
||||
-- Send before unsubscribe
|
||||
discard <| await (← channel.send 1)
|
||||
|
||||
-- Unsubscribe subs1
|
||||
subs1.unsubscribe
|
||||
|
||||
-- Send after unsubscribe
|
||||
discard <| await (← channel.send 2)
|
||||
|
||||
-- subs1 should not receive the second message
|
||||
let msg1 ← await (← subs1.recv)
|
||||
let result ← subs1.tryRecv
|
||||
|
||||
-- subs2 should receive both messages
|
||||
let msg2 ← await (← subs2.recv)
|
||||
let msg3 ← await (← subs2.recv)
|
||||
|
||||
assert! msg1 == none
|
||||
assert! result.isNone -- No more messages for unsubscribed
|
||||
assert! msg2 == some 1
|
||||
assert! msg3 == some 2
|
||||
|
||||
#eval testUnsubscribe.block
|
||||
|
||||
def testUnsubscribeUnblock : Async Unit := do
|
||||
let channel ← Std.Broadcast.new (capacity := 4)
|
||||
|
||||
let subs1 ← channel.subscribe
|
||||
let subs2 ← channel.subscribe
|
||||
|
||||
-- Add 4 messages, so it reaches the limit.
|
||||
for i in [0:4] do
|
||||
assert! (← channel.trySend i).isSome
|
||||
|
||||
-- Mark subs1 messages as read
|
||||
for i in [0:10] do
|
||||
if i < 4 then
|
||||
assert! (← subs1.tryRecv) = some i
|
||||
else
|
||||
assert! (← subs1.tryRecv) = none
|
||||
|
||||
-- Mark 2 messages as read so it cleans 2 messages
|
||||
assert! (← subs2.tryRecv).isSome
|
||||
assert! (← subs2.tryRecv).isSome
|
||||
|
||||
assert! (← channel.trySend 5).isSome
|
||||
assert! (← channel.trySend 5).isSome
|
||||
assert! not (← channel.trySend 6).isSome
|
||||
|
||||
-- It unsubscribe and mark all subs2 messages as read.
|
||||
subs2.unsubscribe
|
||||
|
||||
-- Create a new subscriber to verify channel still works
|
||||
let subs3 ← channel.subscribe
|
||||
|
||||
-- Send one more message that the new subscriber should receive
|
||||
assert! (← channel.trySend 8).isSome
|
||||
|
||||
-- subs1 should be able to receive the messages sent after it last read:
|
||||
-- the two 5's and the 8
|
||||
let subs1Msg1 ← subs1.tryRecv
|
||||
let subs1Msg2 ← subs1.tryRecv
|
||||
let subs1Msg3 ← subs1.tryRecv
|
||||
let subs1Msg4 ← subs1.tryRecv -- should be none
|
||||
|
||||
assert! subs1Msg1 == some 5
|
||||
assert! subs1Msg2 == some 5
|
||||
assert! subs1Msg3 == some 8
|
||||
assert! subs1Msg4.isNone
|
||||
|
||||
-- The new subscriber should only get the most recent message
|
||||
let msg ← subs3.tryRecv
|
||||
assert! msg == some 8
|
||||
|
||||
-- No more messages should be available for the new subscriber
|
||||
let noMsg ← subs3.tryRecv
|
||||
assert! noMsg.isNone
|
||||
|
||||
-- Verify unsubscribed subs2 can't receive anything
|
||||
let subs2NoMsg ← subs2.tryRecv
|
||||
assert! subs2NoMsg.isNone
|
||||
|
||||
#eval testUnsubscribeUnblock.block
|
||||
|
||||
def unsubscribedCannotReceive : Async Unit := do
|
||||
let channel ← Std.Broadcast.new
|
||||
|
||||
let subs1 ← channel.subscribe
|
||||
let subs2 ← channel.subscribe
|
||||
|
||||
discard <| await (← channel.send 1)
|
||||
discard <| await (← channel.send 2)
|
||||
|
||||
let msg1 ← await (← subs1.recv)
|
||||
let msg2 ← await (← subs1.recv)
|
||||
let msg3 ← await (← subs2.recv)
|
||||
let msg4 ← await (← subs2.recv)
|
||||
|
||||
assert! msg1 == some 1
|
||||
assert! msg2 == some 2
|
||||
|
||||
assert! msg3 == some 1
|
||||
assert! msg4 == some 2
|
||||
|
||||
#eval unsubscribedCannotReceive.block
|
||||
|
||||
def fullBuffer : Async Unit := do
|
||||
let channel ← Std.Broadcast.new (capacity := 4)
|
||||
|
||||
let subs1 ← channel.subscribe
|
||||
|
||||
for i in [0:5] do
|
||||
if not (← channel.trySend i).isSome then
|
||||
assert! i == 4
|
||||
|
||||
#eval fullBuffer.block
|
||||
|
||||
def noSubscribers : Async Unit := do
|
||||
let channel ← Std.Broadcast.new (capacity := 4)
|
||||
|
||||
assert! (← channel.trySend 0) == some 0
|
||||
|
||||
#eval noSubscribers.block
|
||||
|
||||
-- Test unsubscribe during message consumption
|
||||
def testUnsubscribeDuringConsumption : Async Unit := do
|
||||
let channel ← Std.Broadcast.new (capacity := 4)
|
||||
let subs1 ← channel.subscribe
|
||||
let subs2 ← channel.subscribe
|
||||
|
||||
-- Send several messages
|
||||
for i in [0:4] do
|
||||
discard <| await (← channel.send i)
|
||||
|
||||
-- subs1 reads first message then unsubscribes
|
||||
let msg1 ← await (← subs1.recv)
|
||||
subs1.unsubscribe
|
||||
|
||||
-- subs2 should still be able to read all messages
|
||||
let msgs2 ← [0, 1, 2, 3].mapM (fun _ => subs2.recv >>= fun r => await r)
|
||||
|
||||
assert! msg1 == some 0
|
||||
assert! msgs2 == [some 0, some 1, some 2, some 3]
|
||||
|
||||
-- subs1 should have no more messages available
|
||||
let result ← subs1.tryRecv
|
||||
assert! result.isNone
|
||||
|
||||
-- Test mixed send and trySend operations
|
||||
def testMixedSendOperations : Async Unit := do
|
||||
let channel ← Std.Broadcast.new (capacity := 3)
|
||||
let subs ← channel.subscribe
|
||||
|
||||
-- Use trySend
|
||||
assert! (← channel.trySend 1).isSome
|
||||
|
||||
-- Use regular send
|
||||
discard <| await (← channel.send 2)
|
||||
|
||||
-- Use trySend again
|
||||
assert! (← channel.trySend 3).isSome
|
||||
|
||||
-- Buffer should be full now
|
||||
assert! not (← channel.trySend 4).isSome
|
||||
|
||||
-- Verify all messages received correctly
|
||||
let msgs ← [1, 2, 3].mapM (fun _ => subs.recv >>= fun r => await r)
|
||||
assert! msgs == [some 1, some 2, some 3]
|
||||
|
||||
#eval testMixedSendOperations.block
|
||||
|
||||
-- Test recv on closed channel with no pending messages
|
||||
def testRecvOnClosedEmpty : Async Unit := do
|
||||
let channel ← Std.Broadcast.new (capacity := 4) (α := Nat)
|
||||
let subs ← channel.subscribe
|
||||
|
||||
channel.close
|
||||
|
||||
-- tryRecv should return none immediately
|
||||
let result ← subs.tryRecv
|
||||
assert! result.isNone
|
||||
|
||||
#eval testRecvOnClosedEmpty.block
|
||||
|
||||
-- Test recv block
|
||||
def testRecvOnEmpty : Async Unit := do
|
||||
let channel ← Std.Broadcast.new (capacity := 4) (α := Nat)
|
||||
let subs ← channel.subscribe
|
||||
|
||||
let recv ← subs.recv
|
||||
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.waiting
|
||||
|
||||
let result ← await (← channel.send 3)
|
||||
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
|
||||
assert! recv.get == some 3
|
||||
|
||||
#eval testRecvOnEmpty.block
|
||||
|
||||
-- Test recv
|
||||
def recvConditions : Async Unit := do
|
||||
let channel ← Std.Broadcast.new (capacity := 16) (α := Nat)
|
||||
let subs1 ← channel.subscribe
|
||||
let subs2 ← channel.subscribe
|
||||
let subs3 ← channel.subscribe
|
||||
|
||||
discard <| EAsync.ofETask (← channel.send 1)
|
||||
discard <| EAsync.ofETask (← channel.send 2)
|
||||
discard <| EAsync.ofETask (← channel.send 3)
|
||||
|
||||
channel.close
|
||||
|
||||
let recv ← subs1.recv
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! recv.get == some 1
|
||||
|
||||
let recv ← subs1.recv
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! recv.get == some 2
|
||||
|
||||
let recv ← subs1.recv
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! recv.get == some 3
|
||||
|
||||
let recv ← subs1.recv
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! recv.get == none
|
||||
|
||||
let recv ← subs1.recv
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! recv.get == none
|
||||
|
||||
let recv ← subs2.recv
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! recv.get == some 1
|
||||
|
||||
let recv ← subs2.recv
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! recv.get == some 2
|
||||
|
||||
let recv ← subs2.recv
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! recv.get == some 3
|
||||
|
||||
let recv ← subs2.recv
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! recv.get == none
|
||||
|
||||
let recv ← subs2.recv
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! recv.get == none
|
||||
|
||||
subs3.unsubscribe
|
||||
|
||||
let recv ← subs3.recv
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! recv.get == none
|
||||
|
||||
#eval recvConditions.block
|
||||
|
||||
-- Test selectables
|
||||
def selectableConditions : Async Unit := do
|
||||
let channel1 ← Std.Channel.new
|
||||
|
||||
let channel ← Std.Broadcast.new (capacity := 16) (α := Nat)
|
||||
let subs1 ← channel.subscribe
|
||||
let subs2 ← channel.subscribe
|
||||
let subs3 ← channel.subscribe
|
||||
|
||||
discard <| EAsync.ofETask (← channel.send 1)
|
||||
discard <| EAsync.ofETask (← channel.send 2)
|
||||
discard <| EAsync.ofETask (← channel.send 3)
|
||||
|
||||
channel.close
|
||||
|
||||
let recv ← Async.toIO <| Selectable.one #[
|
||||
.case subs1.recvSelector pure,
|
||||
.case channel1.recvSelector pure
|
||||
]
|
||||
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! (← IO.ofExcept recv.get) == some 1
|
||||
|
||||
let recv ← Async.toIO <| Selectable.one #[
|
||||
.case subs1.recvSelector pure,
|
||||
.case channel1.recvSelector pure
|
||||
]
|
||||
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! (← IO.ofExcept recv.get) == some 2
|
||||
|
||||
let recv ← Async.toIO <| Selectable.one #[
|
||||
.case subs1.recvSelector pure,
|
||||
.case channel1.recvSelector pure
|
||||
]
|
||||
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! (← IO.ofExcept recv.get) == some 3
|
||||
|
||||
let recv ← Async.toIO <| Selectable.one #[
|
||||
.case subs1.recvSelector pure,
|
||||
.case channel1.recvSelector pure
|
||||
]
|
||||
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! (← IO.ofExcept recv.get) == none
|
||||
|
||||
let recv ← Async.toIO <| Selectable.one #[
|
||||
.case subs2.recvSelector pure,
|
||||
.case channel1.recvSelector pure
|
||||
]
|
||||
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! (← IO.ofExcept recv.get) == some 1
|
||||
|
||||
let recv ← Async.toIO <| Selectable.one #[
|
||||
.case subs2.recvSelector pure,
|
||||
.case channel1.recvSelector pure
|
||||
]
|
||||
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! (← IO.ofExcept recv.get) == some 2
|
||||
|
||||
let recv ← Async.toIO <| Selectable.one #[
|
||||
.case subs2.recvSelector pure,
|
||||
.case channel1.recvSelector pure
|
||||
]
|
||||
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! (← IO.ofExcept recv.get) == some 3
|
||||
|
||||
let recv ← Async.toIO <| Selectable.one #[
|
||||
.case subs2.recvSelector pure,
|
||||
.case channel1.recvSelector pure
|
||||
]
|
||||
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! (← IO.ofExcept recv.get) == none
|
||||
|
||||
subs3.unsubscribe
|
||||
|
||||
let recv ← Async.toIO <| Selectable.one #[
|
||||
.case subs3.recvSelector pure,
|
||||
.case channel1.recvSelector pure
|
||||
]
|
||||
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
assert! (← IO.ofExcept recv.get) == none
|
||||
|
||||
#eval selectableConditions.block
|
||||
Reference in New Issue
Block a user