feat: well-founded definitions irreducible by default (#4061)

we keep running into examples where working with well-founded recursion
is slow because defeq checks (which are all over the place, including
failing ones that are back-tracked) unfold well-founded definitions.

The definition of a function defined by well-founded recursion should be
an implementation detail that should only be peeked inside by the
equation generator and the functional induction generator.

We now mark the mutual recursive function as irreducible (if the user
did not
set a flag explicitly), and use `withAtLeastTransparency .all` when
producing
the equations.

Proofs can be fixed by using rewriting, or – a bit blunt, but nice for
adjusting
existing proofs – using `unseal` (a.k.a. `attribute [local
semireducible]`).

Mathlib performance does not change a whole lot:

http://speed.lean-fro.org/mathlib4/compare/08b82265-75db-4a28-b12b-08751b9ad04a/to/16f46d5e-28b1-41c4-a107-a6f6594841f8
Build instructions -0.126 %, four modules with significant instructions
decrease.

To reduce impact, these definitions were changed:

* `Nat.mod`, to make `1 % n` reduce definitionally, so that `1` as a
`Fin 2` literal
works nicely. Theorems with larger `Fin` literals tend to need a `unseal
Nat.modCore`
   https://github.com/leanprover/lean4/pull/4098
* `List.ofFn` rewritten to be structurally recursive and not go via
`Array.ofFn`:
   https://github.com/leanprover-community/batteries/pull/784

Alternative designs explored were

 * Making `WellFounded.fix` irreducible. 
 
One benefit is that recursive functions with equal definitions (possibly
after
instantiating fixed parameters) are defeq; this is used in mathlib to
relate

[`OrdinalApprox.gfpApprox`](https://leanprover-community.github.io/mathlib4_docs/Mathlib/SetTheory/Ordinal/FixedPointApproximants.html#OrdinalApprox.gfpApprox)
with `.lfpApprox`.
   
   But the downside is that one cannot use `unseal` in a
targeted way, being explicit in which recursive function needs to be
reducible here.

And in cases where Lean does unwanted unfolding, we’d still unfold the
recursive
definition once to expose `WellFounded.fix`, leading to large terms for
often no good
   reason.

* Defining `WellFounded.fix` to unroll defintionally once before hitting
a irreducible
`WellFounded.fixF`. This was explored in #4002. It shares most of the
ups and downs
with the previous variant, with the additional neat benefit that
function calls that
do not lead to recursive cases (e.g. a `[]` base case) reduce nicely.
This means that
   the majority of existing `rfl` proofs continue to work.

Issue #4051, which demonstrates how badly things can go if wf recursive
functions can be
unrolled, showed that making the recursive function irreducible there
leads to noticeably
faster elaboration than making `WellFounded.fix` irreducible; this is
good evidence that
the present PR is the way to go. 

This fixes https://github.com/leanprover/lean4/issues/3988

---------

Co-authored-by: Leonardo de Moura <leomoura@amazon.com>
This commit is contained in:
Joachim Breitner
2024-05-10 08:45:21 +02:00
committed by GitHub
parent ca6437df71
commit 39286862e3
20 changed files with 228 additions and 47 deletions

View File

@@ -11,7 +11,18 @@ of each version.
v4.9.0 (development in progress)
---------
v4.8.0
* Functions defined by well-founded recursion are now marked as
`@[irreducible]`, which should prevent expensive and often unfruitful
unfolding of such definitions.
Existing proofs that hold by definitional equality (e.g. `rfl`) can be
rewritten to explictly unfold the function definition (using `simp`,
`unfold`, `rw`), or the recursive function can be temporariliy made
semireducible (using `unseal f in` before the command) or the function
definition itself can be marked as `@[semireducible]` to get the previous
behavor.
v4.8.0
---------
* **Executables configured with `supportInterpreter := true` on Windows should now be run via `lake exe` to function properly.**

View File

@@ -11,6 +11,9 @@ import Init.ByCases
import Init.Conv
import Init.Omega
-- Remove after the next stage0 update
set_option allowUnsafeReducibility true
namespace Fin
/-- If you actually have an element of `Fin n`, then the `n` is always positive -/
@@ -205,6 +208,7 @@ theorem val_add_one {n : Nat} (i : Fin (n + 1)) :
| .inl h => cases Fin.eq_of_val_eq h; simp
| .inr h => simpa [Fin.ne_of_lt h] using val_add_one_of_lt h
unseal Nat.modCore in
@[simp] theorem val_two {n : Nat} : (2 : Fin (n + 3)).val = 2 := rfl
theorem add_one_pos (i : Fin (n + 1)) (h : i < Fin.last n) : (0 : Fin (n + 1)) < i + 1 := by
@@ -239,6 +243,7 @@ theorem succ_ne_zero {n} : ∀ k : Fin n, Fin.succ k ≠ 0
@[simp] theorem succ_zero_eq_one : Fin.succ (0 : Fin (n + 1)) = 1 := rfl
unseal Nat.modCore in
/-- Version of `succ_one_eq_two` to be used by `dsimp` -/
@[simp] theorem succ_one_eq_two : Fin.succ (1 : Fin (n + 2)) = 2 := rfl
@@ -390,6 +395,7 @@ theorem castSucc_lt_last (a : Fin n) : castSucc a < last n := a.is_lt
@[simp] theorem castSucc_zero : castSucc (0 : Fin (n + 1)) = 0 := rfl
unseal Nat.modCore in
@[simp] theorem castSucc_one {n : Nat} : castSucc (1 : Fin (n + 2)) = 1 := rfl
/-- `castSucc i` is positive when `i` is positive -/

View File

@@ -14,6 +14,8 @@ import Init.RCases
# Lemmas about integer division needed to bootstrap `omega`.
-/
-- Remove after the next stage0 update
set_option allowUnsafeReducibility true
open Nat (succ)
@@ -142,12 +144,14 @@ theorem eq_one_of_mul_eq_one_left {a b : Int} (H : 0 ≤ b) (H' : a * b = 1) : b
| ofNat _ => show ofNat _ = _ by simp
| -[_+1] => show -ofNat _ = _ by simp
unseal Nat.div in
@[simp] protected theorem div_zero : a : Int, div a 0 = 0
| ofNat _ => show ofNat _ = _ by simp
| -[_+1] => rfl
@[simp] theorem zero_fdiv (b : Int) : fdiv 0 b = 0 := by cases b <;> rfl
unseal Nat.div in
@[simp] protected theorem fdiv_zero : a : Int, fdiv a 0 = 0
| 0 => rfl
| succ _ => rfl
@@ -765,11 +769,13 @@ theorem ediv_eq_ediv_of_mul_eq_mul {a b c d : Int}
| (n:Nat) => congrArg ofNat (Nat.div_one _)
| -[n+1] => by simp [Int.div, neg_ofNat_succ]; rfl
unseal Nat.div in
@[simp] protected theorem div_neg : a b : Int, a.div (-b) = -(a.div b)
| ofNat m, 0 => show ofNat (m / 0) = -(m / 0) by rw [Nat.div_zero]; rfl
| ofNat m, -[n+1] | -[m+1], succ n => (Int.neg_neg _).symm
| ofNat m, succ n | -[m+1], 0 | -[m+1], -[n+1] => rfl
unseal Nat.div in
@[simp] protected theorem neg_div : a b : Int, (-a).div b = -(a.div b)
| 0, n => by simp [Int.neg_zero]
| succ m, (n:Nat) | -[m+1], 0 | -[m+1], -[n+1] => rfl
@@ -938,6 +944,7 @@ theorem fdiv_nonneg {a b : Int} (Ha : 0 ≤ a) (Hb : 0 ≤ b) : 0 ≤ a.fdiv b :
match a, b, eq_ofNat_of_zero_le Ha, eq_ofNat_of_zero_le Hb with
| _, _, _, rfl, _, rfl => ofNat_fdiv .. ofNat_zero_le _
unseal Nat.div in
theorem fdiv_nonpos : {a b : Int}, 0 a b 0 a.fdiv b 0
| 0, 0, _, _ | 0, -[_+1], _, _ | succ _, 0, _, _ | succ _, -[_+1], _, _ => _

View File

@@ -50,7 +50,10 @@ noncomputable def div2Induction {motive : Nat → Sort u}
apply hyp
exact Nat.div_lt_self n_pos (Nat.le_refl _)
@[simp] theorem zero_and (x : Nat) : 0 &&& x = 0 := by rfl
@[simp] theorem zero_and (x : Nat) : 0 &&& x = 0 := by
simp only [HAnd.hAnd, AndOp.and, land]
unfold bitwise
simp
@[simp] theorem and_zero (x : Nat) : x &&& 0 = 0 := by
simp only [HAnd.hAnd, AndOp.and, land]

View File

@@ -37,11 +37,11 @@ def gcd (m n : @& Nat) : Nat :=
termination_by m
decreasing_by simp_wf; apply mod_lt _ (zero_lt_of_ne_zero _); assumption
@[simp] theorem gcd_zero_left (y : Nat) : gcd 0 y = y :=
rfl
@[simp] theorem gcd_zero_left (y : Nat) : gcd 0 y = y := by
rw [gcd]; rfl
theorem gcd_succ (x y : Nat) : gcd (succ x) y = gcd (y % succ x) (succ x) :=
rfl
theorem gcd_succ (x y : Nat) : gcd (succ x) y = gcd (y % succ x) (succ x) := by
rw [gcd]; rfl
@[simp] theorem gcd_one_left (n : Nat) : gcd 1 n = 1 := by
rw [gcd_succ, mod_one]
@@ -64,7 +64,7 @@ instance : Std.IdempotentOp gcd := ⟨gcd_self⟩
theorem gcd_rec (m n : Nat) : gcd m n = gcd (n % m) m :=
match m with
| 0 => by have := (mod_zero n).symm; rwa [gcd_zero_right]
| 0 => by have := (mod_zero n).symm; rwa [gcd, gcd_zero_right]
| _ + 1 => by simp [gcd_succ]
@[elab_as_elim] theorem gcd.induction {P : Nat Nat Prop} (m n : Nat)

View File

@@ -677,6 +677,10 @@ protected theorem pow_lt_pow_iff_right {a n m : Nat} (h : 1 < a) :
/-! ### log2 -/
@[simp]
theorem log2_zero : Nat.log2 0 = 0 := by
simp [Nat.log2]
theorem le_log2 (h : n 0) : k n.log2 2 ^ k n := by
match k with
| 0 => simp [show 1 n from Nat.pos_of_ne_zero h]
@@ -697,7 +701,7 @@ theorem log2_self_le (h : n ≠ 0) : 2 ^ n.log2 ≤ n := (le_log2 h).1 (Nat.le_r
theorem lt_log2_self : n < 2 ^ (n.log2 + 1) :=
match n with
| 0 => Nat.zero_lt_two
| 0 => by simp
| n+1 => (log2_lt n.succ_ne_zero).1 (Nat.le_refl _)
/-! ### dvd -/

View File

@@ -81,7 +81,7 @@ private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do
let (_, mvarId) main.mvarId!.intros
let rec go (mvarId : MVarId) : MetaM Unit := do
trace[Elab.definition.wf.eqns] "step\n{MessageData.ofGoal mvarId}"
if ( tryURefl mvarId) then
if withAtLeastTransparency .all (tryURefl mvarId) then
return ()
else if ( tryContradiction mvarId) then
return ()

View File

@@ -132,12 +132,15 @@ def wfRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
return { unaryPreDef with value }
trace[Elab.definition.wf] ">> {preDefNonRec.declName} :=\n{preDefNonRec.value}"
let preDefs preDefs.mapM fun d => eraseRecAppSyntax d
if ( isOnlyOneUnaryDef preDefs fixedPrefixSize) then
addNonRec preDefNonRec (applyAttrAfterCompilation := false)
else
withEnableInfoTree false do
-- Do not complain if the user sets @[semireducible], which usually is a noop,
-- we recognize that below and then do not set @[irreducible]
withOptions (allowUnsafeReducibility.set · true) do
if ( isOnlyOneUnaryDef preDefs fixedPrefixSize) then
addNonRec preDefNonRec (applyAttrAfterCompilation := false)
addNonRecPreDefs fixedPrefixSize argsPacker preDefs preDefNonRec
else
withEnableInfoTree false do
addNonRec preDefNonRec (applyAttrAfterCompilation := false)
addNonRecPreDefs fixedPrefixSize argsPacker preDefs preDefNonRec
-- We create the `_unsafe_rec` before we abstract nested proofs.
-- Reason: the nested proofs may be referring to the _unsafe_rec.
addAndCompilePartialRec preDefs
@@ -146,6 +149,10 @@ def wfRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
for preDef in preDefs do
markAsRecursive preDef.declName
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
-- Unless the user asks for something else, mark the definition as irreducible
unless preDef.modifiers.attrs.any fun a =>
a.name = `semireducible || a.name = `reducible || a.name = `semireducible do
setIrreducibleAttribute preDef.declName
builtin_initialize registerTraceClass `Elab.definition.wf

View File

@@ -184,4 +184,9 @@ def isIrreducible [Monad m] [MonadEnv m] (declName : Name) : m Bool := do
| .irreducible => return true
| _ => return false
/-- Set the given declaration as `[irreducible]` -/
def setIrreducibleAttribute [Monad m] [MonadEnv m] (declName : Name) : m Unit := do
setReducibilityStatus declName ReducibilityStatus.irreducible
end Lean

View File

@@ -1,5 +1,7 @@
#include "util/options.h"
// please auto update stage0
namespace lean {
options get_default_options() {
options opts;

View File

@@ -6,4 +6,4 @@ def f := #[true].any id 0 USize.size
-- `native_decide` used to prove `false` here, due to a bug in `Array.anyMUnsafe`.
example : f = true := by native_decide
example : f = true := by simp (config := { decide := true }) [f, Array.any, Array.anyM]
example : f = true := by simp (config := { decide := true }) [f, Array.any, Array.anyM, Array.anyM.loop]

View File

@@ -26,7 +26,9 @@ def onlyZeros : Tree → Prop
| .node [] => True
| .node (x::s) => onlyZeros x onlyZeros (.node s)
/-- Pattern-matching on `OnlyZeros` works despite `below` and `brecOn` not being generated. -/
unseal onlyZeros in
/-- Pattern-matching on `OnlyZeros` works despite `below` and `brecOn` not being generated
if we make `onlyZeros` semireducible-/
def toFixPoint : OnlyZeros t onlyZeros t
| .leaf => rfl
| .node [] _ => True.intro

View File

@@ -28,6 +28,7 @@ info: [reduction] unfolded declarations (max: 1725, num: 4):
Acc.rec ↦ 754use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
-/
#guard_msgs in
unseal ack in
set_option diagnostics.threshold 500 in
set_option diagnostics true in
theorem ex : ack 3 2 = 29 :=

View File

@@ -8,7 +8,7 @@ def f (x : Nat) : Nat := by
#eval f 10
example : f x.succ = 2 * f x := rfl
example : f x.succ = 2 * f x := by rw [f]; rfl
end Ex1
namespace Ex2

View File

@@ -22,7 +22,7 @@ termination_by (x, y)
example (x y : Nat) : f x y > 0 := by
induction x, y with
| zero_zero => decide
| zero_zero => simp [f]
| succ_zero x ih => simp [f, ih]
| zero_succ y ih => simp [f, ih]
| succ_succ x y ih => simp [f, ih]

View File

@@ -25,7 +25,7 @@ in the list, ignoring delays
theorem length_toList (l : LazyList α) : l.toList.length = l.length := by
match l with
| nil => rfl
| nil => simp [length_toList]
| cons a as => simp [length_toList as]
| delayed as => simp [length_toList as.get]

View File

@@ -15,23 +15,21 @@ termination_by l => l.length
decreasing_by
all_goals sorry
theorem len_nil : len ([] : List α) = 0 := by
simp [len]
-- The `simp [len]` above generated the following equation theorems for len
-- The equational theorems are
#check @len.eq_1
#check @len.eq_2
#check @len.eq_3 -- It is conditional, and may be tricky to use.
#check @len.eq_def
theorem len_nil : len ([] : List α) = 0 := by
simp [len]
theorem len_1 (a : α) : len [a] = 1 := by
simp [len]
theorem len_2 (a b : α) (bs : List α) : len (a::b::bs) = 1 + len (b::bs) := by
conv => lhs; unfold len
rfl
-- The `unfold` tactic above generated the following theorem
#check @len.eq_def
cases bs <;> simp [splitList, len_1]
theorem len_cons (a : α) (as : List α) : len (a::as) = 1 + len as := by
cases as with
@@ -41,7 +39,7 @@ theorem len_cons (a : α) (as : List α) : len (a::as) = 1 + len as := by
theorem listlen : l : List α, l.length = len l := by
intro l
induction l with
| nil => rfl
| nil => simp [len]
| cons h t ih =>
simp [List.length, len_cons, ih]
rw [Nat.add_comm]

View File

@@ -34,23 +34,21 @@ def len : List α → Nat
len fst + len snd
termination_by xs => xs.length
theorem len_nil : len ([] : List α) = 0 := by
simp [len]
-- The `simp [len]` above generated the following equation theorems for len
-- The equational theorems are
#check @len.eq_1
#check @len.eq_2
#check @len.eq_3
#check @len.eq_def
theorem len_nil : len ([] : List α) = 0 := by
simp [len]
theorem len_1 (a : α) : len [a] = 1 := by
simp [len]
theorem len_2 (a b : α) (bs : List α) : len (a::b::bs) = 1 + len (b::bs) := by
conv => lhs; unfold len
rfl
-- The `unfold` tactic above generated the following theorem
#check @len.eq_def
simp [len, splitList]
theorem len_cons (a : α) (as : List α) : len (a::as) = 1 + len as := by
cases as with
@@ -60,7 +58,7 @@ theorem len_cons (a : α) (as : List α) : len (a::as) = 1 + len as := by
theorem listlen : l : List α, l.length = len l := by
intro l
induction l with
| nil => rfl
| nil => simp [len_nil]
| cons h t ih =>
simp [List.length, len_cons, ih]
rw [Nat.add_comm]
@@ -85,23 +83,21 @@ decreasing_by
subst h₂
simp_arith [eq_of_heq h₃] at this |- ; simp [this]
theorem len_nil : len ([] : List α) = 0 := by
simp [len]
-- The `simp [len]` above generated the following equation theorems for len
-- The equational theorems are
#check @len.eq_1
#check @len.eq_2
#check @len.eq_3
#check @len.eq_def
theorem len_nil : len ([] : List α) = 0 := by
simp [len]
theorem len_1 (a : α) : len [a] = 1 := by
simp [len]
theorem len_2 (a b : α) (bs : List α) : len (a::b::bs) = 1 + len (b::bs) := by
conv => lhs; unfold len
rfl
-- The `unfold` tactic above generated the following theorem
#check @len.eq_def
simp [len, splitList]
theorem len_cons (a : α) (as : List α) : len (a::as) = 1 + len as := by
cases as with
@@ -111,7 +107,7 @@ theorem len_cons (a : α) (as : List α) : len (a::as) = 1 + len as := by
theorem listlen : l : List α, l.length = len l := by
intro l
induction l with
| nil => rfl
| nil => simp [len_nil]
| cons h t ih =>
simp [List.length, len_cons, ih]
rw [Nat.add_comm]

139
tests/lean/run/wfirred.lean Normal file
View File

@@ -0,0 +1,139 @@
/-!
Tests that definitions by well-founded recursion are irreducible.
-/
def foo : Nat Nat
| 0 => 0
| n+1 => foo n
termination_by n => n
/--
error: type mismatch
rfl
has type
foo 0 = foo 0 : Prop
but is expected to have type
foo 0 = 0 : Prop
-/
#guard_msgs in
example : foo 0 = 0 := rfl
/--
error: type mismatch
rfl
has type
foo (n + 1) = foo (n + 1) : Prop
but is expected to have type
foo (n + 1) = foo n : Prop
-/
#guard_msgs in
example : foo (n+1) = foo n := rfl
-- This succeeding is a bug or misfeature in the rfl tactic, using the kernel defeq check
#guard_msgs in
example : foo 0 = 0 := by rfl
-- It only works on closed terms:
/--
error: The rfl tactic failed. Possible reasons:
- The goal is not a reflexive relation (neither `=` nor a relation with a @[refl] lemma).
- The arguments of the relation are not equal.
Try using the reflexivitiy lemma for your relation explicitly, e.g. `exact Eq.rfl`.
n : Nat
⊢ foo (n + 1) = foo n
-/
#guard_msgs in
example : foo (n+1) = foo n := by rfl
section Unsealed
unseal foo
example : foo 0 = 0 := rfl
example : foo 0 = 0 := by rfl
example : foo (n+1) = foo n := rfl
example : foo (n+1) = foo n := by rfl
end Unsealed
--should be sealed again here
/--
error: type mismatch
rfl
has type
foo 0 = foo 0 : Prop
but is expected to have type
foo 0 = 0 : Prop
-/
#guard_msgs in
example : foo 0 = 0 := rfl
def bar : Nat Nat
| 0 => 0
| n+1 => bar n
termination_by n => n
-- Once unsealed, the full internals are visible. This allows one to prove, for example
/--
error: type mismatch
rfl
has type
foo = foo : Prop
but is expected to have type
foo = bar : Prop
-/
#guard_msgs in
example : foo = bar := rfl
unseal foo bar in
example : foo = bar := rfl
-- Attributes on the definition take precedence
@[semireducible] def baz : Nat Nat
| 0 => 0
| n+1 => baz n
termination_by n => n
example : baz 0 = 0 := rfl
seal baz in
/--
error: type mismatch
rfl
has type
baz 0 = baz 0 : Prop
but is expected to have type
baz 0 = 0 : Prop
-/
#guard_msgs in
example : baz 0 = 0 := rfl
example : baz 0 = 0 := rfl
@[reducible] def quux : Nat Nat
| 0 => 0
| n+1 => quux n
termination_by n => n
example : quux 0 = 0 := rfl
set_option allowUnsafeReducibility true in
seal quux in
/--
error: type mismatch
rfl
has type
quux 0 = quux 0 : Prop
but is expected to have type
quux 0 = 0 : Prop
-/
#guard_msgs in
example : quux 0 = 0 := rfl
example : quux 0 = 0 := rfl

View File

@@ -1,4 +1,4 @@
def f : Nat → Nat :=
@[irreducible] def f : Nat → Nat :=
f.proof_1.fix fun n a =>
if h : n = 0 then 1
else