mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-17 10:24:07 +00:00
refactor: simplify mkBackwardRuleForSplit and fix fvar alt handling
Inline `withSplitterAndLocals` into `mkBackwardRuleForSplit` and replace it with a single `SplitInfo.splitWith` call. Eta-expand alts in `withAbstract` so `matcherApp.transform` can `instantiateLambda` them directly (no patching needed), then eta-reduce when computing `abstractProg` to avoid expensive higher-order unification in backward rule patterns. Extract `mkGoal` and `extractProgFromGoal` as top-level helpers, removing `replaceProgInGoal`. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -66,46 +66,48 @@ For `ite`/`dite`, introduces `c : Prop`, `dec : Decidable c`, `t : mα` (or `t :
|
||||
For `matcher`, introduces discriminant fvars and alternative fvars, builds a non-dependent
|
||||
motive `fun _ ... _ => mα`, and adjusts matcher universe levels.
|
||||
|
||||
The abstract `SplitInfo` satisfies `abstractInfo.toExpr = abstractProgram`.
|
||||
The abstract `SplitInfo` satisfies `abstractInfo.expr = abstractProgram`.
|
||||
|
||||
For `matcher`, the abstract `MatcherApp` stores fvar alts. Callers that need the original
|
||||
lambda alts (e.g. for `splitWith`/`matcherApp.transform`) should patch them back:
|
||||
`{ abstractMatcherApp with alts := origMatcherApp.alts }`.
|
||||
For `matcher`, the abstract `MatcherApp` stores eta-expanded fvar alts so that
|
||||
`splitWith`/`matcherApp.transform` can `instantiateLambda` them directly (no patching needed).
|
||||
Since eta-expanded alts like `fun n => alt n` can cause expensive higher-order unification in
|
||||
backward rule patterns, callers building backward rules should eta-reduce them first (e.g.
|
||||
via `Expr.eta` on the alt arguments of `abstractInfo.expr`).
|
||||
-/
|
||||
def withAbstract {n} {α} [MonadLiftT MetaM n] [MonadControlT MetaM n] [Monad n] [Inhabited α]
|
||||
(info : SplitInfo) (mα : Expr)
|
||||
(info : SplitInfo) (resTy : Expr)
|
||||
(k : (abstractInfo : SplitInfo) → (splitFVars : Array Expr) → n α) : n α :=
|
||||
match info with
|
||||
| .ite _ =>
|
||||
withLocalDeclD `c (mkSort 0) fun c =>
|
||||
withLocalDeclD `dec (mkApp (mkConst ``Decidable) c) fun dec =>
|
||||
withLocalDeclD `t mα fun t =>
|
||||
withLocalDeclD `e mα fun e => do
|
||||
let u ← liftMetaM <| getLevel mα
|
||||
k (.ite <| mkApp5 (mkConst ``_root_.ite [u]) mα c dec t e) #[c, dec, t, e]
|
||||
withLocalDeclD `t resTy fun t =>
|
||||
withLocalDeclD `e resTy fun e => do
|
||||
let u ← liftMetaM <| getLevel resTy
|
||||
k (.ite <| mkApp5 (mkConst ``_root_.ite [u]) resTy c dec t e) #[c, dec, t, e]
|
||||
| .dite _ =>
|
||||
withLocalDeclD `c (mkSort 0) fun c =>
|
||||
withLocalDeclD `dec (mkApp (mkConst ``Decidable) c) fun dec => do
|
||||
let tTy ← liftMetaM <| mkArrow c mα
|
||||
let eTy ← liftMetaM <| mkArrow (mkNot c) mα
|
||||
let tTy ← liftMetaM <| mkArrow c resTy
|
||||
let eTy ← liftMetaM <| mkArrow (mkNot c) resTy
|
||||
withLocalDeclD `t tTy fun t =>
|
||||
withLocalDeclD `e eTy fun e => do
|
||||
let u ← liftMetaM <| getLevel mα
|
||||
k (.dite <| mkApp5 (mkConst ``_root_.dite [u]) mα c dec t e) #[c, dec, t, e]
|
||||
let u ← liftMetaM <| getLevel resTy
|
||||
k (.dite <| mkApp5 (mkConst ``_root_.dite [u]) resTy c dec t e) #[c, dec, t, e]
|
||||
| .matcher matcherApp => do
|
||||
let discrNamesTypes ← matcherApp.discrs.mapIdxM fun i discr => do
|
||||
return ((`discr).appendIndexAfter (i+1), ← liftMetaM <| inferType discr)
|
||||
withLocalDeclsDND discrNamesTypes fun discrs => do
|
||||
-- Non-dependent motive: fun _ ... _ => mα
|
||||
let motive ← liftMetaM <| lambdaTelescope matcherApp.motive fun motiveArgs _ =>
|
||||
mkLambdaFVars motiveArgs mα
|
||||
mkLambdaFVars motiveArgs resTy
|
||||
-- The matcher's universe levels include a `uElimPos` slot for the elimination target level.
|
||||
-- Our abstract motive `fun _ ... _ => mα` may target a different level than the original
|
||||
-- dependent motive, so we update `matcherLevels[uElimPos]` to `getLevel mα`.
|
||||
let matcherLevels ← match matcherApp.uElimPos? with
|
||||
| none => pure matcherApp.matcherLevels
|
||||
| some pos => do
|
||||
let uElim ← liftMetaM <| getLevel mα
|
||||
let uElim ← liftMetaM <| getLevel resTy
|
||||
pure <| matcherApp.matcherLevels.set! pos uElim
|
||||
-- Build partial application to infer alt types
|
||||
let matcherPartial := mkAppN (mkConst matcherApp.matcherName matcherLevels.toList) matcherApp.params
|
||||
@@ -114,12 +116,13 @@ def withAbstract {n} {α} [MonadLiftT MetaM n] [MonadControlT MetaM n] [Monad n]
|
||||
let origAltTypes ← liftMetaM <| inferArgumentTypesN matcherApp.alts.size matcherPartial
|
||||
let altNamesTypes := origAltTypes.mapIdx fun i ty => ((`alt).appendIndexAfter (i+1), ty)
|
||||
withLocalDeclsDND altNamesTypes fun alts => do
|
||||
-- Eta-expand fvar alts so `splitWith`/`matcherApp.transform` can `instantiateLambda` them.
|
||||
let abstractMatcherApp : MatcherApp := {
|
||||
matcherApp with
|
||||
matcherLevels := matcherLevels
|
||||
discrs := discrs
|
||||
motive := motive
|
||||
alts := alts
|
||||
alts := (← liftMetaM <| alts.mapM etaExpand)
|
||||
remaining := #[]
|
||||
}
|
||||
k (.matcher abstractMatcherApp) (discrs ++ alts)
|
||||
|
||||
@@ -155,6 +155,28 @@ meta def SpecTheoremsNew.findSpecs (database : SpecTheoremsNew) (e : Expr) :
|
||||
|
||||
end Lean.Elab.Tactic.Do.SpecAttr
|
||||
|
||||
|
||||
-- Normalize universe levels in an expression so that `max u v` and `max v u` have a canonical
|
||||
-- representation. This is needed because backward rule pattern matching is structural and
|
||||
-- level expressions from different sources (e.g., instance synthesis, type inference) may have
|
||||
-- different but equivalent `max` orderings.
|
||||
meta def normalizeLevelsExpr (e : Expr) : CoreM Expr :=
|
||||
Core.transform e (pre := fun e => do
|
||||
match e with
|
||||
| .sort u => return .done <| e.updateSort! u.normalize
|
||||
| .const _ us => return .done <| e.updateConst! (us.map Level.normalize)
|
||||
| _ => return .continue)
|
||||
|
||||
/-- Build goal: `P ⊢ₛ wp⟦prog⟧ Q ss...`. Meant to be partially applied for convenience. -/
|
||||
private meta def mkGoal (u v : Level) (m σs ps instWP α : Expr) (ss : Array Expr) (P Q : Expr) (prog : Expr) : Expr :=
|
||||
mkApp3 (mkConst ``SPred.entails [u]) σs P
|
||||
(mkAppN (mkApp4 (mkConst ``PredTrans.apply [u]) ps α
|
||||
(mkApp5 (mkConst ``WP.wp [u, v]) m ps instWP α prog) Q) ss)
|
||||
|
||||
/-- Extract the program from a goal built by `mkGoal`. -/
|
||||
private meta def extractProgFromGoal (goal : Expr) : Expr :=
|
||||
goal.getArg! 2 |>.getArg! 2 |>.getArg! 4
|
||||
|
||||
/--
|
||||
Create a backward rule for the `SpecTheoremNew` that was looked up in the database.
|
||||
In order for the backward rule to apply, we need to instantiate both `m` and `ps` with the ones
|
||||
@@ -217,18 +239,6 @@ prf : ∀ (α : Type) (x : StateT Nat Id α) (β : Type) (f : α → StateT Nat
|
||||
We are still investigating how to get rid of more unfolding overhead, such as for `wp` and
|
||||
`List.rec`.
|
||||
-/
|
||||
|
||||
-- Normalize universe levels in an expression so that `max u v` and `max v u` have a canonical
|
||||
-- representation. This is needed because backward rule pattern matching is structural and
|
||||
-- level expressions from different sources (e.g., instance synthesis, type inference) may have
|
||||
-- different but equivalent `max` orderings.
|
||||
meta def normalizeLevelsExpr (e : Expr) : CoreM Expr :=
|
||||
Core.transform e (pre := fun e => do
|
||||
match e with
|
||||
| .sort u => return .done <| e.updateSort! u.normalize
|
||||
| .const _ us => return .done <| e.updateConst! (us.map Level.normalize)
|
||||
| _ => return .continue)
|
||||
|
||||
meta def mkBackwardRuleFromSpec (specThm : SpecTheoremNew) (m σs ps instWP : Expr) (excessArgs : Array Expr) : SymM BackwardRule := do
|
||||
let preprocessExpr : Expr → SymM Expr := shareCommon <=< liftMetaM ∘ unfoldReducible
|
||||
-- Create a backward rule for the spec we look up in the database.
|
||||
@@ -396,121 +406,55 @@ meta def mkBackwardRuleFromSimpSpec (specThm : SpecTheoremNew) (m σs ps instWP
|
||||
let expr ← normalizeLevelsExpr res.expr
|
||||
mkBackwardRuleFromExpr expr res.paramNames.toList
|
||||
|
||||
open Lean.Elab.Tactic.Do in
|
||||
/--
|
||||
CPS helper that introduces local fvars for a split and provides a uniform splitter.
|
||||
|
||||
Uses `SplitInfo.withAbstract` to open a context where discriminants and alternatives
|
||||
are fvars. Then provides a `splitFn` that builds the splitting proof:
|
||||
- For `ite`/`dite`: uses `dite` as the splitter (condition proofs as params).
|
||||
- For `matcher`: delegates to `SplitInfo.splitWith` / `matcherApp.transform`.
|
||||
|
||||
The continuation `k` receives `abstractProg` (the abstract program expression, i.e.
|
||||
`abstractInfo.expr`), `splitFVars` (fvars to abstract over), and `splitFn` (builds
|
||||
the splitting proof).
|
||||
-/
|
||||
private meta def withSplitterAndLocals {α} [Inhabited α] (splitInfo : SplitInfo) (mα : Expr)
|
||||
(k : (abstractProg : Expr) → (splitFVars : Array Expr) →
|
||||
(splitFn : (goal : Expr) → (onAlt : Nat → Array Expr → Expr → MetaM Expr) → MetaM Expr) →
|
||||
SymM α) : SymM α := do
|
||||
splitInfo.withAbstract mα fun abstractInfo splitFVars => do
|
||||
let abstractProg := match abstractInfo with
|
||||
| .ite e | .dite e => e
|
||||
| .matcher matcherApp => matcherApp.toExpr
|
||||
let splitFn : (goal : Expr) → (onAlt : Nat → Array Expr → Expr → MetaM Expr) → MetaM Expr :=
|
||||
match abstractInfo with
|
||||
| .ite e | .dite e =>
|
||||
let c := e.getArg! 1
|
||||
let dec := e.getArg! 2
|
||||
fun goal onAlt => do
|
||||
let ht ← withLocalDecl `h .default c fun h => do
|
||||
mkLambdaFVars #[h] (← onAlt 0 #[h] goal)
|
||||
let he ← withLocalDecl `h .default (mkNot c) fun h => do
|
||||
mkLambdaFVars #[h] (← onAlt 1 #[h] goal)
|
||||
return mkApp5 (mkConst ``dite [0]) goal c dec ht he
|
||||
| .matcher abstractMatcherApp =>
|
||||
fun goal onAlt => do
|
||||
-- Patch fvar alts back to original lambda alts for splitWith/matcherApp.transform
|
||||
let .matcher origMatcherApp := splitInfo | unreachable!
|
||||
let splitMatcherApp := { abstractMatcherApp with alts := origMatcherApp.alts }
|
||||
SplitInfo.splitWith (.matcher splitMatcherApp) goal
|
||||
(fun _name expAltType idx params => onAlt idx params expAltType)
|
||||
(useSplitter := true)
|
||||
k abstractProg splitFVars splitFn
|
||||
|
||||
open Lean.Elab.Tactic.Do in
|
||||
/--
|
||||
Creates a reusable backward rule for splitting `ite`, `dite`, or matchers.
|
||||
|
||||
For `ite`/`dite`, proves a theorem of the following form:
|
||||
```
|
||||
example {m} {σ} {ps} [WP m (.arg σ ps)]
|
||||
{α} {c : Prop} [Decidable c] {t e : m α} {s : σ} {P : Assertion ps} {Q : PostCond α (.arg σ ps)}
|
||||
(hthen : c → P ⊢ₛ wp⟦t⟧ Q s) (helse : ¬c → P ⊢ₛ wp⟦e⟧ Q s)
|
||||
: P ⊢ₛ wp⟦ite c t e⟧ Q s
|
||||
```
|
||||
For matchers, the hypothesis types are discovered via `rwIfOrMatcher` after `withSplitterAndLocals`
|
||||
opens the splitter telescope. No branching on split kind is needed in this function.
|
||||
Uses `SplitInfo.withAbstract` to open fvars for the split, then `SplitInfo.splitWith`
|
||||
to build the splitting proof. Hypothesis types are discovered via `rwIfOrMatcher` inside
|
||||
the splitter telescope.
|
||||
-/
|
||||
meta def mkBackwardRuleForSplit (splitInfo : SplitInfo) (m σs ps instWP : Expr) (excessArgs : Array Expr) : SymM BackwardRule := do
|
||||
let preprocessExpr : Expr → SymM Expr := shareCommon <=< liftMetaM ∘ unfoldReducible
|
||||
-- Extract the program expression from a goal-shaped expression.
|
||||
let extractProgFromGoal (goal : Expr) : Expr :=
|
||||
goal.getArg! 2 |>.getArg! 2 |>.getArg! 4
|
||||
-- Replace the program expression in a goal-shaped expression:
|
||||
-- SPred.entails σs P (mkAppN (PredTrans.apply ps α (WP.wp m ps instWP α PROG) Q) ss)
|
||||
-- Built by `goalWithProg`, so the structure is fixed.
|
||||
let replaceProgInGoal (goal newProg : Expr) : Expr :=
|
||||
let goalArgs := goal.getAppArgs -- [σs, P, wpApplyQss]
|
||||
let waqArgs := goalArgs[2]!.getAppArgs -- [ps, α, wp, Q, ss...]
|
||||
let wpArgs := waqArgs[2]!.getAppArgs -- [m, ps, instWP, α, prog]
|
||||
let wp' := mkAppN waqArgs[2]!.getAppFn (wpArgs.set! 4 newProg)
|
||||
let waq' := mkAppN goalArgs[2]!.getAppFn (waqArgs.set! 2 wp')
|
||||
mkAppN goal.getAppFn (goalArgs.set! 2 waq')
|
||||
let prf ← do
|
||||
let us := instWP.getAppFn.constLevels!
|
||||
let u := us[0]!
|
||||
let v := us[1]!
|
||||
let us := instWP.getAppFn.constLevels!
|
||||
let u := us[0]!
|
||||
let v := us[1]!
|
||||
let prf ←
|
||||
withLocalDeclD `α (mkSort u.succ) fun α => do
|
||||
let mα ← preprocessExpr <| mkApp m α
|
||||
withSplitterAndLocals splitInfo mα fun abstractProg splitFVars splitFn => do
|
||||
splitInfo.withAbstract mα fun abstractInfo splitFVars => do
|
||||
-- Eta-reduce alts so the backward rule pattern uses clean fvar alts, avoiding expensive
|
||||
-- higher-order unification. The alts are eta-expanded in `withAbstract` so that
|
||||
-- `splitWith`/`matcherApp.transform` can `instantiateLambda` them.
|
||||
let abstractProg := match abstractInfo with
|
||||
| .ite e | .dite e => e
|
||||
| .matcher matcherApp =>
|
||||
{ matcherApp with alts := matcherApp.alts.map Expr.eta }.toExpr
|
||||
let excessArgNamesTypes ← excessArgs.mapM fun arg => return (`s, ← Meta.inferType arg)
|
||||
withLocalDeclsDND excessArgNamesTypes fun ss => do
|
||||
withLocalDeclD `P (← preprocessExpr <| mkApp (mkConst ``SPred [u]) σs) fun P => do
|
||||
withLocalDeclD `Q (← preprocessExpr <| mkApp2 (mkConst ``PostCond [u]) α ps) fun Q => do
|
||||
let goalWithProg prog :=
|
||||
let wp := mkApp5 (mkConst ``WP.wp [u, v]) m ps instWP α prog
|
||||
let wpApplyQ := mkApp4 (mkConst ``PredTrans.apply [u]) ps α wp Q
|
||||
let wpApplyQ := mkAppN wpApplyQ ss
|
||||
mkApp3 (mkConst ``SPred.entails [u]) σs P wpApplyQ
|
||||
let goal := goalWithProg abstractProg
|
||||
-- We will have one subgoal per alt. We need to introduce these subgoals als local hypotheses.
|
||||
-- For match, it's difficult to know the exact locals of these subgoals before splitting.
|
||||
-- Hence we leave the exact shape of the subgoals as metavariables and fill them when
|
||||
-- constructing the proof in the alt body.
|
||||
let numAlts := match splitInfo with
|
||||
| .ite _ | .dite _ => 2
|
||||
| .matcher matcherApp => matcherApp.alts.size
|
||||
let mut subgoals := #[]
|
||||
for _ in [:numAlts] do
|
||||
subgoals := subgoals.push (← liftMetaM <| mkFreshExprMVar (some (mkSort 0)))
|
||||
let mkGoal := mkGoal u v m σs ps instWP α ss P Q
|
||||
-- Subgoal types are synthetic opaque metavariables, filled in the `splitWith` callback below.
|
||||
-- Synthetic opaque so that `rwIfOrMatcher`'s `assumption` tactic cannot assign them.
|
||||
let subgoals ← splitInfo.altInfos.mapM fun _ =>
|
||||
liftMetaM <| mkFreshExprSyntheticOpaqueMVar (mkSort 0)
|
||||
let namedSubgoals := subgoals.mapIdx fun i mv => ((`h).appendIndexAfter (i+1), mv)
|
||||
withLocalDeclsDND namedSubgoals fun subgoalHyps => do
|
||||
-- For ite/dite, bodyType = goal. For matchers, bodyType = expAltType from
|
||||
-- splitWith (discriminant fvars replaced by patterns via motive substitution).
|
||||
let prf ← liftMetaM <| splitFn goal fun idx params bodyType => do
|
||||
let res ← rwIfOrMatcher idx (extractProgFromGoal bodyType)
|
||||
if res.proof?.isNone then
|
||||
throwError "mkBackwardRuleForSplit: rwIfOrMatcher failed for alt {idx}"
|
||||
-- Assign the metavariable of the discovered subgoal
|
||||
let hypBodyType := replaceProgInGoal bodyType res.expr
|
||||
let hypType ← mkForallFVars params hypBodyType
|
||||
subgoals[idx]!.mvarId!.assign hypType
|
||||
let hypProof := mkAppN subgoalHyps[idx]! params
|
||||
let context ← withLocalDecl `e .default mα fun e =>
|
||||
mkLambdaFVars #[e] (replaceProgInGoal bodyType e)
|
||||
let res ← Simp.mkCongrArg context res
|
||||
res.mkEqMPR hypProof
|
||||
let prf ← liftMetaM <|
|
||||
abstractInfo.splitWith
|
||||
(useSplitter := true)
|
||||
(mkGoal abstractProg)
|
||||
(fun _name bodyType idx altFVars => do
|
||||
let prog := extractProgFromGoal bodyType
|
||||
let res ← rwIfOrMatcher idx prog
|
||||
if res.proof?.isNone then
|
||||
throwError "mkBackwardRuleForSplit: rwIfOrMatcher failed for alt {idx}\n{indentExpr prog}"
|
||||
let boundFVars := altFVars.all
|
||||
subgoals[idx]!.mvarId!.assign (← mkForallFVars boundFVars (mkGoal res.expr))
|
||||
let context ← withLocalDecl `e .default mα fun e =>
|
||||
mkLambdaFVars #[e] (mkGoal e)
|
||||
(← Simp.mkCongrArg context res).mkEqMPR (mkAppN subgoalHyps[idx]! boundFVars))
|
||||
mkLambdaFVars (#[α] ++ splitFVars ++ ss ++ #[P, Q] ++ subgoalHyps) prf
|
||||
let prf ← instantiateMVars prf
|
||||
let res ← abstractMVars prf
|
||||
|
||||
Reference in New Issue
Block a user