Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
71402d1f36 fix: performance issue when elaborating match-expressions with many literals
This PR fixes a performance issue that occurs when generating equation
lemmas for functions that use match-expressions containing several
literals. This issue was exposed by #9322 and arises from a combination of factors:

1. Literal values are compiled into a chain of dependent if-then-else expressions.
2. Dependent if-then-else expressions are significantly more expensive to simplify than regular ones.
3. The `split` tactic selects a target, splits it, and then invokes `simp` on the resulting subgoals. Moreover, `simp` traverses the entire goal bottom-up and does not stop after reaching the target.

This PR addresses the issue by introducing a custom simproc that
avoids recursively simplifying nested if-then-else expressions. It
does **not** alter the user-facing behavior of the `split` tactic
because such a change would be highly disruptive. Instead, the PR adds
a new flag, `backward.split` to control the behavior of the
user-facing `split` tactic. It is currently set to `true`, i.e., the
old behavior is still the default one. In a future PR, we should set
this flag to `false` by default and begin repairing all affected
proofs.

closes #9322
2025-07-14 20:34:32 -07:00
9 changed files with 277 additions and 45 deletions

View File

@@ -144,6 +144,12 @@ theorem ite_congr {x y u v : α} {s : Decidable b} [Decidable c]
| inl h => rw [if_pos h]; subst b; rw [if_pos h]; exact h₂ h
| inr h => rw [if_neg h]; subst b; rw [if_neg h]; exact h₃ h
theorem ite_cond_congr {α} {b c : Prop} {s : Decidable b} [Decidable c] {x y : α}
(h₁ : b = c) : ite b x y = ite c x y := by
cases Decidable.em c with
| inl h => rw [if_pos h]; subst b; rw [if_pos h]
| inr h => rw [if_neg h]; subst b; rw [if_neg h]
theorem Eq.mpr_prop {p q : Prop} (h₁ : p = q) (h₂ : q) : p := h₁ h₂
theorem Eq.mpr_not {p q : Prop} (h₁ : p = q) (h₂ : ¬q) : ¬p := h₁ h₂
@@ -158,6 +164,13 @@ theorem dite_congr {_ : Decidable b} [Decidable c]
| inl h => rw [dif_pos h]; subst b; rw [dif_pos h]; exact h₂ h
| inr h => rw [dif_neg h]; subst b; rw [dif_neg h]; exact h₃ h
theorem dite_cond_congr {α} {b c : Prop} {s : Decidable b} [Decidable c]
{x : b α} {y : ¬ b α} (h₁ : b = c) :
dite b x y = dite c (fun h => x (h₁.mpr_prop h)) (fun h => y (h₁.mpr_not h)) := by
cases Decidable.em c with
| inl h => rw [dif_pos h]; subst b; rw [dif_pos h]
| inr h => rw [dif_neg h]; subst b; rw [dif_neg h]
@[simp] theorem ne_eq (a b : α) : (a b) = ¬(a = b) := rfl
norm_cast_add_elim ne_eq
@[simp] theorem ite_true (a b : α) : (if True then a else b) = a := rfl

View File

@@ -44,8 +44,12 @@ def simpMatch? (mvarId : MVarId) : MetaM (Option MVarId) := do
let mvarId' Split.simpMatchTarget mvarId
if mvarId != mvarId' then return some mvarId' else return none
def simpIf? (mvarId : MVarId) : MetaM (Option MVarId) := do
let mvarId' simpIfTarget mvarId (useDecide := true)
/--
Simplify `if-then-expression`s in the goal target.
If `useNewSemantics` is `true`, the flag `backward.split` is ignored.
-/
def simpIf? (mvarId : MVarId) (useNewSemantics := false) : MetaM (Option MVarId) := do
let mvarId' simpIfTarget mvarId (useDecide := true) (useNewSemantics := useNewSemantics)
if mvarId != mvarId' then return some mvarId' else return none
private def findMatchToSplit? (deepRecursiveSplit : Bool) (env : Environment) (e : Expr)
@@ -369,7 +373,7 @@ private partial def mkEqnProof (declName : Name) (type : Expr) (tryRefl : Bool)
return ()
else if let some mvarId simpMatch? mvarId then
go mvarId
else if let some mvarId simpIf? mvarId then
else if let some mvarId simpIf? mvarId (useNewSemantics := true) then
go mvarId
else if let some mvarId whnfReducibleLHS? mvarId then
go mvarId
@@ -381,7 +385,7 @@ private partial def mkEqnProof (declName : Name) (type : Expr) (tryRefl : Bool)
| TacticResultCNM.noChange =>
if let some mvarIds casesOnStuckLHS? mvarId then
mvarIds.forM go
else if let some mvarIds splitTarget? mvarId then
else if let some mvarIds splitTarget? mvarId (useNewSemantics := true) then
mvarIds.forM go
else
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
@@ -459,7 +463,7 @@ partial def mkUnfoldProof (declName : Name) (mvarId : MVarId) : MetaM Unit := do
if let some mvarId simpMatch? mvarId then
return ( go mvarId)
if let some mvarIds splitTarget? mvarId (splitIte := false) then
if let some mvarIds splitTarget? mvarId (splitIte := false) (useNewSemantics := true) then
return ( mvarIds.forM go)
if ( tryContradiction mvarId) then

View File

@@ -48,7 +48,7 @@ where
else if let some mvarId simpMatch? mvarId then
trace[Elab.definition.structural.eqns] "simpMatch? succeeded"
go mvarId
else if let some mvarId simpIf? mvarId then
else if let some mvarId simpIf? mvarId (useNewSemantics := true) then
trace[Elab.definition.structural.eqns] "simpIf? succeeded"
go mvarId
else
@@ -66,7 +66,7 @@ where
else if let some mvarIds casesOnStuckLHS? mvarId then
trace[Elab.definition.structural.eqns] "casesOnStuckLHS? succeeded"
mvarIds.forM go
else if let some mvarIds splitTarget? mvarId then
else if let some mvarIds splitTarget? mvarId (useNewSemantics := true) then
trace[Elab.definition.structural.eqns] "splitTarget? succeeded"
mvarIds.forM go
else

View File

@@ -49,7 +49,7 @@ private partial def mkUnfoldProof (declName : Name) (mvarId : MVarId) : MetaM Un
else if let some mvarId simpMatch? mvarId then
trace[Elab.definition.wf.eqns] "simpMatch!"
mkUnfoldProof declName mvarId
else if let some mvarId simpIf? mvarId then
else if let some mvarId simpIf? mvarId (useNewSemantics := true) then
trace[Elab.definition.wf.eqns] "simpIf!"
mkUnfoldProof declName mvarId
else
@@ -63,7 +63,7 @@ private partial def mkUnfoldProof (declName : Name) (mvarId : MVarId) : MetaM Un
if let some mvarIds casesOnStuckLHS? mvarId then
trace[Elab.definition.wf.eqns] "case split into {mvarIds.size} goals"
mvarIds.forM (mkUnfoldProof declName)
else if let some mvarIds splitTarget? mvarId then
else if let some mvarIds splitTarget? mvarId (useNewSemantics := true) then
trace[Elab.definition.wf.eqns] "splitTarget into {mvarIds.length} goals"
mvarIds.forM (mkUnfoldProof declName)
else

View File

@@ -411,11 +411,11 @@ where
<|>
(casesOnStuckLHS mvarId)
<|>
(do let mvarId' simpIfTarget mvarId (useDecide := true)
(do let mvarId' simpIfTarget mvarId (useDecide := true) (useNewSemantics := true)
if mvarId' == mvarId then throwError "simpIf failed"
return #[mvarId'])
<|>
(do if let some (s₁, s₂) splitIfTarget? mvarId then
(do if let some (s₁, s₂) splitIfTarget? mvarId (useNewSemantics := true) then
let mvarId₁ trySubst s₁.mvarId s₁.fvarId
return #[mvarId₁, s₂.mvarId]
else

View File

@@ -280,12 +280,16 @@ end Split
open Split
partial def splitTarget? (mvarId : MVarId) (splitIte := true) : MetaM (Option (List MVarId)) := commitWhenSome? do mvarId.withContext do
/--
Splits an `if-then-else` of `match`-expression in the goal target.
If `useNewSemantics` is `true`, the flag `backward.split` is ignored. Recall this flag only affects the split of `if-then-else` expressions.
-/
partial def splitTarget? (mvarId : MVarId) (splitIte := true) (useNewSemantics := false) : MetaM (Option (List MVarId)) := commitWhenSome? do mvarId.withContext do
let target instantiateMVars ( mvarId.getType)
let rec go (badCases : ExprSet) : MetaM (Option (List MVarId)) := do
if let some e findSplit? target (if splitIte then .both else .match) badCases then
if e.isIte || e.isDIte then
return ( splitIfTarget? mvarId).map fun (s₁, s₂) => [s₁.mvarId, s₂.mvarId]
return ( splitIfTarget? mvarId (useNewSemantics := useNewSemantics)).map fun (s₁, s₂) => [s₁.mvarId, s₂.mvarId]
else
try
splitMatch mvarId e

View File

@@ -126,11 +126,18 @@ private partial def findIfToSplit? (e : Expr) : MetaM (Option (Expr × Expr)) :=
else
return none
namespace SplitIf
register_builtin_option backward.split : Bool := {
defValue := true
group := "backward compatibility"
descr := "use the old semantics for the `split` tactic where nested `if-then-else` terms could be simplified too"
}
namespace SplitIf
/--
Default `Simp.Context` for `simpIf` methods. It contains all congruence theorems, but
just the rewriting rules for reducing `if` expressions. -/
The `Simp.Context` that used to be used with `simpIf` methods. It contains all congruence theorems, but
just the rewriting rules for reducing `if` expressions.
This function is only used when the old `split` tactic behavior is enabled.
-/
def getSimpContext : MetaM Simp.Context := do
let mut s : SimpTheorems := {}
s s.addConst ``if_pos
@@ -138,7 +145,16 @@ def getSimpContext : MetaM Simp.Context := do
s s.addConst ``dif_pos
s s.addConst ``dif_neg
Simp.mkContext
(simpTheorems := #[s])
(simpTheorems := #[s])
(congrTheorems := ( getSimpCongrTheorems))
(config := { Simp.neutralConfig with dsimp := false, letToHave := true })
/--
Default `Simp.Context` for `simpIf` methods. It contains all congruence theorems, but
without rewriting rules. We use simprocs to reduce the if-then-else terms -/
private def getSimpContext' : MetaM Simp.Context := do
Simp.mkContext
(simpTheorems := {})
(congrTheorems := ( getSimpCongrTheorems))
(config := { Simp.neutralConfig with dsimp := false, letToHave := true })
@@ -177,6 +193,67 @@ private def discharge? (numIndices : Nat) (useDecide : Bool) : Simp.Discharge :=
else
return none
private def reduceIte' (numIndices : Nat) (useDecideBool : Bool) : Simp.Simproc := fun e => do
let_expr f@ite α c i tb eb e | return .continue
let us := f.constLevels!
if let some h discharge? numIndices useDecideBool c then
let h := mkApp6 (mkConst ``if_pos us) c i h α tb eb
return .done { expr := tb, proof? := some h }
else if let some h discharge? numIndices useDecideBool (mkNot c) then
let h := mkApp6 (mkConst ``if_neg us) c i h α tb eb
return .done { expr := eb, proof? := some h }
else
-- `split` may have selected an `if-then-else` nested in `c`.
let r Simp.simp c
if r.expr == c then
return .done { expr := e }
else
let c' := r.expr
let dec := mkApp (mkConst ``Decidable) c'
let .some i' trySynthInstance dec | return .done { expr := e }
let h := mkApp8 (mkConst ``ite_cond_congr us) α c c' i i' tb eb ( r.getProof)
let e' := mkApp5 f α c' i' tb eb
return .done { expr := e', proof? := some h }
private def getBinderName (e : Expr) : MetaM Name := do
let .lam n _ _ _ := e | mkFreshUserName `h
return n
private def reduceDIte' (numIndices : Nat) (useDecideBool : Bool) : Simp.Simproc := fun e => do
let_expr f@dite α c i tb eb e | return .continue
let us := f.constLevels!
if let some h discharge? numIndices useDecideBool c then
let e' := mkApp tb h |>.headBeta
let h := mkApp6 (mkConst ``dif_pos us) c i h α tb eb
return .done { expr := e', proof? := some h }
else if let some h discharge? numIndices useDecideBool (mkNot c) then
let e' := mkApp eb h |>.headBeta
let h := mkApp6 (mkConst ``dif_neg us) c i h α tb eb
return .done { expr := e', proof? := some h }
else
-- `split` may have selected an `if-then-else` nested in `c`.
let r Simp.simp c
if r.expr == c then
return .done { expr := e }
else
let c' := r.expr
let h r.getProof
let dec := mkApp (mkConst ``Decidable) c'
let .some i' trySynthInstance dec | return .done { expr := e }
let tb' := mkApp tb (mkApp4 (mkConst ``Eq.mpr_prop) c c' h (mkBVar 0)) |>.headBeta
let tb' := mkLambda ( getBinderName tb) .default c' tb'
let eb' := mkApp eb (mkApp4 (mkConst ``Eq.mpr_not) c c' h (mkBVar 0)) |>.headBeta
let eb' := mkLambda ( getBinderName eb) .default (mkNot c') eb'
let e' := mkApp5 f α c' i' tb' eb'
let h := mkApp8 (mkConst ``dite_cond_congr us) α c c' i i' tb eb h
return .done { expr := e', proof? := some h }
private def getSimprocs (numIndices : Nat) (useDecide : Bool) : MetaM (Array Simprocs) := do
let s : Simprocs := {}
let s := s.addCore #[.const ``ite 5, .star, .star, .star, .star, .star] ``reduceIte' (post := false) (.inl <| reduceIte' numIndices useDecide)
let s := s.addCore #[.const ``dite 5, .star, .star, .star, .star, .star] ``reduceDIte' (post := false) (.inl <| reduceDIte' numIndices useDecide)
return #[s]
def mkDischarge? (useDecide := false) : MetaM Simp.Discharge :=
return discharge? ( getLCtx).numIndices useDecide
@@ -196,44 +273,68 @@ end SplitIf
open SplitIf
def simpIfTarget (mvarId : MVarId) (useDecide := false) : MetaM MVarId := do
let mut ctx getSimpContext
if let (some mvarId', _) simpTarget mvarId ctx {} ( mvarId.withContext <| mkDischarge? useDecide) (mayCloseGoal := false) then
private def getNumIndices (mvarId : MVarId) : MetaM Nat :=
mvarId.withContext do return ( getLCtx).numIndices
/--
Simplify the `if-then-else` targeted by the `split` tactic. If `useNewSemantics` is `true`, the flag
`backward.split` is ignored.
-/
def simpIfTarget (mvarId : MVarId) (useDecide := false) (useNewSemantics := false) : MetaM MVarId := do
if useNewSemantics || !backward.split.get ( getOptions) then
let ctx getSimpContext'
let numIndices getNumIndices mvarId
let s getSimprocs numIndices useDecide
let (some mvarId', _) simpTarget mvarId ctx s (mayCloseGoal := false) | unreachable!
return mvarId'
else
unreachable!
let mut ctx getSimpContext
let (some mvarId', _) simpTarget mvarId ctx {} ( mvarId.withContext <| mkDischarge? useDecide) (mayCloseGoal := false) | unreachable!
return mvarId'
def simpIfLocalDecl (mvarId : MVarId) (fvarId : FVarId) : MetaM MVarId := do
let mut ctx getSimpContext
if let (some (_, mvarId'), _) simpLocalDecl mvarId fvarId ctx {} ( mvarId.withContext <| mkDischarge?) (mayCloseGoal := false) then
/--
Simplify the `if-then-else` targeted by the `split` tactic. If `useNewSemantics` is `true`, the flag
`backward.split` is ignored.
-/
def simpIfLocalDecl (mvarId : MVarId) (fvarId : FVarId) (useNewSemantics := false) : MetaM MVarId := do
if useNewSemantics || !backward.split.get ( getOptions) then
let ctx getSimpContext'
let numIndices getNumIndices mvarId
let s getSimprocs numIndices (useDecide := false)
let (some (_, mvarId'), _) simpLocalDecl mvarId fvarId ctx s (mayCloseGoal := false) | unreachable!
return mvarId'
else
unreachable!
let mut ctx getSimpContext
let (some (_, mvarId'), _) simpLocalDecl mvarId fvarId ctx {} ( mvarId.withContext <| mkDischarge?) (mayCloseGoal := false) | unreachable!
return mvarId'
def splitIfTarget? (mvarId : MVarId) (hName? : Option Name := none) : MetaM (Option (ByCasesSubgoal × ByCasesSubgoal)) := commitWhenSome? do
if let some (s₁, s₂) splitIfAt? mvarId ( mvarId.getType) hName? then
let mvarId₁ simpIfTarget s₁.mvarId
let mvarId₂ simpIfTarget s₂.mvarId
/--
Split an `if-then-else` in the goal target.
If `useNewSemantics` is `true`, the flag `backward.split` is ignored.
-/
def splitIfTarget? (mvarId : MVarId) (hName? : Option Name := none) (useNewSemantics := false) : MetaM (Option (ByCasesSubgoal × ByCasesSubgoal)) := commitWhenSome? do
let some (s₁, s₂) splitIfAt? mvarId ( mvarId.getType) hName? | return none
let mvarId₁ simpIfTarget s₁.mvarId (useNewSemantics := useNewSemantics)
let mvarId₂ simpIfTarget s₂.mvarId (useNewSemantics := useNewSemantics)
if s₁.mvarId == mvarId₁ && s₂.mvarId == mvarId₂ then
trace[split.failure] "`split` tactic failed to simplify target using new hypotheses Goals:\n{mvarId₁}\n{mvarId₂}"
return none
else
return some ({ s₁ with mvarId := mvarId₁ }, { s₂ with mvarId := mvarId₂ })
/--
Split an `if-then-else` in the hypothesis `fvarId`.
-/
def splitIfLocalDecl? (mvarId : MVarId) (fvarId : FVarId) (hName? : Option Name := none) : MetaM (Option (MVarId × MVarId)) := commitWhenSome? do
mvarId.withContext do
let some (s₁, s₂) splitIfAt? mvarId ( inferType (mkFVar fvarId)) hName? | return none
let mvarId₁ simpIfLocalDecl s₁.mvarId fvarId
let mvarId₂ simpIfLocalDecl s₂.mvarId fvarId
if s₁.mvarId == mvarId₁ && s₂.mvarId == mvarId₂ then
trace[split.failure] "`split` tactic failed to simplify target using new hypotheses Goals:\n{mvarId₁}\n{mvarId₂}"
return none
else
return some ({ s₁ with mvarId := mvarId₁ }, { s₂ with mvarId := mvarId₂ })
else
return none
def splitIfLocalDecl? (mvarId : MVarId) (fvarId : FVarId) (hName? : Option Name := none) : MetaM (Option (MVarId × MVarId)) := commitWhenSome? do
mvarId.withContext do
if let some (s₁, s₂) splitIfAt? mvarId ( inferType (mkFVar fvarId)) hName? then
let mvarId₁ simpIfLocalDecl s₁.mvarId fvarId
let mvarId₂ simpIfLocalDecl s₂.mvarId fvarId
if s₁.mvarId == mvarId₁ && s₂.mvarId == mvarId₂ then
trace[split.failure] "`split` tactic failed to simplify target using new hypotheses Goals:\n{mvarId₁}\n{mvarId₂}"
return none
else
return some (mvarId₁, mvarId₂)
else
return none
return some (mvarId₁, mvarId₂)
builtin_initialize registerTraceClass `Meta.Tactic.splitIf

56
tests/lean/run/9322.lean Normal file
View File

@@ -0,0 +1,56 @@
structure CompoundName where
module : String
name : String
structure T where
name : CompoundName
placeholder : Nat
/--
Value whose size is bounded by a constant offset of another value.
-/
abbrev Bounded (α : Type _) [SizeOf α] {β} [SizeOf β] (e : β) (c : Int) := { a : α // sizeOf a sizeOf e + c }
def getArg (e : T) : Bounded T e (-1) := sorry
mutual
inductive A where
| a : A
| b : C A
inductive C
| c : A C
end
mutual
def mkA (x : T) : A :=
match x.name with
| CompoundName.mk "LongName" "a" => .a
| CompoundName.mk "LongName" "b" => .a
| CompoundName.mk "LongName" "c" => .a
| CompoundName.mk "LongName" "d" => .a
| CompoundName.mk "LongName" "e" => .a
| CompoundName.mk "LongName" "f" => .a
| CompoundName.mk "LongName" "g" => .a
| CompoundName.mk "LongName" "h" => .a
| CompoundName.mk "LongName" "i" => .a
| CompoundName.mk "LongName" "j" => .a
| CompoundName.mk "LongName" "k" => .a
| CompoundName.mk "LongName" "z" =>
let y, _ := getArg x
.b (mkC y)
| _ => sorry
termination_by sizeOf x
def mkC (x : T) : C :=
match x.name with
| CompoundName.mk "LongName" "b" =>
let y, _ := getArg x
.c (mkA y)
| f => sorry
termination_by sizeOf x
end

View File

@@ -8,3 +8,57 @@ example (b : Bool) : (if (if b then true else true) then 1 else 2) = 1 := by
guard_target = (if true = true then 1 else 2) = 1
guard_hyp h : ¬b = true
simp
example (b : Bool) : (if h : (if b then true else true) then 1 else 2) = 1 := by
split
next h' =>
guard_target = (if h : true = true then 1 else 2) = 1
guard_hyp h' : b = true
simp
next h' =>
guard_target = (if h : true = true then 1 else 2) = 1
guard_hyp h' : ¬b = true
simp
opaque f (a : Nat) (h : a > 0) : Nat
axiom fax : f a h = a
example : (if h : (if true then a > 0 else False) then f a (by grind) else a) = a := by
split
next =>
split
next => simp [fax]
next => simp
next => simp
set_option backward.split false
example (b : Bool) : (if (if b then true else true) then 1 else 2) = 1 := by
split
next h =>
guard_target = (if true = true then 1 else 2) = 1
guard_hyp h : b = true
simp
next h =>
guard_target = (if true = true then 1 else 2) = 1
guard_hyp h : ¬b = true
simp
example (b : Bool) : (if h : (if b then true else true) then 1 else 2) = 1 := by
split
next h' =>
guard_target = (if h : true = true then 1 else 2) = 1
guard_hyp h' : b = true
simp
next h' =>
guard_target = (if h : true = true then 1 else 2) = 1
guard_hyp h' : ¬b = true
simp
example : (if h : (if true then a > 0 else False) then f a (by grind) else a) = a := by
split
next =>
split
next => simp [fax]
next => simp
next => simp