Compare commits

...

49 Commits

Author SHA1 Message Date
Kim Morrison
7a2e87fabe chore: update grind IndexMap example 2026-02-01 20:49:19 +11:00
Lean stage0 autoupdater
ec60620534 chore: update stage0 2026-02-01 03:54:16 +00:00
Leonardo de Moura
4606c35c40 feat: @[instance_reducible] (#12247)
This PR adds the new transparency setting `@[instance_reducible]`. We
used to check whether a declaration had `instance` reducibility by using
the `isInstance` predicate. However, this was not a robust solution
because:

- We have scoped instances, and `isInstance` returns `true` only if the
scope is active.

- We have auxiliary declarations used to construct instances manually,
such as:

```lean
    def lt_wfRel : WellFoundedRelation Nat
```
    
`isInstance` also returns `false` for this kind of declaration.

In both cases, the declaration may be (or may have been) used to
construct an instance, but `isInstance`
returns `false`. Thus, we claim it is a mistake to check the
reducibility status using `isInstance`.
`isInstance` indicates whether a declaration is available for the type
class resolution mechanism,
not its transparency status.

**We are decoupling whether a declaration is available for type class
resolution from its transparency status.**

**Remak**: We need a update stage0 to complete this feature.

---------

Co-authored-by: Sebastian Ullrich <sebasti@nullri.ch>
2026-02-01 03:03:16 +00:00
Leonardo de Moura
9f4c81342e chore: cleanup test (#12262) 2026-01-31 23:22:13 +00:00
Mac Malone
89c01c9e7e fix: lake: facet names in unknown facet errors (#12261)
This PR fixes a bug in Lake where the facet names printed in unknown
facet errors would contain the internal facet kind.
2026-01-31 20:57:13 +00:00
Mac Malone
ce980895b2 fix: IO.Process.spawn empty env var on Windows (#12220)
This PR fixes a bug on Windows with `IO.Process.spawn` where setting an
environment variable to the empty string would not set the environment
variable on the subprocess.
2026-01-31 19:17:26 +00:00
Wojciech Różowski
6c5de545f9 feat: add orElse combinator to Sym.Simp.Simproc (#12236)
This PR adds `orElse` combinator to simprocs of `Sym.Simp`.
2026-01-31 18:34:19 +00:00
Leonardo de Moura
21a281b496 fix: bug in instantiateRangeS' (#12260)
This PR fixes a bug in the function `instantiateRangeS'` in the `Sym`
framework.
2026-01-31 17:50:03 +00:00
Paul Reichert
7cd6b78a9c feat: Std.Iter.isEmpty (#12212)
This PR adds the function `Std.Iter.isEmpty` and proves the
specification lemmas `Std.Iter.isEmpty_eq_match_step` and
`Std.Iter.isEmpty_toList` if the iterator is productive.

The monadic variant on `Std.IterM` is also provided.
2026-01-31 16:18:35 +00:00
Paul Reichert
b64e5dec1e feat: projected minima and maxima (#11938)
This PR introduces projected minima and maxima, also known as
"argmin/argmax", for lists under the names `List.minOn` and
`List.maxOn`. It also introduces `List.minIdxOn` and `List.maxIdxOn`,
which return the index of the minimal or maximal element. Moreover,
there are variants with `?` suffix that return an `Option`. The change
further introduces new instances for opposite orders, such as
`LE.opposite`, `IsLinearOrder.opposite` etc. The change also adds the
missing `Std.lt_irrefl` lemma.
2026-01-31 16:16:32 +00:00
Leonardo de Moura
d1514f3cec perf: cache unfold_definition in the kernel (#12259)
This PR ensures we cache the result of `unfold_definition` definition in
the kernel type checker. We used to cache this information in a thread
local storage, but it was deleted during the Lean 3 to Lean 4
transition.
2026-01-31 03:44:50 +00:00
Kim Morrison
a972c4f50d fix: include local variable dot notation params in grind? suggestions (#12224)
This PR fixes a bug where `grind?` suggestions would not include
parameters using local variable dot notation (e.g.,
`cs.getD_rightInvSeq` where `cs` is a local variable). These parameters
were incorrectly filtered out because the code assumed all ident params
resolve to global declarations. In fact, local variable dot notation
produces anchors that need the original term to be loaded during replay,
so they must be preserved in the suggestion.

Closes #12185

🤖 Prepared with Claude Code

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-31 00:34:28 +00:00
Sebastian Graf
7416309805 test: teach SymM mvcgen to recognize specialized theorem applications (#12256)
This PR recognizes certain kinds of composite proof terms of the form
`hpre.trans hspec |> (wp prog).mono _ _ hpost` and abstracts them into
bespoke theorems. This should yield smaller proof terms. Sadly, kernel
checking time is unaffected, even regressing a bit. The number of shared
terms stays almost the same (+- a constant). Hence I deactivate the code
path in this patch. We keep the code, though, because it might be useful
in the future, also there are a few other improvements.
2026-01-30 17:00:59 +00:00
Kim Morrison
bb68f31527 doc: add pp.mvars advice to #guard_msgs docstring (#12253)
This PR adds a "Stabilizing output" section to the `#guard_msgs`
docstring, explaining how to use `pp.mvars.anonymous` and `pp.mvars`
options to stabilize output containing autogenerated metavariable names
like `?m.47`.

This was prompted by discussion on Zulip about improving #mwe
documentation:
https://leanprover.zulipchat.com/#narrow/channel/287929-mathlib4/topic/JacobiZariski.20is.20slow.2E/near/570739745

🤖 Prepared with Claude Code

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-30 16:13:24 +00:00
Sebastian Ullrich
85341d02ac feat: immediate noncomputable check (#12028)
This PR gives a simpler semantics to `noncomputable`, improving
predictability as well as preparing codegen to be moved into a separate
build step without breaking immediate generation of error messages.

Specifically, `noncomputable` is now needed whenever an axiom or another
`noncomputable` def is used by a def except for the following special
cases:
* uses inside proofs, types, type formers, and constructor arguments
corresponding to (fixed) inductive parameters are ignored
* uses of functions marked `@[extern]/@[implemented_by]/@[csimp]` are
ignored
* for applications of a function marked `@[macro_inline]`,
noncomputability of the inlining is instead inspected

# Breaking change

After this change, more `noncomputable` annotations than before may be
required in exchange for improved future stability.
2026-01-30 16:07:25 +00:00
Sebastian Graf
59f3abd0bd test: add SymM mvcgen and port the add_sub_cancel benchmark (#12251)
This PR adds a clone of the `mvcgen` tactic based on `SymM` and
evaluates it based on a ported `add_sub_cancel` benchmark. Notably, it
can reuse all the existing `@[spec]`-annotated theorems to generate VCs.
(It doesn't do control-flow splitting, simp rules on the program
expression or handling of lets; we'll get there.)

It is quite fast already, with the kernel being the bottle-neck:

```
goal_50: 69.524305 ms, kernel: 155.327778 ms
goal_100: 93.834221 ms, kernel: 407.370786 ms
goal_150: 131.364098 ms, kernel: 762.936720 ms
goal_200: 169.577172 ms, kernel: 1181.199093 ms
goal_250: 206.421738 ms, kernel: 1707.539380 ms
```

```
goal_200: 169.458637 ms, kernel: 1186.221085 ms
goal_400: 322.819718 ms, kernel: 3791.613854 ms
goal_600: 474.929013 ms, kernel: 7763.373757 ms
goal_800: 634.379422 ms, kernel: 13107.810430 ms
```

It is best compared to the `solveUsingSym <n> false true` measurements
of the SymM `add_sub_cancel` benchmark (`false`: without intermediate
eager simplification). For `n=200`, it reports

```
goal_200: 779.482300 ms, kernel: 742.097404 ms
```

suggesting that the generated proof term could be improved for kernel
reduction. (TODO.)
I'm unsure whether `solveUsingSym` is run in interpreted mode, so take
the >400% speedup with a grain of salt.
We can definitely conclude that VC generation time is currently not a
bottleneck compared to kernel checking time.

Plot for discharging goals of sizes 100..800:
<img width="1000" height="600" alt="Code_Generated_Image(1)"
src="https://github.com/user-attachments/assets/90e76a45-fa46-4d02-912a-c3355e2aa094"
/>

Plot comparing Kernel and Goal time: 
<img width="1000" height="600" alt="Code_Generated_Image(2)"
src="https://github.com/user-attachments/assets/5849ba0f-1d83-4f2d-98dd-fa65b840bb4e"
/>
2026-01-30 15:12:54 +00:00
Sebastian Graf
2f3912df74 feat: define Triple.iff, Triple.iff_conseq etc. and use defeq less (#12250)
This PR introduces the defining equality `Triple.iff` and uses that in
proofs instead of relying on definitional equality. It also introduces
`Triple.iff_conseq` that is useful for backward reasoning and introduces
verification conditions. Similarly, `Triple.entails_wp_*` theorems are
introduced for backward reasoning where the target is an stateful
entailment rather than a triple.
2026-01-30 14:03:22 +00:00
Henrik Böving
5ce756f350 refactor: introduce a phase separation to the IR (#12214)
This PR introduces a phase separation to the LCNF IR. This is a
preparation for the merge of
the old `Lean.Compiler.IR` and the new `Lean.Compiler.LCNF` framework.

The change parametrizes all relevant `LCNF` data structures over a
`Purity` parameter and
additionally carries around proofs that the `Purity` has certain values,
depending on what's
required. This is done as opposed to indexing the types over `Purity`
because we do (almost) never
have to store the `Purity` value for phase generic structures this way.
2026-01-30 09:42:29 +00:00
Lean stage0 autoupdater
6d370ec3c2 chore: update stage0 2026-01-30 09:12:55 +00:00
Henrik Böving
332c1ec46a perf: specializer a little more courageously (#12239)
This PR reverts a lot of the changes done in #8308. We practically
encountered situations such as:
```
fun y (z) :=
  let x := inst
  mkInst x z
f y
```
Where the instance puller turns it into:
```
let x := inst
fun y (z) :=
  mkInst x z
f y
```
The current heuristic now discovers `x` being in scope at the call site
of `f` and being used under a binder in `y` and thus blocks pulling in
`x` to the specialization, abstracting over an instance.

According to @zwarich this was done at the time either due to observed
stack overflows or pulling in computation into loops. With the current
configuration for abstraction in specialization it seems rather unlikely
that we pull in a non trivial computation into a loop with this. We also
practically didn't observe stack overflows in our tests or benchmarks.
Cameron speculates that the issues he observed might've been fixed
otherwise by now.

Crucial note: Deciding not to abstract over ground terms *might* cause
us to pull in computationally intensive ground terms into a loop. We
could decide to weaken this to just instance terms though of course even
computing instances might end up being non-trivial.
2026-01-30 08:23:15 +00:00
Joachim Breitner
4c5e3d73af fix: deriving Ord with indexed data type (#12243)
This PR fixes #12240, where `deriving Ord` failed with `Unknown
identifier a✝`.
2026-01-29 20:50:14 +00:00
Sebastian Ullrich
2b2b72d113 test: more .git cleanup (#12238)
Co-authored-by: Mac Malone <mac@lean-fro.org>
2026-01-29 17:43:31 +00:00
Garmelon
5b0b365406 chore: stop make install from printing every individual file (#12235)
https://cmake.org/cmake/help/latest/variable/CMAKE_INSTALL_MESSAGE.html
2026-01-29 16:50:21 +00:00
Sebastian Ullrich
892cbe22f8 fix: run @[init] declarations in declaration order (#12221)
Fixes #10175 harder.
2026-01-29 15:32:56 +00:00
Paul Reichert
3883f0f669 feat: min(?)/max(?) for Array (#11936)
This PR provides `Array` operations analogous to `List.min(?)` and
`List.max(?)`.

I had to prove a few auxiliary lemmas. Downstream in Batteries, which
already had `List.min` and `List.max`, I renamed their variants to
`List.rangeMin` and `List.rangeMax` in the PR testing branch. Their
version is more general in the sense that it has `start` and `stop`
autoParams, like `Array.foldl` has, but I think the futore belongs to
`Subarray.min` instead (which I haven't implemented yet).
2026-01-29 14:12:02 +00:00
Marc Huisinga
30c8b39b23 test: fix broken uri test (#12230) 2026-01-29 13:52:36 +00:00
Paul Reichert
e7b6bd6734 refactor: rename Iter(M).count to Iter(M).length (#12210)
This PR renames `Iter(M).count` to `Iter(M).length` and updates lots of
lemmas, adding deprecations.
2026-01-29 07:26:13 +00:00
Paul Reichert
16919852d9 refactor: remove last appearances of allowNontermination (#12211)
This PR updates docstrings and function signatures in order to complete
the transition from `Iter.Partial` to `Iter.Total` (extrinsically
terminating by default). It also deprecates `allowNontermination` and
adds `Iter.Total.atIdxSlow?`.
2026-01-29 07:22:19 +00:00
Leonardo de Moura
29545dcf10 feat: do not dsimp instances (#12195)
This PR ensures `dsimp` does not "simplify" instances by default. The
old behavior can be retrieved by using
```
set_option backward.dsimp.instances true
```
Applying `dsimp` to instances creates non-standard instances, and this
creates all sorts of problems in Mathlib.
This modification is similar to
```
set_option backward.dsimp.proofs true
```

---------

Co-authored-by: Kim Morrison <kim@tqft.net>
Co-authored-by: Claude <noreply@anthropic.com>
2026-01-29 05:25:01 +00:00
Kim Morrison
b772852522 fix: verify PR release artifacts before creating tags (#12223)
This PR moves the artifact verification step before tag creation and
release deletion, so we fail early if no artifacts are available rather
than creating side effects that would need to be cleaned up.

Addresses feedback from
https://lean-fro.zulipchat.com/#narrow/channel/399079-infrastructure/topic/PR.20toolchain.20even.20if.20test.20suite.20fails/near/570482678

🤖 Prepared with Claude Code

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-29 04:41:11 +00:00
Kim Morrison
ebec1b3a16 fix: typo in ExtractLetsConfig doc comment (#12174)
This PR fixes a typo in `ExtractLetsConfig.merge` doc comment.

Reported on Zulip:
https://leanprover.zulipchat.com/#narrow/channel/270676-lean4/topic/Typo.20in.20Init.2FMetaTypes.2Elean/near/568698828

🤖 Prepared with Claude Code

Co-authored-by: Claude <noreply@anthropic.com>
2026-01-29 04:40:43 +00:00
Kim Morrison
00c8431cf8 doc: add changelog label instructions to CLAUDE.md (#12227)
This PR documents the available `changelog-*` labels and when to use
them in the project-specific CLAUDE.md instructions.

🤖 Prepared with Claude Code

Co-authored-by: Claude <noreply@anthropic.com>
2026-01-29 03:24:23 +00:00
Rob23oba
b919cfff30 fix: public section in Dyadic files (#12199)
This PR fixes `Init.Data.Dyadic.Instances` and `Init.Data.Dyadic.Inv`.
Previously, all declarations defined in boths file were private and not
exposed.
2026-01-29 03:05:43 +00:00
Kim Morrison
e441ed8e46 Revert "doc: add changelog label instructions to CLAUDE.md"
This reverts commit 119533d602.
2026-01-29 03:16:10 +00:00
Kim Morrison
119533d602 doc: add changelog label instructions to CLAUDE.md
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-29 03:14:38 +00:00
Kim Morrison
9b9ce0c2ac feat: adjust grind annotations for List.drop (#12170)
This PR adjusts the grind annotations for List.take/drop, and adds two
theorems.

This resolves problems @datokrat encountered while working on
https://github.com/leanprover/human-eval-lean/blob/master/HumanEvalLean/HumanEval114.lean.
2026-01-29 00:27:46 +00:00
Leonardo de Moura
3f0acbbb48 fix: use isClass? instead of binder annotation to identify instance parameters (#12172)
This PR fixes how we determine whether a function parameter is an
instance.
Previously, we relied on binder annotations (e.g., `[Ring A]` vs `{_ :
Ring A}`)
to make this determination. This is unreliable because users
legitimately use
`{..}` binders for class types when the instance is already available
from
context. For example:
```lean
structure OrdSet (α : Type) [Hashable α] [BEq α] where
  ...

def OrdSet.insert {_ : Hashable α} {_ : BEq α} (s : OrdSet α) (a : α) : OrdSet α :=
  ...
```

Here, `Hashable` and `BEq` are classes, but the `{..}` binder is
intentional, the
instances come from `OrdSet`'s parameters, so type class resolution is
unnecessary.

The fix checks the parameter's *type* using `isClass?` rather than its
syntax, and
caches this information in `FunInfo`. This affects several subsystems:

- **Discrimination trees**: instance parameters should not be indexed
even if marked with `{..}`
- **Congruence lemma generation**: instances require special treatment
- **`grind` canonicalizer**: must ensure canonical instances

**Potential regressions**: automation may now behave differently in
cases where it
previously misidentified instance parameters. For example, a rewrite
rule in `simp` that was
not firing due to incorrect indexing may now fire.

---------

Co-authored-by: Kim Morrison <kim@tqft.net>
Co-authored-by: Claude <noreply@anthropic.com>
2026-01-28 20:33:43 +00:00
Garmelon
6dcd6c8f08 chore: reformat all cmake files (#12218)
The script to run for reformatting is `script/fmt`.
2026-01-28 18:23:08 +00:00
Eric Wieser
71be4901c3 fix: do not compile with -fwrapv (#12132)
This PR removes the requirement that libraries compiled against the lean
headers must use `-fwrapv`.

clang
[documents](https://clang.llvm.org/docs/UndefinedBehaviorSanitizer.html#:~:text=Note%20that%20checks%20are%20still%20added%20even%20when%20%2Dfwrapv%20is%20enabled)
that `-fwrapv` does not automatically turn off the integer overflow
sanitizer; and so overflow should still be avoided in normal execution.

This is a retry of #12098 after it was reverted in #12125.
2026-01-28 16:16:15 +00:00
Garmelon
5e13e71a84 chore: fix cmake if conditions (#12213)
Due to the way variable expansion and if interact in cmake, unquoted
variable expansions should essentially never be used inside if and may
lead to unexpected behavior. Also, quoted variable expansions can
usually be replaced by the unquoted variable name.

For more details, see this section in the cmake docs:
https://cmake.org/cmake/help/latest/command/if.html#variable-expansion

As one example of the kinds of issues that can occur with unquoted
variable expansions, consider this check from
`src/shell/CMakeLists.txt`, which tries to ensure that a test is only
run in non-WASM builds.

```cmake
if(NOT ${EMSCRIPTEN})
```

If the variable `EMSCRIPTEN` is empty or not defined (as is the case in
a non-WASM build), `${EMSCRIPTEN}` expands to 0 arguments, meaning the
check becomes

```cmake
if(NOT)
```

Since the `NOT` is unquoted, the if now tries to resolve it as a
variable. Since the variable `NOT` does not exist, the condition is
false and the test is never executed, even in non-WASM builds.
2026-01-28 15:37:18 +00:00
Henrik Böving
08ee91a433 feat: add DecidableEq instances for Sigma and PSimga (#12193)
This PR adds `DecidableEq` instances for `Sigma` and `PSigma`.
2026-01-28 15:00:45 +00:00
Sebastian Ullrich
f790ff1961 chore: remove obsolete repeat macro 2026-01-28 16:27:57 +01:00
Sebastian Ullrich
c4aac5d7c5 chore: update stage0 2026-01-28 16:27:57 +01:00
Sebastian Ullrich
316761c202 perf: make repeat an elaborator 2026-01-28 16:27:57 +01:00
Paul Reichert
b248b13ac2 feat: add useful lemmas about division (#12019)
This PR provides the `Nat`/`Int` lemmas `x ≤ y * z ↔ (x + z - 1) / z ≤
y`, `x ≤ y * z ↔ (x + y - 1) / y ≤ z` and `x / z + y / z ≤ (x + y) / z`.

The PR is inspired by a `human-eval-lean` problem, the solution of which
required these lemmas.
2026-01-28 14:17:47 +00:00
Joachim Breitner
08f43acefb perf: add introSubstEq shortcut (#12190)
This PR adds the `introSubstEq` MetaM tactic, as an optimization over
`intro h; subst h` that avoids introducing `h : a = b` if it can be
avoided,
which is the case when `b` can be reverted without reverting anything
else. Speeds up the generation of `injEq` theorem.
2026-01-28 12:33:14 +00:00
Sebastian Graf
9a37dba765 chore: express SPred lemmas using Iff instead of Eq (#12209) 2026-01-28 10:19:55 +00:00
Henrik Böving
a47eb31076 chore: remove the LCNF testing framework (#12207)
This PR removes the LCNF testing framework. Unfortunately it never got
used much and porting it to
the extended LCNF structure now would be a bit of effort that would
ultimately be in vain.
2026-01-28 10:09:30 +00:00
Marc Huisinga
819fb6a6a8 fix: use windows path separators in System.Uri.fileUriToPath? (#12197)
This PR fixes a bug in `System.Uri.fileUriToPath?` where it wouldn't use
the default Windows path separator in the path it produces.

It also adjusts the URI patching in the interactive test runner to be
more robust.
2026-01-28 09:10:34 +00:00
1053 changed files with 7825 additions and 4050 deletions

View File

@@ -46,6 +46,21 @@ This PR adds a `num?` parameter to `mkPatternFromTheorem` to control how many
leading quantifiers are stripped when creating a pattern.
```
**Changelog labels:** Add one `changelog-*` label to categorize the PR for release notes:
- `changelog-language` - Language features and metaprograms
- `changelog-tactics` - User facing tactics
- `changelog-server` - Language server, widgets, and IDE extensions
- `changelog-pp` - Pretty printing
- `changelog-library` - Library
- `changelog-compiler` - Compiler, runtime, and FFI
- `changelog-lake` - Lake
- `changelog-doc` - Documentation
- `changelog-ffi` - FFI changes
- `changelog-other` - Other changes
- `changelog-no` - Do not include this PR in the release changelog
If you're unsure which label applies, it's fine to omit the label and let reviewers add it.
## CI Log Retrieval
When CI jobs fail, investigate immediately - don't wait for other jobs to complete. Individual job logs are often available even while other jobs are still running. Try `gh run view <run-id> --log` or `gh run view <run-id> --log-failed`, or use `gh run view <run-id> --job=<job-id>` to target the specific failed job. Sleeping is fine when asked to monitor CI and no failures exist yet, but once any job fails, investigate that failure immediately.

View File

@@ -43,6 +43,19 @@ jobs:
name: build-.*
name_is_regexp: true
# Verify artifacts were downloaded before any side effects (tag creation, release deletion).
- name: Verify release artifacts exist
if: ${{ steps.workflow-info.outputs.pullRequestNumber != '' }}
run: |
shopt -s nullglob
files=(artifacts/*/*)
if [ ${#files[@]} -eq 0 ]; then
echo "::error::No artifacts found matching artifacts/*/*"
exit 1
fi
echo "Found ${#files[@]} artifacts to upload:"
printf '%s\n' "${files[@]}"
- name: Push tag
if: ${{ steps.workflow-info.outputs.pullRequestNumber != '' }}
run: |
@@ -74,18 +87,6 @@ jobs:
gh release delete --repo ${{ github.repository_owner }}/lean4-pr-releases pr-release-${{ steps.workflow-info.outputs.pullRequestNumber }}-${{ env.SHORT_SHA }} -y || true
env:
GH_TOKEN: ${{ secrets.PR_RELEASES_TOKEN }}
# Verify artifacts were downloaded (equivalent to fail_on_unmatched_files in the old action).
- name: Verify release artifacts exist
if: ${{ steps.workflow-info.outputs.pullRequestNumber != '' }}
run: |
shopt -s nullglob
files=(artifacts/*/*)
if [ ${#files[@]} -eq 0 ]; then
echo "::error::No artifacts found matching artifacts/*/*"
exit 1
fi
echo "Found ${#files[@]} artifacts to upload:"
printf '%s\n' "${files[@]}"
# We use `gh release create` instead of `softprops/action-gh-release` because
# the latter enumerates all releases to check for existing ones, which fails
# when the repository has more than 10000 releases (GitHub API pagination limit).

View File

@@ -10,22 +10,22 @@ option(USE_MIMALLOC "use mimalloc" ON)
get_cmake_property(vars CACHE_VARIABLES)
foreach(var ${vars})
get_property(currentHelpString CACHE "${var}" PROPERTY HELPSTRING)
if("${var}" MATCHES "STAGE0_(.*)")
if(var MATCHES "STAGE0_(.*)")
list(APPEND STAGE0_ARGS "-D${CMAKE_MATCH_1}=${${var}}")
elseif("${var}" MATCHES "STAGE1_(.*)")
elseif(var MATCHES "STAGE1_(.*)")
list(APPEND STAGE1_ARGS "-D${CMAKE_MATCH_1}=${${var}}")
elseif("${currentHelpString}" MATCHES "No help, variable specified on the command line." OR "${currentHelpString}" STREQUAL "")
elseif(currentHelpString MATCHES "No help, variable specified on the command line." OR currentHelpString STREQUAL "")
list(APPEND CL_ARGS "-D${var}=${${var}}")
if("${var}" MATCHES "USE_GMP|CHECK_OLEAN_VERSION|LEAN_VERSION_.*|LEAN_SPECIAL_VERSION_DESC")
if(var MATCHES "USE_GMP|CHECK_OLEAN_VERSION|LEAN_VERSION_.*|LEAN_SPECIAL_VERSION_DESC")
# must forward options that generate incompatible .olean format
list(APPEND STAGE0_ARGS "-D${var}=${${var}}")
elseif("${var}" MATCHES "LLVM*|PKG_CONFIG|USE_LAKE|USE_MIMALLOC")
elseif(var MATCHES "LLVM*|PKG_CONFIG|USE_LAKE|USE_MIMALLOC")
list(APPEND STAGE0_ARGS "-D${var}=${${var}}")
endif()
elseif("${var}" MATCHES "USE_MIMALLOC")
elseif(var MATCHES "USE_MIMALLOC")
list(APPEND CL_ARGS "-D${var}=${${var}}")
list(APPEND STAGE0_ARGS "-D${var}=${${var}}")
elseif(("${var}" MATCHES "CMAKE_.*") AND NOT ("${var}" MATCHES "CMAKE_BUILD_TYPE") AND NOT ("${var}" MATCHES "CMAKE_HOME_DIRECTORY"))
elseif((var MATCHES "CMAKE_.*") AND NOT (var MATCHES "CMAKE_BUILD_TYPE") AND NOT (var MATCHES "CMAKE_HOME_DIRECTORY"))
list(APPEND PLATFORM_ARGS "-D${var}=${${var}}")
endif()
endforeach()
@@ -34,15 +34,15 @@ include(ExternalProject)
project(LEAN CXX C)
if(NOT (DEFINED STAGE0_CMAKE_EXECUTABLE_SUFFIX))
set(STAGE0_CMAKE_EXECUTABLE_SUFFIX "${CMAKE_EXECUTABLE_SUFFIX}")
set(STAGE0_CMAKE_EXECUTABLE_SUFFIX "${CMAKE_EXECUTABLE_SUFFIX}")
endif()
# Don't do anything with cadical on wasm
if (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Emscripten")
if(NOT CMAKE_SYSTEM_NAME MATCHES "Emscripten")
find_program(CADICAL cadical)
if(NOT CADICAL)
set(CADICAL_CXX c++)
if (CADICAL_USE_CUSTOM_CXX)
if(CADICAL_USE_CUSTOM_CXX)
set(CADICAL_CXX ${CMAKE_CXX_COMPILER})
# Use same platform flags as for Lean executables, in particular from `prepare-llvm-linux.sh`,
# but not Lean-specific `LEAN_EXTRA_CXX_FLAGS` such as fsanitize.
@@ -54,42 +54,51 @@ if (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Emscripten")
set(CADICAL_CXX "${CCACHE} ${CADICAL_CXX}")
endif()
# missing stdio locking API on Windows
if(${CMAKE_SYSTEM_NAME} MATCHES "Windows")
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
string(APPEND CADICAL_CXXFLAGS " -DNUNLOCKED")
endif()
string(APPEND CADICAL_CXXFLAGS " -DNCLOSEFROM")
ExternalProject_add(cadical
ExternalProject_Add(
cadical
PREFIX cadical
GIT_REPOSITORY https://github.com/arminbiere/cadical
GIT_TAG rel-2.1.2
CONFIGURE_COMMAND ""
BUILD_COMMAND $(MAKE) -f ${CMAKE_SOURCE_DIR}/src/cadical.mk
CMAKE_EXECUTABLE_SUFFIX=${CMAKE_EXECUTABLE_SUFFIX}
CXX=${CADICAL_CXX}
CXXFLAGS=${CADICAL_CXXFLAGS}
LDFLAGS=${CADICAL_LDFLAGS}
BUILD_COMMAND
$(MAKE) -f ${CMAKE_SOURCE_DIR}/src/cadical.mk CMAKE_EXECUTABLE_SUFFIX=${CMAKE_EXECUTABLE_SUFFIX}
CXX=${CADICAL_CXX} CXXFLAGS=${CADICAL_CXXFLAGS} LDFLAGS=${CADICAL_LDFLAGS}
BUILD_IN_SOURCE ON
INSTALL_COMMAND "")
set(CADICAL ${CMAKE_BINARY_DIR}/cadical/cadical${CMAKE_EXECUTABLE_SUFFIX} CACHE FILEPATH "path to cadical binary" FORCE)
INSTALL_COMMAND ""
)
set(
CADICAL
${CMAKE_BINARY_DIR}/cadical/cadical${CMAKE_EXECUTABLE_SUFFIX}
CACHE FILEPATH
"path to cadical binary"
FORCE
)
list(APPEND EXTRA_DEPENDS cadical)
endif()
list(APPEND CL_ARGS -DCADICAL=${CADICAL})
endif()
if (USE_MIMALLOC)
ExternalProject_add(mimalloc
if(USE_MIMALLOC)
ExternalProject_Add(
mimalloc
PREFIX mimalloc
GIT_REPOSITORY https://github.com/microsoft/mimalloc
GIT_TAG v2.2.3
# just download, we compile it as part of each stage as it is small
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND "")
INSTALL_COMMAND ""
)
list(APPEND EXTRA_DEPENDS mimalloc)
endif()
if (NOT STAGE1_PREV_STAGE)
ExternalProject_add(stage0
if(NOT STAGE1_PREV_STAGE)
ExternalProject_Add(
stage0
SOURCE_DIR "${LEAN_SOURCE_DIR}/stage0"
SOURCE_SUBDIR src
BINARY_DIR stage0
@@ -97,38 +106,49 @@ if (NOT STAGE1_PREV_STAGE)
# (however, CI will override this as we need to embed the githash into the stage 1 library built
# by stage 0)
CMAKE_ARGS -DSTAGE=0 -DUSE_GITHASH=OFF ${PLATFORM_ARGS} ${STAGE0_ARGS}
BUILD_ALWAYS ON # cmake doesn't auto-detect changes without a download method
INSTALL_COMMAND "" # skip install
BUILD_ALWAYS
ON # cmake doesn't auto-detect changes without a download method
INSTALL_COMMAND
"" # skip install
DEPENDS ${EXTRA_DEPENDS}
)
list(APPEND EXTRA_DEPENDS stage0)
endif()
ExternalProject_add(stage1
ExternalProject_Add(
stage1
SOURCE_DIR "${LEAN_SOURCE_DIR}"
SOURCE_SUBDIR src
BINARY_DIR stage1
CMAKE_ARGS -DSTAGE=1 -DPREV_STAGE=${CMAKE_BINARY_DIR}/stage0 -DPREV_STAGE_CMAKE_EXECUTABLE_SUFFIX=${STAGE0_CMAKE_EXECUTABLE_SUFFIX} ${CL_ARGS} ${STAGE1_ARGS}
CMAKE_ARGS
-DSTAGE=1 -DPREV_STAGE=${CMAKE_BINARY_DIR}/stage0
-DPREV_STAGE_CMAKE_EXECUTABLE_SUFFIX=${STAGE0_CMAKE_EXECUTABLE_SUFFIX} ${CL_ARGS} ${STAGE1_ARGS}
BUILD_ALWAYS ON
INSTALL_COMMAND ""
DEPENDS ${EXTRA_DEPENDS}
STEP_TARGETS configure
)
ExternalProject_add(stage2
ExternalProject_Add(
stage2
SOURCE_DIR "${LEAN_SOURCE_DIR}"
SOURCE_SUBDIR src
BINARY_DIR stage2
CMAKE_ARGS -DSTAGE=2 -DPREV_STAGE=${CMAKE_BINARY_DIR}/stage1 -DPREV_STAGE_CMAKE_EXECUTABLE_SUFFIX=${CMAKE_EXECUTABLE_SUFFIX} ${CL_ARGS}
CMAKE_ARGS
-DSTAGE=2 -DPREV_STAGE=${CMAKE_BINARY_DIR}/stage1 -DPREV_STAGE_CMAKE_EXECUTABLE_SUFFIX=${CMAKE_EXECUTABLE_SUFFIX}
${CL_ARGS}
BUILD_ALWAYS ON
INSTALL_COMMAND ""
DEPENDS stage1
EXCLUDE_FROM_ALL ON
STEP_TARGETS configure
)
ExternalProject_add(stage3
ExternalProject_Add(
stage3
SOURCE_DIR "${LEAN_SOURCE_DIR}"
SOURCE_SUBDIR src
BINARY_DIR stage3
CMAKE_ARGS -DSTAGE=3 -DPREV_STAGE=${CMAKE_BINARY_DIR}/stage2 -DPREV_STAGE_CMAKE_EXECUTABLE_SUFFIX=${CMAKE_EXECUTABLE_SUFFIX} ${CL_ARGS}
CMAKE_ARGS
-DSTAGE=3 -DPREV_STAGE=${CMAKE_BINARY_DIR}/stage2 -DPREV_STAGE_CMAKE_EXECUTABLE_SUFFIX=${CMAKE_EXECUTABLE_SUFFIX}
${CL_ARGS}
BUILD_ALWAYS ON
INSTALL_COMMAND ""
DEPENDS stage2
@@ -137,24 +157,14 @@ ExternalProject_add(stage3
# targets forwarded to appropriate stages
add_custom_target(update-stage0
COMMAND $(MAKE) -C stage1 update-stage0
DEPENDS stage1)
add_custom_target(update-stage0 COMMAND $(MAKE) -C stage1 update-stage0 DEPENDS stage1)
add_custom_target(update-stage0-commit
COMMAND $(MAKE) -C stage1 update-stage0-commit
DEPENDS stage1)
add_custom_target(update-stage0-commit COMMAND $(MAKE) -C stage1 update-stage0-commit DEPENDS stage1)
add_custom_target(test
COMMAND $(MAKE) -C stage1 test
DEPENDS stage1)
add_custom_target(test COMMAND $(MAKE) -C stage1 test DEPENDS stage1)
add_custom_target(clean-stdlib
COMMAND $(MAKE) -C stage1 clean-stdlib
DEPENDS stage1)
add_custom_target(clean-stdlib COMMAND $(MAKE) -C stage1 clean-stdlib DEPENDS stage1)
install(CODE "execute_process(COMMAND make -C stage1 install)")
add_custom_target(check-stage3
COMMAND diff "stage2/bin/lean" "stage3/bin/lean"
DEPENDS stage3)
add_custom_target(check-stage3 COMMAND diff "stage2/bin/lean" "stage3/bin/lean" DEPENDS stage3)

13
script/fmt Executable file
View File

@@ -0,0 +1,13 @@
#!/usr/bin/env bash
set -euo pipefail
# This script expects to be run from the repo root.
# Format cmake files
find -regex '.*/CMakeLists\.txt\(\.in\)?\|.*\.cmake\(\.in\)?' \
! -path './build/*' \
! -path "./stage0/*" \
-exec \
uvx gersemi --in-place --line-length 120 --indent 2 \
--definitions src/cmake/Modules/ src/CMakeLists.txt \
-- {} +

File diff suppressed because it is too large Load Diff

View File

@@ -469,13 +469,13 @@ namespace EStateM
instance : LawfulMonad (EStateM ε σ) := .mk'
(id_map := fun x => funext <| fun s => by
dsimp only [EStateM.instMonad, EStateM.map]
simp only [Functor.map, EStateM.map]
match x s with
| .ok _ _ => rfl
| .error _ _ => rfl)
(pure_bind := fun _ _ => by rfl)
(bind_assoc := fun x _ _ => funext <| fun s => by
dsimp only [EStateM.instMonad, EStateM.bind]
simp only [bind, EStateM.bind]
match x s with
| .ok _ _ => rfl
| .error _ _ => rfl)

View File

@@ -932,6 +932,14 @@ noncomputable def HEq.ndrec.{u1, u2} {α : Sort u2} {a : α} {motive : {β : Sor
noncomputable def HEq.ndrecOn.{u1, u2} {α : Sort u2} {a : α} {motive : {β : Sort u2} β Sort u1} {β : Sort u2} {b : β} (h : a b) (m : motive a) : motive b :=
h.rec m
/-- `HEq.ndrec` specialized to homogeneous heterogeneous equality -/
noncomputable def HEq.homo_ndrec.{u1, u2} {α : Sort u2} {a : α} {motive : α Sort u1} (m : motive a) {b : α} (h : a b) : motive b :=
(eq_of_heq h).ndrec m
/-- `HEq.ndrec` specialized to homogeneous heterogeneous equality, symmetric variant -/
noncomputable def HEq.homo_ndrec_symm.{u1, u2} {α : Sort u2} {a : α} {motive : α Sort u1} (m : motive a) {b : α} (h : b a) : motive b :=
(eq_of_heq h).ndrec_symm m
/-- `HEq.ndrec` variant -/
noncomputable def HEq.elim {α : Sort u} {a : α} {p : α Sort v} {b : α} (h₁ : a b) (h₂ : p a) : p b :=
eq_of_heq h₁ h₂
@@ -1478,6 +1486,29 @@ def Prod.map {α₁ : Type u₁} {α₂ : Type u₂} {β₁ : Type v₁} {β₂
/-! # Dependent products -/
instance {α : Type u} {β : α Type v} [h₁ : DecidableEq α] [h₂ : a, DecidableEq (β a)] :
DecidableEq (Sigma β)
| a₁, b₁, a₂, b₂ =>
match a₁, b₁, a₂, b₂, h₁ a₁ a₂ with
| _, b₁, _, b₂, isTrue (Eq.refl _) =>
match b₁, b₂, h₂ _ b₁ b₂ with
| _, _, isTrue (Eq.refl _) => isTrue rfl
| _, _, isFalse n => isFalse fun h
Sigma.noConfusion rfl .rfl (heq_of_eq h) fun _ e₂ n (eq_of_heq e₂)
| _, _, _, _, isFalse n => isFalse fun h
Sigma.noConfusion rfl .rfl (heq_of_eq h) fun e₁ _ n (eq_of_heq e₁)
instance {α : Sort u} {β : α Sort v} [h₁ : DecidableEq α] [h₂ : a, DecidableEq (β a)] : DecidableEq (PSigma β)
| a₁, b₁, a₂, b₂ =>
match a₁, b₁, a₂, b₂, h₁ a₁ a₂ with
| _, b₁, _, b₂, isTrue (Eq.refl _) =>
match b₁, b₂, h₂ _ b₁ b₂ with
| _, _, isTrue (Eq.refl _) => isTrue rfl
| _, _, isFalse n => isFalse fun h
PSigma.noConfusion rfl .rfl (heq_of_eq h) fun _ e₂ n (eq_of_heq e₂)
| _, _, _, _, isFalse n => isFalse fun h
PSigma.noConfusion rfl .rfl (heq_of_eq h) fun e₁ _ n (eq_of_heq e₁)
theorem Exists.of_psigma_prop {α : Sort u} {p : α Prop} : (PSigma (fun x => p x)) Exists (fun x => p x)
| x, hx => x, hx

View File

@@ -30,3 +30,4 @@ public import Init.Data.Array.Erase
public import Init.Data.Array.Zip
public import Init.Data.Array.InsertIdx
public import Init.Data.Array.Extract
public import Init.Data.Array.MinMax

View File

@@ -3065,6 +3065,18 @@ theorem foldl_eq_foldlM {f : β → α → β} {b} {xs : Array α} {start stop :
theorem foldr_eq_foldrM {f : α β β} {b} {xs : Array α} {start stop : Nat} :
xs.foldr f b start stop = (xs.foldrM (m := Id) (pure <| f · ·) b start stop).run := rfl
public theorem foldl_eq_foldl_extract {xs : Array α} {f : β α β} {init : β} :
xs.foldl (init := init) (start := start) (stop := stop) f =
(xs.extract start stop).foldl (init := init) f := by
simp only [foldl_eq_foldlM]
rw [foldlM_start_stop]
public theorem foldr_eq_foldr_extract {xs : Array α} {f : α β β} {init : β} :
xs.foldr (init := init) (start := start) (stop := stop) f =
(xs.extract stop start).foldr (init := init) f := by
simp only [foldr_eq_foldrM]
rw [foldrM_start_stop]
@[simp] theorem id_run_foldlM {f : β α Id β} {b} {xs : Array α} {start stop : Nat} :
Id.run (xs.foldlM f b start stop) = xs.foldl (f · · |>.run) b start stop := rfl

View File

@@ -0,0 +1,401 @@
/-
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Reichert
-/
module
prelude
public import Init.Data.Array.Bootstrap
public import Init.Data.Array.Lemmas
public import Init.Data.Array.DecidableEq
import Init.Data.List.MinMax
import Init.Data.List.ToArray
namespace Array
/-! ## Minima and maxima -/
/-! ### min -/
/--
Returns the smallest element of a non-empty array.
Examples:
* `#[4].min (by decide) = 4`
* `#[1, 4, 2, 10, 6].min (by decide) = 1`
-/
public protected def min [Min α] (arr : Array α) (h : arr #[]) : α :=
haveI : arr.size > 0 := by simp [Array.size_pos_iff, h]
arr.foldl min arr[0] (start := 1)
/-! ### min? -/
/--
Returns the smallest element of the array if it is not empty, or `none` if it is empty.
Examples:
* `#[].min? = none`
* `#[4].min? = some 4`
* `#[1, 4, 2, 10, 6].min? = some 1`
-/
public protected def min? [Min α] (arr : Array α) : Option α :=
if h : arr #[] then
some (arr.min h)
else
none
/-! ### max -/
/--
Returns the largest element of a non-empty array.
Examples:
* `#[4].max (by decide) = 4`
* `#[1, 4, 2, 10, 6].max (by decide) = 10`
-/
public protected def max [Max α] (arr : Array α) (h : arr #[]) : α :=
haveI : arr.size > 0 := by simp [Array.size_pos_iff, h]
arr.foldl max arr[0] (start := 1)
/-! ### max? -/
/--
Returns the largest element of the array if it is not empty, or `none` if it is empty.
Examples:
* `#[].max? = none`
* `#[4].max? = some 4`
* `#[1, 4, 2, 10, 6].max? = some 10`
-/
public protected def max? [Max α] (arr : Array α) : Option α :=
if h : arr #[] then
some (arr.max h)
else
none
/-! ### Compatibility with `List` -/
@[simp, grind =]
public theorem _root_.List.min_toArray [Min α] {l : List α} {h} :
l.toArray.min h = l.min (by simpa [List.ne_nil_iff_length_pos] using h) := by
let h' : l [] := by simpa [List.ne_nil_iff_length_pos] using h
change l.toArray.min h = l.min h'
rw [Array.min]
· induction l
· contradiction
· rename_i x xs
simp only [List.getElem_toArray, List.getElem_cons_zero, List.size_toArray, List.length_cons]
rw [List.toArray_cons, foldl_eq_foldl_extract]
rw [ Array.foldl_toList, Array.toList_extract, List.extract_eq_drop_take]
simp [List.min]
public theorem _root_.List.min_eq_min_toArray [Min α] {l : List α} {h} :
l.min h = l.toArray.min (by simpa [List.ne_nil_iff_length_pos] using h) := by
simp
@[simp, grind =]
public theorem min_toList [Min α] {xs : Array α} {h} :
xs.toList.min h = xs.min (by simpa [List.ne_nil_iff_length_pos] using h) := by
cases xs; simp
public theorem min_eq_min_toList [Min α] {xs : Array α} {h} :
xs.min h = xs.toList.min (by simpa [List.ne_nil_iff_length_pos] using h) := by
simp
@[simp, grind =]
public theorem _root_.List.min?_toArray [Min α] {l : List α} :
l.toArray.min? = l.min? := by
rw [Array.min?]
split
· simp [List.min_toArray, List.min_eq_get_min?, - List.get_min?]
· simp_all
@[simp, grind =]
public theorem min?_toList [Min α] {xs : Array α} :
xs.toList.min? = xs.min? := by
cases xs; simp
@[simp, grind =]
public theorem _root_.List.max_toArray [Max α] {l : List α} {h} :
l.toArray.max h = l.max (by simpa [List.ne_nil_iff_length_pos] using h) := by
let h' : l [] := by simpa [List.ne_nil_iff_length_pos] using h
change l.toArray.max h = l.max h'
rw [Array.max]
· induction l
· contradiction
· rename_i x xs
simp only [List.getElem_toArray, List.getElem_cons_zero, List.size_toArray, List.length_cons]
rw [List.toArray_cons, foldl_eq_foldl_extract]
rw [ Array.foldl_toList, Array.toList_extract, List.extract_eq_drop_take]
simp [List.max]
public theorem _root_.List.max_eq_max_toArray [Max α] {l : List α} {h} :
l.max h = l.toArray.max (by simpa [List.ne_nil_iff_length_pos] using h) := by
simp
@[simp, grind =]
public theorem max_toList [Max α] {xs : Array α} {h} :
xs.toList.max h = xs.max (by simpa [List.ne_nil_iff_length_pos] using h) := by
cases xs; simp
public theorem max_eq_max_toList [Max α] {xs : Array α} {h} :
xs.max h = xs.toList.max (by simpa [List.ne_nil_iff_length_pos] using h) := by
simp
@[simp, grind =]
public theorem _root_.List.max?_toArray [Max α] {l : List α} :
l.toArray.max? = l.max? := by
rw [Array.max?]
split
· simp [List.max_toArray, List.max_eq_get_max?, - List.get_max?]
· simp_all
@[simp, grind =]
public theorem max?_toList [Max α] {xs : Array α} :
xs.toList.max? = xs.max? := by
cases xs; simp
/-! ### Lemmas about `min?` -/
@[simp, grind =]
public theorem min?_empty [Min α] : (#[] : Array α).min? = none :=
(rfl)
@[simp, grind =]
public theorem min?_singleton [Min α] {x : α} : #[x].min? = some x :=
(rfl)
-- We don't put `@[simp]` on `min?_singleton_append'`,
-- because the definition in terms of `foldl` is not useful for proofs.
public theorem min?_singleton_append' [Min α] {xs : Array α} :
(#[x] ++ xs).min? = some (xs.foldl min x) := by
simp [ min?_toList, toList_append, List.min?]
@[simp]
public theorem min?_singleton_append [Min α] [Std.Associative (min : α α α)] {xs : Array α} :
(#[x] ++ xs).min? = some (xs.min?.elim x (min x)) := by
simp [ min?_toList, toList_append, List.min?_cons]
@[simp, grind =]
public theorem min?_eq_none_iff {xs : Array α} [Min α] : xs.min? = none xs = #[] := by
rcases xs with l
simp
@[simp, grind =]
public theorem isSome_min?_iff {xs : Array α} [Min α] : xs.min?.isSome xs #[] := by
rcases xs with l
simp
@[grind .]
public theorem isSome_min?_of_mem {xs : Array α} [Min α] {a : α} (h : a xs) :
xs.min?.isSome := by
rw [ min?_toList]
apply List.isSome_min?_of_mem (a := a)
simpa
public theorem isSome_min?_of_ne_empty [Min α] (xs : Array α) (h : xs #[]) : xs.min?.isSome := by
rw [ min?_toList]
apply List.isSome_min?_of_ne_nil
simpa
public theorem min?_mem [Min α] [Std.MinEqOr α] (xs : Array α) (h : xs.min? = some a) : a xs := by
rw [ min?_toList] at h
simpa using List.min?_mem h
public theorem le_min?_iff [Min α] [LE α] [Std.LawfulOrderInf α] :
{xs : Array α} xs.min? = some a {x}, x a b, b xs x b := by
intro xs h x
simp only [ min?_toList] at h
simpa using List.le_min?_iff h
public theorem min?_eq_some_iff [Min α] [LE α] {xs : Array α} [Std.IsLinearOrder α]
[Std.LawfulOrderMin α] : xs.min? = some a a xs b, b xs a b := by
rcases xs with l
simpa using List.min?_eq_some_iff
public theorem min?_replicate [Min α] [Std.IdempotentOp (min : α α α)] {n : Nat} {a : α} :
(replicate n a).min? = if n = 0 then none else some a := by
rw [ List.toArray_replicate, List.min?_toArray, List.min?_replicate]
@[simp, grind =]
public theorem min?_replicate_of_pos [Min α] [Std.MinEqOr α] {n : Nat} {a : α} (h : 0 < n) :
(replicate n a).min? = some a := by
simp [min?_replicate, Nat.ne_of_gt h]
public theorem foldl_min [Min α] [Std.IdempotentOp (min : α α α)]
[Std.Associative (min : α α α)] {xs : Array α} {a : α} :
xs.foldl (init := a) min = min a (xs.min?.getD a) := by
rcases xs with l
simp [List.foldl_min]
/-! ### Lemmas about `max?` -/
@[simp, grind =]
public theorem max?_empty [Max α] : (#[] : Array α).max? = none :=
(rfl)
@[simp, grind =]
public theorem max?_singleton [Max α] {x : α} : #[x].max? = some x :=
(rfl)
-- We don't put `@[simp]` on `max?_singleton_append'`,
-- because the definition in terms of `foldl` is not useful for proofs.
public theorem max?_singleton_append' [Max α] {xs : Array α} : (#[x] ++ xs).max? = some (xs.foldl max x) := by
simp [ max?_toList, toList_append, List.max?]
@[simp]
public theorem max?_singleton_append [Max α] [Std.Associative (max : α α α)] {xs : Array α} :
(#[x] ++ xs).max? = some (xs.max?.elim x (max x)) := by
simp [ max?_toList, toList_append, List.max?_cons]
@[simp, grind =]
public theorem max?_eq_none_iff {xs : Array α} [Max α] : xs.max? = none xs = #[] := by
rcases xs with l
simp
@[simp, grind =]
public theorem isSome_max?_iff {xs : Array α} [Max α] : xs.max?.isSome xs #[] := by
rcases xs with l
simp
@[grind .]
public theorem isSome_max?_of_mem {xs : Array α} [Max α] {a : α} (h : a xs) :
xs.max?.isSome := by
rw [ max?_toList]
apply List.isSome_max?_of_mem (a := a)
simpa
public theorem isSome_max?_of_ne_empty [Max α] (xs : Array α) (h : xs #[]) : xs.max?.isSome := by
rw [ max?_toList]
apply List.isSome_max?_of_ne_nil
simpa
public theorem max?_mem [Max α] [Std.MaxEqOr α] (xs : Array α) (h : xs.max? = some a) : a xs := by
rw [ max?_toList] at h
simpa using List.max?_mem h
public theorem max?_le_iff [Max α] [LE α] [Std.LawfulOrderSup α] :
{xs : Array α} xs.max? = some a {x}, a x b, b xs b x := by
intro xs h x
simp only [ max?_toList] at h
simpa using List.max?_le_iff h
public theorem max?_eq_some_iff [Max α] [LE α] {xs : Array α} [Std.IsLinearOrder α]
[Std.LawfulOrderMax α] : xs.max? = some a a xs b, b xs b a := by
rcases xs with l
simpa using List.max?_eq_some_iff
public theorem max?_replicate [Max α] [Std.IdempotentOp (max : α α α)] {n : Nat} {a : α} :
(replicate n a).max? = if n = 0 then none else some a := by
rw [ List.toArray_replicate, List.max?_toArray, List.max?_replicate]
@[simp, grind =]
public theorem max?_replicate_of_pos [Max α] [Std.MaxEqOr α] {n : Nat} {a : α} (h : 0 < n) :
(replicate n a).max? = some a := by
simp [max?_replicate, Nat.ne_of_gt h]
public theorem foldl_max [Max α] [Std.IdempotentOp (max : α α α)] [Std.Associative (max : α α α)]
{xs : Array α} {a : α} : xs.foldl (init := a) max = max a (xs.max?.getD a) := by
rcases xs with l
simp [List.foldl_max]
/-! ### Lemmas about `min` -/
@[simp, grind =]
theorem min_singleton [Min α] {x : α} :
#[x].min (ne_empty_of_size_eq_add_one rfl) = x := by
(rfl)
public theorem min?_eq_some_min [Min α] : {xs : Array α} (h : xs #[])
xs.min? = some (xs.min h)
| a::as, _ => by simp [Array.min, Array.min?]
public theorem min_eq_get_min? [Min α] : (xs : Array α) (h : xs #[])
xs.min h = xs.min?.get (xs.isSome_min?_of_ne_empty h)
| a::as, _ => by simp [Array.min, Array.min?]
@[simp, grind =]
public theorem get_min? [Min α] {xs : Array α} {h : xs.min?.isSome} :
xs.min?.get h = xs.min (isSome_min?_iff.mp h) := by
simp [min?_eq_some_min (isSome_min?_iff.mp h)]
@[grind .]
public theorem min_mem [Min α] [Std.MinEqOr α] {xs : Array α} (h : xs #[]) : xs.min h xs :=
xs.min?_mem (min?_eq_some_min h)
@[grind .]
public theorem min_le_of_mem [Min α] [LE α] [Std.IsLinearOrder α] [Std.LawfulOrderMin α]
{xs : Array α} {a : α} (ha : a xs) :
xs.min (ne_empty_of_mem ha) a :=
(Array.min?_eq_some_iff.mp (min?_eq_some_min (ne_empty_of_mem ha))).right a ha
public protected theorem le_min_iff [Min α] [LE α] [Std.LawfulOrderInf α]
{xs : Array α} (h : xs #[]) : {x}, x xs.min h b, b xs x b :=
le_min?_iff (min?_eq_some_min h)
public theorem min_eq_iff [Min α] [LE α] {xs : Array α} [Std.IsLinearOrder α] [Std.LawfulOrderMin α]
(h : xs #[]) : xs.min h = a a xs b, b xs a b := by
simpa [min?_eq_some_min h] using (min?_eq_some_iff (xs := xs))
@[simp, grind =]
public theorem min_replicate [Min α] [Std.MinEqOr α] {n : Nat} {a : α} (h : (replicate n a) #[]) :
(replicate n a).min h = a := by
have n_pos : 0 < n := by simpa [Nat.ne_zero_iff_zero_lt] using h
simpa [min?_eq_some_min h] using (min?_replicate_of_pos (a := a) n_pos)
public theorem foldl_min_eq_min [Min α] [Std.IdempotentOp (min : α α α)]
[Std.Associative (min : α α α)] {xs : Array α} (h : xs #[]) {a : α} :
xs.foldl min a = min a (xs.min h) := by
simpa [min?_eq_some_min h] using foldl_min (xs := xs)
/-! ### Lemmas about `max` -/
@[simp, grind =]
theorem max_singleton [Max α] {x : α} :
#[x].max (ne_empty_of_size_eq_add_one rfl) = x := by
(rfl)
public theorem max?_eq_some_max [Max α] : {xs : Array α} (h : xs #[])
xs.max? = some (xs.max h)
| a::as, _ => by simp [Array.max, Array.max?]
public theorem max_eq_get_max? [Max α] : (xs : Array α) (h : xs #[])
xs.max h = xs.max?.get (xs.isSome_max?_of_ne_empty h)
| a::as, _ => by simp [Array.max, Array.max?]
@[simp, grind =]
public theorem get_max? [Max α] {xs : Array α} {h : xs.max?.isSome} :
xs.max?.get h = xs.max (isSome_max?_iff.mp h) := by
simp [max?_eq_some_max (isSome_max?_iff.mp h)]
@[grind .]
public theorem max_mem [Max α] [Std.MaxEqOr α] {xs : Array α} (h : xs #[]) : xs.max h xs :=
xs.max?_mem (max?_eq_some_max h)
public protected theorem max_le_iff [Max α] [LE α] [Std.LawfulOrderSup α]
{xs : Array α} (h : xs #[]) : {x}, xs.max h x b, b xs b x :=
max?_le_iff (max?_eq_some_max h)
public theorem max_eq_iff [Max α] [LE α] {xs : Array α} [Std.IsLinearOrder α] [Std.LawfulOrderMax α]
(h : xs #[]) : xs.max h = a a xs b, b xs b a := by
simpa [max?_eq_some_max h] using (max?_eq_some_iff (xs := xs))
@[grind .]
public theorem le_max_of_mem [Max α] [LE α] [Std.IsLinearOrder α] [Std.LawfulOrderMax α]
{xs : Array α} {a : α} (ha : a xs) :
a xs.max (ne_empty_of_mem ha) :=
(Array.max?_eq_some_iff.mp (max?_eq_some_max (ne_empty_of_mem ha))).right a ha
@[simp, grind =]
public theorem max_replicate [Max α] [Std.MaxEqOr α] {n : Nat} {a : α} (h : (replicate n a) #[]) :
(replicate n a).max h = a := by
have n_pos : 0 < n := by simpa [Nat.ne_zero_iff_zero_lt] using h
simpa [max?_eq_some_max h] using (max?_replicate_of_pos (a := a) n_pos)
public theorem foldl_max_eq_max [Max α] [Std.IdempotentOp (max : α α α)]
[Std.Associative (max : α α α)] {xs : Array α} (h : xs #[]) {a : α} :
xs.foldl max a = max a (xs.max h) := by
simpa [max?_eq_some_max h] using foldl_max (xs := xs)
end Array

View File

@@ -11,6 +11,8 @@ public import Init.Grind.Ordered.Ring
/-! # Internal `grind` algebra instances for `Dyadic`. -/
@[expose] public section
open Lean.Grind
namespace Dyadic

View File

@@ -4,7 +4,9 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Kim Morrison
-/
module
prelude
public import Init.Data.Dyadic.Basic
import Init.Data.Dyadic.Round
import Init.Grind.Ordered.Ring
@@ -12,6 +14,8 @@ import Init.Grind.Ordered.Ring
# Inversion for dyadic numbers
-/
@[expose] public section
namespace Dyadic
/--

View File

@@ -7,7 +7,7 @@ module
prelude
public import Init.Data.Dyadic.Basic
import all Init.Data.Dyadic.Instances
import Init.Data.Dyadic.Instances
import Init.Grind.Ordered.Rat
import Init.Grind.Ordered.Field

View File

@@ -1153,6 +1153,15 @@ theorem ediv_le_iff_le_mul {k x y : Int} (h : 0 < k) : x / k ≤ y ↔ x < y * k
rw [Int.le_iff_lt_add_one, Int.ediv_lt_iff_lt_mul h, Int.add_mul]
omega
theorem le_mul_iff_le_left {x y z : Int} (hz : 0 < z) :
x y * z (x + z - 1) / z y := by
rw [Int.ediv_le_iff_le_mul hz]
omega
theorem le_mul_iff_le_right {x y z : Int} (hy : 0 < y) :
x y * z (x + y - 1) / y z := by
rw [ le_mul_iff_le_left hy, Int.mul_comm]
protected theorem le_mul_of_ediv_le {a b c : Int} (H1 : 0 b) (H2 : b a) (H3 : a / b c) :
a c * b := by
rw [ Int.ediv_mul_cancel H2]; exact Int.mul_le_mul_of_nonneg_right H3 H1
@@ -1206,6 +1215,11 @@ theorem add_ediv {a b c : Int} (h : c ≠ 0) :
protected theorem ediv_le_ediv {a b c : Int} (H : 0 < c) (H' : a b) : a / c b / c :=
Int.le_ediv_of_mul_le H (Int.le_trans (Int.ediv_mul_le _ (Int.ne_of_gt H)) H')
theorem ediv_add_ediv_le_add_ediv {x y z : Int} (hz : 0 < z) :
x / z + y / z (x + y) / z := by
rw [Int.le_ediv_iff_mul_le hz, Int.add_mul]
apply Int.add_le_add <;> apply Int.ediv_mul_le <;> omega
/-- If `n > 0` then `m` is not divisible by `n` iff it is between `n * k` and `n * (k + 1)`
for some `k`. -/
theorem not_dvd_iff_lt_mul_succ (m : Int) (hn : 0 < n) :
@@ -1783,12 +1797,12 @@ theorem ediv_lt_ediv_iff_of_dvd_of_neg_of_neg {a b c d : Int} (hb : b < 0) (hd :
theorem ediv_lt_ediv_of_lt {a b c : Int} (h : a < b) (hcb : c b) (hc : 0 < c) :
a / c < b / c :=
Int.lt_ediv_of_mul_lt (Int.le_of_lt hc) hcb
Int.lt_ediv_of_mul_lt (Int.le_of_lt hc) hcb
(Int.lt_of_le_of_lt (Int.ediv_mul_le _ (Int.ne_of_gt hc)) h)
theorem ediv_lt_ediv_of_lt_of_neg {a b c : Int} (h : b < a) (hca : c a) (hc : c < 0) :
a / c < b / c :=
(Int.ediv_lt_iff_of_dvd_of_neg hc hca).2
(Int.ediv_lt_iff_of_dvd_of_neg hc hca).2
(Int.lt_of_le_of_lt (Int.mul_ediv_self_le (Int.ne_of_lt hc)) h)
/-! ### `tdiv` and ordering -/

View File

@@ -94,8 +94,9 @@ By convention, the monadic iterator associated with an object can be obtained vi
For example, `List.iterM IO` creates an iterator over a list in the monad `IO`.
See `Init.Data.Iterators.Consumers` for ways to use an iterator. For example, `it.toList` will
convert a provably finite iterator `it` into a list and `it.allowNontermination.toList` will
do so even if finiteness cannot be proved. It is also always possible to manually iterate using
convert an iterator `it` into a list and `it.ensureTermination.toList` guarantees that this
operation will terminate, given a proof that the iterator is finite.
It is also always possible to manually iterate using
`it.step`, relying on the termination measures `it.finitelyManySteps` and `it.finitelyManySkips`.
See `Iter` for a more convenient interface in case that no monadic effects are needed (`m = Id`).
@@ -139,8 +140,9 @@ By convention, the monadic iterator associated with an object can be obtained vi
For example, `List.iterM IO` creates an iterator over a list in the monad `IO`.
See `Init.Data.Iterators.Consumers` for ways to use an iterator. For example, `it.toList` will
convert a provably finite iterator `it` into a list and `it.allowNontermination.toList` will
do so even if finiteness cannot be proved. It is also always possible to manually iterate using
convert an iterator `it` into a list and `it.ensureTermination.toList` guarantees that this
operation will terminate, given a proof that the iterator is finite.
It is also always possible to manually iterate using
`it.step`, relying on the termination measures `it.finitelyManySteps` and `it.finitelyManySkips`.
See `IterM` for iterators that operate in a monad.
@@ -754,8 +756,8 @@ def IterM.finitelyManySteps {α : Type w} {m : Type w → Type w'} {β : Type w}
it
/--
Termination measure to be used in well-founded recursive functions recursing over a finite iterator
(see also `Finite`).
Termination measure to be used in recursive functions built with `WellFounded.extrinsicFix`
recursing over a finite iterator without requiring a proof of finiteness (see also `Finite`).
-/
@[expose]
def IterM.finitelyManySteps! {α : Type w} {m : Type w Type w'} {β : Type w} [Iterator α m β]
@@ -796,6 +798,11 @@ def Iter.finitelyManySteps {α : Type w} {β : Type w} [Iterator α Id β] [Iter
(it : Iter (α := α) β) : IterM.TerminationMeasures.Finite α Id :=
it.toIterM.finitelyManySteps
@[inherit_doc IterM.finitelyManySteps!, expose]
def Iter.finitelyManySteps! {α : Type w} {β : Type w} [Iterator α Id β]
(it : Iter (α := α) β) : IterM.TerminationMeasures.Finite α Id :=
it.toIterM.finitelyManySteps!
/--
This theorem is used by a `decreasing_trivial` extension. It powers automatic termination proofs
with `IterM.finitelyManySteps`.
@@ -902,6 +909,16 @@ def IterM.finitelyManySkips {α : Type w} {m : Type w → Type w'} {β : Type w}
[Iterators.Productive α m] (it : IterM (α := α) m β) : IterM.TerminationMeasures.Productive α m :=
it
/--
Termination measure to be used in recursive functions built with `WellFounded.extrinsicFix`
recursing over a productive iterator without requiring a proof of productiveness
(see also `Productive`).
-/
@[expose]
def IterM.finitelyManySkips! {α : Type w} {m : Type w Type w'} {β : Type w} [Iterator α m β]
(it : IterM (α := α) m β) : IterM.TerminationMeasures.Productive α m :=
it
/--
This theorem is used by a `decreasing_trivial` extension. It powers automatic termination proofs
with `IterM.finitelyManySkips`.
@@ -922,6 +939,11 @@ def Iter.finitelyManySkips {α : Type w} {β : Type w} [Iterator α Id β] [Iter
(it : Iter (α := α) β) : IterM.TerminationMeasures.Productive α Id :=
it.toIterM.finitelyManySkips
@[inherit_doc IterM.finitelyManySkips!, expose]
def Iter.finitelyManySkips! {α : Type w} {β : Type w} [Iterator α Id β]
(it : Iter (α := α) β) : IterM.TerminationMeasures.Productive α Id :=
it.toIterM.finitelyManySkips!
/--
This theorem is used by a `decreasing_trivial` extension. It powers automatic termination proofs
with `Iter.finitelyManySkips`.

View File

@@ -21,21 +21,70 @@ If possible, takes `n` steps with the iterator `it` and
returns the `n`-th emitted value, or `none` if `it` finished
before emitting `n` values.
This function requires a `Productive` instance proving that the iterator will always emit a value
after a finite number of skips. If the iterator is not productive or such an instance is not
available, consider using `it.allowNontermination.atIdxSlow?` instead of `it.atIdxSlow?`. However,
it is not possible to formally verify the behavior of the partial variant.
If the iterator is not productive, this function might run forever in an endless loop of iterator
steps. The variant `it.ensureTermination.atIdxSlow?` is guaranteed to terminate after finitely many
steps.
-/
@[specialize]
def Iter.atIdxSlow? {α β} [Iterator α Id β] [Productive α Id]
def Iter.atIdxSlow? {α β} [Iterator α Id β]
(n : Nat) (it : Iter (α := α) β) : Option β :=
match it.step with
| .yield it' out _ =>
match n with
| 0 => some out
| k + 1 => it'.atIdxSlow? k
| .skip it' _ => it'.atIdxSlow? n
| .done _ => none
WellFounded.extrinsicFix₂ (C₂ := fun _ _ => Option β) (α := Iter (α := α) β) (β := fun _ => Nat)
(InvImage
(Prod.Lex WellFoundedRelation.rel IterM.TerminationMeasures.Productive.Rel)
(fun p => (p.2, p.1.finitelyManySkips!)))
(fun it n recur =>
match it.step with
| .yield it' out _ =>
match n with
| 0 => some out
| k + 1 => recur it' k (by decreasing_tactic)
| .skip it' _ => recur it' n (by decreasing_tactic)
| .done _ => none) it n
-- We provide the functional induction principle by hand because `atIdxSlow?` is implemented using
-- `extrinsicFix₂` and not using well-founded recursion.
/-
An induction principle for `Iter.atIdxSlow?`.
This lemma provides a functional induction principle for reasoning about `Iter.atIdxSlow? n it`.
The induction follows the structure of iterator steps.
- base case: when we reach the desired index (`n = 0`) and get a `.yield` step
- inductive case: when we have a `.yield` step but need to continue (`n > 0`)
- skip case: when we encounter a `.skip` step and continue with the same index
- done case: when the iterator is exhausted and we return `none`
-/
theorem Iter.atIdxSlow?.induct_unfolding {α β : Type u} [Iterator α Id β] [Productive α Id]
(motive : Nat Iter β Option β Prop)
-- Base case: we have reached index 0 and found a value
(yield_zero : (it it' : Iter (α := α) β) (out : β) (property : it.IsPlausibleStep (IterStep.yield it' out)),
it.step = IterStep.yield it' out, property motive 0 it (some out))
-- Inductive case: we have a yield but need to continue to a higher index
(yield_succ : (it it' : Iter (α := α) β) (out : β) (property : it.IsPlausibleStep (IterStep.yield it' out)),
it.step = IterStep.yield it' out, property
(k : Nat), motive k it' (Iter.atIdxSlow? k it') motive k.succ it (Iter.atIdxSlow? k it'))
-- Skip case: we encounter a skip and continue with the same index
(skip_case : (n : Nat) (it it' : Iter β) (property : it.IsPlausibleStep (IterStep.skip it')),
it.step = IterStep.skip it', property
motive n it' (Iter.atIdxSlow? n it') motive n it (Iter.atIdxSlow? n it'))
-- Done case: the iterator is exhausted, return none
(done_case : (n : Nat) (it : Iter β) (property : it.IsPlausibleStep IterStep.done),
it.step = IterStep.done, property motive n it none)
-- The conclusion: the property holds for all indices and iterators
(n : Nat) (it : Iter β) : motive n it (Iter.atIdxSlow? n it) := by
simp only [atIdxSlow?] at *
rw [WellFounded.extrinsicFix₂_eq_apply]
· split
· split
· apply yield_zero <;> assumption
· apply yield_succ
all_goals try assumption
apply Iter.atIdxSlow?.induct_unfolding <;> assumption
· apply skip_case
all_goals try assumption
apply Iter.atIdxSlow?.induct_unfolding <;> assumption
· apply done_case <;> assumption
· exact InvImage.wf _ WellFoundedRelation.wf
termination_by (n, it.finitelyManySkips)
/--
@@ -43,22 +92,21 @@ If possible, takes `n` steps with the iterator `it` and
returns the `n`-th emitted value, or `none` if `it` finished
before emitting `n` values.
This is a partial, potentially nonterminating, function. It is not possible to formally verify
its behavior. If the iterator has a `Productive` instance, consider using `Iter.atIdxSlow?` instead.
This variant terminates after finitely many steps and requires a proof that the iterator is
productive. If such a proof is not available, consider using `Iter.toArray`.
-/
@[specialize]
partial def Iter.Partial.atIdxSlow? {α β} [Iterator α Id β] [Monad Id]
(n : Nat) (it : Iter.Partial (α := α) β) : Option β := do
match it.it.step with
| .yield it' out _ =>
match n with
| 0 => some out
| k + 1 => (it' : Iter.Partial (α := α) β).atIdxSlow? k
| .skip it' _ => (it' : Iter.Partial (α := α) β).atIdxSlow? n
| .done _ => none
@[inline]
def Iter.Total.atIdxSlow? {α β} [Iterator α Id β] [Productive α Id]
(n : Nat) (it : Iter.Total (α := α) β) : Option β :=
it.it.atIdxSlow? n
@[inline, inherit_doc Iter.atIdxSlow?, deprecated Iter.atIdxSlow? (since := "2026-01-28")]
def Iter.Partial.atIdxSlow? {α β} [Iterator α Id β]
(n : Nat) (it : Iter.Partial (α := α) β) : Option β :=
it.it.atIdxSlow? n
@[always_inline, inline, inherit_doc IterM.atIdx?]
def Iter.atIdx? {α β} [Iterator α Id β] [Productive α Id] [IteratorAccess α Id]
def Iter.atIdx? {α β} [Iterator α Id β] [IteratorAccess α Id]
(n : Nat) (it : Iter (α := α) β) : Option β :=
match (IteratorAccess.nextAtIdx? it.toIterM n).run.val with
| .yield _ out => some out

View File

@@ -667,6 +667,42 @@ def Iter.Total.first? {α β : Type w} [Iterator α Id β] [IteratorLoop α Id I
(it : Iter.Total (α := α) β) : Option β :=
it.it.first?
/--
Returns `true` if the iterator yields no values.
`O(|it|)` since the iterator may skip an unknown number of times before returning a result.
Short-circuits upon encountering the first result. Only the first element of `it` is examined.
If the iterator is not productive, this function might run forever. The variant
`it.ensureTermination.isEmpty` always terminates after finitely many steps.
Examples:
* `[].iter.isEmpty = true`
* `[1].iter.isEmpty = false`
-/
@[inline]
def Iter.isEmpty {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
(it : Iter (α := α) β) : Bool :=
it.toIterM.isEmpty.run.down
/--
Returns `true` if the iterator yields no values.
`O(|it|)` since the iterator may skip an unknown number of times before returning a result.
Short-circuits upon encountering the first result. Only the first element of `it` is examined.
This variant terminates after finitely many steps and requires a proof that the iterator is
productive. If such a proof is not available, consider using `Iter.isEmpty`.
Examples:
* `[].iter.isEmpty = true`
* `[1].iter.isEmpty = false`
-/
@[inline]
def Iter.Total.isEmpty {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id] [Productive α Id]
(it : Iter.Total (α := α) β) : Bool :=
it.it.isEmpty
/--
Steps through the whole iterator, counting the number of outputs emitted.
@@ -675,9 +711,15 @@ Steps through the whole iterator, counting the number of outputs emitted.
This function's runtime is linear in the number of steps taken by the iterator.
-/
@[always_inline, inline, expose]
def Iter.count {α : Type w} {β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
def Iter.length {α : Type w} {β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
(it : Iter (α := α) β) : Nat :=
it.toIterM.count.run.down
it.toIterM.length.run.down
@[inline, inherit_doc Iter.length, deprecated Iter.length (since := "2026-01-28"), expose]
def Iter.count := @Iter.length
@[inline, inherit_doc Iter.length, deprecated Iter.length (since := "2025-10-29"), expose]
def Iter.size := @Iter.length
/--
Steps through the whole iterator, counting the number of outputs emitted.
@@ -686,22 +728,10 @@ Steps through the whole iterator, counting the number of outputs emitted.
This function's runtime is linear in the number of steps taken by the iterator.
-/
@[always_inline, inline, expose, deprecated Iter.count (since := "2025-10-29")]
def Iter.size {α : Type w} {β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
(it : Iter (α := α) β) : Nat :=
it.count
/--
Steps through the whole iterator, counting the number of outputs emitted.
**Performance**:
This function's runtime is linear in the number of steps taken by the iterator.
-/
@[always_inline, inline, expose, deprecated Iter.count (since := "2025-12-04")]
@[always_inline, inline, expose, deprecated Iter.length (since := "2025-12-04")]
def Iter.Partial.count {α : Type w} {β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
(it : Iter.Partial (α := α) β) : Nat :=
it.it.toIterM.count.run.down
it.it.toIterM.length.run.down
/--
Steps through the whole iterator, counting the number of outputs emitted.
@@ -710,9 +740,9 @@ Steps through the whole iterator, counting the number of outputs emitted.
This function's runtime is linear in the number of steps taken by the iterator.
-/
@[always_inline, inline, expose, deprecated Iter.count (since := "2025-10-29")]
@[always_inline, inline, expose, deprecated Iter.length (since := "2025-10-29")]
def Iter.Partial.size {α : Type w} {β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
(it : Iter.Partial (α := α) β) : Nat :=
it.it.count
it.it.length
end Std

View File

@@ -950,6 +950,38 @@ def IterM.Total.first? {α β : Type w} {m : Type w → Type w'} [Monad m] [Iter
[IteratorLoop α m m] [Productive α m] (it : IterM.Total (α := α) m β) : m (Option β) :=
it.it.first?
set_option doc.verso true in
/--
Returns {lean}`ULift.up true` if the iterator {name}`it` yields no values.
{lit}`O(|it|)` since the iterator may skip an unknown number of times before returning a result.
Short-circuits upon encountering the first result. Only the first element of {name}`it` is examined.
If the iterator is not productive, this function might run forever. The variant
{lit}`it.ensureTermination.isEmpty` always terminates after finitely many steps.
-/
@[always_inline]
def IterM.isEmpty {α β : Type w} {m : Type w Type w'} [Monad m] [Iterator α m β]
[IteratorLoop α m m] (it : IterM (α := α) m β) : m (ULift Bool) :=
IteratorLoop.forIn (fun _ _ => flip Bind.bind) _ (fun _ _ s => s = ForInStep.done (.up false)) it
(.up true) (fun _ _ _ => pure ForInStep.done (.up false), rfl)
set_option doc.verso true in
/--
Returns {lean}`ULift.up true` if the iterator {name}`it` yields no values.
{lit}`O(|it|)` since the iterator may skip an unknown number of times before returning a result.
Short-circuits upon encountering the first result. Only the first element of {name}`it` is examined.
This variant terminates after finitely many steps and requires a proof that the iterator is
finite. If such a proof is not available, consider using {name}`IterM.isEmpty`.
-/
@[always_inline, inline]
def IterM.Total.isEmpty {α β : Type w} {m : Type w Type w'} [Monad m]
[Iterator α m β] [IteratorLoop α m m] [Productive α m] (it : IterM.Total (α := α) m β) :
m (ULift Bool) :=
it.it.isEmpty
section Count
/--
@@ -960,21 +992,15 @@ Steps through the whole iterator, counting the number of outputs emitted.
This function's runtime is linear in the number of steps taken by the iterator.
-/
@[always_inline, inline]
def IterM.count {α : Type w} {m : Type w Type w'} {β : Type w} [Iterator α m β]
def IterM.length {α : Type w} {m : Type w Type w'} {β : Type w} [Iterator α m β]
[IteratorLoop α m m] [Monad m] (it : IterM (α := α) m β) : m (ULift Nat) :=
it.fold (init := .up 0) fun acc _ => .up (acc.down + 1)
/--
Steps through the whole iterator, counting the number of outputs emitted.
@[inline, inherit_doc IterM.length, deprecated IterM.length (since := "2026-01-28"), expose]
def IterM.count := @IterM.length
**Performance**:
This function's runtime is linear in the number of steps taken by the iterator.
-/
@[always_inline, inline, deprecated IterM.count (since := "2025-10-29")]
def IterM.size {α : Type w} {m : Type w Type w'} {β : Type w} [Iterator α m β]
[IteratorLoop α m m] [Monad m] (it : IterM (α := α) m β) : m (ULift Nat) :=
it.count
@[inline, inherit_doc IterM.length, deprecated IterM.length (since := "2025-10-29"), expose]
def IterM.size := @IterM.length
/--
Steps through the whole iterator, counting the number of outputs emitted.
@@ -983,7 +1009,7 @@ Steps through the whole iterator, counting the number of outputs emitted.
This function's runtime is linear in the number of steps taken by the iterator.
-/
@[always_inline, inline, deprecated IterM.count (since := "2025-12-04")]
@[always_inline, inline, deprecated IterM.length (since := "2025-12-04")]
def IterM.Partial.count {α : Type w} {m : Type w Type w'} {β : Type w} [Iterator α m β]
[IteratorLoop α m m] [Monad m] (it : IterM.Partial (α := α) m β) : m (ULift Nat) :=
it.it.fold (init := .up 0) fun acc _ => .up (acc.down + 1)
@@ -995,10 +1021,10 @@ Steps through the whole iterator, counting the number of outputs emitted.
This function's runtime is linear in the number of steps taken by the iterator.
-/
@[always_inline, inline, deprecated IterM.Partial.count (since := "2025-10-29")]
@[always_inline, inline, deprecated IterM.length (since := "2025-10-29")]
def IterM.Partial.size {α : Type w} {m : Type w Type w'} {β : Type w} [Iterator α m β]
[IteratorLoop α m m] [Monad m] (it : IterM.Partial (α := α) m β) : m (ULift Nat) :=
it.it.count
it.it.length
end Count

View File

@@ -29,7 +29,7 @@ consumers such as `toList`. They can be used without any proof of termination su
or `Productive`, but as they are implemented with the `partial` declaration modifier, they are
opaque for the kernel and it is impossible to prove anything about them.
-/
@[always_inline, inline]
@[always_inline, inline, deprecated "The consumers on iterators do not require proofs of termination anymore. For example, use `it.toList` instead of `it.allowNontermination.toList`." (since := "2026-01-28")]
def IterM.allowNontermination {α : Type w} {m : Type w Type w'} {β : Type w}
(it : IterM (α := α) m β) : IterM.Partial (α := α) m β :=
it

View File

@@ -29,7 +29,7 @@ consumers such as `toList`. They can be used without any proof of termination su
or `Productive`, but as they are implemented with the `partial` declaration modifier, they are
opaque for the kernel and it is impossible to prove anything about them.
-/
@[always_inline, inline]
@[always_inline, inline, deprecated "The consumers on iterators do not require proofs of termination anymore. For example, use `it.toList` instead of `it.allowNontermination.toList`." (since := "2026-01-28")]
def Iter.allowNontermination {α : Type w} {β : Type w}
(it : Iter (α := α) β) : Iter.Partial (α := α) β :=
it

View File

@@ -79,12 +79,15 @@ theorem Iter.toArray_attachWith [Iterator α Id β]
simp [Iter.toList_toArray]
@[simp]
theorem Iter.count_attachWith [Iterator α Id β]
theorem Iter.length_attachWith [Iterator α Id β]
{it : Iter (α := α) β} {hP}
[Finite α Id] [IteratorLoop α Id Id]
[LawfulIteratorLoop α Id Id] :
(it.attachWith P hP).count = it.count := by
rw [ Iter.length_toList_eq_count, toList_attachWith]
(it.attachWith P hP).length = it.length := by
rw [ Iter.length_toList_eq_length, toList_attachWith]
simp
@[deprecated Iter.length_attachWith (since := "2026-01-28")]
def Iter.count_attachWith := @Iter.length_attachWith
end Std

View File

@@ -722,11 +722,14 @@ end Fold
section Count
@[simp]
theorem Iter.count_map {α β β' : Type w} [Iterator α Id β]
theorem Iter.length_map {α β β' : Type w} [Iterator α Id β]
[IteratorLoop α Id Id] [Finite α Id] [LawfulIteratorLoop α Id Id]
{it : Iter (α := α) β} {f : β β'} :
(it.map f).count = it.count := by
simp [map_eq_toIter_map_toIterM, count_eq_count_toIterM]
(it.map f).length = it.length := by
simp [map_eq_toIter_map_toIterM, length_eq_length_toIterM]
@[deprecated Iter.length_map (since := "2026-01-28")]
def Iter.count_map := @Iter.length_map
end Count

View File

@@ -60,12 +60,15 @@ theorem IterM.map_unattach_toArray_attachWith [Iterator α m β] [Monad m] [Mona
simp [-map_unattach_toList_attachWith, -IterM.toArray_toList]
@[simp]
theorem IterM.count_attachWith [Iterator α m β] [Monad m] [Monad n]
theorem IterM.length_attachWith [Iterator α m β] [Monad m] [Monad n]
{it : IterM (α := α) m β} {hP}
[Finite α m] [IteratorLoop α m m] [LawfulMonad m] [LawfulIteratorLoop α m m] :
(it.attachWith P hP).count = it.count := by
rw [ up_length_toList_eq_count, up_length_toList_eq_count,
(it.attachWith P hP).length = it.length := by
rw [ up_length_toList_eq_length, up_length_toList_eq_length,
map_unattach_toList_attachWith (it := it) (P := P) (hP := hP)]
simp only [Functor.map_map, List.length_unattach]
@[deprecated IterM.length_attachWith (since := "2026-01-28")]
def IterM.count_attachWith := @IterM.length_attachWith
end Std

View File

@@ -1620,18 +1620,21 @@ end Fold
section Count
@[simp]
theorem IterM.count_map {α β β' : Type w} {m : Type w Type w'} [Iterator α m β] [Monad m]
theorem IterM.length_map {α β β' : Type w} {m : Type w Type w'} [Iterator α m β] [Monad m]
[IteratorLoop α m m] [Finite α m] [LawfulMonad m] [LawfulIteratorLoop α m m]
{it : IterM (α := α) m β} {f : β β'} :
(it.map f).count = it.count := by
(it.map f).length = it.length := by
induction it using IterM.inductSteps with | step it ihy ihs
rw [count_eq_match_step, count_eq_match_step, step_map, bind_assoc]
rw [length_eq_match_step, length_eq_match_step, step_map, bind_assoc]
apply bind_congr; intro step
cases step.inflate using PlausibleIterStep.casesOn
· simp [ihy _]
· simp [ihs _]
· simp
@[deprecated IterM.length_map (since := "2026-01-28")]
def IterM.count_map := @IterM.length_map
end Count
section AnyAll

View File

@@ -66,14 +66,14 @@ theorem IterM.toArray_uLift [Iterator α m β] [Monad m] [Monad n] {it : IterM (
simp
@[simp]
theorem IterM.count_uLift [Iterator α m β] [Monad m] [Monad n] {it : IterM (α := α) m β}
theorem IterM.length_uLift [Iterator α m β] [Monad m] [Monad n] {it : IterM (α := α) m β}
[MonadLiftT m (ULiftT n)] [Finite α m] [IteratorLoop α m m]
[LawfulMonad m] [LawfulMonad n] [LawfulIteratorLoop α m m]
[LawfulMonadLiftT m (ULiftT n)] :
(it.uLift n).count =
(.up ·.down.down) <$> (monadLift (n := ULiftT n) it.count).run := by
(it.uLift n).length =
(.up ·.down.down) <$> (monadLift (n := ULiftT n) it.length).run := by
induction it using IterM.inductSteps with | step it ihy ihs
rw [count_eq_match_step, count_eq_match_step, monadLift_bind, map_eq_pure_bind, step_uLift]
rw [length_eq_match_step, length_eq_match_step, monadLift_bind, map_eq_pure_bind, step_uLift]
simp only [bind_assoc, ULiftT.run_bind]
apply bind_congr; intro step
cases step.down.inflate using PlausibleIterStep.casesOn
@@ -81,4 +81,7 @@ theorem IterM.count_uLift [Iterator α m β] [Monad m] [Monad n] {it : IterM (α
· simp [ihs _]
· simp
@[deprecated IterM.length_uLift (since := "2026-01-28")]
def IterM.count_uLift := @IterM.length_uLift
end Std

View File

@@ -47,18 +47,18 @@ theorem Iter.atIdxSlow?_take {α β}
[Iterator α Id β] [Productive α Id] {k l : Nat}
{it : Iter (α := α) β} :
(it.take k).atIdxSlow? l = if l < k then it.atIdxSlow? l else none := by
fun_induction it.atIdxSlow? l generalizing k
case case1 it it' out h h' =>
simp only [atIdxSlow?.eq_def (it := it.take k), step_take, h']
induction l, it using Iter.atIdxSlow?.induct_unfolding generalizing k
case yield_zero it it' out h h' =>
simp only [atIdxSlow?_eq_match (it := it.take k), step_take, h']
cases k <;> simp
case case2 it it' out h h' l ih =>
simp only [Nat.succ_eq_add_one, atIdxSlow?.eq_def (it := it.take k), step_take, h']
case yield_succ it it' out h h' l ih =>
simp only [Nat.succ_eq_add_one, atIdxSlow?_eq_match (it := it.take k), step_take, h']
cases k <;> cases l <;> simp [ih]
case case3 l it it' h h' ih =>
simp only [atIdxSlow?.eq_def (it := it.take k), step_take, h']
case skip_case l it it' h h' ih =>
simp only [atIdxSlow?_eq_match (it := it.take k), step_take, h']
cases k <;> cases l <;> simp [ih]
case case4 l it h h' =>
simp only [atIdxSlow?.eq_def (it := it.take k), step_take, h']
case done_case l it h h' =>
simp only [atIdxSlow?_eq_match (it := it.take k), step_take, h']
cases k <;> cases l <;> simp
@[simp]

View File

@@ -57,11 +57,14 @@ theorem Iter.toArray_uLift [Iterator α Id β] {it : Iter (α := α) β}
simp [-toArray_toList]
@[simp]
theorem Iter.count_uLift [Iterator α Id β] {it : Iter (α := α) β}
theorem Iter.length_uLift [Iterator α Id β] {it : Iter (α := α) β}
[Finite α Id] [IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id] :
it.uLift.count = it.count := by
simp only [monadLift, uLift_eq_toIter_uLift_toIterM, count_eq_count_toIterM, toIterM_toIter]
rw [IterM.count_uLift]
it.uLift.length = it.length := by
simp only [monadLift, uLift_eq_toIter_uLift_toIterM, length_eq_length_toIterM, toIterM_toIter]
rw [IterM.length_uLift]
simp [monadLift]
@[deprecated Iter.length_uLift (since := "2026-01-28")]
def Iter.count_uLift := @Iter.length_uLift
end Std

View File

@@ -7,6 +7,7 @@ module
prelude
public import Init.Data.Iterators.Consumers.Access
import Init.Data.Iterators.Lemmas.Basic
namespace Std.Iter
open Std.Iterators
@@ -21,6 +22,6 @@ public theorem atIdxSlow?_eq_match [Iterator α Id β] [Productive α Id]
| n + 1 => it'.atIdxSlow? n
| .skip it' => it'.atIdxSlow? n
| .done => none) := by
fun_induction it.atIdxSlow? n <;> simp_all
induction n, it using Iter.atIdxSlow?.induct_unfolding <;> simp_all
end Std.Iter

View File

@@ -163,12 +163,14 @@ theorem Iter.getElem?_toList_eq_atIdxSlow? {α β}
{it : Iter (α := α) β} {k : Nat} :
it.toList[k]? = it.atIdxSlow? k := by
induction it using Iter.inductSteps generalizing k with | step it ihy ihs
rw [toList_eq_match_step, atIdxSlow?]
obtain step, h := it.step
cases step
· cases k <;> simp [ihy h]
· simp [ihs h]
· simp
rw [toList_eq_match_step, atIdxSlow?, WellFounded.extrinsicFix₂_eq_apply]
· obtain step, h := it.step
cases step
· cases k <;> simp [ihy h, atIdxSlow?]
· simp [ihs h, atIdxSlow?]
· simp
· apply InvImage.wf
exact WellFoundedRelation.wf
theorem Iter.toList_eq_of_atIdxSlow?_eq {α₁ α₂ β}
[Iterator α₁ Id β] [Finite α₁ Id]

View File

@@ -460,69 +460,90 @@ theorem Iter.foldl_toArray {α β : Type w} {γ : Type x} [Iterator α Id β] [F
it.toArray.foldl (init := init) f = it.fold (init := init) f := by
rw [fold_eq_foldM, Array.foldl_eq_foldlM, Iter.foldlM_toArray]
theorem Iter.count_eq_count_toIterM {α β : Type w} [Iterator α Id β]
theorem Iter.length_eq_length_toIterM {α β : Type w} [Iterator α Id β]
[Finite α Id] [IteratorLoop α Id Id.{w}] {it : Iter (α := α) β} :
it.count = it.toIterM.count.run.down :=
it.length = it.toIterM.length.run.down :=
(rfl)
theorem Iter.count_eq_fold {α β : Type w} [Iterator α Id β]
@[deprecated Iter.length_eq_length_toIterM (since := "2026-01-28")]
def Iter.count_eq_count_toIterM := @Iter.length_eq_length_toIterM
theorem Iter.length_eq_fold {α β : Type w} [Iterator α Id β]
[Finite α Id] [IteratorLoop α Id Id.{w}] [LawfulIteratorLoop α Id Id.{w}]
[IteratorLoop α Id Id.{0}] [LawfulIteratorLoop α Id Id.{0}]
{it : Iter (α := α) β} :
it.count = it.fold (γ := Nat) (init := 0) (fun acc _ => acc + 1) := by
rw [count_eq_count_toIterM, IterM.count_eq_fold, fold_eq_fold_toIterM]
it.length = it.fold (γ := Nat) (init := 0) (fun acc _ => acc + 1) := by
rw [length_eq_length_toIterM, IterM.length_eq_fold, fold_eq_fold_toIterM]
rw [ fold_hom (f := ULift.down)]
simp
theorem Iter.count_eq_forIn {α β : Type w} [Iterator α Id β]
@[deprecated Iter.length_eq_fold (since := "2026-01-28")]
def Iter.count_eq_fold := @Iter.length_eq_fold
theorem Iter.length_eq_forIn {α β : Type w} [Iterator α Id β]
[Finite α Id] [IteratorLoop α Id Id.{w}] [LawfulIteratorLoop α Id Id.{w}]
[IteratorLoop α Id Id.{0}] [LawfulIteratorLoop α Id Id.{0}]
{it : Iter (α := α) β} :
it.count = (ForIn.forIn (m := Id) it 0 (fun _ acc => return .yield (acc + 1))).run := by
rw [count_eq_fold, forIn_pure_yield_eq_fold, Id.run_pure]
it.length = (ForIn.forIn (m := Id) it 0 (fun _ acc => return .yield (acc + 1))).run := by
rw [length_eq_fold, forIn_pure_yield_eq_fold, Id.run_pure]
theorem Iter.count_eq_match_step {α β : Type w} [Iterator α Id β]
@[deprecated Iter.length_eq_forIn (since := "2026-01-28")]
def Iter.count_eq_forIn := @Iter.length_eq_forIn
theorem Iter.length_eq_match_step {α β : Type w} [Iterator α Id β]
[Finite α Id] [IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id]
{it : Iter (α := α) β} :
it.count = (match it.step.val with
| .yield it' _ => it'.count + 1
| .skip it' => it'.count
it.length = (match it.step.val with
| .yield it' _ => it'.length + 1
| .skip it' => it'.length
| .done => 0) := by
simp only [count_eq_count_toIterM]
rw [IterM.count_eq_match_step]
simp only [length_eq_length_toIterM]
rw [IterM.length_eq_match_step]
simp only [bind_pure_comp, id_map', Id.run_bind, Iter.step]
cases it.toIterM.step.run.inflate using PlausibleIterStep.casesOn <;> simp
@[simp]
theorem Iter.size_toArray_eq_count {α β : Type w} [Iterator α Id β] [Finite α Id]
[IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id]
{it : Iter (α := α) β} :
it.toArray.size = it.count := by
simp only [toArray_eq_toArray_toIterM, count_eq_count_toIterM, Id.run_map,
IterM.up_size_toArray_eq_count]
@[deprecated Iter.size_toArray_eq_count (since := "2025-10-29")]
def Iter.size_toArray_eq_size := @size_toArray_eq_count
@[deprecated Iter.length_eq_match_step (since := "2026-01-28")]
def Iter.count_eq_match_step := @Iter.length_eq_match_step
@[simp]
theorem Iter.length_toList_eq_count {α β : Type w} [Iterator α Id β] [Finite α Id]
theorem Iter.size_toArray_eq_length {α β : Type w} [Iterator α Id β] [Finite α Id]
[IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id]
{it : Iter (α := α) β} :
it.toList.length = it.count := by
rw [ toList_toArray, Array.length_toList, size_toArray_eq_count]
it.toArray.size = it.length := by
simp only [toArray_eq_toArray_toIterM, length_eq_length_toIterM, Id.run_map,
IterM.up_size_toArray_eq_length]
@[deprecated Iter.length_toList_eq_count (since := "2025-10-29")]
def Iter.length_toList_eq_size := @length_toList_eq_count
@[deprecated Iter.size_toArray_eq_length (since := "2025-10-29")]
def Iter.size_toArray_eq_size := @size_toArray_eq_length
@[deprecated Iter.size_toArray_eq_length (since := "2026-01-28")]
def Iter.size_toArray_eq_count := @size_toArray_eq_length
@[simp]
theorem Iter.length_toListRev_eq_count {α β : Type w} [Iterator α Id β] [Finite α Id]
theorem Iter.length_toList_eq_length {α β : Type w} [Iterator α Id β] [Finite α Id]
[IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id]
{it : Iter (α := α) β} :
it.toListRev.length = it.count := by
rw [toListRev_eq, List.length_reverse, length_toList_eq_count]
it.toList.length = it.length := by
rw [ toList_toArray, Array.length_toList, size_toArray_eq_length]
@[deprecated Iter.length_toListRev_eq_count (since := "2025-10-29")]
def Iter.length_toListRev_eq_size := @length_toListRev_eq_count
@[deprecated Iter.length_toList_eq_length (since := "2025-10-29")]
def Iter.length_toList_eq_size := @length_toList_eq_length
@[deprecated Iter.length_toList_eq_length (since := "2026-01-28")]
def Iter.length_toList_eq_count := @length_toList_eq_length
@[simp]
theorem Iter.length_toListRev_eq_length {α β : Type w} [Iterator α Id β] [Finite α Id]
[IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id]
{it : Iter (α := α) β} :
it.toListRev.length = it.length := by
rw [toListRev_eq, List.length_reverse, length_toList_eq_length]
@[deprecated Iter.length_toListRev_eq_length (since := "2025-10-29")]
def Iter.length_toListRev_eq_size := @length_toListRev_eq_length
@[deprecated Iter.length_toListRev_eq_length (since := "2026-01-28")]
def Iter.length_toListRev_eq_count := @length_toListRev_eq_length
theorem Iter.anyM_eq_forIn {α β : Type w} {m : Type Type w'} [Iterator α Id β]
[Finite α Id] [Monad m] [LawfulMonad m] [IteratorLoop α Id m] [LawfulIteratorLoop α Id m]
@@ -930,11 +951,35 @@ theorem Iter.first?_eq_match_step {α β : Type w} [Iterator α Id β] [Iterator
generalize it.toIterM.step.run.inflate = s
rcases s with _|_|_, _ <;> simp [Iter.first?_eq_first?_toIterM]
theorem Iter.first?_eq_head?_toList {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
@[simp, grind =]
theorem Iter.head?_toList {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
[Finite α Id] [LawfulIteratorLoop α Id Id] {it : Iter (α := α) β} :
it.first? = it.toList.head? := by
it.toList.head? = it.first? := by
induction it using Iter.inductSteps with | step it ihy ihs
rw [first?_eq_match_step, toList_eq_match_step]
cases it.step using PlausibleIterStep.casesOn <;> simp [*]
theorem Iter.isEmpty_eq_isEmpty_toIterM {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
{it : Iter (α := α) β} :
it.isEmpty = it.toIterM.isEmpty.run.down := (rfl)
theorem Iter.isEmpty_eq_match_step {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
[Productive α Id] [LawfulIteratorLoop α Id Id] {it : Iter (α := α) β} :
it.isEmpty = match it.step.val with
| .yield _ _ => false
| .skip it' => it'.isEmpty
| .done => true := by
rw [Iter.isEmpty_eq_isEmpty_toIterM, IterM.isEmpty_eq_match_step]
simp only [Id.run_bind, step]
generalize it.toIterM.step.run.inflate = s
rcases s with _|_|_, _ <;> simp [Iter.isEmpty_eq_isEmpty_toIterM]
@[simp, grind =]
theorem Iter.isEmpty_toList {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
[Finite α Id] [LawfulIteratorLoop α Id Id] {it : Iter (α := α) β} :
it.toList.isEmpty = it.isEmpty := by
induction it using Iter.inductSteps with | step it ihy ihs
rw [isEmpty_eq_match_step, toList_eq_match_step]
cases it.step using PlausibleIterStep.casesOn <;> simp [*]
end Std

View File

@@ -476,27 +476,33 @@ theorem IterM.drain_eq_map_toArray {α β : Type w} {m : Type w → Type w'} [It
it.drain = (fun _ => .unit) <$> it.toList := by
simp [IterM.drain_eq_map_toList]
theorem IterM.count_eq_fold {α β : Type w} {m : Type w Type w'} [Iterator α m β]
theorem IterM.length_eq_fold {α β : Type w} {m : Type w Type w'} [Iterator α m β]
[Finite α m] [Monad m] [LawfulMonad m] [IteratorLoop α m m]
{it : IterM (α := α) m β} :
it.count = it.fold (init := .up 0) (fun acc _ => .up <| acc.down + 1) :=
it.length = it.fold (init := .up 0) (fun acc _ => .up <| acc.down + 1) :=
(rfl)
theorem IterM.count_eq_forIn {α β : Type w} {m : Type w Type w'} [Iterator α m β]
@[deprecated IterM.length_eq_fold (since := "2026-01-28")]
def IterM.count_eq_fold := @IterM.length_eq_fold
theorem IterM.length_eq_forIn {α β : Type w} {m : Type w Type w'} [Iterator α m β]
[Finite α m] [Monad m] [LawfulMonad m] [IteratorLoop α m m]
{it : IterM (α := α) m β} :
it.count = ForIn.forIn it (.up 0) (fun _ acc => return .yield (.up (acc.down + 1))) :=
it.length = ForIn.forIn it (.up 0) (fun _ acc => return .yield (.up (acc.down + 1))) :=
(rfl)
theorem IterM.count_eq_match_step {α β : Type w} {m : Type w Type w'} [Iterator α m β]
@[deprecated IterM.length_eq_forIn (since := "2026-01-28")]
def IterM.count_eq_forIn := @IterM.length_eq_forIn
theorem IterM.length_eq_match_step {α β : Type w} {m : Type w Type w'} [Iterator α m β]
[Finite α m] [Monad m] [LawfulMonad m] [IteratorLoop α m m] [LawfulIteratorLoop α m m]
{it : IterM (α := α) m β} :
it.count = (do
it.length = (do
match ( it.step).inflate.val with
| .yield it' _ => return .up (( it'.count).down + 1)
| .skip it' => return .up ( it'.count).down
| .yield it' _ => return .up (( it'.length).down + 1)
| .skip it' => return .up ( it'.length).down
| .done => return .up 0) := by
simp only [count_eq_fold]
simp only [length_eq_fold]
have (acc : Nat) (it' : IterM (α := α) m β) :
it'.fold (init := ULift.up acc) (fun acc _ => .up (acc.down + 1)) =
(ULift.up <| ·.down + acc) <$>
@@ -512,33 +518,45 @@ theorem IterM.count_eq_match_step {α β : Type w} {m : Type w → Type w'} [Ite
· simp
· simp
@[deprecated IterM.length_eq_match_step (since := "2026-01-28")]
def IterM.count_eq_match_step := @IterM.length_eq_match_step
@[simp]
theorem IterM.up_size_toArray_eq_count {α β : Type w} [Iterator α m β] [Finite α m]
theorem IterM.up_size_toArray_eq_length {α β : Type w} [Iterator α m β] [Finite α m]
[Monad m] [LawfulMonad m]
[IteratorLoop α m m] [LawfulIteratorLoop α m m]
{it : IterM (α := α) m β} :
(.up <| ·.size) <$> it.toArray = it.count := by
rw [toArray_eq_fold, count_eq_fold, fold_hom]
(.up <| ·.size) <$> it.toArray = it.length := by
rw [toArray_eq_fold, length_eq_fold, fold_hom]
· simp only [List.size_toArray, List.length_nil]; rfl
· simp
@[deprecated IterM.up_size_toArray_eq_length (since := "2026-01-28")]
def IterM.up_size_toArray_eq_count := @IterM.up_size_toArray_eq_length
@[simp]
theorem IterM.up_length_toList_eq_count {α β : Type w} [Iterator α m β] [Finite α m]
theorem IterM.up_length_toList_eq_length {α β : Type w} [Iterator α m β] [Finite α m]
[Monad m] [LawfulMonad m]
[IteratorLoop α m m] [LawfulIteratorLoop α m m]
{it : IterM (α := α) m β} :
(.up <| ·.length) <$> it.toList = it.count := by
rw [toList_eq_fold, count_eq_fold, fold_hom]
(.up <| ·.length) <$> it.toList = it.length := by
rw [toList_eq_fold, length_eq_fold, fold_hom]
· simp only [List.length_nil]; rfl
· simp
@[deprecated IterM.up_length_toList_eq_length (since := "2026-01-28")]
def IterM.up_length_toList_eq_count := @IterM.up_length_toList_eq_length
@[simp]
theorem IterM.up_length_toListRev_eq_count {α β : Type w} [Iterator α m β] [Finite α m]
theorem IterM.up_length_toListRev_eq_length {α β : Type w} [Iterator α m β] [Finite α m]
[Monad m] [LawfulMonad m]
[IteratorLoop α m m] [LawfulIteratorLoop α m m]
{it : IterM (α := α) m β} :
(.up <| ·.length) <$> it.toListRev = it.count := by
simp only [toListRev_eq, Functor.map_map, List.length_reverse, up_length_toList_eq_count]
(.up <| ·.length) <$> it.toListRev = it.length := by
simp only [toListRev_eq, Functor.map_map, List.length_reverse, up_length_toList_eq_length]
@[deprecated IterM.up_length_toListRev_eq_length (since := "2026-01-28")]
def IterM.up_length_toListRev_eq_count := @IterM.up_length_toListRev_eq_length
theorem IterM.anyM_eq_forIn {α β : Type w} {m : Type w Type w'} [Iterator α m β]
[Finite α m] [Monad m] [LawfulMonad m] [IteratorLoop α m m] [LawfulIteratorLoop α m m]
@@ -861,4 +879,24 @@ theorem IterM.first?_eq_match_step {α β : Type w} {m : Type w → Type w'} [Mo
simp only [DefaultConsumers.forIn_eq, *]
exact IterM.DefaultConsumers.forIn'_eq_forIn' _ this (by simp)
theorem IterM.isEmpty_eq_match_step {α β : Type w} {m : Type w Type w'} [Monad m]
[Iterator α m β] [IteratorLoop α m m] [LawfulMonad m] [Productive α m]
[LawfulIteratorLoop α m m] {it : IterM (α := α) m β} :
it.isEmpty = (do
match ( it.step).inflate.val with
| .yield _ _ => return .up false
| .skip it' => it'.isEmpty
| .done => return .up true) := by
simp only [isEmpty]
have := IteratorLoop.wellFounded_of_productive (α := α) (β := β) (m := m)
(P := fun _ _ s => s = ForInStep.done (ULift.up false)) (by simp)
simp only [LawfulIteratorLoop.lawful _ _ _ _ _ this]
rw [IterM.DefaultConsumers.forIn_eq, IterM.DefaultConsumers.forIn'_eq_match_step _ this]
simp only [flip, pure_bind]
congr
ext s
split <;> try (simp [*]; done)
simp only [DefaultConsumers.forIn_eq, *]
exact IterM.DefaultConsumers.forIn'_eq_forIn' _ this (by simp)
end Std

View File

@@ -16,6 +16,8 @@ public import Init.Data.List.Find
public import Init.Data.List.Impl
public import Init.Data.List.Lemmas
public import Init.Data.List.MinMax
public import Init.Data.List.MinMaxIdx
public import Init.Data.List.MinMaxOn
public import Init.Data.List.Monadic
public import Init.Data.List.Nat
public import Init.Data.List.Notation

View File

@@ -85,7 +85,7 @@ theorem cons_lex_cons_iff : Lex r (a :: l₁) (b :: l₂) ↔ r a b a = b
theorem cons_lt_cons_iff [LT α] {a b} {l₁ l₂ : List α} :
(a :: l₁) < (b :: l₂) a < b a = b l₁ < l₂ := by
dsimp only [instLT, List.lt]
simp only [LT.lt, List.lt]
simp [cons_lex_cons_iff]
@[simp] theorem cons_lt_cons_self [LT α] [i₀ : Std.Irrefl (· < · : α α Prop)] {l₁ l₂ : List α} :
@@ -101,7 +101,7 @@ theorem cons_le_cons_iff [LT α]
[i₂ : Std.Trichotomous (· < · : α α Prop)]
{a b} {l₁ l₂ : List α} :
(a :: l₁) (b :: l₂) a < b a = b l₁ l₂ := by
dsimp only [instLE, instLT, List.le, List.lt]
simp only [LE.le, LT.lt, List.le, List.lt]
open Classical in
simp only [not_cons_lex_cons_iff, ne_eq]
constructor

View File

@@ -29,7 +29,11 @@ open Nat
/-! ### min? -/
@[simp] theorem min?_nil [Min α] : ([] : List α).min? = none := rfl
@[simp, grind =] theorem min?_nil [Min α] : ([] : List α).min? = none := rfl
@[simp, grind =]
public theorem min?_singleton [Min α] {x : α} : [x].min? = some x :=
(rfl)
-- We don't put `@[simp]` on `min?_cons'`,
-- because the definition in terms of `foldl` is not useful for proofs.
@@ -39,9 +43,14 @@ theorem min?_cons' [Min α] {xs : List α} : (x :: xs).min? = some (foldl min x
(x :: xs).min? = some (xs.min?.elim x (min x)) := by
cases xs <;> simp [min?_cons', foldl_assoc]
@[simp] theorem min?_eq_none_iff {xs : List α} [Min α] : xs.min? = none xs = [] := by
@[simp, grind =] theorem min?_eq_none_iff {xs : List α} [Min α] : xs.min? = none xs = [] := by
cases xs <;> simp [min?]
@[simp, grind =]
public theorem isSome_min?_iff {xs : List α} [Min α] : xs.min?.isSome xs [] := by
cases xs <;> simp [min?]
@[grind .]
theorem isSome_min?_of_mem {l : List α} [Min α] {a : α} (h : a l) :
l.min?.isSome := by
cases l <;> simp_all [min?_cons']
@@ -143,7 +152,8 @@ theorem min?_replicate [Min α] [Std.IdempotentOp (min : ααα)] {n :
| zero => rfl
| succ n ih => cases n <;> simp_all [replicate_succ, min?_cons', Std.IdempotentOp.idempotent]
@[simp] theorem min?_replicate_of_pos [Min α] [MinEqOr α] {n : Nat} {a : α} (h : 0 < n) :
@[simp, grind =]
theorem min?_replicate_of_pos [Min α] [MinEqOr α] {n : Nat} {a : α} (h : 0 < n) :
(replicate n a).min? = some a := by
simp [min?_replicate, Nat.ne_of_gt h]
@@ -160,6 +170,11 @@ theorem foldl_min [Min α] [Std.IdempotentOp (min : ααα)] [Std.Asso
/-! ### min -/
@[simp, grind =]
theorem min_singleton [Min α] {x : α} :
[x].min (cons_ne_nil _ _) = x := by
(rfl)
theorem min?_eq_some_min [Min α] : {l : List α} (hl : l [])
l.min? = some (l.min hl)
| a::as, _ => by simp [List.min, List.min?_cons']
@@ -168,15 +183,22 @@ theorem min_eq_get_min? [Min α] : (l : List α) → (hl : l ≠ []) →
l.min hl = l.min?.get (isSome_min?_of_ne_nil hl)
| a::as, _ => by simp [List.min, List.min?_cons']
@[simp, grind =]
theorem get_min? [Min α] {l : List α} {h : l.min?.isSome} :
l.min?.get h = l.min (isSome_min?_iff.mp h) := by
simp [min?_eq_some_min (isSome_min?_iff.mp h)]
theorem min_eq_head {α : Type u} [Min α] {l : List α} (hl : l [])
(h : l.Pairwise (fun a b => min a b = a)) : l.min hl = l.head hl := by
apply Option.some.inj
rw [ min?_eq_some_min, head?_eq_some_head]
exact min?_eq_head? h
@[grind .]
theorem min_mem [Min α] [MinEqOr α] {l : List α} (hl : l []) : l.min hl l :=
min?_mem (min?_eq_some_min hl)
@[grind .]
theorem min_le_of_mem [Min α] [LE α] [Std.IsLinearOrder α] [Std.LawfulOrderMin α]
{l : List α} {a : α} (ha : a l) :
l.min (ne_nil_of_mem ha) a :=
@@ -190,7 +212,7 @@ theorem min_eq_iff [Min α] [LE α] {l : List α} [IsLinearOrder α] [LawfulOrde
l.min hl = a a l b, b l a b := by
simpa [min?_eq_some_min hl] using (min?_eq_some_iff (xs := l))
@[simp] theorem min_replicate [Min α] [MinEqOr α] {n : Nat} {a : α} (h : replicate n a []) :
@[simp, grind =] theorem min_replicate [Min α] [MinEqOr α] {n : Nat} {a : α} (h : replicate n a []) :
(replicate n a).min h = a := by
have n_pos : 0 < n := Nat.pos_of_ne_zero (fun hn => by simp [hn] at h)
simpa [min?_eq_some_min h] using (min?_replicate_of_pos (a := a) n_pos)
@@ -202,7 +224,11 @@ theorem foldl_min_eq_min [Min α] [Std.IdempotentOp (min : ααα)] [S
/-! ### max? -/
@[simp] theorem max?_nil [Max α] : ([] : List α).max? = none := rfl
@[simp, grind =] theorem max?_nil [Max α] : ([] : List α).max? = none := rfl
@[simp, grind =]
public theorem max?_singleton [Max α] {x : α} : [x].max? = some x :=
(rfl)
-- We don't put `@[simp]` on `max?_cons'`,
-- because the definition in terms of `foldl` is not useful for proofs.
@@ -212,9 +238,14 @@ theorem max?_cons' [Max α] {xs : List α} : (x :: xs).max? = some (foldl max x
(x :: xs).max? = some (xs.max?.elim x (max x)) := by
cases xs <;> simp [max?_cons', foldl_assoc]
@[simp] theorem max?_eq_none_iff {xs : List α} [Max α] : xs.max? = none xs = [] := by
@[simp, grind =] theorem max?_eq_none_iff {xs : List α} [Max α] : xs.max? = none xs = [] := by
cases xs <;> simp [max?]
@[simp, grind =]
public theorem isSome_max?_iff {xs : List α} [Max α] : xs.max?.isSome xs [] := by
cases xs <;> simp [max?]
@[grind .]
theorem isSome_max?_of_mem {l : List α} [Max α] {a : α} (h : a l) :
l.max?.isSome := by
cases l <;> simp_all [max?_cons']
@@ -329,7 +360,8 @@ theorem max?_replicate [Max α] [Std.IdempotentOp (max : ααα)] {n :
| zero => rfl
| succ n ih => cases n <;> simp_all [replicate_succ, max?_cons', Std.IdempotentOp.idempotent]
@[simp] theorem max?_replicate_of_pos [Max α] [MaxEqOr α] {n : Nat} {a : α} (h : 0 < n) :
@[simp, grind =]
theorem max?_replicate_of_pos [Max α] [MaxEqOr α] {n : Nat} {a : α} (h : 0 < n) :
(replicate n a).max? = some a := by
simp [max?_replicate, Nat.ne_of_gt h]
@@ -346,6 +378,11 @@ theorem foldl_max [Max α] [Std.IdempotentOp (max : ααα)] [Std.Asso
/-! ### max -/
@[simp, grind =]
theorem max_singleton [Max α] {x : α} :
[x].max (cons_ne_nil _ _) = x := by
(rfl)
theorem max?_eq_some_max [Max α] : {l : List α} (hl : l [])
l.max? = some (l.max hl)
| a::as, _ => by simp [List.max, List.max?_cons']
@@ -354,12 +391,18 @@ theorem max_eq_get_max? [Max α] : (l : List α) → (hl : l ≠ []) →
l.max hl = l.max?.get (isSome_max?_of_ne_nil hl)
| a::as, _ => by simp [List.max, List.max?_cons']
@[simp, grind =]
theorem get_max? [Max α] {l : List α} {h : l.max?.isSome} :
l.max?.get h = l.max (isSome_max?_iff.mp h) := by
simp [max?_eq_some_max (isSome_max?_iff.mp h)]
theorem max_eq_head {α : Type u} [Max α] {l : List α} (hl : l [])
(h : l.Pairwise (fun a b => max a b = a)) : l.max hl = l.head hl := by
apply Option.some.inj
rw [ max?_eq_some_max, head?_eq_some_head]
exact max?_eq_head? h
@[grind .]
theorem max_mem [Max α] [MaxEqOr α] {l : List α} (hl : l []) : l.max hl l :=
max?_mem (max?_eq_some_max hl)
@@ -371,12 +414,13 @@ theorem max_eq_iff [Max α] [LE α] {l : List α} [IsLinearOrder α] [LawfulOrde
l.max hl = a a l b, b l b a := by
simpa [max?_eq_some_max hl] using (max?_eq_some_iff (xs := l))
@[grind .]
theorem le_max_of_mem [Max α] [LE α] [Std.IsLinearOrder α] [Std.LawfulOrderMax α]
{l : List α} {a : α} (ha : a l) :
a l.max (List.ne_nil_of_mem ha) :=
(max?_eq_some_iff.mp (max?_eq_some_max (List.ne_nil_of_mem ha))).right a ha
@[simp] theorem max_replicate [Max α] [MaxEqOr α] {n : Nat} {a : α} (h : replicate n a []) :
@[simp, grind =] theorem max_replicate [Max α] [MaxEqOr α] {n : Nat} {a : α} (h : replicate n a []) :
(replicate n a).max h = a := by
have n_pos : 0 < n := Nat.pos_of_ne_zero (fun hn => by simp [hn] at h)
simpa [max?_eq_some_max h] using (max?_replicate_of_pos (a := a) n_pos)

View File

@@ -0,0 +1,830 @@
/-
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Reichert
-/
module
prelude
public import Init.Data.List.MinMaxOn
import Init.Data.List.MinMaxOn
public import Init.Data.List.Pairwise
public import Init.Data.Subtype.Order
import Init.Data.Order.Lemmas
import Init.Data.List.Nat.TakeDrop
import Init.Data.Order.Opposite
import Init.Data.Nat.Order
public section
open Std
open scoped OppositeOrderInstances
set_option doc.verso true
set_option linter.missingDocs true
set_option linter.listVariables true -- Enforce naming conventions for `List`/`Array`/`Vector` variables.
set_option linter.indexVariables true -- Enforce naming conventions for index variables.
namespace List
/--
Returns the index of an element of the non-empty list {name}`xs` that minimizes {name}`f`.
If {given}`x, y` are such that {lean}`f x = f y`, it returns the index of whichever comes first
in the list.
The correctness of this function assumes {name}`β` to be linearly pre-ordered.
-/
@[inline]
def minIdxOn [LE β] [DecidableLE β] (f : α β) (xs : List α) (h : xs []) : Nat :=
match xs with
| y :: ys => go y 0 1 ys
where
@[specialize]
go (x : α) (i : Nat) (j : Nat) (xs : List α) :=
match xs with
| [] => i
| y :: ys =>
if f x f y then
go x i (j + 1) ys
else
go y j (j + 1) ys
/--
Returns the index of an element of {name}`xs` that minimizes {name}`f`. If {given}`x, y`
are such that {lean}`f x = f y`, it returns the index of whichever comes first in the list.
Returns {name}`none` if the list is empty.
The correctness of this function assumes {name}`β` to be linearly pre-ordered.
-/
@[inline]
def minIdxOn? [LE β] [DecidableLE β] (f : α β) (xs : List α) : Option Nat :=
match xs with
| [] => none
| y :: ys => some ((y :: ys).minIdxOn f (nomatch ·))
/--
Returns the index of an element of the non-empty list {name}`xs` that maximizes {name}`f`.
If {given}`x, y` are such that {lean}`f x = f y`, it returns the index of whichever comes first
in the list.
The correctness of this function assumes {name}`β` to be linearly pre-ordered.
-/
@[inline]
def maxIdxOn [LE β] [DecidableLE β] (f : α β) (xs : List α) (h : xs []) : Nat :=
letI : LE β := LE.opposite inferInstance
xs.minIdxOn f h
/--
Returns the index of an element of {name}`xs` that maximizes {name}`f`. If {given}`x, y`
are such that {lean}`f x = f y`, it returns the index of whichever comes first in the list.
Returns {name}`none` if the list is empty.
The correctness of this function assumes {name}`β` to be linearly pre-ordered.
-/
@[inline]
def maxIdxOn? [LE β] [DecidableLE β] (f : α β) (xs : List α) : Option Nat :=
letI : LE β := LE.opposite inferInstance
xs.minIdxOn? f
protected theorem maxIdxOn_eq_minIdxOn {le : LE β} {_ : DecidableLE β} {f : α β}
{xs : List α} {h} :
xs.maxIdxOn f h = (letI := le.opposite; xs.minIdxOn f h) :=
(rfl)
private theorem minIdxOn.go_lt_length_add [LE β] [DecidableLE β] {f : α β} {x : α}
{i j : Nat} {xs : List α} (h : i < j) :
List.minIdxOn.go f x i j xs < xs.length + j := by
induction xs generalizing x i j
· simp [go, h]
· rename_i y ys ih
simp only [go, length_cons, Nat.add_assoc, Nat.add_comm 1]
split
· exact ih (Nat.lt_succ_of_lt i < j)
· exact ih (Nat.lt_succ_self j)
private theorem minIdxOn.go_eq_of_forall_le [LE β] [DecidableLE β] {f : α β}
{x : α} {i j : Nat} {xs : List α} (h : y xs, f x f y) :
List.minIdxOn.go f x i j xs = i := by
induction xs generalizing x i j
· simp [go]
· rename_i y ys ih
simp only [go]
split
· apply ih
simp_all
· simp_all
private theorem exists_getElem_eq_of_drop_eq_cons {xs : List α} {k : Nat} {y : α} {ys : List α}
(h : xs.drop k = y :: ys) : hlt : k < xs.length, xs[k] = y := by
have hlt : k < xs.length := by
false_or_by_contra
have : drop k xs = [] := drop_of_length_le (by omega)
simp [this] at h
refine hlt, ?_
have := take_append_drop k xs
rw [h] at this
simp +singlePass only [ this]
rw [getElem_append_right (length_take_le _ _)]
simp [length_take_of_le (Nat.le_of_lt hlt)]
private theorem take_succ_eq_append_of_drop_eq_cons {xs : List α} {k : Nat} {y : α}
{ys : List α} (h : xs.drop k = y :: ys) : xs.take (k + 1) = xs.take k ++ [y] := by
obtain hlt, rfl := exists_getElem_eq_of_drop_eq_cons h
rw [take_succ_eq_append_getElem hlt]
private theorem minIdxOn_eq_go_drop [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β}
{xs : List α} (h : xs []) {k : Nat} :
(i : Nat) (hlt : i < xs.length), i k xs[i] = (xs.take (k + 1)).minOn f (by simpa)
xs.minIdxOn f h = List.minIdxOn.go f ((xs.take (k + 1)).minOn f (by cases xs <;> simp_all)) i (k + 1) (xs.drop (k + 1)) := by
match xs with
| y :: ys =>
simp only [drop_succ_cons]
induction k
· simp [minIdxOn]
· rename_i k ih
specialize ih
obtain i, hlt, hi, ih := ih
simp only [ih, drop_drop]
simp only [length_cons] at hlt
match h : drop k ys with
| [] =>
have : ys.length k := by simp_all
simp [drop_nil, minIdxOn.go, take_of_length_le, hi, ih, hlt, this, Nat.le_succ_of_le]
| z :: zs =>
simp only [minIdxOn.go]
have : take (k + 1 + 1) (y :: ys) = take (k + 1) (y :: ys) ++ [z] := by apply take_succ_eq_append_of_drop_eq_cons _
simp only [this, List.minOn_append (xs := take (k + 1) (y :: ys)) (by simp) (cons_ne_nil _ _)]
simp only [take_succ_cons] at this
split
· simp only [List.minOn_singleton, minOn_eq_left, length_cons, *]
exact i, by omega, Nat.le_succ_of_le i k, by simp [ih], rfl
· simp only [List.minOn_singleton, not_false_eq_true, minOn_eq_right, length_cons, *]
obtain hlt, rfl := exists_getElem_eq_of_drop_eq_cons h
exact k + 1, by omega, Nat.le_refl _, by simp, rfl
@[simp]
protected theorem minIdxOn_nil_eq_iff_true [LE β] [DecidableLE β] {f : α β} {x : Nat}
(h : [] []) : ([] : List α).minIdxOn f h = x True :=
nomatch h
protected theorem minIdxOn_nil_eq_iff_false [LE β] [DecidableLE β] {f : α β} {x : Nat}
(h : [] []) : ([] : List α).minIdxOn f h = x False :=
nomatch h
@[simp]
protected theorem minIdxOn_singleton [LE β] [DecidableLE β] {x : α} {f : α β} :
[x].minIdxOn f (of_decide_eq_false rfl) = 0 := by
rw [minIdxOn, minIdxOn.go]
@[simp]
protected theorem minIdxOn_lt_length [LE β] [DecidableLE β] {f : α β} {xs : List α}
(h : xs []) : xs.minIdxOn f h < xs.length := by
rw [minIdxOn.eq_def]
split
simp [minIdxOn.go_lt_length_add]
protected theorem minIdxOn_le_of_apply_getElem_le_apply_minOn [LE β] [DecidableLE β] [IsLinearPreorder β]
{f : α β} {xs : List α} (h : xs [])
{k : Nat} (hi : k < xs.length) (hle : f xs[k] f (xs.minOn f h)) :
xs.minIdxOn f h k := by
obtain i, _, hi, _, h' := minIdxOn_eq_go_drop (f := f) h (k := k)
rw [h']
refine Nat.le_trans ?_ hi
apply Nat.le_of_eq
apply minIdxOn.go_eq_of_forall_le
intro y hy
refine le_trans (List.apply_minOn_le_of_mem (y := xs[k]) (by rw [mem_take_iff_getElem]; exact k, by omega, rfl)) ?_
refine le_trans hle ?_
apply List.apply_minOn_le_of_mem
apply mem_of_mem_drop
exact hy
protected theorem apply_minOn_lt_apply_getElem_of_lt_minIdxOn [LE β] [DecidableLE β] [LT β] [IsLinearPreorder β]
[LawfulOrderLT β]
{f : α β} {xs : List α} (h : xs [])
{k : Nat} (hk : k < xs.minIdxOn f h) :
f (xs.minOn f h) < f (xs[k]'(by haveI := List.minIdxOn_lt_length (f := f) h; omega)) := by
simp only [ not_le] at hk
apply hk.imp
apply List.minIdxOn_le_of_apply_getElem_le_apply_minOn
@[simp]
protected theorem getElem_minIdxOn [LE β] [DecidableLE β] [IsLinearPreorder β]
{f : α β} {xs : List α} (h : xs []) :
xs[xs.minIdxOn f h] = xs.minOn f h := by
obtain i, hlt, hi, heq, h' := minIdxOn_eq_go_drop (f := f) h (k := xs.length)
simp only [drop_eq_nil_of_le (as := xs) (i := xs.length + 1) (by omega), minIdxOn.go] at h'
simp [h', heq, take_of_length_le (l := xs) (i := xs.length + 1) (by omega)]
protected theorem le_minIdxOn_of_apply_getElem_lt_apply_getElem [LE β] [DecidableLE β] [LT β] [IsLinearPreorder β]
[LawfulOrderLT β] {f : α β} {xs : List α} (h : xs []) {i : Nat} (hi : i < xs.length)
(hi' : j, (_ : j < i) f xs[i] < f xs[j]) :
i xs.minIdxOn f h := by
false_or_by_contra; rename_i hgt
simp only [not_le] at hgt
specialize hi' _ hgt
simp only [List.getElem_minIdxOn] at hi'
apply (not_le.mpr hi').elim
apply List.apply_minOn_le_of_mem
simp
protected theorem minIdxOn_le_of_apply_getElem_le_apply_getElem [LE β] [DecidableLE β] [IsLinearPreorder β]
{f : α β} {xs : List α} (h : xs []) {i : Nat} (hi : i < xs.length)
(hi' : j, (_ : j < xs.length) f xs[i] f xs[j]) :
xs.minIdxOn f h i := by
apply List.minIdxOn_le_of_apply_getElem_le_apply_minOn h hi
simp only [List.le_apply_minOn_iff, List.mem_iff_getElem]
rintro _ j, hj, rfl
exact hi' _ hj
protected theorem minIdxOn_eq_iff [LE β] [DecidableLE β] [LT β] [IsLinearPreorder β]
[LawfulOrderLT β]
{f : α β} {xs : List α} (h : xs []) {i : Nat} :
xs.minIdxOn f h = i (h : i < xs.length),
( j, (_ : j < xs.length) f xs[i] f xs[j])
( j, (_ : j < i) f xs[i] < f xs[j]) := by
apply Iff.intro
· rintro rfl
simp only [List.getElem_minIdxOn]
refine List.minIdxOn_lt_length h, ?_, ?_
· simp [List.apply_minOn_le_of_mem]
· exact fun j hj => List.apply_minOn_lt_apply_getElem_of_lt_minIdxOn h hj
· rintro hi, h₁, h₂
apply le_antisymm
· apply List.minIdxOn_le_of_apply_getElem_le_apply_getElem h hi h₁
· apply List.le_minIdxOn_of_apply_getElem_lt_apply_getElem h hi h₂
protected theorem minIdxOn_eq_iff_eq_minOn [LE β] [DecidableLE β] [LT β] [IsLinearPreorder β]
[LawfulOrderLT β] {f : α β} {xs : List α} (h : xs []) {i : Nat} :
xs.minIdxOn f h = i hi : i < xs.length, xs[i] = xs.minOn f h
(j : Nat) (hj : j < i), f (xs.minOn f h) < f xs[j] := by
apply Iff.intro
· rintro rfl
refine List.minIdxOn_lt_length h, List.getElem_minIdxOn h, ?_
intro j hj
exact List.apply_minOn_lt_apply_getElem_of_lt_minIdxOn h hj
· rintro hlt, heq, h'
specialize h' (xs.minIdxOn f h)
simp only [List.getElem_minIdxOn] at h'
apply le_antisymm
· apply List.minIdxOn_le_of_apply_getElem_le_apply_minOn h hlt
simp [heq, le_refl]
· simpa [lt_irrefl] using h'
private theorem minIdxOn.go_eq
[LE β] [DecidableLE β] [IsLinearPreorder β] {x : α} {xs : List α} {f : α β} :
List.minIdxOn.go f x i j xs =
if h : xs = [] then i
else if f x f (xs.minOn f h) then i
else (xs.minIdxOn f h) + j := by
open scoped Classical.Order in
induction xs generalizing x i j
· simp [go]
· rename_i y ys ih
simp only [go, reduceCtorEq, reduceDIte]
split
· rw [ih]
split
· simp [*]
· simp only [List.minOn_cons, reduceDIte, le_apply_minOn_iff, true_and, *]
split
· rfl
· rename_i hlt
simp only [minIdxOn]
split
simp only [ih, reduceCtorEq, reduceDIte]
rw [if_neg]
· simp [minIdxOn, Nat.add_assoc, Nat.add_comm 1]
· simp only [not_le] at hlt
exact lt_of_lt_of_le hlt _
· rename_i hlt
rw [if_neg]
· rw [minIdxOn, ih]
split
· simp [*, go]
· simp only [reduceDIte, *]
split
· simp
· simp only [Nat.add_assoc, Nat.add_comm 1]
· simp only [not_le] at hlt
exact lt_of_le_of_lt (List.apply_minOn_le_of_mem mem_cons_self) hlt
protected theorem minIdxOn_cons
[LE β] [DecidableLE β] [IsLinearPreorder β] {x : α} {xs : List α} {f : α β} :
(x :: xs).minIdxOn f (by exact of_decide_eq_false rfl) =
if h : xs = [] then 0
else if f x f (xs.minOn f h) then 0
else (xs.minIdxOn f h) + 1 := by
simpa [List.minIdxOn] using minIdxOn.go_eq
protected theorem minIdxOn_eq_zero_iff [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs : List α} {f : α β} (h : xs []) :
xs.minIdxOn f h = 0 x xs, f (xs.head h) f x := by
rw [minIdxOn.eq_def]
split
rename_i y ys _
simp only [mem_cons, head_cons, forall_eq_or_imp, le_refl, true_and]
apply Iff.intro
· intro h
cases ys
· simp
· intro a ha
refine le_trans ?_ (List.apply_minOn_le_of_mem ha)
simpa [minIdxOn.go_eq] using h
· intro h
cases ys
· simp [minIdxOn.go]
· simpa [minIdxOn.go_eq, List.le_apply_minOn_iff] using h
section Append
/-!
The proof of {name}`List.minOn_append` uses associativity of {name}`minOn` and applies {name}`foldl_assoc`.
The proof of {name (scope := "Init.Data.List.MinMaxIdx")}`minIdxOn_append` is analogous, but the
aggregation operation, {name (scope := "Init.Data.List.MinMaxIdx")}`combineMinIdxOn`, depends on
the length of the lists to combine. After proving associativity of the aggregation operation,
the proof closely follows the proof of {name}`foldl_assoc`.
-/
private def combineMinIdxOn [LE β] [DecidableLE β]
(f : α β) {xs ys : List α} (i j : Nat) (hi : i < xs.length) (hj : j < ys.length) : Nat :=
if f xs[i] f ys[j] then
i
else
xs.length + j
private theorem combineMinIdxOn_lt [LE β] [DecidableLE β]
(f : α β) {xs ys : List α} {i j : Nat} (hi : i < xs.length) (hj : j < ys.length) :
combineMinIdxOn f i j hi hj < (xs ++ ys).length := by
simp only [combineMinIdxOn]
split <;> (simp; omega)
private theorem combineMinIdxOn_assoc [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs ys zs : List α} {i j k : Nat} {f : α β} (hi : i < xs.length) (hj : j < ys.length)
(hk : k < zs.length) :
combineMinIdxOn f (combineMinIdxOn f i j _ _) k
(combineMinIdxOn_lt f hi hj) hk = combineMinIdxOn f i (combineMinIdxOn f j k _ _) hi (combineMinIdxOn_lt f hj hk) := by
open scoped Classical.Order in
simp only [combineMinIdxOn]
split
· rw [getElem_append_left (by omega)]
split
· split
· rw [getElem_append_left (by omega)]
simp [*]
· rw [getElem_append_right (by omega)]
simp [*]
· split
· have := le_trans f xs[i] f ys[j] f ys[j] f zs[k]
contradiction
· rw [getElem_append_right (by omega)]
simp [*, Nat.add_assoc]
· rw [getElem_append_right (by omega)]
simp only [Nat.add_sub_cancel_left]
split
· rw [getElem_append_left (by omega), if_neg _]
· rename_i h₁ h₂
rw [not_le] at h₁ h₂
rw [getElem_append_right (by omega)]
simp only [Nat.add_sub_cancel_left]
have := not_le.mpr <| lt_trans h₂ h₁
simp [*, Nat.add_assoc]
private theorem minIdxOn_cons_aux [LE β] [DecidableLE β]
[IsLinearPreorder β] {x : α} {xs : List α} {f : α β} (hxs : xs []) :
(x :: xs).minIdxOn f (by simp) =
combineMinIdxOn f _ _
(List.minIdxOn_lt_length (f := f) (cons_ne_nil x []))
(List.minIdxOn_lt_length (f := f) hxs) := by
rw [minIdxOn, combineMinIdxOn]
simp [minIdxOn.go_eq, hxs, List.getElem_minIdxOn, Nat.add_comm 1]
private theorem minIdxOn_append_aux [LE β] [DecidableLE β]
[IsLinearPreorder β] {xs ys : List α} {f : α β} (hxs : xs []) (hys : ys []) :
(xs ++ ys).minIdxOn f (by simp [hxs]) =
combineMinIdxOn f _ _
(List.minIdxOn_lt_length (f := f) hxs)
(List.minIdxOn_lt_length (f := f) hys) := by
induction xs
· contradiction
· rename_i x xs ih
match xs with
| [] => simp [minIdxOn_cons_aux (xs := ys) _]
| z :: zs =>
simp +singlePass only [cons_append]
simp only [minIdxOn_cons_aux (xs := z :: zs ++ ys) (by simp), ih (by simp),
minIdxOn_cons_aux (xs := z :: zs) (by simp), combineMinIdxOn_assoc]
protected theorem minIdxOn_append [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs ys : List α} {f : α β} (hxs : xs []) (hys : ys []) :
(xs ++ ys).minIdxOn f (by simp [hxs]) =
if f (xs.minOn f hxs) f (ys.minOn f hys) then
xs.minIdxOn f hxs
else
xs.length + ys.minIdxOn f hys := by
simp [minIdxOn_append_aux hxs hys, combineMinIdxOn, List.getElem_minIdxOn]
end Append
protected theorem left_le_minIdxOn_append [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs ys : List α} {f : α β} (h : xs []) :
xs.minIdxOn f h (xs ++ ys).minIdxOn f (by simp [h]) := by
by_cases hys : ys = []
· simp [hys]
· rw [List.minIdxOn_append h hys]
split
· apply Nat.le_refl
· have := List.minIdxOn_lt_length (f := f) h
omega
protected theorem minIdxOn_take_le [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs : List α} {f : α β} {i : Nat} (h : xs.take i []) :
(xs.take i).minIdxOn f h xs.minIdxOn f (List.ne_nil_of_take_ne_nil h) := by
have := take_append_drop i xs
conv => rhs; simp +singlePass only [ this]
apply List.left_le_minIdxOn_append
@[simp]
protected theorem minIdxOn_replicate [LE β] [DecidableLE β] [Refl (α := β) (· ·)]
{n : Nat} {a : α} {f : α β} (h : replicate n a []) :
(replicate n a).minIdxOn f h = 0 := by
match n with
| 0 => simp at h
| n + 1 =>
simp only [minIdxOn, replicate_succ]
generalize 1 = j
induction n generalizing j
· simp [minIdxOn.go]
· simp only [replicate_succ, minIdxOn.go] at *
split
· simp [*]
· have := le_refl (f a)
contradiction
@[simp]
protected theorem maxIdxOn_nil_eq_iff_true [LE β] [DecidableLE β] {f : α β} {x : Nat}
(h : [] []) : ([] : List α).maxIdxOn f h = x True :=
nomatch h
protected theorem maxIdxOn_nil_eq_iff_false [LE β] [DecidableLE β] {f : α β} {x : Nat}
(h : [] []) : ([] : List α).maxIdxOn f h = x False :=
nomatch h
@[simp]
protected theorem maxIdxOn_singleton [LE β] [DecidableLE β] {x : α} {f : α β} :
[x].maxIdxOn f (of_decide_eq_false rfl) = 0 :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.minIdxOn_singleton
@[simp]
protected theorem maxIdxOn_lt_length [LE β] [DecidableLE β] {f : α β} {xs : List α}
(h : xs []) : xs.maxIdxOn f h < xs.length :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.minIdxOn_lt_length h
protected theorem maxIdxOn_le_of_apply_getElem_le_apply_maxOn [LE β] [DecidableLE β] [IsLinearPreorder β]
{f : α β} {xs : List α} (h : xs [])
{k : Nat} (hi : k < xs.length) (hle : f (xs.maxOn f h) f xs[k]) :
xs.maxIdxOn f h k := by
simp only [List.maxIdxOn_eq_minIdxOn, List.maxOn_eq_minOn] at hle
letI : LE β := (inferInstanceAs (LE β)).opposite
exact List.minIdxOn_le_of_apply_getElem_le_apply_minOn h hi (by simpa [LE.le_opposite_iff] using hle)
protected theorem apply_maxOn_lt_apply_getElem_of_lt_maxIdxOn [LE β] [DecidableLE β] [LT β] [IsLinearPreorder β]
[LawfulOrderLT β]
{f : α β} {xs : List α} (h : xs [])
{k : Nat} (hk : k < xs.maxIdxOn f h) :
f (xs[k]'(by haveI := List.maxIdxOn_lt_length (f := f) h; omega)) < f (xs.maxOn f h) := by
simp only [List.maxIdxOn_eq_minIdxOn, List.maxOn_eq_minOn] at hk
letI : LE β := LE.opposite inferInstance
letI : LT β := LT.opposite inferInstance
simpa [LT.lt_opposite_iff] using List.apply_minOn_lt_apply_getElem_of_lt_minIdxOn (f := f) h hk
@[simp]
protected theorem getElem_maxIdxOn [LE β] [DecidableLE β] [IsLinearPreorder β]
{f : α β} {xs : List α} (h : xs []) :
xs[xs.maxIdxOn f h] = xs.maxOn f h := by
simp only [List.maxIdxOn_eq_minIdxOn, List.maxOn_eq_minOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
exact List.getElem_minIdxOn h
protected theorem le_maxIdxOn_of_apply_getElem_lt_apply_getElem [LE β] [DecidableLE β] [LT β]
[IsLinearPreorder β] [LawfulOrderLT β] {f : α β} {xs : List α} (h : xs []) {i : Nat}
(hi : i < xs.length) (hi' : j, (_ : j < i) f xs[j] < f xs[i]) :
i xs.maxIdxOn f h := by
simp only [List.maxIdxOn_eq_minIdxOn]
letI : LE β := LE.opposite inferInstance
letI : LT β := LT.opposite inferInstance
simpa [LE.le_opposite_iff] using List.le_minIdxOn_of_apply_getElem_lt_apply_getElem h hi
(by simpa [LT.lt_opposite_iff] using hi')
protected theorem maxIdxOn_le_of_apply_getElem_le_apply_getElem [LE β] [DecidableLE β] [IsLinearPreorder β]
{f : α β} {xs : List α} (h : xs []) {i : Nat} (hi : i < xs.length)
(hi' : j, (_ : j < xs.length) f xs[j] f xs[i]) :
xs.maxIdxOn f h i := by
simp only [List.maxIdxOn_eq_minIdxOn]
letI : LE β := LE.opposite inferInstance
simpa [LE.le_opposite_iff] using List.minIdxOn_le_of_apply_getElem_le_apply_getElem (f := f) h hi
(by simpa [LE.le_opposite_iff] using hi')
protected theorem maxIdxOn_eq_iff [LE β] [DecidableLE β] [LT β] [IsLinearPreorder β]
[LawfulOrderLT β]
{f : α β} {xs : List α} (h : xs []) {i : Nat} :
xs.maxIdxOn f h = i (h : i < xs.length),
( j, (_ : j < xs.length) f xs[j] f xs[i])
( j, (_ : j < i) f xs[j] < f xs[i]) := by
simp only [List.maxIdxOn_eq_minIdxOn]
letI : LE β := LE.opposite inferInstance
letI : LT β := LT.opposite inferInstance
simpa [LE.le_opposite_iff, LT.lt_opposite_iff] using List.minIdxOn_eq_iff (f := f) h
protected theorem maxIdxOn_eq_iff_eq_maxOn [LE β] [DecidableLE β] [LT β] [IsLinearPreorder β]
[LawfulOrderLT β] {f : α β} {xs : List α} (h : xs []) {i : Nat} :
xs.maxIdxOn f h = i hi : i < xs.length, xs[i] = xs.maxOn f h
(j : Nat) (hj : j < i), f xs[j] < f (xs.maxOn f h) := by
simp only [List.maxIdxOn_eq_minIdxOn, List.maxOn_eq_minOn]
letI : LE β := LE.opposite inferInstance
letI : LT β := LT.opposite inferInstance
simpa [LT.lt_opposite_iff] using List.minIdxOn_eq_iff_eq_minOn (f := f) h
protected theorem maxIdxOn_cons
[LE β] [DecidableLE β] [IsLinearPreorder β] {x : α} {xs : List α} {f : α β} :
(x :: xs).maxIdxOn f (by exact of_decide_eq_false rfl) =
if h : xs = [] then 0
else if f (xs.maxOn f h) f x then 0
else (xs.maxIdxOn f h) + 1 := by
simp only [List.maxIdxOn_eq_minIdxOn, List.maxOn_eq_minOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
simpa [LE.le_opposite_iff] using List.minIdxOn_cons (f := f)
protected theorem maxIdxOn_eq_zero_iff [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs : List α} {f : α β} (h : xs []) :
xs.maxIdxOn f h = 0 x xs, f x f (xs.head h) := by
simp only [List.maxIdxOn_eq_minIdxOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
simpa [LE.le_opposite_iff] using List.minIdxOn_eq_zero_iff h (f := f)
protected theorem maxIdxOn_append [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs ys : List α} {f : α β} (hxs : xs []) (hys : ys []) :
(xs ++ ys).maxIdxOn f (by simp [hxs]) =
if f (ys.maxOn f hys) f (xs.maxOn f hxs) then
xs.maxIdxOn f hxs
else
xs.length + ys.maxIdxOn f hys := by
simp only [List.maxIdxOn_eq_minIdxOn, List.maxOn_eq_minOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
simpa [LE.le_opposite_iff] using List.minIdxOn_append hxs hys (f := f)
protected theorem left_le_maxIdxOn_append [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs ys : List α} {f : α β} (h : xs []) :
xs.maxIdxOn f h (xs ++ ys).maxIdxOn f (by simp [h]) :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.left_le_minIdxOn_append h
protected theorem maxIdxOn_take_le [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs : List α} {f : α β} {i : Nat} (h : xs.take i []) :
(xs.take i).maxIdxOn f h xs.maxIdxOn f (List.ne_nil_of_take_ne_nil h) :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.minIdxOn_take_le h
@[simp]
protected theorem maxIdxOn_replicate [LE β] [DecidableLE β] [Refl (α := β) (· ·)]
{n : Nat} {a : α} {f : α β} (h : replicate n a []) :
(replicate n a).maxIdxOn f h = 0 :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.minIdxOn_replicate h
@[simp]
protected theorem minIdxOn?_nil [LE β] [DecidableLE β] {f : α β} :
([] : List α).minIdxOn? f = none :=
(rfl)
@[simp]
protected theorem minIdxOn?_singleton [LE β] [DecidableLE β] {x : α} {f : α β} :
[x].minIdxOn? f = some 0 :=
(rfl)
@[simp]
protected theorem isSome_minIdxOn?_iff [LE β] [DecidableLE β] {f : α β} {xs : List α} :
(xs.minIdxOn? f).isSome xs [] := by
cases xs <;> simp [minIdxOn?]
protected theorem minIdxOn_eq_get_minIdxOn? [LE β] [DecidableLE β] {f : α β} {xs : List α}
(h : xs []) : xs.minIdxOn f h = (xs.minIdxOn? f).get (List.isSome_minIdxOn?_iff.mpr h) := by
match xs with
| [] => contradiction
| _ :: _ => simp [minIdxOn?]
@[simp]
protected theorem get_minIdxOn?_eq_minIdxOn [LE β] [DecidableLE β] {f : α β} {xs : List α}
(h : (xs.minIdxOn? f).isSome) :
(xs.minIdxOn? f).get h = xs.minIdxOn f (List.isSome_minIdxOn?_iff.mp h) := by
rw [List.minIdxOn_eq_get_minIdxOn?]
protected theorem minIdxOn?_eq_some_minIdxOn [LE β] [DecidableLE β] {f : α β} {xs : List α}
(h : xs []) : xs.minIdxOn? f = some (xs.minIdxOn f h) := by
match xs with
| [] => contradiction
| _ :: _ => simp [minIdxOn?]
protected theorem minIdxOn_eq_of_minIdxOn?_eq_some
[LE β] [DecidableLE β] {f : α β} {xs : List α} {i : Nat} (h : xs.minIdxOn? f = some i) :
xs.minIdxOn f (List.isSome_minIdxOn?_iff.mp (Option.isSome_of_eq_some h)) = i := by
have h' := List.isSome_minIdxOn?_iff.mp (Option.isSome_of_eq_some h)
rwa [List.minIdxOn?_eq_some_minIdxOn h', Option.some.injEq] at h
protected theorem isSome_minIdxOn?_of_mem
[LE β] [DecidableLE β] {f : α β} {xs : List α} {x : α} (h : x xs) :
(xs.minIdxOn? f).isSome := by
apply List.isSome_minIdxOn?_iff.mpr
exact ne_nil_of_mem h
protected theorem minIdxOn?_cons_eq_some_minIdxOn
[LE β] [DecidableLE β] {f : α β} {x : α} {xs : List α} :
(x :: xs).minIdxOn? f = some ((x :: xs).minIdxOn f (nomatch ·)) := by
simp [List.minIdxOn?_eq_some_minIdxOn]
protected theorem minIdxOn?_eq_if
[LE β] [DecidableLE β] {f : α β} {xs : List α} :
xs.minIdxOn? f =
if h : xs [] then
some (xs.minIdxOn f h)
else
none := by
cases xs <;> simp [List.minIdxOn?_cons_eq_some_minIdxOn]
protected theorem minIdxOn?_cons
[LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β} {x : α} {xs : List α} :
(x :: xs).minIdxOn? f = some
(if h : xs = [] then 0
else if f x f (xs.minOn f h) then 0
else (xs.minIdxOn f h) + 1) := by
simp [List.minIdxOn?_eq_some_minIdxOn, List.minIdxOn_cons]
protected theorem ne_nil_of_minIdxOn?_eq_some
[LE β] [DecidableLE β] {f : α β} {k : Nat} {xs : List α} (h : xs.minIdxOn? f = some k) :
xs [] := by
rintro rfl
simp at h
protected theorem lt_length_of_minIdxOn?_eq_some [LE β] [DecidableLE β] {f : α β}
{xs : List α} (h : xs.minIdxOn? f = some i) : i < xs.length := by
have hne : xs [] := List.ne_nil_of_minIdxOn?_eq_some h
rw [List.minIdxOn?_eq_some_minIdxOn hne] at h
have := List.minIdxOn_lt_length (f := f) hne
simp_all
@[simp]
protected theorem get_minIdxOn?_lt_length [LE β] [DecidableLE β] {f : α β} {xs : List α}
(h : (xs.minIdxOn? f).isSome) : (xs.minIdxOn? f).get h < xs.length := by
rw [List.get_minIdxOn?_eq_minIdxOn]
apply List.minIdxOn_lt_length
@[simp]
protected theorem getElem_get_minIdxOn? [LE β] [DecidableLE β] [IsLinearPreorder β]
{f : α β} {xs : List α} (h : (xs.minIdxOn? f).isSome) :
xs[(xs.minIdxOn? f).get h] = xs.minOn f (List.isSome_minIdxOn?_iff.mp h) := by
rw [getElem_congr rfl (List.get_minIdxOn?_eq_minIdxOn _), List.getElem_minIdxOn]
protected theorem minIdxOn?_eq_some_zero_iff [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs : List α} {f : α β} :
xs.minIdxOn? f = some 0 h : xs [], x xs, f (xs.head h) f x := by
simp [Option.eq_some_iff_get_eq, List.minIdxOn_eq_zero_iff]
protected theorem minIdxOn?_replicate [LE β] [DecidableLE β] [Refl (α := β) (· ·)]
{n : Nat} {a : α} {f : α β} :
(replicate n a).minIdxOn? f = if n = 0 then none else some 0 := by
simp [List.minIdxOn?_eq_if]
@[simp]
protected theorem minIdxOn?_replicate_of_pos [LE β] [DecidableLE β] [Refl (α := β) (· ·)]
{n : Nat} {a : α} {f : α β} (h : 0 < n) :
(replicate n a).minIdxOn? f = some 0 := by
simp [List.minIdxOn?_replicate, Nat.ne_zero_of_lt h]
/-! ### maxIdxOn? -/
protected theorem maxIdxOn?_eq_minIdxOn? {le : LE β} {_ : DecidableLE β} {f : α β}
{xs : List α} :
xs.maxIdxOn? f = (letI := le.opposite; xs.minIdxOn? f) :=
(rfl)
@[simp]
protected theorem maxIdxOn?_nil [LE β] [DecidableLE β] {f : α β} :
([] : List α).maxIdxOn? f = none :=
letI : LE β := LE.opposite inferInstance
List.minIdxOn?_nil
@[simp]
protected theorem maxIdxOn?_singleton [LE β] [DecidableLE β] {x : α} {f : α β} :
[x].maxIdxOn? f = some 0 :=
letI : LE β := LE.opposite inferInstance
List.minIdxOn?_singleton
@[simp]
protected theorem isSome_maxIdxOn?_iff [LE β] [DecidableLE β] {f : α β} {xs : List α} :
(xs.maxIdxOn? f).isSome xs [] := by
letI : LE β := LE.opposite inferInstance
exact List.isSome_minIdxOn?_iff
protected theorem maxIdxOn_eq_get_maxIdxOn? [LE β] [DecidableLE β] {f : α β} {xs : List α}
(h : xs []) : xs.maxIdxOn f h = (xs.maxIdxOn? f).get (List.isSome_maxIdxOn?_iff.mpr h) := by
letI : LE β := LE.opposite inferInstance
exact List.minIdxOn_eq_get_minIdxOn? h
@[simp]
protected theorem get_maxIdxOn?_eq_maxIdxOn [LE β] [DecidableLE β] {f : α β} {xs : List α}
(h : (xs.maxIdxOn? f).isSome) :
(xs.maxIdxOn? f).get h = xs.maxIdxOn f (List.isSome_maxIdxOn?_iff.mp h) := by
letI : LE β := LE.opposite inferInstance
exact List.get_minIdxOn?_eq_minIdxOn h
protected theorem maxIdxOn?_eq_some_maxIdxOn [LE β] [DecidableLE β] {f : α β} {xs : List α}
(h : xs []) : xs.maxIdxOn? f = some (xs.maxIdxOn f h) := by
letI : LE β := LE.opposite inferInstance
exact List.minIdxOn?_eq_some_minIdxOn h
protected theorem maxIdxOn_eq_of_maxIdxOn?_eq_some
[LE β] [DecidableLE β] {f : α β} {xs : List α} {i : Nat} (h : xs.maxIdxOn? f = some i) :
xs.maxIdxOn f (List.isSome_maxIdxOn?_iff.mp (Option.isSome_of_eq_some h)) = i := by
letI : LE β := LE.opposite inferInstance
exact List.minIdxOn_eq_of_minIdxOn?_eq_some h
protected theorem isSome_maxIdxOn?_of_mem
[LE β] [DecidableLE β] {f : α β} {xs : List α} {x : α} (h : x xs) :
(xs.maxIdxOn? f).isSome := by
letI : LE β := LE.opposite inferInstance
exact List.isSome_minIdxOn?_of_mem h
protected theorem maxIdxOn?_cons_eq_some_maxIdxOn
[LE β] [DecidableLE β] {f : α β} {x : α} {xs : List α} :
(x :: xs).maxIdxOn? f = some ((x :: xs).maxIdxOn f (nomatch ·)) := by
letI : LE β := LE.opposite inferInstance
exact List.minIdxOn?_cons_eq_some_minIdxOn
protected theorem maxIdxOn?_eq_if
[LE β] [DecidableLE β] {f : α β} {xs : List α} :
xs.maxIdxOn? f =
if h : xs [] then
some (xs.maxIdxOn f h)
else
none := by
letI : LE β := LE.opposite inferInstance
exact List.minIdxOn?_eq_if
protected theorem maxIdxOn?_cons
[LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β} {x : α} {xs : List α} :
(x :: xs).maxIdxOn? f = some
(if h : xs = [] then 0
else if f (xs.maxOn f h) f x then 0
else (xs.maxIdxOn f h) + 1) := by
simp only [List.maxIdxOn_eq_minIdxOn, List.maxOn_eq_minOn]
letI : LE β := LE.opposite inferInstance
simpa [LE.le_opposite_iff] using List.minIdxOn?_cons (f := f)
protected theorem ne_nil_of_maxIdxOn?_eq_some
[LE β] [DecidableLE β] {f : α β} {k : Nat} {xs : List α} (h : xs.maxIdxOn? f = some k) :
xs [] := by
letI : LE β := LE.opposite inferInstance
exact List.ne_nil_of_minIdxOn?_eq_some (by simpa only [List.maxIdxOn?_eq_minIdxOn?] using h)
protected theorem lt_length_of_maxIdxOn?_eq_some [LE β] [DecidableLE β] {f : α β}
{xs : List α} (h : xs.maxIdxOn? f = some i) : i < xs.length := by
letI : LE β := LE.opposite inferInstance
exact List.lt_length_of_minIdxOn?_eq_some h
@[simp]
protected theorem get_maxIdxOn?_lt_length [LE β] [DecidableLE β] {f : α β} {xs : List α}
(h : (xs.maxIdxOn? f).isSome) : (xs.maxIdxOn? f).get h < xs.length := by
letI : LE β := LE.opposite inferInstance
exact List.get_minIdxOn?_lt_length h
@[simp]
protected theorem getElem_get_maxIdxOn? [LE β] [DecidableLE β] [IsLinearPreorder β]
{f : α β} {xs : List α} (h : (xs.maxIdxOn? f).isSome) :
xs[(xs.maxIdxOn? f).get h] = xs.maxOn f (List.isSome_maxIdxOn?_iff.mp h) := by
simp only [List.maxIdxOn?_eq_minIdxOn?, List.maxOn_eq_minOn]
letI : LE β := LE.opposite inferInstance
exact List.getElem_get_minIdxOn? h
protected theorem maxIdxOn?_eq_some_zero_iff [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs : List α} {f : α β} :
xs.maxIdxOn? f = some 0 h : xs [], x xs, f x f (xs.head h) := by
simp only [List.maxIdxOn?_eq_minIdxOn?]
letI : LE β := LE.opposite inferInstance
simpa [LE.le_opposite_iff] using List.minIdxOn?_eq_some_zero_iff (f := f)
protected theorem maxIdxOn?_replicate [LE β] [DecidableLE β] [Refl (α := β) (· ·)]
{n : Nat} {a : α} {f : α β} :
(replicate n a).maxIdxOn? f = if n = 0 then none else some 0 := by
letI : LE β := LE.opposite inferInstance
exact List.minIdxOn?_replicate
@[simp]
protected theorem maxIdxOn?_replicate_of_pos [LE β] [DecidableLE β] [Refl (α := β) (· ·)]
{n : Nat} {a : α} {f : α β} (h : 0 < n) :
(replicate n a).maxIdxOn? f = some 0 := by
letI : LE β := LE.opposite inferInstance
exact List.minIdxOn?_replicate_of_pos h
end List

View File

@@ -0,0 +1,623 @@
/-
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Reichert
-/
module
prelude
public import Init.Data.Order.MinMaxOn
public import Init.Data.Int.OfNat
public import Init.Data.List.Lemmas
public import Init.Data.List.TakeDrop
import Init.Data.Order.Lemmas
import Init.Data.List.Sublist
import Init.Data.List.MinMax
import Init.Data.Order.Opposite
set_option doc.verso true
set_option linter.missingDocs true
set_option linter.listVariables true -- Enforce naming conventions for `List`/`Array`/`Vector` variables.
set_option linter.indexVariables true -- Enforce naming conventions for index variables.
public section
open Std
open scoped OppositeOrderInstances
namespace List
/--
Returns an element of the non-empty list {name}`l` that minimizes {name}`f`. If {given}`x, y` are
such that {lean}`f x = f y`, it returns whichever comes first in the list.
The correctness of this function assumes {name}`β` to be linearly pre-ordered.
The property that {name}`List.minOn` is the first minimizer in the list is guaranteed by the lemma
{name (scope := "Init.Data.List.MinMaxIdx")}`List.getElem_minIdxOn`.
-/
@[inline, suggest_for List.argmin]
protected def minOn [LE β] [DecidableLE β] (f : α β) (l : List α) (h : l []) : α :=
match l with
| x :: xs => xs.foldl (init := x) (minOn f)
/--
Returns an element of the non-empty list {name}`l` that maximizes {name}`f`. If {given}`x, y` are
such that {lean}`f x = f y`, it returns whichever comes first in the list.
The correctness of this function assumes {name}`β` to be linearly pre-ordered.
The property that {name}`List.maxOn` is the first maximizer in the list is guaranteed by the lemma
{name (scope := "Init.Data.List.MinMaxIdx")}`List.getElem_maxIdxOn`.
-/
@[inline, suggest_for List.argmax]
protected def maxOn [i : LE β] [DecidableLE β] (f : α β) (l : List α) (h : l []) : α :=
letI : LE β := i.opposite
l.minOn f h
/--
Returns an element of {name}`l` that minimizes {name}`f`. If {given}`x, y` are such that
{lean}`f x = f y`, it returns whichever comes first in the list. Returns {name}`none` if the list is
empty.
The correctness of this function assumes {name}`β` to be linearly pre-ordered.
The property that {name}`List.minOn?` is the first minimizer in the list is guaranteed by the lemma
{name (scope := "Init.Data.List.MinMaxIdx")}`List.getElem_get_minIdxOn?`
-/
@[inline, suggest_for List.argmin? List.argmin] -- Mathlib's `List.argmin` returns an `Option α`
protected def minOn? [LE β] [DecidableLE β] (f : α β) (l : List α) : Option α :=
match l with
| [] => none
| x :: xs => some (xs.foldl (init := x) (minOn f))
/--
Returns an element of {name}`l` that maximizes {name}`f`. If {given}`x, y` are such that
{lean}`f x = f y`, it returns whichever comes first in the list. Returns {name}`none` if the list is
empty.
The correctness of this function assumes {name}`β` to be linearly pre-ordered.
The property that {name}`List.maxOn?` is the first minimizer in the list is guaranteed by the lemma
{name (scope := "Init.Data.List.MinMaxIdx")}`List.getElem_get_maxIdxOn?`.
-/
@[inline, suggest_for List.argmax? List.argmax] -- Mathlib's `List.argmax` returns an `Option α`
protected def maxOn? [i : LE β] [DecidableLE β] (f : α β) (l : List α) : Option α :=
letI : LE β := i.opposite
l.minOn? f
/-! ### minOn -/
@[simp]
protected theorem minOn_singleton [LE β] [DecidableLE β] {x : α} {f : α β} :
[x].minOn f (of_decide_eq_false rfl) = x := by
simp [List.minOn]
protected theorem minOn_cons
[LE β] [DecidableLE β] [IsLinearPreorder β] {x : α} {xs : List α} {f : α β} :
(x :: xs).minOn f (by exact of_decide_eq_false rfl) =
if h : xs = [] then x else minOn f x (xs.minOn f h) := by
simp only [List.minOn]
match xs with
| [] => simp
| y :: xs => simp [foldl_assoc]
@[simp]
protected theorem minOn_id [Min α] [LE α] [DecidableLE α] [LawfulOrderLeftLeaningMin α]
{xs : List α} (h : xs []) :
xs.minOn id h = xs.min h := by
have : minOn (α := α) id = min := by ext; apply minOn_id
simp only [List.minOn, List.min, this]
match xs with
| _ :: _ => simp
@[simp]
protected theorem minOn_mem [LE β] [DecidableLE β] {xs : List α}
{f : α β} {h : xs []} : xs.minOn f h xs := by
simp only [List.minOn]
match xs with
| x :: xs =>
fun_induction xs.foldl (init := x) (_root_.minOn f)
· simp
· rename_i x y _ ih
simp only [ne_eq, reduceCtorEq, not_false_eq_true, mem_cons, forall_const, foldl_cons] at ih
cases ih <;> rename_i heq
· cases minOn_eq_or (f := f) (x := x) (y := y) <;> simp_all
· simp [*]
protected theorem apply_minOn_le_of_mem [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs : List α} {f : α β} {y : α} (hx : y xs) :
f (xs.minOn f (List.ne_nil_of_mem hx)) f y := by
have h : xs [] := List.ne_nil_of_mem hx
simp only [List.minOn]
match xs with
| x :: xs =>
fun_induction xs.foldl (init := x) (_root_.minOn f) generalizing y
· simp only [mem_cons] at hx
simp_all [le_refl _]
· rename_i x y _ ih
simp at ih
rcases mem_cons.mp hx with rfl | hx
· exact le_trans ih.1 apply_minOn_le_left
· rcases mem_cons.mp hx with rfl | hx
· exact le_trans ih.1 apply_minOn_le_right
· apply ih.2
assumption
protected theorem le_apply_minOn_iff [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs : List α} {f : α β} (h : xs []) {b : β} :
b f (xs.minOn f h) x xs, b f x := by
match xs with
| x :: xs =>
rw [List.minOn]
induction xs generalizing x
· simp
· rw [foldl_cons, foldl_assoc, le_apply_minOn_iff]
simp_all
protected theorem apply_minOn_le_iff [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs : List α} {f : α β} (h : xs []) {b : β} :
f (xs.minOn f h) b x xs, f x b := by
apply Iff.intro
· intro h
match xs with
| x :: xs =>
rw [List.minOn] at h
induction xs generalizing x
· simpa using h
· rename_i y ys ih _
rw [foldl_cons] at h
specialize ih (minOn f x y) (by simp) h
obtain z, hm, hle := ih
rcases mem_cons.mp hm with rfl | hm
· cases minOn_eq_or (f := f) (x := x) (y := y)
· exact x, by simp_all
· exact y, by simp_all
· exact z, by simp_all
· rintro x, hm, hx
exact le_trans (List.apply_minOn_le_of_mem hm) hx
protected theorem lt_apply_minOn_iff
[LE β] [DecidableLE β] [LT β] [IsLinearPreorder β] [LawfulOrderLT β]
{xs : List α} {f : α β} (h : xs []) {b : β} :
b < f (xs.minOn f h) x xs, b < f x := by
simpa [not_le] using not_congr <| xs.apply_minOn_le_iff (f := f) h (b := b)
protected theorem apply_minOn_lt_iff
[LE β] [DecidableLE β] [LT β] [IsLinearPreorder β] [LawfulOrderLT β]
{xs : List α} {f : α β} (h : xs []) {b : β} :
f (xs.minOn f h) < b x xs, f x < b := by
simpa [not_le] using not_congr <| xs.le_apply_minOn_iff (f := f) h (b := b)
protected theorem apply_minOn_le_apply_minOn_of_subset [LE β] [DecidableLE β]
[IsLinearPreorder β] {xs ys : List α} {f : α β} (hxs : ys xs) (hys : ys []) :
haveI : xs [] := by intro h; rw [h] at hxs; simp_all [subset_nil]
f (xs.minOn f this) f (ys.minOn f hys) := by
rw [List.le_apply_minOn_iff]
intro x hx
exact List.apply_minOn_le_of_mem (hxs hx)
protected theorem le_apply_minOn_take [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs : List α} {f : α β} {i : Nat} (h : xs.take i []) :
f (xs.minOn f (List.ne_nil_of_take_ne_nil h)) f ((xs.take i).minOn f h) := by
apply List.apply_minOn_le_apply_minOn_of_subset
apply take_subset
protected theorem apply_minOn_append_le_left [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs ys : List α} {f : α β} (h : xs []) :
f ((xs ++ ys).minOn f (append_ne_nil_of_left_ne_nil h ys))
f (xs.minOn f h) := by
apply List.apply_minOn_le_apply_minOn_of_subset
apply subset_append_left
protected theorem apply_minOn_append_le_right [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs ys : List α} {f : α β} (h : ys []) :
f ((xs ++ ys).minOn f (append_ne_nil_of_right_ne_nil xs h))
f (ys.minOn f h) := by
apply List.apply_minOn_le_apply_minOn_of_subset
apply subset_append_right
@[simp]
protected theorem minOn_append [LE β] [DecidableLE β] [IsLinearPreorder β] {xs ys : List α}
{f : α β} (hxs : xs []) (hys : ys []) :
(xs ++ ys).minOn f (by simp [hxs]) = minOn f (xs.minOn f hxs) (ys.minOn f hys) := by
match xs, ys with
| x :: xs, y :: ys => simp [List.minOn, foldl_assoc]
protected theorem minOn_eq_head [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs : List α} {f : α β} (h : xs []) (h' : x xs, f (xs.head h) f x) :
xs.minOn f h = xs.head h := by
match xs with
| x :: xs =>
simp only [List.minOn]
induction xs
· simp
· simp only [foldl_cons, head_cons]
rw [minOn_eq_left] <;> simp_all
protected theorem min_map
[LE β] [DecidableLE β] [Min β] [IsLinearPreorder β] [LawfulOrderLeftLeaningMin β] {xs : List α}
{f : α β} (h : xs []) :
(xs.map f).min (by simpa) = f (xs.minOn f h) := by
match xs with
| x :: xs =>
simp only [List.minOn, map_cons, List.min, foldl_map]
rw [foldl_hom]
simp [min_apply]
@[simp]
protected theorem minOn_replicate [LE β] [DecidableLE β] [IsLinearPreorder β]
{n : Nat} {a : α} {f : α β} (h : replicate n a []) :
(replicate n a).minOn f h = a := by
induction n
· simp at h
· rename_i n ih
simp only [ne_eq, replicate_eq_nil_iff] at ih
simp +contextual [List.replicate, List.minOn_cons, ih]
/-! ### maxOn -/
protected theorem maxOn_eq_minOn {le : LE β} {dle : DecidableLE β} {xs : List α} {f : α β} {h} :
xs.maxOn f h = (letI := le.opposite; xs.minOn f h) :=
(rfl)
@[simp]
protected theorem maxOn_singleton [LE β] [DecidableLE β] {x : α} {f : α β} :
[x].maxOn f (of_decide_eq_false rfl) = x := by
simp [List.maxOn]
protected theorem maxOn_cons
[LE β] [DecidableLE β] [IsLinearPreorder β] {x : α} {xs : List α} {f : α β} :
(x :: xs).maxOn f (by exact of_decide_eq_false rfl) =
if h : xs = [] then x else maxOn f x (xs.maxOn f h) := by
simp only [maxOn_eq_minOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
exact List.minOn_cons (f := f)
protected theorem min_eq_max {min : Min α} {xs : List α} {h} :
xs.min h = (letI := min.oppositeMax; xs.max h) := by
simp only [List.min, List.max]
rw [Min.oppositeMax_def]
simp
protected theorem max_eq_min {max : Max α} {xs : List α} {h} :
xs.max h = (letI := max.oppositeMin; xs.min h) := by
simp only [List.min, List.max]
rw [Max.oppositeMin_def]
simp
protected theorem max?_eq_min? {max : Max α} {xs : List α} :
xs.max? = (letI := max.oppositeMin; xs.min?) := by
simp only [List.min?, List.max?]
rw [Max.oppositeMin_def]
simp
@[simp]
protected theorem maxOn_id [Max α] [LE α] [DecidableLE α] [LawfulOrderLeftLeaningMax α]
{xs : List α} (h : xs []) :
xs.maxOn id h = xs.max h := by
simp only [List.maxOn_eq_minOn]
letI : LE α := (inferInstanceAs (LE α)).opposite
letI : Min α := (inferInstanceAs (Max α)).oppositeMin
simpa only [List.max_eq_min] using List.minOn_id h
@[simp]
protected theorem maxOn_mem [LE β] [DecidableLE β] {xs : List α}
{f : α β} {h : xs []} : xs.maxOn f h xs :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.minOn_mem (f := f)
protected theorem le_apply_maxOn_of_mem [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs : List α} {f : α β} {y : α} (hx : y xs) :
f y f (xs.maxOn f (List.ne_nil_of_mem hx)) := by
rw [List.maxOn_eq_minOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
simpa [LE.le_opposite_iff] using List.apply_minOn_le_of_mem (f := f) hx
protected theorem apply_maxOn_le_iff [LE β] [DecidableLE β] [IsLinearPreorder β] {xs : List α}
{f : α β} (h : xs []) {b : β} :
f (xs.maxOn f h) b x xs, f x b := by
rw [List.maxOn_eq_minOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
simpa [LE.le_opposite_iff] using List.le_apply_minOn_iff (f := f) h
protected theorem le_apply_maxOn_iff [LE β] [DecidableLE β] [IsLinearPreorder β] {xs : List α}
{f : α β} (h : xs []) {b : β} :
b f (xs.maxOn f h) x xs, b f x := by
rw [List.maxOn_eq_minOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
simpa [LE.le_opposite_iff] using List.apply_minOn_le_iff (f := f) h
protected theorem apply_maxOn_lt_iff
[LE β] [DecidableLE β] [LT β] [IsLinearPreorder β] [LawfulOrderLT β]
{xs : List α} {f : α β} (h : xs []) {b : β} :
f (xs.maxOn f h) < b x xs, f x < b := by
letI : LE β := (inferInstanceAs (LE β)).opposite
letI : LT β := (inferInstanceAs (LT β)).opposite
simpa [LT.lt_opposite_iff] using List.lt_apply_minOn_iff (f := f) h
protected theorem lt_apply_maxOn_iff
[LE β] [DecidableLE β] [LT β] [IsLinearPreorder β] [LawfulOrderLT β]
{xs : List α} {f : α β} (h : xs []) {b : β} :
b < f (xs.maxOn f h) x xs, b < f x := by
letI : LE β := (inferInstanceAs (LE β)).opposite
letI : LT β := (inferInstanceAs (LT β)).opposite
simpa [LT.lt_opposite_iff] using List.apply_minOn_lt_iff (f := f) h
protected theorem apply_maxOn_le_apply_maxOn_of_subset [LE β] [DecidableLE β]
[IsLinearPreorder β] {xs ys : List α} {f : α β} (hxs : ys xs) (hys : ys []) :
haveI : xs [] := by intro h; rw [h] at hxs; simp_all [subset_nil]
f (ys.maxOn f hys) f (xs.maxOn f this) := by
rw [List.maxOn_eq_minOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
simpa [LE.le_opposite_iff] using List.apply_minOn_le_apply_minOn_of_subset (f := f) hxs hys
protected theorem apply_maxOn_take_le [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs : List α} {f : α β} {i : Nat} (h : xs.take i []) :
f ((xs.take i).maxOn f h) f (xs.maxOn f (List.ne_nil_of_take_ne_nil h)) := by
rw [List.maxOn_eq_minOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
simpa [LE.le_opposite_iff] using List.le_apply_minOn_take (f := f) h
protected theorem le_apply_maxOn_append_left [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs ys : List α} {f : α β} (h : xs []) :
f (xs.maxOn f h)
f ((xs ++ ys).maxOn f (append_ne_nil_of_left_ne_nil h ys)) := by
rw [List.maxOn_eq_minOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
simpa [LE.le_opposite_iff] using List.apply_minOn_append_le_left (f := f) h
protected theorem le_apply_maxOn_append_right [LE β] [DecidableLE β] [IsLinearPreorder β]
{xs ys : List α} {f : α β} (h : ys []) :
f (ys.maxOn f h)
f ((xs ++ ys).maxOn f (append_ne_nil_of_right_ne_nil xs h)) := by
rw [List.maxOn_eq_minOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
simpa [LE.le_opposite_iff] using List.apply_minOn_append_le_right (f := f) h
@[simp]
protected theorem maxOn_append [LE β] [DecidableLE β] [IsLinearPreorder β] {xs ys : List α}
{f : α β} (hxs : xs []) (hys : ys []) :
(xs ++ ys).maxOn f (by simp [hxs]) = maxOn f (xs.maxOn f hxs) (ys.maxOn f hys) := by
simp only [List.maxOn_eq_minOn, maxOn_eq_minOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
simpa [LE.le_opposite_iff] using List.minOn_append (f := f) hxs hys
protected theorem maxOn_eq_head [LE β] [DecidableLE β] [IsLinearPreorder β] {xs : List α}
{f : α β} (h : xs []) (h' : x xs, f x f (xs.head h)) :
xs.maxOn f h = xs.head h := by
rw [List.maxOn_eq_minOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
simpa [LE.le_opposite_iff] using List.minOn_eq_head (f := f) h (by simpa [LE.le_opposite_iff] using h')
protected theorem max_map
[LE β] [DecidableLE β] [Max β] [IsLinearPreorder β] [LawfulOrderLeftLeaningMax β] {xs : List α}
{f : α β} (h : xs []) : (xs.map f).max (by simpa) = f (xs.maxOn f h) := by
letI : LE β := (inferInstanceAs (LE β)).opposite
letI : Min β := (inferInstanceAs (Max β)).oppositeMin
simpa [List.max_eq_min] using List.min_map (f := f) h
@[simp]
protected theorem maxOn_replicate [LE β] [DecidableLE β] [IsLinearPreorder β]
{n : Nat} {a : α} {f : α β} (h : replicate n a []) :
(replicate n a).maxOn f h = a :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.minOn_replicate (f := f) h
/-! ### minOn? -/
/-- {lit}`List.minOn?` returns {name}`none` when applied to an empty list. -/
@[simp]
protected theorem minOn?_nil [LE β] [DecidableLE β] {f : α β} :
([] : List α).minOn? f = none := by
simp [List.minOn?]
protected theorem minOn?_cons_eq_some_minOn
[LE β] [DecidableLE β] {f : α β} {x : α} {xs : List α} :
(x :: xs).minOn? f = some ((x :: xs).minOn f (fun h => nomatch h)) := by
simp [List.minOn?, List.minOn]
protected theorem minOn?_cons
[LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β} {x : α} {xs : List α} :
(x :: xs).minOn? f = some ((xs.minOn? f).elim x (minOn f x)) := by
simp only [List.minOn?]
split <;> simp [foldl_assoc]
@[simp]
protected theorem minOn?_singleton [LE β] [DecidableLE β] {x : α} {f : α β} :
[x].minOn? f = some x := by
simp [List.minOn?_cons_eq_some_minOn]
@[simp]
protected theorem minOn?_id [Min α] [LE α] [DecidableLE α] [LawfulOrderLeftLeaningMin α]
{xs : List α} : xs.minOn? id = xs.min? := by
cases xs
· simp
· simp only [List.minOn?_cons_eq_some_minOn, List.minOn_id, List.min?_eq_some_min (List.cons_ne_nil _ _)]
protected theorem minOn?_eq_if
[LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β} {xs : List α} :
xs.minOn? f =
if h : xs [] then
some (xs.minOn f h)
else
none := by
fun_cases xs.minOn? f <;> simp [List.minOn]
@[simp]
protected theorem isSome_minOn?_iff [LE β] [DecidableLE β] {f : α β} {xs : List α} :
(xs.minOn? f).isSome xs [] := by
fun_cases xs.minOn? f <;> simp
protected theorem minOn_eq_get_minOn? [LE β] [DecidableLE β] {f : α β} {xs : List α}
(h : xs []) : xs.minOn f h = (xs.minOn? f).get (List.isSome_minOn?_iff.mpr h) := by
fun_cases xs.minOn? f
· contradiction
· simp [List.minOn?, List.minOn]
protected theorem minOn?_eq_some_minOn [LE β] [DecidableLE β] {f : α β} {xs : List α}
(h : xs []) : xs.minOn? f = some (xs.minOn f h) := by
simp [List.minOn_eq_get_minOn? h]
@[simp]
protected theorem get_minOn? [LE β] [DecidableLE β] {f : α β} {xs : List α}
(h : xs []) : (xs.minOn? f).get (List.isSome_minOn?_iff.mpr h) = xs.minOn f h := by
rw [List.minOn_eq_get_minOn?]
protected theorem minOn_eq_of_minOn?_eq_some
[LE β] [DecidableLE β] {f : α β} {xs : List α} {x : α} (h : xs.minOn? f = some x) :
xs.minOn f (List.isSome_minOn?_iff.mp (Option.isSome_of_eq_some h)) = x := by
have h' := List.isSome_minOn?_iff.mp (Option.isSome_of_eq_some h)
rwa [List.minOn?_eq_some_minOn h', Option.some.injEq] at h
protected theorem isSome_minOn?_of_mem
[LE β] [DecidableLE β] {f : α β} {xs : List α} {x : α} (h : x xs) :
(xs.minOn? f).isSome := by
apply List.isSome_minOn?_iff.mpr
exact ne_nil_of_mem h
protected theorem apply_get_minOn?_le_of_mem
[LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β} {xs : List α} {x : α} (h : x xs) :
f ((xs.minOn? f).get (List.isSome_minOn?_of_mem h)) f x := by
rw [List.get_minOn? (ne_nil_of_mem h)]
apply List.apply_minOn_le_of_mem h
protected theorem minOn?_mem [LE β] [DecidableLE β] {xs : List α}
{f : α β} (h : xs.minOn? f = some a) : a xs := by
rw [ List.minOn_eq_of_minOn?_eq_some h]
apply List.minOn_mem
protected theorem minOn?_replicate [LE β] [DecidableLE β] [IsLinearPreorder β]
{n : Nat} {a : α} {f : α β} :
(replicate n a).minOn? f = if n = 0 then none else some a := by
split
· simp [*]
· rw [List.minOn?_eq_some_minOn, List.minOn_replicate]
simp [*]
@[simp]
protected theorem minOn?_replicate_of_pos [LE β] [DecidableLE β] [IsLinearPreorder β]
{n : Nat} {a : α} {f : α β} (h : 0 < n) :
(replicate n a).minOn? f = some a := by
simp [List.minOn?_replicate, show n 0 from Nat.ne_zero_of_lt h]
@[simp]
protected theorem minOn?_append [LE β] [DecidableLE β] [IsLinearPreorder β]
(xs ys : List α) (f : α β) :
(xs ++ ys).minOn? f =
(xs.minOn? f).merge (_root_.minOn f) (ys.minOn? f) := by
by_cases xs = [] <;> by_cases ys = [] <;> simp [*, List.minOn?_eq_if, List.minOn_append]
/-! ### maxOn? -/
protected theorem maxOn?_eq_minOn? {le : LE β} {dle : DecidableLE β} {xs : List α} {f : α β} :
xs.maxOn? f = (letI := le.opposite; xs.minOn? f) :=
(rfl)
@[simp]
protected theorem maxOn?_nil [LE β] [DecidableLE β] {f : α β} :
([] : List α).maxOn? f = none :=
List.minOn?_nil (f := f)
protected theorem maxOn?_cons_eq_some_maxOn
[LE β] [DecidableLE β] {f : α β} {x : α} {xs : List α} :
(x :: xs).maxOn? f = some ((x :: xs).maxOn f (fun h => nomatch h)) :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.minOn?_cons_eq_some_minOn
protected theorem maxOn?_cons
[LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β} {x : α} {xs : List α} :
(x :: xs).maxOn? f = some ((xs.maxOn? f).elim x (maxOn f x)) := by
have : maxOn f x = (letI : LE β := LE.opposite inferInstance; minOn f x) := by
ext; simp only [maxOn_eq_minOn]
simp only [List.maxOn?_eq_minOn?, this]
letI : LE β := (inferInstanceAs (LE β)).opposite
exact List.minOn?_cons
@[simp]
protected theorem maxOn?_singleton [LE β] [DecidableLE β] {x : α} {f : α β} :
[x].maxOn? f = some x :=
List.minOn?_singleton (f := f)
@[simp]
protected theorem maxOn?_id [Max α] [LE α] [DecidableLE α] [LawfulOrderLeftLeaningMax α]
{xs : List α} : xs.maxOn? id = xs.max? := by
letI : LE α := (inferInstanceAs (LE α)).opposite
letI : Min α := (inferInstanceAs (Max α)).oppositeMin
simpa only [List.maxOn?_eq_minOn?, List.max?_eq_min?] using List.minOn?_id (α := α)
protected theorem maxOn?_eq_if
[LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β} {xs : List α} :
xs.maxOn? f =
if h : xs [] then
some (xs.maxOn f h)
else
none :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.minOn?_eq_if
@[simp]
protected theorem isSome_maxOn?_iff [LE β] [DecidableLE β] {f : α β} {xs : List α} :
(xs.maxOn? f).isSome xs [] := by
fun_cases xs.maxOn? f <;> simp
protected theorem maxOn_eq_get_maxOn? [LE β] [DecidableLE β] {f : α β} {xs : List α}
(h : xs []) : xs.maxOn f h = (xs.maxOn? f).get (List.isSome_maxOn?_iff.mpr h) :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.minOn_eq_get_minOn? (f := f) h
protected theorem maxOn?_eq_some_maxOn [LE β] [DecidableLE β] {f : α β} {xs : List α}
(h : xs []) : xs.maxOn? f = some (xs.maxOn f h) :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.minOn?_eq_some_minOn (f := f) h
@[simp]
protected theorem get_maxOn? [LE β] [DecidableLE β] {f : α β} {xs : List α}
(h : xs []) : (xs.maxOn? f).get (List.isSome_maxOn?_iff.mpr h) = xs.maxOn f h :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.get_minOn? (f := f) h
protected theorem maxOn_eq_of_maxOn?_eq_some
[LE β] [DecidableLE β] {f : α β} {xs : List α} {x : α} (h : xs.maxOn? f = some x) :
xs.maxOn f (List.isSome_maxOn?_iff.mp (Option.isSome_of_eq_some h)) = x :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.minOn_eq_of_minOn?_eq_some (f := f) h
protected theorem isSome_maxOn?_of_mem
[LE β] [DecidableLE β] {f : α β} {xs : List α} {x : α} (h : x xs) :
(xs.maxOn? f).isSome :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.isSome_minOn?_of_mem (f := f) h
protected theorem le_apply_get_maxOn?_of_mem
[LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β} {xs : List α} {x : α} (h : x xs) :
f x f ((xs.maxOn? f).get (List.isSome_maxOn?_of_mem h)) := by
simp only [List.maxOn?_eq_minOn?]
letI : LE β := (inferInstanceAs (LE β)).opposite
simpa [LE.le_opposite_iff] using List.apply_get_minOn?_le_of_mem (f := f) h
protected theorem maxOn?_mem [LE β] [DecidableLE β] {xs : List α}
{f : α β} (h : xs.maxOn? f = some a) : a xs :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.minOn?_mem (f := f) h
protected theorem maxOn?_replicate [LE β] [DecidableLE β] [IsLinearPreorder β]
{n : Nat} {a : α} {f : α β} :
(replicate n a).maxOn? f = if n = 0 then none else some a :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.minOn?_replicate
@[simp]
protected theorem maxOn?_replicate_of_pos [LE β] [DecidableLE β] [IsLinearPreorder β]
{n : Nat} {a : α} {f : α β} (h : 0 < n) :
(replicate n a).maxOn? f = some a :=
letI : LE β := (inferInstanceAs (LE β)).opposite
List.minOn?_replicate_of_pos (f := f) h
@[simp]
protected theorem maxOn?_append [LE β] [DecidableLE β] [IsLinearPreorder β]
(xs ys : List α) (f : α β) : (xs ++ ys).maxOn? f =
(xs.maxOn? f).merge (_root_.maxOn f) (ys.maxOn? f) := by
have : maxOn f = (letI : LE β := LE.opposite inferInstance; minOn f) := by
ext; simp only [maxOn_eq_minOn]
simp only [List.maxOn?_eq_minOn?, this]
letI : LE β := (inferInstanceAs (LE β)).opposite
exact List.minOn?_append xs ys f
end List

View File

@@ -141,6 +141,10 @@ theorem take_append_of_le_length {l₁ l₂ : List α} {i : Nat} (h : i ≤ l₁
(l₁ ++ l₂).take i = l₁.take i := by
simp [take_append, Nat.sub_eq_zero_of_le h]
@[grind =]
theorem take_append_length {l₁ l₂ : List α} : (l₁ ++ l₂).take l₁.length = l₁ := by
simp
/-- Taking the first `l₁.length + i` elements in `l₁ ++ l₂` is the same as appending the first
`i` elements of `l₂` to `l₁`. -/
theorem take_length_add_append {l₁ l₂ : List α} (i : Nat) :
@@ -304,7 +308,6 @@ theorem drop_length_cons {l : List α} (h : l ≠ []) (a : α) :
/-- Dropping the elements up to `i` in `l₁ ++ l₂` is the same as dropping the elements up to `i`
in `l₁`, dropping the elements up to `i - l₁.length` in `l₂`, and appending them. -/
@[grind =]
theorem drop_append {l₁ l₂ : List α} {i : Nat} :
drop i (l₁ ++ l₂) = drop i l₁ ++ drop (i - l₁.length) l₂ := by
induction l₁ generalizing i
@@ -315,10 +318,15 @@ theorem drop_append {l₁ l₂ : List α} {i : Nat} :
congr 1
omega
@[grind =]
theorem drop_append_of_le_length {l₁ l₂ : List α} {i : Nat} (h : i l₁.length) :
(l₁ ++ l₂).drop i = l₁.drop i ++ l₂ := by
simp [drop_append, Nat.sub_eq_zero_of_le h]
@[grind =]
theorem drop_append_length {l₁ l₂ : List α} : (l₁ ++ l₂).drop l₁.length = l₂ := by
simp [List.drop_append_of_le_length (Nat.le_refl _)]
/-- Dropping the elements up to `l₁.length + i` in `l₁ + l₂` is the same as dropping the elements
up to `i` in `l₂`. -/
@[simp]

View File

@@ -54,6 +54,15 @@ theorem div_le_iff_le_mul (h : 0 < k) : x / k ≤ y ↔ x ≤ y * k + k - 1 := b
rw [le_iff_lt_add_one, Nat.div_lt_iff_lt_mul h, Nat.add_one_mul]
omega
theorem le_mul_iff_le_left (hz : 0 < z) :
x y * z (x + z - 1) / z y := by
rw [Nat.div_le_iff_le_mul hz]
omega
theorem le_mul_iff_le_right (hy : 0 < y) :
x y * z (x + y - 1) / y z := by
rw [ le_mul_iff_le_left hy, Nat.mul_comm]
-- TODO: reprove `div_eq_of_lt_le` in terms of this:
protected theorem div_eq_iff (h : 0 < k) : x / k = y y * k x x y * k + k - 1 := by
rw [Nat.eq_iff_le_and_ge, and_comm, le_div_iff_mul_le h, Nat.div_le_iff_le_mul h]
@@ -95,6 +104,12 @@ theorem div_add_le_right {z : Nat} (h : 0 < z) (x y : Nat) :
x / (y + z) x / z :=
div_le_div_left (Nat.le_add_left z y) h
theorem div_add_div_le_add_div {x y z : Nat} : x / z + y / z (x + y) / z := by
by_cases hc : z > 0
· rw [Nat.le_div_iff_mul_le hc, Nat.add_mul]
apply Nat.add_le_add <;> apply Nat.div_mul_le_self
· simp_all
theorem succ_div_of_dvd {a b : Nat} (h : b a + 1) :
(a + 1) / b = a / b + 1 := by
replace h := mod_eq_zero_of_dvd h

View File

@@ -13,4 +13,6 @@ public import Init.Data.Order.Lemmas
public import Init.Data.Order.LemmasExtra
public import Init.Data.Order.Factories
public import Init.Data.Order.FactoriesExtra
public import Init.Data.Order.MinMaxOn
public import Init.Data.Order.Opposite
public import Init.Data.Order.PackageFactories

View File

@@ -142,6 +142,10 @@ public theorem not_gt_of_lt {α : Type u} [LT α] [i : Std.Asymm (α := α) (·
(h : a < b) : ¬ b < a :=
i.asymm a b h
public theorem lt_irrefl {α : Type u} [LT α] [i : Std.Irrefl (α := α) (· < ·)] {a : α} :
¬ a < a :=
i.irrefl a
public theorem le_of_lt {α : Type u} [LT α] [LE α] [LawfulOrderLT α] {a b : α} (h : a < b) :
a b := (lt_iff_le_and_not_ge.1 h).1

View File

@@ -0,0 +1,198 @@
/-
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Reichert
-/
module
prelude
public import Init.NotationExtra
public import Init.Data.Order.Lemmas
public import Init.Data.Order.Opposite
open Std
open scoped OppositeOrderInstances
/-! ## Definitions -/
/--
Returns either `x` or `y`, the one with the smaller value under `f`.
If `f x ≤ f y`, it returns `x`, and otherwise returns `y`.
-/
public def minOn [LE β] [DecidableLE β] (f : α β) (x y : α) :=
if f x f y then x else y
/--
Returns either `x` or `y`, the one with the greater value under `f`.
If `f y ≤ f x`, it returns `x`, and otherwise returns `y`.
-/
public def maxOn [i : LE β] [DecidableLE β] (f : α β) (x y : α) :=
letI := i.opposite
minOn f x y
/-! ## `minOn` Lemmas -/
public theorem minOn_id [Min α] [LE α] [DecidableLE α] [LawfulOrderLeftLeaningMin α] {x y : α} :
minOn id x y = min x y := by
simp [minOn, min_eq_if]
public theorem maxOn_id [Max α] [LE α] [DecidableLE α] [LawfulOrderLeftLeaningMax α] {x y : α} :
maxOn id x y = max x y := by
letI : LE α := (inferInstanceAs (LE α)).opposite
letI : Min α := (inferInstanceAs (Max α)).oppositeMin
simp [maxOn, minOn_id, Max.min_oppositeMin, this]
public theorem minOn_eq_or [LE β] [DecidableLE β] {f : α β} {x y : α} :
minOn f x y = x minOn f x y = y := by
rw [minOn]
split
· exact Or.inl rfl
· exact Or.inr rfl
@[simp]
public theorem minOn_self [LE β] [DecidableLE β] {f : α β} {x : α} :
minOn f x x = x := by
cases minOn_eq_or (f := f) (x := x) (y := x) <;> assumption
public theorem minOn_eq_left [LE β] [DecidableLE β] {f : α β} {x y : α} (h : f x f y) :
minOn f x y = x := by
simp [minOn, h]
public theorem minOn_eq_right [LE β] [DecidableLE β] {f : α β} {x y : α} (h : ¬ f x f y) :
minOn f x y = y := by
simp [minOn, h]
public theorem minOn_eq_right_of_lt
[LE β] [DecidableLE β] [LT β] [Total (α := β) (· ·)] [LawfulOrderLT β]
{f : α β} {x y : α} (h : f y < f x) :
minOn f x y = y := by
apply minOn_eq_right
simpa [not_le] using h
public theorem apply_minOn_le_left [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β}
{x y : α} : f (minOn f x y) f x := by
rw [minOn]
split
· apply le_refl
· exact le_of_not_ge _
public theorem apply_minOn_le_right [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β}
{x y : α} : f (minOn f x y) f y := by
rw [minOn]
split
· assumption
· apply le_refl
public theorem le_apply_minOn_iff [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β}
{x y : α} {b : β} :
b f (minOn f x y) b f x b f y := by
apply Iff.intro
· intro h
exact le_trans h apply_minOn_le_left, le_trans h apply_minOn_le_right
· intro h
cases minOn_eq_or (f := f) (x := x) (y := y) <;> simp_all
public theorem minOn_assoc [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β}
{x y z : α} : minOn f (minOn f x y) z = minOn f x (minOn f y z) := by
open scoped Classical.Order in
simp only [minOn]
split
· split
· split
· rfl
· rfl
· split
· have : ¬ f x f z := by assumption
have : f x f z := le_trans f x f y f y f z
contradiction
· rfl
· split
· rfl
· have : f z < f y := not_le.mp ¬ f y f z
have : f y < f x := not_le.mp ¬ f x f y
have : f z < f x := lt_trans _ _
rw [if_neg]
exact not_le.mpr _
public instance [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β} :
Associative (minOn f) where
assoc := by apply minOn_assoc
public theorem min_apply [LE β] [DecidableLE β] [Min β] [LawfulOrderLeftLeaningMin β]
{f : α β} {x y : α} : min (f x) (f y) = f (minOn f x y) := by
rw [min_eq_if, minOn]
split <;> rfl
/-! ## `maxOn` Lemmas -/
public theorem maxOn_eq_minOn [le : LE β] [DecidableLE β] {f : α β} {x y : α} :
maxOn f x y = (letI := le.opposite; minOn f x y) :=
(rfl)
public theorem maxOn_eq_or [LE β] [DecidableLE β] {f : α β} {x y : α} :
maxOn f x y = x maxOn f x y = y :=
@minOn_eq_or ..
@[simp]
public theorem maxOn_self [LE β] [DecidableLE β] {f : α β} {x : α} :
maxOn f x x = x :=
@minOn_self ..
public theorem maxOn_eq_left [le : LE β] [DecidableLE β] {f : α β} {x y : α} (h : f y f x) :
maxOn f x y = x := by
simp only [maxOn_eq_minOn]
exact @minOn_eq_left (h := by simpa [LE.opposite_def]) ..
public theorem maxOn_eq_right [LE β] [DecidableLE β] {f : α β} {x y : α} (h : ¬ f y f x) :
maxOn f x y = y := by
simp only [maxOn_eq_minOn]
exact @minOn_eq_right (h := by simpa [LE.opposite_def]) ..
public theorem maxOn_eq_right_of_lt
[LE β] [DecidableLE β] [LT β] [Total (α := β) (· ·)] [LawfulOrderLT β]
{f : α β} {x y : α} (h : f x < f y) :
maxOn f x y = y :=
letI : LE β := (inferInstanceAs (LE β)).opposite
letI : LT β := (inferInstanceAs (LT β)).opposite
minOn_eq_right_of_lt (h := by simpa [LT.lt_opposite_iff] using h) ..
public theorem left_le_apply_maxOn [le : LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β}
{x y : α} : f x f (maxOn f x y) := by
rw [maxOn_eq_minOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
simpa only [LE.le_opposite_iff] using apply_minOn_le_left (f := f) ..
public theorem right_le_apply_maxOn [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β}
{x y : α} : f y f (maxOn f x y) := by
rw [maxOn_eq_minOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
simpa only [LE.le_opposite_iff] using apply_minOn_le_right (f := f)
public theorem apply_maxOn_le_iff [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β}
{x y : α} {b : β} :
f (maxOn f x y) b f x b f y b := by
rw [maxOn_eq_minOn]
letI : LE β := (inferInstanceAs (LE β)).opposite
simpa only [LE.le_opposite_iff] using le_apply_minOn_iff (f := f)
public theorem maxOn_assoc [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β}
{x y z : α} : maxOn f (maxOn f x y) z = maxOn f x (maxOn f y z) :=
letI : LE β := (inferInstanceAs (LE β)).opposite
minOn_assoc (f := f)
public instance [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α β} :
Associative (maxOn f) where
assoc := by
apply maxOn_assoc
public theorem max_apply [LE β] [DecidableLE β] [Max β] [LawfulOrderLeftLeaningMax β]
{f : α β} {x y : α} : max (f x) (f y) = f (maxOn f x y) := by
letI : LE β := (inferInstanceAs (LE β)).opposite
letI : Min β := (inferInstanceAs (Max β)).oppositeMin
simpa [Max.min_oppositeMin] using min_apply (f := f)
public theorem apply_maxOn [LE β] [DecidableLE β] [Max β] [LawfulOrderLeftLeaningMax β]
{f : α β} {x y : α} : f (maxOn f x y) = max (f x) (f y) :=
max_apply.symm

View File

@@ -0,0 +1,407 @@
/-
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Reichert
-/
module
prelude
public import Init.Data.Order.ClassesExtra
public import Init.Data.Order.LemmasExtra
public section
open Std
set_option linter.missingDocs true
set_option linter.listVariables true -- Enforce naming conventions for `List`/`Array`/`Vector` variables.
set_option linter.indexVariables true -- Enforce naming conventions for index variables.
/-
Note: We're having verso docstrings disabled because the examples depend on instances that
are provided later in the module. They will be converted into verso docstrings at the end
of the module.
-/
/--
Inverts an {name}`LE` instance.
The result is an {lean}`LE α` instance where {lit}`a ≤ b` holds when {name}`le` would have
{lit}`b ≤ a` hold.
If {name}`le` obeys laws, then {lean}`le.opposite` obeys the opposite laws. For example, if
{name}`le` encodes a linear order, then {lean}`le.opposite` also encodes a linear order.
To automatically derive these laws, use {lit}`open Std.OppositeOrderInstances`.
For example, {name}`LE.opposite` can be used to derive maximum operations from minimum operations,
since finding the minimum in the opposite order is the same as finding the maximum in the original order:
```lean +warning
def min' [LE α] [DecidableLE α] (a b : α) : α :=
if a ≤ b then a else b
open scoped Std.OppositeOrderInstances in
def max' [LE α] [DecidableLE α] (a b : α) : α :=
letI : LE α := (inferInstanceAs (LE α)).opposite
-- `DecidableLE` for the opposite order is derived automatically via `OppositeOrderInstances`
min' a b
```
Without the `open scoped` command, Lean would not find the required {lit}`DecidableLE α`
instance for the opposite order.
-/
def LE.opposite (le : LE α) : LE α where
le a b := b a
theorem LE.opposite_def {le : LE α} :
le.opposite = fun a b => b a :=
(rfl)
theorem LE.le_opposite_iff {le : LE α} {a b : α} :
(haveI := le.opposite; a b) b a := by
exact Iff.rfl
/--
Inverts an {name}`LT` instance.
The result is an {lean}`LT α` instance where {lit}`a < b` holds when {name}`lt` would have
{lit}`b < a` hold.
If {name}`lt` obeys laws, then {lean}`lt.opposite` obeys the opposite laws.
To automatically derive these laws, use {lit}`open scoped Std.OppositeOrderInstances`.
For example, one can use the derived instances to prove properties about the opposite {name}`LT`
instance:
```lean
open scoped Std.OppositeOrderInstances in
example [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsLinearOrder α] {x y : α} :
letI : LE α := LE.opposite inferInstance
letI : LT α := LT.opposite inferInstance
¬ y ≤ x ↔ x < y :=
letI : LE α := LE.opposite inferInstance
letI : LT α := LT.opposite inferInstance
Std.not_le
```
Without the `open scoped` command, Lean would not find the {lit}`LawfulOrderLT α`
and {lit}`IsLinearOrder α` instances for the opposite order that are required by {name}`not_le`.
-/
def LT.opposite (lt : LT α) : LT α where
lt a b := b < a
theorem LT.opposite_def {lt : LT α} :
lt.opposite = fun a b => b < a :=
(rfl)
theorem LT.lt_opposite_iff {lt : LT α} {a b : α} :
(haveI := lt.opposite; a < b) b < a := by
exact Iff.rfl
/--
Creates a {name}`Max` instance from a {name}`Min` instance.
The result is a {lean}`Max α` instance that uses {lean}`min.min` as its {name}`max` operation.
If {name}`min` obeys laws, then {lean}`min.oppositeMax` obeys the corresponding laws for {name}`Max`.
To automatically derive these laws, use {lit}`open scoped Std.OppositeOrderInstances`.
For example, one can use the derived instances to prove properties about the opposite {name}`Max`
instance:
```lean
open scoped Std.OppositeOrderInstances in
example [LE α] [DecidableLE α] [Min α] [Std.LawfulOrderLeftLeaningMin α] {a b : α} :
letI : LE α := LE.opposite inferInstance
letI : Max α := (inferInstance : Min α).oppositeMax
max a b = if b ≤ a then a else b :=
letI : LE α := LE.opposite inferInstance
letI : Max α := (inferInstance : Min α).oppositeMax
Std.max_eq_if
```
Without the `open scoped` command, Lean would not find the {lit}`LawfulOrderLeftLeaningMax α`
instance for the opposite order that is required by {name}`max_eq_if`.
-/
def Min.oppositeMax (min : Min α) : Max α where
max a b := Min.min a b
theorem Min.oppositeMax_def {min : Min α} :
min.oppositeMax = Min.min :=
(rfl)
theorem Min.max_oppositeMax {min : Min α} {a b : α} :
(haveI := min.oppositeMax; Max.max a b) = Min.min a b :=
(rfl)
/--
Creates a {name}`Min` instance from a {name}`Max` instance.
The result is a {lean}`Min α` instance that uses {lean}`max.max` as its {name}`min` operation.
If {name}`max` obeys laws, then {lean}`max.oppositeMin` obeys the corresponding laws for {name}`Min`.
To automatically derive these laws, use {lit}`open scoped Std.OppositeOrderInstances`.
For example, one can use the derived instances to prove properties about the opposite {name}`Min`
instance:
```lean
open scoped Std.OppositeOrderInstances in
example [LE α] [DecidableLE α] [Max α] [Std.LawfulOrderLeftLeaningMax α] {a b : α} :
letI : LE α := LE.opposite inferInstance
letI : Min α := (inferInstance : Max α).oppositeMin
min a b = if a ≤ b then a else b :=
letI : LE α := LE.opposite inferInstance
letI : Min α := (inferInstance : Max α).oppositeMin
Std.min_eq_if
```
Without the `open scoped` command, Lean would not find the {lit}`LawfulOrderLeftLeaningMin α`
instance for the opposite order that is required by {name}`min_eq_if`.
-/
def Max.oppositeMin (max : Max α) : Min α where
min a b := Max.max a b
theorem Max.oppositeMin_def {min : Max α} :
min.oppositeMin = Max.max :=
(rfl)
theorem Max.min_oppositeMin {max : Max α} {a b : α} :
(haveI := max.oppositeMin; Min.min a b) = Max.max a b :=
(rfl)
namespace Std.OppositeOrderInstances
@[no_expose]
scoped instance (priority := low) instDecidableLEOpposite {i : LE α} [id : DecidableLE α] :
haveI := i.opposite
DecidableLE α :=
fun a b => id b a
@[no_expose]
scoped instance (priority := low) instDecidableLTOpposite {i : LT α} [id : DecidableLT α] :
haveI := i.opposite
DecidableLT α :=
fun a b => id b a
scoped instance (priority := low) instLEReflOpposite {i : LE α} [Refl (α := α) (· ·)] :
haveI := i.opposite
Refl (α := α) (· ·) :=
letI := i.opposite
{ refl a := letI := i; le_refl a }
scoped instance (priority := low) instLESymmOpposite {i : LE α} [Symm (α := α) (· ·)] :
haveI := i.opposite
Symm (α := α) (· ·) :=
letI := i.opposite
{ symm a b hab := by
simp only [LE.opposite] at *
letI := i
exact Symm.symm b a hab }
scoped instance (priority := low) instLEAntisymmOpposite {i : LE α} [Antisymm (α := α) (· ·)] :
haveI := i.opposite
Antisymm (α := α) (· ·) :=
letI := i.opposite
{ antisymm a b hab hba := by
simp only [LE.opposite] at *
letI := i
exact le_antisymm hba hab }
scoped instance (priority := low) instLEAsymmOpposite {i : LE α} [Asymm (α := α) (· ·)] :
haveI := i.opposite
Asymm (α := α) (· ·) :=
letI := i.opposite
{ asymm a b hab := by
simp only [LE.opposite] at *
letI := i
exact Asymm.asymm b a hab }
scoped instance (priority := low) instLETransOpposite {i : LE α}
[Trans (· ·) (· ·) (· · : α α Prop)] :
haveI := i.opposite
Trans (· ·) (· ·) (· · : α α Prop) :=
letI := i.opposite
{ trans hab hbc := by
simp only [LE.opposite] at *
letI := i
exact Trans.trans hbc hab }
scoped instance (priority := low) instLETotalOpposite {i : LE α} [Total (α := α) (· ·)] :
haveI := i.opposite
Total (α := α) (· ·) :=
letI := i.opposite
{ total a b := letI := i; le_total (a := b) (b := a) }
scoped instance (priority := low) instLEIrreflOpposite {i : LE α} [Irrefl (α := α) (· ·)] :
haveI := i.opposite
Irrefl (α := α) (· ·) :=
letI := i.opposite
{ irrefl a := letI := i; Irrefl.irrefl (r := (· ·)) a }
scoped instance (priority := low) instIsPreorderOpposite {i : LE α} [IsPreorder α] :
haveI := i.opposite
IsPreorder α :=
letI := i.opposite
{ le_refl a := le_refl a
le_trans _ _ _ := le_trans }
scoped instance (priority := low) instIsPartialOrderOpposite {i : LE α} [IsPartialOrder α] :
haveI := i.opposite
IsPartialOrder α :=
letI := i.opposite
{ le_antisymm _ _ := le_antisymm }
scoped instance (priority := low) instIsLinearPreorderOpposite {i : LE α} [IsLinearPreorder α] :
haveI := i.opposite
IsLinearPreorder α :=
letI := i.opposite
{ le_total _ _ := le_total }
scoped instance (priority := low) instIsLinearOrderOpposite {i : LE α} [IsLinearOrder α] :
haveI := i.opposite
IsLinearOrder α :=
letI := i.opposite; {}
scoped instance (priority := low) instLawfulOrderOrdOpposite {il : LE α} {io : Ord α}
[LawfulOrderOrd α] :
haveI := il.opposite
haveI := io.opposite
LawfulOrderOrd α :=
letI := il.opposite
letI := io.opposite
{ isLE_compare a b := by
simp only [LE.opposite, Ord.opposite]
letI := il; letI := io
apply isLE_compare
isGE_compare a b := by
simp only [LE.opposite, Ord.opposite]
letI := il; letI := io
apply isGE_compare }
scoped instance (priority := low) instLawfulOrderLTOpposite {il : LE α} {it : LT α}
[LawfulOrderLT α] :
haveI := il.opposite
haveI := it.opposite
LawfulOrderLT α :=
letI := il.opposite
letI := it.opposite
{ lt_iff a b := by
simp only [LE.opposite, LT.opposite]
letI := il; letI := it
exact LawfulOrderLT.lt_iff b a }
scoped instance (priority := low) instLawfulOrderBEqOpposite {il : LE α} {ib : BEq α}
[LawfulOrderBEq α] :
haveI := il.opposite
LawfulOrderBEq α :=
letI := il.opposite
{ beq_iff_le_and_ge a b := by
simp only [LE.opposite]
letI := il; letI := ib
rw [LawfulOrderBEq.beq_iff_le_and_ge]
exact and_comm }
scoped instance (priority := low) instLawfulOrderInfOpposite {il : LE α} {im : Min α}
[LawfulOrderInf α] :
haveI := il.opposite
haveI := im.oppositeMax
LawfulOrderSup α :=
letI := il.opposite
letI := im.oppositeMax
{ max_le_iff a b c := by
simp only [LE.opposite, Min.oppositeMax]
letI := il; letI := im
exact LawfulOrderInf.le_min_iff c a b }
scoped instance (priority := low) instLawfulOrderMinOpposite {il : LE α} {im : Min α}
[LawfulOrderMin α] :
haveI := il.opposite
haveI := im.oppositeMax
LawfulOrderMax α :=
letI := il.opposite
letI := im.oppositeMax
{ max_eq_or a b := by
simp only [Min.oppositeMax]
letI := il; letI := im
exact MinEqOr.min_eq_or a b
max_le_iff a b c := by
simp only [LE.opposite, Min.oppositeMax]
letI := il; letI := im
exact LawfulOrderInf.le_min_iff c a b }
scoped instance (priority := low) instLawfulOrderSupOpposite {il : LE α} {im : Max α}
[LawfulOrderSup α] :
haveI := il.opposite
haveI := im.oppositeMin
LawfulOrderInf α :=
letI := il.opposite
letI := im.oppositeMin
{ le_min_iff a b c := by
simp only [LE.opposite, Max.oppositeMin]
letI := il; letI := im
exact LawfulOrderSup.max_le_iff b c a }
scoped instance (priority := low) instLawfulOrderMaxOpposite {il : LE α} {im : Max α}
[LawfulOrderMax α] :
haveI := il.opposite
haveI := im.oppositeMin
LawfulOrderMin α :=
letI := il.opposite
letI := im.oppositeMin
{ min_eq_or a b := by
simp only [Max.oppositeMin]
letI := il; letI := im
exact MaxEqOr.max_eq_or a b
le_min_iff a b c := by
simp only [LE.opposite, Max.oppositeMin]
letI := il; letI := im
exact LawfulOrderSup.max_le_iff b c a }
scoped instance (priority := low) instLawfulOrderLeftLeaningMinOpposite {il : LE α} {im : Min α}
[LawfulOrderLeftLeaningMin α] :
haveI := il.opposite
haveI := im.oppositeMax
LawfulOrderLeftLeaningMax α :=
letI := il.opposite
letI := im.oppositeMax
{ max_eq_left a b hab := by
simp only [Min.oppositeMax]
letI := il; letI := im
exact LawfulOrderLeftLeaningMin.min_eq_left a b hab
max_eq_right a b hab := by
simp only [Min.oppositeMax]
letI := il; letI := im
exact LawfulOrderLeftLeaningMin.min_eq_right a b hab }
scoped instance (priority := low) instLawfulOrderLeftLeaningMaxOpposite {il : LE α} {im : Max α}
[LawfulOrderLeftLeaningMax α] :
haveI := il.opposite
haveI := im.oppositeMin
LawfulOrderLeftLeaningMin α :=
letI := il.opposite
letI := im.oppositeMin
{ min_eq_left a b hab := by
simp only [Max.oppositeMin]
letI := il; letI := im
exact LawfulOrderLeftLeaningMax.max_eq_left a b hab
min_eq_right a b hab := by
simp only [Max.oppositeMin]
letI := il; letI := im
exact LawfulOrderLeftLeaningMax.max_eq_right a b hab }
end OppositeOrderInstances
-- When imported from a non-module, these instances are exposed, and reducing them during
-- type class resolution is too inefficient.
attribute [irreducible] LE.opposite LT.opposite Min.oppositeMax Max.oppositeMin
section DocsToVerso
set_option linter.unusedVariables false -- Otherwise, we get warnings about Verso code blocks.
docs_to_verso LE.opposite
docs_to_verso LT.opposite
docs_to_verso Min.oppositeMax
docs_to_verso Max.oppositeMin
end DocsToVerso

View File

@@ -85,9 +85,12 @@ theorem toList_eq {α : Type u} {it : Iter (α := SubarrayIterator α) α} :
· rw [dif_neg]; rotate_left; exact h
simp_all [it.internalState.xs.stop_le_array_size]
theorem count_eq {α : Type u} {it : Iter (α := SubarrayIterator α) α} :
it.count = it.internalState.xs.stop - it.internalState.xs.start := by
simp [ Iter.length_toList_eq_count, toList_eq, it.internalState.xs.stop_le_array_size]
theorem length_eq {α : Type u} {it : Iter (α := SubarrayIterator α) α} :
it.length = it.internalState.xs.stop - it.internalState.xs.start := by
simp [ Iter.length_toList_eq_length, toList_eq, it.internalState.xs.stop_le_array_size]
@[deprecated length_eq (since := "2026-01-28")]
def count_eq := @length_eq
end SubarrayIterator
@@ -105,7 +108,7 @@ theorem toList_internalIter {α : Type u} {s : Subarray α} :
public instance : LawfulSliceSize (Internal.SubarrayData α) where
lawful s := by
simp [SliceSize.size, ToIterator.iter_eq, Iter.toIter_toIterM,
Iter.length_toList_eq_count, SubarrayIterator.toList_eq,
Iter.length_toList_eq_length, SubarrayIterator.toList_eq,
s.internalRepresentation.stop_le_array_size, start, stop, array]
public theorem toArray_eq_sliceToArray {α : Type u} {s : Subarray α} :

View File

@@ -60,12 +60,15 @@ public theorem forIn_toArray {γ : Type u} {β : Type v}
ForIn.forIn s.toArray init f = ForIn.forIn s init f := by
rw [ forIn_internalIter, Iter.forIn_toArray, Slice.toArray]
theorem Internal.size_eq_count_iter [ToIterator (Slice γ) Id α β]
theorem Internal.size_eq_length_iter [ToIterator (Slice γ) Id α β]
[Iterator α Id β] [Finite α Id]
[IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id]
{s : Slice γ} [SliceSize γ] [LawfulSliceSize γ] :
s.size = (Internal.iter s).count := by
simp only [Slice.size, iter, LawfulSliceSize.lawful, Iter.length_toList_eq_count]
s.size = (Internal.iter s).length := by
simp only [Slice.size, iter, LawfulSliceSize.lawful, Iter.length_toList_eq_length]
@[deprecated Internal.size_eq_length_iter (since := "2026-01-28")]
def Internal.size_eq_count_iter := @Internal.size_eq_length_iter
theorem Internal.toArray_eq_toArray_iter {s : Slice γ} [ToIterator (Slice γ) Id α β]
[Iterator α Id β]
@@ -91,7 +94,7 @@ theorem size_toArray_eq_size [ToIterator (Slice γ) Id α β]
{s : Slice γ} :
s.toArray.size = s.size := by
letI : IteratorLoop α Id Id := .defaultImplementation
rw [Internal.size_eq_count_iter, Internal.toArray_eq_toArray_iter, Iter.size_toArray_eq_count]
rw [Internal.size_eq_length_iter, Internal.toArray_eq_toArray_iter, Iter.size_toArray_eq_length]
@[simp]
theorem length_toList_eq_size [ToIterator (Slice γ) Id α β]
@@ -100,7 +103,7 @@ theorem length_toList_eq_size [ToIterator (Slice γ) Id α β]
[Finite α Id] :
s.toList.length = s.size := by
letI : IteratorLoop α Id Id := .defaultImplementation
rw [Internal.size_eq_count_iter, Internal.toList_eq_toList_iter, Iter.length_toList_eq_count]
rw [Internal.size_eq_length_iter, Internal.toList_eq_toList_iter, Iter.length_toList_eq_length]
@[simp]
theorem length_toListRev_eq_size [ToIterator (Slice γ) Id α β]
@@ -109,7 +112,7 @@ theorem length_toListRev_eq_size [ToIterator (Slice γ) Id α β]
[Finite α Id]
[LawfulIteratorLoop α Id Id] :
s.toListRev.length = s.size := by
rw [Internal.size_eq_count_iter, Internal.toListRev_eq_toListRev_iter,
Iter.length_toListRev_eq_count]
rw [Internal.size_eq_length_iter, Internal.toListRev_eq_toListRev_iter,
Iter.length_toListRev_eq_length]
end Std.Slice

View File

@@ -34,7 +34,7 @@ attribute [instance] ListSlice.instToIterator
universe v w
instance : SliceSize (Internal.ListSliceData α) where
size s := (Internal.iter s).count
size s := (Internal.iter s).length
@[no_expose]
instance {α : Type u} {m : Type v Type w} [Monad m] :

View File

@@ -60,7 +60,7 @@ public theorem toList_toArray {xs : ListSlice α} :
@[simp, grind =]
public theorem length_toList {xs : ListSlice α} :
xs.toList.length = xs.size := by
simp [ListSlice.toList_eq, Std.Slice.size, Std.Slice.SliceSize.size, Iter.length_toList_eq_count,
simp [ListSlice.toList_eq, Std.Slice.size, Std.Slice.SliceSize.size, Iter.length_toList_eq_length,
toList_internalIter]; rfl
@[grind =]

View File

@@ -45,7 +45,7 @@ class LawfulSliceSize (γ : Type u) [SliceSize γ] [ToIterator (Slice γ) Id α
/-- The iterator of a slice `s` of type `Slice γ` emits exactly `SliceSize.size s` elements. -/
lawful :
letI : IteratorLoop α Id Id := .defaultImplementation
s : Slice γ, SliceSize.size s = (ToIterator.iter (γ := Slice γ) s).count
s : Slice γ, SliceSize.size s = (ToIterator.iter (γ := Slice γ) s).length
/--
Returns the number of elements with distinct indices in the given slice.

View File

@@ -905,9 +905,9 @@ Examples:
def chars (s : Slice) :=
Std.Iter.map (fun pos, h => pos.get h) (positions s)
@[deprecated "There is no constant-time length function on slices. Use `s.positions.count` instead, or `isEmpty` if you only need to know whether the slice is empty." (since := "2025-11-20")]
@[deprecated "There is no constant-time length function on slices. Use `s.positions.length` instead, or `isEmpty` if you only need to know whether the slice is empty." (since := "2025-11-20")]
def length (s : Slice) : Nat :=
s.positions.count
s.positions.length
structure RevPosIterator (s : Slice) where
currPos : s.Pos

View File

@@ -137,6 +137,11 @@ structure Config where
For local theorems, use `+suggestions` instead.
-/
locals : Bool := false
/--
If `instances` is `true`, `dsimp` will visit instance arguments.
If option `backward.dsimp.instances` is `true`, it overrides this field.
-/
instances : Bool := false
deriving Inhabited, BEq
end DSimp
@@ -308,6 +313,11 @@ structure Config where
For local theorems, use `+suggestions` instead.
-/
locals : Bool := false
/--
If `instances` is `true`, `dsimp` will visit instance arguments.
If option `backward.dsimp.instances` is `true`, it overrides this field.
-/
instances : Bool := false
deriving Inhabited, BEq
-- Configuration object for `simp_all`
@@ -374,7 +384,7 @@ structure ExtractLetsConfig where
/-- If true (default: false), eliminate unused lets rather than extract them. -/
usedOnly : Bool := false
/-- If true (default: true), reuse local declarations that have syntactically equal values.
Note that even when false, the caching strategy for `extract_let`s may result in fewer extracted let bindings than expected. -/
Note that even when false, the caching strategy for `extract_lets` may result in fewer extracted let bindings than expected. -/
merge : Bool := true
/-- When merging is enabled, if true (default: true), make use of pre-existing local definitions in the local context. -/
useContext : Bool := true

View File

@@ -872,6 +872,12 @@ Substring matching:
(after whitespace normalization). This is useful when you only care about part of the message.
- `substring := false` (the default) requires exact matching (modulo whitespace normalization).
Stabilizing output:
When messages contain autogenerated names (e.g., metavariables like `?m.47`), the output may
differ between runs or Lean versions. Use `set_option pp.mvars.anonymous false` to replace
anonymous metavariables with `?_` while preserving user-named metavariables like `?a`.
Alternatively, `set_option pp.mvars false` replaces all metavariables with `?_`.
For example, `#guard_msgs (error, drop all) in cmd` means to check errors and drop
everything else.

View File

@@ -322,6 +322,10 @@ For more information: [Equality](https://lean-lang.org/theorem_proving_in_lean4/
@[symm] theorem Eq.symm {α : Sort u} {a b : α} (h : Eq a b) : Eq b a :=
h rfl
/-- Non-dependent recursor for the equality type (symmetric variant) -/
@[simp] abbrev Eq.ndrec_symm.{u1, u2} {α : Sort u2} {a : α} {motive : α Sort u1} (m : motive a) {b : α} (h : Eq b a) : motive b :=
h.symm.ndrec m
/--
Equality is transitive: if `a = b` and `b = c` then `a = c`.

View File

@@ -3,12 +3,9 @@ Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura, Mario Carneiro
-/
module
prelude
public import Init.Data.Array.Set
public section
/-!

View File

@@ -110,11 +110,13 @@ def fileUriToPath? (uri : String) : Option System.FilePath := Id.run do
else
let mut p := (unescapeUri uri).drop "file://".length |>.copy
p := p.dropWhile (λ c => c != '/') |>.copy -- drop the hostname.
-- On Windows, the path "/c:/temp" needs to become "C:/temp"
if System.Platform.isWindows && p.length >= 2 &&
p.front == '/' && (String.Pos.Raw.get p 1).isAlpha && String.Pos.Raw.get p 2 == ':' then
-- see also `pathToUri`
p := String.Pos.Raw.modify (p.drop 1).copy 0 .toUpper
if System.Platform.isWindows then
-- On Windows, the path "/c:/temp" needs to become "C:/temp"
if p.length >= 2 &&
p.front == '/' && (String.Pos.Raw.get p 1).isAlpha && String.Pos.Raw.get p 2 == ':' then
-- see also `pathToUri`
p := String.Pos.Raw.modify (p.drop 1).copy 0 .toUpper
p := p.map (fun c => if c == '/' then '\\' else c)
some p
end Uri

View File

@@ -1093,8 +1093,6 @@ See also:
* `first | tac1 | tac2` implements the backtracking used by `repeat`
-/
syntax "repeat " tacticSeq : tactic
macro_rules
| `(tactic| repeat $seq) => `(tactic| first | ($seq); repeat $seq | skip)
/--
`repeat' tac` recursively applies `tac` on all of the goals so long as it succeeds.

View File

@@ -270,10 +270,11 @@ def registerParametricAttribute (impl : ParametricAttributeImpl α) : IO (Parame
let mut r := if impl.preserveOrder then
decls.toArray.reverse.filterMap (fun n => return (n, m.find? n))
else
m.foldl (fun a n p => a.push (n, p)) #[]
let r := m.foldl (fun a n p => a.push (n, p)) #[]
r.qsort (fun a b => Name.quickLt a.1 b.1)
if lvl != .private then
r := r.filter (fun n, a => impl.filterExport env n a)
r.qsort (fun a b => Name.quickLt a.1 b.1)
r
statsFn := fun (_, m) => "parametric attribute" ++ Format.line ++ "number of local entries: " ++ format m.size
}
let attrImpl : AttributeImpl := {

View File

@@ -33,6 +33,7 @@ def isAuxRecursor (env : Environment) (declName : Name) : Bool :=
-- TODO: use `markAuxRecursor` when they are defined
-- An attribute is not a good solution since we don't want users to control what is tagged as an auxiliary recursor.
|| declName == ``Eq.ndrec
|| declName == ``Eq.ndrec_symm
|| declName == ``Eq.ndrecOn
def isAuxRecursorWithSuffix (env : Environment) (declName : Name) (suffix : String) : Bool :=

View File

@@ -115,10 +115,10 @@ private def exportIREntries (env : Environment) : Array (Name × Array EnvExtens
-- safety: cast to erased type
let irEntries : Array EnvExtensionEntry := unsafe unsafeCast <| sortDecls irDecls
-- see `regularInitAttr.filterExport`
let initDecls : Array (Name × Name) := regularInitAttr.ext.getState env
|>.2.foldl (fun a n p => a.push (n, p)) #[]
|>.qsort (fun a b => Name.quickLt a.1 b.1)
-- save all initializers independent of meta/private. Non-meta initializers will only be used when
-- .ir is actually loaded, and private ones iff visible.
let initDecls : Array (Name × Name) :=
regularInitAttr.ext.exportEntriesFn env (regularInitAttr.ext.getState env) .private
-- safety: cast to erased type
let initDecls : Array EnvExtensionEntry := unsafe unsafeCast initDecls

View File

@@ -40,14 +40,14 @@ structure BuilderState where
For this reason we carry around these kinds of bindings in this substitution and apply it whenever
we access an fvar in the conversion.
-/
subst : LCNF.FVarSubst := {}
subst : LCNF.FVarSubst .pure := {}
abbrev M := StateRefT BuilderState CoreM
instance : LCNF.MonadFVarSubst M false where
instance : LCNF.MonadFVarSubst M .pure false where
getSubst := return ( get).subst
instance : LCNF.MonadFVarSubstState M where
instance : LCNF.MonadFVarSubstState M .pure where
modifySubst f := modify fun s => { s with subst := f s.subst }
def M.run (x : M α) : CoreM α := do
@@ -102,7 +102,7 @@ def lowerLitValue (v : LCNF.LitValue) : LitVal × IRType :=
| .uint64 v => .num (UInt64.toNat v), .uint64
| .usize v => .num (UInt64.toNat v), .usize
def lowerArg (a : LCNF.Arg) : M Arg := do
def lowerArg (a : LCNF.Arg .pure) : M Arg := do
match a with
| .fvar fvarId => getFVarValue fvarId
| .erased | .type .. => return .erased
@@ -121,15 +121,15 @@ def lowerProj (base : VarId) (ctorInfo : CtorInfo) (field : CtorFieldInfo)
| .erased => .erased, .erased
| .void => .erased, .void
def lowerParam (p : LCNF.Param) : M Param := do
def lowerParam (p : LCNF.Param .pure) : M Param := do
let x bindVar p.fvarId
let ty toIRType p.type
if ty.isVoid || ty.isErased then
Compiler.LCNF.addSubst p.fvarId .erased
Compiler.LCNF.addSubst p.fvarId (.erased : LCNF.Arg .pure)
return { x, borrow := p.borrow, ty }
mutual
partial def lowerCode (c : LCNF.Code) : M FnBody := do
partial def lowerCode (c : LCNF.Code .pure) : M FnBody := do
match c with
| .let decl k => lowerLet decl k
| .jp decl k =>
@@ -149,7 +149,7 @@ partial def lowerCode (c : LCNF.Code) : M FnBody := do
for idx in 0...ps.size do
let p := ps[idx]!
if idx == info.fieldIdx then
LCNF.addSubst p.fvarId (.fvar cases.discr)
LCNF.addSubst p.fvarId (.fvar cases.discr : LCNF.Arg .pure)
else
bindErased p.fvarId
lowerCode k
@@ -165,7 +165,7 @@ partial def lowerCode (c : LCNF.Code) : M FnBody := do
| .unreach .. => return .unreachable
| .fun .. => panic! "all local functions should be λ-lifted"
partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
partial def lowerLet (decl : LCNF.LetDecl .pure) (k : LCNF.Code .pure) : M FnBody := do
let value LCNF.normLetValue decl.value
match value with
| .lit litValue =>
@@ -175,7 +175,7 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
| .proj typeName i fvarId =>
if let some info hasTrivialStructure? typeName then
if info.fieldIdx == i then
LCNF.addSubst decl.fvarId (.fvar fvarId)
LCNF.addSubst decl.fvarId (.fvar fvarId : LCNF.Arg .pure)
else
bindErased decl.fvarId
lowerCode k
@@ -250,7 +250,8 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
| some (.defnInfo ..) | some (.opaqueInfo ..) =>
mkFap name irArgs
| some (.axiomInfo ..) | .some (.quotInfo ..) | .some (.inductInfo ..) | .some (.thmInfo ..) =>
throwNamedError lean.dependsOnNoncomputable f!"`{name}` not supported by code generator; consider marking definition as `noncomputable`"
-- Should have been caught by `ToLCNF`
throwError f!"ToIR: unexpected use of noncomputable declaration `{name}`; please report this issue"
| some (.recInfo ..) =>
throwError f!"code generator does not support recursor `{name}` yet, consider using 'match ... with' and/or structural recursion"
| none => panic! "reference to unbound name"
@@ -302,11 +303,11 @@ where
else
mkOverApplication name numParams args
partial def lowerAlt (discr : VarId) (a : LCNF.Alt) : M Alt := do
partial def lowerAlt (discr : VarId) (a : LCNF.Alt .pure) : M Alt := do
match a with
| .alt ctorName params code =>
let ctorInfo, fields getCtorLayout ctorName
let lowerParams (params : Array LCNF.Param) (fields : Array CtorFieldInfo) : M FnBody := do
let lowerParams (params : Array (LCNF.Param .pure)) (fields : Array CtorFieldInfo) : M FnBody := do
let rec loop (i : Nat) : M FnBody := do
match params[i]?, fields[i]? with
| some param, some field =>
@@ -340,7 +341,7 @@ where resultTypeForArity (type : Lean.Expr) (arity : Nat) : Lean.Expr :=
| .const ``lcErased _ => mkConst ``lcErased
| _ => panic! "invalid arity"
def lowerDecl (d : LCNF.Decl) : M (Option Decl) := do
def lowerDecl (d : LCNF.Decl .pure) : M (Option Decl) := do
let params d.params.mapM lowerParam
let mut resultType lowerResultType d.type d.params.size
let taggedReturn := taggedReturnAttr.hasTag ( getEnv) d.name
@@ -366,7 +367,7 @@ def lowerDecl (d : LCNF.Decl) : M (Option Decl) := do
end ToIR
def toIR (decls: Array LCNF.Decl) : CoreM (Array Decl) := do
def toIR (decls: Array (LCNF.Decl .pure)) : CoreM (Array Decl) := do
let mut irDecls := #[]
for decl in decls do
if let some irDecl ToIR.lowerDecl decl |>.run then

View File

@@ -174,8 +174,11 @@ private unsafe def runInitAttrs (env : Environment) (opts : Options) : IO Unit :
continue
interpretedModInits.modify (·.insert mod)
let modEntries := regularInitAttr.ext.getModuleEntries env modIdx
-- `getModuleIREntries` is identical to `getModuleEntries` if we loaded only one of .olean/.ir
-- so deduplicate (these lists should be very short)
-- `getModuleIREntries` is identical to `getModuleEntries` if we loaded only one of
-- .olean (from `meta initialize`)/.ir (`initialize` via transitive `meta import`)
-- so deduplicate (these lists should be very short).
-- If we have both, we should not need to worry about their relative ordering as `meta` and
-- non-`meta` initialize should not have interdependencies.
let modEntries := modEntries ++ (regularInitAttr.ext.getModuleIREntries env modIdx).filter (!modEntries.contains ·)
for (decl, initDecl) in modEntries do
-- Skip initializers we do not have IR for; they should not be reachable by interpretation.

View File

@@ -30,7 +30,6 @@ public import Lean.Compiler.LCNF.ReduceJpArity
public import Lean.Compiler.LCNF.Simp
public import Lean.Compiler.LCNF.Specialize
public import Lean.Compiler.LCNF.SpecInfo
public import Lean.Compiler.LCNF.Testing
public import Lean.Compiler.LCNF.ToDecl
public import Lean.Compiler.LCNF.ToExpr
public import Lean.Compiler.LCNF.ToLCNF

View File

@@ -40,14 +40,14 @@ def eqvTypes (es₁ es₂ : Array Expr) : EqvM Bool := do
else
return false
def eqvArg (a₁ a₂ : Arg) : EqvM Bool := do
def eqvArg (a₁ a₂ : Arg pu) : EqvM Bool := do
match a₁, a₂ with
| .type e₁, .type e₂ => eqvType e₁ e₂
| .type e₁ _, .type e₂ _ => eqvType e₁ e₂
| .fvar x₁, .fvar x₂ => eqvFVar x₁ x₂
| .erased, .erased => return true
| _, _ => return false
def eqvArgs (as₁ as₂ : Array Arg) : EqvM Bool := do
def eqvArgs (as₁ as₂ : Array (Arg pu)) : EqvM Bool := do
if as₁.size = as₂.size then
for a₁ in as₁, a₂ in as₂ do
unless ( eqvArg a₁ a₂) do
@@ -56,19 +56,19 @@ def eqvArgs (as₁ as₂ : Array Arg) : EqvM Bool := do
else
return false
def eqvLetValue (e₁ e₂ : LetValue) : EqvM Bool := do
def eqvLetValue (e₁ e₂ : LetValue pu) : EqvM Bool := do
match e₁, e₂ with
| .lit v₁, .lit v₂ => return v₁ == v₂
| .erased, .erased => return true
| .proj s₁ i₁ x₁, .proj s₂ i₂ x₂ => pure (s₁ == s₂ && i₁ == i₂) <&&> eqvFVar x₁ x₂
| .const n₁ us₁ as₁, .const n₂ us₂ as₂ => pure (n₁ == n₂ && us₁ == us₂) <&&> eqvArgs as₁ as₂
| .proj s₁ i₁ x₁ _, .proj s₂ i₂ x₂ _ => pure (s₁ == s₂ && i₁ == i₂) <&&> eqvFVar x₁ x₂
| .const n₁ us₁ as₁ _, .const n₂ us₂ as₂ _ => pure (n₁ == n₂ && us₁ == us₂) <&&> eqvArgs as₁ as₂
| .fvar f₁ as₁, .fvar f₂ as₂ => eqvFVar f₁ f₂ <&&> eqvArgs as₁ as₂
| _, _ => return false
@[inline] def withFVar (fvarId₁ fvarId₂ : FVarId) (x : EqvM α) : EqvM α :=
withReader (·.insert fvarId₂ fvarId₁) x
@[inline] def withParams (params₁ params₂ : Array Param) (x : EqvM Bool) : EqvM Bool := do
@[inline] def withParams (params₁ params₂ : Array (Param pu)) (x : EqvM Bool) : EqvM Bool := do
if h : params₂.size = params₁.size then
let rec @[specialize] go (i : Nat) : EqvM Bool := do
if h : i < params₁.size then
@@ -85,7 +85,7 @@ def eqvLetValue (e₁ e₂ : LetValue) : EqvM Bool := do
else
return false
def sortAlts (alts : Array Alt) : Array Alt :=
def sortAlts (alts : Array (Alt pu)) : Array (Alt pu) :=
alts.qsort fun
| .alt .., .default .. => true
| .alt ctorName₁ .., .alt ctorName₂ .. => Name.lt ctorName₁ ctorName₂
@@ -93,13 +93,13 @@ def sortAlts (alts : Array Alt) : Array Alt :=
mutual
partial def eqvAlts (alts₁ alts₂ : Array Alt) : EqvM Bool := do
partial def eqvAlts (alts₁ alts₂ : Array (Alt pu)) : EqvM Bool := do
if alts₁.size = alts₂.size then
let alts₁ := sortAlts alts₁
let alts₂ := sortAlts alts₂
for alt₁ in alts₁, alt₂ in alts₂ do
match alt₁, alt₂ with
| .alt ctorName₁ ps₁ k₁, .alt ctorName₂ ps₂ k₂ =>
| .alt ctorName₁ ps₁ k₁ _, .alt ctorName₂ ps₂ k₂ _ =>
unless ctorName₁ == ctorName₂ do return false
unless ( withParams ps₁ ps₂ (eqv k₁ k₂)) do return false
| .default k₁, .default k₂ => unless ( eqv k₁ k₂) do return false
@@ -108,13 +108,13 @@ partial def eqvAlts (alts₁ alts₂ : Array Alt) : EqvM Bool := do
else
return false
partial def eqv (code₁ code₂ : Code) : EqvM Bool := do
partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do
match code₁, code₂ with
| .let decl₁ k₁, .let decl₂ k₂ =>
eqvType decl₁.type decl₂.type <&&>
eqvLetValue decl₁.value decl₂.value <&&>
withFVar decl₁.fvarId decl₂.fvarId (eqv k₁ k₂)
| .fun decl₁ k₁, .fun decl₂ k₂
| .fun decl₁ k₁ _, .fun decl₂ k₂ _
| .jp decl₁ k₁, .jp decl₂ k₂ =>
eqvType decl₁.type decl₂.type <&&>
withParams decl₁.params decl₂.params (eqv decl₁.value decl₂.value) <&&>
@@ -135,7 +135,7 @@ end AlphaEqv
/--
Return `true` if `c₁` and `c₂` are alpha equivalent.
-/
def Code.alphaEqv (c₁ c₂ : Code) : Bool :=
def Code.alphaEqv (c₁ c₂ : Code pu) : Bool :=
AlphaEqv.eqv c₁ c₂ |>.run {}
end Lean.Compiler.LCNF

View File

@@ -13,15 +13,21 @@ public section
namespace Lean.Compiler.LCNF
builtin_initialize auxDeclCacheExt : CacheExtension Decl Name CacheExtension.register
structure AuxDeclCacheKey where
pu : Purity
decl : Decl pu
deriving BEq, Hashable
builtin_initialize auxDeclCacheExt : CacheExtension AuxDeclCacheKey Name CacheExtension.register
inductive CacheAuxDeclResult where
| new
| alreadyCached (declName : Name)
def cacheAuxDecl (decl : Decl) : CompilerM CacheAuxDeclResult := do
def cacheAuxDecl (decl : Decl pu) : CompilerM CacheAuxDeclResult := do
let key := { decl with name := .anonymous }
let key normalizeFVarIds key
let key := pu, key
match ( auxDeclCacheExt.find? key) with
| some declName =>
return .alreadyCached declName

View File

@@ -24,14 +24,50 @@ and the approach described in the paper
-/
structure Param where
/--
This type is used to index the fundamental LCNF IR data structures. Depending on its value different
constructors are available for the different semantic phases of LCNF.
Notably in order to save memory we never index the IR types over `Purity`. Instead the type is
parametrized by the phase and the individual constructors might carry a proof (that will be erased)
that they are only allowed in a certain phase.
-/
inductive Purity where
/--
The code we are acting on is still pure, things like reordering up to value dependencies are
acceptable.
-/
| pure
/--
The code we are acting on is to be considered generally impure, doing reorderings is potentially
no longer legal.
-/
| impure
deriving Inhabited, DecidableEq, Hashable
instance : ToString Purity where
toString
| .pure => "pure"
| .impure => "impure"
@[inline]
def Purity.withAssertPurity [Inhabited α] (is : Purity) (should : Purity)
(k : (is = should) α) : α :=
if h : is = should then
k h
else
panic! s!"Purity should be {should} but is {is}, this is a bug"
scoped macro "purity_tac" : tactic => `(tactic| first | with_reducible rfl | assumption)
structure Param (pu : Purity) where
fvarId : FVarId
binderName : Name
type : Expr
borrow : Bool
deriving Inhabited, BEq
def Param.toExpr (p : Param) : Expr :=
def Param.toExpr (p : Param pu) : Expr :=
.fvar p.fvarId
inductive LitValue where
@@ -55,111 +91,111 @@ def LitValue.toExpr : LitValue → Expr
| .uint64 v => .app (.const ``UInt64.ofNat []) (.lit (.natVal (UInt64.toNat v)))
| .usize v => .app (.const ``USize.ofNat []) (.lit (.natVal (UInt64.toNat v)))
inductive Arg where
inductive Arg (pu : Purity) where
| erased
| fvar (fvarId : FVarId)
| type (expr : Expr)
| type (expr : Expr) (h : pu = .pure := by purity_tac)
deriving Inhabited, BEq, Hashable
def Param.toArg (p : Param) : Arg :=
def Param.toArg (p : Param pu) : Arg pu :=
.fvar p.fvarId
def Arg.toExpr (arg : Arg) : Expr :=
def Arg.toExpr (arg : Arg pu) : Expr :=
match arg with
| .erased => erasedExpr
| .fvar fvarId => .fvar fvarId
| .type e => e
| .type e _ => e
private unsafe def Arg.updateTypeImp (arg : Arg) (type' : Expr) : Arg :=
private unsafe def Arg.updateTypeImp (arg : Arg pu) (type' : Expr) : Arg pu :=
match arg with
| .type ty => if ptrEq ty type' then arg else .type type'
| .type ty _ => if ptrEq ty type' then arg else .type type'
| _ => unreachable!
@[implemented_by Arg.updateTypeImp] opaque Arg.updateType! (arg : Arg) (type : Expr) : Arg
@[implemented_by Arg.updateTypeImp] opaque Arg.updateType! (arg : Arg pu) (type : Expr) : Arg pu
private unsafe def Arg.updateFVarImp (arg : Arg) (fvarId' : FVarId) : Arg :=
private unsafe def Arg.updateFVarImp (arg : Arg pu) (fvarId' : FVarId) : Arg pu :=
match arg with
| .fvar fvarId => if fvarId' == fvarId then arg else .fvar fvarId'
| _ => unreachable!
@[implemented_by Arg.updateFVarImp] opaque Arg.updateFVar! (arg : Arg) (fvarId' : FVarId) : Arg
@[implemented_by Arg.updateFVarImp] opaque Arg.updateFVar! (arg : Arg pu) (fvarId' : FVarId) : Arg pu
inductive LetValue where
inductive LetValue (pu : Purity) where
| lit (value : LitValue)
| erased
| proj (typeName : Name) (idx : Nat) (struct : FVarId)
| const (declName : Name) (us : List Level) (args : Array Arg)
| fvar (fvarId : FVarId) (args : Array Arg)
| proj (typeName : Name) (idx : Nat) (struct : FVarId) (h : pu = .pure := by purity_tac)
| const (declName : Name) (us : List Level) (args : Array (Arg pu)) (h : pu = .pure := by purity_tac)
| fvar (fvarId : FVarId) (args : Array (Arg pu))
deriving Inhabited, BEq, Hashable
def Arg.toLetValue (arg : Arg) : LetValue :=
def Arg.toLetValue (arg : Arg pu) : LetValue pu :=
match arg with
| .fvar fvarId => .fvar fvarId #[]
| .erased | .type .. => .erased
private unsafe def LetValue.updateProjImp (e : LetValue) (fvarId' : FVarId) : LetValue :=
private unsafe def LetValue.updateProjImp (e : LetValue pu) (fvarId' : FVarId) : LetValue pu :=
match e with
| .proj s i fvarId => if fvarId == fvarId' then e else .proj s i fvarId'
| .proj s i fvarId _ => if fvarId == fvarId' then e else .proj s i fvarId'
| _ => unreachable!
@[implemented_by LetValue.updateProjImp] opaque LetValue.updateProj! (e : LetValue) (fvarId' : FVarId) : LetValue
@[implemented_by LetValue.updateProjImp] opaque LetValue.updateProj! (e : LetValue pu) (fvarId' : FVarId) : LetValue pu
private unsafe def LetValue.updateConstImp (e : LetValue) (declName' : Name) (us' : List Level) (args' : Array Arg) : LetValue :=
private unsafe def LetValue.updateConstImp (e : LetValue pu) (declName' : Name) (us' : List Level) (args' : Array (Arg pu)) : LetValue pu :=
match e with
| .const declName us args => if declName == declName' && ptrEq us us' && ptrEq args args' then e else .const declName' us' args'
| .const declName us args _ => if declName == declName' && ptrEq us us' && ptrEq args args' then e else .const declName' us' args'
| _ => unreachable!
@[implemented_by LetValue.updateConstImp] opaque LetValue.updateConst! (e : LetValue) (declName' : Name) (us' : List Level) (args' : Array Arg) : LetValue
@[implemented_by LetValue.updateConstImp] opaque LetValue.updateConst! (e : LetValue pu) (declName' : Name) (us' : List Level) (args' : Array (Arg pu)) : LetValue pu
private unsafe def LetValue.updateFVarImp (e : LetValue) (fvarId' : FVarId) (args' : Array Arg) : LetValue :=
private unsafe def LetValue.updateFVarImp (e : LetValue pu) (fvarId' : FVarId) (args' : Array (Arg pu)) : LetValue pu :=
match e with
| .fvar fvarId args => if fvarId == fvarId' && ptrEq args args' then e else .fvar fvarId' args'
| _ => unreachable!
@[implemented_by LetValue.updateFVarImp] opaque LetValue.updateFVar! (e : LetValue) (fvarId' : FVarId) (args' : Array Arg) : LetValue
@[implemented_by LetValue.updateFVarImp] opaque LetValue.updateFVar! (e : LetValue pu) (fvarId' : FVarId) (args' : Array (Arg pu)) : LetValue pu
private unsafe def LetValue.updateArgsImp (e : LetValue) (args' : Array Arg) : LetValue :=
private unsafe def LetValue.updateArgsImp (e : LetValue pu) (args' : Array (Arg pu)) : LetValue pu :=
match e with
| .const declName us args => if ptrEq args args' then e else .const declName us args'
| .const declName us args h => if ptrEq args args' then e else .const declName us args'
| .fvar fvarId args => if ptrEq args args' then e else .fvar fvarId args'
| _ => unreachable!
@[implemented_by LetValue.updateArgsImp] opaque LetValue.updateArgs! (e : LetValue) (args' : Array Arg) : LetValue
@[implemented_by LetValue.updateArgsImp] opaque LetValue.updateArgs! (e : LetValue pu) (args' : Array (Arg pu)) : LetValue pu
def LetValue.toExpr (e : LetValue) : Expr :=
def LetValue.toExpr (e : LetValue pu) : Expr :=
match e with
| .lit v => v.toExpr
| .erased => erasedExpr
| .proj n i s => .proj n i (.fvar s)
| .const n us as => mkAppN (.const n us) (as.map Arg.toExpr)
| .proj n i s _ => .proj n i (.fvar s)
| .const n us as _ => mkAppN (.const n us) (as.map Arg.toExpr)
| .fvar fvarId as => mkAppN (.fvar fvarId) (as.map Arg.toExpr)
structure LetDecl where
structure LetDecl (pu : Purity) where
fvarId : FVarId
binderName : Name
type : Expr
value : LetValue
value : LetValue pu
deriving Inhabited, BEq
mutual
inductive Alt where
| alt (ctorName : Name) (params : Array Param) (code : Code)
| default (code : Code)
inductive Alt (pu : Purity) where
| alt (ctorName : Name) (params : Array (Param pu)) (code : Code pu) (h : pu = .pure := by purity_tac)
| default (code : Code pu)
inductive FunDecl where
| mk (fvarId : FVarId) (binderName : Name) (params : Array Param) (type : Expr) (value : Code)
inductive FunDecl (pu : Purity) where
| mk (fvarId : FVarId) (binderName : Name) (params : Array (Param pu)) (type : Expr) (value : Code pu)
inductive Cases where
| mk (typeName : Name) (resultType : Expr) (discr : FVarId) (alts : Array Alt)
inductive Cases (pu : Purity) where
| mk (typeName : Name) (resultType : Expr) (discr : FVarId) (alts : Array (Alt pu))
deriving Inhabited
inductive Code where
| let (decl : LetDecl) (k : Code)
| fun (decl : FunDecl) (k : Code)
| jp (decl : FunDecl) (k : Code)
| jmp (fvarId : FVarId) (args : Array Arg)
| cases (cases : Cases)
inductive Code (pu : Purity) where
| let (decl : LetDecl pu) (k : Code pu)
| fun (decl : FunDecl pu) (k : Code pu) (h : pu = .pure := by purity_tac)
| jp (decl : FunDecl pu) (k : Code pu)
| jmp (fvarId : FVarId) (args : Array (Arg pu))
| cases (cases : Cases pu)
| return (fvarId : FVarId)
| unreach (type : Expr)
deriving Inhabited
@@ -167,99 +203,99 @@ inductive Code where
end
@[inline]
def FunDecl.fvarId : FunDecl FVarId
def FunDecl.fvarId : FunDecl pu FVarId
| .mk (fvarId := fvarId) .. => fvarId
@[inline]
def FunDecl.binderName : FunDecl Name
def FunDecl.binderName : FunDecl pu Name
| .mk (binderName := binderName) .. => binderName
@[inline]
def FunDecl.params : FunDecl Array Param
def FunDecl.params : FunDecl pu Array (Param pu)
| .mk (params := params) .. => params
@[inline]
def FunDecl.type : FunDecl Expr
def FunDecl.type : FunDecl pu Expr
| .mk (type := type) .. => type
@[inline]
def FunDecl.value : FunDecl Code
def FunDecl.value : FunDecl pu Code pu
| .mk (value := value) .. => value
@[inline]
def FunDecl.updateBinderName : FunDecl Name FunDecl
def FunDecl.updateBinderName : FunDecl pu Name FunDecl pu
| .mk fvarId _ params type value, new =>
.mk fvarId new params type value
@[inline]
def FunDecl.toParam (decl : FunDecl) (borrow : Bool) : Param :=
def FunDecl.toParam (decl : FunDecl pu) (borrow : Bool) : Param pu :=
match decl with
| .mk fvarId binderName _ type .. => fvarId, binderName, type, borrow
@[inline]
def Cases.typeName : Cases Name
def Cases.typeName : Cases pu Name
| .mk (typeName := typeName) .. => typeName
@[inline]
def Cases.resultType : Cases Expr
def Cases.resultType : Cases pu Expr
| .mk (resultType := resultType) .. => resultType
@[inline]
def Cases.discr : Cases FVarId
def Cases.discr : Cases pu FVarId
| .mk (discr := discr) .. => discr
@[inline]
def Cases.alts : Cases Array Alt
def Cases.alts : Cases pu Array (Alt pu)
| .mk (alts := alts) .. => alts
@[inline]
def Cases.updateAlts : Cases Array Alt Cases
def Cases.updateAlts : Cases pu Array (Alt pu) Cases pu
| .mk typeName resultType discr _, new =>
.mk typeName resultType discr new
deriving instance Inhabited for Alt
deriving instance Inhabited for FunDecl
def FunDecl.getArity (decl : FunDecl) : Nat :=
def FunDecl.getArity (decl : FunDecl pu) : Nat :=
decl.params.size
/--
Return the constructor names that have an explicit (non-default) alternative.
-/
def Cases.getCtorNames (c : Cases) : NameSet :=
def Cases.getCtorNames (c : Cases pu) : NameSet :=
c.alts.foldl (init := {}) fun ctorNames alt =>
match alt with
| .default _ => ctorNames
| .alt ctorName .. => ctorNames.insert ctorName
inductive CodeDecl where
| let (decl : LetDecl)
| fun (decl : FunDecl)
| jp (decl : FunDecl)
inductive CodeDecl (pu : Purity) where
| let (decl : LetDecl pu)
| fun (decl : FunDecl pu) (h : pu = .pure := by purity_tac)
| jp (decl : FunDecl pu)
deriving Inhabited
def CodeDecl.fvarId : CodeDecl FVarId
| .let decl | .fun decl | .jp decl => decl.fvarId
def CodeDecl.fvarId : CodeDecl pu FVarId
| .let decl | .fun decl _ | .jp decl => decl.fvarId
def attachCodeDecls (decls : Array CodeDecl) (code : Code) : Code :=
def attachCodeDecls (decls : Array (CodeDecl pu)) (code : Code pu) : Code pu :=
go decls.size code
where
go (i : Nat) (code : Code) : Code :=
go (i : Nat) (code : Code pu) : Code pu :=
if i > 0 then
match decls[i-1]! with
| .let decl => go (i-1) (.let decl code)
| .fun decl => go (i-1) (.fun decl code)
| .fun decl _ => go (i-1) (.fun decl code)
| .jp decl => go (i-1) (.jp decl code)
else
code
mutual
private unsafe def eqImp (c₁ c₂ : Code) : Bool :=
private unsafe def eqImp (c₁ c₂ : Code pu) : Bool :=
if ptrEq c₁ c₂ then
true
else match c₁, c₂ with
| .let d₁ k₁, .let d₂ k₂ => d₁ == d₂ && eqImp k₁ k₂
| .fun d₁ k₁, .fun d₂ k₂
| .fun d₁ k₁ _, .fun d₂ k₂ _
| .jp d₁ k₁, .jp d₂ k₂ => eqFunDecl d₁ d₂ && eqImp k₁ k₂
| .cases c₁, .cases c₂ => eqCases c₁ c₂
| .jmp j₁ as₁, .jmp j₂ as₂ => j₁ == j₂ && as₁ == as₂
@@ -267,7 +303,7 @@ mutual
| .unreach t₁, .unreach t₂ => t₁ == t₂
| _, _ => false
private unsafe def eqFunDecl (d₁ d₂ : FunDecl) : Bool :=
private unsafe def eqFunDecl (d₁ d₂ : FunDecl pu) : Bool :=
if ptrEq d₁ d₂ then
true
else
@@ -275,62 +311,62 @@ mutual
d₁.params == d₂.params && d₁.type == d₂.type &&
eqImp d₁.value d₂.value
private unsafe def eqCases (c₁ c₂ : Cases) : Bool :=
private unsafe def eqCases (c₁ c₂ : Cases pu) : Bool :=
c₁.resultType == c₂.resultType && c₁.discr == c₂.discr &&
c₁.typeName == c₂.typeName && c₁.alts.isEqv c₂.alts eqAlt
private unsafe def eqAlt (a₁ a₂ : Alt) : Bool :=
private unsafe def eqAlt (a₁ a₂ : Alt pu) : Bool :=
match a₁, a₂ with
| .default k₁, .default k₂ => eqImp k₁ k₂
| .alt c₁ ps₁ k₁, .alt c₂ ps₂ k₂ => c₁ == c₂ && ps₁ == ps₂ && eqImp k₁ k₂
| .alt c₁ ps₁ k₁ _, .alt c₂ ps₂ k₂ _ => c₁ == c₂ && ps₁ == ps₂ && eqImp k₁ k₂
| _, _ => false
end
@[implemented_by eqImp] protected opaque Code.beq : Code Code Bool
@[implemented_by eqImp] protected opaque Code.beq : Code pu Code pu Bool
instance : BEq Code where
instance : BEq (Code pu) where
beq := Code.beq
@[implemented_by eqFunDecl] protected opaque FunDecl.beq : FunDecl FunDecl Bool
@[implemented_by eqFunDecl] protected opaque FunDecl.beq : FunDecl pu FunDecl pu Bool
instance : BEq FunDecl where
instance : BEq (FunDecl pu) where
beq := FunDecl.beq
def Alt.getCode : Alt Code
def Alt.getCode : Alt pu Code pu
| .default k => k
| .alt _ _ k => k
| .alt _ _ k _ => k
def Alt.getParams : Alt Array Param
def Alt.getParams : Alt pu Array (Param pu)
| .default _ => #[]
| .alt _ ps _ => ps
| .alt _ ps _ _ => ps
def Alt.forCodeM [Monad m] (alt : Alt) (f : Code m Unit) : m Unit := do
def Alt.forCodeM [Monad m] (alt : Alt pu) (f : Code pu m Unit) : m Unit := do
match alt with
| .default k => f k
| .alt _ _ k => f k
| .alt _ _ k _ => f k
private unsafe def updateAltCodeImp (alt : Alt) (k' : Code) : Alt :=
private unsafe def updateAltCodeImp (alt : Alt pu) (k' : Code pu) : Alt pu :=
match alt with
| .default k => if ptrEq k k' then alt else .default k'
| .alt ctorName ps k => if ptrEq k k' then alt else .alt ctorName ps k'
| .alt ctorName ps k _ => if ptrEq k k' then alt else .alt ctorName ps k'
@[implemented_by updateAltCodeImp] opaque Alt.updateCode (alt : Alt) (c : Code) : Alt
@[implemented_by updateAltCodeImp] opaque Alt.updateCode (alt : Alt pu) (c : Code pu) : Alt pu
private unsafe def updateAltImp (alt : Alt) (ps' : Array Param) (k' : Code) : Alt :=
private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Code pu) : Alt pu :=
match alt with
| .alt ctorName ps k => if ptrEq k k' && ptrEq ps ps' then alt else .alt ctorName ps' k'
| .alt ctorName ps k _ => if ptrEq k k' && ptrEq ps ps' then alt else .alt ctorName ps' k'
| _ => unreachable!
@[implemented_by updateAltImp] opaque Alt.updateAlt! (alt : Alt) (ps' : Array Param) (k' : Code) : Alt
@[implemented_by updateAltImp] opaque Alt.updateAlt! (alt : Alt pu) (ps' : Array (Param pu)) (k' : Code pu) : Alt pu
@[inline] private unsafe def updateAltsImp (c : Code) (alts : Array Alt) : Code :=
@[inline] private unsafe def updateAltsImp (c : Code pu) (alts : Array (Alt pu)) : Code pu :=
match c with
| .cases cs => if ptrEq cs.alts alts then c else .cases <| cs.updateAlts alts
| _ => unreachable!
@[implemented_by updateAltsImp] opaque Code.updateAlts! (c : Code) (alts : Array Alt) : Code
@[implemented_by updateAltsImp] opaque Code.updateAlts! (c : Code pu) (alts : Array (Alt pu)) : Code pu
@[inline] private unsafe def updateCasesImp (c : Code) (resultType : Expr) (discr : FVarId) (alts : Array Alt) : Code :=
@[inline] private unsafe def updateCasesImp (c : Code pu) (resultType : Expr) (discr : FVarId) (alts : Array (Alt pu)) : Code pu :=
match c with
| .cases cs =>
if ptrEq cs.alts alts && ptrEq cs.resultType resultType && cs.discr == discr then
@@ -339,54 +375,54 @@ private unsafe def updateAltImp (alt : Alt) (ps' : Array Param) (k' : Code) : Al
.cases <| cs.typeName, resultType, discr, alts
| _ => unreachable!
@[implemented_by updateCasesImp] opaque Code.updateCases! (c : Code) (resultType : Expr) (discr : FVarId) (alts : Array Alt) : Code
@[implemented_by updateCasesImp] opaque Code.updateCases! (c : Code pu) (resultType : Expr) (discr : FVarId) (alts : Array (Alt pu)) : Code pu
@[inline] private unsafe def updateLetImp (c : Code) (decl' : LetDecl) (k' : Code) : Code :=
@[inline] private unsafe def updateLetImp (c : Code pu) (decl' : LetDecl pu) (k' : Code pu) : Code pu :=
match c with
| .let decl k => if ptrEq k k' && ptrEq decl decl' then c else .let decl' k'
| _ => unreachable!
@[implemented_by updateLetImp] opaque Code.updateLet! (c : Code) (decl' : LetDecl) (k' : Code) : Code
@[implemented_by updateLetImp] opaque Code.updateLet! (c : Code pu) (decl' : LetDecl pu) (k' : Code pu) : Code pu
@[inline] private unsafe def updateContImp (c : Code) (k' : Code) : Code :=
@[inline] private unsafe def updateContImp (c : Code pu) (k' : Code pu) : Code pu :=
match c with
| .let decl k => if ptrEq k k' then c else .let decl k'
| .fun decl k => if ptrEq k k' then c else .fun decl k'
| .fun decl k _ => if ptrEq k k' then c else .fun decl k'
| .jp decl k => if ptrEq k k' then c else .jp decl k'
| _ => unreachable!
@[implemented_by updateContImp] opaque Code.updateCont! (c : Code) (k' : Code) : Code
@[implemented_by updateContImp] opaque Code.updateCont! (c : Code pu) (k' : Code pu) : Code pu
@[inline] private unsafe def updateFunImp (c : Code) (decl' : FunDecl) (k' : Code) : Code :=
@[inline] private unsafe def updateFunImp (c : Code pu) (decl' : FunDecl pu) (k' : Code pu) : Code pu :=
match c with
| .fun decl k => if ptrEq k k' && ptrEq decl decl' then c else .fun decl' k'
| .fun decl k _ => if ptrEq k k' && ptrEq decl decl' then c else .fun decl' k'
| .jp decl k => if ptrEq k k' && ptrEq decl decl' then c else .jp decl' k'
| _ => unreachable!
@[implemented_by updateFunImp] opaque Code.updateFun! (c : Code) (decl' : FunDecl) (k' : Code) : Code
@[implemented_by updateFunImp] opaque Code.updateFun! (c : Code pu) (decl' : FunDecl pu) (k' : Code pu) : Code pu
@[inline] private unsafe def updateReturnImp (c : Code) (fvarId' : FVarId) : Code :=
@[inline] private unsafe def updateReturnImp (c : Code pu) (fvarId' : FVarId) : Code pu :=
match c with
| .return fvarId => if fvarId == fvarId' then c else .return fvarId'
| _ => unreachable!
@[implemented_by updateReturnImp] opaque Code.updateReturn! (c : Code) (fvarId' : FVarId) : Code
@[implemented_by updateReturnImp] opaque Code.updateReturn! (c : Code pu) (fvarId' : FVarId) : Code pu
@[inline] private unsafe def updateJmpImp (c : Code) (fvarId' : FVarId) (args' : Array Arg) : Code :=
@[inline] private unsafe def updateJmpImp (c : Code pu) (fvarId' : FVarId) (args' : Array (Arg pu)) : Code pu :=
match c with
| .jmp fvarId args => if fvarId == fvarId' && ptrEq args args' then c else .jmp fvarId' args'
| _ => unreachable!
@[implemented_by updateJmpImp] opaque Code.updateJmp! (c : Code) (fvarId' : FVarId) (args' : Array Arg) : Code
@[implemented_by updateJmpImp] opaque Code.updateJmp! (c : Code pu) (fvarId' : FVarId) (args' : Array (Arg pu)) : Code pu
@[inline] private unsafe def updateUnreachImp (c : Code) (type' : Expr) : Code :=
@[inline] private unsafe def updateUnreachImp (c : Code pu) (type' : Expr) : Code pu :=
match c with
| .unreach type => if ptrEq type type' then c else .unreach type'
| _ => unreachable!
@[implemented_by updateUnreachImp] opaque Code.updateUnreach! (c : Code) (type' : Expr) : Code
@[implemented_by updateUnreachImp] opaque Code.updateUnreach! (c : Code pu) (type' : Expr) : Code pu
private unsafe def updateParamCoreImp (p : Param) (type : Expr) : Param :=
private unsafe def updateParamCoreImp (p : Param pu) (type : Expr) : Param pu :=
if ptrEq type p.type then
p
else
@@ -397,9 +433,9 @@ Low-level update `Param` function. It does not update the local context.
Consider using `Param.update : Param → Expr → CompilerM Param` if you want the local context
to be updated.
-/
@[implemented_by updateParamCoreImp] opaque Param.updateCore (p : Param) (type : Expr) : Param
@[implemented_by updateParamCoreImp] opaque Param.updateCore (p : Param pu) (type : Expr) : Param pu
private unsafe def updateLetDeclCoreImp (decl : LetDecl) (type : Expr) (value : LetValue) : LetDecl :=
private unsafe def updateLetDeclCoreImp (decl : LetDecl pu) (type : Expr) (value : LetValue pu) : LetDecl pu :=
if ptrEq type decl.type && ptrEq value decl.value then
decl
else
@@ -410,9 +446,9 @@ Low-level update `LetDecl` function. It does not update the local context.
Consider using `LetDecl.update : LetDecl → Expr → Expr → CompilerM LetDecl` if you want the local context
to be updated.
-/
@[implemented_by updateLetDeclCoreImp] opaque LetDecl.updateCore (decl : LetDecl) (type : Expr) (value : LetValue) : LetDecl
@[implemented_by updateLetDeclCoreImp] opaque LetDecl.updateCore (decl : LetDecl pu) (type : Expr) (value : LetValue pu) : LetDecl pu
private unsafe def updateFunDeclCoreImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl :=
private unsafe def updateFunDeclCoreImp (decl: FunDecl pu) (type : Expr) (params : Array (Param pu)) (value : Code pu) : FunDecl pu :=
if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then
decl
else
@@ -423,9 +459,9 @@ Low-level update `FunDecl` function. It does not update the local context.
Consider using `FunDecl.update : LetDecl → Expr → Array Param → Code → CompilerM FunDecl` if you want the local context
to be updated.
-/
@[implemented_by updateFunDeclCoreImp] opaque FunDecl.updateCore (decl : FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl
@[implemented_by updateFunDeclCoreImp] opaque FunDecl.updateCore (decl : FunDecl pu) (type : Expr) (params : Array (Param pu)) (value : Code pu) : FunDecl pu
def Cases.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases :=
def Cases.extractAlt! (cases : Cases pu) (ctorName : Name) : Alt pu × Cases pu :=
let found i := (cases.alts[i]!, cases.updateAlts (cases.alts.eraseIdx! i))
if let some i := cases.alts.findFinIdx? fun | .alt ctorName' .. => ctorName == ctorName' | _ => false then
found i
@@ -434,34 +470,34 @@ def Cases.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases :=
else
unreachable!
def Alt.mapCodeM [Monad m] (alt : Alt) (f : Code m Code) : m Alt := do
def Alt.mapCodeM [Monad m] (alt : Alt pu) (f : Code pu m (Code pu)) : m (Alt pu) := do
return alt.updateCode ( f alt.getCode)
def Code.isDecl : Code Bool
def Code.isDecl : Code pu Bool
| .let .. | .fun .. | .jp .. => true
| _ => false
def Code.isFun : Code Bool
def Code.isFun : Code pu Bool
| .fun .. => true
| _ => false
def Code.isReturnOf : Code FVarId Bool
def Code.isReturnOf : Code pu FVarId Bool
| .return fvarId, fvarId' => fvarId == fvarId'
| _, _ => false
partial def Code.size (c : Code) : Nat :=
partial def Code.size (c : Code pu) : Nat :=
go c 0
where
go (c : Code) (n : Nat) : Nat :=
go (c : Code pu) (n : Nat) : Nat :=
match c with
| .let _ k => go k (n+1)
| .jp decl k | .fun decl k => go k <| go decl.value n
| .jp decl k | .fun decl k _ => go k <| go decl.value n
| .cases c => c.alts.foldl (init := n+1) fun n alt => go alt.getCode (n+1)
| .jmp .. => n+1
| .return .. | unreach .. => n -- `return` & `unreach` have weight zero
/-- Return true iff `c.size ≤ n` -/
partial def Code.sizeLe (c : Code) (n : Nat) : Bool :=
partial def Code.sizeLe (c : Code pu) (n : Nat) : Bool :=
match go c |>.run 0 with
| .ok .. => true
| .error .. => false
@@ -470,26 +506,26 @@ where
modify (·+1)
unless ( get) <= n do throw ()
go (c : Code) : EStateM Unit Nat Unit := do
go (c : Code pu) : EStateM Unit Nat Unit := do
match c with
| .let _ k => inc; go k
| .jp decl k | .fun decl k => inc; go decl.value; go k
| .jp decl k | .fun decl k _ => inc; go decl.value; go k
| .cases c => inc; c.alts.forM fun alt => go alt.getCode
| .jmp .. => inc
| .return .. | unreach .. => return ()
partial def Code.forM [Monad m] (c : Code) (f : Code m Unit) : m Unit :=
partial def Code.forM [Monad m] (c : Code pu) (f : Code pu m Unit) : m Unit :=
go c
where
go (c : Code) : m Unit := do
go (c : Code pu) : m Unit := do
f c
match c with
| .let _ k => go k
| .fun decl k | .jp decl k => go decl.value; go k
| .fun decl k _ | .jp decl k => go decl.value; go k
| .cases c => c.alts.forM fun alt => go alt.getCode
| .unreach .. | .return .. | .jmp .. => return ()
partial def Code.instantiateValueLevelParams (code : Code) (levelParams : List Name) (us : List Level) : Code :=
partial def Code.instantiateValueLevelParams (code : Code .pure) (levelParams : List Name) (us : List Level) : Code .pure :=
instCode code
where
instLevel (u : Level) :=
@@ -498,67 +534,67 @@ where
instExpr (e : Expr) :=
e.instantiateLevelParamsNoCache levelParams us
instParams (ps : Array Param) :=
instParams (ps : Array (Param .pure)) :=
ps.mapMono fun p => p.updateCore (instExpr p.type)
instAlt (alt : Alt) :=
instAlt (alt : Alt .pure) :=
match alt with
| .default k => alt.updateCode (instCode k)
| .alt _ ps k => alt.updateAlt! (instParams ps) (instCode k)
| .alt _ ps k _ => alt.updateAlt! (instParams ps) (instCode k)
instArg (arg : Arg) : Arg :=
instArg (arg : Arg .pure) : Arg .pure :=
match arg with
| .type e => arg.updateType! (instExpr e)
| .type e _ => arg.updateType! (instExpr e)
| .fvar .. | .erased => arg
instLetValue (e : LetValue) : LetValue :=
instLetValue (e : LetValue .pure) : LetValue .pure :=
match e with
| .const declName vs args => e.updateConst! declName (vs.mapMono instLevel) (args.mapMono instArg)
| .const declName vs args _ => e.updateConst! declName (vs.mapMono instLevel) (args.mapMono instArg)
| .fvar fvarId args => e.updateFVar! fvarId (args.mapMono instArg)
| .proj .. | .lit .. | .erased => e
instLetDecl (decl : LetDecl) :=
instLetDecl (decl : LetDecl .pure) :=
decl.updateCore (instExpr decl.type) (instLetValue decl.value)
instFunDecl (decl : FunDecl) :=
instFunDecl (decl : FunDecl .pure) :=
decl.updateCore (instExpr decl.type) (instParams decl.params) (instCode decl.value)
instCode (code : Code) :=
instCode (code : Code .pure) :=
match code with
| .let decl k => code.updateLet! (instLetDecl decl) (instCode k)
| .jp decl k | .fun decl k => code.updateFun! (instFunDecl decl) (instCode k)
| .jp decl k | .fun decl k _ => code.updateFun! (instFunDecl decl) (instCode k)
| .cases c => code.updateCases! (instExpr c.resultType) c.discr (c.alts.mapMono instAlt)
| .jmp fvarId args => code.updateJmp! fvarId (args.mapMono instArg)
| .return .. => code
| .unreach type => code.updateUnreach! (instExpr type)
inductive DeclValue where
| code (code : Code)
inductive DeclValue (pu : Purity) where
| code (code : Code pu)
| extern (externAttrData : ExternAttrData)
deriving Inhabited, BEq
partial def DeclValue.size : DeclValue Nat
partial def DeclValue.size : DeclValue pu Nat
| .code c => c.size
| .extern .. => 0
def DeclValue.mapCode (f : Code Code) : DeclValue DeclValue :=
def DeclValue.mapCode (f : Code pu Code pu) : DeclValue pu DeclValue pu :=
fun
| .code c => .code (f c)
| .extern e => .extern e
def DeclValue.mapCodeM [Monad m] (f : Code m Code) : DeclValue m DeclValue :=
def DeclValue.mapCodeM [Monad m] (f : Code pu m (Code pu)) : DeclValue pu m (DeclValue pu) :=
fun v => do
match v with
| .code c => return .code ( f c)
| .extern .. => return v
def DeclValue.forCodeM [Monad m] (f : Code m Unit) : DeclValue m Unit :=
def DeclValue.forCodeM [Monad m] (f : Code pu m Unit) : DeclValue pu m Unit :=
fun v => do
match v with
| .code c => f c
| .extern .. => return ()
def DeclValue.isCodeAndM [Monad m] (v : DeclValue) (f : Code m Bool) : m Bool :=
def DeclValue.isCodeAndM [Monad m] (v : DeclValue pu) (f : Code pu m Bool) : m Bool :=
match v with
| .code c => f c
| .extern .. => pure false
@@ -566,7 +602,7 @@ def DeclValue.isCodeAndM [Monad m] (v : DeclValue) (f : Code → m Bool) : m Boo
/--
Declaration being processed by the Lean to Lean compiler passes.
-/
structure Decl where
structure Decl (pu : Purity) where
/--
The name of the declaration from the `Environment` it came from
-/
@@ -584,12 +620,12 @@ structure Decl where
/--
Parameters.
-/
params : Array Param
params : Array (Param pu)
/--
The body of the declaration, usually changes as it progresses
through compiler passes.
-/
value : DeclValue
value : DeclValue pu
/--
We set this flag to true during LCNF conversion. When we receive
a block of functions to be compiled, we set this flag to `true`
@@ -631,31 +667,37 @@ structure Decl where
inlineAttr? : Option InlineAttributeKind
deriving Inhabited, BEq
def Decl.size (decl : Decl) : Nat :=
def Decl.size (decl : Decl pu) : Nat :=
decl.value.size
def Decl.getArity (decl : Decl) : Nat :=
def Decl.getArity (decl : Decl pu) : Nat :=
decl.params.size
def Decl.inlineAttr (decl : Decl) : Bool :=
def Decl.inlineAttr (decl : Decl pu) : Bool :=
decl.inlineAttr? matches some .inline
def Decl.noinlineAttr (decl : Decl) : Bool :=
def Decl.noinlineAttr (decl : Decl pu) : Bool :=
decl.inlineAttr? matches some .noinline
def Decl.inlineIfReduceAttr (decl : Decl) : Bool :=
def Decl.inlineIfReduceAttr (decl : Decl pu) : Bool :=
decl.inlineAttr? matches some .inlineIfReduce
def Decl.alwaysInlineAttr (decl : Decl) : Bool :=
def Decl.alwaysInlineAttr (decl : Decl pu) : Bool :=
decl.inlineAttr? matches some .alwaysInline
/-- Return `true` if the given declaration has been annotated with `[inline]`, `[inline_if_reduce]`, `[macro_inline]`, or `[always_inline]` -/
def Decl.inlineable (decl : Decl) : Bool :=
def Decl.inlineable (decl : Decl pu) : Bool :=
match decl.inlineAttr? with
| some .noinline => false
| some _ => true
| none => false
def Decl.castPurity! (decl : Decl pu1) (pu2 : Purity) : Decl pu2 :=
if h : pu1 = pu2 then
h decl
else
panic! s!"Purity {pu1} does not match {pu2}, this is a bug"
/--
Return `some i` if `decl` is of the form
```
@@ -669,21 +711,21 @@ That is, `f` is a sequence of declarations followed by a `cases` on the paramete
We use this function to decide whether we should inline a declaration tagged with
`[inline_if_reduce]` or not.
-/
def Decl.isCasesOnParam? (decl : Decl) : Option Nat :=
def Decl.isCasesOnParam? (decl : Decl pu) : Option Nat :=
match decl.value with
| .code c => go c
| .extern .. => none
where
go (code : Code) : Option Nat :=
go {pu : Purity} (code : Code pu) : Option Nat :=
match code with
| .let _ k | .jp _ k | .fun _ k => go k
| .let _ k | .jp _ k | .fun _ k _ => go k
| .cases c => decl.params.findIdx? fun param => param.fvarId == c.discr
| _ => none
def Decl.instantiateTypeLevelParams (decl : Decl) (us : List Level) : Expr :=
def Decl.instantiateTypeLevelParams (decl : Decl pu) (us : List Level) : Expr :=
decl.type.instantiateLevelParamsNoCache decl.levelParams us
def Decl.instantiateParamsLevelParams (decl : Decl) (us : List Level) : Array Param :=
def Decl.instantiateParamsLevelParams (decl : Decl pu) (us : List Level) : Array (Param pu) :=
decl.params.mapMono fun param => param.updateCore (param.type.instantiateLevelParamsNoCache decl.levelParams us)
/--
@@ -700,11 +742,11 @@ def hasLocalInst (type : Expr) : CoreM Bool := do
/--
Return `true` if `decl` is supposed to be inlined/specialized.
-/
def Decl.isTemplateLike (decl : Decl) : CoreM Bool := do
def Decl.isTemplateLike (decl : Decl pu) : CoreM Bool := do
let env getEnv
if hasLocalInst decl.type then
return true -- `decl` applications will be specialized
else if Meta.isInstanceCore env decl.name then
else if ( isInstanceReducible decl.name) then
return true -- `decl` is "fuel" for code specialization
else if decl.inlineable || hasSpecializeAttribute env decl.name then
return true -- `decl` is going to be inlined or specialized
@@ -721,40 +763,40 @@ private partial def collectType (e : Expr) : FVarIdHashSet → FVarIdHashSet :=
| .proj .. | .letE .. => unreachable!
| _ => id
private def collectArg (arg : Arg) (s : FVarIdHashSet) : FVarIdHashSet :=
private def collectArg (arg : Arg pu) (s : FVarIdHashSet) : FVarIdHashSet :=
match arg with
| .erased => s
| .fvar fvarId => s.insert fvarId
| .type e => collectType e s
| .type e _ => collectType e s
private def collectArgs (args : Array Arg) (s : FVarIdHashSet) : FVarIdHashSet :=
private def collectArgs (args : Array (Arg pu)) (s : FVarIdHashSet) : FVarIdHashSet :=
args.foldl (init := s) fun s arg => collectArg arg s
private def collectLetValue (e : LetValue) (s : FVarIdHashSet) : FVarIdHashSet :=
private def collectLetValue (e : LetValue pu) (s : FVarIdHashSet) : FVarIdHashSet :=
match e with
| .fvar fvarId args => collectArgs args <| s.insert fvarId
| .const _ _ args => collectArgs args s
| .proj _ _ fvarId => s.insert fvarId
| .const _ _ args _ => collectArgs args s
| .proj _ _ fvarId _ => s.insert fvarId
| .lit .. | .erased => s
private partial def collectParams (ps : Array Param) (s : FVarIdHashSet) : FVarIdHashSet :=
private partial def collectParams (ps : Array (Param pu)) (s : FVarIdHashSet) : FVarIdHashSet :=
ps.foldl (init := s) fun s p => collectType p.type s
mutual
partial def FunDecl.collectUsed (decl : FunDecl) (s : FVarIdHashSet := {}) : FVarIdHashSet :=
partial def FunDecl.collectUsed (decl : FunDecl pu) (s : FVarIdHashSet := {}) : FVarIdHashSet :=
decl.value.collectUsed <| collectParams decl.params <| collectType decl.type s
partial def Code.collectUsed (code : Code) (s : FVarIdHashSet := {}) : FVarIdHashSet :=
partial def Code.collectUsed (code : Code pu) (s : FVarIdHashSet := {}) : FVarIdHashSet :=
match code with
| .let decl k => k.collectUsed <| collectLetValue decl.value <| collectType decl.type s
| .jp decl k | .fun decl k => k.collectUsed <| decl.collectUsed s
| .jp decl k | .fun decl k _ => k.collectUsed <| decl.collectUsed s
| .cases c =>
let s := s.insert c.discr
let s := collectType c.resultType s
c.alts.foldl (init := s) fun s alt =>
match alt with
| .default k => k.collectUsed s
| .alt _ ps k => k.collectUsed <| collectParams ps s
| .alt _ ps k _ => k.collectUsed <| collectParams ps s
| .return fvarId => s.insert fvarId
| .unreach type => collectType type s
| .jmp fvarId args => collectArgs args <| s.insert fvarId
@@ -771,7 +813,7 @@ This is an overapproximation, and relies on the fact that our frontend
computes strongly connected components.
See comment at `recursive` field.
-/
partial def markRecDecls (decls : Array Decl) : Array Decl :=
partial def markRecDecls (decls : Array (Decl pu)) : Array (Decl pu) :=
let (_, isRec) := go |>.run {}
decls.map fun decl =>
if isRec.contains decl.name then
@@ -779,13 +821,13 @@ partial def markRecDecls (decls : Array Decl) : Array Decl :=
else
decl
where
visit (code : Code) : StateM NameSet Unit := do
visit {pu : Purity} (code : Code pu) : StateM NameSet Unit := do
match code with
| .jp decl k | .fun decl k => visit decl.value; visit k
| .jp decl k | .fun decl k _ => visit decl.value; visit k
| .cases c => c.alts.forM fun alt => visit alt.getCode
| .unreach .. | .jmp .. | .return .. => return ()
| .let decl k =>
if let .const declName _ _ := decl.value then
if let .const declName _ _ _ := decl.value then
if decls.any (·.name == declName) then
modify fun s => s.insert declName
visit k
@@ -793,13 +835,13 @@ where
go : StateM NameSet Unit :=
decls.forM (·.value.forCodeM visit)
def instantiateRangeArgs (e : Expr) (beginIdx endIdx : Nat) (args : Array Arg) : Expr :=
def instantiateRangeArgs (e : Expr) (beginIdx endIdx : Nat) (args : Array (Arg pu)) : Expr :=
if !e.hasLooseBVars then
e
else
e.instantiateRange beginIdx endIdx (args.map (·.toExpr))
def instantiateRevRangeArgs (e : Expr) (beginIdx endIdx : Nat) (args : Array Arg) : Expr :=
def instantiateRevRangeArgs (e : Expr) (beginIdx endIdx : Nat) (args : Array (Arg pu)) : Expr :=
if !e.hasLooseBVars then
e
else

View File

@@ -14,7 +14,7 @@ namespace Lean.Compiler.LCNF
/-- Helper class for lifting `CompilerM.codeBind` -/
class MonadCodeBind (m : Type Type) where
codeBind : (c : Code) (f : FVarId m Code) m Code
codeBind : {pu : Purity} (c : Code pu) (f : FVarId m (Code pu)) m (Code pu)
/--
Return code that is equivalent to `c >>= f`. That is, executes `c`, and then `f x`, where
@@ -25,16 +25,17 @@ an invalid block would be generated. It would be invalid because `f` would not
be applied to `jp_i`. Note that, we could have decided to create a copy of `jp_i` where we apply `f` to it,
by we decided to not do it to avoid code duplication.
-/
abbrev Code.bind [MonadCodeBind m] (c : Code) (f : FVarId m Code) : m Code :=
abbrev Code.bind [MonadCodeBind m] (c : Code pu) (f : FVarId m (Code pu)) : m (Code pu) :=
MonadCodeBind.codeBind c f
partial def CompilerM.codeBind (c : Code) (f : FVarId CompilerM Code) : CompilerM Code := do
partial def CompilerM.codeBind (c : Code pu) (f : FVarId CompilerM (Code pu)) :
CompilerM (Code pu) := do
go c |>.run {}
where
go (c : Code) : ReaderT FVarIdSet CompilerM Code := do
go (c : Code pu) : ReaderT FVarIdSet CompilerM (Code pu) := do
match c with
| .let decl k => return .let decl ( go k)
| .fun decl k => return .fun decl ( go k)
| .fun decl k _ => return .fun decl ( go k)
| .jp decl k =>
let value go decl.value
let type value.inferParamType decl.params
@@ -43,7 +44,7 @@ where
return .jp decl ( go k)
| .cases c =>
let alts c.alts.mapM fun
| .alt ctorName params k => return .alt ctorName params ( go k)
| .alt ctorName params k _ => return .alt ctorName params ( go k)
| .default k => return .default ( go k)
if alts.isEmpty then
throwError "`Code.bind` failed, empty `cases` found"
@@ -60,7 +61,7 @@ where
This code is not very efficient, we could ask caller to provide the type of `c >>= f`,
but this is more convenient, and this case is seldom reached.
-/
let auxParam mkAuxParam type
let auxParam mkAuxParam (pu := pu) type
let k f auxParam.fvarId
let typeNew k.inferType
eraseCode k
@@ -81,10 +82,10 @@ Create new parameters for the given arrow type.
Example: if `type` is `Nat → Bool → Int`, the result is
an array containing two new parameters with types `Nat` and `Bool`.
-/
partial def mkNewParams (type : Expr) : CompilerM (Array Param) :=
partial def mkNewParams (type : Expr) : CompilerM (Array (Param pu)) :=
go type #[] #[]
where
go (type : Expr) (xs : Array Expr) (ps : Array Param) : CompilerM (Array Param) := do
go (type : Expr) (xs : Array Expr) (ps : Array (Param pu)) : CompilerM (Array (Param pu)) := do
match type with
| .forallE _ d b _ =>
let d := d.instantiateRev xs
@@ -98,15 +99,16 @@ where
else
return ps
def isEtaExpandCandidateCore (type : Expr) (params : Array Param) : Bool :=
def isEtaExpandCandidateCore (type : Expr) (params : Array (Param .pure)) : Bool :=
let typeArity := getArrowArity type
let valueArity := params.size
typeArity > valueArity
abbrev FunDecl.isEtaExpandCandidate (decl : FunDecl) : Bool :=
abbrev FunDecl.isEtaExpandCandidate (decl : FunDecl .pure) : Bool :=
isEtaExpandCandidateCore decl.type decl.params
def etaExpandCore (type : Expr) (params : Array Param) (value : Code) : CompilerM (Array Param × Code) := do
def etaExpandCore (type : Expr) (params : Array (Param .pure)) (value : Code .pure) :
CompilerM (Array (Param .pure) × Code .pure) := do
let valueType instantiateForall type (params.map (mkFVar ·.fvarId))
let psNew mkNewParams valueType
let params := params ++ psNew
@@ -116,17 +118,17 @@ def etaExpandCore (type : Expr) (params : Array Param) (value : Code) : Compiler
return .let auxDecl (.return auxDecl.fvarId)
return (params, value)
def etaExpandCore? (type : Expr) (params : Array Param) (value : Code) : CompilerM (Option (Array Param × Code)) := do
def etaExpandCore? (type : Expr) (params : Array (Param .pure)) (value : Code .pure) : CompilerM (Option (Array (Param .pure) × Code .pure)) := do
if isEtaExpandCandidateCore type params then
etaExpandCore type params value
else
return none
def FunDecl.etaExpand (decl : FunDecl) : CompilerM FunDecl := do
def FunDecl.etaExpand (decl : FunDecl .pure) : CompilerM (FunDecl .pure) := do
let some (params, value) etaExpandCore? decl.type decl.params decl.value | return decl
decl.update decl.type params value
def Decl.etaExpand (decl : Decl) : CompilerM Decl := do
def Decl.etaExpand (decl : Decl .pure) : CompilerM (Decl .pure) := do
match decl.value with
| .code code =>
let some (params, newCode) etaExpandCore? decl.type decl.params code | return decl

View File

@@ -20,17 +20,17 @@ namespace CSE
structure State where
map : PHashMap Expr FVarId := {}
subst : FVarSubst := {}
subst : FVarSubst .pure := {}
abbrev M := StateRefT State CompilerM
instance : MonadFVarSubst M false where
instance : MonadFVarSubst M .pure false where
getSubst := return ( get).subst
instance : MonadFVarSubstState M where
instance : MonadFVarSubstState M .pure where
modifySubst f := modify fun s => { s with subst := f s.subst }
@[inline] def getSubst : M FVarSubst :=
@[inline] def getSubst : M (FVarSubst .pure) :=
return ( get).subst
@[inline] def addEntry (value : Expr) (fvarId : FVarId) : M Unit :=
@@ -40,31 +40,32 @@ instance : MonadFVarSubstState M where
let map := ( get).map
try x finally modify fun s => { s with map }
def replaceLet (decl : LetDecl) (fvarId : FVarId) : M Unit := do
def replaceLet (decl : LetDecl .pure) (fvarId : FVarId) : M Unit := do
eraseLetDecl decl
addFVarSubst decl.fvarId fvarId
def replaceFun (decl : FunDecl) (fvarId : FVarId) : M Unit := do
def replaceFun (decl : FunDecl .pure) (fvarId : FVarId) : M Unit := do
eraseFunDecl decl
addFVarSubst decl.fvarId fvarId
def hasNeverExtract (v : LetValue) : CompilerM Bool :=
def hasNeverExtract (v : LetValue .pure) : CompilerM Bool :=
match v with
| .const declName .. =>
return hasNeverExtractAttribute ( getEnv) declName
| .lit _ | .erased | .proj .. | .fvar .. =>
return false
partial def _root_.Lean.Compiler.LCNF.Code.cse (shouldElimFunDecls : Bool) (code : Code) : CompilerM Code :=
partial def _root_.Lean.Compiler.LCNF.Code.cse (shouldElimFunDecls : Bool) (code : Code .pure) :
CompilerM (Code .pure) :=
go code |>.run' {}
where
goFunDecl (decl : FunDecl) : M FunDecl := do
goFunDecl (decl : FunDecl .pure) : M (FunDecl .pure) := do
let type normExpr decl.type
let params normParams decl.params
let value withNewScope do go decl.value
decl.update type params value
go (code : Code) : M Code := do
go (code : Code .pure) : M (Code .pure) := do
match code with
| .let decl k =>
let decl normLetDecl decl
@@ -118,12 +119,13 @@ end CSE
/--
Common sub-expression elimination
-/
def Decl.cse (shouldElimFunDecls : Bool) (decl : Decl) : CompilerM Decl := do
def Decl.cse (shouldElimFunDecls : Bool) (decl : Decl .pure) : CompilerM (Decl .pure) := do
let value decl.value.mapCodeM (·.cse shouldElimFunDecls)
return { decl with value }
def cse (phase : Phase := .base) (shouldElimFunDecls := false) (occurrence := 0) : Pass :=
.mkPerDeclaration `cse (Decl.cse shouldElimFunDecls) phase occurrence
phase.withPurityCheck .pure fun h =>
.mkPerDeclaration `cse phase (h Decl.cse shouldElimFunDecls) occurrence
builtin_initialize
registerTraceClass `Compiler.cse (inherited := true)

View File

@@ -79,7 +79,8 @@ the subtype relation in sanity checks and add the necessary casts.
-/
namespace Check
open InferType
namespace Pure
open InferType InferType.Pure
/-
Type and structural properties checker for LCNF expressions.
@@ -110,7 +111,7 @@ def isCtorParam (f : Expr) (i : Nat) : CoreM Bool := do
let .ctorInfo info getConstInfo declName | return false
return i < info.numParams
def checkAppArgs (f : Expr) (args : Array Arg) : CheckM Unit := do
def checkAppArgs (f : Expr) (args : Array (Arg .pure)) : CheckM Unit := do
let mut fType inferType f
let mut j := 0
for h : i in *...args.size do
@@ -129,11 +130,11 @@ def checkAppArgs (f : Expr) (args : Array Arg) : CheckM Unit := do
let expectedType := instantiateRevRangeArgs d j i args
if ( checkTypes) then
let argType arg.inferType
unless ( InferType.compatibleTypes argType expectedType) do
unless ( compatibleTypes argType expectedType) do
throwError "type mismatch at LCNF application{indentExpr (mkAppN f (args.map Arg.toExpr))}\nargument {arg.toExpr} has type{indentExpr argType}\nbut is expected to have type{indentExpr expectedType}"
fType := b
def checkLetValue (e : LetValue) : CheckM Unit := do
def checkLetValue (e : LetValue .pure) : CheckM Unit := do
match e with
| .lit .. | .erased => pure ()
| .const declName us args => checkAppArgs (mkConst declName us) args
@@ -154,18 +155,18 @@ def checkJpInScope (jp : FVarId) : CheckM Unit := do
-/
throwError "invalid jump to out of scope join point `{mkFVar jp}`"
def checkParam (param : Param) : CheckM Unit := do
def checkParam (param : Param .pure) : CheckM Unit := do
unless param == ( getParam param.fvarId) do
throwError "LCNF parameter mismatch at `{param.binderName}`, does not value in local context"
def checkParams (params : Array Param) : CheckM Unit :=
def checkParams (params : Array (Param .pure)) : CheckM Unit :=
params.forM checkParam
def checkLetDecl (letDecl : LetDecl) : CheckM Unit := do
def checkLetDecl (letDecl : LetDecl .pure) : CheckM Unit := do
checkLetValue letDecl.value
if ( checkTypes) then
let valueType letDecl.value.inferType
unless ( InferType.compatibleTypes letDecl.type valueType) do
unless ( compatibleTypes letDecl.type valueType) do
throwError "type mismatch at `{letDecl.binderName}`, value has type{indentExpr valueType}\nbut is expected to have type{indentExpr letDecl.type}"
unless letDecl == ( getLetDecl letDecl.fvarId) do
throwError "LCNF let declaration mismatch at `{letDecl.binderName}`, does not match value in local context"
@@ -183,7 +184,7 @@ def addFVarId (fvarId : FVarId) : CheckM Unit := do
addFVarId fvarId
withReader (fun ctx => { ctx with jps := ctx.jps.insert fvarId }) x
@[inline] def withParams (params : Array Param) (x : CheckM α) : CheckM α := do
@[inline] def withParams (params : Array (Param .pure)) (x : CheckM α) : CheckM α := do
params.forM (addFVarId ·.fvarId)
withReader (fun ctx => { ctx with vars := params.foldl (init := ctx.vars) fun vars p => vars.insert p.fvarId })
x
@@ -192,18 +193,18 @@ mutual
set_option linter.all false
partial def checkFunDeclCore (declName : Name) (params : Array Param) (type : Expr) (value : Code) : CheckM Unit := do
partial def checkFunDeclCore (declName : Name) (params : Array (Param .pure)) (type : Expr) (value : Code .pure) : CheckM Unit := do
checkParams params
withParams params do
discard <| check value
if ( checkTypes) then
let valueType mkForallParams params ( value.inferType)
unless ( InferType.compatibleTypes type valueType) do
unless ( compatibleTypes type valueType) do
throwError "type mismatch at `{.ofConstName declName}`, value has type{indentExpr valueType}\nbut is expected to have type{indentExpr type}"
partial def checkFunDecl (funDecl : FunDecl) : CheckM Unit := do
partial def checkFunDecl (funDecl : FunDecl .pure) : CheckM Unit := do
checkFunDeclCore funDecl.binderName funDecl.params funDecl.type funDecl.value
let decl getFunDecl funDecl.fvarId
let decl getFunDecl (pu := .pure) funDecl.fvarId
unless decl.binderName == funDecl.binderName do
throwError "LCNF local function declaration mismatch at `{funDecl.binderName}`, binder name in local context `{decl.binderName}`"
unless decl.type == funDecl.type do
@@ -211,7 +212,7 @@ partial def checkFunDecl (funDecl : FunDecl) : CheckM Unit := do
unless ( getFunDecl funDecl.fvarId) == funDecl do
throwError "LCNF local function declaration mismatch at `{funDecl.binderName}`, declaration in local context does match"
partial def checkCases (c : Cases) : CheckM Unit := do
partial def checkCases (c : Cases .pure) : CheckM Unit := do
let mut ctorNames : NameSet := {}
let mut hasDefault := false
checkFVar c.discr
@@ -230,7 +231,7 @@ partial def checkCases (c : Cases) : CheckM Unit := do
throwError "invalid LCNF `cases`, `{ctorName}` has # {val.numFields} fields, but alternative has # {params.size} alternatives"
withParams params do check k
partial def check (code : Code) : CheckM Unit := do
partial def check (code : Code .pure) : CheckM Unit := do
match code with
| .let decl k => checkLetDecl decl; withFVarId decl.fvarId do check k
| .fun decl k =>
@@ -241,7 +242,7 @@ partial def check (code : Code) : CheckM Unit := do
| .cases c => checkCases c
| .jmp fvarId args =>
checkJpInScope fvarId
let decl getFunDecl fvarId
let decl getFunDecl (pu := .pure) fvarId
unless decl.getArity == args.size do
throwError "invalid LCNF `goto`, join point {decl.binderName} has #{decl.getArity} parameters, but #{args.size} were provided"
checkAppArgs (.fvar fvarId) args
@@ -253,9 +254,12 @@ end
def run (x : CheckM α) : CompilerM α :=
x |>.run {} |>.run' {} |>.run {}
end Pure
end Check
def Decl.check (decl : Decl) : CompilerM Unit := do
Check.run do decl.value.forCodeM (Check.checkFunDeclCore decl.name decl.params decl.type)
def Decl.check (decl : Decl pu) : CompilerM Unit := do
match pu with
| .pure => Check.Pure.run do decl.value.forCodeM (Check.Pure.checkFunDeclCore decl.name decl.params decl.type)
| .impure => panic! "Check for impure unimplemented" -- TODO
end Lean.Compiler.LCNF

View File

@@ -33,10 +33,6 @@ structure Context where
Remark: the lambda lifting pass abstracts all `let`/`fun`-declarations.
-/
abstract : FVarId Bool
/--
Indicates whether we are processing terms beneath a binder.
-/
isUnderBinder : Bool
/--
State for the `ClosureM` monad.
@@ -49,7 +45,7 @@ structure State where
/--
Free variables that must become new parameters of the code being specialized.
-/
params : Array Param := #[]
params : Array (Param .pure) := #[]
/--
Let-declarations and local function declarations that are going to be "copied" to the code
being processed. For example, when this module is used in the code specializer, the let-declarations
@@ -60,7 +56,7 @@ structure State where
All customers of this module try to avoid work duplication. If a let-declaration is a ground value,
it most likely will be computed during compilation time, and work duplication is not an issue.
-/
decls : Array CodeDecl := #[]
decls : Array (CodeDecl .pure) := #[]
/--
Monad for implementing the dependency collector.
@@ -79,16 +75,16 @@ mutual
Collect dependencies in parameters. We need this because parameters may
contain other type parameters.
-/
partial def collectParams (params : Array Param) : ClosureM Unit :=
partial def collectParams (params : Array (Param .pure)) : ClosureM Unit :=
params.forM (collectType ·.type)
partial def collectArg (arg : Arg) : ClosureM Unit :=
partial def collectArg (arg : Arg .pure) : ClosureM Unit :=
match arg with
| .erased => return ()
| .type e => collectType e
| .fvar fvarId => collectFVar fvarId
partial def collectLetValue (e : LetValue) : ClosureM Unit := do
partial def collectLetValue (e : LetValue .pure) : ClosureM Unit := do
match e with
| .erased | .lit .. => return ()
| .proj _ _ fvarId => collectFVar fvarId
@@ -99,12 +95,11 @@ mutual
Collect dependencies in the given code. We need this function to be able
to collect dependencies in a local function declaration.
-/
partial def collectCode (c : Code) : ClosureM Unit := do
partial def collectCode (c : Code .pure) : ClosureM Unit := do
match c with
| .let decl k =>
collectType decl.type
withReader (fun ctx => { ctx with isUnderBinder := ctx.isUnderBinder || decl.type.isForall })
do collectLetValue decl.value
collectLetValue decl.value
collectCode k
| .fun decl k | .jp decl k => collectFunDecl decl; collectCode k
| .cases c =>
@@ -119,11 +114,10 @@ mutual
| .return fvarId => collectFVar fvarId
/-- Collect dependencies of a local function declaration. -/
partial def collectFunDecl (decl : FunDecl) : ClosureM Unit := do
partial def collectFunDecl (decl : FunDecl .pure) : ClosureM Unit := do
collectType decl.type
collectParams decl.params
withReader (fun ctx => { ctx with isUnderBinder := true }) do
collectCode decl.value
collectCode decl.value
/--
Process the given free variable.
@@ -146,7 +140,7 @@ mutual
modify fun s => { s with params := s.params.push param }
else if let some letDecl findLetDecl? fvarId then
collectType letDecl.type
if ctx.isUnderBinder || ctx.abstract letDecl.fvarId then
if ctx.abstract letDecl.fvarId then
modify fun s => { s with params := s.params.push <| { letDecl with borrow := false } }
else
collectLetValue letDecl.value
@@ -161,8 +155,9 @@ mutual
end
def run (x : ClosureM α) (inScope : FVarId Bool) (abstract : FVarId Bool := fun _ => true) : CompilerM (α × Array Param × Array CodeDecl) := do
let (a, s) x { inScope, abstract, isUnderBinder := false } |>.run {}
def run (x : ClosureM α) (inScope : FVarId Bool) (abstract : FVarId Bool := fun _ => true) :
CompilerM (α × Array (Param .pure) × Array (CodeDecl .pure)) := do
let (a, s) x { inScope, abstract } |>.run {}
-- If we've abstracted an fvar into a param, exclude its definition. Note that this still allows
-- for other decls the removed decl depends upon to be included, but they will be removed later
-- for having no users.

View File

@@ -72,10 +72,13 @@ partial def compatibleTypesQuick (a b : Expr) : Bool :=
| .const n us, .const m vs => n == m && List.isEqv us vs Level.isEquiv
| _, _ => false
namespace InferType
namespace Pure
/--
Complete check for `compatibleTypes`. It eta-expands type formers. See comment at `compatibleTypes`.
-/
partial def InferType.compatibleTypesFull (a b : Expr) : InferTypeM Bool := do
partial def compatibleTypesFull (a b : Expr) : InferTypeM Bool := do
if a.isErased || b.isErased then
return true
else
@@ -141,10 +144,13 @@ This is a simplification. We used to use `isErasedCompatible`, but this only add
For item 2, we would have to modify the `toLCNFType` function and make sure a type former is erased if the expected
type is not always a type former (see `S.mk` type and example in the note above).
-/
def InferType.compatibleTypes (a b : Expr) : InferTypeM Bool := do
def compatibleTypes (a b : Expr) : InferTypeM Bool := do
if compatibleTypesQuick a b then
return true
else
compatibleTypesFull a b
end Pure
end InferType
end Lean.Compiler.LCNF

View File

@@ -21,7 +21,12 @@ inductive Phase where
| base
/-- In this phase polymorphism has been eliminated. -/
| mono
deriving Inhabited, BEq
| impure
deriving Inhabited, DecidableEq
@[expose, reducible] def Phase.toPurity : Phase Purity
| .base | .mono => .pure
| .impure => .impure
/--
The state managed by the `CompilerM` `Monad`.
@@ -52,48 +57,53 @@ instance : Monad CompilerM := let i := inferInstanceAs (Monad CompilerM); { pure
def getPhase : CompilerM Phase :=
return ( read).phase
def getPurity : CompilerM Purity :=
return ( getPhase).toPurity
def inBasePhase : CompilerM Bool :=
return ( getPhase) matches .base
instance : AddMessageContext CompilerM where
addMessageContext msgData := do
let env getEnv
let lctx := ( get).lctx.toLocalContext
let lctx := ( get).lctx.toLocalContext ( getPurity)
let opts getOptions
return MessageData.withContext { env, lctx, opts, mctx := {} } msgData
def getType (fvarId : FVarId) : CompilerM Expr := do
let lctx := ( get).lctx
if let some decl := lctx.letDecls[fvarId]? then
let pu getPurity
if let some decl := (lctx.letDecls pu)[fvarId]? then
return decl.type
else if let some decl := lctx.params[fvarId]? then
else if let some decl := (lctx.params pu)[fvarId]? then
return decl.type
else if let some decl := lctx.funDecls[fvarId]? then
else if let some decl := (lctx.funDecls pu)[fvarId]? then
return decl.type
else
throwError "unknown free variable {fvarId.name}"
def getBinderName (fvarId : FVarId) : CompilerM Name := do
let lctx := ( get).lctx
if let some decl := lctx.letDecls[fvarId]? then
let pu getPurity
if let some decl := (lctx.letDecls pu)[fvarId]? then
return decl.binderName
else if let some decl := lctx.params[fvarId]? then
else if let some decl := (lctx.params pu)[fvarId]? then
return decl.binderName
else if let some decl := lctx.funDecls[fvarId]? then
else if let some decl := (lctx.funDecls pu)[fvarId]? then
return decl.binderName
else
throwError "unknown free variable {fvarId.name}"
def findParam? (fvarId : FVarId) : CompilerM (Option Param) :=
return ( get).lctx.params[fvarId]?
def findParam? (fvarId : FVarId) : CompilerM (Option (Param pu)) := do
return (( get).lctx.params pu)[fvarId]?
def findLetDecl? (fvarId : FVarId) : CompilerM (Option LetDecl) :=
return ( get).lctx.letDecls[fvarId]?
def findLetDecl? (fvarId : FVarId) : CompilerM (Option (LetDecl pu)) := do
return (( get).lctx.letDecls pu)[fvarId]?
def findFunDecl? (fvarId : FVarId) : CompilerM (Option FunDecl) :=
return ( get).lctx.funDecls[fvarId]?
def findFunDecl? (fvarId : FVarId) : CompilerM (Option (FunDecl pu)) := do
return (( get).lctx.funDecls pu)[fvarId]?
def findLetValue? (fvarId : FVarId) : CompilerM (Option LetValue) := do
def findLetValue? (fvarId : FVarId) : CompilerM (Option (LetValue pu)) := do
let some { value, .. } findLetDecl? fvarId | return none
return some value
@@ -101,56 +111,56 @@ def isConstructorApp (fvarId : FVarId) : CompilerM Bool := do
let some (.const declName _ _) findLetValue? fvarId | return false
return ( getEnv).find? declName matches some (.ctorInfo ..)
def Arg.isConstructorApp (arg : Arg) : CompilerM Bool := do
def Arg.isConstructorApp (arg : Arg pu) : CompilerM Bool := do
let .fvar fvarId := arg | return false
LCNF.isConstructorApp fvarId
def getParam (fvarId : FVarId) : CompilerM Param := do
def getParam (fvarId : FVarId) : CompilerM (Param pu) := do
let some param findParam? fvarId | throwError "unknown parameter {fvarId.name}"
return param
def getLetDecl (fvarId : FVarId) : CompilerM LetDecl := do
def getLetDecl (fvarId : FVarId) : CompilerM (LetDecl pu) := do
let some decl findLetDecl? fvarId | throwError "unknown let-declaration {fvarId.name}"
return decl
def getFunDecl (fvarId : FVarId) : CompilerM FunDecl := do
def getFunDecl (fvarId : FVarId) : CompilerM (FunDecl pu) := do
let some decl findFunDecl? fvarId | throwError "unknown local function {fvarId.name}"
return decl
@[inline] def modifyLCtx (f : LCtx LCtx) : CompilerM Unit := do
modify fun s => { s with lctx := f s.lctx }
def eraseLetDecl (decl : LetDecl) : CompilerM Unit := do
def eraseLetDecl (decl : LetDecl pu) : CompilerM Unit := do
modifyLCtx fun lctx => lctx.eraseLetDecl decl
def eraseFunDecl (decl : FunDecl) (recursive := true) : CompilerM Unit := do
def eraseFunDecl (decl : FunDecl pu) (recursive := true) : CompilerM Unit := do
modifyLCtx fun lctx => lctx.eraseFunDecl decl recursive
def eraseCode (code : Code) : CompilerM Unit := do
def eraseCode (code : Code pu) : CompilerM Unit := do
modifyLCtx fun lctx => lctx.eraseCode code
def eraseParam (param : Param) : CompilerM Unit :=
def eraseParam (param : Param pu) : CompilerM Unit :=
modifyLCtx fun lctx => lctx.eraseParam param
def eraseParams (params : Array Param) : CompilerM Unit :=
def eraseParams (params : Array (Param pu)) : CompilerM Unit :=
modifyLCtx fun lctx => lctx.eraseParams params
def eraseCodeDecl (decl : CodeDecl) : CompilerM Unit := do
def eraseCodeDecl (decl : CodeDecl pu) : CompilerM Unit := do
match decl with
| .let decl => eraseLetDecl decl
| .jp decl | .fun decl => eraseFunDecl decl
| .jp decl | .fun decl _ => eraseFunDecl decl
/--
Erase all free variables occurring in `decls` from the local context.
-/
def eraseCodeDecls (decls : Array CodeDecl) : CompilerM Unit := do
def eraseCodeDecls (decls : Array (CodeDecl pu)) : CompilerM Unit := do
decls.forM fun decl => eraseCodeDecl decl
def eraseDecl (decl : Decl) : CompilerM Unit := do
def eraseDecl (decl : Decl pu) : CompilerM Unit := do
eraseParams decl.params
decl.value.forCodeM eraseCode
abbrev Decl.erase (decl : Decl) : CompilerM Unit :=
abbrev Decl.erase (decl : Decl pu) : CompilerM Unit :=
eraseDecl decl
/--
@@ -166,7 +176,7 @@ it is a free variable, a type (or type former), or `lcErased`.
`Check.lean` contains a substitution validator.
-/
abbrev FVarSubst := Std.HashMap FVarId Arg
abbrev FVarSubst (pu : Purity) := Std.HashMap FVarId (Arg pu)
/--
Replace the free variables in `e` using the given substitution.
@@ -179,7 +189,7 @@ If `translator = false`, we assume the substitution contains free variable repla
and given entries such as `x₁ ↦ x₂`, `x₂ ↦ x₃`, ..., `xₙ₋₁ ↦ xₙ`, and the expression `f x₁ x₂`, we want the resulting
expression to be `f xₙ xₙ`. We use this setting, for example, in the simplifier.
-/
private partial def normExprImp (s : FVarSubst) (e : Expr) (translator : Bool) : Expr :=
private partial def normExprImp (s : FVarSubst pu) (e : Expr) (translator : Bool) : Expr :=
go e
where
goApp (e : Expr) : Expr :=
@@ -192,7 +202,7 @@ where
match e with
| .fvar fvarId => match s[fvarId]? with
| some (.fvar fvarId') => if translator then .fvar fvarId' else go (.fvar fvarId')
| some (.type e) => if translator then e else go e
| some (.type e _) => if translator then e else go e
| some .erased => erasedExpr
| none => e
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e
@@ -225,7 +235,7 @@ This function panics if the substitution is mapping `fvarId` to an expression th
That is, it is not a type (or type former), nor `lcErased`. Recall that a valid `FVarSubst` contains only
expressions that are free variables, `lcErased`, or type formers.
-/
partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : Bool) : NormFVarResult :=
partial def normFVarImp (s : FVarSubst pu) (fvarId : FVarId) (translator : Bool) : NormFVarResult :=
match s[fvarId]? with
| some (.fvar fvarId') =>
if translator then
@@ -234,7 +244,7 @@ partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : Bool) :
normFVarImp s fvarId' translator
-- Types and type formers are only preserved as hints and
-- are erased in computationally relevant contexts.
| some .erased | some (.type _) => .erased
| some .erased | some (.type _ _) => .erased
| none => .fvar fvarId
/--
@@ -242,18 +252,18 @@ Replace the free variables in `arg` using the given substitution.
See `normExprImp`
-/
private partial def normArgImp (s : FVarSubst) (arg : Arg) (translator : Bool) : Arg :=
private partial def normArgImp (s : FVarSubst pu) (arg : Arg pu) (translator : Bool) : Arg pu :=
match arg with
| .erased => arg
| .fvar fvarId =>
match s[fvarId]? with
| some (arg'@(.fvar _)) =>
if translator then arg' else normArgImp s arg' translator
| some (arg'@.erased) | some (arg'@(.type _)) => arg'
| some (arg'@.erased) | some (arg'@(.type _ _)) => arg'
| none => arg
| .type e => arg.updateType! (normExprImp s e translator)
| .type e _ => arg.updateType! (normExprImp s e translator)
private def normArgsImp (s : FVarSubst) (args : Array Arg) (translator : Bool) : Array Arg :=
private def normArgsImp (s : FVarSubst pu) (args : Array (Arg pu)) (translator : Bool) : Array (Arg pu) :=
args.mapMono (normArgImp s · translator)
/--
@@ -261,13 +271,13 @@ Replace the free variables in `e` using the given substitution.
See `normExprImp`
-/
private partial def normLetValueImp (s : FVarSubst) (e : LetValue) (translator : Bool) : LetValue :=
private partial def normLetValueImp (s : FVarSubst pu) (e : LetValue pu) (translator : Bool) : LetValue pu :=
match e with
| .erased | .lit .. => e
| .proj _ _ fvarId => match normFVarImp s fvarId translator with
| .proj _ _ fvarId _ => match normFVarImp s fvarId translator with
| .fvar fvarId' => e.updateProj! fvarId'
| .erased => .erased
| .const _ _ args => e.updateArgs! (normArgsImp s args translator)
| .const _ _ args _ => e.updateArgs! (normArgsImp s args translator)
| .fvar fvarId args => match normFVarImp s fvarId translator with
| .fvar fvarId' => e.updateFVar! fvarId' (normArgsImp s args translator)
| .erased => .erased
@@ -275,20 +285,20 @@ private partial def normLetValueImp (s : FVarSubst) (e : LetValue) (translator :
/--
Interface for monads that have a free substitutions.
-/
class MonadFVarSubst (m : Type Type) (translator : outParam Bool) where
getSubst : m FVarSubst
class MonadFVarSubst (m : Type Type) (pu : outParam Purity) (translator : outParam Bool) where
getSubst : m (FVarSubst pu)
export MonadFVarSubst (getSubst)
instance (m n) [MonadLift m n] [MonadFVarSubst m t] : MonadFVarSubst n t where
instance (m n) [MonadLift m n] [MonadFVarSubst m pu t] : MonadFVarSubst n pu t where
getSubst := liftM (getSubst : m _)
class MonadFVarSubstState (m : Type Type) where
modifySubst : (FVarSubst FVarSubst) m Unit
class MonadFVarSubstState (m : Type Type) (pu : outParam Purity) where
modifySubst : (FVarSubst pu FVarSubst pu) m Unit
export MonadFVarSubstState (modifySubst)
instance (m n) [MonadLift m n] [MonadFVarSubstState m] : MonadFVarSubstState n where
instance (m n) [MonadLift m n] [MonadFVarSubstState m pu] : MonadFVarSubstState n pu where
modifySubst f := liftM (modifySubst f : m _)
/--
@@ -296,35 +306,35 @@ Add the substitution `fvarId ↦ e`, `e` must be a valid LCNF `Arg`.
See `Check.lean` for the free variable substitution checker.
-/
@[inline] def addSubst [MonadFVarSubstState m] (fvarId : FVarId) (arg : Arg) : m Unit :=
@[inline] def addSubst [MonadFVarSubstState m pu] (fvarId : FVarId) (arg : Arg pu) : m Unit :=
modifySubst fun s => s.insert fvarId arg
/--
Add the entry `fvarId ↦ fvarId'` to the free variable substitution.
-/
@[inline] def addFVarSubst [MonadFVarSubstState m] (fvarId : FVarId) (fvarId' : FVarId) : m Unit :=
@[inline] def addFVarSubst [MonadFVarSubstState m ph] (fvarId : FVarId) (fvarId' : FVarId) : m Unit :=
modifySubst fun s => s.insert fvarId (.fvar fvarId')
@[inline, inherit_doc normFVarImp] def normFVar [MonadFVarSubst m t] [Monad m] (fvarId : FVarId) : m NormFVarResult :=
@[inline, inherit_doc normFVarImp] def normFVar [MonadFVarSubst m pu t] [Monad m] (fvarId : FVarId) : m NormFVarResult :=
return normFVarImp ( getSubst) fvarId t
@[inline, inherit_doc normExprImp] def normExpr [MonadFVarSubst m t] [Monad m] (e : Expr) : m Expr :=
@[inline, inherit_doc normExprImp] def normExpr [MonadFVarSubst m pu t] [Monad m] (e : Expr) : m Expr :=
return normExprImp ( getSubst) e t
@[inline, inherit_doc normArgImp] def normArg [MonadFVarSubst m t] [Monad m] (arg : Arg) : m Arg :=
@[inline, inherit_doc normArgImp] def normArg [MonadFVarSubst m pu t] [Monad m] (arg : Arg pu) : m (Arg pu) :=
return normArgImp ( getSubst) arg t
@[inline, inherit_doc normLetValueImp] def normLetValue [MonadFVarSubst m t] [Monad m] (e : LetValue) : m LetValue :=
@[inline, inherit_doc normLetValueImp] def normLetValue [MonadFVarSubst m pu t] [Monad m] (e : LetValue pu) : m (LetValue pu) :=
return normLetValueImp ( getSubst) e t
@[inherit_doc normExprImp, inline]
def normExprCore (s : FVarSubst) (e : Expr) (translator : Bool) : Expr :=
def normExprCore (s : FVarSubst pu) (e : Expr) (translator : Bool) : Expr :=
normExprImp s e translator
/--
Normalize the given arguments using the current substitution.
-/
def normArgs [MonadFVarSubst m t] [Monad m] (args : Array Arg) : m (Array Arg) :=
def normArgs [MonadFVarSubst m pu t] [Monad m] (args : Array (Arg pu)) : m (Array (Arg pu)) :=
return normArgsImp ( getSubst) args t
def mkFreshBinderName (binderName := `_x): CompilerM Name := do
@@ -342,35 +352,35 @@ def ensureNotAnonymous (binderName : Name) (baseName : Name) : CompilerM Name :=
Helper functions for creating LCNF local declarations.
-/
def mkParam (binderName : Name) (type : Expr) (borrow : Bool) : CompilerM Param := do
def mkParam (binderName : Name) (type : Expr) (borrow : Bool) : CompilerM (Param pu) := do
let fvarId mkFreshFVarId
let binderName ensureNotAnonymous binderName `_y
let param := { fvarId, binderName, type, borrow }
modifyLCtx fun lctx => lctx.addParam param
return param
def mkLetDecl (binderName : Name) (type : Expr) (value : LetValue) : CompilerM LetDecl := do
def mkLetDecl (binderName : Name) (type : Expr) (value : LetValue pu) : CompilerM (LetDecl pu) := do
let fvarId mkFreshFVarId
let binderName ensureNotAnonymous binderName `_x
let decl := { fvarId, binderName, type, value }
modifyLCtx fun lctx => lctx.addLetDecl decl
return decl
def mkFunDecl (binderName : Name) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
def mkFunDecl (binderName : Name) (type : Expr) (params : Array (Param pu)) (value : Code pu) : CompilerM (FunDecl pu) := do
let fvarId mkFreshFVarId
let binderName ensureNotAnonymous binderName `_f
let funDecl := fvarId, binderName, params, type, value
modifyLCtx fun lctx => lctx.addFunDecl funDecl
return funDecl
def mkLetDeclErased : CompilerM LetDecl := do
def mkLetDeclErased : CompilerM (LetDecl pu) := do
mkLetDecl ( mkFreshBinderName `_x) erasedExpr .erased
def mkReturnErased : CompilerM Code := do
def mkReturnErased : CompilerM (Code pu) := do
let auxDecl mkLetDeclErased
return .let auxDecl (.return auxDecl.fvarId)
private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param := do
private unsafe def updateParamImp (p : Param pu) (type : Expr) : CompilerM (Param pu) := do
if ptrEq type p.type then
return p
else
@@ -378,9 +388,9 @@ private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param :=
modifyLCtx fun lctx => lctx.addParam p
return p
@[implemented_by updateParamImp] opaque Param.update (p : Param) (type : Expr) : CompilerM Param
@[implemented_by updateParamImp] opaque Param.update (p : Param pu) (type : Expr) : CompilerM (Param pu)
private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : LetValue) : CompilerM LetDecl := do
private unsafe def updateLetDeclImp (decl : LetDecl pu) (type : Expr) (value : LetValue pu) : CompilerM (LetDecl pu) := do
if ptrEq type decl.type && ptrEq value decl.value then
return decl
else
@@ -388,12 +398,12 @@ private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : LetV
modifyLCtx fun lctx => lctx.addLetDecl decl
return decl
@[implemented_by updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : LetValue) : CompilerM LetDecl
@[implemented_by updateLetDeclImp] opaque LetDecl.update (decl : LetDecl pu) (type : Expr) (value : LetValue pu) : CompilerM (LetDecl pu)
def LetDecl.updateValue (decl : LetDecl) (value : LetValue) : CompilerM LetDecl :=
def LetDecl.updateValue (decl : LetDecl pu) (value : LetValue pu) : CompilerM (LetDecl pu) :=
decl.update decl.type value
private unsafe def updateFunDeclImp (decl : FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
private unsafe def updateFunDeclImp (decl : FunDecl pu) (type : Expr) (params : Array (Param pu)) (value : Code pu) : CompilerM (FunDecl pu) := do
if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then
return decl
else
@@ -401,48 +411,48 @@ private unsafe def updateFunDeclImp (decl : FunDecl) (type : Expr) (params : Arr
modifyLCtx fun lctx => lctx.addFunDecl decl
return decl
@[implemented_by updateFunDeclImp] opaque FunDecl.update (decl : FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl
@[implemented_by updateFunDeclImp] opaque FunDecl.update (decl : FunDecl pu) (type : Expr) (params : Array (Param pu)) (value : Code pu) : CompilerM (FunDecl pu)
abbrev FunDecl.update' (decl : FunDecl) (type : Expr) (value : Code) : CompilerM FunDecl :=
abbrev FunDecl.update' (decl : FunDecl pu) (type : Expr) (value : Code pu) : CompilerM (FunDecl pu) :=
decl.update type decl.params value
abbrev FunDecl.updateValue (decl : FunDecl) (value : Code) : CompilerM FunDecl :=
abbrev FunDecl.updateValue (decl : FunDecl pu) (value : Code pu) : CompilerM (FunDecl pu) :=
decl.update decl.type decl.params value
@[inline] def normParam [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (p : Param) : m Param := do
@[inline] def normParam [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (p : Param pu) : m (Param pu) := do
p.update ( normExpr p.type)
def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (ps : Array Param) : m (Array Param) :=
def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (ps : Array (Param pu)) : m (Array (Param pu)) :=
ps.mapMonoM normParam
def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (decl : LetDecl) : m LetDecl := do
def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (decl : LetDecl pu) : m (LetDecl pu) := do
decl.update ( normExpr decl.type) ( normLetValue decl.value)
abbrev NormalizerM (_translator : Bool) := ReaderT FVarSubst CompilerM
abbrev NormalizerM (pu : Purity) (_translator : Bool) := ReaderT (FVarSubst pu) CompilerM
instance : MonadFVarSubst (NormalizerM t) t where
instance : MonadFVarSubst (NormalizerM pu t) pu t where
getSubst := read
/--
If `result` is `.fvar fvarId`, then return `x fvarId`. Otherwise, it is `.erased`,
and method returns `let _x.i := .erased; return _x.i`.
-/
@[inline] def withNormFVarResult [MonadLiftT CompilerM m] [Monad m] (result : NormFVarResult) (x : FVarId m Code) : m Code := do
@[inline] def withNormFVarResult [MonadLiftT CompilerM m] [Monad m] (result : NormFVarResult) (x : FVarId m (Code pu)) : m (Code pu) := do
match result with
| .fvar fvarId => x fvarId
| .erased => mkReturnErased
mutual
partial def normFunDeclImp (decl : FunDecl) : NormalizerM t FunDecl := do
partial def normFunDeclImp (decl : FunDecl pu) : NormalizerM pu t (FunDecl pu) := do
let type normExpr decl.type
let params normParams decl.params
let value normCodeImp decl.value
decl.update type params value
partial def normCodeImp (code : Code) : NormalizerM t Code := do
partial def normCodeImp (code : Code pu) : NormalizerM pu t (Code pu) := do
match code with
| .let decl k => return code.updateLet! ( normLetDecl decl) ( normCodeImp k)
| .fun decl k | .jp decl k => return code.updateFun! ( normFunDeclImp decl) ( normCodeImp k)
| .fun decl k _ | .jp decl k => return code.updateFun! ( normFunDeclImp decl) ( normCodeImp k)
| .return fvarId => withNormFVarResult ( normFVar fvarId) fun fvarId => return code.updateReturn! fvarId
| .jmp fvarId args => withNormFVarResult ( normFVar fvarId) fun fvarId => return code.updateJmp! fvarId ( normArgs args)
| .unreach type => return code.updateUnreach! ( normExpr type)
@@ -451,28 +461,28 @@ mutual
withNormFVarResult ( normFVar c.discr) fun discr => do
let alts c.alts.mapMonoM fun alt =>
match alt with
| .alt _ params k => return alt.updateAlt! ( normParams params) ( normCodeImp k)
| .alt _ params k _ => return alt.updateAlt! ( normParams params) ( normCodeImp k)
| .default k => return alt.updateCode ( normCodeImp k)
return code.updateCases! resultType discr alts
end
@[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (decl : FunDecl) : m FunDecl := do
@[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (decl : FunDecl pu) : m (FunDecl pu) := do
normFunDeclImp (t := t) decl ( getSubst)
/-- Similar to `internalize`, but does not refresh `FVarId`s. -/
@[inline] def normCode [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (code : Code) : m Code := do
@[inline] def normCode [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (code : Code pu) : m (Code pu) := do
normCodeImp (t := t) code ( getSubst)
def replaceExprFVars (e : Expr) (s : FVarSubst) (translator : Bool) : CompilerM Expr :=
(normExpr e : NormalizerM translator Expr).run s
def replaceExprFVars (e : Expr) (s : FVarSubst pu) (translator : Bool) : CompilerM Expr :=
(normExpr e : NormalizerM pu translator Expr).run s
def replaceFVars (code : Code) (s : FVarSubst) (translator : Bool) : CompilerM Code :=
(normCode code : NormalizerM translator Code).run s
def replaceFVars (code : Code pu) (s : FVarSubst pu) (translator : Bool) : CompilerM (Code pu) :=
(normCode code : NormalizerM pu translator (Code pu)).run s
def mkFreshJpName : CompilerM Name := do
mkFreshBinderName `_jp
def mkAuxParam (type : Expr) (borrow := false) : CompilerM Param := do
def mkAuxParam (type : Expr) (borrow := false) : CompilerM (Param pu) := do
mkParam ( mkFreshBinderName `_y) type borrow
def getConfig : CompilerM ConfigOptions :=

View File

@@ -12,25 +12,25 @@ public section
namespace Lean.Compiler.LCNF
instance : Hashable Param where
instance : Hashable (Param pu) where
hash p := mixHash (hash p.fvarId) (hash p.type)
def hashParams (ps : Array Param) : UInt64 :=
def hashParams (ps : Array (Param pu)) : UInt64 :=
hash ps
mutual
partial def hashAlt (alt : Alt) : UInt64 :=
partial def hashAlt (alt : Alt pu) : UInt64 :=
match alt with
| .alt ctorName ps k => mixHash (mixHash (hash ctorName) (hash ps)) (hashCode k)
| .alt ctorName ps k _ => mixHash (mixHash (hash ctorName) (hash ps)) (hashCode k)
| .default k => hashCode k
partial def hashAlts (alts : Array Alt) : UInt64 :=
partial def hashAlts (alts : Array (Alt pu)) : UInt64 :=
alts.foldl (fun r a => mixHash r (hashAlt a)) 7
partial def hashCode (code : Code) : UInt64 :=
partial def hashCode (code : Code pu) : UInt64 :=
match code with
| .let decl k => mixHash (mixHash (hash decl.fvarId) (hash decl.type)) (mixHash (hash decl.value) (hashCode k))
| .fun decl k | .jp decl k =>
| .fun decl k _ | .jp decl k =>
mixHash (mixHash (mixHash (hash decl.fvarId) (hash decl.type)) (mixHash (hashCode decl.value) (hashCode k))) (hash decl.params)
| .return fvarId => hash fvarId
| .unreach type => hash type
@@ -39,7 +39,7 @@ partial def hashCode (code : Code) : UInt64 :=
end
instance : Hashable Code where
instance : Hashable (Code pu) where
hash c := hashCode c
deriving instance Hashable for DeclValue

View File

@@ -21,46 +21,46 @@ private def typeDepOn (e : Expr) : M Bool := do
let s read
return e.hasAnyFVar fun fvarId => s.contains fvarId
private def argDepOn (a : Arg) : M Bool := do
private def argDepOn (a : Arg pu) : M Bool := do
match a with
| .erased => return false
| .fvar fvarId => fvarDepOn fvarId
| .type e => typeDepOn e
| .type e _ => typeDepOn e
private def letValueDepOn (e : LetValue) : M Bool :=
private def letValueDepOn (e : LetValue pu) : M Bool :=
match e with
| .erased | .lit .. => return false
| .proj _ _ fvarId => fvarDepOn fvarId
| .proj _ _ fvarId _ => fvarDepOn fvarId
| .fvar fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn
| .const _ _ args => args.anyM argDepOn
| .const _ _ args _ => args.anyM argDepOn
private def LetDecl.depOn (decl : LetDecl) : M Bool :=
private def LetDecl.depOn (decl : LetDecl pu) : M Bool :=
typeDepOn decl.type <||> letValueDepOn decl.value
private partial def depOn (c : Code) : M Bool :=
private partial def depOn (c : Code pu) : M Bool :=
match c with
| .let decl k => decl.depOn <||> depOn k
| .jp decl k | .fun decl k => typeDepOn decl.type <||> depOn decl.value <||> depOn k
| .jp decl k | .fun decl k _ => typeDepOn decl.type <||> depOn decl.value <||> depOn k
| .cases c => typeDepOn c.resultType <||> fvarDepOn c.discr <||> c.alts.anyM fun alt => depOn alt.getCode
| .jmp fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn
| .return fvarId => fvarDepOn fvarId
| .unreach _ => return false
@[inline] def LetDecl.dependsOn (decl : LetDecl) (s : FVarIdSet) : Bool :=
@[inline] def LetDecl.dependsOn (decl : LetDecl pu) (s : FVarIdSet) : Bool :=
decl.depOn s
@[inline] def FunDecl.dependsOn (decl : FunDecl) (s : FVarIdSet) : Bool :=
@[inline] def FunDecl.dependsOn (decl : FunDecl pu) (s : FVarIdSet) : Bool :=
typeDepOn decl.type s || depOn decl.value s
def CodeDecl.dependsOn (decl : CodeDecl) (s : FVarIdSet) : Bool :=
def CodeDecl.dependsOn (decl : CodeDecl pu) (s : FVarIdSet) : Bool :=
match decl with
| .let decl => decl.dependsOn s
| .jp decl | .fun decl => decl.dependsOn s
| .jp decl | .fun decl _ => decl.dependsOn s
/--
Return `true` is `c` depends on a free variable in `s`.
-/
def Code.dependsOn (c : Code) (s : FVarIdSet) : Bool :=
def Code.dependsOn (c : Code pu) (s : FVarIdSet) : Bool :=
depOn c s
end Lean.Compiler.LCNF

View File

@@ -19,16 +19,16 @@ Collect set of (let) free variables in a LCNF value.
This code exploits the LCNF property that local declarations do not occur in types.
-/
def collectLocalDeclsArg (s : UsedLocalDecls) (arg : Arg) : UsedLocalDecls :=
def collectLocalDeclsArg (s : UsedLocalDecls) (arg : Arg .pure) : UsedLocalDecls :=
match arg with
| .fvar fvarId => s.insert fvarId
-- Locally declared variables do not occur in types.
| .type _ | .erased => s
def collectLocalDeclsArgs (s : UsedLocalDecls) (args : Array Arg) : UsedLocalDecls :=
def collectLocalDeclsArgs (s : UsedLocalDecls) (args : Array (Arg .pure)) : UsedLocalDecls :=
args.foldl (init := s) collectLocalDeclsArg
def collectLocalDeclsLetValue (s : UsedLocalDecls) (e : LetValue) : UsedLocalDecls :=
def collectLocalDeclsLetValue (s : UsedLocalDecls) (e : LetValue .pure) : UsedLocalDecls :=
match e with
| .erased | .lit .. => s
| .proj _ _ fvarId => s.insert fvarId
@@ -39,21 +39,22 @@ namespace ElimDead
abbrev M := StateRefT UsedLocalDecls CompilerM
private abbrev collectArgM (arg : Arg) : M Unit :=
private abbrev collectArgM (arg : Arg .pure) : M Unit :=
modify (collectLocalDeclsArg · arg)
private abbrev collectLetValueM (e : LetValue) : M Unit :=
private abbrev collectLetValueM (e : LetValue .pure) : M Unit :=
modify (collectLocalDeclsLetValue · e)
private abbrev collectFVarM (fvarId : FVarId) : M Unit :=
modify (·.insert fvarId)
mutual
partial def visitFunDecl (funDecl : FunDecl) : M FunDecl := do
partial def visitFunDecl (funDecl : FunDecl .pure) : M (FunDecl .pure) := do
let value elimDead funDecl.value
funDecl.updateValue value
partial def elimDead (code : Code) : M Code := do
partial def elimDead (code : Code .pure) : M (Code .pure) := do
match code with
| .let decl k =>
let k elimDead k
@@ -84,10 +85,11 @@ end
end ElimDead
def Code.elimDead (code : Code) : CompilerM Code :=
-- TODO: Generalize this to arbitrary phases, keep in mind that in impure elim dead is not as easy though
def Code.elimDead (code : Code .pure) : CompilerM (Code .pure) :=
ElimDead.elimDead code |>.run' {}
def Decl.elimDead (decl : Decl) : CompilerM Decl := do
def Decl.elimDead (decl : Decl .pure) : CompilerM (Decl .pure) := do
return { decl with value := ( decl.value.mapCodeM Code.elimDead) }
end Lean.Compiler.LCNF

View File

@@ -239,14 +239,14 @@ Attempt to turn a `Value` that is representing a literal into a set of
auxiliary declarations + the final `FVarId` of the declaration that
contains the actual literal. If it is not a literal return none.
-/
partial def getLiteral (v : Value) : CompilerM (Option ((Array CodeDecl) × FVarId)) := do
partial def getLiteral (v : Value) : CompilerM (Option ((Array (CodeDecl .pure)) × FVarId)) := do
if isLiteral v then
let literal go v
return some literal
else
return none
where
go : Value CompilerM ((Array CodeDecl) × FVarId)
go : Value CompilerM ((Array (CodeDecl .pure)) × FVarId)
| .ctor ``Nat.zero #[] .. => do
let decl mkAuxLetDecl <| .lit <| .nat <| 0
return (#[.let decl], decl.fvarId)
@@ -260,7 +260,7 @@ where
let flatten acc := fun (decls, var) => (acc.fst ++ decls, acc.snd.push <| .fvar var)
let (decls, args) :=
fields.foldl (init := (#[], Array.replicate ctorInfo.numParams .erased)) flatten
let letVal : LetValue := .const ctorName [] args
let letVal : LetValue .pure := .const ctorName [] args
let letDecl mkAuxLetDecl letVal
return (decls.push <| .let letDecl, letDecl.fvarId)
| _ => unreachable!
@@ -328,7 +328,7 @@ structure InterpContext where
a single declaration or a mutual block of declarations where their
analysis might influence each other as we approach the fixpoint.
-/
decls : Array Decl
decls : Array (Decl .pure)
/--
The index of the function we are currently operating on in `decls.`
-/
@@ -386,7 +386,7 @@ def findVarValue (var : FVarId) : InterpM Value := do
/--
Find the value of `arg` using the logic of `findVarValue`.
-/
def findArgValue (arg : Arg) : InterpM Value := do
def findArgValue (arg : Arg .pure) : InterpM Value := do
match arg with
| .fvar fvarId => findVarValue fvarId
| _ => return .top
@@ -421,7 +421,8 @@ Furthermore if we see that `params.size != args.size` we know that this is
a partial application and set the values of the remaining parameters to
`top` since it is impossible to track what will happen with them from here on.
-/
def updateFunDeclParamsAssignment (params : Array Param) (args : Array Arg) : InterpM Bool := do
def updateFunDeclParamsAssignment (params : Array (Param .pure)) (args : Array (Arg .pure)) :
InterpM Bool := do
let mut ret := false
let env getEnv
for param in params, arg in args do
@@ -443,7 +444,7 @@ def updateFunDeclParamsAssignment (params : Array Param) (args : Array Arg) : In
updateVarAssignment param.fvarId .top
return ret
def updateFunDeclParamsTop (params : Array Param) : InterpM Bool := do
def updateFunDeclParamsTop (params : Array (Param .pure)) : InterpM Bool := do
let mut ret := false
for param in params do
let paramVal findVarValue param.fvarId
@@ -453,7 +454,7 @@ def updateFunDeclParamsTop (params : Array Param) : InterpM Bool := do
ret := true
return ret
private partial def resetNestedFunDeclParams : Code InterpM Unit
private partial def resetNestedFunDeclParams : Code .pure InterpM Unit
| .let _ k => resetNestedFunDeclParams k
| .jp decl k | .fun decl k => do
decl.params.forM (resetVarAssignment ·.fvarId)
@@ -467,7 +468,7 @@ private partial def resetNestedFunDeclParams : Code → InterpM Unit
/--
The actual abstract interpreter on a block of `Code`.
-/
partial def interpCode : Code InterpM Unit
partial def interpCode : Code .pure InterpM Unit
| .let decl k => do
let val interpLetValue decl.value
updateVarAssignment decl.fvarId val
@@ -503,7 +504,7 @@ where
/--
The abstract interpreter on a `LetValue`.
-/
interpLetValue (letVal : LetValue) : InterpM Value := do
interpLetValue (letVal : LetValue .pure) : InterpM Value := do
match letVal with
| .lit val => return .ofLCNFLit val
| .proj _ idx struct =>
@@ -513,7 +514,7 @@ where
let env getEnv
args.forM handleFunArg
match ( getDecl? declName) with
| some decl =>
| some _, decl =>
if decl.getArity == args.size then
match getFunctionSummary? env declName with
| some v => return v
@@ -538,7 +539,7 @@ where
return .top
| .erased => return .top
handleFunArg (arg : Arg) : InterpM Unit := do
handleFunArg (arg : Arg .pure) : InterpM Unit := do
if let .fvar fvarId := arg then
handleFunVar fvarId
@@ -557,7 +558,7 @@ where
resetNestedFunDeclParams funDecl.value
interpCode funDecl.value
interpFunCall (funDecl : FunDecl) (args : Array Arg) : InterpM Unit := do
interpFunCall (funDecl : FunDecl .pure) (args : Array (Arg .pure)) : InterpM Unit := do
let updated updateFunDeclParamsAssignment funDecl.params args
if updated then
/- We must reset the value of nested function declaration
@@ -608,11 +609,11 @@ Use the information produced by the abstract interpreter to:
- Eliminate branches that we know cannot be hit
- Eliminate values that we know have to be constants.
-/
partial def elimDead (assignment : Assignment) (decl : Decl) : CompilerM Decl := do
partial def elimDead (assignment : Assignment) (decl : Decl .pure) : CompilerM (Decl .pure) := do
trace[Compiler.elimDeadBranches] s!"Eliminating {decl.name} with {repr (← assignment.toArray |>.mapM (fun (name, val) => do return (toString (← getBinderName name), val)))}"
return { decl with value := ( decl.value.mapCodeM go) }
where
go (code : Code) : CompilerM Code := do
go (code : Code .pure) : CompilerM (Code .pure) := do
match code with
| .let decl k =>
return code.updateLet! decl ( go k)
@@ -624,16 +625,14 @@ where
match alt with
| .alt ctor args body =>
if discrVal.containsCtor ctor then
let filter param := do
let constantInfos args.filterMapM fun param => do
if let some val := assignment[param.fvarId]? then
if let some literal val.getLiteral then
return some (param, literal)
return none
let constantInfos args.filterMapM filter
if constantInfos.size != 0 then
let folder := fun (body, subst) (param, decls, var) => do
let (body, subst) constantInfos.foldlM (init := ( go body, {})) fun (body, subst) (param, decls, var) => do
return (attachCodeDecls decls body, subst.insert param.fvarId (.fvar var))
let (body, subst) constantInfos.foldlM (init := ( go body, {})) folder
let body replaceFVars body subst false
return alt.updateCode body
else
@@ -649,7 +648,7 @@ where
end UnreachableBranches
open UnreachableBranches in
def Decl.elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
def Decl.elimDeadBranches (decls : Array (Decl .pure)) : CompilerM (Array (Decl .pure)) := do
/-
We sort declarations by size here to ensure that when we restart in inferStep it will mostly be
small declarations that get re-analyzed.

View File

@@ -16,11 +16,11 @@ public section
namespace Lean.Compiler.LCNF
namespace ExtractClosed
abbrev ExtractM := StateRefT (Array CodeDecl) CompilerM
abbrev ExtractM := StateRefT (Array (CodeDecl .pure)) CompilerM
mutual
partial def extractLetValue (v : LetValue) : ExtractM Unit := do
partial def extractLetValue (v : LetValue .pure) : ExtractM Unit := do
match v with
| .const _ _ args => args.forM extractArg
| .fvar fnVar args =>
@@ -29,7 +29,7 @@ partial def extractLetValue (v : LetValue) : ExtractM Unit := do
| .proj _ _ baseVar => extractFVar baseVar
| .lit _ | .erased => return ()
partial def extractArg (arg : Arg) : ExtractM Unit := do
partial def extractArg (arg : Arg .pure) : ExtractM Unit := do
match arg with
| .fvar fvarId => extractFVar fvarId
| .type _ | .erased => return ()
@@ -41,17 +41,17 @@ partial def extractFVar (fvarId : FVarId) : ExtractM Unit := do
end
def isIrrelevantArg (arg : Arg) : Bool :=
def isIrrelevantArg (arg : Arg .pure) : Bool :=
match arg with
| .erased | .type _ => true
| .fvar _ => false
structure Context where
baseName : Name
sccDecls : Array Decl
sccDecls : Array (Decl .pure)
structure State where
decls : Array Decl := {}
decls : Array (Decl .pure) := {}
/--
Cache for `shouldExtractFVar` in order to avoid superlinear behavior.
-/
@@ -61,7 +61,7 @@ abbrev M := ReaderT Context $ StateRefT State CompilerM
mutual
partial def shouldExtractLetValue (isRoot : Bool) (v : LetValue) : M Bool := do
partial def shouldExtractLetValue (isRoot : Bool) (v : LetValue .pure) : M Bool := do
match v with
| .lit (.str _) => return true
| .lit (.nat v) =>
@@ -90,7 +90,7 @@ partial def shouldExtractLetValue (isRoot : Bool) (v : LetValue) : M Bool := do
| .fvar fnVar args => return ( shouldExtractFVar fnVar) && ( args.allM shouldExtractArg)
| .proj _ _ baseVar => shouldExtractFVar baseVar
partial def shouldExtractArg (arg : Arg) : M Bool := do
partial def shouldExtractArg (arg : Arg .pure) : M Bool := do
match arg with
| .fvar fvarId => shouldExtractFVar fvarId
| .type _ | .erased => return true
@@ -113,7 +113,7 @@ end
mutual
partial def visitCode (code : Code) : M Code := do
partial def visitCode (code : Code .pure) : M (Code .pure) := do
match code with
| .let decl k =>
if ( shouldExtractLetValue true decl.value) then
@@ -151,13 +151,14 @@ partial def visitCode (code : Code) : M Code := do
end
def visitDecl (decl : Decl) : M Decl := do
def visitDecl (decl : Decl .pure) : M (Decl .pure) := do
let value decl.value.mapCodeM visitCode
return { decl with value }
end ExtractClosed
partial def Decl.extractClosed (decl : Decl) (sccDecls : Array Decl) : CompilerM (Array Decl) := do
partial def Decl.extractClosed (decl : Decl .pure) (sccDecls : Array (Decl .pure)) :
CompilerM (Array (Decl .pure)) := do
let decl, s ExtractClosed.visitDecl decl |>.run { baseName := decl.name, sccDecls } |>.run {}
return s.decls.push decl

View File

@@ -48,67 +48,67 @@ instance : TraverseFVar Expr where
mapFVarM := Expr.mapFVarM
forFVarM := Expr.forFVarM
def Arg.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId m FVarId) (arg : Arg) : m Arg := do
def Arg.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId m FVarId) (arg : Arg pu) : m (Arg pu) := do
match arg with
| .erased => return .erased
| .type e => return arg.updateType! ( TraverseFVar.mapFVarM f e)
| .type e _ => return arg.updateType! ( TraverseFVar.mapFVarM f e)
| .fvar fvarId => return arg.updateFVar! ( f fvarId)
def Arg.forFVarM [Monad m] (f : FVarId m Unit) (arg : Arg) : m Unit := do
def Arg.forFVarM [Monad m] (f : FVarId m Unit) (arg : Arg pu) : m Unit := do
match arg with
| .erased => return ()
| .type e => TraverseFVar.forFVarM f e
| .type e _ => TraverseFVar.forFVarM f e
| .fvar fvarId => f fvarId
instance : TraverseFVar Arg where
instance : TraverseFVar (Arg pu) where
mapFVarM := Arg.mapFVarM
forFVarM := Arg.forFVarM
def LetValue.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId m FVarId) (e : LetValue) : m LetValue := do
def LetValue.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId m FVarId) (e : LetValue pu) : m (LetValue pu) := do
match e with
| .lit .. | .erased => return e
| .proj _ _ fvarId => return e.updateProj! ( f fvarId)
| .const _ _ args => return e.updateArgs! ( args.mapM (TraverseFVar.mapFVarM f))
| .proj _ _ fvarId _ => return e.updateProj! ( f fvarId)
| .const _ _ args _ => return e.updateArgs! ( args.mapM (TraverseFVar.mapFVarM f))
| .fvar fvarId args => return e.updateFVar! ( f fvarId) ( args.mapM (TraverseFVar.mapFVarM f))
def LetValue.forFVarM [Monad m] (f : FVarId m Unit) (e : LetValue) : m Unit := do
def LetValue.forFVarM [Monad m] (f : FVarId m Unit) (e : LetValue pu) : m Unit := do
match e with
| .lit .. | .erased => return ()
| .proj _ _ fvarId => f fvarId
| .const _ _ args => args.forM (TraverseFVar.forFVarM f)
| .proj _ _ fvarId _ => f fvarId
| .const _ _ args _ => args.forM (TraverseFVar.forFVarM f)
| .fvar fvarId args => f fvarId; args.forM (TraverseFVar.forFVarM f)
instance : TraverseFVar LetValue where
instance : TraverseFVar (LetValue pu) where
mapFVarM := LetValue.mapFVarM
forFVarM := LetValue.forFVarM
partial def LetDecl.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId m FVarId) (decl : LetDecl) : m LetDecl := do
partial def LetDecl.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId m FVarId) (decl : LetDecl pu) : m (LetDecl pu) := do
decl.update ( Expr.mapFVarM f decl.type) ( LetValue.mapFVarM f decl.value)
partial def LetDecl.forFVarM [Monad m] (f : FVarId m Unit) (decl : LetDecl) : m Unit := do
partial def LetDecl.forFVarM [Monad m] (f : FVarId m Unit) (decl : LetDecl pu) : m Unit := do
Expr.forFVarM f decl.type
LetValue.forFVarM f decl.value
instance : TraverseFVar LetDecl where
instance : TraverseFVar (LetDecl pu) where
mapFVarM := LetDecl.mapFVarM
forFVarM := LetDecl.forFVarM
partial def Param.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId m FVarId) (param : Param) : m Param := do
partial def Param.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId m FVarId) (param : Param pu) : m (Param pu) := do
param.update ( Expr.mapFVarM f param.type)
partial def Param.forFVarM [Monad m] (f : FVarId m Unit) (param : Param) : m Unit := do
partial def Param.forFVarM [Monad m] (f : FVarId m Unit) (param : Param pu) : m Unit := do
Expr.forFVarM f param.type
instance : TraverseFVar Param where
instance : TraverseFVar (Param pu) where
mapFVarM := Param.mapFVarM
forFVarM := Param.forFVarM
partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId m FVarId) (c : Code) : m Code := do
partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId m FVarId) (c : Code pu) : m (Code pu) := do
match c with
| .let decl k =>
let decl LetDecl.mapFVarM f decl
return Code.updateLet! c decl ( mapFVarM f k)
| .fun decl k =>
| .fun decl k _ =>
let params decl.params.mapM (Param.mapFVarM f)
let decl decl.update ( Expr.mapFVarM f decl.type) params ( mapFVarM f decl.value)
return Code.updateFun! c decl ( mapFVarM f k)
@@ -125,12 +125,12 @@ partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m F
| .unreach typ =>
return Code.updateUnreach! c ( Expr.mapFVarM f typ)
partial def Code.forFVarM [Monad m] (f : FVarId m Unit) (c : Code) : m Unit := do
partial def Code.forFVarM [Monad m] (f : FVarId m Unit) (c : Code pu) : m Unit := do
match c with
| .let decl k =>
LetDecl.forFVarM f decl
forFVarM f k
| .fun decl k =>
| .fun decl k _ =>
decl.params.forM (Param.forFVarM f)
Expr.forFVarM f decl.type
forFVarM f decl.value
@@ -151,45 +151,45 @@ partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code) : m Unit
| .unreach typ =>
Expr.forFVarM f typ
instance : TraverseFVar Code where
instance : TraverseFVar (Code pu) where
mapFVarM := Code.mapFVarM
forFVarM := Code.forFVarM
def FunDecl.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId m FVarId) (decl : FunDecl) : m FunDecl := do
def FunDecl.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId m FVarId) (decl : FunDecl pu) : m (FunDecl pu) := do
let params decl.params.mapM (Param.mapFVarM f)
decl.update ( Expr.mapFVarM f decl.type) params ( Code.mapFVarM f decl.value)
def FunDecl.forFVarM [Monad m] (f : FVarId m Unit) (decl : FunDecl) : m Unit := do
def FunDecl.forFVarM [Monad m] (f : FVarId m Unit) (decl : FunDecl pu) : m Unit := do
decl.params.forM (Param.forFVarM f)
Expr.forFVarM f decl.type
Code.forFVarM f decl.value
instance : TraverseFVar FunDecl where
instance : TraverseFVar (FunDecl pu) where
mapFVarM := FunDecl.mapFVarM
forFVarM := FunDecl.forFVarM
instance : TraverseFVar CodeDecl where
instance : TraverseFVar (CodeDecl pu) where
mapFVarM f decl := do
match decl with
| .fun decl => return .fun ( mapFVarM f decl)
| .fun decl _ => return .fun ( mapFVarM f decl)
| .jp decl => return .jp ( mapFVarM f decl)
| .let decl => return .let ( mapFVarM f decl)
forFVarM f decl :=
match decl with
| .fun decl => forFVarM f decl
| .fun decl _ => forFVarM f decl
| .jp decl => forFVarM f decl
| .let decl => forFVarM f decl
instance : TraverseFVar Alt where
instance : TraverseFVar (Alt pu) where
mapFVarM f alt := do
match alt with
| .alt ctor params c =>
| .alt ctor params c _ =>
let params params.mapM (Param.mapFVarM f)
return .alt ctor params ( Code.mapFVarM f c)
| .default c => return .default ( Code.mapFVarM f c)
forFVarM f alt := do
match alt with
| .alt _ params c =>
| .alt _ params c _ =>
params.forM (Param.forFVarM f)
Code.forFVarM f c
| .default c => Code.forFVarM f c

View File

@@ -46,12 +46,12 @@ inductive AbsValue where
structure Context where
/-- Declaration in the same mutual block. -/
decls : Array Decl
decls : Array (Decl .pure)
/--
Function being analyzed. We check every recursive call to this function.
Remark: `main` is in `decls`.
-/
main : Decl
main : Decl .pure
/--
The assignment maps free variable ids in the current code being analyzed to abstract values.
We only track the abstract value assigned to parameters.
@@ -84,17 +84,17 @@ def evalFVar (fvarId : FVarId) : FixParamM AbsValue := do
let some val := ( read).assignment.get? fvarId | return .top
return val
def evalArg (arg : Arg) : FixParamM AbsValue := do
def evalArg (arg : Arg .pure) : FixParamM AbsValue := do
match arg with
| .erased => return .erased
| .type (.fvar fvarId) => evalFVar fvarId
| .type _ => return .top
| .type (.fvar fvarId) _ => evalFVar fvarId
| .type _ _ => return .top
| .fvar fvarId => evalFVar fvarId
def inMutualBlock (declName : Name) : FixParamM Bool :=
return ( read).decls.any (·.name == declName)
def mkAssignment (decl : Decl) (values : Array AbsValue) : FVarIdMap AbsValue := Id.run do
def mkAssignment (decl : Decl .pure) (values : Array AbsValue) : FVarIdMap AbsValue := Id.run do
let mut assignment := {}
for param in decl.params, value in values do
assignment := assignment.insert param.fvarId value
@@ -102,12 +102,12 @@ def mkAssignment (decl : Decl) (values : Array AbsValue) : FVarIdMap AbsValue :=
mutual
partial def evalLetValue (e : LetValue) : FixParamM Unit := do
partial def evalLetValue (e : LetValue .pure) : FixParamM Unit := do
match e with
| .const declName _ args => evalApp declName args
| .const declName _ args _ => evalApp declName args
| _ => return ()
partial def isEquivalentFunDecl? (decl : FunDecl) : FixParamM (Option Nat) := do
partial def isEquivalentFunDecl? (decl : FunDecl .pure) : FixParamM (Option Nat) := do
let .let { fvarId, value := (.fvar funFvarId args), .. } k := decl.value | return none
if args.size != decl.params.size then return none
let .return retFVarId := k | return none
@@ -120,10 +120,10 @@ partial def isEquivalentFunDecl? (decl : FunDecl) : FixParamM (Option Nat) := do
if arg != .fvar param.fvarId && arg != .erased then return none
return some funIdx
partial def evalCode (code : Code) : FixParamM Unit := do
partial def evalCode (code : Code .pure) : FixParamM Unit := do
match code with
| .let decl k => evalLetValue decl.value; evalCode k
| .fun decl k =>
| .fun decl k _ =>
if let some paramIdx isEquivalentFunDecl? decl then
withReader (fun ctx =>
{ ctx with assignment := ctx.assignment.insert decl.fvarId (.val paramIdx) })
@@ -135,7 +135,7 @@ partial def evalCode (code : Code) : FixParamM Unit := do
| .cases c => c.alts.forM fun alt => evalCode alt.getCode
| .unreach .. | .jmp .. | .return .. => return ()
partial def evalApp (declName : Name) (args : Array Arg) : FixParamM Unit := do
partial def evalApp (declName : Name) (args : Array (Arg .pure)) : FixParamM Unit := do
let main := ( read).main
if declName == main.name then
-- Recursive call to the function being analyzed
@@ -180,6 +180,9 @@ def mkInitialValues (numParams : Nat) : Array AbsValue := Id.run do
end FixedParams
open FixedParams
-- TODO: consider making it phase polymorphic, this requires detecting in place mutations of
-- variables etc in addition to just graph theory
/--
Given the (potentially mutually) recursive declarations `decls`,
return a map from declaration name `decl.name` to a bit-mask `m` where `m[i]` is true
@@ -188,7 +191,7 @@ applications.
The function assumes that if a function `f` was declared in a mutual block, then `decls`
contains all (computationally relevant) functions in the mutual block.
-/
def mkFixedParamsMap (decls : Array Decl) : NameMap (Array Bool) := Id.run do
def mkFixedParamsMap (decls : Array (Decl .pure)) : NameMap (Array Bool) := Id.run do
let mut result := {}
for decl in decls do
let values := mkInitialValues decl.params.size

View File

@@ -38,7 +38,7 @@ inductive Decision where
| unknown
deriving Hashable, BEq, Inhabited, Repr
def Decision.ofAlt : Alt Decision
def Decision.ofAlt : Alt .pure Decision
| .alt name _ _ => .arm name
| .default _ => .default
@@ -50,7 +50,7 @@ structure BaseFloatContext where
All the declarations that were collected in the current LCNF basic
block up to the current statement (in reverse order for efficiency).
-/
decls : List CodeDecl := []
decls : List (CodeDecl .pure) := []
/--
The state for `FloatM`
@@ -67,7 +67,7 @@ structure FloatState where
- Which declarations do we move into a certain arm
- Which declarations do we move into the default arm
-/
newArms : Std.HashMap Decision (List CodeDecl)
newArms : Std.HashMap Decision (List (CodeDecl .pure))
/--
Use to collect relevant declarations for the floating mechanism.
@@ -82,7 +82,7 @@ abbrev FloatM := StateRefT FloatState BaseFloatM
/--
Add `decl` to the list of declarations and run `x` with that updated context.
-/
def withNewCandidate (decl : CodeDecl) (x : BaseFloatM α) : BaseFloatM α :=
def withNewCandidate (decl : CodeDecl .pure) (x : BaseFloatM α) : BaseFloatM α :=
withReader (fun r => { r with decls := decl :: r.decls }) do
x
@@ -98,7 +98,7 @@ Whether to ignore `decl` for the floating mechanism. We want to do this if:
- `decl`' is storing a typeclass instance
- `decl` is a projection from a variable that is storing a typeclass instance
-/
def ignore? (decl : LetDecl) : BaseFloatM Bool := do
def ignore? (decl : LetDecl .pure) : BaseFloatM Bool := do
if ( isArrowClass? decl.type).isSome then
return true
else if let .proj _ _ fvarId := decl.value then
@@ -117,7 +117,7 @@ up to this point, with respect to `cs`. The initial decisions are:
- `arm` or `default` if we see the declaration only being used in exactly one cases arm
- `unknown` otherwise
-/
def initialDecisions (cs : Cases) : BaseFloatM (Std.HashMap FVarId Decision) := do
def initialDecisions (cs : Cases .pure) : BaseFloatM (Std.HashMap FVarId Decision) := do
let mut map := Std.HashMap.emptyWithCapacity ( read).decls.length
let owned : Std.HashSet FVarId :=
(map, _) ( read).decls.foldlM (init := (map, owned)) fun (acc, owned) val => do
@@ -135,12 +135,12 @@ def initialDecisions (cs : Cases) : BaseFloatM (Std.HashMap FVarId Decision) :=
(_, map) goCases cs |>.run map
return map
where
visitDecl (env : Environment) (value : CodeDecl) : StateM (Std.HashSet FVarId) Bool := do
visitDecl (env : Environment) (value : CodeDecl .pure) : StateM (Std.HashSet FVarId) Bool := do
match value with
| .let decl => visitLetValue env decl.value
| _ => return false -- will need to investigate whether that can be a problem
visitLetValue (env : Environment) (value : LetValue) : StateM (Std.HashSet FVarId) Bool := do
visitLetValue (env : Environment) (value : LetValue .pure) : StateM (Std.HashSet FVarId) Bool := do
match value with
| .proj _ _ x => visitArg (.fvar x) true
| .const nm _ args =>
@@ -158,7 +158,7 @@ where
( visitArg (.fvar x) false)
| .erased | .lit _ => return false
visitArg (var : Arg) (borrowed : Bool) : StateM (Std.HashSet FVarId) Bool := do
visitArg (var : Arg .pure) (borrowed : Bool) : StateM (Std.HashSet FVarId) Bool := do
let .fvar v := var | return false
let res := ( get).contains v
unless borrowed do
@@ -173,16 +173,16 @@ where
modify fun s => s.insert var .dont
-- otherwise we already have the proper decision
goAlt (alt : Alt) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
goAlt (alt : Alt .pure) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
forFVarM (goFVar (.ofAlt alt)) alt
goCases (cs : Cases) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
goCases (cs : Cases .pure) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
cs.alts.forM goAlt
/--
Compute the initial new arms. This will just set up a map from all arms of
`cs` to empty `Array`s, plus one additional entry for `dont`.
-/
def initialNewArms (cs : Cases) : Std.HashMap Decision (List CodeDecl) := Id.run do
def initialNewArms (cs : Cases .pure) : Std.HashMap Decision (List (CodeDecl .pure)) := Id.run do
let mut map := Std.HashMap.emptyWithCapacity (cs.alts.size + 1)
map := map.insert .dont []
cs.alts.foldr (init := map) fun val acc => acc.insert (.ofAlt val) []
@@ -203,7 +203,7 @@ cases z with
Here `x` and `y` are originally marked as getting floated into `n` and `m`
respectively but since `z` can't be moved we don't want that to move `x` and `y`.
-/
def dontFloat (decl : CodeDecl) : FloatM Unit := do
def dontFloat (decl : CodeDecl .pure) : FloatM Unit := do
forFVarM goFVar decl
modify fun s => { s with newArms := s.newArms.insert .dont (decl :: s.newArms[Decision.dont]!) }
where
@@ -257,7 +257,7 @@ Will:
```
If we are at `y` `x` is still marked to be moved but we don't want that.
-/
def float (decl : CodeDecl) : FloatM Unit := do
def float (decl : CodeDecl .pure) : FloatM Unit := do
let arm := ( get).decision[decl.fvarId]!
forFVarM (goFVar · arm) decl
modify fun s => { s with newArms := s.newArms.insert arm (decl :: s.newArms[arm]!) }
@@ -273,7 +273,7 @@ where
Iterate through `decl`, pushing local declarations that are only used in one
control flow arm into said arm in order to avoid useless computations.
-/
partial def floatLetIn (decl : Decl) : CompilerM Decl := do
partial def floatLetIn (decl : Decl .pure) : CompilerM (Decl .pure) := do
let newValue decl.value.mapCodeM go |>.run {}
return { decl with value := newValue }
where
@@ -296,7 +296,7 @@ where
else
float decl
go (code : Code) : BaseFloatM Code := do
go (code : Code .pure) : BaseFloatM (Code .pure) := do
match code with
| .let decl k =>
withNewCandidate (.let decl) do
@@ -334,11 +334,12 @@ where
end FloatLetIn
def Decl.floatLetIn (decl : Decl) : CompilerM Decl := do
def Decl.floatLetIn (decl : Decl .pure) : CompilerM (Decl .pure) := do
FloatLetIn.floatLetIn decl
def floatLetIn (phase := Phase.base) (occurrence := 0) : Pass :=
.mkPerDeclaration `floatLetIn Decl.floatLetIn phase occurrence
phase.withPurityCheck .pure fun h =>
.mkPerDeclaration `floatLetIn phase (h Decl.floatLetIn) occurrence
builtin_initialize
registerTraceClass `Compiler.floatLetIn (inherited := true)

View File

@@ -14,6 +14,10 @@ public section
namespace Lean.Compiler.LCNF
/-! # Type inference for LCNF -/
namespace InferType
namespace Pure
/-
Note about **erasure confusion**.
@@ -53,10 +57,9 @@ but the expected type is `S Nat Type (fun x => Nat)`. `fun x => Nat` is not eras
here because it is a type former.
-/
namespace InferType
/-
Type inference algorithm for LCNF. Invoked by the LCNF type checker
Type inference algorithm for pure LCNF. Invoked by the LCNF type checker
to check correctness of LCNF IR.
-/
@@ -80,12 +83,12 @@ def mkForallFVars (xs : Array Expr) (type : Expr) : InferTypeM Expr :=
let b := type.abstract xs
xs.size.foldRevM (init := b) fun i _ b => do
let x := xs[i]
let n InferType.getBinderName x.fvarId!
let ty InferType.getType x.fvarId!
let n getBinderName x.fvarId!
let ty getType x.fvarId!
let ty := ty.abstractRange i xs;
return .forallE n ty b .default
def mkForallParams (params : Array Param) (type : Expr) : InferTypeM Expr :=
def mkForallParams (params : Array (Param .pure)) (type : Expr) : InferTypeM Expr :=
let xs := params.map fun p => .fvar p.fvarId
mkForallFVars xs type |>.run {}
@@ -97,7 +100,7 @@ def mkForallParams (params : Array Param) (type : Expr) : InferTypeM Expr :=
def inferConstType (declName : Name) (us : List Level) : CompilerM Expr := do
if declName == ``lcErased then
return erasedExpr
else if let some decl getDecl? declName then
else if let some _, decl getDecl? declName then
return decl.instantiateTypeLevelParams us
else
/- Declaration does not have code associated with it: constructor, inductive type, foreign function -/
@@ -114,7 +117,7 @@ def inferLitValueType (value : LitValue) : Expr :=
| .usize .. => mkConst ``USize
mutual
partial def inferArgType (arg : Arg) : InferTypeM Expr :=
partial def inferArgType (arg : Arg .pure) : InferTypeM Expr :=
match arg with
| .erased => return erasedExpr
| .type e => inferType e
@@ -124,13 +127,13 @@ mutual
match e with
| .const c us => inferConstType c us
| .app .. => inferAppType e
| .fvar fvarId => InferType.getType fvarId
| .fvar fvarId => getType fvarId
| .sort lvl => return .sort (mkLevelSucc lvl)
| .forallE .. => inferForallType e
| .lam .. => inferLambdaType e
| .letE .. | .mvar .. | .mdata .. | .lit .. | .bvar .. | .proj .. => unreachable!
partial def inferLetValueType (e : LetValue) : InferTypeM Expr := do
partial def inferLetValueType (e : LetValue .pure) : InferTypeM Expr := do
match e with
| .erased => return erasedExpr
| .lit v => return inferLitValueType v
@@ -138,7 +141,7 @@ mutual
| .const declName us args => inferAppTypeCore ( inferConstType declName us) args
| .fvar fvarId args => inferAppTypeCore ( getType fvarId) args
partial def inferAppTypeCore (fType : Expr) (args : Array Arg) : InferTypeM Expr := do
partial def inferAppTypeCore (fType : Expr) (args : Array (Arg .pure)) : InferTypeM Expr := do
let mut j := 0
let mut fType := fType
for i in *...args.size do
@@ -237,60 +240,79 @@ mutual
mkForallFVars fvars type
end
end Pure
namespace Impure
end Impure
end InferType
-- TODO
def inferType (e : Expr) : CompilerM Expr :=
InferType.inferType e |>.run {}
InferType.Pure.inferType e |>.run {}
def inferAppType (fnType : Expr) (args : Array Arg) : CompilerM Expr :=
InferType.inferAppTypeCore fnType args |>.run {}
def inferAppType (fnType : Expr) (args : Array (Arg pu)) : CompilerM Expr :=
match pu with
| .pure => InferType.Pure.inferAppTypeCore fnType args |>.run {}
| .impure => panic! "Infer type for impure unimplemented" -- TODO
def getLevel (type : Expr) : CompilerM Level := do
match ( inferType type) with
| .sort u => return u
| e => if e.isErased then return levelOne else throwError "type expected{indentExpr type}"
def Arg.inferType (arg : Arg pu) : CompilerM Expr :=
match pu with
| .pure => InferType.Pure.inferArgType arg |>.run {}
| .impure => panic! "Infer type for impure unimplemented" -- TODO
def Arg.inferType (arg : Arg) : CompilerM Expr :=
InferType.inferArgType arg |>.run {}
def LetValue.inferType (e : LetValue pu) : CompilerM Expr :=
match pu with
| .pure => InferType.Pure.inferLetValueType e |>.run {}
| .impure => panic! "Infer type for impure unimplemented" -- TODO
def LetValue.inferType (e : LetValue) : CompilerM Expr :=
InferType.inferLetValueType e |>.run {}
def Code.inferType (code : Code pu) : CompilerM Expr := do
match pu with
| .pure =>
match code with
| .let _ k | .fun _ k _ | .jp _ k => k.inferType
| .return fvarId => getType fvarId
| .jmp fvarId args => InferType.Pure.inferAppTypeCore ( getType fvarId) args |>.run {}
| .unreach type => return type
| .cases c => return c.resultType
| .impure => panic! "Infer type for impure unimplemented" -- TODO
def Code.inferType (code : Code) : CompilerM Expr := do
match code with
| .let _ k | .fun _ k | .jp _ k => k.inferType
| .return fvarId => getType fvarId
| .jmp fvarId args => InferType.inferAppTypeCore ( getType fvarId) args |>.run {}
| .unreach type => return type
| .cases c => return c.resultType
def Code.inferParamType (params : Array Param) (code : Code) : CompilerM Expr := do
def Code.inferParamType (params : Array (Param pu)) (code : Code pu) : CompilerM Expr := do
let type code.inferType
let xs := params.map fun p => .fvar p.fvarId
InferType.mkForallFVars xs type |>.run {}
InferType.Pure.mkForallFVars xs type |>.run {}
def Alt.inferType (alt : Alt) : CompilerM Expr :=
def Alt.inferType (alt : Alt pu) : CompilerM Expr :=
alt.getCode.inferType
def mkAuxLetDecl (e : LetValue) (prefixName := `_x) : CompilerM LetDecl := do
def mkAuxLetDecl (e : LetValue pu) (prefixName := `_x) : CompilerM (LetDecl pu) := do
mkLetDecl ( mkFreshBinderName prefixName) ( e.inferType) e
def mkForallParams (params : Array Param) (type : Expr) : CompilerM Expr :=
InferType.mkForallParams params type |>.run {}
def mkForallParams (params : Array (Param pu)) (type : Expr) : CompilerM Expr :=
match pu with
| .pure => InferType.Pure.mkForallParams params type |>.run {}
| .impure => panic! "Infer type for impure unimplemented" -- TODO
def mkAuxFunDecl (params : Array Param) (code : Code) (prefixName := `_f) : CompilerM FunDecl := do
private def mkAuxFunDeclAux (params : Array (Param pu)) (code : Code pu) (prefixName : Name) :
CompilerM (FunDecl pu) := do
let type mkForallParams params ( code.inferType)
let binderName mkFreshBinderName prefixName
mkFunDecl binderName type params code
def mkAuxJpDecl (params : Array Param) (code : Code) (prefixName := `_jp) : CompilerM FunDecl := do
mkAuxFunDecl params code prefixName
def mkAuxFunDecl (params : Array (Param .pure)) (code : Code .pure) (prefixName := `_f) :
CompilerM (FunDecl .pure) := do
mkAuxFunDeclAux params code prefixName
def mkAuxJpDecl' (param : Param) (code : Code) (prefixName := `_jp) : CompilerM FunDecl := do
def mkAuxJpDecl (params : Array (Param pu)) (code : Code pu) (prefixName := `_jp) :
CompilerM (FunDecl pu) := do
mkAuxFunDeclAux params code prefixName
def mkAuxJpDecl' (param : Param pu) (code : Code pu) (prefixName := `_jp) :
CompilerM (FunDecl pu) := do
let params := #[param]
mkAuxFunDecl params code prefixName
mkAuxFunDeclAux params code prefixName
def mkCasesResultType (alts : Array Alt) : CompilerM Expr := do
def mkCasesResultType (alts : Array (Alt pu)) : CompilerM Expr := do
if alts.isEmpty then
throwError "`Code.bind` failed, empty `cases` found"
let mut resultType alts[0]!.inferType

View File

@@ -22,44 +22,45 @@ private def refreshBinderName (binderName : Name) : CompilerM Name := do
namespace Internalize
abbrev InternalizeM := StateRefT FVarSubst CompilerM
abbrev InternalizeM (pu : Purity) := StateRefT (FVarSubst pu) CompilerM
/--
The `InternalizeM` monad is a translator. It "translates" the free variables
in the input expressions and `Code`, into new fresh free variables in the
local context.
-/
instance : MonadFVarSubst InternalizeM true where
instance : MonadFVarSubst (InternalizeM pu) pu true where
getSubst := get
instance : MonadFVarSubstState InternalizeM where
instance : MonadFVarSubstState (InternalizeM pu) pu where
modifySubst := modify
private def mkNewFVarId (fvarId : FVarId) : InternalizeM FVarId := do
private def mkNewFVarId (fvarId : FVarId) : InternalizeM pu FVarId := do
let fvarId' Lean.mkFreshFVarId
addFVarSubst fvarId fvarId'
return fvarId'
private partial def internalizeExpr (e : Expr) : InternalizeM Expr :=
private partial def internalizeExpr (e : Expr) : InternalizeM pu Expr :=
go e
where
goApp (e : Expr) : InternalizeM Expr := do
goApp (e : Expr) : InternalizeM pu Expr := do
match e with
| .app f a => return e.updateApp! ( goApp f) ( go a)
| _ => go e
go (e : Expr) : InternalizeM Expr := do
go (e : Expr) : InternalizeM pu Expr := do
if e.hasFVar then
match e with
| .fvar fvarId => match ( get)[fvarId]? with
| .fvar fvarId =>
match ( get)[fvarId]? with
| some (.fvar fvarId') =>
-- In LCNF, types can't depend on let-bound fvars.
if ( findParam? fvarId').isSome then
if ( findParam? (pu := pu) fvarId').isSome then
return .fvar fvarId'
else
return anyExpr
| some .erased => return erasedExpr
| some (.type e) | none => return e
| some (.type e _) | none => return e
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => return e
| .app f a => return e.updateApp! ( goApp f) ( go a) |>.headBeta
| .mdata _ b => return e.updateMData! ( go b)
@@ -70,7 +71,7 @@ where
else
return e
def internalizeParam (p : Param) : InternalizeM Param := do
def internalizeParam (p : Param pu) : InternalizeM pu (Param pu) := do
let binderName refreshBinderName p.binderName
let type internalizeExpr p.type
let fvarId mkNewFVarId p.fvarId
@@ -78,31 +79,31 @@ def internalizeParam (p : Param) : InternalizeM Param := do
modifyLCtx fun lctx => lctx.addParam p
return p
def internalizeArg (arg : Arg) : InternalizeM Arg := do
def internalizeArg (arg : Arg pu) : InternalizeM pu (Arg pu) := do
match arg with
| .fvar fvarId =>
match ( get)[fvarId]? with
| some arg'@(.fvar _) => return arg'
| some arg'@.erased | some arg'@(.type _) => return arg'
| some arg'@.erased | some arg'@(.type _ _) => return arg'
| none => return arg
| .type e => return arg.updateType! ( internalizeExpr e)
| .type e _ => return arg.updateType! ( internalizeExpr e)
| .erased => return arg
def internalizeArgs (args : Array Arg) : InternalizeM (Array Arg) :=
def internalizeArgs (args : Array (Arg pu)) : InternalizeM pu (Array (Arg pu)) :=
args.mapM internalizeArg
private partial def internalizeLetValue (e : LetValue) : InternalizeM LetValue := do
private partial def internalizeLetValue (e : LetValue pu) : InternalizeM pu (LetValue pu) := do
match e with
| .erased | .lit .. => return e
| .proj _ _ fvarId => match ( normFVar fvarId) with
| .proj _ _ fvarId _ => match ( normFVar fvarId) with
| .fvar fvarId' => return e.updateProj! fvarId'
| .erased => return .erased
| .const _ _ args => return e.updateArgs! ( internalizeArgs args)
| .const _ _ args _ => return e.updateArgs! ( internalizeArgs args)
| .fvar fvarId args => match ( normFVar fvarId) with
| .fvar fvarId' => return e.updateFVar! fvarId' ( internalizeArgs args)
| .erased => return .erased
def internalizeLetDecl (decl : LetDecl) : InternalizeM LetDecl := do
def internalizeLetDecl (decl : LetDecl pu) : InternalizeM pu (LetDecl pu) := do
let binderName refreshBinderName decl.binderName
let type internalizeExpr decl.type
let value internalizeLetValue decl.value
@@ -113,7 +114,7 @@ def internalizeLetDecl (decl : LetDecl) : InternalizeM LetDecl := do
mutual
partial def internalizeFunDecl (decl : FunDecl) : InternalizeM FunDecl := do
partial def internalizeFunDecl (decl : FunDecl pu) : InternalizeM pu (FunDecl pu) := do
let type internalizeExpr decl.type
let binderName refreshBinderName decl.binderName
let params decl.params.mapM internalizeParam
@@ -123,10 +124,10 @@ partial def internalizeFunDecl (decl : FunDecl) : InternalizeM FunDecl := do
modifyLCtx fun lctx => lctx.addFunDecl decl
return decl
partial def internalizeCode (code : Code) : InternalizeM Code := do
partial def internalizeCode (code : Code pu) : InternalizeM pu (Code pu) := do
match code with
| .let decl k => return .let ( internalizeLetDecl decl) ( internalizeCode k)
| .fun decl k => return .fun ( internalizeFunDecl decl) ( internalizeCode k)
| .fun decl k _ => return .fun ( internalizeFunDecl decl) ( internalizeCode k)
| .jp decl k => return .jp ( internalizeFunDecl decl) ( internalizeCode k)
| .return fvarId => withNormFVarResult ( normFVar fvarId) fun fvarId => return .return fvarId
| .jmp fvarId args => withNormFVarResult ( normFVar fvarId) fun fvarId => return .jmp fvarId ( internalizeArgs args)
@@ -134,19 +135,19 @@ partial def internalizeCode (code : Code) : InternalizeM Code := do
| .cases c =>
withNormFVarResult ( normFVar c.discr) fun discr => do
let resultType internalizeExpr c.resultType
let internalizeAltCode (k : Code) : InternalizeM Code :=
let internalizeAltCode (k : Code pu) : InternalizeM pu (Code pu) :=
internalizeCode k
let alts c.alts.mapM fun
| .alt ctorName params k => return .alt ctorName ( params.mapM internalizeParam) ( internalizeAltCode k)
| .alt ctorName params k _ => return .alt ctorName ( params.mapM internalizeParam) ( internalizeAltCode k)
| .default k => return .default ( internalizeAltCode k)
return .cases c.typeName, resultType, discr, alts
end
partial def internalizeCodeDecl (decl : CodeDecl) : InternalizeM CodeDecl := do
partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl pu) := do
match decl with
| .let decl => return .let ( internalizeLetDecl decl)
| .fun decl => return .fun ( internalizeFunDecl decl)
| .fun decl _ => return .fun ( internalizeFunDecl decl)
| .jp decl => return .jp ( internalizeFunDecl decl)
end Internalize
@@ -154,14 +155,14 @@ end Internalize
/--
Refresh free variables ids in `code`, and store their declarations in the local context.
-/
partial def Code.internalize (code : Code) (s : FVarSubst := {}) : CompilerM Code :=
partial def Code.internalize (code : Code pu) (s : FVarSubst pu := {}) : CompilerM (Code pu) :=
Internalize.internalizeCode code |>.run' s
open Internalize in
def Decl.internalize (decl : Decl) (s : FVarSubst := {}): CompilerM Decl :=
def Decl.internalize (decl : Decl pu) (s : FVarSubst pu := {}): CompilerM (Decl pu) :=
go decl |>.run' s
where
go (decl : Decl) : InternalizeM Decl := do
go (decl : Decl pu) : InternalizeM pu (Decl pu) := do
let type internalizeExpr decl.type
let params decl.params.mapM internalizeParam
let value decl.value.mapCodeM internalizeCode
@@ -170,13 +171,13 @@ where
/--
Create a fresh local context and internalize the given decls.
-/
def cleanup (decl : Array Decl) : CompilerM (Array Decl) := do
def cleanup (decl : Array (Decl pu)) : CompilerM (Array (Decl pu)) := do
modify fun _ => {}
decl.mapM fun decl => do
modify fun s => { s with nextIdx := 1 }
decl.internalize
def normalizeFVarIds (decl : Decl) : CoreM Decl := do
def normalizeFVarIds (decl : Decl pu) : CoreM (Decl pu) := do
let ngenSaved getNGen
setNGen {}
try

View File

@@ -92,13 +92,13 @@ private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
/--
Remove all join point candidates contained in `a`.
-/
private partial def removeCandidatesInArg (a : Arg) : FindM Unit := do
private partial def removeCandidatesInArg (a : Arg .pure) : FindM Unit := do
forFVarM eraseCandidate a
/--
Remove all join point candidates contained in `a`.
-/
private partial def removeCandidatesInLetValue (e : LetValue) : FindM Unit := do
private partial def removeCandidatesInLetValue (e : LetValue .pure) : FindM Unit := do
forFVarM eraseCandidate e
/--
@@ -117,7 +117,7 @@ private def addDependency (src : FVarId) (target : FVarId) : FindM Unit := do
{ targetInfo with associated := targetInfo.associated.insert src }
@[inline]
private def withFnBody (decl : FunDecl) (x : FindM α) : FindM α :=
private def withFnBody (decl : FunDecl .pure) (x : FindM α) : FindM α :=
withReader (fun ctx => {
ctx with
definitionDepth := ctx.definitionDepth + 1,
@@ -125,7 +125,7 @@ private def withFnBody (decl : FunDecl) (x : FindM α) : FindM α :=
x
@[inline]
private def withFnDefined (decl : FunDecl) (x : FindM α) : FindM α :=
private def withFnDefined (decl : FunDecl .pure) (x : FindM α) : FindM α :=
withReader (fun ctx => {
ctx with
scope := ctx.scope.insert decl.fvarId ctx.definitionDepth }) do
@@ -163,11 +163,11 @@ def test (b : Bool) (x y : Nat) : Nat :=
this. This is because otherwise the calls to `myjp` in `f` and `g` would
produce out of scope join point jumps.
-/
partial def find (decl : Decl) : CompilerM FindState := do
partial def find (decl : Decl .pure) : CompilerM FindState := do
let (_, candidates) decl.value.forCodeM go |>.run {} |>.run {}
return candidates
where
go : Code FindM Unit
go : Code .pure FindM Unit
| .let decl k => do
match k, decl.value with
| .return valId, .fvar fvarId args =>
@@ -207,13 +207,13 @@ where
Replace all join point candidate `fun` declarations with `jp` ones
and all calls to them with `jmp`s.
-/
partial def replace (decl : Decl) (state : FindState) : CompilerM Decl := do
partial def replace (decl : Decl .pure) (state : FindState) : CompilerM (Decl .pure) := do
let mapper := fun acc cname _ => do return acc.insert cname ( mkFreshJpName)
let replaceCtx : ReplaceCtx state.candidates.foldM (init := ) mapper
let newValue decl.value.mapCodeM go |>.run replaceCtx
return { decl with value := newValue }
where
go (code : Code) : ReplaceM Code := do
go (code : Code .pure) : ReplaceM (Code .pure) := do
match code with
| .let decl k =>
match k, decl.value with
@@ -274,7 +274,7 @@ structure ExtendState where
to `Param`s. The free variables in this map are the once that the context
of said join point will be extended by passing in the respective parameter.
-/
fvarMap : Std.HashMap FVarId (Std.HashMap FVarId Param) := {}
fvarMap : Std.HashMap FVarId (Std.HashMap FVarId (Param .pure)) := {}
/--
The monad for the `extendJoinPointContext` pass.
@@ -388,7 +388,7 @@ the join point. This is so in the case of nested join points that refer
to parameters of the current one we extend the context of the nested
join points by said parameters.
-/
def withNewJpScope (decl : FunDecl) (x : ExtendM α): ExtendM α := do
def withNewJpScope (decl : FunDecl .pure) (x : ExtendM α): ExtendM α := do
withReader (fun ctx => { ctx with currentJp? := some decl.fvarId }) do
modify fun s => { s with fvarMap := s.fvarMap.insert decl.fvarId {} }
withNewScope do
@@ -401,7 +401,7 @@ It will back up the current scope (since we are doing a case split
and want to continue with other arms afterwards) and add all of the
parameters of the match arm to the list of candidates.
-/
def withNewAltScope (alt : Alt) (x : ExtendM α) : ExtendM α := do
def withNewAltScope (alt : Alt .pure) (x : ExtendM α) : ExtendM α := do
withBackTrackingScope do
withNewCandidates (alt.getParams.map (·.fvarId)) do
x
@@ -418,7 +418,7 @@ All of this is done to eliminate dependencies of join points onto their
position within the code so we can pull them out as far as possible, hopefully
enabling new inlining possibilities in the next simplifier run.
-/
partial def extend (decl : Decl) : CompilerM Decl := do
partial def extend (decl : Decl .pure) : CompilerM (Decl .pure) := do
let newValue decl.value.mapCodeM go |>.run {} |>.run' {} |>.run' {}
let decl := { decl with value := newValue }
decl.pullFunDecls
@@ -426,7 +426,7 @@ where
goFVar (fvar : FVarId) : ExtendM FVarId := do
extendByIfNecessary fvar
replaceFVar fvar
go (code : Code) : ExtendM Code := do
go (code : Code .pure) : ExtendM (Code .pure) := do
match code with
| .let decl k =>
let decl decl.updateValue ( mapFVarM goFVar decl.value)
@@ -491,7 +491,7 @@ structure AnalysisState where
A map, that for each join point id contains a map from all (so far)
duplicated argument ids to the respective duplicate value
-/
jpJmpArgs : FVarIdMap FVarSubst := {}
jpJmpArgs : FVarIdMap (FVarSubst .pure) := {}
abbrev ReduceAnalysisM := ReaderT AnalysisCtx StateRefT AnalysisState ScopeM
abbrev ReduceActionM := ReaderT AnalysisState CompilerM
@@ -539,17 +539,17 @@ After we have performed all of these optimizations we can take away the
(remaining) common arguments and end up with nicely floated and optimized
code that has as little arguments as possible in the join points.
-/
partial def reduce (decl : Decl) : CompilerM Decl := do
partial def reduce (decl : Decl .pure) : CompilerM (Decl .pure) := do
let (_, analysis) decl.value.forCodeM goAnalyze |>.run {} |>.run {} |>.run' {}
let newValue decl.value.mapCodeM goReduce |>.run analysis
return { decl with value := newValue }
where
goAnalyzeFunDecl (fn : FunDecl) : ReduceAnalysisM Unit := do
goAnalyzeFunDecl (fn : FunDecl .pure) : ReduceAnalysisM Unit := do
withNewScope do
fn.params.forM (addToScope ·.fvarId)
goAnalyze fn.value
goAnalyze (code : Code) : ReduceAnalysisM Unit := do
goAnalyze (code : Code .pure) : ReduceAnalysisM Unit := do
match code with
| .let decl k =>
addToScope decl.fvarId
@@ -571,7 +571,7 @@ where
goAnalyze alt.getCode
cs.alts.forM visitor
| .jmp fn args =>
let decl getFunDecl fn
let decl getFunDecl (pu := .pure) fn
if let some knownArgs := ( get).jpJmpArgs.get? fn then
let mut newArgs := knownArgs
for (param, arg) in decl.params.zip args do
@@ -589,7 +589,7 @@ where
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn interestingArgs }
| .return .. | .unreach .. => return ()
goReduce (code : Code) : ReduceActionM Code := do
goReduce (code : Code .pure) : ReduceActionM (Code .pure) := do
match code with
| .jp decl k =>
if let some reducibleArgs := ( read).jpJmpArgs.get? decl.fvarId then
@@ -613,7 +613,7 @@ where
return Code.updateFun! code decl ( goReduce k)
| .jmp fn args =>
let reducibleArgs := ( read).jpJmpArgs.get! fn
let decl getFunDecl fn
let decl getFunDecl (pu := .pure) fn
let newParams := decl.params.zip args
|>.filter (!reducibleArgs.contains ·.fst.fvarId)
|>.map Prod.snd
@@ -630,7 +630,7 @@ where
end JoinPointCommonArgs
def Decl.findJoinPoints? (decl : Decl) : CompilerM (Option Decl) := do
def Decl.findJoinPoints? (decl : Decl .pure) : CompilerM (Option (Decl .pure)) := do
let findResult JoinPointFinder.find decl
trace[Compiler.findJoinPoints] "Found {findResult.candidates.size} jp candidates for {decl.name}"
if findResult.candidates.isEmpty then
@@ -642,29 +642,32 @@ def Decl.findJoinPoints? (decl : Decl) : CompilerM (Option Decl) := do
Find all `fun` declarations in `decl` that qualify as join points then replace
their definitions and call sites with `jp`/`jmp`.
-/
def Decl.findJoinPoints (decl : Decl) : CompilerM Decl := do
def Decl.findJoinPoints (decl : Decl .pure) : CompilerM (Decl .pure) := do
return ( Decl.findJoinPoints? decl).getD decl
def findJoinPoints (occurrence : Nat := 0) : Pass :=
.mkPerDeclaration `findJoinPoints Decl.findJoinPoints .base (occurrence := occurrence)
.mkPerDeclaration `findJoinPoints .base Decl.findJoinPoints (occurrence := occurrence)
builtin_initialize
registerTraceClass `Compiler.findJoinPoints (inherited := true)
def Decl.extendJoinPointContext (decl : Decl) : CompilerM Decl := do
def Decl.extendJoinPointContext (decl : Decl .pure) : CompilerM (Decl .pure) := do
JoinPointContextExtender.extend decl
-- TODO: It might make sense to extend this to impure one day
def extendJoinPointContext (occurrence : Nat := 0) (phase := Phase.mono) (_h : phase .base := by simp): Pass :=
.mkPerDeclaration `extendJoinPointContext Decl.extendJoinPointContext phase (occurrence := occurrence)
phase.withPurityCheck .pure fun h =>
.mkPerDeclaration `extendJoinPointContext phase (h Decl.extendJoinPointContext) (occurrence := occurrence)
builtin_initialize
registerTraceClass `Compiler.extendJoinPointContext (inherited := true)
def Decl.commonJoinPointArgs (decl : Decl) : CompilerM Decl := do
def Decl.commonJoinPointArgs (decl : Decl .pure) : CompilerM (Decl .pure) := do
JoinPointCommonArgs.reduce decl
-- TODO: It might make sense to extend this to impure one day
def commonJoinPointArgs : Pass :=
.mkPerDeclaration `commonJoinPointArgs Decl.commonJoinPointArgs .mono
.mkPerDeclaration `commonJoinPointArgs .mono Decl.commonJoinPointArgs
builtin_initialize
registerTraceClass `Compiler.commonJoinPointArgs (inherited := true)

View File

@@ -16,61 +16,97 @@ namespace Lean.Compiler.LCNF
LCNF local context.
-/
structure LCtx where
params : Std.HashMap FVarId Param := {}
letDecls : Std.HashMap FVarId LetDecl := {}
funDecls : Std.HashMap FVarId FunDecl := {}
paramsPure : Std.HashMap FVarId (Param .pure) := {}
paramsImpure : Std.HashMap FVarId (Param .impure) := {}
letDeclsPure : Std.HashMap FVarId (LetDecl .pure) := {}
letDeclsImpure : Std.HashMap FVarId (LetDecl .impure) := {}
funDeclsPure : Std.HashMap FVarId (FunDecl .pure) := {}
funDeclsImpure : Std.HashMap FVarId (FunDecl .impure) := {}
deriving Inhabited
def LCtx.addParam (lctx : LCtx) (param : Param) : LCtx :=
{ lctx with params := lctx.params.insert param.fvarId param }
def LCtx.addParam (lctx : LCtx) (param : Param pu) : LCtx :=
match pu with
| .pure => { lctx with paramsPure := lctx.paramsPure.insert param.fvarId param }
| .impure => { lctx with paramsImpure := lctx.paramsImpure.insert param.fvarId param }
def LCtx.addLetDecl (lctx : LCtx) (letDecl : LetDecl) : LCtx :=
{ lctx with letDecls := lctx.letDecls.insert letDecl.fvarId letDecl }
def LCtx.addLetDecl (lctx : LCtx) (letDecl : LetDecl pu) : LCtx :=
match pu with
| .pure => { lctx with letDeclsPure := lctx.letDeclsPure.insert letDecl.fvarId letDecl }
| .impure => { lctx with letDeclsImpure := lctx.letDeclsImpure.insert letDecl.fvarId letDecl }
def LCtx.addFunDecl (lctx : LCtx) (funDecl : FunDecl) : LCtx :=
{ lctx with funDecls := lctx.funDecls.insert funDecl.fvarId funDecl }
def LCtx.addFunDecl (lctx : LCtx) (funDecl : FunDecl pu) : LCtx :=
match pu with
| .pure => { lctx with funDeclsPure := lctx.funDeclsPure.insert funDecl.fvarId funDecl }
| .impure => { lctx with funDeclsImpure := lctx.funDeclsImpure.insert funDecl.fvarId funDecl }
def LCtx.eraseParam (lctx : LCtx) (param : Param) : LCtx :=
{ lctx with params := lctx.params.erase param.fvarId }
def LCtx.eraseParam (lctx : LCtx) (param : Param pu) : LCtx :=
match pu with
| .pure => { lctx with paramsPure := lctx.paramsPure.erase param.fvarId }
| .impure => { lctx with paramsImpure := lctx.paramsImpure.erase param.fvarId }
def LCtx.eraseParams (lctx : LCtx) (ps : Array Param) : LCtx :=
{ lctx with params := ps.foldl (init := lctx.params) fun params p => params.erase p.fvarId }
def LCtx.eraseParams (lctx : LCtx) (ps : Array (Param pu)) : LCtx :=
match pu with
| .pure => { lctx with paramsPure := ps.foldl (init := lctx.paramsPure) fun params p => params.erase p.fvarId }
| .impure => { lctx with paramsImpure := ps.foldl (init := lctx.paramsImpure) fun params p => params.erase p.fvarId }
def LCtx.eraseLetDecl (lctx : LCtx) (decl : LetDecl) : LCtx :=
{ lctx with letDecls := lctx.letDecls.erase decl.fvarId }
def LCtx.eraseLetDecl (lctx : LCtx) (decl : LetDecl pu) : LCtx :=
match pu with
| .pure => { lctx with letDeclsPure := lctx.letDeclsPure.erase decl.fvarId }
| .impure => { lctx with letDeclsImpure := lctx.letDeclsImpure.erase decl.fvarId }
mutual
partial def LCtx.eraseFunDecl (lctx : LCtx) (decl : FunDecl) (recursive := true) : LCtx :=
let lctx := { lctx with funDecls := lctx.funDecls.erase decl.fvarId }
partial def LCtx.eraseFunDecl (lctx : LCtx) (decl : FunDecl pu) (recursive := true) : LCtx :=
let lctx :=
match pu with
| .pure => { lctx with funDeclsPure := lctx.funDeclsPure.erase decl.fvarId }
| .impure => { lctx with funDeclsImpure := lctx.funDeclsImpure.erase decl.fvarId }
if recursive then
eraseCode decl.value <| eraseParams lctx decl.params
else
lctx
partial def LCtx.eraseAlts (alts : Array Alt) (lctx : LCtx) : LCtx :=
partial def LCtx.eraseAlts (alts : Array (Alt pu)) (lctx : LCtx) : LCtx :=
alts.foldl (init := lctx) fun lctx alt =>
match alt with
| .default k => eraseCode k lctx
| .alt _ ps k => eraseCode k <| eraseParams lctx ps
| .alt _ ps k _ => eraseCode k <| eraseParams lctx ps
partial def LCtx.eraseCode (code : Code) (lctx : LCtx) : LCtx :=
partial def LCtx.eraseCode (code : Code pu) (lctx : LCtx) : LCtx :=
match code with
| .let decl k => eraseCode k <| lctx.eraseLetDecl decl
| .jp decl k | .fun decl k => eraseCode k <| eraseFunDecl lctx decl
| .jp decl k | .fun decl k _ => eraseCode k <| eraseFunDecl lctx decl
| .cases c => eraseAlts c.alts lctx
| _ => lctx
end
@[inline]
def LCtx.params (lctx : LCtx) (pu : Purity) : Std.HashMap FVarId (Param pu) :=
match pu with
| .pure => lctx.paramsPure
| .impure => lctx.paramsImpure
@[inline]
def LCtx.letDecls (lctx : LCtx) (pu : Purity) : Std.HashMap FVarId (LetDecl pu) :=
match pu with
| .pure => lctx.letDeclsPure
| .impure => lctx.letDeclsImpure
@[inline]
def LCtx.funDecls (lctx : LCtx) (pu : Purity) : Std.HashMap FVarId (FunDecl pu) :=
match pu with
| .pure => lctx.funDeclsPure
| .impure => lctx.funDeclsImpure
/--
Convert a LCNF local context into a regular Lean local context.
-/
def LCtx.toLocalContext (lctx : LCtx) : LocalContext := Id.run do
def LCtx.toLocalContext (lctx : LCtx) (pu : Purity) : LocalContext := Id.run do
let mut result := {}
for (_, param) in lctx.params.toArray do
for (_, param) in lctx.params pu do
result := result.addDecl (.cdecl 0 param.fvarId param.binderName param.type .default .default)
for (_, decl) in lctx.letDecls.toArray do
for (_, decl) in lctx.letDecls pu do
result := result.addDecl (.ldecl 0 decl.fvarId decl.binderName decl.type decl.value.toExpr true .default)
for (_, decl) in lctx.funDecls.toArray do
for (_, decl) in lctx.funDecls pu do
result := result.addDecl (.cdecl 0 decl.fvarId decl.binderName decl.type .default .default)
return result

View File

@@ -29,7 +29,7 @@ structure Context where
Declaration where lambda lifting is being applied.
We use it to provide the "base name" for auxiliary declarations and the flag `safe`.
-/
mainDecl : Decl
mainDecl : Decl .pure
/--
If true, the lambda-lifted functions inherit the inline attribute from `mainDecl`.
We use this feature to implement `@[inline] instance ...` and `@[always_inline] instance ...`
@@ -51,7 +51,7 @@ structure State where
/--
New auxiliary declarations
-/
decls : Array Decl := #[]
decls : Array (Decl .pure) := #[]
/--
Next index for generating auxiliary declaration name.
-/
@@ -64,13 +64,13 @@ abbrev LiftM := ReaderT Context (StateRefT State (ScopeT CompilerM))
Return `true` if the given declaration takes a local instance as a parameter.
We lambda lift this kind of local function declaration before specialization.
-/
def hasInstParam (decl : FunDecl) : CompilerM Bool :=
def hasInstParam (decl : FunDecl .pure) : CompilerM Bool :=
decl.params.anyM fun param => return ( isArrowClass? param.type).isSome
/--
Return `true` if the given declaration should be lambda lifted.
-/
def shouldLift (decl : FunDecl) : LiftM Bool := do
def shouldLift (decl : FunDecl .pure) : LiftM Bool := do
let minSize := ( read).minSize
if decl.value.size < minSize then
return false
@@ -85,7 +85,7 @@ partial def mkAuxDeclName : LiftM Name := do
if ( getDecl? nameNew).isNone then return nameNew
mkAuxDeclName
def replaceFunDecl (decl : FunDecl) (value : LetValue) : LiftM LetDecl := do
def replaceFunDecl (decl : FunDecl .pure) (value : LetValue .pure) : LiftM (LetDecl .pure) := do
/- We reuse `decl`s `fvarId` to avoid substitution -/
let declNew := { fvarId := decl.fvarId, binderName := decl.binderName, type := decl.type, value }
modifyLCtx fun lctx => lctx.addLetDecl declNew
@@ -97,7 +97,7 @@ open Internalize in
Create a new auxiliary declaration. The array `closure` contains all free variables
occurring in `decl`.
-/
def mkAuxDecl (closure : Array Param) (decl : FunDecl) : LiftM LetDecl := do
def mkAuxDecl (closure : Array (Param .pure)) (decl : FunDecl .pure) : LiftM (LetDecl .pure) := do
let nameNew mkAuxDeclName
let inlineAttr? if ( read).inheritInlineAttrs then pure ( read).mainDecl.inlineAttr? else pure none
let auxDecl go nameNew ( read).mainDecl.safe inlineAttr? |>.run' {}
@@ -113,16 +113,16 @@ def mkAuxDecl (closure : Array Param) (decl : FunDecl) : LiftM LetDecl := do
let value := .const auxDeclName us (closure.map (.fvar ·.fvarId))
replaceFunDecl decl value
where
go (nameNew : Name) (safe : Bool) (inlineAttr? : Option InlineAttributeKind) : InternalizeM Decl := do
go (nameNew : Name) (safe : Bool) (inlineAttr? : Option InlineAttributeKind) : InternalizeM .pure (Decl .pure):= do
let params := ( closure.mapM internalizeParam) ++ ( decl.params.mapM internalizeParam)
let code internalizeCode decl.value
let type code.inferType
let type mkForallParams params type
let value := .code code
let decl := { name := nameNew, levelParams := [], params, type, value, safe, inlineAttr?, recursive := false : Decl }
let decl := { name := nameNew, levelParams := [], params, type, value, safe, inlineAttr?, recursive := false : Decl .pure }
return decl.setLevelParams
def etaContractibleDecl? (decl : FunDecl) : LiftM (Option LetDecl) := do
def etaContractibleDecl? (decl : FunDecl .pure) : LiftM (Option (LetDecl .pure)) := do
if !( read).allowEtaContraction then return none
let .let { fvarId := letVar, value := .const declName us args, .. } (.return retVar) := decl.value
| return none
@@ -137,11 +137,11 @@ def etaContractibleDecl? (decl : FunDecl) : LiftM (Option LetDecl) := do
replaceFunDecl decl value
mutual
partial def visitFunDecl (funDecl : FunDecl) : LiftM FunDecl := do
partial def visitFunDecl (funDecl : FunDecl .pure) : LiftM (FunDecl .pure) := do
let value withParams funDecl.params <| visitCode funDecl.value
funDecl.update' funDecl.type value
partial def visitCode (code : Code) : LiftM Code := do
partial def visitCode (code : Code .pure) : LiftM (Code .pure) := do
match code with
| .let decl k =>
let k withFVar decl.fvarId <| visitCode k
@@ -174,14 +174,14 @@ mutual
| .unreach .. | .jmp .. | .return .. => return code
end
def main (decl : Decl) : LiftM Decl := do
def main (decl : Decl .pure) : LiftM (Decl .pure) := do
let value withParams decl.params <| decl.value.mapCodeM visitCode
return { decl with value }
end LambdaLifting
partial def Decl.lambdaLifting (decl : Decl) (liftInstParamOnly : Bool) (allowEtaContraction : Bool)
(suffix : Name) (inheritInlineAttrs := false) (minSize := 0) : CompilerM (Array Decl) := do
partial def Decl.lambdaLifting (decl : Decl .pure) (liftInstParamOnly : Bool) (allowEtaContraction : Bool)
(suffix : Name) (inheritInlineAttrs := false) (minSize := 0) : CompilerM (Array (Decl .pure)) := do
let ctx := {
mainDecl := decl,
liftInstParamOnly,
@@ -214,7 +214,7 @@ def eagerLambdaLifting : Pass where
name := `eagerLambdaLifting
run := fun decls => do
decls.foldlM (init := #[]) fun decls decl => do
if decl.inlineable || ( Meta.isInstance decl.name) then
if decl.inlineable || ( isInstanceReducible decl.name) then
return decls.push decl
else
return decls ++ ( decl.lambdaLifting (liftInstParamOnly := true) (allowEtaContraction := false) (suffix := `_elam))

View File

@@ -105,45 +105,45 @@ open Lean.CollectLevelParams
abbrev visitType (type : Expr) : Visitor :=
visitExpr type
def visitArg (arg : Arg) : Visitor :=
def visitArg (arg : Arg .pure) : Visitor :=
match arg with
| .erased | .fvar .. => id
| .type e => visitType e
| .type e _ => visitType e
def visitArgs (args : Array Arg) : Visitor :=
def visitArgs (args : Array (Arg .pure)) : Visitor :=
fun s => args.foldl (init := s) fun s arg => visitArg arg s
def visitLetValue (e : LetValue) : Visitor :=
def visitLetValue (e : LetValue .pure) : Visitor :=
match e with
| .erased | .lit .. | .proj .. => id
| .const _ us args => visitLevels us visitArgs args
| .const _ us args _ => visitLevels us visitArgs args
| .fvar _ args => visitArgs args
def visitParam (p : Param) : Visitor :=
def visitParam (p : Param .pure) : Visitor :=
visitType p.type
def visitParams (ps : Array Param) : Visitor :=
def visitParams (ps : Array (Param .pure)) : Visitor :=
fun s => ps.foldl (init := s) fun s p => visitParam p s
mutual
partial def visitAlt (alt : Alt) : Visitor :=
partial def visitAlt (alt : Alt .pure) : Visitor :=
match alt with
| .default k => visitCode k
| .alt _ ps k => visitCode k visitParams ps
| .alt _ ps k _ => visitCode k visitParams ps
partial def visitAlts (alts : Array Alt) : Visitor :=
partial def visitAlts (alts : Array (Alt .pure)) : Visitor :=
fun s => alts.foldl (init := s) fun s alt => visitAlt alt s
partial def visitCode : Code Visitor
partial def visitCode : Code .pure Visitor
| .let decl k => visitCode k visitLetValue decl.value visitType decl.type
| .fun decl k | .jp decl k => visitCode k visitCode decl.value visitParams decl.params visitType decl.type
| .fun decl k _ | .jp decl k => visitCode k visitCode decl.value visitParams decl.params visitType decl.type
| .cases c => visitAlts c.alts visitType c.resultType
| .unreach type => visitType type
| .return _ => id
| .jmp _ args => visitArgs args
end
def visitDeclValue : DeclValue Visitor
def visitDeclValue : DeclValue .pure Visitor
| .code c => visitCode c
| .extern .. => id
@@ -156,7 +156,7 @@ open CollectLevelParams
Collect universe level parameters collecting in the type, parameters, and value, and then
set `decl.levelParams` with the resulting value.
-/
def Decl.setLevelParams (decl : Decl) : Decl :=
def Decl.setLevelParams (decl : Decl .pure) : Decl .pure :=
let levelParams := (visitDeclValue decl.value visitParams decl.params visitType decl.type) {} |>.params.toList
{ decl with levelParams }

View File

@@ -14,6 +14,7 @@ import Lean.Meta.Match.MatcherInfo
import Lean.Compiler.LCNF.SplitSCC
public import Lean.Compiler.IR.Basic
public import Lean.Compiler.LCNF.CompilerM
public section
namespace Lean.Compiler.LCNF
/--
@@ -50,7 +51,7 @@ A checkpoint in code generation to print all declarations in between
compiler passes in order to ease debugging.
The trace can be viewed with `set_option trace.Compiler.step true`.
-/
def checkpoint (stepName : Name) (decls : Array Decl) (shouldCheck : Bool) : CompilerM Unit := do
def checkpoint (stepName : Name) (decls : Array (Decl pu)) (shouldCheck : Bool) : CompilerM Unit := do
for decl in decls do
trace[Compiler.stat] "{decl.name} : {decl.size}"
withOptions (fun opts => opts.set `pp.motives.pi false) do
@@ -101,12 +102,12 @@ def run (declNames : Array Name) : CompilerM (Array (Array IR.Decl)) := withAtLe
let decls := markRecDecls decls
let manager getPassManager
let isCheckEnabled := compiler.check.get ( getOptions)
let decls runPassManagerPart "compilation (LCNF base)" manager.basePasses decls isCheckEnabled
let decls runPassManagerPart "compilation (LCNF mono)" manager.monoPasses decls isCheckEnabled
let decls runPassManagerPart .pure .pure "compilation (LCNF base)" manager.basePasses decls isCheckEnabled
let decls runPassManagerPart .pure .pure "compilation (LCNF mono)" manager.monoPasses decls isCheckEnabled
let sccs withTraceNode `Compiler.splitSCC (fun _ => return m!"Splitting up SCC") do
splitScc decls
sccs.mapM fun decls => do
let decls runPassManagerPart "compilation (LCNF mono)" manager.monoPassesNoLambda decls isCheckEnabled
let decls runPassManagerPart .pure .pure "compilation (LCNF mono)" manager.monoPassesNoLambda decls isCheckEnabled
if ( Lean.isTracingEnabledFor `Compiler.result) then
for decl in decls do
let decl normalizeFVarIds decl
@@ -115,14 +116,19 @@ def run (declNames : Array Name) : CompilerM (Array (Array IR.Decl)) := withAtLe
let irDecls IR.toIR decls
IR.compile irDecls
where
runPassManagerPart (profilerName : String) (passes : Array Pass) (decls : Array Decl)
(isCheckEnabled : Bool) : CompilerM (Array Decl) := do
runPassManagerPart (inPhase outPhase : Purity) (profilerName : String)
(passes : Array Pass) (decls : Array (Decl inPhase)) (isCheckEnabled : Bool) :
CompilerM (Array (Decl outPhase)) := do
profileitM Exception profilerName ( getOptions) do
let mut decls := decls
let mut state : (pu : Purity) × Array (Decl pu) := inPhase, decls
for pass in passes do
decls withTraceNode `Compiler (fun _ => return m!"compiler phase: {pass.phase}, pass: {pass.name}") do
withPhase pass.phase <| pass.run decls
withPhase pass.phaseOut <| checkpoint pass.name decls (isCheckEnabled || pass.shouldAlwaysRunCheck)
state withTraceNode `Compiler (fun _ => return m!"compiler phase: {pass.phase}, pass: {pass.name}") do
let decls withPhase pass.phase do
state.fst.withAssertPurity pass.phase.toPurity fun h => do
pass.run (h state.snd)
pure _, decls
withPhase pass.phaseOut <| checkpoint pass.name state.snd (isCheckEnabled || pass.shouldAlwaysRunCheck)
let decls := state.fst.withAssertPurity outPhase fun h => h state.snd
return decls
end PassManager

View File

@@ -33,7 +33,7 @@ instance (m n) [MonadLift m n] [MonadFunctor m n] [MonadScope m] : MonadScope n
def inScope [MonadScope m] [Monad m] (fvarId : FVarId) : m Bool :=
return ( getScope).contains fvarId
@[inline] def withParams [MonadScope m] [Monad m] (ps : Array Param) (x : m α) : m α :=
@[inline] def withParams [MonadScope m] [Monad m] (ps : Array (Param pu)) (x : m α) : m α :=
withScope (fun s => ps.foldl (init := s) fun s p => s.insert p.fvarId) x
@[inline] def withFVar [MonadScope m] [Monad m] (fvarId : FVarId) (x : m α) : m α :=

View File

@@ -99,4 +99,7 @@ def getOtherDeclMonoType (declName : Name) : CoreM Expr := do
monoTypeExt.insert declName type
return type
def getOtherDeclImpureType (_declName : Name) : CoreM Expr := do
panic! "Other decl impure type unimplemented" -- TODO
end Lean.Compiler.LCNF

View File

@@ -19,5 +19,6 @@ def getOtherDeclType (declName : Name) (us : List Level := []) : CompilerM Expr
match ( getPhase) with
| .base => getOtherDeclBaseType declName us
| .mono => getOtherDeclMonoType declName
| .impure => getOtherDeclImpureType declName
end Lean.Compiler.LCNF

View File

@@ -15,6 +15,20 @@ namespace Lean.Compiler.LCNF
@[expose] def Phase.toNat : Phase Nat
| .base => 0
| .mono => 1
| .impure => 2
instance : ToString Phase where
toString
| .base => "base"
| .mono => "mono"
| .impure => "impure"
def Phase.withPurityCheck [Inhabited α] (pp : Phase) (ip : Purity)
(x : pp.toPurity = ip α) : α :=
if h : pp.toPurity = ip then
x h
else
panic! s!"Compiler error: {pp} is not equivalent to IR phase {ip}, this is a bug"
instance : LT Phase where
lt l r := l.toNat < r.toNat
@@ -60,7 +74,7 @@ structure Pass where
/--
The actual pass function, operating on the `Decl`s.
-/
run : Array Decl CompilerM (Array Decl)
run : Array (Decl phase.toPurity) CompilerM (Array (Decl phase.toPurity))
instance : Inhabited Pass where
default := { phase := .base, name := default, run := fun decls => return decls }
@@ -90,14 +104,10 @@ structure PassManager where
monoPassesNoLambda : Array Pass
deriving Inhabited
instance : ToString Phase where
toString
| .base => "base"
| .mono => "mono"
namespace Pass
def mkPerDeclaration (name : Name) (run : Decl CompilerM Decl) (phase : Phase) (occurrence : Nat := 0) : Pass where
def mkPerDeclaration (name : Name) (phase : Phase)
(run : Decl phase.toPurity CompilerM (Decl phase.toPurity)) (occurrence : Nat := 0) : Pass where
occurrence := occurrence
phase := phase
name := name
@@ -190,6 +200,7 @@ def run (manager : PassManager) (installer : PassInstaller) : CoreM PassManager
return { manager with basePasses := ( installer.install manager.basePasses) }
| .mono =>
return { manager with monoPasses := ( installer.install manager.monoPasses) }
| .impure => panic! "Pass manager support for impure unimplemented" -- TODO
private unsafe def getPassInstallerUnsafe (declName : Name) : CoreM PassInstaller := do
ofExcept <| ( getEnv).evalConstCheck PassInstaller ( getOptions) ``PassInstaller declName

View File

@@ -45,10 +45,15 @@ private builtin_initialize baseTransparentDeclsExt : EnvExtension (List Name ×
Set of public declarations whose mono bodies should be exported to other modules
-/
private builtin_initialize monoTransparentDeclsExt : EnvExtension (List Name × NameSet) mkDeclSetExt
/--
Set of public declarations whose impure bodies should be exported to other modules
-/
private builtin_initialize impureTransparentDeclsExt : EnvExtension (List Name × NameSet) mkDeclSetExt
private def getTransparencyExt : Phase EnvExtension (List Name × NameSet)
| .base => baseTransparentDeclsExt
| .mono => monoTransparentDeclsExt
| .impure => impureTransparentDeclsExt
def isDeclPublic (env : Environment) (declName : Name) : Bool := Id.run do
if !env.header.isModule then
@@ -81,26 +86,28 @@ def setDeclTransparent (env : Environment) (phase : Phase) (declName : Name) : E
getTransparencyExt phase |>.modifyState env fun s =>
(declName :: s.1, s.2.insert declName)
abbrev DeclExtState := PHashMap Name Decl
abbrev DeclExtState (pu : Purity) := PHashMap Name (Decl pu)
private abbrev declLt (a b : Decl) :=
private abbrev declLt (a b : Decl pu) :=
Name.quickLt a.name b.name
private def sortedDecls (s : DeclExtState) : Array Decl :=
private def sortedDecls (s : DeclExtState pu) : Array (Decl pu) :=
let decls := s.foldl (init := #[]) fun ps _ v => ps.push v
decls.qsort declLt
private abbrev findAtSorted? (decls : Array Decl) (declName : Name) : Option Decl :=
let tmpDecl : Decl := default
private abbrev findAtSorted? (decls : Array (Decl pu)) (declName : Name) : Option (Decl pu) :=
let tmpDecl : Decl pu := default
let tmpDecl := { tmpDecl with name := declName }
decls.binSearch tmpDecl declLt
@[expose] def DeclExt := PersistentEnvExtension Decl Decl DeclExtState
@[expose] def DeclExt (pu : Purity) :=
PersistentEnvExtension (Decl pu) (Decl pu) (DeclExtState pu)
instance : Inhabited DeclExt :=
inferInstanceAs (Inhabited (PersistentEnvExtension Decl Decl DeclExtState))
instance : Inhabited (DeclExt pu) :=
inferInstanceAs (Inhabited (PersistentEnvExtension (Decl pu) (Decl pu) (DeclExtState pu)))
def mkDeclExt (phase : Phase) (name : Name := by exact decl_name%) : IO DeclExt :=
def mkDeclExt (phase : Phase) (name : Name := by exact decl_name%) :
IO (DeclExt phase.toPurity) :=
registerPersistentEnvExtension {
name,
mkInitial := pure {},
@@ -128,74 +135,77 @@ def mkDeclExt (phase : Phase) (name : Name := by exact decl_name%) : IO DeclExt
otherState.insert k v
}
builtin_initialize baseExt : DeclExt mkDeclExt .base
builtin_initialize monoExt : DeclExt mkDeclExt .mono
builtin_initialize baseExt : DeclExt .pure mkDeclExt .base
builtin_initialize monoExt : DeclExt .pure mkDeclExt .mono
builtin_initialize impureExt : DeclExt .impure mkDeclExt .impure
def getDeclCore? (env : Environment) (ext : DeclExt) (declName : Name) : Option Decl :=
def getDeclCore? (env : Environment) (ext : DeclExt pu) (declName : Name) : Option (Decl pu) :=
match env.getModuleIdxFor? declName with
| some modIdx => findAtSorted? (ext.getModuleEntries env modIdx) declName
| none => ext.getState env |>.find? declName
def getBaseDecl? (declName : Name) : CoreM (Option Decl) := do
def getBaseDecl? (declName : Name) : CoreM (Option (Decl .pure)) := do
return getDeclCore? ( getEnv) baseExt declName
def getMonoDecl? (declName : Name) : CoreM (Option Decl) := do
def getMonoDecl? (declName : Name) : CoreM (Option (Decl .pure)) := do
return getDeclCore? ( getEnv) monoExt declName
def saveBaseDeclCore (env : Environment) (decl : Decl) : Environment :=
def getImpureDecl? (declName : Name) : CoreM (Option (Decl .impure)) := do
return getDeclCore? ( getEnv) impureExt declName
def saveBaseDeclCore (env : Environment) (decl : Decl .pure) : Environment :=
baseExt.addEntry env decl
def saveMonoDeclCore (env : Environment) (decl : Decl) : Environment :=
def saveMonoDeclCore (env : Environment) (decl : Decl .pure) : Environment :=
monoExt.addEntry env decl
def Decl.saveBase (decl : Decl) : CoreM Unit :=
def saveImpureDeclCore (env : Environment) (decl : Decl .impure) : Environment :=
impureExt.addEntry env decl
def Decl.saveBase (decl : Decl .pure) : CoreM Unit :=
modifyEnv (saveBaseDeclCore · decl)
def Decl.saveMono (decl : Decl) : CoreM Unit :=
def Decl.saveMono (decl : Decl .pure) : CoreM Unit :=
modifyEnv (saveMonoDeclCore · decl)
def Decl.save (decl : Decl) : CompilerM Unit := do
match ( getPhase) with
| .base => decl.saveBase
| .mono => decl.saveMono
def Decl.saveImpure (decl : Decl .impure) : CoreM Unit :=
modifyEnv (saveImpureDeclCore · decl)
def getDeclAt? (declName : Name) (phase : Phase) : CoreM (Option Decl) :=
def Decl.save (decl : Decl pu) : CompilerM Unit := do
match ( getPhase) with
| .base => Phase.withPurityCheck .base pu fun h =>
(h.symm decl).saveBase
| .mono => Phase.withPurityCheck .mono pu fun h =>
(h.symm decl).saveMono
| .impure => Phase.withPurityCheck .impure pu fun h =>
(h.symm decl).saveImpure
def getDeclAt? (declName : Name) (phase : Phase) : CoreM (Option (Decl phase.toPurity)) :=
match phase with
| .base => getBaseDecl? declName
| .mono => getMonoDecl? declName
| .impure => getImpureDecl? declName
def getDecl? (declName : Name) : CompilerM (Option Decl) := do
getDeclAt? declName ( getPhase)
@[inline]
def getDecl? (declName : Name) : CompilerM (Option ((pu : Purity) × Decl pu)) := do
let some decl getDeclAt? declName ( getPhase) | return none
return some _, decl
def getLocalDeclAt? (declName : Name) (phase : Phase) : CompilerM (Option Decl) := do
def getLocalDeclAt? (declName : Name) (phase : Phase) : CompilerM (Option (Decl phase.toPurity)) := do
match phase with
| .base => return baseExt.getState ( getEnv) |>.find? declName
| .mono => return monoExt.getState ( getEnv) |>.find? declName
| .impure => return impureExt.getState ( getEnv) |>.find? declName
def getLocalDecl? (declName : Name) : CompilerM (Option Decl) := do
getLocalDeclAt? declName ( getPhase)
@[inline]
def getLocalDecl? (declName : Name) : CompilerM (Option ((pu : Purity) × Decl pu)) := do
let some decl getLocalDeclAt? declName ( getPhase) | return none
return some _, decl
def getExt (phase : Phase) : DeclExt :=
def getExt (phase : Phase) : DeclExt phase.toPurity :=
match phase with
| .base => baseExt
| .mono => monoExt
def forEachDecl (f : Decl CoreM Unit) (phase := Phase.base) : CoreM Unit := do
let ext := getExt phase
let env getEnv
for modIdx in *...env.allImportedModuleNames.size do
for decl in ext.getModuleEntries env modIdx do
f decl
ext.getState env |>.forM fun _ decl => f decl
def forEachModuleDecl (moduleName : Name) (f : Decl CoreM Unit) (phase := Phase.base) : CoreM Unit := do
let ext := getExt phase
let env getEnv
let some modIdx := env.getModuleIdx? moduleName | throwError "module `{moduleName}` not found"
for decl in ext.getModuleEntries env modIdx do
f decl
def forEachMainModuleDecl (f : Decl CoreM Unit) (phase := Phase.base) : CoreM Unit := do
(getExt phase).getState ( getEnv) |>.forM fun _ decl => f decl
| .impure => impureExt
end Lean.Compiler.LCNF

View File

@@ -43,11 +43,11 @@ def ppFVar (fvarId : FVarId) : M Format :=
def ppExpr (e : Expr) : M Format := do
Meta.ppExpr e |>.run' { lctx := ( read) }
def ppArg (e : Arg) : M Format := do
def ppArg (e : Arg pu) : M Format := do
match e with
| .erased => return ""
| .fvar fvarId => ppFVar fvarId
| .type e =>
| .type e _ =>
if pp.explicit.get ( getOptions) then
if e.isConst || e.isProp || e.isType0 || e.isFVar then
ppExpr e
@@ -56,7 +56,7 @@ def ppArg (e : Arg) : M Format := do
else
return "_"
def ppArgs (args : Array Arg) : M Format := do
def ppArgs (args : Array (Arg pu)) : M Format := do
prefixJoin " " args ppArg
def ppLitValue (lit : LitValue) : M Format := do
@@ -64,49 +64,49 @@ def ppLitValue (lit : LitValue) : M Format := do
| .nat v | .uint8 v | .uint16 v | .uint32 v | .uint64 v | .usize v => return format v
| .str v => return format (repr v)
def ppLetValue (e : LetValue) : M Format := do
def ppLetValue (e : LetValue pu) : M Format := do
match e with
| .erased => return ""
| .lit v => ppLitValue v
| .proj _ i fvarId => return f!"{← ppFVar fvarId} # {i}"
| .proj _ i fvarId _ => return f!"{← ppFVar fvarId} # {i}"
| .fvar fvarId args => return f!"{← ppFVar fvarId}{← ppArgs args}"
| .const declName us args => return f!"{← ppExpr (.const declName us)}{← ppArgs args}"
| .const declName us args _ => return f!"{← ppExpr (.const declName us)}{← ppArgs args}"
def ppParam (param : Param) : M Format := do
def ppParam (param : Param pu) : M Format := do
let borrow := if param.borrow then "@&" else ""
if pp.funBinderTypes.get ( getOptions) then
return Format.paren f!"{param.binderName} : {borrow}{← ppExpr param.type}"
else
return format s!"{borrow}{param.binderName}"
def ppParams (params : Array Param) : M Format := do
def ppParams (params : Array (Param pu)) : M Format := do
prefixJoin " " params ppParam
def ppLetDecl (letDecl : LetDecl) : M Format := do
def ppLetDecl (letDecl : LetDecl pu) : M Format := do
if pp.letVarTypes.get ( getOptions) then
return f!"let {letDecl.binderName} : {← ppExpr letDecl.type} := {← ppLetValue letDecl.value}"
else
return f!"let {letDecl.binderName} := {← ppLetValue letDecl.value}"
def getFunType (ps : Array Param) (type : Expr) : CoreM Expr :=
def getFunType (ps : Array (Param pu)) (type : Expr) : CoreM Expr :=
if type.isErased then
pure type
else
instantiateForall type (ps.map (mkFVar ·.fvarId))
mutual
partial def ppFunDecl (funDecl : FunDecl) : M Format := do
partial def ppFunDecl (funDecl : FunDecl pu) : M Format := do
return f!"{funDecl.binderName}{← ppParams funDecl.params} : {← ppExpr (← getFunType funDecl.params funDecl.type)} :={indentD (← ppCode funDecl.value)}"
partial def ppAlt (alt : Alt) : M Format := do
partial def ppAlt (alt : Alt pu) : M Format := do
match alt with
| .default k => return f!"| _ =>{indentD (← ppCode k)}"
| .alt ctorName params k => return f!"| {ctorName}{← ppParams params} =>{indentD (← ppCode k)}"
| .alt ctorName params k _ => return f!"| {ctorName}{← ppParams params} =>{indentD (← ppCode k)}"
partial def ppCode (c : Code) : M Format := do
partial def ppCode (c : Code pu) : M Format := do
match c with
| .let decl k => return ( ppLetDecl decl) ++ ";" ++ .line ++ ( ppCode k)
| .fun decl k => return f!"fun " ++ ( ppFunDecl decl) ++ ";" ++ .line ++ ( ppCode k)
| .fun decl k _ => return f!"fun " ++ ( ppFunDecl decl) ++ ";" ++ .line ++ ( ppCode k)
| .jp decl k => return f!"jp " ++ ( ppFunDecl decl) ++ ";" ++ .line ++ ( ppCode k)
| .cases c => return f!"cases {← ppFVar c.discr} : {← ppExpr c.resultType}{← prefixJoin .line c.alts ppAlt}"
| .return fvarId => return f!"return {← ppFVar fvarId}"
@@ -117,7 +117,7 @@ mutual
else
return ""
partial def ppDeclValue (b : DeclValue) : M Format := do
partial def ppDeclValue (b : DeclValue pu) : M Format := do
match b with
| .code c => ppCode c
| .extern .. => return "extern"
@@ -125,21 +125,21 @@ end
def run (x : M α) : CompilerM α :=
withOptions (pp.sanitizeNames.set · false) do
x |>.run ( get).lctx.toLocalContext
x |>.run (( get).lctx.toLocalContext ( getPurity))
end PP
def ppCode (code : Code) : CompilerM Format :=
def ppCode (code : Code pu) : CompilerM Format :=
PP.run <| PP.ppCode code
def ppLetValue (e : LetValue) : CompilerM Format :=
def ppLetValue (e : LetValue pu) : CompilerM Format :=
PP.run <| PP.ppLetValue e
def ppDecl (decl : Decl) : CompilerM Format :=
def ppDecl (decl : Decl pu) : CompilerM Format :=
PP.run do
return f!"def {decl.name}{← PP.ppParams decl.params} : {← PP.ppExpr (← PP.getFunType decl.params decl.type)} :={indentD (← PP.ppDeclValue decl.value)}"
def ppFunDecl (decl : FunDecl) : CompilerM Format :=
def ppFunDecl (decl : FunDecl pu) : CompilerM Format :=
PP.run do
return f!"fun {← PP.ppFunDecl decl}"
@@ -159,7 +159,7 @@ Similar to `ppDecl`, but in `CoreM`, and it does not assume
`decl` has already been internalized.
This function is used for debugging purposes.
-/
def ppDecl' (decl : Decl) : CoreM Format := do
def ppDecl' (decl : Decl pu) : CoreM Format := do
runCompilerWithoutModifyingState do
ppDecl ( decl.internalize)
@@ -167,7 +167,7 @@ def ppDecl' (decl : Decl) : CoreM Format := do
Similar to `ppCode`, but in `CoreM`, and it does not assume
`code` has already been internalized.
-/
def ppCode' (code : Code) : CoreM Format := do
def ppCode' (code : Code pu) : CoreM Format := do
runCompilerWithoutModifyingState do
ppCode ( code.internalize)

View File

@@ -26,7 +26,7 @@ def filter (f : α → CompilerM Bool) : Probe α α := fun data => data.filterM
def sorted [Inhabited α] [LT α] [DecidableLT α] : Probe α α := fun data => return data.qsort (· < ·)
@[inline]
def sortedBySize : Probe Decl (Nat × Decl) := fun decls =>
def sortedBySize (pu : Purity) : Probe (Decl pu) (Nat × Decl pu) := fun decls =>
let decls := decls.map fun decl => (decl.size, decl)
return decls.qsort fun (sz₁, decl₁) (sz₂, decl₂) =>
if sz₁ == sz₂ then Name.lt decl₁.name decl₂.name else sz₁ < sz₂
@@ -44,116 +44,118 @@ def countUnique [ToString α] [BEq α] [Hashable α] : Probe α (α × Nat) := f
def countUniqueSorted [ToString α] [BEq α] [Hashable α] [Inhabited α] : Probe α (α × Nat) :=
countUnique >=> fun data => return data.qsort (fun l r => l.snd < r.snd)
partial def getLetValues : Probe Decl LetValue := fun decls => do
partial def getLetValues (pu : Purity) : Probe (Decl pu) (LetValue pu) := fun decls => do
let (_, res) start decls |>.run #[]
return res
where
go (c : Code) : StateRefT (Array LetValue) CompilerM Unit := do
go (c : Code pu) : StateRefT (Array (LetValue pu)) CompilerM Unit := do
match c with
| .let (decl : LetDecl) (k : Code) =>
| .let decl k =>
modify fun s => s.push decl.value
go k
| .fun decl k | .jp decl k =>
| .fun decl k _ | .jp decl k =>
go decl.value
go k
| .cases cs => cs.alts.forM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return ()
start (decls : Array Decl) : StateRefT (Array LetValue) CompilerM Unit :=
start (decls : Array (Decl pu)) : StateRefT (Array (LetValue pu)) CompilerM Unit :=
decls.forM (·.value.forCodeM go)
partial def getJps : Probe Decl FunDecl := fun decls => do
partial def getJps (pu : Purity) : Probe (Decl pu) (FunDecl pu) := fun decls => do
let (_, res) start decls |>.run #[]
return res
where
go (code : Code) : StateRefT (Array FunDecl) CompilerM Unit := do
go (code : Code pu) : StateRefT (Array (FunDecl pu)) CompilerM Unit := do
match code with
| .let _ k => go k
| .fun decl k => go decl.value; go k
| .fun decl k _ => go decl.value; go k
| .jp decl k => modify (·.push decl); go decl.value; go k
| .cases cs => cs.alts.forM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return ()
start (decls : Array Decl) : StateRefT (Array FunDecl) CompilerM Unit :=
start (decls : Array (Decl pu)) : StateRefT (Array (FunDecl pu)) CompilerM Unit :=
decls.forM (·.value.forCodeM go)
partial def filterByLet (f : LetDecl CompilerM Bool) : Probe Decl Decl :=
partial def filterByLet (pu : Purity) (f : LetDecl pu CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
where
go : Code CompilerM Bool
go : Code pu CompilerM Bool
| .let decl k => do if ( f decl) then return true else go k
| .fun decl k | .jp decl k => go decl.value <||> go k
| .fun decl k _ | .jp decl k => go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByFun (f : FunDecl CompilerM Bool) : Probe Decl Decl :=
partial def filterByFun (pu : Purity) (f : FunDecl pu CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
where
go : Code CompilerM Bool
go : Code pu CompilerM Bool
| .let _ k | .jp _ k => go k
| .fun decl k => do if ( f decl) then return true else go decl.value <||> go k
| .fun decl k _ => do if ( f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByJp (f : FunDecl CompilerM Bool) : Probe Decl Decl :=
partial def filterByJp (pu : Purity) (f : FunDecl pu CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
where
go : Code CompilerM Bool
go : Code pu CompilerM Bool
| .let _ k => go k
| .fun decl k => go decl.value <||> go k
| .fun decl k _ => go decl.value <||> go k
| .jp decl k => do if ( f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByFunDecl (f : FunDecl CompilerM Bool) : Probe Decl Decl :=
partial def filterByFunDecl (pu : Purity) (f : FunDecl pu CompilerM Bool) :
Probe (Decl pu) (Decl pu):=
filter (·.value.isCodeAndM go)
where
go : Code CompilerM Bool
go : Code pu CompilerM Bool
| .let _ k => go k
| .fun decl k | .jp decl k => do if ( f decl) then return true else go decl.value <||> go k
| .fun decl k _ | .jp decl k => do if ( f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByCases (f : Cases CompilerM Bool) : Probe Decl Decl :=
partial def filterByCases (pu : Purity) (f : Cases pu CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
where
go : Code CompilerM Bool
go : Code pu CompilerM Bool
| .let _ k => go k
| .fun decl k | .jp decl k => go decl.value <||> go k
| .fun decl k _ | .jp decl k => go decl.value <||> go k
| .cases cs => do if ( f cs) then return true else cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByJmp (f : FVarId Array Arg CompilerM Bool) : Probe Decl Decl :=
partial def filterByJmp (pu : Purity) (f : FVarId Array (Arg pu) CompilerM Bool) :
Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
where
go : Code CompilerM Bool
go : Code pu CompilerM Bool
| .let _ k => go k
| .fun decl k | .jp decl k => go decl.value <||> go k
| .fun decl k _ | .jp decl k => go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp fn var => f fn var
| .return .. | .unreach .. => return false
partial def filterByReturn (f : FVarId CompilerM Bool) : Probe Decl Decl :=
partial def filterByReturn (pu : Purity) (f : FVarId CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
where
go : Code CompilerM Bool
go : Code pu CompilerM Bool
| .let _ k => go k
| .fun decl k | .jp decl k => go decl.value <||> go k
| .fun decl k _ | .jp decl k => go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .unreach .. => return false
| .return var => f var
partial def filterByUnreach (f : Expr CompilerM Bool) : Probe Decl Decl :=
partial def filterByUnreach (pu : Purity) (f : Expr CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
where
go : Code CompilerM Bool
go : Code pu CompilerM Bool
| .let _ k => go k
| .fun decl k | .jp decl k => go decl.value <||> go k
| .fun decl k _ | .jp decl k => go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. => return false
| .unreach typ => f typ
@[inline]
def declNames : Probe Decl Name :=
def declNames (pu : Purity) : Probe (Decl pu) Name :=
Probe.map (fun decl => return decl.name)
@[inline]
@@ -172,7 +174,8 @@ def tail (n : Nat) : Probe α α := fun data => return data[(data.size - n)...*]
@[inline]
def head (n : Nat) : Probe α α := fun data => return data[*...n]
def runOnDeclsNamed (declNames : Array Name) (probe : Probe Decl β) (phase : Phase := Phase.base): CoreM (Array β) := do
def runOnDeclsNamed (declNames : Array Name) (phase : Phase := Phase.base)
(probe : Probe (Decl phase.toPurity) β) : CoreM (Array β) := do
let ext := getExt phase
let env getEnv
let decls declNames.mapM fun name => do
@@ -180,14 +183,15 @@ def runOnDeclsNamed (declNames : Array Name) (probe : Probe Decl β) (phase : Ph
return decl
probe decls |>.run (phase := phase)
def runOnModule (moduleName : Name) (probe : Probe Decl β) (phase : Phase := Phase.base): CoreM (Array β) := do
def runOnModule (moduleName : Name) (phase : Phase := Phase.base)
(probe : Probe (Decl phase.toPurity) β) : CoreM (Array β) := do
let ext := getExt phase
let env getEnv
let some modIdx := env.getModuleIdx? moduleName | throwError "module `{moduleName}` not found"
let decls := ext.getModuleEntries env modIdx
probe decls |>.run (phase := phase)
def runGlobally (probe : Probe Decl β) (phase : Phase := Phase.base) : CoreM (Array β) := do
def runGlobally (phase : Phase := Phase.base) (probe : Probe (Decl phase.toPurity) β) : CoreM (Array β) := do
let ext := getExt phase
let env getEnv
let mut decls := #[]
@@ -195,7 +199,7 @@ def runGlobally (probe : Probe Decl β) (phase : Phase := Phase.base) : CoreM (A
decls := decls.append <| ext.getModuleEntries env modIdx
probe decls |>.run (phase := phase)
def toPass [ToString β] (probe : Probe Decl β) (phase : Phase) : Pass where
def toPass [ToString β] (phase : Phase) (probe : Probe (Decl phase.toPurity) β) : Pass where
phase := phase
name := `probe
run := fun decls => do

View File

@@ -19,7 +19,7 @@ Local function declaration and join point being pulled.
-/
structure ToPull where
isFun : Bool
decl : FunDecl
decl : FunDecl .pure
used : FVarIdHashSet
deriving Inhabited
@@ -50,7 +50,8 @@ where
else
go as (a :: keep) dep
partial def findFVarDepsFixpoint (todo : List ToPull) (acc : Array ToPull := #[]) : PullM (Array ToPull) := do
partial def findFVarDepsFixpoint (todo : List ToPull) (acc : Array ToPull := #[]) :
PullM (Array ToPull) := do
match todo with
| [] => return acc
| p :: ps =>
@@ -65,7 +66,7 @@ partial def findFVarDeps (fvarId : FVarId) : PullM (Array ToPull) := do
Similar to `findFVarDeps`. Extract from the state any local function declarations that depends on the given
parameters.
-/
def findParamsDeps (params : Array Param) : PullM (Array ToPull) := do
def findParamsDeps (params : Array (Param pu)) : PullM (Array ToPull) := do
let mut acc := #[]
for param in params do
acc := acc ++ ( findFVarDeps param.fvarId)
@@ -74,7 +75,7 @@ def findParamsDeps (params : Array Param) : PullM (Array ToPull) := do
/--
Construct the code `fun p.decl k` or `jp p.decl k`.
-/
def ToPull.attach (p : ToPull) (k : Code) : Code :=
def ToPull.attach (p : ToPull) (k : Code .pure) : Code .pure :=
if p.isFun then
.fun p.decl k
else
@@ -83,19 +84,19 @@ def ToPull.attach (p : ToPull) (k : Code) : Code :=
/--
Attach the given array of local function declarations and join points to `k`.
-/
partial def attach (ps : Array ToPull) (k : Code) : Code := Id.run do
partial def attach (ps : Array ToPull) (k : Code .pure) : Code .pure := Id.run do
let visited := ps.map fun _ => false
let (_, (k, _)) := go |>.run (k, visited)
return k
where
go : StateM (Code × Array Bool) Unit := do
go : StateM (Code .pure × Array Bool) Unit := do
for i in *...ps.size do
visit i
visited (i : Nat) : StateM (Code × Array Bool) Bool :=
visited (i : Nat) : StateM (Code .pure × Array Bool) Bool :=
return ( get).2[i]!
visit (i : Nat) : StateM (Code × Array Bool) Unit := do
visit (i : Nat) : StateM (Code .pure × Array Bool) Unit := do
unless ( visited i) do
modify fun (k, visited) => (k, visited.set! i true)
let pi := ps[i]!
@@ -110,7 +111,7 @@ where
Extract from the state any local function declarations that depends on the given
free variable, **and** attach to code `k`.
-/
partial def attachFVarDeps (fvarId : FVarId) (k : Code) : PullM Code := do
partial def attachFVarDeps (fvarId : FVarId) (k : Code .pure) : PullM (Code .pure) := do
let ps findFVarDeps fvarId
return attach ps k
@@ -118,11 +119,11 @@ partial def attachFVarDeps (fvarId : FVarId) (k : Code) : PullM Code := do
Similar to `attachFVarDeps`. Extract from the state any local function declarations that depends on the given
parameters, **and** attach to code `k`.
-/
def attachParamsDeps (params : Array Param) (k : Code) : PullM Code := do
def attachParamsDeps (params : Array (Param .pure)) (k : Code .pure) : PullM (Code .pure) := do
let ps findParamsDeps params
return attach ps k
def attachJps (k : Code) : PullM Code := do
def attachJps (k : Code .pure) : PullM (Code .pure) := do
let jps := ( get).filter fun info => !info.isFun
modify fun s => s.filter fun info => info.isFun
let jps findFVarDepsFixpoint jps
@@ -132,7 +133,7 @@ mutual
/--
Add local function declaration (or join point if `isFun = false`) to the state.
-/
partial def addToPull (isFun : Bool) (decl : FunDecl) : PullM Unit := do
partial def addToPull (isFun : Bool) (decl : FunDecl .pure) : PullM Unit := do
let saved get
modify fun _ => []
let mut value pull decl.value
@@ -147,19 +148,19 @@ partial def addToPull (isFun : Bool) (decl : FunDecl) : PullM Unit := do
Pull local function declarations and join points in `code`.
The state contains the declarations being pulled.
-/
partial def pull (code : Code) : PullM Code := do
partial def pull (code : Code .pure) : PullM (Code .pure) := do
match code with
| .let decl k =>
let k pull k
let k attachFVarDeps decl.fvarId k
return code.updateLet! decl k
| .fun decl k => addToPull true decl; pull k
| .fun decl k _ => addToPull true decl; pull k
| .jp decl k => addToPull false decl; pull k
| .cases c =>
let alts c.alts.mapMonoM fun alt => do
match alt with
| .default k => return alt.updateCode ( pull k)
| .alt _ ps k =>
| .alt _ ps k _ =>
let k pull k
let k attachParamsDeps ps k
return alt.updateCode k
@@ -174,13 +175,13 @@ open PullFunDecls
/--
Pull local function declarations and join points in the given declaration.
-/
def Decl.pullFunDecls (decl : Decl) : CompilerM Decl := do
def Decl.pullFunDecls (decl : Decl .pure) : CompilerM (Decl .pure) := do
let (value, ps) decl.value.mapCodeM pull |>.run []
let value := value.mapCode (attach ps.toArray)
return { decl with value }
def pullFunDecls : Pass :=
.mkPerDeclaration `pullFunDecls Decl.pullFunDecls .base
.mkPerDeclaration `pullFunDecls .base Decl.pullFunDecls
builtin_initialize
registerTraceClass `Compiler.pullFunDecls (inherited := true)

View File

@@ -15,28 +15,28 @@ namespace Lean.Compiler.LCNF
namespace PullLetDecls
structure Context where
isCandidateFn : LetDecl FVarIdSet CompilerM Bool
isCandidateFn : LetDecl .pure FVarIdSet CompilerM Bool
included : FVarIdSet := {}
structure State where
toPull : Array LetDecl := #[]
toPull : Array (LetDecl .pure) := #[]
abbrev PullM := ReaderT Context $ StateRefT State CompilerM
@[inline] def withFVar (fvarId : FVarId) (x : PullM α) : PullM α :=
withReader (fun ctx => { ctx with included := ctx.included.insert fvarId }) x
@[inline] def withParams (ps : Array Param) (x : PullM α) : PullM α :=
@[inline] def withParams (ps : Array (Param .pure)) (x : PullM α) : PullM α :=
withReader (fun ctx => { ctx with included := ps.foldl (init := ctx.included) fun s p => s.insert p.fvarId }) x
@[inline] def withNewScope (x : PullM α) : PullM α :=
withReader (fun ctx => { ctx with included := {} }) x
partial def withCheckpoint (x : PullM Code) : PullM Code := do
partial def withCheckpoint (x : PullM (Code .pure)) : PullM (Code .pure) := do
let toPullSizeSaved := ( get).toPull.size
let c withNewScope x
let toPull := ( get).toPull
let rec go (i : Nat) (included : FVarIdSet) : StateM (Array LetDecl) Code := do
let rec go (i : Nat) (included : FVarIdSet) : StateM (Array (LetDecl .pure)) (Code .pure) := do
if h : i < toPull.size then
let letDecl := toPull[i]
if letDecl.dependsOn included then
@@ -51,11 +51,11 @@ partial def withCheckpoint (x : PullM Code) : PullM Code := do
modify fun s => { s with toPull := s.toPull.shrink toPullSizeSaved ++ keep }
return c
def attachToPull (c : Code) : PullM Code := do
def attachToPull (c : Code .pure) : PullM (Code .pure) := do
let toPull := ( get).toPull
return toPull.foldr (init := c) fun decl c => .let decl c
def shouldPull (decl : LetDecl) : PullM Bool := do
def shouldPull (decl : LetDecl .pure) : PullM Bool := do
unless decl.dependsOn ( read).included do
if ( ( read).isCandidateFn decl ( read).included) then
modify fun s => { s with toPull := s.toPull.push decl }
@@ -63,12 +63,12 @@ def shouldPull (decl : LetDecl) : PullM Bool := do
return false
mutual
partial def pullAlt (alt : Alt) : PullM Alt :=
partial def pullAlt (alt : (Alt .pure)) : PullM (Alt .pure) :=
match alt with
| .default k => return alt.updateCode ( withNewScope <| pullDecls k)
| .alt _ params k => return alt.updateCode ( withNewScope <| withParams params <| pullDecls k)
partial def pullDecls (code : Code) : PullM Code := do
partial def pullDecls (code : Code .pure) : PullM (Code .pure) := do
match code with
| .cases c =>
-- At the present time, we can't correctly enforce the dependencies required for lifting
@@ -93,21 +93,21 @@ mutual
end
def PullM.run (x : PullM α) (isCandidateFn : LetDecl FVarIdSet CompilerM Bool) : CompilerM α :=
def PullM.run (x : PullM α) (isCandidateFn : LetDecl .pure FVarIdSet CompilerM Bool) : CompilerM α :=
x { isCandidateFn } |>.run' {}
end PullLetDecls
open PullLetDecls
def Decl.pullLetDecls (decl : Decl) (isCandidateFn : LetDecl FVarIdSet CompilerM Bool) : CompilerM Decl := do
def Decl.pullLetDecls (decl : Decl .pure) (isCandidateFn : LetDecl .pure FVarIdSet CompilerM Bool) : CompilerM (Decl .pure) := do
PullM.run (isCandidateFn := isCandidateFn) do
withParams decl.params do
let value decl.value.mapCodeM pullDecls
let value value.mapCodeM attachToPull
return { decl with value }
def Decl.pullInstances (decl : Decl) : CompilerM Decl :=
def Decl.pullInstances (decl : Decl .pure) : CompilerM (Decl .pure) :=
decl.pullLetDecls fun letDecl candidates => do
-- TODO: Correctly represent these dependencies so this check isn't required.
if let .const _ _ args := letDecl.value then
@@ -122,7 +122,7 @@ def Decl.pullInstances (decl : Decl) : CompilerM Decl :=
return false
def pullInstances : Pass :=
.mkPerDeclaration `pullInstances Decl.pullInstances .base
.mkPerDeclaration `pullInstances .base Decl.pullInstances
builtin_initialize
registerTraceClass `Compiler.pullInstances (inherited := true)

View File

@@ -52,7 +52,7 @@ We assume this limitation is irrelevant in practice.
namespace FindUsed
structure Context where
decl : Decl
decl : Decl .pure
params : FVarIdSet
structure State where
@@ -64,12 +64,12 @@ def visitFVar (fvarId : FVarId) : FindUsedM Unit := do
if ( read).params.contains fvarId then
modify fun s => { s with used := s.used.insert fvarId }
def visitArg (arg : Arg) : FindUsedM Unit := do
def visitArg (arg : Arg .pure) : FindUsedM Unit := do
match arg with
| .erased | .type .. => return ()
| .fvar fvarId => visitFVar fvarId
def visitLetValue (e : LetValue) : FindUsedM Unit := do
def visitLetValue (e : LetValue .pure) : FindUsedM Unit := do
match e with
| .erased | .lit .. => return ()
| .proj _ _ fvarId => visitFVar fvarId
@@ -93,7 +93,7 @@ def visitLetValue (e : LetValue) : FindUsedM Unit := do
else
args.forM visitArg
partial def visit (code : Code) : FindUsedM Unit := do
partial def visit (code : Code .pure) : FindUsedM Unit := do
match code with
| .let decl k =>
visitLetValue decl.value
@@ -107,7 +107,7 @@ partial def visit (code : Code) : FindUsedM Unit := do
| .return fvarId => visitFVar fvarId
| .unreach _ => return ()
def collectUsedParams (decl : Decl) : CompilerM FVarIdHashSet := do
def collectUsedParams (decl : Decl .pure) : CompilerM FVarIdHashSet := do
let params := decl.params.foldl (init := {}) fun s p => s.insert p.fvarId
let (_, { used, .. }) decl.value.forCodeM visit |>.run { decl, params } |>.run {}
return used
@@ -123,7 +123,7 @@ structure Context where
abbrev ReduceM := ReaderT Context CompilerM
partial def reduce (code : Code) : ReduceM Code := do
partial def reduce (code : Code .pure) : ReduceM (Code .pure) := do
match code with
| .let decl k =>
let .const declName _ args := decl.value | do return code.updateLet! decl ( reduce k)
@@ -148,7 +148,7 @@ end ReduceArity
open FindUsed ReduceArity Internalize
def Decl.reduceArity (decl : Decl) : CompilerM (Array Decl) := do
def Decl.reduceArity (decl : Decl .pure) : CompilerM (Array (Decl .pure)) := do
match decl.value with
| .code code =>
let used collectUsedParams decl
@@ -160,7 +160,7 @@ def Decl.reduceArity (decl : Decl) : CompilerM (Array Decl) := do
trace[Compiler.reduceArity] "{decl.name}, used params: {used.toList.map mkFVar}"
let mask := decl.params.map fun param => used.contains param.fvarId
let auxName := decl.name ++ `_redArg
let mkAuxDecl : CompilerM Decl := do
let mkAuxDecl : CompilerM (Decl .pure) := do
let params := decl.params.filter fun param => used.contains param.fvarId
let value decl.value.mapCodeM reduce |>.run { declName := decl.name, auxDeclName := auxName, paramMask := mask }
let type code.inferType
@@ -168,7 +168,7 @@ def Decl.reduceArity (decl : Decl) : CompilerM (Array Decl) := do
let auxDecl := { decl with name := auxName, levelParams := [], type, params, value }
auxDecl.saveMono
return auxDecl
let updateDecl : InternalizeM Decl := do
let updateDecl : InternalizeM .pure (Decl .pure) := do
let params decl.params.mapM internalizeParam
let mut args := #[]
for used in mask, param in params do

View File

@@ -18,7 +18,7 @@ namespace ReduceJpArity
abbrev ReduceM := ReaderT (FVarIdMap (Array Bool)) CompilerM
partial def reduce (code : Code) : ReduceM Code := do
partial def reduce (code : Code .pure) : ReduceM (Code .pure) := do
match code with
| .let decl k => return code.updateLet! decl ( reduce k)
| .fun decl k =>
@@ -69,12 +69,14 @@ open ReduceJpArity
/--
Try to reduce arity of join points
-/
def Decl.reduceJpArity (decl : Decl) : CompilerM Decl := do
def Decl.reduceJpArity (decl : Decl .pure) : CompilerM (Decl .pure) := do
let value decl.value.mapCodeM reduce |>.run {}
return { decl with value }
-- TODO: This can be made Purity generic
def reduceJpArity (phase := Phase.base) : Pass :=
.mkPerDeclaration `reduceJpArity Decl.reduceJpArity phase
phase.withPurityCheck .pure fun h =>
.mkPerDeclaration `reduceJpArity phase (h Decl.reduceJpArity)
builtin_initialize
registerTraceClass `Compiler.reduceJpArity (inherited := true)

View File

@@ -16,7 +16,7 @@ A mapping from free variable id to binder name.
-/
abbrev Renaming := FVarIdMap Name
def Param.applyRenaming (param : Param) (r : Renaming) : CompilerM Param := do
def Param.applyRenaming (param : Param pu) (r : Renaming) : CompilerM (Param pu) := do
if let some binderName := r.get? param.fvarId then
let param := { param with binderName }
modifyLCtx fun lctx => lctx.addParam param
@@ -24,7 +24,7 @@ def Param.applyRenaming (param : Param) (r : Renaming) : CompilerM Param := do
else
return param
def LetDecl.applyRenaming (decl : LetDecl) (r : Renaming) : CompilerM LetDecl := do
def LetDecl.applyRenaming (decl : LetDecl pu) (r : Renaming) : CompilerM (LetDecl pu) := do
if let some binderName := r.get? decl.fvarId then
let decl := { decl with binderName }
modifyLCtx fun lctx => lctx.addLetDecl decl
@@ -33,7 +33,7 @@ def LetDecl.applyRenaming (decl : LetDecl) (r : Renaming) : CompilerM LetDecl :=
return decl
mutual
partial def FunDecl.applyRenaming (decl : FunDecl) (r : Renaming) : CompilerM FunDecl := do
partial def FunDecl.applyRenaming (decl : (FunDecl pu)) (r : Renaming) : CompilerM (FunDecl pu) := do
if let some binderName := r.get? decl.fvarId then
let decl := decl.updateBinderName binderName
modifyLCtx fun lctx => lctx.addFunDecl decl
@@ -41,20 +41,20 @@ partial def FunDecl.applyRenaming (decl : FunDecl) (r : Renaming) : CompilerM Fu
else
decl.updateValue ( decl.value.applyRenaming r)
partial def Code.applyRenaming (code : Code) (r : Renaming) : CompilerM Code := do
partial def Code.applyRenaming (code : Code pu) (r : Renaming) : CompilerM (Code pu) := do
match code with
| .let decl k => return code.updateLet! ( decl.applyRenaming r) ( k.applyRenaming r)
| .fun decl k | .jp decl k => return code.updateFun! ( decl.applyRenaming r) ( k.applyRenaming r)
| .fun decl k _ | .jp decl k => return code.updateFun! ( decl.applyRenaming r) ( k.applyRenaming r)
| .cases c =>
let alts c.alts.mapMonoM fun alt =>
match alt with
| .default k => return alt.updateCode ( k.applyRenaming r)
| .alt _ ps k => return alt.updateAlt! ( ps.mapMonoM (·.applyRenaming r)) ( k.applyRenaming r)
| .alt _ ps k _ => return alt.updateAlt! ( ps.mapMonoM (·.applyRenaming r)) ( k.applyRenaming r)
return code.updateAlts! alts
| .jmp .. | .unreach .. | .return .. => return code
end
def Decl.applyRenaming (decl : Decl) (r : Renaming) : CompilerM Decl := do
def Decl.applyRenaming (decl : Decl pu) (r : Renaming) : CompilerM (Decl pu) := do
if r.isEmpty then
return decl
else

View File

@@ -24,7 +24,7 @@ public section
namespace Lean.Compiler.LCNF
open Simp
def Decl.simp? (decl : Decl) : SimpM (Option Decl) := do
def Decl.simp? (decl : Decl .pure) : SimpM (Option (Decl .pure)) := do
let .code code := decl.value | return none
updateFunDeclInfo code
traceM `Compiler.simp.inline.info do return m!"{decl.name}:{Format.nest 2 (← (← get).funDeclInfoMap.format)}"
@@ -42,7 +42,7 @@ def Decl.simp? (decl : Decl) : SimpM (Option Decl) := do
else
return none
partial def Decl.simp (decl : Decl) (config : Config) : CompilerM Decl := do
partial def Decl.simp (decl : Decl .pure) (config : Config) : CompilerM (Decl .pure) := do
let mut config := config
if ( isTemplateLike decl) then
/-
@@ -54,7 +54,7 @@ partial def Decl.simp (decl : Decl) (config : Config) : CompilerM Decl := do
config := { config with etaPoly := false, inlinePartial := false }
go decl config
where
go (decl : Decl) (config : Config) : CompilerM Decl := do
go (decl : Decl .pure) (config : Config) : CompilerM (Decl .pure) := do
if let some decl decl.simp? |>.run { config, declName := decl.name } |>.run' {} |>.run {} then
-- TODO: bound number of steps?
go decl config
@@ -62,7 +62,8 @@ where
return decl
def simp (config : Config := {}) (occurrence : Nat := 0) (phase := Phase.base) : Pass :=
.mkPerDeclaration `simp (Decl.simp · config) phase (occurrence := occurrence)
phase.withPurityCheck .pure fun h =>
.mkPerDeclaration `simp phase (h (Decl.simp · config)) (occurrence := occurrence)
builtin_initialize
registerTraceClass `Compiler.simp (inherited := true)

View File

@@ -22,10 +22,10 @@ let _x.2 := _f.1
```
`findFunDecl? _x.2` returns `none`, but `findFunDecl'? _x.2` returns the declaration for `_f.1`.
-/
partial def findFunDecl'? (fvarId : FVarId) : CompilerM (Option FunDecl) := do
if let some decl findFunDecl? fvarId then
partial def findFunDecl'? (fvarId : FVarId) : CompilerM (Option (FunDecl pu)) := do
if let some decl findFunDecl? (pu := pu) fvarId then
return decl
else if let some (.fvar fvarId' #[]) findLetValue? fvarId then
else if let some (.fvar fvarId' #[]) findLetValue? (pu := pu) fvarId then
findFunDecl'? fvarId'
else
return none

View File

@@ -18,14 +18,14 @@ namespace ConstantFold
A constant folding monad, the additional state stores auxiliary declarations
required to build the new constant.
-/
abbrev FolderM := StateRefT (Array CodeDecl) CompilerM
abbrev FolderM := StateRefT (Array (CodeDecl .pure)) CompilerM
/--
A constant folder for a specific function, takes all the arguments of a
certain function and produces a new `Expr` + auxiliary declarations in
the `FolderM` monad on success. If the folding fails it returns `none`.
-/
abbrev Folder := Array Arg FolderM (Option LetValue)
abbrev Folder := Array (Arg .pure) FolderM (Option (LetValue .pure))
/--
A typeclass for detecting and producing literals of arbitrary types
@@ -43,7 +43,7 @@ class Literal (α : Type) where
final `Expr` putting them all together into a literal of type `α`,
where again the idea of what a literal is depends on `α`.
-/
mkLit : α FolderM LetValue
mkLit : α FolderM (LetValue .pure)
export Literal (getLit mkLit)
@@ -51,7 +51,7 @@ export Literal (getLit mkLit)
A wrapper around `LCNF.mkAuxLetDecl` that will automatically store the
`LetDecl` in the state of `FolderM`.
-/
def mkAuxLetDecl (e : LetValue) (prefixName := `_x) : FolderM FVarId := do
def mkAuxLetDecl (e : LetValue .pure) (prefixName := `_x) : FolderM FVarId := do
let decl LCNF.mkAuxLetDecl e prefixName
modify fun s => s.push <| .let decl
return decl.fvarId
@@ -66,10 +66,10 @@ def mkAuxLit [Literal α] (x : α) (prefixName := `_x) : FolderM FVarId := do
mkAuxLetDecl lit prefixName
partial def getNatLit (fvarId : FVarId) : CompilerM (Option Nat) := do
let some (.lit (.nat n)) findLetValue? fvarId | return none
let some (.lit (.nat n)) findLetValue? (pu := .pure) fvarId | return none
return n
def mkNatLit (n : Nat) : FolderM LetValue :=
def mkNatLit (n : Nat) : FolderM (LetValue .pure) :=
return .lit (.nat n)
instance : Literal Nat where
@@ -77,10 +77,10 @@ instance : Literal Nat where
mkLit := mkNatLit
def getStringLit (fvarId : FVarId) : CompilerM (Option String) := do
let some (.lit (.str s)) findLetValue? fvarId | return none
let some (.lit (.str s)) findLetValue? (pu := .pure) fvarId | return none
return s
def mkStringLit (n : String) : FolderM LetValue :=
def mkStringLit (n : String) : FolderM (LetValue .pure) :=
return .lit (.str n)
instance : Literal String where
@@ -91,7 +91,7 @@ def getBoolLit (fvarId : FVarId) : CompilerM (Option Bool) := do
let some (.const ctor [] #[]) findLetValue? fvarId | return none
return ctor == ``Bool.true
def mkBoolLit (b : Bool) : FolderM LetValue :=
def mkBoolLit (b : Bool) : FolderM (LetValue .pure) :=
let ctor := if b then ``Bool.true else ``Bool.false
return .const ctor [] #[]
@@ -115,7 +115,7 @@ instance : Literal Char := mkNatWrapperInstance Char.ofNat ``Char.ofNat Char.toN
def mkUIntInstance (matchLit : LitValue Option α) (litValueCtor : α LitValue) : Literal α where
getLit fvarId := do
let some (.lit litVal) findLetValue? fvarId | return none
let some (.lit litVal) findLetValue? (pu := .pure) fvarId | return none
return matchLit litVal
mkLit x :=
return .lit <| litValueCtor x
@@ -162,7 +162,7 @@ let _x.26 := @Array.push _ _x.24 z
_x.26
```
-/
def mkPseudoArrayLiteral (elements : Array FVarId) (typ : Expr) (typLevel : Level) : FolderM LetValue := do
def mkPseudoArrayLiteral (elements : Array FVarId) (typ : Expr) (typLevel : Level) : FolderM (LetValue .pure) := do
let sizeLit mkAuxLit elements.size
let mut literal mkAuxLetDecl <| .const ``Array.mkEmpty [typLevel] #[.type typ, .fvar sizeLit]
for element in elements do
@@ -335,7 +335,7 @@ def Folder.mulShift [Literal α] [BEq α] (shiftLeft : Name) (pow2 : αα)
-- TODO: add option for controlling the limit
def natPowThreshold := 256
def foldNatPow (args : Array Arg) : FolderM (Option LetValue) := do
def foldNatPow (args : Array (Arg .pure)) : FolderM (Option (LetValue .pure)) := do
let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none
let some value₁ getNatLit fvarId₁ | return none
let some value₂ getNatLit fvarId₂ | return none
@@ -347,14 +347,14 @@ def foldNatPow (args : Array Arg) : FolderM (Option LetValue) := do
/--
Folder for ofNat operations on fixed-sized integer types.
-/
def Folder.ofNat (f : Nat LitValue) (args : Array Arg) : FolderM (Option LetValue) := do
def Folder.ofNat (f : Nat LitValue) (args : Array (Arg .pure)) : FolderM (Option (LetValue .pure)) := do
let #[.fvar fvarId] := args | return none
let some value getNatLit fvarId | return none
return some (.lit (f value))
def Folder.toNat (args : Array Arg) : FolderM (Option LetValue) := do
def Folder.toNat (args : Array (Arg .pure)) : FolderM (Option (LetValue .pure)) := do
let #[.fvar fvarId] := args | return none
let some (.lit lit) findLetValue? fvarId | return none
let some (.lit lit) findLetValue? (pu := .pure) fvarId | return none
match lit with
| .uint8 v | .uint16 v | .uint32 v | .uint64 v | .usize v => return some (.lit (.nat v.toNat))
| .nat _ | .str _ => return none
@@ -436,7 +436,7 @@ def stringFolders : List (Name × Folder) := [
/--
Apply all known folders to `decl`.
-/
def applyFolders (decl : LetDecl) (folders : SMap Name Folder) : CompilerM (Option (Array CodeDecl)) := do
def applyFolders (decl : LetDecl .pure) (folders : SMap Name Folder) : CompilerM (Option (Array (CodeDecl .pure))) := do
match decl.value with
| .const name _ args =>
if let some folder := folders.find? name then
@@ -495,7 +495,7 @@ def getFolders : CoreM (SMap Name Folder) :=
/--
Apply a list of default folders to `decl`
-/
def foldConstants (decl : LetDecl) : CompilerM (Option (Array CodeDecl)) := do
def foldConstants (decl : LetDecl .pure) : CompilerM (Option (Array (CodeDecl .pure))) := do
applyFolders decl ( getFolders)
end ConstantFold

View File

@@ -19,7 +19,7 @@ and the number of occurrences.
We use this function to decide whether to create a `.default` case
or not.
-/
private def getMaxOccs (alts : Array Alt) : Alt × Nat := Id.run do
private def getMaxOccs (alts : Array (Alt .pure)) : Alt .pure × Nat := Id.run do
let mut maxAlt := alts[0]!
let mut max := getNumOccsOf alts 0
for h : i in 1...alts.size do
@@ -35,7 +35,7 @@ where
Note that the number of occurrences can be greater than 1 only when
the alternative does not depend on field parameters
-/
getNumOccsOf (alts : Array Alt) (i : Nat) : Nat := Id.run do
getNumOccsOf (alts : Array (Alt .pure)) (i : Nat) : Nat := Id.run do
let code := alts[i]!.getCode
let mut n := 1
for h : j in (i+1)...alts.size do
@@ -47,7 +47,7 @@ where
Add a default case to the given `cases` alternatives if there
are alternatives with equivalent (aka alpha equivalent) right hand sides.
-/
def addDefaultAlt (alts : Array Alt) : SimpM (Array Alt) := do
def addDefaultAlt (alts : Array (Alt .pure)) : SimpM (Array (Alt .pure)) := do
if alts.size <= 1 || alts.any (· matches .default ..) then
return alts
else

Some files were not shown because too many files have changed in this diff Show More