Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
fd97addf43 feat: improve case-split heuristic in grind
This PR improves the case-split heuristics in `grind`.
In this PR, we do not increment the number of case splits in the first
case. The idea is to leverage non-chronological backtracking: if the first case
is solved using a proof that doesn't depend on the case hypothesis, we backtrack and close
the original goal directly. In this scenario, the case-split was "free", it didn't contribute
to the proof. By not counting it, we allow deeper exploration when case-splits turn out to be
irrelevant.
The new heuristic addresses the second example in #11545
2025-12-11 11:31:49 +01:00
3 changed files with 64 additions and 10 deletions

View File

@@ -185,13 +185,10 @@ where
match cs with
| [] =>
modify fun s => { s with split.candidates := cs'.reverse }
if let .some _ numCases isRec _ := c? then
let numSplits := ( get).split.num
-- We only increase the number of splits if there is more than one case or it is recursive.
let numSplits := if numCases > 1 || isRec then numSplits + 1 else numSplits
if let .some .. := c? then
-- Remark: we reset `numEmatch` after each case split.
-- We should consider other strategies in the future.
modify fun s => { s with split.num := numSplits, ematch.num := 0 }
modify fun s => { s with ematch.num := 0 }
return c?
| c::cs =>
if !( checkAnchorRefs c) then
@@ -422,10 +419,24 @@ def splitCore (c : SplitInfo) (numCases : Nat) (isRec : Bool)
pure 0
return (mvarIds, numDigits)
let numSubgoals := mvarIds.length
let subgoals := mvarIds.mapIdx fun i mvarId => { goal with
mvarId
split.trace := { expr := cExpr, i, num := numSubgoals, source := c.source } :: goal.split.trace
}
/-
**Split counter heuristic**: We do not increment `numSplits` for the first case (`i = 0`)
of a non-recursive split. This leverages non-chronological backtracking: if the first case
is solved using a proof that doesn't depend on the case hypothesis, we backtrack and close
the original goal directly. In this scenario, the case-split was "free", it didn't contribute
to the proof. By not counting it, we allow deeper exploration when case-splits turn out to be
irrelevant.
For recursive types or subsequent cases (`i > 0`), we always increment the counter since
these represent genuine branches in the proof search.
-/
let subgoals := mvarIds.mapIdx fun i mvarId =>
let numSplits := goal.split.num
let numSplits := if i > 0 || isRec then numSplits + 1 else numSplits
{ goal with
mvarId
split.num := numSplits
split.trace := { expr := cExpr, i, num := numSubgoals, source := c.source } :: goal.split.trace }
let mut seqNew : Array (List (TSyntax `grind)) := #[]
let mut stuckNew : Array Goal := #[]
for subgoal in subgoals do

View File

@@ -25,7 +25,7 @@ open List
/--
error: `grind` failed
case grind.1.1.1.1.1.1.1.1.1
case grind.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1
α : Type
inst : DecidableEq α
l₁ l₂ : List α
@@ -66,6 +66,44 @@ left_8 : l₁ ~ l₁.diff l₂
right_8 : ∀ (a : α), count a l₁ = count a (l₁.diff l₂)
left_9 : l₁ ~ l₂
right_9 : ∀ (a : α), count a l₁ = count a l₂
left_10 : filter p l₁ ~ filter p (l₁.diff l₂ ++ l₂)
right_10 : ∀ (a : α), count a (filter p l₁) = count a (filter p (l₁.diff l₂ ++ l₂))
left_11 : filter p (l₁.diff l₂ ++ l₂) ~ filter p l₁
right_11 : ∀ (a : α), count a (filter p (l₁.diff l₂ ++ l₂)) = count a (filter p l₁)
left_12 : l₁.diff l₂ ++ l₂ ~ l₂ ++ (l₁.diff l₂ ++ l₂)
right_12 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₂ ++ (l₁.diff l₂ ++ l₂))
left_13 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ ++ l₁
right_13 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂ ++ l₁)
left_14 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ ++ (l₁.diff l₂ ++ l₂)
right_14 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂ ++ (l₁.diff l₂ ++ l₂))
left_15 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ ++ l₂
right_15 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂ ++ l₂)
left_16 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ ++ l₁.diff l₂
right_16 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂ ++ l₁.diff l₂)
left_17 : l₁.diff l₂ ++ l₂ ~ l₁ ++ (l₁.diff l₂ ++ l₂)
right_17 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁ ++ (l₁.diff l₂ ++ l₂))
left_18 : filter p (l₁.diff l₂ ++ l₂) ~ filter p (l₁.diff l₂)
right_18 : ∀ (a : α), count a (filter p (l₁.diff l₂ ++ l₂)) = count a (filter p (l₁.diff l₂))
left_19 : filter p (l₁.diff l₂) ~ filter p (l₁.diff l₂ ++ l₂)
right_19 : ∀ (a : α), count a (filter p (l₁.diff l₂)) = count a (filter p (l₁.diff l₂ ++ l₂))
left_20 : (filter p (l₁.diff l₂ ++ l₂)).Subperm (filter p l₁)
right_20 : (filter p (l₁.diff l₂ ++ l₂)).Subperm (filter p (l₁.diff l₂ ++ l₂))
left_21 : l₁.diff l₂ ++ l₂ ++ l₁.diff l₂ ~ l₁.diff l₂ ++ l₂
right_21 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂ ++ l₁.diff l₂) = count a (l₁.diff l₂ ++ l₂)
left_22 : l₁.diff l₂ ++ l₂ ++ l₁ ~ l₁.diff l₂ ++ l₂
right_22 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂ ++ l₁) = count a (l₁.diff l₂ ++ l₂)
left_23 : l₁.diff l₂ ++ l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂
right_23 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂)
left_24 : l₁.diff l₂ ++ l₂ ++ (l₁.diff l₂ ++ l₂) ~ l₁.diff l₂ ++ l₂
right_24 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂ ++ (l₁.diff l₂ ++ l₂)) = count a (l₁.diff l₂ ++ l₂)
left_25 : l₁ ++ (l₁.diff l₂ ++ l₂) ~ l₁.diff l₂ ++ l₂
right_25 : ∀ (a : α), count a (l₁ ++ (l₁.diff l₂ ++ l₂)) = count a (l₁.diff l₂ ++ l₂)
left_26 : l₂ ++ (l₁.diff l₂ ++ l₂) ~ l₁.diff l₂ ++ l₂
right_26 : ∀ (a : α), count a (l₂ ++ (l₁.diff l₂ ++ l₂)) = count a (l₁.diff l₂ ++ l₂)
left_27 : l₁.diff l₂ ++ (l₁.diff l₂ ++ l₂) ~ l₁.diff l₂ ++ l₂
right_27 : ∀ (a : α), count a (l₁.diff l₂ ++ (l₁.diff l₂ ++ l₂)) = count a (l₁.diff l₂ ++ l₂)
left_28 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ (l₁.diff l₂ ++ l₂)
right_28 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ (l₁.diff l₂ ++ l₂))
⊢ False
-/
#guard_msgs in

View File

@@ -0,0 +1,5 @@
example (a b : Nat) (f g : Nat Nat)
(hf : ( i a, f i f (i + 1)) f 0 = 0)
(hg : ( i b, g i g (i + 1)) g 0 = 0 g b = 0) :
g (a + b - a) = 0 := by
grind