Compare commits

...

74 Commits

Author SHA1 Message Date
Kim Morrison
29ff22c560 chore: rename Array.setD to setIfInBounds 2024-11-24 19:23:21 +11:00
Kyle Miller
ba3f2b3ecf fix: make sure #check id heeds pp.raw (#6181)
This PR fixes a bug where the signature pretty printer would ignore the
current setting of `pp.raw`. This fixes an issue where `#check ident`
would not heed `pp.raw`. Closes #6090.
2024-11-23 00:39:58 +00:00
Leonardo de Moura
4a69643858 fix: nontermination while generating equation lemmas for match-expressions (#6180)
This PR fixes a non-termination bug that occurred when generating the
match-expression equation theorems. The bug was triggered when the proof
automation for the equation theorem repeatedly applied `injection(` to
the same local declaration, as it could not be removed due to forward
dependencies. See issue #6067 for an example that reproduces this issue.

closes #6067
2024-11-23 00:06:34 +00:00
Kyle Miller
b6a0d63612 feat: have "motive is not type correct" come with an explanation (#6168)
This PR extends the "motive is not type correct" error message for the
rewrite tactic to explain what it means. It also pretty prints the
type-incorrect motive and reports the type error.

Suggested [on
Zulip](https://leanprover.zulipchat.com/#narrow/channel/113489-new-members/topic/tactic.20'rewrite'.20failed.2C.20motive.20is.20not.20type.20correct/near/483545154).
2024-11-22 23:56:17 +00:00
Kyle Miller
5145030ff4 chore: refactor Elab.StructInst to use mutual for its structures/inductives (#6174)
Making use of #6125.
2024-11-22 19:17:48 +00:00
Kyle Miller
d3cb812fb6 chore: add test for recursive structures (#6173)
Closes #6140. This was fixed by #6125.
2024-11-22 18:53:59 +00:00
Lean stage0 autoupdater
e066c17a65 chore: update stage0 2024-11-22 18:29:01 +00:00
Sebastian Ullrich
38cff08888 feat: creation and reporting for asynchronous elaboration tasks (#6170)
This PR adds core metaprogramming functions for forking off background
tasks from elaboration such that their results are visible to reporting
and the language server
2024-11-22 17:12:30 +00:00
David Thrane Christiansen
3388fc8d06 doc: fix typo and make docstring more precise (#6009)
This PR fixes a typo in the docstring for prec and makes the text
slightly more precise.
2024-11-22 16:30:01 +00:00
Eric Wieser
5adcd520fa fix: make the stack handling more robust to sanitizers and -O3 (#6143)
This PR should make lean better-behaved around sanitizers, per
https://github.com/google/sanitizers/issues/1688.
As far as I can tell,
https://github.com/google/sanitizers/wiki/AddressSanitizerUseAfterReturn#algorithm
replaces local variables with heap allocations, and so taking the
address of a local is not effective at producing a monotonic measure of
stack usage.

The approach used here is the same as the one used by clang.
2024-11-22 15:10:20 +00:00
David Thrane Christiansen
1126407d9b feat: create temporary directories (#6148)
This PR adds a primitive for creating temporary directories, akin to the
existing functionality for creating temporary files.
2024-11-22 12:24:32 +00:00
Kyle Miller
a19ff61e15 feat: allow structure in mutual blocks (#6125)
This PR adds support for `structure` in `mutual` blocks, allowing
inductive types defined by `inductive` and `structure` to be mutually
recursive. The limitations are (1) that the parents in the `extends`
clause must be defined before the `mutual` block and (2) mutually
recursive classes are not allowed (a limitation shared by `class
inductive`). There are also improvements to universe level inference for
inductive types and structures. Breaking change: structure parents now
elaborate with the structure in scope (fix: use qualified names or
rename the structure to avoid shadowing), and structure parents no
longer elaborate with autoimplicits enabled.

Internally, this is a large refactor of both the `inductive` and
`structure` commands. Common material is now in
`Lean.Elab.MutualInductive`, and each command plugs into this mutual
inductive elaboration framework with the logic specific to the
respective command. For example, `structure` has code to add projections
after the inductive types are added to the environment.

Closes #4182
2024-11-22 09:20:07 +00:00
Lean stage0 autoupdater
6202461a21 chore: update stage0 2024-11-22 04:42:45 +00:00
Kim Morrison
ea221f3283 feat: Nat.(fold|foldRev|any|all)M? take a function which sees the upper bound (#6139)
This PR modifies the signature of the functions `Nat.fold`,
`Nat.foldRev`, `Nat.any`, `Nat.all`, so that the function is passed the
upper bound. This allows us to change runtime array bounds checks to
compile time checks in many places.
2024-11-22 03:05:51 +00:00
Kyle Miller
7c50d597c3 feat: add builtin attribute to support elaboration of mutual inductives/structures (#6166)
This PR is a prerequisite for #6125.
2024-11-22 01:48:37 +00:00
Tony Beta Lambda
99031695bd feat: display coercions with a type ascription (#6119)
This PR adds a new delab option `pp.coercions.types` which, when
enabled, will display all coercions with an explicit type ascription.

[Link to Zulip
discussion](https://leanprover.zulipchat.com/#narrow/channel/239415-metaprogramming-.2F-tactics/topic/Roundtripping.20delaboration.20involving.20coercions)

Towards #4315
2024-11-21 23:02:47 +00:00
JovanGerb
b7248d5295 fix: revert creates natural metavariable goal (#6145)
This PR fixes the `revert` tactic so that it creates a `syntheticOpaque`
metavariable as the new goal, instead of a `natural` metavariable

I reported it on
[Zulip](https://leanprover.zulipchat.com/#narrow/channel/270676-lean4/topic/.60revert.60.20gives.20natural.20metavariable.20goal/near/483388096)
2024-11-21 23:00:57 +00:00
Tobias Grosser
7f2e7e56d2 feat: BitVec.getMsbD_[ofNatLt|allOnes|not] (#6149)
This PR completes the elementwise accessors for `ofNatLt`, `allOnes`,
and `not` by adding their implementations of `getMsbD`.
2024-11-21 22:13:09 +00:00
Tobias Grosser
1fe66737ad feat: BitVec.toInt_[or|and|xor|not] (#6151)
This PR completes the `toInt` interface for `BitVec` bitwise operations.
2024-11-21 22:10:33 +00:00
Bhavik Mehta
765eb02279 doc: adjust file reference in Data.Sum (#6158)
This file was upstreamed from batteries; I just got bitten by the
invalid reference and it took quite a while to figure out that this one
had been moved!
2024-11-21 21:48:27 +00:00
Henrik Böving
a101377054 perf: speed up reflection of if in bv_decide (#6162)
This PR adds a slight performance improvement to reflection of `if`
statements that I noticed by profiling Leanwuzla against SMTCOMP's
`non-incremental/QF_BV/fft/Sz256_6616.smt2`.

In particular:
1. The profile showed about 4 percent of the total run time were spent
constructing these decidable instances in reflection of `if` statements.
We can construct them much quicker by hand as they always have the same
structure
2. This delays construction of these statements until we actually
generate the reflection proof that we wish to submit to the kernel. Thus
if we encounter a SAT instad of an UNSAT problem we will not spend time
generating these expressions anymore.

```
baseline
  Time (mean ± σ):     31.236 s ±  0.258 s
  Range (min … max):   30.899 s … 31.661 s    10 runs

after
  Time (mean ± σ):     30.671 s ±  0.288 s
  Range (min … max):   30.350 s … 31.156 s    10 runs
```
2024-11-21 19:40:14 +00:00
Kyle Miller
aca9929d84 fix: make sure whitespace is printed before tactic configuration (#6161)
This PR ensures whitespace is printed before `+opt` and `-opt`
configuration options when pretty printing, improving the experience of
tactics such as `simp?`.

Reported [on
Zulip](https://leanprover.zulipchat.com/#narrow/channel/270676-lean4/topic/Minor.20simp.3F.20annoyances/near/483736310)
2024-11-21 19:21:59 +00:00
Sebastian Ullrich
19a701e5c9 refactor: one more recursive structure (#6159) 2024-11-21 18:30:28 +00:00
Leonardo de Moura
fc4305ab15 fix: nontermination when generating the match-expression splitter theorem (#6146)
This PR fixes a non-termination bug that occurred when generating the
match-expression splitter theorem. The bug was triggered when the proof
automation for the splitter theorem repeatedly applied `injection` to
the same local declaration, as it could not be removed due to forward
dependencies. See issue #6065 for an example that reproduces this issue.

closes #6065
2024-11-21 17:20:33 +00:00
Kim Morrison
9cf83706e7 chore: add changelog-* labels via comment (#6147)
This PR enables contributors to modify `changelog-*` labels simply by
writing a comment with the desired label.
2024-11-21 07:23:13 +00:00
Tobias Grosser
459c6e2a46 feat: BitVec.getElem_[sub|neg|sshiftRight'|abs] (#6126)
This PR adds lemmas for extracting a given bit of a `BitVec` obtained
via `sub`/`neg`/`sshiftRight'`/`abs`.

---------

Co-authored-by: Kim Morrison <scott@tqft.net>
2024-11-21 07:01:11 +00:00
Kim Morrison
72e952eadc chore: avoid runtime array bounds checks (#6134)
This PR avoids runtime array bounds checks in places where it can
trivially be done at compile time.

None of these changes are of particular consequence: I mostly wanted to
learn how much we do this, and what the obstacles are to doing it less.
2024-11-21 05:04:52 +00:00
damiano
56a80dec1b doc: doc-strings to module docs in Data/Array/Lemmas (#6144)
This PR converts 3 doc-string to module docs since it seems that this is
what they were intended to be!
2024-11-21 05:04:09 +00:00
JovanGerb
b894464191 fix: type occurs check bug (#6128)
This PR does the same fix as #6104, but such that it doesn't break the
test/the file in `Plausible`. This is done by not creating unused let
binders in metavariable types that are made by `elimMVar`. (This is also
a positive thing for users looking at metavariable types, for example in
error messages)

We get rid of `skipAtMostNumBinders`. This function was originally
defined for the purpose of making this test work, but it is a hack
because it allows cycles in the metavariable context.

It would make sense to split these changes into 2 PRs, but I combined
them here to show that the combination of them closes #6013 without
breaking anything

Closes #6013
2024-11-21 00:28:36 +00:00
Sebastian Ullrich
b30903d1fc refactor: make use of recursive structures in snapshot types (#6141) 2024-11-20 15:15:14 +00:00
Sebastian Ullrich
7fbe8e3b36 fix: Inhabited Float produced a bogus run-time value (#6136)
This PR fixes the run-time evaluation of `(default : Float)`.
2024-11-20 10:43:59 +00:00
Sebastian Ullrich
2fbc46641d fix: trace.profiler pretty-printing (#6138)
This PR fixes `trace.profiler.pp` not using the term pretty printer.

Fixes #5872
2024-11-20 10:21:02 +00:00
Sebastian Ullrich
17419aca7f feat: thread support for trace.profiler.output (#6137)
This PR adds support for displaying multiple threads in the trace
profiler output.

`TraceState.tid` needs to be adjusted for this purpose, which is not
done yet by the Lean elaborator as it is still single-threaded.
2024-11-20 10:02:39 +00:00
Kim Morrison
f85c66789d feat: Array.insertIdx/eraseIdx take a tactic-provided proof (#6133)
This PR replaces `Array.feraseIdx` and `Array.insertAt` with
`Array.eraseIdx` and `Array.insertIdx`, both of which take a `Nat`
argument and a tactic-provided proof that it is in bounds. We also have
`eraseIdxIfInBounds` and `insertIdxIfInBounds` which are noops if the
index is out of bounds. We also provide a `Fin` valued version of
`Array.findIdx?`. Together, these quite ergonomically improve the array
indexing safety at a number of places in the compiler/elaborator.
2024-11-20 09:52:38 +00:00
Kim Morrison
c8b4f6b511 feat: duplicate List.attach/attachWith/pmap API for Array (#6132)
This PR duplicates the verification API for
`List.attach`/`attachWith`/`pmap` over to `Array`.
2024-11-20 01:16:48 +00:00
Luisa Cicolini
3c7555168d feat: add BitVec.(msb, getMsbD)_(rotateLeft, rotateRight) (#6120)
This PR adds theorems `BitVec.(getMsbD, msb)_(rotateLeft, rotateRight)`.

We follow the same strategy taken for `getLsbD`, constructing the
necessary auxilliary theorems first (relying on different hypotheses)
and then generalizing.

---------

Co-authored-by: Siddharth <siddu.druid@gmail.com>
Co-authored-by: Tobias Grosser <tobias@grosser.es>
2024-11-19 23:04:14 +00:00
Kyle Miller
5eef3d27fb feat: have #print show precise fields of structures (#6096)
This PR improves the `#print` command for structures to show all fields
and which parents the fields were inherited from, hiding internal
details such as which parents are represented as subobjects. This
information is still present in the constructor if needed. The pretty
printer for private constants is also improved, and it now handles
private names from the current module like any other name; private names
from other modules are made hygienic.

Example output for `#print Monad`:
```
class Monad.{u, v} (m : Type u → Type v) : Type (max (u + 1) v)
number of parameters: 1
parents:
  Monad.toApplicative : Applicative m
  Monad.toBind : Bind m
fields:
  Functor.map : {α β : Type u} → (α → β) → m α → m β
  Functor.mapConst : {α β : Type u} → α → m β → m α
  Pure.pure : {α : Type u} → α → m α
  Seq.seq : {α β : Type u} → m (α → β) → (Unit → m α) → m β
  SeqLeft.seqLeft : {α β : Type u} → m α → (Unit → m β) → m α
  SeqRight.seqRight : {α β : Type u} → m α → (Unit → m β) → m β
  Bind.bind : {α β : Type u} → m α → (α → m β) → m β
constructor:
  Monad.mk.{u, v} {m : Type u → Type v} [toApplicative : Applicative m] [toBind : Bind m] : Monad m
resolution order:
  Monad, Applicative, Bind, Functor, Pure, Seq, SeqLeft, SeqRight
```

Suggested by Floris van Doorn [on
Zulip](https://leanprover.zulipchat.com/#narrow/channel/270676-lean4/topic/.23print.20command.20for.20structures/near/482503637).
2024-11-19 21:54:45 +00:00
Leonardo de Moura
75d1504af2 fix: isDefEq for constants with different universe parameters (#6131)
This PR fixes a bug at the definitional equality test (`isDefEq`). At
unification constraints of the form `c.{u} =?= c.{v}`, it was not trying
to unfold `c`. This bug did not affect the kernel.

closes #6117
2024-11-19 21:39:13 +00:00
Mario Carneiro
a00cf6330f fix: add a missing case to Level.geq (#2689)
This PR adds a case to `Level.geq` that is present in the kernel's level
`is_geq` procedure, making them consistent with one another.

This came up during testing of `lean4lean`. Currently `Level.geq`
differs from `level::is_geq` in the case of `max u v >= imax u v`. The
elaborator function is overly pessimistic and yields `false` on this
while the kernel function yields true. This comes up concretely in the
`Trans` class:
```lean
class Trans (r : α → β → Sort u) (s : β → γ → Sort v) (t : outParam (α → γ → Sort w)) where
  trans : r a b → s b c → t a c
```
The type of this class is `Sort (max (max (max (max (max (max 1 u) u_1)
u_2) u_3) v) w)` (where `u_1 u_2 u_3` are the levels of `α β γ`), but if
you try writing that type explicitly then the `class` command fails.
Omitting the type leaves the `class` to infer the universe level (the
command assumes the level is correct, and the kernel agrees it is), but
including the type then the elaborator checks the level inequality with
`Level.geq` and fails.

---------

Co-authored-by: Kyle Miller <kmill31415@gmail.com>
2024-11-19 21:27:00 +00:00
Leonardo de Moura
1f32477385 fix: isDefEq when zetaDelta := false (#6129)
This PR fixes a bug at `isDefEq` when `zetaDelta := false`. See new test
for a small example that exposes the issue.
2024-11-19 21:22:02 +00:00
Thomas Köppe
91c14c7ee9 fix: only consider salient bytes in sharecommon eq, hash (#5840)
This PR changes `lean_sharecommon_{eq,hash}` to only consider the
salient bytes of an object, and not any bytes of any
unspecified/uninitialized unused capacity.

Accessing uninitialized storage results in undefined behaviour.

This does not seem to have any semantics disadvantages: If objects
compare equal after this change, their salient bytes are still equal. By
contrast, if the actual identity of allocations needs to be
distinguished, that can be done by just comparing pointers to the
storage.

If we wanted to retain the current logic, we would need initialize the
otherwise unused parts to some specific value to avoid the undefined
behaviour.

Closes #5831
2024-11-19 13:56:46 +00:00
Lean stage0 autoupdater
69530afdf9 chore: update stage0 2024-11-19 13:06:43 +00:00
Marc Huisinga
b7667c1604 fix: don't issue atomic id completions when there is a dangling dot (#5837)
This PR fixes an old auto-completion bug where `x.` would issue
nonsensical completions when `x.` could not be elaborated as a dot
completion.
2024-11-19 12:23:41 +00:00
Eric Wieser
d6f898001b chore: generalize List.get_mem (#6095)
This is syntactically more general than before, though up to eta
expansion it make no difference.
2024-11-19 11:08:10 +00:00
Marc Huisinga
a38566693b test: fix brittle structure instance completion test (#6127)
#5835 contains a brittle test that uses an FVar ID, which caused a
failure on master. This PR changes that test to use a declaration
instead.
2024-11-19 10:13:51 +00:00
Marc Huisinga
4bef3588b5 chore: update stage0 2024-11-19 09:26:58 +01:00
Marc Huisinga
64538cf6e8 chore: prepare for bootstrap
Co-Authored-By: Sebastian Ullrich <sebasti@nullri.ch>
2024-11-19 09:26:58 +01:00
Marc Huisinga
aadf3f1d2c feat: use new structInstFields parser to tag structure instance fields 2024-11-19 09:26:58 +01:00
Marc Huisinga
95bf45ff8b refactor: split Completion.lean 2024-11-19 09:26:58 +01:00
Marc Huisinga
2a02c121cf feat: structure auto-completion & partial InfoTrees 2024-11-19 09:26:58 +01:00
Mac Malone
4600bb16fc feat: use BaseIO at IO.rand (#6102)
This PR moves `IO.rand` and `IO.setRandSeed` to be in the `BaseIO`
monad.

This is their proper monad as neither can error.
2024-11-19 05:26:03 +00:00
Kim Morrison
7ccdfc30ff chore: turn off pp.mvars in apply? results (#6108)
Per request on
[zulip](https://leanprover.zulipchat.com/#narrow/channel/270676-lean4/topic/apply.3F.20using.20tombstones/near/482895588).
2024-11-19 02:02:32 +00:00
Kim Morrison
7f0bdefb6e chore: fix apply? error reporting when out of heartbeats (#6121) 2024-11-19 00:57:59 +00:00
Joachim Breitner
799b2b6628 fix: handle reordered indices in structural recursion (#6116)
This PR fixes a bug where structural recursion did not work when indices
of the recursive argument appeared as function parameters in a different
order than in the argument's type's definition.

Fixes #6015.
2024-11-18 11:28:02 +00:00
David Thrane Christiansen
b8d6e44c4f fix: liberalize rules for atoms by allowing leading '' (#6114)
This PR liberalizes atom rules by allowing `''` to be a prefix of an
atom, after #6012 only added an exception for `''` alone, and also adds
some unit tests for atom validation.
2024-11-18 10:19:20 +00:00
Kim Morrison
5a99cb326c chore: make Lean.Elab.Command.mkMetaContext public (#6113) 2024-11-18 06:14:34 +00:00
Kim Morrison
e10fac93a6 feat: lemmas for Array.findSome? and find? (#6111)
This PR fills in the API for `Array.findSome?` and `Array.find?`,
transferring proofs from the corresponding List statements.
2024-11-18 04:19:56 +00:00
Kyle Miller
62ae320e1c chore: document Lean.Elab.StructInst, refactor (#6110)
This PR does some mild refactoring of the `Lean.Elab.StructInst` module
while adding documentation.

Documentation is drawn from @thorimur's #1928.
2024-11-18 02:57:22 +00:00
Leonardo de Moura
98b1edfc1f fix: backtrack at injection failure (#6109)
This PR fixes an issue in the `injection` tactic. This tactic may
execute multiple sub-tactics. If any of them fail, we must backtrack the
partial assignment. This issue was causing the error: "`mvarId` is
already assigned" in issue #6066. The issue is not yet resolved, as the
equation generator for the match expressions is failing in the example
provided in this issue.
2024-11-18 02:26:06 +00:00
Leonardo de Moura
ab162b3f52 fix: isDefEq, whnf, simp caching and configuration (#6053)
This PR fixes the caching infrastructure for `whnf` and `isDefEq`,
ensuring the cache accounts for all relevant configuration flags. It
also cleans up the `WHNF.lean` module and improves the configuration of
`whnf`.
2024-11-18 01:17:26 +00:00
Kim Morrison
b8a13ab755 chore: fix naming of left/right injectivity lemmas (#6106)
We've been internally inconsistent on the naming of these lemmas in
Lean; this changes them to match Mathlib (which, moreover, I think is
correct).
2024-11-18 00:53:46 +00:00
Sebastian Ullrich
405593ea28 chore: avoid stack overflow in debug tests (#6103) 2024-11-17 14:54:49 +00:00
Kim Morrison
24f305c0e3 chore: fix canonicalizer handling over forall/lambda (#6082)
This PR changes how the canonicalizer handles `forall` and `lambda`,
replacing bvars with temporary fvars. Fixes a bug reported by @hrmacbeth
on
[zulip](https://leanprover.zulipchat.com/#narrow/channel/270676-lean4/topic/Quantifiers.20in.20CanonM/near/482483448).
2024-11-17 07:34:45 +00:00
Leonardo de Moura
5d553d6369 fix: circular assignment at structure instance elaborator (#6105)
This PR fixes a stack overflow caused by a cyclic assignment in the
metavariable context. The cycle is unintentionally introduced by the
structure instance elaborator.

closes #3150
2024-11-17 00:56:52 +00:00
Sebastian Ullrich
a449e3fdd6 feat: IO.getTID (#6049)
This PR adds a primitive for accessing the current thread ID

To be used in a thread-aware trace profiler
2024-11-16 19:13:11 +00:00
Kyle Miller
764386734c fix: improvements to change tactic (#6022)
This PR makes the `change` tactic and conv tactic use the same
elaboration strategy. It works uniformly for both the target and local
hypotheses. Now `change` can assign metavariables, for example:
```lean
example (x y z : Nat) : x + y = z := by
  change ?a = _
  let w := ?a
  -- now `w : Nat := x + y`
```
2024-11-16 07:08:29 +00:00
Kyle Miller
7f1d7a595b fix: use Expr.equal instead of == in MVarId.replaceTargetDefEq and MVarId.replaceLocalDeclDefEq (#6098)
This PR modifies `Lean.MVarId.replaceTargetDefEq` and
`Lean.MVarId.replaceLocalDeclDefEq` to use `Expr.equal` instead of
`Expr.eqv` when determining whether the expression has changed. This is
justified on the grounds that binder names and binder infos are
user-visible and affect elaboration.
2024-11-16 02:03:16 +00:00
Leonardo de Moura
f13e5ca852 chore: naming convention and NaN normalization (#6097)
Changes:
- `Float.fromBits` => `Float.ofBits`
- NaN normalization
2024-11-16 00:14:28 +00:00
Leonardo de Moura
ecbaeff24b feat: add Float.toBits and Float.fromBits (#6094)
This PR adds raw transmutation of floating-point numbers to and from
`UInt64`. Floats and UInts share the same endianness across all
supported platforms. The IEEE 754 standard precisely specifies the bit
layout of floats. Note that `Float.toBits` is distinct from
`Float.toUInt64`, which attempts to preserve the numeric value rather
than the bitwise value.

closes #6071
2024-11-15 19:45:19 +00:00
Kyle Miller
691acde696 feat: pp.parens option to pretty print with all parentheses (#2934)
This PR adds the option `pp.parens` (default: false) that causes the
pretty printer to eagerly insert parentheses, which can be useful for
teaching and for understanding the structure of expressions. For
example, it causes `p → q → r` to pretty print as `p → (q → r)`.

Any notations with precedence greater than or equal to `maxPrec` do not
receive such discretionary parentheses, since this precedence level is
considered to be infinity.

This option was a feature in the Lean 3 community edition.
2024-11-15 19:11:54 +00:00
Kyle Miller
b1e0c1b594 chore: remove decide! tactic (#6016)
This PR removes the `decide!` tactic in favor of `decide +kernel`
(breaking change).
2024-11-15 17:49:33 +00:00
Joachim Breitner
93b4ec0351 refactor: use mkFreshUserName in ArgsPacker (#6093)
and other small refinements done while investigating an issue; not
actually user-visible.
2024-11-15 15:59:14 +00:00
JovanGerb
f06fc30c0b perf: remove @[specialize] from mkBinding (#6019)
This PR removes @[specilize] from `MkBinding.mkBinding`, which is a
function that cannot be specialized (as none of its arguments are
functions). As a result, the specializable function `Nat.foldRevM.loop`
doesn't get specialized, which leads to worse performing code.

As expected, the mathlib bench shows a very small improvement. About 95%
of files show a speedup.
(http://speed.lean-fro.org/mathlib4/compare/e7b27246-a3e6-496a-b552-ff4b45c7236e/to/6033df75-aa53-44d9-819d-51f93fc05e94?hash1=b28f0d7f7e9cc3949a9a3556a6b36513f37f690d)
2024-11-15 15:06:49 +00:00
Markus Himmel
64b35a8c19 perf: add LEAN_ALWAYS_INLINE to some functions (#6045)
Otherwise, clang refuses to inline them for large functions which leads
to a performance cliff.
2024-11-15 15:05:32 +00:00
697 changed files with 9360 additions and 4598 deletions

View File

@@ -1,7 +1,8 @@
# This workflow allows any user to add one of the `awaiting-review`, `awaiting-author`, `WIP`,
# or `release-ci` labels by commenting on the PR or issue.
# `release-ci`, or a `changelog-XXX` label by commenting on the PR or issue.
# If any labels from the set {`awaiting-review`, `awaiting-author`, `WIP`} are added, other labels
# from that set are removed automatically at the same time.
# Similarly, if any `changelog-XXX` label is added, other `changelog-YYY` labels are removed.
name: Label PR based on Comment
@@ -11,7 +12,7 @@ on:
jobs:
update-label:
if: github.event.issue.pull_request != null && (contains(github.event.comment.body, 'awaiting-review') || contains(github.event.comment.body, 'awaiting-author') || contains(github.event.comment.body, 'WIP') || contains(github.event.comment.body, 'release-ci'))
if: github.event.issue.pull_request != null && (contains(github.event.comment.body, 'awaiting-review') || contains(github.event.comment.body, 'awaiting-author') || contains(github.event.comment.body, 'WIP') || contains(github.event.comment.body, 'release-ci') || contains(github.event.comment.body, 'changelog-'))
runs-on: ubuntu-latest
steps:
@@ -20,13 +21,14 @@ jobs:
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const { owner, repo, number: issue_number } = context.issue;
const { owner, repo, number: issue_number } = context.issue;
const commentLines = context.payload.comment.body.split('\r\n');
const awaitingReview = commentLines.includes('awaiting-review');
const awaitingAuthor = commentLines.includes('awaiting-author');
const wip = commentLines.includes('WIP');
const releaseCI = commentLines.includes('release-ci');
const changelogMatch = commentLines.find(line => line.startsWith('changelog-'));
if (awaitingReview || awaitingAuthor || wip) {
await github.rest.issues.removeLabel({ owner, repo, issue_number, name: 'awaiting-review' }).catch(() => {});
@@ -47,3 +49,19 @@ jobs:
if (releaseCI) {
await github.rest.issues.addLabels({ owner, repo, issue_number, labels: ['release-ci'] });
}
if (changelogMatch) {
const changelogLabel = changelogMatch.trim();
const { data: existingLabels } = await github.rest.issues.listLabelsOnIssue({ owner, repo, issue_number });
const changelogLabels = existingLabels.filter(label => label.name.startsWith('changelog-'));
// Remove all other changelog labels
for (const label of changelogLabels) {
if (label.name !== changelogLabel) {
await github.rest.issues.removeLabel({ owner, repo, issue_number, name: label.name }).catch(() => {});
}
}
// Add the new changelog label
await github.rest.issues.addLabels({ owner, repo, issue_number, labels: [changelogLabel] });
}

View File

@@ -10,12 +10,13 @@ import Init.Data.List.Attach
namespace Array
/-- `O(n)`. Partial map. If `f : Π a, P a → β` is a partial function defined on
`a : α` satisfying `P`, then `pmap f l h` is essentially the same as `map f l`
but is defined only when all members of `l` satisfy `P`, using the proof
to apply `f`.
/--
`O(n)`. Partial map. If `f : Π a, P a → β` is a partial function defined on
`a : α` satisfying `P`, then `pmap f l h` is essentially the same as `map f l`
but is defined only when all members of `l` satisfy `P`, using the proof
to apply `f`.
We replace this at runtime with a more efficient version via
We replace this at runtime with a more efficient version via the `csimp` lemma `pmap_eq_pmapImpl`.
-/
def pmap {P : α Prop} (f : a, P a β) (l : Array α) (H : a l, P a) : Array β :=
(l.toList.pmap f (fun a m => H a (mem_def.mpr m))).toArray
@@ -73,6 +74,17 @@ Unsafe implementation of `attachWith`, taking advantage of the fact that the rep
intro a m h₁ h₂
congr
@[simp] theorem pmap_empty {P : α Prop} (f : a, P a β) : pmap f #[] (by simp) = #[] := rfl
@[simp] theorem pmap_push {P : α Prop} (f : a, P a β) (a : α) (l : Array α) (h : b l.push a, P b) :
pmap f (l.push a) h =
(pmap f l (fun a m => by simp at h; exact h a (.inl m))).push (f a (h a (by simp))) := by
simp [pmap]
@[simp] theorem attach_empty : (#[] : Array α).attach = #[] := rfl
@[simp] theorem attachWith_empty {P : α Prop} (H : x #[], P x) : (#[] : Array α).attachWith P H = #[] := rfl
@[simp] theorem _root_.List.attachWith_mem_toArray {l : List α} :
l.attachWith (fun x => x l.toArray) (fun x h => by simpa using h) =
l.attach.map fun x, h => x, by simpa using h := by
@@ -80,6 +92,353 @@ Unsafe implementation of `attachWith`, taking advantage of the fact that the rep
apply List.pmap_congr_left
simp
@[simp]
theorem pmap_eq_map (p : α Prop) (f : α β) (l : Array α) (H) :
@pmap _ _ p (fun a _ => f a) l H = map f l := by
cases l; simp
theorem pmap_congr_left {p q : α Prop} {f : a, p a β} {g : a, q a β} (l : Array α) {H₁ H₂}
(h : a l, (h₁ h₂), f a h₁ = g a h₂) : pmap f l H₁ = pmap g l H₂ := by
cases l
simp only [mem_toArray] at h
simp only [List.pmap_toArray, mk.injEq]
rw [List.pmap_congr_left _ h]
theorem map_pmap {p : α Prop} (g : β γ) (f : a, p a β) (l H) :
map g (pmap f l H) = pmap (fun a h => g (f a h)) l H := by
cases l
simp [List.map_pmap]
theorem pmap_map {p : β Prop} (g : b, p b γ) (f : α β) (l H) :
pmap g (map f l) H = pmap (fun a h => g (f a) h) l fun _ h => H _ (mem_map_of_mem _ h) := by
cases l
simp [List.pmap_map]
theorem attach_congr {l₁ l₂ : Array α} (h : l₁ = l₂) :
l₁.attach = l₂.attach.map (fun x => x.1, h x.2) := by
subst h
simp
theorem attachWith_congr {l₁ l₂ : Array α} (w : l₁ = l₂) {P : α Prop} {H : x l₁, P x} :
l₁.attachWith P H = l₂.attachWith P fun _ h => H _ (w h) := by
subst w
simp
@[simp] theorem attach_push {a : α} {l : Array α} :
(l.push a).attach =
(l.attach.map (fun x, h => x, mem_push_of_mem a h)).push a, by simp := by
cases l
rw [attach_congr (List.push_toArray _ _)]
simp [Function.comp_def]
@[simp] theorem attachWith_push {a : α} {l : Array α} {P : α Prop} {H : x l.push a, P x} :
(l.push a).attachWith P H =
(l.attachWith P (fun x h => by simp at H; exact H x (.inl h))).push a, H a (by simp) := by
cases l
simp [attachWith_congr (List.push_toArray _ _)]
theorem pmap_eq_map_attach {p : α Prop} (f : a, p a β) (l H) :
pmap f l H = l.attach.map fun x => f x.1 (H _ x.2) := by
cases l
simp [List.pmap_eq_map_attach]
theorem attach_map_coe (l : Array α) (f : α β) :
(l.attach.map fun (i : {i // i l}) => f i) = l.map f := by
cases l
simp [List.attach_map_coe]
theorem attach_map_val (l : Array α) (f : α β) : (l.attach.map fun i => f i.val) = l.map f :=
attach_map_coe _ _
@[simp]
theorem attach_map_subtype_val (l : Array α) : l.attach.map Subtype.val = l := by
cases l; simp
theorem attachWith_map_coe {p : α Prop} (f : α β) (l : Array α) (H : a l, p a) :
((l.attachWith p H).map fun (i : { i // p i}) => f i) = l.map f := by
cases l; simp
theorem attachWith_map_val {p : α Prop} (f : α β) (l : Array α) (H : a l, p a) :
((l.attachWith p H).map fun i => f i.val) = l.map f :=
attachWith_map_coe _ _ _
@[simp]
theorem attachWith_map_subtype_val {p : α Prop} (l : Array α) (H : a l, p a) :
(l.attachWith p H).map Subtype.val = l := by
cases l; simp
@[simp]
theorem mem_attach (l : Array α) : x, x l.attach
| a, h => by
have := mem_map.1 (by rw [attach_map_subtype_val] <;> exact h)
rcases this with _, _, m, rfl
exact m
@[simp]
theorem mem_pmap {p : α Prop} {f : a, p a β} {l H b} :
b pmap f l H (a : _) (h : a l), f a (H a h) = b := by
simp only [pmap_eq_map_attach, mem_map, mem_attach, true_and, Subtype.exists, eq_comm]
theorem mem_pmap_of_mem {p : α Prop} {f : a, p a β} {l H} {a} (h : a l) :
f a (H a h) pmap f l H := by
rw [mem_pmap]
exact a, h, rfl
@[simp]
theorem size_pmap {p : α Prop} {f : a, p a β} {l H} : (pmap f l H).size = l.size := by
cases l; simp
@[simp]
theorem size_attach {L : Array α} : L.attach.size = L.size := by
cases L; simp
@[simp]
theorem size_attachWith {p : α Prop} {l : Array α} {H} : (l.attachWith p H).size = l.size := by
cases l; simp
@[simp]
theorem pmap_eq_empty_iff {p : α Prop} {f : a, p a β} {l H} : pmap f l H = #[] l = #[] := by
cases l; simp
theorem pmap_ne_empty_iff {P : α Prop} (f : (a : α) P a β) {xs : Array α}
(H : (a : α), a xs P a) : xs.pmap f H #[] xs #[] := by
cases xs; simp
theorem pmap_eq_self {l : Array α} {p : α Prop} (hp : (a : α), a l p a)
(f : (a : α) p a α) : l.pmap f hp = l a (h : a l), f a (hp a h) = a := by
cases l; simp [List.pmap_eq_self]
@[simp]
theorem attach_eq_empty_iff {l : Array α} : l.attach = #[] l = #[] := by
cases l; simp
theorem attach_ne_empty_iff {l : Array α} : l.attach #[] l #[] := by
cases l; simp
@[simp]
theorem attachWith_eq_empty_iff {l : Array α} {P : α Prop} {H : a l, P a} :
l.attachWith P H = #[] l = #[] := by
cases l; simp
theorem attachWith_ne_empty_iff {l : Array α} {P : α Prop} {H : a l, P a} :
l.attachWith P H #[] l #[] := by
cases l; simp
@[simp]
theorem getElem?_pmap {p : α Prop} (f : a, p a β) {l : Array α} (h : a l, p a) (n : Nat) :
(pmap f l h)[n]? = Option.pmap f l[n]? fun x H => h x (mem_of_getElem? H) := by
cases l; simp
@[simp]
theorem getElem_pmap {p : α Prop} (f : a, p a β) {l : Array α} (h : a l, p a) {n : Nat}
(hn : n < (pmap f l h).size) :
(pmap f l h)[n] =
f (l[n]'(@size_pmap _ _ p f l h hn))
(h _ (getElem_mem (@size_pmap _ _ p f l h hn))) := by
cases l; simp
@[simp]
theorem getElem?_attachWith {xs : Array α} {i : Nat} {P : α Prop} {H : a xs, P a} :
(xs.attachWith P H)[i]? = xs[i]?.pmap Subtype.mk (fun _ a => H _ (mem_of_getElem? a)) :=
getElem?_pmap ..
@[simp]
theorem getElem?_attach {xs : Array α} {i : Nat} :
xs.attach[i]? = xs[i]?.pmap Subtype.mk (fun _ a => mem_of_getElem? a) :=
getElem?_attachWith
@[simp]
theorem getElem_attachWith {xs : Array α} {P : α Prop} {H : a xs, P a}
{i : Nat} (h : i < (xs.attachWith P H).size) :
(xs.attachWith P H)[i] = xs[i]'(by simpa using h), H _ (getElem_mem (by simpa using h)) :=
getElem_pmap ..
@[simp]
theorem getElem_attach {xs : Array α} {i : Nat} (h : i < xs.attach.size) :
xs.attach[i] = xs[i]'(by simpa using h), getElem_mem (by simpa using h) :=
getElem_attachWith h
theorem foldl_pmap (l : Array α) {P : α Prop} (f : (a : α) P a β)
(H : (a : α), a l P a) (g : γ β γ) (x : γ) :
(l.pmap f H).foldl g x = l.attach.foldl (fun acc a => g acc (f a.1 (H _ a.2))) x := by
rw [pmap_eq_map_attach, foldl_map]
theorem foldr_pmap (l : Array α) {P : α Prop} (f : (a : α) P a β)
(H : (a : α), a l P a) (g : β γ γ) (x : γ) :
(l.pmap f H).foldr g x = l.attach.foldr (fun a acc => g (f a.1 (H _ a.2)) acc) x := by
rw [pmap_eq_map_attach, foldr_map]
/--
If we fold over `l.attach` with a function that ignores the membership predicate,
we get the same results as folding over `l` directly.
This is useful when we need to use `attach` to show termination.
Unfortunately this can't be applied by `simp` because of the higher order unification problem,
and even when rewriting we need to specify the function explicitly.
See however `foldl_subtype` below.
-/
theorem foldl_attach (l : Array α) (f : β α β) (b : β) :
l.attach.foldl (fun acc t => f acc t.1) b = l.foldl f b := by
rcases l with l
simp only [List.attach_toArray, List.attachWith_mem_toArray, List.map_attach, size_toArray,
List.length_pmap, List.foldl_toArray', mem_toArray, List.foldl_subtype]
congr
ext
simpa using fun a => List.mem_of_getElem? a
/--
If we fold over `l.attach` with a function that ignores the membership predicate,
we get the same results as folding over `l` directly.
This is useful when we need to use `attach` to show termination.
Unfortunately this can't be applied by `simp` because of the higher order unification problem,
and even when rewriting we need to specify the function explicitly.
See however `foldr_subtype` below.
-/
theorem foldr_attach (l : Array α) (f : α β β) (b : β) :
l.attach.foldr (fun t acc => f t.1 acc) b = l.foldr f b := by
rcases l with l
simp only [List.attach_toArray, List.attachWith_mem_toArray, List.map_attach, size_toArray,
List.length_pmap, List.foldr_toArray', mem_toArray, List.foldr_subtype]
congr
ext
simpa using fun a => List.mem_of_getElem? a
theorem attach_map {l : Array α} (f : α β) :
(l.map f).attach = l.attach.map (fun x, h => f x, mem_map_of_mem f h) := by
cases l
ext <;> simp
theorem attachWith_map {l : Array α} (f : α β) {P : β Prop} {H : (b : β), b l.map f P b} :
(l.map f).attachWith P H = (l.attachWith (P f) (fun _ h => H _ (mem_map_of_mem f h))).map
fun x, h => f x, h := by
cases l
ext
· simp
· simp only [List.map_toArray, List.attachWith_toArray, List.getElem_toArray,
List.getElem_attachWith, List.getElem_map, Function.comp_apply]
erw [List.getElem_attachWith] -- Why is `erw` needed here?
theorem map_attachWith {l : Array α} {P : α Prop} {H : (a : α), a l P a}
(f : { x // P x } β) :
(l.attachWith P H).map f =
l.pmap (fun a (h : a l P a) => f a, H _ h.1) (fun a h => h, H a h) := by
cases l
ext <;> simp
/-- See also `pmap_eq_map_attach` for writing `pmap` in terms of `map` and `attach`. -/
theorem map_attach {l : Array α} (f : { x // x l } β) :
l.attach.map f = l.pmap (fun a h => f a, h) (fun _ => id) := by
cases l
ext <;> simp
theorem attach_filterMap {l : Array α} {f : α Option β} :
(l.filterMap f).attach = l.attach.filterMap
fun x, h => (f x).pbind (fun b m => some b, mem_filterMap.mpr x, h, m) := by
cases l
rw [attach_congr (List.filterMap_toArray f _)]
simp [List.attach_filterMap, List.map_filterMap, Function.comp_def]
theorem attach_filter {l : Array α} (p : α Bool) :
(l.filter p).attach = l.attach.filterMap
fun x => if w : p x.1 then some x.1, mem_filter.mpr x.2, w else none := by
cases l
rw [attach_congr (List.filter_toArray p _)]
simp [List.attach_filter, List.map_filterMap, Function.comp_def]
-- We are still missing here `attachWith_filterMap` and `attachWith_filter`.
-- Also missing are `filterMap_attach`, `filter_attach`, `filterMap_attachWith` and `filter_attachWith`.
theorem pmap_pmap {p : α Prop} {q : β Prop} (g : a, p a β) (f : b, q b γ) (l H₁ H₂) :
pmap f (pmap g l H₁) H₂ =
pmap (α := { x // x l }) (fun a h => f (g a h) (H₂ (g a h) (mem_pmap_of_mem a.2))) l.attach
(fun a _ => H₁ a a.2) := by
cases l
simp [List.pmap_pmap, List.pmap_map]
@[simp] theorem pmap_append {p : ι Prop} (f : a : ι, p a α) (l₁ l₂ : Array ι)
(h : a l₁ ++ l₂, p a) :
(l₁ ++ l₂).pmap f h =
(l₁.pmap f fun a ha => h a (mem_append_left l₂ ha)) ++
l₂.pmap f fun a ha => h a (mem_append_right l₁ ha) := by
cases l₁
cases l₂
simp
theorem pmap_append' {p : α Prop} (f : a : α, p a β) (l₁ l₂ : Array α)
(h₁ : a l₁, p a) (h₂ : a l₂, p a) :
((l₁ ++ l₂).pmap f fun a ha => (mem_append.1 ha).elim (h₁ a) (h₂ a)) =
l₁.pmap f h₁ ++ l₂.pmap f h₂ :=
pmap_append f l₁ l₂ _
@[simp] theorem attach_append (xs ys : Array α) :
(xs ++ ys).attach = xs.attach.map (fun x, h => x, mem_append_left ys h) ++
ys.attach.map fun x, h => x, mem_append_right xs h := by
cases xs
cases ys
rw [attach_congr (List.append_toArray _ _)]
simp [List.attach_append, Function.comp_def]
@[simp] theorem attachWith_append {P : α Prop} {xs ys : Array α}
{H : (a : α), a xs ++ ys P a} :
(xs ++ ys).attachWith P H = xs.attachWith P (fun a h => H a (mem_append_left ys h)) ++
ys.attachWith P (fun a h => H a (mem_append_right xs h)) := by
simp [attachWith, attach_append, map_pmap, pmap_append]
@[simp] theorem pmap_reverse {P : α Prop} (f : (a : α) P a β) (xs : Array α)
(H : (a : α), a xs.reverse P a) :
xs.reverse.pmap f H = (xs.pmap f (fun a h => H a (by simpa using h))).reverse := by
induction xs <;> simp_all
theorem reverse_pmap {P : α Prop} (f : (a : α) P a β) (xs : Array α)
(H : (a : α), a xs P a) :
(xs.pmap f H).reverse = xs.reverse.pmap f (fun a h => H a (by simpa using h)) := by
rw [pmap_reverse]
@[simp] theorem attachWith_reverse {P : α Prop} {xs : Array α}
{H : (a : α), a xs.reverse P a} :
xs.reverse.attachWith P H =
(xs.attachWith P (fun a h => H a (by simpa using h))).reverse := by
cases xs
simp
theorem reverse_attachWith {P : α Prop} {xs : Array α}
{H : (a : α), a xs P a} :
(xs.attachWith P H).reverse = (xs.reverse.attachWith P (fun a h => H a (by simpa using h))) := by
cases xs
simp
@[simp] theorem attach_reverse (xs : Array α) :
xs.reverse.attach = xs.attach.reverse.map fun x, h => x, by simpa using h := by
cases xs
rw [attach_congr (List.reverse_toArray _)]
simp
theorem reverse_attach (xs : Array α) :
xs.attach.reverse = xs.reverse.attach.map fun x, h => x, by simpa using h := by
cases xs
simp
@[simp] theorem back?_pmap {P : α Prop} (f : (a : α) P a β) (xs : Array α)
(H : (a : α), a xs P a) :
(xs.pmap f H).back? = xs.attach.back?.map fun a, m => f a (H a m) := by
cases xs
simp
@[simp] theorem back?_attachWith {P : α Prop} {xs : Array α}
{H : (a : α), a xs P a} :
(xs.attachWith P H).back? = xs.back?.pbind (fun a h => some a, H _ (mem_of_back?_eq_some h)) := by
cases xs
simp
@[simp]
theorem back?_attach {xs : Array α} :
xs.attach.back? = xs.back?.pbind fun a h => some a, mem_of_back?_eq_some h := by
cases xs
simp
/-! ## unattach
`Array.unattach` is the (one-sided) inverse of `Array.attach`. It is a synonym for `Array.map Subtype.val`.
@@ -128,6 +487,15 @@ def unattach {α : Type _} {p : α → Prop} (l : Array { x // p x }) := l.map (
cases l
simp
@[simp] theorem getElem?_unattach {p : α Prop} {l : Array { x // p x }} (i : Nat) :
l.unattach[i]? = l[i]?.map Subtype.val := by
simp [unattach]
@[simp] theorem getElem_unattach
{p : α Prop} {l : Array { x // p x }} (i : Nat) (h : i < l.unattach.size) :
l.unattach[i] = (l[i]'(by simpa using h)).1 := by
simp [unattach]
/-! ### Recognizing higher order functions using a function that only depends on the value. -/
/--

View File

@@ -233,7 +233,7 @@ def ofFn {n} (f : Fin n → α) : Array α := go 0 (mkEmpty n) where
/-- The array `#[0, 1, ..., n - 1]`. -/
def range (n : Nat) : Array Nat :=
n.fold (flip Array.push) (mkEmpty n)
ofFn fun (i : Fin n) => i
def singleton (v : α) : Array α :=
mkArray 1 v
@@ -613,8 +613,15 @@ def findIdx? {α : Type u} (p : α → Bool) (as : Array α) : Option Nat :=
decreasing_by simp_wf; decreasing_trivial_pre_omega
loop 0
def getIdx? [BEq α] (a : Array α) (v : α) : Option Nat :=
a.findIdx? fun a => a == v
@[inline]
def findFinIdx? {α : Type u} (p : α Bool) (as : Array α) : Option (Fin as.size) :=
let rec @[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
loop (j : Nat) :=
if h : j < as.size then
if p as[j] then some j, h else loop (j + 1)
else none
decreasing_by simp_wf; decreasing_trivial_pre_omega
loop 0
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
def indexOfAux [BEq α] (a : Array α) (v : α) (i : Nat) : Option (Fin a.size) :=
@@ -627,6 +634,10 @@ decreasing_by simp_wf; decreasing_trivial_pre_omega
def indexOf? [BEq α] (a : Array α) (v : α) : Option (Fin a.size) :=
indexOfAux a v 0
@[deprecated indexOf? (since := "2024-11-20")]
def getIdx? [BEq α] (a : Array α) (v : α) : Option Nat :=
a.findIdx? fun a => a == v
@[inline]
def any (as : Array α) (p : α Bool) (start := 0) (stop := as.size) : Bool :=
Id.run <| as.anyM p start stop
@@ -766,48 +777,62 @@ def takeWhile (p : α → Bool) (as : Array α) : Array α :=
decreasing_by simp_wf; decreasing_trivial_pre_omega
go 0 #[]
/-- Remove the element at a given index from an array without bounds checks, using a `Fin` index.
/--
Remove the element at a given index from an array without a runtime bounds checks,
using a `Nat` index and a tactic-provided bound.
This function takes worst case O(n) time because
it has to backshift all elements at positions greater than `i`.-/
This function takes worst case O(n) time because
it has to backshift all elements at positions greater than `i`.-/
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
def feraseIdx (a : Array α) (i : Fin a.size) : Array α :=
if h : i.val + 1 < a.size then
let a' := a.swap i.val + 1, h i
let i' : Fin a'.size := i.val + 1, by simp [a', h]
a'.feraseIdx i'
def eraseIdx (a : Array α) (i : Nat) (h : i < a.size := by get_elem_tactic) : Array α :=
if h' : i + 1 < a.size then
let a' := a.swap i + 1, h' i, h
a'.eraseIdx (i + 1) (by simp [a', h'])
else
a.pop
termination_by a.size - i.val
decreasing_by simp_wf; exact Nat.sub_succ_lt_self _ _ i.isLt
termination_by a.size - i
decreasing_by simp_wf; exact Nat.sub_succ_lt_self _ _ h
-- This is required in `Lean.Data.PersistentHashMap`.
@[simp] theorem size_feraseIdx (a : Array α) (i : Fin a.size) : (a.feraseIdx i).size = a.size - 1 := by
induction a, i using Array.feraseIdx.induct with
| @case1 a i h a' _ ih =>
unfold feraseIdx
simp [h, a', ih]
| case2 a i h =>
unfold feraseIdx
simp [h]
@[simp] theorem size_eraseIdx (a : Array α) (i : Nat) (h) : (a.eraseIdx i h).size = a.size - 1 := by
induction a, i, h using Array.eraseIdx.induct with
| @case1 a i h h' a' ih =>
unfold eraseIdx
simp [h', a', ih]
| case2 a i h h' =>
unfold eraseIdx
simp [h']
/-- Remove the element at a given index from an array, or do nothing if the index is out of bounds.
This function takes worst case O(n) time because
it has to backshift all elements at positions greater than `i`.-/
def eraseIdx (a : Array α) (i : Nat) : Array α :=
if h : i < a.size then a.feraseIdx i, h else a
def eraseIdxIfInBounds (a : Array α) (i : Nat) : Array α :=
if h : i < a.size then a.eraseIdx i h else a
/-- Remove the element at a given index from an array, or panic if the index is out of bounds.
This function takes worst case O(n) time because
it has to backshift all elements at positions greater than `i`. -/
def eraseIdx! (a : Array α) (i : Nat) : Array α :=
if h : i < a.size then a.eraseIdx i h else panic! "invalid index"
def erase [BEq α] (as : Array α) (a : α) : Array α :=
match as.indexOf? a with
| none => as
| some i => as.feraseIdx i
| some i => as.eraseIdx i
/-- Erase the first element that satisfies the predicate `p`. -/
def eraseP (as : Array α) (p : α Bool) : Array α :=
match as.findIdx? p with
| none => as
| some i => as.eraseIdxIfInBounds i
/-- Insert element `a` at position `i`. -/
@[inline] def insertAt (as : Array α) (i : Fin (as.size + 1)) (a : α) : Array α :=
@[inline] def insertIdx (as : Array α) (i : Nat) (a : α) (_ : i as.size := by get_elem_tactic) : Array α :=
let rec @[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
loop (as : Array α) (j : Fin as.size) :=
if i.1 < j then
if i < j then
let j' := j-1, Nat.lt_of_le_of_lt (Nat.pred_le _) j.2
let as := as.swap j' j
loop as j', by rw [size_swap]; exact j'.2
@@ -818,12 +843,23 @@ def erase [BEq α] (as : Array α) (a : α) : Array α :=
let as := as.push a
loop as j, size_push .. j.lt_succ_self
@[deprecated insertIdx (since := "2024-11-20")] abbrev insertAt := @insertIdx
/-- Insert element `a` at position `i`. Panics if `i` is not `i ≤ as.size`. -/
def insertAt! (as : Array α) (i : Nat) (a : α) : Array α :=
def insertIdx! (as : Array α) (i : Nat) (a : α) : Array α :=
if h : i as.size then
insertAt as i, Nat.lt_succ_of_le h a
insertIdx as i a
else panic! "invalid index"
@[deprecated insertIdx! (since := "2024-11-20")] abbrev insertAt! := @insertIdx!
/-- Insert element `a` at position `i`, or do nothing if `as.size < i`. -/
def insertIdxIfInBounds (as : Array α) (i : Nat) (a : α) : Array α :=
if h : i as.size then
insertIdx as i a
else
as
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
def isPrefixOfAux [BEq α] (as bs : Array α) (hle : as.size bs.size) (i : Nat) : Bool :=
if h : i < as.size then

View File

@@ -52,7 +52,7 @@ namespace Array
let mid := (lo + hi)/2
let midVal := as.get! mid
if lt midVal k then
if mid == lo then do let v add (); pure <| as.insertAt! (lo+1) v
if mid == lo then do let v add (); pure <| as.insertIdx! (lo+1) v
else binInsertAux lt merge add as k mid hi
else if lt k midVal then
binInsertAux lt merge add as k lo mid
@@ -67,7 +67,7 @@ namespace Array
(k : α) : m (Array α) :=
let _ := Inhabited.mk k
if as.isEmpty then do let v add (); pure <| as.push v
else if lt k (as.get! 0) then do let v add (); pure <| as.insertAt! 0 v
else if lt k (as.get! 0) then do let v add (); pure <| as.insertIdx! 0 v
else if !lt (as.get! 0) k then as.modifyM 0 <| merge
else if lt as.back! k then do let v add (); pure <| as.push v
else if !lt k as.back! then as.modifyM (as.size - 1) <| merge

View File

@@ -0,0 +1,281 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Kim Morrison
-/
prelude
import Init.Data.List.Find
import Init.Data.Array.Lemmas
import Init.Data.Array.Attach
/-!
# Lemmas about `Array.findSome?`, `Array.find?`.
-/
namespace Array
open Nat
/-! ### findSome? -/
@[simp] theorem findSomeRev?_push_of_isSome (l : Array α) (h : (f a).isSome) : (l.push a).findSomeRev? f = f a := by
cases l; simp_all
@[simp] theorem findSomeRev?_push_of_isNone (l : Array α) (h : (f a).isNone) : (l.push a).findSomeRev? f = l.findSomeRev? f := by
cases l; simp_all
theorem exists_of_findSome?_eq_some {f : α Option β} {l : Array α} (w : l.findSome? f = some b) :
a, a l f a = b := by
cases l; simp_all [List.exists_of_findSome?_eq_some]
@[simp] theorem findSome?_eq_none_iff : findSome? p l = none x l, p x = none := by
cases l; simp
@[simp] theorem findSome?_isSome_iff {f : α Option β} {l : Array α} :
(l.findSome? f).isSome x, x l (f x).isSome := by
cases l; simp
theorem findSome?_eq_some_iff {f : α Option β} {l : Array α} {b : β} :
l.findSome? f = some b (l₁ : Array α) (a : α) (l₂ : Array α), l = l₁.push a ++ l₂ f a = some b x l₁, f x = none := by
cases l
simp only [List.findSome?_toArray, List.findSome?_eq_some_iff]
constructor
· rintro l₁, a, l₂, rfl, h₁, h₂
exact l₁.toArray, a, l₂.toArray, by simp_all
· rintro l₁, a, l₂, h₀, h₁, h₂
exact l₁.toList, a, l₂.toList, by simpa using congrArg toList h₀, h₁, by simpa
@[simp] theorem findSome?_guard (l : Array α) : findSome? (Option.guard fun x => p x) l = find? p l := by
cases l; simp
@[simp] theorem getElem?_zero_filterMap (f : α Option β) (l : Array α) : (l.filterMap f)[0]? = l.findSome? f := by
cases l; simp [ List.head?_eq_getElem?]
@[simp] theorem getElem_zero_filterMap (f : α Option β) (l : Array α) (h) :
(l.filterMap f)[0] = (l.findSome? f).get (by cases l; simpa [List.length_filterMap_eq_countP] using h) := by
cases l; simp [ List.head_eq_getElem, getElem?_zero_filterMap]
@[simp] theorem back?_filterMap (f : α Option β) (l : Array α) : (l.filterMap f).back? = l.findSomeRev? f := by
cases l; simp
@[simp] theorem back!_filterMap [Inhabited β] (f : α Option β) (l : Array α) :
(l.filterMap f).back! = (l.findSomeRev? f).getD default := by
cases l; simp
@[simp] theorem map_findSome? (f : α Option β) (g : β γ) (l : Array α) :
(l.findSome? f).map g = l.findSome? (Option.map g f) := by
cases l; simp
theorem findSome?_map (f : β γ) (l : Array β) : findSome? p (l.map f) = l.findSome? (p f) := by
cases l; simp [List.findSome?_map]
theorem findSome?_append {l₁ l₂ : Array α} : (l₁ ++ l₂).findSome? f = (l₁.findSome? f).or (l₂.findSome? f) := by
cases l₁; cases l₂; simp [List.findSome?_append]
theorem getElem?_zero_flatten (L : Array (Array α)) :
(flatten L)[0]? = L.findSome? fun l => l[0]? := by
cases L using array_array_induction
simp [ List.head?_eq_getElem?, List.head?_flatten, List.findSome?_map, Function.comp_def]
theorem getElem_zero_flatten.proof {L : Array (Array α)} (h : 0 < L.flatten.size) :
(L.findSome? fun l => l[0]?).isSome := by
cases L using array_array_induction
simp only [List.findSome?_toArray, List.findSome?_map, Function.comp_def, List.getElem?_toArray,
List.findSome?_isSome_iff, List.isSome_getElem?]
simp only [flatten_toArray_map_toArray, size_toArray, List.length_flatten,
Nat.sum_pos_iff_exists_pos, List.mem_map] at h
obtain _, xs, m, rfl, h := h
exact xs, m, by simpa using h
theorem getElem_zero_flatten {L : Array (Array α)} (h) :
(flatten L)[0] = (L.findSome? fun l => l[0]?).get (getElem_zero_flatten.proof h) := by
have t := getElem?_zero_flatten L
simp [getElem?_eq_getElem, h] at t
simp [ t]
theorem back?_flatten {L : Array (Array α)} :
(flatten L).back? = (L.findSomeRev? fun l => l.back?) := by
cases L using array_array_induction
simp [List.getLast?_flatten, List.map_reverse, List.findSome?_map, Function.comp_def]
theorem findSome?_mkArray : findSome? f (mkArray n a) = if n = 0 then none else f a := by
simp [mkArray_eq_toArray_replicate, List.findSome?_replicate]
@[simp] theorem findSome?_mkArray_of_pos (h : 0 < n) : findSome? f (mkArray n a) = f a := by
simp [findSome?_mkArray, Nat.ne_of_gt h]
-- Argument is unused, but used to decide whether `simp` should unfold.
@[simp] theorem findSome?_mkArray_of_isSome (_ : (f a).isSome) :
findSome? f (mkArray n a) = if n = 0 then none else f a := by
simp [findSome?_mkArray]
@[simp] theorem findSome?_mkArray_of_isNone (h : (f a).isNone) :
findSome? f (mkArray n a) = none := by
rw [Option.isNone_iff_eq_none] at h
simp [findSome?_mkArray, h]
/-! ### find? -/
@[simp] theorem find?_singleton (a : α) (p : α Bool) :
#[a].find? p = if p a then some a else none := by
simp [singleton_eq_toArray_singleton]
@[simp] theorem findRev?_push_of_pos (l : Array α) (h : p a) :
findRev? p (l.push a) = some a := by
cases l; simp [h]
@[simp] theorem findRev?_cons_of_neg (l : Array α) (h : ¬p a) :
findRev? p (l.push a) = findRev? p l := by
cases l; simp [h]
@[simp] theorem find?_eq_none : find? p l = none x l, ¬ p x := by
cases l; simp
theorem find?_eq_some_iff_append {xs : Array α} :
xs.find? p = some b p b (as bs : Array α), xs = as.push b ++ bs a as, !p a := by
rcases xs with xs
simp only [List.find?_toArray, List.find?_eq_some_iff_append, Bool.not_eq_eq_eq_not,
Bool.not_true, exists_and_right, and_congr_right_iff]
intro w
constructor
· rintro as, x, rfl, h
exact as.toArray, x.toArray, by simp , by simpa using h
· rintro as, x, h', h
exact as.toList, x.toList, by simpa using congrArg Array.toList h',
by simpa using h
@[simp]
theorem find?_push_eq_some {xs : Array α} :
(xs.push a).find? p = some b xs.find? p = some b (xs.find? p = none (p a a = b)) := by
cases xs; simp
@[simp] theorem find?_isSome {xs : Array α} {p : α Bool} : (xs.find? p).isSome x, x xs p x := by
cases xs; simp
theorem find?_some {xs : Array α} (h : find? p xs = some a) : p a := by
cases xs
simp at h
exact List.find?_some h
theorem mem_of_find?_eq_some {xs : Array α} (h : find? p xs = some a) : a xs := by
cases xs
simp at h
simpa using List.mem_of_find?_eq_some h
theorem get_find?_mem {xs : Array α} (h) : (xs.find? p).get h xs := by
cases xs
simp [List.get_find?_mem]
@[simp] theorem find?_filter {xs : Array α} (p q : α Bool) :
(xs.filter p).find? q = xs.find? (fun a => p a q a) := by
cases xs; simp
@[simp] theorem getElem?_zero_filter (p : α Bool) (l : Array α) :
(l.filter p)[0]? = l.find? p := by
cases l; simp [ List.head?_eq_getElem?]
@[simp] theorem getElem_zero_filter (p : α Bool) (l : Array α) (h) :
(l.filter p)[0] =
(l.find? p).get (by cases l; simpa [ List.countP_eq_length_filter] using h) := by
cases l
simp [List.getElem_zero_eq_head]
@[simp] theorem back?_filter (p : α Bool) (l : Array α) : (l.filter p).back? = l.findRev? p := by
cases l; simp
@[simp] theorem back!_filter [Inhabited α] (p : α Bool) (l : Array α) :
(l.filter p).back! = (l.findRev? p).get! := by
cases l; simp [Option.get!_eq_getD]
@[simp] theorem find?_filterMap (xs : Array α) (f : α Option β) (p : β Bool) :
(xs.filterMap f).find? p = (xs.find? (fun a => (f a).any p)).bind f := by
cases xs; simp
@[simp] theorem find?_map (f : β α) (xs : Array β) :
find? p (xs.map f) = (xs.find? (p f)).map f := by
cases xs; simp
@[simp] theorem find?_append {l₁ l₂ : Array α} :
(l₁ ++ l₂).find? p = (l₁.find? p).or (l₂.find? p) := by
cases l₁
cases l₂
simp
@[simp] theorem find?_flatten (xs : Array (Array α)) (p : α Bool) :
xs.flatten.find? p = xs.findSome? (·.find? p) := by
cases xs using array_array_induction
simp [List.findSome?_map, Function.comp_def]
theorem find?_flatten_eq_none {xs : Array (Array α)} {p : α Bool} :
xs.flatten.find? p = none ys xs, x ys, !p x := by
simp
/--
If `find? p` returns `some a` from `xs.flatten`, then `p a` holds, and
some array in `xs` contains `a`, and no earlier element of that array satisfies `p`.
Moreover, no earlier array in `xs` has an element satisfying `p`.
-/
theorem find?_flatten_eq_some {xs : Array (Array α)} {p : α Bool} {a : α} :
xs.flatten.find? p = some a
p a (as : Array (Array α)) (ys zs : Array α) (bs : Array (Array α)),
xs = as.push (ys.push a ++ zs) ++ bs
( a as, x a, !p x) ( x ys, !p x) := by
cases xs using array_array_induction
simp only [flatten_toArray_map_toArray, List.find?_toArray, List.find?_flatten_eq_some]
simp only [Bool.not_eq_eq_eq_not, Bool.not_true, exists_and_right, and_congr_right_iff]
intro w
constructor
· rintro as, ys, zs, bs, rfl, h₁, h₂
exact as.toArray.map List.toArray, ys.toArray,
zs.toArray, bs.toArray.map List.toArray, by simp, by simpa using h₁, by simpa using h₂
· rintro as, ys, zs, bs, h, h₁, h₂
replace h := congrArg (·.map Array.toList) (congrArg Array.toList h)
simp [Function.comp_def] at h
exact as.toList.map Array.toList, ys.toList,
zs.toList, bs.toList.map Array.toList, by simpa using h,
by simpa using h₁, by simpa using h₂
@[simp] theorem find?_flatMap (xs : Array α) (f : α Array β) (p : β Bool) :
(xs.flatMap f).find? p = xs.findSome? (fun x => (f x).find? p) := by
cases xs
simp [List.find?_flatMap, Array.flatMap_toArray]
theorem find?_flatMap_eq_none {xs : Array α} {f : α Array β} {p : β Bool} :
(xs.flatMap f).find? p = none x xs, y f x, !p y := by
simp
theorem find?_mkArray :
find? p (mkArray n a) = if n = 0 then none else if p a then some a else none := by
simp [mkArray_eq_toArray_replicate, List.find?_replicate]
@[simp] theorem find?_mkArray_of_length_pos (h : 0 < n) :
find? p (mkArray n a) = if p a then some a else none := by
simp [find?_mkArray, Nat.ne_of_gt h]
@[simp] theorem find?_mkArray_of_pos (h : p a) :
find? p (mkArray n a) = if n = 0 then none else some a := by
simp [find?_mkArray, h]
@[simp] theorem find?_mkArray_of_neg (h : ¬ p a) : find? p (mkArray n a) = none := by
simp [find?_mkArray, h]
-- This isn't a `@[simp]` lemma since there is already a lemma for `l.find? p = none` for any `l`.
theorem find?_mkArray_eq_none {n : Nat} {a : α} {p : α Bool} :
(mkArray n a).find? p = none n = 0 !p a := by
simp [mkArray_eq_toArray_replicate, List.find?_replicate_eq_none, Classical.or_iff_not_imp_left]
@[simp] theorem find?_mkArray_eq_some {n : Nat} {a b : α} {p : α Bool} :
(mkArray n a).find? p = some b n 0 p a a = b := by
simp [mkArray_eq_toArray_replicate]
@[simp] theorem get_find?_mkArray (n : Nat) (a : α) (p : α Bool) (h) :
((mkArray n a).find? p).get h = a := by
simp [mkArray_eq_toArray_replicate]
theorem find?_pmap {P : α Prop} (f : (a : α) P a β) (xs : Array α)
(H : (a : α), a xs P a) (p : β Bool) :
(xs.pmap f H).find? p = (xs.attach.find? (fun a, m => p (f a (H a m)))).map fun a, m => f a (H a m) := by
simp only [pmap_eq_map_attach, find?_map]
rfl
end Array

View File

@@ -23,6 +23,9 @@ import Init.TacticsExtra
namespace Array
@[simp] theorem mem_toArray {a : α} {l : List α} : a l.toArray a l := by
simp [mem_def]
@[simp] theorem getElem_mk {xs : List α} {i : Nat} (h : i < xs.length) : (Array.mk xs)[i] = xs[i] := rfl
theorem getElem_eq_getElem_toList {a : Array α} (h : i < a.size) : a[i] = a.toList[i] := rfl
@@ -36,12 +39,21 @@ theorem getElem?_eq_getElem {a : Array α} {i : Nat} (h : i < a.size) : a[i]? =
· rw [getElem?_neg a i h]
simp_all
@[simp] theorem none_eq_getElem?_iff {a : Array α} {i : Nat} : none = a[i]? a.size i := by
simp [eq_comm (a := none)]
theorem getElem?_eq {a : Array α} {i : Nat} :
a[i]? = if h : i < a.size then some a[i] else none := by
split
· simp_all [getElem?_eq_getElem]
· simp_all
theorem getElem?_eq_some_iff {a : Array α} : a[i]? = some b h : i < a.size, a[i] = b := by
simp [getElem?_eq]
theorem some_eq_getElem?_iff {a : Array α} : some b = a[i]? h : i < a.size, a[i] = b := by
rw [eq_comm, getElem?_eq_some_iff]
theorem getElem?_eq_getElem?_toList (a : Array α) (i : Nat) : a[i]? = a.toList[i]? := by
rw [getElem?_eq]
split <;> simp_all
@@ -66,6 +78,35 @@ theorem getElem_push (a : Array α) (x : α) (i : Nat) (h : i < (a.push x).size)
@[deprecated getElem_push_lt (since := "2024-10-21")] abbrev get_push_lt := @getElem_push_lt
@[deprecated getElem_push_eq (since := "2024-10-21")] abbrev get_push_eq := @getElem_push_eq
@[simp] theorem mem_push {a : Array α} {x y : α} : x a.push y x a x = y := by
simp [mem_def]
theorem mem_push_self {a : Array α} {x : α} : x a.push x :=
mem_push.2 (Or.inr rfl)
theorem mem_push_of_mem {a : Array α} {x : α} (y : α) (h : x a) : x a.push y :=
mem_push.2 (Or.inl h)
theorem getElem_of_mem {a} {l : Array α} (h : a l) : (n : Nat) (h : n < l.size), l[n]'h = a := by
cases l
simp [List.getElem_of_mem (by simpa using h)]
theorem getElem?_of_mem {a} {l : Array α} (h : a l) : n : Nat, l[n]? = some a :=
let n, _, e := getElem_of_mem h; n, e getElem?_eq_getElem _
theorem mem_of_getElem? {l : Array α} {n : Nat} {a : α} (e : l[n]? = some a) : a l :=
let _, e := getElem?_eq_some_iff.1 e; e getElem_mem ..
theorem mem_iff_getElem {a} {l : Array α} : a l (n : Nat) (h : n < l.size), l[n]'h = a :=
getElem_of_mem, fun _, _, e => e getElem_mem ..
theorem mem_iff_getElem? {a} {l : Array α} : a l n : Nat, l[n]? = some a := by
simp [getElem?_eq_some_iff, mem_iff_getElem]
theorem forall_getElem {l : Array α} {p : α Prop} :
( (n : Nat) h, p (l[n]'h)) a, a l p a := by
cases l; simp [List.forall_getElem]
@[simp] theorem get!_eq_getElem! [Inhabited α] (a : Array α) (i : Nat) : a.get! i = a[i]! := by
simp [getElem!_def, get!, getD]
split <;> rename_i h
@@ -93,9 +134,6 @@ We prefer to pull `List.toArray` outwards.
(a.toArrayAux b).size = b.size + a.length := by
simp [size]
@[simp] theorem mem_toArray {a : α} {l : List α} : a l.toArray a l := by
simp [mem_def]
@[simp] theorem push_toArray (l : List α) (a : α) : l.toArray.push a = (l ++ [a]).toArray := by
apply ext'
simp
@@ -475,10 +513,10 @@ theorem getElem?_len_le (a : Array α) {i : Nat} (h : a.size ≤ i) : a[i]? = no
theorem getD_get? (a : Array α) (i : Nat) (d : α) :
Option.getD a[i]? d = if p : i < a.size then a[i]'p else d := by
if h : i < a.size then
simp [setD, h, getElem?_def]
simp [setIfInBounds, h, getElem?_def]
else
have p : i a.size := Nat.le_of_not_gt h
simp [setD, getElem?_len_le _ p, h]
simp [setIfInBounds, getElem?_len_le _ p, h]
@[simp] theorem getD_eq_get? (a : Array α) (n d) : a.getD n d = (a[n]?).getD d := by
simp only [getD, get_eq_getElem, get?_eq_getElem?]; split <;> simp [getD_get?, *]
@@ -514,31 +552,32 @@ theorem getElem_set (a : Array α) (i : Nat) (h' : i < a.size) (v : α) (j : Nat
(ne : i j) : (a.set i v)[j]? = a[j]? := by
by_cases h : j < a.size <;> simp [getElem?_lt, getElem?_ge, Nat.ge_of_not_lt, ne, h]
/-! # setD -/
/-! # setIfInBounds -/
@[simp] theorem set!_is_setD : @set! = @setD := rfl
@[simp] theorem set!_is_setIfInBounds : @set! = @setIfInBounds := rfl
@[simp] theorem size_setD (a : Array α) (index : Nat) (val : α) :
(Array.setD a index val).size = a.size := by
@[simp] theorem size_setIfInBounds (a : Array α) (index : Nat) (val : α) :
(Array.setIfInBounds a index val).size = a.size := by
if h : index < a.size then
simp [setD, h]
simp [setIfInBounds, h]
else
simp [setD, h]
simp [setIfInBounds, h]
@[simp] theorem getElem_setD_eq (a : Array α) {i : Nat} (v : α) (h : _) :
(setD a i v)[i]'h = v := by
@[simp] theorem getElem_setIfInBounds_eq (a : Array α) {i : Nat} (v : α) (h : _) :
(setIfInBounds a i v)[i]'h = v := by
simp at h
simp only [setD, h, reduceDIte, getElem_set_eq]
simp only [setIfInBounds, h, reduceDIte, getElem_set_eq]
@[simp]
theorem getElem?_setD_eq (a : Array α) {i : Nat} (p : i < a.size) (v : α) : (a.setD i v)[i]? = some v := by
theorem getElem?_setIfInBounds_eq (a : Array α) {i : Nat} (p : i < a.size) (v : α) :
(a.setIfInBounds i v)[i]? = some v := by
simp [getElem?_lt, p]
/-- Simplifies a normal form from `get!` -/
@[simp] theorem getD_get?_setD (a : Array α) (i : Nat) (v d : α) :
Option.getD (setD a i v)[i]? d = if i < a.size then v else d := by
@[simp] theorem getD_get?_setIfInBounds (a : Array α) (i : Nat) (v d : α) :
Option.getD (setIfInBounds a i v)[i]? d = if i < a.size then v else d := by
by_cases h : i < a.size <;>
simp [setD, Nat.not_lt_of_le, h, getD_get?]
simp [setIfInBounds, Nat.not_lt_of_le, h, getD_get?]
/-! # ofFn -/
@@ -583,7 +622,20 @@ theorem getElem?_ofFn (f : Fin n → α) (i : Nat) :
(ofFn f)[i]? = if h : i < n then some (f i, h) else none := by
simp [getElem?_def]
/-- # mkArray -/
@[simp] theorem ofFn_zero (f : Fin 0 α) : ofFn f = #[] := rfl
theorem ofFn_succ (f : Fin (n+1) α) :
ofFn f = (ofFn (fun (i : Fin n) => f i.castSucc)).push (f n, by omega) := by
ext i h₁ h₂
· simp
· simp [getElem_push]
split <;> rename_i h₃
· rfl
· congr
simp at h₁ h₂
omega
/-! # mkArray -/
@[simp] theorem size_mkArray (n : Nat) (v : α) : (mkArray n v).size = n :=
List.length_replicate ..
@@ -599,42 +651,29 @@ theorem getElem?_mkArray (n : Nat) (v : α) (i : Nat) :
(mkArray n v)[i]? = if i < n then some v else none := by
simp [getElem?_def]
/-- # mem -/
/-! # mem -/
theorem mem_toList {a : α} {l : Array α} : a l.toList a l := mem_def.symm
@[simp] theorem mem_toList {a : α} {l : Array α} : a l.toList a l := mem_def.symm
theorem not_mem_nil (a : α) : ¬ a #[] := nofun
theorem getElem_of_mem {a : α} {as : Array α} :
a as ( (n : Nat) (h : n < as.size), as[n]'h = a) := by
intro ha
rcases List.getElem_of_mem ha.val with i, hbound, hi
exists i
exists hbound
theorem getElem?_of_mem {a : α} {as : Array α} :
a as (n : Nat), as[n]? = some a := by
intro ha
rcases List.getElem?_of_mem ha.val with i, hi
exists i
@[simp] theorem mem_dite_empty_left {x : α} [Decidable p] {l : ¬ p Array α} :
(x if h : p then #[] else l h) h : ¬ p, x l h := by
split <;> simp_all [mem_def]
split <;> simp_all
@[simp] theorem mem_dite_empty_right {x : α} [Decidable p] {l : p Array α} :
(x if h : p then l h else #[]) h : p, x l h := by
split <;> simp_all [mem_def]
split <;> simp_all
@[simp] theorem mem_ite_empty_left {x : α} [Decidable p] {l : Array α} :
(x if p then #[] else l) ¬ p x l := by
split <;> simp_all [mem_def]
split <;> simp_all
@[simp] theorem mem_ite_empty_right {x : α} [Decidable p] {l : Array α} :
(x if p then l else #[]) p x l := by
split <;> simp_all [mem_def]
split <;> simp_all
/-- # get lemmas -/
/-! # get lemmas -/
theorem lt_of_getElem {x : α} {a : Array α} {idx : Nat} {hidx : idx < a.size} (_ : a[idx] = x) :
idx < a.size :=
@@ -659,10 +698,6 @@ theorem get?_eq_get?_toList (a : Array α) (i : Nat) : a.get? i = a.toList.get?
theorem get!_eq_get? [Inhabited α] (a : Array α) : a.get! n = (a.get? n).getD default := by
simp only [get!_eq_getElem?, get?_eq_getElem?]
theorem getElem?_eq_some_iff {as : Array α} : as[n]? = some a h : n < as.size, as[n] = a := by
cases as
simp [List.getElem?_eq_some_iff]
theorem back!_eq_back? [Inhabited α] (a : Array α) : a.back! = a.back?.getD default := by
simp only [back!, get!_eq_getElem?, get?_eq_getElem?, back?]
@@ -672,6 +707,10 @@ theorem back!_eq_back? [Inhabited α] (a : Array α) : a.back! = a.back?.getD de
@[simp] theorem back!_push [Inhabited α] (a : Array α) : (a.push x).back! = x := by
simp [back!_eq_back?]
theorem mem_of_back?_eq_some {xs : Array α} {a : α} (h : xs.back? = some a) : a xs := by
cases xs
simpa using List.mem_of_getLast?_eq_some (by simpa using h)
theorem getElem?_push_lt (a : Array α) (x : α) (i : Nat) (h : i < a.size) :
(a.push x)[i]? = some a[i] := by
rw [getElem?_pos, getElem_push_lt]
@@ -730,10 +769,10 @@ theorem get_set (a : Array α) (i : Nat) (hi : i < a.size) (j : Nat) (hj : j < a
(h : i j) : (a.set i v)[j]'(by simp [*]) = a[j] := by
simp only [set, getElem_eq_getElem_toList, List.getElem_set_ne h]
theorem getElem_setD (a : Array α) (i : Nat) (v : α) (h : i < (setD a i v).size) :
(setD a i v)[i] = v := by
theorem getElem_setIfInBounds (a : Array α) (i : Nat) (v : α) (h : i < (setIfInBounds a i v).size) :
(setIfInBounds a i v)[i] = v := by
simp at h
simp only [setD, h, reduceDIte, getElem_set_eq]
simp only [setIfInBounds, h, reduceDIte, getElem_set_eq]
theorem set_set (a : Array α) (i : Nat) (h) (v v' : α) :
(a.set i v h).set i v' (by simp [h]) = a.set i v' := by simp [set, List.set_set]
@@ -815,16 +854,10 @@ theorem size_eq_length_toList (as : Array α) : as.size = as.toList.length := rf
simp only [reverse]; split <;> simp [go]
@[simp] theorem size_range {n : Nat} : (range n).size = n := by
unfold range
induction n with
| zero => simp [Nat.fold]
| succ k ih =>
rw [Nat.fold, flip]
simp only [mkEmpty_eq, size_push] at *
omega
induction n <;> simp [range]
@[simp] theorem toList_range (n : Nat) : (range n).toList = List.range n := by
induction n <;> simp_all [range, Nat.fold, flip, List.range_succ]
apply List.ext_getElem <;> simp [range]
@[simp]
theorem getElem_range {n : Nat} {x : Nat} (h : x < (Array.range n).size) : (Array.range n)[x] = x := by
@@ -1025,6 +1058,10 @@ theorem foldr_congr {as bs : Array α} (h₀ : as = bs) {f g : α → β → β}
@[simp] theorem mem_map {f : α β} {l : Array α} : b l.map f a, a l f a = b := by
simp only [mem_def, toList_map, List.mem_map]
theorem exists_of_mem_map (h : b map f l) : a, a l f a = b := mem_map.1 h
theorem mem_map_of_mem (f : α β) (h : a l) : f a map f l := mem_map.2 _, h, rfl
theorem mapM_eq_mapM_toList [Monad m] [LawfulMonad m] (f : α m β) (arr : Array α) :
arr.mapM f = List.toArray <$> (arr.toList.mapM f) := by
rw [mapM_eq_foldlM, foldlM_toList, List.foldrM_reverse]
@@ -1215,9 +1252,23 @@ theorem push_eq_append_singleton (as : Array α) (x) : as.push x = as ++ #[x] :=
@[simp] theorem mem_append {a : α} {s t : Array α} : a s ++ t a s a t := by
simp only [mem_def, toList_append, List.mem_append]
theorem mem_append_left {a : α} {l₁ : Array α} (l₂ : Array α) (h : a l₁) : a l₁ ++ l₂ :=
mem_append.2 (Or.inl h)
theorem mem_append_right {a : α} (l₁ : Array α) {l₂ : Array α} (h : a l₂) : a l₁ ++ l₂ :=
mem_append.2 (Or.inr h)
@[simp] theorem size_append (as bs : Array α) : (as ++ bs).size = as.size + bs.size := by
simp only [size, toList_append, List.length_append]
@[simp] theorem empty_append (as : Array α) : #[] ++ as = as := by
cases as
simp
@[simp] theorem append_empty (as : Array α) : as ++ #[] = as := by
cases as
simp
theorem getElem_append {as bs : Array α} (h : i < (as ++ bs).size) :
(as ++ bs)[i] = if h' : i < as.size then as[i] else bs[i - as.size]'(by simp at h; omega) := by
cases as; cases bs
@@ -1594,9 +1645,9 @@ theorem swap_comm (a : Array α) {i j : Fin a.size} : a.swap i j = a.swap j i :=
/-! ### eraseIdx -/
theorem feraseIdx_eq_eraseIdx {a : Array α} {i : Fin a.size} :
a.feraseIdx i = a.eraseIdx i.1 := by
simp [eraseIdx]
theorem eraseIdx_eq_eraseIdxIfInBounds {a : Array α} {i : Nat} (h : i < a.size) :
a.eraseIdx i h = a.eraseIdxIfInBounds i := by
simp [eraseIdxIfInBounds, h]
/-! ### isPrefixOf -/
@@ -1692,10 +1743,10 @@ Our goal is to have `simp` "pull `List.toArray` outwards" as much as possible.
apply ext'
simp
@[simp] theorem setD_toArray (l : List α) (i : Nat) (a : α) :
l.toArray.setD i a = (l.set i a).toArray := by
@[simp] theorem setIfInBounds_toArray (l : List α) (i : Nat) (a : α) :
l.toArray.setIfInBounds i a = (l.set i a).toArray := by
apply ext'
simp only [setD]
simp only [setIfInBounds]
split
· simp
· simp_all [List.set_eq_of_length_le]
@@ -1826,16 +1877,15 @@ theorem takeWhile_go_toArray (p : α → Bool) (l : List α) (i : Nat) :
l.toArray.takeWhile p = (l.takeWhile p).toArray := by
simp [Array.takeWhile, takeWhile_go_toArray]
@[simp] theorem feraseIdx_toArray (l : List α) (i : Fin l.toArray.size) :
l.toArray.feraseIdx i = (l.eraseIdx i).toArray := by
rw [feraseIdx]
split <;> rename_i h
· rw [feraseIdx_toArray]
@[simp] theorem eraseIdx_toArray (l : List α) (i : Nat) (h : i < l.toArray.size) :
l.toArray.eraseIdx i h = (l.eraseIdx i).toArray := by
rw [Array.eraseIdx]
split <;> rename_i h'
· rw [eraseIdx_toArray]
simp only [swap_toArray, Fin.getElem_fin, toList_toArray, mk.injEq]
rw [eraseIdx_set_gt (by simp), eraseIdx_set_eq]
simp
· rcases i with i, w
simp at h w
· simp at h h'
have t : i = l.length - 1 := by omega
simp [t]
termination_by l.length - i
@@ -1845,9 +1895,9 @@ decreasing_by
simp
omega
@[simp] theorem eraseIdx_toArray (l : List α) (i : Nat) :
l.toArray.eraseIdx i = (l.eraseIdx i).toArray := by
rw [Array.eraseIdx]
@[simp] theorem eraseIdxIfInBounds_toArray (l : List α) (i : Nat) :
l.toArray.eraseIdxIfInBounds i = (l.eraseIdx i).toArray := by
rw [Array.eraseIdxIfInBounds]
split
· simp
· simp_all [eraseIdx_eq_self.2]
@@ -1866,16 +1916,86 @@ namespace Array
(as.takeWhile p).toList = as.toList.takeWhile p := by
induction as; simp
@[simp] theorem toList_feraseIdx (as : Array α) (i : Fin as.size) :
(as.feraseIdx i).toList = as.toList.eraseIdx i.1 := by
@[simp] theorem toList_eraseIdx (as : Array α) (i : Nat) (h : i < as.size) :
(as.eraseIdx i h).toList = as.toList.eraseIdx i := by
induction as
simp
@[simp] theorem toList_eraseIdx (as : Array α) (i : Nat) :
(as.eraseIdx i).toList = as.toList.eraseIdx i := by
@[simp] theorem toList_eraseIdxIfInBounds (as : Array α) (i : Nat) :
(as.eraseIdxIfInBounds i).toList = as.toList.eraseIdx i := by
induction as
simp
/-! ### map -/
@[simp] theorem map_map {f : α β} {g : β γ} {as : Array α} :
(as.map f).map g = as.map (g f) := by
cases as; simp
@[simp] theorem map_id_fun : map (id : α α) = id := by
funext l
induction l <;> simp_all
/-- `map_id_fun'` differs from `map_id_fun` by representing the identity function as a lambda, rather than `id`. -/
@[simp] theorem map_id_fun' : map (fun (a : α) => a) = id := map_id_fun
-- This is not a `@[simp]` lemma because `map_id_fun` will apply.
theorem map_id (as : Array α) : map (id : α α) as = as := by
cases as <;> simp_all
/-- `map_id'` differs from `map_id` by representing the identity function as a lambda, rather than `id`. -/
-- This is not a `@[simp]` lemma because `map_id_fun'` will apply.
theorem map_id' (as : Array α) : map (fun (a : α) => a) as = as := map_id as
/-- Variant of `map_id`, with a side condition that the function is pointwise the identity. -/
theorem map_id'' {f : α α} (h : x, f x = x) (as : Array α) : map f as = as := by
simp [show f = id from funext h]
theorem array_array_induction (P : Array (Array α) Prop) (h : (xss : List (List α)), P (xss.map List.toArray).toArray)
(ass : Array (Array α)) : P ass := by
specialize h (ass.toList.map toList)
simpa [ toList_map, Function.comp_def, map_id] using h
theorem foldl_map (f : β₁ β₂) (g : α β₂ α) (l : Array β₁) (init : α) :
(l.map f).foldl g init = l.foldl (fun x y => g x (f y)) init := by
cases l; simp [List.foldl_map]
theorem foldr_map (f : α₁ α₂) (g : α₂ β β) (l : Array α₁) (init : β) :
(l.map f).foldr g init = l.foldr (fun x y => g (f x) y) init := by
cases l; simp [List.foldr_map]
theorem foldl_filterMap (f : α Option β) (g : γ β γ) (l : Array α) (init : γ) :
(l.filterMap f).foldl g init = l.foldl (fun x y => match f y with | some b => g x b | none => x) init := by
cases l
simp [List.foldl_filterMap]
rfl
theorem foldr_filterMap (f : α Option β) (g : β γ γ) (l : Array α) (init : γ) :
(l.filterMap f).foldr g init = l.foldr (fun x y => match f x with | some b => g b y | none => y) init := by
cases l
simp [List.foldr_filterMap]
rfl
/-! ### flatten -/
@[simp] theorem flatten_empty : flatten (#[] : Array (Array α)) = #[] := rfl
@[simp] theorem flatten_toArray_map_toArray (xss : List (List α)) :
(xss.map List.toArray).toArray.flatten = xss.flatten.toArray := by
simp [flatten]
suffices as, List.foldl (fun r a => r ++ a) as (List.map List.toArray xss) = as ++ xss.flatten.toArray by
simpa using this #[]
intro as
induction xss generalizing as with
| nil => simp
| cons xs xss ih => simp [ih]
/-! ### reverse -/
@[simp] theorem mem_reverse {x : α} {as : Array α} : x as.reverse x as := by
cases as
simp
/-! ### findSomeRevM?, findRevM?, findSomeRev?, findRev? -/
@[simp] theorem findSomeRevM?_eq_findSomeM?_reverse
@@ -1940,6 +2060,27 @@ namespace Array
cases as
simp
@[simp] theorem flatMap_empty {β} (f : α Array β) : (#[] : Array α).flatMap f = #[] := rfl
@[simp] theorem flatMap_toArray_cons {β} (f : α Array β) (a : α) (as : List α) :
(a :: as).toArray.flatMap f = f a ++ as.toArray.flatMap f := by
simp [flatMap]
suffices cs, List.foldl (fun bs a => bs ++ f a) (f a ++ cs) as =
f a ++ List.foldl (fun bs a => bs ++ f a) cs as by
erw [empty_append] -- Why doesn't this work via `simp`?
simpa using this #[]
intro cs
induction as generalizing cs <;> simp_all
@[simp] theorem flatMap_toArray {β} (f : α Array β) (as : List α) :
as.toArray.flatMap f = (as.flatMap (fun a => (f a).toList)).toArray := by
induction as with
| nil => simp
| cons a as ih =>
apply ext'
simp [ih]
end Array
/-! ### Deprecations -/
@@ -1957,6 +2098,8 @@ theorem toArray_concat {as : List α} {x : α} :
@[deprecated back!_toArray (since := "2024-10-31")] abbrev back_toArray := @back!_toArray
@[deprecated setIfInBounds_toArray (since := "2024-11-24")] abbrev setD_toArray := @setIfInBounds_toArray
end List
namespace Array
@@ -2102,4 +2245,11 @@ abbrev get_swap' := @getElem_swap'
@[deprecated eq_push_pop_back!_of_size_ne_zero (since := "2024-10-31")]
abbrev eq_push_pop_back_of_size_ne_zero := @eq_push_pop_back!_of_size_ne_zero
@[deprecated set!_is_setIfInBounds (since := "2024-11-24")] abbrev set_is_setIfInBounds := @set!_is_setIfInBounds
@[deprecated size_setIfInBounds (since := "2024-11-24")] abbrev size_setD := @size_setIfInBounds
@[deprecated getElem_setIfInBounds_eq (since := "2024-11-24")] abbrev getElem_setD_eq := @getElem_setIfInBounds_eq
@[deprecated getElem?_setIfInBounds_eq (since := "2024-11-24")] abbrev get?_setD_eq := @getElem?_setIfInBounds_eq
@[deprecated getD_get?_setIfInBounds (since := "2024-11-24")] abbrev getD_setD := @getD_get?_setIfInBounds
@[deprecated getElem_setIfInBounds (since := "2024-11-24")] abbrev getElem_setD := @getElem_setIfInBounds
end Array

View File

@@ -25,9 +25,11 @@ Set an element in an array, or do nothing if the index is out of bounds.
This will perform the update destructively provided that `a` has a reference
count of 1 when called.
-/
@[inline] def Array.setD (a : Array α) (i : Nat) (v : α) : Array α :=
@[inline] def Array.setIfInBounds (a : Array α) (i : Nat) (v : α) : Array α :=
dite (LT.lt i a.size) (fun h => a.set i v h) (fun _ => a)
@[deprecated Array.setIfInBounds (since := "2024-11-24")] abbrev Array.setD := @Array.setIfInBounds
/--
Set an element in an array, or panic if the index is out of bounds.
@@ -36,4 +38,4 @@ count of 1 when called.
-/
@[extern "lean_array_set"]
def Array.set! (a : Array α) (i : @& Nat) (v : α) : Array α :=
Array.setD a i v
Array.setIfInBounds a i v

View File

@@ -346,6 +346,10 @@ theorem getMsbD_sub {i : Nat} {i_lt : i < w} {x y : BitVec w} :
· rfl
· omega
theorem getElem_sub {i : Nat} {x y : BitVec w} (h : i < w) :
(x - y)[i] = (x[i] ^^ ((~~~y + 1#w)[i] ^^ carry i x (~~~y + 1#w) false)) := by
simp [ getLsbD_eq_getElem, getLsbD_sub, h]
theorem msb_sub {x y: BitVec w} :
(x - y).msb
= (x.msb ^^ ((~~~y + 1#w).msb ^^ carry (w - 1 - 0) x (~~~y + 1#w) false)) := by
@@ -403,13 +407,17 @@ theorem getLsbD_neg {i : Nat} {x : BitVec w} :
rw [carry_succ_one _ _ (by omega), Bool.xor_not, decide_not]
simp only [add_one_ne_zero, decide_false, getLsbD_not, and_eq_true, decide_eq_true_eq,
not_eq_eq_eq_not, Bool.not_true, false_bne, not_exists, _root_.not_and, not_eq_true,
bne_left_inj, decide_eq_decide]
bne_right_inj, decide_eq_decide]
constructor
· rintro h j hj; exact And.right <| h j (by omega)
· rintro h j hj; exact by omega, h j (by omega)
· have h_ge : w i := by omega
simp [getLsbD_ge _ _ h_ge, h_ge, hi]
theorem getElem_neg {i : Nat} {x : BitVec w} (h : i < w) :
(-x)[i] = (x[i] ^^ decide ( j < i, x.getLsbD j = true)) := by
simp [ getLsbD_eq_getElem, getLsbD_neg, h]
theorem getMsbD_neg {i : Nat} {x : BitVec w} :
getMsbD (-x) i =
(getMsbD x i ^^ decide ( j < w, i < j getMsbD x j = true)) := by
@@ -419,7 +427,7 @@ theorem getMsbD_neg {i : Nat} {x : BitVec w} :
simp [hi]; omega
case pos =>
have h₁ : w - 1 - i < w := by omega
simp only [hi, decide_true, h₁, Bool.true_and, Bool.bne_left_inj, decide_eq_decide]
simp only [hi, decide_true, h₁, Bool.true_and, Bool.bne_right_inj, decide_eq_decide]
constructor
· rintro j, hj, h
refine w - 1 - j, by omega, by omega, by omega, _root_.cast ?_ h

View File

@@ -269,6 +269,10 @@ theorem ofBool_eq_iff_eq : ∀ {b b' : Bool}, BitVec.ofBool b = BitVec.ofBool b'
getLsbD (x#'lt) i = x.testBit i := by
simp [getLsbD, BitVec.ofNatLt]
@[simp] theorem getMsbD_ofNatLt {n x i : Nat} (h : x < 2^n) :
getMsbD (x#'h) i = (decide (i < n) && x.testBit (n - 1 - i)) := by
simp [getMsbD, getLsbD]
@[simp, bv_toNat] theorem toNat_ofNat (x w : Nat) : (BitVec.ofNat w x).toNat = x % 2^w := by
simp [BitVec.toNat, BitVec.ofNat, Fin.ofNat']
@@ -755,6 +759,10 @@ theorem extractLsb'_eq_extractLsb {w : Nat} (x : BitVec w) (start len : Nat) (h
@[simp] theorem getLsbD_allOnes : (allOnes v).getLsbD i = decide (i < v) := by
simp [allOnes]
@[simp] theorem getMsbD_allOnes : (allOnes v).getMsbD i = decide (i < v) := by
simp [allOnes]
omega
@[simp] theorem getElem_allOnes (i : Nat) (h : i < v) : (allOnes v)[i] = true := by
simp [getElem_eq_testBit_toNat, h]
@@ -772,6 +780,12 @@ theorem extractLsb'_eq_extractLsb {w : Nat} (x : BitVec w) (start len : Nat) (h
@[simp] theorem toNat_or (x y : BitVec v) :
BitVec.toNat (x ||| y) = BitVec.toNat x ||| BitVec.toNat y := rfl
@[simp] theorem toInt_or (x y : BitVec w) :
BitVec.toInt (x ||| y) = Int.bmod (BitVec.toNat x ||| BitVec.toNat y) (2^w) := by
rw_mod_cast [Int.bmod_def, BitVec.toInt, toNat_or, Nat.mod_eq_of_lt
(Nat.or_lt_two_pow (BitVec.isLt x) (BitVec.isLt y))]
omega
@[simp] theorem toFin_or (x y : BitVec v) :
BitVec.toFin (x ||| y) = BitVec.toFin x ||| BitVec.toFin y := by
apply Fin.eq_of_val_eq
@@ -839,6 +853,12 @@ instance : Std.LawfulCommIdentity (α := BitVec n) (· ||| · ) (0#n) where
@[simp] theorem toNat_and (x y : BitVec v) :
BitVec.toNat (x &&& y) = BitVec.toNat x &&& BitVec.toNat y := rfl
@[simp] theorem toInt_and (x y : BitVec w) :
BitVec.toInt (x &&& y) = Int.bmod (BitVec.toNat x &&& BitVec.toNat y) (2^w) := by
rw_mod_cast [Int.bmod_def, BitVec.toInt, toNat_and, Nat.mod_eq_of_lt
(Nat.and_lt_two_pow x.toNat (BitVec.isLt y))]
omega
@[simp] theorem toFin_and (x y : BitVec v) :
BitVec.toFin (x &&& y) = BitVec.toFin x &&& BitVec.toFin y := by
apply Fin.eq_of_val_eq
@@ -906,6 +926,12 @@ instance : Std.LawfulCommIdentity (α := BitVec n) (· &&& · ) (allOnes n) wher
@[simp] theorem toNat_xor (x y : BitVec v) :
BitVec.toNat (x ^^^ y) = BitVec.toNat x ^^^ BitVec.toNat y := rfl
@[simp] theorem toInt_xor (x y : BitVec w) :
BitVec.toInt (x ^^^ y) = Int.bmod (BitVec.toNat x ^^^ BitVec.toNat y) (2^w) := by
rw_mod_cast [Int.bmod_def, BitVec.toInt, toNat_xor, Nat.mod_eq_of_lt
(Nat.xor_lt_two_pow (BitVec.isLt x) (BitVec.isLt y))]
omega
@[simp] theorem toFin_xor (x y : BitVec v) :
BitVec.toFin (x ^^^ y) = BitVec.toFin x ^^^ BitVec.toFin y := by
apply Fin.eq_of_val_eq
@@ -983,6 +1009,13 @@ theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl
_ 2 ^ i := Nat.pow_le_pow_of_le_right Nat.zero_lt_two w
· simp
@[simp] theorem toInt_not {x : BitVec w} :
(~~~x).toInt = Int.bmod (2^w - 1 - x.toNat) (2^w) := by
rw_mod_cast [BitVec.toInt, BitVec.toNat_not, Int.bmod_def]
simp [show ((2^w : Nat) : Int) - 1 - x.toNat = ((2^w - 1 - x.toNat) : Nat) by omega]
rw_mod_cast [Nat.mod_eq_of_lt (by omega)]
omega
@[simp] theorem ofInt_negSucc_eq_not_ofNat {w n : Nat} :
BitVec.ofInt w (Int.negSucc n) = ~~~.ofNat w n := by
simp only [BitVec.ofInt, Int.toNat, Int.ofNat_eq_coe, toNat_eq, toNat_ofNatLt, toNat_not,
@@ -1007,6 +1040,10 @@ theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl
@[simp] theorem getLsbD_not {x : BitVec v} : (~~~x).getLsbD i = (decide (i < v) && ! x.getLsbD i) := by
by_cases h' : i < v <;> simp_all [not_def]
@[simp] theorem getMsbD_not {x : BitVec v} :
(~~~x).getMsbD i = (decide (i < v) && ! x.getMsbD i) := by
by_cases h' : i < v <;> simp_all [not_def]
@[simp] theorem getElem_not {x : BitVec w} {i : Nat} (h : i < w) : (~~~x)[i] = !x[i] := by
simp only [getElem_eq_testBit_toNat, toNat_not]
rw [ Nat.sub_add_eq, Nat.add_comm 1]
@@ -1480,6 +1517,12 @@ theorem getLsbD_sshiftRight' {x y: BitVec w} {i : Nat} :
(!decide (w i) && if y.toNat + i < w then x.getLsbD (y.toNat + i) else x.msb) := by
simp only [BitVec.sshiftRight', BitVec.getLsbD_sshiftRight]
@[simp]
theorem getElem_sshiftRight' {x y : BitVec w} {i : Nat} (h : i < w) :
(x.sshiftRight' y)[i] =
(!decide (w i) && if y.toNat + i < w then x.getLsbD (y.toNat + i) else x.msb) := by
simp only [ getLsbD_eq_getElem, BitVec.sshiftRight', BitVec.getLsbD_sshiftRight]
@[simp]
theorem getMsbD_sshiftRight' {x y: BitVec w} {i : Nat} :
(x.sshiftRight y.toNat).getMsbD i = (decide (i < w) && if i < y.toNat then x.msb else x.getMsbD (i - y.toNat)) := by
@@ -2611,7 +2654,7 @@ theorem getLsbD_rotateLeftAux_of_geq {x : BitVec w} {r : Nat} {i : Nat} (hi : i
apply getLsbD_ge
omega
/-- When `r < w`, we give a formula for `(x.rotateRight r).getLsbD i`. -/
/-- When `r < w`, we give a formula for `(x.rotateLeft r).getLsbD i`. -/
theorem getLsbD_rotateLeft_of_le {x : BitVec w} {r i : Nat} (hr: r < w) :
(x.rotateLeft r).getLsbD i =
cond (i < r)
@@ -2638,6 +2681,56 @@ theorem getElem_rotateLeft {x : BitVec w} {r i : Nat} (h : i < w) :
if h' : i < r % w then x[(w - (r % w) + i)] else x[i - (r % w)] := by
simp [ BitVec.getLsbD_eq_getElem, h]
/-- If `w ≤ x < 2 * w`, then `x % w = x - w` -/
theorem mod_eq_sub_of_le_of_lt {x w : Nat} (x_le : w x) (x_lt : x < 2 * w) :
x % w = x - w := by
rw [Nat.mod_eq_sub_mod, Nat.mod_eq_of_lt (by omega)]
omega
theorem getMsbD_rotateLeftAux_of_lt {x : BitVec w} {r : Nat} {i : Nat} (hi : i < w - r) :
(x.rotateLeftAux r).getMsbD i = x.getMsbD (r + i) := by
rw [rotateLeftAux, getMsbD_or]
simp [show i < w - r by omega, Nat.add_comm]
theorem getMsbD_rotateLeftAux_of_ge {x : BitVec w} {r : Nat} {i : Nat} (hi : i w - r) :
(x.rotateLeftAux r).getMsbD i = (decide (i < w) && x.getMsbD (i - (w - r))) := by
simp [rotateLeftAux, getMsbD_or, show i + r w by omega, show ¬i < w - r by omega]
/-- When `r < w`, we give a formula for `(x.rotateLeft r).getMsbD i`. -/
theorem getMsbD_rotateLeft_of_lt {n w : Nat} {x : BitVec w} (hi : r < w):
(x.rotateLeft r).getMsbD n = (decide (n < w) && x.getMsbD ((r + n) % w)) := by
rcases w with rfl | w
· simp
· rw [BitVec.rotateLeft_eq_rotateLeftAux_of_lt (by omega)]
by_cases h : n < (w + 1) - r
· simp [getMsbD_rotateLeftAux_of_lt h, Nat.mod_eq_of_lt, show r + n < (w + 1) by omega, show n < w + 1 by omega]
· simp [getMsbD_rotateLeftAux_of_ge <| Nat.ge_of_not_lt h]
by_cases h₁ : n < w + 1
· simp only [h₁, decide_true, Bool.true_and]
have h₂ : (r + n) < 2 * (w + 1) := by omega
rw [mod_eq_sub_of_le_of_lt (by omega) (by omega)]
congr 1
omega
· simp [h₁]
theorem getMsbD_rotateLeft {r n w : Nat} {x : BitVec w} :
(x.rotateLeft r).getMsbD n = (decide (n < w) && x.getMsbD ((r + n) % w)) := by
rcases w with rfl | w
· simp
· by_cases h : r < w
· rw [getMsbD_rotateLeft_of_lt (by omega)]
· rw [ rotateLeft_mod_eq_rotateLeft, getMsbD_rotateLeft_of_lt (by apply Nat.mod_lt; simp)]
simp
@[simp]
theorem msb_rotateLeft {m w : Nat} {x : BitVec w} :
(x.rotateLeft m).msb = x.getMsbD (m % w) := by
simp only [BitVec.msb, getMsbD_rotateLeft]
by_cases h : w = 0
· simp [h]
· simp
omega
/-! ## Rotate Right -/
/--
@@ -2699,7 +2792,7 @@ theorem rotateRight_mod_eq_rotateRight {x : BitVec w} {r : Nat} :
simp only [rotateRight, Nat.mod_mod]
/-- When `r < w`, we give a formula for `(x.rotateRight r).getLsb i`. -/
theorem getLsbD_rotateRight_of_le {x : BitVec w} {r i : Nat} (hr: r < w) :
theorem getLsbD_rotateRight_of_lt {x : BitVec w} {r i : Nat} (hr: r < w) :
(x.rotateRight r).getLsbD i =
cond (i < w - r)
(x.getLsbD (r + i))
@@ -2717,7 +2810,7 @@ theorem getLsbD_rotateRight {x : BitVec w} {r i : Nat} :
(decide (i < w) && x.getLsbD (i - (w - (r % w)))) := by
rcases w with rfl, w
· simp
· rw [ rotateRight_mod_eq_rotateRight, getLsbD_rotateRight_of_le (Nat.mod_lt _ (by omega))]
· rw [ rotateRight_mod_eq_rotateRight, getLsbD_rotateRight_of_lt (Nat.mod_lt _ (by omega))]
@[simp]
theorem getElem_rotateRight {x : BitVec w} {r i : Nat} (h : i < w) :
@@ -2725,6 +2818,56 @@ theorem getElem_rotateRight {x : BitVec w} {r i : Nat} (h : i < w) :
simp only [ BitVec.getLsbD_eq_getElem]
simp [getLsbD_rotateRight, h]
theorem getMsbD_rotateRightAux_of_lt {x : BitVec w} {r : Nat} {i : Nat} (hi : i < r) :
(x.rotateRightAux r).getMsbD i = x.getMsbD (i + (w - r)) := by
rw [rotateRightAux, getMsbD_or, getMsbD_ushiftRight]
simp [show i < r by omega]
theorem getMsbD_rotateRightAux_of_ge {x : BitVec w} {r : Nat} {i : Nat} (hi : i r) :
(x.rotateRightAux r).getMsbD i = (decide (i < w) && x.getMsbD (i - r)) := by
simp [rotateRightAux, show ¬ i < r by omega, show i + (w - r) w by omega]
/-- When `m < w`, we give a formula for `(x.rotateLeft m).getMsbD i`. -/
@[simp]
theorem getMsbD_rotateRight_of_lt {w n m : Nat} {x : BitVec w} (hr : m < w):
(x.rotateRight m).getMsbD n = (decide (n < w) && (if (n < m % w)
then x.getMsbD ((w + n - m % w) % w) else x.getMsbD (n - m % w))):= by
rcases w with rfl | w
· simp
· rw [rotateRight_eq_rotateRightAux_of_lt (by omega)]
by_cases h : n < m
· simp only [getMsbD_rotateRightAux_of_lt h, show n < w + 1 by omega, decide_true,
show m % (w + 1) = m by rw [Nat.mod_eq_of_lt hr], h, reduceIte,
show (w + 1 + n - m) < (w + 1) by omega, Nat.mod_eq_of_lt, Bool.true_and]
congr 1
omega
· simp [h, getMsbD_rotateRightAux_of_ge <| Nat.ge_of_not_lt h]
by_cases h₁ : n < w + 1
· simp [h, h₁, decide_true, Bool.true_and, Nat.mod_eq_of_lt hr]
· simp [h₁]
@[simp]
theorem getMsbD_rotateRight {w n m : Nat} {x : BitVec w} :
(x.rotateRight m).getMsbD n = (decide (n < w) && (if (n < m % w)
then x.getMsbD ((w + n - m % w) % w) else x.getMsbD (n - m % w))):= by
rcases w with rfl | w
· simp
· by_cases h₀ : m < w
· rw [getMsbD_rotateRight_of_lt (by omega)]
· rw [ rotateRight_mod_eq_rotateRight, getMsbD_rotateRight_of_lt (by apply Nat.mod_lt; simp)]
simp
@[simp]
theorem msb_rotateRight {r w : Nat} {x : BitVec w} :
(x.rotateRight r).msb = x.getMsbD ((w - r % w) % w) := by
simp only [BitVec.msb, getMsbD_rotateRight]
by_cases h₀ : 0 < w
· simp only [h₀, decide_true, Nat.add_zero, Nat.zero_le, Nat.sub_eq_zero_of_le, Bool.true_and,
ite_eq_left_iff, Nat.not_lt, Nat.le_zero_eq]
intro h₁
simp [h₁]
· simp [show w = 0 by omega]
/- ## twoPow -/
theorem twoPow_eq (w : Nat) (i : Nat) : twoPow w i = 1#w <<< i := by
@@ -3124,7 +3267,11 @@ theorem toNat_abs {x : BitVec w} : x.abs.toNat = if x.msb then 2^w - x.toNat els
· simp [h]
theorem getLsbD_abs {i : Nat} {x : BitVec w} :
getLsbD x.abs i = if x.msb then getLsbD (-x) i else getLsbD x i := by
getLsbD x.abs i = if x.msb then getLsbD (-x) i else getLsbD x i := by
by_cases h : x.msb <;> simp [BitVec.abs, h]
theorem getElem_abs {i : Nat} {x : BitVec w} (h : i < w) :
x.abs[i] = if x.msb then (-x)[i] else x[i] := by
by_cases h : x.msb <;> simp [BitVec.abs, h]
theorem getMsbD_abs {i : Nat} {x : BitVec w} :

View File

@@ -238,8 +238,8 @@ theorem not_bne_not : ∀ (x y : Bool), ((!x) != (!y)) = (x != y) := by simp
@[simp] theorem bne_assoc : (x y z : Bool), ((x != y) != z) = (x != (y != z)) := by decide
instance : Std.Associative (· != ·) := bne_assoc
@[simp] theorem bne_left_inj : {x y z : Bool}, (x != y) = (x != z) y = z := by decide
@[simp] theorem bne_right_inj : {x y z : Bool}, (x != z) = (y != z) x = y := by decide
@[simp] theorem bne_right_inj : {x y z : Bool}, (x != y) = (x != z) y = z := by decide
@[simp] theorem bne_left_inj : {x y z : Bool}, (x != z) = (y != z) x = y := by decide
theorem eq_not_of_ne : {x y : Bool}, x y x = !y := by decide
@@ -295,9 +295,9 @@ theorem xor_right_comm : ∀ (x y z : Bool), ((x ^^ y) ^^ z) = ((x ^^ z) ^^ y) :
theorem xor_assoc : (x y z : Bool), ((x ^^ y) ^^ z) = (x ^^ (y ^^ z)) := bne_assoc
theorem xor_left_inj : {x y z : Bool}, (x ^^ y) = (x ^^ z) y = z := bne_left_inj
theorem xor_right_inj : {x y z : Bool}, (x ^^ y) = (x ^^ z) y = z := bne_right_inj
theorem xor_right_inj : {x y z : Bool}, (x ^^ z) = (y ^^ z) x = y := bne_right_inj
theorem xor_left_inj : {x y z : Bool}, (x ^^ z) = (y ^^ z) x = y := bne_left_inj
/-! ### le/lt -/

View File

@@ -31,7 +31,7 @@ opaque floatSpec : FloatSpec := {
structure Float where
val : floatSpec.float
instance : Inhabited Float := { val := floatSpec.val }
instance : Nonempty Float := { val := floatSpec.val }
@[extern "lean_float_add"] opaque Float.add : Float Float Float
@[extern "lean_float_sub"] opaque Float.sub : Float Float Float
@@ -47,6 +47,25 @@ def Float.lt : Float → Float → Prop := fun a b =>
def Float.le : Float Float Prop := fun a b =>
floatSpec.le a.val b.val
/--
Raw transmutation from `UInt64`.
Floats and UInts have the same endianness on all supported platforms.
IEEE 754 very precisely specifies the bit layout of floats.
-/
@[extern "lean_float_of_bits"] opaque Float.ofBits : UInt64 Float
/--
Raw transmutation to `UInt64`.
Floats and UInts have the same endianness on all supported platforms.
IEEE 754 very precisely specifies the bit layout of floats.
Note that this function is distinct from `Float.toUInt64`, which attempts
to preserve the numeric value, and not the bitwise value.
-/
@[extern "lean_float_to_bits"] opaque Float.toBits : Float UInt64
instance : Add Float := Float.add
instance : Sub Float := Float.sub
instance : Mul Float := Float.mul
@@ -117,6 +136,9 @@ instance : ToString Float where
@[extern "lean_uint64_to_float"] opaque UInt64.toFloat (n : UInt64) : Float
instance : Inhabited Float where
default := UInt64.toFloat 0
instance : Repr Float where
reprPrec n prec := if n < UInt64.toFloat 0 then Repr.addAppParen (toString n) prec else toString n

View File

@@ -329,22 +329,22 @@ theorem toNat_sub (m n : Nat) : toNat (m - n) = m - n := by
/- ## add/sub injectivity -/
@[simp]
protected theorem add_right_inj {i j : Int} (k : Int) : (i + k = j + k) i = j := by
protected theorem add_left_inj {i j : Int} (k : Int) : (i + k = j + k) i = j := by
apply Iff.intro
· intro p
rw [Int.add_sub_cancel i k, Int.add_sub_cancel j k, p]
· exact congrArg (· + k)
@[simp]
protected theorem add_left_inj {i j : Int} (k : Int) : (k + i = k + j) i = j := by
protected theorem add_right_inj {i j : Int} (k : Int) : (k + i = k + j) i = j := by
simp [Int.add_comm k]
@[simp]
protected theorem sub_left_inj {i j : Int} (k : Int) : (k - i = k - j) i = j := by
protected theorem sub_right_inj {i j : Int} (k : Int) : (k - i = k - j) i = j := by
simp [Int.sub_eq_add_neg, Int.neg_inj]
@[simp]
protected theorem sub_right_inj {i j : Int} (k : Int) : (i - k = j - k) i = j := by
protected theorem sub_left_inj {i j : Int} (k : Int) : (i - k = j - k) i = j := by
simp [Int.sub_eq_add_neg]
/- ## Ring properties -/

View File

@@ -13,7 +13,7 @@ namespace List
`a : α` satisfying `P`, then `pmap f l h` is essentially the same as `map f l`
but is defined only when all members of `l` satisfy `P`, using the proof
to apply `f`. -/
@[simp] def pmap {P : α Prop} (f : a, P a β) : l : List α, (H : a l, P a) List β
def pmap {P : α Prop} (f : a, P a β) : l : List α, (H : a l, P a) List β
| [], _ => []
| a :: l, H => f a (forall_mem_cons.1 H).1 :: pmap f l (forall_mem_cons.1 H).2
@@ -46,6 +46,11 @@ Unsafe implementation of `attachWith`, taking advantage of the fact that the rep
| cons _ L', hL' => congrArg _ <| go L' fun _ hx => hL' (.tail _ hx)
exact go L h'
@[simp] theorem pmap_nil {P : α Prop} (f : a, P a β) : pmap f [] (by simp) = [] := rfl
@[simp] theorem pmap_cons {P : α Prop} (f : a, P a β) (a : α) (l : List α) (h : b a :: l, P b) :
pmap f (a :: l) h = f a (forall_mem_cons.1 h).1 :: pmap f l (forall_mem_cons.1 h).2 := rfl
@[simp] theorem attach_nil : ([] : List α).attach = [] := rfl
@[simp] theorem attachWith_nil : ([] : List α).attachWith P H = [] := rfl
@@ -148,7 +153,7 @@ theorem mem_pmap_of_mem {p : α → Prop} {f : ∀ a, p a → β} {l H} {a} (h :
exact a, h, rfl
@[simp]
theorem length_pmap {p : α Prop} {f : a, p a β} {l H} : length (pmap f l H) = length l := by
theorem length_pmap {p : α Prop} {f : a, p a β} {l H} : (pmap f l H).length = l.length := by
induction l
· rfl
· simp only [*, pmap, length]
@@ -199,7 +204,7 @@ theorem attachWith_ne_nil_iff {l : List α} {P : α → Prop} {H : ∀ a ∈ l,
@[simp]
theorem getElem?_pmap {p : α Prop} (f : a, p a β) {l : List α} (h : a l, p a) (n : Nat) :
(pmap f l h)[n]? = Option.pmap f l[n]? fun x H => h x (getElem?_mem H) := by
(pmap f l h)[n]? = Option.pmap f l[n]? fun x H => h x (mem_of_getElem? H) := by
induction l generalizing n with
| nil => simp
| cons hd tl hl =>
@@ -215,7 +220,7 @@ theorem getElem?_pmap {p : α → Prop} (f : ∀ a, p a → β) {l : List α} (h
· simp_all
theorem get?_pmap {p : α Prop} (f : a, p a β) {l : List α} (h : a l, p a) (n : Nat) :
get? (pmap f l h) n = Option.pmap f (get? l n) fun x H => h x (get?_mem H) := by
get? (pmap f l h) n = Option.pmap f (get? l n) fun x H => h x (mem_of_get? H) := by
simp only [get?_eq_getElem?]
simp [getElem?_pmap, h]
@@ -238,18 +243,18 @@ theorem get_pmap {p : α → Prop} (f : ∀ a, p a → β) {l : List α} (h :
(hn : n < (pmap f l h).length) :
get (pmap f l h) n, hn =
f (get l n, @length_pmap _ _ p f l h hn)
(h _ (get_mem l n (@length_pmap _ _ p f l h hn))) := by
(h _ (getElem_mem (@length_pmap _ _ p f l h hn))) := by
simp only [get_eq_getElem]
simp [getElem_pmap]
@[simp]
theorem getElem?_attachWith {xs : List α} {i : Nat} {P : α Prop} {H : a xs, P a} :
(xs.attachWith P H)[i]? = xs[i]?.pmap Subtype.mk (fun _ a => H _ (getElem?_mem a)) :=
(xs.attachWith P H)[i]? = xs[i]?.pmap Subtype.mk (fun _ a => H _ (mem_of_getElem? a)) :=
getElem?_pmap ..
@[simp]
theorem getElem?_attach {xs : List α} {i : Nat} :
xs.attach[i]? = xs[i]?.pmap Subtype.mk (fun _ a => getElem?_mem a) :=
xs.attach[i]? = xs[i]?.pmap Subtype.mk (fun _ a => mem_of_getElem? a) :=
getElem?_attachWith
@[simp]
@@ -333,6 +338,7 @@ This is useful when we need to use `attach` to show termination.
Unfortunately this can't be applied by `simp` because of the higher order unification problem,
and even when rewriting we need to specify the function explicitly.
See however `foldl_subtype` below.
-/
theorem foldl_attach (l : List α) (f : β α β) (b : β) :
l.attach.foldl (fun acc t => f acc t.1) b = l.foldl f b := by
@@ -348,6 +354,7 @@ This is useful when we need to use `attach` to show termination.
Unfortunately this can't be applied by `simp` because of the higher order unification problem,
and even when rewriting we need to specify the function explicitly.
See however `foldr_subtype` below.
-/
theorem foldr_attach (l : List α) (f : α β β) (b : β) :
l.attach.foldr (fun t acc => f t.1 acc) b = l.foldr f b := by
@@ -452,16 +459,16 @@ theorem pmap_append' {p : α → Prop} (f : ∀ a : α, p a → β) (l₁ l₂ :
pmap_append f l₁ l₂ _
@[simp] theorem attach_append (xs ys : List α) :
(xs ++ ys).attach = xs.attach.map (fun x, h => x, mem_append_of_mem_left ys h) ++
ys.attach.map fun x, h => x, mem_append_of_mem_right xs h := by
(xs ++ ys).attach = xs.attach.map (fun x, h => x, mem_append_left ys h) ++
ys.attach.map fun x, h => x, mem_append_right xs h := by
simp only [attach, attachWith, pmap, map_pmap, pmap_append]
congr 1 <;>
exact pmap_congr_left _ fun _ _ _ _ => rfl
@[simp] theorem attachWith_append {P : α Prop} {xs ys : List α}
{H : (a : α), a xs ++ ys P a} :
(xs ++ ys).attachWith P H = xs.attachWith P (fun a h => H a (mem_append_of_mem_left ys h)) ++
ys.attachWith P (fun a h => H a (mem_append_of_mem_right xs h)) := by
(xs ++ ys).attachWith P H = xs.attachWith P (fun a h => H a (mem_append_left ys h)) ++
ys.attachWith P (fun a h => H a (mem_append_right xs h)) := by
simp only [attachWith, attach_append, map_pmap, pmap_append]
@[simp] theorem pmap_reverse {P : α Prop} (f : (a : α) P a β) (xs : List α)
@@ -598,6 +605,15 @@ def unattach {α : Type _} {p : α → Prop} (l : List { x // p x }) := l.map (
| nil => simp
| cons a l ih => simp [ih, Function.comp_def]
@[simp] theorem getElem?_unattach {p : α Prop} {l : List { x // p x }} (i : Nat) :
l.unattach[i]? = l[i]?.map Subtype.val := by
simp [unattach]
@[simp] theorem getElem_unattach
{p : α Prop} {l : List { x // p x }} (i : Nat) (h : i < l.unattach.length) :
l.unattach[i] = (l[i]'(by simpa using h)).1 := by
simp [unattach]
/-! ### Recognizing higher order functions on subtypes using a function that only depends on the value. -/
/--

View File

@@ -726,13 +726,13 @@ theorem elem_eq_true_of_mem [BEq α] [LawfulBEq α] {a : α} {as : List α} (h :
instance [BEq α] [LawfulBEq α] (a : α) (as : List α) : Decidable (a as) :=
decidable_of_decidable_of_iff (Iff.intro mem_of_elem_eq_true elem_eq_true_of_mem)
theorem mem_append_of_mem_left {a : α} {as : List α} (bs : List α) : a as a as ++ bs := by
theorem mem_append_left {a : α} {as : List α} (bs : List α) : a as a as ++ bs := by
intro h
induction h with
| head => apply Mem.head
| tail => apply Mem.tail; assumption
theorem mem_append_of_mem_right {b : α} {bs : List α} (as : List α) : b bs b as ++ bs := by
theorem mem_append_right {b : α} {bs : List α} (as : List α) : b bs b as ++ bs := by
intro h
induction as with
| nil => simp [h]

View File

@@ -256,7 +256,7 @@ theorem findM?_eq_findSomeM? [Monad m] [LawfulMonad m] (p : α → m Bool) (as :
have : a as := by
have bs, h := h
subst h
exact mem_append_of_mem_right _ (Mem.head ..)
exact mem_append_right _ (Mem.head ..)
match ( f a this b) with
| ForInStep.done b => pure b
| ForInStep.yield b =>

View File

@@ -372,6 +372,17 @@ theorem getElem?_concat_length (l : List α) (a : α) : (l ++ [a])[l.length]? =
@[deprecated getElem?_concat_length (since := "2024-06-12")]
theorem get?_concat_length (l : List α) (a : α) : (l ++ [a]).get? l.length = some a := by simp
@[simp] theorem isSome_getElem? {l : List α} {n : Nat} : l[n]?.isSome n < l.length := by
by_cases h : n < l.length
· simp_all
· simp [h]
simp_all
@[simp] theorem isNone_getElem? {l : List α} {n : Nat} : l[n]?.isNone l.length n := by
by_cases h : n < l.length
· simp_all
· simp [h]
/-! ### mem -/
@[simp] theorem not_mem_nil (a : α) : ¬ a [] := nofun
@@ -383,9 +394,9 @@ theorem get?_concat_length (l : List α) (a : α) : (l ++ [a]).get? l.length = s
theorem mem_cons_self (a : α) (l : List α) : a a :: l := .head ..
theorem mem_concat_self (xs : List α) (a : α) : a xs ++ [a] :=
mem_append_of_mem_right xs (mem_cons_self a _)
mem_append_right xs (mem_cons_self a _)
theorem mem_append_cons_self : a xs ++ a :: ys := mem_append_of_mem_right _ (mem_cons_self _ _)
theorem mem_append_cons_self : a xs ++ a :: ys := mem_append_right _ (mem_cons_self _ _)
theorem eq_append_cons_of_mem {a : α} {xs : List α} (h : a xs) :
as bs, xs = as ++ a :: bs a as := by
@@ -492,16 +503,20 @@ theorem getElem?_of_mem {a} {l : List α} (h : a ∈ l) : ∃ n : Nat, l[n]? = s
theorem get?_of_mem {a} {l : List α} (h : a l) : n, l.get? n = some a :=
let n, _, e := get_of_mem h; n, e get?_eq_get _
theorem get_mem : (l : List α) n h, get l n, h l
| _ :: _, 0, _ => .head ..
| _ :: l, _+1, _ => .tail _ (get_mem l ..)
theorem get_mem : (l : List α) n, get l n l
| _ :: _, 0, _ => .head ..
| _ :: l, _+1, _ => .tail _ (get_mem l ..)
theorem getElem?_mem {l : List α} {n : Nat} {a : α} (e : l[n]? = some a) : a l :=
theorem mem_of_getElem? {l : List α} {n : Nat} {a : α} (e : l[n]? = some a) : a l :=
let _, e := getElem?_eq_some_iff.1 e; e getElem_mem ..
theorem get?_mem {l : List α} {n a} (e : l.get? n = some a) : a l :=
@[deprecated mem_of_getElem? (since := "2024-09-06")] abbrev getElem?_mem := @mem_of_getElem?
theorem mem_of_get? {l : List α} {n a} (e : l.get? n = some a) : a l :=
let _, e := get?_eq_some.1 e; e get_mem ..
@[deprecated mem_of_get? (since := "2024-09-06")] abbrev get?_mem := @mem_of_get?
theorem mem_iff_getElem {a} {l : List α} : a l (n : Nat) (h : n < l.length), l[n]'h = a :=
getElem_of_mem, fun _, _, e => e getElem_mem ..
@@ -1025,6 +1040,10 @@ theorem getLast_eq_getElem : ∀ (l : List α) (h : l ≠ []),
| _ :: _ :: _, _ => by
simp [getLast, get, Nat.succ_sub_succ, getLast_eq_getElem]
theorem getElem_length_sub_one_eq_getLast (l : List α) (h) :
l[l.length - 1] = getLast l (by cases l; simp at h; simp) := by
rw [ getLast_eq_getElem]
@[deprecated getLast_eq_getElem (since := "2024-07-15")]
theorem getLast_eq_get (l : List α) (h : l []) :
getLast l h = l.get l.length - 1, by
@@ -1149,6 +1168,11 @@ theorem head_eq_getElem (l : List α) (h : l ≠ []) : head l h = l[0]'(length_p
| nil => simp at h
| cons _ _ => simp
theorem getElem_zero_eq_head (l : List α) (h) : l[0] = head l (by simpa [length_pos] using h) := by
cases l with
| nil => simp at h
| cons _ _ => simp
theorem head_eq_iff_head?_eq_some {xs : List α} (h) : xs.head h = a xs.head? = some a := by
cases xs with
| nil => simp at h
@@ -1977,11 +2001,8 @@ theorem not_mem_append {a : α} {s t : List α} (h₁ : a ∉ s) (h₂ : a ∉ t
theorem mem_append_eq (a : α) (s t : List α) : (a s ++ t) = (a s a t) :=
propext mem_append
theorem mem_append_left {a : α} {l₁ : List α} (l₂ : List α) (h : a l₁) : a l₁ ++ l₂ :=
mem_append.2 (Or.inl h)
theorem mem_append_right {a : α} (l₁ : List α) {l₂ : List α} (h : a l₂) : a l₁ ++ l₂ :=
mem_append.2 (Or.inr h)
@[deprecated mem_append_left (since := "2024-11-20")] abbrev mem_append_of_mem_left := @mem_append_left
@[deprecated mem_append_right (since := "2024-11-20")] abbrev mem_append_of_mem_right := @mem_append_right
theorem mem_iff_append {a : α} {l : List α} : a l s t : List α, l = s ++ a :: t :=
append_of_mem, fun s, t, e => e by simp
@@ -2395,7 +2416,7 @@ theorem forall_mem_replicate {p : α → Prop} {a : α} {n} :
@[simp] theorem getElem_replicate (a : α) {n : Nat} {m} (h : m < (replicate n a).length) :
(replicate n a)[m] = a :=
eq_of_mem_replicate (get_mem _ _ _)
eq_of_mem_replicate (getElem_mem _)
@[deprecated getElem_replicate (since := "2024-06-12")]
theorem get_replicate (a : α) {n : Nat} (m : Fin _) : (replicate n a).get m = a := by

View File

@@ -417,7 +417,7 @@ theorem Sublist.of_sublist_append_left (w : ∀ a, a ∈ l → a ∉ l₂) (h :
obtain l₁', l₂', rfl, h₁, h₂ := h
have : l₂' = [] := by
rw [eq_nil_iff_forall_not_mem]
exact fun x m => w x (mem_append_of_mem_right l₁' m) (h₂.mem m)
exact fun x m => w x (mem_append_right l₁' m) (h₂.mem m)
simp_all
theorem Sublist.of_sublist_append_right (w : a, a l a l₁) (h : l <+ l₁ ++ l₂) : l <+ l₂ := by
@@ -425,7 +425,7 @@ theorem Sublist.of_sublist_append_right (w : ∀ a, a ∈ l → a ∉ l₁) (h :
obtain l₁', l₂', rfl, h₁, h₂ := h
have : l₁' = [] := by
rw [eq_nil_iff_forall_not_mem]
exact fun x m => w x (mem_append_of_mem_left l₂' m) (h₁.mem m)
exact fun x m => w x (mem_append_left l₂' m) (h₁.mem m)
simp_all
theorem Sublist.middle {l : List α} (h : l <+ l₁ ++ l₂) (a : α) : l <+ l₁ ++ a :: l₂ := by

View File

@@ -20,3 +20,4 @@ import Init.Data.Nat.Mod
import Init.Data.Nat.Lcm
import Init.Data.Nat.Compare
import Init.Data.Nat.Simproc
import Init.Data.Nat.Fold

View File

@@ -35,52 +35,6 @@ Used as the default `Nat` eliminator by the `cases` tactic. -/
protected abbrev casesAuxOn {motive : Nat Sort u} (t : Nat) (zero : motive 0) (succ : (n : Nat) motive (n + 1)) : motive t :=
Nat.casesOn t zero succ
/--
`Nat.fold` evaluates `f` on the numbers up to `n` exclusive, in increasing order:
* `Nat.fold f 3 init = init |> f 0 |> f 1 |> f 2`
-/
@[specialize] def fold {α : Type u} (f : Nat α α) : (n : Nat) (init : α) α
| 0, a => a
| succ n, a => f n (fold f n a)
/-- Tail-recursive version of `Nat.fold`. -/
@[inline] def foldTR {α : Type u} (f : Nat α α) (n : Nat) (init : α) : α :=
let rec @[specialize] loop
| 0, a => a
| succ m, a => loop m (f (n - succ m) a)
loop n init
/--
`Nat.foldRev` evaluates `f` on the numbers up to `n` exclusive, in decreasing order:
* `Nat.foldRev f 3 init = f 0 <| f 1 <| f 2 <| init`
-/
@[specialize] def foldRev {α : Type u} (f : Nat α α) : (n : Nat) (init : α) α
| 0, a => a
| succ n, a => foldRev f n (f n a)
/-- `any f n = true` iff there is `i in [0, n-1]` s.t. `f i = true` -/
@[specialize] def any (f : Nat Bool) : Nat Bool
| 0 => false
| succ n => any f n || f n
/-- Tail-recursive version of `Nat.any`. -/
@[inline] def anyTR (f : Nat Bool) (n : Nat) : Bool :=
let rec @[specialize] loop : Nat Bool
| 0 => false
| succ m => f (n - succ m) || loop m
loop n
/-- `all f n = true` iff every `i in [0, n-1]` satisfies `f i = true` -/
@[specialize] def all (f : Nat Bool) : Nat Bool
| 0 => true
| succ n => all f n && f n
/-- Tail-recursive version of `Nat.all`. -/
@[inline] def allTR (f : Nat Bool) (n : Nat) : Bool :=
let rec @[specialize] loop : Nat Bool
| 0 => true
| succ m => f (n - succ m) && loop m
loop n
/--
`Nat.repeat f n a` is `f^(n) a`; that is, it iterates `f` `n` times on `a`.
@@ -1158,33 +1112,6 @@ theorem not_lt_eq (a b : Nat) : (¬ (a < b)) = (b ≤ a) :=
theorem not_gt_eq (a b : Nat) : (¬ (a > b)) = (a b) :=
not_lt_eq b a
/-! # csimp theorems -/
@[csimp] theorem fold_eq_foldTR : @fold = @foldTR :=
funext fun α => funext fun f => funext fun n => funext fun init =>
let rec go : m n, foldTR.loop f (m + n) m (fold f n init) = fold f (m + n) init
| 0, n => by simp [foldTR.loop]
| succ m, n => by rw [foldTR.loop, add_sub_self_left, succ_add]; exact go m (succ n)
(go n 0).symm
@[csimp] theorem any_eq_anyTR : @any = @anyTR :=
funext fun f => funext fun n =>
let rec go : m n, (any f n || anyTR.loop f (m + n) m) = any f (m + n)
| 0, n => by simp [anyTR.loop]
| succ m, n => by
rw [anyTR.loop, add_sub_self_left, Bool.or_assoc, succ_add]
exact go m (succ n)
(go n 0).symm
@[csimp] theorem all_eq_allTR : @all = @allTR :=
funext fun f => funext fun n =>
let rec go : m n, (all f n && allTR.loop f (m + n) m) = all f (m + n)
| 0, n => by simp [allTR.loop]
| succ m, n => by
rw [allTR.loop, add_sub_self_left, Bool.and_assoc, succ_add]
exact go m (succ n)
(go n 0).symm
@[csimp] theorem repeat_eq_repeatTR : @repeat = @repeatTR :=
funext fun α => funext fun f => funext fun n => funext fun init =>
let rec go : m n, repeatTR.loop f m (repeat f n init) = repeat f (m + n) init
@@ -1193,31 +1120,3 @@ theorem not_gt_eq (a b : Nat) : (¬ (a > b)) = (a ≤ b) :=
(go n 0).symm
end Nat
namespace Prod
/--
`(start, stop).foldI f a` evaluates `f` on all the numbers
from `start` (inclusive) to `stop` (exclusive) in increasing order:
* `(5, 8).foldI f init = init |> f 5 |> f 6 |> f 7`
-/
@[inline] def foldI {α : Type u} (f : Nat α α) (i : Nat × Nat) (a : α) : α :=
Nat.foldTR.loop f i.2 (i.2 - i.1) a
/--
`(start, stop).anyI f a` returns true if `f` is true for some natural number
from `start` (inclusive) to `stop` (exclusive):
* `(5, 8).anyI f = f 5 || f 6 || f 7`
-/
@[inline] def anyI (f : Nat Bool) (i : Nat × Nat) : Bool :=
Nat.anyTR.loop f i.2 (i.2 - i.1)
/--
`(start, stop).allI f a` returns true if `f` is true for all natural numbers
from `start` (inclusive) to `stop` (exclusive):
* `(5, 8).anyI f = f 5 && f 6 && f 7`
-/
@[inline] def allI (f : Nat Bool) (i : Nat × Nat) : Bool :=
Nat.allTR.loop f i.2 (i.2 - i.1)
end Prod

View File

@@ -6,50 +6,51 @@ Author: Leonardo de Moura
prelude
import Init.Control.Basic
import Init.Data.Nat.Basic
import Init.Omega
namespace Nat
universe u v
@[inline] def forM {m} [Monad m] (n : Nat) (f : Nat m Unit) : m Unit :=
let rec @[specialize] loop
| 0 => pure ()
| i+1 => do f (n-i-1); loop i
loop n
@[inline] def forM {m} [Monad m] (n : Nat) (f : (i : Nat) i < n m Unit) : m Unit :=
let rec @[specialize] loop : i, i n m Unit
| 0, _ => pure ()
| i+1, h => do f (n-i-1) (by omega); loop i (Nat.le_of_succ_le h)
loop n (by simp)
@[inline] def forRevM {m} [Monad m] (n : Nat) (f : Nat m Unit) : m Unit :=
let rec @[specialize] loop
| 0 => pure ()
| i+1 => do f i; loop i
loop n
@[inline] def forRevM {m} [Monad m] (n : Nat) (f : (i : Nat) i < n m Unit) : m Unit :=
let rec @[specialize] loop : i, i n m Unit
| 0, _ => pure ()
| i+1, h => do f i (by omega); loop i (Nat.le_of_succ_le h)
loop n (by simp)
@[inline] def foldM {α : Type u} {m : Type u Type v} [Monad m] (f : Nat α m α) (init : α) (n : Nat) : m α :=
let rec @[specialize] loop
| 0, a => pure a
| i+1, a => f (n-i-1) a >>= loop i
loop n init
@[inline] def foldM {α : Type u} {m : Type u Type v} [Monad m] (n : Nat) (f : (i : Nat) i < n α m α) (init : α) : m α :=
let rec @[specialize] loop : i, i n α m α
| 0, h, a => pure a
| i+1, h, a => f (n-i-1) (by omega) a >>= loop i (Nat.le_of_succ_le h)
loop n (by omega) init
@[inline] def foldRevM {α : Type u} {m : Type u Type v} [Monad m] (f : Nat α m α) (init : α) (n : Nat) : m α :=
let rec @[specialize] loop
| 0, a => pure a
| i+1, a => f i a >>= loop i
loop n init
@[inline] def foldRevM {α : Type u} {m : Type u Type v} [Monad m] (n : Nat) (f : (i : Nat) i < n α m α) (init : α) : m α :=
let rec @[specialize] loop : i, i n α m α
| 0, h, a => pure a
| i+1, h, a => f i (by omega) a >>= loop i (Nat.le_of_succ_le h)
loop n (by omega) init
@[inline] def allM {m} [Monad m] (n : Nat) (p : Nat m Bool) : m Bool :=
let rec @[specialize] loop
| 0 => pure true
| i+1 => do
match ( p (n-i-1)) with
| true => loop i
@[inline] def allM {m} [Monad m] (n : Nat) (p : (i : Nat) i < n m Bool) : m Bool :=
let rec @[specialize] loop : i, i n m Bool
| 0, _ => pure true
| i+1 , h => do
match ( p (n-i-1) (by omega)) with
| true => loop i (by omega)
| false => pure false
loop n
loop n (by simp)
@[inline] def anyM {m} [Monad m] (n : Nat) (p : Nat m Bool) : m Bool :=
let rec @[specialize] loop
| 0 => pure false
| i+1 => do
match ( p (n-i-1)) with
@[inline] def anyM {m} [Monad m] (n : Nat) (p : (i : Nat) i < n m Bool) : m Bool :=
let rec @[specialize] loop : i, i n m Bool
| 0, _ => pure false
| i+1, h => do
match ( p (n-i-1) (by omega)) with
| true => pure true
| false => loop i
loop n
| false => loop i (Nat.le_of_succ_le h)
loop n (by simp)
end Nat

168
src/Init/Data/Nat/Fold.lean Normal file
View File

@@ -0,0 +1,168 @@
/-
Copyright (c) 2014 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Floris van Doorn, Leonardo de Moura, Kim Morrison
-/
prelude
import Init.Omega
set_option linter.missingDocs true -- keep it documented
universe u
namespace Nat
/--
`Nat.fold` evaluates `f` on the numbers up to `n` exclusive, in increasing order:
* `Nat.fold f 3 init = init |> f 0 |> f 1 |> f 2`
-/
@[specialize] def fold {α : Type u} : (n : Nat) (f : (i : Nat) i < n α α) (init : α) α
| 0, f, a => a
| succ n, f, a => f n (by omega) (fold n (fun i h => f i (by omega)) a)
/-- Tail-recursive version of `Nat.fold`. -/
@[inline] def foldTR {α : Type u} (n : Nat) (f : (i : Nat) i < n α α) (init : α) : α :=
let rec @[specialize] loop : j, j n α α
| 0, h, a => a
| succ m, h, a => loop m (by omega) (f (n - succ m) (by omega) a)
loop n (by omega) init
/--
`Nat.foldRev` evaluates `f` on the numbers up to `n` exclusive, in decreasing order:
* `Nat.foldRev f 3 init = f 0 <| f 1 <| f 2 <| init`
-/
@[specialize] def foldRev {α : Type u} : (n : Nat) (f : (i : Nat) i < n α α) (init : α) α
| 0, f, a => a
| succ n, f, a => foldRev n (fun i h => f i (by omega)) (f n (by omega) a)
/-- `any f n = true` iff there is `i in [0, n-1]` s.t. `f i = true` -/
@[specialize] def any : (n : Nat) (f : (i : Nat) i < n Bool) Bool
| 0, f => false
| succ n, f => any n (fun i h => f i (by omega)) || f n (by omega)
/-- Tail-recursive version of `Nat.any`. -/
@[inline] def anyTR (n : Nat) (f : (i : Nat) i < n Bool) : Bool :=
let rec @[specialize] loop : (i : Nat) i n Bool
| 0, h => false
| succ m, h => f (n - succ m) (by omega) || loop m (by omega)
loop n (by omega)
/-- `all f n = true` iff every `i in [0, n-1]` satisfies `f i = true` -/
@[specialize] def all : (n : Nat) (f : (i : Nat) i < n Bool) Bool
| 0, f => true
| succ n, f => all n (fun i h => f i (by omega)) && f n (by omega)
/-- Tail-recursive version of `Nat.all`. -/
@[inline] def allTR (n : Nat) (f : (i : Nat) i < n Bool) : Bool :=
let rec @[specialize] loop : (i : Nat) i n Bool
| 0, h => true
| succ m, h => f (n - succ m) (by omega) && loop m (by omega)
loop n (by omega)
/-! # csimp theorems -/
theorem fold_congr {α : Type u} {n m : Nat} (w : n = m)
(f : (i : Nat) i < n α α) (init : α) :
fold n f init = fold m (fun i h => f i (by omega)) init := by
subst m
rfl
theorem foldTR_loop_congr {α : Type u} {n m : Nat} (w : n = m)
(f : (i : Nat) i < n α α) (j : Nat) (h : j n) (init : α) :
foldTR.loop n f j h init = foldTR.loop m (fun i h => f i (by omega)) j (by omega) init := by
subst m
rfl
@[csimp] theorem fold_eq_foldTR : @fold = @foldTR :=
funext fun α => funext fun n => funext fun f => funext fun init =>
let rec go : m n f, fold (m + n) f init = foldTR.loop (m + n) f m (by omega) (fold n (fun i h => f i (by omega)) init)
| 0, n, f => by
simp only [foldTR.loop]
have t : 0 + n = n := by omega
rw [fold_congr t]
| succ m, n, f => by
have t : (m + 1) + n = m + (n + 1) := by omega
rw [foldTR.loop]
simp only [succ_eq_add_one, Nat.add_sub_cancel]
rw [fold_congr t, foldTR_loop_congr t, go, fold]
congr
omega
go n 0 f
theorem any_congr {n m : Nat} (w : n = m) (f : (i : Nat) i < n Bool) : any n f = any m (fun i h => f i (by omega)) := by
subst m
rfl
theorem anyTR_loop_congr {n m : Nat} (w : n = m) (f : (i : Nat) i < n Bool) (j : Nat) (h : j n) :
anyTR.loop n f j h = anyTR.loop m (fun i h => f i (by omega)) j (by omega) := by
subst m
rfl
@[csimp] theorem any_eq_anyTR : @any = @anyTR :=
funext fun n => funext fun f =>
let rec go : m n f, any (m + n) f = (any n (fun i h => f i (by omega)) || anyTR.loop (m + n) f m (by omega))
| 0, n, f => by
simp [anyTR.loop]
have t : 0 + n = n := by omega
rw [any_congr t]
| succ m, n, f => by
have t : (m + 1) + n = m + (n + 1) := by omega
rw [anyTR.loop]
simp only [succ_eq_add_one]
rw [any_congr t, anyTR_loop_congr t, go, any, Bool.or_assoc]
congr
omega
go n 0 f
theorem all_congr {n m : Nat} (w : n = m) (f : (i : Nat) i < n Bool) : all n f = all m (fun i h => f i (by omega)) := by
subst m
rfl
theorem allTR_loop_congr {n m : Nat} (w : n = m) (f : (i : Nat) i < n Bool) (j : Nat) (h : j n) : allTR.loop n f j h = allTR.loop m (fun i h => f i (by omega)) j (by omega) := by
subst m
rfl
@[csimp] theorem all_eq_allTR : @all = @allTR :=
funext fun n => funext fun f =>
let rec go : m n f, all (m + n) f = (all n (fun i h => f i (by omega)) && allTR.loop (m + n) f m (by omega))
| 0, n, f => by
simp [allTR.loop]
have t : 0 + n = n := by omega
rw [all_congr t]
| succ m, n, f => by
have t : (m + 1) + n = m + (n + 1) := by omega
rw [allTR.loop]
simp only [succ_eq_add_one]
rw [all_congr t, allTR_loop_congr t, go, all, Bool.and_assoc]
congr
omega
go n 0 f
end Nat
namespace Prod
/--
`(start, stop).foldI f a` evaluates `f` on all the numbers
from `start` (inclusive) to `stop` (exclusive) in increasing order:
* `(5, 8).foldI f init = init |> f 5 |> f 6 |> f 7`
-/
@[inline] def foldI {α : Type u} (i : Nat × Nat) (f : (j : Nat) i.1 j j < i.2 α α) (a : α) : α :=
(i.2 - i.1).fold (fun j _ => f (i.1 + j) (by omega) (by omega)) a
/--
`(start, stop).anyI f a` returns true if `f` is true for some natural number
from `start` (inclusive) to `stop` (exclusive):
* `(5, 8).anyI f = f 5 || f 6 || f 7`
-/
@[inline] def anyI (i : Nat × Nat) (f : (j : Nat) i.1 j j < i.2 Bool) : Bool :=
(i.2 - i.1).any (fun j _ => f (i.1 + j) (by omega) (by omega))
/--
`(start, stop).allI f a` returns true if `f` is true for all natural numbers
from `start` (inclusive) to `stop` (exclusive):
* `(5, 8).anyI f = f 5 && f 6 && f 7`
-/
@[inline] def allI (i : Nat × Nat) (f : (j : Nat) i.1 j j < i.2 Bool) : Bool :=
(i.2 - i.1).all (fun j _ => f (i.1 + j) (by omega) (by omega))
end Prod

View File

@@ -1029,3 +1029,12 @@ instance decidableExistsLT [h : DecidablePred p] : DecidablePred fun n => ∃ m
instance decidableExistsLE [DecidablePred p] : DecidablePred fun n => m : Nat, m n p m :=
fun n => decidable_of_iff ( m, m < n + 1 p m)
(exists_congr fun _ => and_congr_left' Nat.lt_succ_iff)
/-! ### Results about `List.sum` specialized to `Nat` -/
protected theorem sum_pos_iff_exists_pos {l : List Nat} : 0 < l.sum x l, 0 < x := by
induction l with
| nil => simp
| cons x xs ih =>
simp [ ih]
omega

View File

@@ -55,7 +55,9 @@ theorem get_eq_getD {fallback : α} : (o : Option α) → {h : o.isSome} → o.g
theorem some_get! [Inhabited α] : (o : Option α) o.isSome some (o.get!) = o
| some _, _ => rfl
theorem get!_eq_getD_default [Inhabited α] (o : Option α) : o.get! = o.getD default := rfl
theorem get!_eq_getD [Inhabited α] (o : Option α) : o.get! = o.getD default := rfl
@[deprecated get!_eq_getD (since := "2024-11-18")] abbrev get!_eq_getD_default := @get!_eq_getD
theorem mem_unique {o : Option α} {a b : α} (ha : a o) (hb : b o) : a = b :=
some.inj <| ha hb

View File

@@ -113,10 +113,10 @@ initialize IO.stdGenRef : IO.Ref StdGen ←
let seed := UInt64.toNat (ByteArray.toUInt64LE! ( IO.getRandomBytes 8))
IO.mkRef (mkStdGen seed)
def IO.setRandSeed (n : Nat) : IO Unit :=
def IO.setRandSeed (n : Nat) : BaseIO Unit :=
IO.stdGenRef.set (mkStdGen n)
def IO.rand (lo hi : Nat) : IO Nat := do
def IO.rand (lo hi : Nat) : BaseIO Nat := do
let gen IO.stdGenRef.get
let (r, gen) := randNat gen lo hi
IO.stdGenRef.set gen

View File

@@ -31,7 +31,7 @@ This file defines basic operations on the the sum type `α ⊕ β`.
## Further material
See `Batteries.Data.Sum.Lemmas` for theorems about these definitions.
See `Init.Data.Sum.Lemmas` for theorems about these definitions.
## Notes

View File

@@ -374,6 +374,9 @@ partial def structEq : Syntax → Syntax → Bool
instance : BEq Lean.Syntax := structEq
instance : BEq (Lean.TSyntax k) := (·.raw == ·.raw)
/--
Finds the first `SourceInfo` from the back of `stx` or `none` if no `SourceInfo` can be found.
-/
partial def getTailInfo? : Syntax Option SourceInfo
| atom info _ => info
| ident info .. => info
@@ -382,14 +385,39 @@ partial def getTailInfo? : Syntax → Option SourceInfo
| node info _ _ => info
| _ => none
/--
Finds the first `SourceInfo` from the back of `stx` or `SourceInfo.none`
if no `SourceInfo` can be found.
-/
def getTailInfo (stx : Syntax) : SourceInfo :=
stx.getTailInfo?.getD SourceInfo.none
/--
Finds the trailing size of the first `SourceInfo` from the back of `stx`.
If no `SourceInfo` can be found or the first `SourceInfo` from the back of `stx` contains no
trailing whitespace, the result is `0`.
-/
def getTrailingSize (stx : Syntax) : Nat :=
match stx.getTailInfo? with
| some (SourceInfo.original (trailing := trailing) ..) => trailing.bsize
| _ => 0
/--
Finds the trailing whitespace substring of the first `SourceInfo` from the back of `stx`.
If no `SourceInfo` can be found or the first `SourceInfo` from the back of `stx` contains
no trailing whitespace, the result is `none`.
-/
def getTrailing? (stx : Syntax) : Option Substring :=
stx.getTailInfo.getTrailing?
/--
Finds the tail position of the trailing whitespace of the first `SourceInfo` from the back of `stx`.
If no `SourceInfo` can be found or the first `SourceInfo` from the back of `stx` contains
no trailing whitespace and lacks a tail position, the result is `none`.
-/
def getTrailingTailPos? (stx : Syntax) (canonicalOnly := false) : Option String.Pos :=
stx.getTailInfo.getTrailingTailPos? canonicalOnly
/--
Return substring of original input covering `stx`.
Result is meaningful only if all involved `SourceInfo.original`s refer to the same string (as is the case after parsing). -/
@@ -403,21 +431,20 @@ def getSubstring? (stx : Syntax) (withLeading := true) (withTrailing := true) :
}
| _, _ => none
@[specialize] private partial def updateLast {α} [Inhabited α] (a : Array α) (f : α Option α) (i : Nat) : Option (Array α) :=
if i == 0 then
none
else
let i := i - 1
let v := a[i]!
@[specialize] private partial def updateLast {α} (a : Array α) (f : α Option α) (i : Fin (a.size + 1)) : Option (Array α) :=
match i with
| 0 => none
| i + 1, h =>
let v := a[i]'(Nat.succ_lt_succ_iff.mp h)
match f v with
| some v => some <| a.set! i v
| none => updateLast a f i
| some v => some <| a.set i v (Nat.succ_lt_succ_iff.mp h)
| none => updateLast a f i, Nat.lt_of_succ_lt h
partial def setTailInfoAux (info : SourceInfo) : Syntax Option Syntax
| atom _ val => some <| atom info val
| ident _ rawVal val pre => some <| ident info rawVal val pre
| node info' k args =>
match updateLast args (setTailInfoAux info) args.size with
match updateLast args (setTailInfoAux info) args.size, by simp with
| some args => some <| node info' k args
| none => none
| _ => none

View File

@@ -71,9 +71,9 @@ def prio : Category := {}
/-- `prec` is a builtin syntax category for precedences. A precedence is a value
that expresses how tightly a piece of syntax binds: for example `1 + 2 * 3` is
parsed as `1 + (2 * 3)` because `*` has a higher pr0ecedence than `+`.
parsed as `1 + (2 * 3)` because `*` has a higher precedence than `+`.
Higher numbers denote higher precedence.
In addition to literals like `37`, there are some special named priorities:
In addition to literals like `37`, there are some special named precedence levels:
* `arg` for the precedence of function arguments
* `max` for the highest precedence used in term parsers (not actually the maximum possible value)
* `lead` for the precedence of terms not supposed to be used as arguments

View File

@@ -22,28 +22,28 @@ syntax explicitBinders := (ppSpace bracketedExplicitBinders)+ <|> unb
open TSyntax.Compat in
def expandExplicitBindersAux (combinator : Syntax) (idents : Array Syntax) (type? : Option Syntax) (body : Syntax) : MacroM Syntax :=
let rec loop (i : Nat) (acc : Syntax) := do
let rec loop (i : Nat) (h : i idents.size) (acc : Syntax) := do
match i with
| 0 => pure acc
| i+1 =>
let ident := idents[i]![0]
| i + 1 =>
let ident := idents[i][0]
let acc match ident.isIdent, type? with
| true, none => `($combinator fun $ident => $acc)
| true, some type => `($combinator fun $ident : $type => $acc)
| false, none => `($combinator fun _ => $acc)
| false, some type => `($combinator fun _ : $type => $acc)
loop i acc
loop idents.size body
loop i (Nat.le_of_succ_le h) acc
loop idents.size (by simp) body
def expandBrackedBindersAux (combinator : Syntax) (binders : Array Syntax) (body : Syntax) : MacroM Syntax :=
let rec loop (i : Nat) (acc : Syntax) := do
let rec loop (i : Nat) (h : i binders.size) (acc : Syntax) := do
match i with
| 0 => pure acc
| i+1 =>
let idents := binders[i]![1].getArgs
let type := binders[i]![3]
loop i ( expandExplicitBindersAux combinator idents (some type) acc)
loop binders.size body
let idents := binders[i][1].getArgs
let type := binders[i][3]
loop i (Nat.le_of_succ_le h) ( expandExplicitBindersAux combinator idents (some type) acc)
loop binders.size (by simp) body
def expandExplicitBinders (combinatorDeclName : Name) (explicitBinders : Syntax) (body : Syntax) : MacroM Syntax := do
let combinator := mkCIdentFrom ( getRef) combinatorDeclName

View File

@@ -3654,7 +3654,8 @@ namespace SourceInfo
/--
Gets the position information from a `SourceInfo`, if available.
If `originalOnly` is true, then `.synthetic` syntax will also return `none`.
If `canonicalOnly` is true, then `.synthetic` syntax with `canonical := false`
will also return `none`.
-/
def getPos? (info : SourceInfo) (canonicalOnly := false) : Option String.Pos :=
match info, canonicalOnly with
@@ -3665,7 +3666,8 @@ def getPos? (info : SourceInfo) (canonicalOnly := false) : Option String.Pos :=
/--
Gets the end position information from a `SourceInfo`, if available.
If `originalOnly` is true, then `.synthetic` syntax will also return `none`.
If `canonicalOnly` is true, then `.synthetic` syntax with `canonical := false`
will also return `none`.
-/
def getTailPos? (info : SourceInfo) (canonicalOnly := false) : Option String.Pos :=
match info, canonicalOnly with
@@ -3674,6 +3676,24 @@ def getTailPos? (info : SourceInfo) (canonicalOnly := false) : Option String.Pos
| synthetic (endPos := endPos) .., false => some endPos
| _, _ => none
/--
Gets the substring representing the trailing whitespace of a `SourceInfo`, if available.
-/
def getTrailing? (info : SourceInfo) : Option Substring :=
match info with
| original (trailing := trailing) .. => some trailing
| _ => none
/--
Gets the end position information of the trailing whitespace of a `SourceInfo`, if available.
If `canonicalOnly` is true, then `.synthetic` syntax with `canonical := false`
will also return `none`.
-/
def getTrailingTailPos? (info : SourceInfo) (canonicalOnly := false) : Option String.Pos :=
match info.getTrailing? with
| some trailing => some trailing.stopPos
| none => info.getTailPos? canonicalOnly
end SourceInfo
/--
@@ -3972,7 +3992,6 @@ position information.
def getPos? (stx : Syntax) (canonicalOnly := false) : Option String.Pos :=
stx.getHeadInfo.getPos? canonicalOnly
/--
Get the ending position of the syntax, if possible.
If `canonicalOnly` is true, non-canonical `synthetic` nodes are treated as not carrying

View File

@@ -30,7 +30,7 @@ Does nothing for non-`node` nodes, or if `i` is out of bounds of the node list.
-/
def setArg (stx : Syntax) (i : Nat) (arg : Syntax) : Syntax :=
match stx with
| node info k args => node info k (args.setD i arg)
| node info k args => node info k (args.setIfInBounds i arg)
| stx => stx
end Lean.Syntax

View File

@@ -462,6 +462,16 @@ Note that it is the caller's job to remove the file after use.
-/
@[extern "lean_io_create_tempfile"] opaque createTempFile : IO (Handle × FilePath)
/--
Creates a temporary directory in the most secure manner possible. There are no race conditions in the
directorys creation. The directory is readable and writable only by the creating user ID.
Returns the new directory's path.
It is the caller's job to remove the directory after use.
-/
@[extern "lean_io_create_tempdir"] opaque createTempDir : IO FilePath
end FS
@[extern "lean_io_getenv"] opaque getEnv (var : @& String) : BaseIO (Option String)
@@ -474,17 +484,6 @@ namespace FS
def withFile (fn : FilePath) (mode : Mode) (f : Handle IO α) : IO α :=
Handle.mk fn mode >>= f
/--
Like `createTempFile` but also takes care of removing the file after usage.
-/
def withTempFile [Monad m] [MonadFinally m] [MonadLiftT IO m] (f : Handle FilePath m α) :
m α := do
let (handle, path) createTempFile
try
f handle path
finally
removeFile path
def Handle.putStrLn (h : Handle) (s : String) : IO Unit :=
h.putStr (s.push '\n')
@@ -675,8 +674,10 @@ def appDir : IO FilePath := do
| throw <| IO.userError s!"System.IO.appDir: unexpected filename '{p}'"
FS.realPath p
namespace FS
/-- Create given path and all missing parents as directories. -/
partial def FS.createDirAll (p : FilePath) : IO Unit := do
partial def createDirAll (p : FilePath) : IO Unit := do
if p.isDir then
return ()
if let some parent := p.parent then
@@ -693,7 +694,7 @@ partial def FS.createDirAll (p : FilePath) : IO Unit := do
/--
Fully remove given directory by deleting all contained files and directories in an unspecified order.
Fails if any contained entry cannot be deleted or was newly created during execution. -/
partial def FS.removeDirAll (p : FilePath) : IO Unit := do
partial def removeDirAll (p : FilePath) : IO Unit := do
for ent in ( p.readDir) do
if ( ent.path.isDir : Bool) then
removeDirAll ent.path
@@ -701,6 +702,32 @@ partial def FS.removeDirAll (p : FilePath) : IO Unit := do
removeFile ent.path
removeDir p
/--
Like `createTempFile`, but also takes care of removing the file after usage.
-/
def withTempFile [Monad m] [MonadFinally m] [MonadLiftT IO m] (f : Handle FilePath m α) :
m α := do
let (handle, path) createTempFile
try
f handle path
finally
removeFile path
/--
Like `createTempDir`, but also takes care of removing the directory after usage.
All files in the directory are recursively deleted, regardless of how or when they were created.
-/
def withTempDir [Monad m] [MonadFinally m] [MonadLiftT IO m] (f : FilePath m α) :
m α := do
let path createTempDir
try
f path
finally
removeDirAll path
end FS
namespace Process
/-- Returns the current working directory of the calling process. -/
@@ -802,6 +829,9 @@ def run (args : SpawnArgs) : IO String := do
end Process
/-- Returns the thread ID of the calling thread. -/
@[extern "lean_io_get_tid"] opaque getTID : BaseIO UInt64
structure AccessRight where
read : Bool := false
write : Bool := false

View File

@@ -29,13 +29,13 @@ def decodeUri (uri : String) : String := Id.run do
let len := rawBytes.size
let mut i := 0
let percent := '%'.toNat.toUInt8
while i < len do
let c := rawBytes[i]!
(decoded, i) := if c == percent && i + 1 < len then
let h1 := rawBytes[i + 1]!
while h : i < len do
let c := rawBytes[i]
(decoded, i) := if h₁ : c == percent i + 1 < len then
let h1 := rawBytes[i + 1]
if let some hd1 := hexDigitToUInt8? h1 then
if i + 2 < len then
let h2 := rawBytes[i + 2]!
if h₂ : i + 2 < len then
let h2 := rawBytes[i + 2]
if let some hd2 := hexDigitToUInt8? h2 then
-- decode the hex digits into a byte.
(decoded.push (hd1 * 16 + hd2), i + 3)

View File

@@ -428,11 +428,11 @@ macro "infer_instance" : tactic => `(tactic| exact inferInstance)
/--
`+opt` is short for `(opt := true)`. It sets the `opt` configuration option to `true`.
-/
syntax posConfigItem := "+" noWs ident
syntax posConfigItem := " +" noWs ident
/--
`-opt` is short for `(opt := false)`. It sets the `opt` configuration option to `false`.
-/
syntax negConfigItem := "-" noWs ident
syntax negConfigItem := " -" noWs ident
/--
`(opt := val)` sets the `opt` configuration option to `val`.
@@ -1155,7 +1155,7 @@ Configuration for the `decide` tactic family.
structure DecideConfig where
/-- If true (default: false), then use only kernel reduction when reducing the `Decidable` instance.
This is more efficient, since the default mode reduces twice (once in the elaborator and again in the kernel),
however kernel reduction ignores transparency settings. The `decide!` tactic is a synonym for `decide +kernel`. -/
however kernel reduction ignores transparency settings. -/
kernel : Bool := false
/-- If true (default: false), then uses the native code compiler to evaluate the `Decidable` instance,
admitting the result via the axiom `Lean.ofReduceBool`. This can be significantly more efficient,
@@ -1165,7 +1165,9 @@ structure DecideConfig where
native : Bool := false
/-- If true (default: true), then when preprocessing the goal, do zeta reduction to attempt to eliminate free variables. -/
zetaReduce : Bool := true
/-- If true (default: false), then when preprocessing reverts free variables. -/
/-- If true (default: false), then when preprocessing, removes irrelevant variables and reverts the local context.
A variable is *relevant* if it appears in the target, if it appears in a relevant variable,
or if it is a proposition that refers to a relevant variable. -/
revert : Bool := false
/--
@@ -1240,17 +1242,6 @@ example : 1 + 1 = 2 := by rfl
-/
syntax (name := decide) "decide" optConfig : tactic
/--
`decide!` is a variant of the `decide` tactic that uses kernel reduction to prove the goal.
It has the following properties:
- Since it uses kernel reduction instead of elaborator reduction, it ignores transparency and can unfold everything.
- While `decide` needs to reduce the `Decidable` instance twice (once during elaboration to verify whether the tactic succeeds,
and once during kernel type checking), the `decide!` tactic reduces it exactly once.
The `decide!` syntax is short for `decide +kernel`.
-/
syntax (name := decideBang) "decide!" optConfig : tactic
/--
`native_decide` is a synonym for `decide +native`.
It will attempt to prove a goal of type `p` by synthesizing an instance

View File

@@ -205,8 +205,8 @@ def getParamInfo (k : ParamMap.Key) : M (Array Param) := do
/-- For each ps[i], if ps[i] is owned, then mark xs[i] as owned. -/
def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
xs.size.forM fun i => do
let x := xs[i]!
xs.size.forM fun i _ => do
let x := xs[i]
let p := ps[i]!
unless p.borrow do ownArg x
@@ -216,8 +216,8 @@ def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
we would have to insert a `dec xs[i]` after `f xs` and consequently
"break" the tail call. -/
def ownParamsUsingArgs (xs : Array Arg) (ps : Array Param) : M Unit :=
xs.size.forM fun i => do
let x := xs[i]!
xs.size.forM fun i _ => do
let x := xs[i]
let p := ps[i]!
match x with
| Arg.var x => if ( isOwned x) then ownVar p.x

View File

@@ -48,9 +48,9 @@ def requiresBoxedVersion (env : Environment) (decl : Decl) : Bool :=
def mkBoxedVersionAux (decl : Decl) : N Decl := do
let ps := decl.params
let qs ps.mapM fun _ => do let x N.mkFresh; pure { x := x, ty := IRType.object, borrow := false : Param }
let (newVDecls, xs) qs.size.foldM (init := (#[], #[])) fun i (newVDecls, xs) => do
let (newVDecls, xs) qs.size.foldM (init := (#[], #[])) fun i _ (newVDecls, xs) => do
let p := ps[i]!
let q := qs[i]!
let q := qs[i]
if !p.ty.isScalar then
pure (newVDecls, xs.push (Arg.var q.x))
else

View File

@@ -63,7 +63,7 @@ partial def merge (v₁ v₂ : Value) : Value :=
| top, _ => top
| _, top => top
| v₁@(ctor i₁ vs₁), v₂@(ctor i₂ vs₂) =>
if i₁ == i₂ then ctor i₁ <| vs₁.size.fold (init := #[]) fun i r => r.push (merge vs₁[i]! vs₂[i]!)
if i₁ == i₂ then ctor i₁ <| vs₁.size.fold (init := #[]) fun i _ r => r.push (merge vs₁[i] vs₂[i]!)
else choice [v₁, v₂]
| choice vs₁, choice vs₂ => choice <| vs₁.foldl (addChoice merge) vs₂
| choice vs, v => choice <| addChoice merge vs v
@@ -225,8 +225,8 @@ def updateCurrFnSummary (v : Value) : M Unit := do
def updateJPParamsAssignment (ys : Array Param) (xs : Array Arg) : M Bool := do
let ctx read
let currFnIdx := ctx.currFnIdx
ys.size.foldM (init := false) fun i r => do
let y := ys[i]!
ys.size.foldM (init := false) fun i _ r => do
let y := ys[i]
let x := xs[i]!
let yVal findVarValue y.x
let xVal findArgValue x
@@ -282,8 +282,8 @@ partial def interpFnBody : FnBody → M Unit
def inferStep : M Bool := do
let ctx read
modify fun s => { s with assignments := ctx.decls.map fun _ => {} }
ctx.decls.size.foldM (init := false) fun idx modified => do
match ctx.decls[idx]! with
ctx.decls.size.foldM (init := false) fun idx _ modified => do
match ctx.decls[idx] with
| .fdecl (xs := ys) (body := b) .. => do
let s get
let currVals := s.funVals[idx]!
@@ -336,8 +336,8 @@ def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
let funVals := s.funVals
let assignments := s.assignments
modify fun s =>
let env := decls.size.fold (init := s.env) fun i env =>
addFunctionSummary env decls[i]!.name funVals[i]!
let env := decls.size.fold (init := s.env) fun i _ env =>
addFunctionSummary env decls[i].name funVals[i]!
{ s with env := env }
return decls.mapIdx fun i decl => elimDead assignments[i]! decl

View File

@@ -108,9 +108,9 @@ def emitFnDeclAux (decl : Decl) (cppBaseName : String) (isExternal : Bool) : M U
if ps.size > closureMaxArgs && isBoxedName decl.name then
emit "lean_object**"
else
ps.size.forM fun i => do
ps.size.forM fun i _ => do
if i > 0 then emit ", "
emit (toCType ps[i]!.ty)
emit (toCType ps[i].ty)
emit ")"
emitLn ";"
@@ -271,9 +271,9 @@ def emitTag (x : VarId) (xType : IRType) : M Unit := do
emit x
def isIf (alts : Array Alt) : Option (Nat × FnBody × FnBody) :=
if alts.size != 2 then none
else match alts[0]! with
| Alt.ctor c b => some (c.cidx, b, alts[1]!.body)
if h : alts.size 2 then none
else match alts[0] with
| Alt.ctor c b => some (c.cidx, b, alts[1].body)
| _ => none
def emitInc (x : VarId) (n : Nat) (checkRef : Bool) : M Unit := do
@@ -321,20 +321,22 @@ def emitSSet (x : VarId) (n : Nat) (offset : Nat) (y : VarId) (t : IRType) : M U
def emitJmp (j : JoinPointId) (xs : Array Arg) : M Unit := do
let ps getJPParams j
unless xs.size == ps.size do throw "invalid goto"
xs.size.forM fun i => do
let p := ps[i]!
let x := xs[i]!
emit p.x; emit " = "; emitArg x; emitLn ";"
emit "goto "; emit j; emitLn ";"
if h : xs.size = ps.size then
xs.size.forM fun i _ => do
let p := ps[i]
let x := xs[i]
emit p.x; emit " = "; emitArg x; emitLn ";"
emit "goto "; emit j; emitLn ";"
else
do throw "invalid goto"
def emitLhs (z : VarId) : M Unit := do
emit z; emit " = "
def emitArgs (ys : Array Arg) : M Unit :=
ys.size.forM fun i => do
ys.size.forM fun i _ => do
if i > 0 then emit ", "
emitArg ys[i]!
emitArg ys[i]
def emitCtorScalarSize (usize : Nat) (ssize : Nat) : M Unit := do
if usize == 0 then emit ssize
@@ -346,8 +348,8 @@ def emitAllocCtor (c : CtorInfo) : M Unit := do
emitCtorScalarSize c.usize c.ssize; emitLn ");"
def emitCtorSetArgs (z : VarId) (ys : Array Arg) : M Unit :=
ys.size.forM fun i => do
emit "lean_ctor_set("; emit z; emit ", "; emit i; emit ", "; emitArg ys[i]!; emitLn ");"
ys.size.forM fun i _ => do
emit "lean_ctor_set("; emit z; emit ", "; emit i; emit ", "; emitArg ys[i]; emitLn ");"
def emitCtor (z : VarId) (c : CtorInfo) (ys : Array Arg) : M Unit := do
emitLhs z;
@@ -358,7 +360,7 @@ def emitCtor (z : VarId) (c : CtorInfo) (ys : Array Arg) : M Unit := do
def emitReset (z : VarId) (n : Nat) (x : VarId) : M Unit := do
emit "if (lean_is_exclusive("; emit x; emitLn ")) {";
n.forM fun i => do
n.forM fun i _ => do
emit " lean_ctor_release("; emit x; emit ", "; emit i; emitLn ");"
emit " "; emitLhs z; emit x; emitLn ";";
emitLn "} else {";
@@ -399,12 +401,12 @@ def emitSimpleExternalCall (f : String) (ps : Array Param) (ys : Array Arg) : M
emit f; emit "("
-- We must remove irrelevant arguments to extern calls.
discard <| ys.size.foldM
(fun i (first : Bool) =>
(fun i _ (first : Bool) =>
if ps[i]!.ty.isIrrelevant then
pure first
else do
unless first do emit ", "
emitArg ys[i]!
emitArg ys[i]
pure false)
true
emitLn ");"
@@ -431,8 +433,8 @@ def emitPartialApp (z : VarId) (f : FunId) (ys : Array Arg) : M Unit := do
let decl getDecl f
let arity := decl.params.size;
emitLhs z; emit "lean_alloc_closure((void*)("; emitCName f; emit "), "; emit arity; emit ", "; emit ys.size; emitLn ");";
ys.size.forM fun i => do
let y := ys[i]!
ys.size.forM fun i _ => do
let y := ys[i]
emit "lean_closure_set("; emit z; emit ", "; emit i; emit ", "; emitArg y; emitLn ");"
def emitApp (z : VarId) (f : VarId) (ys : Array Arg) : M Unit :=
@@ -544,34 +546,36 @@ That is, we have
-/
def overwriteParam (ps : Array Param) (ys : Array Arg) : Bool :=
let n := ps.size;
n.any fun i =>
let p := ps[i]!
(i+1, n).anyI fun j => paramEqArg p ys[j]!
n.any fun i _ =>
let p := ps[i]
(i+1, n).anyI fun j _ _ => paramEqArg p ys[j]!
def emitTailCall (v : Expr) : M Unit :=
match v with
| Expr.fap _ ys => do
let ctx read
let ps := ctx.mainParams
unless ps.size == ys.size do throw "invalid tail call"
if overwriteParam ps ys then
emitLn "{"
ps.size.forM fun i => do
let p := ps[i]!
let y := ys[i]!
unless paramEqArg p y do
emit (toCType p.ty); emit " _tmp_"; emit i; emit " = "; emitArg y; emitLn ";"
ps.size.forM fun i => do
let p := ps[i]!
let y := ys[i]!
unless paramEqArg p y do emit p.x; emit " = _tmp_"; emit i; emitLn ";"
emitLn "}"
if h : ps.size = ys.size then
if overwriteParam ps ys then
emitLn "{"
ps.size.forM fun i _ => do
let p := ps[i]
let y := ys[i]
unless paramEqArg p y do
emit (toCType p.ty); emit " _tmp_"; emit i; emit " = "; emitArg y; emitLn ";"
ps.size.forM fun i _ => do
let p := ps[i]
let y := ys[i]
unless paramEqArg p y do emit p.x; emit " = _tmp_"; emit i; emitLn ";"
emitLn "}"
else
ys.size.forM fun i _ => do
let p := ps[i]
let y := ys[i]
unless paramEqArg p y do emit p.x; emit " = "; emitArg y; emitLn ";"
emitLn "goto _start;"
else
ys.size.forM fun i => do
let p := ps[i]!
let y := ys[i]!
unless paramEqArg p y do emit p.x; emit " = "; emitArg y; emitLn ";"
emitLn "goto _start;"
throw "invalid tail call"
| _ => throw "bug at emitTailCall"
mutual
@@ -654,16 +658,16 @@ def emitDeclAux (d : Decl) : M Unit := do
if xs.size > closureMaxArgs && isBoxedName d.name then
emit "lean_object** _args"
else
xs.size.forM fun i => do
xs.size.forM fun i _ => do
if i > 0 then emit ", "
let x := xs[i]!
let x := xs[i]
emit (toCType x.ty); emit " "; emit x.x
emit ")"
else
emit ("_init_" ++ baseName ++ "()")
emitLn " {";
if xs.size > closureMaxArgs && isBoxedName d.name then
xs.size.forM fun i => do
xs.size.forM fun i _ => do
let x := xs[i]!
emit "lean_object* "; emit x.x; emit " = _args["; emit i; emitLn "];"
emitLn "_start:";

View File

@@ -571,9 +571,9 @@ def emitAllocCtor (builder : LLVM.Builder llvmctx)
def emitCtorSetArgs (builder : LLVM.Builder llvmctx)
(z : VarId) (ys : Array Arg) : M llvmctx Unit := do
ys.size.forM fun i => do
ys.size.forM fun i _ => do
let zv emitLhsVal builder z
let (_yty, yv) emitArgVal builder ys[i]!
let (_yty, yv) emitArgVal builder ys[i]
let iv constIntUnsigned i
callLeanCtorSet builder zv iv yv
emitLhsSlotStore builder z zv
@@ -702,8 +702,8 @@ def emitPartialApp (builder : LLVM.Builder llvmctx) (z : VarId) (f : FunId) (ys
( constIntUnsigned arity)
( constIntUnsigned ys.size)
LLVM.buildStore builder zval zslot
ys.size.forM fun i => do
let (yty, yslot) emitArgSlot_ builder ys[i]!
ys.size.forM fun i _ => do
let (yty, yslot) emitArgSlot_ builder ys[i]
let yval LLVM.buildLoad2 builder yty yslot
callLeanClosureSetFn builder zval ( constIntUnsigned i) yval
@@ -922,7 +922,7 @@ def emitReset (builder : LLVM.Builder llvmctx) (z : VarId) (n : Nat) (x : VarId)
buildIfThenElse_ builder "isExclusive" isExclusive
(fun builder => do
let xv emitLhsVal builder x
n.forM fun i => do
n.forM fun i _ => do
callLeanCtorRelease builder xv ( constIntUnsigned i)
emitLhsSlotStore builder z xv
return ShouldForwardControlFlow.yes
@@ -1172,8 +1172,8 @@ def emitFnArgs (builder : LLVM.Builder llvmctx)
(needsPackedArgs? : Bool) (llvmfn : LLVM.Value llvmctx) (params : Array Param) : M llvmctx Unit := do
if needsPackedArgs? then do
let argsp LLVM.getParam llvmfn 0 -- lean_object **args
for i in List.range params.size do
let param := params[i]!
for h : i in [:params.size] do
let param := params[i]
-- argsi := (args + i)
let argsi LLVM.buildGEP2 builder ( LLVM.voidPtrType llvmctx) argsp #[ constIntUnsigned i] s!"packed_arg_{i}_slot"
let llvmty toLLVMType param.ty
@@ -1182,15 +1182,16 @@ def emitFnArgs (builder : LLVM.Builder llvmctx)
-- slot for arg[i] which is always void* ?
let alloca buildPrologueAlloca builder llvmty s!"arg_{i}"
LLVM.buildStore builder pv alloca
addVartoState params[i]!.x alloca llvmty
addVartoState param.x alloca llvmty
else
let n LLVM.countParams llvmfn
for i in (List.range n.toNat) do
let llvmty toLLVMType params[i]!.ty
for i in [:n.toNat] do
let param := params[i]!
let llvmty toLLVMType param.ty
let alloca buildPrologueAlloca builder llvmty s!"arg_{i}"
let arg LLVM.getParam llvmfn (UInt64.ofNat i)
let _ LLVM.buildStore builder arg alloca
addVartoState params[i]!.x alloca llvmty
addVartoState param.x alloca llvmty
def emitDeclAux (mod : LLVM.Module llvmctx) (builder : LLVM.Builder llvmctx) (d : Decl) : M llvmctx Unit := do
let env getEnv

View File

@@ -54,7 +54,7 @@ abbrev Mask := Array (Option VarId)
partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (mask : Mask) (keep : Array FnBody) : Array FnBody × Mask :=
let done (_ : Unit) := (bs ++ keep.reverse, mask)
let keepInstr (b : FnBody) := eraseProjIncForAux y bs.pop mask (keep.push b)
if bs.size < 2 then done ()
if h : bs.size < 2 then done ()
else
let b := bs.back!
match b with
@@ -62,7 +62,7 @@ partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (mask : Mask) (ke
| .vdecl _ _ (.uproj _ _) _ => keepInstr b
| .inc z n c p _ =>
if n == 0 then done () else
let b' := bs[bs.size - 2]!
let b' := bs[bs.size - 2]
match b' with
| .vdecl w _ (.proj i x) _ =>
if w == z && y == x then
@@ -134,15 +134,15 @@ abbrev M := ReaderT Context (StateM Nat)
modifyGet fun n => ({ idx := n }, n + 1)
def releaseUnreadFields (y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
mask.size.foldM (init := b) fun i b =>
match mask.get! i with
mask.size.foldM (init := b) fun i _ b =>
match mask[i] with
| some _ => pure b -- code took ownership of this field
| none => do
let fld mkFresh
pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true false b))
def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
zs.size.fold (init := b) fun i b => FnBody.set y i (zs.get! i) b
zs.size.fold (init := b) fun i _ b => FnBody.set y i zs[i] b
/-- Given `set x[i] := y`, return true iff `y := proj[i] x` -/
def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=

View File

@@ -79,13 +79,13 @@ private def addDecForAlt (ctx : Context) (caseLiveVars altLiveVars : LiveVarSet)
/-- `isFirstOcc xs x i = true` if `xs[i]` is the first occurrence of `xs[i]` in `xs` -/
private def isFirstOcc (xs : Array Arg) (i : Nat) : Bool :=
let x := xs[i]!
i.all fun j => xs[j]! != x
i.all fun j _ => xs[j]! != x
/-- Return true if `x` also occurs in `ys` in a position that is not consumed.
That is, it is also passed as a borrow reference. -/
private def isBorrowParamAux (x : VarId) (ys : Array Arg) (consumeParamPred : Nat Bool) : Bool :=
ys.size.any fun i =>
let y := ys[i]!
ys.size.any fun i _ =>
let y := ys[i]
match y with
| Arg.irrelevant => false
| Arg.var y => x == y && !consumeParamPred i
@@ -99,15 +99,15 @@ Return `n`, the number of times `x` is consumed.
- `consumeParamPred i = true` if parameter `i` is consumed.
-/
private def getNumConsumptions (x : VarId) (ys : Array Arg) (consumeParamPred : Nat Bool) : Nat :=
ys.size.fold (init := 0) fun i n =>
let y := ys[i]!
ys.size.fold (init := 0) fun i _ n =>
let y := ys[i]
match y with
| Arg.irrelevant => n
| Arg.var y => if x == y && consumeParamPred i then n+1 else n
private def addIncBeforeAux (ctx : Context) (xs : Array Arg) (consumeParamPred : Nat Bool) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody :=
xs.size.fold (init := b) fun i b =>
let x := xs[i]!
xs.size.fold (init := b) fun i _ b =>
let x := xs[i]
match x with
| Arg.irrelevant => b
| Arg.var x =>
@@ -128,8 +128,8 @@ private def addIncBefore (ctx : Context) (xs : Array Arg) (ps : Array Param) (b
/-- See `addIncBeforeAux`/`addIncBefore` for the procedure that inserts `inc` operations before an application. -/
private def addDecAfterFullApp (ctx : Context) (xs : Array Arg) (ps : Array Param) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody :=
xs.size.fold (init := b) fun i b =>
match xs[i]! with
xs.size.fold (init := b) fun i _ b =>
match xs[i] with
| Arg.irrelevant => b
| Arg.var x =>
/- We must add a `dec` if `x` must be consumed, it is alive after the application,

View File

@@ -366,10 +366,10 @@ to be updated.
@[implemented_by updateFunDeclCoreImp] opaque FunDeclCore.updateCore (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl
def CasesCore.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases :=
let found (i : Nat) := (cases.alts[i]!, { cases with alts := cases.alts.eraseIdx i })
if let some i := cases.alts.findIdx? fun | .alt ctorName' .. => ctorName == ctorName' | _ => false then
let found i := (cases.alts[i], { cases with alts := cases.alts.eraseIdx i })
if let some i := cases.alts.findFinIdx? fun | .alt ctorName' .. => ctorName == ctorName' | _ => false then
found i
else if let some i := cases.alts.findIdx? fun | .default _ => true | _ => false then
else if let some i := cases.alts.findFinIdx? fun | .default _ => true | _ => false then
found i
else
unreachable!

View File

@@ -587,15 +587,15 @@ def Decl.elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
refer to the docstring of `Decl.safe`.
-/
if decls[i]!.safe then .bot else .top
let mut funVals := decls.size.fold (init := .empty) fun i p => p.push (initialVal i)
let mut funVals := decls.size.fold (init := .empty) fun i _ p => p.push (initialVal i)
let ctx := { decls }
let mut state := { assignments, funVals }
(_, state) inferMain |>.run ctx |>.run state
funVals := state.funVals
assignments := state.assignments
modifyEnv fun e =>
decls.size.fold (init := e) fun i env =>
addFunctionSummary env decls[i]!.name funVals[i]!
decls.size.fold (init := e) fun i _ env =>
addFunctionSummary env decls[i].name funVals[i]!
decls.mapIdxM fun i decl => if decl.safe then elimDead assignments[i]! decl else return decl

View File

@@ -76,8 +76,8 @@ def getType (fvarId : FVarId) : InferTypeM Expr := do
def mkForallFVars (xs : Array Expr) (type : Expr) : InferTypeM Expr :=
let b := type.abstract xs
xs.size.foldRevM (init := b) fun i b => do
let x := xs[i]!
xs.size.foldRevM (init := b) fun i _ b => do
let x := xs[i]
let n InferType.getBinderName x.fvarId!
let ty InferType.getType x.fvarId!
let ty := ty.abstractRange i xs;

View File

@@ -134,9 +134,9 @@ def withEachOccurrence (targetName : Name) (f : Nat → PassInstaller) : PassIns
def installAfter (targetName : Name) (p : Pass Pass) (occurrence : Nat := 0) : PassInstaller where
install passes :=
if let some idx := passes.findIdx? (fun p => p.name == targetName && p.occurrence == occurrence) then
let passUnderTest := passes[idx]!
return passes.insertAt! (idx + 1) (p passUnderTest)
if let some idx := passes.findFinIdx? (fun p => p.name == targetName && p.occurrence == occurrence) then
let passUnderTest := passes[idx]
return passes.insertIdx (idx + 1) (p passUnderTest)
else
throwError s!"Tried to insert pass after {targetName}, occurrence {occurrence} but {targetName} is not in the pass list"
@@ -145,9 +145,9 @@ def installAfterEach (targetName : Name) (p : Pass → Pass) : PassInstaller :=
def installBefore (targetName : Name) (p : Pass Pass) (occurrence : Nat := 0): PassInstaller where
install passes :=
if let some idx := passes.findIdx? (fun p => p.name == targetName && p.occurrence == occurrence) then
let passUnderTest := passes[idx]!
return passes.insertAt! idx (p passUnderTest)
if let some idx := passes.findFinIdx? (fun p => p.name == targetName && p.occurrence == occurrence) then
let passUnderTest := passes[idx]
return passes.insertIdx idx (p passUnderTest)
else
throwError s!"Tried to insert pass after {targetName}, occurrence {occurrence} but {targetName} is not in the pass list"

View File

@@ -152,8 +152,8 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do
let specArgs? := getSpecializationArgs? ( getEnv) decl.name
let contains (i : Nat) : Bool := specArgs?.getD #[] |>.contains i
let mut paramsInfo : Array SpecParamInfo := #[]
for i in [:decl.params.size] do
let param := decl.params[i]!
for h :i in [:decl.params.size] do
let param := decl.params[i]
let info
if contains i then
pure .user
@@ -181,14 +181,14 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do
declsInfo := declsInfo.push paramsInfo
if declsInfo.any fun paramsInfo => paramsInfo.any (· matches .user | .fixedInst | .fixedHO) then
let m := mkFixedParamsMap decls
for i in [:decls.size] do
let decl := decls[i]!
for hi : i in [:decls.size] do
let decl := decls[i]
let mut paramsInfo := declsInfo[i]!
let some mask := m.find? decl.name | unreachable!
trace[Compiler.specialize.info] "{decl.name} {mask}"
paramsInfo := paramsInfo.zipWith mask fun info fixed => if fixed || info matches .user then info else .other
for j in [:paramsInfo.size] do
let mut info := paramsInfo[j]!
let mut info := paramsInfo[j]!
if info matches .fixedNeutral && !hasFwdDeps decl paramsInfo j then
paramsInfo := paramsInfo.set! j .other
if paramsInfo.any fun info => info matches .fixedInst | .fixedHO | .user then

View File

@@ -499,8 +499,8 @@ where
match app with
| .fvar f =>
let mut argsNew := #[]
for i in [arity : args.size] do
argsNew := argsNew.push ( visitAppArg args[i]!)
for h :i in [arity : args.size] do
argsNew := argsNew.push ( visitAppArg args[i])
letValueToArg <| .fvar f argsNew
| .erased | .type .. => return .erased

View File

@@ -26,13 +26,14 @@ private def elabSpecArgs (declName : Name) (args : Array Syntax) : MetaM (Array
if let some idx := arg.isNatLit? then
if idx == 0 then throwErrorAt arg "invalid specialization argument index, index must be greater than 0"
let idx := idx - 1
if idx >= argNames.size then
if h : idx >= argNames.size then
throwErrorAt arg "invalid argument index, `{declName}` has #{argNames.size} arguments"
if result.contains idx then throwErrorAt arg "invalid specialization argument index, `{argNames[idx]!}` has already been specified as a specialization candidate"
result := result.push idx
else
if result.contains idx then throwErrorAt arg "invalid specialization argument index, `{argNames[idx]}` has already been specified as a specialization candidate"
result := result.push idx
else
let argName := arg.getId
if let some idx := argNames.getIdx? argName then
if let some idx := argNames.indexOf? argName then
if result.contains idx then throwErrorAt arg "invalid specialization argument name `{argName}`, it has already been specified as a specialization candidate"
result := result.push idx
else

View File

@@ -11,6 +11,7 @@ import Lean.ResolveName
import Lean.Elab.InfoTree.Types
import Lean.MonadEnv
import Lean.Elab.Exception
import Lean.Language.Basic
namespace Lean
register_builtin_option diagnostics : Bool := {
@@ -72,6 +73,13 @@ structure State where
messages : MessageLog := {}
/-- Info tree. We have the info tree here because we want to update it while adding attributes. -/
infoState : Elab.InfoState := {}
/--
Snapshot trees of asynchronous subtasks. As these are untyped and reported only at the end of the
command's main elaboration thread, they are only useful for basic message log reporting; for
incremental reporting and reuse within a long-running elaboration thread, types rooted in
`CommandParsedSnapshot` need to be adjusted.
-/
snapshotTasks : Array (Language.SnapshotTask Language.SnapshotTree) := #[]
deriving Nonempty
/-- Context for the CoreM monad. -/
@@ -180,7 +188,8 @@ instance : Elab.MonadInfoTree CoreM where
modifyInfoState f := modify fun s => { s with infoState := f s.infoState }
@[inline] def modifyCache (f : Cache Cache) : CoreM Unit :=
modify fun env, next, ngen, trace, cache, messages, infoState => env, next, ngen, trace, f cache, messages, infoState
modify fun env, next, ngen, trace, cache, messages, infoState, snaps =>
env, next, ngen, trace, f cache, messages, infoState, snaps
@[inline] def modifyInstLevelTypeCache (f : InstantiateLevelCache InstantiateLevelCache) : CoreM Unit :=
modifyCache fun c₁, c₂ => f c₁, c₂
@@ -355,13 +364,83 @@ instance : MonadLog CoreM where
if ( read).suppressElabErrors then
-- discard elaboration errors, except for a few important and unlikely misleading ones, on
-- parse error
unless msg.data.hasTag (· matches `Elab.synthPlaceholder | `Tactic.unsolvedGoals) do
unless msg.data.hasTag (· matches `Elab.synthPlaceholder | `Tactic.unsolvedGoals | `trace) do
return
let ctx read
let msg := { msg with data := MessageData.withNamingContext { currNamespace := ctx.currNamespace, openDecls := ctx.openDecls } msg.data };
modify fun s => { s with messages := s.messages.add msg }
/--
Includes a given task (such as from `wrapAsyncAsSnapshot`) in the overall snapshot tree for this
command's elaboration, making its result available to reporting and the language server. The
reporter will not know about this snapshot tree node until the main elaboration thread for this
command has finished so this function is not useful for incremental reporting within a longer
elaboration thread but only for tasks that outlive it such as background kernel checking or proof
elaboration.
-/
def logSnapshotTask (task : Language.SnapshotTask Language.SnapshotTree) : CoreM Unit :=
modify fun s => { s with snapshotTasks := s.snapshotTasks.push task }
/-- Wraps the given action for use in `EIO.asTask` etc., discarding its final monadic state. -/
def wrapAsync (act : Unit CoreM α) : CoreM (EIO Exception α) := do
let st get
let ctx read
let heartbeats := ( IO.getNumHeartbeats) - ctx.initHeartbeats
return withCurrHeartbeats (do
-- include heartbeats since start of elaboration in new thread as well such that forking off
-- an action doesn't suddenly allow it to succeed from a lower heartbeat count
IO.addHeartbeats heartbeats.toUInt64
act () : CoreM _)
|>.run' ctx st
/-- Option for capturing output to stderr during elaboration. -/
register_builtin_option stderrAsMessages : Bool := {
defValue := true
group := "server"
descr := "(server) capture output to the Lean stderr channel (such as from `dbg_trace`) during elaboration of a command as a diagnostic message"
}
open Language in
/--
Wraps the given action for use in `BaseIO.asTask` etc., discarding its final state except for
`logSnapshotTask` tasks, which are reported as part of the returned tree.
-/
def wrapAsyncAsSnapshot (act : Unit CoreM Unit) (desc : String := by exact decl_name%.toString) :
CoreM (BaseIO SnapshotTree) := do
let t wrapAsync fun _ => do
IO.FS.withIsolatedStreams (isolateStderr := stderrAsMessages.get ( getOptions)) do
let tid IO.getTID
-- reset trace state and message log so as not to report them twice
modify ({ · with messages := {}, traceState := { tid } })
try
withTraceNode `Elab.async (fun _ => return desc) do
act ()
catch e =>
logError e.toMessageData
finally
addTraceAsMessages
get
let ctx readThe Core.Context
return do
match ( t.toBaseIO) with
| .ok (output, st) =>
let mut msgs := st.messages
if !output.isEmpty then
msgs := msgs.add {
fileName := ctx.fileName
severity := MessageSeverity.information
pos := ctx.fileMap.toPosition <| ctx.ref.getPos?.getD 0
data := output
}
return .mk {
desc
diagnostics := ( Language.Snapshot.Diagnostics.ofMessageLog msgs)
traces := st.traceState
} st.snapshotTasks
-- interrupt or abort exception as `try catch` above should have caught any others
| .error _ => default
end Core
export Core (CoreM mkFreshUserName checkSystem withCurrHeartbeats)

View File

@@ -365,6 +365,7 @@ structure TextDocumentRegistrationOptions where
inductive MarkupKind where
| plaintext | markdown
deriving DecidableEq, Hashable
instance : FromJson MarkupKind := fun
| str "plaintext" => Except.ok MarkupKind.plaintext
@@ -378,7 +379,7 @@ instance : ToJson MarkupKind := ⟨fun
structure MarkupContent where
kind : MarkupKind
value : String
deriving ToJson, FromJson
deriving ToJson, FromJson, DecidableEq, Hashable
/-- Reference to the progress of some in-flight piece of work.

View File

@@ -25,7 +25,7 @@ inductive CompletionItemKind where
| unit | value | enum | keyword | snippet
| color | file | reference | folder | enumMember
| constant | struct | event | operator | typeParameter
deriving Inhabited, DecidableEq, Repr
deriving Inhabited, DecidableEq, Repr, Hashable
instance : ToJson CompletionItemKind where
toJson a := toJson (a.toCtorIdx + 1)
@@ -39,11 +39,11 @@ structure InsertReplaceEdit where
newText : String
insert : Range
replace : Range
deriving FromJson, ToJson
deriving FromJson, ToJson, BEq, Hashable
inductive CompletionItemTag where
| deprecated
deriving Inhabited, DecidableEq, Repr
deriving Inhabited, DecidableEq, Repr, Hashable
instance : ToJson CompletionItemTag where
toJson t := toJson (t.toCtorIdx + 1)
@@ -73,7 +73,7 @@ structure CompletionItem where
commitCharacters? : string[]
command? : Command
-/
deriving FromJson, ToJson, Inhabited
deriving FromJson, ToJson, Inhabited, BEq, Hashable
structure CompletionList where
isIncomplete : Bool

View File

@@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Data.Nat.Fold
import Init.Data.Array.Basic
import Init.NotationExtra
import Init.Data.ToString.Macro
@@ -371,7 +372,7 @@ instance : ToString Stats := ⟨Stats.toString⟩
end PersistentArray
def mkPersistentArray {α : Type u} (n : Nat) (v : α) : PArray α :=
n.fold (init := PersistentArray.empty) fun _ p => p.push v
n.fold (init := PersistentArray.empty) fun _ _ p => p.push v
@[inline] def mkPArray {α : Type u} (n : Nat) (v : α) : PArray α :=
mkPersistentArray n v

View File

@@ -233,10 +233,10 @@ partial def eraseAux [BEq α] : Node α β → USize → α → Node α β
| n@(Node.collision keys vals heq), _, k =>
match keys.indexOf? k with
| some idx =>
let keys' := keys.feraseIdx idx
have keq := keys.size_feraseIdx idx
let vals' := vals.feraseIdx (Eq.ndrec idx heq)
have veq := vals.size_feraseIdx (Eq.ndrec idx heq)
let keys' := keys.eraseIdx idx
have keq := keys.size_eraseIdx idx _
let vals' := vals.eraseIdx (Eq.ndrec idx heq)
have veq := vals.size_eraseIdx (Eq.ndrec idx heq) _
have : keys.size - 1 = vals.size - 1 := by rw [heq]
Node.collision keys' vals' (keq.trans (this.trans veq.symm))
| none => n

View File

@@ -23,6 +23,7 @@ import Lean.Elab.Quotation
import Lean.Elab.Syntax
import Lean.Elab.Do
import Lean.Elab.StructInst
import Lean.Elab.MutualInductive
import Lean.Elab.Inductive
import Lean.Elab.Structure
import Lean.Elab.Print

View File

@@ -807,8 +807,8 @@ def getElabElimExprInfo (elimExpr : Expr) : MetaM ElabElimInfo := do
These are the primary set of major parameters.
-/
let initMotiveFVars : CollectFVars.State := motiveArgs.foldl (init := {}) collectFVars
let motiveFVars xs.size.foldRevM (init := initMotiveFVars) fun i s => do
let x := xs[i]!
let motiveFVars xs.size.foldRevM (init := initMotiveFVars) fun i _ s => do
let x := xs[i]
if s.fvarSet.contains x.fvarId! then
return collectFVars s ( inferType x)
else
@@ -1347,7 +1347,7 @@ where
let mut unusableNamedArgs := unusableNamedArgs
for x in xs, bInfo in bInfos do
let xDecl x.mvarId!.getDecl
if let some idx := remainingNamedArgs.findIdx? (·.name == xDecl.userName) then
if let some idx := remainingNamedArgs.findFinIdx? (·.name == xDecl.userName) then
/- If there is named argument with name `xDecl.userName`, then it is accounted for and we can't make use of it. -/
remainingNamedArgs := remainingNamedArgs.eraseIdx idx
else
@@ -1355,9 +1355,9 @@ where
/- We found a type of the form (baseName ...).
First, we check if the current argument is an explicit one,
and if the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/
if argIdx args.size && bInfo.isExplicit then
if h : argIdx args.size bInfo.isExplicit then
/- We can insert `e` as an explicit argument -/
return (args.insertAt! argIdx (Arg.expr e), namedArgs)
return (args.insertIdx argIdx (Arg.expr e), namedArgs)
else
/- If we can't add `e` to `args`, we try to add it using a named argument, but this is only possible
if there isn't an argument with the same name occurring before it. -/
@@ -1399,8 +1399,8 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
let rec loop : Expr List LVal TermElabM Expr
| f, [] => elabAppArgs f namedArgs args expectedType? explicit ellipsis
| f, lval::lvals => do
if let LVal.fieldName (fullRef := fullRef) .. := lval then
addDotCompletionInfo fullRef f expectedType?
if let LVal.fieldName (ref := ref) .. := lval then
addDotCompletionInfo ref f expectedType?
let hasArgs := !namedArgs.isEmpty || !args.isEmpty
let (f, lvalRes) resolveLVal f lval hasArgs
match lvalRes with
@@ -1650,6 +1650,14 @@ private def getSuccesses (candidates : Array (TermElabResult Expr)) : TermElabM
-/
private def mergeFailures (failures : Array (TermElabResult Expr)) : TermElabM α := do
let exs := failures.map fun | .error ex _ => ex | _ => unreachable!
let trees := failures.map (fun | .error _ s => s.meta.core.infoState.trees | _ => unreachable!)
|>.filterMap (·[0]?)
-- Retain partial `InfoTree` subtrees in an `.ofChoiceInfo` node in case of multiple failures.
-- This ensures that the language server still has `Info` to work with when multiple overloaded
-- elaborators fail.
withInfoContext (mkInfo := pure <| .ofChoiceInfo { elaborator := .anonymous, stx := getRef }) do
for tree in trees do
pushInfoTree tree
throwErrorWithNestedErrors "overloaded" exs
private def elabAppAux (f : Syntax) (namedArgs : Array NamedArg) (args : Array Arg) (ellipsis : Bool) (expectedType? : Option Expr) : TermElabM Expr := do

View File

@@ -211,7 +211,7 @@ private def replaceBinderAnnotation (binder : TSyntax ``Parser.Term.bracketedBin
else
`(bracketedBinderF| {$id $[: $ty?]?})
for id in ids.reverse do
if let some idx := binderIds.findIdx? fun binderId => binderId.raw.isIdent && binderId.raw.getId == id.raw.getId then
if let some idx := binderIds.findFinIdx? fun binderId => binderId.raw.isIdent && binderId.raw.getId == id.raw.getId then
binderIds := binderIds.eraseIdx idx
modifiedVarDecls := true
varDeclsNew := varDeclsNew.push ( mkBinder id explicit)

View File

@@ -42,16 +42,15 @@ private def elabOptLevel (stx : Syntax) : TermElabM Level :=
@[builtin_term_elab «completion»] def elabCompletion : TermElab := fun stx expectedType? => do
/- `ident.` is ambiguous in Lean, we may try to be completing a declaration name or access a "field". -/
if stx[0].isIdent then
/- If we can elaborate the identifier successfully, we assume it is a dot-completion. Otherwise, we treat it as
identifier completion with a dangling `.`.
Recall that the server falls back to identifier completion when dot-completion fails. -/
-- Add both an `id` and a `dot` `CompletionInfo` and have the language server figure out which
-- one to use.
addCompletionInfo <| CompletionInfo.id stx stx[0].getId (danglingDot := true) ( getLCtx) expectedType?
let s saveState
try
let e elabTerm stx[0] none
addDotCompletionInfo stx e expectedType?
catch _ =>
s.restore
addCompletionInfo <| CompletionInfo.id stx stx[0].getId (danglingDot := true) ( getLCtx) expectedType?
throwErrorAt stx[1] "invalid field notation, identifier or numeral expected"
else
elabPipeCompletion stx expectedType?
@@ -328,7 +327,7 @@ private def mkSilentAnnotationIfHole (e : Expr) : TermElabM Expr := do
@[builtin_term_elab withAnnotateTerm] def elabWithAnnotateTerm : TermElab := fun stx expectedType? => do
match stx with
| `(with_annotate_term $stx $e) =>
withInfoContext' stx (elabTerm e expectedType?) (mkTermInfo .anonymous (expectedType? := expectedType?) stx)
withTermInfoContext' .anonymous stx (expectedType? := expectedType?) (elabTerm e expectedType?)
| _ => throwUnsupportedSyntax
private unsafe def evalFilePathUnsafe (stx : Syntax) : TermElabM System.FilePath :=

View File

@@ -84,6 +84,7 @@ structure State where
ngen : NameGenerator := {}
infoState : InfoState := {}
traceState : TraceState := {}
snapshotTasks : Array (Language.SnapshotTask Language.SnapshotTree) := #[]
deriving Nonempty
structure Context where
@@ -114,8 +115,7 @@ structure Context where
-/
suppressElabErrors : Bool := false
abbrev CommandElabCoreM (ε) := ReaderT Context $ StateRefT State $ EIO ε
abbrev CommandElabM := CommandElabCoreM Exception
abbrev CommandElabM := ReaderT Context $ StateRefT State $ EIO Exception
abbrev CommandElab := Syntax CommandElabM Unit
structure Linter where
run : Syntax CommandElabM Unit
@@ -198,36 +198,6 @@ instance : AddErrorMessageContext CommandElabM where
let msg addMacroStack msg ctx.macroStack
return (ref, msg)
def mkMessageAux (ctx : Context) (ref : Syntax) (msgData : MessageData) (severity : MessageSeverity) : Message :=
let pos := ref.getPos?.getD ctx.cmdPos
let endPos := ref.getTailPos?.getD pos
mkMessageCore ctx.fileName ctx.fileMap msgData severity pos endPos
private def addTraceAsMessagesCore (ctx : Context) (log : MessageLog) (traceState : TraceState) : MessageLog := Id.run do
if traceState.traces.isEmpty then return log
let mut traces : Std.HashMap (String.Pos × String.Pos) (Array MessageData) :=
for traceElem in traceState.traces do
let ref := replaceRef traceElem.ref ctx.ref
let pos := ref.getPos?.getD 0
let endPos := ref.getTailPos?.getD pos
traces := traces.insert (pos, endPos) <| traces.getD (pos, endPos) #[] |>.push traceElem.msg
let mut log := log
let traces' := traces.toArray.qsort fun ((a, _), _) ((b, _), _) => a < b
for ((pos, endPos), traceMsg) in traces' do
let data := .tagged `trace <| .joinSep traceMsg.toList "\n"
log := log.add <| mkMessageCore ctx.fileName ctx.fileMap data .information pos endPos
return log
private def addTraceAsMessages : CommandElabM Unit := do
let ctx read
-- do not add trace messages if `trace.profiler.output` is set as it would be redundant and
-- pretty printing the trace messages is expensive
if trace.profiler.output.get? ( getOptions) |>.isNone then
modify fun s => { s with
messages := addTraceAsMessagesCore ctx s.messages s.traceState
traceState.traces := {}
}
private def runCore (x : CoreM α) : CommandElabM α := do
let s get
let ctx read
@@ -253,6 +223,7 @@ private def runCore (x : CoreM α) : CommandElabM α := do
nextMacroScope := s.nextMacroScope
infoState.enabled := s.infoState.enabled
traceState := s.traceState
snapshotTasks := s.snapshotTasks
}
let (ea, coreS) liftM x
modify fun s => { s with
@@ -261,6 +232,7 @@ private def runCore (x : CoreM α) : CommandElabM α := do
ngen := coreS.ngen
infoState.trees := s.infoState.trees.append coreS.infoState.trees
traceState.traces := coreS.traceState.traces.map fun t => { t with ref := replaceRef t.ref ctx.ref }
snapshotTasks := coreS.snapshotTasks
messages := s.messages ++ coreS.messages
}
return ea
@@ -268,10 +240,6 @@ private def runCore (x : CoreM α) : CommandElabM α := do
def liftCoreM (x : CoreM α) : CommandElabM α := do
MonadExcept.ofExcept ( runCore (observing x))
private def ioErrorToMessage (ctx : Context) (ref : Syntax) (err : IO.Error) : Message :=
let ref := getBetterRef ref ctx.macroStack
mkMessageAux ctx ref (toString err) MessageSeverity.error
@[inline] def liftIO {α} (x : IO α) : CommandElabM α := do
let ctx read
IO.toEIO (fun (ex : IO.Error) => Exception.error ctx.ref ex.toString) x
@@ -294,9 +262,8 @@ instance : MonadLog CommandElabM where
logMessage msg := do
if ( read).suppressElabErrors then
-- discard elaboration errors on parse error
-- NOTE: unlike `CoreM`'s `logMessage`, we do not currently have any command-level errors that
-- we want to allowlist
return
unless msg.data.hasTag (· matches `trace) do
return
let currNamespace getCurrNamespace
let openDecls getOpenDecls
let msg := { msg with data := MessageData.withNamingContext { currNamespace := currNamespace, openDecls := openDecls } msg.data }
@@ -322,6 +289,61 @@ def runLinters (stx : Syntax) : CommandElabM Unit := do
finally
modify fun s => { savedState with messages := s.messages }
/--
Catches and logs exceptions occurring in `x`. Unlike `try catch` in `CommandElabM`, this function
catches interrupt exceptions as well and thus is intended for use at the top level of elaboration.
Interrupt and abort exceptions are caught but not logged.
-/
@[inline] def withLoggingExceptions (x : CommandElabM Unit) : CommandElabM Unit := fun ctx ref =>
EIO.catchExceptions (withLogging x ctx ref) (fun _ => pure ())
@[inherit_doc Core.wrapAsync]
def wrapAsync (act : Unit CommandElabM α) : CommandElabM (EIO Exception α) := do
return act () |>.run ( read) |>.run' ( get)
open Language in
@[inherit_doc Core.wrapAsyncAsSnapshot]
-- `CoreM` and `CommandElabM` are too different to meaningfully share this code
def wrapAsyncAsSnapshot (act : Unit CommandElabM Unit)
(desc : String := by exact decl_name%.toString) :
CommandElabM (BaseIO SnapshotTree) := do
let t wrapAsync fun _ => do
IO.FS.withIsolatedStreams (isolateStderr := Core.stderrAsMessages.get ( getOptions)) do
let tid IO.getTID
-- reset trace state and message log so as not to report them twice
modify ({ · with messages := {}, traceState := { tid } })
try
withTraceNode `Elab.async (fun _ => return desc) do
act ()
catch e =>
logError e.toMessageData
finally
addTraceAsMessages
get
let ctx read
return do
match ( t.toBaseIO) with
| .ok (output, st) =>
let mut msgs := st.messages
if !output.isEmpty then
msgs := msgs.add {
fileName := ctx.fileName
severity := MessageSeverity.information
pos := ctx.fileMap.toPosition <| ctx.ref.getPos?.getD 0
data := output
}
return .mk {
desc
diagnostics := ( Language.Snapshot.Diagnostics.ofMessageLog msgs)
traces := st.traceState
} st.snapshotTasks
-- interrupt or abort exception as `try catch` above should have caught any others
| .error _ => default
@[inherit_doc Core.logSnapshotTask]
def logSnapshotTask (task : Language.SnapshotTask Language.SnapshotTree) : CommandElabM Unit :=
modify fun s => { s with snapshotTasks := s.snapshotTasks.push task }
protected def getCurrMacroScope : CommandElabM Nat := do pure ( read).currMacroScope
protected def getMainModule : CommandElabM Name := do pure ( getEnv).mainModule
@@ -532,12 +554,6 @@ def elabCommandTopLevel (stx : Syntax) : CommandElabM Unit := withRef stx do pro
let mut msgs := ( get).messages
for tree in ( getInfoTrees) do
trace[Elab.info] ( tree.format)
if ( isTracingEnabledFor `Elab.snapshotTree) then
if let some snap := ( read).snap? then
-- We can assume that the root command snapshot is not involved in parallelism yet, so this
-- should be true iff the command supports incrementality
if ( IO.hasFinished snap.new.result) then
liftCoreM <| Language.ToSnapshotTree.toSnapshotTree snap.new.result.get |>.trace
modify fun st => { st with
messages := initMsgs ++ msgs
infoState := { st.infoState with trees := initInfoTrees ++ st.infoState.trees }
@@ -555,7 +571,11 @@ private def getVarDecls (s : State) : Array Syntax :=
instance {α} : Inhabited (CommandElabM α) where
default := throw default
private def mkMetaContext : Meta.Context := {
/--
The environment linter framework needs to be able to run linters with the same context
as `liftTermElabM`, so we expose that context as a public function here.
-/
def mkMetaContext : Meta.Context := {
config := { foApprox := true, ctxApprox := true, quasiPatternApprox := true }
}
@@ -664,14 +684,6 @@ def runTermElabM (elabFn : Array Expr → TermElabM α) : CommandElabM α := do
Term.addAutoBoundImplicits' xs someType fun xs _ =>
Term.withoutAutoBoundImplicit <| elabFn xs
/--
Catches and logs exceptions occurring in `x`. Unlike `try catch` in `CommandElabM`, this function
catches interrupt exceptions as well and thus is intended for use at the top level of elaboration.
Interrupt and abort exceptions are caught but not logged.
-/
@[inline] def withLoggingExceptions (x : CommandElabM Unit) : CommandElabCoreM Empty Unit := fun ctx ref =>
EIO.catchExceptions (withLogging x ctx ref) (fun _ => pure ())
private def liftAttrM {α} (x : AttrM α) : CommandElabM α := do
liftCoreM x

View File

@@ -7,9 +7,8 @@ prelude
import Lean.Util.CollectLevelParams
import Lean.Elab.DeclUtil
import Lean.Elab.DefView
import Lean.Elab.Inductive
import Lean.Elab.Structure
import Lean.Elab.MutualDef
import Lean.Elab.MutualInductive
import Lean.Elab.DeclarationRange
namespace Lean.Elab.Command
@@ -163,15 +162,11 @@ def elabDeclaration : CommandElab := fun stx => do
if declKind == ``Lean.Parser.Command.«axiom» then
let modifiers elabModifiers modifiers
elabAxiom modifiers decl
else if declKind == ``Lean.Parser.Command.«inductive» then
else if declKind == ``Lean.Parser.Command.«inductive»
|| declKind == ``Lean.Parser.Command.classInductive
|| declKind == ``Lean.Parser.Command.«structure» then
let modifiers elabModifiers modifiers
elabInductive modifiers decl
else if declKind == ``Lean.Parser.Command.classInductive then
let modifiers elabModifiers modifiers
elabClassInductive modifiers decl
else if declKind == ``Lean.Parser.Command.«structure» then
let modifiers elabModifiers modifiers
elabStructure modifiers decl
else
throwError "unexpected declaration"
@@ -278,10 +273,10 @@ def elabMutual : CommandElab := fun stx => do
-- only case implementing incrementality currently
elabMutualDef stx[1].getArgs
else withoutCommandIncrementality true do
if isMutualInductive stx then
if isMutualInductive stx then
elabMutualInductive stx[1].getArgs
else
throwError "invalid mutual block: either all elements of the block must be inductive declarations, or they must all be definitions/theorems/abbrevs"
throwError "invalid mutual block: either all elements of the block must be inductive/structure declarations, or they must all be definitions/theorems/abbrevs"
/- leading_parser "attribute " >> "[" >> sepBy1 (eraseAttr <|> Term.attrInstance) ", " >> "]" >> many1 ident -/
@[builtin_command_elab «attribute»] def elabAttr : CommandElab := fun stx => do

View File

@@ -49,9 +49,9 @@ invoking ``mkInstImplicitBinders `BarClass foo #[`α, `n, `β]`` gives `` `([Bar
def mkInstImplicitBinders (className : Name) (indVal : InductiveVal) (argNames : Array Name) : TermElabM (Array Syntax) :=
forallBoundedTelescope indVal.type indVal.numParams fun xs _ => do
let mut binders := #[]
for i in [:xs.size] do
for h : i in [:xs.size] do
try
let x := xs[i]!
let x := xs[i]
let c mkAppM className #[x]
if ( isTypeCorrect c) then
let argName := argNames[i]!
@@ -86,8 +86,8 @@ def mkContext (fnPrefix : String) (typeName : Name) : TermElabM Context := do
def mkLocalInstanceLetDecls (ctx : Context) (className : Name) (argNames : Array Name) : TermElabM (Array (TSyntax ``Parser.Term.letDecl)) := do
let mut letDecls := #[]
for i in [:ctx.typeInfos.size] do
let indVal := ctx.typeInfos[i]!
for h : i in [:ctx.typeInfos.size] do
let indVal := ctx.typeInfos[i]
let auxFunName := ctx.auxFunNames[i]!
let currArgNames mkInductArgNames indVal
let numParams := indVal.numParams

View File

@@ -796,10 +796,10 @@ Note that we are not restricting the macro power since the
actions to be in the same universe.
-/
private def mkTuple (elems : Array Syntax) : MacroM Syntax := do
if elems.size == 0 then
if elems.size = 0 then
mkUnit
else if elems.size == 1 then
return elems[0]!
else if h : elems.size = 1 then
return elems[0]
else
elems.extract 0 (elems.size - 1) |>.foldrM (init := elems.back!) fun elem tuple =>
``(MProd.mk $elem $tuple)
@@ -831,10 +831,10 @@ def isDoExpr? (doElem : Syntax) : Option Syntax :=
We use this method when expanding the `for-in` notation.
-/
private def destructTuple (uvars : Array Var) (x : Syntax) (body : Syntax) : MacroM Syntax := do
if uvars.size == 0 then
if uvars.size = 0 then
return body
else if uvars.size == 1 then
`(let $(uvars[0]!):ident := $x; $body)
else if h : uvars.size = 1 then
`(let $(uvars[0]):ident := $x; $body)
else
destruct uvars.toList x body
where
@@ -1314,9 +1314,9 @@ private partial def expandLiftMethodAux (inQuot : Bool) (inBinder : Bool) : Synt
else if liftMethodDelimiter k then
return stx
-- For `pure` if-then-else, we only lift `(<- ...)` occurring in the condition.
else if args.size >= 2 && (k == ``termDepIfThenElse || k == ``termIfThenElse) then do
else if h : args.size >= 2 (k == ``termDepIfThenElse || k == ``termIfThenElse) then do
let inAntiquot := stx.isAntiquot && !stx.isEscapedAntiquot
let arg1 expandLiftMethodAux (inQuot && !inAntiquot || stx.isQuot) inBinder args[1]!
let arg1 expandLiftMethodAux (inQuot && !inAntiquot || stx.isQuot) inBinder args[1]
let args := args.set! 1 arg1
return Syntax.node i k args
else if k == ``Parser.Term.liftMethod && !inQuot then withFreshMacroScope do
@@ -1518,7 +1518,7 @@ mutual
-/
partial def doForToCode (doFor : Syntax) (doElems : List Syntax) : M CodeBlock := do
let doForDecls := doFor[1].getSepArgs
if doForDecls.size > 1 then
if h : doForDecls.size > 1 then
/-
Expand
```

View File

@@ -327,15 +327,18 @@ private def toExprCore (t : Tree) : TermElabM Expr := do
| .term _ trees e =>
modifyInfoState (fun s => { s with trees := s.trees ++ trees }); return e
| .binop ref kind f lhs rhs =>
withRef ref <| withInfoContext' ref (mkInfo := mkTermInfo .anonymous ref) do
mkBinOp (kind == .lazy) f ( toExprCore lhs) ( toExprCore rhs)
withRef ref <|
withTermInfoContext' .anonymous ref do
mkBinOp (kind == .lazy) f ( toExprCore lhs) ( toExprCore rhs)
| .unop ref f arg =>
withRef ref <| withInfoContext' ref (mkInfo := mkTermInfo .anonymous ref) do
mkUnOp f ( toExprCore arg)
withRef ref <|
withTermInfoContext' .anonymous ref do
mkUnOp f ( toExprCore arg)
| .macroExpansion macroName stx stx' nested =>
withRef stx <| withInfoContext' stx (mkInfo := mkTermInfo macroName stx) do
withMacroExpansion stx stx' do
toExprCore nested
withRef stx <|
withTermInfoContext' macroName stx <|
withMacroExpansion stx stx' <|
toExprCore nested
/--
Auxiliary function to decide whether we should coerce `f`'s argument to `maxType` or not.

View File

@@ -102,7 +102,7 @@ partial def IO.processCommandsIncrementally (inputCtx : Parser.InputContext)
where
go initialSnap t commands :=
let snap := t.get
let commands := commands.push snap.data
let commands := commands.push snap
if let some next := snap.nextCmdSnap? then
go initialSnap next.task commands
else
@@ -115,9 +115,9 @@ where
-- snapshots as they subsume any info trees reported incrementally by their children.
let trees := commands.map (·.finishedSnap.get.infoTree?) |>.filterMap id |>.toPArray'
return {
commandState := { snap.data.finishedSnap.get.cmdState with messages, infoState.trees := trees }
parserState := snap.data.parserState
cmdPos := snap.data.parserState.pos
commandState := { snap.finishedSnap.get.cmdState with messages, infoState.trees := trees }
parserState := snap.parserState
cmdPos := snap.parserState.pos
commands := commands.map (·.stx)
inputCtx, initialSnap
}
@@ -164,8 +164,8 @@ def runFrontend
| return ( mkEmptyEnvironment, false)
if let some out := trace.profiler.output.get? opts then
let traceState := cmdState.traceState
let profile Firefox.Profile.export mainModuleName.toString startTime traceState opts
let traceStates := snaps.getAll.map (·.traces)
let profile Firefox.Profile.export mainModuleName.toString startTime traceStates opts
IO.FS.writeFile out <| Json.compress <| toJson profile
let hasErrors := snaps.getAll.any (·.diagnostics.msgLog.hasErrors)

File diff suppressed because it is too large Load Diff

View File

@@ -139,12 +139,16 @@ def TermInfo.runMetaM (info : TermInfo) (ctx : ContextInfo) (x : MetaM α) : IO
def TermInfo.format (ctx : ContextInfo) (info : TermInfo) : IO Format := do
info.runMetaM ctx do
let ty : Format try
Meta.ppExpr ( Meta.inferType info.expr)
catch _ =>
pure "<failed-to-infer-type>"
let ty : Format
try
Meta.ppExpr ( Meta.inferType info.expr)
catch _ =>
pure "<failed-to-infer-type>"
return f!"{← Meta.ppExpr info.expr} {if info.isBinder then "(isBinder := true) " else ""}: {ty} @ {formatElabInfo ctx info.toElabInfo}"
def PartialTermInfo.format (ctx : ContextInfo) (info : PartialTermInfo) : Format :=
f!"Partial term @ {formatElabInfo ctx info.toElabInfo}"
def CompletionInfo.format (ctx : ContextInfo) (info : CompletionInfo) : IO Format :=
match info with
| .dot i (expectedType? := expectedType?) .. => return f!"[.] {← i.format ctx} : {expectedType?}"
@@ -191,9 +195,13 @@ def FieldRedeclInfo.format (ctx : ContextInfo) (info : FieldRedeclInfo) : Format
def OmissionInfo.format (ctx : ContextInfo) (info : OmissionInfo) : IO Format := do
return f!"Omission @ {← TermInfo.format ctx info.toTermInfo}\nReason: {info.reason}"
def ChoiceInfo.format (ctx : ContextInfo) (info : ChoiceInfo) : Format :=
f!"Choice @ {formatElabInfo ctx info.toElabInfo}"
def Info.format (ctx : ContextInfo) : Info IO Format
| ofTacticInfo i => i.format ctx
| ofTermInfo i => i.format ctx
| ofPartialTermInfo i => pure <| i.format ctx
| ofCommandInfo i => i.format ctx
| ofMacroExpansionInfo i => i.format ctx
| ofOptionInfo i => i.format ctx
@@ -204,10 +212,12 @@ def Info.format (ctx : ContextInfo) : Info → IO Format
| ofFVarAliasInfo i => pure <| i.format
| ofFieldRedeclInfo i => pure <| i.format ctx
| ofOmissionInfo i => i.format ctx
| ofChoiceInfo i => pure <| i.format ctx
def Info.toElabInfo? : Info Option ElabInfo
| ofTacticInfo i => some i.toElabInfo
| ofTermInfo i => some i.toElabInfo
| ofPartialTermInfo i => some i.toElabInfo
| ofCommandInfo i => some i.toElabInfo
| ofMacroExpansionInfo _ => none
| ofOptionInfo _ => none
@@ -218,6 +228,7 @@ def Info.toElabInfo? : Info → Option ElabInfo
| ofFVarAliasInfo _ => none
| ofFieldRedeclInfo _ => none
| ofOmissionInfo i => some i.toElabInfo
| ofChoiceInfo i => some i.toElabInfo
/--
Helper function for propagating the tactic metavariable context to its children nodes.
@@ -311,24 +322,36 @@ def realizeGlobalNameWithInfos (ref : Syntax) (id : Name) : CoreM (List (Name ×
addConstInfo ref n
return ns
/-- Use this to descend a node on the infotree that is being built.
/--
Adds a node containing the `InfoTree`s generated by `x` to the `InfoTree`s in `m`.
It saves the current list of trees `t₀` and resets it and then runs `x >>= mkInfo`, producing either an `i : Info` or a hole id.
Running `x >>= mkInfo` will modify the trees state and produce a new list of trees `t₁`.
In the `i : Info` case, `t₁` become the children of a node `node i t₁` that is appended to `t₀`.
-/
def withInfoContext' [MonadFinally m] (x : m α) (mkInfo : α m (Sum Info MVarId)) : m α := do
If `x` succeeds and `mkInfo` yields an `Info`, the `InfoTree`s of `x` become subtrees of a node
containing the `Info` produced by `mkInfo`, which is then added to the `InfoTree`s in `m`.
If `x` succeeds and `mkInfo` yields an `MVarId`, the `InfoTree`s of `x` are discarded and a `hole`
node is added to the `InfoTree`s in `m`.
If `x` fails, the `InfoTree`s of `x` become subtrees of a node containing the `Info` produced by
`mkInfoOnError`, which is then added to the `InfoTree`s in `m`.
The `InfoTree`s in `m` are reset before `x` is executed and restored with the addition of a new tree
after `x` is executed.
-/
def withInfoContext'
[MonadFinally m]
(x : m α)
(mkInfo : α m (Sum Info MVarId))
(mkInfoOnError : m Info) :
m α := do
if ( getInfoState).enabled then
let treesSaved getResetInfoTrees
Prod.fst <$> MonadFinally.tryFinally' x fun a? => do
match a? with
| none => modifyInfoTrees fun _ => treesSaved
| some a =>
let info mkInfo a
modifyInfoTrees fun trees =>
match info with
| Sum.inl info => treesSaved.push <| InfoTree.node info trees
| Sum.inr mvarId => treesSaved.push <| InfoTree.hole mvarId
let info do
match a? with
| none => pure <| .inl <| mkInfoOnError
| some a => mkInfo a
modifyInfoTrees fun trees =>
match info with
| Sum.inl info => treesSaved.push <| InfoTree.node info trees
| Sum.inr mvarId => treesSaved.push <| InfoTree.hole mvarId
else
x

View File

@@ -70,6 +70,18 @@ structure TermInfo extends ElabInfo where
isBinder : Bool := false
deriving Inhabited
/--
Used instead of `TermInfo` when a term couldn't successfully be elaborated,
and so there is no complete expression available.
The main purpose of `PartialTermInfo` is to ensure that the sub-`InfoTree`s of a failed elaborator
are retained so that they can still be used in the language server.
-/
structure PartialTermInfo extends ElabInfo where
lctx : LocalContext -- The local context when the term was elaborated.
expectedType? : Option Expr
deriving Inhabited
structure CommandInfo extends ElabInfo where
deriving Inhabited
@@ -79,7 +91,7 @@ inductive CompletionInfo where
| dot (termInfo : TermInfo) (expectedType? : Option Expr)
| id (stx : Syntax) (id : Name) (danglingDot : Bool) (lctx : LocalContext) (expectedType? : Option Expr)
| dotId (stx : Syntax) (id : Name) (lctx : LocalContext) (expectedType? : Option Expr)
| fieldId (stx : Syntax) (id : Name) (lctx : LocalContext) (structName : Name)
| fieldId (stx : Syntax) (id : Option Name) (lctx : LocalContext) (structName : Name)
| namespaceId (stx : Syntax)
| option (stx : Syntax)
| endSection (stx : Syntax) (scopeNames : List String)
@@ -165,10 +177,18 @@ regular delaboration settings.
structure OmissionInfo extends TermInfo where
reason : String
/--
Indicates that all overloaded elaborators failed. The subtrees of a `ChoiceInfo` node are the
partial `InfoTree`s of those failed elaborators. Retaining these partial `InfoTree`s helps
the language server provide interactivity even when all overloaded elaborators failed.
-/
structure ChoiceInfo extends ElabInfo where
/-- Header information for a node in `InfoTree`. -/
inductive Info where
| ofTacticInfo (i : TacticInfo)
| ofTermInfo (i : TermInfo)
| ofPartialTermInfo (i : PartialTermInfo)
| ofCommandInfo (i : CommandInfo)
| ofMacroExpansionInfo (i : MacroExpansionInfo)
| ofOptionInfo (i : OptionInfo)
@@ -179,6 +199,7 @@ inductive Info where
| ofFVarAliasInfo (i : FVarAliasInfo)
| ofFieldRedeclInfo (i : FieldRedeclInfo)
| ofOmissionInfo (i : OmissionInfo)
| ofChoiceInfo (i : ChoiceInfo)
deriving Inhabited
/-- The InfoTree is a structure that is generated during elaboration and used

View File

@@ -87,12 +87,15 @@ private def elabLetRecDeclValues (view : LetRecView) : TermElabM (Array Expr) :=
view.decls.mapM fun view => do
forallBoundedTelescope view.type view.binderIds.size fun xs type => do
-- Add new info nodes for new fvars. The server will detect all fvars of a binder by the binder's source location.
for i in [0:view.binderIds.size] do
addLocalVarInfo view.binderIds[i]! xs[i]!
for h : i in [0:view.binderIds.size] do
addLocalVarInfo view.binderIds[i] xs[i]!
withDeclName view.declName do
withInfoContext' view.valStx (mkInfo := (pure <| .inl <| mkBodyInfo view.valStx ·)) do
let value elabTermEnsuringType view.valStx type
mkLambdaFVars xs value
withInfoContext' view.valStx
(mkInfo := (pure <| .inl <| mkBodyInfo view.valStx ·))
(mkInfoOnError := (pure <| mkBodyInfo view.valStx none))
do
let value elabTermEnsuringType view.valStx type
mkLambdaFVars xs value
private def registerLetRecsToLift (views : Array LetRecDeclView) (fvars : Array Expr) (values : Array Expr) : TermElabM Unit := do
let letRecsToLiftCurr := ( get).letRecsToLift

View File

@@ -282,8 +282,8 @@ where
let dArg := dArgs[i]!
unless ( isDefEq tArg dArg) do
return i :: ( goType tArg dArg)
for i in [info.numParams : tArgs.size] do
let tArg := tArgs[i]!
for h : i in [info.numParams : tArgs.size] do
let tArg := tArgs[i]
let dArg := dArgs[i]!
unless ( isDefEq tArg dArg) do
return i :: ( goIndex tArg dArg)
@@ -644,7 +644,7 @@ where
if inaccessible? p |>.isSome then
return mkMData k ( withReader (fun _ => true) (go b))
else if let some (stx, p) := patternWithRef? p then
Elab.withInfoContext' (go p) fun p => do
Elab.withInfoContext' (go p) (mkInfoOnError := mkPartialTermInfo .anonymous stx) fun p => do
/- If `p` is a free variable and we are not inside of an "inaccessible" pattern, this `p` is a binder. -/
mkTermInfo Name.anonymous stx p (isBinder := p.isFVar && !( read))
else

View File

@@ -283,7 +283,7 @@ private partial def withFunLocalDecls {α} (headers : Array DefViewElabHeader) (
loop 0 #[]
private def expandWhereStructInst : Macro
| `(Parser.Command.whereStructInst|where $[$decls:letDecl];* $[$whereDecls?:whereDecls]?) => do
| whereStx@`(Parser.Command.whereStructInst|where%$whereTk $[$decls:letDecl];* $[$whereDecls?:whereDecls]?) => do
let letIdDecls decls.mapM fun stx => match stx with
| `(letDecl|$_decl:letPatDecl) => Macro.throwErrorAt stx "patterns are not allowed here"
| `(letDecl|$decl:letEqnsDecl) => expandLetEqnsDecl decl (useExplicit := false)
@@ -300,7 +300,30 @@ private def expandWhereStructInst : Macro
`(structInstField|$id:ident := $val)
| stx@`(letIdDecl|_ $_* $[: $_]? := $_) => Macro.throwErrorAt stx "'_' is not allowed here"
| _ => Macro.throwUnsupported
let startOfStructureTkInfo : SourceInfo :=
match whereTk.getPos? with
| some pos => .synthetic pos pos.byteIdx + 1 true
| none => .none
-- Position the closing `}` at the end of the trailing whitespace of `where $[$_:letDecl];*`.
-- We need an accurate range of the generated structure instance in the generated `TermInfo`
-- so that we can determine the expected type in structure field completion.
let structureStxTailInfo :=
whereStx[1].getTailInfo?
<|> whereStx[0].getTailInfo?
let endOfStructureTkInfo : SourceInfo :=
match structureStxTailInfo with
| some (SourceInfo.original _ _ trailing _) =>
let tokenPos := trailing.str.prev trailing.stopPos
let tokenEndPos := trailing.stopPos
.synthetic tokenPos tokenEndPos true
| _ => .none
let body `(structInst| { $structInstFields,* })
let body := body.raw.setInfo <|
match startOfStructureTkInfo.getPos?, endOfStructureTkInfo.getTailPos? with
| some startPos, some endPos => .synthetic startPos endPos true
| _, _ => .none
match whereDecls? with
| some whereDecls => expandWhereDecls whereDecls body
| none => return body
@@ -417,12 +440,15 @@ private def elabFunValues (headers : Array DefViewElabHeader) (vars : Array Expr
-- Store instantiated body in info tree for the benefit of the unused variables linter
-- and other metaprograms that may want to inspect it without paying for the instantiation
-- again
withInfoContext' valStx (mkInfo := (pure <| .inl <| mkBodyInfo valStx ·)) do
-- synthesize mvars here to force the top-level tactic block (if any) to run
let val elabTermEnsuringType valStx type <* synthesizeSyntheticMVarsNoPostponing
-- NOTE: without this `instantiatedMVars`, `mkLambdaFVars` may leave around a redex that
-- leads to more section variables being included than necessary
instantiateMVarsProfiling val
withInfoContext' valStx
(mkInfo := (pure <| .inl <| mkBodyInfo valStx ·))
(mkInfoOnError := (pure <| mkBodyInfo valStx none))
do
-- synthesize mvars here to force the top-level tactic block (if any) to run
let val elabTermEnsuringType valStx type <* synthesizeSyntheticMVarsNoPostponing
-- NOTE: without this `instantiatedMVars`, `mkLambdaFVars` may leave around a redex that
-- leads to more section variables being included than necessary
instantiateMVarsProfiling val
let val mkLambdaFVars xs val
if linter.unusedSectionVars.get ( getOptions) && !header.type.hasSorry && !val.hasSorry then
let unusedVars vars.filterMapM fun var => do
@@ -814,8 +840,8 @@ private def mkLetRecClosures (sectionVars : Array Expr) (mainFVarIds : Array FVa
abbrev Replacement := FVarIdMap Expr
def insertReplacementForMainFns (r : Replacement) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) : Replacement :=
mainFVars.size.fold (init := r) fun i r =>
r.insert mainFVars[i]!.fvarId! (mkAppN (Lean.mkConst mainHeaders[i]!.declName) sectionVars)
mainFVars.size.fold (init := r) fun i _ r =>
r.insert mainFVars[i].fvarId! (mkAppN (Lean.mkConst mainHeaders[i]!.declName) sectionVars)
def insertReplacementForLetRecs (r : Replacement) (letRecClosures : List LetRecClosure) : Replacement :=
@@ -845,8 +871,8 @@ def Replacement.apply (r : Replacement) (e : Expr) : Expr :=
def pushMain (preDefs : Array PreDefinition) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainVals : Array Expr)
: TermElabM (Array PreDefinition) :=
mainHeaders.size.foldM (init := preDefs) fun i preDefs => do
let header := mainHeaders[i]!
mainHeaders.size.foldM (init := preDefs) fun i _ preDefs => do
let header := mainHeaders[i]
let termination declValToTerminationHint header.value
let termination := termination.rememberExtraParams header.numParams mainVals[i]!
let value mkLambdaFVars sectionVars mainVals[i]!

File diff suppressed because it is too large Load Diff

View File

@@ -49,12 +49,12 @@ private def resolveNameUsingNamespacesCore (nss : List Name) (idStx : Syntax) :
exs := exs.push ex
if exs.size == nss.length then
withRef idStx do
if exs.size == 1 then
throw exs[0]!
if h : exs.size = 1 then
throw exs[0]
else
throwErrorWithNestedErrors "failed to open" exs
if result.size == 1 then
return result[0]!
if h : result.size = 1 then
return result[0]
else
withRef idStx do throwError "ambiguous identifier '{idStx.getId}', possible interpretations: {result.map mkConst}"

View File

@@ -332,9 +332,9 @@ where
else
let accessible := isNextArgAccessible ctx
let (d, ctx) := getNextParam ctx
match ctx.namedArgs.findIdx? fun namedArg => namedArg.name == d.1 with
match ctx.namedArgs.findFinIdx? fun namedArg => namedArg.name == d.1 with
| some idx =>
let arg := ctx.namedArgs[idx]!
let arg := ctx.namedArgs[idx]
let ctx := { ctx with namedArgs := ctx.namedArgs.eraseIdx idx }
let ctx pushNewArg accessible ctx arg.val
processCtorAppContext ctx

View File

@@ -244,8 +244,8 @@ def checkCodomainsLevel (preDefs : Array PreDefinition) : MetaM Unit := do
lambdaTelescope preDef.value fun xs _ => return xs.size
forallBoundedTelescope preDefs[0]!.type arities[0]! fun _ type₀ => do
let u₀ getLevel type₀
for i in [1:preDefs.size] do
forallBoundedTelescope preDefs[i]!.type arities[i]! fun _ typeᵢ =>
for h : i in [1:preDefs.size] do
forallBoundedTelescope preDefs[i].type arities[i]! fun _ typeᵢ =>
unless isLevelDefEq u₀ ( getLevel typeᵢ) do
withOptions (fun o => pp.sanitizeNames.set o false) do
throwError m!"invalid mutual definition, result types must be in the same universe " ++

View File

@@ -292,9 +292,9 @@ def mkBrecOnApp (positions : Positions) (fnIdx : Nat) (brecOnConst : Nat → Exp
let packedFTypes inferArgumentTypesN positions.size brecOn
let packedFArgs positions.mapMwith PProdN.mkLambdas packedFTypes FArgs
let brecOn := mkAppN brecOn packedFArgs
let some poss := positions.find? (·.contains fnIdx)
let some (size, idx) := positions.findSome? fun pos => (pos.size, ·) <$> pos.indexOf? fnIdx
| throwError "mkBrecOnApp: Could not find {fnIdx} in {positions}"
let brecOn PProdN.proj poss.size (poss.getIdx? fnIdx).get! brecOn
let brecOn PProdN.proj size idx brecOn
mkLambdaFVars ys (mkAppN brecOn otherArgs)
end Lean.Elab.Structural

View File

@@ -243,7 +243,7 @@ def tryAllArgs (fnNames : Array Name) (xs : Array Expr) (values : Array Expr)
recArgInfoss := recArgInfoss.push recArgInfos
-- Put non-indices first
recArgInfoss := recArgInfoss.map nonIndicesFirst
trace[Elab.definition.structural] "recArgInfoss: {recArgInfoss.map (·.map (·.recArgPos))}"
trace[Elab.definition.structural] "recArgInfos:{indentD (.joinSep (recArgInfoss.flatten.toList.map (repr ·)) Format.line)}"
-- Inductive groups to consider
let groups inductiveGroups recArgInfoss.flatten
trace[Elab.definition.structural] "inductive groups: {groups}"

View File

@@ -27,7 +27,7 @@ constituents.
structure IndGroupInfo where
all : Array Name
numNested : Nat
deriving BEq, Inhabited
deriving BEq, Inhabited, Repr
def IndGroupInfo.ofInductiveVal (indInfo : InductiveVal) : IndGroupInfo where
all := indInfo.all.toArray
@@ -56,7 +56,7 @@ mutual structural recursion on such incompatible types.
structure IndGroupInst extends IndGroupInfo where
levels : List Level
params : Array Expr
deriving Inhabited
deriving Inhabited, Repr
def IndGroupInst.toMessageData (igi : IndGroupInst) : MessageData :=
mkAppN (.const igi.all[0]! igi.levels) igi.params

View File

@@ -23,9 +23,9 @@ structure RecArgInfo where
fnName : Name
/-- the fixed prefix of arguments of the function we are trying to justify termination using structural recursion. -/
numFixed : Nat
/-- position of the argument (counted including fixed prefix) we are recursing on -/
/-- position (counted including fixed prefix) of the argument we are recursing on -/
recArgPos : Nat
/-- position of the indices (counted including fixed prefix) of the inductive datatype indices we are recursing on -/
/-- position (counted including fixed prefix) of the indices of the inductive datatype we are recursing on -/
indicesPos : Array Nat
/-- The inductive group (with parameters) of the argument's type -/
indGroupInst : IndGroupInst
@@ -34,20 +34,23 @@ structure RecArgInfo where
If `< indAll.all`, a normal data type, else an auxiliary data type due to nested recursion
-/
indIdx : Nat
deriving Inhabited
deriving Inhabited, Repr
/--
If `xs` are the parameters of the functions (excluding fixed prefix), partitions them
into indices and major arguments, and other parameters.
-/
def RecArgInfo.pickIndicesMajor (info : RecArgInfo) (xs : Array Expr) : (Array Expr × Array Expr) := Id.run do
-- First indices and major arg, using the order they appear in `info.indicesPos`
let mut indexMajorArgs := #[]
let indexMajorPos := info.indicesPos.push info.recArgPos
for j in indexMajorPos do
assert! info.numFixed j && j - info.numFixed < xs.size
indexMajorArgs := indexMajorArgs.push xs[j - info.numFixed]!
-- Then the other arguments, in the order they appear in `xs`
let mut otherArgs := #[]
for h : i in [:xs.size] do
let j := i + info.numFixed
if j = info.recArgPos || info.indicesPos.contains j then
indexMajorArgs := indexMajorArgs.push xs[i]
else
unless indexMajorPos.contains (i + info.numFixed) do
otherArgs := otherArgs.push xs[i]
return (indexMajorArgs, otherArgs)

View File

@@ -53,10 +53,10 @@ def TerminationArgument.elab (funName : Name) (type : Expr) (arity extraParams :
(hint : TerminationBy) : TermElabM TerminationArgument := withDeclName funName do
assert! extraParams arity
if hint.vars.size > extraParams then
if h : hint.vars.size > extraParams then
let mut msg := m!"{parameters hint.vars.size} bound in `termination_by`, but the body of " ++
m!"{funName} only binds {parameters extraParams}."
if let `($ident:ident) := hint.vars[0]! then
if let `($ident:ident) := hint.vars[0] then
if ident.getId.isSuffixOf funName then
msg := msg ++ m!" (Since Lean v4.6.0, the `termination_by` clause no longer " ++
"expects the function name here.)"

View File

@@ -90,10 +90,10 @@ lambda of `value`, and throws appropriate errors.
-/
def TerminationBy.checkVars (funName : Name) (extraParams : Nat) (tb : TerminationBy) : MetaM Unit := do
unless tb.synthetic do
if tb.vars.size > extraParams then
if h : tb.vars.size > extraParams then
let mut msg := m!"{parameters tb.vars.size} bound in `termination_by`, but the body of " ++
m!"{funName} only binds {parameters extraParams}."
if let `($ident:ident) := tb.vars[0]! then
if let `($ident:ident) := tb.vars[0] then
if ident.getId.isSuffixOf funName then
msg := msg ++ m!" (Since Lean v4.6.0, the `termination_by` clause no longer " ++
"expects the function name here.)"

View File

@@ -21,8 +21,8 @@ open Meta
private partial def addNonRecPreDefs (fixedPrefixSize : Nat) (argsPacker : ArgsPacker) (preDefs : Array PreDefinition) (preDefNonRec : PreDefinition) : TermElabM Unit := do
let us := preDefNonRec.levelParams.map mkLevelParam
let all := preDefs.toList.map (·.declName)
for fidx in [:preDefs.size] do
let preDef := preDefs[fidx]!
for h : fidx in [:preDefs.size] do
let preDef := preDefs[fidx]
let value forallBoundedTelescope preDef.type (some fixedPrefixSize) fun xs _ => do
let value := mkAppN (mkConst preDefNonRec.declName us) xs
let value argsPacker.curryProj value fidx

View File

@@ -40,7 +40,7 @@ private partial def post (fixedPrefix : Nat) (argsPacker : ArgsPacker) (funNames
return TransformStep.done e
let declName := f.constName!
let us := f.constLevels!
if let some fidx := funNames.getIdx? declName then
if let some fidx := funNames.indexOf? declName then
let arity := fixedPrefix + argsPacker.varNamess[fidx]!.size
let e' withAppN arity e fun args => do
let fixedArgs := args[:fixedPrefix]

View File

@@ -22,7 +22,7 @@ private def levelParamsToMessageData (levelParams : List Name) : MessageData :=
m := m ++ ", " ++ toMessageData u
return m ++ "}"
private def mkHeader (kind : String) (id : Name) (levelParams : List Name) (type : Expr) (safety : DefinitionSafety) : CommandElabM MessageData := do
private def mkHeader (kind : String) (id : Name) (levelParams : List Name) (type : Expr) (safety : DefinitionSafety) (sig : Bool := true) : CommandElabM MessageData := do
let m : MessageData :=
match ( getReducibilityStatus id) with
| ReducibilityStatus.irreducible => "@[irreducible] "
@@ -38,11 +38,13 @@ private def mkHeader (kind : String) (id : Name) (levelParams : List Name) (type
let (m, id) := match privateToUserName? id with
| some id => (m ++ "private ", id)
| none => (m, id)
let m := m ++ kind ++ " " ++ id ++ levelParamsToMessageData levelParams ++ " : " ++ type
pure m
if sig then
return m!"{m}{kind} {id}{levelParamsToMessageData levelParams} : {type}"
else
return m!"{m}{kind}"
private def mkHeader' (kind : String) (id : Name) (levelParams : List Name) (type : Expr) (isUnsafe : Bool) : CommandElabM MessageData :=
mkHeader kind id levelParams type (if isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe)
private def mkHeader' (kind : String) (id : Name) (levelParams : List Name) (type : Expr) (isUnsafe : Bool) (sig : Bool := true) : CommandElabM MessageData :=
mkHeader kind id levelParams type (if isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe) (sig := sig)
private def printDefLike (kind : String) (id : Name) (levelParams : List Name) (type : Expr) (value : Expr) (safety := DefinitionSafety.safe) : CommandElabM Unit := do
let m mkHeader kind id levelParams type safety
@@ -65,32 +67,63 @@ private def printInduct (id : Name) (levelParams : List Name) (numParams : Nat)
m := m ++ Format.line ++ ctor ++ " : " ++ cinfo.type
logInfo m
/--
Computes the origin of a field. Returns its projection function at the origin.
Multiple parents could be the origin of a field, but we say the first parent that provides it is the one that determines the origin.
-/
private partial def getFieldOrigin (structName field : Name) : MetaM Name := do
let env getEnv
for parent in getStructureParentInfo env structName do
if (findField? env parent.structName field).isSome then
return getFieldOrigin parent.structName field
let some fi := getFieldInfo? env structName field
| throwError "no such field {field} in {structName}"
return fi.projFn
open Meta in
private def printStructure (id : Name) (levelParams : List Name) (numParams : Nat) (type : Expr)
(ctor : Name) (fields : Array Name) (isUnsafe : Bool) (isClass : Bool) : CommandElabM Unit := do
let kind := if isClass then "class" else "structure"
let mut m mkHeader' kind id levelParams type isUnsafe
m := m ++ Format.line ++ "number of parameters: " ++ toString numParams
m := m ++ Format.line ++ "constructor:"
let cinfo getConstInfo ctor
m := m ++ Format.line ++ ctor ++ " : " ++ cinfo.type
m := m ++ Format.line ++ "fields:" ++ ( doFields)
logInfo m
where
doFields := liftTermElabM do
forallTelescope ( getConstInfo id).type fun params _ =>
withLocalDeclD `self (mkAppN (Expr.const id (levelParams.map .param)) params) fun self => do
let params := params.push self
let mut m : MessageData := ""
(isUnsafe : Bool) : CommandElabM Unit := do
let env getEnv
let kind := if isClass env id then "class" else "structure"
let header mkHeader' kind id levelParams type isUnsafe (sig := false)
liftTermElabM <| forallTelescope ( getConstInfo id).type fun params _ =>
let s := Expr.const id (levelParams.map .param)
withLocalDeclD `self (mkAppN s params) fun self => do
let mut m : MessageData := header
-- Signature
m := m ++ " " ++ .ofFormatWithInfosM do
let (stx, infos) PrettyPrinter.delabCore s (delab := PrettyPrinter.Delaborator.delabConstWithSignature)
pure PrettyPrinter.ppTerm stx, infos
m := m ++ Format.line ++ m!"number of parameters: {numParams}"
-- Parents
let parents := getStructureParentInfo env id
unless parents.isEmpty do
m := m ++ Format.line ++ "parents:"
for parent in parents do
let ptype inferType (mkApp (mkAppN (.const parent.projFn (levelParams.map .param)) params) self)
m := m ++ indentD m!"{.ofConstName parent.projFn (fullNames := true)} : {ptype}"
-- Fields
let fields := getStructureFieldsFlattened env id (includeSubobjectFields := false)
if fields.isEmpty then
m := m ++ Format.line ++ "fields: (none)"
else
m := m ++ Format.line ++ "fields:"
for field in fields do
match getProjFnForField? ( getEnv) id field with
| some proj =>
let field : Format := if isPrivateName proj then "private " ++ toString field else toString field
let cinfo getConstInfo proj
let ftype instantiateForall cinfo.type params
m := m ++ Format.line ++ field ++ " : " ++ ftype
| none => panic! "missing structure field info"
addMessageContext m
let some source := findField? env id field | panic! "missing structure field info"
let proj getFieldOrigin source field
let modifier := if isPrivateName proj then "private " else ""
let ftype inferType ( mkProjection self field)
m := m ++ indentD (m!"{modifier}{.ofConstName proj (fullNames := true)} : {ftype}")
-- Constructor
let cinfo := getStructureCtor ( getEnv) id
let ctorModifier := if isPrivateName cinfo.name then "private " else ""
m := m ++ Format.line ++ "constructor:" ++ indentD (ctorModifier ++ .signature cinfo.name)
-- Resolution order
let resOrder getStructureResolutionOrder id
if resOrder.size > 1 then
m := m ++ Format.line ++ "resolution order:"
++ indentD (MessageData.joinSep (resOrder.map (.ofConstName · (fullNames := true))).toList ", ")
logInfo m
private def printIdCore (id : Name) : CommandElabM Unit := do
let env getEnv
@@ -103,11 +136,10 @@ private def printIdCore (id : Name) : CommandElabM Unit := do
| ConstantInfo.ctorInfo { levelParams := us, type := t, isUnsafe := u, .. } => printAxiomLike "constructor" id us t u
| ConstantInfo.recInfo { levelParams := us, type := t, isUnsafe := u, .. } => printAxiomLike "recursor" id us t u
| ConstantInfo.inductInfo { levelParams := us, numParams, type := t, ctors, isUnsafe := u, .. } =>
match getStructureInfo? env id with
| some { fieldNames, .. } =>
let [ctor] := ctors | panic! "structures have only one constructor"
printStructure id us numParams t ctor fieldNames u (isClass env id)
| none => printInduct id us numParams t ctors u
if isStructure env id then
printStructure id us numParams t u
else
printInduct id us numParams t ctors u
| none => throwUnknownId id
private def printId (id : Syntax) : CommandElabM Unit := do

View File

@@ -58,7 +58,7 @@ partial def mkTuple : Array Syntax → TermElabM Syntax
| #[] => `(Unit.unit)
| #[e] => return e
| es => do
let stx mkTuple (es.eraseIdx 0)
let stx mkTuple (es.eraseIdxIfInBounds 0)
`(Prod.mk $(es[0]!) $stx)
def resolveSectionVariable (sectionVars : NameMap Name) (id : Name) : List (Name × List String) :=

View File

@@ -11,21 +11,40 @@ import Lean.Elab.App
import Lean.Elab.Binders
import Lean.PrettyPrinter
/-!
# Structure instance elaborator
A *structure instance* is notation to construct a term of a `structure`.
Examples: `{ x := 2, y.z := true }`, `{ s with cache := c' }`, and `{ s with values[2] := v }`.
Structure instances are the preferred way to invoke a `structure`'s constructor,
since they hide Lean implementation details such as whether parents are represented as subobjects,
and also they do correct processing of default values, which are complicated due to the fact that `structure`s can override default values of their parents.
This module elaborates structure instance notation.
Note that the `where` syntax to define structures (`Lean.Parser.Command.whereStructInst`)
macro expands into the structure instance notation elaborated by this module.
-/
namespace Lean.Elab.Term.StructInst
open Meta
open TSyntax.Compat
/-
Structure instances are of the form:
"{" >> optional (atomic (sepBy1 termParser ", " >> " with "))
>> manyIndent (group ((structInstFieldAbbrev <|> structInstField) >> optional ", "))
>> optEllipsis
>> optional (" : " >> termParser)
>> " }"
/-!
Recall that structure instances are of the form:
```
"{" >> optional (atomic (sepBy1 termParser ", " >> " with "))
>> manyIndent (group ((structInstFieldAbbrev <|> structInstField) >> optional ", "))
>> optEllipsis
>> optional (" : " >> termParser)
>> " }"
```
-/
/--
Transforms structure instances such as `{ x := 0 : Foo }` into `({ x := 0 } : Foo)`.
Structure instance notation makes use of the expected type.
-/
@[builtin_macro Lean.Parser.Term.structInst] def expandStructInstExpectedType : Macro := fun stx =>
let expectedArg := stx[4]
if expectedArg.isNone then
@@ -35,7 +54,10 @@ open TSyntax.Compat
let stxNew := stx.setArg 4 mkNullNode
`(($stxNew : $expected))
/-- Expand field abbreviations. Example: `{ x, y := 0 }` expands to `{ x := x, y := 0 }` -/
/--
Expands field abbreviation notation.
Example: `{ x, y := 0 }` expands to `{ x := x, y := 0 }`.
-/
@[builtin_macro Lean.Parser.Term.structInst] def expandStructInstFieldAbbrev : Macro
| `({ $[$srcs,* with]? $fields,* $[..%$ell]? $[: $ty]? }) =>
if fields.getElems.raw.any (·.getKind == ``Lean.Parser.Term.structInstFieldAbbrev) then do
@@ -49,9 +71,12 @@ open TSyntax.Compat
| _ => Macro.throwUnsupported
/--
If `stx` is of the form `{ s₁, ..., sₙ with ... }` and `sᵢ` is not a local variable, expand into `let src := sᵢ; { ..., src, ... with ... }`.
If `stx` is of the form `{ s₁, ..., sₙ with ... }` and `sᵢ` is not a local variable,
expands into `let __src := sᵢ; { ..., __src, ... with ... }`.
The significance of `__src` is that the variable is treated as an implementation-detail local variable,
which can be unfolded by `simp` when `zetaDelta := false`.
Note that this one is not a `Macro` because we need to access the local context.
Note that this one is not a `Macro` because we need to access the local context.
-/
private def expandNonAtomicExplicitSources (stx : Syntax) : TermElabM (Option Syntax) := do
let sourcesOpt := stx[1]
@@ -100,27 +125,44 @@ where
let r go sources (sourcesNew.push sourceNew)
`(let __src := $source; $r)
structure ExplicitSourceInfo where
/--
An *explicit source* is one of the structures `sᵢ` that appear in `{ s₁, …, sₙ with … }`.
-/
structure ExplicitSourceView where
/-- The syntax of the explicit source. -/
stx : Syntax
/-- The name of the structure for the type of the explicit source. -/
structName : Name
deriving Inhabited
structure Source where
explicit : Array ExplicitSourceInfo -- `s₁ ... sₙ with`
implicit : Option Syntax -- `..`
/--
A view of the sources of fields for the structure instance notation.
-/
structure SourcesView where
/-- Explicit sources (i.e., one of the structures `sᵢ` that appear in `{ s₁, …, sₙ with … }`). -/
explicit : Array ExplicitSourceView
/-- The syntax for a trailing `..`. This is "ellipsis mode" for missing fields, similar to ellipsis mode for applications. -/
implicit : Option Syntax
deriving Inhabited
def Source.isNone : Source Bool
/-- Returns `true` if the structure instance has no sources (neither explicit sources nor a `..`). -/
def SourcesView.isNone : SourcesView Bool
| { explicit := #[], implicit := none } => true
| _ => false
/-- `optional (atomic (sepBy1 termParser ", " >> " with ")` -/
/--
Given an array of explicit sources, returns syntax of the form
`optional (atomic (sepBy1 termParser ", " >> " with ")`
-/
private def mkSourcesWithSyntax (sources : Array Syntax) : Syntax :=
let ref := sources[0]!
let stx := Syntax.mkSep sources (mkAtomFrom ref ", ")
mkNullNode #[stx, mkAtomFrom ref "with "]
private def getStructSource (structStx : Syntax) : TermElabM Source :=
/--
Creates a structure source view from structure instance notation.
-/
private def getStructSources (structStx : Syntax) : TermElabM SourcesView :=
withRef structStx do
let explicitSource := structStx[1]
let implicitSource := structStx[3]
@@ -138,13 +180,13 @@ private def getStructSource (structStx : Syntax) : TermElabM Source :=
return { explicit, implicit }
/--
We say a `{ ... }` notation is a `modifyOp` if it contains only one
```
def structInstArrayRef := leading_parser "[" >> termParser >>"]"
```
We say a structure instance notation is a "modifyOp" if it contains only a single array update.
```lean
def structInstArrayRef := leading_parser "[" >> termParser >>"]"
```
-/
private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do
let s? stx[2].getSepArgs.foldlM (init := none) fun s? arg => do
let s? stx[2][0].getSepArgs.foldlM (init := none) fun s? arg => do
/- arg is of the form `structInstFieldAbbrev <|> structInstField` -/
if arg.getKind == ``Lean.Parser.Term.structInstField then
/- Remark: the syntax for `structInstField` is
@@ -177,7 +219,11 @@ private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do
| none => return none
| some s => if s[0][0].getKind == ``Lean.Parser.Term.structInstArrayRef then return s? else return none
private def elabModifyOp (stx modifyOp : Syntax) (sources : Array ExplicitSourceInfo) (expectedType? : Option Expr) : TermElabM Expr := do
/--
Given a `stx` that is a structure instance notation that's a modifyOp (according to `isModifyOp?`), elaborates it.
Only supports structure instances with a single source.
-/
private def elabModifyOp (stx modifyOp : Syntax) (sources : Array ExplicitSourceView) (expectedType? : Option Expr) : TermElabM Expr := do
if sources.size > 1 then
throwError "invalid \{...} notation, multiple sources and array update is not supported."
let cont (val : Syntax) : TermElabM Expr := do
@@ -199,17 +245,18 @@ private def elabModifyOp (stx modifyOp : Syntax) (sources : Array ExplicitSource
let valField := modifyOp.setArg 0 <| mkNode ``Parser.Term.structInstLVal #[valFirst, valRest]
let valSource := mkSourcesWithSyntax #[s]
let val := stx.setArg 1 valSource
let val := val.setArg 2 <| mkNullNode #[valField]
let val := val.setArg 2 <| mkNode ``Parser.Term.structInstFields #[mkNullNode #[valField]]
trace[Elab.struct.modifyOp] "{stx}\nval: {val}"
cont val
/--
Get structure name.
This method triest to postpone execution if the expected type is not available.
Gets the structure name for the structure instance from the expected type and the sources.
This method tries to postpone execution if the expected type is not available.
If the expected type is available and it is a structure, then we use it.
Otherwise, we use the type of the first source. -/
private def getStructName (expectedType? : Option Expr) (sourceView : Source) : TermElabM Name := do
If the expected type is available and it is a structure, then we use it.
Otherwise, we use the type of the first source.
-/
private def getStructName (expectedType? : Option Expr) (sourceView : SourcesView) : TermElabM Name := do
tryPostponeIfNoneOrMVar expectedType?
let useSource : Unit TermElabM Name := fun _ => do
unless sourceView.explicit.isEmpty do
@@ -226,7 +273,7 @@ private def getStructName (expectedType? : Option Expr) (sourceView : Source) :
unless isStructure ( getEnv) constName do
throwError "invalid \{...} notation, structure type expected{indentExpr expectedType}"
return constName
| _ => useSource ()
| _ => useSource ()
where
throwUnknownExpectedType :=
throwError "invalid \{...} notation, expected type is not known"
@@ -237,72 +284,91 @@ where
else
throwError "invalid \{...} notation, {kind} type is not of the form (C ...){indentExpr type}"
/--
A component of a left-hand side for a field appearing in structure instance syntax.
-/
inductive FieldLHS where
/-- A name component for a field left-hand side. For example, `x` and `y` in `{ x.y := v }`. -/
| fieldName (ref : Syntax) (name : Name)
/-- A numeric index component for a field left-hand side. For example `3` in `{ x.3 := v }`. -/
| fieldIndex (ref : Syntax) (idx : Nat)
/-- An array indexing component for a field left-hand side. For example `[3]` in `{ arr[3] := v }`. -/
| modifyOp (ref : Syntax) (index : Syntax)
deriving Inhabited
instance : ToFormat FieldLHS := fun lhs =>
match lhs with
| .fieldName _ n => format n
| .fieldIndex _ i => format i
| .modifyOp _ i => "[" ++ i.prettyPrint ++ "]"
instance : ToFormat FieldLHS where
format
| .fieldName _ n => format n
| .fieldIndex _ i => format i
| .modifyOp _ i => "[" ++ i.prettyPrint ++ "]"
inductive FieldVal (σ : Type) where
| term (stx : Syntax) : FieldVal σ
| nested (s : σ) : FieldVal σ
| default : FieldVal σ -- mark that field must be synthesized using default value
deriving Inhabited
mutual
/--
`FieldVal StructInstView` is a representation of a field value in the structure instance.
-/
inductive FieldVal where
/-- A `term` to use for the value of the field. -/
| term (stx : Syntax) : FieldVal
/-- A `StructInstView` to use for the value of a subobject field. -/
| nested (s : StructInstView) : FieldVal
/-- A field that was not provided and should be synthesized using default values. -/
| default : FieldVal
deriving Inhabited
structure Field (σ : Type) where
ref : Syntax
lhs : List FieldLHS
val : FieldVal σ
expr? : Option Expr := none
deriving Inhabited
/--
`Field StructInstView` is a representation of a field in the structure instance.
-/
structure Field where
/-- The whole field syntax. -/
ref : Syntax
/-- The LHS decomposed into components. -/
lhs : List FieldLHS
/-- The value of the field. -/
val : FieldVal
/-- The elaborated field value, filled in at `elabStruct`.
Missing fields use a metavariable for the elaborated value and are later solved for in `DefaultFields.propagate`. -/
expr? : Option Expr := none
deriving Inhabited
def Field.isSimple {σ} : Field σ Bool
/--
The view for structure instance notation.
-/
structure StructInstView where
/-- The syntax for the whole structure instance. -/
ref : Syntax
/-- The name of the structure for the type of the structure instance. -/
structName : Name
/-- Used for default values, to propagate structure type parameters. It is initially empty, and then set at `elabStruct`. -/
params : Array (Name × Expr)
/-- The fields of the structure instance. -/
fields : List Field
/-- The additional sources for fields for the structure instance. -/
sources : SourcesView
deriving Inhabited
end
/--
Returns if the field has a single component in its LHS.
-/
def Field.isSimple : Field Bool
| { lhs := [_], .. } => true
| _ => false
inductive Struct where
/-- Remark: the field `params` is use for default value propagation. It is initially empty, and then set at `elabStruct`. -/
| mk (ref : Syntax) (structName : Name) (params : Array (Name × Expr)) (fields : List (Field Struct)) (source : Source)
deriving Inhabited
abbrev Fields := List (Field Struct)
def Struct.ref : Struct Syntax
| ref, _, _, _, _ => ref
def Struct.structName : Struct Name
| _, structName, _, _, _ => structName
def Struct.params : Struct Array (Name × Expr)
| _, _, params, _, _ => params
def Struct.fields : Struct Fields
| _, _, _, fields, _ => fields
def Struct.source : Struct Source
| _, _, _, _, s => s
/-- `true` iff all fields of the given structure are marked as `default` -/
partial def Struct.allDefault (s : Struct) : Bool :=
partial def StructInstView.allDefault (s : StructInstView) : Bool :=
s.fields.all fun { val := val, .. } => match val with
| .term _ => false
| .default => true
| .nested s => allDefault s
def formatField (formatStruct : Struct Format) (field : Field Struct) : Format :=
def formatField (formatStruct : StructInstView Format) (field : Field) : Format :=
Format.joinSep field.lhs " . " ++ " := " ++
match field.val with
| .term v => v.prettyPrint
| .nested s => formatStruct s
| .default => "<default>"
partial def formatStruct : Struct Format
partial def formatStruct : StructInstView Format
| _, _, _, fields, source =>
let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", "
let implicitFmt := if source.implicit.isSome then " .. " else ""
@@ -311,31 +377,39 @@ partial def formatStruct : Struct → Format
else
"{" ++ format (source.explicit.map (·.stx)) ++ " with " ++ fieldsFmt ++ implicitFmt ++ "}"
instance : ToFormat Struct := formatStruct
instance : ToString Struct := toString format
instance : ToFormat StructInstView := formatStruct
instance : ToString StructInstView := toString format
instance : ToFormat (Field Struct) := formatField formatStruct
instance : ToString (Field Struct) := toString format
instance : ToFormat Field := formatField formatStruct
instance : ToString Field := toString format
/--
Converts a `FieldLHS` back into syntax. This assumes the `ref` fields have the correct structure.
/-
Recall that `structInstField` elements have the form
```
def structInstField := leading_parser structInstLVal >> " := " >> termParser
def structInstLVal := leading_parser (ident <|> numLit <|> structInstArrayRef) >> many (("." >> (ident <|> numLit)) <|> structInstArrayRef)
def structInstArrayRef := leading_parser "[" >> termParser >>"]"
```lean
def structInstField := leading_parser structInstLVal >> " := " >> termParser
def structInstLVal := leading_parser (ident <|> numLit <|> structInstArrayRef) >> many (("." >> (ident <|> numLit)) <|> structInstArrayRef)
def structInstArrayRef := leading_parser "[" >> termParser >>"]"
```
-/
-- Remark: this code relies on the fact that `expandStruct` only transforms `fieldLHS.fieldName`
def FieldLHS.toSyntax (first : Bool) : FieldLHS Syntax
private def FieldLHS.toSyntax (first : Bool) : FieldLHS Syntax
| .modifyOp stx _ => stx
| .fieldName stx name => if first then mkIdentFrom stx name else mkGroupNode #[mkAtomFrom stx ".", mkIdentFrom stx name]
| .fieldIndex stx _ => if first then stx else mkGroupNode #[mkAtomFrom stx ".", stx]
def FieldVal.toSyntax : FieldVal Struct Syntax
/--
Converts a `FieldVal StructInstView` back into syntax. Only supports `.term`, and it assumes the `stx` field has the correct structure.
-/
private def FieldVal.toSyntax : FieldVal Syntax
| .term stx => stx
| _ => unreachable!
| _ => unreachable!
def Field.toSyntax : Field Struct Syntax
/--
Converts a `Field StructInstView` back into syntax. Used to construct synthetic structure instance notation for subobjects in `StructInst.expandStruct` processing.
-/
private def Field.toSyntax : Field Syntax
| field =>
let stx := field.ref
let stx := stx.setArg 2 field.val.toSyntax
@@ -343,6 +417,7 @@ def Field.toSyntax : Field Struct → Syntax
| first::rest => stx.setArg 0 <| mkNullNode #[first.toSyntax true, mkNullNode <| rest.toArray.map (FieldLHS.toSyntax false) ]
| _ => unreachable!
/-- Creates a view of a field left-hand side. -/
private def toFieldLHS (stx : Syntax) : MacroM FieldLHS :=
if stx.getKind == ``Lean.Parser.Term.structInstArrayRef then
return FieldLHS.modifyOp stx stx[1]
@@ -355,11 +430,16 @@ private def toFieldLHS (stx : Syntax) : MacroM FieldLHS :=
| some idx => return FieldLHS.fieldIndex stx idx
| none => Macro.throwError "unexpected structure syntax"
private def mkStructView (stx : Syntax) (structName : Name) (source : Source) : MacroM Struct := do
/--
Creates a structure instance view from structure instance notation
and the computed structure name (from `Lean.Elab.Term.StructInst.getStructName`)
and structure source view (from `Lean.Elab.Term.StructInst.getStructSources`).
-/
private def mkStructView (stx : Syntax) (structName : Name) (sources : SourcesView) : MacroM StructInstView := do
/- Recall that `stx` is of the form
```
leading_parser "{" >> optional (atomic (sepBy1 termParser ", " >> " with "))
>> sepByIndent (structInstFieldAbbrev <|> structInstField) ...
>> structInstFields (sepByIndent (structInstFieldAbbrev <|> structInstField) ...)
>> optional ".."
>> optional (" : " >> termParser)
>> " }"
@@ -367,28 +447,22 @@ private def mkStructView (stx : Syntax) (structName : Name) (source : Source) :
This method assumes that `structInstFieldAbbrev` had already been expanded.
-/
let fields stx[2].getSepArgs.toList.mapM fun fieldStx => do
let fields stx[2][0].getSepArgs.toList.mapM fun fieldStx => do
let val := fieldStx[2]
let first toFieldLHS fieldStx[0][0]
let rest fieldStx[0][1].getArgs.toList.mapM toFieldLHS
return { ref := fieldStx, lhs := first :: rest, val := FieldVal.term val : Field Struct }
return stx, structName, #[], fields, source
return { ref := fieldStx, lhs := first :: rest, val := FieldVal.term val : Field }
return { ref := stx, structName, params := #[], fields, sources }
def Struct.modifyFieldsM {m : Type Type} [Monad m] (s : Struct) (f : Fields m Fields) : m Struct :=
def StructInstView.modifyFieldsM {m : Type Type} [Monad m] (s : StructInstView) (f : List Field m (List Field)) : m StructInstView :=
match s with
| ref, structName, params, fields, source => return ref, structName, params, ( f fields), source
| { ref, structName, params, fields, sources } => return { ref, structName, params, fields := ( f fields), sources }
def Struct.modifyFields (s : Struct) (f : Fields Fields) : Struct :=
def StructInstView.modifyFields (s : StructInstView) (f : List Field List Field) : StructInstView :=
Id.run <| s.modifyFieldsM f
def Struct.setFields (s : Struct) (fields : Fields) : Struct :=
s.modifyFields fun _ => fields
def Struct.setParams (s : Struct) (ps : Array (Name × Expr)) : Struct :=
match s with
| ref, structName, _, fields, source => ref, structName, ps, fields, source
private def expandCompositeFields (s : Struct) : Struct :=
/-- Expands name field LHSs with multi-component names into multi-component LHSs. -/
private def expandCompositeFields (s : StructInstView) : StructInstView :=
s.modifyFields fun fields => fields.map fun field => match field with
| { lhs := .fieldName _ (.str Name.anonymous ..) :: _, .. } => field
| { lhs := .fieldName ref n@(.str ..) :: rest, .. } =>
@@ -396,7 +470,8 @@ private def expandCompositeFields (s : Struct) : Struct :=
{ field with lhs := newEntries ++ rest }
| _ => field
private def expandNumLitFields (s : Struct) : TermElabM Struct :=
/-- Replaces numeric index field LHSs with the corresponding named field, or throws an error if no such field exists. -/
private def expandNumLitFields (s : StructInstView) : TermElabM StructInstView :=
s.modifyFieldsM fun fields => do
let env getEnv
let fieldNames := getStructureFields env s.structName
@@ -407,28 +482,31 @@ private def expandNumLitFields (s : Struct) : TermElabM Struct :=
else return { field with lhs := .fieldName ref fieldNames[idx - 1]! :: rest }
| _ => return field
/-- For example, consider the following structures:
```
structure A where
x : Nat
/--
Expands fields that are actually represented as fields of subobject fields.
structure B extends A where
y : Nat
For example, consider the following structures:
```
structure A where
x : Nat
structure C extends B where
z : Bool
```
This method expands parent structure fields using the path to the parent structure.
For example,
```
{ x := 0, y := 0, z := true : C }
```
is expanded into
```
{ toB.toA.x := 0, toB.y := 0, z := true : C }
```
structure B extends A where
y : Nat
structure C extends B where
z : Bool
```
This method expands parent structure fields using the path to the parent structure.
For example,
```
{ x := 0, y := 0, z := true : C }
```
is expanded into
```
{ toB.toA.x := 0, toB.y := 0, z := true : C }
```
-/
private def expandParentFields (s : Struct) : TermElabM Struct := do
private def expandParentFields (s : StructInstView) : TermElabM StructInstView := do
let env getEnv
s.modifyFieldsM fun fields => fields.mapM fun field => do match field with
| { lhs := .fieldName ref fieldName :: _, .. } =>
@@ -446,9 +524,14 @@ private def expandParentFields (s : Struct) : TermElabM Struct := do
| _ => throwErrorAt ref "failed to access field '{fieldName}' in parent structure"
| _ => return field
private abbrev FieldMap := Std.HashMap Name Fields
private abbrev FieldMap := Std.HashMap Name (List Field)
private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
/--
Creates a hash map collecting all fields with the same first name component.
Throws an error if there are multiple simple fields with the same name.
Used by `StructInst.expandStruct` processing.
-/
private def mkFieldMap (fields : List Field) : TermElabM FieldMap :=
fields.foldlM (init := {}) fun fieldMap field =>
match field.lhs with
| .fieldName _ fieldName :: _ =>
@@ -461,15 +544,16 @@ private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
| _ => return fieldMap.insert fieldName [field]
| _ => unreachable!
private def isSimpleField? : Fields Option (Field Struct)
/--
Given a value of the hash map created by `mkFieldMap`, returns true if the value corresponds to a simple field.
-/
private def isSimpleField? : List Field Option Field
| [field] => if field.isSimple then some field else none
| _ => none
private def getFieldIdx (structName : Name) (fieldNames : Array Name) (fieldName : Name) : TermElabM Nat := do
match fieldNames.findIdx? fun n => n == fieldName with
| some idx => return idx
| none => throwError "field '{fieldName}' is not a valid field of '{structName}'"
/--
Creates projection notation for the given structure field. Used
-/
def mkProjStx? (s : Syntax) (structName : Name) (fieldName : Name) : TermElabM (Option Syntax) := do
if (findField? ( getEnv) structName fieldName).isNone then
return none
@@ -478,7 +562,10 @@ def mkProjStx? (s : Syntax) (structName : Name) (fieldName : Name) : TermElabM (
#[mkAtomFrom s "@",
mkNode ``Parser.Term.proj #[s, mkAtomFrom s ".", mkIdentFrom s fieldName]]
def findField? (fields : Fields) (fieldName : Name) : Option (Field Struct) :=
/--
Finds a simple field of the given name.
-/
def findField? (fields : List Field) (fieldName : Name) : Option Field :=
fields.find? fun field =>
match field.lhs with
| [.fieldName _ n] => n == fieldName
@@ -486,7 +573,10 @@ def findField? (fields : Fields) (fieldName : Name) : Option (Field Struct) :=
mutual
private partial def groupFields (s : Struct) : TermElabM Struct := do
/--
Groups compound fields according to which subobject they are from.
-/
private partial def groupFields (s : StructInstView) : TermElabM StructInstView := do
let env getEnv
withRef s.ref do
s.modifyFieldsM fun fields => do
@@ -499,26 +589,28 @@ mutual
let field := fields.head!
match Lean.isSubobjectField? env s.structName fieldName with
| some substructName =>
let substruct := Struct.mk s.ref substructName #[] substructFields s.source
let substruct := { ref := s.ref, structName := substructName, params := #[], fields := substructFields, sources := s.sources }
let substruct expandStruct substruct
pure { field with lhs := [field.lhs.head!], val := FieldVal.nested substruct }
| none =>
let updateSource (structStx : Syntax) : TermElabM Syntax := do
let sourcesNew s.source.explicit.filterMapM fun source => mkProjStx? source.stx source.structName fieldName
let sourcesNew s.sources.explicit.filterMapM fun source => mkProjStx? source.stx source.structName fieldName
let explicitSourceStx := if sourcesNew.isEmpty then mkNullNode else mkSourcesWithSyntax sourcesNew
let implicitSourceStx := s.source.implicit.getD mkNullNode
let implicitSourceStx := s.sources.implicit.getD mkNullNode
return (structStx.setArg 1 explicitSourceStx).setArg 3 implicitSourceStx
let valStx := s.ref -- construct substructure syntax using s.ref as template
let valStx := valStx.setArg 4 mkNullNode -- erase optional expected type
let args := substructFields.toArray.map (·.toSyntax)
let valStx := valStx.setArg 2 (mkNullNode <| mkSepArray args (mkAtom ","))
let fieldsStx := mkNode ``Parser.Term.structInstFields
#[mkNullNode <| mkSepArray args (mkAtom ",")]
let valStx := valStx.setArg 2 fieldsStx
let valStx updateSource valStx
return { field with lhs := [field.lhs.head!], val := FieldVal.term valStx }
/--
Adds in the missing fields using the explicit sources.
Invariant: a missing field always comes from the first source that can provide it.
-/
private partial def addMissingFields (s : Struct) : TermElabM Struct := do
private partial def addMissingFields (s : StructInstView) : TermElabM StructInstView := do
let env getEnv
let fieldNames := getStructureFields env s.structName
let ref := s.ref.mkSynthetic
@@ -527,7 +619,7 @@ mutual
match findField? s.fields fieldName with
| some field => return field::fields
| none =>
let addField (val : FieldVal Struct) : TermElabM Fields := do
let addField (val : FieldVal) : TermElabM (List Field) := do
return { ref, lhs := [FieldLHS.fieldName ref fieldName], val := val } :: fields
match Lean.isSubobjectField? env s.structName fieldName with
| some substructName =>
@@ -535,8 +627,8 @@ mutual
let downFields := getStructureFieldsFlattened env substructName false
-- Filter out all explicit sources that do not share a leaf field keeping
-- structure with no fields
let filtered := s.source.explicit.filter fun source =>
let sourceFields := getStructureFieldsFlattened env source.structName false
let filtered := s.sources.explicit.filter fun sources =>
let sourceFields := getStructureFieldsFlattened env sources.structName false
sourceFields.any (fun name => downFields.contains name) || sourceFields.isEmpty
-- Take the first such one remaining
match filtered[0]? with
@@ -550,27 +642,30 @@ mutual
-- No sources could provide this subobject in the proper order.
-- Recurse to handle default values for fields.
else
let substruct := Struct.mk ref substructName #[] [] s.source
let substruct := { ref, structName := substructName, params := #[], fields := [], sources := s.sources }
let substruct expandStruct substruct
addField (FieldVal.nested substruct)
-- No sources could provide this subobject.
-- Recurse to handle default values for fields.
| none =>
let substruct := Struct.mk ref substructName #[] [] s.source
let substruct := { ref, structName := substructName, params := #[], fields := [], sources := s.sources }
let substruct expandStruct substruct
addField (FieldVal.nested substruct)
-- Since this is not a subobject field, we are free to use the first source that can
-- provide it.
| none =>
if let some val s.source.explicit.findSomeM? fun source => mkProjStx? source.stx source.structName fieldName then
if let some val s.sources.explicit.findSomeM? fun source => mkProjStx? source.stx source.structName fieldName then
addField (FieldVal.term val)
else if s.source.implicit.isSome then
else if s.sources.implicit.isSome then
addField (FieldVal.term (mkHole ref))
else
addField FieldVal.default
return s.setFields fields.reverse
return { s with fields := fields.reverse }
private partial def expandStruct (s : Struct) : TermElabM Struct := do
/--
Expands all fields of the structure instance, consolidates compound fields into subobject fields, and adds missing fields.
-/
private partial def expandStruct (s : StructInstView) : TermElabM StructInstView := do
let s := expandCompositeFields s
let s expandNumLitFields s
let s expandParentFields s
@@ -579,10 +674,17 @@ mutual
end
/--
The constructor to use for the structure instance notation.
-/
structure CtorHeaderResult where
/-- The constructor function with applied structure parameters. -/
ctorFn : Expr
/-- The type of `ctorFn` -/
ctorFnType : Expr
/-- Instance metavariables for structure parameters that are instance implicit. -/
instMVars : Array MVarId
/-- Type parameter names and metavariables for each parameter. Used to seed `StructInstView.params`. -/
params : Array (Name × Expr)
private def mkCtorHeaderAux : Nat Expr Expr Array MVarId Array (Name × Expr) TermElabM CtorHeaderResult
@@ -604,6 +706,7 @@ private partial def getForallBody : Nat → Expr → Option Expr
| _+1, _ => none
| 0, type => type
/-- Attempts to use the expected type to solve for structure parameters. -/
private def propagateExpectedType (type : Expr) (numFields : Nat) (expectedType? : Option Expr) : TermElabM Unit := do
match expectedType? with
| none => return ()
@@ -614,6 +717,7 @@ private def propagateExpectedType (type : Expr) (numFields : Nat) (expectedType?
unless typeBody.hasLooseBVars do
discard <| isDefEq expectedType typeBody
/-- Elaborates the structure constructor using the expected type, filling in all structure parameters. -/
private def mkCtorHeader (ctorVal : ConstructorVal) (expectedType? : Option Expr) : TermElabM CtorHeaderResult := do
let us mkFreshLevelMVars ctorVal.levelParams.length
let val := Lean.mkConst ctorVal.name us
@@ -623,32 +727,43 @@ private def mkCtorHeader (ctorVal : ConstructorVal) (expectedType? : Option Expr
synthesizeAppInstMVars r.instMVars r.ctorFn
return r
/-- Annotates an expression that it is a value for a missing field. -/
def markDefaultMissing (e : Expr) : Expr :=
mkAnnotation `structInstDefault e
/-- If the expression has been annotated by `markDefaultMissing`, returns the unannotated expression. -/
def defaultMissing? (e : Expr) : Option Expr :=
annotation? `structInstDefault e
/-- Throws "failed to elaborate field" error. -/
def throwFailedToElabField {α} (fieldName : Name) (structName : Name) (msgData : MessageData) : TermElabM α :=
throwError "failed to elaborate field '{fieldName}' of '{structName}, {msgData}"
def trySynthStructInstance? (s : Struct) (expectedType : Expr) : TermElabM (Option Expr) := do
/-- If the struct has all-missing fields, tries to synthesize the structure using typeclass inference. -/
def trySynthStructInstance? (s : StructInstView) (expectedType : Expr) : TermElabM (Option Expr) := do
if !s.allDefault then
return none
else
try synthInstance? expectedType catch _ => return none
/-- The result of elaborating a `StructInstView` structure instance view. -/
structure ElabStructResult where
/-- The elaborated value. -/
val : Expr
struct : Struct
/-- The modified `StructInstView` view after elaboration. -/
struct : StructInstView
/-- Metavariables for instance implicit fields. These will be registered after default value propagation. -/
instMVars : Array MVarId
private partial def elabStruct (s : Struct) (expectedType? : Option Expr) : TermElabM ElabStructResult := withRef s.ref do
/--
Main elaborator for structure instances.
-/
private partial def elabStructInstView (s : StructInstView) (expectedType? : Option Expr) : TermElabM ElabStructResult := withRef s.ref do
let env getEnv
let ctorVal := getStructureCtor env s.structName
if isPrivateNameFromImportedModule env ctorVal.name then
throwError "invalid \{...} notation, constructor for `{s.structName}` is marked as private"
-- We store the parameters at the resulting `Struct`. We use this information during default value propagation.
-- We store the parameters at the resulting `StructInstView`. We use this information during default value propagation.
let { ctorFn, ctorFnType, params, .. } mkCtorHeader ctorVal expectedType?
let (e, _, fields, instMVars) s.fields.foldlM (init := (ctorFn, ctorFnType, [], #[])) fun (e, type, fields, instMVars) field => do
match field.lhs with
@@ -657,7 +772,7 @@ private partial def elabStruct (s : Struct) (expectedType? : Option Expr) : Term
trace[Elab.struct] "elabStruct {field}, {type}"
match type with
| .forallE _ d b bi =>
let cont (val : Expr) (field : Field Struct) (instMVars := instMVars) : TermElabM (Expr × Expr × Fields × Array MVarId) := do
let cont (val : Expr) (field : Field) (instMVars := instMVars) : TermElabM (Expr × Expr × List Field × Array MVarId) := do
pushInfoTree <| InfoTree.node (children := {}) <| Info.ofFieldInfo {
projName := s.structName.append fieldName, fieldName, lctx := ( getLCtx), val, stx := ref }
let e := mkApp e val
@@ -671,7 +786,7 @@ private partial def elabStruct (s : Struct) (expectedType? : Option Expr) : Term
match ( trySynthStructInstance? s d) with
| some val => cont val { field with val := FieldVal.term (mkHole field.ref) }
| none =>
let { val, struct := sNew, instMVars := instMVarsNew } elabStruct s (some d)
let { val, struct := sNew, instMVars := instMVarsNew } elabStructInstView s (some d)
let val ensureHasType d val
cont val { field with val := FieldVal.nested sNew } (instMVars ++ instMVarsNew)
| .default =>
@@ -700,17 +815,21 @@ private partial def elabStruct (s : Struct) (expectedType? : Option Expr) : Term
cont (markDefaultMissing val) field
| _ => withRef field.ref <| throwFailedToElabField fieldName s.structName m!"unexpected constructor type{indentExpr type}"
| _ => throwErrorAt field.ref "unexpected unexpanded structure field"
return { val := e, struct := s.setFields fields.reverse |>.setParams params, instMVars }
return { val := e, struct := { s with fields := fields.reverse, params }, instMVars }
namespace DefaultFields
/--
Context for default value propagation.
-/
structure Context where
-- We must search for default values overridden in derived structures
structs : Array Struct := #[]
/-- The current path through `.nested` subobject structures. We must search for default values overridden in derived structures. -/
structs : Array StructInstView := #[]
/-- The collection of structures that could provide a default value. -/
allStructNames : Array Name := #[]
/--
Consider the following example:
```
```lean
structure A where
x : Nat := 1
@@ -736,63 +855,81 @@ structure Context where
-/
maxDistance : Nat := 0
/--
State for default value propagation
-/
structure State where
/-- Whether progress has been made so far on this round of the propagation loop. -/
progress : Bool := false
partial def collectStructNames (struct : Struct) (names : Array Name) : Array Name :=
/-- Collects all structures that may provide default values for fields. -/
partial def collectStructNames (struct : StructInstView) (names : Array Name) : Array Name :=
let names := names.push struct.structName
struct.fields.foldl (init := names) fun names field =>
match field.val with
| .nested struct => collectStructNames struct names
| _ => names
partial def getHierarchyDepth (struct : Struct) : Nat :=
/-- Gets the maximum nesting depth of subobjects. -/
partial def getHierarchyDepth (struct : StructInstView) : Nat :=
struct.fields.foldl (init := 0) fun max field =>
match field.val with
| .nested struct => Nat.max max (getHierarchyDepth struct + 1)
| _ => max
def isDefaultMissing? [Monad m] [MonadMCtx m] (field : Field Struct) : m Bool := do
/-- Returns whether the field is still missing. -/
def isDefaultMissing? [Monad m] [MonadMCtx m] (field : Field) : m Bool := do
if let some expr := field.expr? then
if let some (.mvar mvarId) := defaultMissing? expr then
unless ( mvarId.isAssigned) do
return true
return false
partial def findDefaultMissing? [Monad m] [MonadMCtx m] (struct : Struct) : m (Option (Field Struct)) :=
/-- Returns a field that is still missing. -/
partial def findDefaultMissing? [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Option Field) :=
struct.fields.findSomeM? fun field => do
match field.val with
| .nested struct => findDefaultMissing? struct
| _ => return if ( isDefaultMissing? field) then field else none
partial def allDefaultMissing [Monad m] [MonadMCtx m] (struct : Struct) : m (Array (Field Struct)) :=
/-- Returns all fields that are still missing. -/
partial def allDefaultMissing [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Array Field) :=
go struct *> get |>.run' #[]
where
go (struct : Struct) : StateT (Array (Field Struct)) m Unit :=
go (struct : StructInstView) : StateT (Array Field) m Unit :=
for field in struct.fields do
if let .nested struct := field.val then
go struct
else if ( isDefaultMissing? field) then
modify (·.push field)
def getFieldName (field : Field Struct) : Name :=
/-- Returns the name of the field. Assumes all fields under consideration are simple and named. -/
def getFieldName (field : Field) : Name :=
match field.lhs with
| [.fieldName _ fieldName] => fieldName
| _ => unreachable!
abbrev M := ReaderT Context (StateRefT State TermElabM)
/-- Returns whether we should interrupt the round because we have made progress allowing nonzero depth. -/
def isRoundDone : M Bool := do
return ( get).progress && ( read).maxDistance > 0
def getFieldValue? (struct : Struct) (fieldName : Name) : Option Expr :=
/-- Returns the `expr?` for the given field. -/
def getFieldValue? (struct : StructInstView) (fieldName : Name) : Option Expr :=
struct.fields.findSome? fun field =>
if getFieldName field == fieldName then
field.expr?
else
none
partial def mkDefaultValueAux? (struct : Struct) : Expr TermElabM (Option Expr)
/-- Instantiates a default value from the given default value declaration, if applicable. -/
partial def mkDefaultValue? (struct : StructInstView) (cinfo : ConstantInfo) : TermElabM (Option Expr) :=
withRef struct.ref do
let us mkFreshLevelMVarsFor cinfo
process ( instantiateValueLevelParams cinfo us)
where
process : Expr TermElabM (Option Expr)
| .lam n d b c => withRef struct.ref do
if c.isExplicit then
let fieldName := n
@@ -801,29 +938,26 @@ partial def mkDefaultValueAux? (struct : Struct) : Expr → TermElabM (Option Ex
| some val =>
let valType inferType val
if ( isDefEq valType d) then
mkDefaultValueAux? struct (b.instantiate1 val)
process (b.instantiate1 val)
else
return none
else
if let some (_, param) := struct.params.find? fun (paramName, _) => paramName == n then
-- Recall that we did not use to have support for parameter propagation here.
if ( isDefEq ( inferType param) d) then
mkDefaultValueAux? struct (b.instantiate1 param)
process (b.instantiate1 param)
else
return none
else
let arg mkFreshExprMVar d
mkDefaultValueAux? struct (b.instantiate1 arg)
process (b.instantiate1 arg)
| e =>
let_expr id _ a := e | return some e
return some a
def mkDefaultValue? (struct : Struct) (cinfo : ConstantInfo) : TermElabM (Option Expr) :=
withRef struct.ref do
let us mkFreshLevelMVarsFor cinfo
mkDefaultValueAux? struct ( instantiateValueLevelParams cinfo us)
/-- Reduce default value. It performs beta reduction and projections of the given structures. -/
/--
Reduces a default value. It performs beta reduction and projections of the given structures to reduce them to the provided values for fields.
-/
partial def reduce (structNames : Array Name) (e : Expr) : MetaM Expr := do
match e with
| .forallE .. =>
@@ -880,7 +1014,10 @@ where
else
k
partial def tryToSynthesizeDefault (structs : Array Struct) (allStructNames : Array Name) (maxDistance : Nat) (fieldName : Name) (mvarId : MVarId) : TermElabM Bool :=
/--
Attempts to synthesize a default value for a missing field `fieldName` using default values from each structure in `structs`.
-/
def tryToSynthesizeDefault (structs : Array StructInstView) (allStructNames : Array Name) (maxDistance : Nat) (fieldName : Name) (mvarId : MVarId) : TermElabM Bool :=
let rec loop (i : Nat) (dist : Nat) := do
if dist > maxDistance then
return false
@@ -900,14 +1037,25 @@ partial def tryToSynthesizeDefault (structs : Array Struct) (allStructNames : Ar
| none =>
let mvarDecl getMVarDecl mvarId
let val ensureHasType mvarDecl.type val
mvarId.assign val
return true
/-
We must use `checkedAssign` here to ensure we do not create a cyclic
assignment. See #3150.
This can happen when there are holes in the the fields the default value
depends on.
Possible improvement: create a new `_` instead of returning `false` when
`checkedAssign` fails. Reason: the field will not be needed after the
other `_` are resolved by the user.
-/
mvarId.checkedAssign val
| _ => loop (i+1) dist
else
return false
loop 0 0
partial def step (struct : Struct) : M Unit :=
/--
Performs one step of default value synthesis.
-/
partial def step (struct : StructInstView) : M Unit :=
unless ( isRoundDone) do
withReader (fun ctx => { ctx with structs := ctx.structs.push struct }) do
for field in struct.fields do
@@ -924,7 +1072,10 @@ partial def step (struct : Struct) : M Unit :=
modify fun _ => { progress := true }
| _ => pure ()
partial def propagateLoop (hierarchyDepth : Nat) (d : Nat) (struct : Struct) : M Unit := do
/--
Main entry point to default value synthesis in the `M` monad.
-/
partial def propagateLoop (hierarchyDepth : Nat) (d : Nat) (struct : StructInstView) : M Unit := do
match ( findDefaultMissing? struct) with
| none => return () -- Done
| some field =>
@@ -947,16 +1098,22 @@ partial def propagateLoop (hierarchyDepth : Nat) (d : Nat) (struct : Struct) : M
else
propagateLoop hierarchyDepth (d+1) struct
def propagate (struct : Struct) : TermElabM Unit :=
/--
Synthesizes default values for all missing fields, if possible.
-/
def propagate (struct : StructInstView) : TermElabM Unit :=
let hierarchyDepth := getHierarchyDepth struct
let structNames := collectStructNames struct #[]
propagateLoop hierarchyDepth 0 struct { allStructNames := structNames } |>.run' {}
end DefaultFields
private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (source : Source) : TermElabM Expr := do
let structName getStructName expectedType? source
let struct liftMacroM <| mkStructView stx structName source
/--
Main entry point to elaborator for structure instance notation, unless the structure instance is a modifyOp.
-/
private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (sources : SourcesView) : TermElabM Expr := do
let structName getStructName expectedType? sources
let struct liftMacroM <| mkStructView stx structName sources
let struct expandStruct struct
trace[Elab.struct] "{struct}"
/- We try to synthesize pending problems with `withSynthesize` combinator before trying to use default values.
@@ -974,7 +1131,7 @@ private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (sour
TODO: investigate whether this design decision may have unintended side effects or produce confusing behavior.
-/
let { val := r, struct, instMVars } withSynthesize (postpone := .yes) <| elabStruct struct expectedType?
let { val := r, struct, instMVars } withSynthesize (postpone := .yes) <| elabStructInstView struct expectedType?
trace[Elab.struct] "before propagate {r}"
DefaultFields.propagate struct
synthesizeAppInstMVars instMVars r
@@ -984,13 +1141,13 @@ private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (sour
match ( expandNonAtomicExplicitSources stx) with
| some stxNew => withMacroExpansion stx stxNew <| elabTerm stxNew expectedType?
| none =>
let sourceView getStructSource stx
let sourcesView getStructSources stx
if let some modifyOp isModifyOp? stx then
if sourceView.explicit.isEmpty then
if sourcesView.explicit.isEmpty then
throwError "invalid \{...} notation, explicit source is required when using '[<index>] := <value>'"
elabModifyOp stx modifyOp sourceView.explicit expectedType?
elabModifyOp stx modifyOp sourcesView.explicit expectedType?
else
elabStructInstAux stx expectedType? sourceView
elabStructInstAux stx expectedType? sourcesView
builtin_initialize
registerTraceClass `Elab.struct

View File

@@ -4,22 +4,15 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Class
import Lean.Parser.Command
import Lean.Meta.Closure
import Lean.Meta.SizeOf
import Lean.Meta.Injective
import Lean.Meta.Structure
import Lean.Meta.AppBuilder
import Lean.Elab.Command
import Lean.Elab.DeclModifiers
import Lean.Elab.DeclUtil
import Lean.Elab.Inductive
import Lean.Elab.DeclarationRange
import Lean.Elab.Binders
import Lean.Elab.MutualInductive
namespace Lean.Elab.Command
builtin_initialize
registerTraceClass `Elab.structure
registerTraceClass `Elab.structure.resolutionOrder
register_builtin_option structureDiamondWarning : Bool := {
defValue := false
descr := "if true, enable warnings when a structure has diamond inheritance"
@@ -39,13 +32,6 @@ leading_parser (structureTk <|> classTk) >> declId >> many Term.bracketedBinder
```
-/
structure StructCtorView where
ref : Syntax
modifiers : Modifiers
name : Name
declName : Name
deriving Inhabited
structure StructFieldView where
ref : Syntax
modifiers : Modifiers
@@ -61,22 +47,15 @@ structure StructFieldView where
type? : Option Syntax
value? : Option Syntax
structure StructView where
ref : Syntax
declId : Syntax
modifiers : Modifiers
isClass : Bool -- struct-only
shortDeclName : Name
declName : Name
levelNames : List Name
binders : Syntax
type : Syntax -- modified (inductive has type?)
parents : Array Syntax -- struct-only
ctor : StructCtorView -- struct-only
fields : Array StructFieldView -- struct-only
derivingClasses : Array DerivingClassView
structure StructView extends InductiveView where
parents : Array Syntax
fields : Array StructFieldView
deriving Inhabited
def StructView.ctor : StructView CtorView
| { ctors := #[ctor], ..} => ctor
| _ => unreachable!
structure StructParentInfo where
ref : Syntax
fvar? : Option Expr
@@ -102,18 +81,6 @@ structure StructFieldInfo where
value? : Option Expr := none
deriving Inhabited, Repr
structure ElabStructHeaderResult where
view : StructView
lctx : LocalContext
localInsts : LocalInstances
levelNames : List Name
params : Array Expr
type : Expr
parents : Array StructParentInfo
/-- Field infos from parents. -/
parentFieldInfos : Array StructFieldInfo
deriving Inhabited
def StructFieldInfo.isFromParent (info : StructFieldInfo) : Bool :=
match info.kind with
| StructFieldKind.fromParent => true
@@ -130,12 +97,12 @@ The structure constructor syntax is
leading_parser try (declModifiers >> ident >> " :: ")
```
-/
private def expandCtor (structStx : Syntax) (structModifiers : Modifiers) (structDeclName : Name) : TermElabM StructCtorView := do
private def expandCtor (structStx : Syntax) (structModifiers : Modifiers) (structDeclName : Name) : TermElabM CtorView := do
let useDefault := do
let declName := structDeclName ++ defaultCtorName
let ref := structStx[1].mkSynthetic
addDeclarationRangesFromSyntax declName ref
pure { ref, modifiers := default, name := defaultCtorName, declName }
pure { ref, declId := ref, modifiers := default, declName }
if structStx[5].isNone then
useDefault
else
@@ -156,7 +123,7 @@ private def expandCtor (structStx : Syntax) (structModifiers : Modifiers) (struc
let declName applyVisibility ctorModifiers.visibility declName
addDocString' declName ctorModifiers.docString?
addDeclarationRangesFromSyntax declName ctor[1]
pure { ref := ctor[1], name, modifiers := ctorModifiers, declName }
pure { ref := ctor[1], declId := ctor[1], modifiers := ctorModifiers, declName }
def checkValidFieldModifier (modifiers : Modifiers) : TermElabM Unit := do
if modifiers.isNoncomputable then
@@ -271,7 +238,7 @@ def structureSyntaxToView (modifiers : Modifiers) (stx : Syntax) : TermElabM Str
let parents := if exts.isNone then #[] else exts[0][1].getSepArgs
let optType := stx[4]
let derivingClasses getOptDerivingClasses stx[6]
let type if optType.isNone then `(Sort _) else pure optType[0][1]
let type? := if optType.isNone then none else some optType[0][1]
let ctor expandCtor stx modifiers declName
let fields expandFields stx modifiers declName
fields.forM fun field => do
@@ -287,10 +254,13 @@ def structureSyntaxToView (modifiers : Modifiers) (stx : Syntax) : TermElabM Str
declName
levelNames
binders
type
type?
allowIndices := false
allowSortPolymorphism := false
ctors := #[ctor]
parents
ctor
fields
computedFields := #[]
derivingClasses
}
@@ -536,7 +506,7 @@ private partial def mkToParentName (parentStructName : Name) (p : Name → Bool)
if p curr then curr else go (i+1)
go 1
private partial def elabParents (view : StructView)
private partial def withParents (view : StructView) (rs : Array ElabHeaderResult) (indFVar : Expr)
(k : Array StructFieldInfo Array StructParentInfo TermElabM α) : TermElabM α := do
go 0 #[] #[]
where
@@ -544,11 +514,17 @@ where
if h : i < view.parents.size then
let parent := view.parents[i]
withRef parent do
let type Term.elabType parent
-- The only use case for autobound implicits for parents might be outParams, but outParam is not propagated.
let type Term.withoutAutoBoundImplicit <| Term.elabType parent
let parentType whnf type
if parentType.getAppFn == indFVar then
logWarning "structure extends itself, skipping"
return go (i + 1) infos parents
if rs.any (fun r => r.indFVar == parentType.getAppFn) then
throwError "structure cannot extend types defined in the same mutual block"
let parentStructName getStructureName parentType
if parents.any (fun info => info.structName == parentStructName) then
logWarningAt parent m!"duplicate parent structure '{.ofConstName parentStructName}', skipping"
logWarning m!"duplicate parent structure '{.ofConstName parentStructName}', skipping"
go (i + 1) infos parents
else if let some existingFieldName findExistingField? infos parentStructName then
if structureDiamondWarning.get ( getOptions) then
@@ -570,6 +546,13 @@ where
else
k infos parents
private def registerFailedToInferFieldType (fieldName : Name) (e : Expr) (ref : Syntax) : TermElabM Unit := do
Term.registerCustomErrorIfMVar ( instantiateMVars e) ref m!"failed to infer type of field '{.ofConstName fieldName}'"
private def registerFailedToInferDefaultValue (fieldName : Name) (e : Expr) (ref : Syntax) : TermElabM Unit := do
Term.registerCustomErrorIfMVar ( instantiateMVars e) ref m!"failed to infer default value for field '{.ofConstName fieldName}'"
Term.registerLevelMVarErrorExprInfo e ref m!"failed to infer universe levels in default value for field '{.ofConstName fieldName}'"
private def elabFieldTypeValue (view : StructFieldView) : TermElabM (Option Expr × Option Expr) :=
Term.withAutoBoundImplicit <| Term.withAutoBoundImplicitForbiddenPred (fun n => view.name == n) <| Term.elabBinders view.binders.getArgs fun params => do
match view.type? with
@@ -581,10 +564,13 @@ private def elabFieldTypeValue (view : StructFieldView) : TermElabM (Option Expr
-- TODO: add forbidden predicate using `shortDeclName` from `view`
let params Term.addAutoBoundImplicits params
let value Term.withoutAutoBoundImplicit <| Term.elabTerm valStx none
registerFailedToInferFieldType view.name ( inferType value) view.nameId
registerFailedToInferDefaultValue view.name value valStx
let value mkLambdaFVars params value
return (none, value)
| some typeStx =>
let type Term.elabType typeStx
registerFailedToInferFieldType view.name type typeStx
Term.synthesizeSyntheticMVarsNoPostponing
let params Term.addAutoBoundImplicits params
match view.value? with
@@ -593,6 +579,7 @@ private def elabFieldTypeValue (view : StructFieldView) : TermElabM (Option Expr
return (type, none)
| some valStx =>
let value Term.withoutAutoBoundImplicit <| Term.elabTermEnsuringType valStx type
registerFailedToInferDefaultValue view.name value valStx
Term.synthesizeSyntheticMVarsNoPostponing
let type mkForallFVars params type
let value mkLambdaFVars params value
@@ -639,6 +626,7 @@ where
valStx `(fun $(view.binders.getArgs)* => $valStx:term)
let fvarType inferType info.fvar
let value Term.elabTermEnsuringType valStx fvarType
registerFailedToInferDefaultValue view.name value valStx
pushInfoLeaf <| .ofFieldRedeclInfo { stx := view.ref }
let infos := replaceFieldInfo infos { info with ref := view.nameId, value? := value }
go (i+1) defaultValsOverridden infos
@@ -650,113 +638,14 @@ where
else
k infos
private def getResultUniverse (type : Expr) : TermElabM Level := do
let type whnf type
match type with
| Expr.sort u => pure u
| _ => throwError "unexpected structure resulting type"
private def collectUsed (params : Array Expr) (fieldInfos : Array StructFieldInfo) : StateRefT CollectFVars.State MetaM Unit := do
params.forM fun p => do
let type inferType p
type.collectFVars
fieldInfos.forM fun info => do
let fvarType inferType info.fvar
fvarType.collectFVars
match info.value? with
| none => pure ()
| some value => value.collectFVars
private def removeUnused (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo)
: TermElabM (LocalContext × LocalInstances × Array Expr) := do
let (_, used) (collectUsed params fieldInfos).run {}
Meta.removeUnused scopeVars used
private def withUsed {α} (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo) (k : Array Expr TermElabM α)
: TermElabM α := do
let (lctx, localInsts, vars) removeUnused scopeVars params fieldInfos
withLCtx lctx localInsts <| k vars
private def levelMVarToParam (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo) (univToInfer? : Option LMVarId) : TermElabM (Array StructFieldInfo) := do
levelMVarToParamFVars scopeVars
levelMVarToParamFVars params
fieldInfos.mapM fun info => do
levelMVarToParamFVar info.fvar
match info.value? with
| none => pure info
| some value =>
let value levelMVarToParam' value
pure { info with value? := value }
where
levelMVarToParam' (type : Expr) : TermElabM Expr := do
Term.levelMVarToParam type (except := fun mvarId => univToInfer? == some mvarId)
levelMVarToParamFVars (fvars : Array Expr) : TermElabM Unit :=
fvars.forM levelMVarToParamFVar
levelMVarToParamFVar (fvar : Expr) : TermElabM Unit := do
let type inferType fvar
discard <| levelMVarToParam' type
private partial def collectUniversesFromFields (r : Level) (rOffset : Nat) (fieldInfos : Array StructFieldInfo) : TermElabM (Array Level) := do
let (_, us) go |>.run #[]
return us
where
go : StateRefT (Array Level) TermElabM Unit :=
for info in fieldInfos do
let type inferType info.fvar
let u getLevel type
let u instantiateLevelMVars u
match ( modifyGet fun s => accLevel u r rOffset |>.run |>.run s) with
| some _ => pure ()
| none =>
let typeType inferType type
let mut msg := m!"failed to compute resulting universe level of structure, field '{info.declName}' has type{indentD m!"{type} : {typeType}"}\nstructure resulting type{indentExpr (mkSort (r.addOffset rOffset))}"
if r.isMVar then
msg := msg ++ "\nrecall that Lean only infers the resulting universe level automatically when there is a unique solution for the universe level constraints, consider explicitly providing the structure resulting universe level"
throwError msg
/--
Decides whether the structure should be `Prop`-valued when the universe is not given
and when the universe inference algorithm `collectUniversesFromFields` determines
that the inductive type could naturally be `Prop`-valued.
See `Lean.Elab.Command.isPropCandidate` for an explanation.
Specialized to structures, the heuristic is that we prefer a `Prop` instead of a `Type` structure
when it could be a syntactic subsingleton.
Exception: no-field structures are `Type` since they are likely stubbed-out declarations.
-/
private def isPropCandidate (fieldInfos : Array StructFieldInfo) : Bool :=
!fieldInfos.isEmpty
private def updateResultingUniverse (fieldInfos : Array StructFieldInfo) (type : Expr) : TermElabM Expr := do
let r getResultUniverse type
let rOffset : Nat := r.getOffset
let r : Level := r.getLevelOffset
unless r.isMVar do
throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly: {r}"
let us collectUniversesFromFields r rOffset fieldInfos
trace[Elab.structure] "updateResultingUniverse us: {us}, r: {r}, rOffset: {rOffset}"
let rNew := mkResultUniverse us rOffset (isPropCandidate fieldInfos)
assignLevelMVar r.mvarId! rNew
instantiateMVars type
private def collectLevelParamsInFVar (s : CollectLevelParams.State) (fvar : Expr) : TermElabM CollectLevelParams.State := do
let type inferType fvar
let type instantiateMVars type
return collectLevelParams s type
private def collectLevelParamsInFVars (fvars : Array Expr) (s : CollectLevelParams.State) : TermElabM CollectLevelParams.State :=
fvars.foldlM collectLevelParamsInFVar s
private def collectLevelParamsInStructure (structType : Expr) (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo)
: TermElabM (Array Name) := do
let s := collectLevelParams {} structType
let s collectLevelParamsInFVars scopeVars s
let s collectLevelParamsInFVars params s
let s fieldInfos.foldlM (init := s) fun s info => collectLevelParamsInFVar s info.fvar
return s.params
private def collectUsedFVars (lctx : LocalContext) (localInsts : LocalInstances) (fieldInfos : Array StructFieldInfo) :
StateRefT CollectFVars.State MetaM Unit := do
withLCtx lctx localInsts do
fieldInfos.forM fun info => do
let fvarType inferType info.fvar
fvarType.collectFVars
if let some value := info.value? then
value.collectFVars
private def addCtorFields (fieldInfos : Array StructFieldInfo) : Nat Expr TermElabM Expr
| 0, type => pure type
@@ -772,19 +661,29 @@ private def addCtorFields (fieldInfos : Array StructFieldInfo) : Nat → Expr
| _ =>
addCtorFields fieldInfos i (mkForall decl.userName decl.binderInfo decl.type type)
private def mkCtor (view : StructView) (levelParams : List Name) (params : Array Expr) (fieldInfos : Array StructFieldInfo) : TermElabM Constructor :=
private def mkCtor (view : StructView) (r : ElabHeaderResult) (params : Array Expr) (fieldInfos : Array StructFieldInfo) : TermElabM Constructor :=
withRef view.ref do
let type := mkAppN (mkConst view.declName (levelParams.map mkLevelParam)) params
let type := mkAppN r.indFVar params
let type addCtorFields fieldInfos fieldInfos.size type
let type mkForallFVars params type
let type instantiateMVars type
let type := type.inferImplicit params.size true
pure { name := view.ctor.declName, type }
private partial def checkResultingUniversesForFields (fieldInfos : Array StructFieldInfo) (u : Level) : TermElabM Unit := do
for info in fieldInfos do
let type inferType info.fvar
let v := ( instantiateLevelMVars ( getLevel type)).normalize
unless u.geq v do
let msg := m!"invalid universe level for field '{info.name}', has type{indentExpr type}\n\
at universe level{indentD v}\n\
which is not less than or equal to the structure's resulting universe level{indentD u}"
throwErrorAt info.ref msg
@[extern "lean_mk_projections"]
private opaque mkProjections (env : Environment) (structName : Name) (projs : List Name) (isClass : Bool) : Except KernelException Environment
private def addProjections (r : ElabStructHeaderResult) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
private def addProjections (r : ElabHeaderResult) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
if r.type.isProp then
if let some fieldInfo fieldInfos.findM? (not <$> Meta.isProof ·.fvar) then
throwErrorAt fieldInfo.ref m!"failed to generate projections for 'Prop' structure, field '{format fieldInfo.name}' is not a proof"
@@ -795,49 +694,71 @@ private def addProjections (r : ElabStructHeaderResult) (fieldInfos : Array Stru
private def registerStructure (structName : Name) (infos : Array StructFieldInfo) : TermElabM Unit := do
let fields infos.filterMapM fun info => do
if info.kind == StructFieldKind.fromParent then
return none
else
return some {
fieldName := info.name
projFn := info.declName
binderInfo := ( getFVarLocalDecl info.fvar).binderInfo
autoParam? := ( inferType info.fvar).getAutoParamTactic?
subobject? := if let .subobject parentName := info.kind then parentName else none
}
if info.kind == StructFieldKind.fromParent then
return none
else
return some {
fieldName := info.name
projFn := info.declName
binderInfo := ( getFVarLocalDecl info.fvar).binderInfo
autoParam? := ( inferType info.fvar).getAutoParamTactic?
subobject? := if let .subobject parentName := info.kind then parentName else none
}
modifyEnv fun env => Lean.registerStructure env { structName, fields }
private def mkAuxConstructions (declName : Name) : TermElabM Unit := do
let env getEnv
let hasEq := env.contains ``Eq
let hasHEq := env.contains ``HEq
let hasUnit := env.contains ``PUnit
let hasProd := env.contains ``Prod
mkRecOn declName
if hasUnit then mkCasesOn declName
if hasUnit && hasEq && hasHEq then mkNoConfusion declName
let ival getConstInfoInduct declName
if ival.isRec then
if hasUnit && hasProd then mkBelow declName
if hasUnit && hasProd then mkIBelow declName
if hasUnit && hasProd then mkBRecOn declName
if hasUnit && hasProd then mkBInductionOn declName
private def checkDefaults (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
let mut mvars := {}
let mut lmvars := {}
for fieldInfo in fieldInfos do
if let some value := fieldInfo.value? then
let value instantiateMVars value
mvars := Expr.collectMVars mvars value
lmvars := collectLevelMVars lmvars value
-- Log errors and ignore the failure; we later will just omit adding a default value.
if Term.logUnassignedUsingErrorInfos mvars.result then
return
else if Term.logUnassignedLevelMVarsUsingErrorInfos lmvars.result then
return
private def addDefaults (lctx : LocalContext) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
private def addDefaults (params : Array Expr) (replaceIndFVars : Expr MetaM Expr) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
let lctx getLCtx
/- The `lctx` and `defaultAuxDecls` are used to create the auxiliary "default value" declarations
The parameters `params` for these definitions must be marked as implicit, and all others as explicit. -/
let lctx :=
params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) =>
if p.isFVar then
lctx.setBinderInfo p.fvarId! BinderInfo.implicit
else
lctx
let lctx :=
fieldInfos.foldl (init := lctx) fun (lctx : LocalContext) (info : StructFieldInfo) =>
if info.isFromParent then lctx -- `fromParent` fields are elaborated as let-decls, and are zeta-expanded when creating "default value" auxiliary functions
else lctx.setBinderInfo info.fvar.fvarId! BinderInfo.default
-- Make all indFVar replacements in the local context.
let lctx
lctx.foldlM (init := {}) fun lctx ldecl => do
match ldecl with
| .cdecl _ fvarId userName type bi k =>
let type replaceIndFVars type
return lctx.mkLocalDecl fvarId userName type bi k
| .ldecl _ fvarId userName type value nonDep k =>
let type replaceIndFVars type
let value replaceIndFVars value
return lctx.mkLetDecl fvarId userName type value nonDep k
withLCtx lctx ( getLocalInstances) do
fieldInfos.forM fun fieldInfo => do
if let some value := fieldInfo.value? then
let declName := mkDefaultFnOfProjFn fieldInfo.declName
let type inferType fieldInfo.fvar
let value instantiateMVars value
if value.hasExprMVar then
discard <| Term.logUnassignedUsingErrorInfos ( getMVars value)
throwErrorAt fieldInfo.ref "invalid default value for field '{format fieldInfo.name}', it contains metavariables{indentExpr value}"
/- The identity function is used as "marker". -/
let value mkId value
-- No need to compile the definition, since it is only used during elaboration.
discard <| mkAuxDefinition declName type value (zetaDelta := true) (compile := false)
setReducibleAttribute declName
let type replaceIndFVars ( inferType fieldInfo.fvar)
let value instantiateMVars ( replaceIndFVars value)
trace[Elab.structure] "default value after 'replaceIndFVars': {indentExpr value}"
-- If there are mvars, `checkDefaults` already logged an error.
unless value.hasMVar || value.hasSyntheticSorry do
/- The identity function is used as "marker". -/
let value mkId value
-- No need to compile the definition, since it is only used during elaboration.
discard <| mkAuxDefinition declName type value (zetaDelta := true) (compile := false)
setReducibleAttribute declName
/--
Given `type` of the form `forall ... (source : A), B`, return `forall ... [source : A], B`.
@@ -906,57 +827,6 @@ private partial def mkCoercionToCopiedParent (levelParams : List Name) (params :
setReducibleAttribute declName
return { structName := parentStructName, subobject := false, projFn := declName }
private def elabStructHeader (view : StructView) : TermElabM ElabStructHeaderResult :=
Term.withAutoBoundImplicitForbiddenPred (fun n => view.shortDeclName == n) do
Term.withAutoBoundImplicit do
Term.elabBinders view.binders.getArgs fun params => do
elabParents view fun parentFieldInfos parents => do
let type Term.elabType view.type
Term.synthesizeSyntheticMVarsNoPostponing
let u mkFreshLevelMVar
unless isDefEq type (mkSort u) do
throwErrorAt view.type "invalid structure type, expecting 'Type _' or 'Prop'"
let type instantiateMVars ( whnf type)
Term.addAutoBoundImplicits' params type fun params type => do
let levelNames Term.getLevelNames
trace[Elab.structure] "header params: {params}, type: {type}, levelNames: {levelNames}"
return { lctx := ( getLCtx), localInsts := ( getLocalInstances), levelNames, params, type, view, parents, parentFieldInfos }
private def mkTypeFor (r : ElabStructHeaderResult) : TermElabM Expr := do
withLCtx r.lctx r.localInsts do
mkForallFVars r.params r.type
/--
Create a local declaration for the structure and execute `x params indFVar`, where `params` are the structure's type parameters and
`indFVar` is the new local declaration.
-/
private partial def withStructureLocalDecl (r : ElabStructHeaderResult) (x : Array Expr Expr TermElabM α) : TermElabM α := do
let declName := r.view.declName
let shortDeclName := r.view.shortDeclName
let type mkTypeFor r
let params := r.params
withLCtx r.lctx r.localInsts <| withRef r.view.ref do
Term.withAuxDecl shortDeclName type declName fun indFVar =>
x params indFVar
/--
Remark: `numVars <= numParams`.
`numVars` is the number of context `variables` used in the declaration,
and `numParams - numVars` is the number of parameters provided as binders in the declaration.
-/
private def mkInductiveType (view : StructView) (indFVar : Expr) (levelNames : List Name)
(numVars : Nat) (numParams : Nat) (type : Expr) (ctor : Constructor) : TermElabM InductiveType := do
let levelParams := levelNames.map mkLevelParam
let const := mkConst view.declName levelParams
let ctorType forallBoundedTelescope ctor.type numParams fun params type => do
let type := type.replace fun e =>
if e == indFVar then
mkAppN const (params.extract 0 numVars)
else
none
instantiateMVars ( mkForallFVars params type)
return { name := view.declName, type := instantiateMVars type, ctors := [{ ctor with type := instantiateMVars ctorType }] }
/--
Precomputes the structure's resolution order.
Option `structure.strictResolutionOrder` controls whether to create a warning if the C3 algorithm failed.
@@ -987,109 +857,50 @@ private def addParentInstances (parents : Array StructureParentInfo) : MetaM Uni
for instParent in instParents do
addInstance instParent.projFn AttributeKind.global (eval_prio default)
def mkStructureDecl (vars : Array Expr) (view : StructView) : TermElabM Unit := Term.withoutSavingRecAppSyntax do
let scopeLevelNames Term.getLevelNames
let isUnsafe := view.modifiers.isUnsafe
withRef view.ref <| Term.withLevelNames view.levelNames do
let r elabStructHeader view
Term.synthesizeSyntheticMVarsNoPostponing
withLCtx r.lctx r.localInsts do
withStructureLocalDecl r fun params indFVar => do
trace[Elab.structure] "indFVar: {indFVar}"
Term.addLocalVarInfo view.declId indFVar
withFields view.fields r.parentFieldInfos fun fieldInfos =>
withRef view.ref do
Term.synthesizeSyntheticMVarsNoPostponing
let type instantiateMVars r.type
let u getResultUniverse type
let univToInfer? shouldInferResultUniverse u
withUsed vars params fieldInfos fun scopeVars => do
let fieldInfos levelMVarToParam scopeVars params fieldInfos univToInfer?
let type withRef view.ref do
if univToInfer?.isSome then
updateResultingUniverse fieldInfos type
else
checkResultingUniverse ( getResultUniverse type)
pure type
trace[Elab.structure] "type: {type}"
let usedLevelNames collectLevelParamsInStructure type scopeVars params fieldInfos
match sortDeclLevelParams scopeLevelNames r.levelNames usedLevelNames with
| Except.error msg => throwErrorAt view.declId msg
| Except.ok levelParams =>
let params := scopeVars ++ params
let ctor mkCtor view levelParams params fieldInfos
let type mkForallFVars params type
let type instantiateMVars type
let indType mkInductiveType view indFVar levelParams scopeVars.size params.size type ctor
let decl := Declaration.inductDecl levelParams params.size [indType] isUnsafe
Term.ensureNoUnassignedMVars decl
addDecl decl
-- rename indFVar so that it does not shadow the actual declaration:
let lctx := ( getLCtx).modifyLocalDecl indFVar.fvarId! fun decl => decl.setUserName .anonymous
withLCtx lctx ( getLocalInstances) do
addProjections r fieldInfos
registerStructure view.declName fieldInfos
mkAuxConstructions view.declName
withSaveInfoContext do -- save new env
Term.addLocalVarInfo view.ref[1] ( mkConstWithLevelParams view.declName)
if let some _ := view.ctor.ref.getPos? (canonicalOnly := true) then
Term.addTermInfo' view.ctor.ref ( mkConstWithLevelParams view.ctor.declName) (isBinder := true)
for field in view.fields do
-- may not exist if overriding inherited field
if ( getEnv).contains field.declName then
Term.addTermInfo' field.ref ( mkConstWithLevelParams field.declName) (isBinder := true)
withRef view.declId do
Term.applyAttributesAt view.declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking
let parentInfos r.parents.mapM fun parent => do
if parent.subobject then
let some info := fieldInfos.find? (·.kind == .subobject parent.structName) | unreachable!
pure { structName := parent.structName, subobject := true, projFn := info.declName }
else
mkCoercionToCopiedParent levelParams params view parent.structName parent.type
setStructureParents view.declName parentInfos
checkResolutionOrder view.declName
if view.isClass then
addParentInstances parentInfos
let lctx getLCtx
/- The `lctx` and `defaultAuxDecls` are used to create the auxiliary "default value" declarations
The parameters `params` for these definitions must be marked as implicit, and all others as explicit. -/
let lctx :=
params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) =>
if p.isFVar then
lctx.setBinderInfo p.fvarId! BinderInfo.implicit
else
lctx
let lctx :=
fieldInfos.foldl (init := lctx) fun (lctx : LocalContext) (info : StructFieldInfo) =>
if info.isFromParent then lctx -- `fromParent` fields are elaborated as let-decls, and are zeta-expanded when creating "default value" auxiliary functions
else lctx.setBinderInfo info.fvar.fvarId! BinderInfo.default
addDefaults lctx fieldInfos
def elabStructureView (vars : Array Expr) (view : StructView) : TermElabM Unit := do
Term.withDeclName view.declName <| withRef view.ref do
mkStructureDecl vars view
unless view.isClass do
Lean.Meta.IndPredBelow.mkBelow view.declName
mkSizeOfInstances view.declName
mkInjectiveTheorems view.declName
def elabStructureViewPostprocessing (view : StructView) : CommandElabM Unit := do
view.derivingClasses.forM fun classView => classView.applyHandlers #[view.declName]
runTermElabM fun _ => Term.withDeclName view.declName <| withRef view.declId do
Term.applyAttributesAt view.declName view.modifiers.attrs .afterCompilation
def elabStructure (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit := do
let view runTermElabM fun vars => do
@[builtin_inductive_elab Lean.Parser.Command.«structure»]
def elabStructureCommand : InductiveElabDescr where
mkInductiveView (modifiers : Modifiers) (stx : Syntax) := do
let view structureSyntaxToView modifiers stx
trace[Elab.structure] "view.levelNames: {view.levelNames}"
elabStructureView vars view
pure view
elabStructureViewPostprocessing view
return {
view := view.toInductiveView
elabCtors := fun rs r params => do
withParents view rs r.indFVar fun parentFieldInfos parents =>
withFields view.fields parentFieldInfos fun fieldInfos => do
withRef view.ref do
Term.synthesizeSyntheticMVarsNoPostponing
let lctx getLCtx
let localInsts getLocalInstances
let ctor mkCtor view r params fieldInfos
return {
ctors := [ctor]
collectUsedFVars := collectUsedFVars lctx localInsts fieldInfos
checkUniverses := fun _ u => withLCtx lctx localInsts do checkResultingUniversesForFields fieldInfos u
finalizeTermElab := withLCtx lctx localInsts do checkDefaults fieldInfos
prefinalize := fun _ _ _ => do
withLCtx lctx localInsts do
addProjections r fieldInfos
registerStructure view.declName fieldInfos
withSaveInfoContext do -- save new env
for field in view.fields do
-- may not exist if overriding inherited field
if ( getEnv).contains field.declName then
Term.addTermInfo' field.ref ( mkConstWithLevelParams field.declName) (isBinder := true)
finalize := fun levelParams params replaceIndFVars => do
let parentInfos parents.mapM fun parent => do
if parent.subobject then
let some info := fieldInfos.find? (·.kind == .subobject parent.structName) | unreachable!
pure { structName := parent.structName, subobject := true, projFn := info.declName }
else
mkCoercionToCopiedParent levelParams params view parent.structName parent.type
setStructureParents view.declName parentInfos
checkResolutionOrder view.declName
if view.isClass then
addParentInstances parentInfos
builtin_initialize
registerTraceClass `Elab.structure
registerTraceClass `Elab.structure.resolutionOrder
withLCtx lctx localInsts do
addDefaults params replaceIndFVars fieldInfos
}
}
end Lean.Elab.Command

View File

@@ -19,12 +19,12 @@ def expandOptPrecedence (stx : Syntax) : MacroM (Option Nat) :=
return some ( evalPrec stx[0][1])
private def mkParserSeq (ds : Array (Term × Nat)) : TermElabM (Term × Nat) := do
if ds.size == 0 then
if h₀ : ds.size = 0 then
throwUnsupportedSyntax
else if ds.size == 1 then
pure ds[0]!
else if h₁ : ds.size = 1 then
pure ds[0]
else
let mut (r, stackSum) := ds[0]!
let mut (r, stackSum) := ds[0]
for (d, stackSz) in ds[1:ds.size] do
r `(ParserDescr.binary `andthen $r $d)
stackSum := stackSum + stackSz
@@ -142,7 +142,7 @@ where
let args := stx.getArgs
if ( checkLeftRec stx[0]) then
if args.size == 1 then throwErrorAt stx "invalid atomic left recursive syntax"
let args := args.eraseIdx 0
let args := args.eraseIdxIfInBounds 0
let args args.mapM fun arg => withNestedParser do process arg
mkParserSeq args
else
@@ -236,8 +236,9 @@ where
-- Pretty-printing instructions shouldn't affect validity
let s := s.trim
!s.isEmpty &&
(s.front != '\'' || s == "''") &&
(s.front != '\'' || "''".isPrefixOf s) &&
s.front != '\"' &&
!(isIdBeginEscape s.front) &&
!(s.front == '`' && (s.endPos == ⟨1⟩ || isIdFirst (s.get ⟨1⟩) || isIdBeginEscape (s.get ⟨1⟩))) &&
!s.front.isDigit &&
!(s.any Char.isWhitespace)

View File

@@ -149,8 +149,8 @@ where
-- Succeeded. Collect new TC problems
trace[Elab.defaultInstance] "isDefEq worked {mkMVar mvarId} : {← inferType (mkMVar mvarId)} =?= {candidate} : {← inferType candidate}"
let mut pending := []
for i in [:bis.size] do
if bis[i]! == BinderInfo.instImplicit then
for h : i in [:bis.size] do
if bis[i] == BinderInfo.instImplicit then
pending := mvars[i]!.mvarId! :: pending
synthesizePending pending
else

View File

@@ -58,15 +58,28 @@ where
let some eqBVPred ReifiedBVPred.mkBinPred atom resValExpr atomExpr resExpr .eq | return none
let eqBV ReifiedBVLogical.ofPred eqBVPred
let trueExpr := mkConst ``Bool.true
let impExpr mkArrow ( mkEq eqDiscrExpr trueExpr) ( mkEq eqBVExpr trueExpr)
let decideImpExpr mkAppOptM ``Decidable.decide #[some impExpr, none]
let imp ReifiedBVLogical.mkGate eqDiscr eqBV eqDiscrExpr eqBVExpr .imp
let proof := do
let evalExpr ReifiedBVLogical.mkEvalExpr imp.expr
let congrProof imp.evalsAtAtoms
let lemmaProof := mkApp4 (mkConst lemmaName) (toExpr lhs.width) discrExpr lhsExpr rhsExpr
let trueExpr := mkConst ``Bool.true
let eqDiscrTrueExpr mkEq eqDiscrExpr trueExpr
let eqBVExprTrueExpr mkEq eqBVExpr trueExpr
let impExpr mkArrow eqDiscrTrueExpr eqBVExprTrueExpr
-- construct a `Decidable` instance for the implication using forall_prop_decidable
let decEqDiscrTrue := mkApp2 (mkConst ``instDecidableEqBool) eqDiscrExpr trueExpr
let decEqBVExprTrue := mkApp2 (mkConst ``instDecidableEqBool) eqBVExpr trueExpr
let impDecidable := mkApp4 (mkConst ``forall_prop_decidable)
eqDiscrTrueExpr
(.lam .anonymous eqDiscrTrueExpr eqBVExprTrueExpr .default)
decEqDiscrTrue
(.lam .anonymous eqDiscrTrueExpr decEqBVExprTrue .default)
let decideImpExpr := mkApp2 (mkConst ``Decidable.decide) impExpr impDecidable
return mkApp4
(mkConst ``Std.Tactic.BVDecide.Reflect.Bool.lemma_congr)
decideImpExpr

View File

@@ -218,7 +218,7 @@ where
let old snap.old?
-- If the kind is equal, we can assume the old version was a macro as well
guard <| old.stx.isOfKind stx.getKind
let state old.val.get.data.finished.get.state?
let state old.val.get.finished.get.state?
guard <| state.term.meta.core.nextMacroScope == nextMacroScope
-- check absence of traces; see Note [Incremental Macros]
guard <| state.term.meta.core.traceState.traces.size == 0
@@ -226,9 +226,10 @@ where
return old.val.get
Language.withAlwaysResolvedPromise fun promise => do
-- Store new unfolding in the snapshot tree
snap.new.resolve <| .mk {
snap.new.resolve {
stx := stx'
diagnostics := .empty
inner? := none
finished := .pure {
diagnostics := .empty
state? := ( Tactic.saveState)
@@ -240,7 +241,7 @@ where
new := promise
old? := do
let old old?
return old.data.stx, ( old.data.next.get? 0)
return old.stx, ( old.next.get? 0)
} }) do
evalTactic stx'
return

View File

@@ -60,7 +60,7 @@ where
if let some snap := ( readThe Term.Context).tacSnap? then
if let some old := snap.old? then
let oldParsed := old.val.get
oldInner? := oldParsed.data.inner? |>.map (oldParsed.data.stx, ·)
oldInner? := oldParsed.inner? |>.map (oldParsed.stx, ·)
-- compare `stx[0]` for `finished`/`next` reuse, focus on remainder of script
Term.withNarrowedTacticReuse (stx := stx) (fun stx => (stx[0], mkNullNode stx.getArgs[1:])) fun stxs => do
let some snap := ( readThe Term.Context).tacSnap?
@@ -70,10 +70,10 @@ where
if let some old := snap.old? then
-- `tac` must be unchanged given the narrow above; let's reuse `finished`'s state!
let oldParsed := old.val.get
if let some state := oldParsed.data.finished.get.state? then
if let some state := oldParsed.finished.get.state? then
reusableResult? := some ((), state)
-- only allow `next` reuse in this case
oldNext? := oldParsed.data.next.get? 0 |>.map (old.stx, ·)
oldNext? := oldParsed.next.get? 0 |>.map (old.stx, ·)
-- For `tac`'s snapshot task range, disregard synthetic info as otherwise
-- `SnapshotTree.findInfoTreeAtPos` might choose the wrong snapshot: for example, when
@@ -89,7 +89,7 @@ where
withAlwaysResolvedPromise fun next => do
withAlwaysResolvedPromise fun finished => do
withAlwaysResolvedPromise fun inner => do
snap.new.resolve <| .mk {
snap.new.resolve {
desc := tac.getKind.toString
diagnostics := .empty
stx := tac

View File

@@ -13,6 +13,31 @@ open Meta
# Implementation of the `change` tactic
-/
/--
Elaborates the pattern `p` and ensures that it is defeq to `e`.
Emulates `(show p from ?m : e)`, returning the type of `?m`, but `e` and `p` do not need to be types.
Unlike `(show p from ?m : e)`, this can assign synthetic opaque metavariables appearing in `p`.
-/
def elabChange (e : Expr) (p : Term) : TacticM Expr := do
let p runTermElab do
let p Term.elabTermEnsuringType p ( inferType e)
unless isDefEq p e do
/-
Sometimes isDefEq can fail due to postponed elaboration problems.
We synthesize pending synthetic mvars while allowing typeclass instances to be postponed,
which might enable solving for them with an additional `isDefEq`.
-/
Term.synthesizeSyntheticMVars (postpone := .partial)
discard <| isDefEq p e
pure p
withAssignableSyntheticOpaque do
unless isDefEq p e do
let (p, tgt) addPPExplicitToExposeDiff p e
throwError "\
'change' tactic failed, pattern{indentExpr p}\n\
is not definitionally equal to target{indentExpr tgt}"
instantiateMVars p
/-- `change` can be used to replace the main goal or its hypotheses with
different, yet definitionally equal, goal or hypotheses.
@@ -38,15 +63,13 @@ the main goal. -/
| `(tactic| change $newType:term $[$loc:location]?) => do
withLocation (expandOptLocation (Lean.mkOptionalNode loc))
(atLocal := fun h => do
let hTy h.getType
-- This is a hack to get the new type to elaborate in the same sort of way that
-- it would for a `show` expression for the goal.
let mvar mkFreshExprMVar none
let (_, mvars) elabTermWithHoles
( `(term | show $newType from $( Term.exprToSyntax mvar))) hTy `change
let (hTy', mvars) withCollectingNewGoalsFrom (elabChange ( h.getType) newType) ( getMainTag) `change
liftMetaTactic fun mvarId => do
return ( mvarId.changeLocalDecl h ( inferType mvar)) :: mvars)
(atTarget := evalTactic <| `(tactic| refine_lift show $newType from ?_))
(failed := fun _ => throwError "change tactic failed")
return ( mvarId.changeLocalDecl h hTy') :: mvars)
(atTarget := do
let (tgt', mvars) withCollectingNewGoalsFrom (elabChange ( getMainTarget) newType) ( getMainTag) `change
liftMetaTactic fun mvarId => do
return ( mvarId.replaceTargetDefEq tgt') :: mvars)
(failed := fun _ => throwError "'change' tactic failed")
end Lean.Elab.Tactic

View File

@@ -114,6 +114,13 @@ private def elabConfig (recover : Bool) (structName : Name) (items : Array Confi
let e Term.withSynthesize <| Term.elabTermEnsuringType stx (mkConst structName)
instantiateMVars e
section
-- We automatically disable the following option for `macro`s but the subsequent `def` both contains
-- a quotation and is called only by `macro`s, so we disable the option for it manually. Note that
-- we can't use `in` as it is parsed as a single command and so the option would not influence the
-- parser.
set_option internal.parseQuotWithCurrentStage false
private def mkConfigElaborator
(doc? : Option (TSyntax ``Parser.Command.docComment)) (elabName type monadName : Ident)
(adapt recover : Term) : MacroM (TSyntax `command) := do
@@ -148,6 +155,8 @@ private def mkConfigElaborator
throwError msg
go)
end
/-!
`declare_config_elab elabName TypeName` declares a function `elabName : Syntax → TacticM TypeName`
that elaborates a tactic configuration.

View File

@@ -5,6 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.Elab.Tactic.ElabTerm
import Lean.Elab.Tactic.Change
import Lean.Elab.Tactic.Conv.Basic
namespace Lean.Elab.Tactic.Conv
@@ -15,11 +16,9 @@ open Meta
| `(conv| change $e) => withMainContext do
let lhs getLhs
let mvarCounterSaved := ( getMCtx).mvarCounter
let r elabTermEnsuringType e ( inferType lhs)
logUnassignedAndAbort ( filterOldMVars ( getMVars r) mvarCounterSaved)
unless ( isDefEqGuarded r lhs) do
throwError "invalid 'change' conv tactic, term{indentExpr r}\nis not definitionally equal to current left-hand-side{indentExpr lhs}"
changeLhs r
let lhs' elabChange lhs e
logUnassignedAndAbort ( filterOldMVars ( getMVars lhs') mvarCounterSaved)
changeLhs lhs'
| _ => throwUnsupportedSyntax
end Lean.Elab.Tactic.Conv

View File

@@ -18,21 +18,22 @@ private def mkKey (e : Expr) (simp : Bool) : MetaM (Array Key) := do
let (_, _, type) withReducible <| forallMetaTelescopeReducing e
let type whnfR type
if simp then
if let some (_, lhs, _) := type.eq? then
mkPath lhs simpDtConfig
else if let some (lhs, _) := type.iff? then
mkPath lhs simpDtConfig
else if let some (_, lhs, _) := type.ne? then
mkPath lhs simpDtConfig
else if let some p := type.not? then
match p.eq? with
| some (_, lhs, _) =>
mkPath lhs simpDtConfig
| _ => mkPath p simpDtConfig
else
mkPath type simpDtConfig
withSimpGlobalConfig do
if let some (_, lhs, _) := type.eq? then
mkPath lhs
else if let some (lhs, _) := type.iff? then
mkPath lhs
else if let some (_, lhs, _) := type.ne? then
mkPath lhs
else if let some p := type.not? then
match p.eq? with
| some (_, lhs, _) =>
mkPath lhs
| _ => mkPath p
else
mkPath type
else
mkPath type {}
mkPath type
private def getType (t : TSyntax `term) : TermElabM Expr := do
if let `($id:ident) := t then

View File

@@ -542,11 +542,6 @@ declare_config_elab elabDecideConfig Parser.Tactic.DecideConfig
let cfg elabDecideConfig stx[1]
evalDecideCore `decide cfg
@[builtin_tactic Lean.Parser.Tactic.decideBang] def evalDecideBang : Tactic := fun stx => do
let cfg elabDecideConfig stx[1]
let cfg := { cfg with kernel := true }
evalDecideCore `decide! cfg
@[builtin_tactic Lean.Parser.Tactic.nativeDecide] def evalNativeDecide : Tactic := fun stx => do
let cfg elabDecideConfig stx[1]
let cfg := { cfg with native := true }

View File

@@ -195,9 +195,6 @@ structure ExtTheorems where
erased : PHashSet Name := {}
deriving Inhabited
/-- Discrimation tree settings for the `ext` extension. -/
def extExt.config : WhnfCoreConfig := {}
/-- The environment extension to track `@[ext]` theorems. -/
builtin_initialize extExtension :
SimpleScopedEnvExtension ExtTheorem ExtTheorems
@@ -211,7 +208,7 @@ builtin_initialize extExtension :
ordered from high priority to low. -/
@[inline] def getExtTheorems (ty : Expr) : MetaM (Array ExtTheorem) := do
let extTheorems := extExtension.getState ( getEnv)
let arr extTheorems.tree.getMatch ty extExt.config
let arr extTheorems.tree.getMatch ty
let erasedArr := arr.filter fun thm => !extTheorems.erased.contains thm.declName
-- Using insertion sort because it is stable and the list of matches should be mostly sorted.
-- Most ext theorems have default priority.
@@ -258,7 +255,7 @@ builtin_initialize registerBuiltinAttribute {
but this theorem proves{indentD declTy}"
let some (ty, lhs, rhs) := declTy.eq? | failNotEq
unless lhs.isMVar && rhs.isMVar do failNotEq
let keys withReducible <| DiscrTree.mkPath ty extExt.config
let keys withReducible <| DiscrTree.mkPath ty
let priority liftCommandElabM <| Elab.liftMacroM do evalPrio (prio.getD ( `(prio| default)))
extExtension.add {declName, keys, priority} kind
-- Realize iff theorem

View File

@@ -282,10 +282,11 @@ where
-- them, eventually put each of them back in `Context.tacSnap?` in `applyAltStx`
withAlwaysResolvedPromise fun finished => do
withAlwaysResolvedPromises altStxs.size fun altPromises => do
tacSnap.new.resolve <| .mk {
tacSnap.new.resolve {
-- save all relevant syntax here for comparison with next document version
stx := mkNullNode altStxs
diagnostics := .empty
inner? := none
finished := { range? := none, task := finished.result }
next := altStxs.zipWith altPromises fun stx prom =>
{ range? := stx.getRange?, task := prom.result }
@@ -298,7 +299,7 @@ where
let old := old.val.get
-- use old version of `mkNullNode altsSyntax` as guard, will be compared with new
-- version and picked apart in `applyAltStx`
return old.data.stx, ( old.data.next[i]?)
return old.stx, ( old.next[i]?)
new := prom
}
finished.resolve { diagnostics := .empty, state? := ( saveState) }
@@ -340,9 +341,9 @@ where
for h : altStxIdx in [0:altStxs.size] do
let altStx := altStxs[altStxIdx]
let altName := getAltName altStx
if let some i := alts.findIdx? (·.1 == altName) then
if let some i := alts.findFinIdx? (·.1 == altName) then
-- cover named alternative
applyAltStx tacSnaps altStxIdx altStx alts[i]!
applyAltStx tacSnaps altStxIdx altStx alts[i]
alts := alts.eraseIdx i
else if !alts.isEmpty && isWildcard altStx then
-- cover all alternatives

View File

@@ -40,7 +40,7 @@ private def tryAssumption (mvarId : MVarId) : MetaM (List MVarId) := do
let ids := stx[1].getArgs.toList.map getNameOfIdent'
liftMetaTactic fun mvarId => do
match ( Meta.injections mvarId ids) with
| .solved => checkUnusedIds `injections mvarId ids; return []
| .subgoal mvarId' unusedIds => checkUnusedIds `injections mvarId unusedIds; tryAssumption mvarId'
| .solved => checkUnusedIds `injections mvarId ids; return []
| .subgoal mvarId' unusedIds _ => checkUnusedIds `injections mvarId unusedIds; tryAssumption mvarId'
end Lean.Elab.Tactic

View File

@@ -40,7 +40,7 @@ def exact? (ref : Syntax) (required : Option (Array (TSyntax `term))) (requireCl
| some suggestions =>
if requireClose then throwError
"`exact?` could not close the goal. Try `apply?` to see partial suggestions."
reportOutOfHeartbeats `library_search ref
reportOutOfHeartbeats `apply? ref
for (_, suggestionMCtx) in suggestions do
withMCtx suggestionMCtx do
addExactSuggestion ref ( instantiateMVars (mkMVar mvar)).headBeta (addSubgoalsMsg := true)

View File

@@ -91,7 +91,7 @@ def elabSimpConfig (optConfig : Syntax) (kind : SimpKind) : TacticM Meta.Simp.Co
| .simpAll => return ( elabSimpConfigCtxCore optConfig).toConfig
| .dsimp => return { ( elabDSimpConfigCore optConfig) with }
private def addDeclToUnfoldOrTheorem (thms : SimpTheorems) (id : Origin) (e : Expr) (post : Bool) (inv : Bool) (kind : SimpKind) : MetaM SimpTheorems := do
private def addDeclToUnfoldOrTheorem (config : Meta.ConfigWithKey) (thms : SimpTheorems) (id : Origin) (e : Expr) (post : Bool) (inv : Bool) (kind : SimpKind) : MetaM SimpTheorems := do
if e.isConst then
let declName := e.constName!
let info getConstInfo declName
@@ -108,7 +108,7 @@ private def addDeclToUnfoldOrTheorem (thms : SimpTheorems) (id : Origin) (e : Ex
let fvarId := e.fvarId!
let decl fvarId.getDecl
if ( isProp decl.type) then
thms.add id #[] e (post := post) (inv := inv)
thms.add id #[] e (post := post) (inv := inv) (config := config)
else if !decl.isLet then
throwError "invalid argument, variable is not a proposition or let-declaration"
else if inv then
@@ -116,9 +116,9 @@ private def addDeclToUnfoldOrTheorem (thms : SimpTheorems) (id : Origin) (e : Ex
else
return thms.addLetDeclToUnfold fvarId
else
thms.add id #[] e (post := post) (inv := inv)
thms.add id #[] e (post := post) (inv := inv) (config := config)
private def addSimpTheorem (thms : SimpTheorems) (id : Origin) (stx : Syntax) (post : Bool) (inv : Bool) : TermElabM SimpTheorems := do
private def addSimpTheorem (config : Meta.ConfigWithKey) (thms : SimpTheorems) (id : Origin) (stx : Syntax) (post : Bool) (inv : Bool) : TermElabM SimpTheorems := do
let thm? Term.withoutModifyingElabMetaStateWithInfo <| withRef stx do
let e Term.elabTerm stx none
Term.synthesizeSyntheticMVars (postpone := .no) (ignoreStuckTC := true)
@@ -132,7 +132,7 @@ private def addSimpTheorem (thms : SimpTheorems) (id : Origin) (stx : Syntax) (p
else
return some (#[], e)
if let some (levelParams, proof) := thm? then
thms.add id levelParams proof (post := post) (inv := inv)
thms.add id levelParams proof (post := post) (inv := inv) (config := config)
else
return thms
@@ -212,7 +212,7 @@ def elabSimpArgs (stx : Syntax) (ctx : Simp.Context) (simprocs : Simp.SimprocsAr
match ( resolveSimpIdTheorem? term) with
| .expr e =>
let name mkFreshId
thms addDeclToUnfoldOrTheorem thms (.stx name arg) e post inv kind
thms addDeclToUnfoldOrTheorem ctx.indexConfig thms (.stx name arg) e post inv kind
| .simproc declName =>
simprocs simprocs.add declName post
| .ext (some ext₁) (some ext₂) _ =>
@@ -224,7 +224,7 @@ def elabSimpArgs (stx : Syntax) (ctx : Simp.Context) (simprocs : Simp.SimprocsAr
simprocs := simprocs.push ( ext₂.getSimprocs)
| .none =>
let name mkFreshId
thms addSimpTheorem thms (.stx name arg) term post inv
thms addSimpTheorem ctx.indexConfig thms (.stx name arg) term post inv
else if arg.getKind == ``Lean.Parser.Tactic.simpStar then
starArg := true
else
@@ -329,7 +329,7 @@ def mkSimpContext (stx : Syntax) (eraseLocal : Bool) (kind := SimpKind.simp)
let hs getPropHyps
for h in hs do
unless simpTheorems.isErased (.fvar h) do
simpTheorems simpTheorems.addTheorem (.fvar h) ( h.getDecl).toExpr
simpTheorems simpTheorems.addTheorem (.fvar h) ( h.getDecl).toExpr (config := ctx.indexConfig)
let ctx := ctx.setSimpTheorems simpTheorems
return { ctx, simprocs, dischargeWrapper }

View File

@@ -25,7 +25,7 @@ def elabSimprocPattern (stx : Syntax) : MetaM Expr := do
def elabSimprocKeys (stx : Syntax) : MetaM (Array Meta.SimpTheoremKey) := do
let pattern elabSimprocPattern stx
DiscrTree.mkPath pattern simpDtConfig
withSimpGlobalConfig <| DiscrTree.mkPath pattern
def checkSimprocType (declName : Name) : CoreM Bool := do
let decl getConstInfo declName

Some files were not shown because too many files have changed in this diff Show More