mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-18 19:04:07 +00:00
Compare commits
40 Commits
array_atta
...
binSearch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bcc22fd476 | ||
|
|
bf13b24692 | ||
|
|
51d1cc61d7 | ||
|
|
107a2e8b2e | ||
|
|
c4b0b94c91 | ||
|
|
ba3f2b3ecf | ||
|
|
4a69643858 | ||
|
|
b6a0d63612 | ||
|
|
5145030ff4 | ||
|
|
d3cb812fb6 | ||
|
|
e066c17a65 | ||
|
|
38cff08888 | ||
|
|
3388fc8d06 | ||
|
|
5adcd520fa | ||
|
|
1126407d9b | ||
|
|
a19ff61e15 | ||
|
|
6202461a21 | ||
|
|
ea221f3283 | ||
|
|
7c50d597c3 | ||
|
|
99031695bd | ||
|
|
b7248d5295 | ||
|
|
7f2e7e56d2 | ||
|
|
1fe66737ad | ||
|
|
765eb02279 | ||
|
|
a101377054 | ||
|
|
aca9929d84 | ||
|
|
19a701e5c9 | ||
|
|
fc4305ab15 | ||
|
|
9cf83706e7 | ||
|
|
459c6e2a46 | ||
|
|
72e952eadc | ||
|
|
56a80dec1b | ||
|
|
b894464191 | ||
|
|
b30903d1fc | ||
|
|
7fbe8e3b36 | ||
|
|
2fbc46641d | ||
|
|
17419aca7f | ||
|
|
f85c66789d | ||
|
|
c8b4f6b511 | ||
|
|
3c7555168d |
24
.github/workflows/labels-from-comments.yml
vendored
24
.github/workflows/labels-from-comments.yml
vendored
@@ -1,7 +1,8 @@
|
||||
# This workflow allows any user to add one of the `awaiting-review`, `awaiting-author`, `WIP`,
|
||||
# or `release-ci` labels by commenting on the PR or issue.
|
||||
# `release-ci`, or a `changelog-XXX` label by commenting on the PR or issue.
|
||||
# If any labels from the set {`awaiting-review`, `awaiting-author`, `WIP`} are added, other labels
|
||||
# from that set are removed automatically at the same time.
|
||||
# Similarly, if any `changelog-XXX` label is added, other `changelog-YYY` labels are removed.
|
||||
|
||||
name: Label PR based on Comment
|
||||
|
||||
@@ -11,7 +12,7 @@ on:
|
||||
|
||||
jobs:
|
||||
update-label:
|
||||
if: github.event.issue.pull_request != null && (contains(github.event.comment.body, 'awaiting-review') || contains(github.event.comment.body, 'awaiting-author') || contains(github.event.comment.body, 'WIP') || contains(github.event.comment.body, 'release-ci'))
|
||||
if: github.event.issue.pull_request != null && (contains(github.event.comment.body, 'awaiting-review') || contains(github.event.comment.body, 'awaiting-author') || contains(github.event.comment.body, 'WIP') || contains(github.event.comment.body, 'release-ci') || contains(github.event.comment.body, 'changelog-'))
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
@@ -20,13 +21,14 @@ jobs:
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const { owner, repo, number: issue_number } = context.issue;
|
||||
const { owner, repo, number: issue_number } = context.issue;
|
||||
const commentLines = context.payload.comment.body.split('\r\n');
|
||||
|
||||
const awaitingReview = commentLines.includes('awaiting-review');
|
||||
const awaitingAuthor = commentLines.includes('awaiting-author');
|
||||
const wip = commentLines.includes('WIP');
|
||||
const releaseCI = commentLines.includes('release-ci');
|
||||
const changelogMatch = commentLines.find(line => line.startsWith('changelog-'));
|
||||
|
||||
if (awaitingReview || awaitingAuthor || wip) {
|
||||
await github.rest.issues.removeLabel({ owner, repo, issue_number, name: 'awaiting-review' }).catch(() => {});
|
||||
@@ -47,3 +49,19 @@ jobs:
|
||||
if (releaseCI) {
|
||||
await github.rest.issues.addLabels({ owner, repo, issue_number, labels: ['release-ci'] });
|
||||
}
|
||||
|
||||
if (changelogMatch) {
|
||||
const changelogLabel = changelogMatch.trim();
|
||||
const { data: existingLabels } = await github.rest.issues.listLabelsOnIssue({ owner, repo, issue_number });
|
||||
const changelogLabels = existingLabels.filter(label => label.name.startsWith('changelog-'));
|
||||
|
||||
// Remove all other changelog labels
|
||||
for (const label of changelogLabels) {
|
||||
if (label.name !== changelogLabel) {
|
||||
await github.rest.issues.removeLabel({ owner, repo, issue_number, name: label.name }).catch(() => {});
|
||||
}
|
||||
}
|
||||
|
||||
// Add the new changelog label
|
||||
await github.rest.issues.addLabels({ owner, repo, issue_number, labels: [changelogLabel] });
|
||||
}
|
||||
|
||||
@@ -82,9 +82,7 @@ theorem Expr.typeCheck_correct (h₁ : HasType e ty) (h₂ : e.typeCheck ≠ .un
|
||||
/-!
|
||||
Now, we prove that if `Expr.typeCheck e` returns `Maybe.unknown`, then forall `ty`, `HasType e ty` does not hold.
|
||||
The notation `e.typeCheck` is sugar for `Expr.typeCheck e`. Lean can infer this because we explicitly said that `e` has type `Expr`.
|
||||
The proof is by induction on `e` and case analysis. The tactic `rename_i` is used to rename "inaccessible" variables.
|
||||
We say a variable is inaccessible if it is introduced by a tactic (e.g., `cases`) or has been shadowed by another variable introduced
|
||||
by the user. Note that the tactic `simp [typeCheck]` is applied to all goal generated by the `induction` tactic, and closes
|
||||
The proof is by induction on `e` and case analysis. Note that the tactic `simp [typeCheck]` is applied to all goal generated by the `induction` tactic, and closes
|
||||
the cases corresponding to the constructors `Expr.nat` and `Expr.bool`.
|
||||
-/
|
||||
theorem Expr.typeCheck_complete {e : Expr} : e.typeCheck = .unknown → ¬ HasType e ty := by
|
||||
|
||||
@@ -10,12 +10,13 @@ import Init.Data.List.Attach
|
||||
|
||||
namespace Array
|
||||
|
||||
/-- `O(n)`. Partial map. If `f : Π a, P a → β` is a partial function defined on
|
||||
`a : α` satisfying `P`, then `pmap f l h` is essentially the same as `map f l`
|
||||
but is defined only when all members of `l` satisfy `P`, using the proof
|
||||
to apply `f`.
|
||||
/--
|
||||
`O(n)`. Partial map. If `f : Π a, P a → β` is a partial function defined on
|
||||
`a : α` satisfying `P`, then `pmap f l h` is essentially the same as `map f l`
|
||||
but is defined only when all members of `l` satisfy `P`, using the proof
|
||||
to apply `f`.
|
||||
|
||||
We replace this at runtime with a more efficient version via
|
||||
We replace this at runtime with a more efficient version via the `csimp` lemma `pmap_eq_pmapImpl`.
|
||||
-/
|
||||
def pmap {P : α → Prop} (f : ∀ a, P a → β) (l : Array α) (H : ∀ a ∈ l, P a) : Array β :=
|
||||
(l.toList.pmap f (fun a m => H a (mem_def.mpr m))).toArray
|
||||
@@ -73,6 +74,17 @@ Unsafe implementation of `attachWith`, taking advantage of the fact that the rep
|
||||
intro a m h₁ h₂
|
||||
congr
|
||||
|
||||
@[simp] theorem pmap_empty {P : α → Prop} (f : ∀ a, P a → β) : pmap f #[] (by simp) = #[] := rfl
|
||||
|
||||
@[simp] theorem pmap_push {P : α → Prop} (f : ∀ a, P a → β) (a : α) (l : Array α) (h : ∀ b ∈ l.push a, P b) :
|
||||
pmap f (l.push a) h =
|
||||
(pmap f l (fun a m => by simp at h; exact h a (.inl m))).push (f a (h a (by simp))) := by
|
||||
simp [pmap]
|
||||
|
||||
@[simp] theorem attach_empty : (#[] : Array α).attach = #[] := rfl
|
||||
|
||||
@[simp] theorem attachWith_empty {P : α → Prop} (H : ∀ x ∈ #[], P x) : (#[] : Array α).attachWith P H = #[] := rfl
|
||||
|
||||
@[simp] theorem _root_.List.attachWith_mem_toArray {l : List α} :
|
||||
l.attachWith (fun x => x ∈ l.toArray) (fun x h => by simpa using h) =
|
||||
l.attach.map fun ⟨x, h⟩ => ⟨x, by simpa using h⟩ := by
|
||||
@@ -80,6 +92,353 @@ Unsafe implementation of `attachWith`, taking advantage of the fact that the rep
|
||||
apply List.pmap_congr_left
|
||||
simp
|
||||
|
||||
@[simp]
|
||||
theorem pmap_eq_map (p : α → Prop) (f : α → β) (l : Array α) (H) :
|
||||
@pmap _ _ p (fun a _ => f a) l H = map f l := by
|
||||
cases l; simp
|
||||
|
||||
theorem pmap_congr_left {p q : α → Prop} {f : ∀ a, p a → β} {g : ∀ a, q a → β} (l : Array α) {H₁ H₂}
|
||||
(h : ∀ a ∈ l, ∀ (h₁ h₂), f a h₁ = g a h₂) : pmap f l H₁ = pmap g l H₂ := by
|
||||
cases l
|
||||
simp only [mem_toArray] at h
|
||||
simp only [List.pmap_toArray, mk.injEq]
|
||||
rw [List.pmap_congr_left _ h]
|
||||
|
||||
theorem map_pmap {p : α → Prop} (g : β → γ) (f : ∀ a, p a → β) (l H) :
|
||||
map g (pmap f l H) = pmap (fun a h => g (f a h)) l H := by
|
||||
cases l
|
||||
simp [List.map_pmap]
|
||||
|
||||
theorem pmap_map {p : β → Prop} (g : ∀ b, p b → γ) (f : α → β) (l H) :
|
||||
pmap g (map f l) H = pmap (fun a h => g (f a) h) l fun _ h => H _ (mem_map_of_mem _ h) := by
|
||||
cases l
|
||||
simp [List.pmap_map]
|
||||
|
||||
theorem attach_congr {l₁ l₂ : Array α} (h : l₁ = l₂) :
|
||||
l₁.attach = l₂.attach.map (fun x => ⟨x.1, h ▸ x.2⟩) := by
|
||||
subst h
|
||||
simp
|
||||
|
||||
theorem attachWith_congr {l₁ l₂ : Array α} (w : l₁ = l₂) {P : α → Prop} {H : ∀ x ∈ l₁, P x} :
|
||||
l₁.attachWith P H = l₂.attachWith P fun _ h => H _ (w ▸ h) := by
|
||||
subst w
|
||||
simp
|
||||
|
||||
@[simp] theorem attach_push {a : α} {l : Array α} :
|
||||
(l.push a).attach =
|
||||
(l.attach.map (fun ⟨x, h⟩ => ⟨x, mem_push_of_mem a h⟩)).push ⟨a, by simp⟩ := by
|
||||
cases l
|
||||
rw [attach_congr (List.push_toArray _ _)]
|
||||
simp [Function.comp_def]
|
||||
|
||||
@[simp] theorem attachWith_push {a : α} {l : Array α} {P : α → Prop} {H : ∀ x ∈ l.push a, P x} :
|
||||
(l.push a).attachWith P H =
|
||||
(l.attachWith P (fun x h => by simp at H; exact H x (.inl h))).push ⟨a, H a (by simp)⟩ := by
|
||||
cases l
|
||||
simp [attachWith_congr (List.push_toArray _ _)]
|
||||
|
||||
theorem pmap_eq_map_attach {p : α → Prop} (f : ∀ a, p a → β) (l H) :
|
||||
pmap f l H = l.attach.map fun x => f x.1 (H _ x.2) := by
|
||||
cases l
|
||||
simp [List.pmap_eq_map_attach]
|
||||
|
||||
theorem attach_map_coe (l : Array α) (f : α → β) :
|
||||
(l.attach.map fun (i : {i // i ∈ l}) => f i) = l.map f := by
|
||||
cases l
|
||||
simp [List.attach_map_coe]
|
||||
|
||||
theorem attach_map_val (l : Array α) (f : α → β) : (l.attach.map fun i => f i.val) = l.map f :=
|
||||
attach_map_coe _ _
|
||||
|
||||
@[simp]
|
||||
theorem attach_map_subtype_val (l : Array α) : l.attach.map Subtype.val = l := by
|
||||
cases l; simp
|
||||
|
||||
theorem attachWith_map_coe {p : α → Prop} (f : α → β) (l : Array α) (H : ∀ a ∈ l, p a) :
|
||||
((l.attachWith p H).map fun (i : { i // p i}) => f i) = l.map f := by
|
||||
cases l; simp
|
||||
|
||||
theorem attachWith_map_val {p : α → Prop} (f : α → β) (l : Array α) (H : ∀ a ∈ l, p a) :
|
||||
((l.attachWith p H).map fun i => f i.val) = l.map f :=
|
||||
attachWith_map_coe _ _ _
|
||||
|
||||
@[simp]
|
||||
theorem attachWith_map_subtype_val {p : α → Prop} (l : Array α) (H : ∀ a ∈ l, p a) :
|
||||
(l.attachWith p H).map Subtype.val = l := by
|
||||
cases l; simp
|
||||
|
||||
@[simp]
|
||||
theorem mem_attach (l : Array α) : ∀ x, x ∈ l.attach
|
||||
| ⟨a, h⟩ => by
|
||||
have := mem_map.1 (by rw [attach_map_subtype_val] <;> exact h)
|
||||
rcases this with ⟨⟨_, _⟩, m, rfl⟩
|
||||
exact m
|
||||
|
||||
@[simp]
|
||||
theorem mem_pmap {p : α → Prop} {f : ∀ a, p a → β} {l H b} :
|
||||
b ∈ pmap f l H ↔ ∃ (a : _) (h : a ∈ l), f a (H a h) = b := by
|
||||
simp only [pmap_eq_map_attach, mem_map, mem_attach, true_and, Subtype.exists, eq_comm]
|
||||
|
||||
theorem mem_pmap_of_mem {p : α → Prop} {f : ∀ a, p a → β} {l H} {a} (h : a ∈ l) :
|
||||
f a (H a h) ∈ pmap f l H := by
|
||||
rw [mem_pmap]
|
||||
exact ⟨a, h, rfl⟩
|
||||
|
||||
@[simp]
|
||||
theorem size_pmap {p : α → Prop} {f : ∀ a, p a → β} {l H} : (pmap f l H).size = l.size := by
|
||||
cases l; simp
|
||||
|
||||
@[simp]
|
||||
theorem size_attach {L : Array α} : L.attach.size = L.size := by
|
||||
cases L; simp
|
||||
|
||||
@[simp]
|
||||
theorem size_attachWith {p : α → Prop} {l : Array α} {H} : (l.attachWith p H).size = l.size := by
|
||||
cases l; simp
|
||||
|
||||
@[simp]
|
||||
theorem pmap_eq_empty_iff {p : α → Prop} {f : ∀ a, p a → β} {l H} : pmap f l H = #[] ↔ l = #[] := by
|
||||
cases l; simp
|
||||
|
||||
theorem pmap_ne_empty_iff {P : α → Prop} (f : (a : α) → P a → β) {xs : Array α}
|
||||
(H : ∀ (a : α), a ∈ xs → P a) : xs.pmap f H ≠ #[] ↔ xs ≠ #[] := by
|
||||
cases xs; simp
|
||||
|
||||
theorem pmap_eq_self {l : Array α} {p : α → Prop} (hp : ∀ (a : α), a ∈ l → p a)
|
||||
(f : (a : α) → p a → α) : l.pmap f hp = l ↔ ∀ a (h : a ∈ l), f a (hp a h) = a := by
|
||||
cases l; simp [List.pmap_eq_self]
|
||||
|
||||
@[simp]
|
||||
theorem attach_eq_empty_iff {l : Array α} : l.attach = #[] ↔ l = #[] := by
|
||||
cases l; simp
|
||||
|
||||
theorem attach_ne_empty_iff {l : Array α} : l.attach ≠ #[] ↔ l ≠ #[] := by
|
||||
cases l; simp
|
||||
|
||||
@[simp]
|
||||
theorem attachWith_eq_empty_iff {l : Array α} {P : α → Prop} {H : ∀ a ∈ l, P a} :
|
||||
l.attachWith P H = #[] ↔ l = #[] := by
|
||||
cases l; simp
|
||||
|
||||
theorem attachWith_ne_empty_iff {l : Array α} {P : α → Prop} {H : ∀ a ∈ l, P a} :
|
||||
l.attachWith P H ≠ #[] ↔ l ≠ #[] := by
|
||||
cases l; simp
|
||||
|
||||
@[simp]
|
||||
theorem getElem?_pmap {p : α → Prop} (f : ∀ a, p a → β) {l : Array α} (h : ∀ a ∈ l, p a) (n : Nat) :
|
||||
(pmap f l h)[n]? = Option.pmap f l[n]? fun x H => h x (mem_of_getElem? H) := by
|
||||
cases l; simp
|
||||
|
||||
@[simp]
|
||||
theorem getElem_pmap {p : α → Prop} (f : ∀ a, p a → β) {l : Array α} (h : ∀ a ∈ l, p a) {n : Nat}
|
||||
(hn : n < (pmap f l h).size) :
|
||||
(pmap f l h)[n] =
|
||||
f (l[n]'(@size_pmap _ _ p f l h ▸ hn))
|
||||
(h _ (getElem_mem (@size_pmap _ _ p f l h ▸ hn))) := by
|
||||
cases l; simp
|
||||
|
||||
@[simp]
|
||||
theorem getElem?_attachWith {xs : Array α} {i : Nat} {P : α → Prop} {H : ∀ a ∈ xs, P a} :
|
||||
(xs.attachWith P H)[i]? = xs[i]?.pmap Subtype.mk (fun _ a => H _ (mem_of_getElem? a)) :=
|
||||
getElem?_pmap ..
|
||||
|
||||
@[simp]
|
||||
theorem getElem?_attach {xs : Array α} {i : Nat} :
|
||||
xs.attach[i]? = xs[i]?.pmap Subtype.mk (fun _ a => mem_of_getElem? a) :=
|
||||
getElem?_attachWith
|
||||
|
||||
@[simp]
|
||||
theorem getElem_attachWith {xs : Array α} {P : α → Prop} {H : ∀ a ∈ xs, P a}
|
||||
{i : Nat} (h : i < (xs.attachWith P H).size) :
|
||||
(xs.attachWith P H)[i] = ⟨xs[i]'(by simpa using h), H _ (getElem_mem (by simpa using h))⟩ :=
|
||||
getElem_pmap ..
|
||||
|
||||
@[simp]
|
||||
theorem getElem_attach {xs : Array α} {i : Nat} (h : i < xs.attach.size) :
|
||||
xs.attach[i] = ⟨xs[i]'(by simpa using h), getElem_mem (by simpa using h)⟩ :=
|
||||
getElem_attachWith h
|
||||
|
||||
theorem foldl_pmap (l : Array α) {P : α → Prop} (f : (a : α) → P a → β)
|
||||
(H : ∀ (a : α), a ∈ l → P a) (g : γ → β → γ) (x : γ) :
|
||||
(l.pmap f H).foldl g x = l.attach.foldl (fun acc a => g acc (f a.1 (H _ a.2))) x := by
|
||||
rw [pmap_eq_map_attach, foldl_map]
|
||||
|
||||
theorem foldr_pmap (l : Array α) {P : α → Prop} (f : (a : α) → P a → β)
|
||||
(H : ∀ (a : α), a ∈ l → P a) (g : β → γ → γ) (x : γ) :
|
||||
(l.pmap f H).foldr g x = l.attach.foldr (fun a acc => g (f a.1 (H _ a.2)) acc) x := by
|
||||
rw [pmap_eq_map_attach, foldr_map]
|
||||
|
||||
/--
|
||||
If we fold over `l.attach` with a function that ignores the membership predicate,
|
||||
we get the same results as folding over `l` directly.
|
||||
|
||||
This is useful when we need to use `attach` to show termination.
|
||||
|
||||
Unfortunately this can't be applied by `simp` because of the higher order unification problem,
|
||||
and even when rewriting we need to specify the function explicitly.
|
||||
See however `foldl_subtype` below.
|
||||
-/
|
||||
theorem foldl_attach (l : Array α) (f : β → α → β) (b : β) :
|
||||
l.attach.foldl (fun acc t => f acc t.1) b = l.foldl f b := by
|
||||
rcases l with ⟨l⟩
|
||||
simp only [List.attach_toArray, List.attachWith_mem_toArray, List.map_attach, size_toArray,
|
||||
List.length_pmap, List.foldl_toArray', mem_toArray, List.foldl_subtype]
|
||||
congr
|
||||
ext
|
||||
simpa using fun a => List.mem_of_getElem? a
|
||||
|
||||
/--
|
||||
If we fold over `l.attach` with a function that ignores the membership predicate,
|
||||
we get the same results as folding over `l` directly.
|
||||
|
||||
This is useful when we need to use `attach` to show termination.
|
||||
|
||||
Unfortunately this can't be applied by `simp` because of the higher order unification problem,
|
||||
and even when rewriting we need to specify the function explicitly.
|
||||
See however `foldr_subtype` below.
|
||||
-/
|
||||
theorem foldr_attach (l : Array α) (f : α → β → β) (b : β) :
|
||||
l.attach.foldr (fun t acc => f t.1 acc) b = l.foldr f b := by
|
||||
rcases l with ⟨l⟩
|
||||
simp only [List.attach_toArray, List.attachWith_mem_toArray, List.map_attach, size_toArray,
|
||||
List.length_pmap, List.foldr_toArray', mem_toArray, List.foldr_subtype]
|
||||
congr
|
||||
ext
|
||||
simpa using fun a => List.mem_of_getElem? a
|
||||
|
||||
theorem attach_map {l : Array α} (f : α → β) :
|
||||
(l.map f).attach = l.attach.map (fun ⟨x, h⟩ => ⟨f x, mem_map_of_mem f h⟩) := by
|
||||
cases l
|
||||
ext <;> simp
|
||||
|
||||
theorem attachWith_map {l : Array α} (f : α → β) {P : β → Prop} {H : ∀ (b : β), b ∈ l.map f → P b} :
|
||||
(l.map f).attachWith P H = (l.attachWith (P ∘ f) (fun _ h => H _ (mem_map_of_mem f h))).map
|
||||
fun ⟨x, h⟩ => ⟨f x, h⟩ := by
|
||||
cases l
|
||||
ext
|
||||
· simp
|
||||
· simp only [List.map_toArray, List.attachWith_toArray, List.getElem_toArray,
|
||||
List.getElem_attachWith, List.getElem_map, Function.comp_apply]
|
||||
erw [List.getElem_attachWith] -- Why is `erw` needed here?
|
||||
|
||||
theorem map_attachWith {l : Array α} {P : α → Prop} {H : ∀ (a : α), a ∈ l → P a}
|
||||
(f : { x // P x } → β) :
|
||||
(l.attachWith P H).map f =
|
||||
l.pmap (fun a (h : a ∈ l ∧ P a) => f ⟨a, H _ h.1⟩) (fun a h => ⟨h, H a h⟩) := by
|
||||
cases l
|
||||
ext <;> simp
|
||||
|
||||
/-- See also `pmap_eq_map_attach` for writing `pmap` in terms of `map` and `attach`. -/
|
||||
theorem map_attach {l : Array α} (f : { x // x ∈ l } → β) :
|
||||
l.attach.map f = l.pmap (fun a h => f ⟨a, h⟩) (fun _ => id) := by
|
||||
cases l
|
||||
ext <;> simp
|
||||
|
||||
theorem attach_filterMap {l : Array α} {f : α → Option β} :
|
||||
(l.filterMap f).attach = l.attach.filterMap
|
||||
fun ⟨x, h⟩ => (f x).pbind (fun b m => some ⟨b, mem_filterMap.mpr ⟨x, h, m⟩⟩) := by
|
||||
cases l
|
||||
rw [attach_congr (List.filterMap_toArray f _)]
|
||||
simp [List.attach_filterMap, List.map_filterMap, Function.comp_def]
|
||||
|
||||
theorem attach_filter {l : Array α} (p : α → Bool) :
|
||||
(l.filter p).attach = l.attach.filterMap
|
||||
fun x => if w : p x.1 then some ⟨x.1, mem_filter.mpr ⟨x.2, w⟩⟩ else none := by
|
||||
cases l
|
||||
rw [attach_congr (List.filter_toArray p _)]
|
||||
simp [List.attach_filter, List.map_filterMap, Function.comp_def]
|
||||
|
||||
-- We are still missing here `attachWith_filterMap` and `attachWith_filter`.
|
||||
-- Also missing are `filterMap_attach`, `filter_attach`, `filterMap_attachWith` and `filter_attachWith`.
|
||||
|
||||
theorem pmap_pmap {p : α → Prop} {q : β → Prop} (g : ∀ a, p a → β) (f : ∀ b, q b → γ) (l H₁ H₂) :
|
||||
pmap f (pmap g l H₁) H₂ =
|
||||
pmap (α := { x // x ∈ l }) (fun a h => f (g a h) (H₂ (g a h) (mem_pmap_of_mem a.2))) l.attach
|
||||
(fun a _ => H₁ a a.2) := by
|
||||
cases l
|
||||
simp [List.pmap_pmap, List.pmap_map]
|
||||
|
||||
@[simp] theorem pmap_append {p : ι → Prop} (f : ∀ a : ι, p a → α) (l₁ l₂ : Array ι)
|
||||
(h : ∀ a ∈ l₁ ++ l₂, p a) :
|
||||
(l₁ ++ l₂).pmap f h =
|
||||
(l₁.pmap f fun a ha => h a (mem_append_left l₂ ha)) ++
|
||||
l₂.pmap f fun a ha => h a (mem_append_right l₁ ha) := by
|
||||
cases l₁
|
||||
cases l₂
|
||||
simp
|
||||
|
||||
theorem pmap_append' {p : α → Prop} (f : ∀ a : α, p a → β) (l₁ l₂ : Array α)
|
||||
(h₁ : ∀ a ∈ l₁, p a) (h₂ : ∀ a ∈ l₂, p a) :
|
||||
((l₁ ++ l₂).pmap f fun a ha => (mem_append.1 ha).elim (h₁ a) (h₂ a)) =
|
||||
l₁.pmap f h₁ ++ l₂.pmap f h₂ :=
|
||||
pmap_append f l₁ l₂ _
|
||||
|
||||
@[simp] theorem attach_append (xs ys : Array α) :
|
||||
(xs ++ ys).attach = xs.attach.map (fun ⟨x, h⟩ => ⟨x, mem_append_left ys h⟩) ++
|
||||
ys.attach.map fun ⟨x, h⟩ => ⟨x, mem_append_right xs h⟩ := by
|
||||
cases xs
|
||||
cases ys
|
||||
rw [attach_congr (List.append_toArray _ _)]
|
||||
simp [List.attach_append, Function.comp_def]
|
||||
|
||||
@[simp] theorem attachWith_append {P : α → Prop} {xs ys : Array α}
|
||||
{H : ∀ (a : α), a ∈ xs ++ ys → P a} :
|
||||
(xs ++ ys).attachWith P H = xs.attachWith P (fun a h => H a (mem_append_left ys h)) ++
|
||||
ys.attachWith P (fun a h => H a (mem_append_right xs h)) := by
|
||||
simp [attachWith, attach_append, map_pmap, pmap_append]
|
||||
|
||||
@[simp] theorem pmap_reverse {P : α → Prop} (f : (a : α) → P a → β) (xs : Array α)
|
||||
(H : ∀ (a : α), a ∈ xs.reverse → P a) :
|
||||
xs.reverse.pmap f H = (xs.pmap f (fun a h => H a (by simpa using h))).reverse := by
|
||||
induction xs <;> simp_all
|
||||
|
||||
theorem reverse_pmap {P : α → Prop} (f : (a : α) → P a → β) (xs : Array α)
|
||||
(H : ∀ (a : α), a ∈ xs → P a) :
|
||||
(xs.pmap f H).reverse = xs.reverse.pmap f (fun a h => H a (by simpa using h)) := by
|
||||
rw [pmap_reverse]
|
||||
|
||||
@[simp] theorem attachWith_reverse {P : α → Prop} {xs : Array α}
|
||||
{H : ∀ (a : α), a ∈ xs.reverse → P a} :
|
||||
xs.reverse.attachWith P H =
|
||||
(xs.attachWith P (fun a h => H a (by simpa using h))).reverse := by
|
||||
cases xs
|
||||
simp
|
||||
|
||||
theorem reverse_attachWith {P : α → Prop} {xs : Array α}
|
||||
{H : ∀ (a : α), a ∈ xs → P a} :
|
||||
(xs.attachWith P H).reverse = (xs.reverse.attachWith P (fun a h => H a (by simpa using h))) := by
|
||||
cases xs
|
||||
simp
|
||||
|
||||
@[simp] theorem attach_reverse (xs : Array α) :
|
||||
xs.reverse.attach = xs.attach.reverse.map fun ⟨x, h⟩ => ⟨x, by simpa using h⟩ := by
|
||||
cases xs
|
||||
rw [attach_congr (List.reverse_toArray _)]
|
||||
simp
|
||||
|
||||
theorem reverse_attach (xs : Array α) :
|
||||
xs.attach.reverse = xs.reverse.attach.map fun ⟨x, h⟩ => ⟨x, by simpa using h⟩ := by
|
||||
cases xs
|
||||
simp
|
||||
|
||||
@[simp] theorem back?_pmap {P : α → Prop} (f : (a : α) → P a → β) (xs : Array α)
|
||||
(H : ∀ (a : α), a ∈ xs → P a) :
|
||||
(xs.pmap f H).back? = xs.attach.back?.map fun ⟨a, m⟩ => f a (H a m) := by
|
||||
cases xs
|
||||
simp
|
||||
|
||||
@[simp] theorem back?_attachWith {P : α → Prop} {xs : Array α}
|
||||
{H : ∀ (a : α), a ∈ xs → P a} :
|
||||
(xs.attachWith P H).back? = xs.back?.pbind (fun a h => some ⟨a, H _ (mem_of_back?_eq_some h)⟩) := by
|
||||
cases xs
|
||||
simp
|
||||
|
||||
@[simp]
|
||||
theorem back?_attach {xs : Array α} :
|
||||
xs.attach.back? = xs.back?.pbind fun a h => some ⟨a, mem_of_back?_eq_some h⟩ := by
|
||||
cases xs
|
||||
simp
|
||||
|
||||
/-! ## unattach
|
||||
|
||||
`Array.unattach` is the (one-sided) inverse of `Array.attach`. It is a synonym for `Array.map Subtype.val`.
|
||||
@@ -128,6 +487,15 @@ def unattach {α : Type _} {p : α → Prop} (l : Array { x // p x }) := l.map (
|
||||
cases l
|
||||
simp
|
||||
|
||||
@[simp] theorem getElem?_unattach {p : α → Prop} {l : Array { x // p x }} (i : Nat) :
|
||||
l.unattach[i]? = l[i]?.map Subtype.val := by
|
||||
simp [unattach]
|
||||
|
||||
@[simp] theorem getElem_unattach
|
||||
{p : α → Prop} {l : Array { x // p x }} (i : Nat) (h : i < l.unattach.size) :
|
||||
l.unattach[i] = (l[i]'(by simpa using h)).1 := by
|
||||
simp [unattach]
|
||||
|
||||
/-! ### Recognizing higher order functions using a function that only depends on the value. -/
|
||||
|
||||
/--
|
||||
|
||||
@@ -233,7 +233,7 @@ def ofFn {n} (f : Fin n → α) : Array α := go 0 (mkEmpty n) where
|
||||
|
||||
/-- The array `#[0, 1, ..., n - 1]`. -/
|
||||
def range (n : Nat) : Array Nat :=
|
||||
n.fold (flip Array.push) (mkEmpty n)
|
||||
ofFn fun (i : Fin n) => i
|
||||
|
||||
def singleton (v : α) : Array α :=
|
||||
mkArray 1 v
|
||||
@@ -613,8 +613,15 @@ def findIdx? {α : Type u} (p : α → Bool) (as : Array α) : Option Nat :=
|
||||
decreasing_by simp_wf; decreasing_trivial_pre_omega
|
||||
loop 0
|
||||
|
||||
def getIdx? [BEq α] (a : Array α) (v : α) : Option Nat :=
|
||||
a.findIdx? fun a => a == v
|
||||
@[inline]
|
||||
def findFinIdx? {α : Type u} (p : α → Bool) (as : Array α) : Option (Fin as.size) :=
|
||||
let rec @[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
|
||||
loop (j : Nat) :=
|
||||
if h : j < as.size then
|
||||
if p as[j] then some ⟨j, h⟩ else loop (j + 1)
|
||||
else none
|
||||
decreasing_by simp_wf; decreasing_trivial_pre_omega
|
||||
loop 0
|
||||
|
||||
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
|
||||
def indexOfAux [BEq α] (a : Array α) (v : α) (i : Nat) : Option (Fin a.size) :=
|
||||
@@ -627,6 +634,10 @@ decreasing_by simp_wf; decreasing_trivial_pre_omega
|
||||
def indexOf? [BEq α] (a : Array α) (v : α) : Option (Fin a.size) :=
|
||||
indexOfAux a v 0
|
||||
|
||||
@[deprecated indexOf? (since := "2024-11-20")]
|
||||
def getIdx? [BEq α] (a : Array α) (v : α) : Option Nat :=
|
||||
a.findIdx? fun a => a == v
|
||||
|
||||
@[inline]
|
||||
def any (as : Array α) (p : α → Bool) (start := 0) (stop := as.size) : Bool :=
|
||||
Id.run <| as.anyM p start stop
|
||||
@@ -766,48 +777,62 @@ def takeWhile (p : α → Bool) (as : Array α) : Array α :=
|
||||
decreasing_by simp_wf; decreasing_trivial_pre_omega
|
||||
go 0 #[]
|
||||
|
||||
/-- Remove the element at a given index from an array without bounds checks, using a `Fin` index.
|
||||
/--
|
||||
Remove the element at a given index from an array without a runtime bounds checks,
|
||||
using a `Nat` index and a tactic-provided bound.
|
||||
|
||||
This function takes worst case O(n) time because
|
||||
it has to backshift all elements at positions greater than `i`.-/
|
||||
This function takes worst case O(n) time because
|
||||
it has to backshift all elements at positions greater than `i`.-/
|
||||
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
|
||||
def feraseIdx (a : Array α) (i : Fin a.size) : Array α :=
|
||||
if h : i.val + 1 < a.size then
|
||||
let a' := a.swap ⟨i.val + 1, h⟩ i
|
||||
let i' : Fin a'.size := ⟨i.val + 1, by simp [a', h]⟩
|
||||
a'.feraseIdx i'
|
||||
def eraseIdx (a : Array α) (i : Nat) (h : i < a.size := by get_elem_tactic) : Array α :=
|
||||
if h' : i + 1 < a.size then
|
||||
let a' := a.swap ⟨i + 1, h'⟩ ⟨i, h⟩
|
||||
a'.eraseIdx (i + 1) (by simp [a', h'])
|
||||
else
|
||||
a.pop
|
||||
termination_by a.size - i.val
|
||||
decreasing_by simp_wf; exact Nat.sub_succ_lt_self _ _ i.isLt
|
||||
termination_by a.size - i
|
||||
decreasing_by simp_wf; exact Nat.sub_succ_lt_self _ _ h
|
||||
|
||||
-- This is required in `Lean.Data.PersistentHashMap`.
|
||||
@[simp] theorem size_feraseIdx (a : Array α) (i : Fin a.size) : (a.feraseIdx i).size = a.size - 1 := by
|
||||
induction a, i using Array.feraseIdx.induct with
|
||||
| @case1 a i h a' _ ih =>
|
||||
unfold feraseIdx
|
||||
simp [h, a', ih]
|
||||
| case2 a i h =>
|
||||
unfold feraseIdx
|
||||
simp [h]
|
||||
@[simp] theorem size_eraseIdx (a : Array α) (i : Nat) (h) : (a.eraseIdx i h).size = a.size - 1 := by
|
||||
induction a, i, h using Array.eraseIdx.induct with
|
||||
| @case1 a i h h' a' ih =>
|
||||
unfold eraseIdx
|
||||
simp [h', a', ih]
|
||||
| case2 a i h h' =>
|
||||
unfold eraseIdx
|
||||
simp [h']
|
||||
|
||||
/-- Remove the element at a given index from an array, or do nothing if the index is out of bounds.
|
||||
|
||||
This function takes worst case O(n) time because
|
||||
it has to backshift all elements at positions greater than `i`.-/
|
||||
def eraseIdx (a : Array α) (i : Nat) : Array α :=
|
||||
if h : i < a.size then a.feraseIdx ⟨i, h⟩ else a
|
||||
def eraseIdxIfInBounds (a : Array α) (i : Nat) : Array α :=
|
||||
if h : i < a.size then a.eraseIdx i h else a
|
||||
|
||||
/-- Remove the element at a given index from an array, or panic if the index is out of bounds.
|
||||
|
||||
This function takes worst case O(n) time because
|
||||
it has to backshift all elements at positions greater than `i`. -/
|
||||
def eraseIdx! (a : Array α) (i : Nat) : Array α :=
|
||||
if h : i < a.size then a.eraseIdx i h else panic! "invalid index"
|
||||
|
||||
def erase [BEq α] (as : Array α) (a : α) : Array α :=
|
||||
match as.indexOf? a with
|
||||
| none => as
|
||||
| some i => as.feraseIdx i
|
||||
| some i => as.eraseIdx i
|
||||
|
||||
/-- Erase the first element that satisfies the predicate `p`. -/
|
||||
def eraseP (as : Array α) (p : α → Bool) : Array α :=
|
||||
match as.findIdx? p with
|
||||
| none => as
|
||||
| some i => as.eraseIdxIfInBounds i
|
||||
|
||||
/-- Insert element `a` at position `i`. -/
|
||||
@[inline] def insertAt (as : Array α) (i : Fin (as.size + 1)) (a : α) : Array α :=
|
||||
@[inline] def insertIdx (as : Array α) (i : Nat) (a : α) (_ : i ≤ as.size := by get_elem_tactic) : Array α :=
|
||||
let rec @[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
|
||||
loop (as : Array α) (j : Fin as.size) :=
|
||||
if i.1 < j then
|
||||
if i < j then
|
||||
let j' := ⟨j-1, Nat.lt_of_le_of_lt (Nat.pred_le _) j.2⟩
|
||||
let as := as.swap j' j
|
||||
loop as ⟨j', by rw [size_swap]; exact j'.2⟩
|
||||
@@ -818,12 +843,23 @@ def erase [BEq α] (as : Array α) (a : α) : Array α :=
|
||||
let as := as.push a
|
||||
loop as ⟨j, size_push .. ▸ j.lt_succ_self⟩
|
||||
|
||||
@[deprecated insertIdx (since := "2024-11-20")] abbrev insertAt := @insertIdx
|
||||
|
||||
/-- Insert element `a` at position `i`. Panics if `i` is not `i ≤ as.size`. -/
|
||||
def insertAt! (as : Array α) (i : Nat) (a : α) : Array α :=
|
||||
def insertIdx! (as : Array α) (i : Nat) (a : α) : Array α :=
|
||||
if h : i ≤ as.size then
|
||||
insertAt as ⟨i, Nat.lt_succ_of_le h⟩ a
|
||||
insertIdx as i a
|
||||
else panic! "invalid index"
|
||||
|
||||
@[deprecated insertIdx! (since := "2024-11-20")] abbrev insertAt! := @insertIdx!
|
||||
|
||||
/-- Insert element `a` at position `i`, or do nothing if `as.size < i`. -/
|
||||
def insertIdxIfInBounds (as : Array α) (i : Nat) (a : α) : Array α :=
|
||||
if h : i ≤ as.size then
|
||||
insertIdx as i a
|
||||
else
|
||||
as
|
||||
|
||||
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
|
||||
def isPrefixOfAux [BEq α] (as bs : Array α) (hle : as.size ≤ bs.size) (i : Nat) : Bool :=
|
||||
if h : i < as.size then
|
||||
|
||||
@@ -5,59 +5,64 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Array.Basic
|
||||
import Init.Omega
|
||||
universe u v
|
||||
|
||||
-- TODO: CLEANUP
|
||||
|
||||
namespace Array
|
||||
-- TODO: remove the [Inhabited α] parameters as soon as we have the tactic framework for automating proof generation and using Array.fget
|
||||
-- TODO: remove `partial` using well-founded recursion
|
||||
|
||||
@[specialize] partial def binSearchAux {α : Type u} {β : Type v} [Inhabited β] (lt : α → α → Bool) (found : Option α → β) (as : Array α) (k : α) : Nat → Nat → β
|
||||
| lo, hi =>
|
||||
if lo <= hi then
|
||||
let _ := Inhabited.mk k
|
||||
let m := (lo + hi)/2
|
||||
let a := as.get! m
|
||||
if lt a k then binSearchAux lt found as k (m+1) hi
|
||||
else if lt k a then
|
||||
if m == 0 then found none
|
||||
else binSearchAux lt found as k lo (m-1)
|
||||
else found (some a)
|
||||
else found none
|
||||
@[specialize] def binSearchAux {α : Type u} {β : Type v} (lt : α → α → Bool) (found : Option α → β) (as : Array α) (k : α) :
|
||||
(lo : Fin (as.size + 1)) → (hi : Fin as.size) → (lo.1 ≤ hi.1) → β
|
||||
| lo, hi, h =>
|
||||
let m := (lo.1 + hi.1)/2
|
||||
let a := as[m]
|
||||
if lt a k then
|
||||
if h' : m + 1 ≤ hi.1 then
|
||||
binSearchAux lt found as k ⟨m+1, by omega⟩ hi h'
|
||||
else found none
|
||||
else if lt k a then
|
||||
if h' : m = 0 ∨ m - 1 < lo.1 then found none
|
||||
else binSearchAux lt found as k lo ⟨m-1, by omega⟩ (by simp; omega)
|
||||
else found (some a)
|
||||
termination_by lo hi => hi.1 - lo.1
|
||||
|
||||
@[inline] def binSearch {α : Type} (as : Array α) (k : α) (lt : α → α → Bool) (lo := 0) (hi := as.size - 1) : Option α :=
|
||||
if lo < as.size then
|
||||
if h : lo < as.size then
|
||||
let hi := if hi < as.size then hi else as.size - 1
|
||||
binSearchAux lt id as k lo hi
|
||||
if w : lo ≤ hi then
|
||||
binSearchAux lt id as k ⟨lo, by omega⟩ ⟨hi, by simp [hi]; split <;> omega⟩ (by simp [hi]; omega)
|
||||
else
|
||||
none
|
||||
else
|
||||
none
|
||||
|
||||
@[inline] def binSearchContains {α : Type} (as : Array α) (k : α) (lt : α → α → Bool) (lo := 0) (hi := as.size - 1) : Bool :=
|
||||
if lo < as.size then
|
||||
if h : lo < as.size then
|
||||
let hi := if hi < as.size then hi else as.size - 1
|
||||
binSearchAux lt Option.isSome as k lo hi
|
||||
if w : lo ≤ hi then
|
||||
binSearchAux lt Option.isSome as k ⟨lo, by omega⟩ ⟨hi, by simp [hi]; split <;> omega⟩ (by simp [hi]; omega)
|
||||
else
|
||||
false
|
||||
else
|
||||
false
|
||||
|
||||
@[specialize] private partial def binInsertAux {α : Type u} {m : Type u → Type v} [Monad m]
|
||||
@[specialize] private def binInsertAux {α : Type u} {m : Type u → Type v} [Monad m]
|
||||
(lt : α → α → Bool)
|
||||
(merge : α → m α)
|
||||
(add : Unit → m α)
|
||||
(as : Array α)
|
||||
(k : α) : Nat → Nat → m (Array α)
|
||||
| lo, hi =>
|
||||
let _ := Inhabited.mk k
|
||||
-- as[lo] < k < as[hi]
|
||||
let mid := (lo + hi)/2
|
||||
let midVal := as.get! mid
|
||||
if lt midVal k then
|
||||
if mid == lo then do let v ← add (); pure <| as.insertAt! (lo+1) v
|
||||
else binInsertAux lt merge add as k mid hi
|
||||
else if lt k midVal then
|
||||
binInsertAux lt merge add as k lo mid
|
||||
(k : α) : (lo : Fin as.size) → (hi : Fin as.size) → (lo.1 ≤ hi.1) → (lt as[lo] k) → m (Array α)
|
||||
| lo, hi, h, w =>
|
||||
let mid := (lo.1 + hi.1)/2
|
||||
let midVal := as[mid]
|
||||
if w₁ : lt midVal k then
|
||||
if h' : mid = lo then do let v ← add (); pure <| as.insertIdx (lo+1) v
|
||||
else binInsertAux lt merge add as k ⟨mid, by omega⟩ hi (by simp; omega) w₁
|
||||
else if w₂ : lt k midVal then
|
||||
have : mid ≠ lo := fun z => by simp [midVal, z] at w₁; simp_all
|
||||
binInsertAux lt merge add as k lo ⟨mid, by omega⟩ (by simp; omega) w
|
||||
else do
|
||||
as.modifyM mid <| fun v => merge v
|
||||
termination_by lo hi => hi.1 - lo.1
|
||||
|
||||
@[specialize] def binInsertM {α : Type u} {m : Type u → Type v} [Monad m]
|
||||
(lt : α → α → Bool)
|
||||
@@ -65,13 +70,12 @@ namespace Array
|
||||
(add : Unit → m α)
|
||||
(as : Array α)
|
||||
(k : α) : m (Array α) :=
|
||||
let _ := Inhabited.mk k
|
||||
if as.isEmpty then do let v ← add (); pure <| as.push v
|
||||
else if lt k (as.get! 0) then do let v ← add (); pure <| as.insertAt! 0 v
|
||||
else if !lt (as.get! 0) k then as.modifyM 0 <| merge
|
||||
else if lt as.back! k then do let v ← add (); pure <| as.push v
|
||||
else if !lt k as.back! then as.modifyM (as.size - 1) <| merge
|
||||
else binInsertAux lt merge add as k 0 (as.size - 1)
|
||||
if h : as.size = 0 then do let v ← add (); pure <| as.push v
|
||||
else if lt k as[0] then do let v ← add (); pure <| as.insertIdx 0 v
|
||||
else if h' : !lt as[0] k then as.modifyM 0 <| merge
|
||||
else if lt as[as.size - 1] k then do let v ← add (); pure <| as.push v
|
||||
else if !lt k as[as.size - 1] then as.modifyM (as.size - 1) <| merge
|
||||
else binInsertAux lt merge add as k ⟨0, by omega⟩ ⟨as.size - 1, by omega⟩ (by simp) (by simpa using h')
|
||||
|
||||
@[inline] def binInsert {α : Type u} (lt : α → α → Bool) (as : Array α) (k : α) : Array α :=
|
||||
Id.run <| binInsertM lt (fun _ => k) (fun _ => k) as k
|
||||
|
||||
@@ -6,7 +6,6 @@ Authors: Leonardo de Moura
|
||||
prelude
|
||||
import Init.Data.Array.Basic
|
||||
import Init.Data.BEq
|
||||
import Init.Data.Nat.Lemmas
|
||||
import Init.Data.List.Nat.BEq
|
||||
import Init.ByCases
|
||||
|
||||
|
||||
@@ -272,4 +272,10 @@ theorem find?_mkArray_eq_none {n : Nat} {a : α} {p : α → Bool} :
|
||||
((mkArray n a).find? p).get h = a := by
|
||||
simp [mkArray_eq_toArray_replicate]
|
||||
|
||||
theorem find?_pmap {P : α → Prop} (f : (a : α) → P a → β) (xs : Array α)
|
||||
(H : ∀ (a : α), a ∈ xs → P a) (p : β → Bool) :
|
||||
(xs.pmap f H).find? p = (xs.attach.find? (fun ⟨a, m⟩ => p (f a (H a m)))).map fun ⟨a, m⟩ => f a (H a m) := by
|
||||
simp only [pmap_eq_map_attach, find?_map]
|
||||
rfl
|
||||
|
||||
end Array
|
||||
|
||||
@@ -6,7 +6,7 @@ Authors: Leonardo de Moura
|
||||
prelude
|
||||
import Init.Data.Array.Basic
|
||||
|
||||
@[inline] def Array.insertionSort (a : Array α) (lt : α → α → Bool) : Array α :=
|
||||
@[inline] def Array.insertionSort (a : Array α) (lt : α → α → Bool := by exact (· < ·)) : Array α :=
|
||||
traverse a 0 a.size
|
||||
where
|
||||
@[specialize] traverse (a : Array α) (i : Nat) (fuel : Nat) : Array α :=
|
||||
|
||||
@@ -23,6 +23,9 @@ import Init.TacticsExtra
|
||||
|
||||
namespace Array
|
||||
|
||||
@[simp] theorem mem_toArray {a : α} {l : List α} : a ∈ l.toArray ↔ a ∈ l := by
|
||||
simp [mem_def]
|
||||
|
||||
@[simp] theorem getElem_mk {xs : List α} {i : Nat} (h : i < xs.length) : (Array.mk xs)[i] = xs[i] := rfl
|
||||
|
||||
theorem getElem_eq_getElem_toList {a : Array α} (h : i < a.size) : a[i] = a.toList[i] := rfl
|
||||
@@ -36,12 +39,21 @@ theorem getElem?_eq_getElem {a : Array α} {i : Nat} (h : i < a.size) : a[i]? =
|
||||
· rw [getElem?_neg a i h]
|
||||
simp_all
|
||||
|
||||
@[simp] theorem none_eq_getElem?_iff {a : Array α} {i : Nat} : none = a[i]? ↔ a.size ≤ i := by
|
||||
simp [eq_comm (a := none)]
|
||||
|
||||
theorem getElem?_eq {a : Array α} {i : Nat} :
|
||||
a[i]? = if h : i < a.size then some a[i] else none := by
|
||||
split
|
||||
· simp_all [getElem?_eq_getElem]
|
||||
· simp_all
|
||||
|
||||
theorem getElem?_eq_some_iff {a : Array α} : a[i]? = some b ↔ ∃ h : i < a.size, a[i] = b := by
|
||||
simp [getElem?_eq]
|
||||
|
||||
theorem some_eq_getElem?_iff {a : Array α} : some b = a[i]? ↔ ∃ h : i < a.size, a[i] = b := by
|
||||
rw [eq_comm, getElem?_eq_some_iff]
|
||||
|
||||
theorem getElem?_eq_getElem?_toList (a : Array α) (i : Nat) : a[i]? = a.toList[i]? := by
|
||||
rw [getElem?_eq]
|
||||
split <;> simp_all
|
||||
@@ -66,6 +78,35 @@ theorem getElem_push (a : Array α) (x : α) (i : Nat) (h : i < (a.push x).size)
|
||||
@[deprecated getElem_push_lt (since := "2024-10-21")] abbrev get_push_lt := @getElem_push_lt
|
||||
@[deprecated getElem_push_eq (since := "2024-10-21")] abbrev get_push_eq := @getElem_push_eq
|
||||
|
||||
@[simp] theorem mem_push {a : Array α} {x y : α} : x ∈ a.push y ↔ x ∈ a ∨ x = y := by
|
||||
simp [mem_def]
|
||||
|
||||
theorem mem_push_self {a : Array α} {x : α} : x ∈ a.push x :=
|
||||
mem_push.2 (Or.inr rfl)
|
||||
|
||||
theorem mem_push_of_mem {a : Array α} {x : α} (y : α) (h : x ∈ a) : x ∈ a.push y :=
|
||||
mem_push.2 (Or.inl h)
|
||||
|
||||
theorem getElem_of_mem {a} {l : Array α} (h : a ∈ l) : ∃ (n : Nat) (h : n < l.size), l[n]'h = a := by
|
||||
cases l
|
||||
simp [List.getElem_of_mem (by simpa using h)]
|
||||
|
||||
theorem getElem?_of_mem {a} {l : Array α} (h : a ∈ l) : ∃ n : Nat, l[n]? = some a :=
|
||||
let ⟨n, _, e⟩ := getElem_of_mem h; ⟨n, e ▸ getElem?_eq_getElem _⟩
|
||||
|
||||
theorem mem_of_getElem? {l : Array α} {n : Nat} {a : α} (e : l[n]? = some a) : a ∈ l :=
|
||||
let ⟨_, e⟩ := getElem?_eq_some_iff.1 e; e ▸ getElem_mem ..
|
||||
|
||||
theorem mem_iff_getElem {a} {l : Array α} : a ∈ l ↔ ∃ (n : Nat) (h : n < l.size), l[n]'h = a :=
|
||||
⟨getElem_of_mem, fun ⟨_, _, e⟩ => e ▸ getElem_mem ..⟩
|
||||
|
||||
theorem mem_iff_getElem? {a} {l : Array α} : a ∈ l ↔ ∃ n : Nat, l[n]? = some a := by
|
||||
simp [getElem?_eq_some_iff, mem_iff_getElem]
|
||||
|
||||
theorem forall_getElem {l : Array α} {p : α → Prop} :
|
||||
(∀ (n : Nat) h, p (l[n]'h)) ↔ ∀ a, a ∈ l → p a := by
|
||||
cases l; simp [List.forall_getElem]
|
||||
|
||||
@[simp] theorem get!_eq_getElem! [Inhabited α] (a : Array α) (i : Nat) : a.get! i = a[i]! := by
|
||||
simp [getElem!_def, get!, getD]
|
||||
split <;> rename_i h
|
||||
@@ -93,9 +134,6 @@ We prefer to pull `List.toArray` outwards.
|
||||
(a.toArrayAux b).size = b.size + a.length := by
|
||||
simp [size]
|
||||
|
||||
@[simp] theorem mem_toArray {a : α} {l : List α} : a ∈ l.toArray ↔ a ∈ l := by
|
||||
simp [mem_def]
|
||||
|
||||
@[simp] theorem push_toArray (l : List α) (a : α) : l.toArray.push a = (l ++ [a]).toArray := by
|
||||
apply ext'
|
||||
simp
|
||||
@@ -583,7 +621,20 @@ theorem getElem?_ofFn (f : Fin n → α) (i : Nat) :
|
||||
(ofFn f)[i]? = if h : i < n then some (f ⟨i, h⟩) else none := by
|
||||
simp [getElem?_def]
|
||||
|
||||
/-- # mkArray -/
|
||||
@[simp] theorem ofFn_zero (f : Fin 0 → α) : ofFn f = #[] := rfl
|
||||
|
||||
theorem ofFn_succ (f : Fin (n+1) → α) :
|
||||
ofFn f = (ofFn (fun (i : Fin n) => f i.castSucc)).push (f ⟨n, by omega⟩) := by
|
||||
ext i h₁ h₂
|
||||
· simp
|
||||
· simp [getElem_push]
|
||||
split <;> rename_i h₃
|
||||
· rfl
|
||||
· congr
|
||||
simp at h₁ h₂
|
||||
omega
|
||||
|
||||
/-! # mkArray -/
|
||||
|
||||
@[simp] theorem size_mkArray (n : Nat) (v : α) : (mkArray n v).size = n :=
|
||||
List.length_replicate ..
|
||||
@@ -599,25 +650,12 @@ theorem getElem?_mkArray (n : Nat) (v : α) (i : Nat) :
|
||||
(mkArray n v)[i]? = if i < n then some v else none := by
|
||||
simp [getElem?_def]
|
||||
|
||||
/-- # mem -/
|
||||
/-! # mem -/
|
||||
|
||||
@[simp] theorem mem_toList {a : α} {l : Array α} : a ∈ l.toList ↔ a ∈ l := mem_def.symm
|
||||
|
||||
theorem not_mem_nil (a : α) : ¬ a ∈ #[] := nofun
|
||||
|
||||
theorem getElem_of_mem {a : α} {as : Array α} :
|
||||
a ∈ as → (∃ (n : Nat) (h : n < as.size), as[n]'h = a) := by
|
||||
intro ha
|
||||
rcases List.getElem_of_mem ha.val with ⟨i, hbound, hi⟩
|
||||
exists i
|
||||
exists hbound
|
||||
|
||||
theorem getElem?_of_mem {a : α} {as : Array α} :
|
||||
a ∈ as → ∃ (n : Nat), as[n]? = some a := by
|
||||
intro ha
|
||||
rcases List.getElem?_of_mem ha.val with ⟨i, hi⟩
|
||||
exists i
|
||||
|
||||
@[simp] theorem mem_dite_empty_left {x : α} [Decidable p] {l : ¬ p → Array α} :
|
||||
(x ∈ if h : p then #[] else l h) ↔ ∃ h : ¬ p, x ∈ l h := by
|
||||
split <;> simp_all
|
||||
@@ -634,7 +672,7 @@ theorem getElem?_of_mem {a : α} {as : Array α} :
|
||||
(x ∈ if p then l else #[]) ↔ p ∧ x ∈ l := by
|
||||
split <;> simp_all
|
||||
|
||||
/-- # get lemmas -/
|
||||
/-! # get lemmas -/
|
||||
|
||||
theorem lt_of_getElem {x : α} {a : Array α} {idx : Nat} {hidx : idx < a.size} (_ : a[idx] = x) :
|
||||
idx < a.size :=
|
||||
@@ -659,10 +697,6 @@ theorem get?_eq_get?_toList (a : Array α) (i : Nat) : a.get? i = a.toList.get?
|
||||
theorem get!_eq_get? [Inhabited α] (a : Array α) : a.get! n = (a.get? n).getD default := by
|
||||
simp only [get!_eq_getElem?, get?_eq_getElem?]
|
||||
|
||||
theorem getElem?_eq_some_iff {as : Array α} : as[n]? = some a ↔ ∃ h : n < as.size, as[n] = a := by
|
||||
cases as
|
||||
simp [List.getElem?_eq_some_iff]
|
||||
|
||||
theorem back!_eq_back? [Inhabited α] (a : Array α) : a.back! = a.back?.getD default := by
|
||||
simp only [back!, get!_eq_getElem?, get?_eq_getElem?, back?]
|
||||
|
||||
@@ -672,6 +706,10 @@ theorem back!_eq_back? [Inhabited α] (a : Array α) : a.back! = a.back?.getD de
|
||||
@[simp] theorem back!_push [Inhabited α] (a : Array α) : (a.push x).back! = x := by
|
||||
simp [back!_eq_back?]
|
||||
|
||||
theorem mem_of_back?_eq_some {xs : Array α} {a : α} (h : xs.back? = some a) : a ∈ xs := by
|
||||
cases xs
|
||||
simpa using List.mem_of_getLast?_eq_some (by simpa using h)
|
||||
|
||||
theorem getElem?_push_lt (a : Array α) (x : α) (i : Nat) (h : i < a.size) :
|
||||
(a.push x)[i]? = some a[i] := by
|
||||
rw [getElem?_pos, getElem_push_lt]
|
||||
@@ -815,16 +853,10 @@ theorem size_eq_length_toList (as : Array α) : as.size = as.toList.length := rf
|
||||
simp only [reverse]; split <;> simp [go]
|
||||
|
||||
@[simp] theorem size_range {n : Nat} : (range n).size = n := by
|
||||
unfold range
|
||||
induction n with
|
||||
| zero => simp [Nat.fold]
|
||||
| succ k ih =>
|
||||
rw [Nat.fold, flip]
|
||||
simp only [mkEmpty_eq, size_push] at *
|
||||
omega
|
||||
induction n <;> simp [range]
|
||||
|
||||
@[simp] theorem toList_range (n : Nat) : (range n).toList = List.range n := by
|
||||
induction n <;> simp_all [range, Nat.fold, flip, List.range_succ]
|
||||
apply List.ext_getElem <;> simp [range]
|
||||
|
||||
@[simp]
|
||||
theorem getElem_range {n : Nat} {x : Nat} (h : x < (Array.range n).size) : (Array.range n)[x] = x := by
|
||||
@@ -1025,6 +1057,10 @@ theorem foldr_congr {as bs : Array α} (h₀ : as = bs) {f g : α → β → β}
|
||||
@[simp] theorem mem_map {f : α → β} {l : Array α} : b ∈ l.map f ↔ ∃ a, a ∈ l ∧ f a = b := by
|
||||
simp only [mem_def, toList_map, List.mem_map]
|
||||
|
||||
theorem exists_of_mem_map (h : b ∈ map f l) : ∃ a, a ∈ l ∧ f a = b := mem_map.1 h
|
||||
|
||||
theorem mem_map_of_mem (f : α → β) (h : a ∈ l) : f a ∈ map f l := mem_map.2 ⟨_, h, rfl⟩
|
||||
|
||||
theorem mapM_eq_mapM_toList [Monad m] [LawfulMonad m] (f : α → m β) (arr : Array α) :
|
||||
arr.mapM f = List.toArray <$> (arr.toList.mapM f) := by
|
||||
rw [mapM_eq_foldlM, ← foldlM_toList, ← List.foldrM_reverse]
|
||||
@@ -1215,6 +1251,12 @@ theorem push_eq_append_singleton (as : Array α) (x) : as.push x = as ++ #[x] :=
|
||||
@[simp] theorem mem_append {a : α} {s t : Array α} : a ∈ s ++ t ↔ a ∈ s ∨ a ∈ t := by
|
||||
simp only [mem_def, toList_append, List.mem_append]
|
||||
|
||||
theorem mem_append_left {a : α} {l₁ : Array α} (l₂ : Array α) (h : a ∈ l₁) : a ∈ l₁ ++ l₂ :=
|
||||
mem_append.2 (Or.inl h)
|
||||
|
||||
theorem mem_append_right {a : α} (l₁ : Array α) {l₂ : Array α} (h : a ∈ l₂) : a ∈ l₁ ++ l₂ :=
|
||||
mem_append.2 (Or.inr h)
|
||||
|
||||
@[simp] theorem size_append (as bs : Array α) : (as ++ bs).size = as.size + bs.size := by
|
||||
simp only [size, toList_append, List.length_append]
|
||||
|
||||
@@ -1602,9 +1644,9 @@ theorem swap_comm (a : Array α) {i j : Fin a.size} : a.swap i j = a.swap j i :=
|
||||
|
||||
/-! ### eraseIdx -/
|
||||
|
||||
theorem feraseIdx_eq_eraseIdx {a : Array α} {i : Fin a.size} :
|
||||
a.feraseIdx i = a.eraseIdx i.1 := by
|
||||
simp [eraseIdx]
|
||||
theorem eraseIdx_eq_eraseIdxIfInBounds {a : Array α} {i : Nat} (h : i < a.size) :
|
||||
a.eraseIdx i h = a.eraseIdxIfInBounds i := by
|
||||
simp [eraseIdxIfInBounds, h]
|
||||
|
||||
/-! ### isPrefixOf -/
|
||||
|
||||
@@ -1834,16 +1876,15 @@ theorem takeWhile_go_toArray (p : α → Bool) (l : List α) (i : Nat) :
|
||||
l.toArray.takeWhile p = (l.takeWhile p).toArray := by
|
||||
simp [Array.takeWhile, takeWhile_go_toArray]
|
||||
|
||||
@[simp] theorem feraseIdx_toArray (l : List α) (i : Fin l.toArray.size) :
|
||||
l.toArray.feraseIdx i = (l.eraseIdx i).toArray := by
|
||||
rw [feraseIdx]
|
||||
split <;> rename_i h
|
||||
· rw [feraseIdx_toArray]
|
||||
@[simp] theorem eraseIdx_toArray (l : List α) (i : Nat) (h : i < l.toArray.size) :
|
||||
l.toArray.eraseIdx i h = (l.eraseIdx i).toArray := by
|
||||
rw [Array.eraseIdx]
|
||||
split <;> rename_i h'
|
||||
· rw [eraseIdx_toArray]
|
||||
simp only [swap_toArray, Fin.getElem_fin, toList_toArray, mk.injEq]
|
||||
rw [eraseIdx_set_gt (by simp), eraseIdx_set_eq]
|
||||
simp
|
||||
· rcases i with ⟨i, w⟩
|
||||
simp at h w
|
||||
· simp at h h'
|
||||
have t : i = l.length - 1 := by omega
|
||||
simp [t]
|
||||
termination_by l.length - i
|
||||
@@ -1853,9 +1894,9 @@ decreasing_by
|
||||
simp
|
||||
omega
|
||||
|
||||
@[simp] theorem eraseIdx_toArray (l : List α) (i : Nat) :
|
||||
l.toArray.eraseIdx i = (l.eraseIdx i).toArray := by
|
||||
rw [Array.eraseIdx]
|
||||
@[simp] theorem eraseIdxIfInBounds_toArray (l : List α) (i : Nat) :
|
||||
l.toArray.eraseIdxIfInBounds i = (l.eraseIdx i).toArray := by
|
||||
rw [Array.eraseIdxIfInBounds]
|
||||
split
|
||||
· simp
|
||||
· simp_all [eraseIdx_eq_self.2]
|
||||
@@ -1874,13 +1915,13 @@ namespace Array
|
||||
(as.takeWhile p).toList = as.toList.takeWhile p := by
|
||||
induction as; simp
|
||||
|
||||
@[simp] theorem toList_feraseIdx (as : Array α) (i : Fin as.size) :
|
||||
(as.feraseIdx i).toList = as.toList.eraseIdx i.1 := by
|
||||
@[simp] theorem toList_eraseIdx (as : Array α) (i : Nat) (h : i < as.size) :
|
||||
(as.eraseIdx i h).toList = as.toList.eraseIdx i := by
|
||||
induction as
|
||||
simp
|
||||
|
||||
@[simp] theorem toList_eraseIdx (as : Array α) (i : Nat) :
|
||||
(as.eraseIdx i).toList = as.toList.eraseIdx i := by
|
||||
@[simp] theorem toList_eraseIdxIfInBounds (as : Array α) (i : Nat) :
|
||||
(as.eraseIdxIfInBounds i).toList = as.toList.eraseIdx i := by
|
||||
induction as
|
||||
simp
|
||||
|
||||
@@ -1914,6 +1955,26 @@ theorem array_array_induction (P : Array (Array α) → Prop) (h : ∀ (xss : Li
|
||||
specialize h (ass.toList.map toList)
|
||||
simpa [← toList_map, Function.comp_def, map_id] using h
|
||||
|
||||
theorem foldl_map (f : β₁ → β₂) (g : α → β₂ → α) (l : Array β₁) (init : α) :
|
||||
(l.map f).foldl g init = l.foldl (fun x y => g x (f y)) init := by
|
||||
cases l; simp [List.foldl_map]
|
||||
|
||||
theorem foldr_map (f : α₁ → α₂) (g : α₂ → β → β) (l : Array α₁) (init : β) :
|
||||
(l.map f).foldr g init = l.foldr (fun x y => g (f x) y) init := by
|
||||
cases l; simp [List.foldr_map]
|
||||
|
||||
theorem foldl_filterMap (f : α → Option β) (g : γ → β → γ) (l : Array α) (init : γ) :
|
||||
(l.filterMap f).foldl g init = l.foldl (fun x y => match f y with | some b => g x b | none => x) init := by
|
||||
cases l
|
||||
simp [List.foldl_filterMap]
|
||||
rfl
|
||||
|
||||
theorem foldr_filterMap (f : α → Option β) (g : β → γ → γ) (l : Array α) (init : γ) :
|
||||
(l.filterMap f).foldr g init = l.foldr (fun x y => match f x with | some b => g b y | none => y) init := by
|
||||
cases l
|
||||
simp [List.foldr_filterMap]
|
||||
rfl
|
||||
|
||||
/-! ### flatten -/
|
||||
|
||||
@[simp] theorem flatten_empty : flatten (#[] : Array (Array α)) = #[] := rfl
|
||||
@@ -1928,6 +1989,12 @@ theorem array_array_induction (P : Array (Array α) → Prop) (h : ∀ (xss : Li
|
||||
| nil => simp
|
||||
| cons xs xss ih => simp [ih]
|
||||
|
||||
/-! ### reverse -/
|
||||
|
||||
@[simp] theorem mem_reverse {x : α} {as : Array α} : x ∈ as.reverse ↔ x ∈ as := by
|
||||
cases as
|
||||
simp
|
||||
|
||||
/-! ### findSomeRevM?, findRevM?, findSomeRev?, findRev? -/
|
||||
|
||||
@[simp] theorem findSomeRevM?_eq_findSomeM?_reverse
|
||||
|
||||
@@ -346,6 +346,10 @@ theorem getMsbD_sub {i : Nat} {i_lt : i < w} {x y : BitVec w} :
|
||||
· rfl
|
||||
· omega
|
||||
|
||||
theorem getElem_sub {i : Nat} {x y : BitVec w} (h : i < w) :
|
||||
(x - y)[i] = (x[i] ^^ ((~~~y + 1#w)[i] ^^ carry i x (~~~y + 1#w) false)) := by
|
||||
simp [← getLsbD_eq_getElem, getLsbD_sub, h]
|
||||
|
||||
theorem msb_sub {x y: BitVec w} :
|
||||
(x - y).msb
|
||||
= (x.msb ^^ ((~~~y + 1#w).msb ^^ carry (w - 1 - 0) x (~~~y + 1#w) false)) := by
|
||||
@@ -410,6 +414,10 @@ theorem getLsbD_neg {i : Nat} {x : BitVec w} :
|
||||
· have h_ge : w ≤ i := by omega
|
||||
simp [getLsbD_ge _ _ h_ge, h_ge, hi]
|
||||
|
||||
theorem getElem_neg {i : Nat} {x : BitVec w} (h : i < w) :
|
||||
(-x)[i] = (x[i] ^^ decide (∃ j < i, x.getLsbD j = true)) := by
|
||||
simp [← getLsbD_eq_getElem, getLsbD_neg, h]
|
||||
|
||||
theorem getMsbD_neg {i : Nat} {x : BitVec w} :
|
||||
getMsbD (-x) i =
|
||||
(getMsbD x i ^^ decide (∃ j < w, i < j ∧ getMsbD x j = true)) := by
|
||||
|
||||
@@ -269,6 +269,10 @@ theorem ofBool_eq_iff_eq : ∀ {b b' : Bool}, BitVec.ofBool b = BitVec.ofBool b'
|
||||
getLsbD (x#'lt) i = x.testBit i := by
|
||||
simp [getLsbD, BitVec.ofNatLt]
|
||||
|
||||
@[simp] theorem getMsbD_ofNatLt {n x i : Nat} (h : x < 2^n) :
|
||||
getMsbD (x#'h) i = (decide (i < n) && x.testBit (n - 1 - i)) := by
|
||||
simp [getMsbD, getLsbD]
|
||||
|
||||
@[simp, bv_toNat] theorem toNat_ofNat (x w : Nat) : (BitVec.ofNat w x).toNat = x % 2^w := by
|
||||
simp [BitVec.toNat, BitVec.ofNat, Fin.ofNat']
|
||||
|
||||
@@ -561,6 +565,10 @@ theorem zeroExtend_eq_setWidth {v : Nat} {x : BitVec w} :
|
||||
else
|
||||
simp [n_le_i, toNat_ofNat]
|
||||
|
||||
@[simp] theorem toInt_setWidth (x : BitVec w) :
|
||||
(x.setWidth v).toInt = Int.bmod x.toNat (2^v) := by
|
||||
simp [toInt_eq_toNat_bmod, toNat_setWidth, Int.emod_bmod]
|
||||
|
||||
theorem setWidth'_eq {x : BitVec w} (h : w ≤ v) : x.setWidth' h = x.setWidth v := by
|
||||
apply eq_of_toNat_eq
|
||||
rw [toNat_setWidth, toNat_setWidth']
|
||||
@@ -755,6 +763,10 @@ theorem extractLsb'_eq_extractLsb {w : Nat} (x : BitVec w) (start len : Nat) (h
|
||||
@[simp] theorem getLsbD_allOnes : (allOnes v).getLsbD i = decide (i < v) := by
|
||||
simp [allOnes]
|
||||
|
||||
@[simp] theorem getMsbD_allOnes : (allOnes v).getMsbD i = decide (i < v) := by
|
||||
simp [allOnes]
|
||||
omega
|
||||
|
||||
@[simp] theorem getElem_allOnes (i : Nat) (h : i < v) : (allOnes v)[i] = true := by
|
||||
simp [getElem_eq_testBit_toNat, h]
|
||||
|
||||
@@ -772,6 +784,12 @@ theorem extractLsb'_eq_extractLsb {w : Nat} (x : BitVec w) (start len : Nat) (h
|
||||
@[simp] theorem toNat_or (x y : BitVec v) :
|
||||
BitVec.toNat (x ||| y) = BitVec.toNat x ||| BitVec.toNat y := rfl
|
||||
|
||||
@[simp] theorem toInt_or (x y : BitVec w) :
|
||||
BitVec.toInt (x ||| y) = Int.bmod (BitVec.toNat x ||| BitVec.toNat y) (2^w) := by
|
||||
rw_mod_cast [Int.bmod_def, BitVec.toInt, toNat_or, Nat.mod_eq_of_lt
|
||||
(Nat.or_lt_two_pow (BitVec.isLt x) (BitVec.isLt y))]
|
||||
omega
|
||||
|
||||
@[simp] theorem toFin_or (x y : BitVec v) :
|
||||
BitVec.toFin (x ||| y) = BitVec.toFin x ||| BitVec.toFin y := by
|
||||
apply Fin.eq_of_val_eq
|
||||
@@ -839,6 +857,12 @@ instance : Std.LawfulCommIdentity (α := BitVec n) (· ||| · ) (0#n) where
|
||||
@[simp] theorem toNat_and (x y : BitVec v) :
|
||||
BitVec.toNat (x &&& y) = BitVec.toNat x &&& BitVec.toNat y := rfl
|
||||
|
||||
@[simp] theorem toInt_and (x y : BitVec w) :
|
||||
BitVec.toInt (x &&& y) = Int.bmod (BitVec.toNat x &&& BitVec.toNat y) (2^w) := by
|
||||
rw_mod_cast [Int.bmod_def, BitVec.toInt, toNat_and, Nat.mod_eq_of_lt
|
||||
(Nat.and_lt_two_pow x.toNat (BitVec.isLt y))]
|
||||
omega
|
||||
|
||||
@[simp] theorem toFin_and (x y : BitVec v) :
|
||||
BitVec.toFin (x &&& y) = BitVec.toFin x &&& BitVec.toFin y := by
|
||||
apply Fin.eq_of_val_eq
|
||||
@@ -906,6 +930,12 @@ instance : Std.LawfulCommIdentity (α := BitVec n) (· &&& · ) (allOnes n) wher
|
||||
@[simp] theorem toNat_xor (x y : BitVec v) :
|
||||
BitVec.toNat (x ^^^ y) = BitVec.toNat x ^^^ BitVec.toNat y := rfl
|
||||
|
||||
@[simp] theorem toInt_xor (x y : BitVec w) :
|
||||
BitVec.toInt (x ^^^ y) = Int.bmod (BitVec.toNat x ^^^ BitVec.toNat y) (2^w) := by
|
||||
rw_mod_cast [Int.bmod_def, BitVec.toInt, toNat_xor, Nat.mod_eq_of_lt
|
||||
(Nat.xor_lt_two_pow (BitVec.isLt x) (BitVec.isLt y))]
|
||||
omega
|
||||
|
||||
@[simp] theorem toFin_xor (x y : BitVec v) :
|
||||
BitVec.toFin (x ^^^ y) = BitVec.toFin x ^^^ BitVec.toFin y := by
|
||||
apply Fin.eq_of_val_eq
|
||||
@@ -983,6 +1013,13 @@ theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl
|
||||
_ ≤ 2 ^ i := Nat.pow_le_pow_of_le_right Nat.zero_lt_two w
|
||||
· simp
|
||||
|
||||
@[simp] theorem toInt_not {x : BitVec w} :
|
||||
(~~~x).toInt = Int.bmod (2^w - 1 - x.toNat) (2^w) := by
|
||||
rw_mod_cast [BitVec.toInt, BitVec.toNat_not, Int.bmod_def]
|
||||
simp [show ((2^w : Nat) : Int) - 1 - x.toNat = ((2^w - 1 - x.toNat) : Nat) by omega]
|
||||
rw_mod_cast [Nat.mod_eq_of_lt (by omega)]
|
||||
omega
|
||||
|
||||
@[simp] theorem ofInt_negSucc_eq_not_ofNat {w n : Nat} :
|
||||
BitVec.ofInt w (Int.negSucc n) = ~~~.ofNat w n := by
|
||||
simp only [BitVec.ofInt, Int.toNat, Int.ofNat_eq_coe, toNat_eq, toNat_ofNatLt, toNat_not,
|
||||
@@ -1007,6 +1044,10 @@ theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl
|
||||
@[simp] theorem getLsbD_not {x : BitVec v} : (~~~x).getLsbD i = (decide (i < v) && ! x.getLsbD i) := by
|
||||
by_cases h' : i < v <;> simp_all [not_def]
|
||||
|
||||
@[simp] theorem getMsbD_not {x : BitVec v} :
|
||||
(~~~x).getMsbD i = (decide (i < v) && ! x.getMsbD i) := by
|
||||
by_cases h' : i < v <;> simp_all [not_def]
|
||||
|
||||
@[simp] theorem getElem_not {x : BitVec w} {i : Nat} (h : i < w) : (~~~x)[i] = !x[i] := by
|
||||
simp only [getElem_eq_testBit_toNat, toNat_not]
|
||||
rw [← Nat.sub_add_eq, Nat.add_comm 1]
|
||||
@@ -1480,6 +1521,12 @@ theorem getLsbD_sshiftRight' {x y: BitVec w} {i : Nat} :
|
||||
(!decide (w ≤ i) && if y.toNat + i < w then x.getLsbD (y.toNat + i) else x.msb) := by
|
||||
simp only [BitVec.sshiftRight', BitVec.getLsbD_sshiftRight]
|
||||
|
||||
@[simp]
|
||||
theorem getElem_sshiftRight' {x y : BitVec w} {i : Nat} (h : i < w) :
|
||||
(x.sshiftRight' y)[i] =
|
||||
(!decide (w ≤ i) && if y.toNat + i < w then x.getLsbD (y.toNat + i) else x.msb) := by
|
||||
simp only [← getLsbD_eq_getElem, BitVec.sshiftRight', BitVec.getLsbD_sshiftRight]
|
||||
|
||||
@[simp]
|
||||
theorem getMsbD_sshiftRight' {x y: BitVec w} {i : Nat} :
|
||||
(x.sshiftRight y.toNat).getMsbD i = (decide (i < w) && if i < y.toNat then x.msb else x.getMsbD (i - y.toNat)) := by
|
||||
@@ -1572,6 +1619,79 @@ theorem signExtend_eq_setWidth_of_lt (x : BitVec w) {v : Nat} (hv : v ≤ w):
|
||||
theorem signExtend_eq (x : BitVec w) : x.signExtend w = x := by
|
||||
rw [signExtend_eq_setWidth_of_lt _ (Nat.le_refl _), setWidth_eq]
|
||||
|
||||
/-- Sign extending to a larger bitwidth depends on the msb.
|
||||
If the msb is false, then the result equals the original value.
|
||||
If the msb is true, then we add a value of `(2^v - 2^w)`, which arises from the sign extension. -/
|
||||
theorem toNat_signExtend_of_le (x : BitVec w) {v : Nat} (hv : w ≤ v) :
|
||||
(x.signExtend v).toNat = x.toNat + if x.msb then 2^v - 2^w else 0 := by
|
||||
apply Nat.eq_of_testBit_eq
|
||||
intro i
|
||||
have ⟨k, hk⟩ := Nat.exists_eq_add_of_le hv
|
||||
rw [hk, testBit_toNat, getLsbD_signExtend, Nat.pow_add, ← Nat.mul_sub_one, Nat.add_comm (x.toNat)]
|
||||
by_cases hx : x.msb
|
||||
· simp [hx, Nat.testBit_mul_pow_two_add _ x.isLt, testBit_toNat]
|
||||
-- Case analysis on i being in the intervals [0..w), [w..w + k), [w+k..∞)
|
||||
have hi : i < w ∨ (w ≤ i ∧ i < w + k) ∨ w + k ≤ i := by omega
|
||||
rcases hi with hi | hi | hi
|
||||
· simp [hi]; omega
|
||||
· simp [hi]; omega
|
||||
· simp [hi, show ¬ (i < w + k) by omega, show ¬ (i < w) by omega]
|
||||
omega
|
||||
· simp [hx, Nat.testBit_mul_pow_two_add _ x.isLt, testBit_toNat]
|
||||
have hi : i < w ∨ (w ≤ i ∧ i < w + k) ∨ w + k ≤ i := by omega
|
||||
rcases hi with hi | hi | hi
|
||||
· simp [hi]; omega
|
||||
· simp [hi]
|
||||
· simp [hi, show ¬ (i < w + k) by omega, show ¬ (i < w) by omega, getLsbD_ge x i (by omega)]
|
||||
|
||||
/-- Sign extending to a larger bitwidth depends on the msb.
|
||||
If the msb is false, then the result equals the original value.
|
||||
If the msb is true, then we add a value of `(2^v - 2^w)`, which arises from the sign extension. -/
|
||||
theorem toNat_signExtend (x : BitVec w) {v : Nat} :
|
||||
(x.signExtend v).toNat = (x.setWidth v).toNat + if x.msb then 2^v - 2^w else 0 := by
|
||||
by_cases h : v ≤ w
|
||||
· have : 2^v ≤ 2^w := Nat.pow_le_pow_of_le_right Nat.two_pos h
|
||||
simp [signExtend_eq_setWidth_of_lt x h, toNat_setWidth, Nat.sub_eq_zero_of_le this]
|
||||
· have : 2^w ≤ 2^v := Nat.pow_le_pow_of_le_right Nat.two_pos (by omega)
|
||||
rw [toNat_signExtend_of_le x (by omega), toNat_setWidth, Nat.mod_eq_of_lt (by omega)]
|
||||
|
||||
/-
|
||||
If the current width `w` is smaller than the extended width `v`,
|
||||
then the value when interpreted as an integer does not change.
|
||||
-/
|
||||
theorem toInt_signExtend_of_lt {x : BitVec w} (hv : w < v):
|
||||
(x.signExtend v).toInt = x.toInt := by
|
||||
simp only [toInt_eq_msb_cond, toNat_signExtend]
|
||||
have : (x.signExtend v).msb = x.msb := by
|
||||
rw [msb_eq_getLsbD_last, getLsbD_eq_getElem (Nat.sub_one_lt_of_lt hv)]
|
||||
simp [getElem_signExtend, Nat.le_sub_one_of_lt hv]
|
||||
have H : 2^w ≤ 2^v := Nat.pow_le_pow_of_le_right (by omega) (by omega)
|
||||
simp only [this, toNat_setWidth, Int.natCast_add, Int.ofNat_emod, Int.natCast_mul]
|
||||
by_cases h : x.msb
|
||||
<;> norm_cast
|
||||
<;> simp [h, Nat.mod_eq_of_lt (Nat.lt_of_lt_of_le x.isLt H)]
|
||||
omega
|
||||
|
||||
/-
|
||||
If the current width `w` is larger than the extended width `v`,
|
||||
then the value when interpreted as an integer is truncated,
|
||||
and we compute a modulo by `2^v`.
|
||||
-/
|
||||
theorem toInt_signExtend_of_le {x : BitVec w} (hv : v ≤ w) :
|
||||
(x.signExtend v).toInt = Int.bmod x.toNat (2^v) := by
|
||||
simp [signExtend_eq_setWidth_of_lt _ hv]
|
||||
|
||||
/-
|
||||
Interpreting the sign extension of `(x : BitVec w)` to width `v`
|
||||
computes `x % 2^v` (where `%` is the balanced mod).
|
||||
-/
|
||||
theorem toInt_signExtend (x : BitVec w) :
|
||||
(x.signExtend v).toInt = Int.bmod x.toNat (2^(min v w)) := by
|
||||
by_cases hv : v ≤ w
|
||||
· simp [toInt_signExtend_of_le hv, Nat.min_eq_left hv]
|
||||
· simp only [Nat.not_le] at hv
|
||||
rw [toInt_signExtend_of_lt hv, Nat.min_eq_right (by omega), toInt_eq_toNat_bmod]
|
||||
|
||||
/-! ### append -/
|
||||
|
||||
theorem append_def (x : BitVec v) (y : BitVec w) :
|
||||
@@ -2611,7 +2731,7 @@ theorem getLsbD_rotateLeftAux_of_geq {x : BitVec w} {r : Nat} {i : Nat} (hi : i
|
||||
apply getLsbD_ge
|
||||
omega
|
||||
|
||||
/-- When `r < w`, we give a formula for `(x.rotateRight r).getLsbD i`. -/
|
||||
/-- When `r < w`, we give a formula for `(x.rotateLeft r).getLsbD i`. -/
|
||||
theorem getLsbD_rotateLeft_of_le {x : BitVec w} {r i : Nat} (hr: r < w) :
|
||||
(x.rotateLeft r).getLsbD i =
|
||||
cond (i < r)
|
||||
@@ -2638,6 +2758,56 @@ theorem getElem_rotateLeft {x : BitVec w} {r i : Nat} (h : i < w) :
|
||||
if h' : i < r % w then x[(w - (r % w) + i)] else x[i - (r % w)] := by
|
||||
simp [← BitVec.getLsbD_eq_getElem, h]
|
||||
|
||||
/-- If `w ≤ x < 2 * w`, then `x % w = x - w` -/
|
||||
theorem mod_eq_sub_of_le_of_lt {x w : Nat} (x_le : w ≤ x) (x_lt : x < 2 * w) :
|
||||
x % w = x - w := by
|
||||
rw [Nat.mod_eq_sub_mod, Nat.mod_eq_of_lt (by omega)]
|
||||
omega
|
||||
|
||||
theorem getMsbD_rotateLeftAux_of_lt {x : BitVec w} {r : Nat} {i : Nat} (hi : i < w - r) :
|
||||
(x.rotateLeftAux r).getMsbD i = x.getMsbD (r + i) := by
|
||||
rw [rotateLeftAux, getMsbD_or]
|
||||
simp [show i < w - r by omega, Nat.add_comm]
|
||||
|
||||
theorem getMsbD_rotateLeftAux_of_ge {x : BitVec w} {r : Nat} {i : Nat} (hi : i ≥ w - r) :
|
||||
(x.rotateLeftAux r).getMsbD i = (decide (i < w) && x.getMsbD (i - (w - r))) := by
|
||||
simp [rotateLeftAux, getMsbD_or, show i + r ≥ w by omega, show ¬i < w - r by omega]
|
||||
|
||||
/-- When `r < w`, we give a formula for `(x.rotateLeft r).getMsbD i`. -/
|
||||
theorem getMsbD_rotateLeft_of_lt {n w : Nat} {x : BitVec w} (hi : r < w):
|
||||
(x.rotateLeft r).getMsbD n = (decide (n < w) && x.getMsbD ((r + n) % w)) := by
|
||||
rcases w with rfl | w
|
||||
· simp
|
||||
· rw [BitVec.rotateLeft_eq_rotateLeftAux_of_lt (by omega)]
|
||||
by_cases h : n < (w + 1) - r
|
||||
· simp [getMsbD_rotateLeftAux_of_lt h, Nat.mod_eq_of_lt, show r + n < (w + 1) by omega, show n < w + 1 by omega]
|
||||
· simp [getMsbD_rotateLeftAux_of_ge <| Nat.ge_of_not_lt h]
|
||||
by_cases h₁ : n < w + 1
|
||||
· simp only [h₁, decide_true, Bool.true_and]
|
||||
have h₂ : (r + n) < 2 * (w + 1) := by omega
|
||||
rw [mod_eq_sub_of_le_of_lt (by omega) (by omega)]
|
||||
congr 1
|
||||
omega
|
||||
· simp [h₁]
|
||||
|
||||
theorem getMsbD_rotateLeft {r n w : Nat} {x : BitVec w} :
|
||||
(x.rotateLeft r).getMsbD n = (decide (n < w) && x.getMsbD ((r + n) % w)) := by
|
||||
rcases w with rfl | w
|
||||
· simp
|
||||
· by_cases h : r < w
|
||||
· rw [getMsbD_rotateLeft_of_lt (by omega)]
|
||||
· rw [← rotateLeft_mod_eq_rotateLeft, getMsbD_rotateLeft_of_lt (by apply Nat.mod_lt; simp)]
|
||||
simp
|
||||
|
||||
@[simp]
|
||||
theorem msb_rotateLeft {m w : Nat} {x : BitVec w} :
|
||||
(x.rotateLeft m).msb = x.getMsbD (m % w) := by
|
||||
simp only [BitVec.msb, getMsbD_rotateLeft]
|
||||
by_cases h : w = 0
|
||||
· simp [h]
|
||||
· simp
|
||||
omega
|
||||
|
||||
/-! ## Rotate Right -/
|
||||
|
||||
/--
|
||||
@@ -2699,7 +2869,7 @@ theorem rotateRight_mod_eq_rotateRight {x : BitVec w} {r : Nat} :
|
||||
simp only [rotateRight, Nat.mod_mod]
|
||||
|
||||
/-- When `r < w`, we give a formula for `(x.rotateRight r).getLsb i`. -/
|
||||
theorem getLsbD_rotateRight_of_le {x : BitVec w} {r i : Nat} (hr: r < w) :
|
||||
theorem getLsbD_rotateRight_of_lt {x : BitVec w} {r i : Nat} (hr: r < w) :
|
||||
(x.rotateRight r).getLsbD i =
|
||||
cond (i < w - r)
|
||||
(x.getLsbD (r + i))
|
||||
@@ -2717,7 +2887,7 @@ theorem getLsbD_rotateRight {x : BitVec w} {r i : Nat} :
|
||||
(decide (i < w) && x.getLsbD (i - (w - (r % w)))) := by
|
||||
rcases w with ⟨rfl, w⟩
|
||||
· simp
|
||||
· rw [← rotateRight_mod_eq_rotateRight, getLsbD_rotateRight_of_le (Nat.mod_lt _ (by omega))]
|
||||
· rw [← rotateRight_mod_eq_rotateRight, getLsbD_rotateRight_of_lt (Nat.mod_lt _ (by omega))]
|
||||
|
||||
@[simp]
|
||||
theorem getElem_rotateRight {x : BitVec w} {r i : Nat} (h : i < w) :
|
||||
@@ -2725,6 +2895,56 @@ theorem getElem_rotateRight {x : BitVec w} {r i : Nat} (h : i < w) :
|
||||
simp only [← BitVec.getLsbD_eq_getElem]
|
||||
simp [getLsbD_rotateRight, h]
|
||||
|
||||
theorem getMsbD_rotateRightAux_of_lt {x : BitVec w} {r : Nat} {i : Nat} (hi : i < r) :
|
||||
(x.rotateRightAux r).getMsbD i = x.getMsbD (i + (w - r)) := by
|
||||
rw [rotateRightAux, getMsbD_or, getMsbD_ushiftRight]
|
||||
simp [show i < r by omega]
|
||||
|
||||
theorem getMsbD_rotateRightAux_of_ge {x : BitVec w} {r : Nat} {i : Nat} (hi : i ≥ r) :
|
||||
(x.rotateRightAux r).getMsbD i = (decide (i < w) && x.getMsbD (i - r)) := by
|
||||
simp [rotateRightAux, show ¬ i < r by omega, show i + (w - r) ≥ w by omega]
|
||||
|
||||
/-- When `m < w`, we give a formula for `(x.rotateLeft m).getMsbD i`. -/
|
||||
@[simp]
|
||||
theorem getMsbD_rotateRight_of_lt {w n m : Nat} {x : BitVec w} (hr : m < w):
|
||||
(x.rotateRight m).getMsbD n = (decide (n < w) && (if (n < m % w)
|
||||
then x.getMsbD ((w + n - m % w) % w) else x.getMsbD (n - m % w))):= by
|
||||
rcases w with rfl | w
|
||||
· simp
|
||||
· rw [rotateRight_eq_rotateRightAux_of_lt (by omega)]
|
||||
by_cases h : n < m
|
||||
· simp only [getMsbD_rotateRightAux_of_lt h, show n < w + 1 by omega, decide_true,
|
||||
show m % (w + 1) = m by rw [Nat.mod_eq_of_lt hr], h, ↓reduceIte,
|
||||
show (w + 1 + n - m) < (w + 1) by omega, Nat.mod_eq_of_lt, Bool.true_and]
|
||||
congr 1
|
||||
omega
|
||||
· simp [h, getMsbD_rotateRightAux_of_ge <| Nat.ge_of_not_lt h]
|
||||
by_cases h₁ : n < w + 1
|
||||
· simp [h, h₁, decide_true, Bool.true_and, Nat.mod_eq_of_lt hr]
|
||||
· simp [h₁]
|
||||
|
||||
@[simp]
|
||||
theorem getMsbD_rotateRight {w n m : Nat} {x : BitVec w} :
|
||||
(x.rotateRight m).getMsbD n = (decide (n < w) && (if (n < m % w)
|
||||
then x.getMsbD ((w + n - m % w) % w) else x.getMsbD (n - m % w))):= by
|
||||
rcases w with rfl | w
|
||||
· simp
|
||||
· by_cases h₀ : m < w
|
||||
· rw [getMsbD_rotateRight_of_lt (by omega)]
|
||||
· rw [← rotateRight_mod_eq_rotateRight, getMsbD_rotateRight_of_lt (by apply Nat.mod_lt; simp)]
|
||||
simp
|
||||
|
||||
@[simp]
|
||||
theorem msb_rotateRight {r w : Nat} {x : BitVec w} :
|
||||
(x.rotateRight r).msb = x.getMsbD ((w - r % w) % w) := by
|
||||
simp only [BitVec.msb, getMsbD_rotateRight]
|
||||
by_cases h₀ : 0 < w
|
||||
· simp only [h₀, decide_true, Nat.add_zero, Nat.zero_le, Nat.sub_eq_zero_of_le, Bool.true_and,
|
||||
ite_eq_left_iff, Nat.not_lt, Nat.le_zero_eq]
|
||||
intro h₁
|
||||
simp [h₁]
|
||||
· simp [show w = 0 by omega]
|
||||
|
||||
/- ## twoPow -/
|
||||
|
||||
theorem twoPow_eq (w : Nat) (i : Nat) : twoPow w i = 1#w <<< i := by
|
||||
@@ -3124,7 +3344,11 @@ theorem toNat_abs {x : BitVec w} : x.abs.toNat = if x.msb then 2^w - x.toNat els
|
||||
· simp [h]
|
||||
|
||||
theorem getLsbD_abs {i : Nat} {x : BitVec w} :
|
||||
getLsbD x.abs i = if x.msb then getLsbD (-x) i else getLsbD x i := by
|
||||
getLsbD x.abs i = if x.msb then getLsbD (-x) i else getLsbD x i := by
|
||||
by_cases h : x.msb <;> simp [BitVec.abs, h]
|
||||
|
||||
theorem getElem_abs {i : Nat} {x : BitVec w} (h : i < w) :
|
||||
x.abs[i] = if x.msb then (-x)[i] else x[i] := by
|
||||
by_cases h : x.msb <;> simp [BitVec.abs, h]
|
||||
|
||||
theorem getMsbD_abs {i : Nat} {x : BitVec w} :
|
||||
|
||||
@@ -108,8 +108,18 @@ def toList (bs : ByteArray) : List UInt8 :=
|
||||
|
||||
@[inline] def findIdx? (a : ByteArray) (p : UInt8 → Bool) (start := 0) : Option Nat :=
|
||||
let rec @[specialize] loop (i : Nat) :=
|
||||
if i < a.size then
|
||||
if p (a.get! i) then some i else loop (i+1)
|
||||
if h : i < a.size then
|
||||
if p a[i] then some i else loop (i+1)
|
||||
else
|
||||
none
|
||||
termination_by a.size - i
|
||||
decreasing_by decreasing_trivial_pre_omega
|
||||
loop start
|
||||
|
||||
@[inline] def findFinIdx? (a : ByteArray) (p : UInt8 → Bool) (start := 0) : Option (Fin a.size) :=
|
||||
let rec @[specialize] loop (i : Nat) :=
|
||||
if h : i < a.size then
|
||||
if p a[i] then some ⟨i, h⟩ else loop (i+1)
|
||||
else
|
||||
none
|
||||
termination_by a.size - i
|
||||
|
||||
@@ -31,7 +31,7 @@ opaque floatSpec : FloatSpec := {
|
||||
structure Float where
|
||||
val : floatSpec.float
|
||||
|
||||
instance : Inhabited Float := ⟨{ val := floatSpec.val }⟩
|
||||
instance : Nonempty Float := ⟨{ val := floatSpec.val }⟩
|
||||
|
||||
@[extern "lean_float_add"] opaque Float.add : Float → Float → Float
|
||||
@[extern "lean_float_sub"] opaque Float.sub : Float → Float → Float
|
||||
@@ -136,6 +136,9 @@ instance : ToString Float where
|
||||
|
||||
@[extern "lean_uint64_to_float"] opaque UInt64.toFloat (n : UInt64) : Float
|
||||
|
||||
instance : Inhabited Float where
|
||||
default := UInt64.toFloat 0
|
||||
|
||||
instance : Repr Float where
|
||||
reprPrec n prec := if n < UInt64.toFloat 0 then Repr.addAppParen (toString n) prec else toString n
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ namespace List
|
||||
`a : α` satisfying `P`, then `pmap f l h` is essentially the same as `map f l`
|
||||
but is defined only when all members of `l` satisfy `P`, using the proof
|
||||
to apply `f`. -/
|
||||
@[simp] def pmap {P : α → Prop} (f : ∀ a, P a → β) : ∀ l : List α, (H : ∀ a ∈ l, P a) → List β
|
||||
def pmap {P : α → Prop} (f : ∀ a, P a → β) : ∀ l : List α, (H : ∀ a ∈ l, P a) → List β
|
||||
| [], _ => []
|
||||
| a :: l, H => f a (forall_mem_cons.1 H).1 :: pmap f l (forall_mem_cons.1 H).2
|
||||
|
||||
@@ -46,6 +46,11 @@ Unsafe implementation of `attachWith`, taking advantage of the fact that the rep
|
||||
| cons _ L', hL' => congrArg _ <| go L' fun _ hx => hL' (.tail _ hx)
|
||||
exact go L h'
|
||||
|
||||
@[simp] theorem pmap_nil {P : α → Prop} (f : ∀ a, P a → β) : pmap f [] (by simp) = [] := rfl
|
||||
|
||||
@[simp] theorem pmap_cons {P : α → Prop} (f : ∀ a, P a → β) (a : α) (l : List α) (h : ∀ b ∈ a :: l, P b) :
|
||||
pmap f (a :: l) h = f a (forall_mem_cons.1 h).1 :: pmap f l (forall_mem_cons.1 h).2 := rfl
|
||||
|
||||
@[simp] theorem attach_nil : ([] : List α).attach = [] := rfl
|
||||
|
||||
@[simp] theorem attachWith_nil : ([] : List α).attachWith P H = [] := rfl
|
||||
@@ -148,7 +153,7 @@ theorem mem_pmap_of_mem {p : α → Prop} {f : ∀ a, p a → β} {l H} {a} (h :
|
||||
exact ⟨a, h, rfl⟩
|
||||
|
||||
@[simp]
|
||||
theorem length_pmap {p : α → Prop} {f : ∀ a, p a → β} {l H} : length (pmap f l H) = length l := by
|
||||
theorem length_pmap {p : α → Prop} {f : ∀ a, p a → β} {l H} : (pmap f l H).length = l.length := by
|
||||
induction l
|
||||
· rfl
|
||||
· simp only [*, pmap, length]
|
||||
@@ -199,7 +204,7 @@ theorem attachWith_ne_nil_iff {l : List α} {P : α → Prop} {H : ∀ a ∈ l,
|
||||
|
||||
@[simp]
|
||||
theorem getElem?_pmap {p : α → Prop} (f : ∀ a, p a → β) {l : List α} (h : ∀ a ∈ l, p a) (n : Nat) :
|
||||
(pmap f l h)[n]? = Option.pmap f l[n]? fun x H => h x (getElem?_mem H) := by
|
||||
(pmap f l h)[n]? = Option.pmap f l[n]? fun x H => h x (mem_of_getElem? H) := by
|
||||
induction l generalizing n with
|
||||
| nil => simp
|
||||
| cons hd tl hl =>
|
||||
@@ -215,7 +220,7 @@ theorem getElem?_pmap {p : α → Prop} (f : ∀ a, p a → β) {l : List α} (h
|
||||
· simp_all
|
||||
|
||||
theorem get?_pmap {p : α → Prop} (f : ∀ a, p a → β) {l : List α} (h : ∀ a ∈ l, p a) (n : Nat) :
|
||||
get? (pmap f l h) n = Option.pmap f (get? l n) fun x H => h x (get?_mem H) := by
|
||||
get? (pmap f l h) n = Option.pmap f (get? l n) fun x H => h x (mem_of_get? H) := by
|
||||
simp only [get?_eq_getElem?]
|
||||
simp [getElem?_pmap, h]
|
||||
|
||||
@@ -238,18 +243,18 @@ theorem get_pmap {p : α → Prop} (f : ∀ a, p a → β) {l : List α} (h :
|
||||
(hn : n < (pmap f l h).length) :
|
||||
get (pmap f l h) ⟨n, hn⟩ =
|
||||
f (get l ⟨n, @length_pmap _ _ p f l h ▸ hn⟩)
|
||||
(h _ (get_mem l ⟨n, @length_pmap _ _ p f l h ▸ hn⟩)) := by
|
||||
(h _ (getElem_mem (@length_pmap _ _ p f l h ▸ hn))) := by
|
||||
simp only [get_eq_getElem]
|
||||
simp [getElem_pmap]
|
||||
|
||||
@[simp]
|
||||
theorem getElem?_attachWith {xs : List α} {i : Nat} {P : α → Prop} {H : ∀ a ∈ xs, P a} :
|
||||
(xs.attachWith P H)[i]? = xs[i]?.pmap Subtype.mk (fun _ a => H _ (getElem?_mem a)) :=
|
||||
(xs.attachWith P H)[i]? = xs[i]?.pmap Subtype.mk (fun _ a => H _ (mem_of_getElem? a)) :=
|
||||
getElem?_pmap ..
|
||||
|
||||
@[simp]
|
||||
theorem getElem?_attach {xs : List α} {i : Nat} :
|
||||
xs.attach[i]? = xs[i]?.pmap Subtype.mk (fun _ a => getElem?_mem a) :=
|
||||
xs.attach[i]? = xs[i]?.pmap Subtype.mk (fun _ a => mem_of_getElem? a) :=
|
||||
getElem?_attachWith
|
||||
|
||||
@[simp]
|
||||
@@ -333,6 +338,7 @@ This is useful when we need to use `attach` to show termination.
|
||||
|
||||
Unfortunately this can't be applied by `simp` because of the higher order unification problem,
|
||||
and even when rewriting we need to specify the function explicitly.
|
||||
See however `foldl_subtype` below.
|
||||
-/
|
||||
theorem foldl_attach (l : List α) (f : β → α → β) (b : β) :
|
||||
l.attach.foldl (fun acc t => f acc t.1) b = l.foldl f b := by
|
||||
@@ -348,6 +354,7 @@ This is useful when we need to use `attach` to show termination.
|
||||
|
||||
Unfortunately this can't be applied by `simp` because of the higher order unification problem,
|
||||
and even when rewriting we need to specify the function explicitly.
|
||||
See however `foldr_subtype` below.
|
||||
-/
|
||||
theorem foldr_attach (l : List α) (f : α → β → β) (b : β) :
|
||||
l.attach.foldr (fun t acc => f t.1 acc) b = l.foldr f b := by
|
||||
@@ -452,16 +459,16 @@ theorem pmap_append' {p : α → Prop} (f : ∀ a : α, p a → β) (l₁ l₂ :
|
||||
pmap_append f l₁ l₂ _
|
||||
|
||||
@[simp] theorem attach_append (xs ys : List α) :
|
||||
(xs ++ ys).attach = xs.attach.map (fun ⟨x, h⟩ => ⟨x, mem_append_of_mem_left ys h⟩) ++
|
||||
ys.attach.map fun ⟨x, h⟩ => ⟨x, mem_append_of_mem_right xs h⟩ := by
|
||||
(xs ++ ys).attach = xs.attach.map (fun ⟨x, h⟩ => ⟨x, mem_append_left ys h⟩) ++
|
||||
ys.attach.map fun ⟨x, h⟩ => ⟨x, mem_append_right xs h⟩ := by
|
||||
simp only [attach, attachWith, pmap, map_pmap, pmap_append]
|
||||
congr 1 <;>
|
||||
exact pmap_congr_left _ fun _ _ _ _ => rfl
|
||||
|
||||
@[simp] theorem attachWith_append {P : α → Prop} {xs ys : List α}
|
||||
{H : ∀ (a : α), a ∈ xs ++ ys → P a} :
|
||||
(xs ++ ys).attachWith P H = xs.attachWith P (fun a h => H a (mem_append_of_mem_left ys h)) ++
|
||||
ys.attachWith P (fun a h => H a (mem_append_of_mem_right xs h)) := by
|
||||
(xs ++ ys).attachWith P H = xs.attachWith P (fun a h => H a (mem_append_left ys h)) ++
|
||||
ys.attachWith P (fun a h => H a (mem_append_right xs h)) := by
|
||||
simp only [attachWith, attach_append, map_pmap, pmap_append]
|
||||
|
||||
@[simp] theorem pmap_reverse {P : α → Prop} (f : (a : α) → P a → β) (xs : List α)
|
||||
@@ -598,6 +605,15 @@ def unattach {α : Type _} {p : α → Prop} (l : List { x // p x }) := l.map (
|
||||
| nil => simp
|
||||
| cons a l ih => simp [ih, Function.comp_def]
|
||||
|
||||
@[simp] theorem getElem?_unattach {p : α → Prop} {l : List { x // p x }} (i : Nat) :
|
||||
l.unattach[i]? = l[i]?.map Subtype.val := by
|
||||
simp [unattach]
|
||||
|
||||
@[simp] theorem getElem_unattach
|
||||
{p : α → Prop} {l : List { x // p x }} (i : Nat) (h : i < l.unattach.length) :
|
||||
l.unattach[i] = (l[i]'(by simpa using h)).1 := by
|
||||
simp [unattach]
|
||||
|
||||
/-! ### Recognizing higher order functions on subtypes using a function that only depends on the value. -/
|
||||
|
||||
/--
|
||||
|
||||
@@ -726,13 +726,13 @@ theorem elem_eq_true_of_mem [BEq α] [LawfulBEq α] {a : α} {as : List α} (h :
|
||||
instance [BEq α] [LawfulBEq α] (a : α) (as : List α) : Decidable (a ∈ as) :=
|
||||
decidable_of_decidable_of_iff (Iff.intro mem_of_elem_eq_true elem_eq_true_of_mem)
|
||||
|
||||
theorem mem_append_of_mem_left {a : α} {as : List α} (bs : List α) : a ∈ as → a ∈ as ++ bs := by
|
||||
theorem mem_append_left {a : α} {as : List α} (bs : List α) : a ∈ as → a ∈ as ++ bs := by
|
||||
intro h
|
||||
induction h with
|
||||
| head => apply Mem.head
|
||||
| tail => apply Mem.tail; assumption
|
||||
|
||||
theorem mem_append_of_mem_right {b : α} {bs : List α} (as : List α) : b ∈ bs → b ∈ as ++ bs := by
|
||||
theorem mem_append_right {b : α} {bs : List α} (as : List α) : b ∈ bs → b ∈ as ++ bs := by
|
||||
intro h
|
||||
induction as with
|
||||
| nil => simp [h]
|
||||
|
||||
@@ -256,7 +256,7 @@ theorem findM?_eq_findSomeM? [Monad m] [LawfulMonad m] (p : α → m Bool) (as :
|
||||
have : a ∈ as := by
|
||||
have ⟨bs, h⟩ := h
|
||||
subst h
|
||||
exact mem_append_of_mem_right _ (Mem.head ..)
|
||||
exact mem_append_right _ (Mem.head ..)
|
||||
match (← f a this b) with
|
||||
| ForInStep.done b => pure b
|
||||
| ForInStep.yield b =>
|
||||
|
||||
@@ -394,9 +394,9 @@ theorem get?_concat_length (l : List α) (a : α) : (l ++ [a]).get? l.length = s
|
||||
theorem mem_cons_self (a : α) (l : List α) : a ∈ a :: l := .head ..
|
||||
|
||||
theorem mem_concat_self (xs : List α) (a : α) : a ∈ xs ++ [a] :=
|
||||
mem_append_of_mem_right xs (mem_cons_self a _)
|
||||
mem_append_right xs (mem_cons_self a _)
|
||||
|
||||
theorem mem_append_cons_self : a ∈ xs ++ a :: ys := mem_append_of_mem_right _ (mem_cons_self _ _)
|
||||
theorem mem_append_cons_self : a ∈ xs ++ a :: ys := mem_append_right _ (mem_cons_self _ _)
|
||||
|
||||
theorem eq_append_cons_of_mem {a : α} {xs : List α} (h : a ∈ xs) :
|
||||
∃ as bs, xs = as ++ a :: bs ∧ a ∉ as := by
|
||||
@@ -507,12 +507,16 @@ theorem get_mem : ∀ (l : List α) n, get l n ∈ l
|
||||
| _ :: _, ⟨0, _⟩ => .head ..
|
||||
| _ :: l, ⟨_+1, _⟩ => .tail _ (get_mem l ..)
|
||||
|
||||
theorem getElem?_mem {l : List α} {n : Nat} {a : α} (e : l[n]? = some a) : a ∈ l :=
|
||||
theorem mem_of_getElem? {l : List α} {n : Nat} {a : α} (e : l[n]? = some a) : a ∈ l :=
|
||||
let ⟨_, e⟩ := getElem?_eq_some_iff.1 e; e ▸ getElem_mem ..
|
||||
|
||||
theorem get?_mem {l : List α} {n a} (e : l.get? n = some a) : a ∈ l :=
|
||||
@[deprecated mem_of_getElem? (since := "2024-09-06")] abbrev getElem?_mem := @mem_of_getElem?
|
||||
|
||||
theorem mem_of_get? {l : List α} {n a} (e : l.get? n = some a) : a ∈ l :=
|
||||
let ⟨_, e⟩ := get?_eq_some.1 e; e ▸ get_mem ..
|
||||
|
||||
@[deprecated mem_of_get? (since := "2024-09-06")] abbrev get?_mem := @mem_of_get?
|
||||
|
||||
theorem mem_iff_getElem {a} {l : List α} : a ∈ l ↔ ∃ (n : Nat) (h : n < l.length), l[n]'h = a :=
|
||||
⟨getElem_of_mem, fun ⟨_, _, e⟩ => e ▸ getElem_mem ..⟩
|
||||
|
||||
@@ -1997,11 +2001,8 @@ theorem not_mem_append {a : α} {s t : List α} (h₁ : a ∉ s) (h₂ : a ∉ t
|
||||
theorem mem_append_eq (a : α) (s t : List α) : (a ∈ s ++ t) = (a ∈ s ∨ a ∈ t) :=
|
||||
propext mem_append
|
||||
|
||||
theorem mem_append_left {a : α} {l₁ : List α} (l₂ : List α) (h : a ∈ l₁) : a ∈ l₁ ++ l₂ :=
|
||||
mem_append.2 (Or.inl h)
|
||||
|
||||
theorem mem_append_right {a : α} (l₁ : List α) {l₂ : List α} (h : a ∈ l₂) : a ∈ l₁ ++ l₂ :=
|
||||
mem_append.2 (Or.inr h)
|
||||
@[deprecated mem_append_left (since := "2024-11-20")] abbrev mem_append_of_mem_left := @mem_append_left
|
||||
@[deprecated mem_append_right (since := "2024-11-20")] abbrev mem_append_of_mem_right := @mem_append_right
|
||||
|
||||
theorem mem_iff_append {a : α} {l : List α} : a ∈ l ↔ ∃ s t : List α, l = s ++ a :: t :=
|
||||
⟨append_of_mem, fun ⟨s, t, e⟩ => e ▸ by simp⟩
|
||||
|
||||
@@ -9,7 +9,7 @@ import Init.Data.List.Basic
|
||||
|
||||
namespace List
|
||||
|
||||
/-! ### isEqv-/
|
||||
/-! ### isEqv -/
|
||||
|
||||
theorem isEqv_eq_decide (a b : List α) (r) :
|
||||
isEqv a b r = if h : a.length = b.length then
|
||||
|
||||
@@ -417,7 +417,7 @@ theorem Sublist.of_sublist_append_left (w : ∀ a, a ∈ l → a ∉ l₂) (h :
|
||||
obtain ⟨l₁', l₂', rfl, h₁, h₂⟩ := h
|
||||
have : l₂' = [] := by
|
||||
rw [eq_nil_iff_forall_not_mem]
|
||||
exact fun x m => w x (mem_append_of_mem_right l₁' m) (h₂.mem m)
|
||||
exact fun x m => w x (mem_append_right l₁' m) (h₂.mem m)
|
||||
simp_all
|
||||
|
||||
theorem Sublist.of_sublist_append_right (w : ∀ a, a ∈ l → a ∉ l₁) (h : l <+ l₁ ++ l₂) : l <+ l₂ := by
|
||||
@@ -425,7 +425,7 @@ theorem Sublist.of_sublist_append_right (w : ∀ a, a ∈ l → a ∉ l₁) (h :
|
||||
obtain ⟨l₁', l₂', rfl, h₁, h₂⟩ := h
|
||||
have : l₁' = [] := by
|
||||
rw [eq_nil_iff_forall_not_mem]
|
||||
exact fun x m => w x (mem_append_of_mem_left l₂' m) (h₁.mem m)
|
||||
exact fun x m => w x (mem_append_left l₂' m) (h₁.mem m)
|
||||
simp_all
|
||||
|
||||
theorem Sublist.middle {l : List α} (h : l <+ l₁ ++ l₂) (a : α) : l <+ l₁ ++ a :: l₂ := by
|
||||
|
||||
@@ -20,3 +20,4 @@ import Init.Data.Nat.Mod
|
||||
import Init.Data.Nat.Lcm
|
||||
import Init.Data.Nat.Compare
|
||||
import Init.Data.Nat.Simproc
|
||||
import Init.Data.Nat.Fold
|
||||
|
||||
@@ -35,52 +35,6 @@ Used as the default `Nat` eliminator by the `cases` tactic. -/
|
||||
protected abbrev casesAuxOn {motive : Nat → Sort u} (t : Nat) (zero : motive 0) (succ : (n : Nat) → motive (n + 1)) : motive t :=
|
||||
Nat.casesOn t zero succ
|
||||
|
||||
/--
|
||||
`Nat.fold` evaluates `f` on the numbers up to `n` exclusive, in increasing order:
|
||||
* `Nat.fold f 3 init = init |> f 0 |> f 1 |> f 2`
|
||||
-/
|
||||
@[specialize] def fold {α : Type u} (f : Nat → α → α) : (n : Nat) → (init : α) → α
|
||||
| 0, a => a
|
||||
| succ n, a => f n (fold f n a)
|
||||
|
||||
/-- Tail-recursive version of `Nat.fold`. -/
|
||||
@[inline] def foldTR {α : Type u} (f : Nat → α → α) (n : Nat) (init : α) : α :=
|
||||
let rec @[specialize] loop
|
||||
| 0, a => a
|
||||
| succ m, a => loop m (f (n - succ m) a)
|
||||
loop n init
|
||||
|
||||
/--
|
||||
`Nat.foldRev` evaluates `f` on the numbers up to `n` exclusive, in decreasing order:
|
||||
* `Nat.foldRev f 3 init = f 0 <| f 1 <| f 2 <| init`
|
||||
-/
|
||||
@[specialize] def foldRev {α : Type u} (f : Nat → α → α) : (n : Nat) → (init : α) → α
|
||||
| 0, a => a
|
||||
| succ n, a => foldRev f n (f n a)
|
||||
|
||||
/-- `any f n = true` iff there is `i in [0, n-1]` s.t. `f i = true` -/
|
||||
@[specialize] def any (f : Nat → Bool) : Nat → Bool
|
||||
| 0 => false
|
||||
| succ n => any f n || f n
|
||||
|
||||
/-- Tail-recursive version of `Nat.any`. -/
|
||||
@[inline] def anyTR (f : Nat → Bool) (n : Nat) : Bool :=
|
||||
let rec @[specialize] loop : Nat → Bool
|
||||
| 0 => false
|
||||
| succ m => f (n - succ m) || loop m
|
||||
loop n
|
||||
|
||||
/-- `all f n = true` iff every `i in [0, n-1]` satisfies `f i = true` -/
|
||||
@[specialize] def all (f : Nat → Bool) : Nat → Bool
|
||||
| 0 => true
|
||||
| succ n => all f n && f n
|
||||
|
||||
/-- Tail-recursive version of `Nat.all`. -/
|
||||
@[inline] def allTR (f : Nat → Bool) (n : Nat) : Bool :=
|
||||
let rec @[specialize] loop : Nat → Bool
|
||||
| 0 => true
|
||||
| succ m => f (n - succ m) && loop m
|
||||
loop n
|
||||
|
||||
/--
|
||||
`Nat.repeat f n a` is `f^(n) a`; that is, it iterates `f` `n` times on `a`.
|
||||
@@ -1158,33 +1112,6 @@ theorem not_lt_eq (a b : Nat) : (¬ (a < b)) = (b ≤ a) :=
|
||||
theorem not_gt_eq (a b : Nat) : (¬ (a > b)) = (a ≤ b) :=
|
||||
not_lt_eq b a
|
||||
|
||||
/-! # csimp theorems -/
|
||||
|
||||
@[csimp] theorem fold_eq_foldTR : @fold = @foldTR :=
|
||||
funext fun α => funext fun f => funext fun n => funext fun init =>
|
||||
let rec go : ∀ m n, foldTR.loop f (m + n) m (fold f n init) = fold f (m + n) init
|
||||
| 0, n => by simp [foldTR.loop]
|
||||
| succ m, n => by rw [foldTR.loop, add_sub_self_left, succ_add]; exact go m (succ n)
|
||||
(go n 0).symm
|
||||
|
||||
@[csimp] theorem any_eq_anyTR : @any = @anyTR :=
|
||||
funext fun f => funext fun n =>
|
||||
let rec go : ∀ m n, (any f n || anyTR.loop f (m + n) m) = any f (m + n)
|
||||
| 0, n => by simp [anyTR.loop]
|
||||
| succ m, n => by
|
||||
rw [anyTR.loop, add_sub_self_left, ← Bool.or_assoc, succ_add]
|
||||
exact go m (succ n)
|
||||
(go n 0).symm
|
||||
|
||||
@[csimp] theorem all_eq_allTR : @all = @allTR :=
|
||||
funext fun f => funext fun n =>
|
||||
let rec go : ∀ m n, (all f n && allTR.loop f (m + n) m) = all f (m + n)
|
||||
| 0, n => by simp [allTR.loop]
|
||||
| succ m, n => by
|
||||
rw [allTR.loop, add_sub_self_left, ← Bool.and_assoc, succ_add]
|
||||
exact go m (succ n)
|
||||
(go n 0).symm
|
||||
|
||||
@[csimp] theorem repeat_eq_repeatTR : @repeat = @repeatTR :=
|
||||
funext fun α => funext fun f => funext fun n => funext fun init =>
|
||||
let rec go : ∀ m n, repeatTR.loop f m (repeat f n init) = repeat f (m + n) init
|
||||
@@ -1193,31 +1120,3 @@ theorem not_gt_eq (a b : Nat) : (¬ (a > b)) = (a ≤ b) :=
|
||||
(go n 0).symm
|
||||
|
||||
end Nat
|
||||
|
||||
namespace Prod
|
||||
|
||||
/--
|
||||
`(start, stop).foldI f a` evaluates `f` on all the numbers
|
||||
from `start` (inclusive) to `stop` (exclusive) in increasing order:
|
||||
* `(5, 8).foldI f init = init |> f 5 |> f 6 |> f 7`
|
||||
-/
|
||||
@[inline] def foldI {α : Type u} (f : Nat → α → α) (i : Nat × Nat) (a : α) : α :=
|
||||
Nat.foldTR.loop f i.2 (i.2 - i.1) a
|
||||
|
||||
/--
|
||||
`(start, stop).anyI f a` returns true if `f` is true for some natural number
|
||||
from `start` (inclusive) to `stop` (exclusive):
|
||||
* `(5, 8).anyI f = f 5 || f 6 || f 7`
|
||||
-/
|
||||
@[inline] def anyI (f : Nat → Bool) (i : Nat × Nat) : Bool :=
|
||||
Nat.anyTR.loop f i.2 (i.2 - i.1)
|
||||
|
||||
/--
|
||||
`(start, stop).allI f a` returns true if `f` is true for all natural numbers
|
||||
from `start` (inclusive) to `stop` (exclusive):
|
||||
* `(5, 8).anyI f = f 5 && f 6 && f 7`
|
||||
-/
|
||||
@[inline] def allI (f : Nat → Bool) (i : Nat × Nat) : Bool :=
|
||||
Nat.allTR.loop f i.2 (i.2 - i.1)
|
||||
|
||||
end Prod
|
||||
|
||||
@@ -6,50 +6,51 @@ Author: Leonardo de Moura
|
||||
prelude
|
||||
import Init.Control.Basic
|
||||
import Init.Data.Nat.Basic
|
||||
import Init.Omega
|
||||
|
||||
namespace Nat
|
||||
universe u v
|
||||
|
||||
@[inline] def forM {m} [Monad m] (n : Nat) (f : Nat → m Unit) : m Unit :=
|
||||
let rec @[specialize] loop
|
||||
| 0 => pure ()
|
||||
| i+1 => do f (n-i-1); loop i
|
||||
loop n
|
||||
@[inline] def forM {m} [Monad m] (n : Nat) (f : (i : Nat) → i < n → m Unit) : m Unit :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → m Unit
|
||||
| 0, _ => pure ()
|
||||
| i+1, h => do f (n-i-1) (by omega); loop i (Nat.le_of_succ_le h)
|
||||
loop n (by simp)
|
||||
|
||||
@[inline] def forRevM {m} [Monad m] (n : Nat) (f : Nat → m Unit) : m Unit :=
|
||||
let rec @[specialize] loop
|
||||
| 0 => pure ()
|
||||
| i+1 => do f i; loop i
|
||||
loop n
|
||||
@[inline] def forRevM {m} [Monad m] (n : Nat) (f : (i : Nat) → i < n → m Unit) : m Unit :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → m Unit
|
||||
| 0, _ => pure ()
|
||||
| i+1, h => do f i (by omega); loop i (Nat.le_of_succ_le h)
|
||||
loop n (by simp)
|
||||
|
||||
@[inline] def foldM {α : Type u} {m : Type u → Type v} [Monad m] (f : Nat → α → m α) (init : α) (n : Nat) : m α :=
|
||||
let rec @[specialize] loop
|
||||
| 0, a => pure a
|
||||
| i+1, a => f (n-i-1) a >>= loop i
|
||||
loop n init
|
||||
@[inline] def foldM {α : Type u} {m : Type u → Type v} [Monad m] (n : Nat) (f : (i : Nat) → i < n → α → m α) (init : α) : m α :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → α → m α
|
||||
| 0, h, a => pure a
|
||||
| i+1, h, a => f (n-i-1) (by omega) a >>= loop i (Nat.le_of_succ_le h)
|
||||
loop n (by omega) init
|
||||
|
||||
@[inline] def foldRevM {α : Type u} {m : Type u → Type v} [Monad m] (f : Nat → α → m α) (init : α) (n : Nat) : m α :=
|
||||
let rec @[specialize] loop
|
||||
| 0, a => pure a
|
||||
| i+1, a => f i a >>= loop i
|
||||
loop n init
|
||||
@[inline] def foldRevM {α : Type u} {m : Type u → Type v} [Monad m] (n : Nat) (f : (i : Nat) → i < n → α → m α) (init : α) : m α :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → α → m α
|
||||
| 0, h, a => pure a
|
||||
| i+1, h, a => f i (by omega) a >>= loop i (Nat.le_of_succ_le h)
|
||||
loop n (by omega) init
|
||||
|
||||
@[inline] def allM {m} [Monad m] (n : Nat) (p : Nat → m Bool) : m Bool :=
|
||||
let rec @[specialize] loop
|
||||
| 0 => pure true
|
||||
| i+1 => do
|
||||
match (← p (n-i-1)) with
|
||||
| true => loop i
|
||||
@[inline] def allM {m} [Monad m] (n : Nat) (p : (i : Nat) → i < n → m Bool) : m Bool :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → m Bool
|
||||
| 0, _ => pure true
|
||||
| i+1 , h => do
|
||||
match (← p (n-i-1) (by omega)) with
|
||||
| true => loop i (by omega)
|
||||
| false => pure false
|
||||
loop n
|
||||
loop n (by simp)
|
||||
|
||||
@[inline] def anyM {m} [Monad m] (n : Nat) (p : Nat → m Bool) : m Bool :=
|
||||
let rec @[specialize] loop
|
||||
| 0 => pure false
|
||||
| i+1 => do
|
||||
match (← p (n-i-1)) with
|
||||
@[inline] def anyM {m} [Monad m] (n : Nat) (p : (i : Nat) → i < n → m Bool) : m Bool :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → m Bool
|
||||
| 0, _ => pure false
|
||||
| i+1, h => do
|
||||
match (← p (n-i-1) (by omega)) with
|
||||
| true => pure true
|
||||
| false => loop i
|
||||
loop n
|
||||
| false => loop i (Nat.le_of_succ_le h)
|
||||
loop n (by simp)
|
||||
|
||||
end Nat
|
||||
|
||||
168
src/Init/Data/Nat/Fold.lean
Normal file
168
src/Init/Data/Nat/Fold.lean
Normal file
@@ -0,0 +1,168 @@
|
||||
/-
|
||||
Copyright (c) 2014 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Floris van Doorn, Leonardo de Moura, Kim Morrison
|
||||
-/
|
||||
prelude
|
||||
import Init.Omega
|
||||
|
||||
set_option linter.missingDocs true -- keep it documented
|
||||
universe u
|
||||
|
||||
namespace Nat
|
||||
|
||||
/--
|
||||
`Nat.fold` evaluates `f` on the numbers up to `n` exclusive, in increasing order:
|
||||
* `Nat.fold f 3 init = init |> f 0 |> f 1 |> f 2`
|
||||
-/
|
||||
@[specialize] def fold {α : Type u} : (n : Nat) → (f : (i : Nat) → i < n → α → α) → (init : α) → α
|
||||
| 0, f, a => a
|
||||
| succ n, f, a => f n (by omega) (fold n (fun i h => f i (by omega)) a)
|
||||
|
||||
/-- Tail-recursive version of `Nat.fold`. -/
|
||||
@[inline] def foldTR {α : Type u} (n : Nat) (f : (i : Nat) → i < n → α → α) (init : α) : α :=
|
||||
let rec @[specialize] loop : ∀ j, j ≤ n → α → α
|
||||
| 0, h, a => a
|
||||
| succ m, h, a => loop m (by omega) (f (n - succ m) (by omega) a)
|
||||
loop n (by omega) init
|
||||
|
||||
/--
|
||||
`Nat.foldRev` evaluates `f` on the numbers up to `n` exclusive, in decreasing order:
|
||||
* `Nat.foldRev f 3 init = f 0 <| f 1 <| f 2 <| init`
|
||||
-/
|
||||
@[specialize] def foldRev {α : Type u} : (n : Nat) → (f : (i : Nat) → i < n → α → α) → (init : α) → α
|
||||
| 0, f, a => a
|
||||
| succ n, f, a => foldRev n (fun i h => f i (by omega)) (f n (by omega) a)
|
||||
|
||||
/-- `any f n = true` iff there is `i in [0, n-1]` s.t. `f i = true` -/
|
||||
@[specialize] def any : (n : Nat) → (f : (i : Nat) → i < n → Bool) → Bool
|
||||
| 0, f => false
|
||||
| succ n, f => any n (fun i h => f i (by omega)) || f n (by omega)
|
||||
|
||||
/-- Tail-recursive version of `Nat.any`. -/
|
||||
@[inline] def anyTR (n : Nat) (f : (i : Nat) → i < n → Bool) : Bool :=
|
||||
let rec @[specialize] loop : (i : Nat) → i ≤ n → Bool
|
||||
| 0, h => false
|
||||
| succ m, h => f (n - succ m) (by omega) || loop m (by omega)
|
||||
loop n (by omega)
|
||||
|
||||
/-- `all f n = true` iff every `i in [0, n-1]` satisfies `f i = true` -/
|
||||
@[specialize] def all : (n : Nat) → (f : (i : Nat) → i < n → Bool) → Bool
|
||||
| 0, f => true
|
||||
| succ n, f => all n (fun i h => f i (by omega)) && f n (by omega)
|
||||
|
||||
/-- Tail-recursive version of `Nat.all`. -/
|
||||
@[inline] def allTR (n : Nat) (f : (i : Nat) → i < n → Bool) : Bool :=
|
||||
let rec @[specialize] loop : (i : Nat) → i ≤ n → Bool
|
||||
| 0, h => true
|
||||
| succ m, h => f (n - succ m) (by omega) && loop m (by omega)
|
||||
loop n (by omega)
|
||||
|
||||
/-! # csimp theorems -/
|
||||
|
||||
theorem fold_congr {α : Type u} {n m : Nat} (w : n = m)
|
||||
(f : (i : Nat) → i < n → α → α) (init : α) :
|
||||
fold n f init = fold m (fun i h => f i (by omega)) init := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
theorem foldTR_loop_congr {α : Type u} {n m : Nat} (w : n = m)
|
||||
(f : (i : Nat) → i < n → α → α) (j : Nat) (h : j ≤ n) (init : α) :
|
||||
foldTR.loop n f j h init = foldTR.loop m (fun i h => f i (by omega)) j (by omega) init := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
@[csimp] theorem fold_eq_foldTR : @fold = @foldTR :=
|
||||
funext fun α => funext fun n => funext fun f => funext fun init =>
|
||||
let rec go : ∀ m n f, fold (m + n) f init = foldTR.loop (m + n) f m (by omega) (fold n (fun i h => f i (by omega)) init)
|
||||
| 0, n, f => by
|
||||
simp only [foldTR.loop]
|
||||
have t : 0 + n = n := by omega
|
||||
rw [fold_congr t]
|
||||
| succ m, n, f => by
|
||||
have t : (m + 1) + n = m + (n + 1) := by omega
|
||||
rw [foldTR.loop]
|
||||
simp only [succ_eq_add_one, Nat.add_sub_cancel]
|
||||
rw [fold_congr t, foldTR_loop_congr t, go, fold]
|
||||
congr
|
||||
omega
|
||||
go n 0 f
|
||||
|
||||
theorem any_congr {n m : Nat} (w : n = m) (f : (i : Nat) → i < n → Bool) : any n f = any m (fun i h => f i (by omega)) := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
theorem anyTR_loop_congr {n m : Nat} (w : n = m) (f : (i : Nat) → i < n → Bool) (j : Nat) (h : j ≤ n) :
|
||||
anyTR.loop n f j h = anyTR.loop m (fun i h => f i (by omega)) j (by omega) := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
@[csimp] theorem any_eq_anyTR : @any = @anyTR :=
|
||||
funext fun n => funext fun f =>
|
||||
let rec go : ∀ m n f, any (m + n) f = (any n (fun i h => f i (by omega)) || anyTR.loop (m + n) f m (by omega))
|
||||
| 0, n, f => by
|
||||
simp [anyTR.loop]
|
||||
have t : 0 + n = n := by omega
|
||||
rw [any_congr t]
|
||||
| succ m, n, f => by
|
||||
have t : (m + 1) + n = m + (n + 1) := by omega
|
||||
rw [anyTR.loop]
|
||||
simp only [succ_eq_add_one]
|
||||
rw [any_congr t, anyTR_loop_congr t, go, any, Bool.or_assoc]
|
||||
congr
|
||||
omega
|
||||
go n 0 f
|
||||
|
||||
theorem all_congr {n m : Nat} (w : n = m) (f : (i : Nat) → i < n → Bool) : all n f = all m (fun i h => f i (by omega)) := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
theorem allTR_loop_congr {n m : Nat} (w : n = m) (f : (i : Nat) → i < n → Bool) (j : Nat) (h : j ≤ n) : allTR.loop n f j h = allTR.loop m (fun i h => f i (by omega)) j (by omega) := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
@[csimp] theorem all_eq_allTR : @all = @allTR :=
|
||||
funext fun n => funext fun f =>
|
||||
let rec go : ∀ m n f, all (m + n) f = (all n (fun i h => f i (by omega)) && allTR.loop (m + n) f m (by omega))
|
||||
| 0, n, f => by
|
||||
simp [allTR.loop]
|
||||
have t : 0 + n = n := by omega
|
||||
rw [all_congr t]
|
||||
| succ m, n, f => by
|
||||
have t : (m + 1) + n = m + (n + 1) := by omega
|
||||
rw [allTR.loop]
|
||||
simp only [succ_eq_add_one]
|
||||
rw [all_congr t, allTR_loop_congr t, go, all, Bool.and_assoc]
|
||||
congr
|
||||
omega
|
||||
go n 0 f
|
||||
|
||||
end Nat
|
||||
|
||||
namespace Prod
|
||||
|
||||
/--
|
||||
`(start, stop).foldI f a` evaluates `f` on all the numbers
|
||||
from `start` (inclusive) to `stop` (exclusive) in increasing order:
|
||||
* `(5, 8).foldI f init = init |> f 5 |> f 6 |> f 7`
|
||||
-/
|
||||
@[inline] def foldI {α : Type u} (i : Nat × Nat) (f : (j : Nat) → i.1 ≤ j → j < i.2 → α → α) (a : α) : α :=
|
||||
(i.2 - i.1).fold (fun j _ => f (i.1 + j) (by omega) (by omega)) a
|
||||
|
||||
/--
|
||||
`(start, stop).anyI f a` returns true if `f` is true for some natural number
|
||||
from `start` (inclusive) to `stop` (exclusive):
|
||||
* `(5, 8).anyI f = f 5 || f 6 || f 7`
|
||||
-/
|
||||
@[inline] def anyI (i : Nat × Nat) (f : (j : Nat) → i.1 ≤ j → j < i.2 → Bool) : Bool :=
|
||||
(i.2 - i.1).any (fun j _ => f (i.1 + j) (by omega) (by omega))
|
||||
|
||||
/--
|
||||
`(start, stop).allI f a` returns true if `f` is true for all natural numbers
|
||||
from `start` (inclusive) to `stop` (exclusive):
|
||||
* `(5, 8).anyI f = f 5 && f 6 && f 7`
|
||||
-/
|
||||
@[inline] def allI (i : Nat × Nat) (f : (j : Nat) → i.1 ≤ j → j < i.2 → Bool) : Bool :=
|
||||
(i.2 - i.1).all (fun j _ => f (i.1 + j) (by omega) (by omega))
|
||||
|
||||
end Prod
|
||||
@@ -31,7 +31,7 @@ This file defines basic operations on the the sum type `α ⊕ β`.
|
||||
|
||||
## Further material
|
||||
|
||||
See `Batteries.Data.Sum.Lemmas` for theorems about these definitions.
|
||||
See `Init.Data.Sum.Lemmas` for theorems about these definitions.
|
||||
|
||||
## Notes
|
||||
|
||||
|
||||
@@ -431,21 +431,20 @@ def getSubstring? (stx : Syntax) (withLeading := true) (withTrailing := true) :
|
||||
}
|
||||
| _, _ => none
|
||||
|
||||
@[specialize] private partial def updateLast {α} [Inhabited α] (a : Array α) (f : α → Option α) (i : Nat) : Option (Array α) :=
|
||||
if i == 0 then
|
||||
none
|
||||
else
|
||||
let i := i - 1
|
||||
let v := a[i]!
|
||||
@[specialize] private partial def updateLast {α} (a : Array α) (f : α → Option α) (i : Fin (a.size + 1)) : Option (Array α) :=
|
||||
match i with
|
||||
| 0 => none
|
||||
| ⟨i + 1, h⟩ =>
|
||||
let v := a[i]'(Nat.succ_lt_succ_iff.mp h)
|
||||
match f v with
|
||||
| some v => some <| a.set! i v
|
||||
| none => updateLast a f i
|
||||
| some v => some <| a.set i v (Nat.succ_lt_succ_iff.mp h)
|
||||
| none => updateLast a f ⟨i, Nat.lt_of_succ_lt h⟩
|
||||
|
||||
partial def setTailInfoAux (info : SourceInfo) : Syntax → Option Syntax
|
||||
| atom _ val => some <| atom info val
|
||||
| ident _ rawVal val pre => some <| ident info rawVal val pre
|
||||
| node info' k args =>
|
||||
match updateLast args (setTailInfoAux info) args.size with
|
||||
match updateLast args (setTailInfoAux info) ⟨args.size, by simp⟩ with
|
||||
| some args => some <| node info' k args
|
||||
| none => none
|
||||
| _ => none
|
||||
|
||||
@@ -71,9 +71,9 @@ def prio : Category := {}
|
||||
|
||||
/-- `prec` is a builtin syntax category for precedences. A precedence is a value
|
||||
that expresses how tightly a piece of syntax binds: for example `1 + 2 * 3` is
|
||||
parsed as `1 + (2 * 3)` because `*` has a higher pr0ecedence than `+`.
|
||||
parsed as `1 + (2 * 3)` because `*` has a higher precedence than `+`.
|
||||
Higher numbers denote higher precedence.
|
||||
In addition to literals like `37`, there are some special named priorities:
|
||||
In addition to literals like `37`, there are some special named precedence levels:
|
||||
* `arg` for the precedence of function arguments
|
||||
* `max` for the highest precedence used in term parsers (not actually the maximum possible value)
|
||||
* `lead` for the precedence of terms not supposed to be used as arguments
|
||||
|
||||
@@ -22,28 +22,28 @@ syntax explicitBinders := (ppSpace bracketedExplicitBinders)+ <|> unb
|
||||
|
||||
open TSyntax.Compat in
|
||||
def expandExplicitBindersAux (combinator : Syntax) (idents : Array Syntax) (type? : Option Syntax) (body : Syntax) : MacroM Syntax :=
|
||||
let rec loop (i : Nat) (acc : Syntax) := do
|
||||
let rec loop (i : Nat) (h : i ≤ idents.size) (acc : Syntax) := do
|
||||
match i with
|
||||
| 0 => pure acc
|
||||
| i+1 =>
|
||||
let ident := idents[i]![0]
|
||||
| i + 1 =>
|
||||
let ident := idents[i][0]
|
||||
let acc ← match ident.isIdent, type? with
|
||||
| true, none => `($combinator fun $ident => $acc)
|
||||
| true, some type => `($combinator fun $ident : $type => $acc)
|
||||
| false, none => `($combinator fun _ => $acc)
|
||||
| false, some type => `($combinator fun _ : $type => $acc)
|
||||
loop i acc
|
||||
loop idents.size body
|
||||
loop i (Nat.le_of_succ_le h) acc
|
||||
loop idents.size (by simp) body
|
||||
|
||||
def expandBrackedBindersAux (combinator : Syntax) (binders : Array Syntax) (body : Syntax) : MacroM Syntax :=
|
||||
let rec loop (i : Nat) (acc : Syntax) := do
|
||||
let rec loop (i : Nat) (h : i ≤ binders.size) (acc : Syntax) := do
|
||||
match i with
|
||||
| 0 => pure acc
|
||||
| i+1 =>
|
||||
let idents := binders[i]![1].getArgs
|
||||
let type := binders[i]![3]
|
||||
loop i (← expandExplicitBindersAux combinator idents (some type) acc)
|
||||
loop binders.size body
|
||||
let idents := binders[i][1].getArgs
|
||||
let type := binders[i][3]
|
||||
loop i (Nat.le_of_succ_le h) (← expandExplicitBindersAux combinator idents (some type) acc)
|
||||
loop binders.size (by simp) body
|
||||
|
||||
def expandExplicitBinders (combinatorDeclName : Name) (explicitBinders : Syntax) (body : Syntax) : MacroM Syntax := do
|
||||
let combinator := mkCIdentFrom (← getRef) combinatorDeclName
|
||||
|
||||
@@ -462,6 +462,16 @@ Note that it is the caller's job to remove the file after use.
|
||||
-/
|
||||
@[extern "lean_io_create_tempfile"] opaque createTempFile : IO (Handle × FilePath)
|
||||
|
||||
/--
|
||||
Creates a temporary directory in the most secure manner possible. There are no race conditions in the
|
||||
directory’s creation. The directory is readable and writable only by the creating user ID.
|
||||
|
||||
Returns the new directory's path.
|
||||
|
||||
It is the caller's job to remove the directory after use.
|
||||
-/
|
||||
@[extern "lean_io_create_tempdir"] opaque createTempDir : IO FilePath
|
||||
|
||||
end FS
|
||||
|
||||
@[extern "lean_io_getenv"] opaque getEnv (var : @& String) : BaseIO (Option String)
|
||||
@@ -474,17 +484,6 @@ namespace FS
|
||||
def withFile (fn : FilePath) (mode : Mode) (f : Handle → IO α) : IO α :=
|
||||
Handle.mk fn mode >>= f
|
||||
|
||||
/--
|
||||
Like `createTempFile` but also takes care of removing the file after usage.
|
||||
-/
|
||||
def withTempFile [Monad m] [MonadFinally m] [MonadLiftT IO m] (f : Handle → FilePath → m α) :
|
||||
m α := do
|
||||
let (handle, path) ← createTempFile
|
||||
try
|
||||
f handle path
|
||||
finally
|
||||
removeFile path
|
||||
|
||||
def Handle.putStrLn (h : Handle) (s : String) : IO Unit :=
|
||||
h.putStr (s.push '\n')
|
||||
|
||||
@@ -675,8 +674,10 @@ def appDir : IO FilePath := do
|
||||
| throw <| IO.userError s!"System.IO.appDir: unexpected filename '{p}'"
|
||||
FS.realPath p
|
||||
|
||||
namespace FS
|
||||
|
||||
/-- Create given path and all missing parents as directories. -/
|
||||
partial def FS.createDirAll (p : FilePath) : IO Unit := do
|
||||
partial def createDirAll (p : FilePath) : IO Unit := do
|
||||
if ← p.isDir then
|
||||
return ()
|
||||
if let some parent := p.parent then
|
||||
@@ -693,7 +694,7 @@ partial def FS.createDirAll (p : FilePath) : IO Unit := do
|
||||
/--
|
||||
Fully remove given directory by deleting all contained files and directories in an unspecified order.
|
||||
Fails if any contained entry cannot be deleted or was newly created during execution. -/
|
||||
partial def FS.removeDirAll (p : FilePath) : IO Unit := do
|
||||
partial def removeDirAll (p : FilePath) : IO Unit := do
|
||||
for ent in (← p.readDir) do
|
||||
if (← ent.path.isDir : Bool) then
|
||||
removeDirAll ent.path
|
||||
@@ -701,6 +702,32 @@ partial def FS.removeDirAll (p : FilePath) : IO Unit := do
|
||||
removeFile ent.path
|
||||
removeDir p
|
||||
|
||||
/--
|
||||
Like `createTempFile`, but also takes care of removing the file after usage.
|
||||
-/
|
||||
def withTempFile [Monad m] [MonadFinally m] [MonadLiftT IO m] (f : Handle → FilePath → m α) :
|
||||
m α := do
|
||||
let (handle, path) ← createTempFile
|
||||
try
|
||||
f handle path
|
||||
finally
|
||||
removeFile path
|
||||
|
||||
/--
|
||||
Like `createTempDir`, but also takes care of removing the directory after usage.
|
||||
|
||||
All files in the directory are recursively deleted, regardless of how or when they were created.
|
||||
-/
|
||||
def withTempDir [Monad m] [MonadFinally m] [MonadLiftT IO m] (f : FilePath → m α) :
|
||||
m α := do
|
||||
let path ← createTempDir
|
||||
try
|
||||
f path
|
||||
finally
|
||||
removeDirAll path
|
||||
|
||||
end FS
|
||||
|
||||
namespace Process
|
||||
|
||||
/-- Returns the current working directory of the calling process. -/
|
||||
|
||||
@@ -29,13 +29,13 @@ def decodeUri (uri : String) : String := Id.run do
|
||||
let len := rawBytes.size
|
||||
let mut i := 0
|
||||
let percent := '%'.toNat.toUInt8
|
||||
while i < len do
|
||||
let c := rawBytes[i]!
|
||||
(decoded, i) := if c == percent && i + 1 < len then
|
||||
let h1 := rawBytes[i + 1]!
|
||||
while h : i < len do
|
||||
let c := rawBytes[i]
|
||||
(decoded, i) := if h₁ : c == percent ∧ i + 1 < len then
|
||||
let h1 := rawBytes[i + 1]
|
||||
if let some hd1 := hexDigitToUInt8? h1 then
|
||||
if i + 2 < len then
|
||||
let h2 := rawBytes[i + 2]!
|
||||
if h₂ : i + 2 < len then
|
||||
let h2 := rawBytes[i + 2]
|
||||
if let some hd2 := hexDigitToUInt8? h2 then
|
||||
-- decode the hex digits into a byte.
|
||||
(decoded.push (hd1 * 16 + hd2), i + 3)
|
||||
|
||||
@@ -428,11 +428,11 @@ macro "infer_instance" : tactic => `(tactic| exact inferInstance)
|
||||
/--
|
||||
`+opt` is short for `(opt := true)`. It sets the `opt` configuration option to `true`.
|
||||
-/
|
||||
syntax posConfigItem := "+" noWs ident
|
||||
syntax posConfigItem := " +" noWs ident
|
||||
/--
|
||||
`-opt` is short for `(opt := false)`. It sets the `opt` configuration option to `false`.
|
||||
-/
|
||||
syntax negConfigItem := "-" noWs ident
|
||||
syntax negConfigItem := " -" noWs ident
|
||||
/--
|
||||
`(opt := val)` sets the `opt` configuration option to `val`.
|
||||
|
||||
|
||||
@@ -205,8 +205,8 @@ def getParamInfo (k : ParamMap.Key) : M (Array Param) := do
|
||||
|
||||
/-- For each ps[i], if ps[i] is owned, then mark xs[i] as owned. -/
|
||||
def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
|
||||
xs.size.forM fun i => do
|
||||
let x := xs[i]!
|
||||
xs.size.forM fun i _ => do
|
||||
let x := xs[i]
|
||||
let p := ps[i]!
|
||||
unless p.borrow do ownArg x
|
||||
|
||||
@@ -216,8 +216,8 @@ def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
|
||||
we would have to insert a `dec xs[i]` after `f xs` and consequently
|
||||
"break" the tail call. -/
|
||||
def ownParamsUsingArgs (xs : Array Arg) (ps : Array Param) : M Unit :=
|
||||
xs.size.forM fun i => do
|
||||
let x := xs[i]!
|
||||
xs.size.forM fun i _ => do
|
||||
let x := xs[i]
|
||||
let p := ps[i]!
|
||||
match x with
|
||||
| Arg.var x => if (← isOwned x) then ownVar p.x
|
||||
|
||||
@@ -48,9 +48,9 @@ def requiresBoxedVersion (env : Environment) (decl : Decl) : Bool :=
|
||||
def mkBoxedVersionAux (decl : Decl) : N Decl := do
|
||||
let ps := decl.params
|
||||
let qs ← ps.mapM fun _ => do let x ← N.mkFresh; pure { x := x, ty := IRType.object, borrow := false : Param }
|
||||
let (newVDecls, xs) ← qs.size.foldM (init := (#[], #[])) fun i (newVDecls, xs) => do
|
||||
let (newVDecls, xs) ← qs.size.foldM (init := (#[], #[])) fun i _ (newVDecls, xs) => do
|
||||
let p := ps[i]!
|
||||
let q := qs[i]!
|
||||
let q := qs[i]
|
||||
if !p.ty.isScalar then
|
||||
pure (newVDecls, xs.push (Arg.var q.x))
|
||||
else
|
||||
|
||||
@@ -63,7 +63,7 @@ partial def merge (v₁ v₂ : Value) : Value :=
|
||||
| top, _ => top
|
||||
| _, top => top
|
||||
| v₁@(ctor i₁ vs₁), v₂@(ctor i₂ vs₂) =>
|
||||
if i₁ == i₂ then ctor i₁ <| vs₁.size.fold (init := #[]) fun i r => r.push (merge vs₁[i]! vs₂[i]!)
|
||||
if i₁ == i₂ then ctor i₁ <| vs₁.size.fold (init := #[]) fun i _ r => r.push (merge vs₁[i] vs₂[i]!)
|
||||
else choice [v₁, v₂]
|
||||
| choice vs₁, choice vs₂ => choice <| vs₁.foldl (addChoice merge) vs₂
|
||||
| choice vs, v => choice <| addChoice merge vs v
|
||||
@@ -225,8 +225,8 @@ def updateCurrFnSummary (v : Value) : M Unit := do
|
||||
def updateJPParamsAssignment (ys : Array Param) (xs : Array Arg) : M Bool := do
|
||||
let ctx ← read
|
||||
let currFnIdx := ctx.currFnIdx
|
||||
ys.size.foldM (init := false) fun i r => do
|
||||
let y := ys[i]!
|
||||
ys.size.foldM (init := false) fun i _ r => do
|
||||
let y := ys[i]
|
||||
let x := xs[i]!
|
||||
let yVal ← findVarValue y.x
|
||||
let xVal ← findArgValue x
|
||||
@@ -282,8 +282,8 @@ partial def interpFnBody : FnBody → M Unit
|
||||
def inferStep : M Bool := do
|
||||
let ctx ← read
|
||||
modify fun s => { s with assignments := ctx.decls.map fun _ => {} }
|
||||
ctx.decls.size.foldM (init := false) fun idx modified => do
|
||||
match ctx.decls[idx]! with
|
||||
ctx.decls.size.foldM (init := false) fun idx _ modified => do
|
||||
match ctx.decls[idx] with
|
||||
| .fdecl (xs := ys) (body := b) .. => do
|
||||
let s ← get
|
||||
let currVals := s.funVals[idx]!
|
||||
@@ -336,8 +336,8 @@ def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
let funVals := s.funVals
|
||||
let assignments := s.assignments
|
||||
modify fun s =>
|
||||
let env := decls.size.fold (init := s.env) fun i env =>
|
||||
addFunctionSummary env decls[i]!.name funVals[i]!
|
||||
let env := decls.size.fold (init := s.env) fun i _ env =>
|
||||
addFunctionSummary env decls[i].name funVals[i]!
|
||||
{ s with env := env }
|
||||
return decls.mapIdx fun i decl => elimDead assignments[i]! decl
|
||||
|
||||
|
||||
@@ -108,9 +108,9 @@ def emitFnDeclAux (decl : Decl) (cppBaseName : String) (isExternal : Bool) : M U
|
||||
if ps.size > closureMaxArgs && isBoxedName decl.name then
|
||||
emit "lean_object**"
|
||||
else
|
||||
ps.size.forM fun i => do
|
||||
ps.size.forM fun i _ => do
|
||||
if i > 0 then emit ", "
|
||||
emit (toCType ps[i]!.ty)
|
||||
emit (toCType ps[i].ty)
|
||||
emit ")"
|
||||
emitLn ";"
|
||||
|
||||
@@ -271,9 +271,9 @@ def emitTag (x : VarId) (xType : IRType) : M Unit := do
|
||||
emit x
|
||||
|
||||
def isIf (alts : Array Alt) : Option (Nat × FnBody × FnBody) :=
|
||||
if alts.size != 2 then none
|
||||
else match alts[0]! with
|
||||
| Alt.ctor c b => some (c.cidx, b, alts[1]!.body)
|
||||
if h : alts.size ≠ 2 then none
|
||||
else match alts[0] with
|
||||
| Alt.ctor c b => some (c.cidx, b, alts[1].body)
|
||||
| _ => none
|
||||
|
||||
def emitInc (x : VarId) (n : Nat) (checkRef : Bool) : M Unit := do
|
||||
@@ -321,20 +321,22 @@ def emitSSet (x : VarId) (n : Nat) (offset : Nat) (y : VarId) (t : IRType) : M U
|
||||
|
||||
def emitJmp (j : JoinPointId) (xs : Array Arg) : M Unit := do
|
||||
let ps ← getJPParams j
|
||||
unless xs.size == ps.size do throw "invalid goto"
|
||||
xs.size.forM fun i => do
|
||||
let p := ps[i]!
|
||||
let x := xs[i]!
|
||||
emit p.x; emit " = "; emitArg x; emitLn ";"
|
||||
emit "goto "; emit j; emitLn ";"
|
||||
if h : xs.size = ps.size then
|
||||
xs.size.forM fun i _ => do
|
||||
let p := ps[i]
|
||||
let x := xs[i]
|
||||
emit p.x; emit " = "; emitArg x; emitLn ";"
|
||||
emit "goto "; emit j; emitLn ";"
|
||||
else
|
||||
do throw "invalid goto"
|
||||
|
||||
def emitLhs (z : VarId) : M Unit := do
|
||||
emit z; emit " = "
|
||||
|
||||
def emitArgs (ys : Array Arg) : M Unit :=
|
||||
ys.size.forM fun i => do
|
||||
ys.size.forM fun i _ => do
|
||||
if i > 0 then emit ", "
|
||||
emitArg ys[i]!
|
||||
emitArg ys[i]
|
||||
|
||||
def emitCtorScalarSize (usize : Nat) (ssize : Nat) : M Unit := do
|
||||
if usize == 0 then emit ssize
|
||||
@@ -346,8 +348,8 @@ def emitAllocCtor (c : CtorInfo) : M Unit := do
|
||||
emitCtorScalarSize c.usize c.ssize; emitLn ");"
|
||||
|
||||
def emitCtorSetArgs (z : VarId) (ys : Array Arg) : M Unit :=
|
||||
ys.size.forM fun i => do
|
||||
emit "lean_ctor_set("; emit z; emit ", "; emit i; emit ", "; emitArg ys[i]!; emitLn ");"
|
||||
ys.size.forM fun i _ => do
|
||||
emit "lean_ctor_set("; emit z; emit ", "; emit i; emit ", "; emitArg ys[i]; emitLn ");"
|
||||
|
||||
def emitCtor (z : VarId) (c : CtorInfo) (ys : Array Arg) : M Unit := do
|
||||
emitLhs z;
|
||||
@@ -358,7 +360,7 @@ def emitCtor (z : VarId) (c : CtorInfo) (ys : Array Arg) : M Unit := do
|
||||
|
||||
def emitReset (z : VarId) (n : Nat) (x : VarId) : M Unit := do
|
||||
emit "if (lean_is_exclusive("; emit x; emitLn ")) {";
|
||||
n.forM fun i => do
|
||||
n.forM fun i _ => do
|
||||
emit " lean_ctor_release("; emit x; emit ", "; emit i; emitLn ");"
|
||||
emit " "; emitLhs z; emit x; emitLn ";";
|
||||
emitLn "} else {";
|
||||
@@ -399,12 +401,12 @@ def emitSimpleExternalCall (f : String) (ps : Array Param) (ys : Array Arg) : M
|
||||
emit f; emit "("
|
||||
-- We must remove irrelevant arguments to extern calls.
|
||||
discard <| ys.size.foldM
|
||||
(fun i (first : Bool) =>
|
||||
(fun i _ (first : Bool) =>
|
||||
if ps[i]!.ty.isIrrelevant then
|
||||
pure first
|
||||
else do
|
||||
unless first do emit ", "
|
||||
emitArg ys[i]!
|
||||
emitArg ys[i]
|
||||
pure false)
|
||||
true
|
||||
emitLn ");"
|
||||
@@ -431,8 +433,8 @@ def emitPartialApp (z : VarId) (f : FunId) (ys : Array Arg) : M Unit := do
|
||||
let decl ← getDecl f
|
||||
let arity := decl.params.size;
|
||||
emitLhs z; emit "lean_alloc_closure((void*)("; emitCName f; emit "), "; emit arity; emit ", "; emit ys.size; emitLn ");";
|
||||
ys.size.forM fun i => do
|
||||
let y := ys[i]!
|
||||
ys.size.forM fun i _ => do
|
||||
let y := ys[i]
|
||||
emit "lean_closure_set("; emit z; emit ", "; emit i; emit ", "; emitArg y; emitLn ");"
|
||||
|
||||
def emitApp (z : VarId) (f : VarId) (ys : Array Arg) : M Unit :=
|
||||
@@ -544,34 +546,36 @@ That is, we have
|
||||
-/
|
||||
def overwriteParam (ps : Array Param) (ys : Array Arg) : Bool :=
|
||||
let n := ps.size;
|
||||
n.any fun i =>
|
||||
let p := ps[i]!
|
||||
(i+1, n).anyI fun j => paramEqArg p ys[j]!
|
||||
n.any fun i _ =>
|
||||
let p := ps[i]
|
||||
(i+1, n).anyI fun j _ _ => paramEqArg p ys[j]!
|
||||
|
||||
def emitTailCall (v : Expr) : M Unit :=
|
||||
match v with
|
||||
| Expr.fap _ ys => do
|
||||
let ctx ← read
|
||||
let ps := ctx.mainParams
|
||||
unless ps.size == ys.size do throw "invalid tail call"
|
||||
if overwriteParam ps ys then
|
||||
emitLn "{"
|
||||
ps.size.forM fun i => do
|
||||
let p := ps[i]!
|
||||
let y := ys[i]!
|
||||
unless paramEqArg p y do
|
||||
emit (toCType p.ty); emit " _tmp_"; emit i; emit " = "; emitArg y; emitLn ";"
|
||||
ps.size.forM fun i => do
|
||||
let p := ps[i]!
|
||||
let y := ys[i]!
|
||||
unless paramEqArg p y do emit p.x; emit " = _tmp_"; emit i; emitLn ";"
|
||||
emitLn "}"
|
||||
if h : ps.size = ys.size then
|
||||
if overwriteParam ps ys then
|
||||
emitLn "{"
|
||||
ps.size.forM fun i _ => do
|
||||
let p := ps[i]
|
||||
let y := ys[i]
|
||||
unless paramEqArg p y do
|
||||
emit (toCType p.ty); emit " _tmp_"; emit i; emit " = "; emitArg y; emitLn ";"
|
||||
ps.size.forM fun i _ => do
|
||||
let p := ps[i]
|
||||
let y := ys[i]
|
||||
unless paramEqArg p y do emit p.x; emit " = _tmp_"; emit i; emitLn ";"
|
||||
emitLn "}"
|
||||
else
|
||||
ys.size.forM fun i _ => do
|
||||
let p := ps[i]
|
||||
let y := ys[i]
|
||||
unless paramEqArg p y do emit p.x; emit " = "; emitArg y; emitLn ";"
|
||||
emitLn "goto _start;"
|
||||
else
|
||||
ys.size.forM fun i => do
|
||||
let p := ps[i]!
|
||||
let y := ys[i]!
|
||||
unless paramEqArg p y do emit p.x; emit " = "; emitArg y; emitLn ";"
|
||||
emitLn "goto _start;"
|
||||
throw "invalid tail call"
|
||||
| _ => throw "bug at emitTailCall"
|
||||
|
||||
mutual
|
||||
@@ -654,16 +658,16 @@ def emitDeclAux (d : Decl) : M Unit := do
|
||||
if xs.size > closureMaxArgs && isBoxedName d.name then
|
||||
emit "lean_object** _args"
|
||||
else
|
||||
xs.size.forM fun i => do
|
||||
xs.size.forM fun i _ => do
|
||||
if i > 0 then emit ", "
|
||||
let x := xs[i]!
|
||||
let x := xs[i]
|
||||
emit (toCType x.ty); emit " "; emit x.x
|
||||
emit ")"
|
||||
else
|
||||
emit ("_init_" ++ baseName ++ "()")
|
||||
emitLn " {";
|
||||
if xs.size > closureMaxArgs && isBoxedName d.name then
|
||||
xs.size.forM fun i => do
|
||||
xs.size.forM fun i _ => do
|
||||
let x := xs[i]!
|
||||
emit "lean_object* "; emit x.x; emit " = _args["; emit i; emitLn "];"
|
||||
emitLn "_start:";
|
||||
|
||||
@@ -571,9 +571,9 @@ def emitAllocCtor (builder : LLVM.Builder llvmctx)
|
||||
|
||||
def emitCtorSetArgs (builder : LLVM.Builder llvmctx)
|
||||
(z : VarId) (ys : Array Arg) : M llvmctx Unit := do
|
||||
ys.size.forM fun i => do
|
||||
ys.size.forM fun i _ => do
|
||||
let zv ← emitLhsVal builder z
|
||||
let (_yty, yv) ← emitArgVal builder ys[i]!
|
||||
let (_yty, yv) ← emitArgVal builder ys[i]
|
||||
let iv ← constIntUnsigned i
|
||||
callLeanCtorSet builder zv iv yv
|
||||
emitLhsSlotStore builder z zv
|
||||
@@ -702,8 +702,8 @@ def emitPartialApp (builder : LLVM.Builder llvmctx) (z : VarId) (f : FunId) (ys
|
||||
(← constIntUnsigned arity)
|
||||
(← constIntUnsigned ys.size)
|
||||
LLVM.buildStore builder zval zslot
|
||||
ys.size.forM fun i => do
|
||||
let (yty, yslot) ← emitArgSlot_ builder ys[i]!
|
||||
ys.size.forM fun i _ => do
|
||||
let (yty, yslot) ← emitArgSlot_ builder ys[i]
|
||||
let yval ← LLVM.buildLoad2 builder yty yslot
|
||||
callLeanClosureSetFn builder zval (← constIntUnsigned i) yval
|
||||
|
||||
@@ -922,7 +922,7 @@ def emitReset (builder : LLVM.Builder llvmctx) (z : VarId) (n : Nat) (x : VarId)
|
||||
buildIfThenElse_ builder "isExclusive" isExclusive
|
||||
(fun builder => do
|
||||
let xv ← emitLhsVal builder x
|
||||
n.forM fun i => do
|
||||
n.forM fun i _ => do
|
||||
callLeanCtorRelease builder xv (← constIntUnsigned i)
|
||||
emitLhsSlotStore builder z xv
|
||||
return ShouldForwardControlFlow.yes
|
||||
@@ -1172,8 +1172,8 @@ def emitFnArgs (builder : LLVM.Builder llvmctx)
|
||||
(needsPackedArgs? : Bool) (llvmfn : LLVM.Value llvmctx) (params : Array Param) : M llvmctx Unit := do
|
||||
if needsPackedArgs? then do
|
||||
let argsp ← LLVM.getParam llvmfn 0 -- lean_object **args
|
||||
for i in List.range params.size do
|
||||
let param := params[i]!
|
||||
for h : i in [:params.size] do
|
||||
let param := params[i]
|
||||
-- argsi := (args + i)
|
||||
let argsi ← LLVM.buildGEP2 builder (← LLVM.voidPtrType llvmctx) argsp #[← constIntUnsigned i] s!"packed_arg_{i}_slot"
|
||||
let llvmty ← toLLVMType param.ty
|
||||
@@ -1182,15 +1182,16 @@ def emitFnArgs (builder : LLVM.Builder llvmctx)
|
||||
-- slot for arg[i] which is always void* ?
|
||||
let alloca ← buildPrologueAlloca builder llvmty s!"arg_{i}"
|
||||
LLVM.buildStore builder pv alloca
|
||||
addVartoState params[i]!.x alloca llvmty
|
||||
addVartoState param.x alloca llvmty
|
||||
else
|
||||
let n ← LLVM.countParams llvmfn
|
||||
for i in (List.range n.toNat) do
|
||||
let llvmty ← toLLVMType params[i]!.ty
|
||||
for i in [:n.toNat] do
|
||||
let param := params[i]!
|
||||
let llvmty ← toLLVMType param.ty
|
||||
let alloca ← buildPrologueAlloca builder llvmty s!"arg_{i}"
|
||||
let arg ← LLVM.getParam llvmfn (UInt64.ofNat i)
|
||||
let _ ← LLVM.buildStore builder arg alloca
|
||||
addVartoState params[i]!.x alloca llvmty
|
||||
addVartoState param.x alloca llvmty
|
||||
|
||||
def emitDeclAux (mod : LLVM.Module llvmctx) (builder : LLVM.Builder llvmctx) (d : Decl) : M llvmctx Unit := do
|
||||
let env ← getEnv
|
||||
|
||||
@@ -54,7 +54,7 @@ abbrev Mask := Array (Option VarId)
|
||||
partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (mask : Mask) (keep : Array FnBody) : Array FnBody × Mask :=
|
||||
let done (_ : Unit) := (bs ++ keep.reverse, mask)
|
||||
let keepInstr (b : FnBody) := eraseProjIncForAux y bs.pop mask (keep.push b)
|
||||
if bs.size < 2 then done ()
|
||||
if h : bs.size < 2 then done ()
|
||||
else
|
||||
let b := bs.back!
|
||||
match b with
|
||||
@@ -62,7 +62,7 @@ partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (mask : Mask) (ke
|
||||
| .vdecl _ _ (.uproj _ _) _ => keepInstr b
|
||||
| .inc z n c p _ =>
|
||||
if n == 0 then done () else
|
||||
let b' := bs[bs.size - 2]!
|
||||
let b' := bs[bs.size - 2]
|
||||
match b' with
|
||||
| .vdecl w _ (.proj i x) _ =>
|
||||
if w == z && y == x then
|
||||
@@ -134,15 +134,15 @@ abbrev M := ReaderT Context (StateM Nat)
|
||||
modifyGet fun n => ({ idx := n }, n + 1)
|
||||
|
||||
def releaseUnreadFields (y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
|
||||
mask.size.foldM (init := b) fun i b =>
|
||||
match mask.get! i with
|
||||
mask.size.foldM (init := b) fun i _ b =>
|
||||
match mask[i] with
|
||||
| some _ => pure b -- code took ownership of this field
|
||||
| none => do
|
||||
let fld ← mkFresh
|
||||
pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true false b))
|
||||
|
||||
def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
|
||||
zs.size.fold (init := b) fun i b => FnBody.set y i (zs.get! i) b
|
||||
zs.size.fold (init := b) fun i _ b => FnBody.set y i zs[i] b
|
||||
|
||||
/-- Given `set x[i] := y`, return true iff `y := proj[i] x` -/
|
||||
def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=
|
||||
|
||||
@@ -79,13 +79,13 @@ private def addDecForAlt (ctx : Context) (caseLiveVars altLiveVars : LiveVarSet)
|
||||
/-- `isFirstOcc xs x i = true` if `xs[i]` is the first occurrence of `xs[i]` in `xs` -/
|
||||
private def isFirstOcc (xs : Array Arg) (i : Nat) : Bool :=
|
||||
let x := xs[i]!
|
||||
i.all fun j => xs[j]! != x
|
||||
i.all fun j _ => xs[j]! != x
|
||||
|
||||
/-- Return true if `x` also occurs in `ys` in a position that is not consumed.
|
||||
That is, it is also passed as a borrow reference. -/
|
||||
private def isBorrowParamAux (x : VarId) (ys : Array Arg) (consumeParamPred : Nat → Bool) : Bool :=
|
||||
ys.size.any fun i =>
|
||||
let y := ys[i]!
|
||||
ys.size.any fun i _ =>
|
||||
let y := ys[i]
|
||||
match y with
|
||||
| Arg.irrelevant => false
|
||||
| Arg.var y => x == y && !consumeParamPred i
|
||||
@@ -99,15 +99,15 @@ Return `n`, the number of times `x` is consumed.
|
||||
- `consumeParamPred i = true` if parameter `i` is consumed.
|
||||
-/
|
||||
private def getNumConsumptions (x : VarId) (ys : Array Arg) (consumeParamPred : Nat → Bool) : Nat :=
|
||||
ys.size.fold (init := 0) fun i n =>
|
||||
let y := ys[i]!
|
||||
ys.size.fold (init := 0) fun i _ n =>
|
||||
let y := ys[i]
|
||||
match y with
|
||||
| Arg.irrelevant => n
|
||||
| Arg.var y => if x == y && consumeParamPred i then n+1 else n
|
||||
|
||||
private def addIncBeforeAux (ctx : Context) (xs : Array Arg) (consumeParamPred : Nat → Bool) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody :=
|
||||
xs.size.fold (init := b) fun i b =>
|
||||
let x := xs[i]!
|
||||
xs.size.fold (init := b) fun i _ b =>
|
||||
let x := xs[i]
|
||||
match x with
|
||||
| Arg.irrelevant => b
|
||||
| Arg.var x =>
|
||||
@@ -128,8 +128,8 @@ private def addIncBefore (ctx : Context) (xs : Array Arg) (ps : Array Param) (b
|
||||
|
||||
/-- See `addIncBeforeAux`/`addIncBefore` for the procedure that inserts `inc` operations before an application. -/
|
||||
private def addDecAfterFullApp (ctx : Context) (xs : Array Arg) (ps : Array Param) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody :=
|
||||
xs.size.fold (init := b) fun i b =>
|
||||
match xs[i]! with
|
||||
xs.size.fold (init := b) fun i _ b =>
|
||||
match xs[i] with
|
||||
| Arg.irrelevant => b
|
||||
| Arg.var x =>
|
||||
/- We must add a `dec` if `x` must be consumed, it is alive after the application,
|
||||
|
||||
@@ -366,10 +366,10 @@ to be updated.
|
||||
@[implemented_by updateFunDeclCoreImp] opaque FunDeclCore.updateCore (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl
|
||||
|
||||
def CasesCore.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases :=
|
||||
let found (i : Nat) := (cases.alts[i]!, { cases with alts := cases.alts.eraseIdx i })
|
||||
if let some i := cases.alts.findIdx? fun | .alt ctorName' .. => ctorName == ctorName' | _ => false then
|
||||
let found i := (cases.alts[i], { cases with alts := cases.alts.eraseIdx i })
|
||||
if let some i := cases.alts.findFinIdx? fun | .alt ctorName' .. => ctorName == ctorName' | _ => false then
|
||||
found i
|
||||
else if let some i := cases.alts.findIdx? fun | .default _ => true | _ => false then
|
||||
else if let some i := cases.alts.findFinIdx? fun | .default _ => true | _ => false then
|
||||
found i
|
||||
else
|
||||
unreachable!
|
||||
|
||||
@@ -587,15 +587,15 @@ def Decl.elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
refer to the docstring of `Decl.safe`.
|
||||
-/
|
||||
if decls[i]!.safe then .bot else .top
|
||||
let mut funVals := decls.size.fold (init := .empty) fun i p => p.push (initialVal i)
|
||||
let mut funVals := decls.size.fold (init := .empty) fun i _ p => p.push (initialVal i)
|
||||
let ctx := { decls }
|
||||
let mut state := { assignments, funVals }
|
||||
(_, state) ← inferMain |>.run ctx |>.run state
|
||||
funVals := state.funVals
|
||||
assignments := state.assignments
|
||||
modifyEnv fun e =>
|
||||
decls.size.fold (init := e) fun i env =>
|
||||
addFunctionSummary env decls[i]!.name funVals[i]!
|
||||
decls.size.fold (init := e) fun i _ env =>
|
||||
addFunctionSummary env decls[i].name funVals[i]!
|
||||
|
||||
decls.mapIdxM fun i decl => if decl.safe then elimDead assignments[i]! decl else return decl
|
||||
|
||||
|
||||
@@ -76,8 +76,8 @@ def getType (fvarId : FVarId) : InferTypeM Expr := do
|
||||
|
||||
def mkForallFVars (xs : Array Expr) (type : Expr) : InferTypeM Expr :=
|
||||
let b := type.abstract xs
|
||||
xs.size.foldRevM (init := b) fun i b => do
|
||||
let x := xs[i]!
|
||||
xs.size.foldRevM (init := b) fun i _ b => do
|
||||
let x := xs[i]
|
||||
let n ← InferType.getBinderName x.fvarId!
|
||||
let ty ← InferType.getType x.fvarId!
|
||||
let ty := ty.abstractRange i xs;
|
||||
|
||||
@@ -134,9 +134,9 @@ def withEachOccurrence (targetName : Name) (f : Nat → PassInstaller) : PassIns
|
||||
|
||||
def installAfter (targetName : Name) (p : Pass → Pass) (occurrence : Nat := 0) : PassInstaller where
|
||||
install passes :=
|
||||
if let some idx := passes.findIdx? (fun p => p.name == targetName && p.occurrence == occurrence) then
|
||||
let passUnderTest := passes[idx]!
|
||||
return passes.insertAt! (idx + 1) (p passUnderTest)
|
||||
if let some idx := passes.findFinIdx? (fun p => p.name == targetName && p.occurrence == occurrence) then
|
||||
let passUnderTest := passes[idx]
|
||||
return passes.insertIdx (idx + 1) (p passUnderTest)
|
||||
else
|
||||
throwError s!"Tried to insert pass after {targetName}, occurrence {occurrence} but {targetName} is not in the pass list"
|
||||
|
||||
@@ -145,9 +145,9 @@ def installAfterEach (targetName : Name) (p : Pass → Pass) : PassInstaller :=
|
||||
|
||||
def installBefore (targetName : Name) (p : Pass → Pass) (occurrence : Nat := 0): PassInstaller where
|
||||
install passes :=
|
||||
if let some idx := passes.findIdx? (fun p => p.name == targetName && p.occurrence == occurrence) then
|
||||
let passUnderTest := passes[idx]!
|
||||
return passes.insertAt! idx (p passUnderTest)
|
||||
if let some idx := passes.findFinIdx? (fun p => p.name == targetName && p.occurrence == occurrence) then
|
||||
let passUnderTest := passes[idx]
|
||||
return passes.insertIdx idx (p passUnderTest)
|
||||
else
|
||||
throwError s!"Tried to insert pass after {targetName}, occurrence {occurrence} but {targetName} is not in the pass list"
|
||||
|
||||
@@ -157,9 +157,7 @@ def installBeforeEachOccurrence (targetName : Name) (p : Pass → Pass) : PassIn
|
||||
def replacePass (targetName : Name) (p : Pass → Pass) (occurrence : Nat := 0) : PassInstaller where
|
||||
install passes := do
|
||||
let some idx := passes.findIdx? (fun p => p.name == targetName && p.occurrence == occurrence) | throwError s!"Tried to replace {targetName}, occurrence {occurrence} but {targetName} is not in the pass list"
|
||||
let target := passes[idx]!
|
||||
let replacement := p target
|
||||
return passes.set! idx replacement
|
||||
return passes.modify idx p
|
||||
|
||||
def replaceEachOccurrence (targetName : Name) (p : Pass → Pass) : PassInstaller :=
|
||||
withEachOccurrence targetName (replacePass targetName p ·)
|
||||
|
||||
@@ -152,8 +152,8 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do
|
||||
let specArgs? := getSpecializationArgs? (← getEnv) decl.name
|
||||
let contains (i : Nat) : Bool := specArgs?.getD #[] |>.contains i
|
||||
let mut paramsInfo : Array SpecParamInfo := #[]
|
||||
for i in [:decl.params.size] do
|
||||
let param := decl.params[i]!
|
||||
for h :i in [:decl.params.size] do
|
||||
let param := decl.params[i]
|
||||
let info ←
|
||||
if contains i then
|
||||
pure .user
|
||||
@@ -181,14 +181,14 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do
|
||||
declsInfo := declsInfo.push paramsInfo
|
||||
if declsInfo.any fun paramsInfo => paramsInfo.any (· matches .user | .fixedInst | .fixedHO) then
|
||||
let m := mkFixedParamsMap decls
|
||||
for i in [:decls.size] do
|
||||
let decl := decls[i]!
|
||||
for hi : i in [:decls.size] do
|
||||
let decl := decls[i]
|
||||
let mut paramsInfo := declsInfo[i]!
|
||||
let some mask := m.find? decl.name | unreachable!
|
||||
trace[Compiler.specialize.info] "{decl.name} {mask}"
|
||||
paramsInfo := paramsInfo.zipWith mask fun info fixed => if fixed || info matches .user then info else .other
|
||||
for j in [:paramsInfo.size] do
|
||||
let mut info := paramsInfo[j]!
|
||||
let mut info := paramsInfo[j]!
|
||||
if info matches .fixedNeutral && !hasFwdDeps decl paramsInfo j then
|
||||
paramsInfo := paramsInfo.set! j .other
|
||||
if paramsInfo.any fun info => info matches .fixedInst | .fixedHO | .user then
|
||||
|
||||
@@ -499,8 +499,8 @@ where
|
||||
match app with
|
||||
| .fvar f =>
|
||||
let mut argsNew := #[]
|
||||
for i in [arity : args.size] do
|
||||
argsNew := argsNew.push (← visitAppArg args[i]!)
|
||||
for h :i in [arity : args.size] do
|
||||
argsNew := argsNew.push (← visitAppArg args[i])
|
||||
letValueToArg <| .fvar f argsNew
|
||||
| .erased | .type .. => return .erased
|
||||
|
||||
|
||||
@@ -26,13 +26,14 @@ private def elabSpecArgs (declName : Name) (args : Array Syntax) : MetaM (Array
|
||||
if let some idx := arg.isNatLit? then
|
||||
if idx == 0 then throwErrorAt arg "invalid specialization argument index, index must be greater than 0"
|
||||
let idx := idx - 1
|
||||
if idx >= argNames.size then
|
||||
if h : idx >= argNames.size then
|
||||
throwErrorAt arg "invalid argument index, `{declName}` has #{argNames.size} arguments"
|
||||
if result.contains idx then throwErrorAt arg "invalid specialization argument index, `{argNames[idx]!}` has already been specified as a specialization candidate"
|
||||
result := result.push idx
|
||||
else
|
||||
if result.contains idx then throwErrorAt arg "invalid specialization argument index, `{argNames[idx]}` has already been specified as a specialization candidate"
|
||||
result := result.push idx
|
||||
else
|
||||
let argName := arg.getId
|
||||
if let some idx := argNames.getIdx? argName then
|
||||
if let some idx := argNames.indexOf? argName then
|
||||
if result.contains idx then throwErrorAt arg "invalid specialization argument name `{argName}`, it has already been specified as a specialization candidate"
|
||||
result := result.push idx
|
||||
else
|
||||
|
||||
@@ -11,6 +11,7 @@ import Lean.ResolveName
|
||||
import Lean.Elab.InfoTree.Types
|
||||
import Lean.MonadEnv
|
||||
import Lean.Elab.Exception
|
||||
import Lean.Language.Basic
|
||||
|
||||
namespace Lean
|
||||
register_builtin_option diagnostics : Bool := {
|
||||
@@ -72,6 +73,13 @@ structure State where
|
||||
messages : MessageLog := {}
|
||||
/-- Info tree. We have the info tree here because we want to update it while adding attributes. -/
|
||||
infoState : Elab.InfoState := {}
|
||||
/--
|
||||
Snapshot trees of asynchronous subtasks. As these are untyped and reported only at the end of the
|
||||
command's main elaboration thread, they are only useful for basic message log reporting; for
|
||||
incremental reporting and reuse within a long-running elaboration thread, types rooted in
|
||||
`CommandParsedSnapshot` need to be adjusted.
|
||||
-/
|
||||
snapshotTasks : Array (Language.SnapshotTask Language.SnapshotTree) := #[]
|
||||
deriving Nonempty
|
||||
|
||||
/-- Context for the CoreM monad. -/
|
||||
@@ -180,7 +188,8 @@ instance : Elab.MonadInfoTree CoreM where
|
||||
modifyInfoState f := modify fun s => { s with infoState := f s.infoState }
|
||||
|
||||
@[inline] def modifyCache (f : Cache → Cache) : CoreM Unit :=
|
||||
modify fun ⟨env, next, ngen, trace, cache, messages, infoState⟩ => ⟨env, next, ngen, trace, f cache, messages, infoState⟩
|
||||
modify fun ⟨env, next, ngen, trace, cache, messages, infoState, snaps⟩ =>
|
||||
⟨env, next, ngen, trace, f cache, messages, infoState, snaps⟩
|
||||
|
||||
@[inline] def modifyInstLevelTypeCache (f : InstantiateLevelCache → InstantiateLevelCache) : CoreM Unit :=
|
||||
modifyCache fun ⟨c₁, c₂⟩ => ⟨f c₁, c₂⟩
|
||||
@@ -355,13 +364,83 @@ instance : MonadLog CoreM where
|
||||
if (← read).suppressElabErrors then
|
||||
-- discard elaboration errors, except for a few important and unlikely misleading ones, on
|
||||
-- parse error
|
||||
unless msg.data.hasTag (· matches `Elab.synthPlaceholder | `Tactic.unsolvedGoals) do
|
||||
unless msg.data.hasTag (· matches `Elab.synthPlaceholder | `Tactic.unsolvedGoals | `trace) do
|
||||
return
|
||||
|
||||
let ctx ← read
|
||||
let msg := { msg with data := MessageData.withNamingContext { currNamespace := ctx.currNamespace, openDecls := ctx.openDecls } msg.data };
|
||||
modify fun s => { s with messages := s.messages.add msg }
|
||||
|
||||
/--
|
||||
Includes a given task (such as from `wrapAsyncAsSnapshot`) in the overall snapshot tree for this
|
||||
command's elaboration, making its result available to reporting and the language server. The
|
||||
reporter will not know about this snapshot tree node until the main elaboration thread for this
|
||||
command has finished so this function is not useful for incremental reporting within a longer
|
||||
elaboration thread but only for tasks that outlive it such as background kernel checking or proof
|
||||
elaboration.
|
||||
-/
|
||||
def logSnapshotTask (task : Language.SnapshotTask Language.SnapshotTree) : CoreM Unit :=
|
||||
modify fun s => { s with snapshotTasks := s.snapshotTasks.push task }
|
||||
|
||||
/-- Wraps the given action for use in `EIO.asTask` etc., discarding its final monadic state. -/
|
||||
def wrapAsync (act : Unit → CoreM α) : CoreM (EIO Exception α) := do
|
||||
let st ← get
|
||||
let ctx ← read
|
||||
let heartbeats := (← IO.getNumHeartbeats) - ctx.initHeartbeats
|
||||
return withCurrHeartbeats (do
|
||||
-- include heartbeats since start of elaboration in new thread as well such that forking off
|
||||
-- an action doesn't suddenly allow it to succeed from a lower heartbeat count
|
||||
IO.addHeartbeats heartbeats.toUInt64
|
||||
act () : CoreM _)
|
||||
|>.run' ctx st
|
||||
|
||||
/-- Option for capturing output to stderr during elaboration. -/
|
||||
register_builtin_option stderrAsMessages : Bool := {
|
||||
defValue := true
|
||||
group := "server"
|
||||
descr := "(server) capture output to the Lean stderr channel (such as from `dbg_trace`) during elaboration of a command as a diagnostic message"
|
||||
}
|
||||
|
||||
open Language in
|
||||
/--
|
||||
Wraps the given action for use in `BaseIO.asTask` etc., discarding its final state except for
|
||||
`logSnapshotTask` tasks, which are reported as part of the returned tree.
|
||||
-/
|
||||
def wrapAsyncAsSnapshot (act : Unit → CoreM Unit) (desc : String := by exact decl_name%.toString) :
|
||||
CoreM (BaseIO SnapshotTree) := do
|
||||
let t ← wrapAsync fun _ => do
|
||||
IO.FS.withIsolatedStreams (isolateStderr := stderrAsMessages.get (← getOptions)) do
|
||||
let tid ← IO.getTID
|
||||
-- reset trace state and message log so as not to report them twice
|
||||
modify ({ · with messages := {}, traceState := { tid } })
|
||||
try
|
||||
withTraceNode `Elab.async (fun _ => return desc) do
|
||||
act ()
|
||||
catch e =>
|
||||
logError e.toMessageData
|
||||
finally
|
||||
addTraceAsMessages
|
||||
get
|
||||
let ctx ← readThe Core.Context
|
||||
return do
|
||||
match (← t.toBaseIO) with
|
||||
| .ok (output, st) =>
|
||||
let mut msgs := st.messages
|
||||
if !output.isEmpty then
|
||||
msgs := msgs.add {
|
||||
fileName := ctx.fileName
|
||||
severity := MessageSeverity.information
|
||||
pos := ctx.fileMap.toPosition <| ctx.ref.getPos?.getD 0
|
||||
data := output
|
||||
}
|
||||
return .mk {
|
||||
desc
|
||||
diagnostics := (← Language.Snapshot.Diagnostics.ofMessageLog msgs)
|
||||
traces := st.traceState
|
||||
} st.snapshotTasks
|
||||
-- interrupt or abort exception as `try catch` above should have caught any others
|
||||
| .error _ => default
|
||||
|
||||
end Core
|
||||
|
||||
export Core (CoreM mkFreshUserName checkSystem withCurrHeartbeats)
|
||||
|
||||
@@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Nat.Fold
|
||||
import Init.Data.Array.Basic
|
||||
import Init.NotationExtra
|
||||
import Init.Data.ToString.Macro
|
||||
@@ -371,7 +372,7 @@ instance : ToString Stats := ⟨Stats.toString⟩
|
||||
end PersistentArray
|
||||
|
||||
def mkPersistentArray {α : Type u} (n : Nat) (v : α) : PArray α :=
|
||||
n.fold (init := PersistentArray.empty) fun _ p => p.push v
|
||||
n.fold (init := PersistentArray.empty) fun _ _ p => p.push v
|
||||
|
||||
@[inline] def mkPArray {α : Type u} (n : Nat) (v : α) : PArray α :=
|
||||
mkPersistentArray n v
|
||||
|
||||
@@ -233,10 +233,10 @@ partial def eraseAux [BEq α] : Node α β → USize → α → Node α β
|
||||
| n@(Node.collision keys vals heq), _, k =>
|
||||
match keys.indexOf? k with
|
||||
| some idx =>
|
||||
let keys' := keys.feraseIdx idx
|
||||
have keq := keys.size_feraseIdx idx
|
||||
let vals' := vals.feraseIdx (Eq.ndrec idx heq)
|
||||
have veq := vals.size_feraseIdx (Eq.ndrec idx heq)
|
||||
let keys' := keys.eraseIdx idx
|
||||
have keq := keys.size_eraseIdx idx _
|
||||
let vals' := vals.eraseIdx (Eq.ndrec idx heq)
|
||||
have veq := vals.size_eraseIdx (Eq.ndrec idx heq) _
|
||||
have : keys.size - 1 = vals.size - 1 := by rw [heq]
|
||||
Node.collision keys' vals' (keq.trans (this.trans veq.symm))
|
||||
| none => n
|
||||
|
||||
@@ -23,6 +23,7 @@ import Lean.Elab.Quotation
|
||||
import Lean.Elab.Syntax
|
||||
import Lean.Elab.Do
|
||||
import Lean.Elab.StructInst
|
||||
import Lean.Elab.MutualInductive
|
||||
import Lean.Elab.Inductive
|
||||
import Lean.Elab.Structure
|
||||
import Lean.Elab.Print
|
||||
|
||||
@@ -807,8 +807,8 @@ def getElabElimExprInfo (elimExpr : Expr) : MetaM ElabElimInfo := do
|
||||
These are the primary set of major parameters.
|
||||
-/
|
||||
let initMotiveFVars : CollectFVars.State := motiveArgs.foldl (init := {}) collectFVars
|
||||
let motiveFVars ← xs.size.foldRevM (init := initMotiveFVars) fun i s => do
|
||||
let x := xs[i]!
|
||||
let motiveFVars ← xs.size.foldRevM (init := initMotiveFVars) fun i _ s => do
|
||||
let x := xs[i]
|
||||
if s.fvarSet.contains x.fvarId! then
|
||||
return collectFVars s (← inferType x)
|
||||
else
|
||||
@@ -1347,7 +1347,7 @@ where
|
||||
let mut unusableNamedArgs := unusableNamedArgs
|
||||
for x in xs, bInfo in bInfos do
|
||||
let xDecl ← x.mvarId!.getDecl
|
||||
if let some idx := remainingNamedArgs.findIdx? (·.name == xDecl.userName) then
|
||||
if let some idx := remainingNamedArgs.findFinIdx? (·.name == xDecl.userName) then
|
||||
/- If there is named argument with name `xDecl.userName`, then it is accounted for and we can't make use of it. -/
|
||||
remainingNamedArgs := remainingNamedArgs.eraseIdx idx
|
||||
else
|
||||
@@ -1355,9 +1355,9 @@ where
|
||||
/- We found a type of the form (baseName ...).
|
||||
First, we check if the current argument is an explicit one,
|
||||
and if the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/
|
||||
if argIdx ≤ args.size && bInfo.isExplicit then
|
||||
if h : argIdx ≤ args.size ∧ bInfo.isExplicit then
|
||||
/- We can insert `e` as an explicit argument -/
|
||||
return (args.insertAt! argIdx (Arg.expr e), namedArgs)
|
||||
return (args.insertIdx argIdx (Arg.expr e), namedArgs)
|
||||
else
|
||||
/- If we can't add `e` to `args`, we try to add it using a named argument, but this is only possible
|
||||
if there isn't an argument with the same name occurring before it. -/
|
||||
|
||||
@@ -211,7 +211,7 @@ private def replaceBinderAnnotation (binder : TSyntax ``Parser.Term.bracketedBin
|
||||
else
|
||||
`(bracketedBinderF| {$id $[: $ty?]?})
|
||||
for id in ids.reverse do
|
||||
if let some idx := binderIds.findIdx? fun binderId => binderId.raw.isIdent && binderId.raw.getId == id.raw.getId then
|
||||
if let some idx := binderIds.findFinIdx? fun binderId => binderId.raw.isIdent && binderId.raw.getId == id.raw.getId then
|
||||
binderIds := binderIds.eraseIdx idx
|
||||
modifiedVarDecls := true
|
||||
varDeclsNew := varDeclsNew.push (← mkBinder id explicit)
|
||||
|
||||
@@ -84,6 +84,7 @@ structure State where
|
||||
ngen : NameGenerator := {}
|
||||
infoState : InfoState := {}
|
||||
traceState : TraceState := {}
|
||||
snapshotTasks : Array (Language.SnapshotTask Language.SnapshotTree) := #[]
|
||||
deriving Nonempty
|
||||
|
||||
structure Context where
|
||||
@@ -114,8 +115,7 @@ structure Context where
|
||||
-/
|
||||
suppressElabErrors : Bool := false
|
||||
|
||||
abbrev CommandElabCoreM (ε) := ReaderT Context $ StateRefT State $ EIO ε
|
||||
abbrev CommandElabM := CommandElabCoreM Exception
|
||||
abbrev CommandElabM := ReaderT Context $ StateRefT State $ EIO Exception
|
||||
abbrev CommandElab := Syntax → CommandElabM Unit
|
||||
structure Linter where
|
||||
run : Syntax → CommandElabM Unit
|
||||
@@ -198,36 +198,6 @@ instance : AddErrorMessageContext CommandElabM where
|
||||
let msg ← addMacroStack msg ctx.macroStack
|
||||
return (ref, msg)
|
||||
|
||||
def mkMessageAux (ctx : Context) (ref : Syntax) (msgData : MessageData) (severity : MessageSeverity) : Message :=
|
||||
let pos := ref.getPos?.getD ctx.cmdPos
|
||||
let endPos := ref.getTailPos?.getD pos
|
||||
mkMessageCore ctx.fileName ctx.fileMap msgData severity pos endPos
|
||||
|
||||
private def addTraceAsMessagesCore (ctx : Context) (log : MessageLog) (traceState : TraceState) : MessageLog := Id.run do
|
||||
if traceState.traces.isEmpty then return log
|
||||
let mut traces : Std.HashMap (String.Pos × String.Pos) (Array MessageData) := ∅
|
||||
for traceElem in traceState.traces do
|
||||
let ref := replaceRef traceElem.ref ctx.ref
|
||||
let pos := ref.getPos?.getD 0
|
||||
let endPos := ref.getTailPos?.getD pos
|
||||
traces := traces.insert (pos, endPos) <| traces.getD (pos, endPos) #[] |>.push traceElem.msg
|
||||
let mut log := log
|
||||
let traces' := traces.toArray.qsort fun ((a, _), _) ((b, _), _) => a < b
|
||||
for ((pos, endPos), traceMsg) in traces' do
|
||||
let data := .tagged `trace <| .joinSep traceMsg.toList "\n"
|
||||
log := log.add <| mkMessageCore ctx.fileName ctx.fileMap data .information pos endPos
|
||||
return log
|
||||
|
||||
private def addTraceAsMessages : CommandElabM Unit := do
|
||||
let ctx ← read
|
||||
-- do not add trace messages if `trace.profiler.output` is set as it would be redundant and
|
||||
-- pretty printing the trace messages is expensive
|
||||
if trace.profiler.output.get? (← getOptions) |>.isNone then
|
||||
modify fun s => { s with
|
||||
messages := addTraceAsMessagesCore ctx s.messages s.traceState
|
||||
traceState.traces := {}
|
||||
}
|
||||
|
||||
private def runCore (x : CoreM α) : CommandElabM α := do
|
||||
let s ← get
|
||||
let ctx ← read
|
||||
@@ -253,6 +223,7 @@ private def runCore (x : CoreM α) : CommandElabM α := do
|
||||
nextMacroScope := s.nextMacroScope
|
||||
infoState.enabled := s.infoState.enabled
|
||||
traceState := s.traceState
|
||||
snapshotTasks := s.snapshotTasks
|
||||
}
|
||||
let (ea, coreS) ← liftM x
|
||||
modify fun s => { s with
|
||||
@@ -261,6 +232,7 @@ private def runCore (x : CoreM α) : CommandElabM α := do
|
||||
ngen := coreS.ngen
|
||||
infoState.trees := s.infoState.trees.append coreS.infoState.trees
|
||||
traceState.traces := coreS.traceState.traces.map fun t => { t with ref := replaceRef t.ref ctx.ref }
|
||||
snapshotTasks := coreS.snapshotTasks
|
||||
messages := s.messages ++ coreS.messages
|
||||
}
|
||||
return ea
|
||||
@@ -268,10 +240,6 @@ private def runCore (x : CoreM α) : CommandElabM α := do
|
||||
def liftCoreM (x : CoreM α) : CommandElabM α := do
|
||||
MonadExcept.ofExcept (← runCore (observing x))
|
||||
|
||||
private def ioErrorToMessage (ctx : Context) (ref : Syntax) (err : IO.Error) : Message :=
|
||||
let ref := getBetterRef ref ctx.macroStack
|
||||
mkMessageAux ctx ref (toString err) MessageSeverity.error
|
||||
|
||||
@[inline] def liftIO {α} (x : IO α) : CommandElabM α := do
|
||||
let ctx ← read
|
||||
IO.toEIO (fun (ex : IO.Error) => Exception.error ctx.ref ex.toString) x
|
||||
@@ -294,9 +262,8 @@ instance : MonadLog CommandElabM where
|
||||
logMessage msg := do
|
||||
if (← read).suppressElabErrors then
|
||||
-- discard elaboration errors on parse error
|
||||
-- NOTE: unlike `CoreM`'s `logMessage`, we do not currently have any command-level errors that
|
||||
-- we want to allowlist
|
||||
return
|
||||
unless msg.data.hasTag (· matches `trace) do
|
||||
return
|
||||
let currNamespace ← getCurrNamespace
|
||||
let openDecls ← getOpenDecls
|
||||
let msg := { msg with data := MessageData.withNamingContext { currNamespace := currNamespace, openDecls := openDecls } msg.data }
|
||||
@@ -322,6 +289,61 @@ def runLinters (stx : Syntax) : CommandElabM Unit := do
|
||||
finally
|
||||
modify fun s => { savedState with messages := s.messages }
|
||||
|
||||
/--
|
||||
Catches and logs exceptions occurring in `x`. Unlike `try catch` in `CommandElabM`, this function
|
||||
catches interrupt exceptions as well and thus is intended for use at the top level of elaboration.
|
||||
Interrupt and abort exceptions are caught but not logged.
|
||||
-/
|
||||
@[inline] def withLoggingExceptions (x : CommandElabM Unit) : CommandElabM Unit := fun ctx ref =>
|
||||
EIO.catchExceptions (withLogging x ctx ref) (fun _ => pure ())
|
||||
|
||||
@[inherit_doc Core.wrapAsync]
|
||||
def wrapAsync (act : Unit → CommandElabM α) : CommandElabM (EIO Exception α) := do
|
||||
return act () |>.run (← read) |>.run' (← get)
|
||||
|
||||
open Language in
|
||||
@[inherit_doc Core.wrapAsyncAsSnapshot]
|
||||
-- `CoreM` and `CommandElabM` are too different to meaningfully share this code
|
||||
def wrapAsyncAsSnapshot (act : Unit → CommandElabM Unit)
|
||||
(desc : String := by exact decl_name%.toString) :
|
||||
CommandElabM (BaseIO SnapshotTree) := do
|
||||
let t ← wrapAsync fun _ => do
|
||||
IO.FS.withIsolatedStreams (isolateStderr := Core.stderrAsMessages.get (← getOptions)) do
|
||||
let tid ← IO.getTID
|
||||
-- reset trace state and message log so as not to report them twice
|
||||
modify ({ · with messages := {}, traceState := { tid } })
|
||||
try
|
||||
withTraceNode `Elab.async (fun _ => return desc) do
|
||||
act ()
|
||||
catch e =>
|
||||
logError e.toMessageData
|
||||
finally
|
||||
addTraceAsMessages
|
||||
get
|
||||
let ctx ← read
|
||||
return do
|
||||
match (← t.toBaseIO) with
|
||||
| .ok (output, st) =>
|
||||
let mut msgs := st.messages
|
||||
if !output.isEmpty then
|
||||
msgs := msgs.add {
|
||||
fileName := ctx.fileName
|
||||
severity := MessageSeverity.information
|
||||
pos := ctx.fileMap.toPosition <| ctx.ref.getPos?.getD 0
|
||||
data := output
|
||||
}
|
||||
return .mk {
|
||||
desc
|
||||
diagnostics := (← Language.Snapshot.Diagnostics.ofMessageLog msgs)
|
||||
traces := st.traceState
|
||||
} st.snapshotTasks
|
||||
-- interrupt or abort exception as `try catch` above should have caught any others
|
||||
| .error _ => default
|
||||
|
||||
@[inherit_doc Core.logSnapshotTask]
|
||||
def logSnapshotTask (task : Language.SnapshotTask Language.SnapshotTree) : CommandElabM Unit :=
|
||||
modify fun s => { s with snapshotTasks := s.snapshotTasks.push task }
|
||||
|
||||
protected def getCurrMacroScope : CommandElabM Nat := do pure (← read).currMacroScope
|
||||
protected def getMainModule : CommandElabM Name := do pure (← getEnv).mainModule
|
||||
|
||||
@@ -532,12 +554,6 @@ def elabCommandTopLevel (stx : Syntax) : CommandElabM Unit := withRef stx do pro
|
||||
let mut msgs := (← get).messages
|
||||
for tree in (← getInfoTrees) do
|
||||
trace[Elab.info] (← tree.format)
|
||||
if (← isTracingEnabledFor `Elab.snapshotTree) then
|
||||
if let some snap := (← read).snap? then
|
||||
-- We can assume that the root command snapshot is not involved in parallelism yet, so this
|
||||
-- should be true iff the command supports incrementality
|
||||
if (← IO.hasFinished snap.new.result) then
|
||||
liftCoreM <| Language.ToSnapshotTree.toSnapshotTree snap.new.result.get |>.trace
|
||||
modify fun st => { st with
|
||||
messages := initMsgs ++ msgs
|
||||
infoState := { st.infoState with trees := initInfoTrees ++ st.infoState.trees }
|
||||
@@ -668,14 +684,6 @@ def runTermElabM (elabFn : Array Expr → TermElabM α) : CommandElabM α := do
|
||||
Term.addAutoBoundImplicits' xs someType fun xs _ =>
|
||||
Term.withoutAutoBoundImplicit <| elabFn xs
|
||||
|
||||
/--
|
||||
Catches and logs exceptions occurring in `x`. Unlike `try catch` in `CommandElabM`, this function
|
||||
catches interrupt exceptions as well and thus is intended for use at the top level of elaboration.
|
||||
Interrupt and abort exceptions are caught but not logged.
|
||||
-/
|
||||
@[inline] def withLoggingExceptions (x : CommandElabM Unit) : CommandElabCoreM Empty Unit := fun ctx ref =>
|
||||
EIO.catchExceptions (withLogging x ctx ref) (fun _ => pure ())
|
||||
|
||||
private def liftAttrM {α} (x : AttrM α) : CommandElabM α := do
|
||||
liftCoreM x
|
||||
|
||||
|
||||
@@ -7,9 +7,8 @@ prelude
|
||||
import Lean.Util.CollectLevelParams
|
||||
import Lean.Elab.DeclUtil
|
||||
import Lean.Elab.DefView
|
||||
import Lean.Elab.Inductive
|
||||
import Lean.Elab.Structure
|
||||
import Lean.Elab.MutualDef
|
||||
import Lean.Elab.MutualInductive
|
||||
import Lean.Elab.DeclarationRange
|
||||
namespace Lean.Elab.Command
|
||||
|
||||
@@ -163,15 +162,11 @@ def elabDeclaration : CommandElab := fun stx => do
|
||||
if declKind == ``Lean.Parser.Command.«axiom» then
|
||||
let modifiers ← elabModifiers modifiers
|
||||
elabAxiom modifiers decl
|
||||
else if declKind == ``Lean.Parser.Command.«inductive» then
|
||||
else if declKind == ``Lean.Parser.Command.«inductive»
|
||||
|| declKind == ``Lean.Parser.Command.classInductive
|
||||
|| declKind == ``Lean.Parser.Command.«structure» then
|
||||
let modifiers ← elabModifiers modifiers
|
||||
elabInductive modifiers decl
|
||||
else if declKind == ``Lean.Parser.Command.classInductive then
|
||||
let modifiers ← elabModifiers modifiers
|
||||
elabClassInductive modifiers decl
|
||||
else if declKind == ``Lean.Parser.Command.«structure» then
|
||||
let modifiers ← elabModifiers modifiers
|
||||
elabStructure modifiers decl
|
||||
else
|
||||
throwError "unexpected declaration"
|
||||
|
||||
@@ -278,10 +273,10 @@ def elabMutual : CommandElab := fun stx => do
|
||||
-- only case implementing incrementality currently
|
||||
elabMutualDef stx[1].getArgs
|
||||
else withoutCommandIncrementality true do
|
||||
if isMutualInductive stx then
|
||||
if ← isMutualInductive stx then
|
||||
elabMutualInductive stx[1].getArgs
|
||||
else
|
||||
throwError "invalid mutual block: either all elements of the block must be inductive declarations, or they must all be definitions/theorems/abbrevs"
|
||||
throwError "invalid mutual block: either all elements of the block must be inductive/structure declarations, or they must all be definitions/theorems/abbrevs"
|
||||
|
||||
/- leading_parser "attribute " >> "[" >> sepBy1 (eraseAttr <|> Term.attrInstance) ", " >> "]" >> many1 ident -/
|
||||
@[builtin_command_elab «attribute»] def elabAttr : CommandElab := fun stx => do
|
||||
|
||||
@@ -49,9 +49,9 @@ invoking ``mkInstImplicitBinders `BarClass foo #[`α, `n, `β]`` gives `` `([Bar
|
||||
def mkInstImplicitBinders (className : Name) (indVal : InductiveVal) (argNames : Array Name) : TermElabM (Array Syntax) :=
|
||||
forallBoundedTelescope indVal.type indVal.numParams fun xs _ => do
|
||||
let mut binders := #[]
|
||||
for i in [:xs.size] do
|
||||
for h : i in [:xs.size] do
|
||||
try
|
||||
let x := xs[i]!
|
||||
let x := xs[i]
|
||||
let c ← mkAppM className #[x]
|
||||
if (← isTypeCorrect c) then
|
||||
let argName := argNames[i]!
|
||||
@@ -86,8 +86,8 @@ def mkContext (fnPrefix : String) (typeName : Name) : TermElabM Context := do
|
||||
|
||||
def mkLocalInstanceLetDecls (ctx : Context) (className : Name) (argNames : Array Name) : TermElabM (Array (TSyntax ``Parser.Term.letDecl)) := do
|
||||
let mut letDecls := #[]
|
||||
for i in [:ctx.typeInfos.size] do
|
||||
let indVal := ctx.typeInfos[i]!
|
||||
for h : i in [:ctx.typeInfos.size] do
|
||||
let indVal := ctx.typeInfos[i]
|
||||
let auxFunName := ctx.auxFunNames[i]!
|
||||
let currArgNames ← mkInductArgNames indVal
|
||||
let numParams := indVal.numParams
|
||||
|
||||
@@ -796,10 +796,10 @@ Note that we are not restricting the macro power since the
|
||||
actions to be in the same universe.
|
||||
-/
|
||||
private def mkTuple (elems : Array Syntax) : MacroM Syntax := do
|
||||
if elems.size == 0 then
|
||||
if elems.size = 0 then
|
||||
mkUnit
|
||||
else if elems.size == 1 then
|
||||
return elems[0]!
|
||||
else if h : elems.size = 1 then
|
||||
return elems[0]
|
||||
else
|
||||
elems.extract 0 (elems.size - 1) |>.foldrM (init := elems.back!) fun elem tuple =>
|
||||
``(MProd.mk $elem $tuple)
|
||||
@@ -831,10 +831,10 @@ def isDoExpr? (doElem : Syntax) : Option Syntax :=
|
||||
We use this method when expanding the `for-in` notation.
|
||||
-/
|
||||
private def destructTuple (uvars : Array Var) (x : Syntax) (body : Syntax) : MacroM Syntax := do
|
||||
if uvars.size == 0 then
|
||||
if uvars.size = 0 then
|
||||
return body
|
||||
else if uvars.size == 1 then
|
||||
`(let $(uvars[0]!):ident := $x; $body)
|
||||
else if h : uvars.size = 1 then
|
||||
`(let $(uvars[0]):ident := $x; $body)
|
||||
else
|
||||
destruct uvars.toList x body
|
||||
where
|
||||
@@ -1314,9 +1314,9 @@ private partial def expandLiftMethodAux (inQuot : Bool) (inBinder : Bool) : Synt
|
||||
else if liftMethodDelimiter k then
|
||||
return stx
|
||||
-- For `pure` if-then-else, we only lift `(<- ...)` occurring in the condition.
|
||||
else if args.size >= 2 && (k == ``termDepIfThenElse || k == ``termIfThenElse) then do
|
||||
else if h : args.size >= 2 ∧ (k == ``termDepIfThenElse || k == ``termIfThenElse) then do
|
||||
let inAntiquot := stx.isAntiquot && !stx.isEscapedAntiquot
|
||||
let arg1 ← expandLiftMethodAux (inQuot && !inAntiquot || stx.isQuot) inBinder args[1]!
|
||||
let arg1 ← expandLiftMethodAux (inQuot && !inAntiquot || stx.isQuot) inBinder args[1]
|
||||
let args := args.set! 1 arg1
|
||||
return Syntax.node i k args
|
||||
else if k == ``Parser.Term.liftMethod && !inQuot then withFreshMacroScope do
|
||||
@@ -1518,7 +1518,7 @@ mutual
|
||||
-/
|
||||
partial def doForToCode (doFor : Syntax) (doElems : List Syntax) : M CodeBlock := do
|
||||
let doForDecls := doFor[1].getSepArgs
|
||||
if doForDecls.size > 1 then
|
||||
if h : doForDecls.size > 1 then
|
||||
/-
|
||||
Expand
|
||||
```
|
||||
|
||||
@@ -102,7 +102,7 @@ partial def IO.processCommandsIncrementally (inputCtx : Parser.InputContext)
|
||||
where
|
||||
go initialSnap t commands :=
|
||||
let snap := t.get
|
||||
let commands := commands.push snap.data
|
||||
let commands := commands.push snap
|
||||
if let some next := snap.nextCmdSnap? then
|
||||
go initialSnap next.task commands
|
||||
else
|
||||
@@ -115,9 +115,9 @@ where
|
||||
-- snapshots as they subsume any info trees reported incrementally by their children.
|
||||
let trees := commands.map (·.finishedSnap.get.infoTree?) |>.filterMap id |>.toPArray'
|
||||
return {
|
||||
commandState := { snap.data.finishedSnap.get.cmdState with messages, infoState.trees := trees }
|
||||
parserState := snap.data.parserState
|
||||
cmdPos := snap.data.parserState.pos
|
||||
commandState := { snap.finishedSnap.get.cmdState with messages, infoState.trees := trees }
|
||||
parserState := snap.parserState
|
||||
cmdPos := snap.parserState.pos
|
||||
commands := commands.map (·.stx)
|
||||
inputCtx, initialSnap
|
||||
}
|
||||
@@ -164,8 +164,8 @@ def runFrontend
|
||||
| return (← mkEmptyEnvironment, false)
|
||||
|
||||
if let some out := trace.profiler.output.get? opts then
|
||||
let traceState := cmdState.traceState
|
||||
let profile ← Firefox.Profile.export mainModuleName.toString startTime traceState opts
|
||||
let traceStates := snaps.getAll.map (·.traces)
|
||||
let profile ← Firefox.Profile.export mainModuleName.toString startTime traceStates opts
|
||||
IO.FS.writeFile ⟨out⟩ <| Json.compress <| toJson profile
|
||||
|
||||
let hasErrors := snaps.getAll.any (·.diagnostics.msgLog.hasErrors)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -87,8 +87,8 @@ private def elabLetRecDeclValues (view : LetRecView) : TermElabM (Array Expr) :=
|
||||
view.decls.mapM fun view => do
|
||||
forallBoundedTelescope view.type view.binderIds.size fun xs type => do
|
||||
-- Add new info nodes for new fvars. The server will detect all fvars of a binder by the binder's source location.
|
||||
for i in [0:view.binderIds.size] do
|
||||
addLocalVarInfo view.binderIds[i]! xs[i]!
|
||||
for h : i in [0:view.binderIds.size] do
|
||||
addLocalVarInfo view.binderIds[i] xs[i]!
|
||||
withDeclName view.declName do
|
||||
withInfoContext' view.valStx
|
||||
(mkInfo := (pure <| .inl <| mkBodyInfo view.valStx ·))
|
||||
|
||||
@@ -282,8 +282,8 @@ where
|
||||
let dArg := dArgs[i]!
|
||||
unless (← isDefEq tArg dArg) do
|
||||
return i :: (← goType tArg dArg)
|
||||
for i in [info.numParams : tArgs.size] do
|
||||
let tArg := tArgs[i]!
|
||||
for h : i in [info.numParams : tArgs.size] do
|
||||
let tArg := tArgs[i]
|
||||
let dArg := dArgs[i]!
|
||||
unless (← isDefEq tArg dArg) do
|
||||
return i :: (← goIndex tArg dArg)
|
||||
|
||||
@@ -840,8 +840,8 @@ private def mkLetRecClosures (sectionVars : Array Expr) (mainFVarIds : Array FVa
|
||||
abbrev Replacement := FVarIdMap Expr
|
||||
|
||||
def insertReplacementForMainFns (r : Replacement) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) : Replacement :=
|
||||
mainFVars.size.fold (init := r) fun i r =>
|
||||
r.insert mainFVars[i]!.fvarId! (mkAppN (Lean.mkConst mainHeaders[i]!.declName) sectionVars)
|
||||
mainFVars.size.fold (init := r) fun i _ r =>
|
||||
r.insert mainFVars[i].fvarId! (mkAppN (Lean.mkConst mainHeaders[i]!.declName) sectionVars)
|
||||
|
||||
|
||||
def insertReplacementForLetRecs (r : Replacement) (letRecClosures : List LetRecClosure) : Replacement :=
|
||||
@@ -871,8 +871,8 @@ def Replacement.apply (r : Replacement) (e : Expr) : Expr :=
|
||||
|
||||
def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainVals : Array Expr)
|
||||
: TermElabM (Array PreDefinition) :=
|
||||
mainHeaders.size.foldM (init := preDefs) fun i preDefs => do
|
||||
let header := mainHeaders[i]!
|
||||
mainHeaders.size.foldM (init := preDefs) fun i _ preDefs => do
|
||||
let header := mainHeaders[i]
|
||||
let termination ← declValToTerminationHint header.value
|
||||
let termination := termination.rememberExtraParams header.numParams mainVals[i]!
|
||||
let value ← mkLambdaFVars sectionVars mainVals[i]!
|
||||
|
||||
1041
src/Lean/Elab/MutualInductive.lean
Normal file
1041
src/Lean/Elab/MutualInductive.lean
Normal file
File diff suppressed because it is too large
Load Diff
@@ -17,7 +17,7 @@ open Lean.Parser.Command
|
||||
private partial def antiquote (vars : Array Syntax) : Syntax → Syntax
|
||||
| stx => match stx with
|
||||
| `($id:ident) =>
|
||||
if (vars.findIdx? (fun var => var.getId == id.getId)).isSome then
|
||||
if vars.any (fun var => var.getId == id.getId) then
|
||||
mkAntiquotNode id (kind := `term) (isPseudoKind := true)
|
||||
else
|
||||
stx
|
||||
|
||||
@@ -49,12 +49,12 @@ private def resolveNameUsingNamespacesCore (nss : List Name) (idStx : Syntax) :
|
||||
exs := exs.push ex
|
||||
if exs.size == nss.length then
|
||||
withRef idStx do
|
||||
if exs.size == 1 then
|
||||
throw exs[0]!
|
||||
if h : exs.size = 1 then
|
||||
throw exs[0]
|
||||
else
|
||||
throwErrorWithNestedErrors "failed to open" exs
|
||||
if result.size == 1 then
|
||||
return result[0]!
|
||||
if h : result.size = 1 then
|
||||
return result[0]
|
||||
else
|
||||
withRef idStx do throwError "ambiguous identifier '{idStx.getId}', possible interpretations: {result.map mkConst}"
|
||||
|
||||
|
||||
@@ -332,9 +332,9 @@ where
|
||||
else
|
||||
let accessible := isNextArgAccessible ctx
|
||||
let (d, ctx) := getNextParam ctx
|
||||
match ctx.namedArgs.findIdx? fun namedArg => namedArg.name == d.1 with
|
||||
match ctx.namedArgs.findFinIdx? fun namedArg => namedArg.name == d.1 with
|
||||
| some idx =>
|
||||
let arg := ctx.namedArgs[idx]!
|
||||
let arg := ctx.namedArgs[idx]
|
||||
let ctx := { ctx with namedArgs := ctx.namedArgs.eraseIdx idx }
|
||||
let ctx ← pushNewArg accessible ctx arg.val
|
||||
processCtorAppContext ctx
|
||||
|
||||
@@ -244,8 +244,8 @@ def checkCodomainsLevel (preDefs : Array PreDefinition) : MetaM Unit := do
|
||||
lambdaTelescope preDef.value fun xs _ => return xs.size
|
||||
forallBoundedTelescope preDefs[0]!.type arities[0]! fun _ type₀ => do
|
||||
let u₀ ← getLevel type₀
|
||||
for i in [1:preDefs.size] do
|
||||
forallBoundedTelescope preDefs[i]!.type arities[i]! fun _ typeᵢ =>
|
||||
for h : i in [1:preDefs.size] do
|
||||
forallBoundedTelescope preDefs[i].type arities[i]! fun _ typeᵢ =>
|
||||
unless ← isLevelDefEq u₀ (← getLevel typeᵢ) do
|
||||
withOptions (fun o => pp.sanitizeNames.set o false) do
|
||||
throwError m!"invalid mutual definition, result types must be in the same universe " ++
|
||||
|
||||
@@ -145,8 +145,8 @@ private partial def replaceRecApps (recArgInfos : Array RecArgInfo) (positions :
|
||||
| Expr.app _ _ =>
|
||||
let processApp (e : Expr) : StateRefT (HasConstCache recFnNames) M Expr :=
|
||||
e.withApp fun f args => do
|
||||
if let .some fnIdx := recArgInfos.findIdx? (f.isConstOf ·.fnName) then
|
||||
let recArgInfo := recArgInfos[fnIdx]!
|
||||
if let .some fnIdx := recArgInfos.findFinIdx? (f.isConstOf ·.fnName) then
|
||||
let recArgInfo := recArgInfos[fnIdx]
|
||||
let some recArg := args[recArgInfo.recArgPos]?
|
||||
| throwError "insufficient number of parameters at recursive application {indentExpr e}"
|
||||
-- For reflexive type, we may have nested recursive applications in recArg
|
||||
@@ -292,9 +292,9 @@ def mkBrecOnApp (positions : Positions) (fnIdx : Nat) (brecOnConst : Nat → Exp
|
||||
let packedFTypes ← inferArgumentTypesN positions.size brecOn
|
||||
let packedFArgs ← positions.mapMwith PProdN.mkLambdas packedFTypes FArgs
|
||||
let brecOn := mkAppN brecOn packedFArgs
|
||||
let some poss := positions.find? (·.contains fnIdx)
|
||||
let some (size, idx) := positions.findSome? fun pos => (pos.size, ·) <$> pos.indexOf? fnIdx
|
||||
| throwError "mkBrecOnApp: Could not find {fnIdx} in {positions}"
|
||||
let brecOn ← PProdN.proj poss.size (poss.getIdx? fnIdx).get! brecOn
|
||||
let brecOn ← PProdN.proj size idx brecOn
|
||||
mkLambdaFVars ys (mkAppN brecOn otherArgs)
|
||||
|
||||
end Lean.Elab.Structural
|
||||
|
||||
@@ -53,10 +53,10 @@ def TerminationArgument.elab (funName : Name) (type : Expr) (arity extraParams :
|
||||
(hint : TerminationBy) : TermElabM TerminationArgument := withDeclName funName do
|
||||
assert! extraParams ≤ arity
|
||||
|
||||
if hint.vars.size > extraParams then
|
||||
if h : hint.vars.size > extraParams then
|
||||
let mut msg := m!"{parameters hint.vars.size} bound in `termination_by`, but the body of " ++
|
||||
m!"{funName} only binds {parameters extraParams}."
|
||||
if let `($ident:ident) := hint.vars[0]! then
|
||||
if let `($ident:ident) := hint.vars[0] then
|
||||
if ident.getId.isSuffixOf funName then
|
||||
msg := msg ++ m!" (Since Lean v4.6.0, the `termination_by` clause no longer " ++
|
||||
"expects the function name here.)"
|
||||
|
||||
@@ -90,10 +90,10 @@ lambda of `value`, and throws appropriate errors.
|
||||
-/
|
||||
def TerminationBy.checkVars (funName : Name) (extraParams : Nat) (tb : TerminationBy) : MetaM Unit := do
|
||||
unless tb.synthetic do
|
||||
if tb.vars.size > extraParams then
|
||||
if h : tb.vars.size > extraParams then
|
||||
let mut msg := m!"{parameters tb.vars.size} bound in `termination_by`, but the body of " ++
|
||||
m!"{funName} only binds {parameters extraParams}."
|
||||
if let `($ident:ident) := tb.vars[0]! then
|
||||
if let `($ident:ident) := tb.vars[0] then
|
||||
if ident.getId.isSuffixOf funName then
|
||||
msg := msg ++ m!" (Since Lean v4.6.0, the `termination_by` clause no longer " ++
|
||||
"expects the function name here.)"
|
||||
|
||||
@@ -21,8 +21,8 @@ open Meta
|
||||
private partial def addNonRecPreDefs (fixedPrefixSize : Nat) (argsPacker : ArgsPacker) (preDefs : Array PreDefinition) (preDefNonRec : PreDefinition) : TermElabM Unit := do
|
||||
let us := preDefNonRec.levelParams.map mkLevelParam
|
||||
let all := preDefs.toList.map (·.declName)
|
||||
for fidx in [:preDefs.size] do
|
||||
let preDef := preDefs[fidx]!
|
||||
for h : fidx in [:preDefs.size] do
|
||||
let preDef := preDefs[fidx]
|
||||
let value ← forallBoundedTelescope preDef.type (some fixedPrefixSize) fun xs _ => do
|
||||
let value := mkAppN (mkConst preDefNonRec.declName us) xs
|
||||
let value ← argsPacker.curryProj value fidx
|
||||
|
||||
@@ -40,7 +40,7 @@ private partial def post (fixedPrefix : Nat) (argsPacker : ArgsPacker) (funNames
|
||||
return TransformStep.done e
|
||||
let declName := f.constName!
|
||||
let us := f.constLevels!
|
||||
if let some fidx := funNames.getIdx? declName then
|
||||
if let some fidx := funNames.indexOf? declName then
|
||||
let arity := fixedPrefix + argsPacker.varNamess[fidx]!.size
|
||||
let e' ← withAppN arity e fun args => do
|
||||
let fixedArgs := args[:fixedPrefix]
|
||||
|
||||
@@ -58,7 +58,7 @@ partial def mkTuple : Array Syntax → TermElabM Syntax
|
||||
| #[] => `(Unit.unit)
|
||||
| #[e] => return e
|
||||
| es => do
|
||||
let stx ← mkTuple (es.eraseIdx 0)
|
||||
let stx ← mkTuple (es.eraseIdxIfInBounds 0)
|
||||
`(Prod.mk $(es[0]!) $stx)
|
||||
|
||||
def resolveSectionVariable (sectionVars : NameMap Name) (id : Name) : List (Name × List String) :=
|
||||
|
||||
@@ -302,59 +302,58 @@ instance : ToFormat FieldLHS where
|
||||
| .fieldIndex _ i => format i
|
||||
| .modifyOp _ i => "[" ++ i.prettyPrint ++ "]"
|
||||
|
||||
/--
|
||||
`FieldVal StructInstView` is a representation of a field value in the structure instance.
|
||||
-/
|
||||
inductive FieldVal (σ : Type) where
|
||||
/-- A `term` to use for the value of the field. -/
|
||||
| term (stx : Syntax) : FieldVal σ
|
||||
/-- A `StructInstView` to use for the value of a subobject field. -/
|
||||
| nested (s : σ) : FieldVal σ
|
||||
/-- A field that was not provided and should be synthesized using default values. -/
|
||||
| default : FieldVal σ
|
||||
deriving Inhabited
|
||||
mutual
|
||||
/--
|
||||
`FieldVal StructInstView` is a representation of a field value in the structure instance.
|
||||
-/
|
||||
inductive FieldVal where
|
||||
/-- A `term` to use for the value of the field. -/
|
||||
| term (stx : Syntax) : FieldVal
|
||||
/-- A `StructInstView` to use for the value of a subobject field. -/
|
||||
| nested (s : StructInstView) : FieldVal
|
||||
/-- A field that was not provided and should be synthesized using default values. -/
|
||||
| default : FieldVal
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
`Field StructInstView` is a representation of a field in the structure instance.
|
||||
-/
|
||||
structure Field (σ : Type) where
|
||||
/-- The whole field syntax. -/
|
||||
ref : Syntax
|
||||
/-- The LHS decomposed into components. -/
|
||||
lhs : List FieldLHS
|
||||
/-- The value of the field. -/
|
||||
val : FieldVal σ
|
||||
/-- The elaborated field value, filled in at `elabStruct`.
|
||||
Missing fields use a metavariable for the elaborated value and are later solved for in `DefaultFields.propagate`. -/
|
||||
expr? : Option Expr := none
|
||||
deriving Inhabited
|
||||
/--
|
||||
`Field StructInstView` is a representation of a field in the structure instance.
|
||||
-/
|
||||
structure Field where
|
||||
/-- The whole field syntax. -/
|
||||
ref : Syntax
|
||||
/-- The LHS decomposed into components. -/
|
||||
lhs : List FieldLHS
|
||||
/-- The value of the field. -/
|
||||
val : FieldVal
|
||||
/-- The elaborated field value, filled in at `elabStruct`.
|
||||
Missing fields use a metavariable for the elaborated value and are later solved for in `DefaultFields.propagate`. -/
|
||||
expr? : Option Expr := none
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
The view for structure instance notation.
|
||||
-/
|
||||
structure StructInstView where
|
||||
/-- The syntax for the whole structure instance. -/
|
||||
ref : Syntax
|
||||
/-- The name of the structure for the type of the structure instance. -/
|
||||
structName : Name
|
||||
/-- Used for default values, to propagate structure type parameters. It is initially empty, and then set at `elabStruct`. -/
|
||||
params : Array (Name × Expr)
|
||||
/-- The fields of the structure instance. -/
|
||||
fields : List Field
|
||||
/-- The additional sources for fields for the structure instance. -/
|
||||
sources : SourcesView
|
||||
deriving Inhabited
|
||||
end
|
||||
|
||||
/--
|
||||
Returns if the field has a single component in its LHS.
|
||||
-/
|
||||
def Field.isSimple {σ} : Field σ → Bool
|
||||
def Field.isSimple : Field → Bool
|
||||
| { lhs := [_], .. } => true
|
||||
| _ => false
|
||||
|
||||
/--
|
||||
The view for structure instance notation.
|
||||
-/
|
||||
structure StructInstView where
|
||||
/-- The syntax for the whole structure instance. -/
|
||||
ref : Syntax
|
||||
/-- The name of the structure for the type of the structure instance. -/
|
||||
structName : Name
|
||||
/-- Used for default values, to propagate structure type parameters. It is initially empty, and then set at `elabStruct`. -/
|
||||
params : Array (Name × Expr)
|
||||
/-- The fields of the structure instance. -/
|
||||
fields : List (Field StructInstView)
|
||||
/-- The additional sources for fields for the structure instance. -/
|
||||
sources : SourcesView
|
||||
deriving Inhabited
|
||||
|
||||
/-- Abbreviation for the type of `StructInstView.fields`, namely `List (Field StructInstView)`. -/
|
||||
abbrev Fields := List (Field StructInstView)
|
||||
|
||||
/-- `true` iff all fields of the given structure are marked as `default` -/
|
||||
partial def StructInstView.allDefault (s : StructInstView) : Bool :=
|
||||
s.fields.all fun { val := val, .. } => match val with
|
||||
@@ -362,7 +361,7 @@ partial def StructInstView.allDefault (s : StructInstView) : Bool :=
|
||||
| .default => true
|
||||
| .nested s => allDefault s
|
||||
|
||||
def formatField (formatStruct : StructInstView → Format) (field : Field StructInstView) : Format :=
|
||||
def formatField (formatStruct : StructInstView → Format) (field : Field) : Format :=
|
||||
Format.joinSep field.lhs " . " ++ " := " ++
|
||||
match field.val with
|
||||
| .term v => v.prettyPrint
|
||||
@@ -378,11 +377,11 @@ partial def formatStruct : StructInstView → Format
|
||||
else
|
||||
"{" ++ format (source.explicit.map (·.stx)) ++ " with " ++ fieldsFmt ++ implicitFmt ++ "}"
|
||||
|
||||
instance : ToFormat StructInstView := ⟨formatStruct⟩
|
||||
instance : ToFormat StructInstView := ⟨formatStruct⟩
|
||||
instance : ToString StructInstView := ⟨toString ∘ format⟩
|
||||
|
||||
instance : ToFormat (Field StructInstView) := ⟨formatField formatStruct⟩
|
||||
instance : ToString (Field StructInstView) := ⟨toString ∘ format⟩
|
||||
instance : ToFormat Field := ⟨formatField formatStruct⟩
|
||||
instance : ToString Field := ⟨toString ∘ format⟩
|
||||
|
||||
/--
|
||||
Converts a `FieldLHS` back into syntax. This assumes the `ref` fields have the correct structure.
|
||||
@@ -403,14 +402,14 @@ private def FieldLHS.toSyntax (first : Bool) : FieldLHS → Syntax
|
||||
/--
|
||||
Converts a `FieldVal StructInstView` back into syntax. Only supports `.term`, and it assumes the `stx` field has the correct structure.
|
||||
-/
|
||||
private def FieldVal.toSyntax : FieldVal Struct → Syntax
|
||||
private def FieldVal.toSyntax : FieldVal → Syntax
|
||||
| .term stx => stx
|
||||
| _ => unreachable!
|
||||
|
||||
/--
|
||||
Converts a `Field StructInstView` back into syntax. Used to construct synthetic structure instance notation for subobjects in `StructInst.expandStruct` processing.
|
||||
-/
|
||||
private def Field.toSyntax : Field Struct → Syntax
|
||||
private def Field.toSyntax : Field → Syntax
|
||||
| field =>
|
||||
let stx := field.ref
|
||||
let stx := stx.setArg 2 field.val.toSyntax
|
||||
@@ -452,14 +451,14 @@ private def mkStructView (stx : Syntax) (structName : Name) (sources : SourcesVi
|
||||
let val := fieldStx[2]
|
||||
let first ← toFieldLHS fieldStx[0][0]
|
||||
let rest ← fieldStx[0][1].getArgs.toList.mapM toFieldLHS
|
||||
return { ref := fieldStx, lhs := first :: rest, val := FieldVal.term val : Field StructInstView }
|
||||
return { ref := fieldStx, lhs := first :: rest, val := FieldVal.term val : Field }
|
||||
return { ref := stx, structName, params := #[], fields, sources }
|
||||
|
||||
def StructInstView.modifyFieldsM {m : Type → Type} [Monad m] (s : StructInstView) (f : Fields → m Fields) : m StructInstView :=
|
||||
def StructInstView.modifyFieldsM {m : Type → Type} [Monad m] (s : StructInstView) (f : List Field → m (List Field)) : m StructInstView :=
|
||||
match s with
|
||||
| { ref, structName, params, fields, sources } => return { ref, structName, params, fields := (← f fields), sources }
|
||||
|
||||
def StructInstView.modifyFields (s : StructInstView) (f : Fields → Fields) : StructInstView :=
|
||||
def StructInstView.modifyFields (s : StructInstView) (f : List Field → List Field) : StructInstView :=
|
||||
Id.run <| s.modifyFieldsM f
|
||||
|
||||
/-- Expands name field LHSs with multi-component names into multi-component LHSs. -/
|
||||
@@ -525,14 +524,14 @@ private def expandParentFields (s : StructInstView) : TermElabM StructInstView :
|
||||
| _ => throwErrorAt ref "failed to access field '{fieldName}' in parent structure"
|
||||
| _ => return field
|
||||
|
||||
private abbrev FieldMap := Std.HashMap Name Fields
|
||||
private abbrev FieldMap := Std.HashMap Name (List Field)
|
||||
|
||||
/--
|
||||
Creates a hash map collecting all fields with the same first name component.
|
||||
Throws an error if there are multiple simple fields with the same name.
|
||||
Used by `StructInst.expandStruct` processing.
|
||||
-/
|
||||
private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
|
||||
private def mkFieldMap (fields : List Field) : TermElabM FieldMap :=
|
||||
fields.foldlM (init := {}) fun fieldMap field =>
|
||||
match field.lhs with
|
||||
| .fieldName _ fieldName :: _ =>
|
||||
@@ -548,7 +547,7 @@ private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
|
||||
/--
|
||||
Given a value of the hash map created by `mkFieldMap`, returns true if the value corresponds to a simple field.
|
||||
-/
|
||||
private def isSimpleField? : Fields → Option (Field StructInstView)
|
||||
private def isSimpleField? : List Field → Option Field
|
||||
| [field] => if field.isSimple then some field else none
|
||||
| _ => none
|
||||
|
||||
@@ -566,7 +565,7 @@ def mkProjStx? (s : Syntax) (structName : Name) (fieldName : Name) : TermElabM (
|
||||
/--
|
||||
Finds a simple field of the given name.
|
||||
-/
|
||||
def findField? (fields : Fields) (fieldName : Name) : Option (Field StructInstView) :=
|
||||
def findField? (fields : List Field) (fieldName : Name) : Option Field :=
|
||||
fields.find? fun field =>
|
||||
match field.lhs with
|
||||
| [.fieldName _ n] => n == fieldName
|
||||
@@ -620,7 +619,7 @@ mutual
|
||||
match findField? s.fields fieldName with
|
||||
| some field => return field::fields
|
||||
| none =>
|
||||
let addField (val : FieldVal StructInstView) : TermElabM Fields := do
|
||||
let addField (val : FieldVal) : TermElabM (List Field) := do
|
||||
return { ref, lhs := [FieldLHS.fieldName ref fieldName], val := val } :: fields
|
||||
match Lean.isSubobjectField? env s.structName fieldName with
|
||||
| some substructName =>
|
||||
@@ -773,7 +772,7 @@ private partial def elabStructInstView (s : StructInstView) (expectedType? : Opt
|
||||
trace[Elab.struct] "elabStruct {field}, {type}"
|
||||
match type with
|
||||
| .forallE _ d b bi =>
|
||||
let cont (val : Expr) (field : Field StructInstView) (instMVars := instMVars) : TermElabM (Expr × Expr × Fields × Array MVarId) := do
|
||||
let cont (val : Expr) (field : Field) (instMVars := instMVars) : TermElabM (Expr × Expr × List Field × Array MVarId) := do
|
||||
pushInfoTree <| InfoTree.node (children := {}) <| Info.ofFieldInfo {
|
||||
projName := s.structName.append fieldName, fieldName, lctx := (← getLCtx), val, stx := ref }
|
||||
let e := mkApp e val
|
||||
@@ -879,7 +878,7 @@ partial def getHierarchyDepth (struct : StructInstView) : Nat :=
|
||||
| _ => max
|
||||
|
||||
/-- Returns whether the field is still missing. -/
|
||||
def isDefaultMissing? [Monad m] [MonadMCtx m] (field : Field Struct) : m Bool := do
|
||||
def isDefaultMissing? [Monad m] [MonadMCtx m] (field : Field) : m Bool := do
|
||||
if let some expr := field.expr? then
|
||||
if let some (.mvar mvarId) := defaultMissing? expr then
|
||||
unless (← mvarId.isAssigned) do
|
||||
@@ -887,17 +886,17 @@ def isDefaultMissing? [Monad m] [MonadMCtx m] (field : Field Struct) : m Bool :=
|
||||
return false
|
||||
|
||||
/-- Returns a field that is still missing. -/
|
||||
partial def findDefaultMissing? [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Option (Field StructInstView)) :=
|
||||
partial def findDefaultMissing? [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Option Field) :=
|
||||
struct.fields.findSomeM? fun field => do
|
||||
match field.val with
|
||||
| .nested struct => findDefaultMissing? struct
|
||||
| _ => return if (← isDefaultMissing? field) then field else none
|
||||
|
||||
/-- Returns all fields that are still missing. -/
|
||||
partial def allDefaultMissing [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Array (Field StructInstView)) :=
|
||||
partial def allDefaultMissing [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Array Field) :=
|
||||
go struct *> get |>.run' #[]
|
||||
where
|
||||
go (struct : StructInstView) : StateT (Array (Field StructInstView)) m Unit :=
|
||||
go (struct : StructInstView) : StateT (Array Field) m Unit :=
|
||||
for field in struct.fields do
|
||||
if let .nested struct := field.val then
|
||||
go struct
|
||||
@@ -905,7 +904,7 @@ where
|
||||
modify (·.push field)
|
||||
|
||||
/-- Returns the name of the field. Assumes all fields under consideration are simple and named. -/
|
||||
def getFieldName (field : Field StructInstView) : Name :=
|
||||
def getFieldName (field : Field) : Name :=
|
||||
match field.lhs with
|
||||
| [.fieldName _ fieldName] => fieldName
|
||||
| _ => unreachable!
|
||||
|
||||
@@ -4,22 +4,15 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Class
|
||||
import Lean.Parser.Command
|
||||
import Lean.Meta.Closure
|
||||
import Lean.Meta.SizeOf
|
||||
import Lean.Meta.Injective
|
||||
import Lean.Meta.Structure
|
||||
import Lean.Meta.AppBuilder
|
||||
import Lean.Elab.Command
|
||||
import Lean.Elab.DeclModifiers
|
||||
import Lean.Elab.DeclUtil
|
||||
import Lean.Elab.Inductive
|
||||
import Lean.Elab.DeclarationRange
|
||||
import Lean.Elab.Binders
|
||||
import Lean.Elab.MutualInductive
|
||||
|
||||
namespace Lean.Elab.Command
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Elab.structure
|
||||
registerTraceClass `Elab.structure.resolutionOrder
|
||||
|
||||
register_builtin_option structureDiamondWarning : Bool := {
|
||||
defValue := false
|
||||
descr := "if true, enable warnings when a structure has diamond inheritance"
|
||||
@@ -39,13 +32,6 @@ leading_parser (structureTk <|> classTk) >> declId >> many Term.bracketedBinder
|
||||
```
|
||||
-/
|
||||
|
||||
structure StructCtorView where
|
||||
ref : Syntax
|
||||
modifiers : Modifiers
|
||||
name : Name
|
||||
declName : Name
|
||||
deriving Inhabited
|
||||
|
||||
structure StructFieldView where
|
||||
ref : Syntax
|
||||
modifiers : Modifiers
|
||||
@@ -61,22 +47,15 @@ structure StructFieldView where
|
||||
type? : Option Syntax
|
||||
value? : Option Syntax
|
||||
|
||||
structure StructView where
|
||||
ref : Syntax
|
||||
declId : Syntax
|
||||
modifiers : Modifiers
|
||||
isClass : Bool -- struct-only
|
||||
shortDeclName : Name
|
||||
declName : Name
|
||||
levelNames : List Name
|
||||
binders : Syntax
|
||||
type : Syntax -- modified (inductive has type?)
|
||||
parents : Array Syntax -- struct-only
|
||||
ctor : StructCtorView -- struct-only
|
||||
fields : Array StructFieldView -- struct-only
|
||||
derivingClasses : Array DerivingClassView
|
||||
structure StructView extends InductiveView where
|
||||
parents : Array Syntax
|
||||
fields : Array StructFieldView
|
||||
deriving Inhabited
|
||||
|
||||
def StructView.ctor : StructView → CtorView
|
||||
| { ctors := #[ctor], ..} => ctor
|
||||
| _ => unreachable!
|
||||
|
||||
structure StructParentInfo where
|
||||
ref : Syntax
|
||||
fvar? : Option Expr
|
||||
@@ -102,18 +81,6 @@ structure StructFieldInfo where
|
||||
value? : Option Expr := none
|
||||
deriving Inhabited, Repr
|
||||
|
||||
structure ElabStructHeaderResult where
|
||||
view : StructView
|
||||
lctx : LocalContext
|
||||
localInsts : LocalInstances
|
||||
levelNames : List Name
|
||||
params : Array Expr
|
||||
type : Expr
|
||||
parents : Array StructParentInfo
|
||||
/-- Field infos from parents. -/
|
||||
parentFieldInfos : Array StructFieldInfo
|
||||
deriving Inhabited
|
||||
|
||||
def StructFieldInfo.isFromParent (info : StructFieldInfo) : Bool :=
|
||||
match info.kind with
|
||||
| StructFieldKind.fromParent => true
|
||||
@@ -130,12 +97,12 @@ The structure constructor syntax is
|
||||
leading_parser try (declModifiers >> ident >> " :: ")
|
||||
```
|
||||
-/
|
||||
private def expandCtor (structStx : Syntax) (structModifiers : Modifiers) (structDeclName : Name) : TermElabM StructCtorView := do
|
||||
private def expandCtor (structStx : Syntax) (structModifiers : Modifiers) (structDeclName : Name) : TermElabM CtorView := do
|
||||
let useDefault := do
|
||||
let declName := structDeclName ++ defaultCtorName
|
||||
let ref := structStx[1].mkSynthetic
|
||||
addDeclarationRangesFromSyntax declName ref
|
||||
pure { ref, modifiers := default, name := defaultCtorName, declName }
|
||||
pure { ref, declId := ref, modifiers := default, declName }
|
||||
if structStx[5].isNone then
|
||||
useDefault
|
||||
else
|
||||
@@ -156,7 +123,7 @@ private def expandCtor (structStx : Syntax) (structModifiers : Modifiers) (struc
|
||||
let declName ← applyVisibility ctorModifiers.visibility declName
|
||||
addDocString' declName ctorModifiers.docString?
|
||||
addDeclarationRangesFromSyntax declName ctor[1]
|
||||
pure { ref := ctor[1], name, modifiers := ctorModifiers, declName }
|
||||
pure { ref := ctor[1], declId := ctor[1], modifiers := ctorModifiers, declName }
|
||||
|
||||
def checkValidFieldModifier (modifiers : Modifiers) : TermElabM Unit := do
|
||||
if modifiers.isNoncomputable then
|
||||
@@ -271,7 +238,7 @@ def structureSyntaxToView (modifiers : Modifiers) (stx : Syntax) : TermElabM Str
|
||||
let parents := if exts.isNone then #[] else exts[0][1].getSepArgs
|
||||
let optType := stx[4]
|
||||
let derivingClasses ← getOptDerivingClasses stx[6]
|
||||
let type ← if optType.isNone then `(Sort _) else pure optType[0][1]
|
||||
let type? := if optType.isNone then none else some optType[0][1]
|
||||
let ctor ← expandCtor stx modifiers declName
|
||||
let fields ← expandFields stx modifiers declName
|
||||
fields.forM fun field => do
|
||||
@@ -287,10 +254,13 @@ def structureSyntaxToView (modifiers : Modifiers) (stx : Syntax) : TermElabM Str
|
||||
declName
|
||||
levelNames
|
||||
binders
|
||||
type
|
||||
type?
|
||||
allowIndices := false
|
||||
allowSortPolymorphism := false
|
||||
ctors := #[ctor]
|
||||
parents
|
||||
ctor
|
||||
fields
|
||||
computedFields := #[]
|
||||
derivingClasses
|
||||
}
|
||||
|
||||
@@ -536,7 +506,7 @@ private partial def mkToParentName (parentStructName : Name) (p : Name → Bool)
|
||||
if p curr then curr else go (i+1)
|
||||
go 1
|
||||
|
||||
private partial def elabParents (view : StructView)
|
||||
private partial def withParents (view : StructView) (rs : Array ElabHeaderResult) (indFVar : Expr)
|
||||
(k : Array StructFieldInfo → Array StructParentInfo → TermElabM α) : TermElabM α := do
|
||||
go 0 #[] #[]
|
||||
where
|
||||
@@ -544,11 +514,17 @@ where
|
||||
if h : i < view.parents.size then
|
||||
let parent := view.parents[i]
|
||||
withRef parent do
|
||||
let type ← Term.elabType parent
|
||||
-- The only use case for autobound implicits for parents might be outParams, but outParam is not propagated.
|
||||
let type ← Term.withoutAutoBoundImplicit <| Term.elabType parent
|
||||
let parentType ← whnf type
|
||||
if parentType.getAppFn == indFVar then
|
||||
logWarning "structure extends itself, skipping"
|
||||
return ← go (i + 1) infos parents
|
||||
if rs.any (fun r => r.indFVar == parentType.getAppFn) then
|
||||
throwError "structure cannot extend types defined in the same mutual block"
|
||||
let parentStructName ← getStructureName parentType
|
||||
if parents.any (fun info => info.structName == parentStructName) then
|
||||
logWarningAt parent m!"duplicate parent structure '{.ofConstName parentStructName}', skipping"
|
||||
logWarning m!"duplicate parent structure '{.ofConstName parentStructName}', skipping"
|
||||
go (i + 1) infos parents
|
||||
else if let some existingFieldName ← findExistingField? infos parentStructName then
|
||||
if structureDiamondWarning.get (← getOptions) then
|
||||
@@ -570,6 +546,13 @@ where
|
||||
else
|
||||
k infos parents
|
||||
|
||||
private def registerFailedToInferFieldType (fieldName : Name) (e : Expr) (ref : Syntax) : TermElabM Unit := do
|
||||
Term.registerCustomErrorIfMVar (← instantiateMVars e) ref m!"failed to infer type of field '{.ofConstName fieldName}'"
|
||||
|
||||
private def registerFailedToInferDefaultValue (fieldName : Name) (e : Expr) (ref : Syntax) : TermElabM Unit := do
|
||||
Term.registerCustomErrorIfMVar (← instantiateMVars e) ref m!"failed to infer default value for field '{.ofConstName fieldName}'"
|
||||
Term.registerLevelMVarErrorExprInfo e ref m!"failed to infer universe levels in default value for field '{.ofConstName fieldName}'"
|
||||
|
||||
private def elabFieldTypeValue (view : StructFieldView) : TermElabM (Option Expr × Option Expr) :=
|
||||
Term.withAutoBoundImplicit <| Term.withAutoBoundImplicitForbiddenPred (fun n => view.name == n) <| Term.elabBinders view.binders.getArgs fun params => do
|
||||
match view.type? with
|
||||
@@ -581,10 +564,13 @@ private def elabFieldTypeValue (view : StructFieldView) : TermElabM (Option Expr
|
||||
-- TODO: add forbidden predicate using `shortDeclName` from `view`
|
||||
let params ← Term.addAutoBoundImplicits params
|
||||
let value ← Term.withoutAutoBoundImplicit <| Term.elabTerm valStx none
|
||||
registerFailedToInferFieldType view.name (← inferType value) view.nameId
|
||||
registerFailedToInferDefaultValue view.name value valStx
|
||||
let value ← mkLambdaFVars params value
|
||||
return (none, value)
|
||||
| some typeStx =>
|
||||
let type ← Term.elabType typeStx
|
||||
registerFailedToInferFieldType view.name type typeStx
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let params ← Term.addAutoBoundImplicits params
|
||||
match view.value? with
|
||||
@@ -593,6 +579,7 @@ private def elabFieldTypeValue (view : StructFieldView) : TermElabM (Option Expr
|
||||
return (type, none)
|
||||
| some valStx =>
|
||||
let value ← Term.withoutAutoBoundImplicit <| Term.elabTermEnsuringType valStx type
|
||||
registerFailedToInferDefaultValue view.name value valStx
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let type ← mkForallFVars params type
|
||||
let value ← mkLambdaFVars params value
|
||||
@@ -639,6 +626,7 @@ where
|
||||
valStx ← `(fun $(view.binders.getArgs)* => $valStx:term)
|
||||
let fvarType ← inferType info.fvar
|
||||
let value ← Term.elabTermEnsuringType valStx fvarType
|
||||
registerFailedToInferDefaultValue view.name value valStx
|
||||
pushInfoLeaf <| .ofFieldRedeclInfo { stx := view.ref }
|
||||
let infos := replaceFieldInfo infos { info with ref := view.nameId, value? := value }
|
||||
go (i+1) defaultValsOverridden infos
|
||||
@@ -650,113 +638,14 @@ where
|
||||
else
|
||||
k infos
|
||||
|
||||
private def getResultUniverse (type : Expr) : TermElabM Level := do
|
||||
let type ← whnf type
|
||||
match type with
|
||||
| Expr.sort u => pure u
|
||||
| _ => throwError "unexpected structure resulting type"
|
||||
|
||||
private def collectUsed (params : Array Expr) (fieldInfos : Array StructFieldInfo) : StateRefT CollectFVars.State MetaM Unit := do
|
||||
params.forM fun p => do
|
||||
let type ← inferType p
|
||||
type.collectFVars
|
||||
fieldInfos.forM fun info => do
|
||||
let fvarType ← inferType info.fvar
|
||||
fvarType.collectFVars
|
||||
match info.value? with
|
||||
| none => pure ()
|
||||
| some value => value.collectFVars
|
||||
|
||||
private def removeUnused (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo)
|
||||
: TermElabM (LocalContext × LocalInstances × Array Expr) := do
|
||||
let (_, used) ← (collectUsed params fieldInfos).run {}
|
||||
Meta.removeUnused scopeVars used
|
||||
|
||||
private def withUsed {α} (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo) (k : Array Expr → TermElabM α)
|
||||
: TermElabM α := do
|
||||
let (lctx, localInsts, vars) ← removeUnused scopeVars params fieldInfos
|
||||
withLCtx lctx localInsts <| k vars
|
||||
|
||||
private def levelMVarToParam (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo) (univToInfer? : Option LMVarId) : TermElabM (Array StructFieldInfo) := do
|
||||
levelMVarToParamFVars scopeVars
|
||||
levelMVarToParamFVars params
|
||||
fieldInfos.mapM fun info => do
|
||||
levelMVarToParamFVar info.fvar
|
||||
match info.value? with
|
||||
| none => pure info
|
||||
| some value =>
|
||||
let value ← levelMVarToParam' value
|
||||
pure { info with value? := value }
|
||||
where
|
||||
levelMVarToParam' (type : Expr) : TermElabM Expr := do
|
||||
Term.levelMVarToParam type (except := fun mvarId => univToInfer? == some mvarId)
|
||||
|
||||
levelMVarToParamFVars (fvars : Array Expr) : TermElabM Unit :=
|
||||
fvars.forM levelMVarToParamFVar
|
||||
|
||||
levelMVarToParamFVar (fvar : Expr) : TermElabM Unit := do
|
||||
let type ← inferType fvar
|
||||
discard <| levelMVarToParam' type
|
||||
|
||||
|
||||
private partial def collectUniversesFromFields (r : Level) (rOffset : Nat) (fieldInfos : Array StructFieldInfo) : TermElabM (Array Level) := do
|
||||
let (_, us) ← go |>.run #[]
|
||||
return us
|
||||
where
|
||||
go : StateRefT (Array Level) TermElabM Unit :=
|
||||
for info in fieldInfos do
|
||||
let type ← inferType info.fvar
|
||||
let u ← getLevel type
|
||||
let u ← instantiateLevelMVars u
|
||||
match (← modifyGet fun s => accLevel u r rOffset |>.run |>.run s) with
|
||||
| some _ => pure ()
|
||||
| none =>
|
||||
let typeType ← inferType type
|
||||
let mut msg := m!"failed to compute resulting universe level of structure, field '{info.declName}' has type{indentD m!"{type} : {typeType}"}\nstructure resulting type{indentExpr (mkSort (r.addOffset rOffset))}"
|
||||
if r.isMVar then
|
||||
msg := msg ++ "\nrecall that Lean only infers the resulting universe level automatically when there is a unique solution for the universe level constraints, consider explicitly providing the structure resulting universe level"
|
||||
throwError msg
|
||||
|
||||
/--
|
||||
Decides whether the structure should be `Prop`-valued when the universe is not given
|
||||
and when the universe inference algorithm `collectUniversesFromFields` determines
|
||||
that the inductive type could naturally be `Prop`-valued.
|
||||
|
||||
See `Lean.Elab.Command.isPropCandidate` for an explanation.
|
||||
Specialized to structures, the heuristic is that we prefer a `Prop` instead of a `Type` structure
|
||||
when it could be a syntactic subsingleton.
|
||||
Exception: no-field structures are `Type` since they are likely stubbed-out declarations.
|
||||
-/
|
||||
private def isPropCandidate (fieldInfos : Array StructFieldInfo) : Bool :=
|
||||
!fieldInfos.isEmpty
|
||||
|
||||
private def updateResultingUniverse (fieldInfos : Array StructFieldInfo) (type : Expr) : TermElabM Expr := do
|
||||
let r ← getResultUniverse type
|
||||
let rOffset : Nat := r.getOffset
|
||||
let r : Level := r.getLevelOffset
|
||||
unless r.isMVar do
|
||||
throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly: {r}"
|
||||
let us ← collectUniversesFromFields r rOffset fieldInfos
|
||||
trace[Elab.structure] "updateResultingUniverse us: {us}, r: {r}, rOffset: {rOffset}"
|
||||
let rNew := mkResultUniverse us rOffset (isPropCandidate fieldInfos)
|
||||
assignLevelMVar r.mvarId! rNew
|
||||
instantiateMVars type
|
||||
|
||||
private def collectLevelParamsInFVar (s : CollectLevelParams.State) (fvar : Expr) : TermElabM CollectLevelParams.State := do
|
||||
let type ← inferType fvar
|
||||
let type ← instantiateMVars type
|
||||
return collectLevelParams s type
|
||||
|
||||
private def collectLevelParamsInFVars (fvars : Array Expr) (s : CollectLevelParams.State) : TermElabM CollectLevelParams.State :=
|
||||
fvars.foldlM collectLevelParamsInFVar s
|
||||
|
||||
private def collectLevelParamsInStructure (structType : Expr) (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo)
|
||||
: TermElabM (Array Name) := do
|
||||
let s := collectLevelParams {} structType
|
||||
let s ← collectLevelParamsInFVars scopeVars s
|
||||
let s ← collectLevelParamsInFVars params s
|
||||
let s ← fieldInfos.foldlM (init := s) fun s info => collectLevelParamsInFVar s info.fvar
|
||||
return s.params
|
||||
private def collectUsedFVars (lctx : LocalContext) (localInsts : LocalInstances) (fieldInfos : Array StructFieldInfo) :
|
||||
StateRefT CollectFVars.State MetaM Unit := do
|
||||
withLCtx lctx localInsts do
|
||||
fieldInfos.forM fun info => do
|
||||
let fvarType ← inferType info.fvar
|
||||
fvarType.collectFVars
|
||||
if let some value := info.value? then
|
||||
value.collectFVars
|
||||
|
||||
private def addCtorFields (fieldInfos : Array StructFieldInfo) : Nat → Expr → TermElabM Expr
|
||||
| 0, type => pure type
|
||||
@@ -772,19 +661,29 @@ private def addCtorFields (fieldInfos : Array StructFieldInfo) : Nat → Expr
|
||||
| _ =>
|
||||
addCtorFields fieldInfos i (mkForall decl.userName decl.binderInfo decl.type type)
|
||||
|
||||
private def mkCtor (view : StructView) (levelParams : List Name) (params : Array Expr) (fieldInfos : Array StructFieldInfo) : TermElabM Constructor :=
|
||||
private def mkCtor (view : StructView) (r : ElabHeaderResult) (params : Array Expr) (fieldInfos : Array StructFieldInfo) : TermElabM Constructor :=
|
||||
withRef view.ref do
|
||||
let type := mkAppN (mkConst view.declName (levelParams.map mkLevelParam)) params
|
||||
let type := mkAppN r.indFVar params
|
||||
let type ← addCtorFields fieldInfos fieldInfos.size type
|
||||
let type ← mkForallFVars params type
|
||||
let type ← instantiateMVars type
|
||||
let type := type.inferImplicit params.size true
|
||||
pure { name := view.ctor.declName, type }
|
||||
|
||||
private partial def checkResultingUniversesForFields (fieldInfos : Array StructFieldInfo) (u : Level) : TermElabM Unit := do
|
||||
for info in fieldInfos do
|
||||
let type ← inferType info.fvar
|
||||
let v := (← instantiateLevelMVars (← getLevel type)).normalize
|
||||
unless u.geq v do
|
||||
let msg := m!"invalid universe level for field '{info.name}', has type{indentExpr type}\n\
|
||||
at universe level{indentD v}\n\
|
||||
which is not less than or equal to the structure's resulting universe level{indentD u}"
|
||||
throwErrorAt info.ref msg
|
||||
|
||||
@[extern "lean_mk_projections"]
|
||||
private opaque mkProjections (env : Environment) (structName : Name) (projs : List Name) (isClass : Bool) : Except KernelException Environment
|
||||
|
||||
private def addProjections (r : ElabStructHeaderResult) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
private def addProjections (r : ElabHeaderResult) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
if r.type.isProp then
|
||||
if let some fieldInfo ← fieldInfos.findM? (not <$> Meta.isProof ·.fvar) then
|
||||
throwErrorAt fieldInfo.ref m!"failed to generate projections for 'Prop' structure, field '{format fieldInfo.name}' is not a proof"
|
||||
@@ -795,49 +694,71 @@ private def addProjections (r : ElabStructHeaderResult) (fieldInfos : Array Stru
|
||||
|
||||
private def registerStructure (structName : Name) (infos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
let fields ← infos.filterMapM fun info => do
|
||||
if info.kind == StructFieldKind.fromParent then
|
||||
return none
|
||||
else
|
||||
return some {
|
||||
fieldName := info.name
|
||||
projFn := info.declName
|
||||
binderInfo := (← getFVarLocalDecl info.fvar).binderInfo
|
||||
autoParam? := (← inferType info.fvar).getAutoParamTactic?
|
||||
subobject? := if let .subobject parentName := info.kind then parentName else none
|
||||
}
|
||||
if info.kind == StructFieldKind.fromParent then
|
||||
return none
|
||||
else
|
||||
return some {
|
||||
fieldName := info.name
|
||||
projFn := info.declName
|
||||
binderInfo := (← getFVarLocalDecl info.fvar).binderInfo
|
||||
autoParam? := (← inferType info.fvar).getAutoParamTactic?
|
||||
subobject? := if let .subobject parentName := info.kind then parentName else none
|
||||
}
|
||||
modifyEnv fun env => Lean.registerStructure env { structName, fields }
|
||||
|
||||
private def mkAuxConstructions (declName : Name) : TermElabM Unit := do
|
||||
let env ← getEnv
|
||||
let hasEq := env.contains ``Eq
|
||||
let hasHEq := env.contains ``HEq
|
||||
let hasUnit := env.contains ``PUnit
|
||||
let hasProd := env.contains ``Prod
|
||||
mkRecOn declName
|
||||
if hasUnit then mkCasesOn declName
|
||||
if hasUnit && hasEq && hasHEq then mkNoConfusion declName
|
||||
let ival ← getConstInfoInduct declName
|
||||
if ival.isRec then
|
||||
if hasUnit && hasProd then mkBelow declName
|
||||
if hasUnit && hasProd then mkIBelow declName
|
||||
if hasUnit && hasProd then mkBRecOn declName
|
||||
if hasUnit && hasProd then mkBInductionOn declName
|
||||
private def checkDefaults (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
let mut mvars := {}
|
||||
let mut lmvars := {}
|
||||
for fieldInfo in fieldInfos do
|
||||
if let some value := fieldInfo.value? then
|
||||
let value ← instantiateMVars value
|
||||
mvars := Expr.collectMVars mvars value
|
||||
lmvars := collectLevelMVars lmvars value
|
||||
-- Log errors and ignore the failure; we later will just omit adding a default value.
|
||||
if ← Term.logUnassignedUsingErrorInfos mvars.result then
|
||||
return
|
||||
else if ← Term.logUnassignedLevelMVarsUsingErrorInfos lmvars.result then
|
||||
return
|
||||
|
||||
private def addDefaults (lctx : LocalContext) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
private def addDefaults (params : Array Expr) (replaceIndFVars : Expr → MetaM Expr) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
let lctx ← getLCtx
|
||||
/- The `lctx` and `defaultAuxDecls` are used to create the auxiliary "default value" declarations
|
||||
The parameters `params` for these definitions must be marked as implicit, and all others as explicit. -/
|
||||
let lctx :=
|
||||
params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) =>
|
||||
if p.isFVar then
|
||||
lctx.setBinderInfo p.fvarId! BinderInfo.implicit
|
||||
else
|
||||
lctx
|
||||
let lctx :=
|
||||
fieldInfos.foldl (init := lctx) fun (lctx : LocalContext) (info : StructFieldInfo) =>
|
||||
if info.isFromParent then lctx -- `fromParent` fields are elaborated as let-decls, and are zeta-expanded when creating "default value" auxiliary functions
|
||||
else lctx.setBinderInfo info.fvar.fvarId! BinderInfo.default
|
||||
-- Make all indFVar replacements in the local context.
|
||||
let lctx ←
|
||||
lctx.foldlM (init := {}) fun lctx ldecl => do
|
||||
match ldecl with
|
||||
| .cdecl _ fvarId userName type bi k =>
|
||||
let type ← replaceIndFVars type
|
||||
return lctx.mkLocalDecl fvarId userName type bi k
|
||||
| .ldecl _ fvarId userName type value nonDep k =>
|
||||
let type ← replaceIndFVars type
|
||||
let value ← replaceIndFVars value
|
||||
return lctx.mkLetDecl fvarId userName type value nonDep k
|
||||
withLCtx lctx (← getLocalInstances) do
|
||||
fieldInfos.forM fun fieldInfo => do
|
||||
if let some value := fieldInfo.value? then
|
||||
let declName := mkDefaultFnOfProjFn fieldInfo.declName
|
||||
let type ← inferType fieldInfo.fvar
|
||||
let value ← instantiateMVars value
|
||||
if value.hasExprMVar then
|
||||
discard <| Term.logUnassignedUsingErrorInfos (← getMVars value)
|
||||
throwErrorAt fieldInfo.ref "invalid default value for field '{format fieldInfo.name}', it contains metavariables{indentExpr value}"
|
||||
/- The identity function is used as "marker". -/
|
||||
let value ← mkId value
|
||||
-- No need to compile the definition, since it is only used during elaboration.
|
||||
discard <| mkAuxDefinition declName type value (zetaDelta := true) (compile := false)
|
||||
setReducibleAttribute declName
|
||||
let type ← replaceIndFVars (← inferType fieldInfo.fvar)
|
||||
let value ← instantiateMVars (← replaceIndFVars value)
|
||||
trace[Elab.structure] "default value after 'replaceIndFVars': {indentExpr value}"
|
||||
-- If there are mvars, `checkDefaults` already logged an error.
|
||||
unless value.hasMVar || value.hasSyntheticSorry do
|
||||
/- The identity function is used as "marker". -/
|
||||
let value ← mkId value
|
||||
-- No need to compile the definition, since it is only used during elaboration.
|
||||
discard <| mkAuxDefinition declName type value (zetaDelta := true) (compile := false)
|
||||
setReducibleAttribute declName
|
||||
|
||||
/--
|
||||
Given `type` of the form `forall ... (source : A), B`, return `forall ... [source : A], B`.
|
||||
@@ -906,57 +827,6 @@ private partial def mkCoercionToCopiedParent (levelParams : List Name) (params :
|
||||
setReducibleAttribute declName
|
||||
return { structName := parentStructName, subobject := false, projFn := declName }
|
||||
|
||||
private def elabStructHeader (view : StructView) : TermElabM ElabStructHeaderResult :=
|
||||
Term.withAutoBoundImplicitForbiddenPred (fun n => view.shortDeclName == n) do
|
||||
Term.withAutoBoundImplicit do
|
||||
Term.elabBinders view.binders.getArgs fun params => do
|
||||
elabParents view fun parentFieldInfos parents => do
|
||||
let type ← Term.elabType view.type
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let u ← mkFreshLevelMVar
|
||||
unless ← isDefEq type (mkSort u) do
|
||||
throwErrorAt view.type "invalid structure type, expecting 'Type _' or 'Prop'"
|
||||
let type ← instantiateMVars (← whnf type)
|
||||
Term.addAutoBoundImplicits' params type fun params type => do
|
||||
let levelNames ← Term.getLevelNames
|
||||
trace[Elab.structure] "header params: {params}, type: {type}, levelNames: {levelNames}"
|
||||
return { lctx := (← getLCtx), localInsts := (← getLocalInstances), levelNames, params, type, view, parents, parentFieldInfos }
|
||||
|
||||
private def mkTypeFor (r : ElabStructHeaderResult) : TermElabM Expr := do
|
||||
withLCtx r.lctx r.localInsts do
|
||||
mkForallFVars r.params r.type
|
||||
|
||||
/--
|
||||
Create a local declaration for the structure and execute `x params indFVar`, where `params` are the structure's type parameters and
|
||||
`indFVar` is the new local declaration.
|
||||
-/
|
||||
private partial def withStructureLocalDecl (r : ElabStructHeaderResult) (x : Array Expr → Expr → TermElabM α) : TermElabM α := do
|
||||
let declName := r.view.declName
|
||||
let shortDeclName := r.view.shortDeclName
|
||||
let type ← mkTypeFor r
|
||||
let params := r.params
|
||||
withLCtx r.lctx r.localInsts <| withRef r.view.ref do
|
||||
Term.withAuxDecl shortDeclName type declName fun indFVar =>
|
||||
x params indFVar
|
||||
|
||||
/--
|
||||
Remark: `numVars <= numParams`.
|
||||
`numVars` is the number of context `variables` used in the declaration,
|
||||
and `numParams - numVars` is the number of parameters provided as binders in the declaration.
|
||||
-/
|
||||
private def mkInductiveType (view : StructView) (indFVar : Expr) (levelNames : List Name)
|
||||
(numVars : Nat) (numParams : Nat) (type : Expr) (ctor : Constructor) : TermElabM InductiveType := do
|
||||
let levelParams := levelNames.map mkLevelParam
|
||||
let const := mkConst view.declName levelParams
|
||||
let ctorType ← forallBoundedTelescope ctor.type numParams fun params type => do
|
||||
let type := type.replace fun e =>
|
||||
if e == indFVar then
|
||||
mkAppN const (params.extract 0 numVars)
|
||||
else
|
||||
none
|
||||
instantiateMVars (← mkForallFVars params type)
|
||||
return { name := view.declName, type := ← instantiateMVars type, ctors := [{ ctor with type := ← instantiateMVars ctorType }] }
|
||||
|
||||
/--
|
||||
Precomputes the structure's resolution order.
|
||||
Option `structure.strictResolutionOrder` controls whether to create a warning if the C3 algorithm failed.
|
||||
@@ -987,109 +857,50 @@ private def addParentInstances (parents : Array StructureParentInfo) : MetaM Uni
|
||||
for instParent in instParents do
|
||||
addInstance instParent.projFn AttributeKind.global (eval_prio default)
|
||||
|
||||
def mkStructureDecl (vars : Array Expr) (view : StructView) : TermElabM Unit := Term.withoutSavingRecAppSyntax do
|
||||
let scopeLevelNames ← Term.getLevelNames
|
||||
let isUnsafe := view.modifiers.isUnsafe
|
||||
withRef view.ref <| Term.withLevelNames view.levelNames do
|
||||
let r ← elabStructHeader view
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
withLCtx r.lctx r.localInsts do
|
||||
withStructureLocalDecl r fun params indFVar => do
|
||||
trace[Elab.structure] "indFVar: {indFVar}"
|
||||
Term.addLocalVarInfo view.declId indFVar
|
||||
withFields view.fields r.parentFieldInfos fun fieldInfos =>
|
||||
withRef view.ref do
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let type ← instantiateMVars r.type
|
||||
let u ← getResultUniverse type
|
||||
let univToInfer? ← shouldInferResultUniverse u
|
||||
withUsed vars params fieldInfos fun scopeVars => do
|
||||
let fieldInfos ← levelMVarToParam scopeVars params fieldInfos univToInfer?
|
||||
let type ← withRef view.ref do
|
||||
if univToInfer?.isSome then
|
||||
updateResultingUniverse fieldInfos type
|
||||
else
|
||||
checkResultingUniverse (← getResultUniverse type)
|
||||
pure type
|
||||
trace[Elab.structure] "type: {type}"
|
||||
let usedLevelNames ← collectLevelParamsInStructure type scopeVars params fieldInfos
|
||||
match sortDeclLevelParams scopeLevelNames r.levelNames usedLevelNames with
|
||||
| Except.error msg => throwErrorAt view.declId msg
|
||||
| Except.ok levelParams =>
|
||||
let params := scopeVars ++ params
|
||||
let ctor ← mkCtor view levelParams params fieldInfos
|
||||
let type ← mkForallFVars params type
|
||||
let type ← instantiateMVars type
|
||||
let indType ← mkInductiveType view indFVar levelParams scopeVars.size params.size type ctor
|
||||
let decl := Declaration.inductDecl levelParams params.size [indType] isUnsafe
|
||||
Term.ensureNoUnassignedMVars decl
|
||||
addDecl decl
|
||||
-- rename indFVar so that it does not shadow the actual declaration:
|
||||
let lctx := (← getLCtx).modifyLocalDecl indFVar.fvarId! fun decl => decl.setUserName .anonymous
|
||||
withLCtx lctx (← getLocalInstances) do
|
||||
addProjections r fieldInfos
|
||||
registerStructure view.declName fieldInfos
|
||||
mkAuxConstructions view.declName
|
||||
withSaveInfoContext do -- save new env
|
||||
Term.addLocalVarInfo view.ref[1] (← mkConstWithLevelParams view.declName)
|
||||
if let some _ := view.ctor.ref.getPos? (canonicalOnly := true) then
|
||||
Term.addTermInfo' view.ctor.ref (← mkConstWithLevelParams view.ctor.declName) (isBinder := true)
|
||||
for field in view.fields do
|
||||
-- may not exist if overriding inherited field
|
||||
if (← getEnv).contains field.declName then
|
||||
Term.addTermInfo' field.ref (← mkConstWithLevelParams field.declName) (isBinder := true)
|
||||
withRef view.declId do
|
||||
Term.applyAttributesAt view.declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking
|
||||
let parentInfos ← r.parents.mapM fun parent => do
|
||||
if parent.subobject then
|
||||
let some info := fieldInfos.find? (·.kind == .subobject parent.structName) | unreachable!
|
||||
pure { structName := parent.structName, subobject := true, projFn := info.declName }
|
||||
else
|
||||
mkCoercionToCopiedParent levelParams params view parent.structName parent.type
|
||||
setStructureParents view.declName parentInfos
|
||||
checkResolutionOrder view.declName
|
||||
if view.isClass then
|
||||
addParentInstances parentInfos
|
||||
|
||||
let lctx ← getLCtx
|
||||
/- The `lctx` and `defaultAuxDecls` are used to create the auxiliary "default value" declarations
|
||||
The parameters `params` for these definitions must be marked as implicit, and all others as explicit. -/
|
||||
let lctx :=
|
||||
params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) =>
|
||||
if p.isFVar then
|
||||
lctx.setBinderInfo p.fvarId! BinderInfo.implicit
|
||||
else
|
||||
lctx
|
||||
let lctx :=
|
||||
fieldInfos.foldl (init := lctx) fun (lctx : LocalContext) (info : StructFieldInfo) =>
|
||||
if info.isFromParent then lctx -- `fromParent` fields are elaborated as let-decls, and are zeta-expanded when creating "default value" auxiliary functions
|
||||
else lctx.setBinderInfo info.fvar.fvarId! BinderInfo.default
|
||||
addDefaults lctx fieldInfos
|
||||
|
||||
|
||||
def elabStructureView (vars : Array Expr) (view : StructView) : TermElabM Unit := do
|
||||
Term.withDeclName view.declName <| withRef view.ref do
|
||||
mkStructureDecl vars view
|
||||
unless view.isClass do
|
||||
Lean.Meta.IndPredBelow.mkBelow view.declName
|
||||
mkSizeOfInstances view.declName
|
||||
mkInjectiveTheorems view.declName
|
||||
|
||||
def elabStructureViewPostprocessing (view : StructView) : CommandElabM Unit := do
|
||||
view.derivingClasses.forM fun classView => classView.applyHandlers #[view.declName]
|
||||
runTermElabM fun _ => Term.withDeclName view.declName <| withRef view.declId do
|
||||
Term.applyAttributesAt view.declName view.modifiers.attrs .afterCompilation
|
||||
|
||||
def elabStructure (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit := do
|
||||
let view ← runTermElabM fun vars => do
|
||||
@[builtin_inductive_elab Lean.Parser.Command.«structure»]
|
||||
def elabStructureCommand : InductiveElabDescr where
|
||||
mkInductiveView (modifiers : Modifiers) (stx : Syntax) := do
|
||||
let view ← structureSyntaxToView modifiers stx
|
||||
trace[Elab.structure] "view.levelNames: {view.levelNames}"
|
||||
elabStructureView vars view
|
||||
pure view
|
||||
elabStructureViewPostprocessing view
|
||||
return {
|
||||
view := view.toInductiveView
|
||||
elabCtors := fun rs r params => do
|
||||
withParents view rs r.indFVar fun parentFieldInfos parents =>
|
||||
withFields view.fields parentFieldInfos fun fieldInfos => do
|
||||
withRef view.ref do
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let lctx ← getLCtx
|
||||
let localInsts ← getLocalInstances
|
||||
let ctor ← mkCtor view r params fieldInfos
|
||||
return {
|
||||
ctors := [ctor]
|
||||
collectUsedFVars := collectUsedFVars lctx localInsts fieldInfos
|
||||
checkUniverses := fun _ u => withLCtx lctx localInsts do checkResultingUniversesForFields fieldInfos u
|
||||
finalizeTermElab := withLCtx lctx localInsts do checkDefaults fieldInfos
|
||||
prefinalize := fun _ _ _ => do
|
||||
withLCtx lctx localInsts do
|
||||
addProjections r fieldInfos
|
||||
registerStructure view.declName fieldInfos
|
||||
withSaveInfoContext do -- save new env
|
||||
for field in view.fields do
|
||||
-- may not exist if overriding inherited field
|
||||
if (← getEnv).contains field.declName then
|
||||
Term.addTermInfo' field.ref (← mkConstWithLevelParams field.declName) (isBinder := true)
|
||||
finalize := fun levelParams params replaceIndFVars => do
|
||||
let parentInfos ← parents.mapM fun parent => do
|
||||
if parent.subobject then
|
||||
let some info := fieldInfos.find? (·.kind == .subobject parent.structName) | unreachable!
|
||||
pure { structName := parent.structName, subobject := true, projFn := info.declName }
|
||||
else
|
||||
mkCoercionToCopiedParent levelParams params view parent.structName parent.type
|
||||
setStructureParents view.declName parentInfos
|
||||
checkResolutionOrder view.declName
|
||||
if view.isClass then
|
||||
addParentInstances parentInfos
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Elab.structure
|
||||
registerTraceClass `Elab.structure.resolutionOrder
|
||||
withLCtx lctx localInsts do
|
||||
addDefaults params replaceIndFVars fieldInfos
|
||||
}
|
||||
}
|
||||
|
||||
end Lean.Elab.Command
|
||||
|
||||
@@ -19,12 +19,12 @@ def expandOptPrecedence (stx : Syntax) : MacroM (Option Nat) :=
|
||||
return some (← evalPrec stx[0][1])
|
||||
|
||||
private def mkParserSeq (ds : Array (Term × Nat)) : TermElabM (Term × Nat) := do
|
||||
if ds.size == 0 then
|
||||
if h₀ : ds.size = 0 then
|
||||
throwUnsupportedSyntax
|
||||
else if ds.size == 1 then
|
||||
pure ds[0]!
|
||||
else if h₁ : ds.size = 1 then
|
||||
pure ds[0]
|
||||
else
|
||||
let mut (r, stackSum) := ds[0]!
|
||||
let mut (r, stackSum) := ds[0]
|
||||
for (d, stackSz) in ds[1:ds.size] do
|
||||
r ← `(ParserDescr.binary `andthen $r $d)
|
||||
stackSum := stackSum + stackSz
|
||||
@@ -142,7 +142,7 @@ where
|
||||
let args := stx.getArgs
|
||||
if (← checkLeftRec stx[0]) then
|
||||
if args.size == 1 then throwErrorAt stx "invalid atomic left recursive syntax"
|
||||
let args := args.eraseIdx 0
|
||||
let args := args.eraseIdxIfInBounds 0
|
||||
let args ← args.mapM fun arg => withNestedParser do process arg
|
||||
mkParserSeq args
|
||||
else
|
||||
|
||||
@@ -149,8 +149,8 @@ where
|
||||
-- Succeeded. Collect new TC problems
|
||||
trace[Elab.defaultInstance] "isDefEq worked {mkMVar mvarId} : {← inferType (mkMVar mvarId)} =?= {candidate} : {← inferType candidate}"
|
||||
let mut pending := []
|
||||
for i in [:bis.size] do
|
||||
if bis[i]! == BinderInfo.instImplicit then
|
||||
for h : i in [:bis.size] do
|
||||
if bis[i] == BinderInfo.instImplicit then
|
||||
pending := mvars[i]!.mvarId! :: pending
|
||||
synthesizePending pending
|
||||
else
|
||||
|
||||
@@ -58,15 +58,28 @@ where
|
||||
let some eqBVPred ← ReifiedBVPred.mkBinPred atom resValExpr atomExpr resExpr .eq | return none
|
||||
let eqBV ← ReifiedBVLogical.ofPred eqBVPred
|
||||
|
||||
let trueExpr := mkConst ``Bool.true
|
||||
let impExpr ← mkArrow (← mkEq eqDiscrExpr trueExpr) (← mkEq eqBVExpr trueExpr)
|
||||
let decideImpExpr ← mkAppOptM ``Decidable.decide #[some impExpr, none]
|
||||
let imp ← ReifiedBVLogical.mkGate eqDiscr eqBV eqDiscrExpr eqBVExpr .imp
|
||||
|
||||
let proof := do
|
||||
let evalExpr ← ReifiedBVLogical.mkEvalExpr imp.expr
|
||||
let congrProof ← imp.evalsAtAtoms
|
||||
let lemmaProof := mkApp4 (mkConst lemmaName) (toExpr lhs.width) discrExpr lhsExpr rhsExpr
|
||||
|
||||
let trueExpr := mkConst ``Bool.true
|
||||
let eqDiscrTrueExpr ← mkEq eqDiscrExpr trueExpr
|
||||
let eqBVExprTrueExpr ← mkEq eqBVExpr trueExpr
|
||||
let impExpr ← mkArrow eqDiscrTrueExpr eqBVExprTrueExpr
|
||||
-- construct a `Decidable` instance for the implication using forall_prop_decidable
|
||||
let decEqDiscrTrue := mkApp2 (mkConst ``instDecidableEqBool) eqDiscrExpr trueExpr
|
||||
let decEqBVExprTrue := mkApp2 (mkConst ``instDecidableEqBool) eqBVExpr trueExpr
|
||||
let impDecidable := mkApp4 (mkConst ``forall_prop_decidable)
|
||||
eqDiscrTrueExpr
|
||||
(.lam .anonymous eqDiscrTrueExpr eqBVExprTrueExpr .default)
|
||||
decEqDiscrTrue
|
||||
(.lam .anonymous eqDiscrTrueExpr decEqBVExprTrue .default)
|
||||
|
||||
let decideImpExpr := mkApp2 (mkConst ``Decidable.decide) impExpr impDecidable
|
||||
|
||||
return mkApp4
|
||||
(mkConst ``Std.Tactic.BVDecide.Reflect.Bool.lemma_congr)
|
||||
decideImpExpr
|
||||
|
||||
@@ -218,7 +218,7 @@ where
|
||||
let old ← snap.old?
|
||||
-- If the kind is equal, we can assume the old version was a macro as well
|
||||
guard <| old.stx.isOfKind stx.getKind
|
||||
let state ← old.val.get.data.finished.get.state?
|
||||
let state ← old.val.get.finished.get.state?
|
||||
guard <| state.term.meta.core.nextMacroScope == nextMacroScope
|
||||
-- check absence of traces; see Note [Incremental Macros]
|
||||
guard <| state.term.meta.core.traceState.traces.size == 0
|
||||
@@ -226,9 +226,10 @@ where
|
||||
return old.val.get
|
||||
Language.withAlwaysResolvedPromise fun promise => do
|
||||
-- Store new unfolding in the snapshot tree
|
||||
snap.new.resolve <| .mk {
|
||||
snap.new.resolve {
|
||||
stx := stx'
|
||||
diagnostics := .empty
|
||||
inner? := none
|
||||
finished := .pure {
|
||||
diagnostics := .empty
|
||||
state? := (← Tactic.saveState)
|
||||
@@ -240,7 +241,7 @@ where
|
||||
new := promise
|
||||
old? := do
|
||||
let old ← old?
|
||||
return ⟨old.data.stx, (← old.data.next.get? 0)⟩
|
||||
return ⟨old.stx, (← old.next.get? 0)⟩
|
||||
} }) do
|
||||
evalTactic stx'
|
||||
return
|
||||
|
||||
@@ -60,7 +60,7 @@ where
|
||||
if let some snap := (← readThe Term.Context).tacSnap? then
|
||||
if let some old := snap.old? then
|
||||
let oldParsed := old.val.get
|
||||
oldInner? := oldParsed.data.inner? |>.map (⟨oldParsed.data.stx, ·⟩)
|
||||
oldInner? := oldParsed.inner? |>.map (⟨oldParsed.stx, ·⟩)
|
||||
-- compare `stx[0]` for `finished`/`next` reuse, focus on remainder of script
|
||||
Term.withNarrowedTacticReuse (stx := stx) (fun stx => (stx[0], mkNullNode stx.getArgs[1:])) fun stxs => do
|
||||
let some snap := (← readThe Term.Context).tacSnap?
|
||||
@@ -70,10 +70,10 @@ where
|
||||
if let some old := snap.old? then
|
||||
-- `tac` must be unchanged given the narrow above; let's reuse `finished`'s state!
|
||||
let oldParsed := old.val.get
|
||||
if let some state := oldParsed.data.finished.get.state? then
|
||||
if let some state := oldParsed.finished.get.state? then
|
||||
reusableResult? := some ((), state)
|
||||
-- only allow `next` reuse in this case
|
||||
oldNext? := oldParsed.data.next.get? 0 |>.map (⟨old.stx, ·⟩)
|
||||
oldNext? := oldParsed.next.get? 0 |>.map (⟨old.stx, ·⟩)
|
||||
|
||||
-- For `tac`'s snapshot task range, disregard synthetic info as otherwise
|
||||
-- `SnapshotTree.findInfoTreeAtPos` might choose the wrong snapshot: for example, when
|
||||
@@ -89,7 +89,7 @@ where
|
||||
withAlwaysResolvedPromise fun next => do
|
||||
withAlwaysResolvedPromise fun finished => do
|
||||
withAlwaysResolvedPromise fun inner => do
|
||||
snap.new.resolve <| .mk {
|
||||
snap.new.resolve {
|
||||
desc := tac.getKind.toString
|
||||
diagnostics := .empty
|
||||
stx := tac
|
||||
|
||||
@@ -282,10 +282,11 @@ where
|
||||
-- them, eventually put each of them back in `Context.tacSnap?` in `applyAltStx`
|
||||
withAlwaysResolvedPromise fun finished => do
|
||||
withAlwaysResolvedPromises altStxs.size fun altPromises => do
|
||||
tacSnap.new.resolve <| .mk {
|
||||
tacSnap.new.resolve {
|
||||
-- save all relevant syntax here for comparison with next document version
|
||||
stx := mkNullNode altStxs
|
||||
diagnostics := .empty
|
||||
inner? := none
|
||||
finished := { range? := none, task := finished.result }
|
||||
next := altStxs.zipWith altPromises fun stx prom =>
|
||||
{ range? := stx.getRange?, task := prom.result }
|
||||
@@ -298,7 +299,7 @@ where
|
||||
let old := old.val.get
|
||||
-- use old version of `mkNullNode altsSyntax` as guard, will be compared with new
|
||||
-- version and picked apart in `applyAltStx`
|
||||
return ⟨old.data.stx, (← old.data.next[i]?)⟩
|
||||
return ⟨old.stx, (← old.next[i]?)⟩
|
||||
new := prom
|
||||
}
|
||||
finished.resolve { diagnostics := .empty, state? := (← saveState) }
|
||||
@@ -340,9 +341,9 @@ where
|
||||
for h : altStxIdx in [0:altStxs.size] do
|
||||
let altStx := altStxs[altStxIdx]
|
||||
let altName := getAltName altStx
|
||||
if let some i := alts.findIdx? (·.1 == altName) then
|
||||
if let some i := alts.findFinIdx? (·.1 == altName) then
|
||||
-- cover named alternative
|
||||
applyAltStx tacSnaps altStxIdx altStx alts[i]!
|
||||
applyAltStx tacSnaps altStxIdx altStx alts[i]
|
||||
alts := alts.eraseIdx i
|
||||
else if !alts.isEmpty && isWildcard altStx then
|
||||
-- cover all alternatives
|
||||
|
||||
@@ -40,7 +40,7 @@ private def tryAssumption (mvarId : MVarId) : MetaM (List MVarId) := do
|
||||
let ids := stx[1].getArgs.toList.map getNameOfIdent'
|
||||
liftMetaTactic fun mvarId => do
|
||||
match (← Meta.injections mvarId ids) with
|
||||
| .solved => checkUnusedIds `injections mvarId ids; return []
|
||||
| .subgoal mvarId' unusedIds => checkUnusedIds `injections mvarId unusedIds; tryAssumption mvarId'
|
||||
| .solved => checkUnusedIds `injections mvarId ids; return []
|
||||
| .subgoal mvarId' unusedIds _ => checkUnusedIds `injections mvarId unusedIds; tryAssumption mvarId'
|
||||
|
||||
end Lean.Elab.Tactic
|
||||
|
||||
@@ -230,7 +230,7 @@ instance : ToSnapshotTree TacticFinishedSnapshot where
|
||||
toSnapshotTree s := ⟨s.toSnapshot, #[]⟩
|
||||
|
||||
/-- Snapshot just before execution of a tactic. -/
|
||||
structure TacticParsedSnapshotData (TacticParsedSnapshot : Type) extends Language.Snapshot where
|
||||
structure TacticParsedSnapshot extends Language.Snapshot where
|
||||
/-- Syntax tree of the tactic, stored and compared for incremental reuse. -/
|
||||
stx : Syntax
|
||||
/-- Task for nested incrementality, if enabled for tactic. -/
|
||||
@@ -240,16 +240,9 @@ structure TacticParsedSnapshotData (TacticParsedSnapshot : Type) extends Languag
|
||||
/-- Tasks for subsequent, potentially parallel, tactic steps. -/
|
||||
next : Array (SnapshotTask TacticParsedSnapshot) := #[]
|
||||
deriving Inhabited
|
||||
|
||||
/-- State after execution of a single synchronous tactic step. -/
|
||||
inductive TacticParsedSnapshot where
|
||||
| mk (data : TacticParsedSnapshotData TacticParsedSnapshot)
|
||||
deriving Inhabited
|
||||
abbrev TacticParsedSnapshot.data : TacticParsedSnapshot → TacticParsedSnapshotData TacticParsedSnapshot
|
||||
| .mk data => data
|
||||
partial instance : ToSnapshotTree TacticParsedSnapshot where
|
||||
toSnapshotTree := go where
|
||||
go := fun ⟨s⟩ => ⟨s.toSnapshot,
|
||||
go := fun s => ⟨s.toSnapshot,
|
||||
s.inner?.toArray.map (·.map (sync := true) go) ++
|
||||
#[s.finished.map (sync := true) toSnapshotTree] ++
|
||||
s.next.map (·.map (sync := true) go)⟩
|
||||
@@ -480,6 +473,26 @@ def withoutTacticReuse [Monad m] [MonadWithReaderOf Context m] [MonadOptions m]
|
||||
return !cond }
|
||||
}) act
|
||||
|
||||
@[inherit_doc Core.wrapAsync]
|
||||
def wrapAsync (act : Unit → TermElabM α) : TermElabM (EIO Exception α) := do
|
||||
let ctx ← read
|
||||
let st ← get
|
||||
let metaCtx ← readThe Meta.Context
|
||||
let metaSt ← getThe Meta.State
|
||||
Core.wrapAsync fun _ =>
|
||||
act () |>.run ctx |>.run' st |>.run' metaCtx metaSt
|
||||
|
||||
@[inherit_doc Core.wrapAsyncAsSnapshot]
|
||||
def wrapAsyncAsSnapshot (act : Unit → TermElabM Unit)
|
||||
(desc : String := by exact decl_name%.toString) :
|
||||
TermElabM (BaseIO Language.SnapshotTree) := do
|
||||
let ctx ← read
|
||||
let st ← get
|
||||
let metaCtx ← readThe Meta.Context
|
||||
let metaSt ← getThe Meta.State
|
||||
Core.wrapAsyncAsSnapshot (desc := desc) fun _ =>
|
||||
act () |>.run ctx |>.run' st |>.run' metaCtx metaSt
|
||||
|
||||
abbrev TermElabResult (α : Type) := EStateM.Result Exception SavedState α
|
||||
|
||||
/--
|
||||
|
||||
@@ -317,8 +317,8 @@ partial def ensureExtensionsArraySize (exts : Array EnvExtensionState) : IO (Arr
|
||||
where
|
||||
loop (i : Nat) (exts : Array EnvExtensionState) : IO (Array EnvExtensionState) := do
|
||||
let envExtensions ← envExtensionsRef.get
|
||||
if i < envExtensions.size then
|
||||
let s ← envExtensions[i]!.mkInitial
|
||||
if h : i < envExtensions.size then
|
||||
let s ← envExtensions[i].mkInitial
|
||||
let exts := exts.push s
|
||||
loop (i + 1) exts
|
||||
else
|
||||
@@ -726,7 +726,6 @@ def mkExtNameMap (startingAt : Nat) : IO (Std.HashMap Name Nat) := do
|
||||
let descrs ← persistentEnvExtensionsRef.get
|
||||
let mut result := {}
|
||||
for h : i in [startingAt : descrs.size] do
|
||||
have : i < descrs.size := h.upper
|
||||
let descr := descrs[i]
|
||||
result := result.insert descr.name i
|
||||
return result
|
||||
@@ -740,7 +739,6 @@ private def setImportedEntries (env : Environment) (mods : Array ModuleData) (st
|
||||
/- For each module `mod`, and `mod.entries`, if the extension name is one of the extensions after `startingAt`, set `entries` -/
|
||||
let extNameIdx ← mkExtNameMap startingAt
|
||||
for h : modIdx in [:mods.size] do
|
||||
have : modIdx < mods.size := h.upper
|
||||
let mod := mods[modIdx]
|
||||
for (extName, entries) in mod.entries do
|
||||
if let some entryIdx := extNameIdx[extName]? then
|
||||
@@ -860,7 +858,7 @@ def finalizeImport (s : ImportState) (imports : Array Import) (opts : Options) (
|
||||
let mut const2ModIdx : Std.HashMap Name ModuleIdx := Std.HashMap.empty (capacity := numConsts)
|
||||
let mut constantMap : Std.HashMap Name ConstantInfo := Std.HashMap.empty (capacity := numConsts)
|
||||
for h : modIdx in [0:s.moduleData.size] do
|
||||
let mod := s.moduleData[modIdx]'h.upper
|
||||
let mod := s.moduleData[modIdx]
|
||||
for cname in mod.constNames, cinfo in mod.constants do
|
||||
match constantMap.getThenInsertIfNew? cname cinfo with
|
||||
| (cinfoPrev?, constantMap') =>
|
||||
|
||||
@@ -10,9 +10,8 @@ Authors: Sebastian Ullrich
|
||||
|
||||
prelude
|
||||
import Init.System.Promise
|
||||
import Lean.Message
|
||||
import Lean.Parser.Types
|
||||
import Lean.Elab.InfoTree
|
||||
import Lean.Util.Trace
|
||||
|
||||
set_option linter.missingDocs true
|
||||
|
||||
@@ -56,6 +55,11 @@ structure Snapshot where
|
||||
/-- General elaboration metadata produced by this step. -/
|
||||
infoTree? : Option Elab.InfoTree := none
|
||||
/--
|
||||
Trace data produced by this step. Currently used only by `trace.profiler.output`, otherwise we
|
||||
depend on the elaborator adding traces to `diagnostics` eventually.
|
||||
-/
|
||||
traces : TraceState := {}
|
||||
/--
|
||||
Whether it should be indicated to the user that a fatal error (which should be part of
|
||||
`diagnostics`) occurred that prevents processing of the remainder of the file.
|
||||
-/
|
||||
@@ -193,32 +197,13 @@ def withAlwaysResolvedPromises [Monad m] [MonadLiftT BaseIO m] [MonadFinally m]
|
||||
for asynchronously collecting information about the entirety of snapshots in the language server.
|
||||
The involved tasks may form a DAG on the `Task` dependency level but this is not captured by this
|
||||
data structure. -/
|
||||
inductive SnapshotTree where
|
||||
/-- Creates a snapshot tree node. -/
|
||||
| mk (element : Snapshot) (children : Array (SnapshotTask SnapshotTree))
|
||||
structure SnapshotTree where
|
||||
/-- The immediately available element of the snapshot tree node. -/
|
||||
element : Snapshot
|
||||
/-- The asynchronously available children of the snapshot tree node. -/
|
||||
children : Array (SnapshotTask SnapshotTree)
|
||||
deriving Inhabited
|
||||
|
||||
/-- The immediately available element of the snapshot tree node. -/
|
||||
abbrev SnapshotTree.element : SnapshotTree → Snapshot
|
||||
| mk s _ => s
|
||||
/-- The asynchronously available children of the snapshot tree node. -/
|
||||
abbrev SnapshotTree.children : SnapshotTree → Array (SnapshotTask SnapshotTree)
|
||||
| mk _ children => children
|
||||
|
||||
/-- Produces trace of given snapshot tree, synchronously waiting on all children. -/
|
||||
partial def SnapshotTree.trace (s : SnapshotTree) : CoreM Unit :=
|
||||
go none s
|
||||
where go range? s := do
|
||||
let file ← getFileMap
|
||||
let mut desc := f!"{s.element.desc}"
|
||||
if let some range := range? then
|
||||
desc := desc ++ f!"{file.toPosition range.start}-{file.toPosition range.stop} "
|
||||
desc := desc ++ .prefixJoin "\n• " (← s.element.diagnostics.msgLog.toList.mapM (·.toString))
|
||||
if let some t := s.element.infoTree? then
|
||||
trace[Elab.info] (← t.format)
|
||||
withTraceNode `Elab.snapshotTree (fun _ => pure desc) do
|
||||
s.children.toList.forM fun c => go c.range? c.get
|
||||
|
||||
/--
|
||||
Helper class for projecting a heterogeneous hierarchy of snapshot classes to a homogeneous
|
||||
representation. -/
|
||||
@@ -303,6 +288,14 @@ def SnapshotTree.runAndReport (s : SnapshotTree) (opts : Options) (json := false
|
||||
def SnapshotTree.getAll (s : SnapshotTree) : Array Snapshot :=
|
||||
s.forM (m := StateM _) (fun s => modify (·.push s)) |>.run #[] |>.2
|
||||
|
||||
/-- Returns a task that waits on all snapshots in the tree. -/
|
||||
def SnapshotTree.waitAll : SnapshotTree → BaseIO (Task Unit)
|
||||
| mk _ children => go children.toList
|
||||
where
|
||||
go : List (SnapshotTask SnapshotTree) → BaseIO (Task Unit)
|
||||
| [] => return .pure ()
|
||||
| t::ts => BaseIO.bindTask t.task fun _ => go ts
|
||||
|
||||
/-- Context of an input processing invocation. -/
|
||||
structure ProcessingContext extends Parser.InputContext
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ Authors: Sebastian Ullrich
|
||||
|
||||
prelude
|
||||
import Lean.Language.Basic
|
||||
import Lean.Language.Util
|
||||
import Lean.Language.Lean.Types
|
||||
import Lean.Parser.Module
|
||||
import Lean.Elab.Import
|
||||
@@ -166,13 +167,6 @@ namespace Lean.Language.Lean
|
||||
open Lean.Elab Command
|
||||
open Lean.Parser
|
||||
|
||||
/-- Option for capturing output to stderr during elaboration. -/
|
||||
register_builtin_option stderrAsMessages : Bool := {
|
||||
defValue := true
|
||||
group := "server"
|
||||
descr := "(server) capture output to the Lean stderr channel (such as from `dbg_trace`) during elaboration of a command as a diagnostic message"
|
||||
}
|
||||
|
||||
/-- Lean-specific processing context. -/
|
||||
structure LeanProcessingContext extends ProcessingContext where
|
||||
/-- Position of the first file difference if there was a previous invocation. -/
|
||||
@@ -348,7 +342,7 @@ where
|
||||
if let some (some processed) ← old.processedResult.get? then
|
||||
-- ...and the edit location is after the next command (see note [Incremental Parsing])...
|
||||
if let some nextCom ← processed.firstCmdSnap.get? then
|
||||
if (← isBeforeEditPos nextCom.data.parserState.pos) then
|
||||
if (← isBeforeEditPos nextCom.parserState.pos) then
|
||||
-- ...go immediately to next snapshot
|
||||
return (← unchanged old old.stx oldSuccess.parserState)
|
||||
|
||||
@@ -470,20 +464,20 @@ where
|
||||
-- from `old`
|
||||
if let some oldNext := old.nextCmdSnap? then do
|
||||
let newProm ← IO.Promise.new
|
||||
let _ ← old.data.finishedSnap.bindIO (sync := true) fun oldFinished =>
|
||||
let _ ← old.finishedSnap.bindIO (sync := true) fun oldFinished =>
|
||||
-- also wait on old command parse snapshot as parsing is cheap and may allow for
|
||||
-- elaboration reuse
|
||||
oldNext.bindIO (sync := true) fun oldNext => do
|
||||
parseCmd oldNext newParserState oldFinished.cmdState newProm sync ctx
|
||||
return .pure ()
|
||||
prom.resolve <| .mk (data := old.data) (nextCmdSnap? := some { range? := none, task := newProm.result })
|
||||
prom.resolve <| { old with nextCmdSnap? := some { range? := none, task := newProm.result } }
|
||||
else prom.resolve old -- terminal command, we're done!
|
||||
|
||||
-- fast path, do not even start new task for this snapshot
|
||||
if let some old := old? then
|
||||
if let some nextCom ← old.nextCmdSnap?.bindM (·.get?) then
|
||||
if (← isBeforeEditPos nextCom.data.parserState.pos) then
|
||||
return (← unchanged old old.data.parserState)
|
||||
if (← isBeforeEditPos nextCom.parserState.pos) then
|
||||
return (← unchanged old old.parserState)
|
||||
|
||||
let beginPos := parserState.pos
|
||||
let scope := cmdState.scopes.head!
|
||||
@@ -500,7 +494,7 @@ where
|
||||
-- NOTE: as `parserState.pos` includes trailing whitespace, this forces reprocessing even if
|
||||
-- only that whitespace changes, which is wasteful but still necessary because it may
|
||||
-- influence the range of error messages such as from a trailing `exact`
|
||||
if stx.eqWithInfo old.data.stx then
|
||||
if stx.eqWithInfo old.stx then
|
||||
-- Here we must make sure to pass the *new* parser state; see NOTE in `unchanged`
|
||||
return (← unchanged old parserState)
|
||||
-- on first change, make sure to cancel old invocation
|
||||
@@ -515,11 +509,13 @@ where
|
||||
-- this is a bit ugly as we don't want to adjust our API with `Option`s just for cancellation
|
||||
-- (as no-one should look at this result in that case) but anything containing `Environment`
|
||||
-- is not `Inhabited`
|
||||
prom.resolve <| .mk (nextCmdSnap? := none) {
|
||||
prom.resolve <| {
|
||||
diagnostics := .empty, stx := .missing, parserState
|
||||
elabSnap := .pure <| .ofTyped { diagnostics := .empty : SnapshotLeaf }
|
||||
finishedSnap := .pure { diagnostics := .empty, cmdState }
|
||||
reportSnap := default
|
||||
tacticCache := (← IO.mkRef {})
|
||||
nextCmdSnap? := none
|
||||
}
|
||||
return
|
||||
|
||||
@@ -528,6 +524,7 @@ where
|
||||
-- definitely resolved in `doElab` task
|
||||
let elabPromise ← IO.Promise.new
|
||||
let finishedPromise ← IO.Promise.new
|
||||
let reportPromise ← IO.Promise.new
|
||||
-- (Try to) use last line of command as range for final snapshot task. This ensures we do not
|
||||
-- retract the progress bar to a previous position in case the command support incremental
|
||||
-- reporting but has significant work after resolving its last incremental promise, such as
|
||||
@@ -537,30 +534,55 @@ where
|
||||
-- irrelevant in this case.
|
||||
let endRange? := stx.getTailPos?.map fun pos => ⟨pos, pos⟩
|
||||
let finishedSnap := { range? := endRange?, task := finishedPromise.result }
|
||||
let tacticCache ← old?.map (·.data.tacticCache) |>.getDM (IO.mkRef {})
|
||||
let tacticCache ← old?.map (·.tacticCache) |>.getDM (IO.mkRef {})
|
||||
|
||||
let minimalSnapshots := internal.cmdlineSnapshots.get cmdState.scopes.head!.opts
|
||||
let next? ← if Parser.isTerminalCommand stx then pure none
|
||||
-- for now, wait on "command finished" snapshot before parsing next command
|
||||
else some <$> IO.Promise.new
|
||||
let nextCmdSnap? := next?.map
|
||||
({ range? := some ⟨parserState.pos, ctx.input.endPos⟩, task := ·.result })
|
||||
let diagnostics ← Snapshot.Diagnostics.ofMessageLog msgLog
|
||||
let data := if minimalSnapshots && !Parser.isTerminalCommand stx then {
|
||||
diagnostics
|
||||
stx := .missing
|
||||
parserState := {}
|
||||
let (stx', parserState') := if minimalSnapshots && !Parser.isTerminalCommand stx then
|
||||
(default, default)
|
||||
else
|
||||
(stx, parserState)
|
||||
prom.resolve {
|
||||
diagnostics, finishedSnap, tacticCache, nextCmdSnap?
|
||||
stx := stx', parserState := parserState'
|
||||
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
|
||||
finishedSnap
|
||||
tacticCache
|
||||
} else {
|
||||
diagnostics, stx, parserState, tacticCache
|
||||
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
|
||||
finishedSnap
|
||||
reportSnap := { range? := endRange?, task := reportPromise.result }
|
||||
}
|
||||
prom.resolve <| .mk (nextCmdSnap? := next?.map
|
||||
({ range? := some ⟨parserState.pos, ctx.input.endPos⟩, task := ·.result })) data
|
||||
let cmdState ← doElab stx cmdState beginPos
|
||||
{ old? := old?.map fun old => ⟨old.data.stx, old.data.elabSnap⟩, new := elabPromise }
|
||||
{ old? := old?.map fun old => ⟨old.stx, old.elabSnap⟩, new := elabPromise }
|
||||
finishedPromise tacticCache ctx
|
||||
let traceTask ←
|
||||
if (← isTracingEnabledForCore `Elab.snapshotTree cmdState.scopes.head!.opts) then
|
||||
-- We want to trace all of `CommandParsedSnapshot` but `traceTask` is part of it, so let's
|
||||
-- create a temporary snapshot tree containing all tasks but it
|
||||
let snaps := #[
|
||||
{ range? := none, task := elabPromise.result.map (sync := true) toSnapshotTree },
|
||||
{ range? := none, task := finishedPromise.result.map (sync := true) toSnapshotTree }] ++
|
||||
cmdState.snapshotTasks
|
||||
let tree := SnapshotTree.mk { diagnostics := .empty } snaps
|
||||
BaseIO.bindTask (← tree.waitAll) fun _ => do
|
||||
let .ok (_, s) ← EIO.toBaseIO <| tree.trace |>.run
|
||||
{ ctx with options := cmdState.scopes.head!.opts } { env := cmdState.env }
|
||||
| pure <| .pure <| .mk { diagnostics := .empty } #[]
|
||||
let mut msgLog := MessageLog.empty
|
||||
for trace in s.traceState.traces do
|
||||
msgLog := msgLog.add {
|
||||
fileName := ctx.fileName
|
||||
severity := MessageSeverity.information
|
||||
pos := ctx.fileMap.toPosition beginPos
|
||||
data := trace.msg
|
||||
}
|
||||
return .pure <| .mk { diagnostics := (← Snapshot.Diagnostics.ofMessageLog msgLog) } #[]
|
||||
else
|
||||
pure <| .pure <| .mk { diagnostics := .empty } #[]
|
||||
reportPromise.resolve <|
|
||||
.mk { diagnostics := .empty } <|
|
||||
cmdState.snapshotTasks.push { range? := endRange?, task := traceTask }
|
||||
if let some next := next? then
|
||||
-- We're definitely off the fast-forwarding path now
|
||||
parseCmd none parserState cmdState next (sync := false) ctx
|
||||
@@ -571,7 +593,9 @@ where
|
||||
LeanProcessingM Command.State := do
|
||||
let ctx ← read
|
||||
let scope := cmdState.scopes.head!
|
||||
let cmdStateRef ← IO.mkRef { cmdState with messages := .empty }
|
||||
-- reset per-command state
|
||||
let cmdStateRef ← IO.mkRef { cmdState with
|
||||
messages := .empty, traceState := {}, snapshotTasks := #[] }
|
||||
/-
|
||||
The same snapshot may be executed by different tasks. So, to make sure `elabCommandTopLevel`
|
||||
has exclusive access to the cache, we create a fresh reference here. Before this change, the
|
||||
@@ -586,8 +610,8 @@ where
|
||||
cancelTk? := some ctx.newCancelTk
|
||||
}
|
||||
let (output, _) ←
|
||||
IO.FS.withIsolatedStreams (isolateStderr := stderrAsMessages.get scope.opts) do
|
||||
liftM (m := BaseIO) do
|
||||
IO.FS.withIsolatedStreams (isolateStderr := Core.stderrAsMessages.get scope.opts) do
|
||||
EIO.toBaseIO do
|
||||
withLoggingExceptions
|
||||
(getResetInfoTrees *> Elab.Command.elabCommandTopLevel stx)
|
||||
cmdCtx cmdStateRef
|
||||
@@ -613,6 +637,7 @@ where
|
||||
finishedPromise.resolve {
|
||||
diagnostics := (← Snapshot.Diagnostics.ofMessageLog cmdState.messages)
|
||||
infoTree? := infoTree
|
||||
traces := cmdState.traceState
|
||||
cmdState := if cmdline then {
|
||||
env := Runtime.markPersistent cmdState.env
|
||||
maxRecDepth := 0
|
||||
@@ -644,6 +669,6 @@ where goCmd snap :=
|
||||
if let some next := snap.nextCmdSnap? then
|
||||
goCmd next.get
|
||||
else
|
||||
snap.data.finishedSnap.get.cmdState
|
||||
snap.finishedSnap.get.cmdState
|
||||
|
||||
end Lean
|
||||
|
||||
@@ -34,7 +34,7 @@ instance : ToSnapshotTree CommandFinishedSnapshot where
|
||||
toSnapshotTree s := ⟨s.toSnapshot, #[]⟩
|
||||
|
||||
/-- State after a command has been parsed. -/
|
||||
structure CommandParsedSnapshotData extends Snapshot where
|
||||
structure CommandParsedSnapshot extends Snapshot where
|
||||
/-- Syntax tree of the command. -/
|
||||
stx : Syntax
|
||||
/-- Resulting parser state. -/
|
||||
@@ -46,29 +46,19 @@ structure CommandParsedSnapshotData extends Snapshot where
|
||||
elabSnap : SnapshotTask DynamicSnapshot
|
||||
/-- State after processing is finished. -/
|
||||
finishedSnap : SnapshotTask CommandFinishedSnapshot
|
||||
/-- Additional, untyped snapshots used for reporting, not reuse. -/
|
||||
reportSnap : SnapshotTask SnapshotTree
|
||||
/-- Cache for `save`; to be replaced with incrementality. -/
|
||||
tacticCache : IO.Ref Tactic.Cache
|
||||
/-- Next command, unless this is a terminal command. -/
|
||||
nextCmdSnap? : Option (SnapshotTask CommandParsedSnapshot)
|
||||
deriving Nonempty
|
||||
|
||||
/-- State after a command has been parsed. -/
|
||||
-- workaround for lack of recursive structures
|
||||
inductive CommandParsedSnapshot where
|
||||
/-- Creates a command parsed snapshot. -/
|
||||
| mk (data : CommandParsedSnapshotData)
|
||||
(nextCmdSnap? : Option (SnapshotTask CommandParsedSnapshot))
|
||||
deriving Nonempty
|
||||
/-- The snapshot data. -/
|
||||
abbrev CommandParsedSnapshot.data : CommandParsedSnapshot → CommandParsedSnapshotData
|
||||
| mk data _ => data
|
||||
/-- Next command, unless this is a terminal command. -/
|
||||
abbrev CommandParsedSnapshot.nextCmdSnap? : CommandParsedSnapshot →
|
||||
Option (SnapshotTask CommandParsedSnapshot)
|
||||
| mk _ next? => next?
|
||||
partial instance : ToSnapshotTree CommandParsedSnapshot where
|
||||
toSnapshotTree := go where
|
||||
go s := ⟨s.data.toSnapshot,
|
||||
#[s.data.elabSnap.map (sync := true) toSnapshotTree,
|
||||
s.data.finishedSnap.map (sync := true) toSnapshotTree] |>
|
||||
go s := ⟨s.toSnapshot,
|
||||
#[s.elabSnap.map (sync := true) toSnapshotTree,
|
||||
s.finishedSnap.map (sync := true) toSnapshotTree,
|
||||
s.reportSnap] |>
|
||||
pushOpt (s.nextCmdSnap?.map (·.map (sync := true) go))⟩
|
||||
|
||||
/-- State after successful importing. -/
|
||||
|
||||
29
src/Lean/Language/Util.lean
Normal file
29
src/Lean/Language/Util.lean
Normal file
@@ -0,0 +1,29 @@
|
||||
/-
|
||||
Copyright (c) 2023 Lean FRO. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
||||
Additional snapshot functionality that needs further imports.
|
||||
|
||||
Authors: Sebastian Ullrich
|
||||
-/
|
||||
|
||||
prelude
|
||||
import Lean.Language.Basic
|
||||
import Lean.CoreM
|
||||
import Lean.Elab.InfoTree
|
||||
|
||||
namespace Lean.Language
|
||||
|
||||
/-- Produces trace of given snapshot tree, synchronously waiting on all children. -/
|
||||
partial def SnapshotTree.trace (s : SnapshotTree) : CoreM Unit :=
|
||||
go none s
|
||||
where go range? s := do
|
||||
let file ← getFileMap
|
||||
let mut desc := f!"{s.element.desc}"
|
||||
if let some range := range? then
|
||||
desc := desc ++ f!"{file.toPosition range.start}-{file.toPosition range.stop} "
|
||||
desc := desc ++ .prefixJoin "\n• " (← s.element.diagnostics.msgLog.toList.mapM (·.toString))
|
||||
if let some t := s.element.infoTree? then
|
||||
trace[Elab.info] (← t.format)
|
||||
withTraceNode `Elab.snapshotTree (fun _ => pure desc) do
|
||||
s.children.toList.forM fun c => go c.range? c.get
|
||||
@@ -406,8 +406,8 @@ def isSubPrefixOf (lctx₁ lctx₂ : LocalContext) (exceptFVars : Array Expr :=
|
||||
|
||||
@[inline] def mkBinding (isLambda : Bool) (lctx : LocalContext) (xs : Array Expr) (b : Expr) : Expr :=
|
||||
let b := b.abstract xs
|
||||
xs.size.foldRev (init := b) fun i b =>
|
||||
let x := xs[i]!
|
||||
xs.size.foldRev (init := b) fun i _ b =>
|
||||
let x := xs[i]
|
||||
match lctx.findFVar? x with
|
||||
| some (.cdecl _ _ n ty bi _) =>
|
||||
let ty := ty.abstractRange i xs;
|
||||
@@ -457,7 +457,7 @@ def sanitizeNames (lctx : LocalContext) : StateM NameSanitizerState LocalContext
|
||||
let st ← get
|
||||
if !getSanitizeNames st.options then pure lctx else
|
||||
StateT.run' (s := ({} : NameSet)) <|
|
||||
lctx.decls.size.foldRevM (init := lctx) fun i lctx => do
|
||||
lctx.decls.size.foldRevM (init := lctx) fun i _ lctx => do
|
||||
match lctx.decls[i]! with
|
||||
| none => pure lctx
|
||||
| some decl =>
|
||||
|
||||
@@ -244,8 +244,8 @@ partial def formatAux : NamingContext → Option MessageDataContext → MessageD
|
||||
| panic! s!"MessageData.ofLazy: expected MessageData in Dynamic, got {dyn.typeName}"
|
||||
formatAux nCtx ctx? msg
|
||||
|
||||
protected def format (msgData : MessageData) : IO Format :=
|
||||
formatAux { currNamespace := Name.anonymous, openDecls := [] } none msgData
|
||||
protected def format (msgData : MessageData) (ctx? : Option MessageDataContext := none) : IO Format :=
|
||||
formatAux { currNamespace := Name.anonymous, openDecls := [] } ctx? msgData
|
||||
|
||||
protected def toString (msgData : MessageData) : IO String := do
|
||||
return toString (← msgData.format)
|
||||
|
||||
@@ -825,7 +825,7 @@ def mkFreshExprMVarWithId (mvarId : MVarId) (type? : Option Expr := none) (kind
|
||||
mkFreshExprMVarWithIdCore mvarId type kind userName
|
||||
|
||||
def mkFreshLevelMVars (num : Nat) : MetaM (List Level) :=
|
||||
num.foldM (init := []) fun _ us =>
|
||||
num.foldM (init := []) fun _ _ us =>
|
||||
return (← mkFreshLevelMVar)::us
|
||||
|
||||
def mkFreshLevelMVarsFor (info : ConstantInfo) : MetaM (List Level) :=
|
||||
|
||||
@@ -286,8 +286,8 @@ partial def process : ClosureM Unit := do
|
||||
@[inline] def mkBinding (isLambda : Bool) (decls : Array LocalDecl) (b : Expr) : Expr :=
|
||||
let xs := decls.map LocalDecl.toExpr
|
||||
let b := b.abstract xs
|
||||
decls.size.foldRev (init := b) fun i b =>
|
||||
let decl := decls[i]!
|
||||
decls.size.foldRev (init := b) fun i _ b =>
|
||||
let decl := decls[i]
|
||||
match decl with
|
||||
| .cdecl _ _ n ty bi _ =>
|
||||
let ty := ty.abstractRange i xs
|
||||
|
||||
@@ -122,8 +122,8 @@ def mkHCongr (f : Expr) : MetaM CongrTheorem := do
|
||||
private def fixKindsForDependencies (info : FunInfo) (kinds : Array CongrArgKind) : Array CongrArgKind := Id.run do
|
||||
let mut kinds := kinds
|
||||
for i in [:info.paramInfo.size] do
|
||||
for j in [i+1:info.paramInfo.size] do
|
||||
if info.paramInfo[j]!.backDeps.contains i then
|
||||
for hj : j in [i+1:info.paramInfo.size] do
|
||||
if info.paramInfo[j].backDeps.contains i then
|
||||
if kinds[j]! matches CongrArgKind.eq || kinds[j]! matches CongrArgKind.fixed then
|
||||
-- We must fix `i` because there is a `j` that depends on `i` and `j` is not cast-fixed.
|
||||
kinds := kinds.set! i CongrArgKind.fixed
|
||||
|
||||
@@ -255,8 +255,8 @@ private def isDefEqArgsFirstPass
|
||||
(paramInfo : Array ParamInfo) (args₁ args₂ : Array Expr) : MetaM DefEqArgsFirstPassResult := do
|
||||
let mut postponedImplicit := #[]
|
||||
let mut postponedHO := #[]
|
||||
for i in [:paramInfo.size] do
|
||||
let info := paramInfo[i]!
|
||||
for h : i in [:paramInfo.size] do
|
||||
let info := paramInfo[i]
|
||||
let a₁ := args₁[i]!
|
||||
let a₂ := args₂[i]!
|
||||
if info.dependsOnHigherOrderOutParam || info.higherOrderOutParam then
|
||||
@@ -939,29 +939,6 @@ def check (hasCtxLocals : Bool) (mctx : MetavarContext) (lctx : LocalContext) (m
|
||||
|
||||
end CheckAssignmentQuick
|
||||
|
||||
/--
|
||||
Auxiliary function used at `typeOccursCheckImp`.
|
||||
Given `type`, it tries to eliminate "dependencies". For example, suppose we are trying to
|
||||
perform the assignment `?m := f (?n a b)` where
|
||||
```
|
||||
?n : let k := g ?m; A -> h k ?m -> C
|
||||
```
|
||||
If we just perform occurs check `?m` at the type of `?n`, we get a failure, but
|
||||
we claim these occurrences are ok because the type `?n a b : C`.
|
||||
In the example above, `typeOccursCheckImp` invokes this function with `n := 2`.
|
||||
Note that we avoid using `whnf` and `inferType` at `typeOccursCheckImp` to minimize the
|
||||
performance impact of this extra check.
|
||||
|
||||
See test `typeOccursCheckIssue.lean` for an example where this refinement is needed.
|
||||
The test is derived from a Mathlib file.
|
||||
-/
|
||||
private partial def skipAtMostNumBinders (type : Expr) (n : Nat) : Expr :=
|
||||
match type, n with
|
||||
| .forallE _ _ b _, n+1 => skipAtMostNumBinders b n
|
||||
| .mdata _ b, n => skipAtMostNumBinders b n
|
||||
| .letE _ _ v b _, n => skipAtMostNumBinders (b.instantiate1 v) n
|
||||
| type, _ => type
|
||||
|
||||
/-- `typeOccursCheck` implementation using unsafe (i.e., pointer equality) features. -/
|
||||
private unsafe def typeOccursCheckImp (mctx : MetavarContext) (mvarId : MVarId) (v : Expr) : Bool :=
|
||||
if v.hasExprMVar then
|
||||
@@ -982,19 +959,11 @@ where
|
||||
-- this function assumes all assigned metavariables have already been
|
||||
-- instantiated.
|
||||
go.run' mctx
|
||||
visitMVar (mvarId' : MVarId) (numArgs : Nat := 0) : Bool :=
|
||||
visitMVar (mvarId' : MVarId) : Bool :=
|
||||
if let some mvarDecl := mctx.findDecl? mvarId' then
|
||||
occursCheck (skipAtMostNumBinders mvarDecl.type numArgs)
|
||||
occursCheck mvarDecl.type
|
||||
else
|
||||
false
|
||||
visitApp (e : Expr) : StateM (PtrSet Expr) Bool :=
|
||||
e.withApp fun f args => do
|
||||
unless (← args.allM visit) do
|
||||
return false
|
||||
if f.isMVar then
|
||||
return visitMVar f.mvarId! args.size
|
||||
else
|
||||
visit f
|
||||
visit (e : Expr) : StateM (PtrSet Expr) Bool := do
|
||||
if !e.hasExprMVar then
|
||||
return true
|
||||
@@ -1003,7 +972,7 @@ where
|
||||
else match e with
|
||||
| .mdata _ b => visit b
|
||||
| .proj _ _ s => visit s
|
||||
| .app .. => visitApp e
|
||||
| .app f a => visit f <&&> visit a
|
||||
| .lam _ d b _ => visit d <&&> visit b
|
||||
| .forallE _ d b _ => visit d <&&> visit b
|
||||
| .letE _ t v b _ => visit t <&&> visit v <&&> visit b
|
||||
|
||||
@@ -93,8 +93,8 @@ def setMVarUserNamesAt (e : Expr) (isTarget : Array Expr) : MetaM (Array MVarId)
|
||||
forEachExpr (← instantiateMVars e) fun e => do
|
||||
if e.isApp then
|
||||
let args := e.getAppArgs
|
||||
for i in [:args.size] do
|
||||
let arg := args[i]!
|
||||
for h : i in [:args.size] do
|
||||
let arg := args[i]
|
||||
if arg.isMVar && isTarget.contains arg then
|
||||
let mvarId := arg.mvarId!
|
||||
if (← mvarId.getDecl).userName.isAnonymous then
|
||||
|
||||
@@ -83,7 +83,7 @@ where
|
||||
forallTelescopeReducing t fun xs s => do
|
||||
let motiveType ← instantiateForall motive xs[:numParams]
|
||||
withLocalDecl motiveName BinderInfo.implicit motiveType fun motive => do
|
||||
mkForallFVars (xs.insertAt! numParams motive) s)
|
||||
mkForallFVars (xs.insertIdxIfInBounds numParams motive) s)
|
||||
|
||||
motiveType (indVal : InductiveVal) : MetaM Expr :=
|
||||
forallTelescopeReducing indVal.type fun xs _ => do
|
||||
@@ -152,7 +152,7 @@ where
|
||||
let run (x : StateRefT Nat MetaM Expr) : MetaM (Expr × Nat) := StateRefT'.run x 0
|
||||
let (_, cnt) ← run <| transform domain fun e => do
|
||||
if let some name := e.constName? then
|
||||
if let some _ := ctx.typeInfos.findIdx? fun indVal => indVal.name == name then
|
||||
if ctx.typeInfos.any fun indVal => indVal.name == name then
|
||||
modify (· + 1)
|
||||
return .continue
|
||||
|
||||
|
||||
@@ -566,9 +566,9 @@ Append results to array
|
||||
partial def appendResultsAux (mr : MatchResult α) (a : Array β) (f : Nat → α → β) : Array β :=
|
||||
let aa := mr.elts
|
||||
let n := aa.size
|
||||
Nat.fold (n := n) (init := a) fun i r =>
|
||||
Nat.fold (n := n) (init := a) fun i _ r =>
|
||||
let j := n-1-i
|
||||
let b := aa[j]!
|
||||
let b := aa[j]
|
||||
b.foldl (init := r) (· ++ ·.map (f j))
|
||||
|
||||
partial def appendResults (mr : MatchResult α) (a : Array α) : Array α :=
|
||||
|
||||
@@ -882,7 +882,7 @@ def mkMatcher (input : MkMatcherInput) (exceptionIfContainsSorry := false) : Met
|
||||
| none => pure ()
|
||||
|
||||
trace[Meta.Match.debug] "matcher: {matcher}"
|
||||
let unusedAltIdxs := lhss.length.fold (init := []) fun i r =>
|
||||
let unusedAltIdxs := lhss.length.fold (init := []) fun i _ r =>
|
||||
if s.used.contains i then r else i::r
|
||||
return {
|
||||
matcher,
|
||||
|
||||
@@ -129,9 +129,9 @@ where
|
||||
let typeNew := b.instantiate1 y
|
||||
if let some (_, lhs, rhs) ← matchEq? d then
|
||||
if lhs.isFVar && ys.contains lhs && args.contains lhs && isNamedPatternProof typeNew y then
|
||||
let some j := ys.getIdx? lhs | unreachable!
|
||||
let some j := ys.indexOf? lhs | unreachable!
|
||||
let ys := ys.eraseIdx j
|
||||
let some k := args.getIdx? lhs | unreachable!
|
||||
let some k := args.indexOf? lhs | unreachable!
|
||||
let mask := mask.set! k false
|
||||
let args := args.map fun arg => if arg == lhs then rhs else arg
|
||||
let arg ← mkEqRefl rhs
|
||||
@@ -222,16 +222,16 @@ private def contradiction (mvarId : MVarId) : MetaM Bool :=
|
||||
Auxiliary tactic that tries to replace as many variables as possible and then apply `contradiction`.
|
||||
We use it to discard redundant hypotheses.
|
||||
-/
|
||||
partial def trySubstVarsAndContradiction (mvarId : MVarId) : MetaM Bool :=
|
||||
partial def trySubstVarsAndContradiction (mvarId : MVarId) (forbidden : FVarIdSet := {}) : MetaM Bool :=
|
||||
commitWhen do
|
||||
let mvarId ← substVars mvarId
|
||||
match (← injections mvarId) with
|
||||
match (← injections mvarId (forbidden := forbidden)) with
|
||||
| .solved => return true -- closed goal
|
||||
| .subgoal mvarId' _ =>
|
||||
| .subgoal mvarId' _ forbidden =>
|
||||
if mvarId' == mvarId then
|
||||
contradiction mvarId
|
||||
else
|
||||
trySubstVarsAndContradiction mvarId'
|
||||
trySubstVarsAndContradiction mvarId' forbidden
|
||||
|
||||
private def processNextEq : M Bool := do
|
||||
let s ← get
|
||||
@@ -375,7 +375,8 @@ private partial def withSplitterAlts (altTypes : Array Expr) (f : Array Expr →
|
||||
inductive InjectionAnyResult where
|
||||
| solved
|
||||
| failed
|
||||
| subgoal (mvarId : MVarId)
|
||||
/-- `fvarId` refers to the local declaration selected for the application of the `injection` tactic. -/
|
||||
| subgoal (fvarId : FVarId) (mvarId : MVarId)
|
||||
|
||||
private def injectionAnyCandidate? (type : Expr) : MetaM (Option (Expr × Expr)) := do
|
||||
if let some (_, lhs, rhs) ← matchEq? type then
|
||||
@@ -385,21 +386,28 @@ private def injectionAnyCandidate? (type : Expr) : MetaM (Option (Expr × Expr))
|
||||
return some (lhs, rhs)
|
||||
return none
|
||||
|
||||
private def injectionAny (mvarId : MVarId) : MetaM InjectionAnyResult := do
|
||||
/--
|
||||
Try applying `injection` to a local declaration that is not in `forbidden`.
|
||||
|
||||
We use `forbidden` because the `injection` tactic might fail to clear the variable if there are forward dependencies.
|
||||
See `proveSubgoalLoop` for additional details.
|
||||
-/
|
||||
private def injectionAny (mvarId : MVarId) (forbidden : FVarIdSet := {}) : MetaM InjectionAnyResult := do
|
||||
mvarId.withContext do
|
||||
for localDecl in (← getLCtx) do
|
||||
if let some (lhs, rhs) ← injectionAnyCandidate? localDecl.type then
|
||||
unless (← isDefEq lhs rhs) do
|
||||
let lhs ← whnf lhs
|
||||
let rhs ← whnf rhs
|
||||
unless lhs.isRawNatLit && rhs.isRawNatLit do
|
||||
try
|
||||
match (← injection mvarId localDecl.fvarId) with
|
||||
| InjectionResult.solved => return InjectionAnyResult.solved
|
||||
| InjectionResult.subgoal mvarId .. => return InjectionAnyResult.subgoal mvarId
|
||||
catch ex =>
|
||||
trace[Meta.Match.matchEqs] "injectionAnyFailed at {localDecl.userName}, error\n{ex.toMessageData}"
|
||||
pure ()
|
||||
unless forbidden.contains localDecl.fvarId do
|
||||
if let some (lhs, rhs) ← injectionAnyCandidate? localDecl.type then
|
||||
unless (← isDefEq lhs rhs) do
|
||||
let lhs ← whnf lhs
|
||||
let rhs ← whnf rhs
|
||||
unless lhs.isRawNatLit && rhs.isRawNatLit do
|
||||
try
|
||||
match (← injection mvarId localDecl.fvarId) with
|
||||
| InjectionResult.solved => return InjectionAnyResult.solved
|
||||
| InjectionResult.subgoal mvarId .. => return InjectionAnyResult.subgoal localDecl.fvarId mvarId
|
||||
catch ex =>
|
||||
trace[Meta.Match.matchEqs] "injectionAnyFailed at {localDecl.userName}, error\n{ex.toMessageData}"
|
||||
pure ()
|
||||
return InjectionAnyResult.failed
|
||||
|
||||
|
||||
@@ -557,8 +565,8 @@ where
|
||||
let mut minorBodyNew := minor
|
||||
-- We have to extend the mapping to make sure `convertTemplate` can "fix" occurrences of the refined minor premises
|
||||
let mut m ← read
|
||||
for i in [:isAlt.size] do
|
||||
if isAlt[i]! then
|
||||
for h : i in [:isAlt.size] do
|
||||
if isAlt[i] then
|
||||
-- `convertTemplate` will correct occurrences of the alternative
|
||||
let alt := args[6+i]! -- Recall that `Eq.ndrec` has 6 arguments
|
||||
let some (_, numParams, argMask) := m.find? alt.fvarId! | unreachable!
|
||||
@@ -601,27 +609,32 @@ where
|
||||
let eNew := mkAppN eNew mvars
|
||||
return TransformStep.done eNew
|
||||
|
||||
proveSubgoalLoop (mvarId : MVarId) : MetaM Unit := do
|
||||
/-
|
||||
`forbidden` tracks variables that we have already applied `injection`.
|
||||
Recall that the `injection` tactic may not be able to eliminate them when
|
||||
they have forward dependencies.
|
||||
-/
|
||||
proveSubgoalLoop (mvarId : MVarId) (forbidden : FVarIdSet) : MetaM Unit := do
|
||||
trace[Meta.Match.matchEqs] "proveSubgoalLoop\n{mvarId}"
|
||||
if (← mvarId.contradictionQuick) then
|
||||
return ()
|
||||
match (← injectionAny mvarId) with
|
||||
| InjectionAnyResult.solved => return ()
|
||||
| InjectionAnyResult.failed =>
|
||||
match (← injectionAny mvarId forbidden) with
|
||||
| .solved => return ()
|
||||
| .failed =>
|
||||
let mvarId' ← substVars mvarId
|
||||
if mvarId' == mvarId then
|
||||
if (← mvarId.contradictionCore {}) then
|
||||
return ()
|
||||
throwError "failed to generate splitter for match auxiliary declaration '{matchDeclName}', unsolved subgoal:\n{MessageData.ofGoal mvarId}"
|
||||
else
|
||||
proveSubgoalLoop mvarId'
|
||||
| InjectionAnyResult.subgoal mvarId => proveSubgoalLoop mvarId
|
||||
proveSubgoalLoop mvarId' forbidden
|
||||
| .subgoal fvarId mvarId => proveSubgoalLoop mvarId (forbidden.insert fvarId)
|
||||
|
||||
proveSubgoal (mvarId : MVarId) : MetaM Unit := do
|
||||
trace[Meta.Match.matchEqs] "subgoal {mkMVar mvarId}, {repr (← mvarId.getDecl).kind}, {← mvarId.isAssigned}\n{MessageData.ofGoal mvarId}"
|
||||
let (_, mvarId) ← mvarId.intros
|
||||
let mvarId ← mvarId.tryClearMany (alts.map (·.fvarId!))
|
||||
proveSubgoalLoop mvarId
|
||||
proveSubgoalLoop mvarId {}
|
||||
|
||||
/--
|
||||
Create new alternatives (aka minor premises) by replacing `discrs` with `patterns` at `alts`.
|
||||
@@ -646,7 +659,7 @@ where
|
||||
/--
|
||||
Create conditional equations and splitter for the given match auxiliary declaration. -/
|
||||
private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns := withLCtx {} {} do
|
||||
trace[Meta.Match.matchEqs] "mkEquationsFor '{matchDeclName}'"
|
||||
withTraceNode `Meta.Match.matchEqs (fun _ => return m!"mkEquationsFor '{matchDeclName}'") do
|
||||
withConfig (fun c => { c with etaStruct := .none }) do
|
||||
/-
|
||||
Remark: user have requested the `split` tactic to be available for writing code.
|
||||
|
||||
@@ -59,9 +59,9 @@ def addArg (matcherApp : MatcherApp) (e : Expr) : MetaM MatcherApp :=
|
||||
-- This error can only happen if someone implemented a transformation that rewrites the motive created by `mkMatcher`.
|
||||
throwError "unexpected matcher application, motive must be lambda expression with #{matcherApp.discrs.size} arguments"
|
||||
let eType ← inferType e
|
||||
let eTypeAbst ← matcherApp.discrs.size.foldRevM (init := eType) fun i eTypeAbst => do
|
||||
let eTypeAbst ← matcherApp.discrs.size.foldRevM (init := eType) fun i _ eTypeAbst => do
|
||||
let motiveArg := motiveArgs[i]!
|
||||
let discr := matcherApp.discrs[i]!
|
||||
let discr := matcherApp.discrs[i]
|
||||
let eTypeAbst ← kabstract eTypeAbst discr
|
||||
return eTypeAbst.instantiate1 motiveArg
|
||||
let motiveBody ← mkArrow eTypeAbst motiveBody
|
||||
@@ -118,9 +118,9 @@ def refineThrough (matcherApp : MatcherApp) (e : Expr) : MetaM (Array Expr) :=
|
||||
-- This error can only happen if someone implemented a transformation that rewrites the motive created by `mkMatcher`.
|
||||
throwError "failed to transfer argument through matcher application, motive must be lambda expression with #{matcherApp.discrs.size} arguments"
|
||||
|
||||
let eAbst ← matcherApp.discrs.size.foldRevM (init := e) fun i eAbst => do
|
||||
let eAbst ← matcherApp.discrs.size.foldRevM (init := e) fun i _ eAbst => do
|
||||
let motiveArg := motiveArgs[i]!
|
||||
let discr := matcherApp.discrs[i]!
|
||||
let discr := matcherApp.discrs[i]
|
||||
let eTypeAbst ← kabstract eAbst discr
|
||||
return eTypeAbst.instantiate1 motiveArg
|
||||
-- Let's create something that’s a `Sort` and mentions `e`
|
||||
|
||||
@@ -107,7 +107,7 @@ private def getMajorPosDepElim (declName : Name) (majorPos? : Option Nat) (xs :
|
||||
if motiveArgs.isEmpty then
|
||||
throwError "invalid user defined recursor, '{declName}' does not support dependent elimination, and position of the major premise was not specified (solution: set attribute '[recursor <pos>]', where <pos> is the position of the major premise)"
|
||||
let major := motiveArgs.back!
|
||||
match xs.getIdx? major with
|
||||
match xs.indexOf? major with
|
||||
| some majorPos => pure (major, majorPos, true)
|
||||
| none => throwError "ill-formed recursor '{declName}'"
|
||||
|
||||
|
||||
@@ -104,8 +104,8 @@ def appendParentTag (mvarId : MVarId) (newMVars : Array Expr) (binderInfos : Arr
|
||||
newMVars[0]!.mvarId!.setTag parentTag
|
||||
else
|
||||
unless parentTag.isAnonymous do
|
||||
newMVars.size.forM fun i => do
|
||||
let mvarIdNew := newMVars[i]!.mvarId!
|
||||
newMVars.size.forM fun i _ => do
|
||||
let mvarIdNew := newMVars[i].mvarId!
|
||||
unless (← mvarIdNew.isAssigned) do
|
||||
unless binderInfos[i]!.isInstImplicit do
|
||||
let currTag ← mvarIdNew.getTag
|
||||
|
||||
@@ -168,7 +168,7 @@ private def hasIndepIndices (ctx : Context) : MetaM Bool := do
|
||||
else if ctx.majorTypeIndices.any fun idx => !idx.isFVar then
|
||||
/- One of the indices is not a free variable. -/
|
||||
return false
|
||||
else if ctx.majorTypeIndices.size.any fun i => i.any fun j => ctx.majorTypeIndices[i]! == ctx.majorTypeIndices[j]! then
|
||||
else if ctx.majorTypeIndices.size.any fun i _ => i.any fun j _ => ctx.majorTypeIndices[i] == ctx.majorTypeIndices[j] then
|
||||
/- An index occurs more than once -/
|
||||
return false
|
||||
else
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user