feat: add dite and matcher splitting to sym-based MVCGen

This PR generalizes the VCGen split handling from ite-only to ite, dite,
and arbitrary matchers. `withAbstractSplit` provides a uniform CPS
interface: it introduces abstract fvars for the split components and a
splitting function that the two-pass backward rule construction in
`mkBackwardRuleForSplit` uses — pass 1 discovers hypothesis types via
`rwIfOrMatcher`, pass 2 builds the proof.

For `ite`, branches have type `mα`; for `dite`, branches have dependent
types `(h : c) → mα` / `(h : ¬c) → mα`; for matchers, the splitter
with equality hypotheses and `Eq.refl` witnesses is used.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Sebastian Graf
2026-03-11 13:08:19 +00:00
parent aae827cb4c
commit f70abe08a5
8 changed files with 358 additions and 36 deletions

View File

@@ -1,6 +1,9 @@
import Cases.AddSubCancel
import Cases.AddSubCancelDeep
import Cases.AddSubCancelSimp
import Cases.DiteSplit
import Cases.GetThrowSet
import Cases.MatchSplit
import Cases.MatchSplitState
import Cases.PurePrecond
import Cases.ReaderState

View File

@@ -0,0 +1,42 @@
import Lean
import VCGen
open Lean Meta Elab Tactic Sym Std Do SpecAttr
namespace DiteSplit
set_option mvcgen.warning false
abbrev M := ExceptT String <| StateM Nat
@[spec high]
theorem Spec.throw_M {e : String} :
Q.2.1 e throw (m := M) e Q := by
mvcgen
@[spec high]
theorem Spec.set_M {s : Nat} :
fun _ => Q.1 s set (m := M) s Q := by
mvcgen
@[spec high]
theorem Spec.get_M :
fun s => Q.1 s s get (m := M) Q := by
mvcgen
def step (v : Nat) : M Unit := do
let s get
if _ : s > v then
throw "s is too large"
set (s + v)
let s get
set (s - v)
def loop (n : Nat) : M Unit := do
match n with
| 0 => pure ()
| n+1 => step n; loop n
def Goal (n : Nat) : Prop := fun s => s = 0 loop n _ s => s = 0
end DiteSplit

View File

@@ -0,0 +1,40 @@
import Lean
import VCGen
open Lean Meta Elab Tactic Sym Std Do SpecAttr
namespace MatchSplit
set_option mvcgen.warning false
abbrev M := ExceptT String <| StateM Nat
@[spec high]
theorem Spec.throw_M {e : String} :
Q.2.1 e throw (m := M) e Q := by
mvcgen
@[spec high]
theorem Spec.set_M {s : Nat} :
fun _ => Q.1 s set (m := M) s Q := by
mvcgen
@[spec high]
theorem Spec.get_M :
fun s => Q.1 s s get (m := M) Q := by
mvcgen
def step (v : Nat) : M Unit := do
let s get
match v with
| 0 => throw "v is zero"
| n+1 => set (s + n + 1); let s get; set (s - n)
def loop (n : Nat) : M Unit := do
match n with
| 0 => pure ()
| n+1 => step n; loop n
def Goal (n : Nat) : Prop := fun s => s = 0 loop n _ s => s = n
end MatchSplit

View File

@@ -0,0 +1,41 @@
import Lean
import VCGen
open Lean Meta Elab Tactic Sym Std Do SpecAttr
namespace MatchSplitState
set_option mvcgen.warning false
abbrev M := ExceptT String <| StateM Nat
@[spec high]
theorem Spec.throw_M {e : String} :
Q.2.1 e throw (m := M) e Q := by
mvcgen
@[spec high]
theorem Spec.set_M {s : Nat} :
fun _ => Q.1 s set (m := M) s Q := by
mvcgen
@[spec high]
theorem Spec.get_M :
fun s => Q.1 s s get (m := M) Q := by
mvcgen
/-- Matches on state `s` — the discriminant IS the excess state arg. -/
def step : M Unit := do
let s get
match s with
| 0 => throw "s is zero"
| n+1 => set n
def loop (n : Nat) : M Unit := do
match n with
| 0 => pure ()
| n+1 => step; loop n
def Goal (n : Nat) : Prop := fun s => s = n loop n _ s => s = 0
end MatchSplitState

View File

@@ -19,7 +19,7 @@ lean_lib Cases where
@[default_target]
lean_lib VCGenBench where
roots := #[`vcgen_add_sub_cancel, `vcgen_add_sub_cancel_deep, `vcgen_add_sub_cancel_simp,
`vcgen_get_throw_set, `vcgen_pure_precond, `vcgen_reader_state]
`vcgen_get_throw_set, `vcgen_pure_precond, `vcgen_reader_state, `vcgen_match_split]
moreLeanArgs := #["--tstack=100000000"]
@[default_target]

View File

@@ -398,54 +398,220 @@ meta def mkBackwardRuleFromSimpSpec (specThm : SpecTheoremNew) (m σs ps instWP
open Lean.Elab.Tactic.Do in
/--
Creates a reusable backward rule for `ite`. It proves a theorem of the following form:
CPS helper that introduces local fvars for a split and provides a uniform splitter.
The continuation `k` receives:
- `rawProg`: the abstract program expression (e.g. `ite mα c dec t e` or `matcher discrs alts`)
- `splitFVars`: the fvars introduced for the split (to be abstracted over in the final lemma)
- `splitFn`: a function that, given a goal expression and an `onAlt` callback, constructs the
splitting proof. The `onAlt` callback receives the alt index, an array of params (condition
proofs for ite, pattern vars + equality proofs for matchers) that `rwIfOrMatcher` can find via
`findLocalDeclWithType?` / `assumption`, and the per-alt body type from the splitter telescope.
For ite/dite, the body type equals the original goal.
For matchers, the body type has discriminant fvars substituted by patterns (cf. 99c83b9c),
so that excess state args coinciding with discriminants are properly replaced.
-/
private meta def withSplitterAndLocals {α} [Inhabited α] (splitInfo : SplitInfo) (mα : Expr) (v : Level)
(k : (rawProg : Expr) (splitFVars : Array Expr)
(splitFn : (goal : Expr) (onAlt : Nat Array Expr Expr MetaM Expr) MetaM Expr)
SymM α) : SymM α := do
match splitInfo with
| .ite _ =>
withLocalDeclD `c (mkSort 0) fun c => do
withLocalDeclD `dec (mkApp (mkConst ``Decidable) c) fun dec => do
withLocalDeclD `t mα fun t => do
withLocalDeclD `e mα fun e => do
let rawProg := mkApp5 (mkConst ``ite [v.succ]) mα c dec t e
let splitFVars := #[c, dec, t, e]
let splitFn (goal : Expr) (onAlt : Nat Array Expr Expr MetaM Expr) : MetaM Expr := do
let ht withLocalDecl `h .default c fun h => do
mkLambdaFVars #[h] ( onAlt 0 #[h] goal)
let he withLocalDecl `h .default (mkNot c) fun h => do
mkLambdaFVars #[h] ( onAlt 1 #[h] goal)
return mkApp5 (mkConst ``dite [0]) goal c dec ht he
k rawProg splitFVars splitFn
| .dite _ =>
withLocalDeclD `c (mkSort 0) fun c => do
withLocalDeclD `dec (mkApp (mkConst ``Decidable) c) fun dec => do
let tTy liftMetaM <| mkArrow c mα
let eTy liftMetaM <| mkArrow (mkNot c) mα
withLocalDeclD `t tTy fun t => do
withLocalDeclD `e eTy fun e => do
let rawProg := mkApp5 (mkConst ``dite [v.succ]) mα c dec t e
let splitFVars := #[c, dec, t, e]
let splitFn (goal : Expr) (onAlt : Nat Array Expr Expr MetaM Expr) : MetaM Expr := do
let ht withLocalDecl `h .default c fun h => do
mkLambdaFVars #[h] ( onAlt 0 #[h] goal)
let he withLocalDecl `h .default (mkNot c) fun h => do
mkLambdaFVars #[h] ( onAlt 1 #[h] goal)
return mkApp5 (mkConst ``dite [0]) goal c dec ht he
k rawProg splitFVars splitFn
| .matcher matcherApp => do
-- Introduce fvar discriminants
let discrNamesTypes matcherApp.discrs.mapIdxM fun i discr => do
return ((`discr).appendIndexAfter (i+1), Sym.inferType discr)
withLocalDeclsDND discrNamesTypes fun discrs => do
-- Non-dependent motive: fun _ ... _ => mα
let motive liftMetaM <| lambdaTelescope matcherApp.motive fun motiveArgs _ =>
mkLambdaFVars motiveArgs mα
-- Adjust eliminator level for the abstract motive
let matcherLevels match matcherApp.uElimPos? with
| none => pure matcherApp.matcherLevels
| some pos => do
let uElim liftMetaM <| getLevel mα
pure <| matcherApp.matcherLevels.set! pos uElim
-- Build matcher partial application (params + motive + discrs, without alts)
let matcherPartial := mkAppN (mkConst matcherApp.matcherName matcherLevels.toList) matcherApp.params
let matcherPartial := mkApp matcherPartial motive
let matcherPartial := mkAppN matcherPartial discrs
-- Infer alt types and introduce fvar alts
let origAltTypes liftMetaM <| inferArgumentTypesN matcherApp.alts.size matcherPartial
let altNamesTypes := origAltTypes.mapIdx fun i ty => ((`alt).appendIndexAfter (i+1), ty)
withLocalDeclsDND altNamesTypes fun alts => do
let rawProg := mkAppN matcherPartial alts
let splitFVars := discrs ++ alts
-- Build the splitting function using the splitter (bypasses MatcherApp.transform)
let splitFn (goal : Expr) (onAlt : Nat Array Expr Expr MetaM Expr) : MetaM Expr := do
-- Build splitter motive with equality hypotheses.
-- The motive substitutes discriminant fvars in the goal body (cf. 99c83b9c), so that
-- excess state args that coincide with discriminants are properly replaced by patterns:
-- fun d₁...dₙ => (discr₁ = d₁) → ... → (discrₙ = dₙ) → goal[discr₁:=d₁, ..., discrₙ:=dₙ]
let splitterMotive do
let discrNamesTypes discrs.mapM fun d => return (`d, inferType d)
withLocalDeclsDND discrNamesTypes fun ds => do
-- Substitute discriminant fvars with motive bound vars
let mut body := (goal.abstract discrs).instantiateRev ds.reverse
for i in (List.range discrs.size).reverse do
let discrTy inferType ds[i]!
let lvl getLevel discrTy
let eqType := mkApp3 (mkConst ``Eq [lvl]) discrTy discrs[i]! ds[i]!
body mkArrow eqType body
mkLambdaFVars ds body
-- Set eliminator level to 0 (Prop)
let splitterLevels match matcherApp.uElimPos? with
| none => pure matcherApp.matcherLevels
| some pos => pure <| matcherApp.matcherLevels.set! pos .zero
-- Look up splitter
let matchEqns Match.getEquationsFor matcherApp.matcherName
let splitter := matchEqns.splitterName
-- Build splitter partial application
let splitterPartial := mkAppN (mkConst splitter splitterLevels.toList) matcherApp.params
let splitterPartial := mkApp splitterPartial splitterMotive
let splitterPartial := mkAppN splitterPartial discrs
-- Infer splitter alt types
let splitterAltTypes inferArgumentTypesN matcherApp.alts.size splitterPartial
-- Number of forall params per alt: structural params + equality params
let numTeleParams := matcherApp.altNumParams.map (· + discrs.size)
-- Build each splitter alt
let mut splitterAlts := #[]
for idx in [:matcherApp.alts.size] do
let altType := splitterAltTypes[idx]!
let numParams := numTeleParams[idx]!
let alt forallBoundedTelescope altType (some numParams) fun xs bodyType => do
let prf onAlt idx xs bodyType
mkLambdaFVars xs prf
splitterAlts := splitterAlts.push alt
-- Combine: splitter params motive discrs alts... Eq.refls...
let mut prf := mkAppN splitterPartial splitterAlts
for discr in discrs do
let discrTy inferType discr
let lvl getLevel discrTy
prf := mkApp prf (mkApp2 (mkConst ``Eq.refl [lvl]) discrTy discr)
return prf
k rawProg splitFVars splitFn
open Lean.Elab.Tactic.Do in
/--
Creates a reusable backward rule for splitting `ite`, `dite`, or matchers.
For `ite`/`dite`, proves a theorem of the following form:
```
example {m} {σ} {ps} [WP m (.arg σ ps)] -- These are fixed. The other arguments are parameters of the rule:
example {m} {σ} {ps} [WP m (.arg σ ps)]
{α} {c : Prop} [Decidable c] {t e : m α} {s : σ} {P : Assertion ps} {Q : PostCond α (.arg σ ps)}
(hthen : P ⊢ₛ wp⟦t⟧ Q s) (helse : P ⊢ₛ wp⟦e⟧ Q s)
(hthen : c → P ⊢ₛ wp⟦t⟧ Q s) (helse : ¬c → P ⊢ₛ wp⟦e⟧ Q s)
: P ⊢ₛ wp⟦ite c t e⟧ Q s
```
For matchers, the hypothesis types are discovered via `rwIfOrMatcher` after `withSplitterAndLocals`
opens the splitter telescope. No branching on split kind is needed in this function.
-/
meta def mkBackwardRuleForIte (m σs ps instWP : Expr) (excessArgs : Array Expr) : SymM BackwardRule := do
meta def mkBackwardRuleForSplit (splitInfo : SplitInfo) (m σs ps instWP : Expr) (excessArgs : Array Expr) : SymM BackwardRule := do
let preprocessExpr : Expr SymM Expr := shareCommon <=< liftMetaM unfoldReducible
-- Replace the program expression in a goal-shaped expression:
-- SPred.entails σs P (mkAppN (PredTrans.apply ps α (WP.wp m ps instWP α PROG) Q) ss)
-- Built by `goalWithProg`, so the structure is fixed.
let replaceProgInGoal (goal newProg : Expr) : Expr :=
let goalArgs := goal.getAppArgs -- [σs, P, wpApplyQss]
let waqArgs := goalArgs[2]!.getAppArgs -- [ps, α, wp, Q, ss...]
let wpArgs := waqArgs[2]!.getAppArgs -- [m, ps, instWP, α, prog]
let wp' := mkAppN waqArgs[2]!.getAppFn (wpArgs.set! 4 newProg)
let waq' := mkAppN goalArgs[2]!.getAppFn (waqArgs.set! 2 wp')
mkAppN goal.getAppFn (goalArgs.set! 2 waq')
let extractProgFromGoal (goal : Expr) : Expr :=
goal.getAppArgs[2]!.getAppArgs[2]!.getAppArgs[4]!
let prf do
let us := instWP.getAppFn.constLevels!
let u := us[0]!
let v := us[1]!
withLocalDeclD `α (mkSort u.succ) fun α => do
let mα preprocessExpr <| mkApp m α
withLocalDeclD `c (mkSort 0) fun c => do
withLocalDeclD `dec (mkApp (mkConst ``Decidable) c) fun dec => do
withLocalDeclD `t mα fun t => do
withLocalDeclD `e mα fun e => do
let prog preprocessExpr (mkApp5 (mkConst ``ite [v.succ]) mα c dec t e)
let excessArgNamesTypes excessArgs.mapM fun arg =>
return (`s, Meta.inferType arg)
withSplitterAndLocals splitInfo mα v fun rawProg splitFVars splitFn => do
let excessArgNamesTypes excessArgs.mapM fun arg => return (`s, Meta.inferType arg)
withLocalDeclsDND excessArgNamesTypes fun ss => do
-- When an excess arg equals a matcher discriminant, reuse the discriminant fvar.
-- Combined with the motive substitution in `withSplitterAndLocals`, this ensures
-- the per-alt body types have discriminants replaced by patterns in excess arg positions.
let (ssEffective, ssToAbstract) match splitInfo with
| .matcher matcherApp =>
let numDiscrs := matcherApp.discrs.size
let mut ssEff := ss
let mut ssAbs := #[]
for i in [:excessArgs.size] do
let mut shared := false
for j in [:numDiscrs] do
if excessArgs[i]! == matcherApp.discrs[j]! then
ssEff := ssEff.set! i splitFVars[j]!
shared := true
break
if !shared then
ssAbs := ssAbs.push ss[i]!
pure (ssEff, ssAbs)
| _ => pure (ss, ss)
withLocalDeclD `P ( preprocessExpr <| mkApp (mkConst ``SPred [u]) σs) fun P => do
withLocalDeclD `Q ( preprocessExpr <| mkApp2 (mkConst ``PostCond [u]) α ps) fun Q => do
let goalWithProg prog :=
let wp := mkApp5 (mkConst ``WP.wp [u, v]) m ps instWP α prog
let wpApplyQ := mkApp4 (mkConst ``PredTrans.apply [u]) ps α wp Q -- wp⟦prog⟧ Q
let wpApplyQ := mkAppN wpApplyQ ss -- wp⟦prog⟧ Q s₁ ... sₙ
let wpApplyQ := mkApp4 (mkConst ``PredTrans.apply [u]) ps α wp Q
let wpApplyQ := mkAppN wpApplyQ ssEffective
mkApp3 (mkConst ``SPred.entails [u]) σs P wpApplyQ
let thenType mkArrow c (goalWithProg t)
withLocalDeclD `hthen ( preprocessExpr thenType) fun hthen => do
let elseType mkArrow (mkNot c) (goalWithProg e)
withLocalDeclD `helse ( preprocessExpr elseType) fun helse => do
let onAlt (hc : Expr) (hcase : Expr) := do
let res rwIfWith hc prog
-- When `rw` fails, it returns `proof? := none`. We throw an error.
let goal := goalWithProg rawProg
-- Pass 1: Discover hypothesis types.
-- The callback receives `bodyType` from the splitter telescope. For matchers with
-- motive substitution, this has discriminants replaced by patterns, so we extract the
-- program from `bodyType` (not the original `rawProg`) and build hypotheses from it.
let hypTypesRef IO.mkRef (#[] : Array Expr)
let _ liftMetaM <| splitFn goal fun idx params bodyType => do
let rawProg' := extractProgFromGoal bodyType
let res rwIfOrMatcher idx rawProg'
let hypBodyType := replaceProgInGoal bodyType res.expr
let hypType mkForallFVars params hypBodyType
hypTypesRef.modify (·.push hypType)
return bodyType
let hypTypes hypTypesRef.get
let hypNamesTypes := hypTypes.mapIdx fun i ty => ((`h).appendIndexAfter (i+1), ty)
withLocalDeclsDND hypNamesTypes fun hyps => do
-- Pass 2: Build proof.
-- `context` is built per-alt from `bodyType` so that the congruence proof
-- matches the per-alt goal (with discriminants substituted, cf. 99c83b9c).
let prf liftMetaM <| splitFn goal fun idx params bodyType => do
let rawProg' := extractProgFromGoal bodyType
let res rwIfOrMatcher idx rawProg'
if res.proof?.isNone then
throwError "`rwIfWith` failed to rewrite {indentExpr e}."
-- context = fun e => P ⊢ₛ wp⟦e⟧ Q s₁ ... sₙ
let context withLocalDecl `e .default mα fun e => mkLambdaFVars #[e] (goalWithProg e)
throwError "mkBackwardRuleForSplit: rwIfOrMatcher failed for alt {idx}"
let hypProof := mkAppN hyps[idx]! params
let context withLocalDecl `e .default mα fun e =>
mkLambdaFVars #[e] (replaceProgInGoal bodyType e)
let res Simp.mkCongrArg context res
res.mkEqMPR hcase
let ht withLocalDecl `h .default c fun h => do mkLambdaFVars #[h] ( onAlt h (mkApp hthen h))
let he withLocalDecl `h .default (mkNot c) fun h => do mkLambdaFVars #[h] ( onAlt h (mkApp helse h))
let prf := mkApp5 (mkConst ``dite [0]) (goalWithProg prog) c dec ht he
mkLambdaFVars (#[α, c, dec, t, e] ++ ss ++ #[P, Q, hthen, helse]) prf
res.mkEqMPR hypProof
mkLambdaFVars (#[α] ++ splitFVars ++ ssToAbstract ++ #[P, Q] ++ hyps) prf
let prf instantiateMVars prf
let res abstractMVars prf
let expr normalizeLevelsExpr res.expr
@@ -516,12 +682,15 @@ meta def mkBackwardRuleFromSpecCached (specThm : SpecTheoremNew) (m σs ps instW
return res
open Lean.Elab.Tactic.Do in
/-- See the documentation for `SpecTheoremNew.mkBackwardRuleForIte` for more details. -/
/-- Creates and caches a backward rule for splitting `ite`, `dite`, or matchers. -/
meta def mkBackwardRuleFromSplitInfoCached (splitInfo : SplitInfo) (m σs ps instWP : Expr) (excessArgs : Array Expr) : _root_.VCGenM BackwardRule := do
unless splitInfo matches .ite .. do throwError "Only `ite` is currently supported for splitting."
let mkRuleSlow := mkBackwardRuleForIte m σs ps instWP excessArgs
let cacheKey := match splitInfo with
| .ite .. => ``ite
| .dite .. => ``dite
| .matcher matcherApp => matcherApp.matcherName
let mkRuleSlow := mkBackwardRuleForSplit splitInfo m σs ps instWP excessArgs
let s get
let (res, splitBackwardRuleCache) s.splitBackwardRuleCache.getDM (``ite, m, excessArgs.size) mkRuleSlow
let (res, splitBackwardRuleCache) s.splitBackwardRuleCache.getDM (cacheKey, m, excessArgs.size) mkRuleSlow
set { s with splitBackwardRuleCache }
return res
@@ -743,9 +912,8 @@ meta def solve (goal : MVarId) : VCGenM SolveResult := goal.withContext do
let goal goal.replaceTargetDefEq target
return .goals [goal]
-- Hard-code match splitting for `ite` for now.
if f.isAppOf ``ite then
let some info Lean.Elab.Tactic.Do.getSplitInfo? e | return .noStrategyForProgram e
-- Split ite/dite/match
if let some info liftMetaM <| Lean.Elab.Tactic.Do.getSplitInfo? e then
let rule mkBackwardRuleFromSplitInfoCached info m σs ps instWP excessArgs
let ApplyResult.goals goals rule.apply goal
| throwError "Failed to apply split rule for {indentExpr e}"

View File

@@ -18,6 +18,9 @@ Each case exercises a different aspect of the VC generation:
- `GetThrowSet`: Exception handling with `ExceptT`/`StateM`
- `PurePrecond`: Pure hypotheses `⌜φ⌝` in preconditions
- `ReaderState`: `ReaderT`/`StateM` combination
- `DiteSplit`: Dependent if-then-else (`if h : cond then ...`)
- `MatchSplit`: Pattern matching in monadic programs
- `MatchSplitState`: Match on state variable (discriminant = excess state arg)
-/
open Lean Parser Meta Elab Tactic Sym Std Do SpecAttr
@@ -42,3 +45,12 @@ open PurePrecond in
open ReaderState in
#eval runBenchUsingTactic ``Goal [``loop, ``step] `(tactic| mvcgen') `(tactic| sorry) [10]
open DiteSplit in
#eval runBenchUsingTactic ``Goal [``loop, ``step] `(tactic| mvcgen') `(tactic| sorry) [10]
open MatchSplit in
#eval runBenchUsingTactic ``Goal [``loop, ``step] `(tactic| mvcgen') `(tactic| sorry) [10]
open MatchSplitState in
#eval runBenchUsingTactic ``Goal [``loop, ``step] `(tactic| mvcgen') `(tactic| grind) [10]

View File

@@ -0,0 +1,16 @@
/-
Copyright (c) 2026 Lean FRO LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Sebastian Graf
-/
import Cases.MatchSplit
import Driver
open Lean Parser Meta Elab Tactic Sym Std Do SpecAttr
open MatchSplit
set_option maxRecDepth 10000
set_option maxHeartbeats 10000000
#eval runBenchUsingTactic ``Goal [``loop, ``step] `(tactic| mvcgen') `(tactic| sorry)
[100, 500, 1000]