Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
ffd692e888 feat: preserve instantiation order at finish?
This PR ensures the generated `instantiate` tactic instantiates the
theorems using the same order used by `finish?`
2025-10-22 10:26:49 -07:00
2 changed files with 33 additions and 13 deletions

View File

@@ -12,14 +12,31 @@ import Lean.Meta.Tactic.Grind.EMatchTheoremParam
import Lean.Meta.Tactic.Grind.MarkNestedSubsingletons
namespace Lean.Meta.Grind.Action
/-
**Note**: The unique IDs created to instantiate theorems have the form `<prefix>.<num>`,
where `<num>` corresponds to the instantiation order within a particular proof branch.
Thus, by sorting the collected theorems using their corresponding unique IDs,
we can construct an `instantiate` tactic that performs the instantiations using
the original order.
**Note**: It is unclear at this point whether this is a good strategy or not.
The order in which things are asserted affects the proof found by `grind`.
Thus, preserving the original order should intuitively help ensure that the generated
tactic script for the continuation still closes the goal when combined with the
generated `instantiate` tactic. However, it does not guarantee that the
script can be successfully replayed, since we are filtering out instantiations that do
not appear in the final proof term. Recall that a theorem instance may
contribute to the proof search even if it does not appear in the final proof term.
-/
structure CollectState where
visited : Std.HashSet ExprPtr := {}
collectedThms : Std.HashSet (Origin × EMatchTheoremKind) := {}
thms : Array EMatchTheorem := #[]
idAndThms : Array (Name × EMatchTheorem) := #[]
def collect (e : Expr) (map : EMatch.InstanceMap) : Array EMatchTheorem :=
def collect (e : Expr) (map : EMatch.InstanceMap) : Array (Name × EMatchTheorem) :=
let (_, s) := go e |>.run {}
s.thms
s.idAndThms
where
go (e : Expr) : StateM CollectState Unit := do
if isMarkedSubsingletonApp e then
@@ -35,7 +52,7 @@ where
if let some thm := map[uniqueId]? then
let key := (thm.origin, thm.kind)
unless ( get).collectedThms.contains key do
modify fun s => { s with collectedThms := s.collectedThms.insert key, thms := s.thms.push thm }
modify fun s => { s with collectedThms := s.collectedThms.insert key, idAndThms := s.idAndThms.push (uniqueId, thm) }
match e with
| .lam _ d b _
| .forallE _ d b _ => go d; go b
@@ -93,7 +110,10 @@ public def instantiate' : Action := fun goal kna kp => do
| .closed seq =>
if ( getConfig).trace then
let proof instantiateMVars (mkMVar goal.mvarId)
let usedThms := collect proof map
let usedIdAndThms := collect proof map
-- **Note**: See note above. We want to sort here to reproduce the original instantiation order.
let usedIdAndThms := usedIdAndThms.qsort fun (id₁, _) (id₂, _) => id₁.lt id₂
let usedThms := usedIdAndThms.map (·.2)
let newSeq mkNewSeq goal usedThms seq (approx := false)
if ( checkSeqAt saved? goal newSeq) then
return .closed newSeq

View File

@@ -147,7 +147,7 @@ example (m : IndexMap α β) (a : α) (h : a ∈ m) :
info: Try this:
[apply] ⏎
instantiate only [= mem_indices_of_mem, insert]
instantiate only [= getElem?_neg, = getElem?_pos, =_ HashMap.contains_iff_mem]
instantiate only [=_ HashMap.contains_iff_mem, = getElem?_neg, = getElem?_pos]
cases #4ed2
next =>
cases #ffdf
@@ -179,7 +179,7 @@ example (m : IndexMap α β) (a a' : α) (b : β) :
info: Try this:
[apply] ⏎
instantiate only [= mem_indices_of_mem, insert]
instantiate only [= getElem?_neg, = getElem?_pos, =_ HashMap.contains_iff_mem]
instantiate only [=_ HashMap.contains_iff_mem, = getElem?_neg, = getElem?_pos]
cases #4ed2
next =>
cases #ffdf
@@ -247,19 +247,19 @@ info: Try this:
instantiate only [= Array.getElem_set]
next =>
instantiate only
instantiate only [= Array.getElem_push, size, = HashMap.getElem_insert, = HashMap.mem_insert]
instantiate only [size, = HashMap.mem_insert, = HashMap.getElem_insert, = Array.getElem_push]
next =>
instantiate only [= getElem_def, = mem_indices_of_mem]
instantiate only [= mem_indices_of_mem, = getElem_def]
instantiate only [usr getElem_indices_lt]
instantiate only [size]
cases #ffdf
next =>
instantiate only [=_ WF]
instantiate only [= Array.getElem_set, = getElem?_neg, = getElem?_pos]
instantiate only [= getElem?_neg, = getElem?_pos, = Array.getElem_set]
instantiate only [WF']
next =>
instantiate only
instantiate only [= Array.getElem_push, = HashMap.mem_insert, = HashMap.getElem_insert]
instantiate only [= HashMap.mem_insert, = HashMap.getElem_insert, = Array.getElem_push]
-/
#guard_msgs in
example (m : IndexMap α β) (a a' : α) (b : β) (h : a' m.insert a b) :
@@ -298,8 +298,8 @@ example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) :
/--
info: Try this:
[apply] ⏎
instantiate only [insert, = mem_indices_of_mem, findIdx]
instantiate only [= getElem?_pos, = getElem?_neg]
instantiate only [findIdx, insert, = mem_indices_of_mem]
instantiate only [= getElem?_neg, = getElem?_pos]
cases #1bba
next => instantiate only [findIdx]
next =>