Compare commits

...

6 Commits

Author SHA1 Message Date
Joachim Breitner
0671533df4 Revert "Partial attempt of doing the same to wf-rec"
This reverts commit d6403106a9.
2025-01-27 14:12:27 +01:00
Joachim Breitner
d6403106a9 Partial attempt of doing the same to wf-rec 2025-01-27 14:12:19 +01:00
Joachim Breitner
b448384d4e Tests 2025-01-27 14:08:36 +01:00
Joachim Breitner
3852debdc0 Refactor 2025-01-27 13:48:41 +01:00
Joachim Breitner
1419d36d93 Use mkEqnProofCore 2025-01-27 12:12:30 +01:00
Joachim Breitner
a1b082085e fix: more robust equational theorems generation
fixes #6786
2025-01-27 11:59:19 +01:00
4 changed files with 264 additions and 113 deletions

View File

@@ -308,6 +308,115 @@ def whnfReducibleLHS? (mvarId : MVarId) : MetaM (Option MVarId) := mvarId.withCo
def tryContradiction (mvarId : MVarId) : MetaM Bool := do
mvarId.contradictionCore { genDiseq := true }
/--
Returns the type of the unfold theorem, as the starting point for calculating the equational
types.
-/
private def unfoldThmType (declName : Name) : MetaM Expr := do
if let some unfoldThm getUnfoldEqnFor? declName (nonRec := false) then
let info getConstInfo unfoldThm
pure info.type
else
let info getConstInfoDefn declName
let us := info.levelParams.map mkLevelParam
lambdaTelescope (cleanupAnnotations := true) info.value fun xs body => do
let type mkEq (mkAppN (Lean.mkConst declName us) xs) body
mkForallFVars xs type
private def unfoldLHS (declName : Name) (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
if let some unfoldThm getUnfoldEqnFor? declName (nonRec := false) then
-- Recursive definition: Use unfolding lemma
let mut mvarId := mvarId
let target mvarId.getType'
let some (_, lhs, rhs) := target.eq? | throwError "unfoldLHS: Unexpected target {target}"
unless lhs.isAppOf declName do throwError "unfoldLHS: Unexpected LHS {lhs}"
let h := mkAppN (.const unfoldThm lhs.getAppFn.constLevels!) lhs.getAppArgs
let some (_, _, lhsNew) := ( inferType h).eq? | unreachable!
let targetNew mkEq lhsNew rhs
let mvarNew mkFreshExprSyntheticOpaqueMVar targetNew
mvarId.assign ( mkEqTrans h mvarNew)
return mvarNew.mvarId!
else
-- Else use delta reduction
deltaLHS mvarId
private partial def mkEqnProof (declName : Name) (type : Expr) : MetaM Expr := do
trace[Elab.definition.eqns] "proving: {type}"
withNewMCtxDepth do
let main mkFreshExprSyntheticOpaqueMVar type
let (_, mvarId) main.mvarId!.intros
-- Try rfl before deltaLHS to avoid `id` checkpoints in the proof, which would make
-- the lemma ineligible for dsimp
unless withAtLeastTransparency .all (tryURefl mvarId) do
go ( unfoldLHS declName mvarId)
instantiateMVars main
where
/--
The core loop of proving an equation. Assumes that the function call on the left-hand side has
already been unfolded, using whatever method applies to the current function definition strategy.
Currently used for non-recursive functions and partial fixpoints; maybe later well-founded
recursion and structural recursion can and should use this too.
-/
go (mvarId : MVarId) : MetaM Unit := do
trace[Elab.definition.eqns] "step\n{MessageData.ofGoal mvarId}"
if withAtLeastTransparency .all (tryURefl mvarId) then
return ()
else if ( tryContradiction mvarId) then
return ()
else if let some mvarId simpMatch? mvarId then
go mvarId
else if let some mvarId simpIf? mvarId then
go mvarId
else if let some mvarId whnfReducibleLHS? mvarId then
go mvarId
else
let ctx Simp.mkContext (config := { dsimp := false })
match ( simpTargetStar mvarId ctx (simprocs := {})).1 with
| TacticResultCNM.closed => return ()
| TacticResultCNM.modified mvarId => go mvarId
| TacticResultCNM.noChange =>
if let some mvarIds casesOnStuckLHS? mvarId then
mvarIds.forM go
else if let some mvarIds splitTarget? mvarId then
mvarIds.forM go
else
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
/--
Generate equations for `declName`.
This unfolds the function application on the LHS (using an unfold theorem, if present, or else by
delta-reduction), calculates the types for the equational theorems using `mkEqnTypes`, and then
proves them using `mkEqnProof`.
This is currently used for non-recursive functions and for functions defined by partial_fixpoint.
-/
def mkEqns (declName : Name) : MetaM (Array Name) := do
let info getConstInfoDefn declName
let us := info.levelParams.map mkLevelParam
withOptions (tactic.hygienic.set · false) do
let target unfoldThmType declName
let eqnTypes withNewMCtxDepth <|
forallTelescope (cleanupAnnotations := true) target fun xs target => do
let goal mkFreshExprSyntheticOpaqueMVar target
withReducible do
mkEqnTypes #[] goal.mvarId!
let mut thmNames := #[]
for h : i in [: eqnTypes.size] do
let type := eqnTypes[i]
trace[Elab.definition.eqns] "eqnType[{i}]: {eqnTypes[i]}"
let name := (Name.str declName eqnThmSuffixBase).appendIndexAfter (i+1)
thmNames := thmNames.push name
let value mkEqnProof declName type
let (type, value) removeUnusedEqnHypotheses type value
addDecl <| Declaration.thmDecl {
name, type, value
levelParams := info.levelParams
}
return thmNames
/--
Auxiliary method for `mkUnfoldEq`. The structure is based on `mkEqnTypes`.
`mvarId` is the goal to be proved. It is a goal of the form

View File

@@ -33,71 +33,12 @@ private def mkSimpleEqThm (declName : Name) (suffix := Name.mkSimple unfoldThmSu
else
return none
private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do
trace[Elab.definition.eqns] "proving: {type}"
withNewMCtxDepth do
let main mkFreshExprSyntheticOpaqueMVar type
let (_, mvarId) main.mvarId!.intros
let rec go (mvarId : MVarId) : MetaM Unit := do
trace[Elab.definition.eqns] "step\n{MessageData.ofGoal mvarId}"
if withAtLeastTransparency .all (tryURefl mvarId) then
return ()
else if ( tryContradiction mvarId) then
return ()
else if let some mvarId simpMatch? mvarId then
go mvarId
else if let some mvarId simpIf? mvarId then
go mvarId
else if let some mvarId whnfReducibleLHS? mvarId then
go mvarId
else
let ctx Simp.mkContext (config := { dsimp := false })
match ( simpTargetStar mvarId ctx (simprocs := {})).1 with
| TacticResultCNM.closed => return ()
| TacticResultCNM.modified mvarId => go mvarId
| TacticResultCNM.noChange =>
if let some mvarIds casesOnStuckLHS? mvarId then
mvarIds.forM go
else if let some mvarIds splitTarget? mvarId then
mvarIds.forM go
else
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
-- Try rfl before deltaLHS to avoid `id` checkpoints in the proof, which would make
-- the lemma ineligible for dsimp
unless withAtLeastTransparency .all (tryURefl mvarId) do
go ( deltaLHS mvarId)
instantiateMVars main
def mkEqns (declName : Name) (info : DefinitionVal) : MetaM (Array Name) :=
withOptions (tactic.hygienic.set · false) do
let baseName := declName
let eqnTypes withNewMCtxDepth <| lambdaTelescope (cleanupAnnotations := true) info.value fun xs body => do
let us := info.levelParams.map mkLevelParam
let target mkEq (mkAppN (Lean.mkConst declName us) xs) body
let goal mkFreshExprSyntheticOpaqueMVar target
withReducible do
mkEqnTypes #[] goal.mvarId!
let mut thmNames := #[]
for h : i in [: eqnTypes.size] do
let type := eqnTypes[i]
trace[Elab.definition.eqns] "eqnType[{i}]: {eqnTypes[i]}"
let name := (Name.str baseName eqnThmSuffixBase).appendIndexAfter (i+1)
thmNames := thmNames.push name
let value mkProof declName type
let (type, value) removeUnusedEqnHypotheses type value
addDecl <| Declaration.thmDecl {
name, type, value
levelParams := info.levelParams
}
return thmNames
def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
if ( isRecursiveDefinition declName) then
return none
if let some (.defnInfo info) := ( getEnv).find? declName then
if ( getEnv).contains declName then
if backward.eqns.nonrecursive.get ( getOptions) then
mkEqns declName info
mkEqns declName
else
let o mkSimpleEqThm declName
return o.map (#[·])

View File

@@ -23,6 +23,18 @@ structure EqnInfo extends EqnInfoCore where
fixedPrefixSize : Nat
deriving Inhabited
builtin_initialize eqnInfoExt : MapDeclarationExtension EqnInfo mkMapDeclarationExtension
def registerEqnsInfo (preDefs : Array PreDefinition) (declNameNonRec : Name) (fixedPrefixSize : Nat) : MetaM Unit := do
preDefs.forM fun preDef => ensureEqnReservedNamesAvailable preDef.declName
unless preDefs.all fun p => p.kind.isTheorem do
unless ( preDefs.allM fun p => isProp p.type) do
let declNames := preDefs.map (·.declName)
modifyEnv fun env =>
preDefs.foldl (init := env) fun env preDef =>
eqnInfoExt.insert env preDef.declName { preDef with
declNames, declNameNonRec, fixedPrefixSize }
private def deltaLHSUntilFix (declName declNameNonRec : Name) (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
let target mvarId.getType'
let some (_, lhs, rhs) := target.eq? | throwTacticEx `deltaLHSUntilFix mvarId "equality expected"
@@ -53,62 +65,50 @@ private def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
mvarId.assign ( mkEqTrans h mvarNew)
return mvarNew.mvarId!
private partial def mkProof (declName : Name) (declNameNonRec : Name) (type : Expr) : MetaM Expr := do
trace[Elab.definition.partialFixpoint] "proving: {type}"
withNewMCtxDepth do
let main mkFreshExprSyntheticOpaqueMVar type
let (_, mvarId) main.mvarId!.intros
let mvarId deltaLHSUntilFix declName declNameNonRec mvarId
let mvarId rwFixEq mvarId
if withAtLeastTransparency .all (tryURefl mvarId) then
instantiateMVars main
else
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
def mkEqns (declName : Name) (info : EqnInfo) : MetaM (Array Name) :=
/-- Generate the "unfold" lemma for `declName`. -/
def mkUnfoldEq (declName : Name) (info : EqnInfo) : MetaM Name := withLCtx {} {} do
withOptions (tactic.hygienic.set · false) do
let baseName := declName
let eqnTypes withNewMCtxDepth <| lambdaTelescope (cleanupAnnotations := true) info.value fun xs body => do
let us := info.levelParams.map mkLevelParam
let target mkEq (mkAppN (Lean.mkConst declName us) xs) body
let goal mkFreshExprSyntheticOpaqueMVar target
withReducible do
mkEqnTypes info.declNames goal.mvarId!
let mut thmNames := #[]
for h : i in [: eqnTypes.size] do
let type := eqnTypes[i]
trace[Elab.definition.partialFixpoint] "{eqnTypes[i]}"
let name := (Name.str baseName eqnThmSuffixBase).appendIndexAfter (i+1)
thmNames := thmNames.push name
let value mkProof declName info.declNameNonRec type
let (type, value) removeUnusedEqnHypotheses type value
addDecl <| Declaration.thmDecl {
name, type, value
levelParams := info.levelParams
}
return thmNames
builtin_initialize eqnInfoExt : MapDeclarationExtension EqnInfo mkMapDeclarationExtension
def registerEqnsInfo (preDefs : Array PreDefinition) (declNameNonRec : Name) (fixedPrefixSize : Nat) : MetaM Unit := do
preDefs.forM fun preDef => ensureEqnReservedNamesAvailable preDef.declName
unless preDefs.all fun p => p.kind.isTheorem do
unless ( preDefs.allM fun p => isProp p.type) do
let declNames := preDefs.map (·.declName)
modifyEnv fun env =>
preDefs.foldl (init := env) fun env preDef =>
eqnInfoExt.insert env preDef.declName { preDef with
declNames, declNameNonRec, fixedPrefixSize }
def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
if let some info := eqnInfoExt.find? ( getEnv) declName then
mkEqns declName info
else
return none
let baseName := declName
lambdaTelescope info.value fun xs body => do
let us := info.levelParams.map mkLevelParam
let type mkEq (mkAppN (Lean.mkConst declName us) xs) body
let goal withNewMCtxDepth do
try
let goal mkFreshExprSyntheticOpaqueMVar type
let mvarId := goal.mvarId!
trace[Elab.definition.partialFixpoint] "mkUnfoldEq start:{mvarId}"
let mvarId deltaLHSUntilFix declName info.declNameNonRec mvarId
trace[Elab.definition.partialFixpoint] "mkUnfoldEq after deltaLHS:{mvarId}"
let mvarId rwFixEq mvarId
trace[Elab.definition.partialFixpoint] "mkUnfoldEq after rwFixEq:{mvarId}"
withAtLeastTransparency .all <|
withOptions (smartUnfolding.set · false) <|
mvarId.refl
trace[Elab.definition.partialFixpoint] "mkUnfoldEq rfl succeeded"
instantiateMVars goal
catch e =>
throwError "failed to generate unfold theorem for '{declName}':\n{e.toMessageData}"
let type mkForallFVars xs type
let value mkLambdaFVars xs goal
let name := Name.str baseName unfoldThmSuffix
addDecl <| Declaration.thmDecl {
name, type, value
levelParams := info.levelParams
}
return name
def getUnfoldFor? (declName : Name) : MetaM (Option Name) := do
let name := Name.str declName unfoldThmSuffix
let env getEnv
Eqns.getUnfoldFor? declName fun _ => eqnInfoExt.find? env declName |>.map (·.toEqnInfoCore)
if env.contains name then return name
let some info := eqnInfoExt.find? env declName | return none
return some ( mkUnfoldEq declName info)
def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
if let some _ := eqnInfoExt.find? ( getEnv) declName then
mkEqns declName
else
return none
builtin_initialize
registerGetEqnsFn getEqnsFor?

View File

@@ -0,0 +1,101 @@
def find42 : Nat Bool
| 42 => true
| n => find42 (n + 1)
partial_fixpoint
/--
info: find42.eq_def (x✝ : Nat) :
find42 x✝ =
match x✝ with
| 42 => true
| n => find42 (n + 1)
-/
#guard_msgs in
#check find42.eq_def
/--
info: equations:
theorem find42.eq_1 : find42 42 = true
theorem find42.eq_2 : ∀ (x : Nat), (x = 42 → False) → find42 x = find42 (x + 1)
-/
#guard_msgs in
#print equations find42
mutual
def find99 : Nat Bool
| 99 => true
| n => find23 (n + 1)
partial_fixpoint
def find23 : Nat Bool
| 23 => true
| n => find99 (n + 1)
partial_fixpoint
end
/--
info: find99.eq_def (x✝ : Nat) :
find99 x✝ =
match x✝ with
| 99 => true
| n => find23 (n + 1)
-/
#guard_msgs in
#check find99.eq_def
/--
info: find23.eq_def (x✝ : Nat) :
find23 x✝ =
match x✝ with
| 23 => true
| n => find99 (n + 1)
-/
#guard_msgs in
#check find23.eq_def
/--
info: equations:
theorem find99.eq_1 : find99 99 = true
theorem find99.eq_2 : ∀ (x : Nat), (x = 99 → False) → find99 x = find23 (x + 1)
-/
#guard_msgs in
#print equations find99
/--
info: equations:
theorem find23.eq_1 : find23 23 = true
theorem find23.eq_2 : ∀ (x : Nat), (x = 23 → False) → find23 x = find99 (x + 1)
-/
#guard_msgs in
#print equations find23
mutual
def g (i j : Nat) : Nat :=
if i < 5 then 0 else
match j with
| Nat.zero => 1
| Nat.succ j => h i j
partial_fixpoint
def h (i j : Nat) : Nat :=
match j with
| 0 => g i 0
| Nat.succ j => g i j
partial_fixpoint
end
/--
info: equations:
theorem g.eq_1 : ∀ (i : Nat), g i Nat.zero = if i < 5 then 0 else 1
theorem g.eq_2 : ∀ (i j_2 : Nat), g i j_2.succ = if i < 5 then 0 else h i j_2
-/
#guard_msgs in
#print equations g
/--
info: equations:
theorem h.eq_1 : ∀ (i : Nat), h i 0 = g i 0
theorem h.eq_2 : ∀ (i j_2 : Nat), h i j_2.succ = g i j_2
-/
#guard_msgs in
#print equations h