Compare commits

...

36 Commits

Author SHA1 Message Date
Henrik Böving
118709ee6c chore: update stage0 2024-08-07 20:24:35 +02:00
Henrik Böving
c2575107f2 refactor: windows does not have getline 2024-08-07 19:25:18 +02:00
Henrik Böving
4db828d885 chore: move test 2024-08-07 19:25:18 +02:00
Henrik Böving
3f16f339e7 chore: use guard_msgs 2024-08-07 19:25:18 +02:00
Henrik Böving
2771296ca5 fix: keep trying to read after mdata size 2024-08-07 19:25:17 +02:00
Henrik Böving
5e337872ce test: NULL byte safety of Handle.putStr and Handle.getLine 2024-08-07 19:25:17 +02:00
Henrik Böving
329fa6309b fix: Handle.putStr for strings containing NULL bytes. 2024-08-07 19:25:17 +02:00
Henrik Böving
370488a9ff fix: Handle.getLine for lines containing NULL bytes 2024-08-07 19:25:17 +02:00
Henrik Böving
df38da8e09 feat: add lean_mk_string_from_bytes to object.h 2024-08-07 19:25:17 +02:00
Henrik Böving
a2b93d6c18 feat: read entire files in one system call 2024-08-07 19:25:17 +02:00
Markus Himmel
63c4de5fea chore: update stage0 2024-08-07 18:24:42 +02:00
Markus Himmel
3b14642c42 chore: build Lake again 2024-08-07 18:24:42 +02:00
Markus Himmel
d52da36e68 chore: update stage0 2024-08-07 18:24:42 +02:00
Markus Himmel
bf82965eec chore: avoid builing Lake 2024-08-07 18:24:42 +02:00
Markus Himmel
4bac74c4ac chore: switch to Std.HashMap and Std.HashSet almost everywhere 2024-08-07 18:24:42 +02:00
Henrik Böving
8b9d27de31 chore: Revert "feat: Revamp file reading and writing" (#4948)
Reverts leanprover/lean4#4906
2024-08-07 16:00:45 +00:00
Henrik Böving
d15f0335a9 feat: setup Std.Sat with definitions of SAT and CNF (#4933)
Step 1 out of approximately 7 to upstream LeanSAT.

---------

Co-authored-by: Tobias Grosser <tobias@grosser.es>
Co-authored-by: Markus Himmel <markus@lean-fro.org>
2024-08-07 15:44:46 +00:00
Sebastian Ullrich
240ebff549 chore: Windows needs more LEAN_EXPORTs (#4941) 2024-08-07 17:13:13 +02:00
Sebastian Ullrich
a29bca7f00 chore: CI: placate linter 2024-08-07 16:52:18 +02:00
Tobias Grosser
313f6b3c74 chore: name variables in Data/BitVec consistently (#4930)
This change canonicalizes the BitVec variable names to `x y z : BitVec`
instead of alternative namings such as `s t : BitVec` or `a b : BitVec`.
Variable names that carry semantic meaning such as `(msbs : BitVec w)
(lsb : Bool)` remain untouched.

This is purely a naming change to make our bitvector proofs more
consistent and polish the (auto-generated) documentation as a very small
step towards polishing the documentation of the BitVec library in Lean.

---------

Co-authored-by: AnotherAlexHere <153999274+AnotherAlexHere@users.noreply.github.com>
2024-08-07 13:43:15 +00:00
Markus Himmel
43fa46412d feat: deprecated variants of hash map query methods (#4943)
#4917 will expose users of the `Lean` API to the renaming of the hash
map query methods. This PR aims to make the transition easier by adding
deprecated functions with the old names.
2024-08-07 13:36:19 +00:00
Henrik Böving
234704e304 feat: upstream utilities around Array, Bool and Prod from LeanSAT (#4945)
Co-authored-by: Kim Morrison <kim@tqft.net>
2024-08-07 12:32:40 +00:00
Sebastian Ullrich
12a714a6f9 chore: CI: fix rebase command 2024-08-07 14:27:53 +02:00
Sebastian Ullrich
cdc7ed0224 chore: CI: fix rebase command 2024-08-07 14:21:43 +02:00
Sebastian Ullrich
217abdf97a chore: CI: fix rebase command 2024-08-07 14:15:18 +02:00
Sebastian Ullrich
490a2b4bf9 chore: CI: fix rebase command 2024-08-07 14:05:00 +02:00
Sebastian Ullrich
84d45deb10 chore: CI: fix rebase 2024-08-07 14:02:57 +02:00
Sebastian Ullrich
f46d216e18 chore: CI: !rebase PR comment command 2024-08-07 13:53:17 +02:00
Tobias Grosser
cc42a17931 feat: add ushiftRight_*_distrib theorems (#4667) 2024-08-07 10:43:54 +00:00
Siddharth
e106be19dd feat: sshiftRight bitblasting (#4889)
We follow the same strategy as
https://github.com/leanprover/lean4/pull/4872,
https://github.com/leanprover/lean4/pull/4571, and implement bitblasting
theorems for `sshiftRight`.

---------

Co-authored-by: Tobias Grosser <tobias@grosser.es>
2024-08-07 10:33:56 +00:00
Sebastian Ullrich
1efd6657d4 test: unflakify test cases (#4940)
With the recent unification of server and cmdline processing,
`IO.Process` tests that previously broke the server because they
directly wrote to stdout are now flaky on the cmdline because
elaboration and reporting are happening in separate threads. By removing
direct writes to stdout, the race condition is removed and the file can
actually be edited in the language server as well again.
2024-08-07 09:34:29 +00:00
Henrik Böving
473b34561d feat: Revamp file reading and writing (#4906)
This PR:
- changes the implementation of `readBinFile` and `readFile` to only
require two system calls (`stat` + `read`) instead of one `read` per
1024 byte chunk.
- fixes a bug where `Handle.getLine` would get tripped up by a NUL
character in the line and cut the string off. This is caused by the fact
that the original implementation uses `strlen` and `lean_mk_string`
which is the backer of `mk_string` does so as well.
- fixes a bug where `Handle.putStr` and thus by extension `writeFile`
would get tripped up by a NUL char in the line and cut the string off.
Cause here is the use of `fputs` when a NUL char is possible.

Closes: #4891 
Closes: #3546
Closes: #3741
2024-08-07 07:39:15 +00:00
Marc Huisinga
574066b30b fix: language server windows issues (#4821)
This PR resolves two language server bugs that especially affect Windows
users:
1. Editing the header could result in the watchdog not correctly
restarting the file worker (#3786, #3787), which would lead to the file
seemingly being processed forever.
- The cause of this issue was a race condition in the watchdog that was
accidentally introduced as far back as #1884: In specific circumstances,
the watchdog will attempt forwarding a message to the file worker after
the process has exited due to a changed header, but before the file
worker exiting has been noticed by the watchdog (which will then restart
the file worker). In this case, the watchdog would mark the file worker
as having crashed and not look at its exit code to restart the file
worker, but instead treat it like a crashed file worker that will only
be restarted when editing the file again. Not inspecting the exit code
of the file worker when it crashed from forwarding a message from the
file worker is necessary since we do not restart the file worker until
another notification from the client arrives, and so we would read the
same crash exit code over and over again in the main loop of the
watchdog if we did not remove it from our list of file workers that we
listen to.
- This PR resolves this issue by distinguishing between "crashes when
forwarding messages to the file worker" and "crashes when forwarding
messages from the file worker". In the former case, we still inspect the
exit code of the file worker and potentially restart it if the imports
changed, whereas in the latter case, we stop inspecting the exit code of
the file worker. This is correct because the latter case is exactly the
one where we need to stop inspecting the exit code but where a crash
cannot occur as a result of a changed header, whereas the former case is
exactly the one where we still need to inspect the exit code after a
crash to ensure that we restart the file worker in case it exited
because the header changed.
- At some point in the future, it would be nice to revamp the
concurrency model of the watchdog entirely now that we have all those
fancy concurrency primitives that were not available four years ago when
the watchdog was first written.

2. On an especially slow Windows machine, we found that starting the
language server would sometimes not succeed at all because reading from
the stdin pipe in the watchdog produced an EINVAL error, which was in
turn caused by an NT "pipe empty" error.
- After lots of debugging, @Kha found that Lake accidentally passes its
stdin to Git because it does not explicitly set the `stdin` field to
`null` when spawning the process.
- Changing this fixes the issue, which suggests that Git may mutate the
pipe we pass to it to be non-blocking, which then causes a "pipe empty"
error in the watchdog when we also attempt to read from that same pipe.
- I'm still very uncertain why we only saw this issue on one
particularly slow machine and not across the whole eco system.

This PR also resolves an issue where we would not correctly emit
messages that we received while the file worker is being restarted to
the corresponding file worker after the restart.

Closes #3786, closes #3787.

---------

Co-authored-by: Sebastian Ullrich <sebasti@nullri.ch>
2024-08-07 06:19:33 +00:00
Kim Morrison
1e6d617aad chore: minor fixes to release checklist (#4937) 2024-08-07 01:09:35 +00:00
Leonardo de Moura
c17a4ddc94 perf: skip betaReduceLetRecApps if it is not needed (#4936) 2024-08-07 00:57:35 +00:00
Leonardo de Moura
5be4f5e30c perf: skip eraseRecAppSyntaxExpr if it is not needed (#4935) 2024-08-07 00:29:50 +00:00
320 changed files with 1371 additions and 586 deletions

30
.github/workflows/rebase-on-comment.yml vendored Normal file
View File

@@ -0,0 +1,30 @@
# As the PR author, use `!rebase` to rebase a commit
name: Rebase on Comment
on:
issue_comment:
types: [created]
jobs:
rebase:
if: github.event.issue.pull_request != '' && github.event.comment.body == '!rebase' && github.event.comment.user.login == github.event.issue.user.login
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
ref: refs/pull/${{ github.event.issue.number }}/head
- name: Rebase PR branch onto base branch
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
PR_NUMBER="${{ github.event.issue.number }}"
API_URL="https://api.github.com/repos/${{ github.repository }}/pulls/$PR_NUMBER"
PR_DETAILS="$(curl -s -H "Authorization: token $GITHUB_TOKEN" $API_URL)"
BASE_REF="$(echo $PR_DETAILS | jq -r .base.ref)"
git checkout -b working-branch
git fetch origin $BASE_REF
git rebase origin/$BASE_REF
git push origin refs/pull/${{ github.event.issue.number }}/head --force-with-lease

View File

@@ -152,22 +152,26 @@ We'll use `v4.7.0-rc1` as the intended release version in this example.
This will add a list of all the commits since the last stable version.
- Delete "update stage0" commits, and anything with a completely inscrutable commit message.
- Next, we will move a curated list of downstream repos to the release candidate.
- This assumes that there is already a *reviewed* branch `bump/v4.7.0` on each repository
containing the required adaptations (or no adaptations are required).
The preparation of this branch is beyond the scope of this document.
- This assumes that for each repository either:
* There is already a *reviewed* branch `bump/v4.7.0` containing the required adaptations.
The preparation of this branch is beyond the scope of this document.
* The repository does not need any changes to move to the new version.
- For each of the target repositories:
- Checkout the `bump/v4.7.0` branch.
- Verify that the `lean-toolchain` is set to the nightly from which the release candidate was created.
- `git merge origin/master`
- Change the `lean-toolchain` to `leanprover/lean4:v4.7.0-rc1`
- In `lakefile.lean`, change any dependencies which were using `nightly-testing` or `bump/v4.7.0` branches
back to `master` or `main`, and run `lake update` for those dependencies.
- Run `lake build` to ensure that dependencies are found (but it's okay to stop it after a moment).
- `git commit`
- `git push`
- Open a PR from `bump/v4.7.0` to `master`, and either merge it yourself after CI, if appropriate,
or notify the maintainers that it is ready to go.
- Once this PR has been merged, tag `master` with `v4.7.0-rc1` and push this tag.
- If the repository does not need any changes (i.e. `bump/v4.7.0` does not exist) then create
a new PR updating `lean-toolchain` to `leanprover/lean4:v4.7.0-rc1` and running `lake update`.
- Otherwise:
- Checkout the `bump/v4.7.0` branch.
- Verify that the `lean-toolchain` is set to the nightly from which the release candidate was created.
- `git merge origin/master`
- Change the `lean-toolchain` to `leanprover/lean4:v4.7.0-rc1`
- In `lakefile.lean`, change any dependencies which were using `nightly-testing` or `bump/v4.7.0` branches
back to `master` or `main`, and run `lake update` for those dependencies.
- Run `lake build` to ensure that dependencies are found (but it's okay to stop it after a moment).
- `git commit`
- `git push`
- Open a PR from `bump/v4.7.0` to `master`, and either merge it yourself after CI, if appropriate,
or notify the maintainers that it is ready to go.
- Once the PR has been merged, tag `master` with `v4.7.0-rc1` and push this tag.
- We do this for the same list of repositories as for stable releases, see above.
As above, there are dependencies between these, and so the process above is iterative.
It greatly helps if you can merge the `bump/v4.7.0` PRs yourself!

View File

@@ -7,6 +7,7 @@ prelude
import Init.Data.Nat.MinMax
import Init.Data.Nat.Lemmas
import Init.Data.List.Monadic
import Init.Data.List.Nat.Range
import Init.Data.Fin.Basic
import Init.Data.Array.Mem
import Init.TacticsExtra
@@ -336,6 +337,10 @@ theorem not_mem_nil (a : α) : ¬ a ∈ #[] := nofun
/-- # get lemmas -/
theorem lt_of_getElem {x : α} {a : Array α} {idx : Nat} {hidx : idx < a.size} (_ : a[idx] = x) :
idx < a.size :=
hidx
theorem getElem?_mem {l : Array α} {i : Fin l.size} : l[i] l := by
erw [Array.mem_def, getElem_eq_data_getElem]
apply List.get_mem
@@ -505,6 +510,13 @@ theorem size_eq_length_data (as : Array α) : as.size = as.data.length := rfl
simp only [mkEmpty_eq, size_push] at *
omega
@[simp] theorem data_range (n : Nat) : (range n).data = List.range n := by
induction n <;> simp_all [range, Nat.fold, flip, List.range_succ]
@[simp]
theorem getElem_range {n : Nat} {x : Nat} (h : x < (Array.range n).size) : (Array.range n)[x] = x := by
simp [getElem_eq_data_getElem]
set_option linter.deprecated false in
@[simp] theorem reverse_data (a : Array α) : a.reverse.data = a.data.reverse := by
let rec go (as : Array α) (i j hj)
@@ -707,13 +719,22 @@ theorem mapIdx_spec (as : Array α) (f : Fin as.size → α → β)
unfold modify modifyM Id.run
split <;> simp
theorem get_modify {arr : Array α} {x i} (h : i < arr.size) :
(arr.modify x f).get i, by simp [h] =
if x = i then f (arr.get i, h) else arr.get i, h := by
simp [modify, modifyM, Id.run]; split
· simp [get_set _ _ _ h]; split <;> simp [*]
theorem getElem_modify {as : Array α} {x i} (h : i < as.size) :
(as.modify x f)[i]'(by simp [h]) = if x = i then f as[i] else as[i] := by
simp only [modify, modifyM, get_eq_getElem, Id.run, Id.pure_eq]
split
· simp only [Id.bind_eq, get_set _ _ _ h]; split <;> simp [*]
· rw [if_neg (mt (by rintro rfl; exact h) _)]
theorem getElem_modify_self {as : Array α} {i : Nat} (h : i < as.size) (f : α α) :
(as.modify i f)[i]'(by simp [h]) = f as[i] := by
simp [getElem_modify h]
theorem getElem_modify_of_ne {as : Array α} {i : Nat} (hj : j < as.size)
(f : α α) (h : i j) :
(as.modify i f)[j]'(by rwa [size_modify]) = as[j] := by
simp [getElem_modify hj, h]
/-! ### filter -/
@[simp] theorem filter_data (p : α Bool) (l : Array α) :

View File

@@ -42,8 +42,8 @@ Bitvectors have decidable equality. This should be used via the instance `Decida
-- We manually derive the `DecidableEq` instances for `BitVec` because
-- we want to have builtin support for bit-vector literals, and we
-- need a name for this function to implement `canUnfoldAtMatcher` at `WHNF.lean`.
def BitVec.decEq (a b : BitVec n) : Decidable (a = b) :=
match a, b with
def BitVec.decEq (x y : BitVec n) : Decidable (x = y) :=
match x, y with
| n, m =>
if h : n = m then
isTrue (h rfl)
@@ -69,9 +69,9 @@ protected def ofNat (n : Nat) (i : Nat) : BitVec n where
instance instOfNat : OfNat (BitVec n) i where ofNat := .ofNat n i
instance natCastInst : NatCast (BitVec w) := BitVec.ofNat w
/-- Given a bitvector `a`, return the underlying `Nat`. This is O(1) because `BitVec` is a
/-- Given a bitvector `x`, return the underlying `Nat`. This is O(1) because `BitVec` is a
(zero-cost) wrapper around a `Nat`. -/
protected def toNat (a : BitVec n) : Nat := a.toFin.val
protected def toNat (x : BitVec n) : Nat := x.toFin.val
/-- Return the bound in terms of toNat. -/
theorem isLt (x : BitVec w) : x.toNat < 2^w := x.toFin.isLt
@@ -123,18 +123,18 @@ section getXsb
@[inline] def getMsb (x : BitVec w) (i : Nat) : Bool := i < w && getLsb x (w-1-i)
/-- Return most-significant bit in bitvector. -/
@[inline] protected def msb (a : BitVec n) : Bool := getMsb a 0
@[inline] protected def msb (x : BitVec n) : Bool := getMsb x 0
end getXsb
section Int
/-- Interpret the bitvector as an integer stored in two's complement form. -/
protected def toInt (a : BitVec n) : Int :=
if 2 * a.toNat < 2^n then
a.toNat
protected def toInt (x : BitVec n) : Int :=
if 2 * x.toNat < 2^n then
x.toNat
else
(a.toNat : Int) - (2^n : Nat)
(x.toNat : Int) - (2^n : Nat)
/-- The `BitVec` with value `(2^n + (i mod 2^n)) mod 2^n`. -/
protected def ofInt (n : Nat) (i : Int) : BitVec n := .ofNatLt (i % (Int.ofNat (2^n))).toNat (by
@@ -215,7 +215,7 @@ instance : Neg (BitVec n) := ⟨.neg⟩
/--
Return the absolute value of a signed bitvector.
-/
protected def abs (s : BitVec n) : BitVec n := if s.msb then .neg s else s
protected def abs (x : BitVec n) : BitVec n := if x.msb then .neg x else x
/--
Multiplication for bit vectors. This can be interpreted as either signed or unsigned negation
@@ -262,12 +262,12 @@ sdiv 5#4 -2 = -2#4
sdiv (-7#4) (-2) = 3#4
```
-/
def sdiv (s t : BitVec n) : BitVec n :=
match s.msb, t.msb with
| false, false => udiv s t
| false, true => .neg (udiv s (.neg t))
| true, false => .neg (udiv (.neg s) t)
| true, true => udiv (.neg s) (.neg t)
def sdiv (x y : BitVec n) : BitVec n :=
match x.msb, y.msb with
| false, false => udiv x y
| false, true => .neg (udiv x (.neg y))
| true, false => .neg (udiv (.neg x) y)
| true, true => udiv (.neg x) (.neg y)
/--
Signed division for bit vectors using SMTLIB rules for division by zero.
@@ -276,40 +276,40 @@ Specifically, `smtSDiv x 0 = if x >= 0 then -1 else 1`
SMT-Lib name: `bvsdiv`.
-/
def smtSDiv (s t : BitVec n) : BitVec n :=
match s.msb, t.msb with
| false, false => smtUDiv s t
| false, true => .neg (smtUDiv s (.neg t))
| true, false => .neg (smtUDiv (.neg s) t)
| true, true => smtUDiv (.neg s) (.neg t)
def smtSDiv (x y : BitVec n) : BitVec n :=
match x.msb, y.msb with
| false, false => smtUDiv x y
| false, true => .neg (smtUDiv x (.neg y))
| true, false => .neg (smtUDiv (.neg x) y)
| true, true => smtUDiv (.neg x) (.neg y)
/--
Remainder for signed division rounding to zero.
SMT_Lib name: `bvsrem`.
-/
def srem (s t : BitVec n) : BitVec n :=
match s.msb, t.msb with
| false, false => umod s t
| false, true => umod s (.neg t)
| true, false => .neg (umod (.neg s) t)
| true, true => .neg (umod (.neg s) (.neg t))
def srem (x y : BitVec n) : BitVec n :=
match x.msb, y.msb with
| false, false => umod x y
| false, true => umod x (.neg y)
| true, false => .neg (umod (.neg x) y)
| true, true => .neg (umod (.neg x) (.neg y))
/--
Remainder for signed division rounded to negative infinity.
SMT_Lib name: `bvsmod`.
-/
def smod (s t : BitVec m) : BitVec m :=
match s.msb, t.msb with
| false, false => umod s t
def smod (x y : BitVec m) : BitVec m :=
match x.msb, y.msb with
| false, false => umod x y
| false, true =>
let u := umod s (.neg t)
(if u = .zero m then u else .add u t)
let u := umod x (.neg y)
(if u = .zero m then u else .add u y)
| true, false =>
let u := umod (.neg s) t
(if u = .zero m then u else .sub t u)
| true, true => .neg (umod (.neg s) (.neg t))
let u := umod (.neg x) y
(if u = .zero m then u else .sub y u)
| true, true => .neg (umod (.neg x) (.neg y))
end arithmetic
@@ -373,8 +373,8 @@ end relations
section cast
/-- `cast eq i` embeds `i` into an equal `BitVec` type. -/
@[inline] def cast (eq : n = m) (i : BitVec n) : BitVec m := .ofNatLt i.toNat (eq i.isLt)
/-- `cast eq x` embeds `x` into an equal `BitVec` type. -/
@[inline] def cast (eq : n = m) (x : BitVec n) : BitVec m := .ofNatLt x.toNat (eq x.isLt)
@[simp] theorem cast_ofNat {n m : Nat} (h : n = m) (x : Nat) :
cast h (BitVec.ofNat n x) = BitVec.ofNat m x := by
@@ -391,7 +391,7 @@ Extraction of bits `start` to `start + len - 1` from a bit vector of size `n` to
new bitvector of size `len`. If `start + len > n`, then the vector will be zero-padded in the
high bits.
-/
def extractLsb' (start len : Nat) (a : BitVec n) : BitVec len := .ofNat _ (a.toNat >>> start)
def extractLsb' (start len : Nat) (x : BitVec n) : BitVec len := .ofNat _ (x.toNat >>> start)
/--
Extraction of bits `hi` (inclusive) down to `lo` (inclusive) from a bit vector of size `n` to
@@ -399,12 +399,12 @@ yield a new bitvector of size `hi - lo + 1`.
SMT-Lib name: `extract`.
-/
def extractLsb (hi lo : Nat) (a : BitVec n) : BitVec (hi - lo + 1) := extractLsb' lo _ a
def extractLsb (hi lo : Nat) (x : BitVec n) : BitVec (hi - lo + 1) := extractLsb' lo _ x
/--
A version of `zeroExtend` that requires a proof, but is a noop.
-/
def zeroExtend' {n w : Nat} (le : n w) (x : BitVec n) : BitVec w :=
def zeroExtend' {n w : Nat} (le : n w) (x : BitVec n) : BitVec w :=
x.toNat#'(by
apply Nat.lt_of_lt_of_le x.isLt
exact Nat.pow_le_pow_of_le_right (by trivial) le)
@@ -413,8 +413,8 @@ def zeroExtend' {n w : Nat} (le : n ≤ w) (x : BitVec n) : BitVec w :=
`shiftLeftZeroExtend x n` returns `zeroExtend (w+n) x <<< n` without
needing to compute `x % 2^(2+n)`.
-/
def shiftLeftZeroExtend (msbs : BitVec w) (m : Nat) : BitVec (w+m) :=
let shiftLeftLt {x : Nat} (p : x < 2^w) (m : Nat) : x <<< m < 2^(w+m) := by
def shiftLeftZeroExtend (msbs : BitVec w) (m : Nat) : BitVec (w + m) :=
let shiftLeftLt {x : Nat} (p : x < 2^w) (m : Nat) : x <<< m < 2^(w + m) := by
simp [Nat.shiftLeft_eq, Nat.pow_add]
apply Nat.mul_lt_mul_of_pos_right p
exact (Nat.two_pow_pos m)
@@ -502,24 +502,24 @@ instance : Complement (BitVec w) := ⟨.not⟩
/--
Left shift for bit vectors. The low bits are filled with zeros. As a numeric operation, this is
equivalent to `a * 2^s`, modulo `2^n`.
equivalent to `x * 2^s`, modulo `2^n`.
SMT-Lib name: `bvshl` except this operator uses a `Nat` shift value.
-/
protected def shiftLeft (a : BitVec n) (s : Nat) : BitVec n := BitVec.ofNat n (a.toNat <<< s)
protected def shiftLeft (x : BitVec n) (s : Nat) : BitVec n := BitVec.ofNat n (x.toNat <<< s)
instance : HShiftLeft (BitVec w) Nat (BitVec w) := .shiftLeft
/--
(Logical) right shift for bit vectors. The high bits are filled with zeros.
As a numeric operation, this is equivalent to `a / 2^s`, rounding down.
As a numeric operation, this is equivalent to `x / 2^s`, rounding down.
SMT-Lib name: `bvlshr` except this operator uses a `Nat` shift value.
-/
def ushiftRight (a : BitVec n) (s : Nat) : BitVec n :=
(a.toNat >>> s)#'(by
let a, lt := a
def ushiftRight (x : BitVec n) (s : Nat) : BitVec n :=
(x.toNat >>> s)#'(by
let x, lt := x
simp only [BitVec.toNat, Nat.shiftRight_eq_div_pow, Nat.div_lt_iff_lt_mul (Nat.two_pow_pos s)]
rw [Nat.mul_one a]
rw [Nat.mul_one x]
exact Nat.mul_lt_mul_of_lt_of_le' lt (Nat.two_pow_pos s) (Nat.le_refl 1))
instance : HShiftRight (BitVec w) Nat (BitVec w) := .ushiftRight
@@ -527,15 +527,24 @@ instance : HShiftRight (BitVec w) Nat (BitVec w) := ⟨.ushiftRight⟩
/--
Arithmetic right shift for bit vectors. The high bits are filled with the
most-significant bit.
As a numeric operation, this is equivalent to `a.toInt >>> s`.
As a numeric operation, this is equivalent to `x.toInt >>> s`.
SMT-Lib name: `bvashr` except this operator uses a `Nat` shift value.
-/
def sshiftRight (a : BitVec n) (s : Nat) : BitVec n := .ofInt n (a.toInt >>> s)
def sshiftRight (x : BitVec n) (s : Nat) : BitVec n := .ofInt n (x.toInt >>> s)
instance {n} : HShiftLeft (BitVec m) (BitVec n) (BitVec m) := fun x y => x <<< y.toNat
instance {n} : HShiftRight (BitVec m) (BitVec n) (BitVec m) := fun x y => x >>> y.toNat
/--
Arithmetic right shift for bit vectors. The high bits are filled with the
most-significant bit.
As a numeric operation, this is equivalent to `a.toInt >>> s.toNat`.
SMT-Lib name: `bvashr`.
-/
def sshiftRight' (a : BitVec n) (s : BitVec m) : BitVec n := a.sshiftRight s.toNat
/-- Auxiliary function for `rotateLeft`, which does not take into account the case where
the rotation amount is greater than the bitvector width. -/
def rotateLeftAux (x : BitVec w) (n : Nat) : BitVec w :=

View File

@@ -289,18 +289,18 @@ theorem sle_eq_carry (x y : BitVec w) :
A recurrence that describes multiplication as repeated addition.
Is useful for bitblasting multiplication.
-/
def mulRec (l r : BitVec w) (s : Nat) : BitVec w :=
let cur := if r.getLsb s then (l <<< s) else 0
def mulRec (x y : BitVec w) (s : Nat) : BitVec w :=
let cur := if y.getLsb s then (x <<< s) else 0
match s with
| 0 => cur
| s + 1 => mulRec l r s + cur
| s + 1 => mulRec x y s + cur
theorem mulRec_zero_eq (l r : BitVec w) :
mulRec l r 0 = if r.getLsb 0 then l else 0 := by
theorem mulRec_zero_eq (x y : BitVec w) :
mulRec x y 0 = if y.getLsb 0 then x else 0 := by
simp [mulRec]
theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) :
mulRec l r (s + 1) = mulRec l r s + if r.getLsb (s + 1) then (l <<< (s + 1)) else 0 := rfl
theorem mulRec_succ_eq (x y : BitVec w) (s : Nat) :
mulRec x y (s + 1) = mulRec x y s + if y.getLsb (s + 1) then (x <<< (s + 1)) else 0 := rfl
/--
Recurrence lemma: truncating to `i+1` bits and then zero extending to `w`
@@ -326,29 +326,29 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w
by_cases hi : x.getLsb i <;> simp [hi] <;> omega
/--
Recurrence lemma: multiplying `l` with the first `s` bits of `r` is the
same as truncating `r` to `s` bits, then zero extending to the original length,
Recurrence lemma: multiplying `x` with the first `s` bits of `y` is the
same as truncating `y` to `s` bits, then zero extending to the original length,
and performing the multplication. -/
theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) :
mulRec l r s = l * ((r.truncate (s + 1)).zeroExtend w) := by
theorem mulRec_eq_mul_signExtend_truncate (x y : BitVec w) (s : Nat) :
mulRec x y s = x * ((y.truncate (s + 1)).zeroExtend w) := by
induction s
case zero =>
simp only [mulRec_zero_eq, ofNat_eq_ofNat, Nat.reduceAdd]
by_cases r.getLsb 0
case pos hr =>
simp only [hr, reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero,
hr, ofBool_true, ofNat_eq_ofNat]
by_cases y.getLsb 0
case pos hy =>
simp only [hy, reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero,
ofBool_true, ofNat_eq_ofNat]
rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)]
simp
case neg hr =>
simp [hr, zeroExtend_one_eq_ofBool_getLsb_zero]
case neg hy =>
simp [hy, zeroExtend_one_eq_ofBool_getLsb_zero]
case succ s' hs =>
rw [mulRec_succ_eq, hs]
have heq :
(if r.getLsb (s' + 1) = true then l <<< (s' + 1) else 0) =
(l * (r &&& (BitVec.twoPow w (s' + 1)))) := by
(if y.getLsb (s' + 1) = true then x <<< (s' + 1) else 0) =
(x * (y &&& (BitVec.twoPow w (s' + 1)))) := by
simp only [ofNat_eq_ofNat, and_twoPow]
by_cases hr : r.getLsb (s' + 1) <;> simp [hr]
by_cases hy : y.getLsb (s' + 1) <;> simp [hy]
rw [heq, BitVec.mul_add, zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow]
theorem getLsb_mul (x y : BitVec w) (i : Nat) :
@@ -429,6 +429,67 @@ theorem shiftLeft_eq_shiftLeftRec (x : BitVec w₁) (y : BitVec w₂) :
· simp [of_length_zero]
· simp [shiftLeftRec_eq]
/- ### Arithmetic shift right (sshiftRight) recurrence -/
/--
`sshiftRightRec x y n` shifts `x` arithmetically/signed to the right by the first `n` bits of `y`.
The theorem `sshiftRight_eq_sshiftRightRec` proves the equivalence of `(x.sshiftRight y)` and `sshiftRightRec`.
Together with equations `sshiftRightRec_zero`, `sshiftRightRec_succ`,
this allows us to unfold `sshiftRight` into a circuit for bitblasting.
-/
def sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ :=
let shiftAmt := (y &&& (twoPow w₂ n))
match n with
| 0 => x.sshiftRight' shiftAmt
| n + 1 => (sshiftRightRec x y n).sshiftRight' shiftAmt
@[simp]
theorem sshiftRightRec_zero_eq (x : BitVec w₁) (y : BitVec w₂) :
sshiftRightRec x y 0 = x.sshiftRight' (y &&& 1#w₂) := by
simp only [sshiftRightRec, twoPow_zero]
@[simp]
theorem sshiftRightRec_succ_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
sshiftRightRec x y (n + 1) = (sshiftRightRec x y n).sshiftRight' (y &&& twoPow w₂ (n + 1)) := by
simp [sshiftRightRec]
/--
If `y &&& z = 0`, `x.sshiftRight (y ||| z) = (x.sshiftRight y).sshiftRight z`.
This follows as `y &&& z = 0` implies `y ||| z = y + z`,
and thus `x.sshiftRight (y ||| z) = x.sshiftRight (y + z) = (x.sshiftRight y).sshiftRight z`.
-/
theorem sshiftRight'_or_of_and_eq_zero {x : BitVec w₁} {y z : BitVec w₂}
(h : y &&& z = 0#w₂) :
x.sshiftRight' (y ||| z) = (x.sshiftRight' y).sshiftRight' z := by
simp [sshiftRight', add_eq_or_of_and_eq_zero _ _ h,
toNat_add_of_and_eq_zero h, sshiftRight_add]
theorem sshiftRightRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
sshiftRightRec x y n = x.sshiftRight' ((y.truncate (n + 1)).zeroExtend w₂) := by
induction n generalizing x y
case zero =>
ext i
simp [twoPow_zero, Nat.reduceAdd, and_one_eq_zeroExtend_ofBool_getLsb, truncate_one]
case succ n ih =>
simp only [sshiftRightRec_succ_eq, and_twoPow, ih]
by_cases h : y.getLsb (n + 1)
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true h,
sshiftRight'_or_of_and_eq_zero (by simp), h]
simp
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1)
(by simp [h])]
simp [h]
/--
Show that `x.sshiftRight y` can be written in terms of `sshiftRightRec`.
This can be unfolded in terms of `sshiftRightRec_zero_eq`, `sshiftRightRec_succ_eq` for bitblasting.
-/
theorem sshiftRight_eq_sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) :
(x.sshiftRight' y).getLsb i = (sshiftRightRec x y (w₂ - 1)).getLsb i := by
rcases w₂ with rfl | w₂
· simp [of_length_zero]
· simp [sshiftRightRec_eq]
/- ### Logical shift right (ushiftRight) recurrence for bitblasting -/
/--

View File

@@ -23,7 +23,7 @@ theorem ofFin_eq_ofNat : @BitVec.ofFin w (Fin.mk x lt) = BitVec.ofNat w x := by
simp only [BitVec.ofNat, Fin.ofNat', lt, Nat.mod_eq_of_lt]
/-- Prove equality of bitvectors in terms of nat operations. -/
theorem eq_of_toNat_eq {n} : {i j : BitVec n}, i.toNat = j.toNat i = j
theorem eq_of_toNat_eq {n} : {x y : BitVec n}, x.toNat = y.toNat x = y
| _, _, _, _, rfl => rfl
@[simp] theorem val_toFin (x : BitVec w) : x.toFin.val = x.toNat := rfl
@@ -228,12 +228,12 @@ theorem toNat_ge_of_msb_true {x : BitVec n} (p : BitVec.msb x = true) : x.toNat
/-! ### toInt/ofInt -/
/-- Prove equality of bitvectors in terms of nat operations. -/
theorem toInt_eq_toNat_cond (i : BitVec n) :
i.toInt =
if 2*i.toNat < 2^n then
(i.toNat : Int)
theorem toInt_eq_toNat_cond (x : BitVec n) :
x.toInt =
if 2*x.toNat < 2^n then
(x.toNat : Int)
else
(i.toNat : Int) - (2^n : Nat) :=
(x.toNat : Int) - (2^n : Nat) :=
rfl
theorem msb_eq_false_iff_two_mul_lt (x : BitVec w) : x.msb = false 2 * x.toNat < 2^w := by
@@ -260,13 +260,13 @@ theorem toInt_eq_toNat_bmod (x : BitVec n) : x.toInt = Int.bmod x.toNat (2^n) :=
omega
/-- Prove equality of bitvectors in terms of nat operations. -/
theorem eq_of_toInt_eq {i j : BitVec n} : i.toInt = j.toInt i = j := by
theorem eq_of_toInt_eq {x y : BitVec n} : x.toInt = y.toInt x = y := by
intro eq
simp [toInt_eq_toNat_cond] at eq
apply eq_of_toNat_eq
revert eq
have _ilt := i.isLt
have _jlt := j.isLt
have _xlt := x.isLt
have _ylt := y.isLt
split <;> split <;> omega
theorem toInt_inj (x y : BitVec n) : x.toInt = y.toInt x = y :=
@@ -733,6 +733,21 @@ theorem getLsb_shiftLeft' {x : BitVec w₁} {y : BitVec w₂} {i : Nat} :
getLsb (x >>> i) j = getLsb x (i+j) := by
unfold getLsb ; simp
theorem ushiftRight_xor_distrib (x y : BitVec w) (n : Nat) :
(x ^^^ y) >>> n = (x >>> n) ^^^ (y >>> n) := by
ext
simp
theorem ushiftRight_and_distrib (x y : BitVec w) (n : Nat) :
(x &&& y) >>> n = (x >>> n) &&& (y >>> n) := by
ext
simp
theorem ushiftRight_or_distrib (x y : BitVec w) (n : Nat) :
(x ||| y) >>> n = (x >>> n) ||| (y >>> n) := by
ext
simp
@[simp]
theorem ushiftRight_zero_eq (x : BitVec w) : x >>> 0 = x := by
simp [bv_toNat]
@@ -786,7 +801,7 @@ theorem sshiftRight_eq_of_msb_true {x : BitVec w} {s : Nat} (h : x.msb = true) :
· rw [Nat.shiftRight_eq_div_pow]
apply Nat.lt_of_le_of_lt (Nat.div_le_self _ _) (by omega)
theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
@[simp] theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
getLsb (x.sshiftRight s) i =
(!decide (w i) && if s + i < w then x.getLsb (s + i) else x.msb) := by
rcases hmsb : x.msb with rfl | rfl
@@ -807,6 +822,41 @@ theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
Nat.not_lt, decide_eq_true_eq]
omega
/-- The msb after arithmetic shifting right equals the original msb. -/
theorem sshiftRight_msb_eq_msb {n : Nat} {x : BitVec w} :
(x.sshiftRight n).msb = x.msb := by
rw [msb_eq_getLsb_last, getLsb_sshiftRight, msb_eq_getLsb_last]
by_cases hw₀ : w = 0
· simp [hw₀]
· simp only [show ¬(w w - 1) by omega, decide_False, Bool.not_false, Bool.true_and,
ite_eq_right_iff]
intros h
simp [show n = 0 by omega]
@[simp] theorem sshiftRight_zero {x : BitVec w} : x.sshiftRight 0 = x := by
ext i
simp
theorem sshiftRight_add {x : BitVec w} {m n : Nat} :
x.sshiftRight (m + n) = (x.sshiftRight m).sshiftRight n := by
ext i
simp only [getLsb_sshiftRight, Nat.add_assoc]
by_cases h₁ : w (i : Nat)
· simp [h₁]
· simp only [h₁, decide_False, Bool.not_false, Bool.true_and]
by_cases h₂ : n + i < w
· simp [h₂]
· simp only [h₂, reduceIte]
by_cases h₃ : m + (n + i) < w
· simp [h₃]
omega
· simp [h₃, sshiftRight_msb_eq_msb]
/-! ### sshiftRight reductions from BitVec to Nat -/
@[simp]
theorem sshiftRight_eq' (x : BitVec w) : x.sshiftRight' y = x.sshiftRight y.toNat := rfl
/-! ### signExtend -/
/-- Equation theorem for `Int.sub` when both arguments are `Int.ofNat` -/
@@ -869,15 +919,15 @@ theorem append_def (x : BitVec v) (y : BitVec w) :
(x ++ y).toNat = x.toNat <<< n ||| y.toNat :=
rfl
@[simp] theorem getLsb_append {v : BitVec n} {w : BitVec m} :
getLsb (v ++ w) i = bif i < m then getLsb w i else getLsb v (i - m) := by
@[simp] theorem getLsb_append {x : BitVec n} {y : BitVec m} :
getLsb (x ++ y) i = bif i < m then getLsb y i else getLsb x (i - m) := by
simp only [append_def, getLsb_or, getLsb_shiftLeftZeroExtend, getLsb_zeroExtend']
by_cases h : i < m
· simp [h]
· simp [h]; simp_all
@[simp] theorem getMsb_append {v : BitVec n} {w : BitVec m} :
getMsb (v ++ w) i = bif n i then getMsb w (i - n) else getMsb v i := by
@[simp] theorem getMsb_append {x : BitVec n} {y : BitVec m} :
getMsb (x ++ y) i = bif n i then getMsb y (i - n) else getMsb x i := by
simp [append_def]
by_cases h : n i
· simp [h]

View File

@@ -438,6 +438,24 @@ Added for confluence between `if_true_left` and `ite_false_same` on
-/
@[simp] theorem eq_true_imp_eq_false : (b:Bool), (b = true b = false) (b = false) := by decide
/-! ### forall -/
theorem forall_bool' {p : Bool Prop} (b : Bool) : ( x, p x) p b p !b :=
fun h h _, h _, fun h₁, h₂ x by cases b <;> cases x <;> assumption
@[simp]
theorem forall_bool {p : Bool Prop} : ( b, p b) p false p true :=
forall_bool' false
/-! ### exists -/
theorem exists_bool' {p : Bool Prop} (b : Bool) : ( x, p x) p b p !b :=
fun x, hx by cases x <;> cases b <;> first | exact .inl _ | exact .inr _,
fun h by cases h <;> exact _, _
@[simp]
theorem exists_bool {p : Bool Prop} : ( b, p b) p false p true :=
exists_bool' false
/-! ### cond -/

View File

@@ -5,9 +5,18 @@ Author: Leonardo de Moura
-/
prelude
import Init.SimpLemmas
import Init.NotationExtra
instance [BEq α] [BEq β] [LawfulBEq α] [LawfulBEq β] : LawfulBEq (α × β) where
eq_of_beq {a b} (h : a.1 == b.1 && a.2 == b.2) := by
cases a; cases b
refine congr (congrArg _ (eq_of_beq ?_)) (eq_of_beq ?_) <;> simp_all
rfl {a} := by cases a; simp [BEq.beq, LawfulBEq.rfl]
@[simp]
protected theorem Prod.forall {p : α × β Prop} : ( x, p x) a b, p (a, b) :=
fun h a b h (a, b), fun h a, b h a b
@[simp]
protected theorem Prod.exists {p : α × β Prop} : ( x, p x) a b, p (a, b) :=
fun a, b, h a, b, h, fun a, b, h a, b, h

View File

@@ -470,31 +470,23 @@ def withFile (fn : FilePath) (mode : Mode) (f : Handle → IO α) : IO α :=
def Handle.putStrLn (h : Handle) (s : String) : IO Unit :=
h.putStr (s.push '\n')
partial def Handle.readBinToEnd (h : Handle) : IO ByteArray := do
partial def Handle.readBinToEndInto (h : Handle) (buf : ByteArray) : IO ByteArray := do
let rec loop (acc : ByteArray) : IO ByteArray := do
let buf h.read 1024
if buf.isEmpty then
return acc
else
loop (acc ++ buf)
loop ByteArray.empty
loop buf
partial def Handle.readToEnd (h : Handle) : IO String := do
let rec loop (s : String) := do
let line h.getLine
if line.isEmpty then
return s
else
loop (s ++ line)
loop ""
partial def Handle.readBinToEnd (h : Handle) : IO ByteArray := do
h.readBinToEndInto .empty
def readBinFile (fname : FilePath) : IO ByteArray := do
let h Handle.mk fname Mode.read
h.readBinToEnd
def readFile (fname : FilePath) : IO String := do
let h Handle.mk fname Mode.read
h.readToEnd
def Handle.readToEnd (h : Handle) : IO String := do
let data h.readBinToEnd
match String.fromUTF8? data with
| some s => return s
| none => throw <| .userError s!"Tried to read from handle containing non UTF-8 data."
partial def lines (fname : FilePath) : IO (Array String) := do
let h Handle.mk fname Mode.read
@@ -600,6 +592,28 @@ end System.FilePath
namespace IO
namespace FS
def readBinFile (fname : FilePath) : IO ByteArray := do
-- Requires metadata so defined after metadata
let mdata fname.metadata
let size := mdata.byteSize.toUSize
let handle IO.FS.Handle.mk fname .read
let buf
if size > 0 then
handle.read mdata.byteSize.toUSize
else
pure <| ByteArray.mkEmpty 0
handle.readBinToEndInto buf
def readFile (fname : FilePath) : IO String := do
let data readBinFile fname
match String.fromUTF8? data with
| some s => return s
| none => throw <| .userError s!"Tried to read file '{fname}' containing non UTF-8 data."
end FS
def withStdin [Monad m] [MonadFinally m] [MonadLiftT BaseIO m] (h : FS.Stream) (x : m α) : m α := do
let prev setStdin h
try x finally discard <| setStdin prev

View File

@@ -53,7 +53,7 @@ structure AttributeImpl extends AttributeImplCore where
erase (decl : Name) : AttrM Unit := throwError "attribute cannot be erased"
deriving Inhabited
builtin_initialize attributeMapRef : IO.Ref (HashMap Name AttributeImpl) IO.mkRef {}
builtin_initialize attributeMapRef : IO.Ref (Std.HashMap Name AttributeImpl) IO.mkRef {}
/-- Low level attribute registration function. -/
def registerBuiltinAttribute (attr : AttributeImpl) : IO Unit := do
@@ -296,7 +296,7 @@ end EnumAttributes
-/
abbrev AttributeImplBuilder := Name List DataValue Except String AttributeImpl
abbrev AttributeImplBuilderTable := HashMap Name AttributeImplBuilder
abbrev AttributeImplBuilderTable := Std.HashMap Name AttributeImplBuilder
builtin_initialize attributeImplBuilderTableRef : IO.Ref AttributeImplBuilderTable IO.mkRef {}
@@ -307,7 +307,7 @@ def registerAttributeImplBuilder (builderId : Name) (builder : AttributeImplBuil
def mkAttributeImplOfBuilder (builderId ref : Name) (args : List DataValue) : IO AttributeImpl := do
let table attributeImplBuilderTableRef.get
match table.find? builderId with
match table[builderId]? with
| none => throw (IO.userError ("unknown attribute implementation builder '" ++ toString builderId ++ "'"))
| some builder => IO.ofExcept <| builder ref args
@@ -317,7 +317,7 @@ inductive AttributeExtensionOLeanEntry where
structure AttributeExtensionState where
newEntries : List AttributeExtensionOLeanEntry := []
map : HashMap Name AttributeImpl
map : Std.HashMap Name AttributeImpl
deriving Inhabited
abbrev AttributeExtension := PersistentEnvExtension AttributeExtensionOLeanEntry (AttributeExtensionOLeanEntry × AttributeImpl) AttributeExtensionState
@@ -348,7 +348,7 @@ private def AttributeExtension.addImported (es : Array (Array AttributeExtension
let map es.foldlM
(fun map entries =>
entries.foldlM
(fun (map : HashMap Name AttributeImpl) entry => do
(fun (map : Std.HashMap Name AttributeImpl) entry => do
let attrImpl mkAttributeImplOfEntry ctx.env ctx.opts entry
return map.insert attrImpl.name attrImpl)
map)
@@ -378,7 +378,7 @@ def getBuiltinAttributeNames : IO (List Name) :=
def getBuiltinAttributeImpl (attrName : Name) : IO AttributeImpl := do
let m attributeMapRef.get
match m.find? attrName with
match m[attrName]? with
| some attr => pure attr
| none => throw (IO.userError ("unknown attribute '" ++ toString attrName ++ "'"))
@@ -396,7 +396,7 @@ def getAttributeNames (env : Environment) : List Name :=
def getAttributeImpl (env : Environment) (attrName : Name) : Except String AttributeImpl :=
let m := (attributeExtension.getState env).map
match m.find? attrName with
match m[attrName]? with
| some attr => pure attr
| none => throw ("unknown attribute '" ++ toString attrName ++ "'")

View File

@@ -26,9 +26,9 @@ instance : Hashable Key := ⟨getHash⟩
end OwnedSet
open OwnedSet (Key) in
abbrev OwnedSet := HashMap Key Unit
def OwnedSet.insert (s : OwnedSet) (k : OwnedSet.Key) : OwnedSet := HashMap.insert s k ()
def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := HashMap.contains s k
abbrev OwnedSet := Std.HashMap Key Unit
def OwnedSet.insert (s : OwnedSet) (k : OwnedSet.Key) : OwnedSet := Std.HashMap.insert s k ()
def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := Std.HashMap.contains s k
/-! We perform borrow inference in a block of mutually recursive functions.
Join points are viewed as local functions, and are identified using
@@ -49,7 +49,7 @@ instance : Hashable Key := ⟨getHash⟩
end ParamMap
open ParamMap (Key)
abbrev ParamMap := HashMap Key (Array Param)
abbrev ParamMap := Std.HashMap Key (Array Param)
def ParamMap.fmt (map : ParamMap) : Format :=
let fmts := map.fold (fun fmt k ps =>
@@ -109,7 +109,7 @@ partial def visitFnBody (fn : FunId) (paramMap : ParamMap) : FnBody → FnBody
| FnBody.jdecl j _ v b =>
let v := visitFnBody fn paramMap v
let b := visitFnBody fn paramMap b
match paramMap.find? (ParamMap.Key.jp fn j) with
match paramMap[ParamMap.Key.jp fn j]? with
| some ys => FnBody.jdecl j ys v b
| none => unreachable!
| FnBody.case tid x xType alts =>
@@ -125,7 +125,7 @@ def visitDecls (decls : Array Decl) (paramMap : ParamMap) : Array Decl :=
decls.map fun decl => match decl with
| Decl.fdecl f _ ty b info =>
let b := visitFnBody f paramMap b
match paramMap.find? (ParamMap.Key.decl f) with
match paramMap[ParamMap.Key.decl f]? with
| some xs => Decl.fdecl f xs ty b info
| none => unreachable!
| other => other
@@ -178,7 +178,7 @@ def isOwned (x : VarId) : M Bool := do
/-- Updates `map[k]` using the current set of `owned` variables. -/
def updateParamMap (k : ParamMap.Key) : M Unit := do
let s get
match s.paramMap.find? k with
match s.paramMap[k]? with
| some ps => do
let ps ps.mapM fun (p : Param) => do
if !p.borrow then pure p
@@ -192,7 +192,7 @@ def updateParamMap (k : ParamMap.Key) : M Unit := do
def getParamInfo (k : ParamMap.Key) : M (Array Param) := do
let s get
match s.paramMap.find? k with
match s.paramMap[k]? with
| some ps => pure ps
| none =>
match k with

View File

@@ -11,6 +11,7 @@ import Lean.Compiler.IR.Basic
import Lean.Compiler.IR.CompilerM
import Lean.Compiler.IR.FreeVars
import Lean.Compiler.IR.ElimDeadVars
import Lean.Data.AssocList
namespace Lean.IR.ExplicitBoxing
/-!

View File

@@ -152,7 +152,7 @@ def getFunctionSummary? (env : Environment) (fid : FunId) : Option Value :=
| some modIdx => findAtSorted? (functionSummariesExt.getModuleEntries env modIdx) fid
| none => functionSummariesExt.getState env |>.find? fid
abbrev Assignment := HashMap VarId Value
abbrev Assignment := Std.HashMap VarId Value
structure InterpContext where
currFnIdx : Nat := 0
@@ -172,7 +172,7 @@ def findVarValue (x : VarId) : M Value := do
let ctx read
let s get
let assignment := s.assignments[ctx.currFnIdx]!
return assignment.findD x bot
return assignment.getD x bot
def findArgValue (arg : Arg) : M Value :=
match arg with
@@ -303,7 +303,7 @@ partial def elimDeadAux (assignment : Assignment) : FnBody → FnBody
| FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux assignment b)
| FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux assignment v) (elimDeadAux assignment b)
| FnBody.case tid x xType alts =>
let v := assignment.findD x bot
let v := assignment.getD x bot
let alts := alts.map fun alt =>
match alt with
| Alt.ctor i b => Alt.ctor i <| if containsCtor v i then elimDeadAux assignment b else FnBody.unreachable

View File

@@ -252,7 +252,7 @@ def throwUnknownVar {α : Type} (x : VarId) : M α :=
def getJPParams (j : JoinPointId) : M (Array Param) := do
let ctx read;
match ctx.jpMap.find? j with
match ctx.jpMap[j]? with
| some ps => pure ps
| none => throw "unknown join point"

View File

@@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Siddharth Bhat
-/
prelude
import Lean.Data.HashMap
import Lean.Runtime
import Lean.Compiler.NameMangling
import Lean.Compiler.ExportAttr
@@ -65,8 +64,8 @@ structure Context (llvmctx : LLVM.Context) where
llvmmodule : LLVM.Module llvmctx
structure State (llvmctx : LLVM.Context) where
var2val : HashMap VarId (LLVM.LLVMType llvmctx × LLVM.Value llvmctx)
jp2bb : HashMap JoinPointId (LLVM.BasicBlock llvmctx)
var2val : Std.HashMap VarId (LLVM.LLVMType llvmctx × LLVM.Value llvmctx)
jp2bb : Std.HashMap JoinPointId (LLVM.BasicBlock llvmctx)
abbrev Error := String
@@ -84,7 +83,7 @@ def addJpTostate (jp : JoinPointId) (bb : LLVM.BasicBlock llvmctx) : M llvmctx U
def emitJp (jp : JoinPointId) : M llvmctx (LLVM.BasicBlock llvmctx) := do
let state get
match state.jp2bb.find? jp with
match state.jp2bb[jp]? with
| .some bb => return bb
| .none => throw s!"unable to find join point {jp}"
@@ -531,7 +530,7 @@ def emitFnDecls : M llvmctx Unit := do
def emitLhsSlot_ (x : VarId) : M llvmctx (LLVM.LLVMType llvmctx × LLVM.Value llvmctx) := do
let state get
match state.var2val.find? x with
match state.var2val[x]? with
| .some v => return v
| .none => throw s!"unable to find variable {x}"
@@ -1029,7 +1028,7 @@ def emitTailCall (builder : LLVM.Builder llvmctx) (f : FunId) (v : Expr) : M llv
def emitJmp (builder : LLVM.Builder llvmctx) (jp : JoinPointId) (xs : Array Arg) : M llvmctx Unit := do
let llvmctx read
let ps match llvmctx.jpMap.find? jp with
let ps match llvmctx.jpMap[jp]? with
| some ps => pure ps
| none => throw s!"Unknown join point {jp}"
unless xs.size == ps.size do throw s!"Invalid goto, mismatched sizes between arguments, formal parameters."

View File

@@ -51,8 +51,8 @@ end CollectUsedDecls
def collectUsedDecls (env : Environment) (decl : Decl) (used : NameSet := {}) : NameSet :=
(CollectUsedDecls.collectDecl decl env).run' used
abbrev VarTypeMap := HashMap VarId IRType
abbrev JPParamsMap := HashMap JoinPointId (Array Param)
abbrev VarTypeMap := Std.HashMap VarId IRType
abbrev JPParamsMap := Std.HashMap JoinPointId (Array Param)
namespace CollectMaps
abbrev Collector := (VarTypeMap × JPParamsMap) (VarTypeMap × JPParamsMap)

View File

@@ -10,7 +10,7 @@ import Lean.Compiler.IR.FreeVars
namespace Lean.IR.ExpandResetReuse
/-- Mapping from variable to projections -/
abbrev ProjMap := HashMap VarId Expr
abbrev ProjMap := Std.HashMap VarId Expr
namespace CollectProjMap
abbrev Collector := ProjMap ProjMap
@[inline] def collectVDecl (x : VarId) (v : Expr) : Collector := fun m =>
@@ -148,20 +148,20 @@ def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=
match y with
| Arg.var y =>
match ctx.projMap.find? y with
match ctx.projMap[y]? with
| some (Expr.proj j w) => j == i && w == x
| _ => false
| _ => false
/-- Given `uset x[i] := y`, return true iff `y := uproj[i] x` -/
def isSelfUSet (ctx : Context) (x : VarId) (i : Nat) (y : VarId) : Bool :=
match ctx.projMap.find? y with
match ctx.projMap[y]? with
| some (Expr.uproj j w) => j == i && w == x
| _ => false
/-- Given `sset x[n, i] := y`, return true iff `y := sproj[n, i] x` -/
def isSelfSSet (ctx : Context) (x : VarId) (n : Nat) (i : Nat) (y : VarId) : Bool :=
match ctx.projMap.find? y with
match ctx.projMap[y]? with
| some (Expr.sproj m j w) => n == m && j == i && w == x
| _ => false

View File

@@ -64,34 +64,34 @@ instance : AddMessageContext CompilerM where
def getType (fvarId : FVarId) : CompilerM Expr := do
let lctx := ( get).lctx
if let some decl := lctx.letDecls.find? fvarId then
if let some decl := lctx.letDecls[fvarId]? then
return decl.type
else if let some decl := lctx.params.find? fvarId then
else if let some decl := lctx.params[fvarId]? then
return decl.type
else if let some decl := lctx.funDecls.find? fvarId then
else if let some decl := lctx.funDecls[fvarId]? then
return decl.type
else
throwError "unknown free variable {fvarId.name}"
def getBinderName (fvarId : FVarId) : CompilerM Name := do
let lctx := ( get).lctx
if let some decl := lctx.letDecls.find? fvarId then
if let some decl := lctx.letDecls[fvarId]? then
return decl.binderName
else if let some decl := lctx.params.find? fvarId then
else if let some decl := lctx.params[fvarId]? then
return decl.binderName
else if let some decl := lctx.funDecls.find? fvarId then
else if let some decl := lctx.funDecls[fvarId]? then
return decl.binderName
else
throwError "unknown free variable {fvarId.name}"
def findParam? (fvarId : FVarId) : CompilerM (Option Param) :=
return ( get).lctx.params.find? fvarId
return ( get).lctx.params[fvarId]?
def findLetDecl? (fvarId : FVarId) : CompilerM (Option LetDecl) :=
return ( get).lctx.letDecls.find? fvarId
return ( get).lctx.letDecls[fvarId]?
def findFunDecl? (fvarId : FVarId) : CompilerM (Option FunDecl) :=
return ( get).lctx.funDecls.find? fvarId
return ( get).lctx.funDecls[fvarId]?
def findLetValue? (fvarId : FVarId) : CompilerM (Option LetValue) := do
let some { value, .. } findLetDecl? fvarId | return none
@@ -166,7 +166,7 @@ it is a free variable, a type (or type former), or `lcErased`.
`Check.lean` contains a substitution validator.
-/
abbrev FVarSubst := HashMap FVarId Expr
abbrev FVarSubst := Std.HashMap FVarId Expr
/--
Replace the free variables in `e` using the given substitution.
@@ -190,7 +190,7 @@ where
go (e : Expr) : Expr :=
if e.hasFVar then
match e with
| .fvar fvarId => match s.find? fvarId with
| .fvar fvarId => match s[fvarId]? with
| some e => if translator then e else go e
| none => e
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e
@@ -224,7 +224,7 @@ That is, it is not a type (or type former), nor `lcErased`. Recall that a valid
expressions that are free variables, `lcErased`, or type formers.
-/
private partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : Bool) : NormFVarResult :=
match s.find? fvarId with
match s[fvarId]? with
| some (.fvar fvarId') =>
if translator then
.fvar fvarId'
@@ -246,7 +246,7 @@ private partial def normArgImp (s : FVarSubst) (arg : Arg) (translator : Bool) :
match arg with
| .erased => arg
| .fvar fvarId =>
match s.find? fvarId with
match s[fvarId]? with
| some (.fvar fvarId') =>
let arg' := .fvar fvarId'
if translator then arg' else normArgImp s arg' translator

View File

@@ -268,7 +268,7 @@ def getFunctionSummary? (env : Environment) (fid : Name) : Option Value :=
A map from variable identifiers to the `Value` produced by the abstract
interpreter for them.
-/
abbrev Assignment := HashMap FVarId Value
abbrev Assignment := Std.HashMap FVarId Value
/--
The context of `InterpM`.
@@ -332,7 +332,7 @@ If none is available return `Value.bot`.
-/
def findVarValue (var : FVarId) : InterpM Value := do
let assignment getAssignment
return assignment.findD var .bot
return assignment.getD var .bot
/--
Find the value of `arg` using the logic of `findVarValue`.
@@ -547,13 +547,13 @@ where
| .jp decl k | .fun decl k =>
return code.updateFun! ( decl.updateValue ( go decl.value)) ( go k)
| .cases cs =>
let discrVal := assignment.findD cs.discr .bot
let discrVal := assignment.getD cs.discr .bot
let processAlt typ alt := do
match alt with
| .alt ctor args body =>
if discrVal.containsCtor ctor then
let filter param := do
if let some val := assignment.find? param.fvarId then
if let some val := assignment[param.fvarId]? then
if let some literal val.getLiteral then
return some (param, literal)
return none

View File

@@ -62,7 +62,7 @@ structure State where
Whenever there is function application `f a₁ ... aₙ`, where `f` is in `decls`, `f` is not `main`, and
we visit with the abstract values assigned to `aᵢ`, but first we record the visit here.
-/
visited : HashSet (Name × Array AbsValue) := {}
visited : Std.HashSet (Name × Array AbsValue) := {}
/--
Bitmask containing the result, i.e., which parameters of `main` are fixed.
We initialize it with `true` everywhere.

View File

@@ -59,7 +59,7 @@ structure FloatState where
/--
A map from identifiers of declarations to their current decision.
-/
decision : HashMap FVarId Decision
decision : Std.HashMap FVarId Decision
/--
A map from decisions (excluding `unknown`) to the declarations with
these decisions (in correct order). Basically:
@@ -67,7 +67,7 @@ structure FloatState where
- Which declarations do we move into a certain arm
- Which declarations do we move into the default arm
-/
newArms : HashMap Decision (List CodeDecl)
newArms : Std.HashMap Decision (List CodeDecl)
/--
Use to collect relevant declarations for the floating mechanism.
@@ -116,8 +116,8 @@ up to this point, with respect to `cs`. The initial decisions are:
- `arm` or `default` if we see the declaration only being used in exactly one cases arm
- `unknown` otherwise
-/
def initialDecisions (cs : Cases) : BaseFloatM (HashMap FVarId Decision) := do
let mut map := mkHashMap ( read).decls.length
def initialDecisions (cs : Cases) : BaseFloatM (Std.HashMap FVarId Decision) := do
let mut map := Std.HashMap.empty ( read).decls.length
let folder val acc := do
if let .let decl := val then
if ( ignore? decl) then
@@ -130,25 +130,25 @@ def initialDecisions (cs : Cases) : BaseFloatM (HashMap FVarId Decision) := do
(_, map) goCases cs |>.run map
return map
where
goFVar (plannedDecision : Decision) (var : FVarId) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit := do
if let some decision := ( get).find? var then
goFVar (plannedDecision : Decision) (var : FVarId) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit := do
if let some decision := ( get)[var]? then
if decision == .unknown then
modify fun s => s.insert var plannedDecision
else if decision != plannedDecision then
modify fun s => s.insert var .dont
-- otherwise we already have the proper decision
goAlt (alt : Alt) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit :=
goAlt (alt : Alt) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
forFVarM (goFVar (.ofAlt alt)) alt
goCases (cs : Cases) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit :=
goCases (cs : Cases) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
cs.alts.forM goAlt
/--
Compute the initial new arms. This will just set up a map from all arms of
`cs` to empty `Array`s, plus one additional entry for `dont`.
-/
def initialNewArms (cs : Cases) : HashMap Decision (List CodeDecl) := Id.run do
let mut map := mkHashMap (cs.alts.size + 1)
def initialNewArms (cs : Cases) : Std.HashMap Decision (List CodeDecl) := Id.run do
let mut map := Std.HashMap.empty (cs.alts.size + 1)
map := map.insert .dont []
cs.alts.foldr (init := map) fun val acc => acc.insert (.ofAlt val) []
@@ -170,7 +170,7 @@ respectively but since `z` can't be moved we don't want that to move `x` and `y`
-/
def dontFloat (decl : CodeDecl) : FloatM Unit := do
forFVarM goFVar decl
modify fun s => { s with newArms := s.newArms.insert .dont (decl :: s.newArms.find! .dont) }
modify fun s => { s with newArms := s.newArms.insert .dont (decl :: s.newArms[Decision.dont]!) }
where
goFVar (fvar : FVarId) : FloatM Unit := do
if ( get).decision.contains fvar then
@@ -223,12 +223,12 @@ Will:
If we are at `y` `x` is still marked to be moved but we don't want that.
-/
def float (decl : CodeDecl) : FloatM Unit := do
let arm := ( get).decision.find! decl.fvarId
let arm := ( get).decision[decl.fvarId]!
forFVarM (goFVar · arm) decl
modify fun s => { s with newArms := s.newArms.insert arm (decl :: s.newArms.find! arm) }
modify fun s => { s with newArms := s.newArms.insert arm (decl :: s.newArms[arm]!) }
where
goFVar (fvar : FVarId) (arm : Decision) : FloatM Unit := do
let some decision := ( get).decision.find? fvar | return ()
let some decision := ( get).decision[fvar]? | return ()
if decision != arm then
modify fun s => { s with decision := s.decision.insert fvar .dont }
else if decision == .unknown then
@@ -249,7 +249,7 @@ where
-/
goCases : FloatM Unit := do
for decl in ( read).decls do
let currentDecision := ( get).decision.find! decl.fvarId
let currentDecision := ( get).decision[decl.fvarId]!
if currentDecision == .unknown then
/-
If the decision is still unknown by now this means `decl` is
@@ -284,10 +284,10 @@ where
newArms := initialNewArms cs
}
let (_, res) goCases |>.run base
let remainders := res.newArms.find! .dont
let remainders := res.newArms[Decision.dont]!
let altMapper alt := do
let decision := .ofAlt alt
let newCode := res.newArms.find! decision
let decision := Decision.ofAlt alt
let newCode := res.newArms[decision]!
trace[Compiler.floatLetIn] "Size of code that was pushed into arm: {repr decision} {newCode.length}"
let fused withNewScope do
go (attachCodeDecls newCode.toArray alt.getCode)

View File

@@ -29,7 +29,7 @@ structure CandidateInfo where
The set of candidates that rely on this candidate to be a join point.
For a more detailed explanation see the documentation of `find`
-/
associated : HashSet FVarId
associated : Std.HashSet FVarId
deriving Inhabited
/--
@@ -39,14 +39,14 @@ structure FindState where
/--
All current join point candidates accessible by their `FVarId`.
-/
candidates : HashMap FVarId CandidateInfo := .empty
candidates : Std.HashMap FVarId CandidateInfo := .empty
/--
The `FVarId`s of all `fun` declarations that were declared within the
current `fun`.
-/
scope : HashSet FVarId := .empty
scope : Std.HashSet FVarId := .empty
abbrev ReplaceCtx := HashMap FVarId Name
abbrev ReplaceCtx := Std.HashMap FVarId Name
abbrev FindM := ReaderT (Option FVarId) StateRefT FindState ScopeM
abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
@@ -55,7 +55,7 @@ abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
Attempt to find a join point candidate by its `FVarId`.
-/
private def findCandidate? (fvarId : FVarId) : FindM (Option CandidateInfo) := do
return ( get).candidates.find? fvarId
return ( get).candidates[fvarId]?
/--
Erase a join point candidate as well as all the ones that depend on it
@@ -69,7 +69,7 @@ private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
/--
Combinator for modifying the candidates in `FindM`.
-/
private def modifyCandidates (f : HashMap FVarId CandidateInfo HashMap FVarId CandidateInfo) : FindM Unit :=
private def modifyCandidates (f : Std.HashMap FVarId CandidateInfo Std.HashMap FVarId CandidateInfo) : FindM Unit :=
modify (fun state => {state with candidates := f state.candidates })
/--
@@ -196,7 +196,7 @@ where
return code
| _, _ => return Code.updateLet! code decl ( go k)
| .fun decl k =>
if let some replacement := ( read).find? decl.fvarId then
if let some replacement := ( read)[decl.fvarId]? then
let newDecl := { decl with
binderName := replacement,
value := ( go decl.value)
@@ -244,7 +244,7 @@ structure ExtendState where
to `Param`s. The free variables in this map are the once that the context
of said join point will be extended by by passing in the respective parameter.
-/
fvarMap : HashMap FVarId (HashMap FVarId Param) := {}
fvarMap : Std.HashMap FVarId (Std.HashMap FVarId Param) := {}
/--
The monad for the `extendJoinPointContext` pass.
@@ -262,7 +262,7 @@ otherwise just return `fvar`.
def replaceFVar (fvar : FVarId) : ExtendM FVarId := do
if ( read).candidates.contains fvar then
if let some currentJp := ( read).currentJp? then
if let some replacement := ( get).fvarMap.find! currentJp |>.find? fvar then
if let some replacement := ( get).fvarMap[currentJp]![fvar]? then
return replacement.fvarId
return fvar
@@ -313,7 +313,7 @@ This is necessary if:
-/
def extendByIfNecessary (fvar : FVarId) : ExtendM Unit := do
if let some currentJp := ( read).currentJp? then
let mut translator := ( get).fvarMap.find! currentJp
let mut translator := ( get).fvarMap[currentJp]!
let candidates := ( read).candidates
if !( isInScope fvar) && !translator.contains fvar && candidates.contains fvar then
let typ getType fvar
@@ -337,7 +337,7 @@ of `j.2` in `j.1`.
-/
def mergeJpContextIfNecessary (jp : FVarId) : ExtendM Unit := do
if ( read).currentJp?.isSome then
let additionalArgs := ( get).fvarMap.find! jp |>.toArray
let additionalArgs := ( get).fvarMap[jp]!.toArray
for (fvar, _) in additionalArgs do
extendByIfNecessary fvar
@@ -405,7 +405,7 @@ where
| .jp decl k =>
let decl withNewJpScope decl do
let value go decl.value
let additionalParams := ( get).fvarMap.find! decl.fvarId |>.toArray |>.map Prod.snd
let additionalParams := ( get).fvarMap[decl.fvarId]!.toArray |>.map Prod.snd
let newType := additionalParams.foldr (init := decl.type) (fun val acc => .forallE val.binderName val.type acc .default)
decl.update newType (additionalParams ++ decl.params) value
mergeJpContextIfNecessary decl.fvarId
@@ -426,7 +426,7 @@ where
return Code.updateCases! code cs.resultType discr alts
| .jmp fn args =>
let mut newArgs args.mapM (mapFVarM goFVar)
let additionalArgs := ( get).fvarMap.find! fn |>.toArray |>.map Prod.fst
let additionalArgs := ( get).fvarMap[fn]!.toArray |>.map Prod.fst
if let some _currentJp := ( read).currentJp? then
let f := fun arg => do
return .fvar ( goFVar arg)
@@ -545,7 +545,7 @@ where
if let some knownArgs := ( get).jpJmpArgs.find? fn then
let mut newArgs := knownArgs
for (param, arg) in decl.params.zip args do
if let some knownVal := newArgs.find? param.fvarId then
if let some knownVal := newArgs[param.fvarId]? then
if arg.toExpr != knownVal then
newArgs := newArgs.erase param.fvarId
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn newArgs }

View File

@@ -13,9 +13,9 @@ namespace Lean.Compiler.LCNF
LCNF local context.
-/
structure LCtx where
params : HashMap FVarId Param := {}
letDecls : HashMap FVarId LetDecl := {}
funDecls : HashMap FVarId FunDecl := {}
params : Std.HashMap FVarId Param := {}
letDecls : Std.HashMap FVarId LetDecl := {}
funDecls : Std.HashMap FVarId FunDecl := {}
deriving Inhabited
def LCtx.addParam (lctx : LCtx) (param : Param) : LCtx :=

View File

@@ -30,7 +30,7 @@ structure State where
/-- Counter for generating new (normalized) universe parameter names. -/
nextIdx : Nat := 1
/-- Mapping from existing universe parameter names to the new ones. -/
map : HashMap Name Level := {}
map : Std.HashMap Name Level := {}
/-- Parameters that have been normalized. -/
paramNames : Array Name := #[]
@@ -49,7 +49,7 @@ partial def normLevel (u : Level) : M Level := do
| .max v w => return u.updateMax! ( normLevel v) ( normLevel w)
| .imax v w => return u.updateIMax! ( normLevel v) ( normLevel w)
| .mvar _ => unreachable!
| .param n => match ( get).map.find? n with
| .param n => match ( get).map[n]? with
| some u => return u
| none =>
let u := Level.param <| (`u).appendIndexAfter ( get).nextIdx

View File

@@ -31,9 +31,9 @@ def sortedBySize : Probe Decl (Nat × Decl) := fun decls =>
if sz₁ == sz₂ then Name.lt decl₁.name decl₂.name else sz₁ < sz₂
def countUnique [ToString α] [BEq α] [Hashable α] : Probe α (α × Nat) := fun data => do
let mut map := HashMap.empty
let mut map := Std.HashMap.empty
for d in data do
if let some count := map.find? d then
if let some count := map[d]? then
map := map.insert d (count + 1)
else
map := map.insert d 1

View File

@@ -40,7 +40,7 @@ structure FunDeclInfoMap where
/--
Mapping from local function name to inlining information.
-/
map : HashMap FVarId FunDeclInfo := {}
map : Std.HashMap FVarId FunDeclInfo := {}
deriving Inhabited
def FunDeclInfoMap.format (s : FunDeclInfoMap) : CompilerM Format := do
@@ -56,7 +56,7 @@ Add new occurrence for the local function with binder name `key`.
def FunDeclInfoMap.add (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
match s with
| { map } =>
match map.find? fvarId with
match map[fvarId]? with
| some .once => { map := map.insert fvarId .many }
| none => { map := map.insert fvarId .once }
| _ => { map }
@@ -67,7 +67,7 @@ Add new occurrence for the local function occurring as an argument for another f
def FunDeclInfoMap.addHo (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
match s with
| { map } =>
match map.find? fvarId with
match map[fvarId]? with
| some .once | none => { map := map.insert fvarId .many }
| _ => { map }

View File

@@ -173,7 +173,7 @@ Execute `x` with `fvarId` set as `mustInline`.
After execution the original setting is restored.
-/
def withAddMustInline (fvarId : FVarId) (x : SimpM α) : SimpM α := do
let saved? := ( get).funDeclInfoMap.map.find? fvarId
let saved? := ( get).funDeclInfoMap.map[fvarId]?
try
addMustInline fvarId
x
@@ -185,7 +185,7 @@ Return true if the given local function declaration or join point id is marked a
`once` or `mustInline`. We use this information to decide whether to inline them.
-/
def isOnceOrMustInline (fvarId : FVarId) : SimpM Bool := do
match ( get).funDeclInfoMap.map.find? fvarId with
match ( get).funDeclInfoMap.map[fvarId]? with
| some .once | some .mustInline => return true
| _ => return false

View File

@@ -199,9 +199,9 @@ structure State where
/-- Cache from Lean regular expression to LCNF argument. -/
cache : PHashMap Expr Arg := {}
/-- `toLCNFType` cache -/
typeCache : HashMap Expr Expr := {}
typeCache : Std.HashMap Expr Expr := {}
/-- isTypeFormerType cache -/
isTypeFormerTypeCache : HashMap Expr Bool := {}
isTypeFormerTypeCache : Std.HashMap Expr Bool := {}
/-- LCNF sequence, we chain it to create a LCNF `Code` object. -/
seq : Array Element := #[]
/--
@@ -257,7 +257,7 @@ private partial def isTypeFormerType (type : Expr) : M Bool := do
| .true => return true
| .false => return false
| .undef =>
if let some result := ( get).isTypeFormerTypeCache.find? type then
if let some result := ( get).isTypeFormerTypeCache[type]? then
return result
let result liftMetaM <| Meta.isTypeFormerType type
modify fun s => { s with isTypeFormerTypeCache := s.isTypeFormerTypeCache.insert type result }
@@ -305,7 +305,7 @@ def applyToAny (type : Expr) : M Expr := do
| _ => none
def toLCNFType (type : Expr) : M Expr := do
match ( get).typeCache.find? type with
match ( get).typeCache[type]? with
| some type' => return type'
| none =>
let type' liftMetaM <| LCNF.toLCNFType type

View File

@@ -270,16 +270,3 @@ def ofListWith (l : List (α × β)) (f : β → β → β) : HashMap α β :=
| some v => m.insert p.fst $ f v p.snd)
end Lean.HashMap
/--
Groups all elements `x`, `y` in `xs` with `key x == key y` into the same array
`(xs.groupByKey key).find! (key x)`. Groups preserve the relative order of elements in `xs`.
-/
def Array.groupByKey [BEq α] [Hashable α] (key : β α) (xs : Array β)
: Lean.HashMap α (Array β) := Id.run do
let mut groups :=
for x in xs do
let group := groups.findD (key x) #[]
groups := groups.erase (key x) -- make `group` referentially unique
groups := groups.insert (key x) (group.push x)
return groups

View File

@@ -150,7 +150,7 @@ instance : FromJson RefInfo where
pure { definition?, usages }
/-- References from a single module/file -/
def ModuleRefs := HashMap RefIdent RefInfo
def ModuleRefs := Std.HashMap RefIdent RefInfo
instance : ToJson ModuleRefs where
toJson m := Json.mkObj <| m.toList.map fun (ident, info) => (ident.toJson.compress, toJson info)
@@ -158,7 +158,7 @@ instance : ToJson ModuleRefs where
instance : FromJson ModuleRefs where
fromJson? j := do
let node j.getObj?
node.foldM (init := HashMap.empty) fun m k v =>
node.foldM (init := Std.HashMap.empty) fun m k v =>
return m.insert ( RefIdent.fromJson? ( Json.parse k)) ( fromJson? v)
/--

View File

@@ -4,7 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
-/
prelude
import Lean.Data.HashSet
import Std.Data.HashSet.Basic
import Lean.Data.RBMap
import Lean.Data.RBTree
import Lean.Data.SSet
@@ -64,14 +64,14 @@ abbrev insert (s : NameSSet) (n : Name) : NameSSet := SSet.insert s n
abbrev contains (s : NameSSet) (n : Name) : Bool := SSet.contains s n
end NameSSet
def NameHashSet := HashSet Name
def NameHashSet := Std.HashSet Name
namespace NameHashSet
@[inline] def empty : NameHashSet := HashSet.empty
@[inline] def empty : NameHashSet := Std.HashSet.empty
instance : EmptyCollection NameHashSet := empty
instance : Inhabited NameHashSet := {}
def insert (s : NameHashSet) (n : Name) := HashSet.insert s n
def contains (s : NameHashSet) (n : Name) : Bool := HashSet.contains s n
def insert (s : NameHashSet) (n : Name) := Std.HashSet.insert s n
def contains (s : NameHashSet) (n : Name) : Bool := Std.HashSet.contains s n
end NameHashSet
def MacroScopesView.isPrefixOf (v₁ v₂ : MacroScopesView) : Bool :=

View File

@@ -4,7 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Data.HashMap
import Std.Data.HashMap.Basic
import Lean.Data.PersistentHashMap
universe u v w w'
@@ -28,7 +28,7 @@ namespace Lean
-/
structure SMap (α : Type u) (β : Type v) [BEq α] [Hashable α] where
stage₁ : Bool := true
map₁ : HashMap α β := {}
map₁ : Std.HashMap α β := {}
map₂ : PHashMap α β := {}
namespace SMap
@@ -37,7 +37,7 @@ variable {α : Type u} {β : Type v} [BEq α] [Hashable α]
instance : Inhabited (SMap α β) := {}
def empty : SMap α β := {}
@[inline] def fromHashMap (m : HashMap α β) (stage₁ := true) : SMap α β :=
@[inline] def fromHashMap (m : Std.HashMap α β) (stage₁ := true) : SMap α β :=
{ map₁ := m, stage₁ := stage₁ }
@[specialize] def insert : SMap α β α β SMap α β
@@ -49,8 +49,8 @@ def empty : SMap α β := {}
| false, m₁, m₂, k, v => false, m₁, m₂.insert k v
@[specialize] def find? : SMap α β α Option β
| true, m₁, _, k => m₁.find? k
| false, m₁, m₂, k => (m₂.find? k).orElse fun _ => m₁.find? k
| true, m₁, _, k => m₁[k]?
| false, m₁, m₂, k => (m₂.find? k).orElse fun _ => m₁[k]?
@[inline] def findD (m : SMap α β) (a : α) (b₀ : β) : β :=
(m.find? a).getD b₀
@@ -67,8 +67,8 @@ def empty : SMap α β := {}
/-- Similar to `find?`, but searches for result in the hashmap first.
So, the result is correct only if we never "overwrite" `map₁` entries using `map₂`. -/
@[specialize] def find?' : SMap α β α Option β
| true, m₁, _, k => m₁.find? k
| false, m₁, m₂, k => (m₁.find? k).orElse fun _ => m₂.find? k
| true, m₁, _, k => m₁[k]?
| false, m₁, m₂, k => m₁[k]?.orElse fun _ => m₂.find? k
def forM [Monad m] (s : SMap α β) (f : α β m PUnit) : m PUnit := do
s.map₁.forM f
@@ -96,7 +96,7 @@ def fold {σ : Type w} (f : σα → β → σ) (init : σ) (m : SMap α β
m.map₂.foldl f $ m.map₁.fold f init
def numBuckets (m : SMap α β) : Nat :=
m.map₁.numBuckets
Std.HashMap.Internal.numBuckets m.map₁
def toList (m : SMap α β) : List (α × β) :=
m.fold (init := []) fun es a b => (a, b)::es

View File

@@ -201,12 +201,12 @@ def mkMessageAux (ctx : Context) (ref : Syntax) (msgData : MessageData) (severit
private def addTraceAsMessagesCore (ctx : Context) (log : MessageLog) (traceState : TraceState) : MessageLog := Id.run do
if traceState.traces.isEmpty then return log
let mut traces : HashMap (String.Pos × String.Pos) (Array MessageData) :=
let mut traces : Std.HashMap (String.Pos × String.Pos) (Array MessageData) :=
for traceElem in traceState.traces do
let ref := replaceRef traceElem.ref ctx.ref
let pos := ref.getPos?.getD 0
let endPos := ref.getTailPos?.getD pos
traces := traces.insert (pos, endPos) <| traces.findD (pos, endPos) #[] |>.push traceElem.msg
traces := traces.insert (pos, endPos) <| traces.getD (pos, endPos) #[] |>.push traceElem.msg
let mut log := log
let traces' := traces.toArray.qsort fun ((a, _), _) ((b, _), _) => a < b
for ((pos, endPos), traceMsg) in traces' do

View File

@@ -630,7 +630,7 @@ private def replaceIndFVarsWithConsts (views : Array InductiveView) (indFVars :
let type := type.replace fun e =>
if !e.isFVar then
none
else match indFVar2Const.find? e with
else match indFVar2Const[e]? with
| none => none
| some c => mkAppN c (params.extract 0 numVars)
instantiateMVars ( mkForallFVars params type)

View File

@@ -425,7 +425,7 @@ private def applyRefMap (e : Expr) (map : ExprMap Expr) : Expr :=
e.replace fun e =>
match patternWithRef? e with
| some _ => some e -- stop `e` already has annotation
| none => match map.find? e with
| none => match map[e]? with
| some eWithRef => some eWithRef -- stop `e` found annotation
| none => none -- continue

View File

@@ -164,8 +164,11 @@ def addNonRec (preDef : PreDefinition) (applyAttrAfterCompilation := true) (all
/--
Eliminate recursive application annotations containing syntax. These annotations are used by the well-founded recursion module
to produce better error messages. -/
def eraseRecAppSyntaxExpr (e : Expr) : CoreM Expr :=
Core.transform e (post := fun e => pure <| TransformStep.done <| if (getRecAppSyntax? e).isSome then e.mdataExpr! else e)
def eraseRecAppSyntaxExpr (e : Expr) : CoreM Expr := do
if e.find? hasRecAppSyntax |>.isSome then
Core.transform e (post := fun e => pure <| TransformStep.done <| if hasRecAppSyntax e then e.mdataExpr! else e)
else
return e
def eraseRecAppSyntax (preDef : PreDefinition) : CoreM PreDefinition :=
return { preDef with value := ( eraseRecAppSyntaxExpr preDef.value) }

View File

@@ -69,12 +69,15 @@ private def ensureNoUnassignedMVarsAtPreDef (preDef : PreDefinition) : TermElabM
This method beta-reduces them to make sure they can be eliminated by the well-founded recursion module. -/
private def betaReduceLetRecApps (preDefs : Array PreDefinition) : MetaM (Array PreDefinition) :=
preDefs.mapM fun preDef => do
let value Core.transform preDef.value fun e => do
if e.isApp && e.getAppFn.isLambda && e.getAppArgs.all fun arg => arg.getAppFn.isConst && preDefs.any fun preDef => preDef.declName == arg.getAppFn.constName! then
return .visit e.headBeta
else
return .continue
return { preDef with value }
if preDef.value.find? (fun e => e.isConst && preDefs.any fun preDef => preDef.declName == e.constName!) |>.isSome then
let value Core.transform preDef.value fun e => do
if e.isApp && e.getAppFn.isLambda && e.getAppArgs.all fun arg => arg.getAppFn.isConst && preDefs.any fun preDef => preDef.declName == arg.getAppFn.constName! then
return .visit e.headBeta
else
return .continue
return { preDef with value }
else
return preDef
private def addAsAxioms (preDefs : Array PreDefinition) : TermElabM Unit := do
for preDef in preDefs do

View File

@@ -146,7 +146,7 @@ See issue #837 for an example where we can show termination using the index of a
we don't get the desired definitional equalities.
-/
def nonIndicesFirst (recArgInfos : Array RecArgInfo) : Array RecArgInfo := Id.run do
let mut indicesPos : HashSet Nat := {}
let mut indicesPos : Std.HashSet Nat := {}
for recArgInfo in recArgInfos do
for pos in recArgInfo.indicesPos do
indicesPos := indicesPos.insert pos

View File

@@ -596,7 +596,7 @@ private partial def compileStxMatch (discrs : List Term) (alts : List Alt) : Ter
`(have __discr := $discr; $stx)
| _, _ => unreachable!
abbrev IdxSet := HashSet Nat
abbrev IdxSet := Std.HashSet Nat
private partial def hasNoErrorIfUnused : Syntax Bool
| `(no_error_if_unused% $_) => true

View File

@@ -11,27 +11,35 @@ namespace Lean
private def recAppKey := `_recApp
/--
We store the syntax at recursive applications to be able to generate better error messages
when performing well-founded and structural recursion.
We store the syntax at recursive applications to be able to generate better error messages
when performing well-founded and structural recursion.
-/
def mkRecAppWithSyntax (e : Expr) (stx : Syntax) : Expr :=
mkMData (KVMap.empty.insert recAppKey (DataValue.ofSyntax stx)) e
mkMData (KVMap.empty.insert recAppKey (.ofSyntax stx)) e
/--
Retrieve (if available) the syntax object attached to a recursive application.
Retrieve (if available) the syntax object attached to a recursive application.
-/
def getRecAppSyntax? (e : Expr) : Option Syntax :=
match e with
| Expr.mdata d _ =>
| .mdata d _ =>
match d.find recAppKey with
| some (DataValue.ofSyntax stx) => some stx
| _ => none
| _ => none
/--
Checks if the `MData` is for a recursive applciation.
Checks if the `MData` is for a recursive applciation.
-/
def MData.isRecApp (d : MData) : Bool :=
d.contains recAppKey
/--
Return `true` if `getRecAppSyntax? e` is a `some`.
-/
def hasRecAppSyntax (e : Expr) : Bool :=
match e with
| .mdata d _ => d.isRecApp
| _ => false
end Lean

View File

@@ -445,13 +445,13 @@ private def expandParentFields (s : Struct) : TermElabM Struct := do
| _ => throwErrorAt ref "failed to access field '{fieldName}' in parent structure"
| _ => return field
private abbrev FieldMap := HashMap Name Fields
private abbrev FieldMap := Std.HashMap Name Fields
private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
fields.foldlM (init := {}) fun fieldMap field =>
match field.lhs with
| .fieldName _ fieldName :: _ =>
match fieldMap.find? fieldName with
match fieldMap[fieldName]? with
| some (prevField::restFields) =>
if field.isSimple || prevField.isSimple then
throwErrorAt field.ref "field '{fieldName}' has already been specified"

View File

@@ -246,7 +246,7 @@ private def getSomeSyntheticMVarsRef : TermElabM Syntax := do
private def throwStuckAtUniverseCnstr : TermElabM Unit := do
-- This code assumes `entries` is not empty. Note that `processPostponed` uses `exceptionOnFailure` to guarantee this property
let entries getPostponed
let mut found : HashSet (Level × Level) := {}
let mut found : Std.HashSet (Level × Level) := {}
let mut uniqueEntries := #[]
for entry in entries do
let mut lhs := entry.lhs

View File

@@ -8,8 +8,6 @@ import Init.Omega.Constraint
import Lean.Elab.Tactic.Omega.OmegaM
import Lean.Elab.Tactic.Omega.MinNatAbs
open Lean (HashMap HashSet)
namespace Lean.Elab.Tactic.Omega
initialize Lean.registerTraceClass `omega
@@ -167,11 +165,11 @@ structure Problem where
/-- The number of variables in the problem. -/
numVars : Nat := 0
/-- The current constraints, indexed by their coefficients. -/
constraints : HashMap Coeffs Fact :=
constraints : Std.HashMap Coeffs Fact :=
/--
The coefficients for which `constraints` contains an exact constraint (i.e. an equality).
-/
equalities : HashSet Coeffs :=
equalities : Std.HashSet Coeffs :=
/--
Equations that have already been used to eliminate variables,
along with the variable which was removed, and its coefficient (either `1` or `-1`).
@@ -251,7 +249,7 @@ combining it with any existing constraints for the same coefficients.
def addConstraint (p : Problem) : Fact Problem
| f@x, s, j =>
if p.possible then
match p.constraints.find? x with
match p.constraints[x]? with
| none =>
match s with
| .trivial => p
@@ -313,7 +311,7 @@ After solving, the variable will have been eliminated from all constraints.
def solveEasyEquality (p : Problem) (c : Coeffs) : Problem :=
let i := c.findIdx? (·.natAbs = 1) |>.getD 0 -- findIdx? is always some
let sign := c.get i |> Int.sign
match p.constraints.find? c with
match p.constraints[c]? with
| some f =>
let init :=
{ assumptions := p.assumptions
@@ -335,7 +333,7 @@ After solving the easy equality,
the minimum lexicographic value of `(c.minNatAbs, c.maxNatAbs)` will have been reduced.
-/
def dealWithHardEquality (p : Problem) (c : Coeffs) : OmegaM Problem :=
match p.constraints.find? c with
match p.constraints[c]? with
| some _, some r, some r', j => do
let m := c.minNatAbs + 1
-- We have to store the valid value of the newly introduced variable in the atoms.
@@ -479,7 +477,7 @@ def fourierMotzkinData (p : Problem) : Array FourierMotzkinData := Id.run do
let n := p.numVars
let mut data : Array FourierMotzkinData :=
(List.range p.numVars).foldl (fun a i => a.push { var := i}) #[]
for (_, f@xs, s, _) in p.constraints.toList do -- We could make a forIn instance for HashMap
for (_, f@xs, s, _) in p.constraints do
for i in [0:n] do
let x := Coeffs.get xs i
data := data.modify i fun d =>

View File

@@ -58,7 +58,7 @@ structure MetaProblem where
-/
disjunctions : List Expr := []
/-- Facts which have already been processed; we keep these to avoid duplicates. -/
processedFacts : HashSet Expr :=
processedFacts : Std.HashSet Expr :=
/-- Construct the `rfl` proof that `lc.eval atoms = e`. -/
def mkEvalRflProof (e : Expr) (lc : LinearCombo) : OmegaM Expr := do
@@ -80,7 +80,7 @@ def mkCoordinateEvalAtomsEq (e : Expr) (n : Nat) : OmegaM Expr := do
mkEqTrans eq ( mkEqSymm (mkApp2 (.const ``LinearCombo.coordinate_eval []) n atoms))
/-- Construct the linear combination (and its associated proof and new facts) for an atom. -/
def mkAtomLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
def mkAtomLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
let (n, facts) lookup e
return LinearCombo.coordinate n, mkCoordinateEvalAtomsEq e n, facts.getD
@@ -94,9 +94,9 @@ Gives a small (10%) speedup in testing.
I tried using a pointer based cache,
but there was never enough subexpression sharing to make it effective.
-/
partial def asLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
partial def asLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
let cache get
match cache.find? e with
match cache.get? e with
| some (lc, prf) =>
trace[omega] "Found in cache: {e}"
return (lc, prf, )
@@ -120,7 +120,7 @@ We also transform the expression as we descend into it:
* pushing coercions: `↑(x + y)`, `↑(x * y)`, `↑(x / k)`, `↑(x % k)`, `↑k`
* unfolding `emod`: `x % k` → `x - x / k`
-/
partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
trace[omega] "processing {e}"
match groundInt? e with
| some i =>
@@ -142,7 +142,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
mkEqTrans
( mkAppM ``Int.add_congr #[ prf₁, prf₂])
( mkEqSymm add_eval)
pure (l₁ + l₂, prf, facts₁.merge facts₂)
pure (l₁ + l₂, prf, facts₁.union facts₂)
| (``HSub.hSub, #[_, _, _, _, e₁, e₂]) => do
let (l₁, prf₁, facts₁) asLinearCombo e₁
let (l₂, prf₂, facts₂) asLinearCombo e₂
@@ -152,7 +152,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
mkEqTrans
( mkAppM ``Int.sub_congr #[ prf₁, prf₂])
( mkEqSymm sub_eval)
pure (l₁ - l₂, prf, facts₁.merge facts₂)
pure (l₁ - l₂, prf, facts₁.union facts₂)
| (``Neg.neg, #[_, _, e']) => do
let (l, prf, facts) asLinearCombo e'
let prf' : OmegaM Expr := do
@@ -178,7 +178,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
mkEqTrans
( mkAppM ``Int.mul_congr #[ xprf, yprf])
( mkEqSymm mul_eval)
pure (some (LinearCombo.mul xl yl, prf, xfacts.merge yfacts), true)
pure (some (LinearCombo.mul xl yl, prf, xfacts.union yfacts), true)
else
pure (none, false)
match r? with
@@ -235,7 +235,7 @@ where
Apply a rewrite rule to an expression, and interpret the result as a `LinearCombo`.
(We're not rewriting any subexpressions here, just the top level, for efficiency.)
-/
rewrite (lhs rw : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
rewrite (lhs rw : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
trace[omega] "rewriting {lhs} via {rw} : {← inferType rw}"
match ( inferType rw).eq? with
| some (_, _lhs', rhs) =>
@@ -243,7 +243,7 @@ where
let prf' : OmegaM Expr := do mkEqTrans rw ( prf)
pure (lc, prf', facts)
| none => panic! "Invalid rewrite rule in 'asLinearCombo'"
handleNatCast (e i n : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
handleNatCast (e i n : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
match n with
| .fvar h =>
if let some v h.getValue? then
@@ -296,7 +296,7 @@ where
| (``Fin.val, #[n, x]) =>
handleFinVal e i n x
| _ => mkAtomLinearCombo e
handleFinVal (e i n x : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
handleFinVal (e i n x : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
match x with
| .fvar h =>
if let some v h.getValue? then
@@ -342,7 +342,7 @@ We solve equalities as they are discovered, as this often results in an earlier
-/
def addIntEquality (p : MetaProblem) (h x : Expr) : OmegaM MetaProblem := do
let (lc, prf, facts) asLinearCombo x
let newFacts : HashSet Expr := facts.fold (init := ) fun s e =>
let newFacts : Std.HashSet Expr := facts.fold (init := ) fun s e =>
if p.processedFacts.contains e then s else s.insert e
trace[omega] "Adding proof of {lc} = 0"
pure <|
@@ -358,7 +358,7 @@ We solve equalities as they are discovered, as this often results in an earlier
-/
def addIntInequality (p : MetaProblem) (h y : Expr) : OmegaM MetaProblem := do
let (lc, prf, facts) asLinearCombo y
let newFacts : HashSet Expr := facts.fold (init := ) fun s e =>
let newFacts : Std.HashSet Expr := facts.fold (init := ) fun s e =>
if p.processedFacts.contains e then s else s.insert e
trace[omega] "Adding proof of {lc} ≥ 0"
pure <|
@@ -590,7 +590,7 @@ where
-- We sort the constraints; otherwise the order is dependent on details of the hashing
-- and this can cause test suite output churn
prettyConstraints (names : Array String) (constraints : HashMap Coeffs Fact) : String :=
prettyConstraints (names : Array String) (constraints : Std.HashMap Coeffs Fact) : String :=
constraints.toList
|>.toArray
|>.qsort (·.1 < ·.1)
@@ -615,7 +615,7 @@ where
(if Int.natAbs c = 1 then names[i]! else s!"{c.natAbs}*{names[i]!}"))
|> String.join
mentioned (atoms : Array Expr) (constraints : HashMap Coeffs Fact) : MetaM (Array Bool) := do
mentioned (atoms : Array Expr) (constraints : Std.HashMap Coeffs Fact) : MetaM (Array Bool) := do
let initMask := Array.mkArray atoms.size false
return constraints.fold (init := initMask) fun mask coeffs _ =>
coeffs.enum.foldl (init := mask) fun mask (i, c) =>

View File

@@ -10,6 +10,8 @@ import Init.Omega.Logic
import Init.Data.BitVec.Basic
import Lean.Meta.AppBuilder
import Lean.Meta.Canonicalizer
import Std.Data.HashMap.Basic
import Std.Data.HashSet.Basic
/-!
# The `OmegaM` state monad.
@@ -52,7 +54,7 @@ structure Context where
/-- The internal state for the `OmegaM` monad, recording previously encountered atoms. -/
structure State where
/-- The atoms up-to-defeq encountered so far. -/
atoms : HashMap Expr Nat := {}
atoms : Std.HashMap Expr Nat := {}
/-- An intermediate layer in the `OmegaM` monad. -/
abbrev OmegaM' := StateRefT State (ReaderT Context CanonM)
@@ -60,7 +62,7 @@ abbrev OmegaM' := StateRefT State (ReaderT Context CanonM)
/--
Cache of expressions that have been visited, and their reflection as a linear combination.
-/
def Cache : Type := HashMap Expr (LinearCombo × OmegaM' Expr)
def Cache : Type := Std.HashMap Expr (LinearCombo × OmegaM' Expr)
/--
The `OmegaM` monad maintains two pieces of state:
@@ -71,7 +73,7 @@ abbrev OmegaM := StateRefT Cache OmegaM'
/-- Run a computation in the `OmegaM` monad, starting with no recorded atoms. -/
def OmegaM.run (m : OmegaM α) (cfg : OmegaConfig) : MetaM α :=
m.run' HashMap.empty |>.run' {} { cfg } |>.run'
m.run' Std.HashMap.empty |>.run' {} { cfg } |>.run'
/-- Retrieve the user-specified configuration options. -/
def cfg : OmegaM OmegaConfig := do pure ( read).cfg
@@ -162,11 +164,11 @@ def mkEqReflWithExpectedType (a b : Expr) : MetaM Expr := do
Analyzes a newly recorded atom,
returning a collection of interesting facts about it that should be added to the context.
-/
def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
match e.getAppFnArgs with
| (``Nat.cast, #[.const ``Int [], _, e']) =>
-- Casts of natural numbers are non-negative.
let mut r := HashSet.empty.insert (Expr.app (.const ``Int.ofNat_nonneg []) e')
let mut r := Std.HashSet.empty.insert (Expr.app (.const ``Int.ofNat_nonneg []) e')
match ( cfg).splitNatSub, e'.getAppFnArgs with
| true, (``HSub.hSub, #[_, _, _, _, a, b]) =>
-- `((a - b : Nat) : Int)` gives a dichotomy
@@ -188,7 +190,7 @@ def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
let ne_zero := mkApp3 (.const ``Ne [1]) (.const ``Int []) k (toExpr (0 : Int))
let pos := mkApp4 (.const ``LT.lt [0]) (.const ``Int []) (.const ``Int.instLTInt [])
(toExpr (0 : Int)) k
pure <| HashSet.empty.insert
pure <| Std.HashSet.empty.insert
(mkApp3 (.const ``Int.mul_ediv_self_le []) x k ( mkDecideProof ne_zero)) |>.insert
(mkApp3 (.const ``Int.lt_mul_ediv_self_add []) x k ( mkDecideProof pos))
| (``HMod.hMod, #[_, _, _, _, x, k]) =>
@@ -200,7 +202,7 @@ def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
let b_pos := mkApp4 (.const ``LT.lt [0]) (.const ``Int []) (.const ``Int.instLTInt [])
(toExpr (0 : Int)) b
let pow_pos := mkApp3 (.const ``Lean.Omega.Int.pos_pow_of_pos []) b exp ( mkDecideProof b_pos)
pure <| HashSet.empty.insert
pure <| Std.HashSet.empty.insert
(mkApp3 (.const ``Int.emod_nonneg []) x k
(mkApp3 (.const ``Int.ne_of_gt []) k (toExpr (0 : Int)) pow_pos)) |>.insert
(mkApp3 (.const ``Int.emod_lt_of_pos []) x k pow_pos)
@@ -214,7 +216,7 @@ def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
(toExpr (0 : Nat)) b
let pow_pos := mkApp3 (.const ``Nat.pos_pow_of_pos []) b exp ( mkDecideProof b_pos)
let cast_pos := mkApp2 (.const ``Int.ofNat_pos_of_pos []) k' pow_pos
pure <| HashSet.empty.insert
pure <| Std.HashSet.empty.insert
(mkApp3 (.const ``Int.emod_nonneg []) x k
(mkApp3 (.const ``Int.ne_of_gt []) k (toExpr (0 : Int)) cast_pos)) |>.insert
(mkApp3 (.const ``Int.emod_lt_of_pos []) x k cast_pos)
@@ -222,18 +224,18 @@ def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
| (``Nat.cast, #[.const ``Int [], _, x']) =>
-- Since we push coercions inside `%`, we need to record here that
-- `(x : Int) % (y : Int)` is non-negative.
pure <| HashSet.empty.insert (mkApp2 (.const ``Int.emod_ofNat_nonneg []) x' k)
pure <| Std.HashSet.empty.insert (mkApp2 (.const ``Int.emod_ofNat_nonneg []) x' k)
| _ => pure
| _ => pure
| (``Min.min, #[_, _, x, y]) =>
pure <| HashSet.empty.insert (mkApp2 (.const ``Int.min_le_left []) x y) |>.insert
pure <| Std.HashSet.empty.insert (mkApp2 (.const ``Int.min_le_left []) x y) |>.insert
(mkApp2 (.const ``Int.min_le_right []) x y)
| (``Max.max, #[_, _, x, y]) =>
pure <| HashSet.empty.insert (mkApp2 (.const ``Int.le_max_left []) x y) |>.insert
pure <| Std.HashSet.empty.insert (mkApp2 (.const ``Int.le_max_left []) x y) |>.insert
(mkApp2 (.const ``Int.le_max_right []) x y)
| (``ite, #[α, i, dec, t, e]) =>
if α == (.const ``Int []) then
pure <| HashSet.empty.insert <| mkApp5 (.const ``ite_disjunction [0]) α i dec t e
pure <| Std.HashSet.empty.insert <| mkApp5 (.const ``ite_disjunction [0]) α i dec t e
else
pure {}
| _ => pure
@@ -248,10 +250,10 @@ Return its index, and, if it is new, a collection of interesting facts about the
* for each new atom of the form `((a - b : Nat) : Int)`, the fact:
`b ≤ a ∧ ((a - b : Nat) : Int) = a - b a < b ∧ ((a - b : Nat) : Int) = 0`
-/
def lookup (e : Expr) : OmegaM (Nat × Option (HashSet Expr)) := do
def lookup (e : Expr) : OmegaM (Nat × Option (Std.HashSet Expr)) := do
let c getThe State
let e canon e
match c.atoms.find? e with
match c.atoms[e]? with
| some i => return (i, none)
| none =>
trace[omega] "New atom: {e}"

View File

@@ -7,7 +7,6 @@ prelude
import Init.Control.StateRef
import Init.Data.Array.BinSearch
import Init.Data.Stream
import Lean.Data.HashMap
import Lean.ImportingFlag
import Lean.Data.SMap
import Lean.Declaration
@@ -134,7 +133,7 @@ structure Environment where
the field `constants`. These auxiliary constants are invisible to the Lean kernel and elaborator.
Only the code generator uses them.
-/
const2ModIdx : HashMap Name ModuleIdx
const2ModIdx : Std.HashMap Name ModuleIdx
/--
Mapping from constant name to `ConstantInfo`. It contains all constants (definitions, theorems, axioms, etc)
that have been already type checked by the kernel.
@@ -205,7 +204,7 @@ private def getTrustLevel (env : Environment) : UInt32 :=
env.header.trustLevel
def getModuleIdxFor? (env : Environment) (declName : Name) : Option ModuleIdx :=
env.const2ModIdx.find? declName
env.const2ModIdx[declName]?
def isConstructor (env : Environment) (declName : Name) : Bool :=
match env.find? declName with
@@ -721,7 +720,7 @@ def writeModule (env : Environment) (fname : System.FilePath) : IO Unit := do
Construct a mapping from persistent extension name to entension index at the array of persistent extensions.
We only consider extensions starting with index `>= startingAt`.
-/
def mkExtNameMap (startingAt : Nat) : IO (HashMap Name Nat) := do
def mkExtNameMap (startingAt : Nat) : IO (Std.HashMap Name Nat) := do
let descrs persistentEnvExtensionsRef.get
let mut result := {}
for h : i in [startingAt : descrs.size] do
@@ -742,7 +741,7 @@ private def setImportedEntries (env : Environment) (mods : Array ModuleData) (st
have : modIdx < mods.size := h.upper
let mod := mods[modIdx]
for (extName, entries) in mod.entries do
if let some entryIdx := extNameIdx.find? extName then
if let some entryIdx := extNameIdx[extName]? then
env := extDescrs[entryIdx]!.toEnvExtension.modifyState env fun s => { s with importedEntries := s.importedEntries.set! modIdx entries }
return env
@@ -790,9 +789,9 @@ structure ImportState where
moduleData : Array ModuleData := #[]
regions : Array CompactedRegion := #[]
def throwAlreadyImported (s : ImportState) (const2ModIdx : HashMap Name ModuleIdx) (modIdx : Nat) (cname : Name) : IO α := do
def throwAlreadyImported (s : ImportState) (const2ModIdx : Std.HashMap Name ModuleIdx) (modIdx : Nat) (cname : Name) : IO α := do
let modName := s.moduleNames[modIdx]!
let constModName := s.moduleNames[const2ModIdx[cname].get!.toNat]!
let constModName := s.moduleNames[const2ModIdx[cname]!.toNat]!
throw <| IO.userError s!"import {modName} failed, environment already contains '{cname}' from {constModName}"
abbrev ImportStateM := StateRefT ImportState IO
@@ -856,21 +855,21 @@ def finalizeImport (s : ImportState) (imports : Array Import) (opts : Options) (
(leakEnv := false) : IO Environment := do
let numConsts := s.moduleData.foldl (init := 0) fun numConsts mod =>
numConsts + mod.constants.size + mod.extraConstNames.size
let mut const2ModIdx : HashMap Name ModuleIdx := mkHashMap (capacity := numConsts)
let mut constantMap : HashMap Name ConstantInfo := mkHashMap (capacity := numConsts)
let mut const2ModIdx : Std.HashMap Name ModuleIdx := Std.HashMap.empty (capacity := numConsts)
let mut constantMap : Std.HashMap Name ConstantInfo := Std.HashMap.empty (capacity := numConsts)
for h:modIdx in [0:s.moduleData.size] do
let mod := s.moduleData[modIdx]'h.upper
for cname in mod.constNames, cinfo in mod.constants do
match constantMap.insertIfNew cname cinfo with
| (constantMap', cinfoPrev?) =>
match constantMap.getThenInsertIfNew? cname cinfo with
| (cinfoPrev?, constantMap') =>
constantMap := constantMap'
if let some cinfoPrev := cinfoPrev? then
-- Recall that the map has not been modified when `cinfoPrev? = some _`.
unless equivInfo cinfoPrev cinfo do
throwAlreadyImported s const2ModIdx modIdx cname
const2ModIdx := const2ModIdx.insertIfNew cname modIdx |>.1
const2ModIdx := const2ModIdx.insertIfNew cname modIdx
for cname in mod.extraConstNames do
const2ModIdx := const2ModIdx.insertIfNew cname modIdx |>.1
const2ModIdx := const2ModIdx.insertIfNew cname modIdx
let constants : ConstMap := SMap.fromHashMap constantMap false
let exts mkInitialExtensionStates
let mut env : Environment := {
@@ -936,7 +935,7 @@ builtin_initialize namespacesExt : SimplePersistentEnvExtension Name NameSSet
6.18% of the runtime is here. It was 9.31% before the `HashMap` optimization.
-/
let capacity := as.foldl (init := 0) fun r e => r + e.size
let map : HashMap Name Unit := mkHashMap capacity
let map : Std.HashMap Name Unit := Std.HashMap.empty capacity
let map := mkStateFromImportedEntries (fun map name => map.insert name ()) map as
SMap.fromHashMap map |>.switch
addEntryFn := fun s n => s.insert n

View File

@@ -8,6 +8,7 @@ import Init.Data.Hashable
import Lean.Data.KVMap
import Lean.Data.SMap
import Lean.Level
import Std.Data.HashSet.Basic
namespace Lean
@@ -244,7 +245,7 @@ def FVarIdSet.insert (s : FVarIdSet) (fvarId : FVarId) : FVarIdSet :=
A set of unique free variable identifiers implemented using hashtables.
Hashtables are faster than red-black trees if they are used linearly.
They are not persistent data-structures. -/
def FVarIdHashSet := HashSet FVarId
def FVarIdHashSet := Std.HashSet FVarId
deriving Inhabited, EmptyCollection
/--
@@ -1388,11 +1389,11 @@ def mkDecIsTrue (pred proof : Expr) :=
def mkDecIsFalse (pred proof : Expr) :=
mkAppB (mkConst `Decidable.isFalse) pred proof
abbrev ExprMap (α : Type) := HashMap Expr α
abbrev ExprMap (α : Type) := Std.HashMap Expr α
abbrev PersistentExprMap (α : Type) := PHashMap Expr α
abbrev SExprMap (α : Type) := SMap Expr α
abbrev ExprSet := HashSet Expr
abbrev ExprSet := Std.HashSet Expr
abbrev PersistentExprSet := PHashSet Expr
abbrev PExprSet := PersistentExprSet
@@ -1417,7 +1418,7 @@ instance : ToString ExprStructEq := ⟨fun e => toString e.val⟩
end ExprStructEq
abbrev ExprStructMap (α : Type) := HashMap ExprStructEq α
abbrev ExprStructMap (α : Type) := Std.HashMap ExprStructEq α
abbrev PersistentExprStructMap (α : Type) := PHashMap ExprStructEq α
namespace Expr

View File

@@ -31,7 +31,7 @@ namespace Lean
abbrev LabelExtension := SimpleScopedEnvExtension Name (Array Name)
/-- The collection of all current `LabelExtension`s, indexed by name. -/
abbrev LabelExtensionMap := HashMap Name LabelExtension
abbrev LabelExtensionMap := Std.HashMap Name LabelExtension
/-- Store the current `LabelExtension`s. -/
builtin_initialize labelExtensionMapRef : IO.Ref LabelExtensionMap IO.mkRef {}
@@ -88,7 +88,7 @@ macro (name := _root_.Lean.Parser.Command.registerLabelAttr)
/-- When `attrName` is an attribute created using `register_labelled_attr`,
return the names of all declarations labelled using that attribute. -/
def labelled (attrName : Name) : CoreM (Array Name) := do
match ( labelExtensionMapRef.get).find? attrName with
match ( labelExtensionMapRef.get)[attrName]? with
| none => throwError "No extension named {attrName}"
| some ext => pure <| ext.getState ( getEnv)

View File

@@ -5,8 +5,6 @@ Authors: Leonardo de Moura
-/
prelude
import Init.Data.Array.QSort
import Lean.Data.HashMap
import Lean.Data.HashSet
import Lean.Data.PersistentHashMap
import Lean.Data.PersistentHashSet
import Lean.Hygiene
@@ -614,9 +612,9 @@ where
end Level
abbrev LevelMap (α : Type) := HashMap Level α
abbrev LevelMap (α : Type) := Std.HashMap Level α
abbrev PersistentLevelMap (α : Type) := PHashMap Level α
abbrev LevelSet := HashSet Level
abbrev LevelSet := Std.HashSet Level
abbrev PersistentLevelSet := PHashSet Level
abbrev PLevelSet := PersistentLevelSet

View File

@@ -34,7 +34,7 @@ def constructorNameAsVariable : Linter where
| return
let infoTrees := ( get).infoState.trees.toArray
let warnings : IO.Ref (Lean.HashMap String.Range (Syntax × Name × Name)) IO.mkRef {}
let warnings : IO.Ref (Std.HashMap String.Range (Syntax × Name × Name)) IO.mkRef {}
for tree in infoTrees do
tree.visitM' (preNode := fun ci info _ => do

View File

@@ -149,7 +149,7 @@ def checkDecl : SimpleHandler := fun stx => do
lintField rest[1][0] stx[1] "computed field"
else if rest.getKind == ``«structure» then
unless rest[5][2].isNone do
let redecls : HashSet String.Pos :=
let redecls : Std.HashSet String.Pos :=
( get).infoState.trees.foldl (init := {}) fun s tree =>
tree.foldInfo (init := s) fun _ info s =>
if let .ofFieldRedeclInfo info := info then

View File

@@ -270,14 +270,14 @@ pointer identity and does not store the objects, so it is important not to store
pointer to an object in the map, or it can be freed and reused, resulting in incorrect behavior.
Returns `true` if the object was not already in the set. -/
unsafe def insertObjImpl {α : Type} (set : IO.Ref (HashSet USize)) (a : α) : IO Bool := do
unsafe def insertObjImpl {α : Type} (set : IO.Ref (Std.HashSet USize)) (a : α) : IO Bool := do
if ( set.get).contains (ptrAddrUnsafe a) then
return false
set.modify (·.insert (ptrAddrUnsafe a))
return true
@[inherit_doc insertObjImpl, implemented_by insertObjImpl]
opaque insertObj {α : Type} (set : IO.Ref (HashSet USize)) (a : α) : IO Bool
opaque insertObj {α : Type} (set : IO.Ref (Std.HashSet USize)) (a : α) : IO Bool
/--
Collects into `fvarUses` all `fvar`s occurring in the `Expr`s in `assignments`.
@@ -285,8 +285,8 @@ This implementation respects subterm sharing in both the `PersistentHashMap` and
to ensure that pointer-equal subobjects are not visited multiple times, which is important
in practice because these expressions are very frequently highly shared.
-/
partial def visitAssignments (set : IO.Ref (HashSet USize))
(fvarUses : IO.Ref (HashSet FVarId))
partial def visitAssignments (set : IO.Ref (Std.HashSet USize))
(fvarUses : IO.Ref (Std.HashSet FVarId))
(assignments : Array (PersistentHashMap MVarId Expr)) : IO Unit := do
MonadCacheT.run do
for assignment in assignments do
@@ -316,8 +316,8 @@ where
/-- Given `aliases` as a map from an alias to what it aliases, we get the original
term by recursion. This has no cycle detection, so if `aliases` contains a loop
then this function will recurse infinitely. -/
partial def followAliases (aliases : HashMap FVarId FVarId) (x : FVarId) : FVarId :=
match aliases.find? x with
partial def followAliases (aliases : Std.HashMap FVarId FVarId) (x : FVarId) : FVarId :=
match aliases[x]? with
| none => x
| some y => followAliases aliases y
@@ -343,17 +343,17 @@ structure References where
the spans for `foo`, `bar`, and `baz`. Global definitions are always treated as used.
(It would be nice to be able to detect unused global definitions but this requires more
information than the linter framework can provide.) -/
constDecls : HashSet String.Range := .empty
constDecls : Std.HashSet String.Range := .empty
/-- The collection of all local declarations, organized by the span of the declaration.
We collapse all declarations declared at the same position into a single record using
`FVarDefinition.aliases`. -/
fvarDefs : HashMap String.Range FVarDefinition := .empty
fvarDefs : Std.HashMap String.Range FVarDefinition := .empty
/-- The set of `FVarId`s that are used directly. These may or may not be aliases. -/
fvarUses : HashSet FVarId := .empty
fvarUses : Std.HashSet FVarId := .empty
/-- A mapping from alias to original FVarId. We don't guarantee that the value is not itself
an alias, but we use `followAliases` when adding new elements to try to avoid long chains. -/
-- TODO: use a `UnionFind` data structure here
fvarAliases : HashMap FVarId FVarId := .empty
fvarAliases : Std.HashMap FVarId FVarId := .empty
/-- Collection of all `MetavarContext`s following the execution of a tactic. We trawl these
if needed to find additional `fvarUses`. -/
assignments : Array (PersistentHashMap MVarId Expr) := #[]
@@ -391,7 +391,7 @@ def collectReferences (infoTrees : Array Elab.InfoTree) (cmdStxRange : String.Ra
if s.startsWith "_" then return
-- Record this either as a new `fvarDefs`, or an alias of an existing one
modify fun s =>
if let some ref := s.fvarDefs.find? range then
if let some ref := s.fvarDefs[range]? then
{ s with fvarDefs := s.fvarDefs.insert range { ref with aliases := ref.aliases.push id } }
else
{ s with fvarDefs := s.fvarDefs.insert range { userName := ldecl.userName, stx, opts, aliases := #[id] } }
@@ -444,7 +444,7 @@ def unusedVariables : Linter where
-- Resolve all recursive references in `fvarAliases`.
-- At this point everything in `fvarAliases` is guaranteed not to be itself an alias,
-- and should point to some element of `FVarDefinition.aliases` in `s.fvarDefs`
let fvarAliases : HashMap FVarId FVarId := s.fvarAliases.fold (init := {}) fun m id baseId =>
let fvarAliases : Std.HashMap FVarId FVarId := s.fvarAliases.fold (init := {}) fun m id baseId =>
m.insert id (followAliases s.fvarAliases baseId)
-- Collect all non-alias fvars corresponding to `fvarUses` by resolving aliases in the list.
@@ -461,7 +461,7 @@ def unusedVariables : Linter where
let fvarUses fvarUsesRef.get
-- If any of the `fvar`s corresponding to this declaration is (an alias of) a variable in
-- `fvarUses`, then it is used
if aliases.any fun id => fvarUses.contains (fvarAliases.findD id id) then continue
if aliases.any fun id => fvarUses.contains (fvarAliases.getD id id) then continue
-- If this is a global declaration then it is (potentially) used after the command
if s.constDecls.contains range then continue
@@ -496,7 +496,7 @@ def unusedVariables : Linter where
initializedMVars := true
let fvarUses fvarUsesRef.get
-- Redo the initial check because `fvarUses` could be bigger now
if aliases.any fun id => fvarUses.contains (fvarAliases.findD id id) then continue
if aliases.any fun id => fvarUses.contains (fvarAliases.getD id id) then continue
-- If we made it this far then the variable is unused and not ignored
unused := unused.push (declStx, userName)

View File

@@ -16,8 +16,8 @@ structure State where
nextParamIdx : Nat := 0
paramNames : Array Name := #[]
fvars : Array Expr := #[]
lmap : HashMap LMVarId Level := {}
emap : HashMap MVarId Expr := {}
lmap : Std.HashMap LMVarId Level := {}
emap : Std.HashMap MVarId Expr := {}
abstractLevels : Bool -- whether to abstract level mvars
abbrev M := StateM State
@@ -54,7 +54,7 @@ private partial def abstractLevelMVars (u : Level) : M Level := do
if depth != s.mctx.depth then
return u -- metavariables from lower depths are treated as constants
else
match s.lmap.find? mvarId with
match s.lmap[mvarId]? with
| some u => pure u
| none =>
let paramId := Name.mkNum `_abstMVar s.nextParamIdx
@@ -87,7 +87,7 @@ partial def abstractExprMVars (e : Expr) : M Expr := do
if e != eNew then
abstractExprMVars eNew
else
match ( get).emap.find? mvarId with
match ( get).emap[mvarId]? with
| some e =>
return e
| none =>

View File

@@ -5,9 +5,9 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.Util.ShareCommon
import Lean.Data.HashMap
import Lean.Meta.Basic
import Lean.Meta.FunInfo
import Std.Data.HashMap.Raw
namespace Lean.Meta
namespace Canonicalizer
@@ -47,12 +47,12 @@ State for the `CanonM` monad.
-/
structure State where
/-- Mapping from `Expr` to hash. -/
-- We use `HashMapImp` to ensure we don't have to tag `State` as `unsafe`.
cache : HashMapImp ExprVisited UInt64 := mkHashMapImp
-- We use `HashMap.Raw` to ensure we don't have to tag `State` as `unsafe`.
cache : Std.HashMap.Raw ExprVisited UInt64 := Std.HashMap.Raw.empty
/--
Given a hashcode `k` and `keyToExprs.find? h = some es`, we have that all `es` have hashcode `k`, and
are not definitionally equal modulo the transparency setting used. -/
keyToExprs : HashMap UInt64 (List Expr) := mkHashMap
keyToExprs : Std.HashMap UInt64 (List Expr) :=
instance : Inhabited State where
default := {}
@@ -70,7 +70,7 @@ def CanonM.run (x : CanonM α) (transparency := TransparencyMode.instances) (s :
StateRefT'.run (x transparency) s
private partial def mkKey (e : Expr) : CanonM UInt64 := do
if let some hash := unsafe ( get).cache.find? { e } then
if let some hash := unsafe ( get).cache.get? { e } then
return hash
else
let key match e with
@@ -107,7 +107,7 @@ private partial def mkKey (e : Expr) : CanonM UInt64 := do
return mixHash ( mkKey v) ( mkKey b)
| .proj _ i s =>
return mixHash i.toUInt64 ( mkKey s)
unsafe modify fun { cache, keyToExprs} => { keyToExprs, cache := cache.insert { e } key |>.1 }
unsafe modify fun { cache, keyToExprs} => { keyToExprs, cache := cache.insert { e } key }
return key
/--
@@ -116,7 +116,7 @@ private partial def mkKey (e : Expr) : CanonM UInt64 := do
def canon (e : Expr) : CanonM Expr := do
let k mkKey e
-- Find all expressions canonicalized before that have the same key.
if let some es' := unsafe ( get).keyToExprs.find? k then
if let some es' := unsafe ( get).keyToExprs[k]? then
withTransparency ( read) do
for e' in es' do
-- Found an expression `e'` that is definitionally equal to `e` and share the same key.

View File

@@ -127,7 +127,7 @@ abbrev ClosureM := ReaderT Context $ StateRefT State MetaM
pure u
else
let s get
match s.visitedLevel.find? u with
match s.visitedLevel[u]? with
| some v => pure v
| none => do
let v f u
@@ -139,7 +139,7 @@ abbrev ClosureM := ReaderT Context $ StateRefT State MetaM
pure e
else
let s get
match s.visitedExpr.find? e with
match s.visitedExpr.get? e with
| some r => pure r
| none =>
let r f e

View File

@@ -52,14 +52,14 @@ which appear in the type and local context of `mvarId`, as well as the
metavariables which *those* metavariables depend on, etc.
-/
partial def _root_.Lean.MVarId.getMVarDependencies (mvarId : MVarId) (includeDelayed := false) :
MetaM (HashSet MVarId) :=
MetaM (Std.HashSet MVarId) :=
(·.snd) <$> (go mvarId).run {}
where
/-- Auxiliary definition for `getMVarDependencies`. -/
addMVars (e : Expr) : StateRefT (HashSet MVarId) MetaM Unit := do
addMVars (e : Expr) : StateRefT (Std.HashSet MVarId) MetaM Unit := do
let mvars getMVars e
let mut s get
set ({} : HashSet MVarId) -- Ensure that `s` is not shared.
set ({} : Std.HashSet MVarId) -- Ensure that `s` is not shared.
for mvarId in mvars do
if pure includeDelayed <||> notM (mvarId.isDelayedAssigned) then
s := s.insert mvarId
@@ -67,7 +67,7 @@ where
mvars.forM go
/-- Auxiliary definition for `getMVarDependencies`. -/
go (mvarId : MVarId) : StateRefT (HashSet MVarId) MetaM Unit :=
go (mvarId : MVarId) : StateRefT (Std.HashSet MVarId) MetaM Unit :=
withIncRecDepth do
let mdecl mvarId.getDecl
addMVars mdecl.type

View File

@@ -695,7 +695,7 @@ def throwOutOfScopeFVar : CheckAssignmentM α :=
throw <| Exception.internal outOfScopeExceptionId
private def findCached? (e : Expr) : CheckAssignmentM (Option Expr) := do
return ( get).cache.find? e
return ( get).cache.get? e
private def cache (e r : Expr) : CheckAssignmentM Unit := do
modify fun s => { s with cache := s.cache.insert e r }

View File

@@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
-/
prelude
import Lean.Data.AssocList
import Lean.HeadIndex
import Lean.Meta.Basic

View File

@@ -286,7 +286,7 @@ private structure Trie (α : Type) where
/-- Index of trie matching star. -/
star : TrieIndex
/-- Following matches based on key of trie. -/
children : HashMap Key TrieIndex
children : Std.HashMap Key TrieIndex
/-- Lazy entries at this trie that are not processed. -/
pending : Array (LazyEntry α) := #[]
deriving Inhabited
@@ -318,7 +318,7 @@ structure LazyDiscrTree (α : Type) where
/-- Backing array of trie entries. Should be owned by this trie. -/
tries : Array (LazyDiscrTree.Trie α) := #[default]
/-- Map from discriminator trie roots to the index. -/
roots : Lean.HashMap LazyDiscrTree.Key LazyDiscrTree.TrieIndex := {}
roots : Std.HashMap LazyDiscrTree.Key LazyDiscrTree.TrieIndex := {}
namespace LazyDiscrTree
@@ -445,9 +445,9 @@ private def addLazyEntryToTrie (i:TrieIndex) (e : LazyEntry α) : MatchM α Unit
modify (·.modify i (·.pushPending e))
private def evalLazyEntry (config : WhnfCoreConfig)
(p : Array α × TrieIndex × HashMap Key TrieIndex)
(p : Array α × TrieIndex × Std.HashMap Key TrieIndex)
(entry : LazyEntry α)
: MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
: MatchM α (Array α × TrieIndex × Std.HashMap Key TrieIndex) := do
let (values, starIdx, children) := p
let (todo, lctx, v) := entry
if todo.isEmpty then
@@ -465,7 +465,7 @@ private def evalLazyEntry (config : WhnfCoreConfig)
addLazyEntryToTrie starIdx (todo, lctx, v)
pure (values, starIdx, children)
else
match children.find? k with
match children[k]? with
| none =>
let children := children.insert k ( newTrie (todo, lctx, v))
pure (values, starIdx, children)
@@ -478,16 +478,16 @@ This evaluates all lazy entries in a trie and updates `values`, `starIdx`, and `
accordingly.
-/
private partial def evalLazyEntries (config : WhnfCoreConfig)
(values : Array α) (starIdx : TrieIndex) (children : HashMap Key TrieIndex)
(values : Array α) (starIdx : TrieIndex) (children : Std.HashMap Key TrieIndex)
(entries : Array (LazyEntry α)) :
MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
MatchM α (Array α × TrieIndex × Std.HashMap Key TrieIndex) := do
let mut values := values
let mut starIdx := starIdx
let mut children := children
entries.foldlM (init := (values, starIdx, children)) (evalLazyEntry config)
private def evalNode (c : TrieIndex) :
MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
MatchM α (Array α × TrieIndex × Std.HashMap Key TrieIndex) := do
let .node vs star cs pending := (get).get! c
if pending.size = 0 then
pure (vs, star, cs)
@@ -508,7 +508,7 @@ def dropKeyAux (next : TrieIndex) (rest : List Key) :
| [] =>
modify (·.set! next {values := #[], star, children})
| k :: r => do
let next := if k == .star then star else children.findD k 0
let next := if k == .star then star else children.getD k 0
dropKeyAux next r
/--
@@ -519,7 +519,7 @@ def dropKey (t : LazyDiscrTree α) (path : List LazyDiscrTree.Key) : MetaM (Lazy
match path with
| [] => pure t
| rootKey :: rest => do
let idx := t.roots.findD rootKey 0
let idx := t.roots.getD rootKey 0
Prod.snd <$> runMatch t (dropKeyAux idx rest)
/--
@@ -628,7 +628,7 @@ private partial def getMatchLoop (cases : Array PartialMatch) (result : MatchRes
else
cases.push { todo, score := ca.score, c := star }
let pushNonStar (k : Key) (args : Array Expr) (cases : Array PartialMatch) :=
match cs.find? k with
match cs[k]? with
| none => cases
| some c => cases.push { todo := todo ++ args, score := ca.score + 1, c }
let cases := pushStar cases
@@ -650,8 +650,8 @@ private partial def getMatchLoop (cases : Array PartialMatch) (result : MatchRes
cases |> pushNonStar k args
getMatchLoop cases result
private def getStarResult (root : Lean.HashMap Key TrieIndex) : MatchM α (MatchResult α) :=
match root.find? .star with
private def getStarResult (root : Std.HashMap Key TrieIndex) : MatchM α (MatchResult α) :=
match root[Key.star]? with
| none =>
pure <| {}
| some idx => do
@@ -661,16 +661,16 @@ private def getStarResult (root : Lean.HashMap Key TrieIndex) : MatchM α (Match
/-
Add partial match to cases if discriminator tree root map has potential matches.
-/
private def pushRootCase (r : Lean.HashMap Key TrieIndex) (k : Key) (args : Array Expr)
private def pushRootCase (r : Std.HashMap Key TrieIndex) (k : Key) (args : Array Expr)
(cases : Array PartialMatch) : Array PartialMatch :=
match r.find? k with
match r[k]? with
| none => cases
| some c => cases.push { todo := args, score := 1, c }
/--
Find values that match `e` in `root`.
-/
private def getMatchCore (root : Lean.HashMap Key TrieIndex) (e : Expr) :
private def getMatchCore (root : Std.HashMap Key TrieIndex) (e : Expr) :
MatchM α (MatchResult α) := do
let result getStarResult root
let (k, args) MatchClone.getMatchKeyArgs e (root := true) ( read)
@@ -701,7 +701,7 @@ of elements using concurrent functions for generating entries.
-/
private structure PreDiscrTree (α : Type) where
/-- Maps keys to index in tries array. -/
roots : HashMap Key Nat := {}
roots : Std.HashMap Key Nat := {}
/-- Lazy entries for root of trie. -/
tries : Array (Array (LazyEntry α)) := #[]
deriving Inhabited
@@ -711,7 +711,7 @@ namespace PreDiscrTree
private def modifyAt (d : PreDiscrTree α) (k : Key)
(f : Array (LazyEntry α) Array (LazyEntry α)) : PreDiscrTree α :=
let { roots, tries } := d
match roots.find? k with
match roots[k]? with
| .none =>
let roots := roots.insert k tries.size
{ roots, tries := tries.push (f #[]) }

View File

@@ -68,7 +68,7 @@ where
loop lhss alts minors
structure State where
used : HashSet Nat := {} -- used alternatives
used : Std.HashSet Nat := {} -- used alternatives
counterExamples : List (List Example) := []
/-- Return true if the given (sub-)problem has been solved. -/

View File

@@ -28,17 +28,17 @@ such as `contradiction`.
-/
private def _root_.Lean.MVarId.contradictionQuick (mvarId : MVarId) : MetaM Bool := do
mvarId.withContext do
let mut posMap : HashMap Expr FVarId := {}
let mut negMap : HashMap Expr FVarId := {}
let mut posMap : Std.HashMap Expr FVarId := {}
let mut negMap : Std.HashMap Expr FVarId := {}
for localDecl in ( getLCtx) do
unless localDecl.isImplementationDetail do
if let some p matchNot? localDecl.type then
if let some pFVarId := posMap.find? p then
if let some pFVarId := posMap[p]? then
mvarId.assign ( mkAbsurd ( mvarId.getType) (mkFVar pFVarId) localDecl.toExpr)
return true
negMap := negMap.insert p localDecl.fvarId
if ( isProp localDecl.type) then
if let some nFVarId := negMap.find? localDecl.type then
if let some nFVarId := negMap[localDecl.type]? then
mvarId.assign ( mkAbsurd ( mvarId.getType) localDecl.toExpr (mkFVar nFVarId))
return true
posMap := posMap.insert localDecl.type localDecl.fvarId

View File

@@ -97,8 +97,8 @@ namespace MkTableKey
structure State where
nextIdx : Nat := 0
lmap : HashMap LMVarId Level := {}
emap : HashMap MVarId Expr := {}
lmap : Std.HashMap LMVarId Level := {}
emap : Std.HashMap MVarId Expr := {}
mctx : MetavarContext
abbrev M := StateM State
@@ -120,7 +120,7 @@ partial def normLevel (u : Level) : M Level := do
return u
else
let s get
match ( get).lmap.find? mvarId with
match ( get).lmap[mvarId]? with
| some u' => pure u'
| none =>
let u' := mkLevelParam <| Name.mkNum `_tc s.nextIdx
@@ -145,7 +145,7 @@ partial def normExpr (e : Expr) : M Expr := do
return e
else
let s get
match s.emap.find? mvarId with
match s.emap[mvarId]? with
| some e' => pure e'
| none => do
let e' := mkFVar { name := Name.mkNum `_tc s.nextIdx }
@@ -186,7 +186,7 @@ structure State where
result? : Option AbstractMVarsResult := none
generatorStack : Array GeneratorNode := #[]
resumeStack : Array (ConsumerNode × Answer) := #[]
tableEntries : HashMap Expr TableEntry := {}
tableEntries : Std.HashMap Expr TableEntry := {}
abbrev SynthM := ReaderT Context $ StateRefT State MetaM
@@ -265,7 +265,7 @@ def newSubgoal (mctx : MetavarContext) (key : Expr) (mvar : Expr) (waiter : Wait
pure ((), m!"new goal {key}")
def findEntry? (key : Expr) : SynthM (Option TableEntry) := do
return ( get).tableEntries.find? key
return ( get).tableEntries[key]?
def getEntry (key : Expr) : SynthM TableEntry := do
match ( findEntry? key) with
@@ -553,7 +553,7 @@ def generate : SynthM Unit := do
/- See comment at `typeHasMVars` -/
if backward.synthInstance.canonInstances.get ( getOptions) then
unless gNode.typeHasMVars do
if let some entry := ( get).tableEntries.find? key then
if let some entry := ( get).tableEntries[key]? then
if entry.answers.any fun answer => answer.result.numMVars == 0 then
/-
We already have an answer that:

View File

@@ -66,9 +66,9 @@ inductive PreExpr
def toACExpr (op l r : Expr) : MetaM (Array Expr × ACExpr) := do
let (preExpr, vars)
toPreExpr (mkApp2 op l r)
|>.run HashSet.empty
|>.run Std.HashSet.empty
let vars := vars.toArray.insertionSort Expr.lt
let varMap := vars.foldl (fun xs x => xs.insert x xs.size) HashMap.empty |>.find!
let varMap := vars.foldl (fun xs x => xs.insert x xs.size) Std.HashMap.empty |>.get!
return (vars, toACExpr varMap preExpr)
where

View File

@@ -290,7 +290,7 @@ structure RewriteResultConfig where
side : SideConditions := .solveByElim
mctx : MetavarContext
def takeListAux (cfg : RewriteResultConfig) (seen : HashMap String Unit) (acc : Array RewriteResult)
def takeListAux (cfg : RewriteResultConfig) (seen : Std.HashMap String Unit) (acc : Array RewriteResult)
(xs : List ((Expr Name) × Bool × Nat)) : MetaM (Array RewriteResult) := do
let mut seen := seen
let mut acc := acc

View File

@@ -420,12 +420,12 @@ def mkSimpExt (name : Name := by exact decl_name%) : IO SimpExtension :=
| .toUnfoldThms n thms => d.registerDeclToUnfoldThms n thms
}
abbrev SimpExtensionMap := HashMap Name SimpExtension
abbrev SimpExtensionMap := Std.HashMap Name SimpExtension
builtin_initialize simpExtensionMapRef : IO.Ref SimpExtensionMap IO.mkRef {}
def getSimpExtension? (attrName : Name) : IO (Option SimpExtension) :=
return ( simpExtensionMapRef.get).find? attrName
return ( simpExtensionMapRef.get)[attrName]?
/-- Auxiliary method for adding a global declaration to a `SimpTheorems` datastructure. -/
def SimpTheorems.addConst (s : SimpTheorems) (declName : Name) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM SimpTheorems := do

View File

@@ -19,8 +19,8 @@ It contains:
- The actual procedure associated with a name.
-/
structure BuiltinSimprocs where
keys : HashMap Name (Array SimpTheoremKey) := {}
procs : HashMap Name (Sum Simproc DSimproc) := {}
keys : Std.HashMap Name (Array SimpTheoremKey) := {}
procs : Std.HashMap Name (Sum Simproc DSimproc) := {}
deriving Inhabited
/--
@@ -37,7 +37,7 @@ structure SimprocDecl where
deriving Inhabited
structure SimprocDeclExtState where
builtin : HashMap Name (Array SimpTheoremKey)
builtin : Std.HashMap Name (Array SimpTheoremKey)
newEntries : PHashMap Name (Array SimpTheoremKey) := {}
deriving Inhabited
@@ -65,7 +65,7 @@ def getSimprocDeclKeys? (declName : Name) : CoreM (Option (Array SimpTheoremKey)
if let some keys := keys? then
return some keys
else
return (simprocDeclExt.getState env).builtin.find? declName
return (simprocDeclExt.getState env).builtin[declName]?
def isBuiltinSimproc (declName : Name) : CoreM Bool := do
let s := simprocDeclExt.getState ( getEnv)
@@ -160,7 +160,7 @@ def Simprocs.addCore (s : Simprocs) (keys : Array SimpTheoremKey) (declName : Na
Implements attributes `builtin_simproc` and `builtin_sevalproc`.
-/
def addSimprocBuiltinAttrCore (ref : IO.Ref Simprocs) (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : IO Unit := do
let some keys := ( builtinSimprocDeclsRef.get).keys.find? declName |
let some keys := ( builtinSimprocDeclsRef.get).keys[declName]? |
throw (IO.userError "invalid [builtin_simproc] attribute, '{declName}' is not a builtin simproc")
ref.modify fun s => s.addCore keys declName post proc
@@ -176,7 +176,7 @@ def Simprocs.add (s : Simprocs) (declName : Name) (post : Bool) : CoreM Simprocs
getSimprocFromDecl declName
catch e =>
if ( isBuiltinSimproc declName) then
let some proc := ( builtinSimprocDeclsRef.get).procs.find? declName
let some proc := ( builtinSimprocDeclsRef.get).procs[declName]?
| throwError "invalid [simproc] attribute, '{declName}' is not a simproc"
pure proc
else
@@ -384,7 +384,7 @@ def mkSimprocAttr (attrName : Name) (attrDescr : String) (ext : SimprocExtension
erase := eraseSimprocAttr ext
}
abbrev SimprocExtensionMap := HashMap Name SimprocExtension
abbrev SimprocExtensionMap := Std.HashMap Name SimprocExtension
builtin_initialize simprocExtensionMapRef : IO.Ref SimprocExtensionMap IO.mkRef {}
@@ -438,7 +438,7 @@ def getSEvalSimprocs : CoreM Simprocs :=
return simprocSEvalExtension.getState ( getEnv)
def getSimprocExtensionCore? (attrName : Name) : IO (Option SimprocExtension) :=
return ( simprocExtensionMapRef.get).find? attrName
return ( simprocExtensionMapRef.get)[attrName]?
def simpAttrNameToSimprocAttrName (attrName : Name) : Name :=
if attrName == `simp then `simprocAttr

View File

@@ -512,7 +512,7 @@ def mkCongrSimp? (f : Expr) : SimpM (Option CongrTheorem) := do
if kinds.all fun k => match k with | CongrArgKind.fixed => true | CongrArgKind.eq => true | _ => false then
/- See remark above. -/
return none
match ( get).congrCache.find? f with
match ( get).congrCache[f]? with
| some thm? => return thm?
| none =>
let thm? mkCongrSimpCore? f info kinds

View File

@@ -903,7 +903,7 @@ structure State where
mctx : MetavarContext
nextMacroScope : MacroScope
ngen : NameGenerator
cache : HashMap ExprStructEq Expr := {}
cache : Std.HashMap ExprStructEq Expr := {}
structure Context where
mainModule : Name
@@ -1319,7 +1319,7 @@ structure State where
mctx : MetavarContext
paramNames : Array Name := #[]
nextParamIdx : Nat
cache : HashMap ExprStructEq Expr := {}
cache : Std.HashMap ExprStructEq Expr := {}
abbrev M := ReaderT Context <| StateM State
@@ -1328,7 +1328,7 @@ instance : MonadMCtx M where
modifyMCtx f := modify fun s => { s with mctx := f s.mctx }
instance : MonadCache ExprStructEq Expr M where
findCached? e := return ( get).cache.find? e
findCached? e := return ( get).cache[e]?
cache e v := modify fun s => { s with cache := s.cache.insert e v }
partial def mkParamName : M Name := do

View File

@@ -131,7 +131,7 @@ structure ParserCacheEntry where
structure ParserCache where
tokenCache : TokenCacheEntry
parserCache : HashMap ParserCacheKey ParserCacheEntry
parserCache : Std.HashMap ParserCacheKey ParserCacheEntry
def initCacheForInput (input : String) : ParserCache where
tokenCache := { startPos := input.endPos + ' ' /- make sure it is not a valid position -/ }
@@ -418,7 +418,7 @@ place if there was an error.
-/
def withCacheFn (parserName : Name) (p : ParserFn) : ParserFn := fun c s => Id.run do
let key := c.toCacheableParserContext, parserName, s.pos
if let some r := s.cache.parserCache.find? key then
if let some r := s.cache.parserCache[key]? then
-- TODO: turn this into a proper trace once we have these in the parser
--dbg_trace "parser cache hit: {parserName}:{s.pos} -> {r.stx}"
return s.stxStack.push r.stx, r.lhsPrec, r.newPos, s.cache, r.errorMsg, s.recoveredErrors

View File

@@ -198,10 +198,10 @@ def isHBinOp (e : Expr) : Bool := Id.run do
def replaceLPsWithVars (e : Expr) : MetaM Expr := do
if !e.hasLevelParam then return e
let lps := collectLevelParams {} e |>.params
let mut replaceMap : HashMap Name Level := {}
let mut replaceMap : Std.HashMap Name Level := {}
for lp in lps do replaceMap := replaceMap.insert lp ( mkFreshLevelMVar)
return e.replaceLevel fun
| Level.param n .. => replaceMap.find! n
| Level.param n .. => replaceMap[n]!
| l => if !l.hasParam then some l else none
def isDefEqAssigning (t s : Expr) : MetaM Bool := do

View File

@@ -29,7 +29,7 @@ namespace Lean.Environment
namespace Replay
structure Context where
newConstants : HashMap Name ConstantInfo
newConstants : Std.HashMap Name ConstantInfo
structure State where
env : Environment
@@ -73,7 +73,7 @@ and add it to the environment.
-/
partial def replayConstant (name : Name) : M Unit := do
if isTodo name then
let some ci := ( read).newConstants.find? name | unreachable!
let some ci := ( read).newConstants[name]? | unreachable!
replayConstants ci.getUsedConstantsAsSet
-- Check that this name is still pending: a mutual block may have taken care of it.
if ( get).pending.contains name then
@@ -89,13 +89,13 @@ partial def replayConstant (name : Name) : M Unit := do
| .inductInfo info =>
let lparams := info.levelParams
let nparams := info.numParams
let all info.all.mapM fun n => do pure <| (( read).newConstants.find! n)
let all info.all.mapM fun n => do pure <| (( read).newConstants[n]!)
for o in all do
modify fun s =>
{ s with remaining := s.remaining.erase o.name, pending := s.pending.erase o.name }
let ctorInfo all.mapM fun ci => do
pure (ci, ci.inductiveVal!.ctors.mapM fun n => do
pure (( read).newConstants.find! n))
pure (( read).newConstants[n]!))
-- Make sure we are really finished with the constructors.
for (_, ctors) in ctorInfo do
for ctor in ctors do
@@ -129,7 +129,7 @@ when we replayed the inductives.
-/
def checkPostponedConstructors : M Unit := do
for ctor in ( get).postponedConstructors do
match ( get).env.constants.find? ctor, ( read).newConstants.find? ctor with
match ( get).env.constants.find? ctor, ( read).newConstants[ctor]? with
| some (.ctorInfo info), some (.ctorInfo info') =>
if ! (info == info') then throw <| IO.userError s!"Invalid constructor {ctor}"
| _, _ => throw <| IO.userError s!"No such constructor {ctor}"
@@ -140,7 +140,7 @@ when we replayed the inductives.
-/
def checkPostponedRecursors : M Unit := do
for ctor in ( get).postponedRecursors do
match ( get).env.constants.find? ctor, ( read).newConstants.find? ctor with
match ( get).env.constants.find? ctor, ( read).newConstants[ctor]? with
| some (.recInfo info), some (.recInfo info') =>
if ! (info == info') then throw <| IO.userError s!"Invalid recursor {ctor}"
| _, _ => throw <| IO.userError s!"No such recursor {ctor}"
@@ -155,7 +155,7 @@ open Replay
Throws a `IO.userError` if the kernel rejects a constant,
or if there are malformed recursors or constructors for inductive types.
-/
def replay (newConstants : HashMap Name ConstantInfo) (env : Environment) : IO Environment := do
def replay (newConstants : Std.HashMap Name ConstantInfo) (env : Environment) : IO Environment := do
let mut remaining : NameSet :=
for (n, ci) in newConstants.toList do
-- We skip unsafe constants, and also partial constants.

View File

@@ -81,7 +81,7 @@ open Elab
open Meta
open FuzzyMatching
abbrev EligibleHeaderDecls := HashMap Name ConstantInfo
abbrev EligibleHeaderDecls := Std.HashMap Name ConstantInfo
/-- Cached header declarations for which `allowCompletion headerEnv decl` is true. -/
builtin_initialize eligibleHeaderDeclsRef : IO.Ref (Option EligibleHeaderDecls)

View File

@@ -316,7 +316,7 @@ partial def handleDocumentHighlight (p : DocumentHighlightParams)
let refs : Lsp.ModuleRefs findModuleRefs text trees |>.toLspModuleRefs
let mut ranges := #[]
for ident in refs.findAt p.position do
if let some info := refs.find? ident then
if let some info := refs.get? ident then
if let some definitionRange, _ := info.definition? then
ranges := ranges.push definitionRange
ranges := ranges.append <| info.usages.map (·.range)

View File

@@ -93,20 +93,20 @@ def toLspRefInfo (i : RefInfo) : BaseIO Lsp.RefInfo := do
end RefInfo
/-- All references from within a module for all identifiers used in a single module. -/
def ModuleRefs := HashMap RefIdent RefInfo
def ModuleRefs := Std.HashMap RefIdent RefInfo
namespace ModuleRefs
/-- Adds `ref` to the `RefInfo` corresponding to `ref.ident` in `self`. See `RefInfo.addRef`. -/
def addRef (self : ModuleRefs) (ref : Reference) : ModuleRefs :=
let refInfo := self.findD ref.ident RefInfo.empty
let refInfo := self.getD ref.ident RefInfo.empty
self.insert ref.ident (refInfo.addRef ref)
/-- Converts `refs` to a JSON-serializable `Lsp.ModuleRefs`. -/
def toLspModuleRefs (refs : ModuleRefs) : BaseIO Lsp.ModuleRefs := do
let refs refs.toList.mapM fun (k, v) => do
return (k, v.toLspRefInfo)
return HashMap.ofList refs
return Std.HashMap.ofList refs
end ModuleRefs
@@ -261,7 +261,7 @@ all identifiers that are being collapsed into one.
-/
partial def combineIdents (trees : Array InfoTree) (refs : Array Reference) : Array Reference := Id.run do
-- Deduplicate definitions based on their exact range
let mut posMap : HashMap Lsp.Range RefIdent := HashMap.empty
let mut posMap : Std.HashMap Lsp.Range RefIdent := Std.HashMap.empty
for ref in refs do
if let { ident, range, isBinder := true, .. } := ref then
posMap := posMap.insert range ident
@@ -277,17 +277,17 @@ partial def combineIdents (trees : Array InfoTree) (refs : Array Reference) : Ar
refs' := refs'.push ref
refs'
where
useConstRepresentatives (idMap : HashMap RefIdent RefIdent)
: HashMap RefIdent RefIdent := Id.run do
useConstRepresentatives (idMap : Std.HashMap RefIdent RefIdent)
: Std.HashMap RefIdent RefIdent := Id.run do
let insertIntoClass classesById id :=
let representative := findCanonicalRepresentative idMap id
let «class» := classesById.findD representative
let «class» := classesById.getD representative
let classesById := classesById.erase representative -- make `«class»` referentially unique
let «class» := «class».insert id
classesById.insert representative «class»
-- collect equivalence classes
let mut classesById : HashMap RefIdent (HashSet RefIdent) :=
let mut classesById : Std.HashMap RefIdent (Std.HashSet RefIdent) :=
for id, baseId in idMap.toArray do
classesById := insertIntoClass classesById id
classesById := insertIntoClass classesById baseId
@@ -310,17 +310,17 @@ where
r := r.insert id bestRepresentative
return r
findCanonicalRepresentative (idMap : HashMap RefIdent RefIdent) (id : RefIdent) : RefIdent := Id.run do
findCanonicalRepresentative (idMap : Std.HashMap RefIdent RefIdent) (id : RefIdent) : RefIdent := Id.run do
let mut canonicalRepresentative := id
while idMap.contains canonicalRepresentative do
canonicalRepresentative := idMap.find! canonicalRepresentative
canonicalRepresentative := idMap[canonicalRepresentative]!
return canonicalRepresentative
buildIdMap posMap := Id.run <| StateT.run' (s := HashMap.empty) do
buildIdMap posMap := Id.run <| StateT.run' (s := Std.HashMap.empty) do
-- map fvar defs to overlapping fvar defs/uses
for ref in refs do
let baseId := ref.ident
if let some id := posMap.find? ref.range then
if let some id := posMap[ref.range]? then
insertIdMap id baseId
-- apply `FVarAliasInfo`
@@ -346,11 +346,11 @@ are added to the `aliases` of the representative of the group.
Yields to separate groups for declaration and usages if `allowSimultaneousBinderUse` is set.
-/
def dedupReferences (refs : Array Reference) (allowSimultaneousBinderUse := false) : Array Reference := Id.run do
let mut refsByIdAndRange : HashMap (RefIdent × Option Bool × Lsp.Range) Reference := HashMap.empty
let mut refsByIdAndRange : Std.HashMap (RefIdent × Option Bool × Lsp.Range) Reference := Std.HashMap.empty
for ref in refs do
let isBinder := if allowSimultaneousBinderUse then some ref.isBinder else none
let key := (ref.ident, isBinder, ref.range)
refsByIdAndRange := match refsByIdAndRange[key] with
refsByIdAndRange := match refsByIdAndRange[key]? with
| some ref' => refsByIdAndRange.insert key { ref' with aliases := ref'.aliases ++ ref.aliases }
| none => refsByIdAndRange.insert key ref
@@ -371,21 +371,21 @@ def findModuleRefs (text : FileMap) (trees : Array InfoTree) (localVars : Bool :
refs := refs.filter fun
| { ident := RefIdent.fvar .., .. } => false
| _ => true
refs.foldl (init := HashMap.empty) fun m ref => m.addRef ref
refs.foldl (init := Std.HashMap.empty) fun m ref => m.addRef ref
/-! # Collecting and maintaining reference info from different sources -/
/-- References from ilean files and current ilean information from file workers. -/
structure References where
/-- References loaded from ilean files -/
ileans : HashMap Name (System.FilePath × Lsp.ModuleRefs)
ileans : Std.HashMap Name (System.FilePath × Lsp.ModuleRefs)
/-- References from workers, overriding the corresponding ilean files -/
workers : HashMap Name (Nat × Lsp.ModuleRefs)
workers : Std.HashMap Name (Nat × Lsp.ModuleRefs)
namespace References
/-- No ilean files, no information from workers. -/
def empty : References := { ileans := HashMap.empty, workers := HashMap.empty }
def empty : References := { ileans := Std.HashMap.empty, workers := Std.HashMap.empty }
/-- Adds the contents of an ilean file `ilean` at `path` to `self`. -/
def addIlean (self : References) (path : System.FilePath) (ilean : Ilean) : References :=
@@ -404,13 +404,13 @@ Replaces the current references with `refs` if `version` is newer than the curre
in `refs` and otherwise merges the reference data if `version` is equal to the current version.
-/
def updateWorkerRefs (self : References) (name : Name) (version : Nat) (refs : Lsp.ModuleRefs) : References := Id.run do
if let some (currVersion, _) := self.workers.find? name then
if let some (currVersion, _) := self.workers[name]? then
if version > currVersion then
return { self with workers := self.workers.insert name (version, refs) }
if version == currVersion then
let current := self.workers.findD name (version, HashMap.empty)
let current := self.workers.getD name (version, Std.HashMap.empty)
let merged := refs.fold (init := current.snd) fun m ident info =>
m.findD ident Lsp.RefInfo.empty |>.merge info |> m.insert ident
m.getD ident Lsp.RefInfo.empty |>.merge info |> m.insert ident
return { self with workers := self.workers.insert name (version, merged) }
return self
@@ -419,7 +419,7 @@ Replaces the worker references in `self` with the `refs` of the worker managing
if `version` is newer than the current version managed in `refs`.
-/
def finalizeWorkerRefs (self : References) (name : Name) (version : Nat) (refs : Lsp.ModuleRefs) : References := Id.run do
if let some (currVersion, _) := self.workers.find? name then
if let some (currVersion, _) := self.workers[name]? then
if version < currVersion then
return self
return { self with workers := self.workers.insert name (version, refs) }
@@ -429,8 +429,8 @@ def removeWorkerRefs (self : References) (name : Name) : References :=
{ self with workers := self.workers.erase name }
/-- Yields a map from all modules to all of their references. -/
def allRefs (self : References) : HashMap Name Lsp.ModuleRefs :=
let ileanRefs := self.ileans.toArray.foldl (init := HashMap.empty) fun m (name, _, refs) => m.insert name refs
def allRefs (self : References) : Std.HashMap Name Lsp.ModuleRefs :=
let ileanRefs := self.ileans.toArray.foldl (init := Std.HashMap.empty) fun m (name, _, refs) => m.insert name refs
self.workers.toArray.foldl (init := ileanRefs) fun m (name, _, refs) => m.insert name refs
/--
@@ -445,12 +445,12 @@ def allRefsFor
let refsToCheck := match ident with
| RefIdent.const .. => self.allRefs.toArray
| RefIdent.fvar identModule .. =>
match self.allRefs.find? identModule with
match self.allRefs[identModule]? with
| none => #[]
| some refs => #[(identModule, refs)]
let mut result := #[]
for (module, refs) in refsToCheck do
let some info := refs.find? ident
let some info := refs.get? ident
| continue
let some path srcSearchPath.findModuleWithExt "lean" module
| continue
@@ -462,13 +462,13 @@ def allRefsFor
/-- Yields all references in `module` at `pos`. -/
def findAt (self : References) (module : Name) (pos : Lsp.Position) (includeStop := false) : Array RefIdent := Id.run do
if let some refs := self.allRefs.find? module then
if let some refs := self.allRefs[module]? then
return refs.findAt pos includeStop
#[]
/-- Yields the first reference in `module` at `pos`. -/
def findRange? (self : References) (module : Name) (pos : Lsp.Position) (includeStop := false) : Option Range := do
let refs self.allRefs.find? module
let refs self.allRefs[module]?
refs.findRange? pos includeStop
/-- Location and parent declaration of a reference. -/

View File

@@ -90,6 +90,10 @@ section Utils
| crashed (e : IO.Error)
| ioError (e : IO.Error)
inductive CrashOrigin
| fileWorkerToClientForwarding
| clientToFileWorkerForwarding
inductive WorkerState where
/-- The watchdog can detect a crashed file worker in two places: When trying to send a message
to the file worker and when reading a request reply.
@@ -98,7 +102,7 @@ section Utils
that are in-flight are errored. Upon receiving the next packet for that file worker, the file
worker is restarted and the packet is forwarded to it. If the crash was detected while writing
a packet, we queue that packet until the next packet for the file worker arrives. -/
| crashed (queuedMsgs : Array JsonRpc.Message)
| crashed (queuedMsgs : Array JsonRpc.Message) (origin : CrashOrigin)
| running
abbrev PendingRequestMap := RBMap RequestID JsonRpc.Message compare
@@ -136,6 +140,11 @@ section FileWorker
for id, _ in pendingRequests do
hError.writeLspResponseError { id := id, code := code, message := msg }
def queuedMsgs (fw : FileWorker) : Array JsonRpc.Message :=
match fw.state with
| .running => #[]
| .crashed queuedMsgs _ => queuedMsgs
end FileWorker
end FileWorker
@@ -404,10 +413,23 @@ section ServerM
return
eraseFileWorker uri
def handleCrash (uri : DocumentUri) (queuedMsgs : Array JsonRpc.Message) : ServerM Unit := do
def handleCrash (uri : DocumentUri) (queuedMsgs : Array JsonRpc.Message) (origin: CrashOrigin) : ServerM Unit := do
let some fw findFileWorker? uri
| return
updateFileWorkers { fw with state := WorkerState.crashed queuedMsgs }
updateFileWorkers { fw with state := WorkerState.crashed queuedMsgs origin }
def tryDischargeQueuedMessages (uri : DocumentUri) (queuedMsgs : Array JsonRpc.Message) : ServerM Unit := do
let some fw findFileWorker? uri
| throwServerError "Cannot find file worker for '{uri}'."
let mut crashedMsgs := #[]
-- Try to discharge all queued msgs, tracking the ones that we can't discharge
for msg in queuedMsgs do
try
fw.stdin.writeLspMessage msg
catch _ =>
crashedMsgs := crashedMsgs.push msg
if ¬ crashedMsgs.isEmpty then
handleCrash uri crashedMsgs .clientToFileWorkerForwarding
/-- Tries to write a message, sets the state of the FileWorker to `crashed` if it does not succeed
and restarts the file worker if the `crashed` flag was already set. Just logs an error if
@@ -423,7 +445,7 @@ section ServerM
let some fw findFileWorker? uri
| return
match fw.state with
| WorkerState.crashed queuedMsgs =>
| WorkerState.crashed queuedMsgs _ =>
let mut queuedMsgs := queuedMsgs
if queueFailedMessage then
queuedMsgs := queuedMsgs.push msg
@@ -432,17 +454,7 @@ section ServerM
-- restart the crashed FileWorker
eraseFileWorker uri
startFileWorker fw.doc
let some newFw findFileWorker? uri
| throwServerError "Cannot find file worker for '{uri}'."
let mut crashedMsgs := #[]
-- try to discharge all queued msgs, tracking the ones that we can't discharge
for msg in queuedMsgs do
try
newFw.stdin.writeLspMessage msg
catch _ =>
crashedMsgs := crashedMsgs.push msg
if ¬ crashedMsgs.isEmpty then
handleCrash uri crashedMsgs
tryDischargeQueuedMessages uri queuedMsgs
| WorkerState.running =>
let initialQueuedMsgs :=
if queueFailedMessage then
@@ -452,7 +464,7 @@ section ServerM
try
fw.stdin.writeLspMessage msg
catch _ =>
handleCrash uri initialQueuedMsgs
handleCrash uri initialQueuedMsgs .clientToFileWorkerForwarding
/--
Sends a notification to the file worker identified by `uri` that its dependency `staleDependency`
@@ -638,7 +650,7 @@ def handleCallHierarchyOutgoingCalls (p : CallHierarchyOutgoingCallsParams)
let references ( read).references.get
let some refs := references.allRefs.find? module
let some refs := references.allRefs[module]?
| return #[]
let items refs.toArray.filterMapM fun ident, info => do
@@ -702,9 +714,9 @@ def handlePrepareRename (p : PrepareRenameParams) : ServerM (Option Range) := do
def handleRename (p : RenameParams) : ServerM Lsp.WorkspaceEdit := do
if (String.toName p.newName).isAnonymous then
throwServerError s!"Can't rename: `{p.newName}` is not an identifier"
let mut refs : HashMap DocumentUri (RBMap Lsp.Position Lsp.Position compare) :=
let mut refs : Std.HashMap DocumentUri (RBMap Lsp.Position Lsp.Position compare) :=
for { uri, range } in ( handleReference { p with context.includeDeclaration := true }) do
refs := refs.insert uri <| (refs.findD uri ).insert range.start range.end
refs := refs.insert uri <| (refs.getD uri ).insert range.start range.end
-- We have to filter the list of changes to put the ranges in order and
-- remove any duplicates or overlapping ranges, or else the rename will not apply
let changes := refs.fold (init := ) fun changes uri map => Id.run do
@@ -955,7 +967,16 @@ section MainLoop
let workers st.fileWorkersRef.get
let mut workerTasks := #[]
for (_, fw) in workers do
if let WorkerState.running := fw.state then
-- When the forwarding task crashes, its return value will be stuck at
-- `WorkerEvent.crashed _`.
-- We want to handle this event only once, not over and over again,
-- so once the state becomes `WorkerState.crashed _ .fileWorkerToClientForwarding`
-- as a result of `WorkerEvent.crashed _`, we stop handling this event until
-- eventually the file worker is restarted by a notification from the client.
-- We do not want to filter the forwarding task in case of
-- `WorkerState.crashed _ .clientToFileWorkerForwarding`, since the forwarding task
-- exit code may still contain valuable information in this case (e.g. that the imports changed).
if !(fw.state matches WorkerState.crashed _ .fileWorkerToClientForwarding) then
workerTasks := workerTasks.push <| fw.commTask.map (ServerEvent.workerEvent fw)
let ev IO.waitAny (clientTask :: workerTasks.toList)
@@ -984,13 +1005,16 @@ section MainLoop
| WorkerEvent.ioError e =>
throwServerError s!"IO error while processing events for {fw.doc.uri}: {e}"
| WorkerEvent.crashed _ =>
handleCrash fw.doc.uri #[]
handleCrash fw.doc.uri fw.queuedMsgs .fileWorkerToClientForwarding
mainLoop clientTask
| WorkerEvent.terminated =>
throwServerError <| "Internal server error: got termination event for worker that "
++ "should have been removed"
| .importsChanged =>
let uri := fw.doc.uri
let queuedMsgs := fw.queuedMsgs
startFileWorker fw.doc
tryDischargeQueuedMessages uri queuedMsgs
mainLoop clientTask
end MainLoop

View File

@@ -5,7 +5,7 @@ Authors: David Thrane Christiansen
-/
prelude
import Init.Data
import Lean.Data.HashMap
import Std.Data.HashMap.Basic
import Init.Omega
namespace Lean.Diff
@@ -57,7 +57,7 @@ structure Histogram.Entry (α : Type u) (lsize rsize : Nat) where
/-- A histogram for arrays maps each element to a count and, if applicable, an index.-/
def Histogram (α : Type u) (lsize rsize : Nat) [BEq α] [Hashable α] :=
Lean.HashMap α (Histogram.Entry α lsize rsize)
Std.HashMap α (Histogram.Entry α lsize rsize)
section
@@ -67,7 +67,7 @@ variable [BEq α] [Hashable α]
/-- Add an element from the left array to a histogram -/
def Histogram.addLeft (histogram : Histogram α lsize rsize) (index : Fin lsize) (val : α)
: Histogram α lsize rsize :=
match histogram.find? val with
match histogram.get? val with
| none => histogram.insert val {
leftCount := 1, leftIndex := some index,
leftWF := by simp,
@@ -81,7 +81,7 @@ def Histogram.addLeft (histogram : Histogram α lsize rsize) (index : Fin lsize)
/-- Add an element from the right array to a histogram -/
def Histogram.addRight (histogram : Histogram α lsize rsize) (index : Fin rsize) (val : α)
: Histogram α lsize rsize :=
match histogram.find? val with
match histogram.get? val with
| none => histogram.insert val {
leftCount := 0, leftIndex := none,
leftWF := by simp,

View File

@@ -28,7 +28,7 @@ structure State where
Set of visited subterms that satisfy the predicate `p`.
We have to use this set to make sure `f` is applied at most once of each subterm that satisfies `p`.
-/
checked : HashSet Expr
checked : Std.HashSet Expr
unsafe def initCache : State := {
visited := mkArray cacheSize.toNat (cast lcProof ())

View File

@@ -5,14 +5,15 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.Expr
import Std.Data.HashMap.Raw
namespace Lean
structure HasConstCache (declNames : Array Name) where
cache : HashMapImp Expr Bool := mkHashMapImp
cache : Std.HashMap.Raw Expr Bool := Std.HashMap.Raw.empty
unsafe def HasConstCache.containsUnsafe (e : Expr) : StateM (HasConstCache declNames) Bool := do
if let some r := ( get).cache.find? (beq := ptrEq) e then
if let some r := ( get).cache.get? (beq := ptrEq) e then
return r
else
match e with
@@ -26,7 +27,7 @@ unsafe def HasConstCache.containsUnsafe (e : Expr) : StateM (HasConstCache declN
| _ => return false
where
cache (e : Expr) (r : Bool) : StateM (HasConstCache declNames) Bool := do
modify fun cache => cache.insert (beq := ptrEq) e r |>.1
modify fun cache => cache.insert (beq := ptrEq) e r
return r
/--

View File

@@ -5,7 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Init.Control.StateRef
import Lean.Data.HashMap
import Std.Data.HashMap.Basic
namespace Lean
/-- Interface for caching results. -/
@@ -36,15 +36,15 @@ instance {α β ε : Type} {m : Type → Type} [MonadCache α β m] [Monad m] :
/-- Adapter for implementing `MonadCache` interface using `HashMap`s.
We just have to specify how to extract/modify the `HashMap`. -/
class MonadHashMapCacheAdapter (α β : Type) (m : Type Type) [BEq α] [Hashable α] where
getCache : m (HashMap α β)
modifyCache : (HashMap α β HashMap α β) m Unit
getCache : m (Std.HashMap α β)
modifyCache : (Std.HashMap α β Std.HashMap α β) m Unit
namespace MonadHashMapCacheAdapter
@[always_inline, inline]
def findCached? {α β : Type} {m : Type Type} [BEq α] [Hashable α] [Monad m] [MonadHashMapCacheAdapter α β m] (a : α) : m (Option β) := do
let c getCache
pure (c.find? a)
pure (c.get? a)
@[always_inline, inline]
def cache {α β : Type} {m : Type Type} [BEq α] [Hashable α] [MonadHashMapCacheAdapter α β m] (a : α) (b : β) : m Unit :=
@@ -56,7 +56,7 @@ instance {α β : Type} {m : Type → Type} [BEq α] [Hashable α] [Monad m] [Mo
end MonadHashMapCacheAdapter
def MonadCacheT {ω} (α β : Type) (m : Type Type) [STWorld ω m] [BEq α] [Hashable α] := StateRefT (HashMap α β) m
def MonadCacheT {ω} (α β : Type) (m : Type Type) [STWorld ω m] [BEq α] [Hashable α] := StateRefT (Std.HashMap α β) m
namespace MonadCacheT
@@ -67,7 +67,7 @@ instance : MonadHashMapCacheAdapter α β (MonadCacheT α β m) where
modifyCache f := (modify f : StateRefT' ..)
@[inline] def run {σ} (x : MonadCacheT α β m σ) : m σ :=
x.run' mkHashMap
x.run' Std.HashMap.empty
instance : Monad (MonadCacheT α β m) := inferInstanceAs (Monad (StateRefT' _ _ _))
instance : MonadLift m (MonadCacheT α β m) := inferInstanceAs (MonadLift m (StateRefT' _ _ _))
@@ -80,7 +80,7 @@ instance [Alternative m] : Alternative (MonadCacheT α β m) := inferInstanceAs
end MonadCacheT
/- Similar to `MonadCacheT`, but using `StateT` instead of `StateRefT` -/
def MonadStateCacheT (α β : Type) (m : Type Type) [BEq α] [Hashable α] := StateT (HashMap α β) m
def MonadStateCacheT (α β : Type) (m : Type Type) [BEq α] [Hashable α] := StateT (Std.HashMap α β) m
namespace MonadStateCacheT
@@ -91,7 +91,7 @@ instance : MonadHashMapCacheAdapter α β (MonadStateCacheT α β m) where
modifyCache f := (modify f : StateT ..)
@[always_inline, inline] def run {σ} (x : MonadStateCacheT α β m σ) : m σ :=
x.run' mkHashMap
x.run' Std.HashMap.empty
instance : Monad (MonadStateCacheT α β m) := inferInstanceAs (Monad (StateT _ _))
instance : MonadLift m (MonadStateCacheT α β m) := inferInstanceAs (MonadLift m (StateT _ _))

View File

@@ -96,9 +96,9 @@ deriving FromJson, ToJson
/-- Thread with maps necessary for computing max sharing indices -/
structure ThreadWithMaps extends Thread where
stringMap : HashMap String Nat := {}
funcMap : HashMap Nat Nat := {}
stackMap : HashMap (Nat × Option Nat) Nat := {}
stringMap : Std.HashMap String Nat := {}
funcMap : Std.HashMap Nat Nat := {}
stackMap : Std.HashMap (Nat × Option Nat) Nat := {}
/-- Last timestamp encountered: stop time of preceding sibling, or else start time of parent. -/
lastTime : Float := 0
@@ -123,7 +123,7 @@ where
if pp then
funcName := s!"{funcName}: {← msg.format}"
let strIdx modifyGet fun thread =>
if let some idx := thread.stringMap.find? funcName then
if let some idx := thread.stringMap[funcName]? then
(idx, thread)
else
(thread.stringMap.size, { thread with
@@ -131,7 +131,7 @@ where
stringMap := thread.stringMap.insert funcName thread.stringMap.size })
let category := categories.findIdx? (·.name == data.cls.getRoot.toString) |>.getD 0
let funcIdx modifyGet fun thread =>
if let some idx := thread.funcMap.find? strIdx then
if let some idx := thread.funcMap[strIdx]? then
(idx, thread)
else
(thread.funcMap.size, { thread with
@@ -151,7 +151,7 @@ where
funcMap := thread.funcMap.insert strIdx thread.funcMap.size })
let frameIdx := funcIdx
let stackIdx modifyGet fun thread =>
if let some idx := thread.stackMap.find? (frameIdx, parentStackIdx?) then
if let some idx := thread.stackMap[(frameIdx, parentStackIdx?)]? then
(idx, thread)
else
(thread.stackMap.size, { thread with
@@ -222,7 +222,7 @@ def Profile.export (name : String) (startTime : Milliseconds) (traceState : Trac
structure ThreadWithCollideMaps extends ThreadWithMaps where
/-- Max sharing map for samples -/
sampleMap : HashMap Nat Nat := {}
sampleMap : Std.HashMap Nat Nat := {}
/--
Adds samples from `add` to `thread`, increasing the weight of existing samples with identical stacks
@@ -237,7 +237,7 @@ where
let oldStackIdx := add.samples.stack[oldSampleIdx]!
let stackIdx collideStacks oldStackIdx
modify fun thread =>
if let some idx := thread.sampleMap.find? stackIdx then
if let some idx := thread.sampleMap[stackIdx]? then
-- imperative to preserve linear use of arrays here!
let t1, t2, t3, samples, t5, t6, t7, t8, t9, t10, o2, o3, o4, o5, o6 := thread
let s1, s2, weight, s3, s4 := samples
@@ -265,7 +265,7 @@ where
let oldStrIdx := add.funcTable.name[oldFuncIdx]!
let strIdx getStrIdx add.stringArray[oldStrIdx]!
let funcIdx modifyGet fun thread =>
if let some idx := thread.funcMap.find? strIdx then
if let some idx := thread.funcMap[strIdx]? then
(idx, thread)
else
(thread.funcMap.size, { thread with
@@ -284,7 +284,7 @@ where
funcMap := thread.funcMap.insert strIdx thread.funcMap.size })
let frameIdx := funcIdx
modifyGet fun thread =>
if let some idx := thread.stackMap.find? (frameIdx, parentStackIdx?) then
if let some idx := thread.stackMap[(frameIdx, parentStackIdx?)]? then
(idx, thread)
else
(thread.stackMap.size,
@@ -302,7 +302,7 @@ where
t1,t2, t3, t4, t5, stackTable, t7, t8, t9, t10, o2, o3, stackMap, o5, o6)
getStrIdx (s : String) :=
modifyGet fun thread =>
if let some idx := thread.stringMap.find? s then
if let some idx := thread.stringMap[s]? then
(idx, thread)
else
(thread.stringMap.size, { thread with

View File

@@ -5,8 +5,8 @@ Authors: Leonardo de Moura
-/
prelude
import Init.Data.Hashable
import Lean.Data.HashSet
import Lean.Data.HashMap
import Std.Data.HashSet.Basic
import Std.Data.HashMap.Basic
namespace Lean
@@ -23,33 +23,33 @@ unsafe instance : BEq (Ptr α) where
Set of pointers. It is a low-level auxiliary datastructure used for traversing DAGs.
-/
unsafe def PtrSet (α : Type) :=
HashSet (Ptr α)
Std.HashSet (Ptr α)
unsafe def mkPtrSet {α : Type} (capacity : Nat := 64) : PtrSet α :=
mkHashSet capacity
Std.HashSet.empty capacity
unsafe abbrev PtrSet.insert (s : PtrSet α) (a : α) : PtrSet α :=
HashSet.insert s { value := a }
Std.HashSet.insert s { value := a }
unsafe abbrev PtrSet.contains (s : PtrSet α) (a : α) : Bool :=
HashSet.contains s { value := a }
Std.HashSet.contains s { value := a }
/--
Map of pointers. It is a low-level auxiliary datastructure used for traversing DAGs.
-/
unsafe def PtrMap (α : Type) (β : Type) :=
HashMap (Ptr α) β
Std.HashMap (Ptr α) β
unsafe def mkPtrMap {α β : Type} (capacity : Nat := 64) : PtrMap α β :=
mkHashMap capacity
Std.HashMap.empty capacity
unsafe abbrev PtrMap.insert (s : PtrMap α β) (a : α) (b : β) : PtrMap α β :=
HashMap.insert s { value := a } b
Std.HashMap.insert s { value := a } b
unsafe abbrev PtrMap.contains (s : PtrMap α β) (a : α) : Bool :=
HashMap.contains s { value := a }
Std.HashMap.contains s { value := a }
unsafe abbrev PtrMap.find? (s : PtrMap α β) (a : α) : Option β :=
HashMap.find? s { value := a }
Std.HashMap.get? s { value := a }
end Lean

View File

@@ -5,7 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Init.Data.List.Control
import Lean.Data.HashMap
import Std.Data.HashMap.Basic
namespace Lean.SCC
/-!
Very simple implementation of Tarjan's SCC algorithm.
@@ -25,7 +25,7 @@ structure Data where
structure State where
stack : List α := []
nextIndex : Nat := 0
data : HashMap α Data := {}
data : Std.HashMap α Data := {}
sccs : List (List α) := []
abbrev M := StateM (State α)
@@ -35,7 +35,7 @@ variable {α : Type} [BEq α] [Hashable α]
private def getDataOf (a : α) : M α Data := do
let s get
match s.data.find? a with
match s.data[a]? with
| some d => pure d
| none => pure {}
@@ -52,7 +52,7 @@ private def push (a : α) : M α Unit :=
private def modifyDataOf (a : α) (f : Data Data) : M α Unit :=
modify fun s => { s with
data := match s.data.find? a with
data := match s.data[a]? with
| none => s.data
| some d => s.data.insert a (f d)
}

View File

@@ -67,7 +67,7 @@ structure TraceState where
traces : PersistentArray TraceElem := {}
deriving Inhabited
builtin_initialize inheritedTraceOptions : IO.Ref (HashSet Name) IO.mkRef
builtin_initialize inheritedTraceOptions : IO.Ref (Std.HashSet Name) IO.mkRef
class MonadTrace (m : Type Type) where
modifyTraceState : (TraceState TraceState) m Unit
@@ -88,7 +88,7 @@ def printTraces : m Unit := do
def resetTraceState : m Unit :=
modifyTraceState (fun _ => {})
private def checkTraceOption (inherited : HashSet Name) (opts : Options) (cls : Name) : Bool :=
private def checkTraceOption (inherited : Std.HashSet Name) (opts : Options) (cls : Name) : Bool :=
!opts.isEmpty && go (`trace ++ cls)
where
go (opt : Name) : Bool :=

View File

@@ -5,3 +5,4 @@ Authors: Sebastian Ullrich
-/
prelude
import Std.Data
import Std.Sat

View File

@@ -112,6 +112,10 @@ Tries to retrieve the mapping for the given key, returning `none` if no such map
@[inline] def get? (m : HashMap α β) (a : α) : Option β :=
DHashMap.Const.get? m.inner a
@[deprecated get? "Use `m[a]?` or `m.get? a` instead", inherit_doc get?]
def find? (m : HashMap α β) (a : α) : Option β :=
m.get? a
@[inline, inherit_doc DHashMap.contains] def contains (m : HashMap α β)
(a : α) : Bool :=
m.inner.contains a
@@ -135,6 +139,10 @@ Retrieves the mapping for the given key. Ensures that such a mapping exists by r
(fallback : β) : β :=
DHashMap.Const.getD m.inner a fallback
@[deprecated getD, inherit_doc getD]
def findD (m : HashMap α β) (a : α) (fallback : β) : β :=
m.getD a fallback
/--
The notation `m[a]!` is preferred over calling this function directly.
@@ -143,6 +151,10 @@ Tries to retrieve the mapping for the given key, panicking if no such mapping is
@[inline] def get! [Inhabited β] (m : HashMap α β) (a : α) : β :=
DHashMap.Const.get! m.inner a
@[deprecated get! "Use `m[a]!` or `m.get! a` instead", inherit_doc get!]
def find! [Inhabited β] (m : HashMap α β) (a : α) : Option β :=
m.get! a
instance [BEq α] [Hashable α] : GetElem? (HashMap α β) α β (fun m a => a m) where
getElem m a h := m.get a h
getElem? m a := m.get? a
@@ -236,3 +248,16 @@ instance [BEq α] [Hashable α] [Repr α] [Repr β] : Repr (HashMap α β) where
end Unverified
end Std.HashMap
/--
Groups all elements `x`, `y` in `xs` with `key x == key y` into the same array
`(xs.groupByKey key).find! (key x)`. Groups preserve the relative order of elements in `xs`.
-/
def Array.groupByKey [BEq α] [Hashable α] (key : β α) (xs : Array β)
: Std.HashMap α (Array β) := Id.run do
let mut groups :=
for x in xs do
let group := groups.getD (key x) #[]
groups := groups.erase (key x) -- make `group` referentially unique
groups := groups.insert (key x) (group.push x)
return groups

View File

@@ -69,8 +69,9 @@ instance : EmptyCollection (Raw α β) where
instance : Inhabited (Raw α β) where
default :=
@[inline, inherit_doc DHashMap.Raw.insert] def insert [BEq α] [Hashable α] (m : Raw α β) (a : α)
(b : β) : Raw α β :=
set_option linter.unusedVariables false in
@[inline, inherit_doc DHashMap.Raw.insert] def insert [beq : BEq α] [Hashable α] (m : Raw α β)
(a : α) (b : β) : Raw α β :=
m.inner.insert a b
@[inline, inherit_doc DHashMap.Raw.insertIfNew] def insertIfNew [BEq α] [Hashable α] (m : Raw α β)
@@ -93,12 +94,13 @@ instance : Inhabited (Raw α β) where
let previous, r := DHashMap.Raw.Const.getThenInsertIfNew? m.inner a b
previous, r
set_option linter.unusedVariables false in
/--
The notation `m[a]?` is preferred over calling this function directly.
Tries to retrieve the mapping for the given key, returning `none` if no such mapping is present.
-/
@[inline] def get? [BEq α] [Hashable α] (m : Raw α β) (a : α) : Option β :=
@[inline] def get? [beq : BEq α] [Hashable α] (m : Raw α β) (a : α) : Option β :=
DHashMap.Raw.Const.get? m.inner a
@[inline, inherit_doc DHashMap.Raw.contains] def contains [BEq α] [Hashable α] (m : Raw α β)

View File

@@ -184,6 +184,10 @@ in the collection will be present in the returned hash set.
@[inline] def ofList [BEq α] [Hashable α] (l : List α) : HashSet α :=
HashMap.unitOfList l
/-- Computes the union of the given hash sets. -/
@[inline] def union [BEq α] [Hashable α] (m₁ m₂ : HashSet α) : HashSet α :=
m₂.fold (init := m₁) fun acc x => acc.insert x
/--
Returns the number of buckets in the internal representation of the hash set. This function may
be useful for things like monitoring system health, but it should be considered an internal

7
src/Std/Sat.lean Normal file
View File

@@ -0,0 +1,7 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
prelude
import Std.Sat.CNF

10
src/Std/Sat/CNF.lean Normal file
View File

@@ -0,0 +1,10 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
prelude
import Std.Sat.CNF.Basic
import Std.Sat.CNF.Literal
import Std.Sat.CNF.Relabel
import Std.Sat.CNF.RelabelFin

190
src/Std/Sat/CNF/Basic.lean Normal file
View File

@@ -0,0 +1,190 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Kim Morrison
-/
prelude
import Init.Data.List.Lemmas
import Init.Data.List.Impl
import Std.Sat.CNF.Literal
namespace Std
namespace Sat
/--
A clause in a CNF.
The literal `(i, b)` is satisfied if the assignment to `i` agrees with `b`.
-/
abbrev CNF.Clause (α : Type u) : Type u := List (Literal α)
/--
A CNF formula.
Literals are identified by members of `α`.
-/
abbrev CNF (α : Type u) : Type u := List (CNF.Clause α)
namespace CNF
/--
Evaluating a `Clause` with respect to an assignment `a`.
-/
def Clause.eval (a : α Bool) (c : Clause α) : Bool := c.any fun (i, n) => a i == n
@[simp] theorem Clause.eval_nil (a : α Bool) : Clause.eval a [] = false := rfl
@[simp] theorem Clause.eval_cons (a : α Bool) :
Clause.eval a (i :: c) = (a i.1 == i.2 || Clause.eval a c) := rfl
/--
Evaluating a `CNF` formula with respect to an assignment `a`.
-/
def eval (a : α Bool) (f : CNF α) : Bool := f.all fun c => c.eval a
@[simp] theorem eval_nil (a : α Bool) : eval a [] = true := rfl
@[simp] theorem eval_cons (a : α Bool) : eval a (c :: f) = (c.eval a && eval a f) := rfl
@[simp] theorem eval_append (a : α Bool) (f1 f2 : CNF α) :
eval a (f1 ++ f2) = (eval a f1 && eval a f2) := List.all_append
def Sat (a : α Bool) (f : CNF α) : Prop := eval a f = true
def Unsat (f : CNF α) : Prop := a, eval a f = false
theorem sat_def (a : α Bool) (f : CNF α) : Sat a f (eval a f = true) := by rfl
theorem unsat_def (f : CNF α) : Unsat f ( a, eval a f = false) := by rfl
@[simp] theorem not_unsat_nil : ¬Unsat ([] : CNF α) :=
fun h => by simp [unsat_def] at h
@[simp] theorem sat_nil {assign : α Bool} : Sat assign ([] : CNF α) := by
simp [sat_def]
@[simp] theorem unsat_nil_cons {g : CNF α} : Unsat ([] :: g) := by
simp [unsat_def]
namespace Clause
/--
Variable `v` occurs in `Clause` `c`.
-/
def Mem (v : α) (c : Clause α) : Prop := (v, false) c (v, true) c
instance {v : α} {c : Clause α} [DecidableEq α] : Decidable (Mem v c) :=
inferInstanceAs <| Decidable (_ _)
@[simp] theorem not_mem_nil {v : α} : ¬Mem v ([] : Clause α) := by simp [Mem]
@[simp] theorem mem_cons {v : α} : Mem v (l :: c : Clause α) (v = l.1 Mem v c) := by
rcases l with b, (_|_)
· simp [Mem, or_assoc]
· simp [Mem]
rw [or_left_comm]
theorem mem_of (h : (v, p) c) : Mem v c := by
cases p
· left; exact h
· right; exact h
theorem eval_congr (a1 a2 : α Bool) (c : Clause α) (hw : i, Mem i c a1 i = a2 i) :
eval a1 c = eval a2 c := by
induction c
case nil => rfl
case cons i c ih =>
simp only [eval_cons]
rw [ih, hw]
· rcases i with b, (_|_) <;> simp [Mem]
· intro j h
apply hw
rcases h with h | h
· left
apply List.mem_cons_of_mem _ h
· right
apply List.mem_cons_of_mem _ h
end Clause
/--
Variable `v` occurs in `CNF` formula `f`.
-/
def Mem (v : α) (f : CNF α) : Prop := c, c f c.Mem v
instance {v : α} {f : CNF α} [DecidableEq α] : Decidable (Mem v f) :=
inferInstanceAs <| Decidable ( _, _)
theorem any_not_isEmpty_iff_exists_mem {f : CNF α} :
(List.any f fun c => !List.isEmpty c) = true v, Mem v f := by
simp only [List.any_eq_true, Bool.not_eq_true', List.isEmpty_false_iff_exists_mem, Mem,
Clause.Mem]
constructor
. intro h
rcases h with clause, hclause1, hclause2
rcases hclause2 with lit, hlit
exists lit.fst, clause
constructor
. assumption
. rcases lit with _, _ | _ <;> simp_all
. intro h
rcases h with lit, clause, hclause1, hclause2
exists clause
constructor
. assumption
. cases hclause2 with
| inl hl => exact Exists.intro _ hl
| inr hr => exact Exists.intro _ hr
@[simp] theorem not_exists_mem : (¬ v, Mem v f) n, f = List.replicate n [] := by
simp only [ any_not_isEmpty_iff_exists_mem]
simp only [List.any_eq_true, Bool.not_eq_true', not_exists, not_and, Bool.not_eq_false]
induction f with
| nil =>
simp only [List.not_mem_nil, List.isEmpty_iff, false_implies, forall_const, true_iff]
exact 0, rfl
| cons c f ih =>
simp_all [ih, List.isEmpty_iff]
constructor
· rintro rfl, n, rfl
exact n+1, rfl
· rintro n, h
cases n
· simp at h
· simp_all only [List.replicate, List.cons.injEq, true_and]
exact _, rfl
instance {f : CNF α} [DecidableEq α] : Decidable ( v, Mem v f) :=
decidable_of_iff (f.any fun c => !c.isEmpty) any_not_isEmpty_iff_exists_mem
@[simp] theorem not_mem_nil {v : α} : ¬Mem v ([] : CNF α) := by simp [Mem]
@[simp] theorem mem_cons {v : α} {c} {f : CNF α} :
Mem v (c :: f : CNF α) (Clause.Mem v c Mem v f) := by simp [Mem]
theorem mem_of (h : c f) (w : Clause.Mem v c) : Mem v f := by
apply Exists.intro c
constructor <;> assumption
@[simp] theorem mem_append {v : α} {f1 f2 : CNF α} : Mem v (f1 ++ f2) Mem v f1 Mem v f2 := by
simp [Mem, List.mem_append]
constructor
· rintro c, (mf1 | mf2), mc
· left
exact c, mf1, mc
· right
exact c, mf2, mc
· rintro (c, mf1, mc | c, mf2, mc)
· exact c, Or.inl mf1, mc
· exact c, Or.inr mf2, mc
theorem eval_congr (a1 a2 : α Bool) (f : CNF α) (hw : v, Mem v f a1 v = a2 v) :
eval a1 f = eval a2 f := by
induction f
case nil => rfl
case cons c x ih =>
simp only [eval_cons]
rw [ih, Clause.eval_congr] <;>
· intro i h
apply hw
simp [h]
end CNF
end Sat
end Std

View File

@@ -0,0 +1,38 @@
/-
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Josh Clune
-/
prelude
import Init.Data.Hashable
import Init.Data.ToString
namespace Std
namespace Sat
/--
CNF literals identified by some type `α`. The `Bool` is the polarity of the literal.
`true` means positive polarity.
-/
abbrev Literal (α : Type u) := α × Bool
namespace Literal
/--
Flip the polarity of `l`.
-/
def negate (l : Literal α) : Literal α := (l.1, not l.2)
/--
Output `l` as a DIMACS literal identifier.
-/
def dimacs [ToString α] (l : Literal α) : String :=
if l.2 then
s!"{l.1}"
else
s!"-{l.1}"
end Literal
end Sat
end Std

View File

@@ -0,0 +1,123 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Kim Morrison
-/
prelude
import Std.Sat.CNF.Basic
namespace Std
namespace Sat
namespace CNF
namespace Clause
/--
Change the literal type in a `Clause` from `α` to `β` by using `r`.
-/
def relabel (r : α β) (c : Clause α) : Clause β := c.map (fun (i, n) => (r i, n))
@[simp] theorem eval_relabel {r : α β} {a : β Bool} {c : Clause α} :
(relabel r c).eval a = c.eval (a r) := by
induction c <;> simp_all [relabel]
@[simp] theorem relabel_id' : relabel (id : α α) = id := by funext; simp [relabel]
theorem relabel_congr {c : Clause α} {r1 r2 : α β} (hw : v, Mem v c r1 v = r2 v) :
relabel r1 c = relabel r2 c := by
simp only [relabel]
rw [List.map_congr_left]
intro v, p h
congr
apply hw _ (mem_of h)
-- We need the unapplied equality later.
@[simp] theorem relabel_relabel' : relabel r1 relabel r2 = relabel (r1 r2) := by
funext i
simp only [Function.comp_apply, relabel, List.map_map]
rfl
end Clause
/-! ### Relabelling
It is convenient to be able to construct a CNF using a more complicated literal type,
but eventually we need to embed in `Nat`.
-/
/--
Change the literal type in a `CNF` formula from `α` to `β` by using `r`.
-/
def relabel (r : α β) (f : CNF α) : CNF β := f.map (Clause.relabel r)
@[simp] theorem relabel_nil {r : α β} : relabel r [] = [] := by simp [relabel]
@[simp] theorem relabel_cons {r : α β} : relabel r (c :: f) = (c.relabel r) :: relabel r f := by
simp [relabel]
@[simp] theorem eval_relabel (r : α β) (a : β Bool) (f : CNF α) :
(relabel r f).eval a = f.eval (a r) := by
induction f <;> simp_all
@[simp] theorem relabel_append : relabel r (f1 ++ f2) = relabel r f1 ++ relabel r f2 :=
List.map_append _ _ _
@[simp] theorem relabel_relabel : relabel r1 (relabel r2 f) = relabel (r1 r2) f := by
simp only [relabel, List.map_map, Clause.relabel_relabel']
@[simp] theorem relabel_id : relabel id x = x := by simp [relabel]
theorem relabel_congr {f : CNF α} {r1 r2 : α β} (hw : v, Mem v f r1 v = r2 v) :
relabel r1 f = relabel r2 f := by
dsimp only [relabel]
rw [List.map_congr_left]
intro c h
apply Clause.relabel_congr
intro v m
exact hw _ (mem_of h m)
theorem sat_relabel {f : CNF α} (h : Sat (r1 r2) f) : Sat r1 (relabel r2 f) := by
simp_all [sat_def]
theorem unsat_relabel {f : CNF α} (r : α β) (h : Unsat f) :
Unsat (relabel r f) := by
simp_all [unsat_def]
theorem nonempty_or_impossible (f : CNF α) : Nonempty α n, f = List.replicate n [] := by
induction f with
| nil => exact Or.inr 0, rfl
| cons c x ih => match c with
| [] => cases ih with
| inl h => left; exact h
| inr h =>
obtain n, rfl := h
right
exact n + 1, rfl
| a, b :: c => exact Or.inl a
theorem unsat_relabel_iff {f : CNF α} {r : α β}
(hw : {v1 v2}, Mem v1 f Mem v2 f r v1 = r v2 v1 = v2) :
Unsat (relabel r f) Unsat f := by
rcases nonempty_or_impossible f with (a₀ | n, rfl)
· refine fun h => ?_, unsat_relabel r
have em := Classical.propDecidable
let g : β α := fun b =>
if h : a, Mem a f r a = b then h.choose else a₀
have h' := unsat_relabel g h
suffices w : relabel g (relabel r f) = f by
rwa [w] at h'
have : a, Mem a f g (r a) = a := by
intro v h
dsimp [g]
rw [dif_pos v, h, rfl]
apply hw _ h
· exact (Exists.choose_spec (v, h, rfl : a', Mem a' f r a' = r v)).2
· exact (Exists.choose_spec (v, h, rfl : a', Mem a' f r a' = r v)).1
rw [relabel_relabel, relabel_congr, relabel_id]
exact this
· cases n <;> simp [unsat_def, List.replicate_succ]
end CNF
end Sat
end Std

View File

@@ -0,0 +1,138 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Kim Morrison
-/
prelude
import Init.Data.List.Nat.Basic
import Std.Sat.CNF.Relabel
namespace Std
namespace Sat
namespace CNF
/--
Obtain the literal with the largest identifier in `c`.
-/
def Clause.maxLiteral (c : Clause Nat) : Option Nat := (c.map (·.1)) |>.maximum?
theorem Clause.of_maxLiteral_eq_some (c : Clause Nat) (h : c.maxLiteral = some maxLit) :
lit, Mem lit c lit maxLit := by
intro lit hlit
simp only [maxLiteral, List.maximum?_eq_some_iff', List.mem_map, forall_exists_index, and_imp,
forall_apply_eq_imp_iff₂] at h
simp only [Mem] at hlit
rcases h with _, hbar
cases hlit
all_goals
have := hbar (lit, _) (by assumption)
omega
theorem Clause.maxLiteral_eq_some_of_mem (c : Clause Nat) (h : Mem l c) :
maxLit, c.maxLiteral = some maxLit := by
dsimp [Mem] at h
cases h <;> rename_i h
all_goals
have h1 := List.ne_nil_of_mem h
have h2 := not_congr <| @List.maximum?_eq_none_iff _ (c.map (·.1)) _
simp [ Option.ne_none_iff_exists', h1, h2, maxLiteral]
theorem Clause.of_maxLiteral_eq_none (c : Clause Nat) (h : c.maxLiteral = none) :
lit, ¬Mem lit c := by
intro lit hlit
simp only [maxLiteral, List.maximum?_eq_none_iff, List.map_eq_nil] at h
simp only [h, not_mem_nil] at hlit
/--
Obtain the literal with the largest identifier in `f`.
-/
def maxLiteral (f : CNF Nat) : Option Nat :=
List.filterMap Clause.maxLiteral f |>.maximum?
theorem of_maxLiteral_eq_some' (f : CNF Nat) (h : f.maxLiteral = some maxLit) :
clause, clause f clause.maxLiteral = some localMax localMax maxLit := by
intro clause hclause1 hclause2
simp [maxLiteral, List.maximum?_eq_some_iff'] at h
rcases h with _, hclause3
apply hclause3 localMax clause hclause1 hclause2
theorem of_maxLiteral_eq_some (f : CNF Nat) (h : f.maxLiteral = some maxLit) :
lit, Mem lit f lit maxLit := by
intro lit hlit
dsimp [Mem] at hlit
rcases hlit with clause, hclause1, hclause2
rcases Clause.maxLiteral_eq_some_of_mem clause hclause2 with localMax, hlocal
have h1 := of_maxLiteral_eq_some' f h clause hclause1 hlocal
have h2 := Clause.of_maxLiteral_eq_some clause hlocal lit hclause2
omega
theorem of_maxLiteral_eq_none (f : CNF Nat) (h : f.maxLiteral = none) :
lit, ¬Mem lit f := by
intro lit hlit
simp only [maxLiteral, List.maximum?_eq_none_iff] at h
dsimp [Mem] at hlit
rcases hlit with clause, hclause1, hclause2
have := Clause.of_maxLiteral_eq_none clause (List.forall_none_of_filterMap_eq_nil h clause hclause1) lit
contradiction
/--
An upper bound for the amount of distinct literals in `f`.
-/
def numLiterals (f : CNF Nat) :=
match f.maxLiteral with
| none => 0
| some n => n + 1
theorem lt_numLiterals {f : CNF Nat} (h : Mem v f) : v < numLiterals f := by
dsimp [numLiterals]
split <;> rename_i h2
. exfalso
apply of_maxLiteral_eq_none f h2 v h
. have := of_maxLiteral_eq_some f h2 v h
omega
theorem numLiterals_pos {f : CNF Nat} (h : Mem v f) : 0 < numLiterals f :=
Nat.lt_of_le_of_lt (Nat.zero_le _) (lt_numLiterals h)
/--
Relabel `f` to a `CNF` formula with a known upper bound for its literals.
This operation might be useful when e.g. using the literals to index into an array of known size
without conducting bounds checks.
-/
def relabelFin (f : CNF Nat) : CNF (Fin f.numLiterals) :=
if h : v, Mem v f then
let n := f.numLiterals
f.relabel fun i =>
if w : i < n then
-- This branch will always hold
i, w
else
0, numLiterals_pos h.choose_spec
else
List.replicate f.length []
@[simp] theorem unsat_relabelFin {f : CNF Nat} : Unsat f.relabelFin Unsat f := by
dsimp [relabelFin]
split <;> rename_i h
· apply unsat_relabel_iff
intro a b ma mb
replace ma := lt_numLiterals ma
replace mb := lt_numLiterals mb
split <;> rename_i a_lt
· simp
· contradiction
· cases f with
| nil => simp
| cons c g =>
simp only [not_exists_mem] at h
obtain n, h := h
cases n with
| zero => simp at h
| succ n => simp_all [List.replicate_succ]
end CNF
end Sat
end Std

View File

@@ -23,8 +23,8 @@ structure Module where
instance : Hashable Module where hash m := hash m.keyName
instance : BEq Module where beq m n := m.keyName == n.keyName
abbrev ModuleSet := HashSet Module
@[inline] def ModuleSet.empty : ModuleSet := HashSet.empty
abbrev ModuleSet := Std.HashSet Module
@[inline] def ModuleSet.empty : ModuleSet := Std.HashSet.empty
abbrev OrdModuleSet := OrdHashSet Module
@[inline] def OrdModuleSet.empty : OrdModuleSet := OrdHashSet.empty

View File

@@ -247,8 +247,8 @@ hydrate_opaque_type OpaquePackage Package
instance : Hashable Package where hash pkg := hash pkg.config.name
instance : BEq Package where beq p1 p2 := p1.config.name == p2.config.name
abbrev PackageSet := HashSet Package
@[inline] def PackageSet.empty : PackageSet := HashSet.empty
abbrev PackageSet := Std.HashSet Package
@[inline] def PackageSet.empty : PackageSet := Std.HashSet.empty
abbrev OrdPackageSet := OrdHashSet Package
@[inline] def OrdPackageSet.empty : OrdPackageSet := OrdHashSet.empty

View File

@@ -23,11 +23,11 @@ namespace Lake
deriving instance BEq, Hashable for Import
/- Cache for the imported header environment of Lake configuration files. -/
initialize importEnvCache : IO.Ref (HashMap (Array Import) Environment) IO.mkRef {}
initialize importEnvCache : IO.Ref (Std.HashMap (Array Import) Environment) IO.mkRef {}
/-- Like `importModules`, but fetch the resulting import state from the cache if possible. -/
def importModulesUsingCache (imports : Array Import) (opts : Options) (trustLevel : UInt32) : IO Environment := do
if let some env := ( importEnvCache.get).find? imports then
if let some env := ( importEnvCache.get)[imports]? then
return env
let env importModules imports opts trustLevel
importEnvCache.modify (·.insert imports env)
@@ -120,7 +120,7 @@ def importConfigFileCore (olean : FilePath) (leanOpts : Options) : IO Environmen
let extNameIdx mkExtNameMap 0
let env := mod.entries.foldl (init := env) fun env (extName, ents) =>
if lakeExts.contains extName then
match extNameIdx.find? extName with
match extNameIdx[extName]? with
| some entryIdx => ents.foldl extDescrs[entryIdx]!.addEntry env
| none => env
else

View File

@@ -3,7 +3,7 @@ Copyright (c) 2022 Mac Malone. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Mac Malone
-/
import Lean.Data.HashSet
import Std.Data.HashSet.Basic
open Lean
@@ -11,13 +11,13 @@ namespace Lake
/-- A `HashSet` that preserves insertion order. -/
structure OrdHashSet (α) [Hashable α] [BEq α] where
toHashSet : HashSet α
toHashSet : Std.HashSet α
toArray : Array α
namespace OrdHashSet
variable [Hashable α] [BEq α]
instance : Coe (OrdHashSet α) (HashSet α) := toHashSet
instance : Coe (OrdHashSet α) (Std.HashSet α) := toHashSet
def empty : OrdHashSet α :=
.empty, .empty

View File

@@ -57,6 +57,7 @@ def testProc (args : IO.Process.SpawnArgs) : BaseIO Bool :=
EIO.catchExceptions (h := fun _ => pure false) do
let child IO.Process.spawn {
args with
stdin := IO.Process.Stdio.null
stdout := IO.Process.Stdio.null
stderr := IO.Process.Stdio.null
}

View File

@@ -7,13 +7,14 @@ Author: Leonardo de Moura
#pragma once
#include <stddef.h>
#include <stdint.h>
#include <lean/lean.h>
namespace lean {
void init_thread_heap();
void * alloc(size_t sz);
void dealloc(void * o, size_t sz);
void add_heartbeats(uint64_t count);
uint64_t get_num_heartbeats();
LEAN_EXPORT void * alloc(size_t sz);
LEAN_EXPORT void dealloc(void * o, size_t sz);
LEAN_EXPORT void add_heartbeats(uint64_t count);
LEAN_EXPORT uint64_t get_num_heartbeats();
void initialize_alloc();
void finalize_alloc();
}

View File

@@ -485,43 +485,36 @@ extern "C" LEAN_EXPORT obj_res lean_io_prim_handle_write(b_obj_arg h, b_obj_arg
}
}
/*
Handle.getLine : (@& Handle) → IO Unit
The line returned by `lean_io_prim_handle_get_line`
is truncated at the first '\0' character and the
rest of the line is discarded. */
/* Handle.getLine : (@& Handle) → IO Unit */
extern "C" LEAN_EXPORT obj_res lean_io_prim_handle_get_line(b_obj_arg h, obj_arg /* w */) {
FILE * fp = io_get_handle(h);
const int buf_sz = 64;
char buf_str[buf_sz]; // NOLINT
std::string result;
bool first = true;
while (true) {
char * out = std::fgets(buf_str, buf_sz, fp);
if (out != nullptr) {
if (strlen(buf_str) < buf_sz-1 || buf_str[buf_sz-2] == '\n') {
if (first) {
return io_result_mk_ok(mk_string(out));
} else {
result.append(out);
return io_result_mk_ok(mk_string(result));
}
}
result.append(out);
} else if (std::feof(fp)) {
clearerr(fp);
return io_result_mk_ok(mk_string(result));
} else {
return io_result_mk_error(decode_io_error(errno, nullptr));
int c; // Note: int, not char, required to handle EOF
while ((c = std::fgetc(fp)) != EOF) {
result.push_back(c);
if (c == '\n') {
break;
}
first = false;
}
if (std::ferror(fp)) {
return io_result_mk_error(decode_io_error(errno, nullptr));
} else if (std::feof(fp)) {
clearerr(fp);
return io_result_mk_ok(mk_string(result));
} else {
obj_res ret = io_result_mk_ok(mk_string(result));
return ret;
}
}
/* Handle.putStr : (@& Handle) → (@& String) → IO Unit */
extern "C" LEAN_EXPORT obj_res lean_io_prim_handle_put_str(b_obj_arg h, b_obj_arg s, obj_arg /* w */) {
FILE * fp = io_get_handle(h);
if (std::fputs(lean_string_cstr(s), fp) != EOF) {
usize n = lean_string_size(s) - 1; // - 1 to ignore the terminal NULL byte.
usize m = std::fwrite(lean_string_cstr(s), 1, n, fp);
if (m == n) {
return io_result_mk_ok(box(0));
} else {
return io_result_mk_error(decode_io_error(errno, nullptr));

Some files were not shown because too many files have changed in this diff Show More