Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
2c0ba24588 fix: simp must not use the forward version of an user-specified backward theorem
closes #4290
2024-06-04 15:32:42 -07:00
2 changed files with 108 additions and 33 deletions

View File

@@ -48,11 +48,33 @@ def Origin.key : Origin → Name
| .stx id _ => id
| .other name => name
instance : BEq Origin := (·.key == ·.key)
instance : Hashable Origin := (hash ·.key)
instance : LT Origin := (·.key.lt ·.key)
instance : BEq Origin where
beq a b := match a, b with
| .decl declName₁ _ inv₁, .decl declName₂ _ inv₂ =>
/- Remark: we must distinguish `thm` from `←thm`. See issue #4290. -/
declName₁ == declName₂ && inv₁ == inv₂
| .decl .., _ => false
| _, .decl .. => false
| a, b => a.key == b.key
instance : Hashable Origin where
hash a := match a with
| .decl declName _ true => mixHash (hash declName) 11
| .decl declName _ false => mixHash (hash declName) 13
| a => hash a.key
def Origin.lt : Origin Origin Bool
| .decl declName₁ _ inv₁, .decl declName₂ _ inv₂ =>
Name.lt declName₁ declName₂ || (declName₁ == declName₂ && !inv₁ && inv₂)
| .decl .., _ => false
| _, .decl .. => true
| a, b => Name.lt a.key b.key
instance : LT Origin where
lt a b := a.lt b
instance (a b : Origin) : Decidable (a < b) :=
inferInstanceAs (Decidable (a.key.lt b.key = true))
inferInstanceAs (Decidable (a.lt b))
/-
Note: we want to use iota reduction when indexing instances. Otherwise,
@@ -179,7 +201,35 @@ structure SimpTheorems where
/-- Configuration for the discrimination tree. -/
def simpDtConfig : WhnfCoreConfig := { iota := false, proj := .no, zetaDelta := false }
partial def SimpTheorems.eraseCore (d : SimpTheorems) (thmId : Origin) : SimpTheorems :=
let d := { d with erased := d.erased.insert thmId, lemmaNames := d.lemmaNames.erase thmId }
if let .decl declName .. := thmId then
let d := { d with toUnfold := d.toUnfold.erase declName }
if let some thms := d.toUnfoldThms.find? declName then
let dummy := true
thms.foldl (init := d) (eraseCore · <| .decl · dummy (inv := false))
else
d
else
d
private def eraseIfExists (d : SimpTheorems) (thmId : Origin) : SimpTheorems :=
if d.lemmaNames.contains thmId then
d.eraseCore thmId
else
d
/--
If `e` is a backwards theorem `← thm`, we must ensure the forward theorem is erased
from `d`. See issue #4290
-/
private def eraseFwdIfBwd (d : SimpTheorems) (e : SimpTheorem) : SimpTheorems :=
match e.origin with
| .decl declName post true => eraseIfExists d (.decl declName post false)
| _ => d
def addSimpTheoremEntry (d : SimpTheorems) (e : SimpTheorem) : SimpTheorems :=
let d := eraseFwdIfBwd d e
if e.post then
{ d with post := d.post.insertCore e.keys e, lemmaNames := updateLemmaNames d.lemmaNames }
else
@@ -209,18 +259,6 @@ def SimpTheorems.isLemma (d : SimpTheorems) (thmId : Origin) : Bool :=
def SimpTheorems.registerDeclToUnfoldThms (d : SimpTheorems) (declName : Name) (eqThms : Array Name) : SimpTheorems :=
{ d with toUnfoldThms := d.toUnfoldThms.insert declName eqThms }
partial def SimpTheorems.eraseCore (d : SimpTheorems) (thmId : Origin) : SimpTheorems :=
let d := { d with erased := d.erased.insert thmId, lemmaNames := d.lemmaNames.erase thmId }
if let .decl declName .. := thmId then
let d := { d with toUnfold := d.toUnfold.erase declName }
if let some thms := d.toUnfoldThms.find? declName then
let dummy := true
thms.foldl (init := d) (eraseCore · <| .decl · dummy dummy)
else
d
else
d
def SimpTheorems.erase [Monad m] [MonadLog m] [AddMessageContext m] [MonadOptions m]
(d : SimpTheorems) (thmId : Origin) : m SimpTheorems := do
unless d.isLemma thmId ||
@@ -232,15 +270,17 @@ def SimpTheorems.erase [Monad m] [MonadLog m] [AddMessageContext m] [MonadOption
return d.eraseCore thmId
private partial def isPerm : Expr Expr MetaM Bool
| Expr.app f₁ a₁, Expr.app f₂ a₂ => isPerm f₁ f₂ <&&> isPerm a₁ a₂
| Expr.mdata _ s, t => isPerm s t
| s, Expr.mdata _ t => isPerm s t
| s@(Expr.mvar ..), t@(Expr.mvar ..) => isDefEq s t
| Expr.forallE n₁ d₁ b₁ _, Expr.forallE _ d₂ b₂ _ => isPerm d₁ d₂ <&&> withLocalDeclD n₁ d₁ fun x => isPerm (b₁.instantiate1 x) (b₂.instantiate1 x)
| Expr.lam n₁ d₁ b₁ _, Expr.lam _ d₂ b₂ _ => isPerm d₁ d₂ <&&> withLocalDeclD n₁ d₁ fun x => isPerm (b₁.instantiate1 x) (b₂.instantiate1 x)
| Expr.letE n₁ t₁ v b₁ _, Expr.letE _ t₂ v b₂ _ =>
| .app f₁ a₁, .app f₂ a₂ => isPerm f₁ f₂ <&&> isPerm a₁ a₂
| .mdata _ s, t => isPerm s t
| s, .mdata _ t => isPerm s t
| s@(.mvar ..), t@(.mvar ..) => isDefEq s t
| .forallE n₁ d₁ b₁ _, .forallE _ d₂ b₂ _ =>
isPerm d₁ d₂ <&&> withLocalDeclD n₁ d₁ fun x => isPerm (b₁.instantiate1 x) (b₂.instantiate1 x)
| .lam n₁ d b₁ _, .lam _ d b₂ _ =>
isPerm d₁ d₂ <&&> withLocalDeclD n₁ d₁ fun x => isPerm (b₁.instantiate1 x) (b₂.instantiate1 x)
| .letE n₁ t₁ v₁ b₁ _, .letE _ t₂ v₂ b₂ _ =>
isPerm t₁ t₂ <&&> isPerm v₁ v₂ <&&> withLetDecl n₁ t₁ v₁ fun x => isPerm (b₁.instantiate1 x) (b₂.instantiate1 x)
| Expr.proj _ i₁ b₁, Expr.proj _ i₂ b₂ => pure (i₁ == i₂) <&&> isPerm b₁ b₂
| .proj _ i₁ b₁, .proj _ i₂ b₂ => pure (i₁ == i₂) <&&> isPerm b₁ b₂
| s, t => return s == t
private def checkBadRewrite (lhs rhs : Expr) : MetaM Unit := do
@@ -337,7 +377,9 @@ private def mkSimpTheoremCore (origin : Origin) (e : Expr) (levelParams : Array
private def mkSimpTheoremsFromConst (declName : Name) (post : Bool) (inv : Bool) (prio : Nat) : MetaM (Array SimpTheorem) := do
let cinfo getConstInfo declName
let val := mkConst declName (cinfo.levelParams.map mkLevelParam)
let us := cinfo.levelParams.map mkLevelParam
let origin := .decl declName post inv
let val := mkConst declName us
withReducible do
let type inferType val
checkTypeIsProp type
@@ -345,10 +387,10 @@ private def mkSimpTheoremsFromConst (declName : Name) (post : Bool) (inv : Bool)
let mut r := #[]
for (val, type) in ( preprocess val type inv (isGlobal := true)) do
let auxName mkAuxLemma cinfo.levelParams type val
r := r.push <| ( mkSimpTheoremCore (.decl declName post inv) (mkConst auxName (cinfo.levelParams.map mkLevelParam)) #[] (mkConst auxName) post prio (noIndexAtArgs := false))
r := r.push <| ( mkSimpTheoremCore origin (mkConst auxName us) #[] (mkConst auxName) post prio (noIndexAtArgs := false))
return r
else
return #[ mkSimpTheoremCore (.decl declName post inv) (mkConst declName (cinfo.levelParams.map mkLevelParam)) #[] (mkConst declName) post prio (noIndexAtArgs := false)]
return #[ mkSimpTheoremCore origin (mkConst declName us) #[] (mkConst declName) post prio (noIndexAtArgs := false)]
inductive SimpEntry where
| thm : SimpTheorem SimpEntry
@@ -366,16 +408,15 @@ def addSimpTheorem (ext : SimpExtension) (declName : Name) (post : Bool) (inv :
for simpThm in simpThms do
ext.add (SimpEntry.thm simpThm) attrKind
def mkSimpExt (name : Name := by exact decl_name%) : IO SimpExtension :=
registerSimpleScopedEnvExtension {
name := name
initial := {}
addEntry := fun d e =>
match e with
| SimpEntry.thm e => addSimpTheoremEntry d e
| SimpEntry.toUnfold n => d.addDeclToUnfoldCore n
| SimpEntry.toUnfoldThms n thms => d.registerDeclToUnfoldThms n thms
| .thm e => addSimpTheoremEntry d e
| .toUnfold n => d.addDeclToUnfoldCore n
| .toUnfoldThms n thms => d.registerDeclToUnfoldThms n thms
}
abbrev SimpExtensionMap := HashMap Name SimpExtension
@@ -397,7 +438,7 @@ def SimpTheorem.getValue (simpThm : SimpTheorem) : MetaM Expr := do
if info.levelParams.isEmpty then
return simpThm.proof
else
return simpThm.proof.updateConst! ( info.levelParams.mapM (fun _ => mkFreshLevelMVar))
return simpThm.proof.updateConst! ( info.levelParams.mapM fun _ => mkFreshLevelMVar)
else
let us simpThm.levelParams.mapM fun _ => mkFreshLevelMVar
return simpThm.proof.instantiateLevelParamsArray simpThm.levelParams us
@@ -464,6 +505,6 @@ def SimpTheoremsArray.isDeclToUnfold (thmsArray : SimpTheoremsArray) (declName :
thmsArray.any fun thms => thms.isDeclToUnfold declName
def SimpTheoremsArray.isLetDeclToUnfold (thmsArray : SimpTheoremsArray) (fvarId : FVarId) : Bool :=
thmsArray.any fun thms => thms.isLetDeclToUnfold fvarId
thmsArray.any fun thms => thms.isLetDeclToUnfold fvarId
end Lean.Meta

34
tests/lean/run/4290.lean Normal file
View File

@@ -0,0 +1,34 @@
def foo : Nat := 0
def bar : Nat := 0
@[simp] theorem foo_eq_bar : foo = bar := rfl
example : foo = bar := by simp [ foo_eq_bar]
example : foo = bar + 1 := by
simp [ foo_eq_bar]
guard_target = foo = foo + 1
sorry
def a : Nat := 0
def b : Nat := 0
def c : Nat := 0
@[simp] theorem abc : a = b a = c := And.intro rfl rfl
example : a = b := by simp [ abc]
example : a = c := by simp [ abc]
example : a = c + 1 := by
simp [ abc]
guard_target = a = a + 1
sorry
opaque d : Nat
opaque e : Nat
@[simp] theorem de : d = e := sorry
example : d = e := by simp [ de]
example : d = e := by simp [de]
example : d = e + 1 := by
simp [de]
guard_target = d = d + 1
sorry