Compare commits

...

3 Commits

Author SHA1 Message Date
Kim Morrison
d6fc6e6b45 perf: parallelize rw? tactic
Use `MetaM.parIterWithCancel` to try all candidate rewrites in parallel
while preserving deterministic result ordering. When an rfl-closeable
result is found (and `stopAtRfl` is true), or the maximum number of
results is reached, remaining tasks are cancelled.

This removes the old sequential `takeListAux` implementation along with
the heartbeat-based early termination and `RewriteResultConfig` structure.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-02 18:39:38 +11:00
Kim Morrison
272f0f5db3 feat: add chunked variants of parIterWithCancel
Add `parIterWithCancelChunked` functions for CoreM, MetaM, TermElabM, and TacticM that support chunking jobs into groups to reduce task creation overhead.

The original `parIterWithCancel` functions remain unchanged for backward compatibility. The new chunked variants accept `maxTasks` and `minChunkSize` parameters to control parallelism.

This enables PRs that use `parIterWithCancel` (like parallel library search and rewrites) to benefit from chunking by switching to the new `parIterWithCancelChunked` function with `maxTasks := 128`.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-02 18:35:46 +11:00
Kim Morrison
1d3fda4130 feat: add chunking support to par and par' in Lean.Elab.Parallel
This PR adds optional chunking support to the `par` and `par'` functions in
`Lean.Elab.Parallel` for CoreM, MetaM, TermElabM, and TacticM. This reduces
task creation overhead when there are many small jobs by grouping them into
chunks that run sequentially within each parallel task.

New optional parameters:
- `maxTasks : Nat := 0` - Maximum number of parallel tasks (0 = no limit)
- `minChunkSize : Nat := 1` - Minimum jobs per chunk

Example: With 1000 jobs and `maxTasks := 128, minChunkSize := 8`:
- Chunk size = max(8, ceil(1000/128)) = 8
- Creates ~125 parallel tasks instead of 1000

Default behavior (maxTasks = 0) is unchanged - one task per job.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-02 18:22:40 +11:00
3 changed files with 696 additions and 132 deletions

View File

@@ -59,8 +59,38 @@ as tasks complete, unlike `par`/`par'` which restore the initial state after col
Iterators do not have `Finite` instances, as we cannot prove termination from the available
information. For consumers that require `Finite` (like `.toList`), use `.allowNontermination.toList`.
## Chunking support
The `par`, `par'`, `parIter`, and `parIterWithCancel` functions support optional chunking to reduce
task creation overhead when there are many small jobs. Pass `maxTasks` to limit the number of parallel
tasks created; jobs will be grouped into chunks that run sequentially within each task.
- `maxTasks = 0` (default): No chunking, one task per job (original behavior)
- `maxTasks > 0`: Auto-compute chunk size to limit task count
- `minChunkSize`: Minimum jobs per chunk (default 1)
Example: With 1000 jobs and `maxTasks := 128, minChunkSize := 8`, chunk size = 8, creating ~125 tasks.
-/
/-- Split a list into chunks of at most `chunkSize` elements. -/
def toChunks {α : Type} (xs : List α) (chunkSize : Nat) : List (List α) :=
if h : chunkSize 1 then xs.map ([·])
else go xs [] (Nat.lt_of_not_le h)
where
go (remaining : List α) (acc : List (List α)) (hc : 1 < chunkSize) : List (List α) :=
if _h : remaining.length chunkSize then
(remaining :: acc).reverse
else
go (remaining.drop chunkSize) (remaining.take chunkSize :: acc) hc
termination_by remaining.length
decreasing_by simp_wf; omega
/-- Compute chunk size given job count, max tasks, and minimum chunk size. -/
def computeChunkSize (numJobs maxTasks minChunkSize : Nat) : Nat :=
if maxTasks = 0 then 1
else max minChunkSize ((numJobs + maxTasks - 1) / maxTasks)
public section
namespace Std.Iterators
@@ -125,6 +155,38 @@ namespace Lean.Core.CoreM
open Std.Iterators
/--
Internal state for an iterator over chunked tasks for CoreM.
Yields individual results while internally managing chunk boundaries.
-/
private structure ChunkedTaskIterator (α : Type) where
chunkTasks : List (Task (CoreM (List (Except Exception α))))
currentResults : List (Except Exception α)
private instance {α : Type} : Iterator (ChunkedTaskIterator α) CoreM (Except Exception α) where
IsPlausibleStep _
| .yield _ _ => True
| .skip _ => True -- Allow skip for empty chunks
| .done => True
step it := do
match it.internalState.currentResults with
| r :: rest =>
pure <| .deflate .yield (toIterM { it.internalState with currentResults := rest } CoreM (Except Exception α)) r, trivial
| [] =>
match it.internalState.chunkTasks with
| [] => pure <| .deflate .done, trivial
| task :: rest =>
try
let chunkResults task.get
match chunkResults with
| [] =>
-- Empty chunk, skip to try next
pure <| .deflate .skip (toIterM { chunkTasks := rest, currentResults := [] } CoreM (Except Exception α)), trivial
| r :: rs =>
pure <| .deflate .yield (toIterM { chunkTasks := rest, currentResults := rs } CoreM (Except Exception α)) r, trivial
catch e =>
pure <| .deflate .yield (toIterM { chunkTasks := rest, currentResults := [] } CoreM (Except Exception α)) (.error e), trivial
/--
Runs a list of CoreM computations in parallel and returns:
* a combined cancellation hook for all tasks, and
@@ -150,6 +212,29 @@ def parIterWithCancel {α : Type} (jobs : List (CoreM α)) := do
pure (Except.error e)
return (combinedCancel, iterWithErrors)
/--
Runs a list of CoreM computations in parallel with chunking and returns:
* a combined cancellation hook for all tasks, and
* an iterator that yields results in original order.
Unlike `parIterWithCancel`, this groups jobs into chunks to reduce task overhead.
Each chunk runs its jobs sequentially, but chunks run in parallel.
**Parameters:**
- `maxTasks`: Maximum number of parallel tasks (chunks). Default 0 means one task per job.
- `minChunkSize`: Minimum jobs per chunk. Default 1.
-/
def parIterWithCancelChunked {α : Type} (jobs : List (CoreM α))
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) := do
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
let chunks := toChunks jobs chunkSize
let chunkJobs : List (CoreM (List (Except Exception α))) :=
chunks.map fun (chunk : List (CoreM α)) => chunk.mapM (observing ·)
let (cancels, tasks) := ( chunkJobs.mapM asTask).unzip
let combinedCancel := cancels.forM id
let flatIter := toIterM (ChunkedTaskIterator.mk tasks []) CoreM (Except Exception α)
return (combinedCancel, flatIter)
/--
Runs a list of CoreM computations in parallel (without cancellation hook).
@@ -192,19 +277,9 @@ Returns an iterator that yields results in completion order, wrapped in `Except
def parIterGreedy {α : Type} (jobs : List (CoreM α)) :=
(·.2) <$> parIterGreedyWithCancel jobs
/--
Runs a list of CoreM computations in parallel and collects results in the original order,
including the saved state after each task completes.
Unlike `parIter`, this waits for all tasks to complete and returns results
in the same order as the input list, not in completion order.
Results are wrapped in `Except Exception (α × Core.SavedState)` so that errors in individual
tasks don't stop the collection - you can observe all results including which tasks failed.
The final CoreM state is restored to the initial state (before tasks ran).
-/
def par {α : Type} (jobs : List (CoreM α)) : CoreM (List (Except Exception (α × Core.SavedState))) := do
/-- Internal: run jobs in parallel without chunking, returning state. -/
private def parCore {α : Type} (jobs : List (CoreM α)) :
CoreM (List (Except Exception (α × Core.SavedState))) := do
let initialState get
let tasks jobs.mapM asTask'
let mut results := []
@@ -218,13 +293,47 @@ def par {α : Type} (jobs : List (CoreM α)) : CoreM (List (Except Exception (α
/--
Runs a list of CoreM computations in parallel and collects results in the original order,
discarding state information.
including the saved state after each task completes.
Unlike `par`, this doesn't return state information from tasks.
Unlike `parIter`, this waits for all tasks to complete and returns results
in the same order as the input list, not in completion order.
Results are wrapped in `Except Exception (α × Core.SavedState)` so that errors in individual
tasks don't stop the collection - you can observe all results including which tasks failed.
The final CoreM state is restored to the initial state (before tasks ran).
**Chunking:** Pass `maxTasks > 0` to limit parallel tasks by grouping jobs into chunks.
-/
def par' {α : Type} (jobs : List (CoreM α)) : CoreM (List (Except Exception α)) := do
def par {α : Type} (jobs : List (CoreM α))
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) :
CoreM (List (Except Exception (α × Core.SavedState))) := do
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
if chunkSize 1 then
parCore jobs
else
let initialState get
let chunks := toChunks jobs chunkSize
let chunkJobs := chunks.map fun chunk => do
let mut results := []
for job in chunk do
let r observing do
let a job
pure (a, saveState)
results := r :: results
pure results.reverse
let chunkResults parCore chunkJobs
set initialState
let mut allResults := []
for chunkResult in chunkResults do
match chunkResult with
| .ok (jobResults, _) => allResults := allResults ++ jobResults
| .error e => allResults := allResults ++ [.error e]
return allResults
/-- Internal: run jobs in parallel without chunking, discarding state. -/
private def parCore' {α : Type} (jobs : List (CoreM α)) :
CoreM (List (Except Exception α)) := do
let initialState get
let tasks jobs.mapM asTask'
let mut results := []
@@ -237,6 +346,40 @@ def par' {α : Type} (jobs : List (CoreM α)) : CoreM (List (Except Exception α
set initialState
return results.reverse
/--
Runs a list of CoreM computations in parallel and collects results in the original order,
discarding state information.
Unlike `par`, this doesn't return state information from tasks.
The final CoreM state is restored to the initial state (before tasks ran).
**Chunking:** Pass `maxTasks > 0` to limit parallel tasks by grouping jobs into chunks.
-/
def par' {α : Type} (jobs : List (CoreM α))
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) :
CoreM (List (Except Exception α)) := do
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
if chunkSize 1 then
parCore' jobs
else
let initialState get
let chunks := toChunks jobs chunkSize
let chunkJobs := chunks.map fun chunk => do
let mut results := []
for job in chunk do
let r observing job
results := r :: results
pure results.reverse
let chunkResults parCore' chunkJobs
set initialState
let mut allResults := []
for chunkResult in chunkResults do
match chunkResult with
| .ok jobResults => allResults := allResults ++ jobResults
| .error e => allResults := allResults ++ [.error e]
return allResults
/--
Runs a list of CoreM computations in parallel and returns the first successful result
(by completion order, not list order).
@@ -260,18 +403,43 @@ namespace Lean.Meta.MetaM
open Std.Iterators
/--
Runs a list of MetaM computations in parallel and collects results in the original order,
including the saved state after each task completes.
Unlike `parIter`, this waits for all tasks to complete and returns results
in the same order as the input list, not in completion order.
Results are wrapped in `Except Exception (α × Meta.SavedState)` so that errors in individual
tasks don't stop the collection - you can observe all results including which tasks failed.
The final MetaM state is restored to the initial state (before tasks ran).
Internal state for an iterator over chunked tasks for MetaM.
Yields individual results while internally managing chunk boundaries.
-/
def par {α : Type} (jobs : List (MetaM α)) : MetaM (List (Except Exception (α × Meta.SavedState))) := do
structure ChunkedTaskIterator (α : Type) where
chunkTasks : List (Task (MetaM (List (Except Exception α))))
currentResults : List (Except Exception α)
instance {α : Type} : Iterator (ChunkedTaskIterator α) MetaM (Except Exception α) where
IsPlausibleStep _
| .yield _ _ => True
| .skip _ => True
| .done => True
step it := do
match it.internalState.currentResults with
| r :: rest =>
pure <| .deflate .yield (toIterM { it.internalState with currentResults := rest } MetaM (Except Exception α)) r, trivial
| [] =>
match it.internalState.chunkTasks with
| [] => pure <| .deflate .done, trivial
| task :: rest =>
try
let chunkResults task.get
match chunkResults with
| [] =>
pure <| .deflate .skip (toIterM { chunkTasks := rest, currentResults := [] } MetaM (Except Exception α)), trivial
| r :: rs =>
pure <| .deflate .yield (toIterM { chunkTasks := rest, currentResults := rs } MetaM (Except Exception α)) r, trivial
catch e =>
pure <| .deflate .yield (toIterM { chunkTasks := rest, currentResults := [] } MetaM (Except Exception α)) (.error e), trivial
instance {α : Type} {n : Type Type u} [Monad n] [MonadLiftT MetaM n] :
IteratorLoopPartial (ChunkedTaskIterator α) MetaM n :=
.defaultImplementation
/-- Internal: run jobs in parallel without chunking, returning state. -/
private def parCore {α : Type} (jobs : List (MetaM α)) :
MetaM (List (Except Exception (α × Meta.SavedState))) := do
let initialState get
let tasks jobs.mapM asTask'
let mut results := []
@@ -283,15 +451,9 @@ def par {α : Type} (jobs : List (MetaM α)) : MetaM (List (Except Exception (α
set initialState
return results.reverse
/--
Runs a list of MetaM computations in parallel and collects results in the original order,
discarding state information.
Unlike `par`, this doesn't return state information from tasks.
The final MetaM state is restored to the initial state (before tasks ran).
-/
def par' {α : Type} (jobs : List (MetaM α)) : MetaM (List (Except Exception α)) := do
/-- Internal: run jobs in parallel without chunking, discarding state. -/
private def parCore' {α : Type} (jobs : List (MetaM α)) :
MetaM (List (Except Exception α)) := do
let initialState get
let tasks jobs.mapM asTask'
let mut results := []
@@ -304,6 +466,80 @@ def par' {α : Type} (jobs : List (MetaM α)) : MetaM (List (Except Exception α
set initialState
return results.reverse
/--
Runs a list of MetaM computations in parallel and collects results in the original order,
including the saved state after each task completes.
Unlike `parIter`, this waits for all tasks to complete and returns results
in the same order as the input list, not in completion order.
Results are wrapped in `Except Exception (α × Meta.SavedState)` so that errors in individual
tasks don't stop the collection - you can observe all results including which tasks failed.
The final MetaM state is restored to the initial state (before tasks ran).
**Chunking:** Pass `maxTasks > 0` to limit parallel tasks by grouping jobs into chunks.
-/
def par {α : Type} (jobs : List (MetaM α))
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) :
MetaM (List (Except Exception (α × Meta.SavedState))) := do
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
if chunkSize 1 then
parCore jobs
else
let initialState get
let chunks := toChunks jobs chunkSize
let chunkJobs := chunks.map fun chunk => do
let mut results := []
for job in chunk do
let r observing do
let a job
pure (a, saveState)
results := r :: results
pure results.reverse
let chunkResults parCore chunkJobs
set initialState
let mut allResults := []
for chunkResult in chunkResults do
match chunkResult with
| .ok (jobResults, _) => allResults := allResults ++ jobResults
| .error e => allResults := allResults ++ [.error e]
return allResults
/--
Runs a list of MetaM computations in parallel and collects results in the original order,
discarding state information.
Unlike `par`, this doesn't return state information from tasks.
The final MetaM state is restored to the initial state (before tasks ran).
**Chunking:** Pass `maxTasks > 0` to limit parallel tasks by grouping jobs into chunks.
-/
def par' {α : Type} (jobs : List (MetaM α))
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) :
MetaM (List (Except Exception α)) := do
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
if chunkSize 1 then
parCore' jobs
else
let initialState get
let chunks := toChunks jobs chunkSize
let chunkJobs := chunks.map fun chunk => do
let mut results := []
for job in chunk do
let r observing job
results := r :: results
pure results.reverse
let chunkResults parCore' chunkJobs
set initialState
let mut allResults := []
for chunkResult in chunkResults do
match chunkResult with
| .ok jobResults => allResults := allResults ++ jobResults
| .error e => allResults := allResults ++ [.error e]
return allResults
/--
Runs a list of MetaM computations in parallel and returns:
* a combined cancellation hook for all tasks, and
@@ -321,7 +557,6 @@ The iterator will terminate after all jobs complete (assuming they all do comple
def parIterWithCancel {α : Type} (jobs : List (MetaM α)) := do
let (cancels, tasks) := ( jobs.mapM asTask).unzip
let combinedCancel := cancels.forM id
-- Create iterator that processes tasks sequentially
let iterWithErrors := tasks.iter.mapM fun (task : Task (MetaM α)) => do
try
let result task.get
@@ -330,6 +565,29 @@ def parIterWithCancel {α : Type} (jobs : List (MetaM α)) := do
pure (Except.error e)
return (combinedCancel, iterWithErrors)
/--
Runs a list of MetaM computations in parallel with chunking and returns:
* a combined cancellation hook for all tasks, and
* an iterator that yields results in original order.
Unlike `parIterWithCancel`, this groups jobs into chunks to reduce task overhead.
Each chunk runs its jobs sequentially, but chunks run in parallel.
**Parameters:**
- `maxTasks`: Maximum number of parallel tasks (chunks). Default 0 means one task per job.
- `minChunkSize`: Minimum jobs per chunk. Default 1.
-/
def parIterWithCancelChunked {α : Type} (jobs : List (MetaM α))
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) := do
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
let chunks := toChunks jobs chunkSize
let chunkJobs : List (MetaM (List (Except Exception α))) :=
chunks.map fun (chunk : List (MetaM α)) => chunk.mapM (observing ·)
let (cancels, tasks) := ( chunkJobs.mapM asTask).unzip
let combinedCancel := cancels.forM id
let flatIter := toIterM (ChunkedTaskIterator.mk tasks []) MetaM (Except Exception α)
return (combinedCancel, flatIter)
/--
Runs a list of MetaM computations in parallel (without cancellation hook).
@@ -394,6 +652,37 @@ namespace Lean.Elab.Term.TermElabM
open Std.Iterators
/--
Internal state for an iterator over chunked tasks for TermElabM.
Yields individual results while internally managing chunk boundaries.
-/
private structure ChunkedTaskIterator (α : Type) where
chunkTasks : List (Task (TermElabM (List (Except Exception α))))
currentResults : List (Except Exception α)
private instance {α : Type} : Iterator (ChunkedTaskIterator α) TermElabM (Except Exception α) where
IsPlausibleStep _
| .yield _ _ => True
| .skip _ => True
| .done => True
step it := do
match it.internalState.currentResults with
| r :: rest =>
pure <| .deflate .yield (toIterM { it.internalState with currentResults := rest } TermElabM (Except Exception α)) r, trivial
| [] =>
match it.internalState.chunkTasks with
| [] => pure <| .deflate .done, trivial
| task :: rest =>
try
let chunkResults task.get
match chunkResults with
| [] =>
pure <| .deflate .skip (toIterM { chunkTasks := rest, currentResults := [] } TermElabM (Except Exception α)), trivial
| r :: rs =>
pure <| .deflate .yield (toIterM { chunkTasks := rest, currentResults := rs } TermElabM (Except Exception α)) r, trivial
catch e =>
pure <| .deflate .yield (toIterM { chunkTasks := rest, currentResults := [] } TermElabM (Except Exception α)) (.error e), trivial
/--
Runs a list of TermElabM computations in parallel and returns:
* a combined cancellation hook for all tasks, and
@@ -411,7 +700,6 @@ The iterator will terminate after all jobs complete (assuming they all do comple
def parIterWithCancel {α : Type} (jobs : List (TermElabM α)) := do
let (cancels, tasks) := ( jobs.mapM asTask).unzip
let combinedCancel := cancels.forM id
-- Create iterator that processes tasks sequentially
let iterWithErrors := tasks.iter.mapM fun (task : Task (TermElabM α)) => do
try
let result task.get
@@ -420,6 +708,34 @@ def parIterWithCancel {α : Type} (jobs : List (TermElabM α)) := do
pure (Except.error e)
return (combinedCancel, iterWithErrors)
/--
Runs a list of TermElabM computations in parallel with chunking and returns:
* a combined cancellation hook for all tasks, and
* an iterator that yields results in original order.
Unlike `parIterWithCancel`, this groups jobs into chunks to reduce task overhead.
Each chunk runs its jobs sequentially, but chunks run in parallel.
**Parameters:**
- `maxTasks`: Maximum number of parallel tasks (chunks). Default 0 means one task per job.
- `minChunkSize`: Minimum jobs per chunk. Default 1.
-/
def parIterWithCancelChunked {α : Type} (jobs : List (TermElabM α))
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) := do
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
let chunks := toChunks jobs chunkSize
let chunkJobs : List (TermElabM (List (Except Exception α))) :=
chunks.map fun (chunk : List (TermElabM α)) => chunk.mapM fun job => do
try
let a job
pure (.ok a)
catch e =>
pure (.error e)
let (cancels, tasks) := ( chunkJobs.mapM asTask).unzip
let combinedCancel := cancels.forM id
let flatIter := toIterM (ChunkedTaskIterator.mk tasks []) TermElabM (Except Exception α)
return (combinedCancel, flatIter)
/--
Runs a list of TermElabM computations in parallel (without cancellation hook).
@@ -462,19 +778,9 @@ Returns an iterator that yields results in completion order, wrapped in `Except
def parIterGreedy {α : Type} (jobs : List (TermElabM α)) :=
(·.2) <$> parIterGreedyWithCancel jobs
/--
Runs a list of TermElabM computations in parallel and collects results in the original order,
including the saved state after each task completes.
Unlike `parIter`, this waits for all tasks to complete and returns results
in the same order as the input list, not in completion order.
Results are wrapped in `Except Exception (α × Term.SavedState)` so that errors in individual
tasks don't stop the collection - you can observe all results including which tasks failed.
The final TermElabM state is restored to the initial state (before tasks ran).
-/
def par {α : Type} (jobs : List (TermElabM α)) : TermElabM (List (Except Exception (α × Term.SavedState))) := do
/-- Internal: run jobs in parallel without chunking, returning state. -/
private def parCore {α : Type} (jobs : List (TermElabM α)) :
TermElabM (List (Except Exception (α × Term.SavedState))) := do
let initialState get
let tasks jobs.mapM asTask'
let mut results := []
@@ -488,15 +794,9 @@ def par {α : Type} (jobs : List (TermElabM α)) : TermElabM (List (Except Excep
set initialState
return results.reverse
/--
Runs a list of TermElabM computations in parallel and collects results in the original order,
discarding state information.
Unlike `par`, this doesn't return state information from tasks.
The final TermElabM state is restored to the initial state (before tasks ran).
-/
def par' {α : Type} (jobs : List (TermElabM α)) : TermElabM (List (Except Exception α)) := do
/-- Internal: run jobs in parallel without chunking, discarding state. -/
private def parCore' {α : Type} (jobs : List (TermElabM α)) :
TermElabM (List (Except Exception α)) := do
let initialState get
let tasks jobs.mapM asTask'
let mut results := []
@@ -509,6 +809,86 @@ def par' {α : Type} (jobs : List (TermElabM α)) : TermElabM (List (Except Exce
set initialState
return results.reverse
/--
Runs a list of TermElabM computations in parallel and collects results in the original order,
including the saved state after each task completes.
Unlike `parIter`, this waits for all tasks to complete and returns results
in the same order as the input list, not in completion order.
Results are wrapped in `Except Exception (α × Term.SavedState)` so that errors in individual
tasks don't stop the collection - you can observe all results including which tasks failed.
The final TermElabM state is restored to the initial state (before tasks ran).
**Chunking:** Pass `maxTasks > 0` to limit parallel tasks by grouping jobs into chunks.
-/
def par {α : Type} (jobs : List (TermElabM α))
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) :
TermElabM (List (Except Exception (α × Term.SavedState))) := do
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
if chunkSize 1 then
parCore jobs
else
let initialState get
let chunks := toChunks jobs chunkSize
-- Each chunk processes its jobs sequentially, collecting Except results
let chunkJobs := chunks.map fun chunk => do
let mut results : List (Except Exception (α × Term.SavedState)) := []
for job in chunk do
try
let a job
let s saveState
results := .ok (a, s) :: results
catch e =>
results := .error e :: results
pure results.reverse
let chunkResults parCore' chunkJobs
set initialState
let mut allResults := []
for chunkResult in chunkResults do
match chunkResult with
| .ok jobResults => allResults := allResults ++ jobResults
| .error e => allResults := allResults ++ [.error e]
return allResults
/--
Runs a list of TermElabM computations in parallel and collects results in the original order,
discarding state information.
Unlike `par`, this doesn't return state information from tasks.
The final TermElabM state is restored to the initial state (before tasks ran).
**Chunking:** Pass `maxTasks > 0` to limit parallel tasks by grouping jobs into chunks.
-/
def par' {α : Type} (jobs : List (TermElabM α))
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) :
TermElabM (List (Except Exception α)) := do
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
if chunkSize 1 then
parCore' jobs
else
let initialState get
let chunks := toChunks jobs chunkSize
let chunkJobs := chunks.map fun chunk => do
let mut results : List (Except Exception α) := []
for job in chunk do
try
let a job
results := .ok a :: results
catch e =>
results := .error e :: results
pure results.reverse
let chunkResults parCore' chunkJobs
set initialState
let mut allResults := []
for chunkResult in chunkResults do
match chunkResult with
| .ok jobResults => allResults := allResults ++ jobResults
| .error e => allResults := allResults ++ [.error e]
return allResults
/--
Runs a list of TermElabM computations in parallel and returns the first successful result
(by completion order, not list order).
@@ -531,6 +911,37 @@ namespace Lean.Elab.Tactic.TacticM
open Std.Iterators
/--
Internal state for an iterator over chunked tasks for TacticM.
Yields individual results while internally managing chunk boundaries.
-/
private structure ChunkedTaskIterator (α : Type) where
chunkTasks : List (Task (TacticM (List (Except Exception α))))
currentResults : List (Except Exception α)
private instance {α : Type} : Iterator (ChunkedTaskIterator α) TacticM (Except Exception α) where
IsPlausibleStep _
| .yield _ _ => True
| .skip _ => True
| .done => True
step it := do
match it.internalState.currentResults with
| r :: rest =>
pure <| .deflate .yield (toIterM { it.internalState with currentResults := rest } TacticM (Except Exception α)) r, trivial
| [] =>
match it.internalState.chunkTasks with
| [] => pure <| .deflate .done, trivial
| task :: rest =>
try
let chunkResults task.get
match chunkResults with
| [] =>
pure <| .deflate .skip (toIterM { chunkTasks := rest, currentResults := [] } TacticM (Except Exception α)), trivial
| r :: rs =>
pure <| .deflate .yield (toIterM { chunkTasks := rest, currentResults := rs } TacticM (Except Exception α)) r, trivial
catch e =>
pure <| .deflate .yield (toIterM { chunkTasks := rest, currentResults := [] } TacticM (Except Exception α)) (.error e), trivial
/--
Runs a list of TacticM computations in parallel and returns:
* a combined cancellation hook for all tasks, and
@@ -548,7 +959,6 @@ The iterator will terminate after all jobs complete (assuming they all do comple
def parIterWithCancel {α : Type} (jobs : List (TacticM α)) := do
let (cancels, tasks) := ( jobs.mapM asTask).unzip
let combinedCancel := cancels.forM id
-- Create iterator that processes tasks sequentially
let iterWithErrors := tasks.iter.mapM fun (task : Task (TacticM α)) => do
try
let result task.get
@@ -557,6 +967,34 @@ def parIterWithCancel {α : Type} (jobs : List (TacticM α)) := do
pure (Except.error e)
return (combinedCancel, iterWithErrors)
/--
Runs a list of TacticM computations in parallel with chunking and returns:
* a combined cancellation hook for all tasks, and
* an iterator that yields results in original order.
Unlike `parIterWithCancel`, this groups jobs into chunks to reduce task overhead.
Each chunk runs its jobs sequentially, but chunks run in parallel.
**Parameters:**
- `maxTasks`: Maximum number of parallel tasks (chunks). Default 0 means one task per job.
- `minChunkSize`: Minimum jobs per chunk. Default 1.
-/
def parIterWithCancelChunked {α : Type} (jobs : List (TacticM α))
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) := do
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
let chunks := toChunks jobs chunkSize
let chunkJobs : List (TacticM (List (Except Exception α))) :=
chunks.map fun (chunk : List (TacticM α)) => chunk.mapM fun job => do
try
let a job
pure (.ok a)
catch e =>
pure (.error e)
let (cancels, tasks) := ( chunkJobs.mapM asTask).unzip
let combinedCancel := cancels.forM id
let flatIter := toIterM (ChunkedTaskIterator.mk tasks []) TacticM (Except Exception α)
return (combinedCancel, flatIter)
/--
Runs a list of TacticM computations in parallel (without cancellation hook).
@@ -599,19 +1037,9 @@ Returns an iterator that yields results in completion order, wrapped in `Except
def parIterGreedy {α : Type} (jobs : List (TacticM α)) :=
(·.2) <$> parIterGreedyWithCancel jobs
/--
Runs a list of TacticM computations in parallel and collects results in the original order,
including the saved state after each task completes.
Unlike `parIter`, this waits for all tasks to complete and returns results
in the same order as the input list, not in completion order.
Results are wrapped in `Except Exception (α × Tactic.SavedState)` so that errors in individual
tasks don't stop the collection - you can observe all results including which tasks failed.
The final TacticM state is restored to the initial state (before tasks ran).
-/
def par {α : Type} (jobs : List (TacticM α)) : TacticM (List (Except Exception (α × Tactic.SavedState))) := do
/-- Internal: run jobs in parallel without chunking, returning state. -/
private def parCore {α : Type} (jobs : List (TacticM α)) :
TacticM (List (Except Exception (α × Tactic.SavedState))) := do
let initialState get
let tasks jobs.mapM asTask'
let mut results := []
@@ -625,15 +1053,9 @@ def par {α : Type} (jobs : List (TacticM α)) : TacticM (List (Except Exception
set initialState
return results.reverse
/--
Runs a list of TacticM computations in parallel and collects results in the original order,
discarding state information.
Unlike `par`, this doesn't return state information from tasks.
The final TacticM state is restored to the initial state (before tasks ran).
-/
def par' {α : Type} (jobs : List (TacticM α)) : TacticM (List (Except Exception α)) := do
/-- Internal: run jobs in parallel without chunking, discarding state. -/
private def parCore' {α : Type} (jobs : List (TacticM α)) :
TacticM (List (Except Exception α)) := do
let initialState get
let tasks jobs.mapM asTask'
let mut results := []
@@ -646,6 +1068,86 @@ def par' {α : Type} (jobs : List (TacticM α)) : TacticM (List (Except Exceptio
set initialState
return results.reverse
/--
Runs a list of TacticM computations in parallel and collects results in the original order,
including the saved state after each task completes.
Unlike `parIter`, this waits for all tasks to complete and returns results
in the same order as the input list, not in completion order.
Results are wrapped in `Except Exception (α × Tactic.SavedState)` so that errors in individual
tasks don't stop the collection - you can observe all results including which tasks failed.
The final TacticM state is restored to the initial state (before tasks ran).
**Chunking:** Pass `maxTasks > 0` to limit parallel tasks by grouping jobs into chunks.
-/
def par {α : Type} (jobs : List (TacticM α))
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) :
TacticM (List (Except Exception (α × Tactic.SavedState))) := do
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
if chunkSize 1 then
parCore jobs
else
let initialState get
let chunks := toChunks jobs chunkSize
-- Each chunk processes its jobs sequentially, collecting Except results
let chunkJobs := chunks.map fun chunk => do
let mut results : List (Except Exception (α × Tactic.SavedState)) := []
for job in chunk do
try
let a job
let s Tactic.saveState
results := .ok (a, s) :: results
catch e =>
results := .error e :: results
pure results.reverse
let chunkResults parCore' chunkJobs
set initialState
let mut allResults := []
for chunkResult in chunkResults do
match chunkResult with
| .ok jobResults => allResults := allResults ++ jobResults
| .error e => allResults := allResults ++ [.error e]
return allResults
/--
Runs a list of TacticM computations in parallel and collects results in the original order,
discarding state information.
Unlike `par`, this doesn't return state information from tasks.
The final TacticM state is restored to the initial state (before tasks ran).
**Chunking:** Pass `maxTasks > 0` to limit parallel tasks by grouping jobs into chunks.
-/
def par' {α : Type} (jobs : List (TacticM α))
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) :
TacticM (List (Except Exception α)) := do
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
if chunkSize 1 then
parCore' jobs
else
let initialState get
let chunks := toChunks jobs chunkSize
let chunkJobs := chunks.map fun chunk => do
let mut results : List (Except Exception α) := []
for job in chunk do
try
let a job
results := .ok a :: results
catch e =>
results := .error e :: results
pure results.reverse
let chunkResults parCore' chunkJobs
set initialState
let mut allResults := []
for chunkResult in chunkResults do
match chunkResult with
| .ok jobResults => allResults := allResults ++ jobResults
| .error e => allResults := allResults ++ [.error e]
return allResults
/--
Runs a list of TacticM computations in parallel and returns the first successful result
(by completion order, not list order).

View File

@@ -13,6 +13,7 @@ public import Lean.Meta.Tactic.Refl
public import Lean.Meta.Tactic.SolveByElim
public import Lean.Meta.Tactic.TryThis
public import Lean.Util.Heartbeats
public import Lean.Elab.Parallel
public section
@@ -286,61 +287,44 @@ def RewriteResult.addSuggestion (ref : Syntax) (r : RewriteResult)
(type? := r.newGoal.toLOption) (origSpan? := getRef)
(checkState? := checkState?.getD ( saveState))
structure RewriteResultConfig where
stopAtRfl : Bool
max : Nat
minHeartbeats : Nat
goal : MVarId
target : Expr
side : SideConditions := .solveByElim
mctx : MetavarContext
/--
Find lemmas which can rewrite the goal.
def takeListAux (cfg : RewriteResultConfig) (seen : Std.HashMap String Unit) (acc : Array RewriteResult)
(xs : List ((Expr Name) × Bool × Nat)) : MetaM (Array RewriteResult) := do
let mut seen := seen
let mut acc := acc
for (lem, symm, weight) in xs do
if ( getRemainingHeartbeats) < cfg.minHeartbeats then
return acc
if acc.size cfg.max then
return acc
let res
withoutModifyingState <| withMCtx cfg.mctx do
rwLemma cfg.mctx cfg.goal cfg.target cfg.side lem symm weight
match res with
| none => continue
| some r =>
let s withoutModifyingState <| withMCtx r.mctx r.ppResult
if seen.contains s then
continue
let rfl? dischargableWithRfl? r.mctx r.result.eNew
if cfg.stopAtRfl then
if rfl? then
return #[r]
else
seen := seen.insert s ()
acc := acc.push r
else
seen := seen.insert s ()
acc := acc.push r
return acc
/-- Find lemmas which can rewrite the goal. -/
Runs all candidates in parallel, iterates through results in order.
Cancels remaining tasks and returns immediately if `stopAtRfl` is true and
an rfl-closeable result is found. Collects up to `max` unique results.
-/
def findRewrites (hyps : Array (Expr × Bool × Nat))
(moduleRef : LazyDiscrTree.ModuleDiscrTreeRef (Name × RwDirection))
(goal : MVarId) (target : Expr)
(forbidden : NameSet := ) (side : SideConditions := .solveByElim)
(stopAtRfl : Bool) (max : Nat := 20)
(leavePercentHeartbeats : Nat := 10) : MetaM (List RewriteResult) := do
(stopAtRfl : Bool) (max : Nat := 20) : MetaM (List RewriteResult) := do
let mctx getMCtx
let candidates rewriteCandidates hyps moduleRef target forbidden
let minHeartbeats : Nat
if ( getMaxHeartbeats) = 0 then
pure 0
else
pure <| leavePercentHeartbeats * ( getRemainingHeartbeats) / 100
let cfg : RewriteResultConfig :=
{ stopAtRfl, minHeartbeats, max, mctx, goal, target, side }
return ( takeListAux cfg {} (Array.mkEmpty max) candidates.toList).toList
-- Create parallel jobs for each candidate
let jobs := candidates.toList.map fun (lem, symm, weight) => do
withoutModifyingState <| withMCtx mctx do
let some r rwLemma mctx goal target side lem symm weight
| return none
let s withoutModifyingState <| withMCtx r.mctx r.ppResult
return some (r, s)
let (cancel, iter) MetaM.parIterWithCancelChunked jobs (maxTasks := 128)
let mut seen : Std.HashMap String Unit := {}
let mut acc : Array RewriteResult := Array.mkEmpty max
for result in iter.allowNontermination do
if acc.size max then
cancel
break
match result with
| .error _ => continue
| .ok none => continue
| .ok (some (r, s)) =>
if seen.contains s then continue
seen := seen.insert s ()
if stopAtRfl && r.rfl? then
cancel
return [r]
acc := acc.push r
return acc.toList
end Lean.Meta.Rewrites

View File

@@ -0,0 +1,78 @@
import Lean.Elab.Parallel
/-!
# Tests for chunked parallel execution
Tests for the chunking support in `Lean.Elab.Parallel`.
-/
open Lean Core
/-! ## CoreM tests -/
/-- Test that par with maxTasks=0 (default) works like original -/
def testCoreMParDefault : CoreM Unit := do
let jobs := (List.range 10).map fun i => pure i
let results CoreM.par jobs
let values := results.filterMap fun r =>
match r with
| .ok (v, _) => some v
| .error _ => none
assert! values == List.range 10
/-- Test that par with chunking produces same results -/
def testCoreMParChunked : CoreM Unit := do
let jobs := (List.range 20).map fun i => pure i
let results CoreM.par jobs (maxTasks := 4) (minChunkSize := 2)
let values := results.filterMap fun r =>
match r with
| .ok (v, _) => some v
| .error _ => none
-- Results should be in original order
assert! values == List.range 20
/-- Test that par' with chunking works -/
def testCoreMPar'Chunked : CoreM Unit := do
let jobs := (List.range 15).map fun i => pure (i * 2)
let results CoreM.par' jobs (maxTasks := 3) (minChunkSize := 2)
let values := results.filterMap fun r =>
match r with
| .ok v => some v
| .error _ => none
assert! values == (List.range 15).map (· * 2)
/-- Test error handling in chunks -/
def testCoreMParChunkedErrors : CoreM Unit := do
let jobs := (List.range 10).map fun i =>
if i == 5 then throwError "error at 5"
else pure i
let results CoreM.par' jobs (maxTasks := 3) (minChunkSize := 2)
-- Should have 9 successes and 1 error
let successes := results.filter fun r => match r with | .ok _ => true | .error _ => false
let errors := results.filter fun r => match r with | .ok _ => false | .error _ => true
assert! successes.length == 9
assert! errors.length == 1
#eval do
let env importModules #[{ module := `Init }] {} 0
let (_, _) testCoreMParDefault.toIO { fileName := "", fileMap := default } { env }
IO.println "testCoreMParDefault passed"
#eval do
let env importModules #[{ module := `Init }] {} 0
let (_, _) testCoreMParChunked.toIO { fileName := "", fileMap := default } { env }
IO.println "testCoreMParChunked passed"
#eval do
let env importModules #[{ module := `Init }] {} 0
let (_, _) testCoreMPar'Chunked.toIO { fileName := "", fileMap := default } { env }
IO.println "testCoreMPar'Chunked passed"
#eval do
let env importModules #[{ module := `Init }] {} 0
let (_, _) testCoreMParChunkedErrors.toIO { fileName := "", fileMap := default } { env }
IO.println "testCoreMParChunkedErrors passed"
/-! ## All tests passed -/
#eval IO.println "All parallel_chunked tests passed!"