Compare commits

..

52 Commits

Author SHA1 Message Date
Kim Morrison
d3c8b31d71 cleanup 2025-04-14 03:00:13 +02:00
Kim Morrison
14ed646c3e chore: updates to List/Array.Perm API 2025-04-14 02:59:06 +02:00
Lean stage0 autoupdater
c5e20c980c chore: update stage0 2025-04-13 23:32:03 +00:00
Leonardo de Moura
cd5b495573 feat: add [grind ext] attribute (#7949)
This PR adds the attribute `[grind ext]`. It is used to select which
`[ext]` theorems should be used by `grind`. The option `grind +extAll`
instructs `grind` to use all `[ext]` theorems available in the
environment.
After update stage0, we need to add the builtin `[grind ext]`
annotations to key theorems such as `funext`.
2025-04-13 22:08:36 +00:00
Leonardo de Moura
2337b95676 feat: improve case split heuristics in grind (#7946)
This PR improves the case split heuristics in `grind`.
2025-04-13 17:57:56 +00:00
Sebastian Ullrich
973f521c46 chore: fix cmake install exclude patterns (#7941) 2025-04-13 12:32:55 +00:00
Sebastian Ullrich
069456ea9c chore: disable flaky test 2025-04-13 13:18:05 +02:00
Kim Morrison
aa2cae8801 feat: List/Array/Vector.count_replace lemmas (#7938)
This PR adds lemmas about `List/Array/Vector.countP/count` interacting
with `replace`. (Specializing to `_self` and `_ne` lemmas doesn't seem
useful, as there will still be an `if` on the RHS.)
2025-04-13 03:10:19 +00:00
Leonardo de Moura
f513c35742 feat: lookahead in grind (#7937)
This PR implements a lookahead feature to reduce the size of the search
space in `grind`. It is currently effective only for arithmetic atoms.
2025-04-13 03:01:47 +00:00
Kim Morrison
d7cc0fd754 chore: add grind annotations for Nat/Int.min/max (#7934) 2025-04-13 01:48:17 +00:00
Kim Morrison
5f8847151d chore: generalize List.Perm.take (#7936)
Thanks @b-mehta for these generalizations.

---------

Co-authored-by: Bhavik Mehta <bhavikmehta8@gmail.com>
2025-04-13 01:45:48 +00:00
Kim Morrison
8bc9c4f154 chore: cleanup 'if normalization' grind example (#7935) 2025-04-13 01:09:38 +00:00
Henrik Böving
dd7ca772d8 refactor: more complete channel implementation for Std.Channel (#7819)
This PR extends `Std.Channel` to provide a full sync and async API, as
well as unbounded, zero sized and bounded channels.

A few notes on the implementation:
- the bounded channel is inspired by [Go channels on
steroids](https://docs.google.com/document/d/1yIAYmbvL3JxOKOjuCyon7JhW4cSv1wy5hC0ApeGMV9s/pub)
though currently doesn't do any of the lock-free optimizations
- @mhuisi convinced me that having a non-closable channel may be a good
idea as this alleviates the need for error handling which is very
annoying when working with `Task`. This does complicate the API a little
bit and I'm not quite sure whether this is a choice we want users to
give. An alternative to this would be to just write `send!` that panics
on sending to a closed channel (receiving from a closed channel is not
an error), this is for example the behavior that golang goes with.
2025-04-12 21:02:24 +00:00
Lean stage0 autoupdater
85a0232e87 chore: update stage0 2025-04-12 11:07:22 +00:00
Sebastian Ullrich
8ea6465e6d chore: CI: disable Linux 32bit (#7924)
A 2GB heap is just not that much even before fragmentation
2025-04-12 09:29:13 +00:00
Leonardo de Moura
38ed4346c2 chore: improve grind.clear_aux_decls error message (#7931)
cc @kim-em
2025-04-12 02:39:51 +00:00
Leonardo de Moura
2657f4e62c chore: move test to correct directory (#7932) 2025-04-11 19:46:47 -07:00
Leonardo de Moura
d4767a08b0 chore: another grind fixed test (#7930)
cc @kim-em
2025-04-11 19:43:35 -07:00
Leonardo de Moura
f562e72e59 chore: move test (#7921)
This test is easy for `grind`, we just need to annotate `Nat.min_def`.
2025-04-12 01:40:54 +00:00
Leonardo de Moura
5a6d45817d fix: nontermination in grind (#7928)
This PR fixes a nontermination issue in `grind`.
2025-04-11 21:06:07 +00:00
Leonardo de Moura
264095be7f fix: missing propagation and split filter in grind (#7926)
This PR fixes two issues that were preventing `grind` to solve
`getElem?_eq_some_iff`.
1. Missing propagation rule for `Exists p = False`
2. Missing conditions at `isCongrToPrevSplit` a filter for discarding
unnecessary case-splits.
2025-04-11 19:26:50 +00:00
Sebastian Ullrich
0669a04704 chore: CI: limit CCACHE_SIZE to 400MB (#7922) 2025-04-11 17:09:16 +00:00
Sebastian Ullrich
5cd352588c perf: use mimalloc with important C++ hash maps (#7868)
`unordered_map`/`unordered_set` does an allocation per insert, use
mimalloc for them for important hash maps
2025-04-11 16:23:33 +00:00
Henrik Böving
e9cc776f22 perf: bv_decide DecidableEq fast path using hash comparison (#7920)
This PR introduces a fast path based on comparing the (cached) hash
value to the `DecidableEq` instance of the core expression data type in
`bv_decide`'s bitblaster.

As we use a good hash function ™️ this should allow us to short
circuit to "not equal" quicker (if appropriate) than currently as we
will often not have to traverse all the way down to the actual conflict.
This in turn should speed up traversing of bucket chains during hash
collisions.
2025-04-11 15:00:41 +00:00
Lean stage0 autoupdater
e79fef15df chore: update stage0 2025-04-11 14:12:34 +00:00
Sebastian Ullrich
c672934f11 chore: add "Init size" benchmark (#7918) 2025-04-11 13:15:27 +00:00
Sebastian Ullrich
582877d2d3 feat: environment extension data can be split into .olean.server (#7914)
This PR adds a function hook `PersistentEnvExtension.saveEntriesFn` that
can be used to store server-only metadata such as position information
and docstrings that should not affect (re)builds.
2025-04-11 13:06:19 +00:00
Marc Huisinga
39ce3d14f4 test: make test deterministic (#7916) 2025-04-11 11:16:16 +00:00
Kim Morrison
32758aa712 feat: lemmas about permutations (#7912)
This PR adds `List.Perm.take/drop`, and `Array.Perm.extract`,
restricting permutations to sublist / subarrays when they are constant
elsewhere.
2025-04-11 08:13:58 +00:00
Kim Morrison
0f6e35dc63 feat: missing List/Array/Vector lemmas about isSome_idxOf? and relatives (#7913)
This PR adds some missing `List/Array/Vector lemmas` about
`isSome_idxOf?`, `isSome_finIdxOf?`, `isSome_findFinIdx?,
`isSome_findIdx?` and the corresponding `isNone` versions.
2025-04-11 07:45:46 +00:00
Kim Morrison
2528188dde chore: add failing grind test (#7910)
Adds a currently failing test, for a `grind` improvement.
2025-04-11 03:22:56 +00:00
Leonardo de Moura
1cdadfd47a chore: cleanup grind cutsat trace messages (#7908) 2025-04-11 00:52:18 +00:00
Kyle Miller
e07c59c831 fix: eliminate panic when inductive has autoparam parameter with underdetermined type (#7905)
This PR fixes an issue introduced bug #6125 where an `inductive` or
`structure` with an autoimplicit parameter with a type that has a
metavariable would lead to a panic. Closes #7788.

This was due to switching from `Term.addAutoBoundImplicits'` to
`Term.addAutoBoundImplicits` and not properly handling metavariables in
the parameters list. To fix this, now the inductive type headers record
the abstracted type and the number of parameters, rather than record the
parameters, the type, the local context, and the local instances. A
benefit to this over `Term.addAutoBoundImplicits'` is that the type's
parameters do not appear twice in the local context.
2025-04-11 00:19:53 +00:00
Leonardo de Moura
cbd38ceadd fix: mbtc and cast issue in grind (#7907)
This PR fixes two bugs in `grind`. 
1. Model-based theory combination was creating type incorrect terms.
2. `Nat.cast` vs `NatCast.natCast` issue during normalization.
2025-04-10 22:46:56 +00:00
Kyle Miller
c46f1e941c fix: sorry in Infoview shouldn't show module name (#7813)
This PR fixes an issue where `let n : Nat := sorry` in the Infoview
pretty prints as ``n : ℕ := sorry `«Foo:17:17»``. This was caused by
top-level expressions being pretty printed with the same rules as
Infoview hovers. Closes #6715. Refactors `Lean.Widget.ppExprTagged`; now
it takes a delaborator, and downstream users should configure their own
pretty printer option overrides if necessary if they used the `explicit`
argument (see `Lean.Widget.makePopup.ppExprForPopup` for an example).
Breaking change: `ppExprTagged` does not set `pp.proofs` on the root
expression.
2025-04-10 21:47:07 +00:00
Markus Himmel
cf3b257ccd chore: Option cleanup (#7897)
This PR cleans up the `Option` development, upstreaming some results
from mathlib in the process.

Notable changes:
- the name `<op>_eq_some_iff` is preferred over `<op>_eq_some`
- the `simp` normal form for `<$>` is `Option.map`, for `>>=` is
`Option.bind` and for `<|>` is `Option.orElse` (for the former two, this
was already true before this PR). All further lemmas about these
operations are now stated only in terms of
`Option.map`/`Option.bind`/`Option.orElse`. Previously, in some cases
both versions were available, with a prime used to disambiguate (the
primed version was usually the "non-ascii-art" version). Now, there are
no lemmas about the ascii-art versions besides the ones turning them
into the non-ascii-art operations, and there is only one version of
every lemma, about the non-ascii-art operation, and named without a
prime.
2025-04-10 18:53:30 +00:00
Kyle Miller
09ab15dc6d fix: remove infinite loop in withFnRefWhenTagAppFns (#7904)
This PR fixes an oversight in `withFnRefWhenTagAppFns` that causes an
infinite loop when the expression is a constant. This affected pretty
printing of zero-field structures when `pp.tagAppFns` was true (used by
docgen and verso). Closes #7898.
2025-04-10 17:16:29 +00:00
Sebastian Ullrich
e631efd817 feat: introduce Elab.inServer option (#7902)
This PR introduces a dedicated option for checking whether elaborators
are running in the language server.
2025-04-10 14:51:37 +00:00
Sebastian Graf
d2f4ce0158 fix: Add Inhabited instance for OptionT (#7901)
This PR adds `instance [Pure f] : Inhabited (OptionT f α)`, so that
`Inhabited (OptionT Id Empty)` synthesizes.

Co-authored-by: Sebastian Graf <sg@lean-fro.org>
2025-04-10 14:49:03 +00:00
Sebastian Ullrich
69536808ca feat: read/writeModuleDataParts API for serialization with cross-file sharing (#7854)
This PR introduces fundamental API to distribute module data across
multiple files in preparation for the module system.
2025-04-10 13:32:24 +00:00
Markus Himmel
3d5dd15de4 chore: move bmod results from LemmasAux.lean to DivMod/Lemmas.lean (#7899)
This PR shuffles some results about integers around to make sure that
all material that currently exists about `Int.bmod` is located in
`DivMod/Lemmas.lean` and not downstream of that.
2025-04-10 12:07:11 +00:00
Lean stage0 autoupdater
91c245663b chore: update stage0 2025-04-10 12:26:07 +00:00
Sebastian Ullrich
1421b6145e fix: cancellation of synchronous part of previous elaboration (#7882)
This PR fixes a regression where elaboration of a previous document
version is not cancelled on changes to the document.

Done by removing the default from `SnapshotTask.cancelTk?` and
consistently passing the current thread's token for synchronous
elaboration steps.
2025-04-10 11:43:41 +00:00
Kim Morrison
bffa642ad6 feat: Lean.Grind.IsCharP (#7870)
This PR adds a mixin typeclass for `Lean.Grind.CommRing` recording the
characteristic of the ring, and constructs instances for `Int`, `IntX`,
`UIntX`, and `BitVec`.
2025-04-10 08:36:42 +00:00
Kim Morrison
deef1c2739 feat: BitVec.pow and Pow (BitVec w) Nat (#7893)
This PR adds `BitVec.pow` and `Pow (BitVec w) Nat`. The implementation
is the naive one, and should later be replaced by an `@[extern]`. This
is tracked at https://github.com/leanprover/lean4/issues/7887.
2025-04-10 05:21:30 +00:00
Kim Morrison
acf42bd30b chore: add simp lemma Int.cast x = x for x : Int (#7891)
This PR adds the rfl simp lemma `Int.cast x = x` for `x : Int`.
2025-04-10 02:35:06 +00:00
Leonardo de Moura
4947215325 feat: improve funext support in grind (#7892)
This PR improves the support for `funext` in `grind`. We will push
another PR to minimize the number of case-splits later.
2025-04-10 01:57:27 +00:00
Kim Morrison
6e7209dfa3 chore: add Int.dvd_iff_bmod_eq_zero (#7890)
This PR adds missing lemmas about `Int.bmod`, parallel to lemmas about
the other `mod` variants.
2025-04-10 01:36:42 +00:00
Kim Morrison
97a00b3881 chore: variant of Int.toNat_sub (#7889)
This PR adds `Int.toNat_sub''` a variant of `Int.toNat_sub` taking
inequality hypotheses, rather than expecting the arguments to be casts
of natural numbers. This is parallel to the existing `toNat_add` and
`toNat_mul`.
2025-04-10 01:34:48 +00:00
Kim Morrison
d758b4c862 chore: Fin.ofNat'_mul, analogous to existing add lemmas (#7888)
This PR adds `Fin.ofNat'_mul` and `Fin.mul_ofNat'`, parallel to the
existing lemmas about `add`.
2025-04-10 01:32:47 +00:00
Kim Morrison
61d7716ad8 feat: UIntX.pow and Pow UIntX Nat instances (#7886)
This PR adds `UIntX.pow` and `Pow UIntX Nat` instances, and similarly
for signed fixed-width integers. These are currently only the naive
implementation, and will need to be subsequently replaced via
`@[extern]` with fast implementations (tracked at #7887).
2025-04-10 00:27:48 +00:00
Kim Morrison
05f16ed279 feat: UIntX.ofInt (#7880)
This PR adds the functions `UIntX.ofInt`, and basic lemmas.
2025-04-09 23:50:29 +00:00
602 changed files with 4158 additions and 1411 deletions

View File

@@ -46,7 +46,7 @@ jobs:
CCACHE_DIR: ${{ github.workspace }}/.ccache
CCACHE_COMPRESS: true
# current cache limit
CCACHE_MAXSIZE: 600M
CCACHE_MAXSIZE: 400M
# squelch error message about missing nixpkgs channel
NIX_BUILD_SHELL: bash
LSAN_OPTIONS: max_leaks=10

View File

@@ -256,17 +256,18 @@ jobs:
"llvm-url": "https://github.com/leanprover/lean-llvm/releases/download/15.0.1/lean-llvm-aarch64-linux-gnu.tar.zst",
"prepare-llvm": "../script/prepare-llvm-linux.sh lean-llvm*"
},
{
"name": "Linux 32bit",
"os": "ubuntu-latest",
// Use 32bit on stage0 and stage1 to keep oleans compatible
"CMAKE_OPTIONS": "-DSTAGE0_USE_GMP=OFF -DSTAGE0_LEAN_EXTRA_CXX_FLAGS='-m32' -DSTAGE0_LEANC_OPTS='-m32' -DSTAGE0_MMAP=OFF -DUSE_GMP=OFF -DLEAN_EXTRA_CXX_FLAGS='-m32' -DLEANC_OPTS='-m32' -DMMAP=OFF -DLEAN_INSTALL_SUFFIX=-linux_x86 -DCMAKE_LIBRARY_PATH=/usr/lib/i386-linux-gnu/ -DSTAGE0_CMAKE_LIBRARY_PATH=/usr/lib/i386-linux-gnu/ -DPKG_CONFIG_EXECUTABLE=/usr/bin/i386-linux-gnu-pkg-config",
"cmultilib": true,
"release": true,
"check-level": 2,
"cross": true,
"shell": "bash -euxo pipefail {0}"
}
// Started running out of memory building expensive modules, a 2GB heap is just not that much even before fragmentation
//{
// "name": "Linux 32bit",
// "os": "ubuntu-latest",
// // Use 32bit on stage0 and stage1 to keep oleans compatible
// "CMAKE_OPTIONS": "-DSTAGE0_USE_GMP=OFF -DSTAGE0_LEAN_EXTRA_CXX_FLAGS='-m32' -DSTAGE0_LEANC_OPTS='-m32' -DSTAGE0_MMAP=OFF -DUSE_GMP=OFF -DLEAN_EXTRA_CXX_FLAGS='-m32' -DLEANC_OPTS='-m32' -DMMAP=OFF -DLEAN_INSTALL_SUFFIX=-linux_x86 -DCMAKE_LIBRARY_PATH=/usr/lib/i386-linux-gnu/ -DSTAGE0_CMAKE_LIBRARY_PATH=/usr/lib/i386-linux-gnu/ -DPKG_CONFIG_EXECUTABLE=/usr/bin/i386-linux-gnu-pkg-config",
// "cmultilib": true,
// "release": true,
// "check-level": 2,
// "cross": true,
// "shell": "bash -euxo pipefail {0}"
//}
// {
// "name": "Web Assembly",
// "os": "ubuntu-latest",

View File

@@ -780,12 +780,11 @@ add_custom_target(clean-olean
DEPENDS clean-stdlib)
install(DIRECTORY "${CMAKE_BINARY_DIR}/lib/" DESTINATION lib
PATTERN temp
PATTERN "*.export"
PATTERN "*.hash"
PATTERN "*.trace"
PATTERN "*.rsp"
EXCLUDE)
PATTERN temp EXCLUDE
PATTERN "*.export" EXCLUDE
PATTERN "*.hash" EXCLUDE
PATTERN "*.trace" EXCLUDE
PATTERN "*.rsp" EXCLUDE)
# symlink source into expected installation location for go-to-definition, if file system allows it
file(MAKE_DIRECTORY ${CMAKE_BINARY_DIR}/src)

View File

@@ -59,6 +59,9 @@ instance : Monad (OptionT m) where
pure := OptionT.pure
bind := OptionT.bind
instance {m : Type u Type v} [Pure m] : Inhabited (OptionT m α) where
default := pure (f:=m) default
/--
Recovers from failures. Typically used via the `<|>` operator.
-/

View File

@@ -34,7 +34,6 @@ import Init.Data.Stream
import Init.Data.Prod
import Init.Data.AC
import Init.Data.Queue
import Init.Data.Channel
import Init.Data.Sum
import Init.Data.BEq
import Init.Data.Subtype

View File

@@ -288,6 +288,17 @@ theorem count_flatMap {α} [BEq β] {xs : Array α} {f : α → Array β} {x :
rcases xs with xs
simp [List.count_flatMap, countP_flatMap, Function.comp_def]
theorem countP_replace {a b : α} {xs : Array α} {p : α Bool} :
(xs.replace a b).countP p =
if xs.contains a then xs.countP p + (if p b then 1 else 0) - (if p a then 1 else 0) else xs.countP p := by
rcases xs with xs
simp [List.countP_replace]
theorem count_replace {a b c : α} {xs : Array α} :
(xs.replace a b).count c =
if xs.contains a then xs.count c + (if b == c then 1 else 0) - (if a == c then 1 else 0) else xs.count c := by
simp [count_eq_countP, countP_replace]
-- FIXME these theorems can be restored once `List.erase` and `Array.erase` have been related.
-- theorem count_erase (a b : α) (l : Array α) : count a (l.erase b) = count a l - if b == a then 1 else 0 := by

View File

@@ -446,11 +446,13 @@ theorem findIdx?_eq_none_iff {xs : Array α} {p : α → Bool} :
rcases xs with xs
simp
@[simp]
theorem findIdx?_isSome {xs : Array α} {p : α Bool} :
(xs.findIdx? p).isSome = xs.any p := by
rcases xs with xs
simp [List.findIdx?_isSome]
@[simp]
theorem findIdx?_isNone {xs : Array α} {p : α Bool} :
(xs.findIdx? p).isNone = xs.all (¬p ·) := by
rcases xs with xs
@@ -591,6 +593,18 @@ theorem findFinIdx?_eq_some_iff {xs : Array α} {p : α → Bool} {i : Fin xs.si
· rintro h, w
exact i, i.2, h, fun j hji => w j, by omega hji, rfl
@[simp]
theorem isSome_findFinIdx? {xs : Array α} {p : α Bool} :
(xs.findFinIdx? p).isSome = xs.any p := by
rcases xs with xs
simp
@[simp]
theorem isNone_findFinIdx? {xs : Array α} {p : α Bool} :
(xs.findFinIdx? p).isNone = xs.all (fun x => ¬ p x) := by
rcases xs with xs
simp
@[simp] theorem findFinIdx?_subtype {p : α Prop} {xs : Array { x // p x }}
{f : { x // p x } Bool} {g : α Bool} (hf : x h, f x, h = g x) :
xs.findFinIdx? f = (xs.unattach.findFinIdx? g).map (fun i => i.cast (by simp)) := by
@@ -636,6 +650,20 @@ The lemmas below should be made consistent with those for `findIdx?` (and proved
rcases xs with xs
simp [List.idxOf?_eq_none_iff]
@[simp]
theorem isSome_idxOf? [BEq α] [LawfulBEq α] {xs : Array α} {a : α} :
(xs.idxOf? a).isSome a xs := by
rcases xs with xs
simp
@[simp]
theorem isNone_idxOf? [BEq α] [LawfulBEq α] {xs : Array α} {a : α} :
(xs.idxOf? a).isNone = ¬ a xs := by
rcases xs with xs
simp
/-! ### finIdxOf?
The verification API for `finIdxOf?` is still incomplete.
@@ -658,4 +686,16 @@ theorem idxOf?_eq_map_finIdxOf?_val [BEq α] {xs : Array α} {a : α} :
rcases xs with xs
simp [List.finIdxOf?_eq_some_iff]
@[simp]
theorem isSome_finIdxOf? [BEq α] [LawfulBEq α] {xs : Array α} {a : α} :
(xs.finIdxOf? a).isSome a xs := by
rcases xs with xs
simp
@[simp]
theorem isNone_finIdxOf? [BEq α] [LawfulBEq α] {xs : Array α} {a : α} :
(xs.finIdxOf? a).isNone = ¬ a xs := by
rcases xs with xs
simp
end Array

View File

@@ -53,6 +53,12 @@ instance : Trans (Perm (α := α)) (Perm (α := α)) (Perm (α := α)) where
theorem perm_comm {xs ys : Array α} : xs ~ ys ys ~ xs := Perm.symm, Perm.symm
theorem Perm.mem_iff {a : α} {xs ys : Array α} (p : xs ~ ys) : a xs a ys := by
rcases xs with xs
rcases ys with ys
simp at p
simpa using p.mem_iff
theorem Perm.push (x y : α) {xs ys : Array α} (p : xs ~ ys) :
(xs.push x).push y ~ (ys.push y).push x := by
cases xs; cases ys
@@ -65,4 +71,20 @@ theorem swap_perm {xs : Array α} {i j : Nat} (h₁ : i < xs.size) (h₂ : j < x
simp only [swap, perm_iff_toList_perm, toList_set]
apply set_set_perm
namespace Perm
set_option linter.indexVariables false in
theorem extract {xs ys : Array α} (h : xs ~ ys) {lo hi : Nat}
(wlo : i, i < lo xs[i]? = ys[i]?) (whi : i, hi i xs[i]? = ys[i]?) :
(xs.extract lo hi) ~ (ys.extract lo hi) := by
rcases xs with xs
rcases ys with ys
simp_all only [perm_toArray, List.getElem?_toArray, List.extract_toArray,
List.extract_eq_drop_take]
apply List.Perm.take_of_getElem? (w := fun i h => by simpa using whi (lo + i) (by omega))
apply List.Perm.drop_of_getElem? (w := wlo)
exact h
end Perm
end Array

View File

@@ -227,6 +227,20 @@ SMT-LIB name: `bvmul`.
protected def mul (x y : BitVec n) : BitVec n := BitVec.ofNat n (x.toNat * y.toNat)
instance : Mul (BitVec n) := .mul
/--
Raises a bitvector to a natural number power. Usually accessed via the `^` operator.
Note that this is currently an inefficient implementation,
and should be replaced via an `@[extern]` with a native implementation.
See https://github.com/leanprover/lean4/issues/7887.
-/
protected def pow (x : BitVec n) (y : Nat) : BitVec n :=
match y with
| 0 => 1
| y + 1 => x.pow y * x
instance : Pow (BitVec n) Nat where
pow x y := x.pow y
/--
Unsigned division of bitvectors using the Lean convention where division by zero returns zero.
Usually accessed via the `/` operator.

View File

@@ -3653,6 +3653,13 @@ theorem mul_def {n} {x y : BitVec n} : x * y = (ofFin <| x.toFin * y.toFin) := b
@[simp, bitvec_to_nat] theorem toNat_mul (x y : BitVec n) : (x * y).toNat = (x.toNat * y.toNat) % 2 ^ n := rfl
@[simp] theorem toFin_mul (x y : BitVec n) : (x * y).toFin = (x.toFin * y.toFin) := rfl
theorem ofNat_mul {n} (x y : Nat) : BitVec.ofNat n (x * y) = BitVec.ofNat n x * BitVec.ofNat n y := by
apply eq_of_toNat_eq
simp [BitVec.ofNat, Fin.ofNat'_mul]
theorem ofNat_mul_ofNat {n} (x y : Nat) : BitVec.ofNat n x * BitVec.ofNat n y = BitVec.ofNat n (x * y) :=
(ofNat_mul x y).symm
protected theorem mul_comm (x y : BitVec w) : x * y = y * x := by
apply eq_of_toFin_eq; simpa using Fin.mul_comm ..
instance : Std.Commutative (fun (x y : BitVec w) => x * y) := BitVec.mul_comm
@@ -3746,6 +3753,22 @@ theorem setWidth_mul (x y : BitVec w) (h : i ≤ w) :
have dvd : 2^i 2^w := Nat.pow_dvd_pow _ h
simp [bitvec_to_nat, h, Nat.mod_mod_of_dvd _ dvd]
/-! ### pow -/
@[simp]
protected theorem pow_zero {x : BitVec w} : x ^ 0 = 1#w := rfl
protected theorem pow_succ {x : BitVec w} : x ^ (n + 1) = x ^ n * x := rfl
@[simp]
protected theorem pow_one {x : BitVec w} : x ^ 1 = x := by simp [BitVec.pow_succ]
protected theorem pow_add {x : BitVec w} {n m : Nat}: x ^ (n + m) = (x ^ n) * (x ^ m):= by
induction m with
| zero => simp
| succ m ih =>
rw [ Nat.add_assoc, BitVec.pow_succ, ih, BitVec.mul_assoc, BitVec.pow_succ]
/-! ### le and lt -/
@[bitvec_to_nat] theorem le_def {x y : BitVec n} :

View File

@@ -1,149 +0,0 @@
/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Gabriel Ebner
-/
prelude
import Init.Data.Queue
import Init.System.Promise
import Init.System.Mutex
set_option linter.deprecated false
namespace IO
/--
Internal state of an `Channel`.
We maintain the invariant that at all times either `consumers` or `values` is empty.
-/
@[deprecated "Use Std.Channel.State from Std.Sync.Channel instead" (since := "2024-12-02")]
structure Channel.State (α : Type) where
values : Std.Queue α :=
consumers : Std.Queue (Promise (Option α)) :=
closed := false
deriving Inhabited
/--
FIFO channel with unbounded buffer, where `recv?` returns a `Task`.
A channel can be closed. Once it is closed, all `send`s are ignored, and
`recv?` returns `none` once the queue is empty.
-/
@[deprecated "Use Std.Channel from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel (α : Type) : Type := Mutex (Channel.State α)
instance : Nonempty (Channel α) :=
inferInstanceAs (Nonempty (Mutex _))
/-- Creates a new `Channel`. -/
@[deprecated "Use Std.Channel.new from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.new : BaseIO (Channel α) :=
Mutex.new {}
/--
Sends a message on an `Channel`.
This function does not block.
-/
@[deprecated "Use Std.Channel.send from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.send (ch : Channel α) (v : α) : BaseIO Unit :=
ch.atomically do
let st get
if st.closed then return
if let some (consumer, consumers) := st.consumers.dequeue? then
consumer.resolve (some v)
set { st with consumers }
else
set { st with values := st.values.enqueue v }
/--
Closes an `Channel`.
-/
@[deprecated "Use Std.Channel.close from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.close (ch : Channel α) : BaseIO Unit :=
ch.atomically do
let st get
for consumer in st.consumers.toArray do consumer.resolve none
set { st with closed := true, consumers := }
/--
Receives a message, without blocking.
The returned task waits for the message.
Every message is only received once.
Returns `none` if the channel is closed and the queue is empty.
-/
@[deprecated "Use Std.Channel.recv? from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.recv? (ch : Channel α) : BaseIO (Task (Option α)) :=
ch.atomically do
let st get
if let some (a, values) := st.values.dequeue? then
set { st with values }
return .pure a
else if !st.closed then
let promise Promise.new
set { st with consumers := st.consumers.enqueue promise }
return promise.result
else
return .pure none
/--
`ch.forAsync f` calls `f` for every messages received on `ch`.
Note that if this function is called twice, each `forAsync` only gets half the messages.
-/
@[deprecated "Use Std.Channel.forAsync from Std.Sync.Channel instead" (since := "2024-12-02")]
partial def Channel.forAsync (f : α BaseIO Unit) (ch : Channel α)
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
BaseIO.bindTask (prio := prio) ( ch.recv?) fun
| none => return .pure ()
| some v => do f v; ch.forAsync f prio
/--
Receives all currently queued messages from the channel.
Those messages are dequeued and will not be returned by `recv?`.
-/
@[deprecated "Use Std.Channel.recvAllCurrent from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.recvAllCurrent (ch : Channel α) : BaseIO (Array α) :=
ch.atomically do
modifyGet fun st => (st.values.toArray, { st with values := })
/-- Type tag for synchronous (blocking) operations on a `Channel`. -/
@[deprecated "Use Std.Channel.Sync from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.Sync := Channel
/--
Accesses synchronous (blocking) version of channel operations.
For example, `ch.sync.recv?` blocks until the next message,
and `for msg in ch.sync do ...` iterates synchronously over the channel.
These functions should only be used in dedicated threads.
-/
@[deprecated "Use Std.Channel.sync from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.sync (ch : Channel α) : Channel.Sync α := ch
/--
Synchronously receives a message from the channel.
Every message is only received once.
Returns `none` if the channel is closed and the queue is empty.
-/
@[deprecated "Use Std.Channel.Sync.recv? from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.Sync.recv? (ch : Channel.Sync α) : BaseIO (Option α) := do
IO.wait ( Channel.recv? ch)
@[deprecated "Use Std.Channel.Sync.forIn from Std.Sync.Channel instead" (since := "2024-12-02")]
private partial def Channel.Sync.forIn [Monad m] [MonadLiftT BaseIO m]
(ch : Channel.Sync α) (f : α β m (ForInStep β)) : β m β := fun b => do
match ch.recv? with
| some a =>
match f a b with
| .done b => pure b
| .yield b => ch.forIn f b
| none => pure b
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
instance [MonadLiftT BaseIO m] : ForIn m (Channel.Sync α) α where
forIn ch b f := ch.forIn f b

View File

@@ -976,6 +976,16 @@ theorem coe_sub_iff_lt {a b : Fin n} : (↑(a - b) : Nat) = n + a - b ↔ a < b
/-! ### mul -/
theorem ofNat'_mul [NeZero n] (x : Nat) (y : Fin n) :
Fin.ofNat' n x * y = Fin.ofNat' n (x * y.val) := by
apply Fin.eq_of_val_eq
simp [Fin.ofNat', Fin.mul_def]
theorem mul_ofNat' [NeZero n] (x : Fin n) (y : Nat) :
x * Fin.ofNat' n y = Fin.ofNat' n (x.val * y) := by
apply Fin.eq_of_val_eq
simp [Fin.ofNat', Fin.mul_def]
theorem val_mul {n : Nat} : a b : Fin n, (a * b).val = a.val * b.val % n
| _, _, _, _ => rfl

View File

@@ -420,6 +420,8 @@ instance : IntCast Int where intCast n := n
protected def Int.cast {R : Type u} [IntCast R] : Int R :=
IntCast.intCast
@[simp] theorem Int.cast_eq (x : Int) : Int.cast x = x := rfl
-- see the notes about coercions into arbitrary types in the module doc-string
instance [IntCast R] : CoeTail Int R where coe := Int.cast

View File

@@ -2145,6 +2145,11 @@ theorem bmod_pos (x : Int) (m : Nat) (p : x % m < (m + 1) / 2) : bmod x m = x %
theorem bmod_neg (x : Int) (m : Nat) (p : x % m (m + 1) / 2) : bmod x m = (x % m) - m := by
simp [bmod_def, Int.not_lt.mpr p]
theorem bmod_eq_emod (x : Int) (m : Nat) : bmod x m = x % m - if x % m (m + 1) / 2 then m else 0 := by
split
· rwa [bmod_neg]
· rw [bmod_pos] <;> simp_all
@[simp]
theorem bmod_one_is_zero (x : Int) : Int.bmod x 1 = 0 := by
simp [Int.bmod]
@@ -2373,6 +2378,43 @@ theorem bmod_neg_bmod : bmod (-(bmod x n)) n = bmod (-x) n := by
apply (bmod_add_cancel_right x).mp
rw [Int.add_left_neg, add_bmod_bmod, Int.add_left_neg]
theorem bmod_neg_iff {m : Nat} {x : Int} (h2 : -m x) (h1 : x < m) :
(x.bmod m) < 0 (-(m / 2) x x < 0) ((m + 1) / 2 x) := by
simp only [Int.bmod_def]
by_cases xpos : 0 x
· rw [Int.emod_eq_of_lt xpos (by omega)]; omega
· rw [Int.add_emod_self.symm, Int.emod_eq_of_lt (by omega) (by omega)]; omega
theorem bmod_eq_self_of_le {n : Int} {m : Nat} (hn' : -(m / 2) n) (hn : n < (m + 1) / 2) :
n.bmod m = n := by
rw [ Int.sub_eq_zero]
have := le_bmod (x := n) (m := m) (by omega)
have := bmod_lt (x := n) (m := m) (by omega)
apply eq_zero_of_dvd_of_natAbs_lt_natAbs Int.dvd_bmod_sub_self
omega
theorem bmod_bmod_of_dvd {a : Int} {n m : Nat} (hnm : n m) :
(a.bmod m).bmod n = a.bmod n := by
rw [ Int.sub_eq_iff_eq_add.2 (bmod_add_bdiv a m).symm]
obtain k, rfl := hnm
simp [Int.mul_assoc]
theorem bmod_eq_self_of_le_mul_two {x : Int} {y : Nat} (hle : -y x * 2) (hlt : x * 2 < y) :
x.bmod y = x := by
apply bmod_eq_self_of_le (by omega) (by omega)
theorem dvd_iff_bmod_eq_zero {a : Nat} {b : Int} : (a : Int) b b.bmod a = 0 := by
rw [dvd_iff_emod_eq_zero, bmod]
split <;> rename_i h
· rfl
· simp only [Int.not_lt] at h
match a with
| 0 => omega
| a + 1 =>
have : b % (a+1) < a + 1 := emod_lt b (by omega)
simp_all
omega
/-! Helper theorems for `dvd` simproc -/
protected theorem dvd_eq_true_of_mod_eq_zero {a b : Int} (h : b % a == 0) : (a b) = True := by

View File

@@ -377,6 +377,11 @@ theorem toNat_of_nonpos : ∀ {z : Int}, z ≤ 0 → z.toNat = 0
@[simp] theorem negSucc_add_one_eq_neg_ofNat_iff {a b : Nat} : -[a+1] + 1 = - (b : Int) a = b := by
rw [eq_comm, neg_ofNat_eq_negSucc_add_one_iff, eq_comm]
protected theorem sub_eq_iff_eq_add {b a c : Int} : a - b = c a = c + b := by
refine fun h => ?_, fun h => ?_ <;> subst h <;> simp
protected theorem sub_eq_iff_eq_add' {b a c : Int} : a - b = c a = b + c := by
rw [Int.sub_eq_iff_eq_add, Int.add_comm]
/- ## add/sub injectivity -/
@[simp] protected theorem add_left_inj {i j : Int} (k : Int) : (i + k = j + k) i = j := by

View File

@@ -19,9 +19,6 @@ namespace Int
@[simp] theorem natCast_le_zero : {n : Nat} (n : Int) 0 n = 0 := by omega
protected theorem sub_eq_iff_eq_add {b a c : Int} : a - b = c a = c + b := by omega
protected theorem sub_eq_iff_eq_add' {b a c : Int} : a - b = c a = b + c := by omega
@[simp] protected theorem neg_nonpos_iff (i : Int) : -i 0 0 i := by omega
@[simp] theorem zero_le_ofNat (n : Nat) : 0 ((no_index (OfNat.ofNat n)) : Int) :=
@@ -81,15 +78,6 @@ theorem eq_ofNat_toNat {a : Int} : a = a.toNat ↔ 0 ≤ a := by omega
theorem toNat_le_toNat {n m : Int} (h : n m) : n.toNat m.toNat := by omega
theorem toNat_lt_toNat {n m : Int} (hn : 0 < m) : n.toNat < m.toNat n < m := by omega
/-! ### natAbs -/
theorem eq_zero_of_dvd_of_natAbs_lt_natAbs {d n : Int} (h : d n) (h₁ : n.natAbs < d.natAbs) :
n = 0 := by
obtain a, rfl := h
rw [natAbs_mul] at h₁
suffices ¬ 0 < a.natAbs by simp [Int.natAbs_eq_zero.1 (Nat.eq_zero_of_not_pos this)]
exact fun h => Nat.lt_irrefl _ (Nat.lt_of_le_of_lt (Nat.le_mul_of_pos_right d.natAbs h) h₁)
/-! ### min and max -/
@[simp] protected theorem min_assoc : (a b c : Int), min (min a b) c = min a (min b c) := by omega
@@ -128,33 +116,6 @@ protected theorem sub_min_sub_left (a b c : Int) : min (a - b) (a - c) = a - max
protected theorem sub_max_sub_left (a b c : Int) : max (a - b) (a - c) = a - min b c := by omega
/-! ### bmod -/
theorem bmod_neg_iff {m : Nat} {x : Int} (h2 : -m x) (h1 : x < m) :
(x.bmod m) < 0 (-(m / 2) x x < 0) ((m + 1) / 2 x) := by
simp only [Int.bmod_def]
by_cases xpos : 0 x
· rw [Int.emod_eq_of_lt xpos (by omega)]; omega
· rw [Int.add_emod_self.symm, Int.emod_eq_of_lt (by omega) (by omega)]; omega
theorem bmod_eq_self_of_le {n : Int} {m : Nat} (hn' : -(m / 2) n) (hn : n < (m + 1) / 2) :
n.bmod m = n := by
rw [ Int.sub_eq_zero]
have := le_bmod (x := n) (m := m) (by omega)
have := bmod_lt (x := n) (m := m) (by omega)
apply eq_zero_of_dvd_of_natAbs_lt_natAbs Int.dvd_bmod_sub_self
omega
theorem bmod_bmod_of_dvd {a : Int} {n m : Nat} (hnm : n m) :
(a.bmod m).bmod n = a.bmod n := by
rw [ Int.sub_eq_iff_eq_add.2 (bmod_add_bdiv a m).symm]
obtain k, rfl := hnm
simp [Int.mul_assoc]
theorem bmod_eq_self_of_le_mul_two {x : Int} {y : Nat} (hle : -y x * 2) (hlt : x * 2 < y) :
x.bmod y = x := by
apply bmod_eq_self_of_le (by omega) (by omega)
theorem mul_le_mul_of_natAbs_le {x y : Int} {s t : Nat} (hx : x.natAbs s) (hy : y.natAbs t) :
x * y s * t := by
by_cases 0 < s 0 < t

View File

@@ -329,9 +329,9 @@ protected theorem le_iff_lt_add_one {a b : Int} : a ≤ b ↔ a < b + 1 := by
/- ### min and max -/
protected theorem min_def (n m : Int) : min n m = if n m then n else m := rfl
@[grind =] protected theorem min_def (n m : Int) : min n m = if n m then n else m := rfl
protected theorem max_def (n m : Int) : max n m = if n m then m else n := rfl
@[grind =] protected theorem max_def (n m : Int) : max n m = if n m then m else n := rfl
@[simp] protected theorem neg_min_neg (a b : Int) : min (-a) (-b) = -max a b := by
rw [Int.min_def, Int.max_def]
@@ -562,6 +562,16 @@ theorem natAbs_sub_of_nonneg_of_le {a b : Int} (h₁ : 0 ≤ b) (h₂ : b ≤ a)
· rwa [ Int.ofNat_le, natAbs_of_nonneg h₁, natAbs_of_nonneg (Int.le_trans h₁ h₂)]
· exact Int.sub_nonneg_of_le h₂
theorem eq_zero_of_dvd_of_natAbs_lt_natAbs {d n : Int} (h : d n) (h₁ : n.natAbs < d.natAbs) :
n = 0 := by
let a, ha := h
subst ha
rw [natAbs_mul] at h₁
suffices ¬ 0 < a.natAbs by simp [Int.natAbs_eq_zero.1 (Nat.eq_zero_of_not_pos this)]
refine fun h => Nat.lt_irrefl _ (Nat.lt_of_le_of_lt ?_ h₁)
rw (occs := [1]) [ Nat.mul_one d.natAbs]
exact Nat.mul_le_mul (Nat.le_refl _) h
/-! ### toNat -/
theorem toNat_eq_max : a : Int, (toNat a : Int) = max a 0
@@ -599,6 +609,18 @@ theorem toNat_add {a b : Int} (ha : 0 ≤ a) (hb : 0 ≤ b) : (a + b).toNat = a.
match a, b, eq_ofNat_of_zero_le ha, eq_ofNat_of_zero_le hb with
| _, _, _, rfl, _, rfl => rfl
theorem toNat_mul {a b : Int} (ha : 0 a) (hb : 0 b) : (a * b).toNat = a.toNat * b.toNat :=
match a, b, eq_ofNat_of_zero_le ha, eq_ofNat_of_zero_le hb with
| _, _, _, rfl, _, rfl => rfl
/--
Variant of `Int.toNat_sub` taking non-negativity hypotheses,
rather than expecting the arguments to be casts of natural numbers.
-/
theorem toNat_sub'' {a b : Int} (ha : 0 a) (hb : 0 b) : (a - b).toNat = a.toNat - b.toNat :=
match a, b, eq_ofNat_of_zero_le ha, eq_ofNat_of_zero_le hb with
| _, _, _, rfl, _, rfl => toNat_sub _ _
theorem toNat_add_nat {a : Int} (ha : 0 a) (n : Nat) : (a + n).toNat = a.toNat + n :=
match a, eq_ofNat_of_zero_le ha with | _, _, rfl => rfl

View File

@@ -477,7 +477,7 @@ theorem attach_filterMap {l : List α} {f : α → Option β} :
· simp only [h]
rfl
rw [ih]
simp only [map_filterMap, Option.map_pbind, Option.map_some']
simp only [map_filterMap, Option.map_pbind, Option.map_some]
rfl
· simp only [Option.pbind_eq_some_iff] at h
obtain a, h, w := h

View File

@@ -95,10 +95,10 @@ theorem findSome?_eq_some_iff {f : α → Option β} {l : List α} {b : β} :
| cons x xs ih =>
simp [guard, findSome?, find?]
split <;> rename_i h
· simp only [Option.guard_eq_some] at h
· simp only [Option.guard_eq_some_iff] at h
obtain rfl, h := h
simp [h]
· simp only [Option.guard_eq_none] at h
· simp only [Option.guard_eq_none_iff] at h
simp [ih, h]
theorem find?_eq_findSome?_guard {l : List α} : find? p l = findSome? (Option.guard fun x => p x) l :=
@@ -700,6 +700,7 @@ theorem findIdx?_eq_none_iff {xs : List α} {p : α → Bool} :
simp only [findIdx?_cons]
split <;> simp_all [cond_eq_if]
@[simp]
theorem findIdx?_isSome {xs : List α} {p : α Bool} :
(xs.findIdx? p).isSome = xs.any p := by
induction xs with
@@ -708,6 +709,7 @@ theorem findIdx?_isSome {xs : List α} {p : α → Bool} :
simp only [findIdx?_cons]
split <;> simp_all
@[simp]
theorem findIdx?_isNone {xs : List α} {p : α Bool} :
(xs.findIdx? p).isNone = xs.all (¬p ·) := by
induction xs with
@@ -768,7 +770,7 @@ theorem findIdx?_eq_some_iff_getElem {xs : List α} {p : α → Bool} {i : Nat}
not_and, Classical.not_forall, Bool.not_eq_false]
intros
refine 0, zero_lt_succ i, _
· simp only [Option.map_eq_some', ih, Bool.not_eq_true, length_cons]
· simp only [Option.map_eq_some_iff, ih, Bool.not_eq_true, length_cons]
constructor
· rintro a, h, h₁, h₂, rfl
refine Nat.succ_lt_succ_iff.mpr h, by simpa, fun j hj => ?_
@@ -824,7 +826,7 @@ abbrev findIdx?_of_eq_none := @of_findIdx?_eq_none
(xs ++ ys : List α).findIdx? p =
(xs.findIdx? p).or ((ys.findIdx? p).map fun i => i + xs.length) := by
induction xs with simp
| cons _ _ _ => split <;> simp_all [Option.map_or', Option.map_map]; rfl
| cons _ _ _ => split <;> simp_all [Option.map_or, Option.map_map]; rfl
theorem findIdx?_flatten {l : List (List α)} {p : α Bool} :
l.flatten.findIdx? p =
@@ -984,6 +986,24 @@ theorem findFinIdx?_eq_some_iff {xs : List α} {p : α → Bool} {i : Fin xs.len
· rintro h, w
exact i, i.2, h, fun j hji => w j, by omega hji, rfl
@[simp]
theorem isSome_findFinIdx? {l : List α} {p : α Bool} :
(l.findFinIdx? p).isSome = l.any p := by
induction l with
| nil => simp
| cons x xs ih =>
simp only [findFinIdx?_cons]
split <;> simp_all
@[simp]
theorem isNone_findFinIdx? {l : List α} {p : α Bool} :
(l.findFinIdx? p).isNone = l.all (fun x => ¬ p x) := by
induction l with
| nil => simp
| cons x xs ih =>
simp only [findFinIdx?_cons]
split <;> simp_all
@[simp] theorem findFinIdx?_subtype {p : α Prop} {l : List { x // p x }}
{f : { x // p x } Bool} {g : α Bool} (hf : x h, f x, h = g x) :
l.findFinIdx? f = (l.unattach.findFinIdx? g).map (fun i => i.cast (by simp)) := by
@@ -1084,6 +1104,24 @@ theorem idxOf?_eq_map_finIdxOf?_val [BEq α] {xs : List α} {a : α} :
l.finIdxOf? a = some i l[i] = a j (_ : j < i), ¬l[j] = a := by
simp only [finIdxOf?, findFinIdx?_eq_some_iff, beq_iff_eq]
@[simp]
theorem isSome_finIdxOf? [BEq α] [LawfulBEq α] {l : List α} {a : α} :
(l.finIdxOf? a).isSome a l := by
induction l with
| nil => simp
| cons x xs ih =>
simp only [finIdxOf?_cons]
split <;> simp_all [@eq_comm _ x a]
@[simp]
theorem isNone_finIdxOf? [BEq α] [LawfulBEq α] {l : List α} {a : α} :
(l.finIdxOf? a).isNone = ¬ a l := by
induction l with
| nil => simp
| cons x xs ih =>
simp only [finIdxOf?_cons]
split <;> simp_all [@eq_comm _ x a]
/-! ### idxOf?
The verification API for `idxOf?` is still incomplete.
@@ -1109,6 +1147,25 @@ theorem idxOf?_cons [BEq α] {a : α} {xs : List α} {b : α} :
@[deprecated idxOf?_eq_none_iff (since := "2025-01-29")]
abbrev indexOf?_eq_none_iff := @idxOf?_eq_none_iff
@[simp]
theorem isSome_idxOf? [BEq α] [LawfulBEq α] {l : List α} {a : α} :
(l.idxOf? a).isSome a l := by
induction l with
| nil => simp
| cons x xs ih =>
simp only [idxOf?_cons]
split <;> simp_all [@eq_comm _ x a]
@[simp]
theorem isNone_idxOf? [BEq α] [LawfulBEq α] {l : List α} {a : α} :
(l.idxOf? a).isNone = ¬ a l := by
induction l with
| nil => simp
| cons x xs ih =>
simp only [idxOf?_cons]
split <;> simp_all [@eq_comm _ x a]
/-! ### lookup -/
section lookup

View File

@@ -105,7 +105,7 @@ abbrev length_eq_zero := @length_eq_zero_iff
theorem eq_nil_iff_length_eq_zero : l = [] length l = 0 :=
length_eq_zero_iff.symm
@[grind] theorem length_pos_of_mem {a : α} : {l : List α}, a l 0 < length l
@[grind ] theorem length_pos_of_mem {a : α} : {l : List α}, a l 0 < length l
| _::_, _ => Nat.zero_lt_succ _
theorem exists_mem_of_length_pos : {l : List α}, 0 < length l a, a l
@@ -185,7 +185,7 @@ theorem singleton_inj {α : Type _} {a b : α} : [a] = [b] ↔ a = b := by
We simplify `l.get i` to `l[i.1]'i.2` and `l.get? i` to `l[i]?`.
-/
@[simp, grind]
@[simp, grind =]
theorem get_eq_getElem {l : List α} {i : Fin l.length} : l.get i = l[i.1]'i.2 := rfl
set_option linter.deprecated false in
@@ -225,7 +225,7 @@ theorem get?_eq_getElem? {l : List α} {i : Nat} : l.get? i = l[i]? := by
We simplify `l[i]!` to `(l[i]?).getD default`.
-/
@[simp, grind]
@[simp, grind =]
theorem getElem!_eq_getElem?_getD [Inhabited α] {l : List α} {i : Nat} :
l[i]! = (l[i]?).getD (default : α) := by
simp only [getElem!_def]
@@ -235,16 +235,16 @@ theorem getElem!_eq_getElem?_getD [Inhabited α] {l : List α} {i : Nat} :
/-! ### getElem? and getElem -/
@[simp, grind] theorem getElem?_nil {i : Nat} : ([] : List α)[i]? = none := rfl
@[simp, grind =] theorem getElem?_nil {i : Nat} : ([] : List α)[i]? = none := rfl
theorem getElem_cons {l : List α} (w : i < (a :: l).length) :
(a :: l)[i] =
if h : i = 0 then a else l[i-1]'(match i, h with | i+1, _ => succ_lt_succ_iff.mp w) := by
cases i <;> simp
@[grind] theorem getElem?_cons_zero {l : List α} : (a::l)[0]? = some a := rfl
@[grind =] theorem getElem?_cons_zero {l : List α} : (a::l)[0]? = some a := rfl
@[simp, grind] theorem getElem?_cons_succ {l : List α} : (a::l)[i+1]? = l[i]? := rfl
@[simp, grind =] theorem getElem?_cons_succ {l : List α} : (a::l)[i+1]? = l[i]? := rfl
theorem getElem?_cons : (a :: l)[i]? = if i = 0 then some a else l[i-1]? := by
cases i <;> simp [getElem?_cons_zero]
@@ -337,7 +337,7 @@ We simplify away `getD`, replacing `getD l n a` with `(l[n]?).getD a`.
Because of this, there is only minimal API for `getD`.
-/
@[simp, grind]
@[simp, grind =]
theorem getD_eq_getElem?_getD {l : List α} {i : Nat} {a : α} : getD l i a = (l[i]?).getD a := by
simp [getD]

View File

@@ -339,7 +339,7 @@ theorem getElem?_mapIdx_go : ∀ {l : List α} {acc : Array β} {i : Nat},
if h : i < acc.size then some acc[i] else Option.map (f i) l[i - acc.size]?
| [], acc, i => by
simp only [mapIdx.go, Array.toListImpl_eq, getElem?_def, Array.length_toList,
Array.getElem_toList, length_nil, Nat.not_lt_zero, reduceDIte, Option.map_none']
Array.getElem_toList, length_nil, Nat.not_lt_zero, reduceDIte, Option.map_none]
| a :: l, acc, i => by
rw [mapIdx.go, getElem?_mapIdx_go]
simp only [Array.size_push]

View File

@@ -31,6 +31,33 @@ theorem count_set [BEq α] {a b : α} {l : List α} {i : Nat} (h : i < l.length)
(l.set i a).count b = l.count b - (if l[i] == b then 1 else 0) + (if a == b then 1 else 0) := by
simp [count_eq_countP, countP_set, h]
theorem countP_replace [BEq α] [LawfulBEq α] {a b : α} {l : List α} {p : α Bool} :
(l.replace a b).countP p =
if l.contains a then l.countP p + (if p b then 1 else 0) - (if p a then 1 else 0) else l.countP p := by
induction l with
| nil => simp
| cons x l ih =>
simp [replace_cons]
split <;> rename_i h
· simp at h
simp [h, ih, countP_cons]
omega
· simp only [beq_eq_false_iff_ne, ne_eq] at h
simp only [countP_cons, ih, contains_eq_mem, decide_eq_true_eq, mem_cons, h, false_or]
split <;> rename_i h'
· by_cases h'' : p a
· have : countP p l > 0 := countP_pos_iff.mpr a, h', h''
simp [h'']
omega
· simp [h'']
omega
· omega
theorem count_replace [BEq α] [LawfulBEq α] {a b c : α} {l : List α} :
(l.replace a b).count c =
if l.contains a then l.count c + (if b == c then 1 else 0) - (if a == c then 1 else 0) else l.count c := by
simp [count_eq_countP, countP_replace]
/--
The number of elements satisfying a predicate in a sublist is at least the number of elements satisfying the predicate in the list,
minus the difference in the lengths.

View File

@@ -54,4 +54,23 @@ theorem set_set_perm {as : List α} {i j : Nat} (h₁ : i < as.length) (h₂ : j
subst t
apply set_set_perm' _ _ (by omega)
namespace Perm
/-- Variant of `List.Perm.take` specifying the the permutation is constant after `i` elementwise. -/
theorem take_of_getElem? {l₁ l₂ : List α} (h : l₁ ~ l₂) {i : Nat} (w : j, i j l₁[j]? = l₂[j]?) :
l₁.take i ~ l₂.take i := by
refine h.take (Perm.of_eq ?_)
ext1 j
simpa using w (i + j) (by omega)
/-- Variant of `List.Perm.drop` specifying the the permutation is constant before `i` elementwise. -/
theorem drop_of_getElem? {l₁ l₂ : List α} (h : l₁ ~ l₂) {i : Nat} (w : j, j < i l₁[j]? = l₂[j]?) :
l₁.drop i ~ l₂.drop i := by
refine h.drop (Perm.of_eq ?_)
ext1
simp only [getElem?_take]
split <;> simp_all
end Perm
end List

View File

@@ -6,6 +6,7 @@ Authors: Leonardo de Moura, Jeremy Avigad, Mario Carneiro
prelude
import Init.Data.List.Pairwise
import Init.Data.List.Erase
import Init.Data.List.Find
/-!
# List Permutations
@@ -178,7 +179,7 @@ theorem Perm.singleton_eq (h : [a] ~ l) : [a] = l := singleton_perm.mp h
theorem singleton_perm_singleton {a b : α} : [a] ~ [b] a = b := by simp
theorem perm_cons_erase [DecidableEq α] {a : α} {l : List α} (h : a l) : l ~ a :: l.erase a :=
theorem perm_cons_erase [BEq α] [LawfulBEq α] {a : α} {l : List α} (h : a l) : l ~ a :: l.erase a :=
let _, _, _, e₁, e₂ := exists_erase_eq h
e₂ e₁ perm_middle
@@ -268,7 +269,7 @@ theorem countP_eq_countP_filter_add (l : List α) (p q : α → Bool) :
l.countP p = (l.filter q).countP p + (l.filter fun a => !q a).countP p :=
countP_append .. Perm.countP_eq _ (filter_append_perm _ _).symm
theorem Perm.count_eq [DecidableEq α] {l₁ l₂ : List α} (p : l₁ ~ l₂) (a) :
theorem Perm.count_eq [BEq α] {l₁ l₂ : List α} (p : l₁ ~ l₂) (a) :
count a l₁ = count a l₂ := p.countP_eq _
/-
@@ -369,9 +370,9 @@ theorem perm_append_right_iff {l₁ l₂ : List α} (l) : l₁ ++ l ~ l₂ ++ l
refine fun p => ?_, .append_right _
exact (perm_append_left_iff _).1 <| perm_append_comm.trans <| p.trans perm_append_comm
section DecidableEq
section LawfulBEq
variable [DecidableEq α]
variable [BEq α] [LawfulBEq α]
theorem Perm.erase (a : α) {l₁ l₂ : List α} (p : l₁ ~ l₂) : l₁.erase a ~ l₂.erase a :=
if h₁ : a l₁ then
@@ -387,6 +388,11 @@ theorem cons_perm_iff_perm_erase {a : α} {l₁ l₂ : List α} :
have : a l₂ := h.subset mem_cons_self
exact this, (h.trans <| perm_cons_erase this).cons_inv
end LawfulBEq
section DecidableEq
variable [DecidableEq α]
theorem perm_iff_count {l₁ l₂ : List α} : l₁ ~ l₂ a, count a l₁ = count a l₂ := by
refine Perm.count_eq, fun H => ?_
induction l₁ generalizing l₂ with
@@ -536,4 +542,22 @@ theorem perm_insertIdx {α} (x : α) (l : List α) {i} (h : i ≤ l.length) :
simp only [insertIdx, modifyTailIdx]
refine .trans (.cons _ (ih (Nat.le_of_succ_le_succ h))) (.swap ..)
namespace Perm
theorem take {l₁ l₂ : List α} (h : l₁ ~ l₂) {n : Nat} (w : l₁.drop n ~ l₂.drop n) :
l₁.take n ~ l₂.take n := by
classical
rw [perm_iff_count] at h w
rw [ take_append_drop n l₁, take_append_drop n l₂] at h
simpa only [count_append, w, Nat.add_right_cancel_iff] using h
theorem drop {l₁ l₂ : List α} (h : l₁ ~ l₂) {n : Nat} (w : l₁.take n ~ l₂.take n) :
l₁.drop n ~ l₂.drop n := by
classical
rw [perm_iff_count] at h w
rw [ take_append_drop n l₁, take_append_drop n l₂] at h
simpa only [count_append, w, Nat.add_left_cancel_iff] using h
end Perm
end List

View File

@@ -6,6 +6,7 @@ Authors: Floris van Doorn, Leonardo de Moura
prelude
import Init.SimpLemmas
import Init.Data.NeZero
import Init.Grind.Tactics
set_option linter.missingDocs true -- keep it documented
universe u
@@ -867,7 +868,7 @@ Examples:
-/
protected abbrev min (n m : Nat) := min n m
protected theorem min_def {n m : Nat} : min n m = if n m then n else m := rfl
@[grind =] protected theorem min_def {n m : Nat} : min n m = if n m then n else m := rfl
instance : Max Nat := maxOfLe
@@ -884,7 +885,7 @@ Examples:
-/
protected abbrev max (n m : Nat) := max n m
protected theorem max_def {n m : Nat} : max n m = if n m then m else n := rfl
@[grind =] protected theorem max_def {n m : Nat} : max n m = if n m then m else n := rfl
/-! # Auxiliary theorems for well-founded recursion -/

View File

@@ -141,11 +141,11 @@ theorem toList_attach (o : Option α) :
cases o <;> simp
theorem attach_map {o : Option α} (f : α β) :
(o.map f).attach = o.attach.map (fun x, h => f x, map_eq_some.2 _, h, rfl) := by
(o.map f).attach = o.attach.map (fun x, h => f x, map_eq_some_iff.2 _, h, rfl) := by
cases o <;> simp
theorem attachWith_map {o : Option α} (f : α β) {P : β Prop} {H : (b : β), o.map f = some b P b} :
(o.map f).attachWith P H = (o.attachWith (P f) (fun _ h => H _ (map_eq_some.2 _, h, rfl))).map
(o.map f).attachWith P H = (o.attachWith (P f) (fun _ h => H _ (map_eq_some_iff.2 _, h, rfl))).map
fun x, h => f x, h := by
cases o <;> simp
@@ -174,7 +174,7 @@ theorem map_attach_eq_attachWith {o : Option α} {p : α → Prop} (f : ∀ a, o
theorem attach_bind {o : Option α} {f : α Option β} :
(o.bind f).attach =
o.attach.bind fun x, h => (f x).attach.map fun y, h' => y, bind_eq_some.2 _, h, h' := by
o.attach.bind fun x, h => (f x).attach.map fun y, h' => y, bind_eq_some_iff.2 _, h, h' := by
cases o <;> simp
theorem bind_attach {o : Option α} {f : {x // o = some x} Option β} :

View File

@@ -231,13 +231,12 @@ def merge (fn : ααα) : Option α → Option α → Option α
@[simp] theorem getD_none : getD none a = a := rfl
@[simp] theorem getD_some : getD (some a) b = a := rfl
@[simp] theorem map_none' (f : α β) : none.map f = none := rfl
@[simp] theorem map_some' (a) (f : α β) : (some a).map f = some (f a) := rfl
@[simp] theorem map_none (f : α β) : none.map f = none := rfl
@[simp] theorem map_some (a) (f : α β) : (some a).map f = some (f a) := rfl
@[simp] theorem none_bind (f : α Option β) : none.bind f = none := rfl
@[simp] theorem some_bind (a) (f : α Option β) : (some a).bind f = f a := rfl
/--
A case analysis function for `Option`.

View File

@@ -39,18 +39,24 @@ This is not an instance because it is not definitionally equal to the standard i
Try to use the Boolean comparisons `Option.isNone` or `Option.isSome` instead.
-/
@[inline] def decidable_eq_none {o : Option α} : Decidable (o = none) :=
@[inline] def decidableEqNone {o : Option α} : Decidable (o = none) :=
decidable_of_decidable_of_iff isNone_iff_eq_none
instance {p : α Prop} [DecidablePred p] : o : Option α, Decidable ( a, a o p a)
| none => isTrue nofun
| some a =>
if h : p a then isTrue fun _ e => some_inj.1 e h
else isFalse <| mt (· _ rfl) h
@[deprecated decidableEqNone (since := "2025-04-10"), inline]
def decidable_eq_none {o : Option α} : Decidable (o = none) :=
decidableEqNone
instance {p : α Prop} [DecidablePred p] : o : Option α, Decidable (Exists fun a => a o p a)
| none => isFalse nofun
| some a => if h : p a then isTrue _, rfl, h else isFalse fun _, rfl, hn => h hn
instance decidableForallMem {p : α Prop} [DecidablePred p] :
o : Option α, Decidable ( a, a o p a)
| none => isTrue nofun
| some a =>
if h : p a then isTrue fun _ e => some_inj.1 e h
else isFalse <| mt (· _ rfl) h
instance decidableExistsMem {p : α Prop} [DecidablePred p] :
o : Option α, Decidable (Exists fun a => a o p a)
| none => isFalse nofun
| some a => if h : p a then isTrue _, rfl, h else isFalse fun _, rfl, hn => h hn
/--
Given an optional value and a function that can be applied when the value is `some`, returns the

View File

@@ -17,6 +17,8 @@ theorem mem_iff {a : α} {b : Option α} : a ∈ b ↔ b = some a := .rfl
theorem mem_some {a b : α} : a some b b = a := by simp
theorem mem_some_iff {a b : α} : a some b b = a := mem_some
theorem mem_some_self (a : α) : a some a := rfl
theorem some_ne_none (x : α) : some x none := nofun
@@ -29,6 +31,9 @@ protected theorem «exists» {p : Option α → Prop} :
fun | none, hx => .inl hx | some x, hx => .inr x, hx,
fun | .inl h => _, h | .inr _, hx => _, hx
theorem eq_none_or_eq_some (a : Option α) : a = none x, a = some x :=
Option.exists.mp exists_eq'
theorem get_mem : {o : Option α} (h : isSome o), o.get h o
| some _, _ => rfl
@@ -88,6 +93,9 @@ set_option Elab.async false
theorem eq_none_iff_forall_ne_some : o = none a, o some a := by
cases o <;> simp
theorem eq_none_iff_forall_some_ne : o = none a, some a o := by
cases o <;> simp
theorem eq_none_iff_forall_not_mem : o = none a, a o :=
eq_none_iff_forall_ne_some
@@ -102,9 +110,30 @@ theorem isSome_of_mem {x : Option α} {y : α} (h : y ∈ x) : x.isSome := by
theorem isSome_of_eq_some {x : Option α} {y : α} (h : x = some y) : x.isSome := by
cases x <;> trivial
@[simp] theorem not_isSome : isSome a = false a.isNone = true := by
@[simp] theorem isSome_eq_false_iff : isSome a = false a.isNone = true := by
cases a <;> simp
@[simp] theorem isNone_eq_false_iff : isNone a = false a.isSome = true := by
cases a <;> simp
@[simp]
theorem not_isSome (a : Option α) : (!a.isSome) = a.isNone := by
cases a <;> simp
@[simp]
theorem not_comp_isSome : (! ·) @Option.isSome α = Option.isNone := by
funext
simp
@[simp]
theorem not_isNone (a : Option α) : (!a.isNone) = a.isSome := by
cases a <;> simp
@[simp]
theorem not_comp_isNone : (!·) @Option.isNone α = Option.isSome := by
funext x
simp
theorem eq_some_iff_get_eq : o = some a h : o.isSome, o.get h = a := by
cases o <;> simp
@@ -146,17 +175,25 @@ abbrev ball_ne_none := @forall_ne_none
@[simp] theorem bind_eq_bind : bind = @Option.bind α β := rfl
@[simp] theorem orElse_eq_orElse : HOrElse.hOrElse = @Option.orElse α := rfl
@[simp] theorem bind_some (x : Option α) : x.bind some = x := by cases x <;> rfl
@[simp] theorem bind_none (x : Option α) : x.bind (fun _ => none (α := β)) = none := by
cases x <;> rfl
theorem bind_eq_some : x.bind f = some b a, x = some a f a = some b := by
theorem bind_eq_some_iff : x.bind f = some b a, x = some a f a = some b := by
cases x <;> simp
@[simp] theorem bind_eq_none {o : Option α} {f : α Option β} :
@[deprecated bind_eq_some_iff (since := "2025-04-10")]
abbrev bind_eq_some := @bind_eq_some_iff
@[simp] theorem bind_eq_none_iff {o : Option α} {f : α Option β} :
o.bind f = none a, o = some a f a = none := by cases o <;> simp
@[deprecated bind_eq_none_iff (since := "2025-04-10")]
abbrev bind_eq_none := @bind_eq_none_iff
theorem bind_eq_none' {o : Option α} {f : α Option β} :
o.bind f = none b a, o = some a f a some b := by
cases o <;> simp [eq_none_iff_forall_ne_some]
@@ -193,50 +230,67 @@ theorem isSome_apply_of_isSome_bind {α β : Type _} {x : Option α} {f : α
(isSome_apply_of_isSome_bind h) := by
cases x <;> trivial
theorem join_eq_some : x.join = some a x = some (some a) := by
simp [bind_eq_some]
theorem join_eq_some_iff : x.join = some a x = some (some a) := by
simp [bind_eq_some_iff]
@[deprecated join_eq_some_iff (since := "2025-04-10")]
abbrev join_eq_some := @join_eq_some_iff
theorem join_ne_none : x.join none z, x = some (some z) := by
simp only [ne_none_iff_exists', join_eq_some, iff_self]
simp only [ne_none_iff_exists', join_eq_some_iff, iff_self]
theorem join_ne_none' : ¬x.join = none z, x = some (some z) :=
join_ne_none
theorem join_eq_none : o.join = none o = none o = some none :=
theorem join_eq_none_iff : o.join = none o = none o = some none :=
match o with | none | some none | some (some _) => by simp
@[deprecated join_eq_none_iff (since := "2025-04-10")]
abbrev join_eq_none := @join_eq_none_iff
theorem bind_id_eq_join {x : Option (Option α)} : x.bind id = x.join := rfl
@[simp] theorem map_eq_map : Functor.map f = Option.map f := rfl
theorem map_none : f <$> none = none := rfl
@[deprecated map_none (since := "2025-04-10")]
abbrev map_none' := @map_none
theorem map_some : f <$> some a = some (f a) := rfl
@[deprecated map_some (since := "2025-04-10")]
abbrev map_some' := @map_some
@[simp] theorem map_eq_some' : x.map f = some b a, x = some a f a = b := by cases x <;> simp
theorem map_eq_some : f <$> x = some b a, x = some a f a = b := map_eq_some'
@[simp] theorem map_eq_none' : x.map f = none x = none := by
cases x <;> simp [map_none', map_some', eq_self_iff_true]
theorem isSome_map {x : Option α} : (f <$> x).isSome = x.isSome := by
@[simp] theorem map_eq_some_iff : x.map f = some b a, x = some a f a = b := by
cases x <;> simp
@[simp] theorem isSome_map' {x : Option α} : (x.map f).isSome = x.isSome := by
@[deprecated map_eq_some_iff (since := "2025-04-10")]
abbrev map_eq_some := @map_eq_some_iff
@[deprecated map_eq_some_iff (since := "2025-04-10")]
abbrev map_eq_some' := @map_eq_some_iff
@[simp] theorem map_eq_none_iff : x.map f = none x = none := by
cases x <;> simp [map_none, map_some, eq_self_iff_true]
@[deprecated map_eq_none_iff (since := "2025-04-10")]
abbrev map_eq_none := @map_eq_none_iff
@[deprecated map_eq_none_iff (since := "2025-04-10")]
abbrev map_eq_none' := @map_eq_none_iff
@[simp] theorem isSome_map {x : Option α} : (x.map f).isSome = x.isSome := by
cases x <;> simp
@[simp] theorem isNone_map' {x : Option α} : (x.map f).isNone = x.isNone := by
cases x <;> simp
@[deprecated isSome_map (since := "2025-04-10")]
abbrev isSome_map' := @isSome_map
theorem map_eq_none : f <$> x = none x = none := map_eq_none'
@[simp] theorem isNone_map {x : Option α} : (x.map f).isNone = x.isNone := by
cases x <;> simp
theorem map_eq_bind {x : Option α} : x.map f = x.bind (some f) := by
cases x <;> simp [Option.bind]
theorem map_congr {x : Option α} (h : a, x = some a f a = g a) :
x.map f = x.map g := by
cases x <;> simp only [map_none', map_some', h]
cases x <;> simp only [map_none, map_some, h]
@[simp] theorem map_id_fun {α : Type u} : Option.map (id : α α) = id := by
funext; simp [map_id]
@@ -254,7 +308,7 @@ theorem get_map {f : α → β} {o : Option α} {h : (o.map f).isSome} :
@[simp] theorem map_map (h : β γ) (g : α β) (x : Option α) :
(x.map g).map h = x.map (h g) := by
cases x <;> simp only [map_none', map_some', ··]
cases x <;> simp only [map_none, map_some, ··]
theorem comp_map (h : β γ) (g : α β) (x : Option α) : x.map (h g) = (x.map g).map h :=
(map_map ..).symm
@@ -262,7 +316,7 @@ theorem comp_map (h : β → γ) (g : α → β) (x : Option α) : x.map (h ∘
@[simp] theorem map_comp_map (f : α β) (g : β γ) :
Option.map g Option.map f = Option.map (g f) := by funext x; simp
theorem mem_map_of_mem (g : α β) (h : a x) : g a Option.map g x := h.symm map_some' ..
theorem mem_map_of_mem (g : α β) (h : a x) : g a Option.map g x := h.symm map_some ..
theorem map_inj_right {f : α β} {o o' : Option α} (w : x y, f x = f y x = y) :
o.map f = o'.map f o = o' := by
@@ -292,11 +346,14 @@ theorem isSome_of_isSome_filter (p : α → Bool) (o : Option α) (h : (o.filter
@[deprecated isSome_of_isSome_filter (since := "2025-03-18")]
abbrev isSome_filter_of_isSome := @isSome_of_isSome_filter
@[simp] theorem filter_eq_none {o : Option α} {p : α Bool} :
@[simp] theorem filter_eq_none_iff {o : Option α} {p : α Bool} :
o.filter p = none a, o = some a ¬ p a := by
cases o <;> simp [filter_some]
@[simp] theorem filter_eq_some {o : Option α} {p : α Bool} :
@[deprecated filter_eq_none_iff (since := "2025-04-10")]
abbrev filter_eq_none := @filter_eq_none_iff
@[simp] theorem filter_eq_some_iff {o : Option α} {p : α Bool} :
o.filter p = some a o = some a p a := by
cases o with
| none => simp
@@ -310,6 +367,9 @@ abbrev isSome_filter_of_isSome := @isSome_of_isSome_filter
rintro rfl
simpa using h
@[deprecated filter_eq_some_iff (since := "2025-04-10")]
abbrev filter_eq_some := @filter_eq_some_iff
theorem mem_filter_iff {p : α Bool} {a : α} {o : Option α} :
a o.filter p a o p a := by
simp
@@ -383,29 +443,43 @@ theorem join_join {x : Option (Option (Option α))} : x.join.join = (x.map join)
cases x <;> simp
theorem mem_of_mem_join {a : α} {x : Option (Option α)} (h : a x.join) : some a x :=
h.symm join_eq_some.1 h
h.symm join_eq_some_iff.1 h
@[simp] theorem some_orElse (a : α) (x : Option α) : (some a <|> x) = some a := rfl
@[simp] theorem some_orElse (a : α) (f) : (some a).orElse f = some a := rfl
@[simp] theorem none_orElse (x : Option α) : (none <|> x) = x := rfl
@[simp] theorem none_orElse (f : Unit Option α) : none.orElse f = f () := rfl
@[simp] theorem orElse_none (x : Option α) : (x <|> none) = x := by cases x <;> rfl
@[simp] theorem orElse_none (x : Option α) : x.orElse (fun _ => none) = x := by cases x <;> rfl
theorem map_orElse {x y : Option α} : (x <|> y).map f = (x.map f <|> y.map f) := by
theorem orElse_eq_some_iff (o : Option α) (f) (x : α) :
(o.orElse f) = some x o = some x o = none f () = some x := by
cases o <;> simp
theorem orElse_eq_none_iff (o : Option α) (f) : (o.orElse f) = none o = none f () = none := by
cases o <;> simp
theorem map_orElse {x : Option α} {y} :
(x.orElse y).map f = (x.map f).orElse (fun _ => (y ()).map f) := by
cases x <;> simp
@[simp] theorem guard_eq_some [DecidablePred p] : guard p a = some b a = b p a :=
@[simp] theorem guard_eq_some_iff [DecidablePred p] : guard p a = some b a = b p a :=
if h : p a then by simp [Option.guard, h] else by simp [Option.guard, h]
@[deprecated guard_eq_some_iff (since := "2025-04-10")]
abbrev guard_eq_some := @guard_eq_some_iff
@[simp] theorem isSome_guard [DecidablePred p] : (Option.guard p a).isSome p a :=
if h : p a then by simp [Option.guard, h] else by simp [Option.guard, h]
@[deprecated isSome_guard (since := "2025-03-18")]
abbrev guard_isSome := @isSome_guard
@[simp] theorem guard_eq_none [DecidablePred p] : Option.guard p a = none ¬ p a :=
@[simp] theorem guard_eq_none_iff [DecidablePred p] : Option.guard p a = none ¬ p a :=
if h : p a then by simp [Option.guard, h] else by simp [Option.guard, h]
@[deprecated guard_eq_none_iff (since := "2025-04-10")]
abbrev guard_eq_none := @guard_eq_none_iff
@[simp] theorem guard_pos [DecidablePred p] (h : p a) : Option.guard p a = some a := by
simp [Option.guard, h]
@@ -475,6 +549,22 @@ theorem liftOrGet_none_right {f} {a : Option α} : merge f a none = a :=
theorem liftOrGet_some_some {f} {a b : α} : merge f (some a) (some b) = f a b :=
merge_some_some
instance commutative_merge (f : α α α) [Std.Commutative f] :
Std.Commutative (merge f) :=
fun a b by cases a <;> cases b <;> simp [merge, Std.Commutative.comm]
instance associative_merge (f : α α α) [Std.Associative f] :
Std.Associative (merge f) :=
fun a b c by cases a <;> cases b <;> cases c <;> simp [merge, Std.Associative.assoc]
instance idempotentOp_merge (f : α α α) [Std.IdempotentOp f] :
Std.IdempotentOp (merge f) :=
fun a by cases a <;> simp [merge, Std.IdempotentOp.idempotent]
instance lawfulIdentity_merge (f : α α α) : Std.LawfulIdentity (merge f) none where
left_id a := by cases a <;> simp [merge]
right_id a := by cases a <;> simp [merge]
@[simp] theorem elim_none (x : β) (f : α β) : none.elim x f = x := rfl
@[simp] theorem elim_some (x : β) (f : α β) (a : α) : (some a).elim x f = f a := rfl
@@ -535,12 +625,18 @@ theorem or_eq_bif : or o o' = bif o.isSome then o else o' := by
@[simp] theorem isNone_or : (or o o').isNone = (o.isNone && o'.isNone) := by
cases o <;> rfl
@[simp] theorem or_eq_none : or o o' = none o = none o' = none := by
@[simp] theorem or_eq_none_iff : or o o' = none o = none o' = none := by
cases o <;> simp
@[simp] theorem or_eq_some : or o o' = some a o = some a (o = none o' = some a) := by
@[deprecated or_eq_none_iff (since := "2025-04-10")]
abbrev or_eq_none := @or_eq_none_iff
@[simp] theorem or_eq_some_iff : or o o' = some a o = some a (o = none o' = some a) := by
cases o <;> simp
@[deprecated or_eq_some_iff (since := "2025-04-10")]
abbrev or_eq_some := @or_eq_some_iff
theorem or_assoc : or (or o₁ o₂) o₃ = or o₁ (or o₂ o₃) := by
cases o₁ <;> cases o₂ <;> rfl
instance : Std.Associative (or (α := α)) := @or_assoc _
@@ -564,11 +660,11 @@ instance : Std.IdempotentOp (or (α := α)) := ⟨@or_self _⟩
theorem or_eq_orElse : or o o' = o.orElse (fun _ => o') := by
cases o <;> rfl
theorem map_or : f <$> or o o' = (f <$> o).or (f <$> o') := by
theorem map_or : (or o o').map f = (o.map f).or (o'.map f) := by
cases o <;> rfl
theorem map_or' : (or o o').map f = (o.map f).or (o'.map f) := by
cases o <;> rfl
@[deprecated map_or (since := "2025-04-10")]
abbrev map_or' := @map_or
theorem or_of_isSome {o o' : Option α} (h : o.isSome) : o.or o' = o := by
match o, h with
@@ -804,7 +900,7 @@ theorem map_pmap {p : α → Prop} (g : β → γ) (f : ∀ a, p a → β) (o H)
theorem pmap_map (o : Option α) (f : α β) {p : β Prop} (g : b, p b γ) (H) :
pmap g (o.map f) H =
pmap (fun a h => g (f a) h) o (fun a m => H (f a) (map_eq_some.2 _, m, rfl)) := by
pmap (fun a h => g (f a) h) o (fun a m => H (f a) (map_eq_some_iff.2 _, m, rfl)) := by
cases o <;> simp
theorem pmap_pred_congr {α : Type u}

View File

@@ -20,6 +20,9 @@ def UInt8.mk (bitVec : BitVec 8) : UInt8 :=
def UInt8.ofNatCore (n : Nat) (h : n < UInt8.size) : UInt8 :=
UInt8.ofNatLT n h
/-- Converts an `Int` to a `UInt8` by taking the (non-negative remainder of the division by `2 ^ 8`. -/
def UInt8.ofInt (x : Int) : UInt8 := ofNat (x % 2 ^ 8).toNat
/--
Adds two 8-bit unsigned integers, wrapping around on overflow. Usually accessed via the `+`
operator.
@@ -229,6 +232,9 @@ def UInt16.mk (bitVec : BitVec 16) : UInt16 :=
def UInt16.ofNatCore (n : Nat) (h : n < UInt16.size) : UInt16 :=
UInt16.ofNatLT n h
/-- Converts an `Int` to a `UInt16` by taking the (non-negative remainder of the division by `2 ^ 16`. -/
def UInt16.ofInt (x : Int) : UInt16 := ofNat (x % 2 ^ 16).toNat
/--
Adds two 16-bit unsigned integers, wrapping around on overflow. Usually accessed via the `+`
operator.
@@ -440,6 +446,9 @@ def UInt32.mk (bitVec : BitVec 32) : UInt32 :=
def UInt32.ofNatCore (n : Nat) (h : n < UInt32.size) : UInt32 :=
UInt32.ofNatLT n h
/-- Converts an `Int` to a `UInt32` by taking the (non-negative remainder of the division by `2 ^ 32`. -/
def UInt32.ofInt (x : Int) : UInt32 := ofNat (x % 2 ^ 32).toNat
/--
Adds two 32-bit unsigned integers, wrapping around on overflow. Usually accessed via the `+`
operator.
@@ -613,6 +622,9 @@ def UInt64.mk (bitVec : BitVec 64) : UInt64 :=
def UInt64.ofNatCore (n : Nat) (h : n < UInt64.size) : UInt64 :=
UInt64.ofNatLT n h
/-- Converts an `Int` to a `UInt64` by taking the (non-negative remainder of the division by `2 ^ 64`. -/
def UInt64.ofInt (x : Int) : UInt64 := ofNat (x % 2 ^ 64).toNat
/--
Adds two 64-bit unsigned integers, wrapping around on overflow. Usually accessed via the `+`
operator.
@@ -822,6 +834,9 @@ def USize.mk (bitVec : BitVec System.Platform.numBits) : USize :=
def USize.ofNatCore (n : Nat) (h : n < USize.size) : USize :=
USize.ofNatLT n h
/-- Converts an `Int` to a `USize` by taking the (non-negative remainder of the division by `2 ^ numBits`. -/
def USize.ofInt (x : Int) : USize := ofNat (x % 2 ^ System.Platform.numBits).toNat
@[simp] theorem USize.le_size : 2 ^ 32 USize.size := by cases USize.size_eq <;> simp_all
@[simp] theorem USize.size_le : USize.size 2 ^ 64 := by cases USize.size_eq <;> simp_all

View File

@@ -286,6 +286,17 @@ declare_uint_theorems USize System.Platform.numBits
theorem USize.toNat_ofNat_of_lt_32 {n : Nat} (h : n < 4294967296) : toNat (ofNat n) = n :=
toNat_ofNat_of_lt (Nat.lt_of_lt_of_le h USize.le_size)
theorem UInt8.ofNat_mod_size : ofNat (x % 2 ^ 8) = ofNat x := by
simp [ofNat, BitVec.ofNat, Fin.ofNat']
theorem UInt16.ofNat_mod_size : ofNat (x % 2 ^ 16) = ofNat x := by
simp [ofNat, BitVec.ofNat, Fin.ofNat']
theorem UInt32.ofNat_mod_size : ofNat (x % 2 ^ 32) = ofNat x := by
simp [ofNat, BitVec.ofNat, Fin.ofNat']
theorem UInt64.ofNat_mod_size : ofNat (x % 2 ^ 64) = ofNat x := by
simp [ofNat, BitVec.ofNat, Fin.ofNat']
theorem USize.ofNat_mod_size : ofNat (x % 2 ^ System.Platform.numBits) = ofNat x := by
simp [ofNat, BitVec.ofNat, Fin.ofNat']
theorem UInt8.lt_ofNat_iff {n : UInt8} {m : Nat} (h : m < size) : n < ofNat m n.toNat < m := by
rw [lt_iff_toNat_lt, toNat_ofNat_of_lt' h]
theorem UInt8.ofNat_lt_iff {n : UInt8} {m : Nat} (h : m < size) : ofNat m < n m < n.toNat := by
@@ -2081,6 +2092,23 @@ theorem USize.ofNat_eq_iff_mod_eq_toNat (a : Nat) (b : USize) : USize.ofNat a =
USize.ofNatLT (a % b) (Nat.mod_lt_of_lt ha) = USize.ofNatLT a ha % USize.ofNatLT b hb := by
simp [USize.ofNatLT_eq_ofNat, USize.ofNat_mod ha hb]
@[simp] theorem UInt8.ofInt_one : ofInt 1 = 1 := rfl
@[simp] theorem UInt8.ofInt_neg_one : ofInt (-1) = -1 := rfl
@[simp] theorem UInt16.ofInt_one : ofInt 1 = 1 := rfl
@[simp] theorem UInt16.ofInt_neg_one : ofInt (-1) = -1 := rfl
@[simp] theorem UInt32.ofInt_one : ofInt 1 = 1 := rfl
@[simp] theorem UInt32.ofInt_neg_one : ofInt (-1) = -1 := rfl
@[simp] theorem UInt64.ofInt_one : ofInt 1 = 1 := rfl
@[simp] theorem UInt64.ofInt_neg_one : ofInt (-1) = -1 := rfl
@[simp] theorem USize.ofInt_one : ofInt 1 = 1 := by
rcases System.Platform.numBits_eq with h | h <;>
· apply USize.toNat_inj.mp
simp_all [USize.ofInt, USize.ofNat, size, toNat]
@[simp] theorem USize.ofInt_neg_one : ofInt (-1) = -1 := by
rcases System.Platform.numBits_eq with h | h <;>
· apply USize.toNat_inj.mp
simp_all [USize.ofInt, USize.ofNat, size, toNat]
@[simp] theorem UInt8.ofNat_add (a b : Nat) : UInt8.ofNat (a + b) = UInt8.ofNat a + UInt8.ofNat b := by
simp [UInt8.ofNat_eq_iff_mod_eq_toNat]
@[simp] theorem UInt16.ofNat_add (a b : Nat) : UInt16.ofNat (a + b) = UInt16.ofNat a + UInt16.ofNat b := by
@@ -2092,6 +2120,70 @@ theorem USize.ofNat_eq_iff_mod_eq_toNat (a : Nat) (b : USize) : USize.ofNat a =
@[simp] theorem USize.ofNat_add (a b : Nat) : USize.ofNat (a + b) = USize.ofNat a + USize.ofNat b := by
simp [USize.ofNat_eq_iff_mod_eq_toNat]
@[simp] theorem UInt8.ofInt_add (x y : Int) : ofInt (x + y) = ofInt x + ofInt y := by
dsimp only [UInt8.ofInt]
rw [Int.add_emod]
have h₁ : 0 x % 2 ^ 8 := Int.emod_nonneg _ (by decide)
have h₂ : 0 y % 2 ^ 8 := Int.emod_nonneg _ (by decide)
have h₃ : 0 x % 2 ^ 8 + y % 2 ^ 8 := Int.add_nonneg h₁ h₂
rw [Int.toNat_emod h₃ (by decide), Int.toNat_add h₁ h₂]
have : (2 ^ 8 : Int).toNat = 2 ^ 8 := rfl
rw [this, UInt8.ofNat_mod_size, UInt8.ofNat_add]
@[simp] theorem UInt16.ofInt_add (x y : Int) : UInt16.ofInt (x + y) = UInt16.ofInt x + UInt16.ofInt y := by
dsimp only [UInt16.ofInt]
rw [Int.add_emod]
have h₁ : 0 x % 2 ^ 16 := Int.emod_nonneg _ (by decide)
have h₂ : 0 y % 2 ^ 16 := Int.emod_nonneg _ (by decide)
have h₃ : 0 x % 2 ^ 16 + y % 2 ^ 16 := Int.add_nonneg h₁ h₂
rw [Int.toNat_emod h₃ (by decide), Int.toNat_add h₁ h₂]
have : (2 ^ 16 : Int).toNat = 2 ^ 16 := rfl
rw [this, UInt16.ofNat_mod_size, UInt16.ofNat_add]
@[simp] theorem UInt32.ofInt_add (x y : Int) : UInt32.ofInt (x + y) = UInt32.ofInt x + UInt32.ofInt y := by
dsimp only [UInt32.ofInt]
rw [Int.add_emod]
have h₁ : 0 x % 2 ^ 32 := Int.emod_nonneg _ (by decide)
have h₂ : 0 y % 2 ^ 32 := Int.emod_nonneg _ (by decide)
have h₃ : 0 x % 2 ^ 32 + y % 2 ^ 32 := Int.add_nonneg h₁ h₂
rw [Int.toNat_emod h₃ (by decide), Int.toNat_add h₁ h₂]
have : (2 ^ 32 : Int).toNat = 2 ^ 32 := rfl
rw [this, UInt32.ofNat_mod_size, UInt32.ofNat_add]
@[simp] theorem UInt64.ofInt_add (x y : Int) : UInt64.ofInt (x + y) = UInt64.ofInt x + UInt64.ofInt y := by
dsimp only [UInt64.ofInt]
rw [Int.add_emod]
have h₁ : 0 x % 2 ^ 64 := Int.emod_nonneg _ (by decide)
have h₂ : 0 y % 2 ^ 64 := Int.emod_nonneg _ (by decide)
have h₃ : 0 x % 2 ^ 64 + y % 2 ^ 64 := Int.add_nonneg h₁ h₂
rw [Int.toNat_emod h₃ (by decide), Int.toNat_add h₁ h₂]
have : (2 ^ 64 : Int).toNat = 2 ^ 64 := rfl
rw [this, UInt64.ofNat_mod_size, UInt64.ofNat_add]
namespace System.Platform
theorem two_pow_numBits_nonneg : 0 (2 ^ System.Platform.numBits : Int) := by
rcases System.Platform.numBits_eq with h | h <;>
· rw [h]
decide
theorem two_pow_numBits_ne_zero : (2 ^ System.Platform.numBits : Int) 0 := by
rcases System.Platform.numBits_eq with h | h <;>
· rw [h]
decide
end System.Platform
open System.Platform in
@[simp] theorem USize.ofInt_add (x y : Int) : USize.ofInt (x + y) = USize.ofInt x + USize.ofInt y := by
dsimp only [USize.ofInt]
rw [Int.add_emod]
have h₁ : 0 x % 2 ^ numBits := Int.emod_nonneg _ two_pow_numBits_ne_zero
have h₂ : 0 y % 2 ^ numBits := Int.emod_nonneg _ two_pow_numBits_ne_zero
have h₃ : 0 x % 2 ^ numBits + y % 2 ^ numBits := Int.add_nonneg h₁ h₂
rw [Int.toNat_emod h₃ two_pow_numBits_nonneg, Int.toNat_add h₁ h₂]
have : (2 ^ numBits : Int).toNat = 2 ^ numBits := by
rcases System.Platform.numBits_eq with h | h <;>
· rw [h]
decide
rw [this, USize.ofNat_mod_size, USize.ofNat_add]
@[simp] theorem UInt8.ofNatLT_add {a b : Nat} (hab : a + b < 2 ^ 8) :
UInt8.ofNatLT (a + b) hab = UInt8.ofNatLT a (Nat.lt_of_add_right_lt hab) + UInt8.ofNatLT b (Nat.lt_of_add_left_lt hab) := by
simp [UInt8.ofNatLT_eq_ofNat]
@@ -2176,6 +2268,56 @@ theorem USize.ofNatLT_sub {a b : Nat} (ha : a < 2 ^ System.Platform.numBits) (ha
@[simp] theorem USize.ofNat_mul (a b : Nat) : USize.ofNat (a * b) = USize.ofNat a * USize.ofNat b := by
simp [USize.ofNat_eq_iff_mod_eq_toNat]
@[simp] theorem UInt8.ofInt_mul (x y : Int) : ofInt (x * y) = ofInt x * ofInt y := by
dsimp only [UInt8.ofInt]
rw [Int.mul_emod]
have h₁ : 0 x % 2 ^ 8 := Int.emod_nonneg _ (by decide)
have h₂ : 0 y % 2 ^ 8 := Int.emod_nonneg _ (by decide)
have h₃ : 0 (x % 2 ^ 8) * (y % 2 ^ 8) := Int.mul_nonneg h₁ h₂
rw [Int.toNat_emod h₃ (by decide), Int.toNat_mul h₁ h₂]
have : (2 ^ 8 : Int).toNat = 2 ^ 8 := rfl
rw [this, UInt8.ofNat_mod_size, UInt8.ofNat_mul]
@[simp] theorem UInt16.ofInt_mul (x y : Int) : ofInt (x * y) = ofInt x * ofInt y := by
dsimp only [UInt16.ofInt]
rw [Int.mul_emod]
have h₁ : 0 x % 2 ^ 16 := Int.emod_nonneg _ (by decide)
have h₂ : 0 y % 2 ^ 16 := Int.emod_nonneg _ (by decide)
have h₃ : 0 (x % 2 ^ 16) * (y % 2 ^ 16) := Int.mul_nonneg h₁ h₂
rw [Int.toNat_emod h₃ (by decide), Int.toNat_mul h₁ h₂]
have : (2 ^ 16 : Int).toNat = 2 ^ 16 := rfl
rw [this, UInt16.ofNat_mod_size, UInt16.ofNat_mul]
@[simp] theorem UInt32.ofInt_mul (x y : Int) : ofInt (x * y) = ofInt x * ofInt y := by
dsimp only [UInt32.ofInt]
rw [Int.mul_emod]
have h₁ : 0 x % 2 ^ 32 := Int.emod_nonneg _ (by decide)
have h₂ : 0 y % 2 ^ 32 := Int.emod_nonneg _ (by decide)
have h₃ : 0 (x % 2 ^ 32) * (y % 2 ^ 32) := Int.mul_nonneg h₁ h₂
rw [Int.toNat_emod h₃ (by decide), Int.toNat_mul h₁ h₂]
have : (2 ^ 32 : Int).toNat = 2 ^ 32 := rfl
rw [this, UInt32.ofNat_mod_size, UInt32.ofNat_mul]
@[simp] theorem UInt64.ofInt_mul (x y : Int) : ofInt (x * y) = ofInt x * ofInt y := by
dsimp only [UInt64.ofInt]
rw [Int.mul_emod]
have h₁ : 0 x % 2 ^ 64 := Int.emod_nonneg _ (by decide)
have h₂ : 0 y % 2 ^ 64 := Int.emod_nonneg _ (by decide)
have h₃ : 0 (x % 2 ^ 64) * (y % 2 ^ 64) := Int.mul_nonneg h₁ h₂
rw [Int.toNat_emod h₃ (by decide), Int.toNat_mul h₁ h₂]
have : (2 ^ 64 : Int).toNat = 2 ^ 64 := rfl
rw [this, UInt64.ofNat_mod_size, UInt64.ofNat_mul]
open System.Platform in
@[simp] theorem USize.ofInt_mul (x y : Int) : ofInt (x * y) = ofInt x * ofInt y := by
dsimp only [USize.ofInt]
rw [Int.mul_emod]
have h₁ : 0 x % 2 ^ numBits := Int.emod_nonneg _ two_pow_numBits_ne_zero
have h₂ : 0 y % 2 ^ numBits := Int.emod_nonneg _ two_pow_numBits_ne_zero
have h₃ : 0 (x % 2 ^ numBits) * (y % 2 ^ numBits) := Int.mul_nonneg h₁ h₂
rw [Int.toNat_emod h₃ two_pow_numBits_nonneg, Int.toNat_mul h₁ h₂]
have : (2 ^ numBits : Int).toNat = 2 ^ numBits := by
rcases System.Platform.numBits_eq with h | h <;>
· rw [h]
decide
rw [this, USize.ofNat_mod_size, USize.ofNat_mul]
@[simp] theorem UInt8.ofNatLT_mul {a b : Nat} (ha : a < 2 ^ 8) (hb : b < 2 ^ 8) (hab : a * b < 2 ^ 8) :
UInt8.ofNatLT (a * b) hab = UInt8.ofNatLT a ha * UInt8.ofNatLT b hb := by
simp [UInt8.ofNatLT_eq_ofNat]
@@ -2467,6 +2609,17 @@ protected theorem USize.neg_add {a b : USize} : - (a + b) = -a - b := USize.toBi
@[simp] protected theorem USize.neg_sub {a b : USize} : -(a - b) = b - a := by
rw [USize.sub_eq_add_neg, USize.neg_add, USize.sub_neg, USize.add_comm, USize.sub_eq_add_neg]
@[simp] protected theorem UInt8.ofInt_neg (x : Int) : ofInt (-x) = -ofInt x := by
rw [Int.neg_eq_neg_one_mul, ofInt_mul, ofInt_neg_one, UInt8.neg_eq_neg_one_mul]
@[simp] protected theorem UInt16.ofInt_neg (x : Int) : ofInt (-x) = -ofInt x := by
rw [Int.neg_eq_neg_one_mul, ofInt_mul, ofInt_neg_one, UInt16.neg_eq_neg_one_mul]
@[simp] protected theorem UInt32.ofInt_neg (x : Int) : ofInt (-x) = -ofInt x := by
rw [Int.neg_eq_neg_one_mul, ofInt_mul, ofInt_neg_one, UInt32.neg_eq_neg_one_mul]
@[simp] protected theorem UInt64.ofInt_neg (x : Int) : ofInt (-x) = -ofInt x := by
rw [Int.neg_eq_neg_one_mul, ofInt_mul, ofInt_neg_one, UInt64.neg_eq_neg_one_mul]
@[simp] protected theorem USize.ofInt_neg (x : Int) : ofInt (-x) = -ofInt x := by
rw [Int.neg_eq_neg_one_mul, ofInt_mul, ofInt_neg_one, USize.neg_eq_neg_one_mul]
@[simp] protected theorem UInt8.add_left_inj {a b : UInt8} (c : UInt8) : (a + c = b + c) a = b := by
simp [ UInt8.toBitVec_inj]
@[simp] protected theorem UInt16.add_left_inj {a b : UInt16} (c : UInt16) : (a + c = b + c) a = b := by

View File

@@ -239,4 +239,15 @@ theorem count_flatMap {α} [BEq β] {xs : Vector α n} {f : α → Vector β m}
rcases xs with xs, rfl
simp [Array.count_flatMap, Function.comp_def]
theorem countP_replace {a b : α} {xs : Vector α n} {p : α Bool} :
(xs.replace a b).countP p =
if xs.contains a then xs.countP p + (if p b then 1 else 0) - (if p a then 1 else 0) else xs.countP p := by
rcases xs with xs, rfl
simp [Array.countP_replace]
theorem count_replace {a b c : α} {xs : Vector α n} :
(xs.replace a b).count c =
if xs.contains a then xs.count c + (if b == c then 1 else 0) - (if a == c then 1 else 0) else xs.count c := by
simp [count_eq_countP, countP_replace]
end count

View File

@@ -294,6 +294,18 @@ theorem find?_eq_some_iff_getElem {xs : Vector α n} {p : α → Bool} {b : α}
subst w
simp
@[simp]
theorem isSome_findFinIdx? {xs : Vector α n} {p : α Bool} :
(xs.findFinIdx? p).isSome = xs.any p := by
rcases xs with xs, rfl
simp
@[simp]
theorem isNone_findFinIdx? {xs : Vector α n} {p : α Bool} :
(xs.findFinIdx? p).isNone = xs.all (fun x => ¬ p x) := by
rcases xs with xs, rfl
simp
@[simp] theorem findFinIdx?_subtype {p : α Prop} {xs : Vector { x // p x } n}
{f : { x // p x } Bool} {g : α Bool} (hf : x h, f x, h = g x) :
xs.findFinIdx? f = xs.unattach.findFinIdx? g := by

View File

@@ -1511,7 +1511,7 @@ theorem map_eq_iff {f : α → β} {as : Vector α n} {bs : Vector β n} :
if h : i < as.size then
simpa [h, h'] using w i h
else
rw [getElem?_neg, getElem?_neg, Option.map_none'] <;> omega
rw [getElem?_neg, getElem?_neg, Option.map_none] <;> omega
@[simp] theorem map_set {f : α β} {xs : Vector α n} {i : Nat} {h : i < n} {a : α} :
(xs.set i a).map f = (xs.map f).set i (f a) (by simpa using h) := by

View File

@@ -5,14 +5,30 @@ Authors: Kim Morrison
-/
prelude
import Init.Data.Zero
import Init.Data.Int.DivMod.Lemmas
import Init.TacticsExtra
/-!
# A monolithic commutative ring typeclass for internal use in `grind`.
The `Lean.Grind.CommRing` class will be used to convert expressions into the internal representation via polynomials,
with coefficients expressed via `OfNat` and `Neg`.
The `IsCharP α p` typeclass expresses that the ring has characteristic `p`,
i.e. that a coefficient `OfNat.ofNat x : α` is zero if and only if `x % p = 0` (in `Nat`).
See
```
theorem ofNat_ext_iff {x y : Nat} : OfNat.ofNat (α := α) x = OfNat.ofNat (α := α) y ↔ x % p = y % p
theorem ofNat_emod (x : Nat) : OfNat.ofNat (α := α) (x % p) = OfNat.ofNat x
theorem ofNat_eq_iff_of_lt {x y : Nat} (h₁ : x < p) (h₂ : y < p) :
OfNat.ofNat (α := α) x = OfNat.ofNat (α := α) y ↔ x = y
```
-/
namespace Lean.Grind
class CommRing (α : Type u) extends Add α, Zero α, Mul α, One α, Neg α where
class CommRing (α : Type u) extends Add α, Mul α, Neg α, Sub α, HPow α Nat α where
[ofNat : n, OfNat α n]
add_assoc : a b c : α, a + b + c = a + (b + c)
add_comm : a b : α, a + b = b + a
add_zero : a : α, a + 0 = a
@@ -22,11 +38,31 @@ class CommRing (α : Type u) extends Add α, Zero α, Mul α, One α, Neg α whe
mul_one : a : α, a * 1 = a
left_distrib : a b c : α, a * (b + c) = a * b + a * c
zero_mul : a : α, 0 * a = 0
sub_eq_add_neg : a b : α, a - b = a + -b
pow_zero : a : α, a ^ 0 = 1
pow_succ : a : α, n : Nat, a ^ (n + 1) = (a ^ n) * a
ofNat_succ : a : Nat, OfNat.ofNat (α := α) (a + 1) = OfNat.ofNat a + 1 := by intros; rfl
-- This is a low-priority instance, to avoid conflicts with existing `OfNat` instances.
attribute [instance 100] CommRing.ofNat
namespace CommRing
variable {α : Type u} [CommRing α]
instance : NatCast α where
natCast n := OfNat.ofNat n
theorem natCast_zero : ((0 : Nat) : α) = 0 := rfl
theorem ofNat_eq_natCast (n : Nat) : OfNat.ofNat n = (n : α) := rfl
theorem ofNat_add (a b : Nat) : OfNat.ofNat (α := α) (a + b) = OfNat.ofNat a + OfNat.ofNat b := by
induction b with
| zero => simp [Nat.add_zero, add_zero]
| succ b ih => rw [Nat.add_succ, ofNat_succ, ih, ofNat_succ b, add_assoc]
theorem natCast_succ (n : Nat) : ((n + 1 : Nat) : α) = ((n : α) + 1) := ofNat_add _ _
theorem zero_add (a : α) : 0 + a = a := by
rw [add_comm, add_zero]
@@ -42,6 +78,204 @@ theorem right_distrib (a b c : α) : (a + b) * c = a * c + b * c := by
theorem mul_zero (a : α) : a * 0 = 0 := by
rw [mul_comm, zero_mul]
theorem ofNat_mul (a b : Nat) : OfNat.ofNat (α := α) (a * b) = OfNat.ofNat a * OfNat.ofNat b := by
induction b with
| zero => simp [Nat.mul_zero, mul_zero]
| succ a ih => rw [Nat.mul_succ, ofNat_add, ih, ofNat_add, left_distrib, mul_one]
theorem add_left_inj {a b : α} (c : α) : a + c = b + c a = b :=
fun h => by simpa [add_assoc, add_neg_cancel, add_zero] using (congrArg (· + -c) h),
fun g => congrArg (· + c) g
theorem add_right_inj (a b c : α) : a + b = a + c b = c := by
rw [add_comm a b, add_comm a c, add_left_inj]
theorem neg_zero : (-0 : α) = 0 := by
rw [ add_left_inj 0, neg_add_cancel, add_zero]
theorem neg_neg (a : α) : -(-a) = a := by
rw [ add_left_inj (-a), neg_add_cancel, add_neg_cancel]
theorem neg_eq_zero (a : α) : -a = 0 a = 0 :=
fun h => by
replace h := congrArg (-·) h
simpa [neg_neg, neg_zero] using h,
fun h => by rw [h, neg_zero]
theorem neg_add (a b : α) : -(a + b) = -a + -b := by
rw [ add_left_inj (a + b), neg_add_cancel, add_assoc (-a), add_comm a b, add_assoc (-b),
neg_add_cancel, zero_add, neg_add_cancel]
theorem neg_sub (a b : α) : -(a - b) = b - a := by
rw [sub_eq_add_neg, neg_add, neg_neg, sub_eq_add_neg, add_comm]
theorem sub_self (a : α) : a - a = 0 := by
rw [sub_eq_add_neg, add_neg_cancel]
instance : IntCast α where
intCast n := match n with
| Int.ofNat n => OfNat.ofNat n
| Int.negSucc n => -OfNat.ofNat (n + 1)
theorem intCast_zero : ((0 : Int) : α) = 0 := rfl
theorem intCast_one : ((1 : Int) : α) = 1 := rfl
theorem intCast_neg_one : ((-1 : Int) : α) = -1 := rfl
theorem intCast_ofNat (n : Nat) : ((n : Int) : α) = (n : α) := rfl
theorem intCast_ofNat_add_one (n : Nat) : ((n + 1 : Int) : α) = (n : α) + 1 := ofNat_add _ _
theorem intCast_negSucc (n : Nat) : ((-(n + 1) : Int) : α) = -((n : α) + 1) := congrArg (- ·) (ofNat_add _ _)
theorem intCast_neg (x : Int) : ((-x : Int) : α) = - (x : α) :=
match x with
| (0 : Nat) => neg_zero.symm
| (n + 1 : Nat) => by
rw [Int.natCast_add, Int.cast_ofNat_Int, intCast_negSucc, intCast_ofNat_add_one]
| -((n : Nat) + 1) => by
rw [Int.neg_neg, intCast_ofNat_add_one, intCast_negSucc, neg_neg]
theorem intCast_nat_add {x y : Nat} : ((x + y : Int) : α) = ((x : α) + (y : α)) := ofNat_add _ _
theorem intCast_nat_sub {x y : Nat} (h : x y) : (((x - y : Nat) : Int) : α) = ((x : α) - (y : α)) := by
induction x with
| zero =>
have : y = 0 := by omega
simp [this, intCast_zero, natCast_zero, sub_eq_add_neg, zero_add, neg_zero]
| succ x ih =>
by_cases h : x + 1 = y
· simp [h, intCast_zero, sub_self]
· have : ((x + 1 - y : Nat) : Int) = (x - y : Nat) + 1 := by omega
rw [this, intCast_ofNat_add_one]
specialize ih (by omega)
rw [intCast_ofNat] at ih
rw [ih, natCast_succ, sub_eq_add_neg, sub_eq_add_neg, add_assoc, add_comm _ 1, add_assoc]
theorem intCast_add (x y : Int) : ((x + y : Int) : α) = ((x : α) + (y : α)) :=
match x, y with
| (x : Nat), (y : Nat) => ofNat_add _ _
| (x : Nat), (-(y + 1 : Nat)) => by
by_cases h : x y + 1
· have : (x + -(y+1 : Nat) : Int) = ((x - (y + 1) : Nat) : Int) := by omega
rw [this, intCast_neg, intCast_nat_sub h, intCast_ofNat, intCast_ofNat, sub_eq_add_neg]
· have : (x + -(y+1 : Nat) : Int) = (-(y + 1 - x : Nat) : Int) := by omega
rw [this, intCast_neg, intCast_nat_sub (by omega), intCast_ofNat, intCast_neg, intCast_ofNat,
neg_sub, sub_eq_add_neg]
| (-(x + 1 : Nat)), (y : Nat) => by
by_cases h : y x+ 1
· have : (-(x+1 : Nat) + y : Int) = ((y - (x + 1) : Nat) : Int) := by omega
rw [this, intCast_neg, intCast_nat_sub h, intCast_ofNat, intCast_ofNat, sub_eq_add_neg, add_comm]
· have : (-(x+1 : Nat) + y : Int) = (-(x + 1 - y : Nat) : Int) := by omega
rw [this, intCast_neg, intCast_nat_sub (by omega), intCast_ofNat, intCast_neg, intCast_ofNat,
neg_sub, sub_eq_add_neg, add_comm]
| (-(x + 1 : Nat)), (-(y + 1 : Nat)) => by
rw [ Int.neg_add, intCast_neg, intCast_nat_add, neg_add, intCast_neg, intCast_neg, intCast_ofNat, intCast_ofNat]
theorem intCast_sub (x y : Int) : ((x - y : Int) : α) = ((x : α) - (y : α)) := by
rw [Int.sub_eq_add_neg, intCast_add, intCast_neg, sub_eq_add_neg]
theorem neg_eq_neg_one_mul (a : α) : -a = (-1) * a := by
rw [ add_left_inj a, neg_add_cancel]
conv => rhs; arg 2; rw [ one_mul a]
rw [ right_distrib, intCast_neg_one, intCast_one (α := α)]
simp [ intCast_add, intCast_zero, zero_mul]
theorem neg_mul (a b : α) : (-a) * b = -(a * b) := by
rw [neg_eq_neg_one_mul a, neg_eq_neg_one_mul (a * b), mul_assoc]
theorem mul_neg (a b : α) : a * (-b) = -(a * b) := by
rw [mul_comm, neg_mul, mul_comm]
theorem intCast_nat_mul (x y : Nat) : ((x * y : Int) : α) = ((x : α) * (y : α)) := ofNat_mul _ _
theorem intCast_mul (x y : Int) : ((x * y : Int) : α) = ((x : α) * (y : α)) :=
match x, y with
| (x : Nat), (y : Nat) => ofNat_mul _ _
| (x : Nat), (-(y + 1 : Nat)) => by
rw [Int.mul_neg, intCast_neg, intCast_nat_mul, intCast_neg, mul_neg, intCast_ofNat, intCast_ofNat]
| (-(x + 1 : Nat)), (y : Nat) => by
rw [Int.neg_mul, intCast_neg, intCast_nat_mul, intCast_neg, neg_mul, intCast_ofNat, intCast_ofNat]
| (-(x + 1 : Nat)), (-(y + 1 : Nat)) => by
rw [Int.neg_mul_neg, intCast_neg, intCast_neg, neg_mul, mul_neg, neg_neg, intCast_nat_mul,
intCast_ofNat, intCast_ofNat]
end CommRing
open CommRing
class IsCharP (α : Type u) [CommRing α] (p : Nat) where
ofNat_eq_zero_iff (p) : (x : Nat), OfNat.ofNat (α := α) x = 0 x % p = 0
namespace IsCharP
variable (p) {α : Type u} [CommRing α] [IsCharP α p]
theorem natCast_eq_zero_iff (x : Nat) : (x : α) = 0 x % p = 0 :=
ofNat_eq_zero_iff p x
theorem intCast_eq_zero_iff (x : Int) : (x : α) = 0 x % p = 0 :=
match x with
| (x : Nat) => by
have := ofNat_eq_zero_iff (α := α) p (x := x)
rw [Int.ofNat_mod_ofNat]
norm_cast
| -(x + 1 : Nat) => by
rw [Int.neg_emod, Int.ofNat_mod_ofNat, intCast_neg, intCast_ofNat, neg_eq_zero]
have := ofNat_eq_zero_iff (α := α) p (x := x + 1)
rw [ofNat_eq_natCast] at this
rw [this]
simp only [Int.ofNat_dvd]
simp only [ Nat.dvd_iff_mod_eq_zero, Int.natAbs_ofNat, Int.natCast_add,
Int.cast_ofNat_Int, ite_eq_left_iff]
by_cases h : p x + 1
· simp [h]
· simp only [h, not_false_eq_true, Int.natCast_add, Int.cast_ofNat_Int,
forall_const, false_iff, ne_eq]
by_cases w : p = 0
· simp [w]
omega
· have : ((x + 1) % p) < p := Nat.mod_lt _ (by omega)
omega
theorem intCast_ext_iff {x y : Int} : (x : α) = (y : α) x % p = y % p := by
constructor
· intro h
replace h : ((x - y : Int) : α) = 0 := by rw [intCast_sub, h, sub_self]
exact Int.emod_eq_emod_iff_emod_sub_eq_zero.mpr ((intCast_eq_zero_iff p _).mp h)
· intro h
have : ((x - y : Int) : α) = 0 :=
(intCast_eq_zero_iff p _).mpr (by rw [Int.sub_emod, h, Int.sub_self, Int.zero_emod])
replace this := congrArg (· + (y : α)) this
simpa [intCast_sub, zero_add, sub_eq_add_neg, add_assoc, neg_add_cancel, add_zero] using this
theorem ofNat_ext_iff {x y : Nat} : OfNat.ofNat (α := α) x = OfNat.ofNat (α := α) y x % p = y % p := by
have := intCast_ext_iff (α := α) p (x := x) (y := y)
simp only [intCast_ofNat, Int.ofNat_emod] at this
simp only [ofNat_eq_natCast]
norm_cast at this
theorem ofNat_ext {x y : Nat} (h : x % p = y % p) : OfNat.ofNat (α := α) x = OfNat.ofNat (α := α) y := (ofNat_ext_iff p).mpr h
theorem natCast_ext {x y : Nat} (h : x % p = y % p) : (x : α) = (y : α) := ofNat_ext _ h
theorem natCast_ext_iff {x y : Nat} : (x : α) = (y : α) x % p = y % p :=
ofNat_ext_iff p
theorem intCast_emod (x : Int) : ((x % p : Int) : α) = (x : α) := by
rw [intCast_ext_iff p, Int.emod_emod]
theorem natCast_emod (x : Nat) : ((x % p : Nat) : α) = (x : α) := by
simp only [ intCast_ofNat]
rw [Int.ofNat_emod, intCast_emod]
theorem ofNat_emod (x : Nat) : OfNat.ofNat (α := α) (x % p) = OfNat.ofNat x :=
natCast_emod _ _
theorem ofNat_eq_zero_iff_of_lt {x : Nat} (h : x < p) : OfNat.ofNat (α := α) x = 0 x = 0 := by
rw [ofNat_eq_zero_iff p, Nat.mod_eq_of_lt h]
theorem ofNat_eq_iff_of_lt {x y : Nat} (h₁ : x < p) (h₂ : y < p) :
OfNat.ofNat (α := α) x = OfNat.ofNat (α := α) y x = y := by
rw [ofNat_ext_iff p, Nat.mod_eq_of_lt h₁, Nat.mod_eq_of_lt h₂]
theorem natCast_eq_zero_iff_of_lt {x : Nat} (h : x < p) : (x : α) = 0 x = 0 := by
rw [natCast_eq_zero_iff p, Nat.mod_eq_of_lt h]
theorem natCast_eq_iff_of_lt {x y : Nat} (h₁ : x < p) (h₂ : y < p) :
(x : α) = (y : α) x = y := by
rw [natCast_ext_iff p, Nat.mod_eq_of_lt h₁, Nat.mod_eq_of_lt h₂]
end IsCharP
end Lean.Grind

View File

@@ -19,5 +19,12 @@ instance : CommRing (BitVec w) where
mul_one := BitVec.mul_one
left_distrib _ _ _ := BitVec.mul_add
zero_mul _ := BitVec.zero_mul
sub_eq_add_neg := BitVec.sub_eq_add_neg
pow_zero _ := BitVec.pow_zero
pow_succ _ _ := BitVec.pow_succ
ofNat_succ x := BitVec.ofNat_add x 1
instance : IsCharP (BitVec w) (2 ^ w) where
ofNat_eq_zero_iff {x} := by simp [BitVec.ofInt, BitVec.toNat_eq]
end Lean.Grind

View File

@@ -19,5 +19,12 @@ instance : CommRing Int where
mul_one := Int.mul_one
left_distrib := Int.mul_add
zero_mul := Int.zero_mul
pow_zero _ := rfl
pow_succ _ _ := rfl
ofNat_succ _ := rfl
sub_eq_add_neg _ _ := Int.sub_eq_add_neg
instance : IsCharP Int 0 where
ofNat_eq_zero_iff {x} := by erw [Int.ofNat_eq_zero]; simp
end Lean.Grind

View File

@@ -9,6 +9,9 @@ import Init.Data.SInt.Lemmas
namespace Lean.Grind
instance : IntCast Int8 where
intCast x := Int8.ofInt x
instance : CommRing Int8 where
add_assoc := Int8.add_assoc
add_comm := Int8.add_comm
@@ -19,6 +22,20 @@ instance : CommRing Int8 where
mul_one := Int8.mul_one
left_distrib _ _ _ := Int8.mul_add
zero_mul _ := Int8.zero_mul
sub_eq_add_neg := Int8.sub_eq_add_neg
pow_zero := Int8.pow_zero
pow_succ := Int8.pow_succ
ofNat_succ x := Int8.ofNat_add x 1
instance : IsCharP Int8 (2 ^ 8) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = Int8.ofInt x := rfl
rw [this]
simp [Int8.ofInt_eq_iff_bmod_eq_toInt,
Int.dvd_iff_bmod_eq_zero, Nat.dvd_iff_mod_eq_zero, Int.ofNat_dvd_right]
instance : IntCast Int16 where
intCast x := Int16.ofInt x
instance : CommRing Int16 where
add_assoc := Int16.add_assoc
@@ -30,6 +47,20 @@ instance : CommRing Int16 where
mul_one := Int16.mul_one
left_distrib _ _ _ := Int16.mul_add
zero_mul _ := Int16.zero_mul
sub_eq_add_neg := Int16.sub_eq_add_neg
pow_zero := Int16.pow_zero
pow_succ := Int16.pow_succ
ofNat_succ x := Int16.ofNat_add x 1
instance : IsCharP Int16 (2 ^ 16) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = Int16.ofInt x := rfl
rw [this]
simp [Int16.ofInt_eq_iff_bmod_eq_toInt,
Int.dvd_iff_bmod_eq_zero, Nat.dvd_iff_mod_eq_zero, Int.ofNat_dvd_right]
instance : IntCast Int32 where
intCast x := Int32.ofInt x
instance : CommRing Int32 where
add_assoc := Int32.add_assoc
@@ -41,6 +72,20 @@ instance : CommRing Int32 where
mul_one := Int32.mul_one
left_distrib _ _ _ := Int32.mul_add
zero_mul _ := Int32.zero_mul
sub_eq_add_neg := Int32.sub_eq_add_neg
pow_zero := Int32.pow_zero
pow_succ := Int32.pow_succ
ofNat_succ x := Int32.ofNat_add x 1
instance : IsCharP Int32 (2 ^ 32) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = Int32.ofInt x := rfl
rw [this]
simp [Int32.ofInt_eq_iff_bmod_eq_toInt,
Int.dvd_iff_bmod_eq_zero, Nat.dvd_iff_mod_eq_zero, Int.ofNat_dvd_right]
instance : IntCast Int64 where
intCast x := Int64.ofInt x
instance : CommRing Int64 where
add_assoc := Int64.add_assoc
@@ -52,6 +97,20 @@ instance : CommRing Int64 where
mul_one := Int64.mul_one
left_distrib _ _ _ := Int64.mul_add
zero_mul _ := Int64.zero_mul
sub_eq_add_neg := Int64.sub_eq_add_neg
pow_zero := Int64.pow_zero
pow_succ := Int64.pow_succ
ofNat_succ x := Int64.ofNat_add x 1
instance : IsCharP Int64 (2 ^ 64) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = Int64.ofInt x := rfl
rw [this]
simp [Int64.ofInt_eq_iff_bmod_eq_toInt,
Int.dvd_iff_bmod_eq_zero, Nat.dvd_iff_mod_eq_zero, Int.ofNat_dvd_right]
instance : IntCast ISize where
intCast x := ISize.ofInt x
instance : CommRing ISize where
add_assoc := ISize.add_assoc
@@ -63,5 +122,18 @@ instance : CommRing ISize where
mul_one := ISize.mul_one
left_distrib _ _ _ := ISize.mul_add
zero_mul _ := ISize.zero_mul
sub_eq_add_neg := ISize.sub_eq_add_neg
pow_zero := ISize.pow_zero
pow_succ := ISize.pow_succ
ofNat_succ x := ISize.ofNat_add x 1
open System.Platform (numBits)
instance : IsCharP ISize (2 ^ numBits) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = ISize.ofInt x := rfl
rw [this]
simp [ISize.ofInt_eq_iff_bmod_eq_toInt,
Int.dvd_iff_bmod_eq_zero, Nat.dvd_iff_mod_eq_zero, Int.ofNat_dvd_right]
end Lean.Grind

View File

@@ -7,6 +7,53 @@ prelude
import Init.Grind.CommRing.Basic
import Init.Data.UInt.Lemmas
namespace UInt8
/-- Variant of `UInt8.ofNat_mod_size` replacing `2 ^ 8` with `256`.-/
theorem ofNat_mod_size' : ofNat (x % 256) = ofNat x := ofNat_mod_size
instance : IntCast UInt8 where
intCast x := UInt8.ofInt x
end UInt8
namespace UInt16
/-- Variant of `UInt16.ofNat_mod_size` replacing `2 ^ 16` with `65536`.-/
theorem ofNat_mod_size' : ofNat (x % 65536) = ofNat x := ofNat_mod_size
instance : IntCast UInt16 where
intCast x := UInt16.ofInt x
end UInt16
namespace UInt32
/-- Variant of `UInt32.ofNat_mod_size` replacing `2 ^ 32` with `4294967296`.-/
theorem ofNat_mod_size' : ofNat (x % 4294967296) = ofNat x := ofNat_mod_size
instance : IntCast UInt32 where
intCast x := UInt32.ofInt x
end UInt32
namespace UInt64
/-- Variant of `UInt64.ofNat_mod_size` replacing `2 ^ 64` with `18446744073709551616`.-/
theorem ofNat_mod_size' : ofNat (x % 18446744073709551616) = ofNat x := ofNat_mod_size
instance : IntCast UInt64 where
intCast x := UInt64.ofInt x
end UInt64
namespace USize
instance : IntCast USize where
intCast x := USize.ofInt x
end USize
namespace Lean.Grind
instance : CommRing UInt8 where
@@ -19,6 +66,15 @@ instance : CommRing UInt8 where
mul_one := UInt8.mul_one
left_distrib _ _ _ := UInt8.mul_add
zero_mul _ := UInt8.zero_mul
sub_eq_add_neg := UInt8.sub_eq_add_neg
pow_zero := UInt8.pow_zero
pow_succ := UInt8.pow_succ
ofNat_succ x := UInt8.ofNat_add x 1
instance : IsCharP UInt8 (2 ^ 8) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = UInt8.ofNat x := rfl
simp [this, UInt8.ofNat_eq_iff_mod_eq_toNat]
instance : CommRing UInt16 where
add_assoc := UInt16.add_assoc
@@ -30,6 +86,15 @@ instance : CommRing UInt16 where
mul_one := UInt16.mul_one
left_distrib _ _ _ := UInt16.mul_add
zero_mul _ := UInt16.zero_mul
sub_eq_add_neg := UInt16.sub_eq_add_neg
pow_zero := UInt16.pow_zero
pow_succ := UInt16.pow_succ
ofNat_succ x := UInt16.ofNat_add x 1
instance : IsCharP UInt16 (2 ^ 16) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = UInt16.ofNat x := rfl
simp [this, UInt16.ofNat_eq_iff_mod_eq_toNat]
instance : CommRing UInt32 where
add_assoc := UInt32.add_assoc
@@ -41,6 +106,15 @@ instance : CommRing UInt32 where
mul_one := UInt32.mul_one
left_distrib _ _ _ := UInt32.mul_add
zero_mul _ := UInt32.zero_mul
sub_eq_add_neg := UInt32.sub_eq_add_neg
pow_zero := UInt32.pow_zero
pow_succ := UInt32.pow_succ
ofNat_succ x := UInt32.ofNat_add x 1
instance : IsCharP UInt32 (2 ^ 32) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = UInt32.ofNat x := rfl
simp [this, UInt32.ofNat_eq_iff_mod_eq_toNat]
instance : CommRing UInt64 where
add_assoc := UInt64.add_assoc
@@ -52,6 +126,15 @@ instance : CommRing UInt64 where
mul_one := UInt64.mul_one
left_distrib _ _ _ := UInt64.mul_add
zero_mul _ := UInt64.zero_mul
sub_eq_add_neg := UInt64.sub_eq_add_neg
pow_zero := UInt64.pow_zero
pow_succ := UInt64.pow_succ
ofNat_succ x := UInt64.ofNat_add x 1
instance : IsCharP UInt64 (2 ^ 64) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = UInt64.ofNat x := rfl
simp [this, UInt64.ofNat_eq_iff_mod_eq_toNat]
instance : CommRing USize where
add_assoc := USize.add_assoc
@@ -63,5 +146,16 @@ instance : CommRing USize where
mul_one := USize.mul_one
left_distrib _ _ _ := USize.mul_add
zero_mul _ := USize.zero_mul
sub_eq_add_neg := USize.sub_eq_add_neg
pow_zero := USize.pow_zero
pow_succ := USize.pow_succ
ofNat_succ x := USize.ofNat_add x 1
open System.Platform
instance : IsCharP USize (2 ^ numBits) where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = USize.ofNat x := rfl
simp [this, USize.ofNat_eq_iff_mod_eq_toNat]
end Lean.Grind

View File

@@ -165,4 +165,9 @@ theorem of_decide_eq_false {p : Prop} {_ : Decidable p} : decide p = false → p
theorem decide_eq_true {p : Prop} {_ : Decidable p} : p = True decide p = true := by simp
theorem decide_eq_false {p : Prop} {_ : Decidable p} : p = False decide p = false := by simp
/-! Lookahead -/
theorem of_lookahead (p : Prop) (h : (¬ p) False) : p = True := by
simp at h; simp [h]
end Lean.Grind

View File

@@ -74,11 +74,11 @@ theorem bne_eq_decide_not_eq {_ : BEq α} [LawfulBEq α] [DecidableEq α] (a b :
theorem xor_eq (a b : Bool) : (a ^^ b) = (a != b) := by
rfl
theorem natCast_div (a b : Nat) : ((a / b) : Int) = a / b := by
rfl
theorem natCast_mod (a b : Nat) : ((a % b) : Int) = a % b := by
rfl
theorem natCast_eq [NatCast α] (a : Nat) : (Nat.cast a : α) = (NatCast.natCast a : α) := rfl
theorem natCast_div (a b : Nat) : (NatCast.natCast (a / b) : Int) = (NatCast.natCast a) / (NatCast.natCast b) := rfl
theorem natCast_mod (a b : Nat) : (NatCast.natCast (a % b) : Int) = (NatCast.natCast a) % (NatCast.natCast b) := rfl
theorem natCast_add (a b : Nat) : (NatCast.natCast (a + b : Nat) : Int) = (NatCast.natCast a : Int) + (NatCast.natCast b : Int) := rfl
theorem natCast_mul (a b : Nat) : (NatCast.natCast (a * b : Nat) : Int) = (NatCast.natCast a : Int) * (NatCast.natCast b : Int) := rfl
theorem Nat.pow_one (a : Nat) : a ^ 1 = a := by
simp
@@ -153,8 +153,10 @@ init_grind_norm
Int.emod_neg Int.ediv_neg
Int.ediv_zero Int.emod_zero
Int.ediv_one Int.emod_one
Int.natCast_add Int.natCast_mul Int.natCast_pow
Int.natCast_zero natCast_div natCast_mod
natCast_eq natCast_div natCast_mod
natCast_add natCast_mul
Int.pow_zero Int.pow_one
-- GT GE
ge_eq gt_eq

View File

@@ -25,7 +25,8 @@ syntax grindUsr := &"usr "
syntax grindCases := &"cases "
syntax grindCasesEager := atomic(&"cases" &"eager ")
syntax grindIntro := &"intro "
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro
syntax grindExt := &"ext "
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro <|> grindExt
syntax (name := grind) "grind" (grindMod)? : attr
end Attr
end Lean.Parser
@@ -68,8 +69,17 @@ structure Config where
failures : Nat := 1
/-- Maximum number of heartbeats (in thousands) the canonicalizer can spend per definitional equality test. -/
canonHeartbeats : Nat := 1000
/-- If `ext` is `true`, `grind` uses extensionality theorems available in the environment. -/
/-- If `ext` is `true`, `grind` uses extensionality theorems that have been marked with `[grind ext]`. -/
ext : Bool := true
/-- If `extAll` is `true`, `grind` uses any extensionality theorems available in the environment. -/
extAll : Bool := false
/--
If `funext` is `true`, `grind` creates new opportunities for applying function extensionality by case-splitting
on equalities between lambda expressions.
-/
funext : Bool := true
/-- TODO -/
lookahead : Bool := true
/-- If `verbose` is `false`, additional diagnostics information is not collected. -/
verbose : Bool := true
/-- If `clean` is `true`, `grind` uses `expose_names` and only generates accessible names. -/

View File

@@ -44,6 +44,13 @@ register_builtin_option Elab.async : Bool := {
`Lean.Command.State.snapshotTasks`."
}
register_builtin_option Elab.inServer : Bool := {
defValue := false
descr := "true if elaboration is being run inside the Lean language server\
\n\
\nThis option is set by the file worker and should not be modified otherwise."
}
/-- Performance option used by cmdline driver. -/
register_builtin_option internal.cmdlineSnapshots : Bool := {
defValue := false

View File

@@ -111,6 +111,7 @@ inductive Message where
| response (id : RequestID) (result : Json)
/-- A non-successful response. -/
| responseError (id : RequestID) (code : ErrorCode) (message : String) (data? : Option Json)
deriving Inhabited
def Batch := Array Message

View File

@@ -510,6 +510,7 @@ where go := do
let oldCmds? := oldSnap?.map fun old =>
if old.newStx.isOfKind nullKind then old.newStx.getArgs else #[old.newStx]
let cmdPromises cmds.mapM fun _ => IO.Promise.new
let cancelTk? := ( read).cancelTk?
snap.new.resolve <| .ofTyped {
diagnostics := .empty
macroDecl := decl
@@ -517,7 +518,7 @@ where go := do
newNextMacroScope := nextMacroScope
hasTraces
next := Array.zipWith (fun cmdPromise cmd =>
{ stx? := some cmd, task := cmdPromise.resultD default }) cmdPromises cmds
{ stx? := some cmd, task := cmdPromise.resultD default, cancelTk? }) cmdPromises cmds
: MacroExpandedSnapshot
}
-- After the first command whose syntax tree changed, we must disable

View File

@@ -6,6 +6,7 @@ Authors: Leonardo de Moura, Sebastian Ullrich
prelude
import Lean.Parser.Module
import Lean.Util.Paths
import Lean.CoreM
namespace Lean.Elab
@@ -21,9 +22,16 @@ def processHeader (header : Syntax) (opts : Options) (messages : MessageLog)
(inputCtx : Parser.InputContext) (trustLevel : UInt32 := 0)
(plugins : Array System.FilePath := #[]) (leakEnv := false)
: IO (Environment × MessageLog) := do
let level := if experimental.module.get opts then
if Elab.inServer.get opts then
.server
else
.exported
else
.private
try
let env
importModules (leakEnv := leakEnv) (loadExts := true) (headerToImports header) opts trustLevel plugins
importModules (leakEnv := leakEnv) (loadExts := true) (level := level) (headerToImports header) opts trustLevel plugins
pure (env, messages)
catch e =>
let env mkEmptyEnvironment

View File

@@ -165,6 +165,8 @@ private def elabHeaders (views : Array DefView) (expandedDeclIds : Array ExpandD
-- no syntax guard to store, we already did the necessary checks
oldBodySnap? := guard reuseBody *> pure .missing, old.bodySnap
if oldBodySnap?.isNone then
-- NOTE: this will eagerly cancel async tasks not associated with an inner snapshot, most
-- importantly kernel checking and compilation of the top-level declaration
old.bodySnap.cancelRec
oldTacSnap? := do
guard reuseTac
@@ -217,6 +219,7 @@ private def elabHeaders (views : Array DefView) (expandedDeclIds : Array ExpandD
return newHeader
if let some snap := view.headerSnap? then
let (tacStx?, newTacTask?) mkTacTask view.value tacPromise
let cancelTk? := ( readThe Core.Context).cancelTk?
let bodySnap := {
stx? := view.value
reportingRange? :=
@@ -227,6 +230,8 @@ private def elabHeaders (views : Array DefView) (expandedDeclIds : Array ExpandD
else
getBodyTerm? view.value |>.getD view.value |>.getRange?
task := bodyPromise.resultD default
-- We should not cancel the entire body early if we have tactics
cancelTk? := guard newTacTask?.isNone *> cancelTk?
}
snap.new.resolve <| some {
diagnostics :=
@@ -269,7 +274,8 @@ where
:= do
if let some e := getBodyTerm? body then
if let `(by $tacs*) := e then
return (e, some { stx? := mkNullNode tacs, task := tacPromise.resultD default })
let cancelTk? := ( readThe Core.Context).cancelTk?
return (e, some { stx? := mkNullNode tacs, task := tacPromise.resultD default, cancelTk? })
tacPromise.resolve default
return (none, none)
@@ -432,8 +438,7 @@ private def elabFunValues (headers : Array DefViewElabHeader) (vars : Array Expr
snap.new.resolve <| some old
reusableResult? := some (old.value, old.state)
else
-- NOTE: this will eagerly cancel async tasks not associated with an inner snapshot, most
-- importantly kernel checking and compilation of the top-level declaration
-- make sure to cancel any async tasks that may still be running (e.g. kernel and codegen)
old.val.cancelRec
let (val, state) withRestoreOrSaveFull reusableResult? header.tacSnap? do
@@ -1158,7 +1163,7 @@ is error-free and contains no syntactical `sorry`s.
-/
private def logGoalsAccomplishedSnapshotTask (views : Array DefView)
(defsParsedSnap : DefsParsedSnapshot) : TermElabM Unit := do
if Lean.internal.cmdlineSnapshots.get ( getOptions) then
if Lean.Elab.inServer.get ( getOptions) then
-- Skip 'goals accomplished' task if we are on the command line.
-- These messages are only used in the language server.
return
@@ -1197,6 +1202,7 @@ private def logGoalsAccomplishedSnapshotTask (views : Array DefView)
-- Use first line of the mutual block to avoid covering the progress of the whole mutual block
reportingRange? := ( getRef).getPos?.map fun pos => pos, pos
task := logGoalsAccomplishedTask
cancelTk? := none
}
end Term
@@ -1235,9 +1241,10 @@ def elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
} }
if snap.old?.isSome && (view.headerSnap?.bind (·.old?)).isNone then
snap.old?.forM (·.val.cancelRec)
let cancelTk? := ( read).cancelTk?
defs := defs.push {
fullHeaderRef
headerProcessedSnap := { stx? := d, task := headerPromise.resultD default }
headerProcessedSnap := { stx? := d, task := headerPromise.resultD default, cancelTk? }
}
reusedAllHeaders := reusedAllHeaders && view.headerSnap?.any (·.old?.isSome)
views := views.push view

View File

@@ -109,11 +109,11 @@ structure InductiveView where
/-- Elaborated header for an inductive type before fvars for each inductive are added to the local context. -/
structure PreElabHeaderResult where
view : InductiveView
lctx : LocalContext
localInsts : LocalInstances
levelNames : List Name
params : Array Expr
numParams : Nat
type : Expr
/-- The parameters in the header's initial local context. Used for adding fvar alias terminfo. -/
origParams : Array Expr
deriving Inhabited
/-- The elaborated header with the `indFVar` registered for this inductive type. -/
@@ -228,16 +228,12 @@ private def checkClass (rs : Array PreElabHeaderResult) : TermElabM Unit := do
throwErrorAt r.view.ref "invalid inductive type, mutual classes are not supported"
private def checkNumParams (rs : Array PreElabHeaderResult) : TermElabM Nat := do
let numParams := rs[0]!.params.size
let numParams := rs[0]!.numParams
for r in rs do
unless r.params.size == numParams do
unless r.numParams == numParams do
throwErrorAt r.view.ref "invalid inductive type, number of parameters mismatch in mutually inductive datatypes"
return numParams
private def mkTypeFor (r : PreElabHeaderResult) : TermElabM Expr := do
withLCtx r.lctx r.localInsts do
mkForallFVars r.params r.type
/--
Execute `k` with updated binder information for `xs`. Any `x` that is explicit becomes implicit.
-/
@@ -276,7 +272,7 @@ private def checkHeaders (rs : Array PreElabHeaderResult) (numParams : Nat) (i :
checkHeaders rs numParams (i+1) type
where
checkHeader (r : PreElabHeaderResult) (numParams : Nat) (firstType? : Option Expr) : TermElabM Expr := do
let type mkTypeFor r
let type := r.type
match firstType? with
| none => return type
| some firstType =>
@@ -306,7 +302,8 @@ private def elabHeadersAux (views : Array InductiveView) (i : Nat) (acc : Array
let params Term.addAutoBoundImplicits params (view.declId.getTailPos? (canonicalOnly := true))
trace[Elab.inductive] "header params: {params}, type: {type}"
let levelNames Term.getLevelNames
return acc.push { lctx := ( getLCtx), localInsts := ( getLocalInstances), levelNames, params, type, view }
let type mkForallFVars params type
return acc.push { levelNames, numParams := params.size, type, view, origParams := params }
elabHeadersAux views (i+1) acc
else
return acc
@@ -326,21 +323,21 @@ private def elabHeaders (views : Array InductiveView) : TermElabM (Array PreElab
/--
Create a local declaration for each inductive type in `rs`, and execute `x params indFVars`, where `params` are the inductive type parameters and
`indFVars` are the new local declarations.
We use the local context/instances and parameters of rs[0].
We use the parameters of rs[0].
Note that this method is executed after we executed `checkHeaders` and established all
parameters are compatible.
-/
private def withInductiveLocalDecls (rs : Array PreElabHeaderResult) (x : Array Expr Array Expr TermElabM α) : TermElabM α := do
let namesAndTypes rs.mapM fun r => do
let type mkTypeFor r
pure (r.view.declName, r.view.shortDeclName, type)
let r0 := rs[0]!
let params := r0.params
withLCtx r0.lctx r0.localInsts <| withRef r0.view.ref do
let r0 := rs[0]!
forallBoundedTelescope r0.type r0.numParams fun params _ => withRef r0.view.ref do
let rec loop (i : Nat) (indFVars : Array Expr) := do
if h : i < namesAndTypes.size then
let (declName, shortDeclName, type) := namesAndTypes[i]
withAuxDecl shortDeclName type declName fun indFVar => loop (i+1) (indFVars.push indFVar)
if h : i < rs.size then
let r := rs[i]
for param in params, origParam in r.origParams do
if let .fvar origFVar := origParam then
Elab.pushInfoLeaf <| .ofFVarAliasInfo { id := param.fvarId!, baseId := origFVar, userName := param.fvarId!.getUserName }
withAuxDecl r.view.shortDeclName r.type r.view.declName fun indFVar =>
loop (i+1) (indFVars.push indFVar)
else
x params indFVars
loop 0 #[]
@@ -359,26 +356,6 @@ private def ElabHeaderResult.checkLevelNames (rs : Array PreElabHeaderResult) :
unless r.levelNames == levelNames do
throwErrorAt r.view.ref "invalid inductive type, universe parameters mismatch in mutually inductive datatypes"
/--
We need to work inside a single local context across all the inductive types, so we need to update the `ElabHeaderResult`s
so that resultant types refer to the fvars in `params`, the parameters for `rs[0]!` specifically.
Also updates the local contexts and local instances in each header.
-/
private def updateElabHeaderTypes (params : Array Expr) (rs : Array PreElabHeaderResult) (indFVars : Array Expr) : TermElabM (Array ElabHeaderResult) := do
rs.mapIdxM fun i r => do
/-
At this point, because of `withInductiveLocalDecls`, the only fvars that are in context are the ones related to the first inductive type.
Because of this, we need to replace the fvars present in each inductive type's header of the mutual block with those of the first inductive.
However, some mvars may still be uninstantiated there, and might hide some of the old fvars.
As such we first need to synthesize all possible mvars at this stage, instantiate them in the header types and only
then replace the parameters' fvars in the header type.
See issue #3242 (`https://github.com/leanprover/lean4/issues/3242`)
-/
let type instantiateMVars r.type
let type := type.replaceFVars r.params params
pure { r with lctx := getLCtx, localInsts := getLocalInstances, type := type, indFVar := indFVars[i]! }
private def getArity (indType : InductiveType) : MetaM Nat :=
forallTelescopeReducing indType.type fun xs _ => return xs.size
@@ -878,7 +855,7 @@ private def mkInductiveDecl (vars : Array Expr) (elabs : Array InductiveElabStep
trace[Elab.inductive] "level names: {allUserLevelNames}"
let res withInductiveLocalDecls rs fun params indFVars => do
trace[Elab.inductive] "indFVars: {indFVars}"
let rs updateElabHeaderTypes params rs indFVars
let rs := Array.zipWith (fun r indFVar => { r with indFVar : ElabHeaderResult }) rs indFVars
let mut indTypesArray : Array InductiveType := #[]
let mut elabs' := #[]
for h : i in [:views.size] do
@@ -886,8 +863,7 @@ private def mkInductiveDecl (vars : Array Expr) (elabs : Array InductiveElabStep
let r := rs[i]!
let elab' elabs[i]!.elabCtors rs r params
elabs' := elabs'.push elab'
let type mkForallFVars params r.type
indTypesArray := indTypesArray.push { name := r.view.declName, type, ctors := elab'.ctors }
indTypesArray := indTypesArray.push { name := r.view.declName, type := r.type, ctors := elab'.ctors }
Term.synthesizeSyntheticMVarsNoPostponing
let numExplicitParams fixedIndicesToParams params.size indTypesArray indFVars
trace[Elab.inductive] "numExplicitParams: {numExplicitParams}"

View File

@@ -1260,12 +1260,8 @@ private def addDefaults (levelParams : List Name) (params : Array Expr) (replace
let fieldInfos := ( get).fields
let lctx instantiateLCtxMVars ( getLCtx)
/- The parameters `params` for the auxiliary "default value" definitions must be marked as implicit, and all others as explicit. -/
let lctx :=
params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) =>
if p.isFVar then
lctx.setBinderInfo p.fvarId! BinderInfo.implicit
else
lctx
let lctx := params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) =>
lctx.setBinderInfo p.fvarId! BinderInfo.implicit
let parentFVarIds := fieldInfos |>.filter (·.kind.isParent) |>.map (·.fvar.fvarId!)
let fields := fieldInfos |>.filter (!·.kind.isParent)
withLCtx lctx ( getLocalInstances) do

View File

@@ -165,6 +165,7 @@ structure EvalTacticFailure where
state : SavedState
partial def evalTactic (stx : Syntax) : TacticM Unit := do
checkSystem "tactic execution"
profileitM Exception "tactic execution" (decl := stx.getKind) ( getOptions) <|
withRef stx <| withIncRecDepth <| withFreshMacroScope <| match stx with
| .node _ k _ =>
@@ -240,6 +241,7 @@ where
snap.old?.forM (·.val.cancelRec)
let promise IO.Promise.new
-- Store new unfolding in the snapshot tree
let cancelTk? := ( readThe Core.Context).cancelTk?
snap.new.resolve {
stx := stx'
diagnostics := .empty
@@ -249,7 +251,7 @@ where
state? := ( Tactic.saveState)
moreSnaps := #[]
}
next := #[{ stx? := stx', task := promise.resultD default }]
next := #[{ stx? := stx', task := promise.resultD default, cancelTk? }]
}
-- Update `tacSnap?` to old unfolding
withTheReader Term.Context ({ · with tacSnap? := some {

View File

@@ -78,13 +78,14 @@ where
let next IO.Promise.new
let finished IO.Promise.new
let inner IO.Promise.new
let cancelTk? := ( readThe Core.Context).cancelTk?
snap.new.resolve {
desc := tac.getKind.toString
diagnostics := .empty
stx := tac
inner? := some { stx? := tac, task := inner.resultD default }
finished := { stx? := tac, task := finished.resultD default }
next := #[{ stx? := stxs, task := next.resultD default }]
inner? := some { stx? := tac, task := inner.resultD default, cancelTk? }
finished := { stx? := tac, task := finished.resultD default, cancelTk? }
next := #[{ stx? := stxs, task := next.resultD default, cancelTk? }]
}
-- Run `tac` in a fresh info tree state and store resulting state in snapshot for
-- incremental reporting, then add back saved trees. Here we rely on `evalTactic`

View File

@@ -89,6 +89,8 @@ def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic.
params withRef p <| addEMatchTheorem params ctor .default
else
throwError "invalid use of `intro` modifier, `{declName}` is not an inductive predicate"
| .ext =>
throwError "`[grind ext]` cannot be set using parameters"
| .infer =>
if let some declName Grind.isCasesAttrCandidate? declName false then
params := { params with casesTypes := params.casesTypes.insert declName false }

View File

@@ -286,14 +286,15 @@ where
-- them, eventually put each of them back in `Context.tacSnap?` in `applyAltStx`
let finished IO.Promise.new
let altPromises altStxs.mapM fun _ => IO.Promise.new
let cancelTk? := ( readThe Core.Context).cancelTk?
tacSnap.new.resolve {
-- save all relevant syntax here for comparison with next document version
stx := mkNullNode altStxs
diagnostics := .empty
inner? := none
finished := { stx? := mkNullNode altStxs, reportingRange? := none, task := finished.resultD default }
finished := { stx? := mkNullNode altStxs, reportingRange? := none, task := finished.resultD default, cancelTk? }
next := Array.zipWith
(fun stx prom => { stx? := some stx, task := prom.resultD default })
(fun stx prom => { stx? := some stx, task := prom.resultD default, cancelTk? })
altStxs altPromises
}
goWithIncremental <| altPromises.mapIdx fun i prom => {

View File

@@ -1878,13 +1878,14 @@ where
go todo (autos.push auto)
/--
Similar to `autoBoundImplicits`, but immediately if the resulting array of expressions contains metavariables,
it immediately uses `mkForallFVars` + `forallBoundedTelescope` to convert them into free variables.
Similar to `addAutoBoundImplicits`, but converts all metavariables into free variables.
It uses `mkForallFVars` + `forallBoundedTelescope` to convert metavariables into free variables.
The type `type` is modified during the process if type depends on `xs`.
We use this method to simplify the conversion of code using `autoBoundImplicitsOld` to `autoBoundImplicits`.
-/
def addAutoBoundImplicits' (xs : Array Expr) (type : Expr) (k : Array Expr Expr TermElabM α) : TermElabM α := do
let xs addAutoBoundImplicits xs none
def addAutoBoundImplicits' (xs : Array Expr) (type : Expr) (k : Array Expr Expr TermElabM α) (inlayHintPos? : Option String.Pos := none) : TermElabM α := do
let xs addAutoBoundImplicits xs inlayHintPos?
if xs.all (·.isFVar) then
k xs type
else

View File

@@ -122,13 +122,15 @@ end TagDeclarationExtension
structure MapDeclarationExtension (α : Type) extends PersistentEnvExtension (Name × α) (Name × α) (NameMap α)
deriving Inhabited
def mkMapDeclarationExtension (name : Name := by exact decl_name%) : IO (MapDeclarationExtension α) :=
def mkMapDeclarationExtension (name : Name := by exact decl_name%)
(exportEntriesFn : NameMap α Array (Name × α) := (·.toArray)) : IO (MapDeclarationExtension α) :=
.mk <$> registerPersistentEnvExtension {
name := name,
mkInitial := pure {}
addImportedFn := fun _ => pure {}
addEntryFn := fun s (n, v) => s.insert n v
exportEntriesFn := fun s => s.toArray
saveEntriesFn := fun s => s.toArray
exportEntriesFn
asyncMode := .async
replay? := some fun _ newState newConsts s =>
newConsts.foldl (init := s) fun s c =>
@@ -145,10 +147,11 @@ def insert (ext : MapDeclarationExtension α) (env : Environment) (declName : Na
assert! env.asyncMayContain declName
ext.addEntry env (declName, val)
def find? [Inhabited α] (ext : MapDeclarationExtension α) (env : Environment) (declName : Name) : Option α :=
def find? [Inhabited α] (ext : MapDeclarationExtension α) (env : Environment) (declName : Name)
(includeServer := false) : Option α :=
match env.getModuleIdxFor? declName with
| some modIdx =>
match (ext.getModuleEntries env modIdx).binSearch (declName, default) (fun a b => Name.quickLt a.1 b.1) with
match (ext.getModuleEntries (includeServer := includeServer) env modIdx).binSearch (declName, default) (fun a b => Name.quickLt a.1 b.1) with
| some e => some e.2
| none => none
| none => (ext.findStateAsync env declName).find? declName

View File

@@ -506,6 +506,12 @@ structure Environment where
-/
base : Kernel.Environment
/--
Additional imported environment extension state for use in the language server. This field is
identical to `base.extensions` in other contexts. Access via
`getModuleEntries (includeServer := true)`.
-/
private serverBaseExts : Array EnvExtensionState := base.extensions
/--
Kernel environment task that is fulfilled when all asynchronously elaborated declarations are
finished, containing the resulting environment. Also collects the environment extension state of
all environment branches that contributed contained declarations.
@@ -536,6 +542,12 @@ structure Environment where
`findAsyncCore?`/`findStateAsync`; see there.
-/
private allRealizations : Task (NameMap AsyncConst) := .pure {}
/--
Indicates whether the environment is being used in an exported context, i.e. whether it should
provide access to only the data to be imported by other modules participating in the module
system.
-/
isExporting : Bool := false
deriving Nonempty
namespace Environment
@@ -549,6 +561,10 @@ def ofKernelEnv (env : Kernel.Environment) : Environment :=
def toKernelEnv (env : Environment) : Kernel.Environment :=
env.checked.get
/-- Updates `Environment.isExporting`. -/
def setExporting (env : Environment) (isExporting : Bool) : Environment :=
{ env with isExporting }
/-- Consistently updates synchronous and asynchronous parts of the environment without blocking. -/
private def modifyCheckedAsync (env : Environment) (f : Kernel.Environment Kernel.Environment) : Environment :=
{ env with checked := env.checked.map (sync := true) f, base := f env.base }
@@ -1379,7 +1395,15 @@ structure PersistentEnvExtension (α : Type) (β : Type) (σ : Type) where
name : Name
addImportedFn : Array (Array α) ImportM σ
addEntryFn : σ β σ
/-- Function to transform state into data that should always be imported into other modules. -/
exportEntriesFn : σ Array α
/--
Function to transform state into data that should be imported into other modules when the module
system is disabled. When it is enabled, the data is loaded only in the language server and
accessible via `getModuleEntries (includeServer := true)`. Conventionally, this is a superset of
the data returned by `exportEntriesFn`.
-/
saveEntriesFn : σ Array α
statsFn : σ Format
instance {α σ} [Inhabited σ] : Inhabited (PersistentEnvExtensionState α σ) :=
@@ -1392,14 +1416,21 @@ instance {α β σ} [Inhabited σ] : Inhabited (PersistentEnvExtension α β σ)
addImportedFn := fun _ => default,
addEntryFn := fun s _ => s,
exportEntriesFn := fun _ => #[],
saveEntriesFn := fun _ => #[],
statsFn := fun _ => Format.nil
}
namespace PersistentEnvExtension
def getModuleEntries {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment) (m : ModuleIdx) : Array α :=
-- `importedEntries` is identical on all environment branches, so `local` is always sufficient
(ext.toEnvExtension.getState (asyncMode := .local) env).importedEntries[m]!
/--
Returns the data saved by `ext.exportEntriesFn/saveEntriesFn` when `m` was elaborated. See docs on
the functions for details.
-/
def getModuleEntries {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ)
(env : Environment) (m : ModuleIdx) (includeServer := false) : Array α :=
let exts := if includeServer then env.serverBaseExts else env.base.extensions
-- safety: as in `getStateUnsafe`
unsafe (ext.toEnvExtension.getStateImpl exts).importedEntries[m]!
def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (b : β) : Environment :=
ext.toEnvExtension.modifyState env fun s =>
@@ -1436,10 +1467,14 @@ structure PersistentEnvExtensionDescr (α β σ : Type) where
addImportedFn : Array (Array α) ImportM σ
addEntryFn : σ β σ
exportEntriesFn : σ Array α
saveEntriesFn : σ Array α := exportEntriesFn
statsFn : σ Format := fun _ => Format.nil
asyncMode : EnvExtension.AsyncMode := .mainOnly
replay? : Option (ReplayFn σ) := none
attribute [inherit_doc PersistentEnvExtension.exportEntriesFn] PersistentEnvExtensionDescr.exportEntriesFn
attribute [inherit_doc PersistentEnvExtension.saveEntriesFn] PersistentEnvExtensionDescr.saveEntriesFn
unsafe def registerPersistentEnvExtensionUnsafe {α β σ : Type} [Inhabited σ] (descr : PersistentEnvExtensionDescr α β σ) : IO (PersistentEnvExtension α β σ) := do
let pExts persistentEnvExtensionsRef.get
if pExts.any (fun ext => ext.name == descr.name) then throw (IO.userError s!"invalid environment extension, '{descr.name}' has already been used")
@@ -1458,6 +1493,7 @@ unsafe def registerPersistentEnvExtensionUnsafe {α β σ : Type} [Inhabited σ]
addImportedFn := descr.addImportedFn,
addEntryFn := descr.addEntryFn,
exportEntriesFn := descr.exportEntriesFn,
saveEntriesFn := descr.saveEntriesFn,
statsFn := descr.statsFn
}
persistentEnvExtensionsRef.modify fun pExts => pExts.push (unsafeCast pExt)
@@ -1466,10 +1502,30 @@ unsafe def registerPersistentEnvExtensionUnsafe {α β σ : Type} [Inhabited σ]
@[implemented_by registerPersistentEnvExtensionUnsafe]
opaque registerPersistentEnvExtension {α β σ : Type} [Inhabited σ] (descr : PersistentEnvExtensionDescr α β σ) : IO (PersistentEnvExtension α β σ)
@[extern "lean_save_module_data"]
opaque saveModuleData (fname : @& System.FilePath) (mod : @& Name) (data : @& ModuleData) : IO Unit
@[extern "lean_read_module_data"]
opaque readModuleData (fname : @& System.FilePath) : IO (ModuleData × CompactedRegion)
/--
Stores each given module data in the respective file name. Objects shared with prior parts are not
duplicated. Thus the data cannot be loaded with individual `readModuleData` calls but must loaded by
passing (a prefix of) the file names to `readModuleDataParts`. `mod` is used to determine an
arbitrary but deterministic base address for `mmap`.
-/
@[extern "lean_save_module_data_parts"]
opaque saveModuleDataParts (mod : @& Name) (parts : Array (System.FilePath × ModuleData)) : IO Unit
/--
Loads the module data from the given file names. The files must be (a prefix of) the result of a
`saveModuleDataParts` call.
-/
@[extern "lean_read_module_data_parts"]
opaque readModuleDataParts (fnames : @& Array System.FilePath) : IO (Array (ModuleData × CompactedRegion))
def saveModuleData (fname : System.FilePath) (mod : Name) (data : ModuleData) : IO Unit :=
saveModuleDataParts mod #[(fname, data)]
def readModuleData (fname : @& System.FilePath) : IO (ModuleData × CompactedRegion) := do
let parts readModuleDataParts #[fname]
assert! parts.size == 1
let some part := parts[0]? | unreachable!
return part
/--
Free compacted regions of imports. No live references to imported objects may exist at the time of invocation; in
@@ -1493,7 +1549,22 @@ unsafe def Environment.freeRegions (env : Environment) : IO Unit :=
TODO: statically check for this. -/
env.header.regions.forM CompactedRegion.free
def mkModuleData (env : Environment) : IO ModuleData := do
/-- The level of information to save/load. Each level includes all previous ones. -/
inductive OLeanLevel where
/-- Information from exported contexts. -/
| exported
/-- Environment extension state for the language server. -/
| server
/-- Private module data. -/
| «private»
deriving DecidableEq
def OLeanLevel.adjustFileName (base : System.FilePath) : OLeanLevel System.FilePath
| .exported => base
| .server => base.addExtension "server"
| .private => base.addExtension "private"
def mkModuleData (env : Environment) (level : OLeanLevel := .private) : IO ModuleData := do
let pExts persistentEnvExtensionsRef.get
let entries := pExts.map fun pExt => Id.run do
-- get state from `checked` at the end if `async`; it would otherwise panic
@@ -1501,19 +1572,37 @@ def mkModuleData (env : Environment) : IO ModuleData := do
if asyncMode matches .async then
asyncMode := .sync
let state := pExt.getState (asyncMode := asyncMode) env
(pExt.name, pExt.exportEntriesFn state)
(pExt.name, if level = .exported then pExt.exportEntriesFn state else pExt.saveEntriesFn state)
let kenv := env.toKernelEnv
let constNames := kenv.constants.foldStage2 (fun names name _ => names.push name) #[]
let constants := kenv.constants.foldStage2 (fun cs _ c => cs.push c) #[]
let env := env.setExporting (level != .private)
let constants := kenv.constants.foldStage2 (fun cs _ c => cs.push c) #[]
--let constNames := kenv.constants.foldStage2 (fun names name _ => names.push name) #[]
-- not all kernel constants may be exported
-- TODO: does not include cstage* constants from the old codegen
--let constants := constNames.filterMap env.find?
let constNames := constants.map (·.name)
return {
imports := env.header.imports
extraConstNames := env.checked.get.extraConstNames.toArray
constNames, constants, entries
}
register_builtin_option experimental.module : Bool := {
defValue := false
descr := "Enable module system (experimental)"
}
@[export lean_write_module]
def writeModule (env : Environment) (fname : System.FilePath) : IO Unit := do
saveModuleData fname env.mainModule ( mkModuleData env)
def writeModule (env : Environment) (fname : System.FilePath) (split := false) : IO Unit := do
if split then
let mkPart (level : OLeanLevel) :=
return (level.adjustFileName fname, ( mkModuleData env level))
saveModuleDataParts env.mainModule #[
( mkPart .exported),
( mkPart .server),
( mkPart .private)]
else
saveModuleData fname env.mainModule ( mkModuleData env)
/--
Construct a mapping from persistent extension name to extension index at the array of persistent extensions.
@@ -1527,10 +1616,9 @@ def mkExtNameMap (startingAt : Nat) : IO (Std.HashMap Name Nat) := do
result := result.insert descr.name i
return result
private def setImportedEntries (env : Environment) (mods : Array ModuleData) (startingAt : Nat := 0) : IO Environment := do
-- We work directly on the states array instead of `env` as `Environment.modifyState` introduces
-- significant overhead on such frequent calls
let mut states := env.base.extensions
private def setImportedEntries (states : Array EnvExtensionState) (mods : Array ModuleData)
(startingAt : Nat := 0) : IO (Array EnvExtensionState) := do
let mut states := states
let extDescrs persistentEnvExtensionsRef.get
/- For extensions starting at `startingAt`, ensure their `importedEntries` array have size `mods.size`. -/
for extDescr in extDescrs[startingAt:] do
@@ -1546,7 +1634,7 @@ private def setImportedEntries (env : Environment) (mods : Array ModuleData) (st
-- safety: as in `modifyState`
states := unsafe extDescrs[entryIdx]!.toEnvExtension.modifyStateImpl states fun s =>
{ s with importedEntries := s.importedEntries.set! modIdx entries }
return env.setCheckedSync { env.base with extensions := states }
return states
/--
"Forward declaration" needed for updating the attribute table with user-defined attributes.
@@ -1585,7 +1673,7 @@ where
-- This branch is executed when `pExtDescrs[i]` is the extension associated with the `init` attribute, and
-- a user-defined persistent extension is imported.
-- Thus, we invoke `setImportedEntries` to update the array `importedEntries` with the entries for the new extensions.
env setImportedEntries env mods prevSize
env := env.setCheckedSync { env.base with extensions := ( setImportedEntries env.base.extensions mods prevSize) }
-- See comment at `updateEnvAttributesRef`
env updateEnvAttributes env
loop (i + 1) env
@@ -1596,7 +1684,7 @@ structure ImportState where
moduleNameSet : NameHashSet := {}
moduleNames : Array Name := #[]
moduleData : Array ModuleData := #[]
regions : Array CompactedRegion := #[]
parts : Array (Array (ModuleData × CompactedRegion)) := #[]
def throwAlreadyImported (s : ImportState) (const2ModIdx : Std.HashMap Name ModuleIdx) (modIdx : Nat) (cname : Name) : IO α := do
let modName := s.moduleNames[modIdx]!
@@ -1608,7 +1696,8 @@ abbrev ImportStateM := StateRefT ImportState IO
@[inline] nonrec def ImportStateM.run (x : ImportStateM α) (s : ImportState := {}) : IO (α × ImportState) :=
x.run s
partial def importModulesCore (imports : Array Import) : ImportStateM Unit := do
partial def importModulesCore (imports : Array Import) (level := OLeanLevel.private) :
ImportStateM Unit := do
for i in imports do
if i.runtimeOnly || ( get).moduleNameSet.contains i.module then
continue
@@ -1616,12 +1705,22 @@ partial def importModulesCore (imports : Array Import) : ImportStateM Unit := do
let mFile findOLean i.module
unless ( mFile.pathExists) do
throw <| IO.userError s!"object file '{mFile}' of module {i.module} does not exist"
let (mod, region) readModuleData mFile
importModulesCore mod.imports
let mut fnames := #[mFile]
if level != OLeanLevel.exported then
let sFile := OLeanLevel.server.adjustFileName mFile
if ( sFile.pathExists) then
fnames := fnames.push sFile
if level == OLeanLevel.private then
let pFile := OLeanLevel.private.adjustFileName mFile
if ( pFile.pathExists) then
fnames := fnames.push pFile
let parts readModuleDataParts fnames
let some (mod, _) := parts[if level = .exported then 0 else parts.size - 1]? | unreachable!
importModulesCore (level := level) mod.imports
modify fun s => { s with
moduleData := s.moduleData.push mod
regions := s.regions.push region
moduleNames := s.moduleNames.push i.module
parts := s.parts.push parts
}
/--
@@ -1685,14 +1784,16 @@ def finalizeImport (s : ImportState) (imports : Array Import) (opts : Options) (
extensions := exts
header := {
trustLevel, imports
regions := s.regions
regions := s.parts.flatMap (·.map (·.2))
moduleNames := s.moduleNames
moduleData := s.moduleData
}
}
realizedImportedConsts? := none
}
env setImportedEntries env s.moduleData
env := env.setCheckedSync { env.base with extensions := ( setImportedEntries env.base.extensions s.moduleData) }
let serverData := s.parts.filterMap fun parts => (parts[1]? <|> parts[0]?).map Prod.fst
env := { env with serverBaseExts := ( setImportedEntries env.base.extensions serverData) }
if leakEnv then
/- Mark persistent a first time before `finalizePersistenExtensions`, which
avoids costly MT markings when e.g. an interpreter closure (which
@@ -1739,13 +1840,13 @@ environment's constant map can be accessed without `loadExts`, many functions th
-/
def importModules (imports : Array Import) (opts : Options) (trustLevel : UInt32 := 0)
(plugins : Array System.FilePath := #[]) (leakEnv := false) (loadExts := false)
: IO Environment := profileitIO "import" opts do
(level := OLeanLevel.private) : IO Environment := profileitIO "import" opts do
for imp in imports do
if imp.module matches .anonymous then
throw <| IO.userError "import failed, trying to import module with anonymous name"
withImporting do
plugins.forM Lean.loadPlugin
let (_, s) importModulesCore imports |>.run
let (_, s) importModulesCore (level := level) imports |>.run
finalizeImport (leakEnv := leakEnv) (loadExts := loadExts) s imports opts trustLevel
/--

View File

@@ -93,18 +93,17 @@ structure SnapshotTask (α : Type) where
Cancellation token that can be set by the server to cancel the task when it detects the results
are not needed anymore.
-/
cancelTk? : Option IO.CancelToken := none
cancelTk? : Option IO.CancelToken
/-- Underlying task producing the snapshot. -/
task : Task α
deriving Nonempty, Inhabited
/-- Creates a snapshot task from the syntax processed by the task and a `BaseIO` action. -/
def SnapshotTask.ofIO (stx? : Option Syntax)
def SnapshotTask.ofIO (stx? : Option Syntax) (cancelTk? : Option IO.CancelToken)
(reportingRange? : Option String.Range := defaultReportingRange? stx?) (act : BaseIO α) :
BaseIO (SnapshotTask α) := do
return {
stx?
reportingRange?
stx?, reportingRange?, cancelTk?
task := ( BaseIO.asTask act)
}
@@ -114,6 +113,7 @@ def SnapshotTask.finished (stx? : Option Syntax) (a : α) : SnapshotTask α wher
-- irrelevant when already finished
reportingRange? := none
task := .pure a
cancelTk? := none
/-- Transforms a task's output without changing the processed syntax. -/
def SnapshotTask.map (t : SnapshotTask α) (f : α β) (stx? : Option Syntax := t.stx?)

View File

@@ -397,7 +397,7 @@ where
diagnostics := oldProcessed.diagnostics
result? := some {
cmdState := oldProcSuccess.cmdState
firstCmdSnap := { stx? := none, task := prom.result! } } }
firstCmdSnap := { stx? := none, task := prom.result!, cancelTk? := cancelTk } } }
else
return .finished newStx oldProcessed) } }
else return old
@@ -450,7 +450,7 @@ where
processHeader (stx : Syntax) (parserState : Parser.ModuleParserState) :
LeanProcessingM (SnapshotTask HeaderProcessedSnapshot) := do
let ctx read
SnapshotTask.ofIO stx (some 0, ctx.input.endPos) <|
SnapshotTask.ofIO stx none (some 0, ctx.input.endPos) <|
ReaderT.run (r := ctx) <| -- re-enter reader in new task
withHeaderExceptions (α := HeaderProcessedSnapshot) ({ · with result? := none }) do
let setup match ( setupImports stx) with
@@ -507,7 +507,7 @@ where
infoTree? := cmdState.infoState.trees[0]!
result? := some {
cmdState
firstCmdSnap := { stx? := none, task := prom.result! }
firstCmdSnap := { stx? := none, task := prom.result!, cancelTk? := cancelTk }
}
}
@@ -523,17 +523,19 @@ where
-- from `old`
if let some oldNext := old.nextCmdSnap? then do
let newProm IO.Promise.new
let cancelTk IO.CancelToken.new
-- can reuse range, syntax unchanged
BaseIO.chainTask (sync := true) old.resultSnap.task fun oldResult =>
-- also wait on old command parse snapshot as parsing is cheap and may allow for
-- elaboration reuse
BaseIO.chainTask (sync := true) oldNext.task fun oldNext => do
let cancelTk IO.CancelToken.new
parseCmd oldNext newParserState oldResult.cmdState newProm sync cancelTk ctx
prom.resolve <| { old with nextCmdSnap? := some {
stx? := none
reportingRange? := some newParserState.pos, ctx.input.endPos
task := newProm.result! } }
task := newProm.result!
cancelTk? := cancelTk
} }
else prom.resolve old -- terminal command, we're done!
-- fast path, do not even start new task for this snapshot (see [Incremental Parsing])
@@ -615,15 +617,16 @@ where
})
let diagnostics Snapshot.Diagnostics.ofMessageLog msgLog
-- use per-command cancellation token for elaboration so that
-- use per-command cancellation token for elaboration so that cancellation of further commands
-- does not affect current command
let elabCmdCancelTk IO.CancelToken.new
prom.resolve {
diagnostics, nextCmdSnap?
stx := stx', parserState := parserState'
elabSnap := { stx? := stx', task := elabPromise.result!, cancelTk? := some elabCmdCancelTk }
resultSnap := { stx? := stx', reportingRange? := initRange?, task := resultPromise.result! }
infoTreeSnap := { stx? := stx', reportingRange? := initRange?, task := finishedPromise.result! }
reportSnap := { stx? := none, reportingRange? := initRange?, task := reportPromise.result! }
resultSnap := { stx? := stx', reportingRange? := initRange?, task := resultPromise.result!, cancelTk? := none }
infoTreeSnap := { stx? := stx', reportingRange? := initRange?, task := finishedPromise.result!, cancelTk? := none }
reportSnap := { stx? := none, reportingRange? := initRange?, task := reportPromise.result!, cancelTk? := none }
}
let cmdState doElab stx cmdState beginPos
{ old? := old?.map fun old => old.stx, old.elabSnap, new := elabPromise }
@@ -665,8 +668,8 @@ where
-- We want to trace all of `CommandParsedSnapshot` but `traceTask` is part of it, so let's
-- create a temporary snapshot tree containing all tasks but it
let snaps := #[
{ stx? := stx', task := elabPromise.result!.map (sync := true) toSnapshotTree },
{ stx? := stx', task := resultPromise.result!.map (sync := true) toSnapshotTree }] ++
{ stx? := stx', task := elabPromise.result!.map (sync := true) toSnapshotTree, cancelTk? := none },
{ stx? := stx', task := resultPromise.result!.map (sync := true) toSnapshotTree, cancelTk? := none }] ++
cmdState.snapshotTasks
let tree := SnapshotTree.mk { diagnostics := .empty } snaps
BaseIO.bindTask ( tree.waitAll) fun _ => do
@@ -690,6 +693,7 @@ where
stx? := none
reportingRange? := initRange?
task := traceTask
cancelTk? := none
}
if let some next := next? then
-- We're definitely off the fast-forwarding path now

View File

@@ -2279,6 +2279,7 @@ def realizeConst (forConst : Name) (constName : Name) (realize : MetaM Unit) :
initHeartbeats := ( IO.getNumHeartbeats)
}
let (env, exTask, dyn) env.realizeConst forConst constName (realizeAndReport coreCtx)
-- Realizations cannot be cancelled as their result is shared across elaboration runs
let exAct Core.wrapAsyncAsSnapshot (cancelTk? := none) fun
| none => return
| some ex => do
@@ -2286,6 +2287,7 @@ def realizeConst (forConst : Name) (constName : Name) (realize : MetaM Unit) :
Core.logSnapshotTask {
stx? := none
task := ( BaseIO.mapTask (t := exTask) exAct)
cancelTk? := none
}
if let some res := dyn.get? RealizeConstantResult then
let mut snap := res.snap

View File

@@ -62,6 +62,15 @@ This is triggered by `attribute [-ext] name`.
def ExtTheorems.eraseCore (d : ExtTheorems) (declName : Name) : ExtTheorems :=
{ d with erased := d.erased.insert declName }
/-- Returns `true` if `d` contains theorem with name `declName`. -/
def ExtTheorems.contains (d : ExtTheorems) (declName : Name) : Bool :=
d.tree.containsValueP (·.declName == declName) && !d.erased.contains declName
/-- Returns `true` if `declName` is tagged with `[ext]` attribute. -/
def isExtTheorem (declName : Name) : CoreM Bool := do
let extTheorems := extExtension.getState ( getEnv)
return extTheorems.contains declName
/--
Erases a name marked as a `ext` attribute.
Check that it does in fact have the `ext` attribute by making sure it names a `ExtTheorem`
@@ -69,7 +78,7 @@ found somewhere in the state's tree, and is not erased.
-/
def ExtTheorems.erase [Monad m] [MonadError m] (d : ExtTheorems) (declName : Name) :
m ExtTheorems := do
unless d.tree.containsValueP (·.declName == declName) && !d.erased.contains declName do
unless d.contains declName do
throwError "'{declName}' does not have [ext] attribute"
return d.eraseCore declName

View File

@@ -30,6 +30,7 @@ import Lean.Meta.Tactic.Grind.MatchCond
import Lean.Meta.Tactic.Grind.MatchDiscrOnly
import Lean.Meta.Tactic.Grind.Diseq
import Lean.Meta.Tactic.Grind.MBTC
import Lean.Meta.Tactic.Grind.Lookahead
namespace Lean
@@ -52,6 +53,12 @@ builtin_initialize registerTraceClass `grind.split.candidate
builtin_initialize registerTraceClass `grind.split.resolved
builtin_initialize registerTraceClass `grind.beta
builtin_initialize registerTraceClass `grind.mbtc
builtin_initialize registerTraceClass `grind.ext
builtin_initialize registerTraceClass `grind.ext.candidate
builtin_initialize registerTraceClass `grind.lookahead
builtin_initialize registerTraceClass `grind.lookahead.add (inherited := true)
builtin_initialize registerTraceClass `grind.lookahead.try (inherited := true)
builtin_initialize registerTraceClass `grind.lookahead.assert (inherited := true)
/-! Trace options for `grind` developers -/
builtin_initialize registerTraceClass `grind.debug
@@ -76,5 +83,6 @@ builtin_initialize registerTraceClass `grind.debug.proveEq
builtin_initialize registerTraceClass `grind.debug.pushNewFact
builtin_initialize registerTraceClass `grind.debug.ematch.activate
builtin_initialize registerTraceClass `grind.debug.appMap
builtin_initialize registerTraceClass `grind.debug.ext
end Lean

View File

@@ -22,50 +22,17 @@ namespace Lean
builtin_initialize registerTraceClass `grind.cutsat
builtin_initialize registerTraceClass `grind.cutsat.model
builtin_initialize registerTraceClass `grind.cutsat.subst
builtin_initialize registerTraceClass `grind.cutsat.eq
builtin_initialize registerTraceClass `grind.cutsat.eq.unsat (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.eq.trivial (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.assert
builtin_initialize registerTraceClass `grind.cutsat.assert.dvd
builtin_initialize registerTraceClass `grind.cutsat.dvd
builtin_initialize registerTraceClass `grind.cutsat.dvd.update (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.dvd.unsat (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.dvd.trivial (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.dvd.solve (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.dvd.solve.combine (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.dvd.solve.elim (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.internalize
builtin_initialize registerTraceClass `grind.cutsat.internalize.term (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.assert.trivial
builtin_initialize registerTraceClass `grind.cutsat.assert.unsat
builtin_initialize registerTraceClass `grind.cutsat.assert.store
builtin_initialize registerTraceClass `grind.cutsat.assert.le
builtin_initialize registerTraceClass `grind.cutsat.le
builtin_initialize registerTraceClass `grind.cutsat.le.unsat (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.le.trivial (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.le.lower (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.le.upper (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.assign
builtin_initialize registerTraceClass `grind.cutsat.conflict
builtin_initialize registerTraceClass `grind.cutsat.diseq
builtin_initialize registerTraceClass `grind.cutsat.diseq.trivial (inherited := true)
builtin_initialize registerTraceClass `grind.debug.cutsat.eq
builtin_initialize registerTraceClass `grind.debug.cutsat.dvd.le
builtin_initialize registerTraceClass `grind.debug.cutsat.diseq
builtin_initialize registerTraceClass `grind.debug.cutsat.diseq.split
builtin_initialize registerTraceClass `grind.debug.cutsat.backtrack
builtin_initialize registerTraceClass `grind.debug.cutsat.search
builtin_initialize registerTraceClass `grind.debug.cutsat.cooper
builtin_initialize registerTraceClass `grind.debug.cutsat.cooper.diseq
builtin_initialize registerTraceClass `grind.debug.cutsat.conflict
builtin_initialize registerTraceClass `grind.debug.cutsat.assign
builtin_initialize registerTraceClass `grind.debug.cutsat.subst
builtin_initialize registerTraceClass `grind.debug.cutsat.getBestLower
builtin_initialize registerTraceClass `grind.debug.cutsat.nat
builtin_initialize registerTraceClass `grind.debug.cutsat.proof
builtin_initialize registerTraceClass `grind.debug.cutsat.search
builtin_initialize registerTraceClass `grind.debug.cutsat.search.split (inherited := true)
builtin_initialize registerTraceClass `grind.debug.cutsat.search.assign (inherited := true)
builtin_initialize registerTraceClass `grind.debug.cutsat.search.conflict (inherited := true)
builtin_initialize registerTraceClass `grind.debug.cutsat.search.backtrack (inherited := true)
builtin_initialize registerTraceClass `grind.debug.cutsat.internalize
builtin_initialize registerTraceClass `grind.debug.cutsat.markTerm
builtin_initialize registerTraceClass `grind.debug.cutsat.natCast
end Lean

View File

@@ -34,7 +34,7 @@ def DvdCnstr.applyEq (a : Int) (x : Var) (c₁ : EqCnstr) (b : Int) (c₂ : DvdC
let q := c₂.p
let d := Int.ofNat (a * c₂.d).natAbs
let p := (q.mul a |>.combine (p.mul (-b)))
trace[grind.cutsat.subst] "{← getVar x}, {← c₁.pp}, {← c₂.pp}"
trace[grind.debug.cutsat.subst] "{← getVar x}, {← c₁.pp}, {← c₂.pp}"
return { d, p, h := .subst x c₁ c₂ }
partial def DvdCnstr.applySubsts (c : DvdCnstr) : GoalM DvdCnstr := withIncRecDepth do
@@ -46,21 +46,20 @@ partial def DvdCnstr.applySubsts (c : DvdCnstr) : GoalM DvdCnstr := withIncRecDe
/-- Asserts divisibility constraint. -/
partial def DvdCnstr.assert (c : DvdCnstr) : GoalM Unit := withIncRecDepth do
if ( inconsistent) then return ()
trace[grind.cutsat.dvd] "{← c.pp}"
trace[grind.cutsat.assert] "{← c.pp}"
let c c.norm.applySubsts
if c.isUnsat then
trace[grind.cutsat.dvd.unsat] "{← c.pp}"
trace[grind.cutsat.assert.unsat] "{← c.pp}"
setInconsistent (.dvd c)
return ()
if c.isTrivial then
trace[grind.cutsat.dvd.trivial] "{← c.pp}"
trace[grind.cutsat.assert.trivial] "{← c.pp}"
return ()
let d₁ := c.d
let .add a₁ x p₁ := c.p | c.throwUnexpected
if ( c.satisfied) == .false then
resetAssignmentFrom x
if let some c' := ( get').dvds[x]! then
trace[grind.cutsat.dvd.solve] "{← c.pp}, {← c'.pp}"
let d₂ := c'.d
let .add a₂ _ p₂ := c'.p | c'.throwUnexpected
let (d, α, β) := gcdExt (a₁*d₂) (a₂*d₁)
@@ -75,16 +74,14 @@ partial def DvdCnstr.assert (c : DvdCnstr) : GoalM Unit := withIncRecDepth do
let α_d₂_p₁ := p₁.mul (α*d₂)
let β_d₁_p₂ := p₂.mul (β*d₁)
let combine := { d := d₁*d₂, p := .add d x (α_d₂_p₁.combine β_d₁_p₂), h := .solveCombine c c' : DvdCnstr }
trace[grind.cutsat.dvd.solve.combine] "{← combine.pp}"
modify' fun s => { s with dvds := s.dvds.set x none}
combine.assert
let a₂_p₁ := p₁.mul a₂
let a₁_p₂ := p₂.mul (-a₁)
let elim := { d, p := a₂_p₁.combine a₁_p₂, h := .solveElim c c' : DvdCnstr }
trace[grind.cutsat.dvd.solve.elim] "{← elim.pp}"
elim.assert
else
trace[grind.cutsat.dvd.update] "{← c.pp}"
trace[grind.cutsat.assert.store] "{← c.pp}"
c.p.updateOccs
modify' fun s => { s with dvds := s.dvds.set x (some c) }
@@ -97,7 +94,6 @@ def propagateIntDvd (e : Expr) : GoalM Unit := do
if ( isEqTrue e) then
let p toPoly b
let c := { d, p, h := .core e : DvdCnstr }
trace[grind.cutsat.assert.dvd] "{← c.pp}"
c.assert
else if ( isEqFalse e) then
pushNewFact <| mkApp4 (mkConst ``Int.Linear.of_not_dvd) a b reflBoolTrue (mkOfEqFalseCore e ( mkEqFalseProof e))

View File

@@ -38,12 +38,12 @@ def DiseqCnstr.applyEq (a : Int) (x : Var) (c₁ : EqCnstr) (b : Int) (c₂ : Di
let p := c₁.p
let q := c₂.p
let p := p.mul b |>.combine (q.mul (-a))
trace[grind.cutsat.subst] "{← getVar x}, {← c₁.pp}, {← c₂.pp}"
trace[grind.debug.cutsat.subst] "{← getVar x}, {← c₁.pp}, {← c₂.pp}"
return { p, h := .subst x c₁ c₂ }
partial def DiseqCnstr.applySubsts (c : DiseqCnstr) : GoalM DiseqCnstr := withIncRecDepth do
let some (x, c₁, p) c.p.substVar | return c
trace[grind.cutsat.subst] "{← getVar x}, {← c.pp}, {← c₁.pp}"
trace[grind.debug.cutsat.subst] "{← getVar x}, {← c.pp}, {← c₁.pp}"
applySubsts { p, h := .subst x c₁ c }
/--
@@ -68,10 +68,11 @@ def DiseqCnstr.assert (c : DiseqCnstr) : GoalM Unit := do
trace[grind.cutsat.assert] "{← c.pp}"
let c c.norm.applySubsts
if c.p.isUnsatDiseq then
trace[grind.cutsat.assert.unsat] "{← c.pp}"
setInconsistent (.diseq c)
return ()
if c.isTrivial then
trace[grind.cutsat.diseq.trivial] "{← c.pp}"
trace[grind.cutsat.assert.trivial] "{← c.pp}"
return ()
let k := c.p.gcdCoeffs c.p.getConst
let c := if k == 1 then
@@ -82,7 +83,7 @@ def DiseqCnstr.assert (c : DiseqCnstr) : GoalM Unit := do
return ()
let .add _ x _ := c.p | c.throwUnexpected
c.p.updateOccs
trace[grind.cutsat.diseq] "{← c.pp}"
trace[grind.cutsat.assert.store] "{← c.pp}"
modify' fun s => { s with diseqs := s.diseqs.modify x (·.push c) }
if ( c.satisfied) == .false then
resetAssignmentFrom x
@@ -108,7 +109,7 @@ where
partial def EqCnstr.applySubsts (c : EqCnstr) : GoalM EqCnstr := withIncRecDepth do
let some (x, c₁, p) c.p.substVar | return c
trace[grind.cutsat.subst] "{← getVar x}, {← c.pp}, {← c₁.pp}"
trace[grind.debug.cutsat.subst] "{← getVar x}, {← c.pp}, {← c₁.pp}"
applySubsts { p, h := .subst x c₁ c : EqCnstr }
private def updateDvdCnstr (a : Int) (x : Var) (c : EqCnstr) (y : Var) : GoalM Unit := do
@@ -197,10 +198,11 @@ def EqCnstr.assertImpl (c : EqCnstr) : GoalM Unit := do
trace[grind.cutsat.assert] "{← c.pp}"
let c c.norm.applySubsts
if c.p.isUnsatEq then
trace[grind.cutsat.assert.unsat] "{← c.pp}"
setInconsistent (.eq c)
return ()
if c.isTrivial then
trace[grind.cutsat.eq.trivial] "{← c.pp}"
trace[grind.cutsat.assert.trivial] "{← c.pp}"
return ()
let k := c.p.gcdCoeffs'
if c.p.getConst % k > 0 then
@@ -210,9 +212,9 @@ def EqCnstr.assertImpl (c : EqCnstr) : GoalM Unit := do
c
else
{ p := c.p.div k, h := .divCoeffs c }
trace[grind.cutsat.eq] "{← c.pp}"
let some (k, x) := c.p.pickVarToElim? | c.throwUnexpected
trace[grind.debug.cutsat.subst] ">> {← getVar x}, {← c.pp}"
trace[grind.cutsat.assert.store] "{← c.pp}"
modify' fun s => { s with
elimEqs := s.elimEqs.set x (some c)
elimStack := x :: s.elimStack
@@ -252,7 +254,6 @@ private def processNewNatEq (a b : Expr) : GoalM Unit := do
@[export lean_process_cutsat_eq]
def processNewEqImpl (a b : Expr) : GoalM Unit := do
trace[grind.debug.cutsat.eq] "{a} = {b}"
match ( foreignTerm? a), ( foreignTerm? b) with
| none, none => processNewIntEq a b
| some .nat, some .nat => processNewNatEq a b
@@ -271,7 +272,6 @@ private def processNewIntLitEq (a ke : Expr) : GoalM Unit := do
@[export lean_process_cutsat_eq_lit]
def processNewEqLitImpl (a ke : Expr) : GoalM Unit := do
trace[grind.debug.cutsat.eq] "{a} = {ke}"
match ( foreignTerm? a) with
| none => processNewIntLitEq a ke
| some .nat => processNewNatEq a ke
@@ -294,12 +294,10 @@ private def processNewNatDiseq (a b : Expr) : GoalM Unit := do
let rhs' toLinearExpr ( rhs.denoteAsIntExpr ctx) gen
let p := lhs'.sub rhs' |>.norm
let c := { p, h := .coreNat a b lhs rhs lhs' rhs' : DiseqCnstr }
trace[grind.debug.cutsat.nat] "{← c.pp}"
c.assert
@[export lean_process_cutsat_diseq]
def processNewDiseqImpl (a b : Expr) : GoalM Unit := do
trace[grind.debug.cutsat.diseq] "{a} ≠ {b}"
match ( foreignTerm? a), ( foreignTermOrLit? b) with
| none, none => processNewIntDiseq a b
| some .nat, some .nat => processNewNatDiseq a b
@@ -342,7 +340,7 @@ private def isForbiddenParent (parent? : Option Expr) (k : SupportedTermKind) :
private def internalizeInt (e : Expr) : GoalM Unit := do
if ( hasVar e) then return ()
let p toPoly e
trace[grind.cutsat.internalize] "{aquote e}:= {← p.pp}"
trace[grind.debug.cutsat.internalize] "{aquote e}:= {← p.pp}"
let x mkVar e
if p == .add 1 x (.num 0) then
-- It is pointless to assert `x = x`
@@ -403,9 +401,8 @@ private def internalizeNat (e : Expr) : GoalM Unit := do
let e'' : Int.Linear.Expr toLinearExpr e'' gen
let p := e''.norm
let natCast_e shareCommon (mkIntNatCast e)
trace[grind.cutsat.internalize] "natCast: {natCast_e}"
internalize natCast_e gen
trace[grind.cutsat.internalize] "{aquote natCast_e}:= {← p.pp}"
trace[grind.debug.cutsat.internalize] "{aquote natCast_e}:= {← p.pp}"
let x mkVar natCast_e
modify' fun s => { s with foreignDef := s.foreignDef.insert { expr := e } x }
let c := { p := .add (-1) x p, h := .defnNat e' x e'' : EqCnstr }

View File

@@ -24,7 +24,6 @@ def mkForeignVar (e : Expr) (t : ForeignType) : GoalM Var := do
foreignVars := s.foreignVars.insert t (vars.push e)
foreignVarMap := s.foreignVarMap.insert { expr := e} (x, t)
}
trace[grind.debug.cutsat.markTerm] "mkForeignVar: {e}"
markAsCutsatTerm e
return x

View File

@@ -100,24 +100,24 @@ where
@[export lean_grind_cutsat_assert_le]
def LeCnstr.assertImpl (c : LeCnstr) : GoalM Unit := do
if ( inconsistent) then return ()
trace[grind.cutsat.assert] "{← c.pp}"
let c c.norm.applySubsts
if c.isUnsat then
trace[grind.cutsat.le.unsat] "{← c.pp}"
trace[grind.cutsat.assert.unsat] "{← c.pp}"
setInconsistent (.le c)
return ()
if c.isTrivial then
trace[grind.cutsat.le.trivial] "{← c.pp}"
trace[grind.cutsat.assert.trivial] "{← c.pp}"
return ()
let .add a x _ := c.p | c.throwUnexpected
if ( findEq c) then
return ()
let c refineWithDiseq c
trace[grind.cutsat.assert.store] "{← c.pp}"
if a < 0 then
trace[grind.cutsat.le.lower] "{← c.pp}"
c.p.updateOccs
modify' fun s => { s with lowers := s.lowers.modify x (·.push c) }
else
trace[grind.cutsat.le.upper] "{← c.pp}"
c.p.updateOccs
modify' fun s => { s with uppers := s.uppers.modify x (·.push c) }
if ( c.satisfied) == .false then
@@ -145,7 +145,6 @@ def propagateIntLe (e : Expr) (eqTrue : Bool) : GoalM Unit := do
pure { p, h := .core e : LeCnstr }
else
pure { p := p.mul (-1) |>.addConst 1, h := .coreNeg e p : LeCnstr }
trace[grind.cutsat.assert.le] "{← c.pp}"
c.assert
def propagateNatLe (e : Expr) (eqTrue : Bool) : GoalM Unit := do
@@ -155,7 +154,6 @@ def propagateNatLe (e : Expr) (eqTrue : Bool) : GoalM Unit := do
let lhs' toLinearExpr ( lhs.denoteAsIntExpr ctx) gen
let rhs' toLinearExpr ( rhs.denoteAsIntExpr ctx) gen
let p := lhs'.sub rhs' |>.norm
trace[grind.debug.cutsat.nat] "{← p.pp}"
let c if eqTrue then
pure { p, h := .coreNat e lhs rhs lhs' rhs' : LeCnstr }
else

View File

@@ -136,9 +136,7 @@ def assertDenoteAsIntNonneg (e : Expr) : GoalM Unit := withIncRecDepth do
let lhs' : Int.Linear.Expr := .num 0
let rhs' toLinearExpr ( rhs.denoteAsIntExpr ctx) gen
let p := lhs'.sub rhs' |>.norm
trace[grind.debug.cutsat.nat] "{← p.pp}"
let c := { p, h := .denoteAsIntNonneg rhs rhs' : LeCnstr }
trace[grind.cutsat.assert.le] "{← c.pp}"
c.assert
/--
@@ -149,7 +147,6 @@ def assertNatCast (e : Expr) (x : Var) : GoalM Unit := do
let_expr NatCast.natCast _ inst a := e | return ()
let_expr instNatCastInt := inst | return ()
if ( get').foreignDef.contains { expr := a } then return ()
trace[grind.debug.cutsat.natCast] "{a}"
let n mkForeignVar a .nat
let p := .add (-1) x (.num 0)
let c := { p, h := .denoteAsIntNonneg (.var n) (.var x) : LeCnstr}

View File

@@ -301,7 +301,6 @@ partial def LeCnstr.toExprProof (c' : LeCnstr) : ProofM Expr := caching c' do
( getContext) ( mkPolyDecl p₁) ( mkPolyDecl p₂) ( mkPolyDecl c₃.p) (toExpr c₃.d) (toExpr s.k) (toExpr coeff) ( mkPolyDecl c'.p) ( s.toExprProof) reflBoolTrue
partial def DiseqCnstr.toExprProof (c' : DiseqCnstr) : ProofM Expr := caching c' do
trace[grind.debug.cutsat.proof] "{← c'.pp}"
match c'.h with
| .core0 a zero =>
mkDiseqProof a zero
@@ -352,7 +351,6 @@ partial def CooperSplit.toExprProof (s : CooperSplit) : ProofM Expr := caching s
-- `pred` is an expressions of the form `cooper_*_split ...` with type `Nat → Prop`
let mut k := n
let mut result := base -- `OrOver k (cooper_*_splti)
trace[grind.debug.cutsat.proof] "orOver_cases {n}"
result := mkApp3 (mkConst ``Int.Linear.orOver_cases) (toExpr (n-1)) pred result
for (fvarId, c) in hs do
let type := mkApp pred (toExpr (k-1))
@@ -365,7 +363,6 @@ partial def CooperSplit.toExprProof (s : CooperSplit) : ProofM Expr := caching s
return result
partial def UnsatProof.toExprProofCore (h : UnsatProof) : ProofM Expr := do
trace[grind.debug.cutsat.proof] "{← h.pp}"
match h with
| .le c =>
return mkApp4 (mkConst ``Int.Linear.le_unsat) ( getContext) ( mkPolyDecl c.p) reflBoolTrue ( c.toExprProof)
@@ -392,7 +389,6 @@ def UnsatProof.toExprProof (h : UnsatProof) : GoalM Expr := do
withProofContext do h.toExprProofCore
def setInconsistent (h : UnsatProof) : GoalM Unit := do
trace[grind.debug.cutsat.conflict] "setInconsistent [{← inconsistent}]: {← h.pp}"
if ( get').caseSplits then
-- Let the search procedure in `SearchM` resolve the conflict.
modify' fun s => { s with conflict? := some h }

View File

@@ -27,7 +27,6 @@ def CooperSplit.assert (cs : CooperSplit) : GoalM Unit := do
let p₁' := p.mul b |>.combine (q.mul (-a))
let p₁' := p₁'.addConst <| if left then b*k else (-a)*k
let c₁' := { p := p₁', h := .cooper cs : LeCnstr }
trace[grind.debug.cutsat.cooper] "{← c₁'.pp}"
c₁'.assert
if ( inconsistent) then return ()
let d := if left then a else b
@@ -35,7 +34,6 @@ def CooperSplit.assert (cs : CooperSplit) : GoalM Unit := do
let p₂' := if left then p else q
let p₂' := p₂'.addConst k
let c₂' := { d, p := p₂', h := .cooper₁ cs : DvdCnstr }
trace[grind.debug.cutsat.cooper] "dvd₁: {← c₂'.pp}"
c₂'.assert
if ( inconsistent) then return ()
let some c₃ := c₃? | return ()
@@ -51,7 +49,6 @@ def CooperSplit.assert (cs : CooperSplit) : GoalM Unit := do
let p₃' := q.mul (-c) |>.combine (s.mul b)
let p₃' := p₃'.addConst (-c*k)
{ d := b*d, p := p₃', h := .cooper₂ cs : DvdCnstr }
trace[grind.debug.cutsat.cooper] "dvd₂: {← c₃'.pp}"
c₃'.assert
private def checkIsNextVar (x : Var) : GoalM Unit := do
@@ -59,7 +56,7 @@ private def checkIsNextVar (x : Var) : GoalM Unit := do
throwError "`grind` internal error, assigning variable out of order"
private def traceAssignment (x : Var) (v : Rat) : GoalM Unit := do
trace[grind.cutsat.assign] "{quoteIfArithTerm (← getVar x)} := {v}"
trace[grind.debug.cutsat.search.assign] "{quoteIfArithTerm (← getVar x)} := {v}"
private def setAssignment (x : Var) (v : Rat) : GoalM Unit := do
checkIsNextVar x
@@ -88,7 +85,6 @@ where
modify' fun s => { s with assignment := s.assignment.set x 0 }
let some v c.p.eval? | c.throwUnexpected
let v := (-v) / a
trace[grind.debug.cutsat.assign] "{← getVar x}, {← c.pp}, {v}"
traceAssignment x v
modify' fun s => { s with assignment := s.assignment.set x v }
go xs
@@ -108,7 +104,6 @@ def tightUsingDvd (c : LeCnstr) (dvd? : Option DvdCnstr) : GoalM LeCnstr := do
let b₂ := c.p.getConst
if (b₂ - b₁) % d != 0 then
let b₂' := b₁ - d * ((b₁ - b₂) / d)
trace[grind.debug.cutsat.dvd.le] "[pos] {← c.pp}, {← dvd.pp}, {b₂'}"
let p := c.p.addConst (b₂'-b₂)
return { p, h := .dvdTight dvd c }
if eqCoeffs dvd.p c.p true then
@@ -116,7 +111,6 @@ def tightUsingDvd (c : LeCnstr) (dvd? : Option DvdCnstr) : GoalM LeCnstr := do
let b₂ := c.p.getConst
if (b₂ - b₁) % d != 0 then
let b₂' := b₁ - d * ((b₁ - b₂) / d)
trace[grind.debug.cutsat.dvd.le] "[neg] {← c.pp}, {← dvd.pp}, {b₂'}"
let p := c.p.addConst (b₂'-b₂)
return { p, h := .negDvdTight dvd c }
return c
@@ -134,7 +128,6 @@ def getBestLower? (x : Var) (dvd? : Option DvdCnstr) : GoalM (Option (Rat × LeC
let .add k _ p := c.p | c.throwUnexpected
let some v p.eval? | c.throwUnexpected
let lower' := v / (-k)
trace[grind.debug.cutsat.getBestLower] "k: {k}, x: {x}, p: {repr p}, v: {v}, best?: {best?.map (·.1)}, c: {← c.pp}"
if let some (lower, _) := best? then
if lower' > lower then
best? := some (lower', c)
@@ -214,7 +207,7 @@ def DvdCnstr.getSolutions? (c : DvdCnstr) : SearchM (Option DvdSolution) := do
return some { d, b := -b*a' }
def resolveDvdConflict (c : DvdCnstr) : GoalM Unit := do
trace[grind.cutsat.conflict] "{← c.pp}"
trace[grind.debug.cutsat.search.conflict] "{← c.pp}"
let d := c.d
let .add a _ p := c.p | c.throwUnexpected
{ d := a.gcd d, p, h := .elim c : DvdCnstr }.assert
@@ -300,7 +293,7 @@ partial def findRatVal (lower upper : Rat) (diseqVals : Array (Rat × DiseqCnstr
v
def resolveRealLowerUpperConflict (c₁ c₂ : LeCnstr) : GoalM Bool := do
trace[grind.cutsat.conflict] "{← c₁.pp}, {← c₂.pp}"
trace[grind.debug.cutsat.search.conflict] "{← c₁.pp}, {← c₂.pp}"
let .add a₁ _ p₁ := c₁.p | c₁.throwUnexpected
let .add a₂ _ p₂ := c₂.p | c₂.throwUnexpected
let p := p₁.mul a₂.natAbs |>.combine (p₂.mul a₁.natAbs)
@@ -313,7 +306,7 @@ def resolveRealLowerUpperConflict (c₁ c₂ : LeCnstr) : GoalM Bool := do
{ p, h := .combine c₁ c₂ : LeCnstr }
else
{ p := p.div k, h := .combineDivCoeffs c₁ c₂ k : LeCnstr }
trace[grind.cutsat.conflict] "resolved: {← c.pp}"
trace[grind.debug.cutsat.search.conflict] "resolved: {← c.pp}"
c.assert
return true
@@ -330,7 +323,7 @@ def resolveCooperUnary (pred : CooperSplitPred) : SearchM Bool := do
return true
def resolveCooperPred (pred : CooperSplitPred) : SearchM Unit := do
trace[grind.cutsat.conflict] "[{pred.numCases}]: {← pred.pp}"
trace[grind.debug.cutsat.search.conflict] "[{pred.numCases}]: {← pred.pp}"
if ( resolveCooperUnary pred) then
return
let n := pred.numCases
@@ -347,11 +340,11 @@ def resolveCooperDvd (c₁ c₂ : LeCnstr) (c₃ : DvdCnstr) : SearchM Unit := d
def DiseqCnstr.split (c : DiseqCnstr) : SearchM LeCnstr := do
let fvarId if let some fvarId := ( get').diseqSplits.find? c.p then
trace[grind.debug.cutsat.diseq.split] "{← c.pp}, reusing {fvarId.name}"
trace[grind.debug.cutsat.search.split] "{← c.pp}, reusing {fvarId.name}"
pure fvarId
else
let fvarId mkCase (.diseq c)
trace[grind.debug.cutsat.diseq.split] "{← c.pp}, {fvarId.name}"
trace[grind.debug.cutsat.search.split] "{← c.pp}, {fvarId.name}"
modify' fun s => { s with diseqSplits := s.diseqSplits.insert c.p fvarId }
pure fvarId
let p₂ := c.p.addConst 1
@@ -428,7 +421,6 @@ def processVar (x : Var) : SearchM Unit := do
setAssignment x v
| some (lower, c₁), some (upper, c₂) =>
trace[grind.debug.cutsat.search] "{lower} ≤ {lower.ceil} ≤ {quoteIfArithTerm (← getVar x)} ≤ {upper.floor} ≤ {upper}"
trace[grind.debug.cutsat.getBestLower] "lower: {lower}, c₁: {← c₁.pp}"
if lower > upper then
let .true resolveRealLowerUpperConflict c₁ c₂
| throwError "`grind` internal error, conflict resolution failed"
@@ -472,43 +464,42 @@ private def findCase (decVars : FVarIdSet) : SearchM Case := do
if decVars.contains case.fvarId then
return case
-- Conflict does not depend on this case.
trace[grind.debug.cutsat.backtrack] "skipping {case.fvarId.name}"
trace[grind.debug.cutsat.search.backtrack] "skipping {case.fvarId.name}"
unreachable!
private def union (vs₁ vs₂ : FVarIdSet) : FVarIdSet :=
vs₁.fold (init := vs₂) (·.insert ·)
def resolveConflict (h : UnsatProof) : SearchM Unit := do
trace[grind.debug.cutsat.backtrack] "resolve conflict, decision stack: {(← get).cases.toList.map fun c => c.fvarId.name}"
trace[grind.debug.cutsat.search.backtrack] "resolve conflict, decision stack: {(← get).cases.toList.map fun c => c.fvarId.name}"
let decVars := h.collectDecVars.run ( get).decVars
trace[grind.debug.cutsat.backtrack] "dec vars: {decVars.toList.map (·.name)}"
trace[grind.debug.cutsat.search.backtrack] "dec vars: {decVars.toList.map (·.name)}"
if decVars.isEmpty then
trace[grind.debug.cutsat.backtrack] "close goal: {← h.pp}"
trace[grind.debug.cutsat.search.backtrack] "close goal: {← h.pp}"
closeGoal ( h.toExprProof)
return ()
let c findCase decVars
modify' fun _ => c.saved
trace[grind.debug.cutsat.backtrack] "backtracking {c.fvarId.name}"
trace[grind.debug.cutsat.search.backtrack] "backtracking {c.fvarId.name}"
let decVars := decVars.erase c.fvarId
match c.kind with
| .diseq c₁ =>
let decVars := decVars.toArray
let p' := c₁.p.mul (-1) |>.addConst 1
let c' := { p := p', h := .ofDiseqSplit c₁ c.fvarId h decVars : LeCnstr }
trace[grind.debug.cutsat.backtrack] "resolved diseq split: {← c'.pp}"
trace[grind.debug.cutsat.search.backtrack] "resolved diseq split: {← c'.pp}"
c'.assert
| .cooper pred hs decVars' =>
let decVars' := union decVars decVars'
let n := pred.numCases
let hs := hs.push (c.fvarId, h)
trace[grind.debug.cutsat.backtrack] "cooper #{hs.size + 1}, {← pred.pp}, {hs.map fun p => p.1.name}"
trace[grind.debug.cutsat.search.backtrack] "cooper #{hs.size + 1}, {← pred.pp}, {hs.map fun p => p.1.name}"
let s if hs.size + 1 < n then
let fvarId mkCase (.cooper pred hs decVars')
pure { pred, k := n - hs.size - 1, h := .dec fvarId : CooperSplit }
else
let decVars' := decVars'.toArray
trace[grind.debug.cutsat.backtrack] "cooper last case, {← pred.pp}, dec vars: {decVars'.map (·.name)}"
trace[grind.debug.cutsat.proof] "CooperSplit.last"
trace[grind.debug.cutsat.search.backtrack] "cooper last case, {← pred.pp}, dec vars: {decVars'.map (·.name)}"
pure { pred, k := 0, h := .last hs decVars' : CooperSplit }
s.assert

View File

@@ -74,7 +74,6 @@ def mkCase (kind : CaseKind) : SearchM FVarId := do
decVars := s.decVars.insert fvarId
}
modify' fun s => { s with caseSplits := true }
trace[grind.debug.cutsat.backtrack] "mkCase fvarId: {fvarId.name}"
return fvarId
end Lean.Meta.Grind.Arith.Cutsat

View File

@@ -16,7 +16,7 @@ def mkVarImpl (expr : Expr) : GoalM Var := do
if let some var := ( get').varMap.find? { expr } then
return var
let var : Var := ( get').vars.size
trace[grind.cutsat.internalize.term] "{expr} ↦ #{var}"
trace[grind.debug.cutsat.internalize] "{expr} ↦ #{var}"
modify' fun s => { s with
vars := s.vars.push expr
varMap := s.varMap.insert { expr } var
@@ -27,7 +27,6 @@ def mkVarImpl (expr : Expr) : GoalM Var := do
occurs := s.occurs.push {}
elimEqs := s.elimEqs.push none
}
trace[grind.debug.cutsat.markTerm] "mkVar: {expr}"
markAsCutsatTerm expr
assertNatCast expr var
assertDenoteAsIntNonneg expr

View File

@@ -71,6 +71,7 @@ def isArithTerm (e : Expr) : Bool :=
| HMul.hMul _ _ _ _ _ _ => true
| HDiv.hDiv _ _ _ _ _ _ => true
| HMod.hMod _ _ _ _ _ _ => true
| HPow.hPow _ _ _ _ _ _ => true
| Neg.neg _ _ _ => true
| OfNat.ofNat _ _ _ => true
| _ => false

View File

@@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Lean.Meta.Tactic.Grind.EMatchTheorem
import Lean.Meta.Tactic.Grind.Cases
import Lean.Meta.Tactic.Grind.ExtAttr
namespace Lean.Meta.Grind
@@ -14,6 +15,7 @@ inductive AttrKind where
| cases (eager : Bool)
| intro
| infer
| ext
/-- Return theorem kind for `stx` of the form `Attr.grindThmMod` -/
def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do
@@ -34,6 +36,7 @@ def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do
| `(Parser.Attr.grindMod| cases) => return .cases false
| `(Parser.Attr.grindMod| cases eager) => return .cases true
| `(Parser.Attr.grindMod| intro) => return .intro
| `(Parser.Attr.grindMod| ext) => return .ext
| _ => throwError "unexpected `grind` theorem kind: `{stx}`"
/-- Return theorem kind for `stx` of the form `(Attr.grindMod)?` -/
@@ -78,6 +81,7 @@ builtin_initialize
addEMatchAttr ctor attrKind .default
else
throwError "invalid `[grind intro]`, `{declName}` is not an inductive predicate"
| .ext => addExtAttr declName attrKind
| .infer =>
if let some declName isCasesAttrCandidate? declName false then
addCasesAttr declName false attrKind
@@ -91,6 +95,8 @@ builtin_initialize
erase := fun declName => MetaM.run' do
if ( isCasesAttrCandidate declName false) then
eraseCasesAttr declName
else if ( isExtTheorem declName) then
eraseExtAttr declName
else
eraseEMatchAttr declName
}

View File

@@ -35,6 +35,7 @@ def instantiateExtTheorem (thm : Ext.ExtTheorem) (e : Expr) : GoalM Unit := with
if proof'.hasMVar || prop'.hasMVar then
reportIssue! "failed to apply extensionality theorem `{thm.declName}` for {indentExpr e}\nresulting terms contain metavariables"
return ()
trace[grind.ext] "{prop'}"
addNewRawFact proof' prop' (( getGeneration e) + 1)
end Lean.Meta.Grind

View File

@@ -0,0 +1,43 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Ext
namespace Lean.Meta.Grind
/-! Grind extensionality attribute to mark which `[ext]` theorems should be used. -/
/-- Extensionality theorems that can be used by `grind` -/
abbrev ExtTheorems := PHashSet Name
builtin_initialize extTheoremsExt : SimpleScopedEnvExtension Name ExtTheorems
registerSimpleScopedEnvExtension {
initial := {}
addEntry := fun s declName => s.insert declName
}
def validateExtAttr (declName : Name) : CoreM Unit := do
unless ( Ext.isExtTheorem declName) do
throwError "invalid `[grind ext]`, `{declName}` is tagged with `[ext]`"
def addExtAttr (declName : Name) (attrKind : AttributeKind) : CoreM Unit := do
validateExtAttr declName
extTheoremsExt.add declName attrKind
private def eraseDecl (s : ExtTheorems) (declName : Name) : CoreM ExtTheorems := do
if s.contains declName then
return s.erase declName
else
throwError "`{declName}` is not marked with the `[grind ext]` attribute"
def eraseExtAttr (declName : Name) : CoreM Unit := do
let s := extTheoremsExt.getState ( getEnv)
let s eraseDecl s declName
modifyEnv fun env => extTheoremsExt.modifyState env fun _ => s
def isExtTheorem (declName : Name) : CoreM Bool := do
return extTheoremsExt.getState ( getEnv) |>.contains declName
end Lean.Meta.Grind

View File

@@ -114,4 +114,13 @@ def propagateForallPropDown (e : Expr) : GoalM Unit := do
-- (a → b) = True → b = False → a = False
pushEqFalse a <| mkApp4 (mkConst ``Grind.eq_false_of_imp_eq_true) a b ( mkEqTrueProof e) ( mkEqFalseProof b)
builtin_grind_propagator propagateExistsDown Exists := fun e => do
if ( isEqFalse e) then
let_expr f@Exists α p := e | return ()
let u := f.constLevels!
let notP := mkApp (mkConst ``Not) (mkApp p (.bvar 0) |>.headBeta)
let prop := mkForall `x .default α notP
let proof := mkApp3 (mkConst ``forall_not_of_not_exists u) α p (mkOfEqFalseCore e ( mkEqFalseProof e))
addNewRawFact proof prop ( getGeneration e)
end Lean.Meta.Grind

View File

@@ -50,11 +50,6 @@ private def updateAppMap (e : Expr) : GoalM Unit := do
s.appMap.insert key [e]
}
/-- Inserts `e` into the list of case-split candidates. -/
private def addSplitCandidate (e : Expr) : GoalM Unit := do
trace_goal[grind.split.candidate] "{e}"
modify fun s => { s with split.candidates := e :: s.split.candidates }
private def forbiddenSplitTypes := [``Eq, ``HEq, ``True, ``False]
/-- Returns `true` if `e` is of the form `@Eq Prop a b` -/
@@ -67,10 +62,10 @@ private def checkAndAddSplitCandidate (e : Expr) : GoalM Unit := do
match e with
| .app .. =>
if ( getConfig).splitIte && (e.isIte || e.isDIte) then
addSplitCandidate e
addSplitCandidate (.default e)
return ()
if isMorallyIff e then
addSplitCandidate e
addSplitCandidate (.default e)
return ()
if ( getConfig).splitMatch then
if ( isMatcherApp e) then
@@ -79,7 +74,7 @@ private def checkAndAddSplitCandidate (e : Expr) : GoalM Unit := do
-- and consequently don't need to be split.
return ()
else
addSplitCandidate e
addSplitCandidate (.default e)
return ()
let .const declName _ := e.getAppFn | return ()
if forbiddenSplitTypes.contains declName then
@@ -87,16 +82,21 @@ private def checkAndAddSplitCandidate (e : Expr) : GoalM Unit := do
unless ( isInductivePredicate declName) do
return ()
if ( get).split.casesTypes.isSplit declName then
addSplitCandidate e
addSplitCandidate (.default e)
else if ( getConfig).splitIndPred then
addSplitCandidate e
addSplitCandidate (.default e)
| .fvar .. =>
let .const declName _ := ( whnfD ( inferType e)).getAppFn | return ()
if ( get).split.casesTypes.isSplit declName then
addSplitCandidate e
addSplitCandidate (.default e)
| .forallE _ d _ _ =>
if Arith.isRelevantPred d || ( getConfig).splitImp then
addSplitCandidate e
if ( getConfig).splitImp then
addSplitCandidate (.default e)
else if Arith.isRelevantPred d then
if ( getConfig).lookahead then
addLookaheadCandidate (.default e)
else
addSplitCandidate (.default e)
| _ => pure ()
/--
@@ -172,7 +172,7 @@ private def activateTheoremPatterns (fName : Name) (generation : Nat) : GoalM Un
modify fun s => { s with ematch.thmMap := thmMap }
let appMap := ( get).appMap
for thm in thms do
trace[grind.debug.ematch.activate] "`{fName}` => `{thm.origin.key}`"
trace_goal[grind.debug.ematch.activate] "`{fName}` => `{thm.origin.key}`"
unless ( get).ematch.thmMap.isErased thm.origin do
let symbols := thm.symbols.filter fun sym => !appMap.contains sym
let thm := { thm with symbols }
@@ -207,6 +207,68 @@ private def propagateUnitLike (a : Expr) (generation : Nat) : GoalM Unit := do
internalize unit generation
pushEq a unit <| ( mkEqRefl unit)
/-- Returns `true` if we can ignore `ext` for functions occurring as arguments of a `declName`-application. -/
private def extParentsToIgnore (declName : Name) : Bool :=
declName == ``Eq || declName == ``HEq || declName == ``dite || declName == ``ite
|| declName == ``Exists || declName == ``Subtype
/--
Given a term `arg` that occurs as the argument at position `i` of an `f`-application `parent?`,
we consider `arg` as a candidate for case-splitting. For every other argument `arg'` that also appears
at position `i` in an `f`-application and has the same type as `e`, we add the case-split candidate `arg = arg'`.
When performing the case split, we consider the following two cases:
- `arg = arg'`, which may introduce a new congruence between the corresponding `f`-applications.
- `¬(arg = arg')`, which may trigger extensionality theorems for the type of `arg`.
This feature enables `grind` to solve examples such as:
```lean
example (f : (Nat → Nat) → Nat) : a = b → f (fun x => a + x) = f (fun x => b + x) := by
grind
```
-/
private def addSplitCandidatesForExt (arg : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit := do
let some parent := parent? | return ()
unless parent.isApp do return ()
let f := parent.getAppFn
if let .const declName _ := f then
if extParentsToIgnore declName then return ()
let type inferType arg
-- Remark: we currently do not perform function extensionality on functions that produce a type that is not a proposition.
-- We may add an option to enable that in the future.
let u? typeFormerTypeLevel type
if u? != .none && u? != some .zero then return ()
let mut i := parent.getAppNumArgs
let mut it := parent
repeat
if !it.isApp then return ()
i := i - 1
if isSameExpr arg it.appArg! then
found f i type parent
it := it.appFn!
where
found (f : Expr) (i : Nat) (type : Expr) (parent : Expr) : GoalM Unit := do
trace_goal[grind.debug.ext] "{f}, {i}, {arg}"
let others := ( get).split.argsAt.find? (f, i) |>.getD []
for other in others do
if ( withDefault <| isDefEq type other.type) then
let eq := mkApp3 (mkConst ``Eq [ getLevel type]) type arg other.arg
let eq shareCommon eq
internalize eq generation
trace_goal[grind.ext.candidate] "{eq}"
-- We do not use lookahead here because it is too incomplete.
-- if (← getConfig).lookahead then
-- addLookaheadCandidate (.arg other.app parent i eq)
-- else
addSplitCandidate (.arg other.app parent i eq)
modify fun s => { s with split.argsAt := s.split.argsAt.insert (f, i) ({ arg, type, app := parent } :: others) }
return ()
/-- Applies `addSplitCandidatesForExt` if `funext` is enabled. -/
private def addSplitCandidatesForFunext (arg : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit := do
unless ( getConfig).funext do return ()
addSplitCandidatesForExt arg generation parent?
@[export lean_grind_internalize]
private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit := withIncRecDepth do
if ( alreadyInternalized e) then
@@ -229,7 +291,10 @@ private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Opt
| .fvar .. =>
mkENode' e generation
checkAndAddSplitCandidate e
| .letE .. | .lam .. =>
| .letE .. =>
mkENode' e generation
| .lam .. =>
addSplitCandidatesForFunext e generation parent?
mkENode' e generation
| .forallE _ d b _ =>
mkENode' e generation

View File

@@ -178,6 +178,12 @@ private def isEagerCasesCandidate (goal : Goal) (type : Expr) : Bool := Id.run d
let .const declName _ := type.getAppFn | return false
return goal.split.casesTypes.isEagerSplit declName
/-- Returns `true` if `type` is an inductive type with at most one constructor. -/
private def isCheapInductive (type : Expr) : CoreM Bool := do
let .const declName _ := type.getAppFn | return false
let .inductInfo info getConstInfo declName | return false
return info.numCtors <= 1
private def applyCases? (goal : Goal) (fvarId : FVarId) : GrindM (Option (List Goal)) := goal.mvarId.withContext do
/-
Remark: we used to use `whnfD`. This was a mistake, we don't want to unfold user-defined abstractions.
@@ -185,6 +191,9 @@ private def applyCases? (goal : Goal) (fvarId : FVarId) : GrindM (Option (List G
-/
let type whnf ( fvarId.getType)
if isEagerCasesCandidate goal type then
if ( cheapCasesOnly) then
unless ( isCheapInductive type) do
return none
if let .const declName _ := type.getAppFn then
saveCases declName true
let mvarIds cases goal.mvarId (mkFVar fvarId)
@@ -205,7 +214,7 @@ private def exfalsoIfNotProp (goal : Goal) : MetaM Goal := goal.mvarId.withConte
return { goal with mvarId := ( goal.mvarId.exfalso) }
/-- Introduce new hypotheses (and apply `by_contra`) until goal is of the form `... ⊢ False` -/
partial def intros (generation : Nat) : GrindTactic' := fun goal => do
partial def intros (generation : Nat) : GrindTactic' := fun goal => do
let rec go (goal : Goal) : StateRefT (Array Goal) GrindM Unit := do
if goal.inconsistent then
return ()

View File

@@ -0,0 +1,101 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Intro
import Lean.Meta.Tactic.Grind.Arith
import Lean.Meta.Tactic.Grind.Split
import Lean.Meta.Tactic.Grind.EMatch
namespace Lean.Meta.Grind
private partial def solve (generation : Nat) (goal : Goal) : GrindM Bool := do
cont ( intros generation goal)
where
cont (goals : List Goal) : GrindM Bool := do
match goals with
| [] => return true
| [goal] => loop goal
| _ => throwError "`grind` lookahead internal error, unexpected number of goals"
loop (goal : Goal) : GrindM Bool := withIncRecDepth do
if goal.inconsistent then
return true
else if let some goals assertNext goal then
cont goals
else if let some goals Arith.check goal then
cont goals
else if let some goals splitNext goal then
cont goals
else if let some goals ematchAndAssert goal then
cont goals
else
return false
private def tryLookahead (e : Expr) : GoalM Bool := do
-- TODO: if `e` is an arithmetic expression, we can avoid creating an auxiliary goal.
-- We can assert it directly to the arithmetic module.
-- Remark: We can simplify this code because the lookahead only really worked for arithmetic.
trace_goal[grind.lookahead.try] "{e}"
let proof? withoutModifyingState do
let goal get
let tag goal.mvarId.getTag
let target mkArrow (mkNot e) ( getFalseExpr)
let mvar mkFreshExprMVar target .syntheticOpaque tag
let gen getGeneration e
if ( solve gen { goal with mvarId := mvar.mvarId! }) then
return some ( instantiateMVars mvar)
else
return none
if let some proof := proof? then
trace[grind.lookahead.assert] "{e}"
pushEqTrue e <| mkApp2 (mkConst ``Grind.of_lookahead) e proof
processNewFacts
return true
else
return false
private def withLookaheadConfig (x : GrindM α) : GrindM α := do
withTheReader Grind.Context
(fun ctx => { ctx with config.qlia := true, cheapCases := true })
x
def lookahead : GrindTactic := fun goal => do
unless ( getConfig).lookahead do
return none
if goal.split.lookaheads.isEmpty then
return none
withLookaheadConfig do
let (progress, goal) GoalM.run goal do
let mut postponed := []
let mut progress := false
let infos := ( get).split.lookaheads
modify fun s => { s with split.lookaheads := [] }
for info in infos do
if ( isInconsistent) then
return true
match ( checkSplitStatus info) with
| .resolved => progress := true
| .ready _ _ true
| .notReady => postponed := info :: postponed
| .ready _ _ false =>
if ( tryLookahead info.getExpr) then
progress := true
else
postponed := info :: postponed
if progress then
modify fun s => { s with
split.lookaheads := s.split.lookaheads ++ postponed.reverse
}
return true
else
return false
if progress then
return some [goal]
else
return none
end Lean.Meta.Grind

View File

@@ -33,17 +33,26 @@ structure MBTC.Context where
-/
eqAssignment : Expr Expr GoalM Bool
private abbrev Map := Std.HashMap (Expr × Nat) (List Expr)
private abbrev Candidates := Std.HashSet (Expr × Expr)
private def mkCandidateKey (a b : Expr) : Expr × Expr :=
if a.lt b then
(a, b)
private structure ArgInfo where
arg : Expr
app : Expr
private abbrev Map := Std.HashMap (Expr × Nat) (List ArgInfo)
private abbrev Candidates := Std.HashSet SplitInfo
private def mkCandidate (a b : ArgInfo) (i : Nat) : GoalM SplitInfo := do
let (lhs, rhs) := if a.arg.lt b.arg then
(a.arg, b.arg)
else
(b, a)
(b.arg, a.arg)
let eq mkEq lhs rhs
let eq shareCommon ( canon eq)
return .arg a.app b.app i eq
/-- Model-based theory combination. -/
def mbtc (ctx : MBTC.Context) : GoalM Bool := do
unless ( getConfig).mbtc do return false
-- It is pointless to run `mbtc` if maximum number of splits has been reached.
if ( checkMaxCaseSplit) then return false
let mut map : Map := {}
let mut candidates : Candidates := {}
for ({ expr := e }, _) in ( get).enodes do
@@ -56,33 +65,33 @@ def mbtc (ctx : MBTC.Context) : GoalM Bool := do
let some arg getRoot? arg | pure ()
if ( ctx.hasTheoryVar arg) then
trace[grind.debug.mbtc] "{arg} @ {f}:{i}"
if let some others := map[(f, i)]? then
unless others.any (isSameExpr arg ·) do
for other in others do
if ( ctx.eqAssignment arg other) then
let k := mkCandidateKey arg other
candidates := candidates.insert k
map := map.insert (f, i) (arg :: others)
let argInfo : ArgInfo := { arg, app := e }
if let some otherInfos := map[(f, i)]? then
unless otherInfos.any fun info => isSameExpr arg info.arg do
for otherInfo in otherInfos do
if ( ctx.eqAssignment arg otherInfo.arg) then
if ( hasSameType arg otherInfo.arg) then
candidates := candidates.insert ( mkCandidate argInfo otherInfo i)
map := map.insert (f, i) (argInfo :: otherInfos)
else
map := map.insert (f, i) [arg]
map := map.insert (f, i) [argInfo]
i := i + 1
if candidates.isEmpty then
return false
if ( get).split.num > ( getConfig).splits then
reportIssue "skipping `mbtc`, maximum number of splits has been reached `(splits := {(← getConfig).splits})`"
return false
let result := candidates.toArray.qsort fun (a₁, b₁) (a₂, b₂) =>
if isSameExpr a₁ a₂ then
b₁.lt b₂
else
a₁.lt a₂
let eqs result.mapM fun (a, b) => do
let eq mkEq a b
trace[grind.mbtc] "{eq}"
let eq shareCommon ( canon eq)
let result := candidates.toArray.qsort fun c₁ c₂ => c₁.lt c₂
let result result.filterMapM fun info => do
if ( isKnownCaseSplit info) then
return none
let .arg a b _ eq := info | return none
internalize eq (Nat.max ( getGeneration a) ( getGeneration b))
return eq
modify fun s => { s with split.candidates := s.split.candidates ++ eqs.toList }
return some info
if result.isEmpty then
return false
for info in result do
addSplitCandidate info
return true
def mbtcTac (ctx : MBTC.Context) : GrindTactic := fun goal => do

View File

@@ -9,18 +9,6 @@ import Lean.Meta.Tactic.Grind.Simp
namespace Lean.Meta.Grind
/--
Helper function for executing `x` with a fresh `newFacts` and without modifying
the goal state.
-/
private def withoutModifyingState (x : GoalM α) : GoalM α := do
let saved get
modify fun goal => { goal with newFacts := {} }
try
x
finally
set saved
/--
If `e` has not been internalized yet, instantiate metavariables, unfold reducible, canonicalize,
and internalize the result.

View File

@@ -8,6 +8,7 @@ import Lean.Meta.Tactic.Grind.Combinators
import Lean.Meta.Tactic.Grind.Split
import Lean.Meta.Tactic.Grind.EMatch
import Lean.Meta.Tactic.Grind.Arith
import Lean.Meta.Tactic.Grind.Lookahead
namespace Lean.Meta.Grind
@@ -62,6 +63,8 @@ def trySplit : Goal → M Bool := applyTac splitNext
def tryArith : Goal M Bool := applyTac Arith.check
def tryLookahead : Goal M Bool := applyTac lookahead
def tryMBTC : Goal M Bool := applyTac Arith.Cutsat.mbtcTac
def maxNumFailuresReached : M Bool := do
@@ -81,6 +84,8 @@ partial def main (fallback : Fallback) : M Unit := do
continue
if ( tryEmatch goal) then
continue
if ( tryLookahead goal) then
continue
if ( trySplit goal) then
continue
if ( tryMBTC goal) then

View File

@@ -11,14 +11,14 @@ import Lean.Meta.Tactic.Grind.CasesMatch
namespace Lean.Meta.Grind
inductive CaseSplitStatus where
inductive SplitStatus where
| resolved
| notReady
| ready (numCases : Nat) (isRec := false)
| ready (numCases : Nat) (isRec := false) (tryPostpone := false)
deriving Inhabited, BEq, Repr
/-- Given `c`, the condition of an `if-then-else`, check whether we need to case-split on the `if-then-else` or not -/
private def checkIteCondStatus (c : Expr) : GoalM CaseSplitStatus := do
private def checkIteCondStatus (c : Expr) : GoalM SplitStatus := do
if ( isEqTrue c <||> isEqFalse c) then
return .resolved
else
@@ -28,7 +28,7 @@ private def checkIteCondStatus (c : Expr) : GoalM CaseSplitStatus := do
Given `e` of the form `a b`, check whether we are ready to case-split on `e`.
That is, `e` is `True`, but neither `a` nor `b` is `True`."
-/
private def checkDisjunctStatus (e a b : Expr) : GoalM CaseSplitStatus := do
private def checkDisjunctStatus (e a b : Expr) : GoalM SplitStatus := do
if ( isEqTrue e) then
if ( isEqTrue a <||> isEqTrue b) then
return .resolved
@@ -43,7 +43,7 @@ private def checkDisjunctStatus (e a b : Expr) : GoalM CaseSplitStatus := do
Given `e` of the form `a ∧ b`, check whether we are ready to case-split on `e`.
That is, `e` is `False`, but neither `a` nor `b` is `False`.
-/
private def checkConjunctStatus (e a b : Expr) : GoalM CaseSplitStatus := do
private def checkConjunctStatus (e a b : Expr) : GoalM SplitStatus := do
if ( isEqTrue e) then
return .resolved
else if ( isEqFalse e) then
@@ -60,7 +60,7 @@ There are two cases:
1- `e` is `True`, but neither both `a` and `b` are `True`, nor both `a` and `b` are `False`.
2- `e` is `False`, but neither `a` is `True` and `b` is `False`, nor `a` is `False` and `b` is `True`.
-/
private def checkIffStatus (e a b : Expr) : GoalM CaseSplitStatus := do
private def checkIffStatus (e a b : Expr) : GoalM SplitStatus := do
if ( isEqTrue e) then
if ( (isEqTrue a <&&> isEqTrue b) <||> (isEqFalse a <&&> isEqFalse b)) then
return .resolved
@@ -76,13 +76,14 @@ private def checkIffStatus (e a b : Expr) : GoalM CaseSplitStatus := do
/-- Returns `true` is `c` is congruent to a case-split that was already performed. -/
private def isCongrToPrevSplit (c : Expr) : GoalM Bool := do
unless c.isApp do return false
( get).split.resolved.foldM (init := false) fun flag { expr := c' } => do
if flag then
return true
else
return isCongruent ( get).enodes c c'
return c'.isApp && isCongruent ( get).enodes c c'
private def checkForallStatus (e : Expr) : GoalM CaseSplitStatus := do
private def checkForallStatus (e : Expr) : GoalM SplitStatus := do
if ( isEqTrue e) then
let .forallE _ p q _ := e | return .resolved
if ( isEqTrue p <||> isEqFalse p) then
@@ -96,7 +97,7 @@ private def checkForallStatus (e : Expr) : GoalM CaseSplitStatus := do
else
return .notReady
private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do
private def checkDefaultSplitStatus (e : Expr) : GoalM SplitStatus := do
match_expr e with
| Or a b => checkDisjunctStatus e a b
| And a b => checkConjunctStatus e a b
@@ -132,9 +133,37 @@ private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do
return .ready info.ctors.length info.isRec
return .notReady
def checkSplitInfoArgStatus (a b : Expr) (eq : Expr) : GoalM SplitStatus := do
if ( isEqTrue eq <||> isEqFalse eq) then return .resolved
let is := ( get).split.argPosMap[(a, b)]? |>.getD []
let mut j := a.getAppNumArgs
let mut it_a := a
let mut it_b := b
repeat
unless it_a.isApp && it_b.isApp do return .ready 2
j := j - 1
if j is then
let arg_a := it_a.appArg!
let arg_b := it_b.appArg!
unless ( isEqv arg_a arg_b) do
trace_goal[grind.split] "may be irrelevant\na: {a}\nb: {b}\neq: {eq}\narg_a: {arg_a}\narg_b: {arg_b}, gen: {← getGeneration eq}"
/-
We tried to return `.notReady` because we would not be able to derive a congruence, but
`grind_ite.lean` breaks when this heuristic is used. TODO: understand better why.
-/
return .ready 2 (tryPostpone := true)
it_a := it_a.appFn!
it_b := it_b.appFn!
return .ready 2
def checkSplitStatus (s : SplitInfo) : GoalM SplitStatus := do
match s with
| .default e => checkDefaultSplitStatus e
| .arg a b _ eq => checkSplitInfoArgStatus a b eq
private inductive SplitCandidate where
| none
| some (c : Expr) (numCases : Nat) (isRec : Bool)
| some (c : SplitInfo) (numCases : Nat) (isRec : Bool) (tryPostpone : Bool)
/-- Returns the next case-split to be performed. It uses a very simple heuristic. -/
private def selectNextSplit? : GoalM SplitCandidate := do
@@ -142,11 +171,11 @@ private def selectNextSplit? : GoalM SplitCandidate := do
if ( checkMaxCaseSplit) then return .none
go ( get).split.candidates .none []
where
go (cs : List Expr) (c? : SplitCandidate) (cs' : List Expr) : GoalM SplitCandidate := do
go (cs : List SplitInfo) (c? : SplitCandidate) (cs' : List SplitInfo) : GoalM SplitCandidate := do
match cs with
| [] =>
modify fun s => { s with split.candidates := cs'.reverse }
if let .some _ numCases isRec := c? then
if let .some _ numCases isRec _ := c? then
let numSplits := ( get).split.num
-- We only increase the number of splits if there is more than one case or it is recursive.
let numSplits := if numCases > 1 || isRec then numSplits + 1 else numSplits
@@ -155,22 +184,28 @@ where
modify fun s => { s with split.num := numSplits, ematch.num := 0 }
return c?
| c::cs =>
trace_goal[grind.debug.split] "checking: {c}"
match ( checkCaseSplitStatus c) with
trace_goal[grind.debug.split] "checking: {c.getExpr}"
match ( checkSplitStatus c) with
| .notReady => go cs c? (c::cs')
| .resolved => go cs c? cs'
| .ready numCases isRec =>
match c? with
| .none => go cs (.some c numCases isRec) cs'
| .some c' numCases' _ =>
| .ready numCases isRec tryPostpone =>
if ( cheapCasesOnly) && numCases > 1 then
go cs c? (c::cs')
else match c? with
| .none => go cs (.some c numCases isRec tryPostpone) cs'
| .some c' numCases' _ tryPostpone' =>
let isBetter : GoalM Bool := do
if numCases == 1 && !isRec && numCases' > 1 then
if tryPostpone' && !tryPostpone then
return true
if ( getGeneration c) < ( getGeneration c') then
else if tryPostpone && !tryPostpone' then
return false
else if numCases == 1 && !isRec && numCases' > 1 then
return true
if ( getGeneration c.getExpr) < ( getGeneration c'.getExpr) then
return true
return numCases < numCases'
if ( isBetter) then
go cs (.some c numCases isRec) (c'::cs')
go cs (.some c numCases isRec tryPostpone) (c'::cs')
else
go cs c? (c::cs')
@@ -192,6 +227,7 @@ private def mkCasesMajor (c : Expr) : GoalM Expr := do
else
-- model-based theory combination split
return mkGrindEM c
| Not e => return mkGrindEM e
| _ =>
if let .forallE _ p _ _ := c then
return mkGrindEM p
@@ -212,8 +248,9 @@ and returns a new list of goals if successful.
-/
def splitNext : GrindTactic := fun goal => do
let (goals?, _) GoalM.run goal do
let .some c numCases isRec selectNextSplit?
let .some c numCases isRec _ selectNextSplit?
| return none
let c := c.getExpr
let gen getGeneration c
let genNew := if numCases > 1 || isRec then gen+1 else gen
markCaseSplitAsResolved c

View File

@@ -17,6 +17,7 @@ import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Ext
import Lean.Meta.Tactic.Grind.ENodeKey
import Lean.Meta.Tactic.Grind.Attr
import Lean.Meta.Tactic.Grind.ExtAttr
import Lean.Meta.Tactic.Grind.Cases
import Lean.Meta.Tactic.Grind.Arith.Types
import Lean.Meta.Tactic.Grind.EMatchTheorem
@@ -58,6 +59,15 @@ structure Context where
simprocs : Array Simp.Simprocs
mainDeclName : Name
config : Grind.Config
/--
If `cheapCases` is `true`, `grind` only applies `cases` to types that contain
at most one minor premise.
Recall that `grind` applies `cases` when introducing types tagged with `[grind cases eager]`,
and at `Split.lean`
Remark: We add this option to implement the `lookahead` feature, we don't want to create several subgoals
when performing lookahead.
-/
cheapCases : Bool := false
/-- Key for the congruence theorem cache. -/
structure CongrTheoremCacheKey where
@@ -164,6 +174,9 @@ def getNatZeroExpr : GrindM Expr := do
def getMainDeclName : GrindM Name :=
return ( readThe Context).mainDeclName
def cheapCasesOnly : GrindM Bool :=
return ( readThe Context).cheapCases
def saveEMatchTheorem (thm : EMatchTheorem) : GrindM Unit := do
if ( getConfig).trace then
modify fun s => { s with trace.thms := s.trace.thms.insert { origin := thm.origin, kind := thm.kind } }
@@ -472,22 +485,72 @@ structure EMatch.State where
matchEqNames : PHashSet Name := {}
deriving Inhabited
/-- Case-split information. -/
inductive SplitInfo where
| /--
Term `e` may be an inductive predicate, `match`-expression, `if`-expression, implication, etc.
-/
default (e : Expr)
| /--
Given applications `a` and `b`, case-split on whether the corresponding
`i`-th arguments are equal or not. The split is only performed if all other
arguments are already known to be equal or are also tagged as split candidates.
-/
arg (a b : Expr) (i : Nat) (eq : Expr)
deriving BEq, Hashable, Inhabited
def SplitInfo.getExpr : SplitInfo Expr
| .default (.forallE _ d _ _) => d
| .default e => e
| .arg _ _ _ eq => eq
def SplitInfo.lt : SplitInfo SplitInfo Bool
| .default e₁, .default e₂ => e₁.lt e₂
| .arg _ _ _ e₁, .arg _ _ _ e₂ => e₁.lt e₂
| .default _, .arg .. => true
| .arg .., .default _ => false
/-- Argument `arg : type` of an application `app` in `SplitInfo`. -/
structure SplitArg where
arg : Expr
type : Expr
app : Expr
/-- Case splitting related fields for the `grind` goal. -/
structure Split.State where
/-- Inductive datatypes marked for case-splitting -/
casesTypes : CasesTypes := {}
/-- Case-split candidates. -/
candidates : List Expr := []
/-- Number of splits performed to get to this goal. -/
num : Nat := 0
num : Nat := 0
/-- Inductive datatypes marked for case-splitting -/
casesTypes : CasesTypes := {}
/-- Case-split candidates. -/
candidates : List SplitInfo := []
/-- Case-splits that have been inserted at `candidates` at some point. -/
added : Std.HashSet SplitInfo := {}
/-- Case-splits that have already been performed, or that do not have to be performed anymore. -/
resolved : PHashSet ENodeKey := {}
resolved : PHashSet ENodeKey := {}
/--
Sequence of cases steps that generated this goal. We only use this information for diagnostics.
Remark: `casesTrace.length ≥ numSplits` because we don't increase the counter for `cases`
applications that generated only 1 subgoal.
-/
trace : List CaseTrace := []
trace : List CaseTrace := []
/-- Lookahead "case-splits". -/
lookaheads : List SplitInfo := []
/--
A mapping `(a, b) ↦ is` s.t. for each `SplitInfo.arg a b i eq`
in `candidates` or `lookaheads` we have `i ∈ is`.
We use this information to decide whether the split/lookahead is "ready"
to be tried or not.
-/
argPosMap : Std.HashMap (Expr × Expr) (List Nat) := {}
/--
Mapping from pairs `(f, i)` to a list of arguments.
Each argument occurs as the `i`-th of an `f`-application.
We use this information to add splits/lookaheads for
triggering extensionality theorems and model-based theory combination.
See `addSplitCandidatesForExt`.
-/
argsAt : PHashMap (Expr × Nat) (List SplitArg) := {}
deriving Inhabited
/-- Clean name generator. -/
@@ -510,7 +573,7 @@ structure Goal where
-/
appMap : PHashMap HeadIndex (List Expr) := {}
/-- Equations and propositions to be processed. -/
newFacts : Array NewFact := #[]
newFacts : Array NewFact := #[]
/-- `inconsistent := true` if `ENode`s for `True` and `False` are in the same equivalence class. -/
inconsistent : Bool := false
/-- Next unique index for creating ENodes -/
@@ -518,17 +581,17 @@ structure Goal where
/-- new facts to be preprocessed and then asserted. -/
newRawFacts : Std.Queue NewRawFact :=
/-- Asserted facts -/
facts : PArray Expr := {}
facts : PArray Expr := {}
/-- Cached extensionality theorems for types. -/
extThms : PHashMap ENodeKey (Array Ext.ExtTheorem) := {}
extThms : PHashMap ENodeKey (Array Ext.ExtTheorem) := {}
/-- State of the E-matching module. -/
ematch : EMatch.State
ematch : EMatch.State
/-- State of the case-splitting module. -/
split : Split.State := {}
split : Split.State := {}
/-- State of arithmetic procedures. -/
arith : Arith.State := {}
arith : Arith.State := {}
/-- State of the clean name generator. -/
clean : Clean.State := {}
clean : Clean.State := {}
deriving Inhabited
def Goal.admit (goal : Goal) : MetaM Unit :=
@@ -1153,6 +1216,14 @@ partial def Goal.getEqcs (goal : Goal) : List (List Expr) := Id.run do
def getEqcs : GoalM (List (List Expr)) :=
return ( get).getEqcs
/--
Returns `true` if `s` has been already added to the case-split list at one point.
Remark: this function returns `true` even if the split has already been resolved
and is not in the list anymore.
-/
def isKnownCaseSplit (s : SplitInfo) : GoalM Bool :=
return ( get).split.added.contains s
/-- Returns `true` if `e` is a case-split that does not need to be performed anymore. -/
def isResolvedCaseSplit (e : Expr) : GoalM Bool :=
return ( get).split.resolved.contains { expr := e }
@@ -1167,16 +1238,38 @@ def markCaseSplitAsResolved (e : Expr) : GoalM Unit := do
trace_goal[grind.split.resolved] "{e}"
modify fun s => { s with split.resolved := s.split.resolved.insert { expr := e } }
private def updateSplitArgPosMap (sinfo : SplitInfo) : GoalM Unit := do
let .arg a b i _ := sinfo | return ()
let key := (a, b)
let is := ( get).split.argPosMap[key]? |>.getD []
modify fun s => { s with
split.argPosMap := s.split.argPosMap.insert key (i :: is)
}
/-- Inserts `e` into the list of case-split candidates if it was not inserted before. -/
def addSplitCandidate (sinfo : SplitInfo) : GoalM Unit := do
unless ( isKnownCaseSplit sinfo) do
trace_goal[grind.split.candidate] "{sinfo.getExpr}"
modify fun s => { s with
split.added := s.split.added.insert sinfo
split.candidates := sinfo :: s.split.candidates
}
updateSplitArgPosMap sinfo
/--
Returns extensionality theorems for the given type if available.
If `Config.ext` is `false`, the result is `#[]`.
-/
def getExtTheorems (type : Expr) : GoalM (Array Ext.ExtTheorem) := do
unless ( getConfig).ext do return #[]
unless ( getConfig).ext || ( getConfig).extAll do return #[]
if let some thms := ( get).extThms.find? { expr := type } then
return thms
else
let thms Ext.getExtTheorems type
let thms if ( getConfig).extAll then
pure thms
else
thms.filterM fun thm => isExtTheorem thm.declName
modify fun s => { s with extThms := s.extThms.insert { expr := type } thms }
return thms
@@ -1188,4 +1281,22 @@ def synthesizeInstanceAndAssign (x type : Expr) : MetaM Bool := do
let .some val trySynthInstance type | return false
isDefEq x val
/-- Add a new lookahead candidate. -/
def addLookaheadCandidate (sinfo : SplitInfo) : GoalM Unit := do
trace_goal[grind.lookahead.add] "{sinfo.getExpr}"
modify fun s => { s with split.lookaheads := sinfo :: s.split.lookaheads }
updateSplitArgPosMap sinfo
/--
Helper function for executing `x` with a fresh `newFacts` and without modifying
the goal state.
-/
def withoutModifyingState (x : GoalM α) : GoalM α := do
let saved get
modify fun goal => { goal with newFacts := {} }
try
x
finally
set saved
end Lean.Meta.Grind

View File

@@ -97,7 +97,8 @@ def _root_.Lean.MVarId.clearAuxDecls (mvarId : MVarId) : MetaM MVarId := mvarId.
try
mvarId mvarId.clear fvarId
catch _ =>
throwTacticEx `grind.clear_aux_decls mvarId "failed to clear local auxiliary declaration"
let userName := ( fvarId.getDecl).userName
throwTacticEx `grind mvarId m!"the goal mentions the declaration `{userName}`, which is being defined. To avoid circular reasoning, try rewriting the goal to eliminate `{userName}` before using `grind`."
return mvarId
/--

View File

@@ -29,16 +29,6 @@ def withTypeAscription (d : Delab) (cond : Bool := true) : Delab := do
else
return stx
/--
If `pp.tagAppFns` is set, then `d` is evaluated with the delaborated head constant as the ref.
-/
def withFnRefWhenTagAppFns (d : Delab) : Delab := do
if ( getExpr).getAppFn.isConst && ( getPPOption getPPTagAppFns) then
let head withNaryFn delab
withRef head <| d
else
d
/--
Wraps the identifier (or identifier with explicit universe levels) with `@` if `pp.analysis.blockImplicit` is set to true.
-/
@@ -134,6 +124,18 @@ def delabConst : Delab := do
else
return stx
/--
If `pp.tagAppFns` is set, and if the current expression is a constant application,
then `d` is evaluated with the head constant delaborated with `delabConst` as the ref.
-/
def withFnRefWhenTagAppFns (d : Delab) : Delab := do
if ( getExpr).getAppFn.isConst && ( getPPOption getPPTagAppFns) then
-- delabConst in `pp.tagAppFns` mode annotates the term.
let head withNaryFn delabConst
withRef head <| d
else
d
def withMDataOptions [Inhabited α] (x : DelabM α) : DelabM α := do
match getExpr with
| Expr.mdata m .. =>

View File

@@ -207,7 +207,7 @@ This option can only be set on the command line, not in the lakefile or via `set
stickyInteractiveDiagnostics ++ docInteractiveDiagnostics
|>.map (·.toDiagnostic)
let notification := mkPublishDiagnosticsNotification doc.meta diagnostics
ctx.chanOut.send notification
ctx.chanOut.sync.send notification
open Language in
/--
@@ -239,7 +239,7 @@ This option can only be set on the command line, not in the lakefile or via `set
publishDiagnostics ctx doc
-- This will overwrite existing ilean info for the file, in case something
-- went wrong during the incremental updates.
ctx.chanOut.send ( mkIleanInfoFinalNotification doc.meta st.allInfoTrees)
ctx.chanOut.sync.send ( mkIleanInfoFinalNotification doc.meta st.allInfoTrees)
return ()
where
/--
@@ -312,7 +312,7 @@ This option can only be set on the command line, not in the lakefile or via `set
if let some itree := node.element.infoTree? then
let mut newInfoTrees := ( get).newInfoTrees.push itree
if ( get).hasBlocked then
ctx.chanOut.send ( mkIleanInfoUpdateNotification doc.meta newInfoTrees)
ctx.chanOut.sync.send ( mkIleanInfoUpdateNotification doc.meta newInfoTrees)
newInfoTrees := #[]
modify fun st => { st with newInfoTrees, allInfoTrees := st.allInfoTrees.push itree }
@@ -329,7 +329,7 @@ This option can only be set on the command line, not in the lakefile or via `set
| none => rs.push r
let ranges := ranges.map (·.toLspRange doc.meta.text)
let notifs := ranges.map ({ range := ·, kind := .processing })
ctx.chanOut.send <| mkFileProgressNotification doc.meta notifs
ctx.chanOut.sync.send <| mkFileProgressNotification doc.meta notifs
end Elab
@@ -389,9 +389,9 @@ def setupImports
severity? := DiagnosticSeverity.information
message := stderrLine
}
chanOut.send <| mkPublishDiagnosticsNotification meta #[progressDiagnostic]
chanOut.sync.send <| mkPublishDiagnosticsNotification meta #[progressDiagnostic]
-- clear progress notifications in the end
chanOut.send <| mkPublishDiagnosticsNotification meta #[]
chanOut.sync.send <| mkPublishDiagnosticsNotification meta #[]
match fileSetupResult.kind with
| .importsOutOfDate =>
return .error {
@@ -413,6 +413,8 @@ def setupImports
-- default to async elaboration; see also `Elab.async` docs
let opts := Elab.async.setIfNotSet opts true
let opts := Elab.inServer.set opts true
return .ok {
mainModuleName := meta.mod
opts
@@ -523,7 +525,7 @@ section ServerRequests
(freshRequestId, freshRequestId + 1)
let responseTask ctx.initPendingServerRequest responseType freshRequestId
let r : JsonRpc.Request paramType := freshRequestId, method, param
ctx.chanOut.send r
ctx.chanOut.sync.send r
return responseTask
def sendUntypedServerRequest
@@ -677,7 +679,7 @@ section MessageHandling
let availableImports ImportCompletion.collectAvailableImports
let lastRequestTimestampMs IO.monoMsNow
let completions := ImportCompletion.find text st.doc.initSnap.stx params availableImports
ctx.chanOut.send <| .response id (toJson completions)
ctx.chanOut.sync.send <| .response id (toJson completions)
pure { availableImports, lastRequestTimestampMs : AvailableImportsCache }
| some task => ServerTask.IO.mapTaskCostly (t := task) fun (result : Except Error AvailableImportsCache) => do
@@ -687,7 +689,7 @@ section MessageHandling
availableImports ImportCompletion.collectAvailableImports
lastRequestTimestampMs := timestampNowMs
let completions := ImportCompletion.find text st.doc.initSnap.stx params availableImports
ctx.chanOut.send <| .response id (toJson completions)
ctx.chanOut.sync.send <| .response id (toJson completions)
pure { availableImports, lastRequestTimestampMs : AvailableImportsCache }
def handleStatefulPreRequestSpecialCases (id : RequestID) (method : String) (params : Json) : WorkerM Bool := do
@@ -699,7 +701,7 @@ section MessageHandling
| "$/lean/rpc/connect" =>
let ps parseParams RpcConnectParams params
let resp handleRpcConnect ps
ctx.chanOut.send <| .response id (toJson resp)
ctx.chanOut.sync.send <| .response id (toJson resp)
return true
| "textDocument/completion" =>
let params parseParams CompletionParams params
@@ -712,7 +714,7 @@ section MessageHandling
| _ =>
return false
catch e =>
ctx.chanOut.send <| .responseError id .internalError (toString e) none
ctx.chanOut.sync.send <| .responseError id .internalError (toString e) none
return true
open Widget RequestM Language in
@@ -834,7 +836,7 @@ section MessageHandling
emitResponse ctx (isComplete := false) <| e.toLspResponseError id
where
emitResponse (ctx : WorkerContext) (m : JsonRpc.Message) (isComplete : Bool) : IO Unit := do
ctx.chanOut.send m
ctx.chanOut.sync.send m
let timestamp IO.monoMsNow
ctx.modifyPartialHandler method fun h => { h with
requestsInFlight := h.requestsInFlight - 1

View File

@@ -41,6 +41,7 @@ structure InfoPopup where
doc : Option String
deriving Inhabited, RpcEncodable
open PrettyPrinter.Delaborator in
/-- Given elaborator info for a particular subexpression. Produce the `InfoPopup`.
The intended usage of this is for the infoview to pass the `InfoWithCtx` which
@@ -53,12 +54,10 @@ def makePopup : WithRpcRef InfoWithCtx → RequestM (RequestTask InfoPopup)
| some type => some <$> ppExprTagged type
| none => pure none
let exprExplicit? match i.info with
| Elab.Info.ofTermInfo ti
| Elab.Info.ofDelabTermInfo { toTermInfo := ti, explicit := true, ..} =>
some <$> ppExprTaggedWithoutTopLevelHighlight ti.expr (explicit := true)
| Elab.Info.ofDelabTermInfo { toTermInfo := ti, explicit := false, ..} =>
-- Keep the top-level tag so that users can also see the explicit version of the term on an additional hover.
some <$> ppExprTagged ti.expr (explicit := false)
| Elab.Info.ofTermInfo ti =>
some <$> ppExprForPopup ti.expr (explicit := true)
| Elab.Info.ofDelabTermInfo { toTermInfo := ti, explicit, ..} =>
some <$> ppExprForPopup ti.expr (explicit := explicit)
| Elab.Info.ofFieldInfo fi => pure <| some <| TaggedText.text fi.fieldName.toString
| _ => pure none
return {
@@ -67,11 +66,26 @@ def makePopup : WithRpcRef InfoWithCtx → RequestM (RequestTask InfoPopup)
doc := i.info.docString? : InfoPopup
}
where
ppExprTaggedWithoutTopLevelHighlight (e : Expr) (explicit : Bool) : MetaM CodeWithInfos := do
let pp ppExprTagged e (explicit := explicit)
return match pp with
| .tag _ tt => tt
| tt => tt
maybeWithoutTopLevelHighlight : Bool CodeWithInfos CodeWithInfos
| true, .tag _ tt => tt
| _, tt => tt
ppExprForPopup (e : Expr) (explicit : Bool := false) : MetaM CodeWithInfos := do
let mut e := e
-- When hovering over a metavariable, we want to see its value, even if `pp.instantiateMVars` is false.
if explicit && e.isMVar then
if let some e' getExprMVarAssignment? e.mvarId! then
e := e'
-- When `explicit` is false, keep the top-level tag so that users can also see the explicit version of the term on an additional hover.
maybeWithoutTopLevelHighlight explicit <$> ppExprTagged e do
if explicit then
withOptionAtCurrPos pp.tagAppFns.name true do
withOptionAtCurrPos pp.explicit.name true do
withOptionAtCurrPos pp.mvars.anonymous.name true do
delabApp
else
withOptionAtCurrPos pp.proofs.name true do
withOptionAtCurrPos pp.sorrySource.name true do
delab
builtin_initialize
registerBuiltinRpcProcedure

View File

@@ -56,6 +56,36 @@ elab_rules : tactic
-- can't use a naked promise in `initialize` as marking it persistent would block
initialize unblockedCancelTk : IO.CancelToken IO.CancelToken.new
/--
Waits for `unblock` to be called, which is expected to happen in a subsequent document version that
does not invalidate this tactic. Complains if cancellation token was set before unblocking, i.e. if
the tactic was invalidated after all.
-/
scoped syntax "wait_for_unblock" : tactic
@[incremental]
elab_rules : tactic
| `(tactic| wait_for_unblock) => do
let ctx readThe Core.Context
let some cancelTk := ctx.cancelTk? | unreachable!
dbg_trace "blocked!"
log "blocked"
let ctx readThe Elab.Term.Context
let some tacSnap := ctx.tacSnap? | unreachable!
tacSnap.new.resolve {
diagnostics := ( Language.Snapshot.Diagnostics.ofMessageLog ( Core.getMessageLog))
stx := default
finished := default
}
while true do
if ( unblockedCancelTk.isSet) then
break
IO.sleep 30
if ( cancelTk.isSet) then
IO.eprintln "cancelled!"
log "cancelled (should never be visible)"
/--
Spawns a `logSnapshotTask` that waits for `unblock` to be called, which is expected to happen in a
subsequent document version that does not invalidate this tactic. Complains if cancellation token
@@ -83,6 +113,10 @@ scoped elab "unblock" : tactic => do
dbg_trace "unblocking!"
unblockedCancelTk.set
/--
Like `wait_for_cancel_once` but does the waiting in a separate task and waits for its
cancellation.
-/
scoped syntax "wait_for_cancel_once_async" : tactic
@[incremental]
elab_rules : tactic
@@ -110,3 +144,35 @@ elab_rules : tactic
dbg_trace "blocked!"
log "blocked"
/--
Like `wait_for_cancel_once_async` but waits for the main thread's cancellation token. This is useful
to test main thread cancellation in non-incremental contexts because we otherwise wouldn't be able
to send out the "blocked" message from there.
-/
scoped syntax "wait_for_main_cancel_once_async" : tactic
@[incremental]
elab_rules : tactic
| `(tactic| wait_for_main_cancel_once_async) => do
let prom IO.Promise.new
if let some t := ( onceRef.modifyGet (fun old => (old, old.getD prom.result!))) then
IO.wait t
return
let some cancelTk := ( readThe Core.Context).cancelTk? | unreachable!
let act Elab.Term.wrapAsyncAsSnapshot (cancelTk? := none) fun _ => do
let ctx readThe Core.Context
-- TODO: `CancelToken` should probably use `Promise`
while true do
if ( cancelTk.isSet) then
break
IO.sleep 30
IO.eprintln "cancelled!"
log "cancelled (should never be visible)"
prom.resolve ()
Core.checkInterrupted
let t BaseIO.asTask (act ())
Core.logSnapshotTask { stx? := none, task := t, cancelTk? := cancelTk }
dbg_trace "blocked!"
log "blocked"

View File

@@ -73,24 +73,15 @@ where
}
TaggedText.tag t (go subTt)
def ppExprTagged (e : Expr) (explicit : Bool := false) : MetaM CodeWithInfos := do
open PrettyPrinter Delaborator in
/--
Pretty prints the expression `e` using delaborator `delab`, returning an object that represents
the pretty printed syntax paired with information needed to support hovers.
-/
def ppExprTagged (e : Expr) (delab : Delab := Delaborator.delab) : MetaM CodeWithInfos := do
if pp.raw.get ( getOptions) then
return .text (toString ( instantiateMVars e))
let delab := open PrettyPrinter.Delaborator in
if explicit then
withOptionAtCurrPos pp.tagAppFns.name true do
withOptionAtCurrPos pp.explicit.name true do
withOptionAtCurrPos pp.mvars.anonymous.name true do
delabApp
else
withOptionAtCurrPos pp.proofs.name true do
withOptionAtCurrPos pp.sorrySource.name true do
delab
let mut e := e
-- When hovering over a metavariable, we want to see its value, even if `pp.instantiateMVars` is false.
if explicit && e.isMVar then
if let some e' getExprMVarAssignment? e.mvarId! then
e := e'
let e if getPPInstantiateMVars ( getOptions) then instantiateMVars e else pure e
return .text (toString e)
let fmt, infos PrettyPrinter.ppExprWithInfos e (delab := delab)
let tt := TaggedText.prettyTagged fmt
let ctx := {

View File

@@ -215,7 +215,7 @@ theorem modify_eq_alter [BEq α] [LawfulBEq α] {a : α} {f : β a → β a} {l
modify a f l = alter a (·.map f) l := by
induction l
· rfl
· next ih => simp only [modify, beq_iff_eq, alter, Option.map_some', ih]
· next ih => simp only [modify, beq_iff_eq, alter, Option.map_some, ih]
namespace Const
@@ -235,7 +235,7 @@ theorem modify_eq_alter [BEq α] [EquivBEq α] {a : α} {f : β → β} {l : Ass
modify a f l = alter a (·.map f) l := by
induction l
· rfl
· next ih => simp only [modify, beq_iff_eq, alter, Option.map_some', ih]
· next ih => simp only [modify, beq_iff_eq, alter, Option.map_some, ih]
end Const

View File

@@ -630,7 +630,7 @@ theorem Const.getThenInsertIfNew?_eq_insertIfNewₘ [BEq α] [Hashable α] (m :
(a : α) (b : β) : (Const.getThenInsertIfNew? m a b).2 = m.insertIfNewₘ a b := by
rw [getThenInsertIfNew?, insertIfNewₘ, containsₘ, bucket]
dsimp only [Array.ugetElem_eq_getElem, Array.uset]
split <;> simp_all [consₘ, updateBucket, List.containsKey_eq_isSome_getValue?, -Option.not_isSome]
split <;> simp_all [consₘ, updateBucket, List.containsKey_eq_isSome_getValue?, -Option.isSome_eq_false_iff]
theorem Const.getThenInsertIfNew?_eq_get?ₘ [BEq α] [Hashable α] (m : Raw₀ α (fun _ => β)) (a : α)
(b : β) : (Const.getThenInsertIfNew? m a b).1 = Const.get?ₘ m a := by

View File

@@ -964,7 +964,7 @@ theorem isHashSelf_filterMapₘ [BEq α] [Hashable α] [ReflBEq α] [LawfulHasha
IsHashSelf (m.filterMapₘ f).1.buckets := by
refine h.buckets_hash_self.updateAllBuckets (fun l p hp => ?_)
have hp := AssocList.toList_filterMap.mem_iff.1 hp
simp only [mem_filterMap, Option.map_eq_some'] at hp
simp only [mem_filterMap, Option.map_eq_some_iff] at hp
obtain p, hkv, d, -, rfl := hp
exact containsKey_of_mem hkv

View File

@@ -1031,7 +1031,7 @@ theorem ordered_filterMap [Ord α] {t : Impl α β} {h} {f : (a : α) → β a
simp only [Ordered, toListModel_filterMap]
apply ho.filterMap
intro e f hef e' he' f' hf'
simp only [Option.map_eq_some'] at he' hf'
simp only [Option.map_eq_some_iff] at he' hf'
obtain _, _, rfl := he'
obtain _, _, rfl := hf'
exact hef

View File

@@ -169,7 +169,7 @@ theorem getValue?_eq_getEntry? [BEq α] {l : List ((_ : α) × β)} {a : α} :
· next k v l ih =>
cases h : k == a
· rw [getEntry?_cons_of_false h, getValue?_cons_of_false h, ih]
· rw [getEntry?_cons_of_true h, getValue?_cons_of_true h, Option.map_some']
· rw [getEntry?_cons_of_true h, getValue?_cons_of_true h, Option.map_some]
theorem getValue?_congr [BEq α] [PartialEquivBEq α] {l : List ((_ : α) × β)} {a b : α}
(h : a == b) : getValue? a l = getValue? b l := by
@@ -338,7 +338,7 @@ theorem getEntry?_eq_none [BEq α] {l : List ((a : α) × β a)} {a : α} :
@[simp]
theorem getValue?_eq_none {β : Type v} [BEq α] {l : List ((_ : α) × β)} {a : α} :
getValue? a l = none containsKey a l = false := by
rw [getValue?_eq_getEntry?, Option.map_eq_none', getEntry?_eq_none]
rw [getValue?_eq_getEntry?, Option.map_eq_none_iff, getEntry?_eq_none]
theorem containsKey_eq_isSome_getValue? {β : Type v} [BEq α] {l : List ((_ : α) × β)} {a : α} :
containsKey a l = (getValue? a l).isSome := by
@@ -613,7 +613,7 @@ theorem getKey?_eq_getEntry? [BEq α] {l : List ((a : α) × β a)} {a : α} :
· next k v l ih =>
cases h : k == a
· rw [getEntry?_cons_of_false h, getKey?_cons_of_false h, ih]
· rw [getEntry?_cons_of_true h, getKey?_cons_of_true h, Option.map_some']
· rw [getEntry?_cons_of_true h, getKey?_cons_of_true h, Option.map_some]
theorem fst_mem_keys_of_mem [BEq α] [EquivBEq α] {a : (a : α) × β a} {l : List ((a : α) × β a)}
(hm : a l) : a.1 keys l :=
@@ -1515,7 +1515,7 @@ theorem containsKey_insertEntryIfNew [BEq α] [PartialEquivBEq α] {l : List ((a
cases h : k == a
· simp
· rw [containsKey_eq_isSome_getEntry?, getEntry?_congr h]
simp
simp [-Option.not_isSome]
theorem containsKey_insertEntryIfNew_self [BEq α] [EquivBEq α] {l : List ((a : α) × β a)} {k : α}
{v : β k} : containsKey k (insertEntryIfNew k v l) := by
@@ -2098,7 +2098,7 @@ theorem containsKey_append_of_not_contains_right [BEq α] {l l' : List ((a : α)
@[simp]
theorem getValue?_append {β : Type v} [BEq α] {l l' : List ((_ : α) × β)} {a : α} :
getValue? a (l ++ l') = (getValue? a l).or (getValue? a l') := by
simp [getValue?_eq_getEntry?, Option.map_or']
simp [getValue?_eq_getEntry?, Option.map_or]
theorem getValue?_append_of_containsKey_eq_false {β : Type v} [BEq α] {l l' : List ((_ : α) × β)}
{a : α} (h : containsKey a l' = false) : getValue? a (l ++ l') = getValue? a l := by
@@ -2233,7 +2233,7 @@ theorem mem_map_toProd_iff_mem {β : Type v} {k : α} {v : β} {l : List ((_ :
theorem mem_iff_getValue?_eq_some [BEq α] [LawfulBEq α] {β : Type v} {k : α} {v : β}
{l : List ((_ : α) × β)} (h : DistinctKeys l) :
k, v l getValue? k l = some v := by
simp only [mem_iff_getEntry?_eq_some h, getValue?_eq_getEntry?, Option.map_eq_some']
simp only [mem_iff_getEntry?_eq_some h, getValue?_eq_getEntry?, Option.map_eq_some_iff]
constructor
· intro h
exists k, v
@@ -2256,7 +2256,7 @@ theorem find?_map_toProd_eq_some_iff_getKey?_eq_some_and_getValue?_eq_some [BEq
| nil => simp
| cons hd tl ih =>
simp only [List.map_cons, List.find?_cons_eq_some, Prod.mk.injEq, Bool.not_eq_eq_eq_not,
Bool.not_true, Option.map_eq_some', getKey?, cond_eq_if, getValue?]
Bool.not_true, Option.map_eq_some_iff, getKey?, cond_eq_if, getValue?]
by_cases hdfst_k: hd.fst == k
· simp only [hdfst_k, true_and, Bool.true_eq_false, false_and, or_false, reduceIte,
Option.some.injEq]
@@ -2271,7 +2271,7 @@ theorem mem_iff_getKey?_eq_some_and_getValue?_eq_some [BEq α] [EquivBEq α]
theorem getValue?_eq_some_iff_exists_beq_and_mem_toList {β : Type v} [BEq α] [EquivBEq α]
{l : List ((_ : α) × β)} {k: α} {v : β} (h : DistinctKeys l) :
getValue? k l = some v k', (k == k') = true (k', v) l.map (fun x => (x.fst, x.snd)) := by
simp only [getValue?_eq_getEntry?, Option.map_eq_some', mem_map_toProd_iff_mem,
simp only [getValue?_eq_getEntry?, Option.map_eq_some_iff, mem_map_toProd_iff_mem,
mem_iff_getEntry?_eq_some h]
constructor
· intro h'
@@ -2635,7 +2635,7 @@ theorem getKey?_insertList_of_mem [BEq α] [EquivBEq α]
rcases List.mem_map.1 mem with k, v, pair_mem, rfl
rw [getKey?_eq_getEntry?, getEntry?_insertList distinct_l distinct_toInsert,
getEntry?_of_mem (DistinctKeys.def.2 distinct_toInsert) k_beq pair_mem, Option.some_or,
Option.map_some']
Option.map_some]
theorem getKey_insertList_of_contains_eq_false [BEq α] [EquivBEq α]
{l toInsert : List ((a : α) × β a)} {k : α}
@@ -3074,7 +3074,7 @@ theorem getKey?_insertListIfNewUnit_of_contains_eq_false_of_contains_eq_false [B
(h': containsKey k l = false) (h : toInsert.contains k = false) :
getKey? k (insertListIfNewUnit l toInsert) = none := by
rw [getKey?_eq_getEntry?,
getEntry?_insertListIfNewUnit_of_contains_eq_false h, Option.map_eq_none', getEntry?_eq_none]
getEntry?_insertListIfNewUnit_of_contains_eq_false h, Option.map_eq_none_iff, getEntry?_eq_none]
exact h'
theorem getKey?_insertListIfNewUnit_of_contains_eq_false_of_mem [BEq α] [EquivBEq α]
@@ -3083,8 +3083,8 @@ theorem getKey?_insertListIfNewUnit_of_contains_eq_false_of_mem [BEq α] [EquivB
(mem' : containsKey k l = false)
(distinct : toInsert.Pairwise (fun a b => (a == b) = false)) (mem : k toInsert) :
getKey? k' (insertListIfNewUnit l toInsert) = some k := by
simp only [getKey?_eq_getEntry?, getEntry?_insertListIfNewUnit, Option.map_eq_some',
Option.or_eq_some, getEntry?_eq_none]
simp only [getKey?_eq_getEntry?, getEntry?_insertListIfNewUnit, Option.map_eq_some_iff,
Option.or_eq_some_iff, getEntry?_eq_none]
exists k, ()
simp only [and_true]
right
@@ -4301,8 +4301,8 @@ theorem minEntry?_eq_some_iff [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α]
theorem minKey?_eq_some_iff_getKey?_eq_self_and_forall [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α]
{k} {l : List ((a : α) × β a)} (hd : DistinctKeys l) :
minKey? l = some k getKey? k l = some k k' : α, containsKey k' l (compare k k').isLE := by
simp only [minKey?, Option.map_eq_some', minEntry?_eq_some_iff _ hd]
simp only [getKey?_eq_getEntry?, Option.map_eq_some', getEntry?_eq_some_iff hd]
simp only [minKey?, Option.map_eq_some_iff, minEntry?_eq_some_iff _ hd]
simp only [getKey?_eq_getEntry?, Option.map_eq_some_iff, getEntry?_eq_some_iff hd]
apply Iff.intro
· rintro _, hm, hcmp, rfl
exact _, BEq.refl, hm, rfl, hcmp
@@ -4312,7 +4312,7 @@ theorem minKey?_eq_some_iff_getKey?_eq_self_and_forall [Ord α] [TransOrd α] [B
theorem minKey?_eq_some_iff_mem_and_forall [Ord α] [LawfulEqOrd α] [TransOrd α] [BEq α]
[LawfulBEqOrd α] {k} {l : List ((a : α) × β a)} (hd : DistinctKeys l) :
minKey? l = some k containsKey k l k' : α, containsKey k' l (compare k k').isLE := by
simp only [minKey?, Option.map_eq_some', minEntry?_eq_some_iff _ hd]
simp only [minKey?, Option.map_eq_some_iff, minEntry?_eq_some_iff _ hd]
apply Iff.intro
· rintro _, hm, hcmp, rfl
exact containsKey_of_mem hm, hcmp
@@ -4361,7 +4361,7 @@ theorem isNone_minKey?_eq_isEmpty [Ord α] {l : List ((a : α) × β a)} :
theorem isSome_minEntry?_eq_not_isEmpty [Ord α] {l : List ((a : α) × β a)} :
(minEntry? l).isSome = !l.isEmpty := by
rw [ Bool.not_inj_iff, Bool.not_not, Bool.eq_iff_iff, Bool.not_eq_true', Option.not_isSome,
rw [ Bool.not_inj_iff, Bool.not_not, Bool.eq_iff_iff, Bool.not_eq_true', Option.isSome_eq_false_iff,
Option.isNone_iff_eq_none]
apply minEntry?_eq_none_iff_isEmpty
@@ -4384,7 +4384,7 @@ theorem minEntry?_map [Ord α] (l : List ((a : α) × β a)) (f : (a : α) × β
simp only [minEntry?, List.min?]
cases l <;> try rfl
rename_i e es
simp only [List.map_cons, Option.map_some', Option.some.injEq]
simp only [List.map_cons, Option.map_some, Option.some.injEq]
rw [ List.foldr_reverse, List.foldr_reverse, List.map_reverse]
induction es.reverse with
| nil => rfl
@@ -4442,7 +4442,7 @@ theorem minEntry?_insertEntry [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α]
cases h : containsKey k l
· simp only [cond_false, minEntry?_cons, Option.some.injEq]
rfl
· rw [cond_true, minEntry?_replaceEntry hl, Option.map_eq_some']
· rw [cond_true, minEntry?_replaceEntry hl, Option.map_eq_some_iff]
have := isSome_minEntry?_of_contains _
simp only [Option.isSome_iff_exists] at this
obtain a, ha := this
@@ -4538,7 +4538,7 @@ theorem minKey?_bind_getKey? [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α]
theorem containsKey_minKey? [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α] {l : List ((a : α) × β a)}
(hd : DistinctKeys l) {km} (hkm : minKey? l = some km) :
containsKey km l := by
simp only [minKey?, Option.map_eq_some', minEntry?_eq_some_iff _ hd] at hkm
simp only [minKey?, Option.map_eq_some_iff, minEntry?_eq_some_iff _ hd] at hkm
obtain e, hm, _, rfl := hkm
exact containsKey_of_mem hm
@@ -4710,7 +4710,7 @@ theorem minKey?_alterKey_eq_self [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd
minKey? (alterKey k f l) = some k
(f (getValueCast? k l)).isSome k', containsKey k' l (compare k k').isLE := by
simp only [minKey?_eq_some_iff_getKey?_eq_self_and_forall hd.alterKey, getKey?_alterKey _ hd,
beq_self_eq_true, reduceIte, ite_eq_left_iff, Bool.not_eq_true, Option.not_isSome,
beq_self_eq_true, reduceIte, ite_eq_left_iff, Bool.not_eq_true, Option.isSome_eq_false_iff,
Option.isNone_iff_eq_none, reduceCtorEq, imp_false, Option.isSome_iff_ne_none,
containsKey_alterKey hd, beq_iff_eq, Bool.ite_eq_true_distrib, and_congr_right_iff]
intro hf
@@ -4753,18 +4753,18 @@ theorem minKey?_modifyKey_eq_minKey? [Ord α] [TransOrd α] [BEq α] [LawfulBEqO
simp only [minKey?_modifyKey hd]
cases minKey? l
· rfl
· simp only [beq_iff_eq, Option.map_some', Option.some.injEq, ite_eq_right_iff]
· simp only [beq_iff_eq, Option.map_some, Option.some.injEq, ite_eq_right_iff]
exact Eq.symm
theorem isSome_minKey?_modifyKey [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α] {k f}
{l : List ((_ : α) × β)} :
(modifyKey k f l |> minKey?).isSome = !l.isEmpty := by
simp [Option.isSome_map', isSome_minKey?_eq_not_isEmpty, isEmpty_modifyKey]
simp [Option.isSome_map, isSome_minKey?_eq_not_isEmpty, isEmpty_modifyKey]
theorem isSome_minKey?_modifyKey_eq_isSome [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α] {k f}
{l : List ((_ : α) × β)} :
(modifyKey k f l |> minKey?).isSome = (minKey? l).isSome := by
simp [Option.isSome_map', isSome_minKey?_eq_not_isEmpty, isEmpty_modifyKey]
simp [Option.isSome_map, isSome_minKey?_eq_not_isEmpty, isEmpty_modifyKey]
theorem minKey?_modifyKey_beq [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd α] {k f km kmm}
{l : List ((_ : α) × β)} (hd : DistinctKeys l) (hkm : minKey? l = some km)
@@ -4784,7 +4784,7 @@ theorem minKey?_alterKey_eq_self [Ord α] [TransOrd α] [BEq α] [LawfulBEqOrd
(f (getValue? k l)).isSome k', containsKey k' l (compare k k').isLE := by
simp only [minKey?_eq_some_iff_getKey?_eq_self_and_forall hd.constAlterKey, getKey?_alterKey _ hd,
compare_eq_iff_beq, compare_self, reduceIte, ite_eq_left_iff, Bool.not_eq_true,
Option.not_isSome, Option.isNone_iff_eq_none, reduceCtorEq, imp_false,
Option.isSome_eq_false_iff, Option.isNone_iff_eq_none, reduceCtorEq, imp_false,
Option.isSome_iff_ne_none, containsKey_alterKey hd, Bool.ite_eq_true_distrib,
and_congr_right_iff]
intro hf

View File

@@ -1,137 +1,720 @@
/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Gabriel Ebner
Authors: Henrik Böving
-/
prelude
import Init.System.Promise
import Init.Data.Queue
import Std.Sync.Mutex
/-!
This module contains the implementation of `Std.Channel`. `Std.Channel` is a multi-producer
multi-consumer FIFO channel that offers both bounded and unbounded buffering as well as synchronous
and asynchronous APIs.
Additionally `Std.CloseableChannel` is provided in case closing the channel is of interest.
The two are distinct as the non closable `Std.Channel` can never throw errors which makes
for cleaner code.
-/
namespace Std
/--
Internal state of an `Channel`.
We maintain the invariant that at all times either `consumers` or `values` is empty.
-/
structure Channel.State (α : Type) where
values : Std.Queue α :=
consumers : Std.Queue (IO.Promise (Option α)) :=
closed := false
deriving Inhabited
namespace CloseableChannel
/--
FIFO channel with unbounded buffer, where `recv?` returns a `Task`.
A channel can be closed. Once it is closed, all `send`s are ignored, and
`recv?` returns `none` once the queue is empty.
Errors that may be thrown while interacting with the channel API.
-/
def Channel (α : Type) : Type := Mutex (Channel.State α)
inductive Error where
/--
Tried to send to a closed channel.
-/
| closed
/--
Tried to close an already closed channel.
-/
| alreadyClosed
deriving Repr, DecidableEq, Hashable
instance : Nonempty (Channel α) :=
inferInstanceAs (Nonempty (Mutex _))
instance : ToString Error where
toString
| .closed => "trying to send on an already closed channel"
| .alreadyClosed => "trying to close an already closed channel"
/-- Creates a new `Channel`. -/
def Channel.new : BaseIO (Channel α) :=
Mutex.new {}
instance : MonadLift (EIO Error) IO where
monadLift x := EIO.toIO (.userError <| toString ·) x
/--
Sends a message on an `Channel`.
This function does not block.
The central state structure for an unbounded channel, maintains the following invariants:
1. `values = ∅ consumers = ∅`
2. `closed = true → consumers = ∅`
-/
def Channel.send (ch : Channel α) (v : α) : BaseIO Unit :=
ch.atomically do
private structure Unbounded.State (α : Type) where
/--
Values pushed into the channel that are waiting to be consumed.
-/
values : Std.Queue α
/--
Consumers that are blocked on a producer providing them a value. The `IO.Promise` will be
resolved to `none` if the channel closes.
-/
consumers : Std.Queue (IO.Promise (Option α))
/--
Whether the channel is closed already.
-/
closed : Bool
deriving Nonempty
private structure Unbounded (α : Type) where
state : Mutex (Unbounded.State α)
deriving Nonempty
namespace Unbounded
private def new : BaseIO (Unbounded α) := do
return {
state := Mutex.new {
values :=
consumers :=
closed := false
}
}
private def trySend (ch : Unbounded α) (v : α) : BaseIO Bool := do
ch.state.atomically do
let st get
if st.closed then return
if let some (consumer, consumers) := st.consumers.dequeue? then
if st.closed then
return false
else if let some (consumer, consumers) := st.consumers.dequeue? then
consumer.resolve (some v)
set { st with consumers }
return true
else
set { st with values := st.values.enqueue v }
return true
/--
Closes an `Channel`.
-/
def Channel.close (ch : Channel α) : BaseIO Unit :=
ch.atomically do
private def send (ch : Unbounded α) (v : α) : BaseIO (Task (Except Error Unit)) := do
if Unbounded.trySend ch v then
return .pure <| .ok ()
else
return .pure <| .error .closed
private def close (ch : Unbounded α) : EIO Error Unit := do
ch.state.atomically do
let st get
if st.closed then throw .alreadyClosed
for consumer in st.consumers.toArray do consumer.resolve none
set { st with closed := true, consumers := }
set { st with consumers := , closed := true }
return ()
private def isClosed (ch : Unbounded α) : BaseIO Bool :=
ch.state.atomically do
return ( get).closed
private def tryRecv' : AtomicT (Unbounded.State α) BaseIO (Option α) := do
let st get
if let some (a, values) := st.values.dequeue? then
set { st with values }
return some a
else
return none
private def tryRecv (ch : Unbounded α) : BaseIO (Option α) :=
ch.state.atomically do
tryRecv'
private def recv (ch : Unbounded α) : BaseIO (Task (Option α)) := do
ch.state.atomically do
if let some val tryRecv' then
return .pure <| some val
else if ( get).closed then
return .pure none
else
let promise IO.Promise.new
modify fun st => { st with consumers := st.consumers.enqueue promise }
return promise.result?.map (sync := true) (·.bind id)
end Unbounded
/--
Receives a message, without blocking.
The returned task waits for the message.
Every message is only received once.
Returns `none` if the channel is closed and the queue is empty.
The central state structure for a zero buffer channel, maintains the following invariants:
1. `producers = ∅ consumers = ∅`
2. `closed = true → consumers = ∅`
-/
def Channel.recv? (ch : Channel α) : BaseIO (Task (Option α)) :=
ch.atomically do
private structure Zero.State (α : Type) where
/--
Producers that are blocked on a consumer taking their value.
-/
producers : Std.Queue (α × IO.Promise Bool)
/--
Consumers that are blocked on a producer providing them a value. The `IO.Promise` will be resolved
to `none` if the channel closes.
-/
consumers : Std.Queue (IO.Promise (Option α))
/--
Whether the channel is closed already.
-/
closed : Bool
private structure Zero (α : Type) where
state : Mutex (Zero.State α)
namespace Zero
private def new : BaseIO (Zero α) := do
return {
state := Mutex.new {
producers :=
consumers :=
closed := false
}
}
/--
Precondition: The channel must not be closed.
-/
private def trySend' (v : α) : AtomicT (Zero.State α) BaseIO Bool := do
let st get
if let some (consumer, consumers) := st.consumers.dequeue? then
consumer.resolve (some v)
set { st with consumers }
return true
else
return false
private def trySend (ch : Zero α) (v : α) : BaseIO Bool := do
ch.state.atomically do
if ( get).closed then
return false
else
trySend' v
private def send (ch : Zero α) (v : α) : BaseIO (Task (Except Error Unit)) := do
ch.state.atomically do
if ( get).closed then
return .pure <| .error .closed
else if trySend' v then
return .pure <| .ok ()
else
let promise IO.Promise.new
modify fun st => { st with producers := st.producers.enqueue (v, promise) }
return promise.result?.map (sync := true)
fun
| none | some false => .error .closed
| some true => .ok ()
private def close (ch : Zero α) : EIO Error Unit := do
ch.state.atomically do
let st get
if let some (a, values) := st.values.dequeue? then
set { st with values }
return .pure a
if st.closed then throw .alreadyClosed
for consumer in st.consumers.toArray do consumer.resolve none
set { st with consumers := , closed := true }
return ()
private def isClosed (ch : Zero α) : BaseIO Bool :=
ch.state.atomically do
return ( get).closed
private def tryRecv' : AtomicT (Zero.State α) BaseIO (Option α) := do
let st get
if let some ((val, promise), producers) := st.producers.dequeue? then
set { st with producers }
promise.resolve true
return some val
else
return none
private def tryRecv (ch : Zero α) : BaseIO (Option α) := do
ch.state.atomically do
tryRecv'
private def recv (ch : Zero α) : BaseIO (Task (Option α)) := do
ch.state.atomically do
let st get
if let some val tryRecv' then
return .pure <| some val
else if !st.closed then
let promise IO.Promise.new
set { st with consumers := st.consumers.enqueue promise }
return promise.result?.map (sync := true) (·.bind id)
else
return .pure none
return .pure <| none
end Zero
/--
`ch.forAsync f` calls `f` for every messages received on `ch`.
The central state structure for a bounded channel, maintains the following invariants:
1. `0 < capacity`
2. `0 < bufCount → consumers = ∅`
3. `bufCount < capacity → producers = ∅`
4. `producers = ∅ consumers = ∅`, implied by 1, 2 and 3.
5. `bufCount` corresponds to the amount of slots in `buf` that are `some`.
6. `sendIdx = (recvIdx + bufCount) % capacity`. However all four of these values still get tracked
as there is potential to make a non-blocking send lock-free in the future with this approach.
7. `closed = true → consumers = ∅`
Note that if this function is called twice, each `forAsync` only gets half the messages.
While it (currently) lacks the partial lock-freeness of go channels, the protocol is based on
[Go channels on steroids](https://docs.google.com/document/d/1yIAYmbvL3JxOKOjuCyon7JhW4cSv1wy5hC0ApeGMV9s/pub)
as well as its [implementation](https://go.dev/src/runtime/chan.go).
-/
partial def Channel.forAsync (f : α BaseIO Unit) (ch : Channel α)
private structure Bounded.State (α : Type) where
/--
Producers that are blocked on a consumer taking their value as there was no buffer space
available when they tried to enqueue.
-/
producers : Std.Queue (IO.Promise Bool)
/--
Consumers that are blocked on a producer providing them a value, as there was no value
enqueued when they tried to dequeue. The `IO.Promise` will be resolved to `false` if the channel
closes.
-/
consumers : Std.Queue (IO.Promise Bool)
/--
The capacity of the buffer space.
-/
capacity : Nat
/--
The buffer space for the channel, slots with `some v` contain a value that is waiting for
consumption, the slots with `none` are free for enqueueing.
Note that this is a `Vector` of `IO.Ref (Option α)` as the `buf` itself is shared across threads
and would thus keep getting copied if it was a `Vector (Option α)` instead.
-/
buf : Vector (IO.Ref (Option α)) capacity
/--
How many slots in `buf` are currently used, this is used to disambiguate between an empty and a
full buffer without sacrificing a slot for indicating that.
-/
bufCount : Nat
/--
The slot in `buf` that the next send will happen to.
-/
sendIdx : Nat
hsend : sendIdx < capacity
/--
The slot in `buf` that the next receive will happen from.
-/
recvIdx : Nat
hrecv : recvIdx < capacity
/--
Whether the channel is closed already.
-/
closed : Bool
private structure Bounded (α : Type) where
state : Mutex (Bounded.State α)
namespace Bounded
private def new (capacity : Nat) (hcap : 0 < capacity) : BaseIO (Bounded α) := do
return {
state := Mutex.new {
producers :=
consumers :=
capacity := capacity
buf := Vector.range capacity |>.mapM (fun _ => IO.mkRef none)
bufCount := 0
sendIdx := 0
hsend := hcap
recvIdx := 0
hrecv := hcap
closed := false
}
}
@[inline]
private def incMod (idx : Nat) (cap : Nat) : Nat :=
if idx + 1 = cap then
0
else
idx + 1
private theorem incMod_lt {idx cap : Nat} (h : idx < cap) : incMod idx cap < cap := by
unfold incMod
split <;> omega
/--
Precondition: The channel must not be closed.
-/
private def trySend' (v : α) : AtomicT (Bounded.State α) BaseIO Bool := do
let mut st get
if st.bufCount = st.capacity then
return false
else
st.buf[st.sendIdx]'st.hsend |>.set (some v)
st := { st with
bufCount := st.bufCount + 1
sendIdx := incMod st.sendIdx st.capacity
hsend := incMod_lt st.hsend
}
if let some (consumer, consumers) := st.consumers.dequeue? then
consumer.resolve true
st := { st with consumers }
set st
return true
private def trySend (ch : Bounded α) (v : α) : BaseIO Bool := do
ch.state.atomically do
if ( get).closed then
return false
else
trySend' v
private partial def send (ch : Bounded α) (v : α) : BaseIO (Task (Except Error Unit)) := do
ch.state.atomically do
if ( get).closed then
return .pure <| .error .closed
else if trySend' v then
return .pure <| .ok ()
else
let promise IO.Promise.new
modify fun st => { st with producers := st.producers.enqueue promise }
BaseIO.bindTask promise.result? fun res => do
if res.getD false then
Bounded.send ch v
else
return .pure <| .error .closed
private def close (ch : Bounded α) : EIO Error Unit := do
ch.state.atomically do
let st get
if st.closed then throw .alreadyClosed
for consumer in st.consumers.toArray do consumer.resolve false
set { st with consumers := , closed := true }
return ()
private def isClosed (ch : Bounded α) : BaseIO Bool :=
ch.state.atomically do
return ( get).closed
private def tryRecv' : AtomicT (Bounded.State α) BaseIO (Option α) := do
let st get
if st.bufCount == 0 then
return none
else
let val st.buf[st.recvIdx]'st.hrecv |>.swap none
let nextRecvIdx := incMod st.recvIdx st.capacity
set { st with
bufCount := st.bufCount - 1
recvIdx := nextRecvIdx,
hrecv := incMod_lt st.hrecv
}
return val
private def tryRecv (ch : Bounded α) : BaseIO (Option α) :=
ch.state.atomically do
tryRecv'
private partial def recv (ch : Bounded α) : BaseIO (Task (Option α)) := do
ch.state.atomically do
if let some val tryRecv' then
let st get
if let some (producer, producers) := ( get).producers.dequeue? then
producer.resolve true
set { st with producers }
return .pure <| some val
else if ( get).closed then
return .pure none
else
let promise IO.Promise.new
modify fun st => { st with consumers := st.consumers.enqueue promise }
BaseIO.bindTask promise.result? fun res => do
if res.getD false then
Bounded.recv ch
else
return .pure none
end Bounded
/--
This type represents all flavors of channels that we have available.
-/
private inductive Flavors (α : Type) where
| unbounded (ch : Unbounded α)
| zero (ch : Zero α)
| bounded (ch : Bounded α)
deriving Nonempty
end CloseableChannel
/--
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
and an asynchronous API, to switch into synchronous mode use `CloseableChannel.sync`.
Additionally `Std.CloseableChannel` can be closed if necessary, unlike `Std.Channel`.
This introduces a need for error handling in some cases, thus it is usually easier to use
`Std.Channel` if applicable.
-/
def CloseableChannel (α : Type) : Type := CloseableChannel.Flavors α
/--
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
and a synchronous API. This type acts as a convenient layer to use a channel in a blocking fashion
and is not actually different from the original channel.
Additionally `Std.CloseableChannel.Sync` can be closed if necessary, unlike `Std.Channel.Sync`.
This introduces the need to handle errors in some cases, thus it is usually easier to use
`Std.Channel` if applicable.
-/
def CloseableChannel.Sync (α : Type) : Type := CloseableChannel α
instance : Nonempty (CloseableChannel α) :=
inferInstanceAs (Nonempty (CloseableChannel.Flavors α))
instance : Nonempty (CloseableChannel.Sync α) :=
inferInstanceAs (Nonempty (CloseableChannel α))
namespace CloseableChannel
/--
Create a new channel, if:
- `capacity` is `none` it will be unbounded (the default)
- `capacity` is `some 0` it will always force a rendezvous between sender and receiver
- `capacity` is `some n` with `n > 0` it will use a buffer of size `n` and begin blocking once it
is filled
-/
def new (capacity : Option Nat := none) : BaseIO (CloseableChannel α) := do
match capacity with
| none => return .unbounded ( CloseableChannel.Unbounded.new)
| some 0 => return .zero ( CloseableChannel.Zero.new)
| some (n + 1) => return .bounded ( CloseableChannel.Bounded.new (n + 1) (by omega))
/--
Try to send a value to the channel, if this can be completed right away without blocking return
`true`, otherwise don't send the value and return `false`.
-/
def trySend (ch : CloseableChannel α) (v : α) : BaseIO Bool :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.trySend ch v
| .zero ch => CloseableChannel.Zero.trySend ch v
| .bounded ch => CloseableChannel.Bounded.trySend ch v
/--
Send a value through the channel, returning a task that will resolve once the transmission could be
completed. Note that the task may resolve to `Except.error` if the channel was closed before it
could be completed.
-/
def send (ch : CloseableChannel α) (v : α) : BaseIO (Task (Except Error Unit)) :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.send ch v
| .zero ch => CloseableChannel.Zero.send ch v
| .bounded ch => CloseableChannel.Bounded.send ch v
/--
Closes the channel, returns `Except.ok` when called the first time, otherwise `Except.error`.
When a channel is closed:
- no new values can be sent successfully anymore
- all blocked consumers are resolved to `none` (as no new messages can be sent they will never
resolve)
- if there are already values waiting to be received they can still be received by subsequent `recv`
calls
-/
def close (ch : CloseableChannel α) : EIO Error Unit :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.close ch
| .zero ch => CloseableChannel.Zero.close ch
| .bounded ch => CloseableChannel.Bounded.close ch
/--
Return `true` if the channel is closed.
-/
def isClosed (ch : CloseableChannel α) : BaseIO Bool :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.isClosed ch
| .zero ch => CloseableChannel.Zero.isClosed ch
| .bounded ch => CloseableChannel.Bounded.isClosed ch
/--
Try to receive a value from the channel, if this can be completed right away without blocking return
`some value`, otherwise return `none`.
-/
def tryRecv (ch : CloseableChannel α) : BaseIO (Option α) :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.tryRecv ch
| .zero ch => CloseableChannel.Zero.tryRecv ch
| .bounded ch => CloseableChannel.Bounded.tryRecv ch
/--
Receive a value from the channel, returning a task that will resolve once the transmission could be
completed. Note that the task may resolve to `none` if the channel was closed before it could be
completed.
-/
def recv (ch : CloseableChannel α) : BaseIO (Task (Option α)) :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.recv ch
| .zero ch => CloseableChannel.Zero.recv ch
| .bounded ch => CloseableChannel.Bounded.recv ch
/--
`ch.forAsync f` calls `f` for every message received on `ch`.
Note that if this function is called twice, each message will only arrive at exactly one invocation.
-/
partial def forAsync (f : α BaseIO Unit) (ch : CloseableChannel α)
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
BaseIO.bindTask (prio := prio) ( ch.recv?) fun
BaseIO.bindTask (prio := prio) ( ch.recv) fun
| none => return .pure ()
| some v => do f v; ch.forAsync f prio
/--
Receives all currently queued messages from the channel.
Those messages are dequeued and will not be returned by `recv?`.
This function is a no-op and just a convenient way to expose the synchronous API of the channel.
-/
def Channel.recvAllCurrent (ch : Channel α) : BaseIO (Array α) :=
ch.atomically do
modifyGet fun st => (st.values.toArray, { st with values := })
@[inline]
def sync (ch : CloseableChannel α) : CloseableChannel.Sync α := ch
/-- Type tag for synchronous (blocking) operations on a `Channel`. -/
def Channel.Sync := Channel
namespace Sync
@[inherit_doc CloseableChannel.new, inline]
def new (capacity : Option Nat := none) : BaseIO (Sync α) := CloseableChannel.new capacity
@[inherit_doc CloseableChannel.trySend, inline]
def trySend (ch : Sync α) (v : α) : BaseIO Bool := CloseableChannel.trySend ch v
/--
Accesses synchronous (blocking) version of channel operations.
For example, `ch.sync.recv?` blocks until the next message,
and `for msg in ch.sync do ...` iterates synchronously over the channel.
These functions should only be used in dedicated threads.
Send a value through the channel, blocking until the transmission could be completed. Note that this
function may throw an error when trying to send to an already closed channel.
-/
def Channel.sync (ch : Channel α) : Channel.Sync α := ch
def send (ch : Sync α) (v : α) : EIO Error Unit := do
EIO.ofExcept ( IO.wait ( CloseableChannel.send ch v))
@[inherit_doc CloseableChannel.close, inline]
def close (ch : Sync α) : EIO Error Unit := CloseableChannel.close ch
@[inherit_doc CloseableChannel.isClosed, inline]
def isClosed (ch : Sync α) : BaseIO Bool := CloseableChannel.isClosed ch
@[inherit_doc CloseableChannel.tryRecv, inline]
def tryRecv (ch : Sync α) : BaseIO (Option α) := CloseableChannel.tryRecv ch
/--
Synchronously receives a message from the channel.
Every message is only received once.
Returns `none` if the channel is closed and the queue is empty.
Receive a value from the channel, blocking unitl the transmission could be completed. Note that the
return value may be `none` if the channel was closed before it could be completed.
-/
def Channel.Sync.recv? (ch : Channel.Sync α) : BaseIO (Option α) := do
IO.wait ( Channel.recv? ch)
def recv (ch : Sync α) : BaseIO (Option α) := do
IO.wait ( CloseableChannel.recv ch)
private partial def Channel.Sync.forIn [Monad m] [MonadLiftT BaseIO m]
(ch : Channel.Sync α) (f : α β m (ForInStep β)) : β m β := fun b => do
match ch.recv? with
| some a =>
match f a b with
| .done b => pure b
| .yield b => ch.forIn f b
| none => pure b
private partial def forIn [Monad m] [MonadLiftT BaseIO m]
(ch : Sync α) (f : α β m (ForInStep β)) : β m β := fun b => do
match ch.recv with
| some a =>
match f a b with
| .done b => pure b
| .yield b => ch.forIn f b
| none => pure b
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
instance [MonadLiftT BaseIO m] : ForIn m (Channel.Sync α) α where
instance [MonadLiftT BaseIO m] : ForIn m (Sync α) α where
forIn ch b f := ch.forIn f b
end Sync
end CloseableChannel
/--
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
and an asynchronous API, to switch into synchronous mode use `Channel.sync`.
If a channel needs to be closed to indicate some sort of completion event use `Std.CloseableChannel`
instead. Note that `Std.CloseableChannel` introduces a need for error handling in some cases, thus
`Std.Channel` is usually easier to use if applicable.
-/
structure Channel (α : Type) where
private mk ::
private inner : CloseableChannel α
deriving Nonempty
/--
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
and a synchronous API. This type acts as a convenient layer to use a channel in a blocking fashion
and is not actually different from the original channel.
If a channel needs to be closed to indicate some sort of completion event use
`Std.CloseableChannel.Sync` instead. Note that `Std.CloseableChannel.Sync` introduces a need for error
handling in some cases, thus `Std.Channel.Sync` is usually easier to use if applicable.
-/
def Channel.Sync (α : Type) : Type := Channel α
instance : Nonempty (Channel.Sync α) :=
inferInstanceAs (Nonempty (Channel α))
namespace Channel
@[inherit_doc CloseableChannel.new, inline]
def new (capacity : Option Nat := none) : BaseIO (Channel α) := do
return CloseableChannel.new capacity
@[inherit_doc CloseableChannel.trySend, inline]
def trySend (ch : Channel α) (v : α) : BaseIO Bool :=
CloseableChannel.trySend ch.inner v
/--
Send a value through the channel, returning a task that will resolve once the transmission could be
completed.
-/
def send (ch : Channel α) (v : α) : BaseIO (Task Unit) := do
BaseIO.bindTask (sync := true) ( CloseableChannel.send ch.inner v)
fun
| .ok .. => return .pure ()
| .error .. => unreachable!
@[inherit_doc CloseableChannel.tryRecv, inline]
def tryRecv (ch : Channel α) : BaseIO (Option α) :=
CloseableChannel.tryRecv ch.inner
@[inherit_doc CloseableChannel.recv]
def recv [Inhabited α] (ch : Channel α) : BaseIO (Task α) := do
BaseIO.bindTask (sync := true) ( CloseableChannel.recv ch.inner)
fun
| some val => return .pure val
| none => unreachable!
@[inherit_doc CloseableChannel.forAsync]
partial def forAsync [Inhabited α] (f : α BaseIO Unit) (ch : Channel α)
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
BaseIO.bindTask (prio := prio) ( ch.recv) fun v => do f v; ch.forAsync f prio
@[inherit_doc CloseableChannel.sync, inline]
def sync (ch : Channel α) : Channel.Sync α := ch
namespace Sync
@[inherit_doc Channel.new, inline]
def new (capacity : Option Nat := none) : BaseIO (Sync α) := Channel.new capacity
@[inherit_doc Channel.trySend, inline]
def trySend (ch : Sync α) (v : α) : BaseIO Bool := Channel.trySend ch v
/--
Send a value through the channel, blocking until the transmission could be completed.
-/
def send (ch : Sync α) (v : α) : BaseIO Unit := do
IO.wait ( Channel.send ch v)
@[inherit_doc Channel.tryRecv, inline]
def tryRecv (ch : Sync α) : BaseIO (Option α) := Channel.tryRecv ch
/--
Receive a value from the channel, blocking unitl the transmission could be completed.
-/
def recv [Inhabited α] (ch : Sync α) : BaseIO α := do
IO.wait ( Channel.recv ch)
private partial def forIn [Inhabited α] [Monad m] [MonadLiftT BaseIO m]
(ch : Sync α) (f : α β m (ForInStep β)) : β m β := fun b => do
let a ch.recv
match f a b with
| .done b => pure b
| .yield b => ch.forIn f b
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
instance [Inhabited α] [MonadLiftT BaseIO m] : ForIn m (Sync α) α where
forIn ch b f := ch.forIn f b
end Sync
end Channel
end Std

View File

@@ -246,113 +246,116 @@ instance : Hashable (BVExpr w) where
def decEq : DecidableEq (BVExpr w) := fun l r =>
withPtrEqDecEq l r fun _ =>
match l with
| .var lidx =>
match r with
| .var ridx =>
if h : lidx = ridx then .isTrue (by simp [h]) else .isFalse (by simp [h])
| .const .. | .extract .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .const lval =>
match r with
| .const rval =>
if h : lval = rval then .isTrue (by simp [h]) else .isFalse (by simp [h])
| .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .extract (w := lw) lstart _ lexpr =>
match r with
| .extract (w := rw) rstart _ rexpr =>
if h1 : lw = rw lstart = rstart then
match decEq (h1.left lexpr) rexpr with
| .isTrue h2 => .isTrue (by cases h1.left; simp_all)
| .isFalse h2 => .isFalse (by cases h1.left; simp_all)
else
.isFalse (by simp_all)
| .var .. | .const .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .bin llhs lop lrhs =>
match r with
| .bin rlhs rop rrhs =>
if h1 : lop = rop then
match decEq llhs rlhs, decEq lrhs rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by simp [h1, h2, h3])
| .isFalse h2, _ => .isFalse (by simp [h2])
| _, .isFalse h3 => .isFalse (by simp [h3])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .un lop lexpr =>
match r with
| .un rop rexpr =>
if h1 : lop = rop then
match decEq lexpr rexpr with
| .isTrue h2 => .isTrue (by simp [h1, h2])
| .isFalse h2 => .isFalse (by simp [h2])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .bin .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .append (l := ll) (r := lr) llhs lrhs lh =>
match r with
| .append (l := rl) (r := rr) rlhs rrhs rh =>
if h1 : ll = rl lr = rr then
match decEq (h1.left llhs) rlhs, decEq (h1.right lrhs) rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1.left; cases h1.right; simp [h2, h3])
| .isFalse h2, _ => .isFalse (by cases h1.left; cases h1.right; simp [h2])
| _, .isFalse h3 => .isFalse (by cases h1.left; cases h1.right; simp [h3])
else
.isFalse (by simp; omega)
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .replicate (w := lw) ln lexpr lh =>
match r with
| .replicate (w := rw) rn rexpr rh =>
if h1 : ln = rn lw = rw then
match decEq (h1.right lexpr) rexpr with
| .isTrue h2 => .isTrue (by cases h1.left; cases h1.right; simp [h2])
| .isFalse h2 => .isFalse (by cases h1.left; cases h1.right; simp [h2])
else
.isFalse (by simp; omega)
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
if h : hash l hash r then
.isFalse (ne_of_apply_ne hash h)
else
match l with
| .var lidx =>
match r with
| .var ridx =>
if h : lidx = ridx then .isTrue (by simp [h]) else .isFalse (by simp [h])
| .const .. | .extract .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .const lval =>
match r with
| .const rval =>
if h : lval = rval then .isTrue (by simp [h]) else .isFalse (by simp [h])
| .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .extract (w := lw) lstart _ lexpr =>
match r with
| .extract (w := rw) rstart _ rexpr =>
if h1 : lw = rw lstart = rstart then
match decEq (h1.left lexpr) rexpr with
| .isTrue h2 => .isTrue (by cases h1.left; simp_all)
| .isFalse h2 => .isFalse (by cases h1.left; simp_all)
else
.isFalse (by simp_all)
| .var .. | .const .. | .bin .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .bin llhs lop lrhs =>
match r with
| .bin rlhs rop rrhs =>
if h1 : lop = rop then
match decEq llhs rlhs, decEq lrhs rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by simp [h1, h2, h3])
| .isFalse h2, _ => .isFalse (by simp [h2])
| _, .isFalse h3 => .isFalse (by simp [h3])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .un .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .un lop lexpr =>
match r with
| .un rop rexpr =>
if h1 : lop = rop then
match decEq lexpr rexpr with
| .isTrue h2 => .isTrue (by simp [h1, h2])
| .isFalse h2 => .isFalse (by simp [h2])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .bin .. | .append .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .append (l := ll) (r := lr) llhs lrhs lh =>
match r with
| .append (l := rl) (r := rr) rlhs rrhs rh =>
if h1 : ll = rl lr = rr then
match decEq (h1.left llhs) rlhs, decEq (h1.right lrhs) rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1.left; cases h1.right; simp [h2, h3])
| .isFalse h2, _ => .isFalse (by cases h1.left; cases h1.right; simp [h2])
| _, .isFalse h3 => .isFalse (by cases h1.left; cases h1.right; simp [h3])
else
.isFalse (by simp; omega)
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .replicate .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .replicate (w := lw) ln lexpr lh =>
match r with
| .replicate (w := rw) rn rexpr rh =>
if h1 : ln = rn lw = rw then
match decEq (h1.right lexpr) rexpr with
| .isTrue h2 => .isTrue (by cases h1.left; cases h1.right; simp [h2])
| .isFalse h2 => .isFalse (by cases h1.left; cases h1.right; simp [h2])
else
.isFalse (by simp; omega)
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .shiftLeft ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .shiftLeft (n := lw) llhs lrhs =>
match r with
| .shiftLeft (n := rw) rlhs rrhs =>
if h1 : lw = rw then
match decEq llhs rlhs, decEq (h1 lrhs) rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .shiftRight (n := lw) llhs lrhs =>
match r with
| .shiftRight (n := rw) rlhs rrhs =>
if h1 : lw = rw then
match decEq llhs rlhs, decEq (h1 lrhs) rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
|.shiftLeft .. | .arithShiftRight .. => .isFalse (by simp)
| .arithShiftRight (n := lw) llhs lrhs =>
match r with
| .arithShiftRight (n := rw) rlhs rrhs =>
if h1 : lw = rw then
match decEq llhs rlhs, decEq (h1 lrhs) rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
| .shiftRight .. | .shiftLeft .. => .isFalse (by simp)
| .shiftLeft (n := lw) llhs lrhs =>
match r with
| .shiftLeft (n := rw) rlhs rrhs =>
if h1 : lw = rw then
match decEq llhs rlhs, decEq (h1 lrhs) rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
| .shiftRight .. | .arithShiftRight .. => .isFalse (by simp)
| .shiftRight (n := lw) llhs lrhs =>
match r with
| .shiftRight (n := rw) rlhs rrhs =>
if h1 : lw = rw then
match decEq llhs rlhs, decEq (h1 lrhs) rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
|.shiftLeft .. | .arithShiftRight .. => .isFalse (by simp)
| .arithShiftRight (n := lw) llhs lrhs =>
match r with
| .arithShiftRight (n := rw) rlhs rrhs =>
if h1 : lw = rw then
match decEq llhs rlhs, decEq (h1 lrhs) rrhs with
| .isTrue h2, .isTrue h3 => .isTrue (by cases h1; simp [h2, h3])
| .isFalse h2, _ => .isFalse (by cases h1; simp [h2])
| _, .isFalse h3 => .isFalse (by cases h1; simp [h3])
else
.isFalse (by simp [h1])
| .const .. | .var .. | .extract .. | .bin .. | .un .. | .append .. | .replicate ..
| .shiftRight .. | .shiftLeft .. => .isFalse (by simp)
instance : DecidableEq (BVExpr w) := decEq

View File

@@ -9,6 +9,7 @@ Author: Leonardo de Moura
#include <utility>
#include <unordered_map>
#include "kernel/replace_fn.h"
#include "util/alloc.h"
namespace lean {
@@ -18,7 +19,7 @@ class replace_rec_fn {
return hash((size_t)p.first >> 3, p.second);
}
};
std::unordered_map<std::pair<lean_object *, unsigned>, expr, key_hasher> m_cache;
lean::unordered_map<std::pair<lean_object *, unsigned>, expr, key_hasher> m_cache;
std::function<optional<expr>(expr const &, unsigned)> m_f;
bool m_use_cache;
@@ -84,7 +85,7 @@ expr replace(expr const & e, std::function<optional<expr>(expr const &, unsigned
}
class replace_fn {
std::unordered_map<lean_object *, expr> m_cache;
lean::unordered_map<lean_object *, expr> m_cache;
lean_object * m_f;
expr save_result(expr const & e, expr const & r, bool shared) {

View File

@@ -1,6 +1,9 @@
#!/usr/bin/env bash
set -euxo pipefail
exit 0 # TODO: flaky test disabled
# test disabled
LAKE=${LAKE:-../../.lake/build/bin/lake}
./clean.sh

View File

@@ -21,6 +21,7 @@ Authors: Leonardo de Moura, Gabriel Ebner, Sebastian Ullrich
#include "runtime/io.h"
#include "runtime/compact.h"
#include "runtime/buffer.h"
#include "runtime/array_ref.h"
#include "util/io.h"
#include "util/name_map.h"
#include "library/module.h"
@@ -76,47 +77,73 @@ struct olean_header {
// make sure we don't have any padding bytes, which also ensures `data` is properly aligned
static_assert(sizeof(olean_header) == 5 + 1 + 1 + 33 + 40 + sizeof(size_t), "olean_header must be packed");
extern "C" LEAN_EXPORT object * lean_save_module_data(b_obj_arg fname, b_obj_arg mod, b_obj_arg mdata, object *) {
std::string olean_fn(string_cstr(fname));
// we first write to a temp file and then move it to the correct path (possibly deleting an older file)
// so that we neither expose partially-written files nor modify possibly memory-mapped files
std::string olean_tmp_fn = olean_fn + ".tmp";
try {
std::ofstream out(olean_tmp_fn, std::ios_base::binary);
if (out.fail()) {
return io_result_mk_error((sstream() << "failed to create file '" << olean_fn << "'").str());
extern "C" LEAN_EXPORT object * lean_save_module_data_parts(b_obj_arg mod, b_obj_arg oparts, object *) {
#ifdef LEAN_WINDOWS
uint32_t pid = GetCurrentProcessId();
#else
uint32_t pid = getpid();
#endif
// Derive a base address that is uniformly distributed by deterministic, and should most likely
// work for `mmap` on all interesting platforms
// NOTE: an overlapping/non-compatible base address does not prevent the module from being imported,
// merely from using `mmap` for that
// Let's start with a hash of the module name. Note that while our string hash is a dubious 32-bit
// algorithm, the mixing of multiple `Name` parts seems to result in a nicely distributed 64-bit
// output
size_t base_addr = name(mod, true).hash();
// x86-64 user space is currently limited to the lower 47 bits
// https://en.wikipedia.org/wiki/X86-64#Virtual_address_space_details
// On Linux at least, the stack grows down from ~0x7fff... followed by shared libraries, so reserve
// a bit of space for them (0x7fff...-0x7f00... = 1TB)
base_addr = base_addr % 0x7f0000000000;
// `mmap` addresses must be page-aligned. The default (non-huge) page size on x86-64 is 4KB.
// `MapViewOfFileEx` addresses must be aligned to the "memory allocation granularity", which is 64KB.
const size_t ALIGN = 1LL<<16;
base_addr = base_addr & ~(ALIGN - 1);
object_compactor compactor(reinterpret_cast<void *>(base_addr));
array_ref<pair_ref<string_ref, object_ref>> parts(oparts, true);
std::vector<std::string> tmp_fnames;
for (auto const & part : parts) {
std::string olean_fn = part.fst().to_std_string();
try {
// we first write to a temp file and then move it to the correct path (possibly deleting an older file)
// so that we neither expose partially-written files nor modify possibly memory-mapped files
std::string olean_tmp_fn = olean_fn + ".tmp." + std::to_string(pid);
tmp_fnames.push_back(olean_tmp_fn);
std::ofstream out(olean_tmp_fn, std::ios_base::binary);
if (compactor.size() % ALIGN != 0) {
compactor.alloc(ALIGN - (compactor.size() % ALIGN));
}
size_t file_offset = compactor.size();
compactor.alloc(sizeof(olean_header));
olean_header header = {};
// see/sync with file format description above
header.base_addr = base_addr + file_offset;
strncpy(header.lean_version, get_short_version_string().c_str(), sizeof(header.lean_version));
strncpy(header.githash, LEAN_GITHASH, sizeof(header.githash));
out.write(reinterpret_cast<char *>(&header), sizeof(header));
compactor(part.snd().raw());
if (out.fail()) {
throw exception((sstream() << "failed to create file '" << olean_fn << "'").str());
}
out.write(static_cast<char const *>(compactor.data()) + file_offset + sizeof(olean_header), compactor.size() - file_offset - sizeof(olean_header));
out.close();
} catch (exception & ex) {
return io_result_mk_error((sstream() << "failed to write '" << olean_fn << "': " << ex.what()).str());
}
}
// Derive a base address that is uniformly distributed by deterministic, and should most likely
// work for `mmap` on all interesting platforms
// NOTE: an overlapping/non-compatible base address does not prevent the module from being imported,
// merely from using `mmap` for that
// Let's start with a hash of the module name. Note that while our string hash is a dubious 32-bit
// algorithm, the mixing of multiple `Name` parts seems to result in a nicely distributed 64-bit
// output
size_t base_addr = name(mod, true).hash();
// x86-64 user space is currently limited to the lower 47 bits
// https://en.wikipedia.org/wiki/X86-64#Virtual_address_space_details
// On Linux at least, the stack grows down from ~0x7fff... followed by shared libraries, so reserve
// a bit of space for them (0x7fff...-0x7f00... = 1TB)
base_addr = base_addr % 0x7f0000000000;
// `mmap` addresses must be page-aligned. The default (non-huge) page size on x86-64 is 4KB.
// `MapViewOfFileEx` addresses must be aligned to the "memory allocation granularity", which is 64KB.
base_addr = base_addr & ~((1LL<<16) - 1);
object_compactor compactor(reinterpret_cast<void *>(base_addr + offsetof(olean_header, data)));
compactor(mdata);
// see/sync with file format description above
olean_header header = {};
header.base_addr = base_addr;
strncpy(header.lean_version, get_short_version_string().c_str(), sizeof(header.lean_version));
strncpy(header.githash, LEAN_GITHASH, sizeof(header.githash));
out.write(reinterpret_cast<char *>(&header), sizeof(header));
out.write(static_cast<char const *>(compactor.data()), compactor.size());
out.close();
while (std::rename(olean_tmp_fn.c_str(), olean_fn.c_str()) != 0) {
for (unsigned i = 0; i < parts.size(); i++) {
std::string olean_fn = parts[i].fst().to_std_string();
while (std::rename(tmp_fnames[i].c_str(), olean_fn.c_str()) != 0) {
#ifdef LEAN_WINDOWS
if (errno == EEXIST) {
// Memory-mapped files can be deleted starting with Windows 10 using "POSIX semantics"
@@ -136,94 +163,148 @@ extern "C" LEAN_EXPORT object * lean_save_module_data(b_obj_arg fname, b_obj_arg
#endif
return io_result_mk_error((sstream() << "failed to write '" << olean_fn << "': " << errno << " " << strerror(errno)).str());
}
return io_result_mk_ok(box(0));
} catch (exception & ex) {
return io_result_mk_error((sstream() << "failed to write '" << olean_fn << "': " << ex.what()).str());
}
return io_result_mk_ok(box(0));
}
extern "C" LEAN_EXPORT object * lean_read_module_data(object * fname, object *) {
std::string olean_fn(string_cstr(fname));
try {
std::ifstream in(olean_fn, std::ios_base::binary);
if (in.fail()) {
return io_result_mk_error((sstream() << "failed to open file '" << olean_fn << "'").str());
}
/* Get file size */
in.seekg(0, in.end);
size_t size = in.tellg();
in.seekg(0);
struct module_file {
std::string m_fname;
std::ifstream m_in;
char * m_base_addr;
size_t m_size;
char * m_buffer;
std::function<void()> m_free_data;
};
olean_header default_header = {};
olean_header header;
if (!in.read(reinterpret_cast<char *>(&header), sizeof(header))
|| memcmp(header.marker, default_header.marker, sizeof(header.marker)) != 0) {
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "', invalid header").str());
}
if (header.version != default_header.version || header.flags != default_header.flags
#ifdef LEAN_CHECK_OLEAN_VERSION
|| strncmp(header.githash, LEAN_GITHASH, sizeof(header.githash)) != 0
#endif
) {
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "', incompatible header").str());
}
char * base_addr = reinterpret_cast<char *>(header.base_addr);
char * buffer = nullptr;
bool is_mmap = false;
std::function<void()> free_data;
#ifdef LEAN_WINDOWS
// `FILE_SHARE_DELETE` is necessary to allow the file to (be marked to) be deleted while in use
HANDLE h_olean_fn = CreateFile(olean_fn.c_str(), GENERIC_READ, FILE_SHARE_READ | FILE_SHARE_DELETE, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
if (h_olean_fn == INVALID_HANDLE_VALUE) {
return io_result_mk_error((sstream() << "failed to open '" << olean_fn << "': " << GetLastError()).str());
}
HANDLE h_map = CreateFileMapping(h_olean_fn, NULL, PAGE_READONLY, 0, 0, NULL);
if (h_olean_fn == NULL) {
return io_result_mk_error((sstream() << "failed to map '" << olean_fn << "': " << GetLastError()).str());
}
buffer = static_cast<char *>(MapViewOfFileEx(h_map, FILE_MAP_READ, 0, 0, 0, base_addr));
free_data = [=]() {
if (buffer) {
lean_always_assert(UnmapViewOfFile(base_addr));
extern "C" LEAN_EXPORT object * lean_read_module_data_parts(b_obj_arg ofnames, object *) {
array_ref<string_ref> fnames(ofnames, true);
// first read in all headers
std::vector<module_file> files;
for (auto const & fname : fnames) {
std::string olean_fn = fname.to_std_string();
try {
std::ifstream in(olean_fn, std::ios_base::binary);
if (in.fail()) {
return io_result_mk_error((sstream() << "failed to open file '" << olean_fn << "'").str());
}
/* Get file size */
in.seekg(0, in.end);
size_t size = in.tellg();
in.seekg(0);
olean_header default_header = {};
olean_header header;
if (!in.read(reinterpret_cast<char *>(&header), sizeof(header))
|| memcmp(header.marker, default_header.marker, sizeof(header.marker)) != 0) {
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "', invalid header").str());
}
in.seekg(0);
if (header.version != default_header.version || header.flags != default_header.flags
#ifdef LEAN_CHECK_OLEAN_VERSION
|| strncmp(header.githash, LEAN_GITHASH, sizeof(header.githash)) != 0
#endif
) {
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "', incompatible header").str());
}
char * base_addr = reinterpret_cast<char *>(header.base_addr);
files.push_back({olean_fn, std::move(in), base_addr, size, nullptr, nullptr});
} catch (exception & ex) {
return io_result_mk_error((sstream() << "failed to read '" << olean_fn << "': " << ex.what()).str());
}
}
#ifndef LEAN_MMAP
bool is_mmap = false;
#else
// now try mmapping *all* files
bool is_mmap = true;
for (auto & file : files) {
std::string const & olean_fn = file.m_fname;
char * base_addr = file.m_base_addr;
try {
#ifdef LEAN_WINDOWS
// `FILE_SHARE_DELETE` is necessary to allow the file to (be marked to) be deleted while in use
HANDLE h_olean_fn = CreateFile(olean_fn.c_str(), GENERIC_READ, FILE_SHARE_READ | FILE_SHARE_DELETE, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
if (h_olean_fn == INVALID_HANDLE_VALUE) {
return io_result_mk_error((sstream() << "failed to open '" << olean_fn << "': " << GetLastError()).str());
}
HANDLE h_map = CreateFileMapping(h_olean_fn, NULL, PAGE_READONLY, 0, 0, NULL);
if (h_olean_fn == NULL) {
return io_result_mk_error((sstream() << "failed to map '" << olean_fn << "': " << GetLastError()).str());
}
char * buffer = static_cast<char *>(MapViewOfFileEx(h_map, FILE_MAP_READ, 0, 0, 0, base_addr));
lean_always_assert(CloseHandle(h_map));
lean_always_assert(CloseHandle(h_olean_fn));
};
#else
int fd = open(olean_fn.c_str(), O_RDONLY);
if (fd == -1) {
return io_result_mk_error((sstream() << "failed to open '" << olean_fn << "': " << strerror(errno)).str());
}
#ifdef LEAN_MMAP
buffer = static_cast<char *>(mmap(base_addr, size, PROT_READ, MAP_PRIVATE, fd, 0));
#endif
close(fd);
free_data = [=]() {
if (buffer != MAP_FAILED) {
lean_always_assert(munmap(buffer, size) == 0);
if (!buffer) {
is_mmap = false;
break;
}
};
#endif
if (buffer && buffer == base_addr) {
buffer += sizeof(olean_header);
is_mmap = true;
} else {
#ifdef LEAN_MMAP
free_data();
#endif
buffer = static_cast<char *>(malloc(size - sizeof(olean_header)));
free_data = [=]() {
free_sized(buffer, size - sizeof(olean_header));
file.m_free_data = [=]() {
lean_always_assert(UnmapViewOfFile(base_addr));
};
in.read(buffer, size - sizeof(olean_header));
if (!in) {
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "'").str());
#else
int fd = open(olean_fn.c_str(), O_RDONLY);
if (fd == -1) {
return io_result_mk_error((sstream() << "failed to open '" << olean_fn << "': " << strerror(errno)).str());
}
char * buffer = static_cast<char *>(mmap(base_addr, file.m_size, PROT_READ, MAP_PRIVATE, fd, 0));
if (buffer == MAP_FAILED) {
is_mmap = false;
break;
}
close(fd);
size_t size = file.m_size;
file.m_free_data = [=]() {
lean_always_assert(munmap(buffer, size) == 0);
};
#endif
if (buffer == base_addr) {
file.m_buffer = buffer;
} else {
is_mmap = false;
break;
}
} catch (exception & ex) {
return io_result_mk_error((sstream() << "failed to read '" << olean_fn << "': " << ex.what()).str());
}
}
#endif
// if *any* file failed to mmap, read all of them into a single big allocation so that offsets
// between them are unchanged
if (!is_mmap) {
for (auto & file : files) {
if (file.m_free_data) {
file.m_free_data();
file.m_free_data = {};
}
}
in.close();
size_t big_size = files[files.size()-1].m_base_addr + files[files.size()-1].m_size - files[0].m_base_addr;
char * big_buffer = static_cast<char *>(malloc(big_size));
for (auto & file : files) {
std::string const & olean_fn = file.m_fname;
try {
file.m_buffer = big_buffer + (file.m_base_addr - files[0].m_base_addr);
file.m_in.read(file.m_buffer, file.m_size);
if (!file.m_in) {
return io_result_mk_error((sstream() << "failed to read file '" << olean_fn << "'").str());
}
file.m_in.close();
} catch (exception & ex) {
return io_result_mk_error((sstream() << "failed to read '" << olean_fn << "': " << ex.what()).str());
}
}
files[0].m_free_data = [=]() {
free_sized(big_buffer, big_size);
};
}
std::vector<object_ref> res;
for (auto & file : files) {
compacted_region * region =
new compacted_region(size - sizeof(olean_header), buffer, base_addr + sizeof(olean_header), is_mmap, free_data);
new compacted_region(file.m_size - sizeof(olean_header), file.m_buffer + sizeof(olean_header), static_cast<char *>(file.m_base_addr) + sizeof(olean_header), is_mmap, file.m_free_data);
#if defined(__has_feature)
#if __has_feature(address_sanitizer)
// do not report as leak
@@ -234,18 +315,9 @@ extern "C" LEAN_EXPORT object * lean_read_module_data(object * fname, object *)
object * mod_region = alloc_cnstr(0, 2, 0);
cnstr_set(mod_region, 0, mod);
cnstr_set(mod_region, 1, box_size_t(reinterpret_cast<size_t>(region)));
return io_result_mk_ok(mod_region);
} catch (exception & ex) {
return io_result_mk_error((sstream() << "failed to read '" << olean_fn << "': " << ex.what()).str());
res.push_back(object_ref(mod_region));
}
}
/*
@[export lean.write_module_core]
def writeModule (env : Environment) (fname : String) : IO Unit := */
extern "C" object * lean_write_module(object * env, object * fname, object *);
void write_module(elab_environment const & env, std::string const & olean_fn) {
consume_io_result(lean_write_module(env.to_obj_arg(), mk_string(olean_fn), io_mk_world()));
return io_result_mk_ok(to_array(res));
}
}

Some files were not shown because too many files have changed in this diff Show More