mirror of
https://github.com/leanprover/lean4.git
synced 2026-04-05 03:34:08 +00:00
Compare commits
4 Commits
grind_none
...
hbv/select
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c640b03851 | ||
|
|
3c29552b1e | ||
|
|
8216a667ee | ||
|
|
fbb6f7c2f0 |
@@ -8,3 +8,4 @@ import Std.Internal.Async.Basic
|
||||
import Std.Internal.Async.Timer
|
||||
import Std.Internal.Async.TCP
|
||||
import Std.Internal.Async.UDP
|
||||
import Std.Internal.Async.Select
|
||||
|
||||
159
src/Std/Internal/Async/Select.lean
Normal file
159
src/Std/Internal/Async/Select.lean
Normal file
@@ -0,0 +1,159 @@
|
||||
/-
|
||||
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Henrik Böving
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Array.Basic
|
||||
import Init.Data.Random
|
||||
import Std.Internal.Async.Basic
|
||||
|
||||
/-!
|
||||
This module contains the implementation of a fair and data-loss free IO multiplexing primitive.
|
||||
The main entrypoint for users is `Selectable.one` and the various functions to produces
|
||||
`Selector`s from other modules.
|
||||
-/
|
||||
|
||||
namespace Std
|
||||
namespace Internal
|
||||
namespace IO
|
||||
namespace Async
|
||||
|
||||
/--
|
||||
The core data structure for racing on winning a `Selectable.one` if multiple event sources are ready
|
||||
at the same time. A `Task` can try to finish the waiter by calling `Waiter.race`.
|
||||
-/
|
||||
structure Waiter (α : Type) where
|
||||
private mk ::
|
||||
private finished : IO.Ref Bool
|
||||
promise : IO.Promise (Except IO.Error α)
|
||||
|
||||
/--
|
||||
Create a fresh `Waiter`.
|
||||
-/
|
||||
def Waiter.new : BaseIO (Waiter α) := do
|
||||
return { finished := ← IO.mkRef false, promise := ← IO.Promise.new }
|
||||
|
||||
/--
|
||||
Swap out the `IO.Promise` within the `Waiter`. Note that the part which determines whether the
|
||||
`Waiter` is finished is not swapped out.
|
||||
-/
|
||||
def Waiter.withPromise (w : Waiter α) (p : IO.Promise (Except IO.Error β)) : Waiter β :=
|
||||
Waiter.mk w.finished p
|
||||
|
||||
/--
|
||||
Try to atomically finish the `Waiter`. If the race for finishing it is won, `win` is executed
|
||||
with the internal `IO.Promise` of the `Waiter`. This promise must under all circumstances be
|
||||
resolved by `win`. If the race is lost some cleanup work can be done in `loose`.
|
||||
-/
|
||||
@[specialize]
|
||||
def Waiter.race [Monad m] [MonadLiftT (ST IO.RealWorld) m] (w : Waiter α)
|
||||
(loose : m β) (win : IO.Promise (Except IO.Error α) → m β) : m β := do
|
||||
let first ← w.finished.modifyGet fun s => (s == false, true)
|
||||
if first then
|
||||
win w.promise
|
||||
else
|
||||
loose
|
||||
|
||||
/--
|
||||
An event source that can be multiplexed using `Selectable.one`, see the documentation of
|
||||
`Selectable.one` for how the protocol of communicating with a `Selector` works.
|
||||
-/
|
||||
structure Selector (α : Type) where
|
||||
/--
|
||||
Try to get a piece of data from the event source in a non blocking fashion, returning `some` if
|
||||
data is available and `none` otherwise.
|
||||
-/
|
||||
tryFn : IO (Option α)
|
||||
/--
|
||||
Register a `Waiter` with the event source. Once data is available on the event source it should
|
||||
attempt to call `Waiter.race` and resolve the `Waiter`'s promise if it wins. It is crucial that
|
||||
data is never actually consumed from the event source unless `Waiter.race` wins in order to
|
||||
prevent data loss.
|
||||
-/
|
||||
registerFn : Waiter α → IO Unit
|
||||
/--
|
||||
A cleanup function that will be called once any `Selector` won the `Selectable.one` race.
|
||||
-/
|
||||
unregisterFn : IO Unit
|
||||
|
||||
/--
|
||||
An event source together with a continuation to call on data obtained from that event source,
|
||||
usually used together in conjunction with `Selectable.one`.
|
||||
-/
|
||||
structure Selectable (α : Type) where
|
||||
case ::
|
||||
{β : Type}
|
||||
/--
|
||||
The event source.
|
||||
-/
|
||||
selector : Selector β
|
||||
/--
|
||||
The continuation to call on results of the event source.
|
||||
-/
|
||||
cont : β → IO (AsyncTask α)
|
||||
|
||||
private def shuffleIt {α : Type u} (xs : Array α) (gen : StdGen) : Array α :=
|
||||
go xs gen 0
|
||||
where
|
||||
go (xs : Array α) (gen : StdGen) (i : Nat) : Array α :=
|
||||
if _ : i < xs.size - 1 then
|
||||
let (j, gen) := randNat gen i (xs.size - 1)
|
||||
let xs := xs.swapIfInBounds i j
|
||||
go xs gen (i + 1)
|
||||
else
|
||||
xs
|
||||
|
||||
/--
|
||||
Perform fair and data-loss free multiplexing on the `Selectable`s in `selectables`.
|
||||
|
||||
The protocol for this is as follows:
|
||||
1. Shuffle `selectables` randomly.
|
||||
2. Run `Selector.tryFn` for each element in `selectables`. If any of them succeeds, run
|
||||
`Selectable.cont` on the result and return right away, otherwise continue.
|
||||
3. Register a `Waiter` with each `Selector` using `Selector.registerFn`. Once the `Waiter` is
|
||||
resolved by a `Selector` run `Selector.unregisterFn` for all `Selectors`s, then the
|
||||
`Selectable.cont` of the `Selector` that won and return the produced task.
|
||||
-/
|
||||
partial def Selectable.one (selectables : Array (Selectable α)) : IO (AsyncTask α) := do
|
||||
let seed := UInt64.toNat (ByteArray.toUInt64LE! (← IO.getRandomBytes 8))
|
||||
let gen := mkStdGen seed
|
||||
let selectables := shuffleIt selectables gen
|
||||
for selectable in selectables do
|
||||
if let some val ← selectable.selector.tryFn then
|
||||
return ← selectable.cont val
|
||||
|
||||
let finished ← IO.mkRef false
|
||||
let promise ← IO.Promise.new
|
||||
|
||||
for selectable in selectables do
|
||||
let waiterPromise ← IO.Promise.new
|
||||
let waiter := Waiter.mk finished waiterPromise
|
||||
selectable.selector.registerFn waiter
|
||||
|
||||
IO.chainTask (t := waiterPromise.result?) fun res? => do
|
||||
match res? with
|
||||
| none =>
|
||||
/-
|
||||
If we get `none` that means the waiterPromise was dropped, usually due to cancellation. In
|
||||
this situation just do nothing.
|
||||
-/
|
||||
return ()
|
||||
| some res =>
|
||||
try
|
||||
let res ← IO.ofExcept res
|
||||
|
||||
for selectable in selectables do
|
||||
selectable.selector.unregisterFn
|
||||
|
||||
let contRes ← selectable.cont res
|
||||
discard <| contRes.mapIO (promise.resolve <| .ok ·)
|
||||
catch e =>
|
||||
promise.resolve (.error e)
|
||||
|
||||
return AsyncTask.ofPromise promise
|
||||
|
||||
end Async
|
||||
end IO
|
||||
end Internal
|
||||
end Std
|
||||
@@ -5,8 +5,8 @@ Authors: Sofia Rodrigues
|
||||
-/
|
||||
prelude
|
||||
import Std.Time
|
||||
import Std.Internal.UV
|
||||
import Std.Internal.Async.Basic
|
||||
import Std.Internal.UV.TCP
|
||||
import Std.Internal.Async.Select
|
||||
import Std.Net.Addr
|
||||
|
||||
namespace Std
|
||||
@@ -130,6 +130,40 @@ socket is not supported. Instead, we recommend binding multiple sockets to the s
|
||||
def recv? (s : Client) (size : UInt64) : IO (AsyncTask (Option ByteArray)) :=
|
||||
AsyncTask.ofPromise <$> s.native.recv? size
|
||||
|
||||
/--
|
||||
Create a `Selector` that resolves once `s` has at max `size` bytes of data available and provides
|
||||
that data. Note that calling this function starts the waiting data and may thus not be called
|
||||
concurrently with `recv?`.
|
||||
-/
|
||||
def recvSelector (s : TCP.Socket.Client) (size : UInt64) : IO (Selector (Option ByteArray)) := do
|
||||
let readableWaiter ← s.native.waitReadable
|
||||
return {
|
||||
tryFn := do
|
||||
if ← readableWaiter.isResolved then
|
||||
-- We know that this read should not block
|
||||
let res ← (← s.recv? size).block
|
||||
return some res
|
||||
else
|
||||
return none
|
||||
registerFn waiter := do
|
||||
-- If we get cancelled the promise will be dropped so prepare for that
|
||||
discard <| IO.mapTask (t := readableWaiter.result?) fun res => do
|
||||
match res with
|
||||
| none => return ()
|
||||
| some res =>
|
||||
let loose := return ()
|
||||
let win promise := do
|
||||
try
|
||||
discard <| IO.ofExcept res
|
||||
-- We know that this read should not block
|
||||
let res ← (← s.recv? size).block
|
||||
promise.resolve (.ok res)
|
||||
catch e =>
|
||||
promise.resolve (.error e)
|
||||
waiter.race loose win
|
||||
unregisterFn := s.native.cancelRecv
|
||||
}
|
||||
|
||||
/--
|
||||
Shuts down the write side of the client socket.
|
||||
-/
|
||||
|
||||
@@ -5,8 +5,8 @@ Authors: Henrik Böving
|
||||
-/
|
||||
prelude
|
||||
import Std.Time
|
||||
import Std.Internal.UV
|
||||
import Std.Internal.Async.Basic
|
||||
import Std.Internal.UV.Timer
|
||||
import Std.Internal.Async.Select
|
||||
|
||||
|
||||
namespace Std
|
||||
@@ -65,6 +65,26 @@ If:
|
||||
def stop (s : Sleep) : IO Unit :=
|
||||
s.native.stop
|
||||
|
||||
/--
|
||||
Create a `Selector` that resolves once `s` has finished. Note that calling this function starts `s`
|
||||
if it hasn't already started.
|
||||
-/
|
||||
def selector (s : Sleep) : IO (Selector Unit) := do
|
||||
let sleepWaiter ← s.wait
|
||||
return {
|
||||
tryFn := do
|
||||
if ← IO.hasFinished sleepWaiter then
|
||||
return some ()
|
||||
else
|
||||
return none
|
||||
registerFn waiter := do
|
||||
discard <| AsyncTask.mapIO (x := sleepWaiter) fun _ => do
|
||||
let loose := return ()
|
||||
let win promise := promise.resolve (.ok ())
|
||||
waiter.race loose win
|
||||
unregisterFn := pure ()
|
||||
}
|
||||
|
||||
end Sleep
|
||||
|
||||
/--
|
||||
@@ -74,6 +94,13 @@ def sleep (duration : Std.Time.Millisecond.Offset) : IO (AsyncTask Unit) := do
|
||||
let sleeper ← Sleep.mk duration
|
||||
sleeper.wait
|
||||
|
||||
/--
|
||||
Return a `Selector` that resolves after `duration`.
|
||||
-/
|
||||
def Selector.sleep (duration : Std.Time.Millisecond.Offset) : IO (Selector Unit) := do
|
||||
let sleeper ← Sleep.mk duration
|
||||
sleeper.selector
|
||||
|
||||
/--
|
||||
`Interval` can be used to repeatedly wait for some duration like a clock.
|
||||
The underlying timer has millisecond resolution.
|
||||
|
||||
@@ -54,6 +54,22 @@ socket is not supported. Instead, we recommend binding multiple sockets to the s
|
||||
@[extern "lean_uv_tcp_recv"]
|
||||
opaque recv? (socket : @& Socket) (size : UInt64) : IO (IO.Promise (Except IO.Error (Option ByteArray)))
|
||||
|
||||
/--
|
||||
Return an `IO.Promise` that resolves to `true` once `socket` has data available for reading or to
|
||||
`false` if `socket` is closed before that. Note that calling this function twice on the same
|
||||
`Socket` or in parallel with `recv?` is not supported.
|
||||
-/
|
||||
@[extern "lean_uv_tcp_wait_readable"]
|
||||
opaque waitReadable (socket : @& Socket) : IO (IO.Promise (Except IO.Error Bool))
|
||||
|
||||
/--
|
||||
Cancel a receive in the form of `recv?` or `waitReadable` if there is currently one pending.
|
||||
This is will resolve their returned `IO.Promise` to `none`. Note that his function is dangerous as
|
||||
improper use can cause data loss and is as such not exposed to the top level API.
|
||||
-/
|
||||
@[extern "lean_uv_tcp_cancel_recv"]
|
||||
opaque cancelRecv (socket : @& Socket) : IO Unit
|
||||
|
||||
/--
|
||||
Binds a TCP socket to a specific address.
|
||||
-/
|
||||
|
||||
@@ -7,6 +7,7 @@ prelude
|
||||
import Init.System.Promise
|
||||
import Init.Data.Queue
|
||||
import Std.Sync.Mutex
|
||||
import Std.Internal.Async.Select
|
||||
|
||||
/-!
|
||||
This module contains the implementation of `Std.Channel`. `Std.Channel` is a multi-producer
|
||||
@@ -44,6 +45,23 @@ instance : ToString Error where
|
||||
instance : MonadLift (EIO Error) IO where
|
||||
monadLift x := EIO.toIO (.userError <| toString ·) x
|
||||
|
||||
open Internal.IO.Async in
|
||||
private inductive Consumer (α : Type) where
|
||||
| normal (promise : IO.Promise (Option α))
|
||||
| select (finished : Waiter (Option α))
|
||||
|
||||
private def Consumer.resolve (c : Consumer α) (x : Option α) : BaseIO Bool := do
|
||||
match c with
|
||||
| .normal promise =>
|
||||
promise.resolve x
|
||||
return true
|
||||
| .select waiter =>
|
||||
let loose := return false
|
||||
let win promise := do
|
||||
promise.resolve (.ok x)
|
||||
return true
|
||||
waiter.race loose win
|
||||
|
||||
/--
|
||||
The central state structure for an unbounded channel, maintains the following invariants:
|
||||
1. `values = ∅ ∨ consumers = ∅`
|
||||
@@ -58,7 +76,7 @@ private structure Unbounded.State (α : Type) where
|
||||
Consumers that are blocked on a producer providing them a value. The `IO.Promise` will be
|
||||
resolved to `none` if the channel closes.
|
||||
-/
|
||||
consumers : Std.Queue (IO.Promise (Option α))
|
||||
consumers : Std.Queue (Consumer α)
|
||||
/--
|
||||
Whether the channel is closed already.
|
||||
-/
|
||||
@@ -85,12 +103,18 @@ private def trySend (ch : Unbounded α) (v : α) : BaseIO Bool := do
|
||||
let st ← get
|
||||
if st.closed then
|
||||
return false
|
||||
else if let some (consumer, consumers) := st.consumers.dequeue? then
|
||||
consumer.resolve (some v)
|
||||
set { st with consumers }
|
||||
return true
|
||||
else
|
||||
set { st with values := st.values.enqueue v }
|
||||
while true do
|
||||
let st ← get
|
||||
if let some (consumer, consumers) := st.consumers.dequeue? then
|
||||
let success ← consumer.resolve (some v)
|
||||
set { st with consumers }
|
||||
if success then
|
||||
break
|
||||
else
|
||||
set { st with values := st.values.enqueue v }
|
||||
break
|
||||
|
||||
return true
|
||||
|
||||
private def send (ch : Unbounded α) (v : α) : BaseIO (Task (Except Error Unit)) := do
|
||||
@@ -103,7 +127,8 @@ private def close (ch : Unbounded α) : EIO Error Unit := do
|
||||
ch.state.atomically do
|
||||
let st ← get
|
||||
if st.closed then throw .alreadyClosed
|
||||
for consumer in st.consumers.toArray do consumer.resolve none
|
||||
for consumer in st.consumers.toArray do
|
||||
discard <| consumer.resolve none
|
||||
set { st with consumers := ∅, closed := true }
|
||||
return ()
|
||||
|
||||
@@ -111,7 +136,8 @@ private def isClosed (ch : Unbounded α) : BaseIO Bool :=
|
||||
ch.state.atomically do
|
||||
return (← get).closed
|
||||
|
||||
private def tryRecv' : AtomicT (Unbounded.State α) BaseIO (Option α) := do
|
||||
private def tryRecv' [Monad m] [MonadLiftT (ST IO.RealWorld) m] :
|
||||
AtomicT (Unbounded.State α) m (Option α) := do
|
||||
let st ← get
|
||||
if let some (a, values) := st.values.dequeue? then
|
||||
set { st with values }
|
||||
@@ -131,9 +157,43 @@ private def recv (ch : Unbounded α) : BaseIO (Task (Option α)) := do
|
||||
return .pure none
|
||||
else
|
||||
let promise ← IO.Promise.new
|
||||
modify fun st => { st with consumers := st.consumers.enqueue promise }
|
||||
modify fun st => { st with consumers := st.consumers.enqueue (.normal promise) }
|
||||
return promise.result?.map (sync := true) (·.bind id)
|
||||
|
||||
@[inline]
|
||||
private def recvReady' [Monad m] [MonadLiftT (ST IO.RealWorld) m] :
|
||||
AtomicT (Unbounded.State α) m Bool := do
|
||||
let st ← get
|
||||
return !st.values.isEmpty || st.closed
|
||||
|
||||
open Internal.IO.Async in
|
||||
private def recvSelector (ch : Unbounded α) : Selector (Option α) :=
|
||||
{
|
||||
tryFn := do
|
||||
ch.state.atomically do
|
||||
if ← recvReady' then
|
||||
let val ← tryRecv'
|
||||
return some val
|
||||
else
|
||||
return none
|
||||
|
||||
registerFn waiter := do
|
||||
ch.state.atomically do
|
||||
-- We did drop the lock between `tryFn` and now so maybe ready?
|
||||
if ← recvReady' then
|
||||
let loose := return ()
|
||||
let win promise := do
|
||||
-- We know we are ready so the value by this is fine
|
||||
promise.resolve (.ok (← tryRecv'))
|
||||
|
||||
waiter.race loose win
|
||||
else
|
||||
modify fun st => { st with consumers := st.consumers.enqueue (.select waiter) }
|
||||
|
||||
-- TODO: gc
|
||||
unregisterFn := return ()
|
||||
}
|
||||
|
||||
end Unbounded
|
||||
|
||||
/--
|
||||
@@ -150,7 +210,7 @@ private structure Zero.State (α : Type) where
|
||||
Consumers that are blocked on a producer providing them a value. The `IO.Promise` will be resolved
|
||||
to `none` if the channel closes.
|
||||
-/
|
||||
consumers : Std.Queue (IO.Promise (Option α))
|
||||
consumers : Std.Queue (Consumer α)
|
||||
/--
|
||||
Whether the channel is closed already.
|
||||
-/
|
||||
@@ -174,13 +234,17 @@ private def new : BaseIO (Zero α) := do
|
||||
Precondition: The channel must not be closed.
|
||||
-/
|
||||
private def trySend' (v : α) : AtomicT (Zero.State α) BaseIO Bool := do
|
||||
let st ← get
|
||||
if let some (consumer, consumers) := st.consumers.dequeue? then
|
||||
consumer.resolve (some v)
|
||||
set { st with consumers }
|
||||
return true
|
||||
else
|
||||
return false
|
||||
while true do
|
||||
let st ← get
|
||||
if let some (consumer, consumers) := st.consumers.dequeue? then
|
||||
let success ← consumer.resolve (some v)
|
||||
set { st with consumers }
|
||||
if success then
|
||||
break
|
||||
else
|
||||
return false
|
||||
|
||||
return true
|
||||
|
||||
private def trySend (ch : Zero α) (v : α) : BaseIO Bool := do
|
||||
ch.state.atomically do
|
||||
@@ -207,7 +271,8 @@ private def close (ch : Zero α) : EIO Error Unit := do
|
||||
ch.state.atomically do
|
||||
let st ← get
|
||||
if st.closed then throw .alreadyClosed
|
||||
for consumer in st.consumers.toArray do consumer.resolve none
|
||||
for consumer in st.consumers.toArray do
|
||||
discard <| consumer.resolve none
|
||||
set { st with consumers := ∅, closed := true }
|
||||
return ()
|
||||
|
||||
@@ -215,7 +280,8 @@ private def isClosed (ch : Zero α) : BaseIO Bool :=
|
||||
ch.state.atomically do
|
||||
return (← get).closed
|
||||
|
||||
private def tryRecv' : AtomicT (Zero.State α) BaseIO (Option α) := do
|
||||
private def tryRecv' [Monad m] [MonadLiftT (ST IO.RealWorld) m] [MonadLiftT BaseIO m] :
|
||||
AtomicT (Zero.State α) m (Option α) := do
|
||||
let st ← get
|
||||
if let some ((val, promise), producers) := st.producers.dequeue? then
|
||||
set { st with producers }
|
||||
@@ -235,11 +301,45 @@ private def recv (ch : Zero α) : BaseIO (Task (Option α)) := do
|
||||
return .pure <| some val
|
||||
else if !st.closed then
|
||||
let promise ← IO.Promise.new
|
||||
set { st with consumers := st.consumers.enqueue promise }
|
||||
set { st with consumers := st.consumers.enqueue (.normal promise) }
|
||||
return promise.result?.map (sync := true) (·.bind id)
|
||||
else
|
||||
return .pure <| none
|
||||
|
||||
@[inline]
|
||||
private def recvReady' [Monad m] [MonadLiftT (ST IO.RealWorld) m] :
|
||||
AtomicT (Zero.State α) m Bool := do
|
||||
let st ← get
|
||||
return !st.producers.isEmpty || st.closed
|
||||
|
||||
open Internal.IO.Async in
|
||||
private def recvSelector (ch : Zero α) : Selector (Option α) :=
|
||||
{
|
||||
tryFn := do
|
||||
ch.state.atomically do
|
||||
if ← recvReady' then
|
||||
let val ← tryRecv'
|
||||
return some val
|
||||
else
|
||||
return none
|
||||
|
||||
registerFn waiter := do
|
||||
ch.state.atomically do
|
||||
-- We did drop the lock between `tryFn` and now so maybe ready?
|
||||
if ← recvReady' then
|
||||
let loose := return ()
|
||||
let win promise := do
|
||||
-- We know we are ready so the value by this is fine
|
||||
promise.resolve (.ok (← tryRecv'))
|
||||
|
||||
waiter.race loose win
|
||||
else
|
||||
modify fun st => { st with consumers := st.consumers.enqueue (.select waiter) }
|
||||
|
||||
-- TODO: gc
|
||||
unregisterFn := return ()
|
||||
}
|
||||
|
||||
end Zero
|
||||
|
||||
/--
|
||||
@@ -390,7 +490,8 @@ private def isClosed (ch : Bounded α) : BaseIO Bool :=
|
||||
ch.state.atomically do
|
||||
return (← get).closed
|
||||
|
||||
private def tryRecv' : AtomicT (Bounded.State α) BaseIO (Option α) := do
|
||||
private def tryRecv' [Monad m] [MonadLiftT (ST IO.RealWorld) m] [MonadLiftT BaseIO m] :
|
||||
AtomicT (Bounded.State α) m (Option α) := do
|
||||
let mut st ← get
|
||||
if st.bufCount == 0 then
|
||||
return none
|
||||
@@ -430,6 +531,61 @@ private partial def recv (ch : Bounded α) : BaseIO (Task (Option α)) := do
|
||||
else
|
||||
return .pure none
|
||||
|
||||
@[inline]
|
||||
private def recvReady' [Monad m] [MonadLiftT (ST IO.RealWorld) m] :
|
||||
AtomicT (Bounded.State α) m Bool := do
|
||||
let st ← get
|
||||
return st.bufCount != 0 || st.closed
|
||||
|
||||
open Internal.IO.Async in
|
||||
private partial def recvSelector (ch : Bounded α) : Selector (Option α) :=
|
||||
{
|
||||
tryFn := do
|
||||
ch.state.atomically do
|
||||
if ← recvReady' then
|
||||
let val ← tryRecv'
|
||||
return some val
|
||||
else
|
||||
return none
|
||||
|
||||
registerFn := registerAux ch
|
||||
|
||||
-- TODO: gc
|
||||
unregisterFn := return ()
|
||||
}
|
||||
where
|
||||
registerAux (ch : Bounded α) (waiter : Waiter (Option α)) : IO Unit := do
|
||||
ch.state.atomically do
|
||||
-- We did drop the lock between `tryFn` and now so maybe ready?
|
||||
if ← recvReady' then
|
||||
let loose := return ()
|
||||
let win promise := do
|
||||
-- We know we are ready so the value by this is fine
|
||||
promise.resolve (.ok (← tryRecv'))
|
||||
|
||||
waiter.race loose win
|
||||
else
|
||||
let promise ← IO.Promise.new
|
||||
modify fun st => { st with consumers := st.consumers.enqueue promise }
|
||||
|
||||
IO.chainTask promise.result? fun res? => do
|
||||
match res? with
|
||||
| none => return ()
|
||||
| some res =>
|
||||
if res then
|
||||
registerAux ch waiter
|
||||
else
|
||||
-- if we loose we must trigger the next promise (if available) to avoid deadlocking
|
||||
let loose := do
|
||||
ch.state.atomically do
|
||||
let st ← get
|
||||
if let some (consumer, consumers) := st.consumers.dequeue? then
|
||||
consumer.resolve true
|
||||
set { st with consumers }
|
||||
|
||||
let win promise := promise.resolve (.ok none)
|
||||
waiter.race loose win
|
||||
|
||||
end Bounded
|
||||
|
||||
/--
|
||||
@@ -551,6 +707,18 @@ def recv (ch : CloseableChannel α) : BaseIO (Task (Option α)) :=
|
||||
| .zero ch => CloseableChannel.Zero.recv ch
|
||||
| .bounded ch => CloseableChannel.Bounded.recv ch
|
||||
|
||||
open Internal.IO.Async in
|
||||
/--
|
||||
Create a `Selector` that resolves once `ch` has data available and provides that that data.
|
||||
In particular if `ch` is closed while waiting on this `Selector` and no data is available already
|
||||
this will resolve to `none`.
|
||||
-/
|
||||
def recvSelector (ch : CloseableChannel α) : Selector (Option α) :=
|
||||
match ch with
|
||||
| .unbounded ch => CloseableChannel.Unbounded.recvSelector ch
|
||||
| .zero ch => CloseableChannel.Zero.recvSelector ch
|
||||
| .bounded ch => CloseableChannel.Bounded.recvSelector ch
|
||||
|
||||
/--
|
||||
`ch.forAsync f` calls `f` for every message received on `ch`.
|
||||
|
||||
@@ -674,6 +842,29 @@ def recv [Inhabited α] (ch : Channel α) : BaseIO (Task α) := do
|
||||
| some val => return .pure val
|
||||
| none => unreachable!
|
||||
|
||||
open Internal.IO.Async in
|
||||
/--
|
||||
Create a `Selector` that resolves once `ch` has data available and provides that that data.
|
||||
-/
|
||||
def recvSelector [Inhabited α] (ch : Channel α) : Selector α :=
|
||||
let sel := CloseableChannel.recvSelector ch.inner
|
||||
{
|
||||
tryFn := ch.tryRecv
|
||||
registerFn waiter := do
|
||||
let original := waiter.promise
|
||||
let intermediate ← IO.Promise.new
|
||||
let waiter := waiter.withPromise intermediate
|
||||
sel.registerFn waiter
|
||||
IO.chainTask (sync := true) intermediate.result?
|
||||
fun
|
||||
| none => return ()
|
||||
| some res =>
|
||||
-- `res` can only be `.err` or `.ok some` as we are in a non closeable channel.
|
||||
original.resolve (res.map Option.get!)
|
||||
|
||||
unregisterFn := sel.unregisterFn
|
||||
}
|
||||
|
||||
@[inherit_doc CloseableChannel.forAsync]
|
||||
partial def forAsync [Inhabited α] (f : α → BaseIO Unit) (ch : Channel α)
|
||||
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
|
||||
|
||||
@@ -227,7 +227,7 @@ extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_recv(b_obj_arg socket, uint64_t
|
||||
// Locking early prevents potential parallelism issues setting the byte_array.
|
||||
event_loop_lock(&global_ev);
|
||||
|
||||
if (tcp_socket->m_byte_array != nullptr) {
|
||||
if (tcp_socket->m_promise_read != nullptr) {
|
||||
event_loop_unlock(&global_ev);
|
||||
return lean_io_result_mk_error(lean_decode_uv_error(UV_EALREADY, nullptr));
|
||||
}
|
||||
@@ -295,6 +295,102 @@ extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_recv(b_obj_arg socket, uint64_t
|
||||
return lean_io_result_mk_ok(promise);
|
||||
}
|
||||
|
||||
/* Std.Internal.UV.TCP.Socket.waitReadable (socket : @& Socket) : IO (IO.Promise (Except IO.Error Bool)) */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_wait_readable(b_obj_arg socket, obj_arg /* w */) {
|
||||
lean_uv_tcp_socket_object* tcp_socket = lean_to_uv_tcp_socket(socket);
|
||||
|
||||
event_loop_lock(&global_ev);
|
||||
|
||||
if (tcp_socket->m_promise_read != nullptr) {
|
||||
event_loop_unlock(&global_ev);
|
||||
return lean_io_result_mk_error(lean_decode_uv_error(UV_EALREADY, nullptr));
|
||||
}
|
||||
|
||||
lean_object* promise = lean_promise_new();
|
||||
mark_mt(promise);
|
||||
|
||||
tcp_socket->m_promise_read = promise;
|
||||
|
||||
// The event loop owns the socket.
|
||||
lean_inc(socket);
|
||||
lean_inc(promise);
|
||||
|
||||
int result = uv_read_start((uv_stream_t*)tcp_socket->m_uv_tcp, [](uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf) {
|
||||
// According to libuv documentation if we do this we do not loose data and a UV_ENOBUFS will
|
||||
// be triggered in the read cb.
|
||||
buf->base = NULL;
|
||||
buf->len = 0;
|
||||
}, [](uv_stream_t* stream, ssize_t nread, const uv_buf_t* buf) {
|
||||
uv_read_stop(stream);
|
||||
|
||||
lean_uv_tcp_socket_object* tcp_socket = lean_to_uv_tcp_socket((lean_object*)stream->data);
|
||||
lean_object* promise = tcp_socket->m_promise_read;
|
||||
|
||||
tcp_socket->m_promise_read = nullptr;
|
||||
|
||||
if (nread == UV_ENOBUFS) {
|
||||
lean_promise_resolve(mk_except_ok(lean_box(1)), promise);
|
||||
} else if (nread == UV_EOF) {
|
||||
lean_promise_resolve(mk_except_ok(lean_box(0)), promise);
|
||||
} else if (nread < 0) {
|
||||
lean_promise_resolve(mk_except_err(lean_decode_uv_error(nread, nullptr)), promise);
|
||||
} else {
|
||||
// This branch should be dead, we cannot receive a value >= 0 according to docs.
|
||||
lean_always_assert(false);
|
||||
}
|
||||
|
||||
lean_dec(promise);
|
||||
|
||||
// The event loop does not own the object anymore.
|
||||
lean_dec((lean_object*)stream->data);
|
||||
});
|
||||
|
||||
if (result < 0) {
|
||||
tcp_socket->m_promise_read = nullptr;
|
||||
|
||||
event_loop_unlock(&global_ev);
|
||||
|
||||
lean_dec(promise); // The structure does not own it.
|
||||
lean_dec(promise); // We are not going to return it.
|
||||
lean_dec(socket);
|
||||
|
||||
return lean_io_result_mk_error(lean_decode_uv_error(result, nullptr));
|
||||
}
|
||||
|
||||
event_loop_unlock(&global_ev);
|
||||
|
||||
return lean_io_result_mk_ok(promise);
|
||||
}
|
||||
|
||||
/* Std.Internal.UV.TCP.Socket.cancelRecv (socket : @& Socket) : IO Unit */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_cancel_recv(b_obj_arg socket, obj_arg /* w */) {
|
||||
lean_uv_tcp_socket_object* tcp_socket = lean_to_uv_tcp_socket(socket);
|
||||
|
||||
event_loop_lock(&global_ev);
|
||||
|
||||
if (tcp_socket->m_promise_read == nullptr) {
|
||||
event_loop_unlock(&global_ev);
|
||||
return lean_io_result_mk_ok(lean_box(0));
|
||||
}
|
||||
|
||||
uv_read_stop((uv_stream_t*)tcp_socket->m_uv_tcp);
|
||||
|
||||
lean_object* promise = tcp_socket->m_promise_read;
|
||||
lean_dec(promise);
|
||||
tcp_socket->m_promise_read = nullptr;
|
||||
|
||||
lean_object* byte_array = tcp_socket->m_byte_array;
|
||||
if (byte_array != nullptr) {
|
||||
lean_dec(byte_array);
|
||||
tcp_socket->m_byte_array = nullptr;
|
||||
}
|
||||
|
||||
lean_dec((lean_object*)tcp_socket);
|
||||
|
||||
event_loop_unlock(&global_ev);
|
||||
return lean_io_result_mk_ok(lean_box(0));
|
||||
}
|
||||
|
||||
/* Std.Internal.UV.TCP.Socket.bind (socket : @& Socket) (addr : @& SocketAddress) : IO Unit */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_bind(b_obj_arg socket, b_obj_arg addr, obj_arg /* w */) {
|
||||
lean_uv_tcp_socket_object* tcp_socket = lean_to_uv_tcp_socket(socket);
|
||||
|
||||
@@ -42,6 +42,8 @@ extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_new(obj_arg /* w */);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_connect(b_obj_arg socket, b_obj_arg addr, obj_arg /* w */);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_send(b_obj_arg socket, obj_arg data, obj_arg /* w */);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_recv(b_obj_arg socket, uint64_t buffer_size, obj_arg /* w */);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_wait_readable(b_obj_arg socket, obj_arg /* w */);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_cancel_recv(b_obj_arg socket, obj_arg /* w */);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_bind(b_obj_arg socket, b_obj_arg addr, obj_arg /* w */);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_listen(b_obj_arg socket, int32_t backlog, obj_arg /* w */);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_accept(b_obj_arg socket, obj_arg /* w */);
|
||||
|
||||
99
tests/lean/run/async_select_channel.lean
Normal file
99
tests/lean/run/async_select_channel.lean
Normal file
@@ -0,0 +1,99 @@
|
||||
import Std.Sync.Channel
|
||||
|
||||
open Std Internal IO Async
|
||||
|
||||
namespace A
|
||||
|
||||
def testReceiver (ch1 ch2 : Std.Channel Nat) (count : Nat) : IO (AsyncTask Nat) := do
|
||||
go ch1 ch2 count 0
|
||||
where
|
||||
go (ch1 ch2 : Std.Channel Nat) (count : Nat) (acc : Nat) : IO (AsyncTask Nat) := do
|
||||
match count with
|
||||
| 0 => return AsyncTask.pure acc
|
||||
| count + 1 =>
|
||||
Selectable.one #[
|
||||
.case ch1.recvSelector fun data => go ch1 ch2 count (acc + data),
|
||||
.case ch2.recvSelector fun data => go ch1 ch2 count (acc + data),
|
||||
]
|
||||
|
||||
def testIt (capacity : Option Nat) : IO Bool := do
|
||||
let amount := 1000
|
||||
let messages := Array.range amount
|
||||
let ch1 ← Std.Channel.new capacity
|
||||
let ch2 ← Std.Channel.new capacity
|
||||
let recvTask ← testReceiver ch1 ch2 amount
|
||||
|
||||
for msg in messages do
|
||||
if (← IO.rand 0 1) = 0 then
|
||||
ch1.sync.send msg
|
||||
else
|
||||
ch2.sync.send msg
|
||||
|
||||
let acc ← recvTask.block
|
||||
return acc == messages.sum
|
||||
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval testIt none
|
||||
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval testIt (some 0)
|
||||
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval testIt (some 1)
|
||||
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval testIt (some 128)
|
||||
|
||||
end A
|
||||
|
||||
namespace B
|
||||
|
||||
def testReceiver (ch1 ch2 : Std.CloseableChannel Nat) (count : Nat) : IO (AsyncTask Nat) := do
|
||||
go ch1 ch2 count 0
|
||||
where
|
||||
go (ch1 ch2 : Std.CloseableChannel Nat) (count : Nat) (acc : Nat) : IO (AsyncTask Nat) := do
|
||||
match count with
|
||||
| 0 => return AsyncTask.pure acc
|
||||
| count + 1 =>
|
||||
Selectable.one #[
|
||||
.case ch1.recvSelector fun data => go ch1 ch2 count (acc + data.getD 0),
|
||||
.case ch2.recvSelector fun data => go ch1 ch2 count (acc + data.getD 0),
|
||||
]
|
||||
|
||||
def testIt (capacity : Option Nat) : IO Bool := do
|
||||
let amount := 1000
|
||||
let messages := Array.range amount
|
||||
let ch1 ← Std.CloseableChannel.new capacity
|
||||
let ch2 ← Std.CloseableChannel.new capacity
|
||||
let recvTask ← testReceiver ch1 ch2 amount
|
||||
|
||||
for msg in messages do
|
||||
if (← IO.rand 0 1) = 0 then
|
||||
ch1.sync.send msg
|
||||
else
|
||||
ch2.sync.send msg
|
||||
|
||||
let acc ← recvTask.block
|
||||
return acc == messages.sum
|
||||
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval testIt none
|
||||
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval testIt (some 0)
|
||||
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval testIt (some 1)
|
||||
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval testIt (some 128)
|
||||
|
||||
end B
|
||||
50
tests/lean/run/async_select_socket.lean
Normal file
50
tests/lean/run/async_select_socket.lean
Normal file
@@ -0,0 +1,50 @@
|
||||
import Std.Internal.Async.Timer
|
||||
import Std.Internal.Async.TCP
|
||||
|
||||
open Std Internal IO Async
|
||||
|
||||
def testClient (addr : Net.SocketAddress) : IO (AsyncTask String) := do
|
||||
let client ← TCP.Socket.Client.mk
|
||||
(← client.connect addr).bindIO fun _ => do
|
||||
Selectable.one #[
|
||||
.case (← Selector.sleep 1000) fun _ => return AsyncTask.pure "Timeout",
|
||||
.case (← client.recvSelector 4096) fun data? => do
|
||||
if let some data := data? then
|
||||
return AsyncTask.pure <| String.fromUTF8! data
|
||||
else
|
||||
return AsyncTask.pure "Closed"
|
||||
]
|
||||
|
||||
def test (serverFn : TCP.Socket.Server → IO (AsyncTask Unit)) (addr : Net.SocketAddress) :
|
||||
IO Unit := do
|
||||
let server ← TCP.Socket.Server.mk
|
||||
server.bind addr
|
||||
server.listen 1
|
||||
let serverTask ← serverFn server
|
||||
let clientTask ← testClient addr
|
||||
serverTask.block
|
||||
IO.println (← clientTask.block)
|
||||
|
||||
def testServerSend (server : TCP.Socket.Server) : IO (AsyncTask Unit) := do
|
||||
(← server.accept).bindIO fun client => do
|
||||
client.send (String.toUTF8 "Success")
|
||||
|
||||
def testServerTimeout (server : TCP.Socket.Server) : IO (AsyncTask Unit) := do
|
||||
(← server.accept).bindIO fun client => do
|
||||
(← Async.sleep 1500).bindIO fun _ => do
|
||||
client.shutdown
|
||||
|
||||
def testServerClose (server : TCP.Socket.Server) : IO (AsyncTask Unit) := do
|
||||
(← server.accept).bindIO fun client => client.shutdown
|
||||
|
||||
/-- info: Success -/
|
||||
#guard_msgs in
|
||||
#eval test testServerSend (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7070)
|
||||
|
||||
/-- info: Closed -/
|
||||
#guard_msgs in
|
||||
#eval test testServerClose (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7071)
|
||||
|
||||
/-- info: Timeout -/
|
||||
#guard_msgs in
|
||||
#eval test testServerTimeout (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7072)
|
||||
52
tests/lean/run/async_select_timer.lean
Normal file
52
tests/lean/run/async_select_timer.lean
Normal file
@@ -0,0 +1,52 @@
|
||||
import Std.Internal.Async.Timer
|
||||
|
||||
open Std Internal IO Async
|
||||
|
||||
def test1 : IO (AsyncTask Nat) := do
|
||||
let s1 ← Sleep.mk 1000
|
||||
let s2 ← Sleep.mk 1500
|
||||
Selectable.one #[
|
||||
.case (← s2.selector) fun _ => return AsyncTask.pure 2,
|
||||
.case (← s1.selector) fun _ => return AsyncTask.pure 1,
|
||||
]
|
||||
|
||||
/-- info: 1 -/
|
||||
#guard_msgs in
|
||||
#eval show IO _ from do
|
||||
let task ← test1
|
||||
IO.ofExcept task.get
|
||||
|
||||
def test2Helper (dur : Time.Millisecond.Offset) : IO (AsyncTask Nat) := do
|
||||
Selectable.one #[
|
||||
.case (← Selector.sleep dur) fun _ => return AsyncTask.pure 1,
|
||||
.case (← Selector.sleep dur) fun _ => return AsyncTask.pure 2,
|
||||
.case (← Selector.sleep dur) fun _ => return AsyncTask.pure 3,
|
||||
.case (← Selector.sleep dur) fun _ => return AsyncTask.pure 4,
|
||||
.case (← Selector.sleep dur) fun _ => return AsyncTask.pure 5,
|
||||
.case (← Selector.sleep dur) fun _ => return AsyncTask.pure 6,
|
||||
.case (← Selector.sleep dur) fun _ => return AsyncTask.pure 7,
|
||||
.case (← Selector.sleep dur) fun _ => return AsyncTask.pure 8,
|
||||
.case (← Selector.sleep dur) fun _ => return AsyncTask.pure 9,
|
||||
.case (← Selector.sleep dur) fun _ => return AsyncTask.pure 10,
|
||||
]
|
||||
|
||||
def test2 (dur : Time.Millisecond.Offset) : IO Bool := do
|
||||
let r1 ← IO.ofExcept (← test2Helper dur).get
|
||||
let r2 ← IO.ofExcept (← test2Helper dur).get
|
||||
let r3 ← IO.ofExcept (← test2Helper dur).get
|
||||
let r4 ← IO.ofExcept (← test2Helper dur).get
|
||||
let r5 ← IO.ofExcept (← test2Helper dur).get
|
||||
let r6 ← IO.ofExcept (← test2Helper dur).get
|
||||
let r7 ← IO.ofExcept (← test2Helper dur).get
|
||||
let r8 ← IO.ofExcept (← test2Helper dur).get
|
||||
let r9 ← IO.ofExcept (← test2Helper dur).get
|
||||
let r10 ← IO.ofExcept (← test2Helper dur).get
|
||||
return #[r2, r3, r4, r5, r6, r7, r8, r9, r10].any (· != r1)
|
||||
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval test2 100
|
||||
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval test2 0
|
||||
Reference in New Issue
Block a user