Compare commits

...

31 Commits

Author SHA1 Message Date
Joachim Breitner
035b886ae7 Merge branch 'master' of https://github.com/leanprover/lean4 into joachim/splitter-via-match 2025-12-04 11:17:02 +01:00
Joachim Breitner
4988c6e302 Merge branch 'joachim/bench-nat-match' into joachim/splitter-via-match 2025-12-04 09:05:18 +01:00
Joachim Breitner
caf137cf55 Merge branch 'master' of https://github.com/leanprover/lean4 into joachim/bench-nat-match 2025-12-04 09:00:05 +01:00
Joachim Breitner
d1d23b80dd test: add big match on nat lit benchmarks
This PR adds two benchmarks for elaborating match statements of many
`Nat` literals, one without and one with splitter generation.
2025-12-04 08:58:19 +01:00
Joachim Breitner
9bb4009689 Remove unused code, move common code to helper modules 2025-11-21 10:31:59 +01:00
Joachim Breitner
2247c98e77 Merge branch 'joachim/reducibility-replay' into joachim/splitter-via-match 2025-11-21 10:11:09 +01:00
Joachim Breitner
f828fd517b feat: allow setting reducibilityCoreExt in async contexts
This PR allows setting reducibilityCoreExt in async contexts (e.g. when
using `mkSparseCasesOn` in a realizable definition)
2025-11-21 10:07:29 +01:00
Joachim Breitner
5aebd9682c Don’t add splitter to matcherExt 2025-11-21 10:06:35 +01:00
Joachim Breitner
230fb8bf6e Reduce diff 2025-11-21 09:59:05 +01:00
Joachim Breitner
a1854d14ae Avoid needsSplitter field 2025-11-20 10:02:44 +01:00
Joachim Breitner
84fcfab170 Merge branch 'master' of https://github.com/leanprover/lean4 into joachim/splitter-via-match 2025-11-20 09:50:13 +01:00
Joachim Breitner
423db875c7 Update test 2025-11-19 09:31:15 +01:00
Joachim Breitner
039339a759 Heed overlaps in new code 2025-11-19 09:20:03 +01:00
Joachim Breitner
95d7bbedaa Merge branch 'master' of https://github.com/leanprover/lean4 into joachim/splitter-via-match 2025-11-18 15:56:04 +01:00
Joachim Breitner
c685c19111 Merge commit 'f6e580ccf8997d6ba65424377dc2f5e65d676153^' into joachim/splitter-via-match 2025-11-18 15:50:59 +01:00
Joachim Breitner
c6a449e813 Merge commit 'be6457284a200ad8936f97a57c88dcd8f8c1164e' into joachim/splitter-via-match 2025-11-18 15:38:17 +01:00
Joachim Breitner
73be06937a Merge branch 'joachim/splitter-refactor' into joachim/splitter-via-match 2025-11-18 10:48:50 +01:00
Joachim Breitner
6ebb738bd8 refactor: extract functionality from Match.MatchEqs
This PR extracts two modules from `Match.MatchEqs`, in preparation of #11220
and to use the module system to draw clear boundaries between concerns
here.
2025-11-18 10:46:40 +01:00
Joachim Breitner
e69c1da44d Merge branch 'nightly-with-mathlib' of https://github.com/leanprover/lean4 into joachim/splitter-via-match 2025-11-18 10:34:22 +01:00
Joachim Breitner
0b9db11eeb Spurious newline 2025-11-17 19:03:32 +01:00
Joachim Breitner
ab44217947 More dead code 2025-11-17 19:01:30 +01:00
Joachim Breitner
c8d1b464a0 Not needed 2025-11-17 18:56:52 +01:00
Joachim Breitner
6d7b71e3f7 Merge branch 'joachim/realizeConst_withDeclNameForAuxNaming' into joachim/splitter-via-match 2025-11-17 18:56:41 +01:00
Joachim Breitner
577fb913b3 Add test case 2025-11-17 18:56:35 +01:00
Joachim Breitner
9efa4831c7 fix: let realizeConst run withDeclNameForAuxNaming
This PRs lets `realizeConst` use `withDeclNameForAuxNaming` so that
auxilary definitions created there get non-clashing names.
2025-11-17 18:47:33 +01:00
Joachim Breitner
04c56b31c9 Remove dead code 2025-11-17 18:46:06 +01:00
Joachim Breitner
05521dea02 Seems to work 2025-11-17 18:33:32 +01:00
Joachim Breitner
1cb7527e30 progress 2025-11-17 17:51:27 +01:00
Joachim Breitner
e8d867b670 Add file 2025-11-17 17:01:22 +01:00
Joachim Breitner
6b29b12009 stash 2025-11-17 16:58:38 +01:00
Joachim Breitner
a1155dea35 Stash 2025-11-17 14:39:06 +01:00
14 changed files with 178 additions and 314 deletions

View File

@@ -8,10 +8,12 @@ module
prelude
public import Lean.Elab.PreDefinition.FixedParams
import Lean.Elab.PreDefinition.EqnsUtils
import Lean.Meta.Tactic.Split
import Lean.Meta.Tactic.CasesOnStuckLHS
import Lean.Meta.Tactic.Delta
import Lean.Meta.Tactic.Simp.Main
import Lean.Meta.Tactic.Delta
import Lean.Meta.Tactic.CasesOnStuckLHS
import Lean.Meta.Tactic.Split
namespace Lean.Elab
open Meta

View File

@@ -213,7 +213,9 @@ public def mkCasesOnSameCtor (declName : Name) (indName : Name) : MetaM Unit :=
numDiscrs := info.numIndices + 3
altInfos
uElimPos? := some 0
discrInfos := #[{}, {}, {}]}
discrInfos := #[{}, {}, {}]
overlaps := {}
}
-- Compare attributes with `mkMatcherAuxDefinition`
withExporting (isExporting := !isPrivateName declName) do

View File

@@ -319,7 +319,7 @@ public partial def mkBelowMatcher (matcherApp : MatcherApp) (belowParams : Array
(ctx : RecursionContext) (transformAlt : RecursionContext Expr MetaM Expr) :
MetaM (Option (Expr × MetaM Unit)) :=
withTraceNode `Meta.IndPredBelow.match (return m!"{exceptEmoji ·} {matcherApp.toExpr} and {belowParams}") do
let mut input getMkMatcherInputInContext matcherApp
let mut input getMkMatcherInputInContext matcherApp (unfoldNamed := false)
let mut discrs := matcherApp.discrs
let mut matchTypeAdd := #[] -- #[(discrIdx, ), ...]
let mut i := discrs.size

View File

@@ -150,6 +150,11 @@ structure Alt where
After we perform additional case analysis, their types become definitionally equal.
-/
cnstrs : List (Expr × Expr)
/--
Indices of previous alternatives that this alternative expects a not-that-proofs.
(When producing a splitter, and in the future also for source-level overlap hypotheses.)
-/
notAltIdxs : Array Nat
deriving Inhabited
namespace Alt

View File

@@ -12,7 +12,11 @@ public import Lean.Meta.GeneralizeTelescope
public import Lean.Meta.Match.Basic
public import Lean.Meta.Match.MatcherApp.Basic
public import Lean.Meta.Match.MVarRenaming
public import Lean.Meta.Match.MVarRenaming
import Lean.Meta.Match.SimpH
import Lean.Meta.Match.SolveOverlap
import Lean.Meta.HasNotBit
import Lean.Meta.Match.NamedPatterns
public section
@@ -92,34 +96,62 @@ where
/-- Given a list of `AltLHS`, create a minor premise for each one, convert them into `Alt`, and then execute `k` -/
private def withAlts {α} (motive : Expr) (discrs : Array Expr) (discrInfos : Array DiscrInfo)
(lhss : List AltLHS) (k : List Alt Array Expr Array AltParamInfo MetaM α) : MetaM α :=
loop lhss [] #[] #[]
(lhss : List AltLHS) (isSplitter : Option Overlaps)
(k : List Alt Array Expr Array AltParamInfo MetaM α) : MetaM α :=
loop lhss [] #[] #[] #[]
where
mkMinorType (xs : Array Expr) (lhs : AltLHS) : MetaM Expr :=
mkSplitterHyps (idx : Nat) (lhs : AltLHS) (notAlts : Array Expr) : MetaM (Array Expr × Array Nat) := do
withExistingLocalDecls lhs.fvarDecls do
let patterns lhs.patterns.toArray.mapM (Pattern.toExpr · (annotate := true))
let mut hs := #[]
let mut notAltIdxs := #[]
for overlappingIdx in isSplitter.get!.overlapping idx do
let notAlt := notAlts[overlappingIdx]!
let h instantiateForall notAlt patterns
if let some h simpH? h patterns.size then
notAltIdxs := notAltIdxs.push overlappingIdx
hs := hs.push h
trace[Meta.Match.debug] "hs for {lhs.ref}: {hs}"
return (hs, notAltIdxs)
mkMinorType (xs : Array Expr) (lhs : AltLHS) (notAltHs : Array Expr): MetaM Expr :=
withExistingLocalDecls lhs.fvarDecls do
let args lhs.patterns.toArray.mapM (Pattern.toExpr · (annotate := true))
let minorType := mkAppN motive args
withEqs discrs args discrInfos fun eqs => do
mkForallFVars (xs ++ eqs) minorType
let minorType mkForallFVars eqs minorType
let minorType mkArrowN notAltHs minorType
mkForallFVars xs minorType
loop (lhss : List AltLHS) (alts : List Alt) (minors : Array Expr) (altInfos : Array AltParamInfo) : MetaM α := do
mkNotAlt (xs : Array Expr) (lhs : AltLHS) : MetaM Expr := do
withExistingLocalDecls lhs.fvarDecls do
let mut notAlt := mkConst ``False
for discr in discrs.reverse, pattern in lhs.patterns.reverse do
notAlt mkArrow ( mkEqHEq discr ( pattern.toExpr)) notAlt
notAlt mkForallFVars (discrs ++ xs) notAlt
return notAlt
loop (lhss : List AltLHS) (alts : List Alt) (minors : Array Expr) (altInfos : Array AltParamInfo) (notAlts : Array Expr) : MetaM α := do
match lhss with
| [] => k alts.reverse minors altInfos
| lhs::lhss =>
let xs := lhs.fvarDecls.toArray.map LocalDecl.toExpr
let minorType mkMinorType xs lhs
let hasParams := !xs.isEmpty || discrInfos.any fun info => info.hName?.isSome
let minorType := if hasParams then minorType else mkSimpleThunkType minorType
let idx := alts.length
let xs := lhs.fvarDecls.toArray.map LocalDecl.toExpr
let (notAltHs, notAltIdxs) if isSplitter.isSome then mkSplitterHyps idx lhs notAlts else pure (#[], #[])
let minorType mkMinorType xs lhs notAltHs
let notAlt mkNotAlt xs lhs
let hasParams := !xs.isEmpty || !notAltHs.isEmpty || discrInfos.any fun info => info.hName?.isSome
let minorType := if hasParams then minorType else mkSimpleThunkType minorType
let minorName := (`h).appendIndexAfter (idx+1)
trace[Meta.Match.debug] "minor premise {minorName} : {minorType}"
withLocalDeclD minorName minorType fun minor => do
let rhs := if hasParams then mkAppN minor xs else mkApp minor (mkConst `Unit.unit)
let minors := minors.push minor
let altInfos := altInfos.push { numFields := xs.size, numOverlaps := 0, hasUnitThunk := !hasParams }
let altInfos := altInfos.push { numFields := xs.size, numOverlaps := notAltHs.size, hasUnitThunk := !hasParams }
let fvarDecls lhs.fvarDecls.mapM instantiateLocalDeclMVars
let alts := { ref := lhs.ref, idx := idx, rhs := rhs, fvarDecls := fvarDecls, patterns := lhs.patterns, cnstrs := [] } :: alts
loop lhss alts minors altInfos
let alt := { ref := lhs.ref, idx := idx, rhs := rhs, fvarDecls := fvarDecls, patterns := lhs.patterns, cnstrs := [], notAltIdxs := notAltIdxs }
let alts := alt :: alts
loop lhss alts minors altInfos (notAlts.push notAlt)
structure State where
/-- Used alternatives -/
@@ -338,7 +370,7 @@ where
return (p, (lhs, rhs) :: cnstrs)
/--
Solve pending alternative constraints.
Solve pending alternative constraints and overlap assumptions.
If all constraints can be solved perform assignment `mvarId := alt.rhs`, else throw error.
-/
private partial def solveCnstrs (mvarId : MVarId) (alt : Alt) : StateRefT State MetaM Unit := do
@@ -350,13 +382,19 @@ where
| none =>
let alt filterTrivialCnstrs alt
if alt.cnstrs.isEmpty then
let eType inferType alt.rhs
let targetType mvarId.getType
unless ( isDefEqGuarded targetType eType) do
trace[Meta.Match.match] "assignGoalOf failed {eType} =?= {targetType}"
throwErrorAt alt.ref "Dependent elimination failed: Type mismatch when solving this alternative: it {← mkHasTypeButIsExpectedMsg eType targetType}"
mvarId.assign alt.rhs
modify fun s => { s with used := s.used.insert alt.idx }
mvarId.withContext do
let eType inferType alt.rhs
let (notAltsMVarIds, _, eType) forallMetaBoundedTelescope eType alt.notAltIdxs.size
unless notAltsMVarIds.size = alt.notAltIdxs.size do
throwErrorAt alt.ref "Incorrect number of overlap hypotheses in the right-hand-side, expected {alt.notAltIdxs.size}:{indentExpr eType}"
let targetType mvarId.getType
unless ( isDefEqGuarded targetType eType) do
trace[Meta.Match.match] "assignGoalOf failed {eType} =?= {targetType}"
throwErrorAt alt.ref "Dependent elimination failed: Type mismatch when solving this alternative: it {← mkHasTypeButIsExpectedMsg eType targetType}"
for notAltMVarId in notAltsMVarIds do
solveOverlap notAltMVarId.mvarId!
mvarId.assign (mkAppN alt.rhs notAltsMVarIds)
modify fun s => { s with used := s.used.insert alt.idx }
else
trace[Meta.Match.match] "alt has unsolved cnstrs:\n{← alt.toMessageData}"
let mut msg := m!"Dependent match elimination failed: Could not solve constraints"
@@ -636,7 +674,7 @@ private def processConstructor (p : Problem) : MetaM (Array Problem) := do
| .var _ :: _ => expandVarIntoCtor alt ctorName
| .inaccessible _ :: _ => processInaccessibleAsCtor alt ctorName
| _ => unreachable!
return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
return { p with mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
else
-- A catch-all case
let subst := subgoal.subst
@@ -647,7 +685,7 @@ private def processConstructor (p : Problem) : MetaM (Array Problem) := do
| .ctor .. :: _ => false
| _ => true
let newAlts := newAlts.map fun alt => alt.applyFVarSubst subst
return { mvarId := subgoal.mvarId, alts := newAlts, vars := newVars, examples := examples }
return { p with mvarId := subgoal.mvarId, alts := newAlts, vars := newVars, examples := examples }
private def processNonVariable (p : Problem) : MetaM Problem := withGoalOf p do
let x :: xs := p.vars | unreachable!
@@ -708,7 +746,7 @@ private def processValue (p : Problem) : MetaM (Array Problem) := do
alt.replaceFVarId fvarId value
| _ => unreachable!
let newVars := xs.map fun x => x.applyFVarSubst subst
return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
return { p with mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
else
-- else branch for value
let newAlts := p.alts.filter isFirstPatternVar
@@ -764,7 +802,7 @@ private def processArrayLit (p : Problem) : MetaM (Array Problem) := do
let α getArrayArgType <| subst.apply x
expandVarIntoArrayLit { alt with patterns := ps } fvarId α size
| _ => unreachable!
return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
return { p with mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
else
-- else branch
let newAlts := p.alts.filter isFirstPatternVar
@@ -1018,7 +1056,7 @@ private builtin_initialize matcherExt : EnvExtension (PHashMap MatcherKey Name)
/-- Similar to `mkAuxDefinition`, but uses the cache `matcherExt`.
It also returns an Boolean that indicates whether a new matcher function was added to the environment or not. -/
def mkMatcherAuxDefinition (name : Name) (type : Expr) (value : Expr) : MetaM (Expr × Option (MatcherInfo MetaM Unit)) := do
def mkMatcherAuxDefinition (name : Name) (type : Expr) (value : Expr) (isSplitter : Bool) : MetaM (Expr × Option (MatcherInfo MetaM Unit)) := do
trace[Meta.Match.debug] "{name} : {type} := {value}"
let compile := bootstrap.genMatcherCode.get ( getOptions)
let result Closure.mkValueTypeClosure type value (zetaDelta := false)
@@ -1026,10 +1064,12 @@ def mkMatcherAuxDefinition (name : Name) (type : Expr) (value : Expr) : MetaM (E
let mkMatcherConst name :=
mkAppN (mkConst name result.levelArgs.toList) result.exprArgs
let key := { value := result.value, compile, isPrivate := env.header.isModule && isPrivateName name }
let mut nameNew? := (matcherExt.getState env).find? key
if nameNew?.isNone && key.isPrivate then
-- private contexts may reuse public matchers
nameNew? := (matcherExt.getState env).find? { key with isPrivate := false }
let mut nameNew? := none
unless isSplitter do
nameNew? := (matcherExt.getState env).find? key
if nameNew?.isNone && key.isPrivate then
-- private contexts may reuse public matchers
nameNew? := (matcherExt.getState env).find? { key with isPrivate := false }
match nameNew? with
| some nameNew => return (mkMatcherConst nameNew, none)
| none =>
@@ -1040,8 +1080,9 @@ def mkMatcherAuxDefinition (name : Name) (type : Expr) (value : Expr) : MetaM (E
-- matcher bodies should always be exported, if not private anyway
withExporting do
addDecl decl
modifyEnv fun env => matcherExt.modifyState env fun s => s.insert key name
addMatcherInfo name mi
unless isSplitter do
modifyEnv fun env => matcherExt.modifyState env fun s => s.insert key name
addMatcherInfo name mi
setInlineAttribute name
enableRealizationsForConst name
if compile then
@@ -1053,6 +1094,7 @@ structure MkMatcherInput where
matchType : Expr
discrInfos : Array DiscrInfo
lhss : List AltLHS
isSplitter : Option Overlaps := none
def MkMatcherInput.numDiscrs (m : MkMatcherInput) :=
m.discrInfos.size
@@ -1093,7 +1135,7 @@ The generated matcher has the structure described at `MatcherInfo`. The motive a
where `v` is a universe parameter or 0 if `B[a_1, ..., a_n]` is a proposition.
-/
def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor input do
let matcherName, matchType, discrInfos, lhss := input
let {matcherName, matchType, discrInfos, lhss, isSplitter} := input
let numDiscrs := discrInfos.size
checkNumPatterns numDiscrs lhss
forallBoundedTelescope matchType numDiscrs fun discrs matchTypeBody => do
@@ -1116,7 +1158,7 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor
| negSucc n => succ n
```
which is defined **before** `Int.decLt` -/
let (matcher, addMatcher) mkMatcherAuxDefinition matcherName type val
let (matcher, addMatcher) mkMatcherAuxDefinition matcherName type val (isSplitter := input.isSplitter.isSome)
trace[Meta.Match.debug] "matcher levels: {matcher.getAppFn.constLevels!}, uElim: {uElimGen}"
let uElimPos? getUElimPos? matcher.getAppFn.constLevels! uElimGen
discard <| isLevelDefEq uElimGen uElim
@@ -1152,7 +1194,7 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor
let isEqMask eqs.mapM fun eq => return ( inferType eq).isEq
return (mvarType, isEqMask)
trace[Meta.Match.debug] "target: {mvarType}"
withAlts motive discrs discrInfos lhss fun alts minors altInfos => do
withAlts motive discrs discrInfos lhss isSplitter fun alts minors altInfos => do
let mvar mkFreshExprMVar mvarType
trace[Meta.Match.debug] "goal\n{mvar.mvarId!}"
let examples := discrs'.toList.map fun discr => Example.var discr.fvarId!
@@ -1176,7 +1218,7 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor
else
let mvarType := mkAppN motive discrs
trace[Meta.Match.debug] "target: {mvarType}"
withAlts motive discrs discrInfos lhss fun alts minors altInfos => do
withAlts motive discrs discrInfos lhss isSplitter fun alts minors altInfos => do
let mvar mkFreshExprMVar mvarType
let examples := discrs.toList.map fun discr => Example.var discr.fvarId!
let (_, s) (process { mvarId := mvar.mvarId!, vars := discrs.toList, alts := alts, examples := examples }).run {}
@@ -1185,7 +1227,7 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor
let val mkLambdaFVars args mvar
mkMatcher type val altInfos s
def getMkMatcherInputInContext (matcherApp : MatcherApp) : MetaM MkMatcherInput := do
def getMkMatcherInputInContext (matcherApp : MatcherApp) (unfoldNamed : Bool) : MetaM MkMatcherInput := do
let matcherName := matcherApp.matcherName
let some matcherInfo getMatcherInfo? matcherName
| throwError "Internal error during match expression elaboration: Could not find a matcher named `{matcherName}`"
@@ -1204,6 +1246,7 @@ def getMkMatcherInputInContext (matcherApp : MatcherApp) : MetaM MkMatcherInput
let lhss forallBoundedTelescope matcherType (some matcherApp.alts.size) fun alts _ =>
alts.mapM fun alt => do
let ty inferType alt
let ty if unfoldNamed then unfoldNamedPattern ty else pure ty
forallTelescope ty fun xs body => do
let xs xs.filterM fun x => dependsOn body x.fvarId!
body.withApp fun _ args => do
@@ -1217,18 +1260,17 @@ def getMkMatcherInputInContext (matcherApp : MatcherApp) : MetaM MkMatcherInput
return { matcherName, matchType, discrInfos := matcherInfo.discrInfos, lhss := lhss.toList }
/-- This function is only used for testing purposes -/
def withMkMatcherInput (matcherName : Name) (k : MkMatcherInput MetaM α) : MetaM α := do
def withMkMatcherInput (matcherName : Name) (unfoldNamed : Bool) (k : MkMatcherInput MetaM α) : MetaM α := do
let some matcherInfo getMatcherInfo? matcherName
| throwError "Internal error during match expression elaboration: Could not find a matcher named `{matcherName}`"
| throwError "withMkMatcherInput: {.ofConstName matcherName} is not a matcher"
let matcherConst getConstInfo matcherName
forallBoundedTelescope matcherConst.type (some matcherInfo.arity) fun xs _ => do
let matcherApp mkConstWithLevelParams matcherConst.name
let matcherApp := mkAppN matcherApp xs
let some matcherApp matchMatcherApp? matcherApp
| throwError "Internal error during match expression elaboration: Could not find a matcher app named `{matcherApp}`"
let mkMatcherInput getMkMatcherInputInContext matcherApp
k mkMatcherInput
forallBoundedTelescope matcherConst.type matcherInfo.arity fun xs _ => do
let matcherApp mkConstWithLevelParams matcherConst.name
let matcherApp := mkAppN matcherApp xs
let some matcherApp matchMatcherApp? matcherApp
| throwError "withMkMatcherInput: {.ofConstName matcherName} does not produce a matcher application"
let mkMatcherInput getMkMatcherInputInContext matcherApp unfoldNamed
k mkMatcherInput
end Match

View File

@@ -110,220 +110,6 @@ where
(throwError "failed to generate equality theorems for `match` expression `{matchDeclName}`\n{MessageData.ofGoal mvarId}")
subgoals.forM (go · (depth+1))
/-- Construct new local declarations `xs` with types `altTypes`, and then execute `f xs` -/
private partial def withSplitterAlts (altTypes : Array Expr) (f : Array Expr MetaM α) : MetaM α := do
let rec go (i : Nat) (xs : Array Expr) : MetaM α := do
if h : i < altTypes.size then
let hName := (`h).appendIndexAfter (i+1)
withLocalDeclD hName altTypes[i] fun x =>
go (i+1) (xs.push x)
else
f xs
go 0 #[]
private abbrev ConvertM := ReaderT (FVarIdMap (Expr × AltParamInfo × Array Bool)) $ StateRefT (Array MVarId) MetaM
/--
Construct a proof for the splitter generated by `mkEquationsFor`.
The proof uses the definition of the `match`-declaration as a template (argument `template`).
- `alts` are free variables corresponding to alternatives of the `match` auxiliary declaration being processed.
- `altNews` are the new free variables which contains additional hypotheses that ensure they are only used
when the previous overlapping alternatives are not applicable.
- `altInfos` refers to the splitter -/
private partial def mkSplitterProof (matchDeclName : Name) (template : Expr) (alts altsNew : Array Expr)
(altInfos : Array AltParamInfo) (altArgMasks : Array (Array Bool)) : MetaM Expr := do
trace[Meta.Match.matchEqs] "proof template: {template}"
let map := mkMap
let (proof, mvarIds) convertTemplate template |>.run map |>.run #[]
trace[Meta.Match.matchEqs] "splitter proof: {proof}"
for mvarId in mvarIds do
let mvarId mvarId.tryClearMany (alts.map (·.fvarId!))
solveOverlap mvarId
instantiateMVars proof
where
mkMap : FVarIdMap (Expr × AltParamInfo × Array Bool) := Id.run do
let mut m := {}
for alt in alts, altNew in altsNew, altInfo in altInfos, argMask in altArgMasks do
m := m.insert alt.fvarId! (altNew, altInfo, argMask)
return m
trimFalseTrail (argMask : Array Bool) : Array Bool :=
if argMask.isEmpty then
argMask
else if !argMask.back! then
trimFalseTrail argMask.pop
else
argMask
/--
Auxiliary function used at `convertTemplate` to decide whether to use `convertCastEqRec`.
See `convertCastEqRec`. -/
isCastEqRec (e : Expr) : ConvertM Bool := do
-- TODO: we do not handle `Eq.rec` since we never found an example that needed it.
-- If we find one we must extend `convertCastEqRec`.
unless e.isAppOf ``Eq.ndrec do return false
unless e.getAppNumArgs > 6 do return false
for arg in e.getAppArgs[6...*] do
if arg.isFVar && ( read).contains arg.fvarId! then
return true
return true
/--
Auxiliary function used at `convertTemplate`. It is needed when the auxiliary `match` declaration had to refine the type of its
minor premises during dependent pattern match. For an example, consider
```
inductive Foo : Nat → Type _
| nil : Foo 0
| cons (t: Foo l): Foo l
def Foo.bar (t₁: Foo l₁): Foo l₂ → Bool
| cons s₁ => t₁.bar s₁
| _ => false
attribute [simp] Foo.bar
```
The auxiliary `Foo.bar.match_1` is of the form
```
def Foo.bar.match_1.{u_1} : {l₂ : Nat} →
(t₂ : Foo l₂) →
(motive : Foo l₂ → Sort u_1) →
(t₂ : Foo l₂) → ((s₁ : Foo l₂) → motive (Foo.cons s₁)) → ((x : Foo l₂) → motive x) → motive t₂ :=
fun {l₂} t₂ motive t₂_1 h_1 h_2 =>
(fun t₂_2 =>
Foo.casesOn (motive := fun a x => l₂ = a → t₂_1 ≍ x → motive t₂_1) t₂_2
(fun h =>
Eq.ndrec (motive := fun {l₂} =>
(t₂ t₂ : Foo l₂) →
(motive : Foo l₂ → Sort u_1) →
((s₁ : Foo l₂) → motive (Foo.cons s₁)) → ((x : Foo l₂) → motive x) → t₂ ≍ Foo.nil → motive t₂)
(fun t₂ t₂ motive h_1 h_2 h => Eq.symm (eq_of_heq h) ▸ h_2 Foo.nil) (Eq.symm h) t₂ t₂_1 motive h_1 h_2) --- HERE
fun {l} t h =>
Eq.ndrec (motive := fun {l} => (t : Foo l) → t₂_1 ≍ Foo.cons t → motive t₂_1)
(fun t h => Eq.symm (eq_of_heq h) ▸ h_1 t) h t)
t₂_1 (Eq.refl l₂) (HEq.refl t₂_1)
```
The `HERE` comment marks the place where the type of `Foo.bar.match_1` minor premises `h_1` and `h_2` is being "refined"
using `Eq.ndrec`.
This function will adjust the motive and minor premise of the `Eq.ndrec` to reflect the new minor premises used in the
corresponding splitter theorem.
We may have to extend this function to handle `Eq.rec` too.
This function was added to address issue #1179
-/
convertCastEqRec (e : Expr) : ConvertM Expr := do
assert! ( isCastEqRec e)
e.withApp fun f args => do
let mut argsNew := args
let mut isAlt := #[]
for i in 6...args.size do
let arg := argsNew[i]!
if arg.isFVar then
match ( read).get? arg.fvarId! with
| some (altNew, _, _) =>
argsNew := argsNew.set! i altNew
trace[Meta.Match.matchEqs] "arg: {arg} : {← inferType arg}, altNew: {altNew} : {← inferType altNew}"
isAlt := isAlt.push true
| none =>
argsNew := argsNew.set! i ( convertTemplate arg)
isAlt := isAlt.push false
else
argsNew := argsNew.set! i ( convertTemplate arg)
isAlt := isAlt.push false
assert! isAlt.size == args.size - 6
let rhs := args[4]!
let motive := args[2]!
-- Construct new motive using the splitter theorem minor premise types.
let motiveNew lambdaTelescope motive fun motiveArgs body => do
unless motiveArgs.size == 1 do
throwError "unexpected `Eq.ndrec` motive while creating splitter/eliminator theorem for `{matchDeclName}`, expected lambda with 1 binder{indentExpr motive}"
let x := motiveArgs[0]!
forallTelescopeReducing body fun motiveTypeArgs resultType => do
unless motiveTypeArgs.size >= isAlt.size do
throwError "unexpected `Eq.ndrec` motive while creating splitter/eliminator theorem for `{matchDeclName}`, expected arrow with at least #{isAlt.size} binders{indentExpr body}"
let rec go (i : Nat) (motiveTypeArgsNew : Array Expr) : ConvertM Expr := do
assert! motiveTypeArgsNew.size == i
if h : i < motiveTypeArgs.size then
let motiveTypeArg := motiveTypeArgs[i]
if i < isAlt.size && isAlt[i]! then
let altNew := argsNew[6+i]! -- Recall that `Eq.ndrec` has 6 arguments
let altTypeNew inferType altNew
trace[Meta.Match.matchEqs] "altNew: {altNew} : {altTypeNew}"
-- Replace `rhs` with `x` (the lambda binder in the motive)
let mut altTypeNewAbst := ( kabstract altTypeNew rhs).instantiate1 x
-- Replace args[6...(6+i)] with `motiveTypeArgsNew`
for j in *...i do
altTypeNewAbst := ( kabstract altTypeNewAbst argsNew[6+j]!).instantiate1 motiveTypeArgsNew[j]!
let localDecl motiveTypeArg.fvarId!.getDecl
withLocalDecl localDecl.userName localDecl.binderInfo altTypeNewAbst fun motiveTypeArgNew =>
go (i+1) (motiveTypeArgsNew.push motiveTypeArgNew)
else
go (i+1) (motiveTypeArgsNew.push motiveTypeArg)
else
mkLambdaFVars motiveArgs ( mkForallFVars motiveTypeArgsNew resultType)
go 0 #[]
trace[Meta.Match.matchEqs] "new motive: {motiveNew}"
unless ( isTypeCorrect motiveNew) do
throwError "failed to construct new type correct motive for `Eq.ndrec` while creating splitter/eliminator theorem for `{matchDeclName}`{indentExpr motiveNew}"
argsNew := argsNew.set! 2 motiveNew
-- Construct the new minor premise for the `Eq.ndrec` application.
-- First, we use `eqRecNewPrefix` to infer the new minor premise binders for `Eq.ndrec`
let eqRecNewPrefix := mkAppN f argsNew[*...3] -- `Eq.ndrec` minor premise is the fourth argument.
let .forallE _ minorTypeNew .. whnf ( inferType eqRecNewPrefix) | unreachable!
trace[Meta.Match.matchEqs] "new minor type: {minorTypeNew}"
let minor := args[3]!
let minorNew forallBoundedTelescope minorTypeNew isAlt.size fun minorArgsNew _ => do
let mut minorBodyNew := minor
-- We have to extend the mapping to make sure `convertTemplate` can "fix" occurrences of the refined minor premises
let mut m read
for h : i in *...isAlt.size do
if isAlt[i] then
-- `convertTemplate` will correct occurrences of the alternative
let alt := args[6+i]! -- Recall that `Eq.ndrec` has 6 arguments
let some (_, numParams, argMask) := m.get? alt.fvarId! | unreachable!
-- We add a new entry to `m` to make sure `convertTemplate` will correct the occurrences of the alternative
m := m.insert minorArgsNew[i]!.fvarId! (minorArgsNew[i]!, numParams, argMask)
unless minorBodyNew.isLambda do
throwError "unexpected `Eq.ndrec` minor premise while creating splitter/eliminator theorem for `{matchDeclName}`, expected lambda with at least #{isAlt.size} binders{indentExpr minor}"
minorBodyNew := minorBodyNew.bindingBody!
minorBodyNew := minorBodyNew.instantiateRev minorArgsNew
trace[Meta.Match.matchEqs] "minor premise new body before convertTemplate:{indentExpr minorBodyNew}"
minorBodyNew withReader (fun _ => m) <| convertTemplate minorBodyNew
trace[Meta.Match.matchEqs] "minor premise new body after convertTemplate:{indentExpr minorBodyNew}"
mkLambdaFVars minorArgsNew minorBodyNew
unless ( isTypeCorrect minorNew) do
throwError "failed to construct new type correct minor premise for `Eq.ndrec` while creating splitter/eliminator theorem for `{matchDeclName}`{indentExpr minorNew}"
argsNew := argsNew.set! 3 minorNew
-- trace[Meta.Match.matchEqs] "argsNew: {argsNew}"
trace[Meta.Match.matchEqs] "found cast target {e}"
return mkAppN f argsNew
convertTemplate (e : Expr) : ConvertM Expr :=
transform e fun e => do
if ( isCastEqRec e) then
return .done ( convertCastEqRec e)
else
let Expr.fvar fvarId .. := e.getAppFn | return .continue
let some (altNew, altParamInfo, argMask) := ( read).get? fvarId | return .continue
trace[Meta.Match.matchEqs] ">> argMask: {argMask}, altParamInfo: {repr altParamInfo}, e: {e}, alsNew: {altNew}, "
if altParamInfo.hasUnitThunk then
let eNew := mkApp altNew (mkConst ``Unit.unit)
return TransformStep.done eNew
let mut newArgs := #[]
let argMask := trimFalseTrail argMask
unless e.getAppNumArgs argMask.size do
throwError "unexpected occurrence of `match`-expression alternative (aka minor premise) while creating splitter/eliminator theorem for `{matchDeclName}`, minor premise is partially applied{indentExpr e}\npossible solution if you are matching on inductive families: add its indices as additional discriminants"
for arg in e.getAppArgs, includeArg in argMask do
if includeArg then
newArgs := newArgs.push arg
let eNew := mkAppN altNew newArgs
let (mvars, _, _) forallMetaBoundedTelescope ( inferType eNew) altParamInfo.numOverlaps (kind := MetavarKind.syntheticOpaque)
modify fun s => s ++ (mvars.map (·.mvarId!))
let eNew := mkAppN eNew mvars
return TransformStep.done eNew
/--
Create new alternatives (aka minor premises) by replacing `discrs` with `patterns` at `alts`.
Recall that `alts` depends on `discrs` when `numDiscrEqs > 0`, where `numDiscrEqs` is the number of discriminants
@@ -364,13 +150,15 @@ def getEquationsForImpl (matchDeclName : Name) : MetaM MatchEqns := do
-- `realizeConst` as well as for looking up the resultant environment extension state via
-- `getState`.
realizeConst matchDeclName splitterName (go baseName splitterName)
return matchEqnsExt.getState (asyncMode := .async .asyncEnv) (asyncDecl := splitterName) ( getEnv) |>.map.find! matchDeclName
match matchEqnsExt.getState (asyncMode := .async .asyncEnv) (asyncDecl := splitterName) ( getEnv) |>.map.find? matchDeclName with
| some eqns => return eqns
| none => throwError "failed to retrieve match equations for `{matchDeclName}` after realization"
where go baseName splitterName := withConfig (fun c => { c with etaStruct := .none }) do
let constInfo getConstInfo matchDeclName
let us := constInfo.levelParams.map mkLevelParam
let some matchInfo getMatcherInfo? matchDeclName | throwError "`{matchDeclName}` is not a matcher function"
let numDiscrEqs := getNumEqsFromDiscrInfos matchInfo.discrInfos
forallTelescopeReducing constInfo.type fun xs matchResultType => do
forallTelescopeReducing constInfo.type fun xs _matchResultType => do
let mut eqnNames := #[]
let params := xs[*...matchInfo.numParams]
let motive := xs[matchInfo.getMotivePos]!
@@ -379,16 +167,15 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
let discrs := xs[firstDiscrIdx...(firstDiscrIdx + matchInfo.numDiscrs)]
let mut notAlts := #[]
let mut idx := 1
let mut splitterAltTypes := #[]
let mut splitterAltInfos := #[]
let mut altArgMasks := #[] -- masks produced by `forallAltTelescope`
for i in *...alts.size do
let altInfo := matchInfo.altInfos[i]!
let thmName := Name.str baseName eqnThmSuffixBase |>.appendIndexAfter idx
eqnNames := eqnNames.push thmName
let (notAlt, splitterAltType, splitterAltInfo, argMask)
let (notAlt, splitterAltInfo, argMask)
forallAltTelescope ( inferType alts[i]!) altInfo numDiscrEqs
fun ys eqs rhsArgs argMask altResultType => do
fun ys _eqs rhsArgs argMask altResultType => do
let patterns := altResultType.getAppArgs
let mut hs := #[]
for overlappedBy in matchInfo.overlaps.overlapping i do
@@ -397,15 +184,7 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
if let some h simpH? h patterns.size then
hs := hs.push h
trace[Meta.Match.matchEqs] "hs: {hs}"
let splitterAltType mkForallFVars eqs altResultType
let splitterAltType mkArrowN hs splitterAltType
let splitterAltType mkForallFVars ys splitterAltType
let hasUnitThunk := splitterAltType == altResultType
let splitterAltType if hasUnitThunk then
mkArrow (mkConst ``Unit) splitterAltType
else
pure splitterAltType
let splitterAltType unfoldNamedPattern splitterAltType
let hasUnitThunk := ys.isEmpty && hs.isEmpty && numDiscrEqs = 0
let splitterAltInfo := { numFields := ys.size, numOverlaps := hs.size, hasUnitThunk }
-- Create a proposition for representing terms that do not match `patterns`
let mut notAlt := mkConst ``False
@@ -429,38 +208,38 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
type := thmType
value := thmVal
}
return (notAlt, splitterAltType, splitterAltInfo, argMask)
return (notAlt, splitterAltInfo, argMask)
notAlts := notAlts.push notAlt
splitterAltTypes := splitterAltTypes.push splitterAltType
splitterAltInfos := splitterAltInfos.push splitterAltInfo
altArgMasks := altArgMasks.push argMask
trace[Meta.Match.matchEqs] "splitterAltType: {splitterAltType}"
idx := idx + 1
-- Define splitter with conditional/refined alternatives
withSplitterAlts splitterAltTypes fun altsNew => do
let splitterParams := params.toArray ++ #[motive] ++ discrs.toArray ++ altsNew
let splitterType mkForallFVars splitterParams matchResultType
trace[Meta.Match.matchEqs] "splitterType: {splitterType}"
let splitterVal
if ( isDefEq splitterType constInfo.type) then
pure <| mkConst constInfo.name us
else
let template := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ discrs ++ alts)
let template deltaExpand template (· == constInfo.name)
let template := template.headBeta
mkLambdaFVars splitterParams ( mkSplitterProof matchDeclName template alts altsNew splitterAltInfos altArgMasks)
let splitterMatchInfo : MatcherInfo := { matchInfo with altInfos := splitterAltInfos }
let needsSplitter := !matchInfo.overlaps.isEmpty || (constInfo.type.find? (isNamedPattern )).isSome
if needsSplitter then
withMkMatcherInput matchDeclName (unfoldNamed := true) fun matcherInput => do
let matcherInput := { matcherInput with
matcherName := splitterName
isSplitter := some matchInfo.overlaps
}
let res Match.mkMatcher matcherInput
res.addMatcher -- TODO: Do not set matcherinfo for the splitter!
else
assert! matchInfo.altInfos == splitterAltInfos
-- This match statement does not need a splitter, we can use itself for that.
-- (We still have to generate a declaration to satisfy the realizable constant)
addAndCompile <| Declaration.defnDecl {
name := splitterName
levelParams := constInfo.levelParams
type := splitterType
value := splitterVal
type := constInfo.type
value := mkConst matchDeclName us
hints := .abbrev
safety := .safe
}
setInlineAttribute splitterName
let splitterMatchInfo := { matchInfo with altInfos := splitterAltInfos }
let result := { eqnNames, splitterName, splitterMatchInfo }
registerMatchEqns matchDeclName result
let result := { eqnNames, splitterName, splitterMatchInfo }
registerMatchEqns matchDeclName result
/- We generate the equations and splitter on demand, and do not save them on .olean files. -/
builtin_initialize matchCongrEqnsExt : EnvExtension (PHashMap Name (Array Name))

View File

@@ -67,6 +67,7 @@ def matchMatcherApp? [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (alsoCases
matcherName := declName
matcherLevels := declLevels.toArray
uElimPos?, discrInfos, params, motive, discrs, alts, remaining, altInfos
overlaps := {} -- CasesOn constructor have no overlaps
}
return none

View File

@@ -23,6 +23,9 @@ structure Overlaps where
map : Std.HashMap Nat (Std.TreeSet Nat) := {}
deriving Inhabited, Repr
def Overlaps.isEmpty (o : Overlaps) : Bool :=
o.map.isEmpty
def Overlaps.insert (o : Overlaps) (overlapping overlapped : Nat) : Overlaps where
map := o.map.alter overlapped fun s? => some ((s?.getD {}).insert overlapping)
@@ -41,29 +44,32 @@ structure AltParamInfo where
numOverlaps : Nat
/-- Whether this alternatie has an artifcial `Unit` parameter -/
hasUnitThunk : Bool
deriving Inhabited, Repr
deriving Inhabited, Repr, BEq
/--
A "matcher" auxiliary declaration has the following structure:
- `numParams` parameters
- motive
- `numDiscrs` discriminators (aka major premises)
- `altInfos.size` alternatives (aka minor premises) with parameter structure information
- `uElimPos?` is `some pos` when the matcher can eliminate in different universe levels, and
`pos` is the position of the universe level parameter that specifies the elimination universe.
It is `none` if the matcher only eliminates into `Prop`.
- `overlaps` indicates which alternatives may overlap another
Information about the structure of a matcher declaration
-/
structure MatcherInfo where
/-- Number of parameters -/
numParams : Nat
/-- Number of discriminants -/
numDiscrs : Nat
/-- Parameter structure information for each alternative -/
altInfos : Array AltParamInfo
/--
`uElimPos?` is `some pos` when the matcher can eliminate in different universe levels, and
`pos` is the position of the universe level parameter that specifies the elimination universe.
It is `none` if the matcher only eliminates into `Prop`.
-/
uElimPos? : Option Nat
/--
`discrInfos[i] = { hName? := some h }` if the i-th discriminant was annotated with `h :`.
`discrInfos[i] = { hName? := some h }` if the i-th discriminant was annotated with `h :`.
-/
discrInfos : Array DiscrInfo
overlaps : Overlaps := {}
/--
(Conservative approximation of) which alternatives may overlap another.
-/
overlaps : Overlaps
deriving Inhabited, Repr
@[expose] def MatcherInfo.numAlts (info : MatcherInfo) : Nat :=

View File

@@ -1,5 +1,7 @@
#include "util/options.h"
// please update stage0
namespace lean {
options get_default_options() {
options opts;

View File

@@ -38,8 +38,8 @@ info: Vec.match_on_same_ctor.{u_1, u} {α : Type u}
/--
info: Vec.match_on_same_ctor.splitter.{u_1, u} {α : Type u}
{motive : {a : Nat} → (t t_1 : Vec α a) → t.ctorIdx = t_1.ctorIdx → Sort u_1} {a✝ : Nat} (t t✝ : Vec α a✝)
(h : t.ctorIdx = t✝.ctorIdx) (h_1 : Unit → motive nil nil ⋯)
(h_2 : (a : α) → (n : Nat) → (a_1 : Vec α n) → (a' : α) → (a'_1 : Vec α n) → motive (cons a a_1) (cons a' a'_1) ⋯) :
(h : t.ctorIdx = t✝.ctorIdx) (nil : Unit → motive nil nil ⋯)
(cons : (a : α) → {n : Nat} → (a_1 : Vec α n) → (a' : α) → (a'_1 : Vec α n) → motive (cons a a_1) (cons a' a'_1) ⋯) :
motive t t✝ h
-/
#guard_msgs in

View File

@@ -11,7 +11,7 @@ info: private def myTest.match_1.splitter.{u_1} : (motive : List Bool → Sort u
(x : List Bool) →
((x_1 : Bool) → (xs : List Bool) → x = x_1 :: xs → motive (x_1 :: xs)) → (x = [] → motive []) → motive x :=
fun motive x h_1 h_2 =>
List.casesOn (motive := fun x_1 => x = x_1 → motive x_1) x h_2 (fun head tail => h_1 head tail) ⋯
(fun x_1 => List.casesOn (motive := fun x_2 => x = x_2 → motive x_2) x_1 h_2 fun head tail => h_1 head tail) x
-/
#guard_msgs in
#print myTest.match_1.splitter

View File

@@ -1,7 +1,9 @@
import Lean
set_option linter.unusedVariables false
def checkWithMkMatcherInput (matcher : Lean.Name) : Lean.MetaM Unit :=
Lean.Meta.Match.withMkMatcherInput matcher fun input => do
Lean.Meta.Match.withMkMatcherInput matcher (unfoldNamed := false) fun input => do
let res Lean.Meta.Match.mkMatcher input
let origMatcher Lean.getConstInfo matcher
if not <| input.matcherName == matcher then

View File

@@ -9,6 +9,25 @@ def simple : Lean.Expr → Bool
| .sort _ => true
| _ => false
/--
info: def simple.match_1.{u_1} : (motive : Expr → Sort u_1) →
(x : Expr) → ((u : Level) → motive (sort u)) → ((x : Expr) → motive x) → motive x :=
fun motive x h_1 h_2 => simple._sparseCasesOn_1 x (fun u => h_1 u) fun h => h_2 x
-/
#guard_msgs in
#print simple.match_1
-- Check that the splitter re-uses the sparseCasesOn generated for the matcher:
/--
info: private def simple.match_1.splitter.{u_1} : (motive : Expr → Sort u_1) →
(x : Expr) →
((u : Level) → motive (sort u)) → ((x : Expr) → (∀ (u : Level), x = sort u → False) → motive x) → motive x :=
fun motive x h_1 h_2 => simple._sparseCasesOn_1 x (fun u => h_1 u) fun h => h_2 x ⋯
-/
#guard_msgs in
#print simple.match_1.splitter
def expensive : Lean.Expr Lean.Expr Bool
| .app (.app (.sort 1) (.sort 1)) (.sort 1), .app (.app (.sort 1) (.sort 1)) (.sort 1) => false
| _, _ => true
@@ -49,6 +68,7 @@ info: expensive.match_1.splitter.{u_1} (motive : Expr → Expr → Sort u_1) (x
-/
#guard_msgs in
#check expensive.match_1.splitter
/--
info: expensive.match_1.eq_1.{u_1} (motive : Expr → Expr → Sort u_1)
(h_1 :

View File

@@ -1,3 +1,6 @@
-- set_option trace.Meta.Match.match true
-- set_option trace.Meta.Match.matchEqs true
def f (xs : List Nat) : Nat :=
match xs with
| [] => 1