mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-17 10:24:07 +00:00
refactor: replace flat Array Expr with TransformAltFVars in MatcherApp.transform (#12902)
This PR introduces a `TransformAltFVars` structure to replace the flat `Array Expr` parameter in the `onAlt` callback of `MatcherApp.transform`. The new structure gives callers structured access to the different kinds of fvars introduced in matcher alternative telescopes: constructor fields, overlap parameters, discriminant equations, and extra equations from `addEqualities`. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -192,7 +192,8 @@ where
|
||||
-- Last resort: Split match
|
||||
trace[Elab.Tactic.Do.vcgen] "split match: {e}"
|
||||
burnOne
|
||||
return ← info.splitWith goal.toExpr (useSplitter := true) fun altSuff expAltType idx params => do
|
||||
return ← info.splitWith goal.toExpr (useSplitter := true) fun altSuff expAltType idx altFVars => do
|
||||
let params := altFVars.altParams
|
||||
burnOne
|
||||
let some goal := parseMGoal? expAltType
|
||||
| throwError "Bug in `mvcgen`: Expected alt type {expAltType} could not be parsed as an MGoal."
|
||||
@@ -253,8 +254,8 @@ where
|
||||
mkFreshExprSyntheticOpaqueMVar hypsTy (name.appendIndexAfter i)
|
||||
|
||||
let (joinPrf, joinGoal) ← forallBoundedTelescope joinTy numJoinParams fun joinParams _body => do
|
||||
let φ ← info.splitWith (mkSort .zero) fun _suff _expAltType idx altParams =>
|
||||
return mkAppN hypsMVars[idx]! (joinParams ++ altParams)
|
||||
let φ ← info.splitWith (mkSort .zero) fun _suff _expAltType idx altFVars =>
|
||||
return mkAppN hypsMVars[idx]! (joinParams ++ altFVars.altParams)
|
||||
withLocalDecl (← mkFreshUserName `h) .default φ fun h => do
|
||||
-- NB: `mkJoinGoal` is not quite `goal.withNewProg` because we only take 4 args and clear
|
||||
-- the stateful hypothesis of the goal.
|
||||
|
||||
@@ -54,27 +54,27 @@ def altInfos (info : SplitInfo) : Array (Nat × Expr) := match info with
|
||||
def splitWith
|
||||
{n} [MonadLiftT MetaM n] [MonadControlT MetaM n] [Monad n] [MonadError n] [MonadEnv n] [MonadLog n]
|
||||
[AddMessageContext n] [MonadOptions n]
|
||||
(info : SplitInfo) (resTy : Expr) (onAlt : Name → Expr → Nat → Array Expr → n Expr) (useSplitter := false) : n Expr := match info with
|
||||
(info : SplitInfo) (resTy : Expr) (onAlt : Name → Expr → Nat → MatcherApp.TransformAltFVars → n Expr) (useSplitter := false) : n Expr := match info with
|
||||
| ite e => do
|
||||
let u ← getLevel resTy
|
||||
let c := e.getArg! 1
|
||||
let h := e.getArg! 2
|
||||
if useSplitter then -- dite is the "splitter" for ite
|
||||
let n ← liftMetaM <| mkFreshUserName `h
|
||||
let t ← withLocalDecl n .default c fun h => do mkLambdaFVars #[h] (← onAlt `isTrue resTy 0 #[])
|
||||
let e ← withLocalDecl n .default (mkNot c) fun h => do mkLambdaFVars #[h] (← onAlt `isFalse resTy 1 #[])
|
||||
let t ← withLocalDecl n .default c fun h => do mkLambdaFVars #[h] (← onAlt `isTrue resTy 0 { fields := #[h] })
|
||||
let e ← withLocalDecl n .default (mkNot c) fun h => do mkLambdaFVars #[h] (← onAlt `isFalse resTy 1 { fields := #[h] })
|
||||
return mkApp5 (mkConst ``_root_.dite [u]) resTy c h t e
|
||||
else
|
||||
let t ← onAlt `isTrue resTy 0 #[]
|
||||
let e ← onAlt `isFalse resTy 1 #[]
|
||||
let t ← onAlt `isTrue resTy 0 { fields := #[] }
|
||||
let e ← onAlt `isFalse resTy 1 { fields := #[] }
|
||||
return mkApp5 (mkConst ``_root_.ite [u]) resTy c h t e
|
||||
| dite e => do
|
||||
let u ← getLevel resTy
|
||||
let c := e.getArg! 1
|
||||
let h := e.getArg! 2
|
||||
let n ← liftMetaM <| mkFreshUserName `h
|
||||
let t ← withLocalDecl n .default c fun h => do mkLambdaFVars #[h] (← onAlt `isTrue resTy 0 #[h])
|
||||
let e ← withLocalDecl n .default (mkNot c) fun h => do mkLambdaFVars #[h] (← onAlt `isFalse resTy 1 #[h])
|
||||
let t ← withLocalDecl n .default c fun h => do mkLambdaFVars #[h] (← onAlt `isTrue resTy 0 { args := #[h], fields := #[h] })
|
||||
let e ← withLocalDecl n .default (mkNot c) fun h => do mkLambdaFVars #[h] (← onAlt `isFalse resTy 1 { args := #[h], fields := #[h] })
|
||||
return mkApp5 (mkConst ``_root_.dite [u]) resTy c h t e
|
||||
| matcher matcherApp => do
|
||||
let mask := matcherApp.discrs.map (·.isFVar)
|
||||
@@ -83,8 +83,8 @@ def splitWith
|
||||
(·.toExpr) <$> matcherApp.transform
|
||||
(useSplitter := useSplitter) (addEqualities := useSplitter) -- (freshenNames := true)
|
||||
(onMotive := fun xs _body => pure (absMotiveBody.instantiateRev (Array.mask mask xs)))
|
||||
(onAlt := fun idx expAltType params _alt => do
|
||||
onAlt ((`h).appendIndexAfter (idx+1)) expAltType idx params)
|
||||
(onAlt := fun idx expAltType altFVars _alt => do
|
||||
onAlt ((`h).appendIndexAfter (idx+1)) expAltType idx altFVars)
|
||||
|
||||
def simpDiscrs? (info : SplitInfo) (e : Expr) : SimpM (Option Simp.Result) := match info with
|
||||
| dite _ | ite _ => return none -- Tricky because we need to simultaneously rewrite `[Decidable c]`
|
||||
|
||||
@@ -201,6 +201,47 @@ private def forallAltTelescope'
|
||||
fun ys args _mask _bodyType => k ys args
|
||||
) k
|
||||
|
||||
/--
|
||||
Fvars/exprs introduced in the telescope of a matcher alternative during `transform`.
|
||||
|
||||
* `args` are the values passed to `instantiateLambda` on the original alt. They usually
|
||||
coincide with `fields`, but may include non-fvar values (e.g. `Unit.unit` for thunked alts).
|
||||
* `fields` are the constructor-field fvars (proper fvar subset of `args`).
|
||||
* `overlaps` are overlap-parameter fvars (splitter path only, for non-`casesOn` splitters).
|
||||
* `discrEqs` are discriminant-equation fvars from the matcher's own type (`numDiscrEqs`).
|
||||
* `extraEqs` are equation fvars added by the `addEqualities` flag.
|
||||
|
||||
**Example.** `transform` with `addEqualities := true` on a `Nat.casesOn` application
|
||||
`Nat.casesOn (motive := …) n alt₀ alt₁` opens alt telescopes:
|
||||
```
|
||||
Alt 0 (zero): (heq : n = Nat.zero) → motive Nat.zero
|
||||
⟹ { args := #[], fields := #[], extraEqs := #[heq] }
|
||||
|
||||
Alt 1 (succ): (k : Nat) → (heq : n = Nat.succ k) → motive (Nat.succ k)
|
||||
⟹ { args := #[k], fields := #[k], extraEqs := #[heq] }
|
||||
```
|
||||
-/
|
||||
structure TransformAltFVars where
|
||||
/-- Arguments for `instantiateLambda` on the original alternative (see example above).
|
||||
May include non-fvar values like `Unit.unit` for thunked alternatives. -/
|
||||
args : Array Expr := #[]
|
||||
/-- Constructor field fvars, i.e. the proper fvar subset of `args` (see example above). -/
|
||||
fields : Array Expr
|
||||
/-- Overlap parameter fvars (non-casesOn splitters only). -/
|
||||
overlaps : Array Expr := #[]
|
||||
/-- Discriminant equation fvars from the matcher's own type (`numDiscrEqs`). -/
|
||||
discrEqs : Array Expr := #[]
|
||||
/-- Extra equation fvars added by `addEqualities` (see `heq` in the example above). -/
|
||||
extraEqs : Array Expr := #[]
|
||||
|
||||
/-- The `altParams` that were used for `instantiateLambda alt altParams` inside `transform`. -/
|
||||
def TransformAltFVars.altParams (fvars : TransformAltFVars) : Array Expr :=
|
||||
fvars.args ++ fvars.discrEqs
|
||||
|
||||
/-- All proper fvars in binding order, matching the lambdas that `transform` wraps around the alt result. -/
|
||||
def TransformAltFVars.all (fvars : TransformAltFVars) : Array Expr :=
|
||||
fvars.fields ++ fvars.overlaps ++ fvars.discrEqs ++ fvars.extraEqs
|
||||
|
||||
/--
|
||||
Performs a possibly type-changing transformation to a `MatcherApp`.
|
||||
|
||||
@@ -229,7 +270,7 @@ def transform
|
||||
(addEqualities : Bool := false)
|
||||
(onParams : Expr → n Expr := pure)
|
||||
(onMotive : Array Expr → Expr → n Expr := fun _ e => pure e)
|
||||
(onAlt : Nat → Expr → Array Expr → Expr → n Expr := fun _ _ _ e => pure e)
|
||||
(onAlt : Nat → Expr → TransformAltFVars → Expr → n Expr := fun _ _ _ e => pure e)
|
||||
(onRemaining : Array Expr → n (Array Expr) := pure) :
|
||||
n MatcherApp := do
|
||||
|
||||
@@ -331,7 +372,7 @@ def transform
|
||||
let altParams := args ++ ys3
|
||||
let alt ← try instantiateLambda alt altParams
|
||||
catch _ => throwError "unexpected matcher application, insufficient number of parameters in alternative"
|
||||
let alt' ← onAlt altIdx altType altParams alt
|
||||
let alt' ← onAlt altIdx altType { args, fields := ys, overlaps := ys2, discrEqs := ys3, extraEqs := ys4 } alt
|
||||
mkLambdaFVars (ys ++ ys2 ++ ys3 ++ ys4) alt'
|
||||
if splitterAltInfo.hasUnitThunk then
|
||||
-- The splitter expects a thunked alternative, but we don't want the `x : Unit` to be in
|
||||
@@ -372,7 +413,7 @@ def transform
|
||||
let names ← lambdaTelescope alt fun xs _ => xs.mapM (·.fvarId!.getUserName)
|
||||
withUserNames xs names do
|
||||
let alt ← instantiateLambda alt xs
|
||||
let alt' ← onAlt altIdx altType xs alt
|
||||
let alt' ← onAlt altIdx altType { args := xs, fields := xs, extraEqs := ys4 } alt
|
||||
mkLambdaFVars (xs ++ ys4) alt'
|
||||
alts' := alts'.push alt'
|
||||
|
||||
@@ -446,7 +487,7 @@ def inferMatchType (matcherApp : MatcherApp) : MetaM MatcherApp := do
|
||||
}
|
||||
mkArrowN extraParams typeMatcherApp.toExpr
|
||||
)
|
||||
(onAlt := fun _altIdx expAltType _altParams alt => do
|
||||
(onAlt := fun _altIdx expAltType _altFVars alt => do
|
||||
let altType ← inferType alt
|
||||
let eq ← mkEq expAltType altType
|
||||
let proof ← mkFreshExprSyntheticOpaqueMVar eq
|
||||
|
||||
@@ -315,7 +315,7 @@ partial def foldAndCollect (oldIH newIH : FVarId) (isRecCall : Expr → Option E
|
||||
-- statement and the inferred alt types
|
||||
let dummyGoal := mkConst ``True []
|
||||
mkArrow eTypeAbst dummyGoal)
|
||||
(onAlt := fun _altIdx altType _altParams alt => do
|
||||
(onAlt := fun _altIdx altType _altFVars alt => do
|
||||
lambdaTelescope1 alt fun oldIH' alt => do
|
||||
forallBoundedTelescope altType (some 1) fun newIH' _goal' => do
|
||||
let #[newIH'] := newIH' | unreachable!
|
||||
@@ -333,7 +333,7 @@ partial def foldAndCollect (oldIH newIH : FVarId) (isRecCall : Expr → Option E
|
||||
(onMotive := fun _motiveArgs motiveBody => do
|
||||
let some (_extra, body) := motiveBody.arrow? | throwError "motive not an arrow"
|
||||
M.eval (foldAndCollect oldIH newIH isRecCall body))
|
||||
(onAlt := fun _altIdx altType _altParams alt => do
|
||||
(onAlt := fun _altIdx altType _altFVars alt => do
|
||||
lambdaTelescope1 alt fun oldIH' alt => do
|
||||
-- We don't have suitable newIH around here, but we don't care since
|
||||
-- we just want to fold calls. So lets create a fake one.
|
||||
@@ -691,7 +691,7 @@ partial def buildInductionBody (toErase toClear : Array FVarId) (goal : Expr)
|
||||
(addEqualities := true)
|
||||
(onParams := (foldAndCollect oldIH newIH isRecCall ·))
|
||||
(onMotive := fun xs _body => pure (absMotiveBody.beta (Array.mask mask xs)))
|
||||
(onAlt := fun altIdx expAltType _altParams alt => M2.branch do
|
||||
(onAlt := fun altIdx expAltType _altFVars alt => M2.branch do
|
||||
lambdaTelescope1 alt fun oldIH' alt => do
|
||||
forallBoundedTelescope expAltType (some 1) fun newIH' goal' => do
|
||||
let #[newIH'] := newIH' | unreachable!
|
||||
@@ -714,7 +714,7 @@ partial def buildInductionBody (toErase toClear : Array FVarId) (goal : Expr)
|
||||
(addEqualities := true)
|
||||
(onParams := (foldAndCollect oldIH newIH isRecCall ·))
|
||||
(onMotive := fun xs _body => pure (absMotiveBody.beta (Array.mask mask xs)))
|
||||
(onAlt := fun altIdx expAltType _altParams alt => M2.branch do
|
||||
(onAlt := fun altIdx expAltType _altFVars alt => M2.branch do
|
||||
withRewrittenMotiveArg expAltType (rwMatcher altIdx) fun expAltType' =>
|
||||
buildInductionBody toErase toClear expAltType' oldIH newIH isRecCall alt)
|
||||
return matcherApp'.toExpr
|
||||
|
||||
Reference in New Issue
Block a user