mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-26 14:54:15 +00:00
Compare commits
4 Commits
expr_eq_ex
...
expr_eq_pe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8f6ba0e107 | ||
|
|
3439180a3c | ||
|
|
01322069e3 | ||
|
|
c9d9baa3c8 |
@@ -5,7 +5,7 @@ Some notes on how to debug Lean, which may also be applicable to debugging Lean
|
||||
|
||||
## Tracing
|
||||
|
||||
In `CoreM` and derived monads, we use `trace[traceCls] "msg with {interpolations}"` to fill the structured trace viewable with `set_option trace.traceCls true`.
|
||||
In `CoreM` and derived monads, we use `trace![traceCls] "msg with {interpolations}"` to fill the structured trace viewable with `set_option trace.traceCls true`.
|
||||
New trace classes have to be registered using `registerTraceClass` first.
|
||||
|
||||
Notable trace classes:
|
||||
@@ -22,9 +22,7 @@ Notable trace classes:
|
||||
|
||||
In pure contexts or when execution is aborted before the messages are finally printed, one can instead use the term `dbg_trace "msg with {interpolations}"; val` (`;` can also be replaced by a newline), which will print the message to stderr before evaluating `val`. `dbgTraceVal val` can be used as a shorthand for `dbg_trace "{val}"; val`.
|
||||
Note that if the return value is not actually used, the trace code is silently dropped as well.
|
||||
|
||||
By default, such stderr output is buffered and shown as messages after a command has been elaborated, which is necessary to ensure deterministic ordering of messages under parallelism.
|
||||
If Lean aborts the process before it can finish the command or takes too long to do that, using `-DstderrAsMessages=false` avoids this buffering and shows `dbg_trace` output (but not `trace`s or other diagnostics) immediately.
|
||||
In the language server, stderr output is buffered and shown as messages after a command has been elaborated, unless the option `server.stderrAsMessages` is deactivated.
|
||||
|
||||
## Debuggers
|
||||
|
||||
|
||||
@@ -152,26 +152,22 @@ We'll use `v4.7.0-rc1` as the intended release version in this example.
|
||||
This will add a list of all the commits since the last stable version.
|
||||
- Delete "update stage0" commits, and anything with a completely inscrutable commit message.
|
||||
- Next, we will move a curated list of downstream repos to the release candidate.
|
||||
- This assumes that for each repository either:
|
||||
* There is already a *reviewed* branch `bump/v4.7.0` containing the required adaptations.
|
||||
The preparation of this branch is beyond the scope of this document.
|
||||
* The repository does not need any changes to move to the new version.
|
||||
- This assumes that there is already a *reviewed* branch `bump/v4.7.0` on each repository
|
||||
containing the required adaptations (or no adaptations are required).
|
||||
The preparation of this branch is beyond the scope of this document.
|
||||
- For each of the target repositories:
|
||||
- If the repository does not need any changes (i.e. `bump/v4.7.0` does not exist) then create
|
||||
a new PR updating `lean-toolchain` to `leanprover/lean4:v4.7.0-rc1` and running `lake update`.
|
||||
- Otherwise:
|
||||
- Checkout the `bump/v4.7.0` branch.
|
||||
- Verify that the `lean-toolchain` is set to the nightly from which the release candidate was created.
|
||||
- `git merge origin/master`
|
||||
- Change the `lean-toolchain` to `leanprover/lean4:v4.7.0-rc1`
|
||||
- In `lakefile.lean`, change any dependencies which were using `nightly-testing` or `bump/v4.7.0` branches
|
||||
back to `master` or `main`, and run `lake update` for those dependencies.
|
||||
- Run `lake build` to ensure that dependencies are found (but it's okay to stop it after a moment).
|
||||
- `git commit`
|
||||
- `git push`
|
||||
- Open a PR from `bump/v4.7.0` to `master`, and either merge it yourself after CI, if appropriate,
|
||||
or notify the maintainers that it is ready to go.
|
||||
- Once the PR has been merged, tag `master` with `v4.7.0-rc1` and push this tag.
|
||||
- Checkout the `bump/v4.7.0` branch.
|
||||
- Verify that the `lean-toolchain` is set to the nightly from which the release candidate was created.
|
||||
- `git merge origin/master`
|
||||
- Change the `lean-toolchain` to `leanprover/lean4:v4.7.0-rc1`
|
||||
- In `lakefile.lean`, change any dependencies which were using `nightly-testing` or `bump/v4.7.0` branches
|
||||
back to `master` or `main`, and run `lake update` for those dependencies.
|
||||
- Run `lake build` to ensure that dependencies are found (but it's okay to stop it after a moment).
|
||||
- `git commit`
|
||||
- `git push`
|
||||
- Open a PR from `bump/v4.7.0` to `master`, and either merge it yourself after CI, if appropriate,
|
||||
or notify the maintainers that it is ready to go.
|
||||
- Once this PR has been merged, tag `master` with `v4.7.0-rc1` and push this tag.
|
||||
- We do this for the same list of repositories as for stable releases, see above.
|
||||
As above, there are dependencies between these, and so the process above is iterative.
|
||||
It greatly helps if you can merge the `bump/v4.7.0` PRs yourself!
|
||||
|
||||
@@ -27,9 +27,9 @@ Setting up a basic parallelized release build:
|
||||
git clone https://github.com/leanprover/lean4
|
||||
cd lean4
|
||||
cmake --preset release
|
||||
make -C build/release -j$(nproc || sysctl -n hw.logicalcpu)
|
||||
make -C build/release -j$(nproc) # see below for macOS
|
||||
```
|
||||
You can replace `$(nproc || sysctl -n hw.logicalcpu)` with the desired parallelism amount.
|
||||
You can replace `$(nproc)`, which is not available on macOS and some alternative shells, with the desired parallelism amount.
|
||||
|
||||
The above commands will compile the Lean library and binaries into the
|
||||
`stage1` subfolder; see below for details.
|
||||
|
||||
@@ -7,7 +7,6 @@ prelude
|
||||
import Init.Data.Nat.MinMax
|
||||
import Init.Data.Nat.Lemmas
|
||||
import Init.Data.List.Monadic
|
||||
import Init.Data.List.Nat.Range
|
||||
import Init.Data.Fin.Basic
|
||||
import Init.Data.Array.Mem
|
||||
import Init.TacticsExtra
|
||||
@@ -337,10 +336,6 @@ theorem not_mem_nil (a : α) : ¬ a ∈ #[] := nofun
|
||||
|
||||
/-- # get lemmas -/
|
||||
|
||||
theorem lt_of_getElem {x : α} {a : Array α} {idx : Nat} {hidx : idx < a.size} (_ : a[idx] = x) :
|
||||
idx < a.size :=
|
||||
hidx
|
||||
|
||||
theorem getElem?_mem {l : Array α} {i : Fin l.size} : l[i] ∈ l := by
|
||||
erw [Array.mem_def, getElem_eq_data_getElem]
|
||||
apply List.get_mem
|
||||
@@ -510,13 +505,6 @@ theorem size_eq_length_data (as : Array α) : as.size = as.data.length := rfl
|
||||
simp only [mkEmpty_eq, size_push] at *
|
||||
omega
|
||||
|
||||
@[simp] theorem data_range (n : Nat) : (range n).data = List.range n := by
|
||||
induction n <;> simp_all [range, Nat.fold, flip, List.range_succ]
|
||||
|
||||
@[simp]
|
||||
theorem getElem_range {n : Nat} {x : Nat} (h : x < (Array.range n).size) : (Array.range n)[x] = x := by
|
||||
simp [getElem_eq_data_getElem]
|
||||
|
||||
set_option linter.deprecated false in
|
||||
@[simp] theorem reverse_data (a : Array α) : a.reverse.data = a.data.reverse := by
|
||||
let rec go (as : Array α) (i j hj)
|
||||
@@ -719,22 +707,13 @@ theorem mapIdx_spec (as : Array α) (f : Fin as.size → α → β)
|
||||
unfold modify modifyM Id.run
|
||||
split <;> simp
|
||||
|
||||
theorem getElem_modify {as : Array α} {x i} (h : i < as.size) :
|
||||
(as.modify x f)[i]'(by simp [h]) = if x = i then f as[i] else as[i] := by
|
||||
simp only [modify, modifyM, get_eq_getElem, Id.run, Id.pure_eq]
|
||||
split
|
||||
· simp only [Id.bind_eq, get_set _ _ _ h]; split <;> simp [*]
|
||||
theorem get_modify {arr : Array α} {x i} (h : i < arr.size) :
|
||||
(arr.modify x f).get ⟨i, by simp [h]⟩ =
|
||||
if x = i then f (arr.get ⟨i, h⟩) else arr.get ⟨i, h⟩ := by
|
||||
simp [modify, modifyM, Id.run]; split
|
||||
· simp [get_set _ _ _ h]; split <;> simp [*]
|
||||
· rw [if_neg (mt (by rintro rfl; exact h) ‹_›)]
|
||||
|
||||
theorem getElem_modify_self {as : Array α} {i : Nat} (h : i < as.size) (f : α → α) :
|
||||
(as.modify i f)[i]'(by simp [h]) = f as[i] := by
|
||||
simp [getElem_modify h]
|
||||
|
||||
theorem getElem_modify_of_ne {as : Array α} {i : Nat} (hj : j < as.size)
|
||||
(f : α → α) (h : i ≠ j) :
|
||||
(as.modify i f)[j]'(by rwa [size_modify]) = as[j] := by
|
||||
simp [getElem_modify hj, h]
|
||||
|
||||
/-! ### filter -/
|
||||
|
||||
@[simp] theorem filter_data (p : α → Bool) (l : Array α) :
|
||||
|
||||
@@ -42,8 +42,8 @@ Bitvectors have decidable equality. This should be used via the instance `Decida
|
||||
-- We manually derive the `DecidableEq` instances for `BitVec` because
|
||||
-- we want to have builtin support for bit-vector literals, and we
|
||||
-- need a name for this function to implement `canUnfoldAtMatcher` at `WHNF.lean`.
|
||||
def BitVec.decEq (x y : BitVec n) : Decidable (x = y) :=
|
||||
match x, y with
|
||||
def BitVec.decEq (a b : BitVec n) : Decidable (a = b) :=
|
||||
match a, b with
|
||||
| ⟨n⟩, ⟨m⟩ =>
|
||||
if h : n = m then
|
||||
isTrue (h ▸ rfl)
|
||||
@@ -69,9 +69,9 @@ protected def ofNat (n : Nat) (i : Nat) : BitVec n where
|
||||
instance instOfNat : OfNat (BitVec n) i where ofNat := .ofNat n i
|
||||
instance natCastInst : NatCast (BitVec w) := ⟨BitVec.ofNat w⟩
|
||||
|
||||
/-- Given a bitvector `x`, return the underlying `Nat`. This is O(1) because `BitVec` is a
|
||||
/-- Given a bitvector `a`, return the underlying `Nat`. This is O(1) because `BitVec` is a
|
||||
(zero-cost) wrapper around a `Nat`. -/
|
||||
protected def toNat (x : BitVec n) : Nat := x.toFin.val
|
||||
protected def toNat (a : BitVec n) : Nat := a.toFin.val
|
||||
|
||||
/-- Return the bound in terms of toNat. -/
|
||||
theorem isLt (x : BitVec w) : x.toNat < 2^w := x.toFin.isLt
|
||||
@@ -123,18 +123,18 @@ section getXsb
|
||||
@[inline] def getMsb (x : BitVec w) (i : Nat) : Bool := i < w && getLsb x (w-1-i)
|
||||
|
||||
/-- Return most-significant bit in bitvector. -/
|
||||
@[inline] protected def msb (x : BitVec n) : Bool := getMsb x 0
|
||||
@[inline] protected def msb (a : BitVec n) : Bool := getMsb a 0
|
||||
|
||||
end getXsb
|
||||
|
||||
section Int
|
||||
|
||||
/-- Interpret the bitvector as an integer stored in two's complement form. -/
|
||||
protected def toInt (x : BitVec n) : Int :=
|
||||
if 2 * x.toNat < 2^n then
|
||||
x.toNat
|
||||
protected def toInt (a : BitVec n) : Int :=
|
||||
if 2 * a.toNat < 2^n then
|
||||
a.toNat
|
||||
else
|
||||
(x.toNat : Int) - (2^n : Nat)
|
||||
(a.toNat : Int) - (2^n : Nat)
|
||||
|
||||
/-- The `BitVec` with value `(2^n + (i mod 2^n)) mod 2^n`. -/
|
||||
protected def ofInt (n : Nat) (i : Int) : BitVec n := .ofNatLt (i % (Int.ofNat (2^n))).toNat (by
|
||||
@@ -215,7 +215,7 @@ instance : Neg (BitVec n) := ⟨.neg⟩
|
||||
/--
|
||||
Return the absolute value of a signed bitvector.
|
||||
-/
|
||||
protected def abs (x : BitVec n) : BitVec n := if x.msb then .neg x else x
|
||||
protected def abs (s : BitVec n) : BitVec n := if s.msb then .neg s else s
|
||||
|
||||
/--
|
||||
Multiplication for bit vectors. This can be interpreted as either signed or unsigned negation
|
||||
@@ -262,12 +262,12 @@ sdiv 5#4 -2 = -2#4
|
||||
sdiv (-7#4) (-2) = 3#4
|
||||
```
|
||||
-/
|
||||
def sdiv (x y : BitVec n) : BitVec n :=
|
||||
match x.msb, y.msb with
|
||||
| false, false => udiv x y
|
||||
| false, true => .neg (udiv x (.neg y))
|
||||
| true, false => .neg (udiv (.neg x) y)
|
||||
| true, true => udiv (.neg x) (.neg y)
|
||||
def sdiv (s t : BitVec n) : BitVec n :=
|
||||
match s.msb, t.msb with
|
||||
| false, false => udiv s t
|
||||
| false, true => .neg (udiv s (.neg t))
|
||||
| true, false => .neg (udiv (.neg s) t)
|
||||
| true, true => udiv (.neg s) (.neg t)
|
||||
|
||||
/--
|
||||
Signed division for bit vectors using SMTLIB rules for division by zero.
|
||||
@@ -276,40 +276,40 @@ Specifically, `smtSDiv x 0 = if x >= 0 then -1 else 1`
|
||||
|
||||
SMT-Lib name: `bvsdiv`.
|
||||
-/
|
||||
def smtSDiv (x y : BitVec n) : BitVec n :=
|
||||
match x.msb, y.msb with
|
||||
| false, false => smtUDiv x y
|
||||
| false, true => .neg (smtUDiv x (.neg y))
|
||||
| true, false => .neg (smtUDiv (.neg x) y)
|
||||
| true, true => smtUDiv (.neg x) (.neg y)
|
||||
def smtSDiv (s t : BitVec n) : BitVec n :=
|
||||
match s.msb, t.msb with
|
||||
| false, false => smtUDiv s t
|
||||
| false, true => .neg (smtUDiv s (.neg t))
|
||||
| true, false => .neg (smtUDiv (.neg s) t)
|
||||
| true, true => smtUDiv (.neg s) (.neg t)
|
||||
|
||||
/--
|
||||
Remainder for signed division rounding to zero.
|
||||
|
||||
SMT_Lib name: `bvsrem`.
|
||||
-/
|
||||
def srem (x y : BitVec n) : BitVec n :=
|
||||
match x.msb, y.msb with
|
||||
| false, false => umod x y
|
||||
| false, true => umod x (.neg y)
|
||||
| true, false => .neg (umod (.neg x) y)
|
||||
| true, true => .neg (umod (.neg x) (.neg y))
|
||||
def srem (s t : BitVec n) : BitVec n :=
|
||||
match s.msb, t.msb with
|
||||
| false, false => umod s t
|
||||
| false, true => umod s (.neg t)
|
||||
| true, false => .neg (umod (.neg s) t)
|
||||
| true, true => .neg (umod (.neg s) (.neg t))
|
||||
|
||||
/--
|
||||
Remainder for signed division rounded to negative infinity.
|
||||
|
||||
SMT_Lib name: `bvsmod`.
|
||||
-/
|
||||
def smod (x y : BitVec m) : BitVec m :=
|
||||
match x.msb, y.msb with
|
||||
| false, false => umod x y
|
||||
def smod (s t : BitVec m) : BitVec m :=
|
||||
match s.msb, t.msb with
|
||||
| false, false => umod s t
|
||||
| false, true =>
|
||||
let u := umod x (.neg y)
|
||||
(if u = .zero m then u else .add u y)
|
||||
let u := umod s (.neg t)
|
||||
(if u = .zero m then u else .add u t)
|
||||
| true, false =>
|
||||
let u := umod (.neg x) y
|
||||
(if u = .zero m then u else .sub y u)
|
||||
| true, true => .neg (umod (.neg x) (.neg y))
|
||||
let u := umod (.neg s) t
|
||||
(if u = .zero m then u else .sub t u)
|
||||
| true, true => .neg (umod (.neg s) (.neg t))
|
||||
|
||||
end arithmetic
|
||||
|
||||
@@ -373,8 +373,8 @@ end relations
|
||||
|
||||
section cast
|
||||
|
||||
/-- `cast eq x` embeds `x` into an equal `BitVec` type. -/
|
||||
@[inline] def cast (eq : n = m) (x : BitVec n) : BitVec m := .ofNatLt x.toNat (eq ▸ x.isLt)
|
||||
/-- `cast eq i` embeds `i` into an equal `BitVec` type. -/
|
||||
@[inline] def cast (eq : n = m) (i : BitVec n) : BitVec m := .ofNatLt i.toNat (eq ▸ i.isLt)
|
||||
|
||||
@[simp] theorem cast_ofNat {n m : Nat} (h : n = m) (x : Nat) :
|
||||
cast h (BitVec.ofNat n x) = BitVec.ofNat m x := by
|
||||
@@ -391,7 +391,7 @@ Extraction of bits `start` to `start + len - 1` from a bit vector of size `n` to
|
||||
new bitvector of size `len`. If `start + len > n`, then the vector will be zero-padded in the
|
||||
high bits.
|
||||
-/
|
||||
def extractLsb' (start len : Nat) (x : BitVec n) : BitVec len := .ofNat _ (x.toNat >>> start)
|
||||
def extractLsb' (start len : Nat) (a : BitVec n) : BitVec len := .ofNat _ (a.toNat >>> start)
|
||||
|
||||
/--
|
||||
Extraction of bits `hi` (inclusive) down to `lo` (inclusive) from a bit vector of size `n` to
|
||||
@@ -399,12 +399,12 @@ yield a new bitvector of size `hi - lo + 1`.
|
||||
|
||||
SMT-Lib name: `extract`.
|
||||
-/
|
||||
def extractLsb (hi lo : Nat) (x : BitVec n) : BitVec (hi - lo + 1) := extractLsb' lo _ x
|
||||
def extractLsb (hi lo : Nat) (a : BitVec n) : BitVec (hi - lo + 1) := extractLsb' lo _ a
|
||||
|
||||
/--
|
||||
A version of `zeroExtend` that requires a proof, but is a noop.
|
||||
-/
|
||||
def zeroExtend' {n w : Nat} (le : n ≤ w) (x : BitVec n) : BitVec w :=
|
||||
def zeroExtend' {n w : Nat} (le : n ≤ w) (x : BitVec n) : BitVec w :=
|
||||
x.toNat#'(by
|
||||
apply Nat.lt_of_lt_of_le x.isLt
|
||||
exact Nat.pow_le_pow_of_le_right (by trivial) le)
|
||||
@@ -413,8 +413,8 @@ def zeroExtend' {n w : Nat} (le : n ≤ w) (x : BitVec n) : BitVec w :=
|
||||
`shiftLeftZeroExtend x n` returns `zeroExtend (w+n) x <<< n` without
|
||||
needing to compute `x % 2^(2+n)`.
|
||||
-/
|
||||
def shiftLeftZeroExtend (msbs : BitVec w) (m : Nat) : BitVec (w + m) :=
|
||||
let shiftLeftLt {x : Nat} (p : x < 2^w) (m : Nat) : x <<< m < 2^(w + m) := by
|
||||
def shiftLeftZeroExtend (msbs : BitVec w) (m : Nat) : BitVec (w+m) :=
|
||||
let shiftLeftLt {x : Nat} (p : x < 2^w) (m : Nat) : x <<< m < 2^(w+m) := by
|
||||
simp [Nat.shiftLeft_eq, Nat.pow_add]
|
||||
apply Nat.mul_lt_mul_of_pos_right p
|
||||
exact (Nat.two_pow_pos m)
|
||||
@@ -502,24 +502,24 @@ instance : Complement (BitVec w) := ⟨.not⟩
|
||||
|
||||
/--
|
||||
Left shift for bit vectors. The low bits are filled with zeros. As a numeric operation, this is
|
||||
equivalent to `x * 2^s`, modulo `2^n`.
|
||||
equivalent to `a * 2^s`, modulo `2^n`.
|
||||
|
||||
SMT-Lib name: `bvshl` except this operator uses a `Nat` shift value.
|
||||
-/
|
||||
protected def shiftLeft (x : BitVec n) (s : Nat) : BitVec n := BitVec.ofNat n (x.toNat <<< s)
|
||||
protected def shiftLeft (a : BitVec n) (s : Nat) : BitVec n := BitVec.ofNat n (a.toNat <<< s)
|
||||
instance : HShiftLeft (BitVec w) Nat (BitVec w) := ⟨.shiftLeft⟩
|
||||
|
||||
/--
|
||||
(Logical) right shift for bit vectors. The high bits are filled with zeros.
|
||||
As a numeric operation, this is equivalent to `x / 2^s`, rounding down.
|
||||
As a numeric operation, this is equivalent to `a / 2^s`, rounding down.
|
||||
|
||||
SMT-Lib name: `bvlshr` except this operator uses a `Nat` shift value.
|
||||
-/
|
||||
def ushiftRight (x : BitVec n) (s : Nat) : BitVec n :=
|
||||
(x.toNat >>> s)#'(by
|
||||
let ⟨x, lt⟩ := x
|
||||
def ushiftRight (a : BitVec n) (s : Nat) : BitVec n :=
|
||||
(a.toNat >>> s)#'(by
|
||||
let ⟨a, lt⟩ := a
|
||||
simp only [BitVec.toNat, Nat.shiftRight_eq_div_pow, Nat.div_lt_iff_lt_mul (Nat.two_pow_pos s)]
|
||||
rw [←Nat.mul_one x]
|
||||
rw [←Nat.mul_one a]
|
||||
exact Nat.mul_lt_mul_of_lt_of_le' lt (Nat.two_pow_pos s) (Nat.le_refl 1))
|
||||
|
||||
instance : HShiftRight (BitVec w) Nat (BitVec w) := ⟨.ushiftRight⟩
|
||||
@@ -527,24 +527,15 @@ instance : HShiftRight (BitVec w) Nat (BitVec w) := ⟨.ushiftRight⟩
|
||||
/--
|
||||
Arithmetic right shift for bit vectors. The high bits are filled with the
|
||||
most-significant bit.
|
||||
As a numeric operation, this is equivalent to `x.toInt >>> s`.
|
||||
As a numeric operation, this is equivalent to `a.toInt >>> s`.
|
||||
|
||||
SMT-Lib name: `bvashr` except this operator uses a `Nat` shift value.
|
||||
-/
|
||||
def sshiftRight (x : BitVec n) (s : Nat) : BitVec n := .ofInt n (x.toInt >>> s)
|
||||
def sshiftRight (a : BitVec n) (s : Nat) : BitVec n := .ofInt n (a.toInt >>> s)
|
||||
|
||||
instance {n} : HShiftLeft (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x <<< y.toNat⟩
|
||||
instance {n} : HShiftRight (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x >>> y.toNat⟩
|
||||
|
||||
/--
|
||||
Arithmetic right shift for bit vectors. The high bits are filled with the
|
||||
most-significant bit.
|
||||
As a numeric operation, this is equivalent to `a.toInt >>> s.toNat`.
|
||||
|
||||
SMT-Lib name: `bvashr`.
|
||||
-/
|
||||
def sshiftRight' (a : BitVec n) (s : BitVec m) : BitVec n := a.sshiftRight s.toNat
|
||||
|
||||
/-- Auxiliary function for `rotateLeft`, which does not take into account the case where
|
||||
the rotation amount is greater than the bitvector width. -/
|
||||
def rotateLeftAux (x : BitVec w) (n : Nat) : BitVec w :=
|
||||
|
||||
@@ -289,18 +289,18 @@ theorem sle_eq_carry (x y : BitVec w) :
|
||||
A recurrence that describes multiplication as repeated addition.
|
||||
Is useful for bitblasting multiplication.
|
||||
-/
|
||||
def mulRec (x y : BitVec w) (s : Nat) : BitVec w :=
|
||||
let cur := if y.getLsb s then (x <<< s) else 0
|
||||
def mulRec (l r : BitVec w) (s : Nat) : BitVec w :=
|
||||
let cur := if r.getLsb s then (l <<< s) else 0
|
||||
match s with
|
||||
| 0 => cur
|
||||
| s + 1 => mulRec x y s + cur
|
||||
| s + 1 => mulRec l r s + cur
|
||||
|
||||
theorem mulRec_zero_eq (x y : BitVec w) :
|
||||
mulRec x y 0 = if y.getLsb 0 then x else 0 := by
|
||||
theorem mulRec_zero_eq (l r : BitVec w) :
|
||||
mulRec l r 0 = if r.getLsb 0 then l else 0 := by
|
||||
simp [mulRec]
|
||||
|
||||
theorem mulRec_succ_eq (x y : BitVec w) (s : Nat) :
|
||||
mulRec x y (s + 1) = mulRec x y s + if y.getLsb (s + 1) then (x <<< (s + 1)) else 0 := rfl
|
||||
theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) :
|
||||
mulRec l r (s + 1) = mulRec l r s + if r.getLsb (s + 1) then (l <<< (s + 1)) else 0 := rfl
|
||||
|
||||
/--
|
||||
Recurrence lemma: truncating to `i+1` bits and then zero extending to `w`
|
||||
@@ -326,29 +326,29 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w
|
||||
by_cases hi : x.getLsb i <;> simp [hi] <;> omega
|
||||
|
||||
/--
|
||||
Recurrence lemma: multiplying `x` with the first `s` bits of `y` is the
|
||||
same as truncating `y` to `s` bits, then zero extending to the original length,
|
||||
Recurrence lemma: multiplying `l` with the first `s` bits of `r` is the
|
||||
same as truncating `r` to `s` bits, then zero extending to the original length,
|
||||
and performing the multplication. -/
|
||||
theorem mulRec_eq_mul_signExtend_truncate (x y : BitVec w) (s : Nat) :
|
||||
mulRec x y s = x * ((y.truncate (s + 1)).zeroExtend w) := by
|
||||
theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) :
|
||||
mulRec l r s = l * ((r.truncate (s + 1)).zeroExtend w) := by
|
||||
induction s
|
||||
case zero =>
|
||||
simp only [mulRec_zero_eq, ofNat_eq_ofNat, Nat.reduceAdd]
|
||||
by_cases y.getLsb 0
|
||||
case pos hy =>
|
||||
simp only [hy, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero,
|
||||
ofBool_true, ofNat_eq_ofNat]
|
||||
by_cases r.getLsb 0
|
||||
case pos hr =>
|
||||
simp only [hr, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero,
|
||||
hr, ofBool_true, ofNat_eq_ofNat]
|
||||
rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)]
|
||||
simp
|
||||
case neg hy =>
|
||||
simp [hy, zeroExtend_one_eq_ofBool_getLsb_zero]
|
||||
case neg hr =>
|
||||
simp [hr, zeroExtend_one_eq_ofBool_getLsb_zero]
|
||||
case succ s' hs =>
|
||||
rw [mulRec_succ_eq, hs]
|
||||
have heq :
|
||||
(if y.getLsb (s' + 1) = true then x <<< (s' + 1) else 0) =
|
||||
(x * (y &&& (BitVec.twoPow w (s' + 1)))) := by
|
||||
(if r.getLsb (s' + 1) = true then l <<< (s' + 1) else 0) =
|
||||
(l * (r &&& (BitVec.twoPow w (s' + 1)))) := by
|
||||
simp only [ofNat_eq_ofNat, and_twoPow]
|
||||
by_cases hy : y.getLsb (s' + 1) <;> simp [hy]
|
||||
by_cases hr : r.getLsb (s' + 1) <;> simp [hr]
|
||||
rw [heq, ← BitVec.mul_add, ← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow]
|
||||
|
||||
theorem getLsb_mul (x y : BitVec w) (i : Nat) :
|
||||
@@ -429,67 +429,6 @@ theorem shiftLeft_eq_shiftLeftRec (x : BitVec w₁) (y : BitVec w₂) :
|
||||
· simp [of_length_zero]
|
||||
· simp [shiftLeftRec_eq]
|
||||
|
||||
/- ### Arithmetic shift right (sshiftRight) recurrence -/
|
||||
|
||||
/--
|
||||
`sshiftRightRec x y n` shifts `x` arithmetically/signed to the right by the first `n` bits of `y`.
|
||||
The theorem `sshiftRight_eq_sshiftRightRec` proves the equivalence of `(x.sshiftRight y)` and `sshiftRightRec`.
|
||||
Together with equations `sshiftRightRec_zero`, `sshiftRightRec_succ`,
|
||||
this allows us to unfold `sshiftRight` into a circuit for bitblasting.
|
||||
-/
|
||||
def sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ :=
|
||||
let shiftAmt := (y &&& (twoPow w₂ n))
|
||||
match n with
|
||||
| 0 => x.sshiftRight' shiftAmt
|
||||
| n + 1 => (sshiftRightRec x y n).sshiftRight' shiftAmt
|
||||
|
||||
@[simp]
|
||||
theorem sshiftRightRec_zero_eq (x : BitVec w₁) (y : BitVec w₂) :
|
||||
sshiftRightRec x y 0 = x.sshiftRight' (y &&& 1#w₂) := by
|
||||
simp only [sshiftRightRec, twoPow_zero]
|
||||
|
||||
@[simp]
|
||||
theorem sshiftRightRec_succ_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
|
||||
sshiftRightRec x y (n + 1) = (sshiftRightRec x y n).sshiftRight' (y &&& twoPow w₂ (n + 1)) := by
|
||||
simp [sshiftRightRec]
|
||||
|
||||
/--
|
||||
If `y &&& z = 0`, `x.sshiftRight (y ||| z) = (x.sshiftRight y).sshiftRight z`.
|
||||
This follows as `y &&& z = 0` implies `y ||| z = y + z`,
|
||||
and thus `x.sshiftRight (y ||| z) = x.sshiftRight (y + z) = (x.sshiftRight y).sshiftRight z`.
|
||||
-/
|
||||
theorem sshiftRight'_or_of_and_eq_zero {x : BitVec w₁} {y z : BitVec w₂}
|
||||
(h : y &&& z = 0#w₂) :
|
||||
x.sshiftRight' (y ||| z) = (x.sshiftRight' y).sshiftRight' z := by
|
||||
simp [sshiftRight', ← add_eq_or_of_and_eq_zero _ _ h,
|
||||
toNat_add_of_and_eq_zero h, sshiftRight_add]
|
||||
|
||||
theorem sshiftRightRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
|
||||
sshiftRightRec x y n = x.sshiftRight' ((y.truncate (n + 1)).zeroExtend w₂) := by
|
||||
induction n generalizing x y
|
||||
case zero =>
|
||||
ext i
|
||||
simp [twoPow_zero, Nat.reduceAdd, and_one_eq_zeroExtend_ofBool_getLsb, truncate_one]
|
||||
case succ n ih =>
|
||||
simp only [sshiftRightRec_succ_eq, and_twoPow, ih]
|
||||
by_cases h : y.getLsb (n + 1)
|
||||
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true h,
|
||||
sshiftRight'_or_of_and_eq_zero (by simp), h]
|
||||
simp
|
||||
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1)
|
||||
(by simp [h])]
|
||||
simp [h]
|
||||
|
||||
/--
|
||||
Show that `x.sshiftRight y` can be written in terms of `sshiftRightRec`.
|
||||
This can be unfolded in terms of `sshiftRightRec_zero_eq`, `sshiftRightRec_succ_eq` for bitblasting.
|
||||
-/
|
||||
theorem sshiftRight_eq_sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) :
|
||||
(x.sshiftRight' y).getLsb i = (sshiftRightRec x y (w₂ - 1)).getLsb i := by
|
||||
rcases w₂ with rfl | w₂
|
||||
· simp [of_length_zero]
|
||||
· simp [sshiftRightRec_eq]
|
||||
|
||||
/- ### Logical shift right (ushiftRight) recurrence for bitblasting -/
|
||||
|
||||
/--
|
||||
|
||||
@@ -23,7 +23,7 @@ theorem ofFin_eq_ofNat : @BitVec.ofFin w (Fin.mk x lt) = BitVec.ofNat w x := by
|
||||
simp only [BitVec.ofNat, Fin.ofNat', lt, Nat.mod_eq_of_lt]
|
||||
|
||||
/-- Prove equality of bitvectors in terms of nat operations. -/
|
||||
theorem eq_of_toNat_eq {n} : ∀ {x y : BitVec n}, x.toNat = y.toNat → x = y
|
||||
theorem eq_of_toNat_eq {n} : ∀ {i j : BitVec n}, i.toNat = j.toNat → i = j
|
||||
| ⟨_, _⟩, ⟨_, _⟩, rfl => rfl
|
||||
|
||||
@[simp] theorem val_toFin (x : BitVec w) : x.toFin.val = x.toNat := rfl
|
||||
@@ -228,12 +228,12 @@ theorem toNat_ge_of_msb_true {x : BitVec n} (p : BitVec.msb x = true) : x.toNat
|
||||
/-! ### toInt/ofInt -/
|
||||
|
||||
/-- Prove equality of bitvectors in terms of nat operations. -/
|
||||
theorem toInt_eq_toNat_cond (x : BitVec n) :
|
||||
x.toInt =
|
||||
if 2*x.toNat < 2^n then
|
||||
(x.toNat : Int)
|
||||
theorem toInt_eq_toNat_cond (i : BitVec n) :
|
||||
i.toInt =
|
||||
if 2*i.toNat < 2^n then
|
||||
(i.toNat : Int)
|
||||
else
|
||||
(x.toNat : Int) - (2^n : Nat) :=
|
||||
(i.toNat : Int) - (2^n : Nat) :=
|
||||
rfl
|
||||
|
||||
theorem msb_eq_false_iff_two_mul_lt (x : BitVec w) : x.msb = false ↔ 2 * x.toNat < 2^w := by
|
||||
@@ -260,13 +260,13 @@ theorem toInt_eq_toNat_bmod (x : BitVec n) : x.toInt = Int.bmod x.toNat (2^n) :=
|
||||
omega
|
||||
|
||||
/-- Prove equality of bitvectors in terms of nat operations. -/
|
||||
theorem eq_of_toInt_eq {x y : BitVec n} : x.toInt = y.toInt → x = y := by
|
||||
theorem eq_of_toInt_eq {i j : BitVec n} : i.toInt = j.toInt → i = j := by
|
||||
intro eq
|
||||
simp [toInt_eq_toNat_cond] at eq
|
||||
apply eq_of_toNat_eq
|
||||
revert eq
|
||||
have _xlt := x.isLt
|
||||
have _ylt := y.isLt
|
||||
have _ilt := i.isLt
|
||||
have _jlt := j.isLt
|
||||
split <;> split <;> omega
|
||||
|
||||
theorem toInt_inj (x y : BitVec n) : x.toInt = y.toInt ↔ x = y :=
|
||||
@@ -507,13 +507,6 @@ theorem or_assoc (x y z : BitVec w) :
|
||||
x ||| y ||| z = x ||| (y ||| z) := by
|
||||
ext i
|
||||
simp [Bool.or_assoc]
|
||||
instance : Std.Associative (α := BitVec n) (· ||| ·) := ⟨BitVec.or_assoc⟩
|
||||
|
||||
theorem or_comm (x y : BitVec w) :
|
||||
x ||| y = y ||| x := by
|
||||
ext i
|
||||
simp [Bool.or_comm]
|
||||
instance : Std.Commutative (fun (x y : BitVec w) => x ||| y) := ⟨BitVec.or_comm⟩
|
||||
|
||||
/-! ### and -/
|
||||
|
||||
@@ -545,13 +538,11 @@ theorem and_assoc (x y z : BitVec w) :
|
||||
x &&& y &&& z = x &&& (y &&& z) := by
|
||||
ext i
|
||||
simp [Bool.and_assoc]
|
||||
instance : Std.Associative (α := BitVec n) (· &&& ·) := ⟨BitVec.and_assoc⟩
|
||||
|
||||
theorem and_comm (x y : BitVec w) :
|
||||
x &&& y = y &&& x := by
|
||||
ext i
|
||||
simp [Bool.and_comm]
|
||||
instance : Std.Commutative (fun (x y : BitVec w) => x &&& y) := ⟨BitVec.and_comm⟩
|
||||
|
||||
/-! ### xor -/
|
||||
|
||||
@@ -577,13 +568,6 @@ theorem xor_assoc (x y z : BitVec w) :
|
||||
x ^^^ y ^^^ z = x ^^^ (y ^^^ z) := by
|
||||
ext i
|
||||
simp [Bool.xor_assoc]
|
||||
instance : Std.Associative (fun (x y : BitVec w) => x ^^^ y) := ⟨BitVec.xor_assoc⟩
|
||||
|
||||
theorem xor_comm (x y : BitVec w) :
|
||||
x ^^^ y = y ^^^ x := by
|
||||
ext i
|
||||
simp [Bool.xor_comm]
|
||||
instance : Std.Commutative (fun (x y : BitVec w) => x ^^^ y) := ⟨BitVec.xor_comm⟩
|
||||
|
||||
/-! ### not -/
|
||||
|
||||
@@ -749,21 +733,6 @@ theorem getLsb_shiftLeft' {x : BitVec w₁} {y : BitVec w₂} {i : Nat} :
|
||||
getLsb (x >>> i) j = getLsb x (i+j) := by
|
||||
unfold getLsb ; simp
|
||||
|
||||
theorem ushiftRight_xor_distrib (x y : BitVec w) (n : Nat) :
|
||||
(x ^^^ y) >>> n = (x >>> n) ^^^ (y >>> n) := by
|
||||
ext
|
||||
simp
|
||||
|
||||
theorem ushiftRight_and_distrib (x y : BitVec w) (n : Nat) :
|
||||
(x &&& y) >>> n = (x >>> n) &&& (y >>> n) := by
|
||||
ext
|
||||
simp
|
||||
|
||||
theorem ushiftRight_or_distrib (x y : BitVec w) (n : Nat) :
|
||||
(x ||| y) >>> n = (x >>> n) ||| (y >>> n) := by
|
||||
ext
|
||||
simp
|
||||
|
||||
@[simp]
|
||||
theorem ushiftRight_zero_eq (x : BitVec w) : x >>> 0 = x := by
|
||||
simp [bv_toNat]
|
||||
@@ -817,7 +786,7 @@ theorem sshiftRight_eq_of_msb_true {x : BitVec w} {s : Nat} (h : x.msb = true) :
|
||||
· rw [Nat.shiftRight_eq_div_pow]
|
||||
apply Nat.lt_of_le_of_lt (Nat.div_le_self _ _) (by omega)
|
||||
|
||||
@[simp] theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
|
||||
theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
|
||||
getLsb (x.sshiftRight s) i =
|
||||
(!decide (w ≤ i) && if s + i < w then x.getLsb (s + i) else x.msb) := by
|
||||
rcases hmsb : x.msb with rfl | rfl
|
||||
@@ -838,41 +807,6 @@ theorem sshiftRight_eq_of_msb_true {x : BitVec w} {s : Nat} (h : x.msb = true) :
|
||||
Nat.not_lt, decide_eq_true_eq]
|
||||
omega
|
||||
|
||||
/-- The msb after arithmetic shifting right equals the original msb. -/
|
||||
theorem sshiftRight_msb_eq_msb {n : Nat} {x : BitVec w} :
|
||||
(x.sshiftRight n).msb = x.msb := by
|
||||
rw [msb_eq_getLsb_last, getLsb_sshiftRight, msb_eq_getLsb_last]
|
||||
by_cases hw₀ : w = 0
|
||||
· simp [hw₀]
|
||||
· simp only [show ¬(w ≤ w - 1) by omega, decide_False, Bool.not_false, Bool.true_and,
|
||||
ite_eq_right_iff]
|
||||
intros h
|
||||
simp [show n = 0 by omega]
|
||||
|
||||
@[simp] theorem sshiftRight_zero {x : BitVec w} : x.sshiftRight 0 = x := by
|
||||
ext i
|
||||
simp
|
||||
|
||||
theorem sshiftRight_add {x : BitVec w} {m n : Nat} :
|
||||
x.sshiftRight (m + n) = (x.sshiftRight m).sshiftRight n := by
|
||||
ext i
|
||||
simp only [getLsb_sshiftRight, Nat.add_assoc]
|
||||
by_cases h₁ : w ≤ (i : Nat)
|
||||
· simp [h₁]
|
||||
· simp only [h₁, decide_False, Bool.not_false, Bool.true_and]
|
||||
by_cases h₂ : n + ↑i < w
|
||||
· simp [h₂]
|
||||
· simp only [h₂, ↓reduceIte]
|
||||
by_cases h₃ : m + (n + ↑i) < w
|
||||
· simp [h₃]
|
||||
omega
|
||||
· simp [h₃, sshiftRight_msb_eq_msb]
|
||||
|
||||
/-! ### sshiftRight reductions from BitVec to Nat -/
|
||||
|
||||
@[simp]
|
||||
theorem sshiftRight_eq' (x : BitVec w) : x.sshiftRight' y = x.sshiftRight y.toNat := rfl
|
||||
|
||||
/-! ### signExtend -/
|
||||
|
||||
/-- Equation theorem for `Int.sub` when both arguments are `Int.ofNat` -/
|
||||
@@ -935,15 +869,15 @@ theorem append_def (x : BitVec v) (y : BitVec w) :
|
||||
(x ++ y).toNat = x.toNat <<< n ||| y.toNat :=
|
||||
rfl
|
||||
|
||||
@[simp] theorem getLsb_append {x : BitVec n} {y : BitVec m} :
|
||||
getLsb (x ++ y) i = bif i < m then getLsb y i else getLsb x (i - m) := by
|
||||
@[simp] theorem getLsb_append {v : BitVec n} {w : BitVec m} :
|
||||
getLsb (v ++ w) i = bif i < m then getLsb w i else getLsb v (i - m) := by
|
||||
simp only [append_def, getLsb_or, getLsb_shiftLeftZeroExtend, getLsb_zeroExtend']
|
||||
by_cases h : i < m
|
||||
· simp [h]
|
||||
· simp [h]; simp_all
|
||||
|
||||
@[simp] theorem getMsb_append {x : BitVec n} {y : BitVec m} :
|
||||
getMsb (x ++ y) i = bif n ≤ i then getMsb y (i - n) else getMsb x i := by
|
||||
@[simp] theorem getMsb_append {v : BitVec n} {w : BitVec m} :
|
||||
getMsb (v ++ w) i = bif n ≤ i then getMsb w (i - n) else getMsb v i := by
|
||||
simp [append_def]
|
||||
by_cases h : n ≤ i
|
||||
· simp [h]
|
||||
|
||||
@@ -438,24 +438,6 @@ Added for confluence between `if_true_left` and `ite_false_same` on
|
||||
-/
|
||||
@[simp] theorem eq_true_imp_eq_false : ∀(b:Bool), (b = true → b = false) ↔ (b = false) := by decide
|
||||
|
||||
/-! ### forall -/
|
||||
|
||||
theorem forall_bool' {p : Bool → Prop} (b : Bool) : (∀ x, p x) ↔ p b ∧ p !b :=
|
||||
⟨fun h ↦ ⟨h _, h _⟩, fun ⟨h₁, h₂⟩ x ↦ by cases b <;> cases x <;> assumption⟩
|
||||
|
||||
@[simp]
|
||||
theorem forall_bool {p : Bool → Prop} : (∀ b, p b) ↔ p false ∧ p true :=
|
||||
forall_bool' false
|
||||
|
||||
/-! ### exists -/
|
||||
|
||||
theorem exists_bool' {p : Bool → Prop} (b : Bool) : (∃ x, p x) ↔ p b ∨ p !b :=
|
||||
⟨fun ⟨x, hx⟩ ↦ by cases x <;> cases b <;> first | exact .inl ‹_› | exact .inr ‹_›,
|
||||
fun h ↦ by cases h <;> exact ⟨_, ‹_›⟩⟩
|
||||
|
||||
@[simp]
|
||||
theorem exists_bool {p : Bool → Prop} : (∃ b, p b) ↔ p false ∨ p true :=
|
||||
exists_bool' false
|
||||
|
||||
/-! ### cond -/
|
||||
|
||||
|
||||
@@ -354,7 +354,7 @@ theorem erase_eq_iff [LawfulBEq α] {a : α} {l : List α} :
|
||||
rw [erase_of_not_mem]
|
||||
simp_all
|
||||
|
||||
theorem Nodup.erase_eq_filter [LawfulBEq α] {l} (d : Nodup l) (a : α) : l.erase a = l.filter (· != a) := by
|
||||
theorem Nodup.erase_eq_filter [BEq α] [LawfulBEq α] {l} (d : Nodup l) (a : α) : l.erase a = l.filter (· != a) := by
|
||||
induction d with
|
||||
| nil => rfl
|
||||
| cons m _n ih =>
|
||||
@@ -367,13 +367,13 @@ theorem Nodup.erase_eq_filter [LawfulBEq α] {l} (d : Nodup l) (a : α) : l.eras
|
||||
simpa [@eq_comm α] using m
|
||||
· simp [beq_false_of_ne h, ih, h]
|
||||
|
||||
theorem Nodup.mem_erase_iff [LawfulBEq α] {a : α} (d : Nodup l) : a ∈ l.erase b ↔ a ≠ b ∧ a ∈ l := by
|
||||
theorem Nodup.mem_erase_iff [BEq α] [LawfulBEq α] {a : α} (d : Nodup l) : a ∈ l.erase b ↔ a ≠ b ∧ a ∈ l := by
|
||||
rw [Nodup.erase_eq_filter d, mem_filter, and_comm, bne_iff_ne]
|
||||
|
||||
theorem Nodup.not_mem_erase [LawfulBEq α] {a : α} (h : Nodup l) : a ∉ l.erase a := fun H => by
|
||||
theorem Nodup.not_mem_erase [BEq α] [LawfulBEq α] {a : α} (h : Nodup l) : a ∉ l.erase a := fun H => by
|
||||
simpa using ((Nodup.mem_erase_iff h).mp H).left
|
||||
|
||||
theorem Nodup.erase [LawfulBEq α] (a : α) : Nodup l → Nodup (l.erase a) :=
|
||||
theorem Nodup.erase [BEq α] [LawfulBEq α] (a : α) : Nodup l → Nodup (l.erase a) :=
|
||||
Nodup.sublist <| erase_sublist _ _
|
||||
|
||||
end erase
|
||||
|
||||
@@ -5,18 +5,9 @@ Author: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.SimpLemmas
|
||||
import Init.NotationExtra
|
||||
|
||||
instance [BEq α] [BEq β] [LawfulBEq α] [LawfulBEq β] : LawfulBEq (α × β) where
|
||||
eq_of_beq {a b} (h : a.1 == b.1 && a.2 == b.2) := by
|
||||
cases a; cases b
|
||||
refine congr (congrArg _ (eq_of_beq ?_)) (eq_of_beq ?_) <;> simp_all
|
||||
rfl {a} := by cases a; simp [BEq.beq, LawfulBEq.rfl]
|
||||
|
||||
@[simp]
|
||||
protected theorem Prod.forall {p : α × β → Prop} : (∀ x, p x) ↔ ∀ a b, p (a, b) :=
|
||||
⟨fun h a b ↦ h (a, b), fun h ⟨a, b⟩ ↦ h a b⟩
|
||||
|
||||
@[simp]
|
||||
protected theorem Prod.exists {p : α × β → Prop} : (∃ x, p x) ↔ ∃ a b, p (a, b) :=
|
||||
⟨fun ⟨⟨a, b⟩, h⟩ ↦ ⟨a, b, h⟩, fun ⟨a, b, h⟩ ↦ ⟨⟨a, b⟩, h⟩⟩
|
||||
|
||||
@@ -78,10 +78,7 @@ end Elab.Tactic.Ext
|
||||
end Lean
|
||||
|
||||
attribute [ext] Prod PProd Sigma PSigma
|
||||
attribute [ext] funext propext Subtype.eq Array.ext
|
||||
attribute [ext] funext propext Subtype.eq
|
||||
|
||||
@[ext] protected theorem PUnit.ext (x y : PUnit) : x = y := rfl
|
||||
protected theorem Unit.ext (x y : Unit) : x = y := rfl
|
||||
|
||||
@[ext] protected theorem Thunk.ext : {a b : Thunk α} → a.get = b.get → a = b
|
||||
| {..}, {..}, heq => congrArg _ <| funext fun _ => heq
|
||||
|
||||
@@ -470,23 +470,31 @@ def withFile (fn : FilePath) (mode : Mode) (f : Handle → IO α) : IO α :=
|
||||
def Handle.putStrLn (h : Handle) (s : String) : IO Unit :=
|
||||
h.putStr (s.push '\n')
|
||||
|
||||
partial def Handle.readBinToEndInto (h : Handle) (buf : ByteArray) : IO ByteArray := do
|
||||
partial def Handle.readBinToEnd (h : Handle) : IO ByteArray := do
|
||||
let rec loop (acc : ByteArray) : IO ByteArray := do
|
||||
let buf ← h.read 1024
|
||||
if buf.isEmpty then
|
||||
return acc
|
||||
else
|
||||
loop (acc ++ buf)
|
||||
loop buf
|
||||
loop ByteArray.empty
|
||||
|
||||
partial def Handle.readBinToEnd (h : Handle) : IO ByteArray := do
|
||||
h.readBinToEndInto .empty
|
||||
partial def Handle.readToEnd (h : Handle) : IO String := do
|
||||
let rec loop (s : String) := do
|
||||
let line ← h.getLine
|
||||
if line.isEmpty then
|
||||
return s
|
||||
else
|
||||
loop (s ++ line)
|
||||
loop ""
|
||||
|
||||
def Handle.readToEnd (h : Handle) : IO String := do
|
||||
let data ← h.readBinToEnd
|
||||
match String.fromUTF8? data with
|
||||
| some s => return s
|
||||
| none => throw <| .userError s!"Tried to read from handle containing non UTF-8 data."
|
||||
def readBinFile (fname : FilePath) : IO ByteArray := do
|
||||
let h ← Handle.mk fname Mode.read
|
||||
h.readBinToEnd
|
||||
|
||||
def readFile (fname : FilePath) : IO String := do
|
||||
let h ← Handle.mk fname Mode.read
|
||||
h.readToEnd
|
||||
|
||||
partial def lines (fname : FilePath) : IO (Array String) := do
|
||||
let h ← Handle.mk fname Mode.read
|
||||
@@ -592,28 +600,6 @@ end System.FilePath
|
||||
|
||||
namespace IO
|
||||
|
||||
namespace FS
|
||||
|
||||
def readBinFile (fname : FilePath) : IO ByteArray := do
|
||||
-- Requires metadata so defined after metadata
|
||||
let mdata ← fname.metadata
|
||||
let size := mdata.byteSize.toUSize
|
||||
let handle ← IO.FS.Handle.mk fname .read
|
||||
let buf ←
|
||||
if size > 0 then
|
||||
handle.read mdata.byteSize.toUSize
|
||||
else
|
||||
pure <| ByteArray.mkEmpty 0
|
||||
handle.readBinToEndInto buf
|
||||
|
||||
def readFile (fname : FilePath) : IO String := do
|
||||
let data ← readBinFile fname
|
||||
match String.fromUTF8? data with
|
||||
| some s => return s
|
||||
| none => throw <| .userError s!"Tried to read file '{fname}' containing non UTF-8 data."
|
||||
|
||||
end FS
|
||||
|
||||
def withStdin [Monad m] [MonadFinally m] [MonadLiftT BaseIO m] (h : FS.Stream) (x : m α) : m α := do
|
||||
let prev ← setStdin h
|
||||
try x finally discard <| setStdin prev
|
||||
|
||||
@@ -68,7 +68,6 @@ noncomputable def recursion {C : α → Sort v} (a : α) (h : ∀ x, (∀ y, r y
|
||||
induction (apply hwf a) with
|
||||
| intro x₁ _ ih => exact h x₁ ih
|
||||
|
||||
include hwf in
|
||||
theorem induction {C : α → Prop} (a : α) (h : ∀ x, (∀ y, r y x → C y) → C x) : C a :=
|
||||
recursion hwf a h
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ structure AttributeImpl extends AttributeImplCore where
|
||||
erase (decl : Name) : AttrM Unit := throwError "attribute cannot be erased"
|
||||
deriving Inhabited
|
||||
|
||||
builtin_initialize attributeMapRef : IO.Ref (Std.HashMap Name AttributeImpl) ← IO.mkRef {}
|
||||
builtin_initialize attributeMapRef : IO.Ref (HashMap Name AttributeImpl) ← IO.mkRef {}
|
||||
|
||||
/-- Low level attribute registration function. -/
|
||||
def registerBuiltinAttribute (attr : AttributeImpl) : IO Unit := do
|
||||
@@ -296,7 +296,7 @@ end EnumAttributes
|
||||
-/
|
||||
|
||||
abbrev AttributeImplBuilder := Name → List DataValue → Except String AttributeImpl
|
||||
abbrev AttributeImplBuilderTable := Std.HashMap Name AttributeImplBuilder
|
||||
abbrev AttributeImplBuilderTable := HashMap Name AttributeImplBuilder
|
||||
|
||||
builtin_initialize attributeImplBuilderTableRef : IO.Ref AttributeImplBuilderTable ← IO.mkRef {}
|
||||
|
||||
@@ -307,7 +307,7 @@ def registerAttributeImplBuilder (builderId : Name) (builder : AttributeImplBuil
|
||||
|
||||
def mkAttributeImplOfBuilder (builderId ref : Name) (args : List DataValue) : IO AttributeImpl := do
|
||||
let table ← attributeImplBuilderTableRef.get
|
||||
match table[builderId]? with
|
||||
match table.find? builderId with
|
||||
| none => throw (IO.userError ("unknown attribute implementation builder '" ++ toString builderId ++ "'"))
|
||||
| some builder => IO.ofExcept <| builder ref args
|
||||
|
||||
@@ -317,7 +317,7 @@ inductive AttributeExtensionOLeanEntry where
|
||||
|
||||
structure AttributeExtensionState where
|
||||
newEntries : List AttributeExtensionOLeanEntry := []
|
||||
map : Std.HashMap Name AttributeImpl
|
||||
map : HashMap Name AttributeImpl
|
||||
deriving Inhabited
|
||||
|
||||
abbrev AttributeExtension := PersistentEnvExtension AttributeExtensionOLeanEntry (AttributeExtensionOLeanEntry × AttributeImpl) AttributeExtensionState
|
||||
@@ -348,7 +348,7 @@ private def AttributeExtension.addImported (es : Array (Array AttributeExtension
|
||||
let map ← es.foldlM
|
||||
(fun map entries =>
|
||||
entries.foldlM
|
||||
(fun (map : Std.HashMap Name AttributeImpl) entry => do
|
||||
(fun (map : HashMap Name AttributeImpl) entry => do
|
||||
let attrImpl ← mkAttributeImplOfEntry ctx.env ctx.opts entry
|
||||
return map.insert attrImpl.name attrImpl)
|
||||
map)
|
||||
@@ -378,7 +378,7 @@ def getBuiltinAttributeNames : IO (List Name) :=
|
||||
|
||||
def getBuiltinAttributeImpl (attrName : Name) : IO AttributeImpl := do
|
||||
let m ← attributeMapRef.get
|
||||
match m[attrName]? with
|
||||
match m.find? attrName with
|
||||
| some attr => pure attr
|
||||
| none => throw (IO.userError ("unknown attribute '" ++ toString attrName ++ "'"))
|
||||
|
||||
@@ -396,7 +396,7 @@ def getAttributeNames (env : Environment) : List Name :=
|
||||
|
||||
def getAttributeImpl (env : Environment) (attrName : Name) : Except String AttributeImpl :=
|
||||
let m := (attributeExtension.getState env).map
|
||||
match m[attrName]? with
|
||||
match m.find? attrName with
|
||||
| some attr => pure attr
|
||||
| none => throw ("unknown attribute '" ++ toString attrName ++ "'")
|
||||
|
||||
|
||||
@@ -26,9 +26,9 @@ instance : Hashable Key := ⟨getHash⟩
|
||||
end OwnedSet
|
||||
|
||||
open OwnedSet (Key) in
|
||||
abbrev OwnedSet := Std.HashMap Key Unit
|
||||
def OwnedSet.insert (s : OwnedSet) (k : OwnedSet.Key) : OwnedSet := Std.HashMap.insert s k ()
|
||||
def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := Std.HashMap.contains s k
|
||||
abbrev OwnedSet := HashMap Key Unit
|
||||
def OwnedSet.insert (s : OwnedSet) (k : OwnedSet.Key) : OwnedSet := HashMap.insert s k ()
|
||||
def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := HashMap.contains s k
|
||||
|
||||
/-! We perform borrow inference in a block of mutually recursive functions.
|
||||
Join points are viewed as local functions, and are identified using
|
||||
@@ -49,7 +49,7 @@ instance : Hashable Key := ⟨getHash⟩
|
||||
end ParamMap
|
||||
|
||||
open ParamMap (Key)
|
||||
abbrev ParamMap := Std.HashMap Key (Array Param)
|
||||
abbrev ParamMap := HashMap Key (Array Param)
|
||||
|
||||
def ParamMap.fmt (map : ParamMap) : Format :=
|
||||
let fmts := map.fold (fun fmt k ps =>
|
||||
@@ -109,7 +109,7 @@ partial def visitFnBody (fn : FunId) (paramMap : ParamMap) : FnBody → FnBody
|
||||
| FnBody.jdecl j _ v b =>
|
||||
let v := visitFnBody fn paramMap v
|
||||
let b := visitFnBody fn paramMap b
|
||||
match paramMap[ParamMap.Key.jp fn j]? with
|
||||
match paramMap.find? (ParamMap.Key.jp fn j) with
|
||||
| some ys => FnBody.jdecl j ys v b
|
||||
| none => unreachable!
|
||||
| FnBody.case tid x xType alts =>
|
||||
@@ -125,7 +125,7 @@ def visitDecls (decls : Array Decl) (paramMap : ParamMap) : Array Decl :=
|
||||
decls.map fun decl => match decl with
|
||||
| Decl.fdecl f _ ty b info =>
|
||||
let b := visitFnBody f paramMap b
|
||||
match paramMap[ParamMap.Key.decl f]? with
|
||||
match paramMap.find? (ParamMap.Key.decl f) with
|
||||
| some xs => Decl.fdecl f xs ty b info
|
||||
| none => unreachable!
|
||||
| other => other
|
||||
@@ -178,7 +178,7 @@ def isOwned (x : VarId) : M Bool := do
|
||||
/-- Updates `map[k]` using the current set of `owned` variables. -/
|
||||
def updateParamMap (k : ParamMap.Key) : M Unit := do
|
||||
let s ← get
|
||||
match s.paramMap[k]? with
|
||||
match s.paramMap.find? k with
|
||||
| some ps => do
|
||||
let ps ← ps.mapM fun (p : Param) => do
|
||||
if !p.borrow then pure p
|
||||
@@ -192,7 +192,7 @@ def updateParamMap (k : ParamMap.Key) : M Unit := do
|
||||
|
||||
def getParamInfo (k : ParamMap.Key) : M (Array Param) := do
|
||||
let s ← get
|
||||
match s.paramMap[k]? with
|
||||
match s.paramMap.find? k with
|
||||
| some ps => pure ps
|
||||
| none =>
|
||||
match k with
|
||||
|
||||
@@ -11,7 +11,6 @@ import Lean.Compiler.IR.Basic
|
||||
import Lean.Compiler.IR.CompilerM
|
||||
import Lean.Compiler.IR.FreeVars
|
||||
import Lean.Compiler.IR.ElimDeadVars
|
||||
import Lean.Data.AssocList
|
||||
|
||||
namespace Lean.IR.ExplicitBoxing
|
||||
/-!
|
||||
|
||||
@@ -152,7 +152,7 @@ def getFunctionSummary? (env : Environment) (fid : FunId) : Option Value :=
|
||||
| some modIdx => findAtSorted? (functionSummariesExt.getModuleEntries env modIdx) fid
|
||||
| none => functionSummariesExt.getState env |>.find? fid
|
||||
|
||||
abbrev Assignment := Std.HashMap VarId Value
|
||||
abbrev Assignment := HashMap VarId Value
|
||||
|
||||
structure InterpContext where
|
||||
currFnIdx : Nat := 0
|
||||
@@ -172,7 +172,7 @@ def findVarValue (x : VarId) : M Value := do
|
||||
let ctx ← read
|
||||
let s ← get
|
||||
let assignment := s.assignments[ctx.currFnIdx]!
|
||||
return assignment.getD x bot
|
||||
return assignment.findD x bot
|
||||
|
||||
def findArgValue (arg : Arg) : M Value :=
|
||||
match arg with
|
||||
@@ -303,7 +303,7 @@ partial def elimDeadAux (assignment : Assignment) : FnBody → FnBody
|
||||
| FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux assignment b)
|
||||
| FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux assignment v) (elimDeadAux assignment b)
|
||||
| FnBody.case tid x xType alts =>
|
||||
let v := assignment.getD x bot
|
||||
let v := assignment.findD x bot
|
||||
let alts := alts.map fun alt =>
|
||||
match alt with
|
||||
| Alt.ctor i b => Alt.ctor i <| if containsCtor v i then elimDeadAux assignment b else FnBody.unreachable
|
||||
|
||||
@@ -96,13 +96,8 @@ def shouldExport (n : Name) : Bool :=
|
||||
-- libleanshared to avoid Windows symbol limit
|
||||
!(`Lean.Compiler.LCNF).isPrefixOf n &&
|
||||
!(`Lean.IR).isPrefixOf n &&
|
||||
-- Lean.Server.findModuleRefs is used in SubVerso, and the contents of RequestM are used by the
|
||||
-- full Verso as well as anything else that extends the LSP server.
|
||||
(!(`Lean.Server.Watchdog).isPrefixOf n) &&
|
||||
(!(`Lean.Server.ImportCompletion).isPrefixOf n) &&
|
||||
(!(`Lean.Server.Completion).isPrefixOf n)
|
||||
|
||||
|
||||
-- Lean.Server.findModuleRefs is used in Verso
|
||||
(!(`Lean.Server).isPrefixOf n || n == `Lean.Server.findModuleRefs)
|
||||
|
||||
def emitFnDeclAux (decl : Decl) (cppBaseName : String) (isExternal : Bool) : M Unit := do
|
||||
let ps := decl.params
|
||||
@@ -257,7 +252,7 @@ def throwUnknownVar {α : Type} (x : VarId) : M α :=
|
||||
|
||||
def getJPParams (j : JoinPointId) : M (Array Param) := do
|
||||
let ctx ← read;
|
||||
match ctx.jpMap[j]? with
|
||||
match ctx.jpMap.find? j with
|
||||
| some ps => pure ps
|
||||
| none => throw "unknown join point"
|
||||
|
||||
|
||||
@@ -65,8 +65,8 @@ structure Context (llvmctx : LLVM.Context) where
|
||||
llvmmodule : LLVM.Module llvmctx
|
||||
|
||||
structure State (llvmctx : LLVM.Context) where
|
||||
var2val : Std.HashMap VarId (LLVM.LLVMType llvmctx × LLVM.Value llvmctx)
|
||||
jp2bb : Std.HashMap JoinPointId (LLVM.BasicBlock llvmctx)
|
||||
var2val : HashMap VarId (LLVM.LLVMType llvmctx × LLVM.Value llvmctx)
|
||||
jp2bb : HashMap JoinPointId (LLVM.BasicBlock llvmctx)
|
||||
|
||||
abbrev Error := String
|
||||
|
||||
@@ -84,7 +84,7 @@ def addJpTostate (jp : JoinPointId) (bb : LLVM.BasicBlock llvmctx) : M llvmctx U
|
||||
|
||||
def emitJp (jp : JoinPointId) : M llvmctx (LLVM.BasicBlock llvmctx) := do
|
||||
let state ← get
|
||||
match state.jp2bb[jp]? with
|
||||
match state.jp2bb.find? jp with
|
||||
| .some bb => return bb
|
||||
| .none => throw s!"unable to find join point {jp}"
|
||||
|
||||
@@ -531,7 +531,7 @@ def emitFnDecls : M llvmctx Unit := do
|
||||
|
||||
def emitLhsSlot_ (x : VarId) : M llvmctx (LLVM.LLVMType llvmctx × LLVM.Value llvmctx) := do
|
||||
let state ← get
|
||||
match state.var2val[x]? with
|
||||
match state.var2val.find? x with
|
||||
| .some v => return v
|
||||
| .none => throw s!"unable to find variable {x}"
|
||||
|
||||
@@ -1029,7 +1029,7 @@ def emitTailCall (builder : LLVM.Builder llvmctx) (f : FunId) (v : Expr) : M llv
|
||||
|
||||
def emitJmp (builder : LLVM.Builder llvmctx) (jp : JoinPointId) (xs : Array Arg) : M llvmctx Unit := do
|
||||
let llvmctx ← read
|
||||
let ps ← match llvmctx.jpMap[jp]? with
|
||||
let ps ← match llvmctx.jpMap.find? jp with
|
||||
| some ps => pure ps
|
||||
| none => throw s!"Unknown join point {jp}"
|
||||
unless xs.size == ps.size do throw s!"Invalid goto, mismatched sizes between arguments, formal parameters."
|
||||
|
||||
@@ -51,8 +51,8 @@ end CollectUsedDecls
|
||||
def collectUsedDecls (env : Environment) (decl : Decl) (used : NameSet := {}) : NameSet :=
|
||||
(CollectUsedDecls.collectDecl decl env).run' used
|
||||
|
||||
abbrev VarTypeMap := Std.HashMap VarId IRType
|
||||
abbrev JPParamsMap := Std.HashMap JoinPointId (Array Param)
|
||||
abbrev VarTypeMap := HashMap VarId IRType
|
||||
abbrev JPParamsMap := HashMap JoinPointId (Array Param)
|
||||
|
||||
namespace CollectMaps
|
||||
abbrev Collector := (VarTypeMap × JPParamsMap) → (VarTypeMap × JPParamsMap)
|
||||
|
||||
@@ -10,7 +10,7 @@ import Lean.Compiler.IR.FreeVars
|
||||
|
||||
namespace Lean.IR.ExpandResetReuse
|
||||
/-- Mapping from variable to projections -/
|
||||
abbrev ProjMap := Std.HashMap VarId Expr
|
||||
abbrev ProjMap := HashMap VarId Expr
|
||||
namespace CollectProjMap
|
||||
abbrev Collector := ProjMap → ProjMap
|
||||
@[inline] def collectVDecl (x : VarId) (v : Expr) : Collector := fun m =>
|
||||
@@ -148,20 +148,20 @@ def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
|
||||
def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=
|
||||
match y with
|
||||
| Arg.var y =>
|
||||
match ctx.projMap[y]? with
|
||||
match ctx.projMap.find? y with
|
||||
| some (Expr.proj j w) => j == i && w == x
|
||||
| _ => false
|
||||
| _ => false
|
||||
|
||||
/-- Given `uset x[i] := y`, return true iff `y := uproj[i] x` -/
|
||||
def isSelfUSet (ctx : Context) (x : VarId) (i : Nat) (y : VarId) : Bool :=
|
||||
match ctx.projMap[y]? with
|
||||
match ctx.projMap.find? y with
|
||||
| some (Expr.uproj j w) => j == i && w == x
|
||||
| _ => false
|
||||
|
||||
/-- Given `sset x[n, i] := y`, return true iff `y := sproj[n, i] x` -/
|
||||
def isSelfSSet (ctx : Context) (x : VarId) (n : Nat) (i : Nat) (y : VarId) : Bool :=
|
||||
match ctx.projMap[y]? with
|
||||
match ctx.projMap.find? y with
|
||||
| some (Expr.sproj m j w) => n == m && j == i && w == x
|
||||
| _ => false
|
||||
|
||||
|
||||
@@ -64,34 +64,34 @@ instance : AddMessageContext CompilerM where
|
||||
|
||||
def getType (fvarId : FVarId) : CompilerM Expr := do
|
||||
let lctx := (← get).lctx
|
||||
if let some decl := lctx.letDecls[fvarId]? then
|
||||
if let some decl := lctx.letDecls.find? fvarId then
|
||||
return decl.type
|
||||
else if let some decl := lctx.params[fvarId]? then
|
||||
else if let some decl := lctx.params.find? fvarId then
|
||||
return decl.type
|
||||
else if let some decl := lctx.funDecls[fvarId]? then
|
||||
else if let some decl := lctx.funDecls.find? fvarId then
|
||||
return decl.type
|
||||
else
|
||||
throwError "unknown free variable {fvarId.name}"
|
||||
|
||||
def getBinderName (fvarId : FVarId) : CompilerM Name := do
|
||||
let lctx := (← get).lctx
|
||||
if let some decl := lctx.letDecls[fvarId]? then
|
||||
if let some decl := lctx.letDecls.find? fvarId then
|
||||
return decl.binderName
|
||||
else if let some decl := lctx.params[fvarId]? then
|
||||
else if let some decl := lctx.params.find? fvarId then
|
||||
return decl.binderName
|
||||
else if let some decl := lctx.funDecls[fvarId]? then
|
||||
else if let some decl := lctx.funDecls.find? fvarId then
|
||||
return decl.binderName
|
||||
else
|
||||
throwError "unknown free variable {fvarId.name}"
|
||||
|
||||
def findParam? (fvarId : FVarId) : CompilerM (Option Param) :=
|
||||
return (← get).lctx.params[fvarId]?
|
||||
return (← get).lctx.params.find? fvarId
|
||||
|
||||
def findLetDecl? (fvarId : FVarId) : CompilerM (Option LetDecl) :=
|
||||
return (← get).lctx.letDecls[fvarId]?
|
||||
return (← get).lctx.letDecls.find? fvarId
|
||||
|
||||
def findFunDecl? (fvarId : FVarId) : CompilerM (Option FunDecl) :=
|
||||
return (← get).lctx.funDecls[fvarId]?
|
||||
return (← get).lctx.funDecls.find? fvarId
|
||||
|
||||
def findLetValue? (fvarId : FVarId) : CompilerM (Option LetValue) := do
|
||||
let some { value, .. } ← findLetDecl? fvarId | return none
|
||||
@@ -166,7 +166,7 @@ it is a free variable, a type (or type former), or `lcErased`.
|
||||
|
||||
`Check.lean` contains a substitution validator.
|
||||
-/
|
||||
abbrev FVarSubst := Std.HashMap FVarId Expr
|
||||
abbrev FVarSubst := HashMap FVarId Expr
|
||||
|
||||
/--
|
||||
Replace the free variables in `e` using the given substitution.
|
||||
@@ -190,7 +190,7 @@ where
|
||||
go (e : Expr) : Expr :=
|
||||
if e.hasFVar then
|
||||
match e with
|
||||
| .fvar fvarId => match s[fvarId]? with
|
||||
| .fvar fvarId => match s.find? fvarId with
|
||||
| some e => if translator then e else go e
|
||||
| none => e
|
||||
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e
|
||||
@@ -224,7 +224,7 @@ That is, it is not a type (or type former), nor `lcErased`. Recall that a valid
|
||||
expressions that are free variables, `lcErased`, or type formers.
|
||||
-/
|
||||
private partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : Bool) : NormFVarResult :=
|
||||
match s[fvarId]? with
|
||||
match s.find? fvarId with
|
||||
| some (.fvar fvarId') =>
|
||||
if translator then
|
||||
.fvar fvarId'
|
||||
@@ -246,7 +246,7 @@ private partial def normArgImp (s : FVarSubst) (arg : Arg) (translator : Bool) :
|
||||
match arg with
|
||||
| .erased => arg
|
||||
| .fvar fvarId =>
|
||||
match s[fvarId]? with
|
||||
match s.find? fvarId with
|
||||
| some (.fvar fvarId') =>
|
||||
let arg' := .fvar fvarId'
|
||||
if translator then arg' else normArgImp s arg' translator
|
||||
|
||||
@@ -268,7 +268,7 @@ def getFunctionSummary? (env : Environment) (fid : Name) : Option Value :=
|
||||
A map from variable identifiers to the `Value` produced by the abstract
|
||||
interpreter for them.
|
||||
-/
|
||||
abbrev Assignment := Std.HashMap FVarId Value
|
||||
abbrev Assignment := HashMap FVarId Value
|
||||
|
||||
/--
|
||||
The context of `InterpM`.
|
||||
@@ -332,7 +332,7 @@ If none is available return `Value.bot`.
|
||||
-/
|
||||
def findVarValue (var : FVarId) : InterpM Value := do
|
||||
let assignment ← getAssignment
|
||||
return assignment.getD var .bot
|
||||
return assignment.findD var .bot
|
||||
|
||||
/--
|
||||
Find the value of `arg` using the logic of `findVarValue`.
|
||||
@@ -547,13 +547,13 @@ where
|
||||
| .jp decl k | .fun decl k =>
|
||||
return code.updateFun! (← decl.updateValue (← go decl.value)) (← go k)
|
||||
| .cases cs =>
|
||||
let discrVal := assignment.getD cs.discr .bot
|
||||
let discrVal := assignment.findD cs.discr .bot
|
||||
let processAlt typ alt := do
|
||||
match alt with
|
||||
| .alt ctor args body =>
|
||||
if discrVal.containsCtor ctor then
|
||||
let filter param := do
|
||||
if let some val := assignment[param.fvarId]? then
|
||||
if let some val := assignment.find? param.fvarId then
|
||||
if let some literal ← val.getLiteral then
|
||||
return some (param, literal)
|
||||
return none
|
||||
|
||||
@@ -62,7 +62,7 @@ structure State where
|
||||
Whenever there is function application `f a₁ ... aₙ`, where `f` is in `decls`, `f` is not `main`, and
|
||||
we visit with the abstract values assigned to `aᵢ`, but first we record the visit here.
|
||||
-/
|
||||
visited : Std.HashSet (Name × Array AbsValue) := {}
|
||||
visited : HashSet (Name × Array AbsValue) := {}
|
||||
/--
|
||||
Bitmask containing the result, i.e., which parameters of `main` are fixed.
|
||||
We initialize it with `true` everywhere.
|
||||
|
||||
@@ -59,7 +59,7 @@ structure FloatState where
|
||||
/--
|
||||
A map from identifiers of declarations to their current decision.
|
||||
-/
|
||||
decision : Std.HashMap FVarId Decision
|
||||
decision : HashMap FVarId Decision
|
||||
/--
|
||||
A map from decisions (excluding `unknown`) to the declarations with
|
||||
these decisions (in correct order). Basically:
|
||||
@@ -67,7 +67,7 @@ structure FloatState where
|
||||
- Which declarations do we move into a certain arm
|
||||
- Which declarations do we move into the default arm
|
||||
-/
|
||||
newArms : Std.HashMap Decision (List CodeDecl)
|
||||
newArms : HashMap Decision (List CodeDecl)
|
||||
|
||||
/--
|
||||
Use to collect relevant declarations for the floating mechanism.
|
||||
@@ -116,8 +116,8 @@ up to this point, with respect to `cs`. The initial decisions are:
|
||||
- `arm` or `default` if we see the declaration only being used in exactly one cases arm
|
||||
- `unknown` otherwise
|
||||
-/
|
||||
def initialDecisions (cs : Cases) : BaseFloatM (Std.HashMap FVarId Decision) := do
|
||||
let mut map := Std.HashMap.empty (← read).decls.length
|
||||
def initialDecisions (cs : Cases) : BaseFloatM (HashMap FVarId Decision) := do
|
||||
let mut map := mkHashMap (← read).decls.length
|
||||
let folder val acc := do
|
||||
if let .let decl := val then
|
||||
if (← ignore? decl) then
|
||||
@@ -130,25 +130,25 @@ def initialDecisions (cs : Cases) : BaseFloatM (Std.HashMap FVarId Decision) :=
|
||||
(_, map) ← goCases cs |>.run map
|
||||
return map
|
||||
where
|
||||
goFVar (plannedDecision : Decision) (var : FVarId) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit := do
|
||||
if let some decision := (← get)[var]? then
|
||||
goFVar (plannedDecision : Decision) (var : FVarId) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit := do
|
||||
if let some decision := (← get).find? var then
|
||||
if decision == .unknown then
|
||||
modify fun s => s.insert var plannedDecision
|
||||
else if decision != plannedDecision then
|
||||
modify fun s => s.insert var .dont
|
||||
-- otherwise we already have the proper decision
|
||||
|
||||
goAlt (alt : Alt) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
goAlt (alt : Alt) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
forFVarM (goFVar (.ofAlt alt)) alt
|
||||
goCases (cs : Cases) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
goCases (cs : Cases) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
cs.alts.forM goAlt
|
||||
|
||||
/--
|
||||
Compute the initial new arms. This will just set up a map from all arms of
|
||||
`cs` to empty `Array`s, plus one additional entry for `dont`.
|
||||
-/
|
||||
def initialNewArms (cs : Cases) : Std.HashMap Decision (List CodeDecl) := Id.run do
|
||||
let mut map := Std.HashMap.empty (cs.alts.size + 1)
|
||||
def initialNewArms (cs : Cases) : HashMap Decision (List CodeDecl) := Id.run do
|
||||
let mut map := mkHashMap (cs.alts.size + 1)
|
||||
map := map.insert .dont []
|
||||
cs.alts.foldr (init := map) fun val acc => acc.insert (.ofAlt val) []
|
||||
|
||||
@@ -170,7 +170,7 @@ respectively but since `z` can't be moved we don't want that to move `x` and `y`
|
||||
-/
|
||||
def dontFloat (decl : CodeDecl) : FloatM Unit := do
|
||||
forFVarM goFVar decl
|
||||
modify fun s => { s with newArms := s.newArms.insert .dont (decl :: s.newArms[Decision.dont]!) }
|
||||
modify fun s => { s with newArms := s.newArms.insert .dont (decl :: s.newArms.find! .dont) }
|
||||
where
|
||||
goFVar (fvar : FVarId) : FloatM Unit := do
|
||||
if (← get).decision.contains fvar then
|
||||
@@ -223,12 +223,12 @@ Will:
|
||||
If we are at `y` `x` is still marked to be moved but we don't want that.
|
||||
-/
|
||||
def float (decl : CodeDecl) : FloatM Unit := do
|
||||
let arm := (← get).decision[decl.fvarId]!
|
||||
let arm := (← get).decision.find! decl.fvarId
|
||||
forFVarM (goFVar · arm) decl
|
||||
modify fun s => { s with newArms := s.newArms.insert arm (decl :: s.newArms[arm]!) }
|
||||
modify fun s => { s with newArms := s.newArms.insert arm (decl :: s.newArms.find! arm) }
|
||||
where
|
||||
goFVar (fvar : FVarId) (arm : Decision) : FloatM Unit := do
|
||||
let some decision := (← get).decision[fvar]? | return ()
|
||||
let some decision := (← get).decision.find? fvar | return ()
|
||||
if decision != arm then
|
||||
modify fun s => { s with decision := s.decision.insert fvar .dont }
|
||||
else if decision == .unknown then
|
||||
@@ -249,7 +249,7 @@ where
|
||||
-/
|
||||
goCases : FloatM Unit := do
|
||||
for decl in (← read).decls do
|
||||
let currentDecision := (← get).decision[decl.fvarId]!
|
||||
let currentDecision := (← get).decision.find! decl.fvarId
|
||||
if currentDecision == .unknown then
|
||||
/-
|
||||
If the decision is still unknown by now this means `decl` is
|
||||
@@ -284,10 +284,10 @@ where
|
||||
newArms := initialNewArms cs
|
||||
}
|
||||
let (_, res) ← goCases |>.run base
|
||||
let remainders := res.newArms[Decision.dont]!
|
||||
let remainders := res.newArms.find! .dont
|
||||
let altMapper alt := do
|
||||
let decision := Decision.ofAlt alt
|
||||
let newCode := res.newArms[decision]!
|
||||
let decision := .ofAlt alt
|
||||
let newCode := res.newArms.find! decision
|
||||
trace[Compiler.floatLetIn] "Size of code that was pushed into arm: {repr decision} {newCode.length}"
|
||||
let fused ← withNewScope do
|
||||
go (attachCodeDecls newCode.toArray alt.getCode)
|
||||
|
||||
@@ -29,7 +29,7 @@ structure CandidateInfo where
|
||||
The set of candidates that rely on this candidate to be a join point.
|
||||
For a more detailed explanation see the documentation of `find`
|
||||
-/
|
||||
associated : Std.HashSet FVarId
|
||||
associated : HashSet FVarId
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
@@ -39,14 +39,14 @@ structure FindState where
|
||||
/--
|
||||
All current join point candidates accessible by their `FVarId`.
|
||||
-/
|
||||
candidates : Std.HashMap FVarId CandidateInfo := .empty
|
||||
candidates : HashMap FVarId CandidateInfo := .empty
|
||||
/--
|
||||
The `FVarId`s of all `fun` declarations that were declared within the
|
||||
current `fun`.
|
||||
-/
|
||||
scope : Std.HashSet FVarId := .empty
|
||||
scope : HashSet FVarId := .empty
|
||||
|
||||
abbrev ReplaceCtx := Std.HashMap FVarId Name
|
||||
abbrev ReplaceCtx := HashMap FVarId Name
|
||||
|
||||
abbrev FindM := ReaderT (Option FVarId) StateRefT FindState ScopeM
|
||||
abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
|
||||
@@ -55,7 +55,7 @@ abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
|
||||
Attempt to find a join point candidate by its `FVarId`.
|
||||
-/
|
||||
private def findCandidate? (fvarId : FVarId) : FindM (Option CandidateInfo) := do
|
||||
return (← get).candidates[fvarId]?
|
||||
return (← get).candidates.find? fvarId
|
||||
|
||||
/--
|
||||
Erase a join point candidate as well as all the ones that depend on it
|
||||
@@ -69,7 +69,7 @@ private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
|
||||
/--
|
||||
Combinator for modifying the candidates in `FindM`.
|
||||
-/
|
||||
private def modifyCandidates (f : Std.HashMap FVarId CandidateInfo → Std.HashMap FVarId CandidateInfo) : FindM Unit :=
|
||||
private def modifyCandidates (f : HashMap FVarId CandidateInfo → HashMap FVarId CandidateInfo) : FindM Unit :=
|
||||
modify (fun state => {state with candidates := f state.candidates })
|
||||
|
||||
/--
|
||||
@@ -196,7 +196,7 @@ where
|
||||
return code
|
||||
| _, _ => return Code.updateLet! code decl (← go k)
|
||||
| .fun decl k =>
|
||||
if let some replacement := (← read)[decl.fvarId]? then
|
||||
if let some replacement := (← read).find? decl.fvarId then
|
||||
let newDecl := { decl with
|
||||
binderName := replacement,
|
||||
value := (← go decl.value)
|
||||
@@ -244,7 +244,7 @@ structure ExtendState where
|
||||
to `Param`s. The free variables in this map are the once that the context
|
||||
of said join point will be extended by by passing in the respective parameter.
|
||||
-/
|
||||
fvarMap : Std.HashMap FVarId (Std.HashMap FVarId Param) := {}
|
||||
fvarMap : HashMap FVarId (HashMap FVarId Param) := {}
|
||||
|
||||
/--
|
||||
The monad for the `extendJoinPointContext` pass.
|
||||
@@ -262,7 +262,7 @@ otherwise just return `fvar`.
|
||||
def replaceFVar (fvar : FVarId) : ExtendM FVarId := do
|
||||
if (← read).candidates.contains fvar then
|
||||
if let some currentJp := (← read).currentJp? then
|
||||
if let some replacement := (← get).fvarMap[currentJp]![fvar]? then
|
||||
if let some replacement := (← get).fvarMap.find! currentJp |>.find? fvar then
|
||||
return replacement.fvarId
|
||||
return fvar
|
||||
|
||||
@@ -313,7 +313,7 @@ This is necessary if:
|
||||
-/
|
||||
def extendByIfNecessary (fvar : FVarId) : ExtendM Unit := do
|
||||
if let some currentJp := (← read).currentJp? then
|
||||
let mut translator := (← get).fvarMap[currentJp]!
|
||||
let mut translator := (← get).fvarMap.find! currentJp
|
||||
let candidates := (← read).candidates
|
||||
if !(← isInScope fvar) && !translator.contains fvar && candidates.contains fvar then
|
||||
let typ ← getType fvar
|
||||
@@ -337,7 +337,7 @@ of `j.2` in `j.1`.
|
||||
-/
|
||||
def mergeJpContextIfNecessary (jp : FVarId) : ExtendM Unit := do
|
||||
if (← read).currentJp?.isSome then
|
||||
let additionalArgs := (← get).fvarMap[jp]!.toArray
|
||||
let additionalArgs := (← get).fvarMap.find! jp |>.toArray
|
||||
for (fvar, _) in additionalArgs do
|
||||
extendByIfNecessary fvar
|
||||
|
||||
@@ -405,7 +405,7 @@ where
|
||||
| .jp decl k =>
|
||||
let decl ← withNewJpScope decl do
|
||||
let value ← go decl.value
|
||||
let additionalParams := (← get).fvarMap[decl.fvarId]!.toArray |>.map Prod.snd
|
||||
let additionalParams := (← get).fvarMap.find! decl.fvarId |>.toArray |>.map Prod.snd
|
||||
let newType := additionalParams.foldr (init := decl.type) (fun val acc => .forallE val.binderName val.type acc .default)
|
||||
decl.update newType (additionalParams ++ decl.params) value
|
||||
mergeJpContextIfNecessary decl.fvarId
|
||||
@@ -426,7 +426,7 @@ where
|
||||
return Code.updateCases! code cs.resultType discr alts
|
||||
| .jmp fn args =>
|
||||
let mut newArgs ← args.mapM (mapFVarM goFVar)
|
||||
let additionalArgs := (← get).fvarMap[fn]!.toArray |>.map Prod.fst
|
||||
let additionalArgs := (← get).fvarMap.find! fn |>.toArray |>.map Prod.fst
|
||||
if let some _currentJp := (← read).currentJp? then
|
||||
let f := fun arg => do
|
||||
return .fvar (← goFVar arg)
|
||||
@@ -545,7 +545,7 @@ where
|
||||
if let some knownArgs := (← get).jpJmpArgs.find? fn then
|
||||
let mut newArgs := knownArgs
|
||||
for (param, arg) in decl.params.zip args do
|
||||
if let some knownVal := newArgs[param.fvarId]? then
|
||||
if let some knownVal := newArgs.find? param.fvarId then
|
||||
if arg.toExpr != knownVal then
|
||||
newArgs := newArgs.erase param.fvarId
|
||||
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn newArgs }
|
||||
|
||||
@@ -13,9 +13,9 @@ namespace Lean.Compiler.LCNF
|
||||
LCNF local context.
|
||||
-/
|
||||
structure LCtx where
|
||||
params : Std.HashMap FVarId Param := {}
|
||||
letDecls : Std.HashMap FVarId LetDecl := {}
|
||||
funDecls : Std.HashMap FVarId FunDecl := {}
|
||||
params : HashMap FVarId Param := {}
|
||||
letDecls : HashMap FVarId LetDecl := {}
|
||||
funDecls : HashMap FVarId FunDecl := {}
|
||||
deriving Inhabited
|
||||
|
||||
def LCtx.addParam (lctx : LCtx) (param : Param) : LCtx :=
|
||||
|
||||
@@ -30,7 +30,7 @@ structure State where
|
||||
/-- Counter for generating new (normalized) universe parameter names. -/
|
||||
nextIdx : Nat := 1
|
||||
/-- Mapping from existing universe parameter names to the new ones. -/
|
||||
map : Std.HashMap Name Level := {}
|
||||
map : HashMap Name Level := {}
|
||||
/-- Parameters that have been normalized. -/
|
||||
paramNames : Array Name := #[]
|
||||
|
||||
@@ -49,7 +49,7 @@ partial def normLevel (u : Level) : M Level := do
|
||||
| .max v w => return u.updateMax! (← normLevel v) (← normLevel w)
|
||||
| .imax v w => return u.updateIMax! (← normLevel v) (← normLevel w)
|
||||
| .mvar _ => unreachable!
|
||||
| .param n => match (← get).map[n]? with
|
||||
| .param n => match (← get).map.find? n with
|
||||
| some u => return u
|
||||
| none =>
|
||||
let u := Level.param <| (`u).appendIndexAfter (← get).nextIdx
|
||||
|
||||
@@ -31,9 +31,9 @@ def sortedBySize : Probe Decl (Nat × Decl) := fun decls =>
|
||||
if sz₁ == sz₂ then Name.lt decl₁.name decl₂.name else sz₁ < sz₂
|
||||
|
||||
def countUnique [ToString α] [BEq α] [Hashable α] : Probe α (α × Nat) := fun data => do
|
||||
let mut map := Std.HashMap.empty
|
||||
let mut map := HashMap.empty
|
||||
for d in data do
|
||||
if let some count := map[d]? then
|
||||
if let some count := map.find? d then
|
||||
map := map.insert d (count + 1)
|
||||
else
|
||||
map := map.insert d 1
|
||||
|
||||
@@ -40,7 +40,7 @@ structure FunDeclInfoMap where
|
||||
/--
|
||||
Mapping from local function name to inlining information.
|
||||
-/
|
||||
map : Std.HashMap FVarId FunDeclInfo := {}
|
||||
map : HashMap FVarId FunDeclInfo := {}
|
||||
deriving Inhabited
|
||||
|
||||
def FunDeclInfoMap.format (s : FunDeclInfoMap) : CompilerM Format := do
|
||||
@@ -56,7 +56,7 @@ Add new occurrence for the local function with binder name `key`.
|
||||
def FunDeclInfoMap.add (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
|
||||
match s with
|
||||
| { map } =>
|
||||
match map[fvarId]? with
|
||||
match map.find? fvarId with
|
||||
| some .once => { map := map.insert fvarId .many }
|
||||
| none => { map := map.insert fvarId .once }
|
||||
| _ => { map }
|
||||
@@ -67,7 +67,7 @@ Add new occurrence for the local function occurring as an argument for another f
|
||||
def FunDeclInfoMap.addHo (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
|
||||
match s with
|
||||
| { map } =>
|
||||
match map[fvarId]? with
|
||||
match map.find? fvarId with
|
||||
| some .once | none => { map := map.insert fvarId .many }
|
||||
| _ => { map }
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ Execute `x` with `fvarId` set as `mustInline`.
|
||||
After execution the original setting is restored.
|
||||
-/
|
||||
def withAddMustInline (fvarId : FVarId) (x : SimpM α) : SimpM α := do
|
||||
let saved? := (← get).funDeclInfoMap.map[fvarId]?
|
||||
let saved? := (← get).funDeclInfoMap.map.find? fvarId
|
||||
try
|
||||
addMustInline fvarId
|
||||
x
|
||||
@@ -185,7 +185,7 @@ Return true if the given local function declaration or join point id is marked a
|
||||
`once` or `mustInline`. We use this information to decide whether to inline them.
|
||||
-/
|
||||
def isOnceOrMustInline (fvarId : FVarId) : SimpM Bool := do
|
||||
match (← get).funDeclInfoMap.map[fvarId]? with
|
||||
match (← get).funDeclInfoMap.map.find? fvarId with
|
||||
| some .once | some .mustInline => return true
|
||||
| _ => return false
|
||||
|
||||
|
||||
@@ -199,9 +199,9 @@ structure State where
|
||||
/-- Cache from Lean regular expression to LCNF argument. -/
|
||||
cache : PHashMap Expr Arg := {}
|
||||
/-- `toLCNFType` cache -/
|
||||
typeCache : Std.HashMap Expr Expr := {}
|
||||
typeCache : HashMap Expr Expr := {}
|
||||
/-- isTypeFormerType cache -/
|
||||
isTypeFormerTypeCache : Std.HashMap Expr Bool := {}
|
||||
isTypeFormerTypeCache : HashMap Expr Bool := {}
|
||||
/-- LCNF sequence, we chain it to create a LCNF `Code` object. -/
|
||||
seq : Array Element := #[]
|
||||
/--
|
||||
@@ -257,7 +257,7 @@ private partial def isTypeFormerType (type : Expr) : M Bool := do
|
||||
| .true => return true
|
||||
| .false => return false
|
||||
| .undef =>
|
||||
if let some result := (← get).isTypeFormerTypeCache[type]? then
|
||||
if let some result := (← get).isTypeFormerTypeCache.find? type then
|
||||
return result
|
||||
let result ← liftMetaM <| Meta.isTypeFormerType type
|
||||
modify fun s => { s with isTypeFormerTypeCache := s.isTypeFormerTypeCache.insert type result }
|
||||
@@ -305,7 +305,7 @@ def applyToAny (type : Expr) : M Expr := do
|
||||
| _ => none
|
||||
|
||||
def toLCNFType (type : Expr) : M Expr := do
|
||||
match (← get).typeCache[type]? with
|
||||
match (← get).typeCache.find? type with
|
||||
| some type' => return type'
|
||||
| none =>
|
||||
let type' ← liftMetaM <| LCNF.toLCNFType type
|
||||
|
||||
@@ -6,8 +6,6 @@ Author: Leonardo de Moura
|
||||
prelude
|
||||
import Init.Data.Nat.Power2
|
||||
import Lean.Data.AssocList
|
||||
import Std.Data.HashMap.Basic
|
||||
import Std.Data.HashMap.Raw
|
||||
namespace Lean
|
||||
|
||||
def HashMapBucket (α : Type u) (β : Type v) :=
|
||||
@@ -271,11 +269,17 @@ def ofListWith (l : List (α × β)) (f : β → β → β) : HashMap α β :=
|
||||
| none => m.insert p.fst p.snd
|
||||
| some v => m.insert p.fst $ f v p.snd)
|
||||
|
||||
attribute [deprecated Std.HashMap] HashMap
|
||||
attribute [deprecated Std.HashMap.Raw] HashMapImp
|
||||
attribute [deprecated Std.HashMap.Raw.empty] mkHashMapImp
|
||||
attribute [deprecated Std.HashMap.empty] mkHashMap
|
||||
attribute [deprecated Std.HashMap.empty] HashMap.empty
|
||||
attribute [deprecated Std.HashMap.ofList] HashMap.ofList
|
||||
|
||||
end Lean.HashMap
|
||||
|
||||
/--
|
||||
Groups all elements `x`, `y` in `xs` with `key x == key y` into the same array
|
||||
`(xs.groupByKey key).find! (key x)`. Groups preserve the relative order of elements in `xs`.
|
||||
-/
|
||||
def Array.groupByKey [BEq α] [Hashable α] (key : β → α) (xs : Array β)
|
||||
: Lean.HashMap α (Array β) := Id.run do
|
||||
let mut groups := ∅
|
||||
for x in xs do
|
||||
let group := groups.findD (key x) #[]
|
||||
groups := groups.erase (key x) -- make `group` referentially unique
|
||||
groups := groups.insert (key x) (group.push x)
|
||||
return groups
|
||||
|
||||
@@ -6,8 +6,6 @@ Author: Leonardo de Moura
|
||||
prelude
|
||||
import Init.Data.Nat.Power2
|
||||
import Init.Data.List.Control
|
||||
import Std.Data.HashSet.Basic
|
||||
import Std.Data.HashSet.Raw
|
||||
namespace Lean
|
||||
universe u v w
|
||||
|
||||
@@ -219,9 +217,3 @@ def insertMany [ForIn Id ρ α] (s : HashSet α) (as : ρ) : HashSet α := Id.ru
|
||||
def merge {α : Type u} [BEq α] [Hashable α] (s t : HashSet α) : HashSet α :=
|
||||
t.fold (init := s) fun s a => s.insert a
|
||||
-- We don't use `insertMany` here because it gives weird universes.
|
||||
|
||||
attribute [deprecated Std.HashSet] HashSet
|
||||
attribute [deprecated Std.HashSet.Raw] HashSetImp
|
||||
attribute [deprecated Std.HashSet.Raw.empty] mkHashSetImp
|
||||
attribute [deprecated Std.HashSet.empty] mkHashSet
|
||||
attribute [deprecated Std.HashSet.empty] HashSet.empty
|
||||
|
||||
@@ -150,7 +150,7 @@ instance : FromJson RefInfo where
|
||||
pure { definition?, usages }
|
||||
|
||||
/-- References from a single module/file -/
|
||||
def ModuleRefs := Std.HashMap RefIdent RefInfo
|
||||
def ModuleRefs := HashMap RefIdent RefInfo
|
||||
|
||||
instance : ToJson ModuleRefs where
|
||||
toJson m := Json.mkObj <| m.toList.map fun (ident, info) => (ident.toJson.compress, toJson info)
|
||||
@@ -158,7 +158,7 @@ instance : ToJson ModuleRefs where
|
||||
instance : FromJson ModuleRefs where
|
||||
fromJson? j := do
|
||||
let node ← j.getObj?
|
||||
node.foldM (init := Std.HashMap.empty) fun m k v =>
|
||||
node.foldM (init := HashMap.empty) fun m k v =>
|
||||
return m.insert (← RefIdent.fromJson? (← Json.parse k)) (← fromJson? v)
|
||||
|
||||
/--
|
||||
|
||||
@@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Std.Data.HashSet.Basic
|
||||
import Lean.Data.HashSet
|
||||
import Lean.Data.RBMap
|
||||
import Lean.Data.RBTree
|
||||
@@ -65,14 +64,14 @@ abbrev insert (s : NameSSet) (n : Name) : NameSSet := SSet.insert s n
|
||||
abbrev contains (s : NameSSet) (n : Name) : Bool := SSet.contains s n
|
||||
end NameSSet
|
||||
|
||||
def NameHashSet := Std.HashSet Name
|
||||
def NameHashSet := HashSet Name
|
||||
|
||||
namespace NameHashSet
|
||||
@[inline] def empty : NameHashSet := Std.HashSet.empty
|
||||
@[inline] def empty : NameHashSet := HashSet.empty
|
||||
instance : EmptyCollection NameHashSet := ⟨empty⟩
|
||||
instance : Inhabited NameHashSet := ⟨{}⟩
|
||||
def insert (s : NameHashSet) (n : Name) := Std.HashSet.insert s n
|
||||
def contains (s : NameHashSet) (n : Name) : Bool := Std.HashSet.contains s n
|
||||
def insert (s : NameHashSet) (n : Name) := HashSet.insert s n
|
||||
def contains (s : NameHashSet) (n : Name) : Bool := HashSet.contains s n
|
||||
end NameHashSet
|
||||
|
||||
def MacroScopesView.isPrefixOf (v₁ v₂ : MacroScopesView) : Bool :=
|
||||
|
||||
@@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Std.Data.HashMap.Basic
|
||||
import Lean.Data.HashMap
|
||||
import Lean.Data.PersistentHashMap
|
||||
universe u v w w'
|
||||
@@ -29,7 +28,7 @@ namespace Lean
|
||||
-/
|
||||
structure SMap (α : Type u) (β : Type v) [BEq α] [Hashable α] where
|
||||
stage₁ : Bool := true
|
||||
map₁ : Std.HashMap α β := {}
|
||||
map₁ : HashMap α β := {}
|
||||
map₂ : PHashMap α β := {}
|
||||
|
||||
namespace SMap
|
||||
@@ -38,7 +37,7 @@ variable {α : Type u} {β : Type v} [BEq α] [Hashable α]
|
||||
instance : Inhabited (SMap α β) := ⟨{}⟩
|
||||
def empty : SMap α β := {}
|
||||
|
||||
@[inline] def fromHashMap (m : Std.HashMap α β) (stage₁ := true) : SMap α β :=
|
||||
@[inline] def fromHashMap (m : HashMap α β) (stage₁ := true) : SMap α β :=
|
||||
{ map₁ := m, stage₁ := stage₁ }
|
||||
|
||||
@[specialize] def insert : SMap α β → α → β → SMap α β
|
||||
@@ -50,8 +49,8 @@ def empty : SMap α β := {}
|
||||
| ⟨false, m₁, m₂⟩, k, v => ⟨false, m₁, m₂.insert k v⟩
|
||||
|
||||
@[specialize] def find? : SMap α β → α → Option β
|
||||
| ⟨true, m₁, _⟩, k => m₁[k]?
|
||||
| ⟨false, m₁, m₂⟩, k => (m₂.find? k).orElse fun _ => m₁[k]?
|
||||
| ⟨true, m₁, _⟩, k => m₁.find? k
|
||||
| ⟨false, m₁, m₂⟩, k => (m₂.find? k).orElse fun _ => m₁.find? k
|
||||
|
||||
@[inline] def findD (m : SMap α β) (a : α) (b₀ : β) : β :=
|
||||
(m.find? a).getD b₀
|
||||
@@ -68,8 +67,8 @@ def empty : SMap α β := {}
|
||||
/-- Similar to `find?`, but searches for result in the hashmap first.
|
||||
So, the result is correct only if we never "overwrite" `map₁` entries using `map₂`. -/
|
||||
@[specialize] def find?' : SMap α β → α → Option β
|
||||
| ⟨true, m₁, _⟩, k => m₁[k]?
|
||||
| ⟨false, m₁, m₂⟩, k => m₁[k]?.orElse fun _ => m₂.find? k
|
||||
| ⟨true, m₁, _⟩, k => m₁.find? k
|
||||
| ⟨false, m₁, m₂⟩, k => (m₁.find? k).orElse fun _ => m₂.find? k
|
||||
|
||||
def forM [Monad m] (s : SMap α β) (f : α → β → m PUnit) : m PUnit := do
|
||||
s.map₁.forM f
|
||||
@@ -97,7 +96,7 @@ def fold {σ : Type w} (f : σ → α → β → σ) (init : σ) (m : SMap α β
|
||||
m.map₂.foldl f $ m.map₁.fold f init
|
||||
|
||||
def numBuckets (m : SMap α β) : Nat :=
|
||||
Std.HashMap.Internal.numBuckets m.map₁
|
||||
m.map₁.numBuckets
|
||||
|
||||
def toList (m : SMap α β) : List (α × β) :=
|
||||
m.fold (init := []) fun es a b => (a, b)::es
|
||||
|
||||
@@ -541,7 +541,7 @@ mutual
|
||||
/--
|
||||
Process a `fType` of the form `(x : A) → B x`.
|
||||
This method assume `fType` is a function type -/
|
||||
private partial def processExplicitArg (argName : Name) : M Expr := do
|
||||
private partial def processExplictArg (argName : Name) : M Expr := do
|
||||
match (← get).args with
|
||||
| arg::args =>
|
||||
if (← anyNamedArgDependsOnCurrent) then
|
||||
@@ -586,16 +586,6 @@ mutual
|
||||
| Except.ok tacticSyntax =>
|
||||
-- TODO(Leo): does this work correctly for tactic sequences?
|
||||
let tacticBlock ← `(by $(⟨tacticSyntax⟩))
|
||||
/-
|
||||
We insert position information from the current ref into `stx` everywhere, simulating this being
|
||||
a tactic script inserted by the user, which ensures error messages and logging will always be attributed
|
||||
to this application rather than sometimes being placed at position (1,0) in the file.
|
||||
Placing position information on `by` syntax alone is not sufficient since incrementality
|
||||
(in particular, `Lean.Elab.Term.withReuseContext`) controls the ref to avoid leakage of outside data.
|
||||
Note that `tacticSyntax` contains no position information itself, since it is erased by `Lean.Elab.Term.quoteAutoTactic`.
|
||||
-/
|
||||
let info := (← getRef).getHeadInfo
|
||||
let tacticBlock := tacticBlock.raw.rewriteBottomUp (·.setInfo info)
|
||||
let argNew := Arg.stx tacticBlock
|
||||
propagateExpectedType argNew
|
||||
elabAndAddNewArg argName argNew
|
||||
@@ -625,7 +615,7 @@ mutual
|
||||
This method assume `fType` is a function type -/
|
||||
private partial def processImplicitArg (argName : Name) : M Expr := do
|
||||
if (← read).explicit then
|
||||
processExplicitArg argName
|
||||
processExplictArg argName
|
||||
else
|
||||
addImplicitArg argName
|
||||
|
||||
@@ -634,7 +624,7 @@ mutual
|
||||
This method assume `fType` is a function type -/
|
||||
private partial def processStrictImplicitArg (argName : Name) : M Expr := do
|
||||
if (← read).explicit then
|
||||
processExplicitArg argName
|
||||
processExplictArg argName
|
||||
else if (← hasArgsToProcess) then
|
||||
addImplicitArg argName
|
||||
else
|
||||
@@ -653,7 +643,7 @@ mutual
|
||||
addNewArg argName arg
|
||||
main
|
||||
else
|
||||
processExplicitArg argName
|
||||
processExplictArg argName
|
||||
else
|
||||
let arg ← mkFreshExprMVar (← getArgExpectedType) MetavarKind.synthetic
|
||||
addInstMVar arg.mvarId!
|
||||
@@ -678,7 +668,7 @@ mutual
|
||||
| .implicit => processImplicitArg binderName
|
||||
| .instImplicit => processInstImplicitArg binderName
|
||||
| .strictImplicit => processStrictImplicitArg binderName
|
||||
| _ => processExplicitArg binderName
|
||||
| _ => processExplictArg binderName
|
||||
else if (← hasArgsToProcess) then
|
||||
synthesizePendingAndNormalizeFunType
|
||||
main
|
||||
|
||||
@@ -502,16 +502,6 @@ def elabRunMeta : CommandElab := fun stx =>
|
||||
addDocString declName (← getDocStringText doc)
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
@[builtin_command_elab Lean.Parser.Command.include] def elabInclude : CommandElab
|
||||
| `(Lean.Parser.Command.include| include $ids*) => do
|
||||
let vars := (← getScope).varDecls.concatMap getBracketedBinderIds
|
||||
for id in ids do
|
||||
unless vars.contains id.getId do
|
||||
throwError "invalid 'include', variable '{id}' has not been declared in the current scope"
|
||||
modifyScope fun sc =>
|
||||
{ sc with includedVars := sc.includedVars ++ ids.toList.map (·.getId) }
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
@[builtin_command_elab Parser.Command.exit] def elabExit : CommandElab := fun _ =>
|
||||
logWarning "using 'exit' to interrupt Lean"
|
||||
|
||||
|
||||
@@ -257,52 +257,31 @@ partial def hasCDot : Syntax → Bool
|
||||
Return `some` if succeeded expanding `·` notation occurring in
|
||||
the given syntax. Otherwise, return `none`.
|
||||
Examples:
|
||||
- `· + 1` => `fun x => x + 1`
|
||||
- `f · · b` => `fun x1 x2 => f x1 x2 b` -/
|
||||
- `· + 1` => `fun _a_1 => _a_1 + 1`
|
||||
- `f · · b` => `fun _a_1 _a_2 => f _a_1 _a_2 b` -/
|
||||
partial def expandCDot? (stx : Term) : MacroM (Option Term) := do
|
||||
if hasCDot stx then
|
||||
withFreshMacroScope do
|
||||
let mut (newStx, binders) ← (go stx).run #[]
|
||||
if binders.size == 1 then
|
||||
-- It is nicer using `x` over `x1` if there's only a single binder.
|
||||
let x1 := binders[0]!
|
||||
let x := mkIdentFrom x1 (← MonadQuotation.addMacroScope `x) (canonical := true)
|
||||
binders := binders.set! 0 x
|
||||
newStx ← newStx.replaceM fun s => pure (if s == x1 then x else none)
|
||||
`(fun $binders* => $(⟨newStx⟩))
|
||||
let (newStx, binders) ← (go stx).run #[]
|
||||
`(fun $binders* => $(⟨newStx⟩))
|
||||
else
|
||||
pure none
|
||||
where
|
||||
/--
|
||||
Auxiliary function for expanding the `·` notation.
|
||||
The extra state `Array Syntax` contains the new binder names.
|
||||
If `stx` is a `·`, we create a fresh identifier, store it in the
|
||||
extra state, and return it. Otherwise, we just return `stx`.
|
||||
-/
|
||||
Auxiliary function for expanding the `·` notation.
|
||||
The extra state `Array Syntax` contains the new binder names.
|
||||
If `stx` is a `·`, we create a fresh identifier, store in the
|
||||
extra state, and return it. Otherwise, we just return `stx`. -/
|
||||
go : Syntax → StateT (Array Ident) MacroM Syntax
|
||||
| stx@`(($(_))) => pure stx
|
||||
| stx@`(·) => do
|
||||
let name ← MonadQuotation.addMacroScope <| Name.mkSimple s!"x{(← get).size + 1}"
|
||||
let id := mkIdentFrom stx name (canonical := true)
|
||||
modify (fun s => s.push id)
|
||||
pure id
|
||||
| stx => match stx with
|
||||
| .node _ k args => do
|
||||
let args ←
|
||||
if k == choiceKind then
|
||||
if args.isEmpty then
|
||||
return stx
|
||||
let s ← get
|
||||
let args' ← args.mapM (fun arg => go arg |>.run s)
|
||||
let s' := args'[0]!.2
|
||||
unless args'.all (fun (_, s'') => s''.size == s'.size) do
|
||||
Macro.throwErrorAt stx "Ambiguous notation in cdot function has different numbers of '·' arguments in each alternative."
|
||||
set s'
|
||||
pure <| args'.map Prod.fst
|
||||
else
|
||||
args.mapM go
|
||||
return .node (.fromRef stx (canonical := true)) k args
|
||||
| _ => pure stx
|
||||
| stx@`(($(_))) => pure stx
|
||||
| stx@`(·) => withFreshMacroScope do
|
||||
let id ← mkFreshIdent stx (canonical := true)
|
||||
modify (·.push id)
|
||||
pure id
|
||||
| stx => match stx with
|
||||
| .node _ k args => do
|
||||
let args ← args.mapM go
|
||||
return .node (.fromRef stx (canonical := true)) k args
|
||||
| _ => pure stx
|
||||
|
||||
/--
|
||||
Helper method for elaborating terms such as `(.+.)` where a constant name is expected.
|
||||
|
||||
@@ -51,8 +51,6 @@ structure Scope where
|
||||
even if they do not work with binders per se.
|
||||
-/
|
||||
varDecls : Array (TSyntax ``Parser.Term.bracketedBinder) := #[]
|
||||
/-- `include`d section variable names -/
|
||||
includedVars : List Name := []
|
||||
/--
|
||||
Globally unique internal identifiers for the `varDecls`.
|
||||
There is one identifier per variable introduced by the binders
|
||||
@@ -203,12 +201,12 @@ def mkMessageAux (ctx : Context) (ref : Syntax) (msgData : MessageData) (severit
|
||||
|
||||
private def addTraceAsMessagesCore (ctx : Context) (log : MessageLog) (traceState : TraceState) : MessageLog := Id.run do
|
||||
if traceState.traces.isEmpty then return log
|
||||
let mut traces : Std.HashMap (String.Pos × String.Pos) (Array MessageData) := ∅
|
||||
let mut traces : HashMap (String.Pos × String.Pos) (Array MessageData) := ∅
|
||||
for traceElem in traceState.traces do
|
||||
let ref := replaceRef traceElem.ref ctx.ref
|
||||
let pos := ref.getPos?.getD 0
|
||||
let endPos := ref.getTailPos?.getD pos
|
||||
traces := traces.insert (pos, endPos) <| traces.getD (pos, endPos) #[] |>.push traceElem.msg
|
||||
traces := traces.insert (pos, endPos) <| traces.findD (pos, endPos) #[] |>.push traceElem.msg
|
||||
let mut log := log
|
||||
let traces' := traces.toArray.qsort fun ((a, _), _) ((b, _), _) => a < b
|
||||
for ((pos, endPos), traceMsg) in traces' do
|
||||
|
||||
@@ -630,7 +630,7 @@ private def replaceIndFVarsWithConsts (views : Array InductiveView) (indFVars :
|
||||
let type := type.replace fun e =>
|
||||
if !e.isFVar then
|
||||
none
|
||||
else match indFVar2Const[e]? with
|
||||
else match indFVar2Const.find? e with
|
||||
| none => none
|
||||
| some c => mkAppN c (params.extract 0 numVars)
|
||||
instantiateMVars (← mkForallFVars params type)
|
||||
|
||||
@@ -425,7 +425,7 @@ private def applyRefMap (e : Expr) (map : ExprMap Expr) : Expr :=
|
||||
e.replace fun e =>
|
||||
match patternWithRef? e with
|
||||
| some _ => some e -- stop `e` already has annotation
|
||||
| none => match map[e]? with
|
||||
| none => match map.find? e with
|
||||
| some eWithRef => some eWithRef -- stop `e` found annotation
|
||||
| none => none -- continue
|
||||
|
||||
|
||||
@@ -327,45 +327,7 @@ def instantiateMVarsProfiling (e : Expr) : MetaM Expr := do
|
||||
profileitM Exception s!"instantiate metavars" (← getOptions) do
|
||||
instantiateMVars e
|
||||
|
||||
/--
|
||||
Runs `k` with a restricted local context where only section variables from `vars` are included that
|
||||
* are directly referenced in any `headers`,
|
||||
* are included in `includedVars` (via the `include` command),
|
||||
* are directly referenced in any variable included by these rules, OR
|
||||
* are instance-implicit variables that only reference section variables included by these rules.
|
||||
-/
|
||||
private def withHeaderSecVars {α} (vars : Array Expr) (includedVars : List Name) (headers : Array DefViewElabHeader)
|
||||
(k : Array Expr → TermElabM α) : TermElabM α := do
|
||||
let (_, used) ← collectUsed.run {}
|
||||
let (lctx, localInsts, vars) ← removeUnused vars used
|
||||
withLCtx lctx localInsts <| k vars
|
||||
where
|
||||
collectUsed : StateRefT CollectFVars.State MetaM Unit := do
|
||||
-- directly referenced in headers
|
||||
headers.forM (·.type.collectFVars)
|
||||
-- included by `include`
|
||||
vars.forM fun var => do
|
||||
let ldecl ← getFVarLocalDecl var
|
||||
if includedVars.contains ldecl.userName then
|
||||
modify (·.add ldecl.fvarId)
|
||||
-- transitively referenced
|
||||
get >>= (·.addDependencies) >>= set
|
||||
-- instances (`addDependencies` unnecessary as by definition they may only reference variables
|
||||
-- already included)
|
||||
vars.forM fun var => do
|
||||
let ldecl ← getFVarLocalDecl var
|
||||
let st ← get
|
||||
if ldecl.binderInfo.isInstImplicit && (← getFVars ldecl.type).all st.fvarSet.contains then
|
||||
modify (·.add ldecl.fvarId)
|
||||
getFVars (e : Expr) : MetaM (Array FVarId) :=
|
||||
(·.2.fvarIds) <$> e.collectFVars.run {}
|
||||
|
||||
register_builtin_option deprecated.oldSectionVars : Bool := {
|
||||
defValue := false
|
||||
descr := "re-enable deprecated behavior of including exactly the section variables used in a declaration"
|
||||
}
|
||||
|
||||
private def elabFunValues (headers : Array DefViewElabHeader) (vars : Array Expr) (includedVars : List Name) : TermElabM (Array Expr) :=
|
||||
private def elabFunValues (headers : Array DefViewElabHeader) : TermElabM (Array Expr) :=
|
||||
headers.mapM fun header => do
|
||||
let mut reusableResult? := none
|
||||
if let some snap := header.bodySnap? then
|
||||
@@ -380,7 +342,6 @@ private def elabFunValues (headers : Array DefViewElabHeader) (vars : Array Expr
|
||||
withReuseContext header.value do
|
||||
withDeclName header.declName <| withLevelNames header.levelNames do
|
||||
let valStx ← liftMacroM <| declValToTerm header.value
|
||||
(if header.kind.isTheorem && !deprecated.oldSectionVars.get (← getOptions) then withHeaderSecVars vars includedVars #[header] else fun x => x #[]) fun vars => do
|
||||
forallBoundedTelescope header.type header.numParams fun xs type => do
|
||||
-- Add new info nodes for new fvars. The server will detect all fvars of a binder by the binder's source location.
|
||||
for i in [0:header.binderIds.size] do
|
||||
@@ -392,20 +353,7 @@ private def elabFunValues (headers : Array DefViewElabHeader) (vars : Array Expr
|
||||
-- NOTE: without this `instantiatedMVars`, `mkLambdaFVars` may leave around a redex that
|
||||
-- leads to more section variables being included than necessary
|
||||
let val ← instantiateMVarsProfiling val
|
||||
let val ← mkLambdaFVars xs val
|
||||
unless header.type.hasSorry || val.hasSorry do
|
||||
for var in vars do
|
||||
unless header.type.containsFVar var.fvarId! ||
|
||||
val.containsFVar var.fvarId! ||
|
||||
(← vars.anyM (fun v => return (← v.fvarId!.getType).containsFVar var.fvarId!)) do
|
||||
let varDecl ← var.fvarId!.getDecl
|
||||
let var := if varDecl.userName.hasMacroScopes && varDecl.binderInfo.isInstImplicit then
|
||||
m!"[{varDecl.type}]".group
|
||||
else
|
||||
var
|
||||
logWarningAt header.ref m!"included section variable '{var}' is not used in \
|
||||
'{header.declName}', consider excluding it"
|
||||
return val
|
||||
mkLambdaFVars xs val
|
||||
if let some snap := header.bodySnap? then
|
||||
snap.new.resolve <| some {
|
||||
diagnostics :=
|
||||
@@ -956,7 +904,7 @@ partial def checkForHiddenUnivLevels (allUserLevelNames : List Name) (preDefs :
|
||||
for preDef in preDefs do
|
||||
checkPreDef preDef
|
||||
|
||||
def elabMutualDef (vars : Array Expr) (includedVars : List Name) (views : Array DefView) : TermElabM Unit :=
|
||||
def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit :=
|
||||
if isExample views then
|
||||
withoutModifyingEnv do
|
||||
-- save correct environment in info tree
|
||||
@@ -977,7 +925,7 @@ where
|
||||
addLocalVarInfo view.declId funFVar
|
||||
let values ←
|
||||
try
|
||||
let values ← elabFunValues headers vars includedVars
|
||||
let values ← elabFunValues headers
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
values.mapM (instantiateMVarsProfiling ·)
|
||||
catch ex =>
|
||||
@@ -987,7 +935,7 @@ where
|
||||
let letRecsToLift ← getLetRecsToLift
|
||||
let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift
|
||||
checkLetRecsToLiftTypes funFVars letRecsToLift
|
||||
(if headers.all (·.kind.isTheorem) && !deprecated.oldSectionVars.get (← getOptions) then withHeaderSecVars vars includedVars headers else withUsed vars headers values letRecsToLift) fun vars => do
|
||||
withUsed vars headers values letRecsToLift fun vars => do
|
||||
let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift
|
||||
for preDef in preDefs do
|
||||
trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
|
||||
@@ -1058,8 +1006,7 @@ def elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
|
||||
if let some snap := snap? then
|
||||
-- no non-fatal diagnostics at this point
|
||||
snap.new.resolve <| .ofTyped { defs, diagnostics := .empty : DefsParsedSnapshot }
|
||||
let includedVars := (← getScope).includedVars
|
||||
runTermElabM fun vars => Term.elabMutualDef vars includedVars views
|
||||
runTermElabM fun vars => Term.elabMutualDef vars views
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Elab.definition.mkClosure
|
||||
|
||||
@@ -164,11 +164,8 @@ def addNonRec (preDef : PreDefinition) (applyAttrAfterCompilation := true) (all
|
||||
/--
|
||||
Eliminate recursive application annotations containing syntax. These annotations are used by the well-founded recursion module
|
||||
to produce better error messages. -/
|
||||
def eraseRecAppSyntaxExpr (e : Expr) : CoreM Expr := do
|
||||
if e.find? hasRecAppSyntax |>.isSome then
|
||||
Core.transform e (post := fun e => pure <| TransformStep.done <| if hasRecAppSyntax e then e.mdataExpr! else e)
|
||||
else
|
||||
return e
|
||||
def eraseRecAppSyntaxExpr (e : Expr) : CoreM Expr :=
|
||||
Core.transform e (post := fun e => pure <| TransformStep.done <| if (getRecAppSyntax? e).isSome then e.mdataExpr! else e)
|
||||
|
||||
def eraseRecAppSyntax (preDef : PreDefinition) : CoreM PreDefinition :=
|
||||
return { preDef with value := (← eraseRecAppSyntaxExpr preDef.value) }
|
||||
|
||||
@@ -69,15 +69,12 @@ private def ensureNoUnassignedMVarsAtPreDef (preDef : PreDefinition) : TermElabM
|
||||
This method beta-reduces them to make sure they can be eliminated by the well-founded recursion module. -/
|
||||
private def betaReduceLetRecApps (preDefs : Array PreDefinition) : MetaM (Array PreDefinition) :=
|
||||
preDefs.mapM fun preDef => do
|
||||
if preDef.value.find? (fun e => e.isConst && preDefs.any fun preDef => preDef.declName == e.constName!) |>.isSome then
|
||||
let value ← Core.transform preDef.value fun e => do
|
||||
if e.isApp && e.getAppFn.isLambda && e.getAppArgs.all fun arg => arg.getAppFn.isConst && preDefs.any fun preDef => preDef.declName == arg.getAppFn.constName! then
|
||||
return .visit e.headBeta
|
||||
else
|
||||
return .continue
|
||||
return { preDef with value }
|
||||
else
|
||||
return preDef
|
||||
let value ← Core.transform preDef.value fun e => do
|
||||
if e.isApp && e.getAppFn.isLambda && e.getAppArgs.all fun arg => arg.getAppFn.isConst && preDefs.any fun preDef => preDef.declName == arg.getAppFn.constName! then
|
||||
return .visit e.headBeta
|
||||
else
|
||||
return .continue
|
||||
return { preDef with value }
|
||||
|
||||
private def addAsAxioms (preDefs : Array PreDefinition) : TermElabM Unit := do
|
||||
for preDef in preDefs do
|
||||
|
||||
@@ -146,7 +146,7 @@ See issue #837 for an example where we can show termination using the index of a
|
||||
we don't get the desired definitional equalities.
|
||||
-/
|
||||
def nonIndicesFirst (recArgInfos : Array RecArgInfo) : Array RecArgInfo := Id.run do
|
||||
let mut indicesPos : Std.HashSet Nat := {}
|
||||
let mut indicesPos : HashSet Nat := {}
|
||||
for recArgInfo in recArgInfos do
|
||||
for pos in recArgInfo.indicesPos do
|
||||
indicesPos := indicesPos.insert pos
|
||||
|
||||
@@ -596,7 +596,7 @@ private partial def compileStxMatch (discrs : List Term) (alts : List Alt) : Ter
|
||||
`(have __discr := $discr; $stx)
|
||||
| _, _ => unreachable!
|
||||
|
||||
abbrev IdxSet := Std.HashSet Nat
|
||||
abbrev IdxSet := HashSet Nat
|
||||
|
||||
private partial def hasNoErrorIfUnused : Syntax → Bool
|
||||
| `(no_error_if_unused% $_) => true
|
||||
|
||||
@@ -11,35 +11,27 @@ namespace Lean
|
||||
private def recAppKey := `_recApp
|
||||
|
||||
/--
|
||||
We store the syntax at recursive applications to be able to generate better error messages
|
||||
when performing well-founded and structural recursion.
|
||||
We store the syntax at recursive applications to be able to generate better error messages
|
||||
when performing well-founded and structural recursion.
|
||||
-/
|
||||
def mkRecAppWithSyntax (e : Expr) (stx : Syntax) : Expr :=
|
||||
mkMData (KVMap.empty.insert recAppKey (.ofSyntax stx)) e
|
||||
mkMData (KVMap.empty.insert recAppKey (DataValue.ofSyntax stx)) e
|
||||
|
||||
/--
|
||||
Retrieve (if available) the syntax object attached to a recursive application.
|
||||
Retrieve (if available) the syntax object attached to a recursive application.
|
||||
-/
|
||||
def getRecAppSyntax? (e : Expr) : Option Syntax :=
|
||||
match e with
|
||||
| .mdata d _ =>
|
||||
| Expr.mdata d _ =>
|
||||
match d.find recAppKey with
|
||||
| some (DataValue.ofSyntax stx) => some stx
|
||||
| _ => none
|
||||
| _ => none
|
||||
|
||||
/--
|
||||
Checks if the `MData` is for a recursive applciation.
|
||||
Checks if the `MData` is for a recursive applciation.
|
||||
-/
|
||||
def MData.isRecApp (d : MData) : Bool :=
|
||||
d.contains recAppKey
|
||||
|
||||
/--
|
||||
Return `true` if `getRecAppSyntax? e` is a `some`.
|
||||
-/
|
||||
def hasRecAppSyntax (e : Expr) : Bool :=
|
||||
match e with
|
||||
| .mdata d _ => d.isRecApp
|
||||
| _ => false
|
||||
|
||||
end Lean
|
||||
|
||||
@@ -445,13 +445,13 @@ private def expandParentFields (s : Struct) : TermElabM Struct := do
|
||||
| _ => throwErrorAt ref "failed to access field '{fieldName}' in parent structure"
|
||||
| _ => return field
|
||||
|
||||
private abbrev FieldMap := Std.HashMap Name Fields
|
||||
private abbrev FieldMap := HashMap Name Fields
|
||||
|
||||
private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
|
||||
fields.foldlM (init := {}) fun fieldMap field =>
|
||||
match field.lhs with
|
||||
| .fieldName _ fieldName :: _ =>
|
||||
match fieldMap[fieldName]? with
|
||||
match fieldMap.find? fieldName with
|
||||
| some (prevField::restFields) =>
|
||||
if field.isSimple || prevField.isSimple then
|
||||
throwErrorAt field.ref "field '{fieldName}' has already been specified"
|
||||
@@ -677,10 +677,6 @@ private partial def elabStruct (s : Struct) (expectedType? : Option Expr) : Term
|
||||
| .error err => throwError err
|
||||
| .ok tacticSyntax =>
|
||||
let stx ← `(by $tacticSyntax)
|
||||
-- See comment in `Lean.Elab.Term.ElabAppArgs.processExplicitArg` about `tacticSyntax`.
|
||||
-- We add info to get reliable positions for messages from evaluating the tactic script.
|
||||
let info := field.ref.getHeadInfo
|
||||
let stx := stx.raw.rewriteBottomUp (·.setInfo info)
|
||||
cont (← elabTermEnsuringType stx (d.getArg! 0).consumeTypeAnnotations) field
|
||||
| _ =>
|
||||
if bi == .instImplicit then
|
||||
|
||||
@@ -246,7 +246,7 @@ private def getSomeSyntheticMVarsRef : TermElabM Syntax := do
|
||||
private def throwStuckAtUniverseCnstr : TermElabM Unit := do
|
||||
-- This code assumes `entries` is not empty. Note that `processPostponed` uses `exceptionOnFailure` to guarantee this property
|
||||
let entries ← getPostponed
|
||||
let mut found : Std.HashSet (Level × Level) := {}
|
||||
let mut found : HashSet (Level × Level) := {}
|
||||
let mut uniqueEntries := #[]
|
||||
for entry in entries do
|
||||
let mut lhs := entry.lhs
|
||||
|
||||
@@ -8,6 +8,8 @@ import Init.Omega.Constraint
|
||||
import Lean.Elab.Tactic.Omega.OmegaM
|
||||
import Lean.Elab.Tactic.Omega.MinNatAbs
|
||||
|
||||
open Lean (HashMap HashSet)
|
||||
|
||||
namespace Lean.Elab.Tactic.Omega
|
||||
|
||||
initialize Lean.registerTraceClass `omega
|
||||
@@ -165,11 +167,11 @@ structure Problem where
|
||||
/-- The number of variables in the problem. -/
|
||||
numVars : Nat := 0
|
||||
/-- The current constraints, indexed by their coefficients. -/
|
||||
constraints : Std.HashMap Coeffs Fact := ∅
|
||||
constraints : HashMap Coeffs Fact := ∅
|
||||
/--
|
||||
The coefficients for which `constraints` contains an exact constraint (i.e. an equality).
|
||||
-/
|
||||
equalities : Std.HashSet Coeffs := ∅
|
||||
equalities : HashSet Coeffs := ∅
|
||||
/--
|
||||
Equations that have already been used to eliminate variables,
|
||||
along with the variable which was removed, and its coefficient (either `1` or `-1`).
|
||||
@@ -249,7 +251,7 @@ combining it with any existing constraints for the same coefficients.
|
||||
def addConstraint (p : Problem) : Fact → Problem
|
||||
| f@⟨x, s, j⟩ =>
|
||||
if p.possible then
|
||||
match p.constraints[x]? with
|
||||
match p.constraints.find? x with
|
||||
| none =>
|
||||
match s with
|
||||
| .trivial => p
|
||||
@@ -311,7 +313,7 @@ After solving, the variable will have been eliminated from all constraints.
|
||||
def solveEasyEquality (p : Problem) (c : Coeffs) : Problem :=
|
||||
let i := c.findIdx? (·.natAbs = 1) |>.getD 0 -- findIdx? is always some
|
||||
let sign := c.get i |> Int.sign
|
||||
match p.constraints[c]? with
|
||||
match p.constraints.find? c with
|
||||
| some f =>
|
||||
let init :=
|
||||
{ assumptions := p.assumptions
|
||||
@@ -333,7 +335,7 @@ After solving the easy equality,
|
||||
the minimum lexicographic value of `(c.minNatAbs, c.maxNatAbs)` will have been reduced.
|
||||
-/
|
||||
def dealWithHardEquality (p : Problem) (c : Coeffs) : OmegaM Problem :=
|
||||
match p.constraints[c]? with
|
||||
match p.constraints.find? c with
|
||||
| some ⟨_, ⟨some r, some r'⟩, j⟩ => do
|
||||
let m := c.minNatAbs + 1
|
||||
-- We have to store the valid value of the newly introduced variable in the atoms.
|
||||
@@ -477,7 +479,7 @@ def fourierMotzkinData (p : Problem) : Array FourierMotzkinData := Id.run do
|
||||
let n := p.numVars
|
||||
let mut data : Array FourierMotzkinData :=
|
||||
(List.range p.numVars).foldl (fun a i => a.push { var := i}) #[]
|
||||
for (_, f@⟨xs, s, _⟩) in p.constraints do
|
||||
for (_, f@⟨xs, s, _⟩) in p.constraints.toList do -- We could make a forIn instance for HashMap
|
||||
for i in [0:n] do
|
||||
let x := Coeffs.get xs i
|
||||
data := data.modify i fun d =>
|
||||
|
||||
@@ -58,7 +58,7 @@ structure MetaProblem where
|
||||
-/
|
||||
disjunctions : List Expr := []
|
||||
/-- Facts which have already been processed; we keep these to avoid duplicates. -/
|
||||
processedFacts : Std.HashSet Expr := ∅
|
||||
processedFacts : HashSet Expr := ∅
|
||||
|
||||
/-- Construct the `rfl` proof that `lc.eval atoms = e`. -/
|
||||
def mkEvalRflProof (e : Expr) (lc : LinearCombo) : OmegaM Expr := do
|
||||
@@ -80,7 +80,7 @@ def mkCoordinateEvalAtomsEq (e : Expr) (n : Nat) : OmegaM Expr := do
|
||||
mkEqTrans eq (← mkEqSymm (mkApp2 (.const ``LinearCombo.coordinate_eval []) n atoms))
|
||||
|
||||
/-- Construct the linear combination (and its associated proof and new facts) for an atom. -/
|
||||
def mkAtomLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
def mkAtomLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
let (n, facts) ← lookup e
|
||||
return ⟨LinearCombo.coordinate n, mkCoordinateEvalAtomsEq e n, facts.getD ∅⟩
|
||||
|
||||
@@ -94,9 +94,9 @@ Gives a small (10%) speedup in testing.
|
||||
I tried using a pointer based cache,
|
||||
but there was never enough subexpression sharing to make it effective.
|
||||
-/
|
||||
partial def asLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
partial def asLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
let cache ← get
|
||||
match cache.get? e with
|
||||
match cache.find? e with
|
||||
| some (lc, prf) =>
|
||||
trace[omega] "Found in cache: {e}"
|
||||
return (lc, prf, ∅)
|
||||
@@ -120,7 +120,7 @@ We also transform the expression as we descend into it:
|
||||
* pushing coercions: `↑(x + y)`, `↑(x * y)`, `↑(x / k)`, `↑(x % k)`, `↑k`
|
||||
* unfolding `emod`: `x % k` → `x - x / k`
|
||||
-/
|
||||
partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
trace[omega] "processing {e}"
|
||||
match groundInt? e with
|
||||
| some i =>
|
||||
@@ -142,7 +142,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
|
||||
mkEqTrans
|
||||
(← mkAppM ``Int.add_congr #[← prf₁, ← prf₂])
|
||||
(← mkEqSymm add_eval)
|
||||
pure (l₁ + l₂, prf, facts₁.union facts₂)
|
||||
pure (l₁ + l₂, prf, facts₁.merge facts₂)
|
||||
| (``HSub.hSub, #[_, _, _, _, e₁, e₂]) => do
|
||||
let (l₁, prf₁, facts₁) ← asLinearCombo e₁
|
||||
let (l₂, prf₂, facts₂) ← asLinearCombo e₂
|
||||
@@ -152,7 +152,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
|
||||
mkEqTrans
|
||||
(← mkAppM ``Int.sub_congr #[← prf₁, ← prf₂])
|
||||
(← mkEqSymm sub_eval)
|
||||
pure (l₁ - l₂, prf, facts₁.union facts₂)
|
||||
pure (l₁ - l₂, prf, facts₁.merge facts₂)
|
||||
| (``Neg.neg, #[_, _, e']) => do
|
||||
let (l, prf, facts) ← asLinearCombo e'
|
||||
let prf' : OmegaM Expr := do
|
||||
@@ -178,7 +178,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
|
||||
mkEqTrans
|
||||
(← mkAppM ``Int.mul_congr #[← xprf, ← yprf])
|
||||
(← mkEqSymm mul_eval)
|
||||
pure (some (LinearCombo.mul xl yl, prf, xfacts.union yfacts), true)
|
||||
pure (some (LinearCombo.mul xl yl, prf, xfacts.merge yfacts), true)
|
||||
else
|
||||
pure (none, false)
|
||||
match r? with
|
||||
@@ -235,7 +235,7 @@ where
|
||||
Apply a rewrite rule to an expression, and interpret the result as a `LinearCombo`.
|
||||
(We're not rewriting any subexpressions here, just the top level, for efficiency.)
|
||||
-/
|
||||
rewrite (lhs rw : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
rewrite (lhs rw : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
trace[omega] "rewriting {lhs} via {rw} : {← inferType rw}"
|
||||
match (← inferType rw).eq? with
|
||||
| some (_, _lhs', rhs) =>
|
||||
@@ -243,7 +243,7 @@ where
|
||||
let prf' : OmegaM Expr := do mkEqTrans rw (← prf)
|
||||
pure (lc, prf', facts)
|
||||
| none => panic! "Invalid rewrite rule in 'asLinearCombo'"
|
||||
handleNatCast (e i n : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
handleNatCast (e i n : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
match n with
|
||||
| .fvar h =>
|
||||
if let some v ← h.getValue? then
|
||||
@@ -296,7 +296,7 @@ where
|
||||
| (``Fin.val, #[n, x]) =>
|
||||
handleFinVal e i n x
|
||||
| _ => mkAtomLinearCombo e
|
||||
handleFinVal (e i n x : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
handleFinVal (e i n x : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
match x with
|
||||
| .fvar h =>
|
||||
if let some v ← h.getValue? then
|
||||
@@ -342,7 +342,7 @@ We solve equalities as they are discovered, as this often results in an earlier
|
||||
-/
|
||||
def addIntEquality (p : MetaProblem) (h x : Expr) : OmegaM MetaProblem := do
|
||||
let (lc, prf, facts) ← asLinearCombo x
|
||||
let newFacts : Std.HashSet Expr := facts.fold (init := ∅) fun s e =>
|
||||
let newFacts : HashSet Expr := facts.fold (init := ∅) fun s e =>
|
||||
if p.processedFacts.contains e then s else s.insert e
|
||||
trace[omega] "Adding proof of {lc} = 0"
|
||||
pure <|
|
||||
@@ -358,7 +358,7 @@ We solve equalities as they are discovered, as this often results in an earlier
|
||||
-/
|
||||
def addIntInequality (p : MetaProblem) (h y : Expr) : OmegaM MetaProblem := do
|
||||
let (lc, prf, facts) ← asLinearCombo y
|
||||
let newFacts : Std.HashSet Expr := facts.fold (init := ∅) fun s e =>
|
||||
let newFacts : HashSet Expr := facts.fold (init := ∅) fun s e =>
|
||||
if p.processedFacts.contains e then s else s.insert e
|
||||
trace[omega] "Adding proof of {lc} ≥ 0"
|
||||
pure <|
|
||||
@@ -590,7 +590,7 @@ where
|
||||
|
||||
-- We sort the constraints; otherwise the order is dependent on details of the hashing
|
||||
-- and this can cause test suite output churn
|
||||
prettyConstraints (names : Array String) (constraints : Std.HashMap Coeffs Fact) : String :=
|
||||
prettyConstraints (names : Array String) (constraints : HashMap Coeffs Fact) : String :=
|
||||
constraints.toList
|
||||
|>.toArray
|
||||
|>.qsort (·.1 < ·.1)
|
||||
@@ -615,7 +615,7 @@ where
|
||||
(if Int.natAbs c = 1 then names[i]! else s!"{c.natAbs}*{names[i]!}"))
|
||||
|> String.join
|
||||
|
||||
mentioned (atoms : Array Expr) (constraints : Std.HashMap Coeffs Fact) : MetaM (Array Bool) := do
|
||||
mentioned (atoms : Array Expr) (constraints : HashMap Coeffs Fact) : MetaM (Array Bool) := do
|
||||
let initMask := Array.mkArray atoms.size false
|
||||
return constraints.fold (init := initMask) fun mask coeffs _ =>
|
||||
coeffs.enum.foldl (init := mask) fun mask (i, c) =>
|
||||
|
||||
@@ -10,8 +10,6 @@ import Init.Omega.Logic
|
||||
import Init.Data.BitVec.Basic
|
||||
import Lean.Meta.AppBuilder
|
||||
import Lean.Meta.Canonicalizer
|
||||
import Std.Data.HashMap.Basic
|
||||
import Std.Data.HashSet.Basic
|
||||
|
||||
/-!
|
||||
# The `OmegaM` state monad.
|
||||
@@ -54,7 +52,7 @@ structure Context where
|
||||
/-- The internal state for the `OmegaM` monad, recording previously encountered atoms. -/
|
||||
structure State where
|
||||
/-- The atoms up-to-defeq encountered so far. -/
|
||||
atoms : Std.HashMap Expr Nat := {}
|
||||
atoms : HashMap Expr Nat := {}
|
||||
|
||||
/-- An intermediate layer in the `OmegaM` monad. -/
|
||||
abbrev OmegaM' := StateRefT State (ReaderT Context CanonM)
|
||||
@@ -62,7 +60,7 @@ abbrev OmegaM' := StateRefT State (ReaderT Context CanonM)
|
||||
/--
|
||||
Cache of expressions that have been visited, and their reflection as a linear combination.
|
||||
-/
|
||||
def Cache : Type := Std.HashMap Expr (LinearCombo × OmegaM' Expr)
|
||||
def Cache : Type := HashMap Expr (LinearCombo × OmegaM' Expr)
|
||||
|
||||
/--
|
||||
The `OmegaM` monad maintains two pieces of state:
|
||||
@@ -73,7 +71,7 @@ abbrev OmegaM := StateRefT Cache OmegaM'
|
||||
|
||||
/-- Run a computation in the `OmegaM` monad, starting with no recorded atoms. -/
|
||||
def OmegaM.run (m : OmegaM α) (cfg : OmegaConfig) : MetaM α :=
|
||||
m.run' Std.HashMap.empty |>.run' {} { cfg } |>.run'
|
||||
m.run' HashMap.empty |>.run' {} { cfg } |>.run'
|
||||
|
||||
/-- Retrieve the user-specified configuration options. -/
|
||||
def cfg : OmegaM OmegaConfig := do pure (← read).cfg
|
||||
@@ -164,11 +162,11 @@ def mkEqReflWithExpectedType (a b : Expr) : MetaM Expr := do
|
||||
Analyzes a newly recorded atom,
|
||||
returning a collection of interesting facts about it that should be added to the context.
|
||||
-/
|
||||
def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
|
||||
def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
|
||||
match e.getAppFnArgs with
|
||||
| (``Nat.cast, #[.const ``Int [], _, e']) =>
|
||||
-- Casts of natural numbers are non-negative.
|
||||
let mut r := Std.HashSet.empty.insert (Expr.app (.const ``Int.ofNat_nonneg []) e')
|
||||
let mut r := HashSet.empty.insert (Expr.app (.const ``Int.ofNat_nonneg []) e')
|
||||
match (← cfg).splitNatSub, e'.getAppFnArgs with
|
||||
| true, (``HSub.hSub, #[_, _, _, _, a, b]) =>
|
||||
-- `((a - b : Nat) : Int)` gives a dichotomy
|
||||
@@ -190,7 +188,7 @@ def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
|
||||
let ne_zero := mkApp3 (.const ``Ne [1]) (.const ``Int []) k (toExpr (0 : Int))
|
||||
let pos := mkApp4 (.const ``LT.lt [0]) (.const ``Int []) (.const ``Int.instLTInt [])
|
||||
(toExpr (0 : Int)) k
|
||||
pure <| Std.HashSet.empty.insert
|
||||
pure <| HashSet.empty.insert
|
||||
(mkApp3 (.const ``Int.mul_ediv_self_le []) x k (← mkDecideProof ne_zero)) |>.insert
|
||||
(mkApp3 (.const ``Int.lt_mul_ediv_self_add []) x k (← mkDecideProof pos))
|
||||
| (``HMod.hMod, #[_, _, _, _, x, k]) =>
|
||||
@@ -202,7 +200,7 @@ def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
|
||||
let b_pos := mkApp4 (.const ``LT.lt [0]) (.const ``Int []) (.const ``Int.instLTInt [])
|
||||
(toExpr (0 : Int)) b
|
||||
let pow_pos := mkApp3 (.const ``Lean.Omega.Int.pos_pow_of_pos []) b exp (← mkDecideProof b_pos)
|
||||
pure <| Std.HashSet.empty.insert
|
||||
pure <| HashSet.empty.insert
|
||||
(mkApp3 (.const ``Int.emod_nonneg []) x k
|
||||
(mkApp3 (.const ``Int.ne_of_gt []) k (toExpr (0 : Int)) pow_pos)) |>.insert
|
||||
(mkApp3 (.const ``Int.emod_lt_of_pos []) x k pow_pos)
|
||||
@@ -216,7 +214,7 @@ def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
|
||||
(toExpr (0 : Nat)) b
|
||||
let pow_pos := mkApp3 (.const ``Nat.pos_pow_of_pos []) b exp (← mkDecideProof b_pos)
|
||||
let cast_pos := mkApp2 (.const ``Int.ofNat_pos_of_pos []) k' pow_pos
|
||||
pure <| Std.HashSet.empty.insert
|
||||
pure <| HashSet.empty.insert
|
||||
(mkApp3 (.const ``Int.emod_nonneg []) x k
|
||||
(mkApp3 (.const ``Int.ne_of_gt []) k (toExpr (0 : Int)) cast_pos)) |>.insert
|
||||
(mkApp3 (.const ``Int.emod_lt_of_pos []) x k cast_pos)
|
||||
@@ -224,18 +222,18 @@ def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
|
||||
| (``Nat.cast, #[.const ``Int [], _, x']) =>
|
||||
-- Since we push coercions inside `%`, we need to record here that
|
||||
-- `(x : Int) % (y : Int)` is non-negative.
|
||||
pure <| Std.HashSet.empty.insert (mkApp2 (.const ``Int.emod_ofNat_nonneg []) x' k)
|
||||
pure <| HashSet.empty.insert (mkApp2 (.const ``Int.emod_ofNat_nonneg []) x' k)
|
||||
| _ => pure ∅
|
||||
| _ => pure ∅
|
||||
| (``Min.min, #[_, _, x, y]) =>
|
||||
pure <| Std.HashSet.empty.insert (mkApp2 (.const ``Int.min_le_left []) x y) |>.insert
|
||||
pure <| HashSet.empty.insert (mkApp2 (.const ``Int.min_le_left []) x y) |>.insert
|
||||
(mkApp2 (.const ``Int.min_le_right []) x y)
|
||||
| (``Max.max, #[_, _, x, y]) =>
|
||||
pure <| Std.HashSet.empty.insert (mkApp2 (.const ``Int.le_max_left []) x y) |>.insert
|
||||
pure <| HashSet.empty.insert (mkApp2 (.const ``Int.le_max_left []) x y) |>.insert
|
||||
(mkApp2 (.const ``Int.le_max_right []) x y)
|
||||
| (``ite, #[α, i, dec, t, e]) =>
|
||||
if α == (.const ``Int []) then
|
||||
pure <| Std.HashSet.empty.insert <| mkApp5 (.const ``ite_disjunction [0]) α i dec t e
|
||||
pure <| HashSet.empty.insert <| mkApp5 (.const ``ite_disjunction [0]) α i dec t e
|
||||
else
|
||||
pure {}
|
||||
| _ => pure ∅
|
||||
@@ -250,10 +248,10 @@ Return its index, and, if it is new, a collection of interesting facts about the
|
||||
* for each new atom of the form `((a - b : Nat) : Int)`, the fact:
|
||||
`b ≤ a ∧ ((a - b : Nat) : Int) = a - b ∨ a < b ∧ ((a - b : Nat) : Int) = 0`
|
||||
-/
|
||||
def lookup (e : Expr) : OmegaM (Nat × Option (Std.HashSet Expr)) := do
|
||||
def lookup (e : Expr) : OmegaM (Nat × Option (HashSet Expr)) := do
|
||||
let c ← getThe State
|
||||
let e ← canon e
|
||||
match c.atoms[e]? with
|
||||
match c.atoms.find? e with
|
||||
| some i => return (i, none)
|
||||
| none =>
|
||||
trace[omega] "New atom: {e}"
|
||||
|
||||
@@ -7,8 +7,8 @@ prelude
|
||||
import Init.Control.StateRef
|
||||
import Init.Data.Array.BinSearch
|
||||
import Init.Data.Stream
|
||||
import Lean.ImportingFlag
|
||||
import Lean.Data.HashMap
|
||||
import Lean.ImportingFlag
|
||||
import Lean.Data.SMap
|
||||
import Lean.Declaration
|
||||
import Lean.LocalContext
|
||||
@@ -134,7 +134,7 @@ structure Environment where
|
||||
the field `constants`. These auxiliary constants are invisible to the Lean kernel and elaborator.
|
||||
Only the code generator uses them.
|
||||
-/
|
||||
const2ModIdx : Std.HashMap Name ModuleIdx
|
||||
const2ModIdx : HashMap Name ModuleIdx
|
||||
/--
|
||||
Mapping from constant name to `ConstantInfo`. It contains all constants (definitions, theorems, axioms, etc)
|
||||
that have been already type checked by the kernel.
|
||||
@@ -205,7 +205,7 @@ private def getTrustLevel (env : Environment) : UInt32 :=
|
||||
env.header.trustLevel
|
||||
|
||||
def getModuleIdxFor? (env : Environment) (declName : Name) : Option ModuleIdx :=
|
||||
env.const2ModIdx[declName]?
|
||||
env.const2ModIdx.find? declName
|
||||
|
||||
def isConstructor (env : Environment) (declName : Name) : Bool :=
|
||||
match env.find? declName with
|
||||
@@ -721,7 +721,7 @@ def writeModule (env : Environment) (fname : System.FilePath) : IO Unit := do
|
||||
Construct a mapping from persistent extension name to entension index at the array of persistent extensions.
|
||||
We only consider extensions starting with index `>= startingAt`.
|
||||
-/
|
||||
def mkExtNameMap (startingAt : Nat) : IO (Std.HashMap Name Nat) := do
|
||||
def mkExtNameMap (startingAt : Nat) : IO (HashMap Name Nat) := do
|
||||
let descrs ← persistentEnvExtensionsRef.get
|
||||
let mut result := {}
|
||||
for h : i in [startingAt : descrs.size] do
|
||||
@@ -742,7 +742,7 @@ private def setImportedEntries (env : Environment) (mods : Array ModuleData) (st
|
||||
have : modIdx < mods.size := h.upper
|
||||
let mod := mods[modIdx]
|
||||
for (extName, entries) in mod.entries do
|
||||
if let some entryIdx := extNameIdx[extName]? then
|
||||
if let some entryIdx := extNameIdx.find? extName then
|
||||
env := extDescrs[entryIdx]!.toEnvExtension.modifyState env fun s => { s with importedEntries := s.importedEntries.set! modIdx entries }
|
||||
return env
|
||||
|
||||
@@ -790,9 +790,9 @@ structure ImportState where
|
||||
moduleData : Array ModuleData := #[]
|
||||
regions : Array CompactedRegion := #[]
|
||||
|
||||
def throwAlreadyImported (s : ImportState) (const2ModIdx : Std.HashMap Name ModuleIdx) (modIdx : Nat) (cname : Name) : IO α := do
|
||||
def throwAlreadyImported (s : ImportState) (const2ModIdx : HashMap Name ModuleIdx) (modIdx : Nat) (cname : Name) : IO α := do
|
||||
let modName := s.moduleNames[modIdx]!
|
||||
let constModName := s.moduleNames[const2ModIdx[cname]!.toNat]!
|
||||
let constModName := s.moduleNames[const2ModIdx[cname].get!.toNat]!
|
||||
throw <| IO.userError s!"import {modName} failed, environment already contains '{cname}' from {constModName}"
|
||||
|
||||
abbrev ImportStateM := StateRefT ImportState IO
|
||||
@@ -856,21 +856,21 @@ def finalizeImport (s : ImportState) (imports : Array Import) (opts : Options) (
|
||||
(leakEnv := false) : IO Environment := do
|
||||
let numConsts := s.moduleData.foldl (init := 0) fun numConsts mod =>
|
||||
numConsts + mod.constants.size + mod.extraConstNames.size
|
||||
let mut const2ModIdx : Std.HashMap Name ModuleIdx := Std.HashMap.empty (capacity := numConsts)
|
||||
let mut constantMap : Std.HashMap Name ConstantInfo := Std.HashMap.empty (capacity := numConsts)
|
||||
let mut const2ModIdx : HashMap Name ModuleIdx := mkHashMap (capacity := numConsts)
|
||||
let mut constantMap : HashMap Name ConstantInfo := mkHashMap (capacity := numConsts)
|
||||
for h:modIdx in [0:s.moduleData.size] do
|
||||
let mod := s.moduleData[modIdx]'h.upper
|
||||
for cname in mod.constNames, cinfo in mod.constants do
|
||||
match constantMap.getThenInsertIfNew? cname cinfo with
|
||||
| (cinfoPrev?, constantMap') =>
|
||||
match constantMap.insertIfNew cname cinfo with
|
||||
| (constantMap', cinfoPrev?) =>
|
||||
constantMap := constantMap'
|
||||
if let some cinfoPrev := cinfoPrev? then
|
||||
-- Recall that the map has not been modified when `cinfoPrev? = some _`.
|
||||
unless equivInfo cinfoPrev cinfo do
|
||||
throwAlreadyImported s const2ModIdx modIdx cname
|
||||
const2ModIdx := const2ModIdx.insertIfNew cname modIdx
|
||||
const2ModIdx := const2ModIdx.insertIfNew cname modIdx |>.1
|
||||
for cname in mod.extraConstNames do
|
||||
const2ModIdx := const2ModIdx.insertIfNew cname modIdx
|
||||
const2ModIdx := const2ModIdx.insertIfNew cname modIdx |>.1
|
||||
let constants : ConstMap := SMap.fromHashMap constantMap false
|
||||
let exts ← mkInitialExtensionStates
|
||||
let mut env : Environment := {
|
||||
@@ -936,7 +936,7 @@ builtin_initialize namespacesExt : SimplePersistentEnvExtension Name NameSSet
|
||||
6.18% of the runtime is here. It was 9.31% before the `HashMap` optimization.
|
||||
-/
|
||||
let capacity := as.foldl (init := 0) fun r e => r + e.size
|
||||
let map : Std.HashMap Name Unit := Std.HashMap.empty capacity
|
||||
let map : HashMap Name Unit := mkHashMap capacity
|
||||
let map := mkStateFromImportedEntries (fun map name => map.insert name ()) map as
|
||||
SMap.fromHashMap map |>.switch
|
||||
addEntryFn := fun s n => s.insert n
|
||||
|
||||
@@ -8,7 +8,6 @@ import Init.Data.Hashable
|
||||
import Lean.Data.KVMap
|
||||
import Lean.Data.SMap
|
||||
import Lean.Level
|
||||
import Std.Data.HashSet.Basic
|
||||
|
||||
namespace Lean
|
||||
|
||||
@@ -245,7 +244,7 @@ def FVarIdSet.insert (s : FVarIdSet) (fvarId : FVarId) : FVarIdSet :=
|
||||
A set of unique free variable identifiers implemented using hashtables.
|
||||
Hashtables are faster than red-black trees if they are used linearly.
|
||||
They are not persistent data-structures. -/
|
||||
def FVarIdHashSet := Std.HashSet FVarId
|
||||
def FVarIdHashSet := HashSet FVarId
|
||||
deriving Inhabited, EmptyCollection
|
||||
|
||||
/--
|
||||
@@ -1389,11 +1388,11 @@ def mkDecIsTrue (pred proof : Expr) :=
|
||||
def mkDecIsFalse (pred proof : Expr) :=
|
||||
mkAppB (mkConst `Decidable.isFalse) pred proof
|
||||
|
||||
abbrev ExprMap (α : Type) := Std.HashMap Expr α
|
||||
abbrev ExprMap (α : Type) := HashMap Expr α
|
||||
abbrev PersistentExprMap (α : Type) := PHashMap Expr α
|
||||
abbrev SExprMap (α : Type) := SMap Expr α
|
||||
|
||||
abbrev ExprSet := Std.HashSet Expr
|
||||
abbrev ExprSet := HashSet Expr
|
||||
abbrev PersistentExprSet := PHashSet Expr
|
||||
abbrev PExprSet := PersistentExprSet
|
||||
|
||||
@@ -1418,7 +1417,7 @@ instance : ToString ExprStructEq := ⟨fun e => toString e.val⟩
|
||||
|
||||
end ExprStructEq
|
||||
|
||||
abbrev ExprStructMap (α : Type) := Std.HashMap ExprStructEq α
|
||||
abbrev ExprStructMap (α : Type) := HashMap ExprStructEq α
|
||||
abbrev PersistentExprStructMap (α : Type) := PHashMap ExprStructEq α
|
||||
|
||||
namespace Expr
|
||||
|
||||
@@ -31,7 +31,7 @@ namespace Lean
|
||||
abbrev LabelExtension := SimpleScopedEnvExtension Name (Array Name)
|
||||
|
||||
/-- The collection of all current `LabelExtension`s, indexed by name. -/
|
||||
abbrev LabelExtensionMap := Std.HashMap Name LabelExtension
|
||||
abbrev LabelExtensionMap := HashMap Name LabelExtension
|
||||
|
||||
/-- Store the current `LabelExtension`s. -/
|
||||
builtin_initialize labelExtensionMapRef : IO.Ref LabelExtensionMap ← IO.mkRef {}
|
||||
@@ -88,7 +88,7 @@ macro (name := _root_.Lean.Parser.Command.registerLabelAttr)
|
||||
/-- When `attrName` is an attribute created using `register_labelled_attr`,
|
||||
return the names of all declarations labelled using that attribute. -/
|
||||
def labelled (attrName : Name) : CoreM (Array Name) := do
|
||||
match (← labelExtensionMapRef.get)[attrName]? with
|
||||
match (← labelExtensionMapRef.get).find? attrName with
|
||||
| none => throwError "No extension named {attrName}"
|
||||
| some ext => pure <| ext.getState (← getEnv)
|
||||
|
||||
|
||||
@@ -519,17 +519,12 @@ where
|
||||
|
||||
-- definitely resolved in `doElab` task
|
||||
let elabPromise ← IO.Promise.new
|
||||
let finishedPromise ← IO.Promise.new
|
||||
-- (Try to) use last line of command as range for final snapshot task. This ensures we do not
|
||||
-- retract the progress bar to a previous position in case the command support incremental
|
||||
-- reporting but has significant work after resolving its last incremental promise, such as
|
||||
-- final type checking; if it does not support incrementality, `elabSnap` constructed in
|
||||
-- `parseCmd` and containing the entire range of the command will determine the reported
|
||||
-- progress and be resolved effectively at the same time as this snapshot task, so `tailPos` is
|
||||
-- irrelevant in this case.
|
||||
let endRange? := stx.getTailPos?.map fun pos => ⟨pos, pos⟩
|
||||
let finishedSnap := { range? := endRange?, task := finishedPromise.result }
|
||||
let tacticCache ← old?.map (·.data.tacticCache) |>.getDM (IO.mkRef {})
|
||||
let finishedSnap ←
|
||||
doElab stx cmdState beginPos
|
||||
{ old? := old?.map fun old => ⟨old.data.stx, old.data.elabSnap⟩, new := elabPromise }
|
||||
tacticCache
|
||||
ctx
|
||||
|
||||
let minimalSnapshots := internal.minimalSnapshots.get cmdState.scopes.head!.opts
|
||||
let next? ← if Parser.isTerminalCommand stx then pure none
|
||||
@@ -541,31 +536,35 @@ where
|
||||
stx := .missing
|
||||
parserState := {}
|
||||
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
|
||||
finishedSnap := { range? := none, task := finishedPromise.result.map fun finishedSnap => {
|
||||
finishedSnap := .pure {
|
||||
diagnostics := finishedSnap.diagnostics
|
||||
infoTree? := none
|
||||
cmdState := {
|
||||
env := initEnv
|
||||
maxRecDepth := 0
|
||||
}
|
||||
}}
|
||||
}
|
||||
tacticCache
|
||||
} else {
|
||||
diagnostics, stx, parserState, tacticCache
|
||||
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
|
||||
finishedSnap
|
||||
finishedSnap := .pure finishedSnap
|
||||
}
|
||||
prom.resolve <| .mk (nextCmdSnap? := next?.map ({ range? := some ⟨parserState.pos, ctx.input.endPos⟩, task := ·.result })) data
|
||||
doElab stx cmdState beginPos
|
||||
{ old? := old?.map fun old => ⟨old.data.stx, old.data.elabSnap⟩, new := elabPromise }
|
||||
finishedPromise tacticCache ctx
|
||||
if let some next := next? then
|
||||
parseCmd none parserState finishedSnap.get.cmdState initEnv next ctx
|
||||
parseCmd none parserState finishedSnap.cmdState initEnv next ctx
|
||||
|
||||
doElab (stx : Syntax) (cmdState : Command.State) (beginPos : String.Pos)
|
||||
(snap : SnapshotBundle DynamicSnapshot) (finishedPromise : IO.Promise CommandFinishedSnapshot)
|
||||
(tacticCache : IO.Ref Tactic.Cache) : LeanProcessingM Unit := do
|
||||
(snap : SnapshotBundle DynamicSnapshot) (tacticCache : IO.Ref Tactic.Cache) :
|
||||
LeanProcessingM CommandFinishedSnapshot := do
|
||||
let ctx ← read
|
||||
-- (Try to) use last line of command as range for final snapshot task. This ensures we do not
|
||||
-- retract the progress bar to a previous position in case the command support incremental
|
||||
-- reporting but has significant work after resolving its last incremental promise, such as
|
||||
-- final type checking; if it does not support incrementality, `elabSnap` constructed in
|
||||
-- `parseCmd` and containing the entire range of the command will determine the reported
|
||||
-- progress and be resolved effectively at the same time as this snapshot task, so `tailPos` is
|
||||
-- irrelevant in this case.
|
||||
let scope := cmdState.scopes.head!
|
||||
let cmdStateRef ← IO.mkRef { cmdState with messages := .empty }
|
||||
/-
|
||||
@@ -601,7 +600,7 @@ where
|
||||
let cmdState := { cmdState with messages }
|
||||
-- definitely resolve eventually
|
||||
snap.new.resolve <| .ofTyped { diagnostics := .empty : SnapshotLeaf }
|
||||
finishedPromise.resolve {
|
||||
return {
|
||||
diagnostics := (← Snapshot.Diagnostics.ofMessageLog cmdState.messages)
|
||||
infoTree? := some cmdState.infoState.trees[0]!
|
||||
cmdState
|
||||
|
||||
@@ -614,9 +614,9 @@ where
|
||||
|
||||
end Level
|
||||
|
||||
abbrev LevelMap (α : Type) := Std.HashMap Level α
|
||||
abbrev LevelMap (α : Type) := HashMap Level α
|
||||
abbrev PersistentLevelMap (α : Type) := PHashMap Level α
|
||||
abbrev LevelSet := Std.HashSet Level
|
||||
abbrev LevelSet := HashSet Level
|
||||
abbrev PersistentLevelSet := PHashSet Level
|
||||
abbrev PLevelSet := PersistentLevelSet
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ def constructorNameAsVariable : Linter where
|
||||
| return
|
||||
|
||||
let infoTrees := (← get).infoState.trees.toArray
|
||||
let warnings : IO.Ref (Std.HashMap String.Range (Syntax × Name × Name)) ← IO.mkRef {}
|
||||
let warnings : IO.Ref (Lean.HashMap String.Range (Syntax × Name × Name)) ← IO.mkRef {}
|
||||
|
||||
for tree in infoTrees do
|
||||
tree.visitM' (preNode := fun ci info _ => do
|
||||
|
||||
@@ -149,7 +149,7 @@ def checkDecl : SimpleHandler := fun stx => do
|
||||
lintField rest[1][0] stx[1] "computed field"
|
||||
else if rest.getKind == ``«structure» then
|
||||
unless rest[5][2].isNone do
|
||||
let redecls : Std.HashSet String.Pos :=
|
||||
let redecls : HashSet String.Pos :=
|
||||
(← get).infoState.trees.foldl (init := {}) fun s tree =>
|
||||
tree.foldInfo (init := s) fun _ info s =>
|
||||
if let .ofFieldRedeclInfo info := info then
|
||||
|
||||
@@ -270,14 +270,14 @@ pointer identity and does not store the objects, so it is important not to store
|
||||
pointer to an object in the map, or it can be freed and reused, resulting in incorrect behavior.
|
||||
|
||||
Returns `true` if the object was not already in the set. -/
|
||||
unsafe def insertObjImpl {α : Type} (set : IO.Ref (Std.HashSet USize)) (a : α) : IO Bool := do
|
||||
unsafe def insertObjImpl {α : Type} (set : IO.Ref (HashSet USize)) (a : α) : IO Bool := do
|
||||
if (← set.get).contains (ptrAddrUnsafe a) then
|
||||
return false
|
||||
set.modify (·.insert (ptrAddrUnsafe a))
|
||||
return true
|
||||
|
||||
@[inherit_doc insertObjImpl, implemented_by insertObjImpl]
|
||||
opaque insertObj {α : Type} (set : IO.Ref (Std.HashSet USize)) (a : α) : IO Bool
|
||||
opaque insertObj {α : Type} (set : IO.Ref (HashSet USize)) (a : α) : IO Bool
|
||||
|
||||
/--
|
||||
Collects into `fvarUses` all `fvar`s occurring in the `Expr`s in `assignments`.
|
||||
@@ -285,8 +285,8 @@ This implementation respects subterm sharing in both the `PersistentHashMap` and
|
||||
to ensure that pointer-equal subobjects are not visited multiple times, which is important
|
||||
in practice because these expressions are very frequently highly shared.
|
||||
-/
|
||||
partial def visitAssignments (set : IO.Ref (Std.HashSet USize))
|
||||
(fvarUses : IO.Ref (Std.HashSet FVarId))
|
||||
partial def visitAssignments (set : IO.Ref (HashSet USize))
|
||||
(fvarUses : IO.Ref (HashSet FVarId))
|
||||
(assignments : Array (PersistentHashMap MVarId Expr)) : IO Unit := do
|
||||
MonadCacheT.run do
|
||||
for assignment in assignments do
|
||||
@@ -316,8 +316,8 @@ where
|
||||
/-- Given `aliases` as a map from an alias to what it aliases, we get the original
|
||||
term by recursion. This has no cycle detection, so if `aliases` contains a loop
|
||||
then this function will recurse infinitely. -/
|
||||
partial def followAliases (aliases : Std.HashMap FVarId FVarId) (x : FVarId) : FVarId :=
|
||||
match aliases[x]? with
|
||||
partial def followAliases (aliases : HashMap FVarId FVarId) (x : FVarId) : FVarId :=
|
||||
match aliases.find? x with
|
||||
| none => x
|
||||
| some y => followAliases aliases y
|
||||
|
||||
@@ -343,17 +343,17 @@ structure References where
|
||||
the spans for `foo`, `bar`, and `baz`. Global definitions are always treated as used.
|
||||
(It would be nice to be able to detect unused global definitions but this requires more
|
||||
information than the linter framework can provide.) -/
|
||||
constDecls : Std.HashSet String.Range := .empty
|
||||
constDecls : HashSet String.Range := .empty
|
||||
/-- The collection of all local declarations, organized by the span of the declaration.
|
||||
We collapse all declarations declared at the same position into a single record using
|
||||
`FVarDefinition.aliases`. -/
|
||||
fvarDefs : Std.HashMap String.Range FVarDefinition := .empty
|
||||
fvarDefs : HashMap String.Range FVarDefinition := .empty
|
||||
/-- The set of `FVarId`s that are used directly. These may or may not be aliases. -/
|
||||
fvarUses : Std.HashSet FVarId := .empty
|
||||
fvarUses : HashSet FVarId := .empty
|
||||
/-- A mapping from alias to original FVarId. We don't guarantee that the value is not itself
|
||||
an alias, but we use `followAliases` when adding new elements to try to avoid long chains. -/
|
||||
-- TODO: use a `UnionFind` data structure here
|
||||
fvarAliases : Std.HashMap FVarId FVarId := .empty
|
||||
fvarAliases : HashMap FVarId FVarId := .empty
|
||||
/-- Collection of all `MetavarContext`s following the execution of a tactic. We trawl these
|
||||
if needed to find additional `fvarUses`. -/
|
||||
assignments : Array (PersistentHashMap MVarId Expr) := #[]
|
||||
@@ -391,7 +391,7 @@ def collectReferences (infoTrees : Array Elab.InfoTree) (cmdStxRange : String.Ra
|
||||
if s.startsWith "_" then return
|
||||
-- Record this either as a new `fvarDefs`, or an alias of an existing one
|
||||
modify fun s =>
|
||||
if let some ref := s.fvarDefs[range]? then
|
||||
if let some ref := s.fvarDefs.find? range then
|
||||
{ s with fvarDefs := s.fvarDefs.insert range { ref with aliases := ref.aliases.push id } }
|
||||
else
|
||||
{ s with fvarDefs := s.fvarDefs.insert range { userName := ldecl.userName, stx, opts, aliases := #[id] } }
|
||||
@@ -444,7 +444,7 @@ def unusedVariables : Linter where
|
||||
-- Resolve all recursive references in `fvarAliases`.
|
||||
-- At this point everything in `fvarAliases` is guaranteed not to be itself an alias,
|
||||
-- and should point to some element of `FVarDefinition.aliases` in `s.fvarDefs`
|
||||
let fvarAliases : Std.HashMap FVarId FVarId := s.fvarAliases.fold (init := {}) fun m id baseId =>
|
||||
let fvarAliases : HashMap FVarId FVarId := s.fvarAliases.fold (init := {}) fun m id baseId =>
|
||||
m.insert id (followAliases s.fvarAliases baseId)
|
||||
|
||||
-- Collect all non-alias fvars corresponding to `fvarUses` by resolving aliases in the list.
|
||||
@@ -461,7 +461,7 @@ def unusedVariables : Linter where
|
||||
let fvarUses ← fvarUsesRef.get
|
||||
-- If any of the `fvar`s corresponding to this declaration is (an alias of) a variable in
|
||||
-- `fvarUses`, then it is used
|
||||
if aliases.any fun id => fvarUses.contains (fvarAliases.getD id id) then continue
|
||||
if aliases.any fun id => fvarUses.contains (fvarAliases.findD id id) then continue
|
||||
-- If this is a global declaration then it is (potentially) used after the command
|
||||
if s.constDecls.contains range then continue
|
||||
|
||||
@@ -496,7 +496,7 @@ def unusedVariables : Linter where
|
||||
initializedMVars := true
|
||||
let fvarUses ← fvarUsesRef.get
|
||||
-- Redo the initial check because `fvarUses` could be bigger now
|
||||
if aliases.any fun id => fvarUses.contains (fvarAliases.getD id id) then continue
|
||||
if aliases.any fun id => fvarUses.contains (fvarAliases.findD id id) then continue
|
||||
|
||||
-- If we made it this far then the variable is unused and not ignored
|
||||
unused := unused.push (declStx, userName)
|
||||
|
||||
@@ -16,8 +16,8 @@ structure State where
|
||||
nextParamIdx : Nat := 0
|
||||
paramNames : Array Name := #[]
|
||||
fvars : Array Expr := #[]
|
||||
lmap : Std.HashMap LMVarId Level := {}
|
||||
emap : Std.HashMap MVarId Expr := {}
|
||||
lmap : HashMap LMVarId Level := {}
|
||||
emap : HashMap MVarId Expr := {}
|
||||
abstractLevels : Bool -- whether to abstract level mvars
|
||||
|
||||
abbrev M := StateM State
|
||||
@@ -54,7 +54,7 @@ private partial def abstractLevelMVars (u : Level) : M Level := do
|
||||
if depth != s.mctx.depth then
|
||||
return u -- metavariables from lower depths are treated as constants
|
||||
else
|
||||
match s.lmap[mvarId]? with
|
||||
match s.lmap.find? mvarId with
|
||||
| some u => pure u
|
||||
| none =>
|
||||
let paramId := Name.mkNum `_abstMVar s.nextParamIdx
|
||||
@@ -87,7 +87,7 @@ partial def abstractExprMVars (e : Expr) : M Expr := do
|
||||
if e != eNew then
|
||||
abstractExprMVars eNew
|
||||
else
|
||||
match (← get).emap[mvarId]? with
|
||||
match (← get).emap.find? mvarId with
|
||||
| some e =>
|
||||
return e
|
||||
| none =>
|
||||
|
||||
@@ -4,11 +4,10 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Data.HashMap
|
||||
import Lean.Util.ShareCommon
|
||||
import Lean.Data.HashMap
|
||||
import Lean.Meta.Basic
|
||||
import Lean.Meta.FunInfo
|
||||
import Std.Data.HashMap.Raw
|
||||
|
||||
namespace Lean.Meta
|
||||
namespace Canonicalizer
|
||||
@@ -48,12 +47,12 @@ State for the `CanonM` monad.
|
||||
-/
|
||||
structure State where
|
||||
/-- Mapping from `Expr` to hash. -/
|
||||
-- We use `HashMap.Raw` to ensure we don't have to tag `State` as `unsafe`.
|
||||
cache : Std.HashMap.Raw ExprVisited UInt64 := Std.HashMap.Raw.empty
|
||||
-- We use `HashMapImp` to ensure we don't have to tag `State` as `unsafe`.
|
||||
cache : HashMapImp ExprVisited UInt64 := mkHashMapImp
|
||||
/--
|
||||
Given a hashcode `k` and `keyToExprs.find? h = some es`, we have that all `es` have hashcode `k`, and
|
||||
are not definitionally equal modulo the transparency setting used. -/
|
||||
keyToExprs : Std.HashMap UInt64 (List Expr) := ∅
|
||||
keyToExprs : HashMap UInt64 (List Expr) := mkHashMap
|
||||
|
||||
instance : Inhabited State where
|
||||
default := {}
|
||||
@@ -71,7 +70,7 @@ def CanonM.run (x : CanonM α) (transparency := TransparencyMode.instances) (s :
|
||||
StateRefT'.run (x transparency) s
|
||||
|
||||
private partial def mkKey (e : Expr) : CanonM UInt64 := do
|
||||
if let some hash := unsafe (← get).cache.get? { e } then
|
||||
if let some hash := unsafe (← get).cache.find? { e } then
|
||||
return hash
|
||||
else
|
||||
let key ← match e with
|
||||
@@ -108,7 +107,7 @@ private partial def mkKey (e : Expr) : CanonM UInt64 := do
|
||||
return mixHash (← mkKey v) (← mkKey b)
|
||||
| .proj _ i s =>
|
||||
return mixHash i.toUInt64 (← mkKey s)
|
||||
unsafe modify fun { cache, keyToExprs} => { keyToExprs, cache := cache.insert { e } key }
|
||||
unsafe modify fun { cache, keyToExprs} => { keyToExprs, cache := cache.insert { e } key |>.1 }
|
||||
return key
|
||||
|
||||
/--
|
||||
@@ -117,7 +116,7 @@ private partial def mkKey (e : Expr) : CanonM UInt64 := do
|
||||
def canon (e : Expr) : CanonM Expr := do
|
||||
let k ← mkKey e
|
||||
-- Find all expressions canonicalized before that have the same key.
|
||||
if let some es' := unsafe (← get).keyToExprs[k]? then
|
||||
if let some es' := unsafe (← get).keyToExprs.find? k then
|
||||
withTransparency (← read) do
|
||||
for e' in es' do
|
||||
-- Found an expression `e'` that is definitionally equal to `e` and share the same key.
|
||||
|
||||
@@ -127,7 +127,7 @@ abbrev ClosureM := ReaderT Context $ StateRefT State MetaM
|
||||
pure u
|
||||
else
|
||||
let s ← get
|
||||
match s.visitedLevel[u]? with
|
||||
match s.visitedLevel.find? u with
|
||||
| some v => pure v
|
||||
| none => do
|
||||
let v ← f u
|
||||
@@ -139,7 +139,7 @@ abbrev ClosureM := ReaderT Context $ StateRefT State MetaM
|
||||
pure e
|
||||
else
|
||||
let s ← get
|
||||
match s.visitedExpr.get? e with
|
||||
match s.visitedExpr.find? e with
|
||||
| some r => pure r
|
||||
| none =>
|
||||
let r ← f e
|
||||
|
||||
@@ -52,14 +52,14 @@ which appear in the type and local context of `mvarId`, as well as the
|
||||
metavariables which *those* metavariables depend on, etc.
|
||||
-/
|
||||
partial def _root_.Lean.MVarId.getMVarDependencies (mvarId : MVarId) (includeDelayed := false) :
|
||||
MetaM (Std.HashSet MVarId) :=
|
||||
MetaM (HashSet MVarId) :=
|
||||
(·.snd) <$> (go mvarId).run {}
|
||||
where
|
||||
/-- Auxiliary definition for `getMVarDependencies`. -/
|
||||
addMVars (e : Expr) : StateRefT (Std.HashSet MVarId) MetaM Unit := do
|
||||
addMVars (e : Expr) : StateRefT (HashSet MVarId) MetaM Unit := do
|
||||
let mvars ← getMVars e
|
||||
let mut s ← get
|
||||
set ({} : Std.HashSet MVarId) -- Ensure that `s` is not shared.
|
||||
set ({} : HashSet MVarId) -- Ensure that `s` is not shared.
|
||||
for mvarId in mvars do
|
||||
if ← pure includeDelayed <||> notM (mvarId.isDelayedAssigned) then
|
||||
s := s.insert mvarId
|
||||
@@ -67,7 +67,7 @@ where
|
||||
mvars.forM go
|
||||
|
||||
/-- Auxiliary definition for `getMVarDependencies`. -/
|
||||
go (mvarId : MVarId) : StateRefT (Std.HashSet MVarId) MetaM Unit :=
|
||||
go (mvarId : MVarId) : StateRefT (HashSet MVarId) MetaM Unit :=
|
||||
withIncRecDepth do
|
||||
let mdecl ← mvarId.getDecl
|
||||
addMVars mdecl.type
|
||||
|
||||
@@ -695,7 +695,7 @@ def throwOutOfScopeFVar : CheckAssignmentM α :=
|
||||
throw <| Exception.internal outOfScopeExceptionId
|
||||
|
||||
private def findCached? (e : Expr) : CheckAssignmentM (Option Expr) := do
|
||||
return (← get).cache.get? e
|
||||
return (← get).cache.find? e
|
||||
|
||||
private def cache (e r : Expr) : CheckAssignmentM Unit := do
|
||||
modify fun s => { s with cache := s.cache.insert e r }
|
||||
|
||||
@@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Data.AssocList
|
||||
import Lean.HeadIndex
|
||||
import Lean.Meta.Basic
|
||||
|
||||
|
||||
@@ -286,7 +286,7 @@ private structure Trie (α : Type) where
|
||||
/-- Index of trie matching star. -/
|
||||
star : TrieIndex
|
||||
/-- Following matches based on key of trie. -/
|
||||
children : Std.HashMap Key TrieIndex
|
||||
children : HashMap Key TrieIndex
|
||||
/-- Lazy entries at this trie that are not processed. -/
|
||||
pending : Array (LazyEntry α) := #[]
|
||||
deriving Inhabited
|
||||
@@ -318,7 +318,7 @@ structure LazyDiscrTree (α : Type) where
|
||||
/-- Backing array of trie entries. Should be owned by this trie. -/
|
||||
tries : Array (LazyDiscrTree.Trie α) := #[default]
|
||||
/-- Map from discriminator trie roots to the index. -/
|
||||
roots : Std.HashMap LazyDiscrTree.Key LazyDiscrTree.TrieIndex := {}
|
||||
roots : Lean.HashMap LazyDiscrTree.Key LazyDiscrTree.TrieIndex := {}
|
||||
|
||||
namespace LazyDiscrTree
|
||||
|
||||
@@ -445,9 +445,9 @@ private def addLazyEntryToTrie (i:TrieIndex) (e : LazyEntry α) : MatchM α Unit
|
||||
modify (·.modify i (·.pushPending e))
|
||||
|
||||
private def evalLazyEntry (config : WhnfCoreConfig)
|
||||
(p : Array α × TrieIndex × Std.HashMap Key TrieIndex)
|
||||
(p : Array α × TrieIndex × HashMap Key TrieIndex)
|
||||
(entry : LazyEntry α)
|
||||
: MatchM α (Array α × TrieIndex × Std.HashMap Key TrieIndex) := do
|
||||
: MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
|
||||
let (values, starIdx, children) := p
|
||||
let (todo, lctx, v) := entry
|
||||
if todo.isEmpty then
|
||||
@@ -465,7 +465,7 @@ private def evalLazyEntry (config : WhnfCoreConfig)
|
||||
addLazyEntryToTrie starIdx (todo, lctx, v)
|
||||
pure (values, starIdx, children)
|
||||
else
|
||||
match children[k]? with
|
||||
match children.find? k with
|
||||
| none =>
|
||||
let children := children.insert k (← newTrie (todo, lctx, v))
|
||||
pure (values, starIdx, children)
|
||||
@@ -478,16 +478,16 @@ This evaluates all lazy entries in a trie and updates `values`, `starIdx`, and `
|
||||
accordingly.
|
||||
-/
|
||||
private partial def evalLazyEntries (config : WhnfCoreConfig)
|
||||
(values : Array α) (starIdx : TrieIndex) (children : Std.HashMap Key TrieIndex)
|
||||
(values : Array α) (starIdx : TrieIndex) (children : HashMap Key TrieIndex)
|
||||
(entries : Array (LazyEntry α)) :
|
||||
MatchM α (Array α × TrieIndex × Std.HashMap Key TrieIndex) := do
|
||||
MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
|
||||
let mut values := values
|
||||
let mut starIdx := starIdx
|
||||
let mut children := children
|
||||
entries.foldlM (init := (values, starIdx, children)) (evalLazyEntry config)
|
||||
|
||||
private def evalNode (c : TrieIndex) :
|
||||
MatchM α (Array α × TrieIndex × Std.HashMap Key TrieIndex) := do
|
||||
MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
|
||||
let .node vs star cs pending := (←get).get! c
|
||||
if pending.size = 0 then
|
||||
pure (vs, star, cs)
|
||||
@@ -508,7 +508,7 @@ def dropKeyAux (next : TrieIndex) (rest : List Key) :
|
||||
| [] =>
|
||||
modify (·.set! next {values := #[], star, children})
|
||||
| k :: r => do
|
||||
let next := if k == .star then star else children.getD k 0
|
||||
let next := if k == .star then star else children.findD k 0
|
||||
dropKeyAux next r
|
||||
|
||||
/--
|
||||
@@ -519,7 +519,7 @@ def dropKey (t : LazyDiscrTree α) (path : List LazyDiscrTree.Key) : MetaM (Lazy
|
||||
match path with
|
||||
| [] => pure t
|
||||
| rootKey :: rest => do
|
||||
let idx := t.roots.getD rootKey 0
|
||||
let idx := t.roots.findD rootKey 0
|
||||
Prod.snd <$> runMatch t (dropKeyAux idx rest)
|
||||
|
||||
/--
|
||||
@@ -628,7 +628,7 @@ private partial def getMatchLoop (cases : Array PartialMatch) (result : MatchRes
|
||||
else
|
||||
cases.push { todo, score := ca.score, c := star }
|
||||
let pushNonStar (k : Key) (args : Array Expr) (cases : Array PartialMatch) :=
|
||||
match cs[k]? with
|
||||
match cs.find? k with
|
||||
| none => cases
|
||||
| some c => cases.push { todo := todo ++ args, score := ca.score + 1, c }
|
||||
let cases := pushStar cases
|
||||
@@ -650,8 +650,8 @@ private partial def getMatchLoop (cases : Array PartialMatch) (result : MatchRes
|
||||
cases |> pushNonStar k args
|
||||
getMatchLoop cases result
|
||||
|
||||
private def getStarResult (root : Std.HashMap Key TrieIndex) : MatchM α (MatchResult α) :=
|
||||
match root[Key.star]? with
|
||||
private def getStarResult (root : Lean.HashMap Key TrieIndex) : MatchM α (MatchResult α) :=
|
||||
match root.find? .star with
|
||||
| none =>
|
||||
pure <| {}
|
||||
| some idx => do
|
||||
@@ -661,16 +661,16 @@ private def getStarResult (root : Std.HashMap Key TrieIndex) : MatchM α (MatchR
|
||||
/-
|
||||
Add partial match to cases if discriminator tree root map has potential matches.
|
||||
-/
|
||||
private def pushRootCase (r : Std.HashMap Key TrieIndex) (k : Key) (args : Array Expr)
|
||||
private def pushRootCase (r : Lean.HashMap Key TrieIndex) (k : Key) (args : Array Expr)
|
||||
(cases : Array PartialMatch) : Array PartialMatch :=
|
||||
match r[k]? with
|
||||
match r.find? k with
|
||||
| none => cases
|
||||
| some c => cases.push { todo := args, score := 1, c }
|
||||
|
||||
/--
|
||||
Find values that match `e` in `root`.
|
||||
-/
|
||||
private def getMatchCore (root : Std.HashMap Key TrieIndex) (e : Expr) :
|
||||
private def getMatchCore (root : Lean.HashMap Key TrieIndex) (e : Expr) :
|
||||
MatchM α (MatchResult α) := do
|
||||
let result ← getStarResult root
|
||||
let (k, args) ← MatchClone.getMatchKeyArgs e (root := true) (← read)
|
||||
@@ -701,7 +701,7 @@ of elements using concurrent functions for generating entries.
|
||||
-/
|
||||
private structure PreDiscrTree (α : Type) where
|
||||
/-- Maps keys to index in tries array. -/
|
||||
roots : Std.HashMap Key Nat := {}
|
||||
roots : HashMap Key Nat := {}
|
||||
/-- Lazy entries for root of trie. -/
|
||||
tries : Array (Array (LazyEntry α)) := #[]
|
||||
deriving Inhabited
|
||||
@@ -711,7 +711,7 @@ namespace PreDiscrTree
|
||||
private def modifyAt (d : PreDiscrTree α) (k : Key)
|
||||
(f : Array (LazyEntry α) → Array (LazyEntry α)) : PreDiscrTree α :=
|
||||
let { roots, tries } := d
|
||||
match roots[k]? with
|
||||
match roots.find? k with
|
||||
| .none =>
|
||||
let roots := roots.insert k tries.size
|
||||
{ roots, tries := tries.push (f #[]) }
|
||||
|
||||
@@ -68,7 +68,7 @@ where
|
||||
loop lhss alts minors
|
||||
|
||||
structure State where
|
||||
used : Std.HashSet Nat := {} -- used alternatives
|
||||
used : HashSet Nat := {} -- used alternatives
|
||||
counterExamples : List (List Example) := []
|
||||
|
||||
/-- Return true if the given (sub-)problem has been solved. -/
|
||||
|
||||
@@ -28,17 +28,17 @@ such as `contradiction`.
|
||||
-/
|
||||
private def _root_.Lean.MVarId.contradictionQuick (mvarId : MVarId) : MetaM Bool := do
|
||||
mvarId.withContext do
|
||||
let mut posMap : Std.HashMap Expr FVarId := {}
|
||||
let mut negMap : Std.HashMap Expr FVarId := {}
|
||||
let mut posMap : HashMap Expr FVarId := {}
|
||||
let mut negMap : HashMap Expr FVarId := {}
|
||||
for localDecl in (← getLCtx) do
|
||||
unless localDecl.isImplementationDetail do
|
||||
if let some p ← matchNot? localDecl.type then
|
||||
if let some pFVarId := posMap[p]? then
|
||||
if let some pFVarId := posMap.find? p then
|
||||
mvarId.assign (← mkAbsurd (← mvarId.getType) (mkFVar pFVarId) localDecl.toExpr)
|
||||
return true
|
||||
negMap := negMap.insert p localDecl.fvarId
|
||||
if (← isProp localDecl.type) then
|
||||
if let some nFVarId := negMap[localDecl.type]? then
|
||||
if let some nFVarId := negMap.find? localDecl.type then
|
||||
mvarId.assign (← mkAbsurd (← mvarId.getType) localDecl.toExpr (mkFVar nFVarId))
|
||||
return true
|
||||
posMap := posMap.insert localDecl.type localDecl.fvarId
|
||||
|
||||
@@ -97,8 +97,8 @@ namespace MkTableKey
|
||||
|
||||
structure State where
|
||||
nextIdx : Nat := 0
|
||||
lmap : Std.HashMap LMVarId Level := {}
|
||||
emap : Std.HashMap MVarId Expr := {}
|
||||
lmap : HashMap LMVarId Level := {}
|
||||
emap : HashMap MVarId Expr := {}
|
||||
mctx : MetavarContext
|
||||
|
||||
abbrev M := StateM State
|
||||
@@ -120,7 +120,7 @@ partial def normLevel (u : Level) : M Level := do
|
||||
return u
|
||||
else
|
||||
let s ← get
|
||||
match (← get).lmap[mvarId]? with
|
||||
match (← get).lmap.find? mvarId with
|
||||
| some u' => pure u'
|
||||
| none =>
|
||||
let u' := mkLevelParam <| Name.mkNum `_tc s.nextIdx
|
||||
@@ -145,7 +145,7 @@ partial def normExpr (e : Expr) : M Expr := do
|
||||
return e
|
||||
else
|
||||
let s ← get
|
||||
match s.emap[mvarId]? with
|
||||
match s.emap.find? mvarId with
|
||||
| some e' => pure e'
|
||||
| none => do
|
||||
let e' := mkFVar { name := Name.mkNum `_tc s.nextIdx }
|
||||
@@ -186,7 +186,7 @@ structure State where
|
||||
result? : Option AbstractMVarsResult := none
|
||||
generatorStack : Array GeneratorNode := #[]
|
||||
resumeStack : Array (ConsumerNode × Answer) := #[]
|
||||
tableEntries : Std.HashMap Expr TableEntry := {}
|
||||
tableEntries : HashMap Expr TableEntry := {}
|
||||
|
||||
abbrev SynthM := ReaderT Context $ StateRefT State MetaM
|
||||
|
||||
@@ -265,7 +265,7 @@ def newSubgoal (mctx : MetavarContext) (key : Expr) (mvar : Expr) (waiter : Wait
|
||||
pure ((), m!"new goal {key}")
|
||||
|
||||
def findEntry? (key : Expr) : SynthM (Option TableEntry) := do
|
||||
return (← get).tableEntries[key]?
|
||||
return (← get).tableEntries.find? key
|
||||
|
||||
def getEntry (key : Expr) : SynthM TableEntry := do
|
||||
match (← findEntry? key) with
|
||||
@@ -553,7 +553,7 @@ def generate : SynthM Unit := do
|
||||
/- See comment at `typeHasMVars` -/
|
||||
if backward.synthInstance.canonInstances.get (← getOptions) then
|
||||
unless gNode.typeHasMVars do
|
||||
if let some entry := (← get).tableEntries[key]? then
|
||||
if let some entry := (← get).tableEntries.find? key then
|
||||
if entry.answers.any fun answer => answer.result.numMVars == 0 then
|
||||
/-
|
||||
We already have an answer that:
|
||||
|
||||
@@ -66,9 +66,9 @@ inductive PreExpr
|
||||
def toACExpr (op l r : Expr) : MetaM (Array Expr × ACExpr) := do
|
||||
let (preExpr, vars) ←
|
||||
toPreExpr (mkApp2 op l r)
|
||||
|>.run Std.HashSet.empty
|
||||
|>.run HashSet.empty
|
||||
let vars := vars.toArray.insertionSort Expr.lt
|
||||
let varMap := vars.foldl (fun xs x => xs.insert x xs.size) Std.HashMap.empty |>.get!
|
||||
let varMap := vars.foldl (fun xs x => xs.insert x xs.size) HashMap.empty |>.find!
|
||||
|
||||
return (vars, toACExpr varMap preExpr)
|
||||
where
|
||||
|
||||
@@ -290,7 +290,7 @@ structure RewriteResultConfig where
|
||||
side : SideConditions := .solveByElim
|
||||
mctx : MetavarContext
|
||||
|
||||
def takeListAux (cfg : RewriteResultConfig) (seen : Std.HashMap String Unit) (acc : Array RewriteResult)
|
||||
def takeListAux (cfg : RewriteResultConfig) (seen : HashMap String Unit) (acc : Array RewriteResult)
|
||||
(xs : List ((Expr ⊕ Name) × Bool × Nat)) : MetaM (Array RewriteResult) := do
|
||||
let mut seen := seen
|
||||
let mut acc := acc
|
||||
|
||||
@@ -420,12 +420,12 @@ def mkSimpExt (name : Name := by exact decl_name%) : IO SimpExtension :=
|
||||
| .toUnfoldThms n thms => d.registerDeclToUnfoldThms n thms
|
||||
}
|
||||
|
||||
abbrev SimpExtensionMap := Std.HashMap Name SimpExtension
|
||||
abbrev SimpExtensionMap := HashMap Name SimpExtension
|
||||
|
||||
builtin_initialize simpExtensionMapRef : IO.Ref SimpExtensionMap ← IO.mkRef {}
|
||||
|
||||
def getSimpExtension? (attrName : Name) : IO (Option SimpExtension) :=
|
||||
return (← simpExtensionMapRef.get)[attrName]?
|
||||
return (← simpExtensionMapRef.get).find? attrName
|
||||
|
||||
/-- Auxiliary method for adding a global declaration to a `SimpTheorems` datastructure. -/
|
||||
def SimpTheorems.addConst (s : SimpTheorems) (declName : Name) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM SimpTheorems := do
|
||||
|
||||
@@ -19,8 +19,8 @@ It contains:
|
||||
- The actual procedure associated with a name.
|
||||
-/
|
||||
structure BuiltinSimprocs where
|
||||
keys : Std.HashMap Name (Array SimpTheoremKey) := {}
|
||||
procs : Std.HashMap Name (Sum Simproc DSimproc) := {}
|
||||
keys : HashMap Name (Array SimpTheoremKey) := {}
|
||||
procs : HashMap Name (Sum Simproc DSimproc) := {}
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
@@ -37,7 +37,7 @@ structure SimprocDecl where
|
||||
deriving Inhabited
|
||||
|
||||
structure SimprocDeclExtState where
|
||||
builtin : Std.HashMap Name (Array SimpTheoremKey)
|
||||
builtin : HashMap Name (Array SimpTheoremKey)
|
||||
newEntries : PHashMap Name (Array SimpTheoremKey) := {}
|
||||
deriving Inhabited
|
||||
|
||||
@@ -65,7 +65,7 @@ def getSimprocDeclKeys? (declName : Name) : CoreM (Option (Array SimpTheoremKey)
|
||||
if let some keys := keys? then
|
||||
return some keys
|
||||
else
|
||||
return (simprocDeclExt.getState env).builtin[declName]?
|
||||
return (simprocDeclExt.getState env).builtin.find? declName
|
||||
|
||||
def isBuiltinSimproc (declName : Name) : CoreM Bool := do
|
||||
let s := simprocDeclExt.getState (← getEnv)
|
||||
@@ -160,7 +160,7 @@ def Simprocs.addCore (s : Simprocs) (keys : Array SimpTheoremKey) (declName : Na
|
||||
Implements attributes `builtin_simproc` and `builtin_sevalproc`.
|
||||
-/
|
||||
def addSimprocBuiltinAttrCore (ref : IO.Ref Simprocs) (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : IO Unit := do
|
||||
let some keys := (← builtinSimprocDeclsRef.get).keys[declName]? |
|
||||
let some keys := (← builtinSimprocDeclsRef.get).keys.find? declName |
|
||||
throw (IO.userError "invalid [builtin_simproc] attribute, '{declName}' is not a builtin simproc")
|
||||
ref.modify fun s => s.addCore keys declName post proc
|
||||
|
||||
@@ -176,7 +176,7 @@ def Simprocs.add (s : Simprocs) (declName : Name) (post : Bool) : CoreM Simprocs
|
||||
getSimprocFromDecl declName
|
||||
catch e =>
|
||||
if (← isBuiltinSimproc declName) then
|
||||
let some proc := (← builtinSimprocDeclsRef.get).procs[declName]?
|
||||
let some proc := (← builtinSimprocDeclsRef.get).procs.find? declName
|
||||
| throwError "invalid [simproc] attribute, '{declName}' is not a simproc"
|
||||
pure proc
|
||||
else
|
||||
@@ -384,7 +384,7 @@ def mkSimprocAttr (attrName : Name) (attrDescr : String) (ext : SimprocExtension
|
||||
erase := eraseSimprocAttr ext
|
||||
}
|
||||
|
||||
abbrev SimprocExtensionMap := Std.HashMap Name SimprocExtension
|
||||
abbrev SimprocExtensionMap := HashMap Name SimprocExtension
|
||||
|
||||
builtin_initialize simprocExtensionMapRef : IO.Ref SimprocExtensionMap ← IO.mkRef {}
|
||||
|
||||
@@ -438,7 +438,7 @@ def getSEvalSimprocs : CoreM Simprocs :=
|
||||
return simprocSEvalExtension.getState (← getEnv)
|
||||
|
||||
def getSimprocExtensionCore? (attrName : Name) : IO (Option SimprocExtension) :=
|
||||
return (← simprocExtensionMapRef.get)[attrName]?
|
||||
return (← simprocExtensionMapRef.get).find? attrName
|
||||
|
||||
def simpAttrNameToSimprocAttrName (attrName : Name) : Name :=
|
||||
if attrName == `simp then `simprocAttr
|
||||
|
||||
@@ -512,7 +512,7 @@ def mkCongrSimp? (f : Expr) : SimpM (Option CongrTheorem) := do
|
||||
if kinds.all fun k => match k with | CongrArgKind.fixed => true | CongrArgKind.eq => true | _ => false then
|
||||
/- See remark above. -/
|
||||
return none
|
||||
match (← get).congrCache[f]? with
|
||||
match (← get).congrCache.find? f with
|
||||
| some thm? => return thm?
|
||||
| none =>
|
||||
let thm? ← mkCongrSimpCore? f info kinds
|
||||
|
||||
@@ -903,7 +903,7 @@ structure State where
|
||||
mctx : MetavarContext
|
||||
nextMacroScope : MacroScope
|
||||
ngen : NameGenerator
|
||||
cache : Std.HashMap ExprStructEq Expr := {}
|
||||
cache : HashMap ExprStructEq Expr := {}
|
||||
|
||||
structure Context where
|
||||
mainModule : Name
|
||||
@@ -1319,7 +1319,7 @@ structure State where
|
||||
mctx : MetavarContext
|
||||
paramNames : Array Name := #[]
|
||||
nextParamIdx : Nat
|
||||
cache : Std.HashMap ExprStructEq Expr := {}
|
||||
cache : HashMap ExprStructEq Expr := {}
|
||||
|
||||
abbrev M := ReaderT Context <| StateM State
|
||||
|
||||
@@ -1328,7 +1328,7 @@ instance : MonadMCtx M where
|
||||
modifyMCtx f := modify fun s => { s with mctx := f s.mctx }
|
||||
|
||||
instance : MonadCache ExprStructEq Expr M where
|
||||
findCached? e := return (← get).cache[e]?
|
||||
findCached? e := return (← get).cache.find? e
|
||||
cache e v := modify fun s => { s with cache := s.cache.insert e v }
|
||||
|
||||
partial def mkParamName : M Name := do
|
||||
|
||||
@@ -242,10 +242,10 @@ def «structure» := leading_parser
|
||||
@[builtin_command_parser] def noncomputableSection := leading_parser
|
||||
"noncomputable " >> "section" >> optional (ppSpace >> checkColGt >> ident)
|
||||
/--
|
||||
A `section`/`end` pair delimits the scope of `variable`, `include, `open`, `set_option`, and `local`
|
||||
commands. Sections can be nested. `section <id>` provides a label to the section that has to appear
|
||||
with the matching `end`. In either case, the `end` can be omitted, in which case the section is
|
||||
closed at the end of the file.
|
||||
A `section`/`end` pair delimits the scope of `variable`, `open`, `set_option`, and `local` commands.
|
||||
Sections can be nested. `section <id>` provides a label to the section that has to appear with the
|
||||
matching `end`. In either case, the `end` can be omitted, in which case the section is closed at the
|
||||
end of the file.
|
||||
-/
|
||||
@[builtin_command_parser] def «section» := leading_parser
|
||||
"section" >> optional (ppSpace >> checkColGt >> ident)
|
||||
@@ -274,12 +274,12 @@ with `end <id>`. The `end` command is optional at the end of a file.
|
||||
@[builtin_command_parser] def «end» := leading_parser
|
||||
"end" >> optional (ppSpace >> checkColGt >> ident)
|
||||
/-- Declares one or more typed variables, or modifies whether already-declared variables are
|
||||
implicit.
|
||||
implicit.
|
||||
|
||||
Introduces variables that can be used in definitions within the same `namespace` or `section` block.
|
||||
When a definition mentions a variable, Lean will add it as an argument of the definition. This is
|
||||
useful in particular when writing many definitions that have parameters in common (see below for an
|
||||
example).
|
||||
When a definition mentions a variable, Lean will add it as an argument of the definition. The
|
||||
`variable` command is also able to add typeclass parameters. This is useful in particular when
|
||||
writing many definitions that have parameters in common (see below for an example).
|
||||
|
||||
Variable declarations have the same flexibility as regular function paramaters. In particular they
|
||||
can be [explicit, implicit][binder docs], or [instance implicit][tpil classes] (in which case they
|
||||
@@ -287,22 +287,17 @@ can be anonymous). This can be changed, for instance one can turn explicit varia
|
||||
implicit one with `variable {x}`. Note that currently, you should avoid changing how variables are
|
||||
bound and declare new variables at the same time; see [issue 2789] for more on this topic.
|
||||
|
||||
In *theorem bodies* (i.e. proofs), variables are not included based on usage in order to ensure that
|
||||
changes to the proof cannot change the statement of the overall theorem. Instead, variables are only
|
||||
available to the proof if they have been mentioned in the theorem header or in an `include` command
|
||||
or are instance implicit and depend only on such variables.
|
||||
|
||||
See [*Variables and Sections* from Theorem Proving in Lean][tpil vars] for a more detailed
|
||||
discussion.
|
||||
|
||||
[tpil vars]:
|
||||
https://lean-lang.org/theorem_proving_in_lean4/dependent_type_theory.html#variables-and-sections
|
||||
(Variables and Sections on Theorem Proving in Lean) [tpil classes]:
|
||||
https://lean-lang.org/theorem_proving_in_lean4/type_classes.html (Type classes on Theorem Proving in
|
||||
Lean) [binder docs]:
|
||||
https://leanprover-community.github.io/mathlib4_docs/Lean/Expr.html#Lean.BinderInfo (Documentation
|
||||
for the BinderInfo type) [issue 2789]: https://github.com/leanprover/lean4/issues/2789 (Issue 2789
|
||||
on github)
|
||||
[tpil vars]: https://lean-lang.org/theorem_proving_in_lean4/dependent_type_theory.html#variables-and-sections
|
||||
(Variables and Sections on Theorem Proving in Lean)
|
||||
[tpil classes]: https://lean-lang.org/theorem_proving_in_lean4/type_classes.html
|
||||
(Type classes on Theorem Proving in Lean)
|
||||
[binder docs]: https://leanprover-community.github.io/mathlib4_docs/Lean/Expr.html#Lean.BinderInfo
|
||||
(Documentation for the BinderInfo type)
|
||||
[issue 2789]: https://github.com/leanprover/lean4/issues/2789
|
||||
(Issue 2789 on github)
|
||||
|
||||
## Examples
|
||||
|
||||
@@ -373,24 +368,6 @@ namespace Logger
|
||||
end Logger
|
||||
```
|
||||
|
||||
The following example demonstrates availability of variables in proofs:
|
||||
```lean
|
||||
variable
|
||||
{α : Type} -- available in the proof as indirectly mentioned through `a`
|
||||
[ToString α] -- available in the proof as `α` is included
|
||||
(a : α) -- available in the proof as mentioned in the header
|
||||
{β : Type} -- not available in the proof
|
||||
[ToString β] -- not available in the proof
|
||||
|
||||
theorem ex : a = a := rfl
|
||||
```
|
||||
After elaboration of the proof, the following warning will be generated to highlight the unused
|
||||
hypothesis:
|
||||
```
|
||||
included section variable '[ToString α]' is not used in 'ex', consider excluding it
|
||||
```
|
||||
In such cases, the offending variable declaration should be moved down or into a section so that
|
||||
only theorems that do depend on it follow it until the end of the section.
|
||||
-/
|
||||
@[builtin_command_parser] def «variable» := leading_parser
|
||||
"variable" >> many1 (ppSpace >> checkColGt >> Term.bracketedBinder)
|
||||
@@ -726,13 +703,8 @@ list, so it should be brief.
|
||||
@[builtin_command_parser] def genInjectiveTheorems := leading_parser
|
||||
"gen_injective_theorems% " >> ident
|
||||
|
||||
/--
|
||||
`include eeny meeny` instructs Lean to include the section `variable`s `eeny` and `meeny` in all
|
||||
declarations in the remainder of the current section, differing from the default behavior of
|
||||
conditionally including variables based on use in the declaration header. `include` is usually
|
||||
followed by the `in` combinator to limit the inclusion to the subsequent declaration.
|
||||
-/
|
||||
@[builtin_command_parser] def «include» := leading_parser "include " >> many1 ident
|
||||
/-- To be implemented. -/
|
||||
@[builtin_command_parser] def «include» := leading_parser "include " >> many1 (checkColGt >> ident)
|
||||
|
||||
/-- No-op parser used as syntax kind for attaching remaining whitespace at the end of the input. -/
|
||||
@[run_builtin_parser_attribute_hooks] def eoi : Parser := leading_parser ""
|
||||
|
||||
@@ -131,7 +131,7 @@ structure ParserCacheEntry where
|
||||
|
||||
structure ParserCache where
|
||||
tokenCache : TokenCacheEntry
|
||||
parserCache : Std.HashMap ParserCacheKey ParserCacheEntry
|
||||
parserCache : HashMap ParserCacheKey ParserCacheEntry
|
||||
|
||||
def initCacheForInput (input : String) : ParserCache where
|
||||
tokenCache := { startPos := input.endPos + ' ' /- make sure it is not a valid position -/ }
|
||||
@@ -418,7 +418,7 @@ place if there was an error.
|
||||
-/
|
||||
def withCacheFn (parserName : Name) (p : ParserFn) : ParserFn := fun c s => Id.run do
|
||||
let key := ⟨c.toCacheableParserContext, parserName, s.pos⟩
|
||||
if let some r := s.cache.parserCache[key]? then
|
||||
if let some r := s.cache.parserCache.find? key then
|
||||
-- TODO: turn this into a proper trace once we have these in the parser
|
||||
--dbg_trace "parser cache hit: {parserName}:{s.pos} -> {r.stx}"
|
||||
return ⟨s.stxStack.push r.stx, r.lhsPrec, r.newPos, s.cache, r.errorMsg, s.recoveredErrors⟩
|
||||
|
||||
@@ -123,12 +123,6 @@ unsafe def mkDelabAttribute : IO (KeyedDeclsAttribute Delab) :=
|
||||
} `Lean.PrettyPrinter.Delaborator.delabAttribute
|
||||
@[builtin_init mkDelabAttribute] opaque delabAttribute : KeyedDeclsAttribute Delab
|
||||
|
||||
macro "app_delab" id:ident : attr => do
|
||||
match ← Macro.resolveGlobalName id.getId with
|
||||
| [] => Macro.throwErrorAt id s!"unknown declaration '{id.getId}'"
|
||||
| [(c, [])] => `(attr| delab $(mkIdentFrom (canonical := true) id (`app ++ c)))
|
||||
| _ => Macro.throwErrorAt id s!"ambiguous declaration '{id.getId}'"
|
||||
|
||||
def getExprKind : DelabM Name := do
|
||||
let e ← getExpr
|
||||
pure $ match e with
|
||||
|
||||
@@ -198,10 +198,10 @@ def isHBinOp (e : Expr) : Bool := Id.run do
|
||||
def replaceLPsWithVars (e : Expr) : MetaM Expr := do
|
||||
if !e.hasLevelParam then return e
|
||||
let lps := collectLevelParams {} e |>.params
|
||||
let mut replaceMap : Std.HashMap Name Level := {}
|
||||
let mut replaceMap : HashMap Name Level := {}
|
||||
for lp in lps do replaceMap := replaceMap.insert lp (← mkFreshLevelMVar)
|
||||
return e.replaceLevel fun
|
||||
| Level.param n .. => replaceMap[n]!
|
||||
| Level.param n .. => replaceMap.find! n
|
||||
| l => if !l.hasParam then some l else none
|
||||
|
||||
def isDefEqAssigning (t s : Expr) : MetaM Bool := do
|
||||
|
||||
@@ -29,7 +29,7 @@ namespace Lean.Environment
|
||||
namespace Replay
|
||||
|
||||
structure Context where
|
||||
newConstants : Std.HashMap Name ConstantInfo
|
||||
newConstants : HashMap Name ConstantInfo
|
||||
|
||||
structure State where
|
||||
env : Environment
|
||||
@@ -73,7 +73,7 @@ and add it to the environment.
|
||||
-/
|
||||
partial def replayConstant (name : Name) : M Unit := do
|
||||
if ← isTodo name then
|
||||
let some ci := (← read).newConstants[name]? | unreachable!
|
||||
let some ci := (← read).newConstants.find? name | unreachable!
|
||||
replayConstants ci.getUsedConstantsAsSet
|
||||
-- Check that this name is still pending: a mutual block may have taken care of it.
|
||||
if (← get).pending.contains name then
|
||||
@@ -89,13 +89,13 @@ partial def replayConstant (name : Name) : M Unit := do
|
||||
| .inductInfo info =>
|
||||
let lparams := info.levelParams
|
||||
let nparams := info.numParams
|
||||
let all ← info.all.mapM fun n => do pure <| ((← read).newConstants[n]!)
|
||||
let all ← info.all.mapM fun n => do pure <| ((← read).newConstants.find! n)
|
||||
for o in all do
|
||||
modify fun s =>
|
||||
{ s with remaining := s.remaining.erase o.name, pending := s.pending.erase o.name }
|
||||
let ctorInfo ← all.mapM fun ci => do
|
||||
pure (ci, ← ci.inductiveVal!.ctors.mapM fun n => do
|
||||
pure ((← read).newConstants[n]!))
|
||||
pure ((← read).newConstants.find! n))
|
||||
-- Make sure we are really finished with the constructors.
|
||||
for (_, ctors) in ctorInfo do
|
||||
for ctor in ctors do
|
||||
@@ -129,7 +129,7 @@ when we replayed the inductives.
|
||||
-/
|
||||
def checkPostponedConstructors : M Unit := do
|
||||
for ctor in (← get).postponedConstructors do
|
||||
match (← get).env.constants.find? ctor, (← read).newConstants[ctor]? with
|
||||
match (← get).env.constants.find? ctor, (← read).newConstants.find? ctor with
|
||||
| some (.ctorInfo info), some (.ctorInfo info') =>
|
||||
if ! (info == info') then throw <| IO.userError s!"Invalid constructor {ctor}"
|
||||
| _, _ => throw <| IO.userError s!"No such constructor {ctor}"
|
||||
@@ -140,7 +140,7 @@ when we replayed the inductives.
|
||||
-/
|
||||
def checkPostponedRecursors : M Unit := do
|
||||
for ctor in (← get).postponedRecursors do
|
||||
match (← get).env.constants.find? ctor, (← read).newConstants[ctor]? with
|
||||
match (← get).env.constants.find? ctor, (← read).newConstants.find? ctor with
|
||||
| some (.recInfo info), some (.recInfo info') =>
|
||||
if ! (info == info') then throw <| IO.userError s!"Invalid recursor {ctor}"
|
||||
| _, _ => throw <| IO.userError s!"No such recursor {ctor}"
|
||||
@@ -155,7 +155,7 @@ open Replay
|
||||
Throws a `IO.userError` if the kernel rejects a constant,
|
||||
or if there are malformed recursors or constructors for inductive types.
|
||||
-/
|
||||
def replay (newConstants : Std.HashMap Name ConstantInfo) (env : Environment) : IO Environment := do
|
||||
def replay (newConstants : HashMap Name ConstantInfo) (env : Environment) : IO Environment := do
|
||||
let mut remaining : NameSet := ∅
|
||||
for (n, ci) in newConstants.toList do
|
||||
-- We skip unsafe constants, and also partial constants.
|
||||
|
||||
@@ -81,7 +81,7 @@ open Elab
|
||||
open Meta
|
||||
open FuzzyMatching
|
||||
|
||||
abbrev EligibleHeaderDecls := Std.HashMap Name ConstantInfo
|
||||
abbrev EligibleHeaderDecls := HashMap Name ConstantInfo
|
||||
|
||||
/-- Cached header declarations for which `allowCompletion headerEnv decl` is true. -/
|
||||
builtin_initialize eligibleHeaderDeclsRef : IO.Ref (Option EligibleHeaderDecls) ←
|
||||
|
||||
@@ -316,7 +316,7 @@ partial def handleDocumentHighlight (p : DocumentHighlightParams)
|
||||
let refs : Lsp.ModuleRefs ← findModuleRefs text trees |>.toLspModuleRefs
|
||||
let mut ranges := #[]
|
||||
for ident in refs.findAt p.position do
|
||||
if let some info := refs.get? ident then
|
||||
if let some info := refs.find? ident then
|
||||
if let some ⟨definitionRange, _⟩ := info.definition? then
|
||||
ranges := ranges.push definitionRange
|
||||
ranges := ranges.append <| info.usages.map (·.range)
|
||||
|
||||
@@ -93,20 +93,20 @@ def toLspRefInfo (i : RefInfo) : BaseIO Lsp.RefInfo := do
|
||||
end RefInfo
|
||||
|
||||
/-- All references from within a module for all identifiers used in a single module. -/
|
||||
def ModuleRefs := Std.HashMap RefIdent RefInfo
|
||||
def ModuleRefs := HashMap RefIdent RefInfo
|
||||
|
||||
namespace ModuleRefs
|
||||
|
||||
/-- Adds `ref` to the `RefInfo` corresponding to `ref.ident` in `self`. See `RefInfo.addRef`. -/
|
||||
def addRef (self : ModuleRefs) (ref : Reference) : ModuleRefs :=
|
||||
let refInfo := self.getD ref.ident RefInfo.empty
|
||||
let refInfo := self.findD ref.ident RefInfo.empty
|
||||
self.insert ref.ident (refInfo.addRef ref)
|
||||
|
||||
/-- Converts `refs` to a JSON-serializable `Lsp.ModuleRefs`. -/
|
||||
def toLspModuleRefs (refs : ModuleRefs) : BaseIO Lsp.ModuleRefs := do
|
||||
let refs ← refs.toList.mapM fun (k, v) => do
|
||||
return (k, ← v.toLspRefInfo)
|
||||
return Std.HashMap.ofList refs
|
||||
return HashMap.ofList refs
|
||||
|
||||
end ModuleRefs
|
||||
|
||||
@@ -261,7 +261,7 @@ all identifiers that are being collapsed into one.
|
||||
-/
|
||||
partial def combineIdents (trees : Array InfoTree) (refs : Array Reference) : Array Reference := Id.run do
|
||||
-- Deduplicate definitions based on their exact range
|
||||
let mut posMap : Std.HashMap Lsp.Range RefIdent := Std.HashMap.empty
|
||||
let mut posMap : HashMap Lsp.Range RefIdent := HashMap.empty
|
||||
for ref in refs do
|
||||
if let { ident, range, isBinder := true, .. } := ref then
|
||||
posMap := posMap.insert range ident
|
||||
@@ -277,17 +277,17 @@ partial def combineIdents (trees : Array InfoTree) (refs : Array Reference) : Ar
|
||||
refs' := refs'.push ref
|
||||
refs'
|
||||
where
|
||||
useConstRepresentatives (idMap : Std.HashMap RefIdent RefIdent)
|
||||
: Std.HashMap RefIdent RefIdent := Id.run do
|
||||
useConstRepresentatives (idMap : HashMap RefIdent RefIdent)
|
||||
: HashMap RefIdent RefIdent := Id.run do
|
||||
let insertIntoClass classesById id :=
|
||||
let representative := findCanonicalRepresentative idMap id
|
||||
let «class» := classesById.getD representative ∅
|
||||
let «class» := classesById.findD representative ∅
|
||||
let classesById := classesById.erase representative -- make `«class»` referentially unique
|
||||
let «class» := «class».insert id
|
||||
classesById.insert representative «class»
|
||||
|
||||
-- collect equivalence classes
|
||||
let mut classesById : Std.HashMap RefIdent (Std.HashSet RefIdent) := ∅
|
||||
let mut classesById : HashMap RefIdent (HashSet RefIdent) := ∅
|
||||
for ⟨id, baseId⟩ in idMap.toArray do
|
||||
classesById := insertIntoClass classesById id
|
||||
classesById := insertIntoClass classesById baseId
|
||||
@@ -310,17 +310,17 @@ where
|
||||
r := r.insert id bestRepresentative
|
||||
return r
|
||||
|
||||
findCanonicalRepresentative (idMap : Std.HashMap RefIdent RefIdent) (id : RefIdent) : RefIdent := Id.run do
|
||||
findCanonicalRepresentative (idMap : HashMap RefIdent RefIdent) (id : RefIdent) : RefIdent := Id.run do
|
||||
let mut canonicalRepresentative := id
|
||||
while idMap.contains canonicalRepresentative do
|
||||
canonicalRepresentative := idMap[canonicalRepresentative]!
|
||||
canonicalRepresentative := idMap.find! canonicalRepresentative
|
||||
return canonicalRepresentative
|
||||
|
||||
buildIdMap posMap := Id.run <| StateT.run' (s := Std.HashMap.empty) do
|
||||
buildIdMap posMap := Id.run <| StateT.run' (s := HashMap.empty) do
|
||||
-- map fvar defs to overlapping fvar defs/uses
|
||||
for ref in refs do
|
||||
let baseId := ref.ident
|
||||
if let some id := posMap[ref.range]? then
|
||||
if let some id := posMap.find? ref.range then
|
||||
insertIdMap id baseId
|
||||
|
||||
-- apply `FVarAliasInfo`
|
||||
@@ -346,11 +346,11 @@ are added to the `aliases` of the representative of the group.
|
||||
Yields to separate groups for declaration and usages if `allowSimultaneousBinderUse` is set.
|
||||
-/
|
||||
def dedupReferences (refs : Array Reference) (allowSimultaneousBinderUse := false) : Array Reference := Id.run do
|
||||
let mut refsByIdAndRange : Std.HashMap (RefIdent × Option Bool × Lsp.Range) Reference := Std.HashMap.empty
|
||||
let mut refsByIdAndRange : HashMap (RefIdent × Option Bool × Lsp.Range) Reference := HashMap.empty
|
||||
for ref in refs do
|
||||
let isBinder := if allowSimultaneousBinderUse then some ref.isBinder else none
|
||||
let key := (ref.ident, isBinder, ref.range)
|
||||
refsByIdAndRange := match refsByIdAndRange[key]? with
|
||||
refsByIdAndRange := match refsByIdAndRange[key] with
|
||||
| some ref' => refsByIdAndRange.insert key { ref' with aliases := ref'.aliases ++ ref.aliases }
|
||||
| none => refsByIdAndRange.insert key ref
|
||||
|
||||
@@ -371,21 +371,21 @@ def findModuleRefs (text : FileMap) (trees : Array InfoTree) (localVars : Bool :
|
||||
refs := refs.filter fun
|
||||
| { ident := RefIdent.fvar .., .. } => false
|
||||
| _ => true
|
||||
refs.foldl (init := Std.HashMap.empty) fun m ref => m.addRef ref
|
||||
refs.foldl (init := HashMap.empty) fun m ref => m.addRef ref
|
||||
|
||||
/-! # Collecting and maintaining reference info from different sources -/
|
||||
|
||||
/-- References from ilean files and current ilean information from file workers. -/
|
||||
structure References where
|
||||
/-- References loaded from ilean files -/
|
||||
ileans : Std.HashMap Name (System.FilePath × Lsp.ModuleRefs)
|
||||
ileans : HashMap Name (System.FilePath × Lsp.ModuleRefs)
|
||||
/-- References from workers, overriding the corresponding ilean files -/
|
||||
workers : Std.HashMap Name (Nat × Lsp.ModuleRefs)
|
||||
workers : HashMap Name (Nat × Lsp.ModuleRefs)
|
||||
|
||||
namespace References
|
||||
|
||||
/-- No ilean files, no information from workers. -/
|
||||
def empty : References := { ileans := Std.HashMap.empty, workers := Std.HashMap.empty }
|
||||
def empty : References := { ileans := HashMap.empty, workers := HashMap.empty }
|
||||
|
||||
/-- Adds the contents of an ilean file `ilean` at `path` to `self`. -/
|
||||
def addIlean (self : References) (path : System.FilePath) (ilean : Ilean) : References :=
|
||||
@@ -404,13 +404,13 @@ Replaces the current references with `refs` if `version` is newer than the curre
|
||||
in `refs` and otherwise merges the reference data if `version` is equal to the current version.
|
||||
-/
|
||||
def updateWorkerRefs (self : References) (name : Name) (version : Nat) (refs : Lsp.ModuleRefs) : References := Id.run do
|
||||
if let some (currVersion, _) := self.workers[name]? then
|
||||
if let some (currVersion, _) := self.workers.find? name then
|
||||
if version > currVersion then
|
||||
return { self with workers := self.workers.insert name (version, refs) }
|
||||
if version == currVersion then
|
||||
let current := self.workers.getD name (version, Std.HashMap.empty)
|
||||
let current := self.workers.findD name (version, HashMap.empty)
|
||||
let merged := refs.fold (init := current.snd) fun m ident info =>
|
||||
m.getD ident Lsp.RefInfo.empty |>.merge info |> m.insert ident
|
||||
m.findD ident Lsp.RefInfo.empty |>.merge info |> m.insert ident
|
||||
return { self with workers := self.workers.insert name (version, merged) }
|
||||
return self
|
||||
|
||||
@@ -419,7 +419,7 @@ Replaces the worker references in `self` with the `refs` of the worker managing
|
||||
if `version` is newer than the current version managed in `refs`.
|
||||
-/
|
||||
def finalizeWorkerRefs (self : References) (name : Name) (version : Nat) (refs : Lsp.ModuleRefs) : References := Id.run do
|
||||
if let some (currVersion, _) := self.workers[name]? then
|
||||
if let some (currVersion, _) := self.workers.find? name then
|
||||
if version < currVersion then
|
||||
return self
|
||||
return { self with workers := self.workers.insert name (version, refs) }
|
||||
@@ -429,8 +429,8 @@ def removeWorkerRefs (self : References) (name : Name) : References :=
|
||||
{ self with workers := self.workers.erase name }
|
||||
|
||||
/-- Yields a map from all modules to all of their references. -/
|
||||
def allRefs (self : References) : Std.HashMap Name Lsp.ModuleRefs :=
|
||||
let ileanRefs := self.ileans.toArray.foldl (init := Std.HashMap.empty) fun m (name, _, refs) => m.insert name refs
|
||||
def allRefs (self : References) : HashMap Name Lsp.ModuleRefs :=
|
||||
let ileanRefs := self.ileans.toArray.foldl (init := HashMap.empty) fun m (name, _, refs) => m.insert name refs
|
||||
self.workers.toArray.foldl (init := ileanRefs) fun m (name, _, refs) => m.insert name refs
|
||||
|
||||
/--
|
||||
@@ -445,12 +445,12 @@ def allRefsFor
|
||||
let refsToCheck := match ident with
|
||||
| RefIdent.const .. => self.allRefs.toArray
|
||||
| RefIdent.fvar identModule .. =>
|
||||
match self.allRefs[identModule]? with
|
||||
match self.allRefs.find? identModule with
|
||||
| none => #[]
|
||||
| some refs => #[(identModule, refs)]
|
||||
let mut result := #[]
|
||||
for (module, refs) in refsToCheck do
|
||||
let some info := refs.get? ident
|
||||
let some info := refs.find? ident
|
||||
| continue
|
||||
let some path ← srcSearchPath.findModuleWithExt "lean" module
|
||||
| continue
|
||||
@@ -462,13 +462,13 @@ def allRefsFor
|
||||
|
||||
/-- Yields all references in `module` at `pos`. -/
|
||||
def findAt (self : References) (module : Name) (pos : Lsp.Position) (includeStop := false) : Array RefIdent := Id.run do
|
||||
if let some refs := self.allRefs[module]? then
|
||||
if let some refs := self.allRefs.find? module then
|
||||
return refs.findAt pos includeStop
|
||||
#[]
|
||||
|
||||
/-- Yields the first reference in `module` at `pos`. -/
|
||||
def findRange? (self : References) (module : Name) (pos : Lsp.Position) (includeStop := false) : Option Range := do
|
||||
let refs ← self.allRefs[module]?
|
||||
let refs ← self.allRefs.find? module
|
||||
refs.findRange? pos includeStop
|
||||
|
||||
/-- Location and parent declaration of a reference. -/
|
||||
|
||||
@@ -90,10 +90,6 @@ section Utils
|
||||
| crashed (e : IO.Error)
|
||||
| ioError (e : IO.Error)
|
||||
|
||||
inductive CrashOrigin
|
||||
| fileWorkerToClientForwarding
|
||||
| clientToFileWorkerForwarding
|
||||
|
||||
inductive WorkerState where
|
||||
/-- The watchdog can detect a crashed file worker in two places: When trying to send a message
|
||||
to the file worker and when reading a request reply.
|
||||
@@ -102,7 +98,7 @@ section Utils
|
||||
that are in-flight are errored. Upon receiving the next packet for that file worker, the file
|
||||
worker is restarted and the packet is forwarded to it. If the crash was detected while writing
|
||||
a packet, we queue that packet until the next packet for the file worker arrives. -/
|
||||
| crashed (queuedMsgs : Array JsonRpc.Message) (origin : CrashOrigin)
|
||||
| crashed (queuedMsgs : Array JsonRpc.Message)
|
||||
| running
|
||||
|
||||
abbrev PendingRequestMap := RBMap RequestID JsonRpc.Message compare
|
||||
@@ -140,11 +136,6 @@ section FileWorker
|
||||
for ⟨id, _⟩ in pendingRequests do
|
||||
hError.writeLspResponseError { id := id, code := code, message := msg }
|
||||
|
||||
def queuedMsgs (fw : FileWorker) : Array JsonRpc.Message :=
|
||||
match fw.state with
|
||||
| .running => #[]
|
||||
| .crashed queuedMsgs _ => queuedMsgs
|
||||
|
||||
end FileWorker
|
||||
end FileWorker
|
||||
|
||||
@@ -413,23 +404,10 @@ section ServerM
|
||||
return
|
||||
eraseFileWorker uri
|
||||
|
||||
def handleCrash (uri : DocumentUri) (queuedMsgs : Array JsonRpc.Message) (origin: CrashOrigin) : ServerM Unit := do
|
||||
def handleCrash (uri : DocumentUri) (queuedMsgs : Array JsonRpc.Message) : ServerM Unit := do
|
||||
let some fw ← findFileWorker? uri
|
||||
| return
|
||||
updateFileWorkers { fw with state := WorkerState.crashed queuedMsgs origin }
|
||||
|
||||
def tryDischargeQueuedMessages (uri : DocumentUri) (queuedMsgs : Array JsonRpc.Message) : ServerM Unit := do
|
||||
let some fw ← findFileWorker? uri
|
||||
| throwServerError "Cannot find file worker for '{uri}'."
|
||||
let mut crashedMsgs := #[]
|
||||
-- Try to discharge all queued msgs, tracking the ones that we can't discharge
|
||||
for msg in queuedMsgs do
|
||||
try
|
||||
fw.stdin.writeLspMessage msg
|
||||
catch _ =>
|
||||
crashedMsgs := crashedMsgs.push msg
|
||||
if ¬ crashedMsgs.isEmpty then
|
||||
handleCrash uri crashedMsgs .clientToFileWorkerForwarding
|
||||
updateFileWorkers { fw with state := WorkerState.crashed queuedMsgs }
|
||||
|
||||
/-- Tries to write a message, sets the state of the FileWorker to `crashed` if it does not succeed
|
||||
and restarts the file worker if the `crashed` flag was already set. Just logs an error if
|
||||
@@ -445,7 +423,7 @@ section ServerM
|
||||
let some fw ← findFileWorker? uri
|
||||
| return
|
||||
match fw.state with
|
||||
| WorkerState.crashed queuedMsgs _ =>
|
||||
| WorkerState.crashed queuedMsgs =>
|
||||
let mut queuedMsgs := queuedMsgs
|
||||
if queueFailedMessage then
|
||||
queuedMsgs := queuedMsgs.push msg
|
||||
@@ -454,7 +432,17 @@ section ServerM
|
||||
-- restart the crashed FileWorker
|
||||
eraseFileWorker uri
|
||||
startFileWorker fw.doc
|
||||
tryDischargeQueuedMessages uri queuedMsgs
|
||||
let some newFw ← findFileWorker? uri
|
||||
| throwServerError "Cannot find file worker for '{uri}'."
|
||||
let mut crashedMsgs := #[]
|
||||
-- try to discharge all queued msgs, tracking the ones that we can't discharge
|
||||
for msg in queuedMsgs do
|
||||
try
|
||||
newFw.stdin.writeLspMessage msg
|
||||
catch _ =>
|
||||
crashedMsgs := crashedMsgs.push msg
|
||||
if ¬ crashedMsgs.isEmpty then
|
||||
handleCrash uri crashedMsgs
|
||||
| WorkerState.running =>
|
||||
let initialQueuedMsgs :=
|
||||
if queueFailedMessage then
|
||||
@@ -464,7 +452,7 @@ section ServerM
|
||||
try
|
||||
fw.stdin.writeLspMessage msg
|
||||
catch _ =>
|
||||
handleCrash uri initialQueuedMsgs .clientToFileWorkerForwarding
|
||||
handleCrash uri initialQueuedMsgs
|
||||
|
||||
/--
|
||||
Sends a notification to the file worker identified by `uri` that its dependency `staleDependency`
|
||||
@@ -650,7 +638,7 @@ def handleCallHierarchyOutgoingCalls (p : CallHierarchyOutgoingCallsParams)
|
||||
|
||||
let references ← (← read).references.get
|
||||
|
||||
let some refs := references.allRefs[module]?
|
||||
let some refs := references.allRefs.find? module
|
||||
| return #[]
|
||||
|
||||
let items ← refs.toArray.filterMapM fun ⟨ident, info⟩ => do
|
||||
@@ -714,9 +702,9 @@ def handlePrepareRename (p : PrepareRenameParams) : ServerM (Option Range) := do
|
||||
def handleRename (p : RenameParams) : ServerM Lsp.WorkspaceEdit := do
|
||||
if (String.toName p.newName).isAnonymous then
|
||||
throwServerError s!"Can't rename: `{p.newName}` is not an identifier"
|
||||
let mut refs : Std.HashMap DocumentUri (RBMap Lsp.Position Lsp.Position compare) := ∅
|
||||
let mut refs : HashMap DocumentUri (RBMap Lsp.Position Lsp.Position compare) := ∅
|
||||
for { uri, range } in (← handleReference { p with context.includeDeclaration := true }) do
|
||||
refs := refs.insert uri <| (refs.getD uri ∅).insert range.start range.end
|
||||
refs := refs.insert uri <| (refs.findD uri ∅).insert range.start range.end
|
||||
-- We have to filter the list of changes to put the ranges in order and
|
||||
-- remove any duplicates or overlapping ranges, or else the rename will not apply
|
||||
let changes := refs.fold (init := ∅) fun changes uri map => Id.run do
|
||||
@@ -967,16 +955,7 @@ section MainLoop
|
||||
let workers ← st.fileWorkersRef.get
|
||||
let mut workerTasks := #[]
|
||||
for (_, fw) in workers do
|
||||
-- When the forwarding task crashes, its return value will be stuck at
|
||||
-- `WorkerEvent.crashed _`.
|
||||
-- We want to handle this event only once, not over and over again,
|
||||
-- so once the state becomes `WorkerState.crashed _ .fileWorkerToClientForwarding`
|
||||
-- as a result of `WorkerEvent.crashed _`, we stop handling this event until
|
||||
-- eventually the file worker is restarted by a notification from the client.
|
||||
-- We do not want to filter the forwarding task in case of
|
||||
-- `WorkerState.crashed _ .clientToFileWorkerForwarding`, since the forwarding task
|
||||
-- exit code may still contain valuable information in this case (e.g. that the imports changed).
|
||||
if !(fw.state matches WorkerState.crashed _ .fileWorkerToClientForwarding) then
|
||||
if let WorkerState.running := fw.state then
|
||||
workerTasks := workerTasks.push <| fw.commTask.map (ServerEvent.workerEvent fw)
|
||||
|
||||
let ev ← IO.waitAny (clientTask :: workerTasks.toList)
|
||||
@@ -1005,16 +984,13 @@ section MainLoop
|
||||
| WorkerEvent.ioError e =>
|
||||
throwServerError s!"IO error while processing events for {fw.doc.uri}: {e}"
|
||||
| WorkerEvent.crashed _ =>
|
||||
handleCrash fw.doc.uri fw.queuedMsgs .fileWorkerToClientForwarding
|
||||
handleCrash fw.doc.uri #[]
|
||||
mainLoop clientTask
|
||||
| WorkerEvent.terminated =>
|
||||
throwServerError <| "Internal server error: got termination event for worker that "
|
||||
++ "should have been removed"
|
||||
| .importsChanged =>
|
||||
let uri := fw.doc.uri
|
||||
let queuedMsgs := fw.queuedMsgs
|
||||
startFileWorker fw.doc
|
||||
tryDischargeQueuedMessages uri queuedMsgs
|
||||
mainLoop clientTask
|
||||
end MainLoop
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ Authors: David Thrane Christiansen
|
||||
prelude
|
||||
import Init.Data
|
||||
import Lean.Data.HashMap
|
||||
import Std.Data.HashMap.Basic
|
||||
import Init.Omega
|
||||
|
||||
namespace Lean.Diff
|
||||
@@ -58,7 +57,7 @@ structure Histogram.Entry (α : Type u) (lsize rsize : Nat) where
|
||||
|
||||
/-- A histogram for arrays maps each element to a count and, if applicable, an index.-/
|
||||
def Histogram (α : Type u) (lsize rsize : Nat) [BEq α] [Hashable α] :=
|
||||
Std.HashMap α (Histogram.Entry α lsize rsize)
|
||||
Lean.HashMap α (Histogram.Entry α lsize rsize)
|
||||
|
||||
|
||||
section
|
||||
@@ -68,7 +67,7 @@ variable [BEq α] [Hashable α]
|
||||
/-- Add an element from the left array to a histogram -/
|
||||
def Histogram.addLeft (histogram : Histogram α lsize rsize) (index : Fin lsize) (val : α)
|
||||
: Histogram α lsize rsize :=
|
||||
match histogram.get? val with
|
||||
match histogram.find? val with
|
||||
| none => histogram.insert val {
|
||||
leftCount := 1, leftIndex := some index,
|
||||
leftWF := by simp,
|
||||
@@ -82,7 +81,7 @@ def Histogram.addLeft (histogram : Histogram α lsize rsize) (index : Fin lsize)
|
||||
/-- Add an element from the right array to a histogram -/
|
||||
def Histogram.addRight (histogram : Histogram α lsize rsize) (index : Fin rsize) (val : α)
|
||||
: Histogram α lsize rsize :=
|
||||
match histogram.get? val with
|
||||
match histogram.find? val with
|
||||
| none => histogram.insert val {
|
||||
leftCount := 0, leftIndex := none,
|
||||
leftWF := by simp,
|
||||
|
||||
@@ -28,7 +28,7 @@ structure State where
|
||||
Set of visited subterms that satisfy the predicate `p`.
|
||||
We have to use this set to make sure `f` is applied at most once of each subterm that satisfies `p`.
|
||||
-/
|
||||
checked : Std.HashSet Expr
|
||||
checked : HashSet Expr
|
||||
|
||||
unsafe def initCache : State := {
|
||||
visited := mkArray cacheSize.toNat (cast lcProof ())
|
||||
|
||||
@@ -5,15 +5,14 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Expr
|
||||
import Std.Data.HashMap.Raw
|
||||
|
||||
namespace Lean
|
||||
|
||||
structure HasConstCache (declNames : Array Name) where
|
||||
cache : Std.HashMap.Raw Expr Bool := Std.HashMap.Raw.empty
|
||||
cache : HashMapImp Expr Bool := mkHashMapImp
|
||||
|
||||
unsafe def HasConstCache.containsUnsafe (e : Expr) : StateM (HasConstCache declNames) Bool := do
|
||||
if let some r := (← get).cache.get? (beq := ⟨ptrEq⟩) e then
|
||||
if let some r := (← get).cache.find? (beq := ⟨ptrEq⟩) e then
|
||||
return r
|
||||
else
|
||||
match e with
|
||||
@@ -27,7 +26,7 @@ unsafe def HasConstCache.containsUnsafe (e : Expr) : StateM (HasConstCache declN
|
||||
| _ => return false
|
||||
where
|
||||
cache (e : Expr) (r : Bool) : StateM (HasConstCache declNames) Bool := do
|
||||
modify fun ⟨cache⟩ => ⟨cache.insert (beq := ⟨ptrEq⟩) e r⟩
|
||||
modify fun ⟨cache⟩ => ⟨cache.insert (beq := ⟨ptrEq⟩) e r |>.1⟩
|
||||
return r
|
||||
|
||||
/--
|
||||
|
||||
@@ -6,7 +6,6 @@ Authors: Leonardo de Moura
|
||||
prelude
|
||||
import Init.Control.StateRef
|
||||
import Lean.Data.HashMap
|
||||
import Std.Data.HashMap.Basic
|
||||
|
||||
namespace Lean
|
||||
/-- Interface for caching results. -/
|
||||
@@ -37,15 +36,15 @@ instance {α β ε : Type} {m : Type → Type} [MonadCache α β m] [Monad m] :
|
||||
/-- Adapter for implementing `MonadCache` interface using `HashMap`s.
|
||||
We just have to specify how to extract/modify the `HashMap`. -/
|
||||
class MonadHashMapCacheAdapter (α β : Type) (m : Type → Type) [BEq α] [Hashable α] where
|
||||
getCache : m (Std.HashMap α β)
|
||||
modifyCache : (Std.HashMap α β → Std.HashMap α β) → m Unit
|
||||
getCache : m (HashMap α β)
|
||||
modifyCache : (HashMap α β → HashMap α β) → m Unit
|
||||
|
||||
namespace MonadHashMapCacheAdapter
|
||||
|
||||
@[always_inline, inline]
|
||||
def findCached? {α β : Type} {m : Type → Type} [BEq α] [Hashable α] [Monad m] [MonadHashMapCacheAdapter α β m] (a : α) : m (Option β) := do
|
||||
let c ← getCache
|
||||
pure (c.get? a)
|
||||
pure (c.find? a)
|
||||
|
||||
@[always_inline, inline]
|
||||
def cache {α β : Type} {m : Type → Type} [BEq α] [Hashable α] [MonadHashMapCacheAdapter α β m] (a : α) (b : β) : m Unit :=
|
||||
@@ -57,7 +56,7 @@ instance {α β : Type} {m : Type → Type} [BEq α] [Hashable α] [Monad m] [Mo
|
||||
|
||||
end MonadHashMapCacheAdapter
|
||||
|
||||
def MonadCacheT {ω} (α β : Type) (m : Type → Type) [STWorld ω m] [BEq α] [Hashable α] := StateRefT (Std.HashMap α β) m
|
||||
def MonadCacheT {ω} (α β : Type) (m : Type → Type) [STWorld ω m] [BEq α] [Hashable α] := StateRefT (HashMap α β) m
|
||||
|
||||
namespace MonadCacheT
|
||||
|
||||
@@ -68,7 +67,7 @@ instance : MonadHashMapCacheAdapter α β (MonadCacheT α β m) where
|
||||
modifyCache f := (modify f : StateRefT' ..)
|
||||
|
||||
@[inline] def run {σ} (x : MonadCacheT α β m σ) : m σ :=
|
||||
x.run' Std.HashMap.empty
|
||||
x.run' mkHashMap
|
||||
|
||||
instance : Monad (MonadCacheT α β m) := inferInstanceAs (Monad (StateRefT' _ _ _))
|
||||
instance : MonadLift m (MonadCacheT α β m) := inferInstanceAs (MonadLift m (StateRefT' _ _ _))
|
||||
@@ -81,7 +80,7 @@ instance [Alternative m] : Alternative (MonadCacheT α β m) := inferInstanceAs
|
||||
end MonadCacheT
|
||||
|
||||
/- Similar to `MonadCacheT`, but using `StateT` instead of `StateRefT` -/
|
||||
def MonadStateCacheT (α β : Type) (m : Type → Type) [BEq α] [Hashable α] := StateT (Std.HashMap α β) m
|
||||
def MonadStateCacheT (α β : Type) (m : Type → Type) [BEq α] [Hashable α] := StateT (HashMap α β) m
|
||||
|
||||
namespace MonadStateCacheT
|
||||
|
||||
@@ -92,7 +91,7 @@ instance : MonadHashMapCacheAdapter α β (MonadStateCacheT α β m) where
|
||||
modifyCache f := (modify f : StateT ..)
|
||||
|
||||
@[always_inline, inline] def run {σ} (x : MonadStateCacheT α β m σ) : m σ :=
|
||||
x.run' Std.HashMap.empty
|
||||
x.run' mkHashMap
|
||||
|
||||
instance : Monad (MonadStateCacheT α β m) := inferInstanceAs (Monad (StateT _ _))
|
||||
instance : MonadLift m (MonadStateCacheT α β m) := inferInstanceAs (MonadLift m (StateT _ _))
|
||||
|
||||
@@ -96,9 +96,9 @@ deriving FromJson, ToJson
|
||||
|
||||
/-- Thread with maps necessary for computing max sharing indices -/
|
||||
structure ThreadWithMaps extends Thread where
|
||||
stringMap : Std.HashMap String Nat := {}
|
||||
funcMap : Std.HashMap Nat Nat := {}
|
||||
stackMap : Std.HashMap (Nat × Option Nat) Nat := {}
|
||||
stringMap : HashMap String Nat := {}
|
||||
funcMap : HashMap Nat Nat := {}
|
||||
stackMap : HashMap (Nat × Option Nat) Nat := {}
|
||||
/-- Last timestamp encountered: stop time of preceding sibling, or else start time of parent. -/
|
||||
lastTime : Float := 0
|
||||
|
||||
@@ -123,7 +123,7 @@ where
|
||||
if pp then
|
||||
funcName := s!"{funcName}: {← msg.format}"
|
||||
let strIdx ← modifyGet fun thread =>
|
||||
if let some idx := thread.stringMap[funcName]? then
|
||||
if let some idx := thread.stringMap.find? funcName then
|
||||
(idx, thread)
|
||||
else
|
||||
(thread.stringMap.size, { thread with
|
||||
@@ -131,7 +131,7 @@ where
|
||||
stringMap := thread.stringMap.insert funcName thread.stringMap.size })
|
||||
let category := categories.findIdx? (·.name == data.cls.getRoot.toString) |>.getD 0
|
||||
let funcIdx ← modifyGet fun thread =>
|
||||
if let some idx := thread.funcMap[strIdx]? then
|
||||
if let some idx := thread.funcMap.find? strIdx then
|
||||
(idx, thread)
|
||||
else
|
||||
(thread.funcMap.size, { thread with
|
||||
@@ -151,7 +151,7 @@ where
|
||||
funcMap := thread.funcMap.insert strIdx thread.funcMap.size })
|
||||
let frameIdx := funcIdx
|
||||
let stackIdx ← modifyGet fun thread =>
|
||||
if let some idx := thread.stackMap[(frameIdx, parentStackIdx?)]? then
|
||||
if let some idx := thread.stackMap.find? (frameIdx, parentStackIdx?) then
|
||||
(idx, thread)
|
||||
else
|
||||
(thread.stackMap.size, { thread with
|
||||
@@ -222,7 +222,7 @@ def Profile.export (name : String) (startTime : Milliseconds) (traceState : Trac
|
||||
|
||||
structure ThreadWithCollideMaps extends ThreadWithMaps where
|
||||
/-- Max sharing map for samples -/
|
||||
sampleMap : Std.HashMap Nat Nat := {}
|
||||
sampleMap : HashMap Nat Nat := {}
|
||||
|
||||
/--
|
||||
Adds samples from `add` to `thread`, increasing the weight of existing samples with identical stacks
|
||||
@@ -237,7 +237,7 @@ where
|
||||
let oldStackIdx := add.samples.stack[oldSampleIdx]!
|
||||
let stackIdx ← collideStacks oldStackIdx
|
||||
modify fun thread =>
|
||||
if let some idx := thread.sampleMap[stackIdx]? then
|
||||
if let some idx := thread.sampleMap.find? stackIdx then
|
||||
-- imperative to preserve linear use of arrays here!
|
||||
let ⟨⟨⟨t1, t2, t3, samples, t5, t6, t7, t8, t9, t10⟩, o2, o3, o4, o5⟩, o6⟩ := thread
|
||||
let ⟨s1, s2, weight, s3, s4⟩ := samples
|
||||
@@ -265,7 +265,7 @@ where
|
||||
let oldStrIdx := add.funcTable.name[oldFuncIdx]!
|
||||
let strIdx ← getStrIdx add.stringArray[oldStrIdx]!
|
||||
let funcIdx ← modifyGet fun thread =>
|
||||
if let some idx := thread.funcMap[strIdx]? then
|
||||
if let some idx := thread.funcMap.find? strIdx then
|
||||
(idx, thread)
|
||||
else
|
||||
(thread.funcMap.size, { thread with
|
||||
@@ -284,7 +284,7 @@ where
|
||||
funcMap := thread.funcMap.insert strIdx thread.funcMap.size })
|
||||
let frameIdx := funcIdx
|
||||
modifyGet fun thread =>
|
||||
if let some idx := thread.stackMap[(frameIdx, parentStackIdx?)]? then
|
||||
if let some idx := thread.stackMap.find? (frameIdx, parentStackIdx?) then
|
||||
(idx, thread)
|
||||
else
|
||||
(thread.stackMap.size,
|
||||
@@ -302,7 +302,7 @@ where
|
||||
⟨⟨⟨t1,t2, t3, t4, t5, stackTable, t7, t8, t9, t10⟩, o2, o3, stackMap, o5⟩, o6⟩)
|
||||
getStrIdx (s : String) :=
|
||||
modifyGet fun thread =>
|
||||
if let some idx := thread.stringMap[s]? then
|
||||
if let some idx := thread.stringMap.find? s then
|
||||
(idx, thread)
|
||||
else
|
||||
(thread.stringMap.size, { thread with
|
||||
|
||||
@@ -7,8 +7,6 @@ prelude
|
||||
import Init.Data.Hashable
|
||||
import Lean.Data.HashSet
|
||||
import Lean.Data.HashMap
|
||||
import Std.Data.HashSet.Basic
|
||||
import Std.Data.HashMap.Basic
|
||||
|
||||
namespace Lean
|
||||
|
||||
@@ -25,33 +23,33 @@ unsafe instance : BEq (Ptr α) where
|
||||
Set of pointers. It is a low-level auxiliary datastructure used for traversing DAGs.
|
||||
-/
|
||||
unsafe def PtrSet (α : Type) :=
|
||||
Std.HashSet (Ptr α)
|
||||
HashSet (Ptr α)
|
||||
|
||||
unsafe def mkPtrSet {α : Type} (capacity : Nat := 64) : PtrSet α :=
|
||||
Std.HashSet.empty capacity
|
||||
mkHashSet capacity
|
||||
|
||||
unsafe abbrev PtrSet.insert (s : PtrSet α) (a : α) : PtrSet α :=
|
||||
Std.HashSet.insert s { value := a }
|
||||
HashSet.insert s { value := a }
|
||||
|
||||
unsafe abbrev PtrSet.contains (s : PtrSet α) (a : α) : Bool :=
|
||||
Std.HashSet.contains s { value := a }
|
||||
HashSet.contains s { value := a }
|
||||
|
||||
/--
|
||||
Map of pointers. It is a low-level auxiliary datastructure used for traversing DAGs.
|
||||
-/
|
||||
unsafe def PtrMap (α : Type) (β : Type) :=
|
||||
Std.HashMap (Ptr α) β
|
||||
HashMap (Ptr α) β
|
||||
|
||||
unsafe def mkPtrMap {α β : Type} (capacity : Nat := 64) : PtrMap α β :=
|
||||
Std.HashMap.empty capacity
|
||||
mkHashMap capacity
|
||||
|
||||
unsafe abbrev PtrMap.insert (s : PtrMap α β) (a : α) (b : β) : PtrMap α β :=
|
||||
Std.HashMap.insert s { value := a } b
|
||||
HashMap.insert s { value := a } b
|
||||
|
||||
unsafe abbrev PtrMap.contains (s : PtrMap α β) (a : α) : Bool :=
|
||||
Std.HashMap.contains s { value := a }
|
||||
HashMap.contains s { value := a }
|
||||
|
||||
unsafe abbrev PtrMap.find? (s : PtrMap α β) (a : α) : Option β :=
|
||||
Std.HashMap.get? s { value := a }
|
||||
HashMap.find? s { value := a }
|
||||
|
||||
end Lean
|
||||
|
||||
@@ -6,7 +6,6 @@ Authors: Leonardo de Moura
|
||||
prelude
|
||||
import Init.Data.List.Control
|
||||
import Lean.Data.HashMap
|
||||
import Std.Data.HashMap.Basic
|
||||
namespace Lean.SCC
|
||||
/-!
|
||||
Very simple implementation of Tarjan's SCC algorithm.
|
||||
@@ -26,7 +25,7 @@ structure Data where
|
||||
structure State where
|
||||
stack : List α := []
|
||||
nextIndex : Nat := 0
|
||||
data : Std.HashMap α Data := {}
|
||||
data : HashMap α Data := {}
|
||||
sccs : List (List α) := []
|
||||
|
||||
abbrev M := StateM (State α)
|
||||
@@ -36,7 +35,7 @@ variable {α : Type} [BEq α] [Hashable α]
|
||||
|
||||
private def getDataOf (a : α) : M α Data := do
|
||||
let s ← get
|
||||
match s.data[a]? with
|
||||
match s.data.find? a with
|
||||
| some d => pure d
|
||||
| none => pure {}
|
||||
|
||||
@@ -53,7 +52,7 @@ private def push (a : α) : M α Unit :=
|
||||
|
||||
private def modifyDataOf (a : α) (f : Data → Data) : M α Unit :=
|
||||
modify fun s => { s with
|
||||
data := match s.data[a]? with
|
||||
data := match s.data.find? a with
|
||||
| none => s.data
|
||||
| some d => s.data.insert a (f d)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import Lean.Data.PersistentHashSet
|
||||
open ShareCommon
|
||||
namespace Lean.ShareCommon
|
||||
|
||||
set_option linter.deprecated false in
|
||||
def objectFactory :=
|
||||
StateFactory.mk {
|
||||
Map := HashMap, mkMap := (mkHashMap ·), mapFind? := (·.find?), mapInsert := (·.insert)
|
||||
|
||||
@@ -67,7 +67,7 @@ structure TraceState where
|
||||
traces : PersistentArray TraceElem := {}
|
||||
deriving Inhabited
|
||||
|
||||
builtin_initialize inheritedTraceOptions : IO.Ref (Std.HashSet Name) ← IO.mkRef ∅
|
||||
builtin_initialize inheritedTraceOptions : IO.Ref (HashSet Name) ← IO.mkRef ∅
|
||||
|
||||
class MonadTrace (m : Type → Type) where
|
||||
modifyTraceState : (TraceState → TraceState) → m Unit
|
||||
@@ -88,7 +88,7 @@ def printTraces : m Unit := do
|
||||
def resetTraceState : m Unit :=
|
||||
modifyTraceState (fun _ => {})
|
||||
|
||||
private def checkTraceOption (inherited : Std.HashSet Name) (opts : Options) (cls : Name) : Bool :=
|
||||
private def checkTraceOption (inherited : HashSet Name) (opts : Options) (cls : Name) : Bool :=
|
||||
!opts.isEmpty && go (`trace ++ cls)
|
||||
where
|
||||
go (opt : Name) : Bool :=
|
||||
|
||||
@@ -5,4 +5,3 @@ Authors: Sebastian Ullrich
|
||||
-/
|
||||
prelude
|
||||
import Std.Data
|
||||
import Std.Sat
|
||||
|
||||
@@ -20,7 +20,7 @@ open Std.DHashMap.Internal.List
|
||||
|
||||
universe u v
|
||||
|
||||
variable {α : Type u} {β : α → Type v}
|
||||
variable {α : Type u} {β : α → Type v} [BEq α] [Hashable α]
|
||||
|
||||
namespace Std.DHashMap.Internal
|
||||
|
||||
@@ -41,8 +41,6 @@ theorem Raw.buckets_emptyc {i : Nat} {h} :
|
||||
(∅ : Raw α β).buckets[i]'h = AssocList.nil :=
|
||||
buckets_empty
|
||||
|
||||
variable [BEq α] [Hashable α]
|
||||
|
||||
@[simp]
|
||||
theorem buckets_empty {c} {i : Nat} {h} :
|
||||
(empty c : DHashMap α β).1.buckets[i]'h = AssocList.nil := by
|
||||
@@ -57,9 +55,7 @@ end empty
|
||||
|
||||
namespace Raw₀
|
||||
|
||||
variable [BEq α] [Hashable α]
|
||||
variable (m : Raw₀ α β) (h : m.1.WF)
|
||||
set_option deprecated.oldSectionVars true
|
||||
|
||||
/-- Internal implementation detail of the hash map -/
|
||||
scoped macro "wf_trivial" : tactic => `(tactic|
|
||||
|
||||
@@ -75,7 +75,6 @@ namespace Raw
|
||||
open Internal.Raw₀ Internal.Raw
|
||||
|
||||
variable {m : Raw α β} (h : m.WF)
|
||||
set_option deprecated.oldSectionVars true
|
||||
|
||||
@[simp]
|
||||
theorem isEmpty_empty {c} : (empty c : Raw α β).isEmpty := by
|
||||
|
||||
@@ -112,10 +112,6 @@ Tries to retrieve the mapping for the given key, returning `none` if no such map
|
||||
@[inline] def get? (m : HashMap α β) (a : α) : Option β :=
|
||||
DHashMap.Const.get? m.inner a
|
||||
|
||||
@[deprecated get? "Use `m[a]?` or `m.get? a` instead", inherit_doc get?]
|
||||
def find? (m : HashMap α β) (a : α) : Option β :=
|
||||
m.get? a
|
||||
|
||||
@[inline, inherit_doc DHashMap.contains] def contains (m : HashMap α β)
|
||||
(a : α) : Bool :=
|
||||
m.inner.contains a
|
||||
@@ -139,10 +135,6 @@ Retrieves the mapping for the given key. Ensures that such a mapping exists by r
|
||||
(fallback : β) : β :=
|
||||
DHashMap.Const.getD m.inner a fallback
|
||||
|
||||
@[deprecated getD, inherit_doc getD]
|
||||
def findD (m : HashMap α β) (a : α) (fallback : β) : β :=
|
||||
m.getD a fallback
|
||||
|
||||
/--
|
||||
The notation `m[a]!` is preferred over calling this function directly.
|
||||
|
||||
@@ -151,10 +143,6 @@ Tries to retrieve the mapping for the given key, panicking if no such mapping is
|
||||
@[inline] def get! [Inhabited β] (m : HashMap α β) (a : α) : β :=
|
||||
DHashMap.Const.get! m.inner a
|
||||
|
||||
@[deprecated get! "Use `m[a]!` or `m.get! a` instead", inherit_doc get!]
|
||||
def find! [Inhabited β] (m : HashMap α β) (a : α) : Option β :=
|
||||
m.get! a
|
||||
|
||||
instance [BEq α] [Hashable α] : GetElem? (HashMap α β) α β (fun m a => a ∈ m) where
|
||||
getElem m a h := m.get a h
|
||||
getElem? m a := m.get? a
|
||||
@@ -248,16 +236,3 @@ instance [BEq α] [Hashable α] [Repr α] [Repr β] : Repr (HashMap α β) where
|
||||
end Unverified
|
||||
|
||||
end Std.HashMap
|
||||
|
||||
/--
|
||||
Groups all elements `x`, `y` in `xs` with `key x == key y` into the same array
|
||||
`(xs.groupByKey key).find! (key x)`. Groups preserve the relative order of elements in `xs`.
|
||||
-/
|
||||
def Array.groupByKey [BEq α] [Hashable α] (key : β → α) (xs : Array β)
|
||||
: Std.HashMap α (Array β) := Id.run do
|
||||
let mut groups := ∅
|
||||
for x in xs do
|
||||
let group := groups.getD (key x) #[]
|
||||
groups := groups.erase (key x) -- make `group` referentially unique
|
||||
groups := groups.insert (key x) (group.push x)
|
||||
return groups
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user