Compare commits

..

6 Commits

Author SHA1 Message Date
Kim Morrison
9d8dc1a556 missing theorem 2025-04-10 09:27:55 +10:00
Kim Morrison
a5ddc8f4f9 oops, merge problems 2025-04-10 09:26:29 +10:00
Kim Morrison
72a63d87c9 Merge remote-tracking branch 'origin/master' into UIntX.ofInt 2025-04-10 09:23:25 +10:00
Kim Morrison
c99eade8cf fix merge 2025-04-09 23:55:07 +10:00
Kim Morrison
3cea6eb7ad fix merge 2025-04-09 23:54:27 +10:00
Kim Morrison
5b9adfed13 feat: UIntX.ofInt 2025-04-09 23:51:56 +10:00
583 changed files with 1320 additions and 3661 deletions

View File

@@ -46,7 +46,7 @@ jobs:
CCACHE_DIR: ${{ github.workspace }}/.ccache
CCACHE_COMPRESS: true
# current cache limit
CCACHE_MAXSIZE: 400M
CCACHE_MAXSIZE: 600M
# squelch error message about missing nixpkgs channel
NIX_BUILD_SHELL: bash
LSAN_OPTIONS: max_leaks=10

View File

@@ -256,18 +256,17 @@ jobs:
"llvm-url": "https://github.com/leanprover/lean-llvm/releases/download/15.0.1/lean-llvm-aarch64-linux-gnu.tar.zst",
"prepare-llvm": "../script/prepare-llvm-linux.sh lean-llvm*"
},
// Started running out of memory building expensive modules, a 2GB heap is just not that much even before fragmentation
//{
// "name": "Linux 32bit",
// "os": "ubuntu-latest",
// // Use 32bit on stage0 and stage1 to keep oleans compatible
// "CMAKE_OPTIONS": "-DSTAGE0_USE_GMP=OFF -DSTAGE0_LEAN_EXTRA_CXX_FLAGS='-m32' -DSTAGE0_LEANC_OPTS='-m32' -DSTAGE0_MMAP=OFF -DUSE_GMP=OFF -DLEAN_EXTRA_CXX_FLAGS='-m32' -DLEANC_OPTS='-m32' -DMMAP=OFF -DLEAN_INSTALL_SUFFIX=-linux_x86 -DCMAKE_LIBRARY_PATH=/usr/lib/i386-linux-gnu/ -DSTAGE0_CMAKE_LIBRARY_PATH=/usr/lib/i386-linux-gnu/ -DPKG_CONFIG_EXECUTABLE=/usr/bin/i386-linux-gnu-pkg-config",
// "cmultilib": true,
// "release": true,
// "check-level": 2,
// "cross": true,
// "shell": "bash -euxo pipefail {0}"
//}
{
"name": "Linux 32bit",
"os": "ubuntu-latest",
// Use 32bit on stage0 and stage1 to keep oleans compatible
"CMAKE_OPTIONS": "-DSTAGE0_USE_GMP=OFF -DSTAGE0_LEAN_EXTRA_CXX_FLAGS='-m32' -DSTAGE0_LEANC_OPTS='-m32' -DSTAGE0_MMAP=OFF -DUSE_GMP=OFF -DLEAN_EXTRA_CXX_FLAGS='-m32' -DLEANC_OPTS='-m32' -DMMAP=OFF -DLEAN_INSTALL_SUFFIX=-linux_x86 -DCMAKE_LIBRARY_PATH=/usr/lib/i386-linux-gnu/ -DSTAGE0_CMAKE_LIBRARY_PATH=/usr/lib/i386-linux-gnu/ -DPKG_CONFIG_EXECUTABLE=/usr/bin/i386-linux-gnu-pkg-config",
"cmultilib": true,
"release": true,
"check-level": 2,
"cross": true,
"shell": "bash -euxo pipefail {0}"
}
// {
// "name": "Web Assembly",
// "os": "ubuntu-latest",

View File

@@ -59,9 +59,6 @@ instance : Monad (OptionT m) where
pure := OptionT.pure
bind := OptionT.bind
instance {m : Type u Type v} [Pure m] : Inhabited (OptionT m α) where
default := pure (f:=m) default
/--
Recovers from failures. Typically used via the `<|>` operator.
-/

View File

@@ -34,6 +34,7 @@ import Init.Data.Stream
import Init.Data.Prod
import Init.Data.AC
import Init.Data.Queue
import Init.Data.Channel
import Init.Data.Sum
import Init.Data.BEq
import Init.Data.Subtype

View File

@@ -288,17 +288,6 @@ theorem count_flatMap {α} [BEq β] {xs : Array α} {f : α → Array β} {x :
rcases xs with xs
simp [List.count_flatMap, countP_flatMap, Function.comp_def]
theorem countP_replace {a b : α} {xs : Array α} {p : α Bool} :
(xs.replace a b).countP p =
if xs.contains a then xs.countP p + (if p b then 1 else 0) - (if p a then 1 else 0) else xs.countP p := by
rcases xs with xs
simp [List.countP_replace]
theorem count_replace {a b c : α} {xs : Array α} :
(xs.replace a b).count c =
if xs.contains a then xs.count c + (if b == c then 1 else 0) - (if a == c then 1 else 0) else xs.count c := by
simp [count_eq_countP, countP_replace]
-- FIXME these theorems can be restored once `List.erase` and `Array.erase` have been related.
-- theorem count_erase (a b : α) (l : Array α) : count a (l.erase b) = count a l - if b == a then 1 else 0 := by

View File

@@ -446,13 +446,11 @@ theorem findIdx?_eq_none_iff {xs : Array α} {p : α → Bool} :
rcases xs with xs
simp
@[simp]
theorem findIdx?_isSome {xs : Array α} {p : α Bool} :
(xs.findIdx? p).isSome = xs.any p := by
rcases xs with xs
simp [List.findIdx?_isSome]
@[simp]
theorem findIdx?_isNone {xs : Array α} {p : α Bool} :
(xs.findIdx? p).isNone = xs.all (¬p ·) := by
rcases xs with xs
@@ -593,18 +591,6 @@ theorem findFinIdx?_eq_some_iff {xs : Array α} {p : α → Bool} {i : Fin xs.si
· rintro h, w
exact i, i.2, h, fun j hji => w j, by omega hji, rfl
@[simp]
theorem isSome_findFinIdx? {xs : Array α} {p : α Bool} :
(xs.findFinIdx? p).isSome = xs.any p := by
rcases xs with xs
simp
@[simp]
theorem isNone_findFinIdx? {xs : Array α} {p : α Bool} :
(xs.findFinIdx? p).isNone = xs.all (fun x => ¬ p x) := by
rcases xs with xs
simp
@[simp] theorem findFinIdx?_subtype {p : α Prop} {xs : Array { x // p x }}
{f : { x // p x } Bool} {g : α Bool} (hf : x h, f x, h = g x) :
xs.findFinIdx? f = (xs.unattach.findFinIdx? g).map (fun i => i.cast (by simp)) := by
@@ -650,20 +636,6 @@ The lemmas below should be made consistent with those for `findIdx?` (and proved
rcases xs with xs
simp [List.idxOf?_eq_none_iff]
@[simp]
theorem isSome_idxOf? [BEq α] [LawfulBEq α] {xs : Array α} {a : α} :
(xs.idxOf? a).isSome a xs := by
rcases xs with xs
simp
@[simp]
theorem isNone_idxOf? [BEq α] [LawfulBEq α] {xs : Array α} {a : α} :
(xs.idxOf? a).isNone = ¬ a xs := by
rcases xs with xs
simp
/-! ### finIdxOf?
The verification API for `finIdxOf?` is still incomplete.
@@ -686,16 +658,4 @@ theorem idxOf?_eq_map_finIdxOf?_val [BEq α] {xs : Array α} {a : α} :
rcases xs with xs
simp [List.finIdxOf?_eq_some_iff]
@[simp]
theorem isSome_finIdxOf? [BEq α] [LawfulBEq α] {xs : Array α} {a : α} :
(xs.finIdxOf? a).isSome a xs := by
rcases xs with xs
simp
@[simp]
theorem isNone_finIdxOf? [BEq α] [LawfulBEq α] {xs : Array α} {a : α} :
(xs.finIdxOf? a).isNone = ¬ a xs := by
rcases xs with xs
simp
end Array

View File

@@ -65,20 +65,4 @@ theorem swap_perm {xs : Array α} {i j : Nat} (h₁ : i < xs.size) (h₂ : j < x
simp only [swap, perm_iff_toList_perm, toList_set]
apply set_set_perm
namespace Perm
set_option linter.indexVariables false in
theorem extract {xs ys : Array α} (h : xs ~ ys) {lo hi : Nat}
(wlo : i, i < lo xs[i]? = ys[i]?) (whi : i, hi i xs[i]? = ys[i]?) :
(xs.extract lo (hi + 1)) ~ (ys.extract lo (hi + 1)) := by
rcases xs with xs
rcases ys with ys
simp_all only [perm_toArray, List.getElem?_toArray, List.extract_toArray,
List.extract_eq_drop_take]
apply List.Perm.take_of_getElem? (w := fun i h => by simpa using whi (lo + i) (by omega))
apply List.Perm.drop_of_getElem? (w := wlo)
exact h
end Perm
end Array

View File

@@ -227,20 +227,6 @@ SMT-LIB name: `bvmul`.
protected def mul (x y : BitVec n) : BitVec n := BitVec.ofNat n (x.toNat * y.toNat)
instance : Mul (BitVec n) := .mul
/--
Raises a bitvector to a natural number power. Usually accessed via the `^` operator.
Note that this is currently an inefficient implementation,
and should be replaced via an `@[extern]` with a native implementation.
See https://github.com/leanprover/lean4/issues/7887.
-/
protected def pow (x : BitVec n) (y : Nat) : BitVec n :=
match y with
| 0 => 1
| y + 1 => x.pow y * x
instance : Pow (BitVec n) Nat where
pow x y := x.pow y
/--
Unsigned division of bitvectors using the Lean convention where division by zero returns zero.
Usually accessed via the `/` operator.

View File

@@ -3653,13 +3653,6 @@ theorem mul_def {n} {x y : BitVec n} : x * y = (ofFin <| x.toFin * y.toFin) := b
@[simp, bitvec_to_nat] theorem toNat_mul (x y : BitVec n) : (x * y).toNat = (x.toNat * y.toNat) % 2 ^ n := rfl
@[simp] theorem toFin_mul (x y : BitVec n) : (x * y).toFin = (x.toFin * y.toFin) := rfl
theorem ofNat_mul {n} (x y : Nat) : BitVec.ofNat n (x * y) = BitVec.ofNat n x * BitVec.ofNat n y := by
apply eq_of_toNat_eq
simp [BitVec.ofNat, Fin.ofNat'_mul]
theorem ofNat_mul_ofNat {n} (x y : Nat) : BitVec.ofNat n x * BitVec.ofNat n y = BitVec.ofNat n (x * y) :=
(ofNat_mul x y).symm
protected theorem mul_comm (x y : BitVec w) : x * y = y * x := by
apply eq_of_toFin_eq; simpa using Fin.mul_comm ..
instance : Std.Commutative (fun (x y : BitVec w) => x * y) := BitVec.mul_comm
@@ -3753,22 +3746,6 @@ theorem setWidth_mul (x y : BitVec w) (h : i ≤ w) :
have dvd : 2^i 2^w := Nat.pow_dvd_pow _ h
simp [bitvec_to_nat, h, Nat.mod_mod_of_dvd _ dvd]
/-! ### pow -/
@[simp]
protected theorem pow_zero {x : BitVec w} : x ^ 0 = 1#w := rfl
protected theorem pow_succ {x : BitVec w} : x ^ (n + 1) = x ^ n * x := rfl
@[simp]
protected theorem pow_one {x : BitVec w} : x ^ 1 = x := by simp [BitVec.pow_succ]
protected theorem pow_add {x : BitVec w} {n m : Nat}: x ^ (n + m) = (x ^ n) * (x ^ m):= by
induction m with
| zero => simp
| succ m ih =>
rw [ Nat.add_assoc, BitVec.pow_succ, ih, BitVec.mul_assoc, BitVec.pow_succ]
/-! ### le and lt -/
@[bitvec_to_nat] theorem le_def {x y : BitVec n} :

149
src/Init/Data/Channel.lean Normal file
View File

@@ -0,0 +1,149 @@
/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Gabriel Ebner
-/
prelude
import Init.Data.Queue
import Init.System.Promise
import Init.System.Mutex
set_option linter.deprecated false
namespace IO
/--
Internal state of an `Channel`.
We maintain the invariant that at all times either `consumers` or `values` is empty.
-/
@[deprecated "Use Std.Channel.State from Std.Sync.Channel instead" (since := "2024-12-02")]
structure Channel.State (α : Type) where
values : Std.Queue α :=
consumers : Std.Queue (Promise (Option α)) :=
closed := false
deriving Inhabited
/--
FIFO channel with unbounded buffer, where `recv?` returns a `Task`.
A channel can be closed. Once it is closed, all `send`s are ignored, and
`recv?` returns `none` once the queue is empty.
-/
@[deprecated "Use Std.Channel from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel (α : Type) : Type := Mutex (Channel.State α)
instance : Nonempty (Channel α) :=
inferInstanceAs (Nonempty (Mutex _))
/-- Creates a new `Channel`. -/
@[deprecated "Use Std.Channel.new from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.new : BaseIO (Channel α) :=
Mutex.new {}
/--
Sends a message on an `Channel`.
This function does not block.
-/
@[deprecated "Use Std.Channel.send from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.send (ch : Channel α) (v : α) : BaseIO Unit :=
ch.atomically do
let st get
if st.closed then return
if let some (consumer, consumers) := st.consumers.dequeue? then
consumer.resolve (some v)
set { st with consumers }
else
set { st with values := st.values.enqueue v }
/--
Closes an `Channel`.
-/
@[deprecated "Use Std.Channel.close from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.close (ch : Channel α) : BaseIO Unit :=
ch.atomically do
let st get
for consumer in st.consumers.toArray do consumer.resolve none
set { st with closed := true, consumers := }
/--
Receives a message, without blocking.
The returned task waits for the message.
Every message is only received once.
Returns `none` if the channel is closed and the queue is empty.
-/
@[deprecated "Use Std.Channel.recv? from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.recv? (ch : Channel α) : BaseIO (Task (Option α)) :=
ch.atomically do
let st get
if let some (a, values) := st.values.dequeue? then
set { st with values }
return .pure a
else if !st.closed then
let promise Promise.new
set { st with consumers := st.consumers.enqueue promise }
return promise.result
else
return .pure none
/--
`ch.forAsync f` calls `f` for every messages received on `ch`.
Note that if this function is called twice, each `forAsync` only gets half the messages.
-/
@[deprecated "Use Std.Channel.forAsync from Std.Sync.Channel instead" (since := "2024-12-02")]
partial def Channel.forAsync (f : α BaseIO Unit) (ch : Channel α)
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
BaseIO.bindTask (prio := prio) ( ch.recv?) fun
| none => return .pure ()
| some v => do f v; ch.forAsync f prio
/--
Receives all currently queued messages from the channel.
Those messages are dequeued and will not be returned by `recv?`.
-/
@[deprecated "Use Std.Channel.recvAllCurrent from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.recvAllCurrent (ch : Channel α) : BaseIO (Array α) :=
ch.atomically do
modifyGet fun st => (st.values.toArray, { st with values := })
/-- Type tag for synchronous (blocking) operations on a `Channel`. -/
@[deprecated "Use Std.Channel.Sync from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.Sync := Channel
/--
Accesses synchronous (blocking) version of channel operations.
For example, `ch.sync.recv?` blocks until the next message,
and `for msg in ch.sync do ...` iterates synchronously over the channel.
These functions should only be used in dedicated threads.
-/
@[deprecated "Use Std.Channel.sync from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.sync (ch : Channel α) : Channel.Sync α := ch
/--
Synchronously receives a message from the channel.
Every message is only received once.
Returns `none` if the channel is closed and the queue is empty.
-/
@[deprecated "Use Std.Channel.Sync.recv? from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.Sync.recv? (ch : Channel.Sync α) : BaseIO (Option α) := do
IO.wait ( Channel.recv? ch)
@[deprecated "Use Std.Channel.Sync.forIn from Std.Sync.Channel instead" (since := "2024-12-02")]
private partial def Channel.Sync.forIn [Monad m] [MonadLiftT BaseIO m]
(ch : Channel.Sync α) (f : α β m (ForInStep β)) : β m β := fun b => do
match ch.recv? with
| some a =>
match f a b with
| .done b => pure b
| .yield b => ch.forIn f b
| none => pure b
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
instance [MonadLiftT BaseIO m] : ForIn m (Channel.Sync α) α where
forIn ch b f := ch.forIn f b

View File

@@ -976,16 +976,6 @@ theorem coe_sub_iff_lt {a b : Fin n} : (↑(a - b) : Nat) = n + a - b ↔ a < b
/-! ### mul -/
theorem ofNat'_mul [NeZero n] (x : Nat) (y : Fin n) :
Fin.ofNat' n x * y = Fin.ofNat' n (x * y.val) := by
apply Fin.eq_of_val_eq
simp [Fin.ofNat', Fin.mul_def]
theorem mul_ofNat' [NeZero n] (x : Fin n) (y : Nat) :
x * Fin.ofNat' n y = Fin.ofNat' n (x.val * y) := by
apply Fin.eq_of_val_eq
simp [Fin.ofNat', Fin.mul_def]
theorem val_mul {n : Nat} : a b : Fin n, (a * b).val = a.val * b.val % n
| _, _, _, _ => rfl

View File

@@ -420,8 +420,6 @@ instance : IntCast Int where intCast n := n
protected def Int.cast {R : Type u} [IntCast R] : Int R :=
IntCast.intCast
@[simp] theorem Int.cast_eq (x : Int) : Int.cast x = x := rfl
-- see the notes about coercions into arbitrary types in the module doc-string
instance [IntCast R] : CoeTail Int R where coe := Int.cast

View File

@@ -2145,11 +2145,6 @@ theorem bmod_pos (x : Int) (m : Nat) (p : x % m < (m + 1) / 2) : bmod x m = x %
theorem bmod_neg (x : Int) (m : Nat) (p : x % m (m + 1) / 2) : bmod x m = (x % m) - m := by
simp [bmod_def, Int.not_lt.mpr p]
theorem bmod_eq_emod (x : Int) (m : Nat) : bmod x m = x % m - if x % m (m + 1) / 2 then m else 0 := by
split
· rwa [bmod_neg]
· rw [bmod_pos] <;> simp_all
@[simp]
theorem bmod_one_is_zero (x : Int) : Int.bmod x 1 = 0 := by
simp [Int.bmod]
@@ -2378,43 +2373,6 @@ theorem bmod_neg_bmod : bmod (-(bmod x n)) n = bmod (-x) n := by
apply (bmod_add_cancel_right x).mp
rw [Int.add_left_neg, add_bmod_bmod, Int.add_left_neg]
theorem bmod_neg_iff {m : Nat} {x : Int} (h2 : -m x) (h1 : x < m) :
(x.bmod m) < 0 (-(m / 2) x x < 0) ((m + 1) / 2 x) := by
simp only [Int.bmod_def]
by_cases xpos : 0 x
· rw [Int.emod_eq_of_lt xpos (by omega)]; omega
· rw [Int.add_emod_self.symm, Int.emod_eq_of_lt (by omega) (by omega)]; omega
theorem bmod_eq_self_of_le {n : Int} {m : Nat} (hn' : -(m / 2) n) (hn : n < (m + 1) / 2) :
n.bmod m = n := by
rw [ Int.sub_eq_zero]
have := le_bmod (x := n) (m := m) (by omega)
have := bmod_lt (x := n) (m := m) (by omega)
apply eq_zero_of_dvd_of_natAbs_lt_natAbs Int.dvd_bmod_sub_self
omega
theorem bmod_bmod_of_dvd {a : Int} {n m : Nat} (hnm : n m) :
(a.bmod m).bmod n = a.bmod n := by
rw [ Int.sub_eq_iff_eq_add.2 (bmod_add_bdiv a m).symm]
obtain k, rfl := hnm
simp [Int.mul_assoc]
theorem bmod_eq_self_of_le_mul_two {x : Int} {y : Nat} (hle : -y x * 2) (hlt : x * 2 < y) :
x.bmod y = x := by
apply bmod_eq_self_of_le (by omega) (by omega)
theorem dvd_iff_bmod_eq_zero {a : Nat} {b : Int} : (a : Int) b b.bmod a = 0 := by
rw [dvd_iff_emod_eq_zero, bmod]
split <;> rename_i h
· rfl
· simp only [Int.not_lt] at h
match a with
| 0 => omega
| a + 1 =>
have : b % (a+1) < a + 1 := emod_lt b (by omega)
simp_all
omega
/-! Helper theorems for `dvd` simproc -/
protected theorem dvd_eq_true_of_mod_eq_zero {a b : Int} (h : b % a == 0) : (a b) = True := by

View File

@@ -377,11 +377,6 @@ theorem toNat_of_nonpos : ∀ {z : Int}, z ≤ 0 → z.toNat = 0
@[simp] theorem negSucc_add_one_eq_neg_ofNat_iff {a b : Nat} : -[a+1] + 1 = - (b : Int) a = b := by
rw [eq_comm, neg_ofNat_eq_negSucc_add_one_iff, eq_comm]
protected theorem sub_eq_iff_eq_add {b a c : Int} : a - b = c a = c + b := by
refine fun h => ?_, fun h => ?_ <;> subst h <;> simp
protected theorem sub_eq_iff_eq_add' {b a c : Int} : a - b = c a = b + c := by
rw [Int.sub_eq_iff_eq_add, Int.add_comm]
/- ## add/sub injectivity -/
@[simp] protected theorem add_left_inj {i j : Int} (k : Int) : (i + k = j + k) i = j := by

View File

@@ -19,6 +19,9 @@ namespace Int
@[simp] theorem natCast_le_zero : {n : Nat} (n : Int) 0 n = 0 := by omega
protected theorem sub_eq_iff_eq_add {b a c : Int} : a - b = c a = c + b := by omega
protected theorem sub_eq_iff_eq_add' {b a c : Int} : a - b = c a = b + c := by omega
@[simp] protected theorem neg_nonpos_iff (i : Int) : -i 0 0 i := by omega
@[simp] theorem zero_le_ofNat (n : Nat) : 0 ((no_index (OfNat.ofNat n)) : Int) :=
@@ -78,6 +81,15 @@ theorem eq_ofNat_toNat {a : Int} : a = a.toNat ↔ 0 ≤ a := by omega
theorem toNat_le_toNat {n m : Int} (h : n m) : n.toNat m.toNat := by omega
theorem toNat_lt_toNat {n m : Int} (hn : 0 < m) : n.toNat < m.toNat n < m := by omega
/-! ### natAbs -/
theorem eq_zero_of_dvd_of_natAbs_lt_natAbs {d n : Int} (h : d n) (h₁ : n.natAbs < d.natAbs) :
n = 0 := by
obtain a, rfl := h
rw [natAbs_mul] at h₁
suffices ¬ 0 < a.natAbs by simp [Int.natAbs_eq_zero.1 (Nat.eq_zero_of_not_pos this)]
exact fun h => Nat.lt_irrefl _ (Nat.lt_of_le_of_lt (Nat.le_mul_of_pos_right d.natAbs h) h₁)
/-! ### min and max -/
@[simp] protected theorem min_assoc : (a b c : Int), min (min a b) c = min a (min b c) := by omega
@@ -116,6 +128,33 @@ protected theorem sub_min_sub_left (a b c : Int) : min (a - b) (a - c) = a - max
protected theorem sub_max_sub_left (a b c : Int) : max (a - b) (a - c) = a - min b c := by omega
/-! ### bmod -/
theorem bmod_neg_iff {m : Nat} {x : Int} (h2 : -m x) (h1 : x < m) :
(x.bmod m) < 0 (-(m / 2) x x < 0) ((m + 1) / 2 x) := by
simp only [Int.bmod_def]
by_cases xpos : 0 x
· rw [Int.emod_eq_of_lt xpos (by omega)]; omega
· rw [Int.add_emod_self.symm, Int.emod_eq_of_lt (by omega) (by omega)]; omega
theorem bmod_eq_self_of_le {n : Int} {m : Nat} (hn' : -(m / 2) n) (hn : n < (m + 1) / 2) :
n.bmod m = n := by
rw [ Int.sub_eq_zero]
have := le_bmod (x := n) (m := m) (by omega)
have := bmod_lt (x := n) (m := m) (by omega)
apply eq_zero_of_dvd_of_natAbs_lt_natAbs Int.dvd_bmod_sub_self
omega
theorem bmod_bmod_of_dvd {a : Int} {n m : Nat} (hnm : n m) :
(a.bmod m).bmod n = a.bmod n := by
rw [ Int.sub_eq_iff_eq_add.2 (bmod_add_bdiv a m).symm]
obtain k, rfl := hnm
simp [Int.mul_assoc]
theorem bmod_eq_self_of_le_mul_two {x : Int} {y : Nat} (hle : -y x * 2) (hlt : x * 2 < y) :
x.bmod y = x := by
apply bmod_eq_self_of_le (by omega) (by omega)
theorem mul_le_mul_of_natAbs_le {x y : Int} {s t : Nat} (hx : x.natAbs s) (hy : y.natAbs t) :
x * y s * t := by
by_cases 0 < s 0 < t

View File

@@ -329,9 +329,9 @@ protected theorem le_iff_lt_add_one {a b : Int} : a ≤ b ↔ a < b + 1 := by
/- ### min and max -/
@[grind =] protected theorem min_def (n m : Int) : min n m = if n m then n else m := rfl
protected theorem min_def (n m : Int) : min n m = if n m then n else m := rfl
@[grind =] protected theorem max_def (n m : Int) : max n m = if n m then m else n := rfl
protected theorem max_def (n m : Int) : max n m = if n m then m else n := rfl
@[simp] protected theorem neg_min_neg (a b : Int) : min (-a) (-b) = -max a b := by
rw [Int.min_def, Int.max_def]
@@ -562,16 +562,6 @@ theorem natAbs_sub_of_nonneg_of_le {a b : Int} (h₁ : 0 ≤ b) (h₂ : b ≤ a)
· rwa [ Int.ofNat_le, natAbs_of_nonneg h₁, natAbs_of_nonneg (Int.le_trans h₁ h₂)]
· exact Int.sub_nonneg_of_le h₂
theorem eq_zero_of_dvd_of_natAbs_lt_natAbs {d n : Int} (h : d n) (h₁ : n.natAbs < d.natAbs) :
n = 0 := by
let a, ha := h
subst ha
rw [natAbs_mul] at h₁
suffices ¬ 0 < a.natAbs by simp [Int.natAbs_eq_zero.1 (Nat.eq_zero_of_not_pos this)]
refine fun h => Nat.lt_irrefl _ (Nat.lt_of_le_of_lt ?_ h₁)
rw (occs := [1]) [ Nat.mul_one d.natAbs]
exact Nat.mul_le_mul (Nat.le_refl _) h
/-! ### toNat -/
theorem toNat_eq_max : a : Int, (toNat a : Int) = max a 0
@@ -613,14 +603,6 @@ theorem toNat_mul {a b : Int} (ha : 0 ≤ a) (hb : 0 ≤ b) : (a * b).toNat = a.
match a, b, eq_ofNat_of_zero_le ha, eq_ofNat_of_zero_le hb with
| _, _, _, rfl, _, rfl => rfl
/--
Variant of `Int.toNat_sub` taking non-negativity hypotheses,
rather than expecting the arguments to be casts of natural numbers.
-/
theorem toNat_sub'' {a b : Int} (ha : 0 a) (hb : 0 b) : (a - b).toNat = a.toNat - b.toNat :=
match a, b, eq_ofNat_of_zero_le ha, eq_ofNat_of_zero_le hb with
| _, _, _, rfl, _, rfl => toNat_sub _ _
theorem toNat_add_nat {a : Int} (ha : 0 a) (n : Nat) : (a + n).toNat = a.toNat + n :=
match a, eq_ofNat_of_zero_le ha with | _, _, rfl => rfl

View File

@@ -477,7 +477,7 @@ theorem attach_filterMap {l : List α} {f : α → Option β} :
· simp only [h]
rfl
rw [ih]
simp only [map_filterMap, Option.map_pbind, Option.map_some]
simp only [map_filterMap, Option.map_pbind, Option.map_some']
rfl
· simp only [Option.pbind_eq_some_iff] at h
obtain a, h, w := h

View File

@@ -95,10 +95,10 @@ theorem findSome?_eq_some_iff {f : α → Option β} {l : List α} {b : β} :
| cons x xs ih =>
simp [guard, findSome?, find?]
split <;> rename_i h
· simp only [Option.guard_eq_some_iff] at h
· simp only [Option.guard_eq_some] at h
obtain rfl, h := h
simp [h]
· simp only [Option.guard_eq_none_iff] at h
· simp only [Option.guard_eq_none] at h
simp [ih, h]
theorem find?_eq_findSome?_guard {l : List α} : find? p l = findSome? (Option.guard fun x => p x) l :=
@@ -700,7 +700,6 @@ theorem findIdx?_eq_none_iff {xs : List α} {p : α → Bool} :
simp only [findIdx?_cons]
split <;> simp_all [cond_eq_if]
@[simp]
theorem findIdx?_isSome {xs : List α} {p : α Bool} :
(xs.findIdx? p).isSome = xs.any p := by
induction xs with
@@ -709,7 +708,6 @@ theorem findIdx?_isSome {xs : List α} {p : α → Bool} :
simp only [findIdx?_cons]
split <;> simp_all
@[simp]
theorem findIdx?_isNone {xs : List α} {p : α Bool} :
(xs.findIdx? p).isNone = xs.all (¬p ·) := by
induction xs with
@@ -770,7 +768,7 @@ theorem findIdx?_eq_some_iff_getElem {xs : List α} {p : α → Bool} {i : Nat}
not_and, Classical.not_forall, Bool.not_eq_false]
intros
refine 0, zero_lt_succ i, _
· simp only [Option.map_eq_some_iff, ih, Bool.not_eq_true, length_cons]
· simp only [Option.map_eq_some', ih, Bool.not_eq_true, length_cons]
constructor
· rintro a, h, h₁, h₂, rfl
refine Nat.succ_lt_succ_iff.mpr h, by simpa, fun j hj => ?_
@@ -826,7 +824,7 @@ abbrev findIdx?_of_eq_none := @of_findIdx?_eq_none
(xs ++ ys : List α).findIdx? p =
(xs.findIdx? p).or ((ys.findIdx? p).map fun i => i + xs.length) := by
induction xs with simp
| cons _ _ _ => split <;> simp_all [Option.map_or, Option.map_map]; rfl
| cons _ _ _ => split <;> simp_all [Option.map_or', Option.map_map]; rfl
theorem findIdx?_flatten {l : List (List α)} {p : α Bool} :
l.flatten.findIdx? p =
@@ -986,24 +984,6 @@ theorem findFinIdx?_eq_some_iff {xs : List α} {p : α → Bool} {i : Fin xs.len
· rintro h, w
exact i, i.2, h, fun j hji => w j, by omega hji, rfl
@[simp]
theorem isSome_findFinIdx? {l : List α} {p : α Bool} :
(l.findFinIdx? p).isSome = l.any p := by
induction l with
| nil => simp
| cons x xs ih =>
simp only [findFinIdx?_cons]
split <;> simp_all
@[simp]
theorem isNone_findFinIdx? {l : List α} {p : α Bool} :
(l.findFinIdx? p).isNone = l.all (fun x => ¬ p x) := by
induction l with
| nil => simp
| cons x xs ih =>
simp only [findFinIdx?_cons]
split <;> simp_all
@[simp] theorem findFinIdx?_subtype {p : α Prop} {l : List { x // p x }}
{f : { x // p x } Bool} {g : α Bool} (hf : x h, f x, h = g x) :
l.findFinIdx? f = (l.unattach.findFinIdx? g).map (fun i => i.cast (by simp)) := by
@@ -1104,24 +1084,6 @@ theorem idxOf?_eq_map_finIdxOf?_val [BEq α] {xs : List α} {a : α} :
l.finIdxOf? a = some i l[i] = a j (_ : j < i), ¬l[j] = a := by
simp only [finIdxOf?, findFinIdx?_eq_some_iff, beq_iff_eq]
@[simp]
theorem isSome_finIdxOf? [BEq α] [LawfulBEq α] {l : List α} {a : α} :
(l.finIdxOf? a).isSome a l := by
induction l with
| nil => simp
| cons x xs ih =>
simp only [finIdxOf?_cons]
split <;> simp_all [@eq_comm _ x a]
@[simp]
theorem isNone_finIdxOf? [BEq α] [LawfulBEq α] {l : List α} {a : α} :
(l.finIdxOf? a).isNone = ¬ a l := by
induction l with
| nil => simp
| cons x xs ih =>
simp only [finIdxOf?_cons]
split <;> simp_all [@eq_comm _ x a]
/-! ### idxOf?
The verification API for `idxOf?` is still incomplete.
@@ -1147,25 +1109,6 @@ theorem idxOf?_cons [BEq α] {a : α} {xs : List α} {b : α} :
@[deprecated idxOf?_eq_none_iff (since := "2025-01-29")]
abbrev indexOf?_eq_none_iff := @idxOf?_eq_none_iff
@[simp]
theorem isSome_idxOf? [BEq α] [LawfulBEq α] {l : List α} {a : α} :
(l.idxOf? a).isSome a l := by
induction l with
| nil => simp
| cons x xs ih =>
simp only [idxOf?_cons]
split <;> simp_all [@eq_comm _ x a]
@[simp]
theorem isNone_idxOf? [BEq α] [LawfulBEq α] {l : List α} {a : α} :
(l.idxOf? a).isNone = ¬ a l := by
induction l with
| nil => simp
| cons x xs ih =>
simp only [idxOf?_cons]
split <;> simp_all [@eq_comm _ x a]
/-! ### lookup -/
section lookup

View File

@@ -105,7 +105,7 @@ abbrev length_eq_zero := @length_eq_zero_iff
theorem eq_nil_iff_length_eq_zero : l = [] length l = 0 :=
length_eq_zero_iff.symm
@[grind ] theorem length_pos_of_mem {a : α} : {l : List α}, a l 0 < length l
@[grind] theorem length_pos_of_mem {a : α} : {l : List α}, a l 0 < length l
| _::_, _ => Nat.zero_lt_succ _
theorem exists_mem_of_length_pos : {l : List α}, 0 < length l a, a l
@@ -185,7 +185,7 @@ theorem singleton_inj {α : Type _} {a b : α} : [a] = [b] ↔ a = b := by
We simplify `l.get i` to `l[i.1]'i.2` and `l.get? i` to `l[i]?`.
-/
@[simp, grind =]
@[simp, grind]
theorem get_eq_getElem {l : List α} {i : Fin l.length} : l.get i = l[i.1]'i.2 := rfl
set_option linter.deprecated false in
@@ -225,7 +225,7 @@ theorem get?_eq_getElem? {l : List α} {i : Nat} : l.get? i = l[i]? := by
We simplify `l[i]!` to `(l[i]?).getD default`.
-/
@[simp, grind =]
@[simp, grind]
theorem getElem!_eq_getElem?_getD [Inhabited α] {l : List α} {i : Nat} :
l[i]! = (l[i]?).getD (default : α) := by
simp only [getElem!_def]
@@ -235,16 +235,16 @@ theorem getElem!_eq_getElem?_getD [Inhabited α] {l : List α} {i : Nat} :
/-! ### getElem? and getElem -/
@[simp, grind =] theorem getElem?_nil {i : Nat} : ([] : List α)[i]? = none := rfl
@[simp, grind] theorem getElem?_nil {i : Nat} : ([] : List α)[i]? = none := rfl
theorem getElem_cons {l : List α} (w : i < (a :: l).length) :
(a :: l)[i] =
if h : i = 0 then a else l[i-1]'(match i, h with | i+1, _ => succ_lt_succ_iff.mp w) := by
cases i <;> simp
@[grind =] theorem getElem?_cons_zero {l : List α} : (a::l)[0]? = some a := rfl
@[grind] theorem getElem?_cons_zero {l : List α} : (a::l)[0]? = some a := rfl
@[simp, grind =] theorem getElem?_cons_succ {l : List α} : (a::l)[i+1]? = l[i]? := rfl
@[simp, grind] theorem getElem?_cons_succ {l : List α} : (a::l)[i+1]? = l[i]? := rfl
theorem getElem?_cons : (a :: l)[i]? = if i = 0 then some a else l[i-1]? := by
cases i <;> simp [getElem?_cons_zero]
@@ -337,7 +337,7 @@ We simplify away `getD`, replacing `getD l n a` with `(l[n]?).getD a`.
Because of this, there is only minimal API for `getD`.
-/
@[simp, grind =]
@[simp, grind]
theorem getD_eq_getElem?_getD {l : List α} {i : Nat} {a : α} : getD l i a = (l[i]?).getD a := by
simp [getD]

View File

@@ -339,7 +339,7 @@ theorem getElem?_mapIdx_go : ∀ {l : List α} {acc : Array β} {i : Nat},
if h : i < acc.size then some acc[i] else Option.map (f i) l[i - acc.size]?
| [], acc, i => by
simp only [mapIdx.go, Array.toListImpl_eq, getElem?_def, Array.length_toList,
Array.getElem_toList, length_nil, Nat.not_lt_zero, reduceDIte, Option.map_none]
Array.getElem_toList, length_nil, Nat.not_lt_zero, reduceDIte, Option.map_none']
| a :: l, acc, i => by
rw [mapIdx.go, getElem?_mapIdx_go]
simp only [Array.size_push]

View File

@@ -31,33 +31,6 @@ theorem count_set [BEq α] {a b : α} {l : List α} {i : Nat} (h : i < l.length)
(l.set i a).count b = l.count b - (if l[i] == b then 1 else 0) + (if a == b then 1 else 0) := by
simp [count_eq_countP, countP_set, h]
theorem countP_replace [BEq α] [LawfulBEq α] {a b : α} {l : List α} {p : α Bool} :
(l.replace a b).countP p =
if l.contains a then l.countP p + (if p b then 1 else 0) - (if p a then 1 else 0) else l.countP p := by
induction l with
| nil => simp
| cons x l ih =>
simp [replace_cons]
split <;> rename_i h
· simp at h
simp [h, ih, countP_cons]
omega
· simp only [beq_eq_false_iff_ne, ne_eq] at h
simp only [countP_cons, ih, contains_eq_mem, decide_eq_true_eq, mem_cons, h, false_or]
split <;> rename_i h'
· by_cases h'' : p a
· have : countP p l > 0 := countP_pos_iff.mpr a, h', h''
simp [h'']
omega
· simp [h'']
omega
· omega
theorem count_replace [BEq α] [LawfulBEq α] {a b c : α} {l : List α} :
(l.replace a b).count c =
if l.contains a then l.count c + (if b == c then 1 else 0) - (if a == c then 1 else 0) else l.count c := by
simp [count_eq_countP, countP_replace]
/--
The number of elements satisfying a predicate in a sublist is at least the number of elements satisfying the predicate in the list,
minus the difference in the lengths.

View File

@@ -54,23 +54,4 @@ theorem set_set_perm {as : List α} {i j : Nat} (h₁ : i < as.length) (h₂ : j
subst t
apply set_set_perm' _ _ (by omega)
namespace Perm
/-- Variant of `List.Perm.take` specifying the the permutation is constant after `i` elementwise. -/
theorem take_of_getElem? {l₁ l₂ : List α} (h : l₁ ~ l₂) {i : Nat} (w : j, i j l₁[j]? = l₂[j]?) :
l₁.take i ~ l₂.take i := by
refine h.take (Perm.of_eq ?_)
ext1 j
simpa using w (i + j) (by omega)
/-- Variant of `List.Perm.drop` specifying the the permutation is constant before `i` elementwise. -/
theorem drop_of_getElem? {l₁ l₂ : List α} (h : l₁ ~ l₂) {i : Nat} (w : j, j < i l₁[j]? = l₂[j]?) :
l₁.drop i ~ l₂.drop i := by
refine h.drop (Perm.of_eq ?_)
ext1
simp only [getElem?_take]
split <;> simp_all
end Perm
end List

View File

@@ -536,22 +536,4 @@ theorem perm_insertIdx {α} (x : α) (l : List α) {i} (h : i ≤ l.length) :
simp only [insertIdx, modifyTailIdx]
refine .trans (.cons _ (ih (Nat.le_of_succ_le_succ h))) (.swap ..)
namespace Perm
theorem take {l₁ l₂ : List α} (h : l₁ ~ l₂) {n : Nat} (w : l₁.drop n ~ l₂.drop n) :
l₁.take n ~ l₂.take n := by
classical
rw [perm_iff_count] at h w
rw [ take_append_drop n l₁, take_append_drop n l₂] at h
simpa only [count_append, w, Nat.add_right_cancel_iff] using h
theorem drop {l₁ l₂ : List α} (h : l₁ ~ l₂) {n : Nat} (w : l₁.take n ~ l₂.take n) :
l₁.drop n ~ l₂.drop n := by
classical
rw [perm_iff_count] at h w
rw [ take_append_drop n l₁, take_append_drop n l₂] at h
simpa only [count_append, w, Nat.add_left_cancel_iff] using h
end Perm
end List

View File

@@ -6,7 +6,6 @@ Authors: Floris van Doorn, Leonardo de Moura
prelude
import Init.SimpLemmas
import Init.Data.NeZero
import Init.Grind.Tactics
set_option linter.missingDocs true -- keep it documented
universe u
@@ -868,7 +867,7 @@ Examples:
-/
protected abbrev min (n m : Nat) := min n m
@[grind =] protected theorem min_def {n m : Nat} : min n m = if n m then n else m := rfl
protected theorem min_def {n m : Nat} : min n m = if n m then n else m := rfl
instance : Max Nat := maxOfLe
@@ -885,7 +884,7 @@ Examples:
-/
protected abbrev max (n m : Nat) := max n m
@[grind =] protected theorem max_def {n m : Nat} : max n m = if n m then m else n := rfl
protected theorem max_def {n m : Nat} : max n m = if n m then m else n := rfl
/-! # Auxiliary theorems for well-founded recursion -/

View File

@@ -141,11 +141,11 @@ theorem toList_attach (o : Option α) :
cases o <;> simp
theorem attach_map {o : Option α} (f : α β) :
(o.map f).attach = o.attach.map (fun x, h => f x, map_eq_some_iff.2 _, h, rfl) := by
(o.map f).attach = o.attach.map (fun x, h => f x, map_eq_some.2 _, h, rfl) := by
cases o <;> simp
theorem attachWith_map {o : Option α} (f : α β) {P : β Prop} {H : (b : β), o.map f = some b P b} :
(o.map f).attachWith P H = (o.attachWith (P f) (fun _ h => H _ (map_eq_some_iff.2 _, h, rfl))).map
(o.map f).attachWith P H = (o.attachWith (P f) (fun _ h => H _ (map_eq_some.2 _, h, rfl))).map
fun x, h => f x, h := by
cases o <;> simp
@@ -174,7 +174,7 @@ theorem map_attach_eq_attachWith {o : Option α} {p : α → Prop} (f : ∀ a, o
theorem attach_bind {o : Option α} {f : α Option β} :
(o.bind f).attach =
o.attach.bind fun x, h => (f x).attach.map fun y, h' => y, bind_eq_some_iff.2 _, h, h' := by
o.attach.bind fun x, h => (f x).attach.map fun y, h' => y, bind_eq_some.2 _, h, h' := by
cases o <;> simp
theorem bind_attach {o : Option α} {f : {x // o = some x} Option β} :

View File

@@ -231,12 +231,13 @@ def merge (fn : ααα) : Option α → Option α → Option α
@[simp] theorem getD_none : getD none a = a := rfl
@[simp] theorem getD_some : getD (some a) b = a := rfl
@[simp] theorem map_none (f : α β) : none.map f = none := rfl
@[simp] theorem map_some (a) (f : α β) : (some a).map f = some (f a) := rfl
@[simp] theorem map_none' (f : α β) : none.map f = none := rfl
@[simp] theorem map_some' (a) (f : α β) : (some a).map f = some (f a) := rfl
@[simp] theorem none_bind (f : α Option β) : none.bind f = none := rfl
@[simp] theorem some_bind (a) (f : α Option β) : (some a).bind f = f a := rfl
/--
A case analysis function for `Option`.

View File

@@ -39,24 +39,18 @@ This is not an instance because it is not definitionally equal to the standard i
Try to use the Boolean comparisons `Option.isNone` or `Option.isSome` instead.
-/
@[inline] def decidableEqNone {o : Option α} : Decidable (o = none) :=
@[inline] def decidable_eq_none {o : Option α} : Decidable (o = none) :=
decidable_of_decidable_of_iff isNone_iff_eq_none
@[deprecated decidableEqNone (since := "2025-04-10"), inline]
def decidable_eq_none {o : Option α} : Decidable (o = none) :=
decidableEqNone
instance {p : α Prop} [DecidablePred p] : o : Option α, Decidable ( a, a o p a)
| none => isTrue nofun
| some a =>
if h : p a then isTrue fun _ e => some_inj.1 e h
else isFalse <| mt (· _ rfl) h
instance decidableForallMem {p : α Prop} [DecidablePred p] :
o : Option α, Decidable ( a, a o p a)
| none => isTrue nofun
| some a =>
if h : p a then isTrue fun _ e => some_inj.1 e h
else isFalse <| mt (· _ rfl) h
instance decidableExistsMem {p : α Prop} [DecidablePred p] :
o : Option α, Decidable (Exists fun a => a o p a)
| none => isFalse nofun
| some a => if h : p a then isTrue _, rfl, h else isFalse fun _, rfl, hn => h hn
instance {p : α Prop} [DecidablePred p] : o : Option α, Decidable (Exists fun a => a o p a)
| none => isFalse nofun
| some a => if h : p a then isTrue _, rfl, h else isFalse fun _, rfl, hn => h hn
/--
Given an optional value and a function that can be applied when the value is `some`, returns the

View File

@@ -17,8 +17,6 @@ theorem mem_iff {a : α} {b : Option α} : a ∈ b ↔ b = some a := .rfl
theorem mem_some {a b : α} : a some b b = a := by simp
theorem mem_some_iff {a b : α} : a some b b = a := mem_some
theorem mem_some_self (a : α) : a some a := rfl
theorem some_ne_none (x : α) : some x none := nofun
@@ -31,9 +29,6 @@ protected theorem «exists» {p : Option α → Prop} :
fun | none, hx => .inl hx | some x, hx => .inr x, hx,
fun | .inl h => _, h | .inr _, hx => _, hx
theorem eq_none_or_eq_some (a : Option α) : a = none x, a = some x :=
Option.exists.mp exists_eq'
theorem get_mem : {o : Option α} (h : isSome o), o.get h o
| some _, _ => rfl
@@ -93,9 +88,6 @@ set_option Elab.async false
theorem eq_none_iff_forall_ne_some : o = none a, o some a := by
cases o <;> simp
theorem eq_none_iff_forall_some_ne : o = none a, some a o := by
cases o <;> simp
theorem eq_none_iff_forall_not_mem : o = none a, a o :=
eq_none_iff_forall_ne_some
@@ -110,30 +102,9 @@ theorem isSome_of_mem {x : Option α} {y : α} (h : y ∈ x) : x.isSome := by
theorem isSome_of_eq_some {x : Option α} {y : α} (h : x = some y) : x.isSome := by
cases x <;> trivial
@[simp] theorem isSome_eq_false_iff : isSome a = false a.isNone = true := by
@[simp] theorem not_isSome : isSome a = false a.isNone = true := by
cases a <;> simp
@[simp] theorem isNone_eq_false_iff : isNone a = false a.isSome = true := by
cases a <;> simp
@[simp]
theorem not_isSome (a : Option α) : (!a.isSome) = a.isNone := by
cases a <;> simp
@[simp]
theorem not_comp_isSome : (! ·) @Option.isSome α = Option.isNone := by
funext
simp
@[simp]
theorem not_isNone (a : Option α) : (!a.isNone) = a.isSome := by
cases a <;> simp
@[simp]
theorem not_comp_isNone : (!·) @Option.isNone α = Option.isSome := by
funext x
simp
theorem eq_some_iff_get_eq : o = some a h : o.isSome, o.get h = a := by
cases o <;> simp
@@ -175,25 +146,17 @@ abbrev ball_ne_none := @forall_ne_none
@[simp] theorem bind_eq_bind : bind = @Option.bind α β := rfl
@[simp] theorem orElse_eq_orElse : HOrElse.hOrElse = @Option.orElse α := rfl
@[simp] theorem bind_some (x : Option α) : x.bind some = x := by cases x <;> rfl
@[simp] theorem bind_none (x : Option α) : x.bind (fun _ => none (α := β)) = none := by
cases x <;> rfl
theorem bind_eq_some_iff : x.bind f = some b a, x = some a f a = some b := by
theorem bind_eq_some : x.bind f = some b a, x = some a f a = some b := by
cases x <;> simp
@[deprecated bind_eq_some_iff (since := "2025-04-10")]
abbrev bind_eq_some := @bind_eq_some_iff
@[simp] theorem bind_eq_none_iff {o : Option α} {f : α Option β} :
@[simp] theorem bind_eq_none {o : Option α} {f : α Option β} :
o.bind f = none a, o = some a f a = none := by cases o <;> simp
@[deprecated bind_eq_none_iff (since := "2025-04-10")]
abbrev bind_eq_none := @bind_eq_none_iff
theorem bind_eq_none' {o : Option α} {f : α Option β} :
o.bind f = none b a, o = some a f a some b := by
cases o <;> simp [eq_none_iff_forall_ne_some]
@@ -230,67 +193,50 @@ theorem isSome_apply_of_isSome_bind {α β : Type _} {x : Option α} {f : α
(isSome_apply_of_isSome_bind h) := by
cases x <;> trivial
theorem join_eq_some_iff : x.join = some a x = some (some a) := by
simp [bind_eq_some_iff]
@[deprecated join_eq_some_iff (since := "2025-04-10")]
abbrev join_eq_some := @join_eq_some_iff
theorem join_eq_some : x.join = some a x = some (some a) := by
simp [bind_eq_some]
theorem join_ne_none : x.join none z, x = some (some z) := by
simp only [ne_none_iff_exists', join_eq_some_iff, iff_self]
simp only [ne_none_iff_exists', join_eq_some, iff_self]
theorem join_ne_none' : ¬x.join = none z, x = some (some z) :=
join_ne_none
theorem join_eq_none_iff : o.join = none o = none o = some none :=
theorem join_eq_none : o.join = none o = none o = some none :=
match o with | none | some none | some (some _) => by simp
@[deprecated join_eq_none_iff (since := "2025-04-10")]
abbrev join_eq_none := @join_eq_none_iff
theorem bind_id_eq_join {x : Option (Option α)} : x.bind id = x.join := rfl
@[simp] theorem map_eq_map : Functor.map f = Option.map f := rfl
@[deprecated map_none (since := "2025-04-10")]
abbrev map_none' := @map_none
theorem map_none : f <$> none = none := rfl
@[deprecated map_some (since := "2025-04-10")]
abbrev map_some' := @map_some
theorem map_some : f <$> some a = some (f a) := rfl
@[simp] theorem map_eq_some_iff : x.map f = some b a, x = some a f a = b := by
@[simp] theorem map_eq_some' : x.map f = some b a, x = some a f a = b := by cases x <;> simp
theorem map_eq_some : f <$> x = some b a, x = some a f a = b := map_eq_some'
@[simp] theorem map_eq_none' : x.map f = none x = none := by
cases x <;> simp [map_none', map_some', eq_self_iff_true]
theorem isSome_map {x : Option α} : (f <$> x).isSome = x.isSome := by
cases x <;> simp
@[deprecated map_eq_some_iff (since := "2025-04-10")]
abbrev map_eq_some := @map_eq_some_iff
@[deprecated map_eq_some_iff (since := "2025-04-10")]
abbrev map_eq_some' := @map_eq_some_iff
@[simp] theorem map_eq_none_iff : x.map f = none x = none := by
cases x <;> simp [map_none, map_some, eq_self_iff_true]
@[deprecated map_eq_none_iff (since := "2025-04-10")]
abbrev map_eq_none := @map_eq_none_iff
@[deprecated map_eq_none_iff (since := "2025-04-10")]
abbrev map_eq_none' := @map_eq_none_iff
@[simp] theorem isSome_map {x : Option α} : (x.map f).isSome = x.isSome := by
@[simp] theorem isSome_map' {x : Option α} : (x.map f).isSome = x.isSome := by
cases x <;> simp
@[deprecated isSome_map (since := "2025-04-10")]
abbrev isSome_map' := @isSome_map
@[simp] theorem isNone_map {x : Option α} : (x.map f).isNone = x.isNone := by
@[simp] theorem isNone_map' {x : Option α} : (x.map f).isNone = x.isNone := by
cases x <;> simp
theorem map_eq_none : f <$> x = none x = none := map_eq_none'
theorem map_eq_bind {x : Option α} : x.map f = x.bind (some f) := by
cases x <;> simp [Option.bind]
theorem map_congr {x : Option α} (h : a, x = some a f a = g a) :
x.map f = x.map g := by
cases x <;> simp only [map_none, map_some, h]
cases x <;> simp only [map_none', map_some', h]
@[simp] theorem map_id_fun {α : Type u} : Option.map (id : α α) = id := by
funext; simp [map_id]
@@ -308,7 +254,7 @@ theorem get_map {f : α → β} {o : Option α} {h : (o.map f).isSome} :
@[simp] theorem map_map (h : β γ) (g : α β) (x : Option α) :
(x.map g).map h = x.map (h g) := by
cases x <;> simp only [map_none, map_some, ··]
cases x <;> simp only [map_none', map_some', ··]
theorem comp_map (h : β γ) (g : α β) (x : Option α) : x.map (h g) = (x.map g).map h :=
(map_map ..).symm
@@ -316,7 +262,7 @@ theorem comp_map (h : β → γ) (g : α → β) (x : Option α) : x.map (h ∘
@[simp] theorem map_comp_map (f : α β) (g : β γ) :
Option.map g Option.map f = Option.map (g f) := by funext x; simp
theorem mem_map_of_mem (g : α β) (h : a x) : g a Option.map g x := h.symm map_some ..
theorem mem_map_of_mem (g : α β) (h : a x) : g a Option.map g x := h.symm map_some' ..
theorem map_inj_right {f : α β} {o o' : Option α} (w : x y, f x = f y x = y) :
o.map f = o'.map f o = o' := by
@@ -346,14 +292,11 @@ theorem isSome_of_isSome_filter (p : α → Bool) (o : Option α) (h : (o.filter
@[deprecated isSome_of_isSome_filter (since := "2025-03-18")]
abbrev isSome_filter_of_isSome := @isSome_of_isSome_filter
@[simp] theorem filter_eq_none_iff {o : Option α} {p : α Bool} :
@[simp] theorem filter_eq_none {o : Option α} {p : α Bool} :
o.filter p = none a, o = some a ¬ p a := by
cases o <;> simp [filter_some]
@[deprecated filter_eq_none_iff (since := "2025-04-10")]
abbrev filter_eq_none := @filter_eq_none_iff
@[simp] theorem filter_eq_some_iff {o : Option α} {p : α Bool} :
@[simp] theorem filter_eq_some {o : Option α} {p : α Bool} :
o.filter p = some a o = some a p a := by
cases o with
| none => simp
@@ -367,9 +310,6 @@ abbrev filter_eq_none := @filter_eq_none_iff
rintro rfl
simpa using h
@[deprecated filter_eq_some_iff (since := "2025-04-10")]
abbrev filter_eq_some := @filter_eq_some_iff
theorem mem_filter_iff {p : α Bool} {a : α} {o : Option α} :
a o.filter p a o p a := by
simp
@@ -443,43 +383,29 @@ theorem join_join {x : Option (Option (Option α))} : x.join.join = (x.map join)
cases x <;> simp
theorem mem_of_mem_join {a : α} {x : Option (Option α)} (h : a x.join) : some a x :=
h.symm join_eq_some_iff.1 h
h.symm join_eq_some.1 h
@[simp] theorem some_orElse (a : α) (f) : (some a).orElse f = some a := rfl
@[simp] theorem some_orElse (a : α) (x : Option α) : (some a <|> x) = some a := rfl
@[simp] theorem none_orElse (f : Unit Option α) : none.orElse f = f () := rfl
@[simp] theorem none_orElse (x : Option α) : (none <|> x) = x := rfl
@[simp] theorem orElse_none (x : Option α) : x.orElse (fun _ => none) = x := by cases x <;> rfl
@[simp] theorem orElse_none (x : Option α) : (x <|> none) = x := by cases x <;> rfl
theorem orElse_eq_some_iff (o : Option α) (f) (x : α) :
(o.orElse f) = some x o = some x o = none f () = some x := by
cases o <;> simp
theorem orElse_eq_none_iff (o : Option α) (f) : (o.orElse f) = none o = none f () = none := by
cases o <;> simp
theorem map_orElse {x : Option α} {y} :
(x.orElse y).map f = (x.map f).orElse (fun _ => (y ()).map f) := by
theorem map_orElse {x y : Option α} : (x <|> y).map f = (x.map f <|> y.map f) := by
cases x <;> simp
@[simp] theorem guard_eq_some_iff [DecidablePred p] : guard p a = some b a = b p a :=
@[simp] theorem guard_eq_some [DecidablePred p] : guard p a = some b a = b p a :=
if h : p a then by simp [Option.guard, h] else by simp [Option.guard, h]
@[deprecated guard_eq_some_iff (since := "2025-04-10")]
abbrev guard_eq_some := @guard_eq_some_iff
@[simp] theorem isSome_guard [DecidablePred p] : (Option.guard p a).isSome p a :=
if h : p a then by simp [Option.guard, h] else by simp [Option.guard, h]
@[deprecated isSome_guard (since := "2025-03-18")]
abbrev guard_isSome := @isSome_guard
@[simp] theorem guard_eq_none_iff [DecidablePred p] : Option.guard p a = none ¬ p a :=
@[simp] theorem guard_eq_none [DecidablePred p] : Option.guard p a = none ¬ p a :=
if h : p a then by simp [Option.guard, h] else by simp [Option.guard, h]
@[deprecated guard_eq_none_iff (since := "2025-04-10")]
abbrev guard_eq_none := @guard_eq_none_iff
@[simp] theorem guard_pos [DecidablePred p] (h : p a) : Option.guard p a = some a := by
simp [Option.guard, h]
@@ -549,22 +475,6 @@ theorem liftOrGet_none_right {f} {a : Option α} : merge f a none = a :=
theorem liftOrGet_some_some {f} {a b : α} : merge f (some a) (some b) = f a b :=
merge_some_some
instance commutative_merge (f : α α α) [Std.Commutative f] :
Std.Commutative (merge f) :=
fun a b by cases a <;> cases b <;> simp [merge, Std.Commutative.comm]
instance associative_merge (f : α α α) [Std.Associative f] :
Std.Associative (merge f) :=
fun a b c by cases a <;> cases b <;> cases c <;> simp [merge, Std.Associative.assoc]
instance idempotentOp_merge (f : α α α) [Std.IdempotentOp f] :
Std.IdempotentOp (merge f) :=
fun a by cases a <;> simp [merge, Std.IdempotentOp.idempotent]
instance lawfulIdentity_merge (f : α α α) : Std.LawfulIdentity (merge f) none where
left_id a := by cases a <;> simp [merge]
right_id a := by cases a <;> simp [merge]
@[simp] theorem elim_none (x : β) (f : α β) : none.elim x f = x := rfl
@[simp] theorem elim_some (x : β) (f : α β) (a : α) : (some a).elim x f = f a := rfl
@@ -625,18 +535,12 @@ theorem or_eq_bif : or o o' = bif o.isSome then o else o' := by
@[simp] theorem isNone_or : (or o o').isNone = (o.isNone && o'.isNone) := by
cases o <;> rfl
@[simp] theorem or_eq_none_iff : or o o' = none o = none o' = none := by
@[simp] theorem or_eq_none : or o o' = none o = none o' = none := by
cases o <;> simp
@[deprecated or_eq_none_iff (since := "2025-04-10")]
abbrev or_eq_none := @or_eq_none_iff
@[simp] theorem or_eq_some_iff : or o o' = some a o = some a (o = none o' = some a) := by
@[simp] theorem or_eq_some : or o o' = some a o = some a (o = none o' = some a) := by
cases o <;> simp
@[deprecated or_eq_some_iff (since := "2025-04-10")]
abbrev or_eq_some := @or_eq_some_iff
theorem or_assoc : or (or o₁ o₂) o₃ = or o₁ (or o₂ o₃) := by
cases o₁ <;> cases o₂ <;> rfl
instance : Std.Associative (or (α := α)) := @or_assoc _
@@ -660,11 +564,11 @@ instance : Std.IdempotentOp (or (α := α)) := ⟨@or_self _⟩
theorem or_eq_orElse : or o o' = o.orElse (fun _ => o') := by
cases o <;> rfl
theorem map_or : (or o o').map f = (o.map f).or (o'.map f) := by
theorem map_or : f <$> or o o' = (f <$> o).or (f <$> o') := by
cases o <;> rfl
@[deprecated map_or (since := "2025-04-10")]
abbrev map_or' := @map_or
theorem map_or' : (or o o').map f = (o.map f).or (o'.map f) := by
cases o <;> rfl
theorem or_of_isSome {o o' : Option α} (h : o.isSome) : o.or o' = o := by
match o, h with
@@ -900,7 +804,7 @@ theorem map_pmap {p : α → Prop} (g : β → γ) (f : ∀ a, p a → β) (o H)
theorem pmap_map (o : Option α) (f : α β) {p : β Prop} (g : b, p b γ) (H) :
pmap g (o.map f) H =
pmap (fun a h => g (f a) h) o (fun a m => H (f a) (map_eq_some_iff.2 _, m, rfl)) := by
pmap (fun a h => g (f a) h) o (fun a m => H (f a) (map_eq_some.2 _, m, rfl)) := by
cases o <;> simp
theorem pmap_pred_congr {α : Type u}

View File

@@ -239,17 +239,6 @@ Examples:
@[extern "lean_int8_div"]
protected def Int8.div (a b : Int8) : Int8 := BitVec.sdiv a.toBitVec b.toBitVec
/--
The power operation, raising an 8-bit signed integer to a natural number power,
wrapping around on overflow. Usually accessed via the `^` operator.
This function is currently *not* overridden at runtime with an efficient implementation,
and should be used with caution. See https://github.com/leanprover/lean4/issues/7887.
-/
protected def Int8.pow (x : Int8) (n : Nat) : Int8 :=
match n with
| 0 => 1
| n + 1 => Int8.mul (Int8.pow x n) x
/--
The modulo operator for 8-bit signed integers, which computes the remainder when dividing one
integer by another with the T-rounding convention used by `Int8.div`. Usually accessed via the `%`
operator.
@@ -377,7 +366,6 @@ instance : Inhabited Int8 where
instance : Add Int8 := Int8.add
instance : Sub Int8 := Int8.sub
instance : Mul Int8 := Int8.mul
instance : Pow Int8 Nat := Int8.pow
instance : Mod Int8 := Int8.mod
instance : Div Int8 := Int8.div
instance : LT Int8 := Int8.lt
@@ -610,17 +598,6 @@ Examples:
@[extern "lean_int16_div"]
protected def Int16.div (a b : Int16) : Int16 := BitVec.sdiv a.toBitVec b.toBitVec
/--
The power operation, raising a 16-bit signed integer to a natural number power,
wrapping around on overflow. Usually accessed via the `^` operator.
This function is currently *not* overridden at runtime with an efficient implementation,
and should be used with caution. See https://github.com/leanprover/lean4/issues/7887.
-/
protected def Int16.pow (x : Int16) (n : Nat) : Int16 :=
match n with
| 0 => 1
| n + 1 => Int16.mul (Int16.pow x n) x
/--
The modulo operator for 16-bit signed integers, which computes the remainder when dividing one
integer by another with the T-rounding convention used by `Int16.div`. Usually accessed via the `%`
operator.
@@ -748,7 +725,6 @@ instance : Inhabited Int16 where
instance : Add Int16 := Int16.add
instance : Sub Int16 := Int16.sub
instance : Mul Int16 := Int16.mul
instance : Pow Int16 Nat := Int16.pow
instance : Mod Int16 := Int16.mod
instance : Div Int16 := Int16.div
instance : LT Int16 := Int16.lt
@@ -997,17 +973,6 @@ Examples:
@[extern "lean_int32_div"]
protected def Int32.div (a b : Int32) : Int32 := BitVec.sdiv a.toBitVec b.toBitVec
/--
The power operation, raising a 32-bit signed integer to a natural number power,
wrapping around on overflow. Usually accessed via the `^` operator.
This function is currently *not* overridden at runtime with an efficient implementation,
and should be used with caution. See https://github.com/leanprover/lean4/issues/7887.
-/
protected def Int32.pow (x : Int32) (n : Nat) : Int32 :=
match n with
| 0 => 1
| n + 1 => Int32.mul (Int32.pow x n) x
/--
The modulo operator for 32-bit signed integers, which computes the remainder when dividing one
integer by another with the T-rounding convention used by `Int32.div`. Usually accessed via the `%`
operator.
@@ -1135,7 +1100,6 @@ instance : Inhabited Int32 where
instance : Add Int32 := Int32.add
instance : Sub Int32 := Int32.sub
instance : Mul Int32 := Int32.mul
instance : Pow Int32 Nat := Int32.pow
instance : Mod Int32 := Int32.mod
instance : Div Int32 := Int32.div
instance : LT Int32 := Int32.lt
@@ -1404,17 +1368,6 @@ Examples:
@[extern "lean_int64_div"]
protected def Int64.div (a b : Int64) : Int64 := BitVec.sdiv a.toBitVec b.toBitVec
/--
The power operation, raising a 64-bit signed integer to a natural number power,
wrapping around on overflow. Usually accessed via the `^` operator.
This function is currently *not* overridden at runtime with an efficient implementation,
and should be used with caution. See https://github.com/leanprover/lean4/issues/7887.
-/
protected def Int64.pow (x : Int64) (n : Nat) : Int64 :=
match n with
| 0 => 1
| n + 1 => Int64.mul (Int64.pow x n) x
/--
The modulo operator for 64-bit signed integers, which computes the remainder when dividing one
integer by another with the T-rounding convention used by `Int64.div`. Usually accessed via the `%`
operator.
@@ -1542,7 +1495,6 @@ instance : Inhabited Int64 where
instance : Add Int64 := Int64.add
instance : Sub Int64 := Int64.sub
instance : Mul Int64 := Int64.mul
instance : Pow Int64 Nat := Int64.pow
instance : Mod Int64 := Int64.mod
instance : Div Int64 := Int64.div
instance : LT Int64 := Int64.lt
@@ -1794,17 +1746,6 @@ Examples:
@[extern "lean_isize_div"]
protected def ISize.div (a b : ISize) : ISize := BitVec.sdiv a.toBitVec b.toBitVec
/--
The power operation, raising a word-sized signed integer to a natural number power,
wrapping around on overflow. Usually accessed via the `^` operator.
This function is currently *not* overridden at runtime with an efficient implementation,
and should be used with caution. See https://github.com/leanprover/lean4/issues/7887.
-/
protected def ISize.pow (x : ISize) (n : Nat) : ISize :=
match n with
| 0 => 1
| n + 1 => ISize.mul (ISize.pow x n) x
/--
The modulo operator for word-sized signed integers, which computes the remainder when dividing one
integer by another with the T-rounding convention used by `ISize.div`. Usually accessed via the `%`
operator.
@@ -1934,7 +1875,6 @@ instance : Inhabited ISize where
instance : Add ISize := ISize.add
instance : Sub ISize := ISize.sub
instance : Mul ISize := ISize.mul
instance : Pow ISize Nat := ISize.pow
instance : Mod ISize := ISize.mod
instance : Div ISize := ISize.div
instance : LT ISize := ISize.lt

View File

@@ -2625,17 +2625,6 @@ instance : Std.LawfulCommIdentity (α := ISize) (· * ·) 1 where
@[simp] theorem Int64.zero_mul {a : Int64} : 0 * a = 0 := Int64.toBitVec_inj.1 BitVec.zero_mul
@[simp] theorem ISize.zero_mul {a : ISize} : 0 * a = 0 := ISize.toBitVec_inj.1 BitVec.zero_mul
@[simp] protected theorem Int8.pow_zero (x : Int8) : x ^ 0 = 1 := rfl
protected theorem Int8.pow_succ (x : Int8) (n : Nat) : x ^ (n + 1) = x ^ n * x := rfl
@[simp] protected theorem Int16.pow_zero (x : Int16) : x ^ 0 = 1 := rfl
protected theorem Int16.pow_succ (x : Int16) (n : Nat) : x ^ (n + 1) = x ^ n * x := rfl
@[simp] protected theorem Int32.pow_zero (x : Int32) : x ^ 0 = 1 := rfl
protected theorem Int32.pow_succ (x : Int32) (n : Nat) : x ^ (n + 1) = x ^ n * x := rfl
@[simp] protected theorem Int64.pow_zero (x : Int64) : x ^ 0 = 1 := rfl
protected theorem Int64.pow_succ (x : Int64) (n : Nat) : x ^ (n + 1) = x ^ n * x := rfl
@[simp] protected theorem ISize.pow_zero (x : ISize) : x ^ 0 = 1 := rfl
protected theorem ISize.pow_succ (x : ISize) (n : Nat) : x ^ (n + 1) = x ^ n * x := rfl
protected theorem Int8.mul_add {a b c : Int8} : a * (b + c) = a * b + a * c :=
Int8.toBitVec_inj.1 BitVec.mul_add
protected theorem Int16.mul_add {a b c : Int16} : a * (b + c) = a * b + a * c :=

View File

@@ -58,17 +58,6 @@ This function is overridden at runtime with an efficient implementation.
@[extern "lean_uint8_div"]
protected def UInt8.div (a b : UInt8) : UInt8 := BitVec.udiv a.toBitVec b.toBitVec
/--
The power operation, raising an 8-bit unsigned integer to a natural number power,
wrapping around on overflow. Usually accessed via the `^` operator.
This function is currently *not* overridden at runtime with an efficient implementation,
and should be used with caution. See https://github.com/leanprover/lean4/issues/7887.
-/
protected def UInt8.pow (x : UInt8) (n : Nat) : UInt8 :=
match n with
| 0 => 1
| n + 1 => UInt8.mul (UInt8.pow x n) x
/--
The modulo operator for 8-bit unsigned integers, which computes the remainder when dividing one
integer by another. Usually accessed via the `%` operator.
@@ -143,7 +132,6 @@ protected def UInt8.le (a b : UInt8) : Prop := a.toBitVec ≤ b.toBitVec
instance : Add UInt8 := UInt8.add
instance : Sub UInt8 := UInt8.sub
instance : Mul UInt8 := UInt8.mul
instance : Pow UInt8 Nat := UInt8.pow
instance : Mod UInt8 := UInt8.mod
set_option linter.deprecated false in
@@ -270,17 +258,6 @@ This function is overridden at runtime with an efficient implementation.
@[extern "lean_uint16_div"]
protected def UInt16.div (a b : UInt16) : UInt16 := BitVec.udiv a.toBitVec b.toBitVec
/--
The power operation, raising a 16-bit unsigned integer to a natural number power,
wrapping around on overflow. Usually accessed via the `^` operator.
This function is currently *not* overridden at runtime with an efficient implementation,
and should be used with caution. See https://github.com/leanprover/lean4/issues/7887.
-/
protected def UInt16.pow (x : UInt16) (n : Nat) : UInt16 :=
match n with
| 0 => 1
| n + 1 => UInt16.mul (UInt16.pow x n) x
/--
The modulo operator for 16-bit unsigned integers, which computes the remainder when dividing one
integer by another. Usually accessed via the `%` operator.
@@ -318,7 +295,7 @@ This function is overridden at runtime with an efficient implementation.
@[extern "lean_uint16_lor"]
protected def UInt16.lor (a b : UInt16) : UInt16 := a.toBitVec ||| b.toBitVec
/--
Bitwise exclusive or for 16-bit unsigned integers. Usually accessed via the `^^^` operator.
Bitwise exclusive or for 8-bit unsigned integers. Usually accessed via the `^^^` operator.
Each bit of the resulting integer is set if exactly one of the corresponding bits of both input
integers are set.
@@ -355,7 +332,6 @@ protected def UInt16.le (a b : UInt16) : Prop := a.toBitVec ≤ b.toBitVec
instance : Add UInt16 := UInt16.add
instance : Sub UInt16 := UInt16.sub
instance : Mul UInt16 := UInt16.mul
instance : Pow UInt16 Nat := UInt16.pow
instance : Mod UInt16 := UInt16.mod
set_option linter.deprecated false in
@@ -484,17 +460,6 @@ This function is overridden at runtime with an efficient implementation.
@[extern "lean_uint32_div"]
protected def UInt32.div (a b : UInt32) : UInt32 := BitVec.udiv a.toBitVec b.toBitVec
/--
The power operation, raising a 32-bit unsigned integer to a natural number power,
wrapping around on overflow. Usually accessed via the `^` operator.
This function is currently *not* overridden at runtime with an efficient implementation,
and should be used with caution. See https://github.com/leanprover/lean4/issues/7887.
-/
protected def UInt32.pow (x : UInt32) (n : Nat) : UInt32 :=
match n with
| 0 => 1
| n + 1 => UInt32.mul (UInt32.pow x n) x
/--
The modulo operator for 32-bit unsigned integers, which computes the remainder when dividing one
integer by another. Usually accessed via the `%` operator.
@@ -569,7 +534,6 @@ protected def UInt32.le (a b : UInt32) : Prop := a.toBitVec ≤ b.toBitVec
instance : Add UInt32 := UInt32.add
instance : Sub UInt32 := UInt32.sub
instance : Mul UInt32 := UInt32.mul
instance : Pow UInt32 Nat := UInt32.pow
instance : Mod UInt32 := UInt32.mod
set_option linter.deprecated false in
@@ -660,17 +624,6 @@ This function is overridden at runtime with an efficient implementation.
@[extern "lean_uint64_div"]
protected def UInt64.div (a b : UInt64) : UInt64 := BitVec.udiv a.toBitVec b.toBitVec
/--
The power operation, raising a 64-bit unsigned integer to a natural number power,
wrapping around on overflow. Usually accessed via the `^` operator.
This function is currently *not* overridden at runtime with an efficient implementation,
and should be used with caution. See https://github.com/leanprover/lean4/issues/7887.
-/
protected def UInt64.pow (x : UInt64) (n : Nat) : UInt64 :=
match n with
| 0 => 1
| n + 1 => UInt64.mul (UInt64.pow x n) x
/--
The modulo operator for 64-bit unsigned integers, which computes the remainder when dividing one
integer by another. Usually accessed via the `%` operator.
@@ -745,7 +698,6 @@ protected def UInt64.le (a b : UInt64) : Prop := a.toBitVec ≤ b.toBitVec
instance : Add UInt64 := UInt64.add
instance : Sub UInt64 := UInt64.sub
instance : Mul UInt64 := UInt64.mul
instance : Pow UInt64 Nat := UInt64.pow
instance : Mod UInt64 := UInt64.mod
set_option linter.deprecated false in
@@ -766,7 +718,7 @@ This function is overridden at runtime with an efficient implementation.
@[extern "lean_uint64_complement"]
protected def UInt64.complement (a : UInt64) : UInt64 := ~~~a.toBitVec
/--
Negation of 64-bit unsigned integers, computed modulo `UInt64.size`.
Negation of 32-bit unsigned integers, computed modulo `UInt64.size`.
`UInt64.neg a` is equivalent to `18_446_744_073_709_551_615 - a + 1`.
@@ -867,17 +819,6 @@ This function is overridden at runtime with an efficient implementation.
@[extern "lean_usize_div"]
protected def USize.div (a b : USize) : USize := a.toBitVec / b.toBitVec
/--
The power operation, raising a word-sized unsigned integer to a natural number power,
wrapping around on overflow. Usually accessed via the `^` operator.
This function is currently *not* overridden at runtime with an efficient implementation,
and should be used with caution. See https://github.com/leanprover/lean4/issues/7887.
-/
protected def USize.pow (x : USize) (n : Nat) : USize :=
match n with
| 0 => 1
| n + 1 => USize.mul (USize.pow x n) x
/--
The modulo operator for word-sized unsigned integers, which computes the remainder when dividing one
integer by another. Usually accessed via the `%` operator.
@@ -1011,7 +952,6 @@ def USize.toUInt64 (a : USize) : UInt64 :=
UInt64.ofNatLT a.toBitVec.toNat (Nat.lt_of_lt_of_le a.toBitVec.isLt USize.size_le)
instance : Mul USize := USize.mul
instance : Pow USize Nat := USize.pow
instance : Mod USize := USize.mod
set_option linter.deprecated false in

View File

@@ -2767,17 +2767,6 @@ instance : Std.LawfulCommIdentity (α := USize) (· * ·) 1 where
@[simp] theorem UInt64.zero_mul {a : UInt64} : 0 * a = 0 := UInt64.toBitVec_inj.1 BitVec.zero_mul
@[simp] theorem USize.zero_mul {a : USize} : 0 * a = 0 := USize.toBitVec_inj.1 BitVec.zero_mul
@[simp] protected theorem UInt8.pow_zero (x : UInt8) : x ^ 0 = 1 := rfl
protected theorem UInt8.pow_succ (x : UInt8) (n : Nat) : x ^ (n + 1) = x ^ n * x := rfl
@[simp] protected theorem UInt16.pow_zero (x : UInt16) : x ^ 0 = 1 := rfl
protected theorem UInt16.pow_succ (x : UInt16) (n : Nat) : x ^ (n + 1) = x ^ n * x := rfl
@[simp] protected theorem UInt32.pow_zero (x : UInt32) : x ^ 0 = 1 := rfl
protected theorem UInt32.pow_succ (x : UInt32) (n : Nat) : x ^ (n + 1) = x ^ n * x := rfl
@[simp] protected theorem UInt64.pow_zero (x : UInt64) : x ^ 0 = 1 := rfl
protected theorem UInt64.pow_succ (x : UInt64) (n : Nat) : x ^ (n + 1) = x ^ n * x := rfl
@[simp] protected theorem USize.pow_zero (x : USize) : x ^ 0 = 1 := rfl
protected theorem USize.pow_succ (x : USize) (n : Nat) : x ^ (n + 1) = x ^ n * x := rfl
protected theorem UInt8.mul_add {a b c : UInt8} : a * (b + c) = a * b + a * c :=
UInt8.toBitVec_inj.1 BitVec.mul_add
protected theorem UInt16.mul_add {a b c : UInt16} : a * (b + c) = a * b + a * c :=

View File

@@ -239,15 +239,4 @@ theorem count_flatMap {α} [BEq β] {xs : Vector α n} {f : α → Vector β m}
rcases xs with xs, rfl
simp [Array.count_flatMap, Function.comp_def]
theorem countP_replace {a b : α} {xs : Vector α n} {p : α Bool} :
(xs.replace a b).countP p =
if xs.contains a then xs.countP p + (if p b then 1 else 0) - (if p a then 1 else 0) else xs.countP p := by
rcases xs with xs, rfl
simp [Array.countP_replace]
theorem count_replace {a b c : α} {xs : Vector α n} :
(xs.replace a b).count c =
if xs.contains a then xs.count c + (if b == c then 1 else 0) - (if a == c then 1 else 0) else xs.count c := by
simp [count_eq_countP, countP_replace]
end count

View File

@@ -294,18 +294,6 @@ theorem find?_eq_some_iff_getElem {xs : Vector α n} {p : α → Bool} {b : α}
subst w
simp
@[simp]
theorem isSome_findFinIdx? {xs : Vector α n} {p : α Bool} :
(xs.findFinIdx? p).isSome = xs.any p := by
rcases xs with xs, rfl
simp
@[simp]
theorem isNone_findFinIdx? {xs : Vector α n} {p : α Bool} :
(xs.findFinIdx? p).isNone = xs.all (fun x => ¬ p x) := by
rcases xs with xs, rfl
simp
@[simp] theorem findFinIdx?_subtype {p : α Prop} {xs : Vector { x // p x } n}
{f : { x // p x } Bool} {g : α Bool} (hf : x h, f x, h = g x) :
xs.findFinIdx? f = xs.unattach.findFinIdx? g := by

View File

@@ -1511,7 +1511,7 @@ theorem map_eq_iff {f : α → β} {as : Vector α n} {bs : Vector β n} :
if h : i < as.size then
simpa [h, h'] using w i h
else
rw [getElem?_neg, getElem?_neg, Option.map_none] <;> omega
rw [getElem?_neg, getElem?_neg, Option.map_none'] <;> omega
@[simp] theorem map_set {f : α β} {xs : Vector α n} {i : Nat} {h : i < n} {a : α} :
(xs.set i a).map f = (xs.map f).set i (f a) (by simpa using h) := by

View File

@@ -5,30 +5,14 @@ Authors: Kim Morrison
-/
prelude
import Init.Data.Zero
import Init.Data.Int.DivMod.Lemmas
import Init.TacticsExtra
/-!
# A monolithic commutative ring typeclass for internal use in `grind`.
The `Lean.Grind.CommRing` class will be used to convert expressions into the internal representation via polynomials,
with coefficients expressed via `OfNat` and `Neg`.
The `IsCharP α p` typeclass expresses that the ring has characteristic `p`,
i.e. that a coefficient `OfNat.ofNat x : α` is zero if and only if `x % p = 0` (in `Nat`).
See
```
theorem ofNat_ext_iff {x y : Nat} : OfNat.ofNat (α := α) x = OfNat.ofNat (α := α) y ↔ x % p = y % p
theorem ofNat_emod (x : Nat) : OfNat.ofNat (α := α) (x % p) = OfNat.ofNat x
theorem ofNat_eq_iff_of_lt {x y : Nat} (h₁ : x < p) (h₂ : y < p) :
OfNat.ofNat (α := α) x = OfNat.ofNat (α := α) y ↔ x = y
```
-/
namespace Lean.Grind
class CommRing (α : Type u) extends Add α, Mul α, Neg α, Sub α, HPow α Nat α where
[ofNat : n, OfNat α n]
class CommRing (α : Type u) extends Add α, Zero α, Mul α, One α, Neg α where
add_assoc : a b c : α, a + b + c = a + (b + c)
add_comm : a b : α, a + b = b + a
add_zero : a : α, a + 0 = a
@@ -38,31 +22,11 @@ class CommRing (α : Type u) extends Add α, Mul α, Neg α, Sub α, HPow α Nat
mul_one : a : α, a * 1 = a
left_distrib : a b c : α, a * (b + c) = a * b + a * c
zero_mul : a : α, 0 * a = 0
sub_eq_add_neg : a b : α, a - b = a + -b
pow_zero : a : α, a ^ 0 = 1
pow_succ : a : α, n : Nat, a ^ (n + 1) = (a ^ n) * a
ofNat_succ : a : Nat, OfNat.ofNat (α := α) (a + 1) = OfNat.ofNat a + 1 := by intros; rfl
-- This is a low-priority instance, to avoid conflicts with existing `OfNat` instances.
attribute [instance 100] CommRing.ofNat
namespace CommRing
variable {α : Type u} [CommRing α]
instance : NatCast α where
natCast n := OfNat.ofNat n
theorem natCast_zero : ((0 : Nat) : α) = 0 := rfl
theorem ofNat_eq_natCast (n : Nat) : OfNat.ofNat n = (n : α) := rfl
theorem ofNat_add (a b : Nat) : OfNat.ofNat (α := α) (a + b) = OfNat.ofNat a + OfNat.ofNat b := by
induction b with
| zero => simp [Nat.add_zero, add_zero]
| succ b ih => rw [Nat.add_succ, ofNat_succ, ih, ofNat_succ b, add_assoc]
theorem natCast_succ (n : Nat) : ((n + 1 : Nat) : α) = ((n : α) + 1) := ofNat_add _ _
theorem zero_add (a : α) : 0 + a = a := by
rw [add_comm, add_zero]
@@ -78,204 +42,6 @@ theorem right_distrib (a b c : α) : (a + b) * c = a * c + b * c := by
theorem mul_zero (a : α) : a * 0 = 0 := by
rw [mul_comm, zero_mul]
theorem ofNat_mul (a b : Nat) : OfNat.ofNat (α := α) (a * b) = OfNat.ofNat a * OfNat.ofNat b := by
induction b with
| zero => simp [Nat.mul_zero, mul_zero]
| succ a ih => rw [Nat.mul_succ, ofNat_add, ih, ofNat_add, left_distrib, mul_one]
theorem add_left_inj {a b : α} (c : α) : a + c = b + c a = b :=
fun h => by simpa [add_assoc, add_neg_cancel, add_zero] using (congrArg (· + -c) h),
fun g => congrArg (· + c) g
theorem add_right_inj (a b c : α) : a + b = a + c b = c := by
rw [add_comm a b, add_comm a c, add_left_inj]
theorem neg_zero : (-0 : α) = 0 := by
rw [ add_left_inj 0, neg_add_cancel, add_zero]
theorem neg_neg (a : α) : -(-a) = a := by
rw [ add_left_inj (-a), neg_add_cancel, add_neg_cancel]
theorem neg_eq_zero (a : α) : -a = 0 a = 0 :=
fun h => by
replace h := congrArg (-·) h
simpa [neg_neg, neg_zero] using h,
fun h => by rw [h, neg_zero]
theorem neg_add (a b : α) : -(a + b) = -a + -b := by
rw [ add_left_inj (a + b), neg_add_cancel, add_assoc (-a), add_comm a b, add_assoc (-b),
neg_add_cancel, zero_add, neg_add_cancel]
theorem neg_sub (a b : α) : -(a - b) = b - a := by
rw [sub_eq_add_neg, neg_add, neg_neg, sub_eq_add_neg, add_comm]
theorem sub_self (a : α) : a - a = 0 := by
rw [sub_eq_add_neg, add_neg_cancel]
instance : IntCast α where
intCast n := match n with
| Int.ofNat n => OfNat.ofNat n
| Int.negSucc n => -OfNat.ofNat (n + 1)
theorem intCast_zero : ((0 : Int) : α) = 0 := rfl
theorem intCast_one : ((1 : Int) : α) = 1 := rfl
theorem intCast_neg_one : ((-1 : Int) : α) = -1 := rfl
theorem intCast_ofNat (n : Nat) : ((n : Int) : α) = (n : α) := rfl
theorem intCast_ofNat_add_one (n : Nat) : ((n + 1 : Int) : α) = (n : α) + 1 := ofNat_add _ _
theorem intCast_negSucc (n : Nat) : ((-(n + 1) : Int) : α) = -((n : α) + 1) := congrArg (- ·) (ofNat_add _ _)
theorem intCast_neg (x : Int) : ((-x : Int) : α) = - (x : α) :=
match x with
| (0 : Nat) => neg_zero.symm
| (n + 1 : Nat) => by
rw [Int.natCast_add, Int.cast_ofNat_Int, intCast_negSucc, intCast_ofNat_add_one]
| -((n : Nat) + 1) => by
rw [Int.neg_neg, intCast_ofNat_add_one, intCast_negSucc, neg_neg]
theorem intCast_nat_add {x y : Nat} : ((x + y : Int) : α) = ((x : α) + (y : α)) := ofNat_add _ _
theorem intCast_nat_sub {x y : Nat} (h : x y) : (((x - y : Nat) : Int) : α) = ((x : α) - (y : α)) := by
induction x with
| zero =>
have : y = 0 := by omega
simp [this, intCast_zero, natCast_zero, sub_eq_add_neg, zero_add, neg_zero]
| succ x ih =>
by_cases h : x + 1 = y
· simp [h, intCast_zero, sub_self]
· have : ((x + 1 - y : Nat) : Int) = (x - y : Nat) + 1 := by omega
rw [this, intCast_ofNat_add_one]
specialize ih (by omega)
rw [intCast_ofNat] at ih
rw [ih, natCast_succ, sub_eq_add_neg, sub_eq_add_neg, add_assoc, add_comm _ 1, add_assoc]
theorem intCast_add (x y : Int) : ((x + y : Int) : α) = ((x : α) + (y : α)) :=
match x, y with
| (x : Nat), (y : Nat) => ofNat_add _ _
| (x : Nat), (-(y + 1 : Nat)) => by
by_cases h : x y + 1
· have : (x + -(y+1 : Nat) : Int) = ((x - (y + 1) : Nat) : Int) := by omega
rw [this, intCast_neg, intCast_nat_sub h, intCast_ofNat, intCast_ofNat, sub_eq_add_neg]
· have : (x + -(y+1 : Nat) : Int) = (-(y + 1 - x : Nat) : Int) := by omega
rw [this, intCast_neg, intCast_nat_sub (by omega), intCast_ofNat, intCast_neg, intCast_ofNat,
neg_sub, sub_eq_add_neg]
| (-(x + 1 : Nat)), (y : Nat) => by
by_cases h : y x+ 1
· have : (-(x+1 : Nat) + y : Int) = ((y - (x + 1) : Nat) : Int) := by omega
rw [this, intCast_neg, intCast_nat_sub h, intCast_ofNat, intCast_ofNat, sub_eq_add_neg, add_comm]
· have : (-(x+1 : Nat) + y : Int) = (-(x + 1 - y : Nat) : Int) := by omega
rw [this, intCast_neg, intCast_nat_sub (by omega), intCast_ofNat, intCast_neg, intCast_ofNat,
neg_sub, sub_eq_add_neg, add_comm]
| (-(x + 1 : Nat)), (-(y + 1 : Nat)) => by
rw [ Int.neg_add, intCast_neg, intCast_nat_add, neg_add, intCast_neg, intCast_neg, intCast_ofNat, intCast_ofNat]
theorem intCast_sub (x y : Int) : ((x - y : Int) : α) = ((x : α) - (y : α)) := by
rw [Int.sub_eq_add_neg, intCast_add, intCast_neg, sub_eq_add_neg]
theorem neg_eq_neg_one_mul (a : α) : -a = (-1) * a := by
rw [ add_left_inj a, neg_add_cancel]
conv => rhs; arg 2; rw [ one_mul a]
rw [ right_distrib, intCast_neg_one, intCast_one (α := α)]
simp [ intCast_add, intCast_zero, zero_mul]
theorem neg_mul (a b : α) : (-a) * b = -(a * b) := by
rw [neg_eq_neg_one_mul a, neg_eq_neg_one_mul (a * b), mul_assoc]
theorem mul_neg (a b : α) : a * (-b) = -(a * b) := by
rw [mul_comm, neg_mul, mul_comm]
theorem intCast_nat_mul (x y : Nat) : ((x * y : Int) : α) = ((x : α) * (y : α)) := ofNat_mul _ _
theorem intCast_mul (x y : Int) : ((x * y : Int) : α) = ((x : α) * (y : α)) :=
match x, y with
| (x : Nat), (y : Nat) => ofNat_mul _ _
| (x : Nat), (-(y + 1 : Nat)) => by
rw [Int.mul_neg, intCast_neg, intCast_nat_mul, intCast_neg, mul_neg, intCast_ofNat, intCast_ofNat]
| (-(x + 1 : Nat)), (y : Nat) => by
rw [Int.neg_mul, intCast_neg, intCast_nat_mul, intCast_neg, neg_mul, intCast_ofNat, intCast_ofNat]
| (-(x + 1 : Nat)), (-(y + 1 : Nat)) => by
rw [Int.neg_mul_neg, intCast_neg, intCast_neg, neg_mul, mul_neg, neg_neg, intCast_nat_mul,
intCast_ofNat, intCast_ofNat]
end CommRing
open CommRing
class IsCharP (α : Type u) [CommRing α] (p : Nat) where
ofNat_eq_zero_iff (p) : (x : Nat), OfNat.ofNat (α := α) x = 0 x % p = 0
namespace IsCharP
variable (p) {α : Type u} [CommRing α] [IsCharP α p]
theorem natCast_eq_zero_iff (x : Nat) : (x : α) = 0 x % p = 0 :=
ofNat_eq_zero_iff p x
theorem intCast_eq_zero_iff (x : Int) : (x : α) = 0 x % p = 0 :=
match x with
| (x : Nat) => by
have := ofNat_eq_zero_iff (α := α) p (x := x)
rw [Int.ofNat_mod_ofNat]
norm_cast
| -(x + 1 : Nat) => by
rw [Int.neg_emod, Int.ofNat_mod_ofNat, intCast_neg, intCast_ofNat, neg_eq_zero]
have := ofNat_eq_zero_iff (α := α) p (x := x + 1)
rw [ofNat_eq_natCast] at this
rw [this]
simp only [Int.ofNat_dvd]
simp only [ Nat.dvd_iff_mod_eq_zero, Int.natAbs_ofNat, Int.natCast_add,
Int.cast_ofNat_Int, ite_eq_left_iff]
by_cases h : p x + 1
· simp [h]
· simp only [h, not_false_eq_true, Int.natCast_add, Int.cast_ofNat_Int,
forall_const, false_iff, ne_eq]
by_cases w : p = 0
· simp [w]
omega
· have : ((x + 1) % p) < p := Nat.mod_lt _ (by omega)
omega
theorem intCast_ext_iff {x y : Int} : (x : α) = (y : α) x % p = y % p := by
constructor
· intro h
replace h : ((x - y : Int) : α) = 0 := by rw [intCast_sub, h, sub_self]
exact Int.emod_eq_emod_iff_emod_sub_eq_zero.mpr ((intCast_eq_zero_iff p _).mp h)
· intro h
have : ((x - y : Int) : α) = 0 :=
(intCast_eq_zero_iff p _).mpr (by rw [Int.sub_emod, h, Int.sub_self, Int.zero_emod])
replace this := congrArg (· + (y : α)) this
simpa [intCast_sub, zero_add, sub_eq_add_neg, add_assoc, neg_add_cancel, add_zero] using this
theorem ofNat_ext_iff {x y : Nat} : OfNat.ofNat (α := α) x = OfNat.ofNat (α := α) y x % p = y % p := by
have := intCast_ext_iff (α := α) p (x := x) (y := y)
simp only [intCast_ofNat, Int.ofNat_emod] at this
simp only [ofNat_eq_natCast]
norm_cast at this
theorem ofNat_ext {x y : Nat} (h : x % p = y % p) : OfNat.ofNat (α := α) x = OfNat.ofNat (α := α) y := (ofNat_ext_iff p).mpr h
theorem natCast_ext {x y : Nat} (h : x % p = y % p) : (x : α) = (y : α) := ofNat_ext _ h
theorem natCast_ext_iff {x y : Nat} : (x : α) = (y : α) x % p = y % p :=
ofNat_ext_iff p
theorem intCast_emod (x : Int) : ((x % p : Int) : α) = (x : α) := by
rw [intCast_ext_iff p, Int.emod_emod]
theorem natCast_emod (x : Nat) : ((x % p : Nat) : α) = (x : α) := by
simp only [ intCast_ofNat]
rw [Int.ofNat_emod, intCast_emod]
theorem ofNat_emod (x : Nat) : OfNat.ofNat (α := α) (x % p) = OfNat.ofNat x :=
natCast_emod _ _
theorem ofNat_eq_zero_iff_of_lt {x : Nat} (h : x < p) : OfNat.ofNat (α := α) x = 0 x = 0 := by
rw [ofNat_eq_zero_iff p, Nat.mod_eq_of_lt h]
theorem ofNat_eq_iff_of_lt {x y : Nat} (h₁ : x < p) (h₂ : y < p) :
OfNat.ofNat (α := α) x = OfNat.ofNat (α := α) y x = y := by
rw [ofNat_ext_iff p, Nat.mod_eq_of_lt h₁, Nat.mod_eq_of_lt h₂]
theorem natCast_eq_zero_iff_of_lt {x : Nat} (h : x < p) : (x : α) = 0 x = 0 := by
rw [natCast_eq_zero_iff p, Nat.mod_eq_of_lt h]
theorem natCast_eq_iff_of_lt {x y : Nat} (h₁ : x < p) (h₂ : y < p) :
(x : α) = (y : α) x = y := by
rw [natCast_ext_iff p, Nat.mod_eq_of_lt h₁, Nat.mod_eq_of_lt h₂]
end IsCharP
end Lean.Grind

View File

@@ -19,12 +19,5 @@ instance : CommRing (BitVec w) where
mul_one := BitVec.mul_one
left_distrib _ _ _ := BitVec.mul_add
zero_mul _ := BitVec.zero_mul
sub_eq_add_neg := BitVec.sub_eq_add_neg
pow_zero _ := BitVec.pow_zero
pow_succ _ _ := BitVec.pow_succ
ofNat_succ x := BitVec.ofNat_add x 1
instance : IsCharP (BitVec w) (2 ^ w) where
ofNat_eq_zero_iff {x} := by simp [BitVec.ofInt, BitVec.toNat_eq]
end Lean.Grind

View File

@@ -19,12 +19,5 @@ instance : CommRing Int where
mul_one := Int.mul_one
left_distrib := Int.mul_add
zero_mul := Int.zero_mul
pow_zero _ := rfl
pow_succ _ _ := rfl
ofNat_succ _ := rfl
sub_eq_add_neg _ _ := Int.sub_eq_add_neg
instance : IsCharP Int 0 where
ofNat_eq_zero_iff {x} := by erw [Int.ofNat_eq_zero]; simp
end Lean.Grind

View File

@@ -9,9 +9,6 @@ import Init.Data.SInt.Lemmas
namespace Lean.Grind
instance : IntCast Int8 where
intCast x := Int8.ofInt x
instance : CommRing Int8 where
add_assoc := Int8.add_assoc
add_comm := Int8.add_comm
@@ -22,20 +19,6 @@ instance : CommRing Int8 where
mul_one := Int8.mul_one
left_distrib _ _ _ := Int8.mul_add
zero_mul _ := Int8.zero_mul
sub_eq_add_neg := Int8.sub_eq_add_neg
pow_zero := Int8.pow_zero
pow_succ := Int8.pow_succ
ofNat_succ x := Int8.ofNat_add x 1
instance : IsCharP Int8 (2 ^ 8) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = Int8.ofInt x := rfl
rw [this]
simp [Int8.ofInt_eq_iff_bmod_eq_toInt,
Int.dvd_iff_bmod_eq_zero, Nat.dvd_iff_mod_eq_zero, Int.ofNat_dvd_right]
instance : IntCast Int16 where
intCast x := Int16.ofInt x
instance : CommRing Int16 where
add_assoc := Int16.add_assoc
@@ -47,20 +30,6 @@ instance : CommRing Int16 where
mul_one := Int16.mul_one
left_distrib _ _ _ := Int16.mul_add
zero_mul _ := Int16.zero_mul
sub_eq_add_neg := Int16.sub_eq_add_neg
pow_zero := Int16.pow_zero
pow_succ := Int16.pow_succ
ofNat_succ x := Int16.ofNat_add x 1
instance : IsCharP Int16 (2 ^ 16) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = Int16.ofInt x := rfl
rw [this]
simp [Int16.ofInt_eq_iff_bmod_eq_toInt,
Int.dvd_iff_bmod_eq_zero, Nat.dvd_iff_mod_eq_zero, Int.ofNat_dvd_right]
instance : IntCast Int32 where
intCast x := Int32.ofInt x
instance : CommRing Int32 where
add_assoc := Int32.add_assoc
@@ -72,20 +41,6 @@ instance : CommRing Int32 where
mul_one := Int32.mul_one
left_distrib _ _ _ := Int32.mul_add
zero_mul _ := Int32.zero_mul
sub_eq_add_neg := Int32.sub_eq_add_neg
pow_zero := Int32.pow_zero
pow_succ := Int32.pow_succ
ofNat_succ x := Int32.ofNat_add x 1
instance : IsCharP Int32 (2 ^ 32) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = Int32.ofInt x := rfl
rw [this]
simp [Int32.ofInt_eq_iff_bmod_eq_toInt,
Int.dvd_iff_bmod_eq_zero, Nat.dvd_iff_mod_eq_zero, Int.ofNat_dvd_right]
instance : IntCast Int64 where
intCast x := Int64.ofInt x
instance : CommRing Int64 where
add_assoc := Int64.add_assoc
@@ -97,20 +52,6 @@ instance : CommRing Int64 where
mul_one := Int64.mul_one
left_distrib _ _ _ := Int64.mul_add
zero_mul _ := Int64.zero_mul
sub_eq_add_neg := Int64.sub_eq_add_neg
pow_zero := Int64.pow_zero
pow_succ := Int64.pow_succ
ofNat_succ x := Int64.ofNat_add x 1
instance : IsCharP Int64 (2 ^ 64) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = Int64.ofInt x := rfl
rw [this]
simp [Int64.ofInt_eq_iff_bmod_eq_toInt,
Int.dvd_iff_bmod_eq_zero, Nat.dvd_iff_mod_eq_zero, Int.ofNat_dvd_right]
instance : IntCast ISize where
intCast x := ISize.ofInt x
instance : CommRing ISize where
add_assoc := ISize.add_assoc
@@ -122,18 +63,5 @@ instance : CommRing ISize where
mul_one := ISize.mul_one
left_distrib _ _ _ := ISize.mul_add
zero_mul _ := ISize.zero_mul
sub_eq_add_neg := ISize.sub_eq_add_neg
pow_zero := ISize.pow_zero
pow_succ := ISize.pow_succ
ofNat_succ x := ISize.ofNat_add x 1
open System.Platform (numBits)
instance : IsCharP ISize (2 ^ numBits) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = ISize.ofInt x := rfl
rw [this]
simp [ISize.ofInt_eq_iff_bmod_eq_toInt,
Int.dvd_iff_bmod_eq_zero, Nat.dvd_iff_mod_eq_zero, Int.ofNat_dvd_right]
end Lean.Grind

View File

@@ -7,53 +7,6 @@ prelude
import Init.Grind.CommRing.Basic
import Init.Data.UInt.Lemmas
namespace UInt8
/-- Variant of `UInt8.ofNat_mod_size` replacing `2 ^ 8` with `256`.-/
theorem ofNat_mod_size' : ofNat (x % 256) = ofNat x := ofNat_mod_size
instance : IntCast UInt8 where
intCast x := UInt8.ofInt x
end UInt8
namespace UInt16
/-- Variant of `UInt16.ofNat_mod_size` replacing `2 ^ 16` with `65536`.-/
theorem ofNat_mod_size' : ofNat (x % 65536) = ofNat x := ofNat_mod_size
instance : IntCast UInt16 where
intCast x := UInt16.ofInt x
end UInt16
namespace UInt32
/-- Variant of `UInt32.ofNat_mod_size` replacing `2 ^ 32` with `4294967296`.-/
theorem ofNat_mod_size' : ofNat (x % 4294967296) = ofNat x := ofNat_mod_size
instance : IntCast UInt32 where
intCast x := UInt32.ofInt x
end UInt32
namespace UInt64
/-- Variant of `UInt64.ofNat_mod_size` replacing `2 ^ 64` with `18446744073709551616`.-/
theorem ofNat_mod_size' : ofNat (x % 18446744073709551616) = ofNat x := ofNat_mod_size
instance : IntCast UInt64 where
intCast x := UInt64.ofInt x
end UInt64
namespace USize
instance : IntCast USize where
intCast x := USize.ofInt x
end USize
namespace Lean.Grind
instance : CommRing UInt8 where
@@ -66,15 +19,6 @@ instance : CommRing UInt8 where
mul_one := UInt8.mul_one
left_distrib _ _ _ := UInt8.mul_add
zero_mul _ := UInt8.zero_mul
sub_eq_add_neg := UInt8.sub_eq_add_neg
pow_zero := UInt8.pow_zero
pow_succ := UInt8.pow_succ
ofNat_succ x := UInt8.ofNat_add x 1
instance : IsCharP UInt8 (2 ^ 8) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = UInt8.ofNat x := rfl
simp [this, UInt8.ofNat_eq_iff_mod_eq_toNat]
instance : CommRing UInt16 where
add_assoc := UInt16.add_assoc
@@ -86,15 +30,6 @@ instance : CommRing UInt16 where
mul_one := UInt16.mul_one
left_distrib _ _ _ := UInt16.mul_add
zero_mul _ := UInt16.zero_mul
sub_eq_add_neg := UInt16.sub_eq_add_neg
pow_zero := UInt16.pow_zero
pow_succ := UInt16.pow_succ
ofNat_succ x := UInt16.ofNat_add x 1
instance : IsCharP UInt16 (2 ^ 16) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = UInt16.ofNat x := rfl
simp [this, UInt16.ofNat_eq_iff_mod_eq_toNat]
instance : CommRing UInt32 where
add_assoc := UInt32.add_assoc
@@ -106,15 +41,6 @@ instance : CommRing UInt32 where
mul_one := UInt32.mul_one
left_distrib _ _ _ := UInt32.mul_add
zero_mul _ := UInt32.zero_mul
sub_eq_add_neg := UInt32.sub_eq_add_neg
pow_zero := UInt32.pow_zero
pow_succ := UInt32.pow_succ
ofNat_succ x := UInt32.ofNat_add x 1
instance : IsCharP UInt32 (2 ^ 32) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = UInt32.ofNat x := rfl
simp [this, UInt32.ofNat_eq_iff_mod_eq_toNat]
instance : CommRing UInt64 where
add_assoc := UInt64.add_assoc
@@ -126,15 +52,6 @@ instance : CommRing UInt64 where
mul_one := UInt64.mul_one
left_distrib _ _ _ := UInt64.mul_add
zero_mul _ := UInt64.zero_mul
sub_eq_add_neg := UInt64.sub_eq_add_neg
pow_zero := UInt64.pow_zero
pow_succ := UInt64.pow_succ
ofNat_succ x := UInt64.ofNat_add x 1
instance : IsCharP UInt64 (2 ^ 64) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = UInt64.ofNat x := rfl
simp [this, UInt64.ofNat_eq_iff_mod_eq_toNat]
instance : CommRing USize where
add_assoc := USize.add_assoc
@@ -146,16 +63,5 @@ instance : CommRing USize where
mul_one := USize.mul_one
left_distrib _ _ _ := USize.mul_add
zero_mul _ := USize.zero_mul
sub_eq_add_neg := USize.sub_eq_add_neg
pow_zero := USize.pow_zero
pow_succ := USize.pow_succ
ofNat_succ x := USize.ofNat_add x 1
open System.Platform
instance : IsCharP USize (2 ^ numBits) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = USize.ofNat x := rfl
simp [this, USize.ofNat_eq_iff_mod_eq_toNat]
end Lean.Grind

View File

@@ -74,11 +74,11 @@ theorem bne_eq_decide_not_eq {_ : BEq α} [LawfulBEq α] [DecidableEq α] (a b :
theorem xor_eq (a b : Bool) : (a ^^ b) = (a != b) := by
rfl
theorem natCast_eq [NatCast α] (a : Nat) : (Nat.cast a : α) = (NatCast.natCast a : α) := rfl
theorem natCast_div (a b : Nat) : (NatCast.natCast (a / b) : Int) = (NatCast.natCast a) / (NatCast.natCast b) := rfl
theorem natCast_mod (a b : Nat) : (NatCast.natCast (a % b) : Int) = (NatCast.natCast a) % (NatCast.natCast b) := rfl
theorem natCast_add (a b : Nat) : (NatCast.natCast (a + b : Nat) : Int) = (NatCast.natCast a : Int) + (NatCast.natCast b : Int) := rfl
theorem natCast_mul (a b : Nat) : (NatCast.natCast (a * b : Nat) : Int) = (NatCast.natCast a : Int) * (NatCast.natCast b : Int) := rfl
theorem natCast_div (a b : Nat) : ((a / b) : Int) = a / b := by
rfl
theorem natCast_mod (a b : Nat) : ((a % b) : Int) = a % b := by
rfl
theorem Nat.pow_one (a : Nat) : a ^ 1 = a := by
simp
@@ -153,10 +153,8 @@ init_grind_norm
Int.emod_neg Int.ediv_neg
Int.ediv_zero Int.emod_zero
Int.ediv_one Int.emod_one
natCast_eq natCast_div natCast_mod
natCast_add natCast_mul
Int.natCast_add Int.natCast_mul Int.natCast_pow
Int.natCast_zero natCast_div natCast_mod
Int.pow_zero Int.pow_one
-- GT GE
ge_eq gt_eq

View File

@@ -70,11 +70,6 @@ structure Config where
canonHeartbeats : Nat := 1000
/-- If `ext` is `true`, `grind` uses extensionality theorems available in the environment. -/
ext : Bool := true
/--
If `funext` is `true`, `grind` creates new opportunities for applying function extensionality by case-splitting
on equalities between lambda expressions.
-/
funext : Bool := true
/-- If `verbose` is `false`, additional diagnostics information is not collected. -/
verbose : Bool := true
/-- If `clean` is `true`, `grind` uses `expose_names` and only generates accessible names. -/

View File

@@ -44,13 +44,6 @@ register_builtin_option Elab.async : Bool := {
`Lean.Command.State.snapshotTasks`."
}
register_builtin_option Elab.inServer : Bool := {
defValue := false
descr := "true if elaboration is being run inside the Lean language server\
\n\
\nThis option is set by the file worker and should not be modified otherwise."
}
/-- Performance option used by cmdline driver. -/
register_builtin_option internal.cmdlineSnapshots : Bool := {
defValue := false

View File

@@ -111,7 +111,6 @@ inductive Message where
| response (id : RequestID) (result : Json)
/-- A non-successful response. -/
| responseError (id : RequestID) (code : ErrorCode) (message : String) (data? : Option Json)
deriving Inhabited
def Batch := Array Message

View File

@@ -510,7 +510,6 @@ where go := do
let oldCmds? := oldSnap?.map fun old =>
if old.newStx.isOfKind nullKind then old.newStx.getArgs else #[old.newStx]
let cmdPromises cmds.mapM fun _ => IO.Promise.new
let cancelTk? := ( read).cancelTk?
snap.new.resolve <| .ofTyped {
diagnostics := .empty
macroDecl := decl
@@ -518,7 +517,7 @@ where go := do
newNextMacroScope := nextMacroScope
hasTraces
next := Array.zipWith (fun cmdPromise cmd =>
{ stx? := some cmd, task := cmdPromise.resultD default, cancelTk? }) cmdPromises cmds
{ stx? := some cmd, task := cmdPromise.resultD default }) cmdPromises cmds
: MacroExpandedSnapshot
}
-- After the first command whose syntax tree changed, we must disable

View File

@@ -6,7 +6,6 @@ Authors: Leonardo de Moura, Sebastian Ullrich
prelude
import Lean.Parser.Module
import Lean.Util.Paths
import Lean.CoreM
namespace Lean.Elab
@@ -22,16 +21,9 @@ def processHeader (header : Syntax) (opts : Options) (messages : MessageLog)
(inputCtx : Parser.InputContext) (trustLevel : UInt32 := 0)
(plugins : Array System.FilePath := #[]) (leakEnv := false)
: IO (Environment × MessageLog) := do
let level := if experimental.module.get opts then
if Elab.inServer.get opts then
.server
else
.exported
else
.private
try
let env
importModules (leakEnv := leakEnv) (loadExts := true) (level := level) (headerToImports header) opts trustLevel plugins
importModules (leakEnv := leakEnv) (loadExts := true) (headerToImports header) opts trustLevel plugins
pure (env, messages)
catch e =>
let env mkEmptyEnvironment

View File

@@ -165,8 +165,6 @@ private def elabHeaders (views : Array DefView) (expandedDeclIds : Array ExpandD
-- no syntax guard to store, we already did the necessary checks
oldBodySnap? := guard reuseBody *> pure .missing, old.bodySnap
if oldBodySnap?.isNone then
-- NOTE: this will eagerly cancel async tasks not associated with an inner snapshot, most
-- importantly kernel checking and compilation of the top-level declaration
old.bodySnap.cancelRec
oldTacSnap? := do
guard reuseTac
@@ -219,7 +217,6 @@ private def elabHeaders (views : Array DefView) (expandedDeclIds : Array ExpandD
return newHeader
if let some snap := view.headerSnap? then
let (tacStx?, newTacTask?) mkTacTask view.value tacPromise
let cancelTk? := ( readThe Core.Context).cancelTk?
let bodySnap := {
stx? := view.value
reportingRange? :=
@@ -230,8 +227,6 @@ private def elabHeaders (views : Array DefView) (expandedDeclIds : Array ExpandD
else
getBodyTerm? view.value |>.getD view.value |>.getRange?
task := bodyPromise.resultD default
-- We should not cancel the entire body early if we have tactics
cancelTk? := guard newTacTask?.isNone *> cancelTk?
}
snap.new.resolve <| some {
diagnostics :=
@@ -274,8 +269,7 @@ where
:= do
if let some e := getBodyTerm? body then
if let `(by $tacs*) := e then
let cancelTk? := ( readThe Core.Context).cancelTk?
return (e, some { stx? := mkNullNode tacs, task := tacPromise.resultD default, cancelTk? })
return (e, some { stx? := mkNullNode tacs, task := tacPromise.resultD default })
tacPromise.resolve default
return (none, none)
@@ -438,7 +432,8 @@ private def elabFunValues (headers : Array DefViewElabHeader) (vars : Array Expr
snap.new.resolve <| some old
reusableResult? := some (old.value, old.state)
else
-- make sure to cancel any async tasks that may still be running (e.g. kernel and codegen)
-- NOTE: this will eagerly cancel async tasks not associated with an inner snapshot, most
-- importantly kernel checking and compilation of the top-level declaration
old.val.cancelRec
let (val, state) withRestoreOrSaveFull reusableResult? header.tacSnap? do
@@ -1163,7 +1158,7 @@ is error-free and contains no syntactical `sorry`s.
-/
private def logGoalsAccomplishedSnapshotTask (views : Array DefView)
(defsParsedSnap : DefsParsedSnapshot) : TermElabM Unit := do
if Lean.Elab.inServer.get ( getOptions) then
if Lean.internal.cmdlineSnapshots.get ( getOptions) then
-- Skip 'goals accomplished' task if we are on the command line.
-- These messages are only used in the language server.
return
@@ -1202,7 +1197,6 @@ private def logGoalsAccomplishedSnapshotTask (views : Array DefView)
-- Use first line of the mutual block to avoid covering the progress of the whole mutual block
reportingRange? := ( getRef).getPos?.map fun pos => pos, pos
task := logGoalsAccomplishedTask
cancelTk? := none
}
end Term
@@ -1241,10 +1235,9 @@ def elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
} }
if snap.old?.isSome && (view.headerSnap?.bind (·.old?)).isNone then
snap.old?.forM (·.val.cancelRec)
let cancelTk? := ( read).cancelTk?
defs := defs.push {
fullHeaderRef
headerProcessedSnap := { stx? := d, task := headerPromise.resultD default, cancelTk? }
headerProcessedSnap := { stx? := d, task := headerPromise.resultD default }
}
reusedAllHeaders := reusedAllHeaders && view.headerSnap?.any (·.old?.isSome)
views := views.push view

View File

@@ -109,11 +109,11 @@ structure InductiveView where
/-- Elaborated header for an inductive type before fvars for each inductive are added to the local context. -/
structure PreElabHeaderResult where
view : InductiveView
lctx : LocalContext
localInsts : LocalInstances
levelNames : List Name
numParams : Nat
params : Array Expr
type : Expr
/-- The parameters in the header's initial local context. Used for adding fvar alias terminfo. -/
origParams : Array Expr
deriving Inhabited
/-- The elaborated header with the `indFVar` registered for this inductive type. -/
@@ -228,12 +228,16 @@ private def checkClass (rs : Array PreElabHeaderResult) : TermElabM Unit := do
throwErrorAt r.view.ref "invalid inductive type, mutual classes are not supported"
private def checkNumParams (rs : Array PreElabHeaderResult) : TermElabM Nat := do
let numParams := rs[0]!.numParams
let numParams := rs[0]!.params.size
for r in rs do
unless r.numParams == numParams do
unless r.params.size == numParams do
throwErrorAt r.view.ref "invalid inductive type, number of parameters mismatch in mutually inductive datatypes"
return numParams
private def mkTypeFor (r : PreElabHeaderResult) : TermElabM Expr := do
withLCtx r.lctx r.localInsts do
mkForallFVars r.params r.type
/--
Execute `k` with updated binder information for `xs`. Any `x` that is explicit becomes implicit.
-/
@@ -272,7 +276,7 @@ private def checkHeaders (rs : Array PreElabHeaderResult) (numParams : Nat) (i :
checkHeaders rs numParams (i+1) type
where
checkHeader (r : PreElabHeaderResult) (numParams : Nat) (firstType? : Option Expr) : TermElabM Expr := do
let type := r.type
let type mkTypeFor r
match firstType? with
| none => return type
| some firstType =>
@@ -302,8 +306,7 @@ private def elabHeadersAux (views : Array InductiveView) (i : Nat) (acc : Array
let params Term.addAutoBoundImplicits params (view.declId.getTailPos? (canonicalOnly := true))
trace[Elab.inductive] "header params: {params}, type: {type}"
let levelNames Term.getLevelNames
let type mkForallFVars params type
return acc.push { levelNames, numParams := params.size, type, view, origParams := params }
return acc.push { lctx := ( getLCtx), localInsts := ( getLocalInstances), levelNames, params, type, view }
elabHeadersAux views (i+1) acc
else
return acc
@@ -323,21 +326,21 @@ private def elabHeaders (views : Array InductiveView) : TermElabM (Array PreElab
/--
Create a local declaration for each inductive type in `rs`, and execute `x params indFVars`, where `params` are the inductive type parameters and
`indFVars` are the new local declarations.
We use the parameters of rs[0].
We use the local context/instances and parameters of rs[0].
Note that this method is executed after we executed `checkHeaders` and established all
parameters are compatible.
-/
private def withInductiveLocalDecls (rs : Array PreElabHeaderResult) (x : Array Expr Array Expr TermElabM α) : TermElabM α := do
let r0 := rs[0]!
forallBoundedTelescope r0.type r0.numParams fun params _ => withRef r0.view.ref do
let namesAndTypes rs.mapM fun r => do
let type mkTypeFor r
pure (r.view.declName, r.view.shortDeclName, type)
let r0 := rs[0]!
let params := r0.params
withLCtx r0.lctx r0.localInsts <| withRef r0.view.ref do
let rec loop (i : Nat) (indFVars : Array Expr) := do
if h : i < rs.size then
let r := rs[i]
for param in params, origParam in r.origParams do
if let .fvar origFVar := origParam then
Elab.pushInfoLeaf <| .ofFVarAliasInfo { id := param.fvarId!, baseId := origFVar, userName := param.fvarId!.getUserName }
withAuxDecl r.view.shortDeclName r.type r.view.declName fun indFVar =>
loop (i+1) (indFVars.push indFVar)
if h : i < namesAndTypes.size then
let (declName, shortDeclName, type) := namesAndTypes[i]
withAuxDecl shortDeclName type declName fun indFVar => loop (i+1) (indFVars.push indFVar)
else
x params indFVars
loop 0 #[]
@@ -356,6 +359,26 @@ private def ElabHeaderResult.checkLevelNames (rs : Array PreElabHeaderResult) :
unless r.levelNames == levelNames do
throwErrorAt r.view.ref "invalid inductive type, universe parameters mismatch in mutually inductive datatypes"
/--
We need to work inside a single local context across all the inductive types, so we need to update the `ElabHeaderResult`s
so that resultant types refer to the fvars in `params`, the parameters for `rs[0]!` specifically.
Also updates the local contexts and local instances in each header.
-/
private def updateElabHeaderTypes (params : Array Expr) (rs : Array PreElabHeaderResult) (indFVars : Array Expr) : TermElabM (Array ElabHeaderResult) := do
rs.mapIdxM fun i r => do
/-
At this point, because of `withInductiveLocalDecls`, the only fvars that are in context are the ones related to the first inductive type.
Because of this, we need to replace the fvars present in each inductive type's header of the mutual block with those of the first inductive.
However, some mvars may still be uninstantiated there, and might hide some of the old fvars.
As such we first need to synthesize all possible mvars at this stage, instantiate them in the header types and only
then replace the parameters' fvars in the header type.
See issue #3242 (`https://github.com/leanprover/lean4/issues/3242`)
-/
let type instantiateMVars r.type
let type := type.replaceFVars r.params params
pure { r with lctx := getLCtx, localInsts := getLocalInstances, type := type, indFVar := indFVars[i]! }
private def getArity (indType : InductiveType) : MetaM Nat :=
forallTelescopeReducing indType.type fun xs _ => return xs.size
@@ -855,7 +878,7 @@ private def mkInductiveDecl (vars : Array Expr) (elabs : Array InductiveElabStep
trace[Elab.inductive] "level names: {allUserLevelNames}"
let res withInductiveLocalDecls rs fun params indFVars => do
trace[Elab.inductive] "indFVars: {indFVars}"
let rs := Array.zipWith (fun r indFVar => { r with indFVar : ElabHeaderResult }) rs indFVars
let rs updateElabHeaderTypes params rs indFVars
let mut indTypesArray : Array InductiveType := #[]
let mut elabs' := #[]
for h : i in [:views.size] do
@@ -863,7 +886,8 @@ private def mkInductiveDecl (vars : Array Expr) (elabs : Array InductiveElabStep
let r := rs[i]!
let elab' elabs[i]!.elabCtors rs r params
elabs' := elabs'.push elab'
indTypesArray := indTypesArray.push { name := r.view.declName, type := r.type, ctors := elab'.ctors }
let type mkForallFVars params r.type
indTypesArray := indTypesArray.push { name := r.view.declName, type, ctors := elab'.ctors }
Term.synthesizeSyntheticMVarsNoPostponing
let numExplicitParams fixedIndicesToParams params.size indTypesArray indFVars
trace[Elab.inductive] "numExplicitParams: {numExplicitParams}"

View File

@@ -1260,8 +1260,12 @@ private def addDefaults (levelParams : List Name) (params : Array Expr) (replace
let fieldInfos := ( get).fields
let lctx instantiateLCtxMVars ( getLCtx)
/- The parameters `params` for the auxiliary "default value" definitions must be marked as implicit, and all others as explicit. -/
let lctx := params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) =>
lctx.setBinderInfo p.fvarId! BinderInfo.implicit
let lctx :=
params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) =>
if p.isFVar then
lctx.setBinderInfo p.fvarId! BinderInfo.implicit
else
lctx
let parentFVarIds := fieldInfos |>.filter (·.kind.isParent) |>.map (·.fvar.fvarId!)
let fields := fieldInfos |>.filter (!·.kind.isParent)
withLCtx lctx ( getLocalInstances) do

View File

@@ -165,7 +165,6 @@ structure EvalTacticFailure where
state : SavedState
partial def evalTactic (stx : Syntax) : TacticM Unit := do
checkSystem "tactic execution"
profileitM Exception "tactic execution" (decl := stx.getKind) ( getOptions) <|
withRef stx <| withIncRecDepth <| withFreshMacroScope <| match stx with
| .node _ k _ =>
@@ -241,7 +240,6 @@ where
snap.old?.forM (·.val.cancelRec)
let promise IO.Promise.new
-- Store new unfolding in the snapshot tree
let cancelTk? := ( readThe Core.Context).cancelTk?
snap.new.resolve {
stx := stx'
diagnostics := .empty
@@ -251,7 +249,7 @@ where
state? := ( Tactic.saveState)
moreSnaps := #[]
}
next := #[{ stx? := stx', task := promise.resultD default, cancelTk? }]
next := #[{ stx? := stx', task := promise.resultD default }]
}
-- Update `tacSnap?` to old unfolding
withTheReader Term.Context ({ · with tacSnap? := some {

View File

@@ -78,14 +78,13 @@ where
let next IO.Promise.new
let finished IO.Promise.new
let inner IO.Promise.new
let cancelTk? := ( readThe Core.Context).cancelTk?
snap.new.resolve {
desc := tac.getKind.toString
diagnostics := .empty
stx := tac
inner? := some { stx? := tac, task := inner.resultD default, cancelTk? }
finished := { stx? := tac, task := finished.resultD default, cancelTk? }
next := #[{ stx? := stxs, task := next.resultD default, cancelTk? }]
inner? := some { stx? := tac, task := inner.resultD default }
finished := { stx? := tac, task := finished.resultD default }
next := #[{ stx? := stxs, task := next.resultD default }]
}
-- Run `tac` in a fresh info tree state and store resulting state in snapshot for
-- incremental reporting, then add back saved trees. Here we rely on `evalTactic`

View File

@@ -286,15 +286,14 @@ where
-- them, eventually put each of them back in `Context.tacSnap?` in `applyAltStx`
let finished IO.Promise.new
let altPromises altStxs.mapM fun _ => IO.Promise.new
let cancelTk? := ( readThe Core.Context).cancelTk?
tacSnap.new.resolve {
-- save all relevant syntax here for comparison with next document version
stx := mkNullNode altStxs
diagnostics := .empty
inner? := none
finished := { stx? := mkNullNode altStxs, reportingRange? := none, task := finished.resultD default, cancelTk? }
finished := { stx? := mkNullNode altStxs, reportingRange? := none, task := finished.resultD default }
next := Array.zipWith
(fun stx prom => { stx? := some stx, task := prom.resultD default, cancelTk? })
(fun stx prom => { stx? := some stx, task := prom.resultD default })
altStxs altPromises
}
goWithIncremental <| altPromises.mapIdx fun i prom => {

View File

@@ -1878,14 +1878,13 @@ where
go todo (autos.push auto)
/--
Similar to `addAutoBoundImplicits`, but converts all metavariables into free variables.
It uses `mkForallFVars` + `forallBoundedTelescope` to convert metavariables into free variables.
Similar to `autoBoundImplicits`, but immediately if the resulting array of expressions contains metavariables,
it immediately uses `mkForallFVars` + `forallBoundedTelescope` to convert them into free variables.
The type `type` is modified during the process if type depends on `xs`.
We use this method to simplify the conversion of code using `autoBoundImplicitsOld` to `autoBoundImplicits`.
-/
def addAutoBoundImplicits' (xs : Array Expr) (type : Expr) (k : Array Expr Expr TermElabM α) (inlayHintPos? : Option String.Pos := none) : TermElabM α := do
let xs addAutoBoundImplicits xs inlayHintPos?
def addAutoBoundImplicits' (xs : Array Expr) (type : Expr) (k : Array Expr Expr TermElabM α) : TermElabM α := do
let xs addAutoBoundImplicits xs none
if xs.all (·.isFVar) then
k xs type
else

View File

@@ -122,15 +122,13 @@ end TagDeclarationExtension
structure MapDeclarationExtension (α : Type) extends PersistentEnvExtension (Name × α) (Name × α) (NameMap α)
deriving Inhabited
def mkMapDeclarationExtension (name : Name := by exact decl_name%)
(exportEntriesFn : NameMap α Array (Name × α) := (·.toArray)) : IO (MapDeclarationExtension α) :=
def mkMapDeclarationExtension (name : Name := by exact decl_name%) : IO (MapDeclarationExtension α) :=
.mk <$> registerPersistentEnvExtension {
name := name,
mkInitial := pure {}
addImportedFn := fun _ => pure {}
addEntryFn := fun s (n, v) => s.insert n v
saveEntriesFn := fun s => s.toArray
exportEntriesFn
exportEntriesFn := fun s => s.toArray
asyncMode := .async
replay? := some fun _ newState newConsts s =>
newConsts.foldl (init := s) fun s c =>
@@ -147,11 +145,10 @@ def insert (ext : MapDeclarationExtension α) (env : Environment) (declName : Na
assert! env.asyncMayContain declName
ext.addEntry env (declName, val)
def find? [Inhabited α] (ext : MapDeclarationExtension α) (env : Environment) (declName : Name)
(includeServer := false) : Option α :=
def find? [Inhabited α] (ext : MapDeclarationExtension α) (env : Environment) (declName : Name) : Option α :=
match env.getModuleIdxFor? declName with
| some modIdx =>
match (ext.getModuleEntries (includeServer := includeServer) env modIdx).binSearch (declName, default) (fun a b => Name.quickLt a.1 b.1) with
match (ext.getModuleEntries env modIdx).binSearch (declName, default) (fun a b => Name.quickLt a.1 b.1) with
| some e => some e.2
| none => none
| none => (ext.findStateAsync env declName).find? declName

View File

@@ -506,12 +506,6 @@ structure Environment where
-/
base : Kernel.Environment
/--
Additional imported environment extension state for use in the language server. This field is
identical to `base.extensions` in other contexts. Access via
`getModuleEntries (includeServer := true)`.
-/
private serverBaseExts : Array EnvExtensionState := base.extensions
/--
Kernel environment task that is fulfilled when all asynchronously elaborated declarations are
finished, containing the resulting environment. Also collects the environment extension state of
all environment branches that contributed contained declarations.
@@ -542,12 +536,6 @@ structure Environment where
`findAsyncCore?`/`findStateAsync`; see there.
-/
private allRealizations : Task (NameMap AsyncConst) := .pure {}
/--
Indicates whether the environment is being used in an exported context, i.e. whether it should
provide access to only the data to be imported by other modules participating in the module
system.
-/
isExporting : Bool := false
deriving Nonempty
namespace Environment
@@ -561,10 +549,6 @@ def ofKernelEnv (env : Kernel.Environment) : Environment :=
def toKernelEnv (env : Environment) : Kernel.Environment :=
env.checked.get
/-- Updates `Environment.isExporting`. -/
def setExporting (env : Environment) (isExporting : Bool) : Environment :=
{ env with isExporting }
/-- Consistently updates synchronous and asynchronous parts of the environment without blocking. -/
private def modifyCheckedAsync (env : Environment) (f : Kernel.Environment Kernel.Environment) : Environment :=
{ env with checked := env.checked.map (sync := true) f, base := f env.base }
@@ -1395,15 +1379,7 @@ structure PersistentEnvExtension (α : Type) (β : Type) (σ : Type) where
name : Name
addImportedFn : Array (Array α) ImportM σ
addEntryFn : σ β σ
/-- Function to transform state into data that should always be imported into other modules. -/
exportEntriesFn : σ Array α
/--
Function to transform state into data that should be imported into other modules when the module
system is disabled. When it is enabled, the data is loaded only in the language server and
accessible via `getModuleEntries (includeServer := true)`. Conventionally, this is a superset of
the data returned by `exportEntriesFn`.
-/
saveEntriesFn : σ Array α
statsFn : σ Format
instance {α σ} [Inhabited σ] : Inhabited (PersistentEnvExtensionState α σ) :=
@@ -1416,21 +1392,14 @@ instance {α β σ} [Inhabited σ] : Inhabited (PersistentEnvExtension α β σ)
addImportedFn := fun _ => default,
addEntryFn := fun s _ => s,
exportEntriesFn := fun _ => #[],
saveEntriesFn := fun _ => #[],
statsFn := fun _ => Format.nil
}
namespace PersistentEnvExtension
/--
Returns the data saved by `ext.exportEntriesFn/saveEntriesFn` when `m` was elaborated. See docs on
the functions for details.
-/
def getModuleEntries {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ)
(env : Environment) (m : ModuleIdx) (includeServer := false) : Array α :=
let exts := if includeServer then env.serverBaseExts else env.base.extensions
-- safety: as in `getStateUnsafe`
unsafe (ext.toEnvExtension.getStateImpl exts).importedEntries[m]!
def getModuleEntries {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment) (m : ModuleIdx) : Array α :=
-- `importedEntries` is identical on all environment branches, so `local` is always sufficient
(ext.toEnvExtension.getState (asyncMode := .local) env).importedEntries[m]!
def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (b : β) : Environment :=
ext.toEnvExtension.modifyState env fun s =>
@@ -1467,14 +1436,10 @@ structure PersistentEnvExtensionDescr (α β σ : Type) where
addImportedFn : Array (Array α) ImportM σ
addEntryFn : σ β σ
exportEntriesFn : σ Array α
saveEntriesFn : σ Array α := exportEntriesFn
statsFn : σ Format := fun _ => Format.nil
asyncMode : EnvExtension.AsyncMode := .mainOnly
replay? : Option (ReplayFn σ) := none
attribute [inherit_doc PersistentEnvExtension.exportEntriesFn] PersistentEnvExtensionDescr.exportEntriesFn
attribute [inherit_doc PersistentEnvExtension.saveEntriesFn] PersistentEnvExtensionDescr.saveEntriesFn
unsafe def registerPersistentEnvExtensionUnsafe {α β σ : Type} [Inhabited σ] (descr : PersistentEnvExtensionDescr α β σ) : IO (PersistentEnvExtension α β σ) := do
let pExts persistentEnvExtensionsRef.get
if pExts.any (fun ext => ext.name == descr.name) then throw (IO.userError s!"invalid environment extension, '{descr.name}' has already been used")
@@ -1493,7 +1458,6 @@ unsafe def registerPersistentEnvExtensionUnsafe {α β σ : Type} [Inhabited σ]
addImportedFn := descr.addImportedFn,
addEntryFn := descr.addEntryFn,
exportEntriesFn := descr.exportEntriesFn,
saveEntriesFn := descr.saveEntriesFn,
statsFn := descr.statsFn
}
persistentEnvExtensionsRef.modify fun pExts => pExts.push (unsafeCast pExt)
@@ -1502,30 +1466,10 @@ unsafe def registerPersistentEnvExtensionUnsafe {α β σ : Type} [Inhabited σ]
@[implemented_by registerPersistentEnvExtensionUnsafe]
opaque registerPersistentEnvExtension {α β σ : Type} [Inhabited σ] (descr : PersistentEnvExtensionDescr α β σ) : IO (PersistentEnvExtension α β σ)
/--
Stores each given module data in the respective file name. Objects shared with prior parts are not
duplicated. Thus the data cannot be loaded with individual `readModuleData` calls but must loaded by
passing (a prefix of) the file names to `readModuleDataParts`. `mod` is used to determine an
arbitrary but deterministic base address for `mmap`.
-/
@[extern "lean_save_module_data_parts"]
opaque saveModuleDataParts (mod : @& Name) (parts : Array (System.FilePath × ModuleData)) : IO Unit
/--
Loads the module data from the given file names. The files must be (a prefix of) the result of a
`saveModuleDataParts` call.
-/
@[extern "lean_read_module_data_parts"]
opaque readModuleDataParts (fnames : @& Array System.FilePath) : IO (Array (ModuleData × CompactedRegion))
def saveModuleData (fname : System.FilePath) (mod : Name) (data : ModuleData) : IO Unit :=
saveModuleDataParts mod #[(fname, data)]
def readModuleData (fname : @& System.FilePath) : IO (ModuleData × CompactedRegion) := do
let parts readModuleDataParts #[fname]
assert! parts.size == 1
let some part := parts[0]? | unreachable!
return part
@[extern "lean_save_module_data"]
opaque saveModuleData (fname : @& System.FilePath) (mod : @& Name) (data : @& ModuleData) : IO Unit
@[extern "lean_read_module_data"]
opaque readModuleData (fname : @& System.FilePath) : IO (ModuleData × CompactedRegion)
/--
Free compacted regions of imports. No live references to imported objects may exist at the time of invocation; in
@@ -1549,22 +1493,7 @@ unsafe def Environment.freeRegions (env : Environment) : IO Unit :=
TODO: statically check for this. -/
env.header.regions.forM CompactedRegion.free
/-- The level of information to save/load. Each level includes all previous ones. -/
inductive OLeanLevel where
/-- Information from exported contexts. -/
| exported
/-- Environment extension state for the language server. -/
| server
/-- Private module data. -/
| «private»
deriving DecidableEq
def OLeanLevel.adjustFileName (base : System.FilePath) : OLeanLevel System.FilePath
| .exported => base
| .server => base.addExtension "server"
| .private => base.addExtension "private"
def mkModuleData (env : Environment) (level : OLeanLevel := .private) : IO ModuleData := do
def mkModuleData (env : Environment) : IO ModuleData := do
let pExts persistentEnvExtensionsRef.get
let entries := pExts.map fun pExt => Id.run do
-- get state from `checked` at the end if `async`; it would otherwise panic
@@ -1572,37 +1501,19 @@ def mkModuleData (env : Environment) (level : OLeanLevel := .private) : IO Modul
if asyncMode matches .async then
asyncMode := .sync
let state := pExt.getState (asyncMode := asyncMode) env
(pExt.name, if level = .exported then pExt.exportEntriesFn state else pExt.saveEntriesFn state)
(pExt.name, pExt.exportEntriesFn state)
let kenv := env.toKernelEnv
let env := env.setExporting (level != .private)
let constants := kenv.constants.foldStage2 (fun cs _ c => cs.push c) #[]
--let constNames := kenv.constants.foldStage2 (fun names name _ => names.push name) #[]
-- not all kernel constants may be exported
-- TODO: does not include cstage* constants from the old codegen
--let constants := constNames.filterMap env.find?
let constNames := constants.map (·.name)
let constNames := kenv.constants.foldStage2 (fun names name _ => names.push name) #[]
let constants := kenv.constants.foldStage2 (fun cs _ c => cs.push c) #[]
return {
imports := env.header.imports
extraConstNames := env.checked.get.extraConstNames.toArray
constNames, constants, entries
}
register_builtin_option experimental.module : Bool := {
defValue := false
descr := "Enable module system (experimental)"
}
@[export lean_write_module]
def writeModule (env : Environment) (fname : System.FilePath) (split := false) : IO Unit := do
if split then
let mkPart (level : OLeanLevel) :=
return (level.adjustFileName fname, ( mkModuleData env level))
saveModuleDataParts env.mainModule #[
( mkPart .exported),
( mkPart .server),
( mkPart .private)]
else
saveModuleData fname env.mainModule ( mkModuleData env)
def writeModule (env : Environment) (fname : System.FilePath) : IO Unit := do
saveModuleData fname env.mainModule ( mkModuleData env)
/--
Construct a mapping from persistent extension name to extension index at the array of persistent extensions.
@@ -1616,9 +1527,10 @@ def mkExtNameMap (startingAt : Nat) : IO (Std.HashMap Name Nat) := do
result := result.insert descr.name i
return result
private def setImportedEntries (states : Array EnvExtensionState) (mods : Array ModuleData)
(startingAt : Nat := 0) : IO (Array EnvExtensionState) := do
let mut states := states
private def setImportedEntries (env : Environment) (mods : Array ModuleData) (startingAt : Nat := 0) : IO Environment := do
-- We work directly on the states array instead of `env` as `Environment.modifyState` introduces
-- significant overhead on such frequent calls
let mut states := env.base.extensions
let extDescrs persistentEnvExtensionsRef.get
/- For extensions starting at `startingAt`, ensure their `importedEntries` array have size `mods.size`. -/
for extDescr in extDescrs[startingAt:] do
@@ -1634,7 +1546,7 @@ private def setImportedEntries (states : Array EnvExtensionState) (mods : Array
-- safety: as in `modifyState`
states := unsafe extDescrs[entryIdx]!.toEnvExtension.modifyStateImpl states fun s =>
{ s with importedEntries := s.importedEntries.set! modIdx entries }
return states
return env.setCheckedSync { env.base with extensions := states }
/--
"Forward declaration" needed for updating the attribute table with user-defined attributes.
@@ -1673,7 +1585,7 @@ where
-- This branch is executed when `pExtDescrs[i]` is the extension associated with the `init` attribute, and
-- a user-defined persistent extension is imported.
-- Thus, we invoke `setImportedEntries` to update the array `importedEntries` with the entries for the new extensions.
env := env.setCheckedSync { env.base with extensions := ( setImportedEntries env.base.extensions mods prevSize) }
env setImportedEntries env mods prevSize
-- See comment at `updateEnvAttributesRef`
env updateEnvAttributes env
loop (i + 1) env
@@ -1684,7 +1596,7 @@ structure ImportState where
moduleNameSet : NameHashSet := {}
moduleNames : Array Name := #[]
moduleData : Array ModuleData := #[]
parts : Array (Array (ModuleData × CompactedRegion)) := #[]
regions : Array CompactedRegion := #[]
def throwAlreadyImported (s : ImportState) (const2ModIdx : Std.HashMap Name ModuleIdx) (modIdx : Nat) (cname : Name) : IO α := do
let modName := s.moduleNames[modIdx]!
@@ -1696,8 +1608,7 @@ abbrev ImportStateM := StateRefT ImportState IO
@[inline] nonrec def ImportStateM.run (x : ImportStateM α) (s : ImportState := {}) : IO (α × ImportState) :=
x.run s
partial def importModulesCore (imports : Array Import) (level := OLeanLevel.private) :
ImportStateM Unit := do
partial def importModulesCore (imports : Array Import) : ImportStateM Unit := do
for i in imports do
if i.runtimeOnly || ( get).moduleNameSet.contains i.module then
continue
@@ -1705,22 +1616,12 @@ partial def importModulesCore (imports : Array Import) (level := OLeanLevel.priv
let mFile findOLean i.module
unless ( mFile.pathExists) do
throw <| IO.userError s!"object file '{mFile}' of module {i.module} does not exist"
let mut fnames := #[mFile]
if level != OLeanLevel.exported then
let sFile := OLeanLevel.server.adjustFileName mFile
if ( sFile.pathExists) then
fnames := fnames.push sFile
if level == OLeanLevel.private then
let pFile := OLeanLevel.private.adjustFileName mFile
if ( pFile.pathExists) then
fnames := fnames.push pFile
let parts readModuleDataParts fnames
let some (mod, _) := parts[if level = .exported then 0 else parts.size - 1]? | unreachable!
importModulesCore (level := level) mod.imports
let (mod, region) readModuleData mFile
importModulesCore mod.imports
modify fun s => { s with
moduleData := s.moduleData.push mod
regions := s.regions.push region
moduleNames := s.moduleNames.push i.module
parts := s.parts.push parts
}
/--
@@ -1784,16 +1685,14 @@ def finalizeImport (s : ImportState) (imports : Array Import) (opts : Options) (
extensions := exts
header := {
trustLevel, imports
regions := s.parts.flatMap (·.map (·.2))
regions := s.regions
moduleNames := s.moduleNames
moduleData := s.moduleData
}
}
realizedImportedConsts? := none
}
env := env.setCheckedSync { env.base with extensions := ( setImportedEntries env.base.extensions s.moduleData) }
let serverData := s.parts.filterMap fun parts => (parts[1]? <|> parts[0]?).map Prod.fst
env := { env with serverBaseExts := ( setImportedEntries env.base.extensions serverData) }
env setImportedEntries env s.moduleData
if leakEnv then
/- Mark persistent a first time before `finalizePersistenExtensions`, which
avoids costly MT markings when e.g. an interpreter closure (which
@@ -1840,13 +1739,13 @@ environment's constant map can be accessed without `loadExts`, many functions th
-/
def importModules (imports : Array Import) (opts : Options) (trustLevel : UInt32 := 0)
(plugins : Array System.FilePath := #[]) (leakEnv := false) (loadExts := false)
(level := OLeanLevel.private) : IO Environment := profileitIO "import" opts do
: IO Environment := profileitIO "import" opts do
for imp in imports do
if imp.module matches .anonymous then
throw <| IO.userError "import failed, trying to import module with anonymous name"
withImporting do
plugins.forM Lean.loadPlugin
let (_, s) importModulesCore (level := level) imports |>.run
let (_, s) importModulesCore imports |>.run
finalizeImport (leakEnv := leakEnv) (loadExts := loadExts) s imports opts trustLevel
/--

View File

@@ -93,17 +93,18 @@ structure SnapshotTask (α : Type) where
Cancellation token that can be set by the server to cancel the task when it detects the results
are not needed anymore.
-/
cancelTk? : Option IO.CancelToken
cancelTk? : Option IO.CancelToken := none
/-- Underlying task producing the snapshot. -/
task : Task α
deriving Nonempty, Inhabited
/-- Creates a snapshot task from the syntax processed by the task and a `BaseIO` action. -/
def SnapshotTask.ofIO (stx? : Option Syntax) (cancelTk? : Option IO.CancelToken)
def SnapshotTask.ofIO (stx? : Option Syntax)
(reportingRange? : Option String.Range := defaultReportingRange? stx?) (act : BaseIO α) :
BaseIO (SnapshotTask α) := do
return {
stx?, reportingRange?, cancelTk?
stx?
reportingRange?
task := ( BaseIO.asTask act)
}
@@ -113,7 +114,6 @@ def SnapshotTask.finished (stx? : Option Syntax) (a : α) : SnapshotTask α wher
-- irrelevant when already finished
reportingRange? := none
task := .pure a
cancelTk? := none
/-- Transforms a task's output without changing the processed syntax. -/
def SnapshotTask.map (t : SnapshotTask α) (f : α β) (stx? : Option Syntax := t.stx?)

View File

@@ -397,7 +397,7 @@ where
diagnostics := oldProcessed.diagnostics
result? := some {
cmdState := oldProcSuccess.cmdState
firstCmdSnap := { stx? := none, task := prom.result!, cancelTk? := cancelTk } } }
firstCmdSnap := { stx? := none, task := prom.result! } } }
else
return .finished newStx oldProcessed) } }
else return old
@@ -450,7 +450,7 @@ where
processHeader (stx : Syntax) (parserState : Parser.ModuleParserState) :
LeanProcessingM (SnapshotTask HeaderProcessedSnapshot) := do
let ctx read
SnapshotTask.ofIO stx none (some 0, ctx.input.endPos) <|
SnapshotTask.ofIO stx (some 0, ctx.input.endPos) <|
ReaderT.run (r := ctx) <| -- re-enter reader in new task
withHeaderExceptions (α := HeaderProcessedSnapshot) ({ · with result? := none }) do
let setup match ( setupImports stx) with
@@ -507,7 +507,7 @@ where
infoTree? := cmdState.infoState.trees[0]!
result? := some {
cmdState
firstCmdSnap := { stx? := none, task := prom.result!, cancelTk? := cancelTk }
firstCmdSnap := { stx? := none, task := prom.result! }
}
}
@@ -523,19 +523,17 @@ where
-- from `old`
if let some oldNext := old.nextCmdSnap? then do
let newProm IO.Promise.new
let cancelTk IO.CancelToken.new
-- can reuse range, syntax unchanged
BaseIO.chainTask (sync := true) old.resultSnap.task fun oldResult =>
-- also wait on old command parse snapshot as parsing is cheap and may allow for
-- elaboration reuse
BaseIO.chainTask (sync := true) oldNext.task fun oldNext => do
let cancelTk IO.CancelToken.new
parseCmd oldNext newParserState oldResult.cmdState newProm sync cancelTk ctx
prom.resolve <| { old with nextCmdSnap? := some {
stx? := none
reportingRange? := some newParserState.pos, ctx.input.endPos
task := newProm.result!
cancelTk? := cancelTk
} }
task := newProm.result! } }
else prom.resolve old -- terminal command, we're done!
-- fast path, do not even start new task for this snapshot (see [Incremental Parsing])
@@ -617,16 +615,15 @@ where
})
let diagnostics Snapshot.Diagnostics.ofMessageLog msgLog
-- use per-command cancellation token for elaboration so that cancellation of further commands
-- does not affect current command
-- use per-command cancellation token for elaboration so that
let elabCmdCancelTk IO.CancelToken.new
prom.resolve {
diagnostics, nextCmdSnap?
stx := stx', parserState := parserState'
elabSnap := { stx? := stx', task := elabPromise.result!, cancelTk? := some elabCmdCancelTk }
resultSnap := { stx? := stx', reportingRange? := initRange?, task := resultPromise.result!, cancelTk? := none }
infoTreeSnap := { stx? := stx', reportingRange? := initRange?, task := finishedPromise.result!, cancelTk? := none }
reportSnap := { stx? := none, reportingRange? := initRange?, task := reportPromise.result!, cancelTk? := none }
resultSnap := { stx? := stx', reportingRange? := initRange?, task := resultPromise.result! }
infoTreeSnap := { stx? := stx', reportingRange? := initRange?, task := finishedPromise.result! }
reportSnap := { stx? := none, reportingRange? := initRange?, task := reportPromise.result! }
}
let cmdState doElab stx cmdState beginPos
{ old? := old?.map fun old => old.stx, old.elabSnap, new := elabPromise }
@@ -668,8 +665,8 @@ where
-- We want to trace all of `CommandParsedSnapshot` but `traceTask` is part of it, so let's
-- create a temporary snapshot tree containing all tasks but it
let snaps := #[
{ stx? := stx', task := elabPromise.result!.map (sync := true) toSnapshotTree, cancelTk? := none },
{ stx? := stx', task := resultPromise.result!.map (sync := true) toSnapshotTree, cancelTk? := none }] ++
{ stx? := stx', task := elabPromise.result!.map (sync := true) toSnapshotTree },
{ stx? := stx', task := resultPromise.result!.map (sync := true) toSnapshotTree }] ++
cmdState.snapshotTasks
let tree := SnapshotTree.mk { diagnostics := .empty } snaps
BaseIO.bindTask ( tree.waitAll) fun _ => do
@@ -693,7 +690,6 @@ where
stx? := none
reportingRange? := initRange?
task := traceTask
cancelTk? := none
}
if let some next := next? then
-- We're definitely off the fast-forwarding path now

View File

@@ -2279,7 +2279,6 @@ def realizeConst (forConst : Name) (constName : Name) (realize : MetaM Unit) :
initHeartbeats := ( IO.getNumHeartbeats)
}
let (env, exTask, dyn) env.realizeConst forConst constName (realizeAndReport coreCtx)
-- Realizations cannot be cancelled as their result is shared across elaboration runs
let exAct Core.wrapAsyncAsSnapshot (cancelTk? := none) fun
| none => return
| some ex => do
@@ -2287,7 +2286,6 @@ def realizeConst (forConst : Name) (constName : Name) (realize : MetaM Unit) :
Core.logSnapshotTask {
stx? := none
task := ( BaseIO.mapTask (t := exTask) exAct)
cancelTk? := none
}
if let some res := dyn.get? RealizeConstantResult then
let mut snap := res.snap

View File

@@ -52,8 +52,6 @@ builtin_initialize registerTraceClass `grind.split.candidate
builtin_initialize registerTraceClass `grind.split.resolved
builtin_initialize registerTraceClass `grind.beta
builtin_initialize registerTraceClass `grind.mbtc
builtin_initialize registerTraceClass `grind.ext
builtin_initialize registerTraceClass `grind.ext.candidate
/-! Trace options for `grind` developers -/
builtin_initialize registerTraceClass `grind.debug
@@ -78,6 +76,5 @@ builtin_initialize registerTraceClass `grind.debug.proveEq
builtin_initialize registerTraceClass `grind.debug.pushNewFact
builtin_initialize registerTraceClass `grind.debug.ematch.activate
builtin_initialize registerTraceClass `grind.debug.appMap
builtin_initialize registerTraceClass `grind.debug.ext
end Lean

View File

@@ -22,17 +22,50 @@ namespace Lean
builtin_initialize registerTraceClass `grind.cutsat
builtin_initialize registerTraceClass `grind.cutsat.model
builtin_initialize registerTraceClass `grind.cutsat.subst
builtin_initialize registerTraceClass `grind.cutsat.eq
builtin_initialize registerTraceClass `grind.cutsat.eq.unsat (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.eq.trivial (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.assert
builtin_initialize registerTraceClass `grind.cutsat.assert.trivial
builtin_initialize registerTraceClass `grind.cutsat.assert.unsat
builtin_initialize registerTraceClass `grind.cutsat.assert.store
builtin_initialize registerTraceClass `grind.cutsat.assert.dvd
builtin_initialize registerTraceClass `grind.cutsat.dvd
builtin_initialize registerTraceClass `grind.cutsat.dvd.update (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.dvd.unsat (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.dvd.trivial (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.dvd.solve (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.dvd.solve.combine (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.dvd.solve.elim (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.internalize
builtin_initialize registerTraceClass `grind.cutsat.internalize.term (inherited := true)
builtin_initialize registerTraceClass `grind.debug.cutsat.subst
builtin_initialize registerTraceClass `grind.cutsat.assert.le
builtin_initialize registerTraceClass `grind.cutsat.le
builtin_initialize registerTraceClass `grind.cutsat.le.unsat (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.le.trivial (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.le.lower (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.le.upper (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.assign
builtin_initialize registerTraceClass `grind.cutsat.conflict
builtin_initialize registerTraceClass `grind.cutsat.diseq
builtin_initialize registerTraceClass `grind.cutsat.diseq.trivial (inherited := true)
builtin_initialize registerTraceClass `grind.debug.cutsat.eq
builtin_initialize registerTraceClass `grind.debug.cutsat.dvd.le
builtin_initialize registerTraceClass `grind.debug.cutsat.diseq
builtin_initialize registerTraceClass `grind.debug.cutsat.diseq.split
builtin_initialize registerTraceClass `grind.debug.cutsat.backtrack
builtin_initialize registerTraceClass `grind.debug.cutsat.search
builtin_initialize registerTraceClass `grind.debug.cutsat.search.split (inherited := true)
builtin_initialize registerTraceClass `grind.debug.cutsat.search.assign (inherited := true)
builtin_initialize registerTraceClass `grind.debug.cutsat.search.conflict (inherited := true)
builtin_initialize registerTraceClass `grind.debug.cutsat.search.backtrack (inherited := true)
builtin_initialize registerTraceClass `grind.debug.cutsat.cooper
builtin_initialize registerTraceClass `grind.debug.cutsat.cooper.diseq
builtin_initialize registerTraceClass `grind.debug.cutsat.conflict
builtin_initialize registerTraceClass `grind.debug.cutsat.assign
builtin_initialize registerTraceClass `grind.debug.cutsat.subst
builtin_initialize registerTraceClass `grind.debug.cutsat.getBestLower
builtin_initialize registerTraceClass `grind.debug.cutsat.nat
builtin_initialize registerTraceClass `grind.debug.cutsat.proof
builtin_initialize registerTraceClass `grind.debug.cutsat.internalize
builtin_initialize registerTraceClass `grind.debug.cutsat.markTerm
builtin_initialize registerTraceClass `grind.debug.cutsat.natCast
end Lean

View File

@@ -34,7 +34,7 @@ def DvdCnstr.applyEq (a : Int) (x : Var) (c₁ : EqCnstr) (b : Int) (c₂ : DvdC
let q := c₂.p
let d := Int.ofNat (a * c₂.d).natAbs
let p := (q.mul a |>.combine (p.mul (-b)))
trace[grind.debug.cutsat.subst] "{← getVar x}, {← c₁.pp}, {← c₂.pp}"
trace[grind.cutsat.subst] "{← getVar x}, {← c₁.pp}, {← c₂.pp}"
return { d, p, h := .subst x c₁ c₂ }
partial def DvdCnstr.applySubsts (c : DvdCnstr) : GoalM DvdCnstr := withIncRecDepth do
@@ -46,20 +46,21 @@ partial def DvdCnstr.applySubsts (c : DvdCnstr) : GoalM DvdCnstr := withIncRecDe
/-- Asserts divisibility constraint. -/
partial def DvdCnstr.assert (c : DvdCnstr) : GoalM Unit := withIncRecDepth do
if ( inconsistent) then return ()
trace[grind.cutsat.assert] "{← c.pp}"
trace[grind.cutsat.dvd] "{← c.pp}"
let c c.norm.applySubsts
if c.isUnsat then
trace[grind.cutsat.assert.unsat] "{← c.pp}"
trace[grind.cutsat.dvd.unsat] "{← c.pp}"
setInconsistent (.dvd c)
return ()
if c.isTrivial then
trace[grind.cutsat.assert.trivial] "{← c.pp}"
trace[grind.cutsat.dvd.trivial] "{← c.pp}"
return ()
let d₁ := c.d
let .add a₁ x p₁ := c.p | c.throwUnexpected
if ( c.satisfied) == .false then
resetAssignmentFrom x
if let some c' := ( get').dvds[x]! then
trace[grind.cutsat.dvd.solve] "{← c.pp}, {← c'.pp}"
let d₂ := c'.d
let .add a₂ _ p₂ := c'.p | c'.throwUnexpected
let (d, α, β) := gcdExt (a₁*d₂) (a₂*d₁)
@@ -74,14 +75,16 @@ partial def DvdCnstr.assert (c : DvdCnstr) : GoalM Unit := withIncRecDepth do
let α_d₂_p₁ := p₁.mul (α*d₂)
let β_d₁_p₂ := p₂.mul (β*d₁)
let combine := { d := d₁*d₂, p := .add d x (α_d₂_p₁.combine β_d₁_p₂), h := .solveCombine c c' : DvdCnstr }
trace[grind.cutsat.dvd.solve.combine] "{← combine.pp}"
modify' fun s => { s with dvds := s.dvds.set x none}
combine.assert
let a₂_p₁ := p₁.mul a₂
let a₁_p₂ := p₂.mul (-a₁)
let elim := { d, p := a₂_p₁.combine a₁_p₂, h := .solveElim c c' : DvdCnstr }
trace[grind.cutsat.dvd.solve.elim] "{← elim.pp}"
elim.assert
else
trace[grind.cutsat.assert.store] "{← c.pp}"
trace[grind.cutsat.dvd.update] "{← c.pp}"
c.p.updateOccs
modify' fun s => { s with dvds := s.dvds.set x (some c) }
@@ -94,6 +97,7 @@ def propagateIntDvd (e : Expr) : GoalM Unit := do
if ( isEqTrue e) then
let p toPoly b
let c := { d, p, h := .core e : DvdCnstr }
trace[grind.cutsat.assert.dvd] "{← c.pp}"
c.assert
else if ( isEqFalse e) then
pushNewFact <| mkApp4 (mkConst ``Int.Linear.of_not_dvd) a b reflBoolTrue (mkOfEqFalseCore e ( mkEqFalseProof e))

View File

@@ -38,12 +38,12 @@ def DiseqCnstr.applyEq (a : Int) (x : Var) (c₁ : EqCnstr) (b : Int) (c₂ : Di
let p := c₁.p
let q := c₂.p
let p := p.mul b |>.combine (q.mul (-a))
trace[grind.debug.cutsat.subst] "{← getVar x}, {← c₁.pp}, {← c₂.pp}"
trace[grind.cutsat.subst] "{← getVar x}, {← c₁.pp}, {← c₂.pp}"
return { p, h := .subst x c₁ c₂ }
partial def DiseqCnstr.applySubsts (c : DiseqCnstr) : GoalM DiseqCnstr := withIncRecDepth do
let some (x, c₁, p) c.p.substVar | return c
trace[grind.debug.cutsat.subst] "{← getVar x}, {← c.pp}, {← c₁.pp}"
trace[grind.cutsat.subst] "{← getVar x}, {← c.pp}, {← c₁.pp}"
applySubsts { p, h := .subst x c₁ c }
/--
@@ -68,11 +68,10 @@ def DiseqCnstr.assert (c : DiseqCnstr) : GoalM Unit := do
trace[grind.cutsat.assert] "{← c.pp}"
let c c.norm.applySubsts
if c.p.isUnsatDiseq then
trace[grind.cutsat.assert.unsat] "{← c.pp}"
setInconsistent (.diseq c)
return ()
if c.isTrivial then
trace[grind.cutsat.assert.trivial] "{← c.pp}"
trace[grind.cutsat.diseq.trivial] "{← c.pp}"
return ()
let k := c.p.gcdCoeffs c.p.getConst
let c := if k == 1 then
@@ -83,7 +82,7 @@ def DiseqCnstr.assert (c : DiseqCnstr) : GoalM Unit := do
return ()
let .add _ x _ := c.p | c.throwUnexpected
c.p.updateOccs
trace[grind.cutsat.assert.store] "{← c.pp}"
trace[grind.cutsat.diseq] "{← c.pp}"
modify' fun s => { s with diseqs := s.diseqs.modify x (·.push c) }
if ( c.satisfied) == .false then
resetAssignmentFrom x
@@ -109,7 +108,7 @@ where
partial def EqCnstr.applySubsts (c : EqCnstr) : GoalM EqCnstr := withIncRecDepth do
let some (x, c₁, p) c.p.substVar | return c
trace[grind.debug.cutsat.subst] "{← getVar x}, {← c.pp}, {← c₁.pp}"
trace[grind.cutsat.subst] "{← getVar x}, {← c.pp}, {← c₁.pp}"
applySubsts { p, h := .subst x c₁ c : EqCnstr }
private def updateDvdCnstr (a : Int) (x : Var) (c : EqCnstr) (y : Var) : GoalM Unit := do
@@ -198,11 +197,10 @@ def EqCnstr.assertImpl (c : EqCnstr) : GoalM Unit := do
trace[grind.cutsat.assert] "{← c.pp}"
let c c.norm.applySubsts
if c.p.isUnsatEq then
trace[grind.cutsat.assert.unsat] "{← c.pp}"
setInconsistent (.eq c)
return ()
if c.isTrivial then
trace[grind.cutsat.assert.trivial] "{← c.pp}"
trace[grind.cutsat.eq.trivial] "{← c.pp}"
return ()
let k := c.p.gcdCoeffs'
if c.p.getConst % k > 0 then
@@ -212,9 +210,9 @@ def EqCnstr.assertImpl (c : EqCnstr) : GoalM Unit := do
c
else
{ p := c.p.div k, h := .divCoeffs c }
trace[grind.cutsat.eq] "{← c.pp}"
let some (k, x) := c.p.pickVarToElim? | c.throwUnexpected
trace[grind.debug.cutsat.subst] ">> {← getVar x}, {← c.pp}"
trace[grind.cutsat.assert.store] "{← c.pp}"
modify' fun s => { s with
elimEqs := s.elimEqs.set x (some c)
elimStack := x :: s.elimStack
@@ -254,6 +252,7 @@ private def processNewNatEq (a b : Expr) : GoalM Unit := do
@[export lean_process_cutsat_eq]
def processNewEqImpl (a b : Expr) : GoalM Unit := do
trace[grind.debug.cutsat.eq] "{a} = {b}"
match ( foreignTerm? a), ( foreignTerm? b) with
| none, none => processNewIntEq a b
| some .nat, some .nat => processNewNatEq a b
@@ -272,6 +271,7 @@ private def processNewIntLitEq (a ke : Expr) : GoalM Unit := do
@[export lean_process_cutsat_eq_lit]
def processNewEqLitImpl (a ke : Expr) : GoalM Unit := do
trace[grind.debug.cutsat.eq] "{a} = {ke}"
match ( foreignTerm? a) with
| none => processNewIntLitEq a ke
| some .nat => processNewNatEq a ke
@@ -294,10 +294,12 @@ private def processNewNatDiseq (a b : Expr) : GoalM Unit := do
let rhs' toLinearExpr ( rhs.denoteAsIntExpr ctx) gen
let p := lhs'.sub rhs' |>.norm
let c := { p, h := .coreNat a b lhs rhs lhs' rhs' : DiseqCnstr }
trace[grind.debug.cutsat.nat] "{← c.pp}"
c.assert
@[export lean_process_cutsat_diseq]
def processNewDiseqImpl (a b : Expr) : GoalM Unit := do
trace[grind.debug.cutsat.diseq] "{a} ≠ {b}"
match ( foreignTerm? a), ( foreignTermOrLit? b) with
| none, none => processNewIntDiseq a b
| some .nat, some .nat => processNewNatDiseq a b
@@ -340,7 +342,7 @@ private def isForbiddenParent (parent? : Option Expr) (k : SupportedTermKind) :
private def internalizeInt (e : Expr) : GoalM Unit := do
if ( hasVar e) then return ()
let p toPoly e
trace[grind.debug.cutsat.internalize] "{aquote e}:= {← p.pp}"
trace[grind.cutsat.internalize] "{aquote e}:= {← p.pp}"
let x mkVar e
if p == .add 1 x (.num 0) then
-- It is pointless to assert `x = x`
@@ -401,8 +403,9 @@ private def internalizeNat (e : Expr) : GoalM Unit := do
let e'' : Int.Linear.Expr toLinearExpr e'' gen
let p := e''.norm
let natCast_e shareCommon (mkIntNatCast e)
trace[grind.cutsat.internalize] "natCast: {natCast_e}"
internalize natCast_e gen
trace[grind.debug.cutsat.internalize] "{aquote natCast_e}:= {← p.pp}"
trace[grind.cutsat.internalize] "{aquote natCast_e}:= {← p.pp}"
let x mkVar natCast_e
modify' fun s => { s with foreignDef := s.foreignDef.insert { expr := e } x }
let c := { p := .add (-1) x p, h := .defnNat e' x e'' : EqCnstr }

View File

@@ -24,6 +24,7 @@ def mkForeignVar (e : Expr) (t : ForeignType) : GoalM Var := do
foreignVars := s.foreignVars.insert t (vars.push e)
foreignVarMap := s.foreignVarMap.insert { expr := e} (x, t)
}
trace[grind.debug.cutsat.markTerm] "mkForeignVar: {e}"
markAsCutsatTerm e
return x

View File

@@ -100,24 +100,24 @@ where
@[export lean_grind_cutsat_assert_le]
def LeCnstr.assertImpl (c : LeCnstr) : GoalM Unit := do
if ( inconsistent) then return ()
trace[grind.cutsat.assert] "{← c.pp}"
let c c.norm.applySubsts
if c.isUnsat then
trace[grind.cutsat.assert.unsat] "{← c.pp}"
trace[grind.cutsat.le.unsat] "{← c.pp}"
setInconsistent (.le c)
return ()
if c.isTrivial then
trace[grind.cutsat.assert.trivial] "{← c.pp}"
trace[grind.cutsat.le.trivial] "{← c.pp}"
return ()
let .add a x _ := c.p | c.throwUnexpected
if ( findEq c) then
return ()
let c refineWithDiseq c
trace[grind.cutsat.assert.store] "{← c.pp}"
if a < 0 then
trace[grind.cutsat.le.lower] "{← c.pp}"
c.p.updateOccs
modify' fun s => { s with lowers := s.lowers.modify x (·.push c) }
else
trace[grind.cutsat.le.upper] "{← c.pp}"
c.p.updateOccs
modify' fun s => { s with uppers := s.uppers.modify x (·.push c) }
if ( c.satisfied) == .false then
@@ -145,6 +145,7 @@ def propagateIntLe (e : Expr) (eqTrue : Bool) : GoalM Unit := do
pure { p, h := .core e : LeCnstr }
else
pure { p := p.mul (-1) |>.addConst 1, h := .coreNeg e p : LeCnstr }
trace[grind.cutsat.assert.le] "{← c.pp}"
c.assert
def propagateNatLe (e : Expr) (eqTrue : Bool) : GoalM Unit := do
@@ -154,6 +155,7 @@ def propagateNatLe (e : Expr) (eqTrue : Bool) : GoalM Unit := do
let lhs' toLinearExpr ( lhs.denoteAsIntExpr ctx) gen
let rhs' toLinearExpr ( rhs.denoteAsIntExpr ctx) gen
let p := lhs'.sub rhs' |>.norm
trace[grind.debug.cutsat.nat] "{← p.pp}"
let c if eqTrue then
pure { p, h := .coreNat e lhs rhs lhs' rhs' : LeCnstr }
else

View File

@@ -136,7 +136,9 @@ def assertDenoteAsIntNonneg (e : Expr) : GoalM Unit := withIncRecDepth do
let lhs' : Int.Linear.Expr := .num 0
let rhs' toLinearExpr ( rhs.denoteAsIntExpr ctx) gen
let p := lhs'.sub rhs' |>.norm
trace[grind.debug.cutsat.nat] "{← p.pp}"
let c := { p, h := .denoteAsIntNonneg rhs rhs' : LeCnstr }
trace[grind.cutsat.assert.le] "{← c.pp}"
c.assert
/--
@@ -147,6 +149,7 @@ def assertNatCast (e : Expr) (x : Var) : GoalM Unit := do
let_expr NatCast.natCast _ inst a := e | return ()
let_expr instNatCastInt := inst | return ()
if ( get').foreignDef.contains { expr := a } then return ()
trace[grind.debug.cutsat.natCast] "{a}"
let n mkForeignVar a .nat
let p := .add (-1) x (.num 0)
let c := { p, h := .denoteAsIntNonneg (.var n) (.var x) : LeCnstr}

View File

@@ -301,6 +301,7 @@ partial def LeCnstr.toExprProof (c' : LeCnstr) : ProofM Expr := caching c' do
( getContext) ( mkPolyDecl p₁) ( mkPolyDecl p₂) ( mkPolyDecl c₃.p) (toExpr c₃.d) (toExpr s.k) (toExpr coeff) ( mkPolyDecl c'.p) ( s.toExprProof) reflBoolTrue
partial def DiseqCnstr.toExprProof (c' : DiseqCnstr) : ProofM Expr := caching c' do
trace[grind.debug.cutsat.proof] "{← c'.pp}"
match c'.h with
| .core0 a zero =>
mkDiseqProof a zero
@@ -351,6 +352,7 @@ partial def CooperSplit.toExprProof (s : CooperSplit) : ProofM Expr := caching s
-- `pred` is an expressions of the form `cooper_*_split ...` with type `Nat → Prop`
let mut k := n
let mut result := base -- `OrOver k (cooper_*_splti)
trace[grind.debug.cutsat.proof] "orOver_cases {n}"
result := mkApp3 (mkConst ``Int.Linear.orOver_cases) (toExpr (n-1)) pred result
for (fvarId, c) in hs do
let type := mkApp pred (toExpr (k-1))
@@ -363,6 +365,7 @@ partial def CooperSplit.toExprProof (s : CooperSplit) : ProofM Expr := caching s
return result
partial def UnsatProof.toExprProofCore (h : UnsatProof) : ProofM Expr := do
trace[grind.debug.cutsat.proof] "{← h.pp}"
match h with
| .le c =>
return mkApp4 (mkConst ``Int.Linear.le_unsat) ( getContext) ( mkPolyDecl c.p) reflBoolTrue ( c.toExprProof)
@@ -389,6 +392,7 @@ def UnsatProof.toExprProof (h : UnsatProof) : GoalM Expr := do
withProofContext do h.toExprProofCore
def setInconsistent (h : UnsatProof) : GoalM Unit := do
trace[grind.debug.cutsat.conflict] "setInconsistent [{← inconsistent}]: {← h.pp}"
if ( get').caseSplits then
-- Let the search procedure in `SearchM` resolve the conflict.
modify' fun s => { s with conflict? := some h }

View File

@@ -27,6 +27,7 @@ def CooperSplit.assert (cs : CooperSplit) : GoalM Unit := do
let p₁' := p.mul b |>.combine (q.mul (-a))
let p₁' := p₁'.addConst <| if left then b*k else (-a)*k
let c₁' := { p := p₁', h := .cooper cs : LeCnstr }
trace[grind.debug.cutsat.cooper] "{← c₁'.pp}"
c₁'.assert
if ( inconsistent) then return ()
let d := if left then a else b
@@ -34,6 +35,7 @@ def CooperSplit.assert (cs : CooperSplit) : GoalM Unit := do
let p₂' := if left then p else q
let p₂' := p₂'.addConst k
let c₂' := { d, p := p₂', h := .cooper₁ cs : DvdCnstr }
trace[grind.debug.cutsat.cooper] "dvd₁: {← c₂'.pp}"
c₂'.assert
if ( inconsistent) then return ()
let some c₃ := c₃? | return ()
@@ -49,6 +51,7 @@ def CooperSplit.assert (cs : CooperSplit) : GoalM Unit := do
let p₃' := q.mul (-c) |>.combine (s.mul b)
let p₃' := p₃'.addConst (-c*k)
{ d := b*d, p := p₃', h := .cooper₂ cs : DvdCnstr }
trace[grind.debug.cutsat.cooper] "dvd₂: {← c₃'.pp}"
c₃'.assert
private def checkIsNextVar (x : Var) : GoalM Unit := do
@@ -56,7 +59,7 @@ private def checkIsNextVar (x : Var) : GoalM Unit := do
throwError "`grind` internal error, assigning variable out of order"
private def traceAssignment (x : Var) (v : Rat) : GoalM Unit := do
trace[grind.debug.cutsat.search.assign] "{quoteIfArithTerm (← getVar x)} := {v}"
trace[grind.cutsat.assign] "{quoteIfArithTerm (← getVar x)} := {v}"
private def setAssignment (x : Var) (v : Rat) : GoalM Unit := do
checkIsNextVar x
@@ -85,6 +88,7 @@ where
modify' fun s => { s with assignment := s.assignment.set x 0 }
let some v c.p.eval? | c.throwUnexpected
let v := (-v) / a
trace[grind.debug.cutsat.assign] "{← getVar x}, {← c.pp}, {v}"
traceAssignment x v
modify' fun s => { s with assignment := s.assignment.set x v }
go xs
@@ -104,6 +108,7 @@ def tightUsingDvd (c : LeCnstr) (dvd? : Option DvdCnstr) : GoalM LeCnstr := do
let b₂ := c.p.getConst
if (b₂ - b₁) % d != 0 then
let b₂' := b₁ - d * ((b₁ - b₂) / d)
trace[grind.debug.cutsat.dvd.le] "[pos] {← c.pp}, {← dvd.pp}, {b₂'}"
let p := c.p.addConst (b₂'-b₂)
return { p, h := .dvdTight dvd c }
if eqCoeffs dvd.p c.p true then
@@ -111,6 +116,7 @@ def tightUsingDvd (c : LeCnstr) (dvd? : Option DvdCnstr) : GoalM LeCnstr := do
let b₂ := c.p.getConst
if (b₂ - b₁) % d != 0 then
let b₂' := b₁ - d * ((b₁ - b₂) / d)
trace[grind.debug.cutsat.dvd.le] "[neg] {← c.pp}, {← dvd.pp}, {b₂'}"
let p := c.p.addConst (b₂'-b₂)
return { p, h := .negDvdTight dvd c }
return c
@@ -128,6 +134,7 @@ def getBestLower? (x : Var) (dvd? : Option DvdCnstr) : GoalM (Option (Rat × LeC
let .add k _ p := c.p | c.throwUnexpected
let some v p.eval? | c.throwUnexpected
let lower' := v / (-k)
trace[grind.debug.cutsat.getBestLower] "k: {k}, x: {x}, p: {repr p}, v: {v}, best?: {best?.map (·.1)}, c: {← c.pp}"
if let some (lower, _) := best? then
if lower' > lower then
best? := some (lower', c)
@@ -207,7 +214,7 @@ def DvdCnstr.getSolutions? (c : DvdCnstr) : SearchM (Option DvdSolution) := do
return some { d, b := -b*a' }
def resolveDvdConflict (c : DvdCnstr) : GoalM Unit := do
trace[grind.debug.cutsat.search.conflict] "{← c.pp}"
trace[grind.cutsat.conflict] "{← c.pp}"
let d := c.d
let .add a _ p := c.p | c.throwUnexpected
{ d := a.gcd d, p, h := .elim c : DvdCnstr }.assert
@@ -293,7 +300,7 @@ partial def findRatVal (lower upper : Rat) (diseqVals : Array (Rat × DiseqCnstr
v
def resolveRealLowerUpperConflict (c₁ c₂ : LeCnstr) : GoalM Bool := do
trace[grind.debug.cutsat.search.conflict] "{← c₁.pp}, {← c₂.pp}"
trace[grind.cutsat.conflict] "{← c₁.pp}, {← c₂.pp}"
let .add a₁ _ p₁ := c₁.p | c₁.throwUnexpected
let .add a₂ _ p₂ := c₂.p | c₂.throwUnexpected
let p := p₁.mul a₂.natAbs |>.combine (p₂.mul a₁.natAbs)
@@ -306,7 +313,7 @@ def resolveRealLowerUpperConflict (c₁ c₂ : LeCnstr) : GoalM Bool := do
{ p, h := .combine c₁ c₂ : LeCnstr }
else
{ p := p.div k, h := .combineDivCoeffs c₁ c₂ k : LeCnstr }
trace[grind.debug.cutsat.search.conflict] "resolved: {← c.pp}"
trace[grind.cutsat.conflict] "resolved: {← c.pp}"
c.assert
return true
@@ -323,7 +330,7 @@ def resolveCooperUnary (pred : CooperSplitPred) : SearchM Bool := do
return true
def resolveCooperPred (pred : CooperSplitPred) : SearchM Unit := do
trace[grind.debug.cutsat.search.conflict] "[{pred.numCases}]: {← pred.pp}"
trace[grind.cutsat.conflict] "[{pred.numCases}]: {← pred.pp}"
if ( resolveCooperUnary pred) then
return
let n := pred.numCases
@@ -340,11 +347,11 @@ def resolveCooperDvd (c₁ c₂ : LeCnstr) (c₃ : DvdCnstr) : SearchM Unit := d
def DiseqCnstr.split (c : DiseqCnstr) : SearchM LeCnstr := do
let fvarId if let some fvarId := ( get').diseqSplits.find? c.p then
trace[grind.debug.cutsat.search.split] "{← c.pp}, reusing {fvarId.name}"
trace[grind.debug.cutsat.diseq.split] "{← c.pp}, reusing {fvarId.name}"
pure fvarId
else
let fvarId mkCase (.diseq c)
trace[grind.debug.cutsat.search.split] "{← c.pp}, {fvarId.name}"
trace[grind.debug.cutsat.diseq.split] "{← c.pp}, {fvarId.name}"
modify' fun s => { s with diseqSplits := s.diseqSplits.insert c.p fvarId }
pure fvarId
let p₂ := c.p.addConst 1
@@ -421,6 +428,7 @@ def processVar (x : Var) : SearchM Unit := do
setAssignment x v
| some (lower, c₁), some (upper, c₂) =>
trace[grind.debug.cutsat.search] "{lower} ≤ {lower.ceil} ≤ {quoteIfArithTerm (← getVar x)} ≤ {upper.floor} ≤ {upper}"
trace[grind.debug.cutsat.getBestLower] "lower: {lower}, c₁: {← c₁.pp}"
if lower > upper then
let .true resolveRealLowerUpperConflict c₁ c₂
| throwError "`grind` internal error, conflict resolution failed"
@@ -464,42 +472,43 @@ private def findCase (decVars : FVarIdSet) : SearchM Case := do
if decVars.contains case.fvarId then
return case
-- Conflict does not depend on this case.
trace[grind.debug.cutsat.search.backtrack] "skipping {case.fvarId.name}"
trace[grind.debug.cutsat.backtrack] "skipping {case.fvarId.name}"
unreachable!
private def union (vs₁ vs₂ : FVarIdSet) : FVarIdSet :=
vs₁.fold (init := vs₂) (·.insert ·)
def resolveConflict (h : UnsatProof) : SearchM Unit := do
trace[grind.debug.cutsat.search.backtrack] "resolve conflict, decision stack: {(← get).cases.toList.map fun c => c.fvarId.name}"
trace[grind.debug.cutsat.backtrack] "resolve conflict, decision stack: {(← get).cases.toList.map fun c => c.fvarId.name}"
let decVars := h.collectDecVars.run ( get).decVars
trace[grind.debug.cutsat.search.backtrack] "dec vars: {decVars.toList.map (·.name)}"
trace[grind.debug.cutsat.backtrack] "dec vars: {decVars.toList.map (·.name)}"
if decVars.isEmpty then
trace[grind.debug.cutsat.search.backtrack] "close goal: {← h.pp}"
trace[grind.debug.cutsat.backtrack] "close goal: {← h.pp}"
closeGoal ( h.toExprProof)
return ()
let c findCase decVars
modify' fun _ => c.saved
trace[grind.debug.cutsat.search.backtrack] "backtracking {c.fvarId.name}"
trace[grind.debug.cutsat.backtrack] "backtracking {c.fvarId.name}"
let decVars := decVars.erase c.fvarId
match c.kind with
| .diseq c₁ =>
let decVars := decVars.toArray
let p' := c₁.p.mul (-1) |>.addConst 1
let c' := { p := p', h := .ofDiseqSplit c₁ c.fvarId h decVars : LeCnstr }
trace[grind.debug.cutsat.search.backtrack] "resolved diseq split: {← c'.pp}"
trace[grind.debug.cutsat.backtrack] "resolved diseq split: {← c'.pp}"
c'.assert
| .cooper pred hs decVars' =>
let decVars' := union decVars decVars'
let n := pred.numCases
let hs := hs.push (c.fvarId, h)
trace[grind.debug.cutsat.search.backtrack] "cooper #{hs.size + 1}, {← pred.pp}, {hs.map fun p => p.1.name}"
trace[grind.debug.cutsat.backtrack] "cooper #{hs.size + 1}, {← pred.pp}, {hs.map fun p => p.1.name}"
let s if hs.size + 1 < n then
let fvarId mkCase (.cooper pred hs decVars')
pure { pred, k := n - hs.size - 1, h := .dec fvarId : CooperSplit }
else
let decVars' := decVars'.toArray
trace[grind.debug.cutsat.search.backtrack] "cooper last case, {← pred.pp}, dec vars: {decVars'.map (·.name)}"
trace[grind.debug.cutsat.backtrack] "cooper last case, {← pred.pp}, dec vars: {decVars'.map (·.name)}"
trace[grind.debug.cutsat.proof] "CooperSplit.last"
pure { pred, k := 0, h := .last hs decVars' : CooperSplit }
s.assert

View File

@@ -74,6 +74,7 @@ def mkCase (kind : CaseKind) : SearchM FVarId := do
decVars := s.decVars.insert fvarId
}
modify' fun s => { s with caseSplits := true }
trace[grind.debug.cutsat.backtrack] "mkCase fvarId: {fvarId.name}"
return fvarId
end Lean.Meta.Grind.Arith.Cutsat

View File

@@ -16,7 +16,7 @@ def mkVarImpl (expr : Expr) : GoalM Var := do
if let some var := ( get').varMap.find? { expr } then
return var
let var : Var := ( get').vars.size
trace[grind.debug.cutsat.internalize] "{expr} ↦ #{var}"
trace[grind.cutsat.internalize.term] "{expr} ↦ #{var}"
modify' fun s => { s with
vars := s.vars.push expr
varMap := s.varMap.insert { expr } var
@@ -27,6 +27,7 @@ def mkVarImpl (expr : Expr) : GoalM Var := do
occurs := s.occurs.push {}
elimEqs := s.elimEqs.push none
}
trace[grind.debug.cutsat.markTerm] "mkVar: {expr}"
markAsCutsatTerm expr
assertNatCast expr var
assertDenoteAsIntNonneg expr

View File

@@ -71,7 +71,6 @@ def isArithTerm (e : Expr) : Bool :=
| HMul.hMul _ _ _ _ _ _ => true
| HDiv.hDiv _ _ _ _ _ _ => true
| HMod.hMod _ _ _ _ _ _ => true
| HPow.hPow _ _ _ _ _ _ => true
| Neg.neg _ _ _ => true
| OfNat.ofNat _ _ _ => true
| _ => false

View File

@@ -35,7 +35,6 @@ def instantiateExtTheorem (thm : Ext.ExtTheorem) (e : Expr) : GoalM Unit := with
if proof'.hasMVar || prop'.hasMVar then
reportIssue! "failed to apply extensionality theorem `{thm.declName}` for {indentExpr e}\nresulting terms contain metavariables"
return ()
trace[grind.ext] "{prop'}"
addNewRawFact proof' prop' (( getGeneration e) + 1)
end Lean.Meta.Grind

View File

@@ -114,13 +114,4 @@ def propagateForallPropDown (e : Expr) : GoalM Unit := do
-- (a → b) = True → b = False → a = False
pushEqFalse a <| mkApp4 (mkConst ``Grind.eq_false_of_imp_eq_true) a b ( mkEqTrueProof e) ( mkEqFalseProof b)
builtin_grind_propagator propagateExistsDown Exists := fun e => do
if ( isEqFalse e) then
let_expr f@Exists α p := e | return ()
let u := f.constLevels!
let notP := mkApp (mkConst ``Not) (mkApp p (.bvar 0) |>.headBeta)
let prop := mkForall `x .default α notP
let proof := mkApp3 (mkConst ``forall_not_of_not_exists u) α p (mkOfEqFalseCore e ( mkEqFalseProof e))
addNewRawFact proof prop ( getGeneration e)
end Lean.Meta.Grind

View File

@@ -50,6 +50,11 @@ private def updateAppMap (e : Expr) : GoalM Unit := do
s.appMap.insert key [e]
}
/-- Inserts `e` into the list of case-split candidates. -/
private def addSplitCandidate (e : Expr) : GoalM Unit := do
trace_goal[grind.split.candidate] "{e}"
modify fun s => { s with split.candidates := e :: s.split.candidates }
private def forbiddenSplitTypes := [``Eq, ``HEq, ``True, ``False]
/-- Returns `true` if `e` is of the form `@Eq Prop a b` -/
@@ -202,65 +207,6 @@ private def propagateUnitLike (a : Expr) (generation : Nat) : GoalM Unit := do
internalize unit generation
pushEq a unit <| ( mkEqRefl unit)
/-- Returns `true` if we can ignore `ext` for functions occurring as arguments of a `declName`-application. -/
private def extParentsToIgnore (declName : Name) : Bool :=
declName == ``Eq || declName == ``HEq || declName == ``dite || declName == ``ite
|| declName == ``Exists || declName == ``Subtype
/--
Given a term `e` that occurs as the argument at position `i` of an `f`-application `parent?`,
we consider `e` as a candidate for case-splitting. For every other argument `e'` that also appears
at position `i` in an `f`-application and has the same type as `e`, we add the case-split candidate `e = e'`.
When performing the case split, we consider the following two cases:
- `e = e'`, which may introduce a new congruence between the corresponding `f`-applications.
- `¬(e = e')`, which may trigger extensionality theorems for the type of `e`.
This feature enables `grind` to solve examples such as:
```lean
example (f : (Nat → Nat) → Nat) : a = b → f (fun x => a + x) = f (fun x => b + x) := by
grind
```
-/
private def addSplitCandidatesForExt (e : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit := do
let some parent := parent? | return ()
unless parent.isApp do return ()
let f := parent.getAppFn
if let .const declName _ := f then
if extParentsToIgnore declName then return ()
let type inferType e
-- Remark: we currently do not perform function extensionality on functions that produce a type that is not a proposition.
-- We may add an option to enable that in the future.
let u? typeFormerTypeLevel type
if u? != .none && u? != some .zero then return ()
let mut i := parent.getAppNumArgs
let mut it := parent
repeat
if !it.isApp then return ()
i := i - 1
let arg := it.appArg!
if isSameExpr arg e then
found f i type
it := it.appFn!
where
found (f : Expr) (i : Nat) (type : Expr) : GoalM Unit := do
trace[grind.debug.ext] "{f}, {i}, {e}"
let others := ( get).termsAt.find? (f, i) |>.getD []
for (e', type') in others do
if ( withDefault <| isDefEq type type') then
let eq := mkApp3 (mkConst ``Eq [ getLevel type]) type e e'
let eq shareCommon eq
internalize eq generation
trace_goal[grind.ext.candidate] "{eq}"
addSplitCandidate eq
modify fun s => { s with termsAt := s.termsAt.insert (f, i) ((e, type) :: others) }
return ()
/-- Applies `addSplitCandidatesForExt` if `funext` is enabled. -/
private def addSplitCandidatesForFunext (e : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit := do
unless ( getConfig).funext do return ()
addSplitCandidatesForExt e generation parent?
@[export lean_grind_internalize]
private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit := withIncRecDepth do
if ( alreadyInternalized e) then
@@ -283,10 +229,7 @@ private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Opt
| .fvar .. =>
mkENode' e generation
checkAndAddSplitCandidate e
| .letE .. =>
mkENode' e generation
| .lam .. =>
addSplitCandidatesForFunext e generation parent?
| .letE .. | .lam .. =>
mkENode' e generation
| .forallE _ d b _ =>
mkENode' e generation

View File

@@ -44,8 +44,6 @@ private def mkCandidateKey (a b : Expr) : Expr × Expr :=
/-- Model-based theory combination. -/
def mbtc (ctx : MBTC.Context) : GoalM Bool := do
unless ( getConfig).mbtc do return false
-- It is pointless to run `mbtc` if maximum number of splits has been reached.
if ( checkMaxCaseSplit) then return false
let mut map : Map := {}
let mut candidates : Candidates := {}
for ({ expr := e }, _) in ( get).enodes do
@@ -62,9 +60,8 @@ def mbtc (ctx : MBTC.Context) : GoalM Bool := do
unless others.any (isSameExpr arg ·) do
for other in others do
if ( ctx.eqAssignment arg other) then
if ( hasSameType arg other) then
let k := mkCandidateKey arg other
candidates := candidates.insert k
let k := mkCandidateKey arg other
candidates := candidates.insert k
map := map.insert (f, i) (arg :: others)
else
map := map.insert (f, i) [arg]
@@ -79,19 +76,13 @@ def mbtc (ctx : MBTC.Context) : GoalM Bool := do
b₁.lt b₂
else
a₁.lt a₂
let eqs result.filterMapM fun (a, b) => do
let eqs result.mapM fun (a, b) => do
let eq mkEq a b
trace[grind.mbtc] "{eq}"
let eq shareCommon ( canon eq)
if ( isKnownCaseSplit eq) then
return none
else
internalize eq (Nat.max ( getGeneration a) ( getGeneration b))
return some eq
if eqs.isEmpty then
return false
for eq in eqs do
addSplitCandidate eq
internalize eq (Nat.max ( getGeneration a) ( getGeneration b))
return eq
modify fun s => { s with split.candidates := s.split.candidates ++ eqs.toList }
return true
def mbtcTac (ctx : MBTC.Context) : GrindTactic := fun goal => do

View File

@@ -76,12 +76,11 @@ private def checkIffStatus (e a b : Expr) : GoalM CaseSplitStatus := do
/-- Returns `true` is `c` is congruent to a case-split that was already performed. -/
private def isCongrToPrevSplit (c : Expr) : GoalM Bool := do
unless c.isApp do return false
( get).split.resolved.foldM (init := false) fun flag { expr := c' } => do
if flag then
return true
else
return c'.isApp && isCongruent ( get).enodes c c'
return isCongruent ( get).enodes c c'
private def checkForallStatus (e : Expr) : GoalM CaseSplitStatus := do
if ( isEqTrue e) then

View File

@@ -480,8 +480,6 @@ structure Split.State where
candidates : List Expr := []
/-- Number of splits performed to get to this goal. -/
num : Nat := 0
/-- Case-splits that have been inserted at `candidates` at some point. -/
added : PHashSet ENodeKey := {}
/-- Case-splits that have already been performed, or that do not have to be performed anymore. -/
resolved : PHashSet ENodeKey := {}
/--
@@ -531,13 +529,6 @@ structure Goal where
arith : Arith.State := {}
/-- State of the clean name generator. -/
clean : Clean.State := {}
/--
Mapping from pairs `(f, i)` to a list of `(e, type)`.
The meaning is: `e : type` is lambda expression that occurs at argument `i` of an `f`-application.
We use this information to add case-splits for triggering extensionality theorems.
See `addSplitCandidatesForExt`.
-/
termsAt : PHashMap (Expr × Nat) (List (Expr × Expr)) := {}
deriving Inhabited
def Goal.admit (goal : Goal) : MetaM Unit :=
@@ -1162,14 +1153,6 @@ partial def Goal.getEqcs (goal : Goal) : List (List Expr) := Id.run do
def getEqcs : GoalM (List (List Expr)) :=
return ( get).getEqcs
/--
Returns `true` if `e` has been already added to the case-split list at one point.
Remark: this function returns `true` even if the split has already been resolved
and is not in the list anymore.
-/
def isKnownCaseSplit (e : Expr) : GoalM Bool :=
return ( get).split.added.contains { expr := e }
/-- Returns `true` if `e` is a case-split that does not need to be performed anymore. -/
def isResolvedCaseSplit (e : Expr) : GoalM Bool :=
return ( get).split.resolved.contains { expr := e }
@@ -1184,15 +1167,6 @@ def markCaseSplitAsResolved (e : Expr) : GoalM Unit := do
trace_goal[grind.split.resolved] "{e}"
modify fun s => { s with split.resolved := s.split.resolved.insert { expr := e } }
/-- Inserts `e` into the list of case-split candidates if it was not inserted before. -/
def addSplitCandidate (e : Expr) : GoalM Unit := do
unless ( isKnownCaseSplit e) do
trace_goal[grind.split.candidate] "{e}"
modify fun s => { s with
split.added := s.split.added.insert { expr := e }
split.candidates := e :: s.split.candidates
}
/--
Returns extensionality theorems for the given type if available.
If `Config.ext` is `false`, the result is `#[]`.

View File

@@ -97,8 +97,7 @@ def _root_.Lean.MVarId.clearAuxDecls (mvarId : MVarId) : MetaM MVarId := mvarId.
try
mvarId mvarId.clear fvarId
catch _ =>
let userName := ( fvarId.getDecl).userName
throwTacticEx `grind mvarId m!"the goal mentions the declaration `{userName}`, which is being defined. To avoid circular reasoning, try rewriting the goal to eliminate `{userName}` before using `grind`."
throwTacticEx `grind.clear_aux_decls mvarId "failed to clear local auxiliary declaration"
return mvarId
/--

View File

@@ -29,6 +29,16 @@ def withTypeAscription (d : Delab) (cond : Bool := true) : Delab := do
else
return stx
/--
If `pp.tagAppFns` is set, then `d` is evaluated with the delaborated head constant as the ref.
-/
def withFnRefWhenTagAppFns (d : Delab) : Delab := do
if ( getExpr).getAppFn.isConst && ( getPPOption getPPTagAppFns) then
let head withNaryFn delab
withRef head <| d
else
d
/--
Wraps the identifier (or identifier with explicit universe levels) with `@` if `pp.analysis.blockImplicit` is set to true.
-/
@@ -124,18 +134,6 @@ def delabConst : Delab := do
else
return stx
/--
If `pp.tagAppFns` is set, and if the current expression is a constant application,
then `d` is evaluated with the head constant delaborated with `delabConst` as the ref.
-/
def withFnRefWhenTagAppFns (d : Delab) : Delab := do
if ( getExpr).getAppFn.isConst && ( getPPOption getPPTagAppFns) then
-- delabConst in `pp.tagAppFns` mode annotates the term.
let head withNaryFn delabConst
withRef head <| d
else
d
def withMDataOptions [Inhabited α] (x : DelabM α) : DelabM α := do
match getExpr with
| Expr.mdata m .. =>

View File

@@ -207,7 +207,7 @@ This option can only be set on the command line, not in the lakefile or via `set
stickyInteractiveDiagnostics ++ docInteractiveDiagnostics
|>.map (·.toDiagnostic)
let notification := mkPublishDiagnosticsNotification doc.meta diagnostics
ctx.chanOut.sync.send notification
ctx.chanOut.send notification
open Language in
/--
@@ -239,7 +239,7 @@ This option can only be set on the command line, not in the lakefile or via `set
publishDiagnostics ctx doc
-- This will overwrite existing ilean info for the file, in case something
-- went wrong during the incremental updates.
ctx.chanOut.sync.send ( mkIleanInfoFinalNotification doc.meta st.allInfoTrees)
ctx.chanOut.send ( mkIleanInfoFinalNotification doc.meta st.allInfoTrees)
return ()
where
/--
@@ -312,7 +312,7 @@ This option can only be set on the command line, not in the lakefile or via `set
if let some itree := node.element.infoTree? then
let mut newInfoTrees := ( get).newInfoTrees.push itree
if ( get).hasBlocked then
ctx.chanOut.sync.send ( mkIleanInfoUpdateNotification doc.meta newInfoTrees)
ctx.chanOut.send ( mkIleanInfoUpdateNotification doc.meta newInfoTrees)
newInfoTrees := #[]
modify fun st => { st with newInfoTrees, allInfoTrees := st.allInfoTrees.push itree }
@@ -329,7 +329,7 @@ This option can only be set on the command line, not in the lakefile or via `set
| none => rs.push r
let ranges := ranges.map (·.toLspRange doc.meta.text)
let notifs := ranges.map ({ range := ·, kind := .processing })
ctx.chanOut.sync.send <| mkFileProgressNotification doc.meta notifs
ctx.chanOut.send <| mkFileProgressNotification doc.meta notifs
end Elab
@@ -389,9 +389,9 @@ def setupImports
severity? := DiagnosticSeverity.information
message := stderrLine
}
chanOut.sync.send <| mkPublishDiagnosticsNotification meta #[progressDiagnostic]
chanOut.send <| mkPublishDiagnosticsNotification meta #[progressDiagnostic]
-- clear progress notifications in the end
chanOut.sync.send <| mkPublishDiagnosticsNotification meta #[]
chanOut.send <| mkPublishDiagnosticsNotification meta #[]
match fileSetupResult.kind with
| .importsOutOfDate =>
return .error {
@@ -413,8 +413,6 @@ def setupImports
-- default to async elaboration; see also `Elab.async` docs
let opts := Elab.async.setIfNotSet opts true
let opts := Elab.inServer.set opts true
return .ok {
mainModuleName := meta.mod
opts
@@ -525,7 +523,7 @@ section ServerRequests
(freshRequestId, freshRequestId + 1)
let responseTask ctx.initPendingServerRequest responseType freshRequestId
let r : JsonRpc.Request paramType := freshRequestId, method, param
ctx.chanOut.sync.send r
ctx.chanOut.send r
return responseTask
def sendUntypedServerRequest
@@ -679,7 +677,7 @@ section MessageHandling
let availableImports ImportCompletion.collectAvailableImports
let lastRequestTimestampMs IO.monoMsNow
let completions := ImportCompletion.find text st.doc.initSnap.stx params availableImports
ctx.chanOut.sync.send <| .response id (toJson completions)
ctx.chanOut.send <| .response id (toJson completions)
pure { availableImports, lastRequestTimestampMs : AvailableImportsCache }
| some task => ServerTask.IO.mapTaskCostly (t := task) fun (result : Except Error AvailableImportsCache) => do
@@ -689,7 +687,7 @@ section MessageHandling
availableImports ImportCompletion.collectAvailableImports
lastRequestTimestampMs := timestampNowMs
let completions := ImportCompletion.find text st.doc.initSnap.stx params availableImports
ctx.chanOut.sync.send <| .response id (toJson completions)
ctx.chanOut.send <| .response id (toJson completions)
pure { availableImports, lastRequestTimestampMs : AvailableImportsCache }
def handleStatefulPreRequestSpecialCases (id : RequestID) (method : String) (params : Json) : WorkerM Bool := do
@@ -701,7 +699,7 @@ section MessageHandling
| "$/lean/rpc/connect" =>
let ps parseParams RpcConnectParams params
let resp handleRpcConnect ps
ctx.chanOut.sync.send <| .response id (toJson resp)
ctx.chanOut.send <| .response id (toJson resp)
return true
| "textDocument/completion" =>
let params parseParams CompletionParams params
@@ -714,7 +712,7 @@ section MessageHandling
| _ =>
return false
catch e =>
ctx.chanOut.sync.send <| .responseError id .internalError (toString e) none
ctx.chanOut.send <| .responseError id .internalError (toString e) none
return true
open Widget RequestM Language in
@@ -836,7 +834,7 @@ section MessageHandling
emitResponse ctx (isComplete := false) <| e.toLspResponseError id
where
emitResponse (ctx : WorkerContext) (m : JsonRpc.Message) (isComplete : Bool) : IO Unit := do
ctx.chanOut.sync.send m
ctx.chanOut.send m
let timestamp IO.monoMsNow
ctx.modifyPartialHandler method fun h => { h with
requestsInFlight := h.requestsInFlight - 1

View File

@@ -41,7 +41,6 @@ structure InfoPopup where
doc : Option String
deriving Inhabited, RpcEncodable
open PrettyPrinter.Delaborator in
/-- Given elaborator info for a particular subexpression. Produce the `InfoPopup`.
The intended usage of this is for the infoview to pass the `InfoWithCtx` which
@@ -54,10 +53,12 @@ def makePopup : WithRpcRef InfoWithCtx → RequestM (RequestTask InfoPopup)
| some type => some <$> ppExprTagged type
| none => pure none
let exprExplicit? match i.info with
| Elab.Info.ofTermInfo ti =>
some <$> ppExprForPopup ti.expr (explicit := true)
| Elab.Info.ofDelabTermInfo { toTermInfo := ti, explicit, ..} =>
some <$> ppExprForPopup ti.expr (explicit := explicit)
| Elab.Info.ofTermInfo ti
| Elab.Info.ofDelabTermInfo { toTermInfo := ti, explicit := true, ..} =>
some <$> ppExprTaggedWithoutTopLevelHighlight ti.expr (explicit := true)
| Elab.Info.ofDelabTermInfo { toTermInfo := ti, explicit := false, ..} =>
-- Keep the top-level tag so that users can also see the explicit version of the term on an additional hover.
some <$> ppExprTagged ti.expr (explicit := false)
| Elab.Info.ofFieldInfo fi => pure <| some <| TaggedText.text fi.fieldName.toString
| _ => pure none
return {
@@ -66,26 +67,11 @@ def makePopup : WithRpcRef InfoWithCtx → RequestM (RequestTask InfoPopup)
doc := i.info.docString? : InfoPopup
}
where
maybeWithoutTopLevelHighlight : Bool CodeWithInfos CodeWithInfos
| true, .tag _ tt => tt
| _, tt => tt
ppExprForPopup (e : Expr) (explicit : Bool := false) : MetaM CodeWithInfos := do
let mut e := e
-- When hovering over a metavariable, we want to see its value, even if `pp.instantiateMVars` is false.
if explicit && e.isMVar then
if let some e' getExprMVarAssignment? e.mvarId! then
e := e'
-- When `explicit` is false, keep the top-level tag so that users can also see the explicit version of the term on an additional hover.
maybeWithoutTopLevelHighlight explicit <$> ppExprTagged e do
if explicit then
withOptionAtCurrPos pp.tagAppFns.name true do
withOptionAtCurrPos pp.explicit.name true do
withOptionAtCurrPos pp.mvars.anonymous.name true do
delabApp
else
withOptionAtCurrPos pp.proofs.name true do
withOptionAtCurrPos pp.sorrySource.name true do
delab
ppExprTaggedWithoutTopLevelHighlight (e : Expr) (explicit : Bool) : MetaM CodeWithInfos := do
let pp ppExprTagged e (explicit := explicit)
return match pp with
| .tag _ tt => tt
| tt => tt
builtin_initialize
registerBuiltinRpcProcedure

View File

@@ -56,36 +56,6 @@ elab_rules : tactic
-- can't use a naked promise in `initialize` as marking it persistent would block
initialize unblockedCancelTk : IO.CancelToken IO.CancelToken.new
/--
Waits for `unblock` to be called, which is expected to happen in a subsequent document version that
does not invalidate this tactic. Complains if cancellation token was set before unblocking, i.e. if
the tactic was invalidated after all.
-/
scoped syntax "wait_for_unblock" : tactic
@[incremental]
elab_rules : tactic
| `(tactic| wait_for_unblock) => do
let ctx readThe Core.Context
let some cancelTk := ctx.cancelTk? | unreachable!
dbg_trace "blocked!"
log "blocked"
let ctx readThe Elab.Term.Context
let some tacSnap := ctx.tacSnap? | unreachable!
tacSnap.new.resolve {
diagnostics := ( Language.Snapshot.Diagnostics.ofMessageLog ( Core.getMessageLog))
stx := default
finished := default
}
while true do
if ( unblockedCancelTk.isSet) then
break
IO.sleep 30
if ( cancelTk.isSet) then
IO.eprintln "cancelled!"
log "cancelled (should never be visible)"
/--
Spawns a `logSnapshotTask` that waits for `unblock` to be called, which is expected to happen in a
subsequent document version that does not invalidate this tactic. Complains if cancellation token
@@ -113,10 +83,6 @@ scoped elab "unblock" : tactic => do
dbg_trace "unblocking!"
unblockedCancelTk.set
/--
Like `wait_for_cancel_once` but does the waiting in a separate task and waits for its
cancellation.
-/
scoped syntax "wait_for_cancel_once_async" : tactic
@[incremental]
elab_rules : tactic
@@ -144,35 +110,3 @@ elab_rules : tactic
dbg_trace "blocked!"
log "blocked"
/--
Like `wait_for_cancel_once_async` but waits for the main thread's cancellation token. This is useful
to test main thread cancellation in non-incremental contexts because we otherwise wouldn't be able
to send out the "blocked" message from there.
-/
scoped syntax "wait_for_main_cancel_once_async" : tactic
@[incremental]
elab_rules : tactic
| `(tactic| wait_for_main_cancel_once_async) => do
let prom IO.Promise.new
if let some t := ( onceRef.modifyGet (fun old => (old, old.getD prom.result!))) then
IO.wait t
return
let some cancelTk := ( readThe Core.Context).cancelTk? | unreachable!
let act Elab.Term.wrapAsyncAsSnapshot (cancelTk? := none) fun _ => do
let ctx readThe Core.Context
-- TODO: `CancelToken` should probably use `Promise`
while true do
if ( cancelTk.isSet) then
break
IO.sleep 30
IO.eprintln "cancelled!"
log "cancelled (should never be visible)"
prom.resolve ()
Core.checkInterrupted
let t BaseIO.asTask (act ())
Core.logSnapshotTask { stx? := none, task := t, cancelTk? := cancelTk }
dbg_trace "blocked!"
log "blocked"

View File

@@ -73,15 +73,24 @@ where
}
TaggedText.tag t (go subTt)
open PrettyPrinter Delaborator in
/--
Pretty prints the expression `e` using delaborator `delab`, returning an object that represents
the pretty printed syntax paired with information needed to support hovers.
-/
def ppExprTagged (e : Expr) (delab : Delab := Delaborator.delab) : MetaM CodeWithInfos := do
def ppExprTagged (e : Expr) (explicit : Bool := false) : MetaM CodeWithInfos := do
if pp.raw.get ( getOptions) then
let e if getPPInstantiateMVars ( getOptions) then instantiateMVars e else pure e
return .text (toString e)
return .text (toString ( instantiateMVars e))
let delab := open PrettyPrinter.Delaborator in
if explicit then
withOptionAtCurrPos pp.tagAppFns.name true do
withOptionAtCurrPos pp.explicit.name true do
withOptionAtCurrPos pp.mvars.anonymous.name true do
delabApp
else
withOptionAtCurrPos pp.proofs.name true do
withOptionAtCurrPos pp.sorrySource.name true do
delab
let mut e := e
-- When hovering over a metavariable, we want to see its value, even if `pp.instantiateMVars` is false.
if explicit && e.isMVar then
if let some e' getExprMVarAssignment? e.mvarId! then
e := e'
let fmt, infos PrettyPrinter.ppExprWithInfos e (delab := delab)
let tt := TaggedText.prettyTagged fmt
let ctx := {

View File

@@ -215,7 +215,7 @@ theorem modify_eq_alter [BEq α] [LawfulBEq α] {a : α} {f : β a → β a} {l
modify a f l = alter a (·.map f) l := by
induction l
· rfl
· next ih => simp only [modify, beq_iff_eq, alter, Option.map_some, ih]
· next ih => simp only [modify, beq_iff_eq, alter, Option.map_some', ih]
namespace Const
@@ -235,7 +235,7 @@ theorem modify_eq_alter [BEq α] [EquivBEq α] {a : α} {f : β → β} {l : Ass
modify a f l = alter a (·.map f) l := by
induction l
· rfl
· next ih => simp only [modify, beq_iff_eq, alter, Option.map_some, ih]
· next ih => simp only [modify, beq_iff_eq, alter, Option.map_some', ih]
end Const

View File

@@ -630,7 +630,7 @@ theorem Const.getThenInsertIfNew?_eq_insertIfNewₘ [BEq α] [Hashable α] (m :
(a : α) (b : β) : (Const.getThenInsertIfNew? m a b).2 = m.insertIfNewₘ a b := by
rw [getThenInsertIfNew?, insertIfNewₘ, containsₘ, bucket]
dsimp only [Array.ugetElem_eq_getElem, Array.uset]
split <;> simp_all [consₘ, updateBucket, List.containsKey_eq_isSome_getValue?, -Option.isSome_eq_false_iff]
split <;> simp_all [consₘ, updateBucket, List.containsKey_eq_isSome_getValue?, -Option.not_isSome]
theorem Const.getThenInsertIfNew?_eq_get?ₘ [BEq α] [Hashable α] (m : Raw₀ α (fun _ => β)) (a : α)
(b : β) : (Const.getThenInsertIfNew? m a b).1 = Const.get?ₘ m a := by

View File

@@ -964,7 +964,7 @@ theorem isHashSelf_filterMapₘ [BEq α] [Hashable α] [ReflBEq α] [LawfulHasha
IsHashSelf (m.filterMapₘ f).1.buckets := by
refine h.buckets_hash_self.updateAllBuckets (fun l p hp => ?_)
have hp := AssocList.toList_filterMap.mem_iff.1 hp
simp only [mem_filterMap, Option.map_eq_some_iff] at hp
simp only [mem_filterMap, Option.map_eq_some'] at hp
obtain p, hkv, d, -, rfl := hp
exact containsKey_of_mem hkv

View File

@@ -1031,7 +1031,7 @@ theorem ordered_filterMap [Ord α] {t : Impl α β} {h} {f : (a : α) → β a
simp only [Ordered, toListModel_filterMap]
apply ho.filterMap
intro e f hef e' he' f' hf'
simp only [Option.map_eq_some_iff] at he' hf'
simp only [Option.map_eq_some'] at he' hf'
obtain _, _, rfl := he'
obtain _, _, rfl := hf'
exact hef

View File

@@ -169,7 +169,7 @@ theorem getValue?_eq_getEntry? [BEq α] {l : List ((_ : α) × β)} {a : α} :
· next k v l ih =>
cases h : k == a
· rw [getEntry?_cons_of_false h, getValue?_cons_of_false h, ih]
· rw [getEntry?_cons_of_true h, getValue?_cons_of_true h, Option.map_some]
· rw [getEntry?_cons_of_true h, getValue?_cons_of_true h, Option.map_some']
theorem getValue?_congr [BEq α] [PartialEquivBEq α] {l : List ((_ : α) × β)} {a b : α}
(h : a == b) : getValue? a l = getValue? b l := by
@@ -338,7 +338,7 @@ theorem getEntry?_eq_none [BEq α] {l : List ((a : α) × β a)} {a : α} :
@[simp]
theorem getValue?_eq_none {β : Type v} [BEq α] {l : List ((_ : α) × β)} {a : α} :
getValue? a l = none containsKey a l = false := by
rw [getValue?_eq_getEntry?, Option.map_eq_none_iff, getEntry?_eq_none]
rw [getValue?_eq_getEntry?, Option.map_eq_none', getEntry?_eq_none]
theorem containsKey_eq_isSome_getValue? {β : Type v} [BEq α] {l : List ((_ : α) × β)} {a : α} :
containsKey a l = (getValue? a l).isSome := by
@@ -613,7 +613,7 @@ theorem getKey?_eq_getEntry? [BEq α] {l : List ((a : α) × β a)} {a : α} :
· next k v l ih =>
cases h : k == a
· rw [getEntry?_cons_of_false h, getKey?_cons_of_false h, ih]
· rw [getEntry?_cons_of_true h, getKey?_cons_of_true h, Option.map_some]
· rw [getEntry?_cons_of_true h, getKey?_cons_of_true h, Option.map_some']
theorem fst_mem_keys_of_mem [BEq α] [EquivBEq α] {a : (a : α) × β a} {l : List ((a : α) × β a)}
(hm : a l) : a.1 keys l :=
@@ -1515,7 +1515,7 @@ theorem containsKey_insertEntryIfNew [BEq α] [PartialEquivBEq α] {l : List ((a
cases h : k == a
· simp
· rw [containsKey_eq_isSome_getEntry?, getEntry?_congr h]
simp [-Option.not_isSome]
simp
theorem containsKey_insertEntryIfNew_self [BEq α] [EquivBEq α] {l : List ((a : α) × β a)} {k : α}
{v : β k} : containsKey k (insertEntryIfNew k v l) := by
@@ -2098,7 +2098,7 @@ theorem containsKey_append_of_not_contains_right [BEq α] {l l' : List ((a : α)
@[simp]
theorem getValue?_append {β : Type v} [BEq α] {l l' : List ((_ : α) × β)} {a : α} :
getValue? a (l ++ l') = (getValue? a l).or (getValue? a l') := by
simp [getValue?_eq_getEntry?, Option.map_or]
simp [getValue?_eq_getEntry?, Option.map_or']
theorem getValue?_append_of_containsKey_eq_false {β : Type v} [BEq α] {l l' : List ((_ : α) × β)}
{a : α} (h : containsKey a l' = false) : getValue? a (l ++ l') = getValue? a l := by
@@ -2233,7 +2233,7 @@ theorem mem_map_toProd_iff_mem {β : Type v} {k : α} {v : β} {l : List ((_ :
theorem mem_iff_getValue?_eq_some [BEq α] [LawfulBEq α] {β : Type v} {k : α} {v : β}
{l : List ((_ : α) × β)} (h : DistinctKeys l) :
k, v l getValue? k l = some v := by
simp only [mem_iff_getEntry?_eq_some h, getValue?_eq_getEntry?, Option.map_eq_some_iff]
simp only [mem_iff_getEntry?_eq_some h, getValue?_eq_getEntry?, Option.map_eq_some']
constructor
· intro h
exists k, v
@@ -2256,7 +2256,7 @@ theorem find?_map_toProd_eq_some_iff_getKey?_eq_some_and_getValue?_eq_some [BEq
| nil => simp
| cons hd tl ih =>
simp only [List.map_cons, List.find?_cons_eq_some, Prod.mk.injEq, Bool.not_eq_eq_eq_not,
Bool.not_true, Option.map_eq_some_iff, getKey?, cond_eq_if, getValue?]
Bool.not_true, Option.map_eq_some', getKey?, cond_eq_if, getValue?]
by_cases hdfst_k: hd.fst == k
· simp only [hdfst_k, true_and, Bool.true_eq_false, false_and, or_false, reduceIte,
Option.some.injEq]
@@ -2271,7 +2271,7 @@ theorem mem_iff_getKey?_eq_some_and_getValue?_eq_some [BEq α] [EquivBEq α]
theorem getValue?_eq_some_iff_exists_beq_and_mem_toList {β : Type v} [BEq α] [EquivBEq α]
{l : List ((_ : α) × β)} {k: α} {v : β} (h : DistinctKeys l) :
getValue? k l = some v k', (k == k') = true (k', v) l.map (fun x => (x.fst, x.snd)) := by
simp only [getValue?_eq_getEntry?, Option.map_eq_some_iff, mem_map_toProd_iff_mem,
simp only [getValue?_eq_getEntry?, Option.map_eq_some', mem_map_toProd_iff_mem,
mem_iff_getEntry?_eq_some h]
constructor
· intro h'
@@ -2635,7 +2635,7 @@ theorem getKey?_insertList_of_mem [BEq α] [EquivBEq α]
rcases List.mem_map.1 mem with k, v, pair_mem, rfl
rw [getKey?_eq_getEntry?, getEntry?_insertList distinct_l distinct_toInsert,
getEntry?_of_mem (DistinctKeys.def.2 distinct_toInsert) k_beq pair_mem, Option.some_or,
Option.map_some]
Option.map_some']
theorem getKey_insertList_of_contains_eq_false [BEq α] [EquivBEq α]
{l toInsert : List ((a : α) × β a)} {k : α}
@@ -3074,7 +3074,7 @@ theorem getKey?_insertListIfNewUnit_of_contains_eq_false_of_contains_eq_false [B
(h': containsKey k l = false) (h : toInsert.contains k = false) :
getKey? k (insertListIfNewUnit l toInsert) = none := by
rw [getKey?_eq_getEntry?,
getEntry?_insertListIfNewUnit_of_contains_eq_false h, Option.map_eq_none_iff, getEntry?_eq_none]
getEntry?_insertListIfNewUnit_of_contains_eq_false h, Option.map_eq_none', getEntry?_eq_none]
exact h'
theorem getKey?_insertListIfNewUnit_of_contains_eq_false_of_mem [BEq α] [EquivBEq α]
@@ -3083,8 +3083,8 @@ theorem getKey?_insertListIfNewUnit_of_contains_eq_false_of_mem [BEq α] [EquivB
(mem' : containsKey k l = false)
(distinct : toInsert.Pairwise (fun a b => (a == b) = false)) (mem : k toInsert) :
getKey? k' (insertListIfNewUnit l toInsert) = some k := by
simp only [getKey?_eq_getEntry?, getEntry?_insertListIfNewUnit, Option.map_eq_some_iff,
Option.or_eq_some_iff, getEntry?_eq_none]
simp only [getKey?_eq_getEntry?, getEntry?_insertListIfNewUnit, Option.map_eq_some',
Option.or_eq_some, getEntry?_eq_none]
exists k, ()
simp only [and_true]
right
@@ -4301,8 +4301,8 @@ theorem minEntry?_eq_some_iff [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α]
theorem minKey?_eq_some_iff_getKey?_eq_self_and_forall [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α]
{k} {l : List ((a : α) × β a)} (hd : DistinctKeys l) :
minKey? l = some k getKey? k l = some k k' : α, containsKey k' l (compare k k').isLE := by
simp only [minKey?, Option.map_eq_some_iff, minEntry?_eq_some_iff _ hd]
simp only [getKey?_eq_getEntry?, Option.map_eq_some_iff, getEntry?_eq_some_iff hd]
simp only [minKey?, Option.map_eq_some', minEntry?_eq_some_iff _ hd]
simp only [getKey?_eq_getEntry?, Option.map_eq_some', getEntry?_eq_some_iff hd]
apply Iff.intro
· rintro _, hm, hcmp, rfl
exact _, BEq.refl, hm, rfl, hcmp
@@ -4312,7 +4312,7 @@ theorem minKey?_eq_some_iff_getKey?_eq_self_and_forall [Ord α] [TransOrd α] [B
theorem minKey?_eq_some_iff_mem_and_forall [Ord α] [LawfulEqOrd α] [TransOrd α] [BEq α]
[LawfulBEqOrd α] {k} {l : List ((a : α) × β a)} (hd : DistinctKeys l) :
minKey? l = some k containsKey k l k' : α, containsKey k' l (compare k k').isLE := by
simp only [minKey?, Option.map_eq_some_iff, minEntry?_eq_some_iff _ hd]
simp only [minKey?, Option.map_eq_some', minEntry?_eq_some_iff _ hd]
apply Iff.intro
· rintro _, hm, hcmp, rfl
exact containsKey_of_mem hm, hcmp
@@ -4361,7 +4361,7 @@ theorem isNone_minKey?_eq_isEmpty [Ord α] {l : List ((a : α) × β a)} :
theorem isSome_minEntry?_eq_not_isEmpty [Ord α] {l : List ((a : α) × β a)} :
(minEntry? l).isSome = !l.isEmpty := by
rw [ Bool.not_inj_iff, Bool.not_not, Bool.eq_iff_iff, Bool.not_eq_true', Option.isSome_eq_false_iff,
rw [ Bool.not_inj_iff, Bool.not_not, Bool.eq_iff_iff, Bool.not_eq_true', Option.not_isSome,
Option.isNone_iff_eq_none]
apply minEntry?_eq_none_iff_isEmpty
@@ -4384,7 +4384,7 @@ theorem minEntry?_map [Ord α] (l : List ((a : α) × β a)) (f : (a : α) × β
simp only [minEntry?, List.min?]
cases l <;> try rfl
rename_i e es
simp only [List.map_cons, Option.map_some, Option.some.injEq]
simp only [List.map_cons, Option.map_some', Option.some.injEq]
rw [ List.foldr_reverse, List.foldr_reverse, List.map_reverse]
induction es.reverse with
| nil => rfl
@@ -4442,7 +4442,7 @@ theorem minEntry?_insertEntry [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α]
cases h : containsKey k l
· simp only [cond_false, minEntry?_cons, Option.some.injEq]
rfl
· rw [cond_true, minEntry?_replaceEntry hl, Option.map_eq_some_iff]
· rw [cond_true, minEntry?_replaceEntry hl, Option.map_eq_some']
have := isSome_minEntry?_of_contains _
simp only [Option.isSome_iff_exists] at this
obtain a, ha := this
@@ -4538,7 +4538,7 @@ theorem minKey?_bind_getKey? [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α]
theorem containsKey_minKey? [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α] {l : List ((a : α) × β a)}
(hd : DistinctKeys l) {km} (hkm : minKey? l = some km) :
containsKey km l := by
simp only [minKey?, Option.map_eq_some_iff, minEntry?_eq_some_iff _ hd] at hkm
simp only [minKey?, Option.map_eq_some', minEntry?_eq_some_iff _ hd] at hkm
obtain e, hm, _, rfl := hkm
exact containsKey_of_mem hm
@@ -4710,7 +4710,7 @@ theorem minKey?_alterKey_eq_self [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd
minKey? (alterKey k f l) = some k
(f (getValueCast? k l)).isSome k', containsKey k' l (compare k k').isLE := by
simp only [minKey?_eq_some_iff_getKey?_eq_self_and_forall hd.alterKey, getKey?_alterKey _ hd,
beq_self_eq_true, reduceIte, ite_eq_left_iff, Bool.not_eq_true, Option.isSome_eq_false_iff,
beq_self_eq_true, reduceIte, ite_eq_left_iff, Bool.not_eq_true, Option.not_isSome,
Option.isNone_iff_eq_none, reduceCtorEq, imp_false, Option.isSome_iff_ne_none,
containsKey_alterKey hd, beq_iff_eq, Bool.ite_eq_true_distrib, and_congr_right_iff]
intro hf
@@ -4753,18 +4753,18 @@ theorem minKey?_modifyKey_eq_minKey? [Ord α] [TransOrd α] [BEq α] [LawfulBEqO
simp only [minKey?_modifyKey hd]
cases minKey? l
· rfl
· simp only [beq_iff_eq, Option.map_some, Option.some.injEq, ite_eq_right_iff]
· simp only [beq_iff_eq, Option.map_some', Option.some.injEq, ite_eq_right_iff]
exact Eq.symm
theorem isSome_minKey?_modifyKey [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α] {k f}
{l : List ((_ : α) × β)} :
(modifyKey k f l |> minKey?).isSome = !l.isEmpty := by
simp [Option.isSome_map, isSome_minKey?_eq_not_isEmpty, isEmpty_modifyKey]
simp [Option.isSome_map', isSome_minKey?_eq_not_isEmpty, isEmpty_modifyKey]
theorem isSome_minKey?_modifyKey_eq_isSome [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α] {k f}
{l : List ((_ : α) × β)} :
(modifyKey k f l |> minKey?).isSome = (minKey? l).isSome := by
simp [Option.isSome_map, isSome_minKey?_eq_not_isEmpty, isEmpty_modifyKey]
simp [Option.isSome_map', isSome_minKey?_eq_not_isEmpty, isEmpty_modifyKey]
theorem minKey?_modifyKey_beq [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α] {k f km kmm}
{l : List ((_ : α) × β)} (hd : DistinctKeys l) (hkm : minKey? l = some km)
@@ -4784,7 +4784,7 @@ theorem minKey?_alterKey_eq_self [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd
(f (getValue? k l)).isSome k', containsKey k' l (compare k k').isLE := by
simp only [minKey?_eq_some_iff_getKey?_eq_self_and_forall hd.constAlterKey, getKey?_alterKey _ hd,
compare_eq_iff_beq, compare_self, reduceIte, ite_eq_left_iff, Bool.not_eq_true,
Option.isSome_eq_false_iff, Option.isNone_iff_eq_none, reduceCtorEq, imp_false,
Option.not_isSome, Option.isNone_iff_eq_none, reduceCtorEq, imp_false,
Option.isSome_iff_ne_none, containsKey_alterKey hd, Bool.ite_eq_true_distrib,
and_congr_right_iff]
intro hf

View File

@@ -1,720 +1,137 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
Authors: Gabriel Ebner
-/
prelude
import Init.System.Promise
import Init.Data.Queue
import Std.Sync.Mutex
/-!
This module contains the implementation of `Std.Channel`. `Std.Channel` is a multi-producer
multi-consumer FIFO channel that offers both bounded and unbounded buffering as well as synchronous
and asynchronous APIs.
Additionally `Std.CloseableChannel` is provided in case closing the channel is of interest.
The two are distinct as the non closable `Std.Channel` can never throw errors which makes
for cleaner code.
-/
namespace Std
namespace CloseableChannel
/--
Internal state of an `Channel`.
We maintain the invariant that at all times either `consumers` or `values` is empty.
-/
structure Channel.State (α : Type) where
values : Std.Queue α :=
consumers : Std.Queue (IO.Promise (Option α)) :=
closed := false
deriving Inhabited
/--
Errors that may be thrown while interacting with the channel API.
FIFO channel with unbounded buffer, where `recv?` returns a `Task`.
A channel can be closed. Once it is closed, all `send`s are ignored, and
`recv?` returns `none` once the queue is empty.
-/
inductive Error where
/--
Tried to send to a closed channel.
-/
| closed
/--
Tried to close an already closed channel.
-/
| alreadyClosed
deriving Repr, DecidableEq, Hashable
def Channel (α : Type) : Type := Mutex (Channel.State α)
instance : ToString Error where
toString
| .closed => "trying to send on an already closed channel"
| .alreadyClosed => "trying to close an already closed channel"
instance : Nonempty (Channel α) :=
inferInstanceAs (Nonempty (Mutex _))
instance : MonadLift (EIO Error) IO where
monadLift x := EIO.toIO (.userError <| toString ·) x
/-- Creates a new `Channel`. -/
def Channel.new : BaseIO (Channel α) :=
Mutex.new {}
/--
The central state structure for an unbounded channel, maintains the following invariants:
1. `values = ∅ consumers = ∅`
2. `closed = true → consumers = ∅`
Sends a message on an `Channel`.
This function does not block.
-/
private structure Unbounded.State (α : Type) where
/--
Values pushed into the channel that are waiting to be consumed.
-/
values : Std.Queue α
/--
Consumers that are blocked on a producer providing them a value. The `IO.Promise` will be
resolved to `none` if the channel closes.
-/
consumers : Std.Queue (IO.Promise (Option α))
/--
Whether the channel is closed already.
-/
closed : Bool
deriving Nonempty
private structure Unbounded (α : Type) where
state : Mutex (Unbounded.State α)
deriving Nonempty
namespace Unbounded
private def new : BaseIO (Unbounded α) := do
return {
state := Mutex.new {
values :=
consumers :=
closed := false
}
}
private def trySend (ch : Unbounded α) (v : α) : BaseIO Bool := do
ch.state.atomically do
def Channel.send (ch : Channel α) (v : α) : BaseIO Unit :=
ch.atomically do
let st get
if st.closed then
return false
else if let some (consumer, consumers) := st.consumers.dequeue? then
if st.closed then return
if let some (consumer, consumers) := st.consumers.dequeue? then
consumer.resolve (some v)
set { st with consumers }
return true
else
set { st with values := st.values.enqueue v }
return true
private def send (ch : Unbounded α) (v : α) : BaseIO (Task (Except Error Unit)) := do
if Unbounded.trySend ch v then
return .pure <| .ok ()
else
return .pure <| .error .closed
private def close (ch : Unbounded α) : EIO Error Unit := do
ch.state.atomically do
let st get
if st.closed then throw .alreadyClosed
for consumer in st.consumers.toArray do consumer.resolve none
set { st with consumers := , closed := true }
return ()
private def isClosed (ch : Unbounded α) : BaseIO Bool :=
ch.state.atomically do
return ( get).closed
private def tryRecv' : AtomicT (Unbounded.State α) BaseIO (Option α) := do
let st get
if let some (a, values) := st.values.dequeue? then
set { st with values }
return some a
else
return none
private def tryRecv (ch : Unbounded α) : BaseIO (Option α) :=
ch.state.atomically do
tryRecv'
private def recv (ch : Unbounded α) : BaseIO (Task (Option α)) := do
ch.state.atomically do
if let some val tryRecv' then
return .pure <| some val
else if ( get).closed then
return .pure none
else
let promise IO.Promise.new
modify fun st => { st with consumers := st.consumers.enqueue promise }
return promise.result?.map (sync := true) (·.bind id)
end Unbounded
/--
The central state structure for a zero buffer channel, maintains the following invariants:
1. `producers = ∅ consumers = ∅`
2. `closed = true → consumers = ∅`
Closes an `Channel`.
-/
private structure Zero.State (α : Type) where
/--
Producers that are blocked on a consumer taking their value.
-/
producers : Std.Queue (α × IO.Promise Bool)
/--
Consumers that are blocked on a producer providing them a value. The `IO.Promise` will be resolved
to `none` if the channel closes.
-/
consumers : Std.Queue (IO.Promise (Option α))
/--
Whether the channel is closed already.
-/
closed : Bool
private structure Zero (α : Type) where
state : Mutex (Zero.State α)
namespace Zero
private def new : BaseIO (Zero α) := do
return {
state := Mutex.new {
producers :=
consumers :=
closed := false
}
}
def Channel.close (ch : Channel α) : BaseIO Unit :=
ch.atomically do
let st get
for consumer in st.consumers.toArray do consumer.resolve none
set { st with closed := true, consumers := }
/--
Precondition: The channel must not be closed.
Receives a message, without blocking.
The returned task waits for the message.
Every message is only received once.
Returns `none` if the channel is closed and the queue is empty.
-/
private def trySend' (v : α) : AtomicT (Zero.State α) BaseIO Bool := do
let st get
if let some (consumer, consumers) := st.consumers.dequeue? then
consumer.resolve (some v)
set { st with consumers }
return true
else
return false
private def trySend (ch : Zero α) (v : α) : BaseIO Bool := do
ch.state.atomically do
if ( get).closed then
return false
else
trySend' v
private def send (ch : Zero α) (v : α) : BaseIO (Task (Except Error Unit)) := do
ch.state.atomically do
if ( get).closed then
return .pure <| .error .closed
else if trySend' v then
return .pure <| .ok ()
else
let promise IO.Promise.new
modify fun st => { st with producers := st.producers.enqueue (v, promise) }
return promise.result?.map (sync := true)
fun
| none | some false => .error .closed
| some true => .ok ()
private def close (ch : Zero α) : EIO Error Unit := do
ch.state.atomically do
def Channel.recv? (ch : Channel α) : BaseIO (Task (Option α)) :=
ch.atomically do
let st get
if st.closed then throw .alreadyClosed
for consumer in st.consumers.toArray do consumer.resolve none
set { st with consumers := , closed := true }
return ()
private def isClosed (ch : Zero α) : BaseIO Bool :=
ch.state.atomically do
return ( get).closed
private def tryRecv' : AtomicT (Zero.State α) BaseIO (Option α) := do
let st get
if let some ((val, promise), producers) := st.producers.dequeue? then
set { st with producers }
promise.resolve true
return some val
else
return none
private def tryRecv (ch : Zero α) : BaseIO (Option α) := do
ch.state.atomically do
tryRecv'
private def recv (ch : Zero α) : BaseIO (Task (Option α)) := do
ch.state.atomically do
let st get
if let some val tryRecv' then
return .pure <| some val
if let some (a, values) := st.values.dequeue? then
set { st with values }
return .pure a
else if !st.closed then
let promise IO.Promise.new
set { st with consumers := st.consumers.enqueue promise }
return promise.result?.map (sync := true) (·.bind id)
else
return .pure <| none
end Zero
/--
The central state structure for a bounded channel, maintains the following invariants:
1. `0 < capacity`
2. `0 < bufCount → consumers = ∅`
3. `bufCount < capacity → producers = ∅`
4. `producers = ∅ consumers = ∅`, implied by 1, 2 and 3.
5. `bufCount` corresponds to the amount of slots in `buf` that are `some`.
6. `sendIdx = (recvIdx + bufCount) % capacity`. However all four of these values still get tracked
as there is potential to make a non-blocking send lock-free in the future with this approach.
7. `closed = true → consumers = ∅`
While it (currently) lacks the partial lock-freeness of go channels, the protocol is based on
[Go channels on steroids](https://docs.google.com/document/d/1yIAYmbvL3JxOKOjuCyon7JhW4cSv1wy5hC0ApeGMV9s/pub)
as well as its [implementation](https://go.dev/src/runtime/chan.go).
-/
private structure Bounded.State (α : Type) where
/--
Producers that are blocked on a consumer taking their value as there was no buffer space
available when they tried to enqueue.
-/
producers : Std.Queue (IO.Promise Bool)
/--
Consumers that are blocked on a producer providing them a value, as there was no value
enqueued when they tried to dequeue. The `IO.Promise` will be resolved to `false` if the channel
closes.
-/
consumers : Std.Queue (IO.Promise Bool)
/--
The capacity of the buffer space.
-/
capacity : Nat
/--
The buffer space for the channel, slots with `some v` contain a value that is waiting for
consumption, the slots with `none` are free for enqueueing.
Note that this is a `Vector` of `IO.Ref (Option α)` as the `buf` itself is shared across threads
and would thus keep getting copied if it was a `Vector (Option α)` instead.
-/
buf : Vector (IO.Ref (Option α)) capacity
/--
How many slots in `buf` are currently used, this is used to disambiguate between an empty and a
full buffer without sacrificing a slot for indicating that.
-/
bufCount : Nat
/--
The slot in `buf` that the next send will happen to.
-/
sendIdx : Nat
hsend : sendIdx < capacity
/--
The slot in `buf` that the next receive will happen from.
-/
recvIdx : Nat
hrecv : recvIdx < capacity
/--
Whether the channel is closed already.
-/
closed : Bool
private structure Bounded (α : Type) where
state : Mutex (Bounded.State α)
namespace Bounded
private def new (capacity : Nat) (hcap : 0 < capacity) : BaseIO (Bounded α) := do
return {
state := Mutex.new {
producers :=
consumers :=
capacity := capacity
buf := Vector.range capacity |>.mapM (fun _ => IO.mkRef none)
bufCount := 0
sendIdx := 0
hsend := hcap
recvIdx := 0
hrecv := hcap
closed := false
}
}
@[inline]
private def incMod (idx : Nat) (cap : Nat) : Nat :=
if idx + 1 = cap then
0
else
idx + 1
private theorem incMod_lt {idx cap : Nat} (h : idx < cap) : incMod idx cap < cap := by
unfold incMod
split <;> omega
/--
Precondition: The channel must not be closed.
-/
private def trySend' (v : α) : AtomicT (Bounded.State α) BaseIO Bool := do
let mut st get
if st.bufCount = st.capacity then
return false
else
st.buf[st.sendIdx]'st.hsend |>.set (some v)
st := { st with
bufCount := st.bufCount + 1
sendIdx := incMod st.sendIdx st.capacity
hsend := incMod_lt st.hsend
}
if let some (consumer, consumers) := st.consumers.dequeue? then
consumer.resolve true
st := { st with consumers }
set st
return true
private def trySend (ch : Bounded α) (v : α) : BaseIO Bool := do
ch.state.atomically do
if ( get).closed then
return false
else
trySend' v
private partial def send (ch : Bounded α) (v : α) : BaseIO (Task (Except Error Unit)) := do
ch.state.atomically do
if ( get).closed then
return .pure <| .error .closed
else if trySend' v then
return .pure <| .ok ()
else
let promise IO.Promise.new
modify fun st => { st with producers := st.producers.enqueue promise }
BaseIO.bindTask promise.result? fun res => do
if res.getD false then
Bounded.send ch v
else
return .pure <| .error .closed
private def close (ch : Bounded α) : EIO Error Unit := do
ch.state.atomically do
let st get
if st.closed then throw .alreadyClosed
for consumer in st.consumers.toArray do consumer.resolve false
set { st with consumers := , closed := true }
return ()
private def isClosed (ch : Bounded α) : BaseIO Bool :=
ch.state.atomically do
return ( get).closed
private def tryRecv' : AtomicT (Bounded.State α) BaseIO (Option α) := do
let st get
if st.bufCount == 0 then
return none
else
let val st.buf[st.recvIdx]'st.hrecv |>.swap none
let nextRecvIdx := incMod st.recvIdx st.capacity
set { st with
bufCount := st.bufCount - 1
recvIdx := nextRecvIdx,
hrecv := incMod_lt st.hrecv
}
return val
private def tryRecv (ch : Bounded α) : BaseIO (Option α) :=
ch.state.atomically do
tryRecv'
private partial def recv (ch : Bounded α) : BaseIO (Task (Option α)) := do
ch.state.atomically do
if let some val tryRecv' then
let st get
if let some (producer, producers) := ( get).producers.dequeue? then
producer.resolve true
set { st with producers }
return .pure <| some val
else if ( get).closed then
return .pure none
else
let promise IO.Promise.new
modify fun st => { st with consumers := st.consumers.enqueue promise }
BaseIO.bindTask promise.result? fun res => do
if res.getD false then
Bounded.recv ch
else
return .pure none
end Bounded
/--
This type represents all flavors of channels that we have available.
`ch.forAsync f` calls `f` for every messages received on `ch`.
Note that if this function is called twice, each `forAsync` only gets half the messages.
-/
private inductive Flavors (α : Type) where
| unbounded (ch : Unbounded α)
| zero (ch : Zero α)
| bounded (ch : Bounded α)
deriving Nonempty
end CloseableChannel
/--
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
and an asynchronous API, to switch into synchronous mode use `CloseableChannel.sync`.
Additionally `Std.CloseableChannel` can be closed if necessary, unlike `Std.Channel`.
This introduces a need for error handling in some cases, thus it is usually easier to use
`Std.Channel` if applicable.
-/
def CloseableChannel (α : Type) : Type := CloseableChannel.Flavors α
/--
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
and a synchronous API. This type acts as a convenient layer to use a channel in a blocking fashion
and is not actually different from the original channel.
Additionally `Std.CloseableChannel.Sync` can be closed if necessary, unlike `Std.Channel.Sync`.
This introduces the need to handle errors in some cases, thus it is usually easier to use
`Std.Channel` if applicable.
-/
def CloseableChannel.Sync (α : Type) : Type := CloseableChannel α
instance : Nonempty (CloseableChannel α) :=
inferInstanceAs (Nonempty (CloseableChannel.Flavors α))
instance : Nonempty (CloseableChannel.Sync α) :=
inferInstanceAs (Nonempty (CloseableChannel α))
namespace CloseableChannel
/--
Create a new channel, if:
- `capacity` is `none` it will be unbounded (the default)
- `capacity` is `some 0` it will always force a rendezvous between sender and receiver
- `capacity` is `some n` with `n > 0` it will use a buffer of size `n` and begin blocking once it
is filled
-/
def new (capacity : Option Nat := none) : BaseIO (CloseableChannel α) := do
match capacity with
| none => return .unbounded ( CloseableChannel.Unbounded.new)
| some 0 => return .zero ( CloseableChannel.Zero.new)
| some (n + 1) => return .bounded ( CloseableChannel.Bounded.new (n + 1) (by omega))
/--
Try to send a value to the channel, if this can be completed right away without blocking return
`true`, otherwise don't send the value and return `false`.
-/
def trySend (ch : CloseableChannel α) (v : α) : BaseIO Bool :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.trySend ch v
| .zero ch => CloseableChannel.Zero.trySend ch v
| .bounded ch => CloseableChannel.Bounded.trySend ch v
/--
Send a value through the channel, returning a task that will resolve once the transmission could be
completed. Note that the task may resolve to `Except.error` if the channel was closed before it
could be completed.
-/
def send (ch : CloseableChannel α) (v : α) : BaseIO (Task (Except Error Unit)) :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.send ch v
| .zero ch => CloseableChannel.Zero.send ch v
| .bounded ch => CloseableChannel.Bounded.send ch v
/--
Closes the channel, returns `Except.ok` when called the first time, otherwise `Except.error`.
When a channel is closed:
- no new values can be sent successfully anymore
- all blocked consumers are resolved to `none` (as no new messages can be sent they will never
resolve)
- if there are already values waiting to be received they can still be received by subsequent `recv`
calls
-/
def close (ch : CloseableChannel α) : EIO Error Unit :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.close ch
| .zero ch => CloseableChannel.Zero.close ch
| .bounded ch => CloseableChannel.Bounded.close ch
/--
Return `true` if the channel is closed.
-/
def isClosed (ch : CloseableChannel α) : BaseIO Bool :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.isClosed ch
| .zero ch => CloseableChannel.Zero.isClosed ch
| .bounded ch => CloseableChannel.Bounded.isClosed ch
/--
Try to receive a value from the channel, if this can be completed right away without blocking return
`some value`, otherwise return `none`.
-/
def tryRecv (ch : CloseableChannel α) : BaseIO (Option α) :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.tryRecv ch
| .zero ch => CloseableChannel.Zero.tryRecv ch
| .bounded ch => CloseableChannel.Bounded.tryRecv ch
/--
Receive a value from the channel, returning a task that will resolve once the transmission could be
completed. Note that the task may resolve to `none` if the channel was closed before it could be
completed.
-/
def recv (ch : CloseableChannel α) : BaseIO (Task (Option α)) :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.recv ch
| .zero ch => CloseableChannel.Zero.recv ch
| .bounded ch => CloseableChannel.Bounded.recv ch
/--
`ch.forAsync f` calls `f` for every message received on `ch`.
Note that if this function is called twice, each message will only arrive at exactly one invocation.
-/
partial def forAsync (f : α BaseIO Unit) (ch : CloseableChannel α)
partial def Channel.forAsync (f : α BaseIO Unit) (ch : Channel α)
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
BaseIO.bindTask (prio := prio) ( ch.recv) fun
BaseIO.bindTask (prio := prio) ( ch.recv?) fun
| none => return .pure ()
| some v => do f v; ch.forAsync f prio
/--
This function is a no-op and just a convenient way to expose the synchronous API of the channel.
Receives all currently queued messages from the channel.
Those messages are dequeued and will not be returned by `recv?`.
-/
@[inline]
def sync (ch : CloseableChannel α) : CloseableChannel.Sync α := ch
def Channel.recvAllCurrent (ch : Channel α) : BaseIO (Array α) :=
ch.atomically do
modifyGet fun st => (st.values.toArray, { st with values := })
namespace Sync
@[inherit_doc CloseableChannel.new, inline]
def new (capacity : Option Nat := none) : BaseIO (Sync α) := CloseableChannel.new capacity
@[inherit_doc CloseableChannel.trySend, inline]
def trySend (ch : Sync α) (v : α) : BaseIO Bool := CloseableChannel.trySend ch v
/-- Type tag for synchronous (blocking) operations on a `Channel`. -/
def Channel.Sync := Channel
/--
Send a value through the channel, blocking until the transmission could be completed. Note that this
function may throw an error when trying to send to an already closed channel.
Accesses synchronous (blocking) version of channel operations.
For example, `ch.sync.recv?` blocks until the next message,
and `for msg in ch.sync do ...` iterates synchronously over the channel.
These functions should only be used in dedicated threads.
-/
def send (ch : Sync α) (v : α) : EIO Error Unit := do
EIO.ofExcept ( IO.wait ( CloseableChannel.send ch v))
@[inherit_doc CloseableChannel.close, inline]
def close (ch : Sync α) : EIO Error Unit := CloseableChannel.close ch
@[inherit_doc CloseableChannel.isClosed, inline]
def isClosed (ch : Sync α) : BaseIO Bool := CloseableChannel.isClosed ch
@[inherit_doc CloseableChannel.tryRecv, inline]
def tryRecv (ch : Sync α) : BaseIO (Option α) := CloseableChannel.tryRecv ch
def Channel.sync (ch : Channel α) : Channel.Sync α := ch
/--
Receive a value from the channel, blocking unitl the transmission could be completed. Note that the
return value may be `none` if the channel was closed before it could be completed.
-/
def recv (ch : Sync α) : BaseIO (Option α) := do
IO.wait ( CloseableChannel.recv ch)
Synchronously receives a message from the channel.
private partial def forIn [Monad m] [MonadLiftT BaseIO m]
(ch : Sync α) (f : α β m (ForInStep β)) : β m β := fun b => do
match ch.recv with
| some a =>
match f a b with
| .done b => pure b
| .yield b => ch.forIn f b
| none => pure b
Every message is only received once.
Returns `none` if the channel is closed and the queue is empty.
-/
def Channel.Sync.recv? (ch : Channel.Sync α) : BaseIO (Option α) := do
IO.wait ( Channel.recv? ch)
private partial def Channel.Sync.forIn [Monad m] [MonadLiftT BaseIO m]
(ch : Channel.Sync α) (f : α β m (ForInStep β)) : β m β := fun b => do
match ch.recv? with
| some a =>
match f a b with
| .done b => pure b
| .yield b => ch.forIn f b
| none => pure b
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
instance [MonadLiftT BaseIO m] : ForIn m (Sync α) α where
instance [MonadLiftT BaseIO m] : ForIn m (Channel.Sync α) α where
forIn ch b f := ch.forIn f b
end Sync
end CloseableChannel
/--
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
and an asynchronous API, to switch into synchronous mode use `Channel.sync`.
If a channel needs to be closed to indicate some sort of completion event use `Std.CloseableChannel`
instead. Note that `Std.CloseableChannel` introduces a need for error handling in some cases, thus
`Std.Channel` is usually easier to use if applicable.
-/
structure Channel (α : Type) where
private mk ::
private inner : CloseableChannel α
deriving Nonempty
/--
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
and a synchronous API. This type acts as a convenient layer to use a channel in a blocking fashion
and is not actually different from the original channel.
If a channel needs to be closed to indicate some sort of completion event use
`Std.CloseableChannel.Sync` instead. Note that `Std.CloseableChannel.Sync` introduces a need for error
handling in some cases, thus `Std.Channel.Sync` is usually easier to use if applicable.
-/
def Channel.Sync (α : Type) : Type := Channel α
instance : Nonempty (Channel.Sync α) :=
inferInstanceAs (Nonempty (Channel α))
namespace Channel
@[inherit_doc CloseableChannel.new, inline]
def new (capacity : Option Nat := none) : BaseIO (Channel α) := do
return CloseableChannel.new capacity
@[inherit_doc CloseableChannel.trySend, inline]
def trySend (ch : Channel α) (v : α) : BaseIO Bool :=
CloseableChannel.trySend ch.inner v
/--
Send a value through the channel, returning a task that will resolve once the transmission could be
completed.
-/
def send (ch : Channel α) (v : α) : BaseIO (Task Unit) := do
BaseIO.bindTask (sync := true) ( CloseableChannel.send ch.inner v)
fun
| .ok .. => return .pure ()
| .error .. => unreachable!
@[inherit_doc CloseableChannel.tryRecv, inline]
def tryRecv (ch : Channel α) : BaseIO (Option α) :=
CloseableChannel.tryRecv ch.inner
@[inherit_doc CloseableChannel.recv]
def recv [Inhabited α] (ch : Channel α) : BaseIO (Task α) := do
BaseIO.bindTask (sync := true) ( CloseableChannel.recv ch.inner)
fun
| some val => return .pure val
| none => unreachable!
@[inherit_doc CloseableChannel.forAsync]
partial def forAsync [Inhabited α] (f : α BaseIO Unit) (ch : Channel α)
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
BaseIO.bindTask (prio := prio) ( ch.recv) fun v => do f v; ch.forAsync f prio
@[inherit_doc CloseableChannel.sync, inline]
def sync (ch : Channel α) : Channel.Sync α := ch
namespace Sync
@[inherit_doc Channel.new, inline]
def new (capacity : Option Nat := none) : BaseIO (Sync α) := Channel.new capacity
@[inherit_doc Channel.trySend, inline]
def trySend (ch : Sync α) (v : α) : BaseIO Bool := Channel.trySend ch v
/--
Send a value through the channel, blocking until the transmission could be completed.
-/
def send (ch : Sync α) (v : α) : BaseIO Unit := do
IO.wait ( Channel.send ch v)
@[inherit_doc Channel.tryRecv, inline]
def tryRecv (ch : Sync α) : BaseIO (Option α) := Channel.tryRecv ch
/--
Receive a value from the channel, blocking unitl the transmission could be completed.
-/
def recv [Inhabited α] (ch : Sync α) : BaseIO α := do
IO.wait ( Channel.recv ch)
private partial def forIn [Inhabited α] [Monad m] [MonadLiftT BaseIO m]
(ch : Sync α) (f : α β m (ForInStep β)) : β m β := fun b => do
let a ch.recv
match f a b with
| .done b => pure b
| .yield b => ch.forIn f b
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
instance [Inhabited α] [MonadLiftT BaseIO m] : ForIn m (Sync α) α where
forIn ch b f := ch.forIn f b
end Sync
end Channel
end Std

View File

@@ -246,116 +246,113 @@ instance : Hashable (BVExpr w) where
def decEq : DecidableEq (BVExpr w) := fun l r =>
withPtrEqDecEq l r fun _ =>
if h : hash l hash r then
.isFalse (ne_of_apply_ne hash h)
else
match l with
| .var lidx =>
match r with
| .var ridx =>
if h : lidx = ridx then .isTrue (by simp [h]) else .isFalse (by simp [h])
| .const .. | .extract .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .const lval =>
match r with
| .const rval =>
if h : lval = rval then .isTrue (by simp [h]) else .isFalse (by simp [h])
| .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .extract (w := lw) lstart _ lexpr =>
match r with
| .extract (w := rw) rstart _ rexpr =>
if h1 : lw = rw lstart = rstart then
match decEq (h1.left lexpr) rexpr with
| .isTrue h2 => .isTrue (by cases h1.left; simp_all)
| .isFalse h2 => .isFalse (by cases h1.left; simp_all)
else
.isFalse (by simp_all)
| .var .. | .const .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .bin llhs lop lrhs =>
match r with
| .bin rlhs rop rrhs =>
if h1 : lop = rop then
match decEq llhs rlhs, decEq lrhs rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by simp [h1, h2, h3])
| .isFalse h2, _ => .isFalse (by simp [h2])
| _, .isFalse h3 => .isFalse (by simp [h3])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .un lop lexpr =>
match r with
| .un rop rexpr =>
if h1 : lop = rop then
match decEq lexpr rexpr with
| .isTrue h2 => .isTrue (by simp [h1, h2])
| .isFalse h2 => .isFalse (by simp [h2])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .bin .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .append (l := ll) (r := lr) llhs lrhs lh =>
match r with
| .append (l := rl) (r := rr) rlhs rrhs rh =>
if h1 : ll = rl lr = rr then
match decEq (h1.left llhs) rlhs, decEq (h1.right lrhs) rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1.left; cases h1.right; simp [h2, h3])
| .isFalse h2, _ => .isFalse (by cases h1.left; cases h1.right; simp [h2])
| _, .isFalse h3 => .isFalse (by cases h1.left; cases h1.right; simp [h3])
else
.isFalse (by simp; omega)
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .replicate (w := lw) ln lexpr lh =>
match r with
| .replicate (w := rw) rn rexpr rh =>
if h1 : ln = rn lw = rw then
match decEq (h1.right lexpr) rexpr with
| .isTrue h2 => .isTrue (by cases h1.left; cases h1.right; simp [h2])
| .isFalse h2 => .isFalse (by cases h1.left; cases h1.right; simp [h2])
else
.isFalse (by simp; omega)
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
match l with
| .var lidx =>
match r with
| .var ridx =>
if h : lidx = ridx then .isTrue (by simp [h]) else .isFalse (by simp [h])
| .const .. | .extract .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .const lval =>
match r with
| .const rval =>
if h : lval = rval then .isTrue (by simp [h]) else .isFalse (by simp [h])
| .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .extract (w := lw) lstart _ lexpr =>
match r with
| .extract (w := rw) rstart _ rexpr =>
if h1 : lw = rw lstart = rstart then
match decEq (h1.left lexpr) rexpr with
| .isTrue h2 => .isTrue (by cases h1.left; simp_all)
| .isFalse h2 => .isFalse (by cases h1.left; simp_all)
else
.isFalse (by simp_all)
| .var .. | .const .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .bin llhs lop lrhs =>
match r with
| .bin rlhs rop rrhs =>
if h1 : lop = rop then
match decEq llhs rlhs, decEq lrhs rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by simp [h1, h2, h3])
| .isFalse h2, _ => .isFalse (by simp [h2])
| _, .isFalse h3 => .isFalse (by simp [h3])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .un lop lexpr =>
match r with
| .un rop rexpr =>
if h1 : lop = rop then
match decEq lexpr rexpr with
| .isTrue h2 => .isTrue (by simp [h1, h2])
| .isFalse h2 => .isFalse (by simp [h2])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .bin .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .append (l := ll) (r := lr) llhs lrhs lh =>
match r with
| .append (l := rl) (r := rr) rlhs rrhs rh =>
if h1 : ll = rl lr = rr then
match decEq (h1.left llhs) rlhs, decEq (h1.right lrhs) rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1.left; cases h1.right; simp [h2, h3])
| .isFalse h2, _ => .isFalse (by cases h1.left; cases h1.right; simp [h2])
| _, .isFalse h3 => .isFalse (by cases h1.left; cases h1.right; simp [h3])
else
.isFalse (by simp; omega)
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .replicate (w := lw) ln lexpr lh =>
match r with
| .replicate (w := rw) rn rexpr rh =>
if h1 : ln = rn lw = rw then
match decEq (h1.right lexpr) rexpr with
| .isTrue h2 => .isTrue (by cases h1.left; cases h1.right; simp [h2])
| .isFalse h2 => .isFalse (by cases h1.left; cases h1.right; simp [h2])
else
.isFalse (by simp; omega)
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .shiftLeft (n := lw) llhs lrhs =>
match r with
| .shiftLeft (n := rw) rlhs rrhs =>
if h1 : lw = rw then
match decEq llhs rlhs, decEq (h1 lrhs) rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .shiftRight (n := lw) llhs lrhs =>
match r with
| .shiftRight (n := rw) rlhs rrhs =>
if h1 : lw = rw then
match decEq llhs rlhs, decEq (h1 lrhs) rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
|.shiftLeft .. | .arithShiftRight .. => .isFalse (by simp)
| .arithShiftRight (n := lw) llhs lrhs =>
match r with
| .arithShiftRight (n := rw) rlhs rrhs =>
if h1 : lw = rw then
match decEq llhs rlhs, decEq (h1 lrhs) rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
| .shiftRight .. | .shiftLeft .. => .isFalse (by simp)
| .shiftLeft (n := lw) llhs lrhs =>
match r with
| .shiftLeft (n := rw) rlhs rrhs =>
if h1 : lw = rw then
match decEq llhs rlhs, decEq (h1 lrhs) rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .shiftRight (n := lw) llhs lrhs =>
match r with
| .shiftRight (n := rw) rlhs rrhs =>
if h1 : lw = rw then
match decEq llhs rlhs, decEq (h1 lrhs) rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
|.shiftLeft .. | .arithShiftRight .. => .isFalse (by simp)
| .arithShiftRight (n := lw) llhs lrhs =>
match r with
| .arithShiftRight (n := rw) rlhs rrhs =>
if h1 : lw = rw then
match decEq llhs rlhs, decEq (h1 lrhs) rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
| .shiftRight .. | .shiftLeft .. => .isFalse (by simp)
instance : DecidableEq (BVExpr w) := decEq

View File

@@ -9,7 +9,6 @@ Author: Leonardo de Moura
#include <utility>
#include <unordered_map>
#include "kernel/replace_fn.h"
#include "util/alloc.h"
namespace lean {
@@ -19,7 +18,7 @@ class replace_rec_fn {
return hash((size_t)p.first >> 3, p.second);
}
};
lean::unordered_map<std::pair<lean_object *, unsigned>, expr, key_hasher> m_cache;
std::unordered_map<std::pair<lean_object *, unsigned>, expr, key_hasher> m_cache;
std::function<optional<expr>(expr const &, unsigned)> m_f;
bool m_use_cache;
@@ -85,7 +84,7 @@ expr replace(expr const & e, std::function<optional<expr>(expr const &, unsigned
}
class replace_fn {
lean::unordered_map<lean_object *, expr> m_cache;
std::unordered_map<lean_object *, expr> m_cache;
lean_object * m_f;
expr save_result(expr const & e, expr const & r, bool shared) {

View File

@@ -21,7 +21,6 @@ Authors: Leonardo de Moura, Gabriel Ebner, Sebastian Ullrich
#include "runtime/io.h"
#include "runtime/compact.h"
#include "runtime/buffer.h"
#include "runtime/array_ref.h"
#include "util/io.h"
#include "util/name_map.h"
#include "library/module.h"
@@ -77,73 +76,47 @@ struct olean_header {
// make sure we don't have any padding bytes, which also ensures `data` is properly aligned
static_assert(sizeof(olean_header) == 5 + 1 + 1 + 33 + 40 + sizeof(size_t), "olean_header must be packed");
extern "C" LEAN_EXPORT object * lean_save_module_data_parts(b_obj_arg mod, b_obj_arg oparts, object *) {
#ifdef LEAN_WINDOWS
uint32_t pid = GetCurrentProcessId();
#else
uint32_t pid = getpid();
#endif
// Derive a base address that is uniformly distributed by deterministic, and should most likely
// work for `mmap` on all interesting platforms
// NOTE: an overlapping/non-compatible base address does not prevent the module from being imported,
// merely from using `mmap` for that
// Let's start with a hash of the module name. Note that while our string hash is a dubious 32-bit
// algorithm, the mixing of multiple `Name` parts seems to result in a nicely distributed 64-bit
// output
size_t base_addr = name(mod, true).hash();
// x86-64 user space is currently limited to the lower 47 bits
// https://en.wikipedia.org/wiki/X86-64#Virtual_address_space_details
// On Linux at least, the stack grows down from ~0x7fff... followed by shared libraries, so reserve
// a bit of space for them (0x7fff...-0x7f00... = 1TB)
base_addr = base_addr % 0x7f0000000000;
// `mmap` addresses must be page-aligned. The default (non-huge) page size on x86-64 is 4KB.
// `MapViewOfFileEx` addresses must be aligned to the "memory allocation granularity", which is 64KB.
const size_t ALIGN = 1LL<<16;
base_addr = base_addr & ~(ALIGN - 1);
object_compactor compactor(reinterpret_cast<void *>(base_addr));
array_ref<pair_ref<string_ref, object_ref>> parts(oparts, true);
std::vector<std::string> tmp_fnames;
for (auto const & part : parts) {
std::string olean_fn = part.fst().to_std_string();
try {
// we first write to a temp file and then move it to the correct path (possibly deleting an older file)
// so that we neither expose partially-written files nor modify possibly memory-mapped files
std::string olean_tmp_fn = olean_fn + ".tmp." + std::to_string(pid);
tmp_fnames.push_back(olean_tmp_fn);
std::ofstream out(olean_tmp_fn, std::ios_base::binary);
if (compactor.size() % ALIGN != 0) {
compactor.alloc(ALIGN - (compactor.size() % ALIGN));
}
size_t file_offset = compactor.size();
compactor.alloc(sizeof(olean_header));
olean_header header = {};
// see/sync with file format description above
header.base_addr = base_addr + file_offset;
strncpy(header.lean_version, get_short_version_string().c_str(), sizeof(header.lean_version));
strncpy(header.githash, LEAN_GITHASH, sizeof(header.githash));
out.write(reinterpret_cast<char *>(&header), sizeof(header));
compactor(part.snd().raw());
if (out.fail()) {
throw exception((sstream() << "failed to create file '" << olean_fn << "'").str());
}
out.write(static_cast<char const *>(compactor.data()) + file_offset + sizeof(olean_header), compactor.size() - file_offset - sizeof(olean_header));
out.close();
} catch (exception & ex) {
return io_result_mk_error((sstream() << "failed to write '" << olean_fn << "': " << ex.what()).str());
extern "C" LEAN_EXPORT object * lean_save_module_data(b_obj_arg fname, b_obj_arg mod, b_obj_arg mdata, object *) {
std::string olean_fn(string_cstr(fname));
// we first write to a temp file and then move it to the correct path (possibly deleting an older file)
// so that we neither expose partially-written files nor modify possibly memory-mapped files
std::string olean_tmp_fn = olean_fn + ".tmp";
try {
std::ofstream out(olean_tmp_fn, std::ios_base::binary);
if (out.fail()) {
return io_result_mk_error((sstream() << "failed to create file '" << olean_fn << "'").str());
}
}
for (unsigned i = 0; i < parts.size(); i++) {
std::string olean_fn = parts[i].fst().to_std_string();
while (std::rename(tmp_fnames[i].c_str(), olean_fn.c_str()) != 0) {
// Derive a base address that is uniformly distributed by deterministic, and should most likely
// work for `mmap` on all interesting platforms
// NOTE: an overlapping/non-compatible base address does not prevent the module from being imported,
// merely from using `mmap` for that
// Let's start with a hash of the module name. Note that while our string hash is a dubious 32-bit
// algorithm, the mixing of multiple `Name` parts seems to result in a nicely distributed 64-bit
// output
size_t base_addr = name(mod, true).hash();
// x86-64 user space is currently limited to the lower 47 bits
// https://en.wikipedia.org/wiki/X86-64#Virtual_address_space_details
// On Linux at least, the stack grows down from ~0x7fff... followed by shared libraries, so reserve
// a bit of space for them (0x7fff...-0x7f00... = 1TB)
base_addr = base_addr % 0x7f0000000000;
// `mmap` addresses must be page-aligned. The default (non-huge) page size on x86-64 is 4KB.
// `MapViewOfFileEx` addresses must be aligned to the "memory allocation granularity", which is 64KB.
base_addr = base_addr & ~((1LL<<16) - 1);
object_compactor compactor(reinterpret_cast<void *>(base_addr + offsetof(olean_header, data)));
compactor(mdata);
// see/sync with file format description above
olean_header header = {};
header.base_addr = base_addr;
strncpy(header.lean_version, get_short_version_string().c_str(), sizeof(header.lean_version));
strncpy(header.githash, LEAN_GITHASH, sizeof(header.githash));
out.write(reinterpret_cast<char *>(&header), sizeof(header));
out.write(static_cast<char const *>(compactor.data()), compactor.size());
out.close();
while (std::rename(olean_tmp_fn.c_str(), olean_fn.c_str()) != 0) {
#ifdef LEAN_WINDOWS
if (errno == EEXIST) {
// Memory-mapped files can be deleted starting with Windows 10 using "POSIX semantics"
@@ -163,148 +136,94 @@ extern "C" LEAN_EXPORT object * lean_save_module_data_parts(b_obj_arg mod, b_obj
#endif
return io_result_mk_error((sstream() << "failed to write '" << olean_fn << "': " << errno << " " << strerror(errno)).str());
}
return io_result_mk_ok(box(0));
} catch (exception & ex) {
return io_result_mk_error((sstream() << "failed to write '" << olean_fn << "': " << ex.what()).str());
}
return io_result_mk_ok(box(0));
}
struct module_file {
std::string m_fname;
std::ifstream m_in;
char * m_base_addr;
size_t m_size;
char * m_buffer;
std::function<void()> m_free_data;
};
extern "C" LEAN_EXPORT object * lean_read_module_data_parts(b_obj_arg ofnames, object *) {
array_ref<string_ref> fnames(ofnames, true);
// first read in all headers
std::vector<module_file> files;
for (auto const & fname : fnames) {
std::string olean_fn = fname.to_std_string();
try {
std::ifstream in(olean_fn, std::ios_base::binary);
if (in.fail()) {
return io_result_mk_error((sstream() << "failed to open file '" << olean_fn << "'").str());
}
/* Get file size */
in.seekg(0, in.end);
size_t size = in.tellg();
in.seekg(0);
olean_header default_header = {};
olean_header header;
if (!in.read(reinterpret_cast<char *>(&header), sizeof(header))
|| memcmp(header.marker, default_header.marker, sizeof(header.marker)) != 0) {
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "', invalid header").str());
}
in.seekg(0);
if (header.version != default_header.version || header.flags != default_header.flags
#ifdef LEAN_CHECK_OLEAN_VERSION
|| strncmp(header.githash, LEAN_GITHASH, sizeof(header.githash)) != 0
#endif
) {
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "', incompatible header").str());
}
char * base_addr = reinterpret_cast<char *>(header.base_addr);
files.push_back({olean_fn, std::move(in), base_addr, size, nullptr, nullptr});
} catch (exception & ex) {
return io_result_mk_error((sstream() << "failed to read '" << olean_fn << "': " << ex.what()).str());
extern "C" LEAN_EXPORT object * lean_read_module_data(object * fname, object *) {
std::string olean_fn(string_cstr(fname));
try {
std::ifstream in(olean_fn, std::ios_base::binary);
if (in.fail()) {
return io_result_mk_error((sstream() << "failed to open file '" << olean_fn << "'").str());
}
}
/* Get file size */
in.seekg(0, in.end);
size_t size = in.tellg();
in.seekg(0);
#ifndef LEAN_MMAP
bool is_mmap = false;
#else
// now try mmapping *all* files
bool is_mmap = true;
for (auto & file : files) {
std::string const & olean_fn = file.m_fname;
char * base_addr = file.m_base_addr;
try {
olean_header default_header = {};
olean_header header;
if (!in.read(reinterpret_cast<char *>(&header), sizeof(header))
|| memcmp(header.marker, default_header.marker, sizeof(header.marker)) != 0) {
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "', invalid header").str());
}
if (header.version != default_header.version || header.flags != default_header.flags
#ifdef LEAN_CHECK_OLEAN_VERSION
|| strncmp(header.githash, LEAN_GITHASH, sizeof(header.githash)) != 0
#endif
) {
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "', incompatible header").str());
}
char * base_addr = reinterpret_cast<char *>(header.base_addr);
char * buffer = nullptr;
bool is_mmap = false;
std::function<void()> free_data;
#ifdef LEAN_WINDOWS
// `FILE_SHARE_DELETE` is necessary to allow the file to (be marked to) be deleted while in use
HANDLE h_olean_fn = CreateFile(olean_fn.c_str(), GENERIC_READ, FILE_SHARE_READ | FILE_SHARE_DELETE, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
if (h_olean_fn == INVALID_HANDLE_VALUE) {
return io_result_mk_error((sstream() << "failed to open '" << olean_fn << "': " << GetLastError()).str());
// `FILE_SHARE_DELETE` is necessary to allow the file to (be marked to) be deleted while in use
HANDLE h_olean_fn = CreateFile(olean_fn.c_str(), GENERIC_READ, FILE_SHARE_READ | FILE_SHARE_DELETE, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
if (h_olean_fn == INVALID_HANDLE_VALUE) {
return io_result_mk_error((sstream() << "failed to open '" << olean_fn << "': " << GetLastError()).str());
}
HANDLE h_map = CreateFileMapping(h_olean_fn, NULL, PAGE_READONLY, 0, 0, NULL);
if (h_olean_fn == NULL) {
return io_result_mk_error((sstream() << "failed to map '" << olean_fn << "': " << GetLastError()).str());
}
buffer = static_cast<char *>(MapViewOfFileEx(h_map, FILE_MAP_READ, 0, 0, 0, base_addr));
free_data = [=]() {
if (buffer) {
lean_always_assert(UnmapViewOfFile(base_addr));
}
HANDLE h_map = CreateFileMapping(h_olean_fn, NULL, PAGE_READONLY, 0, 0, NULL);
if (h_olean_fn == NULL) {
return io_result_mk_error((sstream() << "failed to map '" << olean_fn << "': " << GetLastError()).str());
}
char * buffer = static_cast<char *>(MapViewOfFileEx(h_map, FILE_MAP_READ, 0, 0, 0, base_addr));
lean_always_assert(CloseHandle(h_map));
lean_always_assert(CloseHandle(h_olean_fn));
if (!buffer) {
is_mmap = false;
break;
}
file.m_free_data = [=]() {
lean_always_assert(UnmapViewOfFile(base_addr));
};
#else
int fd = open(olean_fn.c_str(), O_RDONLY);
if (fd == -1) {
return io_result_mk_error((sstream() << "failed to open '" << olean_fn << "': " << strerror(errno)).str());
}
char * buffer = static_cast<char *>(mmap(base_addr, file.m_size, PROT_READ, MAP_PRIVATE, fd, 0));
if (buffer == MAP_FAILED) {
is_mmap = false;
break;
}
close(fd);
size_t size = file.m_size;
file.m_free_data = [=]() {
lean_always_assert(munmap(buffer, size) == 0);
};
#endif
if (buffer == base_addr) {
file.m_buffer = buffer;
} else {
is_mmap = false;
break;
}
} catch (exception & ex) {
return io_result_mk_error((sstream() << "failed to read '" << olean_fn << "': " << ex.what()).str());
}
}
#endif
// if *any* file failed to mmap, read all of them into a single big allocation so that offsets
// between them are unchanged
if (!is_mmap) {
for (auto & file : files) {
if (file.m_free_data) {
file.m_free_data();
file.m_free_data = {};
}
}
size_t big_size = files[files.size()-1].m_base_addr + files[files.size()-1].m_size - files[0].m_base_addr;
char * big_buffer = static_cast<char *>(malloc(big_size));
for (auto & file : files) {
std::string const & olean_fn = file.m_fname;
try {
file.m_buffer = big_buffer + (file.m_base_addr - files[0].m_base_addr);
file.m_in.read(file.m_buffer, file.m_size);
if (!file.m_in) {
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "'").str());
}
file.m_in.close();
} catch (exception & ex) {
return io_result_mk_error((sstream() << "failed to read '" << olean_fn << "': " << ex.what()).str());
}
}
files[0].m_free_data = [=]() {
free_sized(big_buffer, big_size);
};
}
#else
int fd = open(olean_fn.c_str(), O_RDONLY);
if (fd == -1) {
return io_result_mk_error((sstream() << "failed to open '" << olean_fn << "': " << strerror(errno)).str());
}
#ifdef LEAN_MMAP
buffer = static_cast<char *>(mmap(base_addr, size, PROT_READ, MAP_PRIVATE, fd, 0));
#endif
close(fd);
free_data = [=]() {
if (buffer != MAP_FAILED) {
lean_always_assert(munmap(buffer, size) == 0);
}
};
#endif
if (buffer && buffer == base_addr) {
buffer += sizeof(olean_header);
is_mmap = true;
} else {
#ifdef LEAN_MMAP
free_data();
#endif
buffer = static_cast<char *>(malloc(size - sizeof(olean_header)));
free_data = [=]() {
free_sized(buffer, size - sizeof(olean_header));
};
in.read(buffer, size - sizeof(olean_header));
if (!in) {
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "'").str());
}
}
in.close();
std::vector<object_ref> res;
for (auto & file : files) {
compacted_region * region =
new compacted_region(file.m_size - sizeof(olean_header), file.m_buffer + sizeof(olean_header), static_cast<char *>(file.m_base_addr) + sizeof(olean_header), is_mmap, file.m_free_data);
new compacted_region(size - sizeof(olean_header), buffer, base_addr + sizeof(olean_header), is_mmap, free_data);
#if defined(__has_feature)
#if __has_feature(address_sanitizer)
// do not report as leak
@@ -315,9 +234,18 @@ extern "C" LEAN_EXPORT object * lean_read_module_data_parts(b_obj_arg ofnames, o
object * mod_region = alloc_cnstr(0, 2, 0);
cnstr_set(mod_region, 0, mod);
cnstr_set(mod_region, 1, box_size_t(reinterpret_cast<size_t>(region)));
res.push_back(object_ref(mod_region));
return io_result_mk_ok(mod_region);
} catch (exception & ex) {
return io_result_mk_error((sstream() << "failed to read '" << olean_fn << "': " << ex.what()).str());
}
return io_result_mk_ok(to_array(res));
}
/*
@[export lean.write_module_core]
def writeModule (env : Environment) (fname : String) : IO Unit := */
extern "C" object * lean_write_module(object * env, object * fname, object *);
void write_module(elab_environment const & env, std::string const & olean_fn) {
consume_io_result(lean_write_module(env.to_obj_arg(), mk_string(olean_fn), io_mk_world()));
}
}

View File

@@ -339,10 +339,7 @@ tag_counter_manager g_tag_counter_manager;
void object_compactor::operator()(object * o) {
lean_assert(m_todo.empty());
// allocate for root address, see end of function
// NOTE: we must store an offset instead of the pointer itself as `m_begin` may have been
// reallocated in the meantime
size_t root_offset =
static_cast<char *>(alloc(sizeof(object_offset))) - static_cast<char *>(m_begin);
alloc(sizeof(object_offset));
if (!lean_is_scalar(o)) {
m_todo.push_back(o);
while (!m_todo.empty()) {
@@ -374,8 +371,7 @@ void object_compactor::operator()(object * o) {
}
m_tmp.clear();
}
object_offset * root = reinterpret_cast<object_offset *>(static_cast<char *>(m_begin) + root_offset);
*root = to_offset(o);
*static_cast<object_offset *>(m_begin) = to_offset(o);
}
compacted_region::compacted_region(size_t sz, void * data, void * base_addr, bool is_mmap, std::function<void()> free_data):

View File

@@ -31,6 +31,7 @@ class LEAN_EXPORT object_compactor {
size_t capacity() const { return static_cast<char*>(m_capacity) - static_cast<char*>(m_begin); }
void save(object * o, object * new_o);
void save_max_sharing(object * o, object * new_o, size_t new_o_sz);
void * alloc(size_t sz);
object_offset to_offset(object * o);
void insert_terminator(object * o);
object * copy_object(object * o);
@@ -53,7 +54,6 @@ public:
void operator()(object * o);
size_t size() const { return static_cast<char*>(m_end) - static_cast<char*>(m_begin); }
void const * data() const { return m_begin; }
void * alloc(size_t sz);
};
class LEAN_EXPORT compacted_region {

View File

@@ -6,8 +6,9 @@ Author: Leonardo de Moura
*/
#pragma once
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include "runtime/object_ref.h"
#include "util/alloc.h"
namespace lean {
extern "C" LEAN_EXPORT uint8 lean_sharecommon_eq(b_obj_arg o1, b_obj_arg o2);
@@ -31,9 +32,9 @@ protected:
We use `m_cache` to ensure we do **not** traverse a DAG as a tree.
We use pointer equality for this collection.
*/
lean::unordered_map<lean_object *, lean_object *> m_cache;
std::unordered_map<lean_object *, lean_object *> m_cache;
/* Set of maximally shared terms. AKA hash-consing table. */
lean::unordered_set<lean_object *, set_hash, set_eq> m_set;
std::unordered_set<lean_object *, set_hash, set_eq> m_set;
/*
If `true`, `check_cache` will also check `m_set`.
This is useful when the input term may contain terms that have already

View File

@@ -1,44 +0,0 @@
/*
Copyright (c) 2025 Lean FRO. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Sebastian Ullrich
*/
#pragma once
#include <lean/config.h>
#include <unordered_map>
#include <unordered_set>
#ifdef LEAN_MIMALLOC
#include <lean/mimalloc.h>
#endif
namespace lean {
// We do not override `new` to avoid FFI issues for users but use `lean::allocator`
// explicitly where using the custom allocator is important.
#ifdef LEAN_MIMALLOC
template<class T> using allocator = mi_stl_allocator<T>;
#else
template<class T> using allocator = std::allocator<T>;
#endif
// `unordered_map/set` allocates per insert, so specializing to the custom allocator can
// save significant time for maps with frequent inserts.
template<
class Key,
class T,
class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = lean::allocator<std::pair<const Key, T>>
> using unordered_map = std::unordered_map<Key, T, Hash, KeyEqual, Allocator>;
template<
class Key,
class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = lean::allocator<Key>
> using unordered_set = std::unordered_set<Key, Hash, KeyEqual, Allocator>;
};

View File

@@ -452,14 +452,6 @@ static void report_task_get_blocked_time(std::chrono::nanoseconds d) {
}
}
/*
@[export lean.write_module_core]
def writeModule (env : Environment) (fname : String) (splitExporting : Bool) : IO Unit := */
extern "C" object * lean_write_module(object * env, object * fname, bool split_exporting, object *);
static void write_module(elab_environment const & env, std::string const & olean_fn, bool split_exporting) {
consume_io_result(lean_write_module(env.to_obj_arg(), mk_string(olean_fn), split_exporting, io_mk_world()));
}
extern "C" LEAN_EXPORT int lean_main(int argc, char ** argv) {
#ifdef LEAN_EMSCRIPTEN
// When running in command-line mode under Node.js, we make system directories available in the virtual filesystem.
@@ -775,7 +767,7 @@ extern "C" LEAN_EXPORT int lean_main(int argc, char ** argv) {
}
if (olean_fn && ok) {
time_task t(".olean serialization", opts);
write_module(env, *olean_fn, opts.get_bool({"experimental", "module"}));
write_module(env, *olean_fn);
}
if (c_output && ok) {

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Some files were not shown because too many files have changed in this diff Show More