mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-27 15:24:17 +00:00
Compare commits
52 Commits
UIntX.pow
...
perm_updat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d3c8b31d71 | ||
|
|
14ed646c3e | ||
|
|
c5e20c980c | ||
|
|
cd5b495573 | ||
|
|
2337b95676 | ||
|
|
973f521c46 | ||
|
|
069456ea9c | ||
|
|
aa2cae8801 | ||
|
|
f513c35742 | ||
|
|
d7cc0fd754 | ||
|
|
5f8847151d | ||
|
|
8bc9c4f154 | ||
|
|
dd7ca772d8 | ||
|
|
85a0232e87 | ||
|
|
8ea6465e6d | ||
|
|
38ed4346c2 | ||
|
|
2657f4e62c | ||
|
|
d4767a08b0 | ||
|
|
f562e72e59 | ||
|
|
5a6d45817d | ||
|
|
264095be7f | ||
|
|
0669a04704 | ||
|
|
5cd352588c | ||
|
|
e9cc776f22 | ||
|
|
e79fef15df | ||
|
|
c672934f11 | ||
|
|
582877d2d3 | ||
|
|
39ce3d14f4 | ||
|
|
32758aa712 | ||
|
|
0f6e35dc63 | ||
|
|
2528188dde | ||
|
|
1cdadfd47a | ||
|
|
e07c59c831 | ||
|
|
cbd38ceadd | ||
|
|
c46f1e941c | ||
|
|
cf3b257ccd | ||
|
|
09ab15dc6d | ||
|
|
e631efd817 | ||
|
|
d2f4ce0158 | ||
|
|
69536808ca | ||
|
|
3d5dd15de4 | ||
|
|
91c245663b | ||
|
|
1421b6145e | ||
|
|
bffa642ad6 | ||
|
|
deef1c2739 | ||
|
|
acf42bd30b | ||
|
|
4947215325 | ||
|
|
6e7209dfa3 | ||
|
|
97a00b3881 | ||
|
|
d758b4c862 | ||
|
|
61d7716ad8 | ||
|
|
05f16ed279 |
2
.github/workflows/build-template.yml
vendored
2
.github/workflows/build-template.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
||||
CCACHE_DIR: ${{ github.workspace }}/.ccache
|
||||
CCACHE_COMPRESS: true
|
||||
# current cache limit
|
||||
CCACHE_MAXSIZE: 600M
|
||||
CCACHE_MAXSIZE: 400M
|
||||
# squelch error message about missing nixpkgs channel
|
||||
NIX_BUILD_SHELL: bash
|
||||
LSAN_OPTIONS: max_leaks=10
|
||||
|
||||
23
.github/workflows/ci.yml
vendored
23
.github/workflows/ci.yml
vendored
@@ -256,17 +256,18 @@ jobs:
|
||||
"llvm-url": "https://github.com/leanprover/lean-llvm/releases/download/15.0.1/lean-llvm-aarch64-linux-gnu.tar.zst",
|
||||
"prepare-llvm": "../script/prepare-llvm-linux.sh lean-llvm*"
|
||||
},
|
||||
{
|
||||
"name": "Linux 32bit",
|
||||
"os": "ubuntu-latest",
|
||||
// Use 32bit on stage0 and stage1 to keep oleans compatible
|
||||
"CMAKE_OPTIONS": "-DSTAGE0_USE_GMP=OFF -DSTAGE0_LEAN_EXTRA_CXX_FLAGS='-m32' -DSTAGE0_LEANC_OPTS='-m32' -DSTAGE0_MMAP=OFF -DUSE_GMP=OFF -DLEAN_EXTRA_CXX_FLAGS='-m32' -DLEANC_OPTS='-m32' -DMMAP=OFF -DLEAN_INSTALL_SUFFIX=-linux_x86 -DCMAKE_LIBRARY_PATH=/usr/lib/i386-linux-gnu/ -DSTAGE0_CMAKE_LIBRARY_PATH=/usr/lib/i386-linux-gnu/ -DPKG_CONFIG_EXECUTABLE=/usr/bin/i386-linux-gnu-pkg-config",
|
||||
"cmultilib": true,
|
||||
"release": true,
|
||||
"check-level": 2,
|
||||
"cross": true,
|
||||
"shell": "bash -euxo pipefail {0}"
|
||||
}
|
||||
// Started running out of memory building expensive modules, a 2GB heap is just not that much even before fragmentation
|
||||
//{
|
||||
// "name": "Linux 32bit",
|
||||
// "os": "ubuntu-latest",
|
||||
// // Use 32bit on stage0 and stage1 to keep oleans compatible
|
||||
// "CMAKE_OPTIONS": "-DSTAGE0_USE_GMP=OFF -DSTAGE0_LEAN_EXTRA_CXX_FLAGS='-m32' -DSTAGE0_LEANC_OPTS='-m32' -DSTAGE0_MMAP=OFF -DUSE_GMP=OFF -DLEAN_EXTRA_CXX_FLAGS='-m32' -DLEANC_OPTS='-m32' -DMMAP=OFF -DLEAN_INSTALL_SUFFIX=-linux_x86 -DCMAKE_LIBRARY_PATH=/usr/lib/i386-linux-gnu/ -DSTAGE0_CMAKE_LIBRARY_PATH=/usr/lib/i386-linux-gnu/ -DPKG_CONFIG_EXECUTABLE=/usr/bin/i386-linux-gnu-pkg-config",
|
||||
// "cmultilib": true,
|
||||
// "release": true,
|
||||
// "check-level": 2,
|
||||
// "cross": true,
|
||||
// "shell": "bash -euxo pipefail {0}"
|
||||
//}
|
||||
// {
|
||||
// "name": "Web Assembly",
|
||||
// "os": "ubuntu-latest",
|
||||
|
||||
@@ -780,12 +780,11 @@ add_custom_target(clean-olean
|
||||
DEPENDS clean-stdlib)
|
||||
|
||||
install(DIRECTORY "${CMAKE_BINARY_DIR}/lib/" DESTINATION lib
|
||||
PATTERN temp
|
||||
PATTERN "*.export"
|
||||
PATTERN "*.hash"
|
||||
PATTERN "*.trace"
|
||||
PATTERN "*.rsp"
|
||||
EXCLUDE)
|
||||
PATTERN temp EXCLUDE
|
||||
PATTERN "*.export" EXCLUDE
|
||||
PATTERN "*.hash" EXCLUDE
|
||||
PATTERN "*.trace" EXCLUDE
|
||||
PATTERN "*.rsp" EXCLUDE)
|
||||
|
||||
# symlink source into expected installation location for go-to-definition, if file system allows it
|
||||
file(MAKE_DIRECTORY ${CMAKE_BINARY_DIR}/src)
|
||||
|
||||
@@ -59,6 +59,9 @@ instance : Monad (OptionT m) where
|
||||
pure := OptionT.pure
|
||||
bind := OptionT.bind
|
||||
|
||||
instance {m : Type u → Type v} [Pure m] : Inhabited (OptionT m α) where
|
||||
default := pure (f:=m) default
|
||||
|
||||
/--
|
||||
Recovers from failures. Typically used via the `<|>` operator.
|
||||
-/
|
||||
|
||||
@@ -34,7 +34,6 @@ import Init.Data.Stream
|
||||
import Init.Data.Prod
|
||||
import Init.Data.AC
|
||||
import Init.Data.Queue
|
||||
import Init.Data.Channel
|
||||
import Init.Data.Sum
|
||||
import Init.Data.BEq
|
||||
import Init.Data.Subtype
|
||||
|
||||
@@ -288,6 +288,17 @@ theorem count_flatMap {α} [BEq β] {xs : Array α} {f : α → Array β} {x :
|
||||
rcases xs with ⟨xs⟩
|
||||
simp [List.count_flatMap, countP_flatMap, Function.comp_def]
|
||||
|
||||
theorem countP_replace {a b : α} {xs : Array α} {p : α → Bool} :
|
||||
(xs.replace a b).countP p =
|
||||
if xs.contains a then xs.countP p + (if p b then 1 else 0) - (if p a then 1 else 0) else xs.countP p := by
|
||||
rcases xs with ⟨xs⟩
|
||||
simp [List.countP_replace]
|
||||
|
||||
theorem count_replace {a b c : α} {xs : Array α} :
|
||||
(xs.replace a b).count c =
|
||||
if xs.contains a then xs.count c + (if b == c then 1 else 0) - (if a == c then 1 else 0) else xs.count c := by
|
||||
simp [count_eq_countP, countP_replace]
|
||||
|
||||
-- FIXME these theorems can be restored once `List.erase` and `Array.erase` have been related.
|
||||
|
||||
-- theorem count_erase (a b : α) (l : Array α) : count a (l.erase b) = count a l - if b == a then 1 else 0 := by
|
||||
|
||||
@@ -446,11 +446,13 @@ theorem findIdx?_eq_none_iff {xs : Array α} {p : α → Bool} :
|
||||
rcases xs with ⟨xs⟩
|
||||
simp
|
||||
|
||||
@[simp]
|
||||
theorem findIdx?_isSome {xs : Array α} {p : α → Bool} :
|
||||
(xs.findIdx? p).isSome = xs.any p := by
|
||||
rcases xs with ⟨xs⟩
|
||||
simp [List.findIdx?_isSome]
|
||||
|
||||
@[simp]
|
||||
theorem findIdx?_isNone {xs : Array α} {p : α → Bool} :
|
||||
(xs.findIdx? p).isNone = xs.all (¬p ·) := by
|
||||
rcases xs with ⟨xs⟩
|
||||
@@ -591,6 +593,18 @@ theorem findFinIdx?_eq_some_iff {xs : Array α} {p : α → Bool} {i : Fin xs.si
|
||||
· rintro ⟨h, w⟩
|
||||
exact ⟨i, ⟨i.2, h, fun j hji => w ⟨j, by omega⟩ hji⟩, rfl⟩
|
||||
|
||||
@[simp]
|
||||
theorem isSome_findFinIdx? {xs : Array α} {p : α → Bool} :
|
||||
(xs.findFinIdx? p).isSome = xs.any p := by
|
||||
rcases xs with ⟨xs⟩
|
||||
simp
|
||||
|
||||
@[simp]
|
||||
theorem isNone_findFinIdx? {xs : Array α} {p : α → Bool} :
|
||||
(xs.findFinIdx? p).isNone = xs.all (fun x => ¬ p x) := by
|
||||
rcases xs with ⟨xs⟩
|
||||
simp
|
||||
|
||||
@[simp] theorem findFinIdx?_subtype {p : α → Prop} {xs : Array { x // p x }}
|
||||
{f : { x // p x } → Bool} {g : α → Bool} (hf : ∀ x h, f ⟨x, h⟩ = g x) :
|
||||
xs.findFinIdx? f = (xs.unattach.findFinIdx? g).map (fun i => i.cast (by simp)) := by
|
||||
@@ -636,6 +650,20 @@ The lemmas below should be made consistent with those for `findIdx?` (and proved
|
||||
rcases xs with ⟨xs⟩
|
||||
simp [List.idxOf?_eq_none_iff]
|
||||
|
||||
@[simp]
|
||||
theorem isSome_idxOf? [BEq α] [LawfulBEq α] {xs : Array α} {a : α} :
|
||||
(xs.idxOf? a).isSome ↔ a ∈ xs := by
|
||||
rcases xs with ⟨xs⟩
|
||||
simp
|
||||
|
||||
@[simp]
|
||||
theorem isNone_idxOf? [BEq α] [LawfulBEq α] {xs : Array α} {a : α} :
|
||||
(xs.idxOf? a).isNone = ¬ a ∈ xs := by
|
||||
rcases xs with ⟨xs⟩
|
||||
simp
|
||||
|
||||
|
||||
|
||||
/-! ### finIdxOf?
|
||||
|
||||
The verification API for `finIdxOf?` is still incomplete.
|
||||
@@ -658,4 +686,16 @@ theorem idxOf?_eq_map_finIdxOf?_val [BEq α] {xs : Array α} {a : α} :
|
||||
rcases xs with ⟨xs⟩
|
||||
simp [List.finIdxOf?_eq_some_iff]
|
||||
|
||||
@[simp]
|
||||
theorem isSome_finIdxOf? [BEq α] [LawfulBEq α] {xs : Array α} {a : α} :
|
||||
(xs.finIdxOf? a).isSome ↔ a ∈ xs := by
|
||||
rcases xs with ⟨xs⟩
|
||||
simp
|
||||
|
||||
@[simp]
|
||||
theorem isNone_finIdxOf? [BEq α] [LawfulBEq α] {xs : Array α} {a : α} :
|
||||
(xs.finIdxOf? a).isNone = ¬ a ∈ xs := by
|
||||
rcases xs with ⟨xs⟩
|
||||
simp
|
||||
|
||||
end Array
|
||||
|
||||
@@ -53,6 +53,12 @@ instance : Trans (Perm (α := α)) (Perm (α := α)) (Perm (α := α)) where
|
||||
|
||||
theorem perm_comm {xs ys : Array α} : xs ~ ys ↔ ys ~ xs := ⟨Perm.symm, Perm.symm⟩
|
||||
|
||||
theorem Perm.mem_iff {a : α} {xs ys : Array α} (p : xs ~ ys) : a ∈ xs ↔ a ∈ ys := by
|
||||
rcases xs with ⟨xs⟩
|
||||
rcases ys with ⟨ys⟩
|
||||
simp at p
|
||||
simpa using p.mem_iff
|
||||
|
||||
theorem Perm.push (x y : α) {xs ys : Array α} (p : xs ~ ys) :
|
||||
(xs.push x).push y ~ (ys.push y).push x := by
|
||||
cases xs; cases ys
|
||||
@@ -65,4 +71,20 @@ theorem swap_perm {xs : Array α} {i j : Nat} (h₁ : i < xs.size) (h₂ : j < x
|
||||
simp only [swap, perm_iff_toList_perm, toList_set]
|
||||
apply set_set_perm
|
||||
|
||||
namespace Perm
|
||||
|
||||
set_option linter.indexVariables false in
|
||||
theorem extract {xs ys : Array α} (h : xs ~ ys) {lo hi : Nat}
|
||||
(wlo : ∀ i, i < lo → xs[i]? = ys[i]?) (whi : ∀ i, hi ≤ i → xs[i]? = ys[i]?) :
|
||||
(xs.extract lo hi) ~ (ys.extract lo hi) := by
|
||||
rcases xs with ⟨xs⟩
|
||||
rcases ys with ⟨ys⟩
|
||||
simp_all only [perm_toArray, List.getElem?_toArray, List.extract_toArray,
|
||||
List.extract_eq_drop_take]
|
||||
apply List.Perm.take_of_getElem? (w := fun i h => by simpa using whi (lo + i) (by omega))
|
||||
apply List.Perm.drop_of_getElem? (w := wlo)
|
||||
exact h
|
||||
|
||||
end Perm
|
||||
|
||||
end Array
|
||||
|
||||
@@ -227,6 +227,20 @@ SMT-LIB name: `bvmul`.
|
||||
protected def mul (x y : BitVec n) : BitVec n := BitVec.ofNat n (x.toNat * y.toNat)
|
||||
instance : Mul (BitVec n) := ⟨.mul⟩
|
||||
|
||||
/--
|
||||
Raises a bitvector to a natural number power. Usually accessed via the `^` operator.
|
||||
|
||||
Note that this is currently an inefficient implementation,
|
||||
and should be replaced via an `@[extern]` with a native implementation.
|
||||
See https://github.com/leanprover/lean4/issues/7887.
|
||||
-/
|
||||
protected def pow (x : BitVec n) (y : Nat) : BitVec n :=
|
||||
match y with
|
||||
| 0 => 1
|
||||
| y + 1 => x.pow y * x
|
||||
instance : Pow (BitVec n) Nat where
|
||||
pow x y := x.pow y
|
||||
|
||||
/--
|
||||
Unsigned division of bitvectors using the Lean convention where division by zero returns zero.
|
||||
Usually accessed via the `/` operator.
|
||||
|
||||
@@ -3653,6 +3653,13 @@ theorem mul_def {n} {x y : BitVec n} : x * y = (ofFin <| x.toFin * y.toFin) := b
|
||||
@[simp, bitvec_to_nat] theorem toNat_mul (x y : BitVec n) : (x * y).toNat = (x.toNat * y.toNat) % 2 ^ n := rfl
|
||||
@[simp] theorem toFin_mul (x y : BitVec n) : (x * y).toFin = (x.toFin * y.toFin) := rfl
|
||||
|
||||
theorem ofNat_mul {n} (x y : Nat) : BitVec.ofNat n (x * y) = BitVec.ofNat n x * BitVec.ofNat n y := by
|
||||
apply eq_of_toNat_eq
|
||||
simp [BitVec.ofNat, Fin.ofNat'_mul]
|
||||
|
||||
theorem ofNat_mul_ofNat {n} (x y : Nat) : BitVec.ofNat n x * BitVec.ofNat n y = BitVec.ofNat n (x * y) :=
|
||||
(ofNat_mul x y).symm
|
||||
|
||||
protected theorem mul_comm (x y : BitVec w) : x * y = y * x := by
|
||||
apply eq_of_toFin_eq; simpa using Fin.mul_comm ..
|
||||
instance : Std.Commutative (fun (x y : BitVec w) => x * y) := ⟨BitVec.mul_comm⟩
|
||||
@@ -3746,6 +3753,22 @@ theorem setWidth_mul (x y : BitVec w) (h : i ≤ w) :
|
||||
have dvd : 2^i ∣ 2^w := Nat.pow_dvd_pow _ h
|
||||
simp [bitvec_to_nat, h, Nat.mod_mod_of_dvd _ dvd]
|
||||
|
||||
/-! ### pow -/
|
||||
|
||||
@[simp]
|
||||
protected theorem pow_zero {x : BitVec w} : x ^ 0 = 1#w := rfl
|
||||
|
||||
protected theorem pow_succ {x : BitVec w} : x ^ (n + 1) = x ^ n * x := rfl
|
||||
|
||||
@[simp]
|
||||
protected theorem pow_one {x : BitVec w} : x ^ 1 = x := by simp [BitVec.pow_succ]
|
||||
|
||||
protected theorem pow_add {x : BitVec w} {n m : Nat}: x ^ (n + m) = (x ^ n) * (x ^ m):= by
|
||||
induction m with
|
||||
| zero => simp
|
||||
| succ m ih =>
|
||||
rw [← Nat.add_assoc, BitVec.pow_succ, ih, BitVec.mul_assoc, BitVec.pow_succ]
|
||||
|
||||
/-! ### le and lt -/
|
||||
|
||||
@[bitvec_to_nat] theorem le_def {x y : BitVec n} :
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Gabriel Ebner
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Queue
|
||||
import Init.System.Promise
|
||||
import Init.System.Mutex
|
||||
|
||||
set_option linter.deprecated false
|
||||
|
||||
namespace IO
|
||||
|
||||
/--
|
||||
Internal state of an `Channel`.
|
||||
|
||||
We maintain the invariant that at all times either `consumers` or `values` is empty.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel.State from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
structure Channel.State (α : Type) where
|
||||
values : Std.Queue α := ∅
|
||||
consumers : Std.Queue (Promise (Option α)) := ∅
|
||||
closed := false
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
FIFO channel with unbounded buffer, where `recv?` returns a `Task`.
|
||||
|
||||
A channel can be closed. Once it is closed, all `send`s are ignored, and
|
||||
`recv?` returns `none` once the queue is empty.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel (α : Type) : Type := Mutex (Channel.State α)
|
||||
|
||||
instance : Nonempty (Channel α) :=
|
||||
inferInstanceAs (Nonempty (Mutex _))
|
||||
|
||||
/-- Creates a new `Channel`. -/
|
||||
@[deprecated "Use Std.Channel.new from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel.new : BaseIO (Channel α) :=
|
||||
Mutex.new {}
|
||||
|
||||
/--
|
||||
Sends a message on an `Channel`.
|
||||
|
||||
This function does not block.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel.send from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel.send (ch : Channel α) (v : α) : BaseIO Unit :=
|
||||
ch.atomically do
|
||||
let st ← get
|
||||
if st.closed then return
|
||||
if let some (consumer, consumers) := st.consumers.dequeue? then
|
||||
consumer.resolve (some v)
|
||||
set { st with consumers }
|
||||
else
|
||||
set { st with values := st.values.enqueue v }
|
||||
|
||||
/--
|
||||
Closes an `Channel`.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel.close from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel.close (ch : Channel α) : BaseIO Unit :=
|
||||
ch.atomically do
|
||||
let st ← get
|
||||
for consumer in st.consumers.toArray do consumer.resolve none
|
||||
set { st with closed := true, consumers := ∅ }
|
||||
|
||||
/--
|
||||
Receives a message, without blocking.
|
||||
The returned task waits for the message.
|
||||
Every message is only received once.
|
||||
|
||||
Returns `none` if the channel is closed and the queue is empty.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel.recv? from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel.recv? (ch : Channel α) : BaseIO (Task (Option α)) :=
|
||||
ch.atomically do
|
||||
let st ← get
|
||||
if let some (a, values) := st.values.dequeue? then
|
||||
set { st with values }
|
||||
return .pure a
|
||||
else if !st.closed then
|
||||
let promise ← Promise.new
|
||||
set { st with consumers := st.consumers.enqueue promise }
|
||||
return promise.result
|
||||
else
|
||||
return .pure none
|
||||
|
||||
/--
|
||||
`ch.forAsync f` calls `f` for every messages received on `ch`.
|
||||
|
||||
Note that if this function is called twice, each `forAsync` only gets half the messages.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel.forAsync from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
partial def Channel.forAsync (f : α → BaseIO Unit) (ch : Channel α)
|
||||
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
|
||||
BaseIO.bindTask (prio := prio) (← ch.recv?) fun
|
||||
| none => return .pure ()
|
||||
| some v => do f v; ch.forAsync f prio
|
||||
|
||||
/--
|
||||
Receives all currently queued messages from the channel.
|
||||
|
||||
Those messages are dequeued and will not be returned by `recv?`.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel.recvAllCurrent from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel.recvAllCurrent (ch : Channel α) : BaseIO (Array α) :=
|
||||
ch.atomically do
|
||||
modifyGet fun st => (st.values.toArray, { st with values := ∅ })
|
||||
|
||||
/-- Type tag for synchronous (blocking) operations on a `Channel`. -/
|
||||
@[deprecated "Use Std.Channel.Sync from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel.Sync := Channel
|
||||
|
||||
/--
|
||||
Accesses synchronous (blocking) version of channel operations.
|
||||
|
||||
For example, `ch.sync.recv?` blocks until the next message,
|
||||
and `for msg in ch.sync do ...` iterates synchronously over the channel.
|
||||
These functions should only be used in dedicated threads.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel.sync from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel.sync (ch : Channel α) : Channel.Sync α := ch
|
||||
|
||||
/--
|
||||
Synchronously receives a message from the channel.
|
||||
|
||||
Every message is only received once.
|
||||
Returns `none` if the channel is closed and the queue is empty.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel.Sync.recv? from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel.Sync.recv? (ch : Channel.Sync α) : BaseIO (Option α) := do
|
||||
IO.wait (← Channel.recv? ch)
|
||||
|
||||
@[deprecated "Use Std.Channel.Sync.forIn from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
private partial def Channel.Sync.forIn [Monad m] [MonadLiftT BaseIO m]
|
||||
(ch : Channel.Sync α) (f : α → β → m (ForInStep β)) : β → m β := fun b => do
|
||||
match ← ch.recv? with
|
||||
| some a =>
|
||||
match ← f a b with
|
||||
| .done b => pure b
|
||||
| .yield b => ch.forIn f b
|
||||
| none => pure b
|
||||
|
||||
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
|
||||
instance [MonadLiftT BaseIO m] : ForIn m (Channel.Sync α) α where
|
||||
forIn ch b f := ch.forIn f b
|
||||
@@ -976,6 +976,16 @@ theorem coe_sub_iff_lt {a b : Fin n} : (↑(a - b) : Nat) = n + a - b ↔ a < b
|
||||
|
||||
/-! ### mul -/
|
||||
|
||||
theorem ofNat'_mul [NeZero n] (x : Nat) (y : Fin n) :
|
||||
Fin.ofNat' n x * y = Fin.ofNat' n (x * y.val) := by
|
||||
apply Fin.eq_of_val_eq
|
||||
simp [Fin.ofNat', Fin.mul_def]
|
||||
|
||||
theorem mul_ofNat' [NeZero n] (x : Fin n) (y : Nat) :
|
||||
x * Fin.ofNat' n y = Fin.ofNat' n (x.val * y) := by
|
||||
apply Fin.eq_of_val_eq
|
||||
simp [Fin.ofNat', Fin.mul_def]
|
||||
|
||||
theorem val_mul {n : Nat} : ∀ a b : Fin n, (a * b).val = a.val * b.val % n
|
||||
| ⟨_, _⟩, ⟨_, _⟩ => rfl
|
||||
|
||||
|
||||
@@ -420,6 +420,8 @@ instance : IntCast Int where intCast n := n
|
||||
protected def Int.cast {R : Type u} [IntCast R] : Int → R :=
|
||||
IntCast.intCast
|
||||
|
||||
@[simp] theorem Int.cast_eq (x : Int) : Int.cast x = x := rfl
|
||||
|
||||
-- see the notes about coercions into arbitrary types in the module doc-string
|
||||
instance [IntCast R] : CoeTail Int R where coe := Int.cast
|
||||
|
||||
|
||||
@@ -2145,6 +2145,11 @@ theorem bmod_pos (x : Int) (m : Nat) (p : x % m < (m + 1) / 2) : bmod x m = x %
|
||||
theorem bmod_neg (x : Int) (m : Nat) (p : x % m ≥ (m + 1) / 2) : bmod x m = (x % m) - m := by
|
||||
simp [bmod_def, Int.not_lt.mpr p]
|
||||
|
||||
theorem bmod_eq_emod (x : Int) (m : Nat) : bmod x m = x % m - if x % m ≥ (m + 1) / 2 then m else 0 := by
|
||||
split
|
||||
· rwa [bmod_neg]
|
||||
· rw [bmod_pos] <;> simp_all
|
||||
|
||||
@[simp]
|
||||
theorem bmod_one_is_zero (x : Int) : Int.bmod x 1 = 0 := by
|
||||
simp [Int.bmod]
|
||||
@@ -2373,6 +2378,43 @@ theorem bmod_neg_bmod : bmod (-(bmod x n)) n = bmod (-x) n := by
|
||||
apply (bmod_add_cancel_right x).mp
|
||||
rw [Int.add_left_neg, ← add_bmod_bmod, Int.add_left_neg]
|
||||
|
||||
theorem bmod_neg_iff {m : Nat} {x : Int} (h2 : -m ≤ x) (h1 : x < m) :
|
||||
(x.bmod m) < 0 ↔ (-(m / 2) ≤ x ∧ x < 0) ∨ ((m + 1) / 2 ≤ x) := by
|
||||
simp only [Int.bmod_def]
|
||||
by_cases xpos : 0 ≤ x
|
||||
· rw [Int.emod_eq_of_lt xpos (by omega)]; omega
|
||||
· rw [Int.add_emod_self.symm, Int.emod_eq_of_lt (by omega) (by omega)]; omega
|
||||
|
||||
theorem bmod_eq_self_of_le {n : Int} {m : Nat} (hn' : -(m / 2) ≤ n) (hn : n < (m + 1) / 2) :
|
||||
n.bmod m = n := by
|
||||
rw [← Int.sub_eq_zero]
|
||||
have := le_bmod (x := n) (m := m) (by omega)
|
||||
have := bmod_lt (x := n) (m := m) (by omega)
|
||||
apply eq_zero_of_dvd_of_natAbs_lt_natAbs Int.dvd_bmod_sub_self
|
||||
omega
|
||||
|
||||
theorem bmod_bmod_of_dvd {a : Int} {n m : Nat} (hnm : n ∣ m) :
|
||||
(a.bmod m).bmod n = a.bmod n := by
|
||||
rw [← Int.sub_eq_iff_eq_add.2 (bmod_add_bdiv a m).symm]
|
||||
obtain ⟨k, rfl⟩ := hnm
|
||||
simp [Int.mul_assoc]
|
||||
|
||||
theorem bmod_eq_self_of_le_mul_two {x : Int} {y : Nat} (hle : -y ≤ x * 2) (hlt : x * 2 < y) :
|
||||
x.bmod y = x := by
|
||||
apply bmod_eq_self_of_le (by omega) (by omega)
|
||||
|
||||
theorem dvd_iff_bmod_eq_zero {a : Nat} {b : Int} : (a : Int) ∣ b ↔ b.bmod a = 0 := by
|
||||
rw [dvd_iff_emod_eq_zero, bmod]
|
||||
split <;> rename_i h
|
||||
· rfl
|
||||
· simp only [Int.not_lt] at h
|
||||
match a with
|
||||
| 0 => omega
|
||||
| a + 1 =>
|
||||
have : b % (a+1) < a + 1 := emod_lt b (by omega)
|
||||
simp_all
|
||||
omega
|
||||
|
||||
/-! Helper theorems for `dvd` simproc -/
|
||||
|
||||
protected theorem dvd_eq_true_of_mod_eq_zero {a b : Int} (h : b % a == 0) : (a ∣ b) = True := by
|
||||
|
||||
@@ -377,6 +377,11 @@ theorem toNat_of_nonpos : ∀ {z : Int}, z ≤ 0 → z.toNat = 0
|
||||
@[simp] theorem negSucc_add_one_eq_neg_ofNat_iff {a b : Nat} : -[a+1] + 1 = - (b : Int) ↔ a = b := by
|
||||
rw [eq_comm, neg_ofNat_eq_negSucc_add_one_iff, eq_comm]
|
||||
|
||||
protected theorem sub_eq_iff_eq_add {b a c : Int} : a - b = c ↔ a = c + b := by
|
||||
refine ⟨fun h => ?_, fun h => ?_⟩ <;> subst h <;> simp
|
||||
protected theorem sub_eq_iff_eq_add' {b a c : Int} : a - b = c ↔ a = b + c := by
|
||||
rw [Int.sub_eq_iff_eq_add, Int.add_comm]
|
||||
|
||||
/- ## add/sub injectivity -/
|
||||
|
||||
@[simp] protected theorem add_left_inj {i j : Int} (k : Int) : (i + k = j + k) ↔ i = j := by
|
||||
|
||||
@@ -19,9 +19,6 @@ namespace Int
|
||||
|
||||
@[simp] theorem natCast_le_zero : {n : Nat} → (n : Int) ≤ 0 ↔ n = 0 := by omega
|
||||
|
||||
protected theorem sub_eq_iff_eq_add {b a c : Int} : a - b = c ↔ a = c + b := by omega
|
||||
protected theorem sub_eq_iff_eq_add' {b a c : Int} : a - b = c ↔ a = b + c := by omega
|
||||
|
||||
@[simp] protected theorem neg_nonpos_iff (i : Int) : -i ≤ 0 ↔ 0 ≤ i := by omega
|
||||
|
||||
@[simp] theorem zero_le_ofNat (n : Nat) : 0 ≤ ((no_index (OfNat.ofNat n)) : Int) :=
|
||||
@@ -81,15 +78,6 @@ theorem eq_ofNat_toNat {a : Int} : a = a.toNat ↔ 0 ≤ a := by omega
|
||||
theorem toNat_le_toNat {n m : Int} (h : n ≤ m) : n.toNat ≤ m.toNat := by omega
|
||||
theorem toNat_lt_toNat {n m : Int} (hn : 0 < m) : n.toNat < m.toNat ↔ n < m := by omega
|
||||
|
||||
/-! ### natAbs -/
|
||||
|
||||
theorem eq_zero_of_dvd_of_natAbs_lt_natAbs {d n : Int} (h : d ∣ n) (h₁ : n.natAbs < d.natAbs) :
|
||||
n = 0 := by
|
||||
obtain ⟨a, rfl⟩ := h
|
||||
rw [natAbs_mul] at h₁
|
||||
suffices ¬ 0 < a.natAbs by simp [Int.natAbs_eq_zero.1 (Nat.eq_zero_of_not_pos this)]
|
||||
exact fun h => Nat.lt_irrefl _ (Nat.lt_of_le_of_lt (Nat.le_mul_of_pos_right d.natAbs h) h₁)
|
||||
|
||||
/-! ### min and max -/
|
||||
|
||||
@[simp] protected theorem min_assoc : ∀ (a b c : Int), min (min a b) c = min a (min b c) := by omega
|
||||
@@ -128,33 +116,6 @@ protected theorem sub_min_sub_left (a b c : Int) : min (a - b) (a - c) = a - max
|
||||
|
||||
protected theorem sub_max_sub_left (a b c : Int) : max (a - b) (a - c) = a - min b c := by omega
|
||||
|
||||
/-! ### bmod -/
|
||||
|
||||
theorem bmod_neg_iff {m : Nat} {x : Int} (h2 : -m ≤ x) (h1 : x < m) :
|
||||
(x.bmod m) < 0 ↔ (-(m / 2) ≤ x ∧ x < 0) ∨ ((m + 1) / 2 ≤ x) := by
|
||||
simp only [Int.bmod_def]
|
||||
by_cases xpos : 0 ≤ x
|
||||
· rw [Int.emod_eq_of_lt xpos (by omega)]; omega
|
||||
· rw [Int.add_emod_self.symm, Int.emod_eq_of_lt (by omega) (by omega)]; omega
|
||||
|
||||
theorem bmod_eq_self_of_le {n : Int} {m : Nat} (hn' : -(m / 2) ≤ n) (hn : n < (m + 1) / 2) :
|
||||
n.bmod m = n := by
|
||||
rw [← Int.sub_eq_zero]
|
||||
have := le_bmod (x := n) (m := m) (by omega)
|
||||
have := bmod_lt (x := n) (m := m) (by omega)
|
||||
apply eq_zero_of_dvd_of_natAbs_lt_natAbs Int.dvd_bmod_sub_self
|
||||
omega
|
||||
|
||||
theorem bmod_bmod_of_dvd {a : Int} {n m : Nat} (hnm : n ∣ m) :
|
||||
(a.bmod m).bmod n = a.bmod n := by
|
||||
rw [← Int.sub_eq_iff_eq_add.2 (bmod_add_bdiv a m).symm]
|
||||
obtain ⟨k, rfl⟩ := hnm
|
||||
simp [Int.mul_assoc]
|
||||
|
||||
theorem bmod_eq_self_of_le_mul_two {x : Int} {y : Nat} (hle : -y ≤ x * 2) (hlt : x * 2 < y) :
|
||||
x.bmod y = x := by
|
||||
apply bmod_eq_self_of_le (by omega) (by omega)
|
||||
|
||||
theorem mul_le_mul_of_natAbs_le {x y : Int} {s t : Nat} (hx : x.natAbs ≤ s) (hy : y.natAbs ≤ t) :
|
||||
x * y ≤ s * t := by
|
||||
by_cases 0 < s ∧ 0 < t
|
||||
|
||||
@@ -329,9 +329,9 @@ protected theorem le_iff_lt_add_one {a b : Int} : a ≤ b ↔ a < b + 1 := by
|
||||
|
||||
/- ### min and max -/
|
||||
|
||||
protected theorem min_def (n m : Int) : min n m = if n ≤ m then n else m := rfl
|
||||
@[grind =] protected theorem min_def (n m : Int) : min n m = if n ≤ m then n else m := rfl
|
||||
|
||||
protected theorem max_def (n m : Int) : max n m = if n ≤ m then m else n := rfl
|
||||
@[grind =] protected theorem max_def (n m : Int) : max n m = if n ≤ m then m else n := rfl
|
||||
|
||||
@[simp] protected theorem neg_min_neg (a b : Int) : min (-a) (-b) = -max a b := by
|
||||
rw [Int.min_def, Int.max_def]
|
||||
@@ -562,6 +562,16 @@ theorem natAbs_sub_of_nonneg_of_le {a b : Int} (h₁ : 0 ≤ b) (h₂ : b ≤ a)
|
||||
· rwa [← Int.ofNat_le, natAbs_of_nonneg h₁, natAbs_of_nonneg (Int.le_trans h₁ h₂)]
|
||||
· exact Int.sub_nonneg_of_le h₂
|
||||
|
||||
theorem eq_zero_of_dvd_of_natAbs_lt_natAbs {d n : Int} (h : d ∣ n) (h₁ : n.natAbs < d.natAbs) :
|
||||
n = 0 := by
|
||||
let ⟨a, ha⟩ := h
|
||||
subst ha
|
||||
rw [natAbs_mul] at h₁
|
||||
suffices ¬ 0 < a.natAbs by simp [Int.natAbs_eq_zero.1 (Nat.eq_zero_of_not_pos this)]
|
||||
refine fun h => Nat.lt_irrefl _ (Nat.lt_of_le_of_lt ?_ h₁)
|
||||
rw (occs := [1]) [← Nat.mul_one d.natAbs]
|
||||
exact Nat.mul_le_mul (Nat.le_refl _) h
|
||||
|
||||
/-! ### toNat -/
|
||||
|
||||
theorem toNat_eq_max : ∀ a : Int, (toNat a : Int) = max a 0
|
||||
@@ -599,6 +609,18 @@ theorem toNat_add {a b : Int} (ha : 0 ≤ a) (hb : 0 ≤ b) : (a + b).toNat = a.
|
||||
match a, b, eq_ofNat_of_zero_le ha, eq_ofNat_of_zero_le hb with
|
||||
| _, _, ⟨_, rfl⟩, ⟨_, rfl⟩ => rfl
|
||||
|
||||
theorem toNat_mul {a b : Int} (ha : 0 ≤ a) (hb : 0 ≤ b) : (a * b).toNat = a.toNat * b.toNat :=
|
||||
match a, b, eq_ofNat_of_zero_le ha, eq_ofNat_of_zero_le hb with
|
||||
| _, _, ⟨_, rfl⟩, ⟨_, rfl⟩ => rfl
|
||||
|
||||
/--
|
||||
Variant of `Int.toNat_sub` taking non-negativity hypotheses,
|
||||
rather than expecting the arguments to be casts of natural numbers.
|
||||
-/
|
||||
theorem toNat_sub'' {a b : Int} (ha : 0 ≤ a) (hb : 0 ≤ b) : (a - b).toNat = a.toNat - b.toNat :=
|
||||
match a, b, eq_ofNat_of_zero_le ha, eq_ofNat_of_zero_le hb with
|
||||
| _, _, ⟨_, rfl⟩, ⟨_, rfl⟩ => toNat_sub _ _
|
||||
|
||||
theorem toNat_add_nat {a : Int} (ha : 0 ≤ a) (n : Nat) : (a + n).toNat = a.toNat + n :=
|
||||
match a, eq_ofNat_of_zero_le ha with | _, ⟨_, rfl⟩ => rfl
|
||||
|
||||
|
||||
@@ -477,7 +477,7 @@ theorem attach_filterMap {l : List α} {f : α → Option β} :
|
||||
· simp only [h]
|
||||
rfl
|
||||
rw [ih]
|
||||
simp only [map_filterMap, Option.map_pbind, Option.map_some']
|
||||
simp only [map_filterMap, Option.map_pbind, Option.map_some]
|
||||
rfl
|
||||
· simp only [Option.pbind_eq_some_iff] at h
|
||||
obtain ⟨a, h, w⟩ := h
|
||||
|
||||
@@ -95,10 +95,10 @@ theorem findSome?_eq_some_iff {f : α → Option β} {l : List α} {b : β} :
|
||||
| cons x xs ih =>
|
||||
simp [guard, findSome?, find?]
|
||||
split <;> rename_i h
|
||||
· simp only [Option.guard_eq_some] at h
|
||||
· simp only [Option.guard_eq_some_iff] at h
|
||||
obtain ⟨rfl, h⟩ := h
|
||||
simp [h]
|
||||
· simp only [Option.guard_eq_none] at h
|
||||
· simp only [Option.guard_eq_none_iff] at h
|
||||
simp [ih, h]
|
||||
|
||||
theorem find?_eq_findSome?_guard {l : List α} : find? p l = findSome? (Option.guard fun x => p x) l :=
|
||||
@@ -700,6 +700,7 @@ theorem findIdx?_eq_none_iff {xs : List α} {p : α → Bool} :
|
||||
simp only [findIdx?_cons]
|
||||
split <;> simp_all [cond_eq_if]
|
||||
|
||||
@[simp]
|
||||
theorem findIdx?_isSome {xs : List α} {p : α → Bool} :
|
||||
(xs.findIdx? p).isSome = xs.any p := by
|
||||
induction xs with
|
||||
@@ -708,6 +709,7 @@ theorem findIdx?_isSome {xs : List α} {p : α → Bool} :
|
||||
simp only [findIdx?_cons]
|
||||
split <;> simp_all
|
||||
|
||||
@[simp]
|
||||
theorem findIdx?_isNone {xs : List α} {p : α → Bool} :
|
||||
(xs.findIdx? p).isNone = xs.all (¬p ·) := by
|
||||
induction xs with
|
||||
@@ -768,7 +770,7 @@ theorem findIdx?_eq_some_iff_getElem {xs : List α} {p : α → Bool} {i : Nat}
|
||||
not_and, Classical.not_forall, Bool.not_eq_false]
|
||||
intros
|
||||
refine ⟨0, zero_lt_succ i, ‹_›⟩
|
||||
· simp only [Option.map_eq_some', ih, Bool.not_eq_true, length_cons]
|
||||
· simp only [Option.map_eq_some_iff, ih, Bool.not_eq_true, length_cons]
|
||||
constructor
|
||||
· rintro ⟨a, ⟨⟨h, h₁, h₂⟩, rfl⟩⟩
|
||||
refine ⟨Nat.succ_lt_succ_iff.mpr h, by simpa, fun j hj => ?_⟩
|
||||
@@ -824,7 +826,7 @@ abbrev findIdx?_of_eq_none := @of_findIdx?_eq_none
|
||||
(xs ++ ys : List α).findIdx? p =
|
||||
(xs.findIdx? p).or ((ys.findIdx? p).map fun i => i + xs.length) := by
|
||||
induction xs with simp
|
||||
| cons _ _ _ => split <;> simp_all [Option.map_or', Option.map_map]; rfl
|
||||
| cons _ _ _ => split <;> simp_all [Option.map_or, Option.map_map]; rfl
|
||||
|
||||
theorem findIdx?_flatten {l : List (List α)} {p : α → Bool} :
|
||||
l.flatten.findIdx? p =
|
||||
@@ -984,6 +986,24 @@ theorem findFinIdx?_eq_some_iff {xs : List α} {p : α → Bool} {i : Fin xs.len
|
||||
· rintro ⟨h, w⟩
|
||||
exact ⟨i, ⟨i.2, h, fun j hji => w ⟨j, by omega⟩ hji⟩, rfl⟩
|
||||
|
||||
@[simp]
|
||||
theorem isSome_findFinIdx? {l : List α} {p : α → Bool} :
|
||||
(l.findFinIdx? p).isSome = l.any p := by
|
||||
induction l with
|
||||
| nil => simp
|
||||
| cons x xs ih =>
|
||||
simp only [findFinIdx?_cons]
|
||||
split <;> simp_all
|
||||
|
||||
@[simp]
|
||||
theorem isNone_findFinIdx? {l : List α} {p : α → Bool} :
|
||||
(l.findFinIdx? p).isNone = l.all (fun x => ¬ p x) := by
|
||||
induction l with
|
||||
| nil => simp
|
||||
| cons x xs ih =>
|
||||
simp only [findFinIdx?_cons]
|
||||
split <;> simp_all
|
||||
|
||||
@[simp] theorem findFinIdx?_subtype {p : α → Prop} {l : List { x // p x }}
|
||||
{f : { x // p x } → Bool} {g : α → Bool} (hf : ∀ x h, f ⟨x, h⟩ = g x) :
|
||||
l.findFinIdx? f = (l.unattach.findFinIdx? g).map (fun i => i.cast (by simp)) := by
|
||||
@@ -1084,6 +1104,24 @@ theorem idxOf?_eq_map_finIdxOf?_val [BEq α] {xs : List α} {a : α} :
|
||||
l.finIdxOf? a = some i ↔ l[i] = a ∧ ∀ j (_ : j < i), ¬l[j] = a := by
|
||||
simp only [finIdxOf?, findFinIdx?_eq_some_iff, beq_iff_eq]
|
||||
|
||||
@[simp]
|
||||
theorem isSome_finIdxOf? [BEq α] [LawfulBEq α] {l : List α} {a : α} :
|
||||
(l.finIdxOf? a).isSome ↔ a ∈ l := by
|
||||
induction l with
|
||||
| nil => simp
|
||||
| cons x xs ih =>
|
||||
simp only [finIdxOf?_cons]
|
||||
split <;> simp_all [@eq_comm _ x a]
|
||||
|
||||
@[simp]
|
||||
theorem isNone_finIdxOf? [BEq α] [LawfulBEq α] {l : List α} {a : α} :
|
||||
(l.finIdxOf? a).isNone = ¬ a ∈ l := by
|
||||
induction l with
|
||||
| nil => simp
|
||||
| cons x xs ih =>
|
||||
simp only [finIdxOf?_cons]
|
||||
split <;> simp_all [@eq_comm _ x a]
|
||||
|
||||
/-! ### idxOf?
|
||||
|
||||
The verification API for `idxOf?` is still incomplete.
|
||||
@@ -1109,6 +1147,25 @@ theorem idxOf?_cons [BEq α] {a : α} {xs : List α} {b : α} :
|
||||
@[deprecated idxOf?_eq_none_iff (since := "2025-01-29")]
|
||||
abbrev indexOf?_eq_none_iff := @idxOf?_eq_none_iff
|
||||
|
||||
@[simp]
|
||||
theorem isSome_idxOf? [BEq α] [LawfulBEq α] {l : List α} {a : α} :
|
||||
(l.idxOf? a).isSome ↔ a ∈ l := by
|
||||
induction l with
|
||||
| nil => simp
|
||||
| cons x xs ih =>
|
||||
simp only [idxOf?_cons]
|
||||
split <;> simp_all [@eq_comm _ x a]
|
||||
|
||||
@[simp]
|
||||
theorem isNone_idxOf? [BEq α] [LawfulBEq α] {l : List α} {a : α} :
|
||||
(l.idxOf? a).isNone = ¬ a ∈ l := by
|
||||
induction l with
|
||||
| nil => simp
|
||||
| cons x xs ih =>
|
||||
simp only [idxOf?_cons]
|
||||
split <;> simp_all [@eq_comm _ x a]
|
||||
|
||||
|
||||
/-! ### lookup -/
|
||||
|
||||
section lookup
|
||||
|
||||
@@ -105,7 +105,7 @@ abbrev length_eq_zero := @length_eq_zero_iff
|
||||
theorem eq_nil_iff_length_eq_zero : l = [] ↔ length l = 0 :=
|
||||
length_eq_zero_iff.symm
|
||||
|
||||
@[grind] theorem length_pos_of_mem {a : α} : ∀ {l : List α}, a ∈ l → 0 < length l
|
||||
@[grind →] theorem length_pos_of_mem {a : α} : ∀ {l : List α}, a ∈ l → 0 < length l
|
||||
| _::_, _ => Nat.zero_lt_succ _
|
||||
|
||||
theorem exists_mem_of_length_pos : ∀ {l : List α}, 0 < length l → ∃ a, a ∈ l
|
||||
@@ -185,7 +185,7 @@ theorem singleton_inj {α : Type _} {a b : α} : [a] = [b] ↔ a = b := by
|
||||
We simplify `l.get i` to `l[i.1]'i.2` and `l.get? i` to `l[i]?`.
|
||||
-/
|
||||
|
||||
@[simp, grind]
|
||||
@[simp, grind =]
|
||||
theorem get_eq_getElem {l : List α} {i : Fin l.length} : l.get i = l[i.1]'i.2 := rfl
|
||||
|
||||
set_option linter.deprecated false in
|
||||
@@ -225,7 +225,7 @@ theorem get?_eq_getElem? {l : List α} {i : Nat} : l.get? i = l[i]? := by
|
||||
We simplify `l[i]!` to `(l[i]?).getD default`.
|
||||
-/
|
||||
|
||||
@[simp, grind]
|
||||
@[simp, grind =]
|
||||
theorem getElem!_eq_getElem?_getD [Inhabited α] {l : List α} {i : Nat} :
|
||||
l[i]! = (l[i]?).getD (default : α) := by
|
||||
simp only [getElem!_def]
|
||||
@@ -235,16 +235,16 @@ theorem getElem!_eq_getElem?_getD [Inhabited α] {l : List α} {i : Nat} :
|
||||
|
||||
/-! ### getElem? and getElem -/
|
||||
|
||||
@[simp, grind] theorem getElem?_nil {i : Nat} : ([] : List α)[i]? = none := rfl
|
||||
@[simp, grind =] theorem getElem?_nil {i : Nat} : ([] : List α)[i]? = none := rfl
|
||||
|
||||
theorem getElem_cons {l : List α} (w : i < (a :: l).length) :
|
||||
(a :: l)[i] =
|
||||
if h : i = 0 then a else l[i-1]'(match i, h with | i+1, _ => succ_lt_succ_iff.mp w) := by
|
||||
cases i <;> simp
|
||||
|
||||
@[grind] theorem getElem?_cons_zero {l : List α} : (a::l)[0]? = some a := rfl
|
||||
@[grind =] theorem getElem?_cons_zero {l : List α} : (a::l)[0]? = some a := rfl
|
||||
|
||||
@[simp, grind] theorem getElem?_cons_succ {l : List α} : (a::l)[i+1]? = l[i]? := rfl
|
||||
@[simp, grind =] theorem getElem?_cons_succ {l : List α} : (a::l)[i+1]? = l[i]? := rfl
|
||||
|
||||
theorem getElem?_cons : (a :: l)[i]? = if i = 0 then some a else l[i-1]? := by
|
||||
cases i <;> simp [getElem?_cons_zero]
|
||||
@@ -337,7 +337,7 @@ We simplify away `getD`, replacing `getD l n a` with `(l[n]?).getD a`.
|
||||
Because of this, there is only minimal API for `getD`.
|
||||
-/
|
||||
|
||||
@[simp, grind]
|
||||
@[simp, grind =]
|
||||
theorem getD_eq_getElem?_getD {l : List α} {i : Nat} {a : α} : getD l i a = (l[i]?).getD a := by
|
||||
simp [getD]
|
||||
|
||||
|
||||
@@ -339,7 +339,7 @@ theorem getElem?_mapIdx_go : ∀ {l : List α} {acc : Array β} {i : Nat},
|
||||
if h : i < acc.size then some acc[i] else Option.map (f i) l[i - acc.size]?
|
||||
| [], acc, i => by
|
||||
simp only [mapIdx.go, Array.toListImpl_eq, getElem?_def, Array.length_toList,
|
||||
← Array.getElem_toList, length_nil, Nat.not_lt_zero, ↓reduceDIte, Option.map_none']
|
||||
← Array.getElem_toList, length_nil, Nat.not_lt_zero, ↓reduceDIte, Option.map_none]
|
||||
| a :: l, acc, i => by
|
||||
rw [mapIdx.go, getElem?_mapIdx_go]
|
||||
simp only [Array.size_push]
|
||||
|
||||
@@ -31,6 +31,33 @@ theorem count_set [BEq α] {a b : α} {l : List α} {i : Nat} (h : i < l.length)
|
||||
(l.set i a).count b = l.count b - (if l[i] == b then 1 else 0) + (if a == b then 1 else 0) := by
|
||||
simp [count_eq_countP, countP_set, h]
|
||||
|
||||
theorem countP_replace [BEq α] [LawfulBEq α] {a b : α} {l : List α} {p : α → Bool} :
|
||||
(l.replace a b).countP p =
|
||||
if l.contains a then l.countP p + (if p b then 1 else 0) - (if p a then 1 else 0) else l.countP p := by
|
||||
induction l with
|
||||
| nil => simp
|
||||
| cons x l ih =>
|
||||
simp [replace_cons]
|
||||
split <;> rename_i h
|
||||
· simp at h
|
||||
simp [h, ih, countP_cons]
|
||||
omega
|
||||
· simp only [beq_eq_false_iff_ne, ne_eq] at h
|
||||
simp only [countP_cons, ih, contains_eq_mem, decide_eq_true_eq, mem_cons, h, false_or]
|
||||
split <;> rename_i h'
|
||||
· by_cases h'' : p a
|
||||
· have : countP p l > 0 := countP_pos_iff.mpr ⟨a, h', h''⟩
|
||||
simp [h'']
|
||||
omega
|
||||
· simp [h'']
|
||||
omega
|
||||
· omega
|
||||
|
||||
theorem count_replace [BEq α] [LawfulBEq α] {a b c : α} {l : List α} :
|
||||
(l.replace a b).count c =
|
||||
if l.contains a then l.count c + (if b == c then 1 else 0) - (if a == c then 1 else 0) else l.count c := by
|
||||
simp [count_eq_countP, countP_replace]
|
||||
|
||||
/--
|
||||
The number of elements satisfying a predicate in a sublist is at least the number of elements satisfying the predicate in the list,
|
||||
minus the difference in the lengths.
|
||||
|
||||
@@ -54,4 +54,23 @@ theorem set_set_perm {as : List α} {i j : Nat} (h₁ : i < as.length) (h₂ : j
|
||||
subst t
|
||||
apply set_set_perm' _ _ (by omega)
|
||||
|
||||
namespace Perm
|
||||
|
||||
/-- Variant of `List.Perm.take` specifying the the permutation is constant after `i` elementwise. -/
|
||||
theorem take_of_getElem? {l₁ l₂ : List α} (h : l₁ ~ l₂) {i : Nat} (w : ∀ j, i ≤ j → l₁[j]? = l₂[j]?) :
|
||||
l₁.take i ~ l₂.take i := by
|
||||
refine h.take (Perm.of_eq ?_)
|
||||
ext1 j
|
||||
simpa using w (i + j) (by omega)
|
||||
|
||||
/-- Variant of `List.Perm.drop` specifying the the permutation is constant before `i` elementwise. -/
|
||||
theorem drop_of_getElem? {l₁ l₂ : List α} (h : l₁ ~ l₂) {i : Nat} (w : ∀ j, j < i → l₁[j]? = l₂[j]?) :
|
||||
l₁.drop i ~ l₂.drop i := by
|
||||
refine h.drop (Perm.of_eq ?_)
|
||||
ext1
|
||||
simp only [getElem?_take]
|
||||
split <;> simp_all
|
||||
|
||||
end Perm
|
||||
|
||||
end List
|
||||
|
||||
@@ -6,6 +6,7 @@ Authors: Leonardo de Moura, Jeremy Avigad, Mario Carneiro
|
||||
prelude
|
||||
import Init.Data.List.Pairwise
|
||||
import Init.Data.List.Erase
|
||||
import Init.Data.List.Find
|
||||
|
||||
/-!
|
||||
# List Permutations
|
||||
@@ -178,7 +179,7 @@ theorem Perm.singleton_eq (h : [a] ~ l) : [a] = l := singleton_perm.mp h
|
||||
|
||||
theorem singleton_perm_singleton {a b : α} : [a] ~ [b] ↔ a = b := by simp
|
||||
|
||||
theorem perm_cons_erase [DecidableEq α] {a : α} {l : List α} (h : a ∈ l) : l ~ a :: l.erase a :=
|
||||
theorem perm_cons_erase [BEq α] [LawfulBEq α] {a : α} {l : List α} (h : a ∈ l) : l ~ a :: l.erase a :=
|
||||
let ⟨_, _, _, e₁, e₂⟩ := exists_erase_eq h
|
||||
e₂ ▸ e₁ ▸ perm_middle
|
||||
|
||||
@@ -268,7 +269,7 @@ theorem countP_eq_countP_filter_add (l : List α) (p q : α → Bool) :
|
||||
l.countP p = (l.filter q).countP p + (l.filter fun a => !q a).countP p :=
|
||||
countP_append .. ▸ Perm.countP_eq _ (filter_append_perm _ _).symm
|
||||
|
||||
theorem Perm.count_eq [DecidableEq α] {l₁ l₂ : List α} (p : l₁ ~ l₂) (a) :
|
||||
theorem Perm.count_eq [BEq α] {l₁ l₂ : List α} (p : l₁ ~ l₂) (a) :
|
||||
count a l₁ = count a l₂ := p.countP_eq _
|
||||
|
||||
/-
|
||||
@@ -369,9 +370,9 @@ theorem perm_append_right_iff {l₁ l₂ : List α} (l) : l₁ ++ l ~ l₂ ++ l
|
||||
refine ⟨fun p => ?_, .append_right _⟩
|
||||
exact (perm_append_left_iff _).1 <| perm_append_comm.trans <| p.trans perm_append_comm
|
||||
|
||||
section DecidableEq
|
||||
section LawfulBEq
|
||||
|
||||
variable [DecidableEq α]
|
||||
variable [BEq α] [LawfulBEq α]
|
||||
|
||||
theorem Perm.erase (a : α) {l₁ l₂ : List α} (p : l₁ ~ l₂) : l₁.erase a ~ l₂.erase a :=
|
||||
if h₁ : a ∈ l₁ then
|
||||
@@ -387,6 +388,11 @@ theorem cons_perm_iff_perm_erase {a : α} {l₁ l₂ : List α} :
|
||||
have : a ∈ l₂ := h.subset mem_cons_self
|
||||
exact ⟨this, (h.trans <| perm_cons_erase this).cons_inv⟩
|
||||
|
||||
end LawfulBEq
|
||||
section DecidableEq
|
||||
|
||||
variable [DecidableEq α]
|
||||
|
||||
theorem perm_iff_count {l₁ l₂ : List α} : l₁ ~ l₂ ↔ ∀ a, count a l₁ = count a l₂ := by
|
||||
refine ⟨Perm.count_eq, fun H => ?_⟩
|
||||
induction l₁ generalizing l₂ with
|
||||
@@ -536,4 +542,22 @@ theorem perm_insertIdx {α} (x : α) (l : List α) {i} (h : i ≤ l.length) :
|
||||
simp only [insertIdx, modifyTailIdx]
|
||||
refine .trans (.cons _ (ih (Nat.le_of_succ_le_succ h))) (.swap ..)
|
||||
|
||||
namespace Perm
|
||||
|
||||
theorem take {l₁ l₂ : List α} (h : l₁ ~ l₂) {n : Nat} (w : l₁.drop n ~ l₂.drop n) :
|
||||
l₁.take n ~ l₂.take n := by
|
||||
classical
|
||||
rw [perm_iff_count] at h w ⊢
|
||||
rw [← take_append_drop n l₁, ← take_append_drop n l₂] at h
|
||||
simpa only [count_append, w, Nat.add_right_cancel_iff] using h
|
||||
|
||||
theorem drop {l₁ l₂ : List α} (h : l₁ ~ l₂) {n : Nat} (w : l₁.take n ~ l₂.take n) :
|
||||
l₁.drop n ~ l₂.drop n := by
|
||||
classical
|
||||
rw [perm_iff_count] at h w ⊢
|
||||
rw [← take_append_drop n l₁, ← take_append_drop n l₂] at h
|
||||
simpa only [count_append, w, Nat.add_left_cancel_iff] using h
|
||||
|
||||
end Perm
|
||||
|
||||
end List
|
||||
|
||||
@@ -6,6 +6,7 @@ Authors: Floris van Doorn, Leonardo de Moura
|
||||
prelude
|
||||
import Init.SimpLemmas
|
||||
import Init.Data.NeZero
|
||||
import Init.Grind.Tactics
|
||||
|
||||
set_option linter.missingDocs true -- keep it documented
|
||||
universe u
|
||||
@@ -867,7 +868,7 @@ Examples:
|
||||
-/
|
||||
protected abbrev min (n m : Nat) := min n m
|
||||
|
||||
protected theorem min_def {n m : Nat} : min n m = if n ≤ m then n else m := rfl
|
||||
@[grind =] protected theorem min_def {n m : Nat} : min n m = if n ≤ m then n else m := rfl
|
||||
|
||||
instance : Max Nat := maxOfLe
|
||||
|
||||
@@ -884,7 +885,7 @@ Examples:
|
||||
-/
|
||||
protected abbrev max (n m : Nat) := max n m
|
||||
|
||||
protected theorem max_def {n m : Nat} : max n m = if n ≤ m then m else n := rfl
|
||||
@[grind =] protected theorem max_def {n m : Nat} : max n m = if n ≤ m then m else n := rfl
|
||||
|
||||
|
||||
/-! # Auxiliary theorems for well-founded recursion -/
|
||||
|
||||
@@ -141,11 +141,11 @@ theorem toList_attach (o : Option α) :
|
||||
cases o <;> simp
|
||||
|
||||
theorem attach_map {o : Option α} (f : α → β) :
|
||||
(o.map f).attach = o.attach.map (fun ⟨x, h⟩ => ⟨f x, map_eq_some.2 ⟨_, h, rfl⟩⟩) := by
|
||||
(o.map f).attach = o.attach.map (fun ⟨x, h⟩ => ⟨f x, map_eq_some_iff.2 ⟨_, h, rfl⟩⟩) := by
|
||||
cases o <;> simp
|
||||
|
||||
theorem attachWith_map {o : Option α} (f : α → β) {P : β → Prop} {H : ∀ (b : β), o.map f = some b → P b} :
|
||||
(o.map f).attachWith P H = (o.attachWith (P ∘ f) (fun _ h => H _ (map_eq_some.2 ⟨_, h, rfl⟩))).map
|
||||
(o.map f).attachWith P H = (o.attachWith (P ∘ f) (fun _ h => H _ (map_eq_some_iff.2 ⟨_, h, rfl⟩))).map
|
||||
fun ⟨x, h⟩ => ⟨f x, h⟩ := by
|
||||
cases o <;> simp
|
||||
|
||||
@@ -174,7 +174,7 @@ theorem map_attach_eq_attachWith {o : Option α} {p : α → Prop} (f : ∀ a, o
|
||||
|
||||
theorem attach_bind {o : Option α} {f : α → Option β} :
|
||||
(o.bind f).attach =
|
||||
o.attach.bind fun ⟨x, h⟩ => (f x).attach.map fun ⟨y, h'⟩ => ⟨y, bind_eq_some.2 ⟨_, h, h'⟩⟩ := by
|
||||
o.attach.bind fun ⟨x, h⟩ => (f x).attach.map fun ⟨y, h'⟩ => ⟨y, bind_eq_some_iff.2 ⟨_, h, h'⟩⟩ := by
|
||||
cases o <;> simp
|
||||
|
||||
theorem bind_attach {o : Option α} {f : {x // o = some x} → Option β} :
|
||||
|
||||
@@ -231,13 +231,12 @@ def merge (fn : α → α → α) : Option α → Option α → Option α
|
||||
@[simp] theorem getD_none : getD none a = a := rfl
|
||||
@[simp] theorem getD_some : getD (some a) b = a := rfl
|
||||
|
||||
@[simp] theorem map_none' (f : α → β) : none.map f = none := rfl
|
||||
@[simp] theorem map_some' (a) (f : α → β) : (some a).map f = some (f a) := rfl
|
||||
@[simp] theorem map_none (f : α → β) : none.map f = none := rfl
|
||||
@[simp] theorem map_some (a) (f : α → β) : (some a).map f = some (f a) := rfl
|
||||
|
||||
@[simp] theorem none_bind (f : α → Option β) : none.bind f = none := rfl
|
||||
@[simp] theorem some_bind (a) (f : α → Option β) : (some a).bind f = f a := rfl
|
||||
|
||||
|
||||
/--
|
||||
A case analysis function for `Option`.
|
||||
|
||||
|
||||
@@ -39,18 +39,24 @@ This is not an instance because it is not definitionally equal to the standard i
|
||||
|
||||
Try to use the Boolean comparisons `Option.isNone` or `Option.isSome` instead.
|
||||
-/
|
||||
@[inline] def decidable_eq_none {o : Option α} : Decidable (o = none) :=
|
||||
@[inline] def decidableEqNone {o : Option α} : Decidable (o = none) :=
|
||||
decidable_of_decidable_of_iff isNone_iff_eq_none
|
||||
|
||||
instance {p : α → Prop} [DecidablePred p] : ∀ o : Option α, Decidable (∀ a, a ∈ o → p a)
|
||||
| none => isTrue nofun
|
||||
| some a =>
|
||||
if h : p a then isTrue fun _ e => some_inj.1 e ▸ h
|
||||
else isFalse <| mt (· _ rfl) h
|
||||
@[deprecated decidableEqNone (since := "2025-04-10"), inline]
|
||||
def decidable_eq_none {o : Option α} : Decidable (o = none) :=
|
||||
decidableEqNone
|
||||
|
||||
instance {p : α → Prop} [DecidablePred p] : ∀ o : Option α, Decidable (Exists fun a => a ∈ o ∧ p a)
|
||||
| none => isFalse nofun
|
||||
| some a => if h : p a then isTrue ⟨_, rfl, h⟩ else isFalse fun ⟨_, ⟨rfl, hn⟩⟩ => h hn
|
||||
instance decidableForallMem {p : α → Prop} [DecidablePred p] :
|
||||
∀ o : Option α, Decidable (∀ a, a ∈ o → p a)
|
||||
| none => isTrue nofun
|
||||
| some a =>
|
||||
if h : p a then isTrue fun _ e => some_inj.1 e ▸ h
|
||||
else isFalse <| mt (· _ rfl) h
|
||||
|
||||
instance decidableExistsMem {p : α → Prop} [DecidablePred p] :
|
||||
∀ o : Option α, Decidable (Exists fun a => a ∈ o ∧ p a)
|
||||
| none => isFalse nofun
|
||||
| some a => if h : p a then isTrue ⟨_, rfl, h⟩ else isFalse fun ⟨_, ⟨rfl, hn⟩⟩ => h hn
|
||||
|
||||
/--
|
||||
Given an optional value and a function that can be applied when the value is `some`, returns the
|
||||
|
||||
@@ -17,6 +17,8 @@ theorem mem_iff {a : α} {b : Option α} : a ∈ b ↔ b = some a := .rfl
|
||||
|
||||
theorem mem_some {a b : α} : a ∈ some b ↔ b = a := by simp
|
||||
|
||||
theorem mem_some_iff {a b : α} : a ∈ some b ↔ b = a := mem_some
|
||||
|
||||
theorem mem_some_self (a : α) : a ∈ some a := rfl
|
||||
|
||||
theorem some_ne_none (x : α) : some x ≠ none := nofun
|
||||
@@ -29,6 +31,9 @@ protected theorem «exists» {p : Option α → Prop} :
|
||||
⟨fun | ⟨none, hx⟩ => .inl hx | ⟨some x, hx⟩ => .inr ⟨x, hx⟩,
|
||||
fun | .inl h => ⟨_, h⟩ | .inr ⟨_, hx⟩ => ⟨_, hx⟩⟩
|
||||
|
||||
theorem eq_none_or_eq_some (a : Option α) : a = none ∨ ∃ x, a = some x :=
|
||||
Option.exists.mp exists_eq'
|
||||
|
||||
theorem get_mem : ∀ {o : Option α} (h : isSome o), o.get h ∈ o
|
||||
| some _, _ => rfl
|
||||
|
||||
@@ -88,6 +93,9 @@ set_option Elab.async false
|
||||
theorem eq_none_iff_forall_ne_some : o = none ↔ ∀ a, o ≠ some a := by
|
||||
cases o <;> simp
|
||||
|
||||
theorem eq_none_iff_forall_some_ne : o = none ↔ ∀ a, some a ≠ o := by
|
||||
cases o <;> simp
|
||||
|
||||
theorem eq_none_iff_forall_not_mem : o = none ↔ ∀ a, a ∉ o :=
|
||||
eq_none_iff_forall_ne_some
|
||||
|
||||
@@ -102,9 +110,30 @@ theorem isSome_of_mem {x : Option α} {y : α} (h : y ∈ x) : x.isSome := by
|
||||
theorem isSome_of_eq_some {x : Option α} {y : α} (h : x = some y) : x.isSome := by
|
||||
cases x <;> trivial
|
||||
|
||||
@[simp] theorem not_isSome : isSome a = false ↔ a.isNone = true := by
|
||||
@[simp] theorem isSome_eq_false_iff : isSome a = false ↔ a.isNone = true := by
|
||||
cases a <;> simp
|
||||
|
||||
@[simp] theorem isNone_eq_false_iff : isNone a = false ↔ a.isSome = true := by
|
||||
cases a <;> simp
|
||||
|
||||
@[simp]
|
||||
theorem not_isSome (a : Option α) : (!a.isSome) = a.isNone := by
|
||||
cases a <;> simp
|
||||
|
||||
@[simp]
|
||||
theorem not_comp_isSome : (! ·) ∘ @Option.isSome α = Option.isNone := by
|
||||
funext
|
||||
simp
|
||||
|
||||
@[simp]
|
||||
theorem not_isNone (a : Option α) : (!a.isNone) = a.isSome := by
|
||||
cases a <;> simp
|
||||
|
||||
@[simp]
|
||||
theorem not_comp_isNone : (!·) ∘ @Option.isNone α = Option.isSome := by
|
||||
funext x
|
||||
simp
|
||||
|
||||
theorem eq_some_iff_get_eq : o = some a ↔ ∃ h : o.isSome, o.get h = a := by
|
||||
cases o <;> simp
|
||||
|
||||
@@ -146,17 +175,25 @@ abbrev ball_ne_none := @forall_ne_none
|
||||
|
||||
@[simp] theorem bind_eq_bind : bind = @Option.bind α β := rfl
|
||||
|
||||
@[simp] theorem orElse_eq_orElse : HOrElse.hOrElse = @Option.orElse α := rfl
|
||||
|
||||
@[simp] theorem bind_some (x : Option α) : x.bind some = x := by cases x <;> rfl
|
||||
|
||||
@[simp] theorem bind_none (x : Option α) : x.bind (fun _ => none (α := β)) = none := by
|
||||
cases x <;> rfl
|
||||
|
||||
theorem bind_eq_some : x.bind f = some b ↔ ∃ a, x = some a ∧ f a = some b := by
|
||||
theorem bind_eq_some_iff : x.bind f = some b ↔ ∃ a, x = some a ∧ f a = some b := by
|
||||
cases x <;> simp
|
||||
|
||||
@[simp] theorem bind_eq_none {o : Option α} {f : α → Option β} :
|
||||
@[deprecated bind_eq_some_iff (since := "2025-04-10")]
|
||||
abbrev bind_eq_some := @bind_eq_some_iff
|
||||
|
||||
@[simp] theorem bind_eq_none_iff {o : Option α} {f : α → Option β} :
|
||||
o.bind f = none ↔ ∀ a, o = some a → f a = none := by cases o <;> simp
|
||||
|
||||
@[deprecated bind_eq_none_iff (since := "2025-04-10")]
|
||||
abbrev bind_eq_none := @bind_eq_none_iff
|
||||
|
||||
theorem bind_eq_none' {o : Option α} {f : α → Option β} :
|
||||
o.bind f = none ↔ ∀ b a, o = some a → f a ≠ some b := by
|
||||
cases o <;> simp [eq_none_iff_forall_ne_some]
|
||||
@@ -193,50 +230,67 @@ theorem isSome_apply_of_isSome_bind {α β : Type _} {x : Option α} {f : α →
|
||||
(isSome_apply_of_isSome_bind h) := by
|
||||
cases x <;> trivial
|
||||
|
||||
theorem join_eq_some : x.join = some a ↔ x = some (some a) := by
|
||||
simp [bind_eq_some]
|
||||
theorem join_eq_some_iff : x.join = some a ↔ x = some (some a) := by
|
||||
simp [bind_eq_some_iff]
|
||||
|
||||
@[deprecated join_eq_some_iff (since := "2025-04-10")]
|
||||
abbrev join_eq_some := @join_eq_some_iff
|
||||
|
||||
theorem join_ne_none : x.join ≠ none ↔ ∃ z, x = some (some z) := by
|
||||
simp only [ne_none_iff_exists', join_eq_some, iff_self]
|
||||
simp only [ne_none_iff_exists', join_eq_some_iff, iff_self]
|
||||
|
||||
theorem join_ne_none' : ¬x.join = none ↔ ∃ z, x = some (some z) :=
|
||||
join_ne_none
|
||||
|
||||
theorem join_eq_none : o.join = none ↔ o = none ∨ o = some none :=
|
||||
theorem join_eq_none_iff : o.join = none ↔ o = none ∨ o = some none :=
|
||||
match o with | none | some none | some (some _) => by simp
|
||||
|
||||
@[deprecated join_eq_none_iff (since := "2025-04-10")]
|
||||
abbrev join_eq_none := @join_eq_none_iff
|
||||
|
||||
theorem bind_id_eq_join {x : Option (Option α)} : x.bind id = x.join := rfl
|
||||
|
||||
@[simp] theorem map_eq_map : Functor.map f = Option.map f := rfl
|
||||
|
||||
theorem map_none : f <$> none = none := rfl
|
||||
@[deprecated map_none (since := "2025-04-10")]
|
||||
abbrev map_none' := @map_none
|
||||
|
||||
theorem map_some : f <$> some a = some (f a) := rfl
|
||||
@[deprecated map_some (since := "2025-04-10")]
|
||||
abbrev map_some' := @map_some
|
||||
|
||||
@[simp] theorem map_eq_some' : x.map f = some b ↔ ∃ a, x = some a ∧ f a = b := by cases x <;> simp
|
||||
|
||||
theorem map_eq_some : f <$> x = some b ↔ ∃ a, x = some a ∧ f a = b := map_eq_some'
|
||||
|
||||
@[simp] theorem map_eq_none' : x.map f = none ↔ x = none := by
|
||||
cases x <;> simp [map_none', map_some', eq_self_iff_true]
|
||||
|
||||
theorem isSome_map {x : Option α} : (f <$> x).isSome = x.isSome := by
|
||||
@[simp] theorem map_eq_some_iff : x.map f = some b ↔ ∃ a, x = some a ∧ f a = b := by
|
||||
cases x <;> simp
|
||||
|
||||
@[simp] theorem isSome_map' {x : Option α} : (x.map f).isSome = x.isSome := by
|
||||
@[deprecated map_eq_some_iff (since := "2025-04-10")]
|
||||
abbrev map_eq_some := @map_eq_some_iff
|
||||
|
||||
@[deprecated map_eq_some_iff (since := "2025-04-10")]
|
||||
abbrev map_eq_some' := @map_eq_some_iff
|
||||
|
||||
@[simp] theorem map_eq_none_iff : x.map f = none ↔ x = none := by
|
||||
cases x <;> simp [map_none, map_some, eq_self_iff_true]
|
||||
|
||||
@[deprecated map_eq_none_iff (since := "2025-04-10")]
|
||||
abbrev map_eq_none := @map_eq_none_iff
|
||||
|
||||
@[deprecated map_eq_none_iff (since := "2025-04-10")]
|
||||
abbrev map_eq_none' := @map_eq_none_iff
|
||||
|
||||
@[simp] theorem isSome_map {x : Option α} : (x.map f).isSome = x.isSome := by
|
||||
cases x <;> simp
|
||||
|
||||
@[simp] theorem isNone_map' {x : Option α} : (x.map f).isNone = x.isNone := by
|
||||
cases x <;> simp
|
||||
@[deprecated isSome_map (since := "2025-04-10")]
|
||||
abbrev isSome_map' := @isSome_map
|
||||
|
||||
theorem map_eq_none : f <$> x = none ↔ x = none := map_eq_none'
|
||||
@[simp] theorem isNone_map {x : Option α} : (x.map f).isNone = x.isNone := by
|
||||
cases x <;> simp
|
||||
|
||||
theorem map_eq_bind {x : Option α} : x.map f = x.bind (some ∘ f) := by
|
||||
cases x <;> simp [Option.bind]
|
||||
|
||||
theorem map_congr {x : Option α} (h : ∀ a, x = some a → f a = g a) :
|
||||
x.map f = x.map g := by
|
||||
cases x <;> simp only [map_none', map_some', h]
|
||||
cases x <;> simp only [map_none, map_some, h]
|
||||
|
||||
@[simp] theorem map_id_fun {α : Type u} : Option.map (id : α → α) = id := by
|
||||
funext; simp [map_id]
|
||||
@@ -254,7 +308,7 @@ theorem get_map {f : α → β} {o : Option α} {h : (o.map f).isSome} :
|
||||
|
||||
@[simp] theorem map_map (h : β → γ) (g : α → β) (x : Option α) :
|
||||
(x.map g).map h = x.map (h ∘ g) := by
|
||||
cases x <;> simp only [map_none', map_some', ·∘·]
|
||||
cases x <;> simp only [map_none, map_some, ·∘·]
|
||||
|
||||
theorem comp_map (h : β → γ) (g : α → β) (x : Option α) : x.map (h ∘ g) = (x.map g).map h :=
|
||||
(map_map ..).symm
|
||||
@@ -262,7 +316,7 @@ theorem comp_map (h : β → γ) (g : α → β) (x : Option α) : x.map (h ∘
|
||||
@[simp] theorem map_comp_map (f : α → β) (g : β → γ) :
|
||||
Option.map g ∘ Option.map f = Option.map (g ∘ f) := by funext x; simp
|
||||
|
||||
theorem mem_map_of_mem (g : α → β) (h : a ∈ x) : g a ∈ Option.map g x := h.symm ▸ map_some' ..
|
||||
theorem mem_map_of_mem (g : α → β) (h : a ∈ x) : g a ∈ Option.map g x := h.symm ▸ map_some ..
|
||||
|
||||
theorem map_inj_right {f : α → β} {o o' : Option α} (w : ∀ x y, f x = f y → x = y) :
|
||||
o.map f = o'.map f ↔ o = o' := by
|
||||
@@ -292,11 +346,14 @@ theorem isSome_of_isSome_filter (p : α → Bool) (o : Option α) (h : (o.filter
|
||||
@[deprecated isSome_of_isSome_filter (since := "2025-03-18")]
|
||||
abbrev isSome_filter_of_isSome := @isSome_of_isSome_filter
|
||||
|
||||
@[simp] theorem filter_eq_none {o : Option α} {p : α → Bool} :
|
||||
@[simp] theorem filter_eq_none_iff {o : Option α} {p : α → Bool} :
|
||||
o.filter p = none ↔ ∀ a, o = some a → ¬ p a := by
|
||||
cases o <;> simp [filter_some]
|
||||
|
||||
@[simp] theorem filter_eq_some {o : Option α} {p : α → Bool} :
|
||||
@[deprecated filter_eq_none_iff (since := "2025-04-10")]
|
||||
abbrev filter_eq_none := @filter_eq_none_iff
|
||||
|
||||
@[simp] theorem filter_eq_some_iff {o : Option α} {p : α → Bool} :
|
||||
o.filter p = some a ↔ o = some a ∧ p a := by
|
||||
cases o with
|
||||
| none => simp
|
||||
@@ -310,6 +367,9 @@ abbrev isSome_filter_of_isSome := @isSome_of_isSome_filter
|
||||
rintro rfl
|
||||
simpa using h
|
||||
|
||||
@[deprecated filter_eq_some_iff (since := "2025-04-10")]
|
||||
abbrev filter_eq_some := @filter_eq_some_iff
|
||||
|
||||
theorem mem_filter_iff {p : α → Bool} {a : α} {o : Option α} :
|
||||
a ∈ o.filter p ↔ a ∈ o ∧ p a := by
|
||||
simp
|
||||
@@ -383,29 +443,43 @@ theorem join_join {x : Option (Option (Option α))} : x.join.join = (x.map join)
|
||||
cases x <;> simp
|
||||
|
||||
theorem mem_of_mem_join {a : α} {x : Option (Option α)} (h : a ∈ x.join) : some a ∈ x :=
|
||||
h.symm ▸ join_eq_some.1 h
|
||||
h.symm ▸ join_eq_some_iff.1 h
|
||||
|
||||
@[simp] theorem some_orElse (a : α) (x : Option α) : (some a <|> x) = some a := rfl
|
||||
@[simp] theorem some_orElse (a : α) (f) : (some a).orElse f = some a := rfl
|
||||
|
||||
@[simp] theorem none_orElse (x : Option α) : (none <|> x) = x := rfl
|
||||
@[simp] theorem none_orElse (f : Unit → Option α) : none.orElse f = f () := rfl
|
||||
|
||||
@[simp] theorem orElse_none (x : Option α) : (x <|> none) = x := by cases x <;> rfl
|
||||
@[simp] theorem orElse_none (x : Option α) : x.orElse (fun _ => none) = x := by cases x <;> rfl
|
||||
|
||||
theorem map_orElse {x y : Option α} : (x <|> y).map f = (x.map f <|> y.map f) := by
|
||||
theorem orElse_eq_some_iff (o : Option α) (f) (x : α) :
|
||||
(o.orElse f) = some x ↔ o = some x ∨ o = none ∧ f () = some x := by
|
||||
cases o <;> simp
|
||||
|
||||
theorem orElse_eq_none_iff (o : Option α) (f) : (o.orElse f) = none ↔ o = none ∧ f () = none := by
|
||||
cases o <;> simp
|
||||
|
||||
theorem map_orElse {x : Option α} {y} :
|
||||
(x.orElse y).map f = (x.map f).orElse (fun _ => (y ()).map f) := by
|
||||
cases x <;> simp
|
||||
|
||||
@[simp] theorem guard_eq_some [DecidablePred p] : guard p a = some b ↔ a = b ∧ p a :=
|
||||
@[simp] theorem guard_eq_some_iff [DecidablePred p] : guard p a = some b ↔ a = b ∧ p a :=
|
||||
if h : p a then by simp [Option.guard, h] else by simp [Option.guard, h]
|
||||
|
||||
@[deprecated guard_eq_some_iff (since := "2025-04-10")]
|
||||
abbrev guard_eq_some := @guard_eq_some_iff
|
||||
|
||||
@[simp] theorem isSome_guard [DecidablePred p] : (Option.guard p a).isSome ↔ p a :=
|
||||
if h : p a then by simp [Option.guard, h] else by simp [Option.guard, h]
|
||||
|
||||
@[deprecated isSome_guard (since := "2025-03-18")]
|
||||
abbrev guard_isSome := @isSome_guard
|
||||
|
||||
@[simp] theorem guard_eq_none [DecidablePred p] : Option.guard p a = none ↔ ¬ p a :=
|
||||
@[simp] theorem guard_eq_none_iff [DecidablePred p] : Option.guard p a = none ↔ ¬ p a :=
|
||||
if h : p a then by simp [Option.guard, h] else by simp [Option.guard, h]
|
||||
|
||||
@[deprecated guard_eq_none_iff (since := "2025-04-10")]
|
||||
abbrev guard_eq_none := @guard_eq_none_iff
|
||||
|
||||
@[simp] theorem guard_pos [DecidablePred p] (h : p a) : Option.guard p a = some a := by
|
||||
simp [Option.guard, h]
|
||||
|
||||
@@ -475,6 +549,22 @@ theorem liftOrGet_none_right {f} {a : Option α} : merge f a none = a :=
|
||||
theorem liftOrGet_some_some {f} {a b : α} : merge f (some a) (some b) = f a b :=
|
||||
merge_some_some
|
||||
|
||||
instance commutative_merge (f : α → α → α) [Std.Commutative f] :
|
||||
Std.Commutative (merge f) :=
|
||||
⟨fun a b ↦ by cases a <;> cases b <;> simp [merge, Std.Commutative.comm]⟩
|
||||
|
||||
instance associative_merge (f : α → α → α) [Std.Associative f] :
|
||||
Std.Associative (merge f) :=
|
||||
⟨fun a b c ↦ by cases a <;> cases b <;> cases c <;> simp [merge, Std.Associative.assoc]⟩
|
||||
|
||||
instance idempotentOp_merge (f : α → α → α) [Std.IdempotentOp f] :
|
||||
Std.IdempotentOp (merge f) :=
|
||||
⟨fun a ↦ by cases a <;> simp [merge, Std.IdempotentOp.idempotent]⟩
|
||||
|
||||
instance lawfulIdentity_merge (f : α → α → α) : Std.LawfulIdentity (merge f) none where
|
||||
left_id a := by cases a <;> simp [merge]
|
||||
right_id a := by cases a <;> simp [merge]
|
||||
|
||||
@[simp] theorem elim_none (x : β) (f : α → β) : none.elim x f = x := rfl
|
||||
|
||||
@[simp] theorem elim_some (x : β) (f : α → β) (a : α) : (some a).elim x f = f a := rfl
|
||||
@@ -535,12 +625,18 @@ theorem or_eq_bif : or o o' = bif o.isSome then o else o' := by
|
||||
@[simp] theorem isNone_or : (or o o').isNone = (o.isNone && o'.isNone) := by
|
||||
cases o <;> rfl
|
||||
|
||||
@[simp] theorem or_eq_none : or o o' = none ↔ o = none ∧ o' = none := by
|
||||
@[simp] theorem or_eq_none_iff : or o o' = none ↔ o = none ∧ o' = none := by
|
||||
cases o <;> simp
|
||||
|
||||
@[simp] theorem or_eq_some : or o o' = some a ↔ o = some a ∨ (o = none ∧ o' = some a) := by
|
||||
@[deprecated or_eq_none_iff (since := "2025-04-10")]
|
||||
abbrev or_eq_none := @or_eq_none_iff
|
||||
|
||||
@[simp] theorem or_eq_some_iff : or o o' = some a ↔ o = some a ∨ (o = none ∧ o' = some a) := by
|
||||
cases o <;> simp
|
||||
|
||||
@[deprecated or_eq_some_iff (since := "2025-04-10")]
|
||||
abbrev or_eq_some := @or_eq_some_iff
|
||||
|
||||
theorem or_assoc : or (or o₁ o₂) o₃ = or o₁ (or o₂ o₃) := by
|
||||
cases o₁ <;> cases o₂ <;> rfl
|
||||
instance : Std.Associative (or (α := α)) := ⟨@or_assoc _⟩
|
||||
@@ -564,11 +660,11 @@ instance : Std.IdempotentOp (or (α := α)) := ⟨@or_self _⟩
|
||||
theorem or_eq_orElse : or o o' = o.orElse (fun _ => o') := by
|
||||
cases o <;> rfl
|
||||
|
||||
theorem map_or : f <$> or o o' = (f <$> o).or (f <$> o') := by
|
||||
theorem map_or : (or o o').map f = (o.map f).or (o'.map f) := by
|
||||
cases o <;> rfl
|
||||
|
||||
theorem map_or' : (or o o').map f = (o.map f).or (o'.map f) := by
|
||||
cases o <;> rfl
|
||||
@[deprecated map_or (since := "2025-04-10")]
|
||||
abbrev map_or' := @map_or
|
||||
|
||||
theorem or_of_isSome {o o' : Option α} (h : o.isSome) : o.or o' = o := by
|
||||
match o, h with
|
||||
@@ -804,7 +900,7 @@ theorem map_pmap {p : α → Prop} (g : β → γ) (f : ∀ a, p a → β) (o H)
|
||||
|
||||
theorem pmap_map (o : Option α) (f : α → β) {p : β → Prop} (g : ∀ b, p b → γ) (H) :
|
||||
pmap g (o.map f) H =
|
||||
pmap (fun a h => g (f a) h) o (fun a m => H (f a) (map_eq_some.2 ⟨_, m, rfl⟩)) := by
|
||||
pmap (fun a h => g (f a) h) o (fun a m => H (f a) (map_eq_some_iff.2 ⟨_, m, rfl⟩)) := by
|
||||
cases o <;> simp
|
||||
|
||||
theorem pmap_pred_congr {α : Type u}
|
||||
|
||||
@@ -20,6 +20,9 @@ def UInt8.mk (bitVec : BitVec 8) : UInt8 :=
|
||||
def UInt8.ofNatCore (n : Nat) (h : n < UInt8.size) : UInt8 :=
|
||||
UInt8.ofNatLT n h
|
||||
|
||||
/-- Converts an `Int` to a `UInt8` by taking the (non-negative remainder of the division by `2 ^ 8`. -/
|
||||
def UInt8.ofInt (x : Int) : UInt8 := ofNat (x % 2 ^ 8).toNat
|
||||
|
||||
/--
|
||||
Adds two 8-bit unsigned integers, wrapping around on overflow. Usually accessed via the `+`
|
||||
operator.
|
||||
@@ -229,6 +232,9 @@ def UInt16.mk (bitVec : BitVec 16) : UInt16 :=
|
||||
def UInt16.ofNatCore (n : Nat) (h : n < UInt16.size) : UInt16 :=
|
||||
UInt16.ofNatLT n h
|
||||
|
||||
/-- Converts an `Int` to a `UInt16` by taking the (non-negative remainder of the division by `2 ^ 16`. -/
|
||||
def UInt16.ofInt (x : Int) : UInt16 := ofNat (x % 2 ^ 16).toNat
|
||||
|
||||
/--
|
||||
Adds two 16-bit unsigned integers, wrapping around on overflow. Usually accessed via the `+`
|
||||
operator.
|
||||
@@ -440,6 +446,9 @@ def UInt32.mk (bitVec : BitVec 32) : UInt32 :=
|
||||
def UInt32.ofNatCore (n : Nat) (h : n < UInt32.size) : UInt32 :=
|
||||
UInt32.ofNatLT n h
|
||||
|
||||
/-- Converts an `Int` to a `UInt32` by taking the (non-negative remainder of the division by `2 ^ 32`. -/
|
||||
def UInt32.ofInt (x : Int) : UInt32 := ofNat (x % 2 ^ 32).toNat
|
||||
|
||||
/--
|
||||
Adds two 32-bit unsigned integers, wrapping around on overflow. Usually accessed via the `+`
|
||||
operator.
|
||||
@@ -613,6 +622,9 @@ def UInt64.mk (bitVec : BitVec 64) : UInt64 :=
|
||||
def UInt64.ofNatCore (n : Nat) (h : n < UInt64.size) : UInt64 :=
|
||||
UInt64.ofNatLT n h
|
||||
|
||||
/-- Converts an `Int` to a `UInt64` by taking the (non-negative remainder of the division by `2 ^ 64`. -/
|
||||
def UInt64.ofInt (x : Int) : UInt64 := ofNat (x % 2 ^ 64).toNat
|
||||
|
||||
/--
|
||||
Adds two 64-bit unsigned integers, wrapping around on overflow. Usually accessed via the `+`
|
||||
operator.
|
||||
@@ -822,6 +834,9 @@ def USize.mk (bitVec : BitVec System.Platform.numBits) : USize :=
|
||||
def USize.ofNatCore (n : Nat) (h : n < USize.size) : USize :=
|
||||
USize.ofNatLT n h
|
||||
|
||||
/-- Converts an `Int` to a `USize` by taking the (non-negative remainder of the division by `2 ^ numBits`. -/
|
||||
def USize.ofInt (x : Int) : USize := ofNat (x % 2 ^ System.Platform.numBits).toNat
|
||||
|
||||
@[simp] theorem USize.le_size : 2 ^ 32 ≤ USize.size := by cases USize.size_eq <;> simp_all
|
||||
@[simp] theorem USize.size_le : USize.size ≤ 2 ^ 64 := by cases USize.size_eq <;> simp_all
|
||||
|
||||
|
||||
@@ -286,6 +286,17 @@ declare_uint_theorems USize System.Platform.numBits
|
||||
theorem USize.toNat_ofNat_of_lt_32 {n : Nat} (h : n < 4294967296) : toNat (ofNat n) = n :=
|
||||
toNat_ofNat_of_lt (Nat.lt_of_lt_of_le h USize.le_size)
|
||||
|
||||
theorem UInt8.ofNat_mod_size : ofNat (x % 2 ^ 8) = ofNat x := by
|
||||
simp [ofNat, BitVec.ofNat, Fin.ofNat']
|
||||
theorem UInt16.ofNat_mod_size : ofNat (x % 2 ^ 16) = ofNat x := by
|
||||
simp [ofNat, BitVec.ofNat, Fin.ofNat']
|
||||
theorem UInt32.ofNat_mod_size : ofNat (x % 2 ^ 32) = ofNat x := by
|
||||
simp [ofNat, BitVec.ofNat, Fin.ofNat']
|
||||
theorem UInt64.ofNat_mod_size : ofNat (x % 2 ^ 64) = ofNat x := by
|
||||
simp [ofNat, BitVec.ofNat, Fin.ofNat']
|
||||
theorem USize.ofNat_mod_size : ofNat (x % 2 ^ System.Platform.numBits) = ofNat x := by
|
||||
simp [ofNat, BitVec.ofNat, Fin.ofNat']
|
||||
|
||||
theorem UInt8.lt_ofNat_iff {n : UInt8} {m : Nat} (h : m < size) : n < ofNat m ↔ n.toNat < m := by
|
||||
rw [lt_iff_toNat_lt, toNat_ofNat_of_lt' h]
|
||||
theorem UInt8.ofNat_lt_iff {n : UInt8} {m : Nat} (h : m < size) : ofNat m < n ↔ m < n.toNat := by
|
||||
@@ -2081,6 +2092,23 @@ theorem USize.ofNat_eq_iff_mod_eq_toNat (a : Nat) (b : USize) : USize.ofNat a =
|
||||
USize.ofNatLT (a % b) (Nat.mod_lt_of_lt ha) = USize.ofNatLT a ha % USize.ofNatLT b hb := by
|
||||
simp [USize.ofNatLT_eq_ofNat, USize.ofNat_mod ha hb]
|
||||
|
||||
@[simp] theorem UInt8.ofInt_one : ofInt 1 = 1 := rfl
|
||||
@[simp] theorem UInt8.ofInt_neg_one : ofInt (-1) = -1 := rfl
|
||||
@[simp] theorem UInt16.ofInt_one : ofInt 1 = 1 := rfl
|
||||
@[simp] theorem UInt16.ofInt_neg_one : ofInt (-1) = -1 := rfl
|
||||
@[simp] theorem UInt32.ofInt_one : ofInt 1 = 1 := rfl
|
||||
@[simp] theorem UInt32.ofInt_neg_one : ofInt (-1) = -1 := rfl
|
||||
@[simp] theorem UInt64.ofInt_one : ofInt 1 = 1 := rfl
|
||||
@[simp] theorem UInt64.ofInt_neg_one : ofInt (-1) = -1 := rfl
|
||||
@[simp] theorem USize.ofInt_one : ofInt 1 = 1 := by
|
||||
rcases System.Platform.numBits_eq with h | h <;>
|
||||
· apply USize.toNat_inj.mp
|
||||
simp_all [USize.ofInt, USize.ofNat, size, toNat]
|
||||
@[simp] theorem USize.ofInt_neg_one : ofInt (-1) = -1 := by
|
||||
rcases System.Platform.numBits_eq with h | h <;>
|
||||
· apply USize.toNat_inj.mp
|
||||
simp_all [USize.ofInt, USize.ofNat, size, toNat]
|
||||
|
||||
@[simp] theorem UInt8.ofNat_add (a b : Nat) : UInt8.ofNat (a + b) = UInt8.ofNat a + UInt8.ofNat b := by
|
||||
simp [UInt8.ofNat_eq_iff_mod_eq_toNat]
|
||||
@[simp] theorem UInt16.ofNat_add (a b : Nat) : UInt16.ofNat (a + b) = UInt16.ofNat a + UInt16.ofNat b := by
|
||||
@@ -2092,6 +2120,70 @@ theorem USize.ofNat_eq_iff_mod_eq_toNat (a : Nat) (b : USize) : USize.ofNat a =
|
||||
@[simp] theorem USize.ofNat_add (a b : Nat) : USize.ofNat (a + b) = USize.ofNat a + USize.ofNat b := by
|
||||
simp [USize.ofNat_eq_iff_mod_eq_toNat]
|
||||
|
||||
@[simp] theorem UInt8.ofInt_add (x y : Int) : ofInt (x + y) = ofInt x + ofInt y := by
|
||||
dsimp only [UInt8.ofInt]
|
||||
rw [Int.add_emod]
|
||||
have h₁ : 0 ≤ x % 2 ^ 8 := Int.emod_nonneg _ (by decide)
|
||||
have h₂ : 0 ≤ y % 2 ^ 8 := Int.emod_nonneg _ (by decide)
|
||||
have h₃ : 0 ≤ x % 2 ^ 8 + y % 2 ^ 8 := Int.add_nonneg h₁ h₂
|
||||
rw [Int.toNat_emod h₃ (by decide), Int.toNat_add h₁ h₂]
|
||||
have : (2 ^ 8 : Int).toNat = 2 ^ 8 := rfl
|
||||
rw [this, UInt8.ofNat_mod_size, UInt8.ofNat_add]
|
||||
@[simp] theorem UInt16.ofInt_add (x y : Int) : UInt16.ofInt (x + y) = UInt16.ofInt x + UInt16.ofInt y := by
|
||||
dsimp only [UInt16.ofInt]
|
||||
rw [Int.add_emod]
|
||||
have h₁ : 0 ≤ x % 2 ^ 16 := Int.emod_nonneg _ (by decide)
|
||||
have h₂ : 0 ≤ y % 2 ^ 16 := Int.emod_nonneg _ (by decide)
|
||||
have h₃ : 0 ≤ x % 2 ^ 16 + y % 2 ^ 16 := Int.add_nonneg h₁ h₂
|
||||
rw [Int.toNat_emod h₃ (by decide), Int.toNat_add h₁ h₂]
|
||||
have : (2 ^ 16 : Int).toNat = 2 ^ 16 := rfl
|
||||
rw [this, UInt16.ofNat_mod_size, UInt16.ofNat_add]
|
||||
@[simp] theorem UInt32.ofInt_add (x y : Int) : UInt32.ofInt (x + y) = UInt32.ofInt x + UInt32.ofInt y := by
|
||||
dsimp only [UInt32.ofInt]
|
||||
rw [Int.add_emod]
|
||||
have h₁ : 0 ≤ x % 2 ^ 32 := Int.emod_nonneg _ (by decide)
|
||||
have h₂ : 0 ≤ y % 2 ^ 32 := Int.emod_nonneg _ (by decide)
|
||||
have h₃ : 0 ≤ x % 2 ^ 32 + y % 2 ^ 32 := Int.add_nonneg h₁ h₂
|
||||
rw [Int.toNat_emod h₃ (by decide), Int.toNat_add h₁ h₂]
|
||||
have : (2 ^ 32 : Int).toNat = 2 ^ 32 := rfl
|
||||
rw [this, UInt32.ofNat_mod_size, UInt32.ofNat_add]
|
||||
@[simp] theorem UInt64.ofInt_add (x y : Int) : UInt64.ofInt (x + y) = UInt64.ofInt x + UInt64.ofInt y := by
|
||||
dsimp only [UInt64.ofInt]
|
||||
rw [Int.add_emod]
|
||||
have h₁ : 0 ≤ x % 2 ^ 64 := Int.emod_nonneg _ (by decide)
|
||||
have h₂ : 0 ≤ y % 2 ^ 64 := Int.emod_nonneg _ (by decide)
|
||||
have h₃ : 0 ≤ x % 2 ^ 64 + y % 2 ^ 64 := Int.add_nonneg h₁ h₂
|
||||
rw [Int.toNat_emod h₃ (by decide), Int.toNat_add h₁ h₂]
|
||||
have : (2 ^ 64 : Int).toNat = 2 ^ 64 := rfl
|
||||
rw [this, UInt64.ofNat_mod_size, UInt64.ofNat_add]
|
||||
|
||||
namespace System.Platform
|
||||
|
||||
theorem two_pow_numBits_nonneg : 0 ≤ (2 ^ System.Platform.numBits : Int) := by
|
||||
rcases System.Platform.numBits_eq with h | h <;>
|
||||
· rw [h]
|
||||
decide
|
||||
theorem two_pow_numBits_ne_zero : (2 ^ System.Platform.numBits : Int) ≠ 0 := by
|
||||
rcases System.Platform.numBits_eq with h | h <;>
|
||||
· rw [h]
|
||||
decide
|
||||
|
||||
end System.Platform
|
||||
|
||||
open System.Platform in
|
||||
@[simp] theorem USize.ofInt_add (x y : Int) : USize.ofInt (x + y) = USize.ofInt x + USize.ofInt y := by
|
||||
dsimp only [USize.ofInt]
|
||||
rw [Int.add_emod]
|
||||
have h₁ : 0 ≤ x % 2 ^ numBits := Int.emod_nonneg _ two_pow_numBits_ne_zero
|
||||
have h₂ : 0 ≤ y % 2 ^ numBits := Int.emod_nonneg _ two_pow_numBits_ne_zero
|
||||
have h₃ : 0 ≤ x % 2 ^ numBits + y % 2 ^ numBits := Int.add_nonneg h₁ h₂
|
||||
rw [Int.toNat_emod h₃ two_pow_numBits_nonneg, Int.toNat_add h₁ h₂]
|
||||
have : (2 ^ numBits : Int).toNat = 2 ^ numBits := by
|
||||
rcases System.Platform.numBits_eq with h | h <;>
|
||||
· rw [h]
|
||||
decide
|
||||
rw [this, USize.ofNat_mod_size, USize.ofNat_add]
|
||||
|
||||
@[simp] theorem UInt8.ofNatLT_add {a b : Nat} (hab : a + b < 2 ^ 8) :
|
||||
UInt8.ofNatLT (a + b) hab = UInt8.ofNatLT a (Nat.lt_of_add_right_lt hab) + UInt8.ofNatLT b (Nat.lt_of_add_left_lt hab) := by
|
||||
simp [UInt8.ofNatLT_eq_ofNat]
|
||||
@@ -2176,6 +2268,56 @@ theorem USize.ofNatLT_sub {a b : Nat} (ha : a < 2 ^ System.Platform.numBits) (ha
|
||||
@[simp] theorem USize.ofNat_mul (a b : Nat) : USize.ofNat (a * b) = USize.ofNat a * USize.ofNat b := by
|
||||
simp [USize.ofNat_eq_iff_mod_eq_toNat]
|
||||
|
||||
@[simp] theorem UInt8.ofInt_mul (x y : Int) : ofInt (x * y) = ofInt x * ofInt y := by
|
||||
dsimp only [UInt8.ofInt]
|
||||
rw [Int.mul_emod]
|
||||
have h₁ : 0 ≤ x % 2 ^ 8 := Int.emod_nonneg _ (by decide)
|
||||
have h₂ : 0 ≤ y % 2 ^ 8 := Int.emod_nonneg _ (by decide)
|
||||
have h₃ : 0 ≤ (x % 2 ^ 8) * (y % 2 ^ 8) := Int.mul_nonneg h₁ h₂
|
||||
rw [Int.toNat_emod h₃ (by decide), Int.toNat_mul h₁ h₂]
|
||||
have : (2 ^ 8 : Int).toNat = 2 ^ 8 := rfl
|
||||
rw [this, UInt8.ofNat_mod_size, UInt8.ofNat_mul]
|
||||
@[simp] theorem UInt16.ofInt_mul (x y : Int) : ofInt (x * y) = ofInt x * ofInt y := by
|
||||
dsimp only [UInt16.ofInt]
|
||||
rw [Int.mul_emod]
|
||||
have h₁ : 0 ≤ x % 2 ^ 16 := Int.emod_nonneg _ (by decide)
|
||||
have h₂ : 0 ≤ y % 2 ^ 16 := Int.emod_nonneg _ (by decide)
|
||||
have h₃ : 0 ≤ (x % 2 ^ 16) * (y % 2 ^ 16) := Int.mul_nonneg h₁ h₂
|
||||
rw [Int.toNat_emod h₃ (by decide), Int.toNat_mul h₁ h₂]
|
||||
have : (2 ^ 16 : Int).toNat = 2 ^ 16 := rfl
|
||||
rw [this, UInt16.ofNat_mod_size, UInt16.ofNat_mul]
|
||||
@[simp] theorem UInt32.ofInt_mul (x y : Int) : ofInt (x * y) = ofInt x * ofInt y := by
|
||||
dsimp only [UInt32.ofInt]
|
||||
rw [Int.mul_emod]
|
||||
have h₁ : 0 ≤ x % 2 ^ 32 := Int.emod_nonneg _ (by decide)
|
||||
have h₂ : 0 ≤ y % 2 ^ 32 := Int.emod_nonneg _ (by decide)
|
||||
have h₃ : 0 ≤ (x % 2 ^ 32) * (y % 2 ^ 32) := Int.mul_nonneg h₁ h₂
|
||||
rw [Int.toNat_emod h₃ (by decide), Int.toNat_mul h₁ h₂]
|
||||
have : (2 ^ 32 : Int).toNat = 2 ^ 32 := rfl
|
||||
rw [this, UInt32.ofNat_mod_size, UInt32.ofNat_mul]
|
||||
@[simp] theorem UInt64.ofInt_mul (x y : Int) : ofInt (x * y) = ofInt x * ofInt y := by
|
||||
dsimp only [UInt64.ofInt]
|
||||
rw [Int.mul_emod]
|
||||
have h₁ : 0 ≤ x % 2 ^ 64 := Int.emod_nonneg _ (by decide)
|
||||
have h₂ : 0 ≤ y % 2 ^ 64 := Int.emod_nonneg _ (by decide)
|
||||
have h₃ : 0 ≤ (x % 2 ^ 64) * (y % 2 ^ 64) := Int.mul_nonneg h₁ h₂
|
||||
rw [Int.toNat_emod h₃ (by decide), Int.toNat_mul h₁ h₂]
|
||||
have : (2 ^ 64 : Int).toNat = 2 ^ 64 := rfl
|
||||
rw [this, UInt64.ofNat_mod_size, UInt64.ofNat_mul]
|
||||
open System.Platform in
|
||||
@[simp] theorem USize.ofInt_mul (x y : Int) : ofInt (x * y) = ofInt x * ofInt y := by
|
||||
dsimp only [USize.ofInt]
|
||||
rw [Int.mul_emod]
|
||||
have h₁ : 0 ≤ x % 2 ^ numBits := Int.emod_nonneg _ two_pow_numBits_ne_zero
|
||||
have h₂ : 0 ≤ y % 2 ^ numBits := Int.emod_nonneg _ two_pow_numBits_ne_zero
|
||||
have h₃ : 0 ≤ (x % 2 ^ numBits) * (y % 2 ^ numBits) := Int.mul_nonneg h₁ h₂
|
||||
rw [Int.toNat_emod h₃ two_pow_numBits_nonneg, Int.toNat_mul h₁ h₂]
|
||||
have : (2 ^ numBits : Int).toNat = 2 ^ numBits := by
|
||||
rcases System.Platform.numBits_eq with h | h <;>
|
||||
· rw [h]
|
||||
decide
|
||||
rw [this, USize.ofNat_mod_size, USize.ofNat_mul]
|
||||
|
||||
@[simp] theorem UInt8.ofNatLT_mul {a b : Nat} (ha : a < 2 ^ 8) (hb : b < 2 ^ 8) (hab : a * b < 2 ^ 8) :
|
||||
UInt8.ofNatLT (a * b) hab = UInt8.ofNatLT a ha * UInt8.ofNatLT b hb := by
|
||||
simp [UInt8.ofNatLT_eq_ofNat]
|
||||
@@ -2467,6 +2609,17 @@ protected theorem USize.neg_add {a b : USize} : - (a + b) = -a - b := USize.toBi
|
||||
@[simp] protected theorem USize.neg_sub {a b : USize} : -(a - b) = b - a := by
|
||||
rw [USize.sub_eq_add_neg, USize.neg_add, USize.sub_neg, USize.add_comm, ← USize.sub_eq_add_neg]
|
||||
|
||||
@[simp] protected theorem UInt8.ofInt_neg (x : Int) : ofInt (-x) = -ofInt x := by
|
||||
rw [Int.neg_eq_neg_one_mul, ofInt_mul, ofInt_neg_one, ← UInt8.neg_eq_neg_one_mul]
|
||||
@[simp] protected theorem UInt16.ofInt_neg (x : Int) : ofInt (-x) = -ofInt x := by
|
||||
rw [Int.neg_eq_neg_one_mul, ofInt_mul, ofInt_neg_one, ← UInt16.neg_eq_neg_one_mul]
|
||||
@[simp] protected theorem UInt32.ofInt_neg (x : Int) : ofInt (-x) = -ofInt x := by
|
||||
rw [Int.neg_eq_neg_one_mul, ofInt_mul, ofInt_neg_one, ← UInt32.neg_eq_neg_one_mul]
|
||||
@[simp] protected theorem UInt64.ofInt_neg (x : Int) : ofInt (-x) = -ofInt x := by
|
||||
rw [Int.neg_eq_neg_one_mul, ofInt_mul, ofInt_neg_one, ← UInt64.neg_eq_neg_one_mul]
|
||||
@[simp] protected theorem USize.ofInt_neg (x : Int) : ofInt (-x) = -ofInt x := by
|
||||
rw [Int.neg_eq_neg_one_mul, ofInt_mul, ofInt_neg_one, ← USize.neg_eq_neg_one_mul]
|
||||
|
||||
@[simp] protected theorem UInt8.add_left_inj {a b : UInt8} (c : UInt8) : (a + c = b + c) ↔ a = b := by
|
||||
simp [← UInt8.toBitVec_inj]
|
||||
@[simp] protected theorem UInt16.add_left_inj {a b : UInt16} (c : UInt16) : (a + c = b + c) ↔ a = b := by
|
||||
|
||||
@@ -239,4 +239,15 @@ theorem count_flatMap {α} [BEq β] {xs : Vector α n} {f : α → Vector β m}
|
||||
rcases xs with ⟨xs, rfl⟩
|
||||
simp [Array.count_flatMap, Function.comp_def]
|
||||
|
||||
theorem countP_replace {a b : α} {xs : Vector α n} {p : α → Bool} :
|
||||
(xs.replace a b).countP p =
|
||||
if xs.contains a then xs.countP p + (if p b then 1 else 0) - (if p a then 1 else 0) else xs.countP p := by
|
||||
rcases xs with ⟨xs, rfl⟩
|
||||
simp [Array.countP_replace]
|
||||
|
||||
theorem count_replace {a b c : α} {xs : Vector α n} :
|
||||
(xs.replace a b).count c =
|
||||
if xs.contains a then xs.count c + (if b == c then 1 else 0) - (if a == c then 1 else 0) else xs.count c := by
|
||||
simp [count_eq_countP, countP_replace]
|
||||
|
||||
end count
|
||||
|
||||
@@ -294,6 +294,18 @@ theorem find?_eq_some_iff_getElem {xs : Vector α n} {p : α → Bool} {b : α}
|
||||
subst w
|
||||
simp
|
||||
|
||||
@[simp]
|
||||
theorem isSome_findFinIdx? {xs : Vector α n} {p : α → Bool} :
|
||||
(xs.findFinIdx? p).isSome = xs.any p := by
|
||||
rcases xs with ⟨xs, rfl⟩
|
||||
simp
|
||||
|
||||
@[simp]
|
||||
theorem isNone_findFinIdx? {xs : Vector α n} {p : α → Bool} :
|
||||
(xs.findFinIdx? p).isNone = xs.all (fun x => ¬ p x) := by
|
||||
rcases xs with ⟨xs, rfl⟩
|
||||
simp
|
||||
|
||||
@[simp] theorem findFinIdx?_subtype {p : α → Prop} {xs : Vector { x // p x } n}
|
||||
{f : { x // p x } → Bool} {g : α → Bool} (hf : ∀ x h, f ⟨x, h⟩ = g x) :
|
||||
xs.findFinIdx? f = xs.unattach.findFinIdx? g := by
|
||||
|
||||
@@ -1511,7 +1511,7 @@ theorem map_eq_iff {f : α → β} {as : Vector α n} {bs : Vector β n} :
|
||||
if h : i < as.size then
|
||||
simpa [h, h'] using w i h
|
||||
else
|
||||
rw [getElem?_neg, getElem?_neg, Option.map_none'] <;> omega
|
||||
rw [getElem?_neg, getElem?_neg, Option.map_none] <;> omega
|
||||
|
||||
@[simp] theorem map_set {f : α → β} {xs : Vector α n} {i : Nat} {h : i < n} {a : α} :
|
||||
(xs.set i a).map f = (xs.map f).set i (f a) (by simpa using h) := by
|
||||
|
||||
@@ -5,14 +5,30 @@ Authors: Kim Morrison
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Zero
|
||||
import Init.Data.Int.DivMod.Lemmas
|
||||
import Init.TacticsExtra
|
||||
|
||||
/-!
|
||||
# A monolithic commutative ring typeclass for internal use in `grind`.
|
||||
|
||||
The `Lean.Grind.CommRing` class will be used to convert expressions into the internal representation via polynomials,
|
||||
with coefficients expressed via `OfNat` and `Neg`.
|
||||
|
||||
The `IsCharP α p` typeclass expresses that the ring has characteristic `p`,
|
||||
i.e. that a coefficient `OfNat.ofNat x : α` is zero if and only if `x % p = 0` (in `Nat`).
|
||||
See
|
||||
```
|
||||
theorem ofNat_ext_iff {x y : Nat} : OfNat.ofNat (α := α) x = OfNat.ofNat (α := α) y ↔ x % p = y % p
|
||||
theorem ofNat_emod (x : Nat) : OfNat.ofNat (α := α) (x % p) = OfNat.ofNat x
|
||||
theorem ofNat_eq_iff_of_lt {x y : Nat} (h₁ : x < p) (h₂ : y < p) :
|
||||
OfNat.ofNat (α := α) x = OfNat.ofNat (α := α) y ↔ x = y
|
||||
```
|
||||
-/
|
||||
|
||||
namespace Lean.Grind
|
||||
|
||||
class CommRing (α : Type u) extends Add α, Zero α, Mul α, One α, Neg α where
|
||||
class CommRing (α : Type u) extends Add α, Mul α, Neg α, Sub α, HPow α Nat α where
|
||||
[ofNat : ∀ n, OfNat α n]
|
||||
add_assoc : ∀ a b c : α, a + b + c = a + (b + c)
|
||||
add_comm : ∀ a b : α, a + b = b + a
|
||||
add_zero : ∀ a : α, a + 0 = a
|
||||
@@ -22,11 +38,31 @@ class CommRing (α : Type u) extends Add α, Zero α, Mul α, One α, Neg α whe
|
||||
mul_one : ∀ a : α, a * 1 = a
|
||||
left_distrib : ∀ a b c : α, a * (b + c) = a * b + a * c
|
||||
zero_mul : ∀ a : α, 0 * a = 0
|
||||
sub_eq_add_neg : ∀ a b : α, a - b = a + -b
|
||||
pow_zero : ∀ a : α, a ^ 0 = 1
|
||||
pow_succ : ∀ a : α, ∀ n : Nat, a ^ (n + 1) = (a ^ n) * a
|
||||
ofNat_succ : ∀ a : Nat, OfNat.ofNat (α := α) (a + 1) = OfNat.ofNat a + 1 := by intros; rfl
|
||||
|
||||
-- This is a low-priority instance, to avoid conflicts with existing `OfNat` instances.
|
||||
attribute [instance 100] CommRing.ofNat
|
||||
|
||||
namespace CommRing
|
||||
|
||||
variable {α : Type u} [CommRing α]
|
||||
|
||||
instance : NatCast α where
|
||||
natCast n := OfNat.ofNat n
|
||||
|
||||
theorem natCast_zero : ((0 : Nat) : α) = 0 := rfl
|
||||
theorem ofNat_eq_natCast (n : Nat) : OfNat.ofNat n = (n : α) := rfl
|
||||
|
||||
theorem ofNat_add (a b : Nat) : OfNat.ofNat (α := α) (a + b) = OfNat.ofNat a + OfNat.ofNat b := by
|
||||
induction b with
|
||||
| zero => simp [Nat.add_zero, add_zero]
|
||||
| succ b ih => rw [Nat.add_succ, ofNat_succ, ih, ofNat_succ b, add_assoc]
|
||||
|
||||
theorem natCast_succ (n : Nat) : ((n + 1 : Nat) : α) = ((n : α) + 1) := ofNat_add _ _
|
||||
|
||||
theorem zero_add (a : α) : 0 + a = a := by
|
||||
rw [add_comm, add_zero]
|
||||
|
||||
@@ -42,6 +78,204 @@ theorem right_distrib (a b c : α) : (a + b) * c = a * c + b * c := by
|
||||
theorem mul_zero (a : α) : a * 0 = 0 := by
|
||||
rw [mul_comm, zero_mul]
|
||||
|
||||
theorem ofNat_mul (a b : Nat) : OfNat.ofNat (α := α) (a * b) = OfNat.ofNat a * OfNat.ofNat b := by
|
||||
induction b with
|
||||
| zero => simp [Nat.mul_zero, mul_zero]
|
||||
| succ a ih => rw [Nat.mul_succ, ofNat_add, ih, ofNat_add, left_distrib, mul_one]
|
||||
|
||||
theorem add_left_inj {a b : α} (c : α) : a + c = b + c ↔ a = b :=
|
||||
⟨fun h => by simpa [add_assoc, add_neg_cancel, add_zero] using (congrArg (· + -c) h),
|
||||
fun g => congrArg (· + c) g⟩
|
||||
|
||||
theorem add_right_inj (a b c : α) : a + b = a + c ↔ b = c := by
|
||||
rw [add_comm a b, add_comm a c, add_left_inj]
|
||||
|
||||
theorem neg_zero : (-0 : α) = 0 := by
|
||||
rw [← add_left_inj 0, neg_add_cancel, add_zero]
|
||||
|
||||
theorem neg_neg (a : α) : -(-a) = a := by
|
||||
rw [← add_left_inj (-a), neg_add_cancel, add_neg_cancel]
|
||||
|
||||
theorem neg_eq_zero (a : α) : -a = 0 ↔ a = 0 :=
|
||||
⟨fun h => by
|
||||
replace h := congrArg (-·) h
|
||||
simpa [neg_neg, neg_zero] using h,
|
||||
fun h => by rw [h, neg_zero]⟩
|
||||
|
||||
theorem neg_add (a b : α) : -(a + b) = -a + -b := by
|
||||
rw [← add_left_inj (a + b), neg_add_cancel, add_assoc (-a), add_comm a b, ← add_assoc (-b),
|
||||
neg_add_cancel, zero_add, neg_add_cancel]
|
||||
|
||||
theorem neg_sub (a b : α) : -(a - b) = b - a := by
|
||||
rw [sub_eq_add_neg, neg_add, neg_neg, sub_eq_add_neg, add_comm]
|
||||
|
||||
theorem sub_self (a : α) : a - a = 0 := by
|
||||
rw [sub_eq_add_neg, add_neg_cancel]
|
||||
|
||||
instance : IntCast α where
|
||||
intCast n := match n with
|
||||
| Int.ofNat n => OfNat.ofNat n
|
||||
| Int.negSucc n => -OfNat.ofNat (n + 1)
|
||||
|
||||
theorem intCast_zero : ((0 : Int) : α) = 0 := rfl
|
||||
theorem intCast_one : ((1 : Int) : α) = 1 := rfl
|
||||
theorem intCast_neg_one : ((-1 : Int) : α) = -1 := rfl
|
||||
theorem intCast_ofNat (n : Nat) : ((n : Int) : α) = (n : α) := rfl
|
||||
theorem intCast_ofNat_add_one (n : Nat) : ((n + 1 : Int) : α) = (n : α) + 1 := ofNat_add _ _
|
||||
theorem intCast_negSucc (n : Nat) : ((-(n + 1) : Int) : α) = -((n : α) + 1) := congrArg (- ·) (ofNat_add _ _)
|
||||
theorem intCast_neg (x : Int) : ((-x : Int) : α) = - (x : α) :=
|
||||
match x with
|
||||
| (0 : Nat) => neg_zero.symm
|
||||
| (n + 1 : Nat) => by
|
||||
rw [Int.natCast_add, Int.cast_ofNat_Int, intCast_negSucc, intCast_ofNat_add_one]
|
||||
| -((n : Nat) + 1) => by
|
||||
rw [Int.neg_neg, intCast_ofNat_add_one, intCast_negSucc, neg_neg]
|
||||
theorem intCast_nat_add {x y : Nat} : ((x + y : Int) : α) = ((x : α) + (y : α)) := ofNat_add _ _
|
||||
theorem intCast_nat_sub {x y : Nat} (h : x ≥ y) : (((x - y : Nat) : Int) : α) = ((x : α) - (y : α)) := by
|
||||
induction x with
|
||||
| zero =>
|
||||
have : y = 0 := by omega
|
||||
simp [this, intCast_zero, natCast_zero, sub_eq_add_neg, zero_add, neg_zero]
|
||||
| succ x ih =>
|
||||
by_cases h : x + 1 = y
|
||||
· simp [h, intCast_zero, sub_self]
|
||||
· have : ((x + 1 - y : Nat) : Int) = (x - y : Nat) + 1 := by omega
|
||||
rw [this, intCast_ofNat_add_one]
|
||||
specialize ih (by omega)
|
||||
rw [intCast_ofNat] at ih
|
||||
rw [ih, natCast_succ, sub_eq_add_neg, sub_eq_add_neg, add_assoc, add_comm _ 1, ← add_assoc]
|
||||
theorem intCast_add (x y : Int) : ((x + y : Int) : α) = ((x : α) + (y : α)) :=
|
||||
match x, y with
|
||||
| (x : Nat), (y : Nat) => ofNat_add _ _
|
||||
| (x : Nat), (-(y + 1 : Nat)) => by
|
||||
by_cases h : x ≥ y + 1
|
||||
· have : (x + -(y+1 : Nat) : Int) = ((x - (y + 1) : Nat) : Int) := by omega
|
||||
rw [this, intCast_neg, intCast_nat_sub h, intCast_ofNat, intCast_ofNat, sub_eq_add_neg]
|
||||
· have : (x + -(y+1 : Nat) : Int) = (-(y + 1 - x : Nat) : Int) := by omega
|
||||
rw [this, intCast_neg, intCast_nat_sub (by omega), intCast_ofNat, intCast_neg, intCast_ofNat,
|
||||
neg_sub, sub_eq_add_neg]
|
||||
| (-(x + 1 : Nat)), (y : Nat) => by
|
||||
by_cases h : y ≥ x+ 1
|
||||
· have : (-(x+1 : Nat) + y : Int) = ((y - (x + 1) : Nat) : Int) := by omega
|
||||
rw [this, intCast_neg, intCast_nat_sub h, intCast_ofNat, intCast_ofNat, sub_eq_add_neg, add_comm]
|
||||
· have : (-(x+1 : Nat) + y : Int) = (-(x + 1 - y : Nat) : Int) := by omega
|
||||
rw [this, intCast_neg, intCast_nat_sub (by omega), intCast_ofNat, intCast_neg, intCast_ofNat,
|
||||
neg_sub, sub_eq_add_neg, add_comm]
|
||||
| (-(x + 1 : Nat)), (-(y + 1 : Nat)) => by
|
||||
rw [← Int.neg_add, intCast_neg, intCast_nat_add, neg_add, intCast_neg, intCast_neg, intCast_ofNat, intCast_ofNat]
|
||||
theorem intCast_sub (x y : Int) : ((x - y : Int) : α) = ((x : α) - (y : α)) := by
|
||||
rw [Int.sub_eq_add_neg, intCast_add, intCast_neg, sub_eq_add_neg]
|
||||
|
||||
theorem neg_eq_neg_one_mul (a : α) : -a = (-1) * a := by
|
||||
rw [← add_left_inj a, neg_add_cancel]
|
||||
conv => rhs; arg 2; rw [← one_mul a]
|
||||
rw [← right_distrib, ← intCast_neg_one, ← intCast_one (α := α)]
|
||||
simp [← intCast_add, intCast_zero, zero_mul]
|
||||
|
||||
theorem neg_mul (a b : α) : (-a) * b = -(a * b) := by
|
||||
rw [neg_eq_neg_one_mul a, neg_eq_neg_one_mul (a * b), mul_assoc]
|
||||
|
||||
theorem mul_neg (a b : α) : a * (-b) = -(a * b) := by
|
||||
rw [mul_comm, neg_mul, mul_comm]
|
||||
|
||||
theorem intCast_nat_mul (x y : Nat) : ((x * y : Int) : α) = ((x : α) * (y : α)) := ofNat_mul _ _
|
||||
theorem intCast_mul (x y : Int) : ((x * y : Int) : α) = ((x : α) * (y : α)) :=
|
||||
match x, y with
|
||||
| (x : Nat), (y : Nat) => ofNat_mul _ _
|
||||
| (x : Nat), (-(y + 1 : Nat)) => by
|
||||
rw [Int.mul_neg, intCast_neg, intCast_nat_mul, intCast_neg, mul_neg, intCast_ofNat, intCast_ofNat]
|
||||
| (-(x + 1 : Nat)), (y : Nat) => by
|
||||
rw [Int.neg_mul, intCast_neg, intCast_nat_mul, intCast_neg, neg_mul, intCast_ofNat, intCast_ofNat]
|
||||
| (-(x + 1 : Nat)), (-(y + 1 : Nat)) => by
|
||||
rw [Int.neg_mul_neg, intCast_neg, intCast_neg, neg_mul, mul_neg, neg_neg, intCast_nat_mul,
|
||||
intCast_ofNat, intCast_ofNat]
|
||||
|
||||
end CommRing
|
||||
|
||||
open CommRing
|
||||
|
||||
class IsCharP (α : Type u) [CommRing α] (p : Nat) where
|
||||
ofNat_eq_zero_iff (p) : ∀ (x : Nat), OfNat.ofNat (α := α) x = 0 ↔ x % p = 0
|
||||
|
||||
namespace IsCharP
|
||||
|
||||
variable (p) {α : Type u} [CommRing α] [IsCharP α p]
|
||||
|
||||
theorem natCast_eq_zero_iff (x : Nat) : (x : α) = 0 ↔ x % p = 0 :=
|
||||
ofNat_eq_zero_iff p x
|
||||
|
||||
theorem intCast_eq_zero_iff (x : Int) : (x : α) = 0 ↔ x % p = 0 :=
|
||||
match x with
|
||||
| (x : Nat) => by
|
||||
have := ofNat_eq_zero_iff (α := α) p (x := x)
|
||||
rw [Int.ofNat_mod_ofNat]
|
||||
norm_cast
|
||||
| -(x + 1 : Nat) => by
|
||||
rw [Int.neg_emod, Int.ofNat_mod_ofNat, intCast_neg, intCast_ofNat, neg_eq_zero]
|
||||
have := ofNat_eq_zero_iff (α := α) p (x := x + 1)
|
||||
rw [ofNat_eq_natCast] at this
|
||||
rw [this]
|
||||
simp only [Int.ofNat_dvd]
|
||||
simp only [← Nat.dvd_iff_mod_eq_zero, Int.natAbs_ofNat, Int.natCast_add,
|
||||
Int.cast_ofNat_Int, ite_eq_left_iff]
|
||||
by_cases h : p ∣ x + 1
|
||||
· simp [h]
|
||||
· simp only [h, not_false_eq_true, Int.natCast_add, Int.cast_ofNat_Int,
|
||||
forall_const, false_iff, ne_eq]
|
||||
by_cases w : p = 0
|
||||
· simp [w]
|
||||
omega
|
||||
· have : ((x + 1) % p) < p := Nat.mod_lt _ (by omega)
|
||||
omega
|
||||
|
||||
theorem intCast_ext_iff {x y : Int} : (x : α) = (y : α) ↔ x % p = y % p := by
|
||||
constructor
|
||||
· intro h
|
||||
replace h : ((x - y : Int) : α) = 0 := by rw [intCast_sub, h, sub_self]
|
||||
exact Int.emod_eq_emod_iff_emod_sub_eq_zero.mpr ((intCast_eq_zero_iff p _).mp h)
|
||||
· intro h
|
||||
have : ((x - y : Int) : α) = 0 :=
|
||||
(intCast_eq_zero_iff p _).mpr (by rw [Int.sub_emod, h, Int.sub_self, Int.zero_emod])
|
||||
replace this := congrArg (· + (y : α)) this
|
||||
simpa [intCast_sub, zero_add, sub_eq_add_neg, add_assoc, neg_add_cancel, add_zero] using this
|
||||
|
||||
theorem ofNat_ext_iff {x y : Nat} : OfNat.ofNat (α := α) x = OfNat.ofNat (α := α) y ↔ x % p = y % p := by
|
||||
have := intCast_ext_iff (α := α) p (x := x) (y := y)
|
||||
simp only [intCast_ofNat, ← Int.ofNat_emod] at this
|
||||
simp only [ofNat_eq_natCast]
|
||||
norm_cast at this
|
||||
|
||||
theorem ofNat_ext {x y : Nat} (h : x % p = y % p) : OfNat.ofNat (α := α) x = OfNat.ofNat (α := α) y := (ofNat_ext_iff p).mpr h
|
||||
|
||||
theorem natCast_ext {x y : Nat} (h : x % p = y % p) : (x : α) = (y : α) := ofNat_ext _ h
|
||||
|
||||
theorem natCast_ext_iff {x y : Nat} : (x : α) = (y : α) ↔ x % p = y % p :=
|
||||
ofNat_ext_iff p
|
||||
|
||||
theorem intCast_emod (x : Int) : ((x % p : Int) : α) = (x : α) := by
|
||||
rw [intCast_ext_iff p, Int.emod_emod]
|
||||
|
||||
theorem natCast_emod (x : Nat) : ((x % p : Nat) : α) = (x : α) := by
|
||||
simp only [← intCast_ofNat]
|
||||
rw [Int.ofNat_emod, intCast_emod]
|
||||
|
||||
theorem ofNat_emod (x : Nat) : OfNat.ofNat (α := α) (x % p) = OfNat.ofNat x :=
|
||||
natCast_emod _ _
|
||||
|
||||
theorem ofNat_eq_zero_iff_of_lt {x : Nat} (h : x < p) : OfNat.ofNat (α := α) x = 0 ↔ x = 0 := by
|
||||
rw [ofNat_eq_zero_iff p, Nat.mod_eq_of_lt h]
|
||||
|
||||
theorem ofNat_eq_iff_of_lt {x y : Nat} (h₁ : x < p) (h₂ : y < p) :
|
||||
OfNat.ofNat (α := α) x = OfNat.ofNat (α := α) y ↔ x = y := by
|
||||
rw [ofNat_ext_iff p, Nat.mod_eq_of_lt h₁, Nat.mod_eq_of_lt h₂]
|
||||
|
||||
theorem natCast_eq_zero_iff_of_lt {x : Nat} (h : x < p) : (x : α) = 0 ↔ x = 0 := by
|
||||
rw [natCast_eq_zero_iff p, Nat.mod_eq_of_lt h]
|
||||
|
||||
theorem natCast_eq_iff_of_lt {x y : Nat} (h₁ : x < p) (h₂ : y < p) :
|
||||
(x : α) = (y : α) ↔ x = y := by
|
||||
rw [natCast_ext_iff p, Nat.mod_eq_of_lt h₁, Nat.mod_eq_of_lt h₂]
|
||||
|
||||
end IsCharP
|
||||
|
||||
end Lean.Grind
|
||||
|
||||
@@ -19,5 +19,12 @@ instance : CommRing (BitVec w) where
|
||||
mul_one := BitVec.mul_one
|
||||
left_distrib _ _ _ := BitVec.mul_add
|
||||
zero_mul _ := BitVec.zero_mul
|
||||
sub_eq_add_neg := BitVec.sub_eq_add_neg
|
||||
pow_zero _ := BitVec.pow_zero
|
||||
pow_succ _ _ := BitVec.pow_succ
|
||||
ofNat_succ x := BitVec.ofNat_add x 1
|
||||
|
||||
instance : IsCharP (BitVec w) (2 ^ w) where
|
||||
ofNat_eq_zero_iff {x} := by simp [BitVec.ofInt, BitVec.toNat_eq]
|
||||
|
||||
end Lean.Grind
|
||||
|
||||
@@ -19,5 +19,12 @@ instance : CommRing Int where
|
||||
mul_one := Int.mul_one
|
||||
left_distrib := Int.mul_add
|
||||
zero_mul := Int.zero_mul
|
||||
pow_zero _ := rfl
|
||||
pow_succ _ _ := rfl
|
||||
ofNat_succ _ := rfl
|
||||
sub_eq_add_neg _ _ := Int.sub_eq_add_neg
|
||||
|
||||
instance : IsCharP Int 0 where
|
||||
ofNat_eq_zero_iff {x} := by erw [Int.ofNat_eq_zero]; simp
|
||||
|
||||
end Lean.Grind
|
||||
|
||||
@@ -9,6 +9,9 @@ import Init.Data.SInt.Lemmas
|
||||
|
||||
namespace Lean.Grind
|
||||
|
||||
instance : IntCast Int8 where
|
||||
intCast x := Int8.ofInt x
|
||||
|
||||
instance : CommRing Int8 where
|
||||
add_assoc := Int8.add_assoc
|
||||
add_comm := Int8.add_comm
|
||||
@@ -19,6 +22,20 @@ instance : CommRing Int8 where
|
||||
mul_one := Int8.mul_one
|
||||
left_distrib _ _ _ := Int8.mul_add
|
||||
zero_mul _ := Int8.zero_mul
|
||||
sub_eq_add_neg := Int8.sub_eq_add_neg
|
||||
pow_zero := Int8.pow_zero
|
||||
pow_succ := Int8.pow_succ
|
||||
ofNat_succ x := Int8.ofNat_add x 1
|
||||
|
||||
instance : IsCharP Int8 (2 ^ 8) where
|
||||
ofNat_eq_zero_iff {x} := by
|
||||
have : OfNat.ofNat x = Int8.ofInt x := rfl
|
||||
rw [this]
|
||||
simp [Int8.ofInt_eq_iff_bmod_eq_toInt,
|
||||
← Int.dvd_iff_bmod_eq_zero, ← Nat.dvd_iff_mod_eq_zero, Int.ofNat_dvd_right]
|
||||
|
||||
instance : IntCast Int16 where
|
||||
intCast x := Int16.ofInt x
|
||||
|
||||
instance : CommRing Int16 where
|
||||
add_assoc := Int16.add_assoc
|
||||
@@ -30,6 +47,20 @@ instance : CommRing Int16 where
|
||||
mul_one := Int16.mul_one
|
||||
left_distrib _ _ _ := Int16.mul_add
|
||||
zero_mul _ := Int16.zero_mul
|
||||
sub_eq_add_neg := Int16.sub_eq_add_neg
|
||||
pow_zero := Int16.pow_zero
|
||||
pow_succ := Int16.pow_succ
|
||||
ofNat_succ x := Int16.ofNat_add x 1
|
||||
|
||||
instance : IsCharP Int16 (2 ^ 16) where
|
||||
ofNat_eq_zero_iff {x} := by
|
||||
have : OfNat.ofNat x = Int16.ofInt x := rfl
|
||||
rw [this]
|
||||
simp [Int16.ofInt_eq_iff_bmod_eq_toInt,
|
||||
← Int.dvd_iff_bmod_eq_zero, ← Nat.dvd_iff_mod_eq_zero, Int.ofNat_dvd_right]
|
||||
|
||||
instance : IntCast Int32 where
|
||||
intCast x := Int32.ofInt x
|
||||
|
||||
instance : CommRing Int32 where
|
||||
add_assoc := Int32.add_assoc
|
||||
@@ -41,6 +72,20 @@ instance : CommRing Int32 where
|
||||
mul_one := Int32.mul_one
|
||||
left_distrib _ _ _ := Int32.mul_add
|
||||
zero_mul _ := Int32.zero_mul
|
||||
sub_eq_add_neg := Int32.sub_eq_add_neg
|
||||
pow_zero := Int32.pow_zero
|
||||
pow_succ := Int32.pow_succ
|
||||
ofNat_succ x := Int32.ofNat_add x 1
|
||||
|
||||
instance : IsCharP Int32 (2 ^ 32) where
|
||||
ofNat_eq_zero_iff {x} := by
|
||||
have : OfNat.ofNat x = Int32.ofInt x := rfl
|
||||
rw [this]
|
||||
simp [Int32.ofInt_eq_iff_bmod_eq_toInt,
|
||||
← Int.dvd_iff_bmod_eq_zero, ← Nat.dvd_iff_mod_eq_zero, Int.ofNat_dvd_right]
|
||||
|
||||
instance : IntCast Int64 where
|
||||
intCast x := Int64.ofInt x
|
||||
|
||||
instance : CommRing Int64 where
|
||||
add_assoc := Int64.add_assoc
|
||||
@@ -52,6 +97,20 @@ instance : CommRing Int64 where
|
||||
mul_one := Int64.mul_one
|
||||
left_distrib _ _ _ := Int64.mul_add
|
||||
zero_mul _ := Int64.zero_mul
|
||||
sub_eq_add_neg := Int64.sub_eq_add_neg
|
||||
pow_zero := Int64.pow_zero
|
||||
pow_succ := Int64.pow_succ
|
||||
ofNat_succ x := Int64.ofNat_add x 1
|
||||
|
||||
instance : IsCharP Int64 (2 ^ 64) where
|
||||
ofNat_eq_zero_iff {x} := by
|
||||
have : OfNat.ofNat x = Int64.ofInt x := rfl
|
||||
rw [this]
|
||||
simp [Int64.ofInt_eq_iff_bmod_eq_toInt,
|
||||
← Int.dvd_iff_bmod_eq_zero, ← Nat.dvd_iff_mod_eq_zero, Int.ofNat_dvd_right]
|
||||
|
||||
instance : IntCast ISize where
|
||||
intCast x := ISize.ofInt x
|
||||
|
||||
instance : CommRing ISize where
|
||||
add_assoc := ISize.add_assoc
|
||||
@@ -63,5 +122,18 @@ instance : CommRing ISize where
|
||||
mul_one := ISize.mul_one
|
||||
left_distrib _ _ _ := ISize.mul_add
|
||||
zero_mul _ := ISize.zero_mul
|
||||
sub_eq_add_neg := ISize.sub_eq_add_neg
|
||||
pow_zero := ISize.pow_zero
|
||||
pow_succ := ISize.pow_succ
|
||||
ofNat_succ x := ISize.ofNat_add x 1
|
||||
|
||||
open System.Platform (numBits)
|
||||
|
||||
instance : IsCharP ISize (2 ^ numBits) where
|
||||
ofNat_eq_zero_iff {x} := by
|
||||
have : OfNat.ofNat x = ISize.ofInt x := rfl
|
||||
rw [this]
|
||||
simp [ISize.ofInt_eq_iff_bmod_eq_toInt,
|
||||
← Int.dvd_iff_bmod_eq_zero, ← Nat.dvd_iff_mod_eq_zero, Int.ofNat_dvd_right]
|
||||
|
||||
end Lean.Grind
|
||||
|
||||
@@ -7,6 +7,53 @@ prelude
|
||||
import Init.Grind.CommRing.Basic
|
||||
import Init.Data.UInt.Lemmas
|
||||
|
||||
|
||||
namespace UInt8
|
||||
|
||||
/-- Variant of `UInt8.ofNat_mod_size` replacing `2 ^ 8` with `256`.-/
|
||||
theorem ofNat_mod_size' : ofNat (x % 256) = ofNat x := ofNat_mod_size
|
||||
|
||||
instance : IntCast UInt8 where
|
||||
intCast x := UInt8.ofInt x
|
||||
|
||||
end UInt8
|
||||
|
||||
namespace UInt16
|
||||
|
||||
/-- Variant of `UInt16.ofNat_mod_size` replacing `2 ^ 16` with `65536`.-/
|
||||
theorem ofNat_mod_size' : ofNat (x % 65536) = ofNat x := ofNat_mod_size
|
||||
|
||||
instance : IntCast UInt16 where
|
||||
intCast x := UInt16.ofInt x
|
||||
|
||||
end UInt16
|
||||
|
||||
namespace UInt32
|
||||
|
||||
/-- Variant of `UInt32.ofNat_mod_size` replacing `2 ^ 32` with `4294967296`.-/
|
||||
theorem ofNat_mod_size' : ofNat (x % 4294967296) = ofNat x := ofNat_mod_size
|
||||
|
||||
instance : IntCast UInt32 where
|
||||
intCast x := UInt32.ofInt x
|
||||
|
||||
end UInt32
|
||||
|
||||
namespace UInt64
|
||||
|
||||
/-- Variant of `UInt64.ofNat_mod_size` replacing `2 ^ 64` with `18446744073709551616`.-/
|
||||
theorem ofNat_mod_size' : ofNat (x % 18446744073709551616) = ofNat x := ofNat_mod_size
|
||||
|
||||
instance : IntCast UInt64 where
|
||||
intCast x := UInt64.ofInt x
|
||||
|
||||
end UInt64
|
||||
|
||||
namespace USize
|
||||
|
||||
instance : IntCast USize where
|
||||
intCast x := USize.ofInt x
|
||||
|
||||
end USize
|
||||
namespace Lean.Grind
|
||||
|
||||
instance : CommRing UInt8 where
|
||||
@@ -19,6 +66,15 @@ instance : CommRing UInt8 where
|
||||
mul_one := UInt8.mul_one
|
||||
left_distrib _ _ _ := UInt8.mul_add
|
||||
zero_mul _ := UInt8.zero_mul
|
||||
sub_eq_add_neg := UInt8.sub_eq_add_neg
|
||||
pow_zero := UInt8.pow_zero
|
||||
pow_succ := UInt8.pow_succ
|
||||
ofNat_succ x := UInt8.ofNat_add x 1
|
||||
|
||||
instance : IsCharP UInt8 (2 ^ 8) where
|
||||
ofNat_eq_zero_iff {x} := by
|
||||
have : OfNat.ofNat x = UInt8.ofNat x := rfl
|
||||
simp [this, UInt8.ofNat_eq_iff_mod_eq_toNat]
|
||||
|
||||
instance : CommRing UInt16 where
|
||||
add_assoc := UInt16.add_assoc
|
||||
@@ -30,6 +86,15 @@ instance : CommRing UInt16 where
|
||||
mul_one := UInt16.mul_one
|
||||
left_distrib _ _ _ := UInt16.mul_add
|
||||
zero_mul _ := UInt16.zero_mul
|
||||
sub_eq_add_neg := UInt16.sub_eq_add_neg
|
||||
pow_zero := UInt16.pow_zero
|
||||
pow_succ := UInt16.pow_succ
|
||||
ofNat_succ x := UInt16.ofNat_add x 1
|
||||
|
||||
instance : IsCharP UInt16 (2 ^ 16) where
|
||||
ofNat_eq_zero_iff {x} := by
|
||||
have : OfNat.ofNat x = UInt16.ofNat x := rfl
|
||||
simp [this, UInt16.ofNat_eq_iff_mod_eq_toNat]
|
||||
|
||||
instance : CommRing UInt32 where
|
||||
add_assoc := UInt32.add_assoc
|
||||
@@ -41,6 +106,15 @@ instance : CommRing UInt32 where
|
||||
mul_one := UInt32.mul_one
|
||||
left_distrib _ _ _ := UInt32.mul_add
|
||||
zero_mul _ := UInt32.zero_mul
|
||||
sub_eq_add_neg := UInt32.sub_eq_add_neg
|
||||
pow_zero := UInt32.pow_zero
|
||||
pow_succ := UInt32.pow_succ
|
||||
ofNat_succ x := UInt32.ofNat_add x 1
|
||||
|
||||
instance : IsCharP UInt32 (2 ^ 32) where
|
||||
ofNat_eq_zero_iff {x} := by
|
||||
have : OfNat.ofNat x = UInt32.ofNat x := rfl
|
||||
simp [this, UInt32.ofNat_eq_iff_mod_eq_toNat]
|
||||
|
||||
instance : CommRing UInt64 where
|
||||
add_assoc := UInt64.add_assoc
|
||||
@@ -52,6 +126,15 @@ instance : CommRing UInt64 where
|
||||
mul_one := UInt64.mul_one
|
||||
left_distrib _ _ _ := UInt64.mul_add
|
||||
zero_mul _ := UInt64.zero_mul
|
||||
sub_eq_add_neg := UInt64.sub_eq_add_neg
|
||||
pow_zero := UInt64.pow_zero
|
||||
pow_succ := UInt64.pow_succ
|
||||
ofNat_succ x := UInt64.ofNat_add x 1
|
||||
|
||||
instance : IsCharP UInt64 (2 ^ 64) where
|
||||
ofNat_eq_zero_iff {x} := by
|
||||
have : OfNat.ofNat x = UInt64.ofNat x := rfl
|
||||
simp [this, UInt64.ofNat_eq_iff_mod_eq_toNat]
|
||||
|
||||
instance : CommRing USize where
|
||||
add_assoc := USize.add_assoc
|
||||
@@ -63,5 +146,16 @@ instance : CommRing USize where
|
||||
mul_one := USize.mul_one
|
||||
left_distrib _ _ _ := USize.mul_add
|
||||
zero_mul _ := USize.zero_mul
|
||||
sub_eq_add_neg := USize.sub_eq_add_neg
|
||||
pow_zero := USize.pow_zero
|
||||
pow_succ := USize.pow_succ
|
||||
ofNat_succ x := USize.ofNat_add x 1
|
||||
|
||||
open System.Platform
|
||||
|
||||
instance : IsCharP USize (2 ^ numBits) where
|
||||
ofNat_eq_zero_iff {x} := by
|
||||
have : OfNat.ofNat x = USize.ofNat x := rfl
|
||||
simp [this, USize.ofNat_eq_iff_mod_eq_toNat]
|
||||
|
||||
end Lean.Grind
|
||||
|
||||
@@ -165,4 +165,9 @@ theorem of_decide_eq_false {p : Prop} {_ : Decidable p} : decide p = false → p
|
||||
theorem decide_eq_true {p : Prop} {_ : Decidable p} : p = True → decide p = true := by simp
|
||||
theorem decide_eq_false {p : Prop} {_ : Decidable p} : p = False → decide p = false := by simp
|
||||
|
||||
/-! Lookahead -/
|
||||
|
||||
theorem of_lookahead (p : Prop) (h : (¬ p) → False) : p = True := by
|
||||
simp at h; simp [h]
|
||||
|
||||
end Lean.Grind
|
||||
|
||||
@@ -74,11 +74,11 @@ theorem bne_eq_decide_not_eq {_ : BEq α} [LawfulBEq α] [DecidableEq α] (a b :
|
||||
theorem xor_eq (a b : Bool) : (a ^^ b) = (a != b) := by
|
||||
rfl
|
||||
|
||||
theorem natCast_div (a b : Nat) : (↑(a / b) : Int) = ↑a / ↑b := by
|
||||
rfl
|
||||
|
||||
theorem natCast_mod (a b : Nat) : (↑(a % b) : Int) = ↑a % ↑b := by
|
||||
rfl
|
||||
theorem natCast_eq [NatCast α] (a : Nat) : (Nat.cast a : α) = (NatCast.natCast a : α) := rfl
|
||||
theorem natCast_div (a b : Nat) : (NatCast.natCast (a / b) : Int) = (NatCast.natCast a) / (NatCast.natCast b) := rfl
|
||||
theorem natCast_mod (a b : Nat) : (NatCast.natCast (a % b) : Int) = (NatCast.natCast a) % (NatCast.natCast b) := rfl
|
||||
theorem natCast_add (a b : Nat) : (NatCast.natCast (a + b : Nat) : Int) = (NatCast.natCast a : Int) + (NatCast.natCast b : Int) := rfl
|
||||
theorem natCast_mul (a b : Nat) : (NatCast.natCast (a * b : Nat) : Int) = (NatCast.natCast a : Int) * (NatCast.natCast b : Int) := rfl
|
||||
|
||||
theorem Nat.pow_one (a : Nat) : a ^ 1 = a := by
|
||||
simp
|
||||
@@ -153,8 +153,10 @@ init_grind_norm
|
||||
Int.emod_neg Int.ediv_neg
|
||||
Int.ediv_zero Int.emod_zero
|
||||
Int.ediv_one Int.emod_one
|
||||
Int.natCast_add Int.natCast_mul Int.natCast_pow
|
||||
Int.natCast_zero natCast_div natCast_mod
|
||||
|
||||
natCast_eq natCast_div natCast_mod
|
||||
natCast_add natCast_mul
|
||||
|
||||
Int.pow_zero Int.pow_one
|
||||
-- GT GE
|
||||
ge_eq gt_eq
|
||||
|
||||
@@ -25,7 +25,8 @@ syntax grindUsr := &"usr "
|
||||
syntax grindCases := &"cases "
|
||||
syntax grindCasesEager := atomic(&"cases" &"eager ")
|
||||
syntax grindIntro := &"intro "
|
||||
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro
|
||||
syntax grindExt := &"ext "
|
||||
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro <|> grindExt
|
||||
syntax (name := grind) "grind" (grindMod)? : attr
|
||||
end Attr
|
||||
end Lean.Parser
|
||||
@@ -68,8 +69,17 @@ structure Config where
|
||||
failures : Nat := 1
|
||||
/-- Maximum number of heartbeats (in thousands) the canonicalizer can spend per definitional equality test. -/
|
||||
canonHeartbeats : Nat := 1000
|
||||
/-- If `ext` is `true`, `grind` uses extensionality theorems available in the environment. -/
|
||||
/-- If `ext` is `true`, `grind` uses extensionality theorems that have been marked with `[grind ext]`. -/
|
||||
ext : Bool := true
|
||||
/-- If `extAll` is `true`, `grind` uses any extensionality theorems available in the environment. -/
|
||||
extAll : Bool := false
|
||||
/--
|
||||
If `funext` is `true`, `grind` creates new opportunities for applying function extensionality by case-splitting
|
||||
on equalities between lambda expressions.
|
||||
-/
|
||||
funext : Bool := true
|
||||
/-- TODO -/
|
||||
lookahead : Bool := true
|
||||
/-- If `verbose` is `false`, additional diagnostics information is not collected. -/
|
||||
verbose : Bool := true
|
||||
/-- If `clean` is `true`, `grind` uses `expose_names` and only generates accessible names. -/
|
||||
|
||||
@@ -44,6 +44,13 @@ register_builtin_option Elab.async : Bool := {
|
||||
`Lean.Command.State.snapshotTasks`."
|
||||
}
|
||||
|
||||
register_builtin_option Elab.inServer : Bool := {
|
||||
defValue := false
|
||||
descr := "true if elaboration is being run inside the Lean language server\
|
||||
\n\
|
||||
\nThis option is set by the file worker and should not be modified otherwise."
|
||||
}
|
||||
|
||||
/-- Performance option used by cmdline driver. -/
|
||||
register_builtin_option internal.cmdlineSnapshots : Bool := {
|
||||
defValue := false
|
||||
|
||||
@@ -111,6 +111,7 @@ inductive Message where
|
||||
| response (id : RequestID) (result : Json)
|
||||
/-- A non-successful response. -/
|
||||
| responseError (id : RequestID) (code : ErrorCode) (message : String) (data? : Option Json)
|
||||
deriving Inhabited
|
||||
|
||||
def Batch := Array Message
|
||||
|
||||
|
||||
@@ -510,6 +510,7 @@ where go := do
|
||||
let oldCmds? := oldSnap?.map fun old =>
|
||||
if old.newStx.isOfKind nullKind then old.newStx.getArgs else #[old.newStx]
|
||||
let cmdPromises ← cmds.mapM fun _ => IO.Promise.new
|
||||
let cancelTk? := (← read).cancelTk?
|
||||
snap.new.resolve <| .ofTyped {
|
||||
diagnostics := .empty
|
||||
macroDecl := decl
|
||||
@@ -517,7 +518,7 @@ where go := do
|
||||
newNextMacroScope := nextMacroScope
|
||||
hasTraces
|
||||
next := Array.zipWith (fun cmdPromise cmd =>
|
||||
{ stx? := some cmd, task := cmdPromise.resultD default }) cmdPromises cmds
|
||||
{ stx? := some cmd, task := cmdPromise.resultD default, cancelTk? }) cmdPromises cmds
|
||||
: MacroExpandedSnapshot
|
||||
}
|
||||
-- After the first command whose syntax tree changed, we must disable
|
||||
|
||||
@@ -6,6 +6,7 @@ Authors: Leonardo de Moura, Sebastian Ullrich
|
||||
prelude
|
||||
import Lean.Parser.Module
|
||||
import Lean.Util.Paths
|
||||
import Lean.CoreM
|
||||
|
||||
namespace Lean.Elab
|
||||
|
||||
@@ -21,9 +22,16 @@ def processHeader (header : Syntax) (opts : Options) (messages : MessageLog)
|
||||
(inputCtx : Parser.InputContext) (trustLevel : UInt32 := 0)
|
||||
(plugins : Array System.FilePath := #[]) (leakEnv := false)
|
||||
: IO (Environment × MessageLog) := do
|
||||
let level := if experimental.module.get opts then
|
||||
if Elab.inServer.get opts then
|
||||
.server
|
||||
else
|
||||
.exported
|
||||
else
|
||||
.private
|
||||
try
|
||||
let env ←
|
||||
importModules (leakEnv := leakEnv) (loadExts := true) (headerToImports header) opts trustLevel plugins
|
||||
importModules (leakEnv := leakEnv) (loadExts := true) (level := level) (headerToImports header) opts trustLevel plugins
|
||||
pure (env, messages)
|
||||
catch e =>
|
||||
let env ← mkEmptyEnvironment
|
||||
|
||||
@@ -165,6 +165,8 @@ private def elabHeaders (views : Array DefView) (expandedDeclIds : Array ExpandD
|
||||
-- no syntax guard to store, we already did the necessary checks
|
||||
oldBodySnap? := guard reuseBody *> pure ⟨.missing, old.bodySnap⟩
|
||||
if oldBodySnap?.isNone then
|
||||
-- NOTE: this will eagerly cancel async tasks not associated with an inner snapshot, most
|
||||
-- importantly kernel checking and compilation of the top-level declaration
|
||||
old.bodySnap.cancelRec
|
||||
oldTacSnap? := do
|
||||
guard reuseTac
|
||||
@@ -217,6 +219,7 @@ private def elabHeaders (views : Array DefView) (expandedDeclIds : Array ExpandD
|
||||
return newHeader
|
||||
if let some snap := view.headerSnap? then
|
||||
let (tacStx?, newTacTask?) ← mkTacTask view.value tacPromise
|
||||
let cancelTk? := (← readThe Core.Context).cancelTk?
|
||||
let bodySnap := {
|
||||
stx? := view.value
|
||||
reportingRange? :=
|
||||
@@ -227,6 +230,8 @@ private def elabHeaders (views : Array DefView) (expandedDeclIds : Array ExpandD
|
||||
else
|
||||
getBodyTerm? view.value |>.getD view.value |>.getRange?
|
||||
task := bodyPromise.resultD default
|
||||
-- We should not cancel the entire body early if we have tactics
|
||||
cancelTk? := guard newTacTask?.isNone *> cancelTk?
|
||||
}
|
||||
snap.new.resolve <| some {
|
||||
diagnostics :=
|
||||
@@ -269,7 +274,8 @@ where
|
||||
:= do
|
||||
if let some e := getBodyTerm? body then
|
||||
if let `(by $tacs*) := e then
|
||||
return (e, some { stx? := mkNullNode tacs, task := tacPromise.resultD default })
|
||||
let cancelTk? := (← readThe Core.Context).cancelTk?
|
||||
return (e, some { stx? := mkNullNode tacs, task := tacPromise.resultD default, cancelTk? })
|
||||
tacPromise.resolve default
|
||||
return (none, none)
|
||||
|
||||
@@ -432,8 +438,7 @@ private def elabFunValues (headers : Array DefViewElabHeader) (vars : Array Expr
|
||||
snap.new.resolve <| some old
|
||||
reusableResult? := some (old.value, old.state)
|
||||
else
|
||||
-- NOTE: this will eagerly cancel async tasks not associated with an inner snapshot, most
|
||||
-- importantly kernel checking and compilation of the top-level declaration
|
||||
-- make sure to cancel any async tasks that may still be running (e.g. kernel and codegen)
|
||||
old.val.cancelRec
|
||||
|
||||
let (val, state) ← withRestoreOrSaveFull reusableResult? header.tacSnap? do
|
||||
@@ -1158,7 +1163,7 @@ is error-free and contains no syntactical `sorry`s.
|
||||
-/
|
||||
private def logGoalsAccomplishedSnapshotTask (views : Array DefView)
|
||||
(defsParsedSnap : DefsParsedSnapshot) : TermElabM Unit := do
|
||||
if Lean.internal.cmdlineSnapshots.get (← getOptions) then
|
||||
if Lean.Elab.inServer.get (← getOptions) then
|
||||
-- Skip 'goals accomplished' task if we are on the command line.
|
||||
-- These messages are only used in the language server.
|
||||
return
|
||||
@@ -1197,6 +1202,7 @@ private def logGoalsAccomplishedSnapshotTask (views : Array DefView)
|
||||
-- Use first line of the mutual block to avoid covering the progress of the whole mutual block
|
||||
reportingRange? := (← getRef).getPos?.map fun pos => ⟨pos, pos⟩
|
||||
task := logGoalsAccomplishedTask
|
||||
cancelTk? := none
|
||||
}
|
||||
|
||||
end Term
|
||||
@@ -1235,9 +1241,10 @@ def elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
|
||||
} }
|
||||
if snap.old?.isSome && (view.headerSnap?.bind (·.old?)).isNone then
|
||||
snap.old?.forM (·.val.cancelRec)
|
||||
let cancelTk? := (← read).cancelTk?
|
||||
defs := defs.push {
|
||||
fullHeaderRef
|
||||
headerProcessedSnap := { stx? := d, task := headerPromise.resultD default }
|
||||
headerProcessedSnap := { stx? := d, task := headerPromise.resultD default, cancelTk? }
|
||||
}
|
||||
reusedAllHeaders := reusedAllHeaders && view.headerSnap?.any (·.old?.isSome)
|
||||
views := views.push view
|
||||
|
||||
@@ -109,11 +109,11 @@ structure InductiveView where
|
||||
/-- Elaborated header for an inductive type before fvars for each inductive are added to the local context. -/
|
||||
structure PreElabHeaderResult where
|
||||
view : InductiveView
|
||||
lctx : LocalContext
|
||||
localInsts : LocalInstances
|
||||
levelNames : List Name
|
||||
params : Array Expr
|
||||
numParams : Nat
|
||||
type : Expr
|
||||
/-- The parameters in the header's initial local context. Used for adding fvar alias terminfo. -/
|
||||
origParams : Array Expr
|
||||
deriving Inhabited
|
||||
|
||||
/-- The elaborated header with the `indFVar` registered for this inductive type. -/
|
||||
@@ -228,16 +228,12 @@ private def checkClass (rs : Array PreElabHeaderResult) : TermElabM Unit := do
|
||||
throwErrorAt r.view.ref "invalid inductive type, mutual classes are not supported"
|
||||
|
||||
private def checkNumParams (rs : Array PreElabHeaderResult) : TermElabM Nat := do
|
||||
let numParams := rs[0]!.params.size
|
||||
let numParams := rs[0]!.numParams
|
||||
for r in rs do
|
||||
unless r.params.size == numParams do
|
||||
unless r.numParams == numParams do
|
||||
throwErrorAt r.view.ref "invalid inductive type, number of parameters mismatch in mutually inductive datatypes"
|
||||
return numParams
|
||||
|
||||
private def mkTypeFor (r : PreElabHeaderResult) : TermElabM Expr := do
|
||||
withLCtx r.lctx r.localInsts do
|
||||
mkForallFVars r.params r.type
|
||||
|
||||
/--
|
||||
Execute `k` with updated binder information for `xs`. Any `x` that is explicit becomes implicit.
|
||||
-/
|
||||
@@ -276,7 +272,7 @@ private def checkHeaders (rs : Array PreElabHeaderResult) (numParams : Nat) (i :
|
||||
checkHeaders rs numParams (i+1) type
|
||||
where
|
||||
checkHeader (r : PreElabHeaderResult) (numParams : Nat) (firstType? : Option Expr) : TermElabM Expr := do
|
||||
let type ← mkTypeFor r
|
||||
let type := r.type
|
||||
match firstType? with
|
||||
| none => return type
|
||||
| some firstType =>
|
||||
@@ -306,7 +302,8 @@ private def elabHeadersAux (views : Array InductiveView) (i : Nat) (acc : Array
|
||||
let params ← Term.addAutoBoundImplicits params (view.declId.getTailPos? (canonicalOnly := true))
|
||||
trace[Elab.inductive] "header params: {params}, type: {type}"
|
||||
let levelNames ← Term.getLevelNames
|
||||
return acc.push { lctx := (← getLCtx), localInsts := (← getLocalInstances), levelNames, params, type, view }
|
||||
let type ← mkForallFVars params type
|
||||
return acc.push { levelNames, numParams := params.size, type, view, origParams := params }
|
||||
elabHeadersAux views (i+1) acc
|
||||
else
|
||||
return acc
|
||||
@@ -326,21 +323,21 @@ private def elabHeaders (views : Array InductiveView) : TermElabM (Array PreElab
|
||||
/--
|
||||
Create a local declaration for each inductive type in `rs`, and execute `x params indFVars`, where `params` are the inductive type parameters and
|
||||
`indFVars` are the new local declarations.
|
||||
We use the local context/instances and parameters of rs[0].
|
||||
We use the parameters of rs[0].
|
||||
Note that this method is executed after we executed `checkHeaders` and established all
|
||||
parameters are compatible.
|
||||
-/
|
||||
private def withInductiveLocalDecls (rs : Array PreElabHeaderResult) (x : Array Expr → Array Expr → TermElabM α) : TermElabM α := do
|
||||
let namesAndTypes ← rs.mapM fun r => do
|
||||
let type ← mkTypeFor r
|
||||
pure (r.view.declName, r.view.shortDeclName, type)
|
||||
let r0 := rs[0]!
|
||||
let params := r0.params
|
||||
withLCtx r0.lctx r0.localInsts <| withRef r0.view.ref do
|
||||
let r0 := rs[0]!
|
||||
forallBoundedTelescope r0.type r0.numParams fun params _ => withRef r0.view.ref do
|
||||
let rec loop (i : Nat) (indFVars : Array Expr) := do
|
||||
if h : i < namesAndTypes.size then
|
||||
let (declName, shortDeclName, type) := namesAndTypes[i]
|
||||
withAuxDecl shortDeclName type declName fun indFVar => loop (i+1) (indFVars.push indFVar)
|
||||
if h : i < rs.size then
|
||||
let r := rs[i]
|
||||
for param in params, origParam in r.origParams do
|
||||
if let .fvar origFVar := origParam then
|
||||
Elab.pushInfoLeaf <| .ofFVarAliasInfo { id := param.fvarId!, baseId := origFVar, userName := ← param.fvarId!.getUserName }
|
||||
withAuxDecl r.view.shortDeclName r.type r.view.declName fun indFVar =>
|
||||
loop (i+1) (indFVars.push indFVar)
|
||||
else
|
||||
x params indFVars
|
||||
loop 0 #[]
|
||||
@@ -359,26 +356,6 @@ private def ElabHeaderResult.checkLevelNames (rs : Array PreElabHeaderResult) :
|
||||
unless r.levelNames == levelNames do
|
||||
throwErrorAt r.view.ref "invalid inductive type, universe parameters mismatch in mutually inductive datatypes"
|
||||
|
||||
/--
|
||||
We need to work inside a single local context across all the inductive types, so we need to update the `ElabHeaderResult`s
|
||||
so that resultant types refer to the fvars in `params`, the parameters for `rs[0]!` specifically.
|
||||
Also updates the local contexts and local instances in each header.
|
||||
-/
|
||||
private def updateElabHeaderTypes (params : Array Expr) (rs : Array PreElabHeaderResult) (indFVars : Array Expr) : TermElabM (Array ElabHeaderResult) := do
|
||||
rs.mapIdxM fun i r => do
|
||||
/-
|
||||
At this point, because of `withInductiveLocalDecls`, the only fvars that are in context are the ones related to the first inductive type.
|
||||
Because of this, we need to replace the fvars present in each inductive type's header of the mutual block with those of the first inductive.
|
||||
However, some mvars may still be uninstantiated there, and might hide some of the old fvars.
|
||||
As such we first need to synthesize all possible mvars at this stage, instantiate them in the header types and only
|
||||
then replace the parameters' fvars in the header type.
|
||||
|
||||
See issue #3242 (`https://github.com/leanprover/lean4/issues/3242`)
|
||||
-/
|
||||
let type ← instantiateMVars r.type
|
||||
let type := type.replaceFVars r.params params
|
||||
pure { r with lctx := ← getLCtx, localInsts := ← getLocalInstances, type := type, indFVar := indFVars[i]! }
|
||||
|
||||
private def getArity (indType : InductiveType) : MetaM Nat :=
|
||||
forallTelescopeReducing indType.type fun xs _ => return xs.size
|
||||
|
||||
@@ -878,7 +855,7 @@ private def mkInductiveDecl (vars : Array Expr) (elabs : Array InductiveElabStep
|
||||
trace[Elab.inductive] "level names: {allUserLevelNames}"
|
||||
let res ← withInductiveLocalDecls rs fun params indFVars => do
|
||||
trace[Elab.inductive] "indFVars: {indFVars}"
|
||||
let rs ← updateElabHeaderTypes params rs indFVars
|
||||
let rs := Array.zipWith (fun r indFVar => { r with indFVar : ElabHeaderResult }) rs indFVars
|
||||
let mut indTypesArray : Array InductiveType := #[]
|
||||
let mut elabs' := #[]
|
||||
for h : i in [:views.size] do
|
||||
@@ -886,8 +863,7 @@ private def mkInductiveDecl (vars : Array Expr) (elabs : Array InductiveElabStep
|
||||
let r := rs[i]!
|
||||
let elab' ← elabs[i]!.elabCtors rs r params
|
||||
elabs' := elabs'.push elab'
|
||||
let type ← mkForallFVars params r.type
|
||||
indTypesArray := indTypesArray.push { name := r.view.declName, type, ctors := elab'.ctors }
|
||||
indTypesArray := indTypesArray.push { name := r.view.declName, type := r.type, ctors := elab'.ctors }
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let numExplicitParams ← fixedIndicesToParams params.size indTypesArray indFVars
|
||||
trace[Elab.inductive] "numExplicitParams: {numExplicitParams}"
|
||||
|
||||
@@ -1260,12 +1260,8 @@ private def addDefaults (levelParams : List Name) (params : Array Expr) (replace
|
||||
let fieldInfos := (← get).fields
|
||||
let lctx ← instantiateLCtxMVars (← getLCtx)
|
||||
/- The parameters `params` for the auxiliary "default value" definitions must be marked as implicit, and all others as explicit. -/
|
||||
let lctx :=
|
||||
params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) =>
|
||||
if p.isFVar then
|
||||
lctx.setBinderInfo p.fvarId! BinderInfo.implicit
|
||||
else
|
||||
lctx
|
||||
let lctx := params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) =>
|
||||
lctx.setBinderInfo p.fvarId! BinderInfo.implicit
|
||||
let parentFVarIds := fieldInfos |>.filter (·.kind.isParent) |>.map (·.fvar.fvarId!)
|
||||
let fields := fieldInfos |>.filter (!·.kind.isParent)
|
||||
withLCtx lctx (← getLocalInstances) do
|
||||
|
||||
@@ -165,6 +165,7 @@ structure EvalTacticFailure where
|
||||
state : SavedState
|
||||
|
||||
partial def evalTactic (stx : Syntax) : TacticM Unit := do
|
||||
checkSystem "tactic execution"
|
||||
profileitM Exception "tactic execution" (decl := stx.getKind) (← getOptions) <|
|
||||
withRef stx <| withIncRecDepth <| withFreshMacroScope <| match stx with
|
||||
| .node _ k _ =>
|
||||
@@ -240,6 +241,7 @@ where
|
||||
snap.old?.forM (·.val.cancelRec)
|
||||
let promise ← IO.Promise.new
|
||||
-- Store new unfolding in the snapshot tree
|
||||
let cancelTk? := (← readThe Core.Context).cancelTk?
|
||||
snap.new.resolve {
|
||||
stx := stx'
|
||||
diagnostics := .empty
|
||||
@@ -249,7 +251,7 @@ where
|
||||
state? := (← Tactic.saveState)
|
||||
moreSnaps := #[]
|
||||
}
|
||||
next := #[{ stx? := stx', task := promise.resultD default }]
|
||||
next := #[{ stx? := stx', task := promise.resultD default, cancelTk? }]
|
||||
}
|
||||
-- Update `tacSnap?` to old unfolding
|
||||
withTheReader Term.Context ({ · with tacSnap? := some {
|
||||
|
||||
@@ -78,13 +78,14 @@ where
|
||||
let next ← IO.Promise.new
|
||||
let finished ← IO.Promise.new
|
||||
let inner ← IO.Promise.new
|
||||
let cancelTk? := (← readThe Core.Context).cancelTk?
|
||||
snap.new.resolve {
|
||||
desc := tac.getKind.toString
|
||||
diagnostics := .empty
|
||||
stx := tac
|
||||
inner? := some { stx? := tac, task := inner.resultD default }
|
||||
finished := { stx? := tac, task := finished.resultD default }
|
||||
next := #[{ stx? := stxs, task := next.resultD default }]
|
||||
inner? := some { stx? := tac, task := inner.resultD default, cancelTk? }
|
||||
finished := { stx? := tac, task := finished.resultD default, cancelTk? }
|
||||
next := #[{ stx? := stxs, task := next.resultD default, cancelTk? }]
|
||||
}
|
||||
-- Run `tac` in a fresh info tree state and store resulting state in snapshot for
|
||||
-- incremental reporting, then add back saved trees. Here we rely on `evalTactic`
|
||||
|
||||
@@ -89,6 +89,8 @@ def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic.
|
||||
params ← withRef p <| addEMatchTheorem params ctor .default
|
||||
else
|
||||
throwError "invalid use of `intro` modifier, `{declName}` is not an inductive predicate"
|
||||
| .ext =>
|
||||
throwError "`[grind ext]` cannot be set using parameters"
|
||||
| .infer =>
|
||||
if let some declName ← Grind.isCasesAttrCandidate? declName false then
|
||||
params := { params with casesTypes := params.casesTypes.insert declName false }
|
||||
|
||||
@@ -286,14 +286,15 @@ where
|
||||
-- them, eventually put each of them back in `Context.tacSnap?` in `applyAltStx`
|
||||
let finished ← IO.Promise.new
|
||||
let altPromises ← altStxs.mapM fun _ => IO.Promise.new
|
||||
let cancelTk? := (← readThe Core.Context).cancelTk?
|
||||
tacSnap.new.resolve {
|
||||
-- save all relevant syntax here for comparison with next document version
|
||||
stx := mkNullNode altStxs
|
||||
diagnostics := .empty
|
||||
inner? := none
|
||||
finished := { stx? := mkNullNode altStxs, reportingRange? := none, task := finished.resultD default }
|
||||
finished := { stx? := mkNullNode altStxs, reportingRange? := none, task := finished.resultD default, cancelTk? }
|
||||
next := Array.zipWith
|
||||
(fun stx prom => { stx? := some stx, task := prom.resultD default })
|
||||
(fun stx prom => { stx? := some stx, task := prom.resultD default, cancelTk? })
|
||||
altStxs altPromises
|
||||
}
|
||||
goWithIncremental <| altPromises.mapIdx fun i prom => {
|
||||
|
||||
@@ -1878,13 +1878,14 @@ where
|
||||
go todo (autos.push auto)
|
||||
|
||||
/--
|
||||
Similar to `autoBoundImplicits`, but immediately if the resulting array of expressions contains metavariables,
|
||||
it immediately uses `mkForallFVars` + `forallBoundedTelescope` to convert them into free variables.
|
||||
Similar to `addAutoBoundImplicits`, but converts all metavariables into free variables.
|
||||
|
||||
It uses `mkForallFVars` + `forallBoundedTelescope` to convert metavariables into free variables.
|
||||
The type `type` is modified during the process if type depends on `xs`.
|
||||
We use this method to simplify the conversion of code using `autoBoundImplicitsOld` to `autoBoundImplicits`.
|
||||
-/
|
||||
def addAutoBoundImplicits' (xs : Array Expr) (type : Expr) (k : Array Expr → Expr → TermElabM α) : TermElabM α := do
|
||||
let xs ← addAutoBoundImplicits xs none
|
||||
def addAutoBoundImplicits' (xs : Array Expr) (type : Expr) (k : Array Expr → Expr → TermElabM α) (inlayHintPos? : Option String.Pos := none) : TermElabM α := do
|
||||
let xs ← addAutoBoundImplicits xs inlayHintPos?
|
||||
if xs.all (·.isFVar) then
|
||||
k xs type
|
||||
else
|
||||
|
||||
@@ -122,13 +122,15 @@ end TagDeclarationExtension
|
||||
structure MapDeclarationExtension (α : Type) extends PersistentEnvExtension (Name × α) (Name × α) (NameMap α)
|
||||
deriving Inhabited
|
||||
|
||||
def mkMapDeclarationExtension (name : Name := by exact decl_name%) : IO (MapDeclarationExtension α) :=
|
||||
def mkMapDeclarationExtension (name : Name := by exact decl_name%)
|
||||
(exportEntriesFn : NameMap α → Array (Name × α) := (·.toArray)) : IO (MapDeclarationExtension α) :=
|
||||
.mk <$> registerPersistentEnvExtension {
|
||||
name := name,
|
||||
mkInitial := pure {}
|
||||
addImportedFn := fun _ => pure {}
|
||||
addEntryFn := fun s (n, v) => s.insert n v
|
||||
exportEntriesFn := fun s => s.toArray
|
||||
saveEntriesFn := fun s => s.toArray
|
||||
exportEntriesFn
|
||||
asyncMode := .async
|
||||
replay? := some fun _ newState newConsts s =>
|
||||
newConsts.foldl (init := s) fun s c =>
|
||||
@@ -145,10 +147,11 @@ def insert (ext : MapDeclarationExtension α) (env : Environment) (declName : Na
|
||||
assert! env.asyncMayContain declName
|
||||
ext.addEntry env (declName, val)
|
||||
|
||||
def find? [Inhabited α] (ext : MapDeclarationExtension α) (env : Environment) (declName : Name) : Option α :=
|
||||
def find? [Inhabited α] (ext : MapDeclarationExtension α) (env : Environment) (declName : Name)
|
||||
(includeServer := false) : Option α :=
|
||||
match env.getModuleIdxFor? declName with
|
||||
| some modIdx =>
|
||||
match (ext.getModuleEntries env modIdx).binSearch (declName, default) (fun a b => Name.quickLt a.1 b.1) with
|
||||
match (ext.getModuleEntries (includeServer := includeServer) env modIdx).binSearch (declName, default) (fun a b => Name.quickLt a.1 b.1) with
|
||||
| some e => some e.2
|
||||
| none => none
|
||||
| none => (ext.findStateAsync env declName).find? declName
|
||||
|
||||
@@ -506,6 +506,12 @@ structure Environment where
|
||||
-/
|
||||
base : Kernel.Environment
|
||||
/--
|
||||
Additional imported environment extension state for use in the language server. This field is
|
||||
identical to `base.extensions` in other contexts. Access via
|
||||
`getModuleEntries (includeServer := true)`.
|
||||
-/
|
||||
private serverBaseExts : Array EnvExtensionState := base.extensions
|
||||
/--
|
||||
Kernel environment task that is fulfilled when all asynchronously elaborated declarations are
|
||||
finished, containing the resulting environment. Also collects the environment extension state of
|
||||
all environment branches that contributed contained declarations.
|
||||
@@ -536,6 +542,12 @@ structure Environment where
|
||||
`findAsyncCore?`/`findStateAsync`; see there.
|
||||
-/
|
||||
private allRealizations : Task (NameMap AsyncConst) := .pure {}
|
||||
/--
|
||||
Indicates whether the environment is being used in an exported context, i.e. whether it should
|
||||
provide access to only the data to be imported by other modules participating in the module
|
||||
system.
|
||||
-/
|
||||
isExporting : Bool := false
|
||||
deriving Nonempty
|
||||
|
||||
namespace Environment
|
||||
@@ -549,6 +561,10 @@ def ofKernelEnv (env : Kernel.Environment) : Environment :=
|
||||
def toKernelEnv (env : Environment) : Kernel.Environment :=
|
||||
env.checked.get
|
||||
|
||||
/-- Updates `Environment.isExporting`. -/
|
||||
def setExporting (env : Environment) (isExporting : Bool) : Environment :=
|
||||
{ env with isExporting }
|
||||
|
||||
/-- Consistently updates synchronous and asynchronous parts of the environment without blocking. -/
|
||||
private def modifyCheckedAsync (env : Environment) (f : Kernel.Environment → Kernel.Environment) : Environment :=
|
||||
{ env with checked := env.checked.map (sync := true) f, base := f env.base }
|
||||
@@ -1379,7 +1395,15 @@ structure PersistentEnvExtension (α : Type) (β : Type) (σ : Type) where
|
||||
name : Name
|
||||
addImportedFn : Array (Array α) → ImportM σ
|
||||
addEntryFn : σ → β → σ
|
||||
/-- Function to transform state into data that should always be imported into other modules. -/
|
||||
exportEntriesFn : σ → Array α
|
||||
/--
|
||||
Function to transform state into data that should be imported into other modules when the module
|
||||
system is disabled. When it is enabled, the data is loaded only in the language server and
|
||||
accessible via `getModuleEntries (includeServer := true)`. Conventionally, this is a superset of
|
||||
the data returned by `exportEntriesFn`.
|
||||
-/
|
||||
saveEntriesFn : σ → Array α
|
||||
statsFn : σ → Format
|
||||
|
||||
instance {α σ} [Inhabited σ] : Inhabited (PersistentEnvExtensionState α σ) :=
|
||||
@@ -1392,14 +1416,21 @@ instance {α β σ} [Inhabited σ] : Inhabited (PersistentEnvExtension α β σ)
|
||||
addImportedFn := fun _ => default,
|
||||
addEntryFn := fun s _ => s,
|
||||
exportEntriesFn := fun _ => #[],
|
||||
saveEntriesFn := fun _ => #[],
|
||||
statsFn := fun _ => Format.nil
|
||||
}
|
||||
|
||||
namespace PersistentEnvExtension
|
||||
|
||||
def getModuleEntries {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment) (m : ModuleIdx) : Array α :=
|
||||
-- `importedEntries` is identical on all environment branches, so `local` is always sufficient
|
||||
(ext.toEnvExtension.getState (asyncMode := .local) env).importedEntries[m]!
|
||||
/--
|
||||
Returns the data saved by `ext.exportEntriesFn/saveEntriesFn` when `m` was elaborated. See docs on
|
||||
the functions for details.
|
||||
-/
|
||||
def getModuleEntries {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ)
|
||||
(env : Environment) (m : ModuleIdx) (includeServer := false) : Array α :=
|
||||
let exts := if includeServer then env.serverBaseExts else env.base.extensions
|
||||
-- safety: as in `getStateUnsafe`
|
||||
unsafe (ext.toEnvExtension.getStateImpl exts).importedEntries[m]!
|
||||
|
||||
def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (b : β) : Environment :=
|
||||
ext.toEnvExtension.modifyState env fun s =>
|
||||
@@ -1436,10 +1467,14 @@ structure PersistentEnvExtensionDescr (α β σ : Type) where
|
||||
addImportedFn : Array (Array α) → ImportM σ
|
||||
addEntryFn : σ → β → σ
|
||||
exportEntriesFn : σ → Array α
|
||||
saveEntriesFn : σ → Array α := exportEntriesFn
|
||||
statsFn : σ → Format := fun _ => Format.nil
|
||||
asyncMode : EnvExtension.AsyncMode := .mainOnly
|
||||
replay? : Option (ReplayFn σ) := none
|
||||
|
||||
attribute [inherit_doc PersistentEnvExtension.exportEntriesFn] PersistentEnvExtensionDescr.exportEntriesFn
|
||||
attribute [inherit_doc PersistentEnvExtension.saveEntriesFn] PersistentEnvExtensionDescr.saveEntriesFn
|
||||
|
||||
unsafe def registerPersistentEnvExtensionUnsafe {α β σ : Type} [Inhabited σ] (descr : PersistentEnvExtensionDescr α β σ) : IO (PersistentEnvExtension α β σ) := do
|
||||
let pExts ← persistentEnvExtensionsRef.get
|
||||
if pExts.any (fun ext => ext.name == descr.name) then throw (IO.userError s!"invalid environment extension, '{descr.name}' has already been used")
|
||||
@@ -1458,6 +1493,7 @@ unsafe def registerPersistentEnvExtensionUnsafe {α β σ : Type} [Inhabited σ]
|
||||
addImportedFn := descr.addImportedFn,
|
||||
addEntryFn := descr.addEntryFn,
|
||||
exportEntriesFn := descr.exportEntriesFn,
|
||||
saveEntriesFn := descr.saveEntriesFn,
|
||||
statsFn := descr.statsFn
|
||||
}
|
||||
persistentEnvExtensionsRef.modify fun pExts => pExts.push (unsafeCast pExt)
|
||||
@@ -1466,10 +1502,30 @@ unsafe def registerPersistentEnvExtensionUnsafe {α β σ : Type} [Inhabited σ]
|
||||
@[implemented_by registerPersistentEnvExtensionUnsafe]
|
||||
opaque registerPersistentEnvExtension {α β σ : Type} [Inhabited σ] (descr : PersistentEnvExtensionDescr α β σ) : IO (PersistentEnvExtension α β σ)
|
||||
|
||||
@[extern "lean_save_module_data"]
|
||||
opaque saveModuleData (fname : @& System.FilePath) (mod : @& Name) (data : @& ModuleData) : IO Unit
|
||||
@[extern "lean_read_module_data"]
|
||||
opaque readModuleData (fname : @& System.FilePath) : IO (ModuleData × CompactedRegion)
|
||||
/--
|
||||
Stores each given module data in the respective file name. Objects shared with prior parts are not
|
||||
duplicated. Thus the data cannot be loaded with individual `readModuleData` calls but must loaded by
|
||||
passing (a prefix of) the file names to `readModuleDataParts`. `mod` is used to determine an
|
||||
arbitrary but deterministic base address for `mmap`.
|
||||
-/
|
||||
@[extern "lean_save_module_data_parts"]
|
||||
opaque saveModuleDataParts (mod : @& Name) (parts : Array (System.FilePath × ModuleData)) : IO Unit
|
||||
|
||||
/--
|
||||
Loads the module data from the given file names. The files must be (a prefix of) the result of a
|
||||
`saveModuleDataParts` call.
|
||||
-/
|
||||
@[extern "lean_read_module_data_parts"]
|
||||
opaque readModuleDataParts (fnames : @& Array System.FilePath) : IO (Array (ModuleData × CompactedRegion))
|
||||
|
||||
def saveModuleData (fname : System.FilePath) (mod : Name) (data : ModuleData) : IO Unit :=
|
||||
saveModuleDataParts mod #[(fname, data)]
|
||||
|
||||
def readModuleData (fname : @& System.FilePath) : IO (ModuleData × CompactedRegion) := do
|
||||
let parts ← readModuleDataParts #[fname]
|
||||
assert! parts.size == 1
|
||||
let some part := parts[0]? | unreachable!
|
||||
return part
|
||||
|
||||
/--
|
||||
Free compacted regions of imports. No live references to imported objects may exist at the time of invocation; in
|
||||
@@ -1493,7 +1549,22 @@ unsafe def Environment.freeRegions (env : Environment) : IO Unit :=
|
||||
TODO: statically check for this. -/
|
||||
env.header.regions.forM CompactedRegion.free
|
||||
|
||||
def mkModuleData (env : Environment) : IO ModuleData := do
|
||||
/-- The level of information to save/load. Each level includes all previous ones. -/
|
||||
inductive OLeanLevel where
|
||||
/-- Information from exported contexts. -/
|
||||
| exported
|
||||
/-- Environment extension state for the language server. -/
|
||||
| server
|
||||
/-- Private module data. -/
|
||||
| «private»
|
||||
deriving DecidableEq
|
||||
|
||||
def OLeanLevel.adjustFileName (base : System.FilePath) : OLeanLevel → System.FilePath
|
||||
| .exported => base
|
||||
| .server => base.addExtension "server"
|
||||
| .private => base.addExtension "private"
|
||||
|
||||
def mkModuleData (env : Environment) (level : OLeanLevel := .private) : IO ModuleData := do
|
||||
let pExts ← persistentEnvExtensionsRef.get
|
||||
let entries := pExts.map fun pExt => Id.run do
|
||||
-- get state from `checked` at the end if `async`; it would otherwise panic
|
||||
@@ -1501,19 +1572,37 @@ def mkModuleData (env : Environment) : IO ModuleData := do
|
||||
if asyncMode matches .async then
|
||||
asyncMode := .sync
|
||||
let state := pExt.getState (asyncMode := asyncMode) env
|
||||
(pExt.name, pExt.exportEntriesFn state)
|
||||
(pExt.name, if level = .exported then pExt.exportEntriesFn state else pExt.saveEntriesFn state)
|
||||
let kenv := env.toKernelEnv
|
||||
let constNames := kenv.constants.foldStage2 (fun names name _ => names.push name) #[]
|
||||
let constants := kenv.constants.foldStage2 (fun cs _ c => cs.push c) #[]
|
||||
let env := env.setExporting (level != .private)
|
||||
let constants := kenv.constants.foldStage2 (fun cs _ c => cs.push c) #[]
|
||||
--let constNames := kenv.constants.foldStage2 (fun names name _ => names.push name) #[]
|
||||
-- not all kernel constants may be exported
|
||||
-- TODO: does not include cstage* constants from the old codegen
|
||||
--let constants := constNames.filterMap env.find?
|
||||
let constNames := constants.map (·.name)
|
||||
return {
|
||||
imports := env.header.imports
|
||||
extraConstNames := env.checked.get.extraConstNames.toArray
|
||||
constNames, constants, entries
|
||||
}
|
||||
|
||||
register_builtin_option experimental.module : Bool := {
|
||||
defValue := false
|
||||
descr := "Enable module system (experimental)"
|
||||
}
|
||||
|
||||
@[export lean_write_module]
|
||||
def writeModule (env : Environment) (fname : System.FilePath) : IO Unit := do
|
||||
saveModuleData fname env.mainModule (← mkModuleData env)
|
||||
def writeModule (env : Environment) (fname : System.FilePath) (split := false) : IO Unit := do
|
||||
if split then
|
||||
let mkPart (level : OLeanLevel) :=
|
||||
return (level.adjustFileName fname, (← mkModuleData env level))
|
||||
saveModuleDataParts env.mainModule #[
|
||||
(← mkPart .exported),
|
||||
(← mkPart .server),
|
||||
(← mkPart .private)]
|
||||
else
|
||||
saveModuleData fname env.mainModule (← mkModuleData env)
|
||||
|
||||
/--
|
||||
Construct a mapping from persistent extension name to extension index at the array of persistent extensions.
|
||||
@@ -1527,10 +1616,9 @@ def mkExtNameMap (startingAt : Nat) : IO (Std.HashMap Name Nat) := do
|
||||
result := result.insert descr.name i
|
||||
return result
|
||||
|
||||
private def setImportedEntries (env : Environment) (mods : Array ModuleData) (startingAt : Nat := 0) : IO Environment := do
|
||||
-- We work directly on the states array instead of `env` as `Environment.modifyState` introduces
|
||||
-- significant overhead on such frequent calls
|
||||
let mut states := env.base.extensions
|
||||
private def setImportedEntries (states : Array EnvExtensionState) (mods : Array ModuleData)
|
||||
(startingAt : Nat := 0) : IO (Array EnvExtensionState) := do
|
||||
let mut states := states
|
||||
let extDescrs ← persistentEnvExtensionsRef.get
|
||||
/- For extensions starting at `startingAt`, ensure their `importedEntries` array have size `mods.size`. -/
|
||||
for extDescr in extDescrs[startingAt:] do
|
||||
@@ -1546,7 +1634,7 @@ private def setImportedEntries (env : Environment) (mods : Array ModuleData) (st
|
||||
-- safety: as in `modifyState`
|
||||
states := unsafe extDescrs[entryIdx]!.toEnvExtension.modifyStateImpl states fun s =>
|
||||
{ s with importedEntries := s.importedEntries.set! modIdx entries }
|
||||
return env.setCheckedSync { env.base with extensions := states }
|
||||
return states
|
||||
|
||||
/--
|
||||
"Forward declaration" needed for updating the attribute table with user-defined attributes.
|
||||
@@ -1585,7 +1673,7 @@ where
|
||||
-- This branch is executed when `pExtDescrs[i]` is the extension associated with the `init` attribute, and
|
||||
-- a user-defined persistent extension is imported.
|
||||
-- Thus, we invoke `setImportedEntries` to update the array `importedEntries` with the entries for the new extensions.
|
||||
env ← setImportedEntries env mods prevSize
|
||||
env := env.setCheckedSync { env.base with extensions := (← setImportedEntries env.base.extensions mods prevSize) }
|
||||
-- See comment at `updateEnvAttributesRef`
|
||||
env ← updateEnvAttributes env
|
||||
loop (i + 1) env
|
||||
@@ -1596,7 +1684,7 @@ structure ImportState where
|
||||
moduleNameSet : NameHashSet := {}
|
||||
moduleNames : Array Name := #[]
|
||||
moduleData : Array ModuleData := #[]
|
||||
regions : Array CompactedRegion := #[]
|
||||
parts : Array (Array (ModuleData × CompactedRegion)) := #[]
|
||||
|
||||
def throwAlreadyImported (s : ImportState) (const2ModIdx : Std.HashMap Name ModuleIdx) (modIdx : Nat) (cname : Name) : IO α := do
|
||||
let modName := s.moduleNames[modIdx]!
|
||||
@@ -1608,7 +1696,8 @@ abbrev ImportStateM := StateRefT ImportState IO
|
||||
@[inline] nonrec def ImportStateM.run (x : ImportStateM α) (s : ImportState := {}) : IO (α × ImportState) :=
|
||||
x.run s
|
||||
|
||||
partial def importModulesCore (imports : Array Import) : ImportStateM Unit := do
|
||||
partial def importModulesCore (imports : Array Import) (level := OLeanLevel.private) :
|
||||
ImportStateM Unit := do
|
||||
for i in imports do
|
||||
if i.runtimeOnly || (← get).moduleNameSet.contains i.module then
|
||||
continue
|
||||
@@ -1616,12 +1705,22 @@ partial def importModulesCore (imports : Array Import) : ImportStateM Unit := do
|
||||
let mFile ← findOLean i.module
|
||||
unless (← mFile.pathExists) do
|
||||
throw <| IO.userError s!"object file '{mFile}' of module {i.module} does not exist"
|
||||
let (mod, region) ← readModuleData mFile
|
||||
importModulesCore mod.imports
|
||||
let mut fnames := #[mFile]
|
||||
if level != OLeanLevel.exported then
|
||||
let sFile := OLeanLevel.server.adjustFileName mFile
|
||||
if (← sFile.pathExists) then
|
||||
fnames := fnames.push sFile
|
||||
if level == OLeanLevel.private then
|
||||
let pFile := OLeanLevel.private.adjustFileName mFile
|
||||
if (← pFile.pathExists) then
|
||||
fnames := fnames.push pFile
|
||||
let parts ← readModuleDataParts fnames
|
||||
let some (mod, _) := parts[if level = .exported then 0 else parts.size - 1]? | unreachable!
|
||||
importModulesCore (level := level) mod.imports
|
||||
modify fun s => { s with
|
||||
moduleData := s.moduleData.push mod
|
||||
regions := s.regions.push region
|
||||
moduleNames := s.moduleNames.push i.module
|
||||
parts := s.parts.push parts
|
||||
}
|
||||
|
||||
/--
|
||||
@@ -1685,14 +1784,16 @@ def finalizeImport (s : ImportState) (imports : Array Import) (opts : Options) (
|
||||
extensions := exts
|
||||
header := {
|
||||
trustLevel, imports
|
||||
regions := s.regions
|
||||
regions := s.parts.flatMap (·.map (·.2))
|
||||
moduleNames := s.moduleNames
|
||||
moduleData := s.moduleData
|
||||
}
|
||||
}
|
||||
realizedImportedConsts? := none
|
||||
}
|
||||
env ← setImportedEntries env s.moduleData
|
||||
env := env.setCheckedSync { env.base with extensions := (← setImportedEntries env.base.extensions s.moduleData) }
|
||||
let serverData := s.parts.filterMap fun parts => (parts[1]? <|> parts[0]?).map Prod.fst
|
||||
env := { env with serverBaseExts := (← setImportedEntries env.base.extensions serverData) }
|
||||
if leakEnv then
|
||||
/- Mark persistent a first time before `finalizePersistenExtensions`, which
|
||||
avoids costly MT markings when e.g. an interpreter closure (which
|
||||
@@ -1739,13 +1840,13 @@ environment's constant map can be accessed without `loadExts`, many functions th
|
||||
-/
|
||||
def importModules (imports : Array Import) (opts : Options) (trustLevel : UInt32 := 0)
|
||||
(plugins : Array System.FilePath := #[]) (leakEnv := false) (loadExts := false)
|
||||
: IO Environment := profileitIO "import" opts do
|
||||
(level := OLeanLevel.private) : IO Environment := profileitIO "import" opts do
|
||||
for imp in imports do
|
||||
if imp.module matches .anonymous then
|
||||
throw <| IO.userError "import failed, trying to import module with anonymous name"
|
||||
withImporting do
|
||||
plugins.forM Lean.loadPlugin
|
||||
let (_, s) ← importModulesCore imports |>.run
|
||||
let (_, s) ← importModulesCore (level := level) imports |>.run
|
||||
finalizeImport (leakEnv := leakEnv) (loadExts := loadExts) s imports opts trustLevel
|
||||
|
||||
/--
|
||||
|
||||
@@ -93,18 +93,17 @@ structure SnapshotTask (α : Type) where
|
||||
Cancellation token that can be set by the server to cancel the task when it detects the results
|
||||
are not needed anymore.
|
||||
-/
|
||||
cancelTk? : Option IO.CancelToken := none
|
||||
cancelTk? : Option IO.CancelToken
|
||||
/-- Underlying task producing the snapshot. -/
|
||||
task : Task α
|
||||
deriving Nonempty, Inhabited
|
||||
|
||||
/-- Creates a snapshot task from the syntax processed by the task and a `BaseIO` action. -/
|
||||
def SnapshotTask.ofIO (stx? : Option Syntax)
|
||||
def SnapshotTask.ofIO (stx? : Option Syntax) (cancelTk? : Option IO.CancelToken)
|
||||
(reportingRange? : Option String.Range := defaultReportingRange? stx?) (act : BaseIO α) :
|
||||
BaseIO (SnapshotTask α) := do
|
||||
return {
|
||||
stx?
|
||||
reportingRange?
|
||||
stx?, reportingRange?, cancelTk?
|
||||
task := (← BaseIO.asTask act)
|
||||
}
|
||||
|
||||
@@ -114,6 +113,7 @@ def SnapshotTask.finished (stx? : Option Syntax) (a : α) : SnapshotTask α wher
|
||||
-- irrelevant when already finished
|
||||
reportingRange? := none
|
||||
task := .pure a
|
||||
cancelTk? := none
|
||||
|
||||
/-- Transforms a task's output without changing the processed syntax. -/
|
||||
def SnapshotTask.map (t : SnapshotTask α) (f : α → β) (stx? : Option Syntax := t.stx?)
|
||||
|
||||
@@ -397,7 +397,7 @@ where
|
||||
diagnostics := oldProcessed.diagnostics
|
||||
result? := some {
|
||||
cmdState := oldProcSuccess.cmdState
|
||||
firstCmdSnap := { stx? := none, task := prom.result! } } }
|
||||
firstCmdSnap := { stx? := none, task := prom.result!, cancelTk? := cancelTk } } }
|
||||
else
|
||||
return .finished newStx oldProcessed) } }
|
||||
else return old
|
||||
@@ -450,7 +450,7 @@ where
|
||||
processHeader (stx : Syntax) (parserState : Parser.ModuleParserState) :
|
||||
LeanProcessingM (SnapshotTask HeaderProcessedSnapshot) := do
|
||||
let ctx ← read
|
||||
SnapshotTask.ofIO stx (some ⟨0, ctx.input.endPos⟩) <|
|
||||
SnapshotTask.ofIO stx none (some ⟨0, ctx.input.endPos⟩) <|
|
||||
ReaderT.run (r := ctx) <| -- re-enter reader in new task
|
||||
withHeaderExceptions (α := HeaderProcessedSnapshot) ({ · with result? := none }) do
|
||||
let setup ← match (← setupImports stx) with
|
||||
@@ -507,7 +507,7 @@ where
|
||||
infoTree? := cmdState.infoState.trees[0]!
|
||||
result? := some {
|
||||
cmdState
|
||||
firstCmdSnap := { stx? := none, task := prom.result! }
|
||||
firstCmdSnap := { stx? := none, task := prom.result!, cancelTk? := cancelTk }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -523,17 +523,19 @@ where
|
||||
-- from `old`
|
||||
if let some oldNext := old.nextCmdSnap? then do
|
||||
let newProm ← IO.Promise.new
|
||||
let cancelTk ← IO.CancelToken.new
|
||||
-- can reuse range, syntax unchanged
|
||||
BaseIO.chainTask (sync := true) old.resultSnap.task fun oldResult =>
|
||||
-- also wait on old command parse snapshot as parsing is cheap and may allow for
|
||||
-- elaboration reuse
|
||||
BaseIO.chainTask (sync := true) oldNext.task fun oldNext => do
|
||||
let cancelTk ← IO.CancelToken.new
|
||||
parseCmd oldNext newParserState oldResult.cmdState newProm sync cancelTk ctx
|
||||
prom.resolve <| { old with nextCmdSnap? := some {
|
||||
stx? := none
|
||||
reportingRange? := some ⟨newParserState.pos, ctx.input.endPos⟩
|
||||
task := newProm.result! } }
|
||||
task := newProm.result!
|
||||
cancelTk? := cancelTk
|
||||
} }
|
||||
else prom.resolve old -- terminal command, we're done!
|
||||
|
||||
-- fast path, do not even start new task for this snapshot (see [Incremental Parsing])
|
||||
@@ -615,15 +617,16 @@ where
|
||||
})
|
||||
let diagnostics ← Snapshot.Diagnostics.ofMessageLog msgLog
|
||||
|
||||
-- use per-command cancellation token for elaboration so that
|
||||
-- use per-command cancellation token for elaboration so that cancellation of further commands
|
||||
-- does not affect current command
|
||||
let elabCmdCancelTk ← IO.CancelToken.new
|
||||
prom.resolve {
|
||||
diagnostics, nextCmdSnap?
|
||||
stx := stx', parserState := parserState'
|
||||
elabSnap := { stx? := stx', task := elabPromise.result!, cancelTk? := some elabCmdCancelTk }
|
||||
resultSnap := { stx? := stx', reportingRange? := initRange?, task := resultPromise.result! }
|
||||
infoTreeSnap := { stx? := stx', reportingRange? := initRange?, task := finishedPromise.result! }
|
||||
reportSnap := { stx? := none, reportingRange? := initRange?, task := reportPromise.result! }
|
||||
resultSnap := { stx? := stx', reportingRange? := initRange?, task := resultPromise.result!, cancelTk? := none }
|
||||
infoTreeSnap := { stx? := stx', reportingRange? := initRange?, task := finishedPromise.result!, cancelTk? := none }
|
||||
reportSnap := { stx? := none, reportingRange? := initRange?, task := reportPromise.result!, cancelTk? := none }
|
||||
}
|
||||
let cmdState ← doElab stx cmdState beginPos
|
||||
{ old? := old?.map fun old => ⟨old.stx, old.elabSnap⟩, new := elabPromise }
|
||||
@@ -665,8 +668,8 @@ where
|
||||
-- We want to trace all of `CommandParsedSnapshot` but `traceTask` is part of it, so let's
|
||||
-- create a temporary snapshot tree containing all tasks but it
|
||||
let snaps := #[
|
||||
{ stx? := stx', task := elabPromise.result!.map (sync := true) toSnapshotTree },
|
||||
{ stx? := stx', task := resultPromise.result!.map (sync := true) toSnapshotTree }] ++
|
||||
{ stx? := stx', task := elabPromise.result!.map (sync := true) toSnapshotTree, cancelTk? := none },
|
||||
{ stx? := stx', task := resultPromise.result!.map (sync := true) toSnapshotTree, cancelTk? := none }] ++
|
||||
cmdState.snapshotTasks
|
||||
let tree := SnapshotTree.mk { diagnostics := .empty } snaps
|
||||
BaseIO.bindTask (← tree.waitAll) fun _ => do
|
||||
@@ -690,6 +693,7 @@ where
|
||||
stx? := none
|
||||
reportingRange? := initRange?
|
||||
task := traceTask
|
||||
cancelTk? := none
|
||||
}
|
||||
if let some next := next? then
|
||||
-- We're definitely off the fast-forwarding path now
|
||||
|
||||
@@ -2279,6 +2279,7 @@ def realizeConst (forConst : Name) (constName : Name) (realize : MetaM Unit) :
|
||||
initHeartbeats := (← IO.getNumHeartbeats)
|
||||
}
|
||||
let (env, exTask, dyn) ← env.realizeConst forConst constName (realizeAndReport coreCtx)
|
||||
-- Realizations cannot be cancelled as their result is shared across elaboration runs
|
||||
let exAct ← Core.wrapAsyncAsSnapshot (cancelTk? := none) fun
|
||||
| none => return
|
||||
| some ex => do
|
||||
@@ -2286,6 +2287,7 @@ def realizeConst (forConst : Name) (constName : Name) (realize : MetaM Unit) :
|
||||
Core.logSnapshotTask {
|
||||
stx? := none
|
||||
task := (← BaseIO.mapTask (t := exTask) exAct)
|
||||
cancelTk? := none
|
||||
}
|
||||
if let some res := dyn.get? RealizeConstantResult then
|
||||
let mut snap := res.snap
|
||||
|
||||
@@ -62,6 +62,15 @@ This is triggered by `attribute [-ext] name`.
|
||||
def ExtTheorems.eraseCore (d : ExtTheorems) (declName : Name) : ExtTheorems :=
|
||||
{ d with erased := d.erased.insert declName }
|
||||
|
||||
/-- Returns `true` if `d` contains theorem with name `declName`. -/
|
||||
def ExtTheorems.contains (d : ExtTheorems) (declName : Name) : Bool :=
|
||||
d.tree.containsValueP (·.declName == declName) && !d.erased.contains declName
|
||||
|
||||
/-- Returns `true` if `declName` is tagged with `[ext]` attribute. -/
|
||||
def isExtTheorem (declName : Name) : CoreM Bool := do
|
||||
let extTheorems := extExtension.getState (← getEnv)
|
||||
return extTheorems.contains declName
|
||||
|
||||
/--
|
||||
Erases a name marked as a `ext` attribute.
|
||||
Check that it does in fact have the `ext` attribute by making sure it names a `ExtTheorem`
|
||||
@@ -69,7 +78,7 @@ found somewhere in the state's tree, and is not erased.
|
||||
-/
|
||||
def ExtTheorems.erase [Monad m] [MonadError m] (d : ExtTheorems) (declName : Name) :
|
||||
m ExtTheorems := do
|
||||
unless d.tree.containsValueP (·.declName == declName) && !d.erased.contains declName do
|
||||
unless d.contains declName do
|
||||
throwError "'{declName}' does not have [ext] attribute"
|
||||
return d.eraseCore declName
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ import Lean.Meta.Tactic.Grind.MatchCond
|
||||
import Lean.Meta.Tactic.Grind.MatchDiscrOnly
|
||||
import Lean.Meta.Tactic.Grind.Diseq
|
||||
import Lean.Meta.Tactic.Grind.MBTC
|
||||
import Lean.Meta.Tactic.Grind.Lookahead
|
||||
|
||||
namespace Lean
|
||||
|
||||
@@ -52,6 +53,12 @@ builtin_initialize registerTraceClass `grind.split.candidate
|
||||
builtin_initialize registerTraceClass `grind.split.resolved
|
||||
builtin_initialize registerTraceClass `grind.beta
|
||||
builtin_initialize registerTraceClass `grind.mbtc
|
||||
builtin_initialize registerTraceClass `grind.ext
|
||||
builtin_initialize registerTraceClass `grind.ext.candidate
|
||||
builtin_initialize registerTraceClass `grind.lookahead
|
||||
builtin_initialize registerTraceClass `grind.lookahead.add (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.lookahead.try (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.lookahead.assert (inherited := true)
|
||||
|
||||
/-! Trace options for `grind` developers -/
|
||||
builtin_initialize registerTraceClass `grind.debug
|
||||
@@ -76,5 +83,6 @@ builtin_initialize registerTraceClass `grind.debug.proveEq
|
||||
builtin_initialize registerTraceClass `grind.debug.pushNewFact
|
||||
builtin_initialize registerTraceClass `grind.debug.ematch.activate
|
||||
builtin_initialize registerTraceClass `grind.debug.appMap
|
||||
builtin_initialize registerTraceClass `grind.debug.ext
|
||||
|
||||
end Lean
|
||||
|
||||
@@ -22,50 +22,17 @@ namespace Lean
|
||||
|
||||
builtin_initialize registerTraceClass `grind.cutsat
|
||||
builtin_initialize registerTraceClass `grind.cutsat.model
|
||||
builtin_initialize registerTraceClass `grind.cutsat.subst
|
||||
builtin_initialize registerTraceClass `grind.cutsat.eq
|
||||
builtin_initialize registerTraceClass `grind.cutsat.eq.unsat (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.cutsat.eq.trivial (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.cutsat.assert
|
||||
builtin_initialize registerTraceClass `grind.cutsat.assert.dvd
|
||||
builtin_initialize registerTraceClass `grind.cutsat.dvd
|
||||
builtin_initialize registerTraceClass `grind.cutsat.dvd.update (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.cutsat.dvd.unsat (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.cutsat.dvd.trivial (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.cutsat.dvd.solve (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.cutsat.dvd.solve.combine (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.cutsat.dvd.solve.elim (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.cutsat.internalize
|
||||
builtin_initialize registerTraceClass `grind.cutsat.internalize.term (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.cutsat.assert.trivial
|
||||
builtin_initialize registerTraceClass `grind.cutsat.assert.unsat
|
||||
builtin_initialize registerTraceClass `grind.cutsat.assert.store
|
||||
|
||||
builtin_initialize registerTraceClass `grind.cutsat.assert.le
|
||||
builtin_initialize registerTraceClass `grind.cutsat.le
|
||||
builtin_initialize registerTraceClass `grind.cutsat.le.unsat (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.cutsat.le.trivial (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.cutsat.le.lower (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.cutsat.le.upper (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.cutsat.assign
|
||||
builtin_initialize registerTraceClass `grind.cutsat.conflict
|
||||
|
||||
builtin_initialize registerTraceClass `grind.cutsat.diseq
|
||||
builtin_initialize registerTraceClass `grind.cutsat.diseq.trivial (inherited := true)
|
||||
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.eq
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.dvd.le
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.diseq
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.diseq.split
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.backtrack
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.search
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.cooper
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.cooper.diseq
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.conflict
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.assign
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.subst
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.getBestLower
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.nat
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.proof
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.search
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.search.split (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.search.assign (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.search.conflict (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.search.backtrack (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.internalize
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.markTerm
|
||||
builtin_initialize registerTraceClass `grind.debug.cutsat.natCast
|
||||
|
||||
end Lean
|
||||
|
||||
@@ -34,7 +34,7 @@ def DvdCnstr.applyEq (a : Int) (x : Var) (c₁ : EqCnstr) (b : Int) (c₂ : DvdC
|
||||
let q := c₂.p
|
||||
let d := Int.ofNat (a * c₂.d).natAbs
|
||||
let p := (q.mul a |>.combine (p.mul (-b)))
|
||||
trace[grind.cutsat.subst] "{← getVar x}, {← c₁.pp}, {← c₂.pp}"
|
||||
trace[grind.debug.cutsat.subst] "{← getVar x}, {← c₁.pp}, {← c₂.pp}"
|
||||
return { d, p, h := .subst x c₁ c₂ }
|
||||
|
||||
partial def DvdCnstr.applySubsts (c : DvdCnstr) : GoalM DvdCnstr := withIncRecDepth do
|
||||
@@ -46,21 +46,20 @@ partial def DvdCnstr.applySubsts (c : DvdCnstr) : GoalM DvdCnstr := withIncRecDe
|
||||
/-- Asserts divisibility constraint. -/
|
||||
partial def DvdCnstr.assert (c : DvdCnstr) : GoalM Unit := withIncRecDepth do
|
||||
if (← inconsistent) then return ()
|
||||
trace[grind.cutsat.dvd] "{← c.pp}"
|
||||
trace[grind.cutsat.assert] "{← c.pp}"
|
||||
let c ← c.norm.applySubsts
|
||||
if c.isUnsat then
|
||||
trace[grind.cutsat.dvd.unsat] "{← c.pp}"
|
||||
trace[grind.cutsat.assert.unsat] "{← c.pp}"
|
||||
setInconsistent (.dvd c)
|
||||
return ()
|
||||
if c.isTrivial then
|
||||
trace[grind.cutsat.dvd.trivial] "{← c.pp}"
|
||||
trace[grind.cutsat.assert.trivial] "{← c.pp}"
|
||||
return ()
|
||||
let d₁ := c.d
|
||||
let .add a₁ x p₁ := c.p | c.throwUnexpected
|
||||
if (← c.satisfied) == .false then
|
||||
resetAssignmentFrom x
|
||||
if let some c' := (← get').dvds[x]! then
|
||||
trace[grind.cutsat.dvd.solve] "{← c.pp}, {← c'.pp}"
|
||||
let d₂ := c'.d
|
||||
let .add a₂ _ p₂ := c'.p | c'.throwUnexpected
|
||||
let (d, α, β) := gcdExt (a₁*d₂) (a₂*d₁)
|
||||
@@ -75,16 +74,14 @@ partial def DvdCnstr.assert (c : DvdCnstr) : GoalM Unit := withIncRecDepth do
|
||||
let α_d₂_p₁ := p₁.mul (α*d₂)
|
||||
let β_d₁_p₂ := p₂.mul (β*d₁)
|
||||
let combine := { d := d₁*d₂, p := .add d x (α_d₂_p₁.combine β_d₁_p₂), h := .solveCombine c c' : DvdCnstr }
|
||||
trace[grind.cutsat.dvd.solve.combine] "{← combine.pp}"
|
||||
modify' fun s => { s with dvds := s.dvds.set x none}
|
||||
combine.assert
|
||||
let a₂_p₁ := p₁.mul a₂
|
||||
let a₁_p₂ := p₂.mul (-a₁)
|
||||
let elim := { d, p := a₂_p₁.combine a₁_p₂, h := .solveElim c c' : DvdCnstr }
|
||||
trace[grind.cutsat.dvd.solve.elim] "{← elim.pp}"
|
||||
elim.assert
|
||||
else
|
||||
trace[grind.cutsat.dvd.update] "{← c.pp}"
|
||||
trace[grind.cutsat.assert.store] "{← c.pp}"
|
||||
c.p.updateOccs
|
||||
modify' fun s => { s with dvds := s.dvds.set x (some c) }
|
||||
|
||||
@@ -97,7 +94,6 @@ def propagateIntDvd (e : Expr) : GoalM Unit := do
|
||||
if (← isEqTrue e) then
|
||||
let p ← toPoly b
|
||||
let c := { d, p, h := .core e : DvdCnstr }
|
||||
trace[grind.cutsat.assert.dvd] "{← c.pp}"
|
||||
c.assert
|
||||
else if (← isEqFalse e) then
|
||||
pushNewFact <| mkApp4 (mkConst ``Int.Linear.of_not_dvd) a b reflBoolTrue (mkOfEqFalseCore e (← mkEqFalseProof e))
|
||||
|
||||
@@ -38,12 +38,12 @@ def DiseqCnstr.applyEq (a : Int) (x : Var) (c₁ : EqCnstr) (b : Int) (c₂ : Di
|
||||
let p := c₁.p
|
||||
let q := c₂.p
|
||||
let p := p.mul b |>.combine (q.mul (-a))
|
||||
trace[grind.cutsat.subst] "{← getVar x}, {← c₁.pp}, {← c₂.pp}"
|
||||
trace[grind.debug.cutsat.subst] "{← getVar x}, {← c₁.pp}, {← c₂.pp}"
|
||||
return { p, h := .subst x c₁ c₂ }
|
||||
|
||||
partial def DiseqCnstr.applySubsts (c : DiseqCnstr) : GoalM DiseqCnstr := withIncRecDepth do
|
||||
let some (x, c₁, p) ← c.p.substVar | return c
|
||||
trace[grind.cutsat.subst] "{← getVar x}, {← c.pp}, {← c₁.pp}"
|
||||
trace[grind.debug.cutsat.subst] "{← getVar x}, {← c.pp}, {← c₁.pp}"
|
||||
applySubsts { p, h := .subst x c₁ c }
|
||||
|
||||
/--
|
||||
@@ -68,10 +68,11 @@ def DiseqCnstr.assert (c : DiseqCnstr) : GoalM Unit := do
|
||||
trace[grind.cutsat.assert] "{← c.pp}"
|
||||
let c ← c.norm.applySubsts
|
||||
if c.p.isUnsatDiseq then
|
||||
trace[grind.cutsat.assert.unsat] "{← c.pp}"
|
||||
setInconsistent (.diseq c)
|
||||
return ()
|
||||
if c.isTrivial then
|
||||
trace[grind.cutsat.diseq.trivial] "{← c.pp}"
|
||||
trace[grind.cutsat.assert.trivial] "{← c.pp}"
|
||||
return ()
|
||||
let k := c.p.gcdCoeffs c.p.getConst
|
||||
let c := if k == 1 then
|
||||
@@ -82,7 +83,7 @@ def DiseqCnstr.assert (c : DiseqCnstr) : GoalM Unit := do
|
||||
return ()
|
||||
let .add _ x _ := c.p | c.throwUnexpected
|
||||
c.p.updateOccs
|
||||
trace[grind.cutsat.diseq] "{← c.pp}"
|
||||
trace[grind.cutsat.assert.store] "{← c.pp}"
|
||||
modify' fun s => { s with diseqs := s.diseqs.modify x (·.push c) }
|
||||
if (← c.satisfied) == .false then
|
||||
resetAssignmentFrom x
|
||||
@@ -108,7 +109,7 @@ where
|
||||
|
||||
partial def EqCnstr.applySubsts (c : EqCnstr) : GoalM EqCnstr := withIncRecDepth do
|
||||
let some (x, c₁, p) ← c.p.substVar | return c
|
||||
trace[grind.cutsat.subst] "{← getVar x}, {← c.pp}, {← c₁.pp}"
|
||||
trace[grind.debug.cutsat.subst] "{← getVar x}, {← c.pp}, {← c₁.pp}"
|
||||
applySubsts { p, h := .subst x c₁ c : EqCnstr }
|
||||
|
||||
private def updateDvdCnstr (a : Int) (x : Var) (c : EqCnstr) (y : Var) : GoalM Unit := do
|
||||
@@ -197,10 +198,11 @@ def EqCnstr.assertImpl (c : EqCnstr) : GoalM Unit := do
|
||||
trace[grind.cutsat.assert] "{← c.pp}"
|
||||
let c ← c.norm.applySubsts
|
||||
if c.p.isUnsatEq then
|
||||
trace[grind.cutsat.assert.unsat] "{← c.pp}"
|
||||
setInconsistent (.eq c)
|
||||
return ()
|
||||
if c.isTrivial then
|
||||
trace[grind.cutsat.eq.trivial] "{← c.pp}"
|
||||
trace[grind.cutsat.assert.trivial] "{← c.pp}"
|
||||
return ()
|
||||
let k := c.p.gcdCoeffs'
|
||||
if c.p.getConst % k > 0 then
|
||||
@@ -210,9 +212,9 @@ def EqCnstr.assertImpl (c : EqCnstr) : GoalM Unit := do
|
||||
c
|
||||
else
|
||||
{ p := c.p.div k, h := .divCoeffs c }
|
||||
trace[grind.cutsat.eq] "{← c.pp}"
|
||||
let some (k, x) := c.p.pickVarToElim? | c.throwUnexpected
|
||||
trace[grind.debug.cutsat.subst] ">> {← getVar x}, {← c.pp}"
|
||||
trace[grind.cutsat.assert.store] "{← c.pp}"
|
||||
modify' fun s => { s with
|
||||
elimEqs := s.elimEqs.set x (some c)
|
||||
elimStack := x :: s.elimStack
|
||||
@@ -252,7 +254,6 @@ private def processNewNatEq (a b : Expr) : GoalM Unit := do
|
||||
|
||||
@[export lean_process_cutsat_eq]
|
||||
def processNewEqImpl (a b : Expr) : GoalM Unit := do
|
||||
trace[grind.debug.cutsat.eq] "{a} = {b}"
|
||||
match (← foreignTerm? a), (← foreignTerm? b) with
|
||||
| none, none => processNewIntEq a b
|
||||
| some .nat, some .nat => processNewNatEq a b
|
||||
@@ -271,7 +272,6 @@ private def processNewIntLitEq (a ke : Expr) : GoalM Unit := do
|
||||
|
||||
@[export lean_process_cutsat_eq_lit]
|
||||
def processNewEqLitImpl (a ke : Expr) : GoalM Unit := do
|
||||
trace[grind.debug.cutsat.eq] "{a} = {ke}"
|
||||
match (← foreignTerm? a) with
|
||||
| none => processNewIntLitEq a ke
|
||||
| some .nat => processNewNatEq a ke
|
||||
@@ -294,12 +294,10 @@ private def processNewNatDiseq (a b : Expr) : GoalM Unit := do
|
||||
let rhs' ← toLinearExpr (← rhs.denoteAsIntExpr ctx) gen
|
||||
let p := lhs'.sub rhs' |>.norm
|
||||
let c := { p, h := .coreNat a b lhs rhs lhs' rhs' : DiseqCnstr }
|
||||
trace[grind.debug.cutsat.nat] "{← c.pp}"
|
||||
c.assert
|
||||
|
||||
@[export lean_process_cutsat_diseq]
|
||||
def processNewDiseqImpl (a b : Expr) : GoalM Unit := do
|
||||
trace[grind.debug.cutsat.diseq] "{a} ≠ {b}"
|
||||
match (← foreignTerm? a), (← foreignTermOrLit? b) with
|
||||
| none, none => processNewIntDiseq a b
|
||||
| some .nat, some .nat => processNewNatDiseq a b
|
||||
@@ -342,7 +340,7 @@ private def isForbiddenParent (parent? : Option Expr) (k : SupportedTermKind) :
|
||||
private def internalizeInt (e : Expr) : GoalM Unit := do
|
||||
if (← hasVar e) then return ()
|
||||
let p ← toPoly e
|
||||
trace[grind.cutsat.internalize] "{aquote e}:= {← p.pp}"
|
||||
trace[grind.debug.cutsat.internalize] "{aquote e}:= {← p.pp}"
|
||||
let x ← mkVar e
|
||||
if p == .add 1 x (.num 0) then
|
||||
-- It is pointless to assert `x = x`
|
||||
@@ -403,9 +401,8 @@ private def internalizeNat (e : Expr) : GoalM Unit := do
|
||||
let e'' : Int.Linear.Expr ← toLinearExpr e'' gen
|
||||
let p := e''.norm
|
||||
let natCast_e ← shareCommon (mkIntNatCast e)
|
||||
trace[grind.cutsat.internalize] "natCast: {natCast_e}"
|
||||
internalize natCast_e gen
|
||||
trace[grind.cutsat.internalize] "{aquote natCast_e}:= {← p.pp}"
|
||||
trace[grind.debug.cutsat.internalize] "{aquote natCast_e}:= {← p.pp}"
|
||||
let x ← mkVar natCast_e
|
||||
modify' fun s => { s with foreignDef := s.foreignDef.insert { expr := e } x }
|
||||
let c := { p := .add (-1) x p, h := .defnNat e' x e'' : EqCnstr }
|
||||
|
||||
@@ -24,7 +24,6 @@ def mkForeignVar (e : Expr) (t : ForeignType) : GoalM Var := do
|
||||
foreignVars := s.foreignVars.insert t (vars.push e)
|
||||
foreignVarMap := s.foreignVarMap.insert { expr := e} (x, t)
|
||||
}
|
||||
trace[grind.debug.cutsat.markTerm] "mkForeignVar: {e}"
|
||||
markAsCutsatTerm e
|
||||
return x
|
||||
|
||||
|
||||
@@ -100,24 +100,24 @@ where
|
||||
@[export lean_grind_cutsat_assert_le]
|
||||
def LeCnstr.assertImpl (c : LeCnstr) : GoalM Unit := do
|
||||
if (← inconsistent) then return ()
|
||||
trace[grind.cutsat.assert] "{← c.pp}"
|
||||
let c ← c.norm.applySubsts
|
||||
if c.isUnsat then
|
||||
trace[grind.cutsat.le.unsat] "{← c.pp}"
|
||||
trace[grind.cutsat.assert.unsat] "{← c.pp}"
|
||||
setInconsistent (.le c)
|
||||
return ()
|
||||
if c.isTrivial then
|
||||
trace[grind.cutsat.le.trivial] "{← c.pp}"
|
||||
trace[grind.cutsat.assert.trivial] "{← c.pp}"
|
||||
return ()
|
||||
let .add a x _ := c.p | c.throwUnexpected
|
||||
if (← findEq c) then
|
||||
return ()
|
||||
let c ← refineWithDiseq c
|
||||
trace[grind.cutsat.assert.store] "{← c.pp}"
|
||||
if a < 0 then
|
||||
trace[grind.cutsat.le.lower] "{← c.pp}"
|
||||
c.p.updateOccs
|
||||
modify' fun s => { s with lowers := s.lowers.modify x (·.push c) }
|
||||
else
|
||||
trace[grind.cutsat.le.upper] "{← c.pp}"
|
||||
c.p.updateOccs
|
||||
modify' fun s => { s with uppers := s.uppers.modify x (·.push c) }
|
||||
if (← c.satisfied) == .false then
|
||||
@@ -145,7 +145,6 @@ def propagateIntLe (e : Expr) (eqTrue : Bool) : GoalM Unit := do
|
||||
pure { p, h := .core e : LeCnstr }
|
||||
else
|
||||
pure { p := p.mul (-1) |>.addConst 1, h := .coreNeg e p : LeCnstr }
|
||||
trace[grind.cutsat.assert.le] "{← c.pp}"
|
||||
c.assert
|
||||
|
||||
def propagateNatLe (e : Expr) (eqTrue : Bool) : GoalM Unit := do
|
||||
@@ -155,7 +154,6 @@ def propagateNatLe (e : Expr) (eqTrue : Bool) : GoalM Unit := do
|
||||
let lhs' ← toLinearExpr (← lhs.denoteAsIntExpr ctx) gen
|
||||
let rhs' ← toLinearExpr (← rhs.denoteAsIntExpr ctx) gen
|
||||
let p := lhs'.sub rhs' |>.norm
|
||||
trace[grind.debug.cutsat.nat] "{← p.pp}"
|
||||
let c ← if eqTrue then
|
||||
pure { p, h := .coreNat e lhs rhs lhs' rhs' : LeCnstr }
|
||||
else
|
||||
|
||||
@@ -136,9 +136,7 @@ def assertDenoteAsIntNonneg (e : Expr) : GoalM Unit := withIncRecDepth do
|
||||
let lhs' : Int.Linear.Expr := .num 0
|
||||
let rhs' ← toLinearExpr (← rhs.denoteAsIntExpr ctx) gen
|
||||
let p := lhs'.sub rhs' |>.norm
|
||||
trace[grind.debug.cutsat.nat] "{← p.pp}"
|
||||
let c := { p, h := .denoteAsIntNonneg rhs rhs' : LeCnstr }
|
||||
trace[grind.cutsat.assert.le] "{← c.pp}"
|
||||
c.assert
|
||||
|
||||
/--
|
||||
@@ -149,7 +147,6 @@ def assertNatCast (e : Expr) (x : Var) : GoalM Unit := do
|
||||
let_expr NatCast.natCast _ inst a := e | return ()
|
||||
let_expr instNatCastInt := inst | return ()
|
||||
if (← get').foreignDef.contains { expr := a } then return ()
|
||||
trace[grind.debug.cutsat.natCast] "{a}"
|
||||
let n ← mkForeignVar a .nat
|
||||
let p := .add (-1) x (.num 0)
|
||||
let c := { p, h := .denoteAsIntNonneg (.var n) (.var x) : LeCnstr}
|
||||
|
||||
@@ -301,7 +301,6 @@ partial def LeCnstr.toExprProof (c' : LeCnstr) : ProofM Expr := caching c' do
|
||||
(← getContext) (← mkPolyDecl p₁) (← mkPolyDecl p₂) (← mkPolyDecl c₃.p) (toExpr c₃.d) (toExpr s.k) (toExpr coeff) (← mkPolyDecl c'.p) (← s.toExprProof) reflBoolTrue
|
||||
|
||||
partial def DiseqCnstr.toExprProof (c' : DiseqCnstr) : ProofM Expr := caching c' do
|
||||
trace[grind.debug.cutsat.proof] "{← c'.pp}"
|
||||
match c'.h with
|
||||
| .core0 a zero =>
|
||||
mkDiseqProof a zero
|
||||
@@ -352,7 +351,6 @@ partial def CooperSplit.toExprProof (s : CooperSplit) : ProofM Expr := caching s
|
||||
-- `pred` is an expressions of the form `cooper_*_split ...` with type `Nat → Prop`
|
||||
let mut k := n
|
||||
let mut result := base -- `OrOver k (cooper_*_splti)
|
||||
trace[grind.debug.cutsat.proof] "orOver_cases {n}"
|
||||
result := mkApp3 (mkConst ``Int.Linear.orOver_cases) (toExpr (n-1)) pred result
|
||||
for (fvarId, c) in hs do
|
||||
let type := mkApp pred (toExpr (k-1))
|
||||
@@ -365,7 +363,6 @@ partial def CooperSplit.toExprProof (s : CooperSplit) : ProofM Expr := caching s
|
||||
return result
|
||||
|
||||
partial def UnsatProof.toExprProofCore (h : UnsatProof) : ProofM Expr := do
|
||||
trace[grind.debug.cutsat.proof] "{← h.pp}"
|
||||
match h with
|
||||
| .le c =>
|
||||
return mkApp4 (mkConst ``Int.Linear.le_unsat) (← getContext) (← mkPolyDecl c.p) reflBoolTrue (← c.toExprProof)
|
||||
@@ -392,7 +389,6 @@ def UnsatProof.toExprProof (h : UnsatProof) : GoalM Expr := do
|
||||
withProofContext do h.toExprProofCore
|
||||
|
||||
def setInconsistent (h : UnsatProof) : GoalM Unit := do
|
||||
trace[grind.debug.cutsat.conflict] "setInconsistent [{← inconsistent}]: {← h.pp}"
|
||||
if (← get').caseSplits then
|
||||
-- Let the search procedure in `SearchM` resolve the conflict.
|
||||
modify' fun s => { s with conflict? := some h }
|
||||
|
||||
@@ -27,7 +27,6 @@ def CooperSplit.assert (cs : CooperSplit) : GoalM Unit := do
|
||||
let p₁' := p.mul b |>.combine (q.mul (-a))
|
||||
let p₁' := p₁'.addConst <| if left then b*k else (-a)*k
|
||||
let c₁' := { p := p₁', h := .cooper cs : LeCnstr }
|
||||
trace[grind.debug.cutsat.cooper] "{← c₁'.pp}"
|
||||
c₁'.assert
|
||||
if (← inconsistent) then return ()
|
||||
let d := if left then a else b
|
||||
@@ -35,7 +34,6 @@ def CooperSplit.assert (cs : CooperSplit) : GoalM Unit := do
|
||||
let p₂' := if left then p else q
|
||||
let p₂' := p₂'.addConst k
|
||||
let c₂' := { d, p := p₂', h := .cooper₁ cs : DvdCnstr }
|
||||
trace[grind.debug.cutsat.cooper] "dvd₁: {← c₂'.pp}"
|
||||
c₂'.assert
|
||||
if (← inconsistent) then return ()
|
||||
let some c₃ := c₃? | return ()
|
||||
@@ -51,7 +49,6 @@ def CooperSplit.assert (cs : CooperSplit) : GoalM Unit := do
|
||||
let p₃' := q.mul (-c) |>.combine (s.mul b)
|
||||
let p₃' := p₃'.addConst (-c*k)
|
||||
{ d := b*d, p := p₃', h := .cooper₂ cs : DvdCnstr }
|
||||
trace[grind.debug.cutsat.cooper] "dvd₂: {← c₃'.pp}"
|
||||
c₃'.assert
|
||||
|
||||
private def checkIsNextVar (x : Var) : GoalM Unit := do
|
||||
@@ -59,7 +56,7 @@ private def checkIsNextVar (x : Var) : GoalM Unit := do
|
||||
throwError "`grind` internal error, assigning variable out of order"
|
||||
|
||||
private def traceAssignment (x : Var) (v : Rat) : GoalM Unit := do
|
||||
trace[grind.cutsat.assign] "{quoteIfArithTerm (← getVar x)} := {v}"
|
||||
trace[grind.debug.cutsat.search.assign] "{quoteIfArithTerm (← getVar x)} := {v}"
|
||||
|
||||
private def setAssignment (x : Var) (v : Rat) : GoalM Unit := do
|
||||
checkIsNextVar x
|
||||
@@ -88,7 +85,6 @@ where
|
||||
modify' fun s => { s with assignment := s.assignment.set x 0 }
|
||||
let some v ← c.p.eval? | c.throwUnexpected
|
||||
let v := (-v) / a
|
||||
trace[grind.debug.cutsat.assign] "{← getVar x}, {← c.pp}, {v}"
|
||||
traceAssignment x v
|
||||
modify' fun s => { s with assignment := s.assignment.set x v }
|
||||
go xs
|
||||
@@ -108,7 +104,6 @@ def tightUsingDvd (c : LeCnstr) (dvd? : Option DvdCnstr) : GoalM LeCnstr := do
|
||||
let b₂ := c.p.getConst
|
||||
if (b₂ - b₁) % d != 0 then
|
||||
let b₂' := b₁ - d * ((b₁ - b₂) / d)
|
||||
trace[grind.debug.cutsat.dvd.le] "[pos] {← c.pp}, {← dvd.pp}, {b₂'}"
|
||||
let p := c.p.addConst (b₂'-b₂)
|
||||
return { p, h := .dvdTight dvd c }
|
||||
if eqCoeffs dvd.p c.p true then
|
||||
@@ -116,7 +111,6 @@ def tightUsingDvd (c : LeCnstr) (dvd? : Option DvdCnstr) : GoalM LeCnstr := do
|
||||
let b₂ := c.p.getConst
|
||||
if (b₂ - b₁) % d != 0 then
|
||||
let b₂' := b₁ - d * ((b₁ - b₂) / d)
|
||||
trace[grind.debug.cutsat.dvd.le] "[neg] {← c.pp}, {← dvd.pp}, {b₂'}"
|
||||
let p := c.p.addConst (b₂'-b₂)
|
||||
return { p, h := .negDvdTight dvd c }
|
||||
return c
|
||||
@@ -134,7 +128,6 @@ def getBestLower? (x : Var) (dvd? : Option DvdCnstr) : GoalM (Option (Rat × LeC
|
||||
let .add k _ p := c.p | c.throwUnexpected
|
||||
let some v ← p.eval? | c.throwUnexpected
|
||||
let lower' := v / (-k)
|
||||
trace[grind.debug.cutsat.getBestLower] "k: {k}, x: {x}, p: {repr p}, v: {v}, best?: {best?.map (·.1)}, c: {← c.pp}"
|
||||
if let some (lower, _) := best? then
|
||||
if lower' > lower then
|
||||
best? := some (lower', c)
|
||||
@@ -214,7 +207,7 @@ def DvdCnstr.getSolutions? (c : DvdCnstr) : SearchM (Option DvdSolution) := do
|
||||
return some { d, b := -b*a' }
|
||||
|
||||
def resolveDvdConflict (c : DvdCnstr) : GoalM Unit := do
|
||||
trace[grind.cutsat.conflict] "{← c.pp}"
|
||||
trace[grind.debug.cutsat.search.conflict] "{← c.pp}"
|
||||
let d := c.d
|
||||
let .add a _ p := c.p | c.throwUnexpected
|
||||
{ d := a.gcd d, p, h := .elim c : DvdCnstr }.assert
|
||||
@@ -300,7 +293,7 @@ partial def findRatVal (lower upper : Rat) (diseqVals : Array (Rat × DiseqCnstr
|
||||
v
|
||||
|
||||
def resolveRealLowerUpperConflict (c₁ c₂ : LeCnstr) : GoalM Bool := do
|
||||
trace[grind.cutsat.conflict] "{← c₁.pp}, {← c₂.pp}"
|
||||
trace[grind.debug.cutsat.search.conflict] "{← c₁.pp}, {← c₂.pp}"
|
||||
let .add a₁ _ p₁ := c₁.p | c₁.throwUnexpected
|
||||
let .add a₂ _ p₂ := c₂.p | c₂.throwUnexpected
|
||||
let p := p₁.mul a₂.natAbs |>.combine (p₂.mul a₁.natAbs)
|
||||
@@ -313,7 +306,7 @@ def resolveRealLowerUpperConflict (c₁ c₂ : LeCnstr) : GoalM Bool := do
|
||||
{ p, h := .combine c₁ c₂ : LeCnstr }
|
||||
else
|
||||
{ p := p.div k, h := .combineDivCoeffs c₁ c₂ k : LeCnstr }
|
||||
trace[grind.cutsat.conflict] "resolved: {← c.pp}"
|
||||
trace[grind.debug.cutsat.search.conflict] "resolved: {← c.pp}"
|
||||
c.assert
|
||||
return true
|
||||
|
||||
@@ -330,7 +323,7 @@ def resolveCooperUnary (pred : CooperSplitPred) : SearchM Bool := do
|
||||
return true
|
||||
|
||||
def resolveCooperPred (pred : CooperSplitPred) : SearchM Unit := do
|
||||
trace[grind.cutsat.conflict] "[{pred.numCases}]: {← pred.pp}"
|
||||
trace[grind.debug.cutsat.search.conflict] "[{pred.numCases}]: {← pred.pp}"
|
||||
if (← resolveCooperUnary pred) then
|
||||
return
|
||||
let n := pred.numCases
|
||||
@@ -347,11 +340,11 @@ def resolveCooperDvd (c₁ c₂ : LeCnstr) (c₃ : DvdCnstr) : SearchM Unit := d
|
||||
|
||||
def DiseqCnstr.split (c : DiseqCnstr) : SearchM LeCnstr := do
|
||||
let fvarId ← if let some fvarId := (← get').diseqSplits.find? c.p then
|
||||
trace[grind.debug.cutsat.diseq.split] "{← c.pp}, reusing {fvarId.name}"
|
||||
trace[grind.debug.cutsat.search.split] "{← c.pp}, reusing {fvarId.name}"
|
||||
pure fvarId
|
||||
else
|
||||
let fvarId ← mkCase (.diseq c)
|
||||
trace[grind.debug.cutsat.diseq.split] "{← c.pp}, {fvarId.name}"
|
||||
trace[grind.debug.cutsat.search.split] "{← c.pp}, {fvarId.name}"
|
||||
modify' fun s => { s with diseqSplits := s.diseqSplits.insert c.p fvarId }
|
||||
pure fvarId
|
||||
let p₂ := c.p.addConst 1
|
||||
@@ -428,7 +421,6 @@ def processVar (x : Var) : SearchM Unit := do
|
||||
setAssignment x v
|
||||
| some (lower, c₁), some (upper, c₂) =>
|
||||
trace[grind.debug.cutsat.search] "{lower} ≤ {lower.ceil} ≤ {quoteIfArithTerm (← getVar x)} ≤ {upper.floor} ≤ {upper}"
|
||||
trace[grind.debug.cutsat.getBestLower] "lower: {lower}, c₁: {← c₁.pp}"
|
||||
if lower > upper then
|
||||
let .true ← resolveRealLowerUpperConflict c₁ c₂
|
||||
| throwError "`grind` internal error, conflict resolution failed"
|
||||
@@ -472,43 +464,42 @@ private def findCase (decVars : FVarIdSet) : SearchM Case := do
|
||||
if decVars.contains case.fvarId then
|
||||
return case
|
||||
-- Conflict does not depend on this case.
|
||||
trace[grind.debug.cutsat.backtrack] "skipping {case.fvarId.name}"
|
||||
trace[grind.debug.cutsat.search.backtrack] "skipping {case.fvarId.name}"
|
||||
unreachable!
|
||||
|
||||
private def union (vs₁ vs₂ : FVarIdSet) : FVarIdSet :=
|
||||
vs₁.fold (init := vs₂) (·.insert ·)
|
||||
|
||||
def resolveConflict (h : UnsatProof) : SearchM Unit := do
|
||||
trace[grind.debug.cutsat.backtrack] "resolve conflict, decision stack: {(← get).cases.toList.map fun c => c.fvarId.name}"
|
||||
trace[grind.debug.cutsat.search.backtrack] "resolve conflict, decision stack: {(← get).cases.toList.map fun c => c.fvarId.name}"
|
||||
let decVars := h.collectDecVars.run (← get).decVars
|
||||
trace[grind.debug.cutsat.backtrack] "dec vars: {decVars.toList.map (·.name)}"
|
||||
trace[grind.debug.cutsat.search.backtrack] "dec vars: {decVars.toList.map (·.name)}"
|
||||
if decVars.isEmpty then
|
||||
trace[grind.debug.cutsat.backtrack] "close goal: {← h.pp}"
|
||||
trace[grind.debug.cutsat.search.backtrack] "close goal: {← h.pp}"
|
||||
closeGoal (← h.toExprProof)
|
||||
return ()
|
||||
let c ← findCase decVars
|
||||
modify' fun _ => c.saved
|
||||
trace[grind.debug.cutsat.backtrack] "backtracking {c.fvarId.name}"
|
||||
trace[grind.debug.cutsat.search.backtrack] "backtracking {c.fvarId.name}"
|
||||
let decVars := decVars.erase c.fvarId
|
||||
match c.kind with
|
||||
| .diseq c₁ =>
|
||||
let decVars := decVars.toArray
|
||||
let p' := c₁.p.mul (-1) |>.addConst 1
|
||||
let c' := { p := p', h := .ofDiseqSplit c₁ c.fvarId h decVars : LeCnstr }
|
||||
trace[grind.debug.cutsat.backtrack] "resolved diseq split: {← c'.pp}"
|
||||
trace[grind.debug.cutsat.search.backtrack] "resolved diseq split: {← c'.pp}"
|
||||
c'.assert
|
||||
| .cooper pred hs decVars' =>
|
||||
let decVars' := union decVars decVars'
|
||||
let n := pred.numCases
|
||||
let hs := hs.push (c.fvarId, h)
|
||||
trace[grind.debug.cutsat.backtrack] "cooper #{hs.size + 1}, {← pred.pp}, {hs.map fun p => p.1.name}"
|
||||
trace[grind.debug.cutsat.search.backtrack] "cooper #{hs.size + 1}, {← pred.pp}, {hs.map fun p => p.1.name}"
|
||||
let s ← if hs.size + 1 < n then
|
||||
let fvarId ← mkCase (.cooper pred hs decVars')
|
||||
pure { pred, k := n - hs.size - 1, h := .dec fvarId : CooperSplit }
|
||||
else
|
||||
let decVars' := decVars'.toArray
|
||||
trace[grind.debug.cutsat.backtrack] "cooper last case, {← pred.pp}, dec vars: {decVars'.map (·.name)}"
|
||||
trace[grind.debug.cutsat.proof] "CooperSplit.last"
|
||||
trace[grind.debug.cutsat.search.backtrack] "cooper last case, {← pred.pp}, dec vars: {decVars'.map (·.name)}"
|
||||
pure { pred, k := 0, h := .last hs decVars' : CooperSplit }
|
||||
s.assert
|
||||
|
||||
|
||||
@@ -74,7 +74,6 @@ def mkCase (kind : CaseKind) : SearchM FVarId := do
|
||||
decVars := s.decVars.insert fvarId
|
||||
}
|
||||
modify' fun s => { s with caseSplits := true }
|
||||
trace[grind.debug.cutsat.backtrack] "mkCase fvarId: {fvarId.name}"
|
||||
return fvarId
|
||||
|
||||
end Lean.Meta.Grind.Arith.Cutsat
|
||||
|
||||
@@ -16,7 +16,7 @@ def mkVarImpl (expr : Expr) : GoalM Var := do
|
||||
if let some var := (← get').varMap.find? { expr } then
|
||||
return var
|
||||
let var : Var := (← get').vars.size
|
||||
trace[grind.cutsat.internalize.term] "{expr} ↦ #{var}"
|
||||
trace[grind.debug.cutsat.internalize] "{expr} ↦ #{var}"
|
||||
modify' fun s => { s with
|
||||
vars := s.vars.push expr
|
||||
varMap := s.varMap.insert { expr } var
|
||||
@@ -27,7 +27,6 @@ def mkVarImpl (expr : Expr) : GoalM Var := do
|
||||
occurs := s.occurs.push {}
|
||||
elimEqs := s.elimEqs.push none
|
||||
}
|
||||
trace[grind.debug.cutsat.markTerm] "mkVar: {expr}"
|
||||
markAsCutsatTerm expr
|
||||
assertNatCast expr var
|
||||
assertDenoteAsIntNonneg expr
|
||||
|
||||
@@ -71,6 +71,7 @@ def isArithTerm (e : Expr) : Bool :=
|
||||
| HMul.hMul _ _ _ _ _ _ => true
|
||||
| HDiv.hDiv _ _ _ _ _ _ => true
|
||||
| HMod.hMod _ _ _ _ _ _ => true
|
||||
| HPow.hPow _ _ _ _ _ _ => true
|
||||
| Neg.neg _ _ _ => true
|
||||
| OfNat.ofNat _ _ _ => true
|
||||
| _ => false
|
||||
|
||||
@@ -6,6 +6,7 @@ Authors: Leonardo de Moura
|
||||
prelude
|
||||
import Lean.Meta.Tactic.Grind.EMatchTheorem
|
||||
import Lean.Meta.Tactic.Grind.Cases
|
||||
import Lean.Meta.Tactic.Grind.ExtAttr
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
@@ -14,6 +15,7 @@ inductive AttrKind where
|
||||
| cases (eager : Bool)
|
||||
| intro
|
||||
| infer
|
||||
| ext
|
||||
|
||||
/-- Return theorem kind for `stx` of the form `Attr.grindThmMod` -/
|
||||
def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do
|
||||
@@ -34,6 +36,7 @@ def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do
|
||||
| `(Parser.Attr.grindMod| cases) => return .cases false
|
||||
| `(Parser.Attr.grindMod| cases eager) => return .cases true
|
||||
| `(Parser.Attr.grindMod| intro) => return .intro
|
||||
| `(Parser.Attr.grindMod| ext) => return .ext
|
||||
| _ => throwError "unexpected `grind` theorem kind: `{stx}`"
|
||||
|
||||
/-- Return theorem kind for `stx` of the form `(Attr.grindMod)?` -/
|
||||
@@ -78,6 +81,7 @@ builtin_initialize
|
||||
addEMatchAttr ctor attrKind .default
|
||||
else
|
||||
throwError "invalid `[grind intro]`, `{declName}` is not an inductive predicate"
|
||||
| .ext => addExtAttr declName attrKind
|
||||
| .infer =>
|
||||
if let some declName ← isCasesAttrCandidate? declName false then
|
||||
addCasesAttr declName false attrKind
|
||||
@@ -91,6 +95,8 @@ builtin_initialize
|
||||
erase := fun declName => MetaM.run' do
|
||||
if (← isCasesAttrCandidate declName false) then
|
||||
eraseCasesAttr declName
|
||||
else if (← isExtTheorem declName) then
|
||||
eraseExtAttr declName
|
||||
else
|
||||
eraseEMatchAttr declName
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ def instantiateExtTheorem (thm : Ext.ExtTheorem) (e : Expr) : GoalM Unit := with
|
||||
if proof'.hasMVar || prop'.hasMVar then
|
||||
reportIssue! "failed to apply extensionality theorem `{thm.declName}` for {indentExpr e}\nresulting terms contain metavariables"
|
||||
return ()
|
||||
trace[grind.ext] "{prop'}"
|
||||
addNewRawFact proof' prop' ((← getGeneration e) + 1)
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
||||
43
src/Lean/Meta/Tactic/Grind/ExtAttr.lean
Normal file
43
src/Lean/Meta/Tactic/Grind/ExtAttr.lean
Normal file
@@ -0,0 +1,43 @@
|
||||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Meta.Tactic.Ext
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
/-! Grind extensionality attribute to mark which `[ext]` theorems should be used. -/
|
||||
|
||||
/-- Extensionality theorems that can be used by `grind` -/
|
||||
abbrev ExtTheorems := PHashSet Name
|
||||
|
||||
builtin_initialize extTheoremsExt : SimpleScopedEnvExtension Name ExtTheorems ←
|
||||
registerSimpleScopedEnvExtension {
|
||||
initial := {}
|
||||
addEntry := fun s declName => s.insert declName
|
||||
}
|
||||
|
||||
def validateExtAttr (declName : Name) : CoreM Unit := do
|
||||
unless (← Ext.isExtTheorem declName) do
|
||||
throwError "invalid `[grind ext]`, `{declName}` is tagged with `[ext]`"
|
||||
|
||||
def addExtAttr (declName : Name) (attrKind : AttributeKind) : CoreM Unit := do
|
||||
validateExtAttr declName
|
||||
extTheoremsExt.add declName attrKind
|
||||
|
||||
private def eraseDecl (s : ExtTheorems) (declName : Name) : CoreM ExtTheorems := do
|
||||
if s.contains declName then
|
||||
return s.erase declName
|
||||
else
|
||||
throwError "`{declName}` is not marked with the `[grind ext]` attribute"
|
||||
|
||||
def eraseExtAttr (declName : Name) : CoreM Unit := do
|
||||
let s := extTheoremsExt.getState (← getEnv)
|
||||
let s ← eraseDecl s declName
|
||||
modifyEnv fun env => extTheoremsExt.modifyState env fun _ => s
|
||||
|
||||
def isExtTheorem (declName : Name) : CoreM Bool := do
|
||||
return extTheoremsExt.getState (← getEnv) |>.contains declName
|
||||
|
||||
end Lean.Meta.Grind
|
||||
@@ -114,4 +114,13 @@ def propagateForallPropDown (e : Expr) : GoalM Unit := do
|
||||
-- (a → b) = True → b = False → a = False
|
||||
pushEqFalse a <| mkApp4 (mkConst ``Grind.eq_false_of_imp_eq_true) a b (← mkEqTrueProof e) (← mkEqFalseProof b)
|
||||
|
||||
builtin_grind_propagator propagateExistsDown ↓Exists := fun e => do
|
||||
if (← isEqFalse e) then
|
||||
let_expr f@Exists α p := e | return ()
|
||||
let u := f.constLevels!
|
||||
let notP := mkApp (mkConst ``Not) (mkApp p (.bvar 0) |>.headBeta)
|
||||
let prop := mkForall `x .default α notP
|
||||
let proof := mkApp3 (mkConst ``forall_not_of_not_exists u) α p (mkOfEqFalseCore e (← mkEqFalseProof e))
|
||||
addNewRawFact proof prop (← getGeneration e)
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
||||
@@ -50,11 +50,6 @@ private def updateAppMap (e : Expr) : GoalM Unit := do
|
||||
s.appMap.insert key [e]
|
||||
}
|
||||
|
||||
/-- Inserts `e` into the list of case-split candidates. -/
|
||||
private def addSplitCandidate (e : Expr) : GoalM Unit := do
|
||||
trace_goal[grind.split.candidate] "{e}"
|
||||
modify fun s => { s with split.candidates := e :: s.split.candidates }
|
||||
|
||||
private def forbiddenSplitTypes := [``Eq, ``HEq, ``True, ``False]
|
||||
|
||||
/-- Returns `true` if `e` is of the form `@Eq Prop a b` -/
|
||||
@@ -67,10 +62,10 @@ private def checkAndAddSplitCandidate (e : Expr) : GoalM Unit := do
|
||||
match e with
|
||||
| .app .. =>
|
||||
if (← getConfig).splitIte && (e.isIte || e.isDIte) then
|
||||
addSplitCandidate e
|
||||
addSplitCandidate (.default e)
|
||||
return ()
|
||||
if isMorallyIff e then
|
||||
addSplitCandidate e
|
||||
addSplitCandidate (.default e)
|
||||
return ()
|
||||
if (← getConfig).splitMatch then
|
||||
if (← isMatcherApp e) then
|
||||
@@ -79,7 +74,7 @@ private def checkAndAddSplitCandidate (e : Expr) : GoalM Unit := do
|
||||
-- and consequently don't need to be split.
|
||||
return ()
|
||||
else
|
||||
addSplitCandidate e
|
||||
addSplitCandidate (.default e)
|
||||
return ()
|
||||
let .const declName _ := e.getAppFn | return ()
|
||||
if forbiddenSplitTypes.contains declName then
|
||||
@@ -87,16 +82,21 @@ private def checkAndAddSplitCandidate (e : Expr) : GoalM Unit := do
|
||||
unless (← isInductivePredicate declName) do
|
||||
return ()
|
||||
if (← get).split.casesTypes.isSplit declName then
|
||||
addSplitCandidate e
|
||||
addSplitCandidate (.default e)
|
||||
else if (← getConfig).splitIndPred then
|
||||
addSplitCandidate e
|
||||
addSplitCandidate (.default e)
|
||||
| .fvar .. =>
|
||||
let .const declName _ := (← whnfD (← inferType e)).getAppFn | return ()
|
||||
if (← get).split.casesTypes.isSplit declName then
|
||||
addSplitCandidate e
|
||||
addSplitCandidate (.default e)
|
||||
| .forallE _ d _ _ =>
|
||||
if Arith.isRelevantPred d || (← getConfig).splitImp then
|
||||
addSplitCandidate e
|
||||
if (← getConfig).splitImp then
|
||||
addSplitCandidate (.default e)
|
||||
else if Arith.isRelevantPred d then
|
||||
if (← getConfig).lookahead then
|
||||
addLookaheadCandidate (.default e)
|
||||
else
|
||||
addSplitCandidate (.default e)
|
||||
| _ => pure ()
|
||||
|
||||
/--
|
||||
@@ -172,7 +172,7 @@ private def activateTheoremPatterns (fName : Name) (generation : Nat) : GoalM Un
|
||||
modify fun s => { s with ematch.thmMap := thmMap }
|
||||
let appMap := (← get).appMap
|
||||
for thm in thms do
|
||||
trace[grind.debug.ematch.activate] "`{fName}` => `{thm.origin.key}`"
|
||||
trace_goal[grind.debug.ematch.activate] "`{fName}` => `{thm.origin.key}`"
|
||||
unless (← get).ematch.thmMap.isErased thm.origin do
|
||||
let symbols := thm.symbols.filter fun sym => !appMap.contains sym
|
||||
let thm := { thm with symbols }
|
||||
@@ -207,6 +207,68 @@ private def propagateUnitLike (a : Expr) (generation : Nat) : GoalM Unit := do
|
||||
internalize unit generation
|
||||
pushEq a unit <| (← mkEqRefl unit)
|
||||
|
||||
/-- Returns `true` if we can ignore `ext` for functions occurring as arguments of a `declName`-application. -/
|
||||
private def extParentsToIgnore (declName : Name) : Bool :=
|
||||
declName == ``Eq || declName == ``HEq || declName == ``dite || declName == ``ite
|
||||
|| declName == ``Exists || declName == ``Subtype
|
||||
|
||||
/--
|
||||
Given a term `arg` that occurs as the argument at position `i` of an `f`-application `parent?`,
|
||||
we consider `arg` as a candidate for case-splitting. For every other argument `arg'` that also appears
|
||||
at position `i` in an `f`-application and has the same type as `e`, we add the case-split candidate `arg = arg'`.
|
||||
|
||||
When performing the case split, we consider the following two cases:
|
||||
- `arg = arg'`, which may introduce a new congruence between the corresponding `f`-applications.
|
||||
- `¬(arg = arg')`, which may trigger extensionality theorems for the type of `arg`.
|
||||
|
||||
This feature enables `grind` to solve examples such as:
|
||||
```lean
|
||||
example (f : (Nat → Nat) → Nat) : a = b → f (fun x => a + x) = f (fun x => b + x) := by
|
||||
grind
|
||||
```
|
||||
-/
|
||||
private def addSplitCandidatesForExt (arg : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit := do
|
||||
let some parent := parent? | return ()
|
||||
unless parent.isApp do return ()
|
||||
let f := parent.getAppFn
|
||||
if let .const declName _ := f then
|
||||
if extParentsToIgnore declName then return ()
|
||||
let type ← inferType arg
|
||||
-- Remark: we currently do not perform function extensionality on functions that produce a type that is not a proposition.
|
||||
-- We may add an option to enable that in the future.
|
||||
let u? ← typeFormerTypeLevel type
|
||||
if u? != .none && u? != some .zero then return ()
|
||||
let mut i := parent.getAppNumArgs
|
||||
let mut it := parent
|
||||
repeat
|
||||
if !it.isApp then return ()
|
||||
i := i - 1
|
||||
if isSameExpr arg it.appArg! then
|
||||
found f i type parent
|
||||
it := it.appFn!
|
||||
where
|
||||
found (f : Expr) (i : Nat) (type : Expr) (parent : Expr) : GoalM Unit := do
|
||||
trace_goal[grind.debug.ext] "{f}, {i}, {arg}"
|
||||
let others := (← get).split.argsAt.find? (f, i) |>.getD []
|
||||
for other in others do
|
||||
if (← withDefault <| isDefEq type other.type) then
|
||||
let eq := mkApp3 (mkConst ``Eq [← getLevel type]) type arg other.arg
|
||||
let eq ← shareCommon eq
|
||||
internalize eq generation
|
||||
trace_goal[grind.ext.candidate] "{eq}"
|
||||
-- We do not use lookahead here because it is too incomplete.
|
||||
-- if (← getConfig).lookahead then
|
||||
-- addLookaheadCandidate (.arg other.app parent i eq)
|
||||
-- else
|
||||
addSplitCandidate (.arg other.app parent i eq)
|
||||
modify fun s => { s with split.argsAt := s.split.argsAt.insert (f, i) ({ arg, type, app := parent } :: others) }
|
||||
return ()
|
||||
|
||||
/-- Applies `addSplitCandidatesForExt` if `funext` is enabled. -/
|
||||
private def addSplitCandidatesForFunext (arg : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit := do
|
||||
unless (← getConfig).funext do return ()
|
||||
addSplitCandidatesForExt arg generation parent?
|
||||
|
||||
@[export lean_grind_internalize]
|
||||
private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit := withIncRecDepth do
|
||||
if (← alreadyInternalized e) then
|
||||
@@ -229,7 +291,10 @@ private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Opt
|
||||
| .fvar .. =>
|
||||
mkENode' e generation
|
||||
checkAndAddSplitCandidate e
|
||||
| .letE .. | .lam .. =>
|
||||
| .letE .. =>
|
||||
mkENode' e generation
|
||||
| .lam .. =>
|
||||
addSplitCandidatesForFunext e generation parent?
|
||||
mkENode' e generation
|
||||
| .forallE _ d b _ =>
|
||||
mkENode' e generation
|
||||
|
||||
@@ -178,6 +178,12 @@ private def isEagerCasesCandidate (goal : Goal) (type : Expr) : Bool := Id.run d
|
||||
let .const declName _ := type.getAppFn | return false
|
||||
return goal.split.casesTypes.isEagerSplit declName
|
||||
|
||||
/-- Returns `true` if `type` is an inductive type with at most one constructor. -/
|
||||
private def isCheapInductive (type : Expr) : CoreM Bool := do
|
||||
let .const declName _ := type.getAppFn | return false
|
||||
let .inductInfo info ← getConstInfo declName | return false
|
||||
return info.numCtors <= 1
|
||||
|
||||
private def applyCases? (goal : Goal) (fvarId : FVarId) : GrindM (Option (List Goal)) := goal.mvarId.withContext do
|
||||
/-
|
||||
Remark: we used to use `whnfD`. This was a mistake, we don't want to unfold user-defined abstractions.
|
||||
@@ -185,6 +191,9 @@ private def applyCases? (goal : Goal) (fvarId : FVarId) : GrindM (Option (List G
|
||||
-/
|
||||
let type ← whnf (← fvarId.getType)
|
||||
if isEagerCasesCandidate goal type then
|
||||
if (← cheapCasesOnly) then
|
||||
unless (← isCheapInductive type) do
|
||||
return none
|
||||
if let .const declName _ := type.getAppFn then
|
||||
saveCases declName true
|
||||
let mvarIds ← cases goal.mvarId (mkFVar fvarId)
|
||||
@@ -205,7 +214,7 @@ private def exfalsoIfNotProp (goal : Goal) : MetaM Goal := goal.mvarId.withConte
|
||||
return { goal with mvarId := (← goal.mvarId.exfalso) }
|
||||
|
||||
/-- Introduce new hypotheses (and apply `by_contra`) until goal is of the form `... ⊢ False` -/
|
||||
partial def intros (generation : Nat) : GrindTactic' := fun goal => do
|
||||
partial def intros (generation : Nat) : GrindTactic' := fun goal => do
|
||||
let rec go (goal : Goal) : StateRefT (Array Goal) GrindM Unit := do
|
||||
if goal.inconsistent then
|
||||
return ()
|
||||
|
||||
101
src/Lean/Meta/Tactic/Grind/Lookahead.lean
Normal file
101
src/Lean/Meta/Tactic/Grind/Lookahead.lean
Normal file
@@ -0,0 +1,101 @@
|
||||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Meta.Tactic.Grind.Types
|
||||
import Lean.Meta.Tactic.Grind.Intro
|
||||
import Lean.Meta.Tactic.Grind.Arith
|
||||
import Lean.Meta.Tactic.Grind.Split
|
||||
import Lean.Meta.Tactic.Grind.EMatch
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
private partial def solve (generation : Nat) (goal : Goal) : GrindM Bool := do
|
||||
cont (← intros generation goal)
|
||||
where
|
||||
cont (goals : List Goal) : GrindM Bool := do
|
||||
match goals with
|
||||
| [] => return true
|
||||
| [goal] => loop goal
|
||||
| _ => throwError "`grind` lookahead internal error, unexpected number of goals"
|
||||
|
||||
loop (goal : Goal) : GrindM Bool := withIncRecDepth do
|
||||
if goal.inconsistent then
|
||||
return true
|
||||
else if let some goals ← assertNext goal then
|
||||
cont goals
|
||||
else if let some goals ← Arith.check goal then
|
||||
cont goals
|
||||
else if let some goals ← splitNext goal then
|
||||
cont goals
|
||||
else if let some goals ← ematchAndAssert goal then
|
||||
cont goals
|
||||
else
|
||||
return false
|
||||
|
||||
private def tryLookahead (e : Expr) : GoalM Bool := do
|
||||
-- TODO: if `e` is an arithmetic expression, we can avoid creating an auxiliary goal.
|
||||
-- We can assert it directly to the arithmetic module.
|
||||
-- Remark: We can simplify this code because the lookahead only really worked for arithmetic.
|
||||
trace_goal[grind.lookahead.try] "{e}"
|
||||
let proof? ← withoutModifyingState do
|
||||
let goal ← get
|
||||
let tag ← goal.mvarId.getTag
|
||||
let target ← mkArrow (mkNot e) (← getFalseExpr)
|
||||
let mvar ← mkFreshExprMVar target .syntheticOpaque tag
|
||||
let gen ← getGeneration e
|
||||
if (← solve gen { goal with mvarId := mvar.mvarId! }) then
|
||||
return some (← instantiateMVars mvar)
|
||||
else
|
||||
return none
|
||||
if let some proof := proof? then
|
||||
trace[grind.lookahead.assert] "{e}"
|
||||
pushEqTrue e <| mkApp2 (mkConst ``Grind.of_lookahead) e proof
|
||||
processNewFacts
|
||||
return true
|
||||
else
|
||||
return false
|
||||
|
||||
private def withLookaheadConfig (x : GrindM α) : GrindM α := do
|
||||
withTheReader Grind.Context
|
||||
(fun ctx => { ctx with config.qlia := true, cheapCases := true })
|
||||
x
|
||||
|
||||
def lookahead : GrindTactic := fun goal => do
|
||||
unless (← getConfig).lookahead do
|
||||
return none
|
||||
if goal.split.lookaheads.isEmpty then
|
||||
return none
|
||||
withLookaheadConfig do
|
||||
let (progress, goal) ← GoalM.run goal do
|
||||
let mut postponed := []
|
||||
let mut progress := false
|
||||
let infos := (← get).split.lookaheads
|
||||
modify fun s => { s with split.lookaheads := [] }
|
||||
for info in infos do
|
||||
if (← isInconsistent) then
|
||||
return true
|
||||
match (← checkSplitStatus info) with
|
||||
| .resolved => progress := true
|
||||
| .ready _ _ true
|
||||
| .notReady => postponed := info :: postponed
|
||||
| .ready _ _ false =>
|
||||
if (← tryLookahead info.getExpr) then
|
||||
progress := true
|
||||
else
|
||||
postponed := info :: postponed
|
||||
if progress then
|
||||
modify fun s => { s with
|
||||
split.lookaheads := s.split.lookaheads ++ postponed.reverse
|
||||
}
|
||||
return true
|
||||
else
|
||||
return false
|
||||
if progress then
|
||||
return some [goal]
|
||||
else
|
||||
return none
|
||||
|
||||
end Lean.Meta.Grind
|
||||
@@ -33,17 +33,26 @@ structure MBTC.Context where
|
||||
-/
|
||||
eqAssignment : Expr → Expr → GoalM Bool
|
||||
|
||||
private abbrev Map := Std.HashMap (Expr × Nat) (List Expr)
|
||||
private abbrev Candidates := Std.HashSet (Expr × Expr)
|
||||
private def mkCandidateKey (a b : Expr) : Expr × Expr :=
|
||||
if a.lt b then
|
||||
(a, b)
|
||||
private structure ArgInfo where
|
||||
arg : Expr
|
||||
app : Expr
|
||||
|
||||
private abbrev Map := Std.HashMap (Expr × Nat) (List ArgInfo)
|
||||
private abbrev Candidates := Std.HashSet SplitInfo
|
||||
private def mkCandidate (a b : ArgInfo) (i : Nat) : GoalM SplitInfo := do
|
||||
let (lhs, rhs) := if a.arg.lt b.arg then
|
||||
(a.arg, b.arg)
|
||||
else
|
||||
(b, a)
|
||||
(b.arg, a.arg)
|
||||
let eq ← mkEq lhs rhs
|
||||
let eq ← shareCommon (← canon eq)
|
||||
return .arg a.app b.app i eq
|
||||
|
||||
/-- Model-based theory combination. -/
|
||||
def mbtc (ctx : MBTC.Context) : GoalM Bool := do
|
||||
unless (← getConfig).mbtc do return false
|
||||
-- It is pointless to run `mbtc` if maximum number of splits has been reached.
|
||||
if (← checkMaxCaseSplit) then return false
|
||||
let mut map : Map := {}
|
||||
let mut candidates : Candidates := {}
|
||||
for ({ expr := e }, _) in (← get).enodes do
|
||||
@@ -56,33 +65,33 @@ def mbtc (ctx : MBTC.Context) : GoalM Bool := do
|
||||
let some arg ← getRoot? arg | pure ()
|
||||
if (← ctx.hasTheoryVar arg) then
|
||||
trace[grind.debug.mbtc] "{arg} @ {f}:{i}"
|
||||
if let some others := map[(f, i)]? then
|
||||
unless others.any (isSameExpr arg ·) do
|
||||
for other in others do
|
||||
if (← ctx.eqAssignment arg other) then
|
||||
let k := mkCandidateKey arg other
|
||||
candidates := candidates.insert k
|
||||
map := map.insert (f, i) (arg :: others)
|
||||
let argInfo : ArgInfo := { arg, app := e }
|
||||
if let some otherInfos := map[(f, i)]? then
|
||||
unless otherInfos.any fun info => isSameExpr arg info.arg do
|
||||
for otherInfo in otherInfos do
|
||||
if (← ctx.eqAssignment arg otherInfo.arg) then
|
||||
if (← hasSameType arg otherInfo.arg) then
|
||||
candidates := candidates.insert (← mkCandidate argInfo otherInfo i)
|
||||
map := map.insert (f, i) (argInfo :: otherInfos)
|
||||
else
|
||||
map := map.insert (f, i) [arg]
|
||||
map := map.insert (f, i) [argInfo]
|
||||
i := i + 1
|
||||
if candidates.isEmpty then
|
||||
return false
|
||||
if (← get).split.num > (← getConfig).splits then
|
||||
reportIssue "skipping `mbtc`, maximum number of splits has been reached `(splits := {(← getConfig).splits})`"
|
||||
return false
|
||||
let result := candidates.toArray.qsort fun (a₁, b₁) (a₂, b₂) =>
|
||||
if isSameExpr a₁ a₂ then
|
||||
b₁.lt b₂
|
||||
else
|
||||
a₁.lt a₂
|
||||
let eqs ← result.mapM fun (a, b) => do
|
||||
let eq ← mkEq a b
|
||||
trace[grind.mbtc] "{eq}"
|
||||
let eq ← shareCommon (← canon eq)
|
||||
let result := candidates.toArray.qsort fun c₁ c₂ => c₁.lt c₂
|
||||
let result ← result.filterMapM fun info => do
|
||||
if (← isKnownCaseSplit info) then
|
||||
return none
|
||||
let .arg a b _ eq := info | return none
|
||||
internalize eq (Nat.max (← getGeneration a) (← getGeneration b))
|
||||
return eq
|
||||
modify fun s => { s with split.candidates := s.split.candidates ++ eqs.toList }
|
||||
return some info
|
||||
if result.isEmpty then
|
||||
return false
|
||||
for info in result do
|
||||
addSplitCandidate info
|
||||
return true
|
||||
|
||||
def mbtcTac (ctx : MBTC.Context) : GrindTactic := fun goal => do
|
||||
|
||||
@@ -9,18 +9,6 @@ import Lean.Meta.Tactic.Grind.Simp
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
/--
|
||||
Helper function for executing `x` with a fresh `newFacts` and without modifying
|
||||
the goal state.
|
||||
-/
|
||||
private def withoutModifyingState (x : GoalM α) : GoalM α := do
|
||||
let saved ← get
|
||||
modify fun goal => { goal with newFacts := {} }
|
||||
try
|
||||
x
|
||||
finally
|
||||
set saved
|
||||
|
||||
/--
|
||||
If `e` has not been internalized yet, instantiate metavariables, unfold reducible, canonicalize,
|
||||
and internalize the result.
|
||||
|
||||
@@ -8,6 +8,7 @@ import Lean.Meta.Tactic.Grind.Combinators
|
||||
import Lean.Meta.Tactic.Grind.Split
|
||||
import Lean.Meta.Tactic.Grind.EMatch
|
||||
import Lean.Meta.Tactic.Grind.Arith
|
||||
import Lean.Meta.Tactic.Grind.Lookahead
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
@@ -62,6 +63,8 @@ def trySplit : Goal → M Bool := applyTac splitNext
|
||||
|
||||
def tryArith : Goal → M Bool := applyTac Arith.check
|
||||
|
||||
def tryLookahead : Goal → M Bool := applyTac lookahead
|
||||
|
||||
def tryMBTC : Goal → M Bool := applyTac Arith.Cutsat.mbtcTac
|
||||
|
||||
def maxNumFailuresReached : M Bool := do
|
||||
@@ -81,6 +84,8 @@ partial def main (fallback : Fallback) : M Unit := do
|
||||
continue
|
||||
if (← tryEmatch goal) then
|
||||
continue
|
||||
if (← tryLookahead goal) then
|
||||
continue
|
||||
if (← trySplit goal) then
|
||||
continue
|
||||
if (← tryMBTC goal) then
|
||||
|
||||
@@ -11,14 +11,14 @@ import Lean.Meta.Tactic.Grind.CasesMatch
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
inductive CaseSplitStatus where
|
||||
inductive SplitStatus where
|
||||
| resolved
|
||||
| notReady
|
||||
| ready (numCases : Nat) (isRec := false)
|
||||
| ready (numCases : Nat) (isRec := false) (tryPostpone := false)
|
||||
deriving Inhabited, BEq, Repr
|
||||
|
||||
/-- Given `c`, the condition of an `if-then-else`, check whether we need to case-split on the `if-then-else` or not -/
|
||||
private def checkIteCondStatus (c : Expr) : GoalM CaseSplitStatus := do
|
||||
private def checkIteCondStatus (c : Expr) : GoalM SplitStatus := do
|
||||
if (← isEqTrue c <||> isEqFalse c) then
|
||||
return .resolved
|
||||
else
|
||||
@@ -28,7 +28,7 @@ private def checkIteCondStatus (c : Expr) : GoalM CaseSplitStatus := do
|
||||
Given `e` of the form `a ∨ b`, check whether we are ready to case-split on `e`.
|
||||
That is, `e` is `True`, but neither `a` nor `b` is `True`."
|
||||
-/
|
||||
private def checkDisjunctStatus (e a b : Expr) : GoalM CaseSplitStatus := do
|
||||
private def checkDisjunctStatus (e a b : Expr) : GoalM SplitStatus := do
|
||||
if (← isEqTrue e) then
|
||||
if (← isEqTrue a <||> isEqTrue b) then
|
||||
return .resolved
|
||||
@@ -43,7 +43,7 @@ private def checkDisjunctStatus (e a b : Expr) : GoalM CaseSplitStatus := do
|
||||
Given `e` of the form `a ∧ b`, check whether we are ready to case-split on `e`.
|
||||
That is, `e` is `False`, but neither `a` nor `b` is `False`.
|
||||
-/
|
||||
private def checkConjunctStatus (e a b : Expr) : GoalM CaseSplitStatus := do
|
||||
private def checkConjunctStatus (e a b : Expr) : GoalM SplitStatus := do
|
||||
if (← isEqTrue e) then
|
||||
return .resolved
|
||||
else if (← isEqFalse e) then
|
||||
@@ -60,7 +60,7 @@ There are two cases:
|
||||
1- `e` is `True`, but neither both `a` and `b` are `True`, nor both `a` and `b` are `False`.
|
||||
2- `e` is `False`, but neither `a` is `True` and `b` is `False`, nor `a` is `False` and `b` is `True`.
|
||||
-/
|
||||
private def checkIffStatus (e a b : Expr) : GoalM CaseSplitStatus := do
|
||||
private def checkIffStatus (e a b : Expr) : GoalM SplitStatus := do
|
||||
if (← isEqTrue e) then
|
||||
if (← (isEqTrue a <&&> isEqTrue b) <||> (isEqFalse a <&&> isEqFalse b)) then
|
||||
return .resolved
|
||||
@@ -76,13 +76,14 @@ private def checkIffStatus (e a b : Expr) : GoalM CaseSplitStatus := do
|
||||
|
||||
/-- Returns `true` is `c` is congruent to a case-split that was already performed. -/
|
||||
private def isCongrToPrevSplit (c : Expr) : GoalM Bool := do
|
||||
unless c.isApp do return false
|
||||
(← get).split.resolved.foldM (init := false) fun flag { expr := c' } => do
|
||||
if flag then
|
||||
return true
|
||||
else
|
||||
return isCongruent (← get).enodes c c'
|
||||
return c'.isApp && isCongruent (← get).enodes c c'
|
||||
|
||||
private def checkForallStatus (e : Expr) : GoalM CaseSplitStatus := do
|
||||
private def checkForallStatus (e : Expr) : GoalM SplitStatus := do
|
||||
if (← isEqTrue e) then
|
||||
let .forallE _ p q _ := e | return .resolved
|
||||
if (← isEqTrue p <||> isEqFalse p) then
|
||||
@@ -96,7 +97,7 @@ private def checkForallStatus (e : Expr) : GoalM CaseSplitStatus := do
|
||||
else
|
||||
return .notReady
|
||||
|
||||
private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do
|
||||
private def checkDefaultSplitStatus (e : Expr) : GoalM SplitStatus := do
|
||||
match_expr e with
|
||||
| Or a b => checkDisjunctStatus e a b
|
||||
| And a b => checkConjunctStatus e a b
|
||||
@@ -132,9 +133,37 @@ private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do
|
||||
return .ready info.ctors.length info.isRec
|
||||
return .notReady
|
||||
|
||||
def checkSplitInfoArgStatus (a b : Expr) (eq : Expr) : GoalM SplitStatus := do
|
||||
if (← isEqTrue eq <||> isEqFalse eq) then return .resolved
|
||||
let is := (← get).split.argPosMap[(a, b)]? |>.getD []
|
||||
let mut j := a.getAppNumArgs
|
||||
let mut it_a := a
|
||||
let mut it_b := b
|
||||
repeat
|
||||
unless it_a.isApp && it_b.isApp do return .ready 2
|
||||
j := j - 1
|
||||
if j ∉ is then
|
||||
let arg_a := it_a.appArg!
|
||||
let arg_b := it_b.appArg!
|
||||
unless (← isEqv arg_a arg_b) do
|
||||
trace_goal[grind.split] "may be irrelevant\na: {a}\nb: {b}\neq: {eq}\narg_a: {arg_a}\narg_b: {arg_b}, gen: {← getGeneration eq}"
|
||||
/-
|
||||
We tried to return `.notReady` because we would not be able to derive a congruence, but
|
||||
`grind_ite.lean` breaks when this heuristic is used. TODO: understand better why.
|
||||
-/
|
||||
return .ready 2 (tryPostpone := true)
|
||||
it_a := it_a.appFn!
|
||||
it_b := it_b.appFn!
|
||||
return .ready 2
|
||||
|
||||
def checkSplitStatus (s : SplitInfo) : GoalM SplitStatus := do
|
||||
match s with
|
||||
| .default e => checkDefaultSplitStatus e
|
||||
| .arg a b _ eq => checkSplitInfoArgStatus a b eq
|
||||
|
||||
private inductive SplitCandidate where
|
||||
| none
|
||||
| some (c : Expr) (numCases : Nat) (isRec : Bool)
|
||||
| some (c : SplitInfo) (numCases : Nat) (isRec : Bool) (tryPostpone : Bool)
|
||||
|
||||
/-- Returns the next case-split to be performed. It uses a very simple heuristic. -/
|
||||
private def selectNextSplit? : GoalM SplitCandidate := do
|
||||
@@ -142,11 +171,11 @@ private def selectNextSplit? : GoalM SplitCandidate := do
|
||||
if (← checkMaxCaseSplit) then return .none
|
||||
go (← get).split.candidates .none []
|
||||
where
|
||||
go (cs : List Expr) (c? : SplitCandidate) (cs' : List Expr) : GoalM SplitCandidate := do
|
||||
go (cs : List SplitInfo) (c? : SplitCandidate) (cs' : List SplitInfo) : GoalM SplitCandidate := do
|
||||
match cs with
|
||||
| [] =>
|
||||
modify fun s => { s with split.candidates := cs'.reverse }
|
||||
if let .some _ numCases isRec := c? then
|
||||
if let .some _ numCases isRec _ := c? then
|
||||
let numSplits := (← get).split.num
|
||||
-- We only increase the number of splits if there is more than one case or it is recursive.
|
||||
let numSplits := if numCases > 1 || isRec then numSplits + 1 else numSplits
|
||||
@@ -155,22 +184,28 @@ where
|
||||
modify fun s => { s with split.num := numSplits, ematch.num := 0 }
|
||||
return c?
|
||||
| c::cs =>
|
||||
trace_goal[grind.debug.split] "checking: {c}"
|
||||
match (← checkCaseSplitStatus c) with
|
||||
trace_goal[grind.debug.split] "checking: {c.getExpr}"
|
||||
match (← checkSplitStatus c) with
|
||||
| .notReady => go cs c? (c::cs')
|
||||
| .resolved => go cs c? cs'
|
||||
| .ready numCases isRec =>
|
||||
match c? with
|
||||
| .none => go cs (.some c numCases isRec) cs'
|
||||
| .some c' numCases' _ =>
|
||||
| .ready numCases isRec tryPostpone =>
|
||||
if (← cheapCasesOnly) && numCases > 1 then
|
||||
go cs c? (c::cs')
|
||||
else match c? with
|
||||
| .none => go cs (.some c numCases isRec tryPostpone) cs'
|
||||
| .some c' numCases' _ tryPostpone' =>
|
||||
let isBetter : GoalM Bool := do
|
||||
if numCases == 1 && !isRec && numCases' > 1 then
|
||||
if tryPostpone' && !tryPostpone then
|
||||
return true
|
||||
if (← getGeneration c) < (← getGeneration c') then
|
||||
else if tryPostpone && !tryPostpone' then
|
||||
return false
|
||||
else if numCases == 1 && !isRec && numCases' > 1 then
|
||||
return true
|
||||
if (← getGeneration c.getExpr) < (← getGeneration c'.getExpr) then
|
||||
return true
|
||||
return numCases < numCases'
|
||||
if (← isBetter) then
|
||||
go cs (.some c numCases isRec) (c'::cs')
|
||||
go cs (.some c numCases isRec tryPostpone) (c'::cs')
|
||||
else
|
||||
go cs c? (c::cs')
|
||||
|
||||
@@ -192,6 +227,7 @@ private def mkCasesMajor (c : Expr) : GoalM Expr := do
|
||||
else
|
||||
-- model-based theory combination split
|
||||
return mkGrindEM c
|
||||
| Not e => return mkGrindEM e
|
||||
| _ =>
|
||||
if let .forallE _ p _ _ := c then
|
||||
return mkGrindEM p
|
||||
@@ -212,8 +248,9 @@ and returns a new list of goals if successful.
|
||||
-/
|
||||
def splitNext : GrindTactic := fun goal => do
|
||||
let (goals?, _) ← GoalM.run goal do
|
||||
let .some c numCases isRec ← selectNextSplit?
|
||||
let .some c numCases isRec _ ← selectNextSplit?
|
||||
| return none
|
||||
let c := c.getExpr
|
||||
let gen ← getGeneration c
|
||||
let genNew := if numCases > 1 || isRec then gen+1 else gen
|
||||
markCaseSplitAsResolved c
|
||||
|
||||
@@ -17,6 +17,7 @@ import Lean.Meta.Tactic.Util
|
||||
import Lean.Meta.Tactic.Ext
|
||||
import Lean.Meta.Tactic.Grind.ENodeKey
|
||||
import Lean.Meta.Tactic.Grind.Attr
|
||||
import Lean.Meta.Tactic.Grind.ExtAttr
|
||||
import Lean.Meta.Tactic.Grind.Cases
|
||||
import Lean.Meta.Tactic.Grind.Arith.Types
|
||||
import Lean.Meta.Tactic.Grind.EMatchTheorem
|
||||
@@ -58,6 +59,15 @@ structure Context where
|
||||
simprocs : Array Simp.Simprocs
|
||||
mainDeclName : Name
|
||||
config : Grind.Config
|
||||
/--
|
||||
If `cheapCases` is `true`, `grind` only applies `cases` to types that contain
|
||||
at most one minor premise.
|
||||
Recall that `grind` applies `cases` when introducing types tagged with `[grind cases eager]`,
|
||||
and at `Split.lean`
|
||||
Remark: We add this option to implement the `lookahead` feature, we don't want to create several subgoals
|
||||
when performing lookahead.
|
||||
-/
|
||||
cheapCases : Bool := false
|
||||
|
||||
/-- Key for the congruence theorem cache. -/
|
||||
structure CongrTheoremCacheKey where
|
||||
@@ -164,6 +174,9 @@ def getNatZeroExpr : GrindM Expr := do
|
||||
def getMainDeclName : GrindM Name :=
|
||||
return (← readThe Context).mainDeclName
|
||||
|
||||
def cheapCasesOnly : GrindM Bool :=
|
||||
return (← readThe Context).cheapCases
|
||||
|
||||
def saveEMatchTheorem (thm : EMatchTheorem) : GrindM Unit := do
|
||||
if (← getConfig).trace then
|
||||
modify fun s => { s with trace.thms := s.trace.thms.insert { origin := thm.origin, kind := thm.kind } }
|
||||
@@ -472,22 +485,72 @@ structure EMatch.State where
|
||||
matchEqNames : PHashSet Name := {}
|
||||
deriving Inhabited
|
||||
|
||||
/-- Case-split information. -/
|
||||
inductive SplitInfo where
|
||||
| /--
|
||||
Term `e` may be an inductive predicate, `match`-expression, `if`-expression, implication, etc.
|
||||
-/
|
||||
default (e : Expr)
|
||||
| /--
|
||||
Given applications `a` and `b`, case-split on whether the corresponding
|
||||
`i`-th arguments are equal or not. The split is only performed if all other
|
||||
arguments are already known to be equal or are also tagged as split candidates.
|
||||
-/
|
||||
arg (a b : Expr) (i : Nat) (eq : Expr)
|
||||
deriving BEq, Hashable, Inhabited
|
||||
|
||||
def SplitInfo.getExpr : SplitInfo → Expr
|
||||
| .default (.forallE _ d _ _) => d
|
||||
| .default e => e
|
||||
| .arg _ _ _ eq => eq
|
||||
|
||||
def SplitInfo.lt : SplitInfo → SplitInfo → Bool
|
||||
| .default e₁, .default e₂ => e₁.lt e₂
|
||||
| .arg _ _ _ e₁, .arg _ _ _ e₂ => e₁.lt e₂
|
||||
| .default _, .arg .. => true
|
||||
| .arg .., .default _ => false
|
||||
|
||||
/-- Argument `arg : type` of an application `app` in `SplitInfo`. -/
|
||||
structure SplitArg where
|
||||
arg : Expr
|
||||
type : Expr
|
||||
app : Expr
|
||||
|
||||
/-- Case splitting related fields for the `grind` goal. -/
|
||||
structure Split.State where
|
||||
/-- Inductive datatypes marked for case-splitting -/
|
||||
casesTypes : CasesTypes := {}
|
||||
/-- Case-split candidates. -/
|
||||
candidates : List Expr := []
|
||||
/-- Number of splits performed to get to this goal. -/
|
||||
num : Nat := 0
|
||||
num : Nat := 0
|
||||
/-- Inductive datatypes marked for case-splitting -/
|
||||
casesTypes : CasesTypes := {}
|
||||
/-- Case-split candidates. -/
|
||||
candidates : List SplitInfo := []
|
||||
/-- Case-splits that have been inserted at `candidates` at some point. -/
|
||||
added : Std.HashSet SplitInfo := {}
|
||||
/-- Case-splits that have already been performed, or that do not have to be performed anymore. -/
|
||||
resolved : PHashSet ENodeKey := {}
|
||||
resolved : PHashSet ENodeKey := {}
|
||||
/--
|
||||
Sequence of cases steps that generated this goal. We only use this information for diagnostics.
|
||||
Remark: `casesTrace.length ≥ numSplits` because we don't increase the counter for `cases`
|
||||
applications that generated only 1 subgoal.
|
||||
-/
|
||||
trace : List CaseTrace := []
|
||||
trace : List CaseTrace := []
|
||||
/-- Lookahead "case-splits". -/
|
||||
lookaheads : List SplitInfo := []
|
||||
/--
|
||||
A mapping `(a, b) ↦ is` s.t. for each `SplitInfo.arg a b i eq`
|
||||
in `candidates` or `lookaheads` we have `i ∈ is`.
|
||||
We use this information to decide whether the split/lookahead is "ready"
|
||||
to be tried or not.
|
||||
-/
|
||||
argPosMap : Std.HashMap (Expr × Expr) (List Nat) := {}
|
||||
/--
|
||||
Mapping from pairs `(f, i)` to a list of arguments.
|
||||
Each argument occurs as the `i`-th of an `f`-application.
|
||||
We use this information to add splits/lookaheads for
|
||||
triggering extensionality theorems and model-based theory combination.
|
||||
See `addSplitCandidatesForExt`.
|
||||
-/
|
||||
argsAt : PHashMap (Expr × Nat) (List SplitArg) := {}
|
||||
deriving Inhabited
|
||||
|
||||
/-- Clean name generator. -/
|
||||
@@ -510,7 +573,7 @@ structure Goal where
|
||||
-/
|
||||
appMap : PHashMap HeadIndex (List Expr) := {}
|
||||
/-- Equations and propositions to be processed. -/
|
||||
newFacts : Array NewFact := #[]
|
||||
newFacts : Array NewFact := #[]
|
||||
/-- `inconsistent := true` if `ENode`s for `True` and `False` are in the same equivalence class. -/
|
||||
inconsistent : Bool := false
|
||||
/-- Next unique index for creating ENodes -/
|
||||
@@ -518,17 +581,17 @@ structure Goal where
|
||||
/-- new facts to be preprocessed and then asserted. -/
|
||||
newRawFacts : Std.Queue NewRawFact := ∅
|
||||
/-- Asserted facts -/
|
||||
facts : PArray Expr := {}
|
||||
facts : PArray Expr := {}
|
||||
/-- Cached extensionality theorems for types. -/
|
||||
extThms : PHashMap ENodeKey (Array Ext.ExtTheorem) := {}
|
||||
extThms : PHashMap ENodeKey (Array Ext.ExtTheorem) := {}
|
||||
/-- State of the E-matching module. -/
|
||||
ematch : EMatch.State
|
||||
ematch : EMatch.State
|
||||
/-- State of the case-splitting module. -/
|
||||
split : Split.State := {}
|
||||
split : Split.State := {}
|
||||
/-- State of arithmetic procedures. -/
|
||||
arith : Arith.State := {}
|
||||
arith : Arith.State := {}
|
||||
/-- State of the clean name generator. -/
|
||||
clean : Clean.State := {}
|
||||
clean : Clean.State := {}
|
||||
deriving Inhabited
|
||||
|
||||
def Goal.admit (goal : Goal) : MetaM Unit :=
|
||||
@@ -1153,6 +1216,14 @@ partial def Goal.getEqcs (goal : Goal) : List (List Expr) := Id.run do
|
||||
def getEqcs : GoalM (List (List Expr)) :=
|
||||
return (← get).getEqcs
|
||||
|
||||
/--
|
||||
Returns `true` if `s` has been already added to the case-split list at one point.
|
||||
Remark: this function returns `true` even if the split has already been resolved
|
||||
and is not in the list anymore.
|
||||
-/
|
||||
def isKnownCaseSplit (s : SplitInfo) : GoalM Bool :=
|
||||
return (← get).split.added.contains s
|
||||
|
||||
/-- Returns `true` if `e` is a case-split that does not need to be performed anymore. -/
|
||||
def isResolvedCaseSplit (e : Expr) : GoalM Bool :=
|
||||
return (← get).split.resolved.contains { expr := e }
|
||||
@@ -1167,16 +1238,38 @@ def markCaseSplitAsResolved (e : Expr) : GoalM Unit := do
|
||||
trace_goal[grind.split.resolved] "{e}"
|
||||
modify fun s => { s with split.resolved := s.split.resolved.insert { expr := e } }
|
||||
|
||||
private def updateSplitArgPosMap (sinfo : SplitInfo) : GoalM Unit := do
|
||||
let .arg a b i _ := sinfo | return ()
|
||||
let key := (a, b)
|
||||
let is := (← get).split.argPosMap[key]? |>.getD []
|
||||
modify fun s => { s with
|
||||
split.argPosMap := s.split.argPosMap.insert key (i :: is)
|
||||
}
|
||||
|
||||
/-- Inserts `e` into the list of case-split candidates if it was not inserted before. -/
|
||||
def addSplitCandidate (sinfo : SplitInfo) : GoalM Unit := do
|
||||
unless (← isKnownCaseSplit sinfo) do
|
||||
trace_goal[grind.split.candidate] "{sinfo.getExpr}"
|
||||
modify fun s => { s with
|
||||
split.added := s.split.added.insert sinfo
|
||||
split.candidates := sinfo :: s.split.candidates
|
||||
}
|
||||
updateSplitArgPosMap sinfo
|
||||
|
||||
/--
|
||||
Returns extensionality theorems for the given type if available.
|
||||
If `Config.ext` is `false`, the result is `#[]`.
|
||||
-/
|
||||
def getExtTheorems (type : Expr) : GoalM (Array Ext.ExtTheorem) := do
|
||||
unless (← getConfig).ext do return #[]
|
||||
unless (← getConfig).ext || (← getConfig).extAll do return #[]
|
||||
if let some thms := (← get).extThms.find? { expr := type } then
|
||||
return thms
|
||||
else
|
||||
let thms ← Ext.getExtTheorems type
|
||||
let thms ← if (← getConfig).extAll then
|
||||
pure thms
|
||||
else
|
||||
thms.filterM fun thm => isExtTheorem thm.declName
|
||||
modify fun s => { s with extThms := s.extThms.insert { expr := type } thms }
|
||||
return thms
|
||||
|
||||
@@ -1188,4 +1281,22 @@ def synthesizeInstanceAndAssign (x type : Expr) : MetaM Bool := do
|
||||
let .some val ← trySynthInstance type | return false
|
||||
isDefEq x val
|
||||
|
||||
/-- Add a new lookahead candidate. -/
|
||||
def addLookaheadCandidate (sinfo : SplitInfo) : GoalM Unit := do
|
||||
trace_goal[grind.lookahead.add] "{sinfo.getExpr}"
|
||||
modify fun s => { s with split.lookaheads := sinfo :: s.split.lookaheads }
|
||||
updateSplitArgPosMap sinfo
|
||||
|
||||
/--
|
||||
Helper function for executing `x` with a fresh `newFacts` and without modifying
|
||||
the goal state.
|
||||
-/
|
||||
def withoutModifyingState (x : GoalM α) : GoalM α := do
|
||||
let saved ← get
|
||||
modify fun goal => { goal with newFacts := {} }
|
||||
try
|
||||
x
|
||||
finally
|
||||
set saved
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
||||
@@ -97,7 +97,8 @@ def _root_.Lean.MVarId.clearAuxDecls (mvarId : MVarId) : MetaM MVarId := mvarId.
|
||||
try
|
||||
mvarId ← mvarId.clear fvarId
|
||||
catch _ =>
|
||||
throwTacticEx `grind.clear_aux_decls mvarId "failed to clear local auxiliary declaration"
|
||||
let userName := (← fvarId.getDecl).userName
|
||||
throwTacticEx `grind mvarId m!"the goal mentions the declaration `{userName}`, which is being defined. To avoid circular reasoning, try rewriting the goal to eliminate `{userName}` before using `grind`."
|
||||
return mvarId
|
||||
|
||||
/--
|
||||
|
||||
@@ -29,16 +29,6 @@ def withTypeAscription (d : Delab) (cond : Bool := true) : Delab := do
|
||||
else
|
||||
return stx
|
||||
|
||||
/--
|
||||
If `pp.tagAppFns` is set, then `d` is evaluated with the delaborated head constant as the ref.
|
||||
-/
|
||||
def withFnRefWhenTagAppFns (d : Delab) : Delab := do
|
||||
if (← getExpr).getAppFn.isConst && (← getPPOption getPPTagAppFns) then
|
||||
let head ← withNaryFn delab
|
||||
withRef head <| d
|
||||
else
|
||||
d
|
||||
|
||||
/--
|
||||
Wraps the identifier (or identifier with explicit universe levels) with `@` if `pp.analysis.blockImplicit` is set to true.
|
||||
-/
|
||||
@@ -134,6 +124,18 @@ def delabConst : Delab := do
|
||||
else
|
||||
return stx
|
||||
|
||||
/--
|
||||
If `pp.tagAppFns` is set, and if the current expression is a constant application,
|
||||
then `d` is evaluated with the head constant delaborated with `delabConst` as the ref.
|
||||
-/
|
||||
def withFnRefWhenTagAppFns (d : Delab) : Delab := do
|
||||
if (← getExpr).getAppFn.isConst && (← getPPOption getPPTagAppFns) then
|
||||
-- delabConst in `pp.tagAppFns` mode annotates the term.
|
||||
let head ← withNaryFn delabConst
|
||||
withRef head <| d
|
||||
else
|
||||
d
|
||||
|
||||
def withMDataOptions [Inhabited α] (x : DelabM α) : DelabM α := do
|
||||
match ← getExpr with
|
||||
| Expr.mdata m .. =>
|
||||
|
||||
@@ -207,7 +207,7 @@ This option can only be set on the command line, not in the lakefile or via `set
|
||||
stickyInteractiveDiagnostics ++ docInteractiveDiagnostics
|
||||
|>.map (·.toDiagnostic)
|
||||
let notification := mkPublishDiagnosticsNotification doc.meta diagnostics
|
||||
ctx.chanOut.send notification
|
||||
ctx.chanOut.sync.send notification
|
||||
|
||||
open Language in
|
||||
/--
|
||||
@@ -239,7 +239,7 @@ This option can only be set on the command line, not in the lakefile or via `set
|
||||
publishDiagnostics ctx doc
|
||||
-- This will overwrite existing ilean info for the file, in case something
|
||||
-- went wrong during the incremental updates.
|
||||
ctx.chanOut.send (← mkIleanInfoFinalNotification doc.meta st.allInfoTrees)
|
||||
ctx.chanOut.sync.send (← mkIleanInfoFinalNotification doc.meta st.allInfoTrees)
|
||||
return ()
|
||||
where
|
||||
/--
|
||||
@@ -312,7 +312,7 @@ This option can only be set on the command line, not in the lakefile or via `set
|
||||
if let some itree := node.element.infoTree? then
|
||||
let mut newInfoTrees := (← get).newInfoTrees.push itree
|
||||
if (← get).hasBlocked then
|
||||
ctx.chanOut.send (← mkIleanInfoUpdateNotification doc.meta newInfoTrees)
|
||||
ctx.chanOut.sync.send (← mkIleanInfoUpdateNotification doc.meta newInfoTrees)
|
||||
newInfoTrees := #[]
|
||||
modify fun st => { st with newInfoTrees, allInfoTrees := st.allInfoTrees.push itree }
|
||||
|
||||
@@ -329,7 +329,7 @@ This option can only be set on the command line, not in the lakefile or via `set
|
||||
| none => rs.push r
|
||||
let ranges := ranges.map (·.toLspRange doc.meta.text)
|
||||
let notifs := ranges.map ({ range := ·, kind := .processing })
|
||||
ctx.chanOut.send <| mkFileProgressNotification doc.meta notifs
|
||||
ctx.chanOut.sync.send <| mkFileProgressNotification doc.meta notifs
|
||||
|
||||
end Elab
|
||||
|
||||
@@ -389,9 +389,9 @@ def setupImports
|
||||
severity? := DiagnosticSeverity.information
|
||||
message := stderrLine
|
||||
}
|
||||
chanOut.send <| mkPublishDiagnosticsNotification meta #[progressDiagnostic]
|
||||
chanOut.sync.send <| mkPublishDiagnosticsNotification meta #[progressDiagnostic]
|
||||
-- clear progress notifications in the end
|
||||
chanOut.send <| mkPublishDiagnosticsNotification meta #[]
|
||||
chanOut.sync.send <| mkPublishDiagnosticsNotification meta #[]
|
||||
match fileSetupResult.kind with
|
||||
| .importsOutOfDate =>
|
||||
return .error {
|
||||
@@ -413,6 +413,8 @@ def setupImports
|
||||
-- default to async elaboration; see also `Elab.async` docs
|
||||
let opts := Elab.async.setIfNotSet opts true
|
||||
|
||||
let opts := Elab.inServer.set opts true
|
||||
|
||||
return .ok {
|
||||
mainModuleName := meta.mod
|
||||
opts
|
||||
@@ -523,7 +525,7 @@ section ServerRequests
|
||||
(freshRequestId, freshRequestId + 1)
|
||||
let responseTask ← ctx.initPendingServerRequest responseType freshRequestId
|
||||
let r : JsonRpc.Request paramType := ⟨freshRequestId, method, param⟩
|
||||
ctx.chanOut.send r
|
||||
ctx.chanOut.sync.send r
|
||||
return responseTask
|
||||
|
||||
def sendUntypedServerRequest
|
||||
@@ -677,7 +679,7 @@ section MessageHandling
|
||||
let availableImports ← ImportCompletion.collectAvailableImports
|
||||
let lastRequestTimestampMs ← IO.monoMsNow
|
||||
let completions := ImportCompletion.find text st.doc.initSnap.stx params availableImports
|
||||
ctx.chanOut.send <| .response id (toJson completions)
|
||||
ctx.chanOut.sync.send <| .response id (toJson completions)
|
||||
pure { availableImports, lastRequestTimestampMs : AvailableImportsCache }
|
||||
|
||||
| some task => ServerTask.IO.mapTaskCostly (t := task) fun (result : Except Error AvailableImportsCache) => do
|
||||
@@ -687,7 +689,7 @@ section MessageHandling
|
||||
availableImports ← ImportCompletion.collectAvailableImports
|
||||
lastRequestTimestampMs := timestampNowMs
|
||||
let completions := ImportCompletion.find text st.doc.initSnap.stx params availableImports
|
||||
ctx.chanOut.send <| .response id (toJson completions)
|
||||
ctx.chanOut.sync.send <| .response id (toJson completions)
|
||||
pure { availableImports, lastRequestTimestampMs : AvailableImportsCache }
|
||||
|
||||
def handleStatefulPreRequestSpecialCases (id : RequestID) (method : String) (params : Json) : WorkerM Bool := do
|
||||
@@ -699,7 +701,7 @@ section MessageHandling
|
||||
| "$/lean/rpc/connect" =>
|
||||
let ps ← parseParams RpcConnectParams params
|
||||
let resp ← handleRpcConnect ps
|
||||
ctx.chanOut.send <| .response id (toJson resp)
|
||||
ctx.chanOut.sync.send <| .response id (toJson resp)
|
||||
return true
|
||||
| "textDocument/completion" =>
|
||||
let params ← parseParams CompletionParams params
|
||||
@@ -712,7 +714,7 @@ section MessageHandling
|
||||
| _ =>
|
||||
return false
|
||||
catch e =>
|
||||
ctx.chanOut.send <| .responseError id .internalError (toString e) none
|
||||
ctx.chanOut.sync.send <| .responseError id .internalError (toString e) none
|
||||
return true
|
||||
|
||||
open Widget RequestM Language in
|
||||
@@ -834,7 +836,7 @@ section MessageHandling
|
||||
emitResponse ctx (isComplete := false) <| e.toLspResponseError id
|
||||
where
|
||||
emitResponse (ctx : WorkerContext) (m : JsonRpc.Message) (isComplete : Bool) : IO Unit := do
|
||||
ctx.chanOut.send m
|
||||
ctx.chanOut.sync.send m
|
||||
let timestamp ← IO.monoMsNow
|
||||
ctx.modifyPartialHandler method fun h => { h with
|
||||
requestsInFlight := h.requestsInFlight - 1
|
||||
|
||||
@@ -41,6 +41,7 @@ structure InfoPopup where
|
||||
doc : Option String
|
||||
deriving Inhabited, RpcEncodable
|
||||
|
||||
open PrettyPrinter.Delaborator in
|
||||
/-- Given elaborator info for a particular subexpression. Produce the `InfoPopup`.
|
||||
|
||||
The intended usage of this is for the infoview to pass the `InfoWithCtx` which
|
||||
@@ -53,12 +54,10 @@ def makePopup : WithRpcRef InfoWithCtx → RequestM (RequestTask InfoPopup)
|
||||
| some type => some <$> ppExprTagged type
|
||||
| none => pure none
|
||||
let exprExplicit? ← match i.info with
|
||||
| Elab.Info.ofTermInfo ti
|
||||
| Elab.Info.ofDelabTermInfo { toTermInfo := ti, explicit := true, ..} =>
|
||||
some <$> ppExprTaggedWithoutTopLevelHighlight ti.expr (explicit := true)
|
||||
| Elab.Info.ofDelabTermInfo { toTermInfo := ti, explicit := false, ..} =>
|
||||
-- Keep the top-level tag so that users can also see the explicit version of the term on an additional hover.
|
||||
some <$> ppExprTagged ti.expr (explicit := false)
|
||||
| Elab.Info.ofTermInfo ti =>
|
||||
some <$> ppExprForPopup ti.expr (explicit := true)
|
||||
| Elab.Info.ofDelabTermInfo { toTermInfo := ti, explicit, ..} =>
|
||||
some <$> ppExprForPopup ti.expr (explicit := explicit)
|
||||
| Elab.Info.ofFieldInfo fi => pure <| some <| TaggedText.text fi.fieldName.toString
|
||||
| _ => pure none
|
||||
return {
|
||||
@@ -67,11 +66,26 @@ def makePopup : WithRpcRef InfoWithCtx → RequestM (RequestTask InfoPopup)
|
||||
doc := ← i.info.docString? : InfoPopup
|
||||
}
|
||||
where
|
||||
ppExprTaggedWithoutTopLevelHighlight (e : Expr) (explicit : Bool) : MetaM CodeWithInfos := do
|
||||
let pp ← ppExprTagged e (explicit := explicit)
|
||||
return match pp with
|
||||
| .tag _ tt => tt
|
||||
| tt => tt
|
||||
maybeWithoutTopLevelHighlight : Bool → CodeWithInfos → CodeWithInfos
|
||||
| true, .tag _ tt => tt
|
||||
| _, tt => tt
|
||||
ppExprForPopup (e : Expr) (explicit : Bool := false) : MetaM CodeWithInfos := do
|
||||
let mut e := e
|
||||
-- When hovering over a metavariable, we want to see its value, even if `pp.instantiateMVars` is false.
|
||||
if explicit && e.isMVar then
|
||||
if let some e' ← getExprMVarAssignment? e.mvarId! then
|
||||
e := e'
|
||||
-- When `explicit` is false, keep the top-level tag so that users can also see the explicit version of the term on an additional hover.
|
||||
maybeWithoutTopLevelHighlight explicit <$> ppExprTagged e do
|
||||
if explicit then
|
||||
withOptionAtCurrPos pp.tagAppFns.name true do
|
||||
withOptionAtCurrPos pp.explicit.name true do
|
||||
withOptionAtCurrPos pp.mvars.anonymous.name true do
|
||||
delabApp
|
||||
else
|
||||
withOptionAtCurrPos pp.proofs.name true do
|
||||
withOptionAtCurrPos pp.sorrySource.name true do
|
||||
delab
|
||||
|
||||
builtin_initialize
|
||||
registerBuiltinRpcProcedure
|
||||
|
||||
@@ -56,6 +56,36 @@ elab_rules : tactic
|
||||
-- can't use a naked promise in `initialize` as marking it persistent would block
|
||||
initialize unblockedCancelTk : IO.CancelToken ← IO.CancelToken.new
|
||||
|
||||
/--
|
||||
Waits for `unblock` to be called, which is expected to happen in a subsequent document version that
|
||||
does not invalidate this tactic. Complains if cancellation token was set before unblocking, i.e. if
|
||||
the tactic was invalidated after all.
|
||||
-/
|
||||
scoped syntax "wait_for_unblock" : tactic
|
||||
@[incremental]
|
||||
elab_rules : tactic
|
||||
| `(tactic| wait_for_unblock) => do
|
||||
let ctx ← readThe Core.Context
|
||||
let some cancelTk := ctx.cancelTk? | unreachable!
|
||||
|
||||
dbg_trace "blocked!"
|
||||
log "blocked"
|
||||
let ctx ← readThe Elab.Term.Context
|
||||
let some tacSnap := ctx.tacSnap? | unreachable!
|
||||
tacSnap.new.resolve {
|
||||
diagnostics := (← Language.Snapshot.Diagnostics.ofMessageLog (← Core.getMessageLog))
|
||||
stx := default
|
||||
finished := default
|
||||
}
|
||||
|
||||
while true do
|
||||
if (← unblockedCancelTk.isSet) then
|
||||
break
|
||||
IO.sleep 30
|
||||
if (← cancelTk.isSet) then
|
||||
IO.eprintln "cancelled!"
|
||||
log "cancelled (should never be visible)"
|
||||
|
||||
/--
|
||||
Spawns a `logSnapshotTask` that waits for `unblock` to be called, which is expected to happen in a
|
||||
subsequent document version that does not invalidate this tactic. Complains if cancellation token
|
||||
@@ -83,6 +113,10 @@ scoped elab "unblock" : tactic => do
|
||||
dbg_trace "unblocking!"
|
||||
unblockedCancelTk.set
|
||||
|
||||
/--
|
||||
Like `wait_for_cancel_once` but does the waiting in a separate task and waits for its
|
||||
cancellation.
|
||||
-/
|
||||
scoped syntax "wait_for_cancel_once_async" : tactic
|
||||
@[incremental]
|
||||
elab_rules : tactic
|
||||
@@ -110,3 +144,35 @@ elab_rules : tactic
|
||||
|
||||
dbg_trace "blocked!"
|
||||
log "blocked"
|
||||
|
||||
/--
|
||||
Like `wait_for_cancel_once_async` but waits for the main thread's cancellation token. This is useful
|
||||
to test main thread cancellation in non-incremental contexts because we otherwise wouldn't be able
|
||||
to send out the "blocked" message from there.
|
||||
-/
|
||||
scoped syntax "wait_for_main_cancel_once_async" : tactic
|
||||
@[incremental]
|
||||
elab_rules : tactic
|
||||
| `(tactic| wait_for_main_cancel_once_async) => do
|
||||
let prom ← IO.Promise.new
|
||||
if let some t := (← onceRef.modifyGet (fun old => (old, old.getD prom.result!))) then
|
||||
IO.wait t
|
||||
return
|
||||
|
||||
let some cancelTk := (← readThe Core.Context).cancelTk? | unreachable!
|
||||
let act ← Elab.Term.wrapAsyncAsSnapshot (cancelTk? := none) fun _ => do
|
||||
let ctx ← readThe Core.Context
|
||||
-- TODO: `CancelToken` should probably use `Promise`
|
||||
while true do
|
||||
if (← cancelTk.isSet) then
|
||||
break
|
||||
IO.sleep 30
|
||||
IO.eprintln "cancelled!"
|
||||
log "cancelled (should never be visible)"
|
||||
prom.resolve ()
|
||||
Core.checkInterrupted
|
||||
let t ← BaseIO.asTask (act ())
|
||||
Core.logSnapshotTask { stx? := none, task := t, cancelTk? := cancelTk }
|
||||
|
||||
dbg_trace "blocked!"
|
||||
log "blocked"
|
||||
|
||||
@@ -73,24 +73,15 @@ where
|
||||
}
|
||||
TaggedText.tag t (go subTt)
|
||||
|
||||
def ppExprTagged (e : Expr) (explicit : Bool := false) : MetaM CodeWithInfos := do
|
||||
open PrettyPrinter Delaborator in
|
||||
/--
|
||||
Pretty prints the expression `e` using delaborator `delab`, returning an object that represents
|
||||
the pretty printed syntax paired with information needed to support hovers.
|
||||
-/
|
||||
def ppExprTagged (e : Expr) (delab : Delab := Delaborator.delab) : MetaM CodeWithInfos := do
|
||||
if pp.raw.get (← getOptions) then
|
||||
return .text (toString (← instantiateMVars e))
|
||||
let delab := open PrettyPrinter.Delaborator in
|
||||
if explicit then
|
||||
withOptionAtCurrPos pp.tagAppFns.name true do
|
||||
withOptionAtCurrPos pp.explicit.name true do
|
||||
withOptionAtCurrPos pp.mvars.anonymous.name true do
|
||||
delabApp
|
||||
else
|
||||
withOptionAtCurrPos pp.proofs.name true do
|
||||
withOptionAtCurrPos pp.sorrySource.name true do
|
||||
delab
|
||||
let mut e := e
|
||||
-- When hovering over a metavariable, we want to see its value, even if `pp.instantiateMVars` is false.
|
||||
if explicit && e.isMVar then
|
||||
if let some e' ← getExprMVarAssignment? e.mvarId! then
|
||||
e := e'
|
||||
let e ← if getPPInstantiateMVars (← getOptions) then instantiateMVars e else pure e
|
||||
return .text (toString e)
|
||||
let ⟨fmt, infos⟩ ← PrettyPrinter.ppExprWithInfos e (delab := delab)
|
||||
let tt := TaggedText.prettyTagged fmt
|
||||
let ctx := {
|
||||
|
||||
@@ -215,7 +215,7 @@ theorem modify_eq_alter [BEq α] [LawfulBEq α] {a : α} {f : β a → β a} {l
|
||||
modify a f l = alter a (·.map f) l := by
|
||||
induction l
|
||||
· rfl
|
||||
· next ih => simp only [modify, beq_iff_eq, alter, Option.map_some', ih]
|
||||
· next ih => simp only [modify, beq_iff_eq, alter, Option.map_some, ih]
|
||||
|
||||
namespace Const
|
||||
|
||||
@@ -235,7 +235,7 @@ theorem modify_eq_alter [BEq α] [EquivBEq α] {a : α} {f : β → β} {l : Ass
|
||||
modify a f l = alter a (·.map f) l := by
|
||||
induction l
|
||||
· rfl
|
||||
· next ih => simp only [modify, beq_iff_eq, alter, Option.map_some', ih]
|
||||
· next ih => simp only [modify, beq_iff_eq, alter, Option.map_some, ih]
|
||||
|
||||
end Const
|
||||
|
||||
|
||||
@@ -630,7 +630,7 @@ theorem Const.getThenInsertIfNew?_eq_insertIfNewₘ [BEq α] [Hashable α] (m :
|
||||
(a : α) (b : β) : (Const.getThenInsertIfNew? m a b).2 = m.insertIfNewₘ a b := by
|
||||
rw [getThenInsertIfNew?, insertIfNewₘ, containsₘ, bucket]
|
||||
dsimp only [Array.ugetElem_eq_getElem, Array.uset]
|
||||
split <;> simp_all [consₘ, updateBucket, List.containsKey_eq_isSome_getValue?, -Option.not_isSome]
|
||||
split <;> simp_all [consₘ, updateBucket, List.containsKey_eq_isSome_getValue?, -Option.isSome_eq_false_iff]
|
||||
|
||||
theorem Const.getThenInsertIfNew?_eq_get?ₘ [BEq α] [Hashable α] (m : Raw₀ α (fun _ => β)) (a : α)
|
||||
(b : β) : (Const.getThenInsertIfNew? m a b).1 = Const.get?ₘ m a := by
|
||||
|
||||
@@ -964,7 +964,7 @@ theorem isHashSelf_filterMapₘ [BEq α] [Hashable α] [ReflBEq α] [LawfulHasha
|
||||
IsHashSelf (m.filterMapₘ f).1.buckets := by
|
||||
refine h.buckets_hash_self.updateAllBuckets (fun l p hp => ?_)
|
||||
have hp := AssocList.toList_filterMap.mem_iff.1 hp
|
||||
simp only [mem_filterMap, Option.map_eq_some'] at hp
|
||||
simp only [mem_filterMap, Option.map_eq_some_iff] at hp
|
||||
obtain ⟨p, ⟨hkv, ⟨d, ⟨-, rfl⟩⟩⟩⟩ := hp
|
||||
exact containsKey_of_mem hkv
|
||||
|
||||
|
||||
@@ -1031,7 +1031,7 @@ theorem ordered_filterMap [Ord α] {t : Impl α β} {h} {f : (a : α) → β a
|
||||
simp only [Ordered, toListModel_filterMap]
|
||||
apply ho.filterMap
|
||||
intro e f hef e' he' f' hf'
|
||||
simp only [Option.map_eq_some'] at he' hf'
|
||||
simp only [Option.map_eq_some_iff] at he' hf'
|
||||
obtain ⟨_, _, rfl⟩ := he'
|
||||
obtain ⟨_, _, rfl⟩ := hf'
|
||||
exact hef
|
||||
|
||||
@@ -169,7 +169,7 @@ theorem getValue?_eq_getEntry? [BEq α] {l : List ((_ : α) × β)} {a : α} :
|
||||
· next k v l ih =>
|
||||
cases h : k == a
|
||||
· rw [getEntry?_cons_of_false h, getValue?_cons_of_false h, ih]
|
||||
· rw [getEntry?_cons_of_true h, getValue?_cons_of_true h, Option.map_some']
|
||||
· rw [getEntry?_cons_of_true h, getValue?_cons_of_true h, Option.map_some]
|
||||
|
||||
theorem getValue?_congr [BEq α] [PartialEquivBEq α] {l : List ((_ : α) × β)} {a b : α}
|
||||
(h : a == b) : getValue? a l = getValue? b l := by
|
||||
@@ -338,7 +338,7 @@ theorem getEntry?_eq_none [BEq α] {l : List ((a : α) × β a)} {a : α} :
|
||||
@[simp]
|
||||
theorem getValue?_eq_none {β : Type v} [BEq α] {l : List ((_ : α) × β)} {a : α} :
|
||||
getValue? a l = none ↔ containsKey a l = false := by
|
||||
rw [getValue?_eq_getEntry?, Option.map_eq_none', getEntry?_eq_none]
|
||||
rw [getValue?_eq_getEntry?, Option.map_eq_none_iff, getEntry?_eq_none]
|
||||
|
||||
theorem containsKey_eq_isSome_getValue? {β : Type v} [BEq α] {l : List ((_ : α) × β)} {a : α} :
|
||||
containsKey a l = (getValue? a l).isSome := by
|
||||
@@ -613,7 +613,7 @@ theorem getKey?_eq_getEntry? [BEq α] {l : List ((a : α) × β a)} {a : α} :
|
||||
· next k v l ih =>
|
||||
cases h : k == a
|
||||
· rw [getEntry?_cons_of_false h, getKey?_cons_of_false h, ih]
|
||||
· rw [getEntry?_cons_of_true h, getKey?_cons_of_true h, Option.map_some']
|
||||
· rw [getEntry?_cons_of_true h, getKey?_cons_of_true h, Option.map_some]
|
||||
|
||||
theorem fst_mem_keys_of_mem [BEq α] [EquivBEq α] {a : (a : α) × β a} {l : List ((a : α) × β a)}
|
||||
(hm : a ∈ l) : a.1 ∈ keys l :=
|
||||
@@ -1515,7 +1515,7 @@ theorem containsKey_insertEntryIfNew [BEq α] [PartialEquivBEq α] {l : List ((a
|
||||
cases h : k == a
|
||||
· simp
|
||||
· rw [containsKey_eq_isSome_getEntry?, getEntry?_congr h]
|
||||
simp
|
||||
simp [-Option.not_isSome]
|
||||
|
||||
theorem containsKey_insertEntryIfNew_self [BEq α] [EquivBEq α] {l : List ((a : α) × β a)} {k : α}
|
||||
{v : β k} : containsKey k (insertEntryIfNew k v l) := by
|
||||
@@ -2098,7 +2098,7 @@ theorem containsKey_append_of_not_contains_right [BEq α] {l l' : List ((a : α)
|
||||
@[simp]
|
||||
theorem getValue?_append {β : Type v} [BEq α] {l l' : List ((_ : α) × β)} {a : α} :
|
||||
getValue? a (l ++ l') = (getValue? a l).or (getValue? a l') := by
|
||||
simp [getValue?_eq_getEntry?, Option.map_or']
|
||||
simp [getValue?_eq_getEntry?, Option.map_or]
|
||||
|
||||
theorem getValue?_append_of_containsKey_eq_false {β : Type v} [BEq α] {l l' : List ((_ : α) × β)}
|
||||
{a : α} (h : containsKey a l' = false) : getValue? a (l ++ l') = getValue? a l := by
|
||||
@@ -2233,7 +2233,7 @@ theorem mem_map_toProd_iff_mem {β : Type v} {k : α} {v : β} {l : List ((_ :
|
||||
theorem mem_iff_getValue?_eq_some [BEq α] [LawfulBEq α] {β : Type v} {k : α} {v : β}
|
||||
{l : List ((_ : α) × β)} (h : DistinctKeys l) :
|
||||
⟨k, v⟩ ∈ l ↔ getValue? k l = some v := by
|
||||
simp only [mem_iff_getEntry?_eq_some h, getValue?_eq_getEntry?, Option.map_eq_some']
|
||||
simp only [mem_iff_getEntry?_eq_some h, getValue?_eq_getEntry?, Option.map_eq_some_iff]
|
||||
constructor
|
||||
· intro h
|
||||
exists ⟨k, v⟩
|
||||
@@ -2256,7 +2256,7 @@ theorem find?_map_toProd_eq_some_iff_getKey?_eq_some_and_getValue?_eq_some [BEq
|
||||
| nil => simp
|
||||
| cons hd tl ih =>
|
||||
simp only [List.map_cons, List.find?_cons_eq_some, Prod.mk.injEq, Bool.not_eq_eq_eq_not,
|
||||
Bool.not_true, Option.map_eq_some', getKey?, cond_eq_if, getValue?]
|
||||
Bool.not_true, Option.map_eq_some_iff, getKey?, cond_eq_if, getValue?]
|
||||
by_cases hdfst_k: hd.fst == k
|
||||
· simp only [hdfst_k, true_and, Bool.true_eq_false, false_and, or_false, ↓reduceIte,
|
||||
Option.some.injEq]
|
||||
@@ -2271,7 +2271,7 @@ theorem mem_iff_getKey?_eq_some_and_getValue?_eq_some [BEq α] [EquivBEq α]
|
||||
theorem getValue?_eq_some_iff_exists_beq_and_mem_toList {β : Type v} [BEq α] [EquivBEq α]
|
||||
{l : List ((_ : α) × β)} {k: α} {v : β} (h : DistinctKeys l) :
|
||||
getValue? k l = some v ↔ ∃ k', (k == k') = true ∧ (k', v) ∈ l.map (fun x => (x.fst, x.snd)) := by
|
||||
simp only [getValue?_eq_getEntry?, Option.map_eq_some', ← mem_map_toProd_iff_mem,
|
||||
simp only [getValue?_eq_getEntry?, Option.map_eq_some_iff, ← mem_map_toProd_iff_mem,
|
||||
mem_iff_getEntry?_eq_some h]
|
||||
constructor
|
||||
· intro h'
|
||||
@@ -2635,7 +2635,7 @@ theorem getKey?_insertList_of_mem [BEq α] [EquivBEq α]
|
||||
rcases List.mem_map.1 mem with ⟨⟨k, v⟩, pair_mem, rfl⟩
|
||||
rw [getKey?_eq_getEntry?, getEntry?_insertList distinct_l distinct_toInsert,
|
||||
getEntry?_of_mem (DistinctKeys.def.2 distinct_toInsert) k_beq pair_mem, Option.some_or,
|
||||
Option.map_some']
|
||||
Option.map_some]
|
||||
|
||||
theorem getKey_insertList_of_contains_eq_false [BEq α] [EquivBEq α]
|
||||
{l toInsert : List ((a : α) × β a)} {k : α}
|
||||
@@ -3074,7 +3074,7 @@ theorem getKey?_insertListIfNewUnit_of_contains_eq_false_of_contains_eq_false [B
|
||||
(h': containsKey k l = false) (h : toInsert.contains k = false) :
|
||||
getKey? k (insertListIfNewUnit l toInsert) = none := by
|
||||
rw [getKey?_eq_getEntry?,
|
||||
getEntry?_insertListIfNewUnit_of_contains_eq_false h, Option.map_eq_none', getEntry?_eq_none]
|
||||
getEntry?_insertListIfNewUnit_of_contains_eq_false h, Option.map_eq_none_iff, getEntry?_eq_none]
|
||||
exact h'
|
||||
|
||||
theorem getKey?_insertListIfNewUnit_of_contains_eq_false_of_mem [BEq α] [EquivBEq α]
|
||||
@@ -3083,8 +3083,8 @@ theorem getKey?_insertListIfNewUnit_of_contains_eq_false_of_mem [BEq α] [EquivB
|
||||
(mem' : containsKey k l = false)
|
||||
(distinct : toInsert.Pairwise (fun a b => (a == b) = false)) (mem : k ∈ toInsert) :
|
||||
getKey? k' (insertListIfNewUnit l toInsert) = some k := by
|
||||
simp only [getKey?_eq_getEntry?, getEntry?_insertListIfNewUnit, Option.map_eq_some',
|
||||
Option.or_eq_some, getEntry?_eq_none]
|
||||
simp only [getKey?_eq_getEntry?, getEntry?_insertListIfNewUnit, Option.map_eq_some_iff,
|
||||
Option.or_eq_some_iff, getEntry?_eq_none]
|
||||
exists ⟨k, ()⟩
|
||||
simp only [and_true]
|
||||
right
|
||||
@@ -4301,8 +4301,8 @@ theorem minEntry?_eq_some_iff [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α]
|
||||
theorem minKey?_eq_some_iff_getKey?_eq_self_and_forall [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α]
|
||||
{k} {l : List ((a : α) × β a)} (hd : DistinctKeys l) :
|
||||
minKey? l = some k ↔ getKey? k l = some k ∧ ∀ k' : α, containsKey k' l → (compare k k').isLE := by
|
||||
simp only [minKey?, Option.map_eq_some', minEntry?_eq_some_iff _ hd]
|
||||
simp only [getKey?_eq_getEntry?, Option.map_eq_some', getEntry?_eq_some_iff hd]
|
||||
simp only [minKey?, Option.map_eq_some_iff, minEntry?_eq_some_iff _ hd]
|
||||
simp only [getKey?_eq_getEntry?, Option.map_eq_some_iff, getEntry?_eq_some_iff hd]
|
||||
apply Iff.intro
|
||||
· rintro ⟨_, ⟨hm, hcmp⟩, rfl⟩
|
||||
exact ⟨⟨_, ⟨BEq.refl, hm⟩, rfl⟩, hcmp⟩
|
||||
@@ -4312,7 +4312,7 @@ theorem minKey?_eq_some_iff_getKey?_eq_self_and_forall [Ord α] [TransOrd α] [B
|
||||
theorem minKey?_eq_some_iff_mem_and_forall [Ord α] [LawfulEqOrd α] [TransOrd α] [BEq α]
|
||||
[LawfulBEqOrd α] {k} {l : List ((a : α) × β a)} (hd : DistinctKeys l) :
|
||||
minKey? l = some k ↔ containsKey k l ∧ ∀ k' : α, containsKey k' l → (compare k k').isLE := by
|
||||
simp only [minKey?, Option.map_eq_some', minEntry?_eq_some_iff _ hd]
|
||||
simp only [minKey?, Option.map_eq_some_iff, minEntry?_eq_some_iff _ hd]
|
||||
apply Iff.intro
|
||||
· rintro ⟨_, ⟨hm, hcmp⟩, rfl⟩
|
||||
exact ⟨containsKey_of_mem hm, hcmp⟩
|
||||
@@ -4361,7 +4361,7 @@ theorem isNone_minKey?_eq_isEmpty [Ord α] {l : List ((a : α) × β a)} :
|
||||
|
||||
theorem isSome_minEntry?_eq_not_isEmpty [Ord α] {l : List ((a : α) × β a)} :
|
||||
(minEntry? l).isSome = !l.isEmpty := by
|
||||
rw [← Bool.not_inj_iff, Bool.not_not, Bool.eq_iff_iff, Bool.not_eq_true', Option.not_isSome,
|
||||
rw [← Bool.not_inj_iff, Bool.not_not, Bool.eq_iff_iff, Bool.not_eq_true', Option.isSome_eq_false_iff,
|
||||
Option.isNone_iff_eq_none]
|
||||
apply minEntry?_eq_none_iff_isEmpty
|
||||
|
||||
@@ -4384,7 +4384,7 @@ theorem minEntry?_map [Ord α] (l : List ((a : α) × β a)) (f : (a : α) × β
|
||||
simp only [minEntry?, List.min?]
|
||||
cases l <;> try rfl
|
||||
rename_i e es
|
||||
simp only [List.map_cons, Option.map_some', Option.some.injEq]
|
||||
simp only [List.map_cons, Option.map_some, Option.some.injEq]
|
||||
rw [← List.foldr_reverse, ← List.foldr_reverse, ← List.map_reverse]
|
||||
induction es.reverse with
|
||||
| nil => rfl
|
||||
@@ -4442,7 +4442,7 @@ theorem minEntry?_insertEntry [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α]
|
||||
cases h : containsKey k l
|
||||
· simp only [cond_false, minEntry?_cons, Option.some.injEq]
|
||||
rfl
|
||||
· rw [cond_true, minEntry?_replaceEntry hl, Option.map_eq_some']
|
||||
· rw [cond_true, minEntry?_replaceEntry hl, Option.map_eq_some_iff]
|
||||
have := isSome_minEntry?_of_contains ‹_›
|
||||
simp only [Option.isSome_iff_exists] at this
|
||||
obtain ⟨a, ha⟩ := this
|
||||
@@ -4538,7 +4538,7 @@ theorem minKey?_bind_getKey? [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α]
|
||||
theorem containsKey_minKey? [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α] {l : List ((a : α) × β a)}
|
||||
(hd : DistinctKeys l) {km} (hkm : minKey? l = some km) :
|
||||
containsKey km l := by
|
||||
simp only [minKey?, Option.map_eq_some', minEntry?_eq_some_iff _ hd] at hkm
|
||||
simp only [minKey?, Option.map_eq_some_iff, minEntry?_eq_some_iff _ hd] at hkm
|
||||
obtain ⟨e, ⟨hm, _⟩, rfl⟩ := hkm
|
||||
exact containsKey_of_mem hm
|
||||
|
||||
@@ -4710,7 +4710,7 @@ theorem minKey?_alterKey_eq_self [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd
|
||||
minKey? (alterKey k f l) = some k ↔
|
||||
(f (getValueCast? k l)).isSome ∧ ∀ k', containsKey k' l → (compare k k').isLE := by
|
||||
simp only [minKey?_eq_some_iff_getKey?_eq_self_and_forall hd.alterKey, getKey?_alterKey _ hd,
|
||||
beq_self_eq_true, ↓reduceIte, ite_eq_left_iff, Bool.not_eq_true, Option.not_isSome,
|
||||
beq_self_eq_true, ↓reduceIte, ite_eq_left_iff, Bool.not_eq_true, Option.isSome_eq_false_iff,
|
||||
Option.isNone_iff_eq_none, reduceCtorEq, imp_false, ← Option.isSome_iff_ne_none,
|
||||
containsKey_alterKey hd, beq_iff_eq, Bool.ite_eq_true_distrib, and_congr_right_iff]
|
||||
intro hf
|
||||
@@ -4753,18 +4753,18 @@ theorem minKey?_modifyKey_eq_minKey? [Ord α] [TransOrd α] [BEq α] [LawfulBEqO
|
||||
simp only [minKey?_modifyKey hd]
|
||||
cases minKey? l
|
||||
· rfl
|
||||
· simp only [beq_iff_eq, Option.map_some', Option.some.injEq, ite_eq_right_iff]
|
||||
· simp only [beq_iff_eq, Option.map_some, Option.some.injEq, ite_eq_right_iff]
|
||||
exact Eq.symm
|
||||
|
||||
theorem isSome_minKey?_modifyKey [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α] {k f}
|
||||
{l : List ((_ : α) × β)} :
|
||||
(modifyKey k f l |> minKey?).isSome = !l.isEmpty := by
|
||||
simp [Option.isSome_map', isSome_minKey?_eq_not_isEmpty, isEmpty_modifyKey]
|
||||
simp [Option.isSome_map, isSome_minKey?_eq_not_isEmpty, isEmpty_modifyKey]
|
||||
|
||||
theorem isSome_minKey?_modifyKey_eq_isSome [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α] {k f}
|
||||
{l : List ((_ : α) × β)} :
|
||||
(modifyKey k f l |> minKey?).isSome = (minKey? l).isSome := by
|
||||
simp [Option.isSome_map', isSome_minKey?_eq_not_isEmpty, isEmpty_modifyKey]
|
||||
simp [Option.isSome_map, isSome_minKey?_eq_not_isEmpty, isEmpty_modifyKey]
|
||||
|
||||
theorem minKey?_modifyKey_beq [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α] {k f km kmm}
|
||||
{l : List ((_ : α) × β)} (hd : DistinctKeys l) (hkm : minKey? l = some km)
|
||||
@@ -4784,7 +4784,7 @@ theorem minKey?_alterKey_eq_self [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd
|
||||
(f (getValue? k l)).isSome ∧ ∀ k', containsKey k' l → (compare k k').isLE := by
|
||||
simp only [minKey?_eq_some_iff_getKey?_eq_self_and_forall hd.constAlterKey, getKey?_alterKey _ hd,
|
||||
← compare_eq_iff_beq, compare_self, ↓reduceIte, ite_eq_left_iff, Bool.not_eq_true,
|
||||
Option.not_isSome, Option.isNone_iff_eq_none, reduceCtorEq, imp_false,
|
||||
Option.isSome_eq_false_iff, Option.isNone_iff_eq_none, reduceCtorEq, imp_false,
|
||||
← Option.isSome_iff_ne_none, containsKey_alterKey hd, Bool.ite_eq_true_distrib,
|
||||
and_congr_right_iff]
|
||||
intro hf
|
||||
|
||||
@@ -1,137 +1,720 @@
|
||||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Gabriel Ebner
|
||||
Authors: Henrik Böving
|
||||
-/
|
||||
prelude
|
||||
import Init.System.Promise
|
||||
import Init.Data.Queue
|
||||
import Std.Sync.Mutex
|
||||
|
||||
/-!
|
||||
This module contains the implementation of `Std.Channel`. `Std.Channel` is a multi-producer
|
||||
multi-consumer FIFO channel that offers both bounded and unbounded buffering as well as synchronous
|
||||
and asynchronous APIs.
|
||||
|
||||
Additionally `Std.CloseableChannel` is provided in case closing the channel is of interest.
|
||||
The two are distinct as the non closable `Std.Channel` can never throw errors which makes
|
||||
for cleaner code.
|
||||
-/
|
||||
|
||||
namespace Std
|
||||
|
||||
/--
|
||||
Internal state of an `Channel`.
|
||||
|
||||
We maintain the invariant that at all times either `consumers` or `values` is empty.
|
||||
-/
|
||||
structure Channel.State (α : Type) where
|
||||
values : Std.Queue α := ∅
|
||||
consumers : Std.Queue (IO.Promise (Option α)) := ∅
|
||||
closed := false
|
||||
deriving Inhabited
|
||||
namespace CloseableChannel
|
||||
|
||||
/--
|
||||
FIFO channel with unbounded buffer, where `recv?` returns a `Task`.
|
||||
|
||||
A channel can be closed. Once it is closed, all `send`s are ignored, and
|
||||
`recv?` returns `none` once the queue is empty.
|
||||
Errors that may be thrown while interacting with the channel API.
|
||||
-/
|
||||
def Channel (α : Type) : Type := Mutex (Channel.State α)
|
||||
inductive Error where
|
||||
/--
|
||||
Tried to send to a closed channel.
|
||||
-/
|
||||
| closed
|
||||
/--
|
||||
Tried to close an already closed channel.
|
||||
-/
|
||||
| alreadyClosed
|
||||
deriving Repr, DecidableEq, Hashable
|
||||
|
||||
instance : Nonempty (Channel α) :=
|
||||
inferInstanceAs (Nonempty (Mutex _))
|
||||
instance : ToString Error where
|
||||
toString
|
||||
| .closed => "trying to send on an already closed channel"
|
||||
| .alreadyClosed => "trying to close an already closed channel"
|
||||
|
||||
/-- Creates a new `Channel`. -/
|
||||
def Channel.new : BaseIO (Channel α) :=
|
||||
Mutex.new {}
|
||||
instance : MonadLift (EIO Error) IO where
|
||||
monadLift x := EIO.toIO (.userError <| toString ·) x
|
||||
|
||||
/--
|
||||
Sends a message on an `Channel`.
|
||||
|
||||
This function does not block.
|
||||
The central state structure for an unbounded channel, maintains the following invariants:
|
||||
1. `values = ∅ ∨ consumers = ∅`
|
||||
2. `closed = true → consumers = ∅`
|
||||
-/
|
||||
def Channel.send (ch : Channel α) (v : α) : BaseIO Unit :=
|
||||
ch.atomically do
|
||||
private structure Unbounded.State (α : Type) where
|
||||
/--
|
||||
Values pushed into the channel that are waiting to be consumed.
|
||||
-/
|
||||
values : Std.Queue α
|
||||
/--
|
||||
Consumers that are blocked on a producer providing them a value. The `IO.Promise` will be
|
||||
resolved to `none` if the channel closes.
|
||||
-/
|
||||
consumers : Std.Queue (IO.Promise (Option α))
|
||||
/--
|
||||
Whether the channel is closed already.
|
||||
-/
|
||||
closed : Bool
|
||||
deriving Nonempty
|
||||
|
||||
private structure Unbounded (α : Type) where
|
||||
state : Mutex (Unbounded.State α)
|
||||
deriving Nonempty
|
||||
|
||||
namespace Unbounded
|
||||
|
||||
private def new : BaseIO (Unbounded α) := do
|
||||
return {
|
||||
state := ← Mutex.new {
|
||||
values := ∅
|
||||
consumers := ∅
|
||||
closed := false
|
||||
}
|
||||
}
|
||||
|
||||
private def trySend (ch : Unbounded α) (v : α) : BaseIO Bool := do
|
||||
ch.state.atomically do
|
||||
let st ← get
|
||||
if st.closed then return
|
||||
if let some (consumer, consumers) := st.consumers.dequeue? then
|
||||
if st.closed then
|
||||
return false
|
||||
else if let some (consumer, consumers) := st.consumers.dequeue? then
|
||||
consumer.resolve (some v)
|
||||
set { st with consumers }
|
||||
return true
|
||||
else
|
||||
set { st with values := st.values.enqueue v }
|
||||
return true
|
||||
|
||||
/--
|
||||
Closes an `Channel`.
|
||||
-/
|
||||
def Channel.close (ch : Channel α) : BaseIO Unit :=
|
||||
ch.atomically do
|
||||
private def send (ch : Unbounded α) (v : α) : BaseIO (Task (Except Error Unit)) := do
|
||||
if ← Unbounded.trySend ch v then
|
||||
return .pure <| .ok ()
|
||||
else
|
||||
return .pure <| .error .closed
|
||||
|
||||
private def close (ch : Unbounded α) : EIO Error Unit := do
|
||||
ch.state.atomically do
|
||||
let st ← get
|
||||
if st.closed then throw .alreadyClosed
|
||||
for consumer in st.consumers.toArray do consumer.resolve none
|
||||
set { st with closed := true, consumers := ∅ }
|
||||
set { st with consumers := ∅, closed := true }
|
||||
return ()
|
||||
|
||||
private def isClosed (ch : Unbounded α) : BaseIO Bool :=
|
||||
ch.state.atomically do
|
||||
return (← get).closed
|
||||
|
||||
private def tryRecv' : AtomicT (Unbounded.State α) BaseIO (Option α) := do
|
||||
let st ← get
|
||||
if let some (a, values) := st.values.dequeue? then
|
||||
set { st with values }
|
||||
return some a
|
||||
else
|
||||
return none
|
||||
|
||||
private def tryRecv (ch : Unbounded α) : BaseIO (Option α) :=
|
||||
ch.state.atomically do
|
||||
tryRecv'
|
||||
|
||||
private def recv (ch : Unbounded α) : BaseIO (Task (Option α)) := do
|
||||
ch.state.atomically do
|
||||
if let some val ← tryRecv' then
|
||||
return .pure <| some val
|
||||
else if (← get).closed then
|
||||
return .pure none
|
||||
else
|
||||
let promise ← IO.Promise.new
|
||||
modify fun st => { st with consumers := st.consumers.enqueue promise }
|
||||
return promise.result?.map (sync := true) (·.bind id)
|
||||
|
||||
end Unbounded
|
||||
|
||||
/--
|
||||
Receives a message, without blocking.
|
||||
The returned task waits for the message.
|
||||
Every message is only received once.
|
||||
|
||||
Returns `none` if the channel is closed and the queue is empty.
|
||||
The central state structure for a zero buffer channel, maintains the following invariants:
|
||||
1. `producers = ∅ ∨ consumers = ∅`
|
||||
2. `closed = true → consumers = ∅`
|
||||
-/
|
||||
def Channel.recv? (ch : Channel α) : BaseIO (Task (Option α)) :=
|
||||
ch.atomically do
|
||||
private structure Zero.State (α : Type) where
|
||||
/--
|
||||
Producers that are blocked on a consumer taking their value.
|
||||
-/
|
||||
producers : Std.Queue (α × IO.Promise Bool)
|
||||
/--
|
||||
Consumers that are blocked on a producer providing them a value. The `IO.Promise` will be resolved
|
||||
to `none` if the channel closes.
|
||||
-/
|
||||
consumers : Std.Queue (IO.Promise (Option α))
|
||||
/--
|
||||
Whether the channel is closed already.
|
||||
-/
|
||||
closed : Bool
|
||||
|
||||
private structure Zero (α : Type) where
|
||||
state : Mutex (Zero.State α)
|
||||
|
||||
namespace Zero
|
||||
|
||||
private def new : BaseIO (Zero α) := do
|
||||
return {
|
||||
state := ← Mutex.new {
|
||||
producers := ∅
|
||||
consumers := ∅
|
||||
closed := false
|
||||
}
|
||||
}
|
||||
|
||||
/--
|
||||
Precondition: The channel must not be closed.
|
||||
-/
|
||||
private def trySend' (v : α) : AtomicT (Zero.State α) BaseIO Bool := do
|
||||
let st ← get
|
||||
if let some (consumer, consumers) := st.consumers.dequeue? then
|
||||
consumer.resolve (some v)
|
||||
set { st with consumers }
|
||||
return true
|
||||
else
|
||||
return false
|
||||
|
||||
private def trySend (ch : Zero α) (v : α) : BaseIO Bool := do
|
||||
ch.state.atomically do
|
||||
if (← get).closed then
|
||||
return false
|
||||
else
|
||||
trySend' v
|
||||
|
||||
private def send (ch : Zero α) (v : α) : BaseIO (Task (Except Error Unit)) := do
|
||||
ch.state.atomically do
|
||||
if (← get).closed then
|
||||
return .pure <| .error .closed
|
||||
else if ← trySend' v then
|
||||
return .pure <| .ok ()
|
||||
else
|
||||
let promise ← IO.Promise.new
|
||||
modify fun st => { st with producers := st.producers.enqueue (v, promise) }
|
||||
return promise.result?.map (sync := true)
|
||||
fun
|
||||
| none | some false => .error .closed
|
||||
| some true => .ok ()
|
||||
|
||||
private def close (ch : Zero α) : EIO Error Unit := do
|
||||
ch.state.atomically do
|
||||
let st ← get
|
||||
if let some (a, values) := st.values.dequeue? then
|
||||
set { st with values }
|
||||
return .pure a
|
||||
if st.closed then throw .alreadyClosed
|
||||
for consumer in st.consumers.toArray do consumer.resolve none
|
||||
set { st with consumers := ∅, closed := true }
|
||||
return ()
|
||||
|
||||
private def isClosed (ch : Zero α) : BaseIO Bool :=
|
||||
ch.state.atomically do
|
||||
return (← get).closed
|
||||
|
||||
private def tryRecv' : AtomicT (Zero.State α) BaseIO (Option α) := do
|
||||
let st ← get
|
||||
if let some ((val, promise), producers) := st.producers.dequeue? then
|
||||
set { st with producers }
|
||||
promise.resolve true
|
||||
return some val
|
||||
else
|
||||
return none
|
||||
|
||||
private def tryRecv (ch : Zero α) : BaseIO (Option α) := do
|
||||
ch.state.atomically do
|
||||
tryRecv'
|
||||
|
||||
private def recv (ch : Zero α) : BaseIO (Task (Option α)) := do
|
||||
ch.state.atomically do
|
||||
let st ← get
|
||||
if let some val ← tryRecv' then
|
||||
return .pure <| some val
|
||||
else if !st.closed then
|
||||
let promise ← IO.Promise.new
|
||||
set { st with consumers := st.consumers.enqueue promise }
|
||||
return promise.result?.map (sync := true) (·.bind id)
|
||||
else
|
||||
return .pure none
|
||||
return .pure <| none
|
||||
|
||||
end Zero
|
||||
|
||||
/--
|
||||
`ch.forAsync f` calls `f` for every messages received on `ch`.
|
||||
The central state structure for a bounded channel, maintains the following invariants:
|
||||
1. `0 < capacity`
|
||||
2. `0 < bufCount → consumers = ∅`
|
||||
3. `bufCount < capacity → producers = ∅`
|
||||
4. `producers = ∅ ∨ consumers = ∅`, implied by 1, 2 and 3.
|
||||
5. `bufCount` corresponds to the amount of slots in `buf` that are `some`.
|
||||
6. `sendIdx = (recvIdx + bufCount) % capacity`. However all four of these values still get tracked
|
||||
as there is potential to make a non-blocking send lock-free in the future with this approach.
|
||||
7. `closed = true → consumers = ∅`
|
||||
|
||||
Note that if this function is called twice, each `forAsync` only gets half the messages.
|
||||
While it (currently) lacks the partial lock-freeness of go channels, the protocol is based on
|
||||
[Go channels on steroids](https://docs.google.com/document/d/1yIAYmbvL3JxOKOjuCyon7JhW4cSv1wy5hC0ApeGMV9s/pub)
|
||||
as well as its [implementation](https://go.dev/src/runtime/chan.go).
|
||||
-/
|
||||
partial def Channel.forAsync (f : α → BaseIO Unit) (ch : Channel α)
|
||||
private structure Bounded.State (α : Type) where
|
||||
/--
|
||||
Producers that are blocked on a consumer taking their value as there was no buffer space
|
||||
available when they tried to enqueue.
|
||||
-/
|
||||
producers : Std.Queue (IO.Promise Bool)
|
||||
/--
|
||||
Consumers that are blocked on a producer providing them a value, as there was no value
|
||||
enqueued when they tried to dequeue. The `IO.Promise` will be resolved to `false` if the channel
|
||||
closes.
|
||||
-/
|
||||
consumers : Std.Queue (IO.Promise Bool)
|
||||
/--
|
||||
The capacity of the buffer space.
|
||||
-/
|
||||
capacity : Nat
|
||||
/--
|
||||
The buffer space for the channel, slots with `some v` contain a value that is waiting for
|
||||
consumption, the slots with `none` are free for enqueueing.
|
||||
|
||||
Note that this is a `Vector` of `IO.Ref (Option α)` as the `buf` itself is shared across threads
|
||||
and would thus keep getting copied if it was a `Vector (Option α)` instead.
|
||||
-/
|
||||
buf : Vector (IO.Ref (Option α)) capacity
|
||||
/--
|
||||
How many slots in `buf` are currently used, this is used to disambiguate between an empty and a
|
||||
full buffer without sacrificing a slot for indicating that.
|
||||
-/
|
||||
bufCount : Nat
|
||||
/--
|
||||
The slot in `buf` that the next send will happen to.
|
||||
-/
|
||||
sendIdx : Nat
|
||||
hsend : sendIdx < capacity
|
||||
/--
|
||||
The slot in `buf` that the next receive will happen from.
|
||||
-/
|
||||
recvIdx : Nat
|
||||
hrecv : recvIdx < capacity
|
||||
/--
|
||||
Whether the channel is closed already.
|
||||
-/
|
||||
closed : Bool
|
||||
|
||||
private structure Bounded (α : Type) where
|
||||
state : Mutex (Bounded.State α)
|
||||
|
||||
namespace Bounded
|
||||
|
||||
private def new (capacity : Nat) (hcap : 0 < capacity) : BaseIO (Bounded α) := do
|
||||
return {
|
||||
state := ← Mutex.new {
|
||||
producers := ∅
|
||||
consumers := ∅
|
||||
capacity := capacity
|
||||
buf := ← Vector.range capacity |>.mapM (fun _ => IO.mkRef none)
|
||||
bufCount := 0
|
||||
sendIdx := 0
|
||||
hsend := hcap
|
||||
recvIdx := 0
|
||||
hrecv := hcap
|
||||
closed := false
|
||||
}
|
||||
}
|
||||
|
||||
@[inline]
|
||||
private def incMod (idx : Nat) (cap : Nat) : Nat :=
|
||||
if idx + 1 = cap then
|
||||
0
|
||||
else
|
||||
idx + 1
|
||||
|
||||
private theorem incMod_lt {idx cap : Nat} (h : idx < cap) : incMod idx cap < cap := by
|
||||
unfold incMod
|
||||
split <;> omega
|
||||
|
||||
/--
|
||||
Precondition: The channel must not be closed.
|
||||
-/
|
||||
private def trySend' (v : α) : AtomicT (Bounded.State α) BaseIO Bool := do
|
||||
let mut st ← get
|
||||
if st.bufCount = st.capacity then
|
||||
return false
|
||||
else
|
||||
st.buf[st.sendIdx]'st.hsend |>.set (some v)
|
||||
st := { st with
|
||||
bufCount := st.bufCount + 1
|
||||
sendIdx := incMod st.sendIdx st.capacity
|
||||
hsend := incMod_lt st.hsend
|
||||
}
|
||||
|
||||
if let some (consumer, consumers) := st.consumers.dequeue? then
|
||||
consumer.resolve true
|
||||
st := { st with consumers }
|
||||
|
||||
set st
|
||||
|
||||
return true
|
||||
|
||||
private def trySend (ch : Bounded α) (v : α) : BaseIO Bool := do
|
||||
ch.state.atomically do
|
||||
if (← get).closed then
|
||||
return false
|
||||
else
|
||||
trySend' v
|
||||
|
||||
private partial def send (ch : Bounded α) (v : α) : BaseIO (Task (Except Error Unit)) := do
|
||||
ch.state.atomically do
|
||||
if (← get).closed then
|
||||
return .pure <| .error .closed
|
||||
else if ← trySend' v then
|
||||
return .pure <| .ok ()
|
||||
else
|
||||
let promise ← IO.Promise.new
|
||||
modify fun st => { st with producers := st.producers.enqueue promise }
|
||||
BaseIO.bindTask promise.result? fun res => do
|
||||
if res.getD false then
|
||||
Bounded.send ch v
|
||||
else
|
||||
return .pure <| .error .closed
|
||||
|
||||
private def close (ch : Bounded α) : EIO Error Unit := do
|
||||
ch.state.atomically do
|
||||
let st ← get
|
||||
if st.closed then throw .alreadyClosed
|
||||
for consumer in st.consumers.toArray do consumer.resolve false
|
||||
set { st with consumers := ∅, closed := true }
|
||||
return ()
|
||||
|
||||
private def isClosed (ch : Bounded α) : BaseIO Bool :=
|
||||
ch.state.atomically do
|
||||
return (← get).closed
|
||||
|
||||
private def tryRecv' : AtomicT (Bounded.State α) BaseIO (Option α) := do
|
||||
let st ← get
|
||||
if st.bufCount == 0 then
|
||||
return none
|
||||
else
|
||||
let val ← st.buf[st.recvIdx]'st.hrecv |>.swap none
|
||||
let nextRecvIdx := incMod st.recvIdx st.capacity
|
||||
set { st with
|
||||
bufCount := st.bufCount - 1
|
||||
recvIdx := nextRecvIdx,
|
||||
hrecv := incMod_lt st.hrecv
|
||||
}
|
||||
return val
|
||||
|
||||
private def tryRecv (ch : Bounded α) : BaseIO (Option α) :=
|
||||
ch.state.atomically do
|
||||
tryRecv'
|
||||
|
||||
private partial def recv (ch : Bounded α) : BaseIO (Task (Option α)) := do
|
||||
ch.state.atomically do
|
||||
if let some val ← tryRecv' then
|
||||
let st ← get
|
||||
if let some (producer, producers) := (← get).producers.dequeue? then
|
||||
producer.resolve true
|
||||
set { st with producers }
|
||||
return .pure <| some val
|
||||
else if (← get).closed then
|
||||
return .pure none
|
||||
else
|
||||
let promise ← IO.Promise.new
|
||||
modify fun st => { st with consumers := st.consumers.enqueue promise }
|
||||
BaseIO.bindTask promise.result? fun res => do
|
||||
if res.getD false then
|
||||
Bounded.recv ch
|
||||
else
|
||||
return .pure none
|
||||
|
||||
end Bounded
|
||||
|
||||
/--
|
||||
This type represents all flavors of channels that we have available.
|
||||
-/
|
||||
private inductive Flavors (α : Type) where
|
||||
| unbounded (ch : Unbounded α)
|
||||
| zero (ch : Zero α)
|
||||
| bounded (ch : Bounded α)
|
||||
deriving Nonempty
|
||||
|
||||
end CloseableChannel
|
||||
|
||||
/--
|
||||
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
|
||||
and an asynchronous API, to switch into synchronous mode use `CloseableChannel.sync`.
|
||||
|
||||
Additionally `Std.CloseableChannel` can be closed if necessary, unlike `Std.Channel`.
|
||||
This introduces a need for error handling in some cases, thus it is usually easier to use
|
||||
`Std.Channel` if applicable.
|
||||
-/
|
||||
def CloseableChannel (α : Type) : Type := CloseableChannel.Flavors α
|
||||
|
||||
/--
|
||||
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
|
||||
and a synchronous API. This type acts as a convenient layer to use a channel in a blocking fashion
|
||||
and is not actually different from the original channel.
|
||||
|
||||
Additionally `Std.CloseableChannel.Sync` can be closed if necessary, unlike `Std.Channel.Sync`.
|
||||
This introduces the need to handle errors in some cases, thus it is usually easier to use
|
||||
`Std.Channel` if applicable.
|
||||
-/
|
||||
def CloseableChannel.Sync (α : Type) : Type := CloseableChannel α
|
||||
|
||||
instance : Nonempty (CloseableChannel α) :=
|
||||
inferInstanceAs (Nonempty (CloseableChannel.Flavors α))
|
||||
|
||||
instance : Nonempty (CloseableChannel.Sync α) :=
|
||||
inferInstanceAs (Nonempty (CloseableChannel α))
|
||||
|
||||
namespace CloseableChannel
|
||||
|
||||
/--
|
||||
Create a new channel, if:
|
||||
- `capacity` is `none` it will be unbounded (the default)
|
||||
- `capacity` is `some 0` it will always force a rendezvous between sender and receiver
|
||||
- `capacity` is `some n` with `n > 0` it will use a buffer of size `n` and begin blocking once it
|
||||
is filled
|
||||
-/
|
||||
def new (capacity : Option Nat := none) : BaseIO (CloseableChannel α) := do
|
||||
match capacity with
|
||||
| none => return .unbounded (← CloseableChannel.Unbounded.new)
|
||||
| some 0 => return .zero (← CloseableChannel.Zero.new)
|
||||
| some (n + 1) => return .bounded (← CloseableChannel.Bounded.new (n + 1) (by omega))
|
||||
|
||||
/--
|
||||
Try to send a value to the channel, if this can be completed right away without blocking return
|
||||
`true`, otherwise don't send the value and return `false`.
|
||||
-/
|
||||
def trySend (ch : CloseableChannel α) (v : α) : BaseIO Bool :=
|
||||
match ch with
|
||||
| .unbounded ch => CloseableChannel.Unbounded.trySend ch v
|
||||
| .zero ch => CloseableChannel.Zero.trySend ch v
|
||||
| .bounded ch => CloseableChannel.Bounded.trySend ch v
|
||||
|
||||
/--
|
||||
Send a value through the channel, returning a task that will resolve once the transmission could be
|
||||
completed. Note that the task may resolve to `Except.error` if the channel was closed before it
|
||||
could be completed.
|
||||
-/
|
||||
def send (ch : CloseableChannel α) (v : α) : BaseIO (Task (Except Error Unit)) :=
|
||||
match ch with
|
||||
| .unbounded ch => CloseableChannel.Unbounded.send ch v
|
||||
| .zero ch => CloseableChannel.Zero.send ch v
|
||||
| .bounded ch => CloseableChannel.Bounded.send ch v
|
||||
|
||||
/--
|
||||
Closes the channel, returns `Except.ok` when called the first time, otherwise `Except.error`.
|
||||
When a channel is closed:
|
||||
- no new values can be sent successfully anymore
|
||||
- all blocked consumers are resolved to `none` (as no new messages can be sent they will never
|
||||
resolve)
|
||||
- if there are already values waiting to be received they can still be received by subsequent `recv`
|
||||
calls
|
||||
-/
|
||||
def close (ch : CloseableChannel α) : EIO Error Unit :=
|
||||
match ch with
|
||||
| .unbounded ch => CloseableChannel.Unbounded.close ch
|
||||
| .zero ch => CloseableChannel.Zero.close ch
|
||||
| .bounded ch => CloseableChannel.Bounded.close ch
|
||||
|
||||
/--
|
||||
Return `true` if the channel is closed.
|
||||
-/
|
||||
def isClosed (ch : CloseableChannel α) : BaseIO Bool :=
|
||||
match ch with
|
||||
| .unbounded ch => CloseableChannel.Unbounded.isClosed ch
|
||||
| .zero ch => CloseableChannel.Zero.isClosed ch
|
||||
| .bounded ch => CloseableChannel.Bounded.isClosed ch
|
||||
|
||||
/--
|
||||
Try to receive a value from the channel, if this can be completed right away without blocking return
|
||||
`some value`, otherwise return `none`.
|
||||
-/
|
||||
def tryRecv (ch : CloseableChannel α) : BaseIO (Option α) :=
|
||||
match ch with
|
||||
| .unbounded ch => CloseableChannel.Unbounded.tryRecv ch
|
||||
| .zero ch => CloseableChannel.Zero.tryRecv ch
|
||||
| .bounded ch => CloseableChannel.Bounded.tryRecv ch
|
||||
|
||||
/--
|
||||
Receive a value from the channel, returning a task that will resolve once the transmission could be
|
||||
completed. Note that the task may resolve to `none` if the channel was closed before it could be
|
||||
completed.
|
||||
-/
|
||||
def recv (ch : CloseableChannel α) : BaseIO (Task (Option α)) :=
|
||||
match ch with
|
||||
| .unbounded ch => CloseableChannel.Unbounded.recv ch
|
||||
| .zero ch => CloseableChannel.Zero.recv ch
|
||||
| .bounded ch => CloseableChannel.Bounded.recv ch
|
||||
|
||||
/--
|
||||
`ch.forAsync f` calls `f` for every message received on `ch`.
|
||||
|
||||
Note that if this function is called twice, each message will only arrive at exactly one invocation.
|
||||
-/
|
||||
partial def forAsync (f : α → BaseIO Unit) (ch : CloseableChannel α)
|
||||
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
|
||||
BaseIO.bindTask (prio := prio) (← ch.recv?) fun
|
||||
BaseIO.bindTask (prio := prio) (← ch.recv) fun
|
||||
| none => return .pure ()
|
||||
| some v => do f v; ch.forAsync f prio
|
||||
|
||||
/--
|
||||
Receives all currently queued messages from the channel.
|
||||
|
||||
Those messages are dequeued and will not be returned by `recv?`.
|
||||
This function is a no-op and just a convenient way to expose the synchronous API of the channel.
|
||||
-/
|
||||
def Channel.recvAllCurrent (ch : Channel α) : BaseIO (Array α) :=
|
||||
ch.atomically do
|
||||
modifyGet fun st => (st.values.toArray, { st with values := ∅ })
|
||||
@[inline]
|
||||
def sync (ch : CloseableChannel α) : CloseableChannel.Sync α := ch
|
||||
|
||||
/-- Type tag for synchronous (blocking) operations on a `Channel`. -/
|
||||
def Channel.Sync := Channel
|
||||
namespace Sync
|
||||
|
||||
@[inherit_doc CloseableChannel.new, inline]
|
||||
def new (capacity : Option Nat := none) : BaseIO (Sync α) := CloseableChannel.new capacity
|
||||
|
||||
@[inherit_doc CloseableChannel.trySend, inline]
|
||||
def trySend (ch : Sync α) (v : α) : BaseIO Bool := CloseableChannel.trySend ch v
|
||||
|
||||
/--
|
||||
Accesses synchronous (blocking) version of channel operations.
|
||||
|
||||
For example, `ch.sync.recv?` blocks until the next message,
|
||||
and `for msg in ch.sync do ...` iterates synchronously over the channel.
|
||||
These functions should only be used in dedicated threads.
|
||||
Send a value through the channel, blocking until the transmission could be completed. Note that this
|
||||
function may throw an error when trying to send to an already closed channel.
|
||||
-/
|
||||
def Channel.sync (ch : Channel α) : Channel.Sync α := ch
|
||||
def send (ch : Sync α) (v : α) : EIO Error Unit := do
|
||||
EIO.ofExcept (← IO.wait (← CloseableChannel.send ch v))
|
||||
|
||||
@[inherit_doc CloseableChannel.close, inline]
|
||||
def close (ch : Sync α) : EIO Error Unit := CloseableChannel.close ch
|
||||
|
||||
@[inherit_doc CloseableChannel.isClosed, inline]
|
||||
def isClosed (ch : Sync α) : BaseIO Bool := CloseableChannel.isClosed ch
|
||||
|
||||
@[inherit_doc CloseableChannel.tryRecv, inline]
|
||||
def tryRecv (ch : Sync α) : BaseIO (Option α) := CloseableChannel.tryRecv ch
|
||||
|
||||
/--
|
||||
Synchronously receives a message from the channel.
|
||||
|
||||
Every message is only received once.
|
||||
Returns `none` if the channel is closed and the queue is empty.
|
||||
Receive a value from the channel, blocking unitl the transmission could be completed. Note that the
|
||||
return value may be `none` if the channel was closed before it could be completed.
|
||||
-/
|
||||
def Channel.Sync.recv? (ch : Channel.Sync α) : BaseIO (Option α) := do
|
||||
IO.wait (← Channel.recv? ch)
|
||||
def recv (ch : Sync α) : BaseIO (Option α) := do
|
||||
IO.wait (← CloseableChannel.recv ch)
|
||||
|
||||
private partial def Channel.Sync.forIn [Monad m] [MonadLiftT BaseIO m]
|
||||
(ch : Channel.Sync α) (f : α → β → m (ForInStep β)) : β → m β := fun b => do
|
||||
match ← ch.recv? with
|
||||
| some a =>
|
||||
match ← f a b with
|
||||
| .done b => pure b
|
||||
| .yield b => ch.forIn f b
|
||||
| none => pure b
|
||||
private partial def forIn [Monad m] [MonadLiftT BaseIO m]
|
||||
(ch : Sync α) (f : α → β → m (ForInStep β)) : β → m β := fun b => do
|
||||
match ← ch.recv with
|
||||
| some a =>
|
||||
match ← f a b with
|
||||
| .done b => pure b
|
||||
| .yield b => ch.forIn f b
|
||||
| none => pure b
|
||||
|
||||
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
|
||||
instance [MonadLiftT BaseIO m] : ForIn m (Channel.Sync α) α where
|
||||
instance [MonadLiftT BaseIO m] : ForIn m (Sync α) α where
|
||||
forIn ch b f := ch.forIn f b
|
||||
|
||||
end Sync
|
||||
|
||||
end CloseableChannel
|
||||
|
||||
/--
|
||||
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
|
||||
and an asynchronous API, to switch into synchronous mode use `Channel.sync`.
|
||||
|
||||
If a channel needs to be closed to indicate some sort of completion event use `Std.CloseableChannel`
|
||||
instead. Note that `Std.CloseableChannel` introduces a need for error handling in some cases, thus
|
||||
`Std.Channel` is usually easier to use if applicable.
|
||||
-/
|
||||
structure Channel (α : Type) where
|
||||
private mk ::
|
||||
private inner : CloseableChannel α
|
||||
deriving Nonempty
|
||||
|
||||
/--
|
||||
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
|
||||
and a synchronous API. This type acts as a convenient layer to use a channel in a blocking fashion
|
||||
and is not actually different from the original channel.
|
||||
|
||||
If a channel needs to be closed to indicate some sort of completion event use
|
||||
`Std.CloseableChannel.Sync` instead. Note that `Std.CloseableChannel.Sync` introduces a need for error
|
||||
handling in some cases, thus `Std.Channel.Sync` is usually easier to use if applicable.
|
||||
-/
|
||||
def Channel.Sync (α : Type) : Type := Channel α
|
||||
|
||||
instance : Nonempty (Channel.Sync α) :=
|
||||
inferInstanceAs (Nonempty (Channel α))
|
||||
|
||||
namespace Channel
|
||||
|
||||
@[inherit_doc CloseableChannel.new, inline]
|
||||
def new (capacity : Option Nat := none) : BaseIO (Channel α) := do
|
||||
return ⟨← CloseableChannel.new capacity⟩
|
||||
|
||||
@[inherit_doc CloseableChannel.trySend, inline]
|
||||
def trySend (ch : Channel α) (v : α) : BaseIO Bool :=
|
||||
CloseableChannel.trySend ch.inner v
|
||||
|
||||
/--
|
||||
Send a value through the channel, returning a task that will resolve once the transmission could be
|
||||
completed.
|
||||
-/
|
||||
def send (ch : Channel α) (v : α) : BaseIO (Task Unit) := do
|
||||
BaseIO.bindTask (sync := true) (← CloseableChannel.send ch.inner v)
|
||||
fun
|
||||
| .ok .. => return .pure ()
|
||||
| .error .. => unreachable!
|
||||
|
||||
@[inherit_doc CloseableChannel.tryRecv, inline]
|
||||
def tryRecv (ch : Channel α) : BaseIO (Option α) :=
|
||||
CloseableChannel.tryRecv ch.inner
|
||||
|
||||
@[inherit_doc CloseableChannel.recv]
|
||||
def recv [Inhabited α] (ch : Channel α) : BaseIO (Task α) := do
|
||||
BaseIO.bindTask (sync := true) (← CloseableChannel.recv ch.inner)
|
||||
fun
|
||||
| some val => return .pure val
|
||||
| none => unreachable!
|
||||
|
||||
@[inherit_doc CloseableChannel.forAsync]
|
||||
partial def forAsync [Inhabited α] (f : α → BaseIO Unit) (ch : Channel α)
|
||||
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
|
||||
BaseIO.bindTask (prio := prio) (← ch.recv) fun v => do f v; ch.forAsync f prio
|
||||
|
||||
@[inherit_doc CloseableChannel.sync, inline]
|
||||
def sync (ch : Channel α) : Channel.Sync α := ch
|
||||
|
||||
namespace Sync
|
||||
|
||||
@[inherit_doc Channel.new, inline]
|
||||
def new (capacity : Option Nat := none) : BaseIO (Sync α) := Channel.new capacity
|
||||
|
||||
@[inherit_doc Channel.trySend, inline]
|
||||
def trySend (ch : Sync α) (v : α) : BaseIO Bool := Channel.trySend ch v
|
||||
|
||||
/--
|
||||
Send a value through the channel, blocking until the transmission could be completed.
|
||||
-/
|
||||
def send (ch : Sync α) (v : α) : BaseIO Unit := do
|
||||
IO.wait (← Channel.send ch v)
|
||||
|
||||
@[inherit_doc Channel.tryRecv, inline]
|
||||
def tryRecv (ch : Sync α) : BaseIO (Option α) := Channel.tryRecv ch
|
||||
|
||||
/--
|
||||
Receive a value from the channel, blocking unitl the transmission could be completed.
|
||||
-/
|
||||
def recv [Inhabited α] (ch : Sync α) : BaseIO α := do
|
||||
IO.wait (← Channel.recv ch)
|
||||
|
||||
private partial def forIn [Inhabited α] [Monad m] [MonadLiftT BaseIO m]
|
||||
(ch : Sync α) (f : α → β → m (ForInStep β)) : β → m β := fun b => do
|
||||
let a ← ch.recv
|
||||
match ← f a b with
|
||||
| .done b => pure b
|
||||
| .yield b => ch.forIn f b
|
||||
|
||||
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
|
||||
instance [Inhabited α] [MonadLiftT BaseIO m] : ForIn m (Sync α) α where
|
||||
forIn ch b f := ch.forIn f b
|
||||
|
||||
end Sync
|
||||
|
||||
end Channel
|
||||
|
||||
end Std
|
||||
|
||||
@@ -246,113 +246,116 @@ instance : Hashable (BVExpr w) where
|
||||
|
||||
def decEq : DecidableEq (BVExpr w) := fun l r =>
|
||||
withPtrEqDecEq l r fun _ =>
|
||||
match l with
|
||||
| .var lidx =>
|
||||
match r with
|
||||
| .var ridx =>
|
||||
if h : lidx = ridx then .isTrue (by simp [h]) else .isFalse (by simp [h])
|
||||
| .const .. | .extract .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
|
||||
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
| .const lval =>
|
||||
match r with
|
||||
| .const rval =>
|
||||
if h : lval = rval then .isTrue (by simp [h]) else .isFalse (by simp [h])
|
||||
| .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
|
||||
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
| .extract (w := lw) lstart _ lexpr =>
|
||||
match r with
|
||||
| .extract (w := rw) rstart _ rexpr =>
|
||||
if h1 : lw = rw ∧ lstart = rstart then
|
||||
match decEq (h1.left ▸ lexpr) rexpr with
|
||||
| .isTrue h2 => .isTrue (by cases h1.left; simp_all)
|
||||
| .isFalse h2 => .isFalse (by cases h1.left; simp_all)
|
||||
else
|
||||
.isFalse (by simp_all)
|
||||
| .var .. | .const .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
|
||||
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
| .bin llhs lop lrhs =>
|
||||
match r with
|
||||
| .bin rlhs rop rrhs =>
|
||||
if h1 : lop = rop then
|
||||
match decEq llhs rlhs, decEq lrhs rrhs with
|
||||
| .isTrue h2, .isTrue h3 => .isTrue (by simp [h1, h2, h3])
|
||||
| .isFalse h2, _ => .isFalse (by simp [h2])
|
||||
| _, .isFalse h3 => .isFalse (by simp [h3])
|
||||
else
|
||||
.isFalse (by simp [h1])
|
||||
| .const .. | .var .. | .extract .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
|
||||
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
| .un lop lexpr =>
|
||||
match r with
|
||||
| .un rop rexpr =>
|
||||
if h1 : lop = rop then
|
||||
match decEq lexpr rexpr with
|
||||
| .isTrue h2 => .isTrue (by simp [h1, h2])
|
||||
| .isFalse h2 => .isFalse (by simp [h2])
|
||||
else
|
||||
.isFalse (by simp [h1])
|
||||
| .const .. | .var .. | .extract .. | .bin .. | .append .. | .replicate .. | .shiftLeft ..
|
||||
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
| .append (l := ll) (r := lr) llhs lrhs lh =>
|
||||
match r with
|
||||
| .append (l := rl) (r := rr) rlhs rrhs rh =>
|
||||
if h1 : ll = rl ∧ lr = rr then
|
||||
match decEq (h1.left ▸ llhs) rlhs, decEq (h1.right ▸ lrhs) rrhs with
|
||||
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1.left; cases h1.right; simp [h2, h3])
|
||||
| .isFalse h2, _ => .isFalse (by cases h1.left; cases h1.right; simp [h2])
|
||||
| _, .isFalse h3 => .isFalse (by cases h1.left; cases h1.right; simp [h3])
|
||||
else
|
||||
.isFalse (by simp; omega)
|
||||
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .replicate .. | .shiftLeft ..
|
||||
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
| .replicate (w := lw) ln lexpr lh =>
|
||||
match r with
|
||||
| .replicate (w := rw) rn rexpr rh =>
|
||||
if h1 : ln = rn ∧ lw = rw then
|
||||
match decEq (h1.right ▸ lexpr) rexpr with
|
||||
| .isTrue h2 => .isTrue (by cases h1.left; cases h1.right; simp [h2])
|
||||
| .isFalse h2 => .isFalse (by cases h1.left; cases h1.right; simp [h2])
|
||||
else
|
||||
.isFalse (by simp; omega)
|
||||
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .shiftLeft ..
|
||||
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
if h : hash l ≠ hash r then
|
||||
.isFalse (ne_of_apply_ne hash h)
|
||||
else
|
||||
match l with
|
||||
| .var lidx =>
|
||||
match r with
|
||||
| .var ridx =>
|
||||
if h : lidx = ridx then .isTrue (by simp [h]) else .isFalse (by simp [h])
|
||||
| .const .. | .extract .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
|
||||
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
| .const lval =>
|
||||
match r with
|
||||
| .const rval =>
|
||||
if h : lval = rval then .isTrue (by simp [h]) else .isFalse (by simp [h])
|
||||
| .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
|
||||
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
| .extract (w := lw) lstart _ lexpr =>
|
||||
match r with
|
||||
| .extract (w := rw) rstart _ rexpr =>
|
||||
if h1 : lw = rw ∧ lstart = rstart then
|
||||
match decEq (h1.left ▸ lexpr) rexpr with
|
||||
| .isTrue h2 => .isTrue (by cases h1.left; simp_all)
|
||||
| .isFalse h2 => .isFalse (by cases h1.left; simp_all)
|
||||
else
|
||||
.isFalse (by simp_all)
|
||||
| .var .. | .const .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
|
||||
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
| .bin llhs lop lrhs =>
|
||||
match r with
|
||||
| .bin rlhs rop rrhs =>
|
||||
if h1 : lop = rop then
|
||||
match decEq llhs rlhs, decEq lrhs rrhs with
|
||||
| .isTrue h2, .isTrue h3 => .isTrue (by simp [h1, h2, h3])
|
||||
| .isFalse h2, _ => .isFalse (by simp [h2])
|
||||
| _, .isFalse h3 => .isFalse (by simp [h3])
|
||||
else
|
||||
.isFalse (by simp [h1])
|
||||
| .const .. | .var .. | .extract .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
|
||||
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
| .un lop lexpr =>
|
||||
match r with
|
||||
| .un rop rexpr =>
|
||||
if h1 : lop = rop then
|
||||
match decEq lexpr rexpr with
|
||||
| .isTrue h2 => .isTrue (by simp [h1, h2])
|
||||
| .isFalse h2 => .isFalse (by simp [h2])
|
||||
else
|
||||
.isFalse (by simp [h1])
|
||||
| .const .. | .var .. | .extract .. | .bin .. | .append .. | .replicate .. | .shiftLeft ..
|
||||
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
| .append (l := ll) (r := lr) llhs lrhs lh =>
|
||||
match r with
|
||||
| .append (l := rl) (r := rr) rlhs rrhs rh =>
|
||||
if h1 : ll = rl ∧ lr = rr then
|
||||
match decEq (h1.left ▸ llhs) rlhs, decEq (h1.right ▸ lrhs) rrhs with
|
||||
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1.left; cases h1.right; simp [h2, h3])
|
||||
| .isFalse h2, _ => .isFalse (by cases h1.left; cases h1.right; simp [h2])
|
||||
| _, .isFalse h3 => .isFalse (by cases h1.left; cases h1.right; simp [h3])
|
||||
else
|
||||
.isFalse (by simp; omega)
|
||||
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .replicate .. | .shiftLeft ..
|
||||
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
| .replicate (w := lw) ln lexpr lh =>
|
||||
match r with
|
||||
| .replicate (w := rw) rn rexpr rh =>
|
||||
if h1 : ln = rn ∧ lw = rw then
|
||||
match decEq (h1.right ▸ lexpr) rexpr with
|
||||
| .isTrue h2 => .isTrue (by cases h1.left; cases h1.right; simp [h2])
|
||||
| .isFalse h2 => .isFalse (by cases h1.left; cases h1.right; simp [h2])
|
||||
else
|
||||
.isFalse (by simp; omega)
|
||||
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .shiftLeft ..
|
||||
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
|
||||
| .shiftLeft (n := lw) llhs lrhs =>
|
||||
match r with
|
||||
| .shiftLeft (n := rw) rlhs rrhs =>
|
||||
if h1 : lw = rw then
|
||||
match decEq llhs rlhs, decEq (h1 ▸ lrhs) rrhs with
|
||||
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
|
||||
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
|
||||
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
|
||||
else
|
||||
.isFalse (by simp [h1])
|
||||
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
|
||||
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
| .shiftRight (n := lw) llhs lrhs =>
|
||||
match r with
|
||||
| .shiftRight (n := rw) rlhs rrhs =>
|
||||
if h1 : lw = rw then
|
||||
match decEq llhs rlhs, decEq (h1 ▸ lrhs) rrhs with
|
||||
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
|
||||
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
|
||||
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
|
||||
else
|
||||
.isFalse (by simp [h1])
|
||||
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
|
||||
|.shiftLeft .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
| .arithShiftRight (n := lw) llhs lrhs =>
|
||||
match r with
|
||||
| .arithShiftRight (n := rw) rlhs rrhs =>
|
||||
if h1 : lw = rw then
|
||||
match decEq llhs rlhs, decEq (h1 ▸ lrhs) rrhs with
|
||||
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
|
||||
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
|
||||
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
|
||||
else
|
||||
.isFalse (by simp [h1])
|
||||
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
|
||||
| .shiftRight .. | .shiftLeft .. => .isFalse (by simp)
|
||||
| .shiftLeft (n := lw) llhs lrhs =>
|
||||
match r with
|
||||
| .shiftLeft (n := rw) rlhs rrhs =>
|
||||
if h1 : lw = rw then
|
||||
match decEq llhs rlhs, decEq (h1 ▸ lrhs) rrhs with
|
||||
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
|
||||
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
|
||||
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
|
||||
else
|
||||
.isFalse (by simp [h1])
|
||||
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
|
||||
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
| .shiftRight (n := lw) llhs lrhs =>
|
||||
match r with
|
||||
| .shiftRight (n := rw) rlhs rrhs =>
|
||||
if h1 : lw = rw then
|
||||
match decEq llhs rlhs, decEq (h1 ▸ lrhs) rrhs with
|
||||
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
|
||||
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
|
||||
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
|
||||
else
|
||||
.isFalse (by simp [h1])
|
||||
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
|
||||
|.shiftLeft .. | .arithShiftRight .. => .isFalse (by simp)
|
||||
| .arithShiftRight (n := lw) llhs lrhs =>
|
||||
match r with
|
||||
| .arithShiftRight (n := rw) rlhs rrhs =>
|
||||
if h1 : lw = rw then
|
||||
match decEq llhs rlhs, decEq (h1 ▸ lrhs) rrhs with
|
||||
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
|
||||
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
|
||||
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
|
||||
else
|
||||
.isFalse (by simp [h1])
|
||||
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
|
||||
| .shiftRight .. | .shiftLeft .. => .isFalse (by simp)
|
||||
|
||||
|
||||
instance : DecidableEq (BVExpr w) := decEq
|
||||
|
||||
@@ -9,6 +9,7 @@ Author: Leonardo de Moura
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include "kernel/replace_fn.h"
|
||||
#include "util/alloc.h"
|
||||
|
||||
namespace lean {
|
||||
|
||||
@@ -18,7 +19,7 @@ class replace_rec_fn {
|
||||
return hash((size_t)p.first >> 3, p.second);
|
||||
}
|
||||
};
|
||||
std::unordered_map<std::pair<lean_object *, unsigned>, expr, key_hasher> m_cache;
|
||||
lean::unordered_map<std::pair<lean_object *, unsigned>, expr, key_hasher> m_cache;
|
||||
std::function<optional<expr>(expr const &, unsigned)> m_f;
|
||||
bool m_use_cache;
|
||||
|
||||
@@ -84,7 +85,7 @@ expr replace(expr const & e, std::function<optional<expr>(expr const &, unsigned
|
||||
}
|
||||
|
||||
class replace_fn {
|
||||
std::unordered_map<lean_object *, expr> m_cache;
|
||||
lean::unordered_map<lean_object *, expr> m_cache;
|
||||
lean_object * m_f;
|
||||
|
||||
expr save_result(expr const & e, expr const & r, bool shared) {
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euxo pipefail
|
||||
|
||||
exit 0 # TODO: flaky test disabled
|
||||
|
||||
# test disabled
|
||||
LAKE=${LAKE:-../../.lake/build/bin/lake}
|
||||
|
||||
./clean.sh
|
||||
|
||||
@@ -21,6 +21,7 @@ Authors: Leonardo de Moura, Gabriel Ebner, Sebastian Ullrich
|
||||
#include "runtime/io.h"
|
||||
#include "runtime/compact.h"
|
||||
#include "runtime/buffer.h"
|
||||
#include "runtime/array_ref.h"
|
||||
#include "util/io.h"
|
||||
#include "util/name_map.h"
|
||||
#include "library/module.h"
|
||||
@@ -76,47 +77,73 @@ struct olean_header {
|
||||
// make sure we don't have any padding bytes, which also ensures `data` is properly aligned
|
||||
static_assert(sizeof(olean_header) == 5 + 1 + 1 + 33 + 40 + sizeof(size_t), "olean_header must be packed");
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_save_module_data(b_obj_arg fname, b_obj_arg mod, b_obj_arg mdata, object *) {
|
||||
std::string olean_fn(string_cstr(fname));
|
||||
// we first write to a temp file and then move it to the correct path (possibly deleting an older file)
|
||||
// so that we neither expose partially-written files nor modify possibly memory-mapped files
|
||||
std::string olean_tmp_fn = olean_fn + ".tmp";
|
||||
try {
|
||||
std::ofstream out(olean_tmp_fn, std::ios_base::binary);
|
||||
if (out.fail()) {
|
||||
return io_result_mk_error((sstream() << "failed to create file '" << olean_fn << "'").str());
|
||||
extern "C" LEAN_EXPORT object * lean_save_module_data_parts(b_obj_arg mod, b_obj_arg oparts, object *) {
|
||||
#ifdef LEAN_WINDOWS
|
||||
uint32_t pid = GetCurrentProcessId();
|
||||
#else
|
||||
uint32_t pid = getpid();
|
||||
#endif
|
||||
// Derive a base address that is uniformly distributed by deterministic, and should most likely
|
||||
// work for `mmap` on all interesting platforms
|
||||
// NOTE: an overlapping/non-compatible base address does not prevent the module from being imported,
|
||||
// merely from using `mmap` for that
|
||||
|
||||
// Let's start with a hash of the module name. Note that while our string hash is a dubious 32-bit
|
||||
// algorithm, the mixing of multiple `Name` parts seems to result in a nicely distributed 64-bit
|
||||
// output
|
||||
size_t base_addr = name(mod, true).hash();
|
||||
// x86-64 user space is currently limited to the lower 47 bits
|
||||
// https://en.wikipedia.org/wiki/X86-64#Virtual_address_space_details
|
||||
// On Linux at least, the stack grows down from ~0x7fff... followed by shared libraries, so reserve
|
||||
// a bit of space for them (0x7fff...-0x7f00... = 1TB)
|
||||
base_addr = base_addr % 0x7f0000000000;
|
||||
// `mmap` addresses must be page-aligned. The default (non-huge) page size on x86-64 is 4KB.
|
||||
// `MapViewOfFileEx` addresses must be aligned to the "memory allocation granularity", which is 64KB.
|
||||
const size_t ALIGN = 1LL<<16;
|
||||
base_addr = base_addr & ~(ALIGN - 1);
|
||||
|
||||
object_compactor compactor(reinterpret_cast<void *>(base_addr));
|
||||
|
||||
array_ref<pair_ref<string_ref, object_ref>> parts(oparts, true);
|
||||
std::vector<std::string> tmp_fnames;
|
||||
for (auto const & part : parts) {
|
||||
std::string olean_fn = part.fst().to_std_string();
|
||||
try {
|
||||
// we first write to a temp file and then move it to the correct path (possibly deleting an older file)
|
||||
// so that we neither expose partially-written files nor modify possibly memory-mapped files
|
||||
std::string olean_tmp_fn = olean_fn + ".tmp." + std::to_string(pid);
|
||||
tmp_fnames.push_back(olean_tmp_fn);
|
||||
|
||||
std::ofstream out(olean_tmp_fn, std::ios_base::binary);
|
||||
|
||||
if (compactor.size() % ALIGN != 0) {
|
||||
compactor.alloc(ALIGN - (compactor.size() % ALIGN));
|
||||
}
|
||||
size_t file_offset = compactor.size();
|
||||
|
||||
compactor.alloc(sizeof(olean_header));
|
||||
olean_header header = {};
|
||||
// see/sync with file format description above
|
||||
header.base_addr = base_addr + file_offset;
|
||||
strncpy(header.lean_version, get_short_version_string().c_str(), sizeof(header.lean_version));
|
||||
strncpy(header.githash, LEAN_GITHASH, sizeof(header.githash));
|
||||
out.write(reinterpret_cast<char *>(&header), sizeof(header));
|
||||
|
||||
compactor(part.snd().raw());
|
||||
|
||||
if (out.fail()) {
|
||||
throw exception((sstream() << "failed to create file '" << olean_fn << "'").str());
|
||||
}
|
||||
out.write(static_cast<char const *>(compactor.data()) + file_offset + sizeof(olean_header), compactor.size() - file_offset - sizeof(olean_header));
|
||||
out.close();
|
||||
} catch (exception & ex) {
|
||||
return io_result_mk_error((sstream() << "failed to write '" << olean_fn << "': " << ex.what()).str());
|
||||
}
|
||||
}
|
||||
|
||||
// Derive a base address that is uniformly distributed by deterministic, and should most likely
|
||||
// work for `mmap` on all interesting platforms
|
||||
// NOTE: an overlapping/non-compatible base address does not prevent the module from being imported,
|
||||
// merely from using `mmap` for that
|
||||
|
||||
// Let's start with a hash of the module name. Note that while our string hash is a dubious 32-bit
|
||||
// algorithm, the mixing of multiple `Name` parts seems to result in a nicely distributed 64-bit
|
||||
// output
|
||||
size_t base_addr = name(mod, true).hash();
|
||||
// x86-64 user space is currently limited to the lower 47 bits
|
||||
// https://en.wikipedia.org/wiki/X86-64#Virtual_address_space_details
|
||||
// On Linux at least, the stack grows down from ~0x7fff... followed by shared libraries, so reserve
|
||||
// a bit of space for them (0x7fff...-0x7f00... = 1TB)
|
||||
base_addr = base_addr % 0x7f0000000000;
|
||||
// `mmap` addresses must be page-aligned. The default (non-huge) page size on x86-64 is 4KB.
|
||||
// `MapViewOfFileEx` addresses must be aligned to the "memory allocation granularity", which is 64KB.
|
||||
base_addr = base_addr & ~((1LL<<16) - 1);
|
||||
|
||||
object_compactor compactor(reinterpret_cast<void *>(base_addr + offsetof(olean_header, data)));
|
||||
compactor(mdata);
|
||||
|
||||
// see/sync with file format description above
|
||||
olean_header header = {};
|
||||
header.base_addr = base_addr;
|
||||
strncpy(header.lean_version, get_short_version_string().c_str(), sizeof(header.lean_version));
|
||||
strncpy(header.githash, LEAN_GITHASH, sizeof(header.githash));
|
||||
out.write(reinterpret_cast<char *>(&header), sizeof(header));
|
||||
out.write(static_cast<char const *>(compactor.data()), compactor.size());
|
||||
out.close();
|
||||
while (std::rename(olean_tmp_fn.c_str(), olean_fn.c_str()) != 0) {
|
||||
for (unsigned i = 0; i < parts.size(); i++) {
|
||||
std::string olean_fn = parts[i].fst().to_std_string();
|
||||
while (std::rename(tmp_fnames[i].c_str(), olean_fn.c_str()) != 0) {
|
||||
#ifdef LEAN_WINDOWS
|
||||
if (errno == EEXIST) {
|
||||
// Memory-mapped files can be deleted starting with Windows 10 using "POSIX semantics"
|
||||
@@ -136,94 +163,148 @@ extern "C" LEAN_EXPORT object * lean_save_module_data(b_obj_arg fname, b_obj_arg
|
||||
#endif
|
||||
return io_result_mk_error((sstream() << "failed to write '" << olean_fn << "': " << errno << " " << strerror(errno)).str());
|
||||
}
|
||||
return io_result_mk_ok(box(0));
|
||||
} catch (exception & ex) {
|
||||
return io_result_mk_error((sstream() << "failed to write '" << olean_fn << "': " << ex.what()).str());
|
||||
}
|
||||
return io_result_mk_ok(box(0));
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_read_module_data(object * fname, object *) {
|
||||
std::string olean_fn(string_cstr(fname));
|
||||
try {
|
||||
std::ifstream in(olean_fn, std::ios_base::binary);
|
||||
if (in.fail()) {
|
||||
return io_result_mk_error((sstream() << "failed to open file '" << olean_fn << "'").str());
|
||||
}
|
||||
/* Get file size */
|
||||
in.seekg(0, in.end);
|
||||
size_t size = in.tellg();
|
||||
in.seekg(0);
|
||||
struct module_file {
|
||||
std::string m_fname;
|
||||
std::ifstream m_in;
|
||||
char * m_base_addr;
|
||||
size_t m_size;
|
||||
char * m_buffer;
|
||||
std::function<void()> m_free_data;
|
||||
};
|
||||
|
||||
olean_header default_header = {};
|
||||
olean_header header;
|
||||
if (!in.read(reinterpret_cast<char *>(&header), sizeof(header))
|
||||
|| memcmp(header.marker, default_header.marker, sizeof(header.marker)) != 0) {
|
||||
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "', invalid header").str());
|
||||
}
|
||||
if (header.version != default_header.version || header.flags != default_header.flags
|
||||
#ifdef LEAN_CHECK_OLEAN_VERSION
|
||||
|| strncmp(header.githash, LEAN_GITHASH, sizeof(header.githash)) != 0
|
||||
#endif
|
||||
) {
|
||||
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "', incompatible header").str());
|
||||
}
|
||||
char * base_addr = reinterpret_cast<char *>(header.base_addr);
|
||||
char * buffer = nullptr;
|
||||
bool is_mmap = false;
|
||||
std::function<void()> free_data;
|
||||
#ifdef LEAN_WINDOWS
|
||||
// `FILE_SHARE_DELETE` is necessary to allow the file to (be marked to) be deleted while in use
|
||||
HANDLE h_olean_fn = CreateFile(olean_fn.c_str(), GENERIC_READ, FILE_SHARE_READ | FILE_SHARE_DELETE, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
|
||||
if (h_olean_fn == INVALID_HANDLE_VALUE) {
|
||||
return io_result_mk_error((sstream() << "failed to open '" << olean_fn << "': " << GetLastError()).str());
|
||||
}
|
||||
HANDLE h_map = CreateFileMapping(h_olean_fn, NULL, PAGE_READONLY, 0, 0, NULL);
|
||||
if (h_olean_fn == NULL) {
|
||||
return io_result_mk_error((sstream() << "failed to map '" << olean_fn << "': " << GetLastError()).str());
|
||||
}
|
||||
buffer = static_cast<char *>(MapViewOfFileEx(h_map, FILE_MAP_READ, 0, 0, 0, base_addr));
|
||||
free_data = [=]() {
|
||||
if (buffer) {
|
||||
lean_always_assert(UnmapViewOfFile(base_addr));
|
||||
extern "C" LEAN_EXPORT object * lean_read_module_data_parts(b_obj_arg ofnames, object *) {
|
||||
array_ref<string_ref> fnames(ofnames, true);
|
||||
|
||||
// first read in all headers
|
||||
std::vector<module_file> files;
|
||||
for (auto const & fname : fnames) {
|
||||
std::string olean_fn = fname.to_std_string();
|
||||
try {
|
||||
std::ifstream in(olean_fn, std::ios_base::binary);
|
||||
if (in.fail()) {
|
||||
return io_result_mk_error((sstream() << "failed to open file '" << olean_fn << "'").str());
|
||||
}
|
||||
/* Get file size */
|
||||
in.seekg(0, in.end);
|
||||
size_t size = in.tellg();
|
||||
in.seekg(0);
|
||||
|
||||
olean_header default_header = {};
|
||||
olean_header header;
|
||||
if (!in.read(reinterpret_cast<char *>(&header), sizeof(header))
|
||||
|| memcmp(header.marker, default_header.marker, sizeof(header.marker)) != 0) {
|
||||
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "', invalid header").str());
|
||||
}
|
||||
in.seekg(0);
|
||||
if (header.version != default_header.version || header.flags != default_header.flags
|
||||
#ifdef LEAN_CHECK_OLEAN_VERSION
|
||||
|| strncmp(header.githash, LEAN_GITHASH, sizeof(header.githash)) != 0
|
||||
#endif
|
||||
) {
|
||||
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "', incompatible header").str());
|
||||
}
|
||||
char * base_addr = reinterpret_cast<char *>(header.base_addr);
|
||||
files.push_back({olean_fn, std::move(in), base_addr, size, nullptr, nullptr});
|
||||
} catch (exception & ex) {
|
||||
return io_result_mk_error((sstream() << "failed to read '" << olean_fn << "': " << ex.what()).str());
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef LEAN_MMAP
|
||||
bool is_mmap = false;
|
||||
#else
|
||||
// now try mmapping *all* files
|
||||
bool is_mmap = true;
|
||||
for (auto & file : files) {
|
||||
std::string const & olean_fn = file.m_fname;
|
||||
char * base_addr = file.m_base_addr;
|
||||
try {
|
||||
#ifdef LEAN_WINDOWS
|
||||
// `FILE_SHARE_DELETE` is necessary to allow the file to (be marked to) be deleted while in use
|
||||
HANDLE h_olean_fn = CreateFile(olean_fn.c_str(), GENERIC_READ, FILE_SHARE_READ | FILE_SHARE_DELETE, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
|
||||
if (h_olean_fn == INVALID_HANDLE_VALUE) {
|
||||
return io_result_mk_error((sstream() << "failed to open '" << olean_fn << "': " << GetLastError()).str());
|
||||
}
|
||||
HANDLE h_map = CreateFileMapping(h_olean_fn, NULL, PAGE_READONLY, 0, 0, NULL);
|
||||
if (h_olean_fn == NULL) {
|
||||
return io_result_mk_error((sstream() << "failed to map '" << olean_fn << "': " << GetLastError()).str());
|
||||
}
|
||||
char * buffer = static_cast<char *>(MapViewOfFileEx(h_map, FILE_MAP_READ, 0, 0, 0, base_addr));
|
||||
lean_always_assert(CloseHandle(h_map));
|
||||
lean_always_assert(CloseHandle(h_olean_fn));
|
||||
};
|
||||
#else
|
||||
int fd = open(olean_fn.c_str(), O_RDONLY);
|
||||
if (fd == -1) {
|
||||
return io_result_mk_error((sstream() << "failed to open '" << olean_fn << "': " << strerror(errno)).str());
|
||||
}
|
||||
#ifdef LEAN_MMAP
|
||||
buffer = static_cast<char *>(mmap(base_addr, size, PROT_READ, MAP_PRIVATE, fd, 0));
|
||||
#endif
|
||||
close(fd);
|
||||
free_data = [=]() {
|
||||
if (buffer != MAP_FAILED) {
|
||||
lean_always_assert(munmap(buffer, size) == 0);
|
||||
if (!buffer) {
|
||||
is_mmap = false;
|
||||
break;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
if (buffer && buffer == base_addr) {
|
||||
buffer += sizeof(olean_header);
|
||||
is_mmap = true;
|
||||
} else {
|
||||
#ifdef LEAN_MMAP
|
||||
free_data();
|
||||
#endif
|
||||
buffer = static_cast<char *>(malloc(size - sizeof(olean_header)));
|
||||
free_data = [=]() {
|
||||
free_sized(buffer, size - sizeof(olean_header));
|
||||
file.m_free_data = [=]() {
|
||||
lean_always_assert(UnmapViewOfFile(base_addr));
|
||||
};
|
||||
in.read(buffer, size - sizeof(olean_header));
|
||||
if (!in) {
|
||||
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "'").str());
|
||||
#else
|
||||
int fd = open(olean_fn.c_str(), O_RDONLY);
|
||||
if (fd == -1) {
|
||||
return io_result_mk_error((sstream() << "failed to open '" << olean_fn << "': " << strerror(errno)).str());
|
||||
}
|
||||
char * buffer = static_cast<char *>(mmap(base_addr, file.m_size, PROT_READ, MAP_PRIVATE, fd, 0));
|
||||
if (buffer == MAP_FAILED) {
|
||||
is_mmap = false;
|
||||
break;
|
||||
}
|
||||
close(fd);
|
||||
size_t size = file.m_size;
|
||||
file.m_free_data = [=]() {
|
||||
lean_always_assert(munmap(buffer, size) == 0);
|
||||
};
|
||||
#endif
|
||||
if (buffer == base_addr) {
|
||||
file.m_buffer = buffer;
|
||||
} else {
|
||||
is_mmap = false;
|
||||
break;
|
||||
}
|
||||
} catch (exception & ex) {
|
||||
return io_result_mk_error((sstream() << "failed to read '" << olean_fn << "': " << ex.what()).str());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// if *any* file failed to mmap, read all of them into a single big allocation so that offsets
|
||||
// between them are unchanged
|
||||
if (!is_mmap) {
|
||||
for (auto & file : files) {
|
||||
if (file.m_free_data) {
|
||||
file.m_free_data();
|
||||
file.m_free_data = {};
|
||||
}
|
||||
}
|
||||
in.close();
|
||||
|
||||
size_t big_size = files[files.size()-1].m_base_addr + files[files.size()-1].m_size - files[0].m_base_addr;
|
||||
char * big_buffer = static_cast<char *>(malloc(big_size));
|
||||
for (auto & file : files) {
|
||||
std::string const & olean_fn = file.m_fname;
|
||||
try {
|
||||
file.m_buffer = big_buffer + (file.m_base_addr - files[0].m_base_addr);
|
||||
file.m_in.read(file.m_buffer, file.m_size);
|
||||
if (!file.m_in) {
|
||||
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "'").str());
|
||||
}
|
||||
file.m_in.close();
|
||||
} catch (exception & ex) {
|
||||
return io_result_mk_error((sstream() << "failed to read '" << olean_fn << "': " << ex.what()).str());
|
||||
}
|
||||
}
|
||||
files[0].m_free_data = [=]() {
|
||||
free_sized(big_buffer, big_size);
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<object_ref> res;
|
||||
for (auto & file : files) {
|
||||
compacted_region * region =
|
||||
new compacted_region(size - sizeof(olean_header), buffer, base_addr + sizeof(olean_header), is_mmap, free_data);
|
||||
new compacted_region(file.m_size - sizeof(olean_header), file.m_buffer + sizeof(olean_header), static_cast<char *>(file.m_base_addr) + sizeof(olean_header), is_mmap, file.m_free_data);
|
||||
#if defined(__has_feature)
|
||||
#if __has_feature(address_sanitizer)
|
||||
// do not report as leak
|
||||
@@ -234,18 +315,9 @@ extern "C" LEAN_EXPORT object * lean_read_module_data(object * fname, object *)
|
||||
object * mod_region = alloc_cnstr(0, 2, 0);
|
||||
cnstr_set(mod_region, 0, mod);
|
||||
cnstr_set(mod_region, 1, box_size_t(reinterpret_cast<size_t>(region)));
|
||||
return io_result_mk_ok(mod_region);
|
||||
} catch (exception & ex) {
|
||||
return io_result_mk_error((sstream() << "failed to read '" << olean_fn << "': " << ex.what()).str());
|
||||
|
||||
res.push_back(object_ref(mod_region));
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
@[export lean.write_module_core]
|
||||
def writeModule (env : Environment) (fname : String) : IO Unit := */
|
||||
extern "C" object * lean_write_module(object * env, object * fname, object *);
|
||||
|
||||
void write_module(elab_environment const & env, std::string const & olean_fn) {
|
||||
consume_io_result(lean_write_module(env.to_obj_arg(), mk_string(olean_fn), io_mk_world()));
|
||||
return io_result_mk_ok(to_array(res));
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user