mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-27 23:34:11 +00:00
Compare commits
36 Commits
betaLetRec
...
hbv/readFi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
118709ee6c | ||
|
|
c2575107f2 | ||
|
|
4db828d885 | ||
|
|
3f16f339e7 | ||
|
|
2771296ca5 | ||
|
|
5e337872ce | ||
|
|
329fa6309b | ||
|
|
370488a9ff | ||
|
|
df38da8e09 | ||
|
|
a2b93d6c18 | ||
|
|
63c4de5fea | ||
|
|
3b14642c42 | ||
|
|
d52da36e68 | ||
|
|
bf82965eec | ||
|
|
4bac74c4ac | ||
|
|
8b9d27de31 | ||
|
|
d15f0335a9 | ||
|
|
240ebff549 | ||
|
|
a29bca7f00 | ||
|
|
313f6b3c74 | ||
|
|
43fa46412d | ||
|
|
234704e304 | ||
|
|
12a714a6f9 | ||
|
|
cdc7ed0224 | ||
|
|
217abdf97a | ||
|
|
490a2b4bf9 | ||
|
|
84d45deb10 | ||
|
|
f46d216e18 | ||
|
|
cc42a17931 | ||
|
|
e106be19dd | ||
|
|
1efd6657d4 | ||
|
|
473b34561d | ||
|
|
574066b30b | ||
|
|
1e6d617aad | ||
|
|
c17a4ddc94 | ||
|
|
5be4f5e30c |
30
.github/workflows/rebase-on-comment.yml
vendored
Normal file
30
.github/workflows/rebase-on-comment.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
# As the PR author, use `!rebase` to rebase a commit
|
||||
name: Rebase on Comment
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
rebase:
|
||||
if: github.event.issue.pull_request != '' && github.event.comment.body == '!rebase' && github.event.comment.user.login == github.event.issue.user.login
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: refs/pull/${{ github.event.issue.number }}/head
|
||||
- name: Rebase PR branch onto base branch
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
PR_NUMBER="${{ github.event.issue.number }}"
|
||||
API_URL="https://api.github.com/repos/${{ github.repository }}/pulls/$PR_NUMBER"
|
||||
PR_DETAILS="$(curl -s -H "Authorization: token $GITHUB_TOKEN" $API_URL)"
|
||||
|
||||
BASE_REF="$(echo $PR_DETAILS | jq -r .base.ref)"
|
||||
|
||||
git checkout -b working-branch
|
||||
git fetch origin $BASE_REF
|
||||
git rebase origin/$BASE_REF
|
||||
git push origin refs/pull/${{ github.event.issue.number }}/head --force-with-lease
|
||||
@@ -152,22 +152,26 @@ We'll use `v4.7.0-rc1` as the intended release version in this example.
|
||||
This will add a list of all the commits since the last stable version.
|
||||
- Delete "update stage0" commits, and anything with a completely inscrutable commit message.
|
||||
- Next, we will move a curated list of downstream repos to the release candidate.
|
||||
- This assumes that there is already a *reviewed* branch `bump/v4.7.0` on each repository
|
||||
containing the required adaptations (or no adaptations are required).
|
||||
The preparation of this branch is beyond the scope of this document.
|
||||
- This assumes that for each repository either:
|
||||
* There is already a *reviewed* branch `bump/v4.7.0` containing the required adaptations.
|
||||
The preparation of this branch is beyond the scope of this document.
|
||||
* The repository does not need any changes to move to the new version.
|
||||
- For each of the target repositories:
|
||||
- Checkout the `bump/v4.7.0` branch.
|
||||
- Verify that the `lean-toolchain` is set to the nightly from which the release candidate was created.
|
||||
- `git merge origin/master`
|
||||
- Change the `lean-toolchain` to `leanprover/lean4:v4.7.0-rc1`
|
||||
- In `lakefile.lean`, change any dependencies which were using `nightly-testing` or `bump/v4.7.0` branches
|
||||
back to `master` or `main`, and run `lake update` for those dependencies.
|
||||
- Run `lake build` to ensure that dependencies are found (but it's okay to stop it after a moment).
|
||||
- `git commit`
|
||||
- `git push`
|
||||
- Open a PR from `bump/v4.7.0` to `master`, and either merge it yourself after CI, if appropriate,
|
||||
or notify the maintainers that it is ready to go.
|
||||
- Once this PR has been merged, tag `master` with `v4.7.0-rc1` and push this tag.
|
||||
- If the repository does not need any changes (i.e. `bump/v4.7.0` does not exist) then create
|
||||
a new PR updating `lean-toolchain` to `leanprover/lean4:v4.7.0-rc1` and running `lake update`.
|
||||
- Otherwise:
|
||||
- Checkout the `bump/v4.7.0` branch.
|
||||
- Verify that the `lean-toolchain` is set to the nightly from which the release candidate was created.
|
||||
- `git merge origin/master`
|
||||
- Change the `lean-toolchain` to `leanprover/lean4:v4.7.0-rc1`
|
||||
- In `lakefile.lean`, change any dependencies which were using `nightly-testing` or `bump/v4.7.0` branches
|
||||
back to `master` or `main`, and run `lake update` for those dependencies.
|
||||
- Run `lake build` to ensure that dependencies are found (but it's okay to stop it after a moment).
|
||||
- `git commit`
|
||||
- `git push`
|
||||
- Open a PR from `bump/v4.7.0` to `master`, and either merge it yourself after CI, if appropriate,
|
||||
or notify the maintainers that it is ready to go.
|
||||
- Once the PR has been merged, tag `master` with `v4.7.0-rc1` and push this tag.
|
||||
- We do this for the same list of repositories as for stable releases, see above.
|
||||
As above, there are dependencies between these, and so the process above is iterative.
|
||||
It greatly helps if you can merge the `bump/v4.7.0` PRs yourself!
|
||||
|
||||
@@ -7,6 +7,7 @@ prelude
|
||||
import Init.Data.Nat.MinMax
|
||||
import Init.Data.Nat.Lemmas
|
||||
import Init.Data.List.Monadic
|
||||
import Init.Data.List.Nat.Range
|
||||
import Init.Data.Fin.Basic
|
||||
import Init.Data.Array.Mem
|
||||
import Init.TacticsExtra
|
||||
@@ -336,6 +337,10 @@ theorem not_mem_nil (a : α) : ¬ a ∈ #[] := nofun
|
||||
|
||||
/-- # get lemmas -/
|
||||
|
||||
theorem lt_of_getElem {x : α} {a : Array α} {idx : Nat} {hidx : idx < a.size} (_ : a[idx] = x) :
|
||||
idx < a.size :=
|
||||
hidx
|
||||
|
||||
theorem getElem?_mem {l : Array α} {i : Fin l.size} : l[i] ∈ l := by
|
||||
erw [Array.mem_def, getElem_eq_data_getElem]
|
||||
apply List.get_mem
|
||||
@@ -505,6 +510,13 @@ theorem size_eq_length_data (as : Array α) : as.size = as.data.length := rfl
|
||||
simp only [mkEmpty_eq, size_push] at *
|
||||
omega
|
||||
|
||||
@[simp] theorem data_range (n : Nat) : (range n).data = List.range n := by
|
||||
induction n <;> simp_all [range, Nat.fold, flip, List.range_succ]
|
||||
|
||||
@[simp]
|
||||
theorem getElem_range {n : Nat} {x : Nat} (h : x < (Array.range n).size) : (Array.range n)[x] = x := by
|
||||
simp [getElem_eq_data_getElem]
|
||||
|
||||
set_option linter.deprecated false in
|
||||
@[simp] theorem reverse_data (a : Array α) : a.reverse.data = a.data.reverse := by
|
||||
let rec go (as : Array α) (i j hj)
|
||||
@@ -707,13 +719,22 @@ theorem mapIdx_spec (as : Array α) (f : Fin as.size → α → β)
|
||||
unfold modify modifyM Id.run
|
||||
split <;> simp
|
||||
|
||||
theorem get_modify {arr : Array α} {x i} (h : i < arr.size) :
|
||||
(arr.modify x f).get ⟨i, by simp [h]⟩ =
|
||||
if x = i then f (arr.get ⟨i, h⟩) else arr.get ⟨i, h⟩ := by
|
||||
simp [modify, modifyM, Id.run]; split
|
||||
· simp [get_set _ _ _ h]; split <;> simp [*]
|
||||
theorem getElem_modify {as : Array α} {x i} (h : i < as.size) :
|
||||
(as.modify x f)[i]'(by simp [h]) = if x = i then f as[i] else as[i] := by
|
||||
simp only [modify, modifyM, get_eq_getElem, Id.run, Id.pure_eq]
|
||||
split
|
||||
· simp only [Id.bind_eq, get_set _ _ _ h]; split <;> simp [*]
|
||||
· rw [if_neg (mt (by rintro rfl; exact h) ‹_›)]
|
||||
|
||||
theorem getElem_modify_self {as : Array α} {i : Nat} (h : i < as.size) (f : α → α) :
|
||||
(as.modify i f)[i]'(by simp [h]) = f as[i] := by
|
||||
simp [getElem_modify h]
|
||||
|
||||
theorem getElem_modify_of_ne {as : Array α} {i : Nat} (hj : j < as.size)
|
||||
(f : α → α) (h : i ≠ j) :
|
||||
(as.modify i f)[j]'(by rwa [size_modify]) = as[j] := by
|
||||
simp [getElem_modify hj, h]
|
||||
|
||||
/-! ### filter -/
|
||||
|
||||
@[simp] theorem filter_data (p : α → Bool) (l : Array α) :
|
||||
|
||||
@@ -42,8 +42,8 @@ Bitvectors have decidable equality. This should be used via the instance `Decida
|
||||
-- We manually derive the `DecidableEq` instances for `BitVec` because
|
||||
-- we want to have builtin support for bit-vector literals, and we
|
||||
-- need a name for this function to implement `canUnfoldAtMatcher` at `WHNF.lean`.
|
||||
def BitVec.decEq (a b : BitVec n) : Decidable (a = b) :=
|
||||
match a, b with
|
||||
def BitVec.decEq (x y : BitVec n) : Decidable (x = y) :=
|
||||
match x, y with
|
||||
| ⟨n⟩, ⟨m⟩ =>
|
||||
if h : n = m then
|
||||
isTrue (h ▸ rfl)
|
||||
@@ -69,9 +69,9 @@ protected def ofNat (n : Nat) (i : Nat) : BitVec n where
|
||||
instance instOfNat : OfNat (BitVec n) i where ofNat := .ofNat n i
|
||||
instance natCastInst : NatCast (BitVec w) := ⟨BitVec.ofNat w⟩
|
||||
|
||||
/-- Given a bitvector `a`, return the underlying `Nat`. This is O(1) because `BitVec` is a
|
||||
/-- Given a bitvector `x`, return the underlying `Nat`. This is O(1) because `BitVec` is a
|
||||
(zero-cost) wrapper around a `Nat`. -/
|
||||
protected def toNat (a : BitVec n) : Nat := a.toFin.val
|
||||
protected def toNat (x : BitVec n) : Nat := x.toFin.val
|
||||
|
||||
/-- Return the bound in terms of toNat. -/
|
||||
theorem isLt (x : BitVec w) : x.toNat < 2^w := x.toFin.isLt
|
||||
@@ -123,18 +123,18 @@ section getXsb
|
||||
@[inline] def getMsb (x : BitVec w) (i : Nat) : Bool := i < w && getLsb x (w-1-i)
|
||||
|
||||
/-- Return most-significant bit in bitvector. -/
|
||||
@[inline] protected def msb (a : BitVec n) : Bool := getMsb a 0
|
||||
@[inline] protected def msb (x : BitVec n) : Bool := getMsb x 0
|
||||
|
||||
end getXsb
|
||||
|
||||
section Int
|
||||
|
||||
/-- Interpret the bitvector as an integer stored in two's complement form. -/
|
||||
protected def toInt (a : BitVec n) : Int :=
|
||||
if 2 * a.toNat < 2^n then
|
||||
a.toNat
|
||||
protected def toInt (x : BitVec n) : Int :=
|
||||
if 2 * x.toNat < 2^n then
|
||||
x.toNat
|
||||
else
|
||||
(a.toNat : Int) - (2^n : Nat)
|
||||
(x.toNat : Int) - (2^n : Nat)
|
||||
|
||||
/-- The `BitVec` with value `(2^n + (i mod 2^n)) mod 2^n`. -/
|
||||
protected def ofInt (n : Nat) (i : Int) : BitVec n := .ofNatLt (i % (Int.ofNat (2^n))).toNat (by
|
||||
@@ -215,7 +215,7 @@ instance : Neg (BitVec n) := ⟨.neg⟩
|
||||
/--
|
||||
Return the absolute value of a signed bitvector.
|
||||
-/
|
||||
protected def abs (s : BitVec n) : BitVec n := if s.msb then .neg s else s
|
||||
protected def abs (x : BitVec n) : BitVec n := if x.msb then .neg x else x
|
||||
|
||||
/--
|
||||
Multiplication for bit vectors. This can be interpreted as either signed or unsigned negation
|
||||
@@ -262,12 +262,12 @@ sdiv 5#4 -2 = -2#4
|
||||
sdiv (-7#4) (-2) = 3#4
|
||||
```
|
||||
-/
|
||||
def sdiv (s t : BitVec n) : BitVec n :=
|
||||
match s.msb, t.msb with
|
||||
| false, false => udiv s t
|
||||
| false, true => .neg (udiv s (.neg t))
|
||||
| true, false => .neg (udiv (.neg s) t)
|
||||
| true, true => udiv (.neg s) (.neg t)
|
||||
def sdiv (x y : BitVec n) : BitVec n :=
|
||||
match x.msb, y.msb with
|
||||
| false, false => udiv x y
|
||||
| false, true => .neg (udiv x (.neg y))
|
||||
| true, false => .neg (udiv (.neg x) y)
|
||||
| true, true => udiv (.neg x) (.neg y)
|
||||
|
||||
/--
|
||||
Signed division for bit vectors using SMTLIB rules for division by zero.
|
||||
@@ -276,40 +276,40 @@ Specifically, `smtSDiv x 0 = if x >= 0 then -1 else 1`
|
||||
|
||||
SMT-Lib name: `bvsdiv`.
|
||||
-/
|
||||
def smtSDiv (s t : BitVec n) : BitVec n :=
|
||||
match s.msb, t.msb with
|
||||
| false, false => smtUDiv s t
|
||||
| false, true => .neg (smtUDiv s (.neg t))
|
||||
| true, false => .neg (smtUDiv (.neg s) t)
|
||||
| true, true => smtUDiv (.neg s) (.neg t)
|
||||
def smtSDiv (x y : BitVec n) : BitVec n :=
|
||||
match x.msb, y.msb with
|
||||
| false, false => smtUDiv x y
|
||||
| false, true => .neg (smtUDiv x (.neg y))
|
||||
| true, false => .neg (smtUDiv (.neg x) y)
|
||||
| true, true => smtUDiv (.neg x) (.neg y)
|
||||
|
||||
/--
|
||||
Remainder for signed division rounding to zero.
|
||||
|
||||
SMT_Lib name: `bvsrem`.
|
||||
-/
|
||||
def srem (s t : BitVec n) : BitVec n :=
|
||||
match s.msb, t.msb with
|
||||
| false, false => umod s t
|
||||
| false, true => umod s (.neg t)
|
||||
| true, false => .neg (umod (.neg s) t)
|
||||
| true, true => .neg (umod (.neg s) (.neg t))
|
||||
def srem (x y : BitVec n) : BitVec n :=
|
||||
match x.msb, y.msb with
|
||||
| false, false => umod x y
|
||||
| false, true => umod x (.neg y)
|
||||
| true, false => .neg (umod (.neg x) y)
|
||||
| true, true => .neg (umod (.neg x) (.neg y))
|
||||
|
||||
/--
|
||||
Remainder for signed division rounded to negative infinity.
|
||||
|
||||
SMT_Lib name: `bvsmod`.
|
||||
-/
|
||||
def smod (s t : BitVec m) : BitVec m :=
|
||||
match s.msb, t.msb with
|
||||
| false, false => umod s t
|
||||
def smod (x y : BitVec m) : BitVec m :=
|
||||
match x.msb, y.msb with
|
||||
| false, false => umod x y
|
||||
| false, true =>
|
||||
let u := umod s (.neg t)
|
||||
(if u = .zero m then u else .add u t)
|
||||
let u := umod x (.neg y)
|
||||
(if u = .zero m then u else .add u y)
|
||||
| true, false =>
|
||||
let u := umod (.neg s) t
|
||||
(if u = .zero m then u else .sub t u)
|
||||
| true, true => .neg (umod (.neg s) (.neg t))
|
||||
let u := umod (.neg x) y
|
||||
(if u = .zero m then u else .sub y u)
|
||||
| true, true => .neg (umod (.neg x) (.neg y))
|
||||
|
||||
end arithmetic
|
||||
|
||||
@@ -373,8 +373,8 @@ end relations
|
||||
|
||||
section cast
|
||||
|
||||
/-- `cast eq i` embeds `i` into an equal `BitVec` type. -/
|
||||
@[inline] def cast (eq : n = m) (i : BitVec n) : BitVec m := .ofNatLt i.toNat (eq ▸ i.isLt)
|
||||
/-- `cast eq x` embeds `x` into an equal `BitVec` type. -/
|
||||
@[inline] def cast (eq : n = m) (x : BitVec n) : BitVec m := .ofNatLt x.toNat (eq ▸ x.isLt)
|
||||
|
||||
@[simp] theorem cast_ofNat {n m : Nat} (h : n = m) (x : Nat) :
|
||||
cast h (BitVec.ofNat n x) = BitVec.ofNat m x := by
|
||||
@@ -391,7 +391,7 @@ Extraction of bits `start` to `start + len - 1` from a bit vector of size `n` to
|
||||
new bitvector of size `len`. If `start + len > n`, then the vector will be zero-padded in the
|
||||
high bits.
|
||||
-/
|
||||
def extractLsb' (start len : Nat) (a : BitVec n) : BitVec len := .ofNat _ (a.toNat >>> start)
|
||||
def extractLsb' (start len : Nat) (x : BitVec n) : BitVec len := .ofNat _ (x.toNat >>> start)
|
||||
|
||||
/--
|
||||
Extraction of bits `hi` (inclusive) down to `lo` (inclusive) from a bit vector of size `n` to
|
||||
@@ -399,12 +399,12 @@ yield a new bitvector of size `hi - lo + 1`.
|
||||
|
||||
SMT-Lib name: `extract`.
|
||||
-/
|
||||
def extractLsb (hi lo : Nat) (a : BitVec n) : BitVec (hi - lo + 1) := extractLsb' lo _ a
|
||||
def extractLsb (hi lo : Nat) (x : BitVec n) : BitVec (hi - lo + 1) := extractLsb' lo _ x
|
||||
|
||||
/--
|
||||
A version of `zeroExtend` that requires a proof, but is a noop.
|
||||
-/
|
||||
def zeroExtend' {n w : Nat} (le : n ≤ w) (x : BitVec n) : BitVec w :=
|
||||
def zeroExtend' {n w : Nat} (le : n ≤ w) (x : BitVec n) : BitVec w :=
|
||||
x.toNat#'(by
|
||||
apply Nat.lt_of_lt_of_le x.isLt
|
||||
exact Nat.pow_le_pow_of_le_right (by trivial) le)
|
||||
@@ -413,8 +413,8 @@ def zeroExtend' {n w : Nat} (le : n ≤ w) (x : BitVec n) : BitVec w :=
|
||||
`shiftLeftZeroExtend x n` returns `zeroExtend (w+n) x <<< n` without
|
||||
needing to compute `x % 2^(2+n)`.
|
||||
-/
|
||||
def shiftLeftZeroExtend (msbs : BitVec w) (m : Nat) : BitVec (w+m) :=
|
||||
let shiftLeftLt {x : Nat} (p : x < 2^w) (m : Nat) : x <<< m < 2^(w+m) := by
|
||||
def shiftLeftZeroExtend (msbs : BitVec w) (m : Nat) : BitVec (w + m) :=
|
||||
let shiftLeftLt {x : Nat} (p : x < 2^w) (m : Nat) : x <<< m < 2^(w + m) := by
|
||||
simp [Nat.shiftLeft_eq, Nat.pow_add]
|
||||
apply Nat.mul_lt_mul_of_pos_right p
|
||||
exact (Nat.two_pow_pos m)
|
||||
@@ -502,24 +502,24 @@ instance : Complement (BitVec w) := ⟨.not⟩
|
||||
|
||||
/--
|
||||
Left shift for bit vectors. The low bits are filled with zeros. As a numeric operation, this is
|
||||
equivalent to `a * 2^s`, modulo `2^n`.
|
||||
equivalent to `x * 2^s`, modulo `2^n`.
|
||||
|
||||
SMT-Lib name: `bvshl` except this operator uses a `Nat` shift value.
|
||||
-/
|
||||
protected def shiftLeft (a : BitVec n) (s : Nat) : BitVec n := BitVec.ofNat n (a.toNat <<< s)
|
||||
protected def shiftLeft (x : BitVec n) (s : Nat) : BitVec n := BitVec.ofNat n (x.toNat <<< s)
|
||||
instance : HShiftLeft (BitVec w) Nat (BitVec w) := ⟨.shiftLeft⟩
|
||||
|
||||
/--
|
||||
(Logical) right shift for bit vectors. The high bits are filled with zeros.
|
||||
As a numeric operation, this is equivalent to `a / 2^s`, rounding down.
|
||||
As a numeric operation, this is equivalent to `x / 2^s`, rounding down.
|
||||
|
||||
SMT-Lib name: `bvlshr` except this operator uses a `Nat` shift value.
|
||||
-/
|
||||
def ushiftRight (a : BitVec n) (s : Nat) : BitVec n :=
|
||||
(a.toNat >>> s)#'(by
|
||||
let ⟨a, lt⟩ := a
|
||||
def ushiftRight (x : BitVec n) (s : Nat) : BitVec n :=
|
||||
(x.toNat >>> s)#'(by
|
||||
let ⟨x, lt⟩ := x
|
||||
simp only [BitVec.toNat, Nat.shiftRight_eq_div_pow, Nat.div_lt_iff_lt_mul (Nat.two_pow_pos s)]
|
||||
rw [←Nat.mul_one a]
|
||||
rw [←Nat.mul_one x]
|
||||
exact Nat.mul_lt_mul_of_lt_of_le' lt (Nat.two_pow_pos s) (Nat.le_refl 1))
|
||||
|
||||
instance : HShiftRight (BitVec w) Nat (BitVec w) := ⟨.ushiftRight⟩
|
||||
@@ -527,15 +527,24 @@ instance : HShiftRight (BitVec w) Nat (BitVec w) := ⟨.ushiftRight⟩
|
||||
/--
|
||||
Arithmetic right shift for bit vectors. The high bits are filled with the
|
||||
most-significant bit.
|
||||
As a numeric operation, this is equivalent to `a.toInt >>> s`.
|
||||
As a numeric operation, this is equivalent to `x.toInt >>> s`.
|
||||
|
||||
SMT-Lib name: `bvashr` except this operator uses a `Nat` shift value.
|
||||
-/
|
||||
def sshiftRight (a : BitVec n) (s : Nat) : BitVec n := .ofInt n (a.toInt >>> s)
|
||||
def sshiftRight (x : BitVec n) (s : Nat) : BitVec n := .ofInt n (x.toInt >>> s)
|
||||
|
||||
instance {n} : HShiftLeft (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x <<< y.toNat⟩
|
||||
instance {n} : HShiftRight (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x >>> y.toNat⟩
|
||||
|
||||
/--
|
||||
Arithmetic right shift for bit vectors. The high bits are filled with the
|
||||
most-significant bit.
|
||||
As a numeric operation, this is equivalent to `a.toInt >>> s.toNat`.
|
||||
|
||||
SMT-Lib name: `bvashr`.
|
||||
-/
|
||||
def sshiftRight' (a : BitVec n) (s : BitVec m) : BitVec n := a.sshiftRight s.toNat
|
||||
|
||||
/-- Auxiliary function for `rotateLeft`, which does not take into account the case where
|
||||
the rotation amount is greater than the bitvector width. -/
|
||||
def rotateLeftAux (x : BitVec w) (n : Nat) : BitVec w :=
|
||||
|
||||
@@ -289,18 +289,18 @@ theorem sle_eq_carry (x y : BitVec w) :
|
||||
A recurrence that describes multiplication as repeated addition.
|
||||
Is useful for bitblasting multiplication.
|
||||
-/
|
||||
def mulRec (l r : BitVec w) (s : Nat) : BitVec w :=
|
||||
let cur := if r.getLsb s then (l <<< s) else 0
|
||||
def mulRec (x y : BitVec w) (s : Nat) : BitVec w :=
|
||||
let cur := if y.getLsb s then (x <<< s) else 0
|
||||
match s with
|
||||
| 0 => cur
|
||||
| s + 1 => mulRec l r s + cur
|
||||
| s + 1 => mulRec x y s + cur
|
||||
|
||||
theorem mulRec_zero_eq (l r : BitVec w) :
|
||||
mulRec l r 0 = if r.getLsb 0 then l else 0 := by
|
||||
theorem mulRec_zero_eq (x y : BitVec w) :
|
||||
mulRec x y 0 = if y.getLsb 0 then x else 0 := by
|
||||
simp [mulRec]
|
||||
|
||||
theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) :
|
||||
mulRec l r (s + 1) = mulRec l r s + if r.getLsb (s + 1) then (l <<< (s + 1)) else 0 := rfl
|
||||
theorem mulRec_succ_eq (x y : BitVec w) (s : Nat) :
|
||||
mulRec x y (s + 1) = mulRec x y s + if y.getLsb (s + 1) then (x <<< (s + 1)) else 0 := rfl
|
||||
|
||||
/--
|
||||
Recurrence lemma: truncating to `i+1` bits and then zero extending to `w`
|
||||
@@ -326,29 +326,29 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w
|
||||
by_cases hi : x.getLsb i <;> simp [hi] <;> omega
|
||||
|
||||
/--
|
||||
Recurrence lemma: multiplying `l` with the first `s` bits of `r` is the
|
||||
same as truncating `r` to `s` bits, then zero extending to the original length,
|
||||
Recurrence lemma: multiplying `x` with the first `s` bits of `y` is the
|
||||
same as truncating `y` to `s` bits, then zero extending to the original length,
|
||||
and performing the multplication. -/
|
||||
theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) :
|
||||
mulRec l r s = l * ((r.truncate (s + 1)).zeroExtend w) := by
|
||||
theorem mulRec_eq_mul_signExtend_truncate (x y : BitVec w) (s : Nat) :
|
||||
mulRec x y s = x * ((y.truncate (s + 1)).zeroExtend w) := by
|
||||
induction s
|
||||
case zero =>
|
||||
simp only [mulRec_zero_eq, ofNat_eq_ofNat, Nat.reduceAdd]
|
||||
by_cases r.getLsb 0
|
||||
case pos hr =>
|
||||
simp only [hr, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero,
|
||||
hr, ofBool_true, ofNat_eq_ofNat]
|
||||
by_cases y.getLsb 0
|
||||
case pos hy =>
|
||||
simp only [hy, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero,
|
||||
ofBool_true, ofNat_eq_ofNat]
|
||||
rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)]
|
||||
simp
|
||||
case neg hr =>
|
||||
simp [hr, zeroExtend_one_eq_ofBool_getLsb_zero]
|
||||
case neg hy =>
|
||||
simp [hy, zeroExtend_one_eq_ofBool_getLsb_zero]
|
||||
case succ s' hs =>
|
||||
rw [mulRec_succ_eq, hs]
|
||||
have heq :
|
||||
(if r.getLsb (s' + 1) = true then l <<< (s' + 1) else 0) =
|
||||
(l * (r &&& (BitVec.twoPow w (s' + 1)))) := by
|
||||
(if y.getLsb (s' + 1) = true then x <<< (s' + 1) else 0) =
|
||||
(x * (y &&& (BitVec.twoPow w (s' + 1)))) := by
|
||||
simp only [ofNat_eq_ofNat, and_twoPow]
|
||||
by_cases hr : r.getLsb (s' + 1) <;> simp [hr]
|
||||
by_cases hy : y.getLsb (s' + 1) <;> simp [hy]
|
||||
rw [heq, ← BitVec.mul_add, ← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow]
|
||||
|
||||
theorem getLsb_mul (x y : BitVec w) (i : Nat) :
|
||||
@@ -429,6 +429,67 @@ theorem shiftLeft_eq_shiftLeftRec (x : BitVec w₁) (y : BitVec w₂) :
|
||||
· simp [of_length_zero]
|
||||
· simp [shiftLeftRec_eq]
|
||||
|
||||
/- ### Arithmetic shift right (sshiftRight) recurrence -/
|
||||
|
||||
/--
|
||||
`sshiftRightRec x y n` shifts `x` arithmetically/signed to the right by the first `n` bits of `y`.
|
||||
The theorem `sshiftRight_eq_sshiftRightRec` proves the equivalence of `(x.sshiftRight y)` and `sshiftRightRec`.
|
||||
Together with equations `sshiftRightRec_zero`, `sshiftRightRec_succ`,
|
||||
this allows us to unfold `sshiftRight` into a circuit for bitblasting.
|
||||
-/
|
||||
def sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ :=
|
||||
let shiftAmt := (y &&& (twoPow w₂ n))
|
||||
match n with
|
||||
| 0 => x.sshiftRight' shiftAmt
|
||||
| n + 1 => (sshiftRightRec x y n).sshiftRight' shiftAmt
|
||||
|
||||
@[simp]
|
||||
theorem sshiftRightRec_zero_eq (x : BitVec w₁) (y : BitVec w₂) :
|
||||
sshiftRightRec x y 0 = x.sshiftRight' (y &&& 1#w₂) := by
|
||||
simp only [sshiftRightRec, twoPow_zero]
|
||||
|
||||
@[simp]
|
||||
theorem sshiftRightRec_succ_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
|
||||
sshiftRightRec x y (n + 1) = (sshiftRightRec x y n).sshiftRight' (y &&& twoPow w₂ (n + 1)) := by
|
||||
simp [sshiftRightRec]
|
||||
|
||||
/--
|
||||
If `y &&& z = 0`, `x.sshiftRight (y ||| z) = (x.sshiftRight y).sshiftRight z`.
|
||||
This follows as `y &&& z = 0` implies `y ||| z = y + z`,
|
||||
and thus `x.sshiftRight (y ||| z) = x.sshiftRight (y + z) = (x.sshiftRight y).sshiftRight z`.
|
||||
-/
|
||||
theorem sshiftRight'_or_of_and_eq_zero {x : BitVec w₁} {y z : BitVec w₂}
|
||||
(h : y &&& z = 0#w₂) :
|
||||
x.sshiftRight' (y ||| z) = (x.sshiftRight' y).sshiftRight' z := by
|
||||
simp [sshiftRight', ← add_eq_or_of_and_eq_zero _ _ h,
|
||||
toNat_add_of_and_eq_zero h, sshiftRight_add]
|
||||
|
||||
theorem sshiftRightRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
|
||||
sshiftRightRec x y n = x.sshiftRight' ((y.truncate (n + 1)).zeroExtend w₂) := by
|
||||
induction n generalizing x y
|
||||
case zero =>
|
||||
ext i
|
||||
simp [twoPow_zero, Nat.reduceAdd, and_one_eq_zeroExtend_ofBool_getLsb, truncate_one]
|
||||
case succ n ih =>
|
||||
simp only [sshiftRightRec_succ_eq, and_twoPow, ih]
|
||||
by_cases h : y.getLsb (n + 1)
|
||||
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true h,
|
||||
sshiftRight'_or_of_and_eq_zero (by simp), h]
|
||||
simp
|
||||
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1)
|
||||
(by simp [h])]
|
||||
simp [h]
|
||||
|
||||
/--
|
||||
Show that `x.sshiftRight y` can be written in terms of `sshiftRightRec`.
|
||||
This can be unfolded in terms of `sshiftRightRec_zero_eq`, `sshiftRightRec_succ_eq` for bitblasting.
|
||||
-/
|
||||
theorem sshiftRight_eq_sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) :
|
||||
(x.sshiftRight' y).getLsb i = (sshiftRightRec x y (w₂ - 1)).getLsb i := by
|
||||
rcases w₂ with rfl | w₂
|
||||
· simp [of_length_zero]
|
||||
· simp [sshiftRightRec_eq]
|
||||
|
||||
/- ### Logical shift right (ushiftRight) recurrence for bitblasting -/
|
||||
|
||||
/--
|
||||
|
||||
@@ -23,7 +23,7 @@ theorem ofFin_eq_ofNat : @BitVec.ofFin w (Fin.mk x lt) = BitVec.ofNat w x := by
|
||||
simp only [BitVec.ofNat, Fin.ofNat', lt, Nat.mod_eq_of_lt]
|
||||
|
||||
/-- Prove equality of bitvectors in terms of nat operations. -/
|
||||
theorem eq_of_toNat_eq {n} : ∀ {i j : BitVec n}, i.toNat = j.toNat → i = j
|
||||
theorem eq_of_toNat_eq {n} : ∀ {x y : BitVec n}, x.toNat = y.toNat → x = y
|
||||
| ⟨_, _⟩, ⟨_, _⟩, rfl => rfl
|
||||
|
||||
@[simp] theorem val_toFin (x : BitVec w) : x.toFin.val = x.toNat := rfl
|
||||
@@ -228,12 +228,12 @@ theorem toNat_ge_of_msb_true {x : BitVec n} (p : BitVec.msb x = true) : x.toNat
|
||||
/-! ### toInt/ofInt -/
|
||||
|
||||
/-- Prove equality of bitvectors in terms of nat operations. -/
|
||||
theorem toInt_eq_toNat_cond (i : BitVec n) :
|
||||
i.toInt =
|
||||
if 2*i.toNat < 2^n then
|
||||
(i.toNat : Int)
|
||||
theorem toInt_eq_toNat_cond (x : BitVec n) :
|
||||
x.toInt =
|
||||
if 2*x.toNat < 2^n then
|
||||
(x.toNat : Int)
|
||||
else
|
||||
(i.toNat : Int) - (2^n : Nat) :=
|
||||
(x.toNat : Int) - (2^n : Nat) :=
|
||||
rfl
|
||||
|
||||
theorem msb_eq_false_iff_two_mul_lt (x : BitVec w) : x.msb = false ↔ 2 * x.toNat < 2^w := by
|
||||
@@ -260,13 +260,13 @@ theorem toInt_eq_toNat_bmod (x : BitVec n) : x.toInt = Int.bmod x.toNat (2^n) :=
|
||||
omega
|
||||
|
||||
/-- Prove equality of bitvectors in terms of nat operations. -/
|
||||
theorem eq_of_toInt_eq {i j : BitVec n} : i.toInt = j.toInt → i = j := by
|
||||
theorem eq_of_toInt_eq {x y : BitVec n} : x.toInt = y.toInt → x = y := by
|
||||
intro eq
|
||||
simp [toInt_eq_toNat_cond] at eq
|
||||
apply eq_of_toNat_eq
|
||||
revert eq
|
||||
have _ilt := i.isLt
|
||||
have _jlt := j.isLt
|
||||
have _xlt := x.isLt
|
||||
have _ylt := y.isLt
|
||||
split <;> split <;> omega
|
||||
|
||||
theorem toInt_inj (x y : BitVec n) : x.toInt = y.toInt ↔ x = y :=
|
||||
@@ -733,6 +733,21 @@ theorem getLsb_shiftLeft' {x : BitVec w₁} {y : BitVec w₂} {i : Nat} :
|
||||
getLsb (x >>> i) j = getLsb x (i+j) := by
|
||||
unfold getLsb ; simp
|
||||
|
||||
theorem ushiftRight_xor_distrib (x y : BitVec w) (n : Nat) :
|
||||
(x ^^^ y) >>> n = (x >>> n) ^^^ (y >>> n) := by
|
||||
ext
|
||||
simp
|
||||
|
||||
theorem ushiftRight_and_distrib (x y : BitVec w) (n : Nat) :
|
||||
(x &&& y) >>> n = (x >>> n) &&& (y >>> n) := by
|
||||
ext
|
||||
simp
|
||||
|
||||
theorem ushiftRight_or_distrib (x y : BitVec w) (n : Nat) :
|
||||
(x ||| y) >>> n = (x >>> n) ||| (y >>> n) := by
|
||||
ext
|
||||
simp
|
||||
|
||||
@[simp]
|
||||
theorem ushiftRight_zero_eq (x : BitVec w) : x >>> 0 = x := by
|
||||
simp [bv_toNat]
|
||||
@@ -786,7 +801,7 @@ theorem sshiftRight_eq_of_msb_true {x : BitVec w} {s : Nat} (h : x.msb = true) :
|
||||
· rw [Nat.shiftRight_eq_div_pow]
|
||||
apply Nat.lt_of_le_of_lt (Nat.div_le_self _ _) (by omega)
|
||||
|
||||
theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
|
||||
@[simp] theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
|
||||
getLsb (x.sshiftRight s) i =
|
||||
(!decide (w ≤ i) && if s + i < w then x.getLsb (s + i) else x.msb) := by
|
||||
rcases hmsb : x.msb with rfl | rfl
|
||||
@@ -807,6 +822,41 @@ theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
|
||||
Nat.not_lt, decide_eq_true_eq]
|
||||
omega
|
||||
|
||||
/-- The msb after arithmetic shifting right equals the original msb. -/
|
||||
theorem sshiftRight_msb_eq_msb {n : Nat} {x : BitVec w} :
|
||||
(x.sshiftRight n).msb = x.msb := by
|
||||
rw [msb_eq_getLsb_last, getLsb_sshiftRight, msb_eq_getLsb_last]
|
||||
by_cases hw₀ : w = 0
|
||||
· simp [hw₀]
|
||||
· simp only [show ¬(w ≤ w - 1) by omega, decide_False, Bool.not_false, Bool.true_and,
|
||||
ite_eq_right_iff]
|
||||
intros h
|
||||
simp [show n = 0 by omega]
|
||||
|
||||
@[simp] theorem sshiftRight_zero {x : BitVec w} : x.sshiftRight 0 = x := by
|
||||
ext i
|
||||
simp
|
||||
|
||||
theorem sshiftRight_add {x : BitVec w} {m n : Nat} :
|
||||
x.sshiftRight (m + n) = (x.sshiftRight m).sshiftRight n := by
|
||||
ext i
|
||||
simp only [getLsb_sshiftRight, Nat.add_assoc]
|
||||
by_cases h₁ : w ≤ (i : Nat)
|
||||
· simp [h₁]
|
||||
· simp only [h₁, decide_False, Bool.not_false, Bool.true_and]
|
||||
by_cases h₂ : n + ↑i < w
|
||||
· simp [h₂]
|
||||
· simp only [h₂, ↓reduceIte]
|
||||
by_cases h₃ : m + (n + ↑i) < w
|
||||
· simp [h₃]
|
||||
omega
|
||||
· simp [h₃, sshiftRight_msb_eq_msb]
|
||||
|
||||
/-! ### sshiftRight reductions from BitVec to Nat -/
|
||||
|
||||
@[simp]
|
||||
theorem sshiftRight_eq' (x : BitVec w) : x.sshiftRight' y = x.sshiftRight y.toNat := rfl
|
||||
|
||||
/-! ### signExtend -/
|
||||
|
||||
/-- Equation theorem for `Int.sub` when both arguments are `Int.ofNat` -/
|
||||
@@ -869,15 +919,15 @@ theorem append_def (x : BitVec v) (y : BitVec w) :
|
||||
(x ++ y).toNat = x.toNat <<< n ||| y.toNat :=
|
||||
rfl
|
||||
|
||||
@[simp] theorem getLsb_append {v : BitVec n} {w : BitVec m} :
|
||||
getLsb (v ++ w) i = bif i < m then getLsb w i else getLsb v (i - m) := by
|
||||
@[simp] theorem getLsb_append {x : BitVec n} {y : BitVec m} :
|
||||
getLsb (x ++ y) i = bif i < m then getLsb y i else getLsb x (i - m) := by
|
||||
simp only [append_def, getLsb_or, getLsb_shiftLeftZeroExtend, getLsb_zeroExtend']
|
||||
by_cases h : i < m
|
||||
· simp [h]
|
||||
· simp [h]; simp_all
|
||||
|
||||
@[simp] theorem getMsb_append {v : BitVec n} {w : BitVec m} :
|
||||
getMsb (v ++ w) i = bif n ≤ i then getMsb w (i - n) else getMsb v i := by
|
||||
@[simp] theorem getMsb_append {x : BitVec n} {y : BitVec m} :
|
||||
getMsb (x ++ y) i = bif n ≤ i then getMsb y (i - n) else getMsb x i := by
|
||||
simp [append_def]
|
||||
by_cases h : n ≤ i
|
||||
· simp [h]
|
||||
|
||||
@@ -438,6 +438,24 @@ Added for confluence between `if_true_left` and `ite_false_same` on
|
||||
-/
|
||||
@[simp] theorem eq_true_imp_eq_false : ∀(b:Bool), (b = true → b = false) ↔ (b = false) := by decide
|
||||
|
||||
/-! ### forall -/
|
||||
|
||||
theorem forall_bool' {p : Bool → Prop} (b : Bool) : (∀ x, p x) ↔ p b ∧ p !b :=
|
||||
⟨fun h ↦ ⟨h _, h _⟩, fun ⟨h₁, h₂⟩ x ↦ by cases b <;> cases x <;> assumption⟩
|
||||
|
||||
@[simp]
|
||||
theorem forall_bool {p : Bool → Prop} : (∀ b, p b) ↔ p false ∧ p true :=
|
||||
forall_bool' false
|
||||
|
||||
/-! ### exists -/
|
||||
|
||||
theorem exists_bool' {p : Bool → Prop} (b : Bool) : (∃ x, p x) ↔ p b ∨ p !b :=
|
||||
⟨fun ⟨x, hx⟩ ↦ by cases x <;> cases b <;> first | exact .inl ‹_› | exact .inr ‹_›,
|
||||
fun h ↦ by cases h <;> exact ⟨_, ‹_›⟩⟩
|
||||
|
||||
@[simp]
|
||||
theorem exists_bool {p : Bool → Prop} : (∃ b, p b) ↔ p false ∨ p true :=
|
||||
exists_bool' false
|
||||
|
||||
/-! ### cond -/
|
||||
|
||||
|
||||
@@ -5,9 +5,18 @@ Author: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.SimpLemmas
|
||||
import Init.NotationExtra
|
||||
|
||||
instance [BEq α] [BEq β] [LawfulBEq α] [LawfulBEq β] : LawfulBEq (α × β) where
|
||||
eq_of_beq {a b} (h : a.1 == b.1 && a.2 == b.2) := by
|
||||
cases a; cases b
|
||||
refine congr (congrArg _ (eq_of_beq ?_)) (eq_of_beq ?_) <;> simp_all
|
||||
rfl {a} := by cases a; simp [BEq.beq, LawfulBEq.rfl]
|
||||
|
||||
@[simp]
|
||||
protected theorem Prod.forall {p : α × β → Prop} : (∀ x, p x) ↔ ∀ a b, p (a, b) :=
|
||||
⟨fun h a b ↦ h (a, b), fun h ⟨a, b⟩ ↦ h a b⟩
|
||||
|
||||
@[simp]
|
||||
protected theorem Prod.exists {p : α × β → Prop} : (∃ x, p x) ↔ ∃ a b, p (a, b) :=
|
||||
⟨fun ⟨⟨a, b⟩, h⟩ ↦ ⟨a, b, h⟩, fun ⟨a, b, h⟩ ↦ ⟨⟨a, b⟩, h⟩⟩
|
||||
|
||||
@@ -470,31 +470,23 @@ def withFile (fn : FilePath) (mode : Mode) (f : Handle → IO α) : IO α :=
|
||||
def Handle.putStrLn (h : Handle) (s : String) : IO Unit :=
|
||||
h.putStr (s.push '\n')
|
||||
|
||||
partial def Handle.readBinToEnd (h : Handle) : IO ByteArray := do
|
||||
partial def Handle.readBinToEndInto (h : Handle) (buf : ByteArray) : IO ByteArray := do
|
||||
let rec loop (acc : ByteArray) : IO ByteArray := do
|
||||
let buf ← h.read 1024
|
||||
if buf.isEmpty then
|
||||
return acc
|
||||
else
|
||||
loop (acc ++ buf)
|
||||
loop ByteArray.empty
|
||||
loop buf
|
||||
|
||||
partial def Handle.readToEnd (h : Handle) : IO String := do
|
||||
let rec loop (s : String) := do
|
||||
let line ← h.getLine
|
||||
if line.isEmpty then
|
||||
return s
|
||||
else
|
||||
loop (s ++ line)
|
||||
loop ""
|
||||
partial def Handle.readBinToEnd (h : Handle) : IO ByteArray := do
|
||||
h.readBinToEndInto .empty
|
||||
|
||||
def readBinFile (fname : FilePath) : IO ByteArray := do
|
||||
let h ← Handle.mk fname Mode.read
|
||||
h.readBinToEnd
|
||||
|
||||
def readFile (fname : FilePath) : IO String := do
|
||||
let h ← Handle.mk fname Mode.read
|
||||
h.readToEnd
|
||||
def Handle.readToEnd (h : Handle) : IO String := do
|
||||
let data ← h.readBinToEnd
|
||||
match String.fromUTF8? data with
|
||||
| some s => return s
|
||||
| none => throw <| .userError s!"Tried to read from handle containing non UTF-8 data."
|
||||
|
||||
partial def lines (fname : FilePath) : IO (Array String) := do
|
||||
let h ← Handle.mk fname Mode.read
|
||||
@@ -600,6 +592,28 @@ end System.FilePath
|
||||
|
||||
namespace IO
|
||||
|
||||
namespace FS
|
||||
|
||||
def readBinFile (fname : FilePath) : IO ByteArray := do
|
||||
-- Requires metadata so defined after metadata
|
||||
let mdata ← fname.metadata
|
||||
let size := mdata.byteSize.toUSize
|
||||
let handle ← IO.FS.Handle.mk fname .read
|
||||
let buf ←
|
||||
if size > 0 then
|
||||
handle.read mdata.byteSize.toUSize
|
||||
else
|
||||
pure <| ByteArray.mkEmpty 0
|
||||
handle.readBinToEndInto buf
|
||||
|
||||
def readFile (fname : FilePath) : IO String := do
|
||||
let data ← readBinFile fname
|
||||
match String.fromUTF8? data with
|
||||
| some s => return s
|
||||
| none => throw <| .userError s!"Tried to read file '{fname}' containing non UTF-8 data."
|
||||
|
||||
end FS
|
||||
|
||||
def withStdin [Monad m] [MonadFinally m] [MonadLiftT BaseIO m] (h : FS.Stream) (x : m α) : m α := do
|
||||
let prev ← setStdin h
|
||||
try x finally discard <| setStdin prev
|
||||
|
||||
@@ -53,7 +53,7 @@ structure AttributeImpl extends AttributeImplCore where
|
||||
erase (decl : Name) : AttrM Unit := throwError "attribute cannot be erased"
|
||||
deriving Inhabited
|
||||
|
||||
builtin_initialize attributeMapRef : IO.Ref (HashMap Name AttributeImpl) ← IO.mkRef {}
|
||||
builtin_initialize attributeMapRef : IO.Ref (Std.HashMap Name AttributeImpl) ← IO.mkRef {}
|
||||
|
||||
/-- Low level attribute registration function. -/
|
||||
def registerBuiltinAttribute (attr : AttributeImpl) : IO Unit := do
|
||||
@@ -296,7 +296,7 @@ end EnumAttributes
|
||||
-/
|
||||
|
||||
abbrev AttributeImplBuilder := Name → List DataValue → Except String AttributeImpl
|
||||
abbrev AttributeImplBuilderTable := HashMap Name AttributeImplBuilder
|
||||
abbrev AttributeImplBuilderTable := Std.HashMap Name AttributeImplBuilder
|
||||
|
||||
builtin_initialize attributeImplBuilderTableRef : IO.Ref AttributeImplBuilderTable ← IO.mkRef {}
|
||||
|
||||
@@ -307,7 +307,7 @@ def registerAttributeImplBuilder (builderId : Name) (builder : AttributeImplBuil
|
||||
|
||||
def mkAttributeImplOfBuilder (builderId ref : Name) (args : List DataValue) : IO AttributeImpl := do
|
||||
let table ← attributeImplBuilderTableRef.get
|
||||
match table.find? builderId with
|
||||
match table[builderId]? with
|
||||
| none => throw (IO.userError ("unknown attribute implementation builder '" ++ toString builderId ++ "'"))
|
||||
| some builder => IO.ofExcept <| builder ref args
|
||||
|
||||
@@ -317,7 +317,7 @@ inductive AttributeExtensionOLeanEntry where
|
||||
|
||||
structure AttributeExtensionState where
|
||||
newEntries : List AttributeExtensionOLeanEntry := []
|
||||
map : HashMap Name AttributeImpl
|
||||
map : Std.HashMap Name AttributeImpl
|
||||
deriving Inhabited
|
||||
|
||||
abbrev AttributeExtension := PersistentEnvExtension AttributeExtensionOLeanEntry (AttributeExtensionOLeanEntry × AttributeImpl) AttributeExtensionState
|
||||
@@ -348,7 +348,7 @@ private def AttributeExtension.addImported (es : Array (Array AttributeExtension
|
||||
let map ← es.foldlM
|
||||
(fun map entries =>
|
||||
entries.foldlM
|
||||
(fun (map : HashMap Name AttributeImpl) entry => do
|
||||
(fun (map : Std.HashMap Name AttributeImpl) entry => do
|
||||
let attrImpl ← mkAttributeImplOfEntry ctx.env ctx.opts entry
|
||||
return map.insert attrImpl.name attrImpl)
|
||||
map)
|
||||
@@ -378,7 +378,7 @@ def getBuiltinAttributeNames : IO (List Name) :=
|
||||
|
||||
def getBuiltinAttributeImpl (attrName : Name) : IO AttributeImpl := do
|
||||
let m ← attributeMapRef.get
|
||||
match m.find? attrName with
|
||||
match m[attrName]? with
|
||||
| some attr => pure attr
|
||||
| none => throw (IO.userError ("unknown attribute '" ++ toString attrName ++ "'"))
|
||||
|
||||
@@ -396,7 +396,7 @@ def getAttributeNames (env : Environment) : List Name :=
|
||||
|
||||
def getAttributeImpl (env : Environment) (attrName : Name) : Except String AttributeImpl :=
|
||||
let m := (attributeExtension.getState env).map
|
||||
match m.find? attrName with
|
||||
match m[attrName]? with
|
||||
| some attr => pure attr
|
||||
| none => throw ("unknown attribute '" ++ toString attrName ++ "'")
|
||||
|
||||
|
||||
@@ -26,9 +26,9 @@ instance : Hashable Key := ⟨getHash⟩
|
||||
end OwnedSet
|
||||
|
||||
open OwnedSet (Key) in
|
||||
abbrev OwnedSet := HashMap Key Unit
|
||||
def OwnedSet.insert (s : OwnedSet) (k : OwnedSet.Key) : OwnedSet := HashMap.insert s k ()
|
||||
def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := HashMap.contains s k
|
||||
abbrev OwnedSet := Std.HashMap Key Unit
|
||||
def OwnedSet.insert (s : OwnedSet) (k : OwnedSet.Key) : OwnedSet := Std.HashMap.insert s k ()
|
||||
def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := Std.HashMap.contains s k
|
||||
|
||||
/-! We perform borrow inference in a block of mutually recursive functions.
|
||||
Join points are viewed as local functions, and are identified using
|
||||
@@ -49,7 +49,7 @@ instance : Hashable Key := ⟨getHash⟩
|
||||
end ParamMap
|
||||
|
||||
open ParamMap (Key)
|
||||
abbrev ParamMap := HashMap Key (Array Param)
|
||||
abbrev ParamMap := Std.HashMap Key (Array Param)
|
||||
|
||||
def ParamMap.fmt (map : ParamMap) : Format :=
|
||||
let fmts := map.fold (fun fmt k ps =>
|
||||
@@ -109,7 +109,7 @@ partial def visitFnBody (fn : FunId) (paramMap : ParamMap) : FnBody → FnBody
|
||||
| FnBody.jdecl j _ v b =>
|
||||
let v := visitFnBody fn paramMap v
|
||||
let b := visitFnBody fn paramMap b
|
||||
match paramMap.find? (ParamMap.Key.jp fn j) with
|
||||
match paramMap[ParamMap.Key.jp fn j]? with
|
||||
| some ys => FnBody.jdecl j ys v b
|
||||
| none => unreachable!
|
||||
| FnBody.case tid x xType alts =>
|
||||
@@ -125,7 +125,7 @@ def visitDecls (decls : Array Decl) (paramMap : ParamMap) : Array Decl :=
|
||||
decls.map fun decl => match decl with
|
||||
| Decl.fdecl f _ ty b info =>
|
||||
let b := visitFnBody f paramMap b
|
||||
match paramMap.find? (ParamMap.Key.decl f) with
|
||||
match paramMap[ParamMap.Key.decl f]? with
|
||||
| some xs => Decl.fdecl f xs ty b info
|
||||
| none => unreachable!
|
||||
| other => other
|
||||
@@ -178,7 +178,7 @@ def isOwned (x : VarId) : M Bool := do
|
||||
/-- Updates `map[k]` using the current set of `owned` variables. -/
|
||||
def updateParamMap (k : ParamMap.Key) : M Unit := do
|
||||
let s ← get
|
||||
match s.paramMap.find? k with
|
||||
match s.paramMap[k]? with
|
||||
| some ps => do
|
||||
let ps ← ps.mapM fun (p : Param) => do
|
||||
if !p.borrow then pure p
|
||||
@@ -192,7 +192,7 @@ def updateParamMap (k : ParamMap.Key) : M Unit := do
|
||||
|
||||
def getParamInfo (k : ParamMap.Key) : M (Array Param) := do
|
||||
let s ← get
|
||||
match s.paramMap.find? k with
|
||||
match s.paramMap[k]? with
|
||||
| some ps => pure ps
|
||||
| none =>
|
||||
match k with
|
||||
|
||||
@@ -11,6 +11,7 @@ import Lean.Compiler.IR.Basic
|
||||
import Lean.Compiler.IR.CompilerM
|
||||
import Lean.Compiler.IR.FreeVars
|
||||
import Lean.Compiler.IR.ElimDeadVars
|
||||
import Lean.Data.AssocList
|
||||
|
||||
namespace Lean.IR.ExplicitBoxing
|
||||
/-!
|
||||
|
||||
@@ -152,7 +152,7 @@ def getFunctionSummary? (env : Environment) (fid : FunId) : Option Value :=
|
||||
| some modIdx => findAtSorted? (functionSummariesExt.getModuleEntries env modIdx) fid
|
||||
| none => functionSummariesExt.getState env |>.find? fid
|
||||
|
||||
abbrev Assignment := HashMap VarId Value
|
||||
abbrev Assignment := Std.HashMap VarId Value
|
||||
|
||||
structure InterpContext where
|
||||
currFnIdx : Nat := 0
|
||||
@@ -172,7 +172,7 @@ def findVarValue (x : VarId) : M Value := do
|
||||
let ctx ← read
|
||||
let s ← get
|
||||
let assignment := s.assignments[ctx.currFnIdx]!
|
||||
return assignment.findD x bot
|
||||
return assignment.getD x bot
|
||||
|
||||
def findArgValue (arg : Arg) : M Value :=
|
||||
match arg with
|
||||
@@ -303,7 +303,7 @@ partial def elimDeadAux (assignment : Assignment) : FnBody → FnBody
|
||||
| FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux assignment b)
|
||||
| FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux assignment v) (elimDeadAux assignment b)
|
||||
| FnBody.case tid x xType alts =>
|
||||
let v := assignment.findD x bot
|
||||
let v := assignment.getD x bot
|
||||
let alts := alts.map fun alt =>
|
||||
match alt with
|
||||
| Alt.ctor i b => Alt.ctor i <| if containsCtor v i then elimDeadAux assignment b else FnBody.unreachable
|
||||
|
||||
@@ -252,7 +252,7 @@ def throwUnknownVar {α : Type} (x : VarId) : M α :=
|
||||
|
||||
def getJPParams (j : JoinPointId) : M (Array Param) := do
|
||||
let ctx ← read;
|
||||
match ctx.jpMap.find? j with
|
||||
match ctx.jpMap[j]? with
|
||||
| some ps => pure ps
|
||||
| none => throw "unknown join point"
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Siddharth Bhat
|
||||
-/
|
||||
prelude
|
||||
import Lean.Data.HashMap
|
||||
import Lean.Runtime
|
||||
import Lean.Compiler.NameMangling
|
||||
import Lean.Compiler.ExportAttr
|
||||
@@ -65,8 +64,8 @@ structure Context (llvmctx : LLVM.Context) where
|
||||
llvmmodule : LLVM.Module llvmctx
|
||||
|
||||
structure State (llvmctx : LLVM.Context) where
|
||||
var2val : HashMap VarId (LLVM.LLVMType llvmctx × LLVM.Value llvmctx)
|
||||
jp2bb : HashMap JoinPointId (LLVM.BasicBlock llvmctx)
|
||||
var2val : Std.HashMap VarId (LLVM.LLVMType llvmctx × LLVM.Value llvmctx)
|
||||
jp2bb : Std.HashMap JoinPointId (LLVM.BasicBlock llvmctx)
|
||||
|
||||
abbrev Error := String
|
||||
|
||||
@@ -84,7 +83,7 @@ def addJpTostate (jp : JoinPointId) (bb : LLVM.BasicBlock llvmctx) : M llvmctx U
|
||||
|
||||
def emitJp (jp : JoinPointId) : M llvmctx (LLVM.BasicBlock llvmctx) := do
|
||||
let state ← get
|
||||
match state.jp2bb.find? jp with
|
||||
match state.jp2bb[jp]? with
|
||||
| .some bb => return bb
|
||||
| .none => throw s!"unable to find join point {jp}"
|
||||
|
||||
@@ -531,7 +530,7 @@ def emitFnDecls : M llvmctx Unit := do
|
||||
|
||||
def emitLhsSlot_ (x : VarId) : M llvmctx (LLVM.LLVMType llvmctx × LLVM.Value llvmctx) := do
|
||||
let state ← get
|
||||
match state.var2val.find? x with
|
||||
match state.var2val[x]? with
|
||||
| .some v => return v
|
||||
| .none => throw s!"unable to find variable {x}"
|
||||
|
||||
@@ -1029,7 +1028,7 @@ def emitTailCall (builder : LLVM.Builder llvmctx) (f : FunId) (v : Expr) : M llv
|
||||
|
||||
def emitJmp (builder : LLVM.Builder llvmctx) (jp : JoinPointId) (xs : Array Arg) : M llvmctx Unit := do
|
||||
let llvmctx ← read
|
||||
let ps ← match llvmctx.jpMap.find? jp with
|
||||
let ps ← match llvmctx.jpMap[jp]? with
|
||||
| some ps => pure ps
|
||||
| none => throw s!"Unknown join point {jp}"
|
||||
unless xs.size == ps.size do throw s!"Invalid goto, mismatched sizes between arguments, formal parameters."
|
||||
|
||||
@@ -51,8 +51,8 @@ end CollectUsedDecls
|
||||
def collectUsedDecls (env : Environment) (decl : Decl) (used : NameSet := {}) : NameSet :=
|
||||
(CollectUsedDecls.collectDecl decl env).run' used
|
||||
|
||||
abbrev VarTypeMap := HashMap VarId IRType
|
||||
abbrev JPParamsMap := HashMap JoinPointId (Array Param)
|
||||
abbrev VarTypeMap := Std.HashMap VarId IRType
|
||||
abbrev JPParamsMap := Std.HashMap JoinPointId (Array Param)
|
||||
|
||||
namespace CollectMaps
|
||||
abbrev Collector := (VarTypeMap × JPParamsMap) → (VarTypeMap × JPParamsMap)
|
||||
|
||||
@@ -10,7 +10,7 @@ import Lean.Compiler.IR.FreeVars
|
||||
|
||||
namespace Lean.IR.ExpandResetReuse
|
||||
/-- Mapping from variable to projections -/
|
||||
abbrev ProjMap := HashMap VarId Expr
|
||||
abbrev ProjMap := Std.HashMap VarId Expr
|
||||
namespace CollectProjMap
|
||||
abbrev Collector := ProjMap → ProjMap
|
||||
@[inline] def collectVDecl (x : VarId) (v : Expr) : Collector := fun m =>
|
||||
@@ -148,20 +148,20 @@ def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
|
||||
def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=
|
||||
match y with
|
||||
| Arg.var y =>
|
||||
match ctx.projMap.find? y with
|
||||
match ctx.projMap[y]? with
|
||||
| some (Expr.proj j w) => j == i && w == x
|
||||
| _ => false
|
||||
| _ => false
|
||||
|
||||
/-- Given `uset x[i] := y`, return true iff `y := uproj[i] x` -/
|
||||
def isSelfUSet (ctx : Context) (x : VarId) (i : Nat) (y : VarId) : Bool :=
|
||||
match ctx.projMap.find? y with
|
||||
match ctx.projMap[y]? with
|
||||
| some (Expr.uproj j w) => j == i && w == x
|
||||
| _ => false
|
||||
|
||||
/-- Given `sset x[n, i] := y`, return true iff `y := sproj[n, i] x` -/
|
||||
def isSelfSSet (ctx : Context) (x : VarId) (n : Nat) (i : Nat) (y : VarId) : Bool :=
|
||||
match ctx.projMap.find? y with
|
||||
match ctx.projMap[y]? with
|
||||
| some (Expr.sproj m j w) => n == m && j == i && w == x
|
||||
| _ => false
|
||||
|
||||
|
||||
@@ -64,34 +64,34 @@ instance : AddMessageContext CompilerM where
|
||||
|
||||
def getType (fvarId : FVarId) : CompilerM Expr := do
|
||||
let lctx := (← get).lctx
|
||||
if let some decl := lctx.letDecls.find? fvarId then
|
||||
if let some decl := lctx.letDecls[fvarId]? then
|
||||
return decl.type
|
||||
else if let some decl := lctx.params.find? fvarId then
|
||||
else if let some decl := lctx.params[fvarId]? then
|
||||
return decl.type
|
||||
else if let some decl := lctx.funDecls.find? fvarId then
|
||||
else if let some decl := lctx.funDecls[fvarId]? then
|
||||
return decl.type
|
||||
else
|
||||
throwError "unknown free variable {fvarId.name}"
|
||||
|
||||
def getBinderName (fvarId : FVarId) : CompilerM Name := do
|
||||
let lctx := (← get).lctx
|
||||
if let some decl := lctx.letDecls.find? fvarId then
|
||||
if let some decl := lctx.letDecls[fvarId]? then
|
||||
return decl.binderName
|
||||
else if let some decl := lctx.params.find? fvarId then
|
||||
else if let some decl := lctx.params[fvarId]? then
|
||||
return decl.binderName
|
||||
else if let some decl := lctx.funDecls.find? fvarId then
|
||||
else if let some decl := lctx.funDecls[fvarId]? then
|
||||
return decl.binderName
|
||||
else
|
||||
throwError "unknown free variable {fvarId.name}"
|
||||
|
||||
def findParam? (fvarId : FVarId) : CompilerM (Option Param) :=
|
||||
return (← get).lctx.params.find? fvarId
|
||||
return (← get).lctx.params[fvarId]?
|
||||
|
||||
def findLetDecl? (fvarId : FVarId) : CompilerM (Option LetDecl) :=
|
||||
return (← get).lctx.letDecls.find? fvarId
|
||||
return (← get).lctx.letDecls[fvarId]?
|
||||
|
||||
def findFunDecl? (fvarId : FVarId) : CompilerM (Option FunDecl) :=
|
||||
return (← get).lctx.funDecls.find? fvarId
|
||||
return (← get).lctx.funDecls[fvarId]?
|
||||
|
||||
def findLetValue? (fvarId : FVarId) : CompilerM (Option LetValue) := do
|
||||
let some { value, .. } ← findLetDecl? fvarId | return none
|
||||
@@ -166,7 +166,7 @@ it is a free variable, a type (or type former), or `lcErased`.
|
||||
|
||||
`Check.lean` contains a substitution validator.
|
||||
-/
|
||||
abbrev FVarSubst := HashMap FVarId Expr
|
||||
abbrev FVarSubst := Std.HashMap FVarId Expr
|
||||
|
||||
/--
|
||||
Replace the free variables in `e` using the given substitution.
|
||||
@@ -190,7 +190,7 @@ where
|
||||
go (e : Expr) : Expr :=
|
||||
if e.hasFVar then
|
||||
match e with
|
||||
| .fvar fvarId => match s.find? fvarId with
|
||||
| .fvar fvarId => match s[fvarId]? with
|
||||
| some e => if translator then e else go e
|
||||
| none => e
|
||||
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e
|
||||
@@ -224,7 +224,7 @@ That is, it is not a type (or type former), nor `lcErased`. Recall that a valid
|
||||
expressions that are free variables, `lcErased`, or type formers.
|
||||
-/
|
||||
private partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : Bool) : NormFVarResult :=
|
||||
match s.find? fvarId with
|
||||
match s[fvarId]? with
|
||||
| some (.fvar fvarId') =>
|
||||
if translator then
|
||||
.fvar fvarId'
|
||||
@@ -246,7 +246,7 @@ private partial def normArgImp (s : FVarSubst) (arg : Arg) (translator : Bool) :
|
||||
match arg with
|
||||
| .erased => arg
|
||||
| .fvar fvarId =>
|
||||
match s.find? fvarId with
|
||||
match s[fvarId]? with
|
||||
| some (.fvar fvarId') =>
|
||||
let arg' := .fvar fvarId'
|
||||
if translator then arg' else normArgImp s arg' translator
|
||||
|
||||
@@ -268,7 +268,7 @@ def getFunctionSummary? (env : Environment) (fid : Name) : Option Value :=
|
||||
A map from variable identifiers to the `Value` produced by the abstract
|
||||
interpreter for them.
|
||||
-/
|
||||
abbrev Assignment := HashMap FVarId Value
|
||||
abbrev Assignment := Std.HashMap FVarId Value
|
||||
|
||||
/--
|
||||
The context of `InterpM`.
|
||||
@@ -332,7 +332,7 @@ If none is available return `Value.bot`.
|
||||
-/
|
||||
def findVarValue (var : FVarId) : InterpM Value := do
|
||||
let assignment ← getAssignment
|
||||
return assignment.findD var .bot
|
||||
return assignment.getD var .bot
|
||||
|
||||
/--
|
||||
Find the value of `arg` using the logic of `findVarValue`.
|
||||
@@ -547,13 +547,13 @@ where
|
||||
| .jp decl k | .fun decl k =>
|
||||
return code.updateFun! (← decl.updateValue (← go decl.value)) (← go k)
|
||||
| .cases cs =>
|
||||
let discrVal := assignment.findD cs.discr .bot
|
||||
let discrVal := assignment.getD cs.discr .bot
|
||||
let processAlt typ alt := do
|
||||
match alt with
|
||||
| .alt ctor args body =>
|
||||
if discrVal.containsCtor ctor then
|
||||
let filter param := do
|
||||
if let some val := assignment.find? param.fvarId then
|
||||
if let some val := assignment[param.fvarId]? then
|
||||
if let some literal ← val.getLiteral then
|
||||
return some (param, literal)
|
||||
return none
|
||||
|
||||
@@ -62,7 +62,7 @@ structure State where
|
||||
Whenever there is function application `f a₁ ... aₙ`, where `f` is in `decls`, `f` is not `main`, and
|
||||
we visit with the abstract values assigned to `aᵢ`, but first we record the visit here.
|
||||
-/
|
||||
visited : HashSet (Name × Array AbsValue) := {}
|
||||
visited : Std.HashSet (Name × Array AbsValue) := {}
|
||||
/--
|
||||
Bitmask containing the result, i.e., which parameters of `main` are fixed.
|
||||
We initialize it with `true` everywhere.
|
||||
|
||||
@@ -59,7 +59,7 @@ structure FloatState where
|
||||
/--
|
||||
A map from identifiers of declarations to their current decision.
|
||||
-/
|
||||
decision : HashMap FVarId Decision
|
||||
decision : Std.HashMap FVarId Decision
|
||||
/--
|
||||
A map from decisions (excluding `unknown`) to the declarations with
|
||||
these decisions (in correct order). Basically:
|
||||
@@ -67,7 +67,7 @@ structure FloatState where
|
||||
- Which declarations do we move into a certain arm
|
||||
- Which declarations do we move into the default arm
|
||||
-/
|
||||
newArms : HashMap Decision (List CodeDecl)
|
||||
newArms : Std.HashMap Decision (List CodeDecl)
|
||||
|
||||
/--
|
||||
Use to collect relevant declarations for the floating mechanism.
|
||||
@@ -116,8 +116,8 @@ up to this point, with respect to `cs`. The initial decisions are:
|
||||
- `arm` or `default` if we see the declaration only being used in exactly one cases arm
|
||||
- `unknown` otherwise
|
||||
-/
|
||||
def initialDecisions (cs : Cases) : BaseFloatM (HashMap FVarId Decision) := do
|
||||
let mut map := mkHashMap (← read).decls.length
|
||||
def initialDecisions (cs : Cases) : BaseFloatM (Std.HashMap FVarId Decision) := do
|
||||
let mut map := Std.HashMap.empty (← read).decls.length
|
||||
let folder val acc := do
|
||||
if let .let decl := val then
|
||||
if (← ignore? decl) then
|
||||
@@ -130,25 +130,25 @@ def initialDecisions (cs : Cases) : BaseFloatM (HashMap FVarId Decision) := do
|
||||
(_, map) ← goCases cs |>.run map
|
||||
return map
|
||||
where
|
||||
goFVar (plannedDecision : Decision) (var : FVarId) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit := do
|
||||
if let some decision := (← get).find? var then
|
||||
goFVar (plannedDecision : Decision) (var : FVarId) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit := do
|
||||
if let some decision := (← get)[var]? then
|
||||
if decision == .unknown then
|
||||
modify fun s => s.insert var plannedDecision
|
||||
else if decision != plannedDecision then
|
||||
modify fun s => s.insert var .dont
|
||||
-- otherwise we already have the proper decision
|
||||
|
||||
goAlt (alt : Alt) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
goAlt (alt : Alt) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
forFVarM (goFVar (.ofAlt alt)) alt
|
||||
goCases (cs : Cases) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
goCases (cs : Cases) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
cs.alts.forM goAlt
|
||||
|
||||
/--
|
||||
Compute the initial new arms. This will just set up a map from all arms of
|
||||
`cs` to empty `Array`s, plus one additional entry for `dont`.
|
||||
-/
|
||||
def initialNewArms (cs : Cases) : HashMap Decision (List CodeDecl) := Id.run do
|
||||
let mut map := mkHashMap (cs.alts.size + 1)
|
||||
def initialNewArms (cs : Cases) : Std.HashMap Decision (List CodeDecl) := Id.run do
|
||||
let mut map := Std.HashMap.empty (cs.alts.size + 1)
|
||||
map := map.insert .dont []
|
||||
cs.alts.foldr (init := map) fun val acc => acc.insert (.ofAlt val) []
|
||||
|
||||
@@ -170,7 +170,7 @@ respectively but since `z` can't be moved we don't want that to move `x` and `y`
|
||||
-/
|
||||
def dontFloat (decl : CodeDecl) : FloatM Unit := do
|
||||
forFVarM goFVar decl
|
||||
modify fun s => { s with newArms := s.newArms.insert .dont (decl :: s.newArms.find! .dont) }
|
||||
modify fun s => { s with newArms := s.newArms.insert .dont (decl :: s.newArms[Decision.dont]!) }
|
||||
where
|
||||
goFVar (fvar : FVarId) : FloatM Unit := do
|
||||
if (← get).decision.contains fvar then
|
||||
@@ -223,12 +223,12 @@ Will:
|
||||
If we are at `y` `x` is still marked to be moved but we don't want that.
|
||||
-/
|
||||
def float (decl : CodeDecl) : FloatM Unit := do
|
||||
let arm := (← get).decision.find! decl.fvarId
|
||||
let arm := (← get).decision[decl.fvarId]!
|
||||
forFVarM (goFVar · arm) decl
|
||||
modify fun s => { s with newArms := s.newArms.insert arm (decl :: s.newArms.find! arm) }
|
||||
modify fun s => { s with newArms := s.newArms.insert arm (decl :: s.newArms[arm]!) }
|
||||
where
|
||||
goFVar (fvar : FVarId) (arm : Decision) : FloatM Unit := do
|
||||
let some decision := (← get).decision.find? fvar | return ()
|
||||
let some decision := (← get).decision[fvar]? | return ()
|
||||
if decision != arm then
|
||||
modify fun s => { s with decision := s.decision.insert fvar .dont }
|
||||
else if decision == .unknown then
|
||||
@@ -249,7 +249,7 @@ where
|
||||
-/
|
||||
goCases : FloatM Unit := do
|
||||
for decl in (← read).decls do
|
||||
let currentDecision := (← get).decision.find! decl.fvarId
|
||||
let currentDecision := (← get).decision[decl.fvarId]!
|
||||
if currentDecision == .unknown then
|
||||
/-
|
||||
If the decision is still unknown by now this means `decl` is
|
||||
@@ -284,10 +284,10 @@ where
|
||||
newArms := initialNewArms cs
|
||||
}
|
||||
let (_, res) ← goCases |>.run base
|
||||
let remainders := res.newArms.find! .dont
|
||||
let remainders := res.newArms[Decision.dont]!
|
||||
let altMapper alt := do
|
||||
let decision := .ofAlt alt
|
||||
let newCode := res.newArms.find! decision
|
||||
let decision := Decision.ofAlt alt
|
||||
let newCode := res.newArms[decision]!
|
||||
trace[Compiler.floatLetIn] "Size of code that was pushed into arm: {repr decision} {newCode.length}"
|
||||
let fused ← withNewScope do
|
||||
go (attachCodeDecls newCode.toArray alt.getCode)
|
||||
|
||||
@@ -29,7 +29,7 @@ structure CandidateInfo where
|
||||
The set of candidates that rely on this candidate to be a join point.
|
||||
For a more detailed explanation see the documentation of `find`
|
||||
-/
|
||||
associated : HashSet FVarId
|
||||
associated : Std.HashSet FVarId
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
@@ -39,14 +39,14 @@ structure FindState where
|
||||
/--
|
||||
All current join point candidates accessible by their `FVarId`.
|
||||
-/
|
||||
candidates : HashMap FVarId CandidateInfo := .empty
|
||||
candidates : Std.HashMap FVarId CandidateInfo := .empty
|
||||
/--
|
||||
The `FVarId`s of all `fun` declarations that were declared within the
|
||||
current `fun`.
|
||||
-/
|
||||
scope : HashSet FVarId := .empty
|
||||
scope : Std.HashSet FVarId := .empty
|
||||
|
||||
abbrev ReplaceCtx := HashMap FVarId Name
|
||||
abbrev ReplaceCtx := Std.HashMap FVarId Name
|
||||
|
||||
abbrev FindM := ReaderT (Option FVarId) StateRefT FindState ScopeM
|
||||
abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
|
||||
@@ -55,7 +55,7 @@ abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
|
||||
Attempt to find a join point candidate by its `FVarId`.
|
||||
-/
|
||||
private def findCandidate? (fvarId : FVarId) : FindM (Option CandidateInfo) := do
|
||||
return (← get).candidates.find? fvarId
|
||||
return (← get).candidates[fvarId]?
|
||||
|
||||
/--
|
||||
Erase a join point candidate as well as all the ones that depend on it
|
||||
@@ -69,7 +69,7 @@ private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
|
||||
/--
|
||||
Combinator for modifying the candidates in `FindM`.
|
||||
-/
|
||||
private def modifyCandidates (f : HashMap FVarId CandidateInfo → HashMap FVarId CandidateInfo) : FindM Unit :=
|
||||
private def modifyCandidates (f : Std.HashMap FVarId CandidateInfo → Std.HashMap FVarId CandidateInfo) : FindM Unit :=
|
||||
modify (fun state => {state with candidates := f state.candidates })
|
||||
|
||||
/--
|
||||
@@ -196,7 +196,7 @@ where
|
||||
return code
|
||||
| _, _ => return Code.updateLet! code decl (← go k)
|
||||
| .fun decl k =>
|
||||
if let some replacement := (← read).find? decl.fvarId then
|
||||
if let some replacement := (← read)[decl.fvarId]? then
|
||||
let newDecl := { decl with
|
||||
binderName := replacement,
|
||||
value := (← go decl.value)
|
||||
@@ -244,7 +244,7 @@ structure ExtendState where
|
||||
to `Param`s. The free variables in this map are the once that the context
|
||||
of said join point will be extended by by passing in the respective parameter.
|
||||
-/
|
||||
fvarMap : HashMap FVarId (HashMap FVarId Param) := {}
|
||||
fvarMap : Std.HashMap FVarId (Std.HashMap FVarId Param) := {}
|
||||
|
||||
/--
|
||||
The monad for the `extendJoinPointContext` pass.
|
||||
@@ -262,7 +262,7 @@ otherwise just return `fvar`.
|
||||
def replaceFVar (fvar : FVarId) : ExtendM FVarId := do
|
||||
if (← read).candidates.contains fvar then
|
||||
if let some currentJp := (← read).currentJp? then
|
||||
if let some replacement := (← get).fvarMap.find! currentJp |>.find? fvar then
|
||||
if let some replacement := (← get).fvarMap[currentJp]![fvar]? then
|
||||
return replacement.fvarId
|
||||
return fvar
|
||||
|
||||
@@ -313,7 +313,7 @@ This is necessary if:
|
||||
-/
|
||||
def extendByIfNecessary (fvar : FVarId) : ExtendM Unit := do
|
||||
if let some currentJp := (← read).currentJp? then
|
||||
let mut translator := (← get).fvarMap.find! currentJp
|
||||
let mut translator := (← get).fvarMap[currentJp]!
|
||||
let candidates := (← read).candidates
|
||||
if !(← isInScope fvar) && !translator.contains fvar && candidates.contains fvar then
|
||||
let typ ← getType fvar
|
||||
@@ -337,7 +337,7 @@ of `j.2` in `j.1`.
|
||||
-/
|
||||
def mergeJpContextIfNecessary (jp : FVarId) : ExtendM Unit := do
|
||||
if (← read).currentJp?.isSome then
|
||||
let additionalArgs := (← get).fvarMap.find! jp |>.toArray
|
||||
let additionalArgs := (← get).fvarMap[jp]!.toArray
|
||||
for (fvar, _) in additionalArgs do
|
||||
extendByIfNecessary fvar
|
||||
|
||||
@@ -405,7 +405,7 @@ where
|
||||
| .jp decl k =>
|
||||
let decl ← withNewJpScope decl do
|
||||
let value ← go decl.value
|
||||
let additionalParams := (← get).fvarMap.find! decl.fvarId |>.toArray |>.map Prod.snd
|
||||
let additionalParams := (← get).fvarMap[decl.fvarId]!.toArray |>.map Prod.snd
|
||||
let newType := additionalParams.foldr (init := decl.type) (fun val acc => .forallE val.binderName val.type acc .default)
|
||||
decl.update newType (additionalParams ++ decl.params) value
|
||||
mergeJpContextIfNecessary decl.fvarId
|
||||
@@ -426,7 +426,7 @@ where
|
||||
return Code.updateCases! code cs.resultType discr alts
|
||||
| .jmp fn args =>
|
||||
let mut newArgs ← args.mapM (mapFVarM goFVar)
|
||||
let additionalArgs := (← get).fvarMap.find! fn |>.toArray |>.map Prod.fst
|
||||
let additionalArgs := (← get).fvarMap[fn]!.toArray |>.map Prod.fst
|
||||
if let some _currentJp := (← read).currentJp? then
|
||||
let f := fun arg => do
|
||||
return .fvar (← goFVar arg)
|
||||
@@ -545,7 +545,7 @@ where
|
||||
if let some knownArgs := (← get).jpJmpArgs.find? fn then
|
||||
let mut newArgs := knownArgs
|
||||
for (param, arg) in decl.params.zip args do
|
||||
if let some knownVal := newArgs.find? param.fvarId then
|
||||
if let some knownVal := newArgs[param.fvarId]? then
|
||||
if arg.toExpr != knownVal then
|
||||
newArgs := newArgs.erase param.fvarId
|
||||
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn newArgs }
|
||||
|
||||
@@ -13,9 +13,9 @@ namespace Lean.Compiler.LCNF
|
||||
LCNF local context.
|
||||
-/
|
||||
structure LCtx where
|
||||
params : HashMap FVarId Param := {}
|
||||
letDecls : HashMap FVarId LetDecl := {}
|
||||
funDecls : HashMap FVarId FunDecl := {}
|
||||
params : Std.HashMap FVarId Param := {}
|
||||
letDecls : Std.HashMap FVarId LetDecl := {}
|
||||
funDecls : Std.HashMap FVarId FunDecl := {}
|
||||
deriving Inhabited
|
||||
|
||||
def LCtx.addParam (lctx : LCtx) (param : Param) : LCtx :=
|
||||
|
||||
@@ -30,7 +30,7 @@ structure State where
|
||||
/-- Counter for generating new (normalized) universe parameter names. -/
|
||||
nextIdx : Nat := 1
|
||||
/-- Mapping from existing universe parameter names to the new ones. -/
|
||||
map : HashMap Name Level := {}
|
||||
map : Std.HashMap Name Level := {}
|
||||
/-- Parameters that have been normalized. -/
|
||||
paramNames : Array Name := #[]
|
||||
|
||||
@@ -49,7 +49,7 @@ partial def normLevel (u : Level) : M Level := do
|
||||
| .max v w => return u.updateMax! (← normLevel v) (← normLevel w)
|
||||
| .imax v w => return u.updateIMax! (← normLevel v) (← normLevel w)
|
||||
| .mvar _ => unreachable!
|
||||
| .param n => match (← get).map.find? n with
|
||||
| .param n => match (← get).map[n]? with
|
||||
| some u => return u
|
||||
| none =>
|
||||
let u := Level.param <| (`u).appendIndexAfter (← get).nextIdx
|
||||
|
||||
@@ -31,9 +31,9 @@ def sortedBySize : Probe Decl (Nat × Decl) := fun decls =>
|
||||
if sz₁ == sz₂ then Name.lt decl₁.name decl₂.name else sz₁ < sz₂
|
||||
|
||||
def countUnique [ToString α] [BEq α] [Hashable α] : Probe α (α × Nat) := fun data => do
|
||||
let mut map := HashMap.empty
|
||||
let mut map := Std.HashMap.empty
|
||||
for d in data do
|
||||
if let some count := map.find? d then
|
||||
if let some count := map[d]? then
|
||||
map := map.insert d (count + 1)
|
||||
else
|
||||
map := map.insert d 1
|
||||
|
||||
@@ -40,7 +40,7 @@ structure FunDeclInfoMap where
|
||||
/--
|
||||
Mapping from local function name to inlining information.
|
||||
-/
|
||||
map : HashMap FVarId FunDeclInfo := {}
|
||||
map : Std.HashMap FVarId FunDeclInfo := {}
|
||||
deriving Inhabited
|
||||
|
||||
def FunDeclInfoMap.format (s : FunDeclInfoMap) : CompilerM Format := do
|
||||
@@ -56,7 +56,7 @@ Add new occurrence for the local function with binder name `key`.
|
||||
def FunDeclInfoMap.add (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
|
||||
match s with
|
||||
| { map } =>
|
||||
match map.find? fvarId with
|
||||
match map[fvarId]? with
|
||||
| some .once => { map := map.insert fvarId .many }
|
||||
| none => { map := map.insert fvarId .once }
|
||||
| _ => { map }
|
||||
@@ -67,7 +67,7 @@ Add new occurrence for the local function occurring as an argument for another f
|
||||
def FunDeclInfoMap.addHo (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
|
||||
match s with
|
||||
| { map } =>
|
||||
match map.find? fvarId with
|
||||
match map[fvarId]? with
|
||||
| some .once | none => { map := map.insert fvarId .many }
|
||||
| _ => { map }
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ Execute `x` with `fvarId` set as `mustInline`.
|
||||
After execution the original setting is restored.
|
||||
-/
|
||||
def withAddMustInline (fvarId : FVarId) (x : SimpM α) : SimpM α := do
|
||||
let saved? := (← get).funDeclInfoMap.map.find? fvarId
|
||||
let saved? := (← get).funDeclInfoMap.map[fvarId]?
|
||||
try
|
||||
addMustInline fvarId
|
||||
x
|
||||
@@ -185,7 +185,7 @@ Return true if the given local function declaration or join point id is marked a
|
||||
`once` or `mustInline`. We use this information to decide whether to inline them.
|
||||
-/
|
||||
def isOnceOrMustInline (fvarId : FVarId) : SimpM Bool := do
|
||||
match (← get).funDeclInfoMap.map.find? fvarId with
|
||||
match (← get).funDeclInfoMap.map[fvarId]? with
|
||||
| some .once | some .mustInline => return true
|
||||
| _ => return false
|
||||
|
||||
|
||||
@@ -199,9 +199,9 @@ structure State where
|
||||
/-- Cache from Lean regular expression to LCNF argument. -/
|
||||
cache : PHashMap Expr Arg := {}
|
||||
/-- `toLCNFType` cache -/
|
||||
typeCache : HashMap Expr Expr := {}
|
||||
typeCache : Std.HashMap Expr Expr := {}
|
||||
/-- isTypeFormerType cache -/
|
||||
isTypeFormerTypeCache : HashMap Expr Bool := {}
|
||||
isTypeFormerTypeCache : Std.HashMap Expr Bool := {}
|
||||
/-- LCNF sequence, we chain it to create a LCNF `Code` object. -/
|
||||
seq : Array Element := #[]
|
||||
/--
|
||||
@@ -257,7 +257,7 @@ private partial def isTypeFormerType (type : Expr) : M Bool := do
|
||||
| .true => return true
|
||||
| .false => return false
|
||||
| .undef =>
|
||||
if let some result := (← get).isTypeFormerTypeCache.find? type then
|
||||
if let some result := (← get).isTypeFormerTypeCache[type]? then
|
||||
return result
|
||||
let result ← liftMetaM <| Meta.isTypeFormerType type
|
||||
modify fun s => { s with isTypeFormerTypeCache := s.isTypeFormerTypeCache.insert type result }
|
||||
@@ -305,7 +305,7 @@ def applyToAny (type : Expr) : M Expr := do
|
||||
| _ => none
|
||||
|
||||
def toLCNFType (type : Expr) : M Expr := do
|
||||
match (← get).typeCache.find? type with
|
||||
match (← get).typeCache[type]? with
|
||||
| some type' => return type'
|
||||
| none =>
|
||||
let type' ← liftMetaM <| LCNF.toLCNFType type
|
||||
|
||||
@@ -270,16 +270,3 @@ def ofListWith (l : List (α × β)) (f : β → β → β) : HashMap α β :=
|
||||
| some v => m.insert p.fst $ f v p.snd)
|
||||
|
||||
end Lean.HashMap
|
||||
|
||||
/--
|
||||
Groups all elements `x`, `y` in `xs` with `key x == key y` into the same array
|
||||
`(xs.groupByKey key).find! (key x)`. Groups preserve the relative order of elements in `xs`.
|
||||
-/
|
||||
def Array.groupByKey [BEq α] [Hashable α] (key : β → α) (xs : Array β)
|
||||
: Lean.HashMap α (Array β) := Id.run do
|
||||
let mut groups := ∅
|
||||
for x in xs do
|
||||
let group := groups.findD (key x) #[]
|
||||
groups := groups.erase (key x) -- make `group` referentially unique
|
||||
groups := groups.insert (key x) (group.push x)
|
||||
return groups
|
||||
|
||||
@@ -150,7 +150,7 @@ instance : FromJson RefInfo where
|
||||
pure { definition?, usages }
|
||||
|
||||
/-- References from a single module/file -/
|
||||
def ModuleRefs := HashMap RefIdent RefInfo
|
||||
def ModuleRefs := Std.HashMap RefIdent RefInfo
|
||||
|
||||
instance : ToJson ModuleRefs where
|
||||
toJson m := Json.mkObj <| m.toList.map fun (ident, info) => (ident.toJson.compress, toJson info)
|
||||
@@ -158,7 +158,7 @@ instance : ToJson ModuleRefs where
|
||||
instance : FromJson ModuleRefs where
|
||||
fromJson? j := do
|
||||
let node ← j.getObj?
|
||||
node.foldM (init := HashMap.empty) fun m k v =>
|
||||
node.foldM (init := Std.HashMap.empty) fun m k v =>
|
||||
return m.insert (← RefIdent.fromJson? (← Json.parse k)) (← fromJson? v)
|
||||
|
||||
/--
|
||||
|
||||
@@ -4,7 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Data.HashSet
|
||||
import Std.Data.HashSet.Basic
|
||||
import Lean.Data.RBMap
|
||||
import Lean.Data.RBTree
|
||||
import Lean.Data.SSet
|
||||
@@ -64,14 +64,14 @@ abbrev insert (s : NameSSet) (n : Name) : NameSSet := SSet.insert s n
|
||||
abbrev contains (s : NameSSet) (n : Name) : Bool := SSet.contains s n
|
||||
end NameSSet
|
||||
|
||||
def NameHashSet := HashSet Name
|
||||
def NameHashSet := Std.HashSet Name
|
||||
|
||||
namespace NameHashSet
|
||||
@[inline] def empty : NameHashSet := HashSet.empty
|
||||
@[inline] def empty : NameHashSet := Std.HashSet.empty
|
||||
instance : EmptyCollection NameHashSet := ⟨empty⟩
|
||||
instance : Inhabited NameHashSet := ⟨{}⟩
|
||||
def insert (s : NameHashSet) (n : Name) := HashSet.insert s n
|
||||
def contains (s : NameHashSet) (n : Name) : Bool := HashSet.contains s n
|
||||
def insert (s : NameHashSet) (n : Name) := Std.HashSet.insert s n
|
||||
def contains (s : NameHashSet) (n : Name) : Bool := Std.HashSet.contains s n
|
||||
end NameHashSet
|
||||
|
||||
def MacroScopesView.isPrefixOf (v₁ v₂ : MacroScopesView) : Bool :=
|
||||
|
||||
@@ -4,7 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Data.HashMap
|
||||
import Std.Data.HashMap.Basic
|
||||
import Lean.Data.PersistentHashMap
|
||||
universe u v w w'
|
||||
|
||||
@@ -28,7 +28,7 @@ namespace Lean
|
||||
-/
|
||||
structure SMap (α : Type u) (β : Type v) [BEq α] [Hashable α] where
|
||||
stage₁ : Bool := true
|
||||
map₁ : HashMap α β := {}
|
||||
map₁ : Std.HashMap α β := {}
|
||||
map₂ : PHashMap α β := {}
|
||||
|
||||
namespace SMap
|
||||
@@ -37,7 +37,7 @@ variable {α : Type u} {β : Type v} [BEq α] [Hashable α]
|
||||
instance : Inhabited (SMap α β) := ⟨{}⟩
|
||||
def empty : SMap α β := {}
|
||||
|
||||
@[inline] def fromHashMap (m : HashMap α β) (stage₁ := true) : SMap α β :=
|
||||
@[inline] def fromHashMap (m : Std.HashMap α β) (stage₁ := true) : SMap α β :=
|
||||
{ map₁ := m, stage₁ := stage₁ }
|
||||
|
||||
@[specialize] def insert : SMap α β → α → β → SMap α β
|
||||
@@ -49,8 +49,8 @@ def empty : SMap α β := {}
|
||||
| ⟨false, m₁, m₂⟩, k, v => ⟨false, m₁, m₂.insert k v⟩
|
||||
|
||||
@[specialize] def find? : SMap α β → α → Option β
|
||||
| ⟨true, m₁, _⟩, k => m₁.find? k
|
||||
| ⟨false, m₁, m₂⟩, k => (m₂.find? k).orElse fun _ => m₁.find? k
|
||||
| ⟨true, m₁, _⟩, k => m₁[k]?
|
||||
| ⟨false, m₁, m₂⟩, k => (m₂.find? k).orElse fun _ => m₁[k]?
|
||||
|
||||
@[inline] def findD (m : SMap α β) (a : α) (b₀ : β) : β :=
|
||||
(m.find? a).getD b₀
|
||||
@@ -67,8 +67,8 @@ def empty : SMap α β := {}
|
||||
/-- Similar to `find?`, but searches for result in the hashmap first.
|
||||
So, the result is correct only if we never "overwrite" `map₁` entries using `map₂`. -/
|
||||
@[specialize] def find?' : SMap α β → α → Option β
|
||||
| ⟨true, m₁, _⟩, k => m₁.find? k
|
||||
| ⟨false, m₁, m₂⟩, k => (m₁.find? k).orElse fun _ => m₂.find? k
|
||||
| ⟨true, m₁, _⟩, k => m₁[k]?
|
||||
| ⟨false, m₁, m₂⟩, k => m₁[k]?.orElse fun _ => m₂.find? k
|
||||
|
||||
def forM [Monad m] (s : SMap α β) (f : α → β → m PUnit) : m PUnit := do
|
||||
s.map₁.forM f
|
||||
@@ -96,7 +96,7 @@ def fold {σ : Type w} (f : σ → α → β → σ) (init : σ) (m : SMap α β
|
||||
m.map₂.foldl f $ m.map₁.fold f init
|
||||
|
||||
def numBuckets (m : SMap α β) : Nat :=
|
||||
m.map₁.numBuckets
|
||||
Std.HashMap.Internal.numBuckets m.map₁
|
||||
|
||||
def toList (m : SMap α β) : List (α × β) :=
|
||||
m.fold (init := []) fun es a b => (a, b)::es
|
||||
|
||||
@@ -201,12 +201,12 @@ def mkMessageAux (ctx : Context) (ref : Syntax) (msgData : MessageData) (severit
|
||||
|
||||
private def addTraceAsMessagesCore (ctx : Context) (log : MessageLog) (traceState : TraceState) : MessageLog := Id.run do
|
||||
if traceState.traces.isEmpty then return log
|
||||
let mut traces : HashMap (String.Pos × String.Pos) (Array MessageData) := ∅
|
||||
let mut traces : Std.HashMap (String.Pos × String.Pos) (Array MessageData) := ∅
|
||||
for traceElem in traceState.traces do
|
||||
let ref := replaceRef traceElem.ref ctx.ref
|
||||
let pos := ref.getPos?.getD 0
|
||||
let endPos := ref.getTailPos?.getD pos
|
||||
traces := traces.insert (pos, endPos) <| traces.findD (pos, endPos) #[] |>.push traceElem.msg
|
||||
traces := traces.insert (pos, endPos) <| traces.getD (pos, endPos) #[] |>.push traceElem.msg
|
||||
let mut log := log
|
||||
let traces' := traces.toArray.qsort fun ((a, _), _) ((b, _), _) => a < b
|
||||
for ((pos, endPos), traceMsg) in traces' do
|
||||
|
||||
@@ -630,7 +630,7 @@ private def replaceIndFVarsWithConsts (views : Array InductiveView) (indFVars :
|
||||
let type := type.replace fun e =>
|
||||
if !e.isFVar then
|
||||
none
|
||||
else match indFVar2Const.find? e with
|
||||
else match indFVar2Const[e]? with
|
||||
| none => none
|
||||
| some c => mkAppN c (params.extract 0 numVars)
|
||||
instantiateMVars (← mkForallFVars params type)
|
||||
|
||||
@@ -425,7 +425,7 @@ private def applyRefMap (e : Expr) (map : ExprMap Expr) : Expr :=
|
||||
e.replace fun e =>
|
||||
match patternWithRef? e with
|
||||
| some _ => some e -- stop `e` already has annotation
|
||||
| none => match map.find? e with
|
||||
| none => match map[e]? with
|
||||
| some eWithRef => some eWithRef -- stop `e` found annotation
|
||||
| none => none -- continue
|
||||
|
||||
|
||||
@@ -164,8 +164,11 @@ def addNonRec (preDef : PreDefinition) (applyAttrAfterCompilation := true) (all
|
||||
/--
|
||||
Eliminate recursive application annotations containing syntax. These annotations are used by the well-founded recursion module
|
||||
to produce better error messages. -/
|
||||
def eraseRecAppSyntaxExpr (e : Expr) : CoreM Expr :=
|
||||
Core.transform e (post := fun e => pure <| TransformStep.done <| if (getRecAppSyntax? e).isSome then e.mdataExpr! else e)
|
||||
def eraseRecAppSyntaxExpr (e : Expr) : CoreM Expr := do
|
||||
if e.find? hasRecAppSyntax |>.isSome then
|
||||
Core.transform e (post := fun e => pure <| TransformStep.done <| if hasRecAppSyntax e then e.mdataExpr! else e)
|
||||
else
|
||||
return e
|
||||
|
||||
def eraseRecAppSyntax (preDef : PreDefinition) : CoreM PreDefinition :=
|
||||
return { preDef with value := (← eraseRecAppSyntaxExpr preDef.value) }
|
||||
|
||||
@@ -69,12 +69,15 @@ private def ensureNoUnassignedMVarsAtPreDef (preDef : PreDefinition) : TermElabM
|
||||
This method beta-reduces them to make sure they can be eliminated by the well-founded recursion module. -/
|
||||
private def betaReduceLetRecApps (preDefs : Array PreDefinition) : MetaM (Array PreDefinition) :=
|
||||
preDefs.mapM fun preDef => do
|
||||
let value ← Core.transform preDef.value fun e => do
|
||||
if e.isApp && e.getAppFn.isLambda && e.getAppArgs.all fun arg => arg.getAppFn.isConst && preDefs.any fun preDef => preDef.declName == arg.getAppFn.constName! then
|
||||
return .visit e.headBeta
|
||||
else
|
||||
return .continue
|
||||
return { preDef with value }
|
||||
if preDef.value.find? (fun e => e.isConst && preDefs.any fun preDef => preDef.declName == e.constName!) |>.isSome then
|
||||
let value ← Core.transform preDef.value fun e => do
|
||||
if e.isApp && e.getAppFn.isLambda && e.getAppArgs.all fun arg => arg.getAppFn.isConst && preDefs.any fun preDef => preDef.declName == arg.getAppFn.constName! then
|
||||
return .visit e.headBeta
|
||||
else
|
||||
return .continue
|
||||
return { preDef with value }
|
||||
else
|
||||
return preDef
|
||||
|
||||
private def addAsAxioms (preDefs : Array PreDefinition) : TermElabM Unit := do
|
||||
for preDef in preDefs do
|
||||
|
||||
@@ -146,7 +146,7 @@ See issue #837 for an example where we can show termination using the index of a
|
||||
we don't get the desired definitional equalities.
|
||||
-/
|
||||
def nonIndicesFirst (recArgInfos : Array RecArgInfo) : Array RecArgInfo := Id.run do
|
||||
let mut indicesPos : HashSet Nat := {}
|
||||
let mut indicesPos : Std.HashSet Nat := {}
|
||||
for recArgInfo in recArgInfos do
|
||||
for pos in recArgInfo.indicesPos do
|
||||
indicesPos := indicesPos.insert pos
|
||||
|
||||
@@ -596,7 +596,7 @@ private partial def compileStxMatch (discrs : List Term) (alts : List Alt) : Ter
|
||||
`(have __discr := $discr; $stx)
|
||||
| _, _ => unreachable!
|
||||
|
||||
abbrev IdxSet := HashSet Nat
|
||||
abbrev IdxSet := Std.HashSet Nat
|
||||
|
||||
private partial def hasNoErrorIfUnused : Syntax → Bool
|
||||
| `(no_error_if_unused% $_) => true
|
||||
|
||||
@@ -11,27 +11,35 @@ namespace Lean
|
||||
private def recAppKey := `_recApp
|
||||
|
||||
/--
|
||||
We store the syntax at recursive applications to be able to generate better error messages
|
||||
when performing well-founded and structural recursion.
|
||||
We store the syntax at recursive applications to be able to generate better error messages
|
||||
when performing well-founded and structural recursion.
|
||||
-/
|
||||
def mkRecAppWithSyntax (e : Expr) (stx : Syntax) : Expr :=
|
||||
mkMData (KVMap.empty.insert recAppKey (DataValue.ofSyntax stx)) e
|
||||
mkMData (KVMap.empty.insert recAppKey (.ofSyntax stx)) e
|
||||
|
||||
/--
|
||||
Retrieve (if available) the syntax object attached to a recursive application.
|
||||
Retrieve (if available) the syntax object attached to a recursive application.
|
||||
-/
|
||||
def getRecAppSyntax? (e : Expr) : Option Syntax :=
|
||||
match e with
|
||||
| Expr.mdata d _ =>
|
||||
| .mdata d _ =>
|
||||
match d.find recAppKey with
|
||||
| some (DataValue.ofSyntax stx) => some stx
|
||||
| _ => none
|
||||
| _ => none
|
||||
|
||||
/--
|
||||
Checks if the `MData` is for a recursive applciation.
|
||||
Checks if the `MData` is for a recursive applciation.
|
||||
-/
|
||||
def MData.isRecApp (d : MData) : Bool :=
|
||||
d.contains recAppKey
|
||||
|
||||
/--
|
||||
Return `true` if `getRecAppSyntax? e` is a `some`.
|
||||
-/
|
||||
def hasRecAppSyntax (e : Expr) : Bool :=
|
||||
match e with
|
||||
| .mdata d _ => d.isRecApp
|
||||
| _ => false
|
||||
|
||||
end Lean
|
||||
|
||||
@@ -445,13 +445,13 @@ private def expandParentFields (s : Struct) : TermElabM Struct := do
|
||||
| _ => throwErrorAt ref "failed to access field '{fieldName}' in parent structure"
|
||||
| _ => return field
|
||||
|
||||
private abbrev FieldMap := HashMap Name Fields
|
||||
private abbrev FieldMap := Std.HashMap Name Fields
|
||||
|
||||
private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
|
||||
fields.foldlM (init := {}) fun fieldMap field =>
|
||||
match field.lhs with
|
||||
| .fieldName _ fieldName :: _ =>
|
||||
match fieldMap.find? fieldName with
|
||||
match fieldMap[fieldName]? with
|
||||
| some (prevField::restFields) =>
|
||||
if field.isSimple || prevField.isSimple then
|
||||
throwErrorAt field.ref "field '{fieldName}' has already been specified"
|
||||
|
||||
@@ -246,7 +246,7 @@ private def getSomeSyntheticMVarsRef : TermElabM Syntax := do
|
||||
private def throwStuckAtUniverseCnstr : TermElabM Unit := do
|
||||
-- This code assumes `entries` is not empty. Note that `processPostponed` uses `exceptionOnFailure` to guarantee this property
|
||||
let entries ← getPostponed
|
||||
let mut found : HashSet (Level × Level) := {}
|
||||
let mut found : Std.HashSet (Level × Level) := {}
|
||||
let mut uniqueEntries := #[]
|
||||
for entry in entries do
|
||||
let mut lhs := entry.lhs
|
||||
|
||||
@@ -8,8 +8,6 @@ import Init.Omega.Constraint
|
||||
import Lean.Elab.Tactic.Omega.OmegaM
|
||||
import Lean.Elab.Tactic.Omega.MinNatAbs
|
||||
|
||||
open Lean (HashMap HashSet)
|
||||
|
||||
namespace Lean.Elab.Tactic.Omega
|
||||
|
||||
initialize Lean.registerTraceClass `omega
|
||||
@@ -167,11 +165,11 @@ structure Problem where
|
||||
/-- The number of variables in the problem. -/
|
||||
numVars : Nat := 0
|
||||
/-- The current constraints, indexed by their coefficients. -/
|
||||
constraints : HashMap Coeffs Fact := ∅
|
||||
constraints : Std.HashMap Coeffs Fact := ∅
|
||||
/--
|
||||
The coefficients for which `constraints` contains an exact constraint (i.e. an equality).
|
||||
-/
|
||||
equalities : HashSet Coeffs := ∅
|
||||
equalities : Std.HashSet Coeffs := ∅
|
||||
/--
|
||||
Equations that have already been used to eliminate variables,
|
||||
along with the variable which was removed, and its coefficient (either `1` or `-1`).
|
||||
@@ -251,7 +249,7 @@ combining it with any existing constraints for the same coefficients.
|
||||
def addConstraint (p : Problem) : Fact → Problem
|
||||
| f@⟨x, s, j⟩ =>
|
||||
if p.possible then
|
||||
match p.constraints.find? x with
|
||||
match p.constraints[x]? with
|
||||
| none =>
|
||||
match s with
|
||||
| .trivial => p
|
||||
@@ -313,7 +311,7 @@ After solving, the variable will have been eliminated from all constraints.
|
||||
def solveEasyEquality (p : Problem) (c : Coeffs) : Problem :=
|
||||
let i := c.findIdx? (·.natAbs = 1) |>.getD 0 -- findIdx? is always some
|
||||
let sign := c.get i |> Int.sign
|
||||
match p.constraints.find? c with
|
||||
match p.constraints[c]? with
|
||||
| some f =>
|
||||
let init :=
|
||||
{ assumptions := p.assumptions
|
||||
@@ -335,7 +333,7 @@ After solving the easy equality,
|
||||
the minimum lexicographic value of `(c.minNatAbs, c.maxNatAbs)` will have been reduced.
|
||||
-/
|
||||
def dealWithHardEquality (p : Problem) (c : Coeffs) : OmegaM Problem :=
|
||||
match p.constraints.find? c with
|
||||
match p.constraints[c]? with
|
||||
| some ⟨_, ⟨some r, some r'⟩, j⟩ => do
|
||||
let m := c.minNatAbs + 1
|
||||
-- We have to store the valid value of the newly introduced variable in the atoms.
|
||||
@@ -479,7 +477,7 @@ def fourierMotzkinData (p : Problem) : Array FourierMotzkinData := Id.run do
|
||||
let n := p.numVars
|
||||
let mut data : Array FourierMotzkinData :=
|
||||
(List.range p.numVars).foldl (fun a i => a.push { var := i}) #[]
|
||||
for (_, f@⟨xs, s, _⟩) in p.constraints.toList do -- We could make a forIn instance for HashMap
|
||||
for (_, f@⟨xs, s, _⟩) in p.constraints do
|
||||
for i in [0:n] do
|
||||
let x := Coeffs.get xs i
|
||||
data := data.modify i fun d =>
|
||||
|
||||
@@ -58,7 +58,7 @@ structure MetaProblem where
|
||||
-/
|
||||
disjunctions : List Expr := []
|
||||
/-- Facts which have already been processed; we keep these to avoid duplicates. -/
|
||||
processedFacts : HashSet Expr := ∅
|
||||
processedFacts : Std.HashSet Expr := ∅
|
||||
|
||||
/-- Construct the `rfl` proof that `lc.eval atoms = e`. -/
|
||||
def mkEvalRflProof (e : Expr) (lc : LinearCombo) : OmegaM Expr := do
|
||||
@@ -80,7 +80,7 @@ def mkCoordinateEvalAtomsEq (e : Expr) (n : Nat) : OmegaM Expr := do
|
||||
mkEqTrans eq (← mkEqSymm (mkApp2 (.const ``LinearCombo.coordinate_eval []) n atoms))
|
||||
|
||||
/-- Construct the linear combination (and its associated proof and new facts) for an atom. -/
|
||||
def mkAtomLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
def mkAtomLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
let (n, facts) ← lookup e
|
||||
return ⟨LinearCombo.coordinate n, mkCoordinateEvalAtomsEq e n, facts.getD ∅⟩
|
||||
|
||||
@@ -94,9 +94,9 @@ Gives a small (10%) speedup in testing.
|
||||
I tried using a pointer based cache,
|
||||
but there was never enough subexpression sharing to make it effective.
|
||||
-/
|
||||
partial def asLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
partial def asLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
let cache ← get
|
||||
match cache.find? e with
|
||||
match cache.get? e with
|
||||
| some (lc, prf) =>
|
||||
trace[omega] "Found in cache: {e}"
|
||||
return (lc, prf, ∅)
|
||||
@@ -120,7 +120,7 @@ We also transform the expression as we descend into it:
|
||||
* pushing coercions: `↑(x + y)`, `↑(x * y)`, `↑(x / k)`, `↑(x % k)`, `↑k`
|
||||
* unfolding `emod`: `x % k` → `x - x / k`
|
||||
-/
|
||||
partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
trace[omega] "processing {e}"
|
||||
match groundInt? e with
|
||||
| some i =>
|
||||
@@ -142,7 +142,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
|
||||
mkEqTrans
|
||||
(← mkAppM ``Int.add_congr #[← prf₁, ← prf₂])
|
||||
(← mkEqSymm add_eval)
|
||||
pure (l₁ + l₂, prf, facts₁.merge facts₂)
|
||||
pure (l₁ + l₂, prf, facts₁.union facts₂)
|
||||
| (``HSub.hSub, #[_, _, _, _, e₁, e₂]) => do
|
||||
let (l₁, prf₁, facts₁) ← asLinearCombo e₁
|
||||
let (l₂, prf₂, facts₂) ← asLinearCombo e₂
|
||||
@@ -152,7 +152,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
|
||||
mkEqTrans
|
||||
(← mkAppM ``Int.sub_congr #[← prf₁, ← prf₂])
|
||||
(← mkEqSymm sub_eval)
|
||||
pure (l₁ - l₂, prf, facts₁.merge facts₂)
|
||||
pure (l₁ - l₂, prf, facts₁.union facts₂)
|
||||
| (``Neg.neg, #[_, _, e']) => do
|
||||
let (l, prf, facts) ← asLinearCombo e'
|
||||
let prf' : OmegaM Expr := do
|
||||
@@ -178,7 +178,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
|
||||
mkEqTrans
|
||||
(← mkAppM ``Int.mul_congr #[← xprf, ← yprf])
|
||||
(← mkEqSymm mul_eval)
|
||||
pure (some (LinearCombo.mul xl yl, prf, xfacts.merge yfacts), true)
|
||||
pure (some (LinearCombo.mul xl yl, prf, xfacts.union yfacts), true)
|
||||
else
|
||||
pure (none, false)
|
||||
match r? with
|
||||
@@ -235,7 +235,7 @@ where
|
||||
Apply a rewrite rule to an expression, and interpret the result as a `LinearCombo`.
|
||||
(We're not rewriting any subexpressions here, just the top level, for efficiency.)
|
||||
-/
|
||||
rewrite (lhs rw : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
rewrite (lhs rw : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
trace[omega] "rewriting {lhs} via {rw} : {← inferType rw}"
|
||||
match (← inferType rw).eq? with
|
||||
| some (_, _lhs', rhs) =>
|
||||
@@ -243,7 +243,7 @@ where
|
||||
let prf' : OmegaM Expr := do mkEqTrans rw (← prf)
|
||||
pure (lc, prf', facts)
|
||||
| none => panic! "Invalid rewrite rule in 'asLinearCombo'"
|
||||
handleNatCast (e i n : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
handleNatCast (e i n : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
match n with
|
||||
| .fvar h =>
|
||||
if let some v ← h.getValue? then
|
||||
@@ -296,7 +296,7 @@ where
|
||||
| (``Fin.val, #[n, x]) =>
|
||||
handleFinVal e i n x
|
||||
| _ => mkAtomLinearCombo e
|
||||
handleFinVal (e i n x : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
handleFinVal (e i n x : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
match x with
|
||||
| .fvar h =>
|
||||
if let some v ← h.getValue? then
|
||||
@@ -342,7 +342,7 @@ We solve equalities as they are discovered, as this often results in an earlier
|
||||
-/
|
||||
def addIntEquality (p : MetaProblem) (h x : Expr) : OmegaM MetaProblem := do
|
||||
let (lc, prf, facts) ← asLinearCombo x
|
||||
let newFacts : HashSet Expr := facts.fold (init := ∅) fun s e =>
|
||||
let newFacts : Std.HashSet Expr := facts.fold (init := ∅) fun s e =>
|
||||
if p.processedFacts.contains e then s else s.insert e
|
||||
trace[omega] "Adding proof of {lc} = 0"
|
||||
pure <|
|
||||
@@ -358,7 +358,7 @@ We solve equalities as they are discovered, as this often results in an earlier
|
||||
-/
|
||||
def addIntInequality (p : MetaProblem) (h y : Expr) : OmegaM MetaProblem := do
|
||||
let (lc, prf, facts) ← asLinearCombo y
|
||||
let newFacts : HashSet Expr := facts.fold (init := ∅) fun s e =>
|
||||
let newFacts : Std.HashSet Expr := facts.fold (init := ∅) fun s e =>
|
||||
if p.processedFacts.contains e then s else s.insert e
|
||||
trace[omega] "Adding proof of {lc} ≥ 0"
|
||||
pure <|
|
||||
@@ -590,7 +590,7 @@ where
|
||||
|
||||
-- We sort the constraints; otherwise the order is dependent on details of the hashing
|
||||
-- and this can cause test suite output churn
|
||||
prettyConstraints (names : Array String) (constraints : HashMap Coeffs Fact) : String :=
|
||||
prettyConstraints (names : Array String) (constraints : Std.HashMap Coeffs Fact) : String :=
|
||||
constraints.toList
|
||||
|>.toArray
|
||||
|>.qsort (·.1 < ·.1)
|
||||
@@ -615,7 +615,7 @@ where
|
||||
(if Int.natAbs c = 1 then names[i]! else s!"{c.natAbs}*{names[i]!}"))
|
||||
|> String.join
|
||||
|
||||
mentioned (atoms : Array Expr) (constraints : HashMap Coeffs Fact) : MetaM (Array Bool) := do
|
||||
mentioned (atoms : Array Expr) (constraints : Std.HashMap Coeffs Fact) : MetaM (Array Bool) := do
|
||||
let initMask := Array.mkArray atoms.size false
|
||||
return constraints.fold (init := initMask) fun mask coeffs _ =>
|
||||
coeffs.enum.foldl (init := mask) fun mask (i, c) =>
|
||||
|
||||
@@ -10,6 +10,8 @@ import Init.Omega.Logic
|
||||
import Init.Data.BitVec.Basic
|
||||
import Lean.Meta.AppBuilder
|
||||
import Lean.Meta.Canonicalizer
|
||||
import Std.Data.HashMap.Basic
|
||||
import Std.Data.HashSet.Basic
|
||||
|
||||
/-!
|
||||
# The `OmegaM` state monad.
|
||||
@@ -52,7 +54,7 @@ structure Context where
|
||||
/-- The internal state for the `OmegaM` monad, recording previously encountered atoms. -/
|
||||
structure State where
|
||||
/-- The atoms up-to-defeq encountered so far. -/
|
||||
atoms : HashMap Expr Nat := {}
|
||||
atoms : Std.HashMap Expr Nat := {}
|
||||
|
||||
/-- An intermediate layer in the `OmegaM` monad. -/
|
||||
abbrev OmegaM' := StateRefT State (ReaderT Context CanonM)
|
||||
@@ -60,7 +62,7 @@ abbrev OmegaM' := StateRefT State (ReaderT Context CanonM)
|
||||
/--
|
||||
Cache of expressions that have been visited, and their reflection as a linear combination.
|
||||
-/
|
||||
def Cache : Type := HashMap Expr (LinearCombo × OmegaM' Expr)
|
||||
def Cache : Type := Std.HashMap Expr (LinearCombo × OmegaM' Expr)
|
||||
|
||||
/--
|
||||
The `OmegaM` monad maintains two pieces of state:
|
||||
@@ -71,7 +73,7 @@ abbrev OmegaM := StateRefT Cache OmegaM'
|
||||
|
||||
/-- Run a computation in the `OmegaM` monad, starting with no recorded atoms. -/
|
||||
def OmegaM.run (m : OmegaM α) (cfg : OmegaConfig) : MetaM α :=
|
||||
m.run' HashMap.empty |>.run' {} { cfg } |>.run'
|
||||
m.run' Std.HashMap.empty |>.run' {} { cfg } |>.run'
|
||||
|
||||
/-- Retrieve the user-specified configuration options. -/
|
||||
def cfg : OmegaM OmegaConfig := do pure (← read).cfg
|
||||
@@ -162,11 +164,11 @@ def mkEqReflWithExpectedType (a b : Expr) : MetaM Expr := do
|
||||
Analyzes a newly recorded atom,
|
||||
returning a collection of interesting facts about it that should be added to the context.
|
||||
-/
|
||||
def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
|
||||
def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
|
||||
match e.getAppFnArgs with
|
||||
| (``Nat.cast, #[.const ``Int [], _, e']) =>
|
||||
-- Casts of natural numbers are non-negative.
|
||||
let mut r := HashSet.empty.insert (Expr.app (.const ``Int.ofNat_nonneg []) e')
|
||||
let mut r := Std.HashSet.empty.insert (Expr.app (.const ``Int.ofNat_nonneg []) e')
|
||||
match (← cfg).splitNatSub, e'.getAppFnArgs with
|
||||
| true, (``HSub.hSub, #[_, _, _, _, a, b]) =>
|
||||
-- `((a - b : Nat) : Int)` gives a dichotomy
|
||||
@@ -188,7 +190,7 @@ def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
|
||||
let ne_zero := mkApp3 (.const ``Ne [1]) (.const ``Int []) k (toExpr (0 : Int))
|
||||
let pos := mkApp4 (.const ``LT.lt [0]) (.const ``Int []) (.const ``Int.instLTInt [])
|
||||
(toExpr (0 : Int)) k
|
||||
pure <| HashSet.empty.insert
|
||||
pure <| Std.HashSet.empty.insert
|
||||
(mkApp3 (.const ``Int.mul_ediv_self_le []) x k (← mkDecideProof ne_zero)) |>.insert
|
||||
(mkApp3 (.const ``Int.lt_mul_ediv_self_add []) x k (← mkDecideProof pos))
|
||||
| (``HMod.hMod, #[_, _, _, _, x, k]) =>
|
||||
@@ -200,7 +202,7 @@ def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
|
||||
let b_pos := mkApp4 (.const ``LT.lt [0]) (.const ``Int []) (.const ``Int.instLTInt [])
|
||||
(toExpr (0 : Int)) b
|
||||
let pow_pos := mkApp3 (.const ``Lean.Omega.Int.pos_pow_of_pos []) b exp (← mkDecideProof b_pos)
|
||||
pure <| HashSet.empty.insert
|
||||
pure <| Std.HashSet.empty.insert
|
||||
(mkApp3 (.const ``Int.emod_nonneg []) x k
|
||||
(mkApp3 (.const ``Int.ne_of_gt []) k (toExpr (0 : Int)) pow_pos)) |>.insert
|
||||
(mkApp3 (.const ``Int.emod_lt_of_pos []) x k pow_pos)
|
||||
@@ -214,7 +216,7 @@ def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
|
||||
(toExpr (0 : Nat)) b
|
||||
let pow_pos := mkApp3 (.const ``Nat.pos_pow_of_pos []) b exp (← mkDecideProof b_pos)
|
||||
let cast_pos := mkApp2 (.const ``Int.ofNat_pos_of_pos []) k' pow_pos
|
||||
pure <| HashSet.empty.insert
|
||||
pure <| Std.HashSet.empty.insert
|
||||
(mkApp3 (.const ``Int.emod_nonneg []) x k
|
||||
(mkApp3 (.const ``Int.ne_of_gt []) k (toExpr (0 : Int)) cast_pos)) |>.insert
|
||||
(mkApp3 (.const ``Int.emod_lt_of_pos []) x k cast_pos)
|
||||
@@ -222,18 +224,18 @@ def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
|
||||
| (``Nat.cast, #[.const ``Int [], _, x']) =>
|
||||
-- Since we push coercions inside `%`, we need to record here that
|
||||
-- `(x : Int) % (y : Int)` is non-negative.
|
||||
pure <| HashSet.empty.insert (mkApp2 (.const ``Int.emod_ofNat_nonneg []) x' k)
|
||||
pure <| Std.HashSet.empty.insert (mkApp2 (.const ``Int.emod_ofNat_nonneg []) x' k)
|
||||
| _ => pure ∅
|
||||
| _ => pure ∅
|
||||
| (``Min.min, #[_, _, x, y]) =>
|
||||
pure <| HashSet.empty.insert (mkApp2 (.const ``Int.min_le_left []) x y) |>.insert
|
||||
pure <| Std.HashSet.empty.insert (mkApp2 (.const ``Int.min_le_left []) x y) |>.insert
|
||||
(mkApp2 (.const ``Int.min_le_right []) x y)
|
||||
| (``Max.max, #[_, _, x, y]) =>
|
||||
pure <| HashSet.empty.insert (mkApp2 (.const ``Int.le_max_left []) x y) |>.insert
|
||||
pure <| Std.HashSet.empty.insert (mkApp2 (.const ``Int.le_max_left []) x y) |>.insert
|
||||
(mkApp2 (.const ``Int.le_max_right []) x y)
|
||||
| (``ite, #[α, i, dec, t, e]) =>
|
||||
if α == (.const ``Int []) then
|
||||
pure <| HashSet.empty.insert <| mkApp5 (.const ``ite_disjunction [0]) α i dec t e
|
||||
pure <| Std.HashSet.empty.insert <| mkApp5 (.const ``ite_disjunction [0]) α i dec t e
|
||||
else
|
||||
pure {}
|
||||
| _ => pure ∅
|
||||
@@ -248,10 +250,10 @@ Return its index, and, if it is new, a collection of interesting facts about the
|
||||
* for each new atom of the form `((a - b : Nat) : Int)`, the fact:
|
||||
`b ≤ a ∧ ((a - b : Nat) : Int) = a - b ∨ a < b ∧ ((a - b : Nat) : Int) = 0`
|
||||
-/
|
||||
def lookup (e : Expr) : OmegaM (Nat × Option (HashSet Expr)) := do
|
||||
def lookup (e : Expr) : OmegaM (Nat × Option (Std.HashSet Expr)) := do
|
||||
let c ← getThe State
|
||||
let e ← canon e
|
||||
match c.atoms.find? e with
|
||||
match c.atoms[e]? with
|
||||
| some i => return (i, none)
|
||||
| none =>
|
||||
trace[omega] "New atom: {e}"
|
||||
|
||||
@@ -7,7 +7,6 @@ prelude
|
||||
import Init.Control.StateRef
|
||||
import Init.Data.Array.BinSearch
|
||||
import Init.Data.Stream
|
||||
import Lean.Data.HashMap
|
||||
import Lean.ImportingFlag
|
||||
import Lean.Data.SMap
|
||||
import Lean.Declaration
|
||||
@@ -134,7 +133,7 @@ structure Environment where
|
||||
the field `constants`. These auxiliary constants are invisible to the Lean kernel and elaborator.
|
||||
Only the code generator uses them.
|
||||
-/
|
||||
const2ModIdx : HashMap Name ModuleIdx
|
||||
const2ModIdx : Std.HashMap Name ModuleIdx
|
||||
/--
|
||||
Mapping from constant name to `ConstantInfo`. It contains all constants (definitions, theorems, axioms, etc)
|
||||
that have been already type checked by the kernel.
|
||||
@@ -205,7 +204,7 @@ private def getTrustLevel (env : Environment) : UInt32 :=
|
||||
env.header.trustLevel
|
||||
|
||||
def getModuleIdxFor? (env : Environment) (declName : Name) : Option ModuleIdx :=
|
||||
env.const2ModIdx.find? declName
|
||||
env.const2ModIdx[declName]?
|
||||
|
||||
def isConstructor (env : Environment) (declName : Name) : Bool :=
|
||||
match env.find? declName with
|
||||
@@ -721,7 +720,7 @@ def writeModule (env : Environment) (fname : System.FilePath) : IO Unit := do
|
||||
Construct a mapping from persistent extension name to entension index at the array of persistent extensions.
|
||||
We only consider extensions starting with index `>= startingAt`.
|
||||
-/
|
||||
def mkExtNameMap (startingAt : Nat) : IO (HashMap Name Nat) := do
|
||||
def mkExtNameMap (startingAt : Nat) : IO (Std.HashMap Name Nat) := do
|
||||
let descrs ← persistentEnvExtensionsRef.get
|
||||
let mut result := {}
|
||||
for h : i in [startingAt : descrs.size] do
|
||||
@@ -742,7 +741,7 @@ private def setImportedEntries (env : Environment) (mods : Array ModuleData) (st
|
||||
have : modIdx < mods.size := h.upper
|
||||
let mod := mods[modIdx]
|
||||
for (extName, entries) in mod.entries do
|
||||
if let some entryIdx := extNameIdx.find? extName then
|
||||
if let some entryIdx := extNameIdx[extName]? then
|
||||
env := extDescrs[entryIdx]!.toEnvExtension.modifyState env fun s => { s with importedEntries := s.importedEntries.set! modIdx entries }
|
||||
return env
|
||||
|
||||
@@ -790,9 +789,9 @@ structure ImportState where
|
||||
moduleData : Array ModuleData := #[]
|
||||
regions : Array CompactedRegion := #[]
|
||||
|
||||
def throwAlreadyImported (s : ImportState) (const2ModIdx : HashMap Name ModuleIdx) (modIdx : Nat) (cname : Name) : IO α := do
|
||||
def throwAlreadyImported (s : ImportState) (const2ModIdx : Std.HashMap Name ModuleIdx) (modIdx : Nat) (cname : Name) : IO α := do
|
||||
let modName := s.moduleNames[modIdx]!
|
||||
let constModName := s.moduleNames[const2ModIdx[cname].get!.toNat]!
|
||||
let constModName := s.moduleNames[const2ModIdx[cname]!.toNat]!
|
||||
throw <| IO.userError s!"import {modName} failed, environment already contains '{cname}' from {constModName}"
|
||||
|
||||
abbrev ImportStateM := StateRefT ImportState IO
|
||||
@@ -856,21 +855,21 @@ def finalizeImport (s : ImportState) (imports : Array Import) (opts : Options) (
|
||||
(leakEnv := false) : IO Environment := do
|
||||
let numConsts := s.moduleData.foldl (init := 0) fun numConsts mod =>
|
||||
numConsts + mod.constants.size + mod.extraConstNames.size
|
||||
let mut const2ModIdx : HashMap Name ModuleIdx := mkHashMap (capacity := numConsts)
|
||||
let mut constantMap : HashMap Name ConstantInfo := mkHashMap (capacity := numConsts)
|
||||
let mut const2ModIdx : Std.HashMap Name ModuleIdx := Std.HashMap.empty (capacity := numConsts)
|
||||
let mut constantMap : Std.HashMap Name ConstantInfo := Std.HashMap.empty (capacity := numConsts)
|
||||
for h:modIdx in [0:s.moduleData.size] do
|
||||
let mod := s.moduleData[modIdx]'h.upper
|
||||
for cname in mod.constNames, cinfo in mod.constants do
|
||||
match constantMap.insertIfNew cname cinfo with
|
||||
| (constantMap', cinfoPrev?) =>
|
||||
match constantMap.getThenInsertIfNew? cname cinfo with
|
||||
| (cinfoPrev?, constantMap') =>
|
||||
constantMap := constantMap'
|
||||
if let some cinfoPrev := cinfoPrev? then
|
||||
-- Recall that the map has not been modified when `cinfoPrev? = some _`.
|
||||
unless equivInfo cinfoPrev cinfo do
|
||||
throwAlreadyImported s const2ModIdx modIdx cname
|
||||
const2ModIdx := const2ModIdx.insertIfNew cname modIdx |>.1
|
||||
const2ModIdx := const2ModIdx.insertIfNew cname modIdx
|
||||
for cname in mod.extraConstNames do
|
||||
const2ModIdx := const2ModIdx.insertIfNew cname modIdx |>.1
|
||||
const2ModIdx := const2ModIdx.insertIfNew cname modIdx
|
||||
let constants : ConstMap := SMap.fromHashMap constantMap false
|
||||
let exts ← mkInitialExtensionStates
|
||||
let mut env : Environment := {
|
||||
@@ -936,7 +935,7 @@ builtin_initialize namespacesExt : SimplePersistentEnvExtension Name NameSSet
|
||||
6.18% of the runtime is here. It was 9.31% before the `HashMap` optimization.
|
||||
-/
|
||||
let capacity := as.foldl (init := 0) fun r e => r + e.size
|
||||
let map : HashMap Name Unit := mkHashMap capacity
|
||||
let map : Std.HashMap Name Unit := Std.HashMap.empty capacity
|
||||
let map := mkStateFromImportedEntries (fun map name => map.insert name ()) map as
|
||||
SMap.fromHashMap map |>.switch
|
||||
addEntryFn := fun s n => s.insert n
|
||||
|
||||
@@ -8,6 +8,7 @@ import Init.Data.Hashable
|
||||
import Lean.Data.KVMap
|
||||
import Lean.Data.SMap
|
||||
import Lean.Level
|
||||
import Std.Data.HashSet.Basic
|
||||
|
||||
namespace Lean
|
||||
|
||||
@@ -244,7 +245,7 @@ def FVarIdSet.insert (s : FVarIdSet) (fvarId : FVarId) : FVarIdSet :=
|
||||
A set of unique free variable identifiers implemented using hashtables.
|
||||
Hashtables are faster than red-black trees if they are used linearly.
|
||||
They are not persistent data-structures. -/
|
||||
def FVarIdHashSet := HashSet FVarId
|
||||
def FVarIdHashSet := Std.HashSet FVarId
|
||||
deriving Inhabited, EmptyCollection
|
||||
|
||||
/--
|
||||
@@ -1388,11 +1389,11 @@ def mkDecIsTrue (pred proof : Expr) :=
|
||||
def mkDecIsFalse (pred proof : Expr) :=
|
||||
mkAppB (mkConst `Decidable.isFalse) pred proof
|
||||
|
||||
abbrev ExprMap (α : Type) := HashMap Expr α
|
||||
abbrev ExprMap (α : Type) := Std.HashMap Expr α
|
||||
abbrev PersistentExprMap (α : Type) := PHashMap Expr α
|
||||
abbrev SExprMap (α : Type) := SMap Expr α
|
||||
|
||||
abbrev ExprSet := HashSet Expr
|
||||
abbrev ExprSet := Std.HashSet Expr
|
||||
abbrev PersistentExprSet := PHashSet Expr
|
||||
abbrev PExprSet := PersistentExprSet
|
||||
|
||||
@@ -1417,7 +1418,7 @@ instance : ToString ExprStructEq := ⟨fun e => toString e.val⟩
|
||||
|
||||
end ExprStructEq
|
||||
|
||||
abbrev ExprStructMap (α : Type) := HashMap ExprStructEq α
|
||||
abbrev ExprStructMap (α : Type) := Std.HashMap ExprStructEq α
|
||||
abbrev PersistentExprStructMap (α : Type) := PHashMap ExprStructEq α
|
||||
|
||||
namespace Expr
|
||||
|
||||
@@ -31,7 +31,7 @@ namespace Lean
|
||||
abbrev LabelExtension := SimpleScopedEnvExtension Name (Array Name)
|
||||
|
||||
/-- The collection of all current `LabelExtension`s, indexed by name. -/
|
||||
abbrev LabelExtensionMap := HashMap Name LabelExtension
|
||||
abbrev LabelExtensionMap := Std.HashMap Name LabelExtension
|
||||
|
||||
/-- Store the current `LabelExtension`s. -/
|
||||
builtin_initialize labelExtensionMapRef : IO.Ref LabelExtensionMap ← IO.mkRef {}
|
||||
@@ -88,7 +88,7 @@ macro (name := _root_.Lean.Parser.Command.registerLabelAttr)
|
||||
/-- When `attrName` is an attribute created using `register_labelled_attr`,
|
||||
return the names of all declarations labelled using that attribute. -/
|
||||
def labelled (attrName : Name) : CoreM (Array Name) := do
|
||||
match (← labelExtensionMapRef.get).find? attrName with
|
||||
match (← labelExtensionMapRef.get)[attrName]? with
|
||||
| none => throwError "No extension named {attrName}"
|
||||
| some ext => pure <| ext.getState (← getEnv)
|
||||
|
||||
|
||||
@@ -5,8 +5,6 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Array.QSort
|
||||
import Lean.Data.HashMap
|
||||
import Lean.Data.HashSet
|
||||
import Lean.Data.PersistentHashMap
|
||||
import Lean.Data.PersistentHashSet
|
||||
import Lean.Hygiene
|
||||
@@ -614,9 +612,9 @@ where
|
||||
|
||||
end Level
|
||||
|
||||
abbrev LevelMap (α : Type) := HashMap Level α
|
||||
abbrev LevelMap (α : Type) := Std.HashMap Level α
|
||||
abbrev PersistentLevelMap (α : Type) := PHashMap Level α
|
||||
abbrev LevelSet := HashSet Level
|
||||
abbrev LevelSet := Std.HashSet Level
|
||||
abbrev PersistentLevelSet := PHashSet Level
|
||||
abbrev PLevelSet := PersistentLevelSet
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ def constructorNameAsVariable : Linter where
|
||||
| return
|
||||
|
||||
let infoTrees := (← get).infoState.trees.toArray
|
||||
let warnings : IO.Ref (Lean.HashMap String.Range (Syntax × Name × Name)) ← IO.mkRef {}
|
||||
let warnings : IO.Ref (Std.HashMap String.Range (Syntax × Name × Name)) ← IO.mkRef {}
|
||||
|
||||
for tree in infoTrees do
|
||||
tree.visitM' (preNode := fun ci info _ => do
|
||||
|
||||
@@ -149,7 +149,7 @@ def checkDecl : SimpleHandler := fun stx => do
|
||||
lintField rest[1][0] stx[1] "computed field"
|
||||
else if rest.getKind == ``«structure» then
|
||||
unless rest[5][2].isNone do
|
||||
let redecls : HashSet String.Pos :=
|
||||
let redecls : Std.HashSet String.Pos :=
|
||||
(← get).infoState.trees.foldl (init := {}) fun s tree =>
|
||||
tree.foldInfo (init := s) fun _ info s =>
|
||||
if let .ofFieldRedeclInfo info := info then
|
||||
|
||||
@@ -270,14 +270,14 @@ pointer identity and does not store the objects, so it is important not to store
|
||||
pointer to an object in the map, or it can be freed and reused, resulting in incorrect behavior.
|
||||
|
||||
Returns `true` if the object was not already in the set. -/
|
||||
unsafe def insertObjImpl {α : Type} (set : IO.Ref (HashSet USize)) (a : α) : IO Bool := do
|
||||
unsafe def insertObjImpl {α : Type} (set : IO.Ref (Std.HashSet USize)) (a : α) : IO Bool := do
|
||||
if (← set.get).contains (ptrAddrUnsafe a) then
|
||||
return false
|
||||
set.modify (·.insert (ptrAddrUnsafe a))
|
||||
return true
|
||||
|
||||
@[inherit_doc insertObjImpl, implemented_by insertObjImpl]
|
||||
opaque insertObj {α : Type} (set : IO.Ref (HashSet USize)) (a : α) : IO Bool
|
||||
opaque insertObj {α : Type} (set : IO.Ref (Std.HashSet USize)) (a : α) : IO Bool
|
||||
|
||||
/--
|
||||
Collects into `fvarUses` all `fvar`s occurring in the `Expr`s in `assignments`.
|
||||
@@ -285,8 +285,8 @@ This implementation respects subterm sharing in both the `PersistentHashMap` and
|
||||
to ensure that pointer-equal subobjects are not visited multiple times, which is important
|
||||
in practice because these expressions are very frequently highly shared.
|
||||
-/
|
||||
partial def visitAssignments (set : IO.Ref (HashSet USize))
|
||||
(fvarUses : IO.Ref (HashSet FVarId))
|
||||
partial def visitAssignments (set : IO.Ref (Std.HashSet USize))
|
||||
(fvarUses : IO.Ref (Std.HashSet FVarId))
|
||||
(assignments : Array (PersistentHashMap MVarId Expr)) : IO Unit := do
|
||||
MonadCacheT.run do
|
||||
for assignment in assignments do
|
||||
@@ -316,8 +316,8 @@ where
|
||||
/-- Given `aliases` as a map from an alias to what it aliases, we get the original
|
||||
term by recursion. This has no cycle detection, so if `aliases` contains a loop
|
||||
then this function will recurse infinitely. -/
|
||||
partial def followAliases (aliases : HashMap FVarId FVarId) (x : FVarId) : FVarId :=
|
||||
match aliases.find? x with
|
||||
partial def followAliases (aliases : Std.HashMap FVarId FVarId) (x : FVarId) : FVarId :=
|
||||
match aliases[x]? with
|
||||
| none => x
|
||||
| some y => followAliases aliases y
|
||||
|
||||
@@ -343,17 +343,17 @@ structure References where
|
||||
the spans for `foo`, `bar`, and `baz`. Global definitions are always treated as used.
|
||||
(It would be nice to be able to detect unused global definitions but this requires more
|
||||
information than the linter framework can provide.) -/
|
||||
constDecls : HashSet String.Range := .empty
|
||||
constDecls : Std.HashSet String.Range := .empty
|
||||
/-- The collection of all local declarations, organized by the span of the declaration.
|
||||
We collapse all declarations declared at the same position into a single record using
|
||||
`FVarDefinition.aliases`. -/
|
||||
fvarDefs : HashMap String.Range FVarDefinition := .empty
|
||||
fvarDefs : Std.HashMap String.Range FVarDefinition := .empty
|
||||
/-- The set of `FVarId`s that are used directly. These may or may not be aliases. -/
|
||||
fvarUses : HashSet FVarId := .empty
|
||||
fvarUses : Std.HashSet FVarId := .empty
|
||||
/-- A mapping from alias to original FVarId. We don't guarantee that the value is not itself
|
||||
an alias, but we use `followAliases` when adding new elements to try to avoid long chains. -/
|
||||
-- TODO: use a `UnionFind` data structure here
|
||||
fvarAliases : HashMap FVarId FVarId := .empty
|
||||
fvarAliases : Std.HashMap FVarId FVarId := .empty
|
||||
/-- Collection of all `MetavarContext`s following the execution of a tactic. We trawl these
|
||||
if needed to find additional `fvarUses`. -/
|
||||
assignments : Array (PersistentHashMap MVarId Expr) := #[]
|
||||
@@ -391,7 +391,7 @@ def collectReferences (infoTrees : Array Elab.InfoTree) (cmdStxRange : String.Ra
|
||||
if s.startsWith "_" then return
|
||||
-- Record this either as a new `fvarDefs`, or an alias of an existing one
|
||||
modify fun s =>
|
||||
if let some ref := s.fvarDefs.find? range then
|
||||
if let some ref := s.fvarDefs[range]? then
|
||||
{ s with fvarDefs := s.fvarDefs.insert range { ref with aliases := ref.aliases.push id } }
|
||||
else
|
||||
{ s with fvarDefs := s.fvarDefs.insert range { userName := ldecl.userName, stx, opts, aliases := #[id] } }
|
||||
@@ -444,7 +444,7 @@ def unusedVariables : Linter where
|
||||
-- Resolve all recursive references in `fvarAliases`.
|
||||
-- At this point everything in `fvarAliases` is guaranteed not to be itself an alias,
|
||||
-- and should point to some element of `FVarDefinition.aliases` in `s.fvarDefs`
|
||||
let fvarAliases : HashMap FVarId FVarId := s.fvarAliases.fold (init := {}) fun m id baseId =>
|
||||
let fvarAliases : Std.HashMap FVarId FVarId := s.fvarAliases.fold (init := {}) fun m id baseId =>
|
||||
m.insert id (followAliases s.fvarAliases baseId)
|
||||
|
||||
-- Collect all non-alias fvars corresponding to `fvarUses` by resolving aliases in the list.
|
||||
@@ -461,7 +461,7 @@ def unusedVariables : Linter where
|
||||
let fvarUses ← fvarUsesRef.get
|
||||
-- If any of the `fvar`s corresponding to this declaration is (an alias of) a variable in
|
||||
-- `fvarUses`, then it is used
|
||||
if aliases.any fun id => fvarUses.contains (fvarAliases.findD id id) then continue
|
||||
if aliases.any fun id => fvarUses.contains (fvarAliases.getD id id) then continue
|
||||
-- If this is a global declaration then it is (potentially) used after the command
|
||||
if s.constDecls.contains range then continue
|
||||
|
||||
@@ -496,7 +496,7 @@ def unusedVariables : Linter where
|
||||
initializedMVars := true
|
||||
let fvarUses ← fvarUsesRef.get
|
||||
-- Redo the initial check because `fvarUses` could be bigger now
|
||||
if aliases.any fun id => fvarUses.contains (fvarAliases.findD id id) then continue
|
||||
if aliases.any fun id => fvarUses.contains (fvarAliases.getD id id) then continue
|
||||
|
||||
-- If we made it this far then the variable is unused and not ignored
|
||||
unused := unused.push (declStx, userName)
|
||||
|
||||
@@ -16,8 +16,8 @@ structure State where
|
||||
nextParamIdx : Nat := 0
|
||||
paramNames : Array Name := #[]
|
||||
fvars : Array Expr := #[]
|
||||
lmap : HashMap LMVarId Level := {}
|
||||
emap : HashMap MVarId Expr := {}
|
||||
lmap : Std.HashMap LMVarId Level := {}
|
||||
emap : Std.HashMap MVarId Expr := {}
|
||||
abstractLevels : Bool -- whether to abstract level mvars
|
||||
|
||||
abbrev M := StateM State
|
||||
@@ -54,7 +54,7 @@ private partial def abstractLevelMVars (u : Level) : M Level := do
|
||||
if depth != s.mctx.depth then
|
||||
return u -- metavariables from lower depths are treated as constants
|
||||
else
|
||||
match s.lmap.find? mvarId with
|
||||
match s.lmap[mvarId]? with
|
||||
| some u => pure u
|
||||
| none =>
|
||||
let paramId := Name.mkNum `_abstMVar s.nextParamIdx
|
||||
@@ -87,7 +87,7 @@ partial def abstractExprMVars (e : Expr) : M Expr := do
|
||||
if e != eNew then
|
||||
abstractExprMVars eNew
|
||||
else
|
||||
match (← get).emap.find? mvarId with
|
||||
match (← get).emap[mvarId]? with
|
||||
| some e =>
|
||||
return e
|
||||
| none =>
|
||||
|
||||
@@ -5,9 +5,9 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Util.ShareCommon
|
||||
import Lean.Data.HashMap
|
||||
import Lean.Meta.Basic
|
||||
import Lean.Meta.FunInfo
|
||||
import Std.Data.HashMap.Raw
|
||||
|
||||
namespace Lean.Meta
|
||||
namespace Canonicalizer
|
||||
@@ -47,12 +47,12 @@ State for the `CanonM` monad.
|
||||
-/
|
||||
structure State where
|
||||
/-- Mapping from `Expr` to hash. -/
|
||||
-- We use `HashMapImp` to ensure we don't have to tag `State` as `unsafe`.
|
||||
cache : HashMapImp ExprVisited UInt64 := mkHashMapImp
|
||||
-- We use `HashMap.Raw` to ensure we don't have to tag `State` as `unsafe`.
|
||||
cache : Std.HashMap.Raw ExprVisited UInt64 := Std.HashMap.Raw.empty
|
||||
/--
|
||||
Given a hashcode `k` and `keyToExprs.find? h = some es`, we have that all `es` have hashcode `k`, and
|
||||
are not definitionally equal modulo the transparency setting used. -/
|
||||
keyToExprs : HashMap UInt64 (List Expr) := mkHashMap
|
||||
keyToExprs : Std.HashMap UInt64 (List Expr) := ∅
|
||||
|
||||
instance : Inhabited State where
|
||||
default := {}
|
||||
@@ -70,7 +70,7 @@ def CanonM.run (x : CanonM α) (transparency := TransparencyMode.instances) (s :
|
||||
StateRefT'.run (x transparency) s
|
||||
|
||||
private partial def mkKey (e : Expr) : CanonM UInt64 := do
|
||||
if let some hash := unsafe (← get).cache.find? { e } then
|
||||
if let some hash := unsafe (← get).cache.get? { e } then
|
||||
return hash
|
||||
else
|
||||
let key ← match e with
|
||||
@@ -107,7 +107,7 @@ private partial def mkKey (e : Expr) : CanonM UInt64 := do
|
||||
return mixHash (← mkKey v) (← mkKey b)
|
||||
| .proj _ i s =>
|
||||
return mixHash i.toUInt64 (← mkKey s)
|
||||
unsafe modify fun { cache, keyToExprs} => { keyToExprs, cache := cache.insert { e } key |>.1 }
|
||||
unsafe modify fun { cache, keyToExprs} => { keyToExprs, cache := cache.insert { e } key }
|
||||
return key
|
||||
|
||||
/--
|
||||
@@ -116,7 +116,7 @@ private partial def mkKey (e : Expr) : CanonM UInt64 := do
|
||||
def canon (e : Expr) : CanonM Expr := do
|
||||
let k ← mkKey e
|
||||
-- Find all expressions canonicalized before that have the same key.
|
||||
if let some es' := unsafe (← get).keyToExprs.find? k then
|
||||
if let some es' := unsafe (← get).keyToExprs[k]? then
|
||||
withTransparency (← read) do
|
||||
for e' in es' do
|
||||
-- Found an expression `e'` that is definitionally equal to `e` and share the same key.
|
||||
|
||||
@@ -127,7 +127,7 @@ abbrev ClosureM := ReaderT Context $ StateRefT State MetaM
|
||||
pure u
|
||||
else
|
||||
let s ← get
|
||||
match s.visitedLevel.find? u with
|
||||
match s.visitedLevel[u]? with
|
||||
| some v => pure v
|
||||
| none => do
|
||||
let v ← f u
|
||||
@@ -139,7 +139,7 @@ abbrev ClosureM := ReaderT Context $ StateRefT State MetaM
|
||||
pure e
|
||||
else
|
||||
let s ← get
|
||||
match s.visitedExpr.find? e with
|
||||
match s.visitedExpr.get? e with
|
||||
| some r => pure r
|
||||
| none =>
|
||||
let r ← f e
|
||||
|
||||
@@ -52,14 +52,14 @@ which appear in the type and local context of `mvarId`, as well as the
|
||||
metavariables which *those* metavariables depend on, etc.
|
||||
-/
|
||||
partial def _root_.Lean.MVarId.getMVarDependencies (mvarId : MVarId) (includeDelayed := false) :
|
||||
MetaM (HashSet MVarId) :=
|
||||
MetaM (Std.HashSet MVarId) :=
|
||||
(·.snd) <$> (go mvarId).run {}
|
||||
where
|
||||
/-- Auxiliary definition for `getMVarDependencies`. -/
|
||||
addMVars (e : Expr) : StateRefT (HashSet MVarId) MetaM Unit := do
|
||||
addMVars (e : Expr) : StateRefT (Std.HashSet MVarId) MetaM Unit := do
|
||||
let mvars ← getMVars e
|
||||
let mut s ← get
|
||||
set ({} : HashSet MVarId) -- Ensure that `s` is not shared.
|
||||
set ({} : Std.HashSet MVarId) -- Ensure that `s` is not shared.
|
||||
for mvarId in mvars do
|
||||
if ← pure includeDelayed <||> notM (mvarId.isDelayedAssigned) then
|
||||
s := s.insert mvarId
|
||||
@@ -67,7 +67,7 @@ where
|
||||
mvars.forM go
|
||||
|
||||
/-- Auxiliary definition for `getMVarDependencies`. -/
|
||||
go (mvarId : MVarId) : StateRefT (HashSet MVarId) MetaM Unit :=
|
||||
go (mvarId : MVarId) : StateRefT (Std.HashSet MVarId) MetaM Unit :=
|
||||
withIncRecDepth do
|
||||
let mdecl ← mvarId.getDecl
|
||||
addMVars mdecl.type
|
||||
|
||||
@@ -695,7 +695,7 @@ def throwOutOfScopeFVar : CheckAssignmentM α :=
|
||||
throw <| Exception.internal outOfScopeExceptionId
|
||||
|
||||
private def findCached? (e : Expr) : CheckAssignmentM (Option Expr) := do
|
||||
return (← get).cache.find? e
|
||||
return (← get).cache.get? e
|
||||
|
||||
private def cache (e r : Expr) : CheckAssignmentM Unit := do
|
||||
modify fun s => { s with cache := s.cache.insert e r }
|
||||
|
||||
@@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Data.AssocList
|
||||
import Lean.HeadIndex
|
||||
import Lean.Meta.Basic
|
||||
|
||||
|
||||
@@ -286,7 +286,7 @@ private structure Trie (α : Type) where
|
||||
/-- Index of trie matching star. -/
|
||||
star : TrieIndex
|
||||
/-- Following matches based on key of trie. -/
|
||||
children : HashMap Key TrieIndex
|
||||
children : Std.HashMap Key TrieIndex
|
||||
/-- Lazy entries at this trie that are not processed. -/
|
||||
pending : Array (LazyEntry α) := #[]
|
||||
deriving Inhabited
|
||||
@@ -318,7 +318,7 @@ structure LazyDiscrTree (α : Type) where
|
||||
/-- Backing array of trie entries. Should be owned by this trie. -/
|
||||
tries : Array (LazyDiscrTree.Trie α) := #[default]
|
||||
/-- Map from discriminator trie roots to the index. -/
|
||||
roots : Lean.HashMap LazyDiscrTree.Key LazyDiscrTree.TrieIndex := {}
|
||||
roots : Std.HashMap LazyDiscrTree.Key LazyDiscrTree.TrieIndex := {}
|
||||
|
||||
namespace LazyDiscrTree
|
||||
|
||||
@@ -445,9 +445,9 @@ private def addLazyEntryToTrie (i:TrieIndex) (e : LazyEntry α) : MatchM α Unit
|
||||
modify (·.modify i (·.pushPending e))
|
||||
|
||||
private def evalLazyEntry (config : WhnfCoreConfig)
|
||||
(p : Array α × TrieIndex × HashMap Key TrieIndex)
|
||||
(p : Array α × TrieIndex × Std.HashMap Key TrieIndex)
|
||||
(entry : LazyEntry α)
|
||||
: MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
|
||||
: MatchM α (Array α × TrieIndex × Std.HashMap Key TrieIndex) := do
|
||||
let (values, starIdx, children) := p
|
||||
let (todo, lctx, v) := entry
|
||||
if todo.isEmpty then
|
||||
@@ -465,7 +465,7 @@ private def evalLazyEntry (config : WhnfCoreConfig)
|
||||
addLazyEntryToTrie starIdx (todo, lctx, v)
|
||||
pure (values, starIdx, children)
|
||||
else
|
||||
match children.find? k with
|
||||
match children[k]? with
|
||||
| none =>
|
||||
let children := children.insert k (← newTrie (todo, lctx, v))
|
||||
pure (values, starIdx, children)
|
||||
@@ -478,16 +478,16 @@ This evaluates all lazy entries in a trie and updates `values`, `starIdx`, and `
|
||||
accordingly.
|
||||
-/
|
||||
private partial def evalLazyEntries (config : WhnfCoreConfig)
|
||||
(values : Array α) (starIdx : TrieIndex) (children : HashMap Key TrieIndex)
|
||||
(values : Array α) (starIdx : TrieIndex) (children : Std.HashMap Key TrieIndex)
|
||||
(entries : Array (LazyEntry α)) :
|
||||
MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
|
||||
MatchM α (Array α × TrieIndex × Std.HashMap Key TrieIndex) := do
|
||||
let mut values := values
|
||||
let mut starIdx := starIdx
|
||||
let mut children := children
|
||||
entries.foldlM (init := (values, starIdx, children)) (evalLazyEntry config)
|
||||
|
||||
private def evalNode (c : TrieIndex) :
|
||||
MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
|
||||
MatchM α (Array α × TrieIndex × Std.HashMap Key TrieIndex) := do
|
||||
let .node vs star cs pending := (←get).get! c
|
||||
if pending.size = 0 then
|
||||
pure (vs, star, cs)
|
||||
@@ -508,7 +508,7 @@ def dropKeyAux (next : TrieIndex) (rest : List Key) :
|
||||
| [] =>
|
||||
modify (·.set! next {values := #[], star, children})
|
||||
| k :: r => do
|
||||
let next := if k == .star then star else children.findD k 0
|
||||
let next := if k == .star then star else children.getD k 0
|
||||
dropKeyAux next r
|
||||
|
||||
/--
|
||||
@@ -519,7 +519,7 @@ def dropKey (t : LazyDiscrTree α) (path : List LazyDiscrTree.Key) : MetaM (Lazy
|
||||
match path with
|
||||
| [] => pure t
|
||||
| rootKey :: rest => do
|
||||
let idx := t.roots.findD rootKey 0
|
||||
let idx := t.roots.getD rootKey 0
|
||||
Prod.snd <$> runMatch t (dropKeyAux idx rest)
|
||||
|
||||
/--
|
||||
@@ -628,7 +628,7 @@ private partial def getMatchLoop (cases : Array PartialMatch) (result : MatchRes
|
||||
else
|
||||
cases.push { todo, score := ca.score, c := star }
|
||||
let pushNonStar (k : Key) (args : Array Expr) (cases : Array PartialMatch) :=
|
||||
match cs.find? k with
|
||||
match cs[k]? with
|
||||
| none => cases
|
||||
| some c => cases.push { todo := todo ++ args, score := ca.score + 1, c }
|
||||
let cases := pushStar cases
|
||||
@@ -650,8 +650,8 @@ private partial def getMatchLoop (cases : Array PartialMatch) (result : MatchRes
|
||||
cases |> pushNonStar k args
|
||||
getMatchLoop cases result
|
||||
|
||||
private def getStarResult (root : Lean.HashMap Key TrieIndex) : MatchM α (MatchResult α) :=
|
||||
match root.find? .star with
|
||||
private def getStarResult (root : Std.HashMap Key TrieIndex) : MatchM α (MatchResult α) :=
|
||||
match root[Key.star]? with
|
||||
| none =>
|
||||
pure <| {}
|
||||
| some idx => do
|
||||
@@ -661,16 +661,16 @@ private def getStarResult (root : Lean.HashMap Key TrieIndex) : MatchM α (Match
|
||||
/-
|
||||
Add partial match to cases if discriminator tree root map has potential matches.
|
||||
-/
|
||||
private def pushRootCase (r : Lean.HashMap Key TrieIndex) (k : Key) (args : Array Expr)
|
||||
private def pushRootCase (r : Std.HashMap Key TrieIndex) (k : Key) (args : Array Expr)
|
||||
(cases : Array PartialMatch) : Array PartialMatch :=
|
||||
match r.find? k with
|
||||
match r[k]? with
|
||||
| none => cases
|
||||
| some c => cases.push { todo := args, score := 1, c }
|
||||
|
||||
/--
|
||||
Find values that match `e` in `root`.
|
||||
-/
|
||||
private def getMatchCore (root : Lean.HashMap Key TrieIndex) (e : Expr) :
|
||||
private def getMatchCore (root : Std.HashMap Key TrieIndex) (e : Expr) :
|
||||
MatchM α (MatchResult α) := do
|
||||
let result ← getStarResult root
|
||||
let (k, args) ← MatchClone.getMatchKeyArgs e (root := true) (← read)
|
||||
@@ -701,7 +701,7 @@ of elements using concurrent functions for generating entries.
|
||||
-/
|
||||
private structure PreDiscrTree (α : Type) where
|
||||
/-- Maps keys to index in tries array. -/
|
||||
roots : HashMap Key Nat := {}
|
||||
roots : Std.HashMap Key Nat := {}
|
||||
/-- Lazy entries for root of trie. -/
|
||||
tries : Array (Array (LazyEntry α)) := #[]
|
||||
deriving Inhabited
|
||||
@@ -711,7 +711,7 @@ namespace PreDiscrTree
|
||||
private def modifyAt (d : PreDiscrTree α) (k : Key)
|
||||
(f : Array (LazyEntry α) → Array (LazyEntry α)) : PreDiscrTree α :=
|
||||
let { roots, tries } := d
|
||||
match roots.find? k with
|
||||
match roots[k]? with
|
||||
| .none =>
|
||||
let roots := roots.insert k tries.size
|
||||
{ roots, tries := tries.push (f #[]) }
|
||||
|
||||
@@ -68,7 +68,7 @@ where
|
||||
loop lhss alts minors
|
||||
|
||||
structure State where
|
||||
used : HashSet Nat := {} -- used alternatives
|
||||
used : Std.HashSet Nat := {} -- used alternatives
|
||||
counterExamples : List (List Example) := []
|
||||
|
||||
/-- Return true if the given (sub-)problem has been solved. -/
|
||||
|
||||
@@ -28,17 +28,17 @@ such as `contradiction`.
|
||||
-/
|
||||
private def _root_.Lean.MVarId.contradictionQuick (mvarId : MVarId) : MetaM Bool := do
|
||||
mvarId.withContext do
|
||||
let mut posMap : HashMap Expr FVarId := {}
|
||||
let mut negMap : HashMap Expr FVarId := {}
|
||||
let mut posMap : Std.HashMap Expr FVarId := {}
|
||||
let mut negMap : Std.HashMap Expr FVarId := {}
|
||||
for localDecl in (← getLCtx) do
|
||||
unless localDecl.isImplementationDetail do
|
||||
if let some p ← matchNot? localDecl.type then
|
||||
if let some pFVarId := posMap.find? p then
|
||||
if let some pFVarId := posMap[p]? then
|
||||
mvarId.assign (← mkAbsurd (← mvarId.getType) (mkFVar pFVarId) localDecl.toExpr)
|
||||
return true
|
||||
negMap := negMap.insert p localDecl.fvarId
|
||||
if (← isProp localDecl.type) then
|
||||
if let some nFVarId := negMap.find? localDecl.type then
|
||||
if let some nFVarId := negMap[localDecl.type]? then
|
||||
mvarId.assign (← mkAbsurd (← mvarId.getType) localDecl.toExpr (mkFVar nFVarId))
|
||||
return true
|
||||
posMap := posMap.insert localDecl.type localDecl.fvarId
|
||||
|
||||
@@ -97,8 +97,8 @@ namespace MkTableKey
|
||||
|
||||
structure State where
|
||||
nextIdx : Nat := 0
|
||||
lmap : HashMap LMVarId Level := {}
|
||||
emap : HashMap MVarId Expr := {}
|
||||
lmap : Std.HashMap LMVarId Level := {}
|
||||
emap : Std.HashMap MVarId Expr := {}
|
||||
mctx : MetavarContext
|
||||
|
||||
abbrev M := StateM State
|
||||
@@ -120,7 +120,7 @@ partial def normLevel (u : Level) : M Level := do
|
||||
return u
|
||||
else
|
||||
let s ← get
|
||||
match (← get).lmap.find? mvarId with
|
||||
match (← get).lmap[mvarId]? with
|
||||
| some u' => pure u'
|
||||
| none =>
|
||||
let u' := mkLevelParam <| Name.mkNum `_tc s.nextIdx
|
||||
@@ -145,7 +145,7 @@ partial def normExpr (e : Expr) : M Expr := do
|
||||
return e
|
||||
else
|
||||
let s ← get
|
||||
match s.emap.find? mvarId with
|
||||
match s.emap[mvarId]? with
|
||||
| some e' => pure e'
|
||||
| none => do
|
||||
let e' := mkFVar { name := Name.mkNum `_tc s.nextIdx }
|
||||
@@ -186,7 +186,7 @@ structure State where
|
||||
result? : Option AbstractMVarsResult := none
|
||||
generatorStack : Array GeneratorNode := #[]
|
||||
resumeStack : Array (ConsumerNode × Answer) := #[]
|
||||
tableEntries : HashMap Expr TableEntry := {}
|
||||
tableEntries : Std.HashMap Expr TableEntry := {}
|
||||
|
||||
abbrev SynthM := ReaderT Context $ StateRefT State MetaM
|
||||
|
||||
@@ -265,7 +265,7 @@ def newSubgoal (mctx : MetavarContext) (key : Expr) (mvar : Expr) (waiter : Wait
|
||||
pure ((), m!"new goal {key}")
|
||||
|
||||
def findEntry? (key : Expr) : SynthM (Option TableEntry) := do
|
||||
return (← get).tableEntries.find? key
|
||||
return (← get).tableEntries[key]?
|
||||
|
||||
def getEntry (key : Expr) : SynthM TableEntry := do
|
||||
match (← findEntry? key) with
|
||||
@@ -553,7 +553,7 @@ def generate : SynthM Unit := do
|
||||
/- See comment at `typeHasMVars` -/
|
||||
if backward.synthInstance.canonInstances.get (← getOptions) then
|
||||
unless gNode.typeHasMVars do
|
||||
if let some entry := (← get).tableEntries.find? key then
|
||||
if let some entry := (← get).tableEntries[key]? then
|
||||
if entry.answers.any fun answer => answer.result.numMVars == 0 then
|
||||
/-
|
||||
We already have an answer that:
|
||||
|
||||
@@ -66,9 +66,9 @@ inductive PreExpr
|
||||
def toACExpr (op l r : Expr) : MetaM (Array Expr × ACExpr) := do
|
||||
let (preExpr, vars) ←
|
||||
toPreExpr (mkApp2 op l r)
|
||||
|>.run HashSet.empty
|
||||
|>.run Std.HashSet.empty
|
||||
let vars := vars.toArray.insertionSort Expr.lt
|
||||
let varMap := vars.foldl (fun xs x => xs.insert x xs.size) HashMap.empty |>.find!
|
||||
let varMap := vars.foldl (fun xs x => xs.insert x xs.size) Std.HashMap.empty |>.get!
|
||||
|
||||
return (vars, toACExpr varMap preExpr)
|
||||
where
|
||||
|
||||
@@ -290,7 +290,7 @@ structure RewriteResultConfig where
|
||||
side : SideConditions := .solveByElim
|
||||
mctx : MetavarContext
|
||||
|
||||
def takeListAux (cfg : RewriteResultConfig) (seen : HashMap String Unit) (acc : Array RewriteResult)
|
||||
def takeListAux (cfg : RewriteResultConfig) (seen : Std.HashMap String Unit) (acc : Array RewriteResult)
|
||||
(xs : List ((Expr ⊕ Name) × Bool × Nat)) : MetaM (Array RewriteResult) := do
|
||||
let mut seen := seen
|
||||
let mut acc := acc
|
||||
|
||||
@@ -420,12 +420,12 @@ def mkSimpExt (name : Name := by exact decl_name%) : IO SimpExtension :=
|
||||
| .toUnfoldThms n thms => d.registerDeclToUnfoldThms n thms
|
||||
}
|
||||
|
||||
abbrev SimpExtensionMap := HashMap Name SimpExtension
|
||||
abbrev SimpExtensionMap := Std.HashMap Name SimpExtension
|
||||
|
||||
builtin_initialize simpExtensionMapRef : IO.Ref SimpExtensionMap ← IO.mkRef {}
|
||||
|
||||
def getSimpExtension? (attrName : Name) : IO (Option SimpExtension) :=
|
||||
return (← simpExtensionMapRef.get).find? attrName
|
||||
return (← simpExtensionMapRef.get)[attrName]?
|
||||
|
||||
/-- Auxiliary method for adding a global declaration to a `SimpTheorems` datastructure. -/
|
||||
def SimpTheorems.addConst (s : SimpTheorems) (declName : Name) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM SimpTheorems := do
|
||||
|
||||
@@ -19,8 +19,8 @@ It contains:
|
||||
- The actual procedure associated with a name.
|
||||
-/
|
||||
structure BuiltinSimprocs where
|
||||
keys : HashMap Name (Array SimpTheoremKey) := {}
|
||||
procs : HashMap Name (Sum Simproc DSimproc) := {}
|
||||
keys : Std.HashMap Name (Array SimpTheoremKey) := {}
|
||||
procs : Std.HashMap Name (Sum Simproc DSimproc) := {}
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
@@ -37,7 +37,7 @@ structure SimprocDecl where
|
||||
deriving Inhabited
|
||||
|
||||
structure SimprocDeclExtState where
|
||||
builtin : HashMap Name (Array SimpTheoremKey)
|
||||
builtin : Std.HashMap Name (Array SimpTheoremKey)
|
||||
newEntries : PHashMap Name (Array SimpTheoremKey) := {}
|
||||
deriving Inhabited
|
||||
|
||||
@@ -65,7 +65,7 @@ def getSimprocDeclKeys? (declName : Name) : CoreM (Option (Array SimpTheoremKey)
|
||||
if let some keys := keys? then
|
||||
return some keys
|
||||
else
|
||||
return (simprocDeclExt.getState env).builtin.find? declName
|
||||
return (simprocDeclExt.getState env).builtin[declName]?
|
||||
|
||||
def isBuiltinSimproc (declName : Name) : CoreM Bool := do
|
||||
let s := simprocDeclExt.getState (← getEnv)
|
||||
@@ -160,7 +160,7 @@ def Simprocs.addCore (s : Simprocs) (keys : Array SimpTheoremKey) (declName : Na
|
||||
Implements attributes `builtin_simproc` and `builtin_sevalproc`.
|
||||
-/
|
||||
def addSimprocBuiltinAttrCore (ref : IO.Ref Simprocs) (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : IO Unit := do
|
||||
let some keys := (← builtinSimprocDeclsRef.get).keys.find? declName |
|
||||
let some keys := (← builtinSimprocDeclsRef.get).keys[declName]? |
|
||||
throw (IO.userError "invalid [builtin_simproc] attribute, '{declName}' is not a builtin simproc")
|
||||
ref.modify fun s => s.addCore keys declName post proc
|
||||
|
||||
@@ -176,7 +176,7 @@ def Simprocs.add (s : Simprocs) (declName : Name) (post : Bool) : CoreM Simprocs
|
||||
getSimprocFromDecl declName
|
||||
catch e =>
|
||||
if (← isBuiltinSimproc declName) then
|
||||
let some proc := (← builtinSimprocDeclsRef.get).procs.find? declName
|
||||
let some proc := (← builtinSimprocDeclsRef.get).procs[declName]?
|
||||
| throwError "invalid [simproc] attribute, '{declName}' is not a simproc"
|
||||
pure proc
|
||||
else
|
||||
@@ -384,7 +384,7 @@ def mkSimprocAttr (attrName : Name) (attrDescr : String) (ext : SimprocExtension
|
||||
erase := eraseSimprocAttr ext
|
||||
}
|
||||
|
||||
abbrev SimprocExtensionMap := HashMap Name SimprocExtension
|
||||
abbrev SimprocExtensionMap := Std.HashMap Name SimprocExtension
|
||||
|
||||
builtin_initialize simprocExtensionMapRef : IO.Ref SimprocExtensionMap ← IO.mkRef {}
|
||||
|
||||
@@ -438,7 +438,7 @@ def getSEvalSimprocs : CoreM Simprocs :=
|
||||
return simprocSEvalExtension.getState (← getEnv)
|
||||
|
||||
def getSimprocExtensionCore? (attrName : Name) : IO (Option SimprocExtension) :=
|
||||
return (← simprocExtensionMapRef.get).find? attrName
|
||||
return (← simprocExtensionMapRef.get)[attrName]?
|
||||
|
||||
def simpAttrNameToSimprocAttrName (attrName : Name) : Name :=
|
||||
if attrName == `simp then `simprocAttr
|
||||
|
||||
@@ -512,7 +512,7 @@ def mkCongrSimp? (f : Expr) : SimpM (Option CongrTheorem) := do
|
||||
if kinds.all fun k => match k with | CongrArgKind.fixed => true | CongrArgKind.eq => true | _ => false then
|
||||
/- See remark above. -/
|
||||
return none
|
||||
match (← get).congrCache.find? f with
|
||||
match (← get).congrCache[f]? with
|
||||
| some thm? => return thm?
|
||||
| none =>
|
||||
let thm? ← mkCongrSimpCore? f info kinds
|
||||
|
||||
@@ -903,7 +903,7 @@ structure State where
|
||||
mctx : MetavarContext
|
||||
nextMacroScope : MacroScope
|
||||
ngen : NameGenerator
|
||||
cache : HashMap ExprStructEq Expr := {}
|
||||
cache : Std.HashMap ExprStructEq Expr := {}
|
||||
|
||||
structure Context where
|
||||
mainModule : Name
|
||||
@@ -1319,7 +1319,7 @@ structure State where
|
||||
mctx : MetavarContext
|
||||
paramNames : Array Name := #[]
|
||||
nextParamIdx : Nat
|
||||
cache : HashMap ExprStructEq Expr := {}
|
||||
cache : Std.HashMap ExprStructEq Expr := {}
|
||||
|
||||
abbrev M := ReaderT Context <| StateM State
|
||||
|
||||
@@ -1328,7 +1328,7 @@ instance : MonadMCtx M where
|
||||
modifyMCtx f := modify fun s => { s with mctx := f s.mctx }
|
||||
|
||||
instance : MonadCache ExprStructEq Expr M where
|
||||
findCached? e := return (← get).cache.find? e
|
||||
findCached? e := return (← get).cache[e]?
|
||||
cache e v := modify fun s => { s with cache := s.cache.insert e v }
|
||||
|
||||
partial def mkParamName : M Name := do
|
||||
|
||||
@@ -131,7 +131,7 @@ structure ParserCacheEntry where
|
||||
|
||||
structure ParserCache where
|
||||
tokenCache : TokenCacheEntry
|
||||
parserCache : HashMap ParserCacheKey ParserCacheEntry
|
||||
parserCache : Std.HashMap ParserCacheKey ParserCacheEntry
|
||||
|
||||
def initCacheForInput (input : String) : ParserCache where
|
||||
tokenCache := { startPos := input.endPos + ' ' /- make sure it is not a valid position -/ }
|
||||
@@ -418,7 +418,7 @@ place if there was an error.
|
||||
-/
|
||||
def withCacheFn (parserName : Name) (p : ParserFn) : ParserFn := fun c s => Id.run do
|
||||
let key := ⟨c.toCacheableParserContext, parserName, s.pos⟩
|
||||
if let some r := s.cache.parserCache.find? key then
|
||||
if let some r := s.cache.parserCache[key]? then
|
||||
-- TODO: turn this into a proper trace once we have these in the parser
|
||||
--dbg_trace "parser cache hit: {parserName}:{s.pos} -> {r.stx}"
|
||||
return ⟨s.stxStack.push r.stx, r.lhsPrec, r.newPos, s.cache, r.errorMsg, s.recoveredErrors⟩
|
||||
|
||||
@@ -198,10 +198,10 @@ def isHBinOp (e : Expr) : Bool := Id.run do
|
||||
def replaceLPsWithVars (e : Expr) : MetaM Expr := do
|
||||
if !e.hasLevelParam then return e
|
||||
let lps := collectLevelParams {} e |>.params
|
||||
let mut replaceMap : HashMap Name Level := {}
|
||||
let mut replaceMap : Std.HashMap Name Level := {}
|
||||
for lp in lps do replaceMap := replaceMap.insert lp (← mkFreshLevelMVar)
|
||||
return e.replaceLevel fun
|
||||
| Level.param n .. => replaceMap.find! n
|
||||
| Level.param n .. => replaceMap[n]!
|
||||
| l => if !l.hasParam then some l else none
|
||||
|
||||
def isDefEqAssigning (t s : Expr) : MetaM Bool := do
|
||||
|
||||
@@ -29,7 +29,7 @@ namespace Lean.Environment
|
||||
namespace Replay
|
||||
|
||||
structure Context where
|
||||
newConstants : HashMap Name ConstantInfo
|
||||
newConstants : Std.HashMap Name ConstantInfo
|
||||
|
||||
structure State where
|
||||
env : Environment
|
||||
@@ -73,7 +73,7 @@ and add it to the environment.
|
||||
-/
|
||||
partial def replayConstant (name : Name) : M Unit := do
|
||||
if ← isTodo name then
|
||||
let some ci := (← read).newConstants.find? name | unreachable!
|
||||
let some ci := (← read).newConstants[name]? | unreachable!
|
||||
replayConstants ci.getUsedConstantsAsSet
|
||||
-- Check that this name is still pending: a mutual block may have taken care of it.
|
||||
if (← get).pending.contains name then
|
||||
@@ -89,13 +89,13 @@ partial def replayConstant (name : Name) : M Unit := do
|
||||
| .inductInfo info =>
|
||||
let lparams := info.levelParams
|
||||
let nparams := info.numParams
|
||||
let all ← info.all.mapM fun n => do pure <| ((← read).newConstants.find! n)
|
||||
let all ← info.all.mapM fun n => do pure <| ((← read).newConstants[n]!)
|
||||
for o in all do
|
||||
modify fun s =>
|
||||
{ s with remaining := s.remaining.erase o.name, pending := s.pending.erase o.name }
|
||||
let ctorInfo ← all.mapM fun ci => do
|
||||
pure (ci, ← ci.inductiveVal!.ctors.mapM fun n => do
|
||||
pure ((← read).newConstants.find! n))
|
||||
pure ((← read).newConstants[n]!))
|
||||
-- Make sure we are really finished with the constructors.
|
||||
for (_, ctors) in ctorInfo do
|
||||
for ctor in ctors do
|
||||
@@ -129,7 +129,7 @@ when we replayed the inductives.
|
||||
-/
|
||||
def checkPostponedConstructors : M Unit := do
|
||||
for ctor in (← get).postponedConstructors do
|
||||
match (← get).env.constants.find? ctor, (← read).newConstants.find? ctor with
|
||||
match (← get).env.constants.find? ctor, (← read).newConstants[ctor]? with
|
||||
| some (.ctorInfo info), some (.ctorInfo info') =>
|
||||
if ! (info == info') then throw <| IO.userError s!"Invalid constructor {ctor}"
|
||||
| _, _ => throw <| IO.userError s!"No such constructor {ctor}"
|
||||
@@ -140,7 +140,7 @@ when we replayed the inductives.
|
||||
-/
|
||||
def checkPostponedRecursors : M Unit := do
|
||||
for ctor in (← get).postponedRecursors do
|
||||
match (← get).env.constants.find? ctor, (← read).newConstants.find? ctor with
|
||||
match (← get).env.constants.find? ctor, (← read).newConstants[ctor]? with
|
||||
| some (.recInfo info), some (.recInfo info') =>
|
||||
if ! (info == info') then throw <| IO.userError s!"Invalid recursor {ctor}"
|
||||
| _, _ => throw <| IO.userError s!"No such recursor {ctor}"
|
||||
@@ -155,7 +155,7 @@ open Replay
|
||||
Throws a `IO.userError` if the kernel rejects a constant,
|
||||
or if there are malformed recursors or constructors for inductive types.
|
||||
-/
|
||||
def replay (newConstants : HashMap Name ConstantInfo) (env : Environment) : IO Environment := do
|
||||
def replay (newConstants : Std.HashMap Name ConstantInfo) (env : Environment) : IO Environment := do
|
||||
let mut remaining : NameSet := ∅
|
||||
for (n, ci) in newConstants.toList do
|
||||
-- We skip unsafe constants, and also partial constants.
|
||||
|
||||
@@ -81,7 +81,7 @@ open Elab
|
||||
open Meta
|
||||
open FuzzyMatching
|
||||
|
||||
abbrev EligibleHeaderDecls := HashMap Name ConstantInfo
|
||||
abbrev EligibleHeaderDecls := Std.HashMap Name ConstantInfo
|
||||
|
||||
/-- Cached header declarations for which `allowCompletion headerEnv decl` is true. -/
|
||||
builtin_initialize eligibleHeaderDeclsRef : IO.Ref (Option EligibleHeaderDecls) ←
|
||||
|
||||
@@ -316,7 +316,7 @@ partial def handleDocumentHighlight (p : DocumentHighlightParams)
|
||||
let refs : Lsp.ModuleRefs ← findModuleRefs text trees |>.toLspModuleRefs
|
||||
let mut ranges := #[]
|
||||
for ident in refs.findAt p.position do
|
||||
if let some info := refs.find? ident then
|
||||
if let some info := refs.get? ident then
|
||||
if let some ⟨definitionRange, _⟩ := info.definition? then
|
||||
ranges := ranges.push definitionRange
|
||||
ranges := ranges.append <| info.usages.map (·.range)
|
||||
|
||||
@@ -93,20 +93,20 @@ def toLspRefInfo (i : RefInfo) : BaseIO Lsp.RefInfo := do
|
||||
end RefInfo
|
||||
|
||||
/-- All references from within a module for all identifiers used in a single module. -/
|
||||
def ModuleRefs := HashMap RefIdent RefInfo
|
||||
def ModuleRefs := Std.HashMap RefIdent RefInfo
|
||||
|
||||
namespace ModuleRefs
|
||||
|
||||
/-- Adds `ref` to the `RefInfo` corresponding to `ref.ident` in `self`. See `RefInfo.addRef`. -/
|
||||
def addRef (self : ModuleRefs) (ref : Reference) : ModuleRefs :=
|
||||
let refInfo := self.findD ref.ident RefInfo.empty
|
||||
let refInfo := self.getD ref.ident RefInfo.empty
|
||||
self.insert ref.ident (refInfo.addRef ref)
|
||||
|
||||
/-- Converts `refs` to a JSON-serializable `Lsp.ModuleRefs`. -/
|
||||
def toLspModuleRefs (refs : ModuleRefs) : BaseIO Lsp.ModuleRefs := do
|
||||
let refs ← refs.toList.mapM fun (k, v) => do
|
||||
return (k, ← v.toLspRefInfo)
|
||||
return HashMap.ofList refs
|
||||
return Std.HashMap.ofList refs
|
||||
|
||||
end ModuleRefs
|
||||
|
||||
@@ -261,7 +261,7 @@ all identifiers that are being collapsed into one.
|
||||
-/
|
||||
partial def combineIdents (trees : Array InfoTree) (refs : Array Reference) : Array Reference := Id.run do
|
||||
-- Deduplicate definitions based on their exact range
|
||||
let mut posMap : HashMap Lsp.Range RefIdent := HashMap.empty
|
||||
let mut posMap : Std.HashMap Lsp.Range RefIdent := Std.HashMap.empty
|
||||
for ref in refs do
|
||||
if let { ident, range, isBinder := true, .. } := ref then
|
||||
posMap := posMap.insert range ident
|
||||
@@ -277,17 +277,17 @@ partial def combineIdents (trees : Array InfoTree) (refs : Array Reference) : Ar
|
||||
refs' := refs'.push ref
|
||||
refs'
|
||||
where
|
||||
useConstRepresentatives (idMap : HashMap RefIdent RefIdent)
|
||||
: HashMap RefIdent RefIdent := Id.run do
|
||||
useConstRepresentatives (idMap : Std.HashMap RefIdent RefIdent)
|
||||
: Std.HashMap RefIdent RefIdent := Id.run do
|
||||
let insertIntoClass classesById id :=
|
||||
let representative := findCanonicalRepresentative idMap id
|
||||
let «class» := classesById.findD representative ∅
|
||||
let «class» := classesById.getD representative ∅
|
||||
let classesById := classesById.erase representative -- make `«class»` referentially unique
|
||||
let «class» := «class».insert id
|
||||
classesById.insert representative «class»
|
||||
|
||||
-- collect equivalence classes
|
||||
let mut classesById : HashMap RefIdent (HashSet RefIdent) := ∅
|
||||
let mut classesById : Std.HashMap RefIdent (Std.HashSet RefIdent) := ∅
|
||||
for ⟨id, baseId⟩ in idMap.toArray do
|
||||
classesById := insertIntoClass classesById id
|
||||
classesById := insertIntoClass classesById baseId
|
||||
@@ -310,17 +310,17 @@ where
|
||||
r := r.insert id bestRepresentative
|
||||
return r
|
||||
|
||||
findCanonicalRepresentative (idMap : HashMap RefIdent RefIdent) (id : RefIdent) : RefIdent := Id.run do
|
||||
findCanonicalRepresentative (idMap : Std.HashMap RefIdent RefIdent) (id : RefIdent) : RefIdent := Id.run do
|
||||
let mut canonicalRepresentative := id
|
||||
while idMap.contains canonicalRepresentative do
|
||||
canonicalRepresentative := idMap.find! canonicalRepresentative
|
||||
canonicalRepresentative := idMap[canonicalRepresentative]!
|
||||
return canonicalRepresentative
|
||||
|
||||
buildIdMap posMap := Id.run <| StateT.run' (s := HashMap.empty) do
|
||||
buildIdMap posMap := Id.run <| StateT.run' (s := Std.HashMap.empty) do
|
||||
-- map fvar defs to overlapping fvar defs/uses
|
||||
for ref in refs do
|
||||
let baseId := ref.ident
|
||||
if let some id := posMap.find? ref.range then
|
||||
if let some id := posMap[ref.range]? then
|
||||
insertIdMap id baseId
|
||||
|
||||
-- apply `FVarAliasInfo`
|
||||
@@ -346,11 +346,11 @@ are added to the `aliases` of the representative of the group.
|
||||
Yields to separate groups for declaration and usages if `allowSimultaneousBinderUse` is set.
|
||||
-/
|
||||
def dedupReferences (refs : Array Reference) (allowSimultaneousBinderUse := false) : Array Reference := Id.run do
|
||||
let mut refsByIdAndRange : HashMap (RefIdent × Option Bool × Lsp.Range) Reference := HashMap.empty
|
||||
let mut refsByIdAndRange : Std.HashMap (RefIdent × Option Bool × Lsp.Range) Reference := Std.HashMap.empty
|
||||
for ref in refs do
|
||||
let isBinder := if allowSimultaneousBinderUse then some ref.isBinder else none
|
||||
let key := (ref.ident, isBinder, ref.range)
|
||||
refsByIdAndRange := match refsByIdAndRange[key] with
|
||||
refsByIdAndRange := match refsByIdAndRange[key]? with
|
||||
| some ref' => refsByIdAndRange.insert key { ref' with aliases := ref'.aliases ++ ref.aliases }
|
||||
| none => refsByIdAndRange.insert key ref
|
||||
|
||||
@@ -371,21 +371,21 @@ def findModuleRefs (text : FileMap) (trees : Array InfoTree) (localVars : Bool :
|
||||
refs := refs.filter fun
|
||||
| { ident := RefIdent.fvar .., .. } => false
|
||||
| _ => true
|
||||
refs.foldl (init := HashMap.empty) fun m ref => m.addRef ref
|
||||
refs.foldl (init := Std.HashMap.empty) fun m ref => m.addRef ref
|
||||
|
||||
/-! # Collecting and maintaining reference info from different sources -/
|
||||
|
||||
/-- References from ilean files and current ilean information from file workers. -/
|
||||
structure References where
|
||||
/-- References loaded from ilean files -/
|
||||
ileans : HashMap Name (System.FilePath × Lsp.ModuleRefs)
|
||||
ileans : Std.HashMap Name (System.FilePath × Lsp.ModuleRefs)
|
||||
/-- References from workers, overriding the corresponding ilean files -/
|
||||
workers : HashMap Name (Nat × Lsp.ModuleRefs)
|
||||
workers : Std.HashMap Name (Nat × Lsp.ModuleRefs)
|
||||
|
||||
namespace References
|
||||
|
||||
/-- No ilean files, no information from workers. -/
|
||||
def empty : References := { ileans := HashMap.empty, workers := HashMap.empty }
|
||||
def empty : References := { ileans := Std.HashMap.empty, workers := Std.HashMap.empty }
|
||||
|
||||
/-- Adds the contents of an ilean file `ilean` at `path` to `self`. -/
|
||||
def addIlean (self : References) (path : System.FilePath) (ilean : Ilean) : References :=
|
||||
@@ -404,13 +404,13 @@ Replaces the current references with `refs` if `version` is newer than the curre
|
||||
in `refs` and otherwise merges the reference data if `version` is equal to the current version.
|
||||
-/
|
||||
def updateWorkerRefs (self : References) (name : Name) (version : Nat) (refs : Lsp.ModuleRefs) : References := Id.run do
|
||||
if let some (currVersion, _) := self.workers.find? name then
|
||||
if let some (currVersion, _) := self.workers[name]? then
|
||||
if version > currVersion then
|
||||
return { self with workers := self.workers.insert name (version, refs) }
|
||||
if version == currVersion then
|
||||
let current := self.workers.findD name (version, HashMap.empty)
|
||||
let current := self.workers.getD name (version, Std.HashMap.empty)
|
||||
let merged := refs.fold (init := current.snd) fun m ident info =>
|
||||
m.findD ident Lsp.RefInfo.empty |>.merge info |> m.insert ident
|
||||
m.getD ident Lsp.RefInfo.empty |>.merge info |> m.insert ident
|
||||
return { self with workers := self.workers.insert name (version, merged) }
|
||||
return self
|
||||
|
||||
@@ -419,7 +419,7 @@ Replaces the worker references in `self` with the `refs` of the worker managing
|
||||
if `version` is newer than the current version managed in `refs`.
|
||||
-/
|
||||
def finalizeWorkerRefs (self : References) (name : Name) (version : Nat) (refs : Lsp.ModuleRefs) : References := Id.run do
|
||||
if let some (currVersion, _) := self.workers.find? name then
|
||||
if let some (currVersion, _) := self.workers[name]? then
|
||||
if version < currVersion then
|
||||
return self
|
||||
return { self with workers := self.workers.insert name (version, refs) }
|
||||
@@ -429,8 +429,8 @@ def removeWorkerRefs (self : References) (name : Name) : References :=
|
||||
{ self with workers := self.workers.erase name }
|
||||
|
||||
/-- Yields a map from all modules to all of their references. -/
|
||||
def allRefs (self : References) : HashMap Name Lsp.ModuleRefs :=
|
||||
let ileanRefs := self.ileans.toArray.foldl (init := HashMap.empty) fun m (name, _, refs) => m.insert name refs
|
||||
def allRefs (self : References) : Std.HashMap Name Lsp.ModuleRefs :=
|
||||
let ileanRefs := self.ileans.toArray.foldl (init := Std.HashMap.empty) fun m (name, _, refs) => m.insert name refs
|
||||
self.workers.toArray.foldl (init := ileanRefs) fun m (name, _, refs) => m.insert name refs
|
||||
|
||||
/--
|
||||
@@ -445,12 +445,12 @@ def allRefsFor
|
||||
let refsToCheck := match ident with
|
||||
| RefIdent.const .. => self.allRefs.toArray
|
||||
| RefIdent.fvar identModule .. =>
|
||||
match self.allRefs.find? identModule with
|
||||
match self.allRefs[identModule]? with
|
||||
| none => #[]
|
||||
| some refs => #[(identModule, refs)]
|
||||
let mut result := #[]
|
||||
for (module, refs) in refsToCheck do
|
||||
let some info := refs.find? ident
|
||||
let some info := refs.get? ident
|
||||
| continue
|
||||
let some path ← srcSearchPath.findModuleWithExt "lean" module
|
||||
| continue
|
||||
@@ -462,13 +462,13 @@ def allRefsFor
|
||||
|
||||
/-- Yields all references in `module` at `pos`. -/
|
||||
def findAt (self : References) (module : Name) (pos : Lsp.Position) (includeStop := false) : Array RefIdent := Id.run do
|
||||
if let some refs := self.allRefs.find? module then
|
||||
if let some refs := self.allRefs[module]? then
|
||||
return refs.findAt pos includeStop
|
||||
#[]
|
||||
|
||||
/-- Yields the first reference in `module` at `pos`. -/
|
||||
def findRange? (self : References) (module : Name) (pos : Lsp.Position) (includeStop := false) : Option Range := do
|
||||
let refs ← self.allRefs.find? module
|
||||
let refs ← self.allRefs[module]?
|
||||
refs.findRange? pos includeStop
|
||||
|
||||
/-- Location and parent declaration of a reference. -/
|
||||
|
||||
@@ -90,6 +90,10 @@ section Utils
|
||||
| crashed (e : IO.Error)
|
||||
| ioError (e : IO.Error)
|
||||
|
||||
inductive CrashOrigin
|
||||
| fileWorkerToClientForwarding
|
||||
| clientToFileWorkerForwarding
|
||||
|
||||
inductive WorkerState where
|
||||
/-- The watchdog can detect a crashed file worker in two places: When trying to send a message
|
||||
to the file worker and when reading a request reply.
|
||||
@@ -98,7 +102,7 @@ section Utils
|
||||
that are in-flight are errored. Upon receiving the next packet for that file worker, the file
|
||||
worker is restarted and the packet is forwarded to it. If the crash was detected while writing
|
||||
a packet, we queue that packet until the next packet for the file worker arrives. -/
|
||||
| crashed (queuedMsgs : Array JsonRpc.Message)
|
||||
| crashed (queuedMsgs : Array JsonRpc.Message) (origin : CrashOrigin)
|
||||
| running
|
||||
|
||||
abbrev PendingRequestMap := RBMap RequestID JsonRpc.Message compare
|
||||
@@ -136,6 +140,11 @@ section FileWorker
|
||||
for ⟨id, _⟩ in pendingRequests do
|
||||
hError.writeLspResponseError { id := id, code := code, message := msg }
|
||||
|
||||
def queuedMsgs (fw : FileWorker) : Array JsonRpc.Message :=
|
||||
match fw.state with
|
||||
| .running => #[]
|
||||
| .crashed queuedMsgs _ => queuedMsgs
|
||||
|
||||
end FileWorker
|
||||
end FileWorker
|
||||
|
||||
@@ -404,10 +413,23 @@ section ServerM
|
||||
return
|
||||
eraseFileWorker uri
|
||||
|
||||
def handleCrash (uri : DocumentUri) (queuedMsgs : Array JsonRpc.Message) : ServerM Unit := do
|
||||
def handleCrash (uri : DocumentUri) (queuedMsgs : Array JsonRpc.Message) (origin: CrashOrigin) : ServerM Unit := do
|
||||
let some fw ← findFileWorker? uri
|
||||
| return
|
||||
updateFileWorkers { fw with state := WorkerState.crashed queuedMsgs }
|
||||
updateFileWorkers { fw with state := WorkerState.crashed queuedMsgs origin }
|
||||
|
||||
def tryDischargeQueuedMessages (uri : DocumentUri) (queuedMsgs : Array JsonRpc.Message) : ServerM Unit := do
|
||||
let some fw ← findFileWorker? uri
|
||||
| throwServerError "Cannot find file worker for '{uri}'."
|
||||
let mut crashedMsgs := #[]
|
||||
-- Try to discharge all queued msgs, tracking the ones that we can't discharge
|
||||
for msg in queuedMsgs do
|
||||
try
|
||||
fw.stdin.writeLspMessage msg
|
||||
catch _ =>
|
||||
crashedMsgs := crashedMsgs.push msg
|
||||
if ¬ crashedMsgs.isEmpty then
|
||||
handleCrash uri crashedMsgs .clientToFileWorkerForwarding
|
||||
|
||||
/-- Tries to write a message, sets the state of the FileWorker to `crashed` if it does not succeed
|
||||
and restarts the file worker if the `crashed` flag was already set. Just logs an error if
|
||||
@@ -423,7 +445,7 @@ section ServerM
|
||||
let some fw ← findFileWorker? uri
|
||||
| return
|
||||
match fw.state with
|
||||
| WorkerState.crashed queuedMsgs =>
|
||||
| WorkerState.crashed queuedMsgs _ =>
|
||||
let mut queuedMsgs := queuedMsgs
|
||||
if queueFailedMessage then
|
||||
queuedMsgs := queuedMsgs.push msg
|
||||
@@ -432,17 +454,7 @@ section ServerM
|
||||
-- restart the crashed FileWorker
|
||||
eraseFileWorker uri
|
||||
startFileWorker fw.doc
|
||||
let some newFw ← findFileWorker? uri
|
||||
| throwServerError "Cannot find file worker for '{uri}'."
|
||||
let mut crashedMsgs := #[]
|
||||
-- try to discharge all queued msgs, tracking the ones that we can't discharge
|
||||
for msg in queuedMsgs do
|
||||
try
|
||||
newFw.stdin.writeLspMessage msg
|
||||
catch _ =>
|
||||
crashedMsgs := crashedMsgs.push msg
|
||||
if ¬ crashedMsgs.isEmpty then
|
||||
handleCrash uri crashedMsgs
|
||||
tryDischargeQueuedMessages uri queuedMsgs
|
||||
| WorkerState.running =>
|
||||
let initialQueuedMsgs :=
|
||||
if queueFailedMessage then
|
||||
@@ -452,7 +464,7 @@ section ServerM
|
||||
try
|
||||
fw.stdin.writeLspMessage msg
|
||||
catch _ =>
|
||||
handleCrash uri initialQueuedMsgs
|
||||
handleCrash uri initialQueuedMsgs .clientToFileWorkerForwarding
|
||||
|
||||
/--
|
||||
Sends a notification to the file worker identified by `uri` that its dependency `staleDependency`
|
||||
@@ -638,7 +650,7 @@ def handleCallHierarchyOutgoingCalls (p : CallHierarchyOutgoingCallsParams)
|
||||
|
||||
let references ← (← read).references.get
|
||||
|
||||
let some refs := references.allRefs.find? module
|
||||
let some refs := references.allRefs[module]?
|
||||
| return #[]
|
||||
|
||||
let items ← refs.toArray.filterMapM fun ⟨ident, info⟩ => do
|
||||
@@ -702,9 +714,9 @@ def handlePrepareRename (p : PrepareRenameParams) : ServerM (Option Range) := do
|
||||
def handleRename (p : RenameParams) : ServerM Lsp.WorkspaceEdit := do
|
||||
if (String.toName p.newName).isAnonymous then
|
||||
throwServerError s!"Can't rename: `{p.newName}` is not an identifier"
|
||||
let mut refs : HashMap DocumentUri (RBMap Lsp.Position Lsp.Position compare) := ∅
|
||||
let mut refs : Std.HashMap DocumentUri (RBMap Lsp.Position Lsp.Position compare) := ∅
|
||||
for { uri, range } in (← handleReference { p with context.includeDeclaration := true }) do
|
||||
refs := refs.insert uri <| (refs.findD uri ∅).insert range.start range.end
|
||||
refs := refs.insert uri <| (refs.getD uri ∅).insert range.start range.end
|
||||
-- We have to filter the list of changes to put the ranges in order and
|
||||
-- remove any duplicates or overlapping ranges, or else the rename will not apply
|
||||
let changes := refs.fold (init := ∅) fun changes uri map => Id.run do
|
||||
@@ -955,7 +967,16 @@ section MainLoop
|
||||
let workers ← st.fileWorkersRef.get
|
||||
let mut workerTasks := #[]
|
||||
for (_, fw) in workers do
|
||||
if let WorkerState.running := fw.state then
|
||||
-- When the forwarding task crashes, its return value will be stuck at
|
||||
-- `WorkerEvent.crashed _`.
|
||||
-- We want to handle this event only once, not over and over again,
|
||||
-- so once the state becomes `WorkerState.crashed _ .fileWorkerToClientForwarding`
|
||||
-- as a result of `WorkerEvent.crashed _`, we stop handling this event until
|
||||
-- eventually the file worker is restarted by a notification from the client.
|
||||
-- We do not want to filter the forwarding task in case of
|
||||
-- `WorkerState.crashed _ .clientToFileWorkerForwarding`, since the forwarding task
|
||||
-- exit code may still contain valuable information in this case (e.g. that the imports changed).
|
||||
if !(fw.state matches WorkerState.crashed _ .fileWorkerToClientForwarding) then
|
||||
workerTasks := workerTasks.push <| fw.commTask.map (ServerEvent.workerEvent fw)
|
||||
|
||||
let ev ← IO.waitAny (clientTask :: workerTasks.toList)
|
||||
@@ -984,13 +1005,16 @@ section MainLoop
|
||||
| WorkerEvent.ioError e =>
|
||||
throwServerError s!"IO error while processing events for {fw.doc.uri}: {e}"
|
||||
| WorkerEvent.crashed _ =>
|
||||
handleCrash fw.doc.uri #[]
|
||||
handleCrash fw.doc.uri fw.queuedMsgs .fileWorkerToClientForwarding
|
||||
mainLoop clientTask
|
||||
| WorkerEvent.terminated =>
|
||||
throwServerError <| "Internal server error: got termination event for worker that "
|
||||
++ "should have been removed"
|
||||
| .importsChanged =>
|
||||
let uri := fw.doc.uri
|
||||
let queuedMsgs := fw.queuedMsgs
|
||||
startFileWorker fw.doc
|
||||
tryDischargeQueuedMessages uri queuedMsgs
|
||||
mainLoop clientTask
|
||||
end MainLoop
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ Authors: David Thrane Christiansen
|
||||
-/
|
||||
prelude
|
||||
import Init.Data
|
||||
import Lean.Data.HashMap
|
||||
import Std.Data.HashMap.Basic
|
||||
import Init.Omega
|
||||
|
||||
namespace Lean.Diff
|
||||
@@ -57,7 +57,7 @@ structure Histogram.Entry (α : Type u) (lsize rsize : Nat) where
|
||||
|
||||
/-- A histogram for arrays maps each element to a count and, if applicable, an index.-/
|
||||
def Histogram (α : Type u) (lsize rsize : Nat) [BEq α] [Hashable α] :=
|
||||
Lean.HashMap α (Histogram.Entry α lsize rsize)
|
||||
Std.HashMap α (Histogram.Entry α lsize rsize)
|
||||
|
||||
|
||||
section
|
||||
@@ -67,7 +67,7 @@ variable [BEq α] [Hashable α]
|
||||
/-- Add an element from the left array to a histogram -/
|
||||
def Histogram.addLeft (histogram : Histogram α lsize rsize) (index : Fin lsize) (val : α)
|
||||
: Histogram α lsize rsize :=
|
||||
match histogram.find? val with
|
||||
match histogram.get? val with
|
||||
| none => histogram.insert val {
|
||||
leftCount := 1, leftIndex := some index,
|
||||
leftWF := by simp,
|
||||
@@ -81,7 +81,7 @@ def Histogram.addLeft (histogram : Histogram α lsize rsize) (index : Fin lsize)
|
||||
/-- Add an element from the right array to a histogram -/
|
||||
def Histogram.addRight (histogram : Histogram α lsize rsize) (index : Fin rsize) (val : α)
|
||||
: Histogram α lsize rsize :=
|
||||
match histogram.find? val with
|
||||
match histogram.get? val with
|
||||
| none => histogram.insert val {
|
||||
leftCount := 0, leftIndex := none,
|
||||
leftWF := by simp,
|
||||
|
||||
@@ -28,7 +28,7 @@ structure State where
|
||||
Set of visited subterms that satisfy the predicate `p`.
|
||||
We have to use this set to make sure `f` is applied at most once of each subterm that satisfies `p`.
|
||||
-/
|
||||
checked : HashSet Expr
|
||||
checked : Std.HashSet Expr
|
||||
|
||||
unsafe def initCache : State := {
|
||||
visited := mkArray cacheSize.toNat (cast lcProof ())
|
||||
|
||||
@@ -5,14 +5,15 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Expr
|
||||
import Std.Data.HashMap.Raw
|
||||
|
||||
namespace Lean
|
||||
|
||||
structure HasConstCache (declNames : Array Name) where
|
||||
cache : HashMapImp Expr Bool := mkHashMapImp
|
||||
cache : Std.HashMap.Raw Expr Bool := Std.HashMap.Raw.empty
|
||||
|
||||
unsafe def HasConstCache.containsUnsafe (e : Expr) : StateM (HasConstCache declNames) Bool := do
|
||||
if let some r := (← get).cache.find? (beq := ⟨ptrEq⟩) e then
|
||||
if let some r := (← get).cache.get? (beq := ⟨ptrEq⟩) e then
|
||||
return r
|
||||
else
|
||||
match e with
|
||||
@@ -26,7 +27,7 @@ unsafe def HasConstCache.containsUnsafe (e : Expr) : StateM (HasConstCache declN
|
||||
| _ => return false
|
||||
where
|
||||
cache (e : Expr) (r : Bool) : StateM (HasConstCache declNames) Bool := do
|
||||
modify fun ⟨cache⟩ => ⟨cache.insert (beq := ⟨ptrEq⟩) e r |>.1⟩
|
||||
modify fun ⟨cache⟩ => ⟨cache.insert (beq := ⟨ptrEq⟩) e r⟩
|
||||
return r
|
||||
|
||||
/--
|
||||
|
||||
@@ -5,7 +5,7 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Control.StateRef
|
||||
import Lean.Data.HashMap
|
||||
import Std.Data.HashMap.Basic
|
||||
|
||||
namespace Lean
|
||||
/-- Interface for caching results. -/
|
||||
@@ -36,15 +36,15 @@ instance {α β ε : Type} {m : Type → Type} [MonadCache α β m] [Monad m] :
|
||||
/-- Adapter for implementing `MonadCache` interface using `HashMap`s.
|
||||
We just have to specify how to extract/modify the `HashMap`. -/
|
||||
class MonadHashMapCacheAdapter (α β : Type) (m : Type → Type) [BEq α] [Hashable α] where
|
||||
getCache : m (HashMap α β)
|
||||
modifyCache : (HashMap α β → HashMap α β) → m Unit
|
||||
getCache : m (Std.HashMap α β)
|
||||
modifyCache : (Std.HashMap α β → Std.HashMap α β) → m Unit
|
||||
|
||||
namespace MonadHashMapCacheAdapter
|
||||
|
||||
@[always_inline, inline]
|
||||
def findCached? {α β : Type} {m : Type → Type} [BEq α] [Hashable α] [Monad m] [MonadHashMapCacheAdapter α β m] (a : α) : m (Option β) := do
|
||||
let c ← getCache
|
||||
pure (c.find? a)
|
||||
pure (c.get? a)
|
||||
|
||||
@[always_inline, inline]
|
||||
def cache {α β : Type} {m : Type → Type} [BEq α] [Hashable α] [MonadHashMapCacheAdapter α β m] (a : α) (b : β) : m Unit :=
|
||||
@@ -56,7 +56,7 @@ instance {α β : Type} {m : Type → Type} [BEq α] [Hashable α] [Monad m] [Mo
|
||||
|
||||
end MonadHashMapCacheAdapter
|
||||
|
||||
def MonadCacheT {ω} (α β : Type) (m : Type → Type) [STWorld ω m] [BEq α] [Hashable α] := StateRefT (HashMap α β) m
|
||||
def MonadCacheT {ω} (α β : Type) (m : Type → Type) [STWorld ω m] [BEq α] [Hashable α] := StateRefT (Std.HashMap α β) m
|
||||
|
||||
namespace MonadCacheT
|
||||
|
||||
@@ -67,7 +67,7 @@ instance : MonadHashMapCacheAdapter α β (MonadCacheT α β m) where
|
||||
modifyCache f := (modify f : StateRefT' ..)
|
||||
|
||||
@[inline] def run {σ} (x : MonadCacheT α β m σ) : m σ :=
|
||||
x.run' mkHashMap
|
||||
x.run' Std.HashMap.empty
|
||||
|
||||
instance : Monad (MonadCacheT α β m) := inferInstanceAs (Monad (StateRefT' _ _ _))
|
||||
instance : MonadLift m (MonadCacheT α β m) := inferInstanceAs (MonadLift m (StateRefT' _ _ _))
|
||||
@@ -80,7 +80,7 @@ instance [Alternative m] : Alternative (MonadCacheT α β m) := inferInstanceAs
|
||||
end MonadCacheT
|
||||
|
||||
/- Similar to `MonadCacheT`, but using `StateT` instead of `StateRefT` -/
|
||||
def MonadStateCacheT (α β : Type) (m : Type → Type) [BEq α] [Hashable α] := StateT (HashMap α β) m
|
||||
def MonadStateCacheT (α β : Type) (m : Type → Type) [BEq α] [Hashable α] := StateT (Std.HashMap α β) m
|
||||
|
||||
namespace MonadStateCacheT
|
||||
|
||||
@@ -91,7 +91,7 @@ instance : MonadHashMapCacheAdapter α β (MonadStateCacheT α β m) where
|
||||
modifyCache f := (modify f : StateT ..)
|
||||
|
||||
@[always_inline, inline] def run {σ} (x : MonadStateCacheT α β m σ) : m σ :=
|
||||
x.run' mkHashMap
|
||||
x.run' Std.HashMap.empty
|
||||
|
||||
instance : Monad (MonadStateCacheT α β m) := inferInstanceAs (Monad (StateT _ _))
|
||||
instance : MonadLift m (MonadStateCacheT α β m) := inferInstanceAs (MonadLift m (StateT _ _))
|
||||
|
||||
@@ -96,9 +96,9 @@ deriving FromJson, ToJson
|
||||
|
||||
/-- Thread with maps necessary for computing max sharing indices -/
|
||||
structure ThreadWithMaps extends Thread where
|
||||
stringMap : HashMap String Nat := {}
|
||||
funcMap : HashMap Nat Nat := {}
|
||||
stackMap : HashMap (Nat × Option Nat) Nat := {}
|
||||
stringMap : Std.HashMap String Nat := {}
|
||||
funcMap : Std.HashMap Nat Nat := {}
|
||||
stackMap : Std.HashMap (Nat × Option Nat) Nat := {}
|
||||
/-- Last timestamp encountered: stop time of preceding sibling, or else start time of parent. -/
|
||||
lastTime : Float := 0
|
||||
|
||||
@@ -123,7 +123,7 @@ where
|
||||
if pp then
|
||||
funcName := s!"{funcName}: {← msg.format}"
|
||||
let strIdx ← modifyGet fun thread =>
|
||||
if let some idx := thread.stringMap.find? funcName then
|
||||
if let some idx := thread.stringMap[funcName]? then
|
||||
(idx, thread)
|
||||
else
|
||||
(thread.stringMap.size, { thread with
|
||||
@@ -131,7 +131,7 @@ where
|
||||
stringMap := thread.stringMap.insert funcName thread.stringMap.size })
|
||||
let category := categories.findIdx? (·.name == data.cls.getRoot.toString) |>.getD 0
|
||||
let funcIdx ← modifyGet fun thread =>
|
||||
if let some idx := thread.funcMap.find? strIdx then
|
||||
if let some idx := thread.funcMap[strIdx]? then
|
||||
(idx, thread)
|
||||
else
|
||||
(thread.funcMap.size, { thread with
|
||||
@@ -151,7 +151,7 @@ where
|
||||
funcMap := thread.funcMap.insert strIdx thread.funcMap.size })
|
||||
let frameIdx := funcIdx
|
||||
let stackIdx ← modifyGet fun thread =>
|
||||
if let some idx := thread.stackMap.find? (frameIdx, parentStackIdx?) then
|
||||
if let some idx := thread.stackMap[(frameIdx, parentStackIdx?)]? then
|
||||
(idx, thread)
|
||||
else
|
||||
(thread.stackMap.size, { thread with
|
||||
@@ -222,7 +222,7 @@ def Profile.export (name : String) (startTime : Milliseconds) (traceState : Trac
|
||||
|
||||
structure ThreadWithCollideMaps extends ThreadWithMaps where
|
||||
/-- Max sharing map for samples -/
|
||||
sampleMap : HashMap Nat Nat := {}
|
||||
sampleMap : Std.HashMap Nat Nat := {}
|
||||
|
||||
/--
|
||||
Adds samples from `add` to `thread`, increasing the weight of existing samples with identical stacks
|
||||
@@ -237,7 +237,7 @@ where
|
||||
let oldStackIdx := add.samples.stack[oldSampleIdx]!
|
||||
let stackIdx ← collideStacks oldStackIdx
|
||||
modify fun thread =>
|
||||
if let some idx := thread.sampleMap.find? stackIdx then
|
||||
if let some idx := thread.sampleMap[stackIdx]? then
|
||||
-- imperative to preserve linear use of arrays here!
|
||||
let ⟨⟨⟨t1, t2, t3, samples, t5, t6, t7, t8, t9, t10⟩, o2, o3, o4, o5⟩, o6⟩ := thread
|
||||
let ⟨s1, s2, weight, s3, s4⟩ := samples
|
||||
@@ -265,7 +265,7 @@ where
|
||||
let oldStrIdx := add.funcTable.name[oldFuncIdx]!
|
||||
let strIdx ← getStrIdx add.stringArray[oldStrIdx]!
|
||||
let funcIdx ← modifyGet fun thread =>
|
||||
if let some idx := thread.funcMap.find? strIdx then
|
||||
if let some idx := thread.funcMap[strIdx]? then
|
||||
(idx, thread)
|
||||
else
|
||||
(thread.funcMap.size, { thread with
|
||||
@@ -284,7 +284,7 @@ where
|
||||
funcMap := thread.funcMap.insert strIdx thread.funcMap.size })
|
||||
let frameIdx := funcIdx
|
||||
modifyGet fun thread =>
|
||||
if let some idx := thread.stackMap.find? (frameIdx, parentStackIdx?) then
|
||||
if let some idx := thread.stackMap[(frameIdx, parentStackIdx?)]? then
|
||||
(idx, thread)
|
||||
else
|
||||
(thread.stackMap.size,
|
||||
@@ -302,7 +302,7 @@ where
|
||||
⟨⟨⟨t1,t2, t3, t4, t5, stackTable, t7, t8, t9, t10⟩, o2, o3, stackMap, o5⟩, o6⟩)
|
||||
getStrIdx (s : String) :=
|
||||
modifyGet fun thread =>
|
||||
if let some idx := thread.stringMap.find? s then
|
||||
if let some idx := thread.stringMap[s]? then
|
||||
(idx, thread)
|
||||
else
|
||||
(thread.stringMap.size, { thread with
|
||||
|
||||
@@ -5,8 +5,8 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Hashable
|
||||
import Lean.Data.HashSet
|
||||
import Lean.Data.HashMap
|
||||
import Std.Data.HashSet.Basic
|
||||
import Std.Data.HashMap.Basic
|
||||
|
||||
namespace Lean
|
||||
|
||||
@@ -23,33 +23,33 @@ unsafe instance : BEq (Ptr α) where
|
||||
Set of pointers. It is a low-level auxiliary datastructure used for traversing DAGs.
|
||||
-/
|
||||
unsafe def PtrSet (α : Type) :=
|
||||
HashSet (Ptr α)
|
||||
Std.HashSet (Ptr α)
|
||||
|
||||
unsafe def mkPtrSet {α : Type} (capacity : Nat := 64) : PtrSet α :=
|
||||
mkHashSet capacity
|
||||
Std.HashSet.empty capacity
|
||||
|
||||
unsafe abbrev PtrSet.insert (s : PtrSet α) (a : α) : PtrSet α :=
|
||||
HashSet.insert s { value := a }
|
||||
Std.HashSet.insert s { value := a }
|
||||
|
||||
unsafe abbrev PtrSet.contains (s : PtrSet α) (a : α) : Bool :=
|
||||
HashSet.contains s { value := a }
|
||||
Std.HashSet.contains s { value := a }
|
||||
|
||||
/--
|
||||
Map of pointers. It is a low-level auxiliary datastructure used for traversing DAGs.
|
||||
-/
|
||||
unsafe def PtrMap (α : Type) (β : Type) :=
|
||||
HashMap (Ptr α) β
|
||||
Std.HashMap (Ptr α) β
|
||||
|
||||
unsafe def mkPtrMap {α β : Type} (capacity : Nat := 64) : PtrMap α β :=
|
||||
mkHashMap capacity
|
||||
Std.HashMap.empty capacity
|
||||
|
||||
unsafe abbrev PtrMap.insert (s : PtrMap α β) (a : α) (b : β) : PtrMap α β :=
|
||||
HashMap.insert s { value := a } b
|
||||
Std.HashMap.insert s { value := a } b
|
||||
|
||||
unsafe abbrev PtrMap.contains (s : PtrMap α β) (a : α) : Bool :=
|
||||
HashMap.contains s { value := a }
|
||||
Std.HashMap.contains s { value := a }
|
||||
|
||||
unsafe abbrev PtrMap.find? (s : PtrMap α β) (a : α) : Option β :=
|
||||
HashMap.find? s { value := a }
|
||||
Std.HashMap.get? s { value := a }
|
||||
|
||||
end Lean
|
||||
|
||||
@@ -5,7 +5,7 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.List.Control
|
||||
import Lean.Data.HashMap
|
||||
import Std.Data.HashMap.Basic
|
||||
namespace Lean.SCC
|
||||
/-!
|
||||
Very simple implementation of Tarjan's SCC algorithm.
|
||||
@@ -25,7 +25,7 @@ structure Data where
|
||||
structure State where
|
||||
stack : List α := []
|
||||
nextIndex : Nat := 0
|
||||
data : HashMap α Data := {}
|
||||
data : Std.HashMap α Data := {}
|
||||
sccs : List (List α) := []
|
||||
|
||||
abbrev M := StateM (State α)
|
||||
@@ -35,7 +35,7 @@ variable {α : Type} [BEq α] [Hashable α]
|
||||
|
||||
private def getDataOf (a : α) : M α Data := do
|
||||
let s ← get
|
||||
match s.data.find? a with
|
||||
match s.data[a]? with
|
||||
| some d => pure d
|
||||
| none => pure {}
|
||||
|
||||
@@ -52,7 +52,7 @@ private def push (a : α) : M α Unit :=
|
||||
|
||||
private def modifyDataOf (a : α) (f : Data → Data) : M α Unit :=
|
||||
modify fun s => { s with
|
||||
data := match s.data.find? a with
|
||||
data := match s.data[a]? with
|
||||
| none => s.data
|
||||
| some d => s.data.insert a (f d)
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ structure TraceState where
|
||||
traces : PersistentArray TraceElem := {}
|
||||
deriving Inhabited
|
||||
|
||||
builtin_initialize inheritedTraceOptions : IO.Ref (HashSet Name) ← IO.mkRef ∅
|
||||
builtin_initialize inheritedTraceOptions : IO.Ref (Std.HashSet Name) ← IO.mkRef ∅
|
||||
|
||||
class MonadTrace (m : Type → Type) where
|
||||
modifyTraceState : (TraceState → TraceState) → m Unit
|
||||
@@ -88,7 +88,7 @@ def printTraces : m Unit := do
|
||||
def resetTraceState : m Unit :=
|
||||
modifyTraceState (fun _ => {})
|
||||
|
||||
private def checkTraceOption (inherited : HashSet Name) (opts : Options) (cls : Name) : Bool :=
|
||||
private def checkTraceOption (inherited : Std.HashSet Name) (opts : Options) (cls : Name) : Bool :=
|
||||
!opts.isEmpty && go (`trace ++ cls)
|
||||
where
|
||||
go (opt : Name) : Bool :=
|
||||
|
||||
@@ -5,3 +5,4 @@ Authors: Sebastian Ullrich
|
||||
-/
|
||||
prelude
|
||||
import Std.Data
|
||||
import Std.Sat
|
||||
|
||||
@@ -112,6 +112,10 @@ Tries to retrieve the mapping for the given key, returning `none` if no such map
|
||||
@[inline] def get? (m : HashMap α β) (a : α) : Option β :=
|
||||
DHashMap.Const.get? m.inner a
|
||||
|
||||
@[deprecated get? "Use `m[a]?` or `m.get? a` instead", inherit_doc get?]
|
||||
def find? (m : HashMap α β) (a : α) : Option β :=
|
||||
m.get? a
|
||||
|
||||
@[inline, inherit_doc DHashMap.contains] def contains (m : HashMap α β)
|
||||
(a : α) : Bool :=
|
||||
m.inner.contains a
|
||||
@@ -135,6 +139,10 @@ Retrieves the mapping for the given key. Ensures that such a mapping exists by r
|
||||
(fallback : β) : β :=
|
||||
DHashMap.Const.getD m.inner a fallback
|
||||
|
||||
@[deprecated getD, inherit_doc getD]
|
||||
def findD (m : HashMap α β) (a : α) (fallback : β) : β :=
|
||||
m.getD a fallback
|
||||
|
||||
/--
|
||||
The notation `m[a]!` is preferred over calling this function directly.
|
||||
|
||||
@@ -143,6 +151,10 @@ Tries to retrieve the mapping for the given key, panicking if no such mapping is
|
||||
@[inline] def get! [Inhabited β] (m : HashMap α β) (a : α) : β :=
|
||||
DHashMap.Const.get! m.inner a
|
||||
|
||||
@[deprecated get! "Use `m[a]!` or `m.get! a` instead", inherit_doc get!]
|
||||
def find! [Inhabited β] (m : HashMap α β) (a : α) : Option β :=
|
||||
m.get! a
|
||||
|
||||
instance [BEq α] [Hashable α] : GetElem? (HashMap α β) α β (fun m a => a ∈ m) where
|
||||
getElem m a h := m.get a h
|
||||
getElem? m a := m.get? a
|
||||
@@ -236,3 +248,16 @@ instance [BEq α] [Hashable α] [Repr α] [Repr β] : Repr (HashMap α β) where
|
||||
end Unverified
|
||||
|
||||
end Std.HashMap
|
||||
|
||||
/--
|
||||
Groups all elements `x`, `y` in `xs` with `key x == key y` into the same array
|
||||
`(xs.groupByKey key).find! (key x)`. Groups preserve the relative order of elements in `xs`.
|
||||
-/
|
||||
def Array.groupByKey [BEq α] [Hashable α] (key : β → α) (xs : Array β)
|
||||
: Std.HashMap α (Array β) := Id.run do
|
||||
let mut groups := ∅
|
||||
for x in xs do
|
||||
let group := groups.getD (key x) #[]
|
||||
groups := groups.erase (key x) -- make `group` referentially unique
|
||||
groups := groups.insert (key x) (group.push x)
|
||||
return groups
|
||||
|
||||
@@ -69,8 +69,9 @@ instance : EmptyCollection (Raw α β) where
|
||||
instance : Inhabited (Raw α β) where
|
||||
default := ∅
|
||||
|
||||
@[inline, inherit_doc DHashMap.Raw.insert] def insert [BEq α] [Hashable α] (m : Raw α β) (a : α)
|
||||
(b : β) : Raw α β :=
|
||||
set_option linter.unusedVariables false in
|
||||
@[inline, inherit_doc DHashMap.Raw.insert] def insert [beq : BEq α] [Hashable α] (m : Raw α β)
|
||||
(a : α) (b : β) : Raw α β :=
|
||||
⟨m.inner.insert a b⟩
|
||||
|
||||
@[inline, inherit_doc DHashMap.Raw.insertIfNew] def insertIfNew [BEq α] [Hashable α] (m : Raw α β)
|
||||
@@ -93,12 +94,13 @@ instance : Inhabited (Raw α β) where
|
||||
let ⟨previous, r⟩ := DHashMap.Raw.Const.getThenInsertIfNew? m.inner a b
|
||||
⟨previous, ⟨r⟩⟩
|
||||
|
||||
set_option linter.unusedVariables false in
|
||||
/--
|
||||
The notation `m[a]?` is preferred over calling this function directly.
|
||||
|
||||
Tries to retrieve the mapping for the given key, returning `none` if no such mapping is present.
|
||||
-/
|
||||
@[inline] def get? [BEq α] [Hashable α] (m : Raw α β) (a : α) : Option β :=
|
||||
@[inline] def get? [beq : BEq α] [Hashable α] (m : Raw α β) (a : α) : Option β :=
|
||||
DHashMap.Raw.Const.get? m.inner a
|
||||
|
||||
@[inline, inherit_doc DHashMap.Raw.contains] def contains [BEq α] [Hashable α] (m : Raw α β)
|
||||
|
||||
@@ -184,6 +184,10 @@ in the collection will be present in the returned hash set.
|
||||
@[inline] def ofList [BEq α] [Hashable α] (l : List α) : HashSet α :=
|
||||
⟨HashMap.unitOfList l⟩
|
||||
|
||||
/-- Computes the union of the given hash sets. -/
|
||||
@[inline] def union [BEq α] [Hashable α] (m₁ m₂ : HashSet α) : HashSet α :=
|
||||
m₂.fold (init := m₁) fun acc x => acc.insert x
|
||||
|
||||
/--
|
||||
Returns the number of buckets in the internal representation of the hash set. This function may
|
||||
be useful for things like monitoring system health, but it should be considered an internal
|
||||
|
||||
7
src/Std/Sat.lean
Normal file
7
src/Std/Sat.lean
Normal file
@@ -0,0 +1,7 @@
|
||||
/-
|
||||
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Henrik Böving
|
||||
-/
|
||||
prelude
|
||||
import Std.Sat.CNF
|
||||
10
src/Std/Sat/CNF.lean
Normal file
10
src/Std/Sat/CNF.lean
Normal file
@@ -0,0 +1,10 @@
|
||||
/-
|
||||
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Henrik Böving
|
||||
-/
|
||||
prelude
|
||||
import Std.Sat.CNF.Basic
|
||||
import Std.Sat.CNF.Literal
|
||||
import Std.Sat.CNF.Relabel
|
||||
import Std.Sat.CNF.RelabelFin
|
||||
190
src/Std/Sat/CNF/Basic.lean
Normal file
190
src/Std/Sat/CNF/Basic.lean
Normal file
@@ -0,0 +1,190 @@
|
||||
/-
|
||||
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Kim Morrison
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.List.Lemmas
|
||||
import Init.Data.List.Impl
|
||||
import Std.Sat.CNF.Literal
|
||||
|
||||
namespace Std
|
||||
namespace Sat
|
||||
|
||||
/--
|
||||
A clause in a CNF.
|
||||
|
||||
The literal `(i, b)` is satisfied if the assignment to `i` agrees with `b`.
|
||||
-/
|
||||
abbrev CNF.Clause (α : Type u) : Type u := List (Literal α)
|
||||
|
||||
/--
|
||||
A CNF formula.
|
||||
|
||||
Literals are identified by members of `α`.
|
||||
-/
|
||||
abbrev CNF (α : Type u) : Type u := List (CNF.Clause α)
|
||||
|
||||
namespace CNF
|
||||
|
||||
/--
|
||||
Evaluating a `Clause` with respect to an assignment `a`.
|
||||
-/
|
||||
def Clause.eval (a : α → Bool) (c : Clause α) : Bool := c.any fun (i, n) => a i == n
|
||||
|
||||
@[simp] theorem Clause.eval_nil (a : α → Bool) : Clause.eval a [] = false := rfl
|
||||
@[simp] theorem Clause.eval_cons (a : α → Bool) :
|
||||
Clause.eval a (i :: c) = (a i.1 == i.2 || Clause.eval a c) := rfl
|
||||
|
||||
/--
|
||||
Evaluating a `CNF` formula with respect to an assignment `a`.
|
||||
-/
|
||||
def eval (a : α → Bool) (f : CNF α) : Bool := f.all fun c => c.eval a
|
||||
|
||||
@[simp] theorem eval_nil (a : α → Bool) : eval a [] = true := rfl
|
||||
@[simp] theorem eval_cons (a : α → Bool) : eval a (c :: f) = (c.eval a && eval a f) := rfl
|
||||
|
||||
@[simp] theorem eval_append (a : α → Bool) (f1 f2 : CNF α) :
|
||||
eval a (f1 ++ f2) = (eval a f1 && eval a f2) := List.all_append
|
||||
|
||||
def Sat (a : α → Bool) (f : CNF α) : Prop := eval a f = true
|
||||
def Unsat (f : CNF α) : Prop := ∀ a, eval a f = false
|
||||
|
||||
theorem sat_def (a : α → Bool) (f : CNF α) : Sat a f ↔ (eval a f = true) := by rfl
|
||||
theorem unsat_def (f : CNF α) : Unsat f ↔ (∀ a, eval a f = false) := by rfl
|
||||
|
||||
|
||||
@[simp] theorem not_unsat_nil : ¬Unsat ([] : CNF α) :=
|
||||
fun h => by simp [unsat_def] at h
|
||||
|
||||
@[simp] theorem sat_nil {assign : α → Bool} : Sat assign ([] : CNF α) := by
|
||||
simp [sat_def]
|
||||
|
||||
@[simp] theorem unsat_nil_cons {g : CNF α} : Unsat ([] :: g) := by
|
||||
simp [unsat_def]
|
||||
|
||||
namespace Clause
|
||||
|
||||
/--
|
||||
Variable `v` occurs in `Clause` `c`.
|
||||
-/
|
||||
def Mem (v : α) (c : Clause α) : Prop := (v, false) ∈ c ∨ (v, true) ∈ c
|
||||
|
||||
instance {v : α} {c : Clause α} [DecidableEq α] : Decidable (Mem v c) :=
|
||||
inferInstanceAs <| Decidable (_ ∨ _)
|
||||
|
||||
@[simp] theorem not_mem_nil {v : α} : ¬Mem v ([] : Clause α) := by simp [Mem]
|
||||
@[simp] theorem mem_cons {v : α} : Mem v (l :: c : Clause α) ↔ (v = l.1 ∨ Mem v c) := by
|
||||
rcases l with ⟨b, (_|_)⟩
|
||||
· simp [Mem, or_assoc]
|
||||
· simp [Mem]
|
||||
rw [or_left_comm]
|
||||
|
||||
theorem mem_of (h : (v, p) ∈ c) : Mem v c := by
|
||||
cases p
|
||||
· left; exact h
|
||||
· right; exact h
|
||||
|
||||
theorem eval_congr (a1 a2 : α → Bool) (c : Clause α) (hw : ∀ i, Mem i c → a1 i = a2 i) :
|
||||
eval a1 c = eval a2 c := by
|
||||
induction c
|
||||
case nil => rfl
|
||||
case cons i c ih =>
|
||||
simp only [eval_cons]
|
||||
rw [ih, hw]
|
||||
· rcases i with ⟨b, (_|_)⟩ <;> simp [Mem]
|
||||
· intro j h
|
||||
apply hw
|
||||
rcases h with h | h
|
||||
· left
|
||||
apply List.mem_cons_of_mem _ h
|
||||
· right
|
||||
apply List.mem_cons_of_mem _ h
|
||||
|
||||
end Clause
|
||||
|
||||
/--
|
||||
Variable `v` occurs in `CNF` formula `f`.
|
||||
-/
|
||||
def Mem (v : α) (f : CNF α) : Prop := ∃ c, c ∈ f ∧ c.Mem v
|
||||
|
||||
instance {v : α} {f : CNF α} [DecidableEq α] : Decidable (Mem v f) :=
|
||||
inferInstanceAs <| Decidable (∃ _, _)
|
||||
|
||||
theorem any_not_isEmpty_iff_exists_mem {f : CNF α} :
|
||||
(List.any f fun c => !List.isEmpty c) = true ↔ ∃ v, Mem v f := by
|
||||
simp only [List.any_eq_true, Bool.not_eq_true', List.isEmpty_false_iff_exists_mem, Mem,
|
||||
Clause.Mem]
|
||||
constructor
|
||||
. intro h
|
||||
rcases h with ⟨clause, ⟨hclause1, hclause2⟩⟩
|
||||
rcases hclause2 with ⟨lit, hlit⟩
|
||||
exists lit.fst, clause
|
||||
constructor
|
||||
. assumption
|
||||
. rcases lit with ⟨_, ⟨_ | _⟩⟩ <;> simp_all
|
||||
. intro h
|
||||
rcases h with ⟨lit, clause, ⟨hclause1, hclause2⟩⟩
|
||||
exists clause
|
||||
constructor
|
||||
. assumption
|
||||
. cases hclause2 with
|
||||
| inl hl => exact Exists.intro _ hl
|
||||
| inr hr => exact Exists.intro _ hr
|
||||
|
||||
@[simp] theorem not_exists_mem : (¬ ∃ v, Mem v f) ↔ ∃ n, f = List.replicate n [] := by
|
||||
simp only [← any_not_isEmpty_iff_exists_mem]
|
||||
simp only [List.any_eq_true, Bool.not_eq_true', not_exists, not_and, Bool.not_eq_false]
|
||||
induction f with
|
||||
| nil =>
|
||||
simp only [List.not_mem_nil, List.isEmpty_iff, false_implies, forall_const, true_iff]
|
||||
exact ⟨0, rfl⟩
|
||||
| cons c f ih =>
|
||||
simp_all [ih, List.isEmpty_iff]
|
||||
constructor
|
||||
· rintro ⟨rfl, n, rfl⟩
|
||||
exact ⟨n+1, rfl⟩
|
||||
· rintro ⟨n, h⟩
|
||||
cases n
|
||||
· simp at h
|
||||
· simp_all only [List.replicate, List.cons.injEq, true_and]
|
||||
exact ⟨_, rfl⟩
|
||||
|
||||
instance {f : CNF α} [DecidableEq α] : Decidable (∃ v, Mem v f) :=
|
||||
decidable_of_iff (f.any fun c => !c.isEmpty) any_not_isEmpty_iff_exists_mem
|
||||
|
||||
@[simp] theorem not_mem_nil {v : α} : ¬Mem v ([] : CNF α) := by simp [Mem]
|
||||
@[simp] theorem mem_cons {v : α} {c} {f : CNF α} :
|
||||
Mem v (c :: f : CNF α) ↔ (Clause.Mem v c ∨ Mem v f) := by simp [Mem]
|
||||
|
||||
theorem mem_of (h : c ∈ f) (w : Clause.Mem v c) : Mem v f := by
|
||||
apply Exists.intro c
|
||||
constructor <;> assumption
|
||||
|
||||
@[simp] theorem mem_append {v : α} {f1 f2 : CNF α} : Mem v (f1 ++ f2) ↔ Mem v f1 ∨ Mem v f2 := by
|
||||
simp [Mem, List.mem_append]
|
||||
constructor
|
||||
· rintro ⟨c, (mf1 | mf2), mc⟩
|
||||
· left
|
||||
exact ⟨c, mf1, mc⟩
|
||||
· right
|
||||
exact ⟨c, mf2, mc⟩
|
||||
· rintro (⟨c, mf1, mc⟩ | ⟨c, mf2, mc⟩)
|
||||
· exact ⟨c, Or.inl mf1, mc⟩
|
||||
· exact ⟨c, Or.inr mf2, mc⟩
|
||||
|
||||
theorem eval_congr (a1 a2 : α → Bool) (f : CNF α) (hw : ∀ v, Mem v f → a1 v = a2 v) :
|
||||
eval a1 f = eval a2 f := by
|
||||
induction f
|
||||
case nil => rfl
|
||||
case cons c x ih =>
|
||||
simp only [eval_cons]
|
||||
rw [ih, Clause.eval_congr] <;>
|
||||
· intro i h
|
||||
apply hw
|
||||
simp [h]
|
||||
|
||||
end CNF
|
||||
|
||||
end Sat
|
||||
end Std
|
||||
38
src/Std/Sat/CNF/Literal.lean
Normal file
38
src/Std/Sat/CNF/Literal.lean
Normal file
@@ -0,0 +1,38 @@
|
||||
/-
|
||||
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Josh Clune
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Hashable
|
||||
import Init.Data.ToString
|
||||
|
||||
namespace Std
|
||||
namespace Sat
|
||||
|
||||
/--
|
||||
CNF literals identified by some type `α`. The `Bool` is the polarity of the literal.
|
||||
`true` means positive polarity.
|
||||
-/
|
||||
abbrev Literal (α : Type u) := α × Bool
|
||||
|
||||
namespace Literal
|
||||
|
||||
/--
|
||||
Flip the polarity of `l`.
|
||||
-/
|
||||
def negate (l : Literal α) : Literal α := (l.1, not l.2)
|
||||
|
||||
/--
|
||||
Output `l` as a DIMACS literal identifier.
|
||||
-/
|
||||
def dimacs [ToString α] (l : Literal α) : String :=
|
||||
if l.2 then
|
||||
s!"{l.1}"
|
||||
else
|
||||
s!"-{l.1}"
|
||||
|
||||
end Literal
|
||||
|
||||
end Sat
|
||||
end Std
|
||||
123
src/Std/Sat/CNF/Relabel.lean
Normal file
123
src/Std/Sat/CNF/Relabel.lean
Normal file
@@ -0,0 +1,123 @@
|
||||
/-
|
||||
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Kim Morrison
|
||||
-/
|
||||
prelude
|
||||
import Std.Sat.CNF.Basic
|
||||
|
||||
namespace Std
|
||||
namespace Sat
|
||||
|
||||
namespace CNF
|
||||
|
||||
namespace Clause
|
||||
|
||||
/--
|
||||
Change the literal type in a `Clause` from `α` to `β` by using `r`.
|
||||
-/
|
||||
def relabel (r : α → β) (c : Clause α) : Clause β := c.map (fun (i, n) => (r i, n))
|
||||
|
||||
@[simp] theorem eval_relabel {r : α → β} {a : β → Bool} {c : Clause α} :
|
||||
(relabel r c).eval a = c.eval (a ∘ r) := by
|
||||
induction c <;> simp_all [relabel]
|
||||
|
||||
@[simp] theorem relabel_id' : relabel (id : α → α) = id := by funext; simp [relabel]
|
||||
|
||||
theorem relabel_congr {c : Clause α} {r1 r2 : α → β} (hw : ∀ v, Mem v c → r1 v = r2 v) :
|
||||
relabel r1 c = relabel r2 c := by
|
||||
simp only [relabel]
|
||||
rw [List.map_congr_left]
|
||||
intro ⟨v, p⟩ h
|
||||
congr
|
||||
apply hw _ (mem_of h)
|
||||
|
||||
-- We need the unapplied equality later.
|
||||
@[simp] theorem relabel_relabel' : relabel r1 ∘ relabel r2 = relabel (r1 ∘ r2) := by
|
||||
funext i
|
||||
simp only [Function.comp_apply, relabel, List.map_map]
|
||||
rfl
|
||||
|
||||
end Clause
|
||||
|
||||
/-! ### Relabelling
|
||||
|
||||
It is convenient to be able to construct a CNF using a more complicated literal type,
|
||||
but eventually we need to embed in `Nat`.
|
||||
-/
|
||||
|
||||
/--
|
||||
Change the literal type in a `CNF` formula from `α` to `β` by using `r`.
|
||||
-/
|
||||
def relabel (r : α → β) (f : CNF α) : CNF β := f.map (Clause.relabel r)
|
||||
|
||||
@[simp] theorem relabel_nil {r : α → β} : relabel r [] = [] := by simp [relabel]
|
||||
@[simp] theorem relabel_cons {r : α → β} : relabel r (c :: f) = (c.relabel r) :: relabel r f := by
|
||||
simp [relabel]
|
||||
|
||||
@[simp] theorem eval_relabel (r : α → β) (a : β → Bool) (f : CNF α) :
|
||||
(relabel r f).eval a = f.eval (a ∘ r) := by
|
||||
induction f <;> simp_all
|
||||
|
||||
@[simp] theorem relabel_append : relabel r (f1 ++ f2) = relabel r f1 ++ relabel r f2 :=
|
||||
List.map_append _ _ _
|
||||
|
||||
@[simp] theorem relabel_relabel : relabel r1 (relabel r2 f) = relabel (r1 ∘ r2) f := by
|
||||
simp only [relabel, List.map_map, Clause.relabel_relabel']
|
||||
|
||||
@[simp] theorem relabel_id : relabel id x = x := by simp [relabel]
|
||||
|
||||
theorem relabel_congr {f : CNF α} {r1 r2 : α → β} (hw : ∀ v, Mem v f → r1 v = r2 v) :
|
||||
relabel r1 f = relabel r2 f := by
|
||||
dsimp only [relabel]
|
||||
rw [List.map_congr_left]
|
||||
intro c h
|
||||
apply Clause.relabel_congr
|
||||
intro v m
|
||||
exact hw _ (mem_of h m)
|
||||
|
||||
theorem sat_relabel {f : CNF α} (h : Sat (r1 ∘ r2) f) : Sat r1 (relabel r2 f) := by
|
||||
simp_all [sat_def]
|
||||
|
||||
theorem unsat_relabel {f : CNF α} (r : α → β) (h : Unsat f) :
|
||||
Unsat (relabel r f) := by
|
||||
simp_all [unsat_def]
|
||||
|
||||
theorem nonempty_or_impossible (f : CNF α) : Nonempty α ∨ ∃ n, f = List.replicate n [] := by
|
||||
induction f with
|
||||
| nil => exact Or.inr ⟨0, rfl⟩
|
||||
| cons c x ih => match c with
|
||||
| [] => cases ih with
|
||||
| inl h => left; exact h
|
||||
| inr h =>
|
||||
obtain ⟨n, rfl⟩ := h
|
||||
right
|
||||
exact ⟨n + 1, rfl⟩
|
||||
| ⟨a, b⟩ :: c => exact Or.inl ⟨a⟩
|
||||
|
||||
theorem unsat_relabel_iff {f : CNF α} {r : α → β}
|
||||
(hw : ∀ {v1 v2}, Mem v1 f → Mem v2 f → r v1 = r v2 → v1 = v2) :
|
||||
Unsat (relabel r f) ↔ Unsat f := by
|
||||
rcases nonempty_or_impossible f with (⟨⟨a₀⟩⟩ | ⟨n, rfl⟩)
|
||||
· refine ⟨fun h => ?_, unsat_relabel r⟩
|
||||
have em := Classical.propDecidable
|
||||
let g : β → α := fun b =>
|
||||
if h : ∃ a, Mem a f ∧ r a = b then h.choose else a₀
|
||||
have h' := unsat_relabel g h
|
||||
suffices w : relabel g (relabel r f) = f by
|
||||
rwa [w] at h'
|
||||
have : ∀ a, Mem a f → g (r a) = a := by
|
||||
intro v h
|
||||
dsimp [g]
|
||||
rw [dif_pos ⟨v, h, rfl⟩]
|
||||
apply hw _ h
|
||||
· exact (Exists.choose_spec (⟨v, h, rfl⟩ : ∃ a', Mem a' f ∧ r a' = r v)).2
|
||||
· exact (Exists.choose_spec (⟨v, h, rfl⟩ : ∃ a', Mem a' f ∧ r a' = r v)).1
|
||||
rw [relabel_relabel, relabel_congr, relabel_id]
|
||||
exact this
|
||||
· cases n <;> simp [unsat_def, List.replicate_succ]
|
||||
|
||||
end CNF
|
||||
|
||||
end Sat
|
||||
end Std
|
||||
138
src/Std/Sat/CNF/RelabelFin.lean
Normal file
138
src/Std/Sat/CNF/RelabelFin.lean
Normal file
@@ -0,0 +1,138 @@
|
||||
/-
|
||||
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Kim Morrison
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.List.Nat.Basic
|
||||
import Std.Sat.CNF.Relabel
|
||||
|
||||
namespace Std
|
||||
namespace Sat
|
||||
|
||||
namespace CNF
|
||||
|
||||
/--
|
||||
Obtain the literal with the largest identifier in `c`.
|
||||
-/
|
||||
def Clause.maxLiteral (c : Clause Nat) : Option Nat := (c.map (·.1)) |>.maximum?
|
||||
|
||||
theorem Clause.of_maxLiteral_eq_some (c : Clause Nat) (h : c.maxLiteral = some maxLit) :
|
||||
∀ lit, Mem lit c → lit ≤ maxLit := by
|
||||
intro lit hlit
|
||||
simp only [maxLiteral, List.maximum?_eq_some_iff', List.mem_map, forall_exists_index, and_imp,
|
||||
forall_apply_eq_imp_iff₂] at h
|
||||
simp only [Mem] at hlit
|
||||
rcases h with ⟨_, hbar⟩
|
||||
cases hlit
|
||||
all_goals
|
||||
have := hbar (lit, _) (by assumption)
|
||||
omega
|
||||
|
||||
theorem Clause.maxLiteral_eq_some_of_mem (c : Clause Nat) (h : Mem l c) :
|
||||
∃ maxLit, c.maxLiteral = some maxLit := by
|
||||
dsimp [Mem] at h
|
||||
cases h <;> rename_i h
|
||||
all_goals
|
||||
have h1 := List.ne_nil_of_mem h
|
||||
have h2 := not_congr <| @List.maximum?_eq_none_iff _ (c.map (·.1)) _
|
||||
simp [← Option.ne_none_iff_exists', h1, h2, maxLiteral]
|
||||
|
||||
theorem Clause.of_maxLiteral_eq_none (c : Clause Nat) (h : c.maxLiteral = none) :
|
||||
∀ lit, ¬Mem lit c := by
|
||||
intro lit hlit
|
||||
simp only [maxLiteral, List.maximum?_eq_none_iff, List.map_eq_nil] at h
|
||||
simp only [h, not_mem_nil] at hlit
|
||||
|
||||
/--
|
||||
Obtain the literal with the largest identifier in `f`.
|
||||
-/
|
||||
def maxLiteral (f : CNF Nat) : Option Nat :=
|
||||
List.filterMap Clause.maxLiteral f |>.maximum?
|
||||
|
||||
theorem of_maxLiteral_eq_some' (f : CNF Nat) (h : f.maxLiteral = some maxLit) :
|
||||
∀ clause, clause ∈ f → clause.maxLiteral = some localMax → localMax ≤ maxLit := by
|
||||
intro clause hclause1 hclause2
|
||||
simp [maxLiteral, List.maximum?_eq_some_iff'] at h
|
||||
rcases h with ⟨_, hclause3⟩
|
||||
apply hclause3 localMax clause hclause1 hclause2
|
||||
|
||||
theorem of_maxLiteral_eq_some (f : CNF Nat) (h : f.maxLiteral = some maxLit) :
|
||||
∀ lit, Mem lit f → lit ≤ maxLit := by
|
||||
intro lit hlit
|
||||
dsimp [Mem] at hlit
|
||||
rcases hlit with ⟨clause, ⟨hclause1, hclause2⟩⟩
|
||||
rcases Clause.maxLiteral_eq_some_of_mem clause hclause2 with ⟨localMax, hlocal⟩
|
||||
have h1 := of_maxLiteral_eq_some' f h clause hclause1 hlocal
|
||||
have h2 := Clause.of_maxLiteral_eq_some clause hlocal lit hclause2
|
||||
omega
|
||||
|
||||
theorem of_maxLiteral_eq_none (f : CNF Nat) (h : f.maxLiteral = none) :
|
||||
∀ lit, ¬Mem lit f := by
|
||||
intro lit hlit
|
||||
simp only [maxLiteral, List.maximum?_eq_none_iff] at h
|
||||
dsimp [Mem] at hlit
|
||||
rcases hlit with ⟨clause, ⟨hclause1, hclause2⟩⟩
|
||||
have := Clause.of_maxLiteral_eq_none clause (List.forall_none_of_filterMap_eq_nil h clause hclause1) lit
|
||||
contradiction
|
||||
|
||||
/--
|
||||
An upper bound for the amount of distinct literals in `f`.
|
||||
-/
|
||||
def numLiterals (f : CNF Nat) :=
|
||||
match f.maxLiteral with
|
||||
| none => 0
|
||||
| some n => n + 1
|
||||
|
||||
theorem lt_numLiterals {f : CNF Nat} (h : Mem v f) : v < numLiterals f := by
|
||||
dsimp [numLiterals]
|
||||
split <;> rename_i h2
|
||||
. exfalso
|
||||
apply of_maxLiteral_eq_none f h2 v h
|
||||
. have := of_maxLiteral_eq_some f h2 v h
|
||||
omega
|
||||
|
||||
theorem numLiterals_pos {f : CNF Nat} (h : Mem v f) : 0 < numLiterals f :=
|
||||
Nat.lt_of_le_of_lt (Nat.zero_le _) (lt_numLiterals h)
|
||||
|
||||
/--
|
||||
Relabel `f` to a `CNF` formula with a known upper bound for its literals.
|
||||
|
||||
This operation might be useful when e.g. using the literals to index into an array of known size
|
||||
without conducting bounds checks.
|
||||
-/
|
||||
def relabelFin (f : CNF Nat) : CNF (Fin f.numLiterals) :=
|
||||
if h : ∃ v, Mem v f then
|
||||
let n := f.numLiterals
|
||||
f.relabel fun i =>
|
||||
if w : i < n then
|
||||
-- This branch will always hold
|
||||
⟨i, w⟩
|
||||
else
|
||||
⟨0, numLiterals_pos h.choose_spec⟩
|
||||
else
|
||||
List.replicate f.length []
|
||||
|
||||
@[simp] theorem unsat_relabelFin {f : CNF Nat} : Unsat f.relabelFin ↔ Unsat f := by
|
||||
dsimp [relabelFin]
|
||||
split <;> rename_i h
|
||||
· apply unsat_relabel_iff
|
||||
intro a b ma mb
|
||||
replace ma := lt_numLiterals ma
|
||||
replace mb := lt_numLiterals mb
|
||||
split <;> rename_i a_lt
|
||||
· simp
|
||||
· contradiction
|
||||
· cases f with
|
||||
| nil => simp
|
||||
| cons c g =>
|
||||
simp only [not_exists_mem] at h
|
||||
obtain ⟨n, h⟩ := h
|
||||
cases n with
|
||||
| zero => simp at h
|
||||
| succ n => simp_all [List.replicate_succ]
|
||||
|
||||
end CNF
|
||||
|
||||
end Sat
|
||||
end Std
|
||||
@@ -23,8 +23,8 @@ structure Module where
|
||||
instance : Hashable Module where hash m := hash m.keyName
|
||||
instance : BEq Module where beq m n := m.keyName == n.keyName
|
||||
|
||||
abbrev ModuleSet := HashSet Module
|
||||
@[inline] def ModuleSet.empty : ModuleSet := HashSet.empty
|
||||
abbrev ModuleSet := Std.HashSet Module
|
||||
@[inline] def ModuleSet.empty : ModuleSet := Std.HashSet.empty
|
||||
|
||||
abbrev OrdModuleSet := OrdHashSet Module
|
||||
@[inline] def OrdModuleSet.empty : OrdModuleSet := OrdHashSet.empty
|
||||
|
||||
@@ -247,8 +247,8 @@ hydrate_opaque_type OpaquePackage Package
|
||||
instance : Hashable Package where hash pkg := hash pkg.config.name
|
||||
instance : BEq Package where beq p1 p2 := p1.config.name == p2.config.name
|
||||
|
||||
abbrev PackageSet := HashSet Package
|
||||
@[inline] def PackageSet.empty : PackageSet := HashSet.empty
|
||||
abbrev PackageSet := Std.HashSet Package
|
||||
@[inline] def PackageSet.empty : PackageSet := Std.HashSet.empty
|
||||
|
||||
abbrev OrdPackageSet := OrdHashSet Package
|
||||
@[inline] def OrdPackageSet.empty : OrdPackageSet := OrdHashSet.empty
|
||||
|
||||
@@ -23,11 +23,11 @@ namespace Lake
|
||||
deriving instance BEq, Hashable for Import
|
||||
|
||||
/- Cache for the imported header environment of Lake configuration files. -/
|
||||
initialize importEnvCache : IO.Ref (HashMap (Array Import) Environment) ← IO.mkRef {}
|
||||
initialize importEnvCache : IO.Ref (Std.HashMap (Array Import) Environment) ← IO.mkRef {}
|
||||
|
||||
/-- Like `importModules`, but fetch the resulting import state from the cache if possible. -/
|
||||
def importModulesUsingCache (imports : Array Import) (opts : Options) (trustLevel : UInt32) : IO Environment := do
|
||||
if let some env := (← importEnvCache.get).find? imports then
|
||||
if let some env := (← importEnvCache.get)[imports]? then
|
||||
return env
|
||||
let env ← importModules imports opts trustLevel
|
||||
importEnvCache.modify (·.insert imports env)
|
||||
@@ -120,7 +120,7 @@ def importConfigFileCore (olean : FilePath) (leanOpts : Options) : IO Environmen
|
||||
let extNameIdx ← mkExtNameMap 0
|
||||
let env := mod.entries.foldl (init := env) fun env (extName, ents) =>
|
||||
if lakeExts.contains extName then
|
||||
match extNameIdx.find? extName with
|
||||
match extNameIdx[extName]? with
|
||||
| some entryIdx => ents.foldl extDescrs[entryIdx]!.addEntry env
|
||||
| none => env
|
||||
else
|
||||
|
||||
@@ -3,7 +3,7 @@ Copyright (c) 2022 Mac Malone. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Mac Malone
|
||||
-/
|
||||
import Lean.Data.HashSet
|
||||
import Std.Data.HashSet.Basic
|
||||
|
||||
open Lean
|
||||
|
||||
@@ -11,13 +11,13 @@ namespace Lake
|
||||
|
||||
/-- A `HashSet` that preserves insertion order. -/
|
||||
structure OrdHashSet (α) [Hashable α] [BEq α] where
|
||||
toHashSet : HashSet α
|
||||
toHashSet : Std.HashSet α
|
||||
toArray : Array α
|
||||
|
||||
namespace OrdHashSet
|
||||
variable [Hashable α] [BEq α]
|
||||
|
||||
instance : Coe (OrdHashSet α) (HashSet α) := ⟨toHashSet⟩
|
||||
instance : Coe (OrdHashSet α) (Std.HashSet α) := ⟨toHashSet⟩
|
||||
|
||||
def empty : OrdHashSet α :=
|
||||
⟨.empty, .empty⟩
|
||||
|
||||
@@ -57,6 +57,7 @@ def testProc (args : IO.Process.SpawnArgs) : BaseIO Bool :=
|
||||
EIO.catchExceptions (h := fun _ => pure false) do
|
||||
let child ← IO.Process.spawn {
|
||||
args with
|
||||
stdin := IO.Process.Stdio.null
|
||||
stdout := IO.Process.Stdio.null
|
||||
stderr := IO.Process.Stdio.null
|
||||
}
|
||||
|
||||
@@ -7,13 +7,14 @@ Author: Leonardo de Moura
|
||||
#pragma once
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <lean/lean.h>
|
||||
|
||||
namespace lean {
|
||||
void init_thread_heap();
|
||||
void * alloc(size_t sz);
|
||||
void dealloc(void * o, size_t sz);
|
||||
void add_heartbeats(uint64_t count);
|
||||
uint64_t get_num_heartbeats();
|
||||
LEAN_EXPORT void * alloc(size_t sz);
|
||||
LEAN_EXPORT void dealloc(void * o, size_t sz);
|
||||
LEAN_EXPORT void add_heartbeats(uint64_t count);
|
||||
LEAN_EXPORT uint64_t get_num_heartbeats();
|
||||
void initialize_alloc();
|
||||
void finalize_alloc();
|
||||
}
|
||||
|
||||
@@ -485,43 +485,36 @@ extern "C" LEAN_EXPORT obj_res lean_io_prim_handle_write(b_obj_arg h, b_obj_arg
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Handle.getLine : (@& Handle) → IO Unit
|
||||
The line returned by `lean_io_prim_handle_get_line`
|
||||
is truncated at the first '\0' character and the
|
||||
rest of the line is discarded. */
|
||||
/* Handle.getLine : (@& Handle) → IO Unit */
|
||||
extern "C" LEAN_EXPORT obj_res lean_io_prim_handle_get_line(b_obj_arg h, obj_arg /* w */) {
|
||||
FILE * fp = io_get_handle(h);
|
||||
const int buf_sz = 64;
|
||||
char buf_str[buf_sz]; // NOLINT
|
||||
|
||||
std::string result;
|
||||
bool first = true;
|
||||
while (true) {
|
||||
char * out = std::fgets(buf_str, buf_sz, fp);
|
||||
if (out != nullptr) {
|
||||
if (strlen(buf_str) < buf_sz-1 || buf_str[buf_sz-2] == '\n') {
|
||||
if (first) {
|
||||
return io_result_mk_ok(mk_string(out));
|
||||
} else {
|
||||
result.append(out);
|
||||
return io_result_mk_ok(mk_string(result));
|
||||
}
|
||||
}
|
||||
result.append(out);
|
||||
} else if (std::feof(fp)) {
|
||||
clearerr(fp);
|
||||
return io_result_mk_ok(mk_string(result));
|
||||
} else {
|
||||
return io_result_mk_error(decode_io_error(errno, nullptr));
|
||||
int c; // Note: int, not char, required to handle EOF
|
||||
while ((c = std::fgetc(fp)) != EOF) {
|
||||
result.push_back(c);
|
||||
if (c == '\n') {
|
||||
break;
|
||||
}
|
||||
first = false;
|
||||
}
|
||||
|
||||
if (std::ferror(fp)) {
|
||||
return io_result_mk_error(decode_io_error(errno, nullptr));
|
||||
} else if (std::feof(fp)) {
|
||||
clearerr(fp);
|
||||
return io_result_mk_ok(mk_string(result));
|
||||
} else {
|
||||
obj_res ret = io_result_mk_ok(mk_string(result));
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
/* Handle.putStr : (@& Handle) → (@& String) → IO Unit */
|
||||
extern "C" LEAN_EXPORT obj_res lean_io_prim_handle_put_str(b_obj_arg h, b_obj_arg s, obj_arg /* w */) {
|
||||
FILE * fp = io_get_handle(h);
|
||||
if (std::fputs(lean_string_cstr(s), fp) != EOF) {
|
||||
usize n = lean_string_size(s) - 1; // - 1 to ignore the terminal NULL byte.
|
||||
usize m = std::fwrite(lean_string_cstr(s), 1, n, fp);
|
||||
if (m == n) {
|
||||
return io_result_mk_ok(box(0));
|
||||
} else {
|
||||
return io_result_mk_error(decode_io_error(errno, nullptr));
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user