mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-20 11:54:07 +00:00
Compare commits
1 Commits
insertionS
...
checkConfi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d0769a274 |
24
.github/workflows/labels-from-comments.yml
vendored
24
.github/workflows/labels-from-comments.yml
vendored
@@ -1,8 +1,7 @@
|
||||
# This workflow allows any user to add one of the `awaiting-review`, `awaiting-author`, `WIP`,
|
||||
# `release-ci`, or a `changelog-XXX` label by commenting on the PR or issue.
|
||||
# or `release-ci` labels by commenting on the PR or issue.
|
||||
# If any labels from the set {`awaiting-review`, `awaiting-author`, `WIP`} are added, other labels
|
||||
# from that set are removed automatically at the same time.
|
||||
# Similarly, if any `changelog-XXX` label is added, other `changelog-YYY` labels are removed.
|
||||
|
||||
name: Label PR based on Comment
|
||||
|
||||
@@ -12,7 +11,7 @@ on:
|
||||
|
||||
jobs:
|
||||
update-label:
|
||||
if: github.event.issue.pull_request != null && (contains(github.event.comment.body, 'awaiting-review') || contains(github.event.comment.body, 'awaiting-author') || contains(github.event.comment.body, 'WIP') || contains(github.event.comment.body, 'release-ci') || contains(github.event.comment.body, 'changelog-'))
|
||||
if: github.event.issue.pull_request != null && (contains(github.event.comment.body, 'awaiting-review') || contains(github.event.comment.body, 'awaiting-author') || contains(github.event.comment.body, 'WIP') || contains(github.event.comment.body, 'release-ci'))
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
@@ -21,14 +20,13 @@ jobs:
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const { owner, repo, number: issue_number } = context.issue;
|
||||
const { owner, repo, number: issue_number } = context.issue;
|
||||
const commentLines = context.payload.comment.body.split('\r\n');
|
||||
|
||||
const awaitingReview = commentLines.includes('awaiting-review');
|
||||
const awaitingAuthor = commentLines.includes('awaiting-author');
|
||||
const wip = commentLines.includes('WIP');
|
||||
const releaseCI = commentLines.includes('release-ci');
|
||||
const changelogMatch = commentLines.find(line => line.startsWith('changelog-'));
|
||||
|
||||
if (awaitingReview || awaitingAuthor || wip) {
|
||||
await github.rest.issues.removeLabel({ owner, repo, issue_number, name: 'awaiting-review' }).catch(() => {});
|
||||
@@ -49,19 +47,3 @@ jobs:
|
||||
if (releaseCI) {
|
||||
await github.rest.issues.addLabels({ owner, repo, issue_number, labels: ['release-ci'] });
|
||||
}
|
||||
|
||||
if (changelogMatch) {
|
||||
const changelogLabel = changelogMatch.trim();
|
||||
const { data: existingLabels } = await github.rest.issues.listLabelsOnIssue({ owner, repo, issue_number });
|
||||
const changelogLabels = existingLabels.filter(label => label.name.startsWith('changelog-'));
|
||||
|
||||
// Remove all other changelog labels
|
||||
for (const label of changelogLabels) {
|
||||
if (label.name !== changelogLabel) {
|
||||
await github.rest.issues.removeLabel({ owner, repo, issue_number, name: label.name }).catch(() => {});
|
||||
}
|
||||
}
|
||||
|
||||
// Add the new changelog label
|
||||
await github.rest.issues.addLabels({ owner, repo, issue_number, labels: [changelogLabel] });
|
||||
}
|
||||
|
||||
@@ -103,21 +103,10 @@ your PR using rebase merge, bypassing the merge queue.
|
||||
As written above, changes in meta code in the current stage usually will only
|
||||
affect later stages. This is an issue in two specific cases.
|
||||
|
||||
* For the special case of *quotations*, it is desirable to have changes in builtin parsers affect them immediately: when the changes in the parser become active in the next stage, builtin macros implemented via quotations should generate syntax trees compatible with the new parser, and quotation patterns in builtin macros and elaborators should be able to match syntax created by the new parser and macros.
|
||||
Since quotations capture the syntax tree structure during execution of the current stage and turn it into code for the next stage, we need to run the current stage's builtin parsers in quotations via the interpreter for this to work.
|
||||
Caveats:
|
||||
* We activate this behavior by default when building stage 1 by setting `-Dinternal.parseQuotWithCurrentStage=true`.
|
||||
We force-disable it inside `macro/macro_rules/elab/elab_rules` via `suppressInsideQuot` as they are guaranteed not to run in the next stage and may need to be run in the current one, so the stage 0 parser is the correct one to use for them.
|
||||
It may be necessary to extend this disabling to functions that contain quotations and are (exclusively) used by one of the mentioned commands. A function using quotations should never be used by both builtin and non-builtin macros/elaborators. Example: https://github.com/leanprover/lean4/blob/f70b7e5722da6101572869d87832494e2f8534b7/src/Lean/Elab/Tactic/Config.lean#L118-L122
|
||||
* The parser needs to be reachable via an `import` statement, otherwise the version of the previous stage will silently be used.
|
||||
* Only the parser code (`Parser.fn`) is affected; all metadata such as leading tokens is taken from the previous stage.
|
||||
|
||||
For an example, see https://github.com/leanprover/lean4/commit/f9dcbbddc48ccab22c7674ba20c5f409823b4cc1#diff-371387aed38bb02bf7761084fd9460e4168ae16d1ffe5de041b47d3ad2d22422R13
|
||||
|
||||
* For *non-builtin* meta code such as `notation`s or `macro`s in
|
||||
`Notation.lean`, we expect changes to affect the current file and all later
|
||||
files of the same stage immediately, just like outside the stdlib. To ensure
|
||||
this, we build stage 1 using `-Dinterpreter.prefer_native=false` -
|
||||
this, we need to build the stage using `-Dinterpreter.prefer_native=false` -
|
||||
otherwise, when executing a macro, the interpreter would notice that there is
|
||||
already a native symbol available for this function and run it instead of the
|
||||
new IR, but the symbol is from the previous stage!
|
||||
@@ -135,11 +124,26 @@ affect later stages. This is an issue in two specific cases.
|
||||
further stages (e.g. after an `update-stage0`) will then need to be compiled
|
||||
with the flag set to `false` again since they will expect the new signature.
|
||||
|
||||
When enabling `prefer_native`, we usually want to *disable* `parseQuotWithCurrentStage` as it would otherwise make quotations use the interpreter after all.
|
||||
However, there is a specific case where we want to set both options to `true`: when we make changes to a non-builtin parser like `simp` that has a builtin elaborator, we cannot have the new parser be active outside of quotations in stage 1 as the builtin elaborator from stage 0 would not understand them; on the other hand, we need quotations in e.g. the builtin `simp` elaborator to produce the new syntax in the next stage.
|
||||
As this issue usually affects only tactics, enabling `debug.byAsSorry` instead of `prefer_native` can be a simpler solution.
|
||||
For an example, see https://github.com/leanprover/lean4/commit/da4c46370d85add64ef7ca5e7cc4638b62823fbb.
|
||||
|
||||
For a `prefer_native` example, see https://github.com/leanprover/lean4/commit/da4c46370d85add64ef7ca5e7cc4638b62823fbb.
|
||||
* For the special case of *quotations*, it is desirable to have changes in
|
||||
built-in parsers affect them immediately: when the changes in the parser
|
||||
become active in the next stage, macros implemented via quotations should
|
||||
generate syntax trees compatible with the new parser, and quotation patterns
|
||||
in macro and elaborators should be able to match syntax created by the new
|
||||
parser and macros. Since quotations capture the syntax tree structure during
|
||||
execution of the current stage and turn it into code for the next stage, we
|
||||
need to run the current stage's built-in parsers in quotation via the
|
||||
interpreter for this to work. Caveats:
|
||||
* Since interpreting full parsers is not nearly as cheap and we rarely change
|
||||
built-in syntax, this needs to be opted in using `-Dinternal.parseQuotWithCurrentStage=true`.
|
||||
* The parser needs to be reachable via an `import` statement, otherwise the
|
||||
version of the previous stage will silently be used.
|
||||
* Only the parser code (`Parser.fn`) is affected; all metadata such as leading
|
||||
tokens is taken from the previous stage.
|
||||
|
||||
For an example, see https://github.com/leanprover/lean4/commit/f9dcbbddc48ccab22c7674ba20c5f409823b4cc1#diff-371387aed38bb02bf7761084fd9460e4168ae16d1ffe5de041b47d3ad2d22422
|
||||
(from before the flag defaulted to `false`).
|
||||
|
||||
To modify either of these flags both for building and editing the stdlib, adjust
|
||||
the code in `stage0/src/stdlib_flags.h`. The flags will automatically be reset
|
||||
|
||||
@@ -12,17 +12,17 @@ Remark: this example is based on an example found in the Idris manual.
|
||||
Vectors
|
||||
--------
|
||||
|
||||
A `Vec` is a list of size `n` whose elements belong to a type `α`.
|
||||
A `Vector` is a list of size `n` whose elements belong to a type `α`.
|
||||
-/
|
||||
|
||||
inductive Vec (α : Type u) : Nat → Type u
|
||||
| nil : Vec α 0
|
||||
| cons : α → Vec α n → Vec α (n+1)
|
||||
inductive Vector (α : Type u) : Nat → Type u
|
||||
| nil : Vector α 0
|
||||
| cons : α → Vector α n → Vector α (n+1)
|
||||
|
||||
/-!
|
||||
We can overload the `List.cons` notation `::` and use it to create `Vec`s.
|
||||
We can overload the `List.cons` notation `::` and use it to create `Vector`s.
|
||||
-/
|
||||
infix:67 " :: " => Vec.cons
|
||||
infix:67 " :: " => Vector.cons
|
||||
|
||||
/-!
|
||||
Now, we define the types of our simple functional language.
|
||||
@@ -50,11 +50,11 @@ the builtin instance for `Add Int` as the solution.
|
||||
/-!
|
||||
Expressions are indexed by the types of the local variables, and the type of the expression itself.
|
||||
-/
|
||||
inductive HasType : Fin n → Vec Ty n → Ty → Type where
|
||||
inductive HasType : Fin n → Vector Ty n → Ty → Type where
|
||||
| stop : HasType 0 (ty :: ctx) ty
|
||||
| pop : HasType k ctx ty → HasType k.succ (u :: ctx) ty
|
||||
|
||||
inductive Expr : Vec Ty n → Ty → Type where
|
||||
inductive Expr : Vector Ty n → Ty → Type where
|
||||
| var : HasType i ctx ty → Expr ctx ty
|
||||
| val : Int → Expr ctx Ty.int
|
||||
| lam : Expr (a :: ctx) ty → Expr ctx (Ty.fn a ty)
|
||||
@@ -102,8 +102,8 @@ indexed over the types in scope. Since an environment is just another form of li
|
||||
to the vector of local variable types, we overload again the notation `::` so that we can use the usual list syntax.
|
||||
Given a proof that a variable is defined in the context, we can then produce a value from the environment.
|
||||
-/
|
||||
inductive Env : Vec Ty n → Type where
|
||||
| nil : Env Vec.nil
|
||||
inductive Env : Vector Ty n → Type where
|
||||
| nil : Env Vector.nil
|
||||
| cons : Ty.interp a → Env ctx → Env (a :: ctx)
|
||||
|
||||
infix:67 " :: " => Env.cons
|
||||
|
||||
@@ -82,7 +82,9 @@ theorem Expr.typeCheck_correct (h₁ : HasType e ty) (h₂ : e.typeCheck ≠ .un
|
||||
/-!
|
||||
Now, we prove that if `Expr.typeCheck e` returns `Maybe.unknown`, then forall `ty`, `HasType e ty` does not hold.
|
||||
The notation `e.typeCheck` is sugar for `Expr.typeCheck e`. Lean can infer this because we explicitly said that `e` has type `Expr`.
|
||||
The proof is by induction on `e` and case analysis. Note that the tactic `simp [typeCheck]` is applied to all goal generated by the `induction` tactic, and closes
|
||||
The proof is by induction on `e` and case analysis. The tactic `rename_i` is used to rename "inaccessible" variables.
|
||||
We say a variable is inaccessible if it is introduced by a tactic (e.g., `cases`) or has been shadowed by another variable introduced
|
||||
by the user. Note that the tactic `simp [typeCheck]` is applied to all goal generated by the `induction` tactic, and closes
|
||||
the cases corresponding to the constructors `Expr.nat` and `Expr.bool`.
|
||||
-/
|
||||
theorem Expr.typeCheck_complete {e : Expr} : e.typeCheck = .unknown → ¬ HasType e ty := by
|
||||
|
||||
@@ -51,8 +51,6 @@ option(LLVM "LLVM" OFF)
|
||||
option(USE_GITHASH "GIT_HASH" ON)
|
||||
# When ON we install LICENSE files to CMAKE_INSTALL_PREFIX
|
||||
option(INSTALL_LICENSE "INSTALL_LICENSE" ON)
|
||||
# When ON we install a copy of cadical
|
||||
option(INSTALL_CADICAL "Install a copy of cadical" ON)
|
||||
# When ON thread storage is automatically finalized, it assumes platform support pthreads.
|
||||
# This option is important when using Lean as library that is invoked from a different programming language (e.g., Haskell).
|
||||
option(AUTO_THREAD_FINALIZATION "AUTO_THREAD_FINALIZATION" ON)
|
||||
@@ -618,7 +616,7 @@ else()
|
||||
OUTPUT_NAME leancpp)
|
||||
endif()
|
||||
|
||||
if((${STAGE} GREATER 0) AND CADICAL AND INSTALL_CADICAL)
|
||||
if((${STAGE} GREATER 0) AND CADICAL)
|
||||
add_custom_target(copy-cadical
|
||||
COMMAND cmake -E copy_if_different "${CADICAL}" "${CMAKE_BINARY_DIR}/bin/cadical${CMAKE_EXECUTABLE_SUFFIX}")
|
||||
add_dependencies(leancpp copy-cadical)
|
||||
@@ -740,7 +738,7 @@ file(COPY ${LEAN_SOURCE_DIR}/bin/leanmake DESTINATION ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
install(DIRECTORY "${CMAKE_BINARY_DIR}/bin/" USE_SOURCE_PERMISSIONS DESTINATION bin)
|
||||
|
||||
if (${STAGE} GREATER 0 AND CADICAL AND INSTALL_CADICAL)
|
||||
if (${STAGE} GREATER 0 AND CADICAL)
|
||||
install(PROGRAMS "${CADICAL}" DESTINATION bin)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -43,4 +43,3 @@ import Init.Data.Zero
|
||||
import Init.Data.NeZero
|
||||
import Init.Data.Function
|
||||
import Init.Data.RArray
|
||||
import Init.Data.Vector
|
||||
|
||||
@@ -19,4 +19,3 @@ import Init.Data.Array.GetLit
|
||||
import Init.Data.Array.MapIdx
|
||||
import Init.Data.Array.Set
|
||||
import Init.Data.Array.Monadic
|
||||
import Init.Data.Array.FinRange
|
||||
|
||||
@@ -13,7 +13,6 @@ import Init.Data.ToString.Basic
|
||||
import Init.GetElem
|
||||
import Init.Data.List.ToArray
|
||||
import Init.Data.Array.Set
|
||||
|
||||
universe u v w
|
||||
|
||||
/-! ### Array literal syntax -/
|
||||
@@ -166,15 +165,15 @@ This will perform the update destructively provided that `a` has a reference
|
||||
count of 1 when called.
|
||||
-/
|
||||
@[extern "lean_array_fswap"]
|
||||
def swap (a : Array α) (i j : @& Nat) (hi : i < a.size := by get_elem_tactic) (hj : j < a.size := by get_elem_tactic) : Array α :=
|
||||
def swap (a : Array α) (i j : @& Fin a.size) : Array α :=
|
||||
let v₁ := a[i]
|
||||
let v₂ := a[j]
|
||||
let a' := a.set i v₂
|
||||
a'.set j v₁ (Nat.lt_of_lt_of_eq hj (size_set a i v₂ _).symm)
|
||||
a'.set j v₁ (Nat.lt_of_lt_of_eq j.isLt (size_set a i v₂ _).symm)
|
||||
|
||||
@[simp] theorem size_swap (a : Array α) (i j : Nat) {hi hj} : (a.swap i j hi hj).size = a.size := by
|
||||
@[simp] theorem size_swap (a : Array α) (i j : Fin a.size) : (a.swap i j).size = a.size := by
|
||||
show ((a.set i a[j]).set j a[i]
|
||||
(Nat.lt_of_lt_of_eq hj (size_set a i a[j] _).symm)).size = a.size
|
||||
(Nat.lt_of_lt_of_eq j.isLt (size_set a i a[j] _).symm)).size = a.size
|
||||
rw [size_set, size_set]
|
||||
|
||||
/--
|
||||
@@ -184,14 +183,12 @@ This will perform the update destructively provided that `a` has a reference
|
||||
count of 1 when called.
|
||||
-/
|
||||
@[extern "lean_array_swap"]
|
||||
def swapIfInBounds (a : Array α) (i j : @& Nat) : Array α :=
|
||||
def swap! (a : Array α) (i j : @& Nat) : Array α :=
|
||||
if h₁ : i < a.size then
|
||||
if h₂ : j < a.size then swap a i j
|
||||
if h₂ : j < a.size then swap a ⟨i, h₁⟩ ⟨j, h₂⟩
|
||||
else a
|
||||
else a
|
||||
|
||||
@[deprecated swapIfInBounds (since := "2024-11-24")] abbrev swap! := @swapIfInBounds
|
||||
|
||||
/-! ### GetElem instance for `USize`, backed by `uget` -/
|
||||
|
||||
instance : GetElem (Array α) USize α fun xs i => i.toNat < xs.size where
|
||||
@@ -236,7 +233,7 @@ def ofFn {n} (f : Fin n → α) : Array α := go 0 (mkEmpty n) where
|
||||
|
||||
/-- The array `#[0, 1, ..., n - 1]`. -/
|
||||
def range (n : Nat) : Array Nat :=
|
||||
ofFn fun (i : Fin n) => i
|
||||
n.fold (flip Array.push) (mkEmpty n)
|
||||
|
||||
def singleton (v : α) : Array α :=
|
||||
mkArray 1 v
|
||||
@@ -252,7 +249,7 @@ def get? (a : Array α) (i : Nat) : Option α :=
|
||||
def back? (a : Array α) : Option α :=
|
||||
a[a.size - 1]?
|
||||
|
||||
@[inline] def swapAt (a : Array α) (i : Nat) (v : α) (hi : i < a.size := by get_elem_tactic) : α × Array α :=
|
||||
@[inline] def swapAt (a : Array α) (i : Fin a.size) (v : α) : α × Array α :=
|
||||
let e := a[i]
|
||||
let a := a.set i v
|
||||
(e, a)
|
||||
@@ -260,7 +257,7 @@ def back? (a : Array α) : Option α :=
|
||||
@[inline]
|
||||
def swapAt! (a : Array α) (i : Nat) (v : α) : α × Array α :=
|
||||
if h : i < a.size then
|
||||
swapAt a i v
|
||||
swapAt a ⟨i, h⟩ v
|
||||
else
|
||||
have : Inhabited (α × Array α) := ⟨(v, a)⟩
|
||||
panic! ("index " ++ toString i ++ " out of bounds")
|
||||
@@ -749,7 +746,7 @@ where
|
||||
loop (as : Array α) (i : Nat) (j : Fin as.size) :=
|
||||
if h : i < j then
|
||||
have := termination h
|
||||
let as := as.swap i j (Nat.lt_trans h j.2)
|
||||
let as := as.swap ⟨i, Nat.lt_trans h j.2⟩ j
|
||||
have : j-1 < as.size := by rw [size_swap]; exact Nat.lt_of_le_of_lt (Nat.pred_le _) j.2
|
||||
loop as (i+1) ⟨j-1, this⟩
|
||||
else
|
||||
@@ -789,7 +786,7 @@ it has to backshift all elements at positions greater than `i`.-/
|
||||
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
|
||||
def eraseIdx (a : Array α) (i : Nat) (h : i < a.size := by get_elem_tactic) : Array α :=
|
||||
if h' : i + 1 < a.size then
|
||||
let a' := a.swap (i + 1) i
|
||||
let a' := a.swap ⟨i + 1, h'⟩ ⟨i, h⟩
|
||||
a'.eraseIdx (i + 1) (by simp [a', h'])
|
||||
else
|
||||
a.pop
|
||||
@@ -836,7 +833,7 @@ def eraseP (as : Array α) (p : α → Bool) : Array α :=
|
||||
let rec @[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
|
||||
loop (as : Array α) (j : Fin as.size) :=
|
||||
if i < j then
|
||||
let j' : Fin as.size := ⟨j-1, Nat.lt_of_le_of_lt (Nat.pred_le _) j.2⟩
|
||||
let j' := ⟨j-1, Nat.lt_of_le_of_lt (Nat.pred_le _) j.2⟩
|
||||
let as := as.swap j' j
|
||||
loop as ⟨j', by rw [size_swap]; exact j'.2⟩
|
||||
else
|
||||
@@ -886,12 +883,12 @@ def isPrefixOf [BEq α] (as bs : Array α) : Bool :=
|
||||
false
|
||||
|
||||
@[semireducible, specialize] -- This is otherwise irreducible because it uses well-founded recursion.
|
||||
def zipWithAux (as : Array α) (bs : Array β) (f : α → β → γ) (i : Nat) (cs : Array γ) : Array γ :=
|
||||
def zipWithAux (f : α → β → γ) (as : Array α) (bs : Array β) (i : Nat) (cs : Array γ) : Array γ :=
|
||||
if h : i < as.size then
|
||||
let a := as[i]
|
||||
if h : i < bs.size then
|
||||
let b := bs[i]
|
||||
zipWithAux as bs f (i+1) <| cs.push <| f a b
|
||||
zipWithAux f as bs (i+1) <| cs.push <| f a b
|
||||
else
|
||||
cs
|
||||
else
|
||||
@@ -899,23 +896,11 @@ def zipWithAux (as : Array α) (bs : Array β) (f : α → β → γ) (i : Nat)
|
||||
decreasing_by simp_wf; decreasing_trivial_pre_omega
|
||||
|
||||
@[inline] def zipWith (as : Array α) (bs : Array β) (f : α → β → γ) : Array γ :=
|
||||
zipWithAux as bs f 0 #[]
|
||||
zipWithAux f as bs 0 #[]
|
||||
|
||||
def zip (as : Array α) (bs : Array β) : Array (α × β) :=
|
||||
zipWith as bs Prod.mk
|
||||
|
||||
def zipWithAll (as : Array α) (bs : Array β) (f : Option α → Option β → γ) : Array γ :=
|
||||
go as bs 0 #[]
|
||||
where go (as : Array α) (bs : Array β) (i : Nat) (cs : Array γ) :=
|
||||
if i < max as.size bs.size then
|
||||
let a := as[i]?
|
||||
let b := bs[i]?
|
||||
go as bs (i+1) (cs.push (f a b))
|
||||
else
|
||||
cs
|
||||
termination_by max as.size bs.size - i
|
||||
decreasing_by simp_wf; decreasing_trivial_pre_omega
|
||||
|
||||
def unzip (as : Array (α × β)) : Array α × Array β :=
|
||||
as.foldl (init := (#[], #[])) fun (as, bs) (a, b) => (as.push a, bs.push b)
|
||||
|
||||
|
||||
@@ -5,64 +5,59 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Array.Basic
|
||||
import Init.Omega
|
||||
universe u v
|
||||
|
||||
namespace Array
|
||||
-- TODO: CLEANUP
|
||||
|
||||
@[specialize] def binSearchAux {α : Type u} {β : Type v} (lt : α → α → Bool) (found : Option α → β) (as : Array α) (k : α) :
|
||||
(lo : Fin (as.size + 1)) → (hi : Fin as.size) → (lo.1 ≤ hi.1) → β
|
||||
| lo, hi, h =>
|
||||
let m := (lo.1 + hi.1)/2
|
||||
let a := as[m]
|
||||
if lt a k then
|
||||
if h' : m + 1 ≤ hi.1 then
|
||||
binSearchAux lt found as k ⟨m+1, by omega⟩ hi h'
|
||||
else found none
|
||||
else if lt k a then
|
||||
if h' : m = 0 ∨ m - 1 < lo.1 then found none
|
||||
else binSearchAux lt found as k lo ⟨m-1, by omega⟩ (by simp; omega)
|
||||
else found (some a)
|
||||
termination_by lo hi => hi.1 - lo.1
|
||||
namespace Array
|
||||
-- TODO: remove the [Inhabited α] parameters as soon as we have the tactic framework for automating proof generation and using Array.fget
|
||||
-- TODO: remove `partial` using well-founded recursion
|
||||
|
||||
@[specialize] partial def binSearchAux {α : Type u} {β : Type v} [Inhabited β] (lt : α → α → Bool) (found : Option α → β) (as : Array α) (k : α) : Nat → Nat → β
|
||||
| lo, hi =>
|
||||
if lo <= hi then
|
||||
let _ := Inhabited.mk k
|
||||
let m := (lo + hi)/2
|
||||
let a := as.get! m
|
||||
if lt a k then binSearchAux lt found as k (m+1) hi
|
||||
else if lt k a then
|
||||
if m == 0 then found none
|
||||
else binSearchAux lt found as k lo (m-1)
|
||||
else found (some a)
|
||||
else found none
|
||||
|
||||
@[inline] def binSearch {α : Type} (as : Array α) (k : α) (lt : α → α → Bool) (lo := 0) (hi := as.size - 1) : Option α :=
|
||||
if h : lo < as.size then
|
||||
if lo < as.size then
|
||||
let hi := if hi < as.size then hi else as.size - 1
|
||||
if w : lo ≤ hi then
|
||||
binSearchAux lt id as k ⟨lo, by omega⟩ ⟨hi, by simp [hi]; split <;> omega⟩ (by simp [hi]; omega)
|
||||
else
|
||||
none
|
||||
binSearchAux lt id as k lo hi
|
||||
else
|
||||
none
|
||||
|
||||
@[inline] def binSearchContains {α : Type} (as : Array α) (k : α) (lt : α → α → Bool) (lo := 0) (hi := as.size - 1) : Bool :=
|
||||
if h : lo < as.size then
|
||||
if lo < as.size then
|
||||
let hi := if hi < as.size then hi else as.size - 1
|
||||
if w : lo ≤ hi then
|
||||
binSearchAux lt Option.isSome as k ⟨lo, by omega⟩ ⟨hi, by simp [hi]; split <;> omega⟩ (by simp [hi]; omega)
|
||||
else
|
||||
false
|
||||
binSearchAux lt Option.isSome as k lo hi
|
||||
else
|
||||
false
|
||||
|
||||
@[specialize] private def binInsertAux {α : Type u} {m : Type u → Type v} [Monad m]
|
||||
@[specialize] private partial def binInsertAux {α : Type u} {m : Type u → Type v} [Monad m]
|
||||
(lt : α → α → Bool)
|
||||
(merge : α → m α)
|
||||
(add : Unit → m α)
|
||||
(as : Array α)
|
||||
(k : α) : (lo : Fin as.size) → (hi : Fin as.size) → (lo.1 ≤ hi.1) → (lt as[lo] k) → m (Array α)
|
||||
| lo, hi, h, w =>
|
||||
let mid := (lo.1 + hi.1)/2
|
||||
let midVal := as[mid]
|
||||
if w₁ : lt midVal k then
|
||||
if h' : mid = lo then do let v ← add (); pure <| as.insertIdx (lo+1) v
|
||||
else binInsertAux lt merge add as k ⟨mid, by omega⟩ hi (by simp; omega) w₁
|
||||
else if w₂ : lt k midVal then
|
||||
have : mid ≠ lo := fun z => by simp [midVal, z] at w₁; simp_all
|
||||
binInsertAux lt merge add as k lo ⟨mid, by omega⟩ (by simp; omega) w
|
||||
(k : α) : Nat → Nat → m (Array α)
|
||||
| lo, hi =>
|
||||
let _ := Inhabited.mk k
|
||||
-- as[lo] < k < as[hi]
|
||||
let mid := (lo + hi)/2
|
||||
let midVal := as.get! mid
|
||||
if lt midVal k then
|
||||
if mid == lo then do let v ← add (); pure <| as.insertIdx! (lo+1) v
|
||||
else binInsertAux lt merge add as k mid hi
|
||||
else if lt k midVal then
|
||||
binInsertAux lt merge add as k lo mid
|
||||
else do
|
||||
as.modifyM mid <| fun v => merge v
|
||||
termination_by lo hi => hi.1 - lo.1
|
||||
|
||||
@[specialize] def binInsertM {α : Type u} {m : Type u → Type v} [Monad m]
|
||||
(lt : α → α → Bool)
|
||||
@@ -70,12 +65,13 @@ termination_by lo hi => hi.1 - lo.1
|
||||
(add : Unit → m α)
|
||||
(as : Array α)
|
||||
(k : α) : m (Array α) :=
|
||||
if h : as.size = 0 then do let v ← add (); pure <| as.push v
|
||||
else if lt k as[0] then do let v ← add (); pure <| as.insertIdx 0 v
|
||||
else if h' : !lt as[0] k then as.modifyM 0 <| merge
|
||||
else if lt as[as.size - 1] k then do let v ← add (); pure <| as.push v
|
||||
else if !lt k as[as.size - 1] then as.modifyM (as.size - 1) <| merge
|
||||
else binInsertAux lt merge add as k ⟨0, by omega⟩ ⟨as.size - 1, by omega⟩ (by simp) (by simpa using h')
|
||||
let _ := Inhabited.mk k
|
||||
if as.isEmpty then do let v ← add (); pure <| as.push v
|
||||
else if lt k (as.get! 0) then do let v ← add (); pure <| as.insertIdx! 0 v
|
||||
else if !lt (as.get! 0) k then as.modifyM 0 <| merge
|
||||
else if lt as.back! k then do let v ← add (); pure <| as.push v
|
||||
else if !lt k as.back! then as.modifyM (as.size - 1) <| merge
|
||||
else binInsertAux lt merge add as k 0 (as.size - 1)
|
||||
|
||||
@[inline] def binInsert {α : Type u} (lt : α → α → Bool) (as : Array α) (k : α) : Array α :=
|
||||
Id.run <| binInsertM lt (fun _ => k) (fun _ => k) as k
|
||||
|
||||
@@ -23,7 +23,7 @@ theorem foldlM_toList.aux [Monad m]
|
||||
· cases Nat.not_le_of_gt ‹_› (Nat.zero_add _ ▸ H)
|
||||
· rename_i i; rw [Nat.succ_add] at H
|
||||
simp [foldlM_toList.aux f arr i (j+1) H]
|
||||
rw (occs := [2]) [← List.getElem_cons_drop_succ_eq_drop ‹_›]
|
||||
rw (occs := .pos [2]) [← List.getElem_cons_drop_succ_eq_drop ‹_›]
|
||||
rfl
|
||||
· rw [List.drop_of_length_le (Nat.ge_of_not_lt ‹_›)]; rfl
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ Authors: Leonardo de Moura
|
||||
prelude
|
||||
import Init.Data.Array.Basic
|
||||
import Init.Data.BEq
|
||||
import Init.Data.Nat.Lemmas
|
||||
import Init.Data.List.Nat.BEq
|
||||
import Init.ByCases
|
||||
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2024 François G. Dorais. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: François G. Dorais
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.List.FinRange
|
||||
|
||||
namespace Array
|
||||
|
||||
/-- `finRange n` is the array of all elements of `Fin n` in order. -/
|
||||
protected def finRange (n : Nat) : Array (Fin n) := ofFn fun i => i
|
||||
|
||||
end Array
|
||||
@@ -5,91 +5,24 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Array.Basic
|
||||
import Init.Data.Nat.Fold
|
||||
import Init.Data.Vector.Lemmas
|
||||
|
||||
namespace Vector
|
||||
|
||||
/-- Swap the `i`-th element repeatedly to the left, while the element to its left is not `lt` it. -/
|
||||
@[specialize, inline] def swapLeftWhileLT {n} (a : Vector α n) (i : Nat) (h : i < n)
|
||||
(lt : α → α → Bool := by exact (· < ·)) : Vector α n :=
|
||||
match h' : i with
|
||||
| 0 => a
|
||||
| i'+1 =>
|
||||
if lt a[i] a[i'] then
|
||||
swapLeftWhileLT (a.swap i' i) i' (by omega) lt
|
||||
else
|
||||
a
|
||||
|
||||
end Vector
|
||||
|
||||
open Vector
|
||||
namespace Array
|
||||
|
||||
/-- Sort an array in place using insertion sort. -/
|
||||
@[inline] def insertionSort (a : Array α) (lt : α → α → Bool := by exact (· < ·)) : Array α :=
|
||||
a.size.fold (init := ⟨a, rfl⟩) (fun i h acc => swapLeftWhileLT acc i h lt) |>.toArray
|
||||
|
||||
/-- Insert an element into an array, after the last element which is not `lt` the inserted element. -/
|
||||
def orderedInsert (a : Array α) (x : α) (lt : α → α → Bool := by exact (· < ·)) : Array α :=
|
||||
swapLeftWhileLT ⟨a.push x, rfl⟩ a.size (by simp) lt |>.toArray
|
||||
|
||||
end Array
|
||||
|
||||
/-! ### Verification -/
|
||||
|
||||
namespace Vector
|
||||
|
||||
theorem swapLeftWhileLT_push {n} (a : Vector α n) (x : α) (j : Nat) (h : j < n) :
|
||||
swapLeftWhileLT (a.push x) j (by omega) lt = (swapLeftWhileLT a j h lt).push x := by
|
||||
induction j generalizing a with
|
||||
| zero => simp [swapLeftWhileLT]
|
||||
| succ j ih =>
|
||||
simp [swapLeftWhileLT]
|
||||
split <;> rename_i h
|
||||
· rw [Vector.getElem_push_lt (by omega), Vector.getElem_push_lt (by omega)] at h
|
||||
rw [← Vector.push_swap, ih, if_pos h]
|
||||
· rw [Vector.getElem_push_lt (by omega), Vector.getElem_push_lt (by omega)] at h
|
||||
rw [if_neg h]
|
||||
|
||||
theorem swapLeftWhileLT_cast {n m} (a : Vector α n) (j : Nat) (h : j < n) (h' : n = m) :
|
||||
swapLeftWhileLT (a.cast h') j (by omega) lt = (swapLeftWhileLT a j h lt).cast h' := by
|
||||
subst h'
|
||||
simp
|
||||
|
||||
end Vector
|
||||
|
||||
namespace Array
|
||||
|
||||
@[simp] theorem size_insertionSort (a : Array α) : (a.insertionSort lt).size = a.size := by
|
||||
simp [insertionSort]
|
||||
|
||||
private theorem insertionSort_push' (a : Array α) (x : α) :
|
||||
(a.push x).insertionSort lt =
|
||||
(swapLeftWhileLT ⟨(a.insertionSort lt).push x, rfl⟩ a.size (by simp) lt).toArray := by
|
||||
rw [insertionSort, Nat.fold_congr (size_push a x), Nat.fold]
|
||||
have : (a.size.fold (fun i h acc => swapLeftWhileLT acc i (by simp; omega) lt) ⟨a.push x, rfl⟩) =
|
||||
((a.size.fold (fun i h acc => swapLeftWhileLT acc i h lt) ⟨a, rfl⟩).push x).cast (by simp) := by
|
||||
rw [Vector.eq_cast_iff]
|
||||
simp only [Nat.fold_eq_finRange_foldl]
|
||||
rw [← List.foldl_hom (fun a => (Vector.push x a)) _ (fun v ⟨i, h⟩ => swapLeftWhileLT v i (by omega) lt)]
|
||||
rw [Vector.push_mk]
|
||||
rw [← List.foldl_hom (Vector.cast _) _ (fun v ⟨i, h⟩ => swapLeftWhileLT v i (by omega) lt)]
|
||||
· simp
|
||||
· intro v i
|
||||
simp only
|
||||
rw [swapLeftWhileLT_cast]
|
||||
· simp [swapLeftWhileLT_push]
|
||||
rw [this]
|
||||
simp only [Nat.lt_add_one, swapLeftWhileLT_cast, Vector.toArray_cast]
|
||||
unfold insertionSort
|
||||
simp only [Vector.push]
|
||||
congr
|
||||
all_goals simp
|
||||
|
||||
theorem insertionSort_push (a : Array α) (x : α) :
|
||||
(a.push x).insertionSort lt = (a.insertionSort lt).orderedInsert x lt := by
|
||||
rw [insertionSort_push', orderedInsert]
|
||||
simp
|
||||
|
||||
end Array
|
||||
@[inline] def Array.insertionSort (a : Array α) (lt : α → α → Bool) : Array α :=
|
||||
traverse a 0 a.size
|
||||
where
|
||||
@[specialize] traverse (a : Array α) (i : Nat) (fuel : Nat) : Array α :=
|
||||
match fuel with
|
||||
| 0 => a
|
||||
| fuel+1 =>
|
||||
if h : i < a.size then
|
||||
traverse (swapLoop a i h) (i+1) fuel
|
||||
else
|
||||
a
|
||||
@[specialize] swapLoop (a : Array α) (j : Nat) (h : j < a.size) : Array α :=
|
||||
match (generalizing := false) he:j with -- using `generalizing` because we don't want to refine the type of `h`
|
||||
| 0 => a
|
||||
| j'+1 =>
|
||||
have h' : j' < a.size := by subst j; exact Nat.lt_trans (Nat.lt_succ_self _) h
|
||||
if lt a[j] a[j'] then
|
||||
swapLoop (a.swap ⟨j, h⟩ ⟨j', h'⟩) j' (by rw [size_swap]; assumption; done)
|
||||
else
|
||||
a
|
||||
|
||||
@@ -343,8 +343,8 @@ theorem isPrefixOfAux_toArray_zero [BEq α] (l₁ l₂ : List α) (hle : l₁.le
|
||||
rw [ih]
|
||||
simp_all
|
||||
|
||||
theorem zipWithAux_toArray_succ (as : List α) (bs : List β) (f : α → β → γ) (i : Nat) (cs : Array γ) :
|
||||
zipWithAux as.toArray bs.toArray f (i + 1) cs = zipWithAux as.tail.toArray bs.tail.toArray f i cs := by
|
||||
theorem zipWithAux_toArray_succ (f : α → β → γ) (as : List α) (bs : List β) (i : Nat) (cs : Array γ) :
|
||||
zipWithAux f as.toArray bs.toArray (i + 1) cs = zipWithAux f as.tail.toArray bs.tail.toArray i cs := by
|
||||
rw [zipWithAux]
|
||||
conv => rhs; rw [zipWithAux]
|
||||
simp only [size_toArray, getElem_toArray, length_tail, getElem_tail]
|
||||
@@ -355,8 +355,8 @@ theorem zipWithAux_toArray_succ (as : List α) (bs : List β) (f : α → β →
|
||||
rw [dif_neg (by omega)]
|
||||
· rw [dif_neg (by omega)]
|
||||
|
||||
theorem zipWithAux_toArray_succ' (as : List α) (bs : List β) (f : α → β → γ) (i : Nat) (cs : Array γ) :
|
||||
zipWithAux as.toArray bs.toArray f (i + 1) cs = zipWithAux (as.drop (i+1)).toArray (bs.drop (i+1)).toArray f 0 cs := by
|
||||
theorem zipWithAux_toArray_succ' (f : α → β → γ) (as : List α) (bs : List β) (i : Nat) (cs : Array γ) :
|
||||
zipWithAux f as.toArray bs.toArray (i + 1) cs = zipWithAux f (as.drop (i+1)).toArray (bs.drop (i+1)).toArray 0 cs := by
|
||||
induction i generalizing as bs cs with
|
||||
| zero => simp [zipWithAux_toArray_succ]
|
||||
| succ i ih =>
|
||||
@@ -364,7 +364,7 @@ theorem zipWithAux_toArray_succ' (as : List α) (bs : List β) (f : α → β
|
||||
simp
|
||||
|
||||
theorem zipWithAux_toArray_zero (f : α → β → γ) (as : List α) (bs : List β) (cs : Array γ) :
|
||||
zipWithAux as.toArray bs.toArray f 0 cs = cs ++ (List.zipWith f as bs).toArray := by
|
||||
zipWithAux f as.toArray bs.toArray 0 cs = cs ++ (List.zipWith f as bs).toArray := by
|
||||
rw [Array.zipWithAux]
|
||||
match as, bs with
|
||||
| [], _ => simp
|
||||
@@ -372,7 +372,7 @@ theorem zipWithAux_toArray_zero (f : α → β → γ) (as : List α) (bs : List
|
||||
| a :: as, b :: bs =>
|
||||
simp [zipWith_cons_cons, zipWithAux_toArray_succ', zipWithAux_toArray_zero, push_append_toArray]
|
||||
|
||||
@[simp] theorem zipWith_toArray (as : List α) (bs : List β) (f : α → β → γ) :
|
||||
@[simp] theorem zipWith_toArray (f : α → β → γ) (as : List α) (bs : List β) :
|
||||
Array.zipWith as.toArray bs.toArray f = (List.zipWith f as bs).toArray := by
|
||||
rw [Array.zipWith]
|
||||
simp [zipWithAux_toArray_zero]
|
||||
@@ -381,44 +381,6 @@ theorem zipWithAux_toArray_zero (f : α → β → γ) (as : List α) (bs : List
|
||||
Array.zip as.toArray bs.toArray = (List.zip as bs).toArray := by
|
||||
simp [Array.zip, zipWith_toArray, zip]
|
||||
|
||||
theorem zipWithAll_go_toArray (as : List α) (bs : List β) (f : Option α → Option β → γ) (i : Nat) (cs : Array γ) :
|
||||
zipWithAll.go f as.toArray bs.toArray i cs = cs ++ (List.zipWithAll f (as.drop i) (bs.drop i)).toArray := by
|
||||
unfold zipWithAll.go
|
||||
split <;> rename_i h
|
||||
· rw [zipWithAll_go_toArray]
|
||||
simp at h
|
||||
simp only [getElem?_toArray, push_append_toArray]
|
||||
if ha : i < as.length then
|
||||
if hb : i < bs.length then
|
||||
rw [List.drop_eq_getElem_cons ha, List.drop_eq_getElem_cons hb]
|
||||
simp only [ha, hb, getElem?_eq_getElem, zipWithAll_cons_cons]
|
||||
else
|
||||
simp only [Nat.not_lt] at hb
|
||||
rw [List.drop_eq_getElem_cons ha]
|
||||
rw [(drop_eq_nil_iff (l := bs)).mpr (by omega), (drop_eq_nil_iff (l := bs)).mpr (by omega)]
|
||||
simp only [zipWithAll_nil, map_drop, map_cons]
|
||||
rw [getElem?_eq_getElem ha]
|
||||
rw [getElem?_eq_none hb]
|
||||
else
|
||||
if hb : i < bs.length then
|
||||
simp only [Nat.not_lt] at ha
|
||||
rw [List.drop_eq_getElem_cons hb]
|
||||
rw [(drop_eq_nil_iff (l := as)).mpr (by omega), (drop_eq_nil_iff (l := as)).mpr (by omega)]
|
||||
simp only [nil_zipWithAll, map_drop, map_cons]
|
||||
rw [getElem?_eq_getElem hb]
|
||||
rw [getElem?_eq_none ha]
|
||||
else
|
||||
omega
|
||||
· simp only [size_toArray, Nat.not_lt] at h
|
||||
rw [drop_eq_nil_of_le (by omega), drop_eq_nil_of_le (by omega)]
|
||||
simp
|
||||
termination_by max as.length bs.length - i
|
||||
decreasing_by simp_wf; decreasing_trivial_pre_omega
|
||||
|
||||
@[simp] theorem zipWithAll_toArray (f : Option α → Option β → γ) (as : List α) (bs : List β) :
|
||||
Array.zipWithAll as.toArray bs.toArray f = (List.zipWithAll f as bs).toArray := by
|
||||
simp [Array.zipWithAll, zipWithAll_go_toArray]
|
||||
|
||||
end List
|
||||
|
||||
namespace Array
|
||||
@@ -496,11 +458,6 @@ where
|
||||
simp only [← length_toList]
|
||||
simp
|
||||
|
||||
@[simp] theorem mapM_empty [Monad m] (f : α → m β) : mapM f #[] = pure #[] := by
|
||||
rw [mapM, mapM.map]; rfl
|
||||
|
||||
@[simp] theorem map_empty (f : α → β) : map f #[] = #[] := mapM_empty f
|
||||
|
||||
@[simp] theorem appendList_nil (arr : Array α) : arr ++ ([] : List α) = arr := Array.ext' (by simp)
|
||||
|
||||
@[simp] theorem appendList_cons (arr : Array α) (a : α) (l : List α) :
|
||||
@@ -556,10 +513,10 @@ theorem getElem?_len_le (a : Array α) {i : Nat} (h : a.size ≤ i) : a[i]? = no
|
||||
theorem getD_get? (a : Array α) (i : Nat) (d : α) :
|
||||
Option.getD a[i]? d = if p : i < a.size then a[i]'p else d := by
|
||||
if h : i < a.size then
|
||||
simp [setIfInBounds, h, getElem?_def]
|
||||
simp [setD, h, getElem?_def]
|
||||
else
|
||||
have p : i ≥ a.size := Nat.le_of_not_gt h
|
||||
simp [setIfInBounds, getElem?_len_le _ p, h]
|
||||
simp [setD, getElem?_len_le _ p, h]
|
||||
|
||||
@[simp] theorem getD_eq_get? (a : Array α) (n d) : a.getD n d = (a[n]?).getD d := by
|
||||
simp only [getD, get_eq_getElem, get?_eq_getElem?]; split <;> simp [getD_get?, *]
|
||||
@@ -595,46 +552,31 @@ theorem getElem_set (a : Array α) (i : Nat) (h' : i < a.size) (v : α) (j : Nat
|
||||
(ne : i ≠ j) : (a.set i v)[j]? = a[j]? := by
|
||||
by_cases h : j < a.size <;> simp [getElem?_lt, getElem?_ge, Nat.ge_of_not_lt, ne, h]
|
||||
|
||||
theorem push_set (a : Array α) (x y : α) {i : Nat} {hi} :
|
||||
(a.set i x).push y = (a.push y).set i x (by simp; omega):= by
|
||||
ext j h₁ h₂
|
||||
· simp
|
||||
· if h' : j = a.size then
|
||||
rw [getElem_push, getElem_set_ne, dif_neg]
|
||||
all_goals simp_all <;> omega
|
||||
else
|
||||
rw [getElem_push_lt, getElem_set, getElem_set]
|
||||
split
|
||||
· rfl
|
||||
· rw [getElem_push_lt]
|
||||
simp_all; omega
|
||||
/-! # setD -/
|
||||
|
||||
/-! # setIfInBounds -/
|
||||
@[simp] theorem set!_is_setD : @set! = @setD := rfl
|
||||
|
||||
@[simp] theorem set!_is_setIfInBounds : @set! = @setIfInBounds := rfl
|
||||
|
||||
@[simp] theorem size_setIfInBounds (a : Array α) (index : Nat) (val : α) :
|
||||
(Array.setIfInBounds a index val).size = a.size := by
|
||||
@[simp] theorem size_setD (a : Array α) (index : Nat) (val : α) :
|
||||
(Array.setD a index val).size = a.size := by
|
||||
if h : index < a.size then
|
||||
simp [setIfInBounds, h]
|
||||
simp [setD, h]
|
||||
else
|
||||
simp [setIfInBounds, h]
|
||||
simp [setD, h]
|
||||
|
||||
@[simp] theorem getElem_setIfInBounds_eq (a : Array α) {i : Nat} (v : α) (h : _) :
|
||||
(setIfInBounds a i v)[i]'h = v := by
|
||||
@[simp] theorem getElem_setD_eq (a : Array α) {i : Nat} (v : α) (h : _) :
|
||||
(setD a i v)[i]'h = v := by
|
||||
simp at h
|
||||
simp only [setIfInBounds, h, ↓reduceDIte, getElem_set_eq]
|
||||
simp only [setD, h, ↓reduceDIte, getElem_set_eq]
|
||||
|
||||
@[simp]
|
||||
theorem getElem?_setIfInBounds_eq (a : Array α) {i : Nat} (p : i < a.size) (v : α) :
|
||||
(a.setIfInBounds i v)[i]? = some v := by
|
||||
theorem getElem?_setD_eq (a : Array α) {i : Nat} (p : i < a.size) (v : α) : (a.setD i v)[i]? = some v := by
|
||||
simp [getElem?_lt, p]
|
||||
|
||||
/-- Simplifies a normal form from `get!` -/
|
||||
@[simp] theorem getD_get?_setIfInBounds (a : Array α) (i : Nat) (v d : α) :
|
||||
Option.getD (setIfInBounds a i v)[i]? d = if i < a.size then v else d := by
|
||||
@[simp] theorem getD_get?_setD (a : Array α) (i : Nat) (v d : α) :
|
||||
Option.getD (setD a i v)[i]? d = if i < a.size then v else d := by
|
||||
by_cases h : i < a.size <;>
|
||||
simp [setIfInBounds, Nat.not_lt_of_le, h, getD_get?]
|
||||
simp [setD, Nat.not_lt_of_le, h, getD_get?]
|
||||
|
||||
/-! # ofFn -/
|
||||
|
||||
@@ -679,19 +621,6 @@ theorem getElem?_ofFn (f : Fin n → α) (i : Nat) :
|
||||
(ofFn f)[i]? = if h : i < n then some (f ⟨i, h⟩) else none := by
|
||||
simp [getElem?_def]
|
||||
|
||||
@[simp] theorem ofFn_zero (f : Fin 0 → α) : ofFn f = #[] := rfl
|
||||
|
||||
theorem ofFn_succ (f : Fin (n+1) → α) :
|
||||
ofFn f = (ofFn (fun (i : Fin n) => f i.castSucc)).push (f ⟨n, by omega⟩) := by
|
||||
ext i h₁ h₂
|
||||
· simp
|
||||
· simp [getElem_push]
|
||||
split <;> rename_i h₃
|
||||
· rfl
|
||||
· congr
|
||||
simp at h₁ h₂
|
||||
omega
|
||||
|
||||
/-! # mkArray -/
|
||||
|
||||
@[simp] theorem size_mkArray (n : Nat) (v : α) : (mkArray n v).size = n :=
|
||||
@@ -826,32 +755,32 @@ theorem get_set (a : Array α) (i : Nat) (hi : i < a.size) (j : Nat) (hj : j < a
|
||||
(h : i ≠ j) : (a.set i v)[j]'(by simp [*]) = a[j] := by
|
||||
simp only [set, getElem_eq_getElem_toList, List.getElem_set_ne h]
|
||||
|
||||
theorem getElem_setIfInBounds (a : Array α) (i : Nat) (v : α) (h : i < (setIfInBounds a i v).size) :
|
||||
(setIfInBounds a i v)[i] = v := by
|
||||
theorem getElem_setD (a : Array α) (i : Nat) (v : α) (h : i < (setD a i v).size) :
|
||||
(setD a i v)[i] = v := by
|
||||
simp at h
|
||||
simp only [setIfInBounds, h, ↓reduceDIte, getElem_set_eq]
|
||||
simp only [setD, h, ↓reduceDIte, getElem_set_eq]
|
||||
|
||||
theorem set_set (a : Array α) (i : Nat) (h) (v v' : α) :
|
||||
(a.set i v h).set i v' (by simp [h]) = a.set i v' := by simp [set, List.set_set]
|
||||
|
||||
private theorem fin_cast_val (e : n = n') (i : Fin n) : e ▸ i = ⟨i.1, e ▸ i.2⟩ := by cases e; rfl
|
||||
|
||||
theorem swap_def (a : Array α) (i j : Nat) (hi hj) :
|
||||
a.swap i j hi hj = (a.set i a[j]).set j a[i] (by simpa using hj) := by
|
||||
theorem swap_def (a : Array α) (i j : Fin a.size) :
|
||||
a.swap i j = (a.set i a[j]).set j a[i] := by
|
||||
simp [swap, fin_cast_val]
|
||||
|
||||
@[simp] theorem toList_swap (a : Array α) (i j : Nat) (hi hj) :
|
||||
(a.swap i j hi hj).toList = (a.toList.set i a[j]).set j a[i] := by simp [swap_def]
|
||||
@[simp] theorem toList_swap (a : Array α) (i j : Fin a.size) :
|
||||
(a.swap i j).toList = (a.toList.set i a[j]).set j a[i] := by simp [swap_def]
|
||||
|
||||
theorem getElem?_swap (a : Array α) (i j : Nat) (hi hj) (k : Nat) : (a.swap i j hi hj)[k]? =
|
||||
if j = k then some a[i] else if i = k then some a[j] else a[k]? := by
|
||||
theorem getElem?_swap (a : Array α) (i j : Fin a.size) (k : Nat) : (a.swap i j)[k]? =
|
||||
if j = k then some a[i.1] else if i = k then some a[j.1] else a[k]? := by
|
||||
simp [swap_def, get?_set, ← getElem_fin_eq_getElem_toList]
|
||||
|
||||
@[simp] theorem swapAt_def (a : Array α) (i : Nat) (v : α) (hi) :
|
||||
a.swapAt i v hi = (a[i], a.set i v) := rfl
|
||||
@[simp] theorem swapAt_def (a : Array α) (i : Fin a.size) (v : α) :
|
||||
a.swapAt i v = (a[i.1], a.set i v) := rfl
|
||||
|
||||
@[simp] theorem size_swapAt (a : Array α) (i : Nat) (v : α) (hi) :
|
||||
(a.swapAt i v hi).2.size = a.size := by simp [swapAt_def]
|
||||
@[simp] theorem size_swapAt (a : Array α) (i : Fin a.size) (v : α) :
|
||||
(a.swapAt i v).2.size = a.size := by simp [swapAt_def]
|
||||
|
||||
@[simp]
|
||||
theorem swapAt!_def (a : Array α) (i : Nat) (v : α) (h : i < a.size) :
|
||||
@@ -898,10 +827,8 @@ theorem eq_push_of_size_ne_zero {as : Array α} (h : as.size ≠ 0) :
|
||||
|
||||
theorem size_eq_length_toList (as : Array α) : as.size = as.toList.length := rfl
|
||||
|
||||
@[simp] theorem size_swapIfInBounds (a : Array α) (i j) :
|
||||
(a.swapIfInBounds i j).size = a.size := by unfold swapIfInBounds; split <;> (try split) <;> simp [size_swap]
|
||||
|
||||
@[deprecated size_swapIfInBounds (since := "2024-11-24")] abbrev size_swap! := @size_swapIfInBounds
|
||||
@[simp] theorem size_swap! (a : Array α) (i j) :
|
||||
(a.swap! i j).size = a.size := by unfold swap!; split <;> (try split) <;> simp [size_swap]
|
||||
|
||||
@[simp] theorem size_reverse (a : Array α) : a.reverse.size = a.size := by
|
||||
let rec go (as : Array α) (i j) : (reverse.loop as i j).size = as.size := by
|
||||
@@ -913,10 +840,16 @@ theorem size_eq_length_toList (as : Array α) : as.size = as.toList.length := rf
|
||||
simp only [reverse]; split <;> simp [go]
|
||||
|
||||
@[simp] theorem size_range {n : Nat} : (range n).size = n := by
|
||||
induction n <;> simp [range]
|
||||
unfold range
|
||||
induction n with
|
||||
| zero => simp [Nat.fold]
|
||||
| succ k ih =>
|
||||
rw [Nat.fold, flip]
|
||||
simp only [mkEmpty_eq, size_push] at *
|
||||
omega
|
||||
|
||||
@[simp] theorem toList_range (n : Nat) : (range n).toList = List.range n := by
|
||||
apply List.ext_getElem <;> simp [range]
|
||||
induction n <;> simp_all [range, Nat.fold, flip, List.range_succ]
|
||||
|
||||
@[simp]
|
||||
theorem getElem_range {n : Nat} {x : Nat} (h : x < (Array.range n).size) : (Array.range n)[x] = x := by
|
||||
@@ -1112,34 +1045,6 @@ theorem foldr_congr {as bs : Array α} (h₀ : as = bs) {f g : α → β → β}
|
||||
as.foldr f a start stop = bs.foldr g b start' stop' := by
|
||||
congr
|
||||
|
||||
theorem foldl_eq_foldlM (f : β → α → β) (b) (l : Array α) :
|
||||
l.foldl f b = l.foldlM (m := Id) f b := by
|
||||
cases l
|
||||
simp [List.foldl_eq_foldlM]
|
||||
|
||||
theorem foldr_eq_foldrM (f : α → β → β) (b) (l : Array α) :
|
||||
l.foldr f b = l.foldrM (m := Id) f b := by
|
||||
cases l
|
||||
simp [List.foldr_eq_foldrM]
|
||||
|
||||
@[simp] theorem id_run_foldlM (f : β → α → Id β) (b) (l : Array α) :
|
||||
Id.run (l.foldlM f b) = l.foldl f b := (foldl_eq_foldlM f b l).symm
|
||||
|
||||
@[simp] theorem id_run_foldrM (f : α → β → Id β) (b) (l : Array α) :
|
||||
Id.run (l.foldrM f b) = l.foldr f b := (foldr_eq_foldrM f b l).symm
|
||||
|
||||
theorem foldl_hom (f : α₁ → α₂) (g₁ : α₁ → β → α₁) (g₂ : α₂ → β → α₂) (l : Array β) (init : α₁)
|
||||
(H : ∀ x y, g₂ (f x) y = f (g₁ x y)) : l.foldl g₂ (f init) = f (l.foldl g₁ init) := by
|
||||
cases l
|
||||
simp
|
||||
rw [List.foldl_hom _ _ _ _ _ H]
|
||||
|
||||
theorem foldr_hom (f : β₁ → β₂) (g₁ : α → β₁ → β₁) (g₂ : α → β₂ → β₂) (l : Array α) (init : β₁)
|
||||
(H : ∀ x y, g₂ x (f y) = f (g₁ x y)) : l.foldr g₂ (f init) = f (l.foldr g₁ init) := by
|
||||
cases l
|
||||
simp
|
||||
rw [List.foldr_hom _ _ _ _ _ H]
|
||||
|
||||
/-! ### map -/
|
||||
|
||||
@[simp] theorem mem_map {f : α → β} {l : Array α} : b ∈ l.map f ↔ ∃ a, a ∈ l ∧ f a = b := by
|
||||
@@ -1691,30 +1596,28 @@ instance [DecidableEq α] (a : α) (as : Array α) : Decidable (a ∈ as) :=
|
||||
|
||||
open Fin
|
||||
|
||||
@[simp] theorem getElem_swap_right (a : Array α) {i j : Nat} {hi hj} :
|
||||
(a.swap i j hi hj)[j]'(by simpa using hj) = a[i] := by
|
||||
@[simp] theorem getElem_swap_right (a : Array α) {i j : Fin a.size} : (a.swap i j)[j.1] = a[i] := by
|
||||
simp [swap_def, getElem_set]
|
||||
|
||||
@[simp] theorem getElem_swap_left (a : Array α) {i j : Nat} {hi hj} :
|
||||
(a.swap i j hi hj)[i]'(by simpa using hi) = a[j] := by
|
||||
@[simp] theorem getElem_swap_left (a : Array α) {i j : Fin a.size} : (a.swap i j)[i.1] = a[j] := by
|
||||
simp +contextual [swap_def, getElem_set]
|
||||
|
||||
@[simp] theorem getElem_swap_of_ne (a : Array α) {i j : Nat} {hi hj} (hp : p < a.size)
|
||||
(hi' : p ≠ i) (hj' : p ≠ j) : (a.swap i j hi hj)[p]'(a.size_swap .. |>.symm ▸ hp) = a[p] := by
|
||||
simp [swap_def, getElem_set, hi'.symm, hj'.symm]
|
||||
@[simp] theorem getElem_swap_of_ne (a : Array α) {i j : Fin a.size} (hp : p < a.size)
|
||||
(hi : p ≠ i) (hj : p ≠ j) : (a.swap i j)[p]'(a.size_swap .. |>.symm ▸ hp) = a[p] := by
|
||||
simp [swap_def, getElem_set, hi.symm, hj.symm]
|
||||
|
||||
theorem getElem_swap' (a : Array α) (i j : Nat) {hi hj} (k : Nat) (hk : k < a.size) :
|
||||
(a.swap i j hi hj)[k]'(by simp_all) = if k = i then a[j] else if k = j then a[i] else a[k] := by
|
||||
theorem getElem_swap' (a : Array α) (i j : Fin a.size) (k : Nat) (hk : k < a.size) :
|
||||
(a.swap i j)[k]'(by simp_all) = if k = i then a[j] else if k = j then a[i] else a[k] := by
|
||||
split
|
||||
· simp_all only [getElem_swap_left]
|
||||
· split <;> simp_all
|
||||
|
||||
theorem getElem_swap (a : Array α) (i j : Nat) {hi hj}(k : Nat) (hk : k < (a.swap i j).size) :
|
||||
(a.swap i j hi hj)[k] = if k = i then a[j] else if k = j then a[i] else a[k]'(by simp_all) := by
|
||||
theorem getElem_swap (a : Array α) (i j : Fin a.size) (k : Nat) (hk : k < (a.swap i j).size) :
|
||||
(a.swap i j)[k] = if k = i then a[j] else if k = j then a[i] else a[k]'(by simp_all) := by
|
||||
apply getElem_swap'
|
||||
|
||||
@[simp] theorem swap_swap (a : Array α) {i j : Nat} (hi hj) :
|
||||
(a.swap i j hi hj).swap i j ((a.size_swap ..).symm ▸ hi) ((a.size_swap ..).symm ▸ hj) = a := by
|
||||
@[simp] theorem swap_swap (a : Array α) {i j : Fin a.size} :
|
||||
(a.swap i j).swap ⟨i.1, (a.size_swap ..).symm ▸ i.2⟩ ⟨j.1, (a.size_swap ..).symm ▸ j.2⟩ = a := by
|
||||
apply ext
|
||||
· simp only [size_swap]
|
||||
· intros
|
||||
@@ -1723,7 +1626,7 @@ theorem getElem_swap (a : Array α) (i j : Nat) {hi hj}(k : Nat) (hk : k < (a.sw
|
||||
· simp_all
|
||||
· split <;> simp_all
|
||||
|
||||
theorem swap_comm (a : Array α) {i j : Nat} {hi hj} : a.swap i j hi hj = a.swap j i hj hi := by
|
||||
theorem swap_comm (a : Array α) {i j : Fin a.size} : a.swap i j = a.swap j i := by
|
||||
apply ext
|
||||
· simp only [size_swap]
|
||||
· intros
|
||||
@@ -1732,11 +1635,6 @@ theorem swap_comm (a : Array α) {i j : Nat} {hi hj} : a.swap i j hi hj = a.swap
|
||||
· split <;> simp_all
|
||||
· split <;> simp_all
|
||||
|
||||
theorem push_swap (a : Array α) (x : α) {i j : Nat} {hi hj} :
|
||||
(a.swap i j hi hj).push x = (a.push x).swap i j (by simp; omega) (by simp; omega) := by
|
||||
rw [swap_def, swap_def]
|
||||
simp [push_set, getElem_push_lt, hi, hj]
|
||||
|
||||
/-! ### eraseIdx -/
|
||||
|
||||
theorem eraseIdx_eq_eraseIdxIfInBounds {a : Array α} {i : Nat} (h : i < a.size) :
|
||||
@@ -1763,20 +1661,6 @@ theorem eraseIdx_eq_eraseIdxIfInBounds {a : Array α} {i : Nat} (h : i < a.size)
|
||||
(Array.zip as bs).toList = List.zip as.toList bs.toList := by
|
||||
simp [zip, toList_zipWith, List.zip]
|
||||
|
||||
@[simp] theorem toList_zipWithAll (f : Option α → Option β → γ) (as : Array α) (bs : Array β) :
|
||||
(Array.zipWithAll as bs f).toList = List.zipWithAll f as.toList bs.toList := by
|
||||
cases as
|
||||
cases bs
|
||||
simp
|
||||
|
||||
@[simp] theorem size_zipWith (as : Array α) (bs : Array β) (f : α → β → γ) :
|
||||
(as.zipWith bs f).size = min as.size bs.size := by
|
||||
rw [size_eq_length_toList, toList_zipWith, List.length_zipWith]
|
||||
|
||||
@[simp] theorem size_zip (as : Array α) (bs : Array β) :
|
||||
(as.zip bs).size = min as.size bs.size :=
|
||||
as.size_zipWith bs Prod.mk
|
||||
|
||||
/-! ### findSomeM?, findM?, findSome?, find? -/
|
||||
|
||||
@[simp] theorem findSomeM?_toList [Monad m] [LawfulMonad m] (p : α → m (Option β)) (as : Array α) :
|
||||
@@ -1851,10 +1735,10 @@ Our goal is to have `simp` "pull `List.toArray` outwards" as much as possible.
|
||||
apply ext'
|
||||
simp
|
||||
|
||||
@[simp] theorem setIfInBounds_toArray (l : List α) (i : Nat) (a : α) :
|
||||
l.toArray.setIfInBounds i a = (l.set i a).toArray := by
|
||||
@[simp] theorem setD_toArray (l : List α) (i : Nat) (a : α) :
|
||||
l.toArray.setD i a = (l.set i a).toArray := by
|
||||
apply ext'
|
||||
simp only [setIfInBounds]
|
||||
simp only [setD]
|
||||
split
|
||||
· simp
|
||||
· simp_all [List.set_eq_of_length_le]
|
||||
@@ -1899,8 +1783,8 @@ theorem all_toArray (p : α → Bool) (l : List α) : l.toArray.all p = l.all p
|
||||
subst h
|
||||
rw [all_toList]
|
||||
|
||||
@[simp] theorem swap_toArray (l : List α) (i j : Nat) {hi hj}:
|
||||
l.toArray.swap i j hi hj = ((l.set i l[j]).set j l[i]).toArray := by
|
||||
@[simp] theorem swap_toArray (l : List α) (i j : Fin l.toArray.size) :
|
||||
l.toArray.swap i j = ((l.set i l[j]).set j l[i]).toArray := by
|
||||
apply ext'
|
||||
simp
|
||||
|
||||
@@ -2084,20 +1968,6 @@ theorem foldr_filterMap (f : α → Option β) (g : β → γ → γ) (l : Array
|
||||
simp [List.foldr_filterMap]
|
||||
rfl
|
||||
|
||||
theorem foldl_map' (g : α → β) (f : α → α → α) (f' : β → β → β) (a : α) (l : Array α)
|
||||
(h : ∀ x y, f' (g x) (g y) = g (f x y)) :
|
||||
(l.map g).foldl f' (g a) = g (l.foldl f a) := by
|
||||
cases l
|
||||
simp
|
||||
rw [List.foldl_map' _ _ _ _ _ h]
|
||||
|
||||
theorem foldr_map' (g : α → β) (f : α → α → α) (f' : β → β → β) (a : α) (l : List α)
|
||||
(h : ∀ x y, f' (g x) (g y) = g (f x y)) :
|
||||
(l.map g).foldr f' (g a) = g (l.foldr f a) := by
|
||||
cases l
|
||||
simp
|
||||
rw [List.foldr_map' _ _ _ _ _ h]
|
||||
|
||||
/-! ### flatten -/
|
||||
|
||||
@[simp] theorem flatten_empty : flatten (#[] : Array (Array α)) = #[] := rfl
|
||||
@@ -2220,8 +2090,6 @@ theorem toArray_concat {as : List α} {x : α} :
|
||||
|
||||
@[deprecated back!_toArray (since := "2024-10-31")] abbrev back_toArray := @back!_toArray
|
||||
|
||||
@[deprecated setIfInBounds_toArray (since := "2024-11-24")] abbrev setD_toArray := @setIfInBounds_toArray
|
||||
|
||||
end List
|
||||
|
||||
namespace Array
|
||||
@@ -2367,11 +2235,4 @@ abbrev get_swap' := @getElem_swap'
|
||||
@[deprecated eq_push_pop_back!_of_size_ne_zero (since := "2024-10-31")]
|
||||
abbrev eq_push_pop_back_of_size_ne_zero := @eq_push_pop_back!_of_size_ne_zero
|
||||
|
||||
@[deprecated set!_is_setIfInBounds (since := "2024-11-24")] abbrev set_is_setIfInBounds := @set!_is_setIfInBounds
|
||||
@[deprecated size_setIfInBounds (since := "2024-11-24")] abbrev size_setD := @size_setIfInBounds
|
||||
@[deprecated getElem_setIfInBounds_eq (since := "2024-11-24")] abbrev getElem_setD_eq := @getElem_setIfInBounds_eq
|
||||
@[deprecated getElem?_setIfInBounds_eq (since := "2024-11-24")] abbrev get?_setD_eq := @getElem?_setIfInBounds_eq
|
||||
@[deprecated getD_get?_setIfInBounds (since := "2024-11-24")] abbrev getD_setD := @getD_get?_setIfInBounds
|
||||
@[deprecated getElem_setIfInBounds (since := "2024-11-24")] abbrev getElem_setD := @getElem_setIfInBounds
|
||||
|
||||
end Array
|
||||
|
||||
@@ -13,19 +13,19 @@ namespace Array
|
||||
def qpartition (as : Array α) (lt : α → α → Bool) (lo hi : Nat) : Nat × Array α :=
|
||||
if h : as.size = 0 then (0, as) else have : Inhabited α := ⟨as[0]'(by revert h; cases as.size <;> simp)⟩ -- TODO: remove
|
||||
let mid := (lo + hi) / 2
|
||||
let as := if lt (as.get! mid) (as.get! lo) then as.swapIfInBounds lo mid else as
|
||||
let as := if lt (as.get! hi) (as.get! lo) then as.swapIfInBounds lo hi else as
|
||||
let as := if lt (as.get! mid) (as.get! hi) then as.swapIfInBounds mid hi else as
|
||||
let as := if lt (as.get! mid) (as.get! lo) then as.swap! lo mid else as
|
||||
let as := if lt (as.get! hi) (as.get! lo) then as.swap! lo hi else as
|
||||
let as := if lt (as.get! mid) (as.get! hi) then as.swap! mid hi else as
|
||||
let pivot := as.get! hi
|
||||
let rec loop (as : Array α) (i j : Nat) :=
|
||||
if h : j < hi then
|
||||
if lt (as.get! j) pivot then
|
||||
let as := as.swapIfInBounds i j
|
||||
let as := as.swap! i j
|
||||
loop as (i+1) (j+1)
|
||||
else
|
||||
loop as i (j+1)
|
||||
else
|
||||
let as := as.swapIfInBounds i hi
|
||||
let as := as.swap! i hi
|
||||
(i, as)
|
||||
termination_by hi - j
|
||||
decreasing_by all_goals simp_wf; decreasing_trivial_pre_omega
|
||||
|
||||
@@ -25,11 +25,9 @@ Set an element in an array, or do nothing if the index is out of bounds.
|
||||
This will perform the update destructively provided that `a` has a reference
|
||||
count of 1 when called.
|
||||
-/
|
||||
@[inline] def Array.setIfInBounds (a : Array α) (i : Nat) (v : α) : Array α :=
|
||||
@[inline] def Array.setD (a : Array α) (i : Nat) (v : α) : Array α :=
|
||||
dite (LT.lt i a.size) (fun h => a.set i v h) (fun _ => a)
|
||||
|
||||
@[deprecated Array.setIfInBounds (since := "2024-11-24")] abbrev Array.setD := @Array.setIfInBounds
|
||||
|
||||
/--
|
||||
Set an element in an array, or panic if the index is out of bounds.
|
||||
|
||||
@@ -38,4 +36,4 @@ count of 1 when called.
|
||||
-/
|
||||
@[extern "lean_array_set"]
|
||||
def Array.set! (a : Array α) (i : @& Nat) (v : α) : Array α :=
|
||||
Array.setIfInBounds a i v
|
||||
Array.setD a i v
|
||||
|
||||
@@ -23,13 +23,16 @@ def split (s : Subarray α) (i : Fin s.size.succ) : (Subarray α × Subarray α)
|
||||
let ⟨i', isLt⟩ := i
|
||||
have := s.start_le_stop
|
||||
have := s.stop_le_array_size
|
||||
have : i' ≤ s.stop - s.start := Nat.lt_succ.mp isLt
|
||||
have : s.start + i' ≤ s.stop := by omega
|
||||
have : s.start + i' ≤ s.array.size := by omega
|
||||
have : s.start + i' ≤ s.stop := by
|
||||
simp only [size] at isLt
|
||||
omega
|
||||
let pre := {s with
|
||||
stop := s.start + i',
|
||||
start_le_stop := by omega,
|
||||
stop_le_array_size := by omega
|
||||
stop_le_array_size := by assumption
|
||||
}
|
||||
let post := {s with
|
||||
start := s.start + i'
|
||||
@@ -45,7 +48,9 @@ def drop (arr : Subarray α) (i : Nat) : Subarray α where
|
||||
array := arr.array
|
||||
start := min (arr.start + i) arr.stop
|
||||
stop := arr.stop
|
||||
start_le_stop := by omega
|
||||
start_le_stop := by
|
||||
rw [Nat.min_def]
|
||||
split <;> simp only [Nat.le_refl, *]
|
||||
stop_le_array_size := arr.stop_le_array_size
|
||||
|
||||
/--
|
||||
@@ -58,7 +63,9 @@ def take (arr : Subarray α) (i : Nat) : Subarray α where
|
||||
stop := min (arr.start + i) arr.stop
|
||||
start_le_stop := by
|
||||
have := arr.start_le_stop
|
||||
omega
|
||||
rw [Nat.min_def]
|
||||
split <;> omega
|
||||
stop_le_array_size := by
|
||||
have := arr.stop_le_array_size
|
||||
omega
|
||||
rw [Nat.min_def]
|
||||
split <;> omega
|
||||
|
||||
@@ -346,10 +346,6 @@ theorem getMsbD_sub {i : Nat} {i_lt : i < w} {x y : BitVec w} :
|
||||
· rfl
|
||||
· omega
|
||||
|
||||
theorem getElem_sub {i : Nat} {x y : BitVec w} (h : i < w) :
|
||||
(x - y)[i] = (x[i] ^^ ((~~~y + 1#w)[i] ^^ carry i x (~~~y + 1#w) false)) := by
|
||||
simp [← getLsbD_eq_getElem, getLsbD_sub, h]
|
||||
|
||||
theorem msb_sub {x y: BitVec w} :
|
||||
(x - y).msb
|
||||
= (x.msb ^^ ((~~~y + 1#w).msb ^^ carry (w - 1 - 0) x (~~~y + 1#w) false)) := by
|
||||
@@ -414,10 +410,6 @@ theorem getLsbD_neg {i : Nat} {x : BitVec w} :
|
||||
· have h_ge : w ≤ i := by omega
|
||||
simp [getLsbD_ge _ _ h_ge, h_ge, hi]
|
||||
|
||||
theorem getElem_neg {i : Nat} {x : BitVec w} (h : i < w) :
|
||||
(-x)[i] = (x[i] ^^ decide (∃ j < i, x.getLsbD j = true)) := by
|
||||
simp [← getLsbD_eq_getElem, getLsbD_neg, h]
|
||||
|
||||
theorem getMsbD_neg {i : Nat} {x : BitVec w} :
|
||||
getMsbD (-x) i =
|
||||
(getMsbD x i ^^ decide (∃ j < w, i < j ∧ getMsbD x j = true)) := by
|
||||
|
||||
@@ -269,10 +269,6 @@ theorem ofBool_eq_iff_eq : ∀ {b b' : Bool}, BitVec.ofBool b = BitVec.ofBool b'
|
||||
getLsbD (x#'lt) i = x.testBit i := by
|
||||
simp [getLsbD, BitVec.ofNatLt]
|
||||
|
||||
@[simp] theorem getMsbD_ofNatLt {n x i : Nat} (h : x < 2^n) :
|
||||
getMsbD (x#'h) i = (decide (i < n) && x.testBit (n - 1 - i)) := by
|
||||
simp [getMsbD, getLsbD]
|
||||
|
||||
@[simp, bv_toNat] theorem toNat_ofNat (x w : Nat) : (BitVec.ofNat w x).toNat = x % 2^w := by
|
||||
simp [BitVec.toNat, BitVec.ofNat, Fin.ofNat']
|
||||
|
||||
@@ -565,10 +561,6 @@ theorem zeroExtend_eq_setWidth {v : Nat} {x : BitVec w} :
|
||||
else
|
||||
simp [n_le_i, toNat_ofNat]
|
||||
|
||||
@[simp] theorem toInt_setWidth (x : BitVec w) :
|
||||
(x.setWidth v).toInt = Int.bmod x.toNat (2^v) := by
|
||||
simp [toInt_eq_toNat_bmod, toNat_setWidth, Int.emod_bmod]
|
||||
|
||||
theorem setWidth'_eq {x : BitVec w} (h : w ≤ v) : x.setWidth' h = x.setWidth v := by
|
||||
apply eq_of_toNat_eq
|
||||
rw [toNat_setWidth, toNat_setWidth']
|
||||
@@ -763,10 +755,6 @@ theorem extractLsb'_eq_extractLsb {w : Nat} (x : BitVec w) (start len : Nat) (h
|
||||
@[simp] theorem getLsbD_allOnes : (allOnes v).getLsbD i = decide (i < v) := by
|
||||
simp [allOnes]
|
||||
|
||||
@[simp] theorem getMsbD_allOnes : (allOnes v).getMsbD i = decide (i < v) := by
|
||||
simp [allOnes]
|
||||
omega
|
||||
|
||||
@[simp] theorem getElem_allOnes (i : Nat) (h : i < v) : (allOnes v)[i] = true := by
|
||||
simp [getElem_eq_testBit_toNat, h]
|
||||
|
||||
@@ -784,12 +772,6 @@ theorem extractLsb'_eq_extractLsb {w : Nat} (x : BitVec w) (start len : Nat) (h
|
||||
@[simp] theorem toNat_or (x y : BitVec v) :
|
||||
BitVec.toNat (x ||| y) = BitVec.toNat x ||| BitVec.toNat y := rfl
|
||||
|
||||
@[simp] theorem toInt_or (x y : BitVec w) :
|
||||
BitVec.toInt (x ||| y) = Int.bmod (BitVec.toNat x ||| BitVec.toNat y) (2^w) := by
|
||||
rw_mod_cast [Int.bmod_def, BitVec.toInt, toNat_or, Nat.mod_eq_of_lt
|
||||
(Nat.or_lt_two_pow (BitVec.isLt x) (BitVec.isLt y))]
|
||||
omega
|
||||
|
||||
@[simp] theorem toFin_or (x y : BitVec v) :
|
||||
BitVec.toFin (x ||| y) = BitVec.toFin x ||| BitVec.toFin y := by
|
||||
apply Fin.eq_of_val_eq
|
||||
@@ -857,12 +839,6 @@ instance : Std.LawfulCommIdentity (α := BitVec n) (· ||| · ) (0#n) where
|
||||
@[simp] theorem toNat_and (x y : BitVec v) :
|
||||
BitVec.toNat (x &&& y) = BitVec.toNat x &&& BitVec.toNat y := rfl
|
||||
|
||||
@[simp] theorem toInt_and (x y : BitVec w) :
|
||||
BitVec.toInt (x &&& y) = Int.bmod (BitVec.toNat x &&& BitVec.toNat y) (2^w) := by
|
||||
rw_mod_cast [Int.bmod_def, BitVec.toInt, toNat_and, Nat.mod_eq_of_lt
|
||||
(Nat.and_lt_two_pow x.toNat (BitVec.isLt y))]
|
||||
omega
|
||||
|
||||
@[simp] theorem toFin_and (x y : BitVec v) :
|
||||
BitVec.toFin (x &&& y) = BitVec.toFin x &&& BitVec.toFin y := by
|
||||
apply Fin.eq_of_val_eq
|
||||
@@ -930,12 +906,6 @@ instance : Std.LawfulCommIdentity (α := BitVec n) (· &&& · ) (allOnes n) wher
|
||||
@[simp] theorem toNat_xor (x y : BitVec v) :
|
||||
BitVec.toNat (x ^^^ y) = BitVec.toNat x ^^^ BitVec.toNat y := rfl
|
||||
|
||||
@[simp] theorem toInt_xor (x y : BitVec w) :
|
||||
BitVec.toInt (x ^^^ y) = Int.bmod (BitVec.toNat x ^^^ BitVec.toNat y) (2^w) := by
|
||||
rw_mod_cast [Int.bmod_def, BitVec.toInt, toNat_xor, Nat.mod_eq_of_lt
|
||||
(Nat.xor_lt_two_pow (BitVec.isLt x) (BitVec.isLt y))]
|
||||
omega
|
||||
|
||||
@[simp] theorem toFin_xor (x y : BitVec v) :
|
||||
BitVec.toFin (x ^^^ y) = BitVec.toFin x ^^^ BitVec.toFin y := by
|
||||
apply Fin.eq_of_val_eq
|
||||
@@ -1013,13 +983,6 @@ theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl
|
||||
_ ≤ 2 ^ i := Nat.pow_le_pow_of_le_right Nat.zero_lt_two w
|
||||
· simp
|
||||
|
||||
@[simp] theorem toInt_not {x : BitVec w} :
|
||||
(~~~x).toInt = Int.bmod (2^w - 1 - x.toNat) (2^w) := by
|
||||
rw_mod_cast [BitVec.toInt, BitVec.toNat_not, Int.bmod_def]
|
||||
simp [show ((2^w : Nat) : Int) - 1 - x.toNat = ((2^w - 1 - x.toNat) : Nat) by omega]
|
||||
rw_mod_cast [Nat.mod_eq_of_lt (by omega)]
|
||||
omega
|
||||
|
||||
@[simp] theorem ofInt_negSucc_eq_not_ofNat {w n : Nat} :
|
||||
BitVec.ofInt w (Int.negSucc n) = ~~~.ofNat w n := by
|
||||
simp only [BitVec.ofInt, Int.toNat, Int.ofNat_eq_coe, toNat_eq, toNat_ofNatLt, toNat_not,
|
||||
@@ -1044,10 +1007,6 @@ theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl
|
||||
@[simp] theorem getLsbD_not {x : BitVec v} : (~~~x).getLsbD i = (decide (i < v) && ! x.getLsbD i) := by
|
||||
by_cases h' : i < v <;> simp_all [not_def]
|
||||
|
||||
@[simp] theorem getMsbD_not {x : BitVec v} :
|
||||
(~~~x).getMsbD i = (decide (i < v) && ! x.getMsbD i) := by
|
||||
by_cases h' : i < v <;> simp_all [not_def]
|
||||
|
||||
@[simp] theorem getElem_not {x : BitVec w} {i : Nat} (h : i < w) : (~~~x)[i] = !x[i] := by
|
||||
simp only [getElem_eq_testBit_toNat, toNat_not]
|
||||
rw [← Nat.sub_add_eq, Nat.add_comm 1]
|
||||
@@ -1521,12 +1480,6 @@ theorem getLsbD_sshiftRight' {x y: BitVec w} {i : Nat} :
|
||||
(!decide (w ≤ i) && if y.toNat + i < w then x.getLsbD (y.toNat + i) else x.msb) := by
|
||||
simp only [BitVec.sshiftRight', BitVec.getLsbD_sshiftRight]
|
||||
|
||||
@[simp]
|
||||
theorem getElem_sshiftRight' {x y : BitVec w} {i : Nat} (h : i < w) :
|
||||
(x.sshiftRight' y)[i] =
|
||||
(!decide (w ≤ i) && if y.toNat + i < w then x.getLsbD (y.toNat + i) else x.msb) := by
|
||||
simp only [← getLsbD_eq_getElem, BitVec.sshiftRight', BitVec.getLsbD_sshiftRight]
|
||||
|
||||
@[simp]
|
||||
theorem getMsbD_sshiftRight' {x y: BitVec w} {i : Nat} :
|
||||
(x.sshiftRight y.toNat).getMsbD i = (decide (i < w) && if i < y.toNat then x.msb else x.getMsbD (i - y.toNat)) := by
|
||||
@@ -1619,82 +1572,6 @@ theorem signExtend_eq_setWidth_of_lt (x : BitVec w) {v : Nat} (hv : v ≤ w):
|
||||
theorem signExtend_eq (x : BitVec w) : x.signExtend w = x := by
|
||||
rw [signExtend_eq_setWidth_of_lt _ (Nat.le_refl _), setWidth_eq]
|
||||
|
||||
/-- Sign extending to a larger bitwidth depends on the msb.
|
||||
If the msb is false, then the result equals the original value.
|
||||
If the msb is true, then we add a value of `(2^v - 2^w)`, which arises from the sign extension. -/
|
||||
private theorem toNat_signExtend_of_le (x : BitVec w) {v : Nat} (hv : w ≤ v) :
|
||||
(x.signExtend v).toNat = x.toNat + if x.msb then 2^v - 2^w else 0 := by
|
||||
apply Nat.eq_of_testBit_eq
|
||||
intro i
|
||||
have ⟨k, hk⟩ := Nat.exists_eq_add_of_le hv
|
||||
rw [hk, testBit_toNat, getLsbD_signExtend, Nat.pow_add, ← Nat.mul_sub_one, Nat.add_comm (x.toNat)]
|
||||
by_cases hx : x.msb
|
||||
· simp only [hx, Bool.if_true_right, ↓reduceIte,
|
||||
Nat.testBit_mul_pow_two_add _ x.isLt,
|
||||
testBit_toNat, Nat.testBit_two_pow_sub_one]
|
||||
-- Case analysis on i being in the intervals [0..w), [w..w + k), [w+k..∞)
|
||||
have hi : i < w ∨ (w ≤ i ∧ i < w + k) ∨ w + k ≤ i := by omega
|
||||
rcases hi with hi | hi | hi
|
||||
· simp [hi]; omega
|
||||
· simp [hi]; omega
|
||||
· simp [hi, show ¬ (i < w + k) by omega, show ¬ (i < w) by omega]
|
||||
omega
|
||||
· simp only [hx, Bool.if_false_right,
|
||||
Bool.false_eq_true, ↓reduceIte, Nat.zero_add, testBit_toNat]
|
||||
have hi : i < w ∨ (w ≤ i ∧ i < w + k) ∨ w + k ≤ i := by omega
|
||||
rcases hi with hi | hi | hi
|
||||
· simp [hi]; omega
|
||||
· simp [hi]
|
||||
· simp [hi, show ¬ (i < w + k) by omega, show ¬ (i < w) by omega, getLsbD_ge x i (by omega)]
|
||||
|
||||
/-- Sign extending to a larger bitwidth depends on the msb.
|
||||
If the msb is false, then the result equals the original value.
|
||||
If the msb is true, then we add a value of `(2^v - 2^w)`, which arises from the sign extension. -/
|
||||
theorem toNat_signExtend (x : BitVec w) {v : Nat} :
|
||||
(x.signExtend v).toNat = (x.setWidth v).toNat + if x.msb then 2^v - 2^w else 0 := by
|
||||
by_cases h : v ≤ w
|
||||
· have : 2^v ≤ 2^w := Nat.pow_le_pow_of_le_right Nat.two_pos h
|
||||
simp [signExtend_eq_setWidth_of_lt x h, toNat_setWidth, Nat.sub_eq_zero_of_le this]
|
||||
· have : 2^w ≤ 2^v := Nat.pow_le_pow_of_le_right Nat.two_pos (by omega)
|
||||
rw [toNat_signExtend_of_le x (by omega), toNat_setWidth, Nat.mod_eq_of_lt (by omega)]
|
||||
|
||||
/-
|
||||
If the current width `w` is smaller than the extended width `v`,
|
||||
then the value when interpreted as an integer does not change.
|
||||
-/
|
||||
theorem toInt_signExtend_of_lt {x : BitVec w} (hv : w < v):
|
||||
(x.signExtend v).toInt = x.toInt := by
|
||||
simp only [toInt_eq_msb_cond, toNat_signExtend]
|
||||
have : (x.signExtend v).msb = x.msb := by
|
||||
rw [msb_eq_getLsbD_last, getLsbD_eq_getElem (Nat.sub_one_lt_of_lt hv)]
|
||||
simp [getElem_signExtend, Nat.le_sub_one_of_lt hv]
|
||||
have H : 2^w ≤ 2^v := Nat.pow_le_pow_of_le_right (by omega) (by omega)
|
||||
simp only [this, toNat_setWidth, Int.natCast_add, Int.ofNat_emod, Int.natCast_mul]
|
||||
by_cases h : x.msb
|
||||
<;> norm_cast
|
||||
<;> simp [h, Nat.mod_eq_of_lt (Nat.lt_of_lt_of_le x.isLt H)]
|
||||
omega
|
||||
|
||||
/-
|
||||
If the current width `w` is larger than the extended width `v`,
|
||||
then the value when interpreted as an integer is truncated,
|
||||
and we compute a modulo by `2^v`.
|
||||
-/
|
||||
theorem toInt_signExtend_of_le {x : BitVec w} (hv : v ≤ w) :
|
||||
(x.signExtend v).toInt = Int.bmod x.toNat (2^v) := by
|
||||
simp [signExtend_eq_setWidth_of_lt _ hv]
|
||||
|
||||
/-
|
||||
Interpreting the sign extension of `(x : BitVec w)` to width `v`
|
||||
computes `x % 2^v` (where `%` is the balanced mod).
|
||||
-/
|
||||
theorem toInt_signExtend (x : BitVec w) :
|
||||
(x.signExtend v).toInt = Int.bmod x.toNat (2^(min v w)) := by
|
||||
by_cases hv : v ≤ w
|
||||
· simp [toInt_signExtend_of_le hv, Nat.min_eq_left hv]
|
||||
· simp only [Nat.not_le] at hv
|
||||
rw [toInt_signExtend_of_lt hv, Nat.min_eq_right (by omega), toInt_eq_toNat_bmod]
|
||||
|
||||
/-! ### append -/
|
||||
|
||||
theorem append_def (x : BitVec v) (y : BitVec w) :
|
||||
@@ -2761,6 +2638,12 @@ theorem getElem_rotateLeft {x : BitVec w} {r i : Nat} (h : i < w) :
|
||||
if h' : i < r % w then x[(w - (r % w) + i)] else x[i - (r % w)] := by
|
||||
simp [← BitVec.getLsbD_eq_getElem, h]
|
||||
|
||||
/-- If `w ≤ x < 2 * w`, then `x % w = x - w` -/
|
||||
theorem mod_eq_sub_of_le_of_lt {x w : Nat} (x_le : w ≤ x) (x_lt : x < 2 * w) :
|
||||
x % w = x - w := by
|
||||
rw [Nat.mod_eq_sub_mod, Nat.mod_eq_of_lt (by omega)]
|
||||
omega
|
||||
|
||||
theorem getMsbD_rotateLeftAux_of_lt {x : BitVec w} {r : Nat} {i : Nat} (hi : i < w - r) :
|
||||
(x.rotateLeftAux r).getMsbD i = x.getMsbD (r + i) := by
|
||||
rw [rotateLeftAux, getMsbD_or]
|
||||
@@ -2770,20 +2653,6 @@ theorem getMsbD_rotateLeftAux_of_ge {x : BitVec w} {r : Nat} {i : Nat} (hi : i
|
||||
(x.rotateLeftAux r).getMsbD i = (decide (i < w) && x.getMsbD (i - (w - r))) := by
|
||||
simp [rotateLeftAux, getMsbD_or, show i + r ≥ w by omega, show ¬i < w - r by omega]
|
||||
|
||||
/--
|
||||
If a number `w * n ≤ i < w * (n + 1)`, then `i - w * n` equals `i % w`.
|
||||
This is true by subtracting `w * n` from the inequality, giving
|
||||
`0 ≤ i - w * n < w`, which uniquely identifies `i % w`.
|
||||
-/
|
||||
private theorem Nat.sub_mul_eq_mod_of_lt_of_le (hlo : w * n ≤ i) (hhi : i < w * (n + 1)) :
|
||||
i - w * n = i % w := by
|
||||
rw [Nat.mod_def]
|
||||
congr
|
||||
symm
|
||||
apply Nat.div_eq_of_lt_le
|
||||
(by rw [Nat.mul_comm]; omega)
|
||||
(by rw [Nat.mul_comm]; omega)
|
||||
|
||||
/-- When `r < w`, we give a formula for `(x.rotateLeft r).getMsbD i`. -/
|
||||
theorem getMsbD_rotateLeft_of_lt {n w : Nat} {x : BitVec w} (hi : r < w):
|
||||
(x.rotateLeft r).getMsbD n = (decide (n < w) && x.getMsbD ((r + n) % w)) := by
|
||||
@@ -2796,8 +2665,8 @@ theorem getMsbD_rotateLeft_of_lt {n w : Nat} {x : BitVec w} (hi : r < w):
|
||||
by_cases h₁ : n < w + 1
|
||||
· simp only [h₁, decide_true, Bool.true_and]
|
||||
have h₂ : (r + n) < 2 * (w + 1) := by omega
|
||||
rw [mod_eq_sub_of_le_of_lt (by omega) (by omega)]
|
||||
congr 1
|
||||
rw [← Nat.sub_mul_eq_mod_of_lt_of_le (n := 1) (by omega) (by omega), Nat.mul_one]
|
||||
omega
|
||||
· simp [h₁]
|
||||
|
||||
@@ -3114,6 +2983,20 @@ theorem replicate_succ_eq {x : BitVec w} :
|
||||
(x ++ replicate n x).cast (by rw [Nat.mul_succ]; omega) := by
|
||||
simp [replicate]
|
||||
|
||||
/--
|
||||
If a number `w * n ≤ i < w * (n + 1)`, then `i - w * n` equals `i % w`.
|
||||
This is true by subtracting `w * n` from the inequality, giving
|
||||
`0 ≤ i - w * n < w`, which uniquely identifies `i % w`.
|
||||
-/
|
||||
private theorem Nat.sub_mul_eq_mod_of_lt_of_le (hlo : w * n ≤ i) (hhi : i < w * (n + 1)) :
|
||||
i - w * n = i % w := by
|
||||
rw [Nat.mod_def]
|
||||
congr
|
||||
symm
|
||||
apply Nat.div_eq_of_lt_le
|
||||
(by rw [Nat.mul_comm]; omega)
|
||||
(by rw [Nat.mul_comm]; omega)
|
||||
|
||||
@[simp]
|
||||
theorem getLsbD_replicate {n w : Nat} (x : BitVec w) :
|
||||
(x.replicate n).getLsbD i =
|
||||
@@ -3219,11 +3102,6 @@ theorem toInt_neg_of_ne_intMin {x : BitVec w} (rs : x ≠ intMin w) :
|
||||
have := @Nat.two_pow_pred_mul_two w (by omega)
|
||||
split <;> split <;> omega
|
||||
|
||||
theorem toInt_neg_eq_ite {x : BitVec w} :
|
||||
(-x).toInt = if x = intMin w then x.toInt else -(x.toInt) := by
|
||||
by_cases hx : x = intMin w <;>
|
||||
simp [hx, neg_intMin, toInt_neg_of_ne_intMin]
|
||||
|
||||
theorem msb_intMin {w : Nat} : (intMin w).msb = decide (0 < w) := by
|
||||
simp only [msb_eq_decide, toNat_intMin, decide_eq_decide]
|
||||
by_cases h : 0 < w <;> simp_all
|
||||
@@ -3346,84 +3224,13 @@ theorem toNat_abs {x : BitVec w} : x.abs.toNat = if x.msb then 2^w - x.toNat els
|
||||
· simp [h]
|
||||
|
||||
theorem getLsbD_abs {i : Nat} {x : BitVec w} :
|
||||
getLsbD x.abs i = if x.msb then getLsbD (-x) i else getLsbD x i := by
|
||||
by_cases h : x.msb <;> simp [BitVec.abs, h]
|
||||
|
||||
theorem getElem_abs {i : Nat} {x : BitVec w} (h : i < w) :
|
||||
x.abs[i] = if x.msb then (-x)[i] else x[i] := by
|
||||
getLsbD x.abs i = if x.msb then getLsbD (-x) i else getLsbD x i := by
|
||||
by_cases h : x.msb <;> simp [BitVec.abs, h]
|
||||
|
||||
theorem getMsbD_abs {i : Nat} {x : BitVec w} :
|
||||
getMsbD (x.abs) i = if x.msb then getMsbD (-x) i else getMsbD x i := by
|
||||
by_cases h : x.msb <;> simp [BitVec.abs, h]
|
||||
|
||||
/-
|
||||
The absolute value of `x : BitVec w` is naively a case split on the sign of `x`.
|
||||
However, recall that when `x = intMin w`, `-x = x`.
|
||||
Thus, the full value of `abs x` is computed by the case split:
|
||||
- If `x : BitVec w` is `intMin`, then its absolute value is also `intMin w`, and
|
||||
thus `toInt` will equal `intMin.toInt`.
|
||||
- Otherwise, if `x` is negative, then `x.abs.toInt = (-x).toInt`.
|
||||
- If `x` is positive, then it is equal to `x.abs.toInt = x.toInt`.
|
||||
-/
|
||||
theorem toInt_abs_eq_ite {x : BitVec w} :
|
||||
x.abs.toInt =
|
||||
if x = intMin w then (intMin w).toInt
|
||||
else if x.msb then -x.toInt
|
||||
else x.toInt := by
|
||||
by_cases hx : x = intMin w
|
||||
· simp [hx]
|
||||
· simp [hx]
|
||||
by_cases hx₂ : x.msb
|
||||
· simp [hx₂, abs_eq, toInt_neg_of_ne_intMin hx]
|
||||
· simp [hx₂, abs_eq]
|
||||
|
||||
|
||||
|
||||
/-
|
||||
The absolute value of `x : BitVec w` is a case split on the sign of `x`, when `x ≠ intMin w`.
|
||||
This is a variant of `toInt_abs_eq_ite`.
|
||||
-/
|
||||
theorem toInt_abs_eq_ite_of_ne_intMin {x : BitVec w} (hx : x ≠ intMin w) :
|
||||
x.abs.toInt = if x.msb then -x.toInt else x.toInt := by
|
||||
simp [toInt_abs_eq_ite, hx]
|
||||
|
||||
|
||||
/--
|
||||
The absolute value of `x : BitVec w`, interpreted as an integer, is a case split:
|
||||
- When `x = intMin w`, then `x.abs = intMin w`
|
||||
- Otherwise, `x.abs.toInt` equals the absolute value (`x.toInt.natAbs`).
|
||||
|
||||
This is a simpler version of `BitVec.toInt_abs_eq_ite`, which hides a case split on `x.msb`.
|
||||
-/
|
||||
theorem toInt_abs_eq_natAbs {x : BitVec w} : x.abs.toInt =
|
||||
if x = intMin w then (intMin w).toInt else x.toInt.natAbs := by
|
||||
rw [toInt_abs_eq_ite]
|
||||
by_cases hx : x = intMin w
|
||||
· simp [hx]
|
||||
· simp [hx]
|
||||
by_cases h : x.msb
|
||||
· simp only [h, ↓reduceIte]
|
||||
have : x.toInt < 0 := by
|
||||
rw [toInt_neg_iff]
|
||||
have := msb_eq_true_iff_two_mul_ge.mp h
|
||||
omega
|
||||
omega
|
||||
· simp only [h, Bool.false_eq_true, ↓reduceIte]
|
||||
have : 0 ≤ x.toInt := by
|
||||
rw [toInt_pos_iff]
|
||||
exact msb_eq_false_iff_two_mul_lt.mp (by simp [h])
|
||||
omega
|
||||
|
||||
/-
|
||||
The absolute value of `(x : BitVec w)`, when interpreted as an integer,
|
||||
is the absolute value of `x.toInt` when `(x ≠ intMin)`.
|
||||
-/
|
||||
theorem toInt_abs_eq_natAbs_of_ne_intMin {x : BitVec w} (hx : x ≠ intMin w) :
|
||||
x.abs.toInt = x.toInt.natAbs := by
|
||||
simp [toInt_abs_eq_natAbs, hx]
|
||||
|
||||
|
||||
/-! ### Decidable quantifiers -/
|
||||
|
||||
theorem forall_zero_iff {P : BitVec 0 → Prop} :
|
||||
|
||||
@@ -108,18 +108,8 @@ def toList (bs : ByteArray) : List UInt8 :=
|
||||
|
||||
@[inline] def findIdx? (a : ByteArray) (p : UInt8 → Bool) (start := 0) : Option Nat :=
|
||||
let rec @[specialize] loop (i : Nat) :=
|
||||
if h : i < a.size then
|
||||
if p a[i] then some i else loop (i+1)
|
||||
else
|
||||
none
|
||||
termination_by a.size - i
|
||||
decreasing_by decreasing_trivial_pre_omega
|
||||
loop start
|
||||
|
||||
@[inline] def findFinIdx? (a : ByteArray) (p : UInt8 → Bool) (start := 0) : Option (Fin a.size) :=
|
||||
let rec @[specialize] loop (i : Nat) :=
|
||||
if h : i < a.size then
|
||||
if p a[i] then some ⟨i, h⟩ else loop (i+1)
|
||||
if i < a.size then
|
||||
if p (a.get! i) then some i else loop (i+1)
|
||||
else
|
||||
none
|
||||
termination_by a.size - i
|
||||
|
||||
@@ -13,17 +13,17 @@ namespace Fin
|
||||
/-- Folds over `Fin n` from the left: `foldl 3 f x = f (f (f x 0) 1) 2`. -/
|
||||
@[inline] def foldl (n) (f : α → Fin n → α) (init : α) : α := loop init 0 where
|
||||
/-- Inner loop for `Fin.foldl`. `Fin.foldl.loop n f x i = f (f (f x i) ...) (n-1)` -/
|
||||
@[semireducible] loop (x : α) (i : Nat) : α :=
|
||||
loop (x : α) (i : Nat) : α :=
|
||||
if h : i < n then loop (f x ⟨i, h⟩) (i+1) else x
|
||||
termination_by n - i
|
||||
decreasing_by decreasing_trivial_pre_omega
|
||||
|
||||
/-- Folds over `Fin n` from the right: `foldr 3 f x = f 0 (f 1 (f 2 x))`. -/
|
||||
@[inline] def foldr (n) (f : Fin n → α → α) (init : α) : α := loop n (Nat.le_refl n) init where
|
||||
@[inline] def foldr (n) (f : Fin n → α → α) (init : α) : α := loop ⟨n, Nat.le_refl n⟩ init where
|
||||
/-- Inner loop for `Fin.foldr`. `Fin.foldr.loop n f i x = f 0 (f ... (f (i-1) x))` -/
|
||||
loop : (i : _) → i ≤ n → α → α
|
||||
| 0, _, x => x
|
||||
| i+1, h, x => loop i (Nat.le_of_lt h) (f ⟨i, h⟩ x)
|
||||
termination_by structural i => i
|
||||
loop : {i // i ≤ n} → α → α
|
||||
| ⟨0, _⟩, x => x
|
||||
| ⟨i+1, h⟩, x => loop ⟨i, Nat.le_of_lt h⟩ (f ⟨i, h⟩ x)
|
||||
|
||||
/--
|
||||
Folds a monadic function over `Fin n` from left to right:
|
||||
@@ -176,19 +176,17 @@ theorem foldl_eq_foldlM (f : α → Fin n → α) (x) :
|
||||
/-! ### foldr -/
|
||||
|
||||
theorem foldr_loop_zero (f : Fin n → α → α) (x) :
|
||||
foldr.loop n f 0 (Nat.zero_le _) x = x := by
|
||||
foldr.loop n f ⟨0, Nat.zero_le _⟩ x = x := by
|
||||
rw [foldr.loop]
|
||||
|
||||
theorem foldr_loop_succ (f : Fin n → α → α) (x) (h : i < n) :
|
||||
foldr.loop n f (i+1) h x = foldr.loop n f i (Nat.le_of_lt h) (f ⟨i, h⟩ x) := by
|
||||
foldr.loop n f ⟨i+1, h⟩ x = foldr.loop n f ⟨i, Nat.le_of_lt h⟩ (f ⟨i, h⟩ x) := by
|
||||
rw [foldr.loop]
|
||||
|
||||
theorem foldr_loop (f : Fin (n+1) → α → α) (x) (h : i+1 ≤ n+1) :
|
||||
foldr.loop (n+1) f (i+1) h x =
|
||||
f 0 (foldr.loop n (fun j => f j.succ) i (Nat.le_of_succ_le_succ h) x) := by
|
||||
induction i generalizing x with
|
||||
| zero => simp [foldr_loop_succ, foldr_loop_zero]
|
||||
| succ i ih => rw [foldr_loop_succ, ih]; rfl
|
||||
foldr.loop (n+1) f ⟨i+1, h⟩ x =
|
||||
f 0 (foldr.loop n (fun j => f j.succ) ⟨i, Nat.le_of_succ_le_succ h⟩ x) := by
|
||||
induction i generalizing x <;> simp [foldr_loop_zero, foldr_loop_succ, *]
|
||||
|
||||
@[simp] theorem foldr_zero (f : Fin 0 → α → α) (x) : foldr 0 f x = x :=
|
||||
foldr_loop_zero ..
|
||||
|
||||
@@ -26,4 +26,3 @@ import Init.Data.List.Sort
|
||||
import Init.Data.List.ToArray
|
||||
import Init.Data.List.MapIdx
|
||||
import Init.Data.List.OfFn
|
||||
import Init.Data.List.FinRange
|
||||
|
||||
@@ -231,7 +231,7 @@ theorem ext_get? : ∀ {l₁ l₂ : List α}, (∀ n, l₁.get? n = l₂.get? n)
|
||||
injection h0 with aa; simp only [aa, ext_get? fun n => h (n+1)]
|
||||
|
||||
/-- Deprecated alias for `ext_get?`. The preferred extensionality theorem is now `ext_getElem?`. -/
|
||||
@[deprecated ext_get? (since := "2024-06-07")] abbrev ext := @ext_get?
|
||||
@[deprecated (since := "2024-06-07")] abbrev ext := @ext_get?
|
||||
|
||||
/-! ### getD -/
|
||||
|
||||
@@ -682,7 +682,7 @@ theorem elem_cons [BEq α] {a : α} :
|
||||
(b::bs).elem a = match a == b with | true => true | false => bs.elem a := rfl
|
||||
|
||||
/-- `notElem a l` is `!(elem a l)`. -/
|
||||
@[deprecated "Use `!(elem a l)` instead."(since := "2024-06-15")]
|
||||
@[deprecated (since := "2024-06-15")]
|
||||
def notElem [BEq α] (a : α) (as : List α) : Bool :=
|
||||
!(as.elem a)
|
||||
|
||||
@@ -1427,10 +1427,10 @@ def zipWithAll (f : Option α → Option β → γ) : List α → List β → Li
|
||||
| a :: as, [] => (a :: as).map fun a => f (some a) none
|
||||
| a :: as, b :: bs => f a b :: zipWithAll f as bs
|
||||
|
||||
@[simp] theorem zipWithAll_nil :
|
||||
@[simp] theorem zipWithAll_nil_right :
|
||||
zipWithAll f as [] = as.map fun a => f (some a) none := by
|
||||
cases as <;> rfl
|
||||
@[simp] theorem nil_zipWithAll :
|
||||
@[simp] theorem zipWithAll_nil_left :
|
||||
zipWithAll f [] bs = bs.map fun b => f none (some b) := rfl
|
||||
@[simp] theorem zipWithAll_cons_cons :
|
||||
zipWithAll f (a :: as) (b :: bs) = f (some a) (some b) :: zipWithAll f as bs := rfl
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2024 François G. Dorais. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: François G. Dorais
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.List.OfFn
|
||||
|
||||
namespace List
|
||||
|
||||
/-- `finRange n` lists all elements of `Fin n` in order -/
|
||||
def finRange (n : Nat) : List (Fin n) := ofFn fun i => i
|
||||
|
||||
@[simp] theorem length_finRange (n) : (List.finRange n).length = n := by
|
||||
simp [List.finRange]
|
||||
|
||||
@[simp] theorem getElem_finRange (i : Nat) (h : i < (List.finRange n).length) :
|
||||
(finRange n)[i] = Fin.cast (length_finRange n) ⟨i, h⟩ := by
|
||||
simp [List.finRange]
|
||||
|
||||
@[simp] theorem finRange_zero : finRange 0 = [] := by simp [finRange, ofFn]
|
||||
|
||||
theorem finRange_succ (n) : finRange (n+1) = 0 :: (finRange n).map Fin.succ := by
|
||||
apply List.ext_getElem; simp; intro i; cases i <;> simp
|
||||
|
||||
theorem finRange_succ_last (n) :
|
||||
finRange (n+1) = (finRange n).map Fin.castSucc ++ [Fin.last n] := by
|
||||
apply List.ext_getElem
|
||||
· simp
|
||||
· intros
|
||||
simp only [List.finRange, List.getElem_ofFn, getElem_append, length_map, length_ofFn,
|
||||
getElem_map, Fin.castSucc_mk, getElem_singleton]
|
||||
split
|
||||
· rfl
|
||||
· next h => exact Fin.eq_last_of_not_lt h
|
||||
|
||||
theorem finRange_reverse (n) : (finRange n).reverse = (finRange n).map Fin.rev := by
|
||||
induction n with
|
||||
| zero => simp
|
||||
| succ n ih =>
|
||||
conv => lhs; rw [finRange_succ_last]
|
||||
conv => rhs; rw [finRange_succ]
|
||||
rw [reverse_append, reverse_cons, reverse_nil, nil_append, singleton_append, ← map_reverse,
|
||||
map_cons, ih, map_map, map_map]
|
||||
congr; funext
|
||||
simp [Fin.rev_succ]
|
||||
|
||||
end List
|
||||
@@ -101,7 +101,7 @@ theorem tail_eq_of_cons_eq (H : h₁ :: t₁ = h₂ :: t₂) : t₁ = t₂ := (c
|
||||
theorem cons_inj_right (a : α) {l l' : List α} : a :: l = a :: l' ↔ l = l' :=
|
||||
⟨tail_eq_of_cons_eq, congrArg _⟩
|
||||
|
||||
@[deprecated cons_inj_right (since := "2024-06-15")] abbrev cons_inj := @cons_inj_right
|
||||
@[deprecated (since := "2024-06-15")] abbrev cons_inj := @cons_inj_right
|
||||
|
||||
theorem cons_eq_cons {a b : α} {l l' : List α} : a :: l = b :: l' ↔ a = b ∧ l = l' :=
|
||||
List.cons.injEq .. ▸ .rfl
|
||||
@@ -171,7 +171,7 @@ theorem get_cons_succ {as : List α} {h : i + 1 < (a :: as).length} :
|
||||
theorem get_cons_succ' {as : List α} {i : Fin as.length} :
|
||||
(a :: as).get i.succ = as.get i := rfl
|
||||
|
||||
@[deprecated "Deprecated without replacement." (since := "2024-07-09")]
|
||||
@[deprecated (since := "2024-07-09")]
|
||||
theorem get_cons_cons_one : (a₁ :: a₂ :: as).get (1 : Fin (as.length + 2)) = a₂ := rfl
|
||||
|
||||
theorem get_mk_zero : ∀ {l : List α} (h : 0 < l.length), l.get ⟨0, h⟩ = l.head (length_pos.mp h)
|
||||
@@ -791,24 +791,6 @@ theorem mem_or_eq_of_mem_set : ∀ {l : List α} {n : Nat} {a b : α}, a ∈ l.s
|
||||
· intro a
|
||||
simp
|
||||
|
||||
@[simp] theorem beq_nil_iff [BEq α] {l : List α} : (l == []) = l.isEmpty := by
|
||||
cases l <;> rfl
|
||||
|
||||
@[simp] theorem nil_beq_iff [BEq α] {l : List α} : ([] == l) = l.isEmpty := by
|
||||
cases l <;> rfl
|
||||
|
||||
@[simp] theorem cons_beq_cons [BEq α] {a b : α} {l₁ l₂ : List α} :
|
||||
(a :: l₁ == b :: l₂) = (a == b && l₁ == l₂) := rfl
|
||||
|
||||
theorem length_eq_of_beq [BEq α] {l₁ l₂ : List α} (h : l₁ == l₂) : l₁.length = l₂.length :=
|
||||
match l₁, l₂ with
|
||||
| [], [] => rfl
|
||||
| [], _ :: _ => by simp [beq_nil_iff] at h
|
||||
| _ :: _, [] => by simp [nil_beq_iff] at h
|
||||
| a :: l₁, b :: l₂ => by
|
||||
simp at h
|
||||
simpa [Nat.add_one_inj]using length_eq_of_beq h.2
|
||||
|
||||
/-! ### Lexicographic ordering -/
|
||||
|
||||
protected theorem lt_irrefl [LT α] (lt_irrefl : ∀ x : α, ¬x < x) (l : List α) : ¬l < l := by
|
||||
@@ -874,12 +856,6 @@ theorem foldr_eq_foldrM (f : α → β → β) (b) (l : List α) :
|
||||
l.foldr f b = l.foldrM (m := Id) f b := by
|
||||
induction l <;> simp [*, foldr]
|
||||
|
||||
@[simp] theorem id_run_foldlM (f : β → α → Id β) (b) (l : List α) :
|
||||
Id.run (l.foldlM f b) = l.foldl f b := (foldl_eq_foldlM f b l).symm
|
||||
|
||||
@[simp] theorem id_run_foldrM (f : α → β → Id β) (b) (l : List α) :
|
||||
Id.run (l.foldrM f b) = l.foldr f b := (foldr_eq_foldrM f b l).symm
|
||||
|
||||
/-! ### foldl and foldr -/
|
||||
|
||||
@[simp] theorem foldr_cons_eq_append (l : List α) : l.foldr cons l' = l ++ l' := by
|
||||
@@ -1824,7 +1800,7 @@ theorem getElem_append_right' (l₁ : List α) {l₂ : List α} {n : Nat} (hn :
|
||||
l₂[n] = (l₁ ++ l₂)[n + l₁.length]'(by simpa [Nat.add_comm] using Nat.add_lt_add_left hn _) := by
|
||||
rw [getElem_append_right] <;> simp [*, le_add_left]
|
||||
|
||||
@[deprecated "Deprecated without replacement." (since := "2024-06-12")]
|
||||
@[deprecated (since := "2024-06-12")]
|
||||
theorem get_append_right_aux {l₁ l₂ : List α} {n : Nat}
|
||||
(h₁ : l₁.length ≤ n) (h₂ : n < (l₁ ++ l₂).length) : n - l₁.length < l₂.length := by
|
||||
rw [length_append] at h₂
|
||||
@@ -1841,7 +1817,7 @@ theorem getElem_of_append {l : List α} (eq : l = l₁ ++ a :: l₂) (h : l₁.l
|
||||
rw [← getElem?_eq_getElem, eq, getElem?_append_right (h ▸ Nat.le_refl _), h]
|
||||
simp
|
||||
|
||||
@[deprecated "Deprecated without replacement." (since := "2024-06-12")]
|
||||
@[deprecated (since := "2024-06-12")]
|
||||
theorem get_of_append_proof {l : List α}
|
||||
(eq : l = l₁ ++ a :: l₂) (h : l₁.length = n) : n < length l := eq ▸ h ▸ by simp_arith
|
||||
|
||||
@@ -3357,10 +3333,10 @@ theorem any_eq_not_all_not (l : List α) (p : α → Bool) : l.any p = !l.all (!
|
||||
theorem all_eq_not_any_not (l : List α) (p : α → Bool) : l.all p = !l.any (!p .) := by
|
||||
simp only [not_any_eq_all_not, Bool.not_not]
|
||||
|
||||
@[simp] theorem any_map {l : List α} {p : β → Bool} : (l.map f).any p = l.any (p ∘ f) := by
|
||||
@[simp] theorem any_map {l : List α} {p : α → Bool} : (l.map f).any p = l.any (p ∘ f) := by
|
||||
induction l with simp | cons _ _ ih => rw [ih]
|
||||
|
||||
@[simp] theorem all_map {l : List α} {p : β → Bool} : (l.map f).all p = l.all (p ∘ f) := by
|
||||
@[simp] theorem all_map {l : List α} {p : α → Bool} : (l.map f).all p = l.all (p ∘ f) := by
|
||||
induction l with simp | cons _ _ ih => rw [ih]
|
||||
|
||||
@[simp] theorem any_filter {l : List α} {p q : α → Bool} :
|
||||
|
||||
@@ -9,7 +9,7 @@ import Init.Data.List.Basic
|
||||
|
||||
namespace List
|
||||
|
||||
/-! ### isEqv -/
|
||||
/-! ### isEqv-/
|
||||
|
||||
theorem isEqv_eq_decide (a b : List α) (r) :
|
||||
isEqv a b r = if h : a.length = b.length then
|
||||
|
||||
@@ -293,7 +293,7 @@ theorem sorted_mergeSort
|
||||
apply sorted_mergeSort trans total
|
||||
termination_by l => l.length
|
||||
|
||||
@[deprecated sorted_mergeSort (since := "2024-09-02")] abbrev mergeSort_sorted := @sorted_mergeSort
|
||||
@[deprecated (since := "2024-09-02")] abbrev mergeSort_sorted := @sorted_mergeSort
|
||||
|
||||
/--
|
||||
If the input list is already sorted, then `mergeSort` does not change the list.
|
||||
@@ -429,8 +429,7 @@ theorem sublist_mergeSort
|
||||
((fun w => Sublist.of_sublist_append_right w h') fun b m₁ m₃ =>
|
||||
(Bool.eq_not_self true).mp ((rel_of_pairwise_cons hc m₁).symm.trans (h₃ b m₃))))
|
||||
|
||||
@[deprecated sublist_mergeSort (since := "2024-09-02")]
|
||||
abbrev mergeSort_stable := @sublist_mergeSort
|
||||
@[deprecated (since := "2024-09-02")] abbrev mergeSort_stable := @sublist_mergeSort
|
||||
|
||||
/--
|
||||
Another statement of stability of merge sort.
|
||||
@@ -443,8 +442,7 @@ theorem pair_sublist_mergeSort
|
||||
(hab : le a b) (h : [a, b] <+ l) : [a, b] <+ mergeSort l le :=
|
||||
sublist_mergeSort trans total (pairwise_pair.mpr hab) h
|
||||
|
||||
@[deprecated pair_sublist_mergeSort(since := "2024-09-02")]
|
||||
abbrev mergeSort_stable_pair := @pair_sublist_mergeSort
|
||||
@[deprecated (since := "2024-09-02")] abbrev mergeSort_stable_pair := @pair_sublist_mergeSort
|
||||
|
||||
theorem map_merge {f : α → β} {r : α → α → Bool} {s : β → β → Bool} {l l' : List α}
|
||||
(hl : ∀ a ∈ l, ∀ b ∈ l', r a b = s (f a) (f b)) :
|
||||
|
||||
@@ -835,7 +835,7 @@ theorem isPrefix_iff : l₁ <+: l₂ ↔ ∀ i (h : i < l₁.length), l₂[i]? =
|
||||
simpa using ⟨0, by simp⟩
|
||||
| cons b l₂ =>
|
||||
simp only [cons_append, cons_prefix_cons, ih]
|
||||
rw (occs := [2]) [← Nat.and_forall_add_one]
|
||||
rw (occs := .pos [2]) [← Nat.and_forall_add_one]
|
||||
simp [Nat.succ_lt_succ_iff, eq_comm]
|
||||
|
||||
theorem isPrefix_iff_getElem {l₁ l₂ : List α} :
|
||||
|
||||
@@ -224,7 +224,7 @@ theorem take_succ {l : List α} {n : Nat} : l.take (n + 1) = l.take n ++ l[n]?.t
|
||||
· simp only [take, Option.toList, getElem?_cons_zero, nil_append]
|
||||
· simp only [take, hl, getElem?_cons_succ, cons_append]
|
||||
|
||||
@[deprecated "Deprecated without replacement." (since := "2024-07-25")]
|
||||
@[deprecated (since := "2024-07-25")]
|
||||
theorem drop_sizeOf_le [SizeOf α] (l : List α) (n : Nat) : sizeOf (l.drop n) ≤ sizeOf l := by
|
||||
induction l generalizing n with
|
||||
| nil => rw [drop_nil]; apply Nat.le_refl
|
||||
|
||||
@@ -20,4 +20,3 @@ import Init.Data.Nat.Mod
|
||||
import Init.Data.Nat.Lcm
|
||||
import Init.Data.Nat.Compare
|
||||
import Init.Data.Nat.Simproc
|
||||
import Init.Data.Nat.Fold
|
||||
|
||||
@@ -35,6 +35,52 @@ Used as the default `Nat` eliminator by the `cases` tactic. -/
|
||||
protected abbrev casesAuxOn {motive : Nat → Sort u} (t : Nat) (zero : motive 0) (succ : (n : Nat) → motive (n + 1)) : motive t :=
|
||||
Nat.casesOn t zero succ
|
||||
|
||||
/--
|
||||
`Nat.fold` evaluates `f` on the numbers up to `n` exclusive, in increasing order:
|
||||
* `Nat.fold f 3 init = init |> f 0 |> f 1 |> f 2`
|
||||
-/
|
||||
@[specialize] def fold {α : Type u} (f : Nat → α → α) : (n : Nat) → (init : α) → α
|
||||
| 0, a => a
|
||||
| succ n, a => f n (fold f n a)
|
||||
|
||||
/-- Tail-recursive version of `Nat.fold`. -/
|
||||
@[inline] def foldTR {α : Type u} (f : Nat → α → α) (n : Nat) (init : α) : α :=
|
||||
let rec @[specialize] loop
|
||||
| 0, a => a
|
||||
| succ m, a => loop m (f (n - succ m) a)
|
||||
loop n init
|
||||
|
||||
/--
|
||||
`Nat.foldRev` evaluates `f` on the numbers up to `n` exclusive, in decreasing order:
|
||||
* `Nat.foldRev f 3 init = f 0 <| f 1 <| f 2 <| init`
|
||||
-/
|
||||
@[specialize] def foldRev {α : Type u} (f : Nat → α → α) : (n : Nat) → (init : α) → α
|
||||
| 0, a => a
|
||||
| succ n, a => foldRev f n (f n a)
|
||||
|
||||
/-- `any f n = true` iff there is `i in [0, n-1]` s.t. `f i = true` -/
|
||||
@[specialize] def any (f : Nat → Bool) : Nat → Bool
|
||||
| 0 => false
|
||||
| succ n => any f n || f n
|
||||
|
||||
/-- Tail-recursive version of `Nat.any`. -/
|
||||
@[inline] def anyTR (f : Nat → Bool) (n : Nat) : Bool :=
|
||||
let rec @[specialize] loop : Nat → Bool
|
||||
| 0 => false
|
||||
| succ m => f (n - succ m) || loop m
|
||||
loop n
|
||||
|
||||
/-- `all f n = true` iff every `i in [0, n-1]` satisfies `f i = true` -/
|
||||
@[specialize] def all (f : Nat → Bool) : Nat → Bool
|
||||
| 0 => true
|
||||
| succ n => all f n && f n
|
||||
|
||||
/-- Tail-recursive version of `Nat.all`. -/
|
||||
@[inline] def allTR (f : Nat → Bool) (n : Nat) : Bool :=
|
||||
let rec @[specialize] loop : Nat → Bool
|
||||
| 0 => true
|
||||
| succ m => f (n - succ m) && loop m
|
||||
loop n
|
||||
|
||||
/--
|
||||
`Nat.repeat f n a` is `f^(n) a`; that is, it iterates `f` `n` times on `a`.
|
||||
@@ -789,7 +835,7 @@ theorem pred_lt_of_lt {n m : Nat} (h : m < n) : pred n < n :=
|
||||
pred_lt (not_eq_zero_of_lt h)
|
||||
|
||||
set_option linter.missingDocs false in
|
||||
@[deprecated pred_lt_of_lt (since := "2024-06-01")] abbrev pred_lt' := @pred_lt_of_lt
|
||||
@[deprecated (since := "2024-06-01")] abbrev pred_lt' := @pred_lt_of_lt
|
||||
|
||||
theorem sub_one_lt_of_lt {n m : Nat} (h : m < n) : n - 1 < n :=
|
||||
sub_one_lt (not_eq_zero_of_lt h)
|
||||
@@ -1075,7 +1121,7 @@ theorem pred_mul (n m : Nat) : pred n * m = n * m - m := by
|
||||
| succ n => rw [Nat.pred_succ, succ_mul, Nat.add_sub_cancel]
|
||||
|
||||
set_option linter.missingDocs false in
|
||||
@[deprecated pred_mul (since := "2024-06-01")] abbrev mul_pred_left := @pred_mul
|
||||
@[deprecated (since := "2024-06-01")] abbrev mul_pred_left := @pred_mul
|
||||
|
||||
protected theorem sub_one_mul (n m : Nat) : (n - 1) * m = n * m - m := by
|
||||
cases n with
|
||||
@@ -1087,7 +1133,7 @@ theorem mul_pred (n m : Nat) : n * pred m = n * m - n := by
|
||||
rw [Nat.mul_comm, pred_mul, Nat.mul_comm]
|
||||
|
||||
set_option linter.missingDocs false in
|
||||
@[deprecated mul_pred (since := "2024-06-01")] abbrev mul_pred_right := @mul_pred
|
||||
@[deprecated (since := "2024-06-01")] abbrev mul_pred_right := @mul_pred
|
||||
|
||||
theorem mul_sub_one (n m : Nat) : n * (m - 1) = n * m - n := by
|
||||
rw [Nat.mul_comm, Nat.sub_one_mul , Nat.mul_comm]
|
||||
@@ -1112,6 +1158,33 @@ theorem not_lt_eq (a b : Nat) : (¬ (a < b)) = (b ≤ a) :=
|
||||
theorem not_gt_eq (a b : Nat) : (¬ (a > b)) = (a ≤ b) :=
|
||||
not_lt_eq b a
|
||||
|
||||
/-! # csimp theorems -/
|
||||
|
||||
@[csimp] theorem fold_eq_foldTR : @fold = @foldTR :=
|
||||
funext fun α => funext fun f => funext fun n => funext fun init =>
|
||||
let rec go : ∀ m n, foldTR.loop f (m + n) m (fold f n init) = fold f (m + n) init
|
||||
| 0, n => by simp [foldTR.loop]
|
||||
| succ m, n => by rw [foldTR.loop, add_sub_self_left, succ_add]; exact go m (succ n)
|
||||
(go n 0).symm
|
||||
|
||||
@[csimp] theorem any_eq_anyTR : @any = @anyTR :=
|
||||
funext fun f => funext fun n =>
|
||||
let rec go : ∀ m n, (any f n || anyTR.loop f (m + n) m) = any f (m + n)
|
||||
| 0, n => by simp [anyTR.loop]
|
||||
| succ m, n => by
|
||||
rw [anyTR.loop, add_sub_self_left, ← Bool.or_assoc, succ_add]
|
||||
exact go m (succ n)
|
||||
(go n 0).symm
|
||||
|
||||
@[csimp] theorem all_eq_allTR : @all = @allTR :=
|
||||
funext fun f => funext fun n =>
|
||||
let rec go : ∀ m n, (all f n && allTR.loop f (m + n) m) = all f (m + n)
|
||||
| 0, n => by simp [allTR.loop]
|
||||
| succ m, n => by
|
||||
rw [allTR.loop, add_sub_self_left, ← Bool.and_assoc, succ_add]
|
||||
exact go m (succ n)
|
||||
(go n 0).symm
|
||||
|
||||
@[csimp] theorem repeat_eq_repeatTR : @repeat = @repeatTR :=
|
||||
funext fun α => funext fun f => funext fun n => funext fun init =>
|
||||
let rec go : ∀ m n, repeatTR.loop f m (repeat f n init) = repeat f (m + n) init
|
||||
@@ -1120,3 +1193,31 @@ theorem not_gt_eq (a b : Nat) : (¬ (a > b)) = (a ≤ b) :=
|
||||
(go n 0).symm
|
||||
|
||||
end Nat
|
||||
|
||||
namespace Prod
|
||||
|
||||
/--
|
||||
`(start, stop).foldI f a` evaluates `f` on all the numbers
|
||||
from `start` (inclusive) to `stop` (exclusive) in increasing order:
|
||||
* `(5, 8).foldI f init = init |> f 5 |> f 6 |> f 7`
|
||||
-/
|
||||
@[inline] def foldI {α : Type u} (f : Nat → α → α) (i : Nat × Nat) (a : α) : α :=
|
||||
Nat.foldTR.loop f i.2 (i.2 - i.1) a
|
||||
|
||||
/--
|
||||
`(start, stop).anyI f a` returns true if `f` is true for some natural number
|
||||
from `start` (inclusive) to `stop` (exclusive):
|
||||
* `(5, 8).anyI f = f 5 || f 6 || f 7`
|
||||
-/
|
||||
@[inline] def anyI (f : Nat → Bool) (i : Nat × Nat) : Bool :=
|
||||
Nat.anyTR.loop f i.2 (i.2 - i.1)
|
||||
|
||||
/--
|
||||
`(start, stop).allI f a` returns true if `f` is true for all natural numbers
|
||||
from `start` (inclusive) to `stop` (exclusive):
|
||||
* `(5, 8).anyI f = f 5 && f 6 && f 7`
|
||||
-/
|
||||
@[inline] def allI (f : Nat → Bool) (i : Nat × Nat) : Bool :=
|
||||
Nat.allTR.loop f i.2 (i.2 - i.1)
|
||||
|
||||
end Prod
|
||||
|
||||
@@ -6,51 +6,50 @@ Author: Leonardo de Moura
|
||||
prelude
|
||||
import Init.Control.Basic
|
||||
import Init.Data.Nat.Basic
|
||||
import Init.Omega
|
||||
|
||||
namespace Nat
|
||||
universe u v
|
||||
|
||||
@[inline] def forM {m} [Monad m] (n : Nat) (f : (i : Nat) → i < n → m Unit) : m Unit :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → m Unit
|
||||
| 0, _ => pure ()
|
||||
| i+1, h => do f (n-i-1) (by omega); loop i (Nat.le_of_succ_le h)
|
||||
loop n (by simp)
|
||||
@[inline] def forM {m} [Monad m] (n : Nat) (f : Nat → m Unit) : m Unit :=
|
||||
let rec @[specialize] loop
|
||||
| 0 => pure ()
|
||||
| i+1 => do f (n-i-1); loop i
|
||||
loop n
|
||||
|
||||
@[inline] def forRevM {m} [Monad m] (n : Nat) (f : (i : Nat) → i < n → m Unit) : m Unit :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → m Unit
|
||||
| 0, _ => pure ()
|
||||
| i+1, h => do f i (by omega); loop i (Nat.le_of_succ_le h)
|
||||
loop n (by simp)
|
||||
@[inline] def forRevM {m} [Monad m] (n : Nat) (f : Nat → m Unit) : m Unit :=
|
||||
let rec @[specialize] loop
|
||||
| 0 => pure ()
|
||||
| i+1 => do f i; loop i
|
||||
loop n
|
||||
|
||||
@[inline] def foldM {α : Type u} {m : Type u → Type v} [Monad m] (n : Nat) (f : (i : Nat) → i < n → α → m α) (init : α) : m α :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → α → m α
|
||||
| 0, h, a => pure a
|
||||
| i+1, h, a => f (n-i-1) (by omega) a >>= loop i (Nat.le_of_succ_le h)
|
||||
loop n (by omega) init
|
||||
@[inline] def foldM {α : Type u} {m : Type u → Type v} [Monad m] (f : Nat → α → m α) (init : α) (n : Nat) : m α :=
|
||||
let rec @[specialize] loop
|
||||
| 0, a => pure a
|
||||
| i+1, a => f (n-i-1) a >>= loop i
|
||||
loop n init
|
||||
|
||||
@[inline] def foldRevM {α : Type u} {m : Type u → Type v} [Monad m] (n : Nat) (f : (i : Nat) → i < n → α → m α) (init : α) : m α :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → α → m α
|
||||
| 0, h, a => pure a
|
||||
| i+1, h, a => f i (by omega) a >>= loop i (Nat.le_of_succ_le h)
|
||||
loop n (by omega) init
|
||||
@[inline] def foldRevM {α : Type u} {m : Type u → Type v} [Monad m] (f : Nat → α → m α) (init : α) (n : Nat) : m α :=
|
||||
let rec @[specialize] loop
|
||||
| 0, a => pure a
|
||||
| i+1, a => f i a >>= loop i
|
||||
loop n init
|
||||
|
||||
@[inline] def allM {m} [Monad m] (n : Nat) (p : (i : Nat) → i < n → m Bool) : m Bool :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → m Bool
|
||||
| 0, _ => pure true
|
||||
| i+1 , h => do
|
||||
match (← p (n-i-1) (by omega)) with
|
||||
| true => loop i (by omega)
|
||||
@[inline] def allM {m} [Monad m] (n : Nat) (p : Nat → m Bool) : m Bool :=
|
||||
let rec @[specialize] loop
|
||||
| 0 => pure true
|
||||
| i+1 => do
|
||||
match (← p (n-i-1)) with
|
||||
| true => loop i
|
||||
| false => pure false
|
||||
loop n (by simp)
|
||||
loop n
|
||||
|
||||
@[inline] def anyM {m} [Monad m] (n : Nat) (p : (i : Nat) → i < n → m Bool) : m Bool :=
|
||||
let rec @[specialize] loop : ∀ i, i ≤ n → m Bool
|
||||
| 0, _ => pure false
|
||||
| i+1, h => do
|
||||
match (← p (n-i-1) (by omega)) with
|
||||
@[inline] def anyM {m} [Monad m] (n : Nat) (p : Nat → m Bool) : m Bool :=
|
||||
let rec @[specialize] loop
|
||||
| 0 => pure false
|
||||
| i+1 => do
|
||||
match (← p (n-i-1)) with
|
||||
| true => pure true
|
||||
| false => loop i (Nat.le_of_succ_le h)
|
||||
loop n (by simp)
|
||||
| false => loop i
|
||||
loop n
|
||||
|
||||
end Nat
|
||||
|
||||
@@ -92,7 +92,7 @@ protected theorem div_mul_cancel {n m : Nat} (H : n ∣ m) : m / n * n = m := by
|
||||
rw [Nat.mul_comm, Nat.mul_div_cancel' H]
|
||||
|
||||
@[simp] theorem mod_mod_of_dvd (a : Nat) (h : c ∣ b) : a % b % c = a % c := by
|
||||
rw (occs := [2]) [← mod_add_div a b]
|
||||
rw (occs := .pos [2]) [← mod_add_div a b]
|
||||
have ⟨x, h⟩ := h
|
||||
subst h
|
||||
rw [Nat.mul_assoc, add_mul_mod_self_left]
|
||||
|
||||
@@ -1,217 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2014 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Floris van Doorn, Leonardo de Moura, Kim Morrison
|
||||
-/
|
||||
prelude
|
||||
import Init.Omega
|
||||
import Init.Data.List.FinRange
|
||||
|
||||
set_option linter.missingDocs true -- keep it documented
|
||||
universe u
|
||||
|
||||
namespace Nat
|
||||
|
||||
/--
|
||||
`Nat.fold` evaluates `f` on the numbers up to `n` exclusive, in increasing order:
|
||||
* `Nat.fold f 3 init = init |> f 0 |> f 1 |> f 2`
|
||||
-/
|
||||
@[specialize] def fold {α : Type u} : (n : Nat) → (f : (i : Nat) → i < n → α → α) → (init : α) → α
|
||||
| 0, f, a => a
|
||||
| succ n, f, a => f n (by omega) (fold n (fun i h => f i (by omega)) a)
|
||||
|
||||
/-- Tail-recursive version of `Nat.fold`. -/
|
||||
@[inline] def foldTR {α : Type u} (n : Nat) (f : (i : Nat) → i < n → α → α) (init : α) : α :=
|
||||
let rec @[specialize] loop : ∀ j, j ≤ n → α → α
|
||||
| 0, h, a => a
|
||||
| succ m, h, a => loop m (by omega) (f (n - succ m) (by omega) a)
|
||||
loop n (by omega) init
|
||||
|
||||
/--
|
||||
`Nat.foldRev` evaluates `f` on the numbers up to `n` exclusive, in decreasing order:
|
||||
* `Nat.foldRev f 3 init = f 0 <| f 1 <| f 2 <| init`
|
||||
-/
|
||||
@[specialize] def foldRev {α : Type u} : (n : Nat) → (f : (i : Nat) → i < n → α → α) → (init : α) → α
|
||||
| 0, f, a => a
|
||||
| succ n, f, a => foldRev n (fun i h => f i (by omega)) (f n (by omega) a)
|
||||
|
||||
/-- `any f n = true` iff there is `i in [0, n-1]` s.t. `f i = true` -/
|
||||
@[specialize] def any : (n : Nat) → (f : (i : Nat) → i < n → Bool) → Bool
|
||||
| 0, f => false
|
||||
| succ n, f => any n (fun i h => f i (by omega)) || f n (by omega)
|
||||
|
||||
/-- Tail-recursive version of `Nat.any`. -/
|
||||
@[inline] def anyTR (n : Nat) (f : (i : Nat) → i < n → Bool) : Bool :=
|
||||
let rec @[specialize] loop : (i : Nat) → i ≤ n → Bool
|
||||
| 0, h => false
|
||||
| succ m, h => f (n - succ m) (by omega) || loop m (by omega)
|
||||
loop n (by omega)
|
||||
|
||||
/-- `all f n = true` iff every `i in [0, n-1]` satisfies `f i = true` -/
|
||||
@[specialize] def all : (n : Nat) → (f : (i : Nat) → i < n → Bool) → Bool
|
||||
| 0, f => true
|
||||
| succ n, f => all n (fun i h => f i (by omega)) && f n (by omega)
|
||||
|
||||
/-- Tail-recursive version of `Nat.all`. -/
|
||||
@[inline] def allTR (n : Nat) (f : (i : Nat) → i < n → Bool) : Bool :=
|
||||
let rec @[specialize] loop : (i : Nat) → i ≤ n → Bool
|
||||
| 0, h => true
|
||||
| succ m, h => f (n - succ m) (by omega) && loop m (by omega)
|
||||
loop n (by omega)
|
||||
|
||||
/-! # csimp theorems -/
|
||||
|
||||
theorem fold_congr {α : Type u} {n m : Nat} (w : n = m)
|
||||
(f : (i : Nat) → i < n → α → α) (init : α) :
|
||||
fold n f init = fold m (fun i h => f i (by omega)) init := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
theorem foldTR_loop_congr {α : Type u} {n m : Nat} (w : n = m)
|
||||
(f : (i : Nat) → i < n → α → α) (j : Nat) (h : j ≤ n) (init : α) :
|
||||
foldTR.loop n f j h init = foldTR.loop m (fun i h => f i (by omega)) j (by omega) init := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
@[csimp] theorem fold_eq_foldTR : @fold = @foldTR :=
|
||||
funext fun α => funext fun n => funext fun f => funext fun init =>
|
||||
let rec go : ∀ m n f, fold (m + n) f init = foldTR.loop (m + n) f m (by omega) (fold n (fun i h => f i (by omega)) init)
|
||||
| 0, n, f => by
|
||||
simp only [foldTR.loop]
|
||||
have t : 0 + n = n := by omega
|
||||
rw [fold_congr t]
|
||||
| succ m, n, f => by
|
||||
have t : (m + 1) + n = m + (n + 1) := by omega
|
||||
rw [foldTR.loop]
|
||||
simp only [succ_eq_add_one, Nat.add_sub_cancel]
|
||||
rw [fold_congr t, foldTR_loop_congr t, go, fold]
|
||||
congr
|
||||
omega
|
||||
go n 0 f
|
||||
|
||||
theorem any_congr {n m : Nat} (w : n = m) (f : (i : Nat) → i < n → Bool) : any n f = any m (fun i h => f i (by omega)) := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
theorem anyTR_loop_congr {n m : Nat} (w : n = m) (f : (i : Nat) → i < n → Bool) (j : Nat) (h : j ≤ n) :
|
||||
anyTR.loop n f j h = anyTR.loop m (fun i h => f i (by omega)) j (by omega) := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
@[csimp] theorem any_eq_anyTR : @any = @anyTR :=
|
||||
funext fun n => funext fun f =>
|
||||
let rec go : ∀ m n f, any (m + n) f = (any n (fun i h => f i (by omega)) || anyTR.loop (m + n) f m (by omega))
|
||||
| 0, n, f => by
|
||||
simp [anyTR.loop]
|
||||
have t : 0 + n = n := by omega
|
||||
rw [any_congr t]
|
||||
| succ m, n, f => by
|
||||
have t : (m + 1) + n = m + (n + 1) := by omega
|
||||
rw [anyTR.loop]
|
||||
simp only [succ_eq_add_one]
|
||||
rw [any_congr t, anyTR_loop_congr t, go, any, Bool.or_assoc]
|
||||
congr
|
||||
omega
|
||||
go n 0 f
|
||||
|
||||
theorem all_congr {n m : Nat} (w : n = m) (f : (i : Nat) → i < n → Bool) : all n f = all m (fun i h => f i (by omega)) := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
theorem allTR_loop_congr {n m : Nat} (w : n = m) (f : (i : Nat) → i < n → Bool) (j : Nat) (h : j ≤ n) : allTR.loop n f j h = allTR.loop m (fun i h => f i (by omega)) j (by omega) := by
|
||||
subst m
|
||||
rfl
|
||||
|
||||
@[csimp] theorem all_eq_allTR : @all = @allTR :=
|
||||
funext fun n => funext fun f =>
|
||||
let rec go : ∀ m n f, all (m + n) f = (all n (fun i h => f i (by omega)) && allTR.loop (m + n) f m (by omega))
|
||||
| 0, n, f => by
|
||||
simp [allTR.loop]
|
||||
have t : 0 + n = n := by omega
|
||||
rw [all_congr t]
|
||||
| succ m, n, f => by
|
||||
have t : (m + 1) + n = m + (n + 1) := by omega
|
||||
rw [allTR.loop]
|
||||
simp only [succ_eq_add_one]
|
||||
rw [all_congr t, allTR_loop_congr t, go, all, Bool.and_assoc]
|
||||
congr
|
||||
omega
|
||||
go n 0 f
|
||||
|
||||
@[simp] theorem fold_zero {α : Type u} (f : (i : Nat) → i < 0 → α → α) (init : α) :
|
||||
fold 0 f init = init := by simp [fold]
|
||||
|
||||
@[simp] theorem fold_succ {α : Type u} (n : Nat) (f : (i : Nat) → i < n + 1 → α → α) (init : α) :
|
||||
fold (n + 1) f init = f n (by omega) (fold n (fun i h => f i (by omega)) init) := by simp [fold]
|
||||
|
||||
theorem fold_eq_finRange_foldl {α : Type u} (n : Nat) (f : (i : Nat) → i < n → α → α) (init : α) :
|
||||
fold n f init = (List.finRange n).foldl (fun acc ⟨i, h⟩ => f i h acc) init := by
|
||||
induction n with
|
||||
| zero => simp
|
||||
| succ n ih =>
|
||||
simp [ih, List.finRange_succ_last, List.foldl_map]
|
||||
|
||||
@[simp] theorem foldRev_zero {α : Type u} (f : (i : Nat) → i < 0 → α → α) (init : α) :
|
||||
foldRev 0 f init = init := by simp [foldRev]
|
||||
|
||||
@[simp] theorem foldRev_succ {α : Type u} (n : Nat) (f : (i : Nat) → i < n + 1 → α → α) (init : α) :
|
||||
foldRev (n + 1) f init = foldRev n (fun i h => f i (by omega)) (f n (by omega) init) := by
|
||||
simp [foldRev]
|
||||
|
||||
theorem foldRev_eq_finRange_foldr {α : Type u} (n : Nat) (f : (i : Nat) → i < n → α → α) (init : α) :
|
||||
foldRev n f init = (List.finRange n).foldr (fun ⟨i, h⟩ acc => f i h acc) init := by
|
||||
induction n generalizing init with
|
||||
| zero => simp
|
||||
| succ n ih => simp [ih, List.finRange_succ_last, List.foldr_map]
|
||||
|
||||
@[simp] theorem any_zero {f : (i : Nat) → i < 0 → Bool} : any 0 f = false := by simp [any]
|
||||
|
||||
@[simp] theorem any_succ {n : Nat} (f : (i : Nat) → i < n + 1 → Bool) :
|
||||
any (n + 1) f = (any n (fun i h => f i (by omega)) || f n (by omega)) := by simp [any]
|
||||
|
||||
theorem any_eq_finRange_any {n : Nat} (f : (i : Nat) → i < n → Bool) :
|
||||
any n f = (List.finRange n).any (fun ⟨i, h⟩ => f i h) := by
|
||||
induction n with
|
||||
| zero => simp
|
||||
| succ n ih => simp [ih, List.finRange_succ_last, List.any_map, Function.comp_def]
|
||||
|
||||
@[simp] theorem all_zero {f : (i : Nat) → i < 0 → Bool} : all 0 f = true := by simp [all]
|
||||
|
||||
@[simp] theorem all_succ {n : Nat} (f : (i : Nat) → i < n + 1 → Bool) :
|
||||
all (n + 1) f = (all n (fun i h => f i (by omega)) && f n (by omega)) := by simp [all]
|
||||
|
||||
theorem all_eq_finRange_all {n : Nat} (f : (i : Nat) → i < n → Bool) :
|
||||
all n f = (List.finRange n).all (fun ⟨i, h⟩ => f i h) := by
|
||||
induction n with
|
||||
| zero => simp
|
||||
| succ n ih => simp [ih, List.finRange_succ_last, List.all_map, Function.comp_def]
|
||||
|
||||
end Nat
|
||||
|
||||
namespace Prod
|
||||
|
||||
/--
|
||||
`(start, stop).foldI f a` evaluates `f` on all the numbers
|
||||
from `start` (inclusive) to `stop` (exclusive) in increasing order:
|
||||
* `(5, 8).foldI f init = init |> f 5 |> f 6 |> f 7`
|
||||
-/
|
||||
@[inline] def foldI {α : Type u} (i : Nat × Nat) (f : (j : Nat) → i.1 ≤ j → j < i.2 → α → α) (a : α) : α :=
|
||||
(i.2 - i.1).fold (fun j _ => f (i.1 + j) (by omega) (by omega)) a
|
||||
|
||||
/--
|
||||
`(start, stop).anyI f a` returns true if `f` is true for some natural number
|
||||
from `start` (inclusive) to `stop` (exclusive):
|
||||
* `(5, 8).anyI f = f 5 || f 6 || f 7`
|
||||
-/
|
||||
@[inline] def anyI (i : Nat × Nat) (f : (j : Nat) → i.1 ≤ j → j < i.2 → Bool) : Bool :=
|
||||
(i.2 - i.1).any (fun j _ => f (i.1 + j) (by omega) (by omega))
|
||||
|
||||
/--
|
||||
`(start, stop).allI f a` returns true if `f` is true for all natural numbers
|
||||
from `start` (inclusive) to `stop` (exclusive):
|
||||
* `(5, 8).anyI f = f 5 && f 6 && f 7`
|
||||
-/
|
||||
@[inline] def allI (i : Nat × Nat) (f : (j : Nat) → i.1 ≤ j → j < i.2 → Bool) : Bool :=
|
||||
(i.2 - i.1).all (fun j _ => f (i.1 + j) (by omega) (by omega))
|
||||
|
||||
end Prod
|
||||
@@ -651,8 +651,8 @@ theorem sub_mul_mod {x k n : Nat} (h₁ : n*k ≤ x) : (x - n*k) % n = x % n :=
|
||||
| .inr npos => Nat.mod_eq_of_lt (mod_lt _ npos)
|
||||
|
||||
theorem mul_mod (a b n : Nat) : a * b % n = (a % n) * (b % n) % n := by
|
||||
rw (occs := [1]) [← mod_add_div a n]
|
||||
rw (occs := [1]) [← mod_add_div b n]
|
||||
rw (occs := .pos [1]) [← mod_add_div a n]
|
||||
rw (occs := .pos [1]) [← mod_add_div b n]
|
||||
rw [Nat.add_mul, Nat.mul_add, Nat.mul_add,
|
||||
Nat.mul_assoc, Nat.mul_assoc, ← Nat.mul_add n, add_mul_mod_self_left,
|
||||
Nat.mul_comm _ (n * (b / n)), Nat.mul_assoc, add_mul_mod_self_left]
|
||||
@@ -679,10 +679,6 @@ theorem add_mod (a b n : Nat) : (a + b) % n = ((a % n) + (b % n)) % n := by
|
||||
@[simp] theorem mod_mul_mod {a b c : Nat} : (a % c * b) % c = a * b % c := by
|
||||
rw [mul_mod, mod_mod, ← mul_mod]
|
||||
|
||||
theorem mod_eq_sub (x w : Nat) : x % w = x - w * (x / w) := by
|
||||
conv => rhs; congr; rw [← mod_add_div x w]
|
||||
simp
|
||||
|
||||
/-! ### pow -/
|
||||
|
||||
theorem pow_succ' {m n : Nat} : m ^ n.succ = m * m ^ n := by
|
||||
@@ -850,18 +846,6 @@ protected theorem pow_lt_pow_iff_pow_mul_le_pow {a n m : Nat} (h : 1 < a) :
|
||||
rw [←Nat.pow_add_one, Nat.pow_le_pow_iff_right (by omega), Nat.pow_lt_pow_iff_right (by omega)]
|
||||
omega
|
||||
|
||||
protected theorem lt_pow_self {n a : Nat} (h : 1 < a) : n < a ^ n := by
|
||||
induction n with
|
||||
| zero => exact Nat.zero_lt_one
|
||||
| succ _ ih => exact Nat.lt_of_lt_of_le (Nat.add_lt_add_right ih 1) (Nat.pow_lt_pow_succ h)
|
||||
|
||||
protected theorem lt_two_pow_self : n < 2 ^ n :=
|
||||
Nat.lt_pow_self Nat.one_lt_two
|
||||
|
||||
@[simp]
|
||||
protected theorem mod_two_pow_self : n % 2 ^ n = n :=
|
||||
Nat.mod_eq_of_lt Nat.lt_two_pow_self
|
||||
|
||||
@[simp]
|
||||
theorem two_pow_pred_mul_two (h : 0 < w) :
|
||||
2 ^ (w - 1) * 2 = 2 ^ w := by
|
||||
|
||||
@@ -31,7 +31,7 @@ This file defines basic operations on the the sum type `α ⊕ β`.
|
||||
|
||||
## Further material
|
||||
|
||||
See `Init.Data.Sum.Lemmas` for theorems about these definitions.
|
||||
See `Batteries.Data.Sum.Lemmas` for theorems about these definitions.
|
||||
|
||||
## Notes
|
||||
|
||||
|
||||
@@ -246,12 +246,6 @@ instance (a b : UInt64) : Decidable (a ≤ b) := UInt64.decLe a b
|
||||
instance : Max UInt64 := maxOfLe
|
||||
instance : Min UInt64 := minOfLe
|
||||
|
||||
theorem usize_size_le : USize.size ≤ 18446744073709551616 := by
|
||||
cases usize_size_eq <;> next h => rw [h]; decide
|
||||
|
||||
theorem le_usize_size : 4294967296 ≤ USize.size := by
|
||||
cases usize_size_eq <;> next h => rw [h]; decide
|
||||
|
||||
@[extern "lean_usize_mul"]
|
||||
def USize.mul (a b : USize) : USize := ⟨a.toBitVec * b.toBitVec⟩
|
||||
@[extern "lean_usize_div"]
|
||||
@@ -270,29 +264,10 @@ def USize.xor (a b : USize) : USize := ⟨a.toBitVec ^^^ b.toBitVec⟩
|
||||
def USize.shiftLeft (a b : USize) : USize := ⟨a.toBitVec <<< (mod b (USize.ofNat System.Platform.numBits)).toBitVec⟩
|
||||
@[extern "lean_usize_shift_right"]
|
||||
def USize.shiftRight (a b : USize) : USize := ⟨a.toBitVec >>> (mod b (USize.ofNat System.Platform.numBits)).toBitVec⟩
|
||||
/--
|
||||
Upcast a `Nat` less than `2^32` to a `USize`.
|
||||
This is lossless because `USize.size` is either `2^32` or `2^64`.
|
||||
This function is overridden with a native implementation.
|
||||
-/
|
||||
@[extern "lean_usize_of_nat"]
|
||||
def USize.ofNat32 (n : @& Nat) (h : n < 4294967296) : USize :=
|
||||
USize.ofNatCore n (Nat.lt_of_lt_of_le h le_usize_size)
|
||||
@[extern "lean_uint32_to_usize"]
|
||||
def UInt32.toUSize (a : UInt32) : USize := USize.ofNat32 a.toBitVec.toNat a.toBitVec.isLt
|
||||
@[extern "lean_usize_to_uint32"]
|
||||
def USize.toUInt32 (a : USize) : UInt32 := a.toNat.toUInt32
|
||||
/-- Converts a `UInt64` to a `USize` by reducing modulo `USize.size`. -/
|
||||
@[extern "lean_uint64_to_usize"]
|
||||
def UInt64.toUSize (a : UInt64) : USize := a.toNat.toUSize
|
||||
/--
|
||||
Upcast a `USize` to a `UInt64`.
|
||||
This is lossless because `USize.size` is either `2^32` or `2^64`.
|
||||
This function is overridden with a native implementation.
|
||||
-/
|
||||
@[extern "lean_usize_to_uint64"]
|
||||
def USize.toUInt64 (a : USize) : UInt64 :=
|
||||
UInt64.ofNatCore a.toBitVec.toNat (Nat.lt_of_lt_of_le a.toBitVec.isLt usize_size_le)
|
||||
|
||||
instance : Mul USize := ⟨USize.mul⟩
|
||||
instance : Mod USize := ⟨USize.mod⟩
|
||||
|
||||
@@ -94,8 +94,10 @@ def UInt32.toUInt64 (a : UInt32) : UInt64 := ⟨⟨a.toNat, Nat.lt_trans a.toBit
|
||||
|
||||
instance UInt64.instOfNat : OfNat UInt64 n := ⟨UInt64.ofNat n⟩
|
||||
|
||||
@[deprecated usize_size_pos (since := "2024-11-24")] theorem usize_size_gt_zero : USize.size > 0 :=
|
||||
usize_size_pos
|
||||
theorem usize_size_gt_zero : USize.size > 0 := by
|
||||
cases usize_size_eq with
|
||||
| inl h => rw [h]; decide
|
||||
| inr h => rw [h]; decide
|
||||
|
||||
def USize.val (x : USize) : Fin USize.size := x.toBitVec.toFin
|
||||
@[extern "lean_usize_of_nat"]
|
||||
|
||||
@@ -133,9 +133,6 @@ declare_uint_theorems UInt32
|
||||
declare_uint_theorems UInt64
|
||||
declare_uint_theorems USize
|
||||
|
||||
theorem USize.toNat_ofNat_of_lt_32 {n : Nat} (h : n < 4294967296) : toNat (ofNat n) = n :=
|
||||
toNat_ofNat_of_lt (Nat.lt_of_lt_of_le h le_usize_size)
|
||||
|
||||
theorem UInt32.toNat_lt_of_lt {n : UInt32} {m : Nat} (h : m < size) : n < ofNat m → n.toNat < m := by
|
||||
simp [lt_def, BitVec.lt_def, UInt32.toNat, toBitVec_eq_of_lt h]
|
||||
|
||||
@@ -148,22 +145,22 @@ theorem UInt32.toNat_le_of_le {n : UInt32} {m : Nat} (h : m < size) : n ≤ ofNa
|
||||
theorem UInt32.le_toNat_of_le {n : UInt32} {m : Nat} (h : m < size) : ofNat m ≤ n → m ≤ n.toNat := by
|
||||
simp [le_def, BitVec.le_def, UInt32.toNat, toBitVec_eq_of_lt h]
|
||||
|
||||
@[deprecated UInt8.toNat_zero (since := "2024-06-23")] protected abbrev UInt8.zero_toNat := @UInt8.toNat_zero
|
||||
@[deprecated UInt8.toNat_div (since := "2024-06-23")] protected abbrev UInt8.div_toNat := @UInt8.toNat_div
|
||||
@[deprecated UInt8.toNat_mod (since := "2024-06-23")] protected abbrev UInt8.mod_toNat := @UInt8.toNat_mod
|
||||
@[deprecated (since := "2024-06-23")] protected abbrev UInt8.zero_toNat := @UInt8.toNat_zero
|
||||
@[deprecated (since := "2024-06-23")] protected abbrev UInt8.div_toNat := @UInt8.toNat_div
|
||||
@[deprecated (since := "2024-06-23")] protected abbrev UInt8.mod_toNat := @UInt8.toNat_mod
|
||||
|
||||
@[deprecated UInt16.toNat_zero (since := "2024-06-23")] protected abbrev UInt16.zero_toNat := @UInt16.toNat_zero
|
||||
@[deprecated UInt16.toNat_div (since := "2024-06-23")] protected abbrev UInt16.div_toNat := @UInt16.toNat_div
|
||||
@[deprecated UInt16.toNat_mod (since := "2024-06-23")] protected abbrev UInt16.mod_toNat := @UInt16.toNat_mod
|
||||
@[deprecated (since := "2024-06-23")] protected abbrev UInt16.zero_toNat := @UInt16.toNat_zero
|
||||
@[deprecated (since := "2024-06-23")] protected abbrev UInt16.div_toNat := @UInt16.toNat_div
|
||||
@[deprecated (since := "2024-06-23")] protected abbrev UInt16.mod_toNat := @UInt16.toNat_mod
|
||||
|
||||
@[deprecated UInt32.toNat_zero (since := "2024-06-23")] protected abbrev UInt32.zero_toNat := @UInt32.toNat_zero
|
||||
@[deprecated UInt32.toNat_div (since := "2024-06-23")] protected abbrev UInt32.div_toNat := @UInt32.toNat_div
|
||||
@[deprecated UInt32.toNat_mod (since := "2024-06-23")] protected abbrev UInt32.mod_toNat := @UInt32.toNat_mod
|
||||
@[deprecated (since := "2024-06-23")] protected abbrev UInt32.zero_toNat := @UInt32.toNat_zero
|
||||
@[deprecated (since := "2024-06-23")] protected abbrev UInt32.div_toNat := @UInt32.toNat_div
|
||||
@[deprecated (since := "2024-06-23")] protected abbrev UInt32.mod_toNat := @UInt32.toNat_mod
|
||||
|
||||
@[deprecated UInt64.toNat_zero (since := "2024-06-23")] protected abbrev UInt64.zero_toNat := @UInt64.toNat_zero
|
||||
@[deprecated UInt64.toNat_div (since := "2024-06-23")] protected abbrev UInt64.div_toNat := @UInt64.toNat_div
|
||||
@[deprecated UInt64.toNat_mod (since := "2024-06-23")] protected abbrev UInt64.mod_toNat := @UInt64.toNat_mod
|
||||
@[deprecated (since := "2024-06-23")] protected abbrev UInt64.zero_toNat := @UInt64.toNat_zero
|
||||
@[deprecated (since := "2024-06-23")] protected abbrev UInt64.div_toNat := @UInt64.toNat_div
|
||||
@[deprecated (since := "2024-06-23")] protected abbrev UInt64.mod_toNat := @UInt64.toNat_mod
|
||||
|
||||
@[deprecated USize.toNat_zero (since := "2024-06-23")] protected abbrev USize.zero_toNat := @USize.toNat_zero
|
||||
@[deprecated USize.toNat_div (since := "2024-06-23")] protected abbrev USize.div_toNat := @USize.toNat_div
|
||||
@[deprecated USize.toNat_mod (since := "2024-06-23")] protected abbrev USize.mod_toNat := @USize.toNat_mod
|
||||
@[deprecated (since := "2024-06-23")] protected abbrev USize.zero_toNat := @USize.toNat_zero
|
||||
@[deprecated (since := "2024-06-23")] protected abbrev USize.div_toNat := @USize.toNat_div
|
||||
@[deprecated (since := "2024-06-23")] protected abbrev USize.mod_toNat := @USize.toNat_mod
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2024 Lean FRO. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Kim Morrison
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Vector.Basic
|
||||
@@ -1,253 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2024 Shreyas Srinivas. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Shreyas Srinivas, François G. Dorais, Kim Morrison
|
||||
-/
|
||||
|
||||
prelude
|
||||
import Init.Data.Array.Lemmas
|
||||
|
||||
/-!
|
||||
# Vectors
|
||||
|
||||
`Vector α n` is a thin wrapper around `Array α` for arrays of fixed size `n`.
|
||||
-/
|
||||
|
||||
/-- `Vector α n` is an `Array α` with size `n`. -/
|
||||
structure Vector (α : Type u) (n : Nat) extends Array α where
|
||||
/-- Array size. -/
|
||||
size_toArray : toArray.size = n
|
||||
deriving Repr, DecidableEq
|
||||
|
||||
attribute [simp] Vector.size_toArray
|
||||
|
||||
namespace Vector
|
||||
|
||||
/-- Syntax for `Vector α n` -/
|
||||
syntax "#v[" withoutPosition(sepBy(term, ", ")) "]" : term
|
||||
|
||||
open Lean in
|
||||
macro_rules
|
||||
| `(#v[ $elems,* ]) => `(Vector.mk (n := $(quote elems.getElems.size)) #[$elems,*] rfl)
|
||||
|
||||
/-- Custom eliminator for `Vector α n` through `Array α` -/
|
||||
@[elab_as_elim]
|
||||
def elimAsArray {motive : Vector α n → Sort u}
|
||||
(mk : ∀ (a : Array α) (ha : a.size = n), motive ⟨a, ha⟩) :
|
||||
(v : Vector α n) → motive v
|
||||
| ⟨a, ha⟩ => mk a ha
|
||||
|
||||
/-- Custom eliminator for `Vector α n` through `List α` -/
|
||||
@[elab_as_elim]
|
||||
def elimAsList {motive : Vector α n → Sort u}
|
||||
(mk : ∀ (a : List α) (ha : a.length = n), motive ⟨⟨a⟩, ha⟩) :
|
||||
(v : Vector α n) → motive v
|
||||
| ⟨⟨a⟩, ha⟩ => mk a ha
|
||||
|
||||
/-- Make an empty vector with pre-allocated capacity. -/
|
||||
@[inline] def mkEmpty (capacity : Nat) : Vector α 0 := ⟨.mkEmpty capacity, rfl⟩
|
||||
|
||||
/-- Makes a vector of size `n` with all cells containing `v`. -/
|
||||
@[inline] def mkVector (n) (v : α) : Vector α n := ⟨mkArray n v, by simp⟩
|
||||
|
||||
/-- Returns a vector of size `1` with element `v`. -/
|
||||
@[inline] def singleton (v : α) : Vector α 1 := ⟨#[v], rfl⟩
|
||||
|
||||
instance [Inhabited α] : Inhabited (Vector α n) where
|
||||
default := mkVector n default
|
||||
|
||||
/-- Get an element of a vector using a `Fin` index. -/
|
||||
@[inline] def get (v : Vector α n) (i : Fin n) : α :=
|
||||
v.toArray[(i.cast v.size_toArray.symm).1]
|
||||
|
||||
/-- Get an element of a vector using a `USize` index and a proof that the index is within bounds. -/
|
||||
@[inline] def uget (v : Vector α n) (i : USize) (h : i.toNat < n) : α :=
|
||||
v.toArray.uget i (v.size_toArray.symm ▸ h)
|
||||
|
||||
instance : GetElem (Vector α n) Nat α fun _ i => i < n where
|
||||
getElem x i h := get x ⟨i, h⟩
|
||||
|
||||
/--
|
||||
Get an element of a vector using a `Nat` index. Returns the given default value if the index is out
|
||||
of bounds.
|
||||
-/
|
||||
@[inline] def getD (v : Vector α n) (i : Nat) (default : α) : α := v.toArray.getD i default
|
||||
|
||||
/-- The last element of a vector. Panics if the vector is empty. -/
|
||||
@[inline] def back! [Inhabited α] (v : Vector α n) : α := v.toArray.back!
|
||||
|
||||
/-- The last element of a vector, or `none` if the array is empty. -/
|
||||
@[inline] def back? (v : Vector α n) : Option α := v.toArray.back?
|
||||
|
||||
/-- The last element of a non-empty vector. -/
|
||||
@[inline] def back [NeZero n] (v : Vector α n) : α :=
|
||||
-- TODO: change to just `v[n]`
|
||||
have : Inhabited α := ⟨v[0]'(Nat.pos_of_neZero n)⟩
|
||||
v.back!
|
||||
|
||||
/-- The first element of a non-empty vector. -/
|
||||
@[inline] def head [NeZero n] (v : Vector α n) := v[0]'(Nat.pos_of_neZero n)
|
||||
|
||||
/-- Push an element `x` to the end of a vector. -/
|
||||
@[inline] def push (x : α) (v : Vector α n) : Vector α (n + 1) :=
|
||||
⟨v.toArray.push x, by simp⟩
|
||||
|
||||
/-- Remove the last element of a vector. -/
|
||||
@[inline] def pop (v : Vector α n) : Vector α (n - 1) :=
|
||||
⟨Array.pop v.toArray, by simp⟩
|
||||
|
||||
/--
|
||||
Set an element in a vector using a `Nat` index, with a tactic provided proof that the index is in
|
||||
bounds.
|
||||
|
||||
This will perform the update destructively provided that the vector has a reference count of 1.
|
||||
-/
|
||||
@[inline] def set (v : Vector α n) (i : Nat) (x : α) (h : i < n := by get_elem_tactic): Vector α n :=
|
||||
⟨v.toArray.set i x (by simp [*]), by simp⟩
|
||||
|
||||
/--
|
||||
Set an element in a vector using a `Nat` index. Returns the vector unchanged if the index is out of
|
||||
bounds.
|
||||
|
||||
This will perform the update destructively provided that the vector has a reference count of 1.
|
||||
-/
|
||||
@[inline] def setIfInBounds (v : Vector α n) (i : Nat) (x : α) : Vector α n :=
|
||||
⟨v.toArray.setIfInBounds i x, by simp⟩
|
||||
|
||||
/--
|
||||
Set an element in a vector using a `Nat` index. Panics if the index is out of bounds.
|
||||
|
||||
This will perform the update destructively provided that the vector has a reference count of 1.
|
||||
-/
|
||||
@[inline] def set! (v : Vector α n) (i : Nat) (x : α) : Vector α n :=
|
||||
⟨v.toArray.set! i x, by simp⟩
|
||||
|
||||
/-- Append two vectors. -/
|
||||
@[inline] def append (v : Vector α n) (w : Vector α m) : Vector α (n + m) :=
|
||||
⟨v.toArray ++ w.toArray, by simp⟩
|
||||
|
||||
instance : HAppend (Vector α n) (Vector α m) (Vector α (n + m)) where
|
||||
hAppend := append
|
||||
|
||||
/-- Creates a vector from another with a provably equal length. -/
|
||||
@[inline] protected def cast (h : n = m) (v : Vector α n) : Vector α m :=
|
||||
⟨v.toArray, by simp [*]⟩
|
||||
|
||||
/--
|
||||
Extracts the slice of a vector from indices `start` to `stop` (exclusive). If `start ≥ stop`, the
|
||||
result is empty. If `stop` is greater than the size of the vector, the size is used instead.
|
||||
-/
|
||||
@[inline] def extract (v : Vector α n) (start stop : Nat) : Vector α (min stop n - start) :=
|
||||
⟨v.toArray.extract start stop, by simp⟩
|
||||
|
||||
/-- Maps elements of a vector using the function `f`. -/
|
||||
@[inline] def map (f : α → β) (v : Vector α n) : Vector β n :=
|
||||
⟨v.toArray.map f, by simp⟩
|
||||
|
||||
/-- Maps corresponding elements of two vectors of equal size using the function `f`. -/
|
||||
@[inline] def zipWith (a : Vector α n) (b : Vector β n) (f : α → β → φ) : Vector φ n :=
|
||||
⟨Array.zipWith a.toArray b.toArray f, by simp⟩
|
||||
|
||||
/-- The vector of length `n` whose `i`-th element is `f i`. -/
|
||||
@[inline] def ofFn (f : Fin n → α) : Vector α n :=
|
||||
⟨Array.ofFn f, by simp⟩
|
||||
|
||||
/--
|
||||
Swap two elements of a vector using `Fin` indices.
|
||||
|
||||
This will perform the update destructively provided that the vector has a reference count of 1.
|
||||
-/
|
||||
@[inline] def swap (v : Vector α n) (i j : Nat)
|
||||
(hi : i < n := by get_elem_tactic) (hj : j < n := by get_elem_tactic) : Vector α n :=
|
||||
⟨v.toArray.swap i j (by simpa using hi) (by simpa using hj), by simp⟩
|
||||
|
||||
/--
|
||||
Swap two elements of a vector using `Nat` indices. Panics if either index is out of bounds.
|
||||
|
||||
This will perform the update destructively provided that the vector has a reference count of 1.
|
||||
-/
|
||||
@[inline] def swapIfInBounds (v : Vector α n) (i j : Nat) : Vector α n :=
|
||||
⟨v.toArray.swapIfInBounds i j, by simp⟩
|
||||
|
||||
/--
|
||||
Swaps an element of a vector with a given value using a `Fin` index. The original value is returned
|
||||
along with the updated vector.
|
||||
|
||||
This will perform the update destructively provided that the vector has a reference count of 1.
|
||||
-/
|
||||
@[inline] def swapAt (v : Vector α n) (i : Nat) (x : α) (hi : i < n := by get_elem_tactic) :
|
||||
α × Vector α n :=
|
||||
let a := v.toArray.swapAt i x (by simpa using hi)
|
||||
⟨a.fst, a.snd, by simp [a]⟩
|
||||
|
||||
/--
|
||||
Swaps an element of a vector with a given value using a `Nat` index. Panics if the index is out of
|
||||
bounds. The original value is returned along with the updated vector.
|
||||
|
||||
This will perform the update destructively provided that the vector has a reference count of 1.
|
||||
-/
|
||||
@[inline] def swapAt! (v : Vector α n) (i : Nat) (x : α) : α × Vector α n :=
|
||||
let a := v.toArray.swapAt! i x
|
||||
⟨a.fst, a.snd, by simp [a]⟩
|
||||
|
||||
/-- The vector `#v[0,1,2,...,n-1]`. -/
|
||||
@[inline] def range (n : Nat) : Vector Nat n := ⟨Array.range n, by simp⟩
|
||||
|
||||
/--
|
||||
Extract the first `m` elements of a vector. If `m` is greater than or equal to the size of the
|
||||
vector then the vector is returned unchanged.
|
||||
-/
|
||||
@[inline] def take (v : Vector α n) (m : Nat) : Vector α (min m n) :=
|
||||
⟨v.toArray.take m, by simp⟩
|
||||
|
||||
/--
|
||||
Deletes the first `m` elements of a vector. If `m` is greater than or equal to the size of the
|
||||
vector then the empty vector is returned.
|
||||
-/
|
||||
@[inline] def drop (v : Vector α n) (m : Nat) : Vector α (n - m) :=
|
||||
⟨v.toArray.extract m v.size, by simp⟩
|
||||
|
||||
/--
|
||||
Compares two vectors of the same size using a given boolean relation `r`. `isEqv v w r` returns
|
||||
`true` if and only if `r v[i] w[i]` is true for all indices `i`.
|
||||
-/
|
||||
@[inline] def isEqv (v w : Vector α n) (r : α → α → Bool) : Bool :=
|
||||
Array.isEqvAux v.toArray w.toArray (by simp) r n (by simp)
|
||||
|
||||
instance [BEq α] : BEq (Vector α n) where
|
||||
beq a b := isEqv a b (· == ·)
|
||||
|
||||
/-- Reverse the elements of a vector. -/
|
||||
@[inline] def reverse (v : Vector α n) : Vector α n :=
|
||||
⟨v.toArray.reverse, by simp⟩
|
||||
|
||||
/-- Delete an element of a vector using a `Nat` index and a tactic provided proof. -/
|
||||
@[inline] def eraseIdx (v : Vector α n) (i : Nat) (h : i < n := by get_elem_tactic) :
|
||||
Vector α (n-1) :=
|
||||
⟨v.toArray.eraseIdx i (v.size_toArray.symm ▸ h), by simp [Array.size_eraseIdx]⟩
|
||||
|
||||
/-- Delete an element of a vector using a `Nat` index. Panics if the index is out of bounds. -/
|
||||
@[inline] def eraseIdx! (v : Vector α n) (i : Nat) : Vector α (n-1) :=
|
||||
if _ : i < n then
|
||||
v.eraseIdx i
|
||||
else
|
||||
have : Inhabited (Vector α (n-1)) := ⟨v.pop⟩
|
||||
panic! "index out of bounds"
|
||||
|
||||
/-- Delete the first element of a vector. Returns the empty vector if the input vector is empty. -/
|
||||
@[inline] def tail (v : Vector α n) : Vector α (n-1) :=
|
||||
if _ : 0 < n then
|
||||
v.eraseIdx 0
|
||||
else
|
||||
v.cast (by omega)
|
||||
|
||||
/--
|
||||
Finds the first index of a given value in a vector using `==` for comparison. Returns `none` if the
|
||||
no element of the index matches the given value.
|
||||
-/
|
||||
@[inline] def indexOf? [BEq α] (v : Vector α n) (x : α) : Option (Fin n) :=
|
||||
(v.toArray.indexOf? x).map (Fin.cast v.size_toArray)
|
||||
|
||||
/-- Returns `true` when `v` is a prefix of the vector `w`. -/
|
||||
@[inline] def isPrefixOf [BEq α] (v : Vector α m) (w : Vector α n) : Bool :=
|
||||
v.toArray.isPrefixOf w.toArray
|
||||
@@ -1,172 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2024 Shreyas Srinivas. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Shreyas Srinivas, Francois Dorais
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Vector.Basic
|
||||
|
||||
/-!
|
||||
## Vectors
|
||||
Lemmas about `Vector α n`
|
||||
-/
|
||||
|
||||
namespace Vector
|
||||
|
||||
theorem length_toList {α n} (xs : Vector α n) : xs.toList.length = n := by simp
|
||||
|
||||
@[simp] theorem getElem_mk {data : Array α} {size : data.size = n} {i : Nat} (h : i < n) :
|
||||
(Vector.mk data size)[i] = data[i] := rfl
|
||||
|
||||
@[simp] theorem getElem_toArray {α n} (xs : Vector α n) (i : Nat) (h : i < xs.toArray.size) :
|
||||
xs.toArray[i] = xs[i]'(by simpa using h) := by
|
||||
cases xs
|
||||
simp
|
||||
|
||||
theorem getElem_toList {α n} (xs : Vector α n) (i : Nat) (h : i < xs.toList.length) :
|
||||
xs.toList[i] = xs[i]'(by simpa using h) := by simp
|
||||
|
||||
@[simp] theorem getElem_ofFn {α n} (f : Fin n → α) (i : Nat) (h : i < n) :
|
||||
(Vector.ofFn f)[i] = f ⟨i, by simpa using h⟩ := by
|
||||
simp [ofFn]
|
||||
|
||||
/-- The empty vector maps to the empty vector. -/
|
||||
@[simp]
|
||||
theorem map_empty (f : α → β) : map f #v[] = #v[] := by
|
||||
rw [map, mk.injEq]
|
||||
exact Array.map_empty f
|
||||
|
||||
theorem toArray_inj : ∀ {v w : Vector α n}, v.toArray = w.toArray → v = w
|
||||
| {..}, {..}, rfl => rfl
|
||||
|
||||
/-- A vector of length `0` is the empty vector. -/
|
||||
protected theorem eq_empty (v : Vector α 0) : v = #v[] := by
|
||||
apply Vector.toArray_inj
|
||||
apply Array.eq_empty_of_size_eq_zero v.2
|
||||
|
||||
/--
|
||||
`Vector.ext` is an extensionality theorem.
|
||||
Vectors `a` and `b` are equal to each other if their elements are equal for each valid index.
|
||||
-/
|
||||
@[ext]
|
||||
protected theorem ext {a b : Vector α n} (h : (i : Nat) → (_ : i < n) → a[i] = b[i]) : a = b := by
|
||||
apply Vector.toArray_inj
|
||||
apply Array.ext
|
||||
· rw [a.size_toArray, b.size_toArray]
|
||||
· intro i hi _
|
||||
rw [a.size_toArray] at hi
|
||||
exact h i hi
|
||||
|
||||
@[simp] theorem push_mk {data : Array α} {size : data.size = n} {x : α} :
|
||||
(Vector.mk data size).push x =
|
||||
Vector.mk (data.push x) (by simp [size, Nat.succ_eq_add_one]) := rfl
|
||||
|
||||
@[simp] theorem pop_mk {data : Array α} {size : data.size = n} :
|
||||
(Vector.mk data size).pop = Vector.mk data.pop (by simp [size]) := rfl
|
||||
|
||||
@[simp] theorem swap_mk {data : Array α} {size : data.size = n} {i j : Nat} {hi hj} :
|
||||
(Vector.mk data size).swap i j hi hj = Vector.mk (data.swap i j) (by simp_all) := rfl
|
||||
|
||||
@[simp] theorem getElem_push_last {v : Vector α n} {x : α} : (v.push x)[n] = x := by
|
||||
rcases v with ⟨data, rfl⟩
|
||||
simp
|
||||
|
||||
@[simp] theorem getElem_push_lt {v : Vector α n} {x : α} {i : Nat} (h : i < n) :
|
||||
(v.push x)[i] = v[i] := by
|
||||
rcases v with ⟨data, rfl⟩
|
||||
simp [Array.getElem_push_lt, h]
|
||||
|
||||
@[simp] theorem getElem_pop {v : Vector α n} {i : Nat} (h : i < n - 1) : (v.pop)[i] = v[i] := by
|
||||
rcases v with ⟨data, rfl⟩
|
||||
simp
|
||||
|
||||
/--
|
||||
Variant of `getElem_pop` that will sometimes fire when `getElem_pop` gets stuck because of
|
||||
defeq issues in the implicit size argument.
|
||||
-/
|
||||
@[simp] theorem getElem_pop' (v : Vector α (n + 1)) (i : Nat) (h : i < n + 1 - 1) :
|
||||
@getElem (Vector α n) Nat α (fun _ i => i < n) instGetElemNatLt v.pop i h = v[i] :=
|
||||
getElem_pop h
|
||||
|
||||
@[simp] theorem push_pop_back (v : Vector α (n + 1)) : v.pop.push v.back = v := by
|
||||
ext i
|
||||
by_cases h : i < n
|
||||
· simp [h]
|
||||
· replace h : i = v.size - 1 := by rw [size_toArray]; omega
|
||||
subst h
|
||||
simp [pop, back, back!, ← Array.eq_push_pop_back!_of_size_ne_zero]
|
||||
|
||||
theorem push_swap (a : Vector α n) (x : α) {i j : Nat} {hi hj} :
|
||||
(a.swap i j hi hj).push x = (a.push x).swap i j := by
|
||||
cases a
|
||||
simp [Array.push_swap]
|
||||
|
||||
/-! ### cast -/
|
||||
|
||||
@[simp] theorem cast_mk {n m} (a : Array α) (w : a.size = n) (h : n = m) :
|
||||
(Vector.mk a w).cast h = ⟨a, h ▸ w⟩ := by
|
||||
simp [Vector.cast]
|
||||
|
||||
@[simp] theorem cast_refl {n} (a : Vector α n) : a.cast rfl = a := by
|
||||
cases a
|
||||
simp
|
||||
|
||||
@[simp] theorem toArray_cast {n m} (a : Vector α n) (h : n = m) :
|
||||
(a.cast h).toArray = a.toArray := by
|
||||
subst h
|
||||
simp
|
||||
|
||||
theorem cast_inj {n m} (a : Vector α n) (b : Vector α n) (h : n = m) :
|
||||
a.cast h = b.cast h ↔ a = b := by
|
||||
cases h
|
||||
simp
|
||||
|
||||
theorem cast_eq_iff {n m} (a : Vector α n) (b : Vector α m) (h : n = m) :
|
||||
a.cast h = b ↔ a = b.cast h.symm := by
|
||||
cases h
|
||||
simp
|
||||
|
||||
theorem eq_cast_iff {n m} (a : Vector α n) (b : Vector α m) (h : m = n) :
|
||||
a = b.cast h ↔ a.cast h.symm = b := by
|
||||
cases h
|
||||
simp
|
||||
|
||||
/-! ### Decidable quantifiers. -/
|
||||
|
||||
theorem forall_zero_iff {P : Vector α 0 → Prop} :
|
||||
(∀ v, P v) ↔ P #v[] := by
|
||||
constructor
|
||||
· intro h
|
||||
apply h
|
||||
· intro h v
|
||||
obtain (rfl : v = #v[]) := (by ext i h; simp at h)
|
||||
apply h
|
||||
|
||||
theorem forall_cons_iff {P : Vector α (n + 1) → Prop} :
|
||||
(∀ v : Vector α (n + 1), P v) ↔ (∀ (x : α) (v : Vector α n), P (v.push x)) := by
|
||||
constructor
|
||||
· intro h _ _
|
||||
apply h
|
||||
· intro h v
|
||||
have w : v = v.pop.push v.back := by simp
|
||||
rw [w]
|
||||
apply h
|
||||
|
||||
instance instDecidableForallVectorZero (P : Vector α 0 → Prop) :
|
||||
∀ [Decidable (P #v[])], Decidable (∀ v, P v)
|
||||
| .isTrue h => .isTrue fun ⟨v, s⟩ => by
|
||||
obtain (rfl : v = .empty) := (by ext i h₁ h₂; exact s; cases h₂)
|
||||
exact h
|
||||
| .isFalse h => .isFalse (fun w => h (w _))
|
||||
|
||||
instance instDecidableForallVectorSucc (P : Vector α (n+1) → Prop)
|
||||
[Decidable (∀ (x : α) (v : Vector α n), P (v.push x))] : Decidable (∀ v, P v) :=
|
||||
decidable_of_iff' (∀ x (v : Vector α n), P (v.push x)) forall_cons_iff
|
||||
|
||||
instance instDecidableExistsVectorZero (P : Vector α 0 → Prop) [Decidable (P #v[])] :
|
||||
Decidable (∃ v, P v) :=
|
||||
decidable_of_iff (¬ ∀ v, ¬ P v) Classical.not_forall_not
|
||||
|
||||
instance instDecidableExistsVectorSucc (P : Vector α (n+1) → Prop)
|
||||
[Decidable (∀ (x : α) (v : Vector α n), ¬ P (v.push x))] : Decidable (∃ v, P v) :=
|
||||
decidable_of_iff (¬ ∀ v, ¬ P v) Classical.not_forall_not
|
||||
@@ -206,12 +206,12 @@ instance : GetElem (List α) Nat α fun as i => i < as.length where
|
||||
@[simp] theorem getElem_cons_zero (a : α) (as : List α) (h : 0 < (a :: as).length) : getElem (a :: as) 0 h = a := by
|
||||
rfl
|
||||
|
||||
@[deprecated getElem_cons_zero (since := "2024-06-12")] abbrev cons_getElem_zero := @getElem_cons_zero
|
||||
@[deprecated (since := "2024-06-12")] abbrev cons_getElem_zero := @getElem_cons_zero
|
||||
|
||||
@[simp] theorem getElem_cons_succ (a : α) (as : List α) (i : Nat) (h : i + 1 < (a :: as).length) : getElem (a :: as) (i+1) h = getElem as i (Nat.lt_of_succ_lt_succ h) := by
|
||||
rfl
|
||||
|
||||
@[deprecated getElem_cons_succ (since := "2024-06-12")] abbrev cons_getElem_succ := @getElem_cons_succ
|
||||
@[deprecated (since := "2024-06-12")] abbrev cons_getElem_succ := @getElem_cons_succ
|
||||
|
||||
@[simp] theorem getElem_mem : ∀ {l : List α} {n} (h : n < l.length), l[n]'h ∈ l
|
||||
| _ :: _, 0, _ => .head ..
|
||||
@@ -223,8 +223,7 @@ theorem getElem_cons_drop_succ_eq_drop {as : List α} {i : Nat} (h : i < as.leng
|
||||
| _::_, 0 => rfl
|
||||
| _::_, i+1 => getElem_cons_drop_succ_eq_drop (i := i) _
|
||||
|
||||
@[deprecated getElem_cons_drop_succ_eq_drop (since := "2024-11-05")]
|
||||
abbrev get_drop_eq_drop := @getElem_cons_drop_succ_eq_drop
|
||||
@[deprecated (since := "2024-11-05")] abbrev get_drop_eq_drop := @getElem_cons_drop_succ_eq_drop
|
||||
|
||||
end List
|
||||
|
||||
|
||||
@@ -251,16 +251,10 @@ def neutralConfig : Simp.Config := {
|
||||
|
||||
end Simp
|
||||
|
||||
/-- Configuration for which occurrences that match an expression should be rewritten. -/
|
||||
inductive Occurrences where
|
||||
/-- All occurrences should be rewritten. -/
|
||||
| all
|
||||
/-- A list of indices for which occurrences should be rewritten. -/
|
||||
| pos (idxs : List Nat)
|
||||
/-- A list of indices for which occurrences should not be rewritten. -/
|
||||
| neg (idxs : List Nat)
|
||||
deriving Inhabited, BEq
|
||||
|
||||
instance : Coe (List Nat) Occurrences := ⟨.pos⟩
|
||||
|
||||
end Lean.Meta
|
||||
|
||||
@@ -71,9 +71,9 @@ def prio : Category := {}
|
||||
|
||||
/-- `prec` is a builtin syntax category for precedences. A precedence is a value
|
||||
that expresses how tightly a piece of syntax binds: for example `1 + 2 * 3` is
|
||||
parsed as `1 + (2 * 3)` because `*` has a higher precedence than `+`.
|
||||
parsed as `1 + (2 * 3)` because `*` has a higher pr0ecedence than `+`.
|
||||
Higher numbers denote higher precedence.
|
||||
In addition to literals like `37`, there are some special named precedence levels:
|
||||
In addition to literals like `37`, there are some special named priorities:
|
||||
* `arg` for the precedence of function arguments
|
||||
* `max` for the highest precedence used in term parsers (not actually the maximum possible value)
|
||||
* `lead` for the precedence of terms not supposed to be used as arguments
|
||||
|
||||
@@ -2116,11 +2116,6 @@ theorem usize_size_eq : Or (Eq USize.size 4294967296) (Eq USize.size 18446744073
|
||||
| _, Or.inl rfl => Or.inl (of_decide_eq_true rfl)
|
||||
| _, Or.inr rfl => Or.inr (of_decide_eq_true rfl)
|
||||
|
||||
theorem usize_size_pos : LT.lt 0 USize.size :=
|
||||
match USize.size, usize_size_eq with
|
||||
| _, Or.inl rfl => of_decide_eq_true rfl
|
||||
| _, Or.inr rfl => of_decide_eq_true rfl
|
||||
|
||||
/--
|
||||
A `USize` is an unsigned integer with the size of a word
|
||||
for the platform's architecture.
|
||||
@@ -2160,7 +2155,24 @@ def USize.decEq (a b : USize) : Decidable (Eq a b) :=
|
||||
instance : DecidableEq USize := USize.decEq
|
||||
|
||||
instance : Inhabited USize where
|
||||
default := USize.ofNatCore 0 usize_size_pos
|
||||
default := USize.ofNatCore 0 (match USize.size, usize_size_eq with
|
||||
| _, Or.inl rfl => of_decide_eq_true rfl
|
||||
| _, Or.inr rfl => of_decide_eq_true rfl)
|
||||
|
||||
/--
|
||||
Upcast a `Nat` less than `2^32` to a `USize`.
|
||||
This is lossless because `USize.size` is either `2^32` or `2^64`.
|
||||
This function is overridden with a native implementation.
|
||||
-/
|
||||
@[extern "lean_usize_of_nat"]
|
||||
def USize.ofNat32 (n : @& Nat) (h : LT.lt n 4294967296) : USize where
|
||||
toBitVec :=
|
||||
BitVec.ofNatLt n (
|
||||
match System.Platform.numBits, System.Platform.numBits_eq with
|
||||
| _, Or.inl rfl => h
|
||||
| _, Or.inr rfl => Nat.lt_trans h (of_decide_eq_true rfl)
|
||||
)
|
||||
|
||||
/--
|
||||
A `Nat` denotes a valid unicode codepoint if it is less than `0x110000`, and
|
||||
it is also not a "surrogate" character (the range `0xd800` to `0xdfff` inclusive).
|
||||
@@ -3420,6 +3432,25 @@ class Hashable (α : Sort u) where
|
||||
|
||||
export Hashable (hash)
|
||||
|
||||
/-- Converts a `UInt64` to a `USize` by reducing modulo `USize.size`. -/
|
||||
@[extern "lean_uint64_to_usize"]
|
||||
opaque UInt64.toUSize (u : UInt64) : USize
|
||||
|
||||
/--
|
||||
Upcast a `USize` to a `UInt64`.
|
||||
This is lossless because `USize.size` is either `2^32` or `2^64`.
|
||||
This function is overridden with a native implementation.
|
||||
-/
|
||||
@[extern "lean_usize_to_uint64"]
|
||||
def USize.toUInt64 (u : USize) : UInt64 where
|
||||
toBitVec := BitVec.ofNatLt u.toBitVec.toNat (
|
||||
let ⟨⟨n, h⟩⟩ := u
|
||||
show LT.lt n _ from
|
||||
match System.Platform.numBits, System.Platform.numBits_eq, h with
|
||||
| _, Or.inl rfl, h => Nat.lt_trans h (of_decide_eq_true rfl)
|
||||
| _, Or.inr rfl, h => h
|
||||
)
|
||||
|
||||
/-- An opaque hash mixing operation, used to implement hashing for tuples. -/
|
||||
@[extern "lean_uint64_mix_hash"]
|
||||
opaque mixHash (u₁ u₂ : UInt64) : UInt64
|
||||
|
||||
@@ -5,7 +5,6 @@ Authors: Leonardo de Moura, Mario Carneiro
|
||||
-/
|
||||
prelude
|
||||
import Init.Util
|
||||
import Init.Data.UInt.Basic
|
||||
|
||||
namespace ShareCommon
|
||||
/-
|
||||
|
||||
@@ -72,21 +72,6 @@ theorem let_body_congr {α : Sort u} {β : α → Sort v} {b b' : (a : α) →
|
||||
(a : α) (h : ∀ x, b x = b' x) : (let x := a; b x) = (let x := a; b' x) :=
|
||||
(funext h : b = b') ▸ rfl
|
||||
|
||||
theorem letFun_unused {α : Sort u} {β : Sort v} (a : α) {b b' : β} (h : b = b') : @letFun α (fun _ => β) a (fun _ => b) = b' :=
|
||||
h
|
||||
|
||||
theorem letFun_congr {α : Sort u} {β : Sort v} {a a' : α} {f f' : α → β} (h₁ : a = a') (h₂ : ∀ x, f x = f' x)
|
||||
: @letFun α (fun _ => β) a f = @letFun α (fun _ => β) a' f' := by
|
||||
rw [h₁, funext h₂]
|
||||
|
||||
theorem letFun_body_congr {α : Sort u} {β : Sort v} (a : α) {f f' : α → β} (h : ∀ x, f x = f' x)
|
||||
: @letFun α (fun _ => β) a f = @letFun α (fun _ => β) a f' := by
|
||||
rw [funext h]
|
||||
|
||||
theorem letFun_val_congr {α : Sort u} {β : Sort v} {a a' : α} {f : α → β} (h : a = a')
|
||||
: @letFun α (fun _ => β) a f = @letFun α (fun _ => β) a' f := by
|
||||
rw [h]
|
||||
|
||||
@[congr]
|
||||
theorem ite_congr {x y u v : α} {s : Decidable b} [Decidable c]
|
||||
(h₁ : b = c) (h₂ : c → x = u) (h₃ : ¬ c → y = v) : ite b x y = ite c u v := by
|
||||
|
||||
@@ -30,7 +30,7 @@ Does nothing for non-`node` nodes, or if `i` is out of bounds of the node list.
|
||||
-/
|
||||
def setArg (stx : Syntax) (i : Nat) (arg : Syntax) : Syntax :=
|
||||
match stx with
|
||||
| node info k args => node info k (args.setIfInBounds i arg)
|
||||
| node info k args => node info k (args.setD i arg)
|
||||
| stx => stx
|
||||
|
||||
end Lean.Syntax
|
||||
|
||||
@@ -462,16 +462,6 @@ Note that it is the caller's job to remove the file after use.
|
||||
-/
|
||||
@[extern "lean_io_create_tempfile"] opaque createTempFile : IO (Handle × FilePath)
|
||||
|
||||
/--
|
||||
Creates a temporary directory in the most secure manner possible. There are no race conditions in the
|
||||
directory’s creation. The directory is readable and writable only by the creating user ID.
|
||||
|
||||
Returns the new directory's path.
|
||||
|
||||
It is the caller's job to remove the directory after use.
|
||||
-/
|
||||
@[extern "lean_io_create_tempdir"] opaque createTempDir : IO FilePath
|
||||
|
||||
end FS
|
||||
|
||||
@[extern "lean_io_getenv"] opaque getEnv (var : @& String) : BaseIO (Option String)
|
||||
@@ -484,6 +474,17 @@ namespace FS
|
||||
def withFile (fn : FilePath) (mode : Mode) (f : Handle → IO α) : IO α :=
|
||||
Handle.mk fn mode >>= f
|
||||
|
||||
/--
|
||||
Like `createTempFile` but also takes care of removing the file after usage.
|
||||
-/
|
||||
def withTempFile [Monad m] [MonadFinally m] [MonadLiftT IO m] (f : Handle → FilePath → m α) :
|
||||
m α := do
|
||||
let (handle, path) ← createTempFile
|
||||
try
|
||||
f handle path
|
||||
finally
|
||||
removeFile path
|
||||
|
||||
def Handle.putStrLn (h : Handle) (s : String) : IO Unit :=
|
||||
h.putStr (s.push '\n')
|
||||
|
||||
@@ -674,10 +675,8 @@ def appDir : IO FilePath := do
|
||||
| throw <| IO.userError s!"System.IO.appDir: unexpected filename '{p}'"
|
||||
FS.realPath p
|
||||
|
||||
namespace FS
|
||||
|
||||
/-- Create given path and all missing parents as directories. -/
|
||||
partial def createDirAll (p : FilePath) : IO Unit := do
|
||||
partial def FS.createDirAll (p : FilePath) : IO Unit := do
|
||||
if ← p.isDir then
|
||||
return ()
|
||||
if let some parent := p.parent then
|
||||
@@ -694,7 +693,7 @@ partial def createDirAll (p : FilePath) : IO Unit := do
|
||||
/--
|
||||
Fully remove given directory by deleting all contained files and directories in an unspecified order.
|
||||
Fails if any contained entry cannot be deleted or was newly created during execution. -/
|
||||
partial def removeDirAll (p : FilePath) : IO Unit := do
|
||||
partial def FS.removeDirAll (p : FilePath) : IO Unit := do
|
||||
for ent in (← p.readDir) do
|
||||
if (← ent.path.isDir : Bool) then
|
||||
removeDirAll ent.path
|
||||
@@ -702,32 +701,6 @@ partial def removeDirAll (p : FilePath) : IO Unit := do
|
||||
removeFile ent.path
|
||||
removeDir p
|
||||
|
||||
/--
|
||||
Like `createTempFile`, but also takes care of removing the file after usage.
|
||||
-/
|
||||
def withTempFile [Monad m] [MonadFinally m] [MonadLiftT IO m] (f : Handle → FilePath → m α) :
|
||||
m α := do
|
||||
let (handle, path) ← createTempFile
|
||||
try
|
||||
f handle path
|
||||
finally
|
||||
removeFile path
|
||||
|
||||
/--
|
||||
Like `createTempDir`, but also takes care of removing the directory after usage.
|
||||
|
||||
All files in the directory are recursively deleted, regardless of how or when they were created.
|
||||
-/
|
||||
def withTempDir [Monad m] [MonadFinally m] [MonadLiftT IO m] (f : FilePath → m α) :
|
||||
m α := do
|
||||
let path ← createTempDir
|
||||
try
|
||||
f path
|
||||
finally
|
||||
removeDirAll path
|
||||
|
||||
end FS
|
||||
|
||||
namespace Process
|
||||
|
||||
/-- Returns the current working directory of the calling process. -/
|
||||
|
||||
@@ -428,11 +428,11 @@ macro "infer_instance" : tactic => `(tactic| exact inferInstance)
|
||||
/--
|
||||
`+opt` is short for `(opt := true)`. It sets the `opt` configuration option to `true`.
|
||||
-/
|
||||
syntax posConfigItem := " +" noWs ident
|
||||
syntax posConfigItem := "+" noWs ident
|
||||
/--
|
||||
`-opt` is short for `(opt := false)`. It sets the `opt` configuration option to `false`.
|
||||
-/
|
||||
syntax negConfigItem := " -" noWs ident
|
||||
syntax negConfigItem := "-" noWs ident
|
||||
/--
|
||||
`(opt := val)` sets the `opt` configuration option to `val`.
|
||||
|
||||
|
||||
@@ -90,14 +90,10 @@ def withPtrAddr {α : Type u} {β : Type v} (a : α) (k : USize → β) (h : ∀
|
||||
def Runtime.markMultiThreaded (a : α) : α := a
|
||||
|
||||
/--
|
||||
Marks given value and its object graph closure as persistent. This will remove
|
||||
reference counter updates but prevent the closure from being deallocated until
|
||||
the end of the process! It can still be useful to do eagerly when the value
|
||||
will be marked persistent later anyway and there is available time budget to
|
||||
mark it now or it would be unnecessarily marked multi-threaded in between.
|
||||
|
||||
This function is only safe to use on objects (in the full closure) which are
|
||||
not used concurrently or which are already persistent.
|
||||
-/
|
||||
Marks given value and its object graph closure as persistent. This will remove
|
||||
reference counter updates but prevent the closure from being deallocated until
|
||||
the end of the process! It can still be useful to do eagerly when the value
|
||||
will be marked persistent later anyway and there is available time budget to
|
||||
mark it now or it would be unnecessarily marked multi-threaded in between. -/
|
||||
@[extern "lean_runtime_mark_persistent"]
|
||||
unsafe def Runtime.markPersistent (a : α) : α := a
|
||||
def Runtime.markPersistent (a : α) : α := a
|
||||
|
||||
@@ -205,8 +205,8 @@ def getParamInfo (k : ParamMap.Key) : M (Array Param) := do
|
||||
|
||||
/-- For each ps[i], if ps[i] is owned, then mark xs[i] as owned. -/
|
||||
def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
|
||||
xs.size.forM fun i _ => do
|
||||
let x := xs[i]
|
||||
xs.size.forM fun i => do
|
||||
let x := xs[i]!
|
||||
let p := ps[i]!
|
||||
unless p.borrow do ownArg x
|
||||
|
||||
@@ -216,8 +216,8 @@ def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
|
||||
we would have to insert a `dec xs[i]` after `f xs` and consequently
|
||||
"break" the tail call. -/
|
||||
def ownParamsUsingArgs (xs : Array Arg) (ps : Array Param) : M Unit :=
|
||||
xs.size.forM fun i _ => do
|
||||
let x := xs[i]
|
||||
xs.size.forM fun i => do
|
||||
let x := xs[i]!
|
||||
let p := ps[i]!
|
||||
match x with
|
||||
| Arg.var x => if (← isOwned x) then ownVar p.x
|
||||
|
||||
@@ -48,9 +48,9 @@ def requiresBoxedVersion (env : Environment) (decl : Decl) : Bool :=
|
||||
def mkBoxedVersionAux (decl : Decl) : N Decl := do
|
||||
let ps := decl.params
|
||||
let qs ← ps.mapM fun _ => do let x ← N.mkFresh; pure { x := x, ty := IRType.object, borrow := false : Param }
|
||||
let (newVDecls, xs) ← qs.size.foldM (init := (#[], #[])) fun i _ (newVDecls, xs) => do
|
||||
let (newVDecls, xs) ← qs.size.foldM (init := (#[], #[])) fun i (newVDecls, xs) => do
|
||||
let p := ps[i]!
|
||||
let q := qs[i]
|
||||
let q := qs[i]!
|
||||
if !p.ty.isScalar then
|
||||
pure (newVDecls, xs.push (Arg.var q.x))
|
||||
else
|
||||
|
||||
@@ -63,7 +63,7 @@ partial def merge (v₁ v₂ : Value) : Value :=
|
||||
| top, _ => top
|
||||
| _, top => top
|
||||
| v₁@(ctor i₁ vs₁), v₂@(ctor i₂ vs₂) =>
|
||||
if i₁ == i₂ then ctor i₁ <| vs₁.size.fold (init := #[]) fun i _ r => r.push (merge vs₁[i] vs₂[i]!)
|
||||
if i₁ == i₂ then ctor i₁ <| vs₁.size.fold (init := #[]) fun i r => r.push (merge vs₁[i]! vs₂[i]!)
|
||||
else choice [v₁, v₂]
|
||||
| choice vs₁, choice vs₂ => choice <| vs₁.foldl (addChoice merge) vs₂
|
||||
| choice vs, v => choice <| addChoice merge vs v
|
||||
@@ -225,8 +225,8 @@ def updateCurrFnSummary (v : Value) : M Unit := do
|
||||
def updateJPParamsAssignment (ys : Array Param) (xs : Array Arg) : M Bool := do
|
||||
let ctx ← read
|
||||
let currFnIdx := ctx.currFnIdx
|
||||
ys.size.foldM (init := false) fun i _ r => do
|
||||
let y := ys[i]
|
||||
ys.size.foldM (init := false) fun i r => do
|
||||
let y := ys[i]!
|
||||
let x := xs[i]!
|
||||
let yVal ← findVarValue y.x
|
||||
let xVal ← findArgValue x
|
||||
@@ -282,8 +282,8 @@ partial def interpFnBody : FnBody → M Unit
|
||||
def inferStep : M Bool := do
|
||||
let ctx ← read
|
||||
modify fun s => { s with assignments := ctx.decls.map fun _ => {} }
|
||||
ctx.decls.size.foldM (init := false) fun idx _ modified => do
|
||||
match ctx.decls[idx] with
|
||||
ctx.decls.size.foldM (init := false) fun idx modified => do
|
||||
match ctx.decls[idx]! with
|
||||
| .fdecl (xs := ys) (body := b) .. => do
|
||||
let s ← get
|
||||
let currVals := s.funVals[idx]!
|
||||
@@ -336,8 +336,8 @@ def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
let funVals := s.funVals
|
||||
let assignments := s.assignments
|
||||
modify fun s =>
|
||||
let env := decls.size.fold (init := s.env) fun i _ env =>
|
||||
addFunctionSummary env decls[i].name funVals[i]!
|
||||
let env := decls.size.fold (init := s.env) fun i env =>
|
||||
addFunctionSummary env decls[i]!.name funVals[i]!
|
||||
{ s with env := env }
|
||||
return decls.mapIdx fun i decl => elimDead assignments[i]! decl
|
||||
|
||||
|
||||
@@ -108,9 +108,9 @@ def emitFnDeclAux (decl : Decl) (cppBaseName : String) (isExternal : Bool) : M U
|
||||
if ps.size > closureMaxArgs && isBoxedName decl.name then
|
||||
emit "lean_object**"
|
||||
else
|
||||
ps.size.forM fun i _ => do
|
||||
ps.size.forM fun i => do
|
||||
if i > 0 then emit ", "
|
||||
emit (toCType ps[i].ty)
|
||||
emit (toCType ps[i]!.ty)
|
||||
emit ")"
|
||||
emitLn ";"
|
||||
|
||||
@@ -321,22 +321,20 @@ def emitSSet (x : VarId) (n : Nat) (offset : Nat) (y : VarId) (t : IRType) : M U
|
||||
|
||||
def emitJmp (j : JoinPointId) (xs : Array Arg) : M Unit := do
|
||||
let ps ← getJPParams j
|
||||
if h : xs.size = ps.size then
|
||||
xs.size.forM fun i _ => do
|
||||
let p := ps[i]
|
||||
let x := xs[i]
|
||||
emit p.x; emit " = "; emitArg x; emitLn ";"
|
||||
emit "goto "; emit j; emitLn ";"
|
||||
else
|
||||
do throw "invalid goto"
|
||||
unless xs.size == ps.size do throw "invalid goto"
|
||||
xs.size.forM fun i => do
|
||||
let p := ps[i]!
|
||||
let x := xs[i]!
|
||||
emit p.x; emit " = "; emitArg x; emitLn ";"
|
||||
emit "goto "; emit j; emitLn ";"
|
||||
|
||||
def emitLhs (z : VarId) : M Unit := do
|
||||
emit z; emit " = "
|
||||
|
||||
def emitArgs (ys : Array Arg) : M Unit :=
|
||||
ys.size.forM fun i _ => do
|
||||
ys.size.forM fun i => do
|
||||
if i > 0 then emit ", "
|
||||
emitArg ys[i]
|
||||
emitArg ys[i]!
|
||||
|
||||
def emitCtorScalarSize (usize : Nat) (ssize : Nat) : M Unit := do
|
||||
if usize == 0 then emit ssize
|
||||
@@ -348,8 +346,8 @@ def emitAllocCtor (c : CtorInfo) : M Unit := do
|
||||
emitCtorScalarSize c.usize c.ssize; emitLn ");"
|
||||
|
||||
def emitCtorSetArgs (z : VarId) (ys : Array Arg) : M Unit :=
|
||||
ys.size.forM fun i _ => do
|
||||
emit "lean_ctor_set("; emit z; emit ", "; emit i; emit ", "; emitArg ys[i]; emitLn ");"
|
||||
ys.size.forM fun i => do
|
||||
emit "lean_ctor_set("; emit z; emit ", "; emit i; emit ", "; emitArg ys[i]!; emitLn ");"
|
||||
|
||||
def emitCtor (z : VarId) (c : CtorInfo) (ys : Array Arg) : M Unit := do
|
||||
emitLhs z;
|
||||
@@ -360,7 +358,7 @@ def emitCtor (z : VarId) (c : CtorInfo) (ys : Array Arg) : M Unit := do
|
||||
|
||||
def emitReset (z : VarId) (n : Nat) (x : VarId) : M Unit := do
|
||||
emit "if (lean_is_exclusive("; emit x; emitLn ")) {";
|
||||
n.forM fun i _ => do
|
||||
n.forM fun i => do
|
||||
emit " lean_ctor_release("; emit x; emit ", "; emit i; emitLn ");"
|
||||
emit " "; emitLhs z; emit x; emitLn ";";
|
||||
emitLn "} else {";
|
||||
@@ -401,12 +399,12 @@ def emitSimpleExternalCall (f : String) (ps : Array Param) (ys : Array Arg) : M
|
||||
emit f; emit "("
|
||||
-- We must remove irrelevant arguments to extern calls.
|
||||
discard <| ys.size.foldM
|
||||
(fun i _ (first : Bool) =>
|
||||
(fun i (first : Bool) =>
|
||||
if ps[i]!.ty.isIrrelevant then
|
||||
pure first
|
||||
else do
|
||||
unless first do emit ", "
|
||||
emitArg ys[i]
|
||||
emitArg ys[i]!
|
||||
pure false)
|
||||
true
|
||||
emitLn ");"
|
||||
@@ -433,8 +431,8 @@ def emitPartialApp (z : VarId) (f : FunId) (ys : Array Arg) : M Unit := do
|
||||
let decl ← getDecl f
|
||||
let arity := decl.params.size;
|
||||
emitLhs z; emit "lean_alloc_closure((void*)("; emitCName f; emit "), "; emit arity; emit ", "; emit ys.size; emitLn ");";
|
||||
ys.size.forM fun i _ => do
|
||||
let y := ys[i]
|
||||
ys.size.forM fun i => do
|
||||
let y := ys[i]!
|
||||
emit "lean_closure_set("; emit z; emit ", "; emit i; emit ", "; emitArg y; emitLn ");"
|
||||
|
||||
def emitApp (z : VarId) (f : VarId) (ys : Array Arg) : M Unit :=
|
||||
@@ -546,36 +544,34 @@ That is, we have
|
||||
-/
|
||||
def overwriteParam (ps : Array Param) (ys : Array Arg) : Bool :=
|
||||
let n := ps.size;
|
||||
n.any fun i _ =>
|
||||
let p := ps[i]
|
||||
(i+1, n).anyI fun j _ _ => paramEqArg p ys[j]!
|
||||
n.any fun i =>
|
||||
let p := ps[i]!
|
||||
(i+1, n).anyI fun j => paramEqArg p ys[j]!
|
||||
|
||||
def emitTailCall (v : Expr) : M Unit :=
|
||||
match v with
|
||||
| Expr.fap _ ys => do
|
||||
let ctx ← read
|
||||
let ps := ctx.mainParams
|
||||
if h : ps.size = ys.size then
|
||||
if overwriteParam ps ys then
|
||||
emitLn "{"
|
||||
ps.size.forM fun i _ => do
|
||||
let p := ps[i]
|
||||
let y := ys[i]
|
||||
unless paramEqArg p y do
|
||||
emit (toCType p.ty); emit " _tmp_"; emit i; emit " = "; emitArg y; emitLn ";"
|
||||
ps.size.forM fun i _ => do
|
||||
let p := ps[i]
|
||||
let y := ys[i]
|
||||
unless paramEqArg p y do emit p.x; emit " = _tmp_"; emit i; emitLn ";"
|
||||
emitLn "}"
|
||||
else
|
||||
ys.size.forM fun i _ => do
|
||||
let p := ps[i]
|
||||
let y := ys[i]
|
||||
unless paramEqArg p y do emit p.x; emit " = "; emitArg y; emitLn ";"
|
||||
emitLn "goto _start;"
|
||||
unless ps.size == ys.size do throw "invalid tail call"
|
||||
if overwriteParam ps ys then
|
||||
emitLn "{"
|
||||
ps.size.forM fun i => do
|
||||
let p := ps[i]!
|
||||
let y := ys[i]!
|
||||
unless paramEqArg p y do
|
||||
emit (toCType p.ty); emit " _tmp_"; emit i; emit " = "; emitArg y; emitLn ";"
|
||||
ps.size.forM fun i => do
|
||||
let p := ps[i]!
|
||||
let y := ys[i]!
|
||||
unless paramEqArg p y do emit p.x; emit " = _tmp_"; emit i; emitLn ";"
|
||||
emitLn "}"
|
||||
else
|
||||
throw "invalid tail call"
|
||||
ys.size.forM fun i => do
|
||||
let p := ps[i]!
|
||||
let y := ys[i]!
|
||||
unless paramEqArg p y do emit p.x; emit " = "; emitArg y; emitLn ";"
|
||||
emitLn "goto _start;"
|
||||
| _ => throw "bug at emitTailCall"
|
||||
|
||||
mutual
|
||||
@@ -658,16 +654,16 @@ def emitDeclAux (d : Decl) : M Unit := do
|
||||
if xs.size > closureMaxArgs && isBoxedName d.name then
|
||||
emit "lean_object** _args"
|
||||
else
|
||||
xs.size.forM fun i _ => do
|
||||
xs.size.forM fun i => do
|
||||
if i > 0 then emit ", "
|
||||
let x := xs[i]
|
||||
let x := xs[i]!
|
||||
emit (toCType x.ty); emit " "; emit x.x
|
||||
emit ")"
|
||||
else
|
||||
emit ("_init_" ++ baseName ++ "()")
|
||||
emitLn " {";
|
||||
if xs.size > closureMaxArgs && isBoxedName d.name then
|
||||
xs.size.forM fun i _ => do
|
||||
xs.size.forM fun i => do
|
||||
let x := xs[i]!
|
||||
emit "lean_object* "; emit x.x; emit " = _args["; emit i; emitLn "];"
|
||||
emitLn "_start:";
|
||||
|
||||
@@ -571,9 +571,9 @@ def emitAllocCtor (builder : LLVM.Builder llvmctx)
|
||||
|
||||
def emitCtorSetArgs (builder : LLVM.Builder llvmctx)
|
||||
(z : VarId) (ys : Array Arg) : M llvmctx Unit := do
|
||||
ys.size.forM fun i _ => do
|
||||
ys.size.forM fun i => do
|
||||
let zv ← emitLhsVal builder z
|
||||
let (_yty, yv) ← emitArgVal builder ys[i]
|
||||
let (_yty, yv) ← emitArgVal builder ys[i]!
|
||||
let iv ← constIntUnsigned i
|
||||
callLeanCtorSet builder zv iv yv
|
||||
emitLhsSlotStore builder z zv
|
||||
@@ -702,8 +702,8 @@ def emitPartialApp (builder : LLVM.Builder llvmctx) (z : VarId) (f : FunId) (ys
|
||||
(← constIntUnsigned arity)
|
||||
(← constIntUnsigned ys.size)
|
||||
LLVM.buildStore builder zval zslot
|
||||
ys.size.forM fun i _ => do
|
||||
let (yty, yslot) ← emitArgSlot_ builder ys[i]
|
||||
ys.size.forM fun i => do
|
||||
let (yty, yslot) ← emitArgSlot_ builder ys[i]!
|
||||
let yval ← LLVM.buildLoad2 builder yty yslot
|
||||
callLeanClosureSetFn builder zval (← constIntUnsigned i) yval
|
||||
|
||||
@@ -922,7 +922,7 @@ def emitReset (builder : LLVM.Builder llvmctx) (z : VarId) (n : Nat) (x : VarId)
|
||||
buildIfThenElse_ builder "isExclusive" isExclusive
|
||||
(fun builder => do
|
||||
let xv ← emitLhsVal builder x
|
||||
n.forM fun i _ => do
|
||||
n.forM fun i => do
|
||||
callLeanCtorRelease builder xv (← constIntUnsigned i)
|
||||
emitLhsSlotStore builder z xv
|
||||
return ShouldForwardControlFlow.yes
|
||||
|
||||
@@ -134,15 +134,15 @@ abbrev M := ReaderT Context (StateM Nat)
|
||||
modifyGet fun n => ({ idx := n }, n + 1)
|
||||
|
||||
def releaseUnreadFields (y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
|
||||
mask.size.foldM (init := b) fun i _ b =>
|
||||
match mask[i] with
|
||||
mask.size.foldM (init := b) fun i b =>
|
||||
match mask.get! i with
|
||||
| some _ => pure b -- code took ownership of this field
|
||||
| none => do
|
||||
let fld ← mkFresh
|
||||
pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true false b))
|
||||
|
||||
def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
|
||||
zs.size.fold (init := b) fun i _ b => FnBody.set y i zs[i] b
|
||||
zs.size.fold (init := b) fun i b => FnBody.set y i (zs.get! i) b
|
||||
|
||||
/-- Given `set x[i] := y`, return true iff `y := proj[i] x` -/
|
||||
def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=
|
||||
|
||||
@@ -79,13 +79,13 @@ private def addDecForAlt (ctx : Context) (caseLiveVars altLiveVars : LiveVarSet)
|
||||
/-- `isFirstOcc xs x i = true` if `xs[i]` is the first occurrence of `xs[i]` in `xs` -/
|
||||
private def isFirstOcc (xs : Array Arg) (i : Nat) : Bool :=
|
||||
let x := xs[i]!
|
||||
i.all fun j _ => xs[j]! != x
|
||||
i.all fun j => xs[j]! != x
|
||||
|
||||
/-- Return true if `x` also occurs in `ys` in a position that is not consumed.
|
||||
That is, it is also passed as a borrow reference. -/
|
||||
private def isBorrowParamAux (x : VarId) (ys : Array Arg) (consumeParamPred : Nat → Bool) : Bool :=
|
||||
ys.size.any fun i _ =>
|
||||
let y := ys[i]
|
||||
ys.size.any fun i =>
|
||||
let y := ys[i]!
|
||||
match y with
|
||||
| Arg.irrelevant => false
|
||||
| Arg.var y => x == y && !consumeParamPred i
|
||||
@@ -99,15 +99,15 @@ Return `n`, the number of times `x` is consumed.
|
||||
- `consumeParamPred i = true` if parameter `i` is consumed.
|
||||
-/
|
||||
private def getNumConsumptions (x : VarId) (ys : Array Arg) (consumeParamPred : Nat → Bool) : Nat :=
|
||||
ys.size.fold (init := 0) fun i _ n =>
|
||||
let y := ys[i]
|
||||
ys.size.fold (init := 0) fun i n =>
|
||||
let y := ys[i]!
|
||||
match y with
|
||||
| Arg.irrelevant => n
|
||||
| Arg.var y => if x == y && consumeParamPred i then n+1 else n
|
||||
|
||||
private def addIncBeforeAux (ctx : Context) (xs : Array Arg) (consumeParamPred : Nat → Bool) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody :=
|
||||
xs.size.fold (init := b) fun i _ b =>
|
||||
let x := xs[i]
|
||||
xs.size.fold (init := b) fun i b =>
|
||||
let x := xs[i]!
|
||||
match x with
|
||||
| Arg.irrelevant => b
|
||||
| Arg.var x =>
|
||||
@@ -128,8 +128,8 @@ private def addIncBefore (ctx : Context) (xs : Array Arg) (ps : Array Param) (b
|
||||
|
||||
/-- See `addIncBeforeAux`/`addIncBefore` for the procedure that inserts `inc` operations before an application. -/
|
||||
private def addDecAfterFullApp (ctx : Context) (xs : Array Arg) (ps : Array Param) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody :=
|
||||
xs.size.fold (init := b) fun i _ b =>
|
||||
match xs[i] with
|
||||
xs.size.fold (init := b) fun i b =>
|
||||
match xs[i]! with
|
||||
| Arg.irrelevant => b
|
||||
| Arg.var x =>
|
||||
/- We must add a `dec` if `x` must be consumed, it is alive after the application,
|
||||
|
||||
@@ -587,15 +587,15 @@ def Decl.elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
refer to the docstring of `Decl.safe`.
|
||||
-/
|
||||
if decls[i]!.safe then .bot else .top
|
||||
let mut funVals := decls.size.fold (init := .empty) fun i _ p => p.push (initialVal i)
|
||||
let mut funVals := decls.size.fold (init := .empty) fun i p => p.push (initialVal i)
|
||||
let ctx := { decls }
|
||||
let mut state := { assignments, funVals }
|
||||
(_, state) ← inferMain |>.run ctx |>.run state
|
||||
funVals := state.funVals
|
||||
assignments := state.assignments
|
||||
modifyEnv fun e =>
|
||||
decls.size.fold (init := e) fun i _ env =>
|
||||
addFunctionSummary env decls[i].name funVals[i]!
|
||||
decls.size.fold (init := e) fun i env =>
|
||||
addFunctionSummary env decls[i]!.name funVals[i]!
|
||||
|
||||
decls.mapIdxM fun i decl => if decl.safe then elimDead assignments[i]! decl else return decl
|
||||
|
||||
|
||||
@@ -76,8 +76,8 @@ def getType (fvarId : FVarId) : InferTypeM Expr := do
|
||||
|
||||
def mkForallFVars (xs : Array Expr) (type : Expr) : InferTypeM Expr :=
|
||||
let b := type.abstract xs
|
||||
xs.size.foldRevM (init := b) fun i _ b => do
|
||||
let x := xs[i]
|
||||
xs.size.foldRevM (init := b) fun i b => do
|
||||
let x := xs[i]!
|
||||
let n ← InferType.getBinderName x.fvarId!
|
||||
let ty ← InferType.getType x.fvarId!
|
||||
let ty := ty.abstractRange i xs;
|
||||
|
||||
@@ -157,7 +157,9 @@ def installBeforeEachOccurrence (targetName : Name) (p : Pass → Pass) : PassIn
|
||||
def replacePass (targetName : Name) (p : Pass → Pass) (occurrence : Nat := 0) : PassInstaller where
|
||||
install passes := do
|
||||
let some idx := passes.findIdx? (fun p => p.name == targetName && p.occurrence == occurrence) | throwError s!"Tried to replace {targetName}, occurrence {occurrence} but {targetName} is not in the pass list"
|
||||
return passes.modify idx p
|
||||
let target := passes[idx]!
|
||||
let replacement := p target
|
||||
return passes.set! idx replacement
|
||||
|
||||
def replaceEachOccurrence (targetName : Name) (p : Pass → Pass) : PassInstaller :=
|
||||
withEachOccurrence targetName (replacePass targetName p ·)
|
||||
|
||||
@@ -11,7 +11,6 @@ import Lean.ResolveName
|
||||
import Lean.Elab.InfoTree.Types
|
||||
import Lean.MonadEnv
|
||||
import Lean.Elab.Exception
|
||||
import Lean.Language.Basic
|
||||
|
||||
namespace Lean
|
||||
register_builtin_option diagnostics : Bool := {
|
||||
@@ -31,11 +30,6 @@ register_builtin_option maxHeartbeats : Nat := {
|
||||
descr := "maximum amount of heartbeats per command. A heartbeat is number of (small) memory allocations (in thousands), 0 means no limit"
|
||||
}
|
||||
|
||||
register_builtin_option Elab.async : Bool := {
|
||||
defValue := false
|
||||
descr := "perform elaboration using multiple threads where possible"
|
||||
}
|
||||
|
||||
/--
|
||||
If the `diagnostics` option is not already set, gives a message explaining this option.
|
||||
Begins with a `\n`, so an error message can look like `m!"some error occurred{useDiagnosticMsg}"`.
|
||||
@@ -78,13 +72,6 @@ structure State where
|
||||
messages : MessageLog := {}
|
||||
/-- Info tree. We have the info tree here because we want to update it while adding attributes. -/
|
||||
infoState : Elab.InfoState := {}
|
||||
/--
|
||||
Snapshot trees of asynchronous subtasks. As these are untyped and reported only at the end of the
|
||||
command's main elaboration thread, they are only useful for basic message log reporting; for
|
||||
incremental reporting and reuse within a long-running elaboration thread, types rooted in
|
||||
`CommandParsedSnapshot` need to be adjusted.
|
||||
-/
|
||||
snapshotTasks : Array (Language.SnapshotTask Language.SnapshotTree) := #[]
|
||||
deriving Nonempty
|
||||
|
||||
/-- Context for the CoreM monad. -/
|
||||
@@ -193,8 +180,7 @@ instance : Elab.MonadInfoTree CoreM where
|
||||
modifyInfoState f := modify fun s => { s with infoState := f s.infoState }
|
||||
|
||||
@[inline] def modifyCache (f : Cache → Cache) : CoreM Unit :=
|
||||
modify fun ⟨env, next, ngen, trace, cache, messages, infoState, snaps⟩ =>
|
||||
⟨env, next, ngen, trace, f cache, messages, infoState, snaps⟩
|
||||
modify fun ⟨env, next, ngen, trace, cache, messages, infoState⟩ => ⟨env, next, ngen, trace, f cache, messages, infoState⟩
|
||||
|
||||
@[inline] def modifyInstLevelTypeCache (f : InstantiateLevelCache → InstantiateLevelCache) : CoreM Unit :=
|
||||
modifyCache fun ⟨c₁, c₂⟩ => ⟨f c₁, c₂⟩
|
||||
@@ -369,83 +355,13 @@ instance : MonadLog CoreM where
|
||||
if (← read).suppressElabErrors then
|
||||
-- discard elaboration errors, except for a few important and unlikely misleading ones, on
|
||||
-- parse error
|
||||
unless msg.data.hasTag (· matches `Elab.synthPlaceholder | `Tactic.unsolvedGoals | `trace) do
|
||||
unless msg.data.hasTag (· matches `Elab.synthPlaceholder | `Tactic.unsolvedGoals) do
|
||||
return
|
||||
|
||||
let ctx ← read
|
||||
let msg := { msg with data := MessageData.withNamingContext { currNamespace := ctx.currNamespace, openDecls := ctx.openDecls } msg.data };
|
||||
modify fun s => { s with messages := s.messages.add msg }
|
||||
|
||||
/--
|
||||
Includes a given task (such as from `wrapAsyncAsSnapshot`) in the overall snapshot tree for this
|
||||
command's elaboration, making its result available to reporting and the language server. The
|
||||
reporter will not know about this snapshot tree node until the main elaboration thread for this
|
||||
command has finished so this function is not useful for incremental reporting within a longer
|
||||
elaboration thread but only for tasks that outlive it such as background kernel checking or proof
|
||||
elaboration.
|
||||
-/
|
||||
def logSnapshotTask (task : Language.SnapshotTask Language.SnapshotTree) : CoreM Unit :=
|
||||
modify fun s => { s with snapshotTasks := s.snapshotTasks.push task }
|
||||
|
||||
/-- Wraps the given action for use in `EIO.asTask` etc., discarding its final monadic state. -/
|
||||
def wrapAsync (act : Unit → CoreM α) : CoreM (EIO Exception α) := do
|
||||
let st ← get
|
||||
let ctx ← read
|
||||
let heartbeats := (← IO.getNumHeartbeats) - ctx.initHeartbeats
|
||||
return withCurrHeartbeats (do
|
||||
-- include heartbeats since start of elaboration in new thread as well such that forking off
|
||||
-- an action doesn't suddenly allow it to succeed from a lower heartbeat count
|
||||
IO.addHeartbeats heartbeats.toUInt64
|
||||
act () : CoreM _)
|
||||
|>.run' ctx st
|
||||
|
||||
/-- Option for capturing output to stderr during elaboration. -/
|
||||
register_builtin_option stderrAsMessages : Bool := {
|
||||
defValue := true
|
||||
group := "server"
|
||||
descr := "(server) capture output to the Lean stderr channel (such as from `dbg_trace`) during elaboration of a command as a diagnostic message"
|
||||
}
|
||||
|
||||
open Language in
|
||||
/--
|
||||
Wraps the given action for use in `BaseIO.asTask` etc., discarding its final state except for
|
||||
`logSnapshotTask` tasks, which are reported as part of the returned tree.
|
||||
-/
|
||||
def wrapAsyncAsSnapshot (act : Unit → CoreM Unit) (desc : String := by exact decl_name%.toString) :
|
||||
CoreM (BaseIO SnapshotTree) := do
|
||||
let t ← wrapAsync fun _ => do
|
||||
IO.FS.withIsolatedStreams (isolateStderr := stderrAsMessages.get (← getOptions)) do
|
||||
let tid ← IO.getTID
|
||||
-- reset trace state and message log so as not to report them twice
|
||||
modify ({ · with messages := {}, traceState := { tid } })
|
||||
try
|
||||
withTraceNode `Elab.async (fun _ => return desc) do
|
||||
act ()
|
||||
catch e =>
|
||||
logError e.toMessageData
|
||||
finally
|
||||
addTraceAsMessages
|
||||
get
|
||||
let ctx ← readThe Core.Context
|
||||
return do
|
||||
match (← t.toBaseIO) with
|
||||
| .ok (output, st) =>
|
||||
let mut msgs := st.messages
|
||||
if !output.isEmpty then
|
||||
msgs := msgs.add {
|
||||
fileName := ctx.fileName
|
||||
severity := MessageSeverity.information
|
||||
pos := ctx.fileMap.toPosition <| ctx.ref.getPos?.getD 0
|
||||
data := output
|
||||
}
|
||||
return .mk {
|
||||
desc
|
||||
diagnostics := (← Language.Snapshot.Diagnostics.ofMessageLog msgs)
|
||||
traces := st.traceState
|
||||
} st.snapshotTasks
|
||||
-- interrupt or abort exception as `try catch` above should have caught any others
|
||||
| .error _ => default
|
||||
|
||||
end Core
|
||||
|
||||
export Core (CoreM mkFreshUserName checkSystem withCurrHeartbeats)
|
||||
|
||||
@@ -277,23 +277,4 @@ attribute [deprecated Std.HashMap.empty (since := "2024-08-08")] mkHashMap
|
||||
attribute [deprecated Std.HashMap.empty (since := "2024-08-08")] HashMap.empty
|
||||
attribute [deprecated Std.HashMap.ofList (since := "2024-08-08")] HashMap.ofList
|
||||
|
||||
attribute [deprecated Std.HashMap.insert (since := "2024-08-08")] HashMap.insert
|
||||
attribute [deprecated Std.HashMap.containsThenInsert (since := "2024-08-08")] HashMap.insert'
|
||||
attribute [deprecated Std.HashMap.insertIfNew (since := "2024-08-08")] HashMap.insertIfNew
|
||||
attribute [deprecated Std.HashMap.erase (since := "2024-08-08")] HashMap.erase
|
||||
attribute [deprecated "Use `m[k]?` instead." (since := "2024-08-08")] HashMap.findEntry?
|
||||
attribute [deprecated "Use `m[k]?` instead." (since := "2024-08-08")] HashMap.find?
|
||||
attribute [deprecated "Use `m[k]?.getD` instead." (since := "2024-08-08")] HashMap.findD
|
||||
attribute [deprecated "Use `m[k]!` instead." (since := "2024-08-08")] HashMap.find!
|
||||
attribute [deprecated Std.HashMap.contains (since := "2024-08-08")] HashMap.contains
|
||||
attribute [deprecated Std.HashMap.foldM (since := "2024-08-08")] HashMap.foldM
|
||||
attribute [deprecated Std.HashMap.fold (since := "2024-08-08")] HashMap.fold
|
||||
attribute [deprecated Std.HashMap.forM (since := "2024-08-08")] HashMap.forM
|
||||
attribute [deprecated Std.HashMap.size (since := "2024-08-08")] HashMap.size
|
||||
attribute [deprecated Std.HashMap.isEmpty (since := "2024-08-08")] HashMap.isEmpty
|
||||
attribute [deprecated Std.HashMap.toList (since := "2024-08-08")] HashMap.toList
|
||||
attribute [deprecated Std.HashMap.toArray (since := "2024-08-08")] HashMap.toArray
|
||||
attribute [deprecated "Deprecateed without a replacement." (since := "2024-08-08")] HashMap.numBuckets
|
||||
attribute [deprecated "Deprecateed without a replacement." (since := "2024-08-08")] HashMap.ofListWith
|
||||
|
||||
end Lean.HashMap
|
||||
|
||||
@@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Nat.Fold
|
||||
import Init.Data.Array.Basic
|
||||
import Init.NotationExtra
|
||||
import Init.Data.ToString.Macro
|
||||
@@ -372,7 +371,7 @@ instance : ToString Stats := ⟨Stats.toString⟩
|
||||
end PersistentArray
|
||||
|
||||
def mkPersistentArray {α : Type u} (n : Nat) (v : α) : PArray α :=
|
||||
n.fold (init := PersistentArray.empty) fun _ _ p => p.push v
|
||||
n.fold (init := PersistentArray.empty) fun _ p => p.push v
|
||||
|
||||
@[inline] def mkPArray {α : Type u} (n : Nat) (v : α) : PArray α :=
|
||||
mkPersistentArray n v
|
||||
|
||||
@@ -23,7 +23,6 @@ import Lean.Elab.Quotation
|
||||
import Lean.Elab.Syntax
|
||||
import Lean.Elab.Do
|
||||
import Lean.Elab.StructInst
|
||||
import Lean.Elab.MutualInductive
|
||||
import Lean.Elab.Inductive
|
||||
import Lean.Elab.Structure
|
||||
import Lean.Elab.Print
|
||||
|
||||
@@ -807,8 +807,8 @@ def getElabElimExprInfo (elimExpr : Expr) : MetaM ElabElimInfo := do
|
||||
These are the primary set of major parameters.
|
||||
-/
|
||||
let initMotiveFVars : CollectFVars.State := motiveArgs.foldl (init := {}) collectFVars
|
||||
let motiveFVars ← xs.size.foldRevM (init := initMotiveFVars) fun i _ s => do
|
||||
let x := xs[i]
|
||||
let motiveFVars ← xs.size.foldRevM (init := initMotiveFVars) fun i s => do
|
||||
let x := xs[i]!
|
||||
if s.fvarSet.contains x.fvarId! then
|
||||
return collectFVars s (← inferType x)
|
||||
else
|
||||
@@ -1150,33 +1150,48 @@ private def throwLValError (e : Expr) (eType : Expr) (msg : MessageData) : TermE
|
||||
throwError "{msg}{indentExpr e}\nhas type{indentExpr eType}"
|
||||
|
||||
/--
|
||||
`findMethod? S fName` tries the for each namespace `S'` in the resolution order for `S` to resolve the name `S'.fname`.
|
||||
If it resolves to `name`, returns `(S', name)`.
|
||||
`findMethod? S fName` tries the following for each namespace `S'` in the resolution order for `S`:
|
||||
- If `env` contains `S' ++ fName`, returns `(S', S' ++ fName)`
|
||||
- Otherwise if `env` contains private name `prv` for `S' ++ fName`, returns `(S', prv)`
|
||||
-/
|
||||
private partial def findMethod? (structName fieldName : Name) : MetaM (Option (Name × Name)) := do
|
||||
let env ← getEnv
|
||||
let find? structName' : MetaM (Option (Name × Name)) := do
|
||||
let fullName := structName' ++ fieldName
|
||||
-- We do not want to make use of the current namespace for resolution.
|
||||
let candidates := ResolveName.resolveGlobalName (← getEnv) Name.anonymous (← getOpenDecls) fullName
|
||||
|>.filter (fun (_, fieldList) => fieldList.isEmpty)
|
||||
|>.map Prod.fst
|
||||
match candidates with
|
||||
| [] => return none
|
||||
| [fullName'] => return some (structName', fullName')
|
||||
| _ => throwError "\
|
||||
invalid field notation '{fieldName}', the name '{fullName}' is ambiguous, possible interpretations: \
|
||||
{MessageData.joinSep (candidates.map (m!"'{.ofConstName ·}'")) ", "}"
|
||||
if env.contains fullName then
|
||||
return some (structName', fullName)
|
||||
let fullNamePrv := mkPrivateName env fullName
|
||||
if env.contains fullNamePrv then
|
||||
return some (structName', fullNamePrv)
|
||||
return none
|
||||
-- Optimization: the first element of the resolution order is `structName`,
|
||||
-- so we can skip computing the resolution order in the common case
|
||||
-- of the name resolving in the `structName` namespace.
|
||||
find? structName <||> do
|
||||
let resolutionOrder ← if isStructure env structName then getStructureResolutionOrder structName else pure #[structName]
|
||||
for ns in resolutionOrder[1:resolutionOrder.size] do
|
||||
if let some res ← find? ns then
|
||||
for h : i in [1:resolutionOrder.size] do
|
||||
if let some res ← find? resolutionOrder[i] then
|
||||
return res
|
||||
return none
|
||||
|
||||
/--
|
||||
Return `some (structName', fullName)` if `structName ++ fieldName` is an alias for `fullName`, and
|
||||
`fullName` is of the form `structName' ++ fieldName`.
|
||||
|
||||
TODO: if there is more than one applicable alias, it returns `none`. We should consider throwing an error or
|
||||
warning.
|
||||
-/
|
||||
private def findMethodAlias? (env : Environment) (structName fieldName : Name) : Option (Name × Name) :=
|
||||
let fullName := structName ++ fieldName
|
||||
-- We never skip `protected` aliases when resolving dot-notation.
|
||||
let aliasesCandidates := getAliases env fullName (skipProtected := false) |>.filterMap fun alias =>
|
||||
match alias.eraseSuffix? fieldName with
|
||||
| none => none
|
||||
| some structName' => some (structName', alias)
|
||||
match aliasesCandidates with
|
||||
| [r] => some r
|
||||
| _ => none
|
||||
|
||||
private def throwInvalidFieldNotation (e eType : Expr) : TermElabM α :=
|
||||
throwLValError e eType "invalid field notation, type is not of the form (C ...) where C is a constant"
|
||||
|
||||
@@ -1208,22 +1223,30 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L
|
||||
throwLValError e eType m!"invalid projection, structure has only {numFields} field(s)"
|
||||
| some structName, LVal.fieldName _ fieldName _ _ =>
|
||||
let env ← getEnv
|
||||
let searchEnv : Unit → TermElabM LValResolution := fun _ => do
|
||||
if let some (baseStructName, fullName) ← findMethod? structName (.mkSimple fieldName) then
|
||||
return LValResolution.const baseStructName structName fullName
|
||||
else if let some (structName', fullName) := findMethodAlias? env structName (.mkSimple fieldName) then
|
||||
return LValResolution.const structName' structName' fullName
|
||||
else
|
||||
throwLValError e eType
|
||||
m!"invalid field '{fieldName}', the environment does not contain '{Name.mkStr structName fieldName}'"
|
||||
-- search local context first, then environment
|
||||
let searchCtx : Unit → TermElabM LValResolution := fun _ => do
|
||||
let fullName := Name.mkStr structName fieldName
|
||||
for localDecl in (← getLCtx) do
|
||||
if localDecl.isAuxDecl then
|
||||
if let some localDeclFullName := (← read).auxDeclToFullName.find? localDecl.fvarId then
|
||||
if fullName == (privateToUserName? localDeclFullName).getD localDeclFullName then
|
||||
/- LVal notation is being used to make a "local" recursive call. -/
|
||||
return LValResolution.localRec structName fullName localDecl.toExpr
|
||||
searchEnv ()
|
||||
if isStructure env structName then
|
||||
if let some baseStructName := findField? env structName (Name.mkSimple fieldName) then
|
||||
return LValResolution.projFn baseStructName structName (Name.mkSimple fieldName)
|
||||
-- Search the local context first
|
||||
let fullName := Name.mkStr structName fieldName
|
||||
for localDecl in (← getLCtx) do
|
||||
if localDecl.isAuxDecl then
|
||||
if let some localDeclFullName := (← read).auxDeclToFullName.find? localDecl.fvarId then
|
||||
if fullName == (privateToUserName? localDeclFullName).getD localDeclFullName then
|
||||
/- LVal notation is being used to make a "local" recursive call. -/
|
||||
return LValResolution.localRec structName fullName localDecl.toExpr
|
||||
-- Then search the environment
|
||||
if let some (baseStructName, fullName) ← findMethod? structName (.mkSimple fieldName) then
|
||||
return LValResolution.const baseStructName structName fullName
|
||||
throwLValError e eType
|
||||
m!"invalid field '{fieldName}', the environment does not contain '{Name.mkStr structName fieldName}'"
|
||||
match findField? env structName (Name.mkSimple fieldName) with
|
||||
| some baseStructName => return LValResolution.projFn baseStructName structName (Name.mkSimple fieldName)
|
||||
| none => searchCtx ()
|
||||
else
|
||||
searchCtx ()
|
||||
| none, LVal.fieldName _ _ (some suffix) _ =>
|
||||
if e.isConst then
|
||||
throwUnknownConstant (e.constName! ++ suffix)
|
||||
@@ -1303,7 +1326,7 @@ Otherwise, if there isn't another parameter with the same name, we add `e` to `n
|
||||
|
||||
Remark: `fullName` is the name of the resolved "field" access function. It is used for reporting errors
|
||||
-/
|
||||
private partial def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (f : Expr) (explicit : Bool) :
|
||||
private partial def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (f : Expr) :
|
||||
MetaM (Array Arg × Array NamedArg) := do
|
||||
withoutModifyingState <| go f (← inferType f) 0 namedArgs (namedArgs.map (·.name)) true
|
||||
where
|
||||
@@ -1328,11 +1351,11 @@ where
|
||||
/- If there is named argument with name `xDecl.userName`, then it is accounted for and we can't make use of it. -/
|
||||
remainingNamedArgs := remainingNamedArgs.eraseIdx idx
|
||||
else
|
||||
if ← typeMatchesBaseName xDecl.type baseName then
|
||||
/- We found a type of the form (baseName ...), or we found the first explicit argument in useFirstExplicit mode.
|
||||
First, we check if the current argument is one that can be used positionally,
|
||||
if (← typeMatchesBaseName xDecl.type baseName) then
|
||||
/- We found a type of the form (baseName ...).
|
||||
First, we check if the current argument is an explicit one,
|
||||
and if the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/
|
||||
if h : argIdx ≤ args.size ∧ (explicit || bInfo.isExplicit) then
|
||||
if h : argIdx ≤ args.size ∧ bInfo.isExplicit then
|
||||
/- We can insert `e` as an explicit argument -/
|
||||
return (args.insertIdx argIdx (Arg.expr e), namedArgs)
|
||||
else
|
||||
@@ -1340,13 +1363,13 @@ where
|
||||
if there isn't an argument with the same name occurring before it. -/
|
||||
if !allowNamed || unusableNamedArgs.contains xDecl.userName then
|
||||
throwError "\
|
||||
invalid field notation, function '{.ofConstName fullName}' has argument with the expected type\
|
||||
invalid field notation, function '{fullName}' has argument with the expected type\
|
||||
{indentExpr xDecl.type}\n\
|
||||
but it cannot be used"
|
||||
else
|
||||
return (args, namedArgs.push { name := xDecl.userName, val := Arg.expr e })
|
||||
/- Advance `argIdx` and update seen named arguments. -/
|
||||
if explicit || bInfo.isExplicit then
|
||||
if bInfo.isExplicit then
|
||||
argIdx := argIdx + 1
|
||||
unusableNamedArgs := unusableNamedArgs.push xDecl.userName
|
||||
/- If named arguments aren't allowed, then it must still be possible to pass the value as an explicit argument.
|
||||
@@ -1357,7 +1380,7 @@ where
|
||||
if let some f' ← coerceToFunction? (mkAppN f xs) then
|
||||
return ← go f' (← inferType f') argIdx remainingNamedArgs unusableNamedArgs false
|
||||
throwError "\
|
||||
invalid field notation, function '{.ofConstName fullName}' does not have argument with type ({.ofConstName baseName} ...) that can be used, \
|
||||
invalid field notation, function '{fullName}' does not have argument with type ({baseName} ...) that can be used, \
|
||||
it must be explicit or implicit with a unique name"
|
||||
|
||||
/-- Adds the `TermInfo` for the field of a projection. See `Lean.Parser.Term.identProjKind`. -/
|
||||
@@ -1403,7 +1426,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
|
||||
let projFn ← mkConst constName
|
||||
let projFn ← addProjTermInfo lval.getRef projFn
|
||||
if lvals.isEmpty then
|
||||
let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFn explicit
|
||||
let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFn
|
||||
elabAppArgs projFn namedArgs args expectedType? explicit ellipsis
|
||||
else
|
||||
let f ← elabAppArgs projFn #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false)
|
||||
@@ -1411,7 +1434,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
|
||||
| LValResolution.localRec baseName fullName fvar =>
|
||||
let fvar ← addProjTermInfo lval.getRef fvar
|
||||
if lvals.isEmpty then
|
||||
let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvar explicit
|
||||
let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvar
|
||||
elabAppArgs fvar namedArgs args expectedType? explicit ellipsis
|
||||
else
|
||||
let f ← elabAppArgs fvar #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false)
|
||||
|
||||
@@ -84,7 +84,6 @@ structure State where
|
||||
ngen : NameGenerator := {}
|
||||
infoState : InfoState := {}
|
||||
traceState : TraceState := {}
|
||||
snapshotTasks : Array (Language.SnapshotTask Language.SnapshotTree) := #[]
|
||||
deriving Nonempty
|
||||
|
||||
structure Context where
|
||||
@@ -115,7 +114,8 @@ structure Context where
|
||||
-/
|
||||
suppressElabErrors : Bool := false
|
||||
|
||||
abbrev CommandElabM := ReaderT Context $ StateRefT State $ EIO Exception
|
||||
abbrev CommandElabCoreM (ε) := ReaderT Context $ StateRefT State $ EIO ε
|
||||
abbrev CommandElabM := CommandElabCoreM Exception
|
||||
abbrev CommandElab := Syntax → CommandElabM Unit
|
||||
structure Linter where
|
||||
run : Syntax → CommandElabM Unit
|
||||
@@ -198,6 +198,36 @@ instance : AddErrorMessageContext CommandElabM where
|
||||
let msg ← addMacroStack msg ctx.macroStack
|
||||
return (ref, msg)
|
||||
|
||||
def mkMessageAux (ctx : Context) (ref : Syntax) (msgData : MessageData) (severity : MessageSeverity) : Message :=
|
||||
let pos := ref.getPos?.getD ctx.cmdPos
|
||||
let endPos := ref.getTailPos?.getD pos
|
||||
mkMessageCore ctx.fileName ctx.fileMap msgData severity pos endPos
|
||||
|
||||
private def addTraceAsMessagesCore (ctx : Context) (log : MessageLog) (traceState : TraceState) : MessageLog := Id.run do
|
||||
if traceState.traces.isEmpty then return log
|
||||
let mut traces : Std.HashMap (String.Pos × String.Pos) (Array MessageData) := ∅
|
||||
for traceElem in traceState.traces do
|
||||
let ref := replaceRef traceElem.ref ctx.ref
|
||||
let pos := ref.getPos?.getD 0
|
||||
let endPos := ref.getTailPos?.getD pos
|
||||
traces := traces.insert (pos, endPos) <| traces.getD (pos, endPos) #[] |>.push traceElem.msg
|
||||
let mut log := log
|
||||
let traces' := traces.toArray.qsort fun ((a, _), _) ((b, _), _) => a < b
|
||||
for ((pos, endPos), traceMsg) in traces' do
|
||||
let data := .tagged `trace <| .joinSep traceMsg.toList "\n"
|
||||
log := log.add <| mkMessageCore ctx.fileName ctx.fileMap data .information pos endPos
|
||||
return log
|
||||
|
||||
private def addTraceAsMessages : CommandElabM Unit := do
|
||||
let ctx ← read
|
||||
-- do not add trace messages if `trace.profiler.output` is set as it would be redundant and
|
||||
-- pretty printing the trace messages is expensive
|
||||
if trace.profiler.output.get? (← getOptions) |>.isNone then
|
||||
modify fun s => { s with
|
||||
messages := addTraceAsMessagesCore ctx s.messages s.traceState
|
||||
traceState.traces := {}
|
||||
}
|
||||
|
||||
private def runCore (x : CoreM α) : CommandElabM α := do
|
||||
let s ← get
|
||||
let ctx ← read
|
||||
@@ -223,7 +253,6 @@ private def runCore (x : CoreM α) : CommandElabM α := do
|
||||
nextMacroScope := s.nextMacroScope
|
||||
infoState.enabled := s.infoState.enabled
|
||||
traceState := s.traceState
|
||||
snapshotTasks := s.snapshotTasks
|
||||
}
|
||||
let (ea, coreS) ← liftM x
|
||||
modify fun s => { s with
|
||||
@@ -232,7 +261,6 @@ private def runCore (x : CoreM α) : CommandElabM α := do
|
||||
ngen := coreS.ngen
|
||||
infoState.trees := s.infoState.trees.append coreS.infoState.trees
|
||||
traceState.traces := coreS.traceState.traces.map fun t => { t with ref := replaceRef t.ref ctx.ref }
|
||||
snapshotTasks := coreS.snapshotTasks
|
||||
messages := s.messages ++ coreS.messages
|
||||
}
|
||||
return ea
|
||||
@@ -240,6 +268,10 @@ private def runCore (x : CoreM α) : CommandElabM α := do
|
||||
def liftCoreM (x : CoreM α) : CommandElabM α := do
|
||||
MonadExcept.ofExcept (← runCore (observing x))
|
||||
|
||||
private def ioErrorToMessage (ctx : Context) (ref : Syntax) (err : IO.Error) : Message :=
|
||||
let ref := getBetterRef ref ctx.macroStack
|
||||
mkMessageAux ctx ref (toString err) MessageSeverity.error
|
||||
|
||||
@[inline] def liftIO {α} (x : IO α) : CommandElabM α := do
|
||||
let ctx ← read
|
||||
IO.toEIO (fun (ex : IO.Error) => Exception.error ctx.ref ex.toString) x
|
||||
@@ -262,8 +294,9 @@ instance : MonadLog CommandElabM where
|
||||
logMessage msg := do
|
||||
if (← read).suppressElabErrors then
|
||||
-- discard elaboration errors on parse error
|
||||
unless msg.data.hasTag (· matches `trace) do
|
||||
return
|
||||
-- NOTE: unlike `CoreM`'s `logMessage`, we do not currently have any command-level errors that
|
||||
-- we want to allowlist
|
||||
return
|
||||
let currNamespace ← getCurrNamespace
|
||||
let openDecls ← getOpenDecls
|
||||
let msg := { msg with data := MessageData.withNamingContext { currNamespace := currNamespace, openDecls := openDecls } msg.data }
|
||||
@@ -289,61 +322,6 @@ def runLinters (stx : Syntax) : CommandElabM Unit := do
|
||||
finally
|
||||
modify fun s => { savedState with messages := s.messages }
|
||||
|
||||
/--
|
||||
Catches and logs exceptions occurring in `x`. Unlike `try catch` in `CommandElabM`, this function
|
||||
catches interrupt exceptions as well and thus is intended for use at the top level of elaboration.
|
||||
Interrupt and abort exceptions are caught but not logged.
|
||||
-/
|
||||
@[inline] def withLoggingExceptions (x : CommandElabM Unit) : CommandElabM Unit := fun ctx ref =>
|
||||
EIO.catchExceptions (withLogging x ctx ref) (fun _ => pure ())
|
||||
|
||||
@[inherit_doc Core.wrapAsync]
|
||||
def wrapAsync (act : Unit → CommandElabM α) : CommandElabM (EIO Exception α) := do
|
||||
return act () |>.run (← read) |>.run' (← get)
|
||||
|
||||
open Language in
|
||||
@[inherit_doc Core.wrapAsyncAsSnapshot]
|
||||
-- `CoreM` and `CommandElabM` are too different to meaningfully share this code
|
||||
def wrapAsyncAsSnapshot (act : Unit → CommandElabM Unit)
|
||||
(desc : String := by exact decl_name%.toString) :
|
||||
CommandElabM (BaseIO SnapshotTree) := do
|
||||
let t ← wrapAsync fun _ => do
|
||||
IO.FS.withIsolatedStreams (isolateStderr := Core.stderrAsMessages.get (← getOptions)) do
|
||||
let tid ← IO.getTID
|
||||
-- reset trace state and message log so as not to report them twice
|
||||
modify ({ · with messages := {}, traceState := { tid } })
|
||||
try
|
||||
withTraceNode `Elab.async (fun _ => return desc) do
|
||||
act ()
|
||||
catch e =>
|
||||
logError e.toMessageData
|
||||
finally
|
||||
addTraceAsMessages
|
||||
get
|
||||
let ctx ← read
|
||||
return do
|
||||
match (← t.toBaseIO) with
|
||||
| .ok (output, st) =>
|
||||
let mut msgs := st.messages
|
||||
if !output.isEmpty then
|
||||
msgs := msgs.add {
|
||||
fileName := ctx.fileName
|
||||
severity := MessageSeverity.information
|
||||
pos := ctx.fileMap.toPosition <| ctx.ref.getPos?.getD 0
|
||||
data := output
|
||||
}
|
||||
return .mk {
|
||||
desc
|
||||
diagnostics := (← Language.Snapshot.Diagnostics.ofMessageLog msgs)
|
||||
traces := st.traceState
|
||||
} st.snapshotTasks
|
||||
-- interrupt or abort exception as `try catch` above should have caught any others
|
||||
| .error _ => default
|
||||
|
||||
@[inherit_doc Core.logSnapshotTask]
|
||||
def logSnapshotTask (task : Language.SnapshotTask Language.SnapshotTree) : CommandElabM Unit :=
|
||||
modify fun s => { s with snapshotTasks := s.snapshotTasks.push task }
|
||||
|
||||
protected def getCurrMacroScope : CommandElabM Nat := do pure (← read).currMacroScope
|
||||
protected def getMainModule : CommandElabM Name := do pure (← getEnv).mainModule
|
||||
|
||||
@@ -554,6 +532,12 @@ def elabCommandTopLevel (stx : Syntax) : CommandElabM Unit := withRef stx do pro
|
||||
let mut msgs := (← get).messages
|
||||
for tree in (← getInfoTrees) do
|
||||
trace[Elab.info] (← tree.format)
|
||||
if (← isTracingEnabledFor `Elab.snapshotTree) then
|
||||
if let some snap := (← read).snap? then
|
||||
-- We can assume that the root command snapshot is not involved in parallelism yet, so this
|
||||
-- should be true iff the command supports incrementality
|
||||
if (← IO.hasFinished snap.new.result) then
|
||||
liftCoreM <| Language.ToSnapshotTree.toSnapshotTree snap.new.result.get |>.trace
|
||||
modify fun st => { st with
|
||||
messages := initMsgs ++ msgs
|
||||
infoState := { st.infoState with trees := initInfoTrees ++ st.infoState.trees }
|
||||
@@ -684,6 +668,14 @@ def runTermElabM (elabFn : Array Expr → TermElabM α) : CommandElabM α := do
|
||||
Term.addAutoBoundImplicits' xs someType fun xs _ =>
|
||||
Term.withoutAutoBoundImplicit <| elabFn xs
|
||||
|
||||
/--
|
||||
Catches and logs exceptions occurring in `x`. Unlike `try catch` in `CommandElabM`, this function
|
||||
catches interrupt exceptions as well and thus is intended for use at the top level of elaboration.
|
||||
Interrupt and abort exceptions are caught but not logged.
|
||||
-/
|
||||
@[inline] def withLoggingExceptions (x : CommandElabM Unit) : CommandElabCoreM Empty Unit := fun ctx ref =>
|
||||
EIO.catchExceptions (withLogging x ctx ref) (fun _ => pure ())
|
||||
|
||||
private def liftAttrM {α} (x : AttrM α) : CommandElabM α := do
|
||||
liftCoreM x
|
||||
|
||||
|
||||
@@ -7,8 +7,9 @@ prelude
|
||||
import Lean.Util.CollectLevelParams
|
||||
import Lean.Elab.DeclUtil
|
||||
import Lean.Elab.DefView
|
||||
import Lean.Elab.Inductive
|
||||
import Lean.Elab.Structure
|
||||
import Lean.Elab.MutualDef
|
||||
import Lean.Elab.MutualInductive
|
||||
import Lean.Elab.DeclarationRange
|
||||
namespace Lean.Elab.Command
|
||||
|
||||
@@ -162,11 +163,15 @@ def elabDeclaration : CommandElab := fun stx => do
|
||||
if declKind == ``Lean.Parser.Command.«axiom» then
|
||||
let modifiers ← elabModifiers modifiers
|
||||
elabAxiom modifiers decl
|
||||
else if declKind == ``Lean.Parser.Command.«inductive»
|
||||
|| declKind == ``Lean.Parser.Command.classInductive
|
||||
|| declKind == ``Lean.Parser.Command.«structure» then
|
||||
else if declKind == ``Lean.Parser.Command.«inductive» then
|
||||
let modifiers ← elabModifiers modifiers
|
||||
elabInductive modifiers decl
|
||||
else if declKind == ``Lean.Parser.Command.classInductive then
|
||||
let modifiers ← elabModifiers modifiers
|
||||
elabClassInductive modifiers decl
|
||||
else if declKind == ``Lean.Parser.Command.«structure» then
|
||||
let modifiers ← elabModifiers modifiers
|
||||
elabStructure modifiers decl
|
||||
else
|
||||
throwError "unexpected declaration"
|
||||
|
||||
@@ -273,10 +278,10 @@ def elabMutual : CommandElab := fun stx => do
|
||||
-- only case implementing incrementality currently
|
||||
elabMutualDef stx[1].getArgs
|
||||
else withoutCommandIncrementality true do
|
||||
if ← isMutualInductive stx then
|
||||
if isMutualInductive stx then
|
||||
elabMutualInductive stx[1].getArgs
|
||||
else
|
||||
throwError "invalid mutual block: either all elements of the block must be inductive/structure declarations, or they must all be definitions/theorems/abbrevs"
|
||||
throwError "invalid mutual block: either all elements of the block must be inductive declarations, or they must all be definitions/theorems/abbrevs"
|
||||
|
||||
/- leading_parser "attribute " >> "[" >> sepBy1 (eraseAttr <|> Term.attrInstance) ", " >> "]" >> many1 ident -/
|
||||
@[builtin_command_elab «attribute»] def elabAttr : CommandElab := fun stx => do
|
||||
|
||||
@@ -1,36 +1,102 @@
|
||||
/-
|
||||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura, Kyle Miller
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Elab.MutualInductive
|
||||
import Lean.Util.ForEachExprWhere
|
||||
import Lean.Util.ReplaceLevel
|
||||
import Lean.Util.ReplaceExpr
|
||||
import Lean.Util.CollectLevelParams
|
||||
import Lean.Meta.Constructions
|
||||
import Lean.Meta.CollectFVars
|
||||
import Lean.Meta.SizeOf
|
||||
import Lean.Meta.Injective
|
||||
import Lean.Meta.IndPredBelow
|
||||
import Lean.Elab.Command
|
||||
import Lean.Elab.ComputedFields
|
||||
import Lean.Elab.DefView
|
||||
import Lean.Elab.DeclUtil
|
||||
import Lean.Elab.Deriving.Basic
|
||||
import Lean.Elab.DeclarationRange
|
||||
|
||||
namespace Lean.Elab.Command
|
||||
open Meta
|
||||
|
||||
/-
|
||||
```
|
||||
def Lean.Parser.Command.«inductive» :=
|
||||
leading_parser "inductive " >> declId >> optDeclSig >> optional ("where" <|> ":=") >> many ctor
|
||||
builtin_initialize
|
||||
registerTraceClass `Elab.inductive
|
||||
|
||||
def Lean.Parser.Command.classInductive :=
|
||||
leading_parser atomic (group ("class " >> "inductive ")) >> declId >> optDeclSig >> optional ("where" <|> ":=") >> many ctor >> optDeriving
|
||||
```
|
||||
register_builtin_option inductive.autoPromoteIndices : Bool := {
|
||||
defValue := true
|
||||
descr := "Promote indices to parameters in inductive types whenever possible."
|
||||
}
|
||||
|
||||
def checkValidInductiveModifier [Monad m] [MonadError m] (modifiers : Modifiers) : m Unit := do
|
||||
if modifiers.isNoncomputable then
|
||||
throwError "invalid use of 'noncomputable' in inductive declaration"
|
||||
if modifiers.isPartial then
|
||||
throwError "invalid use of 'partial' in inductive declaration"
|
||||
|
||||
def checkValidCtorModifier [Monad m] [MonadError m] (modifiers : Modifiers) : m Unit := do
|
||||
if modifiers.isNoncomputable then
|
||||
throwError "invalid use of 'noncomputable' in constructor declaration"
|
||||
if modifiers.isPartial then
|
||||
throwError "invalid use of 'partial' in constructor declaration"
|
||||
if modifiers.isUnsafe then
|
||||
throwError "invalid use of 'unsafe' in constructor declaration"
|
||||
if modifiers.attrs.size != 0 then
|
||||
throwError "invalid use of attributes in constructor declaration"
|
||||
|
||||
structure CtorView where
|
||||
ref : Syntax
|
||||
modifiers : Modifiers
|
||||
declName : Name
|
||||
binders : Syntax
|
||||
type? : Option Syntax
|
||||
deriving Inhabited
|
||||
|
||||
structure ComputedFieldView where
|
||||
ref : Syntax
|
||||
modifiers : Syntax
|
||||
fieldId : Name
|
||||
type : Syntax.Term
|
||||
matchAlts : TSyntax ``Parser.Term.matchAlts
|
||||
|
||||
structure InductiveView where
|
||||
ref : Syntax
|
||||
declId : Syntax
|
||||
modifiers : Modifiers
|
||||
shortDeclName : Name
|
||||
declName : Name
|
||||
levelNames : List Name
|
||||
binders : Syntax
|
||||
type? : Option Syntax
|
||||
ctors : Array CtorView
|
||||
derivingClasses : Array DerivingClassView
|
||||
computedFields : Array ComputedFieldView
|
||||
deriving Inhabited
|
||||
|
||||
structure ElabHeaderResult where
|
||||
view : InductiveView
|
||||
lctx : LocalContext
|
||||
localInsts : LocalInstances
|
||||
levelNames : List Name
|
||||
params : Array Expr
|
||||
type : Expr
|
||||
deriving Inhabited
|
||||
|
||||
/-
|
||||
leading_parser "inductive " >> declId >> optDeclSig >> optional ("where" <|> ":=") >> many ctor
|
||||
leading_parser atomic (group ("class " >> "inductive ")) >> declId >> optDeclSig >> optional ("where" <|> ":=") >> many ctor >> optDeriving
|
||||
-/
|
||||
private def inductiveSyntaxToView (modifiers : Modifiers) (decl : Syntax) : TermElabM InductiveView := do
|
||||
let isClass := decl.isOfKind ``Parser.Command.classInductive
|
||||
let modifiers := if isClass then modifiers.addAttr { name := `class } else modifiers
|
||||
checkValidInductiveModifier modifiers
|
||||
let (binders, type?) := expandOptDeclSig decl[2]
|
||||
let declId := decl[1]
|
||||
let ⟨name, declName, levelNames⟩ ← Term.expandDeclId (← getCurrNamespace) (← Term.getLevelNames) declId modifiers
|
||||
addDeclarationRangesForBuiltin declName modifiers.stx decl
|
||||
let ctors ← decl[4].getArgs.mapM fun ctor => withRef ctor do
|
||||
/-
|
||||
```
|
||||
def ctor := leading_parser optional docComment >> "\n| " >> declModifiers >> rawIdent >> optDeclSig
|
||||
```
|
||||
-/
|
||||
-- def ctor := leading_parser optional docComment >> "\n| " >> declModifiers >> rawIdent >> optDeclSig
|
||||
let mut ctorModifiers ← elabModifiers ⟨ctor[2]⟩
|
||||
if let some leadingDocComment := ctor[0].getOptional? then
|
||||
if ctorModifiers.docString?.isSome then
|
||||
@@ -47,7 +113,7 @@ private def inductiveSyntaxToView (modifiers : Modifiers) (decl : Syntax) : Term
|
||||
let (binders, type?) := expandOptDeclSig ctor[4]
|
||||
addDocString' ctorName ctorModifiers.docString?
|
||||
addDeclarationRangesFromSyntax ctorName ctor ctor[3]
|
||||
return { ref := ctor, declId := ctor[3], modifiers := ctorModifiers, declName := ctorName, binders := binders, type? := type? : CtorView }
|
||||
return { ref := ctor, modifiers := ctorModifiers, declName := ctorName, binders := binders, type? := type? : CtorView }
|
||||
let computedFields ← (decl[5].getOptional?.map (·[1].getArgs) |>.getD #[]).mapM fun cf => withRef cf do
|
||||
return { ref := cf, modifiers := cf[0], fieldId := cf[1].getId, type := ⟨cf[3]⟩, matchAlts := ⟨cf[4]⟩ }
|
||||
let classes ← getOptDerivingClasses decl[6]
|
||||
@@ -59,13 +125,137 @@ private def inductiveSyntaxToView (modifiers : Modifiers) (decl : Syntax) : Term
|
||||
ref := decl
|
||||
shortDeclName := name
|
||||
derivingClasses := classes
|
||||
allowIndices := true
|
||||
allowSortPolymorphism := true
|
||||
declId, modifiers, isClass, declName, levelNames
|
||||
declId, modifiers, declName, levelNames
|
||||
binders, type?, ctors
|
||||
computedFields
|
||||
}
|
||||
|
||||
private partial def elabHeaderAux (views : Array InductiveView) (i : Nat) (acc : Array ElabHeaderResult) : TermElabM (Array ElabHeaderResult) :=
|
||||
Term.withAutoBoundImplicitForbiddenPred (fun n => views.any (·.shortDeclName == n)) do
|
||||
if h : i < views.size then
|
||||
let view := views[i]
|
||||
let acc ← Term.withAutoBoundImplicit <| Term.elabBinders view.binders.getArgs fun params => do
|
||||
match view.type? with
|
||||
| none =>
|
||||
let u ← mkFreshLevelMVar
|
||||
let type := mkSort u
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
Term.addAutoBoundImplicits' params type fun params type => do
|
||||
let levelNames ← Term.getLevelNames
|
||||
return acc.push { lctx := (← getLCtx), localInsts := (← getLocalInstances), levelNames, params, type, view }
|
||||
| some typeStx =>
|
||||
let (type, _) ← Term.withAutoBoundImplicit do
|
||||
let type ← Term.elabType typeStx
|
||||
unless (← isTypeFormerType type) do
|
||||
throwErrorAt typeStx "invalid inductive type, resultant type is not a sort"
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let indices ← Term.addAutoBoundImplicits #[]
|
||||
return (← mkForallFVars indices type, indices.size)
|
||||
Term.addAutoBoundImplicits' params type fun params type => do
|
||||
trace[Elab.inductive] "header params: {params}, type: {type}"
|
||||
let levelNames ← Term.getLevelNames
|
||||
return acc.push { lctx := (← getLCtx), localInsts := (← getLocalInstances), levelNames, params, type, view }
|
||||
elabHeaderAux views (i+1) acc
|
||||
else
|
||||
return acc
|
||||
|
||||
private def checkNumParams (rs : Array ElabHeaderResult) : TermElabM Nat := do
|
||||
let numParams := rs[0]!.params.size
|
||||
for r in rs do
|
||||
unless r.params.size == numParams do
|
||||
throwErrorAt r.view.ref "invalid inductive type, number of parameters mismatch in mutually inductive datatypes"
|
||||
return numParams
|
||||
|
||||
private def checkUnsafe (rs : Array ElabHeaderResult) : TermElabM Unit := do
|
||||
let isUnsafe := rs[0]!.view.modifiers.isUnsafe
|
||||
for r in rs do
|
||||
unless r.view.modifiers.isUnsafe == isUnsafe do
|
||||
throwErrorAt r.view.ref "invalid inductive type, cannot mix unsafe and safe declarations in a mutually inductive datatypes"
|
||||
|
||||
private def InductiveView.checkLevelNames (views : Array InductiveView) : TermElabM Unit := do
|
||||
if h : views.size > 1 then
|
||||
let levelNames := views[0].levelNames
|
||||
for view in views do
|
||||
unless view.levelNames == levelNames do
|
||||
throwErrorAt view.ref "invalid inductive type, universe parameters mismatch in mutually inductive datatypes"
|
||||
|
||||
private def ElabHeaderResult.checkLevelNames (rs : Array ElabHeaderResult) : TermElabM Unit := do
|
||||
if h : rs.size > 1 then
|
||||
let levelNames := rs[0].levelNames
|
||||
for r in rs do
|
||||
unless r.levelNames == levelNames do
|
||||
throwErrorAt r.view.ref "invalid inductive type, universe parameters mismatch in mutually inductive datatypes"
|
||||
|
||||
private def mkTypeFor (r : ElabHeaderResult) : TermElabM Expr := do
|
||||
withLCtx r.lctx r.localInsts do
|
||||
mkForallFVars r.params r.type
|
||||
|
||||
private def throwUnexpectedInductiveType : TermElabM α :=
|
||||
throwError "unexpected inductive resulting type"
|
||||
|
||||
private def eqvFirstTypeResult (firstType type : Expr) : MetaM Bool :=
|
||||
forallTelescopeReducing firstType fun _ firstTypeResult => isDefEq firstTypeResult type
|
||||
|
||||
-- Auxiliary function for checking whether the types in mutually inductive declaration are compatible.
|
||||
private partial def checkParamsAndResultType (type firstType : Expr) (numParams : Nat) : TermElabM Unit := do
|
||||
try
|
||||
forallTelescopeCompatible type firstType numParams fun _ type firstType =>
|
||||
forallTelescopeReducing type fun _ type =>
|
||||
forallTelescopeReducing firstType fun _ firstType => do
|
||||
let type ← whnfD type
|
||||
match type with
|
||||
| .sort .. =>
|
||||
unless (← isDefEq firstType type) do
|
||||
throwError "resulting universe mismatch, given{indentExpr type}\nexpected type{indentExpr firstType}"
|
||||
| _ =>
|
||||
throwError "unexpected inductive resulting type"
|
||||
catch
|
||||
| Exception.error ref msg => throw (Exception.error ref m!"invalid mutually inductive types, {msg}")
|
||||
| ex => throw ex
|
||||
|
||||
-- Auxiliary function for checking whether the types in mutually inductive declaration are compatible.
|
||||
private def checkHeader (r : ElabHeaderResult) (numParams : Nat) (firstType? : Option Expr) : TermElabM Expr := do
|
||||
let type ← mkTypeFor r
|
||||
match firstType? with
|
||||
| none => return type
|
||||
| some firstType =>
|
||||
withRef r.view.ref <| checkParamsAndResultType type firstType numParams
|
||||
return firstType
|
||||
|
||||
-- Auxiliary function for checking whether the types in mutually inductive declaration are compatible.
|
||||
private partial def checkHeaders (rs : Array ElabHeaderResult) (numParams : Nat) (i : Nat) (firstType? : Option Expr) : TermElabM Unit := do
|
||||
if h : i < rs.size then
|
||||
let type ← checkHeader rs[i] numParams firstType?
|
||||
checkHeaders rs numParams (i+1) type
|
||||
|
||||
private def elabHeader (views : Array InductiveView) : TermElabM (Array ElabHeaderResult) := do
|
||||
let rs ← elabHeaderAux views 0 #[]
|
||||
if rs.size > 1 then
|
||||
checkUnsafe rs
|
||||
let numParams ← checkNumParams rs
|
||||
checkHeaders rs numParams 0 none
|
||||
return rs
|
||||
|
||||
/-- Create a local declaration for each inductive type in `rs`, and execute `x params indFVars`, where `params` are the inductive type parameters and
|
||||
`indFVars` are the new local declarations.
|
||||
We use the local context/instances and parameters of rs[0].
|
||||
Note that this method is executed after we executed `checkHeaders` and established all
|
||||
parameters are compatible. -/
|
||||
private partial def withInductiveLocalDecls (rs : Array ElabHeaderResult) (x : Array Expr → Array Expr → TermElabM α) : TermElabM α := do
|
||||
let namesAndTypes ← rs.mapM fun r => do
|
||||
let type ← mkTypeFor r
|
||||
pure (r.view.declName, r.view.shortDeclName, type)
|
||||
let r0 := rs[0]!
|
||||
let params := r0.params
|
||||
withLCtx r0.lctx r0.localInsts <| withRef r0.view.ref do
|
||||
let rec loop (i : Nat) (indFVars : Array Expr) := do
|
||||
if h : i < namesAndTypes.size then
|
||||
let (declName, shortDeclName, type) := namesAndTypes[i]
|
||||
Term.withAuxDecl shortDeclName type declName fun indFVar => loop (i+1) (indFVars.push indFVar)
|
||||
else
|
||||
x params indFVars
|
||||
loop 0 #[]
|
||||
|
||||
private def isInductiveFamily (numParams : Nat) (indFVar : Expr) : TermElabM Bool := do
|
||||
let indFVarType ← inferType indFVar
|
||||
forallTelescopeReducing indFVarType fun xs _ =>
|
||||
@@ -81,8 +271,8 @@ where
|
||||
| _ => acc
|
||||
|
||||
/--
|
||||
Replaces binder names in `type` with `newNames`.
|
||||
Remark: we only replace the names for binder containing macroscopes.
|
||||
Replace binder names in `type` with `newNames`.
|
||||
Remark: we only replace the names for binder containing macroscopes.
|
||||
-/
|
||||
private def replaceArrowBinderNames (type : Expr) (newNames : Array Name) : Expr :=
|
||||
go type 0
|
||||
@@ -100,20 +290,20 @@ where
|
||||
type
|
||||
|
||||
/--
|
||||
Reorders constructor arguments to improve the effectiveness of the `fixedIndicesToParams` method.
|
||||
Reorder constructor arguments to improve the effectiveness of the `fixedIndicesToParams` method.
|
||||
|
||||
The idea is quite simple. Given a constructor type of the form
|
||||
```
|
||||
(a₁ : A₁) → ... → (aₙ : Aₙ) → C b₁ ... bₘ
|
||||
```
|
||||
We try to find the longest prefix `b₁ ... bᵢ`, `i ≤ m` s.t.
|
||||
- each `bₖ` is in `{a₁, ..., aₙ}`
|
||||
- each `bₖ` only depends on variables in `{b₁, ..., bₖ₋₁}`
|
||||
The idea is quite simple. Given a constructor type of the form
|
||||
```
|
||||
(a₁ : A₁) → ... → (aₙ : Aₙ) → C b₁ ... bₘ
|
||||
```
|
||||
We try to find the longest prefix `b₁ ... bᵢ`, `i ≤ m` s.t.
|
||||
- each `bₖ` is in `{a₁, ..., aₙ}`
|
||||
- each `bₖ` only depends on variables in `{b₁, ..., bₖ₋₁}`
|
||||
|
||||
Then, it moves this prefix `b₁ ... bᵢ` to the front.
|
||||
Then, it moves this prefix `b₁ ... bᵢ` to the front.
|
||||
|
||||
Remark: We only reorder implicit arguments that have macroscopes. See issue #1156.
|
||||
The macroscope test is an approximation, we could have restricted ourselves to auto-implicit arguments.
|
||||
Remark: We only reorder implicit arguments that have macroscopes. See issue #1156.
|
||||
The macroscope test is an approximation, we could have restricted ourselves to auto-implicit arguments.
|
||||
-/
|
||||
private def reorderCtorArgs (ctorType : Expr) : MetaM Expr := do
|
||||
forallTelescopeReducing ctorType fun as type => do
|
||||
@@ -158,6 +348,16 @@ private def reorderCtorArgs (ctorType : Expr) : MetaM Expr := do
|
||||
let binderNames := getArrowBinderNames (← instantiateMVars (← inferType C))
|
||||
return replaceArrowBinderNames r binderNames[:bsPrefix.size]
|
||||
|
||||
/--
|
||||
Execute `k` with updated binder information for `xs`. Any `x` that is explicit becomes implicit.
|
||||
-/
|
||||
private def withExplicitToImplicit (xs : Array Expr) (k : TermElabM α) : TermElabM α := do
|
||||
let mut toImplicit := #[]
|
||||
for x in xs do
|
||||
if (← getFVarLocalDecl x).binderInfo.isExplicit then
|
||||
toImplicit := toImplicit.push (x.fvarId!, BinderInfo.implicit)
|
||||
withNewBinderInfos toImplicit k
|
||||
|
||||
/--
|
||||
Elaborate constructor types.
|
||||
|
||||
@@ -166,20 +366,17 @@ private def reorderCtorArgs (ctorType : Expr) : MetaM Expr := do
|
||||
- Positivity (it is a rare failure, and the kernel already checks for it).
|
||||
- Universe constraints (the kernel checks for it).
|
||||
-/
|
||||
private def elabCtors (indFVars : Array Expr) (params : Array Expr) (r : ElabHeaderResult) : TermElabM (List Constructor) :=
|
||||
withRef r.view.ref do
|
||||
withExplicitToImplicit params do
|
||||
let indFVar := r.indFVar
|
||||
private def elabCtors (indFVars : Array Expr) (indFVar : Expr) (params : Array Expr) (r : ElabHeaderResult) : TermElabM (List Constructor) := withRef r.view.ref do
|
||||
let indFamily ← isInductiveFamily params.size indFVar
|
||||
r.view.ctors.toList.mapM fun ctorView =>
|
||||
Term.withAutoBoundImplicit <| Term.elabBinders ctorView.binders.getArgs fun ctorParams =>
|
||||
withRef ctorView.ref do
|
||||
let elabCtorType : TermElabM Expr := do
|
||||
let rec elabCtorType (k : Expr → TermElabM Constructor) : TermElabM Constructor := do
|
||||
match ctorView.type? with
|
||||
| none =>
|
||||
if indFamily then
|
||||
throwError "constructor resulting type must be specified in inductive family declaration"
|
||||
return mkAppN indFVar params
|
||||
k <| mkAppN indFVar params
|
||||
| some ctorType =>
|
||||
let type ← Term.elabType ctorType
|
||||
trace[Elab.inductive] "elabType {ctorView.declName} : {type} "
|
||||
@@ -191,43 +388,43 @@ private def elabCtors (indFVars : Array Expr) (params : Array Expr) (r : ElabHea
|
||||
throwError "unexpected constructor resulting type{indentExpr resultingType}"
|
||||
unless (← isType resultingType) do
|
||||
throwError "unexpected constructor resulting type, type expected{indentExpr resultingType}"
|
||||
return type
|
||||
let type ← elabCtorType
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let ctorParams ← Term.addAutoBoundImplicits ctorParams
|
||||
let except (mvarId : MVarId) := ctorParams.any fun ctorParam => ctorParam.isMVar && ctorParam.mvarId! == mvarId
|
||||
/-
|
||||
We convert metavariables in the resulting type into extra parameters. Otherwise, we would not be able to elaborate
|
||||
declarations such as
|
||||
```
|
||||
inductive Palindrome : List α → Prop where
|
||||
| nil : Palindrome [] -- We would get an error here saying "failed to synthesize implicit argument" at `@List.nil ?m`
|
||||
| single : (a : α) → Palindrome [a]
|
||||
| sandwich : (a : α) → Palindrome as → Palindrome ([a] ++ as ++ [a])
|
||||
```
|
||||
We used to also collect unassigned metavariables on `ctorParams`, but it produced counterintuitive behavior.
|
||||
For example, the following declaration used to be accepted.
|
||||
```
|
||||
inductive Foo
|
||||
| bar (x)
|
||||
k type
|
||||
elabCtorType fun type => do
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let ctorParams ← Term.addAutoBoundImplicits ctorParams
|
||||
let except (mvarId : MVarId) := ctorParams.any fun ctorParam => ctorParam.isMVar && ctorParam.mvarId! == mvarId
|
||||
/-
|
||||
We convert metavariables in the resulting type info extra parameters. Otherwise, we would not be able to elaborate
|
||||
declarations such as
|
||||
```
|
||||
inductive Palindrome : List α → Prop where
|
||||
| nil : Palindrome [] -- We would get an error here saying "failed to synthesize implicit argument" at `@List.nil ?m`
|
||||
| single : (a : α) → Palindrome [a]
|
||||
| sandwich : (a : α) → Palindrome as → Palindrome ([a] ++ as ++ [a])
|
||||
```
|
||||
We used to also collect unassigned metavariables on `ctorParams`, but it produced counterintuitive behavior.
|
||||
For example, the following declaration used to be accepted.
|
||||
```
|
||||
inductive Foo
|
||||
| bar (x)
|
||||
|
||||
#check Foo.bar
|
||||
-- @Foo.bar : {x : Sort u_1} → x → Foo
|
||||
```
|
||||
which is also inconsistent with the behavior of auto implicits in definitions. For example, the following example was never accepted.
|
||||
```
|
||||
def bar (x) := 1
|
||||
```
|
||||
-/
|
||||
let extraCtorParams ← Term.collectUnassignedMVars (← instantiateMVars type) #[] except
|
||||
trace[Elab.inductive] "extraCtorParams: {extraCtorParams}"
|
||||
/- We must abstract `extraCtorParams` and `ctorParams` simultaneously to make
|
||||
sure we do not create auxiliary metavariables. -/
|
||||
let type ← mkForallFVars (extraCtorParams ++ ctorParams) type
|
||||
let type ← reorderCtorArgs type
|
||||
let type ← mkForallFVars params type
|
||||
trace[Elab.inductive] "{ctorView.declName} : {type}"
|
||||
return { name := ctorView.declName, type }
|
||||
#check Foo.bar
|
||||
-- @Foo.bar : {x : Sort u_1} → x → Foo
|
||||
```
|
||||
which is also inconsistent with the behavior of auto implicits in definitions. For example, the following example was never accepted.
|
||||
```
|
||||
def bar (x) := 1
|
||||
```
|
||||
-/
|
||||
let extraCtorParams ← Term.collectUnassignedMVars (← instantiateMVars type) #[] except
|
||||
trace[Elab.inductive] "extraCtorParams: {extraCtorParams}"
|
||||
/- We must abstract `extraCtorParams` and `ctorParams` simultaneously to make
|
||||
sure we do not create auxiliary metavariables. -/
|
||||
let type ← mkForallFVars (extraCtorParams ++ ctorParams) type
|
||||
let type ← reorderCtorArgs type
|
||||
let type ← mkForallFVars params type
|
||||
trace[Elab.inductive] "{ctorView.declName} : {type}"
|
||||
return { name := ctorView.declName, type }
|
||||
where
|
||||
checkParamOccs (ctorType : Expr) : MetaM Expr :=
|
||||
let visit (e : Expr) : MetaM TransformStep := do
|
||||
@@ -247,15 +444,555 @@ where
|
||||
return .continue
|
||||
transform ctorType (pre := visit)
|
||||
|
||||
@[builtin_inductive_elab Lean.Parser.Command.inductive, builtin_inductive_elab Lean.Parser.Command.classInductive]
|
||||
def elabInductiveCommand : InductiveElabDescr where
|
||||
mkInductiveView (modifiers : Modifiers) (stx : Syntax) := do
|
||||
let view ← inductiveSyntaxToView modifiers stx
|
||||
return {
|
||||
view
|
||||
elabCtors := fun rs r params => do
|
||||
let ctors ← elabCtors (rs.map (·.indFVar)) params r
|
||||
return { ctors }
|
||||
}
|
||||
private def getResultingUniverse : List InductiveType → TermElabM Level
|
||||
| [] => throwError "unexpected empty inductive declaration"
|
||||
| indType :: _ => forallTelescopeReducing indType.type fun _ r => do
|
||||
let r ← whnfD r
|
||||
match r with
|
||||
| Expr.sort u => return u
|
||||
| _ => throwError "unexpected inductive type resulting type{indentExpr r}"
|
||||
|
||||
/--
|
||||
Return `some ?m` if `u` is of the form `?m + k`.
|
||||
Return none if `u` does not contain universe metavariables.
|
||||
Throw exception otherwise. -/
|
||||
def shouldInferResultUniverse (u : Level) : TermElabM (Option LMVarId) := do
|
||||
let u ← instantiateLevelMVars u
|
||||
if u.hasMVar then
|
||||
match u.getLevelOffset with
|
||||
| Level.mvar mvarId => return some mvarId
|
||||
| _ =>
|
||||
throwError "cannot infer resulting universe level of inductive datatype, given level contains metavariables {mkSort u}, provide universe explicitly"
|
||||
else
|
||||
return none
|
||||
|
||||
/--
|
||||
Convert universe metavariables into new parameters. It skips `univToInfer?` (the inductive datatype resulting universe) because
|
||||
it should be inferred later using `inferResultingUniverse`.
|
||||
-/
|
||||
private def levelMVarToParam (indTypes : List InductiveType) (univToInfer? : Option LMVarId) : TermElabM (List InductiveType) :=
|
||||
indTypes.mapM fun indType => do
|
||||
let type ← levelMVarToParam' indType.type
|
||||
let ctors ← indType.ctors.mapM fun ctor => do
|
||||
let ctorType ← levelMVarToParam' ctor.type
|
||||
return { ctor with type := ctorType }
|
||||
return { indType with ctors, type }
|
||||
where
|
||||
levelMVarToParam' (type : Expr) : TermElabM Expr := do
|
||||
Term.levelMVarToParam type (except := fun mvarId => univToInfer? == some mvarId)
|
||||
|
||||
def mkResultUniverse (us : Array Level) (rOffset : Nat) (preferProp : Bool) : Level :=
|
||||
if us.isEmpty && rOffset == 0 then
|
||||
if preferProp then levelZero else levelOne
|
||||
else
|
||||
let r := Level.mkNaryMax us.toList
|
||||
if rOffset == 0 && !r.isZero && !r.isNeverZero then
|
||||
mkLevelMax r levelOne |>.normalize
|
||||
else
|
||||
r.normalize
|
||||
|
||||
/--
|
||||
Auxiliary function for `updateResultingUniverse`
|
||||
`accLevel u r rOffset` add `u` to state if it is not already there and
|
||||
it is different from the resulting universe level `r+rOffset`.
|
||||
|
||||
|
||||
If `u` is a `max`, then its components are recursively processed.
|
||||
If `u` is a `succ` and `rOffset > 0`, we process the `u`s child using `rOffset-1`.
|
||||
|
||||
This method is used to infer the resulting universe level of an inductive datatype.
|
||||
-/
|
||||
def accLevel (u : Level) (r : Level) (rOffset : Nat) : OptionT (StateT (Array Level) Id) Unit := do
|
||||
go u rOffset
|
||||
where
|
||||
go (u : Level) (rOffset : Nat) : OptionT (StateT (Array Level) Id) Unit := do
|
||||
match u, rOffset with
|
||||
| .max u v, rOffset => go u rOffset; go v rOffset
|
||||
| .imax u v, rOffset => go u rOffset; go v rOffset
|
||||
| .zero, _ => return ()
|
||||
| .succ u, rOffset+1 => go u rOffset
|
||||
| u, rOffset =>
|
||||
if rOffset == 0 && u == r then
|
||||
return ()
|
||||
else if r.occurs u then
|
||||
failure
|
||||
else if rOffset > 0 then
|
||||
failure
|
||||
else if (← get).contains u then
|
||||
return ()
|
||||
else
|
||||
modify fun us => us.push u
|
||||
|
||||
/--
|
||||
Auxiliary function for `updateResultingUniverse`
|
||||
`accLevelAtCtor ctor ctorParam r rOffset` add `u` (`ctorParam`'s universe) to state if it is not already there and
|
||||
it is different from the resulting universe level `r+rOffset`.
|
||||
|
||||
See `accLevel`.
|
||||
-/
|
||||
def accLevelAtCtor (ctor : Constructor) (ctorParam : Expr) (r : Level) (rOffset : Nat) : StateRefT (Array Level) TermElabM Unit := do
|
||||
let type ← inferType ctorParam
|
||||
let u ← instantiateLevelMVars (← getLevel type)
|
||||
match (← modifyGet fun s => accLevel u r rOffset |>.run |>.run s) with
|
||||
| some _ => pure ()
|
||||
| none =>
|
||||
let typeType ← inferType type
|
||||
let mut msg := m!"failed to compute resulting universe level of inductive datatype, constructor '{ctor.name}' has type{indentExpr ctor.type}\nparameter"
|
||||
let localDecl ← getFVarLocalDecl ctorParam
|
||||
unless localDecl.userName.hasMacroScopes do
|
||||
msg := msg ++ m!" '{ctorParam}'"
|
||||
msg := msg ++ m!" has type{indentD m!"{type} : {typeType}"}\ninductive type resulting type{indentExpr (mkSort (r.addOffset rOffset))}"
|
||||
if r.isMVar then
|
||||
msg := msg ++ "\nrecall that Lean only infers the resulting universe level automatically when there is a unique solution for the universe level constraints, consider explicitly providing the inductive type resulting universe level"
|
||||
throwError msg
|
||||
|
||||
/--
|
||||
Execute `k` using the `Syntax` reference associated with constructor `ctorName`.
|
||||
-/
|
||||
def withCtorRef [Monad m] [MonadRef m] (views : Array InductiveView) (ctorName : Name) (k : m α) : m α := do
|
||||
for view in views do
|
||||
for ctorView in view.ctors do
|
||||
if ctorView.declName == ctorName then
|
||||
return (← withRef ctorView.ref k)
|
||||
k
|
||||
|
||||
/-- Auxiliary function for `updateResultingUniverse` -/
|
||||
private partial def collectUniverses (views : Array InductiveView) (r : Level) (rOffset : Nat) (numParams : Nat) (indTypes : List InductiveType) : TermElabM (Array Level) := do
|
||||
let (_, us) ← go |>.run #[]
|
||||
return us
|
||||
where
|
||||
go : StateRefT (Array Level) TermElabM Unit :=
|
||||
indTypes.forM fun indType => indType.ctors.forM fun ctor =>
|
||||
withCtorRef views ctor.name do
|
||||
forallTelescopeReducing ctor.type fun ctorParams _ =>
|
||||
for ctorParam in ctorParams[numParams:] do
|
||||
accLevelAtCtor ctor ctorParam r rOffset
|
||||
|
||||
/--
|
||||
Decides whether the inductive type should be `Prop`-valued when the universe is not given
|
||||
and when the universe inference algorithm `collectUniverses` determines
|
||||
that the inductive type could naturally be `Prop`-valued.
|
||||
Recall: the natural universe level is the mimimum universe level for all the types of all the constructor parameters.
|
||||
|
||||
Heuristic:
|
||||
- We want `Prop` when each inductive type is a syntactic subsingleton.
|
||||
That's to say, when each inductive type has at most one constructor.
|
||||
Such types carry no data anyway.
|
||||
- Exception: if no inductive type has any constructors, these are likely stubbed-out declarations,
|
||||
so we prefer `Type` instead.
|
||||
- Exception: if each constructor has no parameters, then these are likely partially-written enumerations,
|
||||
so we prefer `Type` instead.
|
||||
-/
|
||||
private def isPropCandidate (numParams : Nat) (indTypes : List InductiveType) : MetaM Bool := do
|
||||
unless indTypes.foldl (fun n indType => max n indType.ctors.length) 0 == 1 do
|
||||
return false
|
||||
for indType in indTypes do
|
||||
for ctor in indType.ctors do
|
||||
let cparams ← forallTelescopeReducing ctor.type fun ctorParams _ => pure (ctorParams.size - numParams)
|
||||
unless cparams == 0 do
|
||||
return true
|
||||
return false
|
||||
|
||||
private def updateResultingUniverse (views : Array InductiveView) (numParams : Nat) (indTypes : List InductiveType) : TermElabM (List InductiveType) := do
|
||||
let r ← getResultingUniverse indTypes
|
||||
let rOffset : Nat := r.getOffset
|
||||
let r : Level := r.getLevelOffset
|
||||
unless r.isMVar do
|
||||
throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly: {r}"
|
||||
let us ← collectUniverses views r rOffset numParams indTypes
|
||||
trace[Elab.inductive] "updateResultingUniverse us: {us}, r: {r}, rOffset: {rOffset}"
|
||||
let rNew := mkResultUniverse us rOffset (← isPropCandidate numParams indTypes)
|
||||
assignLevelMVar r.mvarId! rNew
|
||||
indTypes.mapM fun indType => do
|
||||
let type ← instantiateMVars indType.type
|
||||
let ctors ← indType.ctors.mapM fun ctor => return { ctor with type := (← instantiateMVars ctor.type) }
|
||||
return { indType with type, ctors }
|
||||
|
||||
register_builtin_option bootstrap.inductiveCheckResultingUniverse : Bool := {
|
||||
defValue := true,
|
||||
group := "bootstrap",
|
||||
descr := "by default the `inductive`/`structure` commands report an error if the resulting universe is not zero, but may be zero for some universe parameters. Reason: unless this type is a subsingleton, it is hardly what the user wants since it can only eliminate into `Prop`. In the `Init` package, we define subsingletons, and we use this option to disable the check. This option may be deleted in the future after we improve the validator"
|
||||
}
|
||||
|
||||
def checkResultingUniverse (u : Level) : TermElabM Unit := do
|
||||
if bootstrap.inductiveCheckResultingUniverse.get (← getOptions) then
|
||||
let u ← instantiateLevelMVars u
|
||||
if !u.isZero && !u.isNeverZero then
|
||||
throwError "invalid universe polymorphic type, the resultant universe is not Prop (i.e., 0), but it may be Prop for some parameter values (solution: use 'u+1' or 'max 1 u'){indentD u}"
|
||||
|
||||
private def checkResultingUniverses (views : Array InductiveView) (numParams : Nat) (indTypes : List InductiveType) : TermElabM Unit := do
|
||||
let u := (← instantiateLevelMVars (← getResultingUniverse indTypes)).normalize
|
||||
checkResultingUniverse u
|
||||
unless u.isZero do
|
||||
indTypes.forM fun indType => indType.ctors.forM fun ctor =>
|
||||
forallTelescopeReducing ctor.type fun ctorArgs _ => do
|
||||
for ctorArg in ctorArgs[numParams:] do
|
||||
let type ← inferType ctorArg
|
||||
let v := (← instantiateLevelMVars (← getLevel type)).normalize
|
||||
let rec check (v' : Level) (u' : Level) : TermElabM Unit :=
|
||||
match v', u' with
|
||||
| .succ v', .succ u' => check v' u'
|
||||
| .mvar id, .param .. =>
|
||||
/- Special case:
|
||||
The constructor parameter `v` is at unverse level `?v+k` and
|
||||
the resulting inductive universe level is `u'+k`, where `u'` is a parameter (or zero).
|
||||
Thus, `?v := u'` is the only choice for satisfying the universe constraint `?v+k <= u'+k`.
|
||||
Note that, we still generate an error for cases where there is more than one of satisfying the constraint.
|
||||
Examples:
|
||||
-----------------------------------------------------------
|
||||
| ctor universe level | inductive datatype universe level |
|
||||
-----------------------------------------------------------
|
||||
| ?v | max u w |
|
||||
-----------------------------------------------------------
|
||||
| ?v | u + 1 |
|
||||
-----------------------------------------------------------
|
||||
-/
|
||||
assignLevelMVar id u'
|
||||
| .mvar id, .zero => assignLevelMVar id u' -- TODO: merge with previous case
|
||||
| _, _ =>
|
||||
unless u.geq v do
|
||||
let mut msg := m!"invalid universe level in constructor '{ctor.name}', parameter"
|
||||
let localDecl ← getFVarLocalDecl ctorArg
|
||||
unless localDecl.userName.hasMacroScopes do
|
||||
msg := msg ++ m!" '{ctorArg}'"
|
||||
msg := msg ++ m!" has type{indentExpr type}"
|
||||
msg := msg ++ m!"\nat universe level{indentD v}"
|
||||
msg := msg ++ m!"\nit must be smaller than or equal to the inductive datatype universe level{indentD u}"
|
||||
withCtorRef views ctor.name <| throwError msg
|
||||
check v u
|
||||
|
||||
private def collectUsed (indTypes : List InductiveType) : StateRefT CollectFVars.State MetaM Unit := do
|
||||
indTypes.forM fun indType => do
|
||||
indType.type.collectFVars
|
||||
indType.ctors.forM fun ctor =>
|
||||
ctor.type.collectFVars
|
||||
|
||||
private def removeUnused (vars : Array Expr) (indTypes : List InductiveType) : TermElabM (LocalContext × LocalInstances × Array Expr) := do
|
||||
let (_, used) ← (collectUsed indTypes).run {}
|
||||
Meta.removeUnused vars used
|
||||
|
||||
private def withUsed {α} (vars : Array Expr) (indTypes : List InductiveType) (k : Array Expr → TermElabM α) : TermElabM α := do
|
||||
let (lctx, localInsts, vars) ← removeUnused vars indTypes
|
||||
withLCtx lctx localInsts <| k vars
|
||||
|
||||
private def updateParams (vars : Array Expr) (indTypes : List InductiveType) : TermElabM (List InductiveType) :=
|
||||
indTypes.mapM fun indType => do
|
||||
let type ← mkForallFVars vars indType.type
|
||||
let ctors ← indType.ctors.mapM fun ctor => do
|
||||
let ctorType ← withExplicitToImplicit vars (mkForallFVars vars ctor.type)
|
||||
return { ctor with type := ctorType }
|
||||
return { indType with type, ctors }
|
||||
|
||||
private def collectLevelParamsInInductive (indTypes : List InductiveType) : Array Name := Id.run do
|
||||
let mut usedParams : CollectLevelParams.State := {}
|
||||
for indType in indTypes do
|
||||
usedParams := collectLevelParams usedParams indType.type
|
||||
for ctor in indType.ctors do
|
||||
usedParams := collectLevelParams usedParams ctor.type
|
||||
return usedParams.params
|
||||
|
||||
private def mkIndFVar2Const (views : Array InductiveView) (indFVars : Array Expr) (levelNames : List Name) : ExprMap Expr := Id.run do
|
||||
let levelParams := levelNames.map mkLevelParam;
|
||||
let mut m : ExprMap Expr := {}
|
||||
for h : i in [:views.size] do
|
||||
let view := views[i]
|
||||
let indFVar := indFVars[i]!
|
||||
m := m.insert indFVar (mkConst view.declName levelParams)
|
||||
return m
|
||||
|
||||
/-- Remark: `numVars <= numParams`. `numVars` is the number of context `variables` used in the inductive declaration,
|
||||
and `numParams` is `numVars` + number of explicit parameters provided in the declaration. -/
|
||||
private def replaceIndFVarsWithConsts (views : Array InductiveView) (indFVars : Array Expr) (levelNames : List Name)
|
||||
(numVars : Nat) (numParams : Nat) (indTypes : List InductiveType) : TermElabM (List InductiveType) :=
|
||||
let indFVar2Const := mkIndFVar2Const views indFVars levelNames
|
||||
indTypes.mapM fun indType => do
|
||||
let ctors ← indType.ctors.mapM fun ctor => do
|
||||
let type ← forallBoundedTelescope ctor.type numParams fun params type => do
|
||||
let type := type.replace fun e =>
|
||||
if !e.isFVar then
|
||||
none
|
||||
else match indFVar2Const[e]? with
|
||||
| none => none
|
||||
| some c => mkAppN c (params.extract 0 numVars)
|
||||
instantiateMVars (← mkForallFVars params type)
|
||||
return { ctor with type }
|
||||
return { indType with ctors }
|
||||
|
||||
private def mkAuxConstructions (views : Array InductiveView) : TermElabM Unit := do
|
||||
let env ← getEnv
|
||||
let hasEq := env.contains ``Eq
|
||||
let hasHEq := env.contains ``HEq
|
||||
let hasUnit := env.contains ``PUnit
|
||||
let hasProd := env.contains ``Prod
|
||||
for view in views do
|
||||
let n := view.declName
|
||||
mkRecOn n
|
||||
if hasUnit then mkCasesOn n
|
||||
if hasUnit && hasEq && hasHEq then mkNoConfusion n
|
||||
if hasUnit && hasProd then mkBelow n
|
||||
if hasUnit && hasProd then mkIBelow n
|
||||
for view in views do
|
||||
let n := view.declName;
|
||||
if hasUnit && hasProd then mkBRecOn n
|
||||
if hasUnit && hasProd then mkBInductionOn n
|
||||
|
||||
private def getArity (indType : InductiveType) : MetaM Nat :=
|
||||
forallTelescopeReducing indType.type fun xs _ => return xs.size
|
||||
|
||||
private def resetMaskAt (mask : Array Bool) (i : Nat) : Array Bool :=
|
||||
mask.setD i false
|
||||
|
||||
/--
|
||||
Compute a bit-mask that for `indType`. The size of the resulting array `result` is the arity of `indType`.
|
||||
The first `numParams` elements are `false` since they are parameters.
|
||||
For `i ∈ [numParams, arity)`, we have that `result[i]` if this index of the inductive family is fixed.
|
||||
-/
|
||||
private def computeFixedIndexBitMask (numParams : Nat) (indType : InductiveType) (indFVars : Array Expr) : MetaM (Array Bool) := do
|
||||
let arity ← getArity indType
|
||||
if arity ≤ numParams then
|
||||
return mkArray arity false
|
||||
else
|
||||
let maskRef ← IO.mkRef (mkArray numParams false ++ mkArray (arity - numParams) true)
|
||||
let rec go (ctors : List Constructor) : MetaM (Array Bool) := do
|
||||
match ctors with
|
||||
| [] => maskRef.get
|
||||
| ctor :: ctors =>
|
||||
forallTelescopeReducing ctor.type fun xs type => do
|
||||
let typeArgs := type.getAppArgs
|
||||
for i in [numParams:arity] do
|
||||
unless i < xs.size && xs[i]! == typeArgs[i]! do -- Remark: if we want to allow arguments to be rearranged, this test should be xs.contains typeArgs[i]
|
||||
maskRef.modify fun mask => mask.set! i false
|
||||
for x in xs[numParams:] do
|
||||
let xType ← inferType x
|
||||
let cond (e : Expr) := indFVars.any (fun indFVar => e.getAppFn == indFVar)
|
||||
xType.forEachWhere (stopWhenVisited := true) cond fun e => do
|
||||
let eArgs := e.getAppArgs
|
||||
for i in [numParams:eArgs.size] do
|
||||
if i >= typeArgs.size then
|
||||
maskRef.modify (resetMaskAt · i)
|
||||
else
|
||||
unless eArgs[i]! == typeArgs[i]! do
|
||||
maskRef.modify (resetMaskAt · i)
|
||||
/-If an index is missing in the arguments of the inductive type, then it must be non-fixed.
|
||||
Consider the following example:
|
||||
```lean
|
||||
inductive All {I : Type u} (P : I → Type v) : List I → Type (max u v) where
|
||||
| cons : P x → All P xs → All P (x :: xs)
|
||||
|
||||
inductive Iμ {I : Type u} : I → Type (max u v) where
|
||||
| mk : (i : I) → All Iμ [] → Iμ i
|
||||
```
|
||||
because `i` doesn't appear in `All Iμ []`, the index shouldn't be fixed.
|
||||
-/
|
||||
for i in [eArgs.size:arity] do
|
||||
maskRef.modify (resetMaskAt · i)
|
||||
go ctors
|
||||
go indType.ctors
|
||||
|
||||
/-- Return true iff `arrowType` is an arrow and its domain is defeq to `type` -/
|
||||
private def isDomainDefEq (arrowType : Expr) (type : Expr) : MetaM Bool := do
|
||||
if !arrowType.isForall then
|
||||
return false
|
||||
else
|
||||
/-
|
||||
We used to use `withNewMCtxDepth` to make sure we do not assign universe metavariables,
|
||||
but it was not satisfactory. For example, in declarations such as
|
||||
```
|
||||
inductive Eq : α → α → Prop where
|
||||
| refl (a : α) : Eq a a
|
||||
```
|
||||
We want the first two indices to be promoted to parameters, and this will only
|
||||
happen if we can assign universe metavariables.
|
||||
-/
|
||||
isDefEq arrowType.bindingDomain! type
|
||||
|
||||
/--
|
||||
Convert fixed indices to parameters.
|
||||
-/
|
||||
private partial def fixedIndicesToParams (numParams : Nat) (indTypes : Array InductiveType) (indFVars : Array Expr) : MetaM Nat := do
|
||||
if !inductive.autoPromoteIndices.get (← getOptions) then
|
||||
return numParams
|
||||
let masks ← indTypes.mapM (computeFixedIndexBitMask numParams · indFVars)
|
||||
trace[Elab.inductive] "masks: {masks}"
|
||||
if masks.all fun mask => !mask.contains true then
|
||||
return numParams
|
||||
-- We process just a non-fixed prefix of the indices for now. Reason: we don't want to change the order.
|
||||
-- TODO: extend it in the future. For example, it should be reasonable to change
|
||||
-- the order of indices generated by the auto implicit feature.
|
||||
let mask := masks[0]!
|
||||
forallBoundedTelescope indTypes[0]!.type numParams fun params type => do
|
||||
let otherTypes ← indTypes[1:].toArray.mapM fun indType => do whnfD (← instantiateForall indType.type params)
|
||||
let ctorTypes ← indTypes.toList.mapM fun indType => indType.ctors.mapM fun ctor => do whnfD (← instantiateForall ctor.type params)
|
||||
let typesToCheck := otherTypes.toList ++ ctorTypes.flatten
|
||||
let rec go (i : Nat) (type : Expr) (typesToCheck : List Expr) : MetaM Nat := do
|
||||
if i < mask.size then
|
||||
if !masks.all fun mask => i < mask.size && mask[i]! then
|
||||
return i
|
||||
if !type.isForall then
|
||||
return i
|
||||
let paramType := type.bindingDomain!
|
||||
if !(← typesToCheck.allM fun type => isDomainDefEq type paramType) then
|
||||
trace[Elab.inductive] "domain not def eq: {i}, {type} =?= {paramType}"
|
||||
return i
|
||||
withLocalDeclD `a paramType fun paramNew => do
|
||||
let typesToCheck ← typesToCheck.mapM fun type => whnfD (type.bindingBody!.instantiate1 paramNew)
|
||||
go (i+1) (type.bindingBody!.instantiate1 paramNew) typesToCheck
|
||||
else
|
||||
return i
|
||||
go numParams type typesToCheck
|
||||
|
||||
private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) : TermElabM Unit := Term.withoutSavingRecAppSyntax do
|
||||
let view0 := views[0]!
|
||||
let scopeLevelNames ← Term.getLevelNames
|
||||
InductiveView.checkLevelNames views
|
||||
let allUserLevelNames := view0.levelNames
|
||||
let isUnsafe := view0.modifiers.isUnsafe
|
||||
withRef view0.ref <| Term.withLevelNames allUserLevelNames do
|
||||
let rs ← elabHeader views
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
ElabHeaderResult.checkLevelNames rs
|
||||
let allUserLevelNames := rs[0]!.levelNames
|
||||
trace[Elab.inductive] "level names: {allUserLevelNames}"
|
||||
withInductiveLocalDecls rs fun params indFVars => do
|
||||
trace[Elab.inductive] "indFVars: {indFVars}"
|
||||
let mut indTypesArray := #[]
|
||||
for h : i in [:views.size] do
|
||||
let indFVar := indFVars[i]!
|
||||
Term.addLocalVarInfo views[i].declId indFVar
|
||||
let r := rs[i]!
|
||||
/- At this point, because of `withInductiveLocalDecls`, the only fvars that are in context are the ones related to the first inductive type.
|
||||
Because of this, we need to replace the fvars present in each inductive type's header of the mutual block with those of the first inductive.
|
||||
However, some mvars may still be uninstantiated there, and might hide some of the old fvars.
|
||||
As such we first need to synthesize all possible mvars at this stage, instantiate them in the header types and only
|
||||
then replace the parameters' fvars in the header type.
|
||||
|
||||
See issue #3242 (`https://github.com/leanprover/lean4/issues/3242`)
|
||||
-/
|
||||
let type ← instantiateMVars r.type
|
||||
let type := type.replaceFVars r.params params
|
||||
let type ← mkForallFVars params type
|
||||
let ctors ← withExplicitToImplicit params (elabCtors indFVars indFVar params r)
|
||||
indTypesArray := indTypesArray.push { name := r.view.declName, type, ctors }
|
||||
let numExplicitParams ← fixedIndicesToParams params.size indTypesArray indFVars
|
||||
trace[Elab.inductive] "numExplicitParams: {numExplicitParams}"
|
||||
let indTypes := indTypesArray.toList
|
||||
let u ← getResultingUniverse indTypes
|
||||
let univToInfer? ← shouldInferResultUniverse u
|
||||
withUsed vars indTypes fun vars => do
|
||||
let numVars := vars.size
|
||||
let numParams := numVars + numExplicitParams
|
||||
let indTypes ← updateParams vars indTypes
|
||||
let indTypes ← if let some univToInfer := univToInfer? then
|
||||
updateResultingUniverse views numParams (← levelMVarToParam indTypes univToInfer)
|
||||
else
|
||||
checkResultingUniverses views numParams indTypes
|
||||
levelMVarToParam indTypes none
|
||||
let usedLevelNames := collectLevelParamsInInductive indTypes
|
||||
match sortDeclLevelParams scopeLevelNames allUserLevelNames usedLevelNames with
|
||||
| .error msg => throwError msg
|
||||
| .ok levelParams => do
|
||||
let indTypes ← replaceIndFVarsWithConsts views indFVars levelParams numVars numParams indTypes
|
||||
let decl := Declaration.inductDecl levelParams numParams indTypes isUnsafe
|
||||
Term.ensureNoUnassignedMVars decl
|
||||
addDecl decl
|
||||
mkAuxConstructions views
|
||||
withSaveInfoContext do -- save new env
|
||||
for view in views do
|
||||
Term.addTermInfo' view.ref[1] (← mkConstWithLevelParams view.declName) (isBinder := true)
|
||||
for ctor in view.ctors do
|
||||
Term.addTermInfo' ctor.ref[3] (← mkConstWithLevelParams ctor.declName) (isBinder := true)
|
||||
-- We need to invoke `applyAttributes` because `class` is implemented as an attribute.
|
||||
Term.applyAttributesAt view.declName view.modifiers.attrs .afterTypeChecking
|
||||
|
||||
private def applyDerivingHandlers (views : Array InductiveView) : CommandElabM Unit := do
|
||||
let mut processed : NameSet := {}
|
||||
for view in views do
|
||||
for classView in view.derivingClasses do
|
||||
let className := classView.className
|
||||
unless processed.contains className do
|
||||
processed := processed.insert className
|
||||
let mut declNames := #[]
|
||||
for view in views do
|
||||
if view.derivingClasses.any fun classView => classView.className == className then
|
||||
declNames := declNames.push view.declName
|
||||
classView.applyHandlers declNames
|
||||
|
||||
private def applyComputedFields (indViews : Array InductiveView) : CommandElabM Unit := do
|
||||
if indViews.all (·.computedFields.isEmpty) then return
|
||||
|
||||
let mut computedFields := #[]
|
||||
let mut computedFieldDefs := #[]
|
||||
for indView@{declName, ..} in indViews do
|
||||
for {ref, fieldId, type, matchAlts, modifiers, ..} in indView.computedFields do
|
||||
computedFieldDefs := computedFieldDefs.push <| ← do
|
||||
let modifiers ← match modifiers with
|
||||
| `(Lean.Parser.Command.declModifiersT| $[$doc:docComment]? $[$attrs:attributes]? $[$vis]? $[noncomputable]?) =>
|
||||
`(Lean.Parser.Command.declModifiersT| $[$doc]? $[$attrs]? $[$vis]? noncomputable)
|
||||
| _ => do
|
||||
withRef modifiers do logError "unsupported modifiers for computed field"
|
||||
`(Parser.Command.declModifiersT| noncomputable)
|
||||
`($(⟨modifiers⟩):declModifiers
|
||||
def%$ref $(mkIdent <| `_root_ ++ declName ++ fieldId):ident : $type $matchAlts:matchAlts)
|
||||
let computedFieldNames := indView.computedFields.map fun {fieldId, ..} => declName ++ fieldId
|
||||
computedFields := computedFields.push (declName, computedFieldNames)
|
||||
withScope (fun scope => { scope with
|
||||
opts := scope.opts
|
||||
|>.setBool `bootstrap.genMatcherCode false
|
||||
|>.setBool `elaboratingComputedFields true}) <|
|
||||
elabCommand <| ← `(mutual $computedFieldDefs* end)
|
||||
|
||||
liftTermElabM do Term.withDeclName indViews[0]!.declName do
|
||||
ComputedFields.setComputedFields computedFields
|
||||
|
||||
def elabInductiveViews (vars : Array Expr) (views : Array InductiveView) : TermElabM Unit := do
|
||||
let view0 := views[0]!
|
||||
let ref := view0.ref
|
||||
Term.withDeclName view0.declName do withRef ref do
|
||||
mkInductiveDecl vars views
|
||||
mkSizeOfInstances view0.declName
|
||||
Lean.Meta.IndPredBelow.mkBelow view0.declName
|
||||
for view in views do
|
||||
mkInjectiveTheorems view.declName
|
||||
|
||||
def elabInductiveViewsPostprocessing (views : Array InductiveView) : CommandElabM Unit := do
|
||||
let view0 := views[0]!
|
||||
let ref := view0.ref
|
||||
applyComputedFields views -- NOTE: any generated code before this line is invalid
|
||||
applyDerivingHandlers views
|
||||
runTermElabM fun _ => Term.withDeclName view0.declName do withRef ref do
|
||||
for view in views do
|
||||
Term.applyAttributesAt view.declName view.modifiers.attrs .afterCompilation
|
||||
|
||||
def elabInductives (inductives : Array (Modifiers × Syntax)) : CommandElabM Unit := do
|
||||
let vs ← runTermElabM fun vars => do
|
||||
let vs ← inductives.mapM fun (modifiers, stx) => inductiveSyntaxToView modifiers stx
|
||||
elabInductiveViews vars vs
|
||||
pure vs
|
||||
elabInductiveViewsPostprocessing vs
|
||||
|
||||
def elabInductive (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit := do
|
||||
elabInductives #[(modifiers, stx)]
|
||||
|
||||
def elabClassInductive (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit := do
|
||||
let modifiers := modifiers.addAttr { name := `class }
|
||||
elabInductive modifiers stx
|
||||
|
||||
/--
|
||||
Returns true if all elements of the `mutual` block (`Lean.Parser.Command.mutual`) are inductive declarations.
|
||||
-/
|
||||
def isMutualInductive (stx : Syntax) : Bool :=
|
||||
stx[1].getArgs.all fun elem =>
|
||||
let decl := elem[1]
|
||||
let declKind := decl.getKind
|
||||
declKind == `Lean.Parser.Command.inductive
|
||||
|
||||
/--
|
||||
Elaborates a `mutual` block satisfying `Lean.Elab.Command.isMutualInductive`.
|
||||
-/
|
||||
def elabMutualInductive (elems : Array Syntax) : CommandElabM Unit := do
|
||||
let inductives ← elems.mapM fun stx => do
|
||||
let modifiers ← elabModifiers ⟨stx[0]⟩
|
||||
pure (modifiers, stx[1])
|
||||
elabInductives inductives
|
||||
|
||||
end Lean.Elab.Command
|
||||
|
||||
@@ -840,8 +840,8 @@ private def mkLetRecClosures (sectionVars : Array Expr) (mainFVarIds : Array FVa
|
||||
abbrev Replacement := FVarIdMap Expr
|
||||
|
||||
def insertReplacementForMainFns (r : Replacement) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) : Replacement :=
|
||||
mainFVars.size.fold (init := r) fun i _ r =>
|
||||
r.insert mainFVars[i].fvarId! (mkAppN (Lean.mkConst mainHeaders[i]!.declName) sectionVars)
|
||||
mainFVars.size.fold (init := r) fun i r =>
|
||||
r.insert mainFVars[i]!.fvarId! (mkAppN (Lean.mkConst mainHeaders[i]!.declName) sectionVars)
|
||||
|
||||
|
||||
def insertReplacementForLetRecs (r : Replacement) (letRecClosures : List LetRecClosure) : Replacement :=
|
||||
@@ -871,8 +871,8 @@ def Replacement.apply (r : Replacement) (e : Expr) : Expr :=
|
||||
|
||||
def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainVals : Array Expr)
|
||||
: TermElabM (Array PreDefinition) :=
|
||||
mainHeaders.size.foldM (init := preDefs) fun i _ preDefs => do
|
||||
let header := mainHeaders[i]
|
||||
mainHeaders.size.foldM (init := preDefs) fun i preDefs => do
|
||||
let header := mainHeaders[i]!
|
||||
let termination ← declValToTerminationHint header.value
|
||||
let termination := termination.rememberExtraParams header.numParams mainVals[i]!
|
||||
let value ← mkLambdaFVars sectionVars mainVals[i]!
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -17,7 +17,7 @@ open Lean.Parser.Command
|
||||
private partial def antiquote (vars : Array Syntax) : Syntax → Syntax
|
||||
| stx => match stx with
|
||||
| `($id:ident) =>
|
||||
if vars.any (fun var => var.getId == id.getId) then
|
||||
if (vars.findIdx? (fun var => var.getId == id.getId)).isSome then
|
||||
mkAntiquotNode id (kind := `term) (isPseudoKind := true)
|
||||
else
|
||||
stx
|
||||
|
||||
@@ -145,8 +145,8 @@ private partial def replaceRecApps (recArgInfos : Array RecArgInfo) (positions :
|
||||
| Expr.app _ _ =>
|
||||
let processApp (e : Expr) : StateRefT (HasConstCache recFnNames) M Expr :=
|
||||
e.withApp fun f args => do
|
||||
if let .some fnIdx := recArgInfos.findFinIdx? (f.isConstOf ·.fnName) then
|
||||
let recArgInfo := recArgInfos[fnIdx]
|
||||
if let .some fnIdx := recArgInfos.findIdx? (f.isConstOf ·.fnName) then
|
||||
let recArgInfo := recArgInfos[fnIdx]!
|
||||
let some recArg := args[recArgInfo.recArgPos]?
|
||||
| throwError "insufficient number of parameters at recursive application {indentExpr e}"
|
||||
-- For reflexive type, we may have nested recursive applications in recArg
|
||||
|
||||
@@ -302,58 +302,59 @@ instance : ToFormat FieldLHS where
|
||||
| .fieldIndex _ i => format i
|
||||
| .modifyOp _ i => "[" ++ i.prettyPrint ++ "]"
|
||||
|
||||
mutual
|
||||
/--
|
||||
`FieldVal StructInstView` is a representation of a field value in the structure instance.
|
||||
-/
|
||||
inductive FieldVal where
|
||||
/-- A `term` to use for the value of the field. -/
|
||||
| term (stx : Syntax) : FieldVal
|
||||
/-- A `StructInstView` to use for the value of a subobject field. -/
|
||||
| nested (s : StructInstView) : FieldVal
|
||||
/-- A field that was not provided and should be synthesized using default values. -/
|
||||
| default : FieldVal
|
||||
deriving Inhabited
|
||||
/--
|
||||
`FieldVal StructInstView` is a representation of a field value in the structure instance.
|
||||
-/
|
||||
inductive FieldVal (σ : Type) where
|
||||
/-- A `term` to use for the value of the field. -/
|
||||
| term (stx : Syntax) : FieldVal σ
|
||||
/-- A `StructInstView` to use for the value of a subobject field. -/
|
||||
| nested (s : σ) : FieldVal σ
|
||||
/-- A field that was not provided and should be synthesized using default values. -/
|
||||
| default : FieldVal σ
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
`Field StructInstView` is a representation of a field in the structure instance.
|
||||
-/
|
||||
structure Field where
|
||||
/-- The whole field syntax. -/
|
||||
ref : Syntax
|
||||
/-- The LHS decomposed into components. -/
|
||||
lhs : List FieldLHS
|
||||
/-- The value of the field. -/
|
||||
val : FieldVal
|
||||
/-- The elaborated field value, filled in at `elabStruct`.
|
||||
Missing fields use a metavariable for the elaborated value and are later solved for in `DefaultFields.propagate`. -/
|
||||
expr? : Option Expr := none
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
The view for structure instance notation.
|
||||
-/
|
||||
structure StructInstView where
|
||||
/-- The syntax for the whole structure instance. -/
|
||||
ref : Syntax
|
||||
/-- The name of the structure for the type of the structure instance. -/
|
||||
structName : Name
|
||||
/-- Used for default values, to propagate structure type parameters. It is initially empty, and then set at `elabStruct`. -/
|
||||
params : Array (Name × Expr)
|
||||
/-- The fields of the structure instance. -/
|
||||
fields : List Field
|
||||
/-- The additional sources for fields for the structure instance. -/
|
||||
sources : SourcesView
|
||||
deriving Inhabited
|
||||
end
|
||||
/--
|
||||
`Field StructInstView` is a representation of a field in the structure instance.
|
||||
-/
|
||||
structure Field (σ : Type) where
|
||||
/-- The whole field syntax. -/
|
||||
ref : Syntax
|
||||
/-- The LHS decomposed into components. -/
|
||||
lhs : List FieldLHS
|
||||
/-- The value of the field. -/
|
||||
val : FieldVal σ
|
||||
/-- The elaborated field value, filled in at `elabStruct`.
|
||||
Missing fields use a metavariable for the elaborated value and are later solved for in `DefaultFields.propagate`. -/
|
||||
expr? : Option Expr := none
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
Returns if the field has a single component in its LHS.
|
||||
-/
|
||||
def Field.isSimple : Field → Bool
|
||||
def Field.isSimple {σ} : Field σ → Bool
|
||||
| { lhs := [_], .. } => true
|
||||
| _ => false
|
||||
|
||||
/--
|
||||
The view for structure instance notation.
|
||||
-/
|
||||
structure StructInstView where
|
||||
/-- The syntax for the whole structure instance. -/
|
||||
ref : Syntax
|
||||
/-- The name of the structure for the type of the structure instance. -/
|
||||
structName : Name
|
||||
/-- Used for default values, to propagate structure type parameters. It is initially empty, and then set at `elabStruct`. -/
|
||||
params : Array (Name × Expr)
|
||||
/-- The fields of the structure instance. -/
|
||||
fields : List (Field StructInstView)
|
||||
/-- The additional sources for fields for the structure instance. -/
|
||||
sources : SourcesView
|
||||
deriving Inhabited
|
||||
|
||||
/-- Abbreviation for the type of `StructInstView.fields`, namely `List (Field StructInstView)`. -/
|
||||
abbrev Fields := List (Field StructInstView)
|
||||
|
||||
/-- `true` iff all fields of the given structure are marked as `default` -/
|
||||
partial def StructInstView.allDefault (s : StructInstView) : Bool :=
|
||||
s.fields.all fun { val := val, .. } => match val with
|
||||
@@ -361,7 +362,7 @@ partial def StructInstView.allDefault (s : StructInstView) : Bool :=
|
||||
| .default => true
|
||||
| .nested s => allDefault s
|
||||
|
||||
def formatField (formatStruct : StructInstView → Format) (field : Field) : Format :=
|
||||
def formatField (formatStruct : StructInstView → Format) (field : Field StructInstView) : Format :=
|
||||
Format.joinSep field.lhs " . " ++ " := " ++
|
||||
match field.val with
|
||||
| .term v => v.prettyPrint
|
||||
@@ -377,11 +378,11 @@ partial def formatStruct : StructInstView → Format
|
||||
else
|
||||
"{" ++ format (source.explicit.map (·.stx)) ++ " with " ++ fieldsFmt ++ implicitFmt ++ "}"
|
||||
|
||||
instance : ToFormat StructInstView := ⟨formatStruct⟩
|
||||
instance : ToFormat StructInstView := ⟨formatStruct⟩
|
||||
instance : ToString StructInstView := ⟨toString ∘ format⟩
|
||||
|
||||
instance : ToFormat Field := ⟨formatField formatStruct⟩
|
||||
instance : ToString Field := ⟨toString ∘ format⟩
|
||||
instance : ToFormat (Field StructInstView) := ⟨formatField formatStruct⟩
|
||||
instance : ToString (Field StructInstView) := ⟨toString ∘ format⟩
|
||||
|
||||
/--
|
||||
Converts a `FieldLHS` back into syntax. This assumes the `ref` fields have the correct structure.
|
||||
@@ -402,14 +403,14 @@ private def FieldLHS.toSyntax (first : Bool) : FieldLHS → Syntax
|
||||
/--
|
||||
Converts a `FieldVal StructInstView` back into syntax. Only supports `.term`, and it assumes the `stx` field has the correct structure.
|
||||
-/
|
||||
private def FieldVal.toSyntax : FieldVal → Syntax
|
||||
private def FieldVal.toSyntax : FieldVal Struct → Syntax
|
||||
| .term stx => stx
|
||||
| _ => unreachable!
|
||||
|
||||
/--
|
||||
Converts a `Field StructInstView` back into syntax. Used to construct synthetic structure instance notation for subobjects in `StructInst.expandStruct` processing.
|
||||
-/
|
||||
private def Field.toSyntax : Field → Syntax
|
||||
private def Field.toSyntax : Field Struct → Syntax
|
||||
| field =>
|
||||
let stx := field.ref
|
||||
let stx := stx.setArg 2 field.val.toSyntax
|
||||
@@ -451,14 +452,14 @@ private def mkStructView (stx : Syntax) (structName : Name) (sources : SourcesVi
|
||||
let val := fieldStx[2]
|
||||
let first ← toFieldLHS fieldStx[0][0]
|
||||
let rest ← fieldStx[0][1].getArgs.toList.mapM toFieldLHS
|
||||
return { ref := fieldStx, lhs := first :: rest, val := FieldVal.term val : Field }
|
||||
return { ref := fieldStx, lhs := first :: rest, val := FieldVal.term val : Field StructInstView }
|
||||
return { ref := stx, structName, params := #[], fields, sources }
|
||||
|
||||
def StructInstView.modifyFieldsM {m : Type → Type} [Monad m] (s : StructInstView) (f : List Field → m (List Field)) : m StructInstView :=
|
||||
def StructInstView.modifyFieldsM {m : Type → Type} [Monad m] (s : StructInstView) (f : Fields → m Fields) : m StructInstView :=
|
||||
match s with
|
||||
| { ref, structName, params, fields, sources } => return { ref, structName, params, fields := (← f fields), sources }
|
||||
|
||||
def StructInstView.modifyFields (s : StructInstView) (f : List Field → List Field) : StructInstView :=
|
||||
def StructInstView.modifyFields (s : StructInstView) (f : Fields → Fields) : StructInstView :=
|
||||
Id.run <| s.modifyFieldsM f
|
||||
|
||||
/-- Expands name field LHSs with multi-component names into multi-component LHSs. -/
|
||||
@@ -524,14 +525,14 @@ private def expandParentFields (s : StructInstView) : TermElabM StructInstView :
|
||||
| _ => throwErrorAt ref "failed to access field '{fieldName}' in parent structure"
|
||||
| _ => return field
|
||||
|
||||
private abbrev FieldMap := Std.HashMap Name (List Field)
|
||||
private abbrev FieldMap := Std.HashMap Name Fields
|
||||
|
||||
/--
|
||||
Creates a hash map collecting all fields with the same first name component.
|
||||
Throws an error if there are multiple simple fields with the same name.
|
||||
Used by `StructInst.expandStruct` processing.
|
||||
-/
|
||||
private def mkFieldMap (fields : List Field) : TermElabM FieldMap :=
|
||||
private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
|
||||
fields.foldlM (init := {}) fun fieldMap field =>
|
||||
match field.lhs with
|
||||
| .fieldName _ fieldName :: _ =>
|
||||
@@ -547,7 +548,7 @@ private def mkFieldMap (fields : List Field) : TermElabM FieldMap :=
|
||||
/--
|
||||
Given a value of the hash map created by `mkFieldMap`, returns true if the value corresponds to a simple field.
|
||||
-/
|
||||
private def isSimpleField? : List Field → Option Field
|
||||
private def isSimpleField? : Fields → Option (Field StructInstView)
|
||||
| [field] => if field.isSimple then some field else none
|
||||
| _ => none
|
||||
|
||||
@@ -565,7 +566,7 @@ def mkProjStx? (s : Syntax) (structName : Name) (fieldName : Name) : TermElabM (
|
||||
/--
|
||||
Finds a simple field of the given name.
|
||||
-/
|
||||
def findField? (fields : List Field) (fieldName : Name) : Option Field :=
|
||||
def findField? (fields : Fields) (fieldName : Name) : Option (Field StructInstView) :=
|
||||
fields.find? fun field =>
|
||||
match field.lhs with
|
||||
| [.fieldName _ n] => n == fieldName
|
||||
@@ -619,7 +620,7 @@ mutual
|
||||
match findField? s.fields fieldName with
|
||||
| some field => return field::fields
|
||||
| none =>
|
||||
let addField (val : FieldVal) : TermElabM (List Field) := do
|
||||
let addField (val : FieldVal StructInstView) : TermElabM Fields := do
|
||||
return { ref, lhs := [FieldLHS.fieldName ref fieldName], val := val } :: fields
|
||||
match Lean.isSubobjectField? env s.structName fieldName with
|
||||
| some substructName =>
|
||||
@@ -772,7 +773,7 @@ private partial def elabStructInstView (s : StructInstView) (expectedType? : Opt
|
||||
trace[Elab.struct] "elabStruct {field}, {type}"
|
||||
match type with
|
||||
| .forallE _ d b bi =>
|
||||
let cont (val : Expr) (field : Field) (instMVars := instMVars) : TermElabM (Expr × Expr × List Field × Array MVarId) := do
|
||||
let cont (val : Expr) (field : Field StructInstView) (instMVars := instMVars) : TermElabM (Expr × Expr × Fields × Array MVarId) := do
|
||||
pushInfoTree <| InfoTree.node (children := {}) <| Info.ofFieldInfo {
|
||||
projName := s.structName.append fieldName, fieldName, lctx := (← getLCtx), val, stx := ref }
|
||||
let e := mkApp e val
|
||||
@@ -878,7 +879,7 @@ partial def getHierarchyDepth (struct : StructInstView) : Nat :=
|
||||
| _ => max
|
||||
|
||||
/-- Returns whether the field is still missing. -/
|
||||
def isDefaultMissing? [Monad m] [MonadMCtx m] (field : Field) : m Bool := do
|
||||
def isDefaultMissing? [Monad m] [MonadMCtx m] (field : Field Struct) : m Bool := do
|
||||
if let some expr := field.expr? then
|
||||
if let some (.mvar mvarId) := defaultMissing? expr then
|
||||
unless (← mvarId.isAssigned) do
|
||||
@@ -886,17 +887,17 @@ def isDefaultMissing? [Monad m] [MonadMCtx m] (field : Field) : m Bool := do
|
||||
return false
|
||||
|
||||
/-- Returns a field that is still missing. -/
|
||||
partial def findDefaultMissing? [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Option Field) :=
|
||||
partial def findDefaultMissing? [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Option (Field StructInstView)) :=
|
||||
struct.fields.findSomeM? fun field => do
|
||||
match field.val with
|
||||
| .nested struct => findDefaultMissing? struct
|
||||
| _ => return if (← isDefaultMissing? field) then field else none
|
||||
|
||||
/-- Returns all fields that are still missing. -/
|
||||
partial def allDefaultMissing [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Array Field) :=
|
||||
partial def allDefaultMissing [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Array (Field StructInstView)) :=
|
||||
go struct *> get |>.run' #[]
|
||||
where
|
||||
go (struct : StructInstView) : StateT (Array Field) m Unit :=
|
||||
go (struct : StructInstView) : StateT (Array (Field StructInstView)) m Unit :=
|
||||
for field in struct.fields do
|
||||
if let .nested struct := field.val then
|
||||
go struct
|
||||
@@ -904,7 +905,7 @@ where
|
||||
modify (·.push field)
|
||||
|
||||
/-- Returns the name of the field. Assumes all fields under consideration are simple and named. -/
|
||||
def getFieldName (field : Field) : Name :=
|
||||
def getFieldName (field : Field StructInstView) : Name :=
|
||||
match field.lhs with
|
||||
| [.fieldName _ fieldName] => fieldName
|
||||
| _ => unreachable!
|
||||
|
||||
@@ -4,15 +4,22 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Class
|
||||
import Lean.Parser.Command
|
||||
import Lean.Meta.Closure
|
||||
import Lean.Meta.SizeOf
|
||||
import Lean.Meta.Injective
|
||||
import Lean.Meta.Structure
|
||||
import Lean.Elab.MutualInductive
|
||||
import Lean.Meta.AppBuilder
|
||||
import Lean.Elab.Command
|
||||
import Lean.Elab.DeclModifiers
|
||||
import Lean.Elab.DeclUtil
|
||||
import Lean.Elab.Inductive
|
||||
import Lean.Elab.DeclarationRange
|
||||
import Lean.Elab.Binders
|
||||
|
||||
namespace Lean.Elab.Command
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Elab.structure
|
||||
registerTraceClass `Elab.structure.resolutionOrder
|
||||
|
||||
register_builtin_option structureDiamondWarning : Bool := {
|
||||
defValue := false
|
||||
descr := "if true, enable warnings when a structure has diamond inheritance"
|
||||
@@ -32,6 +39,13 @@ leading_parser (structureTk <|> classTk) >> declId >> many Term.bracketedBinder
|
||||
```
|
||||
-/
|
||||
|
||||
structure StructCtorView where
|
||||
ref : Syntax
|
||||
modifiers : Modifiers
|
||||
name : Name
|
||||
declName : Name
|
||||
deriving Inhabited
|
||||
|
||||
structure StructFieldView where
|
||||
ref : Syntax
|
||||
modifiers : Modifiers
|
||||
@@ -47,15 +61,22 @@ structure StructFieldView where
|
||||
type? : Option Syntax
|
||||
value? : Option Syntax
|
||||
|
||||
structure StructView extends InductiveView where
|
||||
parents : Array Syntax
|
||||
fields : Array StructFieldView
|
||||
structure StructView where
|
||||
ref : Syntax
|
||||
declId : Syntax
|
||||
modifiers : Modifiers
|
||||
isClass : Bool -- struct-only
|
||||
shortDeclName : Name
|
||||
declName : Name
|
||||
levelNames : List Name
|
||||
binders : Syntax
|
||||
type : Syntax -- modified (inductive has type?)
|
||||
parents : Array Syntax -- struct-only
|
||||
ctor : StructCtorView -- struct-only
|
||||
fields : Array StructFieldView -- struct-only
|
||||
derivingClasses : Array DerivingClassView
|
||||
deriving Inhabited
|
||||
|
||||
def StructView.ctor : StructView → CtorView
|
||||
| { ctors := #[ctor], ..} => ctor
|
||||
| _ => unreachable!
|
||||
|
||||
structure StructParentInfo where
|
||||
ref : Syntax
|
||||
fvar? : Option Expr
|
||||
@@ -81,6 +102,18 @@ structure StructFieldInfo where
|
||||
value? : Option Expr := none
|
||||
deriving Inhabited, Repr
|
||||
|
||||
structure ElabStructHeaderResult where
|
||||
view : StructView
|
||||
lctx : LocalContext
|
||||
localInsts : LocalInstances
|
||||
levelNames : List Name
|
||||
params : Array Expr
|
||||
type : Expr
|
||||
parents : Array StructParentInfo
|
||||
/-- Field infos from parents. -/
|
||||
parentFieldInfos : Array StructFieldInfo
|
||||
deriving Inhabited
|
||||
|
||||
def StructFieldInfo.isFromParent (info : StructFieldInfo) : Bool :=
|
||||
match info.kind with
|
||||
| StructFieldKind.fromParent => true
|
||||
@@ -97,12 +130,12 @@ The structure constructor syntax is
|
||||
leading_parser try (declModifiers >> ident >> " :: ")
|
||||
```
|
||||
-/
|
||||
private def expandCtor (structStx : Syntax) (structModifiers : Modifiers) (structDeclName : Name) : TermElabM CtorView := do
|
||||
private def expandCtor (structStx : Syntax) (structModifiers : Modifiers) (structDeclName : Name) : TermElabM StructCtorView := do
|
||||
let useDefault := do
|
||||
let declName := structDeclName ++ defaultCtorName
|
||||
let ref := structStx[1].mkSynthetic
|
||||
addDeclarationRangesFromSyntax declName ref
|
||||
pure { ref, declId := ref, modifiers := default, declName }
|
||||
pure { ref, modifiers := default, name := defaultCtorName, declName }
|
||||
if structStx[5].isNone then
|
||||
useDefault
|
||||
else
|
||||
@@ -123,7 +156,7 @@ private def expandCtor (structStx : Syntax) (structModifiers : Modifiers) (struc
|
||||
let declName ← applyVisibility ctorModifiers.visibility declName
|
||||
addDocString' declName ctorModifiers.docString?
|
||||
addDeclarationRangesFromSyntax declName ctor[1]
|
||||
pure { ref := ctor[1], declId := ctor[1], modifiers := ctorModifiers, declName }
|
||||
pure { ref := ctor[1], name, modifiers := ctorModifiers, declName }
|
||||
|
||||
def checkValidFieldModifier (modifiers : Modifiers) : TermElabM Unit := do
|
||||
if modifiers.isNoncomputable then
|
||||
@@ -238,7 +271,7 @@ def structureSyntaxToView (modifiers : Modifiers) (stx : Syntax) : TermElabM Str
|
||||
let parents := if exts.isNone then #[] else exts[0][1].getSepArgs
|
||||
let optType := stx[4]
|
||||
let derivingClasses ← getOptDerivingClasses stx[6]
|
||||
let type? := if optType.isNone then none else some optType[0][1]
|
||||
let type ← if optType.isNone then `(Sort _) else pure optType[0][1]
|
||||
let ctor ← expandCtor stx modifiers declName
|
||||
let fields ← expandFields stx modifiers declName
|
||||
fields.forM fun field => do
|
||||
@@ -254,13 +287,10 @@ def structureSyntaxToView (modifiers : Modifiers) (stx : Syntax) : TermElabM Str
|
||||
declName
|
||||
levelNames
|
||||
binders
|
||||
type?
|
||||
allowIndices := false
|
||||
allowSortPolymorphism := false
|
||||
ctors := #[ctor]
|
||||
type
|
||||
parents
|
||||
ctor
|
||||
fields
|
||||
computedFields := #[]
|
||||
derivingClasses
|
||||
}
|
||||
|
||||
@@ -285,7 +315,7 @@ private def findExistingField? (infos : Array StructFieldInfo) (parentStructName
|
||||
return some fieldName
|
||||
return none
|
||||
|
||||
private def processSubfields (structDeclName : Name) (parentFVar : Expr) (parentStructName : Name) (subfieldNames : Array Name)
|
||||
private partial def processSubfields (structDeclName : Name) (parentFVar : Expr) (parentStructName : Name) (subfieldNames : Array Name)
|
||||
(infos : Array StructFieldInfo) (k : Array StructFieldInfo → TermElabM α) : TermElabM α :=
|
||||
go 0 infos
|
||||
where
|
||||
@@ -506,7 +536,7 @@ private partial def mkToParentName (parentStructName : Name) (p : Name → Bool)
|
||||
if p curr then curr else go (i+1)
|
||||
go 1
|
||||
|
||||
private def withParents (view : StructView) (rs : Array ElabHeaderResult) (indFVar : Expr)
|
||||
private partial def elabParents (view : StructView)
|
||||
(k : Array StructFieldInfo → Array StructParentInfo → TermElabM α) : TermElabM α := do
|
||||
go 0 #[] #[]
|
||||
where
|
||||
@@ -514,17 +544,11 @@ where
|
||||
if h : i < view.parents.size then
|
||||
let parent := view.parents[i]
|
||||
withRef parent do
|
||||
-- The only use case for autobound implicits for parents might be outParams, but outParam is not propagated.
|
||||
let type ← Term.withoutAutoBoundImplicit <| Term.elabType parent
|
||||
let type ← Term.elabType parent
|
||||
let parentType ← whnf type
|
||||
if parentType.getAppFn == indFVar then
|
||||
logWarning "structure extends itself, skipping"
|
||||
return ← go (i + 1) infos parents
|
||||
if rs.any (fun r => r.indFVar == parentType.getAppFn) then
|
||||
throwError "structure cannot extend types defined in the same mutual block"
|
||||
let parentStructName ← getStructureName parentType
|
||||
if parents.any (fun info => info.structName == parentStructName) then
|
||||
logWarning m!"duplicate parent structure '{.ofConstName parentStructName}', skipping"
|
||||
logWarningAt parent m!"duplicate parent structure '{.ofConstName parentStructName}', skipping"
|
||||
go (i + 1) infos parents
|
||||
else if let some existingFieldName ← findExistingField? infos parentStructName then
|
||||
if structureDiamondWarning.get (← getOptions) then
|
||||
@@ -546,13 +570,6 @@ where
|
||||
else
|
||||
k infos parents
|
||||
|
||||
private def registerFailedToInferFieldType (fieldName : Name) (e : Expr) (ref : Syntax) : TermElabM Unit := do
|
||||
Term.registerCustomErrorIfMVar (← instantiateMVars e) ref m!"failed to infer type of field '{.ofConstName fieldName}'"
|
||||
|
||||
private def registerFailedToInferDefaultValue (fieldName : Name) (e : Expr) (ref : Syntax) : TermElabM Unit := do
|
||||
Term.registerCustomErrorIfMVar (← instantiateMVars e) ref m!"failed to infer default value for field '{.ofConstName fieldName}'"
|
||||
Term.registerLevelMVarErrorExprInfo e ref m!"failed to infer universe levels in default value for field '{.ofConstName fieldName}'"
|
||||
|
||||
private def elabFieldTypeValue (view : StructFieldView) : TermElabM (Option Expr × Option Expr) :=
|
||||
Term.withAutoBoundImplicit <| Term.withAutoBoundImplicitForbiddenPred (fun n => view.name == n) <| Term.elabBinders view.binders.getArgs fun params => do
|
||||
match view.type? with
|
||||
@@ -564,13 +581,10 @@ private def elabFieldTypeValue (view : StructFieldView) : TermElabM (Option Expr
|
||||
-- TODO: add forbidden predicate using `shortDeclName` from `view`
|
||||
let params ← Term.addAutoBoundImplicits params
|
||||
let value ← Term.withoutAutoBoundImplicit <| Term.elabTerm valStx none
|
||||
registerFailedToInferFieldType view.name (← inferType value) view.nameId
|
||||
registerFailedToInferDefaultValue view.name value valStx
|
||||
let value ← mkLambdaFVars params value
|
||||
return (none, value)
|
||||
| some typeStx =>
|
||||
let type ← Term.elabType typeStx
|
||||
registerFailedToInferFieldType view.name type typeStx
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let params ← Term.addAutoBoundImplicits params
|
||||
match view.value? with
|
||||
@@ -579,7 +593,6 @@ private def elabFieldTypeValue (view : StructFieldView) : TermElabM (Option Expr
|
||||
return (type, none)
|
||||
| some valStx =>
|
||||
let value ← Term.withoutAutoBoundImplicit <| Term.elabTermEnsuringType valStx type
|
||||
registerFailedToInferDefaultValue view.name value valStx
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let type ← mkForallFVars params type
|
||||
let value ← mkLambdaFVars params value
|
||||
@@ -626,7 +639,6 @@ where
|
||||
valStx ← `(fun $(view.binders.getArgs)* => $valStx:term)
|
||||
let fvarType ← inferType info.fvar
|
||||
let value ← Term.elabTermEnsuringType valStx fvarType
|
||||
registerFailedToInferDefaultValue view.name value valStx
|
||||
pushInfoLeaf <| .ofFieldRedeclInfo { stx := view.ref }
|
||||
let infos := replaceFieldInfo infos { info with ref := view.nameId, value? := value }
|
||||
go (i+1) defaultValsOverridden infos
|
||||
@@ -638,14 +650,113 @@ where
|
||||
else
|
||||
k infos
|
||||
|
||||
private def collectUsedFVars (lctx : LocalContext) (localInsts : LocalInstances) (fieldInfos : Array StructFieldInfo) :
|
||||
StateRefT CollectFVars.State MetaM Unit := do
|
||||
withLCtx lctx localInsts do
|
||||
fieldInfos.forM fun info => do
|
||||
let fvarType ← inferType info.fvar
|
||||
fvarType.collectFVars
|
||||
if let some value := info.value? then
|
||||
value.collectFVars
|
||||
private def getResultUniverse (type : Expr) : TermElabM Level := do
|
||||
let type ← whnf type
|
||||
match type with
|
||||
| Expr.sort u => pure u
|
||||
| _ => throwError "unexpected structure resulting type"
|
||||
|
||||
private def collectUsed (params : Array Expr) (fieldInfos : Array StructFieldInfo) : StateRefT CollectFVars.State MetaM Unit := do
|
||||
params.forM fun p => do
|
||||
let type ← inferType p
|
||||
type.collectFVars
|
||||
fieldInfos.forM fun info => do
|
||||
let fvarType ← inferType info.fvar
|
||||
fvarType.collectFVars
|
||||
match info.value? with
|
||||
| none => pure ()
|
||||
| some value => value.collectFVars
|
||||
|
||||
private def removeUnused (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo)
|
||||
: TermElabM (LocalContext × LocalInstances × Array Expr) := do
|
||||
let (_, used) ← (collectUsed params fieldInfos).run {}
|
||||
Meta.removeUnused scopeVars used
|
||||
|
||||
private def withUsed {α} (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo) (k : Array Expr → TermElabM α)
|
||||
: TermElabM α := do
|
||||
let (lctx, localInsts, vars) ← removeUnused scopeVars params fieldInfos
|
||||
withLCtx lctx localInsts <| k vars
|
||||
|
||||
private def levelMVarToParam (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo) (univToInfer? : Option LMVarId) : TermElabM (Array StructFieldInfo) := do
|
||||
levelMVarToParamFVars scopeVars
|
||||
levelMVarToParamFVars params
|
||||
fieldInfos.mapM fun info => do
|
||||
levelMVarToParamFVar info.fvar
|
||||
match info.value? with
|
||||
| none => pure info
|
||||
| some value =>
|
||||
let value ← levelMVarToParam' value
|
||||
pure { info with value? := value }
|
||||
where
|
||||
levelMVarToParam' (type : Expr) : TermElabM Expr := do
|
||||
Term.levelMVarToParam type (except := fun mvarId => univToInfer? == some mvarId)
|
||||
|
||||
levelMVarToParamFVars (fvars : Array Expr) : TermElabM Unit :=
|
||||
fvars.forM levelMVarToParamFVar
|
||||
|
||||
levelMVarToParamFVar (fvar : Expr) : TermElabM Unit := do
|
||||
let type ← inferType fvar
|
||||
discard <| levelMVarToParam' type
|
||||
|
||||
|
||||
private partial def collectUniversesFromFields (r : Level) (rOffset : Nat) (fieldInfos : Array StructFieldInfo) : TermElabM (Array Level) := do
|
||||
let (_, us) ← go |>.run #[]
|
||||
return us
|
||||
where
|
||||
go : StateRefT (Array Level) TermElabM Unit :=
|
||||
for info in fieldInfos do
|
||||
let type ← inferType info.fvar
|
||||
let u ← getLevel type
|
||||
let u ← instantiateLevelMVars u
|
||||
match (← modifyGet fun s => accLevel u r rOffset |>.run |>.run s) with
|
||||
| some _ => pure ()
|
||||
| none =>
|
||||
let typeType ← inferType type
|
||||
let mut msg := m!"failed to compute resulting universe level of structure, field '{info.declName}' has type{indentD m!"{type} : {typeType}"}\nstructure resulting type{indentExpr (mkSort (r.addOffset rOffset))}"
|
||||
if r.isMVar then
|
||||
msg := msg ++ "\nrecall that Lean only infers the resulting universe level automatically when there is a unique solution for the universe level constraints, consider explicitly providing the structure resulting universe level"
|
||||
throwError msg
|
||||
|
||||
/--
|
||||
Decides whether the structure should be `Prop`-valued when the universe is not given
|
||||
and when the universe inference algorithm `collectUniversesFromFields` determines
|
||||
that the inductive type could naturally be `Prop`-valued.
|
||||
|
||||
See `Lean.Elab.Command.isPropCandidate` for an explanation.
|
||||
Specialized to structures, the heuristic is that we prefer a `Prop` instead of a `Type` structure
|
||||
when it could be a syntactic subsingleton.
|
||||
Exception: no-field structures are `Type` since they are likely stubbed-out declarations.
|
||||
-/
|
||||
private def isPropCandidate (fieldInfos : Array StructFieldInfo) : Bool :=
|
||||
!fieldInfos.isEmpty
|
||||
|
||||
private def updateResultingUniverse (fieldInfos : Array StructFieldInfo) (type : Expr) : TermElabM Expr := do
|
||||
let r ← getResultUniverse type
|
||||
let rOffset : Nat := r.getOffset
|
||||
let r : Level := r.getLevelOffset
|
||||
unless r.isMVar do
|
||||
throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly: {r}"
|
||||
let us ← collectUniversesFromFields r rOffset fieldInfos
|
||||
trace[Elab.structure] "updateResultingUniverse us: {us}, r: {r}, rOffset: {rOffset}"
|
||||
let rNew := mkResultUniverse us rOffset (isPropCandidate fieldInfos)
|
||||
assignLevelMVar r.mvarId! rNew
|
||||
instantiateMVars type
|
||||
|
||||
private def collectLevelParamsInFVar (s : CollectLevelParams.State) (fvar : Expr) : TermElabM CollectLevelParams.State := do
|
||||
let type ← inferType fvar
|
||||
let type ← instantiateMVars type
|
||||
return collectLevelParams s type
|
||||
|
||||
private def collectLevelParamsInFVars (fvars : Array Expr) (s : CollectLevelParams.State) : TermElabM CollectLevelParams.State :=
|
||||
fvars.foldlM collectLevelParamsInFVar s
|
||||
|
||||
private def collectLevelParamsInStructure (structType : Expr) (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo)
|
||||
: TermElabM (Array Name) := do
|
||||
let s := collectLevelParams {} structType
|
||||
let s ← collectLevelParamsInFVars scopeVars s
|
||||
let s ← collectLevelParamsInFVars params s
|
||||
let s ← fieldInfos.foldlM (init := s) fun s info => collectLevelParamsInFVar s info.fvar
|
||||
return s.params
|
||||
|
||||
private def addCtorFields (fieldInfos : Array StructFieldInfo) : Nat → Expr → TermElabM Expr
|
||||
| 0, type => pure type
|
||||
@@ -661,29 +772,19 @@ private def addCtorFields (fieldInfos : Array StructFieldInfo) : Nat → Expr
|
||||
| _ =>
|
||||
addCtorFields fieldInfos i (mkForall decl.userName decl.binderInfo decl.type type)
|
||||
|
||||
private def mkCtor (view : StructView) (r : ElabHeaderResult) (params : Array Expr) (fieldInfos : Array StructFieldInfo) : TermElabM Constructor :=
|
||||
private def mkCtor (view : StructView) (levelParams : List Name) (params : Array Expr) (fieldInfos : Array StructFieldInfo) : TermElabM Constructor :=
|
||||
withRef view.ref do
|
||||
let type := mkAppN r.indFVar params
|
||||
let type := mkAppN (mkConst view.declName (levelParams.map mkLevelParam)) params
|
||||
let type ← addCtorFields fieldInfos fieldInfos.size type
|
||||
let type ← mkForallFVars params type
|
||||
let type ← instantiateMVars type
|
||||
let type := type.inferImplicit params.size true
|
||||
pure { name := view.ctor.declName, type }
|
||||
|
||||
private partial def checkResultingUniversesForFields (fieldInfos : Array StructFieldInfo) (u : Level) : TermElabM Unit := do
|
||||
for info in fieldInfos do
|
||||
let type ← inferType info.fvar
|
||||
let v := (← instantiateLevelMVars (← getLevel type)).normalize
|
||||
unless u.geq v do
|
||||
let msg := m!"invalid universe level for field '{info.name}', has type{indentExpr type}\n\
|
||||
at universe level{indentD v}\n\
|
||||
which is not less than or equal to the structure's resulting universe level{indentD u}"
|
||||
throwErrorAt info.ref msg
|
||||
|
||||
@[extern "lean_mk_projections"]
|
||||
private opaque mkProjections (env : Environment) (structName : Name) (projs : List Name) (isClass : Bool) : Except KernelException Environment
|
||||
|
||||
private def addProjections (r : ElabHeaderResult) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
private def addProjections (r : ElabStructHeaderResult) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
if r.type.isProp then
|
||||
if let some fieldInfo ← fieldInfos.findM? (not <$> Meta.isProof ·.fvar) then
|
||||
throwErrorAt fieldInfo.ref m!"failed to generate projections for 'Prop' structure, field '{format fieldInfo.name}' is not a proof"
|
||||
@@ -694,71 +795,49 @@ private def addProjections (r : ElabHeaderResult) (fieldInfos : Array StructFiel
|
||||
|
||||
private def registerStructure (structName : Name) (infos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
let fields ← infos.filterMapM fun info => do
|
||||
if info.kind == StructFieldKind.fromParent then
|
||||
return none
|
||||
else
|
||||
return some {
|
||||
fieldName := info.name
|
||||
projFn := info.declName
|
||||
binderInfo := (← getFVarLocalDecl info.fvar).binderInfo
|
||||
autoParam? := (← inferType info.fvar).getAutoParamTactic?
|
||||
subobject? := if let .subobject parentName := info.kind then parentName else none
|
||||
}
|
||||
if info.kind == StructFieldKind.fromParent then
|
||||
return none
|
||||
else
|
||||
return some {
|
||||
fieldName := info.name
|
||||
projFn := info.declName
|
||||
binderInfo := (← getFVarLocalDecl info.fvar).binderInfo
|
||||
autoParam? := (← inferType info.fvar).getAutoParamTactic?
|
||||
subobject? := if let .subobject parentName := info.kind then parentName else none
|
||||
}
|
||||
modifyEnv fun env => Lean.registerStructure env { structName, fields }
|
||||
|
||||
private def checkDefaults (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
let mut mvars := {}
|
||||
let mut lmvars := {}
|
||||
for fieldInfo in fieldInfos do
|
||||
if let some value := fieldInfo.value? then
|
||||
let value ← instantiateMVars value
|
||||
mvars := Expr.collectMVars mvars value
|
||||
lmvars := collectLevelMVars lmvars value
|
||||
-- Log errors and ignore the failure; we later will just omit adding a default value.
|
||||
if ← Term.logUnassignedUsingErrorInfos mvars.result then
|
||||
return
|
||||
else if ← Term.logUnassignedLevelMVarsUsingErrorInfos lmvars.result then
|
||||
return
|
||||
private def mkAuxConstructions (declName : Name) : TermElabM Unit := do
|
||||
let env ← getEnv
|
||||
let hasEq := env.contains ``Eq
|
||||
let hasHEq := env.contains ``HEq
|
||||
let hasUnit := env.contains ``PUnit
|
||||
let hasProd := env.contains ``Prod
|
||||
mkRecOn declName
|
||||
if hasUnit then mkCasesOn declName
|
||||
if hasUnit && hasEq && hasHEq then mkNoConfusion declName
|
||||
let ival ← getConstInfoInduct declName
|
||||
if ival.isRec then
|
||||
if hasUnit && hasProd then mkBelow declName
|
||||
if hasUnit && hasProd then mkIBelow declName
|
||||
if hasUnit && hasProd then mkBRecOn declName
|
||||
if hasUnit && hasProd then mkBInductionOn declName
|
||||
|
||||
private def addDefaults (params : Array Expr) (replaceIndFVars : Expr → MetaM Expr) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
let lctx ← getLCtx
|
||||
/- The `lctx` and `defaultAuxDecls` are used to create the auxiliary "default value" declarations
|
||||
The parameters `params` for these definitions must be marked as implicit, and all others as explicit. -/
|
||||
let lctx :=
|
||||
params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) =>
|
||||
if p.isFVar then
|
||||
lctx.setBinderInfo p.fvarId! BinderInfo.implicit
|
||||
else
|
||||
lctx
|
||||
let lctx :=
|
||||
fieldInfos.foldl (init := lctx) fun (lctx : LocalContext) (info : StructFieldInfo) =>
|
||||
if info.isFromParent then lctx -- `fromParent` fields are elaborated as let-decls, and are zeta-expanded when creating "default value" auxiliary functions
|
||||
else lctx.setBinderInfo info.fvar.fvarId! BinderInfo.default
|
||||
-- Make all indFVar replacements in the local context.
|
||||
let lctx ←
|
||||
lctx.foldlM (init := {}) fun lctx ldecl => do
|
||||
match ldecl with
|
||||
| .cdecl _ fvarId userName type bi k =>
|
||||
let type ← replaceIndFVars type
|
||||
return lctx.mkLocalDecl fvarId userName type bi k
|
||||
| .ldecl _ fvarId userName type value nonDep k =>
|
||||
let type ← replaceIndFVars type
|
||||
let value ← replaceIndFVars value
|
||||
return lctx.mkLetDecl fvarId userName type value nonDep k
|
||||
private def addDefaults (lctx : LocalContext) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
|
||||
withLCtx lctx (← getLocalInstances) do
|
||||
fieldInfos.forM fun fieldInfo => do
|
||||
if let some value := fieldInfo.value? then
|
||||
let declName := mkDefaultFnOfProjFn fieldInfo.declName
|
||||
let type ← replaceIndFVars (← inferType fieldInfo.fvar)
|
||||
let value ← instantiateMVars (← replaceIndFVars value)
|
||||
trace[Elab.structure] "default value after 'replaceIndFVars': {indentExpr value}"
|
||||
-- If there are mvars, `checkDefaults` already logged an error.
|
||||
unless value.hasMVar || value.hasSyntheticSorry do
|
||||
/- The identity function is used as "marker". -/
|
||||
let value ← mkId value
|
||||
-- No need to compile the definition, since it is only used during elaboration.
|
||||
discard <| mkAuxDefinition declName type value (zetaDelta := true) (compile := false)
|
||||
setReducibleAttribute declName
|
||||
let type ← inferType fieldInfo.fvar
|
||||
let value ← instantiateMVars value
|
||||
if value.hasExprMVar then
|
||||
discard <| Term.logUnassignedUsingErrorInfos (← getMVars value)
|
||||
throwErrorAt fieldInfo.ref "invalid default value for field '{format fieldInfo.name}', it contains metavariables{indentExpr value}"
|
||||
/- The identity function is used as "marker". -/
|
||||
let value ← mkId value
|
||||
-- No need to compile the definition, since it is only used during elaboration.
|
||||
discard <| mkAuxDefinition declName type value (zetaDelta := true) (compile := false)
|
||||
setReducibleAttribute declName
|
||||
|
||||
/--
|
||||
Given `type` of the form `forall ... (source : A), B`, return `forall ... [source : A], B`.
|
||||
@@ -775,81 +854,108 @@ private def setSourceInstImplicit (type : Expr) : Expr :=
|
||||
/--
|
||||
Creates a projection function to a non-subobject parent.
|
||||
-/
|
||||
private partial def mkCoercionToCopiedParent (levelParams : List Name) (params : Array Expr) (view : StructView) (source : Expr) (parentStructName : Name) (parentType : Expr) : MetaM StructureParentInfo := do
|
||||
private partial def mkCoercionToCopiedParent (levelParams : List Name) (params : Array Expr) (view : StructView) (parentStructName : Name) (parentType : Expr) : MetaM StructureParentInfo := do
|
||||
let isProp ← Meta.isProp parentType
|
||||
let env ← getEnv
|
||||
let structName := view.declName
|
||||
let sourceFieldNames := getStructureFieldsFlattened env structName
|
||||
let structType := mkAppN (Lean.mkConst structName (levelParams.map mkLevelParam)) params
|
||||
let binfo := if view.isClass && isClass env parentStructName then BinderInfo.instImplicit else BinderInfo.default
|
||||
let mut declType ← instantiateMVars (← mkForallFVars params (← mkForallFVars #[source] parentType))
|
||||
if view.isClass && isClass env parentStructName then
|
||||
declType := setSourceInstImplicit declType
|
||||
declType := declType.inferImplicit params.size true
|
||||
let rec copyFields (parentType : Expr) : MetaM Expr := do
|
||||
let Expr.const parentStructName us ← pure parentType.getAppFn | unreachable!
|
||||
let parentCtor := getStructureCtor env parentStructName
|
||||
let mut result := mkAppN (mkConst parentCtor.name us) parentType.getAppArgs
|
||||
for fieldName in getStructureFields env parentStructName do
|
||||
if sourceFieldNames.contains fieldName then
|
||||
let fieldVal ← mkProjection source fieldName
|
||||
result := mkApp result fieldVal
|
||||
else
|
||||
-- fieldInfo must be a field of `parentStructName`
|
||||
let some fieldInfo := getFieldInfo? env parentStructName fieldName | unreachable!
|
||||
if fieldInfo.subobject?.isNone then throwError "failed to build coercion to parent structure"
|
||||
let resultType ← whnfD (← inferType result)
|
||||
unless resultType.isForall do throwError "failed to build coercion to parent structure, unexpected type{indentExpr resultType}"
|
||||
let fieldVal ← copyFields resultType.bindingDomain!
|
||||
result := mkApp result fieldVal
|
||||
return result
|
||||
let declVal ← instantiateMVars (← mkLambdaFVars params (← mkLambdaFVars #[source] (← copyFields parentType)))
|
||||
let declName := structName ++ mkToParentName (← getStructureName parentType) fun n => !env.contains (structName ++ n)
|
||||
-- Logic from `mk_projections`: prop-valued projections are theorems (or at least opaque)
|
||||
let cval : ConstantVal := { name := declName, levelParams, type := declType }
|
||||
if isProp then
|
||||
addDecl <|
|
||||
if view.modifiers.isUnsafe then
|
||||
-- Theorems cannot be unsafe.
|
||||
Declaration.opaqueDecl { cval with value := declVal, isUnsafe := true }
|
||||
else
|
||||
Declaration.thmDecl { cval with value := declVal }
|
||||
else
|
||||
addAndCompile <| Declaration.defnDecl { cval with
|
||||
value := declVal
|
||||
hints := ReducibilityHints.abbrev
|
||||
safety := if view.modifiers.isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe
|
||||
}
|
||||
-- Logic from `mk_projections`: non-instance-implicits that aren't props become reducible.
|
||||
-- (Instances will get instance reducibility in `Lean.Elab.Command.addParentInstances`.)
|
||||
if !binfo.isInstImplicit && !(← Meta.isProp parentType) then
|
||||
setReducibleAttribute declName
|
||||
return { structName := parentStructName, subobject := false, projFn := declName }
|
||||
|
||||
private def mkRemainingProjections (levelParams : List Name) (params : Array Expr) (view : StructView)
|
||||
(parents : Array StructParentInfo) (fieldInfos : Array StructFieldInfo) : TermElabM (Array StructureParentInfo) := do
|
||||
let structType := mkAppN (Lean.mkConst view.declName (levelParams.map mkLevelParam)) params
|
||||
withLocalDeclD `self structType fun source => do
|
||||
/-
|
||||
Remark: copied parents might still be referring to the fvars of other parents. We need to replace these fvars with projection constants.
|
||||
For subobject parents, this has already been done by `mkProjections`.
|
||||
https://github.com/leanprover/lean4/issues/2611
|
||||
-/
|
||||
let mut parentInfos := #[]
|
||||
let mut parentFVarToConst : ExprMap Expr := {}
|
||||
for h : i in [0:parents.size] do
|
||||
let parent := parents[i]
|
||||
let parentInfo : StructureParentInfo ← (do
|
||||
if parent.subobject then
|
||||
let some info := fieldInfos.find? (·.kind == .subobject parent.structName) | unreachable!
|
||||
pure { structName := parent.structName, subobject := true, projFn := info.declName }
|
||||
let mut declType ← instantiateMVars (← mkForallFVars params (← mkForallFVars #[source] parentType))
|
||||
if view.isClass && isClass env parentStructName then
|
||||
declType := setSourceInstImplicit declType
|
||||
declType := declType.inferImplicit params.size true
|
||||
let rec copyFields (parentType : Expr) : MetaM Expr := do
|
||||
let Expr.const parentStructName us ← pure parentType.getAppFn | unreachable!
|
||||
let parentCtor := getStructureCtor env parentStructName
|
||||
let mut result := mkAppN (mkConst parentCtor.name us) parentType.getAppArgs
|
||||
for fieldName in getStructureFields env parentStructName do
|
||||
if sourceFieldNames.contains fieldName then
|
||||
let fieldVal ← mkProjection source fieldName
|
||||
result := mkApp result fieldVal
|
||||
else
|
||||
let parent_type := (← instantiateMVars parent.type).replace fun e => parentFVarToConst[e]?
|
||||
mkCoercionToCopiedParent levelParams params view source parent.structName parent_type)
|
||||
parentInfos := parentInfos.push parentInfo
|
||||
if let some fvar := parent.fvar? then
|
||||
parentFVarToConst := parentFVarToConst.insert fvar <|
|
||||
mkApp (mkAppN (.const parentInfo.projFn (levelParams.map mkLevelParam)) params) source
|
||||
pure parentInfos
|
||||
-- fieldInfo must be a field of `parentStructName`
|
||||
let some fieldInfo := getFieldInfo? env parentStructName fieldName | unreachable!
|
||||
if fieldInfo.subobject?.isNone then throwError "failed to build coercion to parent structure"
|
||||
let resultType ← whnfD (← inferType result)
|
||||
unless resultType.isForall do throwError "failed to build coercion to parent structure, unexpected type{indentExpr resultType}"
|
||||
let fieldVal ← copyFields resultType.bindingDomain!
|
||||
result := mkApp result fieldVal
|
||||
return result
|
||||
let declVal ← instantiateMVars (← mkLambdaFVars params (← mkLambdaFVars #[source] (← copyFields parentType)))
|
||||
let declName := structName ++ mkToParentName (← getStructureName parentType) fun n => !env.contains (structName ++ n)
|
||||
-- Logic from `mk_projections`: prop-valued projections are theorems (or at least opaque)
|
||||
let cval : ConstantVal := { name := declName, levelParams, type := declType }
|
||||
if isProp then
|
||||
addDecl <|
|
||||
if view.modifiers.isUnsafe then
|
||||
-- Theorems cannot be unsafe.
|
||||
Declaration.opaqueDecl { cval with value := declVal, isUnsafe := true }
|
||||
else
|
||||
Declaration.thmDecl { cval with value := declVal }
|
||||
else
|
||||
addAndCompile <| Declaration.defnDecl { cval with
|
||||
value := declVal
|
||||
hints := ReducibilityHints.abbrev
|
||||
safety := if view.modifiers.isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe
|
||||
}
|
||||
-- Logic from `mk_projections`: non-instance-implicits that aren't props become reducible.
|
||||
-- (Instances will get instance reducibility in `Lean.Elab.Command.addParentInstances`.)
|
||||
if !binfo.isInstImplicit && !(← Meta.isProp parentType) then
|
||||
setReducibleAttribute declName
|
||||
return { structName := parentStructName, subobject := false, projFn := declName }
|
||||
|
||||
private def elabStructHeader (view : StructView) : TermElabM ElabStructHeaderResult :=
|
||||
Term.withAutoBoundImplicitForbiddenPred (fun n => view.shortDeclName == n) do
|
||||
Term.withAutoBoundImplicit do
|
||||
Term.elabBinders view.binders.getArgs fun params => do
|
||||
elabParents view fun parentFieldInfos parents => do
|
||||
let type ← Term.elabType view.type
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let u ← mkFreshLevelMVar
|
||||
unless ← isDefEq type (mkSort u) do
|
||||
throwErrorAt view.type "invalid structure type, expecting 'Type _' or 'Prop'"
|
||||
let type ← instantiateMVars (← whnf type)
|
||||
Term.addAutoBoundImplicits' params type fun params type => do
|
||||
let levelNames ← Term.getLevelNames
|
||||
trace[Elab.structure] "header params: {params}, type: {type}, levelNames: {levelNames}"
|
||||
return { lctx := (← getLCtx), localInsts := (← getLocalInstances), levelNames, params, type, view, parents, parentFieldInfos }
|
||||
|
||||
private def mkTypeFor (r : ElabStructHeaderResult) : TermElabM Expr := do
|
||||
withLCtx r.lctx r.localInsts do
|
||||
mkForallFVars r.params r.type
|
||||
|
||||
/--
|
||||
Create a local declaration for the structure and execute `x params indFVar`, where `params` are the structure's type parameters and
|
||||
`indFVar` is the new local declaration.
|
||||
-/
|
||||
private partial def withStructureLocalDecl (r : ElabStructHeaderResult) (x : Array Expr → Expr → TermElabM α) : TermElabM α := do
|
||||
let declName := r.view.declName
|
||||
let shortDeclName := r.view.shortDeclName
|
||||
let type ← mkTypeFor r
|
||||
let params := r.params
|
||||
withLCtx r.lctx r.localInsts <| withRef r.view.ref do
|
||||
Term.withAuxDecl shortDeclName type declName fun indFVar =>
|
||||
x params indFVar
|
||||
|
||||
/--
|
||||
Remark: `numVars <= numParams`.
|
||||
`numVars` is the number of context `variables` used in the declaration,
|
||||
and `numParams - numVars` is the number of parameters provided as binders in the declaration.
|
||||
-/
|
||||
private def mkInductiveType (view : StructView) (indFVar : Expr) (levelNames : List Name)
|
||||
(numVars : Nat) (numParams : Nat) (type : Expr) (ctor : Constructor) : TermElabM InductiveType := do
|
||||
let levelParams := levelNames.map mkLevelParam
|
||||
let const := mkConst view.declName levelParams
|
||||
let ctorType ← forallBoundedTelescope ctor.type numParams fun params type => do
|
||||
let type := type.replace fun e =>
|
||||
if e == indFVar then
|
||||
mkAppN const (params.extract 0 numVars)
|
||||
else
|
||||
none
|
||||
instantiateMVars (← mkForallFVars params type)
|
||||
return { name := view.declName, type := ← instantiateMVars type, ctors := [{ ctor with type := ← instantiateMVars ctorType }] }
|
||||
|
||||
/--
|
||||
Precomputes the structure's resolution order.
|
||||
@@ -881,45 +987,109 @@ private def addParentInstances (parents : Array StructureParentInfo) : MetaM Uni
|
||||
for instParent in instParents do
|
||||
addInstance instParent.projFn AttributeKind.global (eval_prio default)
|
||||
|
||||
@[builtin_inductive_elab Lean.Parser.Command.«structure»]
|
||||
def elabStructureCommand : InductiveElabDescr where
|
||||
mkInductiveView (modifiers : Modifiers) (stx : Syntax) := do
|
||||
def mkStructureDecl (vars : Array Expr) (view : StructView) : TermElabM Unit := Term.withoutSavingRecAppSyntax do
|
||||
let scopeLevelNames ← Term.getLevelNames
|
||||
let isUnsafe := view.modifiers.isUnsafe
|
||||
withRef view.ref <| Term.withLevelNames view.levelNames do
|
||||
let r ← elabStructHeader view
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
withLCtx r.lctx r.localInsts do
|
||||
withStructureLocalDecl r fun params indFVar => do
|
||||
trace[Elab.structure] "indFVar: {indFVar}"
|
||||
Term.addLocalVarInfo view.declId indFVar
|
||||
withFields view.fields r.parentFieldInfos fun fieldInfos =>
|
||||
withRef view.ref do
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let type ← instantiateMVars r.type
|
||||
let u ← getResultUniverse type
|
||||
let univToInfer? ← shouldInferResultUniverse u
|
||||
withUsed vars params fieldInfos fun scopeVars => do
|
||||
let fieldInfos ← levelMVarToParam scopeVars params fieldInfos univToInfer?
|
||||
let type ← withRef view.ref do
|
||||
if univToInfer?.isSome then
|
||||
updateResultingUniverse fieldInfos type
|
||||
else
|
||||
checkResultingUniverse (← getResultUniverse type)
|
||||
pure type
|
||||
trace[Elab.structure] "type: {type}"
|
||||
let usedLevelNames ← collectLevelParamsInStructure type scopeVars params fieldInfos
|
||||
match sortDeclLevelParams scopeLevelNames r.levelNames usedLevelNames with
|
||||
| Except.error msg => throwErrorAt view.declId msg
|
||||
| Except.ok levelParams =>
|
||||
let params := scopeVars ++ params
|
||||
let ctor ← mkCtor view levelParams params fieldInfos
|
||||
let type ← mkForallFVars params type
|
||||
let type ← instantiateMVars type
|
||||
let indType ← mkInductiveType view indFVar levelParams scopeVars.size params.size type ctor
|
||||
let decl := Declaration.inductDecl levelParams params.size [indType] isUnsafe
|
||||
Term.ensureNoUnassignedMVars decl
|
||||
addDecl decl
|
||||
-- rename indFVar so that it does not shadow the actual declaration:
|
||||
let lctx := (← getLCtx).modifyLocalDecl indFVar.fvarId! fun decl => decl.setUserName .anonymous
|
||||
withLCtx lctx (← getLocalInstances) do
|
||||
addProjections r fieldInfos
|
||||
registerStructure view.declName fieldInfos
|
||||
mkAuxConstructions view.declName
|
||||
withSaveInfoContext do -- save new env
|
||||
Term.addLocalVarInfo view.ref[1] (← mkConstWithLevelParams view.declName)
|
||||
if let some _ := view.ctor.ref.getPos? (canonicalOnly := true) then
|
||||
Term.addTermInfo' view.ctor.ref (← mkConstWithLevelParams view.ctor.declName) (isBinder := true)
|
||||
for field in view.fields do
|
||||
-- may not exist if overriding inherited field
|
||||
if (← getEnv).contains field.declName then
|
||||
Term.addTermInfo' field.ref (← mkConstWithLevelParams field.declName) (isBinder := true)
|
||||
withRef view.declId do
|
||||
Term.applyAttributesAt view.declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking
|
||||
let parentInfos ← r.parents.mapM fun parent => do
|
||||
if parent.subobject then
|
||||
let some info := fieldInfos.find? (·.kind == .subobject parent.structName) | unreachable!
|
||||
pure { structName := parent.structName, subobject := true, projFn := info.declName }
|
||||
else
|
||||
mkCoercionToCopiedParent levelParams params view parent.structName parent.type
|
||||
setStructureParents view.declName parentInfos
|
||||
checkResolutionOrder view.declName
|
||||
if view.isClass then
|
||||
addParentInstances parentInfos
|
||||
|
||||
let lctx ← getLCtx
|
||||
/- The `lctx` and `defaultAuxDecls` are used to create the auxiliary "default value" declarations
|
||||
The parameters `params` for these definitions must be marked as implicit, and all others as explicit. -/
|
||||
let lctx :=
|
||||
params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) =>
|
||||
if p.isFVar then
|
||||
lctx.setBinderInfo p.fvarId! BinderInfo.implicit
|
||||
else
|
||||
lctx
|
||||
let lctx :=
|
||||
fieldInfos.foldl (init := lctx) fun (lctx : LocalContext) (info : StructFieldInfo) =>
|
||||
if info.isFromParent then lctx -- `fromParent` fields are elaborated as let-decls, and are zeta-expanded when creating "default value" auxiliary functions
|
||||
else lctx.setBinderInfo info.fvar.fvarId! BinderInfo.default
|
||||
addDefaults lctx fieldInfos
|
||||
|
||||
|
||||
def elabStructureView (vars : Array Expr) (view : StructView) : TermElabM Unit := do
|
||||
Term.withDeclName view.declName <| withRef view.ref do
|
||||
mkStructureDecl vars view
|
||||
unless view.isClass do
|
||||
Lean.Meta.IndPredBelow.mkBelow view.declName
|
||||
mkSizeOfInstances view.declName
|
||||
mkInjectiveTheorems view.declName
|
||||
|
||||
def elabStructureViewPostprocessing (view : StructView) : CommandElabM Unit := do
|
||||
view.derivingClasses.forM fun classView => classView.applyHandlers #[view.declName]
|
||||
runTermElabM fun _ => Term.withDeclName view.declName <| withRef view.declId do
|
||||
Term.applyAttributesAt view.declName view.modifiers.attrs .afterCompilation
|
||||
|
||||
def elabStructure (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit := do
|
||||
let view ← runTermElabM fun vars => do
|
||||
let view ← structureSyntaxToView modifiers stx
|
||||
trace[Elab.structure] "view.levelNames: {view.levelNames}"
|
||||
return {
|
||||
view := view.toInductiveView
|
||||
elabCtors := fun rs r params => do
|
||||
withParents view rs r.indFVar fun parentFieldInfos parents =>
|
||||
withFields view.fields parentFieldInfos fun fieldInfos => do
|
||||
withRef view.ref do
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let lctx ← getLCtx
|
||||
let localInsts ← getLocalInstances
|
||||
let ctor ← mkCtor view r params fieldInfos
|
||||
return {
|
||||
ctors := [ctor]
|
||||
collectUsedFVars := collectUsedFVars lctx localInsts fieldInfos
|
||||
checkUniverses := fun _ u => withLCtx lctx localInsts do checkResultingUniversesForFields fieldInfos u
|
||||
finalizeTermElab := withLCtx lctx localInsts do checkDefaults fieldInfos
|
||||
prefinalize := fun _ _ _ => do
|
||||
withLCtx lctx localInsts do
|
||||
addProjections r fieldInfos
|
||||
registerStructure view.declName fieldInfos
|
||||
withSaveInfoContext do -- save new env
|
||||
for field in view.fields do
|
||||
-- may not exist if overriding inherited field
|
||||
if (← getEnv).contains field.declName then
|
||||
Term.addTermInfo' field.ref (← mkConstWithLevelParams field.declName) (isBinder := true)
|
||||
finalize := fun levelParams params replaceIndFVars => do
|
||||
let parentInfos ← mkRemainingProjections levelParams params view parents fieldInfos
|
||||
setStructureParents view.declName parentInfos
|
||||
checkResolutionOrder view.declName
|
||||
if view.isClass then
|
||||
addParentInstances parentInfos
|
||||
elabStructureView vars view
|
||||
pure view
|
||||
elabStructureViewPostprocessing view
|
||||
|
||||
withLCtx lctx localInsts do
|
||||
addDefaults params replaceIndFVars fieldInfos
|
||||
}
|
||||
}
|
||||
builtin_initialize
|
||||
registerTraceClass `Elab.structure
|
||||
registerTraceClass `Elab.structure.resolutionOrder
|
||||
|
||||
end Lean.Elab.Command
|
||||
|
||||
@@ -58,28 +58,15 @@ where
|
||||
let some eqBVPred ← ReifiedBVPred.mkBinPred atom resValExpr atomExpr resExpr .eq | return none
|
||||
let eqBV ← ReifiedBVLogical.ofPred eqBVPred
|
||||
|
||||
let trueExpr := mkConst ``Bool.true
|
||||
let impExpr ← mkArrow (← mkEq eqDiscrExpr trueExpr) (← mkEq eqBVExpr trueExpr)
|
||||
let decideImpExpr ← mkAppOptM ``Decidable.decide #[some impExpr, none]
|
||||
let imp ← ReifiedBVLogical.mkGate eqDiscr eqBV eqDiscrExpr eqBVExpr .imp
|
||||
|
||||
let proof := do
|
||||
let evalExpr ← ReifiedBVLogical.mkEvalExpr imp.expr
|
||||
let congrProof ← imp.evalsAtAtoms
|
||||
let lemmaProof := mkApp4 (mkConst lemmaName) (toExpr lhs.width) discrExpr lhsExpr rhsExpr
|
||||
|
||||
let trueExpr := mkConst ``Bool.true
|
||||
let eqDiscrTrueExpr ← mkEq eqDiscrExpr trueExpr
|
||||
let eqBVExprTrueExpr ← mkEq eqBVExpr trueExpr
|
||||
let impExpr ← mkArrow eqDiscrTrueExpr eqBVExprTrueExpr
|
||||
-- construct a `Decidable` instance for the implication using forall_prop_decidable
|
||||
let decEqDiscrTrue := mkApp2 (mkConst ``instDecidableEqBool) eqDiscrExpr trueExpr
|
||||
let decEqBVExprTrue := mkApp2 (mkConst ``instDecidableEqBool) eqBVExpr trueExpr
|
||||
let impDecidable := mkApp4 (mkConst ``forall_prop_decidable)
|
||||
eqDiscrTrueExpr
|
||||
(.lam .anonymous eqDiscrTrueExpr eqBVExprTrueExpr .default)
|
||||
decEqDiscrTrue
|
||||
(.lam .anonymous eqDiscrTrueExpr decEqBVExprTrue .default)
|
||||
|
||||
let decideImpExpr := mkApp2 (mkConst ``Decidable.decide) impExpr impDecidable
|
||||
|
||||
return mkApp4
|
||||
(mkConst ``Std.Tactic.BVDecide.Reflect.Bool.lemma_congr)
|
||||
decideImpExpr
|
||||
|
||||
@@ -40,7 +40,7 @@ private def tryAssumption (mvarId : MVarId) : MetaM (List MVarId) := do
|
||||
let ids := stx[1].getArgs.toList.map getNameOfIdent'
|
||||
liftMetaTactic fun mvarId => do
|
||||
match (← Meta.injections mvarId ids) with
|
||||
| .solved => checkUnusedIds `injections mvarId ids; return []
|
||||
| .subgoal mvarId' unusedIds _ => checkUnusedIds `injections mvarId unusedIds; tryAssumption mvarId'
|
||||
| .solved => checkUnusedIds `injections mvarId ids; return []
|
||||
| .subgoal mvarId' unusedIds => checkUnusedIds `injections mvarId unusedIds; tryAssumption mvarId'
|
||||
|
||||
end Lean.Elab.Tactic
|
||||
|
||||
@@ -473,26 +473,6 @@ def withoutTacticReuse [Monad m] [MonadWithReaderOf Context m] [MonadOptions m]
|
||||
return !cond }
|
||||
}) act
|
||||
|
||||
@[inherit_doc Core.wrapAsync]
|
||||
def wrapAsync (act : Unit → TermElabM α) : TermElabM (EIO Exception α) := do
|
||||
let ctx ← read
|
||||
let st ← get
|
||||
let metaCtx ← readThe Meta.Context
|
||||
let metaSt ← getThe Meta.State
|
||||
Core.wrapAsync fun _ =>
|
||||
act () |>.run ctx |>.run' st |>.run' metaCtx metaSt
|
||||
|
||||
@[inherit_doc Core.wrapAsyncAsSnapshot]
|
||||
def wrapAsyncAsSnapshot (act : Unit → TermElabM Unit)
|
||||
(desc : String := by exact decl_name%.toString) :
|
||||
TermElabM (BaseIO Language.SnapshotTree) := do
|
||||
let ctx ← read
|
||||
let st ← get
|
||||
let metaCtx ← readThe Meta.Context
|
||||
let metaSt ← getThe Meta.State
|
||||
Core.wrapAsyncAsSnapshot (desc := desc) fun _ =>
|
||||
act () |>.run ctx |>.run' st |>.run' metaCtx metaSt
|
||||
|
||||
abbrev TermElabResult (α : Type) := EStateM.Result Exception SavedState α
|
||||
|
||||
/--
|
||||
@@ -1079,9 +1059,7 @@ def synthesizeInstMVarCore (instMVar : MVarId) (maxResultSize? : Option Nat := n
|
||||
let oldValType ← inferType oldVal
|
||||
let valType ← inferType val
|
||||
unless (← isDefEq oldValType valType) do
|
||||
let (oldValType, valType) ← addPPExplicitToExposeDiff oldValType valType
|
||||
throwError "synthesized type class instance type is not definitionally equal to expected type, synthesized{indentExpr val}\nhas type{indentExpr valType}\nexpected{indentExpr oldValType}{extraErrorMsg}"
|
||||
let (oldVal, val) ← addPPExplicitToExposeDiff oldVal val
|
||||
throwError "synthesized type class instance is not definitionally equal to expression inferred by typing rules, synthesized{indentExpr val}\ninferred{indentExpr oldVal}{extraErrorMsg}"
|
||||
else
|
||||
unless (← isDefEq (mkMVar instMVar) val) do
|
||||
|
||||
@@ -897,18 +897,13 @@ def finalizeImport (s : ImportState) (imports : Array Import) (opts : Options) (
|
||||
initialized constant. We have seen significant savings in `open Mathlib`
|
||||
timings, where we have both a big environment and interpreted environment
|
||||
extensions, from this. There is no significant extra cost to calling
|
||||
`markPersistent` multiple times like this.
|
||||
|
||||
Safety: There are no concurrent accesses to `env` at this point. -/
|
||||
env := unsafe Runtime.markPersistent env
|
||||
`markPersistent` multiple times like this. -/
|
||||
env := Runtime.markPersistent env
|
||||
env ← finalizePersistentExtensions env s.moduleData opts
|
||||
if leakEnv then
|
||||
/- Ensure the final environment including environment extension states is
|
||||
marked persistent as documented.
|
||||
|
||||
Safety: There are no concurrent accesses to `env` at this point, assuming
|
||||
extensions' `addImportFn`s did not spawn any unbound tasks. -/
|
||||
env := unsafe Runtime.markPersistent env
|
||||
marked persistent as documented. -/
|
||||
env := Runtime.markPersistent env
|
||||
pure env
|
||||
|
||||
@[export lean_import_modules]
|
||||
|
||||
@@ -10,8 +10,9 @@ Authors: Sebastian Ullrich
|
||||
|
||||
prelude
|
||||
import Init.System.Promise
|
||||
import Lean.Message
|
||||
import Lean.Parser.Types
|
||||
import Lean.Util.Trace
|
||||
import Lean.Elab.InfoTree
|
||||
|
||||
set_option linter.missingDocs true
|
||||
|
||||
@@ -197,13 +198,32 @@ def withAlwaysResolvedPromises [Monad m] [MonadLiftT BaseIO m] [MonadFinally m]
|
||||
for asynchronously collecting information about the entirety of snapshots in the language server.
|
||||
The involved tasks may form a DAG on the `Task` dependency level but this is not captured by this
|
||||
data structure. -/
|
||||
structure SnapshotTree where
|
||||
/-- The immediately available element of the snapshot tree node. -/
|
||||
element : Snapshot
|
||||
/-- The asynchronously available children of the snapshot tree node. -/
|
||||
children : Array (SnapshotTask SnapshotTree)
|
||||
inductive SnapshotTree where
|
||||
/-- Creates a snapshot tree node. -/
|
||||
| mk (element : Snapshot) (children : Array (SnapshotTask SnapshotTree))
|
||||
deriving Inhabited
|
||||
|
||||
/-- The immediately available element of the snapshot tree node. -/
|
||||
abbrev SnapshotTree.element : SnapshotTree → Snapshot
|
||||
| mk s _ => s
|
||||
/-- The asynchronously available children of the snapshot tree node. -/
|
||||
abbrev SnapshotTree.children : SnapshotTree → Array (SnapshotTask SnapshotTree)
|
||||
| mk _ children => children
|
||||
|
||||
/-- Produces trace of given snapshot tree, synchronously waiting on all children. -/
|
||||
partial def SnapshotTree.trace (s : SnapshotTree) : CoreM Unit :=
|
||||
go none s
|
||||
where go range? s := do
|
||||
let file ← getFileMap
|
||||
let mut desc := f!"{s.element.desc}"
|
||||
if let some range := range? then
|
||||
desc := desc ++ f!"{file.toPosition range.start}-{file.toPosition range.stop} "
|
||||
desc := desc ++ .prefixJoin "\n• " (← s.element.diagnostics.msgLog.toList.mapM (·.toString))
|
||||
if let some t := s.element.infoTree? then
|
||||
trace[Elab.info] (← t.format)
|
||||
withTraceNode `Elab.snapshotTree (fun _ => pure desc) do
|
||||
s.children.toList.forM fun c => go c.range? c.get
|
||||
|
||||
/--
|
||||
Helper class for projecting a heterogeneous hierarchy of snapshot classes to a homogeneous
|
||||
representation. -/
|
||||
@@ -288,14 +308,6 @@ def SnapshotTree.runAndReport (s : SnapshotTree) (opts : Options) (json := false
|
||||
def SnapshotTree.getAll (s : SnapshotTree) : Array Snapshot :=
|
||||
s.forM (m := StateM _) (fun s => modify (·.push s)) |>.run #[] |>.2
|
||||
|
||||
/-- Returns a task that waits on all snapshots in the tree. -/
|
||||
def SnapshotTree.waitAll : SnapshotTree → BaseIO (Task Unit)
|
||||
| mk _ children => go children.toList
|
||||
where
|
||||
go : List (SnapshotTask SnapshotTree) → BaseIO (Task Unit)
|
||||
| [] => return .pure ()
|
||||
| t::ts => BaseIO.bindTask t.task fun _ => go ts
|
||||
|
||||
/-- Context of an input processing invocation. -/
|
||||
structure ProcessingContext extends Parser.InputContext
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ Authors: Sebastian Ullrich
|
||||
|
||||
prelude
|
||||
import Lean.Language.Basic
|
||||
import Lean.Language.Util
|
||||
import Lean.Language.Lean.Types
|
||||
import Lean.Parser.Module
|
||||
import Lean.Elab.Import
|
||||
@@ -46,33 +45,17 @@ delete the space after private, it becomes a syntactically correct structure wit
|
||||
privateaxiom! So clearly, because of uses of atomic in the grammar, an edit can affect a command
|
||||
syntax tree even across multiple tokens.
|
||||
|
||||
What we did in Lean 3 was to always reparse the last command completely preceding the edit location.
|
||||
If its syntax tree is unchanged, we preserve its data and reprocess all following commands only,
|
||||
otherwise we reprocess it fully as well. This worked well but did seem a bit arbitrary given that
|
||||
even if it works for a grammar at some point, it can certainly be extended in ways that break the
|
||||
assumption.
|
||||
|
||||
With grammar changes in Lean 4, we found that the following example indeed breaks this assumption:
|
||||
```
|
||||
structure Signature where
|
||||
/-- a docstring -/
|
||||
Sort : Type
|
||||
--^ insert: "s"
|
||||
```
|
||||
As the keyword `Sort` is not a valid start of a structure field and the parser backtracks across the
|
||||
docstring in that case, this is parsed as the complete command `structure Signature where` followed
|
||||
by the partial command `/-- a docstring -/ <missing>`. If we insert an `s` after the `t`, the last
|
||||
command completely preceding the edit location is the partial command containing the docstring. Thus
|
||||
we need to go up two commands to ensure we reparse the `structure` command as well. This kind of
|
||||
nested docstring is the only part of the grammar to our knowledge that requires going up at least
|
||||
two commands; as we never backtrack across more than one docstring, going up two commands should
|
||||
also be sufficient.
|
||||
Now, what we do today, and have done since Lean 3, is to always reparse the last command completely
|
||||
preceding the edit location. If its syntax tree is unchanged, we preserve its data and reprocess all
|
||||
following commands only, otherwise we reprocess it fully as well. This seems to have worked well so
|
||||
far but it does seem a bit arbitrary given that even if it works for our current grammar, it can
|
||||
certainly be extended in ways that break the assumption.
|
||||
|
||||
Finally, a more actually principled and generic solution would be to invalidate a syntax tree when
|
||||
the parser has reached the edit location during parsing. If it did not, surely the edit cannot have
|
||||
an effect on the syntax tree in question. Sadly such a "high-water mark" parser position does not
|
||||
exist currently and likely it could at best be approximated by e.g. "furthest `tokenFn` parse". Thus
|
||||
we remain at "go up two commands" at this point.
|
||||
we remain at "go two commands up" at this point.
|
||||
-/
|
||||
|
||||
/-!
|
||||
@@ -183,6 +166,13 @@ namespace Lean.Language.Lean
|
||||
open Lean.Elab Command
|
||||
open Lean.Parser
|
||||
|
||||
/-- Option for capturing output to stderr during elaboration. -/
|
||||
register_builtin_option stderrAsMessages : Bool := {
|
||||
defValue := true
|
||||
group := "server"
|
||||
descr := "(server) capture output to the Lean stderr channel (such as from `dbg_trace`) during elaboration of a command as a diagnostic message"
|
||||
}
|
||||
|
||||
/-- Lean-specific processing context. -/
|
||||
structure LeanProcessingContext extends ProcessingContext where
|
||||
/-- Position of the first file difference if there was a previous invocation. -/
|
||||
@@ -356,12 +346,11 @@ where
|
||||
if let some old := old? then
|
||||
if let some oldSuccess := old.result? then
|
||||
if let some (some processed) ← old.processedResult.get? then
|
||||
-- ...and the edit is after the second-next command (see note [Incremental Parsing])...
|
||||
-- ...and the edit location is after the next command (see note [Incremental Parsing])...
|
||||
if let some nextCom ← processed.firstCmdSnap.get? then
|
||||
if let some nextNextCom ← processed.firstCmdSnap.get? then
|
||||
if (← isBeforeEditPos nextNextCom.parserState.pos) then
|
||||
-- ...go immediately to next snapshot
|
||||
return (← unchanged old old.stx oldSuccess.parserState)
|
||||
if (← isBeforeEditPos nextCom.parserState.pos) then
|
||||
-- ...go immediately to next snapshot
|
||||
return (← unchanged old old.stx oldSuccess.parserState)
|
||||
|
||||
withHeaderExceptions ({ · with
|
||||
ictx, stx := .missing, result? := none, cancelTk? := none }) do
|
||||
@@ -454,6 +443,11 @@ where
|
||||
traceState
|
||||
}
|
||||
let prom ← IO.Promise.new
|
||||
-- The speedup of these `markPersistent`s is negligible but they help in making unexpected
|
||||
-- `inc_ref_cold`s more visible
|
||||
let parserState := Runtime.markPersistent parserState
|
||||
let cmdState := Runtime.markPersistent cmdState
|
||||
let ctx := Runtime.markPersistent ctx
|
||||
parseCmd none parserState cmdState prom (sync := true) ctx
|
||||
return {
|
||||
diagnostics
|
||||
@@ -485,12 +479,11 @@ where
|
||||
prom.resolve <| { old with nextCmdSnap? := some { range? := none, task := newProm.result } }
|
||||
else prom.resolve old -- terminal command, we're done!
|
||||
|
||||
-- fast path, do not even start new task for this snapshot (see [Incremental Parsing])
|
||||
-- fast path, do not even start new task for this snapshot
|
||||
if let some old := old? then
|
||||
if let some nextCom ← old.nextCmdSnap?.bindM (·.get?) then
|
||||
if let some nextNextCom ← nextCom.nextCmdSnap?.bindM (·.get?) then
|
||||
if (← isBeforeEditPos nextNextCom.parserState.pos) then
|
||||
return (← unchanged old old.parserState)
|
||||
if (← isBeforeEditPos nextCom.parserState.pos) then
|
||||
return (← unchanged old old.parserState)
|
||||
|
||||
let beginPos := parserState.pos
|
||||
let scope := cmdState.scopes.head!
|
||||
@@ -526,7 +519,6 @@ where
|
||||
diagnostics := .empty, stx := .missing, parserState
|
||||
elabSnap := .pure <| .ofTyped { diagnostics := .empty : SnapshotLeaf }
|
||||
finishedSnap := .pure { diagnostics := .empty, cmdState }
|
||||
reportSnap := default
|
||||
tacticCache := (← IO.mkRef {})
|
||||
nextCmdSnap? := none
|
||||
}
|
||||
@@ -537,7 +529,6 @@ where
|
||||
-- definitely resolved in `doElab` task
|
||||
let elabPromise ← IO.Promise.new
|
||||
let finishedPromise ← IO.Promise.new
|
||||
let reportPromise ← IO.Promise.new
|
||||
-- (Try to) use last line of command as range for final snapshot task. This ensures we do not
|
||||
-- retract the progress bar to a previous position in case the command support incremental
|
||||
-- reporting but has significant work after resolving its last incremental promise, such as
|
||||
@@ -556,46 +547,22 @@ where
|
||||
let nextCmdSnap? := next?.map
|
||||
({ range? := some ⟨parserState.pos, ctx.input.endPos⟩, task := ·.result })
|
||||
let diagnostics ← Snapshot.Diagnostics.ofMessageLog msgLog
|
||||
let (stx', parserState') := if minimalSnapshots && !Parser.isTerminalCommand stx then
|
||||
(default, default)
|
||||
if minimalSnapshots && !Parser.isTerminalCommand stx then
|
||||
prom.resolve {
|
||||
diagnostics, finishedSnap, tacticCache, nextCmdSnap?
|
||||
stx := .missing
|
||||
parserState := {}
|
||||
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
|
||||
}
|
||||
else
|
||||
(stx, parserState)
|
||||
prom.resolve {
|
||||
diagnostics, finishedSnap, tacticCache, nextCmdSnap?
|
||||
stx := stx', parserState := parserState'
|
||||
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
|
||||
reportSnap := { range? := endRange?, task := reportPromise.result }
|
||||
}
|
||||
prom.resolve {
|
||||
diagnostics, stx, parserState, tacticCache, nextCmdSnap?
|
||||
elabSnap := { range? := stx.getRange?, task := elabPromise.result }
|
||||
finishedSnap
|
||||
}
|
||||
let cmdState ← doElab stx cmdState beginPos
|
||||
{ old? := old?.map fun old => ⟨old.stx, old.elabSnap⟩, new := elabPromise }
|
||||
finishedPromise tacticCache ctx
|
||||
let traceTask ←
|
||||
if (← isTracingEnabledForCore `Elab.snapshotTree cmdState.scopes.head!.opts) then
|
||||
-- We want to trace all of `CommandParsedSnapshot` but `traceTask` is part of it, so let's
|
||||
-- create a temporary snapshot tree containing all tasks but it
|
||||
let snaps := #[
|
||||
{ range? := none, task := elabPromise.result.map (sync := true) toSnapshotTree },
|
||||
{ range? := none, task := finishedPromise.result.map (sync := true) toSnapshotTree }] ++
|
||||
cmdState.snapshotTasks
|
||||
let tree := SnapshotTree.mk { diagnostics := .empty } snaps
|
||||
BaseIO.bindTask (← tree.waitAll) fun _ => do
|
||||
let .ok (_, s) ← EIO.toBaseIO <| tree.trace |>.run
|
||||
{ ctx with options := cmdState.scopes.head!.opts } { env := cmdState.env }
|
||||
| pure <| .pure <| .mk { diagnostics := .empty } #[]
|
||||
let mut msgLog := MessageLog.empty
|
||||
for trace in s.traceState.traces do
|
||||
msgLog := msgLog.add {
|
||||
fileName := ctx.fileName
|
||||
severity := MessageSeverity.information
|
||||
pos := ctx.fileMap.toPosition beginPos
|
||||
data := trace.msg
|
||||
}
|
||||
return .pure <| .mk { diagnostics := (← Snapshot.Diagnostics.ofMessageLog msgLog) } #[]
|
||||
else
|
||||
pure <| .pure <| .mk { diagnostics := .empty } #[]
|
||||
reportPromise.resolve <|
|
||||
.mk { diagnostics := .empty } <|
|
||||
cmdState.snapshotTasks.push { range? := endRange?, task := traceTask }
|
||||
if let some next := next? then
|
||||
-- We're definitely off the fast-forwarding path now
|
||||
parseCmd none parserState cmdState next (sync := false) ctx
|
||||
@@ -606,9 +573,7 @@ where
|
||||
LeanProcessingM Command.State := do
|
||||
let ctx ← read
|
||||
let scope := cmdState.scopes.head!
|
||||
-- reset per-command state
|
||||
let cmdStateRef ← IO.mkRef { cmdState with
|
||||
messages := .empty, traceState := {}, snapshotTasks := #[] }
|
||||
let cmdStateRef ← IO.mkRef { cmdState with messages := .empty, traceState := {} }
|
||||
/-
|
||||
The same snapshot may be executed by different tasks. So, to make sure `elabCommandTopLevel`
|
||||
has exclusive access to the cache, we create a fresh reference here. Before this change, the
|
||||
@@ -623,8 +588,8 @@ where
|
||||
cancelTk? := some ctx.newCancelTk
|
||||
}
|
||||
let (output, _) ←
|
||||
IO.FS.withIsolatedStreams (isolateStderr := Core.stderrAsMessages.get scope.opts) do
|
||||
EIO.toBaseIO do
|
||||
IO.FS.withIsolatedStreams (isolateStderr := stderrAsMessages.get scope.opts) do
|
||||
liftM (m := BaseIO) do
|
||||
withLoggingExceptions
|
||||
(getResetInfoTrees *> Elab.Command.elabCommandTopLevel stx)
|
||||
cmdCtx cmdStateRef
|
||||
@@ -643,24 +608,16 @@ where
|
||||
-- definitely resolve eventually
|
||||
snap.new.resolve <| .ofTyped { diagnostics := .empty : SnapshotLeaf }
|
||||
|
||||
let mut infoTree : InfoTree := cmdState.infoState.trees[0]!
|
||||
let mut infoTree := cmdState.infoState.trees[0]!
|
||||
let cmdline := internal.cmdlineSnapshots.get scope.opts && !Parser.isTerminalCommand stx
|
||||
if cmdline && !Elab.async.get scope.opts then
|
||||
/-
|
||||
Safety: `infoTree` was created by `elabCommandTopLevel`. Thus it
|
||||
should not have any concurrent accesses if we are on the cmdline and
|
||||
async elaboration is disabled.
|
||||
-/
|
||||
-- TODO: we should likely remove this call when `Elab.async` is turned on
|
||||
-- by default
|
||||
infoTree := unsafe Runtime.markPersistent infoTree
|
||||
if cmdline then
|
||||
infoTree := Runtime.markPersistent infoTree
|
||||
finishedPromise.resolve {
|
||||
diagnostics := (← Snapshot.Diagnostics.ofMessageLog cmdState.messages)
|
||||
infoTree? := infoTree
|
||||
traces := cmdState.traceState
|
||||
cmdState := if cmdline then {
|
||||
/- Safety: as above -/
|
||||
env := unsafe Runtime.markPersistent cmdState.env
|
||||
env := Runtime.markPersistent cmdState.env
|
||||
maxRecDepth := 0
|
||||
} else cmdState
|
||||
}
|
||||
|
||||
@@ -46,8 +46,6 @@ structure CommandParsedSnapshot extends Snapshot where
|
||||
elabSnap : SnapshotTask DynamicSnapshot
|
||||
/-- State after processing is finished. -/
|
||||
finishedSnap : SnapshotTask CommandFinishedSnapshot
|
||||
/-- Additional, untyped snapshots used for reporting, not reuse. -/
|
||||
reportSnap : SnapshotTask SnapshotTree
|
||||
/-- Cache for `save`; to be replaced with incrementality. -/
|
||||
tacticCache : IO.Ref Tactic.Cache
|
||||
/-- Next command, unless this is a terminal command. -/
|
||||
@@ -57,8 +55,7 @@ partial instance : ToSnapshotTree CommandParsedSnapshot where
|
||||
toSnapshotTree := go where
|
||||
go s := ⟨s.toSnapshot,
|
||||
#[s.elabSnap.map (sync := true) toSnapshotTree,
|
||||
s.finishedSnap.map (sync := true) toSnapshotTree,
|
||||
s.reportSnap] |>
|
||||
s.finishedSnap.map (sync := true) toSnapshotTree] |>
|
||||
pushOpt (s.nextCmdSnap?.map (·.map (sync := true) go))⟩
|
||||
|
||||
/-- State after successful importing. -/
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2023 Lean FRO. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
||||
Additional snapshot functionality that needs further imports.
|
||||
|
||||
Authors: Sebastian Ullrich
|
||||
-/
|
||||
|
||||
prelude
|
||||
import Lean.Language.Basic
|
||||
import Lean.CoreM
|
||||
import Lean.Elab.InfoTree
|
||||
|
||||
namespace Lean.Language
|
||||
|
||||
/-- Produces trace of given snapshot tree, synchronously waiting on all children. -/
|
||||
partial def SnapshotTree.trace (s : SnapshotTree) : CoreM Unit :=
|
||||
go none s
|
||||
where go range? s := do
|
||||
let file ← getFileMap
|
||||
let mut desc := f!"{s.element.desc}"
|
||||
if let some range := range? then
|
||||
desc := desc ++ f!"{file.toPosition range.start}-{file.toPosition range.stop} "
|
||||
desc := desc ++ .prefixJoin "\n• " (← s.element.diagnostics.msgLog.toList.mapM (·.toString))
|
||||
if let some t := s.element.infoTree? then
|
||||
trace[Elab.info] (← t.format)
|
||||
withTraceNode `Elab.snapshotTree (fun _ => pure desc) do
|
||||
s.children.toList.forM fun c => go c.range? c.get
|
||||
@@ -31,10 +31,6 @@ builtin_initialize deprecatedAttr : ParametricAttribute DeprecationEntry ←
|
||||
let newName? ← id?.mapM Elab.realizeGlobalConstNoOverloadWithInfo
|
||||
let text? := text?.map TSyntax.getString
|
||||
let since? := since?.map TSyntax.getString
|
||||
if id?.isNone && text?.isNone then
|
||||
logWarning "`[deprecated]` attribute should specify either a new name or a deprecation message"
|
||||
if since?.isNone then
|
||||
logWarning "`[deprecated]` attribute should specify the date or library version at which the deprecation was introduced, using `(since := \"...\")`"
|
||||
return { newName?, text?, since? }
|
||||
}
|
||||
|
||||
@@ -50,9 +46,7 @@ def getDeprecatedNewName (env : Environment) (declName : Name) : Option Name :=
|
||||
def checkDeprecated [Monad m] [MonadEnv m] [MonadLog m] [AddMessageContext m] [MonadOptions m] (declName : Name) : m Unit := do
|
||||
if getLinterValue linter.deprecated (← getOptions) then
|
||||
let some attr := deprecatedAttr.getParam? (← getEnv) declName | pure ()
|
||||
logWarning <| .tagged ``deprecatedAttr <|
|
||||
s!"`{declName}` has been deprecated" ++ match attr.text? with
|
||||
| some text => s!": {text}"
|
||||
| none => match attr.newName? with
|
||||
| some newName => s!": use `{newName}` instead"
|
||||
| none => ""
|
||||
logWarning <| .tagged ``deprecatedAttr <| attr.text?.getD <|
|
||||
match attr.newName? with
|
||||
| none => s!"`{declName}` has been deprecated"
|
||||
| some newName => s!"`{declName}` has been deprecated, use `{newName}` instead"
|
||||
|
||||
@@ -5,40 +5,14 @@ Authors: Mac Malone
|
||||
-/
|
||||
prelude
|
||||
import Init.System.IO
|
||||
|
||||
namespace Lean
|
||||
|
||||
/--
|
||||
Dynamically loads a shared library so that its symbols can be used by
|
||||
the Lean interpreter (e.g., for interpreting `@[extern]` declarations).
|
||||
Equivalent to passing `--load-dynlib=path` to `lean`.
|
||||
Equivalent to passing `--load-dynlib=lib` to `lean`.
|
||||
|
||||
**Lean never unloads libraries.** Attempting to load a library that defines
|
||||
symbols shared with a previously loaded library (including itself) will error.
|
||||
If multiple libraries share common symbols, those symbols should be linked
|
||||
and loaded as separate libraries.
|
||||
Note that Lean never unloads libraries.
|
||||
-/
|
||||
@[extern "lean_load_dynlib"]
|
||||
opaque loadDynlib (path : @& System.FilePath) : IO Unit
|
||||
|
||||
/--
|
||||
Loads a Lean plugin and runs its initializers.
|
||||
|
||||
A Lean plugin is a shared library built from a Lean module.
|
||||
This means it has an `initialize_<module-name>` symbol that runs the
|
||||
module's initializers (including its imports' initializers). Initializers
|
||||
are declared with the `initialize` or `builtin_initialize` commands.
|
||||
|
||||
This is similar to passing `--plugin=path` to `lean`.
|
||||
Lean environment initializers, such as definitions calling
|
||||
`registerEnvExtension`, also require `Lean.initializing` to be `true`.
|
||||
To enable them, use `loadPlugin` within a `withImporting` block. This will
|
||||
set `Lean.initializing` (but not `IO.initializing`).
|
||||
|
||||
**Lean never unloads plugins.** Attempting to load a plugin that defines
|
||||
symbols shared with a previously loaded plugin (including itself) will error.
|
||||
If multiple plugins share common symbols (e.g., imports), those symbols
|
||||
should be linked and loaded separately.
|
||||
-/
|
||||
@[extern "lean_load_plugin"]
|
||||
opaque loadPlugin (path : @& System.FilePath) : IO Unit
|
||||
|
||||
@@ -406,8 +406,8 @@ def isSubPrefixOf (lctx₁ lctx₂ : LocalContext) (exceptFVars : Array Expr :=
|
||||
|
||||
@[inline] def mkBinding (isLambda : Bool) (lctx : LocalContext) (xs : Array Expr) (b : Expr) : Expr :=
|
||||
let b := b.abstract xs
|
||||
xs.size.foldRev (init := b) fun i _ b =>
|
||||
let x := xs[i]
|
||||
xs.size.foldRev (init := b) fun i b =>
|
||||
let x := xs[i]!
|
||||
match lctx.findFVar? x with
|
||||
| some (.cdecl _ _ n ty bi _) =>
|
||||
let ty := ty.abstractRange i xs;
|
||||
@@ -457,7 +457,7 @@ def sanitizeNames (lctx : LocalContext) : StateM NameSanitizerState LocalContext
|
||||
let st ← get
|
||||
if !getSanitizeNames st.options then pure lctx else
|
||||
StateT.run' (s := ({} : NameSet)) <|
|
||||
lctx.decls.size.foldRevM (init := lctx) fun i _ lctx => do
|
||||
lctx.decls.size.foldRevM (init := lctx) fun i lctx => do
|
||||
match lctx.decls[i]! with
|
||||
| none => pure lctx
|
||||
| some decl =>
|
||||
|
||||
@@ -825,7 +825,7 @@ def mkFreshExprMVarWithId (mvarId : MVarId) (type? : Option Expr := none) (kind
|
||||
mkFreshExprMVarWithIdCore mvarId type kind userName
|
||||
|
||||
def mkFreshLevelMVars (num : Nat) : MetaM (List Level) :=
|
||||
num.foldM (init := []) fun _ _ us =>
|
||||
num.foldM (init := []) fun _ us =>
|
||||
return (← mkFreshLevelMVar)::us
|
||||
|
||||
def mkFreshLevelMVarsFor (info : ConstantInfo) : MetaM (List Level) :=
|
||||
|
||||
@@ -286,8 +286,8 @@ partial def process : ClosureM Unit := do
|
||||
@[inline] def mkBinding (isLambda : Bool) (decls : Array LocalDecl) (b : Expr) : Expr :=
|
||||
let xs := decls.map LocalDecl.toExpr
|
||||
let b := b.abstract xs
|
||||
decls.size.foldRev (init := b) fun i _ b =>
|
||||
let decl := decls[i]
|
||||
decls.size.foldRev (init := b) fun i b =>
|
||||
let decl := decls[i]!
|
||||
match decl with
|
||||
| .cdecl _ _ n ty bi _ =>
|
||||
let ty := ty.abstractRange i xs
|
||||
|
||||
@@ -152,7 +152,7 @@ where
|
||||
let run (x : StateRefT Nat MetaM Expr) : MetaM (Expr × Nat) := StateRefT'.run x 0
|
||||
let (_, cnt) ← run <| transform domain fun e => do
|
||||
if let some name := e.constName? then
|
||||
if ctx.typeInfos.any fun indVal => indVal.name == name then
|
||||
if let some _ := ctx.typeInfos.findIdx? fun indVal => indVal.name == name then
|
||||
modify (· + 1)
|
||||
return .continue
|
||||
|
||||
|
||||
@@ -566,9 +566,9 @@ Append results to array
|
||||
partial def appendResultsAux (mr : MatchResult α) (a : Array β) (f : Nat → α → β) : Array β :=
|
||||
let aa := mr.elts
|
||||
let n := aa.size
|
||||
Nat.fold (n := n) (init := a) fun i _ r =>
|
||||
Nat.fold (n := n) (init := a) fun i r =>
|
||||
let j := n-1-i
|
||||
let b := aa[j]
|
||||
let b := aa[j]!
|
||||
b.foldl (init := r) (· ++ ·.map (f j))
|
||||
|
||||
partial def appendResults (mr : MatchResult α) (a : Array α) : Array α :=
|
||||
|
||||
@@ -882,7 +882,7 @@ def mkMatcher (input : MkMatcherInput) (exceptionIfContainsSorry := false) : Met
|
||||
| none => pure ()
|
||||
|
||||
trace[Meta.Match.debug] "matcher: {matcher}"
|
||||
let unusedAltIdxs := lhss.length.fold (init := []) fun i _ r =>
|
||||
let unusedAltIdxs := lhss.length.fold (init := []) fun i r =>
|
||||
if s.used.contains i then r else i::r
|
||||
return {
|
||||
matcher,
|
||||
|
||||
@@ -222,16 +222,16 @@ private def contradiction (mvarId : MVarId) : MetaM Bool :=
|
||||
Auxiliary tactic that tries to replace as many variables as possible and then apply `contradiction`.
|
||||
We use it to discard redundant hypotheses.
|
||||
-/
|
||||
partial def trySubstVarsAndContradiction (mvarId : MVarId) (forbidden : FVarIdSet := {}) : MetaM Bool :=
|
||||
partial def trySubstVarsAndContradiction (mvarId : MVarId) : MetaM Bool :=
|
||||
commitWhen do
|
||||
let mvarId ← substVars mvarId
|
||||
match (← injections mvarId (forbidden := forbidden)) with
|
||||
match (← injections mvarId) with
|
||||
| .solved => return true -- closed goal
|
||||
| .subgoal mvarId' _ forbidden =>
|
||||
| .subgoal mvarId' _ =>
|
||||
if mvarId' == mvarId then
|
||||
contradiction mvarId
|
||||
else
|
||||
trySubstVarsAndContradiction mvarId' forbidden
|
||||
trySubstVarsAndContradiction mvarId'
|
||||
|
||||
private def processNextEq : M Bool := do
|
||||
let s ← get
|
||||
@@ -375,8 +375,7 @@ private partial def withSplitterAlts (altTypes : Array Expr) (f : Array Expr →
|
||||
inductive InjectionAnyResult where
|
||||
| solved
|
||||
| failed
|
||||
/-- `fvarId` refers to the local declaration selected for the application of the `injection` tactic. -/
|
||||
| subgoal (fvarId : FVarId) (mvarId : MVarId)
|
||||
| subgoal (mvarId : MVarId)
|
||||
|
||||
private def injectionAnyCandidate? (type : Expr) : MetaM (Option (Expr × Expr)) := do
|
||||
if let some (_, lhs, rhs) ← matchEq? type then
|
||||
@@ -386,28 +385,21 @@ private def injectionAnyCandidate? (type : Expr) : MetaM (Option (Expr × Expr))
|
||||
return some (lhs, rhs)
|
||||
return none
|
||||
|
||||
/--
|
||||
Try applying `injection` to a local declaration that is not in `forbidden`.
|
||||
|
||||
We use `forbidden` because the `injection` tactic might fail to clear the variable if there are forward dependencies.
|
||||
See `proveSubgoalLoop` for additional details.
|
||||
-/
|
||||
private def injectionAny (mvarId : MVarId) (forbidden : FVarIdSet := {}) : MetaM InjectionAnyResult := do
|
||||
private def injectionAny (mvarId : MVarId) : MetaM InjectionAnyResult := do
|
||||
mvarId.withContext do
|
||||
for localDecl in (← getLCtx) do
|
||||
unless forbidden.contains localDecl.fvarId do
|
||||
if let some (lhs, rhs) ← injectionAnyCandidate? localDecl.type then
|
||||
unless (← isDefEq lhs rhs) do
|
||||
let lhs ← whnf lhs
|
||||
let rhs ← whnf rhs
|
||||
unless lhs.isRawNatLit && rhs.isRawNatLit do
|
||||
try
|
||||
match (← injection mvarId localDecl.fvarId) with
|
||||
| InjectionResult.solved => return InjectionAnyResult.solved
|
||||
| InjectionResult.subgoal mvarId .. => return InjectionAnyResult.subgoal localDecl.fvarId mvarId
|
||||
catch ex =>
|
||||
trace[Meta.Match.matchEqs] "injectionAnyFailed at {localDecl.userName}, error\n{ex.toMessageData}"
|
||||
pure ()
|
||||
if let some (lhs, rhs) ← injectionAnyCandidate? localDecl.type then
|
||||
unless (← isDefEq lhs rhs) do
|
||||
let lhs ← whnf lhs
|
||||
let rhs ← whnf rhs
|
||||
unless lhs.isRawNatLit && rhs.isRawNatLit do
|
||||
try
|
||||
match (← injection mvarId localDecl.fvarId) with
|
||||
| InjectionResult.solved => return InjectionAnyResult.solved
|
||||
| InjectionResult.subgoal mvarId .. => return InjectionAnyResult.subgoal mvarId
|
||||
catch ex =>
|
||||
trace[Meta.Match.matchEqs] "injectionAnyFailed at {localDecl.userName}, error\n{ex.toMessageData}"
|
||||
pure ()
|
||||
return InjectionAnyResult.failed
|
||||
|
||||
|
||||
@@ -609,32 +601,27 @@ where
|
||||
let eNew := mkAppN eNew mvars
|
||||
return TransformStep.done eNew
|
||||
|
||||
/-
|
||||
`forbidden` tracks variables that we have already applied `injection`.
|
||||
Recall that the `injection` tactic may not be able to eliminate them when
|
||||
they have forward dependencies.
|
||||
-/
|
||||
proveSubgoalLoop (mvarId : MVarId) (forbidden : FVarIdSet) : MetaM Unit := do
|
||||
proveSubgoalLoop (mvarId : MVarId) : MetaM Unit := do
|
||||
trace[Meta.Match.matchEqs] "proveSubgoalLoop\n{mvarId}"
|
||||
if (← mvarId.contradictionQuick) then
|
||||
return ()
|
||||
match (← injectionAny mvarId forbidden) with
|
||||
| .solved => return ()
|
||||
| .failed =>
|
||||
match (← injectionAny mvarId) with
|
||||
| InjectionAnyResult.solved => return ()
|
||||
| InjectionAnyResult.failed =>
|
||||
let mvarId' ← substVars mvarId
|
||||
if mvarId' == mvarId then
|
||||
if (← mvarId.contradictionCore {}) then
|
||||
return ()
|
||||
throwError "failed to generate splitter for match auxiliary declaration '{matchDeclName}', unsolved subgoal:\n{MessageData.ofGoal mvarId}"
|
||||
else
|
||||
proveSubgoalLoop mvarId' forbidden
|
||||
| .subgoal fvarId mvarId => proveSubgoalLoop mvarId (forbidden.insert fvarId)
|
||||
proveSubgoalLoop mvarId'
|
||||
| InjectionAnyResult.subgoal mvarId => proveSubgoalLoop mvarId
|
||||
|
||||
proveSubgoal (mvarId : MVarId) : MetaM Unit := do
|
||||
trace[Meta.Match.matchEqs] "subgoal {mkMVar mvarId}, {repr (← mvarId.getDecl).kind}, {← mvarId.isAssigned}\n{MessageData.ofGoal mvarId}"
|
||||
let (_, mvarId) ← mvarId.intros
|
||||
let mvarId ← mvarId.tryClearMany (alts.map (·.fvarId!))
|
||||
proveSubgoalLoop mvarId {}
|
||||
proveSubgoalLoop mvarId
|
||||
|
||||
/--
|
||||
Create new alternatives (aka minor premises) by replacing `discrs` with `patterns` at `alts`.
|
||||
@@ -659,7 +646,7 @@ where
|
||||
/--
|
||||
Create conditional equations and splitter for the given match auxiliary declaration. -/
|
||||
private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns := withLCtx {} {} do
|
||||
withTraceNode `Meta.Match.matchEqs (fun _ => return m!"mkEquationsFor '{matchDeclName}'") do
|
||||
trace[Meta.Match.matchEqs] "mkEquationsFor '{matchDeclName}'"
|
||||
withConfig (fun c => { c with etaStruct := .none }) do
|
||||
/-
|
||||
Remark: user have requested the `split` tactic to be available for writing code.
|
||||
|
||||
@@ -59,9 +59,9 @@ def addArg (matcherApp : MatcherApp) (e : Expr) : MetaM MatcherApp :=
|
||||
-- This error can only happen if someone implemented a transformation that rewrites the motive created by `mkMatcher`.
|
||||
throwError "unexpected matcher application, motive must be lambda expression with #{matcherApp.discrs.size} arguments"
|
||||
let eType ← inferType e
|
||||
let eTypeAbst ← matcherApp.discrs.size.foldRevM (init := eType) fun i _ eTypeAbst => do
|
||||
let eTypeAbst ← matcherApp.discrs.size.foldRevM (init := eType) fun i eTypeAbst => do
|
||||
let motiveArg := motiveArgs[i]!
|
||||
let discr := matcherApp.discrs[i]
|
||||
let discr := matcherApp.discrs[i]!
|
||||
let eTypeAbst ← kabstract eTypeAbst discr
|
||||
return eTypeAbst.instantiate1 motiveArg
|
||||
let motiveBody ← mkArrow eTypeAbst motiveBody
|
||||
@@ -118,9 +118,9 @@ def refineThrough (matcherApp : MatcherApp) (e : Expr) : MetaM (Array Expr) :=
|
||||
-- This error can only happen if someone implemented a transformation that rewrites the motive created by `mkMatcher`.
|
||||
throwError "failed to transfer argument through matcher application, motive must be lambda expression with #{matcherApp.discrs.size} arguments"
|
||||
|
||||
let eAbst ← matcherApp.discrs.size.foldRevM (init := e) fun i _ eAbst => do
|
||||
let eAbst ← matcherApp.discrs.size.foldRevM (init := e) fun i eAbst => do
|
||||
let motiveArg := motiveArgs[i]!
|
||||
let discr := matcherApp.discrs[i]
|
||||
let discr := matcherApp.discrs[i]!
|
||||
let eTypeAbst ← kabstract eAbst discr
|
||||
return eTypeAbst.instantiate1 motiveArg
|
||||
-- Let's create something that’s a `Sort` and mentions `e`
|
||||
|
||||
@@ -104,8 +104,8 @@ def appendParentTag (mvarId : MVarId) (newMVars : Array Expr) (binderInfos : Arr
|
||||
newMVars[0]!.mvarId!.setTag parentTag
|
||||
else
|
||||
unless parentTag.isAnonymous do
|
||||
newMVars.size.forM fun i _ => do
|
||||
let mvarIdNew := newMVars[i].mvarId!
|
||||
newMVars.size.forM fun i => do
|
||||
let mvarIdNew := newMVars[i]!.mvarId!
|
||||
unless (← mvarIdNew.isAssigned) do
|
||||
unless binderInfos[i]!.isInstImplicit do
|
||||
let currTag ← mvarIdNew.getTag
|
||||
|
||||
@@ -168,7 +168,7 @@ private def hasIndepIndices (ctx : Context) : MetaM Bool := do
|
||||
else if ctx.majorTypeIndices.any fun idx => !idx.isFVar then
|
||||
/- One of the indices is not a free variable. -/
|
||||
return false
|
||||
else if ctx.majorTypeIndices.size.any fun i _ => i.any fun j _ => ctx.majorTypeIndices[i] == ctx.majorTypeIndices[j] then
|
||||
else if ctx.majorTypeIndices.size.any fun i => i.any fun j => ctx.majorTypeIndices[i]! == ctx.majorTypeIndices[j]! then
|
||||
/- An index occurs more than once -/
|
||||
return false
|
||||
else
|
||||
|
||||
@@ -283,9 +283,9 @@ partial def foldAndCollect (oldIH newIH : FVarId) (isRecCall : Expr → Option E
|
||||
(onMotive := fun xs _body => do
|
||||
-- Remove the old IH that was added in mkFix
|
||||
let eType ← newIH.getType
|
||||
let eTypeAbst ← matcherApp.discrs.size.foldRevM (init := eType) fun i _ eTypeAbst => do
|
||||
let eTypeAbst ← matcherApp.discrs.size.foldRevM (init := eType) fun i eTypeAbst => do
|
||||
let motiveArg := xs[i]!
|
||||
let discr := matcherApp.discrs[i]
|
||||
let discr := matcherApp.discrs[i]!
|
||||
let eTypeAbst ← kabstract eTypeAbst discr
|
||||
return eTypeAbst.instantiate1 motiveArg
|
||||
|
||||
|
||||
@@ -105,10 +105,10 @@ private partial def finalize
|
||||
let mvarId' ← mvar.mvarId!.tryClear major.fvarId!
|
||||
let (fields, mvarId') ← mvarId'.introN nparams minorGivenNames.varNames (useNamesForExplicitOnly := !minorGivenNames.explicit)
|
||||
let (extra, mvarId') ← mvarId'.introNP nextra
|
||||
let subst := reverted.size.fold (init := baseSubst) fun i _ (subst : FVarSubst) =>
|
||||
let subst := reverted.size.fold (init := baseSubst) fun i (subst : FVarSubst) =>
|
||||
if i < indices.size + 1 then subst
|
||||
else
|
||||
let revertedFVarId := reverted[i]
|
||||
let revertedFVarId := reverted[i]!
|
||||
let newFVarId := extra[i - indices.size - 1]!
|
||||
subst.insert revertedFVarId (mkFVar newFVarId)
|
||||
let fields := fields.map mkFVar
|
||||
@@ -134,8 +134,8 @@ def getMajorTypeIndices (mvarId : MVarId) (tacticName : Name) (recursorInfo : Re
|
||||
if idxPos ≥ majorTypeArgs.size then throwTacticEx tacticName mvarId m!"major premise type is ill-formed{indentExpr majorType}"
|
||||
let idx := majorTypeArgs.get! idxPos
|
||||
unless idx.isFVar do throwTacticEx tacticName mvarId m!"major premise type index {idx} is not a variable{indentExpr majorType}"
|
||||
majorTypeArgs.size.forM fun i _ => do
|
||||
let arg := majorTypeArgs[i]
|
||||
majorTypeArgs.size.forM fun i => do
|
||||
let arg := majorTypeArgs[i]!
|
||||
if i != idxPos && arg == idx then
|
||||
throwTacticEx tacticName mvarId m!"'{idx}' is an index in major premise, but it occurs more than once{indentExpr majorType}"
|
||||
if i < idxPos then
|
||||
|
||||
@@ -94,48 +94,31 @@ def injection (mvarId : MVarId) (fvarId : FVarId) (newNames : List Name := []) :
|
||||
| .subgoal mvarId numEqs => injectionIntro mvarId numEqs newNames
|
||||
|
||||
inductive InjectionsResult where
|
||||
/-- `injections` closed the input goal. -/
|
||||
| solved
|
||||
/--
|
||||
`injections` produces a new goal `mvarId`. `remainingNames` contains the user-facing names that have not been used.
|
||||
`forbidden` contains all local declarations to which `injection` has been applied.
|
||||
Recall that some of these declarations may not have been eliminated from the local context due to forward dependencies, and
|
||||
we use `forbidden` to avoid non-termination when using `injections` in a loop.
|
||||
-/
|
||||
| subgoal (mvarId : MVarId) (remainingNames : List Name) (forbidden : FVarIdSet)
|
||||
| subgoal (mvarId : MVarId) (remainingNames : List Name)
|
||||
|
||||
/--
|
||||
Applies `injection` to local declarations in `mvarId`. It uses `newNames` to name the new local declarations.
|
||||
`maxDepth` is the maximum recursion depth. Only local declarations that are not in `forbidden` are considered.
|
||||
Recall that some of local declarations may not have been eliminated from the local context due to forward dependencies, and
|
||||
we use `forbidden` to avoid non-termination when using `injections` in a loop.
|
||||
-/
|
||||
partial def injections (mvarId : MVarId) (newNames : List Name := []) (maxDepth : Nat := 5) (forbidden : FVarIdSet := {}) : MetaM InjectionsResult :=
|
||||
partial def injections (mvarId : MVarId) (newNames : List Name := []) (maxDepth : Nat := 5) : MetaM InjectionsResult :=
|
||||
mvarId.withContext do
|
||||
let fvarIds := (← getLCtx).getFVarIds
|
||||
go maxDepth fvarIds.toList mvarId newNames forbidden
|
||||
go maxDepth fvarIds.toList mvarId newNames
|
||||
where
|
||||
go (depth : Nat) (fvarIds : List FVarId) (mvarId : MVarId) (newNames : List Name) (forbidden : FVarIdSet) : MetaM InjectionsResult := do
|
||||
match depth, fvarIds with
|
||||
| 0, _ => throwTacticEx `injections mvarId "recursion depth exceeded"
|
||||
| _, [] => return .subgoal mvarId newNames forbidden
|
||||
| d+1, fvarId :: fvarIds => do
|
||||
go : Nat → List FVarId → MVarId → List Name → MetaM InjectionsResult
|
||||
| 0, _, _, _ => throwTacticEx `injections mvarId "recursion depth exceeded"
|
||||
| _, [], mvarId, newNames => return .subgoal mvarId newNames
|
||||
| d+1, fvarId :: fvarIds, mvarId, newNames => do
|
||||
let cont := do
|
||||
go (d+1) fvarIds mvarId newNames forbidden
|
||||
if forbidden.contains fvarId then
|
||||
cont
|
||||
else if let some (_, lhs, rhs) ← matchEqHEq? (← fvarId.getType) then
|
||||
go (d+1) fvarIds mvarId newNames
|
||||
if let some (_, lhs, rhs) ← matchEqHEq? (← fvarId.getType) then
|
||||
let lhs ← whnf lhs
|
||||
let rhs ← whnf rhs
|
||||
if lhs.isRawNatLit && rhs.isRawNatLit then
|
||||
cont
|
||||
if lhs.isRawNatLit && rhs.isRawNatLit then cont
|
||||
else
|
||||
try
|
||||
commitIfNoEx do
|
||||
match (← injection mvarId fvarId newNames) with
|
||||
| .solved => return .solved
|
||||
| .subgoal mvarId newEqs remainingNames =>
|
||||
mvarId.withContext <| go d (newEqs.toList ++ fvarIds) mvarId remainingNames (forbidden.insert fvarId)
|
||||
mvarId.withContext <| go d (newEqs.toList ++ fvarIds) mvarId remainingNames
|
||||
catch _ => cont
|
||||
else cont
|
||||
|
||||
|
||||
@@ -43,7 +43,6 @@ def _root_.Lean.MVarId.revert (mvarId : MVarId) (fvarIds : Array FVarId) (preser
|
||||
finally
|
||||
mvarId.setKind .syntheticOpaque
|
||||
let mvar := e.getAppFn
|
||||
mvar.mvarId!.setKind .syntheticOpaque
|
||||
mvar.mvarId!.setTag tag
|
||||
return (toRevert.map Expr.fvarId!, mvar.mvarId!)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user