Compare commits

...

3 Commits

Author SHA1 Message Date
Leonardo de Moura
48ab0eb8d3 chore: fix tests 2025-02-06 13:00:17 -08:00
Leonardo de Moura
7899420b47 fix: preserve insertion order in Try/Collect.lean 2025-02-06 12:59:32 -08:00
Leonardo de Moura
2053745893 feat: try? composite suggestions 2025-02-06 12:02:12 -08:00
6 changed files with 113 additions and 29 deletions

View File

@@ -27,7 +27,10 @@ namespace Lean.Parser.Tactic
syntax (name := tryTrace) "try?" optConfig : tactic
/-- Helper tactic for implementing the tactic `try?`. -/
/-- Helper internal tactic for implementing the tactic `try?`. -/
syntax (name := attemptAll) "attempt_all " withPosition((ppDedent(ppLine) colGe "| " tacticSeq)+) : tactic
/-- Helper internal tactic used to implement `evalSuggest` in `try?` -/
syntax (name := tryResult) "try_suggestions " tactic* : tactic
end Lean.Parser.Tactic

View File

@@ -13,11 +13,6 @@ import Lean.Elab.Tactic.Config
import Lean.Elab.Tactic.SimpTrace
import Lean.Elab.Tactic.Grind
namespace Lean.Parser.Tactic
/-- Internal tactic used to implement `evalSuggest` -/
syntax (name := tryResult) "try_suggestions " tactic* : tactic
end Lean.Parser.Tactic
namespace Lean.Elab.Tactic
open Meta
/-!
@@ -52,7 +47,7 @@ private def appendSeqResult (suggestionSeqs : Array (Array (TSyntax `tactic))) (
/-- Returns a tactic representing all given suggestions `tacs`. -/
private def mkTrySuggestions (tacs : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do
if tacs.isEmpty then
throwError "`mkSuggestions` failed"
throwError "`mkTrySuggestions` failed"
else if tacs.size == 1 then
return tacs[0]!
else
@@ -130,16 +125,38 @@ private def getKindsSolvedAll (tacss : Array (Array (TSyntax `tactic))) : Array
r := r.push k
return r
private def mkChainResultCore (tac1 : TSyntax `tactic) (tacs2 : Array (TSyntax `tactic)) : TacticM (Array (TSyntax `tactic)) := do
let tacs2 := tacs2.map getSuggestionsCore
private def peekOne (tac1 : TSyntax `tactic) (tacss2 : Array (Array (TSyntax `tactic))) : TacticM (TSyntax `tactic) := do
let mut tacs2 := #[]
for s in tacss2 do
if s.isEmpty then
tacs2 := tacs2.push ( `(tactic| · sorry))
else
tacs2 := tacs2.push ( `(tactic| · $(s[0]!):tactic))
`(tactic| · $tac1:tactic
$tacs2*)
private def mkChainResultCore (tac1 : TSyntax `tactic) (tacss2 : Array (TSyntax `tactic)) : TacticM (Array (TSyntax `tactic)) := do
let tacss2 := tacss2.map getSuggestionsCore
if ( isTracingEnabledFor `try.debug) then
trace[try.debug] "mkChainResultCore tac1{indentD tac1}"
let mut i : Nat := 0
for tacs2 in tacss2 do
i := i + 1
trace[try.debug] "goal #{i} tactics"
for tac2 in tacs2 do
trace[try.debug] " {tac2}"
trace[try.debug] "mkChainResult -----"
let mut acc := #[]
let solvedAll := getTacsSolvedAll tacs2
let solvedAll := getTacsSolvedAll tacss2
for tac2 in solvedAll do
acc := acc.push ( `(tactic| $tac1 <;> $tac2))
let tacs2 := eraseTacs tacs2 solvedAll
let tacss2 := eraseTacs tacss2 solvedAll
-- TODO: mixed cases
trace[Meta.debug] "CHAIN tacs2: {tacs2}"
trace[Meta.debug] "CHAIN kinds: {getKindsSolvedAll tacs2}"
trace[try.debug] "kinds: {getKindsSolvedAll tacss2}"
if (!acc.isEmpty && tacss2.all fun s => !s.isEmpty)
-- We only include partial solutions if there are no other solutions.
|| (acc.isEmpty && tacss2.any fun s => !s.isEmpty) then
acc := acc.push <| ( peekOne tac1 tacss2)
return acc
private def mkChainResult (tac1 : TSyntax `tactic) (tacs2 : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do
@@ -178,6 +195,7 @@ private def evalSuggestGrindTrace (tac : TSyntax `tactic) : TacticM (TSyntax `ta
let trace evalGrindCore tac config only params fallback?
let tac grindTraceToGrind tac
let tac' mkGrindOnly configStx fallback? trace
trace[try.debug] "`grind` succeeded"
mkTrySuggestions #[tac, tac']
| _ => throwUnsupportedSyntax
@@ -188,6 +206,7 @@ private def evalSuggestSimpTrace (tac : TSyntax `tactic) : TacticM (TSyntax `tac
let { ctx, simprocs, .. } mkSimpContext tac (eraseLocal := false)
let stats simpLocation ctx (simprocs := simprocs) none <| (loc.map expandLocation).getD (.targets #[] true)
let tac' mkSimpCallStx tac stats.usedTheorems
trace[try.debug] "`simp` succeeded"
mkTrySuggestions #[tac, tac']
| _ => throwUnsupportedSyntax
@@ -215,11 +234,14 @@ private def evalSuggestChain (tac1 tac2 : TSyntax `tactic) : TacticM (TSyntax `t
let goals getGoals
setGoals []
let mut tac2s := #[]
let mut i : Nat := 0
for goal in goals do
setGoals [goal]
let tac2' (evalSuggest tac2) <|> `(tactic| sorry)
let tac2' : TSyntax `tactic (evalSuggest tac2) <|> `(tactic| sorry)
i := i + 1
trace[try.debug] "`<;>` goal #{i}, tactic{indentD tac2'}"
unless ( getGoals).isEmpty do
throwError "unsolved goals, `<;>` in `try?` requires all goals to be solved"
throwError "unsolved goals, `<;>` in `try?` requires all goals to be solved{indentD tac2}\n{goalsToMessageData (← getGoals)}"
tac2s := tac2s.push tac2'
if tac2s.all isSorry then
throwError "`<;>` failed"
@@ -269,8 +291,11 @@ where
go (i : Nat) (saved? : Option SavedState) (acc : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do
if i < tacs.size then
match ( observing (evalSuggestTacticSeq tacs[i]!)) with
| .ok tac s => go (i+1) (saved? <|> some s) (appendSuggestion acc tac)
| _ => go (i+1) saved? acc
| .ok tac s =>
trace[try.debug] "`attempt_all` argument succeeded{indentD tac}"
go (i+1) (saved? <|> some s) (appendSuggestion acc tac)
| _ =>
go (i+1) saved? acc
else
if let some saved := saved? then
saved.restore
@@ -281,6 +306,7 @@ where
-- `evalSuggest` implementation
@[export lean_eval_suggest_tactic]
private partial def evalSuggestImpl (tac : TSyntax `tactic) : TacticM (TSyntax `tactic) := do
trace[try.debug] "{tac}"
match tac with
| `(tactic| $tac1 <;> $tac2) => evalSuggestChain tac1 tac2
| `(tactic| first $[| $tacs]*) => evalSuggestFirst tacs
@@ -343,17 +369,17 @@ private def setGrindParams (tac : TSyntax `tactic) (params : Array (TSyntax ``Pa
tac.raw.setArg 3 (mkNullNode paramsStx)
/-- Given a set of declaration names, returns `grind` parameters of the form `= <declName>` -/
private def mkGrindEqnParams (declNames : Std.HashSet Name) : MetaM (Array (TSyntax ``Parser.Tactic.grindParam)) := do
declNames.toArray.mapM fun declName => do
private def mkGrindEqnParams (declNames : Array Name) : MetaM (Array (TSyntax ``Parser.Tactic.grindParam)) := do
declNames.mapM fun declName => do
`(Parser.Tactic.grindParam| = $( toIdent declName))
private def mkGrindStx (info : Try.Info) : MetaM (TSyntax `tactic) := do
let grind `(tactic| grind?)
let mut tacs := #[grind]
unless info.eqnCandidates.isEmpty do
tacs := tacs.push (setGrindParams grind ( mkGrindEqnParams info.eqnCandidates))
tacs := tacs.push (setGrindParams grind ( mkGrindEqnParams info.eqnCandidates.elems))
unless info.unfoldCandidates.isEmpty do
tacs := tacs.push (setGrindParams grind ( mkGrindEqnParams info.unfoldCandidates))
tacs := tacs.push (setGrindParams grind ( mkGrindEqnParams info.unfoldCandidates.elems))
mkFirstStx tacs
/-! Other generators -/
@@ -400,7 +426,7 @@ where
`(tactic| induction $terms,* using $indFn <;> $cont)
private def mkAllFunIndStx (info : Try.Info) (cont : TSyntax `tactic) : MetaM (TSyntax `tactic) := do
let tacs info.funIndCandidates.toArray.mapM (mkFunIndStx · cont)
let tacs info.funIndCandidates.elems.mapM (mkFunIndStx · cont)
mkFirstStx tacs
/-! Main code -/

View File

@@ -12,6 +12,7 @@ builtin_initialize registerTraceClass `try
builtin_initialize registerTraceClass `try.collect
builtin_initialize registerTraceClass `try.collect.funInd
builtin_initialize registerTraceClass `try.debug
builtin_initialize registerTraceClass `try.debug.funInd
end Lean

View File

@@ -21,19 +21,35 @@ structure FunIndCandidate where
majors : Array FVarId
deriving Hashable, BEq
/-- `Set` with insertion order preserved. -/
structure OrdSet (α : Type) [Hashable α] [BEq α] where
elems : Array α := #[]
set : Std.HashSet α := {}
deriving Inhabited
def OrdSet.insert {_ : Hashable α} {_ : BEq α} (s : OrdSet α) (a : α) : OrdSet α :=
if s.set.contains a then
s
else
let { elems, set } := s
{ elems := elems.push a, set := set.insert a }
def OrdSet.isEmpty {_ : Hashable α} {_ : BEq α} (s : OrdSet α) : Bool :=
s.elems.isEmpty
structure Result where
/-- All constant symbols occurring in the gal. -/
allConsts : Std.HashSet Name := {}
allConsts : OrdSet Name := {}
/-- Unfolding candiates. -/
unfoldCandidates : Std.HashSet Name := {}
unfoldCandidates : OrdSet Name := {}
/-- Equation function candiates. -/
eqnCandidates : Std.HashSet Name := {}
eqnCandidates : OrdSet Name := {}
/-- Function induction candidates. -/
funIndCandidates : Std.HashSet FunIndCandidate := {}
funIndCandidates : OrdSet FunIndCandidate := {}
/-- Induction candidates. -/
indCandidates : Array InductionCandidate := #[]
/-- Relevant declarations by `libSearch` -/
libSearchResults : Std.HashSet (Name × Grind.EMatchTheoremKind) := {}
libSearchResults : OrdSet (Name × Grind.EMatchTheoremKind) := {}
structure Context where
config : Try.Config

View File

@@ -205,7 +205,15 @@ def evalExpr (e : Expr) : EvalM Val := do
@[grind] theorem UnaryOp.simplify_eval (op : UnaryOp) : (op.simplify a).eval σ = (Expr.una op a).eval σ := by
grind [UnaryOp.simplify.eq_def]
/-- info: Try this: (induction e using Expr.simplify.induct) <;> grind -/
/--
info: Try these:
• (induction e using Expr.simplify.induct) <;> grind
• ·
induction e using Expr.simplify.induct
· grind only [Expr.simplify, BinOp.simplify, Expr.eval, BinaryOp.simplify_eval]
· grind only [UnaryOp.simplify_eval, UnaryOp.simplify, Expr.simplify, Expr.eval]
· simp
-/
#guard_msgs (info) in
example (e : Expr) : e.simplify.eval σ = e.eval σ := by
try?
@@ -304,7 +312,20 @@ theorem State.cons_le_of_eq (h₁ : σ' ≼ σ) (h₂ : σ.find? x = some v) : (
@[grind] theorem State.join_le_left_of (h : σ₁ σ₂) (σ₃ : State) : σ₁.join σ₃ σ₂ := by
grind
/-- info: Try this: (induction σ₁, σ₂ using State.join.induct) <;> grind -/
/--
info: Try these:
• (induction σ₁, σ₂ using State.join.induct) <;> grind
• ·
induction σ₁, σ₂ using State.join.induct
·
grind only [State.join_le_left, State.find?, State.join, State.join_le_left_of, State.le, = State.find?_nil,
State.bot_le, State.le_refl]
·
grind only [State.join, State.join_le_left, State.length_erase_le, State.find?, State.join_le_left_of, State.le, =
State.find?_erase_eq, State.erase_le, State.le_refl, cases Or]
· grind only [State.join, State.join_le_left, State.length_erase_le, State.join_le_left_of, State.le, State.erase_le]
· grind only [State.join, State.join_le_left, State.length_erase_le, State.join_le_left_of, State.le, State.erase_le]
-/
#guard_msgs (info) in
example (σ₁ σ₂ : State) : σ₁.join σ₂ σ₂ := by
try?

View File

@@ -1,4 +1,5 @@
set_option grind.warning false
%reset_grind_attrs
/--
info: Try these:
@@ -97,3 +98,19 @@ example : app (app as bs) cs = app as (app bs cs) := by
intro _ _ _
-- `as`, `bs`, and `cs` now have inaccessible names.
try?
def concat : List α α List α
| .nil, b => .cons b .nil
| .cons a as, b => .cons a (concat as b)
attribute [simp] concat
/--
info: Try this: ·
induction as, a using concat.induct
· rfl
· simp_all
-/
#guard_msgs (info) in
example (as : List α) (a : α) : concat as a = as ++ [a] := by
try?