Compare commits

...

5 Commits

Author SHA1 Message Date
Leonardo de Moura
545be17d0d test: evalAndSuggest 2025-04-05 17:37:12 -07:00
Leonardo de Moura
49dfe5776b feat: extensible evalAndSuggest 2025-04-05 17:36:56 -07:00
Leonardo de Moura
27e815616b feat: add [try_tactic] attribute 2025-04-05 15:32:12 -07:00
Leonardo de Moura
1ee5ddc55a feat: add TryTactic 2025-04-05 14:56:09 -07:00
Leonardo de Moura
291feaa7ae chore: M => TryTacticM 2025-04-05 14:51:24 -07:00
2 changed files with 96 additions and 23 deletions

View File

@@ -216,19 +216,26 @@ structure Ctx where
terminal : Bool
config : Try.Config
abbrev M := ReaderT Ctx TacticM
abbrev TryTacticM := ReaderT Ctx TacticM
abbrev TryTactic := TSyntax `tactic TryTacticM (TSyntax `tactic)
instance : MonadBacktrack SavedState M where
instance : MonadBacktrack SavedState TryTacticM where
saveState := fun _ => saveState
restoreState s := fun _ => restoreState s
abbrev withNonTerminal (x : M α) : M α :=
abbrev withNonTerminal (x : TryTacticM α) : TryTacticM α :=
withReader (fun c => { c with terminal := false}) x
-- TODO: polymorphic `Tactic.focus`
abbrev focus (x : M α) : M α := fun ctx => Tactic.focus (x ctx)
builtin_initialize tryTacticElabAttribute : KeyedDeclsAttribute TryTactic do
unsafe mkElabAttribute TryTactic `builtin_try_tactic `try_tactic `Lean.Parser.Tactic `Lean.Elab.Tactic.Try.TryTactic "try_tactic"
def observing (x : M α) : M (TacticResult α) := do
private def getEvalFns (kind : SyntaxNodeKind) : CoreM (List (KeyedDeclsAttribute.AttributeEntry TryTactic)) := do
return tryTacticElabAttribute.getEntries ( getEnv) kind
-- TODO: polymorphic `Tactic.focus`
abbrev focus (x : TryTacticM α) : TryTacticM α := fun ctx => Tactic.focus (x ctx)
def observing (x : TryTacticM α) : TryTacticM (TacticResult α) := do
let s saveState
try
let e x
@@ -271,7 +278,7 @@ private def merge? (tac1 tac2 : TSyntax `tactic) : Option (TSyntax `tactic) :=
else
none
private def mergeAll? (tacs : Array (TSyntax `tactic)) : M (Option (TSyntax `tactic)) := do
private def mergeAll? (tacs : Array (TSyntax `tactic)) : TryTacticM (Option (TSyntax `tactic)) := do
if !( read).config.merge || tacs.isEmpty then
return none
let tac0 := tacs[0]!
@@ -304,7 +311,7 @@ private def isOnlyAndNonOnly (tacs2 : Array (TSyntax `tactic)) : Bool := Id.run
else
return false
private def mkChainResult (tac1 : TSyntax `tactic) (tacss2 : Array (TSyntax `tactic)) : M (TSyntax `tactic) := do
private def mkChainResult (tac1 : TSyntax `tactic) (tacss2 : Array (TSyntax `tactic)) : TryTacticM (TSyntax `tactic) := do
let tacss2 := tacss2.map getSuggestionsCore
if ( isTracingEnabledFor `try.debug) then
trace[try.debug] "mkChainResultCore tac1{indentD tac1}"
@@ -343,7 +350,7 @@ private def mkChainResult (tac1 : TSyntax `tactic) (tacss2 : Array (TSyntax `tac
(_, acc) go tacss2 0 [] none |>.run acc
mkTrySuggestions acc
where
go (tacss2 : Array (Array (TSyntax `tactic))) (i : Nat) (acc : List (TSyntax `tactic)) (kind? : Option SyntaxNodeKind) : StateT (Array (TSyntax `tactic)) M Unit := do
go (tacss2 : Array (Array (TSyntax `tactic))) (i : Nat) (acc : List (TSyntax `tactic)) (kind? : Option SyntaxNodeKind) : StateT (Array (TSyntax `tactic)) TryTacticM Unit := do
if ( get).size > ( read).config.max then
return ()
else if h : i < tacss2.size then
@@ -371,7 +378,7 @@ where
$tacs2*)
modify (·.push tac)
private def evalSuggestGrindTrace (tac : TSyntax `tactic) : M (TSyntax `tactic) := do
private def evalSuggestGrindTrace : TryTactic := fun tac => do
match tac with
| `(tactic| grind? $configStx:optConfig $[only%$only]? $[ [$params:grindParam,*] ]? $[on_failure $fallback?]?) =>
let config elabGrindConfig configStx
@@ -386,7 +393,7 @@ private def evalSuggestGrindTrace (tac : TSyntax `tactic) : M (TSyntax `tactic)
return tac
| _ => throwUnsupportedSyntax
private def evalSuggestSimpTrace (tac : TSyntax `tactic) : M (TSyntax `tactic) := do ( getMainGoal).withContext do
private def evalSuggestSimpTrace : TryTactic := fun tac => do ( getMainGoal).withContext do
match tac with
| `(tactic| simp? $_:optConfig $[only%$only]? $[[$args,*]]? $(loc)?) =>
let tac simpTraceToSimp tac
@@ -401,10 +408,10 @@ private def evalSuggestSimpTrace (tac : TSyntax `tactic) : M (TSyntax `tactic) :
| _ => throwUnsupportedSyntax
@[extern "lean_eval_suggest_tactic"] -- forward definition to avoid mutual block
opaque evalSuggest (tac : TSyntax `tactic) : M (TSyntax `tactic)
opaque evalSuggest : TryTactic
/-- `evalSuggest` for `tac1 <;> tac2` -/
private def evalSuggestChain (tac1 tac2 : TSyntax `tactic) : M (TSyntax `tactic) := focus do
private def evalSuggestChain (tac1 tac2 : TSyntax `tactic) : TryTacticM (TSyntax `tactic) := focus do
unless ( read).terminal do
throwError "invalid `<;>` occurrence in non-terminal position for `try?` script{indentD (← read).root}"
let tac1 withNonTerminal do evalSuggest tac1
@@ -422,7 +429,7 @@ private def evalSuggestChain (tac1 tac2 : TSyntax `tactic) : M (TSyntax `tactic)
mkChainResult tac1 tac2s
/-- `evalSuggest` for a sequence of tactics. -/
private def evalSuggestSeq (tacs : Array (TSyntax `tactic)) : M (TSyntax `tactic) := do
private def evalSuggestSeq (tacs : Array (TSyntax `tactic)) : TryTacticM (TSyntax `tactic) := do
if ( read).terminal then
let mut result := #[]
for i in [:tacs.size - 1] do
@@ -433,10 +440,10 @@ private def evalSuggestSeq (tacs : Array (TSyntax `tactic)) : M (TSyntax `tactic
else
mkSeq ( tacs.mapM evalSuggest) (terminal := false)
private def evalSuggestSeqCore (tacs : Array Syntax) : M (TSyntax `tactic) := do
private def evalSuggestSeqCore (tacs : Array Syntax) : TryTacticM (TSyntax `tactic) := do
evalSuggestSeq (tacs.map fun tac => tac)
private def evalSuggestTacticSeq (s : TSyntax ``Parser.Tactic.tacticSeq) : M (TSyntax `tactic) := do
private def evalSuggestTacticSeq (s : TSyntax ``Parser.Tactic.tacticSeq) : TryTacticM (TSyntax `tactic) := do
let tacs match s with
| `(tacticSeq| { $t;* }) => pure t.getElems
| `(tacticSeq| $t;*) => pure t.getElems
@@ -444,30 +451,30 @@ private def evalSuggestTacticSeq (s : TSyntax ``Parser.Tactic.tacticSeq) : M (TS
evalSuggestSeq tacs
/-- `evalSuggest` for `first` tactic. -/
private partial def evalSuggestFirst (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : M (TSyntax `tactic) := do
private partial def evalSuggestFirst (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : TryTacticM (TSyntax `tactic) := do
if tacs.size == 0 then
throwError "`first` expects at least one argument"
go 0
where
go (i : Nat) : M (TSyntax `tactic) := do
go (i : Nat) : TryTacticM (TSyntax `tactic) := do
if i = tacs.size - 1 then
evalSuggestTacticSeq tacs[i]!
else
evalSuggestTacticSeq tacs[i]! <|> go (i+1)
/-- `evalSuggest` for `try` tactic. -/
private partial def evalSuggestTry (tac : TSyntax ``Parser.Tactic.tacticSeq) : M (TSyntax `tactic) := do
private partial def evalSuggestTry (tac : TSyntax ``Parser.Tactic.tacticSeq) : TryTacticM (TSyntax `tactic) := do
(do evalSuggestTacticSeq tac)
<|>
`(tactic| skip)
/-- `evalSuggest` for `attempt_all` tactic. -/
private partial def evalSuggestAttemptAll (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : M (TSyntax `tactic) := do
private partial def evalSuggestAttemptAll (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : TryTacticM (TSyntax `tactic) := do
unless ( read).terminal do
throwError "invalid occurrence of `attempt_all` in non-terminal position for `try?` script{indentD (← read).root}"
go 0 none #[]
where
go (i : Nat) (saved? : Option SavedState) (acc : Array (TSyntax `tactic)) : M (TSyntax `tactic) := do
go (i : Nat) (saved? : Option SavedState) (acc : Array (TSyntax `tactic)) : TryTacticM (TSyntax `tactic) := do
-- Remark: we considered using `acc.size < (← read).config.max` here to truncate the search,
-- but it had a negative effect when using `<;>`. We could miss a preferred solution `induction e <;> grind`
-- because only a subset of the goals were solved by simpler tactics such as `rfl` and `simp`.
@@ -485,10 +492,45 @@ where
else
throwError "`attempt_all` failed"
private partial def evalSuggestDefault (tac : TSyntax `tactic) : TryTacticM (TSyntax `tactic) := do
let kind := tac.raw.getKind
match ( getEvalFns kind) with
| [] => evalSuggestAtomic tac -- lift regular tactic
| evalFns => eval ( Tactic.saveState) evalFns #[]
where
throwExs (failures : Array EvalTacticFailure) : TryTacticM (TSyntax `tactic) := do
if h : 0 < failures.size then
let fail := failures[failures.size - 1]
fail.state.restore (restoreInfo := true)
throw fail.exception
else
throwErrorAt tac "unexpected syntax {indentD tac}"
eval (s : SavedState) (evalFns : List _) (failures : Array EvalTacticFailure) : TryTacticM (TSyntax `tactic) := do
match evalFns with
| [] => throwExs failures
| evalFn::evalFns =>
try
withTheReader Tactic.Context ({ · with elaborator := evalFn.declName }) do
evalFn.value tac
catch ex => match ex with
| .error .. =>
let failures := failures.push ex, Tactic.saveState
s.restore (restoreInfo := true); eval s evalFns failures
| .internal id _ =>
if id == unsupportedSyntaxExceptionId then
s.restore (restoreInfo := true); eval s evalFns failures
else if id == abortTacticExceptionId then
let failures := failures.push ex, Tactic.saveState
s.restore (restoreInfo := true); eval s evalFns failures
else
throw ex
-- `evalSuggest` implementation
@[export lean_eval_suggest_tactic]
private partial def evalSuggestImpl (tac : TSyntax `tactic) : M (TSyntax `tactic) := do
private partial def evalSuggestImpl : TryTactic := fun tac => do
trace[try.debug] "{tac}"
-- TODO: Implement builtin cases using `[builtin_try_tactic]` after update-stage0
match tac with
| `(tactic| $tac1 <;> $tac2) => evalSuggestChain tac1 tac2
| `(tactic| first $[| $tacs]*) => evalSuggestFirst tacs
@@ -507,7 +549,7 @@ private partial def evalSuggestImpl (tac : TSyntax `tactic) : M (TSyntax `tactic
else if k == ``Parser.Tactic.exact? then
evalSuggestExact
else
evalSuggestAtomic tac
evalSuggestDefault tac
if ( read).terminal then
unless ( getGoals).isEmpty do
throwError "unsolved goals"

View File

@@ -0,0 +1,31 @@
import Lean
open Lean Meta Elab Tactic Try
-- Install a `TryTactic` handler for `assumption`
@[try_tactic assumption]
def evalTryApply : TryTactic := fun tac => do
-- We just use the default implementation, but return a different tactic.
evalAssumption tac
`(tactic| (trace "worked"; assumption))
/-- info: Try this: · trace "worked"; assumption -/
#guard_msgs (info) in
example (h : False) : False := by
try? (max := 1) -- at most one solution
-- `try?` uses `evalAndSuggest` the attribute `[try_tactic]` is used to extend `evalAndSuggest`.
-- Let's define our own `try?` that uses `evalAndSuggest`
elab stx:"my_try?" : tactic => do
-- Things to try
let toTry `(tactic| attempt_all | assumption | apply True | rfl)
evalAndSuggest stx toTry
/--
info: Try these:
• · trace "worked"; assumption
• rfl
-/
#guard_msgs (info) in
example (a : Nat) (h : a = a) : a = a := by
my_try?