Compare commits

...

4 Commits

Author SHA1 Message Date
Leonardo de Moura
25db17ac63 fix: realizer for matcher equation theorems 2024-03-28 19:34:44 -07:00
Leonardo de Moura
c6a625d41e chore: move test to correct directory 2024-03-28 19:27:41 -07:00
Leonardo de Moura
4585ad9878 fix: realize matcher equational theorems, and use public names 2024-03-28 19:19:10 -07:00
Leonardo de Moura
cb163fd32a fix: reserved name resolution 2024-03-28 18:12:09 -07:00
4 changed files with 101 additions and 47 deletions

View File

@@ -649,10 +649,13 @@ where
private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns := withLCtx {} {} do
trace[Meta.Match.matchEqs] "mkEquationsFor '{matchDeclName}'"
withConfig (fun c => { c with etaStruct := .none }) do
let baseName := mkPrivateName ( getEnv) matchDeclName
let baseName := matchDeclName
let splitterName := baseName ++ `splitter
let constInfo getConstInfo matchDeclName
let us := constInfo.levelParams.map mkLevelParam
let some matchInfo getMatcherInfo? matchDeclName | throwError "'{matchDeclName}' is not a matcher function"
-- `alreadyDeclared` is `true` if matcher equations were defined in an imported module
let alreadyDeclared := ( getEnv).contains splitterName
let numDiscrEqs := getNumEqsFromDiscrInfos matchInfo.discrInfos
forallTelescopeReducing constInfo.type fun xs matchResultType => do
let mut eqnNames := #[]
@@ -687,52 +690,59 @@ private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns :=
for discr in discrs.toArray.reverse, pattern in patterns.reverse do
notAlt mkArrow ( mkEqHEq discr pattern) notAlt
notAlt mkForallFVars (discrs ++ ys) notAlt
/- Recall that when we use the `h : discr`, the alternative type depends on the discriminant.
Thus, we need to create new `alts`. -/
withNewAlts numDiscrEqs discrs patterns alts fun alts => do
let alt := alts[i]!
let lhs := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ patterns ++ alts)
let rhs := mkAppN alt rhsArgs
let thmType mkEq lhs rhs
let thmType hs.foldrM (init := thmType) (mkArrow · ·)
let thmType mkForallFVars (params ++ #[motive] ++ ys ++ alts) thmType
let thmType unfoldNamedPattern thmType
let thmVal proveCondEqThm matchDeclName thmType
addDecl <| Declaration.thmDecl {
name := thmName
levelParams := constInfo.levelParams
type := thmType
value := thmVal
}
return (notAlt, splitterAltType, splitterAltNumParam, argMask)
if alreadyDeclared then
-- If the matcher equations and splitter have already been declared, the only
-- values we are `notAlt` and `splitterAltNumParam`. This code is a bit hackish.
return (notAlt, default, splitterAltNumParam, default)
else
/- Recall that when we use the `h : discr`, the alternative type depends on the discriminant.
Thus, we need to create new `alts`. -/
withNewAlts numDiscrEqs discrs patterns alts fun alts => do
let alt := alts[i]!
let lhs := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ patterns ++ alts)
let rhs := mkAppN alt rhsArgs
let thmType mkEq lhs rhs
let thmType hs.foldrM (init := thmType) (mkArrow · ·)
let thmType mkForallFVars (params ++ #[motive] ++ ys ++ alts) thmType
let thmType unfoldNamedPattern thmType
let thmVal proveCondEqThm matchDeclName thmType
addDecl <| Declaration.thmDecl {
name := thmName
levelParams := constInfo.levelParams
type := thmType
value := thmVal
}
return (notAlt, splitterAltType, splitterAltNumParam, argMask)
notAlts := notAlts.push notAlt
splitterAltTypes := splitterAltTypes.push splitterAltType
splitterAltNumParams := splitterAltNumParams.push splitterAltNumParam
altArgMasks := altArgMasks.push argMask
trace[Meta.Match.matchEqs] "splitterAltType: {splitterAltType}"
idx := idx + 1
-- Define splitter with conditional/refined alternatives
withSplitterAlts splitterAltTypes fun altsNew => do
let splitterParams := params.toArray ++ #[motive] ++ discrs.toArray ++ altsNew
let splitterType mkForallFVars splitterParams matchResultType
trace[Meta.Match.matchEqs] "splitterType: {splitterType}"
let template := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ discrs ++ alts)
let template deltaExpand template (· == constInfo.name)
let template := template.headBeta
let splitterVal mkLambdaFVars splitterParams ( mkSplitterProof matchDeclName template alts altsNew splitterAltNumParams altArgMasks)
let splitterName := baseName ++ `splitter
addAndCompile <| Declaration.defnDecl {
name := splitterName
levelParams := constInfo.levelParams
type := splitterType
value := splitterVal
hints := .abbrev
safety := .safe
}
setInlineAttribute splitterName
let result := { eqnNames, splitterName, splitterAltNumParams }
registerMatchEqns matchDeclName result
return result
if alreadyDeclared then
return { eqnNames, splitterName, splitterAltNumParams }
else
-- Define splitter with conditional/refined alternatives
withSplitterAlts splitterAltTypes fun altsNew => do
let splitterParams := params.toArray ++ #[motive] ++ discrs.toArray ++ altsNew
let splitterType mkForallFVars splitterParams matchResultType
trace[Meta.Match.matchEqs] "splitterType: {splitterType}"
let template := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ discrs ++ alts)
let template deltaExpand template (· == constInfo.name)
let template := template.headBeta
let splitterVal mkLambdaFVars splitterParams ( mkSplitterProof matchDeclName template alts altsNew splitterAltNumParams altArgMasks)
addAndCompile <| Declaration.defnDecl {
name := splitterName
levelParams := constInfo.levelParams
type := splitterType
value := splitterVal
hints := .abbrev
safety := .safe
}
setInlineAttribute splitterName
let result := { eqnNames, splitterName, splitterAltNumParams }
registerMatchEqns matchDeclName result
return result
/- See header at `MatchEqsExt.lean` -/
@[export lean_get_match_equations_for]
@@ -743,4 +753,23 @@ def getEquationsForImpl (matchDeclName : Name) : MetaM MatchEqns := do
builtin_initialize registerTraceClass `Meta.Match.matchEqs
private def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
unless ( isMatcher declName) do
return none
let result getEquationsForImpl declName
return some result.eqnNames
builtin_initialize
registerGetEqnsFn getEqnsFor?
/-
We register `foo.match_<idx>.splitter` as a reserved name, but
we do not install a realizer. The `splitter` will be generated by the
`foo.match_<idx>.eq_<idx>` realizer.
-/
builtin_initialize registerReservedNamePredicate fun env n =>
match n with
| .str p "splitter" => isMatcherCore env p
| _ => false
end Lean.Meta.Match

View File

@@ -93,17 +93,20 @@ def getRevAliases (env : Environment) (e : Name) : List Name :=
/-! # Global name resolution -/
namespace ResolveName
private def containsDeclOrReserved (env : Environment) (declName : Name) : Bool :=
env.contains declName || isReservedName env declName
/-- Check whether `ns ++ id` is a valid namespace name and/or there are aliases names `ns ++ id`. -/
private def resolveQualifiedName (env : Environment) (ns : Name) (id : Name) : List Name :=
let resolvedId := ns ++ id
-- We ignore protected aliases if `id` is atomic.
let resolvedIds := getAliases env resolvedId (skipProtected := id.isAtomic)
if env.contains resolvedId && (!id.isAtomic || !isProtected env resolvedId) then
if (containsDeclOrReserved env resolvedId && (!id.isAtomic || !isProtected env resolvedId)) then
resolvedId :: resolvedIds
else
-- Check whether environment contains the private version. That is, `_private.<module_name>.ns.id`.
let resolvedIdPrv := mkPrivateName env resolvedId
if env.contains resolvedIdPrv then resolvedIdPrv :: resolvedIds
if containsDeclOrReserved env resolvedIdPrv then resolvedIdPrv :: resolvedIds
else resolvedIds
/-- Check surrounding namespaces -/
@@ -119,12 +122,12 @@ private def resolveExact (env : Environment) (id : Name) : Option Name :=
if id.isAtomic then none
else
let resolvedId := id.replacePrefix rootNamespace Name.anonymous
if env.contains resolvedId then some resolvedId
if containsDeclOrReserved env resolvedId then some resolvedId
else
-- We also allow `_root` when accessing private declarations.
-- If we change our minds, we should just replace `resolvedId` with `id`
let resolvedIdPrv := mkPrivateName env resolvedId
if env.contains resolvedIdPrv then some resolvedIdPrv
if containsDeclOrReserved env resolvedIdPrv then some resolvedIdPrv
else none
/-- Check `OpenDecl`s -/
@@ -171,9 +174,9 @@ def resolveGlobalName (env : Environment) (ns : Name) (openDecls : List OpenDecl
match resolveExact env id with
| some newId => [(newId, projs)]
| none =>
let resolvedIds := if env.contains id || isReservedName env id then [id] else []
let resolvedIds := if containsDeclOrReserved env id then [id] else []
let idPrv := mkPrivateName env id
let resolvedIds := if env.contains idPrv || isReservedName env idPrv then [idPrv] ++ resolvedIds else resolvedIds
let resolvedIds := if containsDeclOrReserved env idPrv then [idPrv] ++ resolvedIds else resolvedIds
let resolvedIds := resolveOpenDecls env id openDecls resolvedIds
let resolvedIds := getAliases env id (skipProtected := id.isAtomic) ++ resolvedIds
match resolvedIds with

View File

@@ -26,9 +26,16 @@ set_option trace.Meta.Match.matchEqs true
test% f.match_1
#check f.match_1.splitter
/--
error: 'g.match_1.splitter' is a reserved name
-/
#guard_msgs (error) in
def g.match_1.splitter := 4
test% g.match_1
#check g.match_1.eq_1
#check g.match_1.eq_2
#check g.match_1.splitter
def bla.splitter := 5 -- ok

View File

@@ -0,0 +1,15 @@
namespace Nat
def dist (n m : Nat) :=
n - m + (m - n)
example (n m : Nat) : dist n m = dist m n := by
simp [dist.eq_def, Nat.add_comm]
example (n m : Nat) : dist n m = dist m n := by
simp [Nat.dist.eq_def, Nat.add_comm]
theorem dist_comm (n m : Nat) : dist n m = dist m n := by
simp [dist.eq_def, Nat.add_comm]
theorem dist_self (n : Nat) : dist n n = 0 := by simp [dist.eq_def]