mirror of
https://github.com/leanprover/lean4.git
synced 2026-04-07 04:34:08 +00:00
Compare commits
66 Commits
fix-window
...
hbv/readFi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
118709ee6c | ||
|
|
c2575107f2 | ||
|
|
4db828d885 | ||
|
|
3f16f339e7 | ||
|
|
2771296ca5 | ||
|
|
5e337872ce | ||
|
|
329fa6309b | ||
|
|
370488a9ff | ||
|
|
df38da8e09 | ||
|
|
a2b93d6c18 | ||
|
|
63c4de5fea | ||
|
|
3b14642c42 | ||
|
|
d52da36e68 | ||
|
|
bf82965eec | ||
|
|
4bac74c4ac | ||
|
|
8b9d27de31 | ||
|
|
d15f0335a9 | ||
|
|
240ebff549 | ||
|
|
a29bca7f00 | ||
|
|
313f6b3c74 | ||
|
|
43fa46412d | ||
|
|
234704e304 | ||
|
|
12a714a6f9 | ||
|
|
cdc7ed0224 | ||
|
|
217abdf97a | ||
|
|
490a2b4bf9 | ||
|
|
84d45deb10 | ||
|
|
f46d216e18 | ||
|
|
cc42a17931 | ||
|
|
e106be19dd | ||
|
|
1efd6657d4 | ||
|
|
473b34561d | ||
|
|
574066b30b | ||
|
|
1e6d617aad | ||
|
|
c17a4ddc94 | ||
|
|
5be4f5e30c | ||
|
|
3c5ac9496f | ||
|
|
6c1f8a8a63 | ||
|
|
7bea3c1508 | ||
|
|
a27d4a9519 | ||
|
|
4a2fb6e922 | ||
|
|
b7db82894b | ||
|
|
35e1554ef7 | ||
|
|
14d59b3599 | ||
|
|
a8e480cd52 | ||
|
|
d07239d1bd | ||
|
|
590de785cc | ||
|
|
d671d0d61a | ||
|
|
8e476e9d22 | ||
|
|
a3d144a362 | ||
|
|
87d41e6326 | ||
|
|
d6cb2432c6 | ||
|
|
c0ffc85d75 | ||
|
|
f62359acc7 | ||
|
|
2d09c96caf | ||
|
|
21b4377d36 | ||
|
|
1e9d96be22 | ||
|
|
647a5e9492 | ||
|
|
9c4028aab4 | ||
|
|
2c002718e0 | ||
|
|
b07384acbb | ||
|
|
efc99b982e | ||
|
|
ee430b6c80 | ||
|
|
a8740f5ed9 | ||
|
|
5e6a3cf5f9 | ||
|
|
0ed1cf7244 |
4
.github/workflows/nix-ci.yml
vendored
4
.github/workflows/nix-ci.yml
vendored
@@ -103,7 +103,7 @@ jobs:
|
||||
continue-on-error: true
|
||||
- name: Build manual
|
||||
run: |
|
||||
nix build $NIX_BUILD_ARGS --update-input lean --no-write-lock-file ./doc#{lean-mdbook,leanInk,alectryon,test,inked} -o push-doc
|
||||
nix build $NIX_BUILD_ARGS --update-input lean --no-write-lock-file ./doc#{lean-mdbook,leanInk,alectryon,inked} -o push-doc
|
||||
nix build $NIX_BUILD_ARGS --update-input lean --no-write-lock-file ./doc
|
||||
# https://github.com/netlify/cli/issues/1809
|
||||
cp -r --dereference ./result ./dist
|
||||
@@ -146,5 +146,3 @@ jobs:
|
||||
- name: Fixup CCache Cache
|
||||
run: |
|
||||
sudo chown -R $USER /nix/var/cache
|
||||
- name: CCache stats
|
||||
run: CCACHE_DIR=/nix/var/cache/ccache nix run .#nixpkgs.ccache -- -s
|
||||
|
||||
30
.github/workflows/rebase-on-comment.yml
vendored
Normal file
30
.github/workflows/rebase-on-comment.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
# As the PR author, use `!rebase` to rebase a commit
|
||||
name: Rebase on Comment
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
rebase:
|
||||
if: github.event.issue.pull_request != '' && github.event.comment.body == '!rebase' && github.event.comment.user.login == github.event.issue.user.login
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: refs/pull/${{ github.event.issue.number }}/head
|
||||
- name: Rebase PR branch onto base branch
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
PR_NUMBER="${{ github.event.issue.number }}"
|
||||
API_URL="https://api.github.com/repos/${{ github.repository }}/pulls/$PR_NUMBER"
|
||||
PR_DETAILS="$(curl -s -H "Authorization: token $GITHUB_TOKEN" $API_URL)"
|
||||
|
||||
BASE_REF="$(echo $PR_DETAILS | jq -r .base.ref)"
|
||||
|
||||
git checkout -b working-branch
|
||||
git fetch origin $BASE_REF
|
||||
git rebase origin/$BASE_REF
|
||||
git push origin refs/pull/${{ github.event.issue.number }}/head --force-with-lease
|
||||
@@ -147,33 +147,31 @@ We'll use `v4.7.0-rc1` as the intended release version in this example.
|
||||
- You can monitor this at `https://github.com/leanprover/lean4/actions/workflows/ci.yml`, looking for the `v4.7.0-rc1` tag.
|
||||
- This step can take up to an hour.
|
||||
- (GitHub release notes) Once the release appears at https://github.com/leanprover/lean4/releases/
|
||||
- Edit the release notes on Github to select the "Set as a pre-release box".
|
||||
- If release notes have been written already, copy the section of `RELEASES.md` for this version into the Github release notes
|
||||
and use the title "Changes since v4.6.0 (from RELEASES.md)".
|
||||
- Otherwise, in the "previous tag" dropdown, select `v4.6.0`, and click "Generate release notes".
|
||||
- Verify that the release is marked as a prerelease (this should have been done automatically by the CI release job).
|
||||
- In the "previous tag" dropdown, select `v4.6.0`, and click "Generate release notes".
|
||||
This will add a list of all the commits since the last stable version.
|
||||
- Delete anything already mentioned in the hand-written release notes above.
|
||||
- Delete "update stage0" commits, and anything with a completely inscrutable commit message.
|
||||
- Briefly rearrange the remaining items by category (e.g. `simp`, `lake`, `bug fixes`),
|
||||
but for minor items don't put any work in expanding on commit messages.
|
||||
- (How we want to release notes to look is evolving: please update this section if it looks wrong!)
|
||||
- Next, we will move a curated list of downstream repos to the release candidate.
|
||||
- This assumes that there is already a *reviewed* branch `bump/v4.7.0` on each repository
|
||||
containing the required adaptations (or no adaptations are required).
|
||||
The preparation of this branch is beyond the scope of this document.
|
||||
- This assumes that for each repository either:
|
||||
* There is already a *reviewed* branch `bump/v4.7.0` containing the required adaptations.
|
||||
The preparation of this branch is beyond the scope of this document.
|
||||
* The repository does not need any changes to move to the new version.
|
||||
- For each of the target repositories:
|
||||
- Checkout the `bump/v4.7.0` branch.
|
||||
- Verify that the `lean-toolchain` is set to the nightly from which the release candidate was created.
|
||||
- `git merge origin/master`
|
||||
- Change the `lean-toolchain` to `leanprover/lean4:v4.7.0-rc1`
|
||||
- In `lakefile.lean`, change any dependencies which were using `nightly-testing` or `bump/v4.7.0` branches
|
||||
back to `master` or `main`, and run `lake update` for those dependencies.
|
||||
- Run `lake build` to ensure that dependencies are found (but it's okay to stop it after a moment).
|
||||
- `git commit`
|
||||
- `git push`
|
||||
- Open a PR from `bump/v4.7.0` to `master`, and either merge it yourself after CI, if appropriate,
|
||||
or notify the maintainers that it is ready to go.
|
||||
- Once this PR has been merged, tag `master` with `v4.7.0-rc1` and push this tag.
|
||||
- If the repository does not need any changes (i.e. `bump/v4.7.0` does not exist) then create
|
||||
a new PR updating `lean-toolchain` to `leanprover/lean4:v4.7.0-rc1` and running `lake update`.
|
||||
- Otherwise:
|
||||
- Checkout the `bump/v4.7.0` branch.
|
||||
- Verify that the `lean-toolchain` is set to the nightly from which the release candidate was created.
|
||||
- `git merge origin/master`
|
||||
- Change the `lean-toolchain` to `leanprover/lean4:v4.7.0-rc1`
|
||||
- In `lakefile.lean`, change any dependencies which were using `nightly-testing` or `bump/v4.7.0` branches
|
||||
back to `master` or `main`, and run `lake update` for those dependencies.
|
||||
- Run `lake build` to ensure that dependencies are found (but it's okay to stop it after a moment).
|
||||
- `git commit`
|
||||
- `git push`
|
||||
- Open a PR from `bump/v4.7.0` to `master`, and either merge it yourself after CI, if appropriate,
|
||||
or notify the maintainers that it is ready to go.
|
||||
- Once the PR has been merged, tag `master` with `v4.7.0-rc1` and push this tag.
|
||||
- We do this for the same list of repositories as for stable releases, see above.
|
||||
As above, there are dependencies between these, and so the process above is iterative.
|
||||
It greatly helps if you can merge the `bump/v4.7.0` PRs yourself!
|
||||
|
||||
138
doc/flake.lock
generated
138
doc/flake.lock
generated
@@ -18,12 +18,15 @@
|
||||
}
|
||||
},
|
||||
"flake-utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1656928814,
|
||||
"narHash": "sha256-RIFfgBuKz6Hp89yRr7+NR5tzIAbn52h8vT6vXkYjZoM=",
|
||||
"lastModified": 1710146030,
|
||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "7e2a3b3dfd9af950a856d66b0a7d01e3c18aa249",
|
||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -35,13 +38,12 @@
|
||||
"lean": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils",
|
||||
"lean4-mode": "lean4-mode",
|
||||
"nix": "nix",
|
||||
"nixpkgs": "nixpkgs_2"
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-old": "nixpkgs-old"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 0,
|
||||
"narHash": "sha256-YnYbmG0oou1Q/GE4JbMNb8/yqUVXBPIvcdQQJHBqtPk=",
|
||||
"narHash": "sha256-saRAtQ6VautVXKDw1XH35qwP0KEBKTKZbg/TRa4N9Vw=",
|
||||
"path": "../.",
|
||||
"type": "path"
|
||||
},
|
||||
@@ -50,22 +52,6 @@
|
||||
"type": "path"
|
||||
}
|
||||
},
|
||||
"lean4-mode": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1659020985,
|
||||
"narHash": "sha256-+dRaXB7uvN/weSZiKcfSKWhcdJVNg9Vg8k0pJkDNjpc=",
|
||||
"owner": "leanprover",
|
||||
"repo": "lean4-mode",
|
||||
"rev": "37d5c99b7b29c80ab78321edd6773200deb0bca6",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "leanprover",
|
||||
"repo": "lean4-mode",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"leanInk": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
@@ -83,22 +69,6 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"lowdown-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1633514407,
|
||||
"narHash": "sha256-Dw32tiMjdK9t3ETl5fzGrutQTzh2rufgZV4A/BbxuD4=",
|
||||
"owner": "kristapsdz",
|
||||
"repo": "lowdown",
|
||||
"rev": "d2c2b44ff6c27b936ec27358a2653caaef8f73b8",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "kristapsdz",
|
||||
"repo": "lowdown",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"mdBook": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
@@ -115,65 +85,13 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nix": {
|
||||
"inputs": {
|
||||
"lowdown-src": "lowdown-src",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-regression": "nixpkgs-regression"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1657097207,
|
||||
"narHash": "sha256-SmeGmjWM3fEed3kQjqIAO8VpGmkC2sL1aPE7kKpK650=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nix",
|
||||
"rev": "f6316b49a0c37172bca87ede6ea8144d7d89832f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"repo": "nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1653988320,
|
||||
"narHash": "sha256-ZaqFFsSDipZ6KVqriwM34T739+KLYJvNmCWzErjAg7c=",
|
||||
"lastModified": 1710889954,
|
||||
"narHash": "sha256-Pr6F5Pmd7JnNEMHHmspZ0qVqIBVxyZ13ik1pJtm2QXk=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "2fa57ed190fd6c7c746319444f34b5917666e5c1",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-22.05-small",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs-regression": {
|
||||
"locked": {
|
||||
"lastModified": 1643052045,
|
||||
"narHash": "sha256-uGJ0VXIhWKGXxkeNnq4TvV3CIOkUJ3PAoLZ3HMzNVMw=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "215d4d0fd80ca5163643b03a33fde804a29cc1e2",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "215d4d0fd80ca5163643b03a33fde804a29cc1e2",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs_2": {
|
||||
"locked": {
|
||||
"lastModified": 1657208011,
|
||||
"narHash": "sha256-BlIFwopAykvdy1DYayEkj6ZZdkn+cVgPNX98QVLc0jM=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "2770cc0b1e8faa0e20eb2c6aea64c256a706d4f2",
|
||||
"rev": "7872526e9c5332274ea5932a0c3270d6e4724f3b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -183,6 +101,23 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs-old": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1581379743,
|
||||
"narHash": "sha256-i1XCn9rKuLjvCdu2UeXKzGLF6IuQePQKFt4hEKRU5oc=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "34c7eb7545d155cc5b6f499b23a7cb1c96ab4d59",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-19.03",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"alectryon": "alectryon",
|
||||
@@ -194,6 +129,21 @@
|
||||
"leanInk": "leanInk",
|
||||
"mdBook": "mdBook"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
};
|
||||
|
||||
outputs = inputs@{ self, ... }: inputs.flake-utils.lib.eachDefaultSystem (system:
|
||||
with inputs.lean.packages.${system}; with nixpkgs;
|
||||
with inputs.lean.packages.${system}.deprecated; with nixpkgs;
|
||||
let
|
||||
doc-src = lib.sourceByRegex ../. ["doc.*" "tests(/lean(/beginEndAsMacro.lean)?)?"];
|
||||
in {
|
||||
@@ -44,21 +44,6 @@
|
||||
mdbook build -d $out
|
||||
'';
|
||||
};
|
||||
# We use a separate derivation instead of `checkPhase` so we can push it but not `doc` to the binary cache
|
||||
test = stdenv.mkDerivation {
|
||||
name ="lean-doc-test";
|
||||
src = doc-src;
|
||||
buildInputs = [ lean-mdbook stage1.Lean.lean-package strace ];
|
||||
patchPhase = ''
|
||||
cd doc
|
||||
patchShebangs test
|
||||
'';
|
||||
buildPhase = ''
|
||||
mdbook test
|
||||
touch $out
|
||||
'';
|
||||
dontInstall = true;
|
||||
};
|
||||
leanInk = (buildLeanPackage {
|
||||
name = "Main";
|
||||
src = inputs.leanInk;
|
||||
|
||||
113
flake.lock
generated
113
flake.lock
generated
@@ -1,21 +1,5 @@
|
||||
{
|
||||
"nodes": {
|
||||
"flake-compat": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1673956053,
|
||||
"narHash": "sha256-4gtG9iQuiKITOjNQQeQIpoIB6b16fm+504Ch3sNKLd8=",
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"rev": "35bb57c0c8d8b62bbfd284272c928ceb64ddbde9",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
@@ -34,71 +18,18 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"lean4-mode": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1709737301,
|
||||
"narHash": "sha256-uT9JN2kLNKJK9c/S/WxLjiHmwijq49EgLb+gJUSDpz0=",
|
||||
"owner": "leanprover",
|
||||
"repo": "lean4-mode",
|
||||
"rev": "f1f24c15134dee3754b82c9d9924866fe6bc6b9f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "leanprover",
|
||||
"repo": "lean4-mode",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"libgit2": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1697646580,
|
||||
"narHash": "sha256-oX4Z3S9WtJlwvj0uH9HlYcWv+x1hqp8mhXl7HsLu2f0=",
|
||||
"owner": "libgit2",
|
||||
"repo": "libgit2",
|
||||
"rev": "45fd9ed7ae1a9b74b957ef4f337bc3c8b3df01b5",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "libgit2",
|
||||
"repo": "libgit2",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nix": {
|
||||
"inputs": {
|
||||
"flake-compat": "flake-compat",
|
||||
"libgit2": "libgit2",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-regression": "nixpkgs-regression"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1711102798,
|
||||
"narHash": "sha256-CXOIJr8byjolqG7eqCLa+Wfi7rah62VmLoqSXENaZnw=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nix",
|
||||
"rev": "a22328066416650471c3545b0b138669ea212ab4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"repo": "nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1709083642,
|
||||
"narHash": "sha256-7kkJQd4rZ+vFrzWu8sTRtta5D1kBG0LSRYAfhtmMlSo=",
|
||||
"lastModified": 1710889954,
|
||||
"narHash": "sha256-Pr6F5Pmd7JnNEMHHmspZ0qVqIBVxyZ13ik1pJtm2QXk=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "b550fe4b4776908ac2a861124307045f8e717c8e",
|
||||
"rev": "7872526e9c5332274ea5932a0c3270d6e4724f3b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "release-23.11",
|
||||
"ref": "nixpkgs-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
@@ -120,44 +51,10 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs-regression": {
|
||||
"locked": {
|
||||
"lastModified": 1643052045,
|
||||
"narHash": "sha256-uGJ0VXIhWKGXxkeNnq4TvV3CIOkUJ3PAoLZ3HMzNVMw=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "215d4d0fd80ca5163643b03a33fde804a29cc1e2",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "215d4d0fd80ca5163643b03a33fde804a29cc1e2",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs_2": {
|
||||
"locked": {
|
||||
"lastModified": 1710889954,
|
||||
"narHash": "sha256-Pr6F5Pmd7JnNEMHHmspZ0qVqIBVxyZ13ik1pJtm2QXk=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "7872526e9c5332274ea5932a0c3270d6e4724f3b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixpkgs-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils",
|
||||
"lean4-mode": "lean4-mode",
|
||||
"nix": "nix",
|
||||
"nixpkgs": "nixpkgs_2",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-old": "nixpkgs-old"
|
||||
}
|
||||
},
|
||||
|
||||
61
flake.nix
61
flake.nix
@@ -1,38 +1,21 @@
|
||||
{
|
||||
description = "Lean interactive theorem prover";
|
||||
description = "Lean development flake. Not intended for end users.";
|
||||
|
||||
inputs.nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable";
|
||||
# old nixpkgs used for portable release with older glibc (2.27)
|
||||
inputs.nixpkgs-old.url = "github:NixOS/nixpkgs/nixos-19.03";
|
||||
inputs.nixpkgs-old.flake = false;
|
||||
inputs.flake-utils.url = "github:numtide/flake-utils";
|
||||
inputs.nix.url = "github:NixOS/nix";
|
||||
inputs.lean4-mode = {
|
||||
url = "github:leanprover/lean4-mode";
|
||||
flake = false;
|
||||
};
|
||||
# used *only* by `stage0-from-input` below
|
||||
#inputs.lean-stage0 = {
|
||||
# url = github:leanprover/lean4;
|
||||
# inputs.nixpkgs.follows = "nixpkgs";
|
||||
# inputs.flake-utils.follows = "flake-utils";
|
||||
# inputs.nix.follows = "nix";
|
||||
# inputs.lean4-mode.follows = "lean4-mode";
|
||||
#};
|
||||
|
||||
outputs = { self, nixpkgs, nixpkgs-old, flake-utils, nix, lean4-mode, ... }@inputs: flake-utils.lib.eachDefaultSystem (system:
|
||||
outputs = { self, nixpkgs, nixpkgs-old, flake-utils, ... }@inputs: flake-utils.lib.eachDefaultSystem (system:
|
||||
let
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
# for `vscode-with-extensions`
|
||||
config.allowUnfree = true;
|
||||
};
|
||||
pkgs = import nixpkgs { inherit system; };
|
||||
# An old nixpkgs for creating releases with an old glibc
|
||||
pkgsDist-old = import nixpkgs-old { inherit system; };
|
||||
# An old nixpkgs for creating releases with an old glibc
|
||||
pkgsDist-old-aarch = import nixpkgs-old { localSystem.config = "aarch64-unknown-linux-gnu"; };
|
||||
|
||||
lean-packages = pkgs.callPackage (./nix/packages.nix) { src = ./.; inherit nix lean4-mode; };
|
||||
lean-packages = pkgs.callPackage (./nix/packages.nix) { src = ./.; };
|
||||
|
||||
devShellWithDist = pkgsDist: pkgs.mkShell.override {
|
||||
stdenv = pkgs.overrideCC pkgs.stdenv lean-packages.llvmPackages.clang;
|
||||
@@ -58,41 +41,15 @@
|
||||
GDB = pkgsDist.gdb;
|
||||
});
|
||||
in {
|
||||
packages = lean-packages // rec {
|
||||
debug = lean-packages.override { debug = true; };
|
||||
stage0debug = lean-packages.override { stage0debug = true; };
|
||||
asan = lean-packages.override { extraCMakeFlags = [ "-DLEAN_EXTRA_CXX_FLAGS=-fsanitize=address" "-DLEANC_EXTRA_FLAGS=-fsanitize=address" "-DSMALL_ALLOCATOR=OFF" "-DSYMBOLIC=OFF" ]; };
|
||||
asandebug = asan.override { debug = true; };
|
||||
tsan = lean-packages.override {
|
||||
extraCMakeFlags = [ "-DLEAN_EXTRA_CXX_FLAGS=-fsanitize=thread" "-DLEANC_EXTRA_FLAGS=-fsanitize=thread" "-DCOMPRESSED_OBJECT_HEADER=OFF" ];
|
||||
stage0 = (lean-packages.override {
|
||||
# Compressed headers currently trigger data race reports in tsan.
|
||||
# Turn them off for stage 0 as well so stage 1 can read its own stdlib.
|
||||
extraCMakeFlags = [ "-DCOMPRESSED_OBJECT_HEADER=OFF" ];
|
||||
}).stage1;
|
||||
};
|
||||
tsandebug = tsan.override { debug = true; };
|
||||
stage0-from-input = lean-packages.override {
|
||||
stage0 = pkgs.writeShellScriptBin "lean" ''
|
||||
exec ${inputs.lean-stage0.packages.${system}.lean}/bin/lean -Dinterpreter.prefer_native=false "$@"
|
||||
'';
|
||||
};
|
||||
inherit self;
|
||||
packages = {
|
||||
# to be removed when Nix CI is not needed anymore
|
||||
inherit (lean-packages) cacheRoots test update-stage0-commit ciShell;
|
||||
deprecated = lean-packages;
|
||||
};
|
||||
defaultPackage = lean-packages.lean-all;
|
||||
|
||||
# The default development shell for working on lean itself
|
||||
devShells.default = devShellWithDist pkgs;
|
||||
devShells.oldGlibc = devShellWithDist pkgsDist-old;
|
||||
devShells.oldGlibcAArch = devShellWithDist pkgsDist-old-aarch;
|
||||
|
||||
checks.lean = lean-packages.test;
|
||||
}) // rec {
|
||||
templates.pkg = {
|
||||
path = ./nix/templates/pkg;
|
||||
description = "A custom Lean package";
|
||||
};
|
||||
|
||||
defaultTemplate = templates.pkg;
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
stdenv, lib, cmake, gmp, git, gnumake, bash, buildLeanPackage, writeShellScriptBin, runCommand, symlinkJoin, lndir, perl, gnused, darwin, llvmPackages, linkFarmFromDrvs,
|
||||
... } @ args:
|
||||
with builtins;
|
||||
rec {
|
||||
lib.warn "The Nix-based build is deprecated" rec {
|
||||
inherit stdenv;
|
||||
sourceByRegex = p: rs: lib.sourceByRegex p (map (r: "(/src/)?${r}") rs);
|
||||
buildCMake = args: stdenv.mkDerivation ({
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{ lean, lean-leanDeps ? lean, lean-final ? lean, leanc,
|
||||
stdenv, lib, coreutils, gnused, writeShellScriptBin, bash, lean-emacs, lean-vscode, nix, substituteAll, symlinkJoin, linkFarmFromDrvs,
|
||||
stdenv, lib, coreutils, gnused, writeShellScriptBin, bash, substituteAll, symlinkJoin, linkFarmFromDrvs,
|
||||
runCommand, darwin, mkShell, ... }:
|
||||
let lean-final' = lean-final; in
|
||||
lib.makeOverridable (
|
||||
@@ -197,19 +197,6 @@ with builtins; let
|
||||
then map (m: m.module) header.imports
|
||||
else abort "errors while parsing imports of ${mod}:\n${lib.concatStringsSep "\n" header.errors}";
|
||||
in mkMod mod (map (dep: if modDepsMap ? ${dep} then modCandidates.${dep} else externalModMap.${dep}) deps)) modDepsMap;
|
||||
makeEmacsWrapper = name: emacs: lean: writeShellScriptBin name ''
|
||||
${emacs} --eval "(progn (setq lean4-rootdir \"${lean}\"))" "$@"
|
||||
'';
|
||||
makeVSCodeWrapper = name: lean: writeShellScriptBin name ''
|
||||
PATH=${lean}/bin:$PATH ${lean-vscode}/bin/code "$@"
|
||||
'';
|
||||
printPaths = deps: writeShellScriptBin "print-paths" ''
|
||||
echo '${toJSON {
|
||||
oleanPath = [(depRoot "print-paths" deps)];
|
||||
srcPath = ["."] ++ map (dep: dep.src) allExternalDeps;
|
||||
loadDynlibPaths = map pathOfSharedLib (loadDynlibsOfDeps deps);
|
||||
}}'
|
||||
'';
|
||||
expandGlob = g:
|
||||
if typeOf g == "string" then [g]
|
||||
else if g.glob == "one" then [g.mod]
|
||||
@@ -257,48 +244,4 @@ in rec {
|
||||
-o $out/bin/${executableName} \
|
||||
${lib.concatStringsSep " " allLinkFlags}
|
||||
'') {};
|
||||
|
||||
lean-package = writeShellScriptBin "lean" ''
|
||||
LEAN_PATH=${modRoot}:$LEAN_PATH LEAN_SRC_PATH=$LEAN_SRC_PATH:${src} exec ${lean-final}/bin/lean "$@"
|
||||
'';
|
||||
emacs-package = makeEmacsWrapper "emacs-package" lean-package;
|
||||
vscode-package = makeVSCodeWrapper "vscode-package" lean-package;
|
||||
|
||||
link-ilean = writeShellScriptBin "link-ilean" ''
|
||||
dest=''${1:-.}
|
||||
mkdir -p $dest/build/lib
|
||||
ln -sf ${iTree}/* $dest/build/lib
|
||||
'';
|
||||
|
||||
makePrintPathsFor = deps: mods: printPaths deps // mapAttrs (_: mod: makePrintPathsFor (deps ++ [mod]) mods) mods;
|
||||
print-paths = makePrintPathsFor [] (mods' // externalModMap);
|
||||
# `lean` wrapper that dynamically runs Nix for the actual `lean` executable so the same editor can be
|
||||
# used for multiple projects/after upgrading the `lean` input/for editing both stage 1 and the tests
|
||||
lean-bin-dev = substituteAll {
|
||||
name = "lean";
|
||||
dir = "bin";
|
||||
src = ./lean-dev.in;
|
||||
isExecutable = true;
|
||||
srcRoot = fullSrc; # use root flake.nix in case of Lean repo
|
||||
inherit bash nix srcTarget srcArgs;
|
||||
};
|
||||
lake-dev = substituteAll {
|
||||
name = "lake";
|
||||
dir = "bin";
|
||||
src = ./lake-dev.in;
|
||||
isExecutable = true;
|
||||
srcRoot = fullSrc; # use root flake.nix in case of Lean repo
|
||||
inherit bash nix srcTarget srcArgs;
|
||||
};
|
||||
lean-dev = symlinkJoin { name = "lean-dev"; paths = [ lean-bin-dev lake-dev ]; };
|
||||
emacs-dev = makeEmacsWrapper "emacs-dev" "${lean-emacs}/bin/emacs" lean-dev;
|
||||
emacs-path-dev = makeEmacsWrapper "emacs-path-dev" "emacs" lean-dev;
|
||||
vscode-dev = makeVSCodeWrapper "vscode-dev" lean-dev;
|
||||
|
||||
devShell = mkShell {
|
||||
buildInputs = [ nix ];
|
||||
shellHook = ''
|
||||
export LEAN_SRC_PATH="${srcPath}"
|
||||
'';
|
||||
};
|
||||
})
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
{ src, pkgs, nix, ... } @ args:
|
||||
{ src, pkgs, ... } @ args:
|
||||
with pkgs;
|
||||
let
|
||||
nix-pinned = writeShellScriptBin "nix" ''
|
||||
${nix.packages.${system}.default}/bin/nix --experimental-features 'nix-command flakes' --extra-substituters https://lean4.cachix.org/ --option warn-dirty false "$@"
|
||||
'';
|
||||
# https://github.com/NixOS/nixpkgs/issues/130963
|
||||
llvmPackages = if stdenv.isDarwin then llvmPackages_11 else llvmPackages_15;
|
||||
cc = (ccacheWrapper.override rec {
|
||||
@@ -42,40 +39,9 @@ let
|
||||
inherit (lean) stdenv;
|
||||
lean = lean.stage1;
|
||||
inherit (lean.stage1) leanc;
|
||||
inherit lean-emacs lean-vscode;
|
||||
nix = nix-pinned;
|
||||
}));
|
||||
lean4-mode = emacsPackages.melpaBuild {
|
||||
pname = "lean4-mode";
|
||||
version = "1";
|
||||
commit = "1";
|
||||
src = args.lean4-mode;
|
||||
packageRequires = with pkgs.emacsPackages.melpaPackages; [ dash f flycheck magit-section lsp-mode s ];
|
||||
recipe = pkgs.writeText "recipe" ''
|
||||
(lean4-mode
|
||||
:repo "leanprover/lean4-mode"
|
||||
:fetcher github
|
||||
:files ("*.el" "data"))
|
||||
'';
|
||||
};
|
||||
lean-emacs = emacsWithPackages [ lean4-mode ];
|
||||
# updating might be nicer by building from source from a flake input, but this is good enough for now
|
||||
vscode-lean4 = vscode-utils.extensionFromVscodeMarketplace {
|
||||
name = "lean4";
|
||||
publisher = "leanprover";
|
||||
version = "0.0.63";
|
||||
sha256 = "sha256-kjEex7L0F2P4pMdXi4NIZ1y59ywJVubqDqsoYagZNkI=";
|
||||
};
|
||||
lean-vscode = vscode-with-extensions.override {
|
||||
vscodeExtensions = [ vscode-lean4 ];
|
||||
};
|
||||
in {
|
||||
inherit cc lean4-mode buildLeanPackage llvmPackages vscode-lean4;
|
||||
lean = lean.stage1;
|
||||
stage0print-paths = lean.stage1.Lean.print-paths;
|
||||
HEAD-as-stage0 = (lean.stage1.Lean.overrideArgs { srcTarget = "..#stage0-from-input.stage0"; srcArgs = "(--override-input lean-stage0 ..\?rev=$(git rev-parse HEAD) -- -Dinterpreter.prefer_native=false \"$@\")"; });
|
||||
HEAD-as-stage1 = (lean.stage1.Lean.overrideArgs { srcTarget = "..\?rev=$(git rev-parse HEAD)#stage0"; });
|
||||
nix = nix-pinned;
|
||||
inherit cc buildLeanPackage llvmPackages;
|
||||
nixpkgs = pkgs;
|
||||
ciShell = writeShellScriptBin "ciShell" ''
|
||||
set -o pipefail
|
||||
@@ -83,5 +49,4 @@ in {
|
||||
# prefix lines with cumulative and individual execution time
|
||||
"$@" |& ts -i "(%.S)]" | ts -s "[%M:%S"
|
||||
'';
|
||||
vscode = lean-vscode;
|
||||
} // lean.stage1.Lean // lean.stage1 // lean
|
||||
} // lean.stage1
|
||||
|
||||
@@ -7,6 +7,7 @@ prelude
|
||||
import Init.Data.Nat.MinMax
|
||||
import Init.Data.Nat.Lemmas
|
||||
import Init.Data.List.Monadic
|
||||
import Init.Data.List.Nat.Range
|
||||
import Init.Data.Fin.Basic
|
||||
import Init.Data.Array.Mem
|
||||
import Init.TacticsExtra
|
||||
@@ -336,6 +337,10 @@ theorem not_mem_nil (a : α) : ¬ a ∈ #[] := nofun
|
||||
|
||||
/-- # get lemmas -/
|
||||
|
||||
theorem lt_of_getElem {x : α} {a : Array α} {idx : Nat} {hidx : idx < a.size} (_ : a[idx] = x) :
|
||||
idx < a.size :=
|
||||
hidx
|
||||
|
||||
theorem getElem?_mem {l : Array α} {i : Fin l.size} : l[i] ∈ l := by
|
||||
erw [Array.mem_def, getElem_eq_data_getElem]
|
||||
apply List.get_mem
|
||||
@@ -505,6 +510,13 @@ theorem size_eq_length_data (as : Array α) : as.size = as.data.length := rfl
|
||||
simp only [mkEmpty_eq, size_push] at *
|
||||
omega
|
||||
|
||||
@[simp] theorem data_range (n : Nat) : (range n).data = List.range n := by
|
||||
induction n <;> simp_all [range, Nat.fold, flip, List.range_succ]
|
||||
|
||||
@[simp]
|
||||
theorem getElem_range {n : Nat} {x : Nat} (h : x < (Array.range n).size) : (Array.range n)[x] = x := by
|
||||
simp [getElem_eq_data_getElem]
|
||||
|
||||
set_option linter.deprecated false in
|
||||
@[simp] theorem reverse_data (a : Array α) : a.reverse.data = a.data.reverse := by
|
||||
let rec go (as : Array α) (i j hj)
|
||||
@@ -707,13 +719,22 @@ theorem mapIdx_spec (as : Array α) (f : Fin as.size → α → β)
|
||||
unfold modify modifyM Id.run
|
||||
split <;> simp
|
||||
|
||||
theorem get_modify {arr : Array α} {x i} (h : i < arr.size) :
|
||||
(arr.modify x f).get ⟨i, by simp [h]⟩ =
|
||||
if x = i then f (arr.get ⟨i, h⟩) else arr.get ⟨i, h⟩ := by
|
||||
simp [modify, modifyM, Id.run]; split
|
||||
· simp [get_set _ _ _ h]; split <;> simp [*]
|
||||
theorem getElem_modify {as : Array α} {x i} (h : i < as.size) :
|
||||
(as.modify x f)[i]'(by simp [h]) = if x = i then f as[i] else as[i] := by
|
||||
simp only [modify, modifyM, get_eq_getElem, Id.run, Id.pure_eq]
|
||||
split
|
||||
· simp only [Id.bind_eq, get_set _ _ _ h]; split <;> simp [*]
|
||||
· rw [if_neg (mt (by rintro rfl; exact h) ‹_›)]
|
||||
|
||||
theorem getElem_modify_self {as : Array α} {i : Nat} (h : i < as.size) (f : α → α) :
|
||||
(as.modify i f)[i]'(by simp [h]) = f as[i] := by
|
||||
simp [getElem_modify h]
|
||||
|
||||
theorem getElem_modify_of_ne {as : Array α} {i : Nat} (hj : j < as.size)
|
||||
(f : α → α) (h : i ≠ j) :
|
||||
(as.modify i f)[j]'(by rwa [size_modify]) = as[j] := by
|
||||
simp [getElem_modify hj, h]
|
||||
|
||||
/-! ### filter -/
|
||||
|
||||
@[simp] theorem filter_data (p : α → Bool) (l : Array α) :
|
||||
|
||||
@@ -20,6 +20,8 @@ We define many of the bitvector operations from the
|
||||
of SMT-LIBv2.
|
||||
-/
|
||||
|
||||
set_option linter.missingDocs true
|
||||
|
||||
/--
|
||||
A bitvector of the specified width.
|
||||
|
||||
@@ -34,14 +36,14 @@ structure BitVec (w : Nat) where
|
||||
O(1), because we use `Fin` as the internal representation of a bitvector. -/
|
||||
toFin : Fin (2^w)
|
||||
|
||||
@[deprecated (since := "2024-04-12")]
|
||||
protected abbrev Std.BitVec := _root_.BitVec
|
||||
|
||||
/--
|
||||
Bitvectors have decidable equality. This should be used via the instance `DecidableEq (BitVec n)`.
|
||||
-/
|
||||
-- We manually derive the `DecidableEq` instances for `BitVec` because
|
||||
-- we want to have builtin support for bit-vector literals, and we
|
||||
-- need a name for this function to implement `canUnfoldAtMatcher` at `WHNF.lean`.
|
||||
def BitVec.decEq (a b : BitVec n) : Decidable (a = b) :=
|
||||
match a, b with
|
||||
def BitVec.decEq (x y : BitVec n) : Decidable (x = y) :=
|
||||
match x, y with
|
||||
| ⟨n⟩, ⟨m⟩ =>
|
||||
if h : n = m then
|
||||
isTrue (h ▸ rfl)
|
||||
@@ -67,9 +69,9 @@ protected def ofNat (n : Nat) (i : Nat) : BitVec n where
|
||||
instance instOfNat : OfNat (BitVec n) i where ofNat := .ofNat n i
|
||||
instance natCastInst : NatCast (BitVec w) := ⟨BitVec.ofNat w⟩
|
||||
|
||||
/-- Given a bitvector `a`, return the underlying `Nat`. This is O(1) because `BitVec` is a
|
||||
/-- Given a bitvector `x`, return the underlying `Nat`. This is O(1) because `BitVec` is a
|
||||
(zero-cost) wrapper around a `Nat`. -/
|
||||
protected def toNat (a : BitVec n) : Nat := a.toFin.val
|
||||
protected def toNat (x : BitVec n) : Nat := x.toFin.val
|
||||
|
||||
/-- Return the bound in terms of toNat. -/
|
||||
theorem isLt (x : BitVec w) : x.toNat < 2^w := x.toFin.isLt
|
||||
@@ -121,18 +123,18 @@ section getXsb
|
||||
@[inline] def getMsb (x : BitVec w) (i : Nat) : Bool := i < w && getLsb x (w-1-i)
|
||||
|
||||
/-- Return most-significant bit in bitvector. -/
|
||||
@[inline] protected def msb (a : BitVec n) : Bool := getMsb a 0
|
||||
@[inline] protected def msb (x : BitVec n) : Bool := getMsb x 0
|
||||
|
||||
end getXsb
|
||||
|
||||
section Int
|
||||
|
||||
/-- Interpret the bitvector as an integer stored in two's complement form. -/
|
||||
protected def toInt (a : BitVec n) : Int :=
|
||||
if 2 * a.toNat < 2^n then
|
||||
a.toNat
|
||||
protected def toInt (x : BitVec n) : Int :=
|
||||
if 2 * x.toNat < 2^n then
|
||||
x.toNat
|
||||
else
|
||||
(a.toNat : Int) - (2^n : Nat)
|
||||
(x.toNat : Int) - (2^n : Nat)
|
||||
|
||||
/-- The `BitVec` with value `(2^n + (i mod 2^n)) mod 2^n`. -/
|
||||
protected def ofInt (n : Nat) (i : Int) : BitVec n := .ofNatLt (i % (Int.ofNat (2^n))).toNat (by
|
||||
@@ -213,7 +215,7 @@ instance : Neg (BitVec n) := ⟨.neg⟩
|
||||
/--
|
||||
Return the absolute value of a signed bitvector.
|
||||
-/
|
||||
protected def abs (s : BitVec n) : BitVec n := if s.msb then .neg s else s
|
||||
protected def abs (x : BitVec n) : BitVec n := if x.msb then .neg x else x
|
||||
|
||||
/--
|
||||
Multiplication for bit vectors. This can be interpreted as either signed or unsigned negation
|
||||
@@ -260,12 +262,12 @@ sdiv 5#4 -2 = -2#4
|
||||
sdiv (-7#4) (-2) = 3#4
|
||||
```
|
||||
-/
|
||||
def sdiv (s t : BitVec n) : BitVec n :=
|
||||
match s.msb, t.msb with
|
||||
| false, false => udiv s t
|
||||
| false, true => .neg (udiv s (.neg t))
|
||||
| true, false => .neg (udiv (.neg s) t)
|
||||
| true, true => udiv (.neg s) (.neg t)
|
||||
def sdiv (x y : BitVec n) : BitVec n :=
|
||||
match x.msb, y.msb with
|
||||
| false, false => udiv x y
|
||||
| false, true => .neg (udiv x (.neg y))
|
||||
| true, false => .neg (udiv (.neg x) y)
|
||||
| true, true => udiv (.neg x) (.neg y)
|
||||
|
||||
/--
|
||||
Signed division for bit vectors using SMTLIB rules for division by zero.
|
||||
@@ -274,40 +276,40 @@ Specifically, `smtSDiv x 0 = if x >= 0 then -1 else 1`
|
||||
|
||||
SMT-Lib name: `bvsdiv`.
|
||||
-/
|
||||
def smtSDiv (s t : BitVec n) : BitVec n :=
|
||||
match s.msb, t.msb with
|
||||
| false, false => smtUDiv s t
|
||||
| false, true => .neg (smtUDiv s (.neg t))
|
||||
| true, false => .neg (smtUDiv (.neg s) t)
|
||||
| true, true => smtUDiv (.neg s) (.neg t)
|
||||
def smtSDiv (x y : BitVec n) : BitVec n :=
|
||||
match x.msb, y.msb with
|
||||
| false, false => smtUDiv x y
|
||||
| false, true => .neg (smtUDiv x (.neg y))
|
||||
| true, false => .neg (smtUDiv (.neg x) y)
|
||||
| true, true => smtUDiv (.neg x) (.neg y)
|
||||
|
||||
/--
|
||||
Remainder for signed division rounding to zero.
|
||||
|
||||
SMT_Lib name: `bvsrem`.
|
||||
-/
|
||||
def srem (s t : BitVec n) : BitVec n :=
|
||||
match s.msb, t.msb with
|
||||
| false, false => umod s t
|
||||
| false, true => umod s (.neg t)
|
||||
| true, false => .neg (umod (.neg s) t)
|
||||
| true, true => .neg (umod (.neg s) (.neg t))
|
||||
def srem (x y : BitVec n) : BitVec n :=
|
||||
match x.msb, y.msb with
|
||||
| false, false => umod x y
|
||||
| false, true => umod x (.neg y)
|
||||
| true, false => .neg (umod (.neg x) y)
|
||||
| true, true => .neg (umod (.neg x) (.neg y))
|
||||
|
||||
/--
|
||||
Remainder for signed division rounded to negative infinity.
|
||||
|
||||
SMT_Lib name: `bvsmod`.
|
||||
-/
|
||||
def smod (s t : BitVec m) : BitVec m :=
|
||||
match s.msb, t.msb with
|
||||
| false, false => umod s t
|
||||
def smod (x y : BitVec m) : BitVec m :=
|
||||
match x.msb, y.msb with
|
||||
| false, false => umod x y
|
||||
| false, true =>
|
||||
let u := umod s (.neg t)
|
||||
(if u = .zero m then u else .add u t)
|
||||
let u := umod x (.neg y)
|
||||
(if u = .zero m then u else .add u y)
|
||||
| true, false =>
|
||||
let u := umod (.neg s) t
|
||||
(if u = .zero m then u else .sub t u)
|
||||
| true, true => .neg (umod (.neg s) (.neg t))
|
||||
let u := umod (.neg x) y
|
||||
(if u = .zero m then u else .sub y u)
|
||||
| true, true => .neg (umod (.neg x) (.neg y))
|
||||
|
||||
end arithmetic
|
||||
|
||||
@@ -371,8 +373,8 @@ end relations
|
||||
|
||||
section cast
|
||||
|
||||
/-- `cast eq i` embeds `i` into an equal `BitVec` type. -/
|
||||
@[inline] def cast (eq : n = m) (i : BitVec n) : BitVec m := .ofNatLt i.toNat (eq ▸ i.isLt)
|
||||
/-- `cast eq x` embeds `x` into an equal `BitVec` type. -/
|
||||
@[inline] def cast (eq : n = m) (x : BitVec n) : BitVec m := .ofNatLt x.toNat (eq ▸ x.isLt)
|
||||
|
||||
@[simp] theorem cast_ofNat {n m : Nat} (h : n = m) (x : Nat) :
|
||||
cast h (BitVec.ofNat n x) = BitVec.ofNat m x := by
|
||||
@@ -389,7 +391,7 @@ Extraction of bits `start` to `start + len - 1` from a bit vector of size `n` to
|
||||
new bitvector of size `len`. If `start + len > n`, then the vector will be zero-padded in the
|
||||
high bits.
|
||||
-/
|
||||
def extractLsb' (start len : Nat) (a : BitVec n) : BitVec len := .ofNat _ (a.toNat >>> start)
|
||||
def extractLsb' (start len : Nat) (x : BitVec n) : BitVec len := .ofNat _ (x.toNat >>> start)
|
||||
|
||||
/--
|
||||
Extraction of bits `hi` (inclusive) down to `lo` (inclusive) from a bit vector of size `n` to
|
||||
@@ -397,12 +399,12 @@ yield a new bitvector of size `hi - lo + 1`.
|
||||
|
||||
SMT-Lib name: `extract`.
|
||||
-/
|
||||
def extractLsb (hi lo : Nat) (a : BitVec n) : BitVec (hi - lo + 1) := extractLsb' lo _ a
|
||||
def extractLsb (hi lo : Nat) (x : BitVec n) : BitVec (hi - lo + 1) := extractLsb' lo _ x
|
||||
|
||||
/--
|
||||
A version of `zeroExtend` that requires a proof, but is a noop.
|
||||
-/
|
||||
def zeroExtend' {n w : Nat} (le : n ≤ w) (x : BitVec n) : BitVec w :=
|
||||
def zeroExtend' {n w : Nat} (le : n ≤ w) (x : BitVec n) : BitVec w :=
|
||||
x.toNat#'(by
|
||||
apply Nat.lt_of_lt_of_le x.isLt
|
||||
exact Nat.pow_le_pow_of_le_right (by trivial) le)
|
||||
@@ -411,8 +413,8 @@ def zeroExtend' {n w : Nat} (le : n ≤ w) (x : BitVec n) : BitVec w :=
|
||||
`shiftLeftZeroExtend x n` returns `zeroExtend (w+n) x <<< n` without
|
||||
needing to compute `x % 2^(2+n)`.
|
||||
-/
|
||||
def shiftLeftZeroExtend (msbs : BitVec w) (m : Nat) : BitVec (w+m) :=
|
||||
let shiftLeftLt {x : Nat} (p : x < 2^w) (m : Nat) : x <<< m < 2^(w+m) := by
|
||||
def shiftLeftZeroExtend (msbs : BitVec w) (m : Nat) : BitVec (w + m) :=
|
||||
let shiftLeftLt {x : Nat} (p : x < 2^w) (m : Nat) : x <<< m < 2^(w + m) := by
|
||||
simp [Nat.shiftLeft_eq, Nat.pow_add]
|
||||
apply Nat.mul_lt_mul_of_pos_right p
|
||||
exact (Nat.two_pow_pos m)
|
||||
@@ -500,24 +502,24 @@ instance : Complement (BitVec w) := ⟨.not⟩
|
||||
|
||||
/--
|
||||
Left shift for bit vectors. The low bits are filled with zeros. As a numeric operation, this is
|
||||
equivalent to `a * 2^s`, modulo `2^n`.
|
||||
equivalent to `x * 2^s`, modulo `2^n`.
|
||||
|
||||
SMT-Lib name: `bvshl` except this operator uses a `Nat` shift value.
|
||||
-/
|
||||
protected def shiftLeft (a : BitVec n) (s : Nat) : BitVec n := BitVec.ofNat n (a.toNat <<< s)
|
||||
protected def shiftLeft (x : BitVec n) (s : Nat) : BitVec n := BitVec.ofNat n (x.toNat <<< s)
|
||||
instance : HShiftLeft (BitVec w) Nat (BitVec w) := ⟨.shiftLeft⟩
|
||||
|
||||
/--
|
||||
(Logical) right shift for bit vectors. The high bits are filled with zeros.
|
||||
As a numeric operation, this is equivalent to `a / 2^s`, rounding down.
|
||||
As a numeric operation, this is equivalent to `x / 2^s`, rounding down.
|
||||
|
||||
SMT-Lib name: `bvlshr` except this operator uses a `Nat` shift value.
|
||||
-/
|
||||
def ushiftRight (a : BitVec n) (s : Nat) : BitVec n :=
|
||||
(a.toNat >>> s)#'(by
|
||||
let ⟨a, lt⟩ := a
|
||||
def ushiftRight (x : BitVec n) (s : Nat) : BitVec n :=
|
||||
(x.toNat >>> s)#'(by
|
||||
let ⟨x, lt⟩ := x
|
||||
simp only [BitVec.toNat, Nat.shiftRight_eq_div_pow, Nat.div_lt_iff_lt_mul (Nat.two_pow_pos s)]
|
||||
rw [←Nat.mul_one a]
|
||||
rw [←Nat.mul_one x]
|
||||
exact Nat.mul_lt_mul_of_lt_of_le' lt (Nat.two_pow_pos s) (Nat.le_refl 1))
|
||||
|
||||
instance : HShiftRight (BitVec w) Nat (BitVec w) := ⟨.ushiftRight⟩
|
||||
@@ -525,15 +527,24 @@ instance : HShiftRight (BitVec w) Nat (BitVec w) := ⟨.ushiftRight⟩
|
||||
/--
|
||||
Arithmetic right shift for bit vectors. The high bits are filled with the
|
||||
most-significant bit.
|
||||
As a numeric operation, this is equivalent to `a.toInt >>> s`.
|
||||
As a numeric operation, this is equivalent to `x.toInt >>> s`.
|
||||
|
||||
SMT-Lib name: `bvashr` except this operator uses a `Nat` shift value.
|
||||
-/
|
||||
def sshiftRight (a : BitVec n) (s : Nat) : BitVec n := .ofInt n (a.toInt >>> s)
|
||||
def sshiftRight (x : BitVec n) (s : Nat) : BitVec n := .ofInt n (x.toInt >>> s)
|
||||
|
||||
instance {n} : HShiftLeft (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x <<< y.toNat⟩
|
||||
instance {n} : HShiftRight (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x >>> y.toNat⟩
|
||||
|
||||
/--
|
||||
Arithmetic right shift for bit vectors. The high bits are filled with the
|
||||
most-significant bit.
|
||||
As a numeric operation, this is equivalent to `a.toInt >>> s.toNat`.
|
||||
|
||||
SMT-Lib name: `bvashr`.
|
||||
-/
|
||||
def sshiftRight' (a : BitVec n) (s : BitVec m) : BitVec n := a.sshiftRight s.toNat
|
||||
|
||||
/-- Auxiliary function for `rotateLeft`, which does not take into account the case where
|
||||
the rotation amount is greater than the bitvector width. -/
|
||||
def rotateLeftAux (x : BitVec w) (n : Nat) : BitVec w :=
|
||||
@@ -583,11 +594,9 @@ instance : HAppend (BitVec w) (BitVec v) (BitVec (w + v)) := ⟨.append⟩
|
||||
-- TODO: write this using multiplication
|
||||
/-- `replicate i x` concatenates `i` copies of `x` into a new vector of length `w*i`. -/
|
||||
def replicate : (i : Nat) → BitVec w → BitVec (w*i)
|
||||
| 0, _ => 0
|
||||
| 0, _ => 0#0
|
||||
| n+1, x =>
|
||||
have hEq : w + w*n = w*(n + 1) := by
|
||||
rw [Nat.mul_add, Nat.add_comm, Nat.mul_one]
|
||||
hEq ▸ (x ++ replicate n x)
|
||||
(x ++ replicate n x).cast (by rw [Nat.mul_succ]; omega)
|
||||
|
||||
/-!
|
||||
### Cons and Concat
|
||||
|
||||
@@ -28,6 +28,8 @@ https://github.com/mhk119/lean-smt/blob/bitvec/Smt/Data/Bitwise.lean.
|
||||
|
||||
-/
|
||||
|
||||
set_option linter.missingDocs true
|
||||
|
||||
open Nat Bool
|
||||
|
||||
namespace Bool
|
||||
@@ -287,18 +289,18 @@ theorem sle_eq_carry (x y : BitVec w) :
|
||||
A recurrence that describes multiplication as repeated addition.
|
||||
Is useful for bitblasting multiplication.
|
||||
-/
|
||||
def mulRec (l r : BitVec w) (s : Nat) : BitVec w :=
|
||||
let cur := if r.getLsb s then (l <<< s) else 0
|
||||
def mulRec (x y : BitVec w) (s : Nat) : BitVec w :=
|
||||
let cur := if y.getLsb s then (x <<< s) else 0
|
||||
match s with
|
||||
| 0 => cur
|
||||
| s + 1 => mulRec l r s + cur
|
||||
| s + 1 => mulRec x y s + cur
|
||||
|
||||
theorem mulRec_zero_eq (l r : BitVec w) :
|
||||
mulRec l r 0 = if r.getLsb 0 then l else 0 := by
|
||||
theorem mulRec_zero_eq (x y : BitVec w) :
|
||||
mulRec x y 0 = if y.getLsb 0 then x else 0 := by
|
||||
simp [mulRec]
|
||||
|
||||
theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) :
|
||||
mulRec l r (s + 1) = mulRec l r s + if r.getLsb (s + 1) then (l <<< (s + 1)) else 0 := rfl
|
||||
theorem mulRec_succ_eq (x y : BitVec w) (s : Nat) :
|
||||
mulRec x y (s + 1) = mulRec x y s + if y.getLsb (s + 1) then (x <<< (s + 1)) else 0 := rfl
|
||||
|
||||
/--
|
||||
Recurrence lemma: truncating to `i+1` bits and then zero extending to `w`
|
||||
@@ -324,29 +326,29 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w
|
||||
by_cases hi : x.getLsb i <;> simp [hi] <;> omega
|
||||
|
||||
/--
|
||||
Recurrence lemma: multiplying `l` with the first `s` bits of `r` is the
|
||||
same as truncating `r` to `s` bits, then zero extending to the original length,
|
||||
Recurrence lemma: multiplying `x` with the first `s` bits of `y` is the
|
||||
same as truncating `y` to `s` bits, then zero extending to the original length,
|
||||
and performing the multplication. -/
|
||||
theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) :
|
||||
mulRec l r s = l * ((r.truncate (s + 1)).zeroExtend w) := by
|
||||
theorem mulRec_eq_mul_signExtend_truncate (x y : BitVec w) (s : Nat) :
|
||||
mulRec x y s = x * ((y.truncate (s + 1)).zeroExtend w) := by
|
||||
induction s
|
||||
case zero =>
|
||||
simp only [mulRec_zero_eq, ofNat_eq_ofNat, Nat.reduceAdd]
|
||||
by_cases r.getLsb 0
|
||||
case pos hr =>
|
||||
simp only [hr, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero,
|
||||
hr, ofBool_true, ofNat_eq_ofNat]
|
||||
by_cases y.getLsb 0
|
||||
case pos hy =>
|
||||
simp only [hy, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero,
|
||||
ofBool_true, ofNat_eq_ofNat]
|
||||
rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)]
|
||||
simp
|
||||
case neg hr =>
|
||||
simp [hr, zeroExtend_one_eq_ofBool_getLsb_zero]
|
||||
case neg hy =>
|
||||
simp [hy, zeroExtend_one_eq_ofBool_getLsb_zero]
|
||||
case succ s' hs =>
|
||||
rw [mulRec_succ_eq, hs]
|
||||
have heq :
|
||||
(if r.getLsb (s' + 1) = true then l <<< (s' + 1) else 0) =
|
||||
(l * (r &&& (BitVec.twoPow w (s' + 1)))) := by
|
||||
(if y.getLsb (s' + 1) = true then x <<< (s' + 1) else 0) =
|
||||
(x * (y &&& (BitVec.twoPow w (s' + 1)))) := by
|
||||
simp only [ofNat_eq_ofNat, and_twoPow]
|
||||
by_cases hr : r.getLsb (s' + 1) <;> simp [hr]
|
||||
by_cases hy : y.getLsb (s' + 1) <;> simp [hy]
|
||||
rw [heq, ← BitVec.mul_add, ← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow]
|
||||
|
||||
theorem getLsb_mul (x y : BitVec w) (i : Nat) :
|
||||
@@ -427,6 +429,67 @@ theorem shiftLeft_eq_shiftLeftRec (x : BitVec w₁) (y : BitVec w₂) :
|
||||
· simp [of_length_zero]
|
||||
· simp [shiftLeftRec_eq]
|
||||
|
||||
/- ### Arithmetic shift right (sshiftRight) recurrence -/
|
||||
|
||||
/--
|
||||
`sshiftRightRec x y n` shifts `x` arithmetically/signed to the right by the first `n` bits of `y`.
|
||||
The theorem `sshiftRight_eq_sshiftRightRec` proves the equivalence of `(x.sshiftRight y)` and `sshiftRightRec`.
|
||||
Together with equations `sshiftRightRec_zero`, `sshiftRightRec_succ`,
|
||||
this allows us to unfold `sshiftRight` into a circuit for bitblasting.
|
||||
-/
|
||||
def sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ :=
|
||||
let shiftAmt := (y &&& (twoPow w₂ n))
|
||||
match n with
|
||||
| 0 => x.sshiftRight' shiftAmt
|
||||
| n + 1 => (sshiftRightRec x y n).sshiftRight' shiftAmt
|
||||
|
||||
@[simp]
|
||||
theorem sshiftRightRec_zero_eq (x : BitVec w₁) (y : BitVec w₂) :
|
||||
sshiftRightRec x y 0 = x.sshiftRight' (y &&& 1#w₂) := by
|
||||
simp only [sshiftRightRec, twoPow_zero]
|
||||
|
||||
@[simp]
|
||||
theorem sshiftRightRec_succ_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
|
||||
sshiftRightRec x y (n + 1) = (sshiftRightRec x y n).sshiftRight' (y &&& twoPow w₂ (n + 1)) := by
|
||||
simp [sshiftRightRec]
|
||||
|
||||
/--
|
||||
If `y &&& z = 0`, `x.sshiftRight (y ||| z) = (x.sshiftRight y).sshiftRight z`.
|
||||
This follows as `y &&& z = 0` implies `y ||| z = y + z`,
|
||||
and thus `x.sshiftRight (y ||| z) = x.sshiftRight (y + z) = (x.sshiftRight y).sshiftRight z`.
|
||||
-/
|
||||
theorem sshiftRight'_or_of_and_eq_zero {x : BitVec w₁} {y z : BitVec w₂}
|
||||
(h : y &&& z = 0#w₂) :
|
||||
x.sshiftRight' (y ||| z) = (x.sshiftRight' y).sshiftRight' z := by
|
||||
simp [sshiftRight', ← add_eq_or_of_and_eq_zero _ _ h,
|
||||
toNat_add_of_and_eq_zero h, sshiftRight_add]
|
||||
|
||||
theorem sshiftRightRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
|
||||
sshiftRightRec x y n = x.sshiftRight' ((y.truncate (n + 1)).zeroExtend w₂) := by
|
||||
induction n generalizing x y
|
||||
case zero =>
|
||||
ext i
|
||||
simp [twoPow_zero, Nat.reduceAdd, and_one_eq_zeroExtend_ofBool_getLsb, truncate_one]
|
||||
case succ n ih =>
|
||||
simp only [sshiftRightRec_succ_eq, and_twoPow, ih]
|
||||
by_cases h : y.getLsb (n + 1)
|
||||
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true h,
|
||||
sshiftRight'_or_of_and_eq_zero (by simp), h]
|
||||
simp
|
||||
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1)
|
||||
(by simp [h])]
|
||||
simp [h]
|
||||
|
||||
/--
|
||||
Show that `x.sshiftRight y` can be written in terms of `sshiftRightRec`.
|
||||
This can be unfolded in terms of `sshiftRightRec_zero_eq`, `sshiftRightRec_succ_eq` for bitblasting.
|
||||
-/
|
||||
theorem sshiftRight_eq_sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) :
|
||||
(x.sshiftRight' y).getLsb i = (sshiftRightRec x y (w₂ - 1)).getLsb i := by
|
||||
rcases w₂ with rfl | w₂
|
||||
· simp [of_length_zero]
|
||||
· simp [sshiftRightRec_eq]
|
||||
|
||||
/- ### Logical shift right (ushiftRight) recurrence for bitblasting -/
|
||||
|
||||
/--
|
||||
|
||||
@@ -8,6 +8,8 @@ import Init.Data.BitVec.Lemmas
|
||||
import Init.Data.Nat.Lemmas
|
||||
import Init.Data.Fin.Iterate
|
||||
|
||||
set_option linter.missingDocs true
|
||||
|
||||
namespace BitVec
|
||||
|
||||
/--
|
||||
|
||||
@@ -12,6 +12,8 @@ import Init.Data.Nat.Lemmas
|
||||
import Init.Data.Nat.Mod
|
||||
import Init.Data.Int.Bitwise.Lemmas
|
||||
|
||||
set_option linter.missingDocs true
|
||||
|
||||
namespace BitVec
|
||||
|
||||
/--
|
||||
@@ -21,7 +23,7 @@ theorem ofFin_eq_ofNat : @BitVec.ofFin w (Fin.mk x lt) = BitVec.ofNat w x := by
|
||||
simp only [BitVec.ofNat, Fin.ofNat', lt, Nat.mod_eq_of_lt]
|
||||
|
||||
/-- Prove equality of bitvectors in terms of nat operations. -/
|
||||
theorem eq_of_toNat_eq {n} : ∀ {i j : BitVec n}, i.toNat = j.toNat → i = j
|
||||
theorem eq_of_toNat_eq {n} : ∀ {x y : BitVec n}, x.toNat = y.toNat → x = y
|
||||
| ⟨_, _⟩, ⟨_, _⟩, rfl => rfl
|
||||
|
||||
@[simp] theorem val_toFin (x : BitVec w) : x.toFin.val = x.toNat := rfl
|
||||
@@ -226,12 +228,12 @@ theorem toNat_ge_of_msb_true {x : BitVec n} (p : BitVec.msb x = true) : x.toNat
|
||||
/-! ### toInt/ofInt -/
|
||||
|
||||
/-- Prove equality of bitvectors in terms of nat operations. -/
|
||||
theorem toInt_eq_toNat_cond (i : BitVec n) :
|
||||
i.toInt =
|
||||
if 2*i.toNat < 2^n then
|
||||
(i.toNat : Int)
|
||||
theorem toInt_eq_toNat_cond (x : BitVec n) :
|
||||
x.toInt =
|
||||
if 2*x.toNat < 2^n then
|
||||
(x.toNat : Int)
|
||||
else
|
||||
(i.toNat : Int) - (2^n : Nat) :=
|
||||
(x.toNat : Int) - (2^n : Nat) :=
|
||||
rfl
|
||||
|
||||
theorem msb_eq_false_iff_two_mul_lt (x : BitVec w) : x.msb = false ↔ 2 * x.toNat < 2^w := by
|
||||
@@ -258,13 +260,13 @@ theorem toInt_eq_toNat_bmod (x : BitVec n) : x.toInt = Int.bmod x.toNat (2^n) :=
|
||||
omega
|
||||
|
||||
/-- Prove equality of bitvectors in terms of nat operations. -/
|
||||
theorem eq_of_toInt_eq {i j : BitVec n} : i.toInt = j.toInt → i = j := by
|
||||
theorem eq_of_toInt_eq {x y : BitVec n} : x.toInt = y.toInt → x = y := by
|
||||
intro eq
|
||||
simp [toInt_eq_toNat_cond] at eq
|
||||
apply eq_of_toNat_eq
|
||||
revert eq
|
||||
have _ilt := i.isLt
|
||||
have _jlt := j.isLt
|
||||
have _xlt := x.isLt
|
||||
have _ylt := y.isLt
|
||||
split <;> split <;> omega
|
||||
|
||||
theorem toInt_inj (x y : BitVec n) : x.toInt = y.toInt ↔ x = y :=
|
||||
@@ -731,6 +733,21 @@ theorem getLsb_shiftLeft' {x : BitVec w₁} {y : BitVec w₂} {i : Nat} :
|
||||
getLsb (x >>> i) j = getLsb x (i+j) := by
|
||||
unfold getLsb ; simp
|
||||
|
||||
theorem ushiftRight_xor_distrib (x y : BitVec w) (n : Nat) :
|
||||
(x ^^^ y) >>> n = (x >>> n) ^^^ (y >>> n) := by
|
||||
ext
|
||||
simp
|
||||
|
||||
theorem ushiftRight_and_distrib (x y : BitVec w) (n : Nat) :
|
||||
(x &&& y) >>> n = (x >>> n) &&& (y >>> n) := by
|
||||
ext
|
||||
simp
|
||||
|
||||
theorem ushiftRight_or_distrib (x y : BitVec w) (n : Nat) :
|
||||
(x ||| y) >>> n = (x >>> n) ||| (y >>> n) := by
|
||||
ext
|
||||
simp
|
||||
|
||||
@[simp]
|
||||
theorem ushiftRight_zero_eq (x : BitVec w) : x >>> 0 = x := by
|
||||
simp [bv_toNat]
|
||||
@@ -784,7 +801,7 @@ theorem sshiftRight_eq_of_msb_true {x : BitVec w} {s : Nat} (h : x.msb = true) :
|
||||
· rw [Nat.shiftRight_eq_div_pow]
|
||||
apply Nat.lt_of_le_of_lt (Nat.div_le_self _ _) (by omega)
|
||||
|
||||
theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
|
||||
@[simp] theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
|
||||
getLsb (x.sshiftRight s) i =
|
||||
(!decide (w ≤ i) && if s + i < w then x.getLsb (s + i) else x.msb) := by
|
||||
rcases hmsb : x.msb with rfl | rfl
|
||||
@@ -805,6 +822,41 @@ theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
|
||||
Nat.not_lt, decide_eq_true_eq]
|
||||
omega
|
||||
|
||||
/-- The msb after arithmetic shifting right equals the original msb. -/
|
||||
theorem sshiftRight_msb_eq_msb {n : Nat} {x : BitVec w} :
|
||||
(x.sshiftRight n).msb = x.msb := by
|
||||
rw [msb_eq_getLsb_last, getLsb_sshiftRight, msb_eq_getLsb_last]
|
||||
by_cases hw₀ : w = 0
|
||||
· simp [hw₀]
|
||||
· simp only [show ¬(w ≤ w - 1) by omega, decide_False, Bool.not_false, Bool.true_and,
|
||||
ite_eq_right_iff]
|
||||
intros h
|
||||
simp [show n = 0 by omega]
|
||||
|
||||
@[simp] theorem sshiftRight_zero {x : BitVec w} : x.sshiftRight 0 = x := by
|
||||
ext i
|
||||
simp
|
||||
|
||||
theorem sshiftRight_add {x : BitVec w} {m n : Nat} :
|
||||
x.sshiftRight (m + n) = (x.sshiftRight m).sshiftRight n := by
|
||||
ext i
|
||||
simp only [getLsb_sshiftRight, Nat.add_assoc]
|
||||
by_cases h₁ : w ≤ (i : Nat)
|
||||
· simp [h₁]
|
||||
· simp only [h₁, decide_False, Bool.not_false, Bool.true_and]
|
||||
by_cases h₂ : n + ↑i < w
|
||||
· simp [h₂]
|
||||
· simp only [h₂, ↓reduceIte]
|
||||
by_cases h₃ : m + (n + ↑i) < w
|
||||
· simp [h₃]
|
||||
omega
|
||||
· simp [h₃, sshiftRight_msb_eq_msb]
|
||||
|
||||
/-! ### sshiftRight reductions from BitVec to Nat -/
|
||||
|
||||
@[simp]
|
||||
theorem sshiftRight_eq' (x : BitVec w) : x.sshiftRight' y = x.sshiftRight y.toNat := rfl
|
||||
|
||||
/-! ### signExtend -/
|
||||
|
||||
/-- Equation theorem for `Int.sub` when both arguments are `Int.ofNat` -/
|
||||
@@ -867,15 +919,15 @@ theorem append_def (x : BitVec v) (y : BitVec w) :
|
||||
(x ++ y).toNat = x.toNat <<< n ||| y.toNat :=
|
||||
rfl
|
||||
|
||||
@[simp] theorem getLsb_append {v : BitVec n} {w : BitVec m} :
|
||||
getLsb (v ++ w) i = bif i < m then getLsb w i else getLsb v (i - m) := by
|
||||
@[simp] theorem getLsb_append {x : BitVec n} {y : BitVec m} :
|
||||
getLsb (x ++ y) i = bif i < m then getLsb y i else getLsb x (i - m) := by
|
||||
simp only [append_def, getLsb_or, getLsb_shiftLeftZeroExtend, getLsb_zeroExtend']
|
||||
by_cases h : i < m
|
||||
· simp [h]
|
||||
· simp [h]; simp_all
|
||||
|
||||
@[simp] theorem getMsb_append {v : BitVec n} {w : BitVec m} :
|
||||
getMsb (v ++ w) i = bif n ≤ i then getMsb w (i - n) else getMsb v i := by
|
||||
@[simp] theorem getMsb_append {x : BitVec n} {y : BitVec m} :
|
||||
getMsb (x ++ y) i = bif n ≤ i then getMsb y (i - n) else getMsb x i := by
|
||||
simp [append_def]
|
||||
by_cases h : n ≤ i
|
||||
· simp [h]
|
||||
@@ -1567,4 +1619,46 @@ theorem and_one_eq_zeroExtend_ofBool_getLsb {x : BitVec w} :
|
||||
Bool.true_and]
|
||||
by_cases h : (0 = (i : Nat)) <;> simp [h] <;> omega
|
||||
|
||||
@[simp]
|
||||
theorem replicate_zero_eq {x : BitVec w} : x.replicate 0 = 0#0 := by
|
||||
simp [replicate]
|
||||
|
||||
@[simp]
|
||||
theorem replicate_succ_eq {x : BitVec w} :
|
||||
x.replicate (n + 1) =
|
||||
(x ++ replicate n x).cast (by rw [Nat.mul_succ]; omega) := by
|
||||
simp [replicate]
|
||||
|
||||
/--
|
||||
If a number `w * n ≤ i < w * (n + 1)`, then `i - w * n` equals `i % w`.
|
||||
This is true by subtracting `w * n` from the inequality, giving
|
||||
`0 ≤ i - w * n < w`, which uniquely identifies `i % w`.
|
||||
-/
|
||||
private theorem Nat.sub_mul_eq_mod_of_lt_of_le (hlo : w * n ≤ i) (hhi : i < w * (n + 1)) :
|
||||
i - w * n = i % w := by
|
||||
rw [Nat.mod_def]
|
||||
congr
|
||||
symm
|
||||
apply Nat.div_eq_of_lt_le
|
||||
(by rw [Nat.mul_comm]; omega)
|
||||
(by rw [Nat.mul_comm]; omega)
|
||||
|
||||
@[simp]
|
||||
theorem getLsb_replicate {n w : Nat} (x : BitVec w) :
|
||||
(x.replicate n).getLsb i =
|
||||
(decide (i < w * n) && x.getLsb (i % w)) := by
|
||||
induction n generalizing x
|
||||
case zero => simp
|
||||
case succ n ih =>
|
||||
simp only [replicate_succ_eq, getLsb_cast, getLsb_append]
|
||||
by_cases hi : i < w * (n + 1)
|
||||
· simp only [hi, decide_True, Bool.true_and]
|
||||
by_cases hi' : i < w * n
|
||||
· simp [hi', ih]
|
||||
· simp only [hi', decide_False, cond_false]
|
||||
rw [Nat.sub_mul_eq_mod_of_lt_of_le] <;> omega
|
||||
· rw [Nat.mul_succ] at hi ⊢
|
||||
simp only [show ¬i < w * n by omega, decide_False, cond_false, hi, Bool.false_and]
|
||||
apply BitVec.getLsb_ge (x := x) (i := i - w * n) (ge := by omega)
|
||||
|
||||
end BitVec
|
||||
|
||||
@@ -438,6 +438,24 @@ Added for confluence between `if_true_left` and `ite_false_same` on
|
||||
-/
|
||||
@[simp] theorem eq_true_imp_eq_false : ∀(b:Bool), (b = true → b = false) ↔ (b = false) := by decide
|
||||
|
||||
/-! ### forall -/
|
||||
|
||||
theorem forall_bool' {p : Bool → Prop} (b : Bool) : (∀ x, p x) ↔ p b ∧ p !b :=
|
||||
⟨fun h ↦ ⟨h _, h _⟩, fun ⟨h₁, h₂⟩ x ↦ by cases b <;> cases x <;> assumption⟩
|
||||
|
||||
@[simp]
|
||||
theorem forall_bool {p : Bool → Prop} : (∀ b, p b) ↔ p false ∧ p true :=
|
||||
forall_bool' false
|
||||
|
||||
/-! ### exists -/
|
||||
|
||||
theorem exists_bool' {p : Bool → Prop} (b : Bool) : (∃ x, p x) ↔ p b ∨ p !b :=
|
||||
⟨fun ⟨x, hx⟩ ↦ by cases x <;> cases b <;> first | exact .inl ‹_› | exact .inr ‹_›,
|
||||
fun h ↦ by cases h <;> exact ⟨_, ‹_›⟩⟩
|
||||
|
||||
@[simp]
|
||||
theorem exists_bool {p : Bool → Prop} : (∃ b, p b) ↔ p false ∨ p true :=
|
||||
exists_bool' false
|
||||
|
||||
/-! ### cond -/
|
||||
|
||||
|
||||
@@ -191,6 +191,121 @@ def foldlM {β : Type v} {m : Type v → Type w} [Monad m] (f : β → UInt8 →
|
||||
def foldl {β : Type v} (f : β → UInt8 → β) (init : β) (as : ByteArray) (start := 0) (stop := as.size) : β :=
|
||||
Id.run <| as.foldlM f init start stop
|
||||
|
||||
/-- Iterator over the bytes (`UInt8`) of a `ByteArray`.
|
||||
|
||||
Typically created by `arr.iter`, where `arr` is a `ByteArray`.
|
||||
|
||||
An iterator is *valid* if the position `i` is *valid* for the array `arr`, meaning `0 ≤ i ≤ arr.size`
|
||||
|
||||
Most operations on iterators return arbitrary values if the iterator is not valid. The functions in
|
||||
the `ByteArray.Iterator` API should rule out the creation of invalid iterators, with two exceptions:
|
||||
|
||||
- `Iterator.next iter` is invalid if `iter` is already at the end of the array (`iter.atEnd` is
|
||||
`true`)
|
||||
- `Iterator.forward iter n`/`Iterator.nextn iter n` is invalid if `n` is strictly greater than the
|
||||
number of remaining bytes.
|
||||
-/
|
||||
structure Iterator where
|
||||
/-- The array the iterator is for. -/
|
||||
array : ByteArray
|
||||
/-- The current position.
|
||||
|
||||
This position is not necessarily valid for the array, for instance if one keeps calling
|
||||
`Iterator.next` when `Iterator.atEnd` is true. If the position is not valid, then the
|
||||
current byte is `(default : UInt8)`. -/
|
||||
idx : Nat
|
||||
deriving Inhabited
|
||||
|
||||
/-- Creates an iterator at the beginning of an array. -/
|
||||
def mkIterator (arr : ByteArray) : Iterator :=
|
||||
⟨arr, 0⟩
|
||||
|
||||
@[inherit_doc mkIterator]
|
||||
abbrev iter := mkIterator
|
||||
|
||||
/-- The size of an array iterator is the number of bytes remaining. -/
|
||||
instance : SizeOf Iterator where
|
||||
sizeOf i := i.array.size - i.idx
|
||||
|
||||
theorem Iterator.sizeOf_eq (i : Iterator) : sizeOf i = i.array.size - i.idx :=
|
||||
rfl
|
||||
|
||||
namespace Iterator
|
||||
|
||||
/-- Number of bytes remaining in the iterator. -/
|
||||
def remainingBytes : Iterator → Nat
|
||||
| ⟨arr, i⟩ => arr.size - i
|
||||
|
||||
@[inherit_doc Iterator.idx]
|
||||
def pos := Iterator.idx
|
||||
|
||||
/-- The byte at the current position.
|
||||
|
||||
On an invalid position, returns `(default : UInt8)`. -/
|
||||
@[inline]
|
||||
def curr : Iterator → UInt8
|
||||
| ⟨arr, i⟩ =>
|
||||
if h:i < arr.size then
|
||||
arr[i]'h
|
||||
else
|
||||
default
|
||||
|
||||
/-- Moves the iterator's position forward by one byte, unconditionally.
|
||||
|
||||
It is only valid to call this function if the iterator is not at the end of the array, *i.e.*
|
||||
`Iterator.atEnd` is `false`; otherwise, the resulting iterator will be invalid. -/
|
||||
@[inline]
|
||||
def next : Iterator → Iterator
|
||||
| ⟨arr, i⟩ => ⟨arr, i + 1⟩
|
||||
|
||||
/-- Decreases the iterator's position.
|
||||
|
||||
If the position is zero, this function is the identity. -/
|
||||
@[inline]
|
||||
def prev : Iterator → Iterator
|
||||
| ⟨arr, i⟩ => ⟨arr, i - 1⟩
|
||||
|
||||
/-- True if the iterator is past the array's last byte. -/
|
||||
@[inline]
|
||||
def atEnd : Iterator → Bool
|
||||
| ⟨arr, i⟩ => i ≥ arr.size
|
||||
|
||||
/-- True if the iterator is not past the array's last byte. -/
|
||||
@[inline]
|
||||
def hasNext : Iterator → Bool
|
||||
| ⟨arr, i⟩ => i < arr.size
|
||||
|
||||
/-- True if the position is not zero. -/
|
||||
@[inline]
|
||||
def hasPrev : Iterator → Bool
|
||||
| ⟨_, i⟩ => i > 0
|
||||
|
||||
/-- Moves the iterator's position to the end of the array.
|
||||
|
||||
Note that `i.toEnd.atEnd` is always `true`. -/
|
||||
@[inline]
|
||||
def toEnd : Iterator → Iterator
|
||||
| ⟨arr, _⟩ => ⟨arr, arr.size⟩
|
||||
|
||||
/-- Moves the iterator's position several bytes forward.
|
||||
|
||||
The resulting iterator is only valid if the number of bytes to skip is less than or equal to
|
||||
the number of bytes left in the iterator. -/
|
||||
@[inline]
|
||||
def forward : Iterator → Nat → Iterator
|
||||
| ⟨arr, i⟩, f => ⟨arr, i + f⟩
|
||||
|
||||
@[inherit_doc forward, inline]
|
||||
def nextn : Iterator → Nat → Iterator := forward
|
||||
|
||||
/-- Moves the iterator's position several bytes back.
|
||||
|
||||
If asked to go back more bytes than available, stops at the beginning of the array. -/
|
||||
@[inline]
|
||||
def prevn : Iterator → Nat → Iterator
|
||||
| ⟨arr, i⟩, f => ⟨arr, i - f⟩
|
||||
|
||||
end Iterator
|
||||
end ByteArray
|
||||
|
||||
def List.toByteArray (bs : List UInt8) : ByteArray :=
|
||||
|
||||
@@ -322,8 +322,8 @@ protected def pow (m : Int) : Nat → Int
|
||||
| 0 => 1
|
||||
| succ n => Int.pow m n * m
|
||||
|
||||
instance : HPow Int Nat Int where
|
||||
hPow := Int.pow
|
||||
instance : NatPow Int where
|
||||
pow := Int.pow
|
||||
|
||||
instance : LawfulBEq Int where
|
||||
eq_of_beq h := by simp [BEq.beq] at h; assumption
|
||||
|
||||
@@ -203,6 +203,10 @@ theorem mod_add_div (m k : Nat) : m % k + k * (m / k) = m := by
|
||||
| base x y h => simp [h]
|
||||
| ind x y h IH => simp [h]; rw [Nat.mul_succ, ← Nat.add_assoc, IH, Nat.sub_add_cancel h.2]
|
||||
|
||||
theorem mod_def (m k : Nat) : m % k = m - k * (m / k) := by
|
||||
rw [Nat.sub_eq_of_eq_add]
|
||||
apply (Nat.mod_add_div _ _).symm
|
||||
|
||||
@[simp] protected theorem div_one (n : Nat) : n / 1 = n := by
|
||||
have := mod_add_div n 1
|
||||
rwa [mod_one, Nat.zero_add, Nat.one_mul] at this
|
||||
|
||||
@@ -5,9 +5,18 @@ Author: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.SimpLemmas
|
||||
import Init.NotationExtra
|
||||
|
||||
instance [BEq α] [BEq β] [LawfulBEq α] [LawfulBEq β] : LawfulBEq (α × β) where
|
||||
eq_of_beq {a b} (h : a.1 == b.1 && a.2 == b.2) := by
|
||||
cases a; cases b
|
||||
refine congr (congrArg _ (eq_of_beq ?_)) (eq_of_beq ?_) <;> simp_all
|
||||
rfl {a} := by cases a; simp [BEq.beq, LawfulBEq.rfl]
|
||||
|
||||
@[simp]
|
||||
protected theorem Prod.forall {p : α × β → Prop} : (∀ x, p x) ↔ ∀ a b, p (a, b) :=
|
||||
⟨fun h a b ↦ h (a, b), fun h ⟨a, b⟩ ↦ h a b⟩
|
||||
|
||||
@[simp]
|
||||
protected theorem Prod.exists {p : α × β → Prop} : (∃ x, p x) ↔ ∃ a b, p (a, b) :=
|
||||
⟨fun ⟨⟨a, b⟩, h⟩ ↦ ⟨a, b, h⟩, fun ⟨a, b, h⟩ ↦ ⟨⟨a, b⟩, h⟩⟩
|
||||
|
||||
@@ -435,6 +435,12 @@ Note that EOF does not actually close a handle, so further reads may block and r
|
||||
|
||||
end Handle
|
||||
|
||||
/--
|
||||
Resolves a pathname to an absolute pathname with no '.', '..', or symbolic links.
|
||||
|
||||
This function coincides with the [POSIX `realpath` function](https://pubs.opengroup.org/onlinepubs/9699919799/functions/realpath.html),
|
||||
see there for more information.
|
||||
-/
|
||||
@[extern "lean_io_realpath"] opaque realPath (fname : FilePath) : IO FilePath
|
||||
@[extern "lean_io_remove_file"] opaque removeFile (fname : @& FilePath) : IO Unit
|
||||
/-- Remove given directory. Fails if not empty; see also `IO.FS.removeDirAll`. -/
|
||||
@@ -464,31 +470,23 @@ def withFile (fn : FilePath) (mode : Mode) (f : Handle → IO α) : IO α :=
|
||||
def Handle.putStrLn (h : Handle) (s : String) : IO Unit :=
|
||||
h.putStr (s.push '\n')
|
||||
|
||||
partial def Handle.readBinToEnd (h : Handle) : IO ByteArray := do
|
||||
partial def Handle.readBinToEndInto (h : Handle) (buf : ByteArray) : IO ByteArray := do
|
||||
let rec loop (acc : ByteArray) : IO ByteArray := do
|
||||
let buf ← h.read 1024
|
||||
if buf.isEmpty then
|
||||
return acc
|
||||
else
|
||||
loop (acc ++ buf)
|
||||
loop ByteArray.empty
|
||||
loop buf
|
||||
|
||||
partial def Handle.readToEnd (h : Handle) : IO String := do
|
||||
let rec loop (s : String) := do
|
||||
let line ← h.getLine
|
||||
if line.isEmpty then
|
||||
return s
|
||||
else
|
||||
loop (s ++ line)
|
||||
loop ""
|
||||
partial def Handle.readBinToEnd (h : Handle) : IO ByteArray := do
|
||||
h.readBinToEndInto .empty
|
||||
|
||||
def readBinFile (fname : FilePath) : IO ByteArray := do
|
||||
let h ← Handle.mk fname Mode.read
|
||||
h.readBinToEnd
|
||||
|
||||
def readFile (fname : FilePath) : IO String := do
|
||||
let h ← Handle.mk fname Mode.read
|
||||
h.readToEnd
|
||||
def Handle.readToEnd (h : Handle) : IO String := do
|
||||
let data ← h.readBinToEnd
|
||||
match String.fromUTF8? data with
|
||||
| some s => return s
|
||||
| none => throw <| .userError s!"Tried to read from handle containing non UTF-8 data."
|
||||
|
||||
partial def lines (fname : FilePath) : IO (Array String) := do
|
||||
let h ← Handle.mk fname Mode.read
|
||||
@@ -594,6 +592,28 @@ end System.FilePath
|
||||
|
||||
namespace IO
|
||||
|
||||
namespace FS
|
||||
|
||||
def readBinFile (fname : FilePath) : IO ByteArray := do
|
||||
-- Requires metadata so defined after metadata
|
||||
let mdata ← fname.metadata
|
||||
let size := mdata.byteSize.toUSize
|
||||
let handle ← IO.FS.Handle.mk fname .read
|
||||
let buf ←
|
||||
if size > 0 then
|
||||
handle.read mdata.byteSize.toUSize
|
||||
else
|
||||
pure <| ByteArray.mkEmpty 0
|
||||
handle.readBinToEndInto buf
|
||||
|
||||
def readFile (fname : FilePath) : IO String := do
|
||||
let data ← readBinFile fname
|
||||
match String.fromUTF8? data with
|
||||
| some s => return s
|
||||
| none => throw <| .userError s!"Tried to read file '{fname}' containing non UTF-8 data."
|
||||
|
||||
end FS
|
||||
|
||||
def withStdin [Monad m] [MonadFinally m] [MonadLiftT BaseIO m] (h : FS.Stream) (x : m α) : m α := do
|
||||
let prev ← setStdin h
|
||||
try x finally discard <| setStdin prev
|
||||
|
||||
@@ -53,7 +53,7 @@ structure AttributeImpl extends AttributeImplCore where
|
||||
erase (decl : Name) : AttrM Unit := throwError "attribute cannot be erased"
|
||||
deriving Inhabited
|
||||
|
||||
builtin_initialize attributeMapRef : IO.Ref (HashMap Name AttributeImpl) ← IO.mkRef {}
|
||||
builtin_initialize attributeMapRef : IO.Ref (Std.HashMap Name AttributeImpl) ← IO.mkRef {}
|
||||
|
||||
/-- Low level attribute registration function. -/
|
||||
def registerBuiltinAttribute (attr : AttributeImpl) : IO Unit := do
|
||||
@@ -296,7 +296,7 @@ end EnumAttributes
|
||||
-/
|
||||
|
||||
abbrev AttributeImplBuilder := Name → List DataValue → Except String AttributeImpl
|
||||
abbrev AttributeImplBuilderTable := HashMap Name AttributeImplBuilder
|
||||
abbrev AttributeImplBuilderTable := Std.HashMap Name AttributeImplBuilder
|
||||
|
||||
builtin_initialize attributeImplBuilderTableRef : IO.Ref AttributeImplBuilderTable ← IO.mkRef {}
|
||||
|
||||
@@ -307,7 +307,7 @@ def registerAttributeImplBuilder (builderId : Name) (builder : AttributeImplBuil
|
||||
|
||||
def mkAttributeImplOfBuilder (builderId ref : Name) (args : List DataValue) : IO AttributeImpl := do
|
||||
let table ← attributeImplBuilderTableRef.get
|
||||
match table.find? builderId with
|
||||
match table[builderId]? with
|
||||
| none => throw (IO.userError ("unknown attribute implementation builder '" ++ toString builderId ++ "'"))
|
||||
| some builder => IO.ofExcept <| builder ref args
|
||||
|
||||
@@ -317,7 +317,7 @@ inductive AttributeExtensionOLeanEntry where
|
||||
|
||||
structure AttributeExtensionState where
|
||||
newEntries : List AttributeExtensionOLeanEntry := []
|
||||
map : HashMap Name AttributeImpl
|
||||
map : Std.HashMap Name AttributeImpl
|
||||
deriving Inhabited
|
||||
|
||||
abbrev AttributeExtension := PersistentEnvExtension AttributeExtensionOLeanEntry (AttributeExtensionOLeanEntry × AttributeImpl) AttributeExtensionState
|
||||
@@ -348,7 +348,7 @@ private def AttributeExtension.addImported (es : Array (Array AttributeExtension
|
||||
let map ← es.foldlM
|
||||
(fun map entries =>
|
||||
entries.foldlM
|
||||
(fun (map : HashMap Name AttributeImpl) entry => do
|
||||
(fun (map : Std.HashMap Name AttributeImpl) entry => do
|
||||
let attrImpl ← mkAttributeImplOfEntry ctx.env ctx.opts entry
|
||||
return map.insert attrImpl.name attrImpl)
|
||||
map)
|
||||
@@ -378,7 +378,7 @@ def getBuiltinAttributeNames : IO (List Name) :=
|
||||
|
||||
def getBuiltinAttributeImpl (attrName : Name) : IO AttributeImpl := do
|
||||
let m ← attributeMapRef.get
|
||||
match m.find? attrName with
|
||||
match m[attrName]? with
|
||||
| some attr => pure attr
|
||||
| none => throw (IO.userError ("unknown attribute '" ++ toString attrName ++ "'"))
|
||||
|
||||
@@ -396,7 +396,7 @@ def getAttributeNames (env : Environment) : List Name :=
|
||||
|
||||
def getAttributeImpl (env : Environment) (attrName : Name) : Except String AttributeImpl :=
|
||||
let m := (attributeExtension.getState env).map
|
||||
match m.find? attrName with
|
||||
match m[attrName]? with
|
||||
| some attr => pure attr
|
||||
| none => throw ("unknown attribute '" ++ toString attrName ++ "'")
|
||||
|
||||
|
||||
@@ -26,9 +26,9 @@ instance : Hashable Key := ⟨getHash⟩
|
||||
end OwnedSet
|
||||
|
||||
open OwnedSet (Key) in
|
||||
abbrev OwnedSet := HashMap Key Unit
|
||||
def OwnedSet.insert (s : OwnedSet) (k : OwnedSet.Key) : OwnedSet := HashMap.insert s k ()
|
||||
def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := HashMap.contains s k
|
||||
abbrev OwnedSet := Std.HashMap Key Unit
|
||||
def OwnedSet.insert (s : OwnedSet) (k : OwnedSet.Key) : OwnedSet := Std.HashMap.insert s k ()
|
||||
def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := Std.HashMap.contains s k
|
||||
|
||||
/-! We perform borrow inference in a block of mutually recursive functions.
|
||||
Join points are viewed as local functions, and are identified using
|
||||
@@ -49,7 +49,7 @@ instance : Hashable Key := ⟨getHash⟩
|
||||
end ParamMap
|
||||
|
||||
open ParamMap (Key)
|
||||
abbrev ParamMap := HashMap Key (Array Param)
|
||||
abbrev ParamMap := Std.HashMap Key (Array Param)
|
||||
|
||||
def ParamMap.fmt (map : ParamMap) : Format :=
|
||||
let fmts := map.fold (fun fmt k ps =>
|
||||
@@ -109,7 +109,7 @@ partial def visitFnBody (fn : FunId) (paramMap : ParamMap) : FnBody → FnBody
|
||||
| FnBody.jdecl j _ v b =>
|
||||
let v := visitFnBody fn paramMap v
|
||||
let b := visitFnBody fn paramMap b
|
||||
match paramMap.find? (ParamMap.Key.jp fn j) with
|
||||
match paramMap[ParamMap.Key.jp fn j]? with
|
||||
| some ys => FnBody.jdecl j ys v b
|
||||
| none => unreachable!
|
||||
| FnBody.case tid x xType alts =>
|
||||
@@ -125,7 +125,7 @@ def visitDecls (decls : Array Decl) (paramMap : ParamMap) : Array Decl :=
|
||||
decls.map fun decl => match decl with
|
||||
| Decl.fdecl f _ ty b info =>
|
||||
let b := visitFnBody f paramMap b
|
||||
match paramMap.find? (ParamMap.Key.decl f) with
|
||||
match paramMap[ParamMap.Key.decl f]? with
|
||||
| some xs => Decl.fdecl f xs ty b info
|
||||
| none => unreachable!
|
||||
| other => other
|
||||
@@ -178,7 +178,7 @@ def isOwned (x : VarId) : M Bool := do
|
||||
/-- Updates `map[k]` using the current set of `owned` variables. -/
|
||||
def updateParamMap (k : ParamMap.Key) : M Unit := do
|
||||
let s ← get
|
||||
match s.paramMap.find? k with
|
||||
match s.paramMap[k]? with
|
||||
| some ps => do
|
||||
let ps ← ps.mapM fun (p : Param) => do
|
||||
if !p.borrow then pure p
|
||||
@@ -192,7 +192,7 @@ def updateParamMap (k : ParamMap.Key) : M Unit := do
|
||||
|
||||
def getParamInfo (k : ParamMap.Key) : M (Array Param) := do
|
||||
let s ← get
|
||||
match s.paramMap.find? k with
|
||||
match s.paramMap[k]? with
|
||||
| some ps => pure ps
|
||||
| none =>
|
||||
match k with
|
||||
|
||||
@@ -11,6 +11,7 @@ import Lean.Compiler.IR.Basic
|
||||
import Lean.Compiler.IR.CompilerM
|
||||
import Lean.Compiler.IR.FreeVars
|
||||
import Lean.Compiler.IR.ElimDeadVars
|
||||
import Lean.Data.AssocList
|
||||
|
||||
namespace Lean.IR.ExplicitBoxing
|
||||
/-!
|
||||
|
||||
@@ -152,7 +152,7 @@ def getFunctionSummary? (env : Environment) (fid : FunId) : Option Value :=
|
||||
| some modIdx => findAtSorted? (functionSummariesExt.getModuleEntries env modIdx) fid
|
||||
| none => functionSummariesExt.getState env |>.find? fid
|
||||
|
||||
abbrev Assignment := HashMap VarId Value
|
||||
abbrev Assignment := Std.HashMap VarId Value
|
||||
|
||||
structure InterpContext where
|
||||
currFnIdx : Nat := 0
|
||||
@@ -172,7 +172,7 @@ def findVarValue (x : VarId) : M Value := do
|
||||
let ctx ← read
|
||||
let s ← get
|
||||
let assignment := s.assignments[ctx.currFnIdx]!
|
||||
return assignment.findD x bot
|
||||
return assignment.getD x bot
|
||||
|
||||
def findArgValue (arg : Arg) : M Value :=
|
||||
match arg with
|
||||
@@ -303,7 +303,7 @@ partial def elimDeadAux (assignment : Assignment) : FnBody → FnBody
|
||||
| FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux assignment b)
|
||||
| FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux assignment v) (elimDeadAux assignment b)
|
||||
| FnBody.case tid x xType alts =>
|
||||
let v := assignment.findD x bot
|
||||
let v := assignment.getD x bot
|
||||
let alts := alts.map fun alt =>
|
||||
match alt with
|
||||
| Alt.ctor i b => Alt.ctor i <| if containsCtor v i then elimDeadAux assignment b else FnBody.unreachable
|
||||
|
||||
@@ -252,7 +252,7 @@ def throwUnknownVar {α : Type} (x : VarId) : M α :=
|
||||
|
||||
def getJPParams (j : JoinPointId) : M (Array Param) := do
|
||||
let ctx ← read;
|
||||
match ctx.jpMap.find? j with
|
||||
match ctx.jpMap[j]? with
|
||||
| some ps => pure ps
|
||||
| none => throw "unknown join point"
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Siddharth Bhat
|
||||
-/
|
||||
prelude
|
||||
import Lean.Data.HashMap
|
||||
import Lean.Runtime
|
||||
import Lean.Compiler.NameMangling
|
||||
import Lean.Compiler.ExportAttr
|
||||
@@ -65,8 +64,8 @@ structure Context (llvmctx : LLVM.Context) where
|
||||
llvmmodule : LLVM.Module llvmctx
|
||||
|
||||
structure State (llvmctx : LLVM.Context) where
|
||||
var2val : HashMap VarId (LLVM.LLVMType llvmctx × LLVM.Value llvmctx)
|
||||
jp2bb : HashMap JoinPointId (LLVM.BasicBlock llvmctx)
|
||||
var2val : Std.HashMap VarId (LLVM.LLVMType llvmctx × LLVM.Value llvmctx)
|
||||
jp2bb : Std.HashMap JoinPointId (LLVM.BasicBlock llvmctx)
|
||||
|
||||
abbrev Error := String
|
||||
|
||||
@@ -84,7 +83,7 @@ def addJpTostate (jp : JoinPointId) (bb : LLVM.BasicBlock llvmctx) : M llvmctx U
|
||||
|
||||
def emitJp (jp : JoinPointId) : M llvmctx (LLVM.BasicBlock llvmctx) := do
|
||||
let state ← get
|
||||
match state.jp2bb.find? jp with
|
||||
match state.jp2bb[jp]? with
|
||||
| .some bb => return bb
|
||||
| .none => throw s!"unable to find join point {jp}"
|
||||
|
||||
@@ -531,7 +530,7 @@ def emitFnDecls : M llvmctx Unit := do
|
||||
|
||||
def emitLhsSlot_ (x : VarId) : M llvmctx (LLVM.LLVMType llvmctx × LLVM.Value llvmctx) := do
|
||||
let state ← get
|
||||
match state.var2val.find? x with
|
||||
match state.var2val[x]? with
|
||||
| .some v => return v
|
||||
| .none => throw s!"unable to find variable {x}"
|
||||
|
||||
@@ -1029,7 +1028,7 @@ def emitTailCall (builder : LLVM.Builder llvmctx) (f : FunId) (v : Expr) : M llv
|
||||
|
||||
def emitJmp (builder : LLVM.Builder llvmctx) (jp : JoinPointId) (xs : Array Arg) : M llvmctx Unit := do
|
||||
let llvmctx ← read
|
||||
let ps ← match llvmctx.jpMap.find? jp with
|
||||
let ps ← match llvmctx.jpMap[jp]? with
|
||||
| some ps => pure ps
|
||||
| none => throw s!"Unknown join point {jp}"
|
||||
unless xs.size == ps.size do throw s!"Invalid goto, mismatched sizes between arguments, formal parameters."
|
||||
|
||||
@@ -51,8 +51,8 @@ end CollectUsedDecls
|
||||
def collectUsedDecls (env : Environment) (decl : Decl) (used : NameSet := {}) : NameSet :=
|
||||
(CollectUsedDecls.collectDecl decl env).run' used
|
||||
|
||||
abbrev VarTypeMap := HashMap VarId IRType
|
||||
abbrev JPParamsMap := HashMap JoinPointId (Array Param)
|
||||
abbrev VarTypeMap := Std.HashMap VarId IRType
|
||||
abbrev JPParamsMap := Std.HashMap JoinPointId (Array Param)
|
||||
|
||||
namespace CollectMaps
|
||||
abbrev Collector := (VarTypeMap × JPParamsMap) → (VarTypeMap × JPParamsMap)
|
||||
|
||||
@@ -10,7 +10,7 @@ import Lean.Compiler.IR.FreeVars
|
||||
|
||||
namespace Lean.IR.ExpandResetReuse
|
||||
/-- Mapping from variable to projections -/
|
||||
abbrev ProjMap := HashMap VarId Expr
|
||||
abbrev ProjMap := Std.HashMap VarId Expr
|
||||
namespace CollectProjMap
|
||||
abbrev Collector := ProjMap → ProjMap
|
||||
@[inline] def collectVDecl (x : VarId) (v : Expr) : Collector := fun m =>
|
||||
@@ -148,20 +148,20 @@ def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
|
||||
def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=
|
||||
match y with
|
||||
| Arg.var y =>
|
||||
match ctx.projMap.find? y with
|
||||
match ctx.projMap[y]? with
|
||||
| some (Expr.proj j w) => j == i && w == x
|
||||
| _ => false
|
||||
| _ => false
|
||||
|
||||
/-- Given `uset x[i] := y`, return true iff `y := uproj[i] x` -/
|
||||
def isSelfUSet (ctx : Context) (x : VarId) (i : Nat) (y : VarId) : Bool :=
|
||||
match ctx.projMap.find? y with
|
||||
match ctx.projMap[y]? with
|
||||
| some (Expr.uproj j w) => j == i && w == x
|
||||
| _ => false
|
||||
|
||||
/-- Given `sset x[n, i] := y`, return true iff `y := sproj[n, i] x` -/
|
||||
def isSelfSSet (ctx : Context) (x : VarId) (n : Nat) (i : Nat) (y : VarId) : Bool :=
|
||||
match ctx.projMap.find? y with
|
||||
match ctx.projMap[y]? with
|
||||
| some (Expr.sproj m j w) => n == m && j == i && w == x
|
||||
| _ => false
|
||||
|
||||
|
||||
@@ -64,34 +64,34 @@ instance : AddMessageContext CompilerM where
|
||||
|
||||
def getType (fvarId : FVarId) : CompilerM Expr := do
|
||||
let lctx := (← get).lctx
|
||||
if let some decl := lctx.letDecls.find? fvarId then
|
||||
if let some decl := lctx.letDecls[fvarId]? then
|
||||
return decl.type
|
||||
else if let some decl := lctx.params.find? fvarId then
|
||||
else if let some decl := lctx.params[fvarId]? then
|
||||
return decl.type
|
||||
else if let some decl := lctx.funDecls.find? fvarId then
|
||||
else if let some decl := lctx.funDecls[fvarId]? then
|
||||
return decl.type
|
||||
else
|
||||
throwError "unknown free variable {fvarId.name}"
|
||||
|
||||
def getBinderName (fvarId : FVarId) : CompilerM Name := do
|
||||
let lctx := (← get).lctx
|
||||
if let some decl := lctx.letDecls.find? fvarId then
|
||||
if let some decl := lctx.letDecls[fvarId]? then
|
||||
return decl.binderName
|
||||
else if let some decl := lctx.params.find? fvarId then
|
||||
else if let some decl := lctx.params[fvarId]? then
|
||||
return decl.binderName
|
||||
else if let some decl := lctx.funDecls.find? fvarId then
|
||||
else if let some decl := lctx.funDecls[fvarId]? then
|
||||
return decl.binderName
|
||||
else
|
||||
throwError "unknown free variable {fvarId.name}"
|
||||
|
||||
def findParam? (fvarId : FVarId) : CompilerM (Option Param) :=
|
||||
return (← get).lctx.params.find? fvarId
|
||||
return (← get).lctx.params[fvarId]?
|
||||
|
||||
def findLetDecl? (fvarId : FVarId) : CompilerM (Option LetDecl) :=
|
||||
return (← get).lctx.letDecls.find? fvarId
|
||||
return (← get).lctx.letDecls[fvarId]?
|
||||
|
||||
def findFunDecl? (fvarId : FVarId) : CompilerM (Option FunDecl) :=
|
||||
return (← get).lctx.funDecls.find? fvarId
|
||||
return (← get).lctx.funDecls[fvarId]?
|
||||
|
||||
def findLetValue? (fvarId : FVarId) : CompilerM (Option LetValue) := do
|
||||
let some { value, .. } ← findLetDecl? fvarId | return none
|
||||
@@ -166,7 +166,7 @@ it is a free variable, a type (or type former), or `lcErased`.
|
||||
|
||||
`Check.lean` contains a substitution validator.
|
||||
-/
|
||||
abbrev FVarSubst := HashMap FVarId Expr
|
||||
abbrev FVarSubst := Std.HashMap FVarId Expr
|
||||
|
||||
/--
|
||||
Replace the free variables in `e` using the given substitution.
|
||||
@@ -190,7 +190,7 @@ where
|
||||
go (e : Expr) : Expr :=
|
||||
if e.hasFVar then
|
||||
match e with
|
||||
| .fvar fvarId => match s.find? fvarId with
|
||||
| .fvar fvarId => match s[fvarId]? with
|
||||
| some e => if translator then e else go e
|
||||
| none => e
|
||||
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e
|
||||
@@ -224,7 +224,7 @@ That is, it is not a type (or type former), nor `lcErased`. Recall that a valid
|
||||
expressions that are free variables, `lcErased`, or type formers.
|
||||
-/
|
||||
private partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : Bool) : NormFVarResult :=
|
||||
match s.find? fvarId with
|
||||
match s[fvarId]? with
|
||||
| some (.fvar fvarId') =>
|
||||
if translator then
|
||||
.fvar fvarId'
|
||||
@@ -246,7 +246,7 @@ private partial def normArgImp (s : FVarSubst) (arg : Arg) (translator : Bool) :
|
||||
match arg with
|
||||
| .erased => arg
|
||||
| .fvar fvarId =>
|
||||
match s.find? fvarId with
|
||||
match s[fvarId]? with
|
||||
| some (.fvar fvarId') =>
|
||||
let arg' := .fvar fvarId'
|
||||
if translator then arg' else normArgImp s arg' translator
|
||||
|
||||
@@ -268,7 +268,7 @@ def getFunctionSummary? (env : Environment) (fid : Name) : Option Value :=
|
||||
A map from variable identifiers to the `Value` produced by the abstract
|
||||
interpreter for them.
|
||||
-/
|
||||
abbrev Assignment := HashMap FVarId Value
|
||||
abbrev Assignment := Std.HashMap FVarId Value
|
||||
|
||||
/--
|
||||
The context of `InterpM`.
|
||||
@@ -332,7 +332,7 @@ If none is available return `Value.bot`.
|
||||
-/
|
||||
def findVarValue (var : FVarId) : InterpM Value := do
|
||||
let assignment ← getAssignment
|
||||
return assignment.findD var .bot
|
||||
return assignment.getD var .bot
|
||||
|
||||
/--
|
||||
Find the value of `arg` using the logic of `findVarValue`.
|
||||
@@ -547,13 +547,13 @@ where
|
||||
| .jp decl k | .fun decl k =>
|
||||
return code.updateFun! (← decl.updateValue (← go decl.value)) (← go k)
|
||||
| .cases cs =>
|
||||
let discrVal := assignment.findD cs.discr .bot
|
||||
let discrVal := assignment.getD cs.discr .bot
|
||||
let processAlt typ alt := do
|
||||
match alt with
|
||||
| .alt ctor args body =>
|
||||
if discrVal.containsCtor ctor then
|
||||
let filter param := do
|
||||
if let some val := assignment.find? param.fvarId then
|
||||
if let some val := assignment[param.fvarId]? then
|
||||
if let some literal ← val.getLiteral then
|
||||
return some (param, literal)
|
||||
return none
|
||||
|
||||
@@ -62,7 +62,7 @@ structure State where
|
||||
Whenever there is function application `f a₁ ... aₙ`, where `f` is in `decls`, `f` is not `main`, and
|
||||
we visit with the abstract values assigned to `aᵢ`, but first we record the visit here.
|
||||
-/
|
||||
visited : HashSet (Name × Array AbsValue) := {}
|
||||
visited : Std.HashSet (Name × Array AbsValue) := {}
|
||||
/--
|
||||
Bitmask containing the result, i.e., which parameters of `main` are fixed.
|
||||
We initialize it with `true` everywhere.
|
||||
|
||||
@@ -59,7 +59,7 @@ structure FloatState where
|
||||
/--
|
||||
A map from identifiers of declarations to their current decision.
|
||||
-/
|
||||
decision : HashMap FVarId Decision
|
||||
decision : Std.HashMap FVarId Decision
|
||||
/--
|
||||
A map from decisions (excluding `unknown`) to the declarations with
|
||||
these decisions (in correct order). Basically:
|
||||
@@ -67,7 +67,7 @@ structure FloatState where
|
||||
- Which declarations do we move into a certain arm
|
||||
- Which declarations do we move into the default arm
|
||||
-/
|
||||
newArms : HashMap Decision (List CodeDecl)
|
||||
newArms : Std.HashMap Decision (List CodeDecl)
|
||||
|
||||
/--
|
||||
Use to collect relevant declarations for the floating mechanism.
|
||||
@@ -116,8 +116,8 @@ up to this point, with respect to `cs`. The initial decisions are:
|
||||
- `arm` or `default` if we see the declaration only being used in exactly one cases arm
|
||||
- `unknown` otherwise
|
||||
-/
|
||||
def initialDecisions (cs : Cases) : BaseFloatM (HashMap FVarId Decision) := do
|
||||
let mut map := mkHashMap (← read).decls.length
|
||||
def initialDecisions (cs : Cases) : BaseFloatM (Std.HashMap FVarId Decision) := do
|
||||
let mut map := Std.HashMap.empty (← read).decls.length
|
||||
let folder val acc := do
|
||||
if let .let decl := val then
|
||||
if (← ignore? decl) then
|
||||
@@ -130,25 +130,25 @@ def initialDecisions (cs : Cases) : BaseFloatM (HashMap FVarId Decision) := do
|
||||
(_, map) ← goCases cs |>.run map
|
||||
return map
|
||||
where
|
||||
goFVar (plannedDecision : Decision) (var : FVarId) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit := do
|
||||
if let some decision := (← get).find? var then
|
||||
goFVar (plannedDecision : Decision) (var : FVarId) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit := do
|
||||
if let some decision := (← get)[var]? then
|
||||
if decision == .unknown then
|
||||
modify fun s => s.insert var plannedDecision
|
||||
else if decision != plannedDecision then
|
||||
modify fun s => s.insert var .dont
|
||||
-- otherwise we already have the proper decision
|
||||
|
||||
goAlt (alt : Alt) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
goAlt (alt : Alt) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
forFVarM (goFVar (.ofAlt alt)) alt
|
||||
goCases (cs : Cases) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
goCases (cs : Cases) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
cs.alts.forM goAlt
|
||||
|
||||
/--
|
||||
Compute the initial new arms. This will just set up a map from all arms of
|
||||
`cs` to empty `Array`s, plus one additional entry for `dont`.
|
||||
-/
|
||||
def initialNewArms (cs : Cases) : HashMap Decision (List CodeDecl) := Id.run do
|
||||
let mut map := mkHashMap (cs.alts.size + 1)
|
||||
def initialNewArms (cs : Cases) : Std.HashMap Decision (List CodeDecl) := Id.run do
|
||||
let mut map := Std.HashMap.empty (cs.alts.size + 1)
|
||||
map := map.insert .dont []
|
||||
cs.alts.foldr (init := map) fun val acc => acc.insert (.ofAlt val) []
|
||||
|
||||
@@ -170,7 +170,7 @@ respectively but since `z` can't be moved we don't want that to move `x` and `y`
|
||||
-/
|
||||
def dontFloat (decl : CodeDecl) : FloatM Unit := do
|
||||
forFVarM goFVar decl
|
||||
modify fun s => { s with newArms := s.newArms.insert .dont (decl :: s.newArms.find! .dont) }
|
||||
modify fun s => { s with newArms := s.newArms.insert .dont (decl :: s.newArms[Decision.dont]!) }
|
||||
where
|
||||
goFVar (fvar : FVarId) : FloatM Unit := do
|
||||
if (← get).decision.contains fvar then
|
||||
@@ -223,12 +223,12 @@ Will:
|
||||
If we are at `y` `x` is still marked to be moved but we don't want that.
|
||||
-/
|
||||
def float (decl : CodeDecl) : FloatM Unit := do
|
||||
let arm := (← get).decision.find! decl.fvarId
|
||||
let arm := (← get).decision[decl.fvarId]!
|
||||
forFVarM (goFVar · arm) decl
|
||||
modify fun s => { s with newArms := s.newArms.insert arm (decl :: s.newArms.find! arm) }
|
||||
modify fun s => { s with newArms := s.newArms.insert arm (decl :: s.newArms[arm]!) }
|
||||
where
|
||||
goFVar (fvar : FVarId) (arm : Decision) : FloatM Unit := do
|
||||
let some decision := (← get).decision.find? fvar | return ()
|
||||
let some decision := (← get).decision[fvar]? | return ()
|
||||
if decision != arm then
|
||||
modify fun s => { s with decision := s.decision.insert fvar .dont }
|
||||
else if decision == .unknown then
|
||||
@@ -249,7 +249,7 @@ where
|
||||
-/
|
||||
goCases : FloatM Unit := do
|
||||
for decl in (← read).decls do
|
||||
let currentDecision := (← get).decision.find! decl.fvarId
|
||||
let currentDecision := (← get).decision[decl.fvarId]!
|
||||
if currentDecision == .unknown then
|
||||
/-
|
||||
If the decision is still unknown by now this means `decl` is
|
||||
@@ -284,10 +284,10 @@ where
|
||||
newArms := initialNewArms cs
|
||||
}
|
||||
let (_, res) ← goCases |>.run base
|
||||
let remainders := res.newArms.find! .dont
|
||||
let remainders := res.newArms[Decision.dont]!
|
||||
let altMapper alt := do
|
||||
let decision := .ofAlt alt
|
||||
let newCode := res.newArms.find! decision
|
||||
let decision := Decision.ofAlt alt
|
||||
let newCode := res.newArms[decision]!
|
||||
trace[Compiler.floatLetIn] "Size of code that was pushed into arm: {repr decision} {newCode.length}"
|
||||
let fused ← withNewScope do
|
||||
go (attachCodeDecls newCode.toArray alt.getCode)
|
||||
|
||||
@@ -29,7 +29,7 @@ structure CandidateInfo where
|
||||
The set of candidates that rely on this candidate to be a join point.
|
||||
For a more detailed explanation see the documentation of `find`
|
||||
-/
|
||||
associated : HashSet FVarId
|
||||
associated : Std.HashSet FVarId
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
@@ -39,14 +39,14 @@ structure FindState where
|
||||
/--
|
||||
All current join point candidates accessible by their `FVarId`.
|
||||
-/
|
||||
candidates : HashMap FVarId CandidateInfo := .empty
|
||||
candidates : Std.HashMap FVarId CandidateInfo := .empty
|
||||
/--
|
||||
The `FVarId`s of all `fun` declarations that were declared within the
|
||||
current `fun`.
|
||||
-/
|
||||
scope : HashSet FVarId := .empty
|
||||
scope : Std.HashSet FVarId := .empty
|
||||
|
||||
abbrev ReplaceCtx := HashMap FVarId Name
|
||||
abbrev ReplaceCtx := Std.HashMap FVarId Name
|
||||
|
||||
abbrev FindM := ReaderT (Option FVarId) StateRefT FindState ScopeM
|
||||
abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
|
||||
@@ -55,7 +55,7 @@ abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
|
||||
Attempt to find a join point candidate by its `FVarId`.
|
||||
-/
|
||||
private def findCandidate? (fvarId : FVarId) : FindM (Option CandidateInfo) := do
|
||||
return (← get).candidates.find? fvarId
|
||||
return (← get).candidates[fvarId]?
|
||||
|
||||
/--
|
||||
Erase a join point candidate as well as all the ones that depend on it
|
||||
@@ -69,7 +69,7 @@ private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
|
||||
/--
|
||||
Combinator for modifying the candidates in `FindM`.
|
||||
-/
|
||||
private def modifyCandidates (f : HashMap FVarId CandidateInfo → HashMap FVarId CandidateInfo) : FindM Unit :=
|
||||
private def modifyCandidates (f : Std.HashMap FVarId CandidateInfo → Std.HashMap FVarId CandidateInfo) : FindM Unit :=
|
||||
modify (fun state => {state with candidates := f state.candidates })
|
||||
|
||||
/--
|
||||
@@ -196,7 +196,7 @@ where
|
||||
return code
|
||||
| _, _ => return Code.updateLet! code decl (← go k)
|
||||
| .fun decl k =>
|
||||
if let some replacement := (← read).find? decl.fvarId then
|
||||
if let some replacement := (← read)[decl.fvarId]? then
|
||||
let newDecl := { decl with
|
||||
binderName := replacement,
|
||||
value := (← go decl.value)
|
||||
@@ -244,7 +244,7 @@ structure ExtendState where
|
||||
to `Param`s. The free variables in this map are the once that the context
|
||||
of said join point will be extended by by passing in the respective parameter.
|
||||
-/
|
||||
fvarMap : HashMap FVarId (HashMap FVarId Param) := {}
|
||||
fvarMap : Std.HashMap FVarId (Std.HashMap FVarId Param) := {}
|
||||
|
||||
/--
|
||||
The monad for the `extendJoinPointContext` pass.
|
||||
@@ -262,7 +262,7 @@ otherwise just return `fvar`.
|
||||
def replaceFVar (fvar : FVarId) : ExtendM FVarId := do
|
||||
if (← read).candidates.contains fvar then
|
||||
if let some currentJp := (← read).currentJp? then
|
||||
if let some replacement := (← get).fvarMap.find! currentJp |>.find? fvar then
|
||||
if let some replacement := (← get).fvarMap[currentJp]![fvar]? then
|
||||
return replacement.fvarId
|
||||
return fvar
|
||||
|
||||
@@ -313,7 +313,7 @@ This is necessary if:
|
||||
-/
|
||||
def extendByIfNecessary (fvar : FVarId) : ExtendM Unit := do
|
||||
if let some currentJp := (← read).currentJp? then
|
||||
let mut translator := (← get).fvarMap.find! currentJp
|
||||
let mut translator := (← get).fvarMap[currentJp]!
|
||||
let candidates := (← read).candidates
|
||||
if !(← isInScope fvar) && !translator.contains fvar && candidates.contains fvar then
|
||||
let typ ← getType fvar
|
||||
@@ -337,7 +337,7 @@ of `j.2` in `j.1`.
|
||||
-/
|
||||
def mergeJpContextIfNecessary (jp : FVarId) : ExtendM Unit := do
|
||||
if (← read).currentJp?.isSome then
|
||||
let additionalArgs := (← get).fvarMap.find! jp |>.toArray
|
||||
let additionalArgs := (← get).fvarMap[jp]!.toArray
|
||||
for (fvar, _) in additionalArgs do
|
||||
extendByIfNecessary fvar
|
||||
|
||||
@@ -405,7 +405,7 @@ where
|
||||
| .jp decl k =>
|
||||
let decl ← withNewJpScope decl do
|
||||
let value ← go decl.value
|
||||
let additionalParams := (← get).fvarMap.find! decl.fvarId |>.toArray |>.map Prod.snd
|
||||
let additionalParams := (← get).fvarMap[decl.fvarId]!.toArray |>.map Prod.snd
|
||||
let newType := additionalParams.foldr (init := decl.type) (fun val acc => .forallE val.binderName val.type acc .default)
|
||||
decl.update newType (additionalParams ++ decl.params) value
|
||||
mergeJpContextIfNecessary decl.fvarId
|
||||
@@ -426,7 +426,7 @@ where
|
||||
return Code.updateCases! code cs.resultType discr alts
|
||||
| .jmp fn args =>
|
||||
let mut newArgs ← args.mapM (mapFVarM goFVar)
|
||||
let additionalArgs := (← get).fvarMap.find! fn |>.toArray |>.map Prod.fst
|
||||
let additionalArgs := (← get).fvarMap[fn]!.toArray |>.map Prod.fst
|
||||
if let some _currentJp := (← read).currentJp? then
|
||||
let f := fun arg => do
|
||||
return .fvar (← goFVar arg)
|
||||
@@ -545,7 +545,7 @@ where
|
||||
if let some knownArgs := (← get).jpJmpArgs.find? fn then
|
||||
let mut newArgs := knownArgs
|
||||
for (param, arg) in decl.params.zip args do
|
||||
if let some knownVal := newArgs.find? param.fvarId then
|
||||
if let some knownVal := newArgs[param.fvarId]? then
|
||||
if arg.toExpr != knownVal then
|
||||
newArgs := newArgs.erase param.fvarId
|
||||
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn newArgs }
|
||||
|
||||
@@ -13,9 +13,9 @@ namespace Lean.Compiler.LCNF
|
||||
LCNF local context.
|
||||
-/
|
||||
structure LCtx where
|
||||
params : HashMap FVarId Param := {}
|
||||
letDecls : HashMap FVarId LetDecl := {}
|
||||
funDecls : HashMap FVarId FunDecl := {}
|
||||
params : Std.HashMap FVarId Param := {}
|
||||
letDecls : Std.HashMap FVarId LetDecl := {}
|
||||
funDecls : Std.HashMap FVarId FunDecl := {}
|
||||
deriving Inhabited
|
||||
|
||||
def LCtx.addParam (lctx : LCtx) (param : Param) : LCtx :=
|
||||
|
||||
@@ -30,7 +30,7 @@ structure State where
|
||||
/-- Counter for generating new (normalized) universe parameter names. -/
|
||||
nextIdx : Nat := 1
|
||||
/-- Mapping from existing universe parameter names to the new ones. -/
|
||||
map : HashMap Name Level := {}
|
||||
map : Std.HashMap Name Level := {}
|
||||
/-- Parameters that have been normalized. -/
|
||||
paramNames : Array Name := #[]
|
||||
|
||||
@@ -49,7 +49,7 @@ partial def normLevel (u : Level) : M Level := do
|
||||
| .max v w => return u.updateMax! (← normLevel v) (← normLevel w)
|
||||
| .imax v w => return u.updateIMax! (← normLevel v) (← normLevel w)
|
||||
| .mvar _ => unreachable!
|
||||
| .param n => match (← get).map.find? n with
|
||||
| .param n => match (← get).map[n]? with
|
||||
| some u => return u
|
||||
| none =>
|
||||
let u := Level.param <| (`u).appendIndexAfter (← get).nextIdx
|
||||
|
||||
@@ -31,9 +31,9 @@ def sortedBySize : Probe Decl (Nat × Decl) := fun decls =>
|
||||
if sz₁ == sz₂ then Name.lt decl₁.name decl₂.name else sz₁ < sz₂
|
||||
|
||||
def countUnique [ToString α] [BEq α] [Hashable α] : Probe α (α × Nat) := fun data => do
|
||||
let mut map := HashMap.empty
|
||||
let mut map := Std.HashMap.empty
|
||||
for d in data do
|
||||
if let some count := map.find? d then
|
||||
if let some count := map[d]? then
|
||||
map := map.insert d (count + 1)
|
||||
else
|
||||
map := map.insert d 1
|
||||
|
||||
@@ -40,7 +40,7 @@ structure FunDeclInfoMap where
|
||||
/--
|
||||
Mapping from local function name to inlining information.
|
||||
-/
|
||||
map : HashMap FVarId FunDeclInfo := {}
|
||||
map : Std.HashMap FVarId FunDeclInfo := {}
|
||||
deriving Inhabited
|
||||
|
||||
def FunDeclInfoMap.format (s : FunDeclInfoMap) : CompilerM Format := do
|
||||
@@ -56,7 +56,7 @@ Add new occurrence for the local function with binder name `key`.
|
||||
def FunDeclInfoMap.add (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
|
||||
match s with
|
||||
| { map } =>
|
||||
match map.find? fvarId with
|
||||
match map[fvarId]? with
|
||||
| some .once => { map := map.insert fvarId .many }
|
||||
| none => { map := map.insert fvarId .once }
|
||||
| _ => { map }
|
||||
@@ -67,7 +67,7 @@ Add new occurrence for the local function occurring as an argument for another f
|
||||
def FunDeclInfoMap.addHo (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
|
||||
match s with
|
||||
| { map } =>
|
||||
match map.find? fvarId with
|
||||
match map[fvarId]? with
|
||||
| some .once | none => { map := map.insert fvarId .many }
|
||||
| _ => { map }
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ Execute `x` with `fvarId` set as `mustInline`.
|
||||
After execution the original setting is restored.
|
||||
-/
|
||||
def withAddMustInline (fvarId : FVarId) (x : SimpM α) : SimpM α := do
|
||||
let saved? := (← get).funDeclInfoMap.map.find? fvarId
|
||||
let saved? := (← get).funDeclInfoMap.map[fvarId]?
|
||||
try
|
||||
addMustInline fvarId
|
||||
x
|
||||
@@ -185,7 +185,7 @@ Return true if the given local function declaration or join point id is marked a
|
||||
`once` or `mustInline`. We use this information to decide whether to inline them.
|
||||
-/
|
||||
def isOnceOrMustInline (fvarId : FVarId) : SimpM Bool := do
|
||||
match (← get).funDeclInfoMap.map.find? fvarId with
|
||||
match (← get).funDeclInfoMap.map[fvarId]? with
|
||||
| some .once | some .mustInline => return true
|
||||
| _ => return false
|
||||
|
||||
|
||||
@@ -199,9 +199,9 @@ structure State where
|
||||
/-- Cache from Lean regular expression to LCNF argument. -/
|
||||
cache : PHashMap Expr Arg := {}
|
||||
/-- `toLCNFType` cache -/
|
||||
typeCache : HashMap Expr Expr := {}
|
||||
typeCache : Std.HashMap Expr Expr := {}
|
||||
/-- isTypeFormerType cache -/
|
||||
isTypeFormerTypeCache : HashMap Expr Bool := {}
|
||||
isTypeFormerTypeCache : Std.HashMap Expr Bool := {}
|
||||
/-- LCNF sequence, we chain it to create a LCNF `Code` object. -/
|
||||
seq : Array Element := #[]
|
||||
/--
|
||||
@@ -257,7 +257,7 @@ private partial def isTypeFormerType (type : Expr) : M Bool := do
|
||||
| .true => return true
|
||||
| .false => return false
|
||||
| .undef =>
|
||||
if let some result := (← get).isTypeFormerTypeCache.find? type then
|
||||
if let some result := (← get).isTypeFormerTypeCache[type]? then
|
||||
return result
|
||||
let result ← liftMetaM <| Meta.isTypeFormerType type
|
||||
modify fun s => { s with isTypeFormerTypeCache := s.isTypeFormerTypeCache.insert type result }
|
||||
@@ -305,7 +305,7 @@ def applyToAny (type : Expr) : M Expr := do
|
||||
| _ => none
|
||||
|
||||
def toLCNFType (type : Expr) : M Expr := do
|
||||
match (← get).typeCache.find? type with
|
||||
match (← get).typeCache[type]? with
|
||||
| some type' => return type'
|
||||
| none =>
|
||||
let type' ← liftMetaM <| LCNF.toLCNFType type
|
||||
|
||||
@@ -270,16 +270,3 @@ def ofListWith (l : List (α × β)) (f : β → β → β) : HashMap α β :=
|
||||
| some v => m.insert p.fst $ f v p.snd)
|
||||
|
||||
end Lean.HashMap
|
||||
|
||||
/--
|
||||
Groups all elements `x`, `y` in `xs` with `key x == key y` into the same array
|
||||
`(xs.groupByKey key).find! (key x)`. Groups preserve the relative order of elements in `xs`.
|
||||
-/
|
||||
def Array.groupByKey [BEq α] [Hashable α] (key : β → α) (xs : Array β)
|
||||
: Lean.HashMap α (Array β) := Id.run do
|
||||
let mut groups := ∅
|
||||
for x in xs do
|
||||
let group := groups.findD (key x) #[]
|
||||
groups := groups.erase (key x) -- make `group` referentially unique
|
||||
groups := groups.insert (key x) (group.push x)
|
||||
return groups
|
||||
|
||||
@@ -12,10 +12,11 @@ import Lean.Data.RBMap
|
||||
namespace Lean.Json.Parser
|
||||
|
||||
open Lean.Parsec
|
||||
open Lean.Parsec.String
|
||||
|
||||
@[inline]
|
||||
def hexChar : Parsec Nat := do
|
||||
let c ← anyChar
|
||||
def hexChar : Parser Nat := do
|
||||
let c ← any
|
||||
if '0' ≤ c ∧ c ≤ '9' then
|
||||
pure $ c.val.toNat - '0'.val.toNat
|
||||
else if 'a' ≤ c ∧ c ≤ 'f' then
|
||||
@@ -25,8 +26,8 @@ def hexChar : Parsec Nat := do
|
||||
else
|
||||
fail "invalid hex character"
|
||||
|
||||
def escapedChar : Parsec Char := do
|
||||
let c ← anyChar
|
||||
def escapedChar : Parser Char := do
|
||||
let c ← any
|
||||
match c with
|
||||
| '\\' => return '\\'
|
||||
| '"' => return '"'
|
||||
@@ -41,13 +42,13 @@ def escapedChar : Parsec Char := do
|
||||
return Char.ofNat $ 4096*u1 + 256*u2 + 16*u3 + u4
|
||||
| _ => fail "illegal \\u escape"
|
||||
|
||||
partial def strCore (acc : String) : Parsec String := do
|
||||
partial def strCore (acc : String) : Parser String := do
|
||||
let c ← peek!
|
||||
if c = '"' then -- "
|
||||
skip
|
||||
return acc
|
||||
else
|
||||
let c ← anyChar
|
||||
let c ← any
|
||||
if c = '\\' then
|
||||
strCore (acc.push (← escapedChar))
|
||||
-- as to whether c.val > 0xffff should be split up and encoded with multiple \u,
|
||||
@@ -58,9 +59,9 @@ partial def strCore (acc : String) : Parsec String := do
|
||||
else
|
||||
fail "unexpected character in string"
|
||||
|
||||
def str : Parsec String := strCore ""
|
||||
def str : Parser String := strCore ""
|
||||
|
||||
partial def natCore (acc digits : Nat) : Parsec (Nat × Nat) := do
|
||||
partial def natCore (acc digits : Nat) : Parser (Nat × Nat) := do
|
||||
let some c ← peek? | return (acc, digits)
|
||||
if '0' ≤ c ∧ c ≤ '9' then
|
||||
skip
|
||||
@@ -70,7 +71,7 @@ partial def natCore (acc digits : Nat) : Parsec (Nat × Nat) := do
|
||||
return (acc, digits)
|
||||
|
||||
@[inline]
|
||||
def lookahead (p : Char → Prop) (desc : String) [DecidablePred p] : Parsec Unit := do
|
||||
def lookahead (p : Char → Prop) (desc : String) [DecidablePred p] : Parser Unit := do
|
||||
let c ← peek!
|
||||
if p c then
|
||||
return ()
|
||||
@@ -78,22 +79,22 @@ def lookahead (p : Char → Prop) (desc : String) [DecidablePred p] : Parsec Uni
|
||||
fail <| "expected " ++ desc
|
||||
|
||||
@[inline]
|
||||
def natNonZero : Parsec Nat := do
|
||||
def natNonZero : Parser Nat := do
|
||||
lookahead (fun c => '1' ≤ c ∧ c ≤ '9') "1-9"
|
||||
let (n, _) ← natCore 0 0
|
||||
return n
|
||||
|
||||
@[inline]
|
||||
def natNumDigits : Parsec (Nat × Nat) := do
|
||||
def natNumDigits : Parser (Nat × Nat) := do
|
||||
lookahead (fun c => '0' ≤ c ∧ c ≤ '9') "digit"
|
||||
natCore 0 0
|
||||
|
||||
@[inline]
|
||||
def natMaybeZero : Parsec Nat := do
|
||||
def natMaybeZero : Parser Nat := do
|
||||
let (n, _) ← natNumDigits
|
||||
return n
|
||||
|
||||
def num : Parsec JsonNumber := do
|
||||
def num : Parser JsonNumber := do
|
||||
let c ← peek!
|
||||
let sign ← if c = '-' then
|
||||
skip
|
||||
@@ -132,10 +133,10 @@ def num : Parsec JsonNumber := do
|
||||
else
|
||||
return res
|
||||
|
||||
partial def arrayCore (anyCore : Parsec Json) (acc : Array Json) : Parsec (Array Json) := do
|
||||
partial def arrayCore (anyCore : Parser Json) (acc : Array Json) : Parser (Array Json) := do
|
||||
let hd ← anyCore
|
||||
let acc' := acc.push hd
|
||||
let c ← anyChar
|
||||
let c ← any
|
||||
if c = ']' then
|
||||
ws
|
||||
return acc'
|
||||
@@ -145,12 +146,12 @@ partial def arrayCore (anyCore : Parsec Json) (acc : Array Json) : Parsec (Array
|
||||
else
|
||||
fail "unexpected character in array"
|
||||
|
||||
partial def objectCore (anyCore : Parsec Json) : Parsec (RBNode String (fun _ => Json)) := do
|
||||
partial def objectCore (anyCore : Parser Json) : Parser (RBNode String (fun _ => Json)) := do
|
||||
lookahead (fun c => c = '"') "\""; skip; -- "
|
||||
let k ← strCore ""; ws
|
||||
lookahead (fun c => c = ':') ":"; skip; ws
|
||||
let v ← anyCore
|
||||
let c ← anyChar
|
||||
let c ← any
|
||||
if c = '}' then
|
||||
ws
|
||||
return RBNode.singleton k v
|
||||
@@ -161,7 +162,7 @@ partial def objectCore (anyCore : Parsec Json) : Parsec (RBNode String (fun _ =>
|
||||
else
|
||||
fail "unexpected character in object"
|
||||
|
||||
partial def anyCore : Parsec Json := do
|
||||
partial def anyCore : Parser Json := do
|
||||
let c ← peek!
|
||||
if c = '[' then
|
||||
skip; ws
|
||||
@@ -203,7 +204,7 @@ partial def anyCore : Parsec Json := do
|
||||
fail "unexpected input"
|
||||
|
||||
|
||||
def any : Parsec Json := do
|
||||
def any : Parser Json := do
|
||||
ws
|
||||
let res ← anyCore
|
||||
eof
|
||||
|
||||
@@ -150,7 +150,7 @@ instance : FromJson RefInfo where
|
||||
pure { definition?, usages }
|
||||
|
||||
/-- References from a single module/file -/
|
||||
def ModuleRefs := HashMap RefIdent RefInfo
|
||||
def ModuleRefs := Std.HashMap RefIdent RefInfo
|
||||
|
||||
instance : ToJson ModuleRefs where
|
||||
toJson m := Json.mkObj <| m.toList.map fun (ident, info) => (ident.toJson.compress, toJson info)
|
||||
@@ -158,7 +158,7 @@ instance : ToJson ModuleRefs where
|
||||
instance : FromJson ModuleRefs where
|
||||
fromJson? j := do
|
||||
let node ← j.getObj?
|
||||
node.foldM (init := HashMap.empty) fun m k v =>
|
||||
node.foldM (init := Std.HashMap.empty) fun m k v =>
|
||||
return m.insert (← RefIdent.fromJson? (← Json.parse k)) (← fromJson? v)
|
||||
|
||||
/--
|
||||
|
||||
@@ -4,7 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Data.HashSet
|
||||
import Std.Data.HashSet.Basic
|
||||
import Lean.Data.RBMap
|
||||
import Lean.Data.RBTree
|
||||
import Lean.Data.SSet
|
||||
@@ -64,14 +64,14 @@ abbrev insert (s : NameSSet) (n : Name) : NameSSet := SSet.insert s n
|
||||
abbrev contains (s : NameSSet) (n : Name) : Bool := SSet.contains s n
|
||||
end NameSSet
|
||||
|
||||
def NameHashSet := HashSet Name
|
||||
def NameHashSet := Std.HashSet Name
|
||||
|
||||
namespace NameHashSet
|
||||
@[inline] def empty : NameHashSet := HashSet.empty
|
||||
@[inline] def empty : NameHashSet := Std.HashSet.empty
|
||||
instance : EmptyCollection NameHashSet := ⟨empty⟩
|
||||
instance : Inhabited NameHashSet := ⟨{}⟩
|
||||
def insert (s : NameHashSet) (n : Name) := HashSet.insert s n
|
||||
def contains (s : NameHashSet) (n : Name) : Bool := HashSet.contains s n
|
||||
def insert (s : NameHashSet) (n : Name) := Std.HashSet.insert s n
|
||||
def contains (s : NameHashSet) (n : Name) : Bool := Std.HashSet.contains s n
|
||||
end NameHashSet
|
||||
|
||||
def MacroScopesView.isPrefixOf (v₁ v₂ : MacroScopesView) : Bool :=
|
||||
|
||||
@@ -4,181 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Dany Fabian
|
||||
-/
|
||||
prelude
|
||||
import Init.NotationExtra
|
||||
import Init.Data.ToString.Macro
|
||||
|
||||
namespace Lean
|
||||
|
||||
namespace Parsec
|
||||
inductive ParseResult (α : Type) where
|
||||
| success (pos : String.Iterator) (res : α)
|
||||
| error (pos : String.Iterator) (err : String)
|
||||
deriving Repr
|
||||
end Parsec
|
||||
|
||||
def Parsec (α : Type) : Type := String.Iterator → Lean.Parsec.ParseResult α
|
||||
|
||||
namespace Parsec
|
||||
|
||||
open ParseResult
|
||||
|
||||
instance (α : Type) : Inhabited (Parsec α) :=
|
||||
⟨λ it => error it ""⟩
|
||||
|
||||
@[inline]
|
||||
protected def pure (a : α) : Parsec α := λ it =>
|
||||
success it a
|
||||
|
||||
@[inline]
|
||||
def bind {α β : Type} (f : Parsec α) (g : α → Parsec β) : Parsec β := λ it =>
|
||||
match f it with
|
||||
| success rem a => g a rem
|
||||
| error pos msg => error pos msg
|
||||
|
||||
instance : Monad Parsec :=
|
||||
{ pure := Parsec.pure, bind }
|
||||
|
||||
@[inline]
|
||||
def fail (msg : String) : Parsec α := fun it =>
|
||||
error it msg
|
||||
|
||||
@[inline]
|
||||
def tryCatch (p : Parsec α)
|
||||
(csuccess : α → Parsec β)
|
||||
(cerror : Unit → Parsec β)
|
||||
: Parsec β := fun it =>
|
||||
match p it with
|
||||
| .success rem a => csuccess a rem
|
||||
| .error rem err =>
|
||||
-- We assume that it.s never changes as the `Parsec` monad only modifies `it.pos`.
|
||||
if it.pos = rem.pos then cerror () rem else .error rem err
|
||||
|
||||
@[inline]
|
||||
def orElse (p : Parsec α) (q : Unit → Parsec α) : Parsec α :=
|
||||
tryCatch p pure q
|
||||
|
||||
@[inline]
|
||||
def attempt (p : Parsec α) : Parsec α := λ it =>
|
||||
match p it with
|
||||
| success rem res => success rem res
|
||||
| error _ err => error it err
|
||||
|
||||
instance : Alternative Parsec :=
|
||||
{ failure := fail "", orElse }
|
||||
|
||||
protected def run (p : Parsec α) (s : String) : Except String α :=
|
||||
match p s.mkIterator with
|
||||
| Parsec.ParseResult.success _ res => Except.ok res
|
||||
| Parsec.ParseResult.error it err => Except.error s!"offset {repr it.i.byteIdx}: {err}"
|
||||
|
||||
def expectedEndOfInput := "expected end of input"
|
||||
|
||||
@[inline]
|
||||
def eof : Parsec Unit := fun it =>
|
||||
if it.hasNext then
|
||||
error it expectedEndOfInput
|
||||
else
|
||||
success it ()
|
||||
|
||||
@[specialize]
|
||||
partial def manyCore (p : Parsec α) (acc : Array α) : Parsec $ Array α :=
|
||||
tryCatch p (manyCore p <| acc.push ·) (fun _ => pure acc)
|
||||
|
||||
@[inline]
|
||||
def many (p : Parsec α) : Parsec $ Array α := manyCore p #[]
|
||||
|
||||
@[inline]
|
||||
def many1 (p : Parsec α) : Parsec $ Array α := do manyCore p #[←p]
|
||||
|
||||
@[specialize]
|
||||
partial def manyCharsCore (p : Parsec Char) (acc : String) : Parsec String :=
|
||||
tryCatch p (manyCharsCore p <| acc.push ·) (fun _ => pure acc)
|
||||
|
||||
@[inline]
|
||||
def manyChars (p : Parsec Char) : Parsec String := manyCharsCore p ""
|
||||
|
||||
@[inline]
|
||||
def many1Chars (p : Parsec Char) : Parsec String := do manyCharsCore p (←p).toString
|
||||
|
||||
/-- Parses the given string. -/
|
||||
def pstring (s : String) : Parsec String := λ it =>
|
||||
let substr := it.extract (it.forward s.length)
|
||||
if substr = s then
|
||||
success (it.forward s.length) substr
|
||||
else
|
||||
error it s!"expected: {s}"
|
||||
|
||||
@[inline]
|
||||
def skipString (s : String) : Parsec Unit := pstring s *> pure ()
|
||||
|
||||
def unexpectedEndOfInput := "unexpected end of input"
|
||||
|
||||
@[inline]
|
||||
def anyChar : Parsec Char := λ it =>
|
||||
if it.hasNext then success it.next it.curr else error it unexpectedEndOfInput
|
||||
|
||||
@[inline]
|
||||
def pchar (c : Char) : Parsec Char := attempt do
|
||||
if (←anyChar) = c then pure c else fail s!"expected: '{c}'"
|
||||
|
||||
@[inline]
|
||||
def skipChar (c : Char) : Parsec Unit := pchar c *> pure ()
|
||||
|
||||
@[inline]
|
||||
def digit : Parsec Char := attempt do
|
||||
let c ← anyChar
|
||||
if '0' ≤ c ∧ c ≤ '9' then return c else fail s!"digit expected"
|
||||
|
||||
@[inline]
|
||||
def hexDigit : Parsec Char := attempt do
|
||||
let c ← anyChar
|
||||
if ('0' ≤ c ∧ c ≤ '9')
|
||||
∨ ('a' ≤ c ∧ c ≤ 'f')
|
||||
∨ ('A' ≤ c ∧ c ≤ 'F') then return c else fail s!"hex digit expected"
|
||||
|
||||
@[inline]
|
||||
def asciiLetter : Parsec Char := attempt do
|
||||
let c ← anyChar
|
||||
if ('A' ≤ c ∧ c ≤ 'Z') ∨ ('a' ≤ c ∧ c ≤ 'z') then return c else fail s!"ASCII letter expected"
|
||||
|
||||
@[inline]
|
||||
def satisfy (p : Char → Bool) : Parsec Char := attempt do
|
||||
let c ← anyChar
|
||||
if p c then return c else fail "condition not satisfied"
|
||||
|
||||
@[inline]
|
||||
def notFollowedBy (p : Parsec α) : Parsec Unit := λ it =>
|
||||
match p it with
|
||||
| success _ _ => error it ""
|
||||
| error _ _ => success it ()
|
||||
|
||||
partial def skipWs (it : String.Iterator) : String.Iterator :=
|
||||
if it.hasNext then
|
||||
let c := it.curr
|
||||
if c = '\u0009' ∨ c = '\u000a' ∨ c = '\u000d' ∨ c = '\u0020' then
|
||||
skipWs it.next
|
||||
else
|
||||
it
|
||||
else
|
||||
it
|
||||
|
||||
@[inline]
|
||||
def peek? : Parsec (Option Char) := fun it =>
|
||||
if it.hasNext then
|
||||
success it it.curr
|
||||
else
|
||||
success it none
|
||||
|
||||
@[inline]
|
||||
def peek! : Parsec Char := do
|
||||
let some c ← peek? | fail unexpectedEndOfInput
|
||||
return c
|
||||
|
||||
@[inline]
|
||||
def skip : Parsec Unit := fun it =>
|
||||
success it.next ()
|
||||
|
||||
@[inline]
|
||||
def ws : Parsec Unit := fun it =>
|
||||
success (skipWs it) ()
|
||||
end Parsec
|
||||
import Lean.Data.Parsec.Basic
|
||||
import Lean.Data.Parsec.String
|
||||
import Lean.Data.Parsec.ByteArray
|
||||
|
||||
144
src/Lean/Data/Parsec/Basic.lean
Normal file
144
src/Lean/Data/Parsec/Basic.lean
Normal file
@@ -0,0 +1,144 @@
|
||||
/-
|
||||
Copyright (c) 2021 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Dany Fabian, Henrik Böving
|
||||
-/
|
||||
prelude
|
||||
import Init.NotationExtra
|
||||
import Init.Data.ToString.Macro
|
||||
|
||||
namespace Lean
|
||||
|
||||
namespace Parsec
|
||||
|
||||
inductive ParseResult (α : Type) (ι : Type) where
|
||||
| success (pos : ι) (res : α)
|
||||
| error (pos : ι) (err : String)
|
||||
deriving Repr
|
||||
|
||||
end Parsec
|
||||
|
||||
def Parsec (ι : Type) (α : Type) : Type := ι → Lean.Parsec.ParseResult α ι
|
||||
|
||||
namespace Parsec
|
||||
|
||||
class Input (ι : Type) (elem : outParam Type) (idx : outParam Type) [DecidableEq idx] [DecidableEq elem] where
|
||||
pos : ι → idx
|
||||
next : ι → ι
|
||||
curr : ι → elem
|
||||
hasNext : ι → Bool
|
||||
|
||||
variable {α : Type} {ι : Type} {elem : Type} {idx : Type}
|
||||
variable [DecidableEq idx] [DecidableEq elem] [Input ι elem idx]
|
||||
|
||||
instance : Inhabited (Parsec ι α) where
|
||||
default := fun it => .error it ""
|
||||
|
||||
@[inline]
|
||||
protected def pure (a : α) : Parsec ι α := fun it =>
|
||||
.success it a
|
||||
|
||||
@[inline]
|
||||
def bind {α β : Type} (f : Parsec ι α) (g : α → Parsec ι β) : Parsec ι β := fun it =>
|
||||
match f it with
|
||||
| .success rem a => g a rem
|
||||
| .error pos msg => .error pos msg
|
||||
|
||||
instance : Monad (Parsec ι) where
|
||||
pure := Parsec.pure
|
||||
bind := Parsec.bind
|
||||
|
||||
@[inline]
|
||||
def fail (msg : String) : Parsec ι α := fun it =>
|
||||
.error it msg
|
||||
|
||||
@[inline]
|
||||
def tryCatch (p : Parsec ι α) (csuccess : α → Parsec ι β) (cerror : Unit → Parsec ι β)
|
||||
: Parsec ι β := fun it =>
|
||||
match p it with
|
||||
| .success rem a => csuccess a rem
|
||||
| .error rem err =>
|
||||
-- We assume that it.s never changes as the `Parsec` monad only modifies `it.pos`.
|
||||
if Input.pos it = Input.pos rem then cerror () rem else .error rem err
|
||||
|
||||
@[inline]
|
||||
def orElse (p : Parsec ι α) (q : Unit → Parsec ι α) : Parsec ι α :=
|
||||
tryCatch p pure q
|
||||
|
||||
@[inline]
|
||||
def attempt (p : Parsec ι α) : Parsec ι α := fun it =>
|
||||
match p it with
|
||||
| .success rem res => .success rem res
|
||||
| .error _ err => .error it err
|
||||
|
||||
instance : Alternative (Parsec ι) where
|
||||
failure := fail ""
|
||||
orElse := orElse
|
||||
|
||||
def expectedEndOfInput := "expected end of input"
|
||||
|
||||
@[inline]
|
||||
def eof : Parsec ι Unit := fun it =>
|
||||
if Input.hasNext it then
|
||||
.error it expectedEndOfInput
|
||||
else
|
||||
.success it ()
|
||||
|
||||
@[specialize]
|
||||
partial def manyCore (p : Parsec ι α) (acc : Array α) : Parsec ι <| Array α :=
|
||||
tryCatch p (manyCore p <| acc.push ·) (fun _ => pure acc)
|
||||
|
||||
@[inline]
|
||||
def many (p : Parsec ι α) : Parsec ι <| Array α := manyCore p #[]
|
||||
|
||||
@[inline]
|
||||
def many1 (p : Parsec ι α) : Parsec ι <| Array α := do manyCore p #[← p]
|
||||
|
||||
def unexpectedEndOfInput := "unexpected end of input"
|
||||
|
||||
@[inline]
|
||||
def any : Parsec ι elem := fun it =>
|
||||
if Input.hasNext it then
|
||||
.success (Input.next it) (Input.curr it)
|
||||
else
|
||||
.error it unexpectedEndOfInput
|
||||
|
||||
@[inline]
|
||||
def satisfy (p : elem → Bool) : Parsec ι elem := attempt do
|
||||
let c ← any
|
||||
if p c then return c else fail "condition not satisfied"
|
||||
|
||||
@[inline]
|
||||
def notFollowedBy (p : Parsec ι α) : Parsec ι Unit := fun it =>
|
||||
match p it with
|
||||
| .success _ _ => .error it ""
|
||||
| .error _ _ => .success it ()
|
||||
|
||||
@[inline]
|
||||
def peek? : Parsec ι (Option elem) := fun it =>
|
||||
if Input.hasNext it then
|
||||
.success it (Input.curr it)
|
||||
else
|
||||
.success it none
|
||||
|
||||
@[inline]
|
||||
def peek! : Parsec ι elem := do
|
||||
let some c ← peek? | fail unexpectedEndOfInput
|
||||
return c
|
||||
|
||||
@[inline]
|
||||
def skip : Parsec ι Unit := fun it =>
|
||||
.success (Input.next it) ()
|
||||
|
||||
@[specialize]
|
||||
partial def manyCharsCore (p : Parsec ι Char) (acc : String) : Parsec ι String :=
|
||||
tryCatch p (manyCharsCore p <| acc.push ·) (fun _ => pure acc)
|
||||
|
||||
@[inline]
|
||||
def manyChars (p : Parsec ι Char) : Parsec ι String := manyCharsCore p ""
|
||||
|
||||
@[inline]
|
||||
def many1Chars (p : Parsec ι Char) : Parsec ι String := do manyCharsCore p (← p).toString
|
||||
|
||||
|
||||
end Parsec
|
||||
103
src/Lean/Data/Parsec/ByteArray.lean
Normal file
103
src/Lean/Data/Parsec/ByteArray.lean
Normal file
@@ -0,0 +1,103 @@
|
||||
/-
|
||||
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Henrik Böving
|
||||
-/
|
||||
prelude
|
||||
import Lean.Data.Parsec.Basic
|
||||
import Init.Data.ByteArray.Basic
|
||||
import Init.Data.String.Extra
|
||||
|
||||
namespace Lean
|
||||
namespace Parsec
|
||||
namespace ByteArray
|
||||
|
||||
instance : Input ByteArray.Iterator UInt8 Nat where
|
||||
pos it := it.pos
|
||||
next it := it.next
|
||||
curr it := it.curr
|
||||
hasNext it := it.hasNext
|
||||
|
||||
abbrev Parser (α : Type) : Type := Parsec ByteArray.Iterator α
|
||||
|
||||
protected def Parser.run (p : Parser α) (arr : ByteArray) : Except String α :=
|
||||
match p arr.iter with
|
||||
| .success _ res => Except.ok res
|
||||
| .error it err => Except.error s!"offset {repr it.pos}: {err}"
|
||||
|
||||
@[inline]
|
||||
def pbyte (b : UInt8) : Parser UInt8 := attempt do
|
||||
if (← any) = b then pure b else fail s!"expected: '{b}'"
|
||||
|
||||
@[inline]
|
||||
def skipByte (b : UInt8) : Parser Unit := pbyte b *> pure ()
|
||||
|
||||
def skipBytes (arr : ByteArray) : Parser Unit := do
|
||||
for b in arr do
|
||||
skipByte b
|
||||
|
||||
@[inline]
|
||||
def pstring (s : String) : Parser String := do
|
||||
skipBytes s.toUTF8
|
||||
return s
|
||||
|
||||
@[inline]
|
||||
def skipString (s : String) : Parser Unit := pstring s *> pure ()
|
||||
|
||||
/--
|
||||
Parse a `Char` that can be represented in 1 byte. If `c` uses more than 1 byte it is truncated.
|
||||
-/
|
||||
@[inline]
|
||||
def pByteChar (c : Char) : Parser Char := attempt do
|
||||
if (← any) = c.toUInt8 then pure c else fail s!"expected: '{c}'"
|
||||
|
||||
/--
|
||||
Skip a `Char` that can be represented in 1 byte. If `c` uses more than 1 byte it is truncated.
|
||||
-/
|
||||
@[inline]
|
||||
def skipByteChar (c : Char) : Parser Unit := skipByte c.toUInt8
|
||||
|
||||
@[inline]
|
||||
def digit : Parser Char := attempt do
|
||||
let b ← any
|
||||
if '0'.toUInt8 ≤ b ∧ b ≤ '9'.toUInt8 then return Char.ofUInt8 b else fail s!"digit expected"
|
||||
|
||||
@[inline]
|
||||
def hexDigit : Parser Char := attempt do
|
||||
let b ← any
|
||||
if ('0'.toUInt8 ≤ b ∧ b ≤ '9'.toUInt8)
|
||||
∨ ('a'.toUInt8 ≤ b ∧ b ≤ 'f'.toUInt8)
|
||||
∨ ('A'.toUInt8 ≤ b ∧ b ≤ 'F'.toUInt8) then return Char.ofUInt8 b else fail s!"hex digit expected"
|
||||
|
||||
@[inline]
|
||||
def asciiLetter : Parser Char := attempt do
|
||||
let b ← any
|
||||
if ('A'.toUInt8 ≤ b ∧ b ≤ 'Z'.toUInt8) ∨ ('a'.toUInt8 ≤ b ∧ b ≤ 'z'.toUInt8) then
|
||||
return Char.ofUInt8 b
|
||||
else
|
||||
fail s!"ASCII letter expected"
|
||||
|
||||
private partial def skipWs (it : ByteArray.Iterator) : ByteArray.Iterator :=
|
||||
if it.hasNext then
|
||||
let b := it.curr
|
||||
if b = '\u0009'.toUInt8 ∨ b = '\u000a'.toUInt8 ∨ b = '\u000d'.toUInt8 ∨ b = '\u0020'.toUInt8 then
|
||||
skipWs it.next
|
||||
else
|
||||
it
|
||||
else
|
||||
it
|
||||
|
||||
@[inline]
|
||||
def ws : Parser Unit := fun it =>
|
||||
.success (skipWs it) ()
|
||||
|
||||
def take (n : Nat) : Parser ByteArray := fun it =>
|
||||
let subarr := it.array.extract it.idx (it.idx + n)
|
||||
if subarr.size != n then
|
||||
.error it s!"expected: {n} bytes"
|
||||
else
|
||||
.success (it.forward n) subarr
|
||||
|
||||
end ByteArray
|
||||
end Parsec
|
||||
end Lean
|
||||
84
src/Lean/Data/Parsec/String.lean
Normal file
84
src/Lean/Data/Parsec/String.lean
Normal file
@@ -0,0 +1,84 @@
|
||||
/-
|
||||
Copyright (c) 2021 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Dany Fabian, Henrik Böving
|
||||
-/
|
||||
prelude
|
||||
import Lean.Data.Parsec.Basic
|
||||
|
||||
namespace Lean
|
||||
namespace Parsec
|
||||
namespace String
|
||||
|
||||
instance : Input String.Iterator Char String.Pos where
|
||||
pos it := it.pos
|
||||
next it := it.next
|
||||
curr it := it.curr
|
||||
hasNext it := it.hasNext
|
||||
|
||||
abbrev Parser (α : Type) : Type := Parsec String.Iterator α
|
||||
|
||||
protected def Parser.run (p : Parser α) (s : String) : Except String α :=
|
||||
match p s.mkIterator with
|
||||
| .success _ res => Except.ok res
|
||||
| .error it err => Except.error s!"offset {repr it.i.byteIdx}: {err}"
|
||||
|
||||
/-- Parses the given string. -/
|
||||
def pstring (s : String) : Parser String := fun it =>
|
||||
let substr := it.extract (it.forward s.length)
|
||||
if substr = s then
|
||||
.success (it.forward s.length) substr
|
||||
else
|
||||
.error it s!"expected: {s}"
|
||||
|
||||
@[inline]
|
||||
def skipString (s : String) : Parser Unit := pstring s *> pure ()
|
||||
|
||||
@[inline]
|
||||
def pchar (c : Char) : Parser Char := attempt do
|
||||
if (← any) = c then pure c else fail s!"expected: '{c}'"
|
||||
|
||||
@[inline]
|
||||
def skipChar (c : Char) : Parser Unit := pchar c *> pure ()
|
||||
|
||||
@[inline]
|
||||
def digit : Parser Char := attempt do
|
||||
let c ← any
|
||||
if '0' ≤ c ∧ c ≤ '9' then return c else fail s!"digit expected"
|
||||
|
||||
@[inline]
|
||||
def hexDigit : Parser Char := attempt do
|
||||
let c ← any
|
||||
if ('0' ≤ c ∧ c ≤ '9')
|
||||
∨ ('a' ≤ c ∧ c ≤ 'f')
|
||||
∨ ('A' ≤ c ∧ c ≤ 'F') then return c else fail s!"hex digit expected"
|
||||
|
||||
@[inline]
|
||||
def asciiLetter : Parser Char := attempt do
|
||||
let c ← any
|
||||
if ('A' ≤ c ∧ c ≤ 'Z') ∨ ('a' ≤ c ∧ c ≤ 'z') then return c else fail s!"ASCII letter expected"
|
||||
|
||||
private partial def skipWs (it : String.Iterator) : String.Iterator :=
|
||||
if it.hasNext then
|
||||
let c := it.curr
|
||||
if c = '\u0009' ∨ c = '\u000a' ∨ c = '\u000d' ∨ c = '\u0020' then
|
||||
skipWs it.next
|
||||
else
|
||||
it
|
||||
else
|
||||
it
|
||||
|
||||
@[inline]
|
||||
def ws : Parser Unit := fun it =>
|
||||
.success (skipWs it) ()
|
||||
|
||||
def take (n : Nat) : Parser String := fun it =>
|
||||
let substr := it.extract (it.forward n)
|
||||
if substr.length != n then
|
||||
.error it s!"expected: {n} codepoints"
|
||||
else
|
||||
.success (it.forward n) substr
|
||||
|
||||
end String
|
||||
end Parsec
|
||||
end Lean
|
||||
@@ -4,7 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Data.HashMap
|
||||
import Std.Data.HashMap.Basic
|
||||
import Lean.Data.PersistentHashMap
|
||||
universe u v w w'
|
||||
|
||||
@@ -28,7 +28,7 @@ namespace Lean
|
||||
-/
|
||||
structure SMap (α : Type u) (β : Type v) [BEq α] [Hashable α] where
|
||||
stage₁ : Bool := true
|
||||
map₁ : HashMap α β := {}
|
||||
map₁ : Std.HashMap α β := {}
|
||||
map₂ : PHashMap α β := {}
|
||||
|
||||
namespace SMap
|
||||
@@ -37,7 +37,7 @@ variable {α : Type u} {β : Type v} [BEq α] [Hashable α]
|
||||
instance : Inhabited (SMap α β) := ⟨{}⟩
|
||||
def empty : SMap α β := {}
|
||||
|
||||
@[inline] def fromHashMap (m : HashMap α β) (stage₁ := true) : SMap α β :=
|
||||
@[inline] def fromHashMap (m : Std.HashMap α β) (stage₁ := true) : SMap α β :=
|
||||
{ map₁ := m, stage₁ := stage₁ }
|
||||
|
||||
@[specialize] def insert : SMap α β → α → β → SMap α β
|
||||
@@ -49,8 +49,8 @@ def empty : SMap α β := {}
|
||||
| ⟨false, m₁, m₂⟩, k, v => ⟨false, m₁, m₂.insert k v⟩
|
||||
|
||||
@[specialize] def find? : SMap α β → α → Option β
|
||||
| ⟨true, m₁, _⟩, k => m₁.find? k
|
||||
| ⟨false, m₁, m₂⟩, k => (m₂.find? k).orElse fun _ => m₁.find? k
|
||||
| ⟨true, m₁, _⟩, k => m₁[k]?
|
||||
| ⟨false, m₁, m₂⟩, k => (m₂.find? k).orElse fun _ => m₁[k]?
|
||||
|
||||
@[inline] def findD (m : SMap α β) (a : α) (b₀ : β) : β :=
|
||||
(m.find? a).getD b₀
|
||||
@@ -67,8 +67,8 @@ def empty : SMap α β := {}
|
||||
/-- Similar to `find?`, but searches for result in the hashmap first.
|
||||
So, the result is correct only if we never "overwrite" `map₁` entries using `map₂`. -/
|
||||
@[specialize] def find?' : SMap α β → α → Option β
|
||||
| ⟨true, m₁, _⟩, k => m₁.find? k
|
||||
| ⟨false, m₁, m₂⟩, k => (m₁.find? k).orElse fun _ => m₂.find? k
|
||||
| ⟨true, m₁, _⟩, k => m₁[k]?
|
||||
| ⟨false, m₁, m₂⟩, k => m₁[k]?.orElse fun _ => m₂.find? k
|
||||
|
||||
def forM [Monad m] (s : SMap α β) (f : α → β → m PUnit) : m PUnit := do
|
||||
s.map₁.forM f
|
||||
@@ -96,7 +96,7 @@ def fold {σ : Type w} (f : σ → α → β → σ) (init : σ) (m : SMap α β
|
||||
m.map₂.foldl f $ m.map₁.fold f init
|
||||
|
||||
def numBuckets (m : SMap α β) : Nat :=
|
||||
m.map₁.numBuckets
|
||||
Std.HashMap.Internal.numBuckets m.map₁
|
||||
|
||||
def toList (m : SMap α β) : List (α × β) :=
|
||||
m.fold (init := []) fun es a b => (a, b)::es
|
||||
|
||||
@@ -13,23 +13,24 @@ namespace Lean
|
||||
namespace Xml
|
||||
|
||||
namespace Parser
|
||||
|
||||
open Lean.Parsec
|
||||
open Parsec.ParseResult
|
||||
open Lean.Parsec.String
|
||||
|
||||
abbrev LeanChar := Char
|
||||
|
||||
/-- consume a newline character sequence pretending, that we read '\n'. As per spec:
|
||||
https://www.w3.org/TR/xml/#sec-line-ends -/
|
||||
def endl : Parsec LeanChar := (skipString "\r\n" <|> skipChar '\r' <|> skipChar '\n') *> pure '\n'
|
||||
def endl : Parser LeanChar := (skipString "\r\n" <|> skipChar '\r' <|> skipChar '\n') *> pure '\n'
|
||||
|
||||
def quote (p : Parsec α) : Parsec α :=
|
||||
def quote (p : Parser α) : Parser α :=
|
||||
skipChar '\'' *> p <* skipChar '\''
|
||||
<|> skipChar '"' *> p <* skipChar '"'
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-Char -/
|
||||
def Char : Parsec LeanChar :=
|
||||
def Char : Parser LeanChar :=
|
||||
(attempt do
|
||||
let c ← anyChar
|
||||
let c ← any
|
||||
let cNat := c.toNat
|
||||
if (0x20 ≤ cNat ∧ cNat ≤ 0xD7FF)
|
||||
∨ (0xE000 ≤ cNat ∧ cNat ≤ 0xFFFD)
|
||||
@@ -37,11 +38,11 @@ def Char : Parsec LeanChar :=
|
||||
<|> pchar '\t' <|> endl
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-S -/
|
||||
def S : Parsec String :=
|
||||
def S : Parser String :=
|
||||
many1Chars (pchar ' ' <|> endl <|> pchar '\t')
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-Eq -/
|
||||
def Eq : Parsec Unit :=
|
||||
def Eq : Parser Unit :=
|
||||
optional S *> skipChar '=' <* optional S
|
||||
|
||||
private def nameStartCharRanges : Array (Nat × Nat) :=
|
||||
@@ -59,8 +60,8 @@ private def nameStartCharRanges : Array (Nat × Nat) :=
|
||||
(0x10000, 0xEFFFF)]
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-NameStartChar -/
|
||||
def NameStartChar : Parsec LeanChar := attempt do
|
||||
let c ← anyChar
|
||||
def NameStartChar : Parser LeanChar := attempt do
|
||||
let c ← any
|
||||
if ('A' ≤ c ∧ c ≤ 'Z') ∨ ('a' ≤ c ∧ c ≤ 'z') then pure c
|
||||
else if c = ':' ∨ c = '_' then pure c
|
||||
else
|
||||
@@ -69,44 +70,44 @@ def NameStartChar : Parsec LeanChar := attempt do
|
||||
else fail "expected a name character"
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-NameChar -/
|
||||
def NameChar : Parsec LeanChar :=
|
||||
def NameChar : Parser LeanChar :=
|
||||
NameStartChar <|> digit <|> pchar '-' <|> pchar '.' <|> pchar '\xB7'
|
||||
<|> satisfy (λ c => ('\u0300' ≤ c ∧ c ≤ '\u036F') ∨ ('\u203F' ≤ c ∧ c ≤ '\u2040'))
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-Name -/
|
||||
def Name : Parsec String := do
|
||||
def Name : Parser String := do
|
||||
let x ← NameStartChar
|
||||
manyCharsCore NameChar x.toString
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-VersionNum -/
|
||||
def VersionNum : Parsec Unit :=
|
||||
def VersionNum : Parser Unit :=
|
||||
skipString "1." <* (many1 digit)
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-VersionInfo -/
|
||||
def VersionInfo : Parsec Unit := do
|
||||
def VersionInfo : Parser Unit := do
|
||||
S *>
|
||||
skipString "version"
|
||||
Eq
|
||||
quote VersionNum
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-EncName -/
|
||||
def EncName : Parsec String := do
|
||||
def EncName : Parser String := do
|
||||
let x ← asciiLetter
|
||||
manyCharsCore (asciiLetter <|> digit <|> pchar '-' <|> pchar '_' <|> pchar '.') x.toString
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-EncodingDecl -/
|
||||
def EncodingDecl : Parsec String := do
|
||||
def EncodingDecl : Parser String := do
|
||||
S *>
|
||||
skipString "encoding"
|
||||
Eq
|
||||
quote EncName
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-SDDecl -/
|
||||
def SDDecl : Parsec String := do
|
||||
def SDDecl : Parser String := do
|
||||
S *> skipString "standalone" *> Eq *> quote (pstring "yes" <|> pstring "no")
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-XMLDecl -/
|
||||
def XMLdecl : Parsec Unit := do
|
||||
def XMLdecl : Parser Unit := do
|
||||
skipString "<?xml"
|
||||
VersionInfo
|
||||
optional EncodingDecl *>
|
||||
@@ -115,7 +116,7 @@ def XMLdecl : Parsec Unit := do
|
||||
skipString "?>"
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-Comment -/
|
||||
def Comment : Parsec String :=
|
||||
def Comment : Parser String :=
|
||||
let notDash := Char.toString <$> satisfy (λ c => c ≠ '-')
|
||||
skipString "<!--" *>
|
||||
Array.foldl String.append "" <$> many (attempt <| notDash <|> (do
|
||||
@@ -125,45 +126,45 @@ def Comment : Parsec String :=
|
||||
<* skipString "-->"
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-PITarget -/
|
||||
def PITarget : Parsec String :=
|
||||
def PITarget : Parser String :=
|
||||
Name <* (skipChar 'X' <|> skipChar 'x') <* (skipChar 'M' <|> skipChar 'm') <* (skipChar 'L' <|> skipChar 'l')
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-PI -/
|
||||
def PI : Parsec Unit := do
|
||||
def PI : Parser Unit := do
|
||||
skipString "<?"
|
||||
<* PITarget <*
|
||||
optional (S *> manyChars (notFollowedBy (skipString "?>") *> Char))
|
||||
skipString "?>"
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-Misc -/
|
||||
def Misc : Parsec Unit :=
|
||||
def Misc : Parser Unit :=
|
||||
Comment *> pure () <|> PI <|> S *> pure ()
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-SystemLiteral -/
|
||||
def SystemLiteral : Parsec String :=
|
||||
def SystemLiteral : Parser String :=
|
||||
pchar '"' *> manyChars (satisfy λ c => c ≠ '"') <* pchar '"'
|
||||
<|> pchar '\'' *> manyChars (satisfy λ c => c ≠ '\'') <* pure '\''
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-PubidChar -/
|
||||
def PubidChar : Parsec LeanChar :=
|
||||
def PubidChar : Parser LeanChar :=
|
||||
asciiLetter <|> digit <|> endl <|> attempt do
|
||||
let c ← anyChar
|
||||
let c ← any
|
||||
if "-'()+,./:=?;!*#@$_%".contains c then pure c else fail "PublidChar expected"
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-PubidLiteral -/
|
||||
def PubidLiteral : Parsec String :=
|
||||
def PubidLiteral : Parser String :=
|
||||
pchar '"' *> manyChars PubidChar <* pchar '"'
|
||||
<|> pchar '\'' *> manyChars (attempt do
|
||||
let c ← PubidChar
|
||||
if c = '\'' then fail "'\\'' not expected" else pure c) <* pchar '\''
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-ExternalID -/
|
||||
def ExternalID : Parsec Unit :=
|
||||
def ExternalID : Parser Unit :=
|
||||
skipString "SYSTEM" *> S *> SystemLiteral *> pure ()
|
||||
<|> skipString "PUBLIC" *> S *> PubidLiteral *> S *> SystemLiteral *> pure ()
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-Mixed -/
|
||||
def Mixed : Parsec Unit :=
|
||||
def Mixed : Parser Unit :=
|
||||
(do
|
||||
skipChar '('
|
||||
optional S *>
|
||||
@@ -175,11 +176,11 @@ def Mixed : Parsec Unit :=
|
||||
|
||||
mutual
|
||||
/-- https://www.w3.org/TR/xml/#NT-cp -/
|
||||
partial def cp : Parsec Unit :=
|
||||
partial def cp : Parser Unit :=
|
||||
(Name *> pure () <|> choice <|> seq) <* optional (skipChar '?' <|> skipChar '*' <|> skipChar '+')
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-choice -/
|
||||
partial def choice : Parsec Unit := do
|
||||
partial def choice : Parser Unit := do
|
||||
skipChar '('
|
||||
optional S *>
|
||||
cp
|
||||
@@ -188,7 +189,7 @@ mutual
|
||||
skipChar ')'
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-seq -/
|
||||
partial def seq : Parsec Unit := do
|
||||
partial def seq : Parser Unit := do
|
||||
skipChar '('
|
||||
optional S *>
|
||||
cp
|
||||
@@ -198,15 +199,15 @@ mutual
|
||||
end
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-children -/
|
||||
def children : Parsec Unit :=
|
||||
def children : Parser Unit :=
|
||||
(choice <|> seq) <* optional (skipChar '?' <|> skipChar '*' <|> skipChar '+')
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-contentspec -/
|
||||
def contentspec : Parsec Unit := do
|
||||
def contentspec : Parser Unit := do
|
||||
skipString "EMPTY" <|> skipString "ANY" <|> Mixed <|> children
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-elementdecl -/
|
||||
def elementDecl : Parsec Unit := do
|
||||
def elementDecl : Parser Unit := do
|
||||
skipString "<!ELEMENT"
|
||||
S *>
|
||||
Name *>
|
||||
@@ -215,11 +216,11 @@ def elementDecl : Parsec Unit := do
|
||||
skipChar '>'
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-StringType -/
|
||||
def StringType : Parsec Unit :=
|
||||
def StringType : Parser Unit :=
|
||||
skipString "CDATA"
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-TokenizedType -/
|
||||
def TokenizedType : Parsec Unit :=
|
||||
def TokenizedType : Parser Unit :=
|
||||
skipString "ID"
|
||||
<|> skipString "IDREF"
|
||||
<|> skipString "IDREFS"
|
||||
@@ -229,7 +230,7 @@ def TokenizedType : Parsec Unit :=
|
||||
<|> skipString "NMTOKENS"
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-NotationType -/
|
||||
def NotationType : Parsec Unit := do
|
||||
def NotationType : Parser Unit := do
|
||||
skipString "NOTATION"
|
||||
S *>
|
||||
skipChar '(' <*
|
||||
@@ -239,11 +240,11 @@ def NotationType : Parsec Unit := do
|
||||
skipChar ')'
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-Nmtoken -/
|
||||
def Nmtoken : Parsec String := do
|
||||
def Nmtoken : Parser String := do
|
||||
many1Chars NameChar
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-Enumeration -/
|
||||
def Enumeration : Parsec Unit := do
|
||||
def Enumeration : Parser Unit := do
|
||||
skipChar '('
|
||||
optional S *>
|
||||
Nmtoken *> many (optional S *> skipChar '|' *> optional S *> Nmtoken) *>
|
||||
@@ -251,11 +252,11 @@ def Enumeration : Parsec Unit := do
|
||||
skipChar ')'
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-EnumeratedType -/
|
||||
def EnumeratedType : Parsec Unit :=
|
||||
def EnumeratedType : Parser Unit :=
|
||||
NotationType <|> Enumeration
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-AttType -/
|
||||
def AttType : Parsec Unit :=
|
||||
def AttType : Parser Unit :=
|
||||
StringType <|> TokenizedType <|> EnumeratedType
|
||||
|
||||
def predefinedEntityToChar : String → Option LeanChar
|
||||
@@ -267,7 +268,7 @@ def predefinedEntityToChar : String → Option LeanChar
|
||||
| _ => none
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-EntityRef -/
|
||||
def EntityRef : Parsec $ Option LeanChar := attempt $
|
||||
def EntityRef : Parser $ Option LeanChar := attempt $
|
||||
skipChar '&' *> predefinedEntityToChar <$> Name <* skipChar ';'
|
||||
|
||||
@[inline]
|
||||
@@ -280,7 +281,7 @@ def digitsToNat (base : Nat) (digits : Array Nat) : Nat :=
|
||||
digits.foldl (λ r d => r * base + d) 0
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-CharRef -/
|
||||
def CharRef : Parsec LeanChar := do
|
||||
def CharRef : Parser LeanChar := do
|
||||
skipString "&#"
|
||||
let charCode ←
|
||||
digitsToNat 10 <$> many1 (hexDigitToNat <$> digit)
|
||||
@@ -289,11 +290,11 @@ def CharRef : Parsec LeanChar := do
|
||||
return Char.ofNat charCode
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-Reference -/
|
||||
def Reference : Parsec $ Option LeanChar :=
|
||||
def Reference : Parser $ Option LeanChar :=
|
||||
EntityRef <|> some <$> CharRef
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-AttValue -/
|
||||
def AttValue : Parsec String := do
|
||||
def AttValue : Parser String := do
|
||||
let chars ←
|
||||
(do
|
||||
skipChar '"'
|
||||
@@ -306,25 +307,25 @@ def AttValue : Parsec String := do
|
||||
return chars.foldl (λ s c => if let some c := c then s.push c else s) ""
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-DefaultDecl -/
|
||||
def DefaultDecl : Parsec Unit :=
|
||||
def DefaultDecl : Parser Unit :=
|
||||
skipString "#REQUIRED"
|
||||
<|> skipString "#IMPLIED"
|
||||
<|> optional (skipString "#FIXED" <* S) *> AttValue *> pure ()
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-AttDef -/
|
||||
def AttDef : Parsec Unit :=
|
||||
def AttDef : Parser Unit :=
|
||||
S *> Name *> S *> AttType *> S *> DefaultDecl
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-AttlistDecl -/
|
||||
def AttlistDecl : Parsec Unit :=
|
||||
def AttlistDecl : Parser Unit :=
|
||||
skipString "<!ATTLIST" *> S *> Name *> many AttDef *> optional S *> skipChar '>'
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-PEReference -/
|
||||
def PEReference : Parsec Unit :=
|
||||
def PEReference : Parser Unit :=
|
||||
skipChar '%' *> Name *> skipChar ';'
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-EntityValue -/
|
||||
def EntityValue : Parsec String := do
|
||||
def EntityValue : Parser String := do
|
||||
let chars ←
|
||||
(do
|
||||
skipChar '"'
|
||||
@@ -338,51 +339,51 @@ def EntityValue : Parsec String := do
|
||||
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-NDataDecl -/
|
||||
def NDataDecl : Parsec Unit :=
|
||||
def NDataDecl : Parser Unit :=
|
||||
S *> skipString "NDATA" <* S <* Name
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-EntityDef -/
|
||||
def EntityDef : Parsec Unit :=
|
||||
def EntityDef : Parser Unit :=
|
||||
EntityValue *> pure () <|> (ExternalID <* optional NDataDecl)
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-GEDecl -/
|
||||
def GEDecl : Parsec Unit :=
|
||||
def GEDecl : Parser Unit :=
|
||||
skipString "<!ENTITY" *> S *> Name *> S *> EntityDef *> optional S *> skipChar '>'
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-PEDef -/
|
||||
def PEDef : Parsec Unit :=
|
||||
def PEDef : Parser Unit :=
|
||||
EntityValue *> pure () <|> ExternalID
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-PEDecl -/
|
||||
def PEDecl : Parsec Unit :=
|
||||
def PEDecl : Parser Unit :=
|
||||
skipString "<!ENTITY" *> S *> skipChar '%' *> S *> Name *> PEDef *> optional S *> skipChar '>'
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-EntityDecl -/
|
||||
def EntityDecl : Parsec Unit :=
|
||||
def EntityDecl : Parser Unit :=
|
||||
GEDecl <|> PEDecl
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-PublicID -/
|
||||
def PublicID : Parsec Unit :=
|
||||
def PublicID : Parser Unit :=
|
||||
skipString "PUBLIC" <* S <* PubidLiteral
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-NotationDecl -/
|
||||
def NotationDecl : Parsec Unit :=
|
||||
def NotationDecl : Parser Unit :=
|
||||
skipString "<!NOTATION" *> S *> Name *> (ExternalID <|> PublicID) *> optional S *> skipChar '>'
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-markupdecl -/
|
||||
def markupDecl : Parsec Unit :=
|
||||
def markupDecl : Parser Unit :=
|
||||
elementDecl <|> AttlistDecl <|> EntityDecl <|> NotationDecl <|> PI <|> (Comment *> pure ())
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-DeclSep -/
|
||||
def DeclSep : Parsec Unit :=
|
||||
def DeclSep : Parser Unit :=
|
||||
PEReference <|> S *> pure ()
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-intSubset -/
|
||||
def intSubset : Parsec Unit :=
|
||||
def intSubset : Parser Unit :=
|
||||
many (markupDecl <|> DeclSep) *> pure ()
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-doctypedecl -/
|
||||
def doctypedecl : Parsec Unit := do
|
||||
def doctypedecl : Parser Unit := do
|
||||
skipString "<!DOCTYPE"
|
||||
S *>
|
||||
Name *>
|
||||
@@ -392,19 +393,19 @@ def doctypedecl : Parsec Unit := do
|
||||
skipChar '>'
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-prolog -/
|
||||
def prolog : Parsec Unit :=
|
||||
def prolog : Parser Unit :=
|
||||
optional XMLdecl *>
|
||||
many Misc *>
|
||||
optional (doctypedecl <* many Misc) *> pure ()
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-Attribute -/
|
||||
def Attribute : Parsec (String × String) := do
|
||||
def Attribute : Parser (String × String) := do
|
||||
let name ← Name
|
||||
Eq
|
||||
let value ← AttValue
|
||||
return (name, value)
|
||||
|
||||
protected def elementPrefix : Parsec (Array Content → Element) := do
|
||||
protected def elementPrefix : Parser (Array Content → Element) := do
|
||||
skipChar '<'
|
||||
let name ← Name
|
||||
let attributes ← many (attempt <| S *> Attribute)
|
||||
@@ -412,40 +413,40 @@ protected def elementPrefix : Parsec (Array Content → Element) := do
|
||||
return Element.Element name (RBMap.fromList attributes.toList compare)
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-EmptyElemTag -/
|
||||
def EmptyElemTag (elem : Array Content → Element) : Parsec Element := do
|
||||
def EmptyElemTag (elem : Array Content → Element) : Parser Element := do
|
||||
skipString "/>" *> pure (elem #[])
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-STag -/
|
||||
def STag (elem : Array Content → Element) : Parsec (Array Content → Element) := do
|
||||
def STag (elem : Array Content → Element) : Parser (Array Content → Element) := do
|
||||
skipChar '>' *> pure elem
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-ETag -/
|
||||
def ETag : Parsec Unit :=
|
||||
def ETag : Parser Unit :=
|
||||
skipString "</" *> Name *> optional S *> skipChar '>'
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-CDStart -/
|
||||
def CDStart : Parsec Unit :=
|
||||
def CDStart : Parser Unit :=
|
||||
skipString "<![CDATA["
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-CDEnd -/
|
||||
def CDEnd : Parsec Unit :=
|
||||
def CDEnd : Parser Unit :=
|
||||
skipString "]]>"
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-CData -/
|
||||
def CData : Parsec String :=
|
||||
manyChars (notFollowedBy (skipString "]]>") *> anyChar)
|
||||
def CData : Parser String :=
|
||||
manyChars (notFollowedBy (skipString "]]>") *> any)
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-CDSect -/
|
||||
def CDSect : Parsec String :=
|
||||
def CDSect : Parser String :=
|
||||
CDStart *> CData <* CDEnd
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-CharData -/
|
||||
def CharData : Parsec String :=
|
||||
def CharData : Parser String :=
|
||||
notFollowedBy (skipString "]]>") *> manyChars (satisfy λ c => c ≠ '<' ∧ c ≠ '&')
|
||||
|
||||
mutual
|
||||
/-- https://www.w3.org/TR/xml/#NT-content -/
|
||||
partial def content : Parsec (Array Content) := do
|
||||
partial def content : Parser (Array Content) := do
|
||||
let x ← optional (Content.Character <$> CharData)
|
||||
let xs ← many do
|
||||
let y ←
|
||||
@@ -468,20 +469,20 @@ mutual
|
||||
return res
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-element -/
|
||||
partial def element : Parsec Element := do
|
||||
partial def element : Parser Element := do
|
||||
let elem ← Parser.elementPrefix
|
||||
EmptyElemTag elem <|> STag elem <*> content <* ETag
|
||||
|
||||
end
|
||||
|
||||
/-- https://www.w3.org/TR/xml/#NT-document -/
|
||||
def document : Parsec Element := prolog *> element <* many Misc <* eof
|
||||
def document : Parser Element := prolog *> element <* many Misc <* eof
|
||||
|
||||
end Parser
|
||||
|
||||
def parse (s : String) : Except String Element :=
|
||||
match Xml.Parser.document s.mkIterator with
|
||||
| Parsec.ParseResult.success _ res => Except.ok res
|
||||
| Parsec.ParseResult.error it err => Except.error s!"offset {it.i.byteIdx.repr}: {err}\n{(it.prevn 10).extract it}"
|
||||
| .success _ res => Except.ok res
|
||||
| .error it err => Except.error s!"offset {it.i.byteIdx.repr}: {err}\n{(it.prevn 10).extract it}"
|
||||
|
||||
end Xml
|
||||
|
||||
@@ -541,7 +541,7 @@ mutual
|
||||
/--
|
||||
Process a `fType` of the form `(x : A) → B x`.
|
||||
This method assume `fType` is a function type -/
|
||||
private partial def processExplictArg (argName : Name) : M Expr := do
|
||||
private partial def processExplicitArg (argName : Name) : M Expr := do
|
||||
match (← get).args with
|
||||
| arg::args =>
|
||||
if (← anyNamedArgDependsOnCurrent) then
|
||||
@@ -586,6 +586,16 @@ mutual
|
||||
| Except.ok tacticSyntax =>
|
||||
-- TODO(Leo): does this work correctly for tactic sequences?
|
||||
let tacticBlock ← `(by $(⟨tacticSyntax⟩))
|
||||
/-
|
||||
We insert position information from the current ref into `stx` everywhere, simulating this being
|
||||
a tactic script inserted by the user, which ensures error messages and logging will always be attributed
|
||||
to this application rather than sometimes being placed at position (1,0) in the file.
|
||||
Placing position information on `by` syntax alone is not sufficient since incrementality
|
||||
(in particular, `Lean.Elab.Term.withReuseContext`) controls the ref to avoid leakage of outside data.
|
||||
Note that `tacticSyntax` contains no position information itself, since it is erased by `Lean.Elab.Term.quoteAutoTactic`.
|
||||
-/
|
||||
let info := (← getRef).getHeadInfo
|
||||
let tacticBlock := tacticBlock.raw.rewriteBottomUp (·.setInfo info)
|
||||
let argNew := Arg.stx tacticBlock
|
||||
propagateExpectedType argNew
|
||||
elabAndAddNewArg argName argNew
|
||||
@@ -615,7 +625,7 @@ mutual
|
||||
This method assume `fType` is a function type -/
|
||||
private partial def processImplicitArg (argName : Name) : M Expr := do
|
||||
if (← read).explicit then
|
||||
processExplictArg argName
|
||||
processExplicitArg argName
|
||||
else
|
||||
addImplicitArg argName
|
||||
|
||||
@@ -624,7 +634,7 @@ mutual
|
||||
This method assume `fType` is a function type -/
|
||||
private partial def processStrictImplicitArg (argName : Name) : M Expr := do
|
||||
if (← read).explicit then
|
||||
processExplictArg argName
|
||||
processExplicitArg argName
|
||||
else if (← hasArgsToProcess) then
|
||||
addImplicitArg argName
|
||||
else
|
||||
@@ -643,7 +653,7 @@ mutual
|
||||
addNewArg argName arg
|
||||
main
|
||||
else
|
||||
processExplictArg argName
|
||||
processExplicitArg argName
|
||||
else
|
||||
let arg ← mkFreshExprMVar (← getArgExpectedType) MetavarKind.synthetic
|
||||
addInstMVar arg.mvarId!
|
||||
@@ -668,7 +678,7 @@ mutual
|
||||
| .implicit => processImplicitArg binderName
|
||||
| .instImplicit => processInstImplicitArg binderName
|
||||
| .strictImplicit => processStrictImplicitArg binderName
|
||||
| _ => processExplictArg binderName
|
||||
| _ => processExplicitArg binderName
|
||||
else if (← hasArgsToProcess) then
|
||||
synthesizePendingAndNormalizeFunType
|
||||
main
|
||||
|
||||
@@ -201,12 +201,12 @@ def mkMessageAux (ctx : Context) (ref : Syntax) (msgData : MessageData) (severit
|
||||
|
||||
private def addTraceAsMessagesCore (ctx : Context) (log : MessageLog) (traceState : TraceState) : MessageLog := Id.run do
|
||||
if traceState.traces.isEmpty then return log
|
||||
let mut traces : HashMap (String.Pos × String.Pos) (Array MessageData) := ∅
|
||||
let mut traces : Std.HashMap (String.Pos × String.Pos) (Array MessageData) := ∅
|
||||
for traceElem in traceState.traces do
|
||||
let ref := replaceRef traceElem.ref ctx.ref
|
||||
let pos := ref.getPos?.getD 0
|
||||
let endPos := ref.getTailPos?.getD pos
|
||||
traces := traces.insert (pos, endPos) <| traces.findD (pos, endPos) #[] |>.push traceElem.msg
|
||||
traces := traces.insert (pos, endPos) <| traces.getD (pos, endPos) #[] |>.push traceElem.msg
|
||||
let mut log := log
|
||||
let traces' := traces.toArray.qsort fun ((a, _), _) ((b, _), _) => a < b
|
||||
for ((pos, endPos), traceMsg) in traces' do
|
||||
|
||||
@@ -81,10 +81,6 @@ end Frontend
|
||||
|
||||
open Frontend
|
||||
|
||||
def IO.processCommands (inputCtx : Parser.InputContext) (parserState : Parser.ModuleParserState) (commandState : Command.State) : IO State := do
|
||||
let (_, s) ← (Frontend.processCommands.run { inputCtx := inputCtx }).run { commandState := commandState, parserState := parserState, cmdPos := parserState.pos }
|
||||
pure s
|
||||
|
||||
structure IncrementalState extends State where
|
||||
inputCtx : Parser.InputContext
|
||||
initialSnap : Language.Lean.CommandParsedSnapshot
|
||||
@@ -92,12 +88,10 @@ deriving Nonempty
|
||||
|
||||
open Language in
|
||||
/--
|
||||
Variant of `IO.processCommands` that uses the new Lean language processor implementation for
|
||||
potential incremental reuse. Pass in result of a previous invocation done with the same state
|
||||
(but usually different input context) to allow for reuse.
|
||||
Variant of `IO.processCommands` that allows for potential incremental reuse. Pass in the result of a
|
||||
previous invocation done with the same state (but usually different input context) to allow for
|
||||
reuse.
|
||||
-/
|
||||
-- `IO.processCommands` can be reimplemented on top of this as soon as the additional tasks speed up
|
||||
-- things instead of slowing them down
|
||||
partial def IO.processCommandsIncrementally (inputCtx : Parser.InputContext)
|
||||
(parserState : Parser.ModuleParserState) (commandState : Command.State)
|
||||
(old? : Option IncrementalState) :
|
||||
@@ -110,7 +104,7 @@ where
|
||||
let snap := t.get
|
||||
let commands := commands.push snap.data.stx
|
||||
if let some next := snap.nextCmdSnap? then
|
||||
go initialSnap next commands
|
||||
go initialSnap next.task commands
|
||||
else
|
||||
-- Opting into reuse also enables incremental reporting, so make sure to collect messages from
|
||||
-- all snapshots
|
||||
@@ -126,6 +120,11 @@ where
|
||||
inputCtx, initialSnap, commands
|
||||
}
|
||||
|
||||
def IO.processCommands (inputCtx : Parser.InputContext) (parserState : Parser.ModuleParserState)
|
||||
(commandState : Command.State) : IO State := do
|
||||
let st ← IO.processCommandsIncrementally inputCtx parserState commandState none
|
||||
return st.toState
|
||||
|
||||
def process (input : String) (env : Environment) (opts : Options) (fileName : Option String := none) : IO (Environment × MessageLog) := do
|
||||
let fileName := fileName.getD "<input>"
|
||||
let inputCtx := Parser.mkInputContext input fileName
|
||||
@@ -144,62 +143,31 @@ def runFrontend
|
||||
: IO (Environment × Bool) := do
|
||||
let startTime := (← IO.monoNanosNow).toFloat / 1000000000
|
||||
let inputCtx := Parser.mkInputContext input fileName
|
||||
if true then
|
||||
-- Temporarily keep alive old cmdline driver for the Lean language so that we don't pay the
|
||||
-- overhead of passing the environment between snapshots until we actually make good use of it
|
||||
-- outside the server
|
||||
let (header, parserState, messages) ← Parser.parseHeader inputCtx
|
||||
-- allow `env` to be leaked, which would live until the end of the process anyway
|
||||
let (env, messages) ← processHeader (leakEnv := true) header opts messages inputCtx trustLevel
|
||||
let env := env.setMainModule mainModuleName
|
||||
let mut commandState := Command.mkState env messages opts
|
||||
let elabStartTime := (← IO.monoNanosNow).toFloat / 1000000000
|
||||
|
||||
if ileanFileName?.isSome then
|
||||
-- Collect InfoTrees so we can later extract and export their info to the ilean file
|
||||
commandState := { commandState with infoState.enabled := true }
|
||||
|
||||
let s ← IO.processCommands inputCtx parserState commandState
|
||||
Language.reportMessages s.commandState.messages opts jsonOutput
|
||||
|
||||
if let some ileanFileName := ileanFileName? then
|
||||
let trees := s.commandState.infoState.trees.toArray
|
||||
let references ←
|
||||
Lean.Server.findModuleRefs inputCtx.fileMap trees (localVars := false) |>.toLspModuleRefs
|
||||
let ilean := { module := mainModuleName, references : Lean.Server.Ilean }
|
||||
IO.FS.writeFile ileanFileName $ Json.compress $ toJson ilean
|
||||
|
||||
if let some out := trace.profiler.output.get? opts then
|
||||
let traceState := s.commandState.traceState
|
||||
-- importing does not happen in an elaboration monad, add now
|
||||
let traceState := { traceState with
|
||||
traces := #[{
|
||||
ref := .missing,
|
||||
msg := .trace { cls := `Import, startTime, stopTime := elabStartTime }
|
||||
(.ofFormat "importing") #[]
|
||||
}].toPArray' ++ traceState.traces
|
||||
}
|
||||
let profile ← Firefox.Profile.export mainModuleName.toString startTime traceState opts
|
||||
IO.FS.writeFile ⟨out⟩ <| Json.compress <| toJson profile
|
||||
|
||||
return (s.commandState.env, !s.commandState.messages.hasErrors)
|
||||
|
||||
let opts := Language.Lean.internal.minimalSnapshots.set opts true
|
||||
let ctx := { inputCtx with }
|
||||
let processor := Language.Lean.process
|
||||
let snap ← processor (fun _ => pure <| .ok { mainModuleName, opts, trustLevel }) none ctx
|
||||
let snaps := Language.toSnapshotTree snap
|
||||
snaps.runAndReport opts jsonOutput
|
||||
|
||||
if let some ileanFileName := ileanFileName? then
|
||||
let trees := snaps.getAll.concatMap (match ·.infoTree? with | some t => #[t] | _ => #[])
|
||||
let references := Lean.Server.findModuleRefs inputCtx.fileMap trees (localVars := false)
|
||||
let ilean := { module := mainModuleName, references := ← references.toLspModuleRefs : Lean.Server.Ilean }
|
||||
IO.FS.writeFile ileanFileName $ Json.compress $ toJson ilean
|
||||
|
||||
let hasErrors := snaps.getAll.any (·.diagnostics.msgLog.hasErrors)
|
||||
-- TODO: remove default when reworking cmdline interface in Lean; currently the only case
|
||||
-- where we use the environment despite errors in the file is `--stats`
|
||||
let env := Language.Lean.waitForFinalEnv? snap |>.getD (← mkEmptyEnvironment)
|
||||
pure (env, !hasErrors)
|
||||
let some cmdState := Language.Lean.waitForFinalCmdState? snap
|
||||
| return (← mkEmptyEnvironment, false)
|
||||
|
||||
if let some out := trace.profiler.output.get? opts then
|
||||
let traceState := cmdState.traceState
|
||||
let profile ← Firefox.Profile.export mainModuleName.toString startTime traceState opts
|
||||
IO.FS.writeFile ⟨out⟩ <| Json.compress <| toJson profile
|
||||
|
||||
let hasErrors := snaps.getAll.any (·.diagnostics.msgLog.hasErrors)
|
||||
pure (cmdState.env, !hasErrors)
|
||||
|
||||
|
||||
end Lean.Elab
|
||||
|
||||
@@ -39,7 +39,7 @@ def parseImports (input : String) (fileName : Option String := none) : IO (Array
|
||||
def printImports (input : String) (fileName : Option String) : IO Unit := do
|
||||
let (deps, _, _) ← parseImports input fileName
|
||||
for dep in deps do
|
||||
let fname ← findOLean (checkExists := false) dep.module
|
||||
let fname ← findOLean dep.module
|
||||
IO.println fname
|
||||
|
||||
end Lean.Elab
|
||||
|
||||
@@ -630,7 +630,7 @@ private def replaceIndFVarsWithConsts (views : Array InductiveView) (indFVars :
|
||||
let type := type.replace fun e =>
|
||||
if !e.isFVar then
|
||||
none
|
||||
else match indFVar2Const.find? e with
|
||||
else match indFVar2Const[e]? with
|
||||
| none => none
|
||||
| some c => mkAppN c (params.extract 0 numVars)
|
||||
instantiateMVars (← mkForallFVars params type)
|
||||
|
||||
@@ -425,7 +425,7 @@ private def applyRefMap (e : Expr) (map : ExprMap Expr) : Expr :=
|
||||
e.replace fun e =>
|
||||
match patternWithRef? e with
|
||||
| some _ => some e -- stop `e` already has annotation
|
||||
| none => match map.find? e with
|
||||
| none => match map[e]? with
|
||||
| some eWithRef => some eWithRef -- stop `e` found annotation
|
||||
| none => none -- continue
|
||||
|
||||
|
||||
@@ -323,6 +323,10 @@ private def declValToTerminationHint (declVal : Syntax) : TermElabM TerminationH
|
||||
else
|
||||
return .none
|
||||
|
||||
def instantiateMVarsProfiling (e : Expr) : MetaM Expr := do
|
||||
profileitM Exception s!"instantiate metavars" (← getOptions) do
|
||||
instantiateMVars e
|
||||
|
||||
private def elabFunValues (headers : Array DefViewElabHeader) : TermElabM (Array Expr) :=
|
||||
headers.mapM fun header => do
|
||||
let mut reusableResult? := none
|
||||
@@ -348,7 +352,7 @@ private def elabFunValues (headers : Array DefViewElabHeader) : TermElabM (Array
|
||||
elabTermEnsuringType valStx type <* synthesizeSyntheticMVarsNoPostponing
|
||||
-- NOTE: without this `instantiatedMVars`, `mkLambdaFVars` may leave around a redex that
|
||||
-- leads to more section variables being included than necessary
|
||||
let val ← instantiateMVars val
|
||||
let val ← instantiateMVarsProfiling val
|
||||
mkLambdaFVars xs val
|
||||
if let some snap := header.bodySnap? then
|
||||
snap.new.resolve <| some {
|
||||
@@ -389,7 +393,7 @@ private def instantiateMVarsAtHeader (header : DefViewElabHeader) : TermElabM De
|
||||
|
||||
private def instantiateMVarsAtLetRecToLift (toLift : LetRecToLift) : TermElabM LetRecToLift := do
|
||||
let type ← instantiateMVars toLift.type
|
||||
let val ← instantiateMVars toLift.val
|
||||
let val ← instantiateMVarsProfiling toLift.val
|
||||
pure { toLift with type, val }
|
||||
|
||||
private def typeHasRecFun (type : Expr) (funFVars : Array Expr) (letRecsToLift : List LetRecToLift) : Option FVarId :=
|
||||
@@ -597,7 +601,7 @@ private def pickMaxFVar? (lctx : LocalContext) (fvarIds : Array FVarId) : Option
|
||||
fvarIds.getMax? fun fvarId₁ fvarId₂ => (lctx.get! fvarId₁).index < (lctx.get! fvarId₂).index
|
||||
|
||||
private def preprocess (e : Expr) : TermElabM Expr := do
|
||||
let e ← instantiateMVars e
|
||||
let e ← instantiateMVarsProfiling e
|
||||
-- which let-decls are dependent. We say a let-decl is dependent if its lambda abstraction is type incorrect.
|
||||
Meta.check e
|
||||
pure e
|
||||
@@ -708,7 +712,7 @@ private def mkLetRecClosures (sectionVars : Array Expr) (mainFVarIds : Array FVa
|
||||
-- This can happen when this particular let-rec has nested let-rec that have been resolved in previous iterations.
|
||||
-- This code relies on the fact that nested let-recs occur before the outer most let-recs at `letRecsToLift`.
|
||||
-- Unresolved nested let-recs appear as metavariables before they are resolved. See `assignExprMVar` at `mkLetRecClosureFor`
|
||||
let valNew ← instantiateMVars letRecsToLift[i]!.val
|
||||
let valNew ← instantiateMVarsProfiling letRecsToLift[i]!.val
|
||||
letRecsToLift := letRecsToLift.modify i fun t => { t with val := valNew }
|
||||
-- We have to recompute the `freeVarMap` in this case. This overhead should not be an issue in practice.
|
||||
freeVarMap ← mkFreeVarMap sectionVars mainFVarIds recFVarIds letRecsToLift
|
||||
@@ -821,10 +825,10 @@ def main (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mai
|
||||
let letRecsToLift ← letRecsToLift.mapM fun toLift => withLCtx toLift.lctx toLift.localInstances do
|
||||
Meta.check toLift.type
|
||||
Meta.check toLift.val
|
||||
return { toLift with val := (← instantiateMVars toLift.val), type := (← instantiateMVars toLift.type) }
|
||||
return { toLift with val := (← instantiateMVarsProfiling toLift.val), type := (← instantiateMVars toLift.type) }
|
||||
let letRecClosures ← mkLetRecClosures sectionVars mainFVarIds recFVarIds letRecsToLift
|
||||
-- mkLetRecClosures assign metavariables that were placeholders for the lifted declarations.
|
||||
let mainVals ← mainVals.mapM (instantiateMVars ·)
|
||||
let mainVals ← mainVals.mapM (instantiateMVarsProfiling ·)
|
||||
let mainHeaders ← mainHeaders.mapM instantiateMVarsAtHeader
|
||||
let letRecClosures ← letRecClosures.mapM fun closure => do pure { closure with toLift := (← instantiateMVarsAtLetRecToLift closure.toLift) }
|
||||
-- Replace fvarIds for functions being defined with closed terms
|
||||
@@ -923,7 +927,7 @@ where
|
||||
try
|
||||
let values ← elabFunValues headers
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
values.mapM (instantiateMVars ·)
|
||||
values.mapM (instantiateMVarsProfiling ·)
|
||||
catch ex =>
|
||||
logException ex
|
||||
headers.mapM fun header => mkSorry header.type (synthetic := true)
|
||||
|
||||
@@ -7,6 +7,9 @@ prelude
|
||||
import Init.ShareCommon
|
||||
import Lean.Compiler.NoncomputableAttr
|
||||
import Lean.Util.CollectLevelParams
|
||||
import Lean.Util.NumObjs
|
||||
import Lean.Util.NumApps
|
||||
import Lean.PrettyPrinter
|
||||
import Lean.Meta.AbstractNestedProofs
|
||||
import Lean.Meta.ForEachExpr
|
||||
import Lean.Elab.RecAppSyntax
|
||||
@@ -17,7 +20,6 @@ namespace Lean.Elab
|
||||
open Meta
|
||||
open Term
|
||||
|
||||
|
||||
/--
|
||||
A (potentially recursive) definition.
|
||||
The elaborator converts it into Kernel definitions using many different strategies.
|
||||
@@ -98,15 +100,33 @@ private def compileDecl (decl : Declaration) : TermElabM Bool := do
|
||||
throw ex
|
||||
return true
|
||||
|
||||
register_builtin_option diagnostics.threshold.proofSize : Nat := {
|
||||
defValue := 16384
|
||||
group := "diagnostics"
|
||||
descr := "only display proof statistics when proof has at least this number of terms"
|
||||
}
|
||||
|
||||
private def reportTheoremDiag (d : TheoremVal) : TermElabM Unit := do
|
||||
if (← isDiagnosticsEnabled) then
|
||||
let proofSize ← d.value.numObjs
|
||||
if proofSize > diagnostics.threshold.proofSize.get (← getOptions) then
|
||||
let sizeMsg := MessageData.trace { cls := `size } m!"{proofSize}" #[]
|
||||
let constOccs ← d.value.numApps (threshold := diagnostics.threshold.get (← getOptions))
|
||||
let constOccsMsg ← constOccs.mapM fun (declName, numOccs) => return MessageData.trace { cls := `occs } m!"{MessageData.ofConst (← mkConstWithLevelParams declName)} ↦ {numOccs}" #[]
|
||||
-- let info
|
||||
logInfo <| MessageData.trace { cls := `theorem } m!"{d.name}" (#[sizeMsg] ++ constOccsMsg)
|
||||
|
||||
private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (all : List Name) (applyAttrAfterCompilation := true) : TermElabM Unit :=
|
||||
withRef preDef.ref do
|
||||
let preDef ← abstractNestedProofs preDef
|
||||
let decl ←
|
||||
match preDef.kind with
|
||||
| DefKind.«theorem» =>
|
||||
pure <| Declaration.thmDecl {
|
||||
let d := {
|
||||
name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value, all
|
||||
}
|
||||
reportTheoremDiag d
|
||||
pure <| Declaration.thmDecl d
|
||||
| DefKind.«opaque» =>
|
||||
pure <| Declaration.opaqueDecl {
|
||||
name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value
|
||||
@@ -144,8 +164,11 @@ def addNonRec (preDef : PreDefinition) (applyAttrAfterCompilation := true) (all
|
||||
/--
|
||||
Eliminate recursive application annotations containing syntax. These annotations are used by the well-founded recursion module
|
||||
to produce better error messages. -/
|
||||
def eraseRecAppSyntaxExpr (e : Expr) : CoreM Expr :=
|
||||
Core.transform e (post := fun e => pure <| TransformStep.done <| if (getRecAppSyntax? e).isSome then e.mdataExpr! else e)
|
||||
def eraseRecAppSyntaxExpr (e : Expr) : CoreM Expr := do
|
||||
if e.find? hasRecAppSyntax |>.isSome then
|
||||
Core.transform e (post := fun e => pure <| TransformStep.done <| if hasRecAppSyntax e then e.mdataExpr! else e)
|
||||
else
|
||||
return e
|
||||
|
||||
def eraseRecAppSyntax (preDef : PreDefinition) : CoreM PreDefinition :=
|
||||
return { preDef with value := (← eraseRecAppSyntaxExpr preDef.value) }
|
||||
|
||||
@@ -69,12 +69,15 @@ private def ensureNoUnassignedMVarsAtPreDef (preDef : PreDefinition) : TermElabM
|
||||
This method beta-reduces them to make sure they can be eliminated by the well-founded recursion module. -/
|
||||
private def betaReduceLetRecApps (preDefs : Array PreDefinition) : MetaM (Array PreDefinition) :=
|
||||
preDefs.mapM fun preDef => do
|
||||
let value ← Core.transform preDef.value fun e => do
|
||||
if e.isApp && e.getAppFn.isLambda && e.getAppArgs.all fun arg => arg.getAppFn.isConst && preDefs.any fun preDef => preDef.declName == arg.getAppFn.constName! then
|
||||
return .visit e.headBeta
|
||||
else
|
||||
return .continue
|
||||
return { preDef with value }
|
||||
if preDef.value.find? (fun e => e.isConst && preDefs.any fun preDef => preDef.declName == e.constName!) |>.isSome then
|
||||
let value ← Core.transform preDef.value fun e => do
|
||||
if e.isApp && e.getAppFn.isLambda && e.getAppArgs.all fun arg => arg.getAppFn.isConst && preDefs.any fun preDef => preDef.declName == arg.getAppFn.constName! then
|
||||
return .visit e.headBeta
|
||||
else
|
||||
return .continue
|
||||
return { preDef with value }
|
||||
else
|
||||
return preDef
|
||||
|
||||
private def addAsAxioms (preDefs : Array PreDefinition) : TermElabM Unit := do
|
||||
for preDef in preDefs do
|
||||
|
||||
@@ -146,7 +146,7 @@ See issue #837 for an example where we can show termination using the index of a
|
||||
we don't get the desired definitional equalities.
|
||||
-/
|
||||
def nonIndicesFirst (recArgInfos : Array RecArgInfo) : Array RecArgInfo := Id.run do
|
||||
let mut indicesPos : HashSet Nat := {}
|
||||
let mut indicesPos : Std.HashSet Nat := {}
|
||||
for recArgInfo in recArgInfos do
|
||||
for pos in recArgInfo.indicesPos do
|
||||
indicesPos := indicesPos.insert pos
|
||||
|
||||
@@ -596,7 +596,7 @@ private partial def compileStxMatch (discrs : List Term) (alts : List Alt) : Ter
|
||||
`(have __discr := $discr; $stx)
|
||||
| _, _ => unreachable!
|
||||
|
||||
abbrev IdxSet := HashSet Nat
|
||||
abbrev IdxSet := Std.HashSet Nat
|
||||
|
||||
private partial def hasNoErrorIfUnused : Syntax → Bool
|
||||
| `(no_error_if_unused% $_) => true
|
||||
|
||||
@@ -11,27 +11,35 @@ namespace Lean
|
||||
private def recAppKey := `_recApp
|
||||
|
||||
/--
|
||||
We store the syntax at recursive applications to be able to generate better error messages
|
||||
when performing well-founded and structural recursion.
|
||||
We store the syntax at recursive applications to be able to generate better error messages
|
||||
when performing well-founded and structural recursion.
|
||||
-/
|
||||
def mkRecAppWithSyntax (e : Expr) (stx : Syntax) : Expr :=
|
||||
mkMData (KVMap.empty.insert recAppKey (DataValue.ofSyntax stx)) e
|
||||
mkMData (KVMap.empty.insert recAppKey (.ofSyntax stx)) e
|
||||
|
||||
/--
|
||||
Retrieve (if available) the syntax object attached to a recursive application.
|
||||
Retrieve (if available) the syntax object attached to a recursive application.
|
||||
-/
|
||||
def getRecAppSyntax? (e : Expr) : Option Syntax :=
|
||||
match e with
|
||||
| Expr.mdata d _ =>
|
||||
| .mdata d _ =>
|
||||
match d.find recAppKey with
|
||||
| some (DataValue.ofSyntax stx) => some stx
|
||||
| _ => none
|
||||
| _ => none
|
||||
|
||||
/--
|
||||
Checks if the `MData` is for a recursive applciation.
|
||||
Checks if the `MData` is for a recursive applciation.
|
||||
-/
|
||||
def MData.isRecApp (d : MData) : Bool :=
|
||||
d.contains recAppKey
|
||||
|
||||
/--
|
||||
Return `true` if `getRecAppSyntax? e` is a `some`.
|
||||
-/
|
||||
def hasRecAppSyntax (e : Expr) : Bool :=
|
||||
match e with
|
||||
| .mdata d _ => d.isRecApp
|
||||
| _ => false
|
||||
|
||||
end Lean
|
||||
|
||||
@@ -445,13 +445,13 @@ private def expandParentFields (s : Struct) : TermElabM Struct := do
|
||||
| _ => throwErrorAt ref "failed to access field '{fieldName}' in parent structure"
|
||||
| _ => return field
|
||||
|
||||
private abbrev FieldMap := HashMap Name Fields
|
||||
private abbrev FieldMap := Std.HashMap Name Fields
|
||||
|
||||
private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
|
||||
fields.foldlM (init := {}) fun fieldMap field =>
|
||||
match field.lhs with
|
||||
| .fieldName _ fieldName :: _ =>
|
||||
match fieldMap.find? fieldName with
|
||||
match fieldMap[fieldName]? with
|
||||
| some (prevField::restFields) =>
|
||||
if field.isSimple || prevField.isSimple then
|
||||
throwErrorAt field.ref "field '{fieldName}' has already been specified"
|
||||
@@ -677,6 +677,10 @@ private partial def elabStruct (s : Struct) (expectedType? : Option Expr) : Term
|
||||
| .error err => throwError err
|
||||
| .ok tacticSyntax =>
|
||||
let stx ← `(by $tacticSyntax)
|
||||
-- See comment in `Lean.Elab.Term.ElabAppArgs.processExplicitArg` about `tacticSyntax`.
|
||||
-- We add info to get reliable positions for messages from evaluating the tactic script.
|
||||
let info := field.ref.getHeadInfo
|
||||
let stx := stx.raw.rewriteBottomUp (·.setInfo info)
|
||||
cont (← elabTermEnsuringType stx (d.getArg! 0).consumeTypeAnnotations) field
|
||||
| _ =>
|
||||
if bi == .instImplicit then
|
||||
|
||||
@@ -246,7 +246,7 @@ private def getSomeSyntheticMVarsRef : TermElabM Syntax := do
|
||||
private def throwStuckAtUniverseCnstr : TermElabM Unit := do
|
||||
-- This code assumes `entries` is not empty. Note that `processPostponed` uses `exceptionOnFailure` to guarantee this property
|
||||
let entries ← getPostponed
|
||||
let mut found : HashSet (Level × Level) := {}
|
||||
let mut found : Std.HashSet (Level × Level) := {}
|
||||
let mut uniqueEntries := #[]
|
||||
for entry in entries do
|
||||
let mut lhs := entry.lhs
|
||||
|
||||
@@ -8,8 +8,6 @@ import Init.Omega.Constraint
|
||||
import Lean.Elab.Tactic.Omega.OmegaM
|
||||
import Lean.Elab.Tactic.Omega.MinNatAbs
|
||||
|
||||
open Lean (HashMap HashSet)
|
||||
|
||||
namespace Lean.Elab.Tactic.Omega
|
||||
|
||||
initialize Lean.registerTraceClass `omega
|
||||
@@ -167,11 +165,11 @@ structure Problem where
|
||||
/-- The number of variables in the problem. -/
|
||||
numVars : Nat := 0
|
||||
/-- The current constraints, indexed by their coefficients. -/
|
||||
constraints : HashMap Coeffs Fact := ∅
|
||||
constraints : Std.HashMap Coeffs Fact := ∅
|
||||
/--
|
||||
The coefficients for which `constraints` contains an exact constraint (i.e. an equality).
|
||||
-/
|
||||
equalities : HashSet Coeffs := ∅
|
||||
equalities : Std.HashSet Coeffs := ∅
|
||||
/--
|
||||
Equations that have already been used to eliminate variables,
|
||||
along with the variable which was removed, and its coefficient (either `1` or `-1`).
|
||||
@@ -251,7 +249,7 @@ combining it with any existing constraints for the same coefficients.
|
||||
def addConstraint (p : Problem) : Fact → Problem
|
||||
| f@⟨x, s, j⟩ =>
|
||||
if p.possible then
|
||||
match p.constraints.find? x with
|
||||
match p.constraints[x]? with
|
||||
| none =>
|
||||
match s with
|
||||
| .trivial => p
|
||||
@@ -313,7 +311,7 @@ After solving, the variable will have been eliminated from all constraints.
|
||||
def solveEasyEquality (p : Problem) (c : Coeffs) : Problem :=
|
||||
let i := c.findIdx? (·.natAbs = 1) |>.getD 0 -- findIdx? is always some
|
||||
let sign := c.get i |> Int.sign
|
||||
match p.constraints.find? c with
|
||||
match p.constraints[c]? with
|
||||
| some f =>
|
||||
let init :=
|
||||
{ assumptions := p.assumptions
|
||||
@@ -335,7 +333,7 @@ After solving the easy equality,
|
||||
the minimum lexicographic value of `(c.minNatAbs, c.maxNatAbs)` will have been reduced.
|
||||
-/
|
||||
def dealWithHardEquality (p : Problem) (c : Coeffs) : OmegaM Problem :=
|
||||
match p.constraints.find? c with
|
||||
match p.constraints[c]? with
|
||||
| some ⟨_, ⟨some r, some r'⟩, j⟩ => do
|
||||
let m := c.minNatAbs + 1
|
||||
-- We have to store the valid value of the newly introduced variable in the atoms.
|
||||
@@ -479,7 +477,7 @@ def fourierMotzkinData (p : Problem) : Array FourierMotzkinData := Id.run do
|
||||
let n := p.numVars
|
||||
let mut data : Array FourierMotzkinData :=
|
||||
(List.range p.numVars).foldl (fun a i => a.push { var := i}) #[]
|
||||
for (_, f@⟨xs, s, _⟩) in p.constraints.toList do -- We could make a forIn instance for HashMap
|
||||
for (_, f@⟨xs, s, _⟩) in p.constraints do
|
||||
for i in [0:n] do
|
||||
let x := Coeffs.get xs i
|
||||
data := data.modify i fun d =>
|
||||
|
||||
@@ -58,7 +58,7 @@ structure MetaProblem where
|
||||
-/
|
||||
disjunctions : List Expr := []
|
||||
/-- Facts which have already been processed; we keep these to avoid duplicates. -/
|
||||
processedFacts : HashSet Expr := ∅
|
||||
processedFacts : Std.HashSet Expr := ∅
|
||||
|
||||
/-- Construct the `rfl` proof that `lc.eval atoms = e`. -/
|
||||
def mkEvalRflProof (e : Expr) (lc : LinearCombo) : OmegaM Expr := do
|
||||
@@ -80,7 +80,7 @@ def mkCoordinateEvalAtomsEq (e : Expr) (n : Nat) : OmegaM Expr := do
|
||||
mkEqTrans eq (← mkEqSymm (mkApp2 (.const ``LinearCombo.coordinate_eval []) n atoms))
|
||||
|
||||
/-- Construct the linear combination (and its associated proof and new facts) for an atom. -/
|
||||
def mkAtomLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
def mkAtomLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
let (n, facts) ← lookup e
|
||||
return ⟨LinearCombo.coordinate n, mkCoordinateEvalAtomsEq e n, facts.getD ∅⟩
|
||||
|
||||
@@ -94,9 +94,9 @@ Gives a small (10%) speedup in testing.
|
||||
I tried using a pointer based cache,
|
||||
but there was never enough subexpression sharing to make it effective.
|
||||
-/
|
||||
partial def asLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
partial def asLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
let cache ← get
|
||||
match cache.find? e with
|
||||
match cache.get? e with
|
||||
| some (lc, prf) =>
|
||||
trace[omega] "Found in cache: {e}"
|
||||
return (lc, prf, ∅)
|
||||
@@ -120,7 +120,7 @@ We also transform the expression as we descend into it:
|
||||
* pushing coercions: `↑(x + y)`, `↑(x * y)`, `↑(x / k)`, `↑(x % k)`, `↑k`
|
||||
* unfolding `emod`: `x % k` → `x - x / k`
|
||||
-/
|
||||
partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
trace[omega] "processing {e}"
|
||||
match groundInt? e with
|
||||
| some i =>
|
||||
@@ -142,7 +142,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
|
||||
mkEqTrans
|
||||
(← mkAppM ``Int.add_congr #[← prf₁, ← prf₂])
|
||||
(← mkEqSymm add_eval)
|
||||
pure (l₁ + l₂, prf, facts₁.merge facts₂)
|
||||
pure (l₁ + l₂, prf, facts₁.union facts₂)
|
||||
| (``HSub.hSub, #[_, _, _, _, e₁, e₂]) => do
|
||||
let (l₁, prf₁, facts₁) ← asLinearCombo e₁
|
||||
let (l₂, prf₂, facts₂) ← asLinearCombo e₂
|
||||
@@ -152,7 +152,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
|
||||
mkEqTrans
|
||||
(← mkAppM ``Int.sub_congr #[← prf₁, ← prf₂])
|
||||
(← mkEqSymm sub_eval)
|
||||
pure (l₁ - l₂, prf, facts₁.merge facts₂)
|
||||
pure (l₁ - l₂, prf, facts₁.union facts₂)
|
||||
| (``Neg.neg, #[_, _, e']) => do
|
||||
let (l, prf, facts) ← asLinearCombo e'
|
||||
let prf' : OmegaM Expr := do
|
||||
@@ -178,7 +178,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
|
||||
mkEqTrans
|
||||
(← mkAppM ``Int.mul_congr #[← xprf, ← yprf])
|
||||
(← mkEqSymm mul_eval)
|
||||
pure (some (LinearCombo.mul xl yl, prf, xfacts.merge yfacts), true)
|
||||
pure (some (LinearCombo.mul xl yl, prf, xfacts.union yfacts), true)
|
||||
else
|
||||
pure (none, false)
|
||||
match r? with
|
||||
@@ -235,7 +235,7 @@ where
|
||||
Apply a rewrite rule to an expression, and interpret the result as a `LinearCombo`.
|
||||
(We're not rewriting any subexpressions here, just the top level, for efficiency.)
|
||||
-/
|
||||
rewrite (lhs rw : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
rewrite (lhs rw : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
trace[omega] "rewriting {lhs} via {rw} : {← inferType rw}"
|
||||
match (← inferType rw).eq? with
|
||||
| some (_, _lhs', rhs) =>
|
||||
@@ -243,7 +243,7 @@ where
|
||||
let prf' : OmegaM Expr := do mkEqTrans rw (← prf)
|
||||
pure (lc, prf', facts)
|
||||
| none => panic! "Invalid rewrite rule in 'asLinearCombo'"
|
||||
handleNatCast (e i n : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
handleNatCast (e i n : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
match n with
|
||||
| .fvar h =>
|
||||
if let some v ← h.getValue? then
|
||||
@@ -296,7 +296,7 @@ where
|
||||
| (``Fin.val, #[n, x]) =>
|
||||
handleFinVal e i n x
|
||||
| _ => mkAtomLinearCombo e
|
||||
handleFinVal (e i n x : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
|
||||
handleFinVal (e i n x : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
|
||||
match x with
|
||||
| .fvar h =>
|
||||
if let some v ← h.getValue? then
|
||||
@@ -342,7 +342,7 @@ We solve equalities as they are discovered, as this often results in an earlier
|
||||
-/
|
||||
def addIntEquality (p : MetaProblem) (h x : Expr) : OmegaM MetaProblem := do
|
||||
let (lc, prf, facts) ← asLinearCombo x
|
||||
let newFacts : HashSet Expr := facts.fold (init := ∅) fun s e =>
|
||||
let newFacts : Std.HashSet Expr := facts.fold (init := ∅) fun s e =>
|
||||
if p.processedFacts.contains e then s else s.insert e
|
||||
trace[omega] "Adding proof of {lc} = 0"
|
||||
pure <|
|
||||
@@ -358,7 +358,7 @@ We solve equalities as they are discovered, as this often results in an earlier
|
||||
-/
|
||||
def addIntInequality (p : MetaProblem) (h y : Expr) : OmegaM MetaProblem := do
|
||||
let (lc, prf, facts) ← asLinearCombo y
|
||||
let newFacts : HashSet Expr := facts.fold (init := ∅) fun s e =>
|
||||
let newFacts : Std.HashSet Expr := facts.fold (init := ∅) fun s e =>
|
||||
if p.processedFacts.contains e then s else s.insert e
|
||||
trace[omega] "Adding proof of {lc} ≥ 0"
|
||||
pure <|
|
||||
@@ -590,7 +590,7 @@ where
|
||||
|
||||
-- We sort the constraints; otherwise the order is dependent on details of the hashing
|
||||
-- and this can cause test suite output churn
|
||||
prettyConstraints (names : Array String) (constraints : HashMap Coeffs Fact) : String :=
|
||||
prettyConstraints (names : Array String) (constraints : Std.HashMap Coeffs Fact) : String :=
|
||||
constraints.toList
|
||||
|>.toArray
|
||||
|>.qsort (·.1 < ·.1)
|
||||
@@ -615,7 +615,7 @@ where
|
||||
(if Int.natAbs c = 1 then names[i]! else s!"{c.natAbs}*{names[i]!}"))
|
||||
|> String.join
|
||||
|
||||
mentioned (atoms : Array Expr) (constraints : HashMap Coeffs Fact) : MetaM (Array Bool) := do
|
||||
mentioned (atoms : Array Expr) (constraints : Std.HashMap Coeffs Fact) : MetaM (Array Bool) := do
|
||||
let initMask := Array.mkArray atoms.size false
|
||||
return constraints.fold (init := initMask) fun mask coeffs _ =>
|
||||
coeffs.enum.foldl (init := mask) fun mask (i, c) =>
|
||||
|
||||
@@ -10,6 +10,8 @@ import Init.Omega.Logic
|
||||
import Init.Data.BitVec.Basic
|
||||
import Lean.Meta.AppBuilder
|
||||
import Lean.Meta.Canonicalizer
|
||||
import Std.Data.HashMap.Basic
|
||||
import Std.Data.HashSet.Basic
|
||||
|
||||
/-!
|
||||
# The `OmegaM` state monad.
|
||||
@@ -52,7 +54,7 @@ structure Context where
|
||||
/-- The internal state for the `OmegaM` monad, recording previously encountered atoms. -/
|
||||
structure State where
|
||||
/-- The atoms up-to-defeq encountered so far. -/
|
||||
atoms : HashMap Expr Nat := {}
|
||||
atoms : Std.HashMap Expr Nat := {}
|
||||
|
||||
/-- An intermediate layer in the `OmegaM` monad. -/
|
||||
abbrev OmegaM' := StateRefT State (ReaderT Context CanonM)
|
||||
@@ -60,7 +62,7 @@ abbrev OmegaM' := StateRefT State (ReaderT Context CanonM)
|
||||
/--
|
||||
Cache of expressions that have been visited, and their reflection as a linear combination.
|
||||
-/
|
||||
def Cache : Type := HashMap Expr (LinearCombo × OmegaM' Expr)
|
||||
def Cache : Type := Std.HashMap Expr (LinearCombo × OmegaM' Expr)
|
||||
|
||||
/--
|
||||
The `OmegaM` monad maintains two pieces of state:
|
||||
@@ -71,7 +73,7 @@ abbrev OmegaM := StateRefT Cache OmegaM'
|
||||
|
||||
/-- Run a computation in the `OmegaM` monad, starting with no recorded atoms. -/
|
||||
def OmegaM.run (m : OmegaM α) (cfg : OmegaConfig) : MetaM α :=
|
||||
m.run' HashMap.empty |>.run' {} { cfg } |>.run'
|
||||
m.run' Std.HashMap.empty |>.run' {} { cfg } |>.run'
|
||||
|
||||
/-- Retrieve the user-specified configuration options. -/
|
||||
def cfg : OmegaM OmegaConfig := do pure (← read).cfg
|
||||
@@ -162,11 +164,11 @@ def mkEqReflWithExpectedType (a b : Expr) : MetaM Expr := do
|
||||
Analyzes a newly recorded atom,
|
||||
returning a collection of interesting facts about it that should be added to the context.
|
||||
-/
|
||||
def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
|
||||
def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
|
||||
match e.getAppFnArgs with
|
||||
| (``Nat.cast, #[.const ``Int [], _, e']) =>
|
||||
-- Casts of natural numbers are non-negative.
|
||||
let mut r := HashSet.empty.insert (Expr.app (.const ``Int.ofNat_nonneg []) e')
|
||||
let mut r := Std.HashSet.empty.insert (Expr.app (.const ``Int.ofNat_nonneg []) e')
|
||||
match (← cfg).splitNatSub, e'.getAppFnArgs with
|
||||
| true, (``HSub.hSub, #[_, _, _, _, a, b]) =>
|
||||
-- `((a - b : Nat) : Int)` gives a dichotomy
|
||||
@@ -188,7 +190,7 @@ def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
|
||||
let ne_zero := mkApp3 (.const ``Ne [1]) (.const ``Int []) k (toExpr (0 : Int))
|
||||
let pos := mkApp4 (.const ``LT.lt [0]) (.const ``Int []) (.const ``Int.instLTInt [])
|
||||
(toExpr (0 : Int)) k
|
||||
pure <| HashSet.empty.insert
|
||||
pure <| Std.HashSet.empty.insert
|
||||
(mkApp3 (.const ``Int.mul_ediv_self_le []) x k (← mkDecideProof ne_zero)) |>.insert
|
||||
(mkApp3 (.const ``Int.lt_mul_ediv_self_add []) x k (← mkDecideProof pos))
|
||||
| (``HMod.hMod, #[_, _, _, _, x, k]) =>
|
||||
@@ -200,7 +202,7 @@ def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
|
||||
let b_pos := mkApp4 (.const ``LT.lt [0]) (.const ``Int []) (.const ``Int.instLTInt [])
|
||||
(toExpr (0 : Int)) b
|
||||
let pow_pos := mkApp3 (.const ``Lean.Omega.Int.pos_pow_of_pos []) b exp (← mkDecideProof b_pos)
|
||||
pure <| HashSet.empty.insert
|
||||
pure <| Std.HashSet.empty.insert
|
||||
(mkApp3 (.const ``Int.emod_nonneg []) x k
|
||||
(mkApp3 (.const ``Int.ne_of_gt []) k (toExpr (0 : Int)) pow_pos)) |>.insert
|
||||
(mkApp3 (.const ``Int.emod_lt_of_pos []) x k pow_pos)
|
||||
@@ -214,7 +216,7 @@ def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
|
||||
(toExpr (0 : Nat)) b
|
||||
let pow_pos := mkApp3 (.const ``Nat.pos_pow_of_pos []) b exp (← mkDecideProof b_pos)
|
||||
let cast_pos := mkApp2 (.const ``Int.ofNat_pos_of_pos []) k' pow_pos
|
||||
pure <| HashSet.empty.insert
|
||||
pure <| Std.HashSet.empty.insert
|
||||
(mkApp3 (.const ``Int.emod_nonneg []) x k
|
||||
(mkApp3 (.const ``Int.ne_of_gt []) k (toExpr (0 : Int)) cast_pos)) |>.insert
|
||||
(mkApp3 (.const ``Int.emod_lt_of_pos []) x k cast_pos)
|
||||
@@ -222,18 +224,18 @@ def analyzeAtom (e : Expr) : OmegaM (HashSet Expr) := do
|
||||
| (``Nat.cast, #[.const ``Int [], _, x']) =>
|
||||
-- Since we push coercions inside `%`, we need to record here that
|
||||
-- `(x : Int) % (y : Int)` is non-negative.
|
||||
pure <| HashSet.empty.insert (mkApp2 (.const ``Int.emod_ofNat_nonneg []) x' k)
|
||||
pure <| Std.HashSet.empty.insert (mkApp2 (.const ``Int.emod_ofNat_nonneg []) x' k)
|
||||
| _ => pure ∅
|
||||
| _ => pure ∅
|
||||
| (``Min.min, #[_, _, x, y]) =>
|
||||
pure <| HashSet.empty.insert (mkApp2 (.const ``Int.min_le_left []) x y) |>.insert
|
||||
pure <| Std.HashSet.empty.insert (mkApp2 (.const ``Int.min_le_left []) x y) |>.insert
|
||||
(mkApp2 (.const ``Int.min_le_right []) x y)
|
||||
| (``Max.max, #[_, _, x, y]) =>
|
||||
pure <| HashSet.empty.insert (mkApp2 (.const ``Int.le_max_left []) x y) |>.insert
|
||||
pure <| Std.HashSet.empty.insert (mkApp2 (.const ``Int.le_max_left []) x y) |>.insert
|
||||
(mkApp2 (.const ``Int.le_max_right []) x y)
|
||||
| (``ite, #[α, i, dec, t, e]) =>
|
||||
if α == (.const ``Int []) then
|
||||
pure <| HashSet.empty.insert <| mkApp5 (.const ``ite_disjunction [0]) α i dec t e
|
||||
pure <| Std.HashSet.empty.insert <| mkApp5 (.const ``ite_disjunction [0]) α i dec t e
|
||||
else
|
||||
pure {}
|
||||
| _ => pure ∅
|
||||
@@ -248,10 +250,10 @@ Return its index, and, if it is new, a collection of interesting facts about the
|
||||
* for each new atom of the form `((a - b : Nat) : Int)`, the fact:
|
||||
`b ≤ a ∧ ((a - b : Nat) : Int) = a - b ∨ a < b ∧ ((a - b : Nat) : Int) = 0`
|
||||
-/
|
||||
def lookup (e : Expr) : OmegaM (Nat × Option (HashSet Expr)) := do
|
||||
def lookup (e : Expr) : OmegaM (Nat × Option (Std.HashSet Expr)) := do
|
||||
let c ← getThe State
|
||||
let e ← canon e
|
||||
match c.atoms.find? e with
|
||||
match c.atoms[e]? with
|
||||
| some i => return (i, none)
|
||||
| none =>
|
||||
trace[omega] "New atom: {e}"
|
||||
|
||||
@@ -7,7 +7,6 @@ prelude
|
||||
import Init.Control.StateRef
|
||||
import Init.Data.Array.BinSearch
|
||||
import Init.Data.Stream
|
||||
import Lean.Data.HashMap
|
||||
import Lean.ImportingFlag
|
||||
import Lean.Data.SMap
|
||||
import Lean.Declaration
|
||||
@@ -134,7 +133,7 @@ structure Environment where
|
||||
the field `constants`. These auxiliary constants are invisible to the Lean kernel and elaborator.
|
||||
Only the code generator uses them.
|
||||
-/
|
||||
const2ModIdx : HashMap Name ModuleIdx
|
||||
const2ModIdx : Std.HashMap Name ModuleIdx
|
||||
/--
|
||||
Mapping from constant name to `ConstantInfo`. It contains all constants (definitions, theorems, axioms, etc)
|
||||
that have been already type checked by the kernel.
|
||||
@@ -205,7 +204,7 @@ private def getTrustLevel (env : Environment) : UInt32 :=
|
||||
env.header.trustLevel
|
||||
|
||||
def getModuleIdxFor? (env : Environment) (declName : Name) : Option ModuleIdx :=
|
||||
env.const2ModIdx.find? declName
|
||||
env.const2ModIdx[declName]?
|
||||
|
||||
def isConstructor (env : Environment) (declName : Name) : Bool :=
|
||||
match env.find? declName with
|
||||
@@ -721,7 +720,7 @@ def writeModule (env : Environment) (fname : System.FilePath) : IO Unit := do
|
||||
Construct a mapping from persistent extension name to entension index at the array of persistent extensions.
|
||||
We only consider extensions starting with index `>= startingAt`.
|
||||
-/
|
||||
def mkExtNameMap (startingAt : Nat) : IO (HashMap Name Nat) := do
|
||||
def mkExtNameMap (startingAt : Nat) : IO (Std.HashMap Name Nat) := do
|
||||
let descrs ← persistentEnvExtensionsRef.get
|
||||
let mut result := {}
|
||||
for h : i in [startingAt : descrs.size] do
|
||||
@@ -742,7 +741,7 @@ private def setImportedEntries (env : Environment) (mods : Array ModuleData) (st
|
||||
have : modIdx < mods.size := h.upper
|
||||
let mod := mods[modIdx]
|
||||
for (extName, entries) in mod.entries do
|
||||
if let some entryIdx := extNameIdx.find? extName then
|
||||
if let some entryIdx := extNameIdx[extName]? then
|
||||
env := extDescrs[entryIdx]!.toEnvExtension.modifyState env fun s => { s with importedEntries := s.importedEntries.set! modIdx entries }
|
||||
return env
|
||||
|
||||
@@ -790,9 +789,9 @@ structure ImportState where
|
||||
moduleData : Array ModuleData := #[]
|
||||
regions : Array CompactedRegion := #[]
|
||||
|
||||
def throwAlreadyImported (s : ImportState) (const2ModIdx : HashMap Name ModuleIdx) (modIdx : Nat) (cname : Name) : IO α := do
|
||||
def throwAlreadyImported (s : ImportState) (const2ModIdx : Std.HashMap Name ModuleIdx) (modIdx : Nat) (cname : Name) : IO α := do
|
||||
let modName := s.moduleNames[modIdx]!
|
||||
let constModName := s.moduleNames[const2ModIdx[cname].get!.toNat]!
|
||||
let constModName := s.moduleNames[const2ModIdx[cname]!.toNat]!
|
||||
throw <| IO.userError s!"import {modName} failed, environment already contains '{cname}' from {constModName}"
|
||||
|
||||
abbrev ImportStateM := StateRefT ImportState IO
|
||||
@@ -806,6 +805,8 @@ partial def importModulesCore (imports : Array Import) : ImportStateM Unit := do
|
||||
continue
|
||||
modify fun s => { s with moduleNameSet := s.moduleNameSet.insert i.module }
|
||||
let mFile ← findOLean i.module
|
||||
unless (← mFile.pathExists) do
|
||||
throw <| IO.userError s!"object file '{mFile}' of module {i.module} does not exist"
|
||||
let (mod, region) ← readModuleData mFile
|
||||
importModulesCore mod.imports
|
||||
modify fun s => { s with
|
||||
@@ -854,21 +855,21 @@ def finalizeImport (s : ImportState) (imports : Array Import) (opts : Options) (
|
||||
(leakEnv := false) : IO Environment := do
|
||||
let numConsts := s.moduleData.foldl (init := 0) fun numConsts mod =>
|
||||
numConsts + mod.constants.size + mod.extraConstNames.size
|
||||
let mut const2ModIdx : HashMap Name ModuleIdx := mkHashMap (capacity := numConsts)
|
||||
let mut constantMap : HashMap Name ConstantInfo := mkHashMap (capacity := numConsts)
|
||||
let mut const2ModIdx : Std.HashMap Name ModuleIdx := Std.HashMap.empty (capacity := numConsts)
|
||||
let mut constantMap : Std.HashMap Name ConstantInfo := Std.HashMap.empty (capacity := numConsts)
|
||||
for h:modIdx in [0:s.moduleData.size] do
|
||||
let mod := s.moduleData[modIdx]'h.upper
|
||||
for cname in mod.constNames, cinfo in mod.constants do
|
||||
match constantMap.insertIfNew cname cinfo with
|
||||
| (constantMap', cinfoPrev?) =>
|
||||
match constantMap.getThenInsertIfNew? cname cinfo with
|
||||
| (cinfoPrev?, constantMap') =>
|
||||
constantMap := constantMap'
|
||||
if let some cinfoPrev := cinfoPrev? then
|
||||
-- Recall that the map has not been modified when `cinfoPrev? = some _`.
|
||||
unless equivInfo cinfoPrev cinfo do
|
||||
throwAlreadyImported s const2ModIdx modIdx cname
|
||||
const2ModIdx := const2ModIdx.insertIfNew cname modIdx |>.1
|
||||
const2ModIdx := const2ModIdx.insertIfNew cname modIdx
|
||||
for cname in mod.extraConstNames do
|
||||
const2ModIdx := const2ModIdx.insertIfNew cname modIdx |>.1
|
||||
const2ModIdx := const2ModIdx.insertIfNew cname modIdx
|
||||
let constants : ConstMap := SMap.fromHashMap constantMap false
|
||||
let exts ← mkInitialExtensionStates
|
||||
let mut env : Environment := {
|
||||
@@ -934,7 +935,7 @@ builtin_initialize namespacesExt : SimplePersistentEnvExtension Name NameSSet
|
||||
6.18% of the runtime is here. It was 9.31% before the `HashMap` optimization.
|
||||
-/
|
||||
let capacity := as.foldl (init := 0) fun r e => r + e.size
|
||||
let map : HashMap Name Unit := mkHashMap capacity
|
||||
let map : Std.HashMap Name Unit := Std.HashMap.empty capacity
|
||||
let map := mkStateFromImportedEntries (fun map name => map.insert name ()) map as
|
||||
SMap.fromHashMap map |>.switch
|
||||
addEntryFn := fun s n => s.insert n
|
||||
|
||||
@@ -8,6 +8,7 @@ import Init.Data.Hashable
|
||||
import Lean.Data.KVMap
|
||||
import Lean.Data.SMap
|
||||
import Lean.Level
|
||||
import Std.Data.HashSet.Basic
|
||||
|
||||
namespace Lean
|
||||
|
||||
@@ -244,7 +245,7 @@ def FVarIdSet.insert (s : FVarIdSet) (fvarId : FVarId) : FVarIdSet :=
|
||||
A set of unique free variable identifiers implemented using hashtables.
|
||||
Hashtables are faster than red-black trees if they are used linearly.
|
||||
They are not persistent data-structures. -/
|
||||
def FVarIdHashSet := HashSet FVarId
|
||||
def FVarIdHashSet := Std.HashSet FVarId
|
||||
deriving Inhabited, EmptyCollection
|
||||
|
||||
/--
|
||||
@@ -1388,11 +1389,11 @@ def mkDecIsTrue (pred proof : Expr) :=
|
||||
def mkDecIsFalse (pred proof : Expr) :=
|
||||
mkAppB (mkConst `Decidable.isFalse) pred proof
|
||||
|
||||
abbrev ExprMap (α : Type) := HashMap Expr α
|
||||
abbrev ExprMap (α : Type) := Std.HashMap Expr α
|
||||
abbrev PersistentExprMap (α : Type) := PHashMap Expr α
|
||||
abbrev SExprMap (α : Type) := SMap Expr α
|
||||
|
||||
abbrev ExprSet := HashSet Expr
|
||||
abbrev ExprSet := Std.HashSet Expr
|
||||
abbrev PersistentExprSet := PHashSet Expr
|
||||
abbrev PExprSet := PersistentExprSet
|
||||
|
||||
@@ -1417,7 +1418,7 @@ instance : ToString ExprStructEq := ⟨fun e => toString e.val⟩
|
||||
|
||||
end ExprStructEq
|
||||
|
||||
abbrev ExprStructMap (α : Type) := HashMap ExprStructEq α
|
||||
abbrev ExprStructMap (α : Type) := Std.HashMap ExprStructEq α
|
||||
abbrev PersistentExprStructMap (α : Type) := PHashMap ExprStructEq α
|
||||
|
||||
namespace Expr
|
||||
@@ -1452,28 +1453,26 @@ partial def betaRev (f : Expr) (revArgs : Array Expr) (useZeta := false) (preser
|
||||
else
|
||||
let sz := revArgs.size
|
||||
let rec go (e : Expr) (i : Nat) : Expr :=
|
||||
let done (_ : Unit) : Expr :=
|
||||
let n := sz - i
|
||||
mkAppRevRange (e.instantiateRange n sz revArgs) 0 n revArgs
|
||||
match e with
|
||||
| Expr.lam _ _ b _ =>
|
||||
| .lam _ _ b _ =>
|
||||
if i + 1 < sz then
|
||||
go b (i+1)
|
||||
else
|
||||
let n := sz - (i + 1)
|
||||
mkAppRevRange (b.instantiateRange n sz revArgs) 0 n revArgs
|
||||
| Expr.letE _ _ v b _ =>
|
||||
b.instantiate revArgs
|
||||
| .letE _ _ v b _ =>
|
||||
if useZeta && i < sz then
|
||||
go (b.instantiate1 v) i
|
||||
else
|
||||
let n := sz - i
|
||||
mkAppRevRange (e.instantiateRange n sz revArgs) 0 n revArgs
|
||||
| Expr.mdata k b =>
|
||||
done ()
|
||||
| .mdata _ b =>
|
||||
if preserveMData then
|
||||
let n := sz - i
|
||||
mkMData k (mkAppRevRange (b.instantiateRange n sz revArgs) 0 n revArgs)
|
||||
done ()
|
||||
else
|
||||
go b i
|
||||
| b =>
|
||||
let n := sz - i
|
||||
mkAppRevRange (b.instantiateRange n sz revArgs) 0 n revArgs
|
||||
| _ => done ()
|
||||
go f 0
|
||||
|
||||
/--
|
||||
|
||||
@@ -31,7 +31,7 @@ namespace Lean
|
||||
abbrev LabelExtension := SimpleScopedEnvExtension Name (Array Name)
|
||||
|
||||
/-- The collection of all current `LabelExtension`s, indexed by name. -/
|
||||
abbrev LabelExtensionMap := HashMap Name LabelExtension
|
||||
abbrev LabelExtensionMap := Std.HashMap Name LabelExtension
|
||||
|
||||
/-- Store the current `LabelExtension`s. -/
|
||||
builtin_initialize labelExtensionMapRef : IO.Ref LabelExtensionMap ← IO.mkRef {}
|
||||
@@ -88,7 +88,7 @@ macro (name := _root_.Lean.Parser.Command.registerLabelAttr)
|
||||
/-- When `attrName` is an attribute created using `register_labelled_attr`,
|
||||
return the names of all declarations labelled using that attribute. -/
|
||||
def labelled (attrName : Name) : CoreM (Array Name) := do
|
||||
match (← labelExtensionMapRef.get).find? attrName with
|
||||
match (← labelExtensionMapRef.get)[attrName]? with
|
||||
| none => throwError "No extension named {attrName}"
|
||||
| some ext => pure <| ext.getState (← getEnv)
|
||||
|
||||
|
||||
@@ -234,6 +234,54 @@ structure SetupImportsResult where
|
||||
/-- Kernel trust level. -/
|
||||
trustLevel : UInt32 := 0
|
||||
|
||||
/-- Performance option used by cmdline driver. -/
|
||||
register_builtin_option internal.minimalSnapshots : Bool := {
|
||||
defValue := false
|
||||
descr := "reduce information stored in snapshots to the minimum necessary for the cmdline \
|
||||
driver: diagnostics per command and final full snapshot"
|
||||
}
|
||||
|
||||
/--
|
||||
Parses values of options registered during import and left by the C++ frontend as strings, fails if
|
||||
any option names remain unknown.
|
||||
-/
|
||||
def reparseOptions (opts : Options) : IO Options := do
|
||||
let mut opts := opts
|
||||
let decls ← getOptionDecls
|
||||
for (name, val) in opts do
|
||||
let .ofString val := val
|
||||
| continue -- Already parsed by C++
|
||||
-- Options can be prefixed with `weak` in order to turn off the error when the option is not
|
||||
-- defined
|
||||
let weak := name.getRoot == `weak
|
||||
if weak then
|
||||
opts := opts.erase name
|
||||
let name := name.replacePrefix `weak Name.anonymous
|
||||
let some decl := decls.find? name
|
||||
| unless weak do
|
||||
throw <| .userError s!"invalid -D parameter, unknown configuration option '{name}'
|
||||
|
||||
If the option is defined in this library, use '-D{`weak ++ name}' to set it conditionally"
|
||||
|
||||
match decl.defValue with
|
||||
| .ofBool _ =>
|
||||
match val with
|
||||
| "true" => opts := opts.insert name true
|
||||
| "false" => opts := opts.insert name false
|
||||
| _ =>
|
||||
throw <| .userError s!"invalid -D parameter, invalid configuration option '{val}' value, \
|
||||
it must be true/false"
|
||||
| .ofNat _ =>
|
||||
let some val := val.toNat?
|
||||
| throw <| .userError s!"invalid -D parameter, invalid configuration option '{val}' value, \
|
||||
it must be a natural number"
|
||||
opts := opts.insert name val
|
||||
| .ofString _ => opts := opts.insert name val
|
||||
| _ => throw <| .userError s!"invalid -D parameter, configuration option '{name}' \
|
||||
cannot be set in the command line, use set_option command"
|
||||
|
||||
return opts
|
||||
|
||||
/--
|
||||
Entry point of the Lean language processor.
|
||||
|
||||
@@ -279,9 +327,11 @@ where
|
||||
if let some oldProcSuccess := oldProcessed.result? then
|
||||
-- also wait on old command parse snapshot as parsing is cheap and may allow for
|
||||
-- elaboration reuse
|
||||
oldProcSuccess.firstCmdSnap.bindIO (sync := true) fun oldCmd =>
|
||||
oldProcSuccess.firstCmdSnap.bindIO (sync := true) fun oldCmd => do
|
||||
let prom ← IO.Promise.new
|
||||
let _ ← IO.asTask (parseCmd oldCmd newParserState oldProcSuccess.cmdState oldProcSuccess.cmdState.env prom ctx)
|
||||
return .pure { oldProcessed with result? := some { oldProcSuccess with
|
||||
firstCmdSnap := (← parseCmd oldCmd newParserState oldProcSuccess.cmdState ctx) } }
|
||||
firstCmdSnap := { range? := none, task := prom.result } } }
|
||||
else
|
||||
return .pure oldProcessed) } }
|
||||
else return old
|
||||
@@ -343,42 +393,68 @@ where
|
||||
let setup ← match (← setupImports stx) with
|
||||
| .ok setup => pure setup
|
||||
| .error snap => return snap
|
||||
|
||||
let startTime := (← IO.monoNanosNow).toFloat / 1000000000
|
||||
-- allows `headerEnv` to be leaked, which would live until the end of the process anyway
|
||||
let (headerEnv, msgLog) ← Elab.processHeader (leakEnv := true) stx setup.opts .empty
|
||||
ctx.toInputContext setup.trustLevel
|
||||
let stopTime := (← IO.monoNanosNow).toFloat / 1000000000
|
||||
let diagnostics := (← Snapshot.Diagnostics.ofMessageLog msgLog)
|
||||
if msgLog.hasErrors then
|
||||
return { diagnostics, result? := none }
|
||||
|
||||
let headerEnv := headerEnv.setMainModule setup.mainModuleName
|
||||
let cmdState := Elab.Command.mkState headerEnv msgLog setup.opts
|
||||
let cmdState := { cmdState with infoState := {
|
||||
enabled := true
|
||||
trees := #[Elab.InfoTree.context (.commandCtx {
|
||||
env := headerEnv
|
||||
fileMap := ctx.fileMap
|
||||
ngen := { namePrefix := `_import }
|
||||
}) (Elab.InfoTree.node
|
||||
(Elab.Info.ofCommandInfo { elaborator := `header, stx })
|
||||
(stx[1].getArgs.toList.map (fun importStx =>
|
||||
Elab.InfoTree.node (Elab.Info.ofCommandInfo {
|
||||
elaborator := `import
|
||||
stx := importStx
|
||||
}) #[].toPArray'
|
||||
)).toPArray'
|
||||
)].toPArray'
|
||||
}}
|
||||
let mut traceState := default
|
||||
if trace.profiler.output.get? setup.opts |>.isSome then
|
||||
traceState := {
|
||||
traces := #[{
|
||||
ref := .missing,
|
||||
msg := .trace { cls := `Import, startTime, stopTime }
|
||||
(.ofFormat "importing") #[]
|
||||
: TraceElem
|
||||
}].toPArray'
|
||||
}
|
||||
-- now that imports have been loaded, check options again
|
||||
let opts ← reparseOptions setup.opts
|
||||
let cmdState := Elab.Command.mkState headerEnv msgLog opts
|
||||
let cmdState := { cmdState with
|
||||
infoState := {
|
||||
enabled := true
|
||||
trees := #[Elab.InfoTree.context (.commandCtx {
|
||||
env := headerEnv
|
||||
fileMap := ctx.fileMap
|
||||
ngen := { namePrefix := `_import }
|
||||
}) (Elab.InfoTree.node
|
||||
(Elab.Info.ofCommandInfo { elaborator := `header, stx })
|
||||
(stx[1].getArgs.toList.map (fun importStx =>
|
||||
Elab.InfoTree.node (Elab.Info.ofCommandInfo {
|
||||
elaborator := `import
|
||||
stx := importStx
|
||||
}) #[].toPArray'
|
||||
)).toPArray'
|
||||
)].toPArray'
|
||||
}
|
||||
traceState
|
||||
}
|
||||
let prom ← IO.Promise.new
|
||||
-- The speedup of these `markPersistent`s is negligible but they help in making unexpected
|
||||
-- `inc_ref_cold`s more visible
|
||||
let parserState := Runtime.markPersistent parserState
|
||||
let cmdState := Runtime.markPersistent cmdState
|
||||
let ctx := Runtime.markPersistent ctx
|
||||
let _ ← IO.asTask (parseCmd none parserState cmdState cmdState.env prom ctx)
|
||||
return {
|
||||
diagnostics
|
||||
infoTree? := cmdState.infoState.trees[0]!
|
||||
result? := some {
|
||||
cmdState
|
||||
firstCmdSnap := (← parseCmd none parserState cmdState)
|
||||
firstCmdSnap := { range? := none, task := prom.result }
|
||||
}
|
||||
}
|
||||
|
||||
parseCmd (old? : Option CommandParsedSnapshot) (parserState : Parser.ModuleParserState)
|
||||
(cmdState : Command.State) : LeanProcessingM (SnapshotTask CommandParsedSnapshot) := do
|
||||
(cmdState : Command.State) (initEnv : Environment) (prom : IO.Promise CommandParsedSnapshot) :
|
||||
LeanProcessingM Unit := do
|
||||
let ctx ← read
|
||||
|
||||
-- check for cancellation, most likely during elaboration of previous command, before starting
|
||||
@@ -387,82 +463,100 @@ where
|
||||
-- this is a bit ugly as we don't want to adjust our API with `Option`s just for cancellation
|
||||
-- (as no-one should look at this result in that case) but anything containing `Environment`
|
||||
-- is not `Inhabited`
|
||||
return .pure <| .mk (nextCmdSnap? := none) {
|
||||
prom.resolve <| .mk (nextCmdSnap? := none) {
|
||||
diagnostics := .empty, stx := .missing, parserState
|
||||
elabSnap := .pure <| .ofTyped { diagnostics := .empty : SnapshotLeaf }
|
||||
finishedSnap := .pure { diagnostics := .empty, cmdState }
|
||||
tacticCache := (← IO.mkRef {})
|
||||
}
|
||||
return
|
||||
|
||||
let unchanged old newParserState : BaseIO CommandParsedSnapshot :=
|
||||
let unchanged old newParserState : BaseIO Unit :=
|
||||
-- when syntax is unchanged, reuse command processing task as is
|
||||
-- NOTE: even if the syntax tree is functionally unchanged, the new parser state may still
|
||||
-- have changed because of trailing whitespace and comments etc., so it is passed separately
|
||||
-- from `old`
|
||||
if let some oldNext := old.nextCmdSnap? then
|
||||
return .mk (data := old.data)
|
||||
(nextCmdSnap? := (← old.data.finishedSnap.bindIO (sync := true) fun oldFinished =>
|
||||
-- also wait on old command parse snapshot as parsing is cheap and may allow for
|
||||
-- elaboration reuse
|
||||
oldNext.bindIO (sync := true) fun oldNext => do
|
||||
parseCmd oldNext newParserState oldFinished.cmdState ctx))
|
||||
else return old -- terminal command, we're done!
|
||||
if let some oldNext := old.nextCmdSnap? then do
|
||||
let newProm ← IO.Promise.new
|
||||
let _ ← old.data.finishedSnap.bindIO fun oldFinished =>
|
||||
-- also wait on old command parse snapshot as parsing is cheap and may allow for
|
||||
-- elaboration reuse
|
||||
oldNext.bindIO (sync := true) fun oldNext => do
|
||||
parseCmd oldNext newParserState oldFinished.cmdState initEnv newProm ctx
|
||||
return .pure ()
|
||||
prom.resolve <| .mk (data := old.data) (nextCmdSnap? := some { range? := none, task := newProm.result })
|
||||
else prom.resolve old -- terminal command, we're done!
|
||||
|
||||
-- fast path, do not even start new task for this snapshot
|
||||
if let some old := old? then
|
||||
if let some nextCom ← old.nextCmdSnap?.bindM (·.get?) then
|
||||
if (← isBeforeEditPos nextCom.data.parserState.pos) then
|
||||
return .pure (← unchanged old old.data.parserState)
|
||||
return (← unchanged old old.data.parserState)
|
||||
|
||||
SnapshotTask.ofIO (some ⟨parserState.pos, ctx.input.endPos⟩) do
|
||||
let beginPos := parserState.pos
|
||||
let scope := cmdState.scopes.head!
|
||||
let pmctx := {
|
||||
env := cmdState.env, options := scope.opts, currNamespace := scope.currNamespace
|
||||
openDecls := scope.openDecls
|
||||
}
|
||||
let (stx, parserState, msgLog) := Parser.parseCommand ctx.toInputContext pmctx parserState
|
||||
.empty
|
||||
let beginPos := parserState.pos
|
||||
let scope := cmdState.scopes.head!
|
||||
let pmctx := {
|
||||
env := cmdState.env, options := scope.opts, currNamespace := scope.currNamespace
|
||||
openDecls := scope.openDecls
|
||||
}
|
||||
let (stx, parserState, msgLog) :=
|
||||
profileit "parsing" scope.opts fun _ =>
|
||||
Parser.parseCommand ctx.toInputContext pmctx parserState .empty
|
||||
|
||||
-- semi-fast path
|
||||
if let some old := old? then
|
||||
-- NOTE: as `parserState.pos` includes trailing whitespace, this forces reprocessing even if
|
||||
-- only that whitespace changes, which is wasteful but still necessary because it may
|
||||
-- influence the range of error messages such as from a trailing `exact`
|
||||
if stx.eqWithInfo old.data.stx then
|
||||
-- Here we must make sure to pass the *new* parser state; see NOTE in `unchanged`
|
||||
return (← unchanged old parserState)
|
||||
-- on first change, make sure to cancel old invocation
|
||||
-- TODO: pass token into incrementality-aware elaborators to improve reuse of still-valid,
|
||||
-- still-running elaboration steps?
|
||||
if let some tk := ctx.oldCancelTk? then
|
||||
tk.set
|
||||
-- semi-fast path
|
||||
if let some old := old? then
|
||||
-- NOTE: as `parserState.pos` includes trailing whitespace, this forces reprocessing even if
|
||||
-- only that whitespace changes, which is wasteful but still necessary because it may
|
||||
-- influence the range of error messages such as from a trailing `exact`
|
||||
if stx.eqWithInfo old.data.stx then
|
||||
-- Here we must make sure to pass the *new* parser state; see NOTE in `unchanged`
|
||||
return (← unchanged old parserState)
|
||||
-- on first change, make sure to cancel old invocation
|
||||
-- TODO: pass token into incrementality-aware elaborators to improve reuse of still-valid,
|
||||
-- still-running elaboration steps?
|
||||
if let some tk := ctx.oldCancelTk? then
|
||||
tk.set
|
||||
|
||||
-- definitely resolved in `doElab` task
|
||||
let elabPromise ← IO.Promise.new
|
||||
let tacticCache ← old?.map (·.data.tacticCache) |>.getDM (IO.mkRef {})
|
||||
let finishedSnap ←
|
||||
doElab stx cmdState beginPos
|
||||
{ old? := old?.map fun old => ⟨old.data.stx, old.data.elabSnap⟩, new := elabPromise }
|
||||
tacticCache
|
||||
ctx
|
||||
|
||||
let next? ← if Parser.isTerminalCommand stx then pure none
|
||||
-- for now, wait on "command finished" snapshot before parsing next command
|
||||
else some <$> finishedSnap.bindIO fun finished =>
|
||||
parseCmd none parserState finished.cmdState ctx
|
||||
return .mk (nextCmdSnap? := next?) {
|
||||
diagnostics := (← Snapshot.Diagnostics.ofMessageLog msgLog)
|
||||
stx
|
||||
parserState
|
||||
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
|
||||
finishedSnap
|
||||
-- definitely resolved in `doElab` task
|
||||
let elabPromise ← IO.Promise.new
|
||||
let tacticCache ← old?.map (·.data.tacticCache) |>.getDM (IO.mkRef {})
|
||||
let finishedSnap ←
|
||||
doElab stx cmdState beginPos
|
||||
{ old? := old?.map fun old => ⟨old.data.stx, old.data.elabSnap⟩, new := elabPromise }
|
||||
tacticCache
|
||||
ctx
|
||||
|
||||
let minimalSnapshots := internal.minimalSnapshots.get cmdState.scopes.head!.opts
|
||||
let next? ← if Parser.isTerminalCommand stx then pure none
|
||||
-- for now, wait on "command finished" snapshot before parsing next command
|
||||
else some <$> IO.Promise.new
|
||||
let diagnostics ← Snapshot.Diagnostics.ofMessageLog msgLog
|
||||
let data := if minimalSnapshots && !Parser.isTerminalCommand stx then {
|
||||
diagnostics
|
||||
stx := .missing
|
||||
parserState := {}
|
||||
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
|
||||
finishedSnap := .pure {
|
||||
diagnostics := finishedSnap.diagnostics
|
||||
infoTree? := none
|
||||
cmdState := {
|
||||
env := initEnv
|
||||
maxRecDepth := 0
|
||||
}
|
||||
}
|
||||
tacticCache
|
||||
} else {
|
||||
diagnostics, stx, parserState, tacticCache
|
||||
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
|
||||
finishedSnap := .pure finishedSnap
|
||||
}
|
||||
prom.resolve <| .mk (nextCmdSnap? := next?.map ({ range? := some ⟨parserState.pos, ctx.input.endPos⟩, task := ·.result })) data
|
||||
if let some next := next? then
|
||||
parseCmd none parserState finishedSnap.cmdState initEnv next ctx
|
||||
|
||||
doElab (stx : Syntax) (cmdState : Command.State) (beginPos : String.Pos)
|
||||
(snap : SnapshotBundle DynamicSnapshot) (tacticCache : IO.Ref Tactic.Cache) :
|
||||
LeanProcessingM (SnapshotTask CommandFinishedSnapshot) := do
|
||||
LeanProcessingM CommandFinishedSnapshot := do
|
||||
let ctx ← read
|
||||
-- (Try to) use last line of command as range for final snapshot task. This ensures we do not
|
||||
-- retract the progress bar to a previous position in case the command support incremental
|
||||
@@ -471,48 +565,46 @@ where
|
||||
-- `parseCmd` and containing the entire range of the command will determine the reported
|
||||
-- progress and be resolved effectively at the same time as this snapshot task, so `tailPos` is
|
||||
-- irrelevant in this case.
|
||||
let tailPos := stx.getTailPos? |>.getD beginPos
|
||||
SnapshotTask.ofIO (some ⟨tailPos, tailPos⟩) do
|
||||
let scope := cmdState.scopes.head!
|
||||
let cmdStateRef ← IO.mkRef { cmdState with messages := .empty }
|
||||
/-
|
||||
The same snapshot may be executed by different tasks. So, to make sure `elabCommandTopLevel`
|
||||
has exclusive access to the cache, we create a fresh reference here. Before this change, the
|
||||
following `tacticCache.modify` would reset the tactic post cache while another snapshot was
|
||||
still using it.
|
||||
-/
|
||||
let tacticCacheNew ← IO.mkRef (← tacticCache.get)
|
||||
let cmdCtx : Elab.Command.Context := { ctx with
|
||||
cmdPos := beginPos
|
||||
tacticCache? := some tacticCacheNew
|
||||
snap? := some snap
|
||||
cancelTk? := some ctx.newCancelTk
|
||||
}
|
||||
let (output, _) ←
|
||||
IO.FS.withIsolatedStreams (isolateStderr := stderrAsMessages.get scope.opts) do
|
||||
liftM (m := BaseIO) do
|
||||
withLoggingExceptions
|
||||
(getResetInfoTrees *> Elab.Command.elabCommandTopLevel stx)
|
||||
cmdCtx cmdStateRef
|
||||
let postNew := (← tacticCacheNew.get).post
|
||||
tacticCache.modify fun _ => { pre := postNew, post := {} }
|
||||
let cmdState ← cmdStateRef.get
|
||||
let mut messages := cmdState.messages
|
||||
if !output.isEmpty then
|
||||
messages := messages.add {
|
||||
fileName := ctx.fileName
|
||||
severity := MessageSeverity.information
|
||||
pos := ctx.fileMap.toPosition beginPos
|
||||
data := output
|
||||
}
|
||||
let cmdState := { cmdState with messages }
|
||||
-- definitely resolve eventually
|
||||
snap.new.resolve <| .ofTyped { diagnostics := .empty : SnapshotLeaf }
|
||||
return {
|
||||
diagnostics := (← Snapshot.Diagnostics.ofMessageLog cmdState.messages)
|
||||
infoTree? := some cmdState.infoState.trees[0]!
|
||||
cmdState
|
||||
let scope := cmdState.scopes.head!
|
||||
let cmdStateRef ← IO.mkRef { cmdState with messages := .empty }
|
||||
/-
|
||||
The same snapshot may be executed by different tasks. So, to make sure `elabCommandTopLevel`
|
||||
has exclusive access to the cache, we create a fresh reference here. Before this change, the
|
||||
following `tacticCache.modify` would reset the tactic post cache while another snapshot was
|
||||
still using it.
|
||||
-/
|
||||
let tacticCacheNew ← IO.mkRef (← tacticCache.get)
|
||||
let cmdCtx : Elab.Command.Context := { ctx with
|
||||
cmdPos := beginPos
|
||||
tacticCache? := some tacticCacheNew
|
||||
snap? := if internal.minimalSnapshots.get scope.opts then none else snap
|
||||
cancelTk? := some ctx.newCancelTk
|
||||
}
|
||||
let (output, _) ←
|
||||
IO.FS.withIsolatedStreams (isolateStderr := stderrAsMessages.get scope.opts) do
|
||||
liftM (m := BaseIO) do
|
||||
withLoggingExceptions
|
||||
(getResetInfoTrees *> Elab.Command.elabCommandTopLevel stx)
|
||||
cmdCtx cmdStateRef
|
||||
let postNew := (← tacticCacheNew.get).post
|
||||
tacticCache.modify fun _ => { pre := postNew, post := {} }
|
||||
let cmdState ← cmdStateRef.get
|
||||
let mut messages := cmdState.messages
|
||||
if !output.isEmpty then
|
||||
messages := messages.add {
|
||||
fileName := ctx.fileName
|
||||
severity := MessageSeverity.information
|
||||
pos := ctx.fileMap.toPosition beginPos
|
||||
data := output
|
||||
}
|
||||
let cmdState := { cmdState with messages }
|
||||
-- definitely resolve eventually
|
||||
snap.new.resolve <| .ofTyped { diagnostics := .empty : SnapshotLeaf }
|
||||
return {
|
||||
diagnostics := (← Snapshot.Diagnostics.ofMessageLog cmdState.messages)
|
||||
infoTree? := some cmdState.infoState.trees[0]!
|
||||
cmdState
|
||||
}
|
||||
|
||||
/--
|
||||
Convenience function for tool uses of the language processor that skips header handling.
|
||||
@@ -520,14 +612,15 @@ Convenience function for tool uses of the language processor that skips header h
|
||||
def processCommands (inputCtx : Parser.InputContext) (parserState : Parser.ModuleParserState)
|
||||
(commandState : Command.State)
|
||||
(old? : Option (Parser.InputContext × CommandParsedSnapshot) := none) :
|
||||
BaseIO (SnapshotTask CommandParsedSnapshot) := do
|
||||
process.parseCmd (old?.map (·.2)) parserState commandState
|
||||
BaseIO (Task CommandParsedSnapshot) := do
|
||||
let prom ← IO.Promise.new
|
||||
process.parseCmd (old?.map (·.2)) parserState commandState commandState.env prom
|
||||
|>.run (old?.map (·.1))
|
||||
|>.run { inputCtx with }
|
||||
return prom.result
|
||||
|
||||
|
||||
/-- Waits for and returns final environment, if importing was successful. -/
|
||||
partial def waitForFinalEnv? (snap : InitialSnapshot) : Option Environment := do
|
||||
/-- Waits for and returns final command state, if importing was successful. -/
|
||||
partial def waitForFinalCmdState? (snap : InitialSnapshot) : Option Command.State := do
|
||||
let snap ← snap.result?
|
||||
let snap ← snap.processedSnap.get.result?
|
||||
goCmd snap.firstCmdSnap.get
|
||||
@@ -535,6 +628,6 @@ where goCmd snap :=
|
||||
if let some next := snap.nextCmdSnap? then
|
||||
goCmd next.get
|
||||
else
|
||||
snap.data.finishedSnap.get.cmdState.env
|
||||
snap.data.finishedSnap.get.cmdState
|
||||
|
||||
end Lean
|
||||
|
||||
@@ -5,8 +5,6 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Array.QSort
|
||||
import Lean.Data.HashMap
|
||||
import Lean.Data.HashSet
|
||||
import Lean.Data.PersistentHashMap
|
||||
import Lean.Data.PersistentHashSet
|
||||
import Lean.Hygiene
|
||||
@@ -614,9 +612,9 @@ where
|
||||
|
||||
end Level
|
||||
|
||||
abbrev LevelMap (α : Type) := HashMap Level α
|
||||
abbrev LevelMap (α : Type) := Std.HashMap Level α
|
||||
abbrev PersistentLevelMap (α : Type) := PHashMap Level α
|
||||
abbrev LevelSet := HashSet Level
|
||||
abbrev LevelSet := Std.HashSet Level
|
||||
abbrev PersistentLevelSet := PHashSet Level
|
||||
abbrev PLevelSet := PersistentLevelSet
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ def constructorNameAsVariable : Linter where
|
||||
| return
|
||||
|
||||
let infoTrees := (← get).infoState.trees.toArray
|
||||
let warnings : IO.Ref (Lean.HashMap String.Range (Syntax × Name × Name)) ← IO.mkRef {}
|
||||
let warnings : IO.Ref (Std.HashMap String.Range (Syntax × Name × Name)) ← IO.mkRef {}
|
||||
|
||||
for tree in infoTrees do
|
||||
tree.visitM' (preNode := fun ci info _ => do
|
||||
|
||||
@@ -149,7 +149,7 @@ def checkDecl : SimpleHandler := fun stx => do
|
||||
lintField rest[1][0] stx[1] "computed field"
|
||||
else if rest.getKind == ``«structure» then
|
||||
unless rest[5][2].isNone do
|
||||
let redecls : HashSet String.Pos :=
|
||||
let redecls : Std.HashSet String.Pos :=
|
||||
(← get).infoState.trees.foldl (init := {}) fun s tree =>
|
||||
tree.foldInfo (init := s) fun _ info s =>
|
||||
if let .ofFieldRedeclInfo info := info then
|
||||
|
||||
@@ -270,14 +270,14 @@ pointer identity and does not store the objects, so it is important not to store
|
||||
pointer to an object in the map, or it can be freed and reused, resulting in incorrect behavior.
|
||||
|
||||
Returns `true` if the object was not already in the set. -/
|
||||
unsafe def insertObjImpl {α : Type} (set : IO.Ref (HashSet USize)) (a : α) : IO Bool := do
|
||||
unsafe def insertObjImpl {α : Type} (set : IO.Ref (Std.HashSet USize)) (a : α) : IO Bool := do
|
||||
if (← set.get).contains (ptrAddrUnsafe a) then
|
||||
return false
|
||||
set.modify (·.insert (ptrAddrUnsafe a))
|
||||
return true
|
||||
|
||||
@[inherit_doc insertObjImpl, implemented_by insertObjImpl]
|
||||
opaque insertObj {α : Type} (set : IO.Ref (HashSet USize)) (a : α) : IO Bool
|
||||
opaque insertObj {α : Type} (set : IO.Ref (Std.HashSet USize)) (a : α) : IO Bool
|
||||
|
||||
/--
|
||||
Collects into `fvarUses` all `fvar`s occurring in the `Expr`s in `assignments`.
|
||||
@@ -285,8 +285,8 @@ This implementation respects subterm sharing in both the `PersistentHashMap` and
|
||||
to ensure that pointer-equal subobjects are not visited multiple times, which is important
|
||||
in practice because these expressions are very frequently highly shared.
|
||||
-/
|
||||
partial def visitAssignments (set : IO.Ref (HashSet USize))
|
||||
(fvarUses : IO.Ref (HashSet FVarId))
|
||||
partial def visitAssignments (set : IO.Ref (Std.HashSet USize))
|
||||
(fvarUses : IO.Ref (Std.HashSet FVarId))
|
||||
(assignments : Array (PersistentHashMap MVarId Expr)) : IO Unit := do
|
||||
MonadCacheT.run do
|
||||
for assignment in assignments do
|
||||
@@ -316,8 +316,8 @@ where
|
||||
/-- Given `aliases` as a map from an alias to what it aliases, we get the original
|
||||
term by recursion. This has no cycle detection, so if `aliases` contains a loop
|
||||
then this function will recurse infinitely. -/
|
||||
partial def followAliases (aliases : HashMap FVarId FVarId) (x : FVarId) : FVarId :=
|
||||
match aliases.find? x with
|
||||
partial def followAliases (aliases : Std.HashMap FVarId FVarId) (x : FVarId) : FVarId :=
|
||||
match aliases[x]? with
|
||||
| none => x
|
||||
| some y => followAliases aliases y
|
||||
|
||||
@@ -343,17 +343,17 @@ structure References where
|
||||
the spans for `foo`, `bar`, and `baz`. Global definitions are always treated as used.
|
||||
(It would be nice to be able to detect unused global definitions but this requires more
|
||||
information than the linter framework can provide.) -/
|
||||
constDecls : HashSet String.Range := .empty
|
||||
constDecls : Std.HashSet String.Range := .empty
|
||||
/-- The collection of all local declarations, organized by the span of the declaration.
|
||||
We collapse all declarations declared at the same position into a single record using
|
||||
`FVarDefinition.aliases`. -/
|
||||
fvarDefs : HashMap String.Range FVarDefinition := .empty
|
||||
fvarDefs : Std.HashMap String.Range FVarDefinition := .empty
|
||||
/-- The set of `FVarId`s that are used directly. These may or may not be aliases. -/
|
||||
fvarUses : HashSet FVarId := .empty
|
||||
fvarUses : Std.HashSet FVarId := .empty
|
||||
/-- A mapping from alias to original FVarId. We don't guarantee that the value is not itself
|
||||
an alias, but we use `followAliases` when adding new elements to try to avoid long chains. -/
|
||||
-- TODO: use a `UnionFind` data structure here
|
||||
fvarAliases : HashMap FVarId FVarId := .empty
|
||||
fvarAliases : Std.HashMap FVarId FVarId := .empty
|
||||
/-- Collection of all `MetavarContext`s following the execution of a tactic. We trawl these
|
||||
if needed to find additional `fvarUses`. -/
|
||||
assignments : Array (PersistentHashMap MVarId Expr) := #[]
|
||||
@@ -391,7 +391,7 @@ def collectReferences (infoTrees : Array Elab.InfoTree) (cmdStxRange : String.Ra
|
||||
if s.startsWith "_" then return
|
||||
-- Record this either as a new `fvarDefs`, or an alias of an existing one
|
||||
modify fun s =>
|
||||
if let some ref := s.fvarDefs.find? range then
|
||||
if let some ref := s.fvarDefs[range]? then
|
||||
{ s with fvarDefs := s.fvarDefs.insert range { ref with aliases := ref.aliases.push id } }
|
||||
else
|
||||
{ s with fvarDefs := s.fvarDefs.insert range { userName := ldecl.userName, stx, opts, aliases := #[id] } }
|
||||
@@ -444,7 +444,7 @@ def unusedVariables : Linter where
|
||||
-- Resolve all recursive references in `fvarAliases`.
|
||||
-- At this point everything in `fvarAliases` is guaranteed not to be itself an alias,
|
||||
-- and should point to some element of `FVarDefinition.aliases` in `s.fvarDefs`
|
||||
let fvarAliases : HashMap FVarId FVarId := s.fvarAliases.fold (init := {}) fun m id baseId =>
|
||||
let fvarAliases : Std.HashMap FVarId FVarId := s.fvarAliases.fold (init := {}) fun m id baseId =>
|
||||
m.insert id (followAliases s.fvarAliases baseId)
|
||||
|
||||
-- Collect all non-alias fvars corresponding to `fvarUses` by resolving aliases in the list.
|
||||
@@ -461,7 +461,7 @@ def unusedVariables : Linter where
|
||||
let fvarUses ← fvarUsesRef.get
|
||||
-- If any of the `fvar`s corresponding to this declaration is (an alias of) a variable in
|
||||
-- `fvarUses`, then it is used
|
||||
if aliases.any fun id => fvarUses.contains (fvarAliases.findD id id) then continue
|
||||
if aliases.any fun id => fvarUses.contains (fvarAliases.getD id id) then continue
|
||||
-- If this is a global declaration then it is (potentially) used after the command
|
||||
if s.constDecls.contains range then continue
|
||||
|
||||
@@ -496,7 +496,7 @@ def unusedVariables : Linter where
|
||||
initializedMVars := true
|
||||
let fvarUses ← fvarUsesRef.get
|
||||
-- Redo the initial check because `fvarUses` could be bigger now
|
||||
if aliases.any fun id => fvarUses.contains (fvarAliases.findD id id) then continue
|
||||
if aliases.any fun id => fvarUses.contains (fvarAliases.getD id id) then continue
|
||||
|
||||
-- If we made it this far then the variable is unused and not ignored
|
||||
unused := unused.push (declStx, userName)
|
||||
|
||||
@@ -16,8 +16,8 @@ structure State where
|
||||
nextParamIdx : Nat := 0
|
||||
paramNames : Array Name := #[]
|
||||
fvars : Array Expr := #[]
|
||||
lmap : HashMap LMVarId Level := {}
|
||||
emap : HashMap MVarId Expr := {}
|
||||
lmap : Std.HashMap LMVarId Level := {}
|
||||
emap : Std.HashMap MVarId Expr := {}
|
||||
abstractLevels : Bool -- whether to abstract level mvars
|
||||
|
||||
abbrev M := StateM State
|
||||
@@ -54,7 +54,7 @@ private partial def abstractLevelMVars (u : Level) : M Level := do
|
||||
if depth != s.mctx.depth then
|
||||
return u -- metavariables from lower depths are treated as constants
|
||||
else
|
||||
match s.lmap.find? mvarId with
|
||||
match s.lmap[mvarId]? with
|
||||
| some u => pure u
|
||||
| none =>
|
||||
let paramId := Name.mkNum `_abstMVar s.nextParamIdx
|
||||
@@ -87,7 +87,7 @@ partial def abstractExprMVars (e : Expr) : M Expr := do
|
||||
if e != eNew then
|
||||
abstractExprMVars eNew
|
||||
else
|
||||
match (← get).emap.find? mvarId with
|
||||
match (← get).emap[mvarId]? with
|
||||
| some e =>
|
||||
return e
|
||||
| none =>
|
||||
|
||||
@@ -5,9 +5,9 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Util.ShareCommon
|
||||
import Lean.Data.HashMap
|
||||
import Lean.Meta.Basic
|
||||
import Lean.Meta.FunInfo
|
||||
import Std.Data.HashMap.Raw
|
||||
|
||||
namespace Lean.Meta
|
||||
namespace Canonicalizer
|
||||
@@ -47,12 +47,12 @@ State for the `CanonM` monad.
|
||||
-/
|
||||
structure State where
|
||||
/-- Mapping from `Expr` to hash. -/
|
||||
-- We use `HashMapImp` to ensure we don't have to tag `State` as `unsafe`.
|
||||
cache : HashMapImp ExprVisited UInt64 := mkHashMapImp
|
||||
-- We use `HashMap.Raw` to ensure we don't have to tag `State` as `unsafe`.
|
||||
cache : Std.HashMap.Raw ExprVisited UInt64 := Std.HashMap.Raw.empty
|
||||
/--
|
||||
Given a hashcode `k` and `keyToExprs.find? h = some es`, we have that all `es` have hashcode `k`, and
|
||||
are not definitionally equal modulo the transparency setting used. -/
|
||||
keyToExprs : HashMap UInt64 (List Expr) := mkHashMap
|
||||
keyToExprs : Std.HashMap UInt64 (List Expr) := ∅
|
||||
|
||||
instance : Inhabited State where
|
||||
default := {}
|
||||
@@ -70,7 +70,7 @@ def CanonM.run (x : CanonM α) (transparency := TransparencyMode.instances) (s :
|
||||
StateRefT'.run (x transparency) s
|
||||
|
||||
private partial def mkKey (e : Expr) : CanonM UInt64 := do
|
||||
if let some hash := unsafe (← get).cache.find? { e } then
|
||||
if let some hash := unsafe (← get).cache.get? { e } then
|
||||
return hash
|
||||
else
|
||||
let key ← match e with
|
||||
@@ -107,7 +107,7 @@ private partial def mkKey (e : Expr) : CanonM UInt64 := do
|
||||
return mixHash (← mkKey v) (← mkKey b)
|
||||
| .proj _ i s =>
|
||||
return mixHash i.toUInt64 (← mkKey s)
|
||||
unsafe modify fun { cache, keyToExprs} => { keyToExprs, cache := cache.insert { e } key |>.1 }
|
||||
unsafe modify fun { cache, keyToExprs} => { keyToExprs, cache := cache.insert { e } key }
|
||||
return key
|
||||
|
||||
/--
|
||||
@@ -116,7 +116,7 @@ private partial def mkKey (e : Expr) : CanonM UInt64 := do
|
||||
def canon (e : Expr) : CanonM Expr := do
|
||||
let k ← mkKey e
|
||||
-- Find all expressions canonicalized before that have the same key.
|
||||
if let some es' := unsafe (← get).keyToExprs.find? k then
|
||||
if let some es' := unsafe (← get).keyToExprs[k]? then
|
||||
withTransparency (← read) do
|
||||
for e' in es' do
|
||||
-- Found an expression `e'` that is definitionally equal to `e` and share the same key.
|
||||
|
||||
@@ -127,7 +127,7 @@ abbrev ClosureM := ReaderT Context $ StateRefT State MetaM
|
||||
pure u
|
||||
else
|
||||
let s ← get
|
||||
match s.visitedLevel.find? u with
|
||||
match s.visitedLevel[u]? with
|
||||
| some v => pure v
|
||||
| none => do
|
||||
let v ← f u
|
||||
@@ -139,7 +139,7 @@ abbrev ClosureM := ReaderT Context $ StateRefT State MetaM
|
||||
pure e
|
||||
else
|
||||
let s ← get
|
||||
match s.visitedExpr.find? e with
|
||||
match s.visitedExpr.get? e with
|
||||
| some r => pure r
|
||||
| none =>
|
||||
let r ← f e
|
||||
|
||||
@@ -52,14 +52,14 @@ which appear in the type and local context of `mvarId`, as well as the
|
||||
metavariables which *those* metavariables depend on, etc.
|
||||
-/
|
||||
partial def _root_.Lean.MVarId.getMVarDependencies (mvarId : MVarId) (includeDelayed := false) :
|
||||
MetaM (HashSet MVarId) :=
|
||||
MetaM (Std.HashSet MVarId) :=
|
||||
(·.snd) <$> (go mvarId).run {}
|
||||
where
|
||||
/-- Auxiliary definition for `getMVarDependencies`. -/
|
||||
addMVars (e : Expr) : StateRefT (HashSet MVarId) MetaM Unit := do
|
||||
addMVars (e : Expr) : StateRefT (Std.HashSet MVarId) MetaM Unit := do
|
||||
let mvars ← getMVars e
|
||||
let mut s ← get
|
||||
set ({} : HashSet MVarId) -- Ensure that `s` is not shared.
|
||||
set ({} : Std.HashSet MVarId) -- Ensure that `s` is not shared.
|
||||
for mvarId in mvars do
|
||||
if ← pure includeDelayed <||> notM (mvarId.isDelayedAssigned) then
|
||||
s := s.insert mvarId
|
||||
@@ -67,7 +67,7 @@ where
|
||||
mvars.forM go
|
||||
|
||||
/-- Auxiliary definition for `getMVarDependencies`. -/
|
||||
go (mvarId : MVarId) : StateRefT (HashSet MVarId) MetaM Unit :=
|
||||
go (mvarId : MVarId) : StateRefT (Std.HashSet MVarId) MetaM Unit :=
|
||||
withIncRecDepth do
|
||||
let mdecl ← mvarId.getDecl
|
||||
addMVars mdecl.type
|
||||
|
||||
@@ -35,7 +35,7 @@ structure DiagSummary where
|
||||
def DiagSummary.isEmpty (s : DiagSummary) : Bool :=
|
||||
s.data.isEmpty
|
||||
|
||||
def mkDiagSummary (counters : PHashMap Name Nat) (p : Name → Bool := fun _ => true) : MetaM DiagSummary := do
|
||||
def mkDiagSummary (cls : Name) (counters : PHashMap Name Nat) (p : Name → Bool := fun _ => true) : MetaM DiagSummary := do
|
||||
let threshold := diagnostics.threshold.get (← getOptions)
|
||||
let entries := collectAboveThreshold counters threshold p Name.lt
|
||||
if entries.isEmpty then
|
||||
@@ -43,22 +43,22 @@ def mkDiagSummary (counters : PHashMap Name Nat) (p : Name → Bool := fun _ =>
|
||||
else
|
||||
let mut data := #[]
|
||||
for (declName, counter) in entries do
|
||||
data := data.push m!"{if data.isEmpty then " " else "\n"}{MessageData.ofConst (← mkConstWithLevelParams declName)} ↦ {counter}"
|
||||
data := data.push <| .trace { cls } m!"{MessageData.ofConst (← mkConstWithLevelParams declName)} ↦ {counter}" #[]
|
||||
return { data, max := entries[0]!.2 }
|
||||
|
||||
def mkDiagSummaryForUnfolded (counters : PHashMap Name Nat) (instances := false) : MetaM DiagSummary := do
|
||||
let env ← getEnv
|
||||
mkDiagSummary counters fun declName =>
|
||||
mkDiagSummary `reduction counters fun declName =>
|
||||
getReducibilityStatusCore env declName matches .semireducible
|
||||
&& isInstanceCore env declName == instances
|
||||
|
||||
def mkDiagSummaryForUnfoldedReducible (counters : PHashMap Name Nat) : MetaM DiagSummary := do
|
||||
let env ← getEnv
|
||||
mkDiagSummary counters fun declName =>
|
||||
mkDiagSummary `reduction counters fun declName =>
|
||||
getReducibilityStatusCore env declName matches .reducible
|
||||
|
||||
def mkDiagSummaryForUsedInstances : MetaM DiagSummary := do
|
||||
mkDiagSummary (← get).diag.instanceCounter
|
||||
mkDiagSummary `type_class (← get).diag.instanceCounter
|
||||
|
||||
def mkDiagSynthPendingFailure (failures : PHashMap Expr MessageData) : MetaM DiagSummary := do
|
||||
if failures.isEmpty then
|
||||
@@ -66,7 +66,7 @@ def mkDiagSynthPendingFailure (failures : PHashMap Expr MessageData) : MetaM Dia
|
||||
else
|
||||
let mut data := #[]
|
||||
for (_, msg) in failures do
|
||||
data := data.push m!"{if data.isEmpty then " " else "\n"}{msg}"
|
||||
data := data.push <| .trace { cls := `type_class } msg #[]
|
||||
return { data }
|
||||
|
||||
/--
|
||||
@@ -85,10 +85,10 @@ def reportDiag : MetaM Unit := do
|
||||
let unfoldDefault ← mkDiagSummaryForUnfolded unfoldCounter
|
||||
let unfoldInstance ← mkDiagSummaryForUnfolded unfoldCounter (instances := true)
|
||||
let unfoldReducible ← mkDiagSummaryForUnfoldedReducible unfoldCounter
|
||||
let heu ← mkDiagSummary (← get).diag.heuristicCounter
|
||||
let heu ← mkDiagSummary `def_eq (← get).diag.heuristicCounter
|
||||
let inst ← mkDiagSummaryForUsedInstances
|
||||
let synthPending ← mkDiagSynthPendingFailure (← get).diag.synthPendingFailures
|
||||
let unfoldKernel ← mkDiagSummary (Kernel.getDiagnostics (← getEnv)).unfoldCounter
|
||||
let unfoldKernel ← mkDiagSummary `kernel (Kernel.getDiagnostics (← getEnv)).unfoldCounter
|
||||
let m := MessageData.nil
|
||||
let m := appendSection m `reduction "unfolded declarations" unfoldDefault
|
||||
let m := appendSection m `reduction "unfolded instances" unfoldInstance
|
||||
|
||||
@@ -695,7 +695,7 @@ def throwOutOfScopeFVar : CheckAssignmentM α :=
|
||||
throw <| Exception.internal outOfScopeExceptionId
|
||||
|
||||
private def findCached? (e : Expr) : CheckAssignmentM (Option Expr) := do
|
||||
return (← get).cache.find? e
|
||||
return (← get).cache.get? e
|
||||
|
||||
private def cache (e r : Expr) : CheckAssignmentM Unit := do
|
||||
modify fun s => { s with cache := s.cache.insert e r }
|
||||
|
||||
@@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Data.AssocList
|
||||
import Lean.HeadIndex
|
||||
import Lean.Meta.Basic
|
||||
|
||||
|
||||
@@ -286,7 +286,7 @@ private structure Trie (α : Type) where
|
||||
/-- Index of trie matching star. -/
|
||||
star : TrieIndex
|
||||
/-- Following matches based on key of trie. -/
|
||||
children : HashMap Key TrieIndex
|
||||
children : Std.HashMap Key TrieIndex
|
||||
/-- Lazy entries at this trie that are not processed. -/
|
||||
pending : Array (LazyEntry α) := #[]
|
||||
deriving Inhabited
|
||||
@@ -318,7 +318,7 @@ structure LazyDiscrTree (α : Type) where
|
||||
/-- Backing array of trie entries. Should be owned by this trie. -/
|
||||
tries : Array (LazyDiscrTree.Trie α) := #[default]
|
||||
/-- Map from discriminator trie roots to the index. -/
|
||||
roots : Lean.HashMap LazyDiscrTree.Key LazyDiscrTree.TrieIndex := {}
|
||||
roots : Std.HashMap LazyDiscrTree.Key LazyDiscrTree.TrieIndex := {}
|
||||
|
||||
namespace LazyDiscrTree
|
||||
|
||||
@@ -445,9 +445,9 @@ private def addLazyEntryToTrie (i:TrieIndex) (e : LazyEntry α) : MatchM α Unit
|
||||
modify (·.modify i (·.pushPending e))
|
||||
|
||||
private def evalLazyEntry (config : WhnfCoreConfig)
|
||||
(p : Array α × TrieIndex × HashMap Key TrieIndex)
|
||||
(p : Array α × TrieIndex × Std.HashMap Key TrieIndex)
|
||||
(entry : LazyEntry α)
|
||||
: MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
|
||||
: MatchM α (Array α × TrieIndex × Std.HashMap Key TrieIndex) := do
|
||||
let (values, starIdx, children) := p
|
||||
let (todo, lctx, v) := entry
|
||||
if todo.isEmpty then
|
||||
@@ -465,7 +465,7 @@ private def evalLazyEntry (config : WhnfCoreConfig)
|
||||
addLazyEntryToTrie starIdx (todo, lctx, v)
|
||||
pure (values, starIdx, children)
|
||||
else
|
||||
match children.find? k with
|
||||
match children[k]? with
|
||||
| none =>
|
||||
let children := children.insert k (← newTrie (todo, lctx, v))
|
||||
pure (values, starIdx, children)
|
||||
@@ -478,16 +478,16 @@ This evaluates all lazy entries in a trie and updates `values`, `starIdx`, and `
|
||||
accordingly.
|
||||
-/
|
||||
private partial def evalLazyEntries (config : WhnfCoreConfig)
|
||||
(values : Array α) (starIdx : TrieIndex) (children : HashMap Key TrieIndex)
|
||||
(values : Array α) (starIdx : TrieIndex) (children : Std.HashMap Key TrieIndex)
|
||||
(entries : Array (LazyEntry α)) :
|
||||
MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
|
||||
MatchM α (Array α × TrieIndex × Std.HashMap Key TrieIndex) := do
|
||||
let mut values := values
|
||||
let mut starIdx := starIdx
|
||||
let mut children := children
|
||||
entries.foldlM (init := (values, starIdx, children)) (evalLazyEntry config)
|
||||
|
||||
private def evalNode (c : TrieIndex) :
|
||||
MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
|
||||
MatchM α (Array α × TrieIndex × Std.HashMap Key TrieIndex) := do
|
||||
let .node vs star cs pending := (←get).get! c
|
||||
if pending.size = 0 then
|
||||
pure (vs, star, cs)
|
||||
@@ -508,7 +508,7 @@ def dropKeyAux (next : TrieIndex) (rest : List Key) :
|
||||
| [] =>
|
||||
modify (·.set! next {values := #[], star, children})
|
||||
| k :: r => do
|
||||
let next := if k == .star then star else children.findD k 0
|
||||
let next := if k == .star then star else children.getD k 0
|
||||
dropKeyAux next r
|
||||
|
||||
/--
|
||||
@@ -519,7 +519,7 @@ def dropKey (t : LazyDiscrTree α) (path : List LazyDiscrTree.Key) : MetaM (Lazy
|
||||
match path with
|
||||
| [] => pure t
|
||||
| rootKey :: rest => do
|
||||
let idx := t.roots.findD rootKey 0
|
||||
let idx := t.roots.getD rootKey 0
|
||||
Prod.snd <$> runMatch t (dropKeyAux idx rest)
|
||||
|
||||
/--
|
||||
@@ -628,7 +628,7 @@ private partial def getMatchLoop (cases : Array PartialMatch) (result : MatchRes
|
||||
else
|
||||
cases.push { todo, score := ca.score, c := star }
|
||||
let pushNonStar (k : Key) (args : Array Expr) (cases : Array PartialMatch) :=
|
||||
match cs.find? k with
|
||||
match cs[k]? with
|
||||
| none => cases
|
||||
| some c => cases.push { todo := todo ++ args, score := ca.score + 1, c }
|
||||
let cases := pushStar cases
|
||||
@@ -650,8 +650,8 @@ private partial def getMatchLoop (cases : Array PartialMatch) (result : MatchRes
|
||||
cases |> pushNonStar k args
|
||||
getMatchLoop cases result
|
||||
|
||||
private def getStarResult (root : Lean.HashMap Key TrieIndex) : MatchM α (MatchResult α) :=
|
||||
match root.find? .star with
|
||||
private def getStarResult (root : Std.HashMap Key TrieIndex) : MatchM α (MatchResult α) :=
|
||||
match root[Key.star]? with
|
||||
| none =>
|
||||
pure <| {}
|
||||
| some idx => do
|
||||
@@ -661,16 +661,16 @@ private def getStarResult (root : Lean.HashMap Key TrieIndex) : MatchM α (Match
|
||||
/-
|
||||
Add partial match to cases if discriminator tree root map has potential matches.
|
||||
-/
|
||||
private def pushRootCase (r : Lean.HashMap Key TrieIndex) (k : Key) (args : Array Expr)
|
||||
private def pushRootCase (r : Std.HashMap Key TrieIndex) (k : Key) (args : Array Expr)
|
||||
(cases : Array PartialMatch) : Array PartialMatch :=
|
||||
match r.find? k with
|
||||
match r[k]? with
|
||||
| none => cases
|
||||
| some c => cases.push { todo := args, score := 1, c }
|
||||
|
||||
/--
|
||||
Find values that match `e` in `root`.
|
||||
-/
|
||||
private def getMatchCore (root : Lean.HashMap Key TrieIndex) (e : Expr) :
|
||||
private def getMatchCore (root : Std.HashMap Key TrieIndex) (e : Expr) :
|
||||
MatchM α (MatchResult α) := do
|
||||
let result ← getStarResult root
|
||||
let (k, args) ← MatchClone.getMatchKeyArgs e (root := true) (← read)
|
||||
@@ -701,7 +701,7 @@ of elements using concurrent functions for generating entries.
|
||||
-/
|
||||
private structure PreDiscrTree (α : Type) where
|
||||
/-- Maps keys to index in tries array. -/
|
||||
roots : HashMap Key Nat := {}
|
||||
roots : Std.HashMap Key Nat := {}
|
||||
/-- Lazy entries for root of trie. -/
|
||||
tries : Array (Array (LazyEntry α)) := #[]
|
||||
deriving Inhabited
|
||||
@@ -711,7 +711,7 @@ namespace PreDiscrTree
|
||||
private def modifyAt (d : PreDiscrTree α) (k : Key)
|
||||
(f : Array (LazyEntry α) → Array (LazyEntry α)) : PreDiscrTree α :=
|
||||
let { roots, tries } := d
|
||||
match roots.find? k with
|
||||
match roots[k]? with
|
||||
| .none =>
|
||||
let roots := roots.insert k tries.size
|
||||
{ roots, tries := tries.push (f #[]) }
|
||||
|
||||
@@ -68,7 +68,7 @@ where
|
||||
loop lhss alts minors
|
||||
|
||||
structure State where
|
||||
used : HashSet Nat := {} -- used alternatives
|
||||
used : Std.HashSet Nat := {} -- used alternatives
|
||||
counterExamples : List (List Example) := []
|
||||
|
||||
/-- Return true if the given (sub-)problem has been solved. -/
|
||||
|
||||
@@ -28,17 +28,17 @@ such as `contradiction`.
|
||||
-/
|
||||
private def _root_.Lean.MVarId.contradictionQuick (mvarId : MVarId) : MetaM Bool := do
|
||||
mvarId.withContext do
|
||||
let mut posMap : HashMap Expr FVarId := {}
|
||||
let mut negMap : HashMap Expr FVarId := {}
|
||||
let mut posMap : Std.HashMap Expr FVarId := {}
|
||||
let mut negMap : Std.HashMap Expr FVarId := {}
|
||||
for localDecl in (← getLCtx) do
|
||||
unless localDecl.isImplementationDetail do
|
||||
if let some p ← matchNot? localDecl.type then
|
||||
if let some pFVarId := posMap.find? p then
|
||||
if let some pFVarId := posMap[p]? then
|
||||
mvarId.assign (← mkAbsurd (← mvarId.getType) (mkFVar pFVarId) localDecl.toExpr)
|
||||
return true
|
||||
negMap := negMap.insert p localDecl.fvarId
|
||||
if (← isProp localDecl.type) then
|
||||
if let some nFVarId := negMap.find? localDecl.type then
|
||||
if let some nFVarId := negMap[localDecl.type]? then
|
||||
mvarId.assign (← mkAbsurd (← mvarId.getType) localDecl.toExpr (mkFVar nFVarId))
|
||||
return true
|
||||
posMap := posMap.insert localDecl.type localDecl.fvarId
|
||||
|
||||
@@ -97,8 +97,8 @@ namespace MkTableKey
|
||||
|
||||
structure State where
|
||||
nextIdx : Nat := 0
|
||||
lmap : HashMap LMVarId Level := {}
|
||||
emap : HashMap MVarId Expr := {}
|
||||
lmap : Std.HashMap LMVarId Level := {}
|
||||
emap : Std.HashMap MVarId Expr := {}
|
||||
mctx : MetavarContext
|
||||
|
||||
abbrev M := StateM State
|
||||
@@ -120,7 +120,7 @@ partial def normLevel (u : Level) : M Level := do
|
||||
return u
|
||||
else
|
||||
let s ← get
|
||||
match (← get).lmap.find? mvarId with
|
||||
match (← get).lmap[mvarId]? with
|
||||
| some u' => pure u'
|
||||
| none =>
|
||||
let u' := mkLevelParam <| Name.mkNum `_tc s.nextIdx
|
||||
@@ -145,7 +145,7 @@ partial def normExpr (e : Expr) : M Expr := do
|
||||
return e
|
||||
else
|
||||
let s ← get
|
||||
match s.emap.find? mvarId with
|
||||
match s.emap[mvarId]? with
|
||||
| some e' => pure e'
|
||||
| none => do
|
||||
let e' := mkFVar { name := Name.mkNum `_tc s.nextIdx }
|
||||
@@ -186,7 +186,7 @@ structure State where
|
||||
result? : Option AbstractMVarsResult := none
|
||||
generatorStack : Array GeneratorNode := #[]
|
||||
resumeStack : Array (ConsumerNode × Answer) := #[]
|
||||
tableEntries : HashMap Expr TableEntry := {}
|
||||
tableEntries : Std.HashMap Expr TableEntry := {}
|
||||
|
||||
abbrev SynthM := ReaderT Context $ StateRefT State MetaM
|
||||
|
||||
@@ -265,7 +265,7 @@ def newSubgoal (mctx : MetavarContext) (key : Expr) (mvar : Expr) (waiter : Wait
|
||||
pure ((), m!"new goal {key}")
|
||||
|
||||
def findEntry? (key : Expr) : SynthM (Option TableEntry) := do
|
||||
return (← get).tableEntries.find? key
|
||||
return (← get).tableEntries[key]?
|
||||
|
||||
def getEntry (key : Expr) : SynthM TableEntry := do
|
||||
match (← findEntry? key) with
|
||||
@@ -553,7 +553,7 @@ def generate : SynthM Unit := do
|
||||
/- See comment at `typeHasMVars` -/
|
||||
if backward.synthInstance.canonInstances.get (← getOptions) then
|
||||
unless gNode.typeHasMVars do
|
||||
if let some entry := (← get).tableEntries.find? key then
|
||||
if let some entry := (← get).tableEntries[key]? then
|
||||
if entry.answers.any fun answer => answer.result.numMVars == 0 then
|
||||
/-
|
||||
We already have an answer that:
|
||||
|
||||
@@ -66,9 +66,9 @@ inductive PreExpr
|
||||
def toACExpr (op l r : Expr) : MetaM (Array Expr × ACExpr) := do
|
||||
let (preExpr, vars) ←
|
||||
toPreExpr (mkApp2 op l r)
|
||||
|>.run HashSet.empty
|
||||
|>.run Std.HashSet.empty
|
||||
let vars := vars.toArray.insertionSort Expr.lt
|
||||
let varMap := vars.foldl (fun xs x => xs.insert x xs.size) HashMap.empty |>.find!
|
||||
let varMap := vars.foldl (fun xs x => xs.insert x xs.size) Std.HashMap.empty |>.get!
|
||||
|
||||
return (vars, toACExpr varMap preExpr)
|
||||
where
|
||||
|
||||
@@ -290,7 +290,7 @@ structure RewriteResultConfig where
|
||||
side : SideConditions := .solveByElim
|
||||
mctx : MetavarContext
|
||||
|
||||
def takeListAux (cfg : RewriteResultConfig) (seen : HashMap String Unit) (acc : Array RewriteResult)
|
||||
def takeListAux (cfg : RewriteResultConfig) (seen : Std.HashMap String Unit) (acc : Array RewriteResult)
|
||||
(xs : List ((Expr ⊕ Name) × Bool × Nat)) : MetaM (Array RewriteResult) := do
|
||||
let mut seen := seen
|
||||
let mut acc := acc
|
||||
|
||||
@@ -32,7 +32,7 @@ def mkSimpDiagSummary (counters : PHashMap Origin Nat) (usedCounters? : Option (
|
||||
if let some c := usedCounters.find? thmId then pure s!", succeeded: {c}" else pure s!" {crossEmoji}" -- not used
|
||||
else
|
||||
pure ""
|
||||
data := data.push m!"{if data.isEmpty then " " else "\n"}{key} ↦ {counter}{usedMsg}"
|
||||
data := data.push <| .trace { cls := `simp } m!"{key} ↦ {counter}{usedMsg}" #[]
|
||||
return { data, max := entries[0]!.2 }
|
||||
|
||||
private def mkTheoremsWithBadKeySummary (thms : PArray SimpTheorem) : MetaM DiagSummary := do
|
||||
@@ -41,7 +41,7 @@ private def mkTheoremsWithBadKeySummary (thms : PArray SimpTheorem) : MetaM Diag
|
||||
else
|
||||
let mut data := #[]
|
||||
for thm in thms do
|
||||
data := data.push m!"{if data.isEmpty then " " else "\n"}{← originToKey thm.origin}, key: {← DiscrTree.keysAsPattern thm.keys}"
|
||||
data := data.push <| .trace { cls := `simp } m!"{← originToKey thm.origin}, key: {← DiscrTree.keysAsPattern thm.keys}" #[]
|
||||
pure ()
|
||||
return { data }
|
||||
|
||||
@@ -49,7 +49,7 @@ def reportDiag (diag : Simp.Diagnostics) : MetaM Unit := do
|
||||
if (← isDiagnosticsEnabled) then
|
||||
let used ← mkSimpDiagSummary diag.usedThmCounter
|
||||
let tried ← mkSimpDiagSummary diag.triedThmCounter diag.usedThmCounter
|
||||
let congr ← mkDiagSummary diag.congrThmCounter
|
||||
let congr ← mkDiagSummary `simp diag.congrThmCounter
|
||||
let thmsWithBadKeys ← mkTheoremsWithBadKeySummary diag.thmsWithBadKeys
|
||||
unless used.isEmpty && tried.isEmpty && congr.isEmpty && thmsWithBadKeys.isEmpty do
|
||||
let m := MessageData.nil
|
||||
|
||||
@@ -420,12 +420,12 @@ def mkSimpExt (name : Name := by exact decl_name%) : IO SimpExtension :=
|
||||
| .toUnfoldThms n thms => d.registerDeclToUnfoldThms n thms
|
||||
}
|
||||
|
||||
abbrev SimpExtensionMap := HashMap Name SimpExtension
|
||||
abbrev SimpExtensionMap := Std.HashMap Name SimpExtension
|
||||
|
||||
builtin_initialize simpExtensionMapRef : IO.Ref SimpExtensionMap ← IO.mkRef {}
|
||||
|
||||
def getSimpExtension? (attrName : Name) : IO (Option SimpExtension) :=
|
||||
return (← simpExtensionMapRef.get).find? attrName
|
||||
return (← simpExtensionMapRef.get)[attrName]?
|
||||
|
||||
/-- Auxiliary method for adding a global declaration to a `SimpTheorems` datastructure. -/
|
||||
def SimpTheorems.addConst (s : SimpTheorems) (declName : Name) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM SimpTheorems := do
|
||||
|
||||
@@ -19,8 +19,8 @@ It contains:
|
||||
- The actual procedure associated with a name.
|
||||
-/
|
||||
structure BuiltinSimprocs where
|
||||
keys : HashMap Name (Array SimpTheoremKey) := {}
|
||||
procs : HashMap Name (Sum Simproc DSimproc) := {}
|
||||
keys : Std.HashMap Name (Array SimpTheoremKey) := {}
|
||||
procs : Std.HashMap Name (Sum Simproc DSimproc) := {}
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
@@ -37,7 +37,7 @@ structure SimprocDecl where
|
||||
deriving Inhabited
|
||||
|
||||
structure SimprocDeclExtState where
|
||||
builtin : HashMap Name (Array SimpTheoremKey)
|
||||
builtin : Std.HashMap Name (Array SimpTheoremKey)
|
||||
newEntries : PHashMap Name (Array SimpTheoremKey) := {}
|
||||
deriving Inhabited
|
||||
|
||||
@@ -65,7 +65,7 @@ def getSimprocDeclKeys? (declName : Name) : CoreM (Option (Array SimpTheoremKey)
|
||||
if let some keys := keys? then
|
||||
return some keys
|
||||
else
|
||||
return (simprocDeclExt.getState env).builtin.find? declName
|
||||
return (simprocDeclExt.getState env).builtin[declName]?
|
||||
|
||||
def isBuiltinSimproc (declName : Name) : CoreM Bool := do
|
||||
let s := simprocDeclExt.getState (← getEnv)
|
||||
@@ -160,7 +160,7 @@ def Simprocs.addCore (s : Simprocs) (keys : Array SimpTheoremKey) (declName : Na
|
||||
Implements attributes `builtin_simproc` and `builtin_sevalproc`.
|
||||
-/
|
||||
def addSimprocBuiltinAttrCore (ref : IO.Ref Simprocs) (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : IO Unit := do
|
||||
let some keys := (← builtinSimprocDeclsRef.get).keys.find? declName |
|
||||
let some keys := (← builtinSimprocDeclsRef.get).keys[declName]? |
|
||||
throw (IO.userError "invalid [builtin_simproc] attribute, '{declName}' is not a builtin simproc")
|
||||
ref.modify fun s => s.addCore keys declName post proc
|
||||
|
||||
@@ -176,7 +176,7 @@ def Simprocs.add (s : Simprocs) (declName : Name) (post : Bool) : CoreM Simprocs
|
||||
getSimprocFromDecl declName
|
||||
catch e =>
|
||||
if (← isBuiltinSimproc declName) then
|
||||
let some proc := (← builtinSimprocDeclsRef.get).procs.find? declName
|
||||
let some proc := (← builtinSimprocDeclsRef.get).procs[declName]?
|
||||
| throwError "invalid [simproc] attribute, '{declName}' is not a simproc"
|
||||
pure proc
|
||||
else
|
||||
@@ -384,7 +384,7 @@ def mkSimprocAttr (attrName : Name) (attrDescr : String) (ext : SimprocExtension
|
||||
erase := eraseSimprocAttr ext
|
||||
}
|
||||
|
||||
abbrev SimprocExtensionMap := HashMap Name SimprocExtension
|
||||
abbrev SimprocExtensionMap := Std.HashMap Name SimprocExtension
|
||||
|
||||
builtin_initialize simprocExtensionMapRef : IO.Ref SimprocExtensionMap ← IO.mkRef {}
|
||||
|
||||
@@ -438,7 +438,7 @@ def getSEvalSimprocs : CoreM Simprocs :=
|
||||
return simprocSEvalExtension.getState (← getEnv)
|
||||
|
||||
def getSimprocExtensionCore? (attrName : Name) : IO (Option SimprocExtension) :=
|
||||
return (← simprocExtensionMapRef.get).find? attrName
|
||||
return (← simprocExtensionMapRef.get)[attrName]?
|
||||
|
||||
def simpAttrNameToSimprocAttrName (attrName : Name) : Name :=
|
||||
if attrName == `simp then `simprocAttr
|
||||
|
||||
@@ -512,7 +512,7 @@ def mkCongrSimp? (f : Expr) : SimpM (Option CongrTheorem) := do
|
||||
if kinds.all fun k => match k with | CongrArgKind.fixed => true | CongrArgKind.eq => true | _ => false then
|
||||
/- See remark above. -/
|
||||
return none
|
||||
match (← get).congrCache.find? f with
|
||||
match (← get).congrCache[f]? with
|
||||
| some thm? => return thm?
|
||||
| none =>
|
||||
let thm? ← mkCongrSimpCore? f info kinds
|
||||
|
||||
@@ -336,6 +336,8 @@ structure MetavarContext where
|
||||
For more information about delayed abstraction, see the docstring for `DelayedMetavarAssignment`. -/
|
||||
dAssignment : PersistentHashMap MVarId DelayedMetavarAssignment := {}
|
||||
|
||||
instance : Inhabited MetavarContext := ⟨{}⟩
|
||||
|
||||
/-- A monad with a stateful metavariable context, defining `getMCtx` and `modifyMCtx`. -/
|
||||
class MonadMCtx (m : Type → Type) where
|
||||
getMCtx : m MetavarContext
|
||||
@@ -358,15 +360,27 @@ abbrev setMCtx [MonadMCtx m] (mctx : MetavarContext) : m Unit :=
|
||||
abbrev getLevelMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : LMVarId) : m (Option Level) :=
|
||||
return (← getMCtx).lAssignment.find? mvarId
|
||||
|
||||
@[export lean_get_lmvar_assignment]
|
||||
def getLevelMVarAssignmentExp (m : MetavarContext) (mvarId : LMVarId) : Option Level :=
|
||||
m.lAssignment.find? mvarId
|
||||
|
||||
def MetavarContext.getExprAssignmentCore? (m : MetavarContext) (mvarId : MVarId) : Option Expr :=
|
||||
m.eAssignment.find? mvarId
|
||||
|
||||
@[export lean_get_mvar_assignment]
|
||||
def MetavarContext.getExprAssignmentExp (m : MetavarContext) (mvarId : MVarId) : Option Expr :=
|
||||
m.eAssignment.find? mvarId
|
||||
|
||||
def getExprMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option Expr) :=
|
||||
return (← getMCtx).getExprAssignmentCore? mvarId
|
||||
|
||||
def MetavarContext.getDelayedMVarAssignmentCore? (mctx : MetavarContext) (mvarId : MVarId) : Option DelayedMetavarAssignment :=
|
||||
mctx.dAssignment.find? mvarId
|
||||
|
||||
@[export lean_get_delayed_mvar_assignment]
|
||||
def MetavarContext.getDelayedMVarAssignmentExp (mctx : MetavarContext) (mvarId : MVarId) : Option DelayedMetavarAssignment :=
|
||||
mctx.dAssignment.find? mvarId
|
||||
|
||||
def getDelayedMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option DelayedMetavarAssignment) :=
|
||||
return (← getMCtx).getDelayedMVarAssignmentCore? mvarId
|
||||
|
||||
@@ -478,6 +492,10 @@ def hasAssignableMVar [Monad m] [MonadMCtx m] : Expr → m Bool
|
||||
def assignLevelMVar [MonadMCtx m] (mvarId : LMVarId) (val : Level) : m Unit :=
|
||||
modifyMCtx fun m => { m with lAssignment := m.lAssignment.insert mvarId val }
|
||||
|
||||
@[export lean_assign_lmvar]
|
||||
def assignLevelMVarExp (m : MetavarContext) (mvarId : LMVarId) (val : Level) : MetavarContext :=
|
||||
{ m with lAssignment := m.lAssignment.insert mvarId val }
|
||||
|
||||
/--
|
||||
Add `mvarId := x` to the metavariable assignment.
|
||||
This method does not check whether `mvarId` is already assigned, nor it checks whether
|
||||
@@ -487,6 +505,10 @@ This is a low-level API, and it is safer to use `isDefEq (mkMVar mvarId) x`.
|
||||
def _root_.Lean.MVarId.assign [MonadMCtx m] (mvarId : MVarId) (val : Expr) : m Unit :=
|
||||
modifyMCtx fun m => { m with eAssignment := m.eAssignment.insert mvarId val }
|
||||
|
||||
@[export lean_assign_mvar]
|
||||
def assignExp (m : MetavarContext) (mvarId : MVarId) (val : Expr) : MetavarContext :=
|
||||
{ m with eAssignment := m.eAssignment.insert mvarId val }
|
||||
|
||||
/--
|
||||
Add a delayed assignment for the given metavariable. You must make sure that
|
||||
the metavariable is not already assigned or delayed-assigned.
|
||||
@@ -516,95 +538,22 @@ To avoid this term eta-expanded term, we apply beta-reduction when instantiating
|
||||
This operation is performed at `instantiateExprMVars`, `elimMVarDeps`, and `levelMVarToParam`.
|
||||
-/
|
||||
|
||||
partial def instantiateLevelMVars [Monad m] [MonadMCtx m] : Level → m Level
|
||||
| lvl@(Level.succ lvl₁) => return Level.updateSucc! lvl (← instantiateLevelMVars lvl₁)
|
||||
| lvl@(Level.max lvl₁ lvl₂) => return Level.updateMax! lvl (← instantiateLevelMVars lvl₁) (← instantiateLevelMVars lvl₂)
|
||||
| lvl@(Level.imax lvl₁ lvl₂) => return Level.updateIMax! lvl (← instantiateLevelMVars lvl₁) (← instantiateLevelMVars lvl₂)
|
||||
| lvl@(Level.mvar mvarId) => do
|
||||
match (← getLevelMVarAssignment? mvarId) with
|
||||
| some newLvl =>
|
||||
if !newLvl.hasMVar then pure newLvl
|
||||
else do
|
||||
let newLvl' ← instantiateLevelMVars newLvl
|
||||
assignLevelMVar mvarId newLvl'
|
||||
pure newLvl'
|
||||
| none => pure lvl
|
||||
| lvl => pure lvl
|
||||
@[extern "lean_instantiate_level_mvars"]
|
||||
opaque instantiateLevelMVarsImp (mctx : MetavarContext) (l : Level) : MetavarContext × Level
|
||||
|
||||
partial def instantiateLevelMVars [Monad m] [MonadMCtx m] (l : Level) : m Level := do
|
||||
let (mctx, lNew) := instantiateLevelMVarsImp (← getMCtx) l
|
||||
setMCtx mctx
|
||||
return lNew
|
||||
|
||||
@[extern "lean_instantiate_expr_mvars"]
|
||||
opaque instantiateExprMVarsImp (mctx : MetavarContext) (e : Expr) : MetavarContext × Expr
|
||||
|
||||
/-- instantiateExprMVars main function -/
|
||||
partial def instantiateExprMVars [Monad m] [MonadMCtx m] [STWorld ω m] [MonadLiftT (ST ω) m] (e : Expr) : MonadCacheT ExprStructEq Expr m Expr :=
|
||||
if !e.hasMVar then
|
||||
pure e
|
||||
else checkCache { val := e : ExprStructEq } fun _ => do match e with
|
||||
| .proj _ _ s => return e.updateProj! (← instantiateExprMVars s)
|
||||
| .forallE _ d b _ => return e.updateForallE! (← instantiateExprMVars d) (← instantiateExprMVars b)
|
||||
| .lam _ d b _ => return e.updateLambdaE! (← instantiateExprMVars d) (← instantiateExprMVars b)
|
||||
| .letE _ t v b _ => return e.updateLet! (← instantiateExprMVars t) (← instantiateExprMVars v) (← instantiateExprMVars b)
|
||||
| .const _ lvls => return e.updateConst! (← lvls.mapM instantiateLevelMVars)
|
||||
| .sort lvl => return e.updateSort! (← instantiateLevelMVars lvl)
|
||||
| .mdata _ b => return e.updateMData! (← instantiateExprMVars b)
|
||||
| .app .. => e.withApp fun f args => do
|
||||
let instArgs (f : Expr) : MonadCacheT ExprStructEq Expr m Expr := do
|
||||
let args ← args.mapM instantiateExprMVars
|
||||
pure (mkAppN f args)
|
||||
let instApp : MonadCacheT ExprStructEq Expr m Expr := do
|
||||
let wasMVar := f.isMVar
|
||||
let f ← instantiateExprMVars f
|
||||
if wasMVar && f.isLambda then
|
||||
/- Some of the arguments in `args` are irrelevant after we beta
|
||||
reduce. Also, it may be a bug to not instantiate them, since they
|
||||
may depend on free variables that are not in the context (see
|
||||
issue #4375). So we pass `useZeta := true` to ensure that they are
|
||||
instantiated. -/
|
||||
instantiateExprMVars (f.betaRev args.reverse (useZeta := true))
|
||||
else
|
||||
instArgs f
|
||||
match f with
|
||||
| .mvar mvarId =>
|
||||
match (← getDelayedMVarAssignment? mvarId) with
|
||||
| none => instApp
|
||||
| some { fvars, mvarIdPending } =>
|
||||
/-
|
||||
Apply "delayed substitution" (i.e., delayed assignment + application).
|
||||
That is, `f` is some metavariable `?m`, that is delayed assigned to `val`.
|
||||
If after instantiating `val`, we obtain `newVal`, and `newVal` does not contain
|
||||
metavariables, we replace the free variables `fvars` in `newVal` with the first
|
||||
`fvars.size` elements of `args`. -/
|
||||
if fvars.size > args.size then
|
||||
/- We don't have sufficient arguments for instantiating the free variables `fvars`.
|
||||
This can only happen if a tactic or elaboration function is not implemented correctly.
|
||||
We decided to not use `panic!` here and report it as an error in the frontend
|
||||
when we are checking for unassigned metavariables in an elaborated term. -/
|
||||
instArgs f
|
||||
else
|
||||
let newVal ← instantiateExprMVars (mkMVar mvarIdPending)
|
||||
if newVal.hasExprMVar then
|
||||
instArgs f
|
||||
else do
|
||||
let args ← args.mapM instantiateExprMVars
|
||||
/-
|
||||
Example: suppose we have
|
||||
`?m t1 t2 t3`
|
||||
That is, `f := ?m` and `args := #[t1, t2, t3]`
|
||||
Moreover, `?m` is delayed assigned
|
||||
`?m #[x, y] := f x y`
|
||||
where, `fvars := #[x, y]` and `newVal := f x y`.
|
||||
After abstracting `newVal`, we have `f (Expr.bvar 0) (Expr.bvar 1)`.
|
||||
After `instantiaterRevRange 0 2 args`, we have `f t1 t2`.
|
||||
After `mkAppRange 2 3`, we have `f t1 t2 t3` -/
|
||||
let newVal := newVal.abstract fvars
|
||||
let result := newVal.instantiateRevRange 0 fvars.size args
|
||||
let result := mkAppRange result fvars.size args.size args
|
||||
pure result
|
||||
| _ => instApp
|
||||
| e@(.mvar mvarId) => checkCache { val := e : ExprStructEq } fun _ => do
|
||||
match (← getExprMVarAssignment? mvarId) with
|
||||
| some newE => do
|
||||
let newE' ← instantiateExprMVars newE
|
||||
mvarId.assign newE'
|
||||
pure newE'
|
||||
| none => pure e
|
||||
| e => pure e
|
||||
def instantiateExprMVars [Monad m] [MonadMCtx m] (e : Expr) : m Expr := do
|
||||
let (mctx, eNew) := instantiateExprMVarsImp (← getMCtx) e
|
||||
setMCtx mctx
|
||||
return eNew
|
||||
|
||||
instance : MonadMCtx (StateRefT MetavarContext (ST ω)) where
|
||||
getMCtx := get
|
||||
@@ -792,8 +741,6 @@ def localDeclDependsOnPred [Monad m] [MonadMCtx m] (localDecl : LocalDecl) (pf :
|
||||
|
||||
namespace MetavarContext
|
||||
|
||||
instance : Inhabited MetavarContext := ⟨{}⟩
|
||||
|
||||
@[export lean_mk_metavar_ctx]
|
||||
def mkMetavarContext : Unit → MetavarContext := fun _ => {}
|
||||
|
||||
@@ -956,7 +903,7 @@ structure State where
|
||||
mctx : MetavarContext
|
||||
nextMacroScope : MacroScope
|
||||
ngen : NameGenerator
|
||||
cache : HashMap ExprStructEq Expr := {}
|
||||
cache : Std.HashMap ExprStructEq Expr := {}
|
||||
|
||||
structure Context where
|
||||
mainModule : Name
|
||||
@@ -1372,7 +1319,7 @@ structure State where
|
||||
mctx : MetavarContext
|
||||
paramNames : Array Name := #[]
|
||||
nextParamIdx : Nat
|
||||
cache : HashMap ExprStructEq Expr := {}
|
||||
cache : Std.HashMap ExprStructEq Expr := {}
|
||||
|
||||
abbrev M := ReaderT Context <| StateM State
|
||||
|
||||
@@ -1381,7 +1328,7 @@ instance : MonadMCtx M where
|
||||
modifyMCtx f := modify fun s => { s with mctx := f s.mctx }
|
||||
|
||||
instance : MonadCache ExprStructEq Expr M where
|
||||
findCached? e := return (← get).cache.find? e
|
||||
findCached? e := return (← get).cache[e]?
|
||||
cache e v := modify fun s => { s with cache := s.cache.insert e v }
|
||||
|
||||
partial def mkParamName : M Name := do
|
||||
|
||||
@@ -131,7 +131,7 @@ structure ParserCacheEntry where
|
||||
|
||||
structure ParserCache where
|
||||
tokenCache : TokenCacheEntry
|
||||
parserCache : HashMap ParserCacheKey ParserCacheEntry
|
||||
parserCache : Std.HashMap ParserCacheKey ParserCacheEntry
|
||||
|
||||
def initCacheForInput (input : String) : ParserCache where
|
||||
tokenCache := { startPos := input.endPos + ' ' /- make sure it is not a valid position -/ }
|
||||
@@ -418,7 +418,7 @@ place if there was an error.
|
||||
-/
|
||||
def withCacheFn (parserName : Name) (p : ParserFn) : ParserFn := fun c s => Id.run do
|
||||
let key := ⟨c.toCacheableParserContext, parserName, s.pos⟩
|
||||
if let some r := s.cache.parserCache.find? key then
|
||||
if let some r := s.cache.parserCache[key]? then
|
||||
-- TODO: turn this into a proper trace once we have these in the parser
|
||||
--dbg_trace "parser cache hit: {parserName}:{s.pos} -> {r.stx}"
|
||||
return ⟨s.stxStack.push r.stx, r.lhsPrec, r.newPos, s.cache, r.errorMsg, s.recoveredErrors⟩
|
||||
|
||||
@@ -198,10 +198,10 @@ def isHBinOp (e : Expr) : Bool := Id.run do
|
||||
def replaceLPsWithVars (e : Expr) : MetaM Expr := do
|
||||
if !e.hasLevelParam then return e
|
||||
let lps := collectLevelParams {} e |>.params
|
||||
let mut replaceMap : HashMap Name Level := {}
|
||||
let mut replaceMap : Std.HashMap Name Level := {}
|
||||
for lp in lps do replaceMap := replaceMap.insert lp (← mkFreshLevelMVar)
|
||||
return e.replaceLevel fun
|
||||
| Level.param n .. => replaceMap.find! n
|
||||
| Level.param n .. => replaceMap[n]!
|
||||
| l => if !l.hasParam then some l else none
|
||||
|
||||
def isDefEqAssigning (t s : Expr) : MetaM Bool := do
|
||||
|
||||
@@ -29,7 +29,7 @@ namespace Lean.Environment
|
||||
namespace Replay
|
||||
|
||||
structure Context where
|
||||
newConstants : HashMap Name ConstantInfo
|
||||
newConstants : Std.HashMap Name ConstantInfo
|
||||
|
||||
structure State where
|
||||
env : Environment
|
||||
@@ -73,7 +73,7 @@ and add it to the environment.
|
||||
-/
|
||||
partial def replayConstant (name : Name) : M Unit := do
|
||||
if ← isTodo name then
|
||||
let some ci := (← read).newConstants.find? name | unreachable!
|
||||
let some ci := (← read).newConstants[name]? | unreachable!
|
||||
replayConstants ci.getUsedConstantsAsSet
|
||||
-- Check that this name is still pending: a mutual block may have taken care of it.
|
||||
if (← get).pending.contains name then
|
||||
@@ -89,13 +89,13 @@ partial def replayConstant (name : Name) : M Unit := do
|
||||
| .inductInfo info =>
|
||||
let lparams := info.levelParams
|
||||
let nparams := info.numParams
|
||||
let all ← info.all.mapM fun n => do pure <| ((← read).newConstants.find! n)
|
||||
let all ← info.all.mapM fun n => do pure <| ((← read).newConstants[n]!)
|
||||
for o in all do
|
||||
modify fun s =>
|
||||
{ s with remaining := s.remaining.erase o.name, pending := s.pending.erase o.name }
|
||||
let ctorInfo ← all.mapM fun ci => do
|
||||
pure (ci, ← ci.inductiveVal!.ctors.mapM fun n => do
|
||||
pure ((← read).newConstants.find! n))
|
||||
pure ((← read).newConstants[n]!))
|
||||
-- Make sure we are really finished with the constructors.
|
||||
for (_, ctors) in ctorInfo do
|
||||
for ctor in ctors do
|
||||
@@ -129,7 +129,7 @@ when we replayed the inductives.
|
||||
-/
|
||||
def checkPostponedConstructors : M Unit := do
|
||||
for ctor in (← get).postponedConstructors do
|
||||
match (← get).env.constants.find? ctor, (← read).newConstants.find? ctor with
|
||||
match (← get).env.constants.find? ctor, (← read).newConstants[ctor]? with
|
||||
| some (.ctorInfo info), some (.ctorInfo info') =>
|
||||
if ! (info == info') then throw <| IO.userError s!"Invalid constructor {ctor}"
|
||||
| _, _ => throw <| IO.userError s!"No such constructor {ctor}"
|
||||
@@ -140,7 +140,7 @@ when we replayed the inductives.
|
||||
-/
|
||||
def checkPostponedRecursors : M Unit := do
|
||||
for ctor in (← get).postponedRecursors do
|
||||
match (← get).env.constants.find? ctor, (← read).newConstants.find? ctor with
|
||||
match (← get).env.constants.find? ctor, (← read).newConstants[ctor]? with
|
||||
| some (.recInfo info), some (.recInfo info') =>
|
||||
if ! (info == info') then throw <| IO.userError s!"Invalid recursor {ctor}"
|
||||
| _, _ => throw <| IO.userError s!"No such recursor {ctor}"
|
||||
@@ -155,7 +155,7 @@ open Replay
|
||||
Throws a `IO.userError` if the kernel rejects a constant,
|
||||
or if there are malformed recursors or constructors for inductive types.
|
||||
-/
|
||||
def replay (newConstants : HashMap Name ConstantInfo) (env : Environment) : IO Environment := do
|
||||
def replay (newConstants : Std.HashMap Name ConstantInfo) (env : Environment) : IO Environment := do
|
||||
let mut remaining : NameSet := ∅
|
||||
for (n, ci) in newConstants.toList do
|
||||
-- We skip unsafe constants, and also partial constants.
|
||||
|
||||
@@ -81,7 +81,7 @@ open Elab
|
||||
open Meta
|
||||
open FuzzyMatching
|
||||
|
||||
abbrev EligibleHeaderDecls := HashMap Name ConstantInfo
|
||||
abbrev EligibleHeaderDecls := Std.HashMap Name ConstantInfo
|
||||
|
||||
/-- Cached header declarations for which `allowCompletion headerEnv decl` is true. -/
|
||||
builtin_initialize eligibleHeaderDeclsRef : IO.Ref (Option EligibleHeaderDecls) ←
|
||||
|
||||
@@ -316,7 +316,7 @@ partial def handleDocumentHighlight (p : DocumentHighlightParams)
|
||||
let refs : Lsp.ModuleRefs ← findModuleRefs text trees |>.toLspModuleRefs
|
||||
let mut ranges := #[]
|
||||
for ident in refs.findAt p.position do
|
||||
if let some info := refs.find? ident then
|
||||
if let some info := refs.get? ident then
|
||||
if let some ⟨definitionRange, _⟩ := info.definition? then
|
||||
ranges := ranges.push definitionRange
|
||||
ranges := ranges.append <| info.usages.map (·.range)
|
||||
|
||||
@@ -93,20 +93,20 @@ def toLspRefInfo (i : RefInfo) : BaseIO Lsp.RefInfo := do
|
||||
end RefInfo
|
||||
|
||||
/-- All references from within a module for all identifiers used in a single module. -/
|
||||
def ModuleRefs := HashMap RefIdent RefInfo
|
||||
def ModuleRefs := Std.HashMap RefIdent RefInfo
|
||||
|
||||
namespace ModuleRefs
|
||||
|
||||
/-- Adds `ref` to the `RefInfo` corresponding to `ref.ident` in `self`. See `RefInfo.addRef`. -/
|
||||
def addRef (self : ModuleRefs) (ref : Reference) : ModuleRefs :=
|
||||
let refInfo := self.findD ref.ident RefInfo.empty
|
||||
let refInfo := self.getD ref.ident RefInfo.empty
|
||||
self.insert ref.ident (refInfo.addRef ref)
|
||||
|
||||
/-- Converts `refs` to a JSON-serializable `Lsp.ModuleRefs`. -/
|
||||
def toLspModuleRefs (refs : ModuleRefs) : BaseIO Lsp.ModuleRefs := do
|
||||
let refs ← refs.toList.mapM fun (k, v) => do
|
||||
return (k, ← v.toLspRefInfo)
|
||||
return HashMap.ofList refs
|
||||
return Std.HashMap.ofList refs
|
||||
|
||||
end ModuleRefs
|
||||
|
||||
@@ -261,7 +261,7 @@ all identifiers that are being collapsed into one.
|
||||
-/
|
||||
partial def combineIdents (trees : Array InfoTree) (refs : Array Reference) : Array Reference := Id.run do
|
||||
-- Deduplicate definitions based on their exact range
|
||||
let mut posMap : HashMap Lsp.Range RefIdent := HashMap.empty
|
||||
let mut posMap : Std.HashMap Lsp.Range RefIdent := Std.HashMap.empty
|
||||
for ref in refs do
|
||||
if let { ident, range, isBinder := true, .. } := ref then
|
||||
posMap := posMap.insert range ident
|
||||
@@ -277,17 +277,17 @@ partial def combineIdents (trees : Array InfoTree) (refs : Array Reference) : Ar
|
||||
refs' := refs'.push ref
|
||||
refs'
|
||||
where
|
||||
useConstRepresentatives (idMap : HashMap RefIdent RefIdent)
|
||||
: HashMap RefIdent RefIdent := Id.run do
|
||||
useConstRepresentatives (idMap : Std.HashMap RefIdent RefIdent)
|
||||
: Std.HashMap RefIdent RefIdent := Id.run do
|
||||
let insertIntoClass classesById id :=
|
||||
let representative := findCanonicalRepresentative idMap id
|
||||
let «class» := classesById.findD representative ∅
|
||||
let «class» := classesById.getD representative ∅
|
||||
let classesById := classesById.erase representative -- make `«class»` referentially unique
|
||||
let «class» := «class».insert id
|
||||
classesById.insert representative «class»
|
||||
|
||||
-- collect equivalence classes
|
||||
let mut classesById : HashMap RefIdent (HashSet RefIdent) := ∅
|
||||
let mut classesById : Std.HashMap RefIdent (Std.HashSet RefIdent) := ∅
|
||||
for ⟨id, baseId⟩ in idMap.toArray do
|
||||
classesById := insertIntoClass classesById id
|
||||
classesById := insertIntoClass classesById baseId
|
||||
@@ -310,17 +310,17 @@ where
|
||||
r := r.insert id bestRepresentative
|
||||
return r
|
||||
|
||||
findCanonicalRepresentative (idMap : HashMap RefIdent RefIdent) (id : RefIdent) : RefIdent := Id.run do
|
||||
findCanonicalRepresentative (idMap : Std.HashMap RefIdent RefIdent) (id : RefIdent) : RefIdent := Id.run do
|
||||
let mut canonicalRepresentative := id
|
||||
while idMap.contains canonicalRepresentative do
|
||||
canonicalRepresentative := idMap.find! canonicalRepresentative
|
||||
canonicalRepresentative := idMap[canonicalRepresentative]!
|
||||
return canonicalRepresentative
|
||||
|
||||
buildIdMap posMap := Id.run <| StateT.run' (s := HashMap.empty) do
|
||||
buildIdMap posMap := Id.run <| StateT.run' (s := Std.HashMap.empty) do
|
||||
-- map fvar defs to overlapping fvar defs/uses
|
||||
for ref in refs do
|
||||
let baseId := ref.ident
|
||||
if let some id := posMap.find? ref.range then
|
||||
if let some id := posMap[ref.range]? then
|
||||
insertIdMap id baseId
|
||||
|
||||
-- apply `FVarAliasInfo`
|
||||
@@ -346,11 +346,11 @@ are added to the `aliases` of the representative of the group.
|
||||
Yields to separate groups for declaration and usages if `allowSimultaneousBinderUse` is set.
|
||||
-/
|
||||
def dedupReferences (refs : Array Reference) (allowSimultaneousBinderUse := false) : Array Reference := Id.run do
|
||||
let mut refsByIdAndRange : HashMap (RefIdent × Option Bool × Lsp.Range) Reference := HashMap.empty
|
||||
let mut refsByIdAndRange : Std.HashMap (RefIdent × Option Bool × Lsp.Range) Reference := Std.HashMap.empty
|
||||
for ref in refs do
|
||||
let isBinder := if allowSimultaneousBinderUse then some ref.isBinder else none
|
||||
let key := (ref.ident, isBinder, ref.range)
|
||||
refsByIdAndRange := match refsByIdAndRange[key] with
|
||||
refsByIdAndRange := match refsByIdAndRange[key]? with
|
||||
| some ref' => refsByIdAndRange.insert key { ref' with aliases := ref'.aliases ++ ref.aliases }
|
||||
| none => refsByIdAndRange.insert key ref
|
||||
|
||||
@@ -371,21 +371,21 @@ def findModuleRefs (text : FileMap) (trees : Array InfoTree) (localVars : Bool :
|
||||
refs := refs.filter fun
|
||||
| { ident := RefIdent.fvar .., .. } => false
|
||||
| _ => true
|
||||
refs.foldl (init := HashMap.empty) fun m ref => m.addRef ref
|
||||
refs.foldl (init := Std.HashMap.empty) fun m ref => m.addRef ref
|
||||
|
||||
/-! # Collecting and maintaining reference info from different sources -/
|
||||
|
||||
/-- References from ilean files and current ilean information from file workers. -/
|
||||
structure References where
|
||||
/-- References loaded from ilean files -/
|
||||
ileans : HashMap Name (System.FilePath × Lsp.ModuleRefs)
|
||||
ileans : Std.HashMap Name (System.FilePath × Lsp.ModuleRefs)
|
||||
/-- References from workers, overriding the corresponding ilean files -/
|
||||
workers : HashMap Name (Nat × Lsp.ModuleRefs)
|
||||
workers : Std.HashMap Name (Nat × Lsp.ModuleRefs)
|
||||
|
||||
namespace References
|
||||
|
||||
/-- No ilean files, no information from workers. -/
|
||||
def empty : References := { ileans := HashMap.empty, workers := HashMap.empty }
|
||||
def empty : References := { ileans := Std.HashMap.empty, workers := Std.HashMap.empty }
|
||||
|
||||
/-- Adds the contents of an ilean file `ilean` at `path` to `self`. -/
|
||||
def addIlean (self : References) (path : System.FilePath) (ilean : Ilean) : References :=
|
||||
@@ -404,13 +404,13 @@ Replaces the current references with `refs` if `version` is newer than the curre
|
||||
in `refs` and otherwise merges the reference data if `version` is equal to the current version.
|
||||
-/
|
||||
def updateWorkerRefs (self : References) (name : Name) (version : Nat) (refs : Lsp.ModuleRefs) : References := Id.run do
|
||||
if let some (currVersion, _) := self.workers.find? name then
|
||||
if let some (currVersion, _) := self.workers[name]? then
|
||||
if version > currVersion then
|
||||
return { self with workers := self.workers.insert name (version, refs) }
|
||||
if version == currVersion then
|
||||
let current := self.workers.findD name (version, HashMap.empty)
|
||||
let current := self.workers.getD name (version, Std.HashMap.empty)
|
||||
let merged := refs.fold (init := current.snd) fun m ident info =>
|
||||
m.findD ident Lsp.RefInfo.empty |>.merge info |> m.insert ident
|
||||
m.getD ident Lsp.RefInfo.empty |>.merge info |> m.insert ident
|
||||
return { self with workers := self.workers.insert name (version, merged) }
|
||||
return self
|
||||
|
||||
@@ -419,7 +419,7 @@ Replaces the worker references in `self` with the `refs` of the worker managing
|
||||
if `version` is newer than the current version managed in `refs`.
|
||||
-/
|
||||
def finalizeWorkerRefs (self : References) (name : Name) (version : Nat) (refs : Lsp.ModuleRefs) : References := Id.run do
|
||||
if let some (currVersion, _) := self.workers.find? name then
|
||||
if let some (currVersion, _) := self.workers[name]? then
|
||||
if version < currVersion then
|
||||
return self
|
||||
return { self with workers := self.workers.insert name (version, refs) }
|
||||
@@ -429,8 +429,8 @@ def removeWorkerRefs (self : References) (name : Name) : References :=
|
||||
{ self with workers := self.workers.erase name }
|
||||
|
||||
/-- Yields a map from all modules to all of their references. -/
|
||||
def allRefs (self : References) : HashMap Name Lsp.ModuleRefs :=
|
||||
let ileanRefs := self.ileans.toArray.foldl (init := HashMap.empty) fun m (name, _, refs) => m.insert name refs
|
||||
def allRefs (self : References) : Std.HashMap Name Lsp.ModuleRefs :=
|
||||
let ileanRefs := self.ileans.toArray.foldl (init := Std.HashMap.empty) fun m (name, _, refs) => m.insert name refs
|
||||
self.workers.toArray.foldl (init := ileanRefs) fun m (name, _, refs) => m.insert name refs
|
||||
|
||||
/--
|
||||
@@ -445,12 +445,12 @@ def allRefsFor
|
||||
let refsToCheck := match ident with
|
||||
| RefIdent.const .. => self.allRefs.toArray
|
||||
| RefIdent.fvar identModule .. =>
|
||||
match self.allRefs.find? identModule with
|
||||
match self.allRefs[identModule]? with
|
||||
| none => #[]
|
||||
| some refs => #[(identModule, refs)]
|
||||
let mut result := #[]
|
||||
for (module, refs) in refsToCheck do
|
||||
let some info := refs.find? ident
|
||||
let some info := refs.get? ident
|
||||
| continue
|
||||
let some path ← srcSearchPath.findModuleWithExt "lean" module
|
||||
| continue
|
||||
@@ -462,13 +462,13 @@ def allRefsFor
|
||||
|
||||
/-- Yields all references in `module` at `pos`. -/
|
||||
def findAt (self : References) (module : Name) (pos : Lsp.Position) (includeStop := false) : Array RefIdent := Id.run do
|
||||
if let some refs := self.allRefs.find? module then
|
||||
if let some refs := self.allRefs[module]? then
|
||||
return refs.findAt pos includeStop
|
||||
#[]
|
||||
|
||||
/-- Yields the first reference in `module` at `pos`. -/
|
||||
def findRange? (self : References) (module : Name) (pos : Lsp.Position) (includeStop := false) : Option Range := do
|
||||
let refs ← self.allRefs.find? module
|
||||
let refs ← self.allRefs[module]?
|
||||
refs.findRange? pos includeStop
|
||||
|
||||
/-- Location and parent declaration of a reference. -/
|
||||
|
||||
@@ -90,6 +90,10 @@ section Utils
|
||||
| crashed (e : IO.Error)
|
||||
| ioError (e : IO.Error)
|
||||
|
||||
inductive CrashOrigin
|
||||
| fileWorkerToClientForwarding
|
||||
| clientToFileWorkerForwarding
|
||||
|
||||
inductive WorkerState where
|
||||
/-- The watchdog can detect a crashed file worker in two places: When trying to send a message
|
||||
to the file worker and when reading a request reply.
|
||||
@@ -98,7 +102,7 @@ section Utils
|
||||
that are in-flight are errored. Upon receiving the next packet for that file worker, the file
|
||||
worker is restarted and the packet is forwarded to it. If the crash was detected while writing
|
||||
a packet, we queue that packet until the next packet for the file worker arrives. -/
|
||||
| crashed (queuedMsgs : Array JsonRpc.Message)
|
||||
| crashed (queuedMsgs : Array JsonRpc.Message) (origin : CrashOrigin)
|
||||
| running
|
||||
|
||||
abbrev PendingRequestMap := RBMap RequestID JsonRpc.Message compare
|
||||
@@ -136,6 +140,11 @@ section FileWorker
|
||||
for ⟨id, _⟩ in pendingRequests do
|
||||
hError.writeLspResponseError { id := id, code := code, message := msg }
|
||||
|
||||
def queuedMsgs (fw : FileWorker) : Array JsonRpc.Message :=
|
||||
match fw.state with
|
||||
| .running => #[]
|
||||
| .crashed queuedMsgs _ => queuedMsgs
|
||||
|
||||
end FileWorker
|
||||
end FileWorker
|
||||
|
||||
@@ -404,10 +413,23 @@ section ServerM
|
||||
return
|
||||
eraseFileWorker uri
|
||||
|
||||
def handleCrash (uri : DocumentUri) (queuedMsgs : Array JsonRpc.Message) : ServerM Unit := do
|
||||
def handleCrash (uri : DocumentUri) (queuedMsgs : Array JsonRpc.Message) (origin: CrashOrigin) : ServerM Unit := do
|
||||
let some fw ← findFileWorker? uri
|
||||
| return
|
||||
updateFileWorkers { fw with state := WorkerState.crashed queuedMsgs }
|
||||
updateFileWorkers { fw with state := WorkerState.crashed queuedMsgs origin }
|
||||
|
||||
def tryDischargeQueuedMessages (uri : DocumentUri) (queuedMsgs : Array JsonRpc.Message) : ServerM Unit := do
|
||||
let some fw ← findFileWorker? uri
|
||||
| throwServerError "Cannot find file worker for '{uri}'."
|
||||
let mut crashedMsgs := #[]
|
||||
-- Try to discharge all queued msgs, tracking the ones that we can't discharge
|
||||
for msg in queuedMsgs do
|
||||
try
|
||||
fw.stdin.writeLspMessage msg
|
||||
catch _ =>
|
||||
crashedMsgs := crashedMsgs.push msg
|
||||
if ¬ crashedMsgs.isEmpty then
|
||||
handleCrash uri crashedMsgs .clientToFileWorkerForwarding
|
||||
|
||||
/-- Tries to write a message, sets the state of the FileWorker to `crashed` if it does not succeed
|
||||
and restarts the file worker if the `crashed` flag was already set. Just logs an error if
|
||||
@@ -423,7 +445,7 @@ section ServerM
|
||||
let some fw ← findFileWorker? uri
|
||||
| return
|
||||
match fw.state with
|
||||
| WorkerState.crashed queuedMsgs =>
|
||||
| WorkerState.crashed queuedMsgs _ =>
|
||||
let mut queuedMsgs := queuedMsgs
|
||||
if queueFailedMessage then
|
||||
queuedMsgs := queuedMsgs.push msg
|
||||
@@ -432,17 +454,7 @@ section ServerM
|
||||
-- restart the crashed FileWorker
|
||||
eraseFileWorker uri
|
||||
startFileWorker fw.doc
|
||||
let some newFw ← findFileWorker? uri
|
||||
| throwServerError "Cannot find file worker for '{uri}'."
|
||||
let mut crashedMsgs := #[]
|
||||
-- try to discharge all queued msgs, tracking the ones that we can't discharge
|
||||
for msg in queuedMsgs do
|
||||
try
|
||||
newFw.stdin.writeLspMessage msg
|
||||
catch _ =>
|
||||
crashedMsgs := crashedMsgs.push msg
|
||||
if ¬ crashedMsgs.isEmpty then
|
||||
handleCrash uri crashedMsgs
|
||||
tryDischargeQueuedMessages uri queuedMsgs
|
||||
| WorkerState.running =>
|
||||
let initialQueuedMsgs :=
|
||||
if queueFailedMessage then
|
||||
@@ -452,7 +464,7 @@ section ServerM
|
||||
try
|
||||
fw.stdin.writeLspMessage msg
|
||||
catch _ =>
|
||||
handleCrash uri initialQueuedMsgs
|
||||
handleCrash uri initialQueuedMsgs .clientToFileWorkerForwarding
|
||||
|
||||
/--
|
||||
Sends a notification to the file worker identified by `uri` that its dependency `staleDependency`
|
||||
@@ -638,7 +650,7 @@ def handleCallHierarchyOutgoingCalls (p : CallHierarchyOutgoingCallsParams)
|
||||
|
||||
let references ← (← read).references.get
|
||||
|
||||
let some refs := references.allRefs.find? module
|
||||
let some refs := references.allRefs[module]?
|
||||
| return #[]
|
||||
|
||||
let items ← refs.toArray.filterMapM fun ⟨ident, info⟩ => do
|
||||
@@ -702,9 +714,9 @@ def handlePrepareRename (p : PrepareRenameParams) : ServerM (Option Range) := do
|
||||
def handleRename (p : RenameParams) : ServerM Lsp.WorkspaceEdit := do
|
||||
if (String.toName p.newName).isAnonymous then
|
||||
throwServerError s!"Can't rename: `{p.newName}` is not an identifier"
|
||||
let mut refs : HashMap DocumentUri (RBMap Lsp.Position Lsp.Position compare) := ∅
|
||||
let mut refs : Std.HashMap DocumentUri (RBMap Lsp.Position Lsp.Position compare) := ∅
|
||||
for { uri, range } in (← handleReference { p with context.includeDeclaration := true }) do
|
||||
refs := refs.insert uri <| (refs.findD uri ∅).insert range.start range.end
|
||||
refs := refs.insert uri <| (refs.getD uri ∅).insert range.start range.end
|
||||
-- We have to filter the list of changes to put the ranges in order and
|
||||
-- remove any duplicates or overlapping ranges, or else the rename will not apply
|
||||
let changes := refs.fold (init := ∅) fun changes uri map => Id.run do
|
||||
@@ -955,7 +967,16 @@ section MainLoop
|
||||
let workers ← st.fileWorkersRef.get
|
||||
let mut workerTasks := #[]
|
||||
for (_, fw) in workers do
|
||||
if let WorkerState.running := fw.state then
|
||||
-- When the forwarding task crashes, its return value will be stuck at
|
||||
-- `WorkerEvent.crashed _`.
|
||||
-- We want to handle this event only once, not over and over again,
|
||||
-- so once the state becomes `WorkerState.crashed _ .fileWorkerToClientForwarding`
|
||||
-- as a result of `WorkerEvent.crashed _`, we stop handling this event until
|
||||
-- eventually the file worker is restarted by a notification from the client.
|
||||
-- We do not want to filter the forwarding task in case of
|
||||
-- `WorkerState.crashed _ .clientToFileWorkerForwarding`, since the forwarding task
|
||||
-- exit code may still contain valuable information in this case (e.g. that the imports changed).
|
||||
if !(fw.state matches WorkerState.crashed _ .fileWorkerToClientForwarding) then
|
||||
workerTasks := workerTasks.push <| fw.commTask.map (ServerEvent.workerEvent fw)
|
||||
|
||||
let ev ← IO.waitAny (clientTask :: workerTasks.toList)
|
||||
@@ -984,13 +1005,16 @@ section MainLoop
|
||||
| WorkerEvent.ioError e =>
|
||||
throwServerError s!"IO error while processing events for {fw.doc.uri}: {e}"
|
||||
| WorkerEvent.crashed _ =>
|
||||
handleCrash fw.doc.uri #[]
|
||||
handleCrash fw.doc.uri fw.queuedMsgs .fileWorkerToClientForwarding
|
||||
mainLoop clientTask
|
||||
| WorkerEvent.terminated =>
|
||||
throwServerError <| "Internal server error: got termination event for worker that "
|
||||
++ "should have been removed"
|
||||
| .importsChanged =>
|
||||
let uri := fw.doc.uri
|
||||
let queuedMsgs := fw.queuedMsgs
|
||||
startFileWorker fw.doc
|
||||
tryDischargeQueuedMessages uri queuedMsgs
|
||||
mainLoop clientTask
|
||||
end MainLoop
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user