Compare commits

...

18 Commits

Author SHA1 Message Date
Sofia Rodrigues
8810bbe140 fix: comment 2025-12-15 09:53:30 -03:00
Sofia Rodrigues
349c860b8b fix: function names 2025-12-15 09:53:30 -03:00
Sofia Rodrigues
6a424ee4e6 fix: notes and concurrently 2025-12-15 09:53:30 -03:00
Sofia Rodrigues
07c5465052 feat: split async function 2025-12-15 09:53:30 -03:00
Sofia Rodrigues
7ff00b14af fix: remove fork function and added notes 2025-12-15 09:53:30 -03:00
Sofia Rodrigues
debdf61d73 fix: test 2025-12-15 09:53:30 -03:00
Sofia Rodrigues
6d834b29fb feat: countAliveTokens and background 2025-12-15 09:53:30 -03:00
Sofia Rodrigues
ee8ad4e679 test: fix 2025-12-15 09:53:30 -03:00
Sofia Rodrigues
d3c70a6cc0 fix: concurrently 2025-12-15 09:53:30 -03:00
Sofia Rodrigues
b796cd7714 fix: name and remove backgroudn 2025-12-15 09:53:30 -03:00
Sofia Rodrigues
12b646e538 test: async context test 2025-12-15 09:53:30 -03:00
Sofia Rodrigues
ab292870fa feat: add selector.cancelled function 2025-12-15 09:53:30 -03:00
Sofia Rodrigues
60cf9f0cc4 fix: comments 2025-12-15 09:53:30 -03:00
Sofia Rodrigues
c33dcf5c1e feat: add contextual monad 2025-12-15 09:53:30 -03:00
Sofia Rodrigues
7da031df86 fix: tests 2025-12-15 09:53:30 -03:00
Sofia Rodrigues
81a0c28828 fix: comment 2025-12-15 09:53:30 -03:00
Sofia Rodrigues
1ead9eb8ac fix: comments 2025-12-15 09:53:29 -03:00
Sofia Rodrigues
014b4c8a90 feat: context 2025-12-15 09:53:29 -03:00
8 changed files with 1593 additions and 14 deletions

View File

@@ -7,6 +7,7 @@ module
prelude
public import Std.Internal.Async.Basic
public import Std.Internal.Async.ContextAsync
public import Std.Internal.Async.Timer
public import Std.Internal.Async.TCP
public import Std.Internal.Async.UDP

View File

@@ -0,0 +1,273 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Sofia Rodrigues
-/
module
prelude
public import Std.Time
public import Std.Internal.UV
public import Std.Internal.Async.Basic
public import Std.Internal.Async.Timer
public import Std.Sync.CancellationContext
public section
/-!
This module contains the implementation of `ContextAsync`, a monad for asynchronous computations with
cooperative cancellation support that must be explicitly checked for and cancelled explicitly.
-/
namespace Std
namespace Internal
namespace IO
namespace Async
/--
An asynchronous computation with cooperative cancellation support via a `CancellationContext`. `ContextAsync α`
is equivalent to `ReaderT CancellationContext Async α`, providing a `CancellationContext` value to async
computations.
-/
abbrev ContextAsync (α : Type) := ReaderT CancellationContext Async α
namespace ContextAsync
/--
Runs a `ContextAsync` computation with a given context. See also `ContextAsync.run` for running with a new
context that automatically cancels after execution.
-/
@[inline]
protected def runIn (ctx : CancellationContext) (x : ContextAsync α) : Async α :=
x ctx
/--
Runs a `ContextAsync` computation with a new context that cancels after the execution of the computation.
See also `ContextAsync.runIn` for running with an existing context.
-/
@[inline]
protected def run (x : ContextAsync α) : Async α := do
let ctx CancellationContext.new
x ctx <* ctx.cancel .cancel
/--
Returns the current context for inspection or to pass to other functions.
-/
@[inline]
def getContext : ContextAsync CancellationContext :=
fun ctx => pure ctx
/--
Checks if the current context is cancelled. Returns `true` if the context (or any ancestor) has been cancelled.
Long-running operations should periodically check this and exit gracefully when cancelled.
-/
@[inline]
def isCancelled : ContextAsync Bool := do
let ctx getContext
ctx.isCancelled
/--
Gets the cancellation reason if the context is cancelled. Returns `some reason` if cancelled, `none` otherwise,
allowing you to distinguish between different cancellation types.
-/
@[inline]
def getCancellationReason : ContextAsync (Option CancellationReason) := do
let ctx getContext
ctx.getCancellationReason
/--
Cancels the current context with the given reason, cascading to all child contexts.
Cancellation is cooperative, operations must explicitly check `isCancelled` or use `awaitCancellation` to respond.
-/
@[inline]
def cancel (reason : CancellationReason) : ContextAsync Unit := do
let ctx getContext
ctx.cancel reason
/--
Returns a selector that completes when the current context is cancelled.
-/
@[inline]
def doneSelector : ContextAsync (Selector Unit) := do
let ctx getContext
return ctx.doneSelector
/--
Waits for the current context to be cancelled.
-/
@[inline]
def awaitCancellation : ContextAsync Unit := do
let ctx getContext
let task ctx.done
await task
/--
Runs two computations concurrently and returns both results. Each computation runs in its own child context;
if either fails or is cancelled, both are cancelled immediately and the exception is propagated.
-/
@[inline, specialize]
def concurrently (x : ContextAsync α) (y : ContextAsync β)
(prio := Task.Priority.default) : ContextAsync (α × β) := do
let ctx getContext
let concurrentCtx ctx.fork
let childCtx1 concurrentCtx.fork
let childCtx2 concurrentCtx.fork
let result Async.concurrently
(try x childCtx1 catch err => do concurrentCtx.cancel .cancel; throw err finally childCtx1.cancel .cancel)
(try y childCtx2 catch err => do concurrentCtx.cancel .cancel; throw err finally childCtx2.cancel .cancel)
prio
concurrentCtx.cancel .cancel
return result
/--
Runs two computations concurrently and returns the result of the first to complete. Each computation runs
in its own child context; when either completes, the other is cancelled immediately.
-/
@[inline, specialize]
def race [Inhabited α] (x : ContextAsync α) (y : ContextAsync α)
(prio := Task.Priority.default) : ContextAsync α := do
let parent getContext
let ctx1 CancellationContext.fork parent
let ctx2 CancellationContext.fork parent
let task1 async (x ctx1) prio
let task2 async (y ctx2) prio
let result Async.race
(await task1 <* ctx2.cancel .cancel)
(await task2 <* ctx1.cancel .cancel)
prio
pure result
/--
Runs all computations concurrently and collects results in the same order. Each runs in its own child context;
if any computation fails, all others are cancelled and the exception is propagated.
-/
@[inline, specialize]
def concurrentlyAll (xs : Array (ContextAsync α))
(prio := Task.Priority.default) : ContextAsync (Array α) := do
let ctx getContext
let concurrentCtx ctx.fork
let tasks : Array (AsyncTask α) xs.mapM fun ctxAsync => do
let childCtx concurrentCtx.fork
async (prio := prio)
(try
ctxAsync childCtx
catch err => do
concurrentCtx.cancel .cancel
throw err
finally
childCtx.cancel .cancel)
let result tasks.mapM await
return result
/--
Launches a `ContextAsync` computation in the background, discarding its result.
The computation runs independently in the background in its own child context. The parent computation does not wait
for background tasks to complete. This means that if the parent finishes its execution it will cause
the cancellation of the background functions. See also `disown` for launching tasks that continue independently
even after parent cancellation.
-/
@[inline, specialize]
def background (action : ContextAsync α) (prio := Task.Priority.default) : ContextAsync Unit := do
let ctx getContext
let childCtx ctx.fork
Async.background (action childCtx *> childCtx.cancel .cancel) prio
/--
Launches a `ContextAsync` computation in the background, discarding its result. It's similar to `background`,
but the child context is not automatically cancelled when the action completes. This allows the disowned
computation to continue running independently, even if the parent context is cancelled. The child context
will remain alive as long as the computation needs it. See also `background` for launching tasks that are
cancelled when the parent finishes.
-/
@[inline, specialize]
def disown (action : ContextAsync α) (prio := Task.Priority.default) : ContextAsync Unit := do
let childCtx CancellationContext.new
Async.background (action childCtx) prio
/--
Runs all computations concurrently and returns the first result. Each computation runs in its own child context;
when the first completes successfully, all others are cancelled immediately.
-/
def raceAll [ForM ContextAsync c (ContextAsync α)] (xs : c)
(prio := Task.Priority.default) : ContextAsync α := do
let parent getContext
let promise IO.Promise.new
ForM.forM xs fun x => do
let ctx CancellationContext.fork parent
let task async (x ctx) prio
background do
try
let result await task
promise.resolve (.ok result)
catch e =>
discard $ promise.resolve (.error e)
let result await promise
parent.cancel .cancel
Async.ofExcept result
/--
Launches a `ContextAsync` computation as an asynchronous task with a forked child context.
The child context is automatically cancelled when the task completes or fails.
-/
@[inline, specialize]
def async (x : ContextAsync α) (prio := Task.Priority.default) : ContextAsync (AsyncTask α) :=
fun ctx => do
let childCtx ctx.fork
Async.async (try x childCtx finally childCtx.cancel .cancel) prio
instance : MonadAsync AsyncTask ContextAsync where
async x prio := ContextAsync.async x prio
instance : Functor ContextAsync where
map f x := fun ctx => f <$> x ctx
instance : Monad ContextAsync where
pure a := fun _ => pure a
bind x f := fun ctx => x ctx >>= fun a => f a ctx
instance : MonadLift IO ContextAsync where
monadLift x := fun _ => Async.ofIOTask (Task.pure <$> x)
instance : MonadLift BaseIO ContextAsync where
monadLift x := fun _ => liftM (m := Async) x
instance : MonadExcept IO.Error ContextAsync where
throw e := fun _ => throw e
tryCatch x h := fun ctx => tryCatch (x ctx) (fun e => h e ctx)
instance : MonadFinally ContextAsync where
tryFinally' x f := fun ctx =>
tryFinally' (x ctx) (fun opt => f opt ctx)
instance [Inhabited α] : Inhabited (ContextAsync α) where
default := fun _ => default
instance : MonadAwait AsyncTask ContextAsync where
await t := fun _ => await t
end ContextAsync
/--
Returns a selector that completes when the current context is cancelled.
This is useful for selecting on cancellation alongside other asynchronous operations.
-/
def Selector.cancelled : ContextAsync (Selector Unit) := do
ContextAsync.doneSelector
end Async
end IO
end Internal
end Std

View File

@@ -16,5 +16,6 @@ public import Std.Sync.Notify
public import Std.Sync.Broadcast
public import Std.Sync.StreamMap
public import Std.Sync.CancellationToken
public import Std.Sync.CancellationContext
@[expose] public section

View File

@@ -0,0 +1,152 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Sofia Rodrigues
-/
module
prelude
public import Std.Data
public import Init.System.Promise
public import Init.Data.Queue
public import Std.Sync.Mutex
public import Std.Sync.CancellationToken
public import Std.Internal.Async.Select
public section
/-!
This module provides a tree-structured cancellation context called `CancellationToken` where cancelling a parent
automatically cancels all child contexts.
-/
namespace Std
open Std.Internal.IO.Async
structure CancellationContext.State where
/--
Map of token IDs to optional tokens and their children.
-/
tokens : TreeMap UInt64 (CancellationToken × Array UInt64) := .empty
/--
Next available ID
-/
id : UInt64 := 1
/--
A cancellation context that allows multiple consumers to wait until cancellation is requested. Forms
a tree structure where cancelling a parent cancels all children.
-/
structure CancellationContext where
state : Std.Mutex CancellationContext.State
token : CancellationToken
id : UInt64
namespace CancellationContext
/--
Creates a new root cancellation context.
-/
def new : BaseIO CancellationContext := do
let token Std.CancellationToken.new
return {
state := Std.Mutex.new { tokens := .empty |>.insert 0 (token, #[]) },
token,
id := 0
}
/--
Forks a child context from a parent. If the parent is already cancelled, returns the parent context.
Otherwise, creates a new child that will be cancelled when the parent is cancelled.
-/
def fork (root : CancellationContext) : BaseIO CancellationContext := do
root.state.atomically do
if root.token.isCancelled then
return root
let token Std.CancellationToken.new
let st get
let newId := st.id
set { st with
id := newId + 1,
tokens := st.tokens.insert newId (token, #[])
|>.modify root.id (.map (·) (.push · newId))
}
return { state := root.state, token, id := newId }
/--
Recursively cancels a context and all its children with the given reason.
-/
private partial def cancelChildren (state : CancellationContext.State) (id : UInt64) (reason : CancellationReason) : BaseIO CancellationContext.State := do
let mut state := state
let some (token, children) := state.tokens.get? id
| return state
for tokenId in children do
state cancelChildren state tokenId reason
token.cancel reason
pure { state with tokens := state.tokens.erase id }
/--
Cancels this context and all child contexts with the given reason.
-/
def cancel (x : CancellationContext) (reason : CancellationReason) : BaseIO Unit := do
if x.token.isCancelled then
return
x.state.atomically do
let st get
let st cancelChildren st x.id reason
set st
/--
Checks if the context is cancelled.
-/
@[inline]
def isCancelled (x : CancellationContext) : BaseIO Bool := do
x.token.isCancelled
/--
Returns the cancellation reason if the context is cancelled.
-/
@[inline]
def getCancellationReason (x : CancellationContext) : BaseIO (Option CancellationReason) := do
x.token.getCancellationReason
/--
Waits for cancellation. Returns a task that completes when the context is cancelled.
-/
@[inline]
def done (x : CancellationContext) : IO (AsyncTask Unit) :=
x.token.wait
/--
Creates a selector that waits for cancellation.
-/
@[inline]
def doneSelector (x : CancellationContext) : Selector Unit :=
x.token.selector
private partial def countAliveTokensRec (state : CancellationContext.State) (id : UInt64) : Nat :=
match state.tokens.get? id with
| none => 0
| some (_, children) => 1 + children.foldl (fun acc childId => acc + countAliveTokensRec state childId) 0
/--
Counts the number of alive (non-cancelled) tokens in the context tree, including
this context and all its descendants.
-/
def countAliveTokens (x : CancellationContext) : BaseIO Nat := do
x.state.atomically do
let st get
return countAliveTokensRec st x.id
end CancellationContext
end Std

View File

@@ -23,6 +23,38 @@ that a cancellation has occurred.
namespace Std
open Std.Internal.IO.Async
/--
Reasons for cancellation.
-/
inductive CancellationReason where
/--
Cancelled due to a deadline or timeout
-/
| deadline
/--
Cancelled due to shutdown
-/
| shutdown
/--
Explicitly cancelled
-/
| cancel
/--
Custom cancellation reason
-/
| custom (msg : String)
deriving Repr, BEq
instance : ToString CancellationReason where
toString
| .deadline => "deadline"
| .shutdown => "shutdown"
| .cancel => "cancel"
| .custom msg => s!"custom(\"{msg}\")"
inductive CancellationToken.Consumer where
| normal (promise : IO.Promise Unit)
| select (finished : Waiter Unit)
@@ -44,9 +76,9 @@ The central state structure for a `CancellationToken`.
-/
structure CancellationToken.State where
/--
Whether this token has been cancelled.
The cancellation reason if cancelled, none otherwise.
-/
cancelled : Bool
reason : Option CancellationReason
/--
Consumers that are blocked waiting for cancellation.
@@ -63,24 +95,24 @@ structure CancellationToken where
namespace CancellationToken
/--
Create a new cancellation token.
Creates a new cancellation token.
-/
def new : BaseIO CancellationToken := do
return { state := Std.Mutex.new { cancelled := false, consumers := } }
return { state := Std.Mutex.new { reason := none, consumers := } }
/--
Cancel the token, notifying all currently waiting consumers with `true`.
Cancels the token with the given reason, notifying all currently waiting consumers.
Once cancelled, the token remains cancelled.
-/
def cancel (x : CancellationToken) : BaseIO Unit := do
def cancel (x : CancellationToken) (reason : CancellationReason := .cancel) : BaseIO Unit := do
x.state.atomically do
let mut st get
if st.cancelled then
if st.reason.isSome then
return
let mut remainingConsumers := st.consumers
st := { cancelled := true, consumers := }
st := { reason := some reason, consumers := }
while true do
if let some (consumer, rest) := remainingConsumers.dequeue? then
@@ -92,21 +124,29 @@ def cancel (x : CancellationToken) : BaseIO Unit := do
set st
/--
Check if the token is cancelled.
Checks if the token is cancelled.
-/
def isCancelled (x : CancellationToken) : BaseIO Bool := do
x.state.atomically do
let st get
return st.cancelled
return st.reason.isSome
/--
Wait for cancellation. Returns a task that completes when cancelled,
Gets the cancellation reason if the token is cancelled.
-/
def getCancellationReason (x : CancellationToken) : BaseIO (Option CancellationReason) := do
x.state.atomically do
let st get
return st.reason
/--
Waits for cancellation. Returns a task that completes when cancelled.
-/
def wait (x : CancellationToken) : IO (AsyncTask Unit) :=
x.state.atomically do
let st get
if st.cancelled then
if st.reason.isSome then
return Task.pure (.ok ())
let promise IO.Promise.new
@@ -118,7 +158,7 @@ def wait (x : CancellationToken) : IO (AsyncTask Unit) :=
| none => throw (IO.userError "cancellation token dropped")
/--
Creates a selector that waits for cancellation
Creates a selector that waits for cancellation.
-/
def selector (token : CancellationToken) : Selector Unit := {
tryFn := do
@@ -131,7 +171,7 @@ def selector (token : CancellationToken) : Selector Unit := {
token.state.atomically do
let st get
if st.cancelled then
if st.reason.isSome then
discard <| waiter.race (return false) (fun promise => do
promise.resolve (.ok ())
return true)

View File

@@ -0,0 +1,163 @@
import Std.Internal.Async
import Std.Sync
open Std.Internal.IO Async
-- Test basic cancellation with default reason
def testBasicCancellationWithReason : Async Unit := do
let token Std.CancellationToken.new
assert! not ( token.isCancelled)
token.cancel
assert! ( token.isCancelled)
let reason token.getCancellationReason
assert! reason == some .cancel
#eval testBasicCancellationWithReason.block
-- Test cancellation with deadline reason
def testDeadlineReason : Async Unit := do
let token Std.CancellationToken.new
assert! not ( token.isCancelled)
token.cancel .deadline
assert! ( token.isCancelled)
let reason token.getCancellationReason
assert! reason == some .deadline
#eval testDeadlineReason.block
-- Test cancellation with shutdown reason
def testShutdownReason : Async Unit := do
let token Std.CancellationToken.new
token.cancel .shutdown
let reason token.getCancellationReason
assert! reason == some .shutdown
#eval testShutdownReason.block
-- Test cancellation with custom reason
def testCustomReason : Async Unit := do
let token Std.CancellationToken.new
token.cancel (.custom "connection timeout")
let reason token.getCancellationReason
assert! reason == some (.custom "connection timeout")
#eval testCustomReason.block
-- Test that uncancelled token has no reason
def testUncancelledNoReason : Async Unit := do
let token Std.CancellationToken.new
let reason token.getCancellationReason
assert! reason == none
#eval testUncancelledNoReason.block
-- Test context cancellation with reason
def testContextCancellation : Async Unit := do
let ctx Std.CancellationContext.new
assert! not ( ctx.isCancelled)
ctx.cancel .shutdown
assert! ( ctx.isCancelled)
let reason ctx.token.getCancellationReason
assert! reason == some .shutdown
#eval testContextCancellation.block
-- Test context tree with different reasons
def testContextTreeReasons : Async Unit := do
let root Std.CancellationContext.new
let child1 root.fork
let child2 root.fork
let grandchild child1.fork
-- Cancel root with shutdown reason
root.cancel .shutdown
-- All should be cancelled
assert! ( root.isCancelled)
assert! ( child1.isCancelled)
assert! ( child2.isCancelled)
assert! ( grandchild.isCancelled)
-- All should have the shutdown reason (propagated from root)
assert! ( root.token.getCancellationReason) == some .shutdown
assert! ( child1.token.getCancellationReason) == some .shutdown
assert! ( child2.token.getCancellationReason) == some .shutdown
assert! ( grandchild.token.getCancellationReason) == some .shutdown
#eval testContextTreeReasons.block
-- Test child cancellation doesn't affect parent
def testChildCancellationIndependent : Async Unit := do
let root Std.CancellationContext.new
let child root.fork
-- Cancel child with deadline
child.cancel .deadline
-- Child should be cancelled with deadline reason
assert! ( child.isCancelled)
assert! ( child.token.getCancellationReason) == some .deadline
-- Parent should still be active
assert! not ( root.isCancelled)
assert! ( root.token.getCancellationReason) == none
#eval testChildCancellationIndependent.block
-- Test selector with reason
def testSelectorWithReason : Async Unit := do
let token Std.CancellationToken.new
let completed Std.Mutex.new false
let reasonRef Std.Mutex.new none
let task async do
Selectable.one #[.case token.selector (fun _ => pure ())]
completed.atomically (set true)
reasonRef.atomically (set ( token.getCancellationReason))
assert! not ( completed.atomically get)
token.cancel .deadline
await task
assert! ( completed.atomically get)
assert! ( reasonRef.atomically get) == some Std.CancellationReason.deadline
#eval testSelectorWithReason.block
-- Test wait with reason
def testWaitWithReason : Async Unit := do
let token Std.CancellationToken.new
let task async do
let _ await ( token.wait)
token.getCancellationReason
Async.sleep 10
token.cancel (.custom "test reason")
let reason await task
assert! reason == some (.custom "test reason")
#eval testWaitWithReason.block
-- Test multiple cancellations (first one wins)
def testMultipleCancellations : Async Unit := do
let token Std.CancellationToken.new
token.cancel .deadline
token.cancel .shutdown -- This should be ignored
let reason token.getCancellationReason
assert! reason == some .deadline -- First reason should persist
#eval testMultipleCancellations.block

View File

@@ -0,0 +1,251 @@
import Std.Internal.Async
import Std.Sync
open Std.Internal.IO Async
/-- Test basic tree cancellation -/
partial def testCancelTree : IO Unit := do
let mutex Std.Mutex.new 0
let context Std.CancellationContext.new
Async.block do
let rec loop (x : Nat) (parent : Std.CancellationContext) : Async Unit := do
match x with
| 0 => do
await ( parent.done)
mutex.atomically (modify (· + 1))
| n + 1 => do
background (loop n ( parent.fork))
background (loop n ( parent.fork))
await ( parent.done)
mutex.atomically (modify (· + 1))
background (loop 3 context)
Async.sleep 500
context.cancel .cancel
Async.sleep 1000
assert! ( context.countAliveTokens) == 0
let size mutex.atomically get
IO.println s!"cancelled {size}"
/--
info: cancelled 15
-/
#guard_msgs in
#eval testCancelTree
/-- Test cancellation with different reasons -/
def testCancellationReasons : IO Unit := do
let ctx Std.CancellationContext.new
let (reason1, reason2, reason3, reason4) Async.block do
-- Test with .cancel reason
let ctx1 ctx.fork
ctx1.cancel .cancel
let some reason1 ctx1.getCancellationReason | return (none, none, none, none)
-- Test with .deadline reason
let ctx2 ctx.fork
ctx2.cancel .deadline
let some reason2 ctx2.getCancellationReason | return (none, none, none, none)
-- Test with .shutdown reason
let ctx3 ctx.fork
ctx3.cancel .shutdown
let some reason3 ctx3.getCancellationReason | return (none, none, none, none)
-- Test with custom reason
let ctx4 ctx.fork
ctx4.cancel (.custom "test error")
let some reason4 ctx4.getCancellationReason | return (none, none, none, none)
return (some reason1, some reason2, some reason3, some reason4)
if let some r1 := reason1 then IO.println s!"Reason 1: {r1}"
if let some r2 := reason2 then IO.println s!"Reason 2: {r2}"
if let some r3 := reason3 then IO.println s!"Reason 3: {r3}"
if let some r4 := reason4 then IO.println s!"Reason 4: {r4}"
assert! ( ctx.countAliveTokens) == 1
/--
info: Reason 1: cancel
Reason 2: deadline
Reason 3: shutdown
Reason 4: custom("test error")
-/
#guard_msgs in
#eval testCancellationReasons
/-- Test cancellation propagates reason to children -/
def testReasonPropagation : IO Unit := do
let (parentReason, child1Reason, child2Reason, grandchildReason) Async.block do
let parent Std.CancellationContext.new
let child1 parent.fork
let child2 parent.fork
let grandchild child1.fork
parent.cancel (.custom "parent cancelled")
Async.sleep 100
let some parentReason parent.getCancellationReason | return (none, none, none, none)
let some child1Reason child1.getCancellationReason | return (none, none, none, none)
let some child2Reason child2.getCancellationReason | return (none, none, none, none)
let some grandchildReason grandchild.getCancellationReason | return (none, none, none, none)
return (some parentReason, some child1Reason, some child2Reason, some grandchildReason)
if let some r := parentReason then IO.println s!"Parent: {r}"
if let some r := child1Reason then IO.println s!"Child1: {r}"
if let some r := child2Reason then IO.println s!"Child2: {r}"
if let some r := grandchildReason then IO.println s!"Grandchild: {r}"
/--
info: Parent: custom("parent cancelled")
Child1: custom("parent cancelled")
Child2: custom("parent cancelled")
Grandchild: custom("parent cancelled")
-/
#guard_msgs in
#eval testReasonPropagation
/-- Test cancellation in the middle of work -/
def testCancelInMiddle : IO Unit := do
let counter Std.Mutex.new 0
let cancelledCounter Std.Mutex.new 0
let (finalCount, cancelledCount) Async.block do
let context Std.CancellationContext.new
-- Worker that does work until cancelled
let worker (ctx : Std.CancellationContext) : Async Unit := do
for _ in [0:100] do
if ctx.isCancelled then
cancelledCounter.atomically (modify (· + 1))
break
counter.atomically (modify (· + 1))
Async.sleep 10
-- Start 5 workers
for _ in [0:5] do
background (worker context)
-- Let them run for a bit, then cancel
Async.sleep 200
context.cancel .deadline
-- Wait for them to finish
Async.sleep 500
let finalCount counter.atomically get
let cancelledCount cancelledCounter.atomically get
return (finalCount, cancelledCount)
IO.println s!"Completed {finalCount} iterations before cancellation"
IO.println s!"{cancelledCount} workers detected cancellation"
/-- Test cancellation before forking -/
def testCancelBeforeFork : IO Unit := do
let (isSame, isChildCancelled) Async.block do
let ctx Std.CancellationContext.new
ctx.cancel .cancel
-- Fork after cancellation should return same context
let child ctx.fork
let isSame := ctx.id == child.id
let isChildCancelled child.isCancelled
return (isSame, isChildCancelled)
IO.println s!"Same context: {isSame}, Child cancelled: {isChildCancelled}"
/--
info: Same context: true, Child cancelled: true
-/
#guard_msgs in
#eval testCancelBeforeFork
/-- Test deep tree cancellation with reason -/
partial def testDeepTreeCancellation : IO Unit := do
let depths Std.Mutex.new ([] : List (Nat × Std.CancellationReason))
let (count, allSameReason) Async.block do
let root Std.CancellationContext.new
let rec makeTree (depth : Nat) (ctx : Std.CancellationContext) : Async Unit := do
if depth == 0 then
await ( ctx.done)
if let some reason ctx.getCancellationReason then
depths.atomically (modify (·.cons (depth, reason)))
else
let child1 ctx.fork
let child2 ctx.fork
background (makeTree (depth - 1) child1)
background (makeTree (depth - 1) child2)
await ( ctx.done)
if let some reason ctx.getCancellationReason then
depths.atomically (modify (·.cons (depth, reason)))
background (makeTree 4 root)
Async.sleep 200
root.cancel (.custom "deep tree cancel")
Async.sleep 500
let results depths.atomically get
let count := results.length
let allSameReason := results.all fun (_, r) => r == .custom "deep tree cancel"
return (count, allSameReason)
IO.println s!"Cancelled {count} nodes, all with same reason: {allSameReason}"
/--
info: Cancelled 31 nodes, all with same reason: true
-/
#guard_msgs in
#eval testDeepTreeCancellation
/-- Test counting alive tokens -/
def testCountAliveTokens : IO Unit := do
let (count0, count1, count2, count3, count4) Async.block do
let root Std.CancellationContext.new
let count0 root.countAliveTokens -- Root only
-- Fork 3 children
let child1 root.fork
let child2 root.fork
let _child3 root.fork
let count1 root.countAliveTokens -- Root + 3 children = 4
-- Cancel one child (and its subtree)
child1.cancel .cancel
Async.sleep 100
let count2 root.countAliveTokens -- Root + 2 children = 3
-- Fork a grandchild from child2
let _grandchild child2.fork
let count3 root.countAliveTokens -- Root + 2 children + 1 grandchild = 4
-- Cancel root (should cancel everything)
root.cancel .cancel
Async.sleep 100
let count4 root.countAliveTokens -- All cancelled = 0
return (count0, count1, count2, count3, count4)
IO.println s!"Initial (root only): {count0}"
IO.println s!"After forking 3 children: {count1}"
IO.println s!"After cancelling 1 child: {count2}"
IO.println s!"After forking grandchild: {count3}"
IO.println s!"After cancelling root: {count4}"
/--
info: Initial (root only): 1
After forking 3 children: 4
After cancelling 1 child: 3
After forking grandchild: 4
After cancelling root: 0
-/
#guard_msgs in
#eval testCountAliveTokens

View File

@@ -0,0 +1,698 @@
import Std.Internal.Async
import Std.Sync
open Std.Internal.IO Async
/-- Test ContextAsync cancellation check -/
def testIsCancelled : IO Unit := do
let (before, after) Async.block do
ContextAsync.run do
let before ContextAsync.isCancelled
ContextAsync.cancel .cancel
Async.sleep 50
let after ContextAsync.isCancelled
return (before, after)
IO.println s!"Before: {before}, After: {after}"
/--
info: Before: false, After: true
-/
#guard_msgs in
#eval testIsCancelled
/-- Test ContextAsync cancellation reason -/
def testGetCancellationReason : IO Unit := do
let res Async.block do
ContextAsync.run do
ContextAsync.cancel (.custom "test reason")
Async.sleep 50
let some reason ContextAsync.getCancellationReason
| return "ERROR: No reason found"
return s!"Reason: {reason}"
IO.println res
/--
info: Reason: custom("test reason")
-/
#guard_msgs in
#eval testGetCancellationReason
/-- Test awaitCancellation -/
def testAwaitCancellation : IO Unit := do
let received Std.Mutex.new false
Async.block do
let started Std.Mutex.new false
ContextAsync.run do
discard <| ContextAsync.concurrently
(do
started.atomically (set true)
ContextAsync.awaitCancellation
received.atomically (set true))
(do
-- Wait for task to start
while !( started.atomically get) do
Async.sleep 10
Async.sleep 100
ContextAsync.cancel .shutdown)
Async.sleep 200
let _ received.atomically get
IO.println "Cancellation received"
def testSelectorCancellationFail : IO Unit := do
let received Std.Mutex.new false
let result Async.block do
let ctx Std.CancellationContext.new
let started Std.Mutex.new false
let result do
try
ContextAsync.runIn ctx do
discard <| ContextAsync.concurrently
(do
started.atomically (set true)
let res Selectable.one #[
.case ( ContextAsync.doneSelector) (fun _ => pure true),
.case ( Selector.sleep 2000) (fun _ => pure false)
]
received.atomically (set res))
(do
throw (.userError "failed")
return ())
return Except.ok ()
catch err =>
return Except.error err
Async.sleep 500
return result
let _ received.atomically get
IO.println "Cancellation received"
if let Except.error err := result then
throw err
/--
info: Cancellation received
---
error: failed
-/
#guard_msgs in
#eval testSelectorCancellationFail
/-- Test concurrently with both tasks succeeding -/
def testConcurrently : IO Unit := do
let (a, b) Async.block do
ContextAsync.run do
ContextAsync.concurrently
(do
Async.sleep 100
return 42)
(do
Async.sleep 150
return "hello")
IO.println s!"Results: {a}, {b}"
/--
info: Results: 42, hello
-/
#guard_msgs in
#eval testConcurrently
/-- Test race with first task winning -/
def testRace : IO Unit := do
let result Async.block do
ContextAsync.run do
ContextAsync.race
(do
Async.sleep 50
return "fast")
(do
Async.sleep 200
return "slow")
IO.println s!"Winner: {result}"
/--
info: Winner: fast
-/
#guard_msgs in
#eval testRace
/-- Test concurrentlyAll -/
def testConcurrentlyAll : IO Unit := do
let results Async.block do
ContextAsync.run do
let tasks := #[
(do Async.sleep 50; return 1),
(do Async.sleep 100; return 2),
(do Async.sleep 75; return 3)
]
ContextAsync.concurrentlyAll tasks
IO.println s!"All results: {results}"
/--
info: All results: #[1, 2, 3]
-/
#guard_msgs in
#eval testConcurrentlyAll
/-- Test background task with cancellation -/
def testBackground : IO Unit := do
let counter Std.Mutex.new 0
Async.block do
ContextAsync.run do
discard <| ContextAsync.concurrently
(do
for _ in [0:10] do
if ContextAsync.isCancelled then
break
counter.atomically (modify (· + 1))
Async.sleep 50)
(do
-- Let it run for a bit
Async.sleep 150
ContextAsync.cancel .cancel)
Async.sleep 200
let final counter.atomically get
IO.println s!"Counter reached: {final}"
/-- Test fork cancellation isolation -/
def testForkCancellation : IO Unit := do
let parent Std.CancellationContext.new
let childCancelled Std.Mutex.new false
let parentCancelled Std.Mutex.new false
Async.block do
ContextAsync.runIn parent do
discard <| ContextAsync.concurrentlyAll #[
(do
let child ContextAsync.getContext
Async.sleep 100
child.cancel .cancel
childCancelled.atomically (set true)),
(do
Async.sleep 200
if parent.isCancelled then
parentCancelled.atomically (set true))
]
let childWasCancelled childCancelled.atomically get
let parentWasCancelled parentCancelled.atomically get
IO.println s!"Child cancelled: {childWasCancelled}, Parent cancelled: {parentWasCancelled}"
/--
info: Child cancelled: true, Parent cancelled: false
-/
#guard_msgs in
#eval testForkCancellation
/-- Test doneSelector -/
partial def testNestedFork : IO Unit := do
let res Async.block do
ContextAsync.run do
let ctx ContextAsync.getContext
let sel ContextAsync.doneSelector
let (_, result) ContextAsync.concurrently
(do
Async.sleep 100
ctx.cancel .deadline)
(Selectable.one #[.case sel (fun _ => pure true)])
return result
IO.println s!"Done selector triggered: {res}"
/--
info: Done selector triggered: true
-/
#guard_msgs in
#eval testNestedFork
/-- Test Selector.cancelled -/
def testSelectorCancelled : IO Unit := do
let res Async.block do
ContextAsync.run do
let ctx ContextAsync.getContext
let sel Selector.cancelled
let (_, result) ContextAsync.concurrently
(do
Async.sleep 150
ctx.cancel .shutdown)
(Selectable.one #[.case sel (fun _ => pure true)])
return result
IO.println s!"Selector.cancelled triggered: {res}"
/--
info: Selector.cancelled triggered: true
-/
#guard_msgs in
#eval testSelectorCancelled
/-- Test MonadLift instances -/
def testMonadLift : IO Unit := do
let (msg1, msg2) Async.block do
ContextAsync.run do
-- Lift from IO
let msg1 : String := "From IO"
-- Lift from BaseIO
let msg2 : String := "From BaseIO"
-- Lift from Async
let _ (Async.sleep 50 : Async Unit)
return (msg1, msg2)
IO.println msg1
IO.println msg2
IO.println "All lifts work"
/--
info: From IO
From BaseIO
All lifts work
-/
#guard_msgs in
#eval testMonadLift
/-- Test exception handling in ContextAsync -/
def testExceptionHandling : IO Unit := do
let res Async.block do
ContextAsync.run do
try
throw (IO.userError "test error")
return "Should not reach here"
catch e =>
return s!"Caught: {e}"
IO.println res
/--
info: Caught: test error
-/
#guard_msgs in
#eval testExceptionHandling
/-- Test tryFinally in ContextAsync -/
def testTryFinally : IO Unit := do
let cleaned Std.Mutex.new false
Async.block do
ContextAsync.run do
try
ContextAsync.cancel .cancel
ContextAsync.awaitCancellation
finally
cleaned.atomically (set true)
let wasCleanedUp cleaned.atomically get
IO.println s!"Cleanup ran: {wasCleanedUp}"
/--
info: Cleanup ran: true
-/
#guard_msgs in
#eval testTryFinally
/-- Test race with cancellation -/
def testRaceWithCancellation : IO Unit := do
let ctx Std.CancellationContext.new
let leftCancelled Std.Mutex.new false
let rightCancelled Std.Mutex.new false
Async.block do
ContextAsync.runIn ctx do
let _ ContextAsync.race
(do
try
Async.sleep 500
return "left"
finally
if ContextAsync.isCancelled then
leftCancelled.atomically (set true))
(do
try
Async.sleep 50
return "right"
finally
if ContextAsync.isCancelled then
rightCancelled.atomically (set true))
Async.sleep 1000
let left leftCancelled.atomically get
let right rightCancelled.atomically get
IO.println s!"Left cancelled: {left}, Right cancelled: {right}"
/--
info: Left cancelled: true, Right cancelled: false
-/
#guard_msgs in
#eval testRaceWithCancellation
/-- Test complex concurrent workflow -/
def testComplexWorkflow : IO Unit := do
let results Std.Mutex.new ([] : List String)
Async.block do
ContextAsync.run do
-- Run multiple concurrent operations
let (a, b) ContextAsync.concurrently
(do
Async.sleep 50
results.atomically (modify ("A"::·))
return 1)
(do
Async.sleep 75
results.atomically (modify ("B"::·))
return 2)
-- Additional concurrent task
discard <| ContextAsync.concurrently
(do
Async.sleep 100
results.atomically (modify ("BG"::·)))
(do
Async.sleep 200
results.atomically (modify (s!"Sum:{a+b}"::·)))
let final results.atomically get
IO.println s!"Results: {final.reverse}"
/--
info: Results: [A, B, BG, Sum:3]
-/
#guard_msgs in
#eval testComplexWorkflow
def testConcurrentlyAllException : IO Unit := do
let ref IO.mkRef ""
try
Async.block do
ContextAsync.run do
let tasks := #[
(do
Async.sleep 1000
if ContextAsync.isCancelled then
ref.set "cancelled"
return
else
ref.set "not cancelled"
Async.sleep 500
if ContextAsync.isCancelled then
ref.modify (· ++ ", cancelled")
else
ref.modify (· ++ ", not cancelled")),
(do
Async.sleep 250
throw (IO.userError "Error: Hello"))
]
discard <| ContextAsync.concurrentlyAll tasks
finally
IO.println ( ref.get)
/--
info: cancelled
---
error: Error: Hello
-/
#guard_msgs in
#eval testConcurrentlyAllException
/-- Test that tasks in ContextAsync.run are not cancelled when run completes -/
def test0 : IO Unit := do
let ref IO.mkRef false
Async.block do
ContextAsync.run do
Async.sleep 100
if ContextAsync.isCancelled then
ref.set true
IO.sleep 200
IO.println s!"{← ref.get}"
/--
info: false
-/
#guard_msgs in
#eval test0
/-- Test that background tasks are cancelled when ContextAsync.run completes -/
def test1 : IO Unit := do
let ref IO.mkRef false
Async.block do
ContextAsync.run do
ContextAsync.background do
Async.sleep 100
if ContextAsync.isCancelled then
ref.set true
IO.sleep 200
IO.println s!"{← ref.get}"
/--
info: true
-/
#guard_msgs in
#eval test1
/-- Test that nested background tasks (ContextAsync.background in ContextAsync.background) are cancelled -/
def test2 : IO Unit := do
let ref IO.mkRef false
Async.block do
ContextAsync.run do
ContextAsync.background do
ContextAsync.background do
Async.sleep 100
if ContextAsync.isCancelled then
ref.set true
IO.sleep 200
IO.println s!"{← ref.get}"
/--
info: true
-/
#guard_msgs in
#eval test2
/-- Test that ContextAsync.background in Async.background is cancelled -/
def test2' : IO Unit := do
let ref IO.mkRef false
Async.block do
ContextAsync.run do
Async.background do
ContextAsync.background do
Async.sleep 100
if ContextAsync.isCancelled then
ref.set true
IO.sleep 200
IO.println s!"{← ref.get}"
/--
info: true
-/
#guard_msgs in
#eval test2'
/-- Test that Async.background in ContextAsync.background is cancelled -/
def test2'' : IO Unit := do
let ref IO.mkRef false
Async.block do
ContextAsync.run do
ContextAsync.background do
Async.background do
Async.sleep 100
if ContextAsync.isCancelled then
ref.set true
IO.sleep 200
IO.println s!"{← ref.get}"
/--
info: true
-/
#guard_msgs in
#eval test2''
/-- Test concurrently with first task succeeding immediately, others checking cancellation -/
def testConcurrentlySuccessWithCancellation : IO Unit := do
let task2Cancelled Std.Mutex.new false
let task3Cancelled Std.Mutex.new false
let results Async.block do
ContextAsync.run do
ContextAsync.concurrentlyAll #[
(do
return "first"),
(do
-- Second task waits and checks for cancellation
let res Selectable.one #[
.case ( ContextAsync.doneSelector) (fun _ => pure true),
.case ( Selector.sleep 500) (fun _ => pure false)
]
task2Cancelled.atomically (set (res))
return "second"),
(do
let res Selectable.one #[
.case ( ContextAsync.doneSelector) (fun _ => pure true),
.case ( Selector.sleep 500) (fun _ => pure false)
]
task3Cancelled.atomically (set (res))
return "third")
]
let t2 task2Cancelled.atomically get
let t3 task3Cancelled.atomically get
IO.println s!"Results: {results}"
IO.println s!"Task2 cancelled: {t2}, Task3 cancelled: {t3}"
/--
info: Results: #[first, second, third]
Task2 cancelled: false, Task3 cancelled: false
-/
#guard_msgs in
#eval testConcurrentlySuccessWithCancellation
/-- Test concurrently with first task failing, others checking for cancellation -/
def testConcurrentlyFailWithCancellation : IO Unit := do
let task2Cancelled Std.Mutex.new false
let task3Cancelled Std.Mutex.new false
let results Async.block do
ContextAsync.run do
try
let result ContextAsync.concurrentlyAll #[
(do
-- First task fails immediately
throw (IO.userError "first task failed")),
(do
-- Second task waits and checks for cancellation
let res Selectable.one #[
.case ( ContextAsync.doneSelector) (fun _ => pure true),
.case ( Selector.sleep 2000) (fun _ => pure false)
]
task2Cancelled.atomically (set (res))
return "second"),
(do
let res Selectable.one #[
.case ( ContextAsync.doneSelector) (fun _ => pure true),
.case ( Selector.sleep 2000) (fun _ => pure false)
]
task3Cancelled.atomically (set (res))
return "third")
]
return Except.ok result
catch e =>
Async.sleep 500
return Except.error e
let t2 task2Cancelled.atomically get
let t3 task3Cancelled.atomically get
match results with
| .ok results => IO.println s!"Results: {results}"
| .error e => IO.println s!"Error: {e}"
IO.println s!"Task2 cancelled: {t2}, Task3 cancelled: {t3}"
/--
info: Error: first task failed
Task2 cancelled: true, Task3 cancelled: true
-/
#guard_msgs in
#eval testConcurrentlyFailWithCancellation
/-- Test concurrently with both tasks succeeding, checking cancellation status -/
def testConcurrentlySuccessWithCancellation2Tasks : IO Unit := do
let task2Cancelled Std.Mutex.new false
let (r1, r2) Async.block do
ContextAsync.run do
ContextAsync.concurrently
(do return "first")
(do
-- Second task waits and checks for cancellation
let res Selectable.one #[
.case ( ContextAsync.doneSelector) (fun _ => pure true),
.case ( Selector.sleep 500) (fun _ => pure false)
]
task2Cancelled.atomically (set res)
return "second")
let t2 task2Cancelled.atomically get
IO.println s!"Results: {r1}, {r2}"
IO.println s!"Task2 cancelled: {t2}"
/--
info: Results: first, second
Task2 cancelled: false
-/
#guard_msgs in
#eval testConcurrentlySuccessWithCancellation2Tasks
/-- Test concurrently with first task failing, second task checking for cancellation -/
def testConcurrentlyFailWithCancellation2Tasks : IO Unit := do
let task2Cancelled Std.Mutex.new false
try
Async.block do
ContextAsync.run do
let (_ : (String × String)) ContextAsync.concurrently
(do
-- First task fails immediately
throw (IO.userError "first task failed") : ContextAsync String)
(do
-- Second task waits and checks for cancellation
let res Selectable.one #[
.case ( ContextAsync.doneSelector) (fun _ => pure true),
.case ( Selector.sleep 2000) (fun _ => pure false)
]
task2Cancelled.atomically (set res)
return "second")
catch e =>
IO.sleep 500
let t2 task2Cancelled.atomically get
IO.println s!"Error: {e}"
IO.println s!"Task2 cancelled: {t2}"
/--
info: Error: first task failed
Task2 cancelled: true
-/
#guard_msgs in
#eval testConcurrentlyFailWithCancellation2Tasks