feat: add withEarlyReturnNewDo variants for new do elaborator (#12881)

This PR adds `Invariant.withEarlyReturnNewDo`,
`StringInvariant.withEarlyReturnNewDo`, and
`StringSliceInvariant.withEarlyReturnNewDo` which use `Prod` instead of
`MProd` for the state tuple, matching the new do elaborator's output.
The existing `withEarlyReturn` definitions are reverted to `MProd` for
backwards compatibility with the legacy do elaborator. Tests and
invariant suggestions are updated to use the `NewDo` variants.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Sebastian Graf
2026-03-11 18:44:34 +08:00
committed by GitHub
parent 17807e1cbe
commit 734566088f
7 changed files with 67 additions and 29 deletions

View File

@@ -408,20 +408,20 @@ public def suggestInvariant (vcs : Array MVarId) (inv : MVarId) : TacticM Term :
--
-- Finally, build the syntax for the suggestion. It's a giant configuration space mess, because
-- 1. Generally want to suggest something using `⇓ ⟨xs, letMuts⟩ => ...`, i.e. `PostCond.noThrow`.
-- 2. However, on early return we want to suggest something using `Invariant.withEarlyReturn`.
-- 2. However, on early return we want to suggest something using `Invariant.withEarlyReturnNewDo`.
-- 3. When there are non-`False` failure conditions, we cannot suggest `⇓ ⟨xs, letMuts⟩ => ...`.
-- We might be able to suggest `⇓? ⟨xs, letMuts⟩ => ...` (`True` failure condition),
-- or `post⟨...⟩` (more than 0 failure handlers, but ending in `PUnit.unit`), and fall back to
-- `by exact ⟨...⟩` (not ending in `PUnit.unit`).
-- 4. Similarly for the `onExcept` argument of `Invariant.withEarlyReturn`.
-- 4. Similarly for the `onExcept` argument of `Invariant.withEarlyReturnNewDo`.
-- Hence the spaghetti code.
--
if let some (ρ, σ) hasEarlyReturn vcs inv letMutsTy then
-- logWarning m!"Found early return for {inv}!"
-- Suggest an invariant using `Invariant.withEarlyReturn`.
-- Suggest an invariant using `Invariant.withEarlyReturnNewDo`.
if let some (success, onReturn, failureConds) := suggestion? then
-- First construct `onContinue` and `onReturn` clause and simplify them according to the
-- definition of `Invariant.withEarlyReturn`.
-- definition of `Invariant.withEarlyReturnNewDo`.
let (onContinue, onReturn) withLocalDeclD `xs (mkApp2 (mkConst ``List.Cursor us) α l) fun xs =>
withLocalDeclD `r ρ fun r =>
withLocalDeclD `letMuts σ fun letMuts => do
@@ -439,21 +439,21 @@ public def suggestInvariant (vcs : Array MVarId) (inv : MVarId) : TacticM Term :
if failureConds.points.isEmpty then
match failureConds.default with
| .false | .punit =>
`(Invariant.withEarlyReturn (onReturn := fun r letMuts => $onReturn) (onContinue := fun xs letMuts => $onContinue))
`(Invariant.withEarlyReturnNewDo (onReturn := fun r letMuts => $onReturn) (onContinue := fun xs letMuts => $onContinue))
-- we handle the following two cases here rather than through
-- `postCondWithMultipleConditions` below because that would insert a superfluous `by exact _`.
| .true =>
`(Invariant.withEarlyReturn (onReturn := fun r letMuts => $onReturn) (onContinue := fun xs letMuts => $onContinue (onExcept := ExceptConds.true)))
`(Invariant.withEarlyReturnNewDo (onReturn := fun r letMuts => $onReturn) (onContinue := fun xs letMuts => $onContinue (onExcept := ExceptConds.true)))
| .other e =>
`(Invariant.withEarlyReturn (onReturn := fun r letMuts => $onReturn) (onContinue := fun xs letMuts => $onContinue (onExcept := $( Lean.PrettyPrinter.delab e))))
`(Invariant.withEarlyReturnNewDo (onReturn := fun r letMuts => $onReturn) (onContinue := fun xs letMuts => $onContinue (onExcept := $( Lean.PrettyPrinter.delab e))))
else -- need the postcondition long form as `onExcept` arg
let mut terms : Array Term := #[]
for point in failureConds.points do
terms := terms.push ( Lean.PrettyPrinter.delab point)
let onExcept postCondWithMultipleConditions terms failureConds.default
`(Invariant.withEarlyReturn (onReturn := fun r letMuts => $onReturn) (onContinue := fun xs letMuts => $onContinue) (onExcept := $onExcept))
`(Invariant.withEarlyReturnNewDo (onReturn := fun r letMuts => $onReturn) (onContinue := fun xs letMuts => $onContinue) (onExcept := $onExcept))
else -- No suggestion. Just fill in `_`.
`(Invariant.withEarlyReturn (onReturn := fun r letMuts => _) (onContinue := fun xs letMuts => _))
`(Invariant.withEarlyReturnNewDo (onReturn := fun r letMuts => _) (onContinue := fun xs letMuts => _))
else if let some (success, _, failureConds) := suggestion? then
-- No early return, but we do have a suggestion.
withLocalDeclD `xs (mkApp2 (mkConst ``List.Cursor us) α l) fun xs =>

View File

@@ -711,6 +711,18 @@ won't need to prove anything about the bogus case where the loop has returned ea
another iteration of the loop body.
-/
abbrev Invariant.withEarlyReturn {α} {xs : List α} {γ : Type (max u₁ u₂)}
(onContinue : List.Cursor xs β Assertion ps)
(onReturn : γ β Assertion ps)
(onExcept : ExceptConds ps := ExceptConds.false) :
Invariant xs (MProd (Option γ) β) ps :=
fun xs, x, b => spred(
(x = none onContinue xs b)
( r, x = some r xs.suffix = [] onReturn r b)),
onExcept
/-- Like `Invariant.withEarlyReturn`, but for the new `do` elaborator which uses `Prod`
instead of `MProd` for the state tuple. -/
abbrev Invariant.withEarlyReturnNewDo {α} {xs : List α} {γ : Type (max u₁ u₂)}
(onContinue : List.Cursor xs β Assertion ps)
(onReturn : γ β Assertion ps)
(onExcept : ExceptConds ps := ExceptConds.false) :
@@ -2039,6 +2051,19 @@ abbrev StringInvariant.withEarlyReturn {s : String}
( r, x = some r pos = s.endPos onReturn r b)),
onExcept
/-- Like `StringInvariant.withEarlyReturn`, but for the new `do` elaborator which uses `Prod`
instead of `MProd` for the state tuple. -/
abbrev StringInvariant.withEarlyReturnNewDo {s : String}
(onContinue : s.Pos β Assertion ps)
(onReturn : γ β Assertion ps)
(onExcept : ExceptConds ps := ExceptConds.false) :
StringInvariant s (Prod (Option γ) β) ps
:=
fun pos, x, b => spred(
(x = none onContinue pos b)
( r, x = some r pos = s.endPos onReturn r b)),
onExcept
@[spec]
theorem Spec.forIn_string
{s : String} {init : β} {f : Char β m (ForInStep β)}
@@ -2111,6 +2136,19 @@ abbrev StringSliceInvariant.withEarlyReturn {s : String.Slice}
( r, x = some r pos = s.endPos onReturn r b)),
onExcept
/-- Like `StringSliceInvariant.withEarlyReturn`, but for the new `do` elaborator which uses `Prod`
instead of `MProd` for the state tuple. -/
abbrev StringSliceInvariant.withEarlyReturnNewDo {s : String.Slice}
(onContinue : s.Pos β Assertion ps)
(onReturn : γ β Assertion ps)
(onExcept : ExceptConds ps := ExceptConds.false) :
StringSliceInvariant s (Prod (Option γ) β) ps
:=
fun pos, x, b => spred(
(x = none onContinue pos b)
( r, x = some r pos = s.endPos onReturn r b)),
onExcept
@[spec]
theorem Spec.forIn_stringSlice
{s : String.Slice} {init : β} {f : Char β m (ForInStep β)}

View File

@@ -522,7 +522,7 @@ example (p : Nat → Prop) [DecidablePred p] (n : Nat) :
apply Id.of_wp_run_eq h
mvcgen
case inv1 =>
exact Invariant.withEarlyReturn
exact Invariant.withEarlyReturnNewDo
(onReturn := fun ret _ => ret = false ¬ i < n, p i)
(onContinue := fun xs _ => i, i xs.prefix p i)
all_goals simp_all [-Classical.not_forall]; try grind

View File

@@ -105,7 +105,7 @@ def nodup (l : List Int) : Bool := Id.run do
info: Try this:
[apply] invariants
·
Invariant.withEarlyReturn (onReturn := fun r letMuts => ⌜(r = true ↔ l.Nodup) ∧ l.Nodup⌝) (onContinue :=
Invariant.withEarlyReturnNewDo (onReturn := fun r letMuts => ⌜(r = true ↔ l.Nodup) ∧ l.Nodup⌝) (onContinue :=
fun xs letMuts => ⌜xs.prefix = [] ∧ letMuts = ∅ xs.suffix = [] ∧ l.Nodup⌝)
-/
#guard_msgs (info) in
@@ -132,14 +132,14 @@ def nodup_twice (l : List Int) : Bool := Id.run do
info: Try this:
[apply] invariants
·
Invariant.withEarlyReturn (onReturn := fun r letMuts =>
Invariant.withEarlyReturnNewDo (onReturn := fun r letMuts =>
spred({ down := r = true ↔ l.Nodup } ∧ Prod.fst ?inv2 ({ «prefix» := [], suffix := l, property := ⋯ }, none, ∅)))
(onContinue := fun xs letMuts =>
spred({ down := xs.prefix = [] ∧ letMuts = ∅ }
⌜xs.suffix = []⌝ ∧
{ down := True } ∧ Prod.fst ?inv2 ({ «prefix» := [], suffix := l, property := ⋯ }, none, ∅)))
·
Invariant.withEarlyReturn (onReturn := fun r letMuts => ⌜(r = true ↔ l.Nodup) ∧ l.Nodup⌝) (onContinue :=
Invariant.withEarlyReturnNewDo (onReturn := fun r letMuts => ⌜(r = true ↔ l.Nodup) ∧ l.Nodup⌝) (onContinue :=
fun xs letMuts => ⌜xs.prefix = [] ∧ letMuts = ∅ xs.suffix = [] ∧ l.Nodup⌝)
-/
#guard_msgs (info) in
@@ -190,7 +190,7 @@ def mkFreshN_early_return (n : Nat) : AppM (List Nat) := do
info: Try this:
[apply] invariants
·
Invariant.withEarlyReturn (onReturn := fun r letMuts => ⌜r.Nodup ∧ letMuts.toList.Nodup⌝) (onContinue :=
Invariant.withEarlyReturnNewDo (onReturn := fun r letMuts => ⌜r.Nodup ∧ letMuts.toList.Nodup⌝) (onContinue :=
fun xs letMuts => ⌜xs.prefix = [] ∧ letMuts = acc✝ xs.suffix = [] ∧ letMuts.toList.Nodup⌝)
-/
#guard_msgs (info) in
@@ -207,7 +207,7 @@ def earlyReturnWithoutLetMut (l : List Int) : Bool := Id.run do
info: Try this:
[apply] invariants
·
Invariant.withEarlyReturn (onReturn := fun r letMuts => ⌜r = true⌝) (onContinue := fun xs letMuts =>
Invariant.withEarlyReturnNewDo (onReturn := fun r letMuts => ⌜r = true⌝) (onContinue := fun xs letMuts =>
⌜xs.prefix = [] xs.suffix = []⌝)
-/
#guard_msgs (info) in
@@ -273,7 +273,7 @@ def polyNodup [Monad m] (l : List Int) : m Bool := do
info: Try this:
[apply] invariants
·
Invariant.withEarlyReturn (onReturn := fun r letMuts => ⌜(r = true ↔ l.Nodup) ∧ l.Nodup⌝) (onContinue :=
Invariant.withEarlyReturnNewDo (onReturn := fun r letMuts => ⌜(r = true ↔ l.Nodup) ∧ l.Nodup⌝) (onContinue :=
fun xs letMuts => ⌜xs.prefix = [] ∧ letMuts = seen✝ xs.suffix = [] ∧ l.Nodup⌝)
-/
#guard_msgs (info) in

View File

@@ -20,7 +20,7 @@ theorem nodup_correct_vanilla (l : List Int) : nodup l ↔ l.Nodup := by
apply Id.of_wp_run_eq h
mvcgen
case inv1 =>
exact Invariant.withEarlyReturn
exact Invariant.withEarlyReturnNewDo
(onReturn := fun ret seen => ret = false ¬l.Nodup)
(onContinue := fun traversalState seen =>
( x, x seen x traversalState.prefix) traversalState.prefix.Nodup)
@@ -30,7 +30,7 @@ theorem nodup_correct_invariants (l : List Int) : nodup l ↔ l.Nodup := by
generalize h : nodup l = r
apply Id.of_wp_run_eq h
mvcgen invariants
· Invariant.withEarlyReturn
· Invariant.withEarlyReturnNewDo
(onReturn := fun ret seen => ret = false ¬l.Nodup) -- minimal indentation here is part of the test
(onContinue := fun traversalState seen =>
( x, x seen x traversalState.prefix) traversalState.prefix.Nodup)
@@ -40,7 +40,7 @@ theorem nodup_correct_invariants_with_pretac (l : List Int) : nodup l ↔ l.Nodu
generalize h : nodup l = r
apply Id.of_wp_run_eq h
mvcgen invariants
· Invariant.withEarlyReturn
· Invariant.withEarlyReturnNewDo
(onReturn := fun ret seen => ret = false ¬l.Nodup)
(onContinue := fun traversalState seen =>
( x, x seen x traversalState.prefix) traversalState.prefix.Nodup)
@@ -51,7 +51,7 @@ theorem nodup_correct_invariants_with_cases (l : List Int) : nodup l ↔ l.Nodup
apply Id.of_wp_run_eq h
mvcgen
invariants
· Invariant.withEarlyReturn
· Invariant.withEarlyReturnNewDo
(onReturn := fun ret seen => ret = false ¬l.Nodup)
(onContinue := fun traversalState seen =>
( x, x seen x traversalState.prefix) traversalState.prefix.Nodup)
@@ -67,7 +67,7 @@ theorem nodup_correct_invariants_with_pretac_cases (l : List Int) : nodup l ↔
apply Id.of_wp_run_eq h
mvcgen
invariants
· Invariant.withEarlyReturn
· Invariant.withEarlyReturnNewDo
(onReturn := fun ret seen => ret = false ¬l.Nodup)
(onContinue := fun traversalState seen =>
( x, x seen x traversalState.prefix) traversalState.prefix.Nodup)
@@ -81,7 +81,7 @@ theorem nodup_correct_invariants_with_cases_error (l : List Int) : nodup l ↔ l
apply Id.of_wp_run_eq h
mvcgen
invariants
· Invariant.withEarlyReturn
· Invariant.withEarlyReturnNewDo
(onReturn := fun ret seen => ret = false ¬l.Nodup)
(onContinue := fun traversalState seen =>
( x, x seen x traversalState.prefix) traversalState.prefix.Nodup)
@@ -138,11 +138,11 @@ theorem nodup_twice_correct_invariants_with (l : List Int) : nodup_twice l ↔ l
apply Id.of_wp_run_eq h
mvcgen
invariants
· Invariant.withEarlyReturn
· Invariant.withEarlyReturnNewDo
(onReturn := fun ret seen => ret = false ¬l.Nodup)
(onContinue := fun traversalState seen =>
( x, x seen x traversalState.prefix) traversalState.prefix.Nodup)
· Invariant.withEarlyReturn
· Invariant.withEarlyReturnNewDo
(onReturn := fun ret seen => ret = false ¬l.Nodup)
(onContinue := fun traversalState seen =>
( x, x seen x traversalState.prefix) traversalState.prefix.Nodup)
@@ -153,11 +153,11 @@ theorem nodup_twice_correct_invariants_multiple_with (l : List Int) : nodup_twic
apply Id.of_wp_run_eq h
mvcgen
invariants
· Invariant.withEarlyReturn
· Invariant.withEarlyReturnNewDo
(onReturn := fun ret seen => ret = false ¬l.Nodup)
(onContinue := fun traversalState seen =>
( x, x seen x traversalState.prefix) traversalState.prefix.Nodup)
· Invariant.withEarlyReturn
· Invariant.withEarlyReturnNewDo
(onReturn := fun ret seen => ret = false ¬l.Nodup)
(onContinue := fun traversalState seen =>
( x, x seen x traversalState.prefix) traversalState.prefix.Nodup)
@@ -174,7 +174,7 @@ theorem nodup_twice_missing_one_invariant (l : List Int) : nodup_twice l ↔ l.N
apply Id.of_wp_run_eq h
mvcgen
invariants
· Invariant.withEarlyReturn
· Invariant.withEarlyReturnNewDo
(onReturn := fun ret seen => ret = false ¬l.Nodup)
(onContinue := fun traversalState seen =>
( x, x seen x traversalState.prefix) traversalState.prefix.Nodup)

View File

@@ -25,7 +25,7 @@ theorem nodup_correct (l : List Int) : nodup l ↔ l.Nodup := by
apply Id.of_wp_run_eq h
mvcgen
case inv1 =>
exact Invariant.withEarlyReturn
exact Invariant.withEarlyReturnNewDo
(onReturn := fun ret seen => ret = false ¬l.Nodup)
(onContinue := fun traversalState seen =>
( x, x seen x traversalState.prefix) traversalState.prefix.Nodup)

View File

@@ -69,7 +69,7 @@ theorem pairsSumToZero_correct (l : List Int) : pairsSumToZero l ↔ l.ExistsPai
mvcgen
case inv1 =>
exact Invariant.withEarlyReturn
exact Invariant.withEarlyReturnNewDo
(onReturn := fun r b => r = true l.ExistsPair (fun a b => a + b = 0))
(onContinue := fun traversalState seen =>
( x, x seen x traversalState.prefix) ¬traversalState.prefix.ExistsPair (fun a b => a + b = 0))