mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-25 14:24:08 +00:00
Compare commits
3 Commits
inferInsta
...
parallel-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d6fc6e6b45 | ||
|
|
272f0f5db3 | ||
|
|
1d3fda4130 |
@@ -59,8 +59,38 @@ as tasks complete, unlike `par`/`par'` which restore the initial state after col
|
||||
|
||||
Iterators do not have `Finite` instances, as we cannot prove termination from the available
|
||||
information. For consumers that require `Finite` (like `.toList`), use `.allowNontermination.toList`.
|
||||
|
||||
## Chunking support
|
||||
|
||||
The `par`, `par'`, `parIter`, and `parIterWithCancel` functions support optional chunking to reduce
|
||||
task creation overhead when there are many small jobs. Pass `maxTasks` to limit the number of parallel
|
||||
tasks created; jobs will be grouped into chunks that run sequentially within each task.
|
||||
|
||||
- `maxTasks = 0` (default): No chunking, one task per job (original behavior)
|
||||
- `maxTasks > 0`: Auto-compute chunk size to limit task count
|
||||
- `minChunkSize`: Minimum jobs per chunk (default 1)
|
||||
|
||||
Example: With 1000 jobs and `maxTasks := 128, minChunkSize := 8`, chunk size = 8, creating ~125 tasks.
|
||||
-/
|
||||
|
||||
/-- Split a list into chunks of at most `chunkSize` elements. -/
|
||||
def toChunks {α : Type} (xs : List α) (chunkSize : Nat) : List (List α) :=
|
||||
if h : chunkSize ≤ 1 then xs.map ([·])
|
||||
else go xs [] (Nat.lt_of_not_le h)
|
||||
where
|
||||
go (remaining : List α) (acc : List (List α)) (hc : 1 < chunkSize) : List (List α) :=
|
||||
if _h : remaining.length ≤ chunkSize then
|
||||
(remaining :: acc).reverse
|
||||
else
|
||||
go (remaining.drop chunkSize) (remaining.take chunkSize :: acc) hc
|
||||
termination_by remaining.length
|
||||
decreasing_by simp_wf; omega
|
||||
|
||||
/-- Compute chunk size given job count, max tasks, and minimum chunk size. -/
|
||||
def computeChunkSize (numJobs maxTasks minChunkSize : Nat) : Nat :=
|
||||
if maxTasks = 0 then 1
|
||||
else max minChunkSize ((numJobs + maxTasks - 1) / maxTasks)
|
||||
|
||||
public section
|
||||
|
||||
namespace Std.Iterators
|
||||
@@ -125,6 +155,38 @@ namespace Lean.Core.CoreM
|
||||
|
||||
open Std.Iterators
|
||||
|
||||
/--
|
||||
Internal state for an iterator over chunked tasks for CoreM.
|
||||
Yields individual results while internally managing chunk boundaries.
|
||||
-/
|
||||
private structure ChunkedTaskIterator (α : Type) where
|
||||
chunkTasks : List (Task (CoreM (List (Except Exception α))))
|
||||
currentResults : List (Except Exception α)
|
||||
|
||||
private instance {α : Type} : Iterator (ChunkedTaskIterator α) CoreM (Except Exception α) where
|
||||
IsPlausibleStep _
|
||||
| .yield _ _ => True
|
||||
| .skip _ => True -- Allow skip for empty chunks
|
||||
| .done => True
|
||||
step it := do
|
||||
match it.internalState.currentResults with
|
||||
| r :: rest =>
|
||||
pure <| .deflate ⟨.yield (toIterM { it.internalState with currentResults := rest } CoreM (Except Exception α)) r, trivial⟩
|
||||
| [] =>
|
||||
match it.internalState.chunkTasks with
|
||||
| [] => pure <| .deflate ⟨.done, trivial⟩
|
||||
| task :: rest =>
|
||||
try
|
||||
let chunkResults ← task.get
|
||||
match chunkResults with
|
||||
| [] =>
|
||||
-- Empty chunk, skip to try next
|
||||
pure <| .deflate ⟨.skip (toIterM { chunkTasks := rest, currentResults := [] } CoreM (Except Exception α)), trivial⟩
|
||||
| r :: rs =>
|
||||
pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := rs } CoreM (Except Exception α)) r, trivial⟩
|
||||
catch e =>
|
||||
pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := [] } CoreM (Except Exception α)) (.error e), trivial⟩
|
||||
|
||||
/--
|
||||
Runs a list of CoreM computations in parallel and returns:
|
||||
* a combined cancellation hook for all tasks, and
|
||||
@@ -150,6 +212,29 @@ def parIterWithCancel {α : Type} (jobs : List (CoreM α)) := do
|
||||
pure (Except.error e)
|
||||
return (combinedCancel, iterWithErrors)
|
||||
|
||||
/--
|
||||
Runs a list of CoreM computations in parallel with chunking and returns:
|
||||
* a combined cancellation hook for all tasks, and
|
||||
* an iterator that yields results in original order.
|
||||
|
||||
Unlike `parIterWithCancel`, this groups jobs into chunks to reduce task overhead.
|
||||
Each chunk runs its jobs sequentially, but chunks run in parallel.
|
||||
|
||||
**Parameters:**
|
||||
- `maxTasks`: Maximum number of parallel tasks (chunks). Default 0 means one task per job.
|
||||
- `minChunkSize`: Minimum jobs per chunk. Default 1.
|
||||
-/
|
||||
def parIterWithCancelChunked {α : Type} (jobs : List (CoreM α))
|
||||
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) := do
|
||||
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
|
||||
let chunks := toChunks jobs chunkSize
|
||||
let chunkJobs : List (CoreM (List (Except Exception α))) :=
|
||||
chunks.map fun (chunk : List (CoreM α)) => chunk.mapM (observing ·)
|
||||
let (cancels, tasks) := (← chunkJobs.mapM asTask).unzip
|
||||
let combinedCancel := cancels.forM id
|
||||
let flatIter := toIterM (ChunkedTaskIterator.mk tasks []) CoreM (Except Exception α)
|
||||
return (combinedCancel, flatIter)
|
||||
|
||||
/--
|
||||
Runs a list of CoreM computations in parallel (without cancellation hook).
|
||||
|
||||
@@ -192,19 +277,9 @@ Returns an iterator that yields results in completion order, wrapped in `Except
|
||||
def parIterGreedy {α : Type} (jobs : List (CoreM α)) :=
|
||||
(·.2) <$> parIterGreedyWithCancel jobs
|
||||
|
||||
/--
|
||||
Runs a list of CoreM computations in parallel and collects results in the original order,
|
||||
including the saved state after each task completes.
|
||||
|
||||
Unlike `parIter`, this waits for all tasks to complete and returns results
|
||||
in the same order as the input list, not in completion order.
|
||||
|
||||
Results are wrapped in `Except Exception (α × Core.SavedState)` so that errors in individual
|
||||
tasks don't stop the collection - you can observe all results including which tasks failed.
|
||||
|
||||
The final CoreM state is restored to the initial state (before tasks ran).
|
||||
-/
|
||||
def par {α : Type} (jobs : List (CoreM α)) : CoreM (List (Except Exception (α × Core.SavedState))) := do
|
||||
/-- Internal: run jobs in parallel without chunking, returning state. -/
|
||||
private def parCore {α : Type} (jobs : List (CoreM α)) :
|
||||
CoreM (List (Except Exception (α × Core.SavedState))) := do
|
||||
let initialState ← get
|
||||
let tasks ← jobs.mapM asTask'
|
||||
let mut results := []
|
||||
@@ -218,13 +293,47 @@ def par {α : Type} (jobs : List (CoreM α)) : CoreM (List (Except Exception (α
|
||||
|
||||
/--
|
||||
Runs a list of CoreM computations in parallel and collects results in the original order,
|
||||
discarding state information.
|
||||
including the saved state after each task completes.
|
||||
|
||||
Unlike `par`, this doesn't return state information from tasks.
|
||||
Unlike `parIter`, this waits for all tasks to complete and returns results
|
||||
in the same order as the input list, not in completion order.
|
||||
|
||||
Results are wrapped in `Except Exception (α × Core.SavedState)` so that errors in individual
|
||||
tasks don't stop the collection - you can observe all results including which tasks failed.
|
||||
|
||||
The final CoreM state is restored to the initial state (before tasks ran).
|
||||
|
||||
**Chunking:** Pass `maxTasks > 0` to limit parallel tasks by grouping jobs into chunks.
|
||||
-/
|
||||
def par' {α : Type} (jobs : List (CoreM α)) : CoreM (List (Except Exception α)) := do
|
||||
def par {α : Type} (jobs : List (CoreM α))
|
||||
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) :
|
||||
CoreM (List (Except Exception (α × Core.SavedState))) := do
|
||||
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
|
||||
if chunkSize ≤ 1 then
|
||||
parCore jobs
|
||||
else
|
||||
let initialState ← get
|
||||
let chunks := toChunks jobs chunkSize
|
||||
let chunkJobs := chunks.map fun chunk => do
|
||||
let mut results := []
|
||||
for job in chunk do
|
||||
let r ← observing do
|
||||
let a ← job
|
||||
pure (a, ← saveState)
|
||||
results := r :: results
|
||||
pure results.reverse
|
||||
let chunkResults ← parCore chunkJobs
|
||||
set initialState
|
||||
let mut allResults := []
|
||||
for chunkResult in chunkResults do
|
||||
match chunkResult with
|
||||
| .ok (jobResults, _) => allResults := allResults ++ jobResults
|
||||
| .error e => allResults := allResults ++ [.error e]
|
||||
return allResults
|
||||
|
||||
/-- Internal: run jobs in parallel without chunking, discarding state. -/
|
||||
private def parCore' {α : Type} (jobs : List (CoreM α)) :
|
||||
CoreM (List (Except Exception α)) := do
|
||||
let initialState ← get
|
||||
let tasks ← jobs.mapM asTask'
|
||||
let mut results := []
|
||||
@@ -237,6 +346,40 @@ def par' {α : Type} (jobs : List (CoreM α)) : CoreM (List (Except Exception α
|
||||
set initialState
|
||||
return results.reverse
|
||||
|
||||
/--
|
||||
Runs a list of CoreM computations in parallel and collects results in the original order,
|
||||
discarding state information.
|
||||
|
||||
Unlike `par`, this doesn't return state information from tasks.
|
||||
|
||||
The final CoreM state is restored to the initial state (before tasks ran).
|
||||
|
||||
**Chunking:** Pass `maxTasks > 0` to limit parallel tasks by grouping jobs into chunks.
|
||||
-/
|
||||
def par' {α : Type} (jobs : List (CoreM α))
|
||||
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) :
|
||||
CoreM (List (Except Exception α)) := do
|
||||
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
|
||||
if chunkSize ≤ 1 then
|
||||
parCore' jobs
|
||||
else
|
||||
let initialState ← get
|
||||
let chunks := toChunks jobs chunkSize
|
||||
let chunkJobs := chunks.map fun chunk => do
|
||||
let mut results := []
|
||||
for job in chunk do
|
||||
let r ← observing job
|
||||
results := r :: results
|
||||
pure results.reverse
|
||||
let chunkResults ← parCore' chunkJobs
|
||||
set initialState
|
||||
let mut allResults := []
|
||||
for chunkResult in chunkResults do
|
||||
match chunkResult with
|
||||
| .ok jobResults => allResults := allResults ++ jobResults
|
||||
| .error e => allResults := allResults ++ [.error e]
|
||||
return allResults
|
||||
|
||||
/--
|
||||
Runs a list of CoreM computations in parallel and returns the first successful result
|
||||
(by completion order, not list order).
|
||||
@@ -260,18 +403,43 @@ namespace Lean.Meta.MetaM
|
||||
open Std.Iterators
|
||||
|
||||
/--
|
||||
Runs a list of MetaM computations in parallel and collects results in the original order,
|
||||
including the saved state after each task completes.
|
||||
|
||||
Unlike `parIter`, this waits for all tasks to complete and returns results
|
||||
in the same order as the input list, not in completion order.
|
||||
|
||||
Results are wrapped in `Except Exception (α × Meta.SavedState)` so that errors in individual
|
||||
tasks don't stop the collection - you can observe all results including which tasks failed.
|
||||
|
||||
The final MetaM state is restored to the initial state (before tasks ran).
|
||||
Internal state for an iterator over chunked tasks for MetaM.
|
||||
Yields individual results while internally managing chunk boundaries.
|
||||
-/
|
||||
def par {α : Type} (jobs : List (MetaM α)) : MetaM (List (Except Exception (α × Meta.SavedState))) := do
|
||||
structure ChunkedTaskIterator (α : Type) where
|
||||
chunkTasks : List (Task (MetaM (List (Except Exception α))))
|
||||
currentResults : List (Except Exception α)
|
||||
|
||||
instance {α : Type} : Iterator (ChunkedTaskIterator α) MetaM (Except Exception α) where
|
||||
IsPlausibleStep _
|
||||
| .yield _ _ => True
|
||||
| .skip _ => True
|
||||
| .done => True
|
||||
step it := do
|
||||
match it.internalState.currentResults with
|
||||
| r :: rest =>
|
||||
pure <| .deflate ⟨.yield (toIterM { it.internalState with currentResults := rest } MetaM (Except Exception α)) r, trivial⟩
|
||||
| [] =>
|
||||
match it.internalState.chunkTasks with
|
||||
| [] => pure <| .deflate ⟨.done, trivial⟩
|
||||
| task :: rest =>
|
||||
try
|
||||
let chunkResults ← task.get
|
||||
match chunkResults with
|
||||
| [] =>
|
||||
pure <| .deflate ⟨.skip (toIterM { chunkTasks := rest, currentResults := [] } MetaM (Except Exception α)), trivial⟩
|
||||
| r :: rs =>
|
||||
pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := rs } MetaM (Except Exception α)) r, trivial⟩
|
||||
catch e =>
|
||||
pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := [] } MetaM (Except Exception α)) (.error e), trivial⟩
|
||||
|
||||
instance {α : Type} {n : Type → Type u} [Monad n] [MonadLiftT MetaM n] :
|
||||
IteratorLoopPartial (ChunkedTaskIterator α) MetaM n :=
|
||||
.defaultImplementation
|
||||
|
||||
/-- Internal: run jobs in parallel without chunking, returning state. -/
|
||||
private def parCore {α : Type} (jobs : List (MetaM α)) :
|
||||
MetaM (List (Except Exception (α × Meta.SavedState))) := do
|
||||
let initialState ← get
|
||||
let tasks ← jobs.mapM asTask'
|
||||
let mut results := []
|
||||
@@ -283,15 +451,9 @@ def par {α : Type} (jobs : List (MetaM α)) : MetaM (List (Except Exception (α
|
||||
set initialState
|
||||
return results.reverse
|
||||
|
||||
/--
|
||||
Runs a list of MetaM computations in parallel and collects results in the original order,
|
||||
discarding state information.
|
||||
|
||||
Unlike `par`, this doesn't return state information from tasks.
|
||||
|
||||
The final MetaM state is restored to the initial state (before tasks ran).
|
||||
-/
|
||||
def par' {α : Type} (jobs : List (MetaM α)) : MetaM (List (Except Exception α)) := do
|
||||
/-- Internal: run jobs in parallel without chunking, discarding state. -/
|
||||
private def parCore' {α : Type} (jobs : List (MetaM α)) :
|
||||
MetaM (List (Except Exception α)) := do
|
||||
let initialState ← get
|
||||
let tasks ← jobs.mapM asTask'
|
||||
let mut results := []
|
||||
@@ -304,6 +466,80 @@ def par' {α : Type} (jobs : List (MetaM α)) : MetaM (List (Except Exception α
|
||||
set initialState
|
||||
return results.reverse
|
||||
|
||||
/--
|
||||
Runs a list of MetaM computations in parallel and collects results in the original order,
|
||||
including the saved state after each task completes.
|
||||
|
||||
Unlike `parIter`, this waits for all tasks to complete and returns results
|
||||
in the same order as the input list, not in completion order.
|
||||
|
||||
Results are wrapped in `Except Exception (α × Meta.SavedState)` so that errors in individual
|
||||
tasks don't stop the collection - you can observe all results including which tasks failed.
|
||||
|
||||
The final MetaM state is restored to the initial state (before tasks ran).
|
||||
|
||||
**Chunking:** Pass `maxTasks > 0` to limit parallel tasks by grouping jobs into chunks.
|
||||
-/
|
||||
def par {α : Type} (jobs : List (MetaM α))
|
||||
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) :
|
||||
MetaM (List (Except Exception (α × Meta.SavedState))) := do
|
||||
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
|
||||
if chunkSize ≤ 1 then
|
||||
parCore jobs
|
||||
else
|
||||
let initialState ← get
|
||||
let chunks := toChunks jobs chunkSize
|
||||
let chunkJobs := chunks.map fun chunk => do
|
||||
let mut results := []
|
||||
for job in chunk do
|
||||
let r ← observing do
|
||||
let a ← job
|
||||
pure (a, ← saveState)
|
||||
results := r :: results
|
||||
pure results.reverse
|
||||
let chunkResults ← parCore chunkJobs
|
||||
set initialState
|
||||
let mut allResults := []
|
||||
for chunkResult in chunkResults do
|
||||
match chunkResult with
|
||||
| .ok (jobResults, _) => allResults := allResults ++ jobResults
|
||||
| .error e => allResults := allResults ++ [.error e]
|
||||
return allResults
|
||||
|
||||
/--
|
||||
Runs a list of MetaM computations in parallel and collects results in the original order,
|
||||
discarding state information.
|
||||
|
||||
Unlike `par`, this doesn't return state information from tasks.
|
||||
|
||||
The final MetaM state is restored to the initial state (before tasks ran).
|
||||
|
||||
**Chunking:** Pass `maxTasks > 0` to limit parallel tasks by grouping jobs into chunks.
|
||||
-/
|
||||
def par' {α : Type} (jobs : List (MetaM α))
|
||||
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) :
|
||||
MetaM (List (Except Exception α)) := do
|
||||
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
|
||||
if chunkSize ≤ 1 then
|
||||
parCore' jobs
|
||||
else
|
||||
let initialState ← get
|
||||
let chunks := toChunks jobs chunkSize
|
||||
let chunkJobs := chunks.map fun chunk => do
|
||||
let mut results := []
|
||||
for job in chunk do
|
||||
let r ← observing job
|
||||
results := r :: results
|
||||
pure results.reverse
|
||||
let chunkResults ← parCore' chunkJobs
|
||||
set initialState
|
||||
let mut allResults := []
|
||||
for chunkResult in chunkResults do
|
||||
match chunkResult with
|
||||
| .ok jobResults => allResults := allResults ++ jobResults
|
||||
| .error e => allResults := allResults ++ [.error e]
|
||||
return allResults
|
||||
|
||||
/--
|
||||
Runs a list of MetaM computations in parallel and returns:
|
||||
* a combined cancellation hook for all tasks, and
|
||||
@@ -321,7 +557,6 @@ The iterator will terminate after all jobs complete (assuming they all do comple
|
||||
def parIterWithCancel {α : Type} (jobs : List (MetaM α)) := do
|
||||
let (cancels, tasks) := (← jobs.mapM asTask).unzip
|
||||
let combinedCancel := cancels.forM id
|
||||
-- Create iterator that processes tasks sequentially
|
||||
let iterWithErrors := tasks.iter.mapM fun (task : Task (MetaM α)) => do
|
||||
try
|
||||
let result ← task.get
|
||||
@@ -330,6 +565,29 @@ def parIterWithCancel {α : Type} (jobs : List (MetaM α)) := do
|
||||
pure (Except.error e)
|
||||
return (combinedCancel, iterWithErrors)
|
||||
|
||||
/--
|
||||
Runs a list of MetaM computations in parallel with chunking and returns:
|
||||
* a combined cancellation hook for all tasks, and
|
||||
* an iterator that yields results in original order.
|
||||
|
||||
Unlike `parIterWithCancel`, this groups jobs into chunks to reduce task overhead.
|
||||
Each chunk runs its jobs sequentially, but chunks run in parallel.
|
||||
|
||||
**Parameters:**
|
||||
- `maxTasks`: Maximum number of parallel tasks (chunks). Default 0 means one task per job.
|
||||
- `minChunkSize`: Minimum jobs per chunk. Default 1.
|
||||
-/
|
||||
def parIterWithCancelChunked {α : Type} (jobs : List (MetaM α))
|
||||
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) := do
|
||||
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
|
||||
let chunks := toChunks jobs chunkSize
|
||||
let chunkJobs : List (MetaM (List (Except Exception α))) :=
|
||||
chunks.map fun (chunk : List (MetaM α)) => chunk.mapM (observing ·)
|
||||
let (cancels, tasks) := (← chunkJobs.mapM asTask).unzip
|
||||
let combinedCancel := cancels.forM id
|
||||
let flatIter := toIterM (ChunkedTaskIterator.mk tasks []) MetaM (Except Exception α)
|
||||
return (combinedCancel, flatIter)
|
||||
|
||||
/--
|
||||
Runs a list of MetaM computations in parallel (without cancellation hook).
|
||||
|
||||
@@ -394,6 +652,37 @@ namespace Lean.Elab.Term.TermElabM
|
||||
|
||||
open Std.Iterators
|
||||
|
||||
/--
|
||||
Internal state for an iterator over chunked tasks for TermElabM.
|
||||
Yields individual results while internally managing chunk boundaries.
|
||||
-/
|
||||
private structure ChunkedTaskIterator (α : Type) where
|
||||
chunkTasks : List (Task (TermElabM (List (Except Exception α))))
|
||||
currentResults : List (Except Exception α)
|
||||
|
||||
private instance {α : Type} : Iterator (ChunkedTaskIterator α) TermElabM (Except Exception α) where
|
||||
IsPlausibleStep _
|
||||
| .yield _ _ => True
|
||||
| .skip _ => True
|
||||
| .done => True
|
||||
step it := do
|
||||
match it.internalState.currentResults with
|
||||
| r :: rest =>
|
||||
pure <| .deflate ⟨.yield (toIterM { it.internalState with currentResults := rest } TermElabM (Except Exception α)) r, trivial⟩
|
||||
| [] =>
|
||||
match it.internalState.chunkTasks with
|
||||
| [] => pure <| .deflate ⟨.done, trivial⟩
|
||||
| task :: rest =>
|
||||
try
|
||||
let chunkResults ← task.get
|
||||
match chunkResults with
|
||||
| [] =>
|
||||
pure <| .deflate ⟨.skip (toIterM { chunkTasks := rest, currentResults := [] } TermElabM (Except Exception α)), trivial⟩
|
||||
| r :: rs =>
|
||||
pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := rs } TermElabM (Except Exception α)) r, trivial⟩
|
||||
catch e =>
|
||||
pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := [] } TermElabM (Except Exception α)) (.error e), trivial⟩
|
||||
|
||||
/--
|
||||
Runs a list of TermElabM computations in parallel and returns:
|
||||
* a combined cancellation hook for all tasks, and
|
||||
@@ -411,7 +700,6 @@ The iterator will terminate after all jobs complete (assuming they all do comple
|
||||
def parIterWithCancel {α : Type} (jobs : List (TermElabM α)) := do
|
||||
let (cancels, tasks) := (← jobs.mapM asTask).unzip
|
||||
let combinedCancel := cancels.forM id
|
||||
-- Create iterator that processes tasks sequentially
|
||||
let iterWithErrors := tasks.iter.mapM fun (task : Task (TermElabM α)) => do
|
||||
try
|
||||
let result ← task.get
|
||||
@@ -420,6 +708,34 @@ def parIterWithCancel {α : Type} (jobs : List (TermElabM α)) := do
|
||||
pure (Except.error e)
|
||||
return (combinedCancel, iterWithErrors)
|
||||
|
||||
/--
|
||||
Runs a list of TermElabM computations in parallel with chunking and returns:
|
||||
* a combined cancellation hook for all tasks, and
|
||||
* an iterator that yields results in original order.
|
||||
|
||||
Unlike `parIterWithCancel`, this groups jobs into chunks to reduce task overhead.
|
||||
Each chunk runs its jobs sequentially, but chunks run in parallel.
|
||||
|
||||
**Parameters:**
|
||||
- `maxTasks`: Maximum number of parallel tasks (chunks). Default 0 means one task per job.
|
||||
- `minChunkSize`: Minimum jobs per chunk. Default 1.
|
||||
-/
|
||||
def parIterWithCancelChunked {α : Type} (jobs : List (TermElabM α))
|
||||
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) := do
|
||||
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
|
||||
let chunks := toChunks jobs chunkSize
|
||||
let chunkJobs : List (TermElabM (List (Except Exception α))) :=
|
||||
chunks.map fun (chunk : List (TermElabM α)) => chunk.mapM fun job => do
|
||||
try
|
||||
let a ← job
|
||||
pure (.ok a)
|
||||
catch e =>
|
||||
pure (.error e)
|
||||
let (cancels, tasks) := (← chunkJobs.mapM asTask).unzip
|
||||
let combinedCancel := cancels.forM id
|
||||
let flatIter := toIterM (ChunkedTaskIterator.mk tasks []) TermElabM (Except Exception α)
|
||||
return (combinedCancel, flatIter)
|
||||
|
||||
/--
|
||||
Runs a list of TermElabM computations in parallel (without cancellation hook).
|
||||
|
||||
@@ -462,19 +778,9 @@ Returns an iterator that yields results in completion order, wrapped in `Except
|
||||
def parIterGreedy {α : Type} (jobs : List (TermElabM α)) :=
|
||||
(·.2) <$> parIterGreedyWithCancel jobs
|
||||
|
||||
/--
|
||||
Runs a list of TermElabM computations in parallel and collects results in the original order,
|
||||
including the saved state after each task completes.
|
||||
|
||||
Unlike `parIter`, this waits for all tasks to complete and returns results
|
||||
in the same order as the input list, not in completion order.
|
||||
|
||||
Results are wrapped in `Except Exception (α × Term.SavedState)` so that errors in individual
|
||||
tasks don't stop the collection - you can observe all results including which tasks failed.
|
||||
|
||||
The final TermElabM state is restored to the initial state (before tasks ran).
|
||||
-/
|
||||
def par {α : Type} (jobs : List (TermElabM α)) : TermElabM (List (Except Exception (α × Term.SavedState))) := do
|
||||
/-- Internal: run jobs in parallel without chunking, returning state. -/
|
||||
private def parCore {α : Type} (jobs : List (TermElabM α)) :
|
||||
TermElabM (List (Except Exception (α × Term.SavedState))) := do
|
||||
let initialState ← get
|
||||
let tasks ← jobs.mapM asTask'
|
||||
let mut results := []
|
||||
@@ -488,15 +794,9 @@ def par {α : Type} (jobs : List (TermElabM α)) : TermElabM (List (Except Excep
|
||||
set initialState
|
||||
return results.reverse
|
||||
|
||||
/--
|
||||
Runs a list of TermElabM computations in parallel and collects results in the original order,
|
||||
discarding state information.
|
||||
|
||||
Unlike `par`, this doesn't return state information from tasks.
|
||||
|
||||
The final TermElabM state is restored to the initial state (before tasks ran).
|
||||
-/
|
||||
def par' {α : Type} (jobs : List (TermElabM α)) : TermElabM (List (Except Exception α)) := do
|
||||
/-- Internal: run jobs in parallel without chunking, discarding state. -/
|
||||
private def parCore' {α : Type} (jobs : List (TermElabM α)) :
|
||||
TermElabM (List (Except Exception α)) := do
|
||||
let initialState ← get
|
||||
let tasks ← jobs.mapM asTask'
|
||||
let mut results := []
|
||||
@@ -509,6 +809,86 @@ def par' {α : Type} (jobs : List (TermElabM α)) : TermElabM (List (Except Exce
|
||||
set initialState
|
||||
return results.reverse
|
||||
|
||||
/--
|
||||
Runs a list of TermElabM computations in parallel and collects results in the original order,
|
||||
including the saved state after each task completes.
|
||||
|
||||
Unlike `parIter`, this waits for all tasks to complete and returns results
|
||||
in the same order as the input list, not in completion order.
|
||||
|
||||
Results are wrapped in `Except Exception (α × Term.SavedState)` so that errors in individual
|
||||
tasks don't stop the collection - you can observe all results including which tasks failed.
|
||||
|
||||
The final TermElabM state is restored to the initial state (before tasks ran).
|
||||
|
||||
**Chunking:** Pass `maxTasks > 0` to limit parallel tasks by grouping jobs into chunks.
|
||||
-/
|
||||
def par {α : Type} (jobs : List (TermElabM α))
|
||||
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) :
|
||||
TermElabM (List (Except Exception (α × Term.SavedState))) := do
|
||||
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
|
||||
if chunkSize ≤ 1 then
|
||||
parCore jobs
|
||||
else
|
||||
let initialState ← get
|
||||
let chunks := toChunks jobs chunkSize
|
||||
-- Each chunk processes its jobs sequentially, collecting Except results
|
||||
let chunkJobs := chunks.map fun chunk => do
|
||||
let mut results : List (Except Exception (α × Term.SavedState)) := []
|
||||
for job in chunk do
|
||||
try
|
||||
let a ← job
|
||||
let s ← saveState
|
||||
results := .ok (a, s) :: results
|
||||
catch e =>
|
||||
results := .error e :: results
|
||||
pure results.reverse
|
||||
let chunkResults ← parCore' chunkJobs
|
||||
set initialState
|
||||
let mut allResults := []
|
||||
for chunkResult in chunkResults do
|
||||
match chunkResult with
|
||||
| .ok jobResults => allResults := allResults ++ jobResults
|
||||
| .error e => allResults := allResults ++ [.error e]
|
||||
return allResults
|
||||
|
||||
/--
|
||||
Runs a list of TermElabM computations in parallel and collects results in the original order,
|
||||
discarding state information.
|
||||
|
||||
Unlike `par`, this doesn't return state information from tasks.
|
||||
|
||||
The final TermElabM state is restored to the initial state (before tasks ran).
|
||||
|
||||
**Chunking:** Pass `maxTasks > 0` to limit parallel tasks by grouping jobs into chunks.
|
||||
-/
|
||||
def par' {α : Type} (jobs : List (TermElabM α))
|
||||
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) :
|
||||
TermElabM (List (Except Exception α)) := do
|
||||
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
|
||||
if chunkSize ≤ 1 then
|
||||
parCore' jobs
|
||||
else
|
||||
let initialState ← get
|
||||
let chunks := toChunks jobs chunkSize
|
||||
let chunkJobs := chunks.map fun chunk => do
|
||||
let mut results : List (Except Exception α) := []
|
||||
for job in chunk do
|
||||
try
|
||||
let a ← job
|
||||
results := .ok a :: results
|
||||
catch e =>
|
||||
results := .error e :: results
|
||||
pure results.reverse
|
||||
let chunkResults ← parCore' chunkJobs
|
||||
set initialState
|
||||
let mut allResults := []
|
||||
for chunkResult in chunkResults do
|
||||
match chunkResult with
|
||||
| .ok jobResults => allResults := allResults ++ jobResults
|
||||
| .error e => allResults := allResults ++ [.error e]
|
||||
return allResults
|
||||
|
||||
/--
|
||||
Runs a list of TermElabM computations in parallel and returns the first successful result
|
||||
(by completion order, not list order).
|
||||
@@ -531,6 +911,37 @@ namespace Lean.Elab.Tactic.TacticM
|
||||
|
||||
open Std.Iterators
|
||||
|
||||
/--
|
||||
Internal state for an iterator over chunked tasks for TacticM.
|
||||
Yields individual results while internally managing chunk boundaries.
|
||||
-/
|
||||
private structure ChunkedTaskIterator (α : Type) where
|
||||
chunkTasks : List (Task (TacticM (List (Except Exception α))))
|
||||
currentResults : List (Except Exception α)
|
||||
|
||||
private instance {α : Type} : Iterator (ChunkedTaskIterator α) TacticM (Except Exception α) where
|
||||
IsPlausibleStep _
|
||||
| .yield _ _ => True
|
||||
| .skip _ => True
|
||||
| .done => True
|
||||
step it := do
|
||||
match it.internalState.currentResults with
|
||||
| r :: rest =>
|
||||
pure <| .deflate ⟨.yield (toIterM { it.internalState with currentResults := rest } TacticM (Except Exception α)) r, trivial⟩
|
||||
| [] =>
|
||||
match it.internalState.chunkTasks with
|
||||
| [] => pure <| .deflate ⟨.done, trivial⟩
|
||||
| task :: rest =>
|
||||
try
|
||||
let chunkResults ← task.get
|
||||
match chunkResults with
|
||||
| [] =>
|
||||
pure <| .deflate ⟨.skip (toIterM { chunkTasks := rest, currentResults := [] } TacticM (Except Exception α)), trivial⟩
|
||||
| r :: rs =>
|
||||
pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := rs } TacticM (Except Exception α)) r, trivial⟩
|
||||
catch e =>
|
||||
pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := [] } TacticM (Except Exception α)) (.error e), trivial⟩
|
||||
|
||||
/--
|
||||
Runs a list of TacticM computations in parallel and returns:
|
||||
* a combined cancellation hook for all tasks, and
|
||||
@@ -548,7 +959,6 @@ The iterator will terminate after all jobs complete (assuming they all do comple
|
||||
def parIterWithCancel {α : Type} (jobs : List (TacticM α)) := do
|
||||
let (cancels, tasks) := (← jobs.mapM asTask).unzip
|
||||
let combinedCancel := cancels.forM id
|
||||
-- Create iterator that processes tasks sequentially
|
||||
let iterWithErrors := tasks.iter.mapM fun (task : Task (TacticM α)) => do
|
||||
try
|
||||
let result ← task.get
|
||||
@@ -557,6 +967,34 @@ def parIterWithCancel {α : Type} (jobs : List (TacticM α)) := do
|
||||
pure (Except.error e)
|
||||
return (combinedCancel, iterWithErrors)
|
||||
|
||||
/--
|
||||
Runs a list of TacticM computations in parallel with chunking and returns:
|
||||
* a combined cancellation hook for all tasks, and
|
||||
* an iterator that yields results in original order.
|
||||
|
||||
Unlike `parIterWithCancel`, this groups jobs into chunks to reduce task overhead.
|
||||
Each chunk runs its jobs sequentially, but chunks run in parallel.
|
||||
|
||||
**Parameters:**
|
||||
- `maxTasks`: Maximum number of parallel tasks (chunks). Default 0 means one task per job.
|
||||
- `minChunkSize`: Minimum jobs per chunk. Default 1.
|
||||
-/
|
||||
def parIterWithCancelChunked {α : Type} (jobs : List (TacticM α))
|
||||
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) := do
|
||||
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
|
||||
let chunks := toChunks jobs chunkSize
|
||||
let chunkJobs : List (TacticM (List (Except Exception α))) :=
|
||||
chunks.map fun (chunk : List (TacticM α)) => chunk.mapM fun job => do
|
||||
try
|
||||
let a ← job
|
||||
pure (.ok a)
|
||||
catch e =>
|
||||
pure (.error e)
|
||||
let (cancels, tasks) := (← chunkJobs.mapM asTask).unzip
|
||||
let combinedCancel := cancels.forM id
|
||||
let flatIter := toIterM (ChunkedTaskIterator.mk tasks []) TacticM (Except Exception α)
|
||||
return (combinedCancel, flatIter)
|
||||
|
||||
/--
|
||||
Runs a list of TacticM computations in parallel (without cancellation hook).
|
||||
|
||||
@@ -599,19 +1037,9 @@ Returns an iterator that yields results in completion order, wrapped in `Except
|
||||
def parIterGreedy {α : Type} (jobs : List (TacticM α)) :=
|
||||
(·.2) <$> parIterGreedyWithCancel jobs
|
||||
|
||||
/--
|
||||
Runs a list of TacticM computations in parallel and collects results in the original order,
|
||||
including the saved state after each task completes.
|
||||
|
||||
Unlike `parIter`, this waits for all tasks to complete and returns results
|
||||
in the same order as the input list, not in completion order.
|
||||
|
||||
Results are wrapped in `Except Exception (α × Tactic.SavedState)` so that errors in individual
|
||||
tasks don't stop the collection - you can observe all results including which tasks failed.
|
||||
|
||||
The final TacticM state is restored to the initial state (before tasks ran).
|
||||
-/
|
||||
def par {α : Type} (jobs : List (TacticM α)) : TacticM (List (Except Exception (α × Tactic.SavedState))) := do
|
||||
/-- Internal: run jobs in parallel without chunking, returning state. -/
|
||||
private def parCore {α : Type} (jobs : List (TacticM α)) :
|
||||
TacticM (List (Except Exception (α × Tactic.SavedState))) := do
|
||||
let initialState ← get
|
||||
let tasks ← jobs.mapM asTask'
|
||||
let mut results := []
|
||||
@@ -625,15 +1053,9 @@ def par {α : Type} (jobs : List (TacticM α)) : TacticM (List (Except Exception
|
||||
set initialState
|
||||
return results.reverse
|
||||
|
||||
/--
|
||||
Runs a list of TacticM computations in parallel and collects results in the original order,
|
||||
discarding state information.
|
||||
|
||||
Unlike `par`, this doesn't return state information from tasks.
|
||||
|
||||
The final TacticM state is restored to the initial state (before tasks ran).
|
||||
-/
|
||||
def par' {α : Type} (jobs : List (TacticM α)) : TacticM (List (Except Exception α)) := do
|
||||
/-- Internal: run jobs in parallel without chunking, discarding state. -/
|
||||
private def parCore' {α : Type} (jobs : List (TacticM α)) :
|
||||
TacticM (List (Except Exception α)) := do
|
||||
let initialState ← get
|
||||
let tasks ← jobs.mapM asTask'
|
||||
let mut results := []
|
||||
@@ -646,6 +1068,86 @@ def par' {α : Type} (jobs : List (TacticM α)) : TacticM (List (Except Exceptio
|
||||
set initialState
|
||||
return results.reverse
|
||||
|
||||
/--
|
||||
Runs a list of TacticM computations in parallel and collects results in the original order,
|
||||
including the saved state after each task completes.
|
||||
|
||||
Unlike `parIter`, this waits for all tasks to complete and returns results
|
||||
in the same order as the input list, not in completion order.
|
||||
|
||||
Results are wrapped in `Except Exception (α × Tactic.SavedState)` so that errors in individual
|
||||
tasks don't stop the collection - you can observe all results including which tasks failed.
|
||||
|
||||
The final TacticM state is restored to the initial state (before tasks ran).
|
||||
|
||||
**Chunking:** Pass `maxTasks > 0` to limit parallel tasks by grouping jobs into chunks.
|
||||
-/
|
||||
def par {α : Type} (jobs : List (TacticM α))
|
||||
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) :
|
||||
TacticM (List (Except Exception (α × Tactic.SavedState))) := do
|
||||
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
|
||||
if chunkSize ≤ 1 then
|
||||
parCore jobs
|
||||
else
|
||||
let initialState ← get
|
||||
let chunks := toChunks jobs chunkSize
|
||||
-- Each chunk processes its jobs sequentially, collecting Except results
|
||||
let chunkJobs := chunks.map fun chunk => do
|
||||
let mut results : List (Except Exception (α × Tactic.SavedState)) := []
|
||||
for job in chunk do
|
||||
try
|
||||
let a ← job
|
||||
let s ← Tactic.saveState
|
||||
results := .ok (a, s) :: results
|
||||
catch e =>
|
||||
results := .error e :: results
|
||||
pure results.reverse
|
||||
let chunkResults ← parCore' chunkJobs
|
||||
set initialState
|
||||
let mut allResults := []
|
||||
for chunkResult in chunkResults do
|
||||
match chunkResult with
|
||||
| .ok jobResults => allResults := allResults ++ jobResults
|
||||
| .error e => allResults := allResults ++ [.error e]
|
||||
return allResults
|
||||
|
||||
/--
|
||||
Runs a list of TacticM computations in parallel and collects results in the original order,
|
||||
discarding state information.
|
||||
|
||||
Unlike `par`, this doesn't return state information from tasks.
|
||||
|
||||
The final TacticM state is restored to the initial state (before tasks ran).
|
||||
|
||||
**Chunking:** Pass `maxTasks > 0` to limit parallel tasks by grouping jobs into chunks.
|
||||
-/
|
||||
def par' {α : Type} (jobs : List (TacticM α))
|
||||
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) :
|
||||
TacticM (List (Except Exception α)) := do
|
||||
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
|
||||
if chunkSize ≤ 1 then
|
||||
parCore' jobs
|
||||
else
|
||||
let initialState ← get
|
||||
let chunks := toChunks jobs chunkSize
|
||||
let chunkJobs := chunks.map fun chunk => do
|
||||
let mut results : List (Except Exception α) := []
|
||||
for job in chunk do
|
||||
try
|
||||
let a ← job
|
||||
results := .ok a :: results
|
||||
catch e =>
|
||||
results := .error e :: results
|
||||
pure results.reverse
|
||||
let chunkResults ← parCore' chunkJobs
|
||||
set initialState
|
||||
let mut allResults := []
|
||||
for chunkResult in chunkResults do
|
||||
match chunkResult with
|
||||
| .ok jobResults => allResults := allResults ++ jobResults
|
||||
| .error e => allResults := allResults ++ [.error e]
|
||||
return allResults
|
||||
|
||||
/--
|
||||
Runs a list of TacticM computations in parallel and returns the first successful result
|
||||
(by completion order, not list order).
|
||||
|
||||
@@ -13,6 +13,7 @@ public import Lean.Meta.Tactic.Refl
|
||||
public import Lean.Meta.Tactic.SolveByElim
|
||||
public import Lean.Meta.Tactic.TryThis
|
||||
public import Lean.Util.Heartbeats
|
||||
public import Lean.Elab.Parallel
|
||||
|
||||
public section
|
||||
|
||||
@@ -286,61 +287,44 @@ def RewriteResult.addSuggestion (ref : Syntax) (r : RewriteResult)
|
||||
(type? := r.newGoal.toLOption) (origSpan? := ← getRef)
|
||||
(checkState? := checkState?.getD (← saveState))
|
||||
|
||||
structure RewriteResultConfig where
|
||||
stopAtRfl : Bool
|
||||
max : Nat
|
||||
minHeartbeats : Nat
|
||||
goal : MVarId
|
||||
target : Expr
|
||||
side : SideConditions := .solveByElim
|
||||
mctx : MetavarContext
|
||||
/--
|
||||
Find lemmas which can rewrite the goal.
|
||||
|
||||
def takeListAux (cfg : RewriteResultConfig) (seen : Std.HashMap String Unit) (acc : Array RewriteResult)
|
||||
(xs : List ((Expr ⊕ Name) × Bool × Nat)) : MetaM (Array RewriteResult) := do
|
||||
let mut seen := seen
|
||||
let mut acc := acc
|
||||
for (lem, symm, weight) in xs do
|
||||
if (← getRemainingHeartbeats) < cfg.minHeartbeats then
|
||||
return acc
|
||||
if acc.size ≥ cfg.max then
|
||||
return acc
|
||||
let res ←
|
||||
withoutModifyingState <| withMCtx cfg.mctx do
|
||||
rwLemma cfg.mctx cfg.goal cfg.target cfg.side lem symm weight
|
||||
match res with
|
||||
| none => continue
|
||||
| some r =>
|
||||
let s ← withoutModifyingState <| withMCtx r.mctx r.ppResult
|
||||
if seen.contains s then
|
||||
continue
|
||||
let rfl? ← dischargableWithRfl? r.mctx r.result.eNew
|
||||
if cfg.stopAtRfl then
|
||||
if rfl? then
|
||||
return #[r]
|
||||
else
|
||||
seen := seen.insert s ()
|
||||
acc := acc.push r
|
||||
else
|
||||
seen := seen.insert s ()
|
||||
acc := acc.push r
|
||||
return acc
|
||||
|
||||
/-- Find lemmas which can rewrite the goal. -/
|
||||
Runs all candidates in parallel, iterates through results in order.
|
||||
Cancels remaining tasks and returns immediately if `stopAtRfl` is true and
|
||||
an rfl-closeable result is found. Collects up to `max` unique results.
|
||||
-/
|
||||
def findRewrites (hyps : Array (Expr × Bool × Nat))
|
||||
(moduleRef : LazyDiscrTree.ModuleDiscrTreeRef (Name × RwDirection))
|
||||
(goal : MVarId) (target : Expr)
|
||||
(forbidden : NameSet := ∅) (side : SideConditions := .solveByElim)
|
||||
(stopAtRfl : Bool) (max : Nat := 20)
|
||||
(leavePercentHeartbeats : Nat := 10) : MetaM (List RewriteResult) := do
|
||||
(stopAtRfl : Bool) (max : Nat := 20) : MetaM (List RewriteResult) := do
|
||||
let mctx ← getMCtx
|
||||
let candidates ← rewriteCandidates hyps moduleRef target forbidden
|
||||
let minHeartbeats : Nat ←
|
||||
if (← getMaxHeartbeats) = 0 then
|
||||
pure 0
|
||||
else
|
||||
pure <| leavePercentHeartbeats * (← getRemainingHeartbeats) / 100
|
||||
let cfg : RewriteResultConfig :=
|
||||
{ stopAtRfl, minHeartbeats, max, mctx, goal, target, side }
|
||||
return (← takeListAux cfg {} (Array.mkEmpty max) candidates.toList).toList
|
||||
-- Create parallel jobs for each candidate
|
||||
let jobs := candidates.toList.map fun (lem, symm, weight) => do
|
||||
withoutModifyingState <| withMCtx mctx do
|
||||
let some r ← rwLemma mctx goal target side lem symm weight
|
||||
| return none
|
||||
let s ← withoutModifyingState <| withMCtx r.mctx r.ppResult
|
||||
return some (r, s)
|
||||
let (cancel, iter) ← MetaM.parIterWithCancelChunked jobs (maxTasks := 128)
|
||||
let mut seen : Std.HashMap String Unit := {}
|
||||
let mut acc : Array RewriteResult := Array.mkEmpty max
|
||||
for result in iter.allowNontermination do
|
||||
if acc.size ≥ max then
|
||||
cancel
|
||||
break
|
||||
match result with
|
||||
| .error _ => continue
|
||||
| .ok none => continue
|
||||
| .ok (some (r, s)) =>
|
||||
if seen.contains s then continue
|
||||
seen := seen.insert s ()
|
||||
if stopAtRfl && r.rfl? then
|
||||
cancel
|
||||
return [r]
|
||||
acc := acc.push r
|
||||
return acc.toList
|
||||
|
||||
end Lean.Meta.Rewrites
|
||||
|
||||
78
tests/lean/run/parallel_chunked.lean
Normal file
78
tests/lean/run/parallel_chunked.lean
Normal file
@@ -0,0 +1,78 @@
|
||||
import Lean.Elab.Parallel
|
||||
|
||||
/-!
|
||||
# Tests for chunked parallel execution
|
||||
|
||||
Tests for the chunking support in `Lean.Elab.Parallel`.
|
||||
-/
|
||||
|
||||
open Lean Core
|
||||
|
||||
/-! ## CoreM tests -/
|
||||
|
||||
/-- Test that par with maxTasks=0 (default) works like original -/
|
||||
def testCoreMParDefault : CoreM Unit := do
|
||||
let jobs := (List.range 10).map fun i => pure i
|
||||
let results ← CoreM.par jobs
|
||||
let values := results.filterMap fun r =>
|
||||
match r with
|
||||
| .ok (v, _) => some v
|
||||
| .error _ => none
|
||||
assert! values == List.range 10
|
||||
|
||||
/-- Test that par with chunking produces same results -/
|
||||
def testCoreMParChunked : CoreM Unit := do
|
||||
let jobs := (List.range 20).map fun i => pure i
|
||||
let results ← CoreM.par jobs (maxTasks := 4) (minChunkSize := 2)
|
||||
let values := results.filterMap fun r =>
|
||||
match r with
|
||||
| .ok (v, _) => some v
|
||||
| .error _ => none
|
||||
-- Results should be in original order
|
||||
assert! values == List.range 20
|
||||
|
||||
/-- Test that par' with chunking works -/
|
||||
def testCoreMPar'Chunked : CoreM Unit := do
|
||||
let jobs := (List.range 15).map fun i => pure (i * 2)
|
||||
let results ← CoreM.par' jobs (maxTasks := 3) (minChunkSize := 2)
|
||||
let values := results.filterMap fun r =>
|
||||
match r with
|
||||
| .ok v => some v
|
||||
| .error _ => none
|
||||
assert! values == (List.range 15).map (· * 2)
|
||||
|
||||
/-- Test error handling in chunks -/
|
||||
def testCoreMParChunkedErrors : CoreM Unit := do
|
||||
let jobs := (List.range 10).map fun i =>
|
||||
if i == 5 then throwError "error at 5"
|
||||
else pure i
|
||||
let results ← CoreM.par' jobs (maxTasks := 3) (minChunkSize := 2)
|
||||
-- Should have 9 successes and 1 error
|
||||
let successes := results.filter fun r => match r with | .ok _ => true | .error _ => false
|
||||
let errors := results.filter fun r => match r with | .ok _ => false | .error _ => true
|
||||
assert! successes.length == 9
|
||||
assert! errors.length == 1
|
||||
|
||||
#eval do
|
||||
let env ← importModules #[{ module := `Init }] {} 0
|
||||
let (_, _) ← testCoreMParDefault.toIO { fileName := "", fileMap := default } { env }
|
||||
IO.println "testCoreMParDefault passed"
|
||||
|
||||
#eval do
|
||||
let env ← importModules #[{ module := `Init }] {} 0
|
||||
let (_, _) ← testCoreMParChunked.toIO { fileName := "", fileMap := default } { env }
|
||||
IO.println "testCoreMParChunked passed"
|
||||
|
||||
#eval do
|
||||
let env ← importModules #[{ module := `Init }] {} 0
|
||||
let (_, _) ← testCoreMPar'Chunked.toIO { fileName := "", fileMap := default } { env }
|
||||
IO.println "testCoreMPar'Chunked passed"
|
||||
|
||||
#eval do
|
||||
let env ← importModules #[{ module := `Init }] {} 0
|
||||
let (_, _) ← testCoreMParChunkedErrors.toIO { fileName := "", fileMap := default } { env }
|
||||
IO.println "testCoreMParChunkedErrors passed"
|
||||
|
||||
/-! ## All tests passed -/
|
||||
|
||||
#eval IO.println "All parallel_chunked tests passed!"
|
||||
Reference in New Issue
Block a user