Compare commits

..

4 Commits

Author SHA1 Message Date
Leonardo de Moura
8f6ba0e107 chore: document change 2024-08-06 16:44:23 -07:00
Leonardo de Moura
3439180a3c perf: combine lean_inc_heartbeat calls at expr_eq_fn 2024-08-06 16:40:41 -07:00
Leonardo de Moura
01322069e3 perf: use while for nested applications at expr_eq_fn 2024-08-06 15:55:02 -07:00
Leonardo de Moura
c9d9baa3c8 perf: faster check_system at expr_eq_fn 2024-08-06 15:48:10 -07:00
370 changed files with 663 additions and 1807 deletions

View File

@@ -5,7 +5,7 @@ Some notes on how to debug Lean, which may also be applicable to debugging Lean
## Tracing
In `CoreM` and derived monads, we use `trace[traceCls] "msg with {interpolations}"` to fill the structured trace viewable with `set_option trace.traceCls true`.
In `CoreM` and derived monads, we use `trace![traceCls] "msg with {interpolations}"` to fill the structured trace viewable with `set_option trace.traceCls true`.
New trace classes have to be registered using `registerTraceClass` first.
Notable trace classes:
@@ -22,9 +22,7 @@ Notable trace classes:
In pure contexts or when execution is aborted before the messages are finally printed, one can instead use the term `dbg_trace "msg with {interpolations}"; val` (`;` can also be replaced by a newline), which will print the message to stderr before evaluating `val`. `dbgTraceVal val` can be used as a shorthand for `dbg_trace "{val}"; val`.
Note that if the return value is not actually used, the trace code is silently dropped as well.
By default, such stderr output is buffered and shown as messages after a command has been elaborated, which is necessary to ensure deterministic ordering of messages under parallelism.
If Lean aborts the process before it can finish the command or takes too long to do that, using `-DstderrAsMessages=false` avoids this buffering and shows `dbg_trace` output (but not `trace`s or other diagnostics) immediately.
In the language server, stderr output is buffered and shown as messages after a command has been elaborated, unless the option `server.stderrAsMessages` is deactivated.
## Debuggers

View File

@@ -152,26 +152,22 @@ We'll use `v4.7.0-rc1` as the intended release version in this example.
This will add a list of all the commits since the last stable version.
- Delete "update stage0" commits, and anything with a completely inscrutable commit message.
- Next, we will move a curated list of downstream repos to the release candidate.
- This assumes that for each repository either:
* There is already a *reviewed* branch `bump/v4.7.0` containing the required adaptations.
The preparation of this branch is beyond the scope of this document.
* The repository does not need any changes to move to the new version.
- This assumes that there is already a *reviewed* branch `bump/v4.7.0` on each repository
containing the required adaptations (or no adaptations are required).
The preparation of this branch is beyond the scope of this document.
- For each of the target repositories:
- If the repository does not need any changes (i.e. `bump/v4.7.0` does not exist) then create
a new PR updating `lean-toolchain` to `leanprover/lean4:v4.7.0-rc1` and running `lake update`.
- Otherwise:
- Checkout the `bump/v4.7.0` branch.
- Verify that the `lean-toolchain` is set to the nightly from which the release candidate was created.
- `git merge origin/master`
- Change the `lean-toolchain` to `leanprover/lean4:v4.7.0-rc1`
- In `lakefile.lean`, change any dependencies which were using `nightly-testing` or `bump/v4.7.0` branches
back to `master` or `main`, and run `lake update` for those dependencies.
- Run `lake build` to ensure that dependencies are found (but it's okay to stop it after a moment).
- `git commit`
- `git push`
- Open a PR from `bump/v4.7.0` to `master`, and either merge it yourself after CI, if appropriate,
or notify the maintainers that it is ready to go.
- Once the PR has been merged, tag `master` with `v4.7.0-rc1` and push this tag.
- Checkout the `bump/v4.7.0` branch.
- Verify that the `lean-toolchain` is set to the nightly from which the release candidate was created.
- `git merge origin/master`
- Change the `lean-toolchain` to `leanprover/lean4:v4.7.0-rc1`
- In `lakefile.lean`, change any dependencies which were using `nightly-testing` or `bump/v4.7.0` branches
back to `master` or `main`, and run `lake update` for those dependencies.
- Run `lake build` to ensure that dependencies are found (but it's okay to stop it after a moment).
- `git commit`
- `git push`
- Open a PR from `bump/v4.7.0` to `master`, and either merge it yourself after CI, if appropriate,
or notify the maintainers that it is ready to go.
- Once this PR has been merged, tag `master` with `v4.7.0-rc1` and push this tag.
- We do this for the same list of repositories as for stable releases, see above.
As above, there are dependencies between these, and so the process above is iterative.
It greatly helps if you can merge the `bump/v4.7.0` PRs yourself!

View File

@@ -27,9 +27,9 @@ Setting up a basic parallelized release build:
git clone https://github.com/leanprover/lean4
cd lean4
cmake --preset release
make -C build/release -j$(nproc || sysctl -n hw.logicalcpu)
make -C build/release -j$(nproc) # see below for macOS
```
You can replace `$(nproc || sysctl -n hw.logicalcpu)` with the desired parallelism amount.
You can replace `$(nproc)`, which is not available on macOS and some alternative shells, with the desired parallelism amount.
The above commands will compile the Lean library and binaries into the
`stage1` subfolder; see below for details.

View File

@@ -7,7 +7,6 @@ prelude
import Init.Data.Nat.MinMax
import Init.Data.Nat.Lemmas
import Init.Data.List.Monadic
import Init.Data.List.Nat.Range
import Init.Data.Fin.Basic
import Init.Data.Array.Mem
import Init.TacticsExtra
@@ -337,10 +336,6 @@ theorem not_mem_nil (a : α) : ¬ a ∈ #[] := nofun
/-- # get lemmas -/
theorem lt_of_getElem {x : α} {a : Array α} {idx : Nat} {hidx : idx < a.size} (_ : a[idx] = x) :
idx < a.size :=
hidx
theorem getElem?_mem {l : Array α} {i : Fin l.size} : l[i] l := by
erw [Array.mem_def, getElem_eq_data_getElem]
apply List.get_mem
@@ -510,13 +505,6 @@ theorem size_eq_length_data (as : Array α) : as.size = as.data.length := rfl
simp only [mkEmpty_eq, size_push] at *
omega
@[simp] theorem data_range (n : Nat) : (range n).data = List.range n := by
induction n <;> simp_all [range, Nat.fold, flip, List.range_succ]
@[simp]
theorem getElem_range {n : Nat} {x : Nat} (h : x < (Array.range n).size) : (Array.range n)[x] = x := by
simp [getElem_eq_data_getElem]
set_option linter.deprecated false in
@[simp] theorem reverse_data (a : Array α) : a.reverse.data = a.data.reverse := by
let rec go (as : Array α) (i j hj)
@@ -719,22 +707,13 @@ theorem mapIdx_spec (as : Array α) (f : Fin as.size → α → β)
unfold modify modifyM Id.run
split <;> simp
theorem getElem_modify {as : Array α} {x i} (h : i < as.size) :
(as.modify x f)[i]'(by simp [h]) = if x = i then f as[i] else as[i] := by
simp only [modify, modifyM, get_eq_getElem, Id.run, Id.pure_eq]
split
· simp only [Id.bind_eq, get_set _ _ _ h]; split <;> simp [*]
theorem get_modify {arr : Array α} {x i} (h : i < arr.size) :
(arr.modify x f).get i, by simp [h] =
if x = i then f (arr.get i, h) else arr.get i, h := by
simp [modify, modifyM, Id.run]; split
· simp [get_set _ _ _ h]; split <;> simp [*]
· rw [if_neg (mt (by rintro rfl; exact h) _)]
theorem getElem_modify_self {as : Array α} {i : Nat} (h : i < as.size) (f : α α) :
(as.modify i f)[i]'(by simp [h]) = f as[i] := by
simp [getElem_modify h]
theorem getElem_modify_of_ne {as : Array α} {i : Nat} (hj : j < as.size)
(f : α α) (h : i j) :
(as.modify i f)[j]'(by rwa [size_modify]) = as[j] := by
simp [getElem_modify hj, h]
/-! ### filter -/
@[simp] theorem filter_data (p : α Bool) (l : Array α) :

View File

@@ -42,8 +42,8 @@ Bitvectors have decidable equality. This should be used via the instance `Decida
-- We manually derive the `DecidableEq` instances for `BitVec` because
-- we want to have builtin support for bit-vector literals, and we
-- need a name for this function to implement `canUnfoldAtMatcher` at `WHNF.lean`.
def BitVec.decEq (x y : BitVec n) : Decidable (x = y) :=
match x, y with
def BitVec.decEq (a b : BitVec n) : Decidable (a = b) :=
match a, b with
| n, m =>
if h : n = m then
isTrue (h rfl)
@@ -69,9 +69,9 @@ protected def ofNat (n : Nat) (i : Nat) : BitVec n where
instance instOfNat : OfNat (BitVec n) i where ofNat := .ofNat n i
instance natCastInst : NatCast (BitVec w) := BitVec.ofNat w
/-- Given a bitvector `x`, return the underlying `Nat`. This is O(1) because `BitVec` is a
/-- Given a bitvector `a`, return the underlying `Nat`. This is O(1) because `BitVec` is a
(zero-cost) wrapper around a `Nat`. -/
protected def toNat (x : BitVec n) : Nat := x.toFin.val
protected def toNat (a : BitVec n) : Nat := a.toFin.val
/-- Return the bound in terms of toNat. -/
theorem isLt (x : BitVec w) : x.toNat < 2^w := x.toFin.isLt
@@ -123,18 +123,18 @@ section getXsb
@[inline] def getMsb (x : BitVec w) (i : Nat) : Bool := i < w && getLsb x (w-1-i)
/-- Return most-significant bit in bitvector. -/
@[inline] protected def msb (x : BitVec n) : Bool := getMsb x 0
@[inline] protected def msb (a : BitVec n) : Bool := getMsb a 0
end getXsb
section Int
/-- Interpret the bitvector as an integer stored in two's complement form. -/
protected def toInt (x : BitVec n) : Int :=
if 2 * x.toNat < 2^n then
x.toNat
protected def toInt (a : BitVec n) : Int :=
if 2 * a.toNat < 2^n then
a.toNat
else
(x.toNat : Int) - (2^n : Nat)
(a.toNat : Int) - (2^n : Nat)
/-- The `BitVec` with value `(2^n + (i mod 2^n)) mod 2^n`. -/
protected def ofInt (n : Nat) (i : Int) : BitVec n := .ofNatLt (i % (Int.ofNat (2^n))).toNat (by
@@ -215,7 +215,7 @@ instance : Neg (BitVec n) := ⟨.neg⟩
/--
Return the absolute value of a signed bitvector.
-/
protected def abs (x : BitVec n) : BitVec n := if x.msb then .neg x else x
protected def abs (s : BitVec n) : BitVec n := if s.msb then .neg s else s
/--
Multiplication for bit vectors. This can be interpreted as either signed or unsigned negation
@@ -262,12 +262,12 @@ sdiv 5#4 -2 = -2#4
sdiv (-7#4) (-2) = 3#4
```
-/
def sdiv (x y : BitVec n) : BitVec n :=
match x.msb, y.msb with
| false, false => udiv x y
| false, true => .neg (udiv x (.neg y))
| true, false => .neg (udiv (.neg x) y)
| true, true => udiv (.neg x) (.neg y)
def sdiv (s t : BitVec n) : BitVec n :=
match s.msb, t.msb with
| false, false => udiv s t
| false, true => .neg (udiv s (.neg t))
| true, false => .neg (udiv (.neg s) t)
| true, true => udiv (.neg s) (.neg t)
/--
Signed division for bit vectors using SMTLIB rules for division by zero.
@@ -276,40 +276,40 @@ Specifically, `smtSDiv x 0 = if x >= 0 then -1 else 1`
SMT-Lib name: `bvsdiv`.
-/
def smtSDiv (x y : BitVec n) : BitVec n :=
match x.msb, y.msb with
| false, false => smtUDiv x y
| false, true => .neg (smtUDiv x (.neg y))
| true, false => .neg (smtUDiv (.neg x) y)
| true, true => smtUDiv (.neg x) (.neg y)
def smtSDiv (s t : BitVec n) : BitVec n :=
match s.msb, t.msb with
| false, false => smtUDiv s t
| false, true => .neg (smtUDiv s (.neg t))
| true, false => .neg (smtUDiv (.neg s) t)
| true, true => smtUDiv (.neg s) (.neg t)
/--
Remainder for signed division rounding to zero.
SMT_Lib name: `bvsrem`.
-/
def srem (x y : BitVec n) : BitVec n :=
match x.msb, y.msb with
| false, false => umod x y
| false, true => umod x (.neg y)
| true, false => .neg (umod (.neg x) y)
| true, true => .neg (umod (.neg x) (.neg y))
def srem (s t : BitVec n) : BitVec n :=
match s.msb, t.msb with
| false, false => umod s t
| false, true => umod s (.neg t)
| true, false => .neg (umod (.neg s) t)
| true, true => .neg (umod (.neg s) (.neg t))
/--
Remainder for signed division rounded to negative infinity.
SMT_Lib name: `bvsmod`.
-/
def smod (x y : BitVec m) : BitVec m :=
match x.msb, y.msb with
| false, false => umod x y
def smod (s t : BitVec m) : BitVec m :=
match s.msb, t.msb with
| false, false => umod s t
| false, true =>
let u := umod x (.neg y)
(if u = .zero m then u else .add u y)
let u := umod s (.neg t)
(if u = .zero m then u else .add u t)
| true, false =>
let u := umod (.neg x) y
(if u = .zero m then u else .sub y u)
| true, true => .neg (umod (.neg x) (.neg y))
let u := umod (.neg s) t
(if u = .zero m then u else .sub t u)
| true, true => .neg (umod (.neg s) (.neg t))
end arithmetic
@@ -373,8 +373,8 @@ end relations
section cast
/-- `cast eq x` embeds `x` into an equal `BitVec` type. -/
@[inline] def cast (eq : n = m) (x : BitVec n) : BitVec m := .ofNatLt x.toNat (eq x.isLt)
/-- `cast eq i` embeds `i` into an equal `BitVec` type. -/
@[inline] def cast (eq : n = m) (i : BitVec n) : BitVec m := .ofNatLt i.toNat (eq i.isLt)
@[simp] theorem cast_ofNat {n m : Nat} (h : n = m) (x : Nat) :
cast h (BitVec.ofNat n x) = BitVec.ofNat m x := by
@@ -391,7 +391,7 @@ Extraction of bits `start` to `start + len - 1` from a bit vector of size `n` to
new bitvector of size `len`. If `start + len > n`, then the vector will be zero-padded in the
high bits.
-/
def extractLsb' (start len : Nat) (x : BitVec n) : BitVec len := .ofNat _ (x.toNat >>> start)
def extractLsb' (start len : Nat) (a : BitVec n) : BitVec len := .ofNat _ (a.toNat >>> start)
/--
Extraction of bits `hi` (inclusive) down to `lo` (inclusive) from a bit vector of size `n` to
@@ -399,12 +399,12 @@ yield a new bitvector of size `hi - lo + 1`.
SMT-Lib name: `extract`.
-/
def extractLsb (hi lo : Nat) (x : BitVec n) : BitVec (hi - lo + 1) := extractLsb' lo _ x
def extractLsb (hi lo : Nat) (a : BitVec n) : BitVec (hi - lo + 1) := extractLsb' lo _ a
/--
A version of `zeroExtend` that requires a proof, but is a noop.
-/
def zeroExtend' {n w : Nat} (le : n w) (x : BitVec n) : BitVec w :=
def zeroExtend' {n w : Nat} (le : n w) (x : BitVec n) : BitVec w :=
x.toNat#'(by
apply Nat.lt_of_lt_of_le x.isLt
exact Nat.pow_le_pow_of_le_right (by trivial) le)
@@ -413,8 +413,8 @@ def zeroExtend' {n w : Nat} (le : n ≤ w) (x : BitVec n) : BitVec w :=
`shiftLeftZeroExtend x n` returns `zeroExtend (w+n) x <<< n` without
needing to compute `x % 2^(2+n)`.
-/
def shiftLeftZeroExtend (msbs : BitVec w) (m : Nat) : BitVec (w + m) :=
let shiftLeftLt {x : Nat} (p : x < 2^w) (m : Nat) : x <<< m < 2^(w + m) := by
def shiftLeftZeroExtend (msbs : BitVec w) (m : Nat) : BitVec (w+m) :=
let shiftLeftLt {x : Nat} (p : x < 2^w) (m : Nat) : x <<< m < 2^(w+m) := by
simp [Nat.shiftLeft_eq, Nat.pow_add]
apply Nat.mul_lt_mul_of_pos_right p
exact (Nat.two_pow_pos m)
@@ -502,24 +502,24 @@ instance : Complement (BitVec w) := ⟨.not⟩
/--
Left shift for bit vectors. The low bits are filled with zeros. As a numeric operation, this is
equivalent to `x * 2^s`, modulo `2^n`.
equivalent to `a * 2^s`, modulo `2^n`.
SMT-Lib name: `bvshl` except this operator uses a `Nat` shift value.
-/
protected def shiftLeft (x : BitVec n) (s : Nat) : BitVec n := BitVec.ofNat n (x.toNat <<< s)
protected def shiftLeft (a : BitVec n) (s : Nat) : BitVec n := BitVec.ofNat n (a.toNat <<< s)
instance : HShiftLeft (BitVec w) Nat (BitVec w) := .shiftLeft
/--
(Logical) right shift for bit vectors. The high bits are filled with zeros.
As a numeric operation, this is equivalent to `x / 2^s`, rounding down.
As a numeric operation, this is equivalent to `a / 2^s`, rounding down.
SMT-Lib name: `bvlshr` except this operator uses a `Nat` shift value.
-/
def ushiftRight (x : BitVec n) (s : Nat) : BitVec n :=
(x.toNat >>> s)#'(by
let x, lt := x
def ushiftRight (a : BitVec n) (s : Nat) : BitVec n :=
(a.toNat >>> s)#'(by
let a, lt := a
simp only [BitVec.toNat, Nat.shiftRight_eq_div_pow, Nat.div_lt_iff_lt_mul (Nat.two_pow_pos s)]
rw [Nat.mul_one x]
rw [Nat.mul_one a]
exact Nat.mul_lt_mul_of_lt_of_le' lt (Nat.two_pow_pos s) (Nat.le_refl 1))
instance : HShiftRight (BitVec w) Nat (BitVec w) := .ushiftRight
@@ -527,24 +527,15 @@ instance : HShiftRight (BitVec w) Nat (BitVec w) := ⟨.ushiftRight⟩
/--
Arithmetic right shift for bit vectors. The high bits are filled with the
most-significant bit.
As a numeric operation, this is equivalent to `x.toInt >>> s`.
As a numeric operation, this is equivalent to `a.toInt >>> s`.
SMT-Lib name: `bvashr` except this operator uses a `Nat` shift value.
-/
def sshiftRight (x : BitVec n) (s : Nat) : BitVec n := .ofInt n (x.toInt >>> s)
def sshiftRight (a : BitVec n) (s : Nat) : BitVec n := .ofInt n (a.toInt >>> s)
instance {n} : HShiftLeft (BitVec m) (BitVec n) (BitVec m) := fun x y => x <<< y.toNat
instance {n} : HShiftRight (BitVec m) (BitVec n) (BitVec m) := fun x y => x >>> y.toNat
/--
Arithmetic right shift for bit vectors. The high bits are filled with the
most-significant bit.
As a numeric operation, this is equivalent to `a.toInt >>> s.toNat`.
SMT-Lib name: `bvashr`.
-/
def sshiftRight' (a : BitVec n) (s : BitVec m) : BitVec n := a.sshiftRight s.toNat
/-- Auxiliary function for `rotateLeft`, which does not take into account the case where
the rotation amount is greater than the bitvector width. -/
def rotateLeftAux (x : BitVec w) (n : Nat) : BitVec w :=

View File

@@ -289,18 +289,18 @@ theorem sle_eq_carry (x y : BitVec w) :
A recurrence that describes multiplication as repeated addition.
Is useful for bitblasting multiplication.
-/
def mulRec (x y : BitVec w) (s : Nat) : BitVec w :=
let cur := if y.getLsb s then (x <<< s) else 0
def mulRec (l r : BitVec w) (s : Nat) : BitVec w :=
let cur := if r.getLsb s then (l <<< s) else 0
match s with
| 0 => cur
| s + 1 => mulRec x y s + cur
| s + 1 => mulRec l r s + cur
theorem mulRec_zero_eq (x y : BitVec w) :
mulRec x y 0 = if y.getLsb 0 then x else 0 := by
theorem mulRec_zero_eq (l r : BitVec w) :
mulRec l r 0 = if r.getLsb 0 then l else 0 := by
simp [mulRec]
theorem mulRec_succ_eq (x y : BitVec w) (s : Nat) :
mulRec x y (s + 1) = mulRec x y s + if y.getLsb (s + 1) then (x <<< (s + 1)) else 0 := rfl
theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) :
mulRec l r (s + 1) = mulRec l r s + if r.getLsb (s + 1) then (l <<< (s + 1)) else 0 := rfl
/--
Recurrence lemma: truncating to `i+1` bits and then zero extending to `w`
@@ -326,29 +326,29 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w
by_cases hi : x.getLsb i <;> simp [hi] <;> omega
/--
Recurrence lemma: multiplying `x` with the first `s` bits of `y` is the
same as truncating `y` to `s` bits, then zero extending to the original length,
Recurrence lemma: multiplying `l` with the first `s` bits of `r` is the
same as truncating `r` to `s` bits, then zero extending to the original length,
and performing the multplication. -/
theorem mulRec_eq_mul_signExtend_truncate (x y : BitVec w) (s : Nat) :
mulRec x y s = x * ((y.truncate (s + 1)).zeroExtend w) := by
theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) :
mulRec l r s = l * ((r.truncate (s + 1)).zeroExtend w) := by
induction s
case zero =>
simp only [mulRec_zero_eq, ofNat_eq_ofNat, Nat.reduceAdd]
by_cases y.getLsb 0
case pos hy =>
simp only [hy, reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero,
ofBool_true, ofNat_eq_ofNat]
by_cases r.getLsb 0
case pos hr =>
simp only [hr, reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero,
hr, ofBool_true, ofNat_eq_ofNat]
rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)]
simp
case neg hy =>
simp [hy, zeroExtend_one_eq_ofBool_getLsb_zero]
case neg hr =>
simp [hr, zeroExtend_one_eq_ofBool_getLsb_zero]
case succ s' hs =>
rw [mulRec_succ_eq, hs]
have heq :
(if y.getLsb (s' + 1) = true then x <<< (s' + 1) else 0) =
(x * (y &&& (BitVec.twoPow w (s' + 1)))) := by
(if r.getLsb (s' + 1) = true then l <<< (s' + 1) else 0) =
(l * (r &&& (BitVec.twoPow w (s' + 1)))) := by
simp only [ofNat_eq_ofNat, and_twoPow]
by_cases hy : y.getLsb (s' + 1) <;> simp [hy]
by_cases hr : r.getLsb (s' + 1) <;> simp [hr]
rw [heq, BitVec.mul_add, zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow]
theorem getLsb_mul (x y : BitVec w) (i : Nat) :
@@ -429,67 +429,6 @@ theorem shiftLeft_eq_shiftLeftRec (x : BitVec w₁) (y : BitVec w₂) :
· simp [of_length_zero]
· simp [shiftLeftRec_eq]
/- ### Arithmetic shift right (sshiftRight) recurrence -/
/--
`sshiftRightRec x y n` shifts `x` arithmetically/signed to the right by the first `n` bits of `y`.
The theorem `sshiftRight_eq_sshiftRightRec` proves the equivalence of `(x.sshiftRight y)` and `sshiftRightRec`.
Together with equations `sshiftRightRec_zero`, `sshiftRightRec_succ`,
this allows us to unfold `sshiftRight` into a circuit for bitblasting.
-/
def sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ :=
let shiftAmt := (y &&& (twoPow w₂ n))
match n with
| 0 => x.sshiftRight' shiftAmt
| n + 1 => (sshiftRightRec x y n).sshiftRight' shiftAmt
@[simp]
theorem sshiftRightRec_zero_eq (x : BitVec w₁) (y : BitVec w₂) :
sshiftRightRec x y 0 = x.sshiftRight' (y &&& 1#w₂) := by
simp only [sshiftRightRec, twoPow_zero]
@[simp]
theorem sshiftRightRec_succ_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
sshiftRightRec x y (n + 1) = (sshiftRightRec x y n).sshiftRight' (y &&& twoPow w₂ (n + 1)) := by
simp [sshiftRightRec]
/--
If `y &&& z = 0`, `x.sshiftRight (y ||| z) = (x.sshiftRight y).sshiftRight z`.
This follows as `y &&& z = 0` implies `y ||| z = y + z`,
and thus `x.sshiftRight (y ||| z) = x.sshiftRight (y + z) = (x.sshiftRight y).sshiftRight z`.
-/
theorem sshiftRight'_or_of_and_eq_zero {x : BitVec w₁} {y z : BitVec w₂}
(h : y &&& z = 0#w₂) :
x.sshiftRight' (y ||| z) = (x.sshiftRight' y).sshiftRight' z := by
simp [sshiftRight', add_eq_or_of_and_eq_zero _ _ h,
toNat_add_of_and_eq_zero h, sshiftRight_add]
theorem sshiftRightRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
sshiftRightRec x y n = x.sshiftRight' ((y.truncate (n + 1)).zeroExtend w₂) := by
induction n generalizing x y
case zero =>
ext i
simp [twoPow_zero, Nat.reduceAdd, and_one_eq_zeroExtend_ofBool_getLsb, truncate_one]
case succ n ih =>
simp only [sshiftRightRec_succ_eq, and_twoPow, ih]
by_cases h : y.getLsb (n + 1)
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true h,
sshiftRight'_or_of_and_eq_zero (by simp), h]
simp
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1)
(by simp [h])]
simp [h]
/--
Show that `x.sshiftRight y` can be written in terms of `sshiftRightRec`.
This can be unfolded in terms of `sshiftRightRec_zero_eq`, `sshiftRightRec_succ_eq` for bitblasting.
-/
theorem sshiftRight_eq_sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) :
(x.sshiftRight' y).getLsb i = (sshiftRightRec x y (w₂ - 1)).getLsb i := by
rcases w₂ with rfl | w₂
· simp [of_length_zero]
· simp [sshiftRightRec_eq]
/- ### Logical shift right (ushiftRight) recurrence for bitblasting -/
/--

View File

@@ -23,7 +23,7 @@ theorem ofFin_eq_ofNat : @BitVec.ofFin w (Fin.mk x lt) = BitVec.ofNat w x := by
simp only [BitVec.ofNat, Fin.ofNat', lt, Nat.mod_eq_of_lt]
/-- Prove equality of bitvectors in terms of nat operations. -/
theorem eq_of_toNat_eq {n} : {x y : BitVec n}, x.toNat = y.toNat x = y
theorem eq_of_toNat_eq {n} : {i j : BitVec n}, i.toNat = j.toNat i = j
| _, _, _, _, rfl => rfl
@[simp] theorem val_toFin (x : BitVec w) : x.toFin.val = x.toNat := rfl
@@ -228,12 +228,12 @@ theorem toNat_ge_of_msb_true {x : BitVec n} (p : BitVec.msb x = true) : x.toNat
/-! ### toInt/ofInt -/
/-- Prove equality of bitvectors in terms of nat operations. -/
theorem toInt_eq_toNat_cond (x : BitVec n) :
x.toInt =
if 2*x.toNat < 2^n then
(x.toNat : Int)
theorem toInt_eq_toNat_cond (i : BitVec n) :
i.toInt =
if 2*i.toNat < 2^n then
(i.toNat : Int)
else
(x.toNat : Int) - (2^n : Nat) :=
(i.toNat : Int) - (2^n : Nat) :=
rfl
theorem msb_eq_false_iff_two_mul_lt (x : BitVec w) : x.msb = false 2 * x.toNat < 2^w := by
@@ -260,13 +260,13 @@ theorem toInt_eq_toNat_bmod (x : BitVec n) : x.toInt = Int.bmod x.toNat (2^n) :=
omega
/-- Prove equality of bitvectors in terms of nat operations. -/
theorem eq_of_toInt_eq {x y : BitVec n} : x.toInt = y.toInt x = y := by
theorem eq_of_toInt_eq {i j : BitVec n} : i.toInt = j.toInt i = j := by
intro eq
simp [toInt_eq_toNat_cond] at eq
apply eq_of_toNat_eq
revert eq
have _xlt := x.isLt
have _ylt := y.isLt
have _ilt := i.isLt
have _jlt := j.isLt
split <;> split <;> omega
theorem toInt_inj (x y : BitVec n) : x.toInt = y.toInt x = y :=
@@ -507,13 +507,6 @@ theorem or_assoc (x y z : BitVec w) :
x ||| y ||| z = x ||| (y ||| z) := by
ext i
simp [Bool.or_assoc]
instance : Std.Associative (α := BitVec n) (· ||| ·) := BitVec.or_assoc
theorem or_comm (x y : BitVec w) :
x ||| y = y ||| x := by
ext i
simp [Bool.or_comm]
instance : Std.Commutative (fun (x y : BitVec w) => x ||| y) := BitVec.or_comm
/-! ### and -/
@@ -545,13 +538,11 @@ theorem and_assoc (x y z : BitVec w) :
x &&& y &&& z = x &&& (y &&& z) := by
ext i
simp [Bool.and_assoc]
instance : Std.Associative (α := BitVec n) (· &&& ·) := BitVec.and_assoc
theorem and_comm (x y : BitVec w) :
x &&& y = y &&& x := by
ext i
simp [Bool.and_comm]
instance : Std.Commutative (fun (x y : BitVec w) => x &&& y) := BitVec.and_comm
/-! ### xor -/
@@ -577,13 +568,6 @@ theorem xor_assoc (x y z : BitVec w) :
x ^^^ y ^^^ z = x ^^^ (y ^^^ z) := by
ext i
simp [Bool.xor_assoc]
instance : Std.Associative (fun (x y : BitVec w) => x ^^^ y) := BitVec.xor_assoc
theorem xor_comm (x y : BitVec w) :
x ^^^ y = y ^^^ x := by
ext i
simp [Bool.xor_comm]
instance : Std.Commutative (fun (x y : BitVec w) => x ^^^ y) := BitVec.xor_comm
/-! ### not -/
@@ -749,21 +733,6 @@ theorem getLsb_shiftLeft' {x : BitVec w₁} {y : BitVec w₂} {i : Nat} :
getLsb (x >>> i) j = getLsb x (i+j) := by
unfold getLsb ; simp
theorem ushiftRight_xor_distrib (x y : BitVec w) (n : Nat) :
(x ^^^ y) >>> n = (x >>> n) ^^^ (y >>> n) := by
ext
simp
theorem ushiftRight_and_distrib (x y : BitVec w) (n : Nat) :
(x &&& y) >>> n = (x >>> n) &&& (y >>> n) := by
ext
simp
theorem ushiftRight_or_distrib (x y : BitVec w) (n : Nat) :
(x ||| y) >>> n = (x >>> n) ||| (y >>> n) := by
ext
simp
@[simp]
theorem ushiftRight_zero_eq (x : BitVec w) : x >>> 0 = x := by
simp [bv_toNat]
@@ -817,7 +786,7 @@ theorem sshiftRight_eq_of_msb_true {x : BitVec w} {s : Nat} (h : x.msb = true) :
· rw [Nat.shiftRight_eq_div_pow]
apply Nat.lt_of_le_of_lt (Nat.div_le_self _ _) (by omega)
@[simp] theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
getLsb (x.sshiftRight s) i =
(!decide (w i) && if s + i < w then x.getLsb (s + i) else x.msb) := by
rcases hmsb : x.msb with rfl | rfl
@@ -838,41 +807,6 @@ theorem sshiftRight_eq_of_msb_true {x : BitVec w} {s : Nat} (h : x.msb = true) :
Nat.not_lt, decide_eq_true_eq]
omega
/-- The msb after arithmetic shifting right equals the original msb. -/
theorem sshiftRight_msb_eq_msb {n : Nat} {x : BitVec w} :
(x.sshiftRight n).msb = x.msb := by
rw [msb_eq_getLsb_last, getLsb_sshiftRight, msb_eq_getLsb_last]
by_cases hw₀ : w = 0
· simp [hw₀]
· simp only [show ¬(w w - 1) by omega, decide_False, Bool.not_false, Bool.true_and,
ite_eq_right_iff]
intros h
simp [show n = 0 by omega]
@[simp] theorem sshiftRight_zero {x : BitVec w} : x.sshiftRight 0 = x := by
ext i
simp
theorem sshiftRight_add {x : BitVec w} {m n : Nat} :
x.sshiftRight (m + n) = (x.sshiftRight m).sshiftRight n := by
ext i
simp only [getLsb_sshiftRight, Nat.add_assoc]
by_cases h₁ : w (i : Nat)
· simp [h₁]
· simp only [h₁, decide_False, Bool.not_false, Bool.true_and]
by_cases h₂ : n + i < w
· simp [h₂]
· simp only [h₂, reduceIte]
by_cases h₃ : m + (n + i) < w
· simp [h₃]
omega
· simp [h₃, sshiftRight_msb_eq_msb]
/-! ### sshiftRight reductions from BitVec to Nat -/
@[simp]
theorem sshiftRight_eq' (x : BitVec w) : x.sshiftRight' y = x.sshiftRight y.toNat := rfl
/-! ### signExtend -/
/-- Equation theorem for `Int.sub` when both arguments are `Int.ofNat` -/
@@ -935,15 +869,15 @@ theorem append_def (x : BitVec v) (y : BitVec w) :
(x ++ y).toNat = x.toNat <<< n ||| y.toNat :=
rfl
@[simp] theorem getLsb_append {x : BitVec n} {y : BitVec m} :
getLsb (x ++ y) i = bif i < m then getLsb y i else getLsb x (i - m) := by
@[simp] theorem getLsb_append {v : BitVec n} {w : BitVec m} :
getLsb (v ++ w) i = bif i < m then getLsb w i else getLsb v (i - m) := by
simp only [append_def, getLsb_or, getLsb_shiftLeftZeroExtend, getLsb_zeroExtend']
by_cases h : i < m
· simp [h]
· simp [h]; simp_all
@[simp] theorem getMsb_append {x : BitVec n} {y : BitVec m} :
getMsb (x ++ y) i = bif n i then getMsb y (i - n) else getMsb x i := by
@[simp] theorem getMsb_append {v : BitVec n} {w : BitVec m} :
getMsb (v ++ w) i = bif n i then getMsb w (i - n) else getMsb v i := by
simp [append_def]
by_cases h : n i
· simp [h]

View File

@@ -438,24 +438,6 @@ Added for confluence between `if_true_left` and `ite_false_same` on
-/
@[simp] theorem eq_true_imp_eq_false : (b:Bool), (b = true b = false) (b = false) := by decide
/-! ### forall -/
theorem forall_bool' {p : Bool Prop} (b : Bool) : ( x, p x) p b p !b :=
fun h h _, h _, fun h₁, h₂ x by cases b <;> cases x <;> assumption
@[simp]
theorem forall_bool {p : Bool Prop} : ( b, p b) p false p true :=
forall_bool' false
/-! ### exists -/
theorem exists_bool' {p : Bool Prop} (b : Bool) : ( x, p x) p b p !b :=
fun x, hx by cases x <;> cases b <;> first | exact .inl _ | exact .inr _,
fun h by cases h <;> exact _, _
@[simp]
theorem exists_bool {p : Bool Prop} : ( b, p b) p false p true :=
exists_bool' false
/-! ### cond -/

View File

@@ -354,7 +354,7 @@ theorem erase_eq_iff [LawfulBEq α] {a : α} {l : List α} :
rw [erase_of_not_mem]
simp_all
theorem Nodup.erase_eq_filter [LawfulBEq α] {l} (d : Nodup l) (a : α) : l.erase a = l.filter (· != a) := by
theorem Nodup.erase_eq_filter [BEq α] [LawfulBEq α] {l} (d : Nodup l) (a : α) : l.erase a = l.filter (· != a) := by
induction d with
| nil => rfl
| cons m _n ih =>
@@ -367,13 +367,13 @@ theorem Nodup.erase_eq_filter [LawfulBEq α] {l} (d : Nodup l) (a : α) : l.eras
simpa [@eq_comm α] using m
· simp [beq_false_of_ne h, ih, h]
theorem Nodup.mem_erase_iff [LawfulBEq α] {a : α} (d : Nodup l) : a l.erase b a b a l := by
theorem Nodup.mem_erase_iff [BEq α] [LawfulBEq α] {a : α} (d : Nodup l) : a l.erase b a b a l := by
rw [Nodup.erase_eq_filter d, mem_filter, and_comm, bne_iff_ne]
theorem Nodup.not_mem_erase [LawfulBEq α] {a : α} (h : Nodup l) : a l.erase a := fun H => by
theorem Nodup.not_mem_erase [BEq α] [LawfulBEq α] {a : α} (h : Nodup l) : a l.erase a := fun H => by
simpa using ((Nodup.mem_erase_iff h).mp H).left
theorem Nodup.erase [LawfulBEq α] (a : α) : Nodup l Nodup (l.erase a) :=
theorem Nodup.erase [BEq α] [LawfulBEq α] (a : α) : Nodup l Nodup (l.erase a) :=
Nodup.sublist <| erase_sublist _ _
end erase

View File

@@ -5,18 +5,9 @@ Author: Leonardo de Moura
-/
prelude
import Init.SimpLemmas
import Init.NotationExtra
instance [BEq α] [BEq β] [LawfulBEq α] [LawfulBEq β] : LawfulBEq (α × β) where
eq_of_beq {a b} (h : a.1 == b.1 && a.2 == b.2) := by
cases a; cases b
refine congr (congrArg _ (eq_of_beq ?_)) (eq_of_beq ?_) <;> simp_all
rfl {a} := by cases a; simp [BEq.beq, LawfulBEq.rfl]
@[simp]
protected theorem Prod.forall {p : α × β Prop} : ( x, p x) a b, p (a, b) :=
fun h a b h (a, b), fun h a, b h a b
@[simp]
protected theorem Prod.exists {p : α × β Prop} : ( x, p x) a b, p (a, b) :=
fun a, b, h a, b, h, fun a, b, h a, b, h

View File

@@ -78,10 +78,7 @@ end Elab.Tactic.Ext
end Lean
attribute [ext] Prod PProd Sigma PSigma
attribute [ext] funext propext Subtype.eq Array.ext
attribute [ext] funext propext Subtype.eq
@[ext] protected theorem PUnit.ext (x y : PUnit) : x = y := rfl
protected theorem Unit.ext (x y : Unit) : x = y := rfl
@[ext] protected theorem Thunk.ext : {a b : Thunk α} a.get = b.get a = b
| {..}, {..}, heq => congrArg _ <| funext fun _ => heq

View File

@@ -470,23 +470,31 @@ def withFile (fn : FilePath) (mode : Mode) (f : Handle → IO α) : IO α :=
def Handle.putStrLn (h : Handle) (s : String) : IO Unit :=
h.putStr (s.push '\n')
partial def Handle.readBinToEndInto (h : Handle) (buf : ByteArray) : IO ByteArray := do
partial def Handle.readBinToEnd (h : Handle) : IO ByteArray := do
let rec loop (acc : ByteArray) : IO ByteArray := do
let buf h.read 1024
if buf.isEmpty then
return acc
else
loop (acc ++ buf)
loop buf
loop ByteArray.empty
partial def Handle.readBinToEnd (h : Handle) : IO ByteArray := do
h.readBinToEndInto .empty
partial def Handle.readToEnd (h : Handle) : IO String := do
let rec loop (s : String) := do
let line h.getLine
if line.isEmpty then
return s
else
loop (s ++ line)
loop ""
def Handle.readToEnd (h : Handle) : IO String := do
let data h.readBinToEnd
match String.fromUTF8? data with
| some s => return s
| none => throw <| .userError s!"Tried to read from handle containing non UTF-8 data."
def readBinFile (fname : FilePath) : IO ByteArray := do
let h Handle.mk fname Mode.read
h.readBinToEnd
def readFile (fname : FilePath) : IO String := do
let h Handle.mk fname Mode.read
h.readToEnd
partial def lines (fname : FilePath) : IO (Array String) := do
let h Handle.mk fname Mode.read
@@ -592,28 +600,6 @@ end System.FilePath
namespace IO
namespace FS
def readBinFile (fname : FilePath) : IO ByteArray := do
-- Requires metadata so defined after metadata
let mdata fname.metadata
let size := mdata.byteSize.toUSize
let handle IO.FS.Handle.mk fname .read
let buf
if size > 0 then
handle.read mdata.byteSize.toUSize
else
pure <| ByteArray.mkEmpty 0
handle.readBinToEndInto buf
def readFile (fname : FilePath) : IO String := do
let data readBinFile fname
match String.fromUTF8? data with
| some s => return s
| none => throw <| .userError s!"Tried to read file '{fname}' containing non UTF-8 data."
end FS
def withStdin [Monad m] [MonadFinally m] [MonadLiftT BaseIO m] (h : FS.Stream) (x : m α) : m α := do
let prev setStdin h
try x finally discard <| setStdin prev

View File

@@ -68,7 +68,6 @@ noncomputable def recursion {C : α → Sort v} (a : α) (h : ∀ x, (∀ y, r y
induction (apply hwf a) with
| intro x₁ _ ih => exact h x₁ ih
include hwf in
theorem induction {C : α Prop} (a : α) (h : x, ( y, r y x C y) C x) : C a :=
recursion hwf a h

View File

@@ -53,7 +53,7 @@ structure AttributeImpl extends AttributeImplCore where
erase (decl : Name) : AttrM Unit := throwError "attribute cannot be erased"
deriving Inhabited
builtin_initialize attributeMapRef : IO.Ref (Std.HashMap Name AttributeImpl) IO.mkRef {}
builtin_initialize attributeMapRef : IO.Ref (HashMap Name AttributeImpl) IO.mkRef {}
/-- Low level attribute registration function. -/
def registerBuiltinAttribute (attr : AttributeImpl) : IO Unit := do
@@ -296,7 +296,7 @@ end EnumAttributes
-/
abbrev AttributeImplBuilder := Name List DataValue Except String AttributeImpl
abbrev AttributeImplBuilderTable := Std.HashMap Name AttributeImplBuilder
abbrev AttributeImplBuilderTable := HashMap Name AttributeImplBuilder
builtin_initialize attributeImplBuilderTableRef : IO.Ref AttributeImplBuilderTable IO.mkRef {}
@@ -307,7 +307,7 @@ def registerAttributeImplBuilder (builderId : Name) (builder : AttributeImplBuil
def mkAttributeImplOfBuilder (builderId ref : Name) (args : List DataValue) : IO AttributeImpl := do
let table attributeImplBuilderTableRef.get
match table[builderId]? with
match table.find? builderId with
| none => throw (IO.userError ("unknown attribute implementation builder '" ++ toString builderId ++ "'"))
| some builder => IO.ofExcept <| builder ref args
@@ -317,7 +317,7 @@ inductive AttributeExtensionOLeanEntry where
structure AttributeExtensionState where
newEntries : List AttributeExtensionOLeanEntry := []
map : Std.HashMap Name AttributeImpl
map : HashMap Name AttributeImpl
deriving Inhabited
abbrev AttributeExtension := PersistentEnvExtension AttributeExtensionOLeanEntry (AttributeExtensionOLeanEntry × AttributeImpl) AttributeExtensionState
@@ -348,7 +348,7 @@ private def AttributeExtension.addImported (es : Array (Array AttributeExtension
let map es.foldlM
(fun map entries =>
entries.foldlM
(fun (map : Std.HashMap Name AttributeImpl) entry => do
(fun (map : HashMap Name AttributeImpl) entry => do
let attrImpl mkAttributeImplOfEntry ctx.env ctx.opts entry
return map.insert attrImpl.name attrImpl)
map)
@@ -378,7 +378,7 @@ def getBuiltinAttributeNames : IO (List Name) :=
def getBuiltinAttributeImpl (attrName : Name) : IO AttributeImpl := do
let m attributeMapRef.get
match m[attrName]? with
match m.find? attrName with
| some attr => pure attr
| none => throw (IO.userError ("unknown attribute '" ++ toString attrName ++ "'"))
@@ -396,7 +396,7 @@ def getAttributeNames (env : Environment) : List Name :=
def getAttributeImpl (env : Environment) (attrName : Name) : Except String AttributeImpl :=
let m := (attributeExtension.getState env).map
match m[attrName]? with
match m.find? attrName with
| some attr => pure attr
| none => throw ("unknown attribute '" ++ toString attrName ++ "'")

View File

@@ -26,9 +26,9 @@ instance : Hashable Key := ⟨getHash⟩
end OwnedSet
open OwnedSet (Key) in
abbrev OwnedSet := Std.HashMap Key Unit
def OwnedSet.insert (s : OwnedSet) (k : OwnedSet.Key) : OwnedSet := Std.HashMap.insert s k ()
def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := Std.HashMap.contains s k
abbrev OwnedSet := HashMap Key Unit
def OwnedSet.insert (s : OwnedSet) (k : OwnedSet.Key) : OwnedSet := HashMap.insert s k ()
def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := HashMap.contains s k
/-! We perform borrow inference in a block of mutually recursive functions.
Join points are viewed as local functions, and are identified using
@@ -49,7 +49,7 @@ instance : Hashable Key := ⟨getHash⟩
end ParamMap
open ParamMap (Key)
abbrev ParamMap := Std.HashMap Key (Array Param)
abbrev ParamMap := HashMap Key (Array Param)
def ParamMap.fmt (map : ParamMap) : Format :=
let fmts := map.fold (fun fmt k ps =>
@@ -109,7 +109,7 @@ partial def visitFnBody (fn : FunId) (paramMap : ParamMap) : FnBody → FnBody
| FnBody.jdecl j _ v b =>
let v := visitFnBody fn paramMap v
let b := visitFnBody fn paramMap b
match paramMap[ParamMap.Key.jp fn j]? with
match paramMap.find? (ParamMap.Key.jp fn j) with
| some ys => FnBody.jdecl j ys v b
| none => unreachable!
| FnBody.case tid x xType alts =>
@@ -125,7 +125,7 @@ def visitDecls (decls : Array Decl) (paramMap : ParamMap) : Array Decl :=
decls.map fun decl => match decl with
| Decl.fdecl f _ ty b info =>
let b := visitFnBody f paramMap b
match paramMap[ParamMap.Key.decl f]? with
match paramMap.find? (ParamMap.Key.decl f) with
| some xs => Decl.fdecl f xs ty b info
| none => unreachable!
| other => other
@@ -178,7 +178,7 @@ def isOwned (x : VarId) : M Bool := do
/-- Updates `map[k]` using the current set of `owned` variables. -/
def updateParamMap (k : ParamMap.Key) : M Unit := do
let s get
match s.paramMap[k]? with
match s.paramMap.find? k with
| some ps => do
let ps ps.mapM fun (p : Param) => do
if !p.borrow then pure p
@@ -192,7 +192,7 @@ def updateParamMap (k : ParamMap.Key) : M Unit := do
def getParamInfo (k : ParamMap.Key) : M (Array Param) := do
let s get
match s.paramMap[k]? with
match s.paramMap.find? k with
| some ps => pure ps
| none =>
match k with

View File

@@ -11,7 +11,6 @@ import Lean.Compiler.IR.Basic
import Lean.Compiler.IR.CompilerM
import Lean.Compiler.IR.FreeVars
import Lean.Compiler.IR.ElimDeadVars
import Lean.Data.AssocList
namespace Lean.IR.ExplicitBoxing
/-!

View File

@@ -152,7 +152,7 @@ def getFunctionSummary? (env : Environment) (fid : FunId) : Option Value :=
| some modIdx => findAtSorted? (functionSummariesExt.getModuleEntries env modIdx) fid
| none => functionSummariesExt.getState env |>.find? fid
abbrev Assignment := Std.HashMap VarId Value
abbrev Assignment := HashMap VarId Value
structure InterpContext where
currFnIdx : Nat := 0
@@ -172,7 +172,7 @@ def findVarValue (x : VarId) : M Value := do
let ctx read
let s get
let assignment := s.assignments[ctx.currFnIdx]!
return assignment.getD x bot
return assignment.findD x bot
def findArgValue (arg : Arg) : M Value :=
match arg with
@@ -303,7 +303,7 @@ partial def elimDeadAux (assignment : Assignment) : FnBody → FnBody
| FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux assignment b)
| FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux assignment v) (elimDeadAux assignment b)
| FnBody.case tid x xType alts =>
let v := assignment.getD x bot
let v := assignment.findD x bot
let alts := alts.map fun alt =>
match alt with
| Alt.ctor i b => Alt.ctor i <| if containsCtor v i then elimDeadAux assignment b else FnBody.unreachable

View File

@@ -96,13 +96,8 @@ def shouldExport (n : Name) : Bool :=
-- libleanshared to avoid Windows symbol limit
!(`Lean.Compiler.LCNF).isPrefixOf n &&
!(`Lean.IR).isPrefixOf n &&
-- Lean.Server.findModuleRefs is used in SubVerso, and the contents of RequestM are used by the
-- full Verso as well as anything else that extends the LSP server.
(!(`Lean.Server.Watchdog).isPrefixOf n) &&
(!(`Lean.Server.ImportCompletion).isPrefixOf n) &&
(!(`Lean.Server.Completion).isPrefixOf n)
-- Lean.Server.findModuleRefs is used in Verso
(!(`Lean.Server).isPrefixOf n || n == `Lean.Server.findModuleRefs)
def emitFnDeclAux (decl : Decl) (cppBaseName : String) (isExternal : Bool) : M Unit := do
let ps := decl.params
@@ -257,7 +252,7 @@ def throwUnknownVar {α : Type} (x : VarId) : M α :=
def getJPParams (j : JoinPointId) : M (Array Param) := do
let ctx read;
match ctx.jpMap[j]? with
match ctx.jpMap.find? j with
| some ps => pure ps
| none => throw "unknown join point"

View File

@@ -65,8 +65,8 @@ structure Context (llvmctx : LLVM.Context) where
llvmmodule : LLVM.Module llvmctx
structure State (llvmctx : LLVM.Context) where
var2val : Std.HashMap VarId (LLVM.LLVMType llvmctx × LLVM.Value llvmctx)
jp2bb : Std.HashMap JoinPointId (LLVM.BasicBlock llvmctx)
var2val : HashMap VarId (LLVM.LLVMType llvmctx × LLVM.Value llvmctx)
jp2bb : HashMap JoinPointId (LLVM.BasicBlock llvmctx)
abbrev Error := String
@@ -84,7 +84,7 @@ def addJpTostate (jp : JoinPointId) (bb : LLVM.BasicBlock llvmctx) : M llvmctx U
def emitJp (jp : JoinPointId) : M llvmctx (LLVM.BasicBlock llvmctx) := do
let state get
match state.jp2bb[jp]? with
match state.jp2bb.find? jp with
| .some bb => return bb
| .none => throw s!"unable to find join point {jp}"
@@ -531,7 +531,7 @@ def emitFnDecls : M llvmctx Unit := do
def emitLhsSlot_ (x : VarId) : M llvmctx (LLVM.LLVMType llvmctx × LLVM.Value llvmctx) := do
let state get
match state.var2val[x]? with
match state.var2val.find? x with
| .some v => return v
| .none => throw s!"unable to find variable {x}"
@@ -1029,7 +1029,7 @@ def emitTailCall (builder : LLVM.Builder llvmctx) (f : FunId) (v : Expr) : M llv
def emitJmp (builder : LLVM.Builder llvmctx) (jp : JoinPointId) (xs : Array Arg) : M llvmctx Unit := do
let llvmctx read
let ps match llvmctx.jpMap[jp]? with
let ps match llvmctx.jpMap.find? jp with
| some ps => pure ps
| none => throw s!"Unknown join point {jp}"
unless xs.size == ps.size do throw s!"Invalid goto, mismatched sizes between arguments, formal parameters."

View File

@@ -51,8 +51,8 @@ end CollectUsedDecls
def collectUsedDecls (env : Environment) (decl : Decl) (used : NameSet := {}) : NameSet :=
(CollectUsedDecls.collectDecl decl env).run' used
abbrev VarTypeMap := Std.HashMap VarId IRType
abbrev JPParamsMap := Std.HashMap JoinPointId (Array Param)
abbrev VarTypeMap := HashMap VarId IRType
abbrev JPParamsMap := HashMap JoinPointId (Array Param)
namespace CollectMaps
abbrev Collector := (VarTypeMap × JPParamsMap) (VarTypeMap × JPParamsMap)

View File

@@ -10,7 +10,7 @@ import Lean.Compiler.IR.FreeVars
namespace Lean.IR.ExpandResetReuse
/-- Mapping from variable to projections -/
abbrev ProjMap := Std.HashMap VarId Expr
abbrev ProjMap := HashMap VarId Expr
namespace CollectProjMap
abbrev Collector := ProjMap ProjMap
@[inline] def collectVDecl (x : VarId) (v : Expr) : Collector := fun m =>
@@ -148,20 +148,20 @@ def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=
match y with
| Arg.var y =>
match ctx.projMap[y]? with
match ctx.projMap.find? y with
| some (Expr.proj j w) => j == i && w == x
| _ => false
| _ => false
/-- Given `uset x[i] := y`, return true iff `y := uproj[i] x` -/
def isSelfUSet (ctx : Context) (x : VarId) (i : Nat) (y : VarId) : Bool :=
match ctx.projMap[y]? with
match ctx.projMap.find? y with
| some (Expr.uproj j w) => j == i && w == x
| _ => false
/-- Given `sset x[n, i] := y`, return true iff `y := sproj[n, i] x` -/
def isSelfSSet (ctx : Context) (x : VarId) (n : Nat) (i : Nat) (y : VarId) : Bool :=
match ctx.projMap[y]? with
match ctx.projMap.find? y with
| some (Expr.sproj m j w) => n == m && j == i && w == x
| _ => false

View File

@@ -64,34 +64,34 @@ instance : AddMessageContext CompilerM where
def getType (fvarId : FVarId) : CompilerM Expr := do
let lctx := ( get).lctx
if let some decl := lctx.letDecls[fvarId]? then
if let some decl := lctx.letDecls.find? fvarId then
return decl.type
else if let some decl := lctx.params[fvarId]? then
else if let some decl := lctx.params.find? fvarId then
return decl.type
else if let some decl := lctx.funDecls[fvarId]? then
else if let some decl := lctx.funDecls.find? fvarId then
return decl.type
else
throwError "unknown free variable {fvarId.name}"
def getBinderName (fvarId : FVarId) : CompilerM Name := do
let lctx := ( get).lctx
if let some decl := lctx.letDecls[fvarId]? then
if let some decl := lctx.letDecls.find? fvarId then
return decl.binderName
else if let some decl := lctx.params[fvarId]? then
else if let some decl := lctx.params.find? fvarId then
return decl.binderName
else if let some decl := lctx.funDecls[fvarId]? then
else if let some decl := lctx.funDecls.find? fvarId then
return decl.binderName
else
throwError "unknown free variable {fvarId.name}"
def findParam? (fvarId : FVarId) : CompilerM (Option Param) :=
return ( get).lctx.params[fvarId]?
return ( get).lctx.params.find? fvarId
def findLetDecl? (fvarId : FVarId) : CompilerM (Option LetDecl) :=
return ( get).lctx.letDecls[fvarId]?
return ( get).lctx.letDecls.find? fvarId
def findFunDecl? (fvarId : FVarId) : CompilerM (Option FunDecl) :=
return ( get).lctx.funDecls[fvarId]?
return ( get).lctx.funDecls.find? fvarId
def findLetValue? (fvarId : FVarId) : CompilerM (Option LetValue) := do
let some { value, .. } findLetDecl? fvarId | return none
@@ -166,7 +166,7 @@ it is a free variable, a type (or type former), or `lcErased`.
`Check.lean` contains a substitution validator.
-/
abbrev FVarSubst := Std.HashMap FVarId Expr
abbrev FVarSubst := HashMap FVarId Expr
/--
Replace the free variables in `e` using the given substitution.
@@ -190,7 +190,7 @@ where
go (e : Expr) : Expr :=
if e.hasFVar then
match e with
| .fvar fvarId => match s[fvarId]? with
| .fvar fvarId => match s.find? fvarId with
| some e => if translator then e else go e
| none => e
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e
@@ -224,7 +224,7 @@ That is, it is not a type (or type former), nor `lcErased`. Recall that a valid
expressions that are free variables, `lcErased`, or type formers.
-/
private partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : Bool) : NormFVarResult :=
match s[fvarId]? with
match s.find? fvarId with
| some (.fvar fvarId') =>
if translator then
.fvar fvarId'
@@ -246,7 +246,7 @@ private partial def normArgImp (s : FVarSubst) (arg : Arg) (translator : Bool) :
match arg with
| .erased => arg
| .fvar fvarId =>
match s[fvarId]? with
match s.find? fvarId with
| some (.fvar fvarId') =>
let arg' := .fvar fvarId'
if translator then arg' else normArgImp s arg' translator

View File

@@ -268,7 +268,7 @@ def getFunctionSummary? (env : Environment) (fid : Name) : Option Value :=
A map from variable identifiers to the `Value` produced by the abstract
interpreter for them.
-/
abbrev Assignment := Std.HashMap FVarId Value
abbrev Assignment := HashMap FVarId Value
/--
The context of `InterpM`.
@@ -332,7 +332,7 @@ If none is available return `Value.bot`.
-/
def findVarValue (var : FVarId) : InterpM Value := do
let assignment getAssignment
return assignment.getD var .bot
return assignment.findD var .bot
/--
Find the value of `arg` using the logic of `findVarValue`.
@@ -547,13 +547,13 @@ where
| .jp decl k | .fun decl k =>
return code.updateFun! ( decl.updateValue ( go decl.value)) ( go k)
| .cases cs =>
let discrVal := assignment.getD cs.discr .bot
let discrVal := assignment.findD cs.discr .bot
let processAlt typ alt := do
match alt with
| .alt ctor args body =>
if discrVal.containsCtor ctor then
let filter param := do
if let some val := assignment[param.fvarId]? then
if let some val := assignment.find? param.fvarId then
if let some literal val.getLiteral then
return some (param, literal)
return none

View File

@@ -62,7 +62,7 @@ structure State where
Whenever there is function application `f a₁ ... aₙ`, where `f` is in `decls`, `f` is not `main`, and
we visit with the abstract values assigned to `aᵢ`, but first we record the visit here.
-/
visited : Std.HashSet (Name × Array AbsValue) := {}
visited : HashSet (Name × Array AbsValue) := {}
/--
Bitmask containing the result, i.e., which parameters of `main` are fixed.
We initialize it with `true` everywhere.

View File

@@ -59,7 +59,7 @@ structure FloatState where
/--
A map from identifiers of declarations to their current decision.
-/
decision : Std.HashMap FVarId Decision
decision : HashMap FVarId Decision
/--
A map from decisions (excluding `unknown`) to the declarations with
these decisions (in correct order). Basically:
@@ -67,7 +67,7 @@ structure FloatState where
- Which declarations do we move into a certain arm
- Which declarations do we move into the default arm
-/
newArms : Std.HashMap Decision (List CodeDecl)
newArms : HashMap Decision (List CodeDecl)
/--
Use to collect relevant declarations for the floating mechanism.
@@ -116,8 +116,8 @@ up to this point, with respect to `cs`. The initial decisions are:
- `arm` or `default` if we see the declaration only being used in exactly one cases arm
- `unknown` otherwise
-/
def initialDecisions (cs : Cases) : BaseFloatM (Std.HashMap FVarId Decision) := do
let mut map := Std.HashMap.empty ( read).decls.length
def initialDecisions (cs : Cases) : BaseFloatM (HashMap FVarId Decision) := do
let mut map := mkHashMap ( read).decls.length
let folder val acc := do
if let .let decl := val then
if ( ignore? decl) then
@@ -130,25 +130,25 @@ def initialDecisions (cs : Cases) : BaseFloatM (Std.HashMap FVarId Decision) :=
(_, map) goCases cs |>.run map
return map
where
goFVar (plannedDecision : Decision) (var : FVarId) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit := do
if let some decision := ( get)[var]? then
goFVar (plannedDecision : Decision) (var : FVarId) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit := do
if let some decision := ( get).find? var then
if decision == .unknown then
modify fun s => s.insert var plannedDecision
else if decision != plannedDecision then
modify fun s => s.insert var .dont
-- otherwise we already have the proper decision
goAlt (alt : Alt) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
goAlt (alt : Alt) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit :=
forFVarM (goFVar (.ofAlt alt)) alt
goCases (cs : Cases) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
goCases (cs : Cases) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit :=
cs.alts.forM goAlt
/--
Compute the initial new arms. This will just set up a map from all arms of
`cs` to empty `Array`s, plus one additional entry for `dont`.
-/
def initialNewArms (cs : Cases) : Std.HashMap Decision (List CodeDecl) := Id.run do
let mut map := Std.HashMap.empty (cs.alts.size + 1)
def initialNewArms (cs : Cases) : HashMap Decision (List CodeDecl) := Id.run do
let mut map := mkHashMap (cs.alts.size + 1)
map := map.insert .dont []
cs.alts.foldr (init := map) fun val acc => acc.insert (.ofAlt val) []
@@ -170,7 +170,7 @@ respectively but since `z` can't be moved we don't want that to move `x` and `y`
-/
def dontFloat (decl : CodeDecl) : FloatM Unit := do
forFVarM goFVar decl
modify fun s => { s with newArms := s.newArms.insert .dont (decl :: s.newArms[Decision.dont]!) }
modify fun s => { s with newArms := s.newArms.insert .dont (decl :: s.newArms.find! .dont) }
where
goFVar (fvar : FVarId) : FloatM Unit := do
if ( get).decision.contains fvar then
@@ -223,12 +223,12 @@ Will:
If we are at `y` `x` is still marked to be moved but we don't want that.
-/
def float (decl : CodeDecl) : FloatM Unit := do
let arm := ( get).decision[decl.fvarId]!
let arm := ( get).decision.find! decl.fvarId
forFVarM (goFVar · arm) decl
modify fun s => { s with newArms := s.newArms.insert arm (decl :: s.newArms[arm]!) }
modify fun s => { s with newArms := s.newArms.insert arm (decl :: s.newArms.find! arm) }
where
goFVar (fvar : FVarId) (arm : Decision) : FloatM Unit := do
let some decision := ( get).decision[fvar]? | return ()
let some decision := ( get).decision.find? fvar | return ()
if decision != arm then
modify fun s => { s with decision := s.decision.insert fvar .dont }
else if decision == .unknown then
@@ -249,7 +249,7 @@ where
-/
goCases : FloatM Unit := do
for decl in ( read).decls do
let currentDecision := ( get).decision[decl.fvarId]!
let currentDecision := ( get).decision.find! decl.fvarId
if currentDecision == .unknown then
/-
If the decision is still unknown by now this means `decl` is
@@ -284,10 +284,10 @@ where
newArms := initialNewArms cs
}
let (_, res) goCases |>.run base
let remainders := res.newArms[Decision.dont]!
let remainders := res.newArms.find! .dont
let altMapper alt := do
let decision := Decision.ofAlt alt
let newCode := res.newArms[decision]!
let decision := .ofAlt alt
let newCode := res.newArms.find! decision
trace[Compiler.floatLetIn] "Size of code that was pushed into arm: {repr decision} {newCode.length}"
let fused withNewScope do
go (attachCodeDecls newCode.toArray alt.getCode)

View File

@@ -29,7 +29,7 @@ structure CandidateInfo where
The set of candidates that rely on this candidate to be a join point.
For a more detailed explanation see the documentation of `find`
-/
associated : Std.HashSet FVarId
associated : HashSet FVarId
deriving Inhabited
/--
@@ -39,14 +39,14 @@ structure FindState where
/--
All current join point candidates accessible by their `FVarId`.
-/
candidates : Std.HashMap FVarId CandidateInfo := .empty
candidates : HashMap FVarId CandidateInfo := .empty
/--
The `FVarId`s of all `fun` declarations that were declared within the
current `fun`.
-/
scope : Std.HashSet FVarId := .empty
scope : HashSet FVarId := .empty
abbrev ReplaceCtx := Std.HashMap FVarId Name
abbrev ReplaceCtx := HashMap FVarId Name
abbrev FindM := ReaderT (Option FVarId) StateRefT FindState ScopeM
abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
@@ -55,7 +55,7 @@ abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
Attempt to find a join point candidate by its `FVarId`.
-/
private def findCandidate? (fvarId : FVarId) : FindM (Option CandidateInfo) := do
return ( get).candidates[fvarId]?
return ( get).candidates.find? fvarId
/--
Erase a join point candidate as well as all the ones that depend on it
@@ -69,7 +69,7 @@ private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
/--
Combinator for modifying the candidates in `FindM`.
-/
private def modifyCandidates (f : Std.HashMap FVarId CandidateInfo Std.HashMap FVarId CandidateInfo) : FindM Unit :=
private def modifyCandidates (f : HashMap FVarId CandidateInfo HashMap FVarId CandidateInfo) : FindM Unit :=
modify (fun state => {state with candidates := f state.candidates })
/--
@@ -196,7 +196,7 @@ where
return code
| _, _ => return Code.updateLet! code decl ( go k)
| .fun decl k =>
if let some replacement := ( read)[decl.fvarId]? then
if let some replacement := ( read).find? decl.fvarId then
let newDecl := { decl with
binderName := replacement,
value := ( go decl.value)
@@ -244,7 +244,7 @@ structure ExtendState where
to `Param`s. The free variables in this map are the once that the context
of said join point will be extended by by passing in the respective parameter.
-/
fvarMap : Std.HashMap FVarId (Std.HashMap FVarId Param) := {}
fvarMap : HashMap FVarId (HashMap FVarId Param) := {}
/--
The monad for the `extendJoinPointContext` pass.
@@ -262,7 +262,7 @@ otherwise just return `fvar`.
def replaceFVar (fvar : FVarId) : ExtendM FVarId := do
if ( read).candidates.contains fvar then
if let some currentJp := ( read).currentJp? then
if let some replacement := ( get).fvarMap[currentJp]![fvar]? then
if let some replacement := ( get).fvarMap.find! currentJp |>.find? fvar then
return replacement.fvarId
return fvar
@@ -313,7 +313,7 @@ This is necessary if:
-/
def extendByIfNecessary (fvar : FVarId) : ExtendM Unit := do
if let some currentJp := ( read).currentJp? then
let mut translator := ( get).fvarMap[currentJp]!
let mut translator := ( get).fvarMap.find! currentJp
let candidates := ( read).candidates
if !( isInScope fvar) && !translator.contains fvar && candidates.contains fvar then
let typ getType fvar
@@ -337,7 +337,7 @@ of `j.2` in `j.1`.
-/
def mergeJpContextIfNecessary (jp : FVarId) : ExtendM Unit := do
if ( read).currentJp?.isSome then
let additionalArgs := ( get).fvarMap[jp]!.toArray
let additionalArgs := ( get).fvarMap.find! jp |>.toArray
for (fvar, _) in additionalArgs do
extendByIfNecessary fvar
@@ -405,7 +405,7 @@ where
| .jp decl k =>
let decl withNewJpScope decl do
let value go decl.value
let additionalParams := ( get).fvarMap[decl.fvarId]!.toArray |>.map Prod.snd
let additionalParams := ( get).fvarMap.find! decl.fvarId |>.toArray |>.map Prod.snd
let newType := additionalParams.foldr (init := decl.type) (fun val acc => .forallE val.binderName val.type acc .default)
decl.update newType (additionalParams ++ decl.params) value
mergeJpContextIfNecessary decl.fvarId
@@ -426,7 +426,7 @@ where
return Code.updateCases! code cs.resultType discr alts
| .jmp fn args =>
let mut newArgs args.mapM (mapFVarM goFVar)
let additionalArgs := ( get).fvarMap[fn]!.toArray |>.map Prod.fst
let additionalArgs := ( get).fvarMap.find! fn |>.toArray |>.map Prod.fst
if let some _currentJp := ( read).currentJp? then
let f := fun arg => do
return .fvar ( goFVar arg)
@@ -545,7 +545,7 @@ where
if let some knownArgs := ( get).jpJmpArgs.find? fn then
let mut newArgs := knownArgs
for (param, arg) in decl.params.zip args do
if let some knownVal := newArgs[param.fvarId]? then
if let some knownVal := newArgs.find? param.fvarId then
if arg.toExpr != knownVal then
newArgs := newArgs.erase param.fvarId
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn newArgs }

View File

@@ -13,9 +13,9 @@ namespace Lean.Compiler.LCNF
LCNF local context.
-/
structure LCtx where
params : Std.HashMap FVarId Param := {}
letDecls : Std.HashMap FVarId LetDecl := {}
funDecls : Std.HashMap FVarId FunDecl := {}
params : HashMap FVarId Param := {}
letDecls : HashMap FVarId LetDecl := {}
funDecls : HashMap FVarId FunDecl := {}
deriving Inhabited
def LCtx.addParam (lctx : LCtx) (param : Param) : LCtx :=

View File

@@ -30,7 +30,7 @@ structure State where
/-- Counter for generating new (normalized) universe parameter names. -/
nextIdx : Nat := 1
/-- Mapping from existing universe parameter names to the new ones. -/
map : Std.HashMap Name Level := {}
map : HashMap Name Level := {}
/-- Parameters that have been normalized. -/
paramNames : Array Name := #[]
@@ -49,7 +49,7 @@ partial def normLevel (u : Level) : M Level := do
| .max v w => return u.updateMax! ( normLevel v) ( normLevel w)
| .imax v w => return u.updateIMax! ( normLevel v) ( normLevel w)
| .mvar _ => unreachable!
| .param n => match ( get).map[n]? with
| .param n => match ( get).map.find? n with
| some u => return u
| none =>
let u := Level.param <| (`u).appendIndexAfter ( get).nextIdx

View File

@@ -31,9 +31,9 @@ def sortedBySize : Probe Decl (Nat × Decl) := fun decls =>
if sz₁ == sz₂ then Name.lt decl₁.name decl₂.name else sz₁ < sz₂
def countUnique [ToString α] [BEq α] [Hashable α] : Probe α (α × Nat) := fun data => do
let mut map := Std.HashMap.empty
let mut map := HashMap.empty
for d in data do
if let some count := map[d]? then
if let some count := map.find? d then
map := map.insert d (count + 1)
else
map := map.insert d 1

View File

@@ -40,7 +40,7 @@ structure FunDeclInfoMap where
/--
Mapping from local function name to inlining information.
-/
map : Std.HashMap FVarId FunDeclInfo := {}
map : HashMap FVarId FunDeclInfo := {}
deriving Inhabited
def FunDeclInfoMap.format (s : FunDeclInfoMap) : CompilerM Format := do
@@ -56,7 +56,7 @@ Add new occurrence for the local function with binder name `key`.
def FunDeclInfoMap.add (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
match s with
| { map } =>
match map[fvarId]? with
match map.find? fvarId with
| some .once => { map := map.insert fvarId .many }
| none => { map := map.insert fvarId .once }
| _ => { map }
@@ -67,7 +67,7 @@ Add new occurrence for the local function occurring as an argument for another f
def FunDeclInfoMap.addHo (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
match s with
| { map } =>
match map[fvarId]? with
match map.find? fvarId with
| some .once | none => { map := map.insert fvarId .many }
| _ => { map }

View File

@@ -173,7 +173,7 @@ Execute `x` with `fvarId` set as `mustInline`.
After execution the original setting is restored.
-/
def withAddMustInline (fvarId : FVarId) (x : SimpM α) : SimpM α := do
let saved? := ( get).funDeclInfoMap.map[fvarId]?
let saved? := ( get).funDeclInfoMap.map.find? fvarId
try
addMustInline fvarId
x
@@ -185,7 +185,7 @@ Return true if the given local function declaration or join point id is marked a
`once` or `mustInline`. We use this information to decide whether to inline them.
-/
def isOnceOrMustInline (fvarId : FVarId) : SimpM Bool := do
match ( get).funDeclInfoMap.map[fvarId]? with
match ( get).funDeclInfoMap.map.find? fvarId with
| some .once | some .mustInline => return true
| _ => return false

View File

@@ -199,9 +199,9 @@ structure State where
/-- Cache from Lean regular expression to LCNF argument. -/
cache : PHashMap Expr Arg := {}
/-- `toLCNFType` cache -/
typeCache : Std.HashMap Expr Expr := {}
typeCache : HashMap Expr Expr := {}
/-- isTypeFormerType cache -/
isTypeFormerTypeCache : Std.HashMap Expr Bool := {}
isTypeFormerTypeCache : HashMap Expr Bool := {}
/-- LCNF sequence, we chain it to create a LCNF `Code` object. -/
seq : Array Element := #[]
/--
@@ -257,7 +257,7 @@ private partial def isTypeFormerType (type : Expr) : M Bool := do
| .true => return true
| .false => return false
| .undef =>
if let some result := ( get).isTypeFormerTypeCache[type]? then
if let some result := ( get).isTypeFormerTypeCache.find? type then
return result
let result liftMetaM <| Meta.isTypeFormerType type
modify fun s => { s with isTypeFormerTypeCache := s.isTypeFormerTypeCache.insert type result }
@@ -305,7 +305,7 @@ def applyToAny (type : Expr) : M Expr := do
| _ => none
def toLCNFType (type : Expr) : M Expr := do
match ( get).typeCache[type]? with
match ( get).typeCache.find? type with
| some type' => return type'
| none =>
let type' liftMetaM <| LCNF.toLCNFType type

View File

@@ -6,8 +6,6 @@ Author: Leonardo de Moura
prelude
import Init.Data.Nat.Power2
import Lean.Data.AssocList
import Std.Data.HashMap.Basic
import Std.Data.HashMap.Raw
namespace Lean
def HashMapBucket (α : Type u) (β : Type v) :=
@@ -271,11 +269,17 @@ def ofListWith (l : List (α × β)) (f : β → β → β) : HashMap α β :=
| none => m.insert p.fst p.snd
| some v => m.insert p.fst $ f v p.snd)
attribute [deprecated Std.HashMap] HashMap
attribute [deprecated Std.HashMap.Raw] HashMapImp
attribute [deprecated Std.HashMap.Raw.empty] mkHashMapImp
attribute [deprecated Std.HashMap.empty] mkHashMap
attribute [deprecated Std.HashMap.empty] HashMap.empty
attribute [deprecated Std.HashMap.ofList] HashMap.ofList
end Lean.HashMap
/--
Groups all elements `x`, `y` in `xs` with `key x == key y` into the same array
`(xs.groupByKey key).find! (key x)`. Groups preserve the relative order of elements in `xs`.
-/
def Array.groupByKey [BEq α] [Hashable α] (key : β α) (xs : Array β)
: Lean.HashMap α (Array β) := Id.run do
let mut groups :=
for x in xs do
let group := groups.findD (key x) #[]
groups := groups.erase (key x) -- make `group` referentially unique
groups := groups.insert (key x) (group.push x)
return groups

View File

@@ -6,8 +6,6 @@ Author: Leonardo de Moura
prelude
import Init.Data.Nat.Power2
import Init.Data.List.Control
import Std.Data.HashSet.Basic
import Std.Data.HashSet.Raw
namespace Lean
universe u v w
@@ -219,9 +217,3 @@ def insertMany [ForIn Id ρ α] (s : HashSet α) (as : ρ) : HashSet α := Id.ru
def merge {α : Type u} [BEq α] [Hashable α] (s t : HashSet α) : HashSet α :=
t.fold (init := s) fun s a => s.insert a
-- We don't use `insertMany` here because it gives weird universes.
attribute [deprecated Std.HashSet] HashSet
attribute [deprecated Std.HashSet.Raw] HashSetImp
attribute [deprecated Std.HashSet.Raw.empty] mkHashSetImp
attribute [deprecated Std.HashSet.empty] mkHashSet
attribute [deprecated Std.HashSet.empty] HashSet.empty

View File

@@ -150,7 +150,7 @@ instance : FromJson RefInfo where
pure { definition?, usages }
/-- References from a single module/file -/
def ModuleRefs := Std.HashMap RefIdent RefInfo
def ModuleRefs := HashMap RefIdent RefInfo
instance : ToJson ModuleRefs where
toJson m := Json.mkObj <| m.toList.map fun (ident, info) => (ident.toJson.compress, toJson info)
@@ -158,7 +158,7 @@ instance : ToJson ModuleRefs where
instance : FromJson ModuleRefs where
fromJson? j := do
let node j.getObj?
node.foldM (init := Std.HashMap.empty) fun m k v =>
node.foldM (init := HashMap.empty) fun m k v =>
return m.insert ( RefIdent.fromJson? ( Json.parse k)) ( fromJson? v)
/--

View File

@@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
-/
prelude
import Std.Data.HashSet.Basic
import Lean.Data.HashSet
import Lean.Data.RBMap
import Lean.Data.RBTree
@@ -65,14 +64,14 @@ abbrev insert (s : NameSSet) (n : Name) : NameSSet := SSet.insert s n
abbrev contains (s : NameSSet) (n : Name) : Bool := SSet.contains s n
end NameSSet
def NameHashSet := Std.HashSet Name
def NameHashSet := HashSet Name
namespace NameHashSet
@[inline] def empty : NameHashSet := Std.HashSet.empty
@[inline] def empty : NameHashSet := HashSet.empty
instance : EmptyCollection NameHashSet := empty
instance : Inhabited NameHashSet := {}
def insert (s : NameHashSet) (n : Name) := Std.HashSet.insert s n
def contains (s : NameHashSet) (n : Name) : Bool := Std.HashSet.contains s n
def insert (s : NameHashSet) (n : Name) := HashSet.insert s n
def contains (s : NameHashSet) (n : Name) : Bool := HashSet.contains s n
end NameHashSet
def MacroScopesView.isPrefixOf (v₁ v₂ : MacroScopesView) : Bool :=

View File

@@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Std.Data.HashMap.Basic
import Lean.Data.HashMap
import Lean.Data.PersistentHashMap
universe u v w w'
@@ -29,7 +28,7 @@ namespace Lean
-/
structure SMap (α : Type u) (β : Type v) [BEq α] [Hashable α] where
stage₁ : Bool := true
map₁ : Std.HashMap α β := {}
map₁ : HashMap α β := {}
map₂ : PHashMap α β := {}
namespace SMap
@@ -38,7 +37,7 @@ variable {α : Type u} {β : Type v} [BEq α] [Hashable α]
instance : Inhabited (SMap α β) := {}
def empty : SMap α β := {}
@[inline] def fromHashMap (m : Std.HashMap α β) (stage₁ := true) : SMap α β :=
@[inline] def fromHashMap (m : HashMap α β) (stage₁ := true) : SMap α β :=
{ map₁ := m, stage₁ := stage₁ }
@[specialize] def insert : SMap α β α β SMap α β
@@ -50,8 +49,8 @@ def empty : SMap α β := {}
| false, m₁, m₂, k, v => false, m₁, m₂.insert k v
@[specialize] def find? : SMap α β α Option β
| true, m₁, _, k => m₁[k]?
| false, m₁, m₂, k => (m₂.find? k).orElse fun _ => m₁[k]?
| true, m₁, _, k => m₁.find? k
| false, m₁, m₂, k => (m₂.find? k).orElse fun _ => m₁.find? k
@[inline] def findD (m : SMap α β) (a : α) (b₀ : β) : β :=
(m.find? a).getD b₀
@@ -68,8 +67,8 @@ def empty : SMap α β := {}
/-- Similar to `find?`, but searches for result in the hashmap first.
So, the result is correct only if we never "overwrite" `map₁` entries using `map₂`. -/
@[specialize] def find?' : SMap α β α Option β
| true, m₁, _, k => m₁[k]?
| false, m₁, m₂, k => m₁[k]?.orElse fun _ => m₂.find? k
| true, m₁, _, k => m₁.find? k
| false, m₁, m₂, k => (m₁.find? k).orElse fun _ => m₂.find? k
def forM [Monad m] (s : SMap α β) (f : α β m PUnit) : m PUnit := do
s.map₁.forM f
@@ -97,7 +96,7 @@ def fold {σ : Type w} (f : σα → β → σ) (init : σ) (m : SMap α β
m.map₂.foldl f $ m.map₁.fold f init
def numBuckets (m : SMap α β) : Nat :=
Std.HashMap.Internal.numBuckets m.map₁
m.map₁.numBuckets
def toList (m : SMap α β) : List (α × β) :=
m.fold (init := []) fun es a b => (a, b)::es

View File

@@ -541,7 +541,7 @@ mutual
/--
Process a `fType` of the form `(x : A) → B x`.
This method assume `fType` is a function type -/
private partial def processExplicitArg (argName : Name) : M Expr := do
private partial def processExplictArg (argName : Name) : M Expr := do
match ( get).args with
| arg::args =>
if ( anyNamedArgDependsOnCurrent) then
@@ -586,16 +586,6 @@ mutual
| Except.ok tacticSyntax =>
-- TODO(Leo): does this work correctly for tactic sequences?
let tacticBlock `(by $(tacticSyntax))
/-
We insert position information from the current ref into `stx` everywhere, simulating this being
a tactic script inserted by the user, which ensures error messages and logging will always be attributed
to this application rather than sometimes being placed at position (1,0) in the file.
Placing position information on `by` syntax alone is not sufficient since incrementality
(in particular, `Lean.Elab.Term.withReuseContext`) controls the ref to avoid leakage of outside data.
Note that `tacticSyntax` contains no position information itself, since it is erased by `Lean.Elab.Term.quoteAutoTactic`.
-/
let info := ( getRef).getHeadInfo
let tacticBlock := tacticBlock.raw.rewriteBottomUp (·.setInfo info)
let argNew := Arg.stx tacticBlock
propagateExpectedType argNew
elabAndAddNewArg argName argNew
@@ -625,7 +615,7 @@ mutual
This method assume `fType` is a function type -/
private partial def processImplicitArg (argName : Name) : M Expr := do
if ( read).explicit then
processExplicitArg argName
processExplictArg argName
else
addImplicitArg argName
@@ -634,7 +624,7 @@ mutual
This method assume `fType` is a function type -/
private partial def processStrictImplicitArg (argName : Name) : M Expr := do
if ( read).explicit then
processExplicitArg argName
processExplictArg argName
else if ( hasArgsToProcess) then
addImplicitArg argName
else
@@ -653,7 +643,7 @@ mutual
addNewArg argName arg
main
else
processExplicitArg argName
processExplictArg argName
else
let arg mkFreshExprMVar ( getArgExpectedType) MetavarKind.synthetic
addInstMVar arg.mvarId!
@@ -678,7 +668,7 @@ mutual
| .implicit => processImplicitArg binderName
| .instImplicit => processInstImplicitArg binderName
| .strictImplicit => processStrictImplicitArg binderName
| _ => processExplicitArg binderName
| _ => processExplictArg binderName
else if ( hasArgsToProcess) then
synthesizePendingAndNormalizeFunType
main

View File

@@ -502,16 +502,6 @@ def elabRunMeta : CommandElab := fun stx =>
addDocString declName ( getDocStringText doc)
| _ => throwUnsupportedSyntax
@[builtin_command_elab Lean.Parser.Command.include] def elabInclude : CommandElab
| `(Lean.Parser.Command.include| include $ids*) => do
let vars := ( getScope).varDecls.concatMap getBracketedBinderIds
for id in ids do
unless vars.contains id.getId do
throwError "invalid 'include', variable '{id}' has not been declared in the current scope"
modifyScope fun sc =>
{ sc with includedVars := sc.includedVars ++ ids.toList.map (·.getId) }
| _ => throwUnsupportedSyntax
@[builtin_command_elab Parser.Command.exit] def elabExit : CommandElab := fun _ =>
logWarning "using 'exit' to interrupt Lean"

View File

@@ -257,52 +257,31 @@ partial def hasCDot : Syntax → Bool
Return `some` if succeeded expanding `·` notation occurring in
the given syntax. Otherwise, return `none`.
Examples:
- `· + 1` => `fun x => x + 1`
- `f · · b` => `fun x1 x2 => f x1 x2 b` -/
- `· + 1` => `fun _a_1 => _a_1 + 1`
- `f · · b` => `fun _a_1 _a_2 => f _a_1 _a_2 b` -/
partial def expandCDot? (stx : Term) : MacroM (Option Term) := do
if hasCDot stx then
withFreshMacroScope do
let mut (newStx, binders) (go stx).run #[]
if binders.size == 1 then
-- It is nicer using `x` over `x1` if there's only a single binder.
let x1 := binders[0]!
let x := mkIdentFrom x1 ( MonadQuotation.addMacroScope `x) (canonical := true)
binders := binders.set! 0 x
newStx newStx.replaceM fun s => pure (if s == x1 then x else none)
`(fun $binders* => $(newStx))
let (newStx, binders) (go stx).run #[]
`(fun $binders* => $(newStx))
else
pure none
where
/--
Auxiliary function for expanding the `·` notation.
The extra state `Array Syntax` contains the new binder names.
If `stx` is a `·`, we create a fresh identifier, store it in the
extra state, and return it. Otherwise, we just return `stx`.
-/
Auxiliary function for expanding the `·` notation.
The extra state `Array Syntax` contains the new binder names.
If `stx` is a `·`, we create a fresh identifier, store in the
extra state, and return it. Otherwise, we just return `stx`. -/
go : Syntax StateT (Array Ident) MacroM Syntax
| stx@`(($(_))) => pure stx
| stx@`(·) => do
let name MonadQuotation.addMacroScope <| Name.mkSimple s!"x{(← get).size + 1}"
let id := mkIdentFrom stx name (canonical := true)
modify (fun s => s.push id)
pure id
| stx => match stx with
| .node _ k args => do
let args
if k == choiceKind then
if args.isEmpty then
return stx
let s get
let args' args.mapM (fun arg => go arg |>.run s)
let s' := args'[0]!.2
unless args'.all (fun (_, s'') => s''.size == s'.size) do
Macro.throwErrorAt stx "Ambiguous notation in cdot function has different numbers of '·' arguments in each alternative."
set s'
pure <| args'.map Prod.fst
else
args.mapM go
return .node (.fromRef stx (canonical := true)) k args
| _ => pure stx
| stx@`(($(_))) => pure stx
| stx@`(·) => withFreshMacroScope do
let id mkFreshIdent stx (canonical := true)
modify (·.push id)
pure id
| stx => match stx with
| .node _ k args => do
let args args.mapM go
return .node (.fromRef stx (canonical := true)) k args
| _ => pure stx
/--
Helper method for elaborating terms such as `(.+.)` where a constant name is expected.

View File

@@ -51,8 +51,6 @@ structure Scope where
even if they do not work with binders per se.
-/
varDecls : Array (TSyntax ``Parser.Term.bracketedBinder) := #[]
/-- `include`d section variable names -/
includedVars : List Name := []
/--
Globally unique internal identifiers for the `varDecls`.
There is one identifier per variable introduced by the binders
@@ -203,12 +201,12 @@ def mkMessageAux (ctx : Context) (ref : Syntax) (msgData : MessageData) (severit
private def addTraceAsMessagesCore (ctx : Context) (log : MessageLog) (traceState : TraceState) : MessageLog := Id.run do
if traceState.traces.isEmpty then return log
let mut traces : Std.HashMap (String.Pos × String.Pos) (Array MessageData) :=
let mut traces : HashMap (String.Pos × String.Pos) (Array MessageData) :=
for traceElem in traceState.traces do
let ref := replaceRef traceElem.ref ctx.ref
let pos := ref.getPos?.getD 0
let endPos := ref.getTailPos?.getD pos
traces := traces.insert (pos, endPos) <| traces.getD (pos, endPos) #[] |>.push traceElem.msg
traces := traces.insert (pos, endPos) <| traces.findD (pos, endPos) #[] |>.push traceElem.msg
let mut log := log
let traces' := traces.toArray.qsort fun ((a, _), _) ((b, _), _) => a < b
for ((pos, endPos), traceMsg) in traces' do

View File

@@ -630,7 +630,7 @@ private def replaceIndFVarsWithConsts (views : Array InductiveView) (indFVars :
let type := type.replace fun e =>
if !e.isFVar then
none
else match indFVar2Const[e]? with
else match indFVar2Const.find? e with
| none => none
| some c => mkAppN c (params.extract 0 numVars)
instantiateMVars ( mkForallFVars params type)

View File

@@ -425,7 +425,7 @@ private def applyRefMap (e : Expr) (map : ExprMap Expr) : Expr :=
e.replace fun e =>
match patternWithRef? e with
| some _ => some e -- stop `e` already has annotation
| none => match map[e]? with
| none => match map.find? e with
| some eWithRef => some eWithRef -- stop `e` found annotation
| none => none -- continue

View File

@@ -327,45 +327,7 @@ def instantiateMVarsProfiling (e : Expr) : MetaM Expr := do
profileitM Exception s!"instantiate metavars" ( getOptions) do
instantiateMVars e
/--
Runs `k` with a restricted local context where only section variables from `vars` are included that
* are directly referenced in any `headers`,
* are included in `includedVars` (via the `include` command),
* are directly referenced in any variable included by these rules, OR
* are instance-implicit variables that only reference section variables included by these rules.
-/
private def withHeaderSecVars {α} (vars : Array Expr) (includedVars : List Name) (headers : Array DefViewElabHeader)
(k : Array Expr TermElabM α) : TermElabM α := do
let (_, used) collectUsed.run {}
let (lctx, localInsts, vars) removeUnused vars used
withLCtx lctx localInsts <| k vars
where
collectUsed : StateRefT CollectFVars.State MetaM Unit := do
-- directly referenced in headers
headers.forM (·.type.collectFVars)
-- included by `include`
vars.forM fun var => do
let ldecl getFVarLocalDecl var
if includedVars.contains ldecl.userName then
modify (·.add ldecl.fvarId)
-- transitively referenced
get >>= (·.addDependencies) >>= set
-- instances (`addDependencies` unnecessary as by definition they may only reference variables
-- already included)
vars.forM fun var => do
let ldecl getFVarLocalDecl var
let st get
if ldecl.binderInfo.isInstImplicit && ( getFVars ldecl.type).all st.fvarSet.contains then
modify (·.add ldecl.fvarId)
getFVars (e : Expr) : MetaM (Array FVarId) :=
(·.2.fvarIds) <$> e.collectFVars.run {}
register_builtin_option deprecated.oldSectionVars : Bool := {
defValue := false
descr := "re-enable deprecated behavior of including exactly the section variables used in a declaration"
}
private def elabFunValues (headers : Array DefViewElabHeader) (vars : Array Expr) (includedVars : List Name) : TermElabM (Array Expr) :=
private def elabFunValues (headers : Array DefViewElabHeader) : TermElabM (Array Expr) :=
headers.mapM fun header => do
let mut reusableResult? := none
if let some snap := header.bodySnap? then
@@ -380,7 +342,6 @@ private def elabFunValues (headers : Array DefViewElabHeader) (vars : Array Expr
withReuseContext header.value do
withDeclName header.declName <| withLevelNames header.levelNames do
let valStx liftMacroM <| declValToTerm header.value
(if header.kind.isTheorem && !deprecated.oldSectionVars.get ( getOptions) then withHeaderSecVars vars includedVars #[header] else fun x => x #[]) fun vars => do
forallBoundedTelescope header.type header.numParams fun xs type => do
-- Add new info nodes for new fvars. The server will detect all fvars of a binder by the binder's source location.
for i in [0:header.binderIds.size] do
@@ -392,20 +353,7 @@ private def elabFunValues (headers : Array DefViewElabHeader) (vars : Array Expr
-- NOTE: without this `instantiatedMVars`, `mkLambdaFVars` may leave around a redex that
-- leads to more section variables being included than necessary
let val instantiateMVarsProfiling val
let val mkLambdaFVars xs val
unless header.type.hasSorry || val.hasSorry do
for var in vars do
unless header.type.containsFVar var.fvarId! ||
val.containsFVar var.fvarId! ||
( vars.anyM (fun v => return ( v.fvarId!.getType).containsFVar var.fvarId!)) do
let varDecl var.fvarId!.getDecl
let var := if varDecl.userName.hasMacroScopes && varDecl.binderInfo.isInstImplicit then
m!"[{varDecl.type}]".group
else
var
logWarningAt header.ref m!"included section variable '{var}' is not used in \
'{header.declName}', consider excluding it"
return val
mkLambdaFVars xs val
if let some snap := header.bodySnap? then
snap.new.resolve <| some {
diagnostics :=
@@ -956,7 +904,7 @@ partial def checkForHiddenUnivLevels (allUserLevelNames : List Name) (preDefs :
for preDef in preDefs do
checkPreDef preDef
def elabMutualDef (vars : Array Expr) (includedVars : List Name) (views : Array DefView) : TermElabM Unit :=
def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit :=
if isExample views then
withoutModifyingEnv do
-- save correct environment in info tree
@@ -977,7 +925,7 @@ where
addLocalVarInfo view.declId funFVar
let values
try
let values elabFunValues headers vars includedVars
let values elabFunValues headers
Term.synthesizeSyntheticMVarsNoPostponing
values.mapM (instantiateMVarsProfiling ·)
catch ex =>
@@ -987,7 +935,7 @@ where
let letRecsToLift getLetRecsToLift
let letRecsToLift letRecsToLift.mapM instantiateMVarsAtLetRecToLift
checkLetRecsToLiftTypes funFVars letRecsToLift
(if headers.all (·.kind.isTheorem) && !deprecated.oldSectionVars.get ( getOptions) then withHeaderSecVars vars includedVars headers else withUsed vars headers values letRecsToLift) fun vars => do
withUsed vars headers values letRecsToLift fun vars => do
let preDefs MutualClosure.main vars headers funFVars values letRecsToLift
for preDef in preDefs do
trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
@@ -1058,8 +1006,7 @@ def elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
if let some snap := snap? then
-- no non-fatal diagnostics at this point
snap.new.resolve <| .ofTyped { defs, diagnostics := .empty : DefsParsedSnapshot }
let includedVars := ( getScope).includedVars
runTermElabM fun vars => Term.elabMutualDef vars includedVars views
runTermElabM fun vars => Term.elabMutualDef vars views
builtin_initialize
registerTraceClass `Elab.definition.mkClosure

View File

@@ -164,11 +164,8 @@ def addNonRec (preDef : PreDefinition) (applyAttrAfterCompilation := true) (all
/--
Eliminate recursive application annotations containing syntax. These annotations are used by the well-founded recursion module
to produce better error messages. -/
def eraseRecAppSyntaxExpr (e : Expr) : CoreM Expr := do
if e.find? hasRecAppSyntax |>.isSome then
Core.transform e (post := fun e => pure <| TransformStep.done <| if hasRecAppSyntax e then e.mdataExpr! else e)
else
return e
def eraseRecAppSyntaxExpr (e : Expr) : CoreM Expr :=
Core.transform e (post := fun e => pure <| TransformStep.done <| if (getRecAppSyntax? e).isSome then e.mdataExpr! else e)
def eraseRecAppSyntax (preDef : PreDefinition) : CoreM PreDefinition :=
return { preDef with value := ( eraseRecAppSyntaxExpr preDef.value) }

View File

@@ -69,15 +69,12 @@ private def ensureNoUnassignedMVarsAtPreDef (preDef : PreDefinition) : TermElabM
This method beta-reduces them to make sure they can be eliminated by the well-founded recursion module. -/
private def betaReduceLetRecApps (preDefs : Array PreDefinition) : MetaM (Array PreDefinition) :=
preDefs.mapM fun preDef => do
if preDef.value.find? (fun e => e.isConst && preDefs.any fun preDef => preDef.declName == e.constName!) |>.isSome then
let value Core.transform preDef.value fun e => do
if e.isApp && e.getAppFn.isLambda && e.getAppArgs.all fun arg => arg.getAppFn.isConst && preDefs.any fun preDef => preDef.declName == arg.getAppFn.constName! then
return .visit e.headBeta
else
return .continue
return { preDef with value }
else
return preDef
let value Core.transform preDef.value fun e => do
if e.isApp && e.getAppFn.isLambda && e.getAppArgs.all fun arg => arg.getAppFn.isConst && preDefs.any fun preDef => preDef.declName == arg.getAppFn.constName! then
return .visit e.headBeta
else
return .continue
return { preDef with value }
private def addAsAxioms (preDefs : Array PreDefinition) : TermElabM Unit := do
for preDef in preDefs do

View File

@@ -146,7 +146,7 @@ See issue #837 for an example where we can show termination using the index of a
we don't get the desired definitional equalities.
-/
def nonIndicesFirst (recArgInfos : Array RecArgInfo) : Array RecArgInfo := Id.run do
let mut indicesPos : Std.HashSet Nat := {}
let mut indicesPos : HashSet Nat := {}
for recArgInfo in recArgInfos do
for pos in recArgInfo.indicesPos do
indicesPos := indicesPos.insert pos

View File

@@ -596,7 +596,7 @@ private partial def compileStxMatch (discrs : List Term) (alts : List Alt) : Ter
`(have __discr := $discr; $stx)
| _, _ => unreachable!
abbrev IdxSet := Std.HashSet Nat
abbrev IdxSet := HashSet Nat
private partial def hasNoErrorIfUnused : Syntax Bool
| `(no_error_if_unused% $_) => true

View File

@@ -11,35 +11,27 @@ namespace Lean
private def recAppKey := `_recApp
/--
We store the syntax at recursive applications to be able to generate better error messages
when performing well-founded and structural recursion.
We store the syntax at recursive applications to be able to generate better error messages
when performing well-founded and structural recursion.
-/
def mkRecAppWithSyntax (e : Expr) (stx : Syntax) : Expr :=
mkMData (KVMap.empty.insert recAppKey (.ofSyntax stx)) e
mkMData (KVMap.empty.insert recAppKey (DataValue.ofSyntax stx)) e
/--
Retrieve (if available) the syntax object attached to a recursive application.
Retrieve (if available) the syntax object attached to a recursive application.
-/
def getRecAppSyntax? (e : Expr) : Option Syntax :=
match e with
| .mdata d _ =>
| Expr.mdata d _ =>
match d.find recAppKey with
| some (DataValue.ofSyntax stx) => some stx
| _ => none
| _ => none
/--
Checks if the `MData` is for a recursive applciation.
Checks if the `MData` is for a recursive applciation.
-/
def MData.isRecApp (d : MData) : Bool :=
d.contains recAppKey
/--
Return `true` if `getRecAppSyntax? e` is a `some`.
-/
def hasRecAppSyntax (e : Expr) : Bool :=
match e with
| .mdata d _ => d.isRecApp
| _ => false
end Lean

View File

@@ -445,13 +445,13 @@ private def expandParentFields (s : Struct) : TermElabM Struct := do
| _ => throwErrorAt ref "failed to access field '{fieldName}' in parent structure"
| _ => return field
private abbrev FieldMap := Std.HashMap Name Fields
private abbrev FieldMap := HashMap Name Fields
private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
fields.foldlM (init := {}) fun fieldMap field =>
match field.lhs with
| .fieldName _ fieldName :: _ =>
match fieldMap[fieldName]? with
match fieldMap.find? fieldName with
| some (prevField::restFields) =>
if field.isSimple || prevField.isSimple then
throwErrorAt field.ref "field '{fieldName}' has already been specified"
@@ -677,10 +677,6 @@ private partial def elabStruct (s : Struct) (expectedType? : Option Expr) : Term
| .error err => throwError err
| .ok tacticSyntax =>
let stx `(by $tacticSyntax)
-- See comment in `Lean.Elab.Term.ElabAppArgs.processExplicitArg` about `tacticSyntax`.
-- We add info to get reliable positions for messages from evaluating the tactic script.
let info := field.ref.getHeadInfo
let stx := stx.raw.rewriteBottomUp (·.setInfo info)
cont ( elabTermEnsuringType stx (d.getArg! 0).consumeTypeAnnotations) field
| _ =>
if bi == .instImplicit then

View File

@@ -246,7 +246,7 @@ private def getSomeSyntheticMVarsRef : TermElabM Syntax := do
private def throwStuckAtUniverseCnstr : TermElabM Unit := do
-- This code assumes `entries` is not empty. Note that `processPostponed` uses `exceptionOnFailure` to guarantee this property
let entries getPostponed
let mut found : Std.HashSet (Level × Level) := {}
let mut found : HashSet (Level × Level) := {}
let mut uniqueEntries := #[]
for entry in entries do
let mut lhs := entry.lhs

View File

@@ -8,6 +8,8 @@ import Init.Omega.Constraint
import Lean.Elab.Tactic.Omega.OmegaM
import Lean.Elab.Tactic.Omega.MinNatAbs
open Lean (HashMap HashSet)
namespace Lean.Elab.Tactic.Omega
initialize Lean.registerTraceClass `omega
@@ -165,11 +167,11 @@ structure Problem where
/-- The number of variables in the problem. -/
numVars : Nat := 0
/-- The current constraints, indexed by their coefficients. -/
constraints : Std.HashMap Coeffs Fact :=
constraints : HashMap Coeffs Fact :=
/--
The coefficients for which `constraints` contains an exact constraint (i.e. an equality).
-/
equalities : Std.HashSet Coeffs :=
equalities : HashSet Coeffs :=
/--
Equations that have already been used to eliminate variables,
along with the variable which was removed, and its coefficient (either `1` or `-1`).
@@ -249,7 +251,7 @@ combining it with any existing constraints for the same coefficients.
def addConstraint (p : Problem) : Fact Problem
| f@x, s, j =>
if p.possible then
match p.constraints[x]? with
match p.constraints.find? x with
| none =>
match s with
| .trivial => p
@@ -311,7 +313,7 @@ After solving, the variable will have been eliminated from all constraints.
def solveEasyEquality (p : Problem) (c : Coeffs) : Problem :=
let i := c.findIdx? (·.natAbs = 1) |>.getD 0 -- findIdx? is always some
let sign := c.get i |> Int.sign
match p.constraints[c]? with
match p.constraints.find? c with
| some f =>
let init :=
{ assumptions := p.assumptions
@@ -333,7 +335,7 @@ After solving the easy equality,
the minimum lexicographic value of `(c.minNatAbs, c.maxNatAbs)` will have been reduced.
-/
def dealWithHardEquality (p : Problem) (c : Coeffs) : OmegaM Problem :=
match p.constraints[c]? with
match p.constraints.find? c with
| some _, some r, some r', j => do
let m := c.minNatAbs + 1
-- We have to store the valid value of the newly introduced variable in the atoms.
@@ -477,7 +479,7 @@ def fourierMotzkinData (p : Problem) : Array FourierMotzkinData := Id.run do
let n := p.numVars
let mut data : Array FourierMotzkinData :=
(List.range p.numVars).foldl (fun a i => a.push { var := i}) #[]
for (_, f@xs, s, _) in p.constraints do
for (_, f@xs, s, _) in p.constraints.toList do -- We could make a forIn instance for HashMap
for i in [0:n] do
let x := Coeffs.get xs i
data := data.modify i fun d =>

View File

@@ -58,7 +58,7 @@ structure MetaProblem where
-/
disjunctions : List Expr := []
/-- Facts which have already been processed; we keep these to avoid duplicates. -/
processedFacts : Std.HashSet Expr :=
processedFacts : HashSet Expr :=
/-- Construct the `rfl` proof that `lc.eval atoms = e`. -/
def mkEvalRflProof (e : Expr) (lc : LinearCombo) : OmegaM Expr := do
@@ -80,7 +80,7 @@ def mkCoordinateEvalAtomsEq (e : Expr) (n : Nat) : OmegaM Expr := do
mkEqTrans eq ( mkEqSymm (mkApp2 (.const ``LinearCombo.coordinate_eval []) n atoms))
/-- Construct the linear combination (and its associated proof and new facts) for an atom. -/
def mkAtomLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
def mkAtomLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
let (n, facts) lookup e
return LinearCombo.coordinate n, mkCoordinateEvalAtomsEq e n, facts.getD
@@ -94,9 +94,9 @@ Gives a small (10%) speedup in testing.
I tried using a pointer based cache,
but there was never enough subexpression sharing to make it effective.
-/
partial def asLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
partial def asLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
let cache get
match cache.get? e with
match cache.find? e with
| some (lc, prf) =>
trace[omega] "Found in cache: {e}"
return (lc, prf, )
@@ -120,7 +120,7 @@ We also transform the expression as we descend into it:
* pushing coercions: `↑(x + y)`, `↑(x * y)`, `↑(x / k)`, `↑(x % k)`, `↑k`
* unfolding `emod`: `x % k` → `x - x / k`
-/
partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
trace[omega] "processing {e}"
match groundInt? e with
| some i =>
@@ -142,7 +142,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
mkEqTrans
( mkAppM ``Int.add_congr #[ prf₁, prf₂])
( mkEqSymm add_eval)
pure (l₁ + l₂, prf, facts₁.union facts₂)
pure (l₁ + l₂, prf, facts₁.merge facts₂)
| (``HSub.hSub, #[_, _, _, _, e₁, e₂]) => do
let (l₁, prf₁, facts₁) asLinearCombo e₁
let (l₂, prf₂, facts₂) asLinearCombo e₂
@@ -152,7 +152,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
mkEqTrans
( mkAppM ``Int.sub_congr #[ prf₁, prf₂])
( mkEqSymm sub_eval)
pure (l₁ - l₂, prf, facts₁.union facts₂)
pure (l₁ - l₂, prf, facts₁.merge facts₂)
| (``Neg.neg, #[_, _, e']) => do
let (l, prf, facts) asLinearCombo e'
let prf' : OmegaM Expr := do
@@ -178,7 +178,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
mkEqTrans
( mkAppM ``Int.mul_congr #[ xprf, yprf])
( mkEqSymm mul_eval)
pure (some (LinearCombo.mul xl yl, prf, xfacts.union yfacts), true)
pure (some (LinearCombo.mul xl yl, prf, xfacts.merge yfacts), true)
else
pure (none, false)
match r? with
@@ -235,7 +235,7 @@ where
Apply a rewrite rule to an expression, and interpret the result as a `LinearCombo`.
(We're not rewriting any subexpressions here, just the top level, for efficiency.)
-/
rewrite (lhs rw : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
rewrite (lhs rw : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
trace[omega] "rewriting {lhs} via {rw} : {← inferType rw}"
match ( inferType rw).eq? with
| some (_, _lhs', rhs) =>
@@ -243,7 +243,7 @@ where
let prf' : OmegaM Expr := do mkEqTrans rw ( prf)
pure (lc, prf', facts)
| none => panic! "Invalid rewrite rule in 'asLinearCombo'"
handleNatCast (e i n : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
handleNatCast (e i n : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
match n with
| .fvar h =>
if let some v h.getValue? then
@@ -296,7 +296,7 @@ where
| (``Fin.val, #[n, x]) =>
handleFinVal e i n x
| _ => mkAtomLinearCombo e
handleFinVal (e i n x : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
handleFinVal (e i n x : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
match x with
| .fvar h =>
if let some v h.getValue? then
@@ -342,7 +342,7 @@ We solve equalities as they are discovered, as this often results in an earlier
-/
def addIntEquality (p : MetaProblem) (h x : Expr) : OmegaM MetaProblem := do
let (lc, prf, facts) asLinearCombo x
let newFacts : Std.HashSet Expr := facts.fold (init := ) fun s e =>
let newFacts : HashSet Expr := facts.fold (init := ) fun s e =>
if p.processedFacts.contains e then s else s.insert e
trace[omega] "Adding proof of {lc} = 0"
pure <|
@@ -358,7 +358,7 @@ We solve equalities as they are discovered, as this often results in an earlier
-/
def addIntInequality (p : MetaProblem) (h y : Expr) : OmegaM MetaProblem := do
let (lc, prf, facts) asLinearCombo y
let newFacts : Std.HashSet Expr := facts.fold (init := ) fun s e =>
let newFacts : HashSet Expr := facts.fold (init := ) fun s e =>
if p.processedFacts.contains e then s else s.insert e
trace[omega] "Adding proof of {lc} ≥ 0"
pure <|
@@ -590,7 +590,7 @@ where
-- We sort the constraints; otherwise the order is dependent on details of the hashing
-- and this can cause test suite output churn
prettyConstraints (names : Array String) (constraints : Std.HashMap Coeffs Fact) : String :=
prettyConstraints (names : Array String) (constraints : HashMap Coeffs Fact) : String :=
constraints.toList
|>.toArray
|>.qsort (·.1 < ·.1)
@@ -615,7 +615,7 @@ where
(if Int.natAbs c = 1 then names[i]! else s!"{c.natAbs}*{names[i]!}"))
|> String.join
mentioned (atoms : Array Expr) (constraints : Std.HashMap Coeffs Fact) : MetaM (Array Bool) := do
mentioned (atoms : Array Expr) (constraints : HashMap Coeffs Fact) : MetaM (Array Bool) := do
let initMask := Array.mkArray atoms.size false
return constraints.fold (init := initMask) fun mask coeffs _ =>
coeffs.enum.foldl (init := mask) fun mask (i, c) =>

View File

@@ -10,8 +10,6 @@ import Init.Omega.Logic
import Init.Data.BitVec.Basic
import Lean.Meta.AppBuilder
import Lean.Meta.Canonicalizer
import Std.Data.HashMap.Basic
import Std.Data.HashSet.Basic
/-!
# The `OmegaM` state monad.
@@ -54,7 +52,7 @@ structure Context where
/-- The internal state for the `OmegaM` monad, recording previously encountered atoms. -/
structure State where
/-- The atoms up-to-defeq encountered so far. -/
atoms : Std.HashMap Expr Nat := {}
atoms : HashMap Expr Nat := {}
/-- An intermediate layer in the `OmegaM` monad. -/
abbrev OmegaM' := StateRefT State (ReaderT Context CanonM)
@@ -62,7 +60,7 @@ abbrev OmegaM' := StateRefT State (ReaderT Context CanonM)
/--
Cache of expressions that have been visited, and their reflection as a linear combination.
-/
def Cache : Type := Std.HashMap Expr (LinearCombo × OmegaM' Expr)
def Cache : Type := HashMap Expr (LinearCombo × OmegaM' Expr)
/--
The `OmegaM` monad maintains two pieces of state:
@@ -73,7 +71,7 @@ abbrev OmegaM := StateRefT Cache OmegaM'
/-- Run a computation in the `OmegaM` monad, starting with no recorded atoms. -/
def OmegaM.run (m : OmegaM α) (cfg : OmegaConfig) : MetaM α :=
m.run' Std.HashMap.empty |>.run' {} { cfg } |>.run'
m.run' HashMap.empty |>.run' {} { cfg } |>.run'
/-- Retrieve the user-specified configuration options. -/
def cfg : OmegaM OmegaConfig := do pure ( read).cfg
@@ -164,11 +162,11 @@ def mkEqReflWithExpectedType (a b : Expr) : MetaM Expr := do
Analyzes a newly recorded atom,
returning a collection of interesting facts about it that should be added to the context.
-/
def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
match e.getAppFnArgs with
| (``Nat.cast, #[.const ``Int [], _, e']) =>
-- Casts of natural numbers are non-negative.
let mut r := Std.HashSet.empty.insert (Expr.app (.const ``Int.ofNat_nonneg []) e')
let mut r := HashSet.empty.insert (Expr.app (.const ``Int.ofNat_nonneg []) e')
match ( cfg).splitNatSub, e'.getAppFnArgs with
| true, (``HSub.hSub, #[_, _, _, _, a, b]) =>
-- `((a - b : Nat) : Int)` gives a dichotomy
@@ -190,7 +188,7 @@ def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
let ne_zero := mkApp3 (.const ``Ne [1]) (.const ``Int []) k (toExpr (0 : Int))
let pos := mkApp4 (.const ``LT.lt [0]) (.const ``Int []) (.const ``Int.instLTInt [])
(toExpr (0 : Int)) k
pure <| Std.HashSet.empty.insert
pure <| HashSet.empty.insert
(mkApp3 (.const ``Int.mul_ediv_self_le []) x k ( mkDecideProof ne_zero)) |>.insert
(mkApp3 (.const ``Int.lt_mul_ediv_self_add []) x k ( mkDecideProof pos))
| (``HMod.hMod, #[_, _, _, _, x, k]) =>
@@ -202,7 +200,7 @@ def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
let b_pos := mkApp4 (.const ``LT.lt [0]) (.const ``Int []) (.const ``Int.instLTInt [])
(toExpr (0 : Int)) b
let pow_pos := mkApp3 (.const ``Lean.Omega.Int.pos_pow_of_pos []) b exp ( mkDecideProof b_pos)
pure <| Std.HashSet.empty.insert
pure <| HashSet.empty.insert
(mkApp3 (.const ``Int.emod_nonneg []) x k
(mkApp3 (.const ``Int.ne_of_gt []) k (toExpr (0 : Int)) pow_pos)) |>.insert
(mkApp3 (.const ``Int.emod_lt_of_pos []) x k pow_pos)
@@ -216,7 +214,7 @@ def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
(toExpr (0 : Nat)) b
let pow_pos := mkApp3 (.const ``Nat.pos_pow_of_pos []) b exp ( mkDecideProof b_pos)
let cast_pos := mkApp2 (.const ``Int.ofNat_pos_of_pos []) k' pow_pos
pure <| Std.HashSet.empty.insert
pure <| HashSet.empty.insert
(mkApp3 (.const ``Int.emod_nonneg []) x k
(mkApp3 (.const ``Int.ne_of_gt []) k (toExpr (0 : Int)) cast_pos)) |>.insert
(mkApp3 (.const ``Int.emod_lt_of_pos []) x k cast_pos)
@@ -224,18 +222,18 @@ def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
| (``Nat.cast, #[.const ``Int [], _, x']) =>
-- Since we push coercions inside `%`, we need to record here that
-- `(x : Int) % (y : Int)` is non-negative.
pure <| Std.HashSet.empty.insert (mkApp2 (.const ``Int.emod_ofNat_nonneg []) x' k)
pure <| HashSet.empty.insert (mkApp2 (.const ``Int.emod_ofNat_nonneg []) x' k)
| _ => pure
| _ => pure
| (``Min.min, #[_, _, x, y]) =>
pure <| Std.HashSet.empty.insert (mkApp2 (.const ``Int.min_le_left []) x y) |>.insert
pure <| HashSet.empty.insert (mkApp2 (.const ``Int.min_le_left []) x y) |>.insert
(mkApp2 (.const ``Int.min_le_right []) x y)
| (``Max.max, #[_, _, x, y]) =>
pure <| Std.HashSet.empty.insert (mkApp2 (.const ``Int.le_max_left []) x y) |>.insert
pure <| HashSet.empty.insert (mkApp2 (.const ``Int.le_max_left []) x y) |>.insert
(mkApp2 (.const ``Int.le_max_right []) x y)
| (``ite, #[α, i, dec, t, e]) =>
if α == (.const ``Int []) then
pure <| Std.HashSet.empty.insert <| mkApp5 (.const ``ite_disjunction [0]) α i dec t e
pure <| HashSet.empty.insert <| mkApp5 (.const ``ite_disjunction [0]) α i dec t e
else
pure {}
| _ => pure
@@ -250,10 +248,10 @@ Return its index, and, if it is new, a collection of interesting facts about the
* for each new atom of the form `((a - b : Nat) : Int)`, the fact:
`b ≤ a ∧ ((a - b : Nat) : Int) = a - b a < b ∧ ((a - b : Nat) : Int) = 0`
-/
def lookup (e : Expr) : OmegaM (Nat × Option (Std.HashSet Expr)) := do
def lookup (e : Expr) : OmegaM (Nat × Option (HashSet Expr)) := do
let c getThe State
let e canon e
match c.atoms[e]? with
match c.atoms.find? e with
| some i => return (i, none)
| none =>
trace[omega] "New atom: {e}"

View File

@@ -7,8 +7,8 @@ prelude
import Init.Control.StateRef
import Init.Data.Array.BinSearch
import Init.Data.Stream
import Lean.ImportingFlag
import Lean.Data.HashMap
import Lean.ImportingFlag
import Lean.Data.SMap
import Lean.Declaration
import Lean.LocalContext
@@ -134,7 +134,7 @@ structure Environment where
the field `constants`. These auxiliary constants are invisible to the Lean kernel and elaborator.
Only the code generator uses them.
-/
const2ModIdx : Std.HashMap Name ModuleIdx
const2ModIdx : HashMap Name ModuleIdx
/--
Mapping from constant name to `ConstantInfo`. It contains all constants (definitions, theorems, axioms, etc)
that have been already type checked by the kernel.
@@ -205,7 +205,7 @@ private def getTrustLevel (env : Environment) : UInt32 :=
env.header.trustLevel
def getModuleIdxFor? (env : Environment) (declName : Name) : Option ModuleIdx :=
env.const2ModIdx[declName]?
env.const2ModIdx.find? declName
def isConstructor (env : Environment) (declName : Name) : Bool :=
match env.find? declName with
@@ -721,7 +721,7 @@ def writeModule (env : Environment) (fname : System.FilePath) : IO Unit := do
Construct a mapping from persistent extension name to entension index at the array of persistent extensions.
We only consider extensions starting with index `>= startingAt`.
-/
def mkExtNameMap (startingAt : Nat) : IO (Std.HashMap Name Nat) := do
def mkExtNameMap (startingAt : Nat) : IO (HashMap Name Nat) := do
let descrs persistentEnvExtensionsRef.get
let mut result := {}
for h : i in [startingAt : descrs.size] do
@@ -742,7 +742,7 @@ private def setImportedEntries (env : Environment) (mods : Array ModuleData) (st
have : modIdx < mods.size := h.upper
let mod := mods[modIdx]
for (extName, entries) in mod.entries do
if let some entryIdx := extNameIdx[extName]? then
if let some entryIdx := extNameIdx.find? extName then
env := extDescrs[entryIdx]!.toEnvExtension.modifyState env fun s => { s with importedEntries := s.importedEntries.set! modIdx entries }
return env
@@ -790,9 +790,9 @@ structure ImportState where
moduleData : Array ModuleData := #[]
regions : Array CompactedRegion := #[]
def throwAlreadyImported (s : ImportState) (const2ModIdx : Std.HashMap Name ModuleIdx) (modIdx : Nat) (cname : Name) : IO α := do
def throwAlreadyImported (s : ImportState) (const2ModIdx : HashMap Name ModuleIdx) (modIdx : Nat) (cname : Name) : IO α := do
let modName := s.moduleNames[modIdx]!
let constModName := s.moduleNames[const2ModIdx[cname]!.toNat]!
let constModName := s.moduleNames[const2ModIdx[cname].get!.toNat]!
throw <| IO.userError s!"import {modName} failed, environment already contains '{cname}' from {constModName}"
abbrev ImportStateM := StateRefT ImportState IO
@@ -856,21 +856,21 @@ def finalizeImport (s : ImportState) (imports : Array Import) (opts : Options) (
(leakEnv := false) : IO Environment := do
let numConsts := s.moduleData.foldl (init := 0) fun numConsts mod =>
numConsts + mod.constants.size + mod.extraConstNames.size
let mut const2ModIdx : Std.HashMap Name ModuleIdx := Std.HashMap.empty (capacity := numConsts)
let mut constantMap : Std.HashMap Name ConstantInfo := Std.HashMap.empty (capacity := numConsts)
let mut const2ModIdx : HashMap Name ModuleIdx := mkHashMap (capacity := numConsts)
let mut constantMap : HashMap Name ConstantInfo := mkHashMap (capacity := numConsts)
for h:modIdx in [0:s.moduleData.size] do
let mod := s.moduleData[modIdx]'h.upper
for cname in mod.constNames, cinfo in mod.constants do
match constantMap.getThenInsertIfNew? cname cinfo with
| (cinfoPrev?, constantMap') =>
match constantMap.insertIfNew cname cinfo with
| (constantMap', cinfoPrev?) =>
constantMap := constantMap'
if let some cinfoPrev := cinfoPrev? then
-- Recall that the map has not been modified when `cinfoPrev? = some _`.
unless equivInfo cinfoPrev cinfo do
throwAlreadyImported s const2ModIdx modIdx cname
const2ModIdx := const2ModIdx.insertIfNew cname modIdx
const2ModIdx := const2ModIdx.insertIfNew cname modIdx |>.1
for cname in mod.extraConstNames do
const2ModIdx := const2ModIdx.insertIfNew cname modIdx
const2ModIdx := const2ModIdx.insertIfNew cname modIdx |>.1
let constants : ConstMap := SMap.fromHashMap constantMap false
let exts mkInitialExtensionStates
let mut env : Environment := {
@@ -936,7 +936,7 @@ builtin_initialize namespacesExt : SimplePersistentEnvExtension Name NameSSet
6.18% of the runtime is here. It was 9.31% before the `HashMap` optimization.
-/
let capacity := as.foldl (init := 0) fun r e => r + e.size
let map : Std.HashMap Name Unit := Std.HashMap.empty capacity
let map : HashMap Name Unit := mkHashMap capacity
let map := mkStateFromImportedEntries (fun map name => map.insert name ()) map as
SMap.fromHashMap map |>.switch
addEntryFn := fun s n => s.insert n

View File

@@ -8,7 +8,6 @@ import Init.Data.Hashable
import Lean.Data.KVMap
import Lean.Data.SMap
import Lean.Level
import Std.Data.HashSet.Basic
namespace Lean
@@ -245,7 +244,7 @@ def FVarIdSet.insert (s : FVarIdSet) (fvarId : FVarId) : FVarIdSet :=
A set of unique free variable identifiers implemented using hashtables.
Hashtables are faster than red-black trees if they are used linearly.
They are not persistent data-structures. -/
def FVarIdHashSet := Std.HashSet FVarId
def FVarIdHashSet := HashSet FVarId
deriving Inhabited, EmptyCollection
/--
@@ -1389,11 +1388,11 @@ def mkDecIsTrue (pred proof : Expr) :=
def mkDecIsFalse (pred proof : Expr) :=
mkAppB (mkConst `Decidable.isFalse) pred proof
abbrev ExprMap (α : Type) := Std.HashMap Expr α
abbrev ExprMap (α : Type) := HashMap Expr α
abbrev PersistentExprMap (α : Type) := PHashMap Expr α
abbrev SExprMap (α : Type) := SMap Expr α
abbrev ExprSet := Std.HashSet Expr
abbrev ExprSet := HashSet Expr
abbrev PersistentExprSet := PHashSet Expr
abbrev PExprSet := PersistentExprSet
@@ -1418,7 +1417,7 @@ instance : ToString ExprStructEq := ⟨fun e => toString e.val⟩
end ExprStructEq
abbrev ExprStructMap (α : Type) := Std.HashMap ExprStructEq α
abbrev ExprStructMap (α : Type) := HashMap ExprStructEq α
abbrev PersistentExprStructMap (α : Type) := PHashMap ExprStructEq α
namespace Expr

View File

@@ -31,7 +31,7 @@ namespace Lean
abbrev LabelExtension := SimpleScopedEnvExtension Name (Array Name)
/-- The collection of all current `LabelExtension`s, indexed by name. -/
abbrev LabelExtensionMap := Std.HashMap Name LabelExtension
abbrev LabelExtensionMap := HashMap Name LabelExtension
/-- Store the current `LabelExtension`s. -/
builtin_initialize labelExtensionMapRef : IO.Ref LabelExtensionMap IO.mkRef {}
@@ -88,7 +88,7 @@ macro (name := _root_.Lean.Parser.Command.registerLabelAttr)
/-- When `attrName` is an attribute created using `register_labelled_attr`,
return the names of all declarations labelled using that attribute. -/
def labelled (attrName : Name) : CoreM (Array Name) := do
match ( labelExtensionMapRef.get)[attrName]? with
match ( labelExtensionMapRef.get).find? attrName with
| none => throwError "No extension named {attrName}"
| some ext => pure <| ext.getState ( getEnv)

View File

@@ -519,17 +519,12 @@ where
-- definitely resolved in `doElab` task
let elabPromise IO.Promise.new
let finishedPromise IO.Promise.new
-- (Try to) use last line of command as range for final snapshot task. This ensures we do not
-- retract the progress bar to a previous position in case the command support incremental
-- reporting but has significant work after resolving its last incremental promise, such as
-- final type checking; if it does not support incrementality, `elabSnap` constructed in
-- `parseCmd` and containing the entire range of the command will determine the reported
-- progress and be resolved effectively at the same time as this snapshot task, so `tailPos` is
-- irrelevant in this case.
let endRange? := stx.getTailPos?.map fun pos => pos, pos
let finishedSnap := { range? := endRange?, task := finishedPromise.result }
let tacticCache old?.map (·.data.tacticCache) |>.getDM (IO.mkRef {})
let finishedSnap
doElab stx cmdState beginPos
{ old? := old?.map fun old => old.data.stx, old.data.elabSnap, new := elabPromise }
tacticCache
ctx
let minimalSnapshots := internal.minimalSnapshots.get cmdState.scopes.head!.opts
let next? if Parser.isTerminalCommand stx then pure none
@@ -541,31 +536,35 @@ where
stx := .missing
parserState := {}
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
finishedSnap := { range? := none, task := finishedPromise.result.map fun finishedSnap => {
finishedSnap := .pure {
diagnostics := finishedSnap.diagnostics
infoTree? := none
cmdState := {
env := initEnv
maxRecDepth := 0
}
}}
}
tacticCache
} else {
diagnostics, stx, parserState, tacticCache
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
finishedSnap
finishedSnap := .pure finishedSnap
}
prom.resolve <| .mk (nextCmdSnap? := next?.map ({ range? := some parserState.pos, ctx.input.endPos, task := ·.result })) data
doElab stx cmdState beginPos
{ old? := old?.map fun old => old.data.stx, old.data.elabSnap, new := elabPromise }
finishedPromise tacticCache ctx
if let some next := next? then
parseCmd none parserState finishedSnap.get.cmdState initEnv next ctx
parseCmd none parserState finishedSnap.cmdState initEnv next ctx
doElab (stx : Syntax) (cmdState : Command.State) (beginPos : String.Pos)
(snap : SnapshotBundle DynamicSnapshot) (finishedPromise : IO.Promise CommandFinishedSnapshot)
(tacticCache : IO.Ref Tactic.Cache) : LeanProcessingM Unit := do
(snap : SnapshotBundle DynamicSnapshot) (tacticCache : IO.Ref Tactic.Cache) :
LeanProcessingM CommandFinishedSnapshot := do
let ctx read
-- (Try to) use last line of command as range for final snapshot task. This ensures we do not
-- retract the progress bar to a previous position in case the command support incremental
-- reporting but has significant work after resolving its last incremental promise, such as
-- final type checking; if it does not support incrementality, `elabSnap` constructed in
-- `parseCmd` and containing the entire range of the command will determine the reported
-- progress and be resolved effectively at the same time as this snapshot task, so `tailPos` is
-- irrelevant in this case.
let scope := cmdState.scopes.head!
let cmdStateRef IO.mkRef { cmdState with messages := .empty }
/-
@@ -601,7 +600,7 @@ where
let cmdState := { cmdState with messages }
-- definitely resolve eventually
snap.new.resolve <| .ofTyped { diagnostics := .empty : SnapshotLeaf }
finishedPromise.resolve {
return {
diagnostics := ( Snapshot.Diagnostics.ofMessageLog cmdState.messages)
infoTree? := some cmdState.infoState.trees[0]!
cmdState

View File

@@ -614,9 +614,9 @@ where
end Level
abbrev LevelMap (α : Type) := Std.HashMap Level α
abbrev LevelMap (α : Type) := HashMap Level α
abbrev PersistentLevelMap (α : Type) := PHashMap Level α
abbrev LevelSet := Std.HashSet Level
abbrev LevelSet := HashSet Level
abbrev PersistentLevelSet := PHashSet Level
abbrev PLevelSet := PersistentLevelSet

View File

@@ -34,7 +34,7 @@ def constructorNameAsVariable : Linter where
| return
let infoTrees := ( get).infoState.trees.toArray
let warnings : IO.Ref (Std.HashMap String.Range (Syntax × Name × Name)) IO.mkRef {}
let warnings : IO.Ref (Lean.HashMap String.Range (Syntax × Name × Name)) IO.mkRef {}
for tree in infoTrees do
tree.visitM' (preNode := fun ci info _ => do

View File

@@ -149,7 +149,7 @@ def checkDecl : SimpleHandler := fun stx => do
lintField rest[1][0] stx[1] "computed field"
else if rest.getKind == ``«structure» then
unless rest[5][2].isNone do
let redecls : Std.HashSet String.Pos :=
let redecls : HashSet String.Pos :=
( get).infoState.trees.foldl (init := {}) fun s tree =>
tree.foldInfo (init := s) fun _ info s =>
if let .ofFieldRedeclInfo info := info then

View File

@@ -270,14 +270,14 @@ pointer identity and does not store the objects, so it is important not to store
pointer to an object in the map, or it can be freed and reused, resulting in incorrect behavior.
Returns `true` if the object was not already in the set. -/
unsafe def insertObjImpl {α : Type} (set : IO.Ref (Std.HashSet USize)) (a : α) : IO Bool := do
unsafe def insertObjImpl {α : Type} (set : IO.Ref (HashSet USize)) (a : α) : IO Bool := do
if ( set.get).contains (ptrAddrUnsafe a) then
return false
set.modify (·.insert (ptrAddrUnsafe a))
return true
@[inherit_doc insertObjImpl, implemented_by insertObjImpl]
opaque insertObj {α : Type} (set : IO.Ref (Std.HashSet USize)) (a : α) : IO Bool
opaque insertObj {α : Type} (set : IO.Ref (HashSet USize)) (a : α) : IO Bool
/--
Collects into `fvarUses` all `fvar`s occurring in the `Expr`s in `assignments`.
@@ -285,8 +285,8 @@ This implementation respects subterm sharing in both the `PersistentHashMap` and
to ensure that pointer-equal subobjects are not visited multiple times, which is important
in practice because these expressions are very frequently highly shared.
-/
partial def visitAssignments (set : IO.Ref (Std.HashSet USize))
(fvarUses : IO.Ref (Std.HashSet FVarId))
partial def visitAssignments (set : IO.Ref (HashSet USize))
(fvarUses : IO.Ref (HashSet FVarId))
(assignments : Array (PersistentHashMap MVarId Expr)) : IO Unit := do
MonadCacheT.run do
for assignment in assignments do
@@ -316,8 +316,8 @@ where
/-- Given `aliases` as a map from an alias to what it aliases, we get the original
term by recursion. This has no cycle detection, so if `aliases` contains a loop
then this function will recurse infinitely. -/
partial def followAliases (aliases : Std.HashMap FVarId FVarId) (x : FVarId) : FVarId :=
match aliases[x]? with
partial def followAliases (aliases : HashMap FVarId FVarId) (x : FVarId) : FVarId :=
match aliases.find? x with
| none => x
| some y => followAliases aliases y
@@ -343,17 +343,17 @@ structure References where
the spans for `foo`, `bar`, and `baz`. Global definitions are always treated as used.
(It would be nice to be able to detect unused global definitions but this requires more
information than the linter framework can provide.) -/
constDecls : Std.HashSet String.Range := .empty
constDecls : HashSet String.Range := .empty
/-- The collection of all local declarations, organized by the span of the declaration.
We collapse all declarations declared at the same position into a single record using
`FVarDefinition.aliases`. -/
fvarDefs : Std.HashMap String.Range FVarDefinition := .empty
fvarDefs : HashMap String.Range FVarDefinition := .empty
/-- The set of `FVarId`s that are used directly. These may or may not be aliases. -/
fvarUses : Std.HashSet FVarId := .empty
fvarUses : HashSet FVarId := .empty
/-- A mapping from alias to original FVarId. We don't guarantee that the value is not itself
an alias, but we use `followAliases` when adding new elements to try to avoid long chains. -/
-- TODO: use a `UnionFind` data structure here
fvarAliases : Std.HashMap FVarId FVarId := .empty
fvarAliases : HashMap FVarId FVarId := .empty
/-- Collection of all `MetavarContext`s following the execution of a tactic. We trawl these
if needed to find additional `fvarUses`. -/
assignments : Array (PersistentHashMap MVarId Expr) := #[]
@@ -391,7 +391,7 @@ def collectReferences (infoTrees : Array Elab.InfoTree) (cmdStxRange : String.Ra
if s.startsWith "_" then return
-- Record this either as a new `fvarDefs`, or an alias of an existing one
modify fun s =>
if let some ref := s.fvarDefs[range]? then
if let some ref := s.fvarDefs.find? range then
{ s with fvarDefs := s.fvarDefs.insert range { ref with aliases := ref.aliases.push id } }
else
{ s with fvarDefs := s.fvarDefs.insert range { userName := ldecl.userName, stx, opts, aliases := #[id] } }
@@ -444,7 +444,7 @@ def unusedVariables : Linter where
-- Resolve all recursive references in `fvarAliases`.
-- At this point everything in `fvarAliases` is guaranteed not to be itself an alias,
-- and should point to some element of `FVarDefinition.aliases` in `s.fvarDefs`
let fvarAliases : Std.HashMap FVarId FVarId := s.fvarAliases.fold (init := {}) fun m id baseId =>
let fvarAliases : HashMap FVarId FVarId := s.fvarAliases.fold (init := {}) fun m id baseId =>
m.insert id (followAliases s.fvarAliases baseId)
-- Collect all non-alias fvars corresponding to `fvarUses` by resolving aliases in the list.
@@ -461,7 +461,7 @@ def unusedVariables : Linter where
let fvarUses fvarUsesRef.get
-- If any of the `fvar`s corresponding to this declaration is (an alias of) a variable in
-- `fvarUses`, then it is used
if aliases.any fun id => fvarUses.contains (fvarAliases.getD id id) then continue
if aliases.any fun id => fvarUses.contains (fvarAliases.findD id id) then continue
-- If this is a global declaration then it is (potentially) used after the command
if s.constDecls.contains range then continue
@@ -496,7 +496,7 @@ def unusedVariables : Linter where
initializedMVars := true
let fvarUses fvarUsesRef.get
-- Redo the initial check because `fvarUses` could be bigger now
if aliases.any fun id => fvarUses.contains (fvarAliases.getD id id) then continue
if aliases.any fun id => fvarUses.contains (fvarAliases.findD id id) then continue
-- If we made it this far then the variable is unused and not ignored
unused := unused.push (declStx, userName)

View File

@@ -16,8 +16,8 @@ structure State where
nextParamIdx : Nat := 0
paramNames : Array Name := #[]
fvars : Array Expr := #[]
lmap : Std.HashMap LMVarId Level := {}
emap : Std.HashMap MVarId Expr := {}
lmap : HashMap LMVarId Level := {}
emap : HashMap MVarId Expr := {}
abstractLevels : Bool -- whether to abstract level mvars
abbrev M := StateM State
@@ -54,7 +54,7 @@ private partial def abstractLevelMVars (u : Level) : M Level := do
if depth != s.mctx.depth then
return u -- metavariables from lower depths are treated as constants
else
match s.lmap[mvarId]? with
match s.lmap.find? mvarId with
| some u => pure u
| none =>
let paramId := Name.mkNum `_abstMVar s.nextParamIdx
@@ -87,7 +87,7 @@ partial def abstractExprMVars (e : Expr) : M Expr := do
if e != eNew then
abstractExprMVars eNew
else
match ( get).emap[mvarId]? with
match ( get).emap.find? mvarId with
| some e =>
return e
| none =>

View File

@@ -4,11 +4,10 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Data.HashMap
import Lean.Util.ShareCommon
import Lean.Data.HashMap
import Lean.Meta.Basic
import Lean.Meta.FunInfo
import Std.Data.HashMap.Raw
namespace Lean.Meta
namespace Canonicalizer
@@ -48,12 +47,12 @@ State for the `CanonM` monad.
-/
structure State where
/-- Mapping from `Expr` to hash. -/
-- We use `HashMap.Raw` to ensure we don't have to tag `State` as `unsafe`.
cache : Std.HashMap.Raw ExprVisited UInt64 := Std.HashMap.Raw.empty
-- We use `HashMapImp` to ensure we don't have to tag `State` as `unsafe`.
cache : HashMapImp ExprVisited UInt64 := mkHashMapImp
/--
Given a hashcode `k` and `keyToExprs.find? h = some es`, we have that all `es` have hashcode `k`, and
are not definitionally equal modulo the transparency setting used. -/
keyToExprs : Std.HashMap UInt64 (List Expr) :=
keyToExprs : HashMap UInt64 (List Expr) := mkHashMap
instance : Inhabited State where
default := {}
@@ -71,7 +70,7 @@ def CanonM.run (x : CanonM α) (transparency := TransparencyMode.instances) (s :
StateRefT'.run (x transparency) s
private partial def mkKey (e : Expr) : CanonM UInt64 := do
if let some hash := unsafe ( get).cache.get? { e } then
if let some hash := unsafe ( get).cache.find? { e } then
return hash
else
let key match e with
@@ -108,7 +107,7 @@ private partial def mkKey (e : Expr) : CanonM UInt64 := do
return mixHash ( mkKey v) ( mkKey b)
| .proj _ i s =>
return mixHash i.toUInt64 ( mkKey s)
unsafe modify fun { cache, keyToExprs} => { keyToExprs, cache := cache.insert { e } key }
unsafe modify fun { cache, keyToExprs} => { keyToExprs, cache := cache.insert { e } key |>.1 }
return key
/--
@@ -117,7 +116,7 @@ private partial def mkKey (e : Expr) : CanonM UInt64 := do
def canon (e : Expr) : CanonM Expr := do
let k mkKey e
-- Find all expressions canonicalized before that have the same key.
if let some es' := unsafe ( get).keyToExprs[k]? then
if let some es' := unsafe ( get).keyToExprs.find? k then
withTransparency ( read) do
for e' in es' do
-- Found an expression `e'` that is definitionally equal to `e` and share the same key.

View File

@@ -127,7 +127,7 @@ abbrev ClosureM := ReaderT Context $ StateRefT State MetaM
pure u
else
let s get
match s.visitedLevel[u]? with
match s.visitedLevel.find? u with
| some v => pure v
| none => do
let v f u
@@ -139,7 +139,7 @@ abbrev ClosureM := ReaderT Context $ StateRefT State MetaM
pure e
else
let s get
match s.visitedExpr.get? e with
match s.visitedExpr.find? e with
| some r => pure r
| none =>
let r f e

View File

@@ -52,14 +52,14 @@ which appear in the type and local context of `mvarId`, as well as the
metavariables which *those* metavariables depend on, etc.
-/
partial def _root_.Lean.MVarId.getMVarDependencies (mvarId : MVarId) (includeDelayed := false) :
MetaM (Std.HashSet MVarId) :=
MetaM (HashSet MVarId) :=
(·.snd) <$> (go mvarId).run {}
where
/-- Auxiliary definition for `getMVarDependencies`. -/
addMVars (e : Expr) : StateRefT (Std.HashSet MVarId) MetaM Unit := do
addMVars (e : Expr) : StateRefT (HashSet MVarId) MetaM Unit := do
let mvars getMVars e
let mut s get
set ({} : Std.HashSet MVarId) -- Ensure that `s` is not shared.
set ({} : HashSet MVarId) -- Ensure that `s` is not shared.
for mvarId in mvars do
if pure includeDelayed <||> notM (mvarId.isDelayedAssigned) then
s := s.insert mvarId
@@ -67,7 +67,7 @@ where
mvars.forM go
/-- Auxiliary definition for `getMVarDependencies`. -/
go (mvarId : MVarId) : StateRefT (Std.HashSet MVarId) MetaM Unit :=
go (mvarId : MVarId) : StateRefT (HashSet MVarId) MetaM Unit :=
withIncRecDepth do
let mdecl mvarId.getDecl
addMVars mdecl.type

View File

@@ -695,7 +695,7 @@ def throwOutOfScopeFVar : CheckAssignmentM α :=
throw <| Exception.internal outOfScopeExceptionId
private def findCached? (e : Expr) : CheckAssignmentM (Option Expr) := do
return ( get).cache.get? e
return ( get).cache.find? e
private def cache (e r : Expr) : CheckAssignmentM Unit := do
modify fun s => { s with cache := s.cache.insert e r }

View File

@@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
-/
prelude
import Lean.Data.AssocList
import Lean.HeadIndex
import Lean.Meta.Basic

View File

@@ -286,7 +286,7 @@ private structure Trie (α : Type) where
/-- Index of trie matching star. -/
star : TrieIndex
/-- Following matches based on key of trie. -/
children : Std.HashMap Key TrieIndex
children : HashMap Key TrieIndex
/-- Lazy entries at this trie that are not processed. -/
pending : Array (LazyEntry α) := #[]
deriving Inhabited
@@ -318,7 +318,7 @@ structure LazyDiscrTree (α : Type) where
/-- Backing array of trie entries. Should be owned by this trie. -/
tries : Array (LazyDiscrTree.Trie α) := #[default]
/-- Map from discriminator trie roots to the index. -/
roots : Std.HashMap LazyDiscrTree.Key LazyDiscrTree.TrieIndex := {}
roots : Lean.HashMap LazyDiscrTree.Key LazyDiscrTree.TrieIndex := {}
namespace LazyDiscrTree
@@ -445,9 +445,9 @@ private def addLazyEntryToTrie (i:TrieIndex) (e : LazyEntry α) : MatchM α Unit
modify (·.modify i (·.pushPending e))
private def evalLazyEntry (config : WhnfCoreConfig)
(p : Array α × TrieIndex × Std.HashMap Key TrieIndex)
(p : Array α × TrieIndex × HashMap Key TrieIndex)
(entry : LazyEntry α)
: MatchM α (Array α × TrieIndex × Std.HashMap Key TrieIndex) := do
: MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
let (values, starIdx, children) := p
let (todo, lctx, v) := entry
if todo.isEmpty then
@@ -465,7 +465,7 @@ private def evalLazyEntry (config : WhnfCoreConfig)
addLazyEntryToTrie starIdx (todo, lctx, v)
pure (values, starIdx, children)
else
match children[k]? with
match children.find? k with
| none =>
let children := children.insert k ( newTrie (todo, lctx, v))
pure (values, starIdx, children)
@@ -478,16 +478,16 @@ This evaluates all lazy entries in a trie and updates `values`, `starIdx`, and `
accordingly.
-/
private partial def evalLazyEntries (config : WhnfCoreConfig)
(values : Array α) (starIdx : TrieIndex) (children : Std.HashMap Key TrieIndex)
(values : Array α) (starIdx : TrieIndex) (children : HashMap Key TrieIndex)
(entries : Array (LazyEntry α)) :
MatchM α (Array α × TrieIndex × Std.HashMap Key TrieIndex) := do
MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
let mut values := values
let mut starIdx := starIdx
let mut children := children
entries.foldlM (init := (values, starIdx, children)) (evalLazyEntry config)
private def evalNode (c : TrieIndex) :
MatchM α (Array α × TrieIndex × Std.HashMap Key TrieIndex) := do
MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
let .node vs star cs pending := (get).get! c
if pending.size = 0 then
pure (vs, star, cs)
@@ -508,7 +508,7 @@ def dropKeyAux (next : TrieIndex) (rest : List Key) :
| [] =>
modify (·.set! next {values := #[], star, children})
| k :: r => do
let next := if k == .star then star else children.getD k 0
let next := if k == .star then star else children.findD k 0
dropKeyAux next r
/--
@@ -519,7 +519,7 @@ def dropKey (t : LazyDiscrTree α) (path : List LazyDiscrTree.Key) : MetaM (Lazy
match path with
| [] => pure t
| rootKey :: rest => do
let idx := t.roots.getD rootKey 0
let idx := t.roots.findD rootKey 0
Prod.snd <$> runMatch t (dropKeyAux idx rest)
/--
@@ -628,7 +628,7 @@ private partial def getMatchLoop (cases : Array PartialMatch) (result : MatchRes
else
cases.push { todo, score := ca.score, c := star }
let pushNonStar (k : Key) (args : Array Expr) (cases : Array PartialMatch) :=
match cs[k]? with
match cs.find? k with
| none => cases
| some c => cases.push { todo := todo ++ args, score := ca.score + 1, c }
let cases := pushStar cases
@@ -650,8 +650,8 @@ private partial def getMatchLoop (cases : Array PartialMatch) (result : MatchRes
cases |> pushNonStar k args
getMatchLoop cases result
private def getStarResult (root : Std.HashMap Key TrieIndex) : MatchM α (MatchResult α) :=
match root[Key.star]? with
private def getStarResult (root : Lean.HashMap Key TrieIndex) : MatchM α (MatchResult α) :=
match root.find? .star with
| none =>
pure <| {}
| some idx => do
@@ -661,16 +661,16 @@ private def getStarResult (root : Std.HashMap Key TrieIndex) : MatchM α (MatchR
/-
Add partial match to cases if discriminator tree root map has potential matches.
-/
private def pushRootCase (r : Std.HashMap Key TrieIndex) (k : Key) (args : Array Expr)
private def pushRootCase (r : Lean.HashMap Key TrieIndex) (k : Key) (args : Array Expr)
(cases : Array PartialMatch) : Array PartialMatch :=
match r[k]? with
match r.find? k with
| none => cases
| some c => cases.push { todo := args, score := 1, c }
/--
Find values that match `e` in `root`.
-/
private def getMatchCore (root : Std.HashMap Key TrieIndex) (e : Expr) :
private def getMatchCore (root : Lean.HashMap Key TrieIndex) (e : Expr) :
MatchM α (MatchResult α) := do
let result getStarResult root
let (k, args) MatchClone.getMatchKeyArgs e (root := true) ( read)
@@ -701,7 +701,7 @@ of elements using concurrent functions for generating entries.
-/
private structure PreDiscrTree (α : Type) where
/-- Maps keys to index in tries array. -/
roots : Std.HashMap Key Nat := {}
roots : HashMap Key Nat := {}
/-- Lazy entries for root of trie. -/
tries : Array (Array (LazyEntry α)) := #[]
deriving Inhabited
@@ -711,7 +711,7 @@ namespace PreDiscrTree
private def modifyAt (d : PreDiscrTree α) (k : Key)
(f : Array (LazyEntry α) Array (LazyEntry α)) : PreDiscrTree α :=
let { roots, tries } := d
match roots[k]? with
match roots.find? k with
| .none =>
let roots := roots.insert k tries.size
{ roots, tries := tries.push (f #[]) }

View File

@@ -68,7 +68,7 @@ where
loop lhss alts minors
structure State where
used : Std.HashSet Nat := {} -- used alternatives
used : HashSet Nat := {} -- used alternatives
counterExamples : List (List Example) := []
/-- Return true if the given (sub-)problem has been solved. -/

View File

@@ -28,17 +28,17 @@ such as `contradiction`.
-/
private def _root_.Lean.MVarId.contradictionQuick (mvarId : MVarId) : MetaM Bool := do
mvarId.withContext do
let mut posMap : Std.HashMap Expr FVarId := {}
let mut negMap : Std.HashMap Expr FVarId := {}
let mut posMap : HashMap Expr FVarId := {}
let mut negMap : HashMap Expr FVarId := {}
for localDecl in ( getLCtx) do
unless localDecl.isImplementationDetail do
if let some p matchNot? localDecl.type then
if let some pFVarId := posMap[p]? then
if let some pFVarId := posMap.find? p then
mvarId.assign ( mkAbsurd ( mvarId.getType) (mkFVar pFVarId) localDecl.toExpr)
return true
negMap := negMap.insert p localDecl.fvarId
if ( isProp localDecl.type) then
if let some nFVarId := negMap[localDecl.type]? then
if let some nFVarId := negMap.find? localDecl.type then
mvarId.assign ( mkAbsurd ( mvarId.getType) localDecl.toExpr (mkFVar nFVarId))
return true
posMap := posMap.insert localDecl.type localDecl.fvarId

View File

@@ -97,8 +97,8 @@ namespace MkTableKey
structure State where
nextIdx : Nat := 0
lmap : Std.HashMap LMVarId Level := {}
emap : Std.HashMap MVarId Expr := {}
lmap : HashMap LMVarId Level := {}
emap : HashMap MVarId Expr := {}
mctx : MetavarContext
abbrev M := StateM State
@@ -120,7 +120,7 @@ partial def normLevel (u : Level) : M Level := do
return u
else
let s get
match ( get).lmap[mvarId]? with
match ( get).lmap.find? mvarId with
| some u' => pure u'
| none =>
let u' := mkLevelParam <| Name.mkNum `_tc s.nextIdx
@@ -145,7 +145,7 @@ partial def normExpr (e : Expr) : M Expr := do
return e
else
let s get
match s.emap[mvarId]? with
match s.emap.find? mvarId with
| some e' => pure e'
| none => do
let e' := mkFVar { name := Name.mkNum `_tc s.nextIdx }
@@ -186,7 +186,7 @@ structure State where
result? : Option AbstractMVarsResult := none
generatorStack : Array GeneratorNode := #[]
resumeStack : Array (ConsumerNode × Answer) := #[]
tableEntries : Std.HashMap Expr TableEntry := {}
tableEntries : HashMap Expr TableEntry := {}
abbrev SynthM := ReaderT Context $ StateRefT State MetaM
@@ -265,7 +265,7 @@ def newSubgoal (mctx : MetavarContext) (key : Expr) (mvar : Expr) (waiter : Wait
pure ((), m!"new goal {key}")
def findEntry? (key : Expr) : SynthM (Option TableEntry) := do
return ( get).tableEntries[key]?
return ( get).tableEntries.find? key
def getEntry (key : Expr) : SynthM TableEntry := do
match ( findEntry? key) with
@@ -553,7 +553,7 @@ def generate : SynthM Unit := do
/- See comment at `typeHasMVars` -/
if backward.synthInstance.canonInstances.get ( getOptions) then
unless gNode.typeHasMVars do
if let some entry := ( get).tableEntries[key]? then
if let some entry := ( get).tableEntries.find? key then
if entry.answers.any fun answer => answer.result.numMVars == 0 then
/-
We already have an answer that:

View File

@@ -66,9 +66,9 @@ inductive PreExpr
def toACExpr (op l r : Expr) : MetaM (Array Expr × ACExpr) := do
let (preExpr, vars)
toPreExpr (mkApp2 op l r)
|>.run Std.HashSet.empty
|>.run HashSet.empty
let vars := vars.toArray.insertionSort Expr.lt
let varMap := vars.foldl (fun xs x => xs.insert x xs.size) Std.HashMap.empty |>.get!
let varMap := vars.foldl (fun xs x => xs.insert x xs.size) HashMap.empty |>.find!
return (vars, toACExpr varMap preExpr)
where

View File

@@ -290,7 +290,7 @@ structure RewriteResultConfig where
side : SideConditions := .solveByElim
mctx : MetavarContext
def takeListAux (cfg : RewriteResultConfig) (seen : Std.HashMap String Unit) (acc : Array RewriteResult)
def takeListAux (cfg : RewriteResultConfig) (seen : HashMap String Unit) (acc : Array RewriteResult)
(xs : List ((Expr Name) × Bool × Nat)) : MetaM (Array RewriteResult) := do
let mut seen := seen
let mut acc := acc

View File

@@ -420,12 +420,12 @@ def mkSimpExt (name : Name := by exact decl_name%) : IO SimpExtension :=
| .toUnfoldThms n thms => d.registerDeclToUnfoldThms n thms
}
abbrev SimpExtensionMap := Std.HashMap Name SimpExtension
abbrev SimpExtensionMap := HashMap Name SimpExtension
builtin_initialize simpExtensionMapRef : IO.Ref SimpExtensionMap IO.mkRef {}
def getSimpExtension? (attrName : Name) : IO (Option SimpExtension) :=
return ( simpExtensionMapRef.get)[attrName]?
return ( simpExtensionMapRef.get).find? attrName
/-- Auxiliary method for adding a global declaration to a `SimpTheorems` datastructure. -/
def SimpTheorems.addConst (s : SimpTheorems) (declName : Name) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM SimpTheorems := do

View File

@@ -19,8 +19,8 @@ It contains:
- The actual procedure associated with a name.
-/
structure BuiltinSimprocs where
keys : Std.HashMap Name (Array SimpTheoremKey) := {}
procs : Std.HashMap Name (Sum Simproc DSimproc) := {}
keys : HashMap Name (Array SimpTheoremKey) := {}
procs : HashMap Name (Sum Simproc DSimproc) := {}
deriving Inhabited
/--
@@ -37,7 +37,7 @@ structure SimprocDecl where
deriving Inhabited
structure SimprocDeclExtState where
builtin : Std.HashMap Name (Array SimpTheoremKey)
builtin : HashMap Name (Array SimpTheoremKey)
newEntries : PHashMap Name (Array SimpTheoremKey) := {}
deriving Inhabited
@@ -65,7 +65,7 @@ def getSimprocDeclKeys? (declName : Name) : CoreM (Option (Array SimpTheoremKey)
if let some keys := keys? then
return some keys
else
return (simprocDeclExt.getState env).builtin[declName]?
return (simprocDeclExt.getState env).builtin.find? declName
def isBuiltinSimproc (declName : Name) : CoreM Bool := do
let s := simprocDeclExt.getState ( getEnv)
@@ -160,7 +160,7 @@ def Simprocs.addCore (s : Simprocs) (keys : Array SimpTheoremKey) (declName : Na
Implements attributes `builtin_simproc` and `builtin_sevalproc`.
-/
def addSimprocBuiltinAttrCore (ref : IO.Ref Simprocs) (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : IO Unit := do
let some keys := ( builtinSimprocDeclsRef.get).keys[declName]? |
let some keys := ( builtinSimprocDeclsRef.get).keys.find? declName |
throw (IO.userError "invalid [builtin_simproc] attribute, '{declName}' is not a builtin simproc")
ref.modify fun s => s.addCore keys declName post proc
@@ -176,7 +176,7 @@ def Simprocs.add (s : Simprocs) (declName : Name) (post : Bool) : CoreM Simprocs
getSimprocFromDecl declName
catch e =>
if ( isBuiltinSimproc declName) then
let some proc := ( builtinSimprocDeclsRef.get).procs[declName]?
let some proc := ( builtinSimprocDeclsRef.get).procs.find? declName
| throwError "invalid [simproc] attribute, '{declName}' is not a simproc"
pure proc
else
@@ -384,7 +384,7 @@ def mkSimprocAttr (attrName : Name) (attrDescr : String) (ext : SimprocExtension
erase := eraseSimprocAttr ext
}
abbrev SimprocExtensionMap := Std.HashMap Name SimprocExtension
abbrev SimprocExtensionMap := HashMap Name SimprocExtension
builtin_initialize simprocExtensionMapRef : IO.Ref SimprocExtensionMap IO.mkRef {}
@@ -438,7 +438,7 @@ def getSEvalSimprocs : CoreM Simprocs :=
return simprocSEvalExtension.getState ( getEnv)
def getSimprocExtensionCore? (attrName : Name) : IO (Option SimprocExtension) :=
return ( simprocExtensionMapRef.get)[attrName]?
return ( simprocExtensionMapRef.get).find? attrName
def simpAttrNameToSimprocAttrName (attrName : Name) : Name :=
if attrName == `simp then `simprocAttr

View File

@@ -512,7 +512,7 @@ def mkCongrSimp? (f : Expr) : SimpM (Option CongrTheorem) := do
if kinds.all fun k => match k with | CongrArgKind.fixed => true | CongrArgKind.eq => true | _ => false then
/- See remark above. -/
return none
match ( get).congrCache[f]? with
match ( get).congrCache.find? f with
| some thm? => return thm?
| none =>
let thm? mkCongrSimpCore? f info kinds

View File

@@ -903,7 +903,7 @@ structure State where
mctx : MetavarContext
nextMacroScope : MacroScope
ngen : NameGenerator
cache : Std.HashMap ExprStructEq Expr := {}
cache : HashMap ExprStructEq Expr := {}
structure Context where
mainModule : Name
@@ -1319,7 +1319,7 @@ structure State where
mctx : MetavarContext
paramNames : Array Name := #[]
nextParamIdx : Nat
cache : Std.HashMap ExprStructEq Expr := {}
cache : HashMap ExprStructEq Expr := {}
abbrev M := ReaderT Context <| StateM State
@@ -1328,7 +1328,7 @@ instance : MonadMCtx M where
modifyMCtx f := modify fun s => { s with mctx := f s.mctx }
instance : MonadCache ExprStructEq Expr M where
findCached? e := return ( get).cache[e]?
findCached? e := return ( get).cache.find? e
cache e v := modify fun s => { s with cache := s.cache.insert e v }
partial def mkParamName : M Name := do

View File

@@ -242,10 +242,10 @@ def «structure» := leading_parser
@[builtin_command_parser] def noncomputableSection := leading_parser
"noncomputable " >> "section" >> optional (ppSpace >> checkColGt >> ident)
/--
A `section`/`end` pair delimits the scope of `variable`, `include, `open`, `set_option`, and `local`
commands. Sections can be nested. `section <id>` provides a label to the section that has to appear
with the matching `end`. In either case, the `end` can be omitted, in which case the section is
closed at the end of the file.
A `section`/`end` pair delimits the scope of `variable`, `open`, `set_option`, and `local` commands.
Sections can be nested. `section <id>` provides a label to the section that has to appear with the
matching `end`. In either case, the `end` can be omitted, in which case the section is closed at the
end of the file.
-/
@[builtin_command_parser] def «section» := leading_parser
"section" >> optional (ppSpace >> checkColGt >> ident)
@@ -274,12 +274,12 @@ with `end <id>`. The `end` command is optional at the end of a file.
@[builtin_command_parser] def «end» := leading_parser
"end" >> optional (ppSpace >> checkColGt >> ident)
/-- Declares one or more typed variables, or modifies whether already-declared variables are
implicit.
implicit.
Introduces variables that can be used in definitions within the same `namespace` or `section` block.
When a definition mentions a variable, Lean will add it as an argument of the definition. This is
useful in particular when writing many definitions that have parameters in common (see below for an
example).
When a definition mentions a variable, Lean will add it as an argument of the definition. The
`variable` command is also able to add typeclass parameters. This is useful in particular when
writing many definitions that have parameters in common (see below for an example).
Variable declarations have the same flexibility as regular function paramaters. In particular they
can be [explicit, implicit][binder docs], or [instance implicit][tpil classes] (in which case they
@@ -287,22 +287,17 @@ can be anonymous). This can be changed, for instance one can turn explicit varia
implicit one with `variable {x}`. Note that currently, you should avoid changing how variables are
bound and declare new variables at the same time; see [issue 2789] for more on this topic.
In *theorem bodies* (i.e. proofs), variables are not included based on usage in order to ensure that
changes to the proof cannot change the statement of the overall theorem. Instead, variables are only
available to the proof if they have been mentioned in the theorem header or in an `include` command
or are instance implicit and depend only on such variables.
See [*Variables and Sections* from Theorem Proving in Lean][tpil vars] for a more detailed
discussion.
[tpil vars]:
https://lean-lang.org/theorem_proving_in_lean4/dependent_type_theory.html#variables-and-sections
(Variables and Sections on Theorem Proving in Lean) [tpil classes]:
https://lean-lang.org/theorem_proving_in_lean4/type_classes.html (Type classes on Theorem Proving in
Lean) [binder docs]:
https://leanprover-community.github.io/mathlib4_docs/Lean/Expr.html#Lean.BinderInfo (Documentation
for the BinderInfo type) [issue 2789]: https://github.com/leanprover/lean4/issues/2789 (Issue 2789
on github)
[tpil vars]: https://lean-lang.org/theorem_proving_in_lean4/dependent_type_theory.html#variables-and-sections
(Variables and Sections on Theorem Proving in Lean)
[tpil classes]: https://lean-lang.org/theorem_proving_in_lean4/type_classes.html
(Type classes on Theorem Proving in Lean)
[binder docs]: https://leanprover-community.github.io/mathlib4_docs/Lean/Expr.html#Lean.BinderInfo
(Documentation for the BinderInfo type)
[issue 2789]: https://github.com/leanprover/lean4/issues/2789
(Issue 2789 on github)
## Examples
@@ -373,24 +368,6 @@ namespace Logger
end Logger
```
The following example demonstrates availability of variables in proofs:
```lean
variable
{α : Type} -- available in the proof as indirectly mentioned through `a`
[ToString α] -- available in the proof as `α` is included
(a : α) -- available in the proof as mentioned in the header
{β : Type} -- not available in the proof
[ToString β] -- not available in the proof
theorem ex : a = a := rfl
```
After elaboration of the proof, the following warning will be generated to highlight the unused
hypothesis:
```
included section variable '[ToString α]' is not used in 'ex', consider excluding it
```
In such cases, the offending variable declaration should be moved down or into a section so that
only theorems that do depend on it follow it until the end of the section.
-/
@[builtin_command_parser] def «variable» := leading_parser
"variable" >> many1 (ppSpace >> checkColGt >> Term.bracketedBinder)
@@ -726,13 +703,8 @@ list, so it should be brief.
@[builtin_command_parser] def genInjectiveTheorems := leading_parser
"gen_injective_theorems% " >> ident
/--
`include eeny meeny` instructs Lean to include the section `variable`s `eeny` and `meeny` in all
declarations in the remainder of the current section, differing from the default behavior of
conditionally including variables based on use in the declaration header. `include` is usually
followed by the `in` combinator to limit the inclusion to the subsequent declaration.
-/
@[builtin_command_parser] def «include» := leading_parser "include " >> many1 ident
/-- To be implemented. -/
@[builtin_command_parser] def «include» := leading_parser "include " >> many1 (checkColGt >> ident)
/-- No-op parser used as syntax kind for attaching remaining whitespace at the end of the input. -/
@[run_builtin_parser_attribute_hooks] def eoi : Parser := leading_parser ""

View File

@@ -131,7 +131,7 @@ structure ParserCacheEntry where
structure ParserCache where
tokenCache : TokenCacheEntry
parserCache : Std.HashMap ParserCacheKey ParserCacheEntry
parserCache : HashMap ParserCacheKey ParserCacheEntry
def initCacheForInput (input : String) : ParserCache where
tokenCache := { startPos := input.endPos + ' ' /- make sure it is not a valid position -/ }
@@ -418,7 +418,7 @@ place if there was an error.
-/
def withCacheFn (parserName : Name) (p : ParserFn) : ParserFn := fun c s => Id.run do
let key := c.toCacheableParserContext, parserName, s.pos
if let some r := s.cache.parserCache[key]? then
if let some r := s.cache.parserCache.find? key then
-- TODO: turn this into a proper trace once we have these in the parser
--dbg_trace "parser cache hit: {parserName}:{s.pos} -> {r.stx}"
return s.stxStack.push r.stx, r.lhsPrec, r.newPos, s.cache, r.errorMsg, s.recoveredErrors

View File

@@ -123,12 +123,6 @@ unsafe def mkDelabAttribute : IO (KeyedDeclsAttribute Delab) :=
} `Lean.PrettyPrinter.Delaborator.delabAttribute
@[builtin_init mkDelabAttribute] opaque delabAttribute : KeyedDeclsAttribute Delab
macro "app_delab" id:ident : attr => do
match Macro.resolveGlobalName id.getId with
| [] => Macro.throwErrorAt id s!"unknown declaration '{id.getId}'"
| [(c, [])] => `(attr| delab $(mkIdentFrom (canonical := true) id (`app ++ c)))
| _ => Macro.throwErrorAt id s!"ambiguous declaration '{id.getId}'"
def getExprKind : DelabM Name := do
let e getExpr
pure $ match e with

View File

@@ -198,10 +198,10 @@ def isHBinOp (e : Expr) : Bool := Id.run do
def replaceLPsWithVars (e : Expr) : MetaM Expr := do
if !e.hasLevelParam then return e
let lps := collectLevelParams {} e |>.params
let mut replaceMap : Std.HashMap Name Level := {}
let mut replaceMap : HashMap Name Level := {}
for lp in lps do replaceMap := replaceMap.insert lp ( mkFreshLevelMVar)
return e.replaceLevel fun
| Level.param n .. => replaceMap[n]!
| Level.param n .. => replaceMap.find! n
| l => if !l.hasParam then some l else none
def isDefEqAssigning (t s : Expr) : MetaM Bool := do

View File

@@ -29,7 +29,7 @@ namespace Lean.Environment
namespace Replay
structure Context where
newConstants : Std.HashMap Name ConstantInfo
newConstants : HashMap Name ConstantInfo
structure State where
env : Environment
@@ -73,7 +73,7 @@ and add it to the environment.
-/
partial def replayConstant (name : Name) : M Unit := do
if isTodo name then
let some ci := ( read).newConstants[name]? | unreachable!
let some ci := ( read).newConstants.find? name | unreachable!
replayConstants ci.getUsedConstantsAsSet
-- Check that this name is still pending: a mutual block may have taken care of it.
if ( get).pending.contains name then
@@ -89,13 +89,13 @@ partial def replayConstant (name : Name) : M Unit := do
| .inductInfo info =>
let lparams := info.levelParams
let nparams := info.numParams
let all info.all.mapM fun n => do pure <| (( read).newConstants[n]!)
let all info.all.mapM fun n => do pure <| (( read).newConstants.find! n)
for o in all do
modify fun s =>
{ s with remaining := s.remaining.erase o.name, pending := s.pending.erase o.name }
let ctorInfo all.mapM fun ci => do
pure (ci, ci.inductiveVal!.ctors.mapM fun n => do
pure (( read).newConstants[n]!))
pure (( read).newConstants.find! n))
-- Make sure we are really finished with the constructors.
for (_, ctors) in ctorInfo do
for ctor in ctors do
@@ -129,7 +129,7 @@ when we replayed the inductives.
-/
def checkPostponedConstructors : M Unit := do
for ctor in ( get).postponedConstructors do
match ( get).env.constants.find? ctor, ( read).newConstants[ctor]? with
match ( get).env.constants.find? ctor, ( read).newConstants.find? ctor with
| some (.ctorInfo info), some (.ctorInfo info') =>
if ! (info == info') then throw <| IO.userError s!"Invalid constructor {ctor}"
| _, _ => throw <| IO.userError s!"No such constructor {ctor}"
@@ -140,7 +140,7 @@ when we replayed the inductives.
-/
def checkPostponedRecursors : M Unit := do
for ctor in ( get).postponedRecursors do
match ( get).env.constants.find? ctor, ( read).newConstants[ctor]? with
match ( get).env.constants.find? ctor, ( read).newConstants.find? ctor with
| some (.recInfo info), some (.recInfo info') =>
if ! (info == info') then throw <| IO.userError s!"Invalid recursor {ctor}"
| _, _ => throw <| IO.userError s!"No such recursor {ctor}"
@@ -155,7 +155,7 @@ open Replay
Throws a `IO.userError` if the kernel rejects a constant,
or if there are malformed recursors or constructors for inductive types.
-/
def replay (newConstants : Std.HashMap Name ConstantInfo) (env : Environment) : IO Environment := do
def replay (newConstants : HashMap Name ConstantInfo) (env : Environment) : IO Environment := do
let mut remaining : NameSet :=
for (n, ci) in newConstants.toList do
-- We skip unsafe constants, and also partial constants.

View File

@@ -81,7 +81,7 @@ open Elab
open Meta
open FuzzyMatching
abbrev EligibleHeaderDecls := Std.HashMap Name ConstantInfo
abbrev EligibleHeaderDecls := HashMap Name ConstantInfo
/-- Cached header declarations for which `allowCompletion headerEnv decl` is true. -/
builtin_initialize eligibleHeaderDeclsRef : IO.Ref (Option EligibleHeaderDecls)

View File

@@ -316,7 +316,7 @@ partial def handleDocumentHighlight (p : DocumentHighlightParams)
let refs : Lsp.ModuleRefs findModuleRefs text trees |>.toLspModuleRefs
let mut ranges := #[]
for ident in refs.findAt p.position do
if let some info := refs.get? ident then
if let some info := refs.find? ident then
if let some definitionRange, _ := info.definition? then
ranges := ranges.push definitionRange
ranges := ranges.append <| info.usages.map (·.range)

View File

@@ -93,20 +93,20 @@ def toLspRefInfo (i : RefInfo) : BaseIO Lsp.RefInfo := do
end RefInfo
/-- All references from within a module for all identifiers used in a single module. -/
def ModuleRefs := Std.HashMap RefIdent RefInfo
def ModuleRefs := HashMap RefIdent RefInfo
namespace ModuleRefs
/-- Adds `ref` to the `RefInfo` corresponding to `ref.ident` in `self`. See `RefInfo.addRef`. -/
def addRef (self : ModuleRefs) (ref : Reference) : ModuleRefs :=
let refInfo := self.getD ref.ident RefInfo.empty
let refInfo := self.findD ref.ident RefInfo.empty
self.insert ref.ident (refInfo.addRef ref)
/-- Converts `refs` to a JSON-serializable `Lsp.ModuleRefs`. -/
def toLspModuleRefs (refs : ModuleRefs) : BaseIO Lsp.ModuleRefs := do
let refs refs.toList.mapM fun (k, v) => do
return (k, v.toLspRefInfo)
return Std.HashMap.ofList refs
return HashMap.ofList refs
end ModuleRefs
@@ -261,7 +261,7 @@ all identifiers that are being collapsed into one.
-/
partial def combineIdents (trees : Array InfoTree) (refs : Array Reference) : Array Reference := Id.run do
-- Deduplicate definitions based on their exact range
let mut posMap : Std.HashMap Lsp.Range RefIdent := Std.HashMap.empty
let mut posMap : HashMap Lsp.Range RefIdent := HashMap.empty
for ref in refs do
if let { ident, range, isBinder := true, .. } := ref then
posMap := posMap.insert range ident
@@ -277,17 +277,17 @@ partial def combineIdents (trees : Array InfoTree) (refs : Array Reference) : Ar
refs' := refs'.push ref
refs'
where
useConstRepresentatives (idMap : Std.HashMap RefIdent RefIdent)
: Std.HashMap RefIdent RefIdent := Id.run do
useConstRepresentatives (idMap : HashMap RefIdent RefIdent)
: HashMap RefIdent RefIdent := Id.run do
let insertIntoClass classesById id :=
let representative := findCanonicalRepresentative idMap id
let «class» := classesById.getD representative
let «class» := classesById.findD representative
let classesById := classesById.erase representative -- make `«class»` referentially unique
let «class» := «class».insert id
classesById.insert representative «class»
-- collect equivalence classes
let mut classesById : Std.HashMap RefIdent (Std.HashSet RefIdent) :=
let mut classesById : HashMap RefIdent (HashSet RefIdent) :=
for id, baseId in idMap.toArray do
classesById := insertIntoClass classesById id
classesById := insertIntoClass classesById baseId
@@ -310,17 +310,17 @@ where
r := r.insert id bestRepresentative
return r
findCanonicalRepresentative (idMap : Std.HashMap RefIdent RefIdent) (id : RefIdent) : RefIdent := Id.run do
findCanonicalRepresentative (idMap : HashMap RefIdent RefIdent) (id : RefIdent) : RefIdent := Id.run do
let mut canonicalRepresentative := id
while idMap.contains canonicalRepresentative do
canonicalRepresentative := idMap[canonicalRepresentative]!
canonicalRepresentative := idMap.find! canonicalRepresentative
return canonicalRepresentative
buildIdMap posMap := Id.run <| StateT.run' (s := Std.HashMap.empty) do
buildIdMap posMap := Id.run <| StateT.run' (s := HashMap.empty) do
-- map fvar defs to overlapping fvar defs/uses
for ref in refs do
let baseId := ref.ident
if let some id := posMap[ref.range]? then
if let some id := posMap.find? ref.range then
insertIdMap id baseId
-- apply `FVarAliasInfo`
@@ -346,11 +346,11 @@ are added to the `aliases` of the representative of the group.
Yields to separate groups for declaration and usages if `allowSimultaneousBinderUse` is set.
-/
def dedupReferences (refs : Array Reference) (allowSimultaneousBinderUse := false) : Array Reference := Id.run do
let mut refsByIdAndRange : Std.HashMap (RefIdent × Option Bool × Lsp.Range) Reference := Std.HashMap.empty
let mut refsByIdAndRange : HashMap (RefIdent × Option Bool × Lsp.Range) Reference := HashMap.empty
for ref in refs do
let isBinder := if allowSimultaneousBinderUse then some ref.isBinder else none
let key := (ref.ident, isBinder, ref.range)
refsByIdAndRange := match refsByIdAndRange[key]? with
refsByIdAndRange := match refsByIdAndRange[key] with
| some ref' => refsByIdAndRange.insert key { ref' with aliases := ref'.aliases ++ ref.aliases }
| none => refsByIdAndRange.insert key ref
@@ -371,21 +371,21 @@ def findModuleRefs (text : FileMap) (trees : Array InfoTree) (localVars : Bool :
refs := refs.filter fun
| { ident := RefIdent.fvar .., .. } => false
| _ => true
refs.foldl (init := Std.HashMap.empty) fun m ref => m.addRef ref
refs.foldl (init := HashMap.empty) fun m ref => m.addRef ref
/-! # Collecting and maintaining reference info from different sources -/
/-- References from ilean files and current ilean information from file workers. -/
structure References where
/-- References loaded from ilean files -/
ileans : Std.HashMap Name (System.FilePath × Lsp.ModuleRefs)
ileans : HashMap Name (System.FilePath × Lsp.ModuleRefs)
/-- References from workers, overriding the corresponding ilean files -/
workers : Std.HashMap Name (Nat × Lsp.ModuleRefs)
workers : HashMap Name (Nat × Lsp.ModuleRefs)
namespace References
/-- No ilean files, no information from workers. -/
def empty : References := { ileans := Std.HashMap.empty, workers := Std.HashMap.empty }
def empty : References := { ileans := HashMap.empty, workers := HashMap.empty }
/-- Adds the contents of an ilean file `ilean` at `path` to `self`. -/
def addIlean (self : References) (path : System.FilePath) (ilean : Ilean) : References :=
@@ -404,13 +404,13 @@ Replaces the current references with `refs` if `version` is newer than the curre
in `refs` and otherwise merges the reference data if `version` is equal to the current version.
-/
def updateWorkerRefs (self : References) (name : Name) (version : Nat) (refs : Lsp.ModuleRefs) : References := Id.run do
if let some (currVersion, _) := self.workers[name]? then
if let some (currVersion, _) := self.workers.find? name then
if version > currVersion then
return { self with workers := self.workers.insert name (version, refs) }
if version == currVersion then
let current := self.workers.getD name (version, Std.HashMap.empty)
let current := self.workers.findD name (version, HashMap.empty)
let merged := refs.fold (init := current.snd) fun m ident info =>
m.getD ident Lsp.RefInfo.empty |>.merge info |> m.insert ident
m.findD ident Lsp.RefInfo.empty |>.merge info |> m.insert ident
return { self with workers := self.workers.insert name (version, merged) }
return self
@@ -419,7 +419,7 @@ Replaces the worker references in `self` with the `refs` of the worker managing
if `version` is newer than the current version managed in `refs`.
-/
def finalizeWorkerRefs (self : References) (name : Name) (version : Nat) (refs : Lsp.ModuleRefs) : References := Id.run do
if let some (currVersion, _) := self.workers[name]? then
if let some (currVersion, _) := self.workers.find? name then
if version < currVersion then
return self
return { self with workers := self.workers.insert name (version, refs) }
@@ -429,8 +429,8 @@ def removeWorkerRefs (self : References) (name : Name) : References :=
{ self with workers := self.workers.erase name }
/-- Yields a map from all modules to all of their references. -/
def allRefs (self : References) : Std.HashMap Name Lsp.ModuleRefs :=
let ileanRefs := self.ileans.toArray.foldl (init := Std.HashMap.empty) fun m (name, _, refs) => m.insert name refs
def allRefs (self : References) : HashMap Name Lsp.ModuleRefs :=
let ileanRefs := self.ileans.toArray.foldl (init := HashMap.empty) fun m (name, _, refs) => m.insert name refs
self.workers.toArray.foldl (init := ileanRefs) fun m (name, _, refs) => m.insert name refs
/--
@@ -445,12 +445,12 @@ def allRefsFor
let refsToCheck := match ident with
| RefIdent.const .. => self.allRefs.toArray
| RefIdent.fvar identModule .. =>
match self.allRefs[identModule]? with
match self.allRefs.find? identModule with
| none => #[]
| some refs => #[(identModule, refs)]
let mut result := #[]
for (module, refs) in refsToCheck do
let some info := refs.get? ident
let some info := refs.find? ident
| continue
let some path srcSearchPath.findModuleWithExt "lean" module
| continue
@@ -462,13 +462,13 @@ def allRefsFor
/-- Yields all references in `module` at `pos`. -/
def findAt (self : References) (module : Name) (pos : Lsp.Position) (includeStop := false) : Array RefIdent := Id.run do
if let some refs := self.allRefs[module]? then
if let some refs := self.allRefs.find? module then
return refs.findAt pos includeStop
#[]
/-- Yields the first reference in `module` at `pos`. -/
def findRange? (self : References) (module : Name) (pos : Lsp.Position) (includeStop := false) : Option Range := do
let refs self.allRefs[module]?
let refs self.allRefs.find? module
refs.findRange? pos includeStop
/-- Location and parent declaration of a reference. -/

View File

@@ -90,10 +90,6 @@ section Utils
| crashed (e : IO.Error)
| ioError (e : IO.Error)
inductive CrashOrigin
| fileWorkerToClientForwarding
| clientToFileWorkerForwarding
inductive WorkerState where
/-- The watchdog can detect a crashed file worker in two places: When trying to send a message
to the file worker and when reading a request reply.
@@ -102,7 +98,7 @@ section Utils
that are in-flight are errored. Upon receiving the next packet for that file worker, the file
worker is restarted and the packet is forwarded to it. If the crash was detected while writing
a packet, we queue that packet until the next packet for the file worker arrives. -/
| crashed (queuedMsgs : Array JsonRpc.Message) (origin : CrashOrigin)
| crashed (queuedMsgs : Array JsonRpc.Message)
| running
abbrev PendingRequestMap := RBMap RequestID JsonRpc.Message compare
@@ -140,11 +136,6 @@ section FileWorker
for id, _ in pendingRequests do
hError.writeLspResponseError { id := id, code := code, message := msg }
def queuedMsgs (fw : FileWorker) : Array JsonRpc.Message :=
match fw.state with
| .running => #[]
| .crashed queuedMsgs _ => queuedMsgs
end FileWorker
end FileWorker
@@ -413,23 +404,10 @@ section ServerM
return
eraseFileWorker uri
def handleCrash (uri : DocumentUri) (queuedMsgs : Array JsonRpc.Message) (origin: CrashOrigin) : ServerM Unit := do
def handleCrash (uri : DocumentUri) (queuedMsgs : Array JsonRpc.Message) : ServerM Unit := do
let some fw findFileWorker? uri
| return
updateFileWorkers { fw with state := WorkerState.crashed queuedMsgs origin }
def tryDischargeQueuedMessages (uri : DocumentUri) (queuedMsgs : Array JsonRpc.Message) : ServerM Unit := do
let some fw findFileWorker? uri
| throwServerError "Cannot find file worker for '{uri}'."
let mut crashedMsgs := #[]
-- Try to discharge all queued msgs, tracking the ones that we can't discharge
for msg in queuedMsgs do
try
fw.stdin.writeLspMessage msg
catch _ =>
crashedMsgs := crashedMsgs.push msg
if ¬ crashedMsgs.isEmpty then
handleCrash uri crashedMsgs .clientToFileWorkerForwarding
updateFileWorkers { fw with state := WorkerState.crashed queuedMsgs }
/-- Tries to write a message, sets the state of the FileWorker to `crashed` if it does not succeed
and restarts the file worker if the `crashed` flag was already set. Just logs an error if
@@ -445,7 +423,7 @@ section ServerM
let some fw findFileWorker? uri
| return
match fw.state with
| WorkerState.crashed queuedMsgs _ =>
| WorkerState.crashed queuedMsgs =>
let mut queuedMsgs := queuedMsgs
if queueFailedMessage then
queuedMsgs := queuedMsgs.push msg
@@ -454,7 +432,17 @@ section ServerM
-- restart the crashed FileWorker
eraseFileWorker uri
startFileWorker fw.doc
tryDischargeQueuedMessages uri queuedMsgs
let some newFw findFileWorker? uri
| throwServerError "Cannot find file worker for '{uri}'."
let mut crashedMsgs := #[]
-- try to discharge all queued msgs, tracking the ones that we can't discharge
for msg in queuedMsgs do
try
newFw.stdin.writeLspMessage msg
catch _ =>
crashedMsgs := crashedMsgs.push msg
if ¬ crashedMsgs.isEmpty then
handleCrash uri crashedMsgs
| WorkerState.running =>
let initialQueuedMsgs :=
if queueFailedMessage then
@@ -464,7 +452,7 @@ section ServerM
try
fw.stdin.writeLspMessage msg
catch _ =>
handleCrash uri initialQueuedMsgs .clientToFileWorkerForwarding
handleCrash uri initialQueuedMsgs
/--
Sends a notification to the file worker identified by `uri` that its dependency `staleDependency`
@@ -650,7 +638,7 @@ def handleCallHierarchyOutgoingCalls (p : CallHierarchyOutgoingCallsParams)
let references ( read).references.get
let some refs := references.allRefs[module]?
let some refs := references.allRefs.find? module
| return #[]
let items refs.toArray.filterMapM fun ident, info => do
@@ -714,9 +702,9 @@ def handlePrepareRename (p : PrepareRenameParams) : ServerM (Option Range) := do
def handleRename (p : RenameParams) : ServerM Lsp.WorkspaceEdit := do
if (String.toName p.newName).isAnonymous then
throwServerError s!"Can't rename: `{p.newName}` is not an identifier"
let mut refs : Std.HashMap DocumentUri (RBMap Lsp.Position Lsp.Position compare) :=
let mut refs : HashMap DocumentUri (RBMap Lsp.Position Lsp.Position compare) :=
for { uri, range } in ( handleReference { p with context.includeDeclaration := true }) do
refs := refs.insert uri <| (refs.getD uri ).insert range.start range.end
refs := refs.insert uri <| (refs.findD uri ).insert range.start range.end
-- We have to filter the list of changes to put the ranges in order and
-- remove any duplicates or overlapping ranges, or else the rename will not apply
let changes := refs.fold (init := ) fun changes uri map => Id.run do
@@ -967,16 +955,7 @@ section MainLoop
let workers st.fileWorkersRef.get
let mut workerTasks := #[]
for (_, fw) in workers do
-- When the forwarding task crashes, its return value will be stuck at
-- `WorkerEvent.crashed _`.
-- We want to handle this event only once, not over and over again,
-- so once the state becomes `WorkerState.crashed _ .fileWorkerToClientForwarding`
-- as a result of `WorkerEvent.crashed _`, we stop handling this event until
-- eventually the file worker is restarted by a notification from the client.
-- We do not want to filter the forwarding task in case of
-- `WorkerState.crashed _ .clientToFileWorkerForwarding`, since the forwarding task
-- exit code may still contain valuable information in this case (e.g. that the imports changed).
if !(fw.state matches WorkerState.crashed _ .fileWorkerToClientForwarding) then
if let WorkerState.running := fw.state then
workerTasks := workerTasks.push <| fw.commTask.map (ServerEvent.workerEvent fw)
let ev IO.waitAny (clientTask :: workerTasks.toList)
@@ -1005,16 +984,13 @@ section MainLoop
| WorkerEvent.ioError e =>
throwServerError s!"IO error while processing events for {fw.doc.uri}: {e}"
| WorkerEvent.crashed _ =>
handleCrash fw.doc.uri fw.queuedMsgs .fileWorkerToClientForwarding
handleCrash fw.doc.uri #[]
mainLoop clientTask
| WorkerEvent.terminated =>
throwServerError <| "Internal server error: got termination event for worker that "
++ "should have been removed"
| .importsChanged =>
let uri := fw.doc.uri
let queuedMsgs := fw.queuedMsgs
startFileWorker fw.doc
tryDischargeQueuedMessages uri queuedMsgs
mainLoop clientTask
end MainLoop

View File

@@ -6,7 +6,6 @@ Authors: David Thrane Christiansen
prelude
import Init.Data
import Lean.Data.HashMap
import Std.Data.HashMap.Basic
import Init.Omega
namespace Lean.Diff
@@ -58,7 +57,7 @@ structure Histogram.Entry (α : Type u) (lsize rsize : Nat) where
/-- A histogram for arrays maps each element to a count and, if applicable, an index.-/
def Histogram (α : Type u) (lsize rsize : Nat) [BEq α] [Hashable α] :=
Std.HashMap α (Histogram.Entry α lsize rsize)
Lean.HashMap α (Histogram.Entry α lsize rsize)
section
@@ -68,7 +67,7 @@ variable [BEq α] [Hashable α]
/-- Add an element from the left array to a histogram -/
def Histogram.addLeft (histogram : Histogram α lsize rsize) (index : Fin lsize) (val : α)
: Histogram α lsize rsize :=
match histogram.get? val with
match histogram.find? val with
| none => histogram.insert val {
leftCount := 1, leftIndex := some index,
leftWF := by simp,
@@ -82,7 +81,7 @@ def Histogram.addLeft (histogram : Histogram α lsize rsize) (index : Fin lsize)
/-- Add an element from the right array to a histogram -/
def Histogram.addRight (histogram : Histogram α lsize rsize) (index : Fin rsize) (val : α)
: Histogram α lsize rsize :=
match histogram.get? val with
match histogram.find? val with
| none => histogram.insert val {
leftCount := 0, leftIndex := none,
leftWF := by simp,

View File

@@ -28,7 +28,7 @@ structure State where
Set of visited subterms that satisfy the predicate `p`.
We have to use this set to make sure `f` is applied at most once of each subterm that satisfies `p`.
-/
checked : Std.HashSet Expr
checked : HashSet Expr
unsafe def initCache : State := {
visited := mkArray cacheSize.toNat (cast lcProof ())

View File

@@ -5,15 +5,14 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.Expr
import Std.Data.HashMap.Raw
namespace Lean
structure HasConstCache (declNames : Array Name) where
cache : Std.HashMap.Raw Expr Bool := Std.HashMap.Raw.empty
cache : HashMapImp Expr Bool := mkHashMapImp
unsafe def HasConstCache.containsUnsafe (e : Expr) : StateM (HasConstCache declNames) Bool := do
if let some r := ( get).cache.get? (beq := ptrEq) e then
if let some r := ( get).cache.find? (beq := ptrEq) e then
return r
else
match e with
@@ -27,7 +26,7 @@ unsafe def HasConstCache.containsUnsafe (e : Expr) : StateM (HasConstCache declN
| _ => return false
where
cache (e : Expr) (r : Bool) : StateM (HasConstCache declNames) Bool := do
modify fun cache => cache.insert (beq := ptrEq) e r
modify fun cache => cache.insert (beq := ptrEq) e r |>.1
return r
/--

View File

@@ -6,7 +6,6 @@ Authors: Leonardo de Moura
prelude
import Init.Control.StateRef
import Lean.Data.HashMap
import Std.Data.HashMap.Basic
namespace Lean
/-- Interface for caching results. -/
@@ -37,15 +36,15 @@ instance {α β ε : Type} {m : Type → Type} [MonadCache α β m] [Monad m] :
/-- Adapter for implementing `MonadCache` interface using `HashMap`s.
We just have to specify how to extract/modify the `HashMap`. -/
class MonadHashMapCacheAdapter (α β : Type) (m : Type Type) [BEq α] [Hashable α] where
getCache : m (Std.HashMap α β)
modifyCache : (Std.HashMap α β Std.HashMap α β) m Unit
getCache : m (HashMap α β)
modifyCache : (HashMap α β HashMap α β) m Unit
namespace MonadHashMapCacheAdapter
@[always_inline, inline]
def findCached? {α β : Type} {m : Type Type} [BEq α] [Hashable α] [Monad m] [MonadHashMapCacheAdapter α β m] (a : α) : m (Option β) := do
let c getCache
pure (c.get? a)
pure (c.find? a)
@[always_inline, inline]
def cache {α β : Type} {m : Type Type} [BEq α] [Hashable α] [MonadHashMapCacheAdapter α β m] (a : α) (b : β) : m Unit :=
@@ -57,7 +56,7 @@ instance {α β : Type} {m : Type → Type} [BEq α] [Hashable α] [Monad m] [Mo
end MonadHashMapCacheAdapter
def MonadCacheT {ω} (α β : Type) (m : Type Type) [STWorld ω m] [BEq α] [Hashable α] := StateRefT (Std.HashMap α β) m
def MonadCacheT {ω} (α β : Type) (m : Type Type) [STWorld ω m] [BEq α] [Hashable α] := StateRefT (HashMap α β) m
namespace MonadCacheT
@@ -68,7 +67,7 @@ instance : MonadHashMapCacheAdapter α β (MonadCacheT α β m) where
modifyCache f := (modify f : StateRefT' ..)
@[inline] def run {σ} (x : MonadCacheT α β m σ) : m σ :=
x.run' Std.HashMap.empty
x.run' mkHashMap
instance : Monad (MonadCacheT α β m) := inferInstanceAs (Monad (StateRefT' _ _ _))
instance : MonadLift m (MonadCacheT α β m) := inferInstanceAs (MonadLift m (StateRefT' _ _ _))
@@ -81,7 +80,7 @@ instance [Alternative m] : Alternative (MonadCacheT α β m) := inferInstanceAs
end MonadCacheT
/- Similar to `MonadCacheT`, but using `StateT` instead of `StateRefT` -/
def MonadStateCacheT (α β : Type) (m : Type Type) [BEq α] [Hashable α] := StateT (Std.HashMap α β) m
def MonadStateCacheT (α β : Type) (m : Type Type) [BEq α] [Hashable α] := StateT (HashMap α β) m
namespace MonadStateCacheT
@@ -92,7 +91,7 @@ instance : MonadHashMapCacheAdapter α β (MonadStateCacheT α β m) where
modifyCache f := (modify f : StateT ..)
@[always_inline, inline] def run {σ} (x : MonadStateCacheT α β m σ) : m σ :=
x.run' Std.HashMap.empty
x.run' mkHashMap
instance : Monad (MonadStateCacheT α β m) := inferInstanceAs (Monad (StateT _ _))
instance : MonadLift m (MonadStateCacheT α β m) := inferInstanceAs (MonadLift m (StateT _ _))

View File

@@ -96,9 +96,9 @@ deriving FromJson, ToJson
/-- Thread with maps necessary for computing max sharing indices -/
structure ThreadWithMaps extends Thread where
stringMap : Std.HashMap String Nat := {}
funcMap : Std.HashMap Nat Nat := {}
stackMap : Std.HashMap (Nat × Option Nat) Nat := {}
stringMap : HashMap String Nat := {}
funcMap : HashMap Nat Nat := {}
stackMap : HashMap (Nat × Option Nat) Nat := {}
/-- Last timestamp encountered: stop time of preceding sibling, or else start time of parent. -/
lastTime : Float := 0
@@ -123,7 +123,7 @@ where
if pp then
funcName := s!"{funcName}: {← msg.format}"
let strIdx modifyGet fun thread =>
if let some idx := thread.stringMap[funcName]? then
if let some idx := thread.stringMap.find? funcName then
(idx, thread)
else
(thread.stringMap.size, { thread with
@@ -131,7 +131,7 @@ where
stringMap := thread.stringMap.insert funcName thread.stringMap.size })
let category := categories.findIdx? (·.name == data.cls.getRoot.toString) |>.getD 0
let funcIdx modifyGet fun thread =>
if let some idx := thread.funcMap[strIdx]? then
if let some idx := thread.funcMap.find? strIdx then
(idx, thread)
else
(thread.funcMap.size, { thread with
@@ -151,7 +151,7 @@ where
funcMap := thread.funcMap.insert strIdx thread.funcMap.size })
let frameIdx := funcIdx
let stackIdx modifyGet fun thread =>
if let some idx := thread.stackMap[(frameIdx, parentStackIdx?)]? then
if let some idx := thread.stackMap.find? (frameIdx, parentStackIdx?) then
(idx, thread)
else
(thread.stackMap.size, { thread with
@@ -222,7 +222,7 @@ def Profile.export (name : String) (startTime : Milliseconds) (traceState : Trac
structure ThreadWithCollideMaps extends ThreadWithMaps where
/-- Max sharing map for samples -/
sampleMap : Std.HashMap Nat Nat := {}
sampleMap : HashMap Nat Nat := {}
/--
Adds samples from `add` to `thread`, increasing the weight of existing samples with identical stacks
@@ -237,7 +237,7 @@ where
let oldStackIdx := add.samples.stack[oldSampleIdx]!
let stackIdx collideStacks oldStackIdx
modify fun thread =>
if let some idx := thread.sampleMap[stackIdx]? then
if let some idx := thread.sampleMap.find? stackIdx then
-- imperative to preserve linear use of arrays here!
let t1, t2, t3, samples, t5, t6, t7, t8, t9, t10, o2, o3, o4, o5, o6 := thread
let s1, s2, weight, s3, s4 := samples
@@ -265,7 +265,7 @@ where
let oldStrIdx := add.funcTable.name[oldFuncIdx]!
let strIdx getStrIdx add.stringArray[oldStrIdx]!
let funcIdx modifyGet fun thread =>
if let some idx := thread.funcMap[strIdx]? then
if let some idx := thread.funcMap.find? strIdx then
(idx, thread)
else
(thread.funcMap.size, { thread with
@@ -284,7 +284,7 @@ where
funcMap := thread.funcMap.insert strIdx thread.funcMap.size })
let frameIdx := funcIdx
modifyGet fun thread =>
if let some idx := thread.stackMap[(frameIdx, parentStackIdx?)]? then
if let some idx := thread.stackMap.find? (frameIdx, parentStackIdx?) then
(idx, thread)
else
(thread.stackMap.size,
@@ -302,7 +302,7 @@ where
t1,t2, t3, t4, t5, stackTable, t7, t8, t9, t10, o2, o3, stackMap, o5, o6)
getStrIdx (s : String) :=
modifyGet fun thread =>
if let some idx := thread.stringMap[s]? then
if let some idx := thread.stringMap.find? s then
(idx, thread)
else
(thread.stringMap.size, { thread with

View File

@@ -7,8 +7,6 @@ prelude
import Init.Data.Hashable
import Lean.Data.HashSet
import Lean.Data.HashMap
import Std.Data.HashSet.Basic
import Std.Data.HashMap.Basic
namespace Lean
@@ -25,33 +23,33 @@ unsafe instance : BEq (Ptr α) where
Set of pointers. It is a low-level auxiliary datastructure used for traversing DAGs.
-/
unsafe def PtrSet (α : Type) :=
Std.HashSet (Ptr α)
HashSet (Ptr α)
unsafe def mkPtrSet {α : Type} (capacity : Nat := 64) : PtrSet α :=
Std.HashSet.empty capacity
mkHashSet capacity
unsafe abbrev PtrSet.insert (s : PtrSet α) (a : α) : PtrSet α :=
Std.HashSet.insert s { value := a }
HashSet.insert s { value := a }
unsafe abbrev PtrSet.contains (s : PtrSet α) (a : α) : Bool :=
Std.HashSet.contains s { value := a }
HashSet.contains s { value := a }
/--
Map of pointers. It is a low-level auxiliary datastructure used for traversing DAGs.
-/
unsafe def PtrMap (α : Type) (β : Type) :=
Std.HashMap (Ptr α) β
HashMap (Ptr α) β
unsafe def mkPtrMap {α β : Type} (capacity : Nat := 64) : PtrMap α β :=
Std.HashMap.empty capacity
mkHashMap capacity
unsafe abbrev PtrMap.insert (s : PtrMap α β) (a : α) (b : β) : PtrMap α β :=
Std.HashMap.insert s { value := a } b
HashMap.insert s { value := a } b
unsafe abbrev PtrMap.contains (s : PtrMap α β) (a : α) : Bool :=
Std.HashMap.contains s { value := a }
HashMap.contains s { value := a }
unsafe abbrev PtrMap.find? (s : PtrMap α β) (a : α) : Option β :=
Std.HashMap.get? s { value := a }
HashMap.find? s { value := a }
end Lean

View File

@@ -6,7 +6,6 @@ Authors: Leonardo de Moura
prelude
import Init.Data.List.Control
import Lean.Data.HashMap
import Std.Data.HashMap.Basic
namespace Lean.SCC
/-!
Very simple implementation of Tarjan's SCC algorithm.
@@ -26,7 +25,7 @@ structure Data where
structure State where
stack : List α := []
nextIndex : Nat := 0
data : Std.HashMap α Data := {}
data : HashMap α Data := {}
sccs : List (List α) := []
abbrev M := StateM (State α)
@@ -36,7 +35,7 @@ variable {α : Type} [BEq α] [Hashable α]
private def getDataOf (a : α) : M α Data := do
let s get
match s.data[a]? with
match s.data.find? a with
| some d => pure d
| none => pure {}
@@ -53,7 +52,7 @@ private def push (a : α) : M α Unit :=
private def modifyDataOf (a : α) (f : Data Data) : M α Unit :=
modify fun s => { s with
data := match s.data[a]? with
data := match s.data.find? a with
| none => s.data
| some d => s.data.insert a (f d)
}

View File

@@ -13,7 +13,6 @@ import Lean.Data.PersistentHashSet
open ShareCommon
namespace Lean.ShareCommon
set_option linter.deprecated false in
def objectFactory :=
StateFactory.mk {
Map := HashMap, mkMap := (mkHashMap ·), mapFind? := (·.find?), mapInsert := (·.insert)

View File

@@ -67,7 +67,7 @@ structure TraceState where
traces : PersistentArray TraceElem := {}
deriving Inhabited
builtin_initialize inheritedTraceOptions : IO.Ref (Std.HashSet Name) IO.mkRef
builtin_initialize inheritedTraceOptions : IO.Ref (HashSet Name) IO.mkRef
class MonadTrace (m : Type Type) where
modifyTraceState : (TraceState TraceState) m Unit
@@ -88,7 +88,7 @@ def printTraces : m Unit := do
def resetTraceState : m Unit :=
modifyTraceState (fun _ => {})
private def checkTraceOption (inherited : Std.HashSet Name) (opts : Options) (cls : Name) : Bool :=
private def checkTraceOption (inherited : HashSet Name) (opts : Options) (cls : Name) : Bool :=
!opts.isEmpty && go (`trace ++ cls)
where
go (opt : Name) : Bool :=

View File

@@ -5,4 +5,3 @@ Authors: Sebastian Ullrich
-/
prelude
import Std.Data
import Std.Sat

View File

@@ -20,7 +20,7 @@ open Std.DHashMap.Internal.List
universe u v
variable {α : Type u} {β : α Type v}
variable {α : Type u} {β : α Type v} [BEq α] [Hashable α]
namespace Std.DHashMap.Internal
@@ -41,8 +41,6 @@ theorem Raw.buckets_emptyc {i : Nat} {h} :
( : Raw α β).buckets[i]'h = AssocList.nil :=
buckets_empty
variable [BEq α] [Hashable α]
@[simp]
theorem buckets_empty {c} {i : Nat} {h} :
(empty c : DHashMap α β).1.buckets[i]'h = AssocList.nil := by
@@ -57,9 +55,7 @@ end empty
namespace Raw₀
variable [BEq α] [Hashable α]
variable (m : Raw₀ α β) (h : m.1.WF)
set_option deprecated.oldSectionVars true
/-- Internal implementation detail of the hash map -/
scoped macro "wf_trivial" : tactic => `(tactic|

View File

@@ -75,7 +75,6 @@ namespace Raw
open Internal.Raw₀ Internal.Raw
variable {m : Raw α β} (h : m.WF)
set_option deprecated.oldSectionVars true
@[simp]
theorem isEmpty_empty {c} : (empty c : Raw α β).isEmpty := by

View File

@@ -112,10 +112,6 @@ Tries to retrieve the mapping for the given key, returning `none` if no such map
@[inline] def get? (m : HashMap α β) (a : α) : Option β :=
DHashMap.Const.get? m.inner a
@[deprecated get? "Use `m[a]?` or `m.get? a` instead", inherit_doc get?]
def find? (m : HashMap α β) (a : α) : Option β :=
m.get? a
@[inline, inherit_doc DHashMap.contains] def contains (m : HashMap α β)
(a : α) : Bool :=
m.inner.contains a
@@ -139,10 +135,6 @@ Retrieves the mapping for the given key. Ensures that such a mapping exists by r
(fallback : β) : β :=
DHashMap.Const.getD m.inner a fallback
@[deprecated getD, inherit_doc getD]
def findD (m : HashMap α β) (a : α) (fallback : β) : β :=
m.getD a fallback
/--
The notation `m[a]!` is preferred over calling this function directly.
@@ -151,10 +143,6 @@ Tries to retrieve the mapping for the given key, panicking if no such mapping is
@[inline] def get! [Inhabited β] (m : HashMap α β) (a : α) : β :=
DHashMap.Const.get! m.inner a
@[deprecated get! "Use `m[a]!` or `m.get! a` instead", inherit_doc get!]
def find! [Inhabited β] (m : HashMap α β) (a : α) : Option β :=
m.get! a
instance [BEq α] [Hashable α] : GetElem? (HashMap α β) α β (fun m a => a m) where
getElem m a h := m.get a h
getElem? m a := m.get? a
@@ -248,16 +236,3 @@ instance [BEq α] [Hashable α] [Repr α] [Repr β] : Repr (HashMap α β) where
end Unverified
end Std.HashMap
/--
Groups all elements `x`, `y` in `xs` with `key x == key y` into the same array
`(xs.groupByKey key).find! (key x)`. Groups preserve the relative order of elements in `xs`.
-/
def Array.groupByKey [BEq α] [Hashable α] (key : β α) (xs : Array β)
: Std.HashMap α (Array β) := Id.run do
let mut groups :=
for x in xs do
let group := groups.getD (key x) #[]
groups := groups.erase (key x) -- make `group` referentially unique
groups := groups.insert (key x) (group.push x)
return groups

Some files were not shown because too many files have changed in this diff Show More