Compare commits

...

2 Commits

Author SHA1 Message Date
Kim Morrison
12ec9a4bb1 test 2025-11-25 23:05:30 +01:00
Kim Morrison
efcc679296 feat: try? uses parallelism 2025-11-25 05:02:48 +01:00
4 changed files with 122 additions and 25 deletions

View File

@@ -55,6 +55,9 @@ syntax (name := tryTrace) "try?" optConfig : tactic
/-- Helper internal tactic for implementing the tactic `try?`. -/
syntax (name := attemptAll) "attempt_all " withPosition((ppDedent(ppLine) colGe "| " tacticSeq)+) : tactic
/-- Helper internal tactic for implementing the tactic `try?` with parallel execution. -/
syntax (name := attemptAllPar) "attempt_all_par " withPosition((ppDedent(ppLine) colGe "| " tacticSeq)+) : tactic
/-- Helper internal tactic used to implement `evalSuggest` in `try?` -/
syntax (name := tryResult) "try_suggestions " tactic* : tactic

View File

@@ -194,24 +194,24 @@ def parIterGreedy {α : Type} (jobs : List (CoreM α)) :=
/--
Runs a list of CoreM computations in parallel and collects results in the original order,
including the state after each task completes.
including the saved state after each task completes.
Unlike `parIter`, this waits for all tasks to complete and returns results
in the same order as the input list, not in completion order.
Results are wrapped in `Except Exception (α × Core.State)` so that errors in individual
Results are wrapped in `Except Exception (α × Core.SavedState)` so that errors in individual
tasks don't stop the collection - you can observe all results including which tasks failed.
The final CoreM state is restored to the initial state (before tasks ran).
-/
def par {α : Type} (jobs : List (CoreM α)) : CoreM (List (Except Exception (α × Core.State))) := do
def par {α : Type} (jobs : List (CoreM α)) : CoreM (List (Except Exception (α × Core.SavedState))) := do
let initialState get
let tasks jobs.mapM asTask'
let mut results := []
for task in tasks do
let resultWithState observing do
let result task.get
pure (result, ( get))
pure (result, ( saveState))
results := resultWithState :: results
set initialState
return results.reverse
@@ -261,25 +261,24 @@ open Std.Iterators
/--
Runs a list of MetaM computations in parallel and collects results in the original order,
including the state after each task completes.
including the saved state after each task completes.
Unlike `parIter`, this waits for all tasks to complete and returns results
in the same order as the input list, not in completion order.
Results are wrapped in `Except Exception (α × Meta.State)` so that errors in individual
Results are wrapped in `Except Exception (α × Meta.SavedState)` so that errors in individual
tasks don't stop the collection - you can observe all results including which tasks failed.
The final MetaM state is restored to the initial state (before tasks ran).
Note: Only Meta.State is captured/reverted, not Core.State or IO effects.
-/
def par {α : Type} (jobs : List (MetaM α)) : MetaM (List (Except Exception (α × Meta.State))) := do
def par {α : Type} (jobs : List (MetaM α)) : MetaM (List (Except Exception (α × Meta.SavedState))) := do
let initialState get
let tasks jobs.mapM asTask'
let mut results := []
for task in tasks do
let resultWithState observing do
let result task.get
pure (result, ( get))
pure (result, ( saveState))
results := resultWithState :: results
set initialState
return results.reverse
@@ -465,27 +464,24 @@ def parIterGreedy {α : Type} (jobs : List (TermElabM α)) :=
/--
Runs a list of TermElabM computations in parallel and collects results in the original order,
including the state after each task completes.
including the saved state after each task completes.
Unlike `parIter`, this waits for all tasks to complete and returns results
in the same order as the input list, not in completion order.
Results are wrapped in `Except Exception (α × Term.State)` so that errors in individual
Results are wrapped in `Except Exception (α × Term.SavedState)` so that errors in individual
tasks don't stop the collection - you can observe all results including which tasks failed.
The final TermElabM state is restored to the initial state (before tasks ran).
Note: Only Term.State is captured/reverted, not Meta.State, Core.State or IO effects.
-/
def par {α : Type} (jobs : List (TermElabM α)) : TermElabM (List (Except Exception (α × Term.State))) := do
def par {α : Type} (jobs : List (TermElabM α)) : TermElabM (List (Except Exception (α × Term.SavedState))) := do
let initialState get
let tasks jobs.mapM asTask'
let mut results := []
for task in tasks do
-- Note: We use try/catch instead of `observing` here because TermElabM's `observing`
-- returns `TermElabResult` (not `Except`), which includes SavedState that we don't need.
try
let result task.get
let taskState get
let taskState saveState
results := .ok (result, taskState) :: results
catch e =>
results := .error e :: results
@@ -605,27 +601,24 @@ def parIterGreedy {α : Type} (jobs : List (TacticM α)) :=
/--
Runs a list of TacticM computations in parallel and collects results in the original order,
including the state after each task completes.
including the saved state after each task completes.
Unlike `parIter`, this waits for all tasks to complete and returns results
in the same order as the input list, not in completion order.
Results are wrapped in `Except Exception (α × Tactic.State)` so that errors in individual
Results are wrapped in `Except Exception (α × Tactic.SavedState)` so that errors in individual
tasks don't stop the collection - you can observe all results including which tasks failed.
The final TacticM state is restored to the initial state (before tasks ran).
Note: Only Tactic.State is captured/reverted, not Term.State, Meta.State, Core.State or IO effects.
-/
def par {α : Type} (jobs : List (TacticM α)) : TacticM (List (Except Exception (α × Tactic.State))) := do
def par {α : Type} (jobs : List (TacticM α)) : TacticM (List (Except Exception (α × Tactic.SavedState))) := do
let initialState get
let tasks jobs.mapM asTask'
let mut results := []
for task in tasks do
-- Note: We use try/catch instead of `observing` here because TacticM's `observing`
-- (inherited from TermElabM) returns `TermElabResult`, not `Except`.
try
let result task.get
let taskState get
let taskState Tactic.saveState
results := .ok (result, taskState) :: results
catch e =>
results := .error e :: results

View File

@@ -10,6 +10,7 @@ public import Lean.Meta.Tactic.Try
public import Lean.Elab.Tactic.SimpTrace
public import Lean.Elab.Tactic.LibrarySearch
public import Lean.Elab.Tactic.Grind.Main
public import Lean.Elab.Parallel
meta import Lean.Elab.Command
public section
namespace Lean.Elab.Tactic
@@ -697,6 +698,39 @@ where
else
throwError "`attempt_all` failed"
/-- `evalSuggest` for `attempt_all_par` tactic (parallel version). -/
private partial def evalSuggestAttemptAllPar (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : TryTacticM (TSyntax `tactic) := do
unless ( read).terminal do
throwError "invalid occurrence of `attempt_all_par` in non-terminal position for `try?` script{indentD (← read).root}"
let ctx read
-- Create jobs that each try one tactic and return the suggestion
let jobs : List (TacticM (TSyntax `tactic)) := tacs.toList.map fun tacSeq =>
withOriginalHeartbeats (evalSuggestTacticSeq tacSeq) ctx
-- Run all jobs in parallel - par returns (result, SavedState) for each
let results TacticM.par jobs
-- Collect successful results (maintaining order)
let mut acc : Array (TSyntax `tactic) := #[]
let mut firstSaved? : Option SavedState := none
for result in results do
match result with
| .ok (tac, s) =>
trace[try.debug] "`attempt_all_par` argument succeeded{indentD tac}"
acc := appendSuggestion acc tac
if firstSaved?.isNone then
firstSaved? := some s
| .error _ => pure ()
-- Restore first successful state and return suggestions
if let some saved := firstSaved? then
saved.restore
mkTrySuggestions acc
else
throwError "`attempt_all_par` failed"
private partial def evalSuggestDefault (tac : TSyntax `tactic) : TryTacticM (TSyntax `tactic) := do
let kind := tac.raw.getKind
match ( getEvalFns kind) with
@@ -743,6 +777,7 @@ private partial def evalSuggestImpl : TryTactic := fun tac => do
| `(tactic| ($tac:tacticSeq)) => evalSuggestTacticSeq tac
| `(tactic| try $tac:tacticSeq) => evalSuggestTry tac
| `(tactic| attempt_all $[| $tacs]*) => evalSuggestAttemptAll tacs
| `(tactic| attempt_all_par $[| $tacs]*) => evalSuggestAttemptAllPar tacs
| _ =>
let k := tac.raw.getKind
if k == ``Parser.Tactic.seq1 then
@@ -910,7 +945,7 @@ private unsafe def mkTryEvalSuggestStxUnsafe (goal : MVarId) (info : Try.Info) :
let simp mkSimpStx
let grind mkGrindStx info
let atomic `(tactic| attempt_all | $simple:tactic | $simp:tactic | $grind:tactic | simp_all)
let atomic `(tactic| attempt_all_par | $simple:tactic | $simp:tactic | $grind:tactic | simp_all)
let atomicSuggestions mkAtomicWithSuggestionsStx
let funInds mkAllFunIndStx info atomic
let inds mkAllIndStx info atomic
@@ -934,7 +969,7 @@ private unsafe def mkTryEvalSuggestStxUnsafe (goal : MVarId) (info : Try.Info) :
if userTactics.isEmpty then
`(tactic| first | $atomic:tactic | $atomicSuggestions:tactic | $funInds:tactic | $inds:tactic | $extra:tactic)
else
let userAttemptAll `(tactic| attempt_all $[| $userTactics:tactic]*)
let userAttemptAll `(tactic| attempt_all_par $[| $userTactics:tactic]*)
`(tactic| first | $atomic:tactic | $atomicSuggestions:tactic | $funInds:tactic | $inds:tactic | $extra:tactic | $userAttemptAll:tactic)
@[implemented_by mkTryEvalSuggestStxUnsafe]

View File

@@ -0,0 +1,66 @@
/-
Test that try? runs user suggestion tactics in parallel via attempt_all_par.
This test uses IO.stdGenRef (a builtin_initialize ref) to demonstrate parallelism:
- Tactic 1 (high prio): waits 1000ms, then checks if the random seed was changed
- Tactic 2 (low prio): immediately sets the seed to a magic value, then succeeds
If sequential: Tactic 1 executes first, waits, seed unchanged, fails.
Then tactic 2 executes, sets seed, succeeds. Only one suggestion.
If parallel: Both tactics start together. Tactic 2 sets seed immediately.
Tactic 1 waits 1000ms, sees changed seed, succeeds. Two suggestions.
-/
module
public import Lean
public meta import Lean.Elab.Tactic.Try
open Lean Meta Elab Tactic Try
-- A goal that built-in tactics won't solve
inductive ParallelTestGoal : Prop where
| mk : ParallelTestGoal
-- Magic seed value to signal parallelism
meta def magicSeed : Nat := 314159265
-- Tactic that waits, then checks if seed was changed
elab "wait_and_check_seed" : tactic => do
IO.sleep 1000
let gen IO.stdGenRef.get
let expected := mkStdGen magicSeed
if gen.s1 == expected.s1 && gen.s2 == expected.s2 then
evalTactic ( `(tactic| exact ParallelTestGoal.mk))
else
throwError "seed not changed (sequential execution detected)"
-- Tactic that immediately sets seed and succeeds
elab "set_seed_and_succeed" : tactic => do
IO.setRandSeed magicSeed
evalTactic ( `(tactic| exact ParallelTestGoal.mk))
-- Register both tactics as user suggestions
-- High priority tactic: reset seed first (to ensure clean state), then return waiting tactic
@[local try_suggestion 900]
meta def waitAndCheckSolver (_goal : MVarId) (_info : Try.Info) : MetaM (Array (TSyntax `tactic)) := do
-- Reset to a different seed to ensure we're testing actual communication
IO.setRandSeed 0
return #[ `(tactic| wait_and_check_seed)]
-- Low priority tactic returns the seed-setting tactic
@[local try_suggestion 800]
meta def setFlagSolver (_goal : MVarId) (_info : Try.Info) : MetaM (Array (TSyntax `tactic)) := do
return #[ `(tactic| set_seed_and_succeed)]
-- If parallel: both tactics succeed (tactic 1 sees seed change after waiting)
-- If sequential: only tactic 2 succeeds (tactic 1 sees unchanged seed)
--
-- EXPECTED ON MASTER (sequential): Only one suggestion
-- EXPECTED ON try_par (parallel): Two suggestions
/--
info: Try these:
[apply] wait_and_check_seed
[apply] set_seed_and_succeed
-/
#guard_msgs in
example : ParallelTestGoal := by
try?