refactor: simplify mkBackwardRuleForSplit and fix fvar alt handling

Inline `withSplitterAndLocals` into `mkBackwardRuleForSplit` and replace it
with a single `SplitInfo.splitWith` call. Eta-expand alts in `withAbstract`
so `matcherApp.transform` can `instantiateLambda` them directly (no patching
needed), then eta-reduce when computing `abstractProg` to avoid expensive
higher-order unification in backward rule patterns. Extract `mkGoal` and
`extractProgFromGoal` as top-level helpers, removing `replaceProgInGoal`.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Sebastian Graf
2026-03-12 11:08:51 +00:00
parent 4261fbe8ce
commit 0007ffa16a
2 changed files with 75 additions and 128 deletions

View File

@@ -66,46 +66,48 @@ For `ite`/`dite`, introduces `c : Prop`, `dec : Decidable c`, `t : mα` (or `t :
For `matcher`, introduces discriminant fvars and alternative fvars, builds a non-dependent
motive `fun _ ... _ => mα`, and adjusts matcher universe levels.
The abstract `SplitInfo` satisfies `abstractInfo.toExpr = abstractProgram`.
The abstract `SplitInfo` satisfies `abstractInfo.expr = abstractProgram`.
For `matcher`, the abstract `MatcherApp` stores fvar alts. Callers that need the original
lambda alts (e.g. for `splitWith`/`matcherApp.transform`) should patch them back:
`{ abstractMatcherApp with alts := origMatcherApp.alts }`.
For `matcher`, the abstract `MatcherApp` stores eta-expanded fvar alts so that
`splitWith`/`matcherApp.transform` can `instantiateLambda` them directly (no patching needed).
Since eta-expanded alts like `fun n => alt n` can cause expensive higher-order unification in
backward rule patterns, callers building backward rules should eta-reduce them first (e.g.
via `Expr.eta` on the alt arguments of `abstractInfo.expr`).
-/
def withAbstract {n} {α} [MonadLiftT MetaM n] [MonadControlT MetaM n] [Monad n] [Inhabited α]
(info : SplitInfo) (mα : Expr)
(info : SplitInfo) (resTy : Expr)
(k : (abstractInfo : SplitInfo) (splitFVars : Array Expr) n α) : n α :=
match info with
| .ite _ =>
withLocalDeclD `c (mkSort 0) fun c =>
withLocalDeclD `dec (mkApp (mkConst ``Decidable) c) fun dec =>
withLocalDeclD `t mα fun t =>
withLocalDeclD `e mα fun e => do
let u liftMetaM <| getLevel mα
k (.ite <| mkApp5 (mkConst ``_root_.ite [u]) mα c dec t e) #[c, dec, t, e]
withLocalDeclD `t resTy fun t =>
withLocalDeclD `e resTy fun e => do
let u liftMetaM <| getLevel resTy
k (.ite <| mkApp5 (mkConst ``_root_.ite [u]) resTy c dec t e) #[c, dec, t, e]
| .dite _ =>
withLocalDeclD `c (mkSort 0) fun c =>
withLocalDeclD `dec (mkApp (mkConst ``Decidable) c) fun dec => do
let tTy liftMetaM <| mkArrow c mα
let eTy liftMetaM <| mkArrow (mkNot c) mα
let tTy liftMetaM <| mkArrow c resTy
let eTy liftMetaM <| mkArrow (mkNot c) resTy
withLocalDeclD `t tTy fun t =>
withLocalDeclD `e eTy fun e => do
let u liftMetaM <| getLevel mα
k (.dite <| mkApp5 (mkConst ``_root_.dite [u]) mα c dec t e) #[c, dec, t, e]
let u liftMetaM <| getLevel resTy
k (.dite <| mkApp5 (mkConst ``_root_.dite [u]) resTy c dec t e) #[c, dec, t, e]
| .matcher matcherApp => do
let discrNamesTypes matcherApp.discrs.mapIdxM fun i discr => do
return ((`discr).appendIndexAfter (i+1), liftMetaM <| inferType discr)
withLocalDeclsDND discrNamesTypes fun discrs => do
-- Non-dependent motive: fun _ ... _ => mα
let motive liftMetaM <| lambdaTelescope matcherApp.motive fun motiveArgs _ =>
mkLambdaFVars motiveArgs mα
mkLambdaFVars motiveArgs resTy
-- The matcher's universe levels include a `uElimPos` slot for the elimination target level.
-- Our abstract motive `fun _ ... _ => mα` may target a different level than the original
-- dependent motive, so we update `matcherLevels[uElimPos]` to `getLevel mα`.
let matcherLevels match matcherApp.uElimPos? with
| none => pure matcherApp.matcherLevels
| some pos => do
let uElim liftMetaM <| getLevel mα
let uElim liftMetaM <| getLevel resTy
pure <| matcherApp.matcherLevels.set! pos uElim
-- Build partial application to infer alt types
let matcherPartial := mkAppN (mkConst matcherApp.matcherName matcherLevels.toList) matcherApp.params
@@ -114,12 +116,13 @@ def withAbstract {n} {α} [MonadLiftT MetaM n] [MonadControlT MetaM n] [Monad n]
let origAltTypes liftMetaM <| inferArgumentTypesN matcherApp.alts.size matcherPartial
let altNamesTypes := origAltTypes.mapIdx fun i ty => ((`alt).appendIndexAfter (i+1), ty)
withLocalDeclsDND altNamesTypes fun alts => do
-- Eta-expand fvar alts so `splitWith`/`matcherApp.transform` can `instantiateLambda` them.
let abstractMatcherApp : MatcherApp := {
matcherApp with
matcherLevels := matcherLevels
discrs := discrs
motive := motive
alts := alts
alts := ( liftMetaM <| alts.mapM etaExpand)
remaining := #[]
}
k (.matcher abstractMatcherApp) (discrs ++ alts)

View File

@@ -155,6 +155,28 @@ meta def SpecTheoremsNew.findSpecs (database : SpecTheoremsNew) (e : Expr) :
end Lean.Elab.Tactic.Do.SpecAttr
-- Normalize universe levels in an expression so that `max u v` and `max v u` have a canonical
-- representation. This is needed because backward rule pattern matching is structural and
-- level expressions from different sources (e.g., instance synthesis, type inference) may have
-- different but equivalent `max` orderings.
meta def normalizeLevelsExpr (e : Expr) : CoreM Expr :=
Core.transform e (pre := fun e => do
match e with
| .sort u => return .done <| e.updateSort! u.normalize
| .const _ us => return .done <| e.updateConst! (us.map Level.normalize)
| _ => return .continue)
/-- Build goal: `P ⊢ₛ wp⟦prog⟧ Q ss...`. Meant to be partially applied for convenience. -/
private meta def mkGoal (u v : Level) (m σs ps instWP α : Expr) (ss : Array Expr) (P Q : Expr) (prog : Expr) : Expr :=
mkApp3 (mkConst ``SPred.entails [u]) σs P
(mkAppN (mkApp4 (mkConst ``PredTrans.apply [u]) ps α
(mkApp5 (mkConst ``WP.wp [u, v]) m ps instWP α prog) Q) ss)
/-- Extract the program from a goal built by `mkGoal`. -/
private meta def extractProgFromGoal (goal : Expr) : Expr :=
goal.getArg! 2 |>.getArg! 2 |>.getArg! 4
/--
Create a backward rule for the `SpecTheoremNew` that was looked up in the database.
In order for the backward rule to apply, we need to instantiate both `m` and `ps` with the ones
@@ -217,18 +239,6 @@ prf : ∀ (α : Type) (x : StateT Nat Id α) (β : Type) (f : α → StateT Nat
We are still investigating how to get rid of more unfolding overhead, such as for `wp` and
`List.rec`.
-/
-- Normalize universe levels in an expression so that `max u v` and `max v u` have a canonical
-- representation. This is needed because backward rule pattern matching is structural and
-- level expressions from different sources (e.g., instance synthesis, type inference) may have
-- different but equivalent `max` orderings.
meta def normalizeLevelsExpr (e : Expr) : CoreM Expr :=
Core.transform e (pre := fun e => do
match e with
| .sort u => return .done <| e.updateSort! u.normalize
| .const _ us => return .done <| e.updateConst! (us.map Level.normalize)
| _ => return .continue)
meta def mkBackwardRuleFromSpec (specThm : SpecTheoremNew) (m σs ps instWP : Expr) (excessArgs : Array Expr) : SymM BackwardRule := do
let preprocessExpr : Expr SymM Expr := shareCommon <=< liftMetaM unfoldReducible
-- Create a backward rule for the spec we look up in the database.
@@ -396,121 +406,55 @@ meta def mkBackwardRuleFromSimpSpec (specThm : SpecTheoremNew) (m σs ps instWP
let expr normalizeLevelsExpr res.expr
mkBackwardRuleFromExpr expr res.paramNames.toList
open Lean.Elab.Tactic.Do in
/--
CPS helper that introduces local fvars for a split and provides a uniform splitter.
Uses `SplitInfo.withAbstract` to open a context where discriminants and alternatives
are fvars. Then provides a `splitFn` that builds the splitting proof:
- For `ite`/`dite`: uses `dite` as the splitter (condition proofs as params).
- For `matcher`: delegates to `SplitInfo.splitWith` / `matcherApp.transform`.
The continuation `k` receives `abstractProg` (the abstract program expression, i.e.
`abstractInfo.expr`), `splitFVars` (fvars to abstract over), and `splitFn` (builds
the splitting proof).
-/
private meta def withSplitterAndLocals {α} [Inhabited α] (splitInfo : SplitInfo) (mα : Expr)
(k : (abstractProg : Expr) (splitFVars : Array Expr)
(splitFn : (goal : Expr) (onAlt : Nat Array Expr Expr MetaM Expr) MetaM Expr)
SymM α) : SymM α := do
splitInfo.withAbstract mα fun abstractInfo splitFVars => do
let abstractProg := match abstractInfo with
| .ite e | .dite e => e
| .matcher matcherApp => matcherApp.toExpr
let splitFn : (goal : Expr) (onAlt : Nat Array Expr Expr MetaM Expr) MetaM Expr :=
match abstractInfo with
| .ite e | .dite e =>
let c := e.getArg! 1
let dec := e.getArg! 2
fun goal onAlt => do
let ht withLocalDecl `h .default c fun h => do
mkLambdaFVars #[h] ( onAlt 0 #[h] goal)
let he withLocalDecl `h .default (mkNot c) fun h => do
mkLambdaFVars #[h] ( onAlt 1 #[h] goal)
return mkApp5 (mkConst ``dite [0]) goal c dec ht he
| .matcher abstractMatcherApp =>
fun goal onAlt => do
-- Patch fvar alts back to original lambda alts for splitWith/matcherApp.transform
let .matcher origMatcherApp := splitInfo | unreachable!
let splitMatcherApp := { abstractMatcherApp with alts := origMatcherApp.alts }
SplitInfo.splitWith (.matcher splitMatcherApp) goal
(fun _name expAltType idx params => onAlt idx params expAltType)
(useSplitter := true)
k abstractProg splitFVars splitFn
open Lean.Elab.Tactic.Do in
/--
Creates a reusable backward rule for splitting `ite`, `dite`, or matchers.
For `ite`/`dite`, proves a theorem of the following form:
```
example {m} {σ} {ps} [WP m (.arg σ ps)]
{α} {c : Prop} [Decidable c] {t e : m α} {s : σ} {P : Assertion ps} {Q : PostCond α (.arg σ ps)}
(hthen : c → P ⊢ₛ wp⟦t⟧ Q s) (helse : ¬c → P ⊢ₛ wp⟦e⟧ Q s)
: P ⊢ₛ wp⟦ite c t e⟧ Q s
```
For matchers, the hypothesis types are discovered via `rwIfOrMatcher` after `withSplitterAndLocals`
opens the splitter telescope. No branching on split kind is needed in this function.
Uses `SplitInfo.withAbstract` to open fvars for the split, then `SplitInfo.splitWith`
to build the splitting proof. Hypothesis types are discovered via `rwIfOrMatcher` inside
the splitter telescope.
-/
meta def mkBackwardRuleForSplit (splitInfo : SplitInfo) (m σs ps instWP : Expr) (excessArgs : Array Expr) : SymM BackwardRule := do
let preprocessExpr : Expr SymM Expr := shareCommon <=< liftMetaM unfoldReducible
-- Extract the program expression from a goal-shaped expression.
let extractProgFromGoal (goal : Expr) : Expr :=
goal.getArg! 2 |>.getArg! 2 |>.getArg! 4
-- Replace the program expression in a goal-shaped expression:
-- SPred.entails σs P (mkAppN (PredTrans.apply ps α (WP.wp m ps instWP α PROG) Q) ss)
-- Built by `goalWithProg`, so the structure is fixed.
let replaceProgInGoal (goal newProg : Expr) : Expr :=
let goalArgs := goal.getAppArgs -- [σs, P, wpApplyQss]
let waqArgs := goalArgs[2]!.getAppArgs -- [ps, α, wp, Q, ss...]
let wpArgs := waqArgs[2]!.getAppArgs -- [m, ps, instWP, α, prog]
let wp' := mkAppN waqArgs[2]!.getAppFn (wpArgs.set! 4 newProg)
let waq' := mkAppN goalArgs[2]!.getAppFn (waqArgs.set! 2 wp')
mkAppN goal.getAppFn (goalArgs.set! 2 waq')
let prf do
let us := instWP.getAppFn.constLevels!
let u := us[0]!
let v := us[1]!
let us := instWP.getAppFn.constLevels!
let u := us[0]!
let v := us[1]!
let prf
withLocalDeclD `α (mkSort u.succ) fun α => do
let mα preprocessExpr <| mkApp m α
withSplitterAndLocals splitInfo mα fun abstractProg splitFVars splitFn => do
splitInfo.withAbstract mα fun abstractInfo splitFVars => do
-- Eta-reduce alts so the backward rule pattern uses clean fvar alts, avoiding expensive
-- higher-order unification. The alts are eta-expanded in `withAbstract` so that
-- `splitWith`/`matcherApp.transform` can `instantiateLambda` them.
let abstractProg := match abstractInfo with
| .ite e | .dite e => e
| .matcher matcherApp =>
{ matcherApp with alts := matcherApp.alts.map Expr.eta }.toExpr
let excessArgNamesTypes excessArgs.mapM fun arg => return (`s, Meta.inferType arg)
withLocalDeclsDND excessArgNamesTypes fun ss => do
withLocalDeclD `P ( preprocessExpr <| mkApp (mkConst ``SPred [u]) σs) fun P => do
withLocalDeclD `Q ( preprocessExpr <| mkApp2 (mkConst ``PostCond [u]) α ps) fun Q => do
let goalWithProg prog :=
let wp := mkApp5 (mkConst ``WP.wp [u, v]) m ps instWP α prog
let wpApplyQ := mkApp4 (mkConst ``PredTrans.apply [u]) ps α wp Q
let wpApplyQ := mkAppN wpApplyQ ss
mkApp3 (mkConst ``SPred.entails [u]) σs P wpApplyQ
let goal := goalWithProg abstractProg
-- We will have one subgoal per alt. We need to introduce these subgoals als local hypotheses.
-- For match, it's difficult to know the exact locals of these subgoals before splitting.
-- Hence we leave the exact shape of the subgoals as metavariables and fill them when
-- constructing the proof in the alt body.
let numAlts := match splitInfo with
| .ite _ | .dite _ => 2
| .matcher matcherApp => matcherApp.alts.size
let mut subgoals := #[]
for _ in [:numAlts] do
subgoals := subgoals.push ( liftMetaM <| mkFreshExprMVar (some (mkSort 0)))
let mkGoal := mkGoal u v m σs ps instWP α ss P Q
-- Subgoal types are synthetic opaque metavariables, filled in the `splitWith` callback below.
-- Synthetic opaque so that `rwIfOrMatcher`'s `assumption` tactic cannot assign them.
let subgoals splitInfo.altInfos.mapM fun _ =>
liftMetaM <| mkFreshExprSyntheticOpaqueMVar (mkSort 0)
let namedSubgoals := subgoals.mapIdx fun i mv => ((`h).appendIndexAfter (i+1), mv)
withLocalDeclsDND namedSubgoals fun subgoalHyps => do
-- For ite/dite, bodyType = goal. For matchers, bodyType = expAltType from
-- splitWith (discriminant fvars replaced by patterns via motive substitution).
let prf liftMetaM <| splitFn goal fun idx params bodyType => do
let res rwIfOrMatcher idx (extractProgFromGoal bodyType)
if res.proof?.isNone then
throwError "mkBackwardRuleForSplit: rwIfOrMatcher failed for alt {idx}"
-- Assign the metavariable of the discovered subgoal
let hypBodyType := replaceProgInGoal bodyType res.expr
let hypType mkForallFVars params hypBodyType
subgoals[idx]!.mvarId!.assign hypType
let hypProof := mkAppN subgoalHyps[idx]! params
let context withLocalDecl `e .default mα fun e =>
mkLambdaFVars #[e] (replaceProgInGoal bodyType e)
let res Simp.mkCongrArg context res
res.mkEqMPR hypProof
let prf liftMetaM <|
abstractInfo.splitWith
(useSplitter := true)
(mkGoal abstractProg)
(fun _name bodyType idx altFVars => do
let prog := extractProgFromGoal bodyType
let res rwIfOrMatcher idx prog
if res.proof?.isNone then
throwError "mkBackwardRuleForSplit: rwIfOrMatcher failed for alt {idx}\n{indentExpr prog}"
let boundFVars := altFVars.all
subgoals[idx]!.mvarId!.assign ( mkForallFVars boundFVars (mkGoal res.expr))
let context withLocalDecl `e .default mα fun e =>
mkLambdaFVars #[e] (mkGoal e)
( Simp.mkCongrArg context res).mkEqMPR (mkAppN subgoalHyps[idx]! boundFVars))
mkLambdaFVars (#[α] ++ splitFVars ++ ss ++ #[P, Q] ++ subgoalHyps) prf
let prf instantiateMVars prf
let res abstractMVars prf