Compare commits

..

1 Commits

Author SHA1 Message Date
Leonardo de Moura
7d0769a274 fix: Context.setConfig must update metaConfig and indexConfig fields
This PR fixes a bug at `Context.setConfig`. It was not propagating
the configuration to the redundant fields (cache) `metaConfig` and `indexConfig`.
2024-11-21 18:10:57 -08:00
459 changed files with 2463 additions and 4485 deletions

View File

@@ -1,8 +1,7 @@
# This workflow allows any user to add one of the `awaiting-review`, `awaiting-author`, `WIP`,
# `release-ci`, or a `changelog-XXX` label by commenting on the PR or issue.
# or `release-ci` labels by commenting on the PR or issue.
# If any labels from the set {`awaiting-review`, `awaiting-author`, `WIP`} are added, other labels
# from that set are removed automatically at the same time.
# Similarly, if any `changelog-XXX` label is added, other `changelog-YYY` labels are removed.
name: Label PR based on Comment
@@ -12,7 +11,7 @@ on:
jobs:
update-label:
if: github.event.issue.pull_request != null && (contains(github.event.comment.body, 'awaiting-review') || contains(github.event.comment.body, 'awaiting-author') || contains(github.event.comment.body, 'WIP') || contains(github.event.comment.body, 'release-ci') || contains(github.event.comment.body, 'changelog-'))
if: github.event.issue.pull_request != null && (contains(github.event.comment.body, 'awaiting-review') || contains(github.event.comment.body, 'awaiting-author') || contains(github.event.comment.body, 'WIP') || contains(github.event.comment.body, 'release-ci'))
runs-on: ubuntu-latest
steps:
@@ -21,14 +20,13 @@ jobs:
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const { owner, repo, number: issue_number } = context.issue;
const { owner, repo, number: issue_number } = context.issue;
const commentLines = context.payload.comment.body.split('\r\n');
const awaitingReview = commentLines.includes('awaiting-review');
const awaitingAuthor = commentLines.includes('awaiting-author');
const wip = commentLines.includes('WIP');
const releaseCI = commentLines.includes('release-ci');
const changelogMatch = commentLines.find(line => line.startsWith('changelog-'));
if (awaitingReview || awaitingAuthor || wip) {
await github.rest.issues.removeLabel({ owner, repo, issue_number, name: 'awaiting-review' }).catch(() => {});
@@ -49,19 +47,3 @@ jobs:
if (releaseCI) {
await github.rest.issues.addLabels({ owner, repo, issue_number, labels: ['release-ci'] });
}
if (changelogMatch) {
const changelogLabel = changelogMatch.trim();
const { data: existingLabels } = await github.rest.issues.listLabelsOnIssue({ owner, repo, issue_number });
const changelogLabels = existingLabels.filter(label => label.name.startsWith('changelog-'));
// Remove all other changelog labels
for (const label of changelogLabels) {
if (label.name !== changelogLabel) {
await github.rest.issues.removeLabel({ owner, repo, issue_number, name: label.name }).catch(() => {});
}
}
// Add the new changelog label
await github.rest.issues.addLabels({ owner, repo, issue_number, labels: [changelogLabel] });
}

View File

@@ -12,17 +12,17 @@ Remark: this example is based on an example found in the Idris manual.
Vectors
--------
A `Vec` is a list of size `n` whose elements belong to a type `α`.
A `Vector` is a list of size `n` whose elements belong to a type `α`.
-/
inductive Vec (α : Type u) : Nat Type u
| nil : Vec α 0
| cons : α Vec α n Vec α (n+1)
inductive Vector (α : Type u) : Nat Type u
| nil : Vector α 0
| cons : α Vector α n Vector α (n+1)
/-!
We can overload the `List.cons` notation `::` and use it to create `Vec`s.
We can overload the `List.cons` notation `::` and use it to create `Vector`s.
-/
infix:67 " :: " => Vec.cons
infix:67 " :: " => Vector.cons
/-!
Now, we define the types of our simple functional language.
@@ -50,11 +50,11 @@ the builtin instance for `Add Int` as the solution.
/-!
Expressions are indexed by the types of the local variables, and the type of the expression itself.
-/
inductive HasType : Fin n Vec Ty n Ty Type where
inductive HasType : Fin n Vector Ty n Ty Type where
| stop : HasType 0 (ty :: ctx) ty
| pop : HasType k ctx ty HasType k.succ (u :: ctx) ty
inductive Expr : Vec Ty n Ty Type where
inductive Expr : Vector Ty n Ty Type where
| var : HasType i ctx ty Expr ctx ty
| val : Int Expr ctx Ty.int
| lam : Expr (a :: ctx) ty Expr ctx (Ty.fn a ty)
@@ -102,8 +102,8 @@ indexed over the types in scope. Since an environment is just another form of li
to the vector of local variable types, we overload again the notation `::` so that we can use the usual list syntax.
Given a proof that a variable is defined in the context, we can then produce a value from the environment.
-/
inductive Env : Vec Ty n Type where
| nil : Env Vec.nil
inductive Env : Vector Ty n Type where
| nil : Env Vector.nil
| cons : Ty.interp a Env ctx Env (a :: ctx)
infix:67 " :: " => Env.cons

View File

@@ -82,7 +82,9 @@ theorem Expr.typeCheck_correct (h₁ : HasType e ty) (h₂ : e.typeCheck ≠ .un
/-!
Now, we prove that if `Expr.typeCheck e` returns `Maybe.unknown`, then forall `ty`, `HasType e ty` does not hold.
The notation `e.typeCheck` is sugar for `Expr.typeCheck e`. Lean can infer this because we explicitly said that `e` has type `Expr`.
The proof is by induction on `e` and case analysis. Note that the tactic `simp [typeCheck]` is applied to all goal generated by the `induction` tactic, and closes
The proof is by induction on `e` and case analysis. The tactic `rename_i` is used to rename "inaccessible" variables.
We say a variable is inaccessible if it is introduced by a tactic (e.g., `cases`) or has been shadowed by another variable introduced
by the user. Note that the tactic `simp [typeCheck]` is applied to all goal generated by the `induction` tactic, and closes
the cases corresponding to the constructors `Expr.nat` and `Expr.bool`.
-/
theorem Expr.typeCheck_complete {e : Expr} : e.typeCheck = .unknown ¬ HasType e ty := by

View File

@@ -43,4 +43,3 @@ import Init.Data.Zero
import Init.Data.NeZero
import Init.Data.Function
import Init.Data.RArray
import Init.Data.Vector

View File

@@ -13,7 +13,6 @@ import Init.Data.ToString.Basic
import Init.GetElem
import Init.Data.List.ToArray
import Init.Data.Array.Set
universe u v w
/-! ### Array literal syntax -/
@@ -166,15 +165,15 @@ This will perform the update destructively provided that `a` has a reference
count of 1 when called.
-/
@[extern "lean_array_fswap"]
def swap (a : Array α) (i j : @& Nat) (hi : i < a.size := by get_elem_tactic) (hj : j < a.size := by get_elem_tactic) : Array α :=
def swap (a : Array α) (i j : @& Fin a.size) : Array α :=
let v₁ := a[i]
let v₂ := a[j]
let a' := a.set i v₂
a'.set j v₁ (Nat.lt_of_lt_of_eq hj (size_set a i v₂ _).symm)
a'.set j v₁ (Nat.lt_of_lt_of_eq j.isLt (size_set a i v₂ _).symm)
@[simp] theorem size_swap (a : Array α) (i j : Nat) {hi hj} : (a.swap i j hi hj).size = a.size := by
@[simp] theorem size_swap (a : Array α) (i j : Fin a.size) : (a.swap i j).size = a.size := by
show ((a.set i a[j]).set j a[i]
(Nat.lt_of_lt_of_eq hj (size_set a i a[j] _).symm)).size = a.size
(Nat.lt_of_lt_of_eq j.isLt (size_set a i a[j] _).symm)).size = a.size
rw [size_set, size_set]
/--
@@ -184,14 +183,12 @@ This will perform the update destructively provided that `a` has a reference
count of 1 when called.
-/
@[extern "lean_array_swap"]
def swapIfInBounds (a : Array α) (i j : @& Nat) : Array α :=
def swap! (a : Array α) (i j : @& Nat) : Array α :=
if h₁ : i < a.size then
if h₂ : j < a.size then swap a i j
if h₂ : j < a.size then swap a i, h₁ j, h₂
else a
else a
@[deprecated swapIfInBounds (since := "2024-11-24")] abbrev swap! := @swapIfInBounds
/-! ### GetElem instance for `USize`, backed by `uget` -/
instance : GetElem (Array α) USize α fun xs i => i.toNat < xs.size where
@@ -236,7 +233,7 @@ def ofFn {n} (f : Fin n → α) : Array α := go 0 (mkEmpty n) where
/-- The array `#[0, 1, ..., n - 1]`. -/
def range (n : Nat) : Array Nat :=
ofFn fun (i : Fin n) => i
n.fold (flip Array.push) (mkEmpty n)
def singleton (v : α) : Array α :=
mkArray 1 v
@@ -252,7 +249,7 @@ def get? (a : Array α) (i : Nat) : Option α :=
def back? (a : Array α) : Option α :=
a[a.size - 1]?
@[inline] def swapAt (a : Array α) (i : Nat) (v : α) (hi : i < a.size := by get_elem_tactic) : α × Array α :=
@[inline] def swapAt (a : Array α) (i : Fin a.size) (v : α) : α × Array α :=
let e := a[i]
let a := a.set i v
(e, a)
@@ -260,7 +257,7 @@ def back? (a : Array α) : Option α :=
@[inline]
def swapAt! (a : Array α) (i : Nat) (v : α) : α × Array α :=
if h : i < a.size then
swapAt a i v
swapAt a i, h v
else
have : Inhabited (α × Array α) := (v, a)
panic! ("index " ++ toString i ++ " out of bounds")
@@ -749,7 +746,7 @@ where
loop (as : Array α) (i : Nat) (j : Fin as.size) :=
if h : i < j then
have := termination h
let as := as.swap i j (Nat.lt_trans h j.2)
let as := as.swap i, Nat.lt_trans h j.2 j
have : j-1 < as.size := by rw [size_swap]; exact Nat.lt_of_le_of_lt (Nat.pred_le _) j.2
loop as (i+1) j-1, this
else
@@ -789,7 +786,7 @@ it has to backshift all elements at positions greater than `i`.-/
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
def eraseIdx (a : Array α) (i : Nat) (h : i < a.size := by get_elem_tactic) : Array α :=
if h' : i + 1 < a.size then
let a' := a.swap (i + 1) i
let a' := a.swap i + 1, h' i, h
a'.eraseIdx (i + 1) (by simp [a', h'])
else
a.pop
@@ -836,7 +833,7 @@ def eraseP (as : Array α) (p : α → Bool) : Array α :=
let rec @[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
loop (as : Array α) (j : Fin as.size) :=
if i < j then
let j' : Fin as.size := j-1, Nat.lt_of_le_of_lt (Nat.pred_le _) j.2
let j' := j-1, Nat.lt_of_le_of_lt (Nat.pred_le _) j.2
let as := as.swap j' j
loop as j', by rw [size_swap]; exact j'.2
else
@@ -886,12 +883,12 @@ def isPrefixOf [BEq α] (as bs : Array α) : Bool :=
false
@[semireducible, specialize] -- This is otherwise irreducible because it uses well-founded recursion.
def zipWithAux (as : Array α) (bs : Array β) (f : α β γ) (i : Nat) (cs : Array γ) : Array γ :=
def zipWithAux (f : α β γ) (as : Array α) (bs : Array β) (i : Nat) (cs : Array γ) : Array γ :=
if h : i < as.size then
let a := as[i]
if h : i < bs.size then
let b := bs[i]
zipWithAux as bs f (i+1) <| cs.push <| f a b
zipWithAux f as bs (i+1) <| cs.push <| f a b
else
cs
else
@@ -899,23 +896,11 @@ def zipWithAux (as : Array α) (bs : Array β) (f : α → β → γ) (i : Nat)
decreasing_by simp_wf; decreasing_trivial_pre_omega
@[inline] def zipWith (as : Array α) (bs : Array β) (f : α β γ) : Array γ :=
zipWithAux as bs f 0 #[]
zipWithAux f as bs 0 #[]
def zip (as : Array α) (bs : Array β) : Array (α × β) :=
zipWith as bs Prod.mk
def zipWithAll (as : Array α) (bs : Array β) (f : Option α Option β γ) : Array γ :=
go as bs 0 #[]
where go (as : Array α) (bs : Array β) (i : Nat) (cs : Array γ) :=
if i < max as.size bs.size then
let a := as[i]?
let b := bs[i]?
go as bs (i+1) (cs.push (f a b))
else
cs
termination_by max as.size bs.size - i
decreasing_by simp_wf; decreasing_trivial_pre_omega
def unzip (as : Array (α × β)) : Array α × Array β :=
as.foldl (init := (#[], #[])) fun (as, bs) (a, b) => (as.push a, bs.push b)

View File

@@ -5,64 +5,59 @@ Authors: Leonardo de Moura
-/
prelude
import Init.Data.Array.Basic
import Init.Omega
universe u v
namespace Array
-- TODO: CLEANUP
@[specialize] def binSearchAux {α : Type u} {β : Type v} (lt : α α Bool) (found : Option α β) (as : Array α) (k : α) :
(lo : Fin (as.size + 1)) (hi : Fin as.size) (lo.1 hi.1) β
| lo, hi, h =>
let m := (lo.1 + hi.1)/2
let a := as[m]
if lt a k then
if h' : m + 1 hi.1 then
binSearchAux lt found as k m+1, by omega hi h'
else found none
else if lt k a then
if h' : m = 0 m - 1 < lo.1 then found none
else binSearchAux lt found as k lo m-1, by omega (by simp; omega)
else found (some a)
termination_by lo hi => hi.1 - lo.1
namespace Array
-- TODO: remove the [Inhabited α] parameters as soon as we have the tactic framework for automating proof generation and using Array.fget
-- TODO: remove `partial` using well-founded recursion
@[specialize] partial def binSearchAux {α : Type u} {β : Type v} [Inhabited β] (lt : α α Bool) (found : Option α β) (as : Array α) (k : α) : Nat Nat β
| lo, hi =>
if lo <= hi then
let _ := Inhabited.mk k
let m := (lo + hi)/2
let a := as.get! m
if lt a k then binSearchAux lt found as k (m+1) hi
else if lt k a then
if m == 0 then found none
else binSearchAux lt found as k lo (m-1)
else found (some a)
else found none
@[inline] def binSearch {α : Type} (as : Array α) (k : α) (lt : α α Bool) (lo := 0) (hi := as.size - 1) : Option α :=
if h : lo < as.size then
if lo < as.size then
let hi := if hi < as.size then hi else as.size - 1
if w : lo hi then
binSearchAux lt id as k lo, by omega hi, by simp [hi]; split <;> omega (by simp [hi]; omega)
else
none
binSearchAux lt id as k lo hi
else
none
@[inline] def binSearchContains {α : Type} (as : Array α) (k : α) (lt : α α Bool) (lo := 0) (hi := as.size - 1) : Bool :=
if h : lo < as.size then
if lo < as.size then
let hi := if hi < as.size then hi else as.size - 1
if w : lo hi then
binSearchAux lt Option.isSome as k lo, by omega hi, by simp [hi]; split <;> omega (by simp [hi]; omega)
else
false
binSearchAux lt Option.isSome as k lo hi
else
false
@[specialize] private def binInsertAux {α : Type u} {m : Type u Type v} [Monad m]
@[specialize] private partial def binInsertAux {α : Type u} {m : Type u Type v} [Monad m]
(lt : α α Bool)
(merge : α m α)
(add : Unit m α)
(as : Array α)
(k : α) : (lo : Fin as.size) (hi : Fin as.size) (lo.1 hi.1) (lt as[lo] k) m (Array α)
| lo, hi, h, w =>
let mid := (lo.1 + hi.1)/2
let midVal := as[mid]
if w₁ : lt midVal k then
if h' : mid = lo then do let v add (); pure <| as.insertIdx (lo+1) v
else binInsertAux lt merge add as k mid, by omega hi (by simp; omega) w₁
else if w₂ : lt k midVal then
have : mid lo := fun z => by simp [midVal, z] at w₁; simp_all
binInsertAux lt merge add as k lo mid, by omega (by simp; omega) w
(k : α) : Nat Nat m (Array α)
| lo, hi =>
let _ := Inhabited.mk k
-- as[lo] < k < as[hi]
let mid := (lo + hi)/2
let midVal := as.get! mid
if lt midVal k then
if mid == lo then do let v add (); pure <| as.insertIdx! (lo+1) v
else binInsertAux lt merge add as k mid hi
else if lt k midVal then
binInsertAux lt merge add as k lo mid
else do
as.modifyM mid <| fun v => merge v
termination_by lo hi => hi.1 - lo.1
@[specialize] def binInsertM {α : Type u} {m : Type u Type v} [Monad m]
(lt : α α Bool)
@@ -70,12 +65,13 @@ termination_by lo hi => hi.1 - lo.1
(add : Unit m α)
(as : Array α)
(k : α) : m (Array α) :=
if h : as.size = 0 then do let v add (); pure <| as.push v
else if lt k as[0] then do let v add (); pure <| as.insertIdx 0 v
else if h' : !lt as[0] k then as.modifyM 0 <| merge
else if lt as[as.size - 1] k then do let v add (); pure <| as.push v
else if !lt k as[as.size - 1] then as.modifyM (as.size - 1) <| merge
else binInsertAux lt merge add as k 0, by omega as.size - 1, by omega (by simp) (by simpa using h')
let _ := Inhabited.mk k
if as.isEmpty then do let v add (); pure <| as.push v
else if lt k (as.get! 0) then do let v add (); pure <| as.insertIdx! 0 v
else if !lt (as.get! 0) k then as.modifyM 0 <| merge
else if lt as.back! k then do let v add (); pure <| as.push v
else if !lt k as.back! then as.modifyM (as.size - 1) <| merge
else binInsertAux lt merge add as k 0 (as.size - 1)
@[inline] def binInsert {α : Type u} (lt : α α Bool) (as : Array α) (k : α) : Array α :=
Id.run <| binInsertM lt (fun _ => k) (fun _ => k) as k

View File

@@ -23,7 +23,7 @@ theorem foldlM_toList.aux [Monad m]
· cases Nat.not_le_of_gt _ (Nat.zero_add _ H)
· rename_i i; rw [Nat.succ_add] at H
simp [foldlM_toList.aux f arr i (j+1) H]
rw (occs := [2]) [ List.getElem_cons_drop_succ_eq_drop _]
rw (occs := .pos [2]) [ List.getElem_cons_drop_succ_eq_drop _]
rfl
· rw [List.drop_of_length_le (Nat.ge_of_not_lt _)]; rfl

View File

@@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Init.Data.Array.Basic
import Init.Data.BEq
import Init.Data.Nat.Lemmas
import Init.Data.List.Nat.BEq
import Init.ByCases

View File

@@ -6,7 +6,7 @@ Authors: Leonardo de Moura
prelude
import Init.Data.Array.Basic
@[inline] def Array.insertionSort (a : Array α) (lt : α α Bool := by exact (· < ·)) : Array α :=
@[inline] def Array.insertionSort (a : Array α) (lt : α α Bool) : Array α :=
traverse a 0 a.size
where
@[specialize] traverse (a : Array α) (i : Nat) (fuel : Nat) : Array α :=
@@ -23,6 +23,6 @@ where
| j'+1 =>
have h' : j' < a.size := by subst j; exact Nat.lt_trans (Nat.lt_succ_self _) h
if lt a[j] a[j'] then
swapLoop (a.swap j j') j' (by rw [size_swap]; assumption; done)
swapLoop (a.swap j, h j', h') j' (by rw [size_swap]; assumption; done)
else
a

View File

@@ -343,8 +343,8 @@ theorem isPrefixOfAux_toArray_zero [BEq α] (l₁ l₂ : List α) (hle : l₁.le
rw [ih]
simp_all
theorem zipWithAux_toArray_succ (as : List α) (bs : List β) (f : α β γ) (i : Nat) (cs : Array γ) :
zipWithAux as.toArray bs.toArray f (i + 1) cs = zipWithAux as.tail.toArray bs.tail.toArray f i cs := by
theorem zipWithAux_toArray_succ (f : α β γ) (as : List α) (bs : List β) (i : Nat) (cs : Array γ) :
zipWithAux f as.toArray bs.toArray (i + 1) cs = zipWithAux f as.tail.toArray bs.tail.toArray i cs := by
rw [zipWithAux]
conv => rhs; rw [zipWithAux]
simp only [size_toArray, getElem_toArray, length_tail, getElem_tail]
@@ -355,8 +355,8 @@ theorem zipWithAux_toArray_succ (as : List α) (bs : List β) (f : α → β →
rw [dif_neg (by omega)]
· rw [dif_neg (by omega)]
theorem zipWithAux_toArray_succ' (as : List α) (bs : List β) (f : α β γ) (i : Nat) (cs : Array γ) :
zipWithAux as.toArray bs.toArray f (i + 1) cs = zipWithAux (as.drop (i+1)).toArray (bs.drop (i+1)).toArray f 0 cs := by
theorem zipWithAux_toArray_succ' (f : α β γ) (as : List α) (bs : List β) (i : Nat) (cs : Array γ) :
zipWithAux f as.toArray bs.toArray (i + 1) cs = zipWithAux f (as.drop (i+1)).toArray (bs.drop (i+1)).toArray 0 cs := by
induction i generalizing as bs cs with
| zero => simp [zipWithAux_toArray_succ]
| succ i ih =>
@@ -364,7 +364,7 @@ theorem zipWithAux_toArray_succ' (as : List α) (bs : List β) (f : α → β
simp
theorem zipWithAux_toArray_zero (f : α β γ) (as : List α) (bs : List β) (cs : Array γ) :
zipWithAux as.toArray bs.toArray f 0 cs = cs ++ (List.zipWith f as bs).toArray := by
zipWithAux f as.toArray bs.toArray 0 cs = cs ++ (List.zipWith f as bs).toArray := by
rw [Array.zipWithAux]
match as, bs with
| [], _ => simp
@@ -372,7 +372,7 @@ theorem zipWithAux_toArray_zero (f : α → β → γ) (as : List α) (bs : List
| a :: as, b :: bs =>
simp [zipWith_cons_cons, zipWithAux_toArray_succ', zipWithAux_toArray_zero, push_append_toArray]
@[simp] theorem zipWith_toArray (as : List α) (bs : List β) (f : α β γ) :
@[simp] theorem zipWith_toArray (f : α β γ) (as : List α) (bs : List β) :
Array.zipWith as.toArray bs.toArray f = (List.zipWith f as bs).toArray := by
rw [Array.zipWith]
simp [zipWithAux_toArray_zero]
@@ -381,44 +381,6 @@ theorem zipWithAux_toArray_zero (f : α → β → γ) (as : List α) (bs : List
Array.zip as.toArray bs.toArray = (List.zip as bs).toArray := by
simp [Array.zip, zipWith_toArray, zip]
theorem zipWithAll_go_toArray (as : List α) (bs : List β) (f : Option α Option β γ) (i : Nat) (cs : Array γ) :
zipWithAll.go f as.toArray bs.toArray i cs = cs ++ (List.zipWithAll f (as.drop i) (bs.drop i)).toArray := by
unfold zipWithAll.go
split <;> rename_i h
· rw [zipWithAll_go_toArray]
simp at h
simp only [getElem?_toArray, push_append_toArray]
if ha : i < as.length then
if hb : i < bs.length then
rw [List.drop_eq_getElem_cons ha, List.drop_eq_getElem_cons hb]
simp only [ha, hb, getElem?_eq_getElem, zipWithAll_cons_cons]
else
simp only [Nat.not_lt] at hb
rw [List.drop_eq_getElem_cons ha]
rw [(drop_eq_nil_iff (l := bs)).mpr (by omega), (drop_eq_nil_iff (l := bs)).mpr (by omega)]
simp only [zipWithAll_nil, map_drop, map_cons]
rw [getElem?_eq_getElem ha]
rw [getElem?_eq_none hb]
else
if hb : i < bs.length then
simp only [Nat.not_lt] at ha
rw [List.drop_eq_getElem_cons hb]
rw [(drop_eq_nil_iff (l := as)).mpr (by omega), (drop_eq_nil_iff (l := as)).mpr (by omega)]
simp only [nil_zipWithAll, map_drop, map_cons]
rw [getElem?_eq_getElem hb]
rw [getElem?_eq_none ha]
else
omega
· simp only [size_toArray, Nat.not_lt] at h
rw [drop_eq_nil_of_le (by omega), drop_eq_nil_of_le (by omega)]
simp
termination_by max as.length bs.length - i
decreasing_by simp_wf; decreasing_trivial_pre_omega
@[simp] theorem zipWithAll_toArray (f : Option α Option β γ) (as : List α) (bs : List β) :
Array.zipWithAll as.toArray bs.toArray f = (List.zipWithAll f as bs).toArray := by
simp [Array.zipWithAll, zipWithAll_go_toArray]
end List
namespace Array
@@ -551,10 +513,10 @@ theorem getElem?_len_le (a : Array α) {i : Nat} (h : a.size ≤ i) : a[i]? = no
theorem getD_get? (a : Array α) (i : Nat) (d : α) :
Option.getD a[i]? d = if p : i < a.size then a[i]'p else d := by
if h : i < a.size then
simp [setIfInBounds, h, getElem?_def]
simp [setD, h, getElem?_def]
else
have p : i a.size := Nat.le_of_not_gt h
simp [setIfInBounds, getElem?_len_le _ p, h]
simp [setD, getElem?_len_le _ p, h]
@[simp] theorem getD_eq_get? (a : Array α) (n d) : a.getD n d = (a[n]?).getD d := by
simp only [getD, get_eq_getElem, get?_eq_getElem?]; split <;> simp [getD_get?, *]
@@ -590,32 +552,31 @@ theorem getElem_set (a : Array α) (i : Nat) (h' : i < a.size) (v : α) (j : Nat
(ne : i j) : (a.set i v)[j]? = a[j]? := by
by_cases h : j < a.size <;> simp [getElem?_lt, getElem?_ge, Nat.ge_of_not_lt, ne, h]
/-! # setIfInBounds -/
/-! # setD -/
@[simp] theorem set!_is_setIfInBounds : @set! = @setIfInBounds := rfl
@[simp] theorem set!_is_setD : @set! = @setD := rfl
@[simp] theorem size_setIfInBounds (a : Array α) (index : Nat) (val : α) :
(Array.setIfInBounds a index val).size = a.size := by
@[simp] theorem size_setD (a : Array α) (index : Nat) (val : α) :
(Array.setD a index val).size = a.size := by
if h : index < a.size then
simp [setIfInBounds, h]
simp [setD, h]
else
simp [setIfInBounds, h]
simp [setD, h]
@[simp] theorem getElem_setIfInBounds_eq (a : Array α) {i : Nat} (v : α) (h : _) :
(setIfInBounds a i v)[i]'h = v := by
@[simp] theorem getElem_setD_eq (a : Array α) {i : Nat} (v : α) (h : _) :
(setD a i v)[i]'h = v := by
simp at h
simp only [setIfInBounds, h, reduceDIte, getElem_set_eq]
simp only [setD, h, reduceDIte, getElem_set_eq]
@[simp]
theorem getElem?_setIfInBounds_eq (a : Array α) {i : Nat} (p : i < a.size) (v : α) :
(a.setIfInBounds i v)[i]? = some v := by
theorem getElem?_setD_eq (a : Array α) {i : Nat} (p : i < a.size) (v : α) : (a.setD i v)[i]? = some v := by
simp [getElem?_lt, p]
/-- Simplifies a normal form from `get!` -/
@[simp] theorem getD_get?_setIfInBounds (a : Array α) (i : Nat) (v d : α) :
Option.getD (setIfInBounds a i v)[i]? d = if i < a.size then v else d := by
@[simp] theorem getD_get?_setD (a : Array α) (i : Nat) (v d : α) :
Option.getD (setD a i v)[i]? d = if i < a.size then v else d := by
by_cases h : i < a.size <;>
simp [setIfInBounds, Nat.not_lt_of_le, h, getD_get?]
simp [setD, Nat.not_lt_of_le, h, getD_get?]
/-! # ofFn -/
@@ -660,19 +621,6 @@ theorem getElem?_ofFn (f : Fin n → α) (i : Nat) :
(ofFn f)[i]? = if h : i < n then some (f i, h) else none := by
simp [getElem?_def]
@[simp] theorem ofFn_zero (f : Fin 0 α) : ofFn f = #[] := rfl
theorem ofFn_succ (f : Fin (n+1) α) :
ofFn f = (ofFn (fun (i : Fin n) => f i.castSucc)).push (f n, by omega) := by
ext i h₁ h₂
· simp
· simp [getElem_push]
split <;> rename_i h₃
· rfl
· congr
simp at h₁ h₂
omega
/-! # mkArray -/
@[simp] theorem size_mkArray (n : Nat) (v : α) : (mkArray n v).size = n :=
@@ -807,32 +755,32 @@ theorem get_set (a : Array α) (i : Nat) (hi : i < a.size) (j : Nat) (hj : j < a
(h : i j) : (a.set i v)[j]'(by simp [*]) = a[j] := by
simp only [set, getElem_eq_getElem_toList, List.getElem_set_ne h]
theorem getElem_setIfInBounds (a : Array α) (i : Nat) (v : α) (h : i < (setIfInBounds a i v).size) :
(setIfInBounds a i v)[i] = v := by
theorem getElem_setD (a : Array α) (i : Nat) (v : α) (h : i < (setD a i v).size) :
(setD a i v)[i] = v := by
simp at h
simp only [setIfInBounds, h, reduceDIte, getElem_set_eq]
simp only [setD, h, reduceDIte, getElem_set_eq]
theorem set_set (a : Array α) (i : Nat) (h) (v v' : α) :
(a.set i v h).set i v' (by simp [h]) = a.set i v' := by simp [set, List.set_set]
private theorem fin_cast_val (e : n = n') (i : Fin n) : e i = i.1, e i.2 := by cases e; rfl
theorem swap_def (a : Array α) (i j : Nat) (hi hj) :
a.swap i j hi hj = (a.set i a[j]).set j a[i] (by simpa using hj) := by
theorem swap_def (a : Array α) (i j : Fin a.size) :
a.swap i j = (a.set i a[j]).set j a[i] := by
simp [swap, fin_cast_val]
@[simp] theorem toList_swap (a : Array α) (i j : Nat) (hi hj) :
(a.swap i j hi hj).toList = (a.toList.set i a[j]).set j a[i] := by simp [swap_def]
@[simp] theorem toList_swap (a : Array α) (i j : Fin a.size) :
(a.swap i j).toList = (a.toList.set i a[j]).set j a[i] := by simp [swap_def]
theorem getElem?_swap (a : Array α) (i j : Nat) (hi hj) (k : Nat) : (a.swap i j hi hj)[k]? =
if j = k then some a[i] else if i = k then some a[j] else a[k]? := by
theorem getElem?_swap (a : Array α) (i j : Fin a.size) (k : Nat) : (a.swap i j)[k]? =
if j = k then some a[i.1] else if i = k then some a[j.1] else a[k]? := by
simp [swap_def, get?_set, getElem_fin_eq_getElem_toList]
@[simp] theorem swapAt_def (a : Array α) (i : Nat) (v : α) (hi) :
a.swapAt i v hi = (a[i], a.set i v) := rfl
@[simp] theorem swapAt_def (a : Array α) (i : Fin a.size) (v : α) :
a.swapAt i v = (a[i.1], a.set i v) := rfl
@[simp] theorem size_swapAt (a : Array α) (i : Nat) (v : α) (hi) :
(a.swapAt i v hi).2.size = a.size := by simp [swapAt_def]
@[simp] theorem size_swapAt (a : Array α) (i : Fin a.size) (v : α) :
(a.swapAt i v).2.size = a.size := by simp [swapAt_def]
@[simp]
theorem swapAt!_def (a : Array α) (i : Nat) (v : α) (h : i < a.size) :
@@ -879,10 +827,8 @@ theorem eq_push_of_size_ne_zero {as : Array α} (h : as.size ≠ 0) :
theorem size_eq_length_toList (as : Array α) : as.size = as.toList.length := rfl
@[simp] theorem size_swapIfInBounds (a : Array α) (i j) :
(a.swapIfInBounds i j).size = a.size := by unfold swapIfInBounds; split <;> (try split) <;> simp [size_swap]
@[deprecated size_swapIfInBounds (since := "2024-11-24")] abbrev size_swap! := @size_swapIfInBounds
@[simp] theorem size_swap! (a : Array α) (i j) :
(a.swap! i j).size = a.size := by unfold swap!; split <;> (try split) <;> simp [size_swap]
@[simp] theorem size_reverse (a : Array α) : a.reverse.size = a.size := by
let rec go (as : Array α) (i j) : (reverse.loop as i j).size = as.size := by
@@ -894,10 +840,16 @@ theorem size_eq_length_toList (as : Array α) : as.size = as.toList.length := rf
simp only [reverse]; split <;> simp [go]
@[simp] theorem size_range {n : Nat} : (range n).size = n := by
induction n <;> simp [range]
unfold range
induction n with
| zero => simp [Nat.fold]
| succ k ih =>
rw [Nat.fold, flip]
simp only [mkEmpty_eq, size_push] at *
omega
@[simp] theorem toList_range (n : Nat) : (range n).toList = List.range n := by
apply List.ext_getElem <;> simp [range]
induction n <;> simp_all [range, Nat.fold, flip, List.range_succ]
@[simp]
theorem getElem_range {n : Nat} {x : Nat} (h : x < (Array.range n).size) : (Array.range n)[x] = x := by
@@ -1093,34 +1045,6 @@ theorem foldr_congr {as bs : Array α} (h₀ : as = bs) {f g : α → β → β}
as.foldr f a start stop = bs.foldr g b start' stop' := by
congr
theorem foldl_eq_foldlM (f : β α β) (b) (l : Array α) :
l.foldl f b = l.foldlM (m := Id) f b := by
cases l
simp [List.foldl_eq_foldlM]
theorem foldr_eq_foldrM (f : α β β) (b) (l : Array α) :
l.foldr f b = l.foldrM (m := Id) f b := by
cases l
simp [List.foldr_eq_foldrM]
@[simp] theorem id_run_foldlM (f : β α Id β) (b) (l : Array α) :
Id.run (l.foldlM f b) = l.foldl f b := (foldl_eq_foldlM f b l).symm
@[simp] theorem id_run_foldrM (f : α β Id β) (b) (l : Array α) :
Id.run (l.foldrM f b) = l.foldr f b := (foldr_eq_foldrM f b l).symm
theorem foldl_hom (f : α₁ α₂) (g₁ : α₁ β α₁) (g₂ : α₂ β α₂) (l : Array β) (init : α₁)
(H : x y, g₂ (f x) y = f (g₁ x y)) : l.foldl g₂ (f init) = f (l.foldl g₁ init) := by
cases l
simp
rw [List.foldl_hom _ _ _ _ _ H]
theorem foldr_hom (f : β₁ β₂) (g₁ : α β₁ β₁) (g₂ : α β₂ β₂) (l : Array α) (init : β₁)
(H : x y, g₂ x (f y) = f (g₁ x y)) : l.foldr g₂ (f init) = f (l.foldr g₁ init) := by
cases l
simp
rw [List.foldr_hom _ _ _ _ _ H]
/-! ### map -/
@[simp] theorem mem_map {f : α β} {l : Array α} : b l.map f a, a l f a = b := by
@@ -1672,30 +1596,28 @@ instance [DecidableEq α] (a : α) (as : Array α) : Decidable (a ∈ as) :=
open Fin
@[simp] theorem getElem_swap_right (a : Array α) {i j : Nat} {hi hj} :
(a.swap i j hi hj)[j]'(by simpa using hj) = a[i] := by
@[simp] theorem getElem_swap_right (a : Array α) {i j : Fin a.size} : (a.swap i j)[j.1] = a[i] := by
simp [swap_def, getElem_set]
@[simp] theorem getElem_swap_left (a : Array α) {i j : Nat} {hi hj} :
(a.swap i j hi hj)[i]'(by simpa using hi) = a[j] := by
@[simp] theorem getElem_swap_left (a : Array α) {i j : Fin a.size} : (a.swap i j)[i.1] = a[j] := by
simp +contextual [swap_def, getElem_set]
@[simp] theorem getElem_swap_of_ne (a : Array α) {i j : Nat} {hi hj} (hp : p < a.size)
(hi' : p i) (hj' : p j) : (a.swap i j hi hj)[p]'(a.size_swap .. |>.symm hp) = a[p] := by
simp [swap_def, getElem_set, hi'.symm, hj'.symm]
@[simp] theorem getElem_swap_of_ne (a : Array α) {i j : Fin a.size} (hp : p < a.size)
(hi : p i) (hj : p j) : (a.swap i j)[p]'(a.size_swap .. |>.symm hp) = a[p] := by
simp [swap_def, getElem_set, hi.symm, hj.symm]
theorem getElem_swap' (a : Array α) (i j : Nat) {hi hj} (k : Nat) (hk : k < a.size) :
(a.swap i j hi hj)[k]'(by simp_all) = if k = i then a[j] else if k = j then a[i] else a[k] := by
theorem getElem_swap' (a : Array α) (i j : Fin a.size) (k : Nat) (hk : k < a.size) :
(a.swap i j)[k]'(by simp_all) = if k = i then a[j] else if k = j then a[i] else a[k] := by
split
· simp_all only [getElem_swap_left]
· split <;> simp_all
theorem getElem_swap (a : Array α) (i j : Nat) {hi hj}(k : Nat) (hk : k < (a.swap i j).size) :
(a.swap i j hi hj)[k] = if k = i then a[j] else if k = j then a[i] else a[k]'(by simp_all) := by
theorem getElem_swap (a : Array α) (i j : Fin a.size) (k : Nat) (hk : k < (a.swap i j).size) :
(a.swap i j)[k] = if k = i then a[j] else if k = j then a[i] else a[k]'(by simp_all) := by
apply getElem_swap'
@[simp] theorem swap_swap (a : Array α) {i j : Nat} (hi hj) :
(a.swap i j hi hj).swap i j ((a.size_swap ..).symm hi) ((a.size_swap ..).symm hj) = a := by
@[simp] theorem swap_swap (a : Array α) {i j : Fin a.size} :
(a.swap i j).swap i.1, (a.size_swap ..).symm i.2 j.1, (a.size_swap ..).symm j.2 = a := by
apply ext
· simp only [size_swap]
· intros
@@ -1704,7 +1626,7 @@ theorem getElem_swap (a : Array α) (i j : Nat) {hi hj}(k : Nat) (hk : k < (a.sw
· simp_all
· split <;> simp_all
theorem swap_comm (a : Array α) {i j : Nat} {hi hj} : a.swap i j hi hj = a.swap j i hj hi := by
theorem swap_comm (a : Array α) {i j : Fin a.size} : a.swap i j = a.swap j i := by
apply ext
· simp only [size_swap]
· intros
@@ -1739,20 +1661,6 @@ theorem eraseIdx_eq_eraseIdxIfInBounds {a : Array α} {i : Nat} (h : i < a.size)
(Array.zip as bs).toList = List.zip as.toList bs.toList := by
simp [zip, toList_zipWith, List.zip]
@[simp] theorem toList_zipWithAll (f : Option α Option β γ) (as : Array α) (bs : Array β) :
(Array.zipWithAll as bs f).toList = List.zipWithAll f as.toList bs.toList := by
cases as
cases bs
simp
@[simp] theorem size_zipWith (as : Array α) (bs : Array β) (f : α β γ) :
(as.zipWith bs f).size = min as.size bs.size := by
rw [size_eq_length_toList, toList_zipWith, List.length_zipWith]
@[simp] theorem size_zip (as : Array α) (bs : Array β) :
(as.zip bs).size = min as.size bs.size :=
as.size_zipWith bs Prod.mk
/-! ### findSomeM?, findM?, findSome?, find? -/
@[simp] theorem findSomeM?_toList [Monad m] [LawfulMonad m] (p : α m (Option β)) (as : Array α) :
@@ -1827,10 +1735,10 @@ Our goal is to have `simp` "pull `List.toArray` outwards" as much as possible.
apply ext'
simp
@[simp] theorem setIfInBounds_toArray (l : List α) (i : Nat) (a : α) :
l.toArray.setIfInBounds i a = (l.set i a).toArray := by
@[simp] theorem setD_toArray (l : List α) (i : Nat) (a : α) :
l.toArray.setD i a = (l.set i a).toArray := by
apply ext'
simp only [setIfInBounds]
simp only [setD]
split
· simp
· simp_all [List.set_eq_of_length_le]
@@ -1875,8 +1783,8 @@ theorem all_toArray (p : α → Bool) (l : List α) : l.toArray.all p = l.all p
subst h
rw [all_toList]
@[simp] theorem swap_toArray (l : List α) (i j : Nat) {hi hj}:
l.toArray.swap i j hi hj = ((l.set i l[j]).set j l[i]).toArray := by
@[simp] theorem swap_toArray (l : List α) (i j : Fin l.toArray.size) :
l.toArray.swap i j = ((l.set i l[j]).set j l[i]).toArray := by
apply ext'
simp
@@ -2060,20 +1968,6 @@ theorem foldr_filterMap (f : α → Option β) (g : β → γγ) (l : Array
simp [List.foldr_filterMap]
rfl
theorem foldl_map' (g : α β) (f : α α α) (f' : β β β) (a : α) (l : Array α)
(h : x y, f' (g x) (g y) = g (f x y)) :
(l.map g).foldl f' (g a) = g (l.foldl f a) := by
cases l
simp
rw [List.foldl_map' _ _ _ _ _ h]
theorem foldr_map' (g : α β) (f : α α α) (f' : β β β) (a : α) (l : List α)
(h : x y, f' (g x) (g y) = g (f x y)) :
(l.map g).foldr f' (g a) = g (l.foldr f a) := by
cases l
simp
rw [List.foldr_map' _ _ _ _ _ h]
/-! ### flatten -/
@[simp] theorem flatten_empty : flatten (#[] : Array (Array α)) = #[] := rfl
@@ -2196,8 +2090,6 @@ theorem toArray_concat {as : List α} {x : α} :
@[deprecated back!_toArray (since := "2024-10-31")] abbrev back_toArray := @back!_toArray
@[deprecated setIfInBounds_toArray (since := "2024-11-24")] abbrev setD_toArray := @setIfInBounds_toArray
end List
namespace Array
@@ -2343,11 +2235,4 @@ abbrev get_swap' := @getElem_swap'
@[deprecated eq_push_pop_back!_of_size_ne_zero (since := "2024-10-31")]
abbrev eq_push_pop_back_of_size_ne_zero := @eq_push_pop_back!_of_size_ne_zero
@[deprecated set!_is_setIfInBounds (since := "2024-11-24")] abbrev set_is_setIfInBounds := @set!_is_setIfInBounds
@[deprecated size_setIfInBounds (since := "2024-11-24")] abbrev size_setD := @size_setIfInBounds
@[deprecated getElem_setIfInBounds_eq (since := "2024-11-24")] abbrev getElem_setD_eq := @getElem_setIfInBounds_eq
@[deprecated getElem?_setIfInBounds_eq (since := "2024-11-24")] abbrev get?_setD_eq := @getElem?_setIfInBounds_eq
@[deprecated getD_get?_setIfInBounds (since := "2024-11-24")] abbrev getD_setD := @getD_get?_setIfInBounds
@[deprecated getElem_setIfInBounds (since := "2024-11-24")] abbrev getElem_setD := @getElem_setIfInBounds
end Array

View File

@@ -13,19 +13,19 @@ namespace Array
def qpartition (as : Array α) (lt : α α Bool) (lo hi : Nat) : Nat × Array α :=
if h : as.size = 0 then (0, as) else have : Inhabited α := as[0]'(by revert h; cases as.size <;> simp) -- TODO: remove
let mid := (lo + hi) / 2
let as := if lt (as.get! mid) (as.get! lo) then as.swapIfInBounds lo mid else as
let as := if lt (as.get! hi) (as.get! lo) then as.swapIfInBounds lo hi else as
let as := if lt (as.get! mid) (as.get! hi) then as.swapIfInBounds mid hi else as
let as := if lt (as.get! mid) (as.get! lo) then as.swap! lo mid else as
let as := if lt (as.get! hi) (as.get! lo) then as.swap! lo hi else as
let as := if lt (as.get! mid) (as.get! hi) then as.swap! mid hi else as
let pivot := as.get! hi
let rec loop (as : Array α) (i j : Nat) :=
if h : j < hi then
if lt (as.get! j) pivot then
let as := as.swapIfInBounds i j
let as := as.swap! i j
loop as (i+1) (j+1)
else
loop as i (j+1)
else
let as := as.swapIfInBounds i hi
let as := as.swap! i hi
(i, as)
termination_by hi - j
decreasing_by all_goals simp_wf; decreasing_trivial_pre_omega

View File

@@ -25,11 +25,9 @@ Set an element in an array, or do nothing if the index is out of bounds.
This will perform the update destructively provided that `a` has a reference
count of 1 when called.
-/
@[inline] def Array.setIfInBounds (a : Array α) (i : Nat) (v : α) : Array α :=
@[inline] def Array.setD (a : Array α) (i : Nat) (v : α) : Array α :=
dite (LT.lt i a.size) (fun h => a.set i v h) (fun _ => a)
@[deprecated Array.setIfInBounds (since := "2024-11-24")] abbrev Array.setD := @Array.setIfInBounds
/--
Set an element in an array, or panic if the index is out of bounds.
@@ -38,4 +36,4 @@ count of 1 when called.
-/
@[extern "lean_array_set"]
def Array.set! (a : Array α) (i : @& Nat) (v : α) : Array α :=
Array.setIfInBounds a i v
Array.setD a i v

View File

@@ -23,13 +23,16 @@ def split (s : Subarray α) (i : Fin s.size.succ) : (Subarray α × Subarray α)
let i', isLt := i
have := s.start_le_stop
have := s.stop_le_array_size
have : i' s.stop - s.start := Nat.lt_succ.mp isLt
have : s.start + i' s.stop := by omega
have : s.start + i' s.array.size := by omega
have : s.start + i' s.stop := by
simp only [size] at isLt
omega
let pre := {s with
stop := s.start + i',
start_le_stop := by omega,
stop_le_array_size := by omega
stop_le_array_size := by assumption
}
let post := {s with
start := s.start + i'
@@ -45,7 +48,9 @@ def drop (arr : Subarray α) (i : Nat) : Subarray α where
array := arr.array
start := min (arr.start + i) arr.stop
stop := arr.stop
start_le_stop := by omega
start_le_stop := by
rw [Nat.min_def]
split <;> simp only [Nat.le_refl, *]
stop_le_array_size := arr.stop_le_array_size
/--
@@ -58,7 +63,9 @@ def take (arr : Subarray α) (i : Nat) : Subarray α where
stop := min (arr.start + i) arr.stop
start_le_stop := by
have := arr.start_le_stop
omega
rw [Nat.min_def]
split <;> omega
stop_le_array_size := by
have := arr.stop_le_array_size
omega
rw [Nat.min_def]
split <;> omega

View File

@@ -346,10 +346,6 @@ theorem getMsbD_sub {i : Nat} {i_lt : i < w} {x y : BitVec w} :
· rfl
· omega
theorem getElem_sub {i : Nat} {x y : BitVec w} (h : i < w) :
(x - y)[i] = (x[i] ^^ ((~~~y + 1#w)[i] ^^ carry i x (~~~y + 1#w) false)) := by
simp [ getLsbD_eq_getElem, getLsbD_sub, h]
theorem msb_sub {x y: BitVec w} :
(x - y).msb
= (x.msb ^^ ((~~~y + 1#w).msb ^^ carry (w - 1 - 0) x (~~~y + 1#w) false)) := by
@@ -414,10 +410,6 @@ theorem getLsbD_neg {i : Nat} {x : BitVec w} :
· have h_ge : w i := by omega
simp [getLsbD_ge _ _ h_ge, h_ge, hi]
theorem getElem_neg {i : Nat} {x : BitVec w} (h : i < w) :
(-x)[i] = (x[i] ^^ decide ( j < i, x.getLsbD j = true)) := by
simp [ getLsbD_eq_getElem, getLsbD_neg, h]
theorem getMsbD_neg {i : Nat} {x : BitVec w} :
getMsbD (-x) i =
(getMsbD x i ^^ decide ( j < w, i < j getMsbD x j = true)) := by

View File

@@ -269,10 +269,6 @@ theorem ofBool_eq_iff_eq : ∀ {b b' : Bool}, BitVec.ofBool b = BitVec.ofBool b'
getLsbD (x#'lt) i = x.testBit i := by
simp [getLsbD, BitVec.ofNatLt]
@[simp] theorem getMsbD_ofNatLt {n x i : Nat} (h : x < 2^n) :
getMsbD (x#'h) i = (decide (i < n) && x.testBit (n - 1 - i)) := by
simp [getMsbD, getLsbD]
@[simp, bv_toNat] theorem toNat_ofNat (x w : Nat) : (BitVec.ofNat w x).toNat = x % 2^w := by
simp [BitVec.toNat, BitVec.ofNat, Fin.ofNat']
@@ -565,10 +561,6 @@ theorem zeroExtend_eq_setWidth {v : Nat} {x : BitVec w} :
else
simp [n_le_i, toNat_ofNat]
@[simp] theorem toInt_setWidth (x : BitVec w) :
(x.setWidth v).toInt = Int.bmod x.toNat (2^v) := by
simp [toInt_eq_toNat_bmod, toNat_setWidth, Int.emod_bmod]
theorem setWidth'_eq {x : BitVec w} (h : w v) : x.setWidth' h = x.setWidth v := by
apply eq_of_toNat_eq
rw [toNat_setWidth, toNat_setWidth']
@@ -763,10 +755,6 @@ theorem extractLsb'_eq_extractLsb {w : Nat} (x : BitVec w) (start len : Nat) (h
@[simp] theorem getLsbD_allOnes : (allOnes v).getLsbD i = decide (i < v) := by
simp [allOnes]
@[simp] theorem getMsbD_allOnes : (allOnes v).getMsbD i = decide (i < v) := by
simp [allOnes]
omega
@[simp] theorem getElem_allOnes (i : Nat) (h : i < v) : (allOnes v)[i] = true := by
simp [getElem_eq_testBit_toNat, h]
@@ -784,12 +772,6 @@ theorem extractLsb'_eq_extractLsb {w : Nat} (x : BitVec w) (start len : Nat) (h
@[simp] theorem toNat_or (x y : BitVec v) :
BitVec.toNat (x ||| y) = BitVec.toNat x ||| BitVec.toNat y := rfl
@[simp] theorem toInt_or (x y : BitVec w) :
BitVec.toInt (x ||| y) = Int.bmod (BitVec.toNat x ||| BitVec.toNat y) (2^w) := by
rw_mod_cast [Int.bmod_def, BitVec.toInt, toNat_or, Nat.mod_eq_of_lt
(Nat.or_lt_two_pow (BitVec.isLt x) (BitVec.isLt y))]
omega
@[simp] theorem toFin_or (x y : BitVec v) :
BitVec.toFin (x ||| y) = BitVec.toFin x ||| BitVec.toFin y := by
apply Fin.eq_of_val_eq
@@ -857,12 +839,6 @@ instance : Std.LawfulCommIdentity (α := BitVec n) (· ||| · ) (0#n) where
@[simp] theorem toNat_and (x y : BitVec v) :
BitVec.toNat (x &&& y) = BitVec.toNat x &&& BitVec.toNat y := rfl
@[simp] theorem toInt_and (x y : BitVec w) :
BitVec.toInt (x &&& y) = Int.bmod (BitVec.toNat x &&& BitVec.toNat y) (2^w) := by
rw_mod_cast [Int.bmod_def, BitVec.toInt, toNat_and, Nat.mod_eq_of_lt
(Nat.and_lt_two_pow x.toNat (BitVec.isLt y))]
omega
@[simp] theorem toFin_and (x y : BitVec v) :
BitVec.toFin (x &&& y) = BitVec.toFin x &&& BitVec.toFin y := by
apply Fin.eq_of_val_eq
@@ -930,12 +906,6 @@ instance : Std.LawfulCommIdentity (α := BitVec n) (· &&& · ) (allOnes n) wher
@[simp] theorem toNat_xor (x y : BitVec v) :
BitVec.toNat (x ^^^ y) = BitVec.toNat x ^^^ BitVec.toNat y := rfl
@[simp] theorem toInt_xor (x y : BitVec w) :
BitVec.toInt (x ^^^ y) = Int.bmod (BitVec.toNat x ^^^ BitVec.toNat y) (2^w) := by
rw_mod_cast [Int.bmod_def, BitVec.toInt, toNat_xor, Nat.mod_eq_of_lt
(Nat.xor_lt_two_pow (BitVec.isLt x) (BitVec.isLt y))]
omega
@[simp] theorem toFin_xor (x y : BitVec v) :
BitVec.toFin (x ^^^ y) = BitVec.toFin x ^^^ BitVec.toFin y := by
apply Fin.eq_of_val_eq
@@ -1013,13 +983,6 @@ theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl
_ 2 ^ i := Nat.pow_le_pow_of_le_right Nat.zero_lt_two w
· simp
@[simp] theorem toInt_not {x : BitVec w} :
(~~~x).toInt = Int.bmod (2^w - 1 - x.toNat) (2^w) := by
rw_mod_cast [BitVec.toInt, BitVec.toNat_not, Int.bmod_def]
simp [show ((2^w : Nat) : Int) - 1 - x.toNat = ((2^w - 1 - x.toNat) : Nat) by omega]
rw_mod_cast [Nat.mod_eq_of_lt (by omega)]
omega
@[simp] theorem ofInt_negSucc_eq_not_ofNat {w n : Nat} :
BitVec.ofInt w (Int.negSucc n) = ~~~.ofNat w n := by
simp only [BitVec.ofInt, Int.toNat, Int.ofNat_eq_coe, toNat_eq, toNat_ofNatLt, toNat_not,
@@ -1044,10 +1007,6 @@ theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl
@[simp] theorem getLsbD_not {x : BitVec v} : (~~~x).getLsbD i = (decide (i < v) && ! x.getLsbD i) := by
by_cases h' : i < v <;> simp_all [not_def]
@[simp] theorem getMsbD_not {x : BitVec v} :
(~~~x).getMsbD i = (decide (i < v) && ! x.getMsbD i) := by
by_cases h' : i < v <;> simp_all [not_def]
@[simp] theorem getElem_not {x : BitVec w} {i : Nat} (h : i < w) : (~~~x)[i] = !x[i] := by
simp only [getElem_eq_testBit_toNat, toNat_not]
rw [ Nat.sub_add_eq, Nat.add_comm 1]
@@ -1521,12 +1480,6 @@ theorem getLsbD_sshiftRight' {x y: BitVec w} {i : Nat} :
(!decide (w i) && if y.toNat + i < w then x.getLsbD (y.toNat + i) else x.msb) := by
simp only [BitVec.sshiftRight', BitVec.getLsbD_sshiftRight]
@[simp]
theorem getElem_sshiftRight' {x y : BitVec w} {i : Nat} (h : i < w) :
(x.sshiftRight' y)[i] =
(!decide (w i) && if y.toNat + i < w then x.getLsbD (y.toNat + i) else x.msb) := by
simp only [ getLsbD_eq_getElem, BitVec.sshiftRight', BitVec.getLsbD_sshiftRight]
@[simp]
theorem getMsbD_sshiftRight' {x y: BitVec w} {i : Nat} :
(x.sshiftRight y.toNat).getMsbD i = (decide (i < w) && if i < y.toNat then x.msb else x.getMsbD (i - y.toNat)) := by
@@ -1619,79 +1572,6 @@ theorem signExtend_eq_setWidth_of_lt (x : BitVec w) {v : Nat} (hv : v ≤ w):
theorem signExtend_eq (x : BitVec w) : x.signExtend w = x := by
rw [signExtend_eq_setWidth_of_lt _ (Nat.le_refl _), setWidth_eq]
/-- Sign extending to a larger bitwidth depends on the msb.
If the msb is false, then the result equals the original value.
If the msb is true, then we add a value of `(2^v - 2^w)`, which arises from the sign extension. -/
theorem toNat_signExtend_of_le (x : BitVec w) {v : Nat} (hv : w v) :
(x.signExtend v).toNat = x.toNat + if x.msb then 2^v - 2^w else 0 := by
apply Nat.eq_of_testBit_eq
intro i
have k, hk := Nat.exists_eq_add_of_le hv
rw [hk, testBit_toNat, getLsbD_signExtend, Nat.pow_add, Nat.mul_sub_one, Nat.add_comm (x.toNat)]
by_cases hx : x.msb
· simp [hx, Nat.testBit_mul_pow_two_add _ x.isLt, testBit_toNat]
-- Case analysis on i being in the intervals [0..w), [w..w + k), [w+k..∞)
have hi : i < w (w i i < w + k) w + k i := by omega
rcases hi with hi | hi | hi
· simp [hi]; omega
· simp [hi]; omega
· simp [hi, show ¬ (i < w + k) by omega, show ¬ (i < w) by omega]
omega
· simp [hx, Nat.testBit_mul_pow_two_add _ x.isLt, testBit_toNat]
have hi : i < w (w i i < w + k) w + k i := by omega
rcases hi with hi | hi | hi
· simp [hi]; omega
· simp [hi]
· simp [hi, show ¬ (i < w + k) by omega, show ¬ (i < w) by omega, getLsbD_ge x i (by omega)]
/-- Sign extending to a larger bitwidth depends on the msb.
If the msb is false, then the result equals the original value.
If the msb is true, then we add a value of `(2^v - 2^w)`, which arises from the sign extension. -/
theorem toNat_signExtend (x : BitVec w) {v : Nat} :
(x.signExtend v).toNat = (x.setWidth v).toNat + if x.msb then 2^v - 2^w else 0 := by
by_cases h : v w
· have : 2^v 2^w := Nat.pow_le_pow_of_le_right Nat.two_pos h
simp [signExtend_eq_setWidth_of_lt x h, toNat_setWidth, Nat.sub_eq_zero_of_le this]
· have : 2^w 2^v := Nat.pow_le_pow_of_le_right Nat.two_pos (by omega)
rw [toNat_signExtend_of_le x (by omega), toNat_setWidth, Nat.mod_eq_of_lt (by omega)]
/-
If the current width `w` is smaller than the extended width `v`,
then the value when interpreted as an integer does not change.
-/
theorem toInt_signExtend_of_lt {x : BitVec w} (hv : w < v):
(x.signExtend v).toInt = x.toInt := by
simp only [toInt_eq_msb_cond, toNat_signExtend]
have : (x.signExtend v).msb = x.msb := by
rw [msb_eq_getLsbD_last, getLsbD_eq_getElem (Nat.sub_one_lt_of_lt hv)]
simp [getElem_signExtend, Nat.le_sub_one_of_lt hv]
have H : 2^w 2^v := Nat.pow_le_pow_of_le_right (by omega) (by omega)
simp only [this, toNat_setWidth, Int.natCast_add, Int.ofNat_emod, Int.natCast_mul]
by_cases h : x.msb
<;> norm_cast
<;> simp [h, Nat.mod_eq_of_lt (Nat.lt_of_lt_of_le x.isLt H)]
omega
/-
If the current width `w` is larger than the extended width `v`,
then the value when interpreted as an integer is truncated,
and we compute a modulo by `2^v`.
-/
theorem toInt_signExtend_of_le {x : BitVec w} (hv : v w) :
(x.signExtend v).toInt = Int.bmod x.toNat (2^v) := by
simp [signExtend_eq_setWidth_of_lt _ hv]
/-
Interpreting the sign extension of `(x : BitVec w)` to width `v`
computes `x % 2^v` (where `%` is the balanced mod).
-/
theorem toInt_signExtend (x : BitVec w) :
(x.signExtend v).toInt = Int.bmod x.toNat (2^(min v w)) := by
by_cases hv : v w
· simp [toInt_signExtend_of_le hv, Nat.min_eq_left hv]
· simp only [Nat.not_le] at hv
rw [toInt_signExtend_of_lt hv, Nat.min_eq_right (by omega), toInt_eq_toNat_bmod]
/-! ### append -/
theorem append_def (x : BitVec v) (y : BitVec w) :
@@ -3344,11 +3224,7 @@ theorem toNat_abs {x : BitVec w} : x.abs.toNat = if x.msb then 2^w - x.toNat els
· simp [h]
theorem getLsbD_abs {i : Nat} {x : BitVec w} :
getLsbD x.abs i = if x.msb then getLsbD (-x) i else getLsbD x i := by
by_cases h : x.msb <;> simp [BitVec.abs, h]
theorem getElem_abs {i : Nat} {x : BitVec w} (h : i < w) :
x.abs[i] = if x.msb then (-x)[i] else x[i] := by
getLsbD x.abs i = if x.msb then getLsbD (-x) i else getLsbD x i := by
by_cases h : x.msb <;> simp [BitVec.abs, h]
theorem getMsbD_abs {i : Nat} {x : BitVec w} :

View File

@@ -108,18 +108,8 @@ def toList (bs : ByteArray) : List UInt8 :=
@[inline] def findIdx? (a : ByteArray) (p : UInt8 Bool) (start := 0) : Option Nat :=
let rec @[specialize] loop (i : Nat) :=
if h : i < a.size then
if p a[i] then some i else loop (i+1)
else
none
termination_by a.size - i
decreasing_by decreasing_trivial_pre_omega
loop start
@[inline] def findFinIdx? (a : ByteArray) (p : UInt8 Bool) (start := 0) : Option (Fin a.size) :=
let rec @[specialize] loop (i : Nat) :=
if h : i < a.size then
if p a[i] then some i, h else loop (i+1)
if i < a.size then
if p (a.get! i) then some i else loop (i+1)
else
none
termination_by a.size - i

View File

@@ -13,17 +13,17 @@ namespace Fin
/-- Folds over `Fin n` from the left: `foldl 3 f x = f (f (f x 0) 1) 2`. -/
@[inline] def foldl (n) (f : α Fin n α) (init : α) : α := loop init 0 where
/-- Inner loop for `Fin.foldl`. `Fin.foldl.loop n f x i = f (f (f x i) ...) (n-1)` -/
@[semireducible] loop (x : α) (i : Nat) : α :=
loop (x : α) (i : Nat) : α :=
if h : i < n then loop (f x i, h) (i+1) else x
termination_by n - i
decreasing_by decreasing_trivial_pre_omega
/-- Folds over `Fin n` from the right: `foldr 3 f x = f 0 (f 1 (f 2 x))`. -/
@[inline] def foldr (n) (f : Fin n α α) (init : α) : α := loop n (Nat.le_refl n) init where
@[inline] def foldr (n) (f : Fin n α α) (init : α) : α := loop n, Nat.le_refl n init where
/-- Inner loop for `Fin.foldr`. `Fin.foldr.loop n f i x = f 0 (f ... (f (i-1) x))` -/
loop : (i : _) i n α α
| 0, _, x => x
| i+1, h, x => loop i (Nat.le_of_lt h) (f i, h x)
termination_by structural i => i
loop : {i // i n} α α
| 0, _, x => x
| i+1, h, x => loop i, Nat.le_of_lt h (f i, h x)
/--
Folds a monadic function over `Fin n` from left to right:
@@ -176,19 +176,17 @@ theorem foldl_eq_foldlM (f : α → Fin n → α) (x) :
/-! ### foldr -/
theorem foldr_loop_zero (f : Fin n α α) (x) :
foldr.loop n f 0 (Nat.zero_le _) x = x := by
foldr.loop n f 0, Nat.zero_le _ x = x := by
rw [foldr.loop]
theorem foldr_loop_succ (f : Fin n α α) (x) (h : i < n) :
foldr.loop n f (i+1) h x = foldr.loop n f i (Nat.le_of_lt h) (f i, h x) := by
foldr.loop n f i+1, h x = foldr.loop n f i, Nat.le_of_lt h (f i, h x) := by
rw [foldr.loop]
theorem foldr_loop (f : Fin (n+1) α α) (x) (h : i+1 n+1) :
foldr.loop (n+1) f (i+1) h x =
f 0 (foldr.loop n (fun j => f j.succ) i (Nat.le_of_succ_le_succ h) x) := by
induction i generalizing x with
| zero => simp [foldr_loop_succ, foldr_loop_zero]
| succ i ih => rw [foldr_loop_succ, ih]; rfl
foldr.loop (n+1) f i+1, h x =
f 0 (foldr.loop n (fun j => f j.succ) i, Nat.le_of_succ_le_succ h x) := by
induction i generalizing x <;> simp [foldr_loop_zero, foldr_loop_succ, *]
@[simp] theorem foldr_zero (f : Fin 0 α α) (x) : foldr 0 f x = x :=
foldr_loop_zero ..

View File

@@ -231,7 +231,7 @@ theorem ext_get? : ∀ {l₁ l₂ : List α}, (∀ n, l₁.get? n = l₂.get? n)
injection h0 with aa; simp only [aa, ext_get? fun n => h (n+1)]
/-- Deprecated alias for `ext_get?`. The preferred extensionality theorem is now `ext_getElem?`. -/
@[deprecated ext_get? (since := "2024-06-07")] abbrev ext := @ext_get?
@[deprecated (since := "2024-06-07")] abbrev ext := @ext_get?
/-! ### getD -/
@@ -682,7 +682,7 @@ theorem elem_cons [BEq α] {a : α} :
(b::bs).elem a = match a == b with | true => true | false => bs.elem a := rfl
/-- `notElem a l` is `!(elem a l)`. -/
@[deprecated "Use `!(elem a l)` instead."(since := "2024-06-15")]
@[deprecated (since := "2024-06-15")]
def notElem [BEq α] (a : α) (as : List α) : Bool :=
!(as.elem a)
@@ -1427,10 +1427,10 @@ def zipWithAll (f : Option α → Option β → γ) : List α → List β → Li
| a :: as, [] => (a :: as).map fun a => f (some a) none
| a :: as, b :: bs => f a b :: zipWithAll f as bs
@[simp] theorem zipWithAll_nil :
@[simp] theorem zipWithAll_nil_right :
zipWithAll f as [] = as.map fun a => f (some a) none := by
cases as <;> rfl
@[simp] theorem nil_zipWithAll :
@[simp] theorem zipWithAll_nil_left :
zipWithAll f [] bs = bs.map fun b => f none (some b) := rfl
@[simp] theorem zipWithAll_cons_cons :
zipWithAll f (a :: as) (b :: bs) = f (some a) (some b) :: zipWithAll f as bs := rfl

View File

@@ -101,7 +101,7 @@ theorem tail_eq_of_cons_eq (H : h₁ :: t₁ = h₂ :: t₂) : t₁ = t₂ := (c
theorem cons_inj_right (a : α) {l l' : List α} : a :: l = a :: l' l = l' :=
tail_eq_of_cons_eq, congrArg _
@[deprecated cons_inj_right (since := "2024-06-15")] abbrev cons_inj := @cons_inj_right
@[deprecated (since := "2024-06-15")] abbrev cons_inj := @cons_inj_right
theorem cons_eq_cons {a b : α} {l l' : List α} : a :: l = b :: l' a = b l = l' :=
List.cons.injEq .. .rfl
@@ -171,7 +171,7 @@ theorem get_cons_succ {as : List α} {h : i + 1 < (a :: as).length} :
theorem get_cons_succ' {as : List α} {i : Fin as.length} :
(a :: as).get i.succ = as.get i := rfl
@[deprecated "Deprecated without replacement." (since := "2024-07-09")]
@[deprecated (since := "2024-07-09")]
theorem get_cons_cons_one : (a₁ :: a₂ :: as).get (1 : Fin (as.length + 2)) = a₂ := rfl
theorem get_mk_zero : {l : List α} (h : 0 < l.length), l.get 0, h = l.head (length_pos.mp h)
@@ -791,24 +791,6 @@ theorem mem_or_eq_of_mem_set : ∀ {l : List α} {n : Nat} {a b : α}, a ∈ l.s
· intro a
simp
@[simp] theorem beq_nil_iff [BEq α] {l : List α} : (l == []) = l.isEmpty := by
cases l <;> rfl
@[simp] theorem nil_beq_iff [BEq α] {l : List α} : ([] == l) = l.isEmpty := by
cases l <;> rfl
@[simp] theorem cons_beq_cons [BEq α] {a b : α} {l₁ l₂ : List α} :
(a :: l₁ == b :: l₂) = (a == b && l₁ == l₂) := rfl
theorem length_eq_of_beq [BEq α] {l₁ l₂ : List α} (h : l₁ == l₂) : l₁.length = l₂.length :=
match l₁, l₂ with
| [], [] => rfl
| [], _ :: _ => by simp [beq_nil_iff] at h
| _ :: _, [] => by simp [nil_beq_iff] at h
| a :: l₁, b :: l₂ => by
simp at h
simpa [Nat.add_one_inj]using length_eq_of_beq h.2
/-! ### Lexicographic ordering -/
protected theorem lt_irrefl [LT α] (lt_irrefl : x : α, ¬x < x) (l : List α) : ¬l < l := by
@@ -874,12 +856,6 @@ theorem foldr_eq_foldrM (f : α → β → β) (b) (l : List α) :
l.foldr f b = l.foldrM (m := Id) f b := by
induction l <;> simp [*, foldr]
@[simp] theorem id_run_foldlM (f : β α Id β) (b) (l : List α) :
Id.run (l.foldlM f b) = l.foldl f b := (foldl_eq_foldlM f b l).symm
@[simp] theorem id_run_foldrM (f : α β Id β) (b) (l : List α) :
Id.run (l.foldrM f b) = l.foldr f b := (foldr_eq_foldrM f b l).symm
/-! ### foldl and foldr -/
@[simp] theorem foldr_cons_eq_append (l : List α) : l.foldr cons l' = l ++ l' := by
@@ -1824,7 +1800,7 @@ theorem getElem_append_right' (l₁ : List α) {l₂ : List α} {n : Nat} (hn :
l₂[n] = (l₁ ++ l₂)[n + l₁.length]'(by simpa [Nat.add_comm] using Nat.add_lt_add_left hn _) := by
rw [getElem_append_right] <;> simp [*, le_add_left]
@[deprecated "Deprecated without replacement." (since := "2024-06-12")]
@[deprecated (since := "2024-06-12")]
theorem get_append_right_aux {l₁ l₂ : List α} {n : Nat}
(h₁ : l₁.length n) (h₂ : n < (l₁ ++ l₂).length) : n - l₁.length < l₂.length := by
rw [length_append] at h₂
@@ -1841,7 +1817,7 @@ theorem getElem_of_append {l : List α} (eq : l = l₁ ++ a :: l₂) (h : l₁.l
rw [ getElem?_eq_getElem, eq, getElem?_append_right (h Nat.le_refl _), h]
simp
@[deprecated "Deprecated without replacement." (since := "2024-06-12")]
@[deprecated (since := "2024-06-12")]
theorem get_of_append_proof {l : List α}
(eq : l = l₁ ++ a :: l₂) (h : l₁.length = n) : n < length l := eq h by simp_arith

View File

@@ -9,7 +9,7 @@ import Init.Data.List.Basic
namespace List
/-! ### isEqv -/
/-! ### isEqv-/
theorem isEqv_eq_decide (a b : List α) (r) :
isEqv a b r = if h : a.length = b.length then

View File

@@ -293,7 +293,7 @@ theorem sorted_mergeSort
apply sorted_mergeSort trans total
termination_by l => l.length
@[deprecated sorted_mergeSort (since := "2024-09-02")] abbrev mergeSort_sorted := @sorted_mergeSort
@[deprecated (since := "2024-09-02")] abbrev mergeSort_sorted := @sorted_mergeSort
/--
If the input list is already sorted, then `mergeSort` does not change the list.
@@ -429,8 +429,7 @@ theorem sublist_mergeSort
((fun w => Sublist.of_sublist_append_right w h') fun b m₁ m₃ =>
(Bool.eq_not_self true).mp ((rel_of_pairwise_cons hc m₁).symm.trans (h₃ b m₃))))
@[deprecated sublist_mergeSort (since := "2024-09-02")]
abbrev mergeSort_stable := @sublist_mergeSort
@[deprecated (since := "2024-09-02")] abbrev mergeSort_stable := @sublist_mergeSort
/--
Another statement of stability of merge sort.
@@ -443,8 +442,7 @@ theorem pair_sublist_mergeSort
(hab : le a b) (h : [a, b] <+ l) : [a, b] <+ mergeSort l le :=
sublist_mergeSort trans total (pairwise_pair.mpr hab) h
@[deprecated pair_sublist_mergeSort(since := "2024-09-02")]
abbrev mergeSort_stable_pair := @pair_sublist_mergeSort
@[deprecated (since := "2024-09-02")] abbrev mergeSort_stable_pair := @pair_sublist_mergeSort
theorem map_merge {f : α β} {r : α α Bool} {s : β β Bool} {l l' : List α}
(hl : a l, b l', r a b = s (f a) (f b)) :

View File

@@ -835,7 +835,7 @@ theorem isPrefix_iff : l₁ <+: l₂ ↔ ∀ i (h : i < l₁.length), l₂[i]? =
simpa using 0, by simp
| cons b l₂ =>
simp only [cons_append, cons_prefix_cons, ih]
rw (occs := [2]) [ Nat.and_forall_add_one]
rw (occs := .pos [2]) [ Nat.and_forall_add_one]
simp [Nat.succ_lt_succ_iff, eq_comm]
theorem isPrefix_iff_getElem {l₁ l₂ : List α} :

View File

@@ -224,7 +224,7 @@ theorem take_succ {l : List α} {n : Nat} : l.take (n + 1) = l.take n ++ l[n]?.t
· simp only [take, Option.toList, getElem?_cons_zero, nil_append]
· simp only [take, hl, getElem?_cons_succ, cons_append]
@[deprecated "Deprecated without replacement." (since := "2024-07-25")]
@[deprecated (since := "2024-07-25")]
theorem drop_sizeOf_le [SizeOf α] (l : List α) (n : Nat) : sizeOf (l.drop n) sizeOf l := by
induction l generalizing n with
| nil => rw [drop_nil]; apply Nat.le_refl

View File

@@ -20,4 +20,3 @@ import Init.Data.Nat.Mod
import Init.Data.Nat.Lcm
import Init.Data.Nat.Compare
import Init.Data.Nat.Simproc
import Init.Data.Nat.Fold

View File

@@ -35,6 +35,52 @@ Used as the default `Nat` eliminator by the `cases` tactic. -/
protected abbrev casesAuxOn {motive : Nat Sort u} (t : Nat) (zero : motive 0) (succ : (n : Nat) motive (n + 1)) : motive t :=
Nat.casesOn t zero succ
/--
`Nat.fold` evaluates `f` on the numbers up to `n` exclusive, in increasing order:
* `Nat.fold f 3 init = init |> f 0 |> f 1 |> f 2`
-/
@[specialize] def fold {α : Type u} (f : Nat α α) : (n : Nat) (init : α) α
| 0, a => a
| succ n, a => f n (fold f n a)
/-- Tail-recursive version of `Nat.fold`. -/
@[inline] def foldTR {α : Type u} (f : Nat α α) (n : Nat) (init : α) : α :=
let rec @[specialize] loop
| 0, a => a
| succ m, a => loop m (f (n - succ m) a)
loop n init
/--
`Nat.foldRev` evaluates `f` on the numbers up to `n` exclusive, in decreasing order:
* `Nat.foldRev f 3 init = f 0 <| f 1 <| f 2 <| init`
-/
@[specialize] def foldRev {α : Type u} (f : Nat α α) : (n : Nat) (init : α) α
| 0, a => a
| succ n, a => foldRev f n (f n a)
/-- `any f n = true` iff there is `i in [0, n-1]` s.t. `f i = true` -/
@[specialize] def any (f : Nat Bool) : Nat Bool
| 0 => false
| succ n => any f n || f n
/-- Tail-recursive version of `Nat.any`. -/
@[inline] def anyTR (f : Nat Bool) (n : Nat) : Bool :=
let rec @[specialize] loop : Nat Bool
| 0 => false
| succ m => f (n - succ m) || loop m
loop n
/-- `all f n = true` iff every `i in [0, n-1]` satisfies `f i = true` -/
@[specialize] def all (f : Nat Bool) : Nat Bool
| 0 => true
| succ n => all f n && f n
/-- Tail-recursive version of `Nat.all`. -/
@[inline] def allTR (f : Nat Bool) (n : Nat) : Bool :=
let rec @[specialize] loop : Nat Bool
| 0 => true
| succ m => f (n - succ m) && loop m
loop n
/--
`Nat.repeat f n a` is `f^(n) a`; that is, it iterates `f` `n` times on `a`.
@@ -789,7 +835,7 @@ theorem pred_lt_of_lt {n m : Nat} (h : m < n) : pred n < n :=
pred_lt (not_eq_zero_of_lt h)
set_option linter.missingDocs false in
@[deprecated pred_lt_of_lt (since := "2024-06-01")] abbrev pred_lt' := @pred_lt_of_lt
@[deprecated (since := "2024-06-01")] abbrev pred_lt' := @pred_lt_of_lt
theorem sub_one_lt_of_lt {n m : Nat} (h : m < n) : n - 1 < n :=
sub_one_lt (not_eq_zero_of_lt h)
@@ -1075,7 +1121,7 @@ theorem pred_mul (n m : Nat) : pred n * m = n * m - m := by
| succ n => rw [Nat.pred_succ, succ_mul, Nat.add_sub_cancel]
set_option linter.missingDocs false in
@[deprecated pred_mul (since := "2024-06-01")] abbrev mul_pred_left := @pred_mul
@[deprecated (since := "2024-06-01")] abbrev mul_pred_left := @pred_mul
protected theorem sub_one_mul (n m : Nat) : (n - 1) * m = n * m - m := by
cases n with
@@ -1087,7 +1133,7 @@ theorem mul_pred (n m : Nat) : n * pred m = n * m - n := by
rw [Nat.mul_comm, pred_mul, Nat.mul_comm]
set_option linter.missingDocs false in
@[deprecated mul_pred (since := "2024-06-01")] abbrev mul_pred_right := @mul_pred
@[deprecated (since := "2024-06-01")] abbrev mul_pred_right := @mul_pred
theorem mul_sub_one (n m : Nat) : n * (m - 1) = n * m - n := by
rw [Nat.mul_comm, Nat.sub_one_mul , Nat.mul_comm]
@@ -1112,6 +1158,33 @@ theorem not_lt_eq (a b : Nat) : (¬ (a < b)) = (b ≤ a) :=
theorem not_gt_eq (a b : Nat) : (¬ (a > b)) = (a b) :=
not_lt_eq b a
/-! # csimp theorems -/
@[csimp] theorem fold_eq_foldTR : @fold = @foldTR :=
funext fun α => funext fun f => funext fun n => funext fun init =>
let rec go : m n, foldTR.loop f (m + n) m (fold f n init) = fold f (m + n) init
| 0, n => by simp [foldTR.loop]
| succ m, n => by rw [foldTR.loop, add_sub_self_left, succ_add]; exact go m (succ n)
(go n 0).symm
@[csimp] theorem any_eq_anyTR : @any = @anyTR :=
funext fun f => funext fun n =>
let rec go : m n, (any f n || anyTR.loop f (m + n) m) = any f (m + n)
| 0, n => by simp [anyTR.loop]
| succ m, n => by
rw [anyTR.loop, add_sub_self_left, Bool.or_assoc, succ_add]
exact go m (succ n)
(go n 0).symm
@[csimp] theorem all_eq_allTR : @all = @allTR :=
funext fun f => funext fun n =>
let rec go : m n, (all f n && allTR.loop f (m + n) m) = all f (m + n)
| 0, n => by simp [allTR.loop]
| succ m, n => by
rw [allTR.loop, add_sub_self_left, Bool.and_assoc, succ_add]
exact go m (succ n)
(go n 0).symm
@[csimp] theorem repeat_eq_repeatTR : @repeat = @repeatTR :=
funext fun α => funext fun f => funext fun n => funext fun init =>
let rec go : m n, repeatTR.loop f m (repeat f n init) = repeat f (m + n) init
@@ -1120,3 +1193,31 @@ theorem not_gt_eq (a b : Nat) : (¬ (a > b)) = (a ≤ b) :=
(go n 0).symm
end Nat
namespace Prod
/--
`(start, stop).foldI f a` evaluates `f` on all the numbers
from `start` (inclusive) to `stop` (exclusive) in increasing order:
* `(5, 8).foldI f init = init |> f 5 |> f 6 |> f 7`
-/
@[inline] def foldI {α : Type u} (f : Nat α α) (i : Nat × Nat) (a : α) : α :=
Nat.foldTR.loop f i.2 (i.2 - i.1) a
/--
`(start, stop).anyI f a` returns true if `f` is true for some natural number
from `start` (inclusive) to `stop` (exclusive):
* `(5, 8).anyI f = f 5 || f 6 || f 7`
-/
@[inline] def anyI (f : Nat Bool) (i : Nat × Nat) : Bool :=
Nat.anyTR.loop f i.2 (i.2 - i.1)
/--
`(start, stop).allI f a` returns true if `f` is true for all natural numbers
from `start` (inclusive) to `stop` (exclusive):
* `(5, 8).anyI f = f 5 && f 6 && f 7`
-/
@[inline] def allI (f : Nat Bool) (i : Nat × Nat) : Bool :=
Nat.allTR.loop f i.2 (i.2 - i.1)
end Prod

View File

@@ -6,51 +6,50 @@ Author: Leonardo de Moura
prelude
import Init.Control.Basic
import Init.Data.Nat.Basic
import Init.Omega
namespace Nat
universe u v
@[inline] def forM {m} [Monad m] (n : Nat) (f : (i : Nat) i < n m Unit) : m Unit :=
let rec @[specialize] loop : i, i n m Unit
| 0, _ => pure ()
| i+1, h => do f (n-i-1) (by omega); loop i (Nat.le_of_succ_le h)
loop n (by simp)
@[inline] def forM {m} [Monad m] (n : Nat) (f : Nat m Unit) : m Unit :=
let rec @[specialize] loop
| 0 => pure ()
| i+1 => do f (n-i-1); loop i
loop n
@[inline] def forRevM {m} [Monad m] (n : Nat) (f : (i : Nat) i < n m Unit) : m Unit :=
let rec @[specialize] loop : i, i n m Unit
| 0, _ => pure ()
| i+1, h => do f i (by omega); loop i (Nat.le_of_succ_le h)
loop n (by simp)
@[inline] def forRevM {m} [Monad m] (n : Nat) (f : Nat m Unit) : m Unit :=
let rec @[specialize] loop
| 0 => pure ()
| i+1 => do f i; loop i
loop n
@[inline] def foldM {α : Type u} {m : Type u Type v} [Monad m] (n : Nat) (f : (i : Nat) i < n α m α) (init : α) : m α :=
let rec @[specialize] loop : i, i n α m α
| 0, h, a => pure a
| i+1, h, a => f (n-i-1) (by omega) a >>= loop i (Nat.le_of_succ_le h)
loop n (by omega) init
@[inline] def foldM {α : Type u} {m : Type u Type v} [Monad m] (f : Nat α m α) (init : α) (n : Nat) : m α :=
let rec @[specialize] loop
| 0, a => pure a
| i+1, a => f (n-i-1) a >>= loop i
loop n init
@[inline] def foldRevM {α : Type u} {m : Type u Type v} [Monad m] (n : Nat) (f : (i : Nat) i < n α m α) (init : α) : m α :=
let rec @[specialize] loop : i, i n α m α
| 0, h, a => pure a
| i+1, h, a => f i (by omega) a >>= loop i (Nat.le_of_succ_le h)
loop n (by omega) init
@[inline] def foldRevM {α : Type u} {m : Type u Type v} [Monad m] (f : Nat α m α) (init : α) (n : Nat) : m α :=
let rec @[specialize] loop
| 0, a => pure a
| i+1, a => f i a >>= loop i
loop n init
@[inline] def allM {m} [Monad m] (n : Nat) (p : (i : Nat) i < n m Bool) : m Bool :=
let rec @[specialize] loop : i, i n m Bool
| 0, _ => pure true
| i+1 , h => do
match ( p (n-i-1) (by omega)) with
| true => loop i (by omega)
@[inline] def allM {m} [Monad m] (n : Nat) (p : Nat m Bool) : m Bool :=
let rec @[specialize] loop
| 0 => pure true
| i+1 => do
match ( p (n-i-1)) with
| true => loop i
| false => pure false
loop n (by simp)
loop n
@[inline] def anyM {m} [Monad m] (n : Nat) (p : (i : Nat) i < n m Bool) : m Bool :=
let rec @[specialize] loop : i, i n m Bool
| 0, _ => pure false
| i+1, h => do
match ( p (n-i-1) (by omega)) with
@[inline] def anyM {m} [Monad m] (n : Nat) (p : Nat m Bool) : m Bool :=
let rec @[specialize] loop
| 0 => pure false
| i+1 => do
match ( p (n-i-1)) with
| true => pure true
| false => loop i (Nat.le_of_succ_le h)
loop n (by simp)
| false => loop i
loop n
end Nat

View File

@@ -92,7 +92,7 @@ protected theorem div_mul_cancel {n m : Nat} (H : n m) : m / n * n = m := by
rw [Nat.mul_comm, Nat.mul_div_cancel' H]
@[simp] theorem mod_mod_of_dvd (a : Nat) (h : c b) : a % b % c = a % c := by
rw (occs := [2]) [ mod_add_div a b]
rw (occs := .pos [2]) [ mod_add_div a b]
have x, h := h
subst h
rw [Nat.mul_assoc, add_mul_mod_self_left]

View File

@@ -1,168 +0,0 @@
/-
Copyright (c) 2014 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Floris van Doorn, Leonardo de Moura, Kim Morrison
-/
prelude
import Init.Omega
set_option linter.missingDocs true -- keep it documented
universe u
namespace Nat
/--
`Nat.fold` evaluates `f` on the numbers up to `n` exclusive, in increasing order:
* `Nat.fold f 3 init = init |> f 0 |> f 1 |> f 2`
-/
@[specialize] def fold {α : Type u} : (n : Nat) (f : (i : Nat) i < n α α) (init : α) α
| 0, f, a => a
| succ n, f, a => f n (by omega) (fold n (fun i h => f i (by omega)) a)
/-- Tail-recursive version of `Nat.fold`. -/
@[inline] def foldTR {α : Type u} (n : Nat) (f : (i : Nat) i < n α α) (init : α) : α :=
let rec @[specialize] loop : j, j n α α
| 0, h, a => a
| succ m, h, a => loop m (by omega) (f (n - succ m) (by omega) a)
loop n (by omega) init
/--
`Nat.foldRev` evaluates `f` on the numbers up to `n` exclusive, in decreasing order:
* `Nat.foldRev f 3 init = f 0 <| f 1 <| f 2 <| init`
-/
@[specialize] def foldRev {α : Type u} : (n : Nat) (f : (i : Nat) i < n α α) (init : α) α
| 0, f, a => a
| succ n, f, a => foldRev n (fun i h => f i (by omega)) (f n (by omega) a)
/-- `any f n = true` iff there is `i in [0, n-1]` s.t. `f i = true` -/
@[specialize] def any : (n : Nat) (f : (i : Nat) i < n Bool) Bool
| 0, f => false
| succ n, f => any n (fun i h => f i (by omega)) || f n (by omega)
/-- Tail-recursive version of `Nat.any`. -/
@[inline] def anyTR (n : Nat) (f : (i : Nat) i < n Bool) : Bool :=
let rec @[specialize] loop : (i : Nat) i n Bool
| 0, h => false
| succ m, h => f (n - succ m) (by omega) || loop m (by omega)
loop n (by omega)
/-- `all f n = true` iff every `i in [0, n-1]` satisfies `f i = true` -/
@[specialize] def all : (n : Nat) (f : (i : Nat) i < n Bool) Bool
| 0, f => true
| succ n, f => all n (fun i h => f i (by omega)) && f n (by omega)
/-- Tail-recursive version of `Nat.all`. -/
@[inline] def allTR (n : Nat) (f : (i : Nat) i < n Bool) : Bool :=
let rec @[specialize] loop : (i : Nat) i n Bool
| 0, h => true
| succ m, h => f (n - succ m) (by omega) && loop m (by omega)
loop n (by omega)
/-! # csimp theorems -/
theorem fold_congr {α : Type u} {n m : Nat} (w : n = m)
(f : (i : Nat) i < n α α) (init : α) :
fold n f init = fold m (fun i h => f i (by omega)) init := by
subst m
rfl
theorem foldTR_loop_congr {α : Type u} {n m : Nat} (w : n = m)
(f : (i : Nat) i < n α α) (j : Nat) (h : j n) (init : α) :
foldTR.loop n f j h init = foldTR.loop m (fun i h => f i (by omega)) j (by omega) init := by
subst m
rfl
@[csimp] theorem fold_eq_foldTR : @fold = @foldTR :=
funext fun α => funext fun n => funext fun f => funext fun init =>
let rec go : m n f, fold (m + n) f init = foldTR.loop (m + n) f m (by omega) (fold n (fun i h => f i (by omega)) init)
| 0, n, f => by
simp only [foldTR.loop]
have t : 0 + n = n := by omega
rw [fold_congr t]
| succ m, n, f => by
have t : (m + 1) + n = m + (n + 1) := by omega
rw [foldTR.loop]
simp only [succ_eq_add_one, Nat.add_sub_cancel]
rw [fold_congr t, foldTR_loop_congr t, go, fold]
congr
omega
go n 0 f
theorem any_congr {n m : Nat} (w : n = m) (f : (i : Nat) i < n Bool) : any n f = any m (fun i h => f i (by omega)) := by
subst m
rfl
theorem anyTR_loop_congr {n m : Nat} (w : n = m) (f : (i : Nat) i < n Bool) (j : Nat) (h : j n) :
anyTR.loop n f j h = anyTR.loop m (fun i h => f i (by omega)) j (by omega) := by
subst m
rfl
@[csimp] theorem any_eq_anyTR : @any = @anyTR :=
funext fun n => funext fun f =>
let rec go : m n f, any (m + n) f = (any n (fun i h => f i (by omega)) || anyTR.loop (m + n) f m (by omega))
| 0, n, f => by
simp [anyTR.loop]
have t : 0 + n = n := by omega
rw [any_congr t]
| succ m, n, f => by
have t : (m + 1) + n = m + (n + 1) := by omega
rw [anyTR.loop]
simp only [succ_eq_add_one]
rw [any_congr t, anyTR_loop_congr t, go, any, Bool.or_assoc]
congr
omega
go n 0 f
theorem all_congr {n m : Nat} (w : n = m) (f : (i : Nat) i < n Bool) : all n f = all m (fun i h => f i (by omega)) := by
subst m
rfl
theorem allTR_loop_congr {n m : Nat} (w : n = m) (f : (i : Nat) i < n Bool) (j : Nat) (h : j n) : allTR.loop n f j h = allTR.loop m (fun i h => f i (by omega)) j (by omega) := by
subst m
rfl
@[csimp] theorem all_eq_allTR : @all = @allTR :=
funext fun n => funext fun f =>
let rec go : m n f, all (m + n) f = (all n (fun i h => f i (by omega)) && allTR.loop (m + n) f m (by omega))
| 0, n, f => by
simp [allTR.loop]
have t : 0 + n = n := by omega
rw [all_congr t]
| succ m, n, f => by
have t : (m + 1) + n = m + (n + 1) := by omega
rw [allTR.loop]
simp only [succ_eq_add_one]
rw [all_congr t, allTR_loop_congr t, go, all, Bool.and_assoc]
congr
omega
go n 0 f
end Nat
namespace Prod
/--
`(start, stop).foldI f a` evaluates `f` on all the numbers
from `start` (inclusive) to `stop` (exclusive) in increasing order:
* `(5, 8).foldI f init = init |> f 5 |> f 6 |> f 7`
-/
@[inline] def foldI {α : Type u} (i : Nat × Nat) (f : (j : Nat) i.1 j j < i.2 α α) (a : α) : α :=
(i.2 - i.1).fold (fun j _ => f (i.1 + j) (by omega) (by omega)) a
/--
`(start, stop).anyI f a` returns true if `f` is true for some natural number
from `start` (inclusive) to `stop` (exclusive):
* `(5, 8).anyI f = f 5 || f 6 || f 7`
-/
@[inline] def anyI (i : Nat × Nat) (f : (j : Nat) i.1 j j < i.2 Bool) : Bool :=
(i.2 - i.1).any (fun j _ => f (i.1 + j) (by omega) (by omega))
/--
`(start, stop).allI f a` returns true if `f` is true for all natural numbers
from `start` (inclusive) to `stop` (exclusive):
* `(5, 8).anyI f = f 5 && f 6 && f 7`
-/
@[inline] def allI (i : Nat × Nat) (f : (j : Nat) i.1 j j < i.2 Bool) : Bool :=
(i.2 - i.1).all (fun j _ => f (i.1 + j) (by omega) (by omega))
end Prod

View File

@@ -651,8 +651,8 @@ theorem sub_mul_mod {x k n : Nat} (h₁ : n*k ≤ x) : (x - n*k) % n = x % n :=
| .inr npos => Nat.mod_eq_of_lt (mod_lt _ npos)
theorem mul_mod (a b n : Nat) : a * b % n = (a % n) * (b % n) % n := by
rw (occs := [1]) [ mod_add_div a n]
rw (occs := [1]) [ mod_add_div b n]
rw (occs := .pos [1]) [ mod_add_div a n]
rw (occs := .pos [1]) [ mod_add_div b n]
rw [Nat.add_mul, Nat.mul_add, Nat.mul_add,
Nat.mul_assoc, Nat.mul_assoc, Nat.mul_add n, add_mul_mod_self_left,
Nat.mul_comm _ (n * (b / n)), Nat.mul_assoc, add_mul_mod_self_left]
@@ -846,18 +846,6 @@ protected theorem pow_lt_pow_iff_pow_mul_le_pow {a n m : Nat} (h : 1 < a) :
rw [Nat.pow_add_one, Nat.pow_le_pow_iff_right (by omega), Nat.pow_lt_pow_iff_right (by omega)]
omega
protected theorem lt_pow_self {n a : Nat} (h : 1 < a) : n < a ^ n := by
induction n with
| zero => exact Nat.zero_lt_one
| succ _ ih => exact Nat.lt_of_lt_of_le (Nat.add_lt_add_right ih 1) (Nat.pow_lt_pow_succ h)
protected theorem lt_two_pow_self : n < 2 ^ n :=
Nat.lt_pow_self Nat.one_lt_two
@[simp]
protected theorem mod_two_pow_self : n % 2 ^ n = n :=
Nat.mod_eq_of_lt Nat.lt_two_pow_self
@[simp]
theorem two_pow_pred_mul_two (h : 0 < w) :
2 ^ (w - 1) * 2 = 2 ^ w := by

View File

@@ -31,7 +31,7 @@ This file defines basic operations on the the sum type `α ⊕ β`.
## Further material
See `Init.Data.Sum.Lemmas` for theorems about these definitions.
See `Batteries.Data.Sum.Lemmas` for theorems about these definitions.
## Notes

View File

@@ -246,12 +246,6 @@ instance (a b : UInt64) : Decidable (a ≤ b) := UInt64.decLe a b
instance : Max UInt64 := maxOfLe
instance : Min UInt64 := minOfLe
theorem usize_size_le : USize.size 18446744073709551616 := by
cases usize_size_eq <;> next h => rw [h]; decide
theorem le_usize_size : 4294967296 USize.size := by
cases usize_size_eq <;> next h => rw [h]; decide
@[extern "lean_usize_mul"]
def USize.mul (a b : USize) : USize := a.toBitVec * b.toBitVec
@[extern "lean_usize_div"]
@@ -270,29 +264,10 @@ def USize.xor (a b : USize) : USize := ⟨a.toBitVec ^^^ b.toBitVec⟩
def USize.shiftLeft (a b : USize) : USize := a.toBitVec <<< (mod b (USize.ofNat System.Platform.numBits)).toBitVec
@[extern "lean_usize_shift_right"]
def USize.shiftRight (a b : USize) : USize := a.toBitVec >>> (mod b (USize.ofNat System.Platform.numBits)).toBitVec
/--
Upcast a `Nat` less than `2^32` to a `USize`.
This is lossless because `USize.size` is either `2^32` or `2^64`.
This function is overridden with a native implementation.
-/
@[extern "lean_usize_of_nat"]
def USize.ofNat32 (n : @& Nat) (h : n < 4294967296) : USize :=
USize.ofNatCore n (Nat.lt_of_lt_of_le h le_usize_size)
@[extern "lean_uint32_to_usize"]
def UInt32.toUSize (a : UInt32) : USize := USize.ofNat32 a.toBitVec.toNat a.toBitVec.isLt
@[extern "lean_usize_to_uint32"]
def USize.toUInt32 (a : USize) : UInt32 := a.toNat.toUInt32
/-- Converts a `UInt64` to a `USize` by reducing modulo `USize.size`. -/
@[extern "lean_uint64_to_usize"]
def UInt64.toUSize (a : UInt64) : USize := a.toNat.toUSize
/--
Upcast a `USize` to a `UInt64`.
This is lossless because `USize.size` is either `2^32` or `2^64`.
This function is overridden with a native implementation.
-/
@[extern "lean_usize_to_uint64"]
def USize.toUInt64 (a : USize) : UInt64 :=
UInt64.ofNatCore a.toBitVec.toNat (Nat.lt_of_lt_of_le a.toBitVec.isLt usize_size_le)
instance : Mul USize := USize.mul
instance : Mod USize := USize.mod

View File

@@ -94,8 +94,10 @@ def UInt32.toUInt64 (a : UInt32) : UInt64 := ⟨⟨a.toNat, Nat.lt_trans a.toBit
instance UInt64.instOfNat : OfNat UInt64 n := UInt64.ofNat n
@[deprecated usize_size_pos (since := "2024-11-24")] theorem usize_size_gt_zero : USize.size > 0 :=
usize_size_pos
theorem usize_size_gt_zero : USize.size > 0 := by
cases usize_size_eq with
| inl h => rw [h]; decide
| inr h => rw [h]; decide
def USize.val (x : USize) : Fin USize.size := x.toBitVec.toFin
@[extern "lean_usize_of_nat"]

View File

@@ -133,9 +133,6 @@ declare_uint_theorems UInt32
declare_uint_theorems UInt64
declare_uint_theorems USize
theorem USize.toNat_ofNat_of_lt_32 {n : Nat} (h : n < 4294967296) : toNat (ofNat n) = n :=
toNat_ofNat_of_lt (Nat.lt_of_lt_of_le h le_usize_size)
theorem UInt32.toNat_lt_of_lt {n : UInt32} {m : Nat} (h : m < size) : n < ofNat m n.toNat < m := by
simp [lt_def, BitVec.lt_def, UInt32.toNat, toBitVec_eq_of_lt h]
@@ -148,22 +145,22 @@ theorem UInt32.toNat_le_of_le {n : UInt32} {m : Nat} (h : m < size) : n ≤ ofNa
theorem UInt32.le_toNat_of_le {n : UInt32} {m : Nat} (h : m < size) : ofNat m n m n.toNat := by
simp [le_def, BitVec.le_def, UInt32.toNat, toBitVec_eq_of_lt h]
@[deprecated UInt8.toNat_zero (since := "2024-06-23")] protected abbrev UInt8.zero_toNat := @UInt8.toNat_zero
@[deprecated UInt8.toNat_div (since := "2024-06-23")] protected abbrev UInt8.div_toNat := @UInt8.toNat_div
@[deprecated UInt8.toNat_mod (since := "2024-06-23")] protected abbrev UInt8.mod_toNat := @UInt8.toNat_mod
@[deprecated (since := "2024-06-23")] protected abbrev UInt8.zero_toNat := @UInt8.toNat_zero
@[deprecated (since := "2024-06-23")] protected abbrev UInt8.div_toNat := @UInt8.toNat_div
@[deprecated (since := "2024-06-23")] protected abbrev UInt8.mod_toNat := @UInt8.toNat_mod
@[deprecated UInt16.toNat_zero (since := "2024-06-23")] protected abbrev UInt16.zero_toNat := @UInt16.toNat_zero
@[deprecated UInt16.toNat_div (since := "2024-06-23")] protected abbrev UInt16.div_toNat := @UInt16.toNat_div
@[deprecated UInt16.toNat_mod (since := "2024-06-23")] protected abbrev UInt16.mod_toNat := @UInt16.toNat_mod
@[deprecated (since := "2024-06-23")] protected abbrev UInt16.zero_toNat := @UInt16.toNat_zero
@[deprecated (since := "2024-06-23")] protected abbrev UInt16.div_toNat := @UInt16.toNat_div
@[deprecated (since := "2024-06-23")] protected abbrev UInt16.mod_toNat := @UInt16.toNat_mod
@[deprecated UInt32.toNat_zero (since := "2024-06-23")] protected abbrev UInt32.zero_toNat := @UInt32.toNat_zero
@[deprecated UInt32.toNat_div (since := "2024-06-23")] protected abbrev UInt32.div_toNat := @UInt32.toNat_div
@[deprecated UInt32.toNat_mod (since := "2024-06-23")] protected abbrev UInt32.mod_toNat := @UInt32.toNat_mod
@[deprecated (since := "2024-06-23")] protected abbrev UInt32.zero_toNat := @UInt32.toNat_zero
@[deprecated (since := "2024-06-23")] protected abbrev UInt32.div_toNat := @UInt32.toNat_div
@[deprecated (since := "2024-06-23")] protected abbrev UInt32.mod_toNat := @UInt32.toNat_mod
@[deprecated UInt64.toNat_zero (since := "2024-06-23")] protected abbrev UInt64.zero_toNat := @UInt64.toNat_zero
@[deprecated UInt64.toNat_div (since := "2024-06-23")] protected abbrev UInt64.div_toNat := @UInt64.toNat_div
@[deprecated UInt64.toNat_mod (since := "2024-06-23")] protected abbrev UInt64.mod_toNat := @UInt64.toNat_mod
@[deprecated (since := "2024-06-23")] protected abbrev UInt64.zero_toNat := @UInt64.toNat_zero
@[deprecated (since := "2024-06-23")] protected abbrev UInt64.div_toNat := @UInt64.toNat_div
@[deprecated (since := "2024-06-23")] protected abbrev UInt64.mod_toNat := @UInt64.toNat_mod
@[deprecated USize.toNat_zero (since := "2024-06-23")] protected abbrev USize.zero_toNat := @USize.toNat_zero
@[deprecated USize.toNat_div (since := "2024-06-23")] protected abbrev USize.div_toNat := @USize.toNat_div
@[deprecated USize.toNat_mod (since := "2024-06-23")] protected abbrev USize.mod_toNat := @USize.toNat_mod
@[deprecated (since := "2024-06-23")] protected abbrev USize.zero_toNat := @USize.toNat_zero
@[deprecated (since := "2024-06-23")] protected abbrev USize.div_toNat := @USize.toNat_div
@[deprecated (since := "2024-06-23")] protected abbrev USize.mod_toNat := @USize.toNat_mod

View File

@@ -1,7 +0,0 @@
/-
Copyright (c) 2024 Lean FRO. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Kim Morrison
-/
prelude
import Init.Data.Vector.Basic

View File

@@ -1,256 +0,0 @@
/-
Copyright (c) 2024 Shreyas Srinivas. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Shreyas Srinivas, François G. Dorais, Kim Morrison
-/
prelude
import Init.Data.Array.Lemmas
/-!
# Vectors
`Vector α n` is a thin wrapper around `Array α` for arrays of fixed size `n`.
-/
/-- `Vector α n` is an `Array α` with size `n`. -/
structure Vector (α : Type u) (n : Nat) extends Array α where
/-- Array size. -/
size_toArray : toArray.size = n
deriving Repr, DecidableEq
attribute [simp] Vector.size_toArray
namespace Vector
/-- Syntax for `Vector α n` -/
syntax "#v[" withoutPosition(sepBy(term, ", ")) "]" : term
open Lean in
macro_rules
| `(#v[ $elems,* ]) => `(Vector.mk (n := $(quote elems.getElems.size)) #[$elems,*] rfl)
/-- Custom eliminator for `Vector α n` through `Array α` -/
@[elab_as_elim]
def elimAsArray {motive : Vector α n Sort u}
(mk : (a : Array α) (ha : a.size = n), motive a, ha) :
(v : Vector α n) motive v
| a, ha => mk a ha
/-- Custom eliminator for `Vector α n` through `List α` -/
@[elab_as_elim]
def elimAsList {motive : Vector α n Sort u}
(mk : (a : List α) (ha : a.length = n), motive a, ha) :
(v : Vector α n) motive v
| a, ha => mk a ha
/-- The empty vector. -/
@[inline] def empty : Vector α 0 := .empty, rfl
/-- Make an empty vector with pre-allocated capacity. -/
@[inline] def mkEmpty (capacity : Nat) : Vector α 0 := .mkEmpty capacity, rfl
/-- Makes a vector of size `n` with all cells containing `v`. -/
@[inline] def mkVector (n) (v : α) : Vector α n := mkArray n v, by simp
/-- Returns a vector of size `1` with element `v`. -/
@[inline] def singleton (v : α) : Vector α 1 := #[v], rfl
instance [Inhabited α] : Inhabited (Vector α n) where
default := mkVector n default
/-- Get an element of a vector using a `Fin` index. -/
@[inline] def get (v : Vector α n) (i : Fin n) : α :=
v.toArray[(i.cast v.size_toArray.symm).1]
/-- Get an element of a vector using a `USize` index and a proof that the index is within bounds. -/
@[inline] def uget (v : Vector α n) (i : USize) (h : i.toNat < n) : α :=
v.toArray.uget i (v.size_toArray.symm h)
instance : GetElem (Vector α n) Nat α fun _ i => i < n where
getElem x i h := get x i, h
/--
Get an element of a vector using a `Nat` index. Returns the given default value if the index is out
of bounds.
-/
@[inline] def getD (v : Vector α n) (i : Nat) (default : α) : α := v.toArray.getD i default
/-- The last element of a vector. Panics if the vector is empty. -/
@[inline] def back! [Inhabited α] (v : Vector α n) : α := v.toArray.back!
/-- The last element of a vector, or `none` if the array is empty. -/
@[inline] def back? (v : Vector α n) : Option α := v.toArray.back?
/-- The last element of a non-empty vector. -/
@[inline] def back [NeZero n] (v : Vector α n) : α :=
-- TODO: change to just `v[n]`
have : Inhabited α := v[0]'(Nat.pos_of_neZero n)
v.back!
/-- The first element of a non-empty vector. -/
@[inline] def head [NeZero n] (v : Vector α n) := v[0]'(Nat.pos_of_neZero n)
/-- Push an element `x` to the end of a vector. -/
@[inline] def push (x : α) (v : Vector α n) : Vector α (n + 1) :=
v.toArray.push x, by simp
/-- Remove the last element of a vector. -/
@[inline] def pop (v : Vector α n) : Vector α (n - 1) :=
Array.pop v.toArray, by simp
/--
Set an element in a vector using a `Nat` index, with a tactic provided proof that the index is in
bounds.
This will perform the update destructively provided that the vector has a reference count of 1.
-/
@[inline] def set (v : Vector α n) (i : Nat) (x : α) (h : i < n := by get_elem_tactic): Vector α n :=
v.toArray.set i x (by simp [*]), by simp
/--
Set an element in a vector using a `Nat` index. Returns the vector unchanged if the index is out of
bounds.
This will perform the update destructively provided that the vector has a reference count of 1.
-/
@[inline] def setIfInBounds (v : Vector α n) (i : Nat) (x : α) : Vector α n :=
v.toArray.setIfInBounds i x, by simp
/--
Set an element in a vector using a `Nat` index. Panics if the index is out of bounds.
This will perform the update destructively provided that the vector has a reference count of 1.
-/
@[inline] def set! (v : Vector α n) (i : Nat) (x : α) : Vector α n :=
v.toArray.set! i x, by simp
/-- Append two vectors. -/
@[inline] def append (v : Vector α n) (w : Vector α m) : Vector α (n + m) :=
v.toArray ++ w.toArray, by simp
instance : HAppend (Vector α n) (Vector α m) (Vector α (n + m)) where
hAppend := append
/-- Creates a vector from another with a provably equal length. -/
@[inline] protected def cast (h : n = m) (v : Vector α n) : Vector α m :=
v.toArray, by simp [*]
/--
Extracts the slice of a vector from indices `start` to `stop` (exclusive). If `start ≥ stop`, the
result is empty. If `stop` is greater than the size of the vector, the size is used instead.
-/
@[inline] def extract (v : Vector α n) (start stop : Nat) : Vector α (min stop n - start) :=
v.toArray.extract start stop, by simp
/-- Maps elements of a vector using the function `f`. -/
@[inline] def map (f : α β) (v : Vector α n) : Vector β n :=
v.toArray.map f, by simp
/-- Maps corresponding elements of two vectors of equal size using the function `f`. -/
@[inline] def zipWith (a : Vector α n) (b : Vector β n) (f : α β φ) : Vector φ n :=
Array.zipWith a.toArray b.toArray f, by simp
/-- The vector of length `n` whose `i`-th element is `f i`. -/
@[inline] def ofFn (f : Fin n α) : Vector α n :=
Array.ofFn f, by simp
/--
Swap two elements of a vector using `Fin` indices.
This will perform the update destructively provided that the vector has a reference count of 1.
-/
@[inline] def swap (v : Vector α n) (i j : Nat)
(hi : i < n := by get_elem_tactic) (hj : j < n := by get_elem_tactic) : Vector α n :=
v.toArray.swap i j (by simpa using hi) (by simpa using hj), by simp
/--
Swap two elements of a vector using `Nat` indices. Panics if either index is out of bounds.
This will perform the update destructively provided that the vector has a reference count of 1.
-/
@[inline] def swapIfInBounds (v : Vector α n) (i j : Nat) : Vector α n :=
v.toArray.swapIfInBounds i j, by simp
/--
Swaps an element of a vector with a given value using a `Fin` index. The original value is returned
along with the updated vector.
This will perform the update destructively provided that the vector has a reference count of 1.
-/
@[inline] def swapAt (v : Vector α n) (i : Nat) (x : α) (hi : i < n := by get_elem_tactic) :
α × Vector α n :=
let a := v.toArray.swapAt i x (by simpa using hi)
a.fst, a.snd, by simp [a]
/--
Swaps an element of a vector with a given value using a `Nat` index. Panics if the index is out of
bounds. The original value is returned along with the updated vector.
This will perform the update destructively provided that the vector has a reference count of 1.
-/
@[inline] def swapAt! (v : Vector α n) (i : Nat) (x : α) : α × Vector α n :=
let a := v.toArray.swapAt! i x
a.fst, a.snd, by simp [a]
/-- The vector `#v[0,1,2,...,n-1]`. -/
@[inline] def range (n : Nat) : Vector Nat n := Array.range n, by simp
/--
Extract the first `m` elements of a vector. If `m` is greater than or equal to the size of the
vector then the vector is returned unchanged.
-/
@[inline] def take (v : Vector α n) (m : Nat) : Vector α (min m n) :=
v.toArray.take m, by simp
/--
Deletes the first `m` elements of a vector. If `m` is greater than or equal to the size of the
vector then the empty vector is returned.
-/
@[inline] def drop (v : Vector α n) (m : Nat) : Vector α (n - m) :=
v.toArray.extract m v.size, by simp
/--
Compares two vectors of the same size using a given boolean relation `r`. `isEqv v w r` returns
`true` if and only if `r v[i] w[i]` is true for all indices `i`.
-/
@[inline] def isEqv (v w : Vector α n) (r : α α Bool) : Bool :=
Array.isEqvAux v.toArray w.toArray (by simp) r n (by simp)
instance [BEq α] : BEq (Vector α n) where
beq a b := isEqv a b (· == ·)
/-- Reverse the elements of a vector. -/
@[inline] def reverse (v : Vector α n) : Vector α n :=
v.toArray.reverse, by simp
/-- Delete an element of a vector using a `Nat` index and a tactic provided proof. -/
@[inline] def eraseIdx (v : Vector α n) (i : Nat) (h : i < n := by get_elem_tactic) :
Vector α (n-1) :=
v.toArray.eraseIdx i (v.size_toArray.symm h), by simp [Array.size_eraseIdx]
/-- Delete an element of a vector using a `Nat` index. Panics if the index is out of bounds. -/
@[inline] def eraseIdx! (v : Vector α n) (i : Nat) : Vector α (n-1) :=
if _ : i < n then
v.eraseIdx i
else
have : Inhabited (Vector α (n-1)) := v.pop
panic! "index out of bounds"
/-- Delete the first element of a vector. Returns the empty vector if the input vector is empty. -/
@[inline] def tail (v : Vector α n) : Vector α (n-1) :=
if _ : 0 < n then
v.eraseIdx 0
else
v.cast (by omega)
/--
Finds the first index of a given value in a vector using `==` for comparison. Returns `none` if the
no element of the index matches the given value.
-/
@[inline] def indexOf? [BEq α] (v : Vector α n) (x : α) : Option (Fin n) :=
(v.toArray.indexOf? x).map (Fin.cast v.size_toArray)
/-- Returns `true` when `v` is a prefix of the vector `w`. -/
@[inline] def isPrefixOf [BEq α] (v : Vector α m) (w : Vector α n) : Bool :=
v.toArray.isPrefixOf w.toArray

View File

@@ -206,12 +206,12 @@ instance : GetElem (List α) Nat α fun as i => i < as.length where
@[simp] theorem getElem_cons_zero (a : α) (as : List α) (h : 0 < (a :: as).length) : getElem (a :: as) 0 h = a := by
rfl
@[deprecated getElem_cons_zero (since := "2024-06-12")] abbrev cons_getElem_zero := @getElem_cons_zero
@[deprecated (since := "2024-06-12")] abbrev cons_getElem_zero := @getElem_cons_zero
@[simp] theorem getElem_cons_succ (a : α) (as : List α) (i : Nat) (h : i + 1 < (a :: as).length) : getElem (a :: as) (i+1) h = getElem as i (Nat.lt_of_succ_lt_succ h) := by
rfl
@[deprecated getElem_cons_succ (since := "2024-06-12")] abbrev cons_getElem_succ := @getElem_cons_succ
@[deprecated (since := "2024-06-12")] abbrev cons_getElem_succ := @getElem_cons_succ
@[simp] theorem getElem_mem : {l : List α} {n} (h : n < l.length), l[n]'h l
| _ :: _, 0, _ => .head ..
@@ -223,8 +223,7 @@ theorem getElem_cons_drop_succ_eq_drop {as : List α} {i : Nat} (h : i < as.leng
| _::_, 0 => rfl
| _::_, i+1 => getElem_cons_drop_succ_eq_drop (i := i) _
@[deprecated getElem_cons_drop_succ_eq_drop (since := "2024-11-05")]
abbrev get_drop_eq_drop := @getElem_cons_drop_succ_eq_drop
@[deprecated (since := "2024-11-05")] abbrev get_drop_eq_drop := @getElem_cons_drop_succ_eq_drop
end List

View File

@@ -251,16 +251,10 @@ def neutralConfig : Simp.Config := {
end Simp
/-- Configuration for which occurrences that match an expression should be rewritten. -/
inductive Occurrences where
/-- All occurrences should be rewritten. -/
| all
/-- A list of indices for which occurrences should be rewritten. -/
| pos (idxs : List Nat)
/-- A list of indices for which occurrences should not be rewritten. -/
| neg (idxs : List Nat)
deriving Inhabited, BEq
instance : Coe (List Nat) Occurrences := .pos
end Lean.Meta

View File

@@ -71,9 +71,9 @@ def prio : Category := {}
/-- `prec` is a builtin syntax category for precedences. A precedence is a value
that expresses how tightly a piece of syntax binds: for example `1 + 2 * 3` is
parsed as `1 + (2 * 3)` because `*` has a higher precedence than `+`.
parsed as `1 + (2 * 3)` because `*` has a higher pr0ecedence than `+`.
Higher numbers denote higher precedence.
In addition to literals like `37`, there are some special named precedence levels:
In addition to literals like `37`, there are some special named priorities:
* `arg` for the precedence of function arguments
* `max` for the highest precedence used in term parsers (not actually the maximum possible value)
* `lead` for the precedence of terms not supposed to be used as arguments

View File

@@ -2116,11 +2116,6 @@ theorem usize_size_eq : Or (Eq USize.size 4294967296) (Eq USize.size 18446744073
| _, Or.inl rfl => Or.inl (of_decide_eq_true rfl)
| _, Or.inr rfl => Or.inr (of_decide_eq_true rfl)
theorem usize_size_pos : LT.lt 0 USize.size :=
match USize.size, usize_size_eq with
| _, Or.inl rfl => of_decide_eq_true rfl
| _, Or.inr rfl => of_decide_eq_true rfl
/--
A `USize` is an unsigned integer with the size of a word
for the platform's architecture.
@@ -2160,7 +2155,24 @@ def USize.decEq (a b : USize) : Decidable (Eq a b) :=
instance : DecidableEq USize := USize.decEq
instance : Inhabited USize where
default := USize.ofNatCore 0 usize_size_pos
default := USize.ofNatCore 0 (match USize.size, usize_size_eq with
| _, Or.inl rfl => of_decide_eq_true rfl
| _, Or.inr rfl => of_decide_eq_true rfl)
/--
Upcast a `Nat` less than `2^32` to a `USize`.
This is lossless because `USize.size` is either `2^32` or `2^64`.
This function is overridden with a native implementation.
-/
@[extern "lean_usize_of_nat"]
def USize.ofNat32 (n : @& Nat) (h : LT.lt n 4294967296) : USize where
toBitVec :=
BitVec.ofNatLt n (
match System.Platform.numBits, System.Platform.numBits_eq with
| _, Or.inl rfl => h
| _, Or.inr rfl => Nat.lt_trans h (of_decide_eq_true rfl)
)
/--
A `Nat` denotes a valid unicode codepoint if it is less than `0x110000`, and
it is also not a "surrogate" character (the range `0xd800` to `0xdfff` inclusive).
@@ -3420,6 +3432,25 @@ class Hashable (α : Sort u) where
export Hashable (hash)
/-- Converts a `UInt64` to a `USize` by reducing modulo `USize.size`. -/
@[extern "lean_uint64_to_usize"]
opaque UInt64.toUSize (u : UInt64) : USize
/--
Upcast a `USize` to a `UInt64`.
This is lossless because `USize.size` is either `2^32` or `2^64`.
This function is overridden with a native implementation.
-/
@[extern "lean_usize_to_uint64"]
def USize.toUInt64 (u : USize) : UInt64 where
toBitVec := BitVec.ofNatLt u.toBitVec.toNat (
let n, h := u
show LT.lt n _ from
match System.Platform.numBits, System.Platform.numBits_eq, h with
| _, Or.inl rfl, h => Nat.lt_trans h (of_decide_eq_true rfl)
| _, Or.inr rfl, h => h
)
/-- An opaque hash mixing operation, used to implement hashing for tuples. -/
@[extern "lean_uint64_mix_hash"]
opaque mixHash (u₁ u₂ : UInt64) : UInt64

View File

@@ -5,7 +5,6 @@ Authors: Leonardo de Moura, Mario Carneiro
-/
prelude
import Init.Util
import Init.Data.UInt.Basic
namespace ShareCommon
/-

View File

@@ -72,21 +72,6 @@ theorem let_body_congr {α : Sort u} {β : α → Sort v} {b b' : (a : α) →
(a : α) (h : x, b x = b' x) : (let x := a; b x) = (let x := a; b' x) :=
(funext h : b = b') rfl
theorem letFun_unused {α : Sort u} {β : Sort v} (a : α) {b b' : β} (h : b = b') : @letFun α (fun _ => β) a (fun _ => b) = b' :=
h
theorem letFun_congr {α : Sort u} {β : Sort v} {a a' : α} {f f' : α β} (h₁ : a = a') (h₂ : x, f x = f' x)
: @letFun α (fun _ => β) a f = @letFun α (fun _ => β) a' f' := by
rw [h₁, funext h₂]
theorem letFun_body_congr {α : Sort u} {β : Sort v} (a : α) {f f' : α β} (h : x, f x = f' x)
: @letFun α (fun _ => β) a f = @letFun α (fun _ => β) a f' := by
rw [funext h]
theorem letFun_val_congr {α : Sort u} {β : Sort v} {a a' : α} {f : α β} (h : a = a')
: @letFun α (fun _ => β) a f = @letFun α (fun _ => β) a' f := by
rw [h]
@[congr]
theorem ite_congr {x y u v : α} {s : Decidable b} [Decidable c]
(h₁ : b = c) (h₂ : c x = u) (h₃ : ¬ c y = v) : ite b x y = ite c u v := by

View File

@@ -30,7 +30,7 @@ Does nothing for non-`node` nodes, or if `i` is out of bounds of the node list.
-/
def setArg (stx : Syntax) (i : Nat) (arg : Syntax) : Syntax :=
match stx with
| node info k args => node info k (args.setIfInBounds i arg)
| node info k args => node info k (args.setD i arg)
| stx => stx
end Lean.Syntax

View File

@@ -462,16 +462,6 @@ Note that it is the caller's job to remove the file after use.
-/
@[extern "lean_io_create_tempfile"] opaque createTempFile : IO (Handle × FilePath)
/--
Creates a temporary directory in the most secure manner possible. There are no race conditions in the
directorys creation. The directory is readable and writable only by the creating user ID.
Returns the new directory's path.
It is the caller's job to remove the directory after use.
-/
@[extern "lean_io_create_tempdir"] opaque createTempDir : IO FilePath
end FS
@[extern "lean_io_getenv"] opaque getEnv (var : @& String) : BaseIO (Option String)
@@ -484,6 +474,17 @@ namespace FS
def withFile (fn : FilePath) (mode : Mode) (f : Handle IO α) : IO α :=
Handle.mk fn mode >>= f
/--
Like `createTempFile` but also takes care of removing the file after usage.
-/
def withTempFile [Monad m] [MonadFinally m] [MonadLiftT IO m] (f : Handle FilePath m α) :
m α := do
let (handle, path) createTempFile
try
f handle path
finally
removeFile path
def Handle.putStrLn (h : Handle) (s : String) : IO Unit :=
h.putStr (s.push '\n')
@@ -674,10 +675,8 @@ def appDir : IO FilePath := do
| throw <| IO.userError s!"System.IO.appDir: unexpected filename '{p}'"
FS.realPath p
namespace FS
/-- Create given path and all missing parents as directories. -/
partial def createDirAll (p : FilePath) : IO Unit := do
partial def FS.createDirAll (p : FilePath) : IO Unit := do
if p.isDir then
return ()
if let some parent := p.parent then
@@ -694,7 +693,7 @@ partial def createDirAll (p : FilePath) : IO Unit := do
/--
Fully remove given directory by deleting all contained files and directories in an unspecified order.
Fails if any contained entry cannot be deleted or was newly created during execution. -/
partial def removeDirAll (p : FilePath) : IO Unit := do
partial def FS.removeDirAll (p : FilePath) : IO Unit := do
for ent in ( p.readDir) do
if ( ent.path.isDir : Bool) then
removeDirAll ent.path
@@ -702,32 +701,6 @@ partial def removeDirAll (p : FilePath) : IO Unit := do
removeFile ent.path
removeDir p
/--
Like `createTempFile`, but also takes care of removing the file after usage.
-/
def withTempFile [Monad m] [MonadFinally m] [MonadLiftT IO m] (f : Handle FilePath m α) :
m α := do
let (handle, path) createTempFile
try
f handle path
finally
removeFile path
/--
Like `createTempDir`, but also takes care of removing the directory after usage.
All files in the directory are recursively deleted, regardless of how or when they were created.
-/
def withTempDir [Monad m] [MonadFinally m] [MonadLiftT IO m] (f : FilePath m α) :
m α := do
let path createTempDir
try
f path
finally
removeDirAll path
end FS
namespace Process
/-- Returns the current working directory of the calling process. -/

View File

@@ -428,11 +428,11 @@ macro "infer_instance" : tactic => `(tactic| exact inferInstance)
/--
`+opt` is short for `(opt := true)`. It sets the `opt` configuration option to `true`.
-/
syntax posConfigItem := " +" noWs ident
syntax posConfigItem := "+" noWs ident
/--
`-opt` is short for `(opt := false)`. It sets the `opt` configuration option to `false`.
-/
syntax negConfigItem := " -" noWs ident
syntax negConfigItem := "-" noWs ident
/--
`(opt := val)` sets the `opt` configuration option to `val`.

View File

@@ -205,8 +205,8 @@ def getParamInfo (k : ParamMap.Key) : M (Array Param) := do
/-- For each ps[i], if ps[i] is owned, then mark xs[i] as owned. -/
def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
xs.size.forM fun i _ => do
let x := xs[i]
xs.size.forM fun i => do
let x := xs[i]!
let p := ps[i]!
unless p.borrow do ownArg x
@@ -216,8 +216,8 @@ def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
we would have to insert a `dec xs[i]` after `f xs` and consequently
"break" the tail call. -/
def ownParamsUsingArgs (xs : Array Arg) (ps : Array Param) : M Unit :=
xs.size.forM fun i _ => do
let x := xs[i]
xs.size.forM fun i => do
let x := xs[i]!
let p := ps[i]!
match x with
| Arg.var x => if ( isOwned x) then ownVar p.x

View File

@@ -48,9 +48,9 @@ def requiresBoxedVersion (env : Environment) (decl : Decl) : Bool :=
def mkBoxedVersionAux (decl : Decl) : N Decl := do
let ps := decl.params
let qs ps.mapM fun _ => do let x N.mkFresh; pure { x := x, ty := IRType.object, borrow := false : Param }
let (newVDecls, xs) qs.size.foldM (init := (#[], #[])) fun i _ (newVDecls, xs) => do
let (newVDecls, xs) qs.size.foldM (init := (#[], #[])) fun i (newVDecls, xs) => do
let p := ps[i]!
let q := qs[i]
let q := qs[i]!
if !p.ty.isScalar then
pure (newVDecls, xs.push (Arg.var q.x))
else

View File

@@ -63,7 +63,7 @@ partial def merge (v₁ v₂ : Value) : Value :=
| top, _ => top
| _, top => top
| v₁@(ctor i₁ vs₁), v₂@(ctor i₂ vs₂) =>
if i₁ == i₂ then ctor i₁ <| vs₁.size.fold (init := #[]) fun i _ r => r.push (merge vs₁[i] vs₂[i]!)
if i₁ == i₂ then ctor i₁ <| vs₁.size.fold (init := #[]) fun i r => r.push (merge vs₁[i]! vs₂[i]!)
else choice [v₁, v₂]
| choice vs₁, choice vs₂ => choice <| vs₁.foldl (addChoice merge) vs₂
| choice vs, v => choice <| addChoice merge vs v
@@ -225,8 +225,8 @@ def updateCurrFnSummary (v : Value) : M Unit := do
def updateJPParamsAssignment (ys : Array Param) (xs : Array Arg) : M Bool := do
let ctx read
let currFnIdx := ctx.currFnIdx
ys.size.foldM (init := false) fun i _ r => do
let y := ys[i]
ys.size.foldM (init := false) fun i r => do
let y := ys[i]!
let x := xs[i]!
let yVal findVarValue y.x
let xVal findArgValue x
@@ -282,8 +282,8 @@ partial def interpFnBody : FnBody → M Unit
def inferStep : M Bool := do
let ctx read
modify fun s => { s with assignments := ctx.decls.map fun _ => {} }
ctx.decls.size.foldM (init := false) fun idx _ modified => do
match ctx.decls[idx] with
ctx.decls.size.foldM (init := false) fun idx modified => do
match ctx.decls[idx]! with
| .fdecl (xs := ys) (body := b) .. => do
let s get
let currVals := s.funVals[idx]!
@@ -336,8 +336,8 @@ def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
let funVals := s.funVals
let assignments := s.assignments
modify fun s =>
let env := decls.size.fold (init := s.env) fun i _ env =>
addFunctionSummary env decls[i].name funVals[i]!
let env := decls.size.fold (init := s.env) fun i env =>
addFunctionSummary env decls[i]!.name funVals[i]!
{ s with env := env }
return decls.mapIdx fun i decl => elimDead assignments[i]! decl

View File

@@ -108,9 +108,9 @@ def emitFnDeclAux (decl : Decl) (cppBaseName : String) (isExternal : Bool) : M U
if ps.size > closureMaxArgs && isBoxedName decl.name then
emit "lean_object**"
else
ps.size.forM fun i _ => do
ps.size.forM fun i => do
if i > 0 then emit ", "
emit (toCType ps[i].ty)
emit (toCType ps[i]!.ty)
emit ")"
emitLn ";"
@@ -321,22 +321,20 @@ def emitSSet (x : VarId) (n : Nat) (offset : Nat) (y : VarId) (t : IRType) : M U
def emitJmp (j : JoinPointId) (xs : Array Arg) : M Unit := do
let ps getJPParams j
if h : xs.size = ps.size then
xs.size.forM fun i _ => do
let p := ps[i]
let x := xs[i]
emit p.x; emit " = "; emitArg x; emitLn ";"
emit "goto "; emit j; emitLn ";"
else
do throw "invalid goto"
unless xs.size == ps.size do throw "invalid goto"
xs.size.forM fun i => do
let p := ps[i]!
let x := xs[i]!
emit p.x; emit " = "; emitArg x; emitLn ";"
emit "goto "; emit j; emitLn ";"
def emitLhs (z : VarId) : M Unit := do
emit z; emit " = "
def emitArgs (ys : Array Arg) : M Unit :=
ys.size.forM fun i _ => do
ys.size.forM fun i => do
if i > 0 then emit ", "
emitArg ys[i]
emitArg ys[i]!
def emitCtorScalarSize (usize : Nat) (ssize : Nat) : M Unit := do
if usize == 0 then emit ssize
@@ -348,8 +346,8 @@ def emitAllocCtor (c : CtorInfo) : M Unit := do
emitCtorScalarSize c.usize c.ssize; emitLn ");"
def emitCtorSetArgs (z : VarId) (ys : Array Arg) : M Unit :=
ys.size.forM fun i _ => do
emit "lean_ctor_set("; emit z; emit ", "; emit i; emit ", "; emitArg ys[i]; emitLn ");"
ys.size.forM fun i => do
emit "lean_ctor_set("; emit z; emit ", "; emit i; emit ", "; emitArg ys[i]!; emitLn ");"
def emitCtor (z : VarId) (c : CtorInfo) (ys : Array Arg) : M Unit := do
emitLhs z;
@@ -360,7 +358,7 @@ def emitCtor (z : VarId) (c : CtorInfo) (ys : Array Arg) : M Unit := do
def emitReset (z : VarId) (n : Nat) (x : VarId) : M Unit := do
emit "if (lean_is_exclusive("; emit x; emitLn ")) {";
n.forM fun i _ => do
n.forM fun i => do
emit " lean_ctor_release("; emit x; emit ", "; emit i; emitLn ");"
emit " "; emitLhs z; emit x; emitLn ";";
emitLn "} else {";
@@ -401,12 +399,12 @@ def emitSimpleExternalCall (f : String) (ps : Array Param) (ys : Array Arg) : M
emit f; emit "("
-- We must remove irrelevant arguments to extern calls.
discard <| ys.size.foldM
(fun i _ (first : Bool) =>
(fun i (first : Bool) =>
if ps[i]!.ty.isIrrelevant then
pure first
else do
unless first do emit ", "
emitArg ys[i]
emitArg ys[i]!
pure false)
true
emitLn ");"
@@ -433,8 +431,8 @@ def emitPartialApp (z : VarId) (f : FunId) (ys : Array Arg) : M Unit := do
let decl getDecl f
let arity := decl.params.size;
emitLhs z; emit "lean_alloc_closure((void*)("; emitCName f; emit "), "; emit arity; emit ", "; emit ys.size; emitLn ");";
ys.size.forM fun i _ => do
let y := ys[i]
ys.size.forM fun i => do
let y := ys[i]!
emit "lean_closure_set("; emit z; emit ", "; emit i; emit ", "; emitArg y; emitLn ");"
def emitApp (z : VarId) (f : VarId) (ys : Array Arg) : M Unit :=
@@ -546,36 +544,34 @@ That is, we have
-/
def overwriteParam (ps : Array Param) (ys : Array Arg) : Bool :=
let n := ps.size;
n.any fun i _ =>
let p := ps[i]
(i+1, n).anyI fun j _ _ => paramEqArg p ys[j]!
n.any fun i =>
let p := ps[i]!
(i+1, n).anyI fun j => paramEqArg p ys[j]!
def emitTailCall (v : Expr) : M Unit :=
match v with
| Expr.fap _ ys => do
let ctx read
let ps := ctx.mainParams
if h : ps.size = ys.size then
if overwriteParam ps ys then
emitLn "{"
ps.size.forM fun i _ => do
let p := ps[i]
let y := ys[i]
unless paramEqArg p y do
emit (toCType p.ty); emit " _tmp_"; emit i; emit " = "; emitArg y; emitLn ";"
ps.size.forM fun i _ => do
let p := ps[i]
let y := ys[i]
unless paramEqArg p y do emit p.x; emit " = _tmp_"; emit i; emitLn ";"
emitLn "}"
else
ys.size.forM fun i _ => do
let p := ps[i]
let y := ys[i]
unless paramEqArg p y do emit p.x; emit " = "; emitArg y; emitLn ";"
emitLn "goto _start;"
unless ps.size == ys.size do throw "invalid tail call"
if overwriteParam ps ys then
emitLn "{"
ps.size.forM fun i => do
let p := ps[i]!
let y := ys[i]!
unless paramEqArg p y do
emit (toCType p.ty); emit " _tmp_"; emit i; emit " = "; emitArg y; emitLn ";"
ps.size.forM fun i => do
let p := ps[i]!
let y := ys[i]!
unless paramEqArg p y do emit p.x; emit " = _tmp_"; emit i; emitLn ";"
emitLn "}"
else
throw "invalid tail call"
ys.size.forM fun i => do
let p := ps[i]!
let y := ys[i]!
unless paramEqArg p y do emit p.x; emit " = "; emitArg y; emitLn ";"
emitLn "goto _start;"
| _ => throw "bug at emitTailCall"
mutual
@@ -658,16 +654,16 @@ def emitDeclAux (d : Decl) : M Unit := do
if xs.size > closureMaxArgs && isBoxedName d.name then
emit "lean_object** _args"
else
xs.size.forM fun i _ => do
xs.size.forM fun i => do
if i > 0 then emit ", "
let x := xs[i]
let x := xs[i]!
emit (toCType x.ty); emit " "; emit x.x
emit ")"
else
emit ("_init_" ++ baseName ++ "()")
emitLn " {";
if xs.size > closureMaxArgs && isBoxedName d.name then
xs.size.forM fun i _ => do
xs.size.forM fun i => do
let x := xs[i]!
emit "lean_object* "; emit x.x; emit " = _args["; emit i; emitLn "];"
emitLn "_start:";

View File

@@ -571,9 +571,9 @@ def emitAllocCtor (builder : LLVM.Builder llvmctx)
def emitCtorSetArgs (builder : LLVM.Builder llvmctx)
(z : VarId) (ys : Array Arg) : M llvmctx Unit := do
ys.size.forM fun i _ => do
ys.size.forM fun i => do
let zv emitLhsVal builder z
let (_yty, yv) emitArgVal builder ys[i]
let (_yty, yv) emitArgVal builder ys[i]!
let iv constIntUnsigned i
callLeanCtorSet builder zv iv yv
emitLhsSlotStore builder z zv
@@ -702,8 +702,8 @@ def emitPartialApp (builder : LLVM.Builder llvmctx) (z : VarId) (f : FunId) (ys
( constIntUnsigned arity)
( constIntUnsigned ys.size)
LLVM.buildStore builder zval zslot
ys.size.forM fun i _ => do
let (yty, yslot) emitArgSlot_ builder ys[i]
ys.size.forM fun i => do
let (yty, yslot) emitArgSlot_ builder ys[i]!
let yval LLVM.buildLoad2 builder yty yslot
callLeanClosureSetFn builder zval ( constIntUnsigned i) yval
@@ -922,7 +922,7 @@ def emitReset (builder : LLVM.Builder llvmctx) (z : VarId) (n : Nat) (x : VarId)
buildIfThenElse_ builder "isExclusive" isExclusive
(fun builder => do
let xv emitLhsVal builder x
n.forM fun i _ => do
n.forM fun i => do
callLeanCtorRelease builder xv ( constIntUnsigned i)
emitLhsSlotStore builder z xv
return ShouldForwardControlFlow.yes

View File

@@ -134,15 +134,15 @@ abbrev M := ReaderT Context (StateM Nat)
modifyGet fun n => ({ idx := n }, n + 1)
def releaseUnreadFields (y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
mask.size.foldM (init := b) fun i _ b =>
match mask[i] with
mask.size.foldM (init := b) fun i b =>
match mask.get! i with
| some _ => pure b -- code took ownership of this field
| none => do
let fld mkFresh
pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true false b))
def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
zs.size.fold (init := b) fun i _ b => FnBody.set y i zs[i] b
zs.size.fold (init := b) fun i b => FnBody.set y i (zs.get! i) b
/-- Given `set x[i] := y`, return true iff `y := proj[i] x` -/
def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=

View File

@@ -79,13 +79,13 @@ private def addDecForAlt (ctx : Context) (caseLiveVars altLiveVars : LiveVarSet)
/-- `isFirstOcc xs x i = true` if `xs[i]` is the first occurrence of `xs[i]` in `xs` -/
private def isFirstOcc (xs : Array Arg) (i : Nat) : Bool :=
let x := xs[i]!
i.all fun j _ => xs[j]! != x
i.all fun j => xs[j]! != x
/-- Return true if `x` also occurs in `ys` in a position that is not consumed.
That is, it is also passed as a borrow reference. -/
private def isBorrowParamAux (x : VarId) (ys : Array Arg) (consumeParamPred : Nat Bool) : Bool :=
ys.size.any fun i _ =>
let y := ys[i]
ys.size.any fun i =>
let y := ys[i]!
match y with
| Arg.irrelevant => false
| Arg.var y => x == y && !consumeParamPred i
@@ -99,15 +99,15 @@ Return `n`, the number of times `x` is consumed.
- `consumeParamPred i = true` if parameter `i` is consumed.
-/
private def getNumConsumptions (x : VarId) (ys : Array Arg) (consumeParamPred : Nat Bool) : Nat :=
ys.size.fold (init := 0) fun i _ n =>
let y := ys[i]
ys.size.fold (init := 0) fun i n =>
let y := ys[i]!
match y with
| Arg.irrelevant => n
| Arg.var y => if x == y && consumeParamPred i then n+1 else n
private def addIncBeforeAux (ctx : Context) (xs : Array Arg) (consumeParamPred : Nat Bool) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody :=
xs.size.fold (init := b) fun i _ b =>
let x := xs[i]
xs.size.fold (init := b) fun i b =>
let x := xs[i]!
match x with
| Arg.irrelevant => b
| Arg.var x =>
@@ -128,8 +128,8 @@ private def addIncBefore (ctx : Context) (xs : Array Arg) (ps : Array Param) (b
/-- See `addIncBeforeAux`/`addIncBefore` for the procedure that inserts `inc` operations before an application. -/
private def addDecAfterFullApp (ctx : Context) (xs : Array Arg) (ps : Array Param) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody :=
xs.size.fold (init := b) fun i _ b =>
match xs[i] with
xs.size.fold (init := b) fun i b =>
match xs[i]! with
| Arg.irrelevant => b
| Arg.var x =>
/- We must add a `dec` if `x` must be consumed, it is alive after the application,

View File

@@ -587,15 +587,15 @@ def Decl.elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
refer to the docstring of `Decl.safe`.
-/
if decls[i]!.safe then .bot else .top
let mut funVals := decls.size.fold (init := .empty) fun i _ p => p.push (initialVal i)
let mut funVals := decls.size.fold (init := .empty) fun i p => p.push (initialVal i)
let ctx := { decls }
let mut state := { assignments, funVals }
(_, state) inferMain |>.run ctx |>.run state
funVals := state.funVals
assignments := state.assignments
modifyEnv fun e =>
decls.size.fold (init := e) fun i _ env =>
addFunctionSummary env decls[i].name funVals[i]!
decls.size.fold (init := e) fun i env =>
addFunctionSummary env decls[i]!.name funVals[i]!
decls.mapIdxM fun i decl => if decl.safe then elimDead assignments[i]! decl else return decl

View File

@@ -76,8 +76,8 @@ def getType (fvarId : FVarId) : InferTypeM Expr := do
def mkForallFVars (xs : Array Expr) (type : Expr) : InferTypeM Expr :=
let b := type.abstract xs
xs.size.foldRevM (init := b) fun i _ b => do
let x := xs[i]
xs.size.foldRevM (init := b) fun i b => do
let x := xs[i]!
let n InferType.getBinderName x.fvarId!
let ty InferType.getType x.fvarId!
let ty := ty.abstractRange i xs;

View File

@@ -157,7 +157,9 @@ def installBeforeEachOccurrence (targetName : Name) (p : Pass → Pass) : PassIn
def replacePass (targetName : Name) (p : Pass Pass) (occurrence : Nat := 0) : PassInstaller where
install passes := do
let some idx := passes.findIdx? (fun p => p.name == targetName && p.occurrence == occurrence) | throwError s!"Tried to replace {targetName}, occurrence {occurrence} but {targetName} is not in the pass list"
return passes.modify idx p
let target := passes[idx]!
let replacement := p target
return passes.set! idx replacement
def replaceEachOccurrence (targetName : Name) (p : Pass Pass) : PassInstaller :=
withEachOccurrence targetName (replacePass targetName p ·)

View File

@@ -11,7 +11,6 @@ import Lean.ResolveName
import Lean.Elab.InfoTree.Types
import Lean.MonadEnv
import Lean.Elab.Exception
import Lean.Language.Basic
namespace Lean
register_builtin_option diagnostics : Bool := {
@@ -73,13 +72,6 @@ structure State where
messages : MessageLog := {}
/-- Info tree. We have the info tree here because we want to update it while adding attributes. -/
infoState : Elab.InfoState := {}
/--
Snapshot trees of asynchronous subtasks. As these are untyped and reported only at the end of the
command's main elaboration thread, they are only useful for basic message log reporting; for
incremental reporting and reuse within a long-running elaboration thread, types rooted in
`CommandParsedSnapshot` need to be adjusted.
-/
snapshotTasks : Array (Language.SnapshotTask Language.SnapshotTree) := #[]
deriving Nonempty
/-- Context for the CoreM monad. -/
@@ -188,8 +180,7 @@ instance : Elab.MonadInfoTree CoreM where
modifyInfoState f := modify fun s => { s with infoState := f s.infoState }
@[inline] def modifyCache (f : Cache Cache) : CoreM Unit :=
modify fun env, next, ngen, trace, cache, messages, infoState, snaps =>
env, next, ngen, trace, f cache, messages, infoState, snaps
modify fun env, next, ngen, trace, cache, messages, infoState => env, next, ngen, trace, f cache, messages, infoState
@[inline] def modifyInstLevelTypeCache (f : InstantiateLevelCache InstantiateLevelCache) : CoreM Unit :=
modifyCache fun c₁, c₂ => f c₁, c₂
@@ -364,83 +355,13 @@ instance : MonadLog CoreM where
if ( read).suppressElabErrors then
-- discard elaboration errors, except for a few important and unlikely misleading ones, on
-- parse error
unless msg.data.hasTag (· matches `Elab.synthPlaceholder | `Tactic.unsolvedGoals | `trace) do
unless msg.data.hasTag (· matches `Elab.synthPlaceholder | `Tactic.unsolvedGoals) do
return
let ctx read
let msg := { msg with data := MessageData.withNamingContext { currNamespace := ctx.currNamespace, openDecls := ctx.openDecls } msg.data };
modify fun s => { s with messages := s.messages.add msg }
/--
Includes a given task (such as from `wrapAsyncAsSnapshot`) in the overall snapshot tree for this
command's elaboration, making its result available to reporting and the language server. The
reporter will not know about this snapshot tree node until the main elaboration thread for this
command has finished so this function is not useful for incremental reporting within a longer
elaboration thread but only for tasks that outlive it such as background kernel checking or proof
elaboration.
-/
def logSnapshotTask (task : Language.SnapshotTask Language.SnapshotTree) : CoreM Unit :=
modify fun s => { s with snapshotTasks := s.snapshotTasks.push task }
/-- Wraps the given action for use in `EIO.asTask` etc., discarding its final monadic state. -/
def wrapAsync (act : Unit CoreM α) : CoreM (EIO Exception α) := do
let st get
let ctx read
let heartbeats := ( IO.getNumHeartbeats) - ctx.initHeartbeats
return withCurrHeartbeats (do
-- include heartbeats since start of elaboration in new thread as well such that forking off
-- an action doesn't suddenly allow it to succeed from a lower heartbeat count
IO.addHeartbeats heartbeats.toUInt64
act () : CoreM _)
|>.run' ctx st
/-- Option for capturing output to stderr during elaboration. -/
register_builtin_option stderrAsMessages : Bool := {
defValue := true
group := "server"
descr := "(server) capture output to the Lean stderr channel (such as from `dbg_trace`) during elaboration of a command as a diagnostic message"
}
open Language in
/--
Wraps the given action for use in `BaseIO.asTask` etc., discarding its final state except for
`logSnapshotTask` tasks, which are reported as part of the returned tree.
-/
def wrapAsyncAsSnapshot (act : Unit CoreM Unit) (desc : String := by exact decl_name%.toString) :
CoreM (BaseIO SnapshotTree) := do
let t wrapAsync fun _ => do
IO.FS.withIsolatedStreams (isolateStderr := stderrAsMessages.get ( getOptions)) do
let tid IO.getTID
-- reset trace state and message log so as not to report them twice
modify ({ · with messages := {}, traceState := { tid } })
try
withTraceNode `Elab.async (fun _ => return desc) do
act ()
catch e =>
logError e.toMessageData
finally
addTraceAsMessages
get
let ctx readThe Core.Context
return do
match ( t.toBaseIO) with
| .ok (output, st) =>
let mut msgs := st.messages
if !output.isEmpty then
msgs := msgs.add {
fileName := ctx.fileName
severity := MessageSeverity.information
pos := ctx.fileMap.toPosition <| ctx.ref.getPos?.getD 0
data := output
}
return .mk {
desc
diagnostics := ( Language.Snapshot.Diagnostics.ofMessageLog msgs)
traces := st.traceState
} st.snapshotTasks
-- interrupt or abort exception as `try catch` above should have caught any others
| .error _ => default
end Core
export Core (CoreM mkFreshUserName checkSystem withCurrHeartbeats)

View File

@@ -277,23 +277,4 @@ attribute [deprecated Std.HashMap.empty (since := "2024-08-08")] mkHashMap
attribute [deprecated Std.HashMap.empty (since := "2024-08-08")] HashMap.empty
attribute [deprecated Std.HashMap.ofList (since := "2024-08-08")] HashMap.ofList
attribute [deprecated Std.HashMap.insert (since := "2024-08-08")] HashMap.insert
attribute [deprecated Std.HashMap.containsThenInsert (since := "2024-08-08")] HashMap.insert'
attribute [deprecated Std.HashMap.insertIfNew (since := "2024-08-08")] HashMap.insertIfNew
attribute [deprecated Std.HashMap.erase (since := "2024-08-08")] HashMap.erase
attribute [deprecated "Use `m[k]?` instead." (since := "2024-08-08")] HashMap.findEntry?
attribute [deprecated "Use `m[k]?` instead." (since := "2024-08-08")] HashMap.find?
attribute [deprecated "Use `m[k]?.getD` instead." (since := "2024-08-08")] HashMap.findD
attribute [deprecated "Use `m[k]!` instead." (since := "2024-08-08")] HashMap.find!
attribute [deprecated Std.HashMap.contains (since := "2024-08-08")] HashMap.contains
attribute [deprecated Std.HashMap.foldM (since := "2024-08-08")] HashMap.foldM
attribute [deprecated Std.HashMap.fold (since := "2024-08-08")] HashMap.fold
attribute [deprecated Std.HashMap.forM (since := "2024-08-08")] HashMap.forM
attribute [deprecated Std.HashMap.size (since := "2024-08-08")] HashMap.size
attribute [deprecated Std.HashMap.isEmpty (since := "2024-08-08")] HashMap.isEmpty
attribute [deprecated Std.HashMap.toList (since := "2024-08-08")] HashMap.toList
attribute [deprecated Std.HashMap.toArray (since := "2024-08-08")] HashMap.toArray
attribute [deprecated "Deprecateed without a replacement." (since := "2024-08-08")] HashMap.numBuckets
attribute [deprecated "Deprecateed without a replacement." (since := "2024-08-08")] HashMap.ofListWith
end Lean.HashMap

View File

@@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Data.Nat.Fold
import Init.Data.Array.Basic
import Init.NotationExtra
import Init.Data.ToString.Macro
@@ -372,7 +371,7 @@ instance : ToString Stats := ⟨Stats.toString⟩
end PersistentArray
def mkPersistentArray {α : Type u} (n : Nat) (v : α) : PArray α :=
n.fold (init := PersistentArray.empty) fun _ _ p => p.push v
n.fold (init := PersistentArray.empty) fun _ p => p.push v
@[inline] def mkPArray {α : Type u} (n : Nat) (v : α) : PArray α :=
mkPersistentArray n v

View File

@@ -23,7 +23,6 @@ import Lean.Elab.Quotation
import Lean.Elab.Syntax
import Lean.Elab.Do
import Lean.Elab.StructInst
import Lean.Elab.MutualInductive
import Lean.Elab.Inductive
import Lean.Elab.Structure
import Lean.Elab.Print

View File

@@ -807,8 +807,8 @@ def getElabElimExprInfo (elimExpr : Expr) : MetaM ElabElimInfo := do
These are the primary set of major parameters.
-/
let initMotiveFVars : CollectFVars.State := motiveArgs.foldl (init := {}) collectFVars
let motiveFVars xs.size.foldRevM (init := initMotiveFVars) fun i _ s => do
let x := xs[i]
let motiveFVars xs.size.foldRevM (init := initMotiveFVars) fun i s => do
let x := xs[i]!
if s.fvarSet.contains x.fvarId! then
return collectFVars s ( inferType x)
else
@@ -1150,33 +1150,48 @@ private def throwLValError (e : Expr) (eType : Expr) (msg : MessageData) : TermE
throwError "{msg}{indentExpr e}\nhas type{indentExpr eType}"
/--
`findMethod? S fName` tries the for each namespace `S'` in the resolution order for `S` to resolve the name `S'.fname`.
If it resolves to `name`, returns `(S', name)`.
`findMethod? S fName` tries the following for each namespace `S'` in the resolution order for `S`:
- If `env` contains `S' ++ fName`, returns `(S', S' ++ fName)`
- Otherwise if `env` contains private name `prv` for `S' ++ fName`, returns `(S', prv)`
-/
private partial def findMethod? (structName fieldName : Name) : MetaM (Option (Name × Name)) := do
let env getEnv
let find? structName' : MetaM (Option (Name × Name)) := do
let fullName := structName' ++ fieldName
-- We do not want to make use of the current namespace for resolution.
let candidates := ResolveName.resolveGlobalName ( getEnv) Name.anonymous ( getOpenDecls) fullName
|>.filter (fun (_, fieldList) => fieldList.isEmpty)
|>.map Prod.fst
match candidates with
| [] => return none
| [fullName'] => return some (structName', fullName')
| _ => throwError "\
invalid field notation '{fieldName}', the name '{fullName}' is ambiguous, possible interpretations: \
{MessageData.joinSep (candidates.map (m!"'{.ofConstName ·}'")) ", "}"
if env.contains fullName then
return some (structName', fullName)
let fullNamePrv := mkPrivateName env fullName
if env.contains fullNamePrv then
return some (structName', fullNamePrv)
return none
-- Optimization: the first element of the resolution order is `structName`,
-- so we can skip computing the resolution order in the common case
-- of the name resolving in the `structName` namespace.
find? structName <||> do
let resolutionOrder if isStructure env structName then getStructureResolutionOrder structName else pure #[structName]
for ns in resolutionOrder[1:resolutionOrder.size] do
if let some res find? ns then
for h : i in [1:resolutionOrder.size] do
if let some res find? resolutionOrder[i] then
return res
return none
/--
Return `some (structName', fullName)` if `structName ++ fieldName` is an alias for `fullName`, and
`fullName` is of the form `structName' ++ fieldName`.
TODO: if there is more than one applicable alias, it returns `none`. We should consider throwing an error or
warning.
-/
private def findMethodAlias? (env : Environment) (structName fieldName : Name) : Option (Name × Name) :=
let fullName := structName ++ fieldName
-- We never skip `protected` aliases when resolving dot-notation.
let aliasesCandidates := getAliases env fullName (skipProtected := false) |>.filterMap fun alias =>
match alias.eraseSuffix? fieldName with
| none => none
| some structName' => some (structName', alias)
match aliasesCandidates with
| [r] => some r
| _ => none
private def throwInvalidFieldNotation (e eType : Expr) : TermElabM α :=
throwLValError e eType "invalid field notation, type is not of the form (C ...) where C is a constant"
@@ -1208,22 +1223,30 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L
throwLValError e eType m!"invalid projection, structure has only {numFields} field(s)"
| some structName, LVal.fieldName _ fieldName _ _ =>
let env getEnv
let searchEnv : Unit TermElabM LValResolution := fun _ => do
if let some (baseStructName, fullName) findMethod? structName (.mkSimple fieldName) then
return LValResolution.const baseStructName structName fullName
else if let some (structName', fullName) := findMethodAlias? env structName (.mkSimple fieldName) then
return LValResolution.const structName' structName' fullName
else
throwLValError e eType
m!"invalid field '{fieldName}', the environment does not contain '{Name.mkStr structName fieldName}'"
-- search local context first, then environment
let searchCtx : Unit TermElabM LValResolution := fun _ => do
let fullName := Name.mkStr structName fieldName
for localDecl in ( getLCtx) do
if localDecl.isAuxDecl then
if let some localDeclFullName := ( read).auxDeclToFullName.find? localDecl.fvarId then
if fullName == (privateToUserName? localDeclFullName).getD localDeclFullName then
/- LVal notation is being used to make a "local" recursive call. -/
return LValResolution.localRec structName fullName localDecl.toExpr
searchEnv ()
if isStructure env structName then
if let some baseStructName := findField? env structName (Name.mkSimple fieldName) then
return LValResolution.projFn baseStructName structName (Name.mkSimple fieldName)
-- Search the local context first
let fullName := Name.mkStr structName fieldName
for localDecl in ( getLCtx) do
if localDecl.isAuxDecl then
if let some localDeclFullName := ( read).auxDeclToFullName.find? localDecl.fvarId then
if fullName == (privateToUserName? localDeclFullName).getD localDeclFullName then
/- LVal notation is being used to make a "local" recursive call. -/
return LValResolution.localRec structName fullName localDecl.toExpr
-- Then search the environment
if let some (baseStructName, fullName) findMethod? structName (.mkSimple fieldName) then
return LValResolution.const baseStructName structName fullName
throwLValError e eType
m!"invalid field '{fieldName}', the environment does not contain '{Name.mkStr structName fieldName}'"
match findField? env structName (Name.mkSimple fieldName) with
| some baseStructName => return LValResolution.projFn baseStructName structName (Name.mkSimple fieldName)
| none => searchCtx ()
else
searchCtx ()
| none, LVal.fieldName _ _ (some suffix) _ =>
if e.isConst then
throwUnknownConstant (e.constName! ++ suffix)
@@ -1303,7 +1326,7 @@ Otherwise, if there isn't another parameter with the same name, we add `e` to `n
Remark: `fullName` is the name of the resolved "field" access function. It is used for reporting errors
-/
private partial def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (f : Expr) (explicit : Bool) :
private partial def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (f : Expr) :
MetaM (Array Arg × Array NamedArg) := do
withoutModifyingState <| go f ( inferType f) 0 namedArgs (namedArgs.map (·.name)) true
where
@@ -1328,11 +1351,11 @@ where
/- If there is named argument with name `xDecl.userName`, then it is accounted for and we can't make use of it. -/
remainingNamedArgs := remainingNamedArgs.eraseIdx idx
else
if typeMatchesBaseName xDecl.type baseName then
/- We found a type of the form (baseName ...), or we found the first explicit argument in useFirstExplicit mode.
First, we check if the current argument is one that can be used positionally,
if ( typeMatchesBaseName xDecl.type baseName) then
/- We found a type of the form (baseName ...).
First, we check if the current argument is an explicit one,
and if the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/
if h : argIdx args.size (explicit || bInfo.isExplicit) then
if h : argIdx args.size bInfo.isExplicit then
/- We can insert `e` as an explicit argument -/
return (args.insertIdx argIdx (Arg.expr e), namedArgs)
else
@@ -1340,13 +1363,13 @@ where
if there isn't an argument with the same name occurring before it. -/
if !allowNamed || unusableNamedArgs.contains xDecl.userName then
throwError "\
invalid field notation, function '{.ofConstName fullName}' has argument with the expected type\
invalid field notation, function '{fullName}' has argument with the expected type\
{indentExpr xDecl.type}\n\
but it cannot be used"
else
return (args, namedArgs.push { name := xDecl.userName, val := Arg.expr e })
/- Advance `argIdx` and update seen named arguments. -/
if explicit || bInfo.isExplicit then
if bInfo.isExplicit then
argIdx := argIdx + 1
unusableNamedArgs := unusableNamedArgs.push xDecl.userName
/- If named arguments aren't allowed, then it must still be possible to pass the value as an explicit argument.
@@ -1357,7 +1380,7 @@ where
if let some f' coerceToFunction? (mkAppN f xs) then
return go f' ( inferType f') argIdx remainingNamedArgs unusableNamedArgs false
throwError "\
invalid field notation, function '{.ofConstName fullName}' does not have argument with type ({.ofConstName baseName} ...) that can be used, \
invalid field notation, function '{fullName}' does not have argument with type ({baseName} ...) that can be used, \
it must be explicit or implicit with a unique name"
/-- Adds the `TermInfo` for the field of a projection. See `Lean.Parser.Term.identProjKind`. -/
@@ -1403,7 +1426,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
let projFn mkConst constName
let projFn addProjTermInfo lval.getRef projFn
if lvals.isEmpty then
let (args, namedArgs) addLValArg baseStructName constName f args namedArgs projFn explicit
let (args, namedArgs) addLValArg baseStructName constName f args namedArgs projFn
elabAppArgs projFn namedArgs args expectedType? explicit ellipsis
else
let f elabAppArgs projFn #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false)
@@ -1411,7 +1434,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
| LValResolution.localRec baseName fullName fvar =>
let fvar addProjTermInfo lval.getRef fvar
if lvals.isEmpty then
let (args, namedArgs) addLValArg baseName fullName f args namedArgs fvar explicit
let (args, namedArgs) addLValArg baseName fullName f args namedArgs fvar
elabAppArgs fvar namedArgs args expectedType? explicit ellipsis
else
let f elabAppArgs fvar #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false)

View File

@@ -84,7 +84,6 @@ structure State where
ngen : NameGenerator := {}
infoState : InfoState := {}
traceState : TraceState := {}
snapshotTasks : Array (Language.SnapshotTask Language.SnapshotTree) := #[]
deriving Nonempty
structure Context where
@@ -115,7 +114,8 @@ structure Context where
-/
suppressElabErrors : Bool := false
abbrev CommandElabM := ReaderT Context $ StateRefT State $ EIO Exception
abbrev CommandElabCoreM (ε) := ReaderT Context $ StateRefT State $ EIO ε
abbrev CommandElabM := CommandElabCoreM Exception
abbrev CommandElab := Syntax CommandElabM Unit
structure Linter where
run : Syntax CommandElabM Unit
@@ -198,6 +198,36 @@ instance : AddErrorMessageContext CommandElabM where
let msg addMacroStack msg ctx.macroStack
return (ref, msg)
def mkMessageAux (ctx : Context) (ref : Syntax) (msgData : MessageData) (severity : MessageSeverity) : Message :=
let pos := ref.getPos?.getD ctx.cmdPos
let endPos := ref.getTailPos?.getD pos
mkMessageCore ctx.fileName ctx.fileMap msgData severity pos endPos
private def addTraceAsMessagesCore (ctx : Context) (log : MessageLog) (traceState : TraceState) : MessageLog := Id.run do
if traceState.traces.isEmpty then return log
let mut traces : Std.HashMap (String.Pos × String.Pos) (Array MessageData) :=
for traceElem in traceState.traces do
let ref := replaceRef traceElem.ref ctx.ref
let pos := ref.getPos?.getD 0
let endPos := ref.getTailPos?.getD pos
traces := traces.insert (pos, endPos) <| traces.getD (pos, endPos) #[] |>.push traceElem.msg
let mut log := log
let traces' := traces.toArray.qsort fun ((a, _), _) ((b, _), _) => a < b
for ((pos, endPos), traceMsg) in traces' do
let data := .tagged `trace <| .joinSep traceMsg.toList "\n"
log := log.add <| mkMessageCore ctx.fileName ctx.fileMap data .information pos endPos
return log
private def addTraceAsMessages : CommandElabM Unit := do
let ctx read
-- do not add trace messages if `trace.profiler.output` is set as it would be redundant and
-- pretty printing the trace messages is expensive
if trace.profiler.output.get? ( getOptions) |>.isNone then
modify fun s => { s with
messages := addTraceAsMessagesCore ctx s.messages s.traceState
traceState.traces := {}
}
private def runCore (x : CoreM α) : CommandElabM α := do
let s get
let ctx read
@@ -223,7 +253,6 @@ private def runCore (x : CoreM α) : CommandElabM α := do
nextMacroScope := s.nextMacroScope
infoState.enabled := s.infoState.enabled
traceState := s.traceState
snapshotTasks := s.snapshotTasks
}
let (ea, coreS) liftM x
modify fun s => { s with
@@ -232,7 +261,6 @@ private def runCore (x : CoreM α) : CommandElabM α := do
ngen := coreS.ngen
infoState.trees := s.infoState.trees.append coreS.infoState.trees
traceState.traces := coreS.traceState.traces.map fun t => { t with ref := replaceRef t.ref ctx.ref }
snapshotTasks := coreS.snapshotTasks
messages := s.messages ++ coreS.messages
}
return ea
@@ -240,6 +268,10 @@ private def runCore (x : CoreM α) : CommandElabM α := do
def liftCoreM (x : CoreM α) : CommandElabM α := do
MonadExcept.ofExcept ( runCore (observing x))
private def ioErrorToMessage (ctx : Context) (ref : Syntax) (err : IO.Error) : Message :=
let ref := getBetterRef ref ctx.macroStack
mkMessageAux ctx ref (toString err) MessageSeverity.error
@[inline] def liftIO {α} (x : IO α) : CommandElabM α := do
let ctx read
IO.toEIO (fun (ex : IO.Error) => Exception.error ctx.ref ex.toString) x
@@ -262,8 +294,9 @@ instance : MonadLog CommandElabM where
logMessage msg := do
if ( read).suppressElabErrors then
-- discard elaboration errors on parse error
unless msg.data.hasTag (· matches `trace) do
return
-- NOTE: unlike `CoreM`'s `logMessage`, we do not currently have any command-level errors that
-- we want to allowlist
return
let currNamespace getCurrNamespace
let openDecls getOpenDecls
let msg := { msg with data := MessageData.withNamingContext { currNamespace := currNamespace, openDecls := openDecls } msg.data }
@@ -289,61 +322,6 @@ def runLinters (stx : Syntax) : CommandElabM Unit := do
finally
modify fun s => { savedState with messages := s.messages }
/--
Catches and logs exceptions occurring in `x`. Unlike `try catch` in `CommandElabM`, this function
catches interrupt exceptions as well and thus is intended for use at the top level of elaboration.
Interrupt and abort exceptions are caught but not logged.
-/
@[inline] def withLoggingExceptions (x : CommandElabM Unit) : CommandElabM Unit := fun ctx ref =>
EIO.catchExceptions (withLogging x ctx ref) (fun _ => pure ())
@[inherit_doc Core.wrapAsync]
def wrapAsync (act : Unit CommandElabM α) : CommandElabM (EIO Exception α) := do
return act () |>.run ( read) |>.run' ( get)
open Language in
@[inherit_doc Core.wrapAsyncAsSnapshot]
-- `CoreM` and `CommandElabM` are too different to meaningfully share this code
def wrapAsyncAsSnapshot (act : Unit CommandElabM Unit)
(desc : String := by exact decl_name%.toString) :
CommandElabM (BaseIO SnapshotTree) := do
let t wrapAsync fun _ => do
IO.FS.withIsolatedStreams (isolateStderr := Core.stderrAsMessages.get ( getOptions)) do
let tid IO.getTID
-- reset trace state and message log so as not to report them twice
modify ({ · with messages := {}, traceState := { tid } })
try
withTraceNode `Elab.async (fun _ => return desc) do
act ()
catch e =>
logError e.toMessageData
finally
addTraceAsMessages
get
let ctx read
return do
match ( t.toBaseIO) with
| .ok (output, st) =>
let mut msgs := st.messages
if !output.isEmpty then
msgs := msgs.add {
fileName := ctx.fileName
severity := MessageSeverity.information
pos := ctx.fileMap.toPosition <| ctx.ref.getPos?.getD 0
data := output
}
return .mk {
desc
diagnostics := ( Language.Snapshot.Diagnostics.ofMessageLog msgs)
traces := st.traceState
} st.snapshotTasks
-- interrupt or abort exception as `try catch` above should have caught any others
| .error _ => default
@[inherit_doc Core.logSnapshotTask]
def logSnapshotTask (task : Language.SnapshotTask Language.SnapshotTree) : CommandElabM Unit :=
modify fun s => { s with snapshotTasks := s.snapshotTasks.push task }
protected def getCurrMacroScope : CommandElabM Nat := do pure ( read).currMacroScope
protected def getMainModule : CommandElabM Name := do pure ( getEnv).mainModule
@@ -554,6 +532,12 @@ def elabCommandTopLevel (stx : Syntax) : CommandElabM Unit := withRef stx do pro
let mut msgs := ( get).messages
for tree in ( getInfoTrees) do
trace[Elab.info] ( tree.format)
if ( isTracingEnabledFor `Elab.snapshotTree) then
if let some snap := ( read).snap? then
-- We can assume that the root command snapshot is not involved in parallelism yet, so this
-- should be true iff the command supports incrementality
if ( IO.hasFinished snap.new.result) then
liftCoreM <| Language.ToSnapshotTree.toSnapshotTree snap.new.result.get |>.trace
modify fun st => { st with
messages := initMsgs ++ msgs
infoState := { st.infoState with trees := initInfoTrees ++ st.infoState.trees }
@@ -684,6 +668,14 @@ def runTermElabM (elabFn : Array Expr → TermElabM α) : CommandElabM α := do
Term.addAutoBoundImplicits' xs someType fun xs _ =>
Term.withoutAutoBoundImplicit <| elabFn xs
/--
Catches and logs exceptions occurring in `x`. Unlike `try catch` in `CommandElabM`, this function
catches interrupt exceptions as well and thus is intended for use at the top level of elaboration.
Interrupt and abort exceptions are caught but not logged.
-/
@[inline] def withLoggingExceptions (x : CommandElabM Unit) : CommandElabCoreM Empty Unit := fun ctx ref =>
EIO.catchExceptions (withLogging x ctx ref) (fun _ => pure ())
private def liftAttrM {α} (x : AttrM α) : CommandElabM α := do
liftCoreM x

View File

@@ -7,8 +7,9 @@ prelude
import Lean.Util.CollectLevelParams
import Lean.Elab.DeclUtil
import Lean.Elab.DefView
import Lean.Elab.Inductive
import Lean.Elab.Structure
import Lean.Elab.MutualDef
import Lean.Elab.MutualInductive
import Lean.Elab.DeclarationRange
namespace Lean.Elab.Command
@@ -162,11 +163,15 @@ def elabDeclaration : CommandElab := fun stx => do
if declKind == ``Lean.Parser.Command.«axiom» then
let modifiers elabModifiers modifiers
elabAxiom modifiers decl
else if declKind == ``Lean.Parser.Command.«inductive»
|| declKind == ``Lean.Parser.Command.classInductive
|| declKind == ``Lean.Parser.Command.«structure» then
else if declKind == ``Lean.Parser.Command.«inductive» then
let modifiers elabModifiers modifiers
elabInductive modifiers decl
else if declKind == ``Lean.Parser.Command.classInductive then
let modifiers elabModifiers modifiers
elabClassInductive modifiers decl
else if declKind == ``Lean.Parser.Command.«structure» then
let modifiers elabModifiers modifiers
elabStructure modifiers decl
else
throwError "unexpected declaration"
@@ -273,10 +278,10 @@ def elabMutual : CommandElab := fun stx => do
-- only case implementing incrementality currently
elabMutualDef stx[1].getArgs
else withoutCommandIncrementality true do
if isMutualInductive stx then
if isMutualInductive stx then
elabMutualInductive stx[1].getArgs
else
throwError "invalid mutual block: either all elements of the block must be inductive/structure declarations, or they must all be definitions/theorems/abbrevs"
throwError "invalid mutual block: either all elements of the block must be inductive declarations, or they must all be definitions/theorems/abbrevs"
/- leading_parser "attribute " >> "[" >> sepBy1 (eraseAttr <|> Term.attrInstance) ", " >> "]" >> many1 ident -/
@[builtin_command_elab «attribute»] def elabAttr : CommandElab := fun stx => do

View File

@@ -1,36 +1,102 @@
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura, Kyle Miller
Authors: Leonardo de Moura
-/
prelude
import Lean.Elab.MutualInductive
import Lean.Util.ForEachExprWhere
import Lean.Util.ReplaceLevel
import Lean.Util.ReplaceExpr
import Lean.Util.CollectLevelParams
import Lean.Meta.Constructions
import Lean.Meta.CollectFVars
import Lean.Meta.SizeOf
import Lean.Meta.Injective
import Lean.Meta.IndPredBelow
import Lean.Elab.Command
import Lean.Elab.ComputedFields
import Lean.Elab.DefView
import Lean.Elab.DeclUtil
import Lean.Elab.Deriving.Basic
import Lean.Elab.DeclarationRange
namespace Lean.Elab.Command
open Meta
/-
```
def Lean.Parser.Command.«inductive» :=
leading_parser "inductive " >> declId >> optDeclSig >> optional ("where" <|> ":=") >> many ctor
builtin_initialize
registerTraceClass `Elab.inductive
def Lean.Parser.Command.classInductive :=
leading_parser atomic (group ("class " >> "inductive ")) >> declId >> optDeclSig >> optional ("where" <|> ":=") >> many ctor >> optDeriving
```
register_builtin_option inductive.autoPromoteIndices : Bool := {
defValue := true
descr := "Promote indices to parameters in inductive types whenever possible."
}
def checkValidInductiveModifier [Monad m] [MonadError m] (modifiers : Modifiers) : m Unit := do
if modifiers.isNoncomputable then
throwError "invalid use of 'noncomputable' in inductive declaration"
if modifiers.isPartial then
throwError "invalid use of 'partial' in inductive declaration"
def checkValidCtorModifier [Monad m] [MonadError m] (modifiers : Modifiers) : m Unit := do
if modifiers.isNoncomputable then
throwError "invalid use of 'noncomputable' in constructor declaration"
if modifiers.isPartial then
throwError "invalid use of 'partial' in constructor declaration"
if modifiers.isUnsafe then
throwError "invalid use of 'unsafe' in constructor declaration"
if modifiers.attrs.size != 0 then
throwError "invalid use of attributes in constructor declaration"
structure CtorView where
ref : Syntax
modifiers : Modifiers
declName : Name
binders : Syntax
type? : Option Syntax
deriving Inhabited
structure ComputedFieldView where
ref : Syntax
modifiers : Syntax
fieldId : Name
type : Syntax.Term
matchAlts : TSyntax ``Parser.Term.matchAlts
structure InductiveView where
ref : Syntax
declId : Syntax
modifiers : Modifiers
shortDeclName : Name
declName : Name
levelNames : List Name
binders : Syntax
type? : Option Syntax
ctors : Array CtorView
derivingClasses : Array DerivingClassView
computedFields : Array ComputedFieldView
deriving Inhabited
structure ElabHeaderResult where
view : InductiveView
lctx : LocalContext
localInsts : LocalInstances
levelNames : List Name
params : Array Expr
type : Expr
deriving Inhabited
/-
leading_parser "inductive " >> declId >> optDeclSig >> optional ("where" <|> ":=") >> many ctor
leading_parser atomic (group ("class " >> "inductive ")) >> declId >> optDeclSig >> optional ("where" <|> ":=") >> many ctor >> optDeriving
-/
private def inductiveSyntaxToView (modifiers : Modifiers) (decl : Syntax) : TermElabM InductiveView := do
let isClass := decl.isOfKind ``Parser.Command.classInductive
let modifiers := if isClass then modifiers.addAttr { name := `class } else modifiers
checkValidInductiveModifier modifiers
let (binders, type?) := expandOptDeclSig decl[2]
let declId := decl[1]
let name, declName, levelNames Term.expandDeclId ( getCurrNamespace) ( Term.getLevelNames) declId modifiers
addDeclarationRangesForBuiltin declName modifiers.stx decl
let ctors decl[4].getArgs.mapM fun ctor => withRef ctor do
/-
```
def ctor := leading_parser optional docComment >> "\n| " >> declModifiers >> rawIdent >> optDeclSig
```
-/
-- def ctor := leading_parser optional docComment >> "\n| " >> declModifiers >> rawIdent >> optDeclSig
let mut ctorModifiers elabModifiers ctor[2]
if let some leadingDocComment := ctor[0].getOptional? then
if ctorModifiers.docString?.isSome then
@@ -47,7 +113,7 @@ private def inductiveSyntaxToView (modifiers : Modifiers) (decl : Syntax) : Term
let (binders, type?) := expandOptDeclSig ctor[4]
addDocString' ctorName ctorModifiers.docString?
addDeclarationRangesFromSyntax ctorName ctor ctor[3]
return { ref := ctor, declId := ctor[3], modifiers := ctorModifiers, declName := ctorName, binders := binders, type? := type? : CtorView }
return { ref := ctor, modifiers := ctorModifiers, declName := ctorName, binders := binders, type? := type? : CtorView }
let computedFields (decl[5].getOptional?.map (·[1].getArgs) |>.getD #[]).mapM fun cf => withRef cf do
return { ref := cf, modifiers := cf[0], fieldId := cf[1].getId, type := cf[3], matchAlts := cf[4] }
let classes getOptDerivingClasses decl[6]
@@ -59,13 +125,137 @@ private def inductiveSyntaxToView (modifiers : Modifiers) (decl : Syntax) : Term
ref := decl
shortDeclName := name
derivingClasses := classes
allowIndices := true
allowSortPolymorphism := true
declId, modifiers, isClass, declName, levelNames
declId, modifiers, declName, levelNames
binders, type?, ctors
computedFields
}
private partial def elabHeaderAux (views : Array InductiveView) (i : Nat) (acc : Array ElabHeaderResult) : TermElabM (Array ElabHeaderResult) :=
Term.withAutoBoundImplicitForbiddenPred (fun n => views.any (·.shortDeclName == n)) do
if h : i < views.size then
let view := views[i]
let acc Term.withAutoBoundImplicit <| Term.elabBinders view.binders.getArgs fun params => do
match view.type? with
| none =>
let u mkFreshLevelMVar
let type := mkSort u
Term.synthesizeSyntheticMVarsNoPostponing
Term.addAutoBoundImplicits' params type fun params type => do
let levelNames Term.getLevelNames
return acc.push { lctx := ( getLCtx), localInsts := ( getLocalInstances), levelNames, params, type, view }
| some typeStx =>
let (type, _) Term.withAutoBoundImplicit do
let type Term.elabType typeStx
unless ( isTypeFormerType type) do
throwErrorAt typeStx "invalid inductive type, resultant type is not a sort"
Term.synthesizeSyntheticMVarsNoPostponing
let indices Term.addAutoBoundImplicits #[]
return ( mkForallFVars indices type, indices.size)
Term.addAutoBoundImplicits' params type fun params type => do
trace[Elab.inductive] "header params: {params}, type: {type}"
let levelNames Term.getLevelNames
return acc.push { lctx := ( getLCtx), localInsts := ( getLocalInstances), levelNames, params, type, view }
elabHeaderAux views (i+1) acc
else
return acc
private def checkNumParams (rs : Array ElabHeaderResult) : TermElabM Nat := do
let numParams := rs[0]!.params.size
for r in rs do
unless r.params.size == numParams do
throwErrorAt r.view.ref "invalid inductive type, number of parameters mismatch in mutually inductive datatypes"
return numParams
private def checkUnsafe (rs : Array ElabHeaderResult) : TermElabM Unit := do
let isUnsafe := rs[0]!.view.modifiers.isUnsafe
for r in rs do
unless r.view.modifiers.isUnsafe == isUnsafe do
throwErrorAt r.view.ref "invalid inductive type, cannot mix unsafe and safe declarations in a mutually inductive datatypes"
private def InductiveView.checkLevelNames (views : Array InductiveView) : TermElabM Unit := do
if h : views.size > 1 then
let levelNames := views[0].levelNames
for view in views do
unless view.levelNames == levelNames do
throwErrorAt view.ref "invalid inductive type, universe parameters mismatch in mutually inductive datatypes"
private def ElabHeaderResult.checkLevelNames (rs : Array ElabHeaderResult) : TermElabM Unit := do
if h : rs.size > 1 then
let levelNames := rs[0].levelNames
for r in rs do
unless r.levelNames == levelNames do
throwErrorAt r.view.ref "invalid inductive type, universe parameters mismatch in mutually inductive datatypes"
private def mkTypeFor (r : ElabHeaderResult) : TermElabM Expr := do
withLCtx r.lctx r.localInsts do
mkForallFVars r.params r.type
private def throwUnexpectedInductiveType : TermElabM α :=
throwError "unexpected inductive resulting type"
private def eqvFirstTypeResult (firstType type : Expr) : MetaM Bool :=
forallTelescopeReducing firstType fun _ firstTypeResult => isDefEq firstTypeResult type
-- Auxiliary function for checking whether the types in mutually inductive declaration are compatible.
private partial def checkParamsAndResultType (type firstType : Expr) (numParams : Nat) : TermElabM Unit := do
try
forallTelescopeCompatible type firstType numParams fun _ type firstType =>
forallTelescopeReducing type fun _ type =>
forallTelescopeReducing firstType fun _ firstType => do
let type whnfD type
match type with
| .sort .. =>
unless ( isDefEq firstType type) do
throwError "resulting universe mismatch, given{indentExpr type}\nexpected type{indentExpr firstType}"
| _ =>
throwError "unexpected inductive resulting type"
catch
| Exception.error ref msg => throw (Exception.error ref m!"invalid mutually inductive types, {msg}")
| ex => throw ex
-- Auxiliary function for checking whether the types in mutually inductive declaration are compatible.
private def checkHeader (r : ElabHeaderResult) (numParams : Nat) (firstType? : Option Expr) : TermElabM Expr := do
let type mkTypeFor r
match firstType? with
| none => return type
| some firstType =>
withRef r.view.ref <| checkParamsAndResultType type firstType numParams
return firstType
-- Auxiliary function for checking whether the types in mutually inductive declaration are compatible.
private partial def checkHeaders (rs : Array ElabHeaderResult) (numParams : Nat) (i : Nat) (firstType? : Option Expr) : TermElabM Unit := do
if h : i < rs.size then
let type checkHeader rs[i] numParams firstType?
checkHeaders rs numParams (i+1) type
private def elabHeader (views : Array InductiveView) : TermElabM (Array ElabHeaderResult) := do
let rs elabHeaderAux views 0 #[]
if rs.size > 1 then
checkUnsafe rs
let numParams checkNumParams rs
checkHeaders rs numParams 0 none
return rs
/-- Create a local declaration for each inductive type in `rs`, and execute `x params indFVars`, where `params` are the inductive type parameters and
`indFVars` are the new local declarations.
We use the local context/instances and parameters of rs[0].
Note that this method is executed after we executed `checkHeaders` and established all
parameters are compatible. -/
private partial def withInductiveLocalDecls (rs : Array ElabHeaderResult) (x : Array Expr Array Expr TermElabM α) : TermElabM α := do
let namesAndTypes rs.mapM fun r => do
let type mkTypeFor r
pure (r.view.declName, r.view.shortDeclName, type)
let r0 := rs[0]!
let params := r0.params
withLCtx r0.lctx r0.localInsts <| withRef r0.view.ref do
let rec loop (i : Nat) (indFVars : Array Expr) := do
if h : i < namesAndTypes.size then
let (declName, shortDeclName, type) := namesAndTypes[i]
Term.withAuxDecl shortDeclName type declName fun indFVar => loop (i+1) (indFVars.push indFVar)
else
x params indFVars
loop 0 #[]
private def isInductiveFamily (numParams : Nat) (indFVar : Expr) : TermElabM Bool := do
let indFVarType inferType indFVar
forallTelescopeReducing indFVarType fun xs _ =>
@@ -81,8 +271,8 @@ where
| _ => acc
/--
Replaces binder names in `type` with `newNames`.
Remark: we only replace the names for binder containing macroscopes.
Replace binder names in `type` with `newNames`.
Remark: we only replace the names for binder containing macroscopes.
-/
private def replaceArrowBinderNames (type : Expr) (newNames : Array Name) : Expr :=
go type 0
@@ -100,20 +290,20 @@ where
type
/--
Reorders constructor arguments to improve the effectiveness of the `fixedIndicesToParams` method.
Reorder constructor arguments to improve the effectiveness of the `fixedIndicesToParams` method.
The idea is quite simple. Given a constructor type of the form
```
(a₁ : A₁) → ... → (aₙ : Aₙ) → C b₁ ... bₘ
```
We try to find the longest prefix `b₁ ... bᵢ`, `i ≤ m` s.t.
- each `bₖ` is in `{a₁, ..., aₙ}`
- each `bₖ` only depends on variables in `{b₁, ..., bₖ₋₁}`
The idea is quite simple. Given a constructor type of the form
```
(a₁ : A₁) → ... → (aₙ : Aₙ) → C b₁ ... bₘ
```
We try to find the longest prefix `b₁ ... bᵢ`, `i ≤ m` s.t.
- each `bₖ` is in `{a₁, ..., aₙ}`
- each `bₖ` only depends on variables in `{b₁, ..., bₖ₋₁}`
Then, it moves this prefix `b₁ ... bᵢ` to the front.
Then, it moves this prefix `b₁ ... bᵢ` to the front.
Remark: We only reorder implicit arguments that have macroscopes. See issue #1156.
The macroscope test is an approximation, we could have restricted ourselves to auto-implicit arguments.
Remark: We only reorder implicit arguments that have macroscopes. See issue #1156.
The macroscope test is an approximation, we could have restricted ourselves to auto-implicit arguments.
-/
private def reorderCtorArgs (ctorType : Expr) : MetaM Expr := do
forallTelescopeReducing ctorType fun as type => do
@@ -158,6 +348,16 @@ private def reorderCtorArgs (ctorType : Expr) : MetaM Expr := do
let binderNames := getArrowBinderNames ( instantiateMVars ( inferType C))
return replaceArrowBinderNames r binderNames[:bsPrefix.size]
/--
Execute `k` with updated binder information for `xs`. Any `x` that is explicit becomes implicit.
-/
private def withExplicitToImplicit (xs : Array Expr) (k : TermElabM α) : TermElabM α := do
let mut toImplicit := #[]
for x in xs do
if ( getFVarLocalDecl x).binderInfo.isExplicit then
toImplicit := toImplicit.push (x.fvarId!, BinderInfo.implicit)
withNewBinderInfos toImplicit k
/--
Elaborate constructor types.
@@ -166,20 +366,17 @@ private def reorderCtorArgs (ctorType : Expr) : MetaM Expr := do
- Positivity (it is a rare failure, and the kernel already checks for it).
- Universe constraints (the kernel checks for it).
-/
private def elabCtors (indFVars : Array Expr) (params : Array Expr) (r : ElabHeaderResult) : TermElabM (List Constructor) :=
withRef r.view.ref do
withExplicitToImplicit params do
let indFVar := r.indFVar
private def elabCtors (indFVars : Array Expr) (indFVar : Expr) (params : Array Expr) (r : ElabHeaderResult) : TermElabM (List Constructor) := withRef r.view.ref do
let indFamily isInductiveFamily params.size indFVar
r.view.ctors.toList.mapM fun ctorView =>
Term.withAutoBoundImplicit <| Term.elabBinders ctorView.binders.getArgs fun ctorParams =>
withRef ctorView.ref do
let elabCtorType : TermElabM Expr := do
let rec elabCtorType (k : Expr TermElabM Constructor) : TermElabM Constructor := do
match ctorView.type? with
| none =>
if indFamily then
throwError "constructor resulting type must be specified in inductive family declaration"
return mkAppN indFVar params
k <| mkAppN indFVar params
| some ctorType =>
let type Term.elabType ctorType
trace[Elab.inductive] "elabType {ctorView.declName} : {type} "
@@ -191,43 +388,43 @@ private def elabCtors (indFVars : Array Expr) (params : Array Expr) (r : ElabHea
throwError "unexpected constructor resulting type{indentExpr resultingType}"
unless ( isType resultingType) do
throwError "unexpected constructor resulting type, type expected{indentExpr resultingType}"
return type
let type elabCtorType
Term.synthesizeSyntheticMVarsNoPostponing
let ctorParams Term.addAutoBoundImplicits ctorParams
let except (mvarId : MVarId) := ctorParams.any fun ctorParam => ctorParam.isMVar && ctorParam.mvarId! == mvarId
/-
We convert metavariables in the resulting type into extra parameters. Otherwise, we would not be able to elaborate
declarations such as
```
inductive Palindrome : List α → Prop where
| nil : Palindrome [] -- We would get an error here saying "failed to synthesize implicit argument" at `@List.nil ?m`
| single : (a : α) → Palindrome [a]
| sandwich : (a : α) → Palindrome as → Palindrome ([a] ++ as ++ [a])
```
We used to also collect unassigned metavariables on `ctorParams`, but it produced counterintuitive behavior.
For example, the following declaration used to be accepted.
```
inductive Foo
| bar (x)
k type
elabCtorType fun type => do
Term.synthesizeSyntheticMVarsNoPostponing
let ctorParams Term.addAutoBoundImplicits ctorParams
let except (mvarId : MVarId) := ctorParams.any fun ctorParam => ctorParam.isMVar && ctorParam.mvarId! == mvarId
/-
We convert metavariables in the resulting type info extra parameters. Otherwise, we would not be able to elaborate
declarations such as
```
inductive Palindrome : List α → Prop where
| nil : Palindrome [] -- We would get an error here saying "failed to synthesize implicit argument" at `@List.nil ?m`
| single : (a : α) → Palindrome [a]
| sandwich : (a : α) → Palindrome as → Palindrome ([a] ++ as ++ [a])
```
We used to also collect unassigned metavariables on `ctorParams`, but it produced counterintuitive behavior.
For example, the following declaration used to be accepted.
```
inductive Foo
| bar (x)
#check Foo.bar
-- @Foo.bar : {x : Sort u_1} → x → Foo
```
which is also inconsistent with the behavior of auto implicits in definitions. For example, the following example was never accepted.
```
def bar (x) := 1
```
-/
let extraCtorParams Term.collectUnassignedMVars ( instantiateMVars type) #[] except
trace[Elab.inductive] "extraCtorParams: {extraCtorParams}"
/- We must abstract `extraCtorParams` and `ctorParams` simultaneously to make
sure we do not create auxiliary metavariables. -/
let type mkForallFVars (extraCtorParams ++ ctorParams) type
let type reorderCtorArgs type
let type mkForallFVars params type
trace[Elab.inductive] "{ctorView.declName} : {type}"
return { name := ctorView.declName, type }
#check Foo.bar
-- @Foo.bar : {x : Sort u_1} → x → Foo
```
which is also inconsistent with the behavior of auto implicits in definitions. For example, the following example was never accepted.
```
def bar (x) := 1
```
-/
let extraCtorParams Term.collectUnassignedMVars ( instantiateMVars type) #[] except
trace[Elab.inductive] "extraCtorParams: {extraCtorParams}"
/- We must abstract `extraCtorParams` and `ctorParams` simultaneously to make
sure we do not create auxiliary metavariables. -/
let type mkForallFVars (extraCtorParams ++ ctorParams) type
let type reorderCtorArgs type
let type mkForallFVars params type
trace[Elab.inductive] "{ctorView.declName} : {type}"
return { name := ctorView.declName, type }
where
checkParamOccs (ctorType : Expr) : MetaM Expr :=
let visit (e : Expr) : MetaM TransformStep := do
@@ -247,15 +444,555 @@ where
return .continue
transform ctorType (pre := visit)
@[builtin_inductive_elab Lean.Parser.Command.inductive, builtin_inductive_elab Lean.Parser.Command.classInductive]
def elabInductiveCommand : InductiveElabDescr where
mkInductiveView (modifiers : Modifiers) (stx : Syntax) := do
let view inductiveSyntaxToView modifiers stx
return {
view
elabCtors := fun rs r params => do
let ctors elabCtors (rs.map (·.indFVar)) params r
return { ctors }
}
private def getResultingUniverse : List InductiveType TermElabM Level
| [] => throwError "unexpected empty inductive declaration"
| indType :: _ => forallTelescopeReducing indType.type fun _ r => do
let r whnfD r
match r with
| Expr.sort u => return u
| _ => throwError "unexpected inductive type resulting type{indentExpr r}"
/--
Return `some ?m` if `u` is of the form `?m + k`.
Return none if `u` does not contain universe metavariables.
Throw exception otherwise. -/
def shouldInferResultUniverse (u : Level) : TermElabM (Option LMVarId) := do
let u instantiateLevelMVars u
if u.hasMVar then
match u.getLevelOffset with
| Level.mvar mvarId => return some mvarId
| _ =>
throwError "cannot infer resulting universe level of inductive datatype, given level contains metavariables {mkSort u}, provide universe explicitly"
else
return none
/--
Convert universe metavariables into new parameters. It skips `univToInfer?` (the inductive datatype resulting universe) because
it should be inferred later using `inferResultingUniverse`.
-/
private def levelMVarToParam (indTypes : List InductiveType) (univToInfer? : Option LMVarId) : TermElabM (List InductiveType) :=
indTypes.mapM fun indType => do
let type levelMVarToParam' indType.type
let ctors indType.ctors.mapM fun ctor => do
let ctorType levelMVarToParam' ctor.type
return { ctor with type := ctorType }
return { indType with ctors, type }
where
levelMVarToParam' (type : Expr) : TermElabM Expr := do
Term.levelMVarToParam type (except := fun mvarId => univToInfer? == some mvarId)
def mkResultUniverse (us : Array Level) (rOffset : Nat) (preferProp : Bool) : Level :=
if us.isEmpty && rOffset == 0 then
if preferProp then levelZero else levelOne
else
let r := Level.mkNaryMax us.toList
if rOffset == 0 && !r.isZero && !r.isNeverZero then
mkLevelMax r levelOne |>.normalize
else
r.normalize
/--
Auxiliary function for `updateResultingUniverse`
`accLevel u r rOffset` add `u` to state if it is not already there and
it is different from the resulting universe level `r+rOffset`.
If `u` is a `max`, then its components are recursively processed.
If `u` is a `succ` and `rOffset > 0`, we process the `u`s child using `rOffset-1`.
This method is used to infer the resulting universe level of an inductive datatype.
-/
def accLevel (u : Level) (r : Level) (rOffset : Nat) : OptionT (StateT (Array Level) Id) Unit := do
go u rOffset
where
go (u : Level) (rOffset : Nat) : OptionT (StateT (Array Level) Id) Unit := do
match u, rOffset with
| .max u v, rOffset => go u rOffset; go v rOffset
| .imax u v, rOffset => go u rOffset; go v rOffset
| .zero, _ => return ()
| .succ u, rOffset+1 => go u rOffset
| u, rOffset =>
if rOffset == 0 && u == r then
return ()
else if r.occurs u then
failure
else if rOffset > 0 then
failure
else if ( get).contains u then
return ()
else
modify fun us => us.push u
/--
Auxiliary function for `updateResultingUniverse`
`accLevelAtCtor ctor ctorParam r rOffset` add `u` (`ctorParam`'s universe) to state if it is not already there and
it is different from the resulting universe level `r+rOffset`.
See `accLevel`.
-/
def accLevelAtCtor (ctor : Constructor) (ctorParam : Expr) (r : Level) (rOffset : Nat) : StateRefT (Array Level) TermElabM Unit := do
let type inferType ctorParam
let u instantiateLevelMVars ( getLevel type)
match ( modifyGet fun s => accLevel u r rOffset |>.run |>.run s) with
| some _ => pure ()
| none =>
let typeType inferType type
let mut msg := m!"failed to compute resulting universe level of inductive datatype, constructor '{ctor.name}' has type{indentExpr ctor.type}\nparameter"
let localDecl getFVarLocalDecl ctorParam
unless localDecl.userName.hasMacroScopes do
msg := msg ++ m!" '{ctorParam}'"
msg := msg ++ m!" has type{indentD m!"{type} : {typeType}"}\ninductive type resulting type{indentExpr (mkSort (r.addOffset rOffset))}"
if r.isMVar then
msg := msg ++ "\nrecall that Lean only infers the resulting universe level automatically when there is a unique solution for the universe level constraints, consider explicitly providing the inductive type resulting universe level"
throwError msg
/--
Execute `k` using the `Syntax` reference associated with constructor `ctorName`.
-/
def withCtorRef [Monad m] [MonadRef m] (views : Array InductiveView) (ctorName : Name) (k : m α) : m α := do
for view in views do
for ctorView in view.ctors do
if ctorView.declName == ctorName then
return ( withRef ctorView.ref k)
k
/-- Auxiliary function for `updateResultingUniverse` -/
private partial def collectUniverses (views : Array InductiveView) (r : Level) (rOffset : Nat) (numParams : Nat) (indTypes : List InductiveType) : TermElabM (Array Level) := do
let (_, us) go |>.run #[]
return us
where
go : StateRefT (Array Level) TermElabM Unit :=
indTypes.forM fun indType => indType.ctors.forM fun ctor =>
withCtorRef views ctor.name do
forallTelescopeReducing ctor.type fun ctorParams _ =>
for ctorParam in ctorParams[numParams:] do
accLevelAtCtor ctor ctorParam r rOffset
/--
Decides whether the inductive type should be `Prop`-valued when the universe is not given
and when the universe inference algorithm `collectUniverses` determines
that the inductive type could naturally be `Prop`-valued.
Recall: the natural universe level is the mimimum universe level for all the types of all the constructor parameters.
Heuristic:
- We want `Prop` when each inductive type is a syntactic subsingleton.
That's to say, when each inductive type has at most one constructor.
Such types carry no data anyway.
- Exception: if no inductive type has any constructors, these are likely stubbed-out declarations,
so we prefer `Type` instead.
- Exception: if each constructor has no parameters, then these are likely partially-written enumerations,
so we prefer `Type` instead.
-/
private def isPropCandidate (numParams : Nat) (indTypes : List InductiveType) : MetaM Bool := do
unless indTypes.foldl (fun n indType => max n indType.ctors.length) 0 == 1 do
return false
for indType in indTypes do
for ctor in indType.ctors do
let cparams forallTelescopeReducing ctor.type fun ctorParams _ => pure (ctorParams.size - numParams)
unless cparams == 0 do
return true
return false
private def updateResultingUniverse (views : Array InductiveView) (numParams : Nat) (indTypes : List InductiveType) : TermElabM (List InductiveType) := do
let r getResultingUniverse indTypes
let rOffset : Nat := r.getOffset
let r : Level := r.getLevelOffset
unless r.isMVar do
throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly: {r}"
let us collectUniverses views r rOffset numParams indTypes
trace[Elab.inductive] "updateResultingUniverse us: {us}, r: {r}, rOffset: {rOffset}"
let rNew := mkResultUniverse us rOffset ( isPropCandidate numParams indTypes)
assignLevelMVar r.mvarId! rNew
indTypes.mapM fun indType => do
let type instantiateMVars indType.type
let ctors indType.ctors.mapM fun ctor => return { ctor with type := ( instantiateMVars ctor.type) }
return { indType with type, ctors }
register_builtin_option bootstrap.inductiveCheckResultingUniverse : Bool := {
defValue := true,
group := "bootstrap",
descr := "by default the `inductive`/`structure` commands report an error if the resulting universe is not zero, but may be zero for some universe parameters. Reason: unless this type is a subsingleton, it is hardly what the user wants since it can only eliminate into `Prop`. In the `Init` package, we define subsingletons, and we use this option to disable the check. This option may be deleted in the future after we improve the validator"
}
def checkResultingUniverse (u : Level) : TermElabM Unit := do
if bootstrap.inductiveCheckResultingUniverse.get ( getOptions) then
let u instantiateLevelMVars u
if !u.isZero && !u.isNeverZero then
throwError "invalid universe polymorphic type, the resultant universe is not Prop (i.e., 0), but it may be Prop for some parameter values (solution: use 'u+1' or 'max 1 u'){indentD u}"
private def checkResultingUniverses (views : Array InductiveView) (numParams : Nat) (indTypes : List InductiveType) : TermElabM Unit := do
let u := ( instantiateLevelMVars ( getResultingUniverse indTypes)).normalize
checkResultingUniverse u
unless u.isZero do
indTypes.forM fun indType => indType.ctors.forM fun ctor =>
forallTelescopeReducing ctor.type fun ctorArgs _ => do
for ctorArg in ctorArgs[numParams:] do
let type inferType ctorArg
let v := ( instantiateLevelMVars ( getLevel type)).normalize
let rec check (v' : Level) (u' : Level) : TermElabM Unit :=
match v', u' with
| .succ v', .succ u' => check v' u'
| .mvar id, .param .. =>
/- Special case:
The constructor parameter `v` is at unverse level `?v+k` and
the resulting inductive universe level is `u'+k`, where `u'` is a parameter (or zero).
Thus, `?v := u'` is the only choice for satisfying the universe constraint `?v+k <= u'+k`.
Note that, we still generate an error for cases where there is more than one of satisfying the constraint.
Examples:
-----------------------------------------------------------
| ctor universe level | inductive datatype universe level |
-----------------------------------------------------------
| ?v | max u w |
-----------------------------------------------------------
| ?v | u + 1 |
-----------------------------------------------------------
-/
assignLevelMVar id u'
| .mvar id, .zero => assignLevelMVar id u' -- TODO: merge with previous case
| _, _ =>
unless u.geq v do
let mut msg := m!"invalid universe level in constructor '{ctor.name}', parameter"
let localDecl getFVarLocalDecl ctorArg
unless localDecl.userName.hasMacroScopes do
msg := msg ++ m!" '{ctorArg}'"
msg := msg ++ m!" has type{indentExpr type}"
msg := msg ++ m!"\nat universe level{indentD v}"
msg := msg ++ m!"\nit must be smaller than or equal to the inductive datatype universe level{indentD u}"
withCtorRef views ctor.name <| throwError msg
check v u
private def collectUsed (indTypes : List InductiveType) : StateRefT CollectFVars.State MetaM Unit := do
indTypes.forM fun indType => do
indType.type.collectFVars
indType.ctors.forM fun ctor =>
ctor.type.collectFVars
private def removeUnused (vars : Array Expr) (indTypes : List InductiveType) : TermElabM (LocalContext × LocalInstances × Array Expr) := do
let (_, used) (collectUsed indTypes).run {}
Meta.removeUnused vars used
private def withUsed {α} (vars : Array Expr) (indTypes : List InductiveType) (k : Array Expr TermElabM α) : TermElabM α := do
let (lctx, localInsts, vars) removeUnused vars indTypes
withLCtx lctx localInsts <| k vars
private def updateParams (vars : Array Expr) (indTypes : List InductiveType) : TermElabM (List InductiveType) :=
indTypes.mapM fun indType => do
let type mkForallFVars vars indType.type
let ctors indType.ctors.mapM fun ctor => do
let ctorType withExplicitToImplicit vars (mkForallFVars vars ctor.type)
return { ctor with type := ctorType }
return { indType with type, ctors }
private def collectLevelParamsInInductive (indTypes : List InductiveType) : Array Name := Id.run do
let mut usedParams : CollectLevelParams.State := {}
for indType in indTypes do
usedParams := collectLevelParams usedParams indType.type
for ctor in indType.ctors do
usedParams := collectLevelParams usedParams ctor.type
return usedParams.params
private def mkIndFVar2Const (views : Array InductiveView) (indFVars : Array Expr) (levelNames : List Name) : ExprMap Expr := Id.run do
let levelParams := levelNames.map mkLevelParam;
let mut m : ExprMap Expr := {}
for h : i in [:views.size] do
let view := views[i]
let indFVar := indFVars[i]!
m := m.insert indFVar (mkConst view.declName levelParams)
return m
/-- Remark: `numVars <= numParams`. `numVars` is the number of context `variables` used in the inductive declaration,
and `numParams` is `numVars` + number of explicit parameters provided in the declaration. -/
private def replaceIndFVarsWithConsts (views : Array InductiveView) (indFVars : Array Expr) (levelNames : List Name)
(numVars : Nat) (numParams : Nat) (indTypes : List InductiveType) : TermElabM (List InductiveType) :=
let indFVar2Const := mkIndFVar2Const views indFVars levelNames
indTypes.mapM fun indType => do
let ctors indType.ctors.mapM fun ctor => do
let type forallBoundedTelescope ctor.type numParams fun params type => do
let type := type.replace fun e =>
if !e.isFVar then
none
else match indFVar2Const[e]? with
| none => none
| some c => mkAppN c (params.extract 0 numVars)
instantiateMVars ( mkForallFVars params type)
return { ctor with type }
return { indType with ctors }
private def mkAuxConstructions (views : Array InductiveView) : TermElabM Unit := do
let env getEnv
let hasEq := env.contains ``Eq
let hasHEq := env.contains ``HEq
let hasUnit := env.contains ``PUnit
let hasProd := env.contains ``Prod
for view in views do
let n := view.declName
mkRecOn n
if hasUnit then mkCasesOn n
if hasUnit && hasEq && hasHEq then mkNoConfusion n
if hasUnit && hasProd then mkBelow n
if hasUnit && hasProd then mkIBelow n
for view in views do
let n := view.declName;
if hasUnit && hasProd then mkBRecOn n
if hasUnit && hasProd then mkBInductionOn n
private def getArity (indType : InductiveType) : MetaM Nat :=
forallTelescopeReducing indType.type fun xs _ => return xs.size
private def resetMaskAt (mask : Array Bool) (i : Nat) : Array Bool :=
mask.setD i false
/--
Compute a bit-mask that for `indType`. The size of the resulting array `result` is the arity of `indType`.
The first `numParams` elements are `false` since they are parameters.
For `i ∈ [numParams, arity)`, we have that `result[i]` if this index of the inductive family is fixed.
-/
private def computeFixedIndexBitMask (numParams : Nat) (indType : InductiveType) (indFVars : Array Expr) : MetaM (Array Bool) := do
let arity getArity indType
if arity numParams then
return mkArray arity false
else
let maskRef IO.mkRef (mkArray numParams false ++ mkArray (arity - numParams) true)
let rec go (ctors : List Constructor) : MetaM (Array Bool) := do
match ctors with
| [] => maskRef.get
| ctor :: ctors =>
forallTelescopeReducing ctor.type fun xs type => do
let typeArgs := type.getAppArgs
for i in [numParams:arity] do
unless i < xs.size && xs[i]! == typeArgs[i]! do -- Remark: if we want to allow arguments to be rearranged, this test should be xs.contains typeArgs[i]
maskRef.modify fun mask => mask.set! i false
for x in xs[numParams:] do
let xType inferType x
let cond (e : Expr) := indFVars.any (fun indFVar => e.getAppFn == indFVar)
xType.forEachWhere (stopWhenVisited := true) cond fun e => do
let eArgs := e.getAppArgs
for i in [numParams:eArgs.size] do
if i >= typeArgs.size then
maskRef.modify (resetMaskAt · i)
else
unless eArgs[i]! == typeArgs[i]! do
maskRef.modify (resetMaskAt · i)
/-If an index is missing in the arguments of the inductive type, then it must be non-fixed.
Consider the following example:
```lean
inductive All {I : Type u} (P : I → Type v) : List I → Type (max u v) where
| cons : P x → All P xs → All P (x :: xs)
inductive Iμ {I : Type u} : I → Type (max u v) where
| mk : (i : I) → All Iμ [] → Iμ i
```
because `i` doesn't appear in `All Iμ []`, the index shouldn't be fixed.
-/
for i in [eArgs.size:arity] do
maskRef.modify (resetMaskAt · i)
go ctors
go indType.ctors
/-- Return true iff `arrowType` is an arrow and its domain is defeq to `type` -/
private def isDomainDefEq (arrowType : Expr) (type : Expr) : MetaM Bool := do
if !arrowType.isForall then
return false
else
/-
We used to use `withNewMCtxDepth` to make sure we do not assign universe metavariables,
but it was not satisfactory. For example, in declarations such as
```
inductive Eq : αα → Prop where
| refl (a : α) : Eq a a
```
We want the first two indices to be promoted to parameters, and this will only
happen if we can assign universe metavariables.
-/
isDefEq arrowType.bindingDomain! type
/--
Convert fixed indices to parameters.
-/
private partial def fixedIndicesToParams (numParams : Nat) (indTypes : Array InductiveType) (indFVars : Array Expr) : MetaM Nat := do
if !inductive.autoPromoteIndices.get ( getOptions) then
return numParams
let masks indTypes.mapM (computeFixedIndexBitMask numParams · indFVars)
trace[Elab.inductive] "masks: {masks}"
if masks.all fun mask => !mask.contains true then
return numParams
-- We process just a non-fixed prefix of the indices for now. Reason: we don't want to change the order.
-- TODO: extend it in the future. For example, it should be reasonable to change
-- the order of indices generated by the auto implicit feature.
let mask := masks[0]!
forallBoundedTelescope indTypes[0]!.type numParams fun params type => do
let otherTypes indTypes[1:].toArray.mapM fun indType => do whnfD ( instantiateForall indType.type params)
let ctorTypes indTypes.toList.mapM fun indType => indType.ctors.mapM fun ctor => do whnfD ( instantiateForall ctor.type params)
let typesToCheck := otherTypes.toList ++ ctorTypes.flatten
let rec go (i : Nat) (type : Expr) (typesToCheck : List Expr) : MetaM Nat := do
if i < mask.size then
if !masks.all fun mask => i < mask.size && mask[i]! then
return i
if !type.isForall then
return i
let paramType := type.bindingDomain!
if !( typesToCheck.allM fun type => isDomainDefEq type paramType) then
trace[Elab.inductive] "domain not def eq: {i}, {type} =?= {paramType}"
return i
withLocalDeclD `a paramType fun paramNew => do
let typesToCheck typesToCheck.mapM fun type => whnfD (type.bindingBody!.instantiate1 paramNew)
go (i+1) (type.bindingBody!.instantiate1 paramNew) typesToCheck
else
return i
go numParams type typesToCheck
private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) : TermElabM Unit := Term.withoutSavingRecAppSyntax do
let view0 := views[0]!
let scopeLevelNames Term.getLevelNames
InductiveView.checkLevelNames views
let allUserLevelNames := view0.levelNames
let isUnsafe := view0.modifiers.isUnsafe
withRef view0.ref <| Term.withLevelNames allUserLevelNames do
let rs elabHeader views
Term.synthesizeSyntheticMVarsNoPostponing
ElabHeaderResult.checkLevelNames rs
let allUserLevelNames := rs[0]!.levelNames
trace[Elab.inductive] "level names: {allUserLevelNames}"
withInductiveLocalDecls rs fun params indFVars => do
trace[Elab.inductive] "indFVars: {indFVars}"
let mut indTypesArray := #[]
for h : i in [:views.size] do
let indFVar := indFVars[i]!
Term.addLocalVarInfo views[i].declId indFVar
let r := rs[i]!
/- At this point, because of `withInductiveLocalDecls`, the only fvars that are in context are the ones related to the first inductive type.
Because of this, we need to replace the fvars present in each inductive type's header of the mutual block with those of the first inductive.
However, some mvars may still be uninstantiated there, and might hide some of the old fvars.
As such we first need to synthesize all possible mvars at this stage, instantiate them in the header types and only
then replace the parameters' fvars in the header type.
See issue #3242 (`https://github.com/leanprover/lean4/issues/3242`)
-/
let type instantiateMVars r.type
let type := type.replaceFVars r.params params
let type mkForallFVars params type
let ctors withExplicitToImplicit params (elabCtors indFVars indFVar params r)
indTypesArray := indTypesArray.push { name := r.view.declName, type, ctors }
let numExplicitParams fixedIndicesToParams params.size indTypesArray indFVars
trace[Elab.inductive] "numExplicitParams: {numExplicitParams}"
let indTypes := indTypesArray.toList
let u getResultingUniverse indTypes
let univToInfer? shouldInferResultUniverse u
withUsed vars indTypes fun vars => do
let numVars := vars.size
let numParams := numVars + numExplicitParams
let indTypes updateParams vars indTypes
let indTypes if let some univToInfer := univToInfer? then
updateResultingUniverse views numParams ( levelMVarToParam indTypes univToInfer)
else
checkResultingUniverses views numParams indTypes
levelMVarToParam indTypes none
let usedLevelNames := collectLevelParamsInInductive indTypes
match sortDeclLevelParams scopeLevelNames allUserLevelNames usedLevelNames with
| .error msg => throwError msg
| .ok levelParams => do
let indTypes replaceIndFVarsWithConsts views indFVars levelParams numVars numParams indTypes
let decl := Declaration.inductDecl levelParams numParams indTypes isUnsafe
Term.ensureNoUnassignedMVars decl
addDecl decl
mkAuxConstructions views
withSaveInfoContext do -- save new env
for view in views do
Term.addTermInfo' view.ref[1] ( mkConstWithLevelParams view.declName) (isBinder := true)
for ctor in view.ctors do
Term.addTermInfo' ctor.ref[3] ( mkConstWithLevelParams ctor.declName) (isBinder := true)
-- We need to invoke `applyAttributes` because `class` is implemented as an attribute.
Term.applyAttributesAt view.declName view.modifiers.attrs .afterTypeChecking
private def applyDerivingHandlers (views : Array InductiveView) : CommandElabM Unit := do
let mut processed : NameSet := {}
for view in views do
for classView in view.derivingClasses do
let className := classView.className
unless processed.contains className do
processed := processed.insert className
let mut declNames := #[]
for view in views do
if view.derivingClasses.any fun classView => classView.className == className then
declNames := declNames.push view.declName
classView.applyHandlers declNames
private def applyComputedFields (indViews : Array InductiveView) : CommandElabM Unit := do
if indViews.all (·.computedFields.isEmpty) then return
let mut computedFields := #[]
let mut computedFieldDefs := #[]
for indView@{declName, ..} in indViews do
for {ref, fieldId, type, matchAlts, modifiers, ..} in indView.computedFields do
computedFieldDefs := computedFieldDefs.push <| do
let modifiers match modifiers with
| `(Lean.Parser.Command.declModifiersT| $[$doc:docComment]? $[$attrs:attributes]? $[$vis]? $[noncomputable]?) =>
`(Lean.Parser.Command.declModifiersT| $[$doc]? $[$attrs]? $[$vis]? noncomputable)
| _ => do
withRef modifiers do logError "unsupported modifiers for computed field"
`(Parser.Command.declModifiersT| noncomputable)
`($(modifiers):declModifiers
def%$ref $(mkIdent <| `_root_ ++ declName ++ fieldId):ident : $type $matchAlts:matchAlts)
let computedFieldNames := indView.computedFields.map fun {fieldId, ..} => declName ++ fieldId
computedFields := computedFields.push (declName, computedFieldNames)
withScope (fun scope => { scope with
opts := scope.opts
|>.setBool `bootstrap.genMatcherCode false
|>.setBool `elaboratingComputedFields true}) <|
elabCommand <| `(mutual $computedFieldDefs* end)
liftTermElabM do Term.withDeclName indViews[0]!.declName do
ComputedFields.setComputedFields computedFields
def elabInductiveViews (vars : Array Expr) (views : Array InductiveView) : TermElabM Unit := do
let view0 := views[0]!
let ref := view0.ref
Term.withDeclName view0.declName do withRef ref do
mkInductiveDecl vars views
mkSizeOfInstances view0.declName
Lean.Meta.IndPredBelow.mkBelow view0.declName
for view in views do
mkInjectiveTheorems view.declName
def elabInductiveViewsPostprocessing (views : Array InductiveView) : CommandElabM Unit := do
let view0 := views[0]!
let ref := view0.ref
applyComputedFields views -- NOTE: any generated code before this line is invalid
applyDerivingHandlers views
runTermElabM fun _ => Term.withDeclName view0.declName do withRef ref do
for view in views do
Term.applyAttributesAt view.declName view.modifiers.attrs .afterCompilation
def elabInductives (inductives : Array (Modifiers × Syntax)) : CommandElabM Unit := do
let vs runTermElabM fun vars => do
let vs inductives.mapM fun (modifiers, stx) => inductiveSyntaxToView modifiers stx
elabInductiveViews vars vs
pure vs
elabInductiveViewsPostprocessing vs
def elabInductive (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit := do
elabInductives #[(modifiers, stx)]
def elabClassInductive (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit := do
let modifiers := modifiers.addAttr { name := `class }
elabInductive modifiers stx
/--
Returns true if all elements of the `mutual` block (`Lean.Parser.Command.mutual`) are inductive declarations.
-/
def isMutualInductive (stx : Syntax) : Bool :=
stx[1].getArgs.all fun elem =>
let decl := elem[1]
let declKind := decl.getKind
declKind == `Lean.Parser.Command.inductive
/--
Elaborates a `mutual` block satisfying `Lean.Elab.Command.isMutualInductive`.
-/
def elabMutualInductive (elems : Array Syntax) : CommandElabM Unit := do
let inductives elems.mapM fun stx => do
let modifiers elabModifiers stx[0]
pure (modifiers, stx[1])
elabInductives inductives
end Lean.Elab.Command

View File

@@ -840,8 +840,8 @@ private def mkLetRecClosures (sectionVars : Array Expr) (mainFVarIds : Array FVa
abbrev Replacement := FVarIdMap Expr
def insertReplacementForMainFns (r : Replacement) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) : Replacement :=
mainFVars.size.fold (init := r) fun i _ r =>
r.insert mainFVars[i].fvarId! (mkAppN (Lean.mkConst mainHeaders[i]!.declName) sectionVars)
mainFVars.size.fold (init := r) fun i r =>
r.insert mainFVars[i]!.fvarId! (mkAppN (Lean.mkConst mainHeaders[i]!.declName) sectionVars)
def insertReplacementForLetRecs (r : Replacement) (letRecClosures : List LetRecClosure) : Replacement :=
@@ -871,8 +871,8 @@ def Replacement.apply (r : Replacement) (e : Expr) : Expr :=
def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainVals : Array Expr)
: TermElabM (Array PreDefinition) :=
mainHeaders.size.foldM (init := preDefs) fun i _ preDefs => do
let header := mainHeaders[i]
mainHeaders.size.foldM (init := preDefs) fun i preDefs => do
let header := mainHeaders[i]!
let termination declValToTerminationHint header.value
let termination := termination.rememberExtraParams header.numParams mainVals[i]!
let value mkLambdaFVars sectionVars mainVals[i]!

File diff suppressed because it is too large Load Diff

View File

@@ -17,7 +17,7 @@ open Lean.Parser.Command
private partial def antiquote (vars : Array Syntax) : Syntax Syntax
| stx => match stx with
| `($id:ident) =>
if vars.any (fun var => var.getId == id.getId) then
if (vars.findIdx? (fun var => var.getId == id.getId)).isSome then
mkAntiquotNode id (kind := `term) (isPseudoKind := true)
else
stx

View File

@@ -145,8 +145,8 @@ private partial def replaceRecApps (recArgInfos : Array RecArgInfo) (positions :
| Expr.app _ _ =>
let processApp (e : Expr) : StateRefT (HasConstCache recFnNames) M Expr :=
e.withApp fun f args => do
if let .some fnIdx := recArgInfos.findFinIdx? (f.isConstOf ·.fnName) then
let recArgInfo := recArgInfos[fnIdx]
if let .some fnIdx := recArgInfos.findIdx? (f.isConstOf ·.fnName) then
let recArgInfo := recArgInfos[fnIdx]!
let some recArg := args[recArgInfo.recArgPos]?
| throwError "insufficient number of parameters at recursive application {indentExpr e}"
-- For reflexive type, we may have nested recursive applications in recArg

View File

@@ -302,58 +302,59 @@ instance : ToFormat FieldLHS where
| .fieldIndex _ i => format i
| .modifyOp _ i => "[" ++ i.prettyPrint ++ "]"
mutual
/--
`FieldVal StructInstView` is a representation of a field value in the structure instance.
-/
inductive FieldVal where
/-- A `term` to use for the value of the field. -/
| term (stx : Syntax) : FieldVal
/-- A `StructInstView` to use for the value of a subobject field. -/
| nested (s : StructInstView) : FieldVal
/-- A field that was not provided and should be synthesized using default values. -/
| default : FieldVal
deriving Inhabited
/--
`FieldVal StructInstView` is a representation of a field value in the structure instance.
-/
inductive FieldVal (σ : Type) where
/-- A `term` to use for the value of the field. -/
| term (stx : Syntax) : FieldVal σ
/-- A `StructInstView` to use for the value of a subobject field. -/
| nested (s : σ) : FieldVal σ
/-- A field that was not provided and should be synthesized using default values. -/
| default : FieldVal σ
deriving Inhabited
/--
`Field StructInstView` is a representation of a field in the structure instance.
-/
structure Field where
/-- The whole field syntax. -/
ref : Syntax
/-- The LHS decomposed into components. -/
lhs : List FieldLHS
/-- The value of the field. -/
val : FieldVal
/-- The elaborated field value, filled in at `elabStruct`.
Missing fields use a metavariable for the elaborated value and are later solved for in `DefaultFields.propagate`. -/
expr? : Option Expr := none
deriving Inhabited
/--
The view for structure instance notation.
-/
structure StructInstView where
/-- The syntax for the whole structure instance. -/
ref : Syntax
/-- The name of the structure for the type of the structure instance. -/
structName : Name
/-- Used for default values, to propagate structure type parameters. It is initially empty, and then set at `elabStruct`. -/
params : Array (Name × Expr)
/-- The fields of the structure instance. -/
fields : List Field
/-- The additional sources for fields for the structure instance. -/
sources : SourcesView
deriving Inhabited
end
/--
`Field StructInstView` is a representation of a field in the structure instance.
-/
structure Field (σ : Type) where
/-- The whole field syntax. -/
ref : Syntax
/-- The LHS decomposed into components. -/
lhs : List FieldLHS
/-- The value of the field. -/
val : FieldVal σ
/-- The elaborated field value, filled in at `elabStruct`.
Missing fields use a metavariable for the elaborated value and are later solved for in `DefaultFields.propagate`. -/
expr? : Option Expr := none
deriving Inhabited
/--
Returns if the field has a single component in its LHS.
-/
def Field.isSimple : Field Bool
def Field.isSimple {σ} : Field σ Bool
| { lhs := [_], .. } => true
| _ => false
/--
The view for structure instance notation.
-/
structure StructInstView where
/-- The syntax for the whole structure instance. -/
ref : Syntax
/-- The name of the structure for the type of the structure instance. -/
structName : Name
/-- Used for default values, to propagate structure type parameters. It is initially empty, and then set at `elabStruct`. -/
params : Array (Name × Expr)
/-- The fields of the structure instance. -/
fields : List (Field StructInstView)
/-- The additional sources for fields for the structure instance. -/
sources : SourcesView
deriving Inhabited
/-- Abbreviation for the type of `StructInstView.fields`, namely `List (Field StructInstView)`. -/
abbrev Fields := List (Field StructInstView)
/-- `true` iff all fields of the given structure are marked as `default` -/
partial def StructInstView.allDefault (s : StructInstView) : Bool :=
s.fields.all fun { val := val, .. } => match val with
@@ -361,7 +362,7 @@ partial def StructInstView.allDefault (s : StructInstView) : Bool :=
| .default => true
| .nested s => allDefault s
def formatField (formatStruct : StructInstView Format) (field : Field) : Format :=
def formatField (formatStruct : StructInstView Format) (field : Field StructInstView) : Format :=
Format.joinSep field.lhs " . " ++ " := " ++
match field.val with
| .term v => v.prettyPrint
@@ -377,11 +378,11 @@ partial def formatStruct : StructInstView → Format
else
"{" ++ format (source.explicit.map (·.stx)) ++ " with " ++ fieldsFmt ++ implicitFmt ++ "}"
instance : ToFormat StructInstView := formatStruct
instance : ToFormat StructInstView := formatStruct
instance : ToString StructInstView := toString format
instance : ToFormat Field := formatField formatStruct
instance : ToString Field := toString format
instance : ToFormat (Field StructInstView) := formatField formatStruct
instance : ToString (Field StructInstView) := toString format
/--
Converts a `FieldLHS` back into syntax. This assumes the `ref` fields have the correct structure.
@@ -402,14 +403,14 @@ private def FieldLHS.toSyntax (first : Bool) : FieldLHS → Syntax
/--
Converts a `FieldVal StructInstView` back into syntax. Only supports `.term`, and it assumes the `stx` field has the correct structure.
-/
private def FieldVal.toSyntax : FieldVal Syntax
private def FieldVal.toSyntax : FieldVal Struct Syntax
| .term stx => stx
| _ => unreachable!
/--
Converts a `Field StructInstView` back into syntax. Used to construct synthetic structure instance notation for subobjects in `StructInst.expandStruct` processing.
-/
private def Field.toSyntax : Field Syntax
private def Field.toSyntax : Field Struct Syntax
| field =>
let stx := field.ref
let stx := stx.setArg 2 field.val.toSyntax
@@ -451,14 +452,14 @@ private def mkStructView (stx : Syntax) (structName : Name) (sources : SourcesVi
let val := fieldStx[2]
let first toFieldLHS fieldStx[0][0]
let rest fieldStx[0][1].getArgs.toList.mapM toFieldLHS
return { ref := fieldStx, lhs := first :: rest, val := FieldVal.term val : Field }
return { ref := fieldStx, lhs := first :: rest, val := FieldVal.term val : Field StructInstView }
return { ref := stx, structName, params := #[], fields, sources }
def StructInstView.modifyFieldsM {m : Type Type} [Monad m] (s : StructInstView) (f : List Field m (List Field)) : m StructInstView :=
def StructInstView.modifyFieldsM {m : Type Type} [Monad m] (s : StructInstView) (f : Fields m Fields) : m StructInstView :=
match s with
| { ref, structName, params, fields, sources } => return { ref, structName, params, fields := ( f fields), sources }
def StructInstView.modifyFields (s : StructInstView) (f : List Field List Field) : StructInstView :=
def StructInstView.modifyFields (s : StructInstView) (f : Fields Fields) : StructInstView :=
Id.run <| s.modifyFieldsM f
/-- Expands name field LHSs with multi-component names into multi-component LHSs. -/
@@ -524,14 +525,14 @@ private def expandParentFields (s : StructInstView) : TermElabM StructInstView :
| _ => throwErrorAt ref "failed to access field '{fieldName}' in parent structure"
| _ => return field
private abbrev FieldMap := Std.HashMap Name (List Field)
private abbrev FieldMap := Std.HashMap Name Fields
/--
Creates a hash map collecting all fields with the same first name component.
Throws an error if there are multiple simple fields with the same name.
Used by `StructInst.expandStruct` processing.
-/
private def mkFieldMap (fields : List Field) : TermElabM FieldMap :=
private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
fields.foldlM (init := {}) fun fieldMap field =>
match field.lhs with
| .fieldName _ fieldName :: _ =>
@@ -547,7 +548,7 @@ private def mkFieldMap (fields : List Field) : TermElabM FieldMap :=
/--
Given a value of the hash map created by `mkFieldMap`, returns true if the value corresponds to a simple field.
-/
private def isSimpleField? : List Field Option Field
private def isSimpleField? : Fields Option (Field StructInstView)
| [field] => if field.isSimple then some field else none
| _ => none
@@ -565,7 +566,7 @@ def mkProjStx? (s : Syntax) (structName : Name) (fieldName : Name) : TermElabM (
/--
Finds a simple field of the given name.
-/
def findField? (fields : List Field) (fieldName : Name) : Option Field :=
def findField? (fields : Fields) (fieldName : Name) : Option (Field StructInstView) :=
fields.find? fun field =>
match field.lhs with
| [.fieldName _ n] => n == fieldName
@@ -619,7 +620,7 @@ mutual
match findField? s.fields fieldName with
| some field => return field::fields
| none =>
let addField (val : FieldVal) : TermElabM (List Field) := do
let addField (val : FieldVal StructInstView) : TermElabM Fields := do
return { ref, lhs := [FieldLHS.fieldName ref fieldName], val := val } :: fields
match Lean.isSubobjectField? env s.structName fieldName with
| some substructName =>
@@ -772,7 +773,7 @@ private partial def elabStructInstView (s : StructInstView) (expectedType? : Opt
trace[Elab.struct] "elabStruct {field}, {type}"
match type with
| .forallE _ d b bi =>
let cont (val : Expr) (field : Field) (instMVars := instMVars) : TermElabM (Expr × Expr × List Field × Array MVarId) := do
let cont (val : Expr) (field : Field StructInstView) (instMVars := instMVars) : TermElabM (Expr × Expr × Fields × Array MVarId) := do
pushInfoTree <| InfoTree.node (children := {}) <| Info.ofFieldInfo {
projName := s.structName.append fieldName, fieldName, lctx := ( getLCtx), val, stx := ref }
let e := mkApp e val
@@ -878,7 +879,7 @@ partial def getHierarchyDepth (struct : StructInstView) : Nat :=
| _ => max
/-- Returns whether the field is still missing. -/
def isDefaultMissing? [Monad m] [MonadMCtx m] (field : Field) : m Bool := do
def isDefaultMissing? [Monad m] [MonadMCtx m] (field : Field Struct) : m Bool := do
if let some expr := field.expr? then
if let some (.mvar mvarId) := defaultMissing? expr then
unless ( mvarId.isAssigned) do
@@ -886,17 +887,17 @@ def isDefaultMissing? [Monad m] [MonadMCtx m] (field : Field) : m Bool := do
return false
/-- Returns a field that is still missing. -/
partial def findDefaultMissing? [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Option Field) :=
partial def findDefaultMissing? [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Option (Field StructInstView)) :=
struct.fields.findSomeM? fun field => do
match field.val with
| .nested struct => findDefaultMissing? struct
| _ => return if ( isDefaultMissing? field) then field else none
/-- Returns all fields that are still missing. -/
partial def allDefaultMissing [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Array Field) :=
partial def allDefaultMissing [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Array (Field StructInstView)) :=
go struct *> get |>.run' #[]
where
go (struct : StructInstView) : StateT (Array Field) m Unit :=
go (struct : StructInstView) : StateT (Array (Field StructInstView)) m Unit :=
for field in struct.fields do
if let .nested struct := field.val then
go struct
@@ -904,7 +905,7 @@ where
modify (·.push field)
/-- Returns the name of the field. Assumes all fields under consideration are simple and named. -/
def getFieldName (field : Field) : Name :=
def getFieldName (field : Field StructInstView) : Name :=
match field.lhs with
| [.fieldName _ fieldName] => fieldName
| _ => unreachable!

View File

@@ -4,15 +4,22 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Class
import Lean.Parser.Command
import Lean.Meta.Closure
import Lean.Meta.SizeOf
import Lean.Meta.Injective
import Lean.Meta.Structure
import Lean.Elab.MutualInductive
import Lean.Meta.AppBuilder
import Lean.Elab.Command
import Lean.Elab.DeclModifiers
import Lean.Elab.DeclUtil
import Lean.Elab.Inductive
import Lean.Elab.DeclarationRange
import Lean.Elab.Binders
namespace Lean.Elab.Command
builtin_initialize
registerTraceClass `Elab.structure
registerTraceClass `Elab.structure.resolutionOrder
register_builtin_option structureDiamondWarning : Bool := {
defValue := false
descr := "if true, enable warnings when a structure has diamond inheritance"
@@ -32,6 +39,13 @@ leading_parser (structureTk <|> classTk) >> declId >> many Term.bracketedBinder
```
-/
structure StructCtorView where
ref : Syntax
modifiers : Modifiers
name : Name
declName : Name
deriving Inhabited
structure StructFieldView where
ref : Syntax
modifiers : Modifiers
@@ -47,15 +61,22 @@ structure StructFieldView where
type? : Option Syntax
value? : Option Syntax
structure StructView extends InductiveView where
parents : Array Syntax
fields : Array StructFieldView
structure StructView where
ref : Syntax
declId : Syntax
modifiers : Modifiers
isClass : Bool -- struct-only
shortDeclName : Name
declName : Name
levelNames : List Name
binders : Syntax
type : Syntax -- modified (inductive has type?)
parents : Array Syntax -- struct-only
ctor : StructCtorView -- struct-only
fields : Array StructFieldView -- struct-only
derivingClasses : Array DerivingClassView
deriving Inhabited
def StructView.ctor : StructView CtorView
| { ctors := #[ctor], ..} => ctor
| _ => unreachable!
structure StructParentInfo where
ref : Syntax
fvar? : Option Expr
@@ -81,6 +102,18 @@ structure StructFieldInfo where
value? : Option Expr := none
deriving Inhabited, Repr
structure ElabStructHeaderResult where
view : StructView
lctx : LocalContext
localInsts : LocalInstances
levelNames : List Name
params : Array Expr
type : Expr
parents : Array StructParentInfo
/-- Field infos from parents. -/
parentFieldInfos : Array StructFieldInfo
deriving Inhabited
def StructFieldInfo.isFromParent (info : StructFieldInfo) : Bool :=
match info.kind with
| StructFieldKind.fromParent => true
@@ -97,12 +130,12 @@ The structure constructor syntax is
leading_parser try (declModifiers >> ident >> " :: ")
```
-/
private def expandCtor (structStx : Syntax) (structModifiers : Modifiers) (structDeclName : Name) : TermElabM CtorView := do
private def expandCtor (structStx : Syntax) (structModifiers : Modifiers) (structDeclName : Name) : TermElabM StructCtorView := do
let useDefault := do
let declName := structDeclName ++ defaultCtorName
let ref := structStx[1].mkSynthetic
addDeclarationRangesFromSyntax declName ref
pure { ref, declId := ref, modifiers := default, declName }
pure { ref, modifiers := default, name := defaultCtorName, declName }
if structStx[5].isNone then
useDefault
else
@@ -123,7 +156,7 @@ private def expandCtor (structStx : Syntax) (structModifiers : Modifiers) (struc
let declName applyVisibility ctorModifiers.visibility declName
addDocString' declName ctorModifiers.docString?
addDeclarationRangesFromSyntax declName ctor[1]
pure { ref := ctor[1], declId := ctor[1], modifiers := ctorModifiers, declName }
pure { ref := ctor[1], name, modifiers := ctorModifiers, declName }
def checkValidFieldModifier (modifiers : Modifiers) : TermElabM Unit := do
if modifiers.isNoncomputable then
@@ -238,7 +271,7 @@ def structureSyntaxToView (modifiers : Modifiers) (stx : Syntax) : TermElabM Str
let parents := if exts.isNone then #[] else exts[0][1].getSepArgs
let optType := stx[4]
let derivingClasses getOptDerivingClasses stx[6]
let type? := if optType.isNone then none else some optType[0][1]
let type if optType.isNone then `(Sort _) else pure optType[0][1]
let ctor expandCtor stx modifiers declName
let fields expandFields stx modifiers declName
fields.forM fun field => do
@@ -254,13 +287,10 @@ def structureSyntaxToView (modifiers : Modifiers) (stx : Syntax) : TermElabM Str
declName
levelNames
binders
type?
allowIndices := false
allowSortPolymorphism := false
ctors := #[ctor]
type
parents
ctor
fields
computedFields := #[]
derivingClasses
}
@@ -285,7 +315,7 @@ private def findExistingField? (infos : Array StructFieldInfo) (parentStructName
return some fieldName
return none
private def processSubfields (structDeclName : Name) (parentFVar : Expr) (parentStructName : Name) (subfieldNames : Array Name)
private partial def processSubfields (structDeclName : Name) (parentFVar : Expr) (parentStructName : Name) (subfieldNames : Array Name)
(infos : Array StructFieldInfo) (k : Array StructFieldInfo TermElabM α) : TermElabM α :=
go 0 infos
where
@@ -506,7 +536,7 @@ private partial def mkToParentName (parentStructName : Name) (p : Name → Bool)
if p curr then curr else go (i+1)
go 1
private def withParents (view : StructView) (rs : Array ElabHeaderResult) (indFVar : Expr)
private partial def elabParents (view : StructView)
(k : Array StructFieldInfo Array StructParentInfo TermElabM α) : TermElabM α := do
go 0 #[] #[]
where
@@ -514,17 +544,11 @@ where
if h : i < view.parents.size then
let parent := view.parents[i]
withRef parent do
-- The only use case for autobound implicits for parents might be outParams, but outParam is not propagated.
let type Term.withoutAutoBoundImplicit <| Term.elabType parent
let type Term.elabType parent
let parentType whnf type
if parentType.getAppFn == indFVar then
logWarning "structure extends itself, skipping"
return go (i + 1) infos parents
if rs.any (fun r => r.indFVar == parentType.getAppFn) then
throwError "structure cannot extend types defined in the same mutual block"
let parentStructName getStructureName parentType
if parents.any (fun info => info.structName == parentStructName) then
logWarning m!"duplicate parent structure '{.ofConstName parentStructName}', skipping"
logWarningAt parent m!"duplicate parent structure '{.ofConstName parentStructName}', skipping"
go (i + 1) infos parents
else if let some existingFieldName findExistingField? infos parentStructName then
if structureDiamondWarning.get ( getOptions) then
@@ -546,13 +570,6 @@ where
else
k infos parents
private def registerFailedToInferFieldType (fieldName : Name) (e : Expr) (ref : Syntax) : TermElabM Unit := do
Term.registerCustomErrorIfMVar ( instantiateMVars e) ref m!"failed to infer type of field '{.ofConstName fieldName}'"
private def registerFailedToInferDefaultValue (fieldName : Name) (e : Expr) (ref : Syntax) : TermElabM Unit := do
Term.registerCustomErrorIfMVar ( instantiateMVars e) ref m!"failed to infer default value for field '{.ofConstName fieldName}'"
Term.registerLevelMVarErrorExprInfo e ref m!"failed to infer universe levels in default value for field '{.ofConstName fieldName}'"
private def elabFieldTypeValue (view : StructFieldView) : TermElabM (Option Expr × Option Expr) :=
Term.withAutoBoundImplicit <| Term.withAutoBoundImplicitForbiddenPred (fun n => view.name == n) <| Term.elabBinders view.binders.getArgs fun params => do
match view.type? with
@@ -564,13 +581,10 @@ private def elabFieldTypeValue (view : StructFieldView) : TermElabM (Option Expr
-- TODO: add forbidden predicate using `shortDeclName` from `view`
let params Term.addAutoBoundImplicits params
let value Term.withoutAutoBoundImplicit <| Term.elabTerm valStx none
registerFailedToInferFieldType view.name ( inferType value) view.nameId
registerFailedToInferDefaultValue view.name value valStx
let value mkLambdaFVars params value
return (none, value)
| some typeStx =>
let type Term.elabType typeStx
registerFailedToInferFieldType view.name type typeStx
Term.synthesizeSyntheticMVarsNoPostponing
let params Term.addAutoBoundImplicits params
match view.value? with
@@ -579,7 +593,6 @@ private def elabFieldTypeValue (view : StructFieldView) : TermElabM (Option Expr
return (type, none)
| some valStx =>
let value Term.withoutAutoBoundImplicit <| Term.elabTermEnsuringType valStx type
registerFailedToInferDefaultValue view.name value valStx
Term.synthesizeSyntheticMVarsNoPostponing
let type mkForallFVars params type
let value mkLambdaFVars params value
@@ -626,7 +639,6 @@ where
valStx `(fun $(view.binders.getArgs)* => $valStx:term)
let fvarType inferType info.fvar
let value Term.elabTermEnsuringType valStx fvarType
registerFailedToInferDefaultValue view.name value valStx
pushInfoLeaf <| .ofFieldRedeclInfo { stx := view.ref }
let infos := replaceFieldInfo infos { info with ref := view.nameId, value? := value }
go (i+1) defaultValsOverridden infos
@@ -638,14 +650,113 @@ where
else
k infos
private def collectUsedFVars (lctx : LocalContext) (localInsts : LocalInstances) (fieldInfos : Array StructFieldInfo) :
StateRefT CollectFVars.State MetaM Unit := do
withLCtx lctx localInsts do
fieldInfos.forM fun info => do
let fvarType inferType info.fvar
fvarType.collectFVars
if let some value := info.value? then
value.collectFVars
private def getResultUniverse (type : Expr) : TermElabM Level := do
let type whnf type
match type with
| Expr.sort u => pure u
| _ => throwError "unexpected structure resulting type"
private def collectUsed (params : Array Expr) (fieldInfos : Array StructFieldInfo) : StateRefT CollectFVars.State MetaM Unit := do
params.forM fun p => do
let type inferType p
type.collectFVars
fieldInfos.forM fun info => do
let fvarType inferType info.fvar
fvarType.collectFVars
match info.value? with
| none => pure ()
| some value => value.collectFVars
private def removeUnused (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo)
: TermElabM (LocalContext × LocalInstances × Array Expr) := do
let (_, used) (collectUsed params fieldInfos).run {}
Meta.removeUnused scopeVars used
private def withUsed {α} (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo) (k : Array Expr TermElabM α)
: TermElabM α := do
let (lctx, localInsts, vars) removeUnused scopeVars params fieldInfos
withLCtx lctx localInsts <| k vars
private def levelMVarToParam (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo) (univToInfer? : Option LMVarId) : TermElabM (Array StructFieldInfo) := do
levelMVarToParamFVars scopeVars
levelMVarToParamFVars params
fieldInfos.mapM fun info => do
levelMVarToParamFVar info.fvar
match info.value? with
| none => pure info
| some value =>
let value levelMVarToParam' value
pure { info with value? := value }
where
levelMVarToParam' (type : Expr) : TermElabM Expr := do
Term.levelMVarToParam type (except := fun mvarId => univToInfer? == some mvarId)
levelMVarToParamFVars (fvars : Array Expr) : TermElabM Unit :=
fvars.forM levelMVarToParamFVar
levelMVarToParamFVar (fvar : Expr) : TermElabM Unit := do
let type inferType fvar
discard <| levelMVarToParam' type
private partial def collectUniversesFromFields (r : Level) (rOffset : Nat) (fieldInfos : Array StructFieldInfo) : TermElabM (Array Level) := do
let (_, us) go |>.run #[]
return us
where
go : StateRefT (Array Level) TermElabM Unit :=
for info in fieldInfos do
let type inferType info.fvar
let u getLevel type
let u instantiateLevelMVars u
match ( modifyGet fun s => accLevel u r rOffset |>.run |>.run s) with
| some _ => pure ()
| none =>
let typeType inferType type
let mut msg := m!"failed to compute resulting universe level of structure, field '{info.declName}' has type{indentD m!"{type} : {typeType}"}\nstructure resulting type{indentExpr (mkSort (r.addOffset rOffset))}"
if r.isMVar then
msg := msg ++ "\nrecall that Lean only infers the resulting universe level automatically when there is a unique solution for the universe level constraints, consider explicitly providing the structure resulting universe level"
throwError msg
/--
Decides whether the structure should be `Prop`-valued when the universe is not given
and when the universe inference algorithm `collectUniversesFromFields` determines
that the inductive type could naturally be `Prop`-valued.
See `Lean.Elab.Command.isPropCandidate` for an explanation.
Specialized to structures, the heuristic is that we prefer a `Prop` instead of a `Type` structure
when it could be a syntactic subsingleton.
Exception: no-field structures are `Type` since they are likely stubbed-out declarations.
-/
private def isPropCandidate (fieldInfos : Array StructFieldInfo) : Bool :=
!fieldInfos.isEmpty
private def updateResultingUniverse (fieldInfos : Array StructFieldInfo) (type : Expr) : TermElabM Expr := do
let r getResultUniverse type
let rOffset : Nat := r.getOffset
let r : Level := r.getLevelOffset
unless r.isMVar do
throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly: {r}"
let us collectUniversesFromFields r rOffset fieldInfos
trace[Elab.structure] "updateResultingUniverse us: {us}, r: {r}, rOffset: {rOffset}"
let rNew := mkResultUniverse us rOffset (isPropCandidate fieldInfos)
assignLevelMVar r.mvarId! rNew
instantiateMVars type
private def collectLevelParamsInFVar (s : CollectLevelParams.State) (fvar : Expr) : TermElabM CollectLevelParams.State := do
let type inferType fvar
let type instantiateMVars type
return collectLevelParams s type
private def collectLevelParamsInFVars (fvars : Array Expr) (s : CollectLevelParams.State) : TermElabM CollectLevelParams.State :=
fvars.foldlM collectLevelParamsInFVar s
private def collectLevelParamsInStructure (structType : Expr) (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo)
: TermElabM (Array Name) := do
let s := collectLevelParams {} structType
let s collectLevelParamsInFVars scopeVars s
let s collectLevelParamsInFVars params s
let s fieldInfos.foldlM (init := s) fun s info => collectLevelParamsInFVar s info.fvar
return s.params
private def addCtorFields (fieldInfos : Array StructFieldInfo) : Nat Expr TermElabM Expr
| 0, type => pure type
@@ -661,29 +772,19 @@ private def addCtorFields (fieldInfos : Array StructFieldInfo) : Nat → Expr
| _ =>
addCtorFields fieldInfos i (mkForall decl.userName decl.binderInfo decl.type type)
private def mkCtor (view : StructView) (r : ElabHeaderResult) (params : Array Expr) (fieldInfos : Array StructFieldInfo) : TermElabM Constructor :=
private def mkCtor (view : StructView) (levelParams : List Name) (params : Array Expr) (fieldInfos : Array StructFieldInfo) : TermElabM Constructor :=
withRef view.ref do
let type := mkAppN r.indFVar params
let type := mkAppN (mkConst view.declName (levelParams.map mkLevelParam)) params
let type addCtorFields fieldInfos fieldInfos.size type
let type mkForallFVars params type
let type instantiateMVars type
let type := type.inferImplicit params.size true
pure { name := view.ctor.declName, type }
private partial def checkResultingUniversesForFields (fieldInfos : Array StructFieldInfo) (u : Level) : TermElabM Unit := do
for info in fieldInfos do
let type inferType info.fvar
let v := ( instantiateLevelMVars ( getLevel type)).normalize
unless u.geq v do
let msg := m!"invalid universe level for field '{info.name}', has type{indentExpr type}\n\
at universe level{indentD v}\n\
which is not less than or equal to the structure's resulting universe level{indentD u}"
throwErrorAt info.ref msg
@[extern "lean_mk_projections"]
private opaque mkProjections (env : Environment) (structName : Name) (projs : List Name) (isClass : Bool) : Except KernelException Environment
private def addProjections (r : ElabHeaderResult) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
private def addProjections (r : ElabStructHeaderResult) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
if r.type.isProp then
if let some fieldInfo fieldInfos.findM? (not <$> Meta.isProof ·.fvar) then
throwErrorAt fieldInfo.ref m!"failed to generate projections for 'Prop' structure, field '{format fieldInfo.name}' is not a proof"
@@ -694,71 +795,49 @@ private def addProjections (r : ElabHeaderResult) (fieldInfos : Array StructFiel
private def registerStructure (structName : Name) (infos : Array StructFieldInfo) : TermElabM Unit := do
let fields infos.filterMapM fun info => do
if info.kind == StructFieldKind.fromParent then
return none
else
return some {
fieldName := info.name
projFn := info.declName
binderInfo := ( getFVarLocalDecl info.fvar).binderInfo
autoParam? := ( inferType info.fvar).getAutoParamTactic?
subobject? := if let .subobject parentName := info.kind then parentName else none
}
if info.kind == StructFieldKind.fromParent then
return none
else
return some {
fieldName := info.name
projFn := info.declName
binderInfo := ( getFVarLocalDecl info.fvar).binderInfo
autoParam? := ( inferType info.fvar).getAutoParamTactic?
subobject? := if let .subobject parentName := info.kind then parentName else none
}
modifyEnv fun env => Lean.registerStructure env { structName, fields }
private def checkDefaults (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
let mut mvars := {}
let mut lmvars := {}
for fieldInfo in fieldInfos do
if let some value := fieldInfo.value? then
let value instantiateMVars value
mvars := Expr.collectMVars mvars value
lmvars := collectLevelMVars lmvars value
-- Log errors and ignore the failure; we later will just omit adding a default value.
if Term.logUnassignedUsingErrorInfos mvars.result then
return
else if Term.logUnassignedLevelMVarsUsingErrorInfos lmvars.result then
return
private def mkAuxConstructions (declName : Name) : TermElabM Unit := do
let env getEnv
let hasEq := env.contains ``Eq
let hasHEq := env.contains ``HEq
let hasUnit := env.contains ``PUnit
let hasProd := env.contains ``Prod
mkRecOn declName
if hasUnit then mkCasesOn declName
if hasUnit && hasEq && hasHEq then mkNoConfusion declName
let ival getConstInfoInduct declName
if ival.isRec then
if hasUnit && hasProd then mkBelow declName
if hasUnit && hasProd then mkIBelow declName
if hasUnit && hasProd then mkBRecOn declName
if hasUnit && hasProd then mkBInductionOn declName
private def addDefaults (params : Array Expr) (replaceIndFVars : Expr MetaM Expr) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
let lctx getLCtx
/- The `lctx` and `defaultAuxDecls` are used to create the auxiliary "default value" declarations
The parameters `params` for these definitions must be marked as implicit, and all others as explicit. -/
let lctx :=
params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) =>
if p.isFVar then
lctx.setBinderInfo p.fvarId! BinderInfo.implicit
else
lctx
let lctx :=
fieldInfos.foldl (init := lctx) fun (lctx : LocalContext) (info : StructFieldInfo) =>
if info.isFromParent then lctx -- `fromParent` fields are elaborated as let-decls, and are zeta-expanded when creating "default value" auxiliary functions
else lctx.setBinderInfo info.fvar.fvarId! BinderInfo.default
-- Make all indFVar replacements in the local context.
let lctx
lctx.foldlM (init := {}) fun lctx ldecl => do
match ldecl with
| .cdecl _ fvarId userName type bi k =>
let type replaceIndFVars type
return lctx.mkLocalDecl fvarId userName type bi k
| .ldecl _ fvarId userName type value nonDep k =>
let type replaceIndFVars type
let value replaceIndFVars value
return lctx.mkLetDecl fvarId userName type value nonDep k
private def addDefaults (lctx : LocalContext) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
withLCtx lctx ( getLocalInstances) do
fieldInfos.forM fun fieldInfo => do
if let some value := fieldInfo.value? then
let declName := mkDefaultFnOfProjFn fieldInfo.declName
let type replaceIndFVars ( inferType fieldInfo.fvar)
let value instantiateMVars ( replaceIndFVars value)
trace[Elab.structure] "default value after 'replaceIndFVars': {indentExpr value}"
-- If there are mvars, `checkDefaults` already logged an error.
unless value.hasMVar || value.hasSyntheticSorry do
/- The identity function is used as "marker". -/
let value mkId value
-- No need to compile the definition, since it is only used during elaboration.
discard <| mkAuxDefinition declName type value (zetaDelta := true) (compile := false)
setReducibleAttribute declName
let type inferType fieldInfo.fvar
let value instantiateMVars value
if value.hasExprMVar then
discard <| Term.logUnassignedUsingErrorInfos ( getMVars value)
throwErrorAt fieldInfo.ref "invalid default value for field '{format fieldInfo.name}', it contains metavariables{indentExpr value}"
/- The identity function is used as "marker". -/
let value mkId value
-- No need to compile the definition, since it is only used during elaboration.
discard <| mkAuxDefinition declName type value (zetaDelta := true) (compile := false)
setReducibleAttribute declName
/--
Given `type` of the form `forall ... (source : A), B`, return `forall ... [source : A], B`.
@@ -775,81 +854,108 @@ private def setSourceInstImplicit (type : Expr) : Expr :=
/--
Creates a projection function to a non-subobject parent.
-/
private partial def mkCoercionToCopiedParent (levelParams : List Name) (params : Array Expr) (view : StructView) (source : Expr) (parentStructName : Name) (parentType : Expr) : MetaM StructureParentInfo := do
private partial def mkCoercionToCopiedParent (levelParams : List Name) (params : Array Expr) (view : StructView) (parentStructName : Name) (parentType : Expr) : MetaM StructureParentInfo := do
let isProp Meta.isProp parentType
let env getEnv
let structName := view.declName
let sourceFieldNames := getStructureFieldsFlattened env structName
let structType := mkAppN (Lean.mkConst structName (levelParams.map mkLevelParam)) params
let binfo := if view.isClass && isClass env parentStructName then BinderInfo.instImplicit else BinderInfo.default
let mut declType instantiateMVars ( mkForallFVars params ( mkForallFVars #[source] parentType))
if view.isClass && isClass env parentStructName then
declType := setSourceInstImplicit declType
declType := declType.inferImplicit params.size true
let rec copyFields (parentType : Expr) : MetaM Expr := do
let Expr.const parentStructName us pure parentType.getAppFn | unreachable!
let parentCtor := getStructureCtor env parentStructName
let mut result := mkAppN (mkConst parentCtor.name us) parentType.getAppArgs
for fieldName in getStructureFields env parentStructName do
if sourceFieldNames.contains fieldName then
let fieldVal mkProjection source fieldName
result := mkApp result fieldVal
else
-- fieldInfo must be a field of `parentStructName`
let some fieldInfo := getFieldInfo? env parentStructName fieldName | unreachable!
if fieldInfo.subobject?.isNone then throwError "failed to build coercion to parent structure"
let resultType whnfD ( inferType result)
unless resultType.isForall do throwError "failed to build coercion to parent structure, unexpected type{indentExpr resultType}"
let fieldVal copyFields resultType.bindingDomain!
result := mkApp result fieldVal
return result
let declVal instantiateMVars ( mkLambdaFVars params ( mkLambdaFVars #[source] ( copyFields parentType)))
let declName := structName ++ mkToParentName ( getStructureName parentType) fun n => !env.contains (structName ++ n)
-- Logic from `mk_projections`: prop-valued projections are theorems (or at least opaque)
let cval : ConstantVal := { name := declName, levelParams, type := declType }
if isProp then
addDecl <|
if view.modifiers.isUnsafe then
-- Theorems cannot be unsafe.
Declaration.opaqueDecl { cval with value := declVal, isUnsafe := true }
else
Declaration.thmDecl { cval with value := declVal }
else
addAndCompile <| Declaration.defnDecl { cval with
value := declVal
hints := ReducibilityHints.abbrev
safety := if view.modifiers.isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe
}
-- Logic from `mk_projections`: non-instance-implicits that aren't props become reducible.
-- (Instances will get instance reducibility in `Lean.Elab.Command.addParentInstances`.)
if !binfo.isInstImplicit && !( Meta.isProp parentType) then
setReducibleAttribute declName
return { structName := parentStructName, subobject := false, projFn := declName }
private def mkRemainingProjections (levelParams : List Name) (params : Array Expr) (view : StructView)
(parents : Array StructParentInfo) (fieldInfos : Array StructFieldInfo) : TermElabM (Array StructureParentInfo) := do
let structType := mkAppN (Lean.mkConst view.declName (levelParams.map mkLevelParam)) params
withLocalDeclD `self structType fun source => do
/-
Remark: copied parents might still be referring to the fvars of other parents. We need to replace these fvars with projection constants.
For subobject parents, this has already been done by `mkProjections`.
https://github.com/leanprover/lean4/issues/2611
-/
let mut parentInfos := #[]
let mut parentFVarToConst : ExprMap Expr := {}
for h : i in [0:parents.size] do
let parent := parents[i]
let parentInfo : StructureParentInfo (do
if parent.subobject then
let some info := fieldInfos.find? (·.kind == .subobject parent.structName) | unreachable!
pure { structName := parent.structName, subobject := true, projFn := info.declName }
let mut declType instantiateMVars ( mkForallFVars params ( mkForallFVars #[source] parentType))
if view.isClass && isClass env parentStructName then
declType := setSourceInstImplicit declType
declType := declType.inferImplicit params.size true
let rec copyFields (parentType : Expr) : MetaM Expr := do
let Expr.const parentStructName us pure parentType.getAppFn | unreachable!
let parentCtor := getStructureCtor env parentStructName
let mut result := mkAppN (mkConst parentCtor.name us) parentType.getAppArgs
for fieldName in getStructureFields env parentStructName do
if sourceFieldNames.contains fieldName then
let fieldVal mkProjection source fieldName
result := mkApp result fieldVal
else
let parent_type := ( instantiateMVars parent.type).replace fun e => parentFVarToConst[e]?
mkCoercionToCopiedParent levelParams params view source parent.structName parent_type)
parentInfos := parentInfos.push parentInfo
if let some fvar := parent.fvar? then
parentFVarToConst := parentFVarToConst.insert fvar <|
mkApp (mkAppN (.const parentInfo.projFn (levelParams.map mkLevelParam)) params) source
pure parentInfos
-- fieldInfo must be a field of `parentStructName`
let some fieldInfo := getFieldInfo? env parentStructName fieldName | unreachable!
if fieldInfo.subobject?.isNone then throwError "failed to build coercion to parent structure"
let resultType whnfD ( inferType result)
unless resultType.isForall do throwError "failed to build coercion to parent structure, unexpected type{indentExpr resultType}"
let fieldVal copyFields resultType.bindingDomain!
result := mkApp result fieldVal
return result
let declVal instantiateMVars ( mkLambdaFVars params ( mkLambdaFVars #[source] ( copyFields parentType)))
let declName := structName ++ mkToParentName ( getStructureName parentType) fun n => !env.contains (structName ++ n)
-- Logic from `mk_projections`: prop-valued projections are theorems (or at least opaque)
let cval : ConstantVal := { name := declName, levelParams, type := declType }
if isProp then
addDecl <|
if view.modifiers.isUnsafe then
-- Theorems cannot be unsafe.
Declaration.opaqueDecl { cval with value := declVal, isUnsafe := true }
else
Declaration.thmDecl { cval with value := declVal }
else
addAndCompile <| Declaration.defnDecl { cval with
value := declVal
hints := ReducibilityHints.abbrev
safety := if view.modifiers.isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe
}
-- Logic from `mk_projections`: non-instance-implicits that aren't props become reducible.
-- (Instances will get instance reducibility in `Lean.Elab.Command.addParentInstances`.)
if !binfo.isInstImplicit && !( Meta.isProp parentType) then
setReducibleAttribute declName
return { structName := parentStructName, subobject := false, projFn := declName }
private def elabStructHeader (view : StructView) : TermElabM ElabStructHeaderResult :=
Term.withAutoBoundImplicitForbiddenPred (fun n => view.shortDeclName == n) do
Term.withAutoBoundImplicit do
Term.elabBinders view.binders.getArgs fun params => do
elabParents view fun parentFieldInfos parents => do
let type Term.elabType view.type
Term.synthesizeSyntheticMVarsNoPostponing
let u mkFreshLevelMVar
unless isDefEq type (mkSort u) do
throwErrorAt view.type "invalid structure type, expecting 'Type _' or 'Prop'"
let type instantiateMVars ( whnf type)
Term.addAutoBoundImplicits' params type fun params type => do
let levelNames Term.getLevelNames
trace[Elab.structure] "header params: {params}, type: {type}, levelNames: {levelNames}"
return { lctx := ( getLCtx), localInsts := ( getLocalInstances), levelNames, params, type, view, parents, parentFieldInfos }
private def mkTypeFor (r : ElabStructHeaderResult) : TermElabM Expr := do
withLCtx r.lctx r.localInsts do
mkForallFVars r.params r.type
/--
Create a local declaration for the structure and execute `x params indFVar`, where `params` are the structure's type parameters and
`indFVar` is the new local declaration.
-/
private partial def withStructureLocalDecl (r : ElabStructHeaderResult) (x : Array Expr Expr TermElabM α) : TermElabM α := do
let declName := r.view.declName
let shortDeclName := r.view.shortDeclName
let type mkTypeFor r
let params := r.params
withLCtx r.lctx r.localInsts <| withRef r.view.ref do
Term.withAuxDecl shortDeclName type declName fun indFVar =>
x params indFVar
/--
Remark: `numVars <= numParams`.
`numVars` is the number of context `variables` used in the declaration,
and `numParams - numVars` is the number of parameters provided as binders in the declaration.
-/
private def mkInductiveType (view : StructView) (indFVar : Expr) (levelNames : List Name)
(numVars : Nat) (numParams : Nat) (type : Expr) (ctor : Constructor) : TermElabM InductiveType := do
let levelParams := levelNames.map mkLevelParam
let const := mkConst view.declName levelParams
let ctorType forallBoundedTelescope ctor.type numParams fun params type => do
let type := type.replace fun e =>
if e == indFVar then
mkAppN const (params.extract 0 numVars)
else
none
instantiateMVars ( mkForallFVars params type)
return { name := view.declName, type := instantiateMVars type, ctors := [{ ctor with type := instantiateMVars ctorType }] }
/--
Precomputes the structure's resolution order.
@@ -881,45 +987,109 @@ private def addParentInstances (parents : Array StructureParentInfo) : MetaM Uni
for instParent in instParents do
addInstance instParent.projFn AttributeKind.global (eval_prio default)
@[builtin_inductive_elab Lean.Parser.Command.«structure»]
def elabStructureCommand : InductiveElabDescr where
mkInductiveView (modifiers : Modifiers) (stx : Syntax) := do
def mkStructureDecl (vars : Array Expr) (view : StructView) : TermElabM Unit := Term.withoutSavingRecAppSyntax do
let scopeLevelNames Term.getLevelNames
let isUnsafe := view.modifiers.isUnsafe
withRef view.ref <| Term.withLevelNames view.levelNames do
let r elabStructHeader view
Term.synthesizeSyntheticMVarsNoPostponing
withLCtx r.lctx r.localInsts do
withStructureLocalDecl r fun params indFVar => do
trace[Elab.structure] "indFVar: {indFVar}"
Term.addLocalVarInfo view.declId indFVar
withFields view.fields r.parentFieldInfos fun fieldInfos =>
withRef view.ref do
Term.synthesizeSyntheticMVarsNoPostponing
let type instantiateMVars r.type
let u getResultUniverse type
let univToInfer? shouldInferResultUniverse u
withUsed vars params fieldInfos fun scopeVars => do
let fieldInfos levelMVarToParam scopeVars params fieldInfos univToInfer?
let type withRef view.ref do
if univToInfer?.isSome then
updateResultingUniverse fieldInfos type
else
checkResultingUniverse ( getResultUniverse type)
pure type
trace[Elab.structure] "type: {type}"
let usedLevelNames collectLevelParamsInStructure type scopeVars params fieldInfos
match sortDeclLevelParams scopeLevelNames r.levelNames usedLevelNames with
| Except.error msg => throwErrorAt view.declId msg
| Except.ok levelParams =>
let params := scopeVars ++ params
let ctor mkCtor view levelParams params fieldInfos
let type mkForallFVars params type
let type instantiateMVars type
let indType mkInductiveType view indFVar levelParams scopeVars.size params.size type ctor
let decl := Declaration.inductDecl levelParams params.size [indType] isUnsafe
Term.ensureNoUnassignedMVars decl
addDecl decl
-- rename indFVar so that it does not shadow the actual declaration:
let lctx := ( getLCtx).modifyLocalDecl indFVar.fvarId! fun decl => decl.setUserName .anonymous
withLCtx lctx ( getLocalInstances) do
addProjections r fieldInfos
registerStructure view.declName fieldInfos
mkAuxConstructions view.declName
withSaveInfoContext do -- save new env
Term.addLocalVarInfo view.ref[1] ( mkConstWithLevelParams view.declName)
if let some _ := view.ctor.ref.getPos? (canonicalOnly := true) then
Term.addTermInfo' view.ctor.ref ( mkConstWithLevelParams view.ctor.declName) (isBinder := true)
for field in view.fields do
-- may not exist if overriding inherited field
if ( getEnv).contains field.declName then
Term.addTermInfo' field.ref ( mkConstWithLevelParams field.declName) (isBinder := true)
withRef view.declId do
Term.applyAttributesAt view.declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking
let parentInfos r.parents.mapM fun parent => do
if parent.subobject then
let some info := fieldInfos.find? (·.kind == .subobject parent.structName) | unreachable!
pure { structName := parent.structName, subobject := true, projFn := info.declName }
else
mkCoercionToCopiedParent levelParams params view parent.structName parent.type
setStructureParents view.declName parentInfos
checkResolutionOrder view.declName
if view.isClass then
addParentInstances parentInfos
let lctx getLCtx
/- The `lctx` and `defaultAuxDecls` are used to create the auxiliary "default value" declarations
The parameters `params` for these definitions must be marked as implicit, and all others as explicit. -/
let lctx :=
params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) =>
if p.isFVar then
lctx.setBinderInfo p.fvarId! BinderInfo.implicit
else
lctx
let lctx :=
fieldInfos.foldl (init := lctx) fun (lctx : LocalContext) (info : StructFieldInfo) =>
if info.isFromParent then lctx -- `fromParent` fields are elaborated as let-decls, and are zeta-expanded when creating "default value" auxiliary functions
else lctx.setBinderInfo info.fvar.fvarId! BinderInfo.default
addDefaults lctx fieldInfos
def elabStructureView (vars : Array Expr) (view : StructView) : TermElabM Unit := do
Term.withDeclName view.declName <| withRef view.ref do
mkStructureDecl vars view
unless view.isClass do
Lean.Meta.IndPredBelow.mkBelow view.declName
mkSizeOfInstances view.declName
mkInjectiveTheorems view.declName
def elabStructureViewPostprocessing (view : StructView) : CommandElabM Unit := do
view.derivingClasses.forM fun classView => classView.applyHandlers #[view.declName]
runTermElabM fun _ => Term.withDeclName view.declName <| withRef view.declId do
Term.applyAttributesAt view.declName view.modifiers.attrs .afterCompilation
def elabStructure (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit := do
let view runTermElabM fun vars => do
let view structureSyntaxToView modifiers stx
trace[Elab.structure] "view.levelNames: {view.levelNames}"
return {
view := view.toInductiveView
elabCtors := fun rs r params => do
withParents view rs r.indFVar fun parentFieldInfos parents =>
withFields view.fields parentFieldInfos fun fieldInfos => do
withRef view.ref do
Term.synthesizeSyntheticMVarsNoPostponing
let lctx getLCtx
let localInsts getLocalInstances
let ctor mkCtor view r params fieldInfos
return {
ctors := [ctor]
collectUsedFVars := collectUsedFVars lctx localInsts fieldInfos
checkUniverses := fun _ u => withLCtx lctx localInsts do checkResultingUniversesForFields fieldInfos u
finalizeTermElab := withLCtx lctx localInsts do checkDefaults fieldInfos
prefinalize := fun _ _ _ => do
withLCtx lctx localInsts do
addProjections r fieldInfos
registerStructure view.declName fieldInfos
withSaveInfoContext do -- save new env
for field in view.fields do
-- may not exist if overriding inherited field
if ( getEnv).contains field.declName then
Term.addTermInfo' field.ref ( mkConstWithLevelParams field.declName) (isBinder := true)
finalize := fun levelParams params replaceIndFVars => do
let parentInfos mkRemainingProjections levelParams params view parents fieldInfos
setStructureParents view.declName parentInfos
checkResolutionOrder view.declName
if view.isClass then
addParentInstances parentInfos
elabStructureView vars view
pure view
elabStructureViewPostprocessing view
withLCtx lctx localInsts do
addDefaults params replaceIndFVars fieldInfos
}
}
builtin_initialize
registerTraceClass `Elab.structure
registerTraceClass `Elab.structure.resolutionOrder
end Lean.Elab.Command

View File

@@ -58,28 +58,15 @@ where
let some eqBVPred ReifiedBVPred.mkBinPred atom resValExpr atomExpr resExpr .eq | return none
let eqBV ReifiedBVLogical.ofPred eqBVPred
let trueExpr := mkConst ``Bool.true
let impExpr mkArrow ( mkEq eqDiscrExpr trueExpr) ( mkEq eqBVExpr trueExpr)
let decideImpExpr mkAppOptM ``Decidable.decide #[some impExpr, none]
let imp ReifiedBVLogical.mkGate eqDiscr eqBV eqDiscrExpr eqBVExpr .imp
let proof := do
let evalExpr ReifiedBVLogical.mkEvalExpr imp.expr
let congrProof imp.evalsAtAtoms
let lemmaProof := mkApp4 (mkConst lemmaName) (toExpr lhs.width) discrExpr lhsExpr rhsExpr
let trueExpr := mkConst ``Bool.true
let eqDiscrTrueExpr mkEq eqDiscrExpr trueExpr
let eqBVExprTrueExpr mkEq eqBVExpr trueExpr
let impExpr mkArrow eqDiscrTrueExpr eqBVExprTrueExpr
-- construct a `Decidable` instance for the implication using forall_prop_decidable
let decEqDiscrTrue := mkApp2 (mkConst ``instDecidableEqBool) eqDiscrExpr trueExpr
let decEqBVExprTrue := mkApp2 (mkConst ``instDecidableEqBool) eqBVExpr trueExpr
let impDecidable := mkApp4 (mkConst ``forall_prop_decidable)
eqDiscrTrueExpr
(.lam .anonymous eqDiscrTrueExpr eqBVExprTrueExpr .default)
decEqDiscrTrue
(.lam .anonymous eqDiscrTrueExpr decEqBVExprTrue .default)
let decideImpExpr := mkApp2 (mkConst ``Decidable.decide) impExpr impDecidable
return mkApp4
(mkConst ``Std.Tactic.BVDecide.Reflect.Bool.lemma_congr)
decideImpExpr

View File

@@ -40,7 +40,7 @@ private def tryAssumption (mvarId : MVarId) : MetaM (List MVarId) := do
let ids := stx[1].getArgs.toList.map getNameOfIdent'
liftMetaTactic fun mvarId => do
match ( Meta.injections mvarId ids) with
| .solved => checkUnusedIds `injections mvarId ids; return []
| .subgoal mvarId' unusedIds _ => checkUnusedIds `injections mvarId unusedIds; tryAssumption mvarId'
| .solved => checkUnusedIds `injections mvarId ids; return []
| .subgoal mvarId' unusedIds => checkUnusedIds `injections mvarId unusedIds; tryAssumption mvarId'
end Lean.Elab.Tactic

View File

@@ -473,26 +473,6 @@ def withoutTacticReuse [Monad m] [MonadWithReaderOf Context m] [MonadOptions m]
return !cond }
}) act
@[inherit_doc Core.wrapAsync]
def wrapAsync (act : Unit TermElabM α) : TermElabM (EIO Exception α) := do
let ctx read
let st get
let metaCtx readThe Meta.Context
let metaSt getThe Meta.State
Core.wrapAsync fun _ =>
act () |>.run ctx |>.run' st |>.run' metaCtx metaSt
@[inherit_doc Core.wrapAsyncAsSnapshot]
def wrapAsyncAsSnapshot (act : Unit TermElabM Unit)
(desc : String := by exact decl_name%.toString) :
TermElabM (BaseIO Language.SnapshotTree) := do
let ctx read
let st get
let metaCtx readThe Meta.Context
let metaSt getThe Meta.State
Core.wrapAsyncAsSnapshot (desc := desc) fun _ =>
act () |>.run ctx |>.run' st |>.run' metaCtx metaSt
abbrev TermElabResult (α : Type) := EStateM.Result Exception SavedState α
/--

View File

@@ -10,8 +10,9 @@ Authors: Sebastian Ullrich
prelude
import Init.System.Promise
import Lean.Message
import Lean.Parser.Types
import Lean.Util.Trace
import Lean.Elab.InfoTree
set_option linter.missingDocs true
@@ -197,13 +198,32 @@ def withAlwaysResolvedPromises [Monad m] [MonadLiftT BaseIO m] [MonadFinally m]
for asynchronously collecting information about the entirety of snapshots in the language server.
The involved tasks may form a DAG on the `Task` dependency level but this is not captured by this
data structure. -/
structure SnapshotTree where
/-- The immediately available element of the snapshot tree node. -/
element : Snapshot
/-- The asynchronously available children of the snapshot tree node. -/
children : Array (SnapshotTask SnapshotTree)
inductive SnapshotTree where
/-- Creates a snapshot tree node. -/
| mk (element : Snapshot) (children : Array (SnapshotTask SnapshotTree))
deriving Inhabited
/-- The immediately available element of the snapshot tree node. -/
abbrev SnapshotTree.element : SnapshotTree Snapshot
| mk s _ => s
/-- The asynchronously available children of the snapshot tree node. -/
abbrev SnapshotTree.children : SnapshotTree Array (SnapshotTask SnapshotTree)
| mk _ children => children
/-- Produces trace of given snapshot tree, synchronously waiting on all children. -/
partial def SnapshotTree.trace (s : SnapshotTree) : CoreM Unit :=
go none s
where go range? s := do
let file getFileMap
let mut desc := f!"{s.element.desc}"
if let some range := range? then
desc := desc ++ f!"{file.toPosition range.start}-{file.toPosition range.stop} "
desc := desc ++ .prefixJoin "\n" ( s.element.diagnostics.msgLog.toList.mapM (·.toString))
if let some t := s.element.infoTree? then
trace[Elab.info] ( t.format)
withTraceNode `Elab.snapshotTree (fun _ => pure desc) do
s.children.toList.forM fun c => go c.range? c.get
/--
Helper class for projecting a heterogeneous hierarchy of snapshot classes to a homogeneous
representation. -/
@@ -288,14 +308,6 @@ def SnapshotTree.runAndReport (s : SnapshotTree) (opts : Options) (json := false
def SnapshotTree.getAll (s : SnapshotTree) : Array Snapshot :=
s.forM (m := StateM _) (fun s => modify (·.push s)) |>.run #[] |>.2
/-- Returns a task that waits on all snapshots in the tree. -/
def SnapshotTree.waitAll : SnapshotTree BaseIO (Task Unit)
| mk _ children => go children.toList
where
go : List (SnapshotTask SnapshotTree) BaseIO (Task Unit)
| [] => return .pure ()
| t::ts => BaseIO.bindTask t.task fun _ => go ts
/-- Context of an input processing invocation. -/
structure ProcessingContext extends Parser.InputContext

View File

@@ -10,7 +10,6 @@ Authors: Sebastian Ullrich
prelude
import Lean.Language.Basic
import Lean.Language.Util
import Lean.Language.Lean.Types
import Lean.Parser.Module
import Lean.Elab.Import
@@ -167,6 +166,13 @@ namespace Lean.Language.Lean
open Lean.Elab Command
open Lean.Parser
/-- Option for capturing output to stderr during elaboration. -/
register_builtin_option stderrAsMessages : Bool := {
defValue := true
group := "server"
descr := "(server) capture output to the Lean stderr channel (such as from `dbg_trace`) during elaboration of a command as a diagnostic message"
}
/-- Lean-specific processing context. -/
structure LeanProcessingContext extends ProcessingContext where
/-- Position of the first file difference if there was a previous invocation. -/
@@ -513,7 +519,6 @@ where
diagnostics := .empty, stx := .missing, parserState
elabSnap := .pure <| .ofTyped { diagnostics := .empty : SnapshotLeaf }
finishedSnap := .pure { diagnostics := .empty, cmdState }
reportSnap := default
tacticCache := ( IO.mkRef {})
nextCmdSnap? := none
}
@@ -524,7 +529,6 @@ where
-- definitely resolved in `doElab` task
let elabPromise IO.Promise.new
let finishedPromise IO.Promise.new
let reportPromise IO.Promise.new
-- (Try to) use last line of command as range for final snapshot task. This ensures we do not
-- retract the progress bar to a previous position in case the command support incremental
-- reporting but has significant work after resolving its last incremental promise, such as
@@ -543,46 +547,22 @@ where
let nextCmdSnap? := next?.map
({ range? := some parserState.pos, ctx.input.endPos, task := ·.result })
let diagnostics Snapshot.Diagnostics.ofMessageLog msgLog
let (stx', parserState') := if minimalSnapshots && !Parser.isTerminalCommand stx then
(default, default)
if minimalSnapshots && !Parser.isTerminalCommand stx then
prom.resolve {
diagnostics, finishedSnap, tacticCache, nextCmdSnap?
stx := .missing
parserState := {}
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
}
else
(stx, parserState)
prom.resolve {
diagnostics, finishedSnap, tacticCache, nextCmdSnap?
stx := stx', parserState := parserState'
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
reportSnap := { range? := endRange?, task := reportPromise.result }
}
prom.resolve {
diagnostics, stx, parserState, tacticCache, nextCmdSnap?
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
finishedSnap
}
let cmdState doElab stx cmdState beginPos
{ old? := old?.map fun old => old.stx, old.elabSnap, new := elabPromise }
finishedPromise tacticCache ctx
let traceTask
if ( isTracingEnabledForCore `Elab.snapshotTree cmdState.scopes.head!.opts) then
-- We want to trace all of `CommandParsedSnapshot` but `traceTask` is part of it, so let's
-- create a temporary snapshot tree containing all tasks but it
let snaps := #[
{ range? := none, task := elabPromise.result.map (sync := true) toSnapshotTree },
{ range? := none, task := finishedPromise.result.map (sync := true) toSnapshotTree }] ++
cmdState.snapshotTasks
let tree := SnapshotTree.mk { diagnostics := .empty } snaps
BaseIO.bindTask ( tree.waitAll) fun _ => do
let .ok (_, s) EIO.toBaseIO <| tree.trace |>.run
{ ctx with options := cmdState.scopes.head!.opts } { env := cmdState.env }
| pure <| .pure <| .mk { diagnostics := .empty } #[]
let mut msgLog := MessageLog.empty
for trace in s.traceState.traces do
msgLog := msgLog.add {
fileName := ctx.fileName
severity := MessageSeverity.information
pos := ctx.fileMap.toPosition beginPos
data := trace.msg
}
return .pure <| .mk { diagnostics := ( Snapshot.Diagnostics.ofMessageLog msgLog) } #[]
else
pure <| .pure <| .mk { diagnostics := .empty } #[]
reportPromise.resolve <|
.mk { diagnostics := .empty } <|
cmdState.snapshotTasks.push { range? := endRange?, task := traceTask }
if let some next := next? then
-- We're definitely off the fast-forwarding path now
parseCmd none parserState cmdState next (sync := false) ctx
@@ -593,9 +573,7 @@ where
LeanProcessingM Command.State := do
let ctx read
let scope := cmdState.scopes.head!
-- reset per-command state
let cmdStateRef IO.mkRef { cmdState with
messages := .empty, traceState := {}, snapshotTasks := #[] }
let cmdStateRef IO.mkRef { cmdState with messages := .empty, traceState := {} }
/-
The same snapshot may be executed by different tasks. So, to make sure `elabCommandTopLevel`
has exclusive access to the cache, we create a fresh reference here. Before this change, the
@@ -610,8 +588,8 @@ where
cancelTk? := some ctx.newCancelTk
}
let (output, _)
IO.FS.withIsolatedStreams (isolateStderr := Core.stderrAsMessages.get scope.opts) do
EIO.toBaseIO do
IO.FS.withIsolatedStreams (isolateStderr := stderrAsMessages.get scope.opts) do
liftM (m := BaseIO) do
withLoggingExceptions
(getResetInfoTrees *> Elab.Command.elabCommandTopLevel stx)
cmdCtx cmdStateRef

View File

@@ -46,8 +46,6 @@ structure CommandParsedSnapshot extends Snapshot where
elabSnap : SnapshotTask DynamicSnapshot
/-- State after processing is finished. -/
finishedSnap : SnapshotTask CommandFinishedSnapshot
/-- Additional, untyped snapshots used for reporting, not reuse. -/
reportSnap : SnapshotTask SnapshotTree
/-- Cache for `save`; to be replaced with incrementality. -/
tacticCache : IO.Ref Tactic.Cache
/-- Next command, unless this is a terminal command. -/
@@ -57,8 +55,7 @@ partial instance : ToSnapshotTree CommandParsedSnapshot where
toSnapshotTree := go where
go s := s.toSnapshot,
#[s.elabSnap.map (sync := true) toSnapshotTree,
s.finishedSnap.map (sync := true) toSnapshotTree,
s.reportSnap] |>
s.finishedSnap.map (sync := true) toSnapshotTree] |>
pushOpt (s.nextCmdSnap?.map (·.map (sync := true) go))
/-- State after successful importing. -/

View File

@@ -1,29 +0,0 @@
/-
Copyright (c) 2023 Lean FRO. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Additional snapshot functionality that needs further imports.
Authors: Sebastian Ullrich
-/
prelude
import Lean.Language.Basic
import Lean.CoreM
import Lean.Elab.InfoTree
namespace Lean.Language
/-- Produces trace of given snapshot tree, synchronously waiting on all children. -/
partial def SnapshotTree.trace (s : SnapshotTree) : CoreM Unit :=
go none s
where go range? s := do
let file getFileMap
let mut desc := f!"{s.element.desc}"
if let some range := range? then
desc := desc ++ f!"{file.toPosition range.start}-{file.toPosition range.stop} "
desc := desc ++ .prefixJoin "\n" ( s.element.diagnostics.msgLog.toList.mapM (·.toString))
if let some t := s.element.infoTree? then
trace[Elab.info] ( t.format)
withTraceNode `Elab.snapshotTree (fun _ => pure desc) do
s.children.toList.forM fun c => go c.range? c.get

View File

@@ -31,10 +31,6 @@ builtin_initialize deprecatedAttr : ParametricAttribute DeprecationEntry ←
let newName? id?.mapM Elab.realizeGlobalConstNoOverloadWithInfo
let text? := text?.map TSyntax.getString
let since? := since?.map TSyntax.getString
if id?.isNone && text?.isNone then
logWarning "`[deprecated]` attribute should specify either a new name or a deprecation message"
if since?.isNone then
logWarning "`[deprecated]` attribute should specify the date or library version at which the deprecation was introduced, using `(since := \"...\")`"
return { newName?, text?, since? }
}
@@ -50,9 +46,7 @@ def getDeprecatedNewName (env : Environment) (declName : Name) : Option Name :=
def checkDeprecated [Monad m] [MonadEnv m] [MonadLog m] [AddMessageContext m] [MonadOptions m] (declName : Name) : m Unit := do
if getLinterValue linter.deprecated ( getOptions) then
let some attr := deprecatedAttr.getParam? ( getEnv) declName | pure ()
logWarning <| .tagged ``deprecatedAttr <|
s!"`{declName}` has been deprecated" ++ match attr.text? with
| some text => s!": {text}"
| none => match attr.newName? with
| some newName => s!": use `{newName}` instead"
| none => ""
logWarning <| .tagged ``deprecatedAttr <| attr.text?.getD <|
match attr.newName? with
| none => s!"`{declName}` has been deprecated"
| some newName => s!"`{declName}` has been deprecated, use `{newName}` instead"

View File

@@ -406,8 +406,8 @@ def isSubPrefixOf (lctx₁ lctx₂ : LocalContext) (exceptFVars : Array Expr :=
@[inline] def mkBinding (isLambda : Bool) (lctx : LocalContext) (xs : Array Expr) (b : Expr) : Expr :=
let b := b.abstract xs
xs.size.foldRev (init := b) fun i _ b =>
let x := xs[i]
xs.size.foldRev (init := b) fun i b =>
let x := xs[i]!
match lctx.findFVar? x with
| some (.cdecl _ _ n ty bi _) =>
let ty := ty.abstractRange i xs;
@@ -457,7 +457,7 @@ def sanitizeNames (lctx : LocalContext) : StateM NameSanitizerState LocalContext
let st get
if !getSanitizeNames st.options then pure lctx else
StateT.run' (s := ({} : NameSet)) <|
lctx.decls.size.foldRevM (init := lctx) fun i _ lctx => do
lctx.decls.size.foldRevM (init := lctx) fun i lctx => do
match lctx.decls[i]! with
| none => pure lctx
| some decl =>

View File

@@ -825,7 +825,7 @@ def mkFreshExprMVarWithId (mvarId : MVarId) (type? : Option Expr := none) (kind
mkFreshExprMVarWithIdCore mvarId type kind userName
def mkFreshLevelMVars (num : Nat) : MetaM (List Level) :=
num.foldM (init := []) fun _ _ us =>
num.foldM (init := []) fun _ us =>
return ( mkFreshLevelMVar)::us
def mkFreshLevelMVarsFor (info : ConstantInfo) : MetaM (List Level) :=

View File

@@ -286,8 +286,8 @@ partial def process : ClosureM Unit := do
@[inline] def mkBinding (isLambda : Bool) (decls : Array LocalDecl) (b : Expr) : Expr :=
let xs := decls.map LocalDecl.toExpr
let b := b.abstract xs
decls.size.foldRev (init := b) fun i _ b =>
let decl := decls[i]
decls.size.foldRev (init := b) fun i b =>
let decl := decls[i]!
match decl with
| .cdecl _ _ n ty bi _ =>
let ty := ty.abstractRange i xs

View File

@@ -152,7 +152,7 @@ where
let run (x : StateRefT Nat MetaM Expr) : MetaM (Expr × Nat) := StateRefT'.run x 0
let (_, cnt) run <| transform domain fun e => do
if let some name := e.constName? then
if ctx.typeInfos.any fun indVal => indVal.name == name then
if let some _ := ctx.typeInfos.findIdx? fun indVal => indVal.name == name then
modify (· + 1)
return .continue

View File

@@ -566,9 +566,9 @@ Append results to array
partial def appendResultsAux (mr : MatchResult α) (a : Array β) (f : Nat α β) : Array β :=
let aa := mr.elts
let n := aa.size
Nat.fold (n := n) (init := a) fun i _ r =>
Nat.fold (n := n) (init := a) fun i r =>
let j := n-1-i
let b := aa[j]
let b := aa[j]!
b.foldl (init := r) (· ++ ·.map (f j))
partial def appendResults (mr : MatchResult α) (a : Array α) : Array α :=

View File

@@ -882,7 +882,7 @@ def mkMatcher (input : MkMatcherInput) (exceptionIfContainsSorry := false) : Met
| none => pure ()
trace[Meta.Match.debug] "matcher: {matcher}"
let unusedAltIdxs := lhss.length.fold (init := []) fun i _ r =>
let unusedAltIdxs := lhss.length.fold (init := []) fun i r =>
if s.used.contains i then r else i::r
return {
matcher,

View File

@@ -222,16 +222,16 @@ private def contradiction (mvarId : MVarId) : MetaM Bool :=
Auxiliary tactic that tries to replace as many variables as possible and then apply `contradiction`.
We use it to discard redundant hypotheses.
-/
partial def trySubstVarsAndContradiction (mvarId : MVarId) (forbidden : FVarIdSet := {}) : MetaM Bool :=
partial def trySubstVarsAndContradiction (mvarId : MVarId) : MetaM Bool :=
commitWhen do
let mvarId substVars mvarId
match ( injections mvarId (forbidden := forbidden)) with
match ( injections mvarId) with
| .solved => return true -- closed goal
| .subgoal mvarId' _ forbidden =>
| .subgoal mvarId' _ =>
if mvarId' == mvarId then
contradiction mvarId
else
trySubstVarsAndContradiction mvarId' forbidden
trySubstVarsAndContradiction mvarId'
private def processNextEq : M Bool := do
let s get
@@ -375,8 +375,7 @@ private partial def withSplitterAlts (altTypes : Array Expr) (f : Array Expr →
inductive InjectionAnyResult where
| solved
| failed
/-- `fvarId` refers to the local declaration selected for the application of the `injection` tactic. -/
| subgoal (fvarId : FVarId) (mvarId : MVarId)
| subgoal (mvarId : MVarId)
private def injectionAnyCandidate? (type : Expr) : MetaM (Option (Expr × Expr)) := do
if let some (_, lhs, rhs) matchEq? type then
@@ -386,28 +385,21 @@ private def injectionAnyCandidate? (type : Expr) : MetaM (Option (Expr × Expr))
return some (lhs, rhs)
return none
/--
Try applying `injection` to a local declaration that is not in `forbidden`.
We use `forbidden` because the `injection` tactic might fail to clear the variable if there are forward dependencies.
See `proveSubgoalLoop` for additional details.
-/
private def injectionAny (mvarId : MVarId) (forbidden : FVarIdSet := {}) : MetaM InjectionAnyResult := do
private def injectionAny (mvarId : MVarId) : MetaM InjectionAnyResult := do
mvarId.withContext do
for localDecl in ( getLCtx) do
unless forbidden.contains localDecl.fvarId do
if let some (lhs, rhs) injectionAnyCandidate? localDecl.type then
unless ( isDefEq lhs rhs) do
let lhs whnf lhs
let rhs whnf rhs
unless lhs.isRawNatLit && rhs.isRawNatLit do
try
match ( injection mvarId localDecl.fvarId) with
| InjectionResult.solved => return InjectionAnyResult.solved
| InjectionResult.subgoal mvarId .. => return InjectionAnyResult.subgoal localDecl.fvarId mvarId
catch ex =>
trace[Meta.Match.matchEqs] "injectionAnyFailed at {localDecl.userName}, error\n{ex.toMessageData}"
pure ()
if let some (lhs, rhs) injectionAnyCandidate? localDecl.type then
unless ( isDefEq lhs rhs) do
let lhs whnf lhs
let rhs whnf rhs
unless lhs.isRawNatLit && rhs.isRawNatLit do
try
match ( injection mvarId localDecl.fvarId) with
| InjectionResult.solved => return InjectionAnyResult.solved
| InjectionResult.subgoal mvarId .. => return InjectionAnyResult.subgoal mvarId
catch ex =>
trace[Meta.Match.matchEqs] "injectionAnyFailed at {localDecl.userName}, error\n{ex.toMessageData}"
pure ()
return InjectionAnyResult.failed
@@ -609,32 +601,27 @@ where
let eNew := mkAppN eNew mvars
return TransformStep.done eNew
/-
`forbidden` tracks variables that we have already applied `injection`.
Recall that the `injection` tactic may not be able to eliminate them when
they have forward dependencies.
-/
proveSubgoalLoop (mvarId : MVarId) (forbidden : FVarIdSet) : MetaM Unit := do
proveSubgoalLoop (mvarId : MVarId) : MetaM Unit := do
trace[Meta.Match.matchEqs] "proveSubgoalLoop\n{mvarId}"
if ( mvarId.contradictionQuick) then
return ()
match ( injectionAny mvarId forbidden) with
| .solved => return ()
| .failed =>
match ( injectionAny mvarId) with
| InjectionAnyResult.solved => return ()
| InjectionAnyResult.failed =>
let mvarId' substVars mvarId
if mvarId' == mvarId then
if ( mvarId.contradictionCore {}) then
return ()
throwError "failed to generate splitter for match auxiliary declaration '{matchDeclName}', unsolved subgoal:\n{MessageData.ofGoal mvarId}"
else
proveSubgoalLoop mvarId' forbidden
| .subgoal fvarId mvarId => proveSubgoalLoop mvarId (forbidden.insert fvarId)
proveSubgoalLoop mvarId'
| InjectionAnyResult.subgoal mvarId => proveSubgoalLoop mvarId
proveSubgoal (mvarId : MVarId) : MetaM Unit := do
trace[Meta.Match.matchEqs] "subgoal {mkMVar mvarId}, {repr (← mvarId.getDecl).kind}, {← mvarId.isAssigned}\n{MessageData.ofGoal mvarId}"
let (_, mvarId) mvarId.intros
let mvarId mvarId.tryClearMany (alts.map (·.fvarId!))
proveSubgoalLoop mvarId {}
proveSubgoalLoop mvarId
/--
Create new alternatives (aka minor premises) by replacing `discrs` with `patterns` at `alts`.
@@ -659,7 +646,7 @@ where
/--
Create conditional equations and splitter for the given match auxiliary declaration. -/
private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns := withLCtx {} {} do
withTraceNode `Meta.Match.matchEqs (fun _ => return m!"mkEquationsFor '{matchDeclName}'") do
trace[Meta.Match.matchEqs] "mkEquationsFor '{matchDeclName}'"
withConfig (fun c => { c with etaStruct := .none }) do
/-
Remark: user have requested the `split` tactic to be available for writing code.

View File

@@ -59,9 +59,9 @@ def addArg (matcherApp : MatcherApp) (e : Expr) : MetaM MatcherApp :=
-- This error can only happen if someone implemented a transformation that rewrites the motive created by `mkMatcher`.
throwError "unexpected matcher application, motive must be lambda expression with #{matcherApp.discrs.size} arguments"
let eType inferType e
let eTypeAbst matcherApp.discrs.size.foldRevM (init := eType) fun i _ eTypeAbst => do
let eTypeAbst matcherApp.discrs.size.foldRevM (init := eType) fun i eTypeAbst => do
let motiveArg := motiveArgs[i]!
let discr := matcherApp.discrs[i]
let discr := matcherApp.discrs[i]!
let eTypeAbst kabstract eTypeAbst discr
return eTypeAbst.instantiate1 motiveArg
let motiveBody mkArrow eTypeAbst motiveBody
@@ -118,9 +118,9 @@ def refineThrough (matcherApp : MatcherApp) (e : Expr) : MetaM (Array Expr) :=
-- This error can only happen if someone implemented a transformation that rewrites the motive created by `mkMatcher`.
throwError "failed to transfer argument through matcher application, motive must be lambda expression with #{matcherApp.discrs.size} arguments"
let eAbst matcherApp.discrs.size.foldRevM (init := e) fun i _ eAbst => do
let eAbst matcherApp.discrs.size.foldRevM (init := e) fun i eAbst => do
let motiveArg := motiveArgs[i]!
let discr := matcherApp.discrs[i]
let discr := matcherApp.discrs[i]!
let eTypeAbst kabstract eAbst discr
return eTypeAbst.instantiate1 motiveArg
-- Let's create something thats a `Sort` and mentions `e`

View File

@@ -104,8 +104,8 @@ def appendParentTag (mvarId : MVarId) (newMVars : Array Expr) (binderInfos : Arr
newMVars[0]!.mvarId!.setTag parentTag
else
unless parentTag.isAnonymous do
newMVars.size.forM fun i _ => do
let mvarIdNew := newMVars[i].mvarId!
newMVars.size.forM fun i => do
let mvarIdNew := newMVars[i]!.mvarId!
unless ( mvarIdNew.isAssigned) do
unless binderInfos[i]!.isInstImplicit do
let currTag mvarIdNew.getTag

View File

@@ -168,7 +168,7 @@ private def hasIndepIndices (ctx : Context) : MetaM Bool := do
else if ctx.majorTypeIndices.any fun idx => !idx.isFVar then
/- One of the indices is not a free variable. -/
return false
else if ctx.majorTypeIndices.size.any fun i _ => i.any fun j _ => ctx.majorTypeIndices[i] == ctx.majorTypeIndices[j] then
else if ctx.majorTypeIndices.size.any fun i => i.any fun j => ctx.majorTypeIndices[i]! == ctx.majorTypeIndices[j]! then
/- An index occurs more than once -/
return false
else

View File

@@ -283,9 +283,9 @@ partial def foldAndCollect (oldIH newIH : FVarId) (isRecCall : Expr → Option E
(onMotive := fun xs _body => do
-- Remove the old IH that was added in mkFix
let eType newIH.getType
let eTypeAbst matcherApp.discrs.size.foldRevM (init := eType) fun i _ eTypeAbst => do
let eTypeAbst matcherApp.discrs.size.foldRevM (init := eType) fun i eTypeAbst => do
let motiveArg := xs[i]!
let discr := matcherApp.discrs[i]
let discr := matcherApp.discrs[i]!
let eTypeAbst kabstract eTypeAbst discr
return eTypeAbst.instantiate1 motiveArg

View File

@@ -105,10 +105,10 @@ private partial def finalize
let mvarId' mvar.mvarId!.tryClear major.fvarId!
let (fields, mvarId') mvarId'.introN nparams minorGivenNames.varNames (useNamesForExplicitOnly := !minorGivenNames.explicit)
let (extra, mvarId') mvarId'.introNP nextra
let subst := reverted.size.fold (init := baseSubst) fun i _ (subst : FVarSubst) =>
let subst := reverted.size.fold (init := baseSubst) fun i (subst : FVarSubst) =>
if i < indices.size + 1 then subst
else
let revertedFVarId := reverted[i]
let revertedFVarId := reverted[i]!
let newFVarId := extra[i - indices.size - 1]!
subst.insert revertedFVarId (mkFVar newFVarId)
let fields := fields.map mkFVar
@@ -134,8 +134,8 @@ def getMajorTypeIndices (mvarId : MVarId) (tacticName : Name) (recursorInfo : Re
if idxPos majorTypeArgs.size then throwTacticEx tacticName mvarId m!"major premise type is ill-formed{indentExpr majorType}"
let idx := majorTypeArgs.get! idxPos
unless idx.isFVar do throwTacticEx tacticName mvarId m!"major premise type index {idx} is not a variable{indentExpr majorType}"
majorTypeArgs.size.forM fun i _ => do
let arg := majorTypeArgs[i]
majorTypeArgs.size.forM fun i => do
let arg := majorTypeArgs[i]!
if i != idxPos && arg == idx then
throwTacticEx tacticName mvarId m!"'{idx}' is an index in major premise, but it occurs more than once{indentExpr majorType}"
if i < idxPos then

View File

@@ -94,48 +94,31 @@ def injection (mvarId : MVarId) (fvarId : FVarId) (newNames : List Name := []) :
| .subgoal mvarId numEqs => injectionIntro mvarId numEqs newNames
inductive InjectionsResult where
/-- `injections` closed the input goal. -/
| solved
/--
`injections` produces a new goal `mvarId`. `remainingNames` contains the user-facing names that have not been used.
`forbidden` contains all local declarations to which `injection` has been applied.
Recall that some of these declarations may not have been eliminated from the local context due to forward dependencies, and
we use `forbidden` to avoid non-termination when using `injections` in a loop.
-/
| subgoal (mvarId : MVarId) (remainingNames : List Name) (forbidden : FVarIdSet)
| subgoal (mvarId : MVarId) (remainingNames : List Name)
/--
Applies `injection` to local declarations in `mvarId`. It uses `newNames` to name the new local declarations.
`maxDepth` is the maximum recursion depth. Only local declarations that are not in `forbidden` are considered.
Recall that some of local declarations may not have been eliminated from the local context due to forward dependencies, and
we use `forbidden` to avoid non-termination when using `injections` in a loop.
-/
partial def injections (mvarId : MVarId) (newNames : List Name := []) (maxDepth : Nat := 5) (forbidden : FVarIdSet := {}) : MetaM InjectionsResult :=
partial def injections (mvarId : MVarId) (newNames : List Name := []) (maxDepth : Nat := 5) : MetaM InjectionsResult :=
mvarId.withContext do
let fvarIds := ( getLCtx).getFVarIds
go maxDepth fvarIds.toList mvarId newNames forbidden
go maxDepth fvarIds.toList mvarId newNames
where
go (depth : Nat) (fvarIds : List FVarId) (mvarId : MVarId) (newNames : List Name) (forbidden : FVarIdSet) : MetaM InjectionsResult := do
match depth, fvarIds with
| 0, _ => throwTacticEx `injections mvarId "recursion depth exceeded"
| _, [] => return .subgoal mvarId newNames forbidden
| d+1, fvarId :: fvarIds => do
go : Nat List FVarId MVarId List Name MetaM InjectionsResult
| 0, _, _, _ => throwTacticEx `injections mvarId "recursion depth exceeded"
| _, [], mvarId, newNames => return .subgoal mvarId newNames
| d+1, fvarId :: fvarIds, mvarId, newNames => do
let cont := do
go (d+1) fvarIds mvarId newNames forbidden
if forbidden.contains fvarId then
cont
else if let some (_, lhs, rhs) matchEqHEq? ( fvarId.getType) then
go (d+1) fvarIds mvarId newNames
if let some (_, lhs, rhs) matchEqHEq? ( fvarId.getType) then
let lhs whnf lhs
let rhs whnf rhs
if lhs.isRawNatLit && rhs.isRawNatLit then
cont
if lhs.isRawNatLit && rhs.isRawNatLit then cont
else
try
commitIfNoEx do
match ( injection mvarId fvarId newNames) with
| .solved => return .solved
| .subgoal mvarId newEqs remainingNames =>
mvarId.withContext <| go d (newEqs.toList ++ fvarIds) mvarId remainingNames (forbidden.insert fvarId)
mvarId.withContext <| go d (newEqs.toList ++ fvarIds) mvarId remainingNames
catch _ => cont
else cont

View File

@@ -43,7 +43,6 @@ def _root_.Lean.MVarId.revert (mvarId : MVarId) (fvarIds : Array FVarId) (preser
finally
mvarId.setKind .syntheticOpaque
let mvar := e.getAppFn
mvar.mvarId!.setKind .syntheticOpaque
mvar.mvarId!.setTag tag
return (toRevert.map Expr.fvarId!, mvar.mvarId!)

View File

@@ -45,26 +45,12 @@ def _root_.Lean.MVarId.rewrite (mvarId : MVarId) (e : Expr) (heq : Expr)
let eNew instantiateMVars eNew
let eType inferType e
let motive := Lean.mkLambda `_a BinderInfo.default α eAbst
try
check motive
catch ex =>
throwTacticEx `rewrite mvarId m!"\
motive is not type correct:{indentD motive}\nError: {ex.toMessageData}\
\n\n\
Explanation: The rewrite tactic rewrites an expression 'e' using an equality 'a = b' by the following process. \
First, it looks for all 'a' in 'e'. Second, it tries to abstract these occurrences of 'a' to create a function 'm := fun _a => ...', called the *motive*, \
with the property that 'm a' is definitionally equal to 'e'. \
Third, we observe that '{.ofConstName ``congrArg}' implies that 'm a = m b', which can be used with lemmas such as '{.ofConstName ``Eq.mpr}' to change the goal. \
However, if 'e' depends on specific properties of 'a', then the motive 'm' might not typecheck.\
\n\n\
Possible solutions: use rewrite's 'occs' configuration option to limit which occurrences are rewritten, \
or use 'simp' or 'conv' mode, which have strategies for certain kinds of dependencies \
(these tactics can handle proofs and '{.ofConstName ``Decidable}' instances whose types depend on the rewritten term, \
and 'simp' can apply user-defined '@[congr]' theorems as well)."
unless ( isTypeCorrect motive) do
throwTacticEx `rewrite mvarId "motive is not type correct"
unless ( withLocalDeclD `_a α fun a => do isDefEq ( inferType (eAbst.instantiate1 a)) eType) do
-- NB: using motive.arrow? would disallow motives where the dependency
-- can be reduced away
throwTacticEx `rewrite mvarId m!"motive is dependent{indentD motive}"
throwTacticEx `rewrite mvarId "motive is dependent"
let u1 getLevel α
let u2 getLevel eType
let eqPrf := mkApp6 (.const ``congrArg [u1, u2]) α eType lhs rhs motive heq

View File

@@ -571,127 +571,10 @@ def congr (e : Expr) : SimpM Result := do
else
congrDefault e
/--
Returns `true` if `e` is of the form `@letFun _ (fun _ => β) _ _`
`β` must not contain loose bound variables. Recall that `simp` does not have support for `let_fun`s
where resulting type depends on the `let`-value. Example:
```
let_fun n := 10;
BitVec.zero n
```
-/
def isNonDepLetFun (e : Expr) : Bool :=
let_expr letFun _ beta _ body := e | false
beta.isLambda && !beta.bindingBody!.hasLooseBVars && body.isLambda
/--
Auxiliary structure used to represent the return value of `simpNonDepLetFun.go`.
-/
private structure SimpLetFunResult where
/--
The simplified expression. Note that is may contain loose bound variables.
`simpNonDepLetFun.go` attempts to minimize the quadratic overhead imposed
by the locally nameless discipline in a sequence of `let_fun` declarations.
-/
expr : Expr
/--
The proof that the simplified expression is equal to the input one.
It may containt loose bound variables. See `expr` field.
-/
proof : Expr
/--
`modified := false` iff `expr` is structurally equal to the input expression.
-/
modified : Bool
/--
Simplifies a sequence of `let_fun` declarations.
It attempts to minimize the quadratic overhead imposed by
the locally nameless discipline.
-/
partial def simpNonDepLetFun (e : Expr) : SimpM Result := do
let rec go (xs : Array Expr) (e : Expr) : SimpM SimpLetFunResult := do
/-
Helper function applied when `e` is not a `let_fun` or
is a non supported `let_fun` (e.g., the resulting type depends on the value).
-/
let stop : SimpM SimpLetFunResult := do
let e := e.instantiateRev xs
let r simp e
return { expr := r.expr.abstract xs, proof := ( r.getProof).abstract xs, modified := r.expr != e }
let_expr f@letFun alpha betaFun val body := e | stop
let us := f.constLevels!
let [_, v] := us | stop
/-
Recall that `let_fun x : α := val; e[x]` is encoded at
```
@letFun α (fun x : α => β[x]) val (fun x : α => e[x])
```
`betaFun` is `(fun x : α => β[x])`. If `β[x]` does not have loose bound variables then the resulting type
does not depend on the value since it does not depend on `x`.
We also check whether `alpha` does not depend on the previous `let_fun`s in the sequence.
This check is just to make the code simpler. It is not common to have a type depending on the value of a previous `let_fun`.
-/
if alpha.hasLooseBVars || !betaFun.isLambda || !body.isLambda || betaFun.bindingBody!.hasLooseBVars then
stop
else if !body.bindingBody!.hasLooseBVar 0 then
/-
Redundant `let_fun`. The simplifier will remove it.
Remark: the `hasLooseBVar` check here may introduce a quadratic overhead in the worst case.
If observe that in practice, we may use a separate step for removing unused variables.
Remark: note that we do **not** simplify the value in this case.
-/
let x := mkConst `__no_used_dummy__ -- dummy value
let { expr, proof, .. } go (xs.push x) body.bindingBody!
let proof := mkApp6 (mkConst ``letFun_unused us) alpha betaFun.bindingBody! val body.bindingBody! expr proof
return { expr, proof, modified := true }
else
let beta := betaFun.bindingBody!
let valInst := val.instantiateRev xs
let valResult simp valInst
withLocalDecl body.bindingName! body.bindingInfo! alpha fun x => do
let valIsNew := valResult.expr != valInst
let { expr, proof, modified := bodyIsNew } go (xs.push x) body.bindingBody!
if !valIsNew && !bodyIsNew then
/-
Value and body were not simplified. We just return `e` and a new refl proof.
We must use the low-level `Expr` APIs because `e` may contain loose bound variables.
-/
let proof := mkApp2 (mkConst ``Eq.refl [v]) beta e
return { expr := e, proof, modified := false }
else
let body' := mkLambda body.bindingName! body.bindingInfo! alpha expr
let val' := valResult.expr.abstract xs
let e' := mkApp4 f alpha betaFun val' body'
if valIsNew && bodyIsNew then
-- Value and body were simplified
let valProof := ( valResult.getProof).abstract xs
let proof := mkApp8 (mkConst ``letFun_congr us) alpha beta val val' body body' valProof (mkLambda body.bindingName! body.bindingInfo! alpha proof)
return { expr := e', proof, modified := true }
else if valIsNew then
-- Only the value was simplified.
let valProof := ( valResult.getProof).abstract xs
let proof := mkApp6 (mkConst ``letFun_val_congr us) alpha beta val val' body valProof
return { expr := e', proof, modified := true }
else
-- Only the body was simplified.
let proof := mkApp6 (mkConst ``letFun_body_congr us) alpha beta val body body' (mkLambda body.bindingName! body.bindingInfo! alpha proof)
return { expr := e', proof, modified := true }
let { expr, proof, modified } go #[] e
if !modified then
return { expr := e }
else
return { expr, proof? := proof }
def simpApp (e : Expr) : SimpM Result := do
if isOfNatNatLit e || isOfScientificLit e || isCharLit e then
-- Recall that we fold "orphan" kernel Nat literals `n` into `OfNat.ofNat n`
return { expr := e }
else if isNonDepLetFun e then
simpNonDepLetFun e
else
congr e

View File

@@ -146,8 +146,12 @@ def mkContext (config : Config := {}) (simpTheorems : SimpTheoremsArray := {}) (
indexConfig := mkIndexConfig config
}
def Context.setConfig (context : Context) (config : Config) : Context :=
{ context with config }
def Context.setConfig (context : Context) (config : Config) : MetaM Context := do
return { context with
config
metaConfig := ( mkMetaConfig config)
indexConfig := ( mkIndexConfig config)
}
def Context.setSimpTheorems (c : Context) (simpTheorems : SimpTheoremsArray) : Context :=
{ c with simpTheorems }

View File

@@ -80,9 +80,9 @@ def substCore (mvarId : MVarId) (hFVarId : FVarId) (symm := false) (fvarSubst :
pure mvarId
let (newFVars, mvarId) mvarId.introNP (vars.size - 2)
trace[Meta.Tactic.subst] "after intro rest {vars.size - 2} {MessageData.ofGoal mvarId}"
let fvarSubst newFVars.size.foldM (init := fvarSubst) fun i _ (fvarSubst : FVarSubst) =>
let fvarSubst newFVars.size.foldM (init := fvarSubst) fun i (fvarSubst : FVarSubst) =>
let var := vars[i+2]!
let newFVar := newFVars[i]
let newFVar := newFVars[i]!
pure $ fvarSubst.insert var (mkFVar newFVar)
let fvarSubst := fvarSubst.insert aFVarIdOriginal (if clearH then b else mkFVar aFVarId)
let fvarSubst := fvarSubst.insert hFVarIdOriginal (mkFVar hFVarId)

View File

@@ -979,10 +979,10 @@ def collectForwardDeps (lctx : LocalContext) (toRevert : Array Expr) : M (Array
else
if ( preserveOrder) then
-- Make sure toRevert[j] does not depend on toRevert[i] for j < i
toRevert.size.forM fun i _ => do
let fvar := toRevert[i]
i.forM fun j _ => do
let prevFVar := toRevert[j]
toRevert.size.forM fun i => do
let fvar := toRevert[i]!
i.forM fun j => do
let prevFVar := toRevert[j]!
let prevDecl := lctx.getFVar! prevFVar
if ( localDeclDependsOn prevDecl fvar.fvarId!) then
throw (Exception.revertFailure ( getMCtx) lctx toRevert prevDecl.userName.toString)
@@ -990,7 +990,7 @@ def collectForwardDeps (lctx : LocalContext) (toRevert : Array Expr) : M (Array
let firstDeclToVisit := getLocalDeclWithSmallestIdx lctx toRevert
let initSize := newToRevert.size
lctx.foldlM (init := newToRevert) (start := firstDeclToVisit.index) fun (newToRevert : Array Expr) decl => do
if initSize.any fun i _ => decl.fvarId == newToRevert[i]!.fvarId! then
if initSize.any fun i => decl.fvarId == newToRevert[i]!.fvarId! then
return newToRevert
else if toRevert.any fun x => decl.fvarId == x.fvarId! then
return newToRevert.push decl.toExpr
@@ -1061,8 +1061,8 @@ mutual
-/
private partial def mkAuxMVarType (lctx : LocalContext) (xs : Array Expr) (kind : MetavarKind) (e : Expr) (usedLetOnly : Bool) : M Expr := do
let e abstractRangeAux xs xs.size e
xs.size.foldRevM (init := e) fun i _ e => do
let x := xs[i]
xs.size.foldRevM (init := e) fun i e => do
let x := xs[i]!
if x.isFVar then
match lctx.getFVar! x with
| LocalDecl.cdecl _ _ n type bi _ =>
@@ -1231,8 +1231,8 @@ private def mkLambda' (x : Name) (bi : BinderInfo) (t : Expr) (b : Expr) (etaRed
If `usedLetOnly == true` then `let` expressions are created only for used (let-) variables. -/
def mkBinding (isLambda : Bool) (lctx : LocalContext) (xs : Array Expr) (e : Expr) (usedOnly : Bool) (usedLetOnly : Bool) (etaReduce : Bool) : M Expr := do
let e abstractRange xs xs.size e
xs.size.foldRevM (init := e) fun i _ e => do
let x := xs[i]
xs.size.foldRevM (init := e) fun i e => do
let x := xs[i]!
if x.isFVar then
match lctx.getFVar! x with
| LocalDecl.cdecl _ _ n type bi _ =>

View File

@@ -91,12 +91,9 @@ open Delaborator in
/-- Pretty-prints a declaration `c` as `c.{<levels>} <params> : <type>`. -/
def ppSignature (c : Name) : MetaM FormatWithInfos := do
let decl getConstInfo c
let e := Expr.const c (decl.levelParams.map mkLevelParam)
if pp.raw.get ( getOptions) then
return s!"{e} : {decl.type}"
else
let (stx, infos) delabCore e (delab := delabConstWithSignature)
return ppTerm stx, infos -- HACK: not a term
let e := .const c (decl.levelParams.map mkLevelParam)
let (stx, infos) delabCore e (delab := delabConstWithSignature)
return ppTerm stx, infos -- HACK: not a term
private partial def noContext : MessageData MessageData
| MessageData.withContext _ msg => noContext msg

View File

@@ -1115,18 +1115,17 @@ def coeDelaborator : Delab := whenPPOption getPPCoercions do
delabAppCore nargs (delabHead info nargs) (unexpand := false)
where
delabHead (info : CoeFnInfo) (nargs : Nat) (insertExplicit : Bool) : Delab := do
withTypeAscription (cond := getPPOption getPPCoercionsTypes) do
guard <| !insertExplicit
if info.type == .coeFun && nargs > 0 then
-- In the CoeFun case, annotate with the coercee itself.
-- We can still see the whole coercion expression by hovering over the whitespace between the arguments.
withNaryArg info.coercee <| withAnnotateTermInfo delab
else
withAnnotateTermInfo do
match info.type with
| .coe => `($( withNaryArg info.coercee delab))
| .coeFun => `($( withNaryArg info.coercee delab))
| .coeSort => `($( withNaryArg info.coercee delab))
guard <| !insertExplicit
if info.type == .coeFun && nargs > 0 then
-- In the CoeFun case, annotate with the coercee itself.
-- We can still see the whole coercion expression by hovering over the whitespace between the arguments.
withNaryArg info.coercee <| withAnnotateTermInfo delab
else
withAnnotateTermInfo do
match info.type with
| .coe => `($( withNaryArg info.coercee delab))
| .coeFun => `($( withNaryArg info.coercee delab))
| .coeSort => `($( withNaryArg info.coercee delab))
@[builtin_delab app.dite]
def delabDIte : Delab := whenNotPPOption getPPExplicit <| whenPPOption getPPNotation <| withOverApp 5 do

View File

@@ -44,11 +44,6 @@ register_builtin_option pp.coercions : Bool := {
group := "pp"
descr := "(pretty printer) hide coercion applications"
}
register_builtin_option pp.coercions.types : Bool := {
defValue := false
group := "pp"
descr := "(pretty printer) display coercion applications with a type ascription"
}
register_builtin_option pp.universes : Bool := {
defValue := false
group := "pp"
@@ -256,7 +251,6 @@ def getPPLetVarTypes (o : Options) : Bool := o.get pp.letVarTypes.name (getPPAll
def getPPNumericTypes (o : Options) : Bool := o.get pp.numericTypes.name pp.numericTypes.defValue
def getPPNatLit (o : Options) : Bool := o.get pp.natLit.name (getPPNumericTypes o && !getPPAll o)
def getPPCoercions (o : Options) : Bool := o.get pp.coercions.name (!getPPAll o)
def getPPCoercionsTypes (o : Options) : Bool := o.get pp.coercions.types.name pp.coercions.types.defValue
def getPPExplicit (o : Options) : Bool := o.get pp.explicit.name (getPPAll o)
def getPPNotation (o : Options) : Bool := o.get pp.notation.name (!getPPAll o)
def getPPParens (o : Options) : Bool := o.get pp.parens.name pp.parens.defValue

View File

@@ -467,7 +467,7 @@ def visitAtom (k : SyntaxNodeKind) : Formatter := do
@[combinator_formatter manyNoAntiquot]
def manyNoAntiquot.formatter (p : Formatter) : Formatter := do
let stx getCur
visitArgs $ stx.getArgs.size.forM fun _ _ => p
visitArgs $ stx.getArgs.size.forM fun _ => p
@[combinator_formatter many1NoAntiquot] def many1NoAntiquot.formatter (p : Formatter) : Formatter := manyNoAntiquot.formatter p
@@ -487,7 +487,7 @@ def many1Unbox.formatter (p : Formatter) : Formatter := do
@[combinator_formatter sepByNoAntiquot]
def sepByNoAntiquot.formatter (p pSep : Formatter) : Formatter := do
let stx getCur
visitArgs <| stx.getArgs.size.forRevM fun i _ => if i % 2 == 0 then p else pSep
visitArgs <| stx.getArgs.size.forRevM fun i => if i % 2 == 0 then p else pSep
@[combinator_formatter sepBy1NoAntiquot] def sepBy1NoAntiquot.formatter := sepByNoAntiquot.formatter

View File

@@ -268,7 +268,7 @@ def visitToken : Parenthesizer := do
let stx getCur
-- `orelse` may produce `choice` nodes for antiquotations
if stx.getKind == `choice then
visitArgs $ stx.getArgs.size.forM fun _ _ => do
visitArgs $ stx.getArgs.size.forM fun _ => do
orelse.parenthesizer p1 p2
else
-- HACK: We have no (immediate) information on which side of the orelse could have produced the current node, so try
@@ -332,7 +332,7 @@ partial def parenthesizeCategoryCore (cat : Name) (_prec : Nat) : Parenthesizer
withReader (fun ctx => { ctx with cat := cat }) do
let stx getCur
if stx.getKind == `choice then
visitArgs $ stx.getArgs.size.forM fun _ _ => do
visitArgs $ stx.getArgs.size.forM fun _ => do
parenthesizeCategoryCore cat _prec
else
withAntiquot.parenthesizer (mkAntiquot.parenthesizer' cat.toString cat (isPseudoKind := true)) (parenthesizerForKind stx.getKind)
@@ -470,7 +470,7 @@ def trailingNode.parenthesizer (k : SyntaxNodeKind) (prec lhsPrec : Nat) (p : Pa
@[combinator_parenthesizer manyNoAntiquot]
def manyNoAntiquot.parenthesizer (p : Parenthesizer) : Parenthesizer := do
let stx getCur
visitArgs $ stx.getArgs.size.forM fun _ _ => p
visitArgs $ stx.getArgs.size.forM fun _ => p
@[combinator_parenthesizer many1NoAntiquot]
def many1NoAntiquot.parenthesizer (p : Parenthesizer) : Parenthesizer := do

Some files were not shown because too many files have changed in this diff Show More