mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-24 05:44:15 +00:00
Compare commits
25 Commits
array_boun
...
issue6067
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa83bcd070 | ||
|
|
1126407d9b | ||
|
|
a19ff61e15 | ||
|
|
6202461a21 | ||
|
|
ea221f3283 | ||
|
|
7c50d597c3 | ||
|
|
99031695bd | ||
|
|
b7248d5295 | ||
|
|
7f2e7e56d2 | ||
|
|
1fe66737ad | ||
|
|
765eb02279 | ||
|
|
a101377054 | ||
|
|
aca9929d84 | ||
|
|
19a701e5c9 | ||
|
|
fc4305ab15 | ||
|
|
9cf83706e7 | ||
|
|
459c6e2a46 | ||
|
|
72e952eadc | ||
|
|
56a80dec1b | ||
|
|
b894464191 | ||
|
|
b30903d1fc | ||
|
|
7fbe8e3b36 | ||
|
|
2fbc46641d | ||
|
|
17419aca7f | ||
|
|
f85c66789d |
24
.github/workflows/labels-from-comments.yml
vendored
24
.github/workflows/labels-from-comments.yml
vendored
@@ -1,7 +1,8 @@
|
||||
# This workflow allows any user to add one of the `awaiting-review`, `awaiting-author`, `WIP`,
|
||||
# or `release-ci` labels by commenting on the PR or issue.
|
||||
# `release-ci`, or a `changelog-XXX` label by commenting on the PR or issue.
|
||||
# If any labels from the set {`awaiting-review`, `awaiting-author`, `WIP`} are added, other labels
|
||||
# from that set are removed automatically at the same time.
|
||||
# Similarly, if any `changelog-XXX` label is added, other `changelog-YYY` labels are removed.
|
||||
|
||||
name: Label PR based on Comment
|
||||
|
||||
@@ -11,7 +12,7 @@ on:
|
||||
|
||||
jobs:
|
||||
update-label:
|
||||
if: github.event.issue.pull_request != null && (contains(github.event.comment.body, 'awaiting-review') || contains(github.event.comment.body, 'awaiting-author') || contains(github.event.comment.body, 'WIP') || contains(github.event.comment.body, 'release-ci'))
|
||||
if: github.event.issue.pull_request != null && (contains(github.event.comment.body, 'awaiting-review') || contains(github.event.comment.body, 'awaiting-author') || contains(github.event.comment.body, 'WIP') || contains(github.event.comment.body, 'release-ci') || contains(github.event.comment.body, 'changelog-'))
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
@@ -20,13 +21,14 @@ jobs:
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const { owner, repo, number: issue_number } = context.issue;
|
||||
const { owner, repo, number: issue_number } = context.issue;
|
||||
const commentLines = context.payload.comment.body.split('\r\n');
|
||||
|
||||
const awaitingReview = commentLines.includes('awaiting-review');
|
||||
const awaitingAuthor = commentLines.includes('awaiting-author');
|
||||
const wip = commentLines.includes('WIP');
|
||||
const releaseCI = commentLines.includes('release-ci');
|
||||
const changelogMatch = commentLines.find(line => line.startsWith('changelog-'));
|
||||
|
||||
if (awaitingReview || awaitingAuthor || wip) {
|
||||
await github.rest.issues.removeLabel({ owner, repo, issue_number, name: 'awaiting-review' }).catch(() => {});
|
||||
@@ -47,3 +49,19 @@ jobs:
|
||||
if (releaseCI) {
|
||||
await github.rest.issues.addLabels({ owner, repo, issue_number, labels: ['release-ci'] });
|
||||
}
|
||||
|
||||
if (changelogMatch) {
|
||||
const changelogLabel = changelogMatch.trim();
|
||||
const { data: existingLabels } = await github.rest.issues.listLabelsOnIssue({ owner, repo, issue_number });
|
||||
const changelogLabels = existingLabels.filter(label => label.name.startsWith('changelog-'));
|
||||
|
||||
// Remove all other changelog labels
|
||||
for (const label of changelogLabels) {
|
||||
if (label.name !== changelogLabel) {
|
||||
await github.rest.issues.removeLabel({ owner, repo, issue_number, name: label.name }).catch(() => {});
|
||||
}
|
||||
}
|
||||
|
||||
// Add the new changelog label
|
||||
await github.rest.issues.addLabels({ owner, repo, issue_number, labels: [changelogLabel] });
|
||||
}
|
||||
|
||||
@@ -233,7 +233,7 @@ def ofFn {n} (f : Fin n → α) : Array α := go 0 (mkEmpty n) where
|
||||
|
||||
/-- The array `#[0, 1, ..., n - 1]`. -/
|
||||
def range (n : Nat) : Array Nat :=
|
||||
n.fold (flip Array.push) (mkEmpty n)
|
||||
ofFn fun (i : Fin n) => i
|
||||
|
||||
def singleton (v : α) : Array α :=
|
||||
mkArray 1 v
|
||||
@@ -613,8 +613,15 @@ def findIdx? {α : Type u} (p : α → Bool) (as : Array α) : Option Nat :=
|
||||
decreasing_by simp_wf; decreasing_trivial_pre_omega
|
||||
loop 0
|
||||
|
||||
def getIdx? [BEq α] (a : Array α) (v : α) : Option Nat :=
|
||||
a.findIdx? fun a => a == v
|
||||
@[inline]
|
||||
def findFinIdx? {α : Type u} (p : α → Bool) (as : Array α) : Option (Fin as.size) :=
|
||||
let rec @[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
|
||||
loop (j : Nat) :=
|
||||
if h : j < as.size then
|
||||
if p as[j] then some ⟨j, h⟩ else loop (j + 1)
|
||||
else none
|
||||
decreasing_by simp_wf; decreasing_trivial_pre_omega
|
||||
loop 0
|
||||
|
||||
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
|
||||
def indexOfAux [BEq α] (a : Array α) (v : α) (i : Nat) : Option (Fin a.size) :=
|
||||
@@ -627,6 +634,10 @@ decreasing_by simp_wf; decreasing_trivial_pre_omega
|
||||
def indexOf? [BEq α] (a : Array α) (v : α) : Option (Fin a.size) :=
|
||||
indexOfAux a v 0
|
||||
|
||||
@[deprecated indexOf? (since := "2024-11-20")]
|
||||
def getIdx? [BEq α] (a : Array α) (v : α) : Option Nat :=
|
||||
a.findIdx? fun a => a == v
|
||||
|
||||
@[inline]
|
||||
def any (as : Array α) (p : α → Bool) (start := 0) (stop := as.size) : Bool :=
|
||||
Id.run <| as.anyM p start stop
|
||||
@@ -766,48 +777,62 @@ def takeWhile (p : α → Bool) (as : Array α) : Array α :=
|
||||
decreasing_by simp_wf; decreasing_trivial_pre_omega
|
||||
go 0 #[]
|
||||
|
||||
/-- Remove the element at a given index from an array without bounds checks, using a `Fin` index.
|
||||
/--
|
||||
Remove the element at a given index from an array without a runtime bounds checks,
|
||||
using a `Nat` index and a tactic-provided bound.
|
||||
|
||||
This function takes worst case O(n) time because
|
||||
it has to backshift all elements at positions greater than `i`.-/
|
||||
This function takes worst case O(n) time because
|
||||
it has to backshift all elements at positions greater than `i`.-/
|
||||
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
|
||||
def feraseIdx (a : Array α) (i : Fin a.size) : Array α :=
|
||||
if h : i.val + 1 < a.size then
|
||||
let a' := a.swap ⟨i.val + 1, h⟩ i
|
||||
let i' : Fin a'.size := ⟨i.val + 1, by simp [a', h]⟩
|
||||
a'.feraseIdx i'
|
||||
def eraseIdx (a : Array α) (i : Nat) (h : i < a.size := by get_elem_tactic) : Array α :=
|
||||
if h' : i + 1 < a.size then
|
||||
let a' := a.swap ⟨i + 1, h'⟩ ⟨i, h⟩
|
||||
a'.eraseIdx (i + 1) (by simp [a', h'])
|
||||
else
|
||||
a.pop
|
||||
termination_by a.size - i.val
|
||||
decreasing_by simp_wf; exact Nat.sub_succ_lt_self _ _ i.isLt
|
||||
termination_by a.size - i
|
||||
decreasing_by simp_wf; exact Nat.sub_succ_lt_self _ _ h
|
||||
|
||||
-- This is required in `Lean.Data.PersistentHashMap`.
|
||||
@[simp] theorem size_feraseIdx (a : Array α) (i : Fin a.size) : (a.feraseIdx i).size = a.size - 1 := by
|
||||
induction a, i using Array.feraseIdx.induct with
|
||||
| @case1 a i h a' _ ih =>
|
||||
unfold feraseIdx
|
||||
simp [h, a', ih]
|
||||
| case2 a i h =>
|
||||
unfold feraseIdx
|
||||
simp [h]
|
||||
@[simp] theorem size_eraseIdx (a : Array α) (i : Nat) (h) : (a.eraseIdx i h).size = a.size - 1 := by
|
||||
induction a, i, h using Array.eraseIdx.induct with
|
||||
| @case1 a i h h' a' ih =>
|
||||
unfold eraseIdx
|
||||
simp [h', a', ih]
|
||||
| case2 a i h h' =>
|
||||
unfold eraseIdx
|
||||
simp [h']
|
||||
|
||||
/-- Remove the element at a given index from an array, or do nothing if the index is out of bounds.
|
||||
|
||||
This function takes worst case O(n) time because
|
||||
it has to backshift all elements at positions greater than `i`.-/
|
||||
def eraseIdx (a : Array α) (i : Nat) : Array α :=
|
||||
if h : i < a.size then a.feraseIdx ⟨i, h⟩ else a
|
||||
def eraseIdxIfInBounds (a : Array α) (i : Nat) : Array α :=
|
||||
if h : i < a.size then a.eraseIdx i h else a
|
||||
|
||||
/-- Remove the element at a given index from an array, or panic if the index is out of bounds.
|
||||
|
||||
This function takes worst case O(n) time because
|
||||
it has to backshift all elements at positions greater than `i`. -/
|
||||
def eraseIdx! (a : Array α) (i : Nat) : Array α :=
|
||||
if h : i < a.size then a.eraseIdx i h else panic! "invalid index"
|
||||
|
||||
def erase [BEq α] (as : Array α) (a : α) : Array α :=
|
||||
match as.indexOf? a with
|
||||
| none => as
|
||||
| some i => as.feraseIdx i
|
||||
| some i => as.eraseIdx i
|
||||
|
||||
/-- Erase the first element that satisfies the predicate `p`. -/
|
||||
def eraseP (as : Array α) (p : α → Bool) : Array α :=
|
||||
match as.findIdx? p with
|
||||
| none => as
|
||||
| some i => as.eraseIdxIfInBounds i
|
||||
|
||||
/-- Insert element `a` at position `i`. -/
|
||||
@[inline] def insertAt (as : Array α) (i : Fin (as.size + 1)) (a : α) : Array α :=
|
||||
@[inline] def insertIdx (as : Array α) (i : Nat) (a : α) (_ : i ≤ as.size := by get_elem_tactic) : Array α :=
|
||||
let rec @[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
|
||||
loop (as : Array α) (j : Fin as.size) :=
|
||||
if i.1 < j then
|
||||
if i < j then
|
||||
let j' := ⟨j-1, Nat.lt_of_le_of_lt (Nat.pred_le _) j.2⟩
|
||||
let as := as.swap j' j
|
||||
loop as ⟨j', by rw [size_swap]; exact j'.2⟩
|
||||
@@ -818,12 +843,23 @@ def erase [BEq α] (as : Array α) (a : α) : Array α :=
|
||||
let as := as.push a
|
||||
loop as ⟨j, size_push .. ▸ j.lt_succ_self⟩
|
||||
|
||||
@[deprecated insertIdx (since := "2024-11-20")] abbrev insertAt := @insertIdx
|
||||
|
||||
/-- Insert element `a` at position `i`. Panics if `i` is not `i ≤ as.size`. -/
|
||||
def insertAt! (as : Array α) (i : Nat) (a : α) : Array α :=
|
||||
def insertIdx! (as : Array α) (i : Nat) (a : α) : Array α :=
|
||||
if h : i ≤ as.size then
|
||||
insertAt as ⟨i, Nat.lt_succ_of_le h⟩ a
|
||||
insertIdx as i a
|
||||
else panic! "invalid index"
|
||||
|
||||
@[deprecated insertIdx! (since := "2024-11-20")] abbrev insertAt! := @insertIdx!
|
||||
|
||||
/-- Insert element `a` at position `i`, or do nothing if `as.size < i`. -/
|
||||
def insertIdxIfInBounds (as : Array α) (i : Nat) (a : α) : Array α :=
|
||||
if h : i ≤ as.size then
|
||||
insertIdx as i a
|
||||
else
|
||||
as
|
||||
|
||||
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
|
||||
def isPrefixOfAux [BEq α] (as bs : Array α) (hle : as.size ≤ bs.size) (i : Nat) : Bool :=
|
||||
if h : i < as.size then
|
||||
|
||||
@@ -52,7 +52,7 @@ namespace Array
|
||||
let mid := (lo + hi)/2
|
||||
let midVal := as.get! mid
|
||||
if lt midVal k then
|
||||
if mid == lo then do let v ← add (); pure <| as.insertAt! (lo+1) v
|
||||
if mid == lo then do let v ← add (); pure <| as.insertIdx! (lo+1) v
|
||||
else binInsertAux lt merge add as k mid hi
|
||||
else if lt k midVal then
|
||||
binInsertAux lt merge add as k lo mid
|
||||
@@ -67,7 +67,7 @@ namespace Array
|
||||
(k : α) : m (Array α) :=
|
||||
let _ := Inhabited.mk k
|
||||
if as.isEmpty then do let v ← add (); pure <| as.push v
|
||||
else if lt k (as.get! 0) then do let v ← add (); pure <| as.insertAt! 0 v
|
||||
else if lt k (as.get! 0) then do let v ← add (); pure <| as.insertIdx! 0 v
|
||||
else if !lt (as.get! 0) k then as.modifyM 0 <| merge
|
||||
else if lt as.back! k then do let v ← add (); pure <| as.push v
|
||||
else if !lt k as.back! then as.modifyM (as.size - 1) <| merge
|
||||
|
||||
@@ -621,7 +621,20 @@ theorem getElem?_ofFn (f : Fin n → α) (i : Nat) :
|
||||
(ofFn f)[i]? = if h : i < n then some (f ⟨i, h⟩) else none := by
|
||||
simp [getElem?_def]
|
||||
|
||||
/-- # mkArray -/
|
||||
@[simp] theorem ofFn_zero (f : Fin 0 → α) : ofFn f = #[] := rfl
|
||||
|
||||
theorem ofFn_succ (f : Fin (n+1) → α) :
|
||||
ofFn f = (ofFn (fun (i : Fin n) => f i.castSucc)).push (f ⟨n, by omega⟩) := by
|
||||
ext i h₁ h₂
|
||||
· simp
|
||||
· simp [getElem_push]
|
||||
split <;> rename_i h₃
|
||||
· rfl
|
||||
· congr
|
||||
simp at h₁ h₂
|
||||
omega
|
||||
|
||||
/-! # mkArray -/
|
||||
|
||||
@[simp] theorem size_mkArray (n : Nat) (v : α) : (mkArray n v).size = n :=
|
||||
List.length_replicate ..
|
||||
@@ -637,7 +650,7 @@ theorem getElem?_mkArray (n : Nat) (v : α) (i : Nat) :
|
||||
(mkArray n v)[i]? = if i < n then some v else none := by
|
||||
simp [getElem?_def]
|
||||
|
||||
/-- # mem -/
|
||||
/-! # mem -/
|
||||
|
||||
@[simp] theorem mem_toList {a : α} {l : Array α} : a ∈ l.toList ↔ a ∈ l := mem_def.symm
|
||||
|
||||
@@ -659,7 +672,7 @@ theorem not_mem_nil (a : α) : ¬ a ∈ #[] := nofun
|
||||
(x ∈ if p then l else #[]) ↔ p ∧ x ∈ l := by
|
||||
split <;> simp_all
|
||||
|
||||
/-- # get lemmas -/
|
||||
/-! # get lemmas -/
|
||||
|
||||
theorem lt_of_getElem {x : α} {a : Array α} {idx : Nat} {hidx : idx < a.size} (_ : a[idx] = x) :
|
||||
idx < a.size :=
|
||||
@@ -840,16 +853,10 @@ theorem size_eq_length_toList (as : Array α) : as.size = as.toList.length := rf
|
||||
simp only [reverse]; split <;> simp [go]
|
||||
|
||||
@[simp] theorem size_range {n : Nat} : (range n).size = n := by
|
||||
unfold range
|
||||
induction n with
|
||||
| zero => simp [Nat.fold]
|
||||
| succ k ih =>
|
||||
rw [Nat.fold, flip]
|
||||
simp only [mkEmpty_eq, size_push] at *
|
||||
omega
|
||||
induction n <;> simp [range]
|
||||
|
||||
@[simp] theorem toList_range (n : Nat) : (range n).toList = List.range n := by
|
||||
induction n <;> simp_all [range, Nat.fold, flip, List.range_succ]
|
||||
apply List.ext_getElem <;> simp [range]
|
||||
|
||||
@[simp]
|
||||
theorem getElem_range {n : Nat} {x : Nat} (h : x < (Array.range n).size) : (Array.range n)[x] = x := by
|
||||
@@ -1637,9 +1644,9 @@ theorem swap_comm (a : Array α) {i j : Fin a.size} : a.swap i j = a.swap j i :=
|
||||
|
||||
/-! ### eraseIdx -/
|
||||
|
||||
theorem feraseIdx_eq_eraseIdx {a : Array α} {i : Fin a.size} :
|
||||
a.feraseIdx i = a.eraseIdx i.1 := by
|
||||
simp [eraseIdx]
|
||||
theorem eraseIdx_eq_eraseIdxIfInBounds {a : Array α} {i : Nat} (h : i < a.size) :
|
||||
a.eraseIdx i h = a.eraseIdxIfInBounds i := by
|
||||
simp [eraseIdxIfInBounds, h]
|
||||
|
||||
/-! ### isPrefixOf -/
|
||||
|
||||
@@ -1869,16 +1876,15 @@ theorem takeWhile_go_toArray (p : α → Bool) (l : List α) (i : Nat) :
|
||||
l.toArray.takeWhile p = (l.takeWhile p).toArray := by
|
||||
simp [Array.takeWhile, takeWhile_go_toArray]
|
||||
|
||||
@[simp] theorem feraseIdx_toArray (l : List α) (i : Fin l.toArray.size) :
|
||||
l.toArray.feraseIdx i = (l.eraseIdx i).toArray := by
|
||||
rw [feraseIdx]
|
||||
split <;> rename_i h
|
||||
· rw [feraseIdx_toArray]
|
||||
@[simp] theorem eraseIdx_toArray (l : List α) (i : Nat) (h : i < l.toArray.size) :
|
||||
l.toArray.eraseIdx i h = (l.eraseIdx i).toArray := by
|
||||
rw [Array.eraseIdx]
|
||||
split <;> rename_i h'
|
||||
· rw [eraseIdx_toArray]
|
||||
simp only [swap_toArray, Fin.getElem_fin, toList_toArray, mk.injEq]
|
||||
rw [eraseIdx_set_gt (by simp), eraseIdx_set_eq]
|
||||
simp
|
||||
· rcases i with ⟨i, w⟩
|
||||
simp at h w
|
||||
· simp at h h'
|
||||
have t : i = l.length - 1 := by omega
|
||||
simp [t]
|
||||
termination_by l.length - i
|
||||
@@ -1888,9 +1894,9 @@ decreasing_by
|
||||
simp
|
||||
omega
|
||||
|
||||
@[simp] theorem eraseIdx_toArray (l : List α) (i : Nat) :
|
||||
l.toArray.eraseIdx i = (l.eraseIdx i).toArray := by
|
||||
rw [Array.eraseIdx]
|
||||
@[simp] theorem eraseIdxIfInBounds_toArray (l : List α) (i : Nat) :
|
||||
l.toArray.eraseIdxIfInBounds i = (l.eraseIdx i).toArray := by
|
||||
rw [Array.eraseIdxIfInBounds]
|
||||
split
|
||||
· simp
|
||||
· simp_all [eraseIdx_eq_self.2]
|
||||
@@ -1909,13 +1915,13 @@ namespace Array
|
||||
(as.takeWhile p).toList = as.toList.takeWhile p := by
|
||||
induction as; simp
|
||||
|
||||
@[simp] theorem toList_feraseIdx (as : Array α) (i : Fin as.size) :
|
||||
(as.feraseIdx i).toList = as.toList.eraseIdx i.1 := by
|
||||
@[simp] theorem toList_eraseIdx (as : Array α) (i : Nat) (h : i < as.size) :
|
||||
(as.eraseIdx i h).toList = as.toList.eraseIdx i := by
|
||||
induction as
|
||||
simp
|
||||
|
||||
@[simp] theorem toList_eraseIdx (as : Array α) (i : Nat) :
|
||||
(as.eraseIdx i).toList = as.toList.eraseIdx i := by
|
||||
@[simp] theorem toList_eraseIdxIfInBounds (as : Array α) (i : Nat) :
|
||||
(as.eraseIdxIfInBounds i).toList = as.toList.eraseIdx i := by
|
||||
induction as
|
||||
simp
|
||||
|
||||
|
||||
@@ -346,6 +346,10 @@ theorem getMsbD_sub {i : Nat} {i_lt : i < w} {x y : BitVec w} :
|
||||
· rfl
|
||||
· omega
|
||||
|
||||
theorem getElem_sub {i : Nat} {x y : BitVec w} (h : i < w) :
|
||||
(x - y)[i] = (x[i] ^^ ((~~~y + 1#w)[i] ^^ carry i x (~~~y + 1#w) false)) := by
|
||||
simp [← getLsbD_eq_getElem, getLsbD_sub, h]
|
||||
|
||||
theorem msb_sub {x y: BitVec w} :
|
||||
(x - y).msb
|
||||
= (x.msb ^^ ((~~~y + 1#w).msb ^^ carry (w - 1 - 0) x (~~~y + 1#w) false)) := by
|
||||
@@ -410,6 +414,10 @@ theorem getLsbD_neg {i : Nat} {x : BitVec w} :
|
||||
· have h_ge : w ≤ i := by omega
|
||||
simp [getLsbD_ge _ _ h_ge, h_ge, hi]
|
||||
|
||||
theorem getElem_neg {i : Nat} {x : BitVec w} (h : i < w) :
|
||||
(-x)[i] = (x[i] ^^ decide (∃ j < i, x.getLsbD j = true)) := by
|
||||
simp [← getLsbD_eq_getElem, getLsbD_neg, h]
|
||||
|
||||
theorem getMsbD_neg {i : Nat} {x : BitVec w} :
|
||||
getMsbD (-x) i =
|
||||
(getMsbD x i ^^ decide (∃ j < w, i < j ∧ getMsbD x j = true)) := by
|
||||
|
||||
@@ -269,6 +269,10 @@ theorem ofBool_eq_iff_eq : ∀ {b b' : Bool}, BitVec.ofBool b = BitVec.ofBool b'
|
||||
getLsbD (x#'lt) i = x.testBit i := by
|
||||
simp [getLsbD, BitVec.ofNatLt]
|
||||
|
||||
@[simp] theorem getMsbD_ofNatLt {n x i : Nat} (h : x < 2^n) :
|
||||
getMsbD (x#'h) i = (decide (i < n) && x.testBit (n - 1 - i)) := by
|
||||
simp [getMsbD, getLsbD]
|
||||
|
||||
@[simp, bv_toNat] theorem toNat_ofNat (x w : Nat) : (BitVec.ofNat w x).toNat = x % 2^w := by
|
||||
simp [BitVec.toNat, BitVec.ofNat, Fin.ofNat']
|
||||
|
||||
@@ -755,6 +759,10 @@ theorem extractLsb'_eq_extractLsb {w : Nat} (x : BitVec w) (start len : Nat) (h
|
||||
@[simp] theorem getLsbD_allOnes : (allOnes v).getLsbD i = decide (i < v) := by
|
||||
simp [allOnes]
|
||||
|
||||
@[simp] theorem getMsbD_allOnes : (allOnes v).getMsbD i = decide (i < v) := by
|
||||
simp [allOnes]
|
||||
omega
|
||||
|
||||
@[simp] theorem getElem_allOnes (i : Nat) (h : i < v) : (allOnes v)[i] = true := by
|
||||
simp [getElem_eq_testBit_toNat, h]
|
||||
|
||||
@@ -772,6 +780,12 @@ theorem extractLsb'_eq_extractLsb {w : Nat} (x : BitVec w) (start len : Nat) (h
|
||||
@[simp] theorem toNat_or (x y : BitVec v) :
|
||||
BitVec.toNat (x ||| y) = BitVec.toNat x ||| BitVec.toNat y := rfl
|
||||
|
||||
@[simp] theorem toInt_or (x y : BitVec w) :
|
||||
BitVec.toInt (x ||| y) = Int.bmod (BitVec.toNat x ||| BitVec.toNat y) (2^w) := by
|
||||
rw_mod_cast [Int.bmod_def, BitVec.toInt, toNat_or, Nat.mod_eq_of_lt
|
||||
(Nat.or_lt_two_pow (BitVec.isLt x) (BitVec.isLt y))]
|
||||
omega
|
||||
|
||||
@[simp] theorem toFin_or (x y : BitVec v) :
|
||||
BitVec.toFin (x ||| y) = BitVec.toFin x ||| BitVec.toFin y := by
|
||||
apply Fin.eq_of_val_eq
|
||||
@@ -839,6 +853,12 @@ instance : Std.LawfulCommIdentity (α := BitVec n) (· ||| · ) (0#n) where
|
||||
@[simp] theorem toNat_and (x y : BitVec v) :
|
||||
BitVec.toNat (x &&& y) = BitVec.toNat x &&& BitVec.toNat y := rfl
|
||||
|
||||
@[simp] theorem toInt_and (x y : BitVec w) :
|
||||
BitVec.toInt (x &&& y) = Int.bmod (BitVec.toNat x &&& BitVec.toNat y) (2^w) := by
|
||||
rw_mod_cast [Int.bmod_def, BitVec.toInt, toNat_and, Nat.mod_eq_of_lt
|
||||
(Nat.and_lt_two_pow x.toNat (BitVec.isLt y))]
|
||||
omega
|
||||
|
||||
@[simp] theorem toFin_and (x y : BitVec v) :
|
||||
BitVec.toFin (x &&& y) = BitVec.toFin x &&& BitVec.toFin y := by
|
||||
apply Fin.eq_of_val_eq
|
||||
@@ -906,6 +926,12 @@ instance : Std.LawfulCommIdentity (α := BitVec n) (· &&& · ) (allOnes n) wher
|
||||
@[simp] theorem toNat_xor (x y : BitVec v) :
|
||||
BitVec.toNat (x ^^^ y) = BitVec.toNat x ^^^ BitVec.toNat y := rfl
|
||||
|
||||
@[simp] theorem toInt_xor (x y : BitVec w) :
|
||||
BitVec.toInt (x ^^^ y) = Int.bmod (BitVec.toNat x ^^^ BitVec.toNat y) (2^w) := by
|
||||
rw_mod_cast [Int.bmod_def, BitVec.toInt, toNat_xor, Nat.mod_eq_of_lt
|
||||
(Nat.xor_lt_two_pow (BitVec.isLt x) (BitVec.isLt y))]
|
||||
omega
|
||||
|
||||
@[simp] theorem toFin_xor (x y : BitVec v) :
|
||||
BitVec.toFin (x ^^^ y) = BitVec.toFin x ^^^ BitVec.toFin y := by
|
||||
apply Fin.eq_of_val_eq
|
||||
@@ -983,6 +1009,13 @@ theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl
|
||||
_ ≤ 2 ^ i := Nat.pow_le_pow_of_le_right Nat.zero_lt_two w
|
||||
· simp
|
||||
|
||||
@[simp] theorem toInt_not {x : BitVec w} :
|
||||
(~~~x).toInt = Int.bmod (2^w - 1 - x.toNat) (2^w) := by
|
||||
rw_mod_cast [BitVec.toInt, BitVec.toNat_not, Int.bmod_def]
|
||||
simp [show ((2^w : Nat) : Int) - 1 - x.toNat = ((2^w - 1 - x.toNat) : Nat) by omega]
|
||||
rw_mod_cast [Nat.mod_eq_of_lt (by omega)]
|
||||
omega
|
||||
|
||||
@[simp] theorem ofInt_negSucc_eq_not_ofNat {w n : Nat} :
|
||||
BitVec.ofInt w (Int.negSucc n) = ~~~.ofNat w n := by
|
||||
simp only [BitVec.ofInt, Int.toNat, Int.ofNat_eq_coe, toNat_eq, toNat_ofNatLt, toNat_not,
|
||||
@@ -1007,6 +1040,10 @@ theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl
|
||||
@[simp] theorem getLsbD_not {x : BitVec v} : (~~~x).getLsbD i = (decide (i < v) && ! x.getLsbD i) := by
|
||||
by_cases h' : i < v <;> simp_all [not_def]
|
||||
|
||||
@[simp] theorem getMsbD_not {x : BitVec v} :
|
||||
(~~~x).getMsbD i = (decide (i < v) && ! x.getMsbD i) := by
|
||||
by_cases h' : i < v <;> simp_all [not_def]
|
||||
|
||||
@[simp] theorem getElem_not {x : BitVec w} {i : Nat} (h : i < w) : (~~~x)[i] = !x[i] := by
|
||||
simp only [getElem_eq_testBit_toNat, toNat_not]
|
||||
rw [← Nat.sub_add_eq, Nat.add_comm 1]
|
||||
@@ -1480,6 +1517,12 @@ theorem getLsbD_sshiftRight' {x y: BitVec w} {i : Nat} :
|
||||
(!decide (w ≤ i) && if y.toNat + i < w then x.getLsbD (y.toNat + i) else x.msb) := by
|
||||
simp only [BitVec.sshiftRight', BitVec.getLsbD_sshiftRight]
|
||||
|
||||
@[simp]
|
||||
theorem getElem_sshiftRight' {x y : BitVec w} {i : Nat} (h : i < w) :
|
||||
(x.sshiftRight' y)[i] =
|
||||
(!decide (w ≤ i) && if y.toNat + i < w then x.getLsbD (y.toNat + i) else x.msb) := by
|
||||
simp only [← getLsbD_eq_getElem, BitVec.sshiftRight', BitVec.getLsbD_sshiftRight]
|
||||
|
||||
@[simp]
|
||||
theorem getMsbD_sshiftRight' {x y: BitVec w} {i : Nat} :
|
||||
(x.sshiftRight y.toNat).getMsbD i = (decide (i < w) && if i < y.toNat then x.msb else x.getMsbD (i - y.toNat)) := by
|
||||
@@ -3224,7 +3267,11 @@ theorem toNat_abs {x : BitVec w} : x.abs.toNat = if x.msb then 2^w - x.toNat els
|
||||
· simp [h]
|
||||
|
||||
theorem getLsbD_abs {i : Nat} {x : BitVec w} :
|
||||
getLsbD x.abs i = if x.msb then getLsbD (-x) i else getLsbD x i := by
|
||||
getLsbD x.abs i = if x.msb then getLsbD (-x) i else getLsbD x i := by
|
||||
by_cases h : x.msb <;> simp [BitVec.abs, h]
|
||||
|
||||
theorem getElem_abs {i : Nat} {x : BitVec w} (h : i < w) :
|
||||
x.abs[i] = if x.msb then (-x)[i] else x[i] := by
|
||||
by_cases h : x.msb <;> simp [BitVec.abs, h]
|
||||
|
||||
theorem getMsbD_abs {i : Nat} {x : BitVec w} :
|
||||
|
||||
@@ -31,7 +31,7 @@ opaque floatSpec : FloatSpec := {
|
||||
structure Float where
|
||||
val : floatSpec.float
|
||||
|
||||
instance : Inhabited Float := ⟨{ val := floatSpec.val }⟩
|
||||
instance : Nonempty Float := ⟨{ val := floatSpec.val }⟩
|
||||
|
||||
@[extern "lean_float_add"] opaque Float.add : Float → Float → Float
|
||||
@[extern "lean_float_sub"] opaque Float.sub : Float → Float → Float
|
||||
@@ -136,6 +136,9 @@ instance : ToString Float where
|
||||
|
||||
@[extern "lean_uint64_to_float"] opaque UInt64.toFloat (n : UInt64) : Float
|
||||
|
||||
instance : Inhabited Float where
|
||||
default := UInt64.toFloat 0
|
||||
|
||||
instance : Repr Float where
|
||||
reprPrec n prec := if n < UInt64.toFloat 0 then Repr.addAppParen (toString n) prec else toString n
|
||||
|
||||
|
||||
@@ -20,3 +20,4 @@ import Init.Data.Nat.Mod
|
||||
import Init.Data.Nat.Lcm
|
||||
import Init.Data.Nat.Compare
|
||||
import Init.Data.Nat.Simproc
|
||||
import Init.Data.Nat.Fold
|
||||
|
||||
@@ -35,52 +35,6 @@ Used as the default `Nat` eliminator by the `cases` tactic. -/
|
||||
protected abbrev casesAuxOn {motive : Nat → Sort u} (t : Nat) (zero : motive 0) (succ : (n : Nat) → motive (n + 1)) : motive t :=
|
||||
Nat.casesOn t zero succ
|
||||
|
||||
/--
|
||||
`Nat.fold` evaluates `f` on the numbers up to `n` exclusive, in increasing order:
|
||||
* `Nat.fold f 3 init = init |> f 0 |> f 1 |> f 2`
|
||||
-/
|
||||
@[specialize] def fold {α : Type u} (f : Nat → α → α) : (n : Nat) → (init : α) → α
|
||||
| 0, a => a
|
||||
| succ n, a => f n (fold f n a)
|
||||
|
||||
/-- Tail-recursive version of `Nat.fold`. -/
|
||||
@[inline] def foldTR {α : Type u} (f : Nat → α → α) (n : Nat) (init : α) : α :=
|
||||
let rec @[specialize] loop
|
||||
| 0, a => a
|
||||
| succ m, a => loop m (f (n - succ m) a)
|
||||
loop n init
|
||||
|
||||
/--
|
||||
`Nat.foldRev` evaluates `f` on the numbers up to `n` exclusive, in decreasing order:
|
||||
* `Nat.foldRev f 3 init = f 0 <| f 1 <| f 2 <| init`
|
||||
-/
|
||||
@[specialize] def foldRev {α : Type u} (f : Nat → α → α) : (n : Nat) → (init : α) → α
|
||||
| 0, a => a
|
||||
| succ n, a => foldRev f n (f n a)
|
||||
|
||||
/-- `any f n = true` iff there is `i in [0, n-1]` s.t. `f i = true` -/
|
||||
@[specialize] def any (f : Nat → Bool) : Nat → Bool
|
||||
| 0 => false
|
||||
| succ n => any f n || f n
|
||||
|
||||
/-- Tail-recursive version of `Nat.any`. -/
|
||||
@[inline] def anyTR (f : Nat → Bool) (n : Nat) : Bool :=
|
||||
let rec @[specialize] loop : Nat → Bool
|
||||
| 0 => false
|
||||
| succ m => f (n - succ m) || loop m
|
||||
loop n
|
||||
|
||||
/-- `all f n = true` iff every `i in [0, n-1]` satisfies `f i = true` -/
|
||||
@[specialize] def all (f : Nat → Bool) : Nat → Bool
|
||||
| 0 => true
|
||||
| succ n => all f n && f n
|
||||
|
||||
/-- Tail-recursive version of `Nat.all`. -/
|
||||
@[inline] def allTR (f : Nat → Bool) (n : Nat) : Bool :=
|
||||
let rec @[specialize] loop : Nat → Bool
|
||||
| 0 => true
|
||||
| succ m => f (n - succ m) && loop m
|
||||
loop n
|
||||
|
||||
/--
|
||||
`Nat.repeat f n a` is `f^(n) a`; that is, it iterates `f` `n` times on `a`.
|
||||
@@ -1158,33 +1112,6 @@ theorem not_lt_eq (a b : Nat) : (¬ (a < b)) = (b ≤ a) :=
|
||||
theorem not_gt_eq (a b : Nat) : (¬ (a > b)) = (a ≤ b) :=
|
||||
not_lt_eq b a
|
||||
|
||||
/-! # csimp theorems -/
|
||||
|
||||
@[csimp] theorem fold_eq_foldTR : @fold = @foldTR :=
|
||||
funext fun α => funext fun f => funext fun n => funext fun init =>
|
||||
let rec go : ∀ m n, foldTR.loop f (m + n) m (fold f n init) = fold f (m + n) init
|
||||
| 0, n => by simp [foldTR.loop]
|
||||
| succ m, n => by rw [foldTR.loop, add_sub_self_left, succ_add]; exact go m (succ n)
|
||||
(go n 0).symm
|
||||
|
||||
@[csimp] theorem any_eq_anyTR : @any = @anyTR :=
|
||||
funext fun f => funext fun n =>
|
||||
let rec go : ∀ m n, (any f n || anyTR.loop f (m + n) m) = any f (m + n)
|
||||
| 0, n => by simp [anyTR.loop]
|
||||
| succ m, n => by
|
||||
rw [anyTR.loop, add_sub_self_left, ← Bool.or_assoc, succ_add]
|
||||
exact go m (succ n)
|
||||
(go n 0).symm
|
||||
|
||||
@[csimp] theorem all_eq_allTR : @all = @allTR :=
|
||||
funext fun f => funext fun n =>
|
||||
let rec go : ∀ m n, (all f n && allTR.loop f (m + n) m) = all f (m + n)
|
||||
| 0, n => by simp [allTR.loop]
|
||||
| succ m, n => by
|
||||
rw [allTR.loop, add_sub_self_left, ← Bool.and_assoc, succ_add]
|
||||
exact go m (succ n)
|
||||
(go n 0).symm
|
||||
|
||||
@[csimp] theorem repeat_eq_repeatTR : @repeat = @repeatTR :=
|
||||
funext fun α => funext fun f => funext fun n => funext fun init =>
|
||||
let rec go : ∀ m n, repeatTR.loop f m (repeat f n init) = repeat f (m + n) init
|
||||
@@ -1193,31 +1120,3 @@ theorem not_gt_eq (a b : Nat) : (¬ (a > b)) = (a ≤ b) :=
|
||||
(go n 0).symm
|
||||
|
||||
end Nat
|
||||
|
||||
namespace Prod
|
||||
|
||||
/--
|
||||
`(start, stop).foldI f a` evaluates `f` on all the numbers
|
||||
from `start` (inclusive) to `stop` (exclusive) in increasing order:
|
||||
* `(5, 8).foldI f init = init |> f 5 |> f 6 |> f 7`
|
||||
-/
|
||||
@[inline] def foldI {α : Type u} (f : Nat → α → α) (i : Nat × Nat) (a : α) : α :=
|
||||
Nat.foldTR.loop f i.2 (i.2 - i.1) a
|
||||
|
||||
/--
|
||||
`(start, stop).anyI f a` returns true if `f` is true for some natural number
|
||||
from `start` (inclusive) to `stop` (exclusive):
|
||||
* `(5, 8).anyI f = f 5 || f 6 || f 7`
|
||||
-/
|
||||
@[inline] def anyI (f : Nat → Bool) (i : Nat × Nat) : Bool :=
|
||||
Nat.anyTR.loop f i.2 (i.2 - i.1)
|
||||
|
||||
/--
|
||||
`(start, stop).allI f a` returns true if `f` is true for all natural numbers
|
||||
from `start` (inclusive) to `stop` (exclusive):
|
||||
* `(5, 8).anyI f = f 5 && f 6 && f 7`
|
||||
-/
|
||||
@[inline] def allI (f : Nat → Bool) (i : Nat × Nat) : Bool :=
|
||||
Nat.allTR.loop f i.2 (i.2 - i.1)
|
||||
|
||||
end Prod
|
||||
|
||||
@@ -6,50 +6,51 @@ Author: Leonardo de Moura
|
||||
prelude
|
||||
import Init.Control.Basic
|
||||
import Init.Data.Nat.Basic
|
||||
import Init.Omega
|
||||
|
||||
namespace Nat
|
||||
universe u v
|
||||
|
||||
@[inline] def forM {m} [Monad m] (n : Nat) (f : Nat → m Unit) : m Unit :=
|
||||
let rec @[specialize] loop
|
||||
| 0 => pure ()
|
||||
| i+1 => do f (n-i-1); loop i
|
||||
loop n
|
||||
@[inline] def forM {m} [Monad m] (n : Nat) (f : (i : Nat) → i < n → m Unit) : m Unit :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → m Unit
|
||||
| 0, _ => pure ()
|
||||
| i+1, h => do f (n-i-1) (by omega); loop i (Nat.le_of_succ_le h)
|
||||
loop n (by simp)
|
||||
|
||||
@[inline] def forRevM {m} [Monad m] (n : Nat) (f : Nat → m Unit) : m Unit :=
|
||||
let rec @[specialize] loop
|
||||
| 0 => pure ()
|
||||
| i+1 => do f i; loop i
|
||||
loop n
|
||||
@[inline] def forRevM {m} [Monad m] (n : Nat) (f : (i : Nat) → i < n → m Unit) : m Unit :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → m Unit
|
||||
| 0, _ => pure ()
|
||||
| i+1, h => do f i (by omega); loop i (Nat.le_of_succ_le h)
|
||||
loop n (by simp)
|
||||
|
||||
@[inline] def foldM {α : Type u} {m : Type u → Type v} [Monad m] (f : Nat → α → m α) (init : α) (n : Nat) : m α :=
|
||||
let rec @[specialize] loop
|
||||
| 0, a => pure a
|
||||
| i+1, a => f (n-i-1) a >>= loop i
|
||||
loop n init
|
||||
@[inline] def foldM {α : Type u} {m : Type u → Type v} [Monad m] (n : Nat) (f : (i : Nat) → i < n → α → m α) (init : α) : m α :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → α → m α
|
||||
| 0, h, a => pure a
|
||||
| i+1, h, a => f (n-i-1) (by omega) a >>= loop i (Nat.le_of_succ_le h)
|
||||
loop n (by omega) init
|
||||
|
||||
@[inline] def foldRevM {α : Type u} {m : Type u → Type v} [Monad m] (f : Nat → α → m α) (init : α) (n : Nat) : m α :=
|
||||
let rec @[specialize] loop
|
||||
| 0, a => pure a
|
||||
| i+1, a => f i a >>= loop i
|
||||
loop n init
|
||||
@[inline] def foldRevM {α : Type u} {m : Type u → Type v} [Monad m] (n : Nat) (f : (i : Nat) → i < n → α → m α) (init : α) : m α :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → α → m α
|
||||
| 0, h, a => pure a
|
||||
| i+1, h, a => f i (by omega) a >>= loop i (Nat.le_of_succ_le h)
|
||||
loop n (by omega) init
|
||||
|
||||
@[inline] def allM {m} [Monad m] (n : Nat) (p : Nat → m Bool) : m Bool :=
|
||||
let rec @[specialize] loop
|
||||
| 0 => pure true
|
||||
| i+1 => do
|
||||
match (← p (n-i-1)) with
|
||||
| true => loop i
|
||||
@[inline] def allM {m} [Monad m] (n : Nat) (p : (i : Nat) → i < n → m Bool) : m Bool :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → m Bool
|
||||
| 0, _ => pure true
|
||||
| i+1 , h => do
|
||||
match (← p (n-i-1) (by omega)) with
|
||||
| true => loop i (by omega)
|
||||
| false => pure false
|
||||
loop n
|
||||
loop n (by simp)
|
||||
|
||||
@[inline] def anyM {m} [Monad m] (n : Nat) (p : Nat → m Bool) : m Bool :=
|
||||
let rec @[specialize] loop
|
||||
| 0 => pure false
|
||||
| i+1 => do
|
||||
match (← p (n-i-1)) with
|
||||
@[inline] def anyM {m} [Monad m] (n : Nat) (p : (i : Nat) → i < n → m Bool) : m Bool :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → m Bool
|
||||
| 0, _ => pure false
|
||||
| i+1, h => do
|
||||
match (← p (n-i-1) (by omega)) with
|
||||
| true => pure true
|
||||
| false => loop i
|
||||
loop n
|
||||
| false => loop i (Nat.le_of_succ_le h)
|
||||
loop n (by simp)
|
||||
|
||||
end Nat
|
||||
|
||||
168
src/Init/Data/Nat/Fold.lean
Normal file
168
src/Init/Data/Nat/Fold.lean
Normal file
@@ -0,0 +1,168 @@
|
||||
/-
|
||||
Copyright (c) 2014 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Floris van Doorn, Leonardo de Moura, Kim Morrison
|
||||
-/
|
||||
prelude
|
||||
import Init.Omega
|
||||
|
||||
set_option linter.missingDocs true -- keep it documented
|
||||
universe u
|
||||
|
||||
namespace Nat
|
||||
|
||||
/--
|
||||
`Nat.fold` evaluates `f` on the numbers up to `n` exclusive, in increasing order:
|
||||
* `Nat.fold f 3 init = init |> f 0 |> f 1 |> f 2`
|
||||
-/
|
||||
@[specialize] def fold {α : Type u} : (n : Nat) → (f : (i : Nat) → i < n → α → α) → (init : α) → α
|
||||
| 0, f, a => a
|
||||
| succ n, f, a => f n (by omega) (fold n (fun i h => f i (by omega)) a)
|
||||
|
||||
/-- Tail-recursive version of `Nat.fold`. -/
|
||||
@[inline] def foldTR {α : Type u} (n : Nat) (f : (i : Nat) → i < n → α → α) (init : α) : α :=
|
||||
let rec @[specialize] loop : ∀ j, j ≤ n → α → α
|
||||
| 0, h, a => a
|
||||
| succ m, h, a => loop m (by omega) (f (n - succ m) (by omega) a)
|
||||
loop n (by omega) init
|
||||
|
||||
/--
|
||||
`Nat.foldRev` evaluates `f` on the numbers up to `n` exclusive, in decreasing order:
|
||||
* `Nat.foldRev f 3 init = f 0 <| f 1 <| f 2 <| init`
|
||||
-/
|
||||
@[specialize] def foldRev {α : Type u} : (n : Nat) → (f : (i : Nat) → i < n → α → α) → (init : α) → α
|
||||
| 0, f, a => a
|
||||
| succ n, f, a => foldRev n (fun i h => f i (by omega)) (f n (by omega) a)
|
||||
|
||||
/-- `any f n = true` iff there is `i in [0, n-1]` s.t. `f i = true` -/
|
||||
@[specialize] def any : (n : Nat) → (f : (i : Nat) → i < n → Bool) → Bool
|
||||
| 0, f => false
|
||||
| succ n, f => any n (fun i h => f i (by omega)) || f n (by omega)
|
||||
|
||||
/-- Tail-recursive version of `Nat.any`. -/
|
||||
@[inline] def anyTR (n : Nat) (f : (i : Nat) → i < n → Bool) : Bool :=
|
||||
let rec @[specialize] loop : (i : Nat) → i ≤ n → Bool
|
||||
| 0, h => false
|
||||
| succ m, h => f (n - succ m) (by omega) || loop m (by omega)
|
||||
loop n (by omega)
|
||||
|
||||
/-- `all f n = true` iff every `i in [0, n-1]` satisfies `f i = true` -/
|
||||
@[specialize] def all : (n : Nat) → (f : (i : Nat) → i < n → Bool) → Bool
|
||||
| 0, f => true
|
||||
| succ n, f => all n (fun i h => f i (by omega)) && f n (by omega)
|
||||
|
||||
/-- Tail-recursive version of `Nat.all`. -/
|
||||
@[inline] def allTR (n : Nat) (f : (i : Nat) → i < n → Bool) : Bool :=
|
||||
let rec @[specialize] loop : (i : Nat) → i ≤ n → Bool
|
||||
| 0, h => true
|
||||
| succ m, h => f (n - succ m) (by omega) && loop m (by omega)
|
||||
loop n (by omega)
|
||||
|
||||
/-! # csimp theorems -/
|
||||
|
||||
theorem fold_congr {α : Type u} {n m : Nat} (w : n = m)
|
||||
(f : (i : Nat) → i < n → α → α) (init : α) :
|
||||
fold n f init = fold m (fun i h => f i (by omega)) init := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
theorem foldTR_loop_congr {α : Type u} {n m : Nat} (w : n = m)
|
||||
(f : (i : Nat) → i < n → α → α) (j : Nat) (h : j ≤ n) (init : α) :
|
||||
foldTR.loop n f j h init = foldTR.loop m (fun i h => f i (by omega)) j (by omega) init := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
@[csimp] theorem fold_eq_foldTR : @fold = @foldTR :=
|
||||
funext fun α => funext fun n => funext fun f => funext fun init =>
|
||||
let rec go : ∀ m n f, fold (m + n) f init = foldTR.loop (m + n) f m (by omega) (fold n (fun i h => f i (by omega)) init)
|
||||
| 0, n, f => by
|
||||
simp only [foldTR.loop]
|
||||
have t : 0 + n = n := by omega
|
||||
rw [fold_congr t]
|
||||
| succ m, n, f => by
|
||||
have t : (m + 1) + n = m + (n + 1) := by omega
|
||||
rw [foldTR.loop]
|
||||
simp only [succ_eq_add_one, Nat.add_sub_cancel]
|
||||
rw [fold_congr t, foldTR_loop_congr t, go, fold]
|
||||
congr
|
||||
omega
|
||||
go n 0 f
|
||||
|
||||
theorem any_congr {n m : Nat} (w : n = m) (f : (i : Nat) → i < n → Bool) : any n f = any m (fun i h => f i (by omega)) := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
theorem anyTR_loop_congr {n m : Nat} (w : n = m) (f : (i : Nat) → i < n → Bool) (j : Nat) (h : j ≤ n) :
|
||||
anyTR.loop n f j h = anyTR.loop m (fun i h => f i (by omega)) j (by omega) := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
@[csimp] theorem any_eq_anyTR : @any = @anyTR :=
|
||||
funext fun n => funext fun f =>
|
||||
let rec go : ∀ m n f, any (m + n) f = (any n (fun i h => f i (by omega)) || anyTR.loop (m + n) f m (by omega))
|
||||
| 0, n, f => by
|
||||
simp [anyTR.loop]
|
||||
have t : 0 + n = n := by omega
|
||||
rw [any_congr t]
|
||||
| succ m, n, f => by
|
||||
have t : (m + 1) + n = m + (n + 1) := by omega
|
||||
rw [anyTR.loop]
|
||||
simp only [succ_eq_add_one]
|
||||
rw [any_congr t, anyTR_loop_congr t, go, any, Bool.or_assoc]
|
||||
congr
|
||||
omega
|
||||
go n 0 f
|
||||
|
||||
theorem all_congr {n m : Nat} (w : n = m) (f : (i : Nat) → i < n → Bool) : all n f = all m (fun i h => f i (by omega)) := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
theorem allTR_loop_congr {n m : Nat} (w : n = m) (f : (i : Nat) → i < n → Bool) (j : Nat) (h : j ≤ n) : allTR.loop n f j h = allTR.loop m (fun i h => f i (by omega)) j (by omega) := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
@[csimp] theorem all_eq_allTR : @all = @allTR :=
|
||||
funext fun n => funext fun f =>
|
||||
let rec go : ∀ m n f, all (m + n) f = (all n (fun i h => f i (by omega)) && allTR.loop (m + n) f m (by omega))
|
||||
| 0, n, f => by
|
||||
simp [allTR.loop]
|
||||
have t : 0 + n = n := by omega
|
||||
rw [all_congr t]
|
||||
| succ m, n, f => by
|
||||
have t : (m + 1) + n = m + (n + 1) := by omega
|
||||
rw [allTR.loop]
|
||||
simp only [succ_eq_add_one]
|
||||
rw [all_congr t, allTR_loop_congr t, go, all, Bool.and_assoc]
|
||||
congr
|
||||
omega
|
||||
go n 0 f
|
||||
|
||||
end Nat
|
||||
|
||||
namespace Prod
|
||||
|
||||
/--
|
||||
`(start, stop).foldI f a` evaluates `f` on all the numbers
|
||||
from `start` (inclusive) to `stop` (exclusive) in increasing order:
|
||||
* `(5, 8).foldI f init = init |> f 5 |> f 6 |> f 7`
|
||||
-/
|
||||
@[inline] def foldI {α : Type u} (i : Nat × Nat) (f : (j : Nat) → i.1 ≤ j → j < i.2 → α → α) (a : α) : α :=
|
||||
(i.2 - i.1).fold (fun j _ => f (i.1 + j) (by omega) (by omega)) a
|
||||
|
||||
/--
|
||||
`(start, stop).anyI f a` returns true if `f` is true for some natural number
|
||||
from `start` (inclusive) to `stop` (exclusive):
|
||||
* `(5, 8).anyI f = f 5 || f 6 || f 7`
|
||||
-/
|
||||
@[inline] def anyI (i : Nat × Nat) (f : (j : Nat) → i.1 ≤ j → j < i.2 → Bool) : Bool :=
|
||||
(i.2 - i.1).any (fun j _ => f (i.1 + j) (by omega) (by omega))
|
||||
|
||||
/--
|
||||
`(start, stop).allI f a` returns true if `f` is true for all natural numbers
|
||||
from `start` (inclusive) to `stop` (exclusive):
|
||||
* `(5, 8).anyI f = f 5 && f 6 && f 7`
|
||||
-/
|
||||
@[inline] def allI (i : Nat × Nat) (f : (j : Nat) → i.1 ≤ j → j < i.2 → Bool) : Bool :=
|
||||
(i.2 - i.1).all (fun j _ => f (i.1 + j) (by omega) (by omega))
|
||||
|
||||
end Prod
|
||||
@@ -31,7 +31,7 @@ This file defines basic operations on the the sum type `α ⊕ β`.
|
||||
|
||||
## Further material
|
||||
|
||||
See `Batteries.Data.Sum.Lemmas` for theorems about these definitions.
|
||||
See `Init.Data.Sum.Lemmas` for theorems about these definitions.
|
||||
|
||||
## Notes
|
||||
|
||||
|
||||
@@ -462,6 +462,16 @@ Note that it is the caller's job to remove the file after use.
|
||||
-/
|
||||
@[extern "lean_io_create_tempfile"] opaque createTempFile : IO (Handle × FilePath)
|
||||
|
||||
/--
|
||||
Creates a temporary directory in the most secure manner possible. There are no race conditions in the
|
||||
directory’s creation. The directory is readable and writable only by the creating user ID.
|
||||
|
||||
Returns the new directory's path.
|
||||
|
||||
It is the caller's job to remove the directory after use.
|
||||
-/
|
||||
@[extern "lean_io_create_tempdir"] opaque createTempDir : IO FilePath
|
||||
|
||||
end FS
|
||||
|
||||
@[extern "lean_io_getenv"] opaque getEnv (var : @& String) : BaseIO (Option String)
|
||||
@@ -474,17 +484,6 @@ namespace FS
|
||||
def withFile (fn : FilePath) (mode : Mode) (f : Handle → IO α) : IO α :=
|
||||
Handle.mk fn mode >>= f
|
||||
|
||||
/--
|
||||
Like `createTempFile` but also takes care of removing the file after usage.
|
||||
-/
|
||||
def withTempFile [Monad m] [MonadFinally m] [MonadLiftT IO m] (f : Handle → FilePath → m α) :
|
||||
m α := do
|
||||
let (handle, path) ← createTempFile
|
||||
try
|
||||
f handle path
|
||||
finally
|
||||
removeFile path
|
||||
|
||||
def Handle.putStrLn (h : Handle) (s : String) : IO Unit :=
|
||||
h.putStr (s.push '\n')
|
||||
|
||||
@@ -675,8 +674,10 @@ def appDir : IO FilePath := do
|
||||
| throw <| IO.userError s!"System.IO.appDir: unexpected filename '{p}'"
|
||||
FS.realPath p
|
||||
|
||||
namespace FS
|
||||
|
||||
/-- Create given path and all missing parents as directories. -/
|
||||
partial def FS.createDirAll (p : FilePath) : IO Unit := do
|
||||
partial def createDirAll (p : FilePath) : IO Unit := do
|
||||
if ← p.isDir then
|
||||
return ()
|
||||
if let some parent := p.parent then
|
||||
@@ -693,7 +694,7 @@ partial def FS.createDirAll (p : FilePath) : IO Unit := do
|
||||
/--
|
||||
Fully remove given directory by deleting all contained files and directories in an unspecified order.
|
||||
Fails if any contained entry cannot be deleted or was newly created during execution. -/
|
||||
partial def FS.removeDirAll (p : FilePath) : IO Unit := do
|
||||
partial def removeDirAll (p : FilePath) : IO Unit := do
|
||||
for ent in (← p.readDir) do
|
||||
if (← ent.path.isDir : Bool) then
|
||||
removeDirAll ent.path
|
||||
@@ -701,6 +702,32 @@ partial def FS.removeDirAll (p : FilePath) : IO Unit := do
|
||||
removeFile ent.path
|
||||
removeDir p
|
||||
|
||||
/--
|
||||
Like `createTempFile`, but also takes care of removing the file after usage.
|
||||
-/
|
||||
def withTempFile [Monad m] [MonadFinally m] [MonadLiftT IO m] (f : Handle → FilePath → m α) :
|
||||
m α := do
|
||||
let (handle, path) ← createTempFile
|
||||
try
|
||||
f handle path
|
||||
finally
|
||||
removeFile path
|
||||
|
||||
/--
|
||||
Like `createTempDir`, but also takes care of removing the directory after usage.
|
||||
|
||||
All files in the directory are recursively deleted, regardless of how or when they were created.
|
||||
-/
|
||||
def withTempDir [Monad m] [MonadFinally m] [MonadLiftT IO m] (f : FilePath → m α) :
|
||||
m α := do
|
||||
let path ← createTempDir
|
||||
try
|
||||
f path
|
||||
finally
|
||||
removeDirAll path
|
||||
|
||||
end FS
|
||||
|
||||
namespace Process
|
||||
|
||||
/-- Returns the current working directory of the calling process. -/
|
||||
|
||||
@@ -428,11 +428,11 @@ macro "infer_instance" : tactic => `(tactic| exact inferInstance)
|
||||
/--
|
||||
`+opt` is short for `(opt := true)`. It sets the `opt` configuration option to `true`.
|
||||
-/
|
||||
syntax posConfigItem := "+" noWs ident
|
||||
syntax posConfigItem := " +" noWs ident
|
||||
/--
|
||||
`-opt` is short for `(opt := false)`. It sets the `opt` configuration option to `false`.
|
||||
-/
|
||||
syntax negConfigItem := "-" noWs ident
|
||||
syntax negConfigItem := " -" noWs ident
|
||||
/--
|
||||
`(opt := val)` sets the `opt` configuration option to `val`.
|
||||
|
||||
|
||||
@@ -205,8 +205,8 @@ def getParamInfo (k : ParamMap.Key) : M (Array Param) := do
|
||||
|
||||
/-- For each ps[i], if ps[i] is owned, then mark xs[i] as owned. -/
|
||||
def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
|
||||
xs.size.forM fun i => do
|
||||
let x := xs[i]!
|
||||
xs.size.forM fun i _ => do
|
||||
let x := xs[i]
|
||||
let p := ps[i]!
|
||||
unless p.borrow do ownArg x
|
||||
|
||||
@@ -216,8 +216,8 @@ def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
|
||||
we would have to insert a `dec xs[i]` after `f xs` and consequently
|
||||
"break" the tail call. -/
|
||||
def ownParamsUsingArgs (xs : Array Arg) (ps : Array Param) : M Unit :=
|
||||
xs.size.forM fun i => do
|
||||
let x := xs[i]!
|
||||
xs.size.forM fun i _ => do
|
||||
let x := xs[i]
|
||||
let p := ps[i]!
|
||||
match x with
|
||||
| Arg.var x => if (← isOwned x) then ownVar p.x
|
||||
|
||||
@@ -48,9 +48,9 @@ def requiresBoxedVersion (env : Environment) (decl : Decl) : Bool :=
|
||||
def mkBoxedVersionAux (decl : Decl) : N Decl := do
|
||||
let ps := decl.params
|
||||
let qs ← ps.mapM fun _ => do let x ← N.mkFresh; pure { x := x, ty := IRType.object, borrow := false : Param }
|
||||
let (newVDecls, xs) ← qs.size.foldM (init := (#[], #[])) fun i (newVDecls, xs) => do
|
||||
let (newVDecls, xs) ← qs.size.foldM (init := (#[], #[])) fun i _ (newVDecls, xs) => do
|
||||
let p := ps[i]!
|
||||
let q := qs[i]!
|
||||
let q := qs[i]
|
||||
if !p.ty.isScalar then
|
||||
pure (newVDecls, xs.push (Arg.var q.x))
|
||||
else
|
||||
|
||||
@@ -63,7 +63,7 @@ partial def merge (v₁ v₂ : Value) : Value :=
|
||||
| top, _ => top
|
||||
| _, top => top
|
||||
| v₁@(ctor i₁ vs₁), v₂@(ctor i₂ vs₂) =>
|
||||
if i₁ == i₂ then ctor i₁ <| vs₁.size.fold (init := #[]) fun i r => r.push (merge vs₁[i]! vs₂[i]!)
|
||||
if i₁ == i₂ then ctor i₁ <| vs₁.size.fold (init := #[]) fun i _ r => r.push (merge vs₁[i] vs₂[i]!)
|
||||
else choice [v₁, v₂]
|
||||
| choice vs₁, choice vs₂ => choice <| vs₁.foldl (addChoice merge) vs₂
|
||||
| choice vs, v => choice <| addChoice merge vs v
|
||||
@@ -225,8 +225,8 @@ def updateCurrFnSummary (v : Value) : M Unit := do
|
||||
def updateJPParamsAssignment (ys : Array Param) (xs : Array Arg) : M Bool := do
|
||||
let ctx ← read
|
||||
let currFnIdx := ctx.currFnIdx
|
||||
ys.size.foldM (init := false) fun i r => do
|
||||
let y := ys[i]!
|
||||
ys.size.foldM (init := false) fun i _ r => do
|
||||
let y := ys[i]
|
||||
let x := xs[i]!
|
||||
let yVal ← findVarValue y.x
|
||||
let xVal ← findArgValue x
|
||||
@@ -282,8 +282,8 @@ partial def interpFnBody : FnBody → M Unit
|
||||
def inferStep : M Bool := do
|
||||
let ctx ← read
|
||||
modify fun s => { s with assignments := ctx.decls.map fun _ => {} }
|
||||
ctx.decls.size.foldM (init := false) fun idx modified => do
|
||||
match ctx.decls[idx]! with
|
||||
ctx.decls.size.foldM (init := false) fun idx _ modified => do
|
||||
match ctx.decls[idx] with
|
||||
| .fdecl (xs := ys) (body := b) .. => do
|
||||
let s ← get
|
||||
let currVals := s.funVals[idx]!
|
||||
@@ -336,8 +336,8 @@ def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
let funVals := s.funVals
|
||||
let assignments := s.assignments
|
||||
modify fun s =>
|
||||
let env := decls.size.fold (init := s.env) fun i env =>
|
||||
addFunctionSummary env decls[i]!.name funVals[i]!
|
||||
let env := decls.size.fold (init := s.env) fun i _ env =>
|
||||
addFunctionSummary env decls[i].name funVals[i]!
|
||||
{ s with env := env }
|
||||
return decls.mapIdx fun i decl => elimDead assignments[i]! decl
|
||||
|
||||
|
||||
@@ -108,9 +108,9 @@ def emitFnDeclAux (decl : Decl) (cppBaseName : String) (isExternal : Bool) : M U
|
||||
if ps.size > closureMaxArgs && isBoxedName decl.name then
|
||||
emit "lean_object**"
|
||||
else
|
||||
ps.size.forM fun i => do
|
||||
ps.size.forM fun i _ => do
|
||||
if i > 0 then emit ", "
|
||||
emit (toCType ps[i]!.ty)
|
||||
emit (toCType ps[i].ty)
|
||||
emit ")"
|
||||
emitLn ";"
|
||||
|
||||
@@ -321,20 +321,22 @@ def emitSSet (x : VarId) (n : Nat) (offset : Nat) (y : VarId) (t : IRType) : M U
|
||||
|
||||
def emitJmp (j : JoinPointId) (xs : Array Arg) : M Unit := do
|
||||
let ps ← getJPParams j
|
||||
unless xs.size == ps.size do throw "invalid goto"
|
||||
xs.size.forM fun i => do
|
||||
let p := ps[i]!
|
||||
let x := xs[i]!
|
||||
emit p.x; emit " = "; emitArg x; emitLn ";"
|
||||
emit "goto "; emit j; emitLn ";"
|
||||
if h : xs.size = ps.size then
|
||||
xs.size.forM fun i _ => do
|
||||
let p := ps[i]
|
||||
let x := xs[i]
|
||||
emit p.x; emit " = "; emitArg x; emitLn ";"
|
||||
emit "goto "; emit j; emitLn ";"
|
||||
else
|
||||
do throw "invalid goto"
|
||||
|
||||
def emitLhs (z : VarId) : M Unit := do
|
||||
emit z; emit " = "
|
||||
|
||||
def emitArgs (ys : Array Arg) : M Unit :=
|
||||
ys.size.forM fun i => do
|
||||
ys.size.forM fun i _ => do
|
||||
if i > 0 then emit ", "
|
||||
emitArg ys[i]!
|
||||
emitArg ys[i]
|
||||
|
||||
def emitCtorScalarSize (usize : Nat) (ssize : Nat) : M Unit := do
|
||||
if usize == 0 then emit ssize
|
||||
@@ -346,8 +348,8 @@ def emitAllocCtor (c : CtorInfo) : M Unit := do
|
||||
emitCtorScalarSize c.usize c.ssize; emitLn ");"
|
||||
|
||||
def emitCtorSetArgs (z : VarId) (ys : Array Arg) : M Unit :=
|
||||
ys.size.forM fun i => do
|
||||
emit "lean_ctor_set("; emit z; emit ", "; emit i; emit ", "; emitArg ys[i]!; emitLn ");"
|
||||
ys.size.forM fun i _ => do
|
||||
emit "lean_ctor_set("; emit z; emit ", "; emit i; emit ", "; emitArg ys[i]; emitLn ");"
|
||||
|
||||
def emitCtor (z : VarId) (c : CtorInfo) (ys : Array Arg) : M Unit := do
|
||||
emitLhs z;
|
||||
@@ -358,7 +360,7 @@ def emitCtor (z : VarId) (c : CtorInfo) (ys : Array Arg) : M Unit := do
|
||||
|
||||
def emitReset (z : VarId) (n : Nat) (x : VarId) : M Unit := do
|
||||
emit "if (lean_is_exclusive("; emit x; emitLn ")) {";
|
||||
n.forM fun i => do
|
||||
n.forM fun i _ => do
|
||||
emit " lean_ctor_release("; emit x; emit ", "; emit i; emitLn ");"
|
||||
emit " "; emitLhs z; emit x; emitLn ";";
|
||||
emitLn "} else {";
|
||||
@@ -399,12 +401,12 @@ def emitSimpleExternalCall (f : String) (ps : Array Param) (ys : Array Arg) : M
|
||||
emit f; emit "("
|
||||
-- We must remove irrelevant arguments to extern calls.
|
||||
discard <| ys.size.foldM
|
||||
(fun i (first : Bool) =>
|
||||
(fun i _ (first : Bool) =>
|
||||
if ps[i]!.ty.isIrrelevant then
|
||||
pure first
|
||||
else do
|
||||
unless first do emit ", "
|
||||
emitArg ys[i]!
|
||||
emitArg ys[i]
|
||||
pure false)
|
||||
true
|
||||
emitLn ");"
|
||||
@@ -431,8 +433,8 @@ def emitPartialApp (z : VarId) (f : FunId) (ys : Array Arg) : M Unit := do
|
||||
let decl ← getDecl f
|
||||
let arity := decl.params.size;
|
||||
emitLhs z; emit "lean_alloc_closure((void*)("; emitCName f; emit "), "; emit arity; emit ", "; emit ys.size; emitLn ");";
|
||||
ys.size.forM fun i => do
|
||||
let y := ys[i]!
|
||||
ys.size.forM fun i _ => do
|
||||
let y := ys[i]
|
||||
emit "lean_closure_set("; emit z; emit ", "; emit i; emit ", "; emitArg y; emitLn ");"
|
||||
|
||||
def emitApp (z : VarId) (f : VarId) (ys : Array Arg) : M Unit :=
|
||||
@@ -544,34 +546,36 @@ That is, we have
|
||||
-/
|
||||
def overwriteParam (ps : Array Param) (ys : Array Arg) : Bool :=
|
||||
let n := ps.size;
|
||||
n.any fun i =>
|
||||
let p := ps[i]!
|
||||
(i+1, n).anyI fun j => paramEqArg p ys[j]!
|
||||
n.any fun i _ =>
|
||||
let p := ps[i]
|
||||
(i+1, n).anyI fun j _ _ => paramEqArg p ys[j]!
|
||||
|
||||
def emitTailCall (v : Expr) : M Unit :=
|
||||
match v with
|
||||
| Expr.fap _ ys => do
|
||||
let ctx ← read
|
||||
let ps := ctx.mainParams
|
||||
unless ps.size == ys.size do throw "invalid tail call"
|
||||
if overwriteParam ps ys then
|
||||
emitLn "{"
|
||||
ps.size.forM fun i => do
|
||||
let p := ps[i]!
|
||||
let y := ys[i]!
|
||||
unless paramEqArg p y do
|
||||
emit (toCType p.ty); emit " _tmp_"; emit i; emit " = "; emitArg y; emitLn ";"
|
||||
ps.size.forM fun i => do
|
||||
let p := ps[i]!
|
||||
let y := ys[i]!
|
||||
unless paramEqArg p y do emit p.x; emit " = _tmp_"; emit i; emitLn ";"
|
||||
emitLn "}"
|
||||
if h : ps.size = ys.size then
|
||||
if overwriteParam ps ys then
|
||||
emitLn "{"
|
||||
ps.size.forM fun i _ => do
|
||||
let p := ps[i]
|
||||
let y := ys[i]
|
||||
unless paramEqArg p y do
|
||||
emit (toCType p.ty); emit " _tmp_"; emit i; emit " = "; emitArg y; emitLn ";"
|
||||
ps.size.forM fun i _ => do
|
||||
let p := ps[i]
|
||||
let y := ys[i]
|
||||
unless paramEqArg p y do emit p.x; emit " = _tmp_"; emit i; emitLn ";"
|
||||
emitLn "}"
|
||||
else
|
||||
ys.size.forM fun i _ => do
|
||||
let p := ps[i]
|
||||
let y := ys[i]
|
||||
unless paramEqArg p y do emit p.x; emit " = "; emitArg y; emitLn ";"
|
||||
emitLn "goto _start;"
|
||||
else
|
||||
ys.size.forM fun i => do
|
||||
let p := ps[i]!
|
||||
let y := ys[i]!
|
||||
unless paramEqArg p y do emit p.x; emit " = "; emitArg y; emitLn ";"
|
||||
emitLn "goto _start;"
|
||||
throw "invalid tail call"
|
||||
| _ => throw "bug at emitTailCall"
|
||||
|
||||
mutual
|
||||
@@ -654,16 +658,16 @@ def emitDeclAux (d : Decl) : M Unit := do
|
||||
if xs.size > closureMaxArgs && isBoxedName d.name then
|
||||
emit "lean_object** _args"
|
||||
else
|
||||
xs.size.forM fun i => do
|
||||
xs.size.forM fun i _ => do
|
||||
if i > 0 then emit ", "
|
||||
let x := xs[i]!
|
||||
let x := xs[i]
|
||||
emit (toCType x.ty); emit " "; emit x.x
|
||||
emit ")"
|
||||
else
|
||||
emit ("_init_" ++ baseName ++ "()")
|
||||
emitLn " {";
|
||||
if xs.size > closureMaxArgs && isBoxedName d.name then
|
||||
xs.size.forM fun i => do
|
||||
xs.size.forM fun i _ => do
|
||||
let x := xs[i]!
|
||||
emit "lean_object* "; emit x.x; emit " = _args["; emit i; emitLn "];"
|
||||
emitLn "_start:";
|
||||
|
||||
@@ -571,9 +571,9 @@ def emitAllocCtor (builder : LLVM.Builder llvmctx)
|
||||
|
||||
def emitCtorSetArgs (builder : LLVM.Builder llvmctx)
|
||||
(z : VarId) (ys : Array Arg) : M llvmctx Unit := do
|
||||
ys.size.forM fun i => do
|
||||
ys.size.forM fun i _ => do
|
||||
let zv ← emitLhsVal builder z
|
||||
let (_yty, yv) ← emitArgVal builder ys[i]!
|
||||
let (_yty, yv) ← emitArgVal builder ys[i]
|
||||
let iv ← constIntUnsigned i
|
||||
callLeanCtorSet builder zv iv yv
|
||||
emitLhsSlotStore builder z zv
|
||||
@@ -702,8 +702,8 @@ def emitPartialApp (builder : LLVM.Builder llvmctx) (z : VarId) (f : FunId) (ys
|
||||
(← constIntUnsigned arity)
|
||||
(← constIntUnsigned ys.size)
|
||||
LLVM.buildStore builder zval zslot
|
||||
ys.size.forM fun i => do
|
||||
let (yty, yslot) ← emitArgSlot_ builder ys[i]!
|
||||
ys.size.forM fun i _ => do
|
||||
let (yty, yslot) ← emitArgSlot_ builder ys[i]
|
||||
let yval ← LLVM.buildLoad2 builder yty yslot
|
||||
callLeanClosureSetFn builder zval (← constIntUnsigned i) yval
|
||||
|
||||
@@ -922,7 +922,7 @@ def emitReset (builder : LLVM.Builder llvmctx) (z : VarId) (n : Nat) (x : VarId)
|
||||
buildIfThenElse_ builder "isExclusive" isExclusive
|
||||
(fun builder => do
|
||||
let xv ← emitLhsVal builder x
|
||||
n.forM fun i => do
|
||||
n.forM fun i _ => do
|
||||
callLeanCtorRelease builder xv (← constIntUnsigned i)
|
||||
emitLhsSlotStore builder z xv
|
||||
return ShouldForwardControlFlow.yes
|
||||
|
||||
@@ -134,15 +134,15 @@ abbrev M := ReaderT Context (StateM Nat)
|
||||
modifyGet fun n => ({ idx := n }, n + 1)
|
||||
|
||||
def releaseUnreadFields (y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
|
||||
mask.size.foldM (init := b) fun i b =>
|
||||
match mask.get! i with
|
||||
mask.size.foldM (init := b) fun i _ b =>
|
||||
match mask[i] with
|
||||
| some _ => pure b -- code took ownership of this field
|
||||
| none => do
|
||||
let fld ← mkFresh
|
||||
pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true false b))
|
||||
|
||||
def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
|
||||
zs.size.fold (init := b) fun i b => FnBody.set y i (zs.get! i) b
|
||||
zs.size.fold (init := b) fun i _ b => FnBody.set y i zs[i] b
|
||||
|
||||
/-- Given `set x[i] := y`, return true iff `y := proj[i] x` -/
|
||||
def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=
|
||||
|
||||
@@ -79,13 +79,13 @@ private def addDecForAlt (ctx : Context) (caseLiveVars altLiveVars : LiveVarSet)
|
||||
/-- `isFirstOcc xs x i = true` if `xs[i]` is the first occurrence of `xs[i]` in `xs` -/
|
||||
private def isFirstOcc (xs : Array Arg) (i : Nat) : Bool :=
|
||||
let x := xs[i]!
|
||||
i.all fun j => xs[j]! != x
|
||||
i.all fun j _ => xs[j]! != x
|
||||
|
||||
/-- Return true if `x` also occurs in `ys` in a position that is not consumed.
|
||||
That is, it is also passed as a borrow reference. -/
|
||||
private def isBorrowParamAux (x : VarId) (ys : Array Arg) (consumeParamPred : Nat → Bool) : Bool :=
|
||||
ys.size.any fun i =>
|
||||
let y := ys[i]!
|
||||
ys.size.any fun i _ =>
|
||||
let y := ys[i]
|
||||
match y with
|
||||
| Arg.irrelevant => false
|
||||
| Arg.var y => x == y && !consumeParamPred i
|
||||
@@ -99,15 +99,15 @@ Return `n`, the number of times `x` is consumed.
|
||||
- `consumeParamPred i = true` if parameter `i` is consumed.
|
||||
-/
|
||||
private def getNumConsumptions (x : VarId) (ys : Array Arg) (consumeParamPred : Nat → Bool) : Nat :=
|
||||
ys.size.fold (init := 0) fun i n =>
|
||||
let y := ys[i]!
|
||||
ys.size.fold (init := 0) fun i _ n =>
|
||||
let y := ys[i]
|
||||
match y with
|
||||
| Arg.irrelevant => n
|
||||
| Arg.var y => if x == y && consumeParamPred i then n+1 else n
|
||||
|
||||
private def addIncBeforeAux (ctx : Context) (xs : Array Arg) (consumeParamPred : Nat → Bool) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody :=
|
||||
xs.size.fold (init := b) fun i b =>
|
||||
let x := xs[i]!
|
||||
xs.size.fold (init := b) fun i _ b =>
|
||||
let x := xs[i]
|
||||
match x with
|
||||
| Arg.irrelevant => b
|
||||
| Arg.var x =>
|
||||
@@ -128,8 +128,8 @@ private def addIncBefore (ctx : Context) (xs : Array Arg) (ps : Array Param) (b
|
||||
|
||||
/-- See `addIncBeforeAux`/`addIncBefore` for the procedure that inserts `inc` operations before an application. -/
|
||||
private def addDecAfterFullApp (ctx : Context) (xs : Array Arg) (ps : Array Param) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody :=
|
||||
xs.size.fold (init := b) fun i b =>
|
||||
match xs[i]! with
|
||||
xs.size.fold (init := b) fun i _ b =>
|
||||
match xs[i] with
|
||||
| Arg.irrelevant => b
|
||||
| Arg.var x =>
|
||||
/- We must add a `dec` if `x` must be consumed, it is alive after the application,
|
||||
|
||||
@@ -366,10 +366,10 @@ to be updated.
|
||||
@[implemented_by updateFunDeclCoreImp] opaque FunDeclCore.updateCore (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl
|
||||
|
||||
def CasesCore.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases :=
|
||||
let found (i : Nat) := (cases.alts[i]!, { cases with alts := cases.alts.eraseIdx i })
|
||||
if let some i := cases.alts.findIdx? fun | .alt ctorName' .. => ctorName == ctorName' | _ => false then
|
||||
let found i := (cases.alts[i], { cases with alts := cases.alts.eraseIdx i })
|
||||
if let some i := cases.alts.findFinIdx? fun | .alt ctorName' .. => ctorName == ctorName' | _ => false then
|
||||
found i
|
||||
else if let some i := cases.alts.findIdx? fun | .default _ => true | _ => false then
|
||||
else if let some i := cases.alts.findFinIdx? fun | .default _ => true | _ => false then
|
||||
found i
|
||||
else
|
||||
unreachable!
|
||||
|
||||
@@ -587,15 +587,15 @@ def Decl.elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
refer to the docstring of `Decl.safe`.
|
||||
-/
|
||||
if decls[i]!.safe then .bot else .top
|
||||
let mut funVals := decls.size.fold (init := .empty) fun i p => p.push (initialVal i)
|
||||
let mut funVals := decls.size.fold (init := .empty) fun i _ p => p.push (initialVal i)
|
||||
let ctx := { decls }
|
||||
let mut state := { assignments, funVals }
|
||||
(_, state) ← inferMain |>.run ctx |>.run state
|
||||
funVals := state.funVals
|
||||
assignments := state.assignments
|
||||
modifyEnv fun e =>
|
||||
decls.size.fold (init := e) fun i env =>
|
||||
addFunctionSummary env decls[i]!.name funVals[i]!
|
||||
decls.size.fold (init := e) fun i _ env =>
|
||||
addFunctionSummary env decls[i].name funVals[i]!
|
||||
|
||||
decls.mapIdxM fun i decl => if decl.safe then elimDead assignments[i]! decl else return decl
|
||||
|
||||
|
||||
@@ -76,8 +76,8 @@ def getType (fvarId : FVarId) : InferTypeM Expr := do
|
||||
|
||||
def mkForallFVars (xs : Array Expr) (type : Expr) : InferTypeM Expr :=
|
||||
let b := type.abstract xs
|
||||
xs.size.foldRevM (init := b) fun i b => do
|
||||
let x := xs[i]!
|
||||
xs.size.foldRevM (init := b) fun i _ b => do
|
||||
let x := xs[i]
|
||||
let n ← InferType.getBinderName x.fvarId!
|
||||
let ty ← InferType.getType x.fvarId!
|
||||
let ty := ty.abstractRange i xs;
|
||||
|
||||
@@ -134,9 +134,9 @@ def withEachOccurrence (targetName : Name) (f : Nat → PassInstaller) : PassIns
|
||||
|
||||
def installAfter (targetName : Name) (p : Pass → Pass) (occurrence : Nat := 0) : PassInstaller where
|
||||
install passes :=
|
||||
if let some idx := passes.findIdx? (fun p => p.name == targetName && p.occurrence == occurrence) then
|
||||
let passUnderTest := passes[idx]!
|
||||
return passes.insertAt! (idx + 1) (p passUnderTest)
|
||||
if let some idx := passes.findFinIdx? (fun p => p.name == targetName && p.occurrence == occurrence) then
|
||||
let passUnderTest := passes[idx]
|
||||
return passes.insertIdx (idx + 1) (p passUnderTest)
|
||||
else
|
||||
throwError s!"Tried to insert pass after {targetName}, occurrence {occurrence} but {targetName} is not in the pass list"
|
||||
|
||||
@@ -145,9 +145,9 @@ def installAfterEach (targetName : Name) (p : Pass → Pass) : PassInstaller :=
|
||||
|
||||
def installBefore (targetName : Name) (p : Pass → Pass) (occurrence : Nat := 0): PassInstaller where
|
||||
install passes :=
|
||||
if let some idx := passes.findIdx? (fun p => p.name == targetName && p.occurrence == occurrence) then
|
||||
let passUnderTest := passes[idx]!
|
||||
return passes.insertAt! idx (p passUnderTest)
|
||||
if let some idx := passes.findFinIdx? (fun p => p.name == targetName && p.occurrence == occurrence) then
|
||||
let passUnderTest := passes[idx]
|
||||
return passes.insertIdx idx (p passUnderTest)
|
||||
else
|
||||
throwError s!"Tried to insert pass after {targetName}, occurrence {occurrence} but {targetName} is not in the pass list"
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ private def elabSpecArgs (declName : Name) (args : Array Syntax) : MetaM (Array
|
||||
result := result.push idx
|
||||
else
|
||||
let argName := arg.getId
|
||||
if let some idx := argNames.getIdx? argName then
|
||||
if let some idx := argNames.indexOf? argName then
|
||||
if result.contains idx then throwErrorAt arg "invalid specialization argument name `{argName}`, it has already been specified as a specialization candidate"
|
||||
result := result.push idx
|
||||
else
|
||||
|
||||
@@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Nat.Fold
|
||||
import Init.Data.Array.Basic
|
||||
import Init.NotationExtra
|
||||
import Init.Data.ToString.Macro
|
||||
@@ -371,7 +372,7 @@ instance : ToString Stats := ⟨Stats.toString⟩
|
||||
end PersistentArray
|
||||
|
||||
def mkPersistentArray {α : Type u} (n : Nat) (v : α) : PArray α :=
|
||||
n.fold (init := PersistentArray.empty) fun _ p => p.push v
|
||||
n.fold (init := PersistentArray.empty) fun _ _ p => p.push v
|
||||
|
||||
@[inline] def mkPArray {α : Type u} (n : Nat) (v : α) : PArray α :=
|
||||
mkPersistentArray n v
|
||||
|
||||
@@ -233,10 +233,10 @@ partial def eraseAux [BEq α] : Node α β → USize → α → Node α β
|
||||
| n@(Node.collision keys vals heq), _, k =>
|
||||
match keys.indexOf? k with
|
||||
| some idx =>
|
||||
let keys' := keys.feraseIdx idx
|
||||
have keq := keys.size_feraseIdx idx
|
||||
let vals' := vals.feraseIdx (Eq.ndrec idx heq)
|
||||
have veq := vals.size_feraseIdx (Eq.ndrec idx heq)
|
||||
let keys' := keys.eraseIdx idx
|
||||
have keq := keys.size_eraseIdx idx _
|
||||
let vals' := vals.eraseIdx (Eq.ndrec idx heq)
|
||||
have veq := vals.size_eraseIdx (Eq.ndrec idx heq) _
|
||||
have : keys.size - 1 = vals.size - 1 := by rw [heq]
|
||||
Node.collision keys' vals' (keq.trans (this.trans veq.symm))
|
||||
| none => n
|
||||
|
||||
@@ -23,6 +23,7 @@ import Lean.Elab.Quotation
|
||||
import Lean.Elab.Syntax
|
||||
import Lean.Elab.Do
|
||||
import Lean.Elab.StructInst
|
||||
import Lean.Elab.MutualInductive
|
||||
import Lean.Elab.Inductive
|
||||
import Lean.Elab.Structure
|
||||
import Lean.Elab.Print
|
||||
|
||||
@@ -807,8 +807,8 @@ def getElabElimExprInfo (elimExpr : Expr) : MetaM ElabElimInfo := do
|
||||
These are the primary set of major parameters.
|
||||
-/
|
||||
let initMotiveFVars : CollectFVars.State := motiveArgs.foldl (init := {}) collectFVars
|
||||
let motiveFVars ← xs.size.foldRevM (init := initMotiveFVars) fun i s => do
|
||||
let x := xs[i]!
|
||||
let motiveFVars ← xs.size.foldRevM (init := initMotiveFVars) fun i _ s => do
|
||||
let x := xs[i]
|
||||
if s.fvarSet.contains x.fvarId! then
|
||||
return collectFVars s (← inferType x)
|
||||
else
|
||||
@@ -1347,7 +1347,7 @@ where
|
||||
let mut unusableNamedArgs := unusableNamedArgs
|
||||
for x in xs, bInfo in bInfos do
|
||||
let xDecl ← x.mvarId!.getDecl
|
||||
if let some idx := remainingNamedArgs.findIdx? (·.name == xDecl.userName) then
|
||||
if let some idx := remainingNamedArgs.findFinIdx? (·.name == xDecl.userName) then
|
||||
/- If there is named argument with name `xDecl.userName`, then it is accounted for and we can't make use of it. -/
|
||||
remainingNamedArgs := remainingNamedArgs.eraseIdx idx
|
||||
else
|
||||
@@ -1355,9 +1355,9 @@ where
|
||||
/- We found a type of the form (baseName ...).
|
||||
First, we check if the current argument is an explicit one,
|
||||
and if the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/
|
||||
if argIdx ≤ args.size && bInfo.isExplicit then
|
||||
if h : argIdx ≤ args.size ∧ bInfo.isExplicit then
|
||||
/- We can insert `e` as an explicit argument -/
|
||||
return (args.insertAt! argIdx (Arg.expr e), namedArgs)
|
||||
return (args.insertIdx argIdx (Arg.expr e), namedArgs)
|
||||
else
|
||||
/- If we can't add `e` to `args`, we try to add it using a named argument, but this is only possible
|
||||
if there isn't an argument with the same name occurring before it. -/
|
||||
|
||||
@@ -211,7 +211,7 @@ private def replaceBinderAnnotation (binder : TSyntax ``Parser.Term.bracketedBin
|
||||
else
|
||||
`(bracketedBinderF| {$id $[: $ty?]?})
|
||||
for id in ids.reverse do
|
||||
if let some idx := binderIds.findIdx? fun binderId => binderId.raw.isIdent && binderId.raw.getId == id.raw.getId then
|
||||
if let some idx := binderIds.findFinIdx? fun binderId => binderId.raw.isIdent && binderId.raw.getId == id.raw.getId then
|
||||
binderIds := binderIds.eraseIdx idx
|
||||
modifiedVarDecls := true
|
||||
varDeclsNew := varDeclsNew.push (← mkBinder id explicit)
|
||||
|
||||
@@ -7,9 +7,8 @@ prelude
|
||||
import Lean.Util.CollectLevelParams
|
||||
import Lean.Elab.DeclUtil
|
||||
import Lean.Elab.DefView
|
||||
import Lean.Elab.Inductive
|
||||
import Lean.Elab.Structure
|
||||
import Lean.Elab.MutualDef
|
||||
import Lean.Elab.MutualInductive
|
||||
import Lean.Elab.DeclarationRange
|
||||
namespace Lean.Elab.Command
|
||||
|
||||
@@ -163,15 +162,11 @@ def elabDeclaration : CommandElab := fun stx => do
|
||||
if declKind == ``Lean.Parser.Command.«axiom» then
|
||||
let modifiers ← elabModifiers modifiers
|
||||
elabAxiom modifiers decl
|
||||
else if declKind == ``Lean.Parser.Command.«inductive» then
|
||||
else if declKind == ``Lean.Parser.Command.«inductive»
|
||||
|| declKind == ``Lean.Parser.Command.classInductive
|
||||
|| declKind == ``Lean.Parser.Command.«structure» then
|
||||
let modifiers ← elabModifiers modifiers
|
||||
elabInductive modifiers decl
|
||||
else if declKind == ``Lean.Parser.Command.classInductive then
|
||||
let modifiers ← elabModifiers modifiers
|
||||
elabClassInductive modifiers decl
|
||||
else if declKind == ``Lean.Parser.Command.«structure» then
|
||||
let modifiers ← elabModifiers modifiers
|
||||
elabStructure modifiers decl
|
||||
else
|
||||
throwError "unexpected declaration"
|
||||
|
||||
@@ -278,10 +273,10 @@ def elabMutual : CommandElab := fun stx => do
|
||||
-- only case implementing incrementality currently
|
||||
elabMutualDef stx[1].getArgs
|
||||
else withoutCommandIncrementality true do
|
||||
if isMutualInductive stx then
|
||||
if ← isMutualInductive stx then
|
||||
elabMutualInductive stx[1].getArgs
|
||||
else
|
||||
throwError "invalid mutual block: either all elements of the block must be inductive declarations, or they must all be definitions/theorems/abbrevs"
|
||||
throwError "invalid mutual block: either all elements of the block must be inductive/structure declarations, or they must all be definitions/theorems/abbrevs"
|
||||
|
||||
/- leading_parser "attribute " >> "[" >> sepBy1 (eraseAttr <|> Term.attrInstance) ", " >> "]" >> many1 ident -/
|
||||
@[builtin_command_elab «attribute»] def elabAttr : CommandElab := fun stx => do
|
||||
|
||||
@@ -1518,7 +1518,7 @@ mutual
|
||||
-/
|
||||
partial def doForToCode (doFor : Syntax) (doElems : List Syntax) : M CodeBlock := do
|
||||
let doForDecls := doFor[1].getSepArgs
|
||||
if doForDecls.size > 1 then
|
||||
if h : doForDecls.size > 1 then
|
||||
/-
|
||||
Expand
|
||||
```
|
||||
|
||||
@@ -102,7 +102,7 @@ partial def IO.processCommandsIncrementally (inputCtx : Parser.InputContext)
|
||||
where
|
||||
go initialSnap t commands :=
|
||||
let snap := t.get
|
||||
let commands := commands.push snap.data
|
||||
let commands := commands.push snap
|
||||
if let some next := snap.nextCmdSnap? then
|
||||
go initialSnap next.task commands
|
||||
else
|
||||
@@ -115,9 +115,9 @@ where
|
||||
-- snapshots as they subsume any info trees reported incrementally by their children.
|
||||
let trees := commands.map (·.finishedSnap.get.infoTree?) |>.filterMap id |>.toPArray'
|
||||
return {
|
||||
commandState := { snap.data.finishedSnap.get.cmdState with messages, infoState.trees := trees }
|
||||
parserState := snap.data.parserState
|
||||
cmdPos := snap.data.parserState.pos
|
||||
commandState := { snap.finishedSnap.get.cmdState with messages, infoState.trees := trees }
|
||||
parserState := snap.parserState
|
||||
cmdPos := snap.parserState.pos
|
||||
commands := commands.map (·.stx)
|
||||
inputCtx, initialSnap
|
||||
}
|
||||
@@ -164,8 +164,8 @@ def runFrontend
|
||||
| return (← mkEmptyEnvironment, false)
|
||||
|
||||
if let some out := trace.profiler.output.get? opts then
|
||||
let traceState := cmdState.traceState
|
||||
let profile ← Firefox.Profile.export mainModuleName.toString startTime traceState opts
|
||||
let traceStates := snaps.getAll.map (·.traces)
|
||||
let profile ← Firefox.Profile.export mainModuleName.toString startTime traceStates opts
|
||||
IO.FS.writeFile ⟨out⟩ <| Json.compress <| toJson profile
|
||||
|
||||
let hasErrors := snaps.getAll.any (·.diagnostics.msgLog.hasErrors)
|
||||
|
||||
@@ -1,102 +1,36 @@
|
||||
/-
|
||||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
Authors: Leonardo de Moura, Kyle Miller
|
||||
-/
|
||||
prelude
|
||||
import Lean.Util.ForEachExprWhere
|
||||
import Lean.Util.ReplaceLevel
|
||||
import Lean.Util.ReplaceExpr
|
||||
import Lean.Util.CollectLevelParams
|
||||
import Lean.Meta.Constructions
|
||||
import Lean.Meta.CollectFVars
|
||||
import Lean.Meta.SizeOf
|
||||
import Lean.Meta.Injective
|
||||
import Lean.Meta.IndPredBelow
|
||||
import Lean.Elab.Command
|
||||
import Lean.Elab.ComputedFields
|
||||
import Lean.Elab.DefView
|
||||
import Lean.Elab.DeclUtil
|
||||
import Lean.Elab.Deriving.Basic
|
||||
import Lean.Elab.DeclarationRange
|
||||
import Lean.Elab.MutualInductive
|
||||
|
||||
namespace Lean.Elab.Command
|
||||
open Meta
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Elab.inductive
|
||||
|
||||
register_builtin_option inductive.autoPromoteIndices : Bool := {
|
||||
defValue := true
|
||||
descr := "Promote indices to parameters in inductive types whenever possible."
|
||||
}
|
||||
|
||||
def checkValidInductiveModifier [Monad m] [MonadError m] (modifiers : Modifiers) : m Unit := do
|
||||
if modifiers.isNoncomputable then
|
||||
throwError "invalid use of 'noncomputable' in inductive declaration"
|
||||
if modifiers.isPartial then
|
||||
throwError "invalid use of 'partial' in inductive declaration"
|
||||
|
||||
def checkValidCtorModifier [Monad m] [MonadError m] (modifiers : Modifiers) : m Unit := do
|
||||
if modifiers.isNoncomputable then
|
||||
throwError "invalid use of 'noncomputable' in constructor declaration"
|
||||
if modifiers.isPartial then
|
||||
throwError "invalid use of 'partial' in constructor declaration"
|
||||
if modifiers.isUnsafe then
|
||||
throwError "invalid use of 'unsafe' in constructor declaration"
|
||||
if modifiers.attrs.size != 0 then
|
||||
throwError "invalid use of attributes in constructor declaration"
|
||||
|
||||
structure CtorView where
|
||||
ref : Syntax
|
||||
modifiers : Modifiers
|
||||
declName : Name
|
||||
binders : Syntax
|
||||
type? : Option Syntax
|
||||
deriving Inhabited
|
||||
|
||||
structure ComputedFieldView where
|
||||
ref : Syntax
|
||||
modifiers : Syntax
|
||||
fieldId : Name
|
||||
type : Syntax.Term
|
||||
matchAlts : TSyntax ``Parser.Term.matchAlts
|
||||
|
||||
structure InductiveView where
|
||||
ref : Syntax
|
||||
declId : Syntax
|
||||
modifiers : Modifiers
|
||||
shortDeclName : Name
|
||||
declName : Name
|
||||
levelNames : List Name
|
||||
binders : Syntax
|
||||
type? : Option Syntax
|
||||
ctors : Array CtorView
|
||||
derivingClasses : Array DerivingClassView
|
||||
computedFields : Array ComputedFieldView
|
||||
deriving Inhabited
|
||||
|
||||
structure ElabHeaderResult where
|
||||
view : InductiveView
|
||||
lctx : LocalContext
|
||||
localInsts : LocalInstances
|
||||
levelNames : List Name
|
||||
params : Array Expr
|
||||
type : Expr
|
||||
deriving Inhabited
|
||||
|
||||
/-
|
||||
leading_parser "inductive " >> declId >> optDeclSig >> optional ("where" <|> ":=") >> many ctor
|
||||
leading_parser atomic (group ("class " >> "inductive ")) >> declId >> optDeclSig >> optional ("where" <|> ":=") >> many ctor >> optDeriving
|
||||
```
|
||||
def Lean.Parser.Command.«inductive» :=
|
||||
leading_parser "inductive " >> declId >> optDeclSig >> optional ("where" <|> ":=") >> many ctor
|
||||
|
||||
def Lean.Parser.Command.classInductive :=
|
||||
leading_parser atomic (group ("class " >> "inductive ")) >> declId >> optDeclSig >> optional ("where" <|> ":=") >> many ctor >> optDeriving
|
||||
```
|
||||
-/
|
||||
private def inductiveSyntaxToView (modifiers : Modifiers) (decl : Syntax) : TermElabM InductiveView := do
|
||||
checkValidInductiveModifier modifiers
|
||||
let isClass := decl.isOfKind ``Parser.Command.classInductive
|
||||
let modifiers := if isClass then modifiers.addAttr { name := `class } else modifiers
|
||||
let (binders, type?) := expandOptDeclSig decl[2]
|
||||
let declId := decl[1]
|
||||
let ⟨name, declName, levelNames⟩ ← Term.expandDeclId (← getCurrNamespace) (← Term.getLevelNames) declId modifiers
|
||||
addDeclarationRangesForBuiltin declName modifiers.stx decl
|
||||
let ctors ← decl[4].getArgs.mapM fun ctor => withRef ctor do
|
||||
-- def ctor := leading_parser optional docComment >> "\n| " >> declModifiers >> rawIdent >> optDeclSig
|
||||
/-
|
||||
```
|
||||
def ctor := leading_parser optional docComment >> "\n| " >> declModifiers >> rawIdent >> optDeclSig
|
||||
```
|
||||
-/
|
||||
let mut ctorModifiers ← elabModifiers ⟨ctor[2]⟩
|
||||
if let some leadingDocComment := ctor[0].getOptional? then
|
||||
if ctorModifiers.docString?.isSome then
|
||||
@@ -113,7 +47,7 @@ private def inductiveSyntaxToView (modifiers : Modifiers) (decl : Syntax) : Term
|
||||
let (binders, type?) := expandOptDeclSig ctor[4]
|
||||
addDocString' ctorName ctorModifiers.docString?
|
||||
addDeclarationRangesFromSyntax ctorName ctor ctor[3]
|
||||
return { ref := ctor, modifiers := ctorModifiers, declName := ctorName, binders := binders, type? := type? : CtorView }
|
||||
return { ref := ctor, declId := ctor[3], modifiers := ctorModifiers, declName := ctorName, binders := binders, type? := type? : CtorView }
|
||||
let computedFields ← (decl[5].getOptional?.map (·[1].getArgs) |>.getD #[]).mapM fun cf => withRef cf do
|
||||
return { ref := cf, modifiers := cf[0], fieldId := cf[1].getId, type := ⟨cf[3]⟩, matchAlts := ⟨cf[4]⟩ }
|
||||
let classes ← getOptDerivingClasses decl[6]
|
||||
@@ -125,137 +59,13 @@ private def inductiveSyntaxToView (modifiers : Modifiers) (decl : Syntax) : Term
|
||||
ref := decl
|
||||
shortDeclName := name
|
||||
derivingClasses := classes
|
||||
declId, modifiers, declName, levelNames
|
||||
allowIndices := true
|
||||
allowSortPolymorphism := true
|
||||
declId, modifiers, isClass, declName, levelNames
|
||||
binders, type?, ctors
|
||||
computedFields
|
||||
}
|
||||
|
||||
private partial def elabHeaderAux (views : Array InductiveView) (i : Nat) (acc : Array ElabHeaderResult) : TermElabM (Array ElabHeaderResult) :=
|
||||
Term.withAutoBoundImplicitForbiddenPred (fun n => views.any (·.shortDeclName == n)) do
|
||||
if h : i < views.size then
|
||||
let view := views[i]
|
||||
let acc ← Term.withAutoBoundImplicit <| Term.elabBinders view.binders.getArgs fun params => do
|
||||
match view.type? with
|
||||
| none =>
|
||||
let u ← mkFreshLevelMVar
|
||||
let type := mkSort u
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
Term.addAutoBoundImplicits' params type fun params type => do
|
||||
let levelNames ← Term.getLevelNames
|
||||
return acc.push { lctx := (← getLCtx), localInsts := (← getLocalInstances), levelNames, params, type, view }
|
||||
| some typeStx =>
|
||||
let (type, _) ← Term.withAutoBoundImplicit do
|
||||
let type ← Term.elabType typeStx
|
||||
unless (← isTypeFormerType type) do
|
||||
throwErrorAt typeStx "invalid inductive type, resultant type is not a sort"
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let indices ← Term.addAutoBoundImplicits #[]
|
||||
return (← mkForallFVars indices type, indices.size)
|
||||
Term.addAutoBoundImplicits' params type fun params type => do
|
||||
trace[Elab.inductive] "header params: {params}, type: {type}"
|
||||
let levelNames ← Term.getLevelNames
|
||||
return acc.push { lctx := (← getLCtx), localInsts := (← getLocalInstances), levelNames, params, type, view }
|
||||
elabHeaderAux views (i+1) acc
|
||||
else
|
||||
return acc
|
||||
|
||||
private def checkNumParams (rs : Array ElabHeaderResult) : TermElabM Nat := do
|
||||
let numParams := rs[0]!.params.size
|
||||
for r in rs do
|
||||
unless r.params.size == numParams do
|
||||
throwErrorAt r.view.ref "invalid inductive type, number of parameters mismatch in mutually inductive datatypes"
|
||||
return numParams
|
||||
|
||||
private def checkUnsafe (rs : Array ElabHeaderResult) : TermElabM Unit := do
|
||||
let isUnsafe := rs[0]!.view.modifiers.isUnsafe
|
||||
for r in rs do
|
||||
unless r.view.modifiers.isUnsafe == isUnsafe do
|
||||
throwErrorAt r.view.ref "invalid inductive type, cannot mix unsafe and safe declarations in a mutually inductive datatypes"
|
||||
|
||||
private def InductiveView.checkLevelNames (views : Array InductiveView) : TermElabM Unit := do
|
||||
if h : views.size > 1 then
|
||||
let levelNames := views[0].levelNames
|
||||
for view in views do
|
||||
unless view.levelNames == levelNames do
|
||||
throwErrorAt view.ref "invalid inductive type, universe parameters mismatch in mutually inductive datatypes"
|
||||
|
||||
private def ElabHeaderResult.checkLevelNames (rs : Array ElabHeaderResult) : TermElabM Unit := do
|
||||
if h : rs.size > 1 then
|
||||
let levelNames := rs[0].levelNames
|
||||
for r in rs do
|
||||
unless r.levelNames == levelNames do
|
||||
throwErrorAt r.view.ref "invalid inductive type, universe parameters mismatch in mutually inductive datatypes"
|
||||
|
||||
private def mkTypeFor (r : ElabHeaderResult) : TermElabM Expr := do
|
||||
withLCtx r.lctx r.localInsts do
|
||||
mkForallFVars r.params r.type
|
||||
|
||||
private def throwUnexpectedInductiveType : TermElabM α :=
|
||||
throwError "unexpected inductive resulting type"
|
||||
|
||||
private def eqvFirstTypeResult (firstType type : Expr) : MetaM Bool :=
|
||||
forallTelescopeReducing firstType fun _ firstTypeResult => isDefEq firstTypeResult type
|
||||
|
||||
-- Auxiliary function for checking whether the types in mutually inductive declaration are compatible.
|
||||
private partial def checkParamsAndResultType (type firstType : Expr) (numParams : Nat) : TermElabM Unit := do
|
||||
try
|
||||
forallTelescopeCompatible type firstType numParams fun _ type firstType =>
|
||||
forallTelescopeReducing type fun _ type =>
|
||||
forallTelescopeReducing firstType fun _ firstType => do
|
||||
let type ← whnfD type
|
||||
match type with
|
||||
| .sort .. =>
|
||||
unless (← isDefEq firstType type) do
|
||||
throwError "resulting universe mismatch, given{indentExpr type}\nexpected type{indentExpr firstType}"
|
||||
| _ =>
|
||||
throwError "unexpected inductive resulting type"
|
||||
catch
|
||||
| Exception.error ref msg => throw (Exception.error ref m!"invalid mutually inductive types, {msg}")
|
||||
| ex => throw ex
|
||||
|
||||
-- Auxiliary function for checking whether the types in mutually inductive declaration are compatible.
|
||||
private def checkHeader (r : ElabHeaderResult) (numParams : Nat) (firstType? : Option Expr) : TermElabM Expr := do
|
||||
let type ← mkTypeFor r
|
||||
match firstType? with
|
||||
| none => return type
|
||||
| some firstType =>
|
||||
withRef r.view.ref <| checkParamsAndResultType type firstType numParams
|
||||
return firstType
|
||||
|
||||
-- Auxiliary function for checking whether the types in mutually inductive declaration are compatible.
|
||||
private partial def checkHeaders (rs : Array ElabHeaderResult) (numParams : Nat) (i : Nat) (firstType? : Option Expr) : TermElabM Unit := do
|
||||
if h : i < rs.size then
|
||||
let type ← checkHeader rs[i] numParams firstType?
|
||||
checkHeaders rs numParams (i+1) type
|
||||
|
||||
private def elabHeader (views : Array InductiveView) : TermElabM (Array ElabHeaderResult) := do
|
||||
let rs ← elabHeaderAux views 0 #[]
|
||||
if rs.size > 1 then
|
||||
checkUnsafe rs
|
||||
let numParams ← checkNumParams rs
|
||||
checkHeaders rs numParams 0 none
|
||||
return rs
|
||||
|
||||
/-- Create a local declaration for each inductive type in `rs`, and execute `x params indFVars`, where `params` are the inductive type parameters and
|
||||
`indFVars` are the new local declarations.
|
||||
We use the local context/instances and parameters of rs[0].
|
||||
Note that this method is executed after we executed `checkHeaders` and established all
|
||||
parameters are compatible. -/
|
||||
private partial def withInductiveLocalDecls (rs : Array ElabHeaderResult) (x : Array Expr → Array Expr → TermElabM α) : TermElabM α := do
|
||||
let namesAndTypes ← rs.mapM fun r => do
|
||||
let type ← mkTypeFor r
|
||||
pure (r.view.declName, r.view.shortDeclName, type)
|
||||
let r0 := rs[0]!
|
||||
let params := r0.params
|
||||
withLCtx r0.lctx r0.localInsts <| withRef r0.view.ref do
|
||||
let rec loop (i : Nat) (indFVars : Array Expr) := do
|
||||
if h : i < namesAndTypes.size then
|
||||
let (declName, shortDeclName, type) := namesAndTypes[i]
|
||||
Term.withAuxDecl shortDeclName type declName fun indFVar => loop (i+1) (indFVars.push indFVar)
|
||||
else
|
||||
x params indFVars
|
||||
loop 0 #[]
|
||||
|
||||
private def isInductiveFamily (numParams : Nat) (indFVar : Expr) : TermElabM Bool := do
|
||||
let indFVarType ← inferType indFVar
|
||||
forallTelescopeReducing indFVarType fun xs _ =>
|
||||
@@ -271,8 +81,8 @@ where
|
||||
| _ => acc
|
||||
|
||||
/--
|
||||
Replace binder names in `type` with `newNames`.
|
||||
Remark: we only replace the names for binder containing macroscopes.
|
||||
Replaces binder names in `type` with `newNames`.
|
||||
Remark: we only replace the names for binder containing macroscopes.
|
||||
-/
|
||||
private def replaceArrowBinderNames (type : Expr) (newNames : Array Name) : Expr :=
|
||||
go type 0
|
||||
@@ -290,20 +100,20 @@ where
|
||||
type
|
||||
|
||||
/--
|
||||
Reorder constructor arguments to improve the effectiveness of the `fixedIndicesToParams` method.
|
||||
Reorders constructor arguments to improve the effectiveness of the `fixedIndicesToParams` method.
|
||||
|
||||
The idea is quite simple. Given a constructor type of the form
|
||||
```
|
||||
(a₁ : A₁) → ... → (aₙ : Aₙ) → C b₁ ... bₘ
|
||||
```
|
||||
We try to find the longest prefix `b₁ ... bᵢ`, `i ≤ m` s.t.
|
||||
- each `bₖ` is in `{a₁, ..., aₙ}`
|
||||
- each `bₖ` only depends on variables in `{b₁, ..., bₖ₋₁}`
|
||||
The idea is quite simple. Given a constructor type of the form
|
||||
```
|
||||
(a₁ : A₁) → ... → (aₙ : Aₙ) → C b₁ ... bₘ
|
||||
```
|
||||
We try to find the longest prefix `b₁ ... bᵢ`, `i ≤ m` s.t.
|
||||
- each `bₖ` is in `{a₁, ..., aₙ}`
|
||||
- each `bₖ` only depends on variables in `{b₁, ..., bₖ₋₁}`
|
||||
|
||||
Then, it moves this prefix `b₁ ... bᵢ` to the front.
|
||||
Then, it moves this prefix `b₁ ... bᵢ` to the front.
|
||||
|
||||
Remark: We only reorder implicit arguments that have macroscopes. See issue #1156.
|
||||
The macroscope test is an approximation, we could have restricted ourselves to auto-implicit arguments.
|
||||
Remark: We only reorder implicit arguments that have macroscopes. See issue #1156.
|
||||
The macroscope test is an approximation, we could have restricted ourselves to auto-implicit arguments.
|
||||
-/
|
||||
private def reorderCtorArgs (ctorType : Expr) : MetaM Expr := do
|
||||
forallTelescopeReducing ctorType fun as type => do
|
||||
@@ -348,16 +158,6 @@ private def reorderCtorArgs (ctorType : Expr) : MetaM Expr := do
|
||||
let binderNames := getArrowBinderNames (← instantiateMVars (← inferType C))
|
||||
return replaceArrowBinderNames r binderNames[:bsPrefix.size]
|
||||
|
||||
/--
|
||||
Execute `k` with updated binder information for `xs`. Any `x` that is explicit becomes implicit.
|
||||
-/
|
||||
private def withExplicitToImplicit (xs : Array Expr) (k : TermElabM α) : TermElabM α := do
|
||||
let mut toImplicit := #[]
|
||||
for x in xs do
|
||||
if (← getFVarLocalDecl x).binderInfo.isExplicit then
|
||||
toImplicit := toImplicit.push (x.fvarId!, BinderInfo.implicit)
|
||||
withNewBinderInfos toImplicit k
|
||||
|
||||
/--
|
||||
Elaborate constructor types.
|
||||
|
||||
@@ -366,17 +166,20 @@ private def withExplicitToImplicit (xs : Array Expr) (k : TermElabM α) : TermEl
|
||||
- Positivity (it is a rare failure, and the kernel already checks for it).
|
||||
- Universe constraints (the kernel checks for it).
|
||||
-/
|
||||
private def elabCtors (indFVars : Array Expr) (indFVar : Expr) (params : Array Expr) (r : ElabHeaderResult) : TermElabM (List Constructor) := withRef r.view.ref do
|
||||
private def elabCtors (indFVars : Array Expr) (params : Array Expr) (r : ElabHeaderResult) : TermElabM (List Constructor) :=
|
||||
withRef r.view.ref do
|
||||
withExplicitToImplicit params do
|
||||
let indFVar := r.indFVar
|
||||
let indFamily ← isInductiveFamily params.size indFVar
|
||||
r.view.ctors.toList.mapM fun ctorView =>
|
||||
Term.withAutoBoundImplicit <| Term.elabBinders ctorView.binders.getArgs fun ctorParams =>
|
||||
withRef ctorView.ref do
|
||||
let rec elabCtorType (k : Expr → TermElabM Constructor) : TermElabM Constructor := do
|
||||
let elabCtorType : TermElabM Expr := do
|
||||
match ctorView.type? with
|
||||
| none =>
|
||||
if indFamily then
|
||||
throwError "constructor resulting type must be specified in inductive family declaration"
|
||||
k <| mkAppN indFVar params
|
||||
return mkAppN indFVar params
|
||||
| some ctorType =>
|
||||
let type ← Term.elabType ctorType
|
||||
trace[Elab.inductive] "elabType {ctorView.declName} : {type} "
|
||||
@@ -388,43 +191,43 @@ private def elabCtors (indFVars : Array Expr) (indFVar : Expr) (params : Array E
|
||||
throwError "unexpected constructor resulting type{indentExpr resultingType}"
|
||||
unless (← isType resultingType) do
|
||||
throwError "unexpected constructor resulting type, type expected{indentExpr resultingType}"
|
||||
k type
|
||||
elabCtorType fun type => do
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let ctorParams ← Term.addAutoBoundImplicits ctorParams
|
||||
let except (mvarId : MVarId) := ctorParams.any fun ctorParam => ctorParam.isMVar && ctorParam.mvarId! == mvarId
|
||||
/-
|
||||
We convert metavariables in the resulting type info extra parameters. Otherwise, we would not be able to elaborate
|
||||
declarations such as
|
||||
```
|
||||
inductive Palindrome : List α → Prop where
|
||||
| nil : Palindrome [] -- We would get an error here saying "failed to synthesize implicit argument" at `@List.nil ?m`
|
||||
| single : (a : α) → Palindrome [a]
|
||||
| sandwich : (a : α) → Palindrome as → Palindrome ([a] ++ as ++ [a])
|
||||
```
|
||||
We used to also collect unassigned metavariables on `ctorParams`, but it produced counterintuitive behavior.
|
||||
For example, the following declaration used to be accepted.
|
||||
```
|
||||
inductive Foo
|
||||
| bar (x)
|
||||
return type
|
||||
let type ← elabCtorType
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let ctorParams ← Term.addAutoBoundImplicits ctorParams
|
||||
let except (mvarId : MVarId) := ctorParams.any fun ctorParam => ctorParam.isMVar && ctorParam.mvarId! == mvarId
|
||||
/-
|
||||
We convert metavariables in the resulting type into extra parameters. Otherwise, we would not be able to elaborate
|
||||
declarations such as
|
||||
```
|
||||
inductive Palindrome : List α → Prop where
|
||||
| nil : Palindrome [] -- We would get an error here saying "failed to synthesize implicit argument" at `@List.nil ?m`
|
||||
| single : (a : α) → Palindrome [a]
|
||||
| sandwich : (a : α) → Palindrome as → Palindrome ([a] ++ as ++ [a])
|
||||
```
|
||||
We used to also collect unassigned metavariables on `ctorParams`, but it produced counterintuitive behavior.
|
||||
For example, the following declaration used to be accepted.
|
||||
```
|
||||
inductive Foo
|
||||
| bar (x)
|
||||
|
||||
#check Foo.bar
|
||||
-- @Foo.bar : {x : Sort u_1} → x → Foo
|
||||
```
|
||||
which is also inconsistent with the behavior of auto implicits in definitions. For example, the following example was never accepted.
|
||||
```
|
||||
def bar (x) := 1
|
||||
```
|
||||
-/
|
||||
let extraCtorParams ← Term.collectUnassignedMVars (← instantiateMVars type) #[] except
|
||||
trace[Elab.inductive] "extraCtorParams: {extraCtorParams}"
|
||||
/- We must abstract `extraCtorParams` and `ctorParams` simultaneously to make
|
||||
sure we do not create auxiliary metavariables. -/
|
||||
let type ← mkForallFVars (extraCtorParams ++ ctorParams) type
|
||||
let type ← reorderCtorArgs type
|
||||
let type ← mkForallFVars params type
|
||||
trace[Elab.inductive] "{ctorView.declName} : {type}"
|
||||
return { name := ctorView.declName, type }
|
||||
#check Foo.bar
|
||||
-- @Foo.bar : {x : Sort u_1} → x → Foo
|
||||
```
|
||||
which is also inconsistent with the behavior of auto implicits in definitions. For example, the following example was never accepted.
|
||||
```
|
||||
def bar (x) := 1
|
||||
```
|
||||
-/
|
||||
let extraCtorParams ← Term.collectUnassignedMVars (← instantiateMVars type) #[] except
|
||||
trace[Elab.inductive] "extraCtorParams: {extraCtorParams}"
|
||||
/- We must abstract `extraCtorParams` and `ctorParams` simultaneously to make
|
||||
sure we do not create auxiliary metavariables. -/
|
||||
let type ← mkForallFVars (extraCtorParams ++ ctorParams) type
|
||||
let type ← reorderCtorArgs type
|
||||
let type ← mkForallFVars params type
|
||||
trace[Elab.inductive] "{ctorView.declName} : {type}"
|
||||
return { name := ctorView.declName, type }
|
||||
where
|
||||
checkParamOccs (ctorType : Expr) : MetaM Expr :=
|
||||
let visit (e : Expr) : MetaM TransformStep := do
|
||||
@@ -444,555 +247,15 @@ where
|
||||
return .continue
|
||||
transform ctorType (pre := visit)
|
||||
|
||||
private def getResultingUniverse : List InductiveType → TermElabM Level
|
||||
| [] => throwError "unexpected empty inductive declaration"
|
||||
| indType :: _ => forallTelescopeReducing indType.type fun _ r => do
|
||||
let r ← whnfD r
|
||||
match r with
|
||||
| Expr.sort u => return u
|
||||
| _ => throwError "unexpected inductive type resulting type{indentExpr r}"
|
||||
|
||||
/--
|
||||
Return `some ?m` if `u` is of the form `?m + k`.
|
||||
Return none if `u` does not contain universe metavariables.
|
||||
Throw exception otherwise. -/
|
||||
def shouldInferResultUniverse (u : Level) : TermElabM (Option LMVarId) := do
|
||||
let u ← instantiateLevelMVars u
|
||||
if u.hasMVar then
|
||||
match u.getLevelOffset with
|
||||
| Level.mvar mvarId => return some mvarId
|
||||
| _ =>
|
||||
throwError "cannot infer resulting universe level of inductive datatype, given level contains metavariables {mkSort u}, provide universe explicitly"
|
||||
else
|
||||
return none
|
||||
|
||||
/--
|
||||
Convert universe metavariables into new parameters. It skips `univToInfer?` (the inductive datatype resulting universe) because
|
||||
it should be inferred later using `inferResultingUniverse`.
|
||||
-/
|
||||
private def levelMVarToParam (indTypes : List InductiveType) (univToInfer? : Option LMVarId) : TermElabM (List InductiveType) :=
|
||||
indTypes.mapM fun indType => do
|
||||
let type ← levelMVarToParam' indType.type
|
||||
let ctors ← indType.ctors.mapM fun ctor => do
|
||||
let ctorType ← levelMVarToParam' ctor.type
|
||||
return { ctor with type := ctorType }
|
||||
return { indType with ctors, type }
|
||||
where
|
||||
levelMVarToParam' (type : Expr) : TermElabM Expr := do
|
||||
Term.levelMVarToParam type (except := fun mvarId => univToInfer? == some mvarId)
|
||||
|
||||
def mkResultUniverse (us : Array Level) (rOffset : Nat) (preferProp : Bool) : Level :=
|
||||
if us.isEmpty && rOffset == 0 then
|
||||
if preferProp then levelZero else levelOne
|
||||
else
|
||||
let r := Level.mkNaryMax us.toList
|
||||
if rOffset == 0 && !r.isZero && !r.isNeverZero then
|
||||
mkLevelMax r levelOne |>.normalize
|
||||
else
|
||||
r.normalize
|
||||
|
||||
/--
|
||||
Auxiliary function for `updateResultingUniverse`
|
||||
`accLevel u r rOffset` add `u` to state if it is not already there and
|
||||
it is different from the resulting universe level `r+rOffset`.
|
||||
|
||||
|
||||
If `u` is a `max`, then its components are recursively processed.
|
||||
If `u` is a `succ` and `rOffset > 0`, we process the `u`s child using `rOffset-1`.
|
||||
|
||||
This method is used to infer the resulting universe level of an inductive datatype.
|
||||
-/
|
||||
def accLevel (u : Level) (r : Level) (rOffset : Nat) : OptionT (StateT (Array Level) Id) Unit := do
|
||||
go u rOffset
|
||||
where
|
||||
go (u : Level) (rOffset : Nat) : OptionT (StateT (Array Level) Id) Unit := do
|
||||
match u, rOffset with
|
||||
| .max u v, rOffset => go u rOffset; go v rOffset
|
||||
| .imax u v, rOffset => go u rOffset; go v rOffset
|
||||
| .zero, _ => return ()
|
||||
| .succ u, rOffset+1 => go u rOffset
|
||||
| u, rOffset =>
|
||||
if rOffset == 0 && u == r then
|
||||
return ()
|
||||
else if r.occurs u then
|
||||
failure
|
||||
else if rOffset > 0 then
|
||||
failure
|
||||
else if (← get).contains u then
|
||||
return ()
|
||||
else
|
||||
modify fun us => us.push u
|
||||
|
||||
/--
|
||||
Auxiliary function for `updateResultingUniverse`
|
||||
`accLevelAtCtor ctor ctorParam r rOffset` add `u` (`ctorParam`'s universe) to state if it is not already there and
|
||||
it is different from the resulting universe level `r+rOffset`.
|
||||
|
||||
See `accLevel`.
|
||||
-/
|
||||
def accLevelAtCtor (ctor : Constructor) (ctorParam : Expr) (r : Level) (rOffset : Nat) : StateRefT (Array Level) TermElabM Unit := do
|
||||
let type ← inferType ctorParam
|
||||
let u ← instantiateLevelMVars (← getLevel type)
|
||||
match (← modifyGet fun s => accLevel u r rOffset |>.run |>.run s) with
|
||||
| some _ => pure ()
|
||||
| none =>
|
||||
let typeType ← inferType type
|
||||
let mut msg := m!"failed to compute resulting universe level of inductive datatype, constructor '{ctor.name}' has type{indentExpr ctor.type}\nparameter"
|
||||
let localDecl ← getFVarLocalDecl ctorParam
|
||||
unless localDecl.userName.hasMacroScopes do
|
||||
msg := msg ++ m!" '{ctorParam}'"
|
||||
msg := msg ++ m!" has type{indentD m!"{type} : {typeType}"}\ninductive type resulting type{indentExpr (mkSort (r.addOffset rOffset))}"
|
||||
if r.isMVar then
|
||||
msg := msg ++ "\nrecall that Lean only infers the resulting universe level automatically when there is a unique solution for the universe level constraints, consider explicitly providing the inductive type resulting universe level"
|
||||
throwError msg
|
||||
|
||||
/--
|
||||
Execute `k` using the `Syntax` reference associated with constructor `ctorName`.
|
||||
-/
|
||||
def withCtorRef [Monad m] [MonadRef m] (views : Array InductiveView) (ctorName : Name) (k : m α) : m α := do
|
||||
for view in views do
|
||||
for ctorView in view.ctors do
|
||||
if ctorView.declName == ctorName then
|
||||
return (← withRef ctorView.ref k)
|
||||
k
|
||||
|
||||
/-- Auxiliary function for `updateResultingUniverse` -/
|
||||
private partial def collectUniverses (views : Array InductiveView) (r : Level) (rOffset : Nat) (numParams : Nat) (indTypes : List InductiveType) : TermElabM (Array Level) := do
|
||||
let (_, us) ← go |>.run #[]
|
||||
return us
|
||||
where
|
||||
go : StateRefT (Array Level) TermElabM Unit :=
|
||||
indTypes.forM fun indType => indType.ctors.forM fun ctor =>
|
||||
withCtorRef views ctor.name do
|
||||
forallTelescopeReducing ctor.type fun ctorParams _ =>
|
||||
for ctorParam in ctorParams[numParams:] do
|
||||
accLevelAtCtor ctor ctorParam r rOffset
|
||||
|
||||
/--
|
||||
Decides whether the inductive type should be `Prop`-valued when the universe is not given
|
||||
and when the universe inference algorithm `collectUniverses` determines
|
||||
that the inductive type could naturally be `Prop`-valued.
|
||||
Recall: the natural universe level is the mimimum universe level for all the types of all the constructor parameters.
|
||||
|
||||
Heuristic:
|
||||
- We want `Prop` when each inductive type is a syntactic subsingleton.
|
||||
That's to say, when each inductive type has at most one constructor.
|
||||
Such types carry no data anyway.
|
||||
- Exception: if no inductive type has any constructors, these are likely stubbed-out declarations,
|
||||
so we prefer `Type` instead.
|
||||
- Exception: if each constructor has no parameters, then these are likely partially-written enumerations,
|
||||
so we prefer `Type` instead.
|
||||
-/
|
||||
private def isPropCandidate (numParams : Nat) (indTypes : List InductiveType) : MetaM Bool := do
|
||||
unless indTypes.foldl (fun n indType => max n indType.ctors.length) 0 == 1 do
|
||||
return false
|
||||
for indType in indTypes do
|
||||
for ctor in indType.ctors do
|
||||
let cparams ← forallTelescopeReducing ctor.type fun ctorParams _ => pure (ctorParams.size - numParams)
|
||||
unless cparams == 0 do
|
||||
return true
|
||||
return false
|
||||
|
||||
private def updateResultingUniverse (views : Array InductiveView) (numParams : Nat) (indTypes : List InductiveType) : TermElabM (List InductiveType) := do
|
||||
let r ← getResultingUniverse indTypes
|
||||
let rOffset : Nat := r.getOffset
|
||||
let r : Level := r.getLevelOffset
|
||||
unless r.isMVar do
|
||||
throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly: {r}"
|
||||
let us ← collectUniverses views r rOffset numParams indTypes
|
||||
trace[Elab.inductive] "updateResultingUniverse us: {us}, r: {r}, rOffset: {rOffset}"
|
||||
let rNew := mkResultUniverse us rOffset (← isPropCandidate numParams indTypes)
|
||||
assignLevelMVar r.mvarId! rNew
|
||||
indTypes.mapM fun indType => do
|
||||
let type ← instantiateMVars indType.type
|
||||
let ctors ← indType.ctors.mapM fun ctor => return { ctor with type := (← instantiateMVars ctor.type) }
|
||||
return { indType with type, ctors }
|
||||
|
||||
register_builtin_option bootstrap.inductiveCheckResultingUniverse : Bool := {
|
||||
defValue := true,
|
||||
group := "bootstrap",
|
||||
descr := "by default the `inductive`/`structure` commands report an error if the resulting universe is not zero, but may be zero for some universe parameters. Reason: unless this type is a subsingleton, it is hardly what the user wants since it can only eliminate into `Prop`. In the `Init` package, we define subsingletons, and we use this option to disable the check. This option may be deleted in the future after we improve the validator"
|
||||
}
|
||||
|
||||
def checkResultingUniverse (u : Level) : TermElabM Unit := do
|
||||
if bootstrap.inductiveCheckResultingUniverse.get (← getOptions) then
|
||||
let u ← instantiateLevelMVars u
|
||||
if !u.isZero && !u.isNeverZero then
|
||||
throwError "invalid universe polymorphic type, the resultant universe is not Prop (i.e., 0), but it may be Prop for some parameter values (solution: use 'u+1' or 'max 1 u'){indentD u}"
|
||||
|
||||
private def checkResultingUniverses (views : Array InductiveView) (numParams : Nat) (indTypes : List InductiveType) : TermElabM Unit := do
|
||||
let u := (← instantiateLevelMVars (← getResultingUniverse indTypes)).normalize
|
||||
checkResultingUniverse u
|
||||
unless u.isZero do
|
||||
indTypes.forM fun indType => indType.ctors.forM fun ctor =>
|
||||
forallTelescopeReducing ctor.type fun ctorArgs _ => do
|
||||
for ctorArg in ctorArgs[numParams:] do
|
||||
let type ← inferType ctorArg
|
||||
let v := (← instantiateLevelMVars (← getLevel type)).normalize
|
||||
let rec check (v' : Level) (u' : Level) : TermElabM Unit :=
|
||||
match v', u' with
|
||||
| .succ v', .succ u' => check v' u'
|
||||
| .mvar id, .param .. =>
|
||||
/- Special case:
|
||||
The constructor parameter `v` is at unverse level `?v+k` and
|
||||
the resulting inductive universe level is `u'+k`, where `u'` is a parameter (or zero).
|
||||
Thus, `?v := u'` is the only choice for satisfying the universe constraint `?v+k <= u'+k`.
|
||||
Note that, we still generate an error for cases where there is more than one of satisfying the constraint.
|
||||
Examples:
|
||||
-----------------------------------------------------------
|
||||
| ctor universe level | inductive datatype universe level |
|
||||
-----------------------------------------------------------
|
||||
| ?v | max u w |
|
||||
-----------------------------------------------------------
|
||||
| ?v | u + 1 |
|
||||
-----------------------------------------------------------
|
||||
-/
|
||||
assignLevelMVar id u'
|
||||
| .mvar id, .zero => assignLevelMVar id u' -- TODO: merge with previous case
|
||||
| _, _ =>
|
||||
unless u.geq v do
|
||||
let mut msg := m!"invalid universe level in constructor '{ctor.name}', parameter"
|
||||
let localDecl ← getFVarLocalDecl ctorArg
|
||||
unless localDecl.userName.hasMacroScopes do
|
||||
msg := msg ++ m!" '{ctorArg}'"
|
||||
msg := msg ++ m!" has type{indentExpr type}"
|
||||
msg := msg ++ m!"\nat universe level{indentD v}"
|
||||
msg := msg ++ m!"\nit must be smaller than or equal to the inductive datatype universe level{indentD u}"
|
||||
withCtorRef views ctor.name <| throwError msg
|
||||
check v u
|
||||
|
||||
private def collectUsed (indTypes : List InductiveType) : StateRefT CollectFVars.State MetaM Unit := do
|
||||
indTypes.forM fun indType => do
|
||||
indType.type.collectFVars
|
||||
indType.ctors.forM fun ctor =>
|
||||
ctor.type.collectFVars
|
||||
|
||||
private def removeUnused (vars : Array Expr) (indTypes : List InductiveType) : TermElabM (LocalContext × LocalInstances × Array Expr) := do
|
||||
let (_, used) ← (collectUsed indTypes).run {}
|
||||
Meta.removeUnused vars used
|
||||
|
||||
private def withUsed {α} (vars : Array Expr) (indTypes : List InductiveType) (k : Array Expr → TermElabM α) : TermElabM α := do
|
||||
let (lctx, localInsts, vars) ← removeUnused vars indTypes
|
||||
withLCtx lctx localInsts <| k vars
|
||||
|
||||
private def updateParams (vars : Array Expr) (indTypes : List InductiveType) : TermElabM (List InductiveType) :=
|
||||
indTypes.mapM fun indType => do
|
||||
let type ← mkForallFVars vars indType.type
|
||||
let ctors ← indType.ctors.mapM fun ctor => do
|
||||
let ctorType ← withExplicitToImplicit vars (mkForallFVars vars ctor.type)
|
||||
return { ctor with type := ctorType }
|
||||
return { indType with type, ctors }
|
||||
|
||||
private def collectLevelParamsInInductive (indTypes : List InductiveType) : Array Name := Id.run do
|
||||
let mut usedParams : CollectLevelParams.State := {}
|
||||
for indType in indTypes do
|
||||
usedParams := collectLevelParams usedParams indType.type
|
||||
for ctor in indType.ctors do
|
||||
usedParams := collectLevelParams usedParams ctor.type
|
||||
return usedParams.params
|
||||
|
||||
private def mkIndFVar2Const (views : Array InductiveView) (indFVars : Array Expr) (levelNames : List Name) : ExprMap Expr := Id.run do
|
||||
let levelParams := levelNames.map mkLevelParam;
|
||||
let mut m : ExprMap Expr := {}
|
||||
for h : i in [:views.size] do
|
||||
let view := views[i]
|
||||
let indFVar := indFVars[i]!
|
||||
m := m.insert indFVar (mkConst view.declName levelParams)
|
||||
return m
|
||||
|
||||
/-- Remark: `numVars <= numParams`. `numVars` is the number of context `variables` used in the inductive declaration,
|
||||
and `numParams` is `numVars` + number of explicit parameters provided in the declaration. -/
|
||||
private def replaceIndFVarsWithConsts (views : Array InductiveView) (indFVars : Array Expr) (levelNames : List Name)
|
||||
(numVars : Nat) (numParams : Nat) (indTypes : List InductiveType) : TermElabM (List InductiveType) :=
|
||||
let indFVar2Const := mkIndFVar2Const views indFVars levelNames
|
||||
indTypes.mapM fun indType => do
|
||||
let ctors ← indType.ctors.mapM fun ctor => do
|
||||
let type ← forallBoundedTelescope ctor.type numParams fun params type => do
|
||||
let type := type.replace fun e =>
|
||||
if !e.isFVar then
|
||||
none
|
||||
else match indFVar2Const[e]? with
|
||||
| none => none
|
||||
| some c => mkAppN c (params.extract 0 numVars)
|
||||
instantiateMVars (← mkForallFVars params type)
|
||||
return { ctor with type }
|
||||
return { indType with ctors }
|
||||
|
||||
private def mkAuxConstructions (views : Array InductiveView) : TermElabM Unit := do
|
||||
let env ← getEnv
|
||||
let hasEq := env.contains ``Eq
|
||||
let hasHEq := env.contains ``HEq
|
||||
let hasUnit := env.contains ``PUnit
|
||||
let hasProd := env.contains ``Prod
|
||||
for view in views do
|
||||
let n := view.declName
|
||||
mkRecOn n
|
||||
if hasUnit then mkCasesOn n
|
||||
if hasUnit && hasEq && hasHEq then mkNoConfusion n
|
||||
if hasUnit && hasProd then mkBelow n
|
||||
if hasUnit && hasProd then mkIBelow n
|
||||
for view in views do
|
||||
let n := view.declName;
|
||||
if hasUnit && hasProd then mkBRecOn n
|
||||
if hasUnit && hasProd then mkBInductionOn n
|
||||
|
||||
private def getArity (indType : InductiveType) : MetaM Nat :=
|
||||
forallTelescopeReducing indType.type fun xs _ => return xs.size
|
||||
|
||||
private def resetMaskAt (mask : Array Bool) (i : Nat) : Array Bool :=
|
||||
mask.setD i false
|
||||
|
||||
/--
|
||||
Compute a bit-mask that for `indType`. The size of the resulting array `result` is the arity of `indType`.
|
||||
The first `numParams` elements are `false` since they are parameters.
|
||||
For `i ∈ [numParams, arity)`, we have that `result[i]` if this index of the inductive family is fixed.
|
||||
-/
|
||||
private def computeFixedIndexBitMask (numParams : Nat) (indType : InductiveType) (indFVars : Array Expr) : MetaM (Array Bool) := do
|
||||
let arity ← getArity indType
|
||||
if arity ≤ numParams then
|
||||
return mkArray arity false
|
||||
else
|
||||
let maskRef ← IO.mkRef (mkArray numParams false ++ mkArray (arity - numParams) true)
|
||||
let rec go (ctors : List Constructor) : MetaM (Array Bool) := do
|
||||
match ctors with
|
||||
| [] => maskRef.get
|
||||
| ctor :: ctors =>
|
||||
forallTelescopeReducing ctor.type fun xs type => do
|
||||
let typeArgs := type.getAppArgs
|
||||
for i in [numParams:arity] do
|
||||
unless i < xs.size && xs[i]! == typeArgs[i]! do -- Remark: if we want to allow arguments to be rearranged, this test should be xs.contains typeArgs[i]
|
||||
maskRef.modify fun mask => mask.set! i false
|
||||
for x in xs[numParams:] do
|
||||
let xType ← inferType x
|
||||
let cond (e : Expr) := indFVars.any (fun indFVar => e.getAppFn == indFVar)
|
||||
xType.forEachWhere (stopWhenVisited := true) cond fun e => do
|
||||
let eArgs := e.getAppArgs
|
||||
for i in [numParams:eArgs.size] do
|
||||
if i >= typeArgs.size then
|
||||
maskRef.modify (resetMaskAt · i)
|
||||
else
|
||||
unless eArgs[i]! == typeArgs[i]! do
|
||||
maskRef.modify (resetMaskAt · i)
|
||||
/-If an index is missing in the arguments of the inductive type, then it must be non-fixed.
|
||||
Consider the following example:
|
||||
```lean
|
||||
inductive All {I : Type u} (P : I → Type v) : List I → Type (max u v) where
|
||||
| cons : P x → All P xs → All P (x :: xs)
|
||||
|
||||
inductive Iμ {I : Type u} : I → Type (max u v) where
|
||||
| mk : (i : I) → All Iμ [] → Iμ i
|
||||
```
|
||||
because `i` doesn't appear in `All Iμ []`, the index shouldn't be fixed.
|
||||
-/
|
||||
for i in [eArgs.size:arity] do
|
||||
maskRef.modify (resetMaskAt · i)
|
||||
go ctors
|
||||
go indType.ctors
|
||||
|
||||
/-- Return true iff `arrowType` is an arrow and its domain is defeq to `type` -/
|
||||
private def isDomainDefEq (arrowType : Expr) (type : Expr) : MetaM Bool := do
|
||||
if !arrowType.isForall then
|
||||
return false
|
||||
else
|
||||
/-
|
||||
We used to use `withNewMCtxDepth` to make sure we do not assign universe metavariables,
|
||||
but it was not satisfactory. For example, in declarations such as
|
||||
```
|
||||
inductive Eq : α → α → Prop where
|
||||
| refl (a : α) : Eq a a
|
||||
```
|
||||
We want the first two indices to be promoted to parameters, and this will only
|
||||
happen if we can assign universe metavariables.
|
||||
-/
|
||||
isDefEq arrowType.bindingDomain! type
|
||||
|
||||
/--
|
||||
Convert fixed indices to parameters.
|
||||
-/
|
||||
private partial def fixedIndicesToParams (numParams : Nat) (indTypes : Array InductiveType) (indFVars : Array Expr) : MetaM Nat := do
|
||||
if !inductive.autoPromoteIndices.get (← getOptions) then
|
||||
return numParams
|
||||
let masks ← indTypes.mapM (computeFixedIndexBitMask numParams · indFVars)
|
||||
trace[Elab.inductive] "masks: {masks}"
|
||||
if masks.all fun mask => !mask.contains true then
|
||||
return numParams
|
||||
-- We process just a non-fixed prefix of the indices for now. Reason: we don't want to change the order.
|
||||
-- TODO: extend it in the future. For example, it should be reasonable to change
|
||||
-- the order of indices generated by the auto implicit feature.
|
||||
let mask := masks[0]!
|
||||
forallBoundedTelescope indTypes[0]!.type numParams fun params type => do
|
||||
let otherTypes ← indTypes[1:].toArray.mapM fun indType => do whnfD (← instantiateForall indType.type params)
|
||||
let ctorTypes ← indTypes.toList.mapM fun indType => indType.ctors.mapM fun ctor => do whnfD (← instantiateForall ctor.type params)
|
||||
let typesToCheck := otherTypes.toList ++ ctorTypes.flatten
|
||||
let rec go (i : Nat) (type : Expr) (typesToCheck : List Expr) : MetaM Nat := do
|
||||
if i < mask.size then
|
||||
if !masks.all fun mask => i < mask.size && mask[i]! then
|
||||
return i
|
||||
if !type.isForall then
|
||||
return i
|
||||
let paramType := type.bindingDomain!
|
||||
if !(← typesToCheck.allM fun type => isDomainDefEq type paramType) then
|
||||
trace[Elab.inductive] "domain not def eq: {i}, {type} =?= {paramType}"
|
||||
return i
|
||||
withLocalDeclD `a paramType fun paramNew => do
|
||||
let typesToCheck ← typesToCheck.mapM fun type => whnfD (type.bindingBody!.instantiate1 paramNew)
|
||||
go (i+1) (type.bindingBody!.instantiate1 paramNew) typesToCheck
|
||||
else
|
||||
return i
|
||||
go numParams type typesToCheck
|
||||
|
||||
private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) : TermElabM Unit := Term.withoutSavingRecAppSyntax do
|
||||
let view0 := views[0]!
|
||||
let scopeLevelNames ← Term.getLevelNames
|
||||
InductiveView.checkLevelNames views
|
||||
let allUserLevelNames := view0.levelNames
|
||||
let isUnsafe := view0.modifiers.isUnsafe
|
||||
withRef view0.ref <| Term.withLevelNames allUserLevelNames do
|
||||
let rs ← elabHeader views
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
ElabHeaderResult.checkLevelNames rs
|
||||
let allUserLevelNames := rs[0]!.levelNames
|
||||
trace[Elab.inductive] "level names: {allUserLevelNames}"
|
||||
withInductiveLocalDecls rs fun params indFVars => do
|
||||
trace[Elab.inductive] "indFVars: {indFVars}"
|
||||
let mut indTypesArray := #[]
|
||||
for h : i in [:views.size] do
|
||||
let indFVar := indFVars[i]!
|
||||
Term.addLocalVarInfo views[i].declId indFVar
|
||||
let r := rs[i]!
|
||||
/- At this point, because of `withInductiveLocalDecls`, the only fvars that are in context are the ones related to the first inductive type.
|
||||
Because of this, we need to replace the fvars present in each inductive type's header of the mutual block with those of the first inductive.
|
||||
However, some mvars may still be uninstantiated there, and might hide some of the old fvars.
|
||||
As such we first need to synthesize all possible mvars at this stage, instantiate them in the header types and only
|
||||
then replace the parameters' fvars in the header type.
|
||||
|
||||
See issue #3242 (`https://github.com/leanprover/lean4/issues/3242`)
|
||||
-/
|
||||
let type ← instantiateMVars r.type
|
||||
let type := type.replaceFVars r.params params
|
||||
let type ← mkForallFVars params type
|
||||
let ctors ← withExplicitToImplicit params (elabCtors indFVars indFVar params r)
|
||||
indTypesArray := indTypesArray.push { name := r.view.declName, type, ctors }
|
||||
let numExplicitParams ← fixedIndicesToParams params.size indTypesArray indFVars
|
||||
trace[Elab.inductive] "numExplicitParams: {numExplicitParams}"
|
||||
let indTypes := indTypesArray.toList
|
||||
let u ← getResultingUniverse indTypes
|
||||
let univToInfer? ← shouldInferResultUniverse u
|
||||
withUsed vars indTypes fun vars => do
|
||||
let numVars := vars.size
|
||||
let numParams := numVars + numExplicitParams
|
||||
let indTypes ← updateParams vars indTypes
|
||||
let indTypes ← if let some univToInfer := univToInfer? then
|
||||
updateResultingUniverse views numParams (← levelMVarToParam indTypes univToInfer)
|
||||
else
|
||||
checkResultingUniverses views numParams indTypes
|
||||
levelMVarToParam indTypes none
|
||||
let usedLevelNames := collectLevelParamsInInductive indTypes
|
||||
match sortDeclLevelParams scopeLevelNames allUserLevelNames usedLevelNames with
|
||||
| .error msg => throwError msg
|
||||
| .ok levelParams => do
|
||||
let indTypes ← replaceIndFVarsWithConsts views indFVars levelParams numVars numParams indTypes
|
||||
let decl := Declaration.inductDecl levelParams numParams indTypes isUnsafe
|
||||
Term.ensureNoUnassignedMVars decl
|
||||
addDecl decl
|
||||
mkAuxConstructions views
|
||||
withSaveInfoContext do -- save new env
|
||||
for view in views do
|
||||
Term.addTermInfo' view.ref[1] (← mkConstWithLevelParams view.declName) (isBinder := true)
|
||||
for ctor in view.ctors do
|
||||
Term.addTermInfo' ctor.ref[3] (← mkConstWithLevelParams ctor.declName) (isBinder := true)
|
||||
-- We need to invoke `applyAttributes` because `class` is implemented as an attribute.
|
||||
Term.applyAttributesAt view.declName view.modifiers.attrs .afterTypeChecking
|
||||
|
||||
private def applyDerivingHandlers (views : Array InductiveView) : CommandElabM Unit := do
|
||||
let mut processed : NameSet := {}
|
||||
for view in views do
|
||||
for classView in view.derivingClasses do
|
||||
let className := classView.className
|
||||
unless processed.contains className do
|
||||
processed := processed.insert className
|
||||
let mut declNames := #[]
|
||||
for view in views do
|
||||
if view.derivingClasses.any fun classView => classView.className == className then
|
||||
declNames := declNames.push view.declName
|
||||
classView.applyHandlers declNames
|
||||
|
||||
private def applyComputedFields (indViews : Array InductiveView) : CommandElabM Unit := do
|
||||
if indViews.all (·.computedFields.isEmpty) then return
|
||||
|
||||
let mut computedFields := #[]
|
||||
let mut computedFieldDefs := #[]
|
||||
for indView@{declName, ..} in indViews do
|
||||
for {ref, fieldId, type, matchAlts, modifiers, ..} in indView.computedFields do
|
||||
computedFieldDefs := computedFieldDefs.push <| ← do
|
||||
let modifiers ← match modifiers with
|
||||
| `(Lean.Parser.Command.declModifiersT| $[$doc:docComment]? $[$attrs:attributes]? $[$vis]? $[noncomputable]?) =>
|
||||
`(Lean.Parser.Command.declModifiersT| $[$doc]? $[$attrs]? $[$vis]? noncomputable)
|
||||
| _ => do
|
||||
withRef modifiers do logError "unsupported modifiers for computed field"
|
||||
`(Parser.Command.declModifiersT| noncomputable)
|
||||
`($(⟨modifiers⟩):declModifiers
|
||||
def%$ref $(mkIdent <| `_root_ ++ declName ++ fieldId):ident : $type $matchAlts:matchAlts)
|
||||
let computedFieldNames := indView.computedFields.map fun {fieldId, ..} => declName ++ fieldId
|
||||
computedFields := computedFields.push (declName, computedFieldNames)
|
||||
withScope (fun scope => { scope with
|
||||
opts := scope.opts
|
||||
|>.setBool `bootstrap.genMatcherCode false
|
||||
|>.setBool `elaboratingComputedFields true}) <|
|
||||
elabCommand <| ← `(mutual $computedFieldDefs* end)
|
||||
|
||||
liftTermElabM do Term.withDeclName indViews[0]!.declName do
|
||||
ComputedFields.setComputedFields computedFields
|
||||
|
||||
def elabInductiveViews (vars : Array Expr) (views : Array InductiveView) : TermElabM Unit := do
|
||||
let view0 := views[0]!
|
||||
let ref := view0.ref
|
||||
Term.withDeclName view0.declName do withRef ref do
|
||||
mkInductiveDecl vars views
|
||||
mkSizeOfInstances view0.declName
|
||||
Lean.Meta.IndPredBelow.mkBelow view0.declName
|
||||
for view in views do
|
||||
mkInjectiveTheorems view.declName
|
||||
|
||||
def elabInductiveViewsPostprocessing (views : Array InductiveView) : CommandElabM Unit := do
|
||||
let view0 := views[0]!
|
||||
let ref := view0.ref
|
||||
applyComputedFields views -- NOTE: any generated code before this line is invalid
|
||||
applyDerivingHandlers views
|
||||
runTermElabM fun _ => Term.withDeclName view0.declName do withRef ref do
|
||||
for view in views do
|
||||
Term.applyAttributesAt view.declName view.modifiers.attrs .afterCompilation
|
||||
|
||||
def elabInductives (inductives : Array (Modifiers × Syntax)) : CommandElabM Unit := do
|
||||
let vs ← runTermElabM fun vars => do
|
||||
let vs ← inductives.mapM fun (modifiers, stx) => inductiveSyntaxToView modifiers stx
|
||||
elabInductiveViews vars vs
|
||||
pure vs
|
||||
elabInductiveViewsPostprocessing vs
|
||||
|
||||
def elabInductive (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit := do
|
||||
elabInductives #[(modifiers, stx)]
|
||||
|
||||
def elabClassInductive (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit := do
|
||||
let modifiers := modifiers.addAttr { name := `class }
|
||||
elabInductive modifiers stx
|
||||
|
||||
/--
|
||||
Returns true if all elements of the `mutual` block (`Lean.Parser.Command.mutual`) are inductive declarations.
|
||||
-/
|
||||
def isMutualInductive (stx : Syntax) : Bool :=
|
||||
stx[1].getArgs.all fun elem =>
|
||||
let decl := elem[1]
|
||||
let declKind := decl.getKind
|
||||
declKind == `Lean.Parser.Command.inductive
|
||||
|
||||
/--
|
||||
Elaborates a `mutual` block satisfying `Lean.Elab.Command.isMutualInductive`.
|
||||
-/
|
||||
def elabMutualInductive (elems : Array Syntax) : CommandElabM Unit := do
|
||||
let inductives ← elems.mapM fun stx => do
|
||||
let modifiers ← elabModifiers ⟨stx[0]⟩
|
||||
pure (modifiers, stx[1])
|
||||
elabInductives inductives
|
||||
@[builtin_inductive_elab Lean.Parser.Command.inductive, builtin_inductive_elab Lean.Parser.Command.classInductive]
|
||||
def elabInductiveCommand : InductiveElabDescr where
|
||||
mkInductiveView (modifiers : Modifiers) (stx : Syntax) := do
|
||||
let view ← inductiveSyntaxToView modifiers stx
|
||||
return {
|
||||
view
|
||||
elabCtors := fun rs r params => do
|
||||
let ctors ← elabCtors (rs.map (·.indFVar)) params r
|
||||
return { ctors }
|
||||
}
|
||||
|
||||
end Lean.Elab.Command
|
||||
|
||||
@@ -840,8 +840,8 @@ private def mkLetRecClosures (sectionVars : Array Expr) (mainFVarIds : Array FVa
|
||||
abbrev Replacement := FVarIdMap Expr
|
||||
|
||||
def insertReplacementForMainFns (r : Replacement) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) : Replacement :=
|
||||
mainFVars.size.fold (init := r) fun i r =>
|
||||
r.insert mainFVars[i]!.fvarId! (mkAppN (Lean.mkConst mainHeaders[i]!.declName) sectionVars)
|
||||
mainFVars.size.fold (init := r) fun i _ r =>
|
||||
r.insert mainFVars[i].fvarId! (mkAppN (Lean.mkConst mainHeaders[i]!.declName) sectionVars)
|
||||
|
||||
|
||||
def insertReplacementForLetRecs (r : Replacement) (letRecClosures : List LetRecClosure) : Replacement :=
|
||||
@@ -871,8 +871,8 @@ def Replacement.apply (r : Replacement) (e : Expr) : Expr :=
|
||||
|
||||
def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainVals : Array Expr)
|
||||
: TermElabM (Array PreDefinition) :=
|
||||
mainHeaders.size.foldM (init := preDefs) fun i preDefs => do
|
||||
let header := mainHeaders[i]!
|
||||
mainHeaders.size.foldM (init := preDefs) fun i _ preDefs => do
|
||||
let header := mainHeaders[i]
|
||||
let termination ← declValToTerminationHint header.value
|
||||
let termination := termination.rememberExtraParams header.numParams mainVals[i]!
|
||||
let value ← mkLambdaFVars sectionVars mainVals[i]!
|
||||
|
||||
1041
src/Lean/Elab/MutualInductive.lean
Normal file
1041
src/Lean/Elab/MutualInductive.lean
Normal file
File diff suppressed because it is too large
Load Diff
@@ -332,9 +332,9 @@ where
|
||||
else
|
||||
let accessible := isNextArgAccessible ctx
|
||||
let (d, ctx) := getNextParam ctx
|
||||
match ctx.namedArgs.findIdx? fun namedArg => namedArg.name == d.1 with
|
||||
match ctx.namedArgs.findFinIdx? fun namedArg => namedArg.name == d.1 with
|
||||
| some idx =>
|
||||
let arg := ctx.namedArgs[idx]!
|
||||
let arg := ctx.namedArgs[idx]
|
||||
let ctx := { ctx with namedArgs := ctx.namedArgs.eraseIdx idx }
|
||||
let ctx ← pushNewArg accessible ctx arg.val
|
||||
processCtorAppContext ctx
|
||||
|
||||
@@ -292,9 +292,9 @@ def mkBrecOnApp (positions : Positions) (fnIdx : Nat) (brecOnConst : Nat → Exp
|
||||
let packedFTypes ← inferArgumentTypesN positions.size brecOn
|
||||
let packedFArgs ← positions.mapMwith PProdN.mkLambdas packedFTypes FArgs
|
||||
let brecOn := mkAppN brecOn packedFArgs
|
||||
let some poss := positions.find? (·.contains fnIdx)
|
||||
let some (size, idx) := positions.findSome? fun pos => (pos.size, ·) <$> pos.indexOf? fnIdx
|
||||
| throwError "mkBrecOnApp: Could not find {fnIdx} in {positions}"
|
||||
let brecOn ← PProdN.proj poss.size (poss.getIdx? fnIdx).get! brecOn
|
||||
let brecOn ← PProdN.proj size idx brecOn
|
||||
mkLambdaFVars ys (mkAppN brecOn otherArgs)
|
||||
|
||||
end Lean.Elab.Structural
|
||||
|
||||
@@ -40,7 +40,7 @@ private partial def post (fixedPrefix : Nat) (argsPacker : ArgsPacker) (funNames
|
||||
return TransformStep.done e
|
||||
let declName := f.constName!
|
||||
let us := f.constLevels!
|
||||
if let some fidx := funNames.getIdx? declName then
|
||||
if let some fidx := funNames.indexOf? declName then
|
||||
let arity := fixedPrefix + argsPacker.varNamess[fidx]!.size
|
||||
let e' ← withAppN arity e fun args => do
|
||||
let fixedArgs := args[:fixedPrefix]
|
||||
|
||||
@@ -58,7 +58,7 @@ partial def mkTuple : Array Syntax → TermElabM Syntax
|
||||
| #[] => `(Unit.unit)
|
||||
| #[e] => return e
|
||||
| es => do
|
||||
let stx ← mkTuple (es.eraseIdx 0)
|
||||
let stx ← mkTuple (es.eraseIdxIfInBounds 0)
|
||||
`(Prod.mk $(es[0]!) $stx)
|
||||
|
||||
def resolveSectionVariable (sectionVars : NameMap Name) (id : Name) : List (Name × List String) :=
|
||||
|
||||
@@ -4,22 +4,15 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Class
|
||||
import Lean.Parser.Command
|
||||
import Lean.Meta.Closure
|
||||
import Lean.Meta.SizeOf
|
||||
import Lean.Meta.Injective
|
||||
import Lean.Meta.Structure
|
||||
import Lean.Meta.AppBuilder
|
||||
import Lean.Elab.Command
|
||||
import Lean.Elab.DeclModifiers
|
||||
import Lean.Elab.DeclUtil
|
||||
import Lean.Elab.Inductive
|
||||
import Lean.Elab.DeclarationRange
|
||||
import Lean.Elab.Binders
|
||||
import Lean.Elab.MutualInductive
|
||||
|
||||
namespace Lean.Elab.Command
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Elab.structure
|
||||
registerTraceClass `Elab.structure.resolutionOrder
|
||||
|
||||
register_builtin_option structureDiamondWarning : Bool := {
|
||||
defValue := false
|
||||
descr := "if true, enable warnings when a structure has diamond inheritance"
|
||||
@@ -39,13 +32,6 @@ leading_parser (structureTk <|> classTk) >> declId >> many Term.bracketedBinder
|
||||
```
|
||||
-/
|
||||
|
||||
structure StructCtorView where
|
||||
ref : Syntax
|
||||
modifiers : Modifiers
|
||||
name : Name
|
||||
declName : Name
|
||||
deriving Inhabited
|
||||
|
||||
structure StructFieldView where
|
||||
ref : Syntax
|
||||
modifiers : Modifiers
|
||||
@@ -61,22 +47,15 @@ structure StructFieldView where
|
||||
type? : Option Syntax
|
||||
value? : Option Syntax
|
||||
|
||||
structure StructView where
|
||||
ref : Syntax
|
||||
declId : Syntax
|
||||
modifiers : Modifiers
|
||||
isClass : Bool -- struct-only
|
||||
shortDeclName : Name
|
||||
declName : Name
|
||||
levelNames : List Name
|
||||
binders : Syntax
|
||||
type : Syntax -- modified (inductive has type?)
|
||||
parents : Array Syntax -- struct-only
|
||||
ctor : StructCtorView -- struct-only
|
||||
fields : Array StructFieldView -- struct-only
|
||||
derivingClasses : Array DerivingClassView
|
||||
structure StructView extends InductiveView where
|
||||
parents : Array Syntax
|
||||
fields : Array StructFieldView
|
||||
deriving Inhabited
|
||||
|
||||
def StructView.ctor : StructView → CtorView
|
||||
| { ctors := #[ctor], ..} => ctor
|
||||
| _ => unreachable!
|
||||
|
||||
structure StructParentInfo where
|
||||
ref : Syntax
|
||||
fvar? : Option Expr
|
||||
@@ -102,18 +81,6 @@ structure StructFieldInfo where
|
||||
value? : Option Expr := none
|
||||
deriving Inhabited, Repr
|
||||
|
||||
structure ElabStructHeaderResult where
|
||||
view : StructView
|
||||
lctx : LocalContext
|
||||
localInsts : LocalInstances
|
||||
levelNames : List Name
|
||||
params : Array Expr
|
||||
type : Expr
|
||||
parents : Array StructParentInfo
|
||||
/-- Field infos from parents. -/
|
||||
parentFieldInfos : Array StructFieldInfo
|
||||
deriving Inhabited
|
||||
|
||||
def StructFieldInfo.isFromParent (info : StructFieldInfo) : Bool :=
|
||||
match info.kind with
|
||||
| StructFieldKind.fromParent => true
|
||||
@@ -130,12 +97,12 @@ The structure constructor syntax is
|
||||
leading_parser try (declModifiers >> ident >> " :: ")
|
||||
```
|
||||
-/
|
||||
private def expandCtor (structStx : Syntax) (structModifiers : Modifiers) (structDeclName : Name) : TermElabM StructCtorView := do
|
||||
private def expandCtor (structStx : Syntax) (structModifiers : Modifiers) (structDeclName : Name) : TermElabM CtorView := do
|
||||
let useDefault := do
|
||||
let declName := structDeclName ++ defaultCtorName
|
||||
let ref := structStx[1].mkSynthetic
|
||||
addDeclarationRangesFromSyntax declName ref
|
||||
pure { ref, modifiers := default, name := defaultCtorName, declName }
|
||||
pure { ref, declId := ref, modifiers := default, declName }
|
||||
if structStx[5].isNone then
|
||||
useDefault
|
||||
else
|
||||
@@ -156,7 +123,7 @@ private def expandCtor (structStx : Syntax) (structModifiers : Modifiers) (struc
|
||||
let declName ← applyVisibility ctorModifiers.visibility declName
|
||||
addDocString' declName ctorModifiers.docString?
|
||||
addDeclarationRangesFromSyntax declName ctor[1]
|
||||
pure { ref := ctor[1], name, modifiers := ctorModifiers, declName }
|
||||
pure { ref := ctor[1], declId := ctor[1], modifiers := ctorModifiers, declName }
|
||||
|
||||
def checkValidFieldModifier (modifiers : Modifiers) : TermElabM Unit := do
|
||||
if modifiers.isNoncomputable then
|
||||
@@ -271,7 +238,7 @@ def structureSyntaxToView (modifiers : Modifiers) (stx : Syntax) : TermElabM Str
|
||||
let parents := if exts.isNone then #[] else exts[0][1].getSepArgs
|
||||
let optType := stx[4]
|
||||
let derivingClasses ← getOptDerivingClasses stx[6]
|
||||
let type ← if optType.isNone then `(Sort _) else pure optType[0][1]
|
||||
let type? := if optType.isNone then none else some optType[0][1]
|
||||
let ctor ← expandCtor stx modifiers declName
|
||||
let fields ← expandFields stx modifiers declName
|
||||
fields.forM fun field => do
|
||||
@@ -287,10 +254,13 @@ def structureSyntaxToView (modifiers : Modifiers) (stx : Syntax) : TermElabM Str
|
||||
declName
|
||||
levelNames
|
||||
binders
|
||||
type
|
||||
type?
|
||||
allowIndices := false
|
||||
allowSortPolymorphism := false
|
||||
ctors := #[ctor]
|
||||
parents
|
||||
ctor
|
||||
fields
|
||||
computedFields := #[]
|
||||
derivingClasses
|
||||
}
|
||||
|
||||
@@ -536,7 +506,7 @@ private partial def mkToParentName (parentStructName : Name) (p : Name → Bool)
|
||||
if p curr then curr else go (i+1)
|
||||
go 1
|
||||
|
||||
private partial def elabParents (view : StructView)
|
||||
private partial def withParents (view : StructView) (rs : Array ElabHeaderResult) (indFVar : Expr)
|
||||
(k : Array StructFieldInfo → Array StructParentInfo → TermElabM α) : TermElabM α := do
|
||||
go 0 #[] #[]
|
||||
where
|
||||
@@ -544,11 +514,17 @@ where
|
||||
if h : i < view.parents.size then
|
||||
let parent := view.parents[i]
|
||||
withRef parent do
|
||||
let type ← Term.elabType parent
|
||||
-- The only use case for autobound implicits for parents might be outParams, but outParam is not propagated.
|
||||
let type ← Term.withoutAutoBoundImplicit <| Term.elabType parent
|
||||
let parentType ← whnf type
|
||||
if parentType.getAppFn == indFVar then
|
||||
logWarning "structure extends itself, skipping"
|
||||
return ← go (i + 1) infos parents
|
||||
if rs.any (fun r => r.indFVar == parentType.getAppFn) then
|
||||
throwError "structure cannot extend types defined in the same mutual block"
|
||||
let parentStructName ← getStructureName parentType
|
||||
if parents.any (fun info => info.structName == parentStructName) then
|
||||
logWarningAt parent m!"duplicate parent structure '{.ofConstName parentStructName}', skipping"
|
||||
logWarning m!"duplicate parent structure '{.ofConstName parentStructName}', skipping"
|
||||
go (i + 1) infos parents
|
||||
else if let some existingFieldName ← findExistingField? infos parentStructName then
|
||||
if structureDiamondWarning.get (← getOptions) then
|
||||
@@ -570,6 +546,13 @@ where
|
||||
else
|
||||
k infos parents
|
||||
|
||||
private def registerFailedToInferFieldType (fieldName : Name) (e : Expr) (ref : Syntax) : TermElabM Unit := do
|
||||
Term.registerCustomErrorIfMVar (← instantiateMVars e) ref m!"failed to infer type of field '{.ofConstName fieldName}'"
|
||||
|
||||
private def registerFailedToInferDefaultValue (fieldName : Name) (e : Expr) (ref : Syntax) : TermElabM Unit := do
|
||||
Term.registerCustomErrorIfMVar (← instantiateMVars e) ref m!"failed to infer default value for field '{.ofConstName fieldName}'"
|
||||
Term.registerLevelMVarErrorExprInfo e ref m!"failed to infer universe levels in default value for field '{.ofConstName fieldName}'"
|
||||
|
||||
private def elabFieldTypeValue (view : StructFieldView) : TermElabM (Option Expr × Option Expr) :=
|
||||
Term.withAutoBoundImplicit <| Term.withAutoBoundImplicitForbiddenPred (fun n => view.name == n) <| Term.elabBinders view.binders.getArgs fun params => do
|
||||
match view.type? with
|
||||
@@ -581,10 +564,13 @@ private def elabFieldTypeValue (view : StructFieldView) : TermElabM (Option Expr
|
||||
-- TODO: add forbidden predicate using `shortDeclName` from `view`
|
||||
let params ← Term.addAutoBoundImplicits params
|
||||
let value ← Term.withoutAutoBoundImplicit <| Term.elabTerm valStx none
|
||||
registerFailedToInferFieldType view.name (← inferType value) view.nameId
|
||||
registerFailedToInferDefaultValue view.name value valStx
|
||||
let value ← mkLambdaFVars params value
|
||||
return (none, value)
|
||||
| some typeStx =>
|
||||
let type ← Term.elabType typeStx
|
||||
registerFailedToInferFieldType view.name type typeStx
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let params ← Term.addAutoBoundImplicits params
|
||||
match view.value? with
|
||||
@@ -593,6 +579,7 @@ private def elabFieldTypeValue (view : StructFieldView) : TermElabM (Option Expr
|
||||
return (type, none)
|
||||
| some valStx =>
|
||||
let value ← Term.withoutAutoBoundImplicit <| Term.elabTermEnsuringType valStx type
|
||||
registerFailedToInferDefaultValue view.name value valStx
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let type ← mkForallFVars params type
|
||||
let value ← mkLambdaFVars params value
|
||||
@@ -639,6 +626,7 @@ where
|
||||
valStx ← `(fun $(view.binders.getArgs)* => $valStx:term)
|
||||
let fvarType ← inferType info.fvar
|
||||
let value ← Term.elabTermEnsuringType valStx fvarType
|
||||
registerFailedToInferDefaultValue view.name value valStx
|
||||
pushInfoLeaf <| .ofFieldRedeclInfo { stx := view.ref }
|
||||
let infos := replaceFieldInfo infos { info with ref := view.nameId, value? := value }
|
||||
go (i+1) defaultValsOverridden infos
|
||||
@@ -650,113 +638,14 @@ where
|
||||
else
|
||||
k infos
|
||||
|
||||
private def getResultUniverse (type : Expr) : TermElabM Level := do
|
||||
let type ← whnf type
|
||||
match type with
|
||||
| Expr.sort u => pure u
|
||||
| _ => throwError "unexpected structure resulting type"
|
||||
|
||||
private def collectUsed (params : Array Expr) (fieldInfos : Array StructFieldInfo) : StateRefT CollectFVars.State MetaM Unit := do
|
||||
params.forM fun p => do
|
||||
let type ← inferType p
|
||||
type.collectFVars
|
||||
fieldInfos.forM fun info => do
|
||||
let fvarType ← inferType info.fvar
|
||||
fvarType.collectFVars
|
||||
match info.value? with
|
||||
| none => pure ()
|
||||
| some value => value.collectFVars
|
||||
|
||||
private def removeUnused (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo)
|
||||
: TermElabM (LocalContext × LocalInstances × Array Expr) := do
|
||||
let (_, used) ← (collectUsed params fieldInfos).run {}
|
||||
Meta.removeUnused scopeVars used
|
||||
|
||||
private def withUsed {α} (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo) (k : Array Expr → TermElabM α)
|
||||
: TermElabM α := do
|
||||
let (lctx, localInsts, vars) ← removeUnused scopeVars params fieldInfos
|
||||
withLCtx lctx localInsts <| k vars
|
||||
|
||||
private def levelMVarToParam (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo) (univToInfer? : Option LMVarId) : TermElabM (Array StructFieldInfo) := do
|
||||
levelMVarToParamFVars scopeVars
|
||||
levelMVarToParamFVars params
|
||||
fieldInfos.mapM fun info => do
|
||||
levelMVarToParamFVar info.fvar
|
||||
match info.value? with
|
||||
| none => pure info
|
||||
| some value =>
|
||||
let value ← levelMVarToParam' value
|
||||
pure { info with value? := value }
|
||||
where
|
||||
levelMVarToParam' (type : Expr) : TermElabM Expr := do
|
||||
Term.levelMVarToParam type (except := fun mvarId => univToInfer? == some mvarId)
|
||||
|
||||
levelMVarToParamFVars (fvars : Array Expr) : TermElabM Unit :=
|
||||
fvars.forM levelMVarToParamFVar
|
||||
|
||||
levelMVarToParamFVar (fvar : Expr) : TermElabM Unit := do
|
||||
let type ← inferType fvar
|
||||
discard <| levelMVarToParam' type
|
||||
|
||||
|
||||
private partial def collectUniversesFromFields (r : Level) (rOffset : Nat) (fieldInfos : Array StructFieldInfo) : TermElabM (Array Level) := do
|
||||
let (_, us) ← go |>.run #[]
|
||||
return us
|
||||
where
|
||||
go : StateRefT (Array Level) TermElabM Unit :=
|
||||
for info in fieldInfos do
|
||||
let type ← inferType info.fvar
|
||||
let u ← getLevel type
|
||||
let u ← instantiateLevelMVars u
|
||||
match (← modifyGet fun s => accLevel u r rOffset |>.run |>.run s) with
|
||||
| some _ => pure ()
|
||||
| none =>
|
||||
let typeType ← inferType type
|
||||
let mut msg := m!"failed to compute resulting universe level of structure, field '{info.declName}' has type{indentD m!"{type} : {typeType}"}\nstructure resulting type{indentExpr (mkSort (r.addOffset rOffset))}"
|
||||
if r.isMVar then
|
||||
msg := msg ++ "\nrecall that Lean only infers the resulting universe level automatically when there is a unique solution for the universe level constraints, consider explicitly providing the structure resulting universe level"
|
||||
throwError msg
|
||||
|
||||
/--
|
||||
Decides whether the structure should be `Prop`-valued when the universe is not given
|
||||
and when the universe inference algorithm `collectUniversesFromFields` determines
|
||||
that the inductive type could naturally be `Prop`-valued.
|
||||
|
||||
See `Lean.Elab.Command.isPropCandidate` for an explanation.
|
||||
Specialized to structures, the heuristic is that we prefer a `Prop` instead of a `Type` structure
|
||||
when it could be a syntactic subsingleton.
|
||||
Exception: no-field structures are `Type` since they are likely stubbed-out declarations.
|
||||
-/
|
||||
private def isPropCandidate (fieldInfos : Array StructFieldInfo) : Bool :=
|
||||
!fieldInfos.isEmpty
|
||||
|
||||
private def updateResultingUniverse (fieldInfos : Array StructFieldInfo) (type : Expr) : TermElabM Expr := do
|
||||
let r ← getResultUniverse type
|
||||
let rOffset : Nat := r.getOffset
|
||||
let r : Level := r.getLevelOffset
|
||||
unless r.isMVar do
|
||||
throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly: {r}"
|
||||
let us ← collectUniversesFromFields r rOffset fieldInfos
|
||||
trace[Elab.structure] "updateResultingUniverse us: {us}, r: {r}, rOffset: {rOffset}"
|
||||
let rNew := mkResultUniverse us rOffset (isPropCandidate fieldInfos)
|
||||
assignLevelMVar r.mvarId! rNew
|
||||
instantiateMVars type
|
||||
|
||||
private def collectLevelParamsInFVar (s : CollectLevelParams.State) (fvar : Expr) : TermElabM CollectLevelParams.State := do
|
||||
let type ← inferType fvar
|
||||
let type ← instantiateMVars type
|
||||
return collectLevelParams s type
|
||||
|
||||
private def collectLevelParamsInFVars (fvars : Array Expr) (s : CollectLevelParams.State) : TermElabM CollectLevelParams.State :=
|
||||
fvars.foldlM collectLevelParamsInFVar s
|
||||
|
||||
private def collectLevelParamsInStructure (structType : Expr) (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo)
|
||||
: TermElabM (Array Name) := do
|
||||
let s := collectLevelParams {} structType
|
||||
let s ← collectLevelParamsInFVars scopeVars s
|
||||
let s ← collectLevelParamsInFVars params s
|
||||
let s ← fieldInfos.foldlM (init := s) fun s info => collectLevelParamsInFVar s info.fvar
|
||||
return s.params
|
||||
private def collectUsedFVars (lctx : LocalContext) (localInsts : LocalInstances) (fieldInfos : Array StructFieldInfo) :
|
||||
StateRefT CollectFVars.State MetaM Unit := do
|
||||
withLCtx lctx localInsts do
|
||||
fieldInfos.forM fun info => do
|
||||
let fvarType ← inferType info.fvar
|
||||
fvarType.collectFVars
|
||||
if let some value := info.value? then
|
||||
value.collectFVars
|
||||
|
||||
private def addCtorFields (fieldInfos : Array StructFieldInfo) : Nat → Expr → TermElabM Expr
|
||||
| 0, type => pure type
|
||||
@@ -772,19 +661,29 @@ private def addCtorFields (fieldInfos : Array StructFieldInfo) : Nat → Expr
|
||||
| _ =>
|
||||
addCtorFields fieldInfos i (mkForall decl.userName decl.binderInfo decl.type type)
|
||||
|
||||
private def mkCtor (view : StructView) (levelParams : List Name) (params : Array Expr) (fieldInfos : Array StructFieldInfo) : TermElabM Constructor :=
|
||||
private def mkCtor (view : StructView) (r : ElabHeaderResult) (params : Array Expr) (fieldInfos : Array StructFieldInfo) : TermElabM Constructor :=
|
||||
withRef view.ref do
|
||||
let type := mkAppN (mkConst view.declName (levelParams.map mkLevelParam)) params
|
||||
let type := mkAppN r.indFVar params
|
||||
let type ← addCtorFields fieldInfos fieldInfos.size type
|
||||
let type ← mkForallFVars params type
|
||||
let type ← instantiateMVars type
|
||||
let type := type.inferImplicit params.size true
|
||||
pure { name := view.ctor.declName, type }
|
||||
|
||||
private partial def checkResultingUniversesForFields (fieldInfos : Array StructFieldInfo) (u : Level) : TermElabM Unit := do
|
||||
for info in fieldInfos do
|
||||
let type ← inferType info.fvar
|
||||
let v := (← instantiateLevelMVars (← getLevel type)).normalize
|
||||
unless u.geq v do
|
||||
let msg := m!"invalid universe level for field '{info.name}', has type{indentExpr type}\n\
|
||||
at universe level{indentD v}\n\
|
||||
which is not less than or equal to the structure's resulting universe level{indentD u}"
|
||||
throwErrorAt info.ref msg
|
||||
|
||||
@[extern "lean_mk_projections"]
|
||||
private opaque mkProjections (env : Environment) (structName : Name) (projs : List Name) (isClass : Bool) : Except KernelException Environment
|
||||
|
||||
private def addProjections (r : ElabStructHeaderResult) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
private def addProjections (r : ElabHeaderResult) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
if r.type.isProp then
|
||||
if let some fieldInfo ← fieldInfos.findM? (not <$> Meta.isProof ·.fvar) then
|
||||
throwErrorAt fieldInfo.ref m!"failed to generate projections for 'Prop' structure, field '{format fieldInfo.name}' is not a proof"
|
||||
@@ -795,49 +694,71 @@ private def addProjections (r : ElabStructHeaderResult) (fieldInfos : Array Stru
|
||||
|
||||
private def registerStructure (structName : Name) (infos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
let fields ← infos.filterMapM fun info => do
|
||||
if info.kind == StructFieldKind.fromParent then
|
||||
return none
|
||||
else
|
||||
return some {
|
||||
fieldName := info.name
|
||||
projFn := info.declName
|
||||
binderInfo := (← getFVarLocalDecl info.fvar).binderInfo
|
||||
autoParam? := (← inferType info.fvar).getAutoParamTactic?
|
||||
subobject? := if let .subobject parentName := info.kind then parentName else none
|
||||
}
|
||||
if info.kind == StructFieldKind.fromParent then
|
||||
return none
|
||||
else
|
||||
return some {
|
||||
fieldName := info.name
|
||||
projFn := info.declName
|
||||
binderInfo := (← getFVarLocalDecl info.fvar).binderInfo
|
||||
autoParam? := (← inferType info.fvar).getAutoParamTactic?
|
||||
subobject? := if let .subobject parentName := info.kind then parentName else none
|
||||
}
|
||||
modifyEnv fun env => Lean.registerStructure env { structName, fields }
|
||||
|
||||
private def mkAuxConstructions (declName : Name) : TermElabM Unit := do
|
||||
let env ← getEnv
|
||||
let hasEq := env.contains ``Eq
|
||||
let hasHEq := env.contains ``HEq
|
||||
let hasUnit := env.contains ``PUnit
|
||||
let hasProd := env.contains ``Prod
|
||||
mkRecOn declName
|
||||
if hasUnit then mkCasesOn declName
|
||||
if hasUnit && hasEq && hasHEq then mkNoConfusion declName
|
||||
let ival ← getConstInfoInduct declName
|
||||
if ival.isRec then
|
||||
if hasUnit && hasProd then mkBelow declName
|
||||
if hasUnit && hasProd then mkIBelow declName
|
||||
if hasUnit && hasProd then mkBRecOn declName
|
||||
if hasUnit && hasProd then mkBInductionOn declName
|
||||
private def checkDefaults (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
let mut mvars := {}
|
||||
let mut lmvars := {}
|
||||
for fieldInfo in fieldInfos do
|
||||
if let some value := fieldInfo.value? then
|
||||
let value ← instantiateMVars value
|
||||
mvars := Expr.collectMVars mvars value
|
||||
lmvars := collectLevelMVars lmvars value
|
||||
-- Log errors and ignore the failure; we later will just omit adding a default value.
|
||||
if ← Term.logUnassignedUsingErrorInfos mvars.result then
|
||||
return
|
||||
else if ← Term.logUnassignedLevelMVarsUsingErrorInfos lmvars.result then
|
||||
return
|
||||
|
||||
private def addDefaults (lctx : LocalContext) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
private def addDefaults (params : Array Expr) (replaceIndFVars : Expr → MetaM Expr) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
let lctx ← getLCtx
|
||||
/- The `lctx` and `defaultAuxDecls` are used to create the auxiliary "default value" declarations
|
||||
The parameters `params` for these definitions must be marked as implicit, and all others as explicit. -/
|
||||
let lctx :=
|
||||
params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) =>
|
||||
if p.isFVar then
|
||||
lctx.setBinderInfo p.fvarId! BinderInfo.implicit
|
||||
else
|
||||
lctx
|
||||
let lctx :=
|
||||
fieldInfos.foldl (init := lctx) fun (lctx : LocalContext) (info : StructFieldInfo) =>
|
||||
if info.isFromParent then lctx -- `fromParent` fields are elaborated as let-decls, and are zeta-expanded when creating "default value" auxiliary functions
|
||||
else lctx.setBinderInfo info.fvar.fvarId! BinderInfo.default
|
||||
-- Make all indFVar replacements in the local context.
|
||||
let lctx ←
|
||||
lctx.foldlM (init := {}) fun lctx ldecl => do
|
||||
match ldecl with
|
||||
| .cdecl _ fvarId userName type bi k =>
|
||||
let type ← replaceIndFVars type
|
||||
return lctx.mkLocalDecl fvarId userName type bi k
|
||||
| .ldecl _ fvarId userName type value nonDep k =>
|
||||
let type ← replaceIndFVars type
|
||||
let value ← replaceIndFVars value
|
||||
return lctx.mkLetDecl fvarId userName type value nonDep k
|
||||
withLCtx lctx (← getLocalInstances) do
|
||||
fieldInfos.forM fun fieldInfo => do
|
||||
if let some value := fieldInfo.value? then
|
||||
let declName := mkDefaultFnOfProjFn fieldInfo.declName
|
||||
let type ← inferType fieldInfo.fvar
|
||||
let value ← instantiateMVars value
|
||||
if value.hasExprMVar then
|
||||
discard <| Term.logUnassignedUsingErrorInfos (← getMVars value)
|
||||
throwErrorAt fieldInfo.ref "invalid default value for field '{format fieldInfo.name}', it contains metavariables{indentExpr value}"
|
||||
/- The identity function is used as "marker". -/
|
||||
let value ← mkId value
|
||||
-- No need to compile the definition, since it is only used during elaboration.
|
||||
discard <| mkAuxDefinition declName type value (zetaDelta := true) (compile := false)
|
||||
setReducibleAttribute declName
|
||||
let type ← replaceIndFVars (← inferType fieldInfo.fvar)
|
||||
let value ← instantiateMVars (← replaceIndFVars value)
|
||||
trace[Elab.structure] "default value after 'replaceIndFVars': {indentExpr value}"
|
||||
-- If there are mvars, `checkDefaults` already logged an error.
|
||||
unless value.hasMVar || value.hasSyntheticSorry do
|
||||
/- The identity function is used as "marker". -/
|
||||
let value ← mkId value
|
||||
-- No need to compile the definition, since it is only used during elaboration.
|
||||
discard <| mkAuxDefinition declName type value (zetaDelta := true) (compile := false)
|
||||
setReducibleAttribute declName
|
||||
|
||||
/--
|
||||
Given `type` of the form `forall ... (source : A), B`, return `forall ... [source : A], B`.
|
||||
@@ -906,57 +827,6 @@ private partial def mkCoercionToCopiedParent (levelParams : List Name) (params :
|
||||
setReducibleAttribute declName
|
||||
return { structName := parentStructName, subobject := false, projFn := declName }
|
||||
|
||||
private def elabStructHeader (view : StructView) : TermElabM ElabStructHeaderResult :=
|
||||
Term.withAutoBoundImplicitForbiddenPred (fun n => view.shortDeclName == n) do
|
||||
Term.withAutoBoundImplicit do
|
||||
Term.elabBinders view.binders.getArgs fun params => do
|
||||
elabParents view fun parentFieldInfos parents => do
|
||||
let type ← Term.elabType view.type
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let u ← mkFreshLevelMVar
|
||||
unless ← isDefEq type (mkSort u) do
|
||||
throwErrorAt view.type "invalid structure type, expecting 'Type _' or 'Prop'"
|
||||
let type ← instantiateMVars (← whnf type)
|
||||
Term.addAutoBoundImplicits' params type fun params type => do
|
||||
let levelNames ← Term.getLevelNames
|
||||
trace[Elab.structure] "header params: {params}, type: {type}, levelNames: {levelNames}"
|
||||
return { lctx := (← getLCtx), localInsts := (← getLocalInstances), levelNames, params, type, view, parents, parentFieldInfos }
|
||||
|
||||
private def mkTypeFor (r : ElabStructHeaderResult) : TermElabM Expr := do
|
||||
withLCtx r.lctx r.localInsts do
|
||||
mkForallFVars r.params r.type
|
||||
|
||||
/--
|
||||
Create a local declaration for the structure and execute `x params indFVar`, where `params` are the structure's type parameters and
|
||||
`indFVar` is the new local declaration.
|
||||
-/
|
||||
private partial def withStructureLocalDecl (r : ElabStructHeaderResult) (x : Array Expr → Expr → TermElabM α) : TermElabM α := do
|
||||
let declName := r.view.declName
|
||||
let shortDeclName := r.view.shortDeclName
|
||||
let type ← mkTypeFor r
|
||||
let params := r.params
|
||||
withLCtx r.lctx r.localInsts <| withRef r.view.ref do
|
||||
Term.withAuxDecl shortDeclName type declName fun indFVar =>
|
||||
x params indFVar
|
||||
|
||||
/--
|
||||
Remark: `numVars <= numParams`.
|
||||
`numVars` is the number of context `variables` used in the declaration,
|
||||
and `numParams - numVars` is the number of parameters provided as binders in the declaration.
|
||||
-/
|
||||
private def mkInductiveType (view : StructView) (indFVar : Expr) (levelNames : List Name)
|
||||
(numVars : Nat) (numParams : Nat) (type : Expr) (ctor : Constructor) : TermElabM InductiveType := do
|
||||
let levelParams := levelNames.map mkLevelParam
|
||||
let const := mkConst view.declName levelParams
|
||||
let ctorType ← forallBoundedTelescope ctor.type numParams fun params type => do
|
||||
let type := type.replace fun e =>
|
||||
if e == indFVar then
|
||||
mkAppN const (params.extract 0 numVars)
|
||||
else
|
||||
none
|
||||
instantiateMVars (← mkForallFVars params type)
|
||||
return { name := view.declName, type := ← instantiateMVars type, ctors := [{ ctor with type := ← instantiateMVars ctorType }] }
|
||||
|
||||
/--
|
||||
Precomputes the structure's resolution order.
|
||||
Option `structure.strictResolutionOrder` controls whether to create a warning if the C3 algorithm failed.
|
||||
@@ -987,109 +857,50 @@ private def addParentInstances (parents : Array StructureParentInfo) : MetaM Uni
|
||||
for instParent in instParents do
|
||||
addInstance instParent.projFn AttributeKind.global (eval_prio default)
|
||||
|
||||
def mkStructureDecl (vars : Array Expr) (view : StructView) : TermElabM Unit := Term.withoutSavingRecAppSyntax do
|
||||
let scopeLevelNames ← Term.getLevelNames
|
||||
let isUnsafe := view.modifiers.isUnsafe
|
||||
withRef view.ref <| Term.withLevelNames view.levelNames do
|
||||
let r ← elabStructHeader view
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
withLCtx r.lctx r.localInsts do
|
||||
withStructureLocalDecl r fun params indFVar => do
|
||||
trace[Elab.structure] "indFVar: {indFVar}"
|
||||
Term.addLocalVarInfo view.declId indFVar
|
||||
withFields view.fields r.parentFieldInfos fun fieldInfos =>
|
||||
withRef view.ref do
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let type ← instantiateMVars r.type
|
||||
let u ← getResultUniverse type
|
||||
let univToInfer? ← shouldInferResultUniverse u
|
||||
withUsed vars params fieldInfos fun scopeVars => do
|
||||
let fieldInfos ← levelMVarToParam scopeVars params fieldInfos univToInfer?
|
||||
let type ← withRef view.ref do
|
||||
if univToInfer?.isSome then
|
||||
updateResultingUniverse fieldInfos type
|
||||
else
|
||||
checkResultingUniverse (← getResultUniverse type)
|
||||
pure type
|
||||
trace[Elab.structure] "type: {type}"
|
||||
let usedLevelNames ← collectLevelParamsInStructure type scopeVars params fieldInfos
|
||||
match sortDeclLevelParams scopeLevelNames r.levelNames usedLevelNames with
|
||||
| Except.error msg => throwErrorAt view.declId msg
|
||||
| Except.ok levelParams =>
|
||||
let params := scopeVars ++ params
|
||||
let ctor ← mkCtor view levelParams params fieldInfos
|
||||
let type ← mkForallFVars params type
|
||||
let type ← instantiateMVars type
|
||||
let indType ← mkInductiveType view indFVar levelParams scopeVars.size params.size type ctor
|
||||
let decl := Declaration.inductDecl levelParams params.size [indType] isUnsafe
|
||||
Term.ensureNoUnassignedMVars decl
|
||||
addDecl decl
|
||||
-- rename indFVar so that it does not shadow the actual declaration:
|
||||
let lctx := (← getLCtx).modifyLocalDecl indFVar.fvarId! fun decl => decl.setUserName .anonymous
|
||||
withLCtx lctx (← getLocalInstances) do
|
||||
addProjections r fieldInfos
|
||||
registerStructure view.declName fieldInfos
|
||||
mkAuxConstructions view.declName
|
||||
withSaveInfoContext do -- save new env
|
||||
Term.addLocalVarInfo view.ref[1] (← mkConstWithLevelParams view.declName)
|
||||
if let some _ := view.ctor.ref.getPos? (canonicalOnly := true) then
|
||||
Term.addTermInfo' view.ctor.ref (← mkConstWithLevelParams view.ctor.declName) (isBinder := true)
|
||||
for field in view.fields do
|
||||
-- may not exist if overriding inherited field
|
||||
if (← getEnv).contains field.declName then
|
||||
Term.addTermInfo' field.ref (← mkConstWithLevelParams field.declName) (isBinder := true)
|
||||
withRef view.declId do
|
||||
Term.applyAttributesAt view.declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking
|
||||
let parentInfos ← r.parents.mapM fun parent => do
|
||||
if parent.subobject then
|
||||
let some info := fieldInfos.find? (·.kind == .subobject parent.structName) | unreachable!
|
||||
pure { structName := parent.structName, subobject := true, projFn := info.declName }
|
||||
else
|
||||
mkCoercionToCopiedParent levelParams params view parent.structName parent.type
|
||||
setStructureParents view.declName parentInfos
|
||||
checkResolutionOrder view.declName
|
||||
if view.isClass then
|
||||
addParentInstances parentInfos
|
||||
|
||||
let lctx ← getLCtx
|
||||
/- The `lctx` and `defaultAuxDecls` are used to create the auxiliary "default value" declarations
|
||||
The parameters `params` for these definitions must be marked as implicit, and all others as explicit. -/
|
||||
let lctx :=
|
||||
params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) =>
|
||||
if p.isFVar then
|
||||
lctx.setBinderInfo p.fvarId! BinderInfo.implicit
|
||||
else
|
||||
lctx
|
||||
let lctx :=
|
||||
fieldInfos.foldl (init := lctx) fun (lctx : LocalContext) (info : StructFieldInfo) =>
|
||||
if info.isFromParent then lctx -- `fromParent` fields are elaborated as let-decls, and are zeta-expanded when creating "default value" auxiliary functions
|
||||
else lctx.setBinderInfo info.fvar.fvarId! BinderInfo.default
|
||||
addDefaults lctx fieldInfos
|
||||
|
||||
|
||||
def elabStructureView (vars : Array Expr) (view : StructView) : TermElabM Unit := do
|
||||
Term.withDeclName view.declName <| withRef view.ref do
|
||||
mkStructureDecl vars view
|
||||
unless view.isClass do
|
||||
Lean.Meta.IndPredBelow.mkBelow view.declName
|
||||
mkSizeOfInstances view.declName
|
||||
mkInjectiveTheorems view.declName
|
||||
|
||||
def elabStructureViewPostprocessing (view : StructView) : CommandElabM Unit := do
|
||||
view.derivingClasses.forM fun classView => classView.applyHandlers #[view.declName]
|
||||
runTermElabM fun _ => Term.withDeclName view.declName <| withRef view.declId do
|
||||
Term.applyAttributesAt view.declName view.modifiers.attrs .afterCompilation
|
||||
|
||||
def elabStructure (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit := do
|
||||
let view ← runTermElabM fun vars => do
|
||||
@[builtin_inductive_elab Lean.Parser.Command.«structure»]
|
||||
def elabStructureCommand : InductiveElabDescr where
|
||||
mkInductiveView (modifiers : Modifiers) (stx : Syntax) := do
|
||||
let view ← structureSyntaxToView modifiers stx
|
||||
trace[Elab.structure] "view.levelNames: {view.levelNames}"
|
||||
elabStructureView vars view
|
||||
pure view
|
||||
elabStructureViewPostprocessing view
|
||||
return {
|
||||
view := view.toInductiveView
|
||||
elabCtors := fun rs r params => do
|
||||
withParents view rs r.indFVar fun parentFieldInfos parents =>
|
||||
withFields view.fields parentFieldInfos fun fieldInfos => do
|
||||
withRef view.ref do
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let lctx ← getLCtx
|
||||
let localInsts ← getLocalInstances
|
||||
let ctor ← mkCtor view r params fieldInfos
|
||||
return {
|
||||
ctors := [ctor]
|
||||
collectUsedFVars := collectUsedFVars lctx localInsts fieldInfos
|
||||
checkUniverses := fun _ u => withLCtx lctx localInsts do checkResultingUniversesForFields fieldInfos u
|
||||
finalizeTermElab := withLCtx lctx localInsts do checkDefaults fieldInfos
|
||||
prefinalize := fun _ _ _ => do
|
||||
withLCtx lctx localInsts do
|
||||
addProjections r fieldInfos
|
||||
registerStructure view.declName fieldInfos
|
||||
withSaveInfoContext do -- save new env
|
||||
for field in view.fields do
|
||||
-- may not exist if overriding inherited field
|
||||
if (← getEnv).contains field.declName then
|
||||
Term.addTermInfo' field.ref (← mkConstWithLevelParams field.declName) (isBinder := true)
|
||||
finalize := fun levelParams params replaceIndFVars => do
|
||||
let parentInfos ← parents.mapM fun parent => do
|
||||
if parent.subobject then
|
||||
let some info := fieldInfos.find? (·.kind == .subobject parent.structName) | unreachable!
|
||||
pure { structName := parent.structName, subobject := true, projFn := info.declName }
|
||||
else
|
||||
mkCoercionToCopiedParent levelParams params view parent.structName parent.type
|
||||
setStructureParents view.declName parentInfos
|
||||
checkResolutionOrder view.declName
|
||||
if view.isClass then
|
||||
addParentInstances parentInfos
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Elab.structure
|
||||
registerTraceClass `Elab.structure.resolutionOrder
|
||||
withLCtx lctx localInsts do
|
||||
addDefaults params replaceIndFVars fieldInfos
|
||||
}
|
||||
}
|
||||
|
||||
end Lean.Elab.Command
|
||||
|
||||
@@ -142,7 +142,7 @@ where
|
||||
let args := stx.getArgs
|
||||
if (← checkLeftRec stx[0]) then
|
||||
if args.size == 1 then throwErrorAt stx "invalid atomic left recursive syntax"
|
||||
let args := args.eraseIdx 0
|
||||
let args := args.eraseIdxIfInBounds 0
|
||||
let args ← args.mapM fun arg => withNestedParser do process arg
|
||||
mkParserSeq args
|
||||
else
|
||||
|
||||
@@ -58,15 +58,28 @@ where
|
||||
let some eqBVPred ← ReifiedBVPred.mkBinPred atom resValExpr atomExpr resExpr .eq | return none
|
||||
let eqBV ← ReifiedBVLogical.ofPred eqBVPred
|
||||
|
||||
let trueExpr := mkConst ``Bool.true
|
||||
let impExpr ← mkArrow (← mkEq eqDiscrExpr trueExpr) (← mkEq eqBVExpr trueExpr)
|
||||
let decideImpExpr ← mkAppOptM ``Decidable.decide #[some impExpr, none]
|
||||
let imp ← ReifiedBVLogical.mkGate eqDiscr eqBV eqDiscrExpr eqBVExpr .imp
|
||||
|
||||
let proof := do
|
||||
let evalExpr ← ReifiedBVLogical.mkEvalExpr imp.expr
|
||||
let congrProof ← imp.evalsAtAtoms
|
||||
let lemmaProof := mkApp4 (mkConst lemmaName) (toExpr lhs.width) discrExpr lhsExpr rhsExpr
|
||||
|
||||
let trueExpr := mkConst ``Bool.true
|
||||
let eqDiscrTrueExpr ← mkEq eqDiscrExpr trueExpr
|
||||
let eqBVExprTrueExpr ← mkEq eqBVExpr trueExpr
|
||||
let impExpr ← mkArrow eqDiscrTrueExpr eqBVExprTrueExpr
|
||||
-- construct a `Decidable` instance for the implication using forall_prop_decidable
|
||||
let decEqDiscrTrue := mkApp2 (mkConst ``instDecidableEqBool) eqDiscrExpr trueExpr
|
||||
let decEqBVExprTrue := mkApp2 (mkConst ``instDecidableEqBool) eqBVExpr trueExpr
|
||||
let impDecidable := mkApp4 (mkConst ``forall_prop_decidable)
|
||||
eqDiscrTrueExpr
|
||||
(.lam .anonymous eqDiscrTrueExpr eqBVExprTrueExpr .default)
|
||||
decEqDiscrTrue
|
||||
(.lam .anonymous eqDiscrTrueExpr decEqBVExprTrue .default)
|
||||
|
||||
let decideImpExpr := mkApp2 (mkConst ``Decidable.decide) impExpr impDecidable
|
||||
|
||||
return mkApp4
|
||||
(mkConst ``Std.Tactic.BVDecide.Reflect.Bool.lemma_congr)
|
||||
decideImpExpr
|
||||
|
||||
@@ -218,7 +218,7 @@ where
|
||||
let old ← snap.old?
|
||||
-- If the kind is equal, we can assume the old version was a macro as well
|
||||
guard <| old.stx.isOfKind stx.getKind
|
||||
let state ← old.val.get.data.finished.get.state?
|
||||
let state ← old.val.get.finished.get.state?
|
||||
guard <| state.term.meta.core.nextMacroScope == nextMacroScope
|
||||
-- check absence of traces; see Note [Incremental Macros]
|
||||
guard <| state.term.meta.core.traceState.traces.size == 0
|
||||
@@ -226,9 +226,10 @@ where
|
||||
return old.val.get
|
||||
Language.withAlwaysResolvedPromise fun promise => do
|
||||
-- Store new unfolding in the snapshot tree
|
||||
snap.new.resolve <| .mk {
|
||||
snap.new.resolve {
|
||||
stx := stx'
|
||||
diagnostics := .empty
|
||||
inner? := none
|
||||
finished := .pure {
|
||||
diagnostics := .empty
|
||||
state? := (← Tactic.saveState)
|
||||
@@ -240,7 +241,7 @@ where
|
||||
new := promise
|
||||
old? := do
|
||||
let old ← old?
|
||||
return ⟨old.data.stx, (← old.data.next.get? 0)⟩
|
||||
return ⟨old.stx, (← old.next.get? 0)⟩
|
||||
} }) do
|
||||
evalTactic stx'
|
||||
return
|
||||
|
||||
@@ -60,7 +60,7 @@ where
|
||||
if let some snap := (← readThe Term.Context).tacSnap? then
|
||||
if let some old := snap.old? then
|
||||
let oldParsed := old.val.get
|
||||
oldInner? := oldParsed.data.inner? |>.map (⟨oldParsed.data.stx, ·⟩)
|
||||
oldInner? := oldParsed.inner? |>.map (⟨oldParsed.stx, ·⟩)
|
||||
-- compare `stx[0]` for `finished`/`next` reuse, focus on remainder of script
|
||||
Term.withNarrowedTacticReuse (stx := stx) (fun stx => (stx[0], mkNullNode stx.getArgs[1:])) fun stxs => do
|
||||
let some snap := (← readThe Term.Context).tacSnap?
|
||||
@@ -70,10 +70,10 @@ where
|
||||
if let some old := snap.old? then
|
||||
-- `tac` must be unchanged given the narrow above; let's reuse `finished`'s state!
|
||||
let oldParsed := old.val.get
|
||||
if let some state := oldParsed.data.finished.get.state? then
|
||||
if let some state := oldParsed.finished.get.state? then
|
||||
reusableResult? := some ((), state)
|
||||
-- only allow `next` reuse in this case
|
||||
oldNext? := oldParsed.data.next.get? 0 |>.map (⟨old.stx, ·⟩)
|
||||
oldNext? := oldParsed.next.get? 0 |>.map (⟨old.stx, ·⟩)
|
||||
|
||||
-- For `tac`'s snapshot task range, disregard synthetic info as otherwise
|
||||
-- `SnapshotTree.findInfoTreeAtPos` might choose the wrong snapshot: for example, when
|
||||
@@ -89,7 +89,7 @@ where
|
||||
withAlwaysResolvedPromise fun next => do
|
||||
withAlwaysResolvedPromise fun finished => do
|
||||
withAlwaysResolvedPromise fun inner => do
|
||||
snap.new.resolve <| .mk {
|
||||
snap.new.resolve {
|
||||
desc := tac.getKind.toString
|
||||
diagnostics := .empty
|
||||
stx := tac
|
||||
|
||||
@@ -282,10 +282,11 @@ where
|
||||
-- them, eventually put each of them back in `Context.tacSnap?` in `applyAltStx`
|
||||
withAlwaysResolvedPromise fun finished => do
|
||||
withAlwaysResolvedPromises altStxs.size fun altPromises => do
|
||||
tacSnap.new.resolve <| .mk {
|
||||
tacSnap.new.resolve {
|
||||
-- save all relevant syntax here for comparison with next document version
|
||||
stx := mkNullNode altStxs
|
||||
diagnostics := .empty
|
||||
inner? := none
|
||||
finished := { range? := none, task := finished.result }
|
||||
next := altStxs.zipWith altPromises fun stx prom =>
|
||||
{ range? := stx.getRange?, task := prom.result }
|
||||
@@ -298,7 +299,7 @@ where
|
||||
let old := old.val.get
|
||||
-- use old version of `mkNullNode altsSyntax` as guard, will be compared with new
|
||||
-- version and picked apart in `applyAltStx`
|
||||
return ⟨old.data.stx, (← old.data.next[i]?)⟩
|
||||
return ⟨old.stx, (← old.next[i]?)⟩
|
||||
new := prom
|
||||
}
|
||||
finished.resolve { diagnostics := .empty, state? := (← saveState) }
|
||||
@@ -340,9 +341,9 @@ where
|
||||
for h : altStxIdx in [0:altStxs.size] do
|
||||
let altStx := altStxs[altStxIdx]
|
||||
let altName := getAltName altStx
|
||||
if let some i := alts.findIdx? (·.1 == altName) then
|
||||
if let some i := alts.findFinIdx? (·.1 == altName) then
|
||||
-- cover named alternative
|
||||
applyAltStx tacSnaps altStxIdx altStx alts[i]!
|
||||
applyAltStx tacSnaps altStxIdx altStx alts[i]
|
||||
alts := alts.eraseIdx i
|
||||
else if !alts.isEmpty && isWildcard altStx then
|
||||
-- cover all alternatives
|
||||
|
||||
@@ -40,7 +40,7 @@ private def tryAssumption (mvarId : MVarId) : MetaM (List MVarId) := do
|
||||
let ids := stx[1].getArgs.toList.map getNameOfIdent'
|
||||
liftMetaTactic fun mvarId => do
|
||||
match (← Meta.injections mvarId ids) with
|
||||
| .solved => checkUnusedIds `injections mvarId ids; return []
|
||||
| .subgoal mvarId' unusedIds => checkUnusedIds `injections mvarId unusedIds; tryAssumption mvarId'
|
||||
| .solved => checkUnusedIds `injections mvarId ids; return []
|
||||
| .subgoal mvarId' unusedIds _ => checkUnusedIds `injections mvarId unusedIds; tryAssumption mvarId'
|
||||
|
||||
end Lean.Elab.Tactic
|
||||
|
||||
@@ -230,7 +230,7 @@ instance : ToSnapshotTree TacticFinishedSnapshot where
|
||||
toSnapshotTree s := ⟨s.toSnapshot, #[]⟩
|
||||
|
||||
/-- Snapshot just before execution of a tactic. -/
|
||||
structure TacticParsedSnapshotData (TacticParsedSnapshot : Type) extends Language.Snapshot where
|
||||
structure TacticParsedSnapshot extends Language.Snapshot where
|
||||
/-- Syntax tree of the tactic, stored and compared for incremental reuse. -/
|
||||
stx : Syntax
|
||||
/-- Task for nested incrementality, if enabled for tactic. -/
|
||||
@@ -240,16 +240,9 @@ structure TacticParsedSnapshotData (TacticParsedSnapshot : Type) extends Languag
|
||||
/-- Tasks for subsequent, potentially parallel, tactic steps. -/
|
||||
next : Array (SnapshotTask TacticParsedSnapshot) := #[]
|
||||
deriving Inhabited
|
||||
|
||||
/-- State after execution of a single synchronous tactic step. -/
|
||||
inductive TacticParsedSnapshot where
|
||||
| mk (data : TacticParsedSnapshotData TacticParsedSnapshot)
|
||||
deriving Inhabited
|
||||
abbrev TacticParsedSnapshot.data : TacticParsedSnapshot → TacticParsedSnapshotData TacticParsedSnapshot
|
||||
| .mk data => data
|
||||
partial instance : ToSnapshotTree TacticParsedSnapshot where
|
||||
toSnapshotTree := go where
|
||||
go := fun ⟨s⟩ => ⟨s.toSnapshot,
|
||||
go := fun s => ⟨s.toSnapshot,
|
||||
s.inner?.toArray.map (·.map (sync := true) go) ++
|
||||
#[s.finished.map (sync := true) toSnapshotTree] ++
|
||||
s.next.map (·.map (sync := true) go)⟩
|
||||
|
||||
@@ -56,6 +56,11 @@ structure Snapshot where
|
||||
/-- General elaboration metadata produced by this step. -/
|
||||
infoTree? : Option Elab.InfoTree := none
|
||||
/--
|
||||
Trace data produced by this step. Currently used only by `trace.profiler.output`, otherwise we
|
||||
depend on the elaborator adding traces to `diagnostics` eventually.
|
||||
-/
|
||||
traces : TraceState := {}
|
||||
/--
|
||||
Whether it should be indicated to the user that a fatal error (which should be part of
|
||||
`diagnostics`) occurred that prevents processing of the remainder of the file.
|
||||
-/
|
||||
@@ -193,18 +198,13 @@ def withAlwaysResolvedPromises [Monad m] [MonadLiftT BaseIO m] [MonadFinally m]
|
||||
for asynchronously collecting information about the entirety of snapshots in the language server.
|
||||
The involved tasks may form a DAG on the `Task` dependency level but this is not captured by this
|
||||
data structure. -/
|
||||
inductive SnapshotTree where
|
||||
/-- Creates a snapshot tree node. -/
|
||||
| mk (element : Snapshot) (children : Array (SnapshotTask SnapshotTree))
|
||||
structure SnapshotTree where
|
||||
/-- The immediately available element of the snapshot tree node. -/
|
||||
element : Snapshot
|
||||
/-- The asynchronously available children of the snapshot tree node. -/
|
||||
children : Array (SnapshotTask SnapshotTree)
|
||||
deriving Inhabited
|
||||
|
||||
/-- The immediately available element of the snapshot tree node. -/
|
||||
abbrev SnapshotTree.element : SnapshotTree → Snapshot
|
||||
| mk s _ => s
|
||||
/-- The asynchronously available children of the snapshot tree node. -/
|
||||
abbrev SnapshotTree.children : SnapshotTree → Array (SnapshotTask SnapshotTree)
|
||||
| mk _ children => children
|
||||
|
||||
/-- Produces trace of given snapshot tree, synchronously waiting on all children. -/
|
||||
partial def SnapshotTree.trace (s : SnapshotTree) : CoreM Unit :=
|
||||
go none s
|
||||
|
||||
@@ -348,7 +348,7 @@ where
|
||||
if let some (some processed) ← old.processedResult.get? then
|
||||
-- ...and the edit location is after the next command (see note [Incremental Parsing])...
|
||||
if let some nextCom ← processed.firstCmdSnap.get? then
|
||||
if (← isBeforeEditPos nextCom.data.parserState.pos) then
|
||||
if (← isBeforeEditPos nextCom.parserState.pos) then
|
||||
-- ...go immediately to next snapshot
|
||||
return (← unchanged old old.stx oldSuccess.parserState)
|
||||
|
||||
@@ -470,20 +470,20 @@ where
|
||||
-- from `old`
|
||||
if let some oldNext := old.nextCmdSnap? then do
|
||||
let newProm ← IO.Promise.new
|
||||
let _ ← old.data.finishedSnap.bindIO (sync := true) fun oldFinished =>
|
||||
let _ ← old.finishedSnap.bindIO (sync := true) fun oldFinished =>
|
||||
-- also wait on old command parse snapshot as parsing is cheap and may allow for
|
||||
-- elaboration reuse
|
||||
oldNext.bindIO (sync := true) fun oldNext => do
|
||||
parseCmd oldNext newParserState oldFinished.cmdState newProm sync ctx
|
||||
return .pure ()
|
||||
prom.resolve <| .mk (data := old.data) (nextCmdSnap? := some { range? := none, task := newProm.result })
|
||||
prom.resolve <| { old with nextCmdSnap? := some { range? := none, task := newProm.result } }
|
||||
else prom.resolve old -- terminal command, we're done!
|
||||
|
||||
-- fast path, do not even start new task for this snapshot
|
||||
if let some old := old? then
|
||||
if let some nextCom ← old.nextCmdSnap?.bindM (·.get?) then
|
||||
if (← isBeforeEditPos nextCom.data.parserState.pos) then
|
||||
return (← unchanged old old.data.parserState)
|
||||
if (← isBeforeEditPos nextCom.parserState.pos) then
|
||||
return (← unchanged old old.parserState)
|
||||
|
||||
let beginPos := parserState.pos
|
||||
let scope := cmdState.scopes.head!
|
||||
@@ -500,7 +500,7 @@ where
|
||||
-- NOTE: as `parserState.pos` includes trailing whitespace, this forces reprocessing even if
|
||||
-- only that whitespace changes, which is wasteful but still necessary because it may
|
||||
-- influence the range of error messages such as from a trailing `exact`
|
||||
if stx.eqWithInfo old.data.stx then
|
||||
if stx.eqWithInfo old.stx then
|
||||
-- Here we must make sure to pass the *new* parser state; see NOTE in `unchanged`
|
||||
return (← unchanged old parserState)
|
||||
-- on first change, make sure to cancel old invocation
|
||||
@@ -515,11 +515,12 @@ where
|
||||
-- this is a bit ugly as we don't want to adjust our API with `Option`s just for cancellation
|
||||
-- (as no-one should look at this result in that case) but anything containing `Environment`
|
||||
-- is not `Inhabited`
|
||||
prom.resolve <| .mk (nextCmdSnap? := none) {
|
||||
prom.resolve <| {
|
||||
diagnostics := .empty, stx := .missing, parserState
|
||||
elabSnap := .pure <| .ofTyped { diagnostics := .empty : SnapshotLeaf }
|
||||
finishedSnap := .pure { diagnostics := .empty, cmdState }
|
||||
tacticCache := (← IO.mkRef {})
|
||||
nextCmdSnap? := none
|
||||
}
|
||||
return
|
||||
|
||||
@@ -537,29 +538,30 @@ where
|
||||
-- irrelevant in this case.
|
||||
let endRange? := stx.getTailPos?.map fun pos => ⟨pos, pos⟩
|
||||
let finishedSnap := { range? := endRange?, task := finishedPromise.result }
|
||||
let tacticCache ← old?.map (·.data.tacticCache) |>.getDM (IO.mkRef {})
|
||||
let tacticCache ← old?.map (·.tacticCache) |>.getDM (IO.mkRef {})
|
||||
|
||||
let minimalSnapshots := internal.cmdlineSnapshots.get cmdState.scopes.head!.opts
|
||||
let next? ← if Parser.isTerminalCommand stx then pure none
|
||||
-- for now, wait on "command finished" snapshot before parsing next command
|
||||
else some <$> IO.Promise.new
|
||||
let nextCmdSnap? := next?.map
|
||||
({ range? := some ⟨parserState.pos, ctx.input.endPos⟩, task := ·.result })
|
||||
let diagnostics ← Snapshot.Diagnostics.ofMessageLog msgLog
|
||||
let data := if minimalSnapshots && !Parser.isTerminalCommand stx then {
|
||||
diagnostics
|
||||
stx := .missing
|
||||
parserState := {}
|
||||
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
|
||||
finishedSnap
|
||||
tacticCache
|
||||
} else {
|
||||
diagnostics, stx, parserState, tacticCache
|
||||
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
|
||||
finishedSnap
|
||||
}
|
||||
prom.resolve <| .mk (nextCmdSnap? := next?.map
|
||||
({ range? := some ⟨parserState.pos, ctx.input.endPos⟩, task := ·.result })) data
|
||||
if minimalSnapshots && !Parser.isTerminalCommand stx then
|
||||
prom.resolve {
|
||||
diagnostics, finishedSnap, tacticCache, nextCmdSnap?
|
||||
stx := .missing
|
||||
parserState := {}
|
||||
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
|
||||
}
|
||||
else
|
||||
prom.resolve {
|
||||
diagnostics, stx, parserState, tacticCache, nextCmdSnap?
|
||||
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
|
||||
finishedSnap
|
||||
}
|
||||
let cmdState ← doElab stx cmdState beginPos
|
||||
{ old? := old?.map fun old => ⟨old.data.stx, old.data.elabSnap⟩, new := elabPromise }
|
||||
{ old? := old?.map fun old => ⟨old.stx, old.elabSnap⟩, new := elabPromise }
|
||||
finishedPromise tacticCache ctx
|
||||
if let some next := next? then
|
||||
-- We're definitely off the fast-forwarding path now
|
||||
@@ -571,7 +573,7 @@ where
|
||||
LeanProcessingM Command.State := do
|
||||
let ctx ← read
|
||||
let scope := cmdState.scopes.head!
|
||||
let cmdStateRef ← IO.mkRef { cmdState with messages := .empty }
|
||||
let cmdStateRef ← IO.mkRef { cmdState with messages := .empty, traceState := {} }
|
||||
/-
|
||||
The same snapshot may be executed by different tasks. So, to make sure `elabCommandTopLevel`
|
||||
has exclusive access to the cache, we create a fresh reference here. Before this change, the
|
||||
@@ -613,6 +615,7 @@ where
|
||||
finishedPromise.resolve {
|
||||
diagnostics := (← Snapshot.Diagnostics.ofMessageLog cmdState.messages)
|
||||
infoTree? := infoTree
|
||||
traces := cmdState.traceState
|
||||
cmdState := if cmdline then {
|
||||
env := Runtime.markPersistent cmdState.env
|
||||
maxRecDepth := 0
|
||||
@@ -644,6 +647,6 @@ where goCmd snap :=
|
||||
if let some next := snap.nextCmdSnap? then
|
||||
goCmd next.get
|
||||
else
|
||||
snap.data.finishedSnap.get.cmdState
|
||||
snap.finishedSnap.get.cmdState
|
||||
|
||||
end Lean
|
||||
|
||||
@@ -34,7 +34,7 @@ instance : ToSnapshotTree CommandFinishedSnapshot where
|
||||
toSnapshotTree s := ⟨s.toSnapshot, #[]⟩
|
||||
|
||||
/-- State after a command has been parsed. -/
|
||||
structure CommandParsedSnapshotData extends Snapshot where
|
||||
structure CommandParsedSnapshot extends Snapshot where
|
||||
/-- Syntax tree of the command. -/
|
||||
stx : Syntax
|
||||
/-- Resulting parser state. -/
|
||||
@@ -48,27 +48,14 @@ structure CommandParsedSnapshotData extends Snapshot where
|
||||
finishedSnap : SnapshotTask CommandFinishedSnapshot
|
||||
/-- Cache for `save`; to be replaced with incrementality. -/
|
||||
tacticCache : IO.Ref Tactic.Cache
|
||||
/-- Next command, unless this is a terminal command. -/
|
||||
nextCmdSnap? : Option (SnapshotTask CommandParsedSnapshot)
|
||||
deriving Nonempty
|
||||
|
||||
/-- State after a command has been parsed. -/
|
||||
-- workaround for lack of recursive structures
|
||||
inductive CommandParsedSnapshot where
|
||||
/-- Creates a command parsed snapshot. -/
|
||||
| mk (data : CommandParsedSnapshotData)
|
||||
(nextCmdSnap? : Option (SnapshotTask CommandParsedSnapshot))
|
||||
deriving Nonempty
|
||||
/-- The snapshot data. -/
|
||||
abbrev CommandParsedSnapshot.data : CommandParsedSnapshot → CommandParsedSnapshotData
|
||||
| mk data _ => data
|
||||
/-- Next command, unless this is a terminal command. -/
|
||||
abbrev CommandParsedSnapshot.nextCmdSnap? : CommandParsedSnapshot →
|
||||
Option (SnapshotTask CommandParsedSnapshot)
|
||||
| mk _ next? => next?
|
||||
partial instance : ToSnapshotTree CommandParsedSnapshot where
|
||||
toSnapshotTree := go where
|
||||
go s := ⟨s.data.toSnapshot,
|
||||
#[s.data.elabSnap.map (sync := true) toSnapshotTree,
|
||||
s.data.finishedSnap.map (sync := true) toSnapshotTree] |>
|
||||
go s := ⟨s.toSnapshot,
|
||||
#[s.elabSnap.map (sync := true) toSnapshotTree,
|
||||
s.finishedSnap.map (sync := true) toSnapshotTree] |>
|
||||
pushOpt (s.nextCmdSnap?.map (·.map (sync := true) go))⟩
|
||||
|
||||
/-- State after successful importing. -/
|
||||
|
||||
@@ -406,8 +406,8 @@ def isSubPrefixOf (lctx₁ lctx₂ : LocalContext) (exceptFVars : Array Expr :=
|
||||
|
||||
@[inline] def mkBinding (isLambda : Bool) (lctx : LocalContext) (xs : Array Expr) (b : Expr) : Expr :=
|
||||
let b := b.abstract xs
|
||||
xs.size.foldRev (init := b) fun i b =>
|
||||
let x := xs[i]!
|
||||
xs.size.foldRev (init := b) fun i _ b =>
|
||||
let x := xs[i]
|
||||
match lctx.findFVar? x with
|
||||
| some (.cdecl _ _ n ty bi _) =>
|
||||
let ty := ty.abstractRange i xs;
|
||||
@@ -457,7 +457,7 @@ def sanitizeNames (lctx : LocalContext) : StateM NameSanitizerState LocalContext
|
||||
let st ← get
|
||||
if !getSanitizeNames st.options then pure lctx else
|
||||
StateT.run' (s := ({} : NameSet)) <|
|
||||
lctx.decls.size.foldRevM (init := lctx) fun i lctx => do
|
||||
lctx.decls.size.foldRevM (init := lctx) fun i _ lctx => do
|
||||
match lctx.decls[i]! with
|
||||
| none => pure lctx
|
||||
| some decl =>
|
||||
|
||||
@@ -244,8 +244,8 @@ partial def formatAux : NamingContext → Option MessageDataContext → MessageD
|
||||
| panic! s!"MessageData.ofLazy: expected MessageData in Dynamic, got {dyn.typeName}"
|
||||
formatAux nCtx ctx? msg
|
||||
|
||||
protected def format (msgData : MessageData) : IO Format :=
|
||||
formatAux { currNamespace := Name.anonymous, openDecls := [] } none msgData
|
||||
protected def format (msgData : MessageData) (ctx? : Option MessageDataContext := none) : IO Format :=
|
||||
formatAux { currNamespace := Name.anonymous, openDecls := [] } ctx? msgData
|
||||
|
||||
protected def toString (msgData : MessageData) : IO String := do
|
||||
return toString (← msgData.format)
|
||||
|
||||
@@ -825,7 +825,7 @@ def mkFreshExprMVarWithId (mvarId : MVarId) (type? : Option Expr := none) (kind
|
||||
mkFreshExprMVarWithIdCore mvarId type kind userName
|
||||
|
||||
def mkFreshLevelMVars (num : Nat) : MetaM (List Level) :=
|
||||
num.foldM (init := []) fun _ us =>
|
||||
num.foldM (init := []) fun _ _ us =>
|
||||
return (← mkFreshLevelMVar)::us
|
||||
|
||||
def mkFreshLevelMVarsFor (info : ConstantInfo) : MetaM (List Level) :=
|
||||
|
||||
@@ -286,8 +286,8 @@ partial def process : ClosureM Unit := do
|
||||
@[inline] def mkBinding (isLambda : Bool) (decls : Array LocalDecl) (b : Expr) : Expr :=
|
||||
let xs := decls.map LocalDecl.toExpr
|
||||
let b := b.abstract xs
|
||||
decls.size.foldRev (init := b) fun i b =>
|
||||
let decl := decls[i]!
|
||||
decls.size.foldRev (init := b) fun i _ b =>
|
||||
let decl := decls[i]
|
||||
match decl with
|
||||
| .cdecl _ _ n ty bi _ =>
|
||||
let ty := ty.abstractRange i xs
|
||||
|
||||
@@ -939,29 +939,6 @@ def check (hasCtxLocals : Bool) (mctx : MetavarContext) (lctx : LocalContext) (m
|
||||
|
||||
end CheckAssignmentQuick
|
||||
|
||||
/--
|
||||
Auxiliary function used at `typeOccursCheckImp`.
|
||||
Given `type`, it tries to eliminate "dependencies". For example, suppose we are trying to
|
||||
perform the assignment `?m := f (?n a b)` where
|
||||
```
|
||||
?n : let k := g ?m; A -> h k ?m -> C
|
||||
```
|
||||
If we just perform occurs check `?m` at the type of `?n`, we get a failure, but
|
||||
we claim these occurrences are ok because the type `?n a b : C`.
|
||||
In the example above, `typeOccursCheckImp` invokes this function with `n := 2`.
|
||||
Note that we avoid using `whnf` and `inferType` at `typeOccursCheckImp` to minimize the
|
||||
performance impact of this extra check.
|
||||
|
||||
See test `typeOccursCheckIssue.lean` for an example where this refinement is needed.
|
||||
The test is derived from a Mathlib file.
|
||||
-/
|
||||
private partial def skipAtMostNumBinders (type : Expr) (n : Nat) : Expr :=
|
||||
match type, n with
|
||||
| .forallE _ _ b _, n+1 => skipAtMostNumBinders b n
|
||||
| .mdata _ b, n => skipAtMostNumBinders b n
|
||||
| .letE _ _ v b _, n => skipAtMostNumBinders (b.instantiate1 v) n
|
||||
| type, _ => type
|
||||
|
||||
/-- `typeOccursCheck` implementation using unsafe (i.e., pointer equality) features. -/
|
||||
private unsafe def typeOccursCheckImp (mctx : MetavarContext) (mvarId : MVarId) (v : Expr) : Bool :=
|
||||
if v.hasExprMVar then
|
||||
@@ -982,19 +959,11 @@ where
|
||||
-- this function assumes all assigned metavariables have already been
|
||||
-- instantiated.
|
||||
go.run' mctx
|
||||
visitMVar (mvarId' : MVarId) (numArgs : Nat := 0) : Bool :=
|
||||
visitMVar (mvarId' : MVarId) : Bool :=
|
||||
if let some mvarDecl := mctx.findDecl? mvarId' then
|
||||
occursCheck (skipAtMostNumBinders mvarDecl.type numArgs)
|
||||
occursCheck mvarDecl.type
|
||||
else
|
||||
false
|
||||
visitApp (e : Expr) : StateM (PtrSet Expr) Bool :=
|
||||
e.withApp fun f args => do
|
||||
unless (← args.allM visit) do
|
||||
return false
|
||||
if f.isMVar then
|
||||
return visitMVar f.mvarId! args.size
|
||||
else
|
||||
visit f
|
||||
visit (e : Expr) : StateM (PtrSet Expr) Bool := do
|
||||
if !e.hasExprMVar then
|
||||
return true
|
||||
@@ -1003,7 +972,7 @@ where
|
||||
else match e with
|
||||
| .mdata _ b => visit b
|
||||
| .proj _ _ s => visit s
|
||||
| .app .. => visitApp e
|
||||
| .app f a => visit f <&&> visit a
|
||||
| .lam _ d b _ => visit d <&&> visit b
|
||||
| .forallE _ d b _ => visit d <&&> visit b
|
||||
| .letE _ t v b _ => visit t <&&> visit v <&&> visit b
|
||||
|
||||
@@ -83,7 +83,7 @@ where
|
||||
forallTelescopeReducing t fun xs s => do
|
||||
let motiveType ← instantiateForall motive xs[:numParams]
|
||||
withLocalDecl motiveName BinderInfo.implicit motiveType fun motive => do
|
||||
mkForallFVars (xs.insertAt! numParams motive) s)
|
||||
mkForallFVars (xs.insertIdxIfInBounds numParams motive) s)
|
||||
|
||||
motiveType (indVal : InductiveVal) : MetaM Expr :=
|
||||
forallTelescopeReducing indVal.type fun xs _ => do
|
||||
|
||||
@@ -566,9 +566,9 @@ Append results to array
|
||||
partial def appendResultsAux (mr : MatchResult α) (a : Array β) (f : Nat → α → β) : Array β :=
|
||||
let aa := mr.elts
|
||||
let n := aa.size
|
||||
Nat.fold (n := n) (init := a) fun i r =>
|
||||
Nat.fold (n := n) (init := a) fun i _ r =>
|
||||
let j := n-1-i
|
||||
let b := aa[j]!
|
||||
let b := aa[j]
|
||||
b.foldl (init := r) (· ++ ·.map (f j))
|
||||
|
||||
partial def appendResults (mr : MatchResult α) (a : Array α) : Array α :=
|
||||
|
||||
@@ -882,7 +882,7 @@ def mkMatcher (input : MkMatcherInput) (exceptionIfContainsSorry := false) : Met
|
||||
| none => pure ()
|
||||
|
||||
trace[Meta.Match.debug] "matcher: {matcher}"
|
||||
let unusedAltIdxs := lhss.length.fold (init := []) fun i r =>
|
||||
let unusedAltIdxs := lhss.length.fold (init := []) fun i _ r =>
|
||||
if s.used.contains i then r else i::r
|
||||
return {
|
||||
matcher,
|
||||
|
||||
@@ -129,9 +129,9 @@ where
|
||||
let typeNew := b.instantiate1 y
|
||||
if let some (_, lhs, rhs) ← matchEq? d then
|
||||
if lhs.isFVar && ys.contains lhs && args.contains lhs && isNamedPatternProof typeNew y then
|
||||
let some j := ys.getIdx? lhs | unreachable!
|
||||
let some j := ys.indexOf? lhs | unreachable!
|
||||
let ys := ys.eraseIdx j
|
||||
let some k := args.getIdx? lhs | unreachable!
|
||||
let some k := args.indexOf? lhs | unreachable!
|
||||
let mask := mask.set! k false
|
||||
let args := args.map fun arg => if arg == lhs then rhs else arg
|
||||
let arg ← mkEqRefl rhs
|
||||
@@ -222,16 +222,16 @@ private def contradiction (mvarId : MVarId) : MetaM Bool :=
|
||||
Auxiliary tactic that tries to replace as many variables as possible and then apply `contradiction`.
|
||||
We use it to discard redundant hypotheses.
|
||||
-/
|
||||
partial def trySubstVarsAndContradiction (mvarId : MVarId) : MetaM Bool :=
|
||||
partial def trySubstVarsAndContradiction (mvarId : MVarId) (forbidden : FVarIdSet := {}) : MetaM Bool :=
|
||||
commitWhen do
|
||||
let mvarId ← substVars mvarId
|
||||
match (← injections mvarId) with
|
||||
match (← injections mvarId (forbidden := forbidden)) with
|
||||
| .solved => return true -- closed goal
|
||||
| .subgoal mvarId' _ =>
|
||||
| .subgoal mvarId' _ forbidden =>
|
||||
if mvarId' == mvarId then
|
||||
contradiction mvarId
|
||||
else
|
||||
trySubstVarsAndContradiction mvarId'
|
||||
trySubstVarsAndContradiction mvarId' forbidden
|
||||
|
||||
private def processNextEq : M Bool := do
|
||||
let s ← get
|
||||
@@ -375,7 +375,8 @@ private partial def withSplitterAlts (altTypes : Array Expr) (f : Array Expr →
|
||||
inductive InjectionAnyResult where
|
||||
| solved
|
||||
| failed
|
||||
| subgoal (mvarId : MVarId)
|
||||
/-- `fvarId` refers to the local declaration selected for the application of the `injection` tactic. -/
|
||||
| subgoal (fvarId : FVarId) (mvarId : MVarId)
|
||||
|
||||
private def injectionAnyCandidate? (type : Expr) : MetaM (Option (Expr × Expr)) := do
|
||||
if let some (_, lhs, rhs) ← matchEq? type then
|
||||
@@ -385,21 +386,28 @@ private def injectionAnyCandidate? (type : Expr) : MetaM (Option (Expr × Expr))
|
||||
return some (lhs, rhs)
|
||||
return none
|
||||
|
||||
private def injectionAny (mvarId : MVarId) : MetaM InjectionAnyResult := do
|
||||
/--
|
||||
Try applying `injection` to a local declaration that is not in `forbidden`.
|
||||
|
||||
We use `forbidden` because the `injection` tactic might fail to clear the variable if there are forward dependencies.
|
||||
See `proveSubgoalLoop` for additional details.
|
||||
-/
|
||||
private def injectionAny (mvarId : MVarId) (forbidden : FVarIdSet := {}) : MetaM InjectionAnyResult := do
|
||||
mvarId.withContext do
|
||||
for localDecl in (← getLCtx) do
|
||||
if let some (lhs, rhs) ← injectionAnyCandidate? localDecl.type then
|
||||
unless (← isDefEq lhs rhs) do
|
||||
let lhs ← whnf lhs
|
||||
let rhs ← whnf rhs
|
||||
unless lhs.isRawNatLit && rhs.isRawNatLit do
|
||||
try
|
||||
match (← injection mvarId localDecl.fvarId) with
|
||||
| InjectionResult.solved => return InjectionAnyResult.solved
|
||||
| InjectionResult.subgoal mvarId .. => return InjectionAnyResult.subgoal mvarId
|
||||
catch ex =>
|
||||
trace[Meta.Match.matchEqs] "injectionAnyFailed at {localDecl.userName}, error\n{ex.toMessageData}"
|
||||
pure ()
|
||||
unless forbidden.contains localDecl.fvarId do
|
||||
if let some (lhs, rhs) ← injectionAnyCandidate? localDecl.type then
|
||||
unless (← isDefEq lhs rhs) do
|
||||
let lhs ← whnf lhs
|
||||
let rhs ← whnf rhs
|
||||
unless lhs.isRawNatLit && rhs.isRawNatLit do
|
||||
try
|
||||
match (← injection mvarId localDecl.fvarId) with
|
||||
| InjectionResult.solved => return InjectionAnyResult.solved
|
||||
| InjectionResult.subgoal mvarId .. => return InjectionAnyResult.subgoal localDecl.fvarId mvarId
|
||||
catch ex =>
|
||||
trace[Meta.Match.matchEqs] "injectionAnyFailed at {localDecl.userName}, error\n{ex.toMessageData}"
|
||||
pure ()
|
||||
return InjectionAnyResult.failed
|
||||
|
||||
|
||||
@@ -601,27 +609,32 @@ where
|
||||
let eNew := mkAppN eNew mvars
|
||||
return TransformStep.done eNew
|
||||
|
||||
proveSubgoalLoop (mvarId : MVarId) : MetaM Unit := do
|
||||
/-
|
||||
`forbidden` tracks variables that we have already applied `injection`.
|
||||
Recall that the `injection` tactic may not be able to eliminate them when
|
||||
they have forward dependencies.
|
||||
-/
|
||||
proveSubgoalLoop (mvarId : MVarId) (forbidden : FVarIdSet) : MetaM Unit := do
|
||||
trace[Meta.Match.matchEqs] "proveSubgoalLoop\n{mvarId}"
|
||||
if (← mvarId.contradictionQuick) then
|
||||
return ()
|
||||
match (← injectionAny mvarId) with
|
||||
| InjectionAnyResult.solved => return ()
|
||||
| InjectionAnyResult.failed =>
|
||||
match (← injectionAny mvarId forbidden) with
|
||||
| .solved => return ()
|
||||
| .failed =>
|
||||
let mvarId' ← substVars mvarId
|
||||
if mvarId' == mvarId then
|
||||
if (← mvarId.contradictionCore {}) then
|
||||
return ()
|
||||
throwError "failed to generate splitter for match auxiliary declaration '{matchDeclName}', unsolved subgoal:\n{MessageData.ofGoal mvarId}"
|
||||
else
|
||||
proveSubgoalLoop mvarId'
|
||||
| InjectionAnyResult.subgoal mvarId => proveSubgoalLoop mvarId
|
||||
proveSubgoalLoop mvarId' forbidden
|
||||
| .subgoal fvarId mvarId => proveSubgoalLoop mvarId (forbidden.insert fvarId)
|
||||
|
||||
proveSubgoal (mvarId : MVarId) : MetaM Unit := do
|
||||
trace[Meta.Match.matchEqs] "subgoal {mkMVar mvarId}, {repr (← mvarId.getDecl).kind}, {← mvarId.isAssigned}\n{MessageData.ofGoal mvarId}"
|
||||
let (_, mvarId) ← mvarId.intros
|
||||
let mvarId ← mvarId.tryClearMany (alts.map (·.fvarId!))
|
||||
proveSubgoalLoop mvarId
|
||||
proveSubgoalLoop mvarId {}
|
||||
|
||||
/--
|
||||
Create new alternatives (aka minor premises) by replacing `discrs` with `patterns` at `alts`.
|
||||
@@ -646,7 +659,7 @@ where
|
||||
/--
|
||||
Create conditional equations and splitter for the given match auxiliary declaration. -/
|
||||
private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns := withLCtx {} {} do
|
||||
trace[Meta.Match.matchEqs] "mkEquationsFor '{matchDeclName}'"
|
||||
withTraceNode `Meta.Match.matchEqs (fun _ => return m!"mkEquationsFor '{matchDeclName}'") do
|
||||
withConfig (fun c => { c with etaStruct := .none }) do
|
||||
/-
|
||||
Remark: user have requested the `split` tactic to be available for writing code.
|
||||
|
||||
@@ -59,9 +59,9 @@ def addArg (matcherApp : MatcherApp) (e : Expr) : MetaM MatcherApp :=
|
||||
-- This error can only happen if someone implemented a transformation that rewrites the motive created by `mkMatcher`.
|
||||
throwError "unexpected matcher application, motive must be lambda expression with #{matcherApp.discrs.size} arguments"
|
||||
let eType ← inferType e
|
||||
let eTypeAbst ← matcherApp.discrs.size.foldRevM (init := eType) fun i eTypeAbst => do
|
||||
let eTypeAbst ← matcherApp.discrs.size.foldRevM (init := eType) fun i _ eTypeAbst => do
|
||||
let motiveArg := motiveArgs[i]!
|
||||
let discr := matcherApp.discrs[i]!
|
||||
let discr := matcherApp.discrs[i]
|
||||
let eTypeAbst ← kabstract eTypeAbst discr
|
||||
return eTypeAbst.instantiate1 motiveArg
|
||||
let motiveBody ← mkArrow eTypeAbst motiveBody
|
||||
@@ -118,9 +118,9 @@ def refineThrough (matcherApp : MatcherApp) (e : Expr) : MetaM (Array Expr) :=
|
||||
-- This error can only happen if someone implemented a transformation that rewrites the motive created by `mkMatcher`.
|
||||
throwError "failed to transfer argument through matcher application, motive must be lambda expression with #{matcherApp.discrs.size} arguments"
|
||||
|
||||
let eAbst ← matcherApp.discrs.size.foldRevM (init := e) fun i eAbst => do
|
||||
let eAbst ← matcherApp.discrs.size.foldRevM (init := e) fun i _ eAbst => do
|
||||
let motiveArg := motiveArgs[i]!
|
||||
let discr := matcherApp.discrs[i]!
|
||||
let discr := matcherApp.discrs[i]
|
||||
let eTypeAbst ← kabstract eAbst discr
|
||||
return eTypeAbst.instantiate1 motiveArg
|
||||
-- Let's create something that’s a `Sort` and mentions `e`
|
||||
|
||||
@@ -107,7 +107,7 @@ private def getMajorPosDepElim (declName : Name) (majorPos? : Option Nat) (xs :
|
||||
if motiveArgs.isEmpty then
|
||||
throwError "invalid user defined recursor, '{declName}' does not support dependent elimination, and position of the major premise was not specified (solution: set attribute '[recursor <pos>]', where <pos> is the position of the major premise)"
|
||||
let major := motiveArgs.back!
|
||||
match xs.getIdx? major with
|
||||
match xs.indexOf? major with
|
||||
| some majorPos => pure (major, majorPos, true)
|
||||
| none => throwError "ill-formed recursor '{declName}'"
|
||||
|
||||
|
||||
@@ -104,8 +104,8 @@ def appendParentTag (mvarId : MVarId) (newMVars : Array Expr) (binderInfos : Arr
|
||||
newMVars[0]!.mvarId!.setTag parentTag
|
||||
else
|
||||
unless parentTag.isAnonymous do
|
||||
newMVars.size.forM fun i => do
|
||||
let mvarIdNew := newMVars[i]!.mvarId!
|
||||
newMVars.size.forM fun i _ => do
|
||||
let mvarIdNew := newMVars[i].mvarId!
|
||||
unless (← mvarIdNew.isAssigned) do
|
||||
unless binderInfos[i]!.isInstImplicit do
|
||||
let currTag ← mvarIdNew.getTag
|
||||
|
||||
@@ -168,7 +168,7 @@ private def hasIndepIndices (ctx : Context) : MetaM Bool := do
|
||||
else if ctx.majorTypeIndices.any fun idx => !idx.isFVar then
|
||||
/- One of the indices is not a free variable. -/
|
||||
return false
|
||||
else if ctx.majorTypeIndices.size.any fun i => i.any fun j => ctx.majorTypeIndices[i]! == ctx.majorTypeIndices[j]! then
|
||||
else if ctx.majorTypeIndices.size.any fun i _ => i.any fun j _ => ctx.majorTypeIndices[i] == ctx.majorTypeIndices[j] then
|
||||
/- An index occurs more than once -/
|
||||
return false
|
||||
else
|
||||
|
||||
@@ -27,7 +27,7 @@ def _root_.Lean.MVarId.clear (mvarId : MVarId) (fvarId : FVarId) : MetaM MVarId
|
||||
throwTacticEx `clear mvarId m!"target depends on '{mkFVar fvarId}'"
|
||||
let lctx := lctx.erase fvarId
|
||||
let localInsts ← getLocalInstances
|
||||
let localInsts := match localInsts.findIdx? fun localInst => localInst.fvar.fvarId! == fvarId with
|
||||
let localInsts := match localInsts.findFinIdx? fun localInst => localInst.fvar.fvarId! == fvarId with
|
||||
| none => localInsts
|
||||
| some idx => localInsts.eraseIdx idx
|
||||
let newMVar ← mkFreshExprMVarAt lctx localInsts mvarDecl.type MetavarKind.syntheticOpaque tag
|
||||
|
||||
@@ -283,9 +283,9 @@ partial def foldAndCollect (oldIH newIH : FVarId) (isRecCall : Expr → Option E
|
||||
(onMotive := fun xs _body => do
|
||||
-- Remove the old IH that was added in mkFix
|
||||
let eType ← newIH.getType
|
||||
let eTypeAbst ← matcherApp.discrs.size.foldRevM (init := eType) fun i eTypeAbst => do
|
||||
let eTypeAbst ← matcherApp.discrs.size.foldRevM (init := eType) fun i _ eTypeAbst => do
|
||||
let motiveArg := xs[i]!
|
||||
let discr := matcherApp.discrs[i]!
|
||||
let discr := matcherApp.discrs[i]
|
||||
let eTypeAbst ← kabstract eTypeAbst discr
|
||||
return eTypeAbst.instantiate1 motiveArg
|
||||
|
||||
|
||||
@@ -105,10 +105,10 @@ private partial def finalize
|
||||
let mvarId' ← mvar.mvarId!.tryClear major.fvarId!
|
||||
let (fields, mvarId') ← mvarId'.introN nparams minorGivenNames.varNames (useNamesForExplicitOnly := !minorGivenNames.explicit)
|
||||
let (extra, mvarId') ← mvarId'.introNP nextra
|
||||
let subst := reverted.size.fold (init := baseSubst) fun i (subst : FVarSubst) =>
|
||||
let subst := reverted.size.fold (init := baseSubst) fun i _ (subst : FVarSubst) =>
|
||||
if i < indices.size + 1 then subst
|
||||
else
|
||||
let revertedFVarId := reverted[i]!
|
||||
let revertedFVarId := reverted[i]
|
||||
let newFVarId := extra[i - indices.size - 1]!
|
||||
subst.insert revertedFVarId (mkFVar newFVarId)
|
||||
let fields := fields.map mkFVar
|
||||
@@ -134,8 +134,8 @@ def getMajorTypeIndices (mvarId : MVarId) (tacticName : Name) (recursorInfo : Re
|
||||
if idxPos ≥ majorTypeArgs.size then throwTacticEx tacticName mvarId m!"major premise type is ill-formed{indentExpr majorType}"
|
||||
let idx := majorTypeArgs.get! idxPos
|
||||
unless idx.isFVar do throwTacticEx tacticName mvarId m!"major premise type index {idx} is not a variable{indentExpr majorType}"
|
||||
majorTypeArgs.size.forM fun i => do
|
||||
let arg := majorTypeArgs[i]!
|
||||
majorTypeArgs.size.forM fun i _ => do
|
||||
let arg := majorTypeArgs[i]
|
||||
if i != idxPos && arg == idx then
|
||||
throwTacticEx tacticName mvarId m!"'{idx}' is an index in major premise, but it occurs more than once{indentExpr majorType}"
|
||||
if i < idxPos then
|
||||
|
||||
@@ -94,31 +94,48 @@ def injection (mvarId : MVarId) (fvarId : FVarId) (newNames : List Name := []) :
|
||||
| .subgoal mvarId numEqs => injectionIntro mvarId numEqs newNames
|
||||
|
||||
inductive InjectionsResult where
|
||||
/-- `injections` closed the input goal. -/
|
||||
| solved
|
||||
| subgoal (mvarId : MVarId) (remainingNames : List Name)
|
||||
/--
|
||||
`injections` produces a new goal `mvarId`. `remainingNames` contains the user-facing names that have not been used.
|
||||
`forbidden` contains all local declarations to which `injection` has been applied.
|
||||
Recall that some of these declarations may not have been eliminated from the local context due to forward dependencies, and
|
||||
we use `forbidden` to avoid non-termination when using `injections` in a loop.
|
||||
-/
|
||||
| subgoal (mvarId : MVarId) (remainingNames : List Name) (forbidden : FVarIdSet)
|
||||
|
||||
partial def injections (mvarId : MVarId) (newNames : List Name := []) (maxDepth : Nat := 5) : MetaM InjectionsResult :=
|
||||
/--
|
||||
Applies `injection` to local declarations in `mvarId`. It uses `newNames` to name the new local declarations.
|
||||
`maxDepth` is the maximum recursion depth. Only local declarations that are not in `forbidden` are considered.
|
||||
Recall that some of local declarations may not have been eliminated from the local context due to forward dependencies, and
|
||||
we use `forbidden` to avoid non-termination when using `injections` in a loop.
|
||||
-/
|
||||
partial def injections (mvarId : MVarId) (newNames : List Name := []) (maxDepth : Nat := 5) (forbidden : FVarIdSet := {}) : MetaM InjectionsResult :=
|
||||
mvarId.withContext do
|
||||
let fvarIds := (← getLCtx).getFVarIds
|
||||
go maxDepth fvarIds.toList mvarId newNames
|
||||
go maxDepth fvarIds.toList mvarId newNames forbidden
|
||||
where
|
||||
go : Nat → List FVarId → MVarId → List Name → MetaM InjectionsResult
|
||||
| 0, _, _, _ => throwTacticEx `injections mvarId "recursion depth exceeded"
|
||||
| _, [], mvarId, newNames => return .subgoal mvarId newNames
|
||||
| d+1, fvarId :: fvarIds, mvarId, newNames => do
|
||||
go (depth : Nat) (fvarIds : List FVarId) (mvarId : MVarId) (newNames : List Name) (forbidden : FVarIdSet) : MetaM InjectionsResult := do
|
||||
match depth, fvarIds with
|
||||
| 0, _ => throwTacticEx `injections mvarId "recursion depth exceeded"
|
||||
| _, [] => return .subgoal mvarId newNames forbidden
|
||||
| d+1, fvarId :: fvarIds => do
|
||||
let cont := do
|
||||
go (d+1) fvarIds mvarId newNames
|
||||
if let some (_, lhs, rhs) ← matchEqHEq? (← fvarId.getType) then
|
||||
go (d+1) fvarIds mvarId newNames forbidden
|
||||
if forbidden.contains fvarId then
|
||||
cont
|
||||
else if let some (_, lhs, rhs) ← matchEqHEq? (← fvarId.getType) then
|
||||
let lhs ← whnf lhs
|
||||
let rhs ← whnf rhs
|
||||
if lhs.isRawNatLit && rhs.isRawNatLit then cont
|
||||
if lhs.isRawNatLit && rhs.isRawNatLit then
|
||||
cont
|
||||
else
|
||||
try
|
||||
commitIfNoEx do
|
||||
match (← injection mvarId fvarId newNames) with
|
||||
| .solved => return .solved
|
||||
| .subgoal mvarId newEqs remainingNames =>
|
||||
mvarId.withContext <| go d (newEqs.toList ++ fvarIds) mvarId remainingNames
|
||||
mvarId.withContext <| go d (newEqs.toList ++ fvarIds) mvarId remainingNames (forbidden.insert fvarId)
|
||||
catch _ => cont
|
||||
else cont
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ def _root_.Lean.MVarId.revert (mvarId : MVarId) (fvarIds : Array FVarId) (preser
|
||||
finally
|
||||
mvarId.setKind .syntheticOpaque
|
||||
let mvar := e.getAppFn
|
||||
mvar.mvarId!.setKind .syntheticOpaque
|
||||
mvar.mvarId!.setTag tag
|
||||
return (toRevert.map Expr.fvarId!, mvar.mvarId!)
|
||||
|
||||
|
||||
@@ -80,9 +80,9 @@ def substCore (mvarId : MVarId) (hFVarId : FVarId) (symm := false) (fvarSubst :
|
||||
pure mvarId
|
||||
let (newFVars, mvarId) ← mvarId.introNP (vars.size - 2)
|
||||
trace[Meta.Tactic.subst] "after intro rest {vars.size - 2} {MessageData.ofGoal mvarId}"
|
||||
let fvarSubst ← newFVars.size.foldM (init := fvarSubst) fun i (fvarSubst : FVarSubst) =>
|
||||
let fvarSubst ← newFVars.size.foldM (init := fvarSubst) fun i _ (fvarSubst : FVarSubst) =>
|
||||
let var := vars[i+2]!
|
||||
let newFVar := newFVars[i]!
|
||||
let newFVar := newFVars[i]
|
||||
pure $ fvarSubst.insert var (mkFVar newFVar)
|
||||
let fvarSubst := fvarSubst.insert aFVarIdOriginal (if clearH then b else mkFVar aFVarId)
|
||||
let fvarSubst := fvarSubst.insert hFVarIdOriginal (mkFVar hFVarId)
|
||||
|
||||
@@ -248,9 +248,7 @@ instance : Hashable LocalInstance where
|
||||
|
||||
/-- Remove local instance with the given `fvarId`. Do nothing if `localInsts` does not contain any free variable with id `fvarId`. -/
|
||||
def LocalInstances.erase (localInsts : LocalInstances) (fvarId : FVarId) : LocalInstances :=
|
||||
match localInsts.findIdx? (fun inst => inst.fvar.fvarId! == fvarId) with
|
||||
| some idx => localInsts.eraseIdx idx
|
||||
| _ => localInsts
|
||||
localInsts.eraseP (fun inst => inst.fvar.fvarId! == fvarId)
|
||||
|
||||
/-- A kind for the metavariable that determines its unification behaviour.
|
||||
For more information see the large comment at the beginning of this file. -/
|
||||
@@ -981,10 +979,10 @@ def collectForwardDeps (lctx : LocalContext) (toRevert : Array Expr) : M (Array
|
||||
else
|
||||
if (← preserveOrder) then
|
||||
-- Make sure toRevert[j] does not depend on toRevert[i] for j < i
|
||||
toRevert.size.forM fun i => do
|
||||
let fvar := toRevert[i]!
|
||||
i.forM fun j => do
|
||||
let prevFVar := toRevert[j]!
|
||||
toRevert.size.forM fun i _ => do
|
||||
let fvar := toRevert[i]
|
||||
i.forM fun j _ => do
|
||||
let prevFVar := toRevert[j]
|
||||
let prevDecl := lctx.getFVar! prevFVar
|
||||
if (← localDeclDependsOn prevDecl fvar.fvarId!) then
|
||||
throw (Exception.revertFailure (← getMCtx) lctx toRevert prevDecl.userName.toString)
|
||||
@@ -992,7 +990,7 @@ def collectForwardDeps (lctx : LocalContext) (toRevert : Array Expr) : M (Array
|
||||
let firstDeclToVisit := getLocalDeclWithSmallestIdx lctx toRevert
|
||||
let initSize := newToRevert.size
|
||||
lctx.foldlM (init := newToRevert) (start := firstDeclToVisit.index) fun (newToRevert : Array Expr) decl => do
|
||||
if initSize.any fun i => decl.fvarId == newToRevert[i]!.fvarId! then
|
||||
if initSize.any fun i _ => decl.fvarId == newToRevert[i]!.fvarId! then
|
||||
return newToRevert
|
||||
else if toRevert.any fun x => decl.fvarId == x.fvarId! then
|
||||
return newToRevert.push decl.toExpr
|
||||
@@ -1061,10 +1059,10 @@ mutual
|
||||
|
||||
Note: It is assumed that `xs` is the result of calling `collectForwardDeps` on a subset of variables in `lctx`.
|
||||
-/
|
||||
private partial def mkAuxMVarType (lctx : LocalContext) (xs : Array Expr) (kind : MetavarKind) (e : Expr) : M Expr := do
|
||||
private partial def mkAuxMVarType (lctx : LocalContext) (xs : Array Expr) (kind : MetavarKind) (e : Expr) (usedLetOnly : Bool) : M Expr := do
|
||||
let e ← abstractRangeAux xs xs.size e
|
||||
xs.size.foldRevM (init := e) fun i e => do
|
||||
let x := xs[i]!
|
||||
xs.size.foldRevM (init := e) fun i _ e => do
|
||||
let x := xs[i]
|
||||
if x.isFVar then
|
||||
match lctx.getFVar! x with
|
||||
| LocalDecl.cdecl _ _ n type bi _ =>
|
||||
@@ -1072,16 +1070,25 @@ mutual
|
||||
let type ← abstractRangeAux xs i type
|
||||
return Lean.mkForall n bi type e
|
||||
| LocalDecl.ldecl _ _ n type value nonDep _ =>
|
||||
let type := type.headBeta
|
||||
let type ← abstractRangeAux xs i type
|
||||
let value ← abstractRangeAux xs i value
|
||||
let e := mkLet n type value e nonDep
|
||||
match kind with
|
||||
| MetavarKind.syntheticOpaque =>
|
||||
-- See "Gruesome details" section in the beginning of the file
|
||||
let e := e.liftLooseBVars 0 1
|
||||
return mkForall n BinderInfo.default type e
|
||||
| _ => pure e
|
||||
if !usedLetOnly || e.hasLooseBVar 0 then
|
||||
let type := type.headBeta
|
||||
let type ← abstractRangeAux xs i type
|
||||
let value ← abstractRangeAux xs i value
|
||||
let e := mkLet n type value e nonDep
|
||||
match kind with
|
||||
| MetavarKind.syntheticOpaque =>
|
||||
-- See "Gruesome details" section in the beginning of the file
|
||||
let e := e.liftLooseBVars 0 1
|
||||
return mkForall n BinderInfo.default type e
|
||||
| _ => pure e
|
||||
else
|
||||
match kind with
|
||||
| MetavarKind.syntheticOpaque =>
|
||||
let type := type.headBeta
|
||||
let type ← abstractRangeAux xs i type
|
||||
return mkForall n BinderInfo.default type e
|
||||
| _ =>
|
||||
return e.lowerLooseBVars 1 1
|
||||
else
|
||||
-- `xs` may contain metavariables as "may dependencies" (see `findExprDependsOn`)
|
||||
let mvarDecl := (← get).mctx.getDecl x.mvarId!
|
||||
@@ -1101,7 +1108,7 @@ mutual
|
||||
|
||||
See details in the comment at the top of the file.
|
||||
-/
|
||||
private partial def elimMVar (xs : Array Expr) (mvarId : MVarId) (args : Array Expr) : M (Expr × Array Expr) := do
|
||||
private partial def elimMVar (xs : Array Expr) (mvarId : MVarId) (args : Array Expr) (usedLetOnly : Bool) : M (Expr × Array Expr) := do
|
||||
let mvarDecl := (← getMCtx).getDecl mvarId
|
||||
let mvarLCtx := mvarDecl.lctx
|
||||
let toRevert := getInScope mvarLCtx xs
|
||||
@@ -1129,7 +1136,7 @@ mutual
|
||||
let newMVarLCtx := reduceLocalContext mvarLCtx toRevert
|
||||
let newLocalInsts := mvarDecl.localInstances.filter fun inst => toRevert.all fun x => inst.fvar != x
|
||||
-- Remark: we must reset the cache before processing `mkAuxMVarType` because `toRevert` may not be equal to `xs`
|
||||
let newMVarType ← withFreshCache do mkAuxMVarType mvarLCtx toRevert newMVarKind mvarDecl.type
|
||||
let newMVarType ← withFreshCache do mkAuxMVarType mvarLCtx toRevert newMVarKind mvarDecl.type usedLetOnly
|
||||
let newMVarId := { name := (← get).ngen.curr }
|
||||
let newMVar := mkMVar newMVarId
|
||||
let result := mkMVarApp mvarLCtx newMVar toRevert newMVarKind
|
||||
@@ -1170,7 +1177,8 @@ mutual
|
||||
if (← read).mvarIdsToAbstract.contains mvarId then
|
||||
return mkAppN f (← args.mapM (visit xs))
|
||||
else
|
||||
return (← elimMVar xs mvarId args).1
|
||||
-- We set `usedLetOnly := true` to avoid unnecessary let binders in the new metavariable type.
|
||||
return (← elimMVar xs mvarId args (usedLetOnly := true)).1
|
||||
| _ =>
|
||||
return mkAppN (← visit xs f) (← args.mapM (visit xs))
|
||||
|
||||
@@ -1192,7 +1200,9 @@ partial def elimMVarDeps (xs : Array Expr) (e : Expr) : M Expr :=
|
||||
-/
|
||||
partial def revert (xs : Array Expr) (mvarId : MVarId) : M (Expr × Array Expr) :=
|
||||
withFreshCache do
|
||||
elimMVar xs mvarId #[]
|
||||
-- We set `usedLetOnly := false`, because in the `revert` tactic
|
||||
-- we expect that reverting a let variable always results in a let binder.
|
||||
elimMVar xs mvarId #[] (usedLetOnly := false)
|
||||
|
||||
/--
|
||||
Similar to `Expr.abstractRange`, but handles metavariables correctly.
|
||||
@@ -1221,8 +1231,8 @@ private def mkLambda' (x : Name) (bi : BinderInfo) (t : Expr) (b : Expr) (etaRed
|
||||
If `usedLetOnly == true` then `let` expressions are created only for used (let-) variables. -/
|
||||
def mkBinding (isLambda : Bool) (lctx : LocalContext) (xs : Array Expr) (e : Expr) (usedOnly : Bool) (usedLetOnly : Bool) (etaReduce : Bool) : M Expr := do
|
||||
let e ← abstractRange xs xs.size e
|
||||
xs.size.foldRevM (init := e) fun i e => do
|
||||
let x := xs[i]!
|
||||
xs.size.foldRevM (init := e) fun i _ e => do
|
||||
let x := xs[i]
|
||||
if x.isFVar then
|
||||
match lctx.getFVar! x with
|
||||
| LocalDecl.cdecl _ _ n type bi _ =>
|
||||
|
||||
@@ -349,7 +349,7 @@ def delabAppExplicitCore (fieldNotation : Bool) (numArgs : Nat) (delabHead : (in
|
||||
if idx == 0 then
|
||||
-- If it's the first argument, then we can tag `obj.field` with the first app.
|
||||
head ← withBoundedAppFn (numArgs - 1) <| annotateTermInfo head
|
||||
return Syntax.mkApp head (argStxs.eraseIdx idx)
|
||||
return Syntax.mkApp head (argStxs.eraseIdxIfInBounds idx)
|
||||
else
|
||||
return Syntax.mkApp fnStx argStxs
|
||||
|
||||
@@ -876,7 +876,7 @@ def delabLam : Delab :=
|
||||
-- "default" binder group is the only one that expects binder names
|
||||
-- as a term, i.e. a single `Syntax.ident` or an application thereof
|
||||
let stxCurNames ←
|
||||
if curNames.size > 1 then
|
||||
if h : curNames.size > 1 then
|
||||
`($(curNames.get! 0) $(curNames.eraseIdx 0)*)
|
||||
else
|
||||
pure $ curNames.get! 0;
|
||||
@@ -1115,17 +1115,18 @@ def coeDelaborator : Delab := whenPPOption getPPCoercions do
|
||||
delabAppCore nargs (delabHead info nargs) (unexpand := false)
|
||||
where
|
||||
delabHead (info : CoeFnInfo) (nargs : Nat) (insertExplicit : Bool) : Delab := do
|
||||
guard <| !insertExplicit
|
||||
if info.type == .coeFun && nargs > 0 then
|
||||
-- In the CoeFun case, annotate with the coercee itself.
|
||||
-- We can still see the whole coercion expression by hovering over the whitespace between the arguments.
|
||||
withNaryArg info.coercee <| withAnnotateTermInfo delab
|
||||
else
|
||||
withAnnotateTermInfo do
|
||||
match info.type with
|
||||
| .coe => `(↑$(← withNaryArg info.coercee delab))
|
||||
| .coeFun => `(⇑$(← withNaryArg info.coercee delab))
|
||||
| .coeSort => `(↥$(← withNaryArg info.coercee delab))
|
||||
withTypeAscription (cond := ← getPPOption getPPCoercionsTypes) do
|
||||
guard <| !insertExplicit
|
||||
if info.type == .coeFun && nargs > 0 then
|
||||
-- In the CoeFun case, annotate with the coercee itself.
|
||||
-- We can still see the whole coercion expression by hovering over the whitespace between the arguments.
|
||||
withNaryArg info.coercee <| withAnnotateTermInfo delab
|
||||
else
|
||||
withAnnotateTermInfo do
|
||||
match info.type with
|
||||
| .coe => `(↑$(← withNaryArg info.coercee delab))
|
||||
| .coeFun => `(⇑$(← withNaryArg info.coercee delab))
|
||||
| .coeSort => `(↥$(← withNaryArg info.coercee delab))
|
||||
|
||||
@[builtin_delab app.dite]
|
||||
def delabDIte : Delab := whenNotPPOption getPPExplicit <| whenPPOption getPPNotation <| withOverApp 5 do
|
||||
|
||||
@@ -44,6 +44,11 @@ register_builtin_option pp.coercions : Bool := {
|
||||
group := "pp"
|
||||
descr := "(pretty printer) hide coercion applications"
|
||||
}
|
||||
register_builtin_option pp.coercions.types : Bool := {
|
||||
defValue := false
|
||||
group := "pp"
|
||||
descr := "(pretty printer) display coercion applications with a type ascription"
|
||||
}
|
||||
register_builtin_option pp.universes : Bool := {
|
||||
defValue := false
|
||||
group := "pp"
|
||||
@@ -251,6 +256,7 @@ def getPPLetVarTypes (o : Options) : Bool := o.get pp.letVarTypes.name (getPPAll
|
||||
def getPPNumericTypes (o : Options) : Bool := o.get pp.numericTypes.name pp.numericTypes.defValue
|
||||
def getPPNatLit (o : Options) : Bool := o.get pp.natLit.name (getPPNumericTypes o && !getPPAll o)
|
||||
def getPPCoercions (o : Options) : Bool := o.get pp.coercions.name (!getPPAll o)
|
||||
def getPPCoercionsTypes (o : Options) : Bool := o.get pp.coercions.types.name pp.coercions.types.defValue
|
||||
def getPPExplicit (o : Options) : Bool := o.get pp.explicit.name (getPPAll o)
|
||||
def getPPNotation (o : Options) : Bool := o.get pp.notation.name (!getPPAll o)
|
||||
def getPPParens (o : Options) : Bool := o.get pp.parens.name pp.parens.defValue
|
||||
|
||||
@@ -467,7 +467,7 @@ def visitAtom (k : SyntaxNodeKind) : Formatter := do
|
||||
@[combinator_formatter manyNoAntiquot]
|
||||
def manyNoAntiquot.formatter (p : Formatter) : Formatter := do
|
||||
let stx ← getCur
|
||||
visitArgs $ stx.getArgs.size.forM fun _ => p
|
||||
visitArgs $ stx.getArgs.size.forM fun _ _ => p
|
||||
|
||||
@[combinator_formatter many1NoAntiquot] def many1NoAntiquot.formatter (p : Formatter) : Formatter := manyNoAntiquot.formatter p
|
||||
|
||||
@@ -487,7 +487,7 @@ def many1Unbox.formatter (p : Formatter) : Formatter := do
|
||||
@[combinator_formatter sepByNoAntiquot]
|
||||
def sepByNoAntiquot.formatter (p pSep : Formatter) : Formatter := do
|
||||
let stx ← getCur
|
||||
visitArgs <| stx.getArgs.size.forRevM fun i => if i % 2 == 0 then p else pSep
|
||||
visitArgs <| stx.getArgs.size.forRevM fun i _ => if i % 2 == 0 then p else pSep
|
||||
|
||||
@[combinator_formatter sepBy1NoAntiquot] def sepBy1NoAntiquot.formatter := sepByNoAntiquot.formatter
|
||||
|
||||
|
||||
@@ -268,7 +268,7 @@ def visitToken : Parenthesizer := do
|
||||
let stx ← getCur
|
||||
-- `orelse` may produce `choice` nodes for antiquotations
|
||||
if stx.getKind == `choice then
|
||||
visitArgs $ stx.getArgs.size.forM fun _ => do
|
||||
visitArgs $ stx.getArgs.size.forM fun _ _ => do
|
||||
orelse.parenthesizer p1 p2
|
||||
else
|
||||
-- HACK: We have no (immediate) information on which side of the orelse could have produced the current node, so try
|
||||
@@ -332,7 +332,7 @@ partial def parenthesizeCategoryCore (cat : Name) (_prec : Nat) : Parenthesizer
|
||||
withReader (fun ctx => { ctx with cat := cat }) do
|
||||
let stx ← getCur
|
||||
if stx.getKind == `choice then
|
||||
visitArgs $ stx.getArgs.size.forM fun _ => do
|
||||
visitArgs $ stx.getArgs.size.forM fun _ _ => do
|
||||
parenthesizeCategoryCore cat _prec
|
||||
else
|
||||
withAntiquot.parenthesizer (mkAntiquot.parenthesizer' cat.toString cat (isPseudoKind := true)) (parenthesizerForKind stx.getKind)
|
||||
@@ -470,7 +470,7 @@ def trailingNode.parenthesizer (k : SyntaxNodeKind) (prec lhsPrec : Nat) (p : Pa
|
||||
@[combinator_parenthesizer manyNoAntiquot]
|
||||
def manyNoAntiquot.parenthesizer (p : Parenthesizer) : Parenthesizer := do
|
||||
let stx ← getCur
|
||||
visitArgs $ stx.getArgs.size.forM fun _ => p
|
||||
visitArgs $ stx.getArgs.size.forM fun _ _ => p
|
||||
|
||||
@[combinator_parenthesizer many1NoAntiquot]
|
||||
def many1NoAntiquot.parenthesizer (p : Parenthesizer) : Parenthesizer := do
|
||||
|
||||
@@ -33,9 +33,9 @@ def findCompletionCmdDataAtPos
|
||||
(pos : String.Pos)
|
||||
: Task (Option (Syntax × Elab.InfoTree)) :=
|
||||
findCmdDataAtPos doc (pos := pos) fun s => Id.run do
|
||||
let some tailPos := s.data.stx.getTailPos?
|
||||
let some tailPos := s.stx.getTailPos?
|
||||
| return false
|
||||
return pos.byteIdx <= tailPos.byteIdx + s.data.stx.getTrailingSize
|
||||
return pos.byteIdx <= tailPos.byteIdx + s.stx.getTrailingSize
|
||||
|
||||
def handleCompletion (p : CompletionParams)
|
||||
: RequestM (RequestTask CompletionList) := do
|
||||
|
||||
@@ -28,10 +28,10 @@ private partial def mkCmdSnaps (initSnap : Language.Lean.InitialSnapshot) :
|
||||
} <| .delayed <| headerSuccess.firstCmdSnap.task.bind go
|
||||
where
|
||||
go cmdParsed :=
|
||||
cmdParsed.data.finishedSnap.task.map fun finished =>
|
||||
cmdParsed.finishedSnap.task.map fun finished =>
|
||||
.ok <| .cons {
|
||||
stx := cmdParsed.data.stx
|
||||
mpState := cmdParsed.data.parserState
|
||||
stx := cmdParsed.stx
|
||||
mpState := cmdParsed.parserState
|
||||
cmdState := finished.cmdState
|
||||
} (match cmdParsed.nextCmdSnap? with
|
||||
| some next => .delayed <| next.task.bind go
|
||||
|
||||
@@ -204,7 +204,7 @@ partial def findInfoTreeAtPos
|
||||
findCmdParsedSnap doc (isMatchingSnapshot ·) |>.bind (sync := true) fun
|
||||
| some cmdParsed => toSnapshotTree cmdParsed |>.findInfoTreeAtPos pos |>.bind (sync := true) fun
|
||||
| some infoTree => .pure <| some infoTree
|
||||
| none => cmdParsed.data.finishedSnap.task.map (sync := true) fun s =>
|
||||
| none => cmdParsed.finishedSnap.task.map (sync := true) fun s =>
|
||||
-- the parser returns exactly one command per snapshot, and the elaborator creates exactly one node per command
|
||||
assert! s.cmdState.infoState.trees.size == 1
|
||||
some s.cmdState.infoState.trees[0]!
|
||||
@@ -225,11 +225,11 @@ def findCmdDataAtPos
|
||||
: Task (Option (Syntax × Elab.InfoTree)) :=
|
||||
findCmdParsedSnap doc (isMatchingSnapshot ·) |>.bind (sync := true) fun
|
||||
| some cmdParsed => toSnapshotTree cmdParsed |>.findInfoTreeAtPos pos |>.bind (sync := true) fun
|
||||
| some infoTree => .pure <| some (cmdParsed.data.stx, infoTree)
|
||||
| none => cmdParsed.data.finishedSnap.task.map (sync := true) fun s =>
|
||||
| some infoTree => .pure <| some (cmdParsed.stx, infoTree)
|
||||
| none => cmdParsed.finishedSnap.task.map (sync := true) fun s =>
|
||||
-- the parser returns exactly one command per snapshot, and the elaborator creates exactly one node per command
|
||||
assert! s.cmdState.infoState.trees.size == 1
|
||||
some (cmdParsed.data.stx, s.cmdState.infoState.trees[0]!)
|
||||
some (cmdParsed.stx, s.cmdState.infoState.trees[0]!)
|
||||
| none => .pure none
|
||||
|
||||
/--
|
||||
@@ -244,7 +244,7 @@ def findInfoTreeAtPosWithTrailingWhitespace
|
||||
: Task (Option Elab.InfoTree) :=
|
||||
-- NOTE: use `>=` since the cursor can be *after* the input (and there is no interesting info on
|
||||
-- the first character of the subsequent command if any)
|
||||
findInfoTreeAtPos doc (·.data.parserState.pos ≥ pos) pos
|
||||
findInfoTreeAtPos doc (·.parserState.pos ≥ pos) pos
|
||||
|
||||
open Elab.Command in
|
||||
def runCommandElabM (snap : Snapshot) (c : RequestT CommandElabM α) : RequestM α := do
|
||||
|
||||
@@ -383,7 +383,7 @@ partial def computeStructureResolutionOrder [Monad m] [MonadEnv m]
|
||||
-- `resOrders` contains the resolution orders to merge.
|
||||
-- The parent list is inserted as a pseudo resolution order to ensure immediate parents come out in order,
|
||||
-- and it is added first to be the primary ordering constraint when there are ordering errors.
|
||||
let mut resOrders := parentResOrders.insertAt 0 parentNames |>.filter (!·.isEmpty)
|
||||
let mut resOrders := parentResOrders.insertIdx 0 parentNames |>.filter (!·.isEmpty)
|
||||
|
||||
let mut resOrder : Array Name := #[structName]
|
||||
let mut defects : Array StructureResolutionOrderConflict := #[]
|
||||
|
||||
@@ -12,7 +12,23 @@ namespace Lean.Firefox
|
||||
|
||||
/-! Definitions from https://github.com/firefox-devtools/profiler/blob/main/src/types/profile.js -/
|
||||
|
||||
abbrev Milliseconds := Float
|
||||
structure Milliseconds where
|
||||
ms : Float
|
||||
deriving Inhabited
|
||||
|
||||
instance : Coe Float Milliseconds := ⟨fun s => .mk (s * 1000)⟩
|
||||
instance : OfScientific Milliseconds := ⟨fun m s e => .mk (OfScientific.ofScientific m s e)⟩
|
||||
instance : ToJson Milliseconds where toJson x := toJson x.ms
|
||||
instance : FromJson Milliseconds where fromJson? j := Milliseconds.mk <$> fromJson? j
|
||||
instance : Add Milliseconds where add x y := ⟨x.ms + y.ms⟩
|
||||
|
||||
structure Microseconds where
|
||||
μs : Float
|
||||
|
||||
instance : Coe Float Microseconds := ⟨fun s => .mk (s * 1000000)⟩
|
||||
instance : OfScientific Microseconds := ⟨fun m s e => .mk (OfScientific.ofScientific m s e)⟩
|
||||
instance : ToJson Microseconds where toJson x := toJson x.μs
|
||||
instance : FromJson Microseconds where fromJson? j := Microseconds.mk <$> fromJson? j
|
||||
|
||||
structure Category where
|
||||
name : String
|
||||
@@ -20,14 +36,22 @@ structure Category where
|
||||
subcategories : Array String := #[]
|
||||
deriving FromJson, ToJson
|
||||
|
||||
structure SampleUnits where
|
||||
time : String := "ms"
|
||||
eventDelay : String := "ms"
|
||||
threadCPUDelta : String := "µs"
|
||||
deriving FromJson, ToJson
|
||||
|
||||
structure ProfileMeta where
|
||||
interval : Milliseconds := 0 -- should be irrelevant with `tracing-ms`
|
||||
interval : Milliseconds := 0.0 -- should be irrelevant with `tracing-ms`
|
||||
startTime : Milliseconds
|
||||
categories : Array Category := #[]
|
||||
processType : Nat := 0
|
||||
product : String := "lean"
|
||||
preprocessedProfileVersion : Nat := 48
|
||||
markerSchema : Array Json := #[]
|
||||
-- without this field, no CPU usage is shown
|
||||
sampleUnits : SampleUnits := {}
|
||||
deriving FromJson, ToJson
|
||||
|
||||
structure StackTable where
|
||||
@@ -43,6 +67,7 @@ structure SamplesTable where
|
||||
time : Array Milliseconds
|
||||
weight : Array Milliseconds
|
||||
weightType : String := "tracing-ms"
|
||||
threadCPUDelta : Array Float
|
||||
length : Nat
|
||||
deriving FromJson, ToJson
|
||||
|
||||
@@ -109,11 +134,22 @@ def categories : Array Category := #[
|
||||
{ name := "Meta", color := "yellow" }
|
||||
]
|
||||
|
||||
/-- Returns first `startTime` in the trace tree, if any. -/
|
||||
private partial def getFirstStart? : MessageData → Option Float
|
||||
| .trace data _ children => do
|
||||
if data.startTime != 0 then
|
||||
return data.startTime
|
||||
if let some startTime := children.findSome? getFirstStart? then
|
||||
return startTime
|
||||
none
|
||||
| .withContext _ msg => getFirstStart? msg
|
||||
| _ => none
|
||||
|
||||
private partial def addTrace (pp : Bool) (thread : ThreadWithMaps) (trace : MessageData) :
|
||||
IO ThreadWithMaps :=
|
||||
(·.2) <$> StateT.run (go none trace) thread
|
||||
(·.2) <$> StateT.run (go none none trace) thread
|
||||
where
|
||||
go parentStackIdx? : _ → StateT ThreadWithMaps IO Unit
|
||||
go parentStackIdx? ctx? : _ → StateT ThreadWithMaps IO Unit
|
||||
| .trace data msg children => do
|
||||
if data.startTime == 0 then
|
||||
return -- no time data, skip
|
||||
@@ -121,7 +157,7 @@ where
|
||||
if !data.tag.isEmpty then
|
||||
funcName := s!"{funcName}: {data.tag}"
|
||||
if pp then
|
||||
funcName := s!"{funcName}: {← msg.format}"
|
||||
funcName := s!"{funcName}: {← msg.format ctx?}"
|
||||
let strIdx ← modifyGet fun thread =>
|
||||
if let some idx := thread.stringMap[funcName]? then
|
||||
(idx, thread)
|
||||
@@ -165,40 +201,34 @@ where
|
||||
stackMap := thread.stackMap.insert (frameIdx, parentStackIdx?) thread.stackMap.size })
|
||||
modify fun thread => { thread with lastTime := data.startTime }
|
||||
for c in children do
|
||||
if let some nextStart := getNextStart? c then
|
||||
if let some nextStart := getFirstStart? c then
|
||||
-- add run time slice between children/before first child
|
||||
-- a "slice" is always a pair of samples as otherwise Firefox Profiler interpolates
|
||||
-- between adjacent samples in undesirable ways
|
||||
modify fun thread => { thread with samples := {
|
||||
stack := thread.samples.stack.push stackIdx
|
||||
time := thread.samples.time.push (thread.lastTime * 1000)
|
||||
weight := thread.samples.weight.push ((nextStart - thread.lastTime) * 1000)
|
||||
length := thread.samples.length + 1
|
||||
stack := thread.samples.stack.push stackIdx |>.push stackIdx
|
||||
time := thread.samples.time.push thread.lastTime |>.push nextStart
|
||||
weight := thread.samples.weight.push 0.0 |>.push (nextStart - thread.lastTime)
|
||||
threadCPUDelta := thread.samples.threadCPUDelta.push 0.0 |>.push (nextStart - thread.lastTime)
|
||||
length := thread.samples.length + 2
|
||||
} }
|
||||
go (some stackIdx) c
|
||||
go (some stackIdx) ctx? c
|
||||
-- add remaining slice after last child
|
||||
modify fun thread => { thread with
|
||||
lastTime := data.stopTime
|
||||
samples := {
|
||||
stack := thread.samples.stack.push stackIdx
|
||||
time := thread.samples.time.push (thread.lastTime * 1000)
|
||||
weight := thread.samples.weight.push ((data.stopTime - thread.lastTime) * 1000)
|
||||
length := thread.samples.length + 1
|
||||
stack := thread.samples.stack.push stackIdx |>.push stackIdx
|
||||
time := thread.samples.time.push thread.lastTime |>.push data.stopTime
|
||||
weight := thread.samples.weight.push 0.0 |>.push (data.stopTime - thread.lastTime)
|
||||
threadCPUDelta := thread.samples.threadCPUDelta.push 0.0 |>.push (data.stopTime - thread.lastTime)
|
||||
length := thread.samples.length + 2
|
||||
} }
|
||||
| .withContext _ msg => go parentStackIdx? msg
|
||||
| .withContext ctx msg => go parentStackIdx? (some ctx) msg
|
||||
| _ => return
|
||||
/-- Returns first `startTime` in the trace tree, if any. -/
|
||||
getNextStart?
|
||||
| .trace data _ children => do
|
||||
if data.startTime != 0 then
|
||||
return data.startTime
|
||||
if let some startTime := children.findSome? getNextStart? then
|
||||
return startTime
|
||||
none
|
||||
| .withContext _ msg => getNextStart? msg
|
||||
| _ => none
|
||||
|
||||
def Thread.new (name : String) : Thread := {
|
||||
name
|
||||
samples := { stack := #[], time := #[], weight := #[], length := 0 }
|
||||
samples := { stack := #[], time := #[], weight := #[], threadCPUDelta := #[], length := 0 }
|
||||
stackTable := { frame := #[], «prefix» := #[], category := #[], subcategory := #[], length := 0 }
|
||||
frameTable := { func := #[], length := 0 }
|
||||
stringArray := #[]
|
||||
@@ -207,17 +237,39 @@ def Thread.new (name : String) : Thread := {
|
||||
length := 0 }
|
||||
}
|
||||
|
||||
def Profile.export (name : String) (startTime : Milliseconds) (traceState : TraceState)
|
||||
def Profile.export (name : String) (startTime : Float) (traceStates : Array TraceState)
|
||||
(opts : Options) : IO Profile := do
|
||||
let thread := Thread.new name
|
||||
-- wrap entire trace up to current time in `runFrontend` node
|
||||
let trace := .trace {
|
||||
cls := `runFrontend, startTime, stopTime := (← IO.monoNanosNow).toFloat / 1000000000,
|
||||
collapsed := true } "" (traceState.traces.toArray.map (·.msg))
|
||||
let thread ← addTrace (Lean.trace.profiler.output.pp.get opts) { thread with } trace
|
||||
let tids := traceStates.groupByKey (·.tid)
|
||||
let stopTime := (← IO.monoNanosNow).toFloat / 1000000000
|
||||
let threads ← tids.toArray.qsort (fun (t1, _) (t2, _) => t1 < t2) |>.mapM fun (tid, traceStates) => do
|
||||
let mut thread := { Thread.new (if tid = 0 then name else toString tid) with isMainThread := tid = 0 }
|
||||
let traces := traceStates.map (·.traces.toArray)
|
||||
-- sort traces of thread by start time
|
||||
let traces := traces.qsort (fun tr1 tr2 =>
|
||||
let f tr := tr.get? 0 |>.bind (getFirstStart? ·.msg) |>.getD 0
|
||||
f tr1 < f tr2)
|
||||
let mut traces := traces.flatMap id |>.map (·.msg)
|
||||
if tid = 0 then
|
||||
-- wrap entire trace up to current time in `runFrontend` node
|
||||
traces := #[.trace {
|
||||
cls := `runFrontend, startTime, stopTime,
|
||||
collapsed := true } "" traces]
|
||||
for trace in traces do
|
||||
if thread.lastTime > 0.0 then
|
||||
if let some nextStart := getFirstStart? trace then
|
||||
-- add 0% CPU run time slice between traces
|
||||
thread := { thread with samples := {
|
||||
stack := thread.samples.stack.push 0 |>.push 0
|
||||
time := thread.samples.time.push thread.lastTime |>.push nextStart
|
||||
weight := thread.samples.weight.push 0.0 |>.push 0.0
|
||||
threadCPUDelta := thread.samples.threadCPUDelta.push 0.0 |>.push 0.0
|
||||
length := thread.samples.length + 2
|
||||
} }
|
||||
thread ← addTrace (Lean.trace.profiler.output.pp.get opts) thread trace
|
||||
return thread
|
||||
return {
|
||||
meta := { startTime, categories }
|
||||
threads := #[thread.toThread]
|
||||
threads := threads.map (·.toThread)
|
||||
}
|
||||
|
||||
structure ThreadWithCollideMaps extends ThreadWithMaps where
|
||||
@@ -240,19 +292,20 @@ where
|
||||
if let some idx := thread.sampleMap[stackIdx]? then
|
||||
-- imperative to preserve linear use of arrays here!
|
||||
let ⟨⟨⟨t1, t2, t3, samples, t5, t6, t7, t8, t9, t10⟩, o2, o3, o4, o5⟩, o6⟩ := thread
|
||||
let ⟨s1, s2, weight, s3, s4⟩ := samples
|
||||
let ⟨s1, s2, weight, s3, s4, s5⟩ := samples
|
||||
let weight := weight.set! idx <| weight[idx]! + add.samples.weight[oldSampleIdx]!
|
||||
let samples := ⟨s1, s2, weight, s3, s4⟩
|
||||
let samples := ⟨s1, s2, weight, s3, s4, s5⟩
|
||||
⟨⟨⟨t1, t2, t3, samples, t5, t6, t7, t8, t9, t10⟩, o2, o3, o4, o5⟩, o6⟩
|
||||
else
|
||||
-- imperative to preserve linear use of arrays here!
|
||||
let ⟨⟨⟨t1, t2, t3, samples, t5, t6, t7, t8, t9, t10⟩, o2, o3, o4, o5⟩, sampleMap⟩ :=
|
||||
thread
|
||||
let ⟨stack, time, weight, _, length⟩ := samples
|
||||
let ⟨stack, time, weight, _, threadCPUDelta, length⟩ := samples
|
||||
let samples := {
|
||||
stack := stack.push stackIdx
|
||||
time := time.push time.size.toFloat
|
||||
weight := weight.push add.samples.weight[oldSampleIdx]!
|
||||
threadCPUDelta := threadCPUDelta.push 1
|
||||
length := length + 1
|
||||
}
|
||||
let sampleMap := sampleMap.insert stackIdx sampleMap.size
|
||||
|
||||
@@ -64,6 +64,8 @@ structure TraceElem where
|
||||
deriving Inhabited
|
||||
|
||||
structure TraceState where
|
||||
/-- Thread ID, used by `trace.profiler.output`. -/
|
||||
tid : UInt64 := 0
|
||||
traces : PersistentArray TraceElem := {}
|
||||
deriving Inhabited
|
||||
|
||||
|
||||
@@ -1140,6 +1140,7 @@ extern "C" LEAN_EXPORT obj_res lean_io_create_tempfile(lean_object * /* w */) {
|
||||
strcat(path, file_pattern);
|
||||
|
||||
uv_fs_t req;
|
||||
// Differences from lean_io_create_tempdir start here
|
||||
ret = uv_fs_mkstemp(NULL, &req, path, NULL);
|
||||
if (ret < 0) {
|
||||
// If mkstemp throws an error we cannot rely on path to contain a proper file name.
|
||||
@@ -1151,6 +1152,48 @@ extern "C" LEAN_EXPORT obj_res lean_io_create_tempfile(lean_object * /* w */) {
|
||||
}
|
||||
}
|
||||
|
||||
/* createTempDir : IO FilePath */
|
||||
extern "C" LEAN_EXPORT obj_res lean_io_create_tempdir(lean_object * /* w */) {
|
||||
char path[PATH_MAX];
|
||||
size_t base_len = PATH_MAX;
|
||||
int ret = uv_os_tmpdir(path, &base_len);
|
||||
if (ret < 0) {
|
||||
return io_result_mk_error(decode_uv_error(ret, nullptr));
|
||||
} else if (base_len == 0) {
|
||||
return lean_io_result_mk_error(decode_uv_error(UV_ENOENT, mk_string("")));
|
||||
}
|
||||
|
||||
#if defined(LEAN_WINDOWS)
|
||||
// On Windows `GetTempPathW` always returns a path ending in \, but libuv removes it.
|
||||
// https://learn.microsoft.com/en-us/windows/win32/fileio/creating-and-using-a-temporary-file
|
||||
if (path[base_len - 1] != '\\') {
|
||||
lean_always_assert(PATH_MAX >= base_len + 1 + 1);
|
||||
strcat(path, "\\");
|
||||
}
|
||||
#else
|
||||
// No guarantee that we have a trailing / in TMPDIR.
|
||||
if (path[base_len - 1] != '/') {
|
||||
lean_always_assert(PATH_MAX >= base_len + 1 + 1);
|
||||
strcat(path, "/");
|
||||
}
|
||||
#endif
|
||||
|
||||
const char* file_pattern = "tmp.XXXXXXXX";
|
||||
const size_t file_pattern_size = strlen(file_pattern);
|
||||
lean_always_assert(PATH_MAX >= strlen(path) + file_pattern_size + 1);
|
||||
strcat(path, file_pattern);
|
||||
|
||||
uv_fs_t req;
|
||||
// Differences from lean_io_create_tempfile start here
|
||||
ret = uv_fs_mkdtemp(NULL, &req, path, NULL);
|
||||
if (ret < 0) {
|
||||
// If mkdtemp throws an error we cannot rely on path to contain a proper file name.
|
||||
return io_result_mk_error(decode_uv_error(ret, nullptr));
|
||||
} else {
|
||||
return lean_io_result_mk_ok(mk_string(req.path));
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT obj_res lean_io_remove_file(b_obj_arg fname, obj_arg) {
|
||||
if (std::remove(string_cstr(fname)) == 0) {
|
||||
return io_result_mk_ok(box(0));
|
||||
|
||||
BIN
stage0/src/include/lean/lean.h
generated
BIN
stage0/src/include/lean/lean.h
generated
Binary file not shown.
BIN
stage0/src/runtime/object.cpp
generated
BIN
stage0/src/runtime/object.cpp
generated
Binary file not shown.
BIN
stage0/src/runtime/sharecommon.cpp
generated
BIN
stage0/src/runtime/sharecommon.cpp
generated
Binary file not shown.
BIN
stage0/stdlib/Init/Data/Array/Basic.c
generated
BIN
stage0/stdlib/Init/Data/Array/Basic.c
generated
Binary file not shown.
BIN
stage0/stdlib/Init/Data/Array/BinSearch.c
generated
BIN
stage0/stdlib/Init/Data/Array/BinSearch.c
generated
Binary file not shown.
BIN
stage0/stdlib/Init/Data/Array/Lemmas.c
generated
BIN
stage0/stdlib/Init/Data/Array/Lemmas.c
generated
Binary file not shown.
BIN
stage0/stdlib/Init/Data/Float.c
generated
BIN
stage0/stdlib/Init/Data/Float.c
generated
Binary file not shown.
BIN
stage0/stdlib/Init/Data/List/Attach.c
generated
BIN
stage0/stdlib/Init/Data/List/Attach.c
generated
Binary file not shown.
BIN
stage0/stdlib/Init/Data/List/Sort/Basic.c
generated
BIN
stage0/stdlib/Init/Data/List/Sort/Basic.c
generated
Binary file not shown.
BIN
stage0/stdlib/Init/Data/List/Sort/Impl.c
generated
BIN
stage0/stdlib/Init/Data/List/Sort/Impl.c
generated
Binary file not shown.
BIN
stage0/stdlib/Init/Data/Nat.c
generated
BIN
stage0/stdlib/Init/Data/Nat.c
generated
Binary file not shown.
BIN
stage0/stdlib/Init/Data/Nat/Basic.c
generated
BIN
stage0/stdlib/Init/Data/Nat/Basic.c
generated
Binary file not shown.
BIN
stage0/stdlib/Init/Data/Nat/Control.c
generated
BIN
stage0/stdlib/Init/Data/Nat/Control.c
generated
Binary file not shown.
BIN
stage0/stdlib/Init/Data/Nat/Fold.c
generated
Normal file
BIN
stage0/stdlib/Init/Data/Nat/Fold.c
generated
Normal file
Binary file not shown.
BIN
stage0/stdlib/Init/Data/String/Basic.c
generated
BIN
stage0/stdlib/Init/Data/String/Basic.c
generated
Binary file not shown.
BIN
stage0/stdlib/Init/Meta.c
generated
BIN
stage0/stdlib/Init/Meta.c
generated
Binary file not shown.
BIN
stage0/stdlib/Init/NotationExtra.c
generated
BIN
stage0/stdlib/Init/NotationExtra.c
generated
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user