Merge branch 'nightly-with-mathlib' of https://github.com/leanprover/lean4 into joachim/instantiateMVarsNoUpdate

This commit is contained in:
Joachim Breitner
2026-03-03 11:18:47 +00:00
6994 changed files with 25502 additions and 4302 deletions

View File

@@ -4,29 +4,25 @@ To build Lean you should use `make -j$(nproc) -C build/release`.
## Running Tests
See `doc/dev/testing.md` for full documentation. Quick reference:
See `tests/README.md` for full documentation. Quick reference:
```bash
# Full test suite (use after builds to verify correctness)
make -j$(nproc) -C build/release test ARGS="-j$(nproc)"
CTEST_PARALLEL_LEVEL="$(nproc)" CTEST_OUTPUT_ON_FAILURE=1 \
make -C build/release -j "$(nproc)" test
# Specific test by name (supports regex via ctest -R)
make -j$(nproc) -C build/release test ARGS='-R grind_ematch --output-on-failure'
CTEST_PARALLEL_LEVEL="$(nproc)" CTEST_OUTPUT_ON_FAILURE=1 \
make -C build/release -j "$(nproc)" test ARGS='-R grind_ematch'
# Rerun only previously failed tests
make -j$(nproc) -C build/release test ARGS='--rerun-failed --output-on-failure'
CTEST_PARALLEL_LEVEL="$(nproc)" CTEST_OUTPUT_ON_FAILURE=1 \
make -C build/release -j "$(nproc)" test ARGS='--rerun-failed'
# Single test from tests/lean/run/ (quick check during development)
cd tests/lean/run && ./test_single.sh example_test.lean
# ctest directly (from stage1 build dir)
cd build/release/stage1 && ctest -j$(nproc) --output-on-failure --timeout 300
# Single test from tests/foo/bar/ (quick check during development)
cd tests/foo/bar && ./run_test example_test.lean
```
The full test suite includes `tests/lean/`, `tests/lean/run/`, `tests/lean/interactive/`,
`tests/compiler/`, `tests/pkg/`, Lake tests, and more. Using `make test` or `ctest` runs
all of them; `test_single.sh` in `tests/lean/run/` only covers that one directory.
## New features
When asked to implement new features:
@@ -34,8 +30,6 @@ When asked to implement new features:
* write comprehensive tests first (expecting that these will initially fail)
* and then iterate on the implementation until the tests pass.
All new tests should go in `tests/lean/run/`. These tests don't have expected output; we just check there are no errors. You should use `#guard_msgs` to check for specific messages.
## Success Criteria
*Never* report success on a task unless you have verified both a clean build without errors, and that the relevant tests pass.

View File

@@ -121,6 +121,20 @@ The nightly build system uses branches and tags across two repositories:
When a nightly succeeds with mathlib, all three should point to the same commit. Don't confuse these: branches are in the main lean4 repo, dated tags are in lean4-nightly.
## Waiting for CI or Merges
Use `gh pr checks --watch` to block until a PR's CI checks complete (no polling needed).
Run these as background bash commands so you get notified when they finish:
```bash
# Watch CI, then check merge state
gh pr checks <number> --repo <owner>/<repo> --watch && gh pr view <number> --repo <owner>/<repo> --json state --jq '.state'
```
For multiple PRs, launch one background command per PR in parallel. When each completes,
you'll be notified automatically via a task-notification. Do NOT use sleep-based polling
loops — `--watch` is event-driven and exits as soon as checks finish.
## Error Handling
**CRITICAL**: If something goes wrong or a command fails:

View File

@@ -0,0 +1,26 @@
---
name: profiling
description: Profile Lean programs with demangled names using samply and Firefox Profiler. Use when the user asks to profile a Lean binary or investigate performance.
allowed-tools: Bash, Read, Glob, Grep
---
# Profiling Lean Programs
Full documentation: `script/PROFILER_README.md`.
## Quick Start
```bash
script/lean_profile.sh ./build/release/stage1/bin/lean some_file.lean
```
Requires `samply` (`cargo install samply`) and `python3`.
## Agent Notes
- The pipeline is interactive (serves to browser at the end). When running non-interactively, run the steps manually instead of using the wrapper script.
- The three steps are: `samply record --save-only`, `symbolicate_profile.py`, then `serve_profile.py`.
- `lean_demangle.py` works standalone as a stdin filter (like `c++filt`) for quick name lookups.
- The `--raw` flag on `lean_demangle.py` gives exact demangled names without postprocessing (keeps `._redArg`, `._lam_0` suffixes as-is).
- Use `PROFILE_KEEP=1` to keep the temp directory for later inspection.
- The demangled profile is a standard Firefox Profiler JSON. Function names live in `threads[i].stringArray`, indexed by `threads[i].funcTable.name`.

View File

@@ -2,16 +2,19 @@ name: Check awaiting-manual label
on:
merge_group:
pull_request:
pull_request_target:
types: [opened, synchronize, reopened, labeled, unlabeled]
permissions:
pull-requests: read
jobs:
check-awaiting-manual:
runs-on: ubuntu-latest
steps:
- name: Check awaiting-manual label
id: check-awaiting-manual-label
if: github.event_name == 'pull_request'
if: github.event_name == 'pull_request_target'
uses: actions/github-script@v8
with:
script: |
@@ -28,7 +31,7 @@ jobs:
}
- name: Wait for manual compatibility
if: github.event_name == 'pull_request' && steps.check-awaiting-manual-label.outputs.awaiting == 'true'
if: github.event_name == 'pull_request_target' && steps.check-awaiting-manual-label.outputs.awaiting == 'true'
run: |
echo "::notice title=Awaiting manual::PR is marked 'awaiting-manual' but neither 'breaks-manual' nor 'builds-manual' labels are present."
echo "This check will remain in progress until the PR is updated with appropriate manual compatibility labels."

View File

@@ -2,16 +2,19 @@ name: Check awaiting-mathlib label
on:
merge_group:
pull_request:
pull_request_target:
types: [opened, synchronize, reopened, labeled, unlabeled]
permissions:
pull-requests: read
jobs:
check-awaiting-mathlib:
runs-on: ubuntu-latest
steps:
- name: Check awaiting-mathlib label
id: check-awaiting-mathlib-label
if: github.event_name == 'pull_request'
if: github.event_name == 'pull_request_target'
uses: actions/github-script@v8
with:
script: |
@@ -28,7 +31,7 @@ jobs:
}
- name: Wait for mathlib compatibility
if: github.event_name == 'pull_request' && steps.check-awaiting-mathlib-label.outputs.awaiting == 'true'
if: github.event_name == 'pull_request_target' && steps.check-awaiting-mathlib-label.outputs.awaiting == 'true'
run: |
echo "::notice title=Awaiting mathlib::PR is marked 'awaiting-mathlib' but neither 'breaks-mathlib' nor 'builds-mathlib' labels are present."
echo "This check will remain in progress until the PR is updated with appropriate mathlib compatibility labels."

View File

@@ -66,16 +66,10 @@ jobs:
brew install ccache tree zstd coreutils gmp libuv
if: runner.os == 'macOS'
- name: Checkout
if: (!endsWith(matrix.os, '-with-cache'))
uses: actions/checkout@v6
with:
# the default is to use a virtual merge commit between the PR and master: just use the PR
ref: ${{ github.event.pull_request.head.sha }}
- name: Namespace Checkout
if: endsWith(matrix.os, '-with-cache')
uses: namespacelabs/nscloud-checkout-action@v8
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Open Nix shell once
run: true
if: runner.os == 'Linux'
@@ -85,7 +79,7 @@ jobs:
- name: CI Merge Checkout
run: |
git fetch --depth=1 origin ${{ github.sha }}
git checkout FETCH_HEAD flake.nix flake.lock script/prepare-* tests/lean/run/importStructure.lean
git checkout FETCH_HEAD flake.nix flake.lock script/prepare-* tests/elab/importStructure.lean
if: github.event_name == 'pull_request'
# (needs to be after "Checkout" so files don't get overridden)
- name: Setup emsdk
@@ -235,25 +229,21 @@ jobs:
# prefix `if` above with `always` so it's run even if tests failed
if: always() && steps.test.conclusion != 'skipped'
- name: Check Test Binary
run: ${{ matrix.binary-check }} tests/compiler/534.lean.out
run: ${{ matrix.binary-check }} tests/compile/534.lean.out
if: (!matrix.cross) && steps.test.conclusion != 'skipped'
- name: Build Stage 2
run: |
make -C build -j$NPROC stage2
if: matrix.test-speedcenter
if: matrix.test-bench
- name: Check Stage 3
run: |
make -C build -j$NPROC check-stage3
if: matrix.check-stage3
- name: Test Speedcenter Benchmarks
- name: Test Benchmarks
run: |
# Necessary for some timing metrics but does not work on Namespace runners
# and we just want to test that the benchmarks run at all here
#echo -1 | sudo tee /proc/sys/kernel/perf_event_paranoid
export BUILD=$PWD/build PATH=$PWD/build/stage1/bin:$PATH
cd tests/bench
nix shell .#temci -c temci exec --config speedcenter.yaml --included_blocks fast --runs 1
if: matrix.test-speedcenter
cd tests
nix develop -c make -C ../build -j$NPROC bench
if: matrix.test-bench
- name: Check rebootstrap
run: |
set -e

View File

@@ -1,9 +1,12 @@
name: Check stdlib_flags.h modifications
on:
pull_request:
pull_request_target:
types: [opened, synchronize, reopened, labeled, unlabeled]
permissions:
pull-requests: read
jobs:
check-stdlib-flags:
runs-on: ubuntu-latest

View File

@@ -258,8 +258,8 @@ jobs:
"check-rebootstrap": level >= 1,
"check-stage3": level >= 2,
"test": true,
// NOTE: `test-speedcenter` currently seems to be broken on `ubuntu-latest`
"test-speedcenter": large && level >= 2,
// NOTE: `test-bench` currently seems to be broken on `ubuntu-latest`
"test-bench": large && level >= 2,
// We are not warning-free yet on all platforms, start here
"CMAKE_OPTIONS": "-DLEAN_EXTRA_CXX_FLAGS=-Werror",
},
@@ -269,6 +269,8 @@ jobs:
"enabled": level >= 2,
"test": true,
"CMAKE_PRESET": "reldebug",
// * `elab_bench/big_do` crashes with exit code 134
"CTEST_OPTIONS": "-E 'elab_bench/big_do'",
},
{
"name": "Linux fsanitize",

View File

@@ -2,17 +2,23 @@ name: Check PR body for changelog convention
on:
merge_group:
pull_request:
pull_request_target:
types: [opened, synchronize, reopened, edited, labeled, converted_to_draft, ready_for_review]
permissions:
pull-requests: read
jobs:
check-pr-body:
runs-on: ubuntu-latest
steps:
- name: Check PR body
if: github.event_name == 'pull_request'
if: github.event_name == 'pull_request_target'
uses: actions/github-script@v8
with:
# Safety note: this uses pull_request_target, so the workflow has elevated privileges.
# The PR title and body are only used in regex tests (read-only string matching),
# never interpolated into shell commands, eval'd, or written to GITHUB_ENV/GITHUB_OUTPUT.
script: |
const { title, body, labels, draft } = context.payload.pull_request;
if (!draft && /^(feat|fix):/.test(title) && !labels.some(label => label.name == "changelog-no")) {

1
.gitignore vendored
View File

@@ -1,7 +1,6 @@
*~
\#*
.#*
*.lock
.lake
lake-manifest.json
/build

View File

@@ -1,4 +1,8 @@
cmake_minimum_required(VERSION 3.11)
cmake_minimum_required(VERSION 3.21)
if(NOT CMAKE_GENERATOR MATCHES "Makefiles")
message(FATAL_ERROR "Only makefile generators are supported")
endif()
option(USE_MIMALLOC "use mimalloc" ON)
@@ -147,6 +151,7 @@ ExternalProject_Add(
INSTALL_COMMAND ""
DEPENDS stage2
EXCLUDE_FROM_ALL ON
STEP_TARGETS configure
)
# targets forwarded to appropriate stages
@@ -157,6 +162,25 @@ add_custom_target(update-stage0-commit COMMAND $(MAKE) -C stage1 update-stage0-c
add_custom_target(test COMMAND $(MAKE) -C stage1 test DEPENDS stage1)
add_custom_target(
bench
COMMAND $(MAKE) -C stage2
COMMAND $(MAKE) -C stage2 -j1 bench
DEPENDS stage2
)
add_custom_target(
bench-part1
COMMAND $(MAKE) -C stage2
COMMAND $(MAKE) -C stage2 -j1 bench-part1
DEPENDS stage2
)
add_custom_target(
bench-part2
COMMAND $(MAKE) -C stage2
COMMAND $(MAKE) -C stage2 -j1 bench-part2
DEPENDS stage2
)
add_custom_target(clean-stdlib COMMAND $(MAKE) -C stage1 clean-stdlib DEPENDS stage1)
install(CODE "execute_process(COMMAND make -C stage1 install)")

View File

@@ -41,7 +41,7 @@
"SMALL_ALLOCATOR": "OFF",
"USE_MIMALLOC": "OFF",
"BSYMBOLIC": "OFF",
"LEAN_TEST_VARS": "MAIN_STACK_SIZE=16000 LSAN_OPTIONS=max_leaks=10"
"LEAN_TEST_VARS": "MAIN_STACK_SIZE=16000 TEST_STACK_SIZE=16000 LSAN_OPTIONS=max_leaks=10"
},
"generator": "Unix Makefiles",
"binaryDir": "${sourceDir}/build/sanitize"

View File

@@ -1,5 +1,9 @@
# Test Suite
**Warning:** This document is partially outdated.
It describes the old test suite, which is currently in the process of being replaced.
The new test suite's documentation can be found at [`tests/README.md`](../../tests/README.md).
After [building Lean](../make/index.md) you can run all the tests using
```
cd build/release

View File

@@ -1 +1 @@
lean4
../../../build/release/stage1

View File

@@ -1 +1 @@
lean4
build/release/stage1

View File

@@ -2,21 +2,9 @@
"folders": [
{
"path": "."
},
{
"path": "src"
},
{
"path": "tests"
},
{
"path": "script"
}
],
"settings": {
// Open terminal at root, not current workspace folder
// (there is not way to directly refer to the root folder included as `.` above)
"terminal.integrated.cwd": "${workspaceFolder:src}/..",
"files.insertFinalNewline": true,
"files.trimTrailingWhitespace": true,
"cmake.buildDirectory": "${workspaceFolder}/build/release",

View File

@@ -83,7 +83,7 @@ def main (args : List String) : IO Unit := do
lastRSS? := some rss
let avgRSSDelta := totalRSSDelta / (n - 2)
IO.println s!"avg-reelab-rss-delta: {avgRSSDelta}"
IO.println s!"measurement: avg-reelab-rss-delta {avgRSSDelta*1024} b"
let _ Ipc.collectDiagnostics requestNo uri versionNo
( Ipc.stdin).writeLspMessage (Message.notification "exit" none)

View File

@@ -82,7 +82,7 @@ def main (args : List String) : IO Unit := do
lastRSS? := some rss
let avgRSSDelta := totalRSSDelta / (n - 2)
IO.println s!"avg-reelab-rss-delta: {avgRSSDelta}"
IO.println s!"measurement: avg-reelab-rss-delta {avgRSSDelta*1024} b"
let _ Ipc.collectDiagnostics requestNo uri versionNo
Ipc.shutdown requestNo

View File

@@ -9,5 +9,5 @@ find -regex '.*/CMakeLists\.txt\(\.in\)?\|.*\.cmake\(\.in\)?' \
! -path "./stage0/*" \
-exec \
uvx gersemi --in-place --line-length 120 --indent 2 \
--definitions src/cmake/Modules/ src/CMakeLists.txt \
--definitions src/cmake/Modules/ src/CMakeLists.txt tests/CMakeLists.txt \
-- {} +

View File

@@ -1 +1 @@
lean4
../build/release/stage1

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
# Profile a Lean binary with demangled names.
#
# Usage:

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash
set -euo pipefail
rm -r stage0 || true
rm -rf stage0 || true
# don't copy untracked files
# `:!` is git glob flavor for exclude patterns
for f in $(git ls-files src ':!:src/lake/*' ':!:src/Leanc.lean'); do

View File

@@ -1,6 +1,4 @@
cmake_minimum_required(VERSION 3.10)
cmake_policy(SET CMP0054 NEW)
cmake_policy(SET CMP0110 NEW)
cmake_minimum_required(VERSION 3.21)
if(NOT CMAKE_GENERATOR MATCHES "Unix Makefiles")
message(FATAL_ERROR "The only supported CMake generator at the moment is 'Unix Makefiles'")
endif()

View File

@@ -72,6 +72,9 @@ theorem toArray_eq : List.toArray as = xs ↔ as = xs.toList := by
/-! ### size -/
theorem size_singleton {x : α} : #[x].size = 1 := by
simp
theorem eq_empty_of_size_eq_zero (h : xs.size = 0) : xs = #[] := by
cases xs
simp_all
@@ -3483,6 +3486,21 @@ theorem foldl_eq_foldr_reverse {xs : Array α} {f : β → α → β} {b} :
theorem foldr_eq_foldl_reverse {xs : Array α} {f : α β β} {b} :
xs.foldr f b = xs.reverse.foldl (fun x y => f y x) b := by simp
theorem foldl_eq_apply_foldr {xs : Array α} {f : α α α}
[Std.Associative f] [Std.LawfulRightIdentity f init] :
xs.foldl f x = f x (xs.foldr f init) := by
simp [ foldl_toList, foldr_toList, List.foldl_eq_apply_foldr]
theorem foldr_eq_apply_foldl {xs : Array α} {f : α α α}
[Std.Associative f] [Std.LawfulLeftIdentity f init] :
xs.foldr f x = f (xs.foldl f init) x := by
simp [ foldl_toList, foldr_toList, List.foldr_eq_apply_foldl]
theorem foldr_eq_foldl {xs : Array α} {f : α α α}
[Std.Associative f] [Std.LawfulIdentity f init] :
xs.foldr f init = xs.foldl f init := by
simp [foldl_eq_apply_foldr, Std.LawfulLeftIdentity.left_id]
@[simp] theorem foldr_push_eq_append {as : Array α} {bs : Array β} {f : α β} (w : start = as.size) :
as.foldr (fun a xs => Array.push xs (f a)) bs start 0 = bs ++ (as.map f).reverse := by
subst w
@@ -4335,16 +4353,33 @@ def sum_eq_sum_toList := @sum_toList
@[simp, grind =]
theorem sum_append [Zero α] [Add α] [Std.Associative (α := α) (· + ·)]
[Std.LeftIdentity (α := α) (· + ·) 0] [Std.LawfulLeftIdentity (α := α) (· + ·) 0]
[Std.LawfulLeftIdentity (α := α) (· + ·) 0]
{as₁ as₂ : Array α} : (as₁ ++ as₂).sum = as₁.sum + as₂.sum := by
simp [ sum_toList, List.sum_append]
@[simp, grind =]
theorem sum_singleton [Add α] [Zero α] [Std.LawfulRightIdentity (· + ·) (0 : α)] {x : α} :
#[x].sum = x := by
simp [Array.sum_eq_foldr, Std.LawfulRightIdentity.right_id x]
@[simp, grind =]
theorem sum_push [Add α] [Zero α] [Std.Associative (α := α) (· + ·)]
[Std.LawfulIdentity (· + ·) (0 : α)] {xs : Array α} {x : α} :
(xs.push x).sum = xs.sum + x := by
simp [Array.sum_eq_foldr, Std.LawfulRightIdentity.right_id, Std.LawfulLeftIdentity.left_id,
Array.foldr_assoc]
@[simp, grind =]
theorem sum_reverse [Zero α] [Add α] [Std.Associative (α := α) (· + ·)]
[Std.Commutative (α := α) (· + ·)]
[Std.LawfulLeftIdentity (α := α) (· + ·) 0] (xs : Array α) : xs.reverse.sum = xs.sum := by
simp [ sum_toList, List.sum_reverse]
theorem sum_eq_foldl [Zero α] [Add α] [Std.Associative (α := α) (· + ·)]
[Std.LawfulIdentity (· + ·) (0 : α)] {xs : Array α} :
xs.sum = xs.foldl (init := 0) (· + ·) := by
simp [ sum_toList, List.sum_eq_foldl]
theorem foldl_toList_eq_flatMap {l : List α} {acc : Array β}
{F : Array β α Array β} {G : α List β}
(H : acc a, (F acc a).toList = acc.toList ++ G a) :

View File

@@ -126,6 +126,14 @@ theorem swap_perm {xs : Array α} {i j : Nat} (h₁ : i < xs.size) (h₂ : j < x
simp only [swap, perm_iff_toList_perm, toList_set]
apply set_set_perm
theorem Perm.pairwise_iff {R : α α Prop} (S : {x y}, R x y R y x) {xs ys : Array α}
: _p : xs.Perm ys, xs.toList.Pairwise R ys.toList.Pairwise R := by
simpa only [perm_iff_toList_perm] using List.Perm.pairwise_iff S
theorem Perm.pairwise {R : α α Prop} {xs ys : Array α} (hp : xs ~ ys)
(hR : xs.toList.Pairwise R) (hsymm : {x y}, R x y R y x) :
ys.toList.Pairwise R := (hp.pairwise_iff hsymm).mp hR
namespace Perm
set_option linter.indexVariables false in

View File

@@ -37,3 +37,4 @@ public import Init.Data.List.Lex
public import Init.Data.List.Range
public import Init.Data.List.Scan
public import Init.Data.List.ControlImpl
public import Init.Data.List.SplitOn

View File

@@ -135,7 +135,11 @@ protected def beq [BEq α] : List α → List α → Bool
@[simp] theorem beq_nil_nil [BEq α] : List.beq ([] : List α) ([] : List α) = true := rfl
@[simp] theorem beq_cons_nil [BEq α] {a : α} {as : List α} : List.beq (a::as) [] = false := rfl
@[simp] theorem beq_nil_cons [BEq α] {a : α} {as : List α} : List.beq [] (a::as) = false := rfl
theorem beq_cons [BEq α] {a b : α} {as bs : List α} : List.beq (a::as) (b::bs) = (a == b && List.beq as bs) := rfl
theorem beq_cons_cons [BEq α] {a b : α} {as bs : List α} : List.beq (a::as) (b::bs) = (a == b && List.beq as bs) := rfl
@[deprecated beq_cons_cons (since := "2026-02-26")]
theorem beq_cons₂ [BEq α] {a b : α} {as bs : List α} :
List.beq (a::as) (b::bs) = (a == b && List.beq as bs) := beq_cons_cons
instance [BEq α] : BEq (List α) := List.beq
@@ -175,7 +179,10 @@ Examples:
@[simp, grind =] theorem isEqv_nil_nil : isEqv ([] : List α) [] eqv = true := rfl
@[simp, grind =] theorem isEqv_nil_cons : isEqv ([] : List α) (a::as) eqv = false := rfl
@[simp, grind =] theorem isEqv_cons_nil : isEqv (a::as : List α) [] eqv = false := rfl
@[grind =] theorem isEqv_cons : isEqv (a::as) (b::bs) eqv = (eqv a b && isEqv as bs eqv) := rfl
@[grind =] theorem isEqv_cons_cons : isEqv (a::as) (b::bs) eqv = (eqv a b && isEqv as bs eqv) := rfl
@[deprecated isEqv_cons_cons (since := "2026-02-26")]
theorem isEqv_cons₂ : isEqv (a::as) (b::bs) eqv = (eqv a b && isEqv as bs eqv) := isEqv_cons_cons
/-! ## Lexicographic ordering -/
@@ -1048,9 +1055,12 @@ def dropLast {α} : List α → List α
@[simp, grind =] theorem dropLast_nil : ([] : List α).dropLast = [] := rfl
@[simp, grind =] theorem dropLast_singleton : [x].dropLast = [] := rfl
@[simp, grind =] theorem dropLast_cons :
@[simp, grind =] theorem dropLast_cons_cons :
(x::y::zs).dropLast = x :: (y::zs).dropLast := rfl
@[deprecated dropLast_cons_cons (since := "2026-02-26")]
theorem dropLast_cons₂ : (x::y::zs).dropLast = x :: (y::zs).dropLast := dropLast_cons_cons
-- Later this can be proved by `simp` via `[List.length_dropLast, List.length_cons, Nat.add_sub_cancel]`,
-- but we need this while bootstrapping `Array`.
@[simp] theorem length_dropLast_cons {a : α} {as : List α} : (a :: as).dropLast.length = as.length := by
@@ -1085,7 +1095,11 @@ inductive Sublist {α} : List α → List α → Prop
/-- If `l₁` is a subsequence of `l₂`, then it is also a subsequence of `a :: l₂`. -/
| cons a : Sublist l₁ l₂ Sublist l₁ (a :: l₂)
/-- If `l₁` is a subsequence of `l₂`, then `a :: l₁` is a subsequence of `a :: l₂`. -/
| cons a : Sublist l₁ l₂ Sublist (a :: l₁) (a :: l₂)
| cons_cons a : Sublist l₁ l₂ Sublist (a :: l₁) (a :: l₂)
set_option linter.missingDocs false in
@[deprecated Sublist.cons_cons (since := "2026-02-26"), match_pattern]
abbrev Sublist.cons₂ := @Sublist.cons_cons
@[inherit_doc] scoped infixl:50 " <+ " => Sublist
@@ -1143,9 +1157,13 @@ def isPrefixOf [BEq α] : List α → List α → Bool
@[simp, grind =] theorem isPrefixOf_nil_left [BEq α] : isPrefixOf ([] : List α) l = true := by
simp [isPrefixOf]
@[simp, grind =] theorem isPrefixOf_cons_nil [BEq α] : isPrefixOf (a::as) ([] : List α) = false := rfl
@[grind =] theorem isPrefixOf_cons [BEq α] {a : α} :
@[grind =] theorem isPrefixOf_cons_cons [BEq α] {a : α} :
isPrefixOf (a::as) (b::bs) = (a == b && isPrefixOf as bs) := rfl
@[deprecated isPrefixOf_cons_cons (since := "2026-02-26")]
theorem isPrefixOf_cons₂ [BEq α] {a : α} :
isPrefixOf (a::as) (b::bs) = (a == b && isPrefixOf as bs) := isPrefixOf_cons_cons
/--
If the first list is a prefix of the second, returns the result of dropping the prefix.
@@ -2164,10 +2182,16 @@ def intersperse (sep : α) : (l : List α) → List α
| x::xs => x :: sep :: intersperse sep xs
@[simp] theorem intersperse_nil {sep : α} : ([] : List α).intersperse sep = [] := rfl
@[simp] theorem intersperse_single {x : α} {sep : α} : [x].intersperse sep = [x] := rfl
@[simp] theorem intersperse_cons₂ {x : α} {y : α} {zs : List α} {sep : α} :
@[simp] theorem intersperse_singleton {x : α} {sep : α} : [x].intersperse sep = [x] := rfl
@[deprecated intersperse_singleton (since := "2026-02-26")]
theorem intersperse_single {x : α} {sep : α} : [x].intersperse sep = [x] := rfl
@[simp] theorem intersperse_cons_cons {x : α} {y : α} {zs : List α} {sep : α} :
(x::y::zs).intersperse sep = x::sep::((y::zs).intersperse sep) := rfl
@[deprecated intersperse_cons_cons (since := "2026-02-26")]
theorem intersperse_cons₂ {x : α} {y : α} {zs : List α} {sep : α} :
(x::y::zs).intersperse sep = x::sep::((y::zs).intersperse sep) := intersperse_cons_cons
/-! ### intercalate -/
set_option linter.listVariables false in

View File

@@ -125,7 +125,7 @@ protected theorem Sublist.eraseP : l₁ <+ l₂ → l₁.eraseP p <+ l₂.eraseP
by_cases h : p a
· simpa [h] using s.eraseP.trans eraseP_sublist
· simpa [h] using s.eraseP.cons _
| .cons a s => by
| .cons_cons a s => by
by_cases h : p a
· simpa [h] using s
· simpa [h] using s.eraseP

View File

@@ -184,7 +184,7 @@ theorem Sublist.findSome?_isSome {l₁ l₂ : List α} (h : l₁ <+ l₂) :
induction h with
| slnil => simp
| cons a h ih
| cons a h ih =>
| cons_cons a h ih =>
simp only [findSome?]
split
· simp_all
@@ -455,7 +455,7 @@ theorem Sublist.find?_isSome {l₁ l₂ : List α} (h : l₁ <+ l₂) : (l₁.fi
induction h with
| slnil => simp
| cons a h ih
| cons a h ih =>
| cons_cons a h ih =>
simp only [find?]
split
· simp

View File

@@ -1394,7 +1394,7 @@ theorem head_filter_of_pos {p : α → Bool} {l : List α} (w : l ≠ []) (h : p
@[simp] theorem filter_sublist {p : α Bool} : {l : List α}, filter p l <+ l
| [] => .slnil
| a :: l => by rw [filter]; split <;> simp [Sublist.cons, Sublist.cons, filter_sublist]
| a :: l => by rw [filter]; split <;> simp [Sublist.cons, Sublist.cons_cons, filter_sublist]
/-! ### filterMap -/
@@ -1838,6 +1838,11 @@ theorem sum_append [Add α] [Zero α] [Std.LawfulLeftIdentity (α := α) (· +
[Std.Associative (α := α) (· + ·)] {l₁ l₂ : List α} : (l₁ ++ l₂).sum = l₁.sum + l₂.sum := by
induction l₁ generalizing l₂ <;> simp_all [Std.Associative.assoc, Std.LawfulLeftIdentity.left_id]
@[simp, grind =]
theorem sum_singleton [Add α] [Zero α] [Std.LawfulRightIdentity (· + ·) (0 : α)] {x : α} :
[x].sum = x := by
simp [List.sum_eq_foldr, Std.LawfulRightIdentity.right_id x]
@[simp, grind =]
theorem sum_reverse [Zero α] [Add α] [Std.Associative (α := α) (· + ·)]
[Std.Commutative (α := α) (· + ·)]
@@ -2727,6 +2732,31 @@ theorem foldr_assoc {op : ααα} [ha : Std.Associative op] :
simp only [foldr_cons, ha.assoc]
rw [foldr_assoc]
theorem foldl_eq_apply_foldr {xs : List α} {f : α α α}
[Std.Associative f] [Std.LawfulRightIdentity f init] :
xs.foldl f x = f x (xs.foldr f init) := by
induction xs generalizing x
· simp [Std.LawfulRightIdentity.right_id]
· simp [foldl_assoc, *]
theorem foldr_eq_apply_foldl {xs : List α} {f : α α α}
[Std.Associative f] [Std.LawfulLeftIdentity f init] :
xs.foldr f x = f (xs.foldl f init) x := by
have : Std.Associative (fun x y => f y x) := by simp [Std.Associative.assoc]
have : Std.RightIdentity (fun x y => f y x) init :=
have : Std.LawfulRightIdentity (fun x y => f y x) init := by simp [Std.LawfulLeftIdentity.left_id]
rw [ List.reverse_reverse (as := xs), foldr_reverse, foldl_eq_apply_foldr, foldl_reverse]
theorem foldr_eq_foldl {xs : List α} {f : α α α}
[Std.Associative f] [Std.LawfulIdentity f init] :
xs.foldr f init = xs.foldl f init := by
simp [foldl_eq_apply_foldr, Std.LawfulLeftIdentity.left_id]
theorem sum_eq_foldl [Zero α] [Add α] [Std.Associative (α := α) (· + ·)]
[Std.LawfulIdentity (· + ·) (0 : α)] {xs : List α} :
xs.sum = xs.foldl (init := 0) (· + ·) := by
simp [sum_eq_foldr, foldl_eq_apply_foldr, Std.LawfulLeftIdentity.left_id]
-- The argument `f : α₁ → α₂` is intentionally explicit, as it is sometimes not found by unification.
theorem foldl_hom (f : α₁ α₂) {g₁ : α₁ β α₁} {g₂ : α₂ β α₂} {l : List β} {init : α₁}
(H : x y, g₂ (f x) y = f (g₁ x y)) : l.foldl g₂ (f init) = f (l.foldl g₁ init) := by
@@ -3124,7 +3154,7 @@ theorem dropLast_concat_getLast : ∀ {l : List α} (h : l ≠ []), dropLast l +
| [], h => absurd rfl h
| [_], _ => rfl
| _ :: b :: l, _ => by
rw [dropLast_cons, cons_append, getLast_cons (cons_ne_nil _ _)]
rw [dropLast_cons_cons, cons_append, getLast_cons (cons_ne_nil _ _)]
congr
exact dropLast_concat_getLast (cons_ne_nil b l)
@@ -3744,4 +3774,28 @@ theorem get_mem : ∀ (l : List α) n, get l n ∈ l
theorem mem_iff_get {a} {l : List α} : a l n, get l n = a :=
get_of_mem, fun _, e => e get_mem ..
/-! ### `intercalate` -/
@[simp]
theorem intercalate_nil {ys : List α} : ys.intercalate [] = [] := rfl
@[simp]
theorem intercalate_singleton {ys xs : List α} : ys.intercalate [xs] = xs := by
simp [intercalate]
@[simp]
theorem intercalate_cons_cons {ys l l' : List α} {zs : List (List α)} :
ys.intercalate (l :: l' :: zs) = l ++ ys ++ ys.intercalate (l' :: zs) := by
simp [intercalate]
@[simp]
theorem intercalate_cons_cons_left {ys l : List α} {x : α} {zs : List (List α)} :
ys.intercalate ((x :: l) :: zs) = x :: ys.intercalate (l :: zs) := by
cases zs <;> simp
theorem intercalate_cons_of_ne_nil {ys l : List α} {zs : List (List α)} (h : zs []) :
ys.intercalate (l :: zs) = l ++ ys ++ ys.intercalate zs :=
match zs, h with
| l'::zs, _ => by simp
end List

View File

@@ -42,7 +42,7 @@ theorem beq_eq_isEqv [BEq α] {as bs : List α} : as.beq bs = isEqv as bs (· ==
cases bs with
| nil => simp
| cons b bs =>
simp only [beq_cons, ih, isEqv_eq_decide, length_cons, Nat.add_right_cancel_iff,
simp only [beq_cons_cons, ih, isEqv_eq_decide, length_cons, Nat.add_right_cancel_iff,
Nat.forall_lt_succ_left', getElem_cons_zero, getElem_cons_succ, Bool.decide_and,
Bool.decide_eq_true]
split <;> simp

View File

@@ -106,7 +106,7 @@ theorem Sublist.le_countP (s : l₁ <+ l₂) (p) : countP p l₂ - (l₂.length
have := s.le_countP p
have := s.length_le
split <;> omega
| .cons a s =>
| .cons_cons a s =>
rename_i l₁ l₂
simp only [countP_cons, length_cons]
have := s.le_countP p

View File

@@ -38,7 +38,7 @@ theorem map_getElem_sublist {l : List α} {is : List (Fin l.length)} (h : is.Pai
simp only [Fin.getElem_fin, map_cons]
have := IH h.of_cons (hd+1) (pairwise_cons.mp h).1
specialize his hd (.head _)
have := (drop_eq_getElem_cons ..).symm this.cons (get l hd)
have := (drop_eq_getElem_cons ..).symm this.cons_cons (get l hd)
have := Sublist.append (nil_sublist (take hd l |>.drop j)) this
rwa [nil_append, (drop_append_of_le_length ?_), take_append_drop] at this
simp [Nat.min_eq_left (Nat.le_of_lt hd.isLt), his]
@@ -55,7 +55,7 @@ theorem sublist_eq_map_getElem {l l' : List α} (h : l' <+ l) : ∃ is : List (F
refine is.map (·.succ), ?_
set_option backward.isDefEq.respectTransparency false in
simpa [Function.comp_def, pairwise_map]
| cons _ _ IH =>
| cons_cons _ _ IH =>
rcases IH with is,IH
refine 0, by simp [Nat.zero_lt_succ] :: is.map (·.succ), ?_
set_option backward.isDefEq.respectTransparency false in

View File

@@ -207,7 +207,7 @@ theorem take_eq_dropLast {l : List α} {i : Nat} (h : i + 1 = l.length) :
· cases as with
| nil => simp_all
| cons b bs =>
simp only [take_succ_cons, dropLast_cons]
simp only [take_succ_cons, dropLast_cons_cons]
rw [ih]
simpa using h

View File

@@ -33,7 +33,7 @@ open Nat
@[grind ] theorem Pairwise.sublist : l₁ <+ l₂ l₂.Pairwise R l₁.Pairwise R
| .slnil, h => h
| .cons _ s, .cons _ h₂ => h₂.sublist s
| .cons _ s, .cons h₁ h₂ => (h₂.sublist s).cons fun _ h => h₁ _ (s.subset h)
| .cons_cons _ s, .cons h₁ h₂ => (h₂.sublist s).cons fun _ h => h₁ _ (s.subset h)
theorem Pairwise.imp {α R S} (H : {a b}, R a b S a b) :
{l : List α}, l.Pairwise R l.Pairwise S
@@ -226,7 +226,7 @@ theorem pairwise_iff_forall_sublist : l.Pairwise R ↔ (∀ {a b}, [a,b] <+ l
constructor <;> intro h
· intro
| a, b, .cons _ hab => exact IH.mp h.2 hab
| _, b, .cons _ hab => refine h.1 _ (hab.subset ?_); simp
| _, b, .cons_cons _ hab => refine h.1 _ (hab.subset ?_); simp
· constructor
· intro x hx
apply h

View File

@@ -252,13 +252,13 @@ theorem exists_perm_sublist {l₁ l₂ l₂' : List α} (s : l₁ <+ l₂) (p :
| cons x _ IH =>
match s with
| .cons _ s => let l₁', p', s' := IH s; exact l₁', p', s'.cons _
| .cons _ s => let l₁', p', s' := IH s; exact x :: l₁', p'.cons x, s'.cons _
| .cons_cons _ s => let l₁', p', s' := IH s; exact x :: l₁', p'.cons x, s'.cons_cons _
| swap x y l' =>
match s with
| .cons _ (.cons _ s) => exact _, .rfl, (s.cons _).cons _
| .cons _ (.cons _ s) => exact x :: _, .rfl, (s.cons _).cons _
| .cons _ (.cons _ s) => exact y :: _, .rfl, (s.cons _).cons _
| .cons _ (.cons _ s) => exact x :: y :: _, .swap .., (s.cons _).cons _
| .cons _ (.cons_cons _ s) => exact x :: _, .rfl, (s.cons _).cons_cons _
| .cons_cons _ (.cons _ s) => exact y :: _, .rfl, (s.cons_cons _).cons _
| .cons_cons _ (.cons_cons _ s) => exact x :: y :: _, .swap .., (s.cons_cons _).cons_cons _
| trans _ _ IH₁ IH₂ =>
let _, pm, sm := IH₁ s
let r₁, pr, sr := IH₂ sm
@@ -277,7 +277,7 @@ theorem Sublist.exists_perm_append {l₁ l₂ : List α} : l₁ <+ l₂ → ∃
| Sublist.cons a s =>
let l, p := Sublist.exists_perm_append s
a :: l, (p.cons a).trans perm_middle.symm
| Sublist.cons a s =>
| Sublist.cons_cons a s =>
let l, p := Sublist.exists_perm_append s
l, p.cons a

View File

@@ -452,7 +452,7 @@ theorem sublist_mergeSort
have h' := sublist_mergeSort trans total hc h
rw [h₂] at h'
exact h'.middle a
| _, _, @Sublist.cons _ l₁ l₂ a h => by
| _, _, @Sublist.cons_cons _ l₁ l₂ a h => by
rename_i hc
obtain l₃, l₄, h₁, h₂, h₃ := mergeSort_cons trans total a l₂
rw [h₁]
@@ -460,7 +460,7 @@ theorem sublist_mergeSort
rw [h₂] at h'
simp only [Bool.not_eq_true', tail_cons] at h₃ h'
exact
sublist_append_of_sublist_right (Sublist.cons a
sublist_append_of_sublist_right (Sublist.cons_cons a
((fun w => Sublist.of_sublist_append_right w h') fun b m₁ m₃ =>
(Bool.eq_not_self true).mp ((rel_of_pairwise_cons hc m₁).symm.trans (h₃ b m₃))))

View File

@@ -0,0 +1,10 @@
/-
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Markus Himmel
-/
module
prelude
public import Init.Data.List.SplitOn.Basic
public import Init.Data.List.SplitOn.Lemmas

View File

@@ -0,0 +1,70 @@
/-
Copyright (c) 2016 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Init.Data.List.Basic
public import Init.NotationExtra
import Init.Data.Array.Bootstrap
import Init.Data.List.Lemmas
public section
set_option doc.verso true
namespace List
/--
Split a list at every element satisfying a predicate, and then prepend {lean}`acc.reverse` to the
first element of the result.
* {lean}`[1, 1, 2, 3, 2, 4, 4].splitOnPPrepend (· == 2) [0, 5] = [[5, 0, 1, 1], [3], [4, 4]]`
-/
noncomputable def splitOnPPrepend (p : α Bool) : (l : List α) (acc : List α) List (List α)
| [], acc => [acc.reverse]
| a :: t, acc => if p a then acc.reverse :: splitOnPPrepend p t [] else splitOnPPrepend p t (a::acc)
/--
Split a list at every element satisfying a predicate. The separators are not in the result.
Examples:
* {lean}`[1, 1, 2, 3, 2, 4, 4].splitOnP (· == 2) = [[1, 1], [3], [4, 4]]`
-/
noncomputable def splitOnP (p : α Bool) (l : List α) : List (List α) :=
splitOnPPrepend p l []
@[deprecated splitOnPPrepend (since := "2026-02-26")]
noncomputable def splitOnP.go (p : α Bool) (l acc : List α) : List (List α) :=
splitOnPPrepend p l acc
/-- Tail recursive version of {name}`splitOnP`. -/
@[inline]
def splitOnPTR (p : α Bool) (l : List α) : List (List α) := go l #[] #[] where
@[specialize] go : List α Array α Array (List α) List (List α)
| [], acc, r => r.toListAppend [acc.toList]
| a :: t, acc, r => bif p a then go t #[] (r.push acc.toList) else go t (acc.push a) r
@[csimp] theorem splitOnP_eq_splitOnPTR : @splitOnP = @splitOnPTR := by
funext α P l
simp only [splitOnPTR]
suffices xs acc r,
splitOnPTR.go P xs acc r = r.toList ++ splitOnPPrepend P xs acc.toList.reverse from
(this l #[] #[]).symm
intro xs acc r
induction xs generalizing acc r with
| nil => simp [splitOnPPrepend, splitOnPTR.go]
| cons x xs IH => cases h : P x <;> simp [splitOnPPrepend, splitOnPTR.go, *]
/--
Split a list at every occurrence of a separator element. The separators are not in the result.
Examples:
* {lean}`[1, 1, 2, 3, 2, 4, 4].splitOn 2 = [[1, 1], [3], [4, 4]]`
-/
@[inline] def splitOn [BEq α] (a : α) (as : List α) : List (List α) :=
as.splitOnP (· == a)
end List

View File

@@ -0,0 +1,208 @@
/-
Copyright (c) 2014 Parikshit Khanna. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Parikshit Khanna, Jeremy Avigad, Leonardo de Moura, Floris van Doorn, Mario Carneiro, Markus Himmel
-/
module
prelude
public import Init.Data.List.SplitOn.Basic
import all Init.Data.List.SplitOn.Basic
import Init.Data.List.Nat.Modify
import Init.ByCases
public section
namespace List
variable {p : α Bool} {xs : List α} {ls : List (List α)}
@[simp]
theorem splitOn_nil [BEq α] (a : α) : [].splitOn a = [[]] :=
(rfl)
@[simp]
theorem splitOnP_nil : [].splitOnP p = [[]] :=
(rfl)
@[simp]
theorem splitOnPPrepend_ne_nil (p : α Bool) (xs acc : List α) : splitOnPPrepend p xs acc [] := by
fun_induction splitOnPPrepend <;> simp_all
@[deprecated splitOnPPrepend_ne_nil (since := "2026-02-26")]
theorem splitOnP.go_ne_nil (p : α Bool) (xs acc : List α) : splitOnPPrepend p xs acc [] :=
splitOnPPrepend_ne_nil p xs acc
@[simp] theorem splitOnPPrepend_nil {acc : List α} : splitOnPPrepend p [] acc = [acc.reverse] := (rfl)
@[simp] theorem splitOnPPrepend_nil_right : splitOnPPrepend p xs [] = splitOnP p xs := (rfl)
theorem splitOnP_eq_splitOnPPrepend : splitOnP p xs = splitOnPPrepend p xs [] := (rfl)
theorem splitOnPPrepend_cons_eq_if {x : α} {xs acc : List α} :
splitOnPPrepend p (x :: xs) acc =
if p x then acc.reverse :: splitOnP p xs else splitOnPPrepend p xs (x :: acc) := by
simp [splitOnPPrepend]
theorem splitOnPPrepend_cons_pos {p : α Bool} {a : α} {l acc : List α} (h : p a) :
splitOnPPrepend p (a :: l) acc = acc.reverse :: splitOnP p l := by
simp [splitOnPPrepend, h]
theorem splitOnPPrepend_cons_neg {p : α Bool} {a : α} {l acc : List α} (h : p a = false) :
splitOnPPrepend p (a :: l) acc = splitOnPPrepend p l (a :: acc) := by
simp [splitOnPPrepend, h]
theorem splitOnP_cons_eq_if_splitOnPPrepend {x : α} {xs : List α} :
splitOnP p (x :: xs) = if p x then [] :: splitOnP p xs else splitOnPPrepend p xs [x] := by
simp [splitOnPPrepend_cons_eq_if, splitOnPPrepend_nil_right]
theorem splitOnPPrepend_eq_modifyHead {xs acc : List α} :
splitOnPPrepend p xs acc = modifyHead (acc.reverse ++ ·) (splitOnP p xs) := by
induction xs generalizing acc with
| nil => simp
| cons hd tl ih =>
simp [splitOnPPrepend_cons_eq_if, splitOnP_cons_eq_if_splitOnPPrepend, ih]
split <;> simp <;> congr
@[deprecated splitOnPPrepend_eq_modifyHead (since := "2026-02-26")]
theorem splitOnP.go_acc {xs acc : List α} :
splitOnPPrepend p xs acc = modifyHead (acc.reverse ++ ·) (splitOnP p xs) :=
splitOnPPrepend_eq_modifyHead
@[simp]
theorem splitOnP_ne_nil (p : α Bool) (xs : List α) : xs.splitOnP p [] :=
splitOnPPrepend_ne_nil p xs []
theorem splitOnP_cons_eq_if_modifyHead (x : α) (xs : List α) :
(x :: xs).splitOnP p =
if p x then [] :: xs.splitOnP p else (xs.splitOnP p).modifyHead (cons x) := by
simp [splitOnP_cons_eq_if_splitOnPPrepend, splitOnPPrepend_eq_modifyHead]
@[deprecated splitOnP_cons_eq_if_modifyHead (since := "2026-02-26")]
theorem splitOnP_cons (x : α) (xs : List α) :
(x :: xs).splitOnP p =
if p x then [] :: xs.splitOnP p else (xs.splitOnP p).modifyHead (cons x) :=
splitOnP_cons_eq_if_modifyHead x xs
/-- The original list `L` can be recovered by flattening the lists produced by `splitOnP p L`,
interspersed with the elements `L.filter p`. -/
theorem splitOnP_spec (as : List α) :
flatten (zipWith (· ++ ·) (splitOnP p as) (((as.filter p).map fun x => [x]) ++ [[]])) = as := by
induction as with
| nil => simp
| cons a as' ih =>
rw [splitOnP_cons_eq_if_modifyHead]
split <;> simp [*, flatten_zipWith, splitOnP_ne_nil]
where
flatten_zipWith {xs ys : List (List α)} {a : α} (hxs : xs []) (hys : ys []) :
flatten (zipWith (fun x x_1 => x ++ x_1) (modifyHead (cons a) xs) ys) =
a :: flatten (zipWith (fun x x_1 => x ++ x_1) xs ys) := by
cases xs <;> cases ys <;> simp_all
/-- If no element satisfies `p` in the list `xs`, then `xs.splitOnP p = [xs]` -/
theorem splitOnP_eq_singleton (h : x xs, p x = false) : xs.splitOnP p = [xs] := by
induction xs with
| nil => simp
| cons hd tl ih =>
simp only [mem_cons, forall_eq_or_imp] at h
simp [splitOnP_cons_eq_if_modifyHead, h.1, ih h.2]
@[deprecated splitOnP_eq_singleton (since := "2026-02-26")]
theorem splitOnP_eq_single (h : x xs, p x = false) : xs.splitOnP p = [xs] :=
splitOnP_eq_singleton h
/-- When a list of the form `[...xs, sep, ...as]` is split at the `sep` element satisfying `p`,
the result is the concatenation of `splitOnP` called on `xs` and `as` -/
theorem splitOnP_append_cons (xs as : List α) {sep : α} (hsep : p sep) :
(xs ++ sep :: as).splitOnP p = List.splitOnP p xs ++ List.splitOnP p as := by
induction xs with
| nil => simp [splitOnP_cons_eq_if_modifyHead, hsep]
| cons hd tl ih =>
obtain hd1, tl1, h1' := List.exists_cons_of_ne_nil (List.splitOnP_ne_nil (p := p) (xs := tl))
by_cases hPh : p hd <;> simp [splitOnP_cons_eq_if_modifyHead, *]
/-- When a list of the form `[...xs, sep, ...as]` is split on `p`, the first element is `xs`,
assuming no element in `xs` satisfies `p` but `sep` does satisfy `p` -/
theorem splitOnP_append_cons_of_forall_mem (h : x xs, p x = false) (sep : α)
(hsep : p sep = true) (as : List α) : (xs ++ sep :: as).splitOnP p = xs :: as.splitOnP p := by
rw [splitOnP_append_cons xs as hsep, splitOnP_eq_singleton h, singleton_append]
@[deprecated splitOnP_append_cons_of_forall_mem (since := "2026-02-26")]
theorem splitOnP_first (h : x xs, p x = false) (sep : α)
(hsep : p sep = true) (as : List α) : (xs ++ sep :: as).splitOnP p = xs :: as.splitOnP p :=
splitOnP_append_cons_of_forall_mem h sep hsep as
theorem splitOn_eq_splitOnP [BEq α] {x : α} {xs : List α} : xs.splitOn x = xs.splitOnP (· == x) :=
(rfl)
@[simp]
theorem splitOn_ne_nil [BEq α] (a : α) (xs : List α) : xs.splitOn a [] := by
simp [splitOn_eq_splitOnP]
theorem splitOn_cons_eq_if_modifyHead [BEq α] {a : α} (x : α) (xs : List α) :
(x :: xs).splitOn a =
if x == a then [] :: xs.splitOn a else (xs.splitOn a).modifyHead (cons x) := by
simpa [splitOn_eq_splitOnP] using splitOnP_cons_eq_if_modifyHead ..
/-- If no element satisfies `p` in the list `xs`, then `xs.splitOnP p = [xs]` -/
theorem splitOn_eq_singleton_of_beq_eq_false [BEq α] {a : α} (h : x xs, (x == a) = false) :
xs.splitOn a = [xs] := by
simpa [splitOn_eq_splitOnP] using splitOnP_eq_singleton h
theorem splitOn_eq_singleton [BEq α] [LawfulBEq α] {a : α} (h : a xs) :
xs.splitOn a = [xs] :=
splitOn_eq_singleton_of_beq_eq_false
(fun _ hb => beq_eq_false_iff_ne.2 (fun hab => absurd hb (hab h)))
/-- When a list of the form `[...xs, sep, ...as]` is split at the `sep` element equal to `a`,
the result is the concatenation of `splitOnP` called on `xs` and `as` -/
theorem splitOn_append_cons_of_beq [BEq α] {a : α} (xs as : List α) {sep : α} (hsep : sep == a) :
(xs ++ sep :: as).splitOn a = List.splitOn a xs ++ List.splitOn a as := by
simpa [splitOn_eq_splitOnP] using splitOnP_append_cons (p := (· == a)) _ _ hsep
/-- When a list of the form `[...xs, sep, ...as]` is split at `a`,
the result is the concatenation of `splitOnP` called on `xs` and `as` -/
theorem splitOn_append_cons_self [BEq α] [ReflBEq α] {a : α} (xs as : List α) :
(xs ++ a :: as).splitOn a = List.splitOn a xs ++ List.splitOn a as :=
splitOn_append_cons_of_beq _ _ (BEq.refl _)
/-- When a list of the form `[...xs, sep, ...as]` is split at `a`, the first element is `xs`,
assuming no element in `xs` is equal to `a` but `sep` is equal to `a`. -/
theorem splitOn_append_cons_of_forall_mem_beq_eq_false [BEq α] {a : α}
(h : x xs, (x == a) = false) (sep : α)
(hsep : sep == a) (as : List α) : (xs ++ sep :: as).splitOn a = xs :: as.splitOn a := by
simpa [splitOn_eq_splitOnP] using splitOnP_append_cons_of_forall_mem h _ hsep _
/-- When a list of the form `[...xs, a, ...as]` is split at `a`, the first element is `xs`,
assuming no element in `xs` is equal to `a`. -/
theorem splitOn_append_cons_self_of_not_mem [BEq α] [LawfulBEq α] {a : α}
(h : a xs) (as : List α) : (xs ++ a :: as).splitOn a = xs :: as.splitOn a :=
splitOn_append_cons_of_forall_mem_beq_eq_false
(fun b hb => beq_eq_false_iff_ne.2 fun hab => absurd hb (hab h)) _ (by simp) _
/-- `intercalate [x]` is the left inverse of `splitOn x` -/
@[simp]
theorem intercalate_splitOn [BEq α] [LawfulBEq α] (x : α) : [x].intercalate (xs.splitOn x) = xs := by
induction xs with
| nil => simp
| cons hd tl ih =>
simp only [splitOn_cons_eq_if_modifyHead, beq_iff_eq]
split
· simp_all [intercalate_cons_of_ne_nil, splitOn_ne_nil]
· have hsp := splitOn_ne_nil x tl
generalize splitOn x tl = ls at *
cases ls <;> simp_all
/-- `splitOn x` is the left inverse of `intercalate [x]`, on the domain
consisting of each nonempty list of lists `ls` whose elements do not contain `x` -/
theorem splitOn_intercalate [BEq α] [LawfulBEq α] (x : α) (hx : l ls, x l) (hls : ls []) :
([x].intercalate ls).splitOn x = ls := by
induction ls with
| nil => simp at hls
| cons hd tl ih =>
simp only [mem_cons, forall_eq_or_imp] at hx
match tl with
| [] => simpa using splitOn_eq_singleton hx.1
| t::tl =>
simp only [intercalate_cons_cons, append_assoc, cons_append, nil_append]
rw [splitOn_append_cons_self_of_not_mem hx.1, ih hx.2 (by simp)]
end List

View File

@@ -32,8 +32,12 @@ open Nat
section isPrefixOf
variable [BEq α]
@[simp, grind =] theorem isPrefixOf_cons_self [LawfulBEq α] {a : α} :
isPrefixOf (a::as) (a::bs) = isPrefixOf as bs := by simp [isPrefixOf_cons]
@[simp, grind =] theorem isPrefixOf_cons_cons_self [LawfulBEq α] {a : α} :
isPrefixOf (a::as) (a::bs) = isPrefixOf as bs := by simp [isPrefixOf_cons_cons]
@[deprecated isPrefixOf_cons_cons_self (since := "2026-02-26")]
theorem isPrefixOf_cons₂_self [LawfulBEq α] {a : α} :
isPrefixOf (a::as) (a::bs) = isPrefixOf as bs := isPrefixOf_cons_cons_self
@[simp] theorem isPrefixOf_length_pos_nil {l : List α} (h : 0 < l.length) : isPrefixOf l [] = false := by
cases l <;> simp_all [isPrefixOf]
@@ -45,7 +49,7 @@ variable [BEq α]
| cons _ _ ih =>
cases n
· simp
· simp [replicate_succ, isPrefixOf_cons, ih, Nat.succ_le_succ_iff, Bool.and_left_comm]
· simp [replicate_succ, isPrefixOf_cons_cons, ih, Nat.succ_le_succ_iff, Bool.and_left_comm]
end isPrefixOf
@@ -169,18 +173,18 @@ theorem subset_replicate {n : Nat} {a : α} {l : List α} (h : n ≠ 0) : l ⊆
@[simp, grind ] theorem Sublist.refl : l : List α, l <+ l
| [] => .slnil
| a :: l => (Sublist.refl l).cons a
| a :: l => (Sublist.refl l).cons_cons a
theorem Sublist.trans {l₁ l₂ l₃ : List α} (h₁ : l₁ <+ l₂) (h₂ : l₂ <+ l₃) : l₁ <+ l₃ := by
induction h₂ generalizing l₁ with
| slnil => exact h₁
| cons _ _ IH => exact (IH h₁).cons _
| @cons l₂ _ a _ IH =>
| @cons_cons l₂ _ a _ IH =>
generalize e : a :: l₂ = l₂' at h₁
match h₁ with
| .slnil => apply nil_sublist
| .cons a' h₁' => cases e; apply (IH h₁').cons
| .cons a' h₁' => cases e; apply (IH h₁').cons
| .cons_cons a' h₁' => cases e; apply (IH h₁').cons_cons
instance : Trans (@Sublist α) Sublist Sublist := Sublist.trans
@@ -193,23 +197,23 @@ theorem sublist_of_cons_sublist : a :: l₁ <+ l₂ → l₁ <+ l₂ :=
@[simp, grind =]
theorem cons_sublist_cons : a :: l₁ <+ a :: l₂ l₁ <+ l₂ :=
fun | .cons _ s => sublist_of_cons_sublist s | .cons _ s => s, .cons _
fun | .cons _ s => sublist_of_cons_sublist s | .cons_cons _ s => s, .cons_cons _
theorem sublist_or_mem_of_sublist (h : l <+ l₁ ++ a :: l₂) : l <+ l₁ ++ l₂ a l := by
induction l₁ generalizing l with
| nil => match h with
| .cons _ h => exact .inl h
| .cons _ h => exact .inr (.head ..)
| .cons_cons _ h => exact .inr (.head ..)
| cons b l₁ IH =>
match h with
| .cons _ h => exact (IH h).imp_left (Sublist.cons _)
| .cons _ h => exact (IH h).imp (Sublist.cons _) (.tail _)
| .cons_cons _ h => exact (IH h).imp (Sublist.cons_cons _) (.tail _)
@[grind ] theorem Sublist.subset : l₁ <+ l₂ l₁ l₂
| .slnil, _, h => h
| .cons _ s, _, h => .tail _ (s.subset h)
| .cons .., _, .head .. => .head ..
| .cons _ s, _, .tail _ h => .tail _ (s.subset h)
| .cons_cons .., _, .head .. => .head ..
| .cons_cons _ s, _, .tail _ h => .tail _ (s.subset h)
protected theorem Sublist.mem (hx : a l₁) (hl : l₁ <+ l₂) : a l₂ :=
hl.subset hx
@@ -245,7 +249,7 @@ theorem eq_nil_of_sublist_nil {l : List α} (s : l <+ []) : l = [] :=
theorem Sublist.length_le : l₁ <+ l₂ length l₁ length l₂
| .slnil => Nat.le_refl 0
| .cons _l s => le_succ_of_le (length_le s)
| .cons _ s => succ_le_succ (length_le s)
| .cons_cons _ s => succ_le_succ (length_le s)
grind_pattern Sublist.length_le => l₁ <+ l₂, length l₁
grind_pattern Sublist.length_le => l₁ <+ l₂, length l₂
@@ -253,7 +257,7 @@ grind_pattern Sublist.length_le => l₁ <+ l₂, length l₂
theorem Sublist.eq_of_length : l₁ <+ l₂ length l₁ = length l₂ l₁ = l₂
| .slnil, _ => rfl
| .cons a s, h => nomatch Nat.not_lt.2 s.length_le (h lt_succ_self _)
| .cons a s, h => by rw [s.eq_of_length (succ.inj h)]
| .cons_cons a s, h => by rw [s.eq_of_length (succ.inj h)]
theorem Sublist.eq_of_length_le (s : l₁ <+ l₂) (h : length l₂ length l₁) : l₁ = l₂ :=
s.eq_of_length <| Nat.le_antisymm s.length_le h
@@ -275,7 +279,7 @@ grind_pattern tail_sublist => tail l <+ _
protected theorem Sublist.tail : {l₁ l₂ : List α}, l₁ <+ l₂ tail l₁ <+ tail l₂
| _, _, slnil => .slnil
| _, _, Sublist.cons _ h => (tail_sublist _).trans h
| _, _, Sublist.cons _ h => h
| _, _, Sublist.cons_cons _ h => h
@[grind ]
theorem Sublist.of_cons_cons {l₁ l₂ : List α} {a b : α} (h : a :: l₁ <+ b :: l₂) : l₁ <+ l₂ :=
@@ -287,8 +291,8 @@ protected theorem Sublist.map (f : α → β) {l₁ l₂} (s : l₁ <+ l₂) : m
| slnil => simp
| cons a s ih =>
simpa using cons (f a) ih
| cons a s ih =>
simpa using cons (f a) ih
| cons_cons a s ih =>
simpa using cons_cons (f a) ih
grind_pattern Sublist.map => l₁ <+ l₂, map f l₁
grind_pattern Sublist.map => l₁ <+ l₂, map f l₂
@@ -338,7 +342,7 @@ theorem sublist_filterMap_iff {l₁ : List β} {f : α → Option β} :
cases h with
| cons _ h =>
exact l', h, rfl
| cons _ h =>
| cons_cons _ h =>
rename_i l'
exact l', h, by simp_all
· constructor
@@ -347,10 +351,10 @@ theorem sublist_filterMap_iff {l₁ : List β} {f : α → Option β} :
| cons _ h =>
obtain l', s, rfl := ih.1 h
exact l', Sublist.cons a s, rfl
| cons _ h =>
| cons_cons _ h =>
rename_i l'
obtain l', s, rfl := ih.1 h
refine a :: l', Sublist.cons a s, ?_
refine a :: l', Sublist.cons_cons a s, ?_
rwa [filterMap_cons_some]
· rintro l', h, rfl
replace h := h.filterMap f
@@ -369,7 +373,7 @@ theorem sublist_filter_iff {l₁ : List α} {p : α → Bool} :
theorem sublist_append_left : l₁ l₂ : List α, l₁ <+ l₁ ++ l₂
| [], _ => nil_sublist _
| _ :: l₁, l₂ => (sublist_append_left l₁ l₂).cons _
| _ :: l₁, l₂ => (sublist_append_left l₁ l₂).cons_cons _
grind_pattern sublist_append_left => Sublist, l₁ ++ l₂
@@ -382,7 +386,7 @@ grind_pattern sublist_append_right => Sublist, l₁ ++ l₂
@[simp, grind =] theorem singleton_sublist {a : α} {l} : [a] <+ l a l := by
refine fun h => h.subset (mem_singleton_self _), fun h => ?_
obtain _, _, rfl := append_of_mem h
exact ((nil_sublist _).cons _).trans (sublist_append_right ..)
exact ((nil_sublist _).cons_cons _).trans (sublist_append_right ..)
@[simp] theorem sublist_append_of_sublist_left (s : l <+ l₁) : l <+ l₁ ++ l₂ :=
s.trans <| sublist_append_left ..
@@ -404,7 +408,7 @@ theorem Sublist.append_left : l₁ <+ l₂ → ∀ l, l ++ l₁ <+ l ++ l₂ :=
theorem Sublist.append_right : l₁ <+ l₂ l, l₁ ++ l <+ l₂ ++ l
| .slnil, _ => Sublist.refl _
| .cons _ h, _ => (h.append_right _).cons _
| .cons _ h, _ => (h.append_right _).cons _
| .cons_cons _ h, _ => (h.append_right _).cons_cons _
theorem Sublist.append (hl : l₁ <+ l₂) (hr : r₁ <+ r₂) : l₁ ++ r₁ <+ l₂ ++ r₂ :=
(hl.append_right _).trans ((append_sublist_append_left _).2 hr)
@@ -418,10 +422,10 @@ theorem sublist_cons_iff {a : α} {l l'} :
· intro h
cases h with
| cons _ h => exact Or.inl h
| cons _ h => exact Or.inr _, rfl, h
| cons_cons _ h => exact Or.inr _, rfl, h
· rintro (h | r, rfl, h)
· exact h.cons _
· exact h.cons _
· exact h.cons_cons _
@[grind =]
theorem cons_sublist_iff {a : α} {l l'} :
@@ -435,7 +439,7 @@ theorem cons_sublist_iff {a : α} {l l'} :
| cons _ w =>
obtain r₁, r₂, rfl, h₁, h₂ := ih.1 w
exact a' :: r₁, r₂, by simp, mem_cons_of_mem a' h₁, h₂
| cons _ w =>
| cons_cons _ w =>
exact [a], l', by simp, mem_singleton_self _, w
· rintro r₁, r₂, w, h₁, h₂
rw [w, singleton_append]
@@ -458,7 +462,7 @@ theorem sublist_append_iff {l : List α} :
| cons _ w =>
obtain l₁, l₂, rfl, w₁, w₂ := ih.1 w
exact l₁, l₂, rfl, Sublist.cons r w₁, w₂
| cons _ w =>
| cons_cons _ w =>
rename_i l
obtain l₁, l₂, rfl, w₁, w₂ := ih.1 w
refine r :: l₁, l₂, by simp, cons_sublist_cons.mpr w₁, w₂
@@ -466,9 +470,9 @@ theorem sublist_append_iff {l : List α} :
cases w₁ with
| cons _ w₁ =>
exact Sublist.cons _ (Sublist.append w₁ w₂)
| cons _ w₁ =>
| cons_cons _ w₁ =>
rename_i l
exact Sublist.cons _ (Sublist.append w₁ w₂)
exact Sublist.cons_cons _ (Sublist.append w₁ w₂)
theorem append_sublist_iff {l₁ l₂ : List α} :
l₁ ++ l₂ <+ r r₁ r₂, r = r₁ ++ r₂ l₁ <+ r₁ l₂ <+ r₂ := by
@@ -516,7 +520,7 @@ theorem Sublist.middle {l : List α} (h : l <+ l₁ ++ l₂) (a : α) : l <+ l
theorem Sublist.reverse : l₁ <+ l₂ l₁.reverse <+ l₂.reverse
| .slnil => Sublist.refl _
| .cons _ h => by rw [reverse_cons]; exact sublist_append_of_sublist_left h.reverse
| .cons _ h => by rw [reverse_cons, reverse_cons]; exact h.reverse.append_right _
| .cons_cons _ h => by rw [reverse_cons, reverse_cons]; exact h.reverse.append_right _
@[simp, grind =] theorem reverse_sublist : l₁.reverse <+ l₂.reverse l₁ <+ l₂ :=
fun h => l₁.reverse_reverse l₂.reverse_reverse h.reverse, Sublist.reverse
@@ -558,7 +562,7 @@ theorem sublist_replicate_iff : l <+ replicate m a ↔ ∃ n, n ≤ m ∧ l = re
obtain n, le, rfl := ih.1 (sublist_of_cons_sublist w)
obtain rfl := (mem_replicate.1 (mem_of_cons_sublist w)).2
exact n+1, Nat.add_le_add_right le 1, rfl
| cons _ w =>
| cons_cons _ w =>
obtain n, le, rfl := ih.1 w
refine n+1, Nat.add_le_add_right le 1, by simp [replicate_succ]
· rintro n, le, w
@@ -644,7 +648,7 @@ theorem flatten_sublist_iff {L : List (List α)} {l} :
cases h_sub
case cons h_sub =>
exact isSublist_iff_sublist.mpr h_sub
case cons =>
case cons_cons =>
contradiction
instance [DecidableEq α] (l₁ l₂ : List α) : Decidable (l₁ <+ l₂) :=

View File

@@ -393,7 +393,7 @@ theorem isPrefixOfAux_toArray_zero [BEq α] (l₁ l₂ : List α) (hle : l₁.le
| [], _ => rw [dif_neg] <;> simp
| _::_, [] => simp at hle
| a::l₁, b::l₂ =>
simp [isPrefixOf_cons, isPrefixOfAux_toArray_succ', isPrefixOfAux_toArray_zero]
simp [isPrefixOf_cons_cons, isPrefixOfAux_toArray_succ', isPrefixOfAux_toArray_zero]
@[simp, grind =] theorem isPrefixOf_toArray [BEq α] (l₁ l₂ : List α) :
l₁.toArray.isPrefixOf l₂.toArray = l₁.isPrefixOf l₂ := by
@@ -407,7 +407,7 @@ theorem isPrefixOfAux_toArray_zero [BEq α] (l₁ l₂ : List α) (hle : l₁.le
cases l₂ with
| nil => simp
| cons b l₂ =>
simp only [isPrefixOf_cons, Bool.and_eq_false_imp]
simp only [isPrefixOf_cons_cons, Bool.and_eq_false_imp]
intro w
rw [ih]
simp_all

View File

@@ -369,6 +369,12 @@ theorem String.ofList_toList {s : String} : String.ofList s.toList = s := by
theorem String.asString_data {b : String} : String.ofList b.toList = b :=
String.ofList_toList
@[simp]
theorem String.ofList_comp_toList : String.ofList String.toList = id := by ext; simp
@[simp]
theorem String.toList_comp_ofList : String.toList String.ofList = id := by ext; simp
theorem String.ofList_injective {l₁ l₂ : List Char} (h : String.ofList l₁ = String.ofList l₂) : l₁ = l₂ := by
simpa using congrArg String.toList h
@@ -1525,6 +1531,11 @@ def Slice.Pos.toReplaceEnd {s : Slice} (p₀ : s.Pos) (pos : s.Pos) (h : pos ≤
theorem Slice.Pos.offset_sliceTo {s : Slice} {p₀ : s.Pos} {pos : s.Pos} {h : pos p₀} :
(sliceTo p₀ pos h).offset = pos.offset := (rfl)
@[simp]
theorem Slice.Pos.sliceTo_inj {s : Slice} {p₀ : s.Pos} {pos pos' : s.Pos} {h h'} :
p₀.sliceTo pos h = p₀.sliceTo pos' h' pos = pos' := by
simp [Pos.ext_iff]
@[simp]
theorem Slice.Pos.ofSliceTo_startPos {s : Slice} {pos : s.Pos} :
ofSliceTo (s.sliceTo pos).startPos = s.startPos := by
@@ -2347,6 +2358,16 @@ theorem Slice.Pos.ofSliceTo_le {s : Slice} {p₀ : s.Pos} {pos : (s.sliceTo p₀
ofSliceTo pos p₀ := by
simpa [Pos.le_iff, Pos.Raw.le_iff] using pos.isValidForSlice.le_utf8ByteSize
@[simp]
theorem Pos.ofSliceTo_lt_ofSliceTo_iff {s : String} {p : s.Pos}
{q r : (s.sliceTo p).Pos} : Pos.ofSliceTo q < Pos.ofSliceTo r q < r := by
simp [Pos.lt_iff, Slice.Pos.lt_iff, Pos.Raw.lt_iff]
@[simp]
theorem Pos.ofSliceTo_le_ofSliceTo_iff {s : String} {p : s.Pos}
{q r : (s.sliceTo p).Pos} : Pos.ofSliceTo q Pos.ofSliceTo r q r := by
simp [Pos.le_iff, Slice.Pos.le_iff, Pos.Raw.le_iff]
/-- Given a position in `s` that is at most `p₀`, obtain the corresponding position in `s.sliceTo p₀`. -/
@[inline]
def Pos.sliceTo {s : String} (p₀ : s.Pos) (pos : s.Pos) (h : pos p₀) :
@@ -2363,6 +2384,11 @@ def Pos.toReplaceEnd {s : String} (p₀ : s.Pos) (pos : s.Pos) (h : pos ≤ p₀
theorem Pos.offset_sliceTo {s : String} {p₀ : s.Pos} {pos : s.Pos} {h : pos p₀} :
(sliceTo p₀ pos h).offset = pos.offset := (rfl)
@[simp]
theorem Pos.sliceTo_inj {s : String} {p₀ : s.Pos} {pos pos' : s.Pos} {h h'} :
p₀.sliceTo pos h = p₀.sliceTo pos' h' pos = pos' := by
simp [Pos.ext_iff, Slice.Pos.ext_iff]
@[simp]
theorem Slice.Pos.ofSliceTo_sliceTo {s : Slice} {p₀ p : s.Pos} {h : p p₀} :
Slice.Pos.ofSliceTo (p₀.sliceTo p h) = p := by
@@ -2431,6 +2457,27 @@ theorem Slice.Pos.ofSlice_inj {s : Slice} {p₀ p₁ : s.Pos} {h} (pos₁ pos₂
ofSlice pos₁ = ofSlice pos₂ pos₁ = pos₂ := by
simp [Pos.ext_iff, Pos.Raw.ext_iff]
@[simp]
theorem Slice.Pos.le_ofSlice {s : Slice} {p₀ p₁ : s.Pos} {h}
{pos : (s.slice p₀ p₁ h).Pos} : p₀ ofSlice pos := by
simp [Pos.le_iff, Pos.Raw.le_iff]
@[simp]
theorem Slice.Pos.ofSlice_le {s : Slice} {p₀ p₁ : s.Pos} {h}
{pos : (s.slice p₀ p₁ h).Pos} : ofSlice pos p₁ := by
have := (Pos.Raw.isValidForSlice_slice _).1 pos.isValidForSlice |>.1
simpa [Pos.le_iff, Pos.Raw.le_iff]
@[simp]
theorem Slice.Pos.ofSlice_lt_ofSlice_iff {s : Slice} {p₀ p₁ : s.Pos} {h}
{q r : (s.slice p₀ p₁ h).Pos} : Slice.Pos.ofSlice q < Slice.Pos.ofSlice r q < r := by
simp [Slice.Pos.lt_iff, Pos.Raw.lt_iff]
@[simp]
theorem Slice.Pos.ofSlice_le_ofSlice_iff {s : Slice} {p₀ p₁ : s.Pos} {h}
{q r : (s.slice p₀ p₁ h).Pos} : Slice.Pos.ofSlice q Slice.Pos.ofSlice r q r := by
simp [Slice.Pos.le_iff, Pos.Raw.le_iff]
/-- Given a position in `s.slice p₀ p₁ h`, obtain the corresponding position in `s`. -/
@[inline]
def Pos.ofSlice {s : String} {p₀ p₁ : s.Pos} {h} (pos : (s.slice p₀ p₁ h).Pos) : s.Pos :=
@@ -2461,6 +2508,27 @@ theorem Pos.ofSlice_inj {s : String} {p₀ p₁ : s.Pos} {h} (pos₁ pos₂ : (s
ofSlice pos₁ = ofSlice pos₂ pos₁ = pos₂ := by
simp [Pos.ext_iff, Pos.Raw.ext_iff, Slice.Pos.ext_iff]
@[simp]
theorem Pos.le_ofSlice {s : String} {p₀ p₁ : s.Pos} {h}
{pos : (s.slice p₀ p₁ h).Pos} : p₀ ofSlice pos := by
simp [Pos.le_iff, Pos.Raw.le_iff]
@[simp]
theorem Pos.ofSlice_le {s : String} {p₀ p₁ : s.Pos} {h}
{pos : (s.slice p₀ p₁ h).Pos} : ofSlice pos p₁ := by
have := (Pos.Raw.isValidForSlice_slice _).1 pos.isValidForSlice |>.1
simpa [Pos.le_iff, Pos.Raw.le_iff]
@[simp]
theorem Pos.ofSlice_lt_ofSlice_iff {s : String} {p₀ p₁ : s.Pos} {h}
{q r : (s.slice p₀ p₁ h).Pos} : Pos.ofSlice q < Pos.ofSlice r q < r := by
simp [Pos.lt_iff, Slice.Pos.lt_iff, Pos.Raw.lt_iff]
@[simp]
theorem Pos.ofSlice_le_ofSlice_iff {s : String} {p₀ p₁ : s.Pos} {h}
{q r : (s.slice p₀ p₁ h).Pos} : Pos.ofSlice q Pos.ofSlice r q r := by
simp [Pos.le_iff, Slice.Pos.le_iff, Pos.Raw.le_iff]
theorem Slice.Pos.le_trans {s : Slice} {p q r : s.Pos} : p q q r p r := by
simpa [Pos.le_iff, Pos.Raw.le_iff] using Nat.le_trans
@@ -2484,6 +2552,48 @@ def Pos.slice {s : String} (pos : s.Pos) (p₀ p₁ : s.Pos) (h₁ : p₀ ≤ po
theorem Pos.offset_slice {s : String} {p₀ p₁ pos : s.Pos} {h₁ : p₀ pos} {h₂ : pos p₁} :
(pos.slice p₀ p₁ h₁ h₂).offset = pos.offset.unoffsetBy p₀.offset := (rfl)
@[simp]
theorem Slice.Pos.offset_slice {s : Slice} {p₀ p₁ pos : s.Pos} {h₁ : p₀ pos} {h₂ : pos p₁} :
(pos.slice p₀ p₁ h₁ h₂).offset = pos.offset.unoffsetBy p₀.offset := (rfl)
@[simp]
theorem Slice.Pos.ofSlice_slice {s : Slice} {p₀ p₁ pos : s.Pos}
{h₁ : p₀ pos} {h₂ : pos p₁} :
Slice.Pos.ofSlice (pos.slice p₀ p₁ h₁ h₂) = pos := by
simpa [Pos.ext_iff] using Pos.Raw.offsetBy_unoffsetBy_of_le h₁
@[simp]
theorem Slice.Pos.slice_ofSlice {s : Slice} {p₀ p₁ : s.Pos} {h}
{pos : (s.slice p₀ p₁ h).Pos} :
(Slice.Pos.ofSlice pos).slice p₀ p₁ Slice.Pos.le_ofSlice Slice.Pos.ofSlice_le = pos := by
simp [ Slice.Pos.ofSlice_inj]
@[simp]
theorem Pos.ofSlice_slice {s : String} {p₀ p₁ pos : s.Pos}
{h₁ : p₀ pos} {h₂ : pos p₁} :
Pos.ofSlice (pos.slice p₀ p₁ h₁ h₂) = pos := by
simpa [Pos.ext_iff] using Pos.Raw.offsetBy_unoffsetBy_of_le h₁
@[simp]
theorem Pos.slice_ofSlice {s : String} {p₀ p₁ : s.Pos} {h}
{pos : (s.slice p₀ p₁ h).Pos} :
(Pos.ofSlice pos).slice p₀ p₁ Pos.le_ofSlice Pos.ofSlice_le = pos := by
simp [ Pos.ofSlice_inj]
@[simp]
theorem Slice.Pos.slice_inj {s : Slice} {p₀ p₁ : s.Pos} {pos pos' : s.Pos}
{h₁ h₁' h₂ h₂'} :
pos.slice p₀ p₁ h₁ h₂ = pos'.slice p₀ p₁ h₁' h₂' pos = pos' := by
simp [Pos.ext_iff, Pos.Raw.ext_iff, Pos.le_iff, Pos.Raw.le_iff] at h₁ h₁'
omega
@[simp]
theorem Pos.slice_inj {s : String} {p₀ p₁ : s.Pos} {pos pos' : s.Pos}
{h₁ h₁' h₂ h₂'} :
pos.slice p₀ p₁ h₁ h₂ = pos'.slice p₀ p₁ h₁' h₂' pos = pos' := by
simp [Pos.ext_iff, Pos.Raw.ext_iff, Slice.Pos.ext_iff, Pos.le_iff, Pos.Raw.le_iff] at h₁ h₁'
omega
/--
Given a position in `s`, obtain the corresponding position in `s.slice p₀ p₁ h`, or panic if `pos`
is not between `p₀` and `p₁`.

View File

@@ -403,7 +403,6 @@ achieved by tracking the bounds by hand, the slice API is much more convenient.
`String.Slice` bundles proofs to ensure that the start and end positions always delineate a valid
string. For this reason, it should be preferred over `Substring.Raw`.
-/
@[ext]
structure Slice where
/-- The underlying strings. -/
str : String

View File

@@ -16,6 +16,7 @@ public import Init.Data.String.Lemmas.IsEmpty
public import Init.Data.String.Lemmas.Pattern
public import Init.Data.String.Lemmas.Slice
public import Init.Data.String.Lemmas.Iterate
public import Init.Data.String.Lemmas.Intercalate
import Init.Data.Order.Lemmas
public import Init.Data.String.Basic
import Init.Data.Char.Lemmas

View File

@@ -99,6 +99,15 @@ theorem Slice.utf8ByteSize_eq_size_toByteArray_copy {s : Slice} :
s.utf8ByteSize = s.copy.toByteArray.size := by
simp [utf8ByteSize_eq]
@[ext (iff := false)]
theorem Slice.ext {s t : Slice} (h : s.str = t.str)
(hsi : s.startInclusive.cast h = t.startInclusive)
(hee : s.endExclusive.cast h = t.endExclusive) : s = t := by
rcases s with s, s₁, e₁, h₁
rcases t with t, s₂, e₂, h₂
cases h
simp_all
section Iterate
/-
@@ -106,32 +115,71 @@ These lemmas are slightly evil because they are non-definitional equalities betw
are useful and they are at least equalities between slices with definitionally equal underlying
strings, so it should be fine.
-/
set_option backward.isDefEq.respectTransparency false in
@[simp]
theorem Slice.sliceTo_sliceFrom {s : Slice} {pos pos'} :
(s.sliceFrom pos).sliceTo pos' =
s.slice pos (Slice.Pos.ofSliceFrom pos') Slice.Pos.le_ofSliceFrom := by
ext <;> simp [String.Pos.ext_iff, Pos.Raw.offsetBy_assoc]
ext <;> simp [Pos.Raw.offsetBy_assoc]
set_option backward.isDefEq.respectTransparency false in
@[simp]
theorem Slice.sliceFrom_sliceTo {s : Slice} {pos pos'} :
(s.sliceTo pos).sliceFrom pos' =
s.slice (Slice.Pos.ofSliceTo pos') pos Slice.Pos.ofSliceTo_le := by
ext <;> simp [String.Pos.ext_iff]
ext <;> simp
set_option backward.isDefEq.respectTransparency false in
@[simp]
theorem Slice.sliceFrom_sliceFrom {s : Slice} {pos pos'} :
(s.sliceFrom pos).sliceFrom pos' =
s.sliceFrom (Slice.Pos.ofSliceFrom pos') := by
ext <;> simp [String.Pos.ext_iff, Pos.Raw.offsetBy_assoc]
ext <;> simp [Pos.Raw.offsetBy_assoc]
set_option backward.isDefEq.respectTransparency false in
@[simp]
theorem Slice.sliceTo_sliceTo {s : Slice} {pos pos'} :
(s.sliceTo pos).sliceTo pos' = s.sliceTo (Slice.Pos.ofSliceTo pos') := by
ext <;> simp [String.Pos.ext_iff]
ext <;> simp
@[simp]
theorem Slice.sliceFrom_slice {s : Slice} {p₁ p₂ h p} :
(s.slice p₁ p₂ h).sliceFrom p = s.slice (Pos.ofSlice p) p₂ Pos.ofSlice_le := by
ext <;> simp [Nat.add_assoc]
@[simp]
theorem Slice.sliceTo_slice {s : Slice} {p₁ p₂ h p} :
(s.slice p₁ p₂ h).sliceTo p = s.slice p₁ (Pos.ofSlice p) Pos.le_ofSlice := by
ext <;> simp [Nat.add_assoc]
@[simp]
theorem sliceTo_sliceFrom {s : String} {pos pos'} :
(s.sliceFrom pos).sliceTo pos' =
s.slice pos (Pos.ofSliceFrom pos') Pos.le_ofSliceFrom := by
ext <;> simp
@[simp]
theorem sliceFrom_sliceTo {s : String} {pos pos'} :
(s.sliceTo pos).sliceFrom pos' =
s.slice (Pos.ofSliceTo pos') pos Pos.ofSliceTo_le := by
ext <;> simp
@[simp]
theorem sliceFrom_sliceFrom {s : String} {pos pos'} :
(s.sliceFrom pos).sliceFrom pos' =
s.sliceFrom (Pos.ofSliceFrom pos') := by
ext <;> simp
@[simp]
theorem sliceTo_sliceTo {s : String} {pos pos'} :
(s.sliceTo pos).sliceTo pos' = s.sliceTo (Pos.ofSliceTo pos') := by
ext <;> simp
@[simp]
theorem sliceFrom_slice {s : String} {p₁ p₂ h p} :
(s.slice p₁ p₂ h).sliceFrom p = s.slice (Pos.ofSlice p) p₂ Pos.ofSlice_le := by
ext <;> simp
@[simp]
theorem sliceTo_slice {s : String} {p₁ p₂ h p} :
(s.slice p₁ p₂ h).sliceTo p = s.slice p₁ (Pos.ofSlice p) Pos.le_ofSlice := by
ext <;> simp
end Iterate
@@ -176,4 +224,7 @@ theorem Pos.get_ofToSlice {s : String} {p : (s.toSlice).Pos} {h} :
(ofToSlice p).get h = p.get (by simpa [ ofToSlice_inj]) := by
simp [get_eq_get_toSlice]
@[simp]
theorem push_empty {c : Char} : "".push c = singleton c := rfl
end String

View File

@@ -0,0 +1,70 @@
/-
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Markus Himmel
-/
module
prelude
public import Init.Data.String.Defs
import all Init.Data.String.Defs
public import Init.Data.String.Slice
import all Init.Data.String.Slice
public section
namespace String
@[simp]
theorem intercalate_nil {s : String} : s.intercalate [] = "" := by
simp [intercalate]
@[simp]
theorem intercalate_singleton {s t : String} : s.intercalate [t] = t := by
simp [intercalate, intercalate.go]
private theorem intercalateGo_append {s t u : String} {l : List String} :
intercalate.go (s ++ t) u l = s ++ intercalate.go t u l := by
induction l generalizing t <;> simp [intercalate.go, String.append_assoc, *]
@[simp]
theorem intercalate_cons_cons {s t u : String} {l : List String} :
s.intercalate (t :: u :: l) = t ++ s ++ s.intercalate (u :: l) := by
simp [intercalate, intercalate.go, intercalateGo_append]
@[simp]
theorem intercalate_cons_append {s t u : String} {l : List String} :
s.intercalate ((t ++ u) :: l) = t ++ s.intercalate (u :: l) := by
cases l <;> simp [String.append_assoc]
theorem intercalate_cons_of_ne_nil {s t : String} {l : List String} (h : l []) :
s.intercalate (t :: l) = t ++ s ++ s.intercalate l :=
match l, h with
| u::l, _ => by simp
@[simp]
theorem toList_intercalate {s : String} {l : List String} :
(s.intercalate l).toList = s.toList.intercalate (l.map String.toList) := by
induction l with
| nil => simp
| cons hd tl ih => cases tl <;> simp_all
namespace Slice
@[simp]
theorem _root_.String.appendSlice_eq {s : String} {t : Slice} : s ++ t = s ++ t.copy := rfl
private theorem intercalateGo_append {s t : String} {u : Slice} {l : List Slice} :
intercalate.go (s ++ t) u l = s ++ intercalate.go t u l := by
induction l generalizing t <;> simp [intercalate.go, String.append_assoc, *]
@[simp]
theorem intercalate_eq {s : Slice} {l : List Slice} :
s.intercalate l = s.copy.intercalate (l.map Slice.copy) := by
induction l with
| nil => simp [intercalate]
| cons hd tl ih => cases tl <;> simp_all [intercalate, intercalate.go, intercalateGo_append]
end Slice
end String

View File

@@ -87,6 +87,10 @@ theorem isEmpty_iff_utf8ByteSize_eq_zero {s : String} : s.isEmpty ↔ s.utf8Byte
theorem isEmpty_iff {s : String} : s.isEmpty s = "" := by
simp [isEmpty_iff_utf8ByteSize_eq_zero]
@[simp]
theorem isEmpty_eq_false_iff {s : String} : s.isEmpty = false s "" := by
simp [ isEmpty_iff]
theorem startPos_ne_endPos_iff {s : String} : s.startPos s.endPos s "" := by
simp
@@ -175,4 +179,34 @@ theorem Slice.toByteArray_copy_ne_empty_iff {s : Slice} :
s.copy.toByteArray ByteArray.empty s.isEmpty = false := by
simp
section CopyEqEmpty
-- Yes, `simp` can prove these, but we still need to mark them as simp lemmas.
@[simp]
theorem copy_slice_self {s : String} {p : s.Pos} : (s.slice p p (Pos.le_refl _)).copy = "" := by
simp
@[simp]
theorem copy_sliceTo_startPos {s : String} : (s.sliceTo s.startPos).copy = "" := by
simp
@[simp]
theorem copy_sliceFrom_startPos {s : String} : (s.sliceFrom s.endPos).copy = "" := by
simp
@[simp]
theorem Slice.copy_slice_self {s : Slice} {p : s.Pos} : (s.slice p p (Pos.le_refl _)).copy = "" := by
simp
@[simp]
theorem Slice.copy_sliceTo_startPos {s : Slice} : (s.sliceTo s.startPos).copy = "" := by
simp
@[simp]
theorem Slice.copy_sliceFrom_startPos {s : Slice} : (s.sliceFrom s.endPos).copy = "" := by
simp
end CopyEqEmpty
end String

View File

@@ -57,6 +57,14 @@ theorem Slice.Pos.endPos_le {s : Slice} (p : s.Pos) : s.endPos ≤ p ↔ p = s.e
theorem Slice.Pos.lt_endPos_iff {s : Slice} (p : s.Pos) : p < s.endPos p s.endPos := by
simp [ endPos_le, Std.not_le]
@[simp]
theorem Pos.endPos_le {s : String} (p : s.Pos) : s.endPos p p = s.endPos :=
fun h => Std.le_antisymm (le_endPos _) h, by simp +contextual
@[simp]
theorem Pos.lt_endPos_iff {s : String} (p : s.Pos) : p < s.endPos p s.endPos := by
simp [ endPos_le, Std.not_le]
@[simp]
theorem Pos.le_startPos {s : String} (p : s.Pos) : p s.startPos p = s.startPos :=
fun h => Std.le_antisymm h (startPos_le _), by simp +contextual
@@ -65,10 +73,6 @@ theorem Pos.le_startPos {s : String} (p : s.Pos) : p ≤ s.startPos ↔ p = s.st
theorem Pos.startPos_lt_iff {s : String} {p : s.Pos} : s.startPos < p p s.startPos := by
simp [ le_startPos, Std.not_le]
@[simp]
theorem Pos.endPos_le {s : String} (p : s.Pos) : s.endPos p p = s.endPos :=
fun h => Std.le_antisymm (le_endPos _) h, by simp +contextual [Std.le_refl]
@[simp]
theorem Slice.Pos.not_lt_startPos {s : Slice} {p : s.Pos} : ¬ p < s.startPos :=
fun h => Std.lt_irrefl (Std.lt_of_lt_of_le h (Slice.Pos.startPos_le _))
@@ -101,19 +105,57 @@ theorem Slice.Pos.le_next {s : Slice} {p : s.Pos} {h} : p ≤ p.next h :=
theorem Pos.le_next {s : String} {p : s.Pos} {h} : p p.next h :=
Std.le_of_lt (by simp)
@[simp]
theorem Slice.Pos.ne_next {s : Slice} {p : s.Pos} {h} : p p.next h :=
Std.ne_of_lt (by simp)
@[simp]
theorem Pos.ne_next {s : String} {p : s.Pos} {h} : p p.next h :=
Std.ne_of_lt (by simp)
@[simp]
theorem Slice.Pos.next_ne {s : Slice} {p : s.Pos} {h} : p.next h p :=
Ne.symm (by simp)
@[simp]
theorem Pos.next_ne {s : String} {p : s.Pos} {h} : p.next h p :=
Ne.symm (by simp)
@[simp]
theorem Slice.Pos.next_ne_startPos {s : Slice} {p : s.Pos} {h} :
p.next h s.startPos :=
ne_startPos_of_lt lt_next
@[simp]
theorem Slice.Pos.ofSliceTo_lt_ofSliceTo_iff {s : Slice} {p : s.Pos}
{q r : (s.sliceTo p).Pos} : Slice.Pos.ofSliceTo q < Slice.Pos.ofSliceTo r q < r := by
simp [Slice.Pos.lt_iff, Pos.Raw.lt_iff]
@[simp]
theorem Slice.Pos.ofSliceTo_le_ofSliceTo_iff {s : Slice} {p : s.Pos}
{q r : (s.sliceTo p).Pos} : Slice.Pos.ofSliceTo q Slice.Pos.ofSliceTo r q r := by
simp [Slice.Pos.le_iff, Pos.Raw.le_iff]
@[simp]
theorem Slice.Pos.sliceTo_lt_sliceTo_iff {s : Slice} {p₀ : s.Pos} {q r : s.Pos} {h₁ h₂} :
Pos.sliceTo p₀ q h₁ < Pos.sliceTo p₀ r h₂ q < r := by
simp [Slice.Pos.lt_iff, Pos.Raw.lt_iff]
@[simp]
theorem Slice.Pos.sliceTo_le_sliceTo_iff {s : Slice} {p₀ : s.Pos} {q r : s.Pos} {h₁ h₂} :
Pos.sliceTo p₀ q h₁ Pos.sliceTo p₀ r h₂ q r := by
simp [Slice.Pos.le_iff, Pos.Raw.le_iff]
@[simp]
theorem Pos.sliceTo_lt_sliceTo_iff {s : String} {p₀ : s.Pos} {q r : s.Pos} {h₁ h₂} :
Pos.sliceTo p₀ q h₁ < Pos.sliceTo p₀ r h₂ q < r := by
simp [Slice.Pos.lt_iff, Pos.lt_iff, Pos.Raw.lt_iff]
@[simp]
theorem Pos.sliceTo_le_sliceTo_iff {s : String} {p₀ : s.Pos} {q r : s.Pos} {h₁ h₂} :
Pos.sliceTo p₀ q h₁ Pos.sliceTo p₀ r h₂ q r := by
simp [Slice.Pos.le_iff, Pos.le_iff, Pos.Raw.le_iff]
@[simp]
theorem Slice.Pos.sliceFrom_lt_sliceFrom_iff {s : Slice} {p₀ : s.Pos} {q r : s.Pos} {h₁ h₂} :
Pos.sliceFrom p₀ q h₁ < Pos.sliceFrom p₀ r h₂ q < r := by
@@ -200,6 +242,116 @@ theorem Pos.ofSliceFrom_next {s : String} {p₀ : s.Pos} {p : (s.sliceFrom p₀)
Slice.Pos.next_le_iff_lt, true_and]
simp [Pos.ofSliceFrom_lt_iff]
theorem Slice.Pos.le_ofSliceTo_iff {s : Slice} {p₀ : s.Pos} {p : (s.sliceTo p₀).Pos} {q : s.Pos} :
q Pos.ofSliceTo p h, Slice.Pos.sliceTo p₀ q h p := by
refine fun h => Slice.Pos.le_trans h Pos.ofSliceTo_le, ?_, fun h, h' => ?_
· simp +singlePass only [ Pos.sliceTo_ofSliceTo (p := p)]
rwa [Pos.sliceTo_le_sliceTo_iff]
· simp +singlePass only [ Pos.ofSliceTo_sliceTo (h := h)]
rwa [Pos.ofSliceTo_le_ofSliceTo_iff]
theorem Slice.Pos.ofSliceTo_lt_iff {s : Slice} {p₀ : s.Pos} {p : (s.sliceTo p₀).Pos} {q : s.Pos} :
Pos.ofSliceTo p < q h, p < Slice.Pos.sliceTo p₀ q h := by
simp [ Std.not_le, Slice.Pos.le_ofSliceTo_iff]
theorem Slice.Pos.lt_ofSliceTo_iff {s : Slice} {p₀ : s.Pos} {p : (s.sliceTo p₀).Pos} {q : s.Pos} :
q < Pos.ofSliceTo p h, Slice.Pos.sliceTo p₀ q h < p := by
refine fun h => Std.le_of_lt (Std.lt_of_le_of_lt (Std.le_refl q) (Std.lt_of_lt_of_le h Pos.ofSliceTo_le)), ?_, fun h, h' => ?_
· simp +singlePass only [ Pos.sliceTo_ofSliceTo (p := p)]
rwa [Pos.sliceTo_lt_sliceTo_iff]
· simp +singlePass only [ Pos.ofSliceTo_sliceTo (h := h)]
rwa [Pos.ofSliceTo_lt_ofSliceTo_iff]
theorem Slice.Pos.ofSliceTo_le_iff {s : Slice} {p₀ : s.Pos} {p : (s.sliceTo p₀).Pos} {q : s.Pos} :
Pos.ofSliceTo p q h, p Slice.Pos.sliceTo p₀ q h := by
simp [ Std.not_lt, Slice.Pos.lt_ofSliceTo_iff]
theorem Pos.le_ofSliceTo_iff {s : String} {p₀ : s.Pos} {p : (s.sliceTo p₀).Pos} {q : s.Pos} :
q Pos.ofSliceTo p h, Pos.sliceTo p₀ q h p := by
refine fun h => Pos.le_trans h Pos.ofSliceTo_le, ?_, fun h, h' => ?_
· simp +singlePass only [ Pos.sliceTo_ofSliceTo (p := p)]
rwa [Pos.sliceTo_le_sliceTo_iff]
· simp +singlePass only [ Pos.ofSliceTo_sliceTo (h := h)]
rwa [Pos.ofSliceTo_le_ofSliceTo_iff]
theorem Pos.ofSliceTo_lt_iff {s : String} {p₀ : s.Pos} {p : (s.sliceTo p₀).Pos} {q : s.Pos} :
Pos.ofSliceTo p < q h, p < Pos.sliceTo p₀ q h := by
simp [ Std.not_le, Pos.le_ofSliceTo_iff]
theorem Pos.lt_ofSliceTo_iff {s : String} {p₀ : s.Pos} {p : (s.sliceTo p₀).Pos} {q : s.Pos} :
q < Pos.ofSliceTo p h, Pos.sliceTo p₀ q h < p := by
refine fun h => Pos.le_of_lt (Pos.lt_of_lt_of_le h Pos.ofSliceTo_le), ?_, fun h, h' => ?_
· simp +singlePass only [ Pos.sliceTo_ofSliceTo (p := p)]
rwa [Pos.sliceTo_lt_sliceTo_iff]
· simp +singlePass only [ Pos.ofSliceTo_sliceTo (h := h)]
rwa [Pos.ofSliceTo_lt_ofSliceTo_iff]
theorem Pos.ofSliceTo_le_iff {s : String} {p₀ : s.Pos} {p : (s.sliceTo p₀).Pos} {q : s.Pos} :
Pos.ofSliceTo p q h, p Pos.sliceTo p₀ q h := by
simp [ Std.not_lt, Pos.lt_ofSliceTo_iff]
theorem Slice.Pos.ofSliceTo_ne_endPos {s : Slice} {p₀ : s.Pos} {p : (s.sliceTo p₀).Pos}
(h : p (s.sliceTo p₀).endPos) : Pos.ofSliceTo p s.endPos := by
refine (lt_endPos_iff _).1 (Std.lt_of_lt_of_le ?_ (le_endPos p₀))
simpa [ lt_endPos_iff, ofSliceTo_lt_ofSliceTo_iff] using h
theorem Pos.ofSliceTo_ne_endPos {s : String} {p₀ : s.Pos} {p : (s.sliceTo p₀).Pos}
(h : p (s.sliceTo p₀).endPos) : Pos.ofSliceTo p s.endPos := by
refine (lt_endPos_iff _).1 (Std.lt_of_lt_of_le ?_ (le_endPos p₀))
simpa [ Slice.Pos.lt_endPos_iff, ofSliceTo_lt_ofSliceTo_iff] using h
theorem Slice.Pos.ofSliceTo_next {s : Slice} {p₀ : s.Pos} {p : (s.sliceTo p₀).Pos} {h} :
Pos.ofSliceTo (p.next h) = (Pos.ofSliceTo p).next (ofSliceTo_ne_endPos h) := by
rw [eq_comm, Pos.next_eq_iff]
simp only [Pos.ofSliceTo_lt_ofSliceTo_iff, Pos.lt_next, Pos.ofSliceTo_le_iff,
Pos.next_le_iff_lt, true_and]
simp [Pos.ofSliceTo_lt_iff]
theorem Pos.ofSliceTo_next {s : String} {p₀ : s.Pos} {p : (s.sliceTo p₀).Pos} {h} :
Pos.ofSliceTo (p.next h) = (Pos.ofSliceTo p).next (ofSliceTo_ne_endPos h) := by
rw [eq_comm, Pos.next_eq_iff]
simp only [Pos.ofSliceTo_lt_ofSliceTo_iff, Slice.Pos.lt_next, Pos.ofSliceTo_le_iff,
Slice.Pos.next_le_iff_lt, true_and]
simp [Pos.ofSliceTo_lt_iff]
@[simp]
theorem Slice.Pos.slice_lt_slice_iff {s : Slice} {p₀ p₁ : s.Pos} {q r : s.Pos}
{h₁ h₁' h₂ h₂'} :
q.slice p₀ p₁ h₁ h₂ < r.slice p₀ p₁ h₁' h₂' q < r := by
simp [Slice.Pos.lt_iff, Pos.Raw.lt_iff, Slice.Pos.le_iff, Pos.Raw.le_iff] at h₁ h₁'
omega
@[simp]
theorem Slice.Pos.slice_le_slice_iff {s : Slice} {p₀ p₁ : s.Pos} {q r : s.Pos}
{h₁ h₁' h₂ h₂'} :
q.slice p₀ p₁ h₁ h₂ r.slice p₀ p₁ h₁' h₂' q r := by
simp [Slice.Pos.le_iff, Pos.Raw.le_iff] at h₁ h₁'
omega
@[simp]
theorem Pos.slice_lt_slice_iff {s : String} {p₀ p₁ : s.Pos} {q r : s.Pos}
{h₁ h₁' h₂ h₂'} :
q.slice p₀ p₁ h₁ h₂ < r.slice p₀ p₁ h₁' h₂' q < r := by
simp [Slice.Pos.lt_iff, Pos.lt_iff, Pos.Raw.lt_iff, Pos.le_iff, Pos.Raw.le_iff] at h₁ h₁'
omega
@[simp]
theorem Pos.slice_le_slice_iff {s : String} {p₀ p₁ : s.Pos} {q r : s.Pos}
{h₁ h₁' h₂ h₂'} :
q.slice p₀ p₁ h₁ h₂ r.slice p₀ p₁ h₁' h₂' q r := by
simp [Slice.Pos.le_iff, Pos.le_iff, Pos.Raw.le_iff] at h₁ h₁'
omega
theorem Slice.Pos.ofSlice_ne_endPos {s : Slice} {p₀ p₁ : s.Pos} {h} {p : (s.slice p₀ p₁ h).Pos}
(h : p (s.slice p₀ p₁ h).endPos) : Pos.ofSlice p s.endPos := by
refine (lt_endPos_iff _).1 (Std.lt_of_lt_of_le ?_ (le_endPos p₁))
simpa [ lt_endPos_iff, ofSlice_lt_ofSlice_iff] using h
theorem Pos.ofSlice_ne_endPos {s : String} {p₀ p₁ : s.Pos} {h} {p : (s.slice p₀ p₁ h).Pos}
(h : p (s.slice p₀ p₁ h).endPos) : Pos.ofSlice p s.endPos := by
refine (lt_endPos_iff _).1 (Std.lt_of_lt_of_le ?_ (le_endPos p₁))
simpa [ Slice.Pos.lt_endPos_iff, ofSlice_lt_ofSlice_iff] using h
@[simp]
theorem Slice.Pos.offset_le_rawEndPos {s : Slice} {p : s.Pos} :
p.offset s.rawEndPos :=
@@ -248,4 +400,38 @@ theorem Pos.isUTF8FirstByte_getUTF8Byte_offset {s : String} {p : s.Pos} {h} :
(s.getUTF8Byte p.offset h).IsUTF8FirstByte := by
simpa [getUTF8Byte_offset] using isUTF8FirstByte_byte
theorem Slice.Pos.get_eq_get_ofSliceTo {s : Slice} {p₀ : s.Pos} {pos : (s.sliceTo p₀).Pos} {h} :
pos.get h = (ofSliceTo pos).get (ofSliceTo_ne_endPos h) := by
simp [Slice.Pos.get]
theorem Pos.get_eq_get_ofSliceTo {s : String} {p₀ : s.Pos}
{pos : (s.sliceTo p₀).Pos} {h} :
pos.get h = (ofSliceTo pos).get (ofSliceTo_ne_endPos h) := by
simp [Pos.get, Slice.Pos.get]
theorem Slice.Pos.get_eq_get_ofSlice {s : Slice} {p₀ p₁ : s.Pos} {h}
{pos : (s.slice p₀ p₁ h).Pos} {h'} :
pos.get h' = (ofSlice pos).get (ofSlice_ne_endPos h') := by
simp [Slice.Pos.get, Nat.add_assoc]
theorem Pos.get_eq_get_ofSlice {s : String} {p₀ p₁ : s.Pos} {h}
{pos : (s.slice p₀ p₁ h).Pos} {h'} :
pos.get h' = (ofSlice pos).get (ofSlice_ne_endPos h') := by
simp [Pos.get, Slice.Pos.get]
theorem Slice.Pos.ofSlice_next {s : Slice} {p₀ p₁ : s.Pos} {h}
{p : (s.slice p₀ p₁ h).Pos} {h'} :
Pos.ofSlice (p.next h') = (Pos.ofSlice p).next (ofSlice_ne_endPos h') := by
simp only [Slice.Pos.ext_iff, Pos.Raw.ext_iff, Slice.Pos.offset_next, Slice.Pos.offset_ofSlice]
rw [Slice.Pos.get_eq_get_ofSlice (h' := h')]
simp [Pos.Raw.offsetBy, Nat.add_assoc]
theorem Pos.ofSlice_next {s : String} {p₀ p₁ : s.Pos} {h}
{p : (s.slice p₀ p₁ h).Pos} {h'} :
Pos.ofSlice (p.next h') = (Pos.ofSlice p).next (ofSlice_ne_endPos h') := by
simp only [Pos.ext_iff, Pos.Raw.ext_iff, Slice.Pos.offset_next, Pos.offset_next,
Pos.offset_ofSlice]
rw [Pos.get_eq_get_ofSlice (h' := h')]
simp [Pos.Raw.offsetBy, Nat.add_assoc]
end String

View File

@@ -46,6 +46,10 @@ theorem isLongestMatchAt_iff {c : Char} {s : Slice} {pos pos' : s.Pos} :
simp +contextual [Model.isLongestMatchAt_iff, isLongestMatch_iff, Pos.ofSliceFrom_inj,
Pos.get_eq_get_ofSliceFrom, Pos.ofSliceFrom_next]
theorem isLongestMatchAt_of_get_eq {c : Char} {s : Slice} {pos : s.Pos} {h : pos s.endPos}
(hc : pos.get h = c) : IsLongestMatchAt c pos (pos.next h) :=
isLongestMatchAt_iff.2 h, by simp [hc]
instance {c : Char} : LawfulForwardPatternModel c where
dropPrefix?_eq_some_iff {s} pos := by
simp [isLongestMatch_iff, ForwardPattern.dropPrefix?, and_comm, eq_comm (b := pos)]
@@ -57,6 +61,10 @@ theorem matchesAt_iff {c : Char} {s : Slice} {pos : s.Pos} :
MatchesAt c pos (h : pos s.endPos), pos.get h = c := by
simp [matchesAt_iff_exists_isLongestMatchAt, isLongestMatchAt_iff, exists_comm]
theorem not_matchesAt_of_get_ne {c : Char} {s : Slice} {pos : s.Pos} {h : pos s.endPos}
(hc : pos.get h c) : ¬ MatchesAt c pos := by
simp [matchesAt_iff, hc]
theorem matchAt?_eq {s : Slice} {pos : s.Pos} {c : Char} :
matchAt? c pos =
if h₀ : (h : pos s.endPos), pos.get h = c then some (pos.next h₀.1) else none := by

View File

@@ -47,6 +47,10 @@ theorem isLongestMatchAt_iff {p : Char → Bool} {s : Slice} {pos pos' : s.Pos}
simp +contextual [Model.isLongestMatchAt_iff, isLongestMatch_iff, Pos.ofSliceFrom_inj,
Pos.get_eq_get_ofSliceFrom, Pos.ofSliceFrom_next]
theorem isLongestMatchAt_of_get {p : Char Bool} {s : Slice} {pos : s.Pos} {h : pos s.endPos}
(hc : p (pos.get h)) : IsLongestMatchAt p pos (pos.next h) :=
isLongestMatchAt_iff.2 h, by simp [hc]
instance {p : Char Bool} : LawfulForwardPatternModel p where
dropPrefix?_eq_some_iff {s} pos := by
simp [isLongestMatch_iff, ForwardPattern.dropPrefix?, and_comm, eq_comm (b := pos)]
@@ -58,6 +62,10 @@ theorem matchesAt_iff {p : Char → Bool} {s : Slice} {pos : s.Pos} :
MatchesAt p pos (h : pos s.endPos), p (pos.get h) := by
simp [matchesAt_iff_exists_isLongestMatchAt, isLongestMatchAt_iff, exists_comm]
theorem not_matchesAt_of_get {p : Char Bool} {s : Slice} {pos : s.Pos} {h : pos s.endPos}
(hc : p (pos.get h) = false) : ¬ MatchesAt p pos := by
simp [matchesAt_iff, hc]
theorem matchAt?_eq {s : Slice} {pos : s.Pos} {p : Char Bool} :
matchAt? p pos =
if h₀ : (h : pos s.endPos), p (pos.get h) then some (pos.next h₀.1) else none := by
@@ -100,6 +108,10 @@ theorem isLongestMatchAt_iff {p : Char → Prop} [DecidablePred p] {s : Slice}
IsLongestMatchAt p pos pos' h, pos' = pos.next h p (pos.get h) := by
simp [isLongestMatchAt_iff_isLongestMatchAt_decide, CharPred.isLongestMatchAt_iff]
theorem isLongestMatchAt_of_get {p : Char Prop} [DecidablePred p] {s : Slice} {pos : s.Pos}
{h : pos s.endPos} (hc : p (pos.get h)) : IsLongestMatchAt p pos (pos.next h) :=
isLongestMatchAt_iff.2 h, by simp [hc]
theorem dropPrefix?_eq_dropPrefix?_decide {p : Char Prop} [DecidablePred p] :
ForwardPattern.dropPrefix? p = ForwardPattern.dropPrefix? (decide <| p ·) := rfl
@@ -115,6 +127,10 @@ theorem matchesAt_iff {p : Char → Prop} [DecidablePred p] {s : Slice} {pos : s
MatchesAt p pos (h : pos s.endPos), p (pos.get h) := by
simp [matchesAt_iff_exists_isLongestMatchAt, isLongestMatchAt_iff, exists_comm]
theorem not_matchesAt_of_get {p : Char Prop} [DecidablePred p] {s : Slice} {pos : s.Pos}
{h : pos s.endPos} (hc : ¬ p (pos.get h)) : ¬ MatchesAt p pos := by
simp [matchesAt_iff, hc]
theorem matchAt?_eq {s : Slice} {pos : s.Pos} {p : Char Prop} [DecidablePred p] :
matchAt? p pos =
if h₀ : (h : pos s.endPos), p (pos.get h) then some (pos.next h₀.1) else none := by

View File

@@ -7,3 +7,5 @@ module
prelude
public import Init.Data.String.Lemmas.Pattern.Split.Basic
public import Init.Data.String.Lemmas.Pattern.Split.Char
public import Init.Data.String.Lemmas.Pattern.Split.Pred

View File

@@ -8,7 +8,9 @@ module
prelude
public import Init.Data.String.Lemmas.Pattern.Basic
public import Init.Data.String.Slice
public import Init.Data.String.Search
import all Init.Data.String.Slice
import all Init.Data.String.Search
import Init.Data.Option.Lemmas
import Init.Data.String.Termination
import Init.Data.String.Lemmas.Order
@@ -17,6 +19,8 @@ import Init.Data.Order.Lemmas
import Init.Data.String.OrderInstances
import Init.Data.Iterators.Lemmas.Basic
import Init.Data.Iterators.Lemmas.Consumers.Collect
import Init.Data.Iterators.Lemmas.Combinators.FilterMap
import Init.Data.String.Lemmas.IsEmpty
set_option doc.verso true
@@ -31,111 +35,38 @@ This gives a low-level correctness proof from which higher-level API lemmas can
namespace String.Slice.Pattern.Model
/--
Represents a list of subslices of a slice {name}`s`, the first of which starts at the given
position {name}`startPos`. This is a natural type for a split routine to return.
-/
@[ext]
public structure SlicesFrom {s : Slice} (startPos : s.Pos) : Type where
l : List s.Subslice
any_head? : l.head?.any (·.startInclusive = startPos)
namespace SlicesFrom
/--
A {name}`SlicesFrom` consisting of a single empty subslice at the position {name}`pos`.
-/
public def «at» {s : Slice} (pos : s.Pos) : SlicesFrom pos where
l := [s.subslice pos pos (Slice.Pos.le_refl _)]
any_head? := by simp
@[simp]
public theorem l_at {s : Slice} (pos : s.Pos) :
(SlicesFrom.at pos).l = [s.subslice pos pos (Slice.Pos.le_refl _)] := (rfl)
/--
Concatenating two {name}`SlicesFrom` yields a {name}`SlicesFrom` from the first position.
-/
public def append {s : Slice} {p₁ p₂ : s.Pos} (l₁ : SlicesFrom p₁) (l₂ : SlicesFrom p₂) :
SlicesFrom p₁ where
l := l₁.l ++ l₂.l
any_head? := by simpa using Option.any_or_of_any_left l₁.any_head?
@[simp]
public theorem l_append {s : Slice} {p₁ p₂ : s.Pos} {l₁ : SlicesFrom p₁} {l₂ : SlicesFrom p₂} :
(l₁.append l₂).l = l₁.l ++ l₂.l :=
(rfl)
/--
Given a {lean}`SlicesFrom p₂` and a position {name}`p₁` such that {lean}`p₁ ≤ p₂`, obtain a
{lean}`SlicesFrom p₁` by extending the left end of the first subslice to from {name}`p₂` to
{name}`p₁`.
-/
public def extend {s : Slice} (p₁ : s.Pos) {p₂ : s.Pos} (h : p₁ p₂) (l : SlicesFrom p₂) :
SlicesFrom p₁ where
l :=
match l.l, l.any_head? with
| st :: sts, h => st.extendLeft p₁ (by simp_all) :: sts
any_head? := by split; simp
@[simp]
public theorem l_extend {s : Slice} {p₁ p₂ : s.Pos} (h : p₁ p₂) {l : SlicesFrom p₂} :
(l.extend p₁ h).l =
match l.l, l.any_head? with
| st :: sts, h => st.extendLeft p₁ (by simp_all) :: sts :=
(rfl)
@[simp]
public theorem extend_self {s : Slice} {p₁ : s.Pos} (l : SlicesFrom p₁) :
l.extend p₁ (Slice.Pos.le_refl _) = l := by
rcases l with l, h
match l, h with
| st :: sts, h =>
simp at h
simp [SlicesFrom.extend, h]
@[simp]
public theorem extend_extend {s : Slice} {p₀ p₁ p₂ : s.Pos} {h : p₀ p₁} {h' : p₁ p₂}
{l : SlicesFrom p₂} : (l.extend p₁ h').extend p₀ h = l.extend p₀ (Slice.Pos.le_trans h h') := by
rcases l with l, h
match l, h with
| st :: sts, h => simp [SlicesFrom.extend]
end SlicesFrom
/--
Noncomputable model implementation of {name}`String.Slice.splitToSubslice` based on
{name}`ForwardPatternModel`. This is supposed to be simple enough to allow deriving higher-level
API lemmas about the public splitting functions.
-/
public protected noncomputable def split {ρ : Type} (pat : ρ) [ForwardPatternModel pat] {s : Slice}
(start : s.Pos) : SlicesFrom start :=
if h : start = s.endPos then
.at start
(firstRejected curr : s.Pos) (hle : firstRejected curr) : List s.Subslice :=
if h : curr = s.endPos then
[s.subslice _ _ hle]
else
match hd : matchAt? pat start with
match hd : matchAt? pat curr with
| some pos =>
have : start < pos := (matchAt?_eq_some_iff.1 hd).lt
(SlicesFrom.at start).append (Model.split pat pos)
| none => (Model.split pat (start.next h)).extend start (by simp)
termination_by start
have : curr < pos := (matchAt?_eq_some_iff.1 hd).lt
s.subslice _ _ hle :: Model.split pat pos pos (Std.le_refl _)
| none => Model.split pat firstRejected (curr.next h) (Std.le_trans hle (by simp))
termination_by curr
@[simp]
public theorem split_endPos {ρ : Type} {pat : ρ} [ForwardPatternModel pat] {s : Slice} :
Model.split pat s.endPos = SlicesFrom.at s.endPos := by
public theorem split_endPos {ρ : Type} {pat : ρ} [ForwardPatternModel pat] {s : Slice}
{firstRejected : s.Pos} :
Model.split (s := s) pat firstRejected s.endPos (by simp) = [s.subslice firstRejected s.endPos (by simp)] := by
simp [Model.split]
public theorem split_eq_of_isLongestMatchAt {ρ : Type} {pat : ρ} [ForwardPatternModel pat]
{s : Slice} {start stop : s.Pos} (h : IsLongestMatchAt pat start stop) :
Model.split pat start = (SlicesFrom.at start).append (Model.split pat stop) := by
{s : Slice} {firstRejected start stop : s.Pos} {hle} (h : IsLongestMatchAt pat start stop) :
Model.split pat firstRejected start hle =
s.subslice _ _ hle :: Model.split pat stop stop (by exact Std.le_refl _) := by
rw [Model.split, dif_neg (Slice.Pos.ne_endPos_of_lt h.lt)]
split
· congr <;> exact (matchAt?_eq_some_iff.1 _).eq h
· simp [matchAt?_eq_some_iff.2 _] at *
public theorem split_eq_of_not_matchesAt {ρ : Type} {pat : ρ} [ForwardPatternModel pat] {s : Slice}
{start stop : s.Pos} (h₀ : start stop) (h : p, start p p < stop ¬ MatchesAt pat p) :
Model.split pat start = (SlicesFrom.extend start h₀ (Model.split pat stop)) := by
public theorem split_eq_of_not_matchesAt {ρ : Type} {pat : ρ} [ForwardPatternModel pat]
{s : Slice} {firstRejected start} (stop : s.Pos) (h : start stop) {hle}
(h : p, start p p < stop ¬ MatchesAt pat p) :
Model.split pat firstRejected start hle =
Model.split pat firstRejected stop (by exact Std.le_trans hle h₀) := by
induction start using WellFounded.induction Slice.Pos.wellFounded_gt with | h start ih
by_cases h' : start < stop
· rw [Model.split, dif_neg (Slice.Pos.ne_endPos_of_lt h')]
@@ -143,13 +74,19 @@ public theorem split_eq_of_not_matchesAt {ρ : Type} {pat : ρ} [ForwardPatternM
split
· rename_i heq
simp [matchAt?_eq_none_iff.2 _] at heq
· rw [ih, SlicesFrom.extend_extend]
· simp
· simp [h']
· refine fun p hp₁ hp₂ => h p (Std.le_of_lt (by simpa using hp₁)) hp₂
· rw [ih _ (by simp) (by simpa)]
exact fun p hp₁ hp₂ => h p (Std.le_of_lt (by simpa using hp₁)) hp₂
· obtain rfl : start = stop := Std.le_antisymm h₀ (Std.not_lt.1 h')
simp
public theorem split_eq_next_of_not_matchesAt {ρ : Type} {pat : ρ} [ForwardPatternModel pat]
{s : Slice} {firstRejected start} {hle} (hs : start s.endPos) (h : ¬ MatchesAt pat start) :
Model.split pat firstRejected start hle =
Model.split pat firstRejected (start.next hs) (by exact Std.le_trans hle (by simp)) := by
refine split_eq_of_not_matchesAt _ (by simp) (fun p hp₁ hp₂ => ?_)
obtain rfl : start = p := Std.le_antisymm hp₁ (by simpa using hp₂)
exact h
/--
Splits a slice {name}`s` into subslices from a list of {lean}`SearchStep s`.
@@ -168,30 +105,18 @@ theorem IsValidSearchFrom.splitFromSteps_eq_extend_split {ρ : Type} (pat : ρ)
[ForwardPatternModel pat] (l : List (SearchStep s)) (pos pos' : s.Pos) (h₀ : pos pos')
(h' : p, pos p p < pos' ¬ MatchesAt pat p)
(h : IsValidSearchFrom pat pos' l) :
splitFromSteps pos l = ((Model.split pat pos').extend pos h₀).l := by
splitFromSteps pos l = Model.split pat pos pos' h₀ := by
induction h generalizing pos with
| endPos =>
simp only [splitFromSteps, Model.split, reduceDIte, SlicesFrom.l_extend, List.head?_cons,
Option.any_some]
split
simp_all only [SlicesFrom.l_at, List.cons.injEq, List.nil_eq, List.head?_cons, Option.any_some,
decide_eq_true_eq, heq_eq_eq, and_true]
rename_i h
simp only [ h.1]
ext <;> simp
simp [splitFromSteps]
| matched h valid ih =>
simp only [splitFromSteps]
rw [subslice!_eq_subslice h₀, split_eq_of_isLongestMatchAt h]
simp only [SlicesFrom.append, SlicesFrom.at, List.cons_append, List.nil_append,
SlicesFrom.l_extend, List.cons.injEq]
refine ?_, ?_
· ext <;> simp
· rw [ih _ (Slice.Pos.le_refl _), SlicesFrom.extend_self]
exact fun p hp₁ hp₂ => False.elim (Std.lt_irrefl (Std.lt_of_le_of_lt hp₁ hp₂))
rw [subslice!_eq_subslice h₀, split_eq_of_isLongestMatchAt h, ih]
simp +contextual [ Std.not_lt]
| mismatched h rej valid ih =>
simp only [splitFromSteps]
rename_i l startPos endPos
rw [split_eq_of_not_matchesAt (Std.le_of_lt h) rej, SlicesFrom.extend_extend, ih]
rw [split_eq_of_not_matchesAt _ (Std.le_of_lt h) rej, ih]
intro p hp₁ hp₂
by_cases hp : p < startPos
· exact h' p hp₁ hp
@@ -231,10 +156,52 @@ open Model
public theorem toList_splitToSubslice_eq_modelSplit {ρ : Type} (pat : ρ) [ForwardPatternModel pat]
{σ : Slice Type} [ToForwardSearcher pat σ] [ s, Std.Iterator (σ s) Id (SearchStep s)]
[ s, Std.Iterators.Finite (σ s) Id] [LawfulToForwardSearcherModel pat] (s : Slice) :
(s.splitToSubslice pat).toList = (Model.split pat s.startPos).l := by
(s.splitToSubslice pat).toList = Model.split pat s.startPos s.startPos (by exact Std.le_refl _) := by
rw [toList_splitToSubslice_eq_splitFromSteps, IsValidSearchFrom.splitFromSteps_eq_extend_split pat _
s.startPos s.startPos (Std.le_refl _) _ (LawfulToForwardSearcherModel.isValidSearchFrom_toList _),
SlicesFrom.extend_self]
s.startPos s.startPos (Std.le_refl _) _ (LawfulToForwardSearcherModel.isValidSearchFrom_toList _)]
simp
end String.Slice.Pattern
end Pattern
open Pattern
public theorem toList_splitToSubslice_of_isEmpty {ρ : Type} (pat : ρ)
[Model.ForwardPatternModel pat] {σ : Slice Type}
[ToForwardSearcher pat σ] [ s, Std.Iterator (σ s) Id (SearchStep s)]
[ s, Std.Iterators.Finite (σ s) Id] [Model.LawfulToForwardSearcherModel pat] {s : Slice}
(h : s.isEmpty = true) :
(s.splitToSubslice pat).toList = [s.subsliceFrom s.endPos] := by
simp [toList_splitToSubslice_eq_modelSplit, Slice.startPos_eq_endPos_iff.2 h]
public theorem toList_split_eq_splitToSubslice {ρ : Type} (pat : ρ) {σ : Slice Type}
[ToForwardSearcher pat σ] [ s, Std.Iterator (σ s) Id (SearchStep s)]
[ s, Std.Iterators.Finite (σ s) Id] {s : Slice} :
(s.split pat).toList = (s.splitToSubslice pat).toList.map Subslice.toSlice := by
simp [split, Std.Iter.toList_map]
public theorem toList_split_of_isEmpty {ρ : Type} (pat : ρ)
[Model.ForwardPatternModel pat] {σ : Slice Type}
[ToForwardSearcher pat σ] [ s, Std.Iterator (σ s) Id (SearchStep s)]
[ s, Std.Iterators.Finite (σ s) Id] [Model.LawfulToForwardSearcherModel pat] {s : Slice}
(h : s.isEmpty = true) :
(s.split pat).toList.map Slice.copy = [""] := by
rw [toList_split_eq_splitToSubslice, toList_splitToSubslice_of_isEmpty _ h]
simp
end Slice
open Slice.Pattern
public theorem split_eq_split_toSlice {ρ : Type} {pat : ρ} {σ : Slice Type}
[ToForwardSearcher pat σ] [ s, Std.Iterator (σ s) Id (SearchStep s)] {s : String} :
s.split pat = s.toSlice.split pat := (rfl)
@[simp]
public theorem toList_split_empty {ρ : Type} (pat : ρ)
[Model.ForwardPatternModel pat] {σ : Slice Type}
[ToForwardSearcher pat σ] [ s, Std.Iterator (σ s) Id (SearchStep s)]
[ s, Std.Iterators.Finite (σ s) Id] [Model.LawfulToForwardSearcherModel pat] :
("".split pat).toList.map Slice.copy = [""] := by
rw [split_eq_split_toSlice, Slice.toList_split_of_isEmpty _ (by simp)]
end String

View File

@@ -0,0 +1,78 @@
/-
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Markus Himmel
-/
module
prelude
public import Init.Data.String.Slice
public import Init.Data.String.Search
public import Init.Data.List.SplitOn.Basic
import Init.Data.String.Termination
import Init.Data.Order.Lemmas
import Init.Data.Iterators.Lemmas.Combinators.FilterMap
import Init.Data.String.Lemmas.Pattern.Split.Basic
import Init.Data.String.Lemmas.Pattern.Char
import Init.ByCases
import Init.Data.String.OrderInstances
import Init.Data.String.Lemmas.Order
import Init.Data.String.Lemmas.Intercalate
import Init.Data.List.SplitOn.Lemmas
public section
namespace String.Slice
open Pattern.Model Pattern.Model.Char
theorem toList_splitToSubslice_char {s : Slice} {c : Char} :
(s.splitToSubslice c).toList.map (Slice.copy Subslice.toSlice) =
(s.copy.toList.splitOn c).map String.ofList := by
simp only [Pattern.toList_splitToSubslice_eq_modelSplit]
suffices (f p : s.Pos) (hle : f p) (t₁ t₂ : String),
p.Splits t₁ t₂ (Pattern.Model.split c f p hle).map (copy Subslice.toSlice) =
(t₂.toList.splitOnPPrepend (· == c) (s.subslice f p hle).copy.toList.reverse).map String.ofList by
simpa [List.splitOn_eq_splitOnP] using this s.startPos s.startPos (Std.le_refl _) "" s.copy
intro f p hle t₁ t₂ hp
induction p using Pos.next_induction generalizing f t₁ t₂ with
| next p h ih =>
obtain t₂, rfl := hp.exists_eq_singleton_append h
by_cases hpc : p.get h = c
· simp [split_eq_of_isLongestMatchAt (isLongestMatchAt_of_get_eq hpc),
ih _ (Std.le_refl _) _ _ hp.next,
List.splitOnPPrepend_cons_pos (p := (· == c)) (beq_iff_eq.2 hpc)]
· rw [split_eq_next_of_not_matchesAt h (not_matchesAt_of_get_ne hpc)]
simp only [toList_append, toList_singleton, List.cons_append, List.nil_append, Subslice.copy_eq]
rw [ih _ _ _ _ hp.next, List.splitOnPPrepend_cons_neg (by simpa)]
have := (splits_slice (Std.le_trans hle (by simp)) (p.slice f (p.next h) hle (by simp))).eq_append
simp_all
| endPos => simp_all
theorem toList_split_char {s : Slice} {c : Char} :
(s.split c).toList.map Slice.copy = (s.copy.toList.splitOn c).map String.ofList := by
simp [toList_split_eq_splitToSubslice, toList_splitToSubslice_char]
end Slice
theorem toList_split_char {s : String} {c : Char} :
(s.split c).toList.map Slice.copy = (s.toList.splitOn c).map String.ofList := by
simp [split_eq_split_toSlice, Slice.toList_split_char]
theorem Slice.toList_split_intercalate {c : Char} {l : List Slice} (hl : s l, c s.copy.toList) :
((Slice.intercalate (String.singleton c) l).split c).toList.map Slice.copy =
if l = [] then [""] else l.map Slice.copy := by
simp [String.toList_split_char]
split
· simp_all
· rw [List.splitOn_intercalate] <;> simp_all
theorem toList_split_intercalate {c : Char} {l : List String} (hl : s l, c s.toList) :
((String.intercalate (String.singleton c) l).split c).toList.map (·.copy) =
if l = [] then [""] else l := by
simp only [toList_split_char, toList_intercalate, toList_singleton]
split
· simp_all
· rw [List.splitOn_intercalate] <;> simp_all
end String

View File

@@ -0,0 +1,103 @@
/-
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Markus Himmel
-/
module
prelude
public import Init.Data.String.Slice
public import Init.Data.String.Search
public import Init.Data.List.SplitOn.Basic
import Init.Data.String.Termination
import Init.Data.Order.Lemmas
import Init.Data.Iterators.Lemmas.Combinators.FilterMap
import Init.Data.String.Lemmas.Pattern.Split.Basic
import Init.Data.String.Lemmas.Pattern.Pred
import Init.ByCases
import Init.Data.String.OrderInstances
import Init.Data.List.SplitOn.Lemmas
import Init.Data.String.Lemmas.Order
public section
namespace String.Slice
section
open Pattern.Model Pattern.Model.CharPred
theorem toList_splitToSubslice_bool {s : Slice} {p : Char Bool} :
(s.splitToSubslice p).toList.map (Slice.copy Subslice.toSlice) =
(s.copy.toList.splitOnP p).map String.ofList := by
simp only [Pattern.toList_splitToSubslice_eq_modelSplit]
suffices (f pos : s.Pos) (hle : f pos) (t₁ t₂ : String),
pos.Splits t₁ t₂ (Pattern.Model.split p f pos hle).map (copy Subslice.toSlice) =
(t₂.toList.splitOnPPrepend p (s.subslice f pos hle).copy.toList.reverse).map String.ofList by
simpa using this s.startPos s.startPos (Std.le_refl _) "" s.copy
intro f pos hle t₁ t₂ hp
induction pos using Pos.next_induction generalizing f t₁ t₂ with
| next pos h ih =>
obtain t₂, rfl := hp.exists_eq_singleton_append h
by_cases hpc : p (pos.get h)
· simp [split_eq_of_isLongestMatchAt (isLongestMatchAt_of_get hpc),
ih _ (Std.le_refl _) _ _ hp.next,
List.splitOnPPrepend_cons_pos (p := p) hpc]
· rw [Bool.not_eq_true] at hpc
rw [split_eq_next_of_not_matchesAt h (not_matchesAt_of_get hpc)]
simp only [toList_append, toList_singleton, List.cons_append, List.nil_append, Subslice.copy_eq]
rw [ih _ _ _ _ hp.next, List.splitOnPPrepend_cons_neg (by simpa)]
have := (splits_slice (Std.le_trans hle (by simp)) (pos.slice f (pos.next h) hle (by simp))).eq_append
simp_all
| endPos => simp_all
theorem toList_split_bool {s : Slice} {p : Char Bool} :
(s.split p).toList.map Slice.copy = (s.copy.toList.splitOnP p).map String.ofList := by
simp [toList_split_eq_splitToSubslice, toList_splitToSubslice_bool]
end
section
open Pattern.Model Pattern.Model.CharPred.Decidable
theorem toList_splitToSubslice_prop {s : Slice} {p : Char Prop} [DecidablePred p] :
(s.splitToSubslice p).toList.map (Slice.copy Subslice.toSlice) =
(s.copy.toList.splitOnP p).map String.ofList := by
simp only [Pattern.toList_splitToSubslice_eq_modelSplit]
suffices (f pos : s.Pos) (hle : f pos) (t₁ t₂ : String),
pos.Splits t₁ t₂ (Pattern.Model.split p f pos hle).map (copy Subslice.toSlice) =
(t₂.toList.splitOnPPrepend p (s.subslice f pos hle).copy.toList.reverse).map String.ofList by
simpa using this s.startPos s.startPos (Std.le_refl _) "" s.copy
intro f pos hle t₁ t₂ hp
induction pos using Pos.next_induction generalizing f t₁ t₂ with
| next pos h ih =>
obtain t₂, rfl := hp.exists_eq_singleton_append h
by_cases hpc : p (pos.get h)
· simp [split_eq_of_isLongestMatchAt (isLongestMatchAt_of_get hpc),
ih _ (Std.le_refl _) _ _ hp.next,
List.splitOnPPrepend_cons_pos (p := (decide <| p ·)) (by simpa using hpc)]
· rw [split_eq_next_of_not_matchesAt h (not_matchesAt_of_get hpc)]
simp only [toList_append, toList_singleton, List.cons_append, List.nil_append, Subslice.copy_eq]
rw [ih _ _ _ _ hp.next, List.splitOnPPrepend_cons_neg (by simpa)]
have := (splits_slice (Std.le_trans hle (by simp)) (pos.slice f (pos.next h) hle (by simp))).eq_append
simp_all
| endPos => simp_all
theorem toList_split_prop {s : Slice} {p : Char Prop} [DecidablePred p] :
(s.split p).toList.map Slice.copy = (s.copy.toList.splitOnP p).map String.ofList := by
simp [toList_split_eq_splitToSubslice, toList_splitToSubslice_prop]
end
end Slice
theorem toList_split_bool {s : String} {p : Char Bool} :
(s.split p).toList.map Slice.copy = (s.toList.splitOnP p).map String.ofList := by
simp [split_eq_split_toSlice, Slice.toList_split_bool]
theorem toList_split_prop {s : String} {p : Char Prop} [DecidablePred p] :
(s.split p).toList.map Slice.copy = (s.toList.splitOnP p).map String.ofList := by
simp [split_eq_split_toSlice, Slice.toList_split_prop]
end String

View File

@@ -416,14 +416,6 @@ theorem splits_singleton_iff {s : String} {p : s.Pos} {c : Char} {t : String} :
rw [ Pos.splits_toSlice_iff, Slice.splits_singleton_iff]
simp [ Pos.ofToSlice_inj]
@[simp]
theorem Slice.copy_sliceTo_startPos {s : Slice} : (s.sliceTo s.startPos).copy = "" :=
s.startPos.splits.eq_left s.splits_startPos
@[simp]
theorem copy_sliceTo_startPos {s : String} : (s.sliceTo s.startPos).copy = "" :=
s.startPos.splits.eq_left s.splits_startPos
theorem Slice.splits_next_startPos {s : Slice} {h : s.startPos s.endPos} :
(s.startPos.next h).Splits
(singleton (s.startPos.get h)) (s.sliceFrom (s.startPos.next h)).copy := by
@@ -597,4 +589,40 @@ theorem Slice.Pos.Splits.copy_sliceFrom_eq {s : Slice} {p : s.Pos} (h : p.Splits
(s.sliceFrom p).copy = t₂ :=
p.splits.eq_right h
theorem copy_slice_eq_append_of_lt {s : String} {p q : s.Pos} (h : p < q) :
(s.slice p q (by exact Std.le_of_lt h)).copy =
String.singleton (p.get (by exact Pos.ne_endPos_of_lt h)) ++
(s.slice (p.next (by exact Pos.ne_endPos_of_lt h)) q (by simpa)).copy := by
have hsp := (s.slice p q (Std.le_of_lt h)).splits_startPos
obtain t₂, ht := hsp.exists_eq_singleton_append (by simpa [ Pos.ofSlice_inj] using Std.ne_of_lt h)
have := (ht hsp).next.eq_right (Slice.Pos.splits _)
simpa [Pos.ofSlice_next, this, Pos.get_eq_get_ofSlice] using ht
@[simp]
theorem copy_slice_next {s : String} {p : s.Pos} {h} :
(s.slice p (p.next h) (by simp)).copy = String.singleton (p.get h) := by
rw [copy_slice_eq_append_of_lt (by simp), copy_slice_self, String.append_empty]
theorem splits_slice {s : String} {p₀ p₁ : s.Pos} (h) (p : (s.slice p₀ p₁ h).Pos) :
p.Splits (s.slice p₀ (Pos.ofSlice p) Pos.le_ofSlice).copy (s.slice (Pos.ofSlice p) p₁ Pos.ofSlice_le).copy := by
simpa using p.splits
theorem Slice.copy_slice_eq_append_of_lt {s : Slice} {p q : s.Pos} (h : p < q) :
(s.slice p q (by exact Std.le_of_lt h)).copy =
String.singleton (p.get (by exact Pos.ne_endPos_of_lt h)) ++
(s.slice (p.next (Pos.ne_endPos_of_lt h)) q (by simpa)).copy := by
have hsp := (s.slice p q (Std.le_of_lt h)).splits_startPos
obtain t₂, ht := hsp.exists_eq_singleton_append (by simpa [ Pos.ofSlice_inj] using Std.ne_of_lt h)
have := (ht hsp).next.eq_right (Slice.Pos.splits _)
simpa [Pos.ofSlice_next, this, Pos.get_eq_get_ofSlice] using ht
@[simp]
theorem Slice.copy_slice_next {s : Slice} {p : s.Pos} {h} :
(s.slice p (p.next h) (by simp)).copy = String.singleton (p.get h) := by
rw [copy_slice_eq_append_of_lt (by simp), copy_slice_self, String.append_empty]
theorem Slice.splits_slice {s : Slice} {p₀ p₁ : s.Pos} (h) (p : (s.slice p₀ p₁ h).Pos) :
p.Splits (s.slice p₀ (Pos.ofSlice p) Pos.le_ofSlice).copy (s.slice (Pos.ofSlice p) p₁ Pos.ofSlice_le).copy := by
simpa using p.splits
end String

View File

@@ -74,9 +74,15 @@ instance : BEq Slice where
def toString (s : Slice) : String :=
s.copy
@[simp]
theorem toString_eq : toString = copy := (rfl)
instance : ToString String.Slice where
toString := toString
@[simp]
theorem toStringToString_eq : ToString.toString = String.Slice.copy := (rfl)
@[extern "lean_slice_hash"]
opaque hash (s : @& Slice) : UInt64

View File

@@ -7,6 +7,8 @@ module
prelude
public import Init.Data.String.Basic
import Init.Data.String.Lemmas.IsEmpty
import Init.Data.String.Lemmas.Basic
set_option doc.verso true
@@ -59,6 +61,11 @@ theorem startInclusive_toSlice {s : Slice} {sl : s.Subslice} :
theorem endExclusive_toSlice {s : Slice} {sl : s.Subslice} :
sl.toSlice.endExclusive = sl.endExclusive.str := rfl
@[simp]
theorem isEmpty_toSlice_iff {s : Slice} {sl : s.Subslice} :
sl.toSlice.isEmpty sl.startInclusive = sl.endExclusive := by
simp [toSlice]
instance {s : Slice} : CoeOut s.Subslice Slice where
coe := Subslice.toSlice
@@ -76,6 +83,16 @@ def toString {s : Slice} (sl : s.Subslice) : String :=
instance {s : Slice} : ToString s.Subslice where
toString
@[simp]
theorem copy_eq {s : Slice} : copy (s := s) = Slice.copy toSlice := (rfl)
@[simp]
theorem toString_eq {s : Slice} : toString (s := s) = Slice.copy toSlice := (rfl)
@[simp]
theorem toStringToString_eq {s : Slice} :
ToString.toString (α := s.Subslice) = Slice.copy toSlice := (rfl)
end Subslice
/--
@@ -130,6 +147,15 @@ theorem startInclusive_subsliceFrom {s : Slice} {newStart : s.Pos} :
theorem endExclusive_subsliceFrom {s : Slice} {newStart : s.Pos} :
(s.subsliceFrom newStart).endExclusive = s.endPos := (rfl)
@[simp]
theorem subslice_endPos {s : Slice} {newStart : s.Pos} :
s.subslice newStart s.endPos (Slice.Pos.le_endPos _) = s.subsliceFrom newStart := (rfl)
@[simp]
theorem toSlice_subsliceFrom {s : Slice} {newStart : s.Pos} :
(s.subsliceFrom newStart).toSlice = s.sliceFrom newStart := by
ext1 <;> simp
/-- The entire slice, as a subslice of itself. -/
@[inline]
def toSubslice (s : Slice) : s.Subslice :=

View File

@@ -101,6 +101,17 @@ theorem toArray_mk {xs : Array α} (h : xs.size = n) : (Vector.mk xs h).toArray
@[simp] theorem foldr_mk {f : α β β} {b : β} {xs : Array α} (h : xs.size = n) :
(Vector.mk xs h).foldr f b = xs.foldr f b := rfl
@[simp, grind =] theorem foldlM_toArray [Monad m]
{f : β α m β} {init : β} {xs : Vector α n} :
xs.toArray.foldlM f init = xs.foldlM f init := rfl
@[simp, grind =] theorem foldrM_toArray [Monad m]
{f : α β m β} {init : β} {xs : Vector α n} :
xs.toArray.foldrM f init = xs.foldrM f init := rfl
@[simp, grind =] theorem foldl_toArray (f : β α β) {init : β} {xs : Vector α n} :
xs.toArray.foldl f init = xs.foldl f init := rfl
@[simp] theorem drop_mk {xs : Array α} {h : xs.size = n} {i} :
(Vector.mk xs h).drop i = Vector.mk (xs.extract i xs.size) (by simp [h]) := rfl
@@ -514,17 +525,32 @@ protected theorem ext {xs ys : Vector α n} (h : (i : Nat) → (_ : i < n) → x
@[grind =_] theorem toList_toArray {xs : Vector α n} : xs.toArray.toList = xs.toList := rfl
theorem toArray_toList {xs : Vector α n} : xs.toList.toArray = xs.toArray := rfl
@[simp, grind =] theorem foldlM_toList [Monad m]
{f : β α m β} {init : β} {xs : Vector α n} :
xs.toList.foldlM f init = xs.foldlM f init := by
rw [ foldlM_toArray, toArray_toList, List.foldlM_toArray]
@[simp, grind =] theorem foldl_toList (f : β α β) {init : β} {xs : Vector α n} :
xs.toList.foldl f init = xs.foldl f init :=
List.foldl_eq_foldlM .. foldlM_toList ..
@[simp, grind =] theorem foldrM_toList [Monad m]
{f : α β m β} {init : β} {xs : Vector α n} :
xs.toList.foldrM f init = xs.foldrM f init := by
rw [ foldrM_toArray, toArray_toList, List.foldrM_toArray]
@[simp, grind =] theorem foldr_toList (f : α β β) {init : β} {xs : Vector α n} :
xs.toList.foldr f init = xs.foldr f init :=
List.foldr_eq_foldrM .. foldrM_toList ..
@[simp, grind =] theorem toList_mk : (Vector.mk xs h).toList = xs.toList := rfl
@[simp, grind =] theorem sum_toList [Add α] [Zero α] {xs : Vector α n} :
xs.toList.sum = xs.sum := by
rw [ toList_toArray, Array.sum_toList, sum_toArray]
@[simp, grind =]
theorem toList_zip {as : Vector α n} {bs : Vector β n} :
(Vector.zip as bs).toList = List.zip as.toList bs.toList := by
rw [mk_zip_mk, toList_mk, Array.toList_zip, toList_toArray, toList_toArray]
@[simp] theorem getElem_toList {xs : Vector α n} {i : Nat} (h : i < xs.toList.length) :
xs.toList[i] = xs[i]'(by simpa using h) := by
cases xs
@@ -609,6 +635,11 @@ theorem toList_swap {xs : Vector α n} {i j} (hi hj) :
@[simp] theorem toList_take {xs : Vector α n} {i} : (xs.take i).toList = xs.toList.take i := by
simp [toList]
@[simp, grind =]
theorem toList_zip {as : Vector α n} {bs : Vector β n} :
(Vector.zip as bs).toList = List.zip as.toList bs.toList := by
rw [mk_zip_mk, toList_mk, Array.toList_zip, toList_toArray, toList_toArray]
@[simp] theorem toList_zipWith {f : α β γ} {as : Vector α n} {bs : Vector β n} :
(Vector.zipWith f as bs).toList = List.zipWith f as.toList bs.toList := by
rcases as with as, rfl
@@ -703,6 +734,9 @@ protected theorem eq_empty {xs : Vector α 0} : xs = #v[] := by
/-! ### size -/
theorem size_singleton {x : α} : #v[x].size = 1 := by
simp
theorem eq_empty_of_size_eq_zero {xs : Vector α n} (h : n = 0) : xs = #v[].cast h.symm := by
rcases xs with xs, rfl
apply toArray_inj.1
@@ -2448,6 +2482,21 @@ theorem foldl_eq_foldr_reverse {xs : Vector α n} {f : β → α → β} {b} :
theorem foldr_eq_foldl_reverse {xs : Vector α n} {f : α β β} {b} :
xs.foldr f b = xs.reverse.foldl (fun x y => f y x) b := by simp
theorem foldl_eq_apply_foldr {xs : Vector α n} {f : α α α}
[Std.Associative f] [Std.LawfulRightIdentity f init] :
xs.foldl f x = f x (xs.foldr f init) := by
simp [ foldl_toList, foldr_toList, List.foldl_eq_apply_foldr]
theorem foldr_eq_apply_foldl {xs : Vector α n} {f : α α α}
[Std.Associative f] [Std.LawfulLeftIdentity f init] :
xs.foldr f x = f (xs.foldl f init) x := by
simp [ foldl_toList, foldr_toList, List.foldr_eq_apply_foldl]
theorem foldr_eq_foldl {xs : Vector α n} {f : α α α}
[Std.Associative f] [Std.LawfulIdentity f init] :
xs.foldr f init = xs.foldl f init := by
simp [foldl_eq_apply_foldr, Std.LawfulLeftIdentity.left_id]
theorem foldl_assoc {op : α α α} [ha : Std.Associative op] {xs : Vector α n} {a₁ a₂} :
xs.foldl op (op a₁ a₂) = op a₁ (xs.foldl op a₂) := by
rcases xs with xs, rfl
@@ -3064,8 +3113,25 @@ theorem sum_append [Zero α] [Add α] [Std.Associative (α := α) (· + ·)]
{as₁ as₂ : Vector α n} : (as₁ ++ as₂).sum = as₁.sum + as₂.sum := by
simp [ sum_toList, List.sum_append]
@[simp, grind =]
theorem sum_singleton [Add α] [Zero α] [Std.LawfulRightIdentity (· + ·) (0 : α)] {x : α} :
#v[x].sum = x := by
simp [ sum_toList, Std.LawfulRightIdentity.right_id x]
@[simp, grind =]
theorem sum_push [Add α] [Zero α] [Std.Associative (α := α) (· + ·)]
[Std.LawfulIdentity (· + ·) (0 : α)] {xs : Vector α n} {x : α} :
(xs.push x).sum = xs.sum + x := by
simp [ sum_toArray]
@[simp, grind =]
theorem sum_reverse [Zero α] [Add α] [Std.Associative (α := α) (· + ·)]
[Std.Commutative (α := α) (· + ·)]
[Std.LawfulLeftIdentity (α := α) (· + ·) 0] (xs : Vector α n) : xs.reverse.sum = xs.sum := by
simp [ sum_toList, List.sum_reverse]
theorem sum_eq_foldl [Zero α] [Add α]
[Std.Associative (α := α) (· + ·)] [Std.LawfulIdentity (· + ·) (0 : α)]
{xs : Vector α n} :
xs.sum = xs.foldl (b := 0) (· + ·) := by
simp [ sum_toList, List.sum_eq_foldl]

View File

@@ -910,6 +910,8 @@ When messages contain autogenerated names (e.g., metavariables like `?m.47`), th
differ between runs or Lean versions. Use `set_option pp.mvars.anonymous false` to replace
anonymous metavariables with `?_` while preserving user-named metavariables like `?a`.
Alternatively, `set_option pp.mvars false` replaces all metavariables with `?_`.
Similarly, `set_option pp.fvars.anonymous false` replaces loose free variable names like
`_fvar.22` with `_fvar._`.
For example, `#guard_msgs (error, drop all) in cmd` means to check errors and drop
everything else.

View File

@@ -144,7 +144,10 @@ theorem mul_def (xs ys : IntList) : xs * ys = List.zipWith (· * ·) xs ys :=
@[simp] theorem mul_nil_left : ([] : IntList) * ys = [] := rfl
@[simp] theorem mul_nil_right : xs * ([] : IntList) = [] := List.zipWith_nil_right
@[simp] theorem mul_cons : (x::xs : IntList) * (y::ys) = (x * y) :: (xs * ys) := rfl
@[simp] theorem mul_cons_cons : (x::xs : IntList) * (y::ys) = (x * y) :: (xs * ys) := rfl
@[deprecated mul_cons_cons (since := "2026-02-26")]
theorem mul_cons₂ : (x::xs : IntList) * (y::ys) = (x * y) :: (xs * ys) := mul_cons_cons
/-- Implementation of negation on `IntList`. -/
def neg (xs : IntList) : IntList := xs.map fun x => -x
@@ -278,7 +281,10 @@ example : IntList.dot [a, b, c] [x, y, z] = IntList.dot [a, b, c] [x, y, z, w] :
@[local simp] theorem dot_nil_left : dot ([] : IntList) ys = 0 := rfl
@[simp] theorem dot_nil_right : dot xs ([] : IntList) = 0 := by simp [dot]
@[simp] theorem dot_cons : dot (x::xs) (y::ys) = x * y + dot xs ys := rfl
@[simp] theorem dot_cons_cons : dot (x::xs) (y::ys) = x * y + dot xs ys := rfl
@[deprecated dot_cons_cons (since := "2026-02-26")]
theorem dot_cons₂ : dot (x::xs) (y::ys) = x * y + dot xs ys := dot_cons_cons
-- theorem dot_comm (xs ys : IntList) : dot xs ys = dot ys xs := by
-- rw [dot, dot, mul_comm]
@@ -296,7 +302,7 @@ example : IntList.dot [a, b, c] [x, y, z] = IntList.dot [a, b, c] [x, y, z, w] :
cases ys with
| nil => simp
| cons y ys =>
simp only [set_cons_zero, dot_cons, get_cons_zero, Int.sub_mul]
simp only [set_cons_zero, dot_cons_cons, get_cons_zero, Int.sub_mul]
rw [Int.add_right_comm, Int.add_comm (x * y), Int.sub_add_cancel]
| succ i =>
cases ys with
@@ -319,7 +325,7 @@ theorem dot_of_left_zero (w : ∀ x, x ∈ xs → x = 0) : dot xs ys = 0 := by
cases ys with
| nil => simp
| cons y ys =>
rw [dot_cons, w x (by simp [List.mem_cons_self]), ih]
rw [dot_cons_cons, w x (by simp [List.mem_cons_self]), ih]
· simp
· intro x m
apply w
@@ -400,7 +406,7 @@ attribute [simp] Int.zero_dvd
cases ys with
| nil => simp
| cons y ys =>
rw [dot_cons, Int.add_emod,
rw [dot_cons_cons, Int.add_emod,
Int.emod_emod_of_dvd (x * y) (gcd_cons_div_left),
Int.emod_emod_of_dvd (dot xs ys) (Int.ofNat_dvd.mpr gcd_cons_div_right)]
simp_all
@@ -415,7 +421,7 @@ theorem dot_eq_zero_of_left_eq_zero {xs ys : IntList} (h : ∀ x, x ∈ xs → x
cases ys with
| nil => rfl
| cons y ys =>
rw [dot_cons, h x List.mem_cons_self, ih (fun x m => h x (List.mem_cons_of_mem _ m)),
rw [dot_cons_cons, h x List.mem_cons_self, ih (fun x m => h x (List.mem_cons_of_mem _ m)),
Int.zero_mul, Int.add_zero]
@[simp] theorem nil_dot (xs : IntList) : dot [] xs = 0 := rfl
@@ -456,7 +462,7 @@ theorem dvd_bmod_dot_sub_dot_bmod (m : Nat) (xs ys : IntList) :
cases ys with
| nil => simp
| cons y ys =>
simp only [IntList.dot_cons, List.map_cons]
simp only [IntList.dot_cons_cons, List.map_cons]
specialize ih ys
rw [Int.sub_emod, Int.bmod_emod] at ih
rw [Int.sub_emod, Int.bmod_emod, Int.add_emod, Int.add_emod (Int.bmod x m * y),

View File

@@ -32,6 +32,89 @@ unsafe axiom lcAny : Type
/-- Internal representation of `Void` in the compiler. -/
unsafe axiom lcVoid : Type
set_option bootstrap.inductiveCheckResultingUniverse false in
/--
The canonical universe-polymorphic type with just one element.
It should be used in contexts that require a type to be universe polymorphic, thus disallowing
`Unit`.
-/
inductive PUnit : Sort u where
/-- The only element of the universe-polymorphic unit type. -/
| unit : PUnit
/--
The equality relation. It has one introduction rule, `Eq.refl`.
We use `a = b` as notation for `Eq a b`.
A fundamental property of equality is that it is an equivalence relation.
```
variable (α : Type) (a b c d : α)
variable (hab : a = b) (hcb : c = b) (hcd : c = d)
example : a = d :=
Eq.trans (Eq.trans hab (Eq.symm hcb)) hcd
```
Equality is much more than an equivalence relation, however. It has the important property that every assertion
respects the equivalence, in the sense that we can substitute equal expressions without changing the truth value.
That is, given `h1 : a = b` and `h2 : p a`, we can construct a proof for `p b` using substitution: `Eq.subst h1 h2`.
Example:
```
example (α : Type) (a b : α) (p : α → Prop)
(h1 : a = b) (h2 : p a) : p b :=
Eq.subst h1 h2
example (α : Type) (a b : α) (p : α → Prop)
(h1 : a = b) (h2 : p a) : p b :=
h1 ▸ h2
```
The triangle in the second presentation is a macro built on top of `Eq.subst` and `Eq.symm`, and you can enter it by typing `\t`.
For more information: [Equality](https://lean-lang.org/theorem_proving_in_lean4/quantifiers_and_equality.html#equality)
-/
inductive Eq : α α Prop where
/-- `Eq.refl a : a = a` is reflexivity, the unique constructor of the
equality type. See also `rfl`, which is usually used instead. -/
| refl (a : α) : Eq a a
/-- Non-dependent recursor for the equality type. -/
@[simp] abbrev Eq.ndrec.{u1, u2} {α : Sort u2} {a : α} {motive : α Sort u1} (m : motive a) {b : α} (h : Eq a b) : motive b :=
h.rec m
/--
Heterogeneous equality. `a ≍ b` asserts that `a` and `b` have the same
type, and casting `a` across the equality yields `b`, and vice versa.
You should avoid using this type if you can. Heterogeneous equality does not
have all the same properties as `Eq`, because the assumption that the types of
`a` and `b` are equal is often too weak to prove theorems of interest. One
public important non-theorem is the analogue of `congr`: If `f ≍ g` and `x ≍ y`
and `f x` and `g y` are well typed it does not follow that `f x ≍ g y`.
(This does follow if you have `f = g` instead.) However if `a` and `b` have
the same type then `a = b` and `a ≍ b` are equivalent.
-/
inductive HEq : {α : Sort u} α {β : Sort u} β Prop where
/-- Reflexivity of heterogeneous equality. -/
| refl (a : α) : HEq a a
/--
The Boolean values, `true` and `false`.
Logically speaking, this is equivalent to `Prop` (the type of propositions). The distinction is
public important for programming: both propositions and their proofs are erased in the code generator,
while `Bool` corresponds to the Boolean type in most programming languages and carries precisely one
bit of run-time information.
-/
inductive Bool : Type where
/-- The Boolean value `false`, not to be confused with the proposition `False`. -/
| false : Bool
/-- The Boolean value `true`, not to be confused with the proposition `True`. -/
| true : Bool
export Bool (false true)
/-- Compute whether `x` is a tagged pointer or not. -/
@[extern "lean_is_scalar"]
unsafe axiom isScalarObj {α : Type u} (x : α) : Bool
/--
The identity function. `id` takes an implicit argument `α : Sort u`
@@ -115,16 +198,7 @@ does.) Example:
-/
abbrev inferInstanceAs (α : Sort u) [i : α] : α := i
set_option bootstrap.inductiveCheckResultingUniverse false in
/--
The canonical universe-polymorphic type with just one element.
It should be used in contexts that require a type to be universe polymorphic, thus disallowing
`Unit`.
-/
inductive PUnit : Sort u where
/-- The only element of the universe-polymorphic unit type. -/
| unit : PUnit
/--
The canonical type with one element. This element is written `()`.
@@ -245,42 +319,6 @@ For more information: [Propositional Logic](https://lean-lang.org/theorem_provin
@[macro_inline] def absurd {a : Prop} {b : Sort v} (h₁ : a) (h₂ : Not a) : b :=
(h₂ h₁).rec
/--
The equality relation. It has one introduction rule, `Eq.refl`.
We use `a = b` as notation for `Eq a b`.
A fundamental property of equality is that it is an equivalence relation.
```
variable (α : Type) (a b c d : α)
variable (hab : a = b) (hcb : c = b) (hcd : c = d)
example : a = d :=
Eq.trans (Eq.trans hab (Eq.symm hcb)) hcd
```
Equality is much more than an equivalence relation, however. It has the important property that every assertion
respects the equivalence, in the sense that we can substitute equal expressions without changing the truth value.
That is, given `h1 : a = b` and `h2 : p a`, we can construct a proof for `p b` using substitution: `Eq.subst h1 h2`.
Example:
```
example (α : Type) (a b : α) (p : α → Prop)
(h1 : a = b) (h2 : p a) : p b :=
Eq.subst h1 h2
example (α : Type) (a b : α) (p : α → Prop)
(h1 : a = b) (h2 : p a) : p b :=
h1 ▸ h2
```
The triangle in the second presentation is a macro built on top of `Eq.subst` and `Eq.symm`, and you can enter it by typing `\t`.
For more information: [Equality](https://lean-lang.org/theorem_proving_in_lean4/quantifiers_and_equality.html#equality)
-/
inductive Eq : α α Prop where
/-- `Eq.refl a : a = a` is reflexivity, the unique constructor of the
equality type. See also `rfl`, which is usually used instead. -/
| refl (a : α) : Eq a a
/-- Non-dependent recursor for the equality type. -/
@[simp] abbrev Eq.ndrec.{u1, u2} {α : Sort u2} {a : α} {motive : α Sort u1} (m : motive a) {b : α} (h : Eq a b) : motive b :=
h.rec m
/--
`rfl : a = a` is the unique constructor of the equality type. This is the
same as `Eq.refl` except that it takes `a` implicitly instead of explicitly.
@@ -477,21 +515,6 @@ Unsafe auxiliary constant used by the compiler to erase `Quot.lift`.
-/
unsafe axiom Quot.lcInv {α : Sort u} {r : α α Prop} (q : Quot r) : α
/--
Heterogeneous equality. `a ≍ b` asserts that `a` and `b` have the same
type, and casting `a` across the equality yields `b`, and vice versa.
You should avoid using this type if you can. Heterogeneous equality does not
have all the same properties as `Eq`, because the assumption that the types of
`a` and `b` are equal is often too weak to prove theorems of interest. One
public important non-theorem is the analogue of `congr`: If `f ≍ g` and `x ≍ y`
and `f x` and `g y` are well typed it does not follow that `f x ≍ g y`.
(This does follow if you have `f = g` instead.) However if `a` and `b` have
the same type then `a = b` and `a ≍ b` are equivalent.
-/
inductive HEq : {α : Sort u} α {β : Sort u} β Prop where
/-- Reflexivity of heterogeneous equality. -/
| refl (a : α) : HEq a a
/-- A version of `HEq.refl` with an implicit argument. -/
@[match_pattern] protected def HEq.rfl {α : Sort u} {a : α} : HEq a a :=
@@ -599,23 +622,6 @@ theorem Or.resolve_left (h: Or a b) (na : Not a) : b := h.elim (absurd · na) i
theorem Or.resolve_right (h: Or a b) (nb : Not b) : a := h.elim id (absurd · nb)
theorem Or.neg_resolve_left (h : Or (Not a) b) (ha : a) : b := h.elim (absurd ha) id
theorem Or.neg_resolve_right (h : Or a (Not b)) (nb : b) : a := h.elim id (absurd nb)
/--
The Boolean values, `true` and `false`.
Logically speaking, this is equivalent to `Prop` (the type of propositions). The distinction is
public important for programming: both propositions and their proofs are erased in the code generator,
while `Bool` corresponds to the Boolean type in most programming languages and carries precisely one
bit of run-time information.
-/
inductive Bool : Type where
/-- The Boolean value `false`, not to be confused with the proposition `False`. -/
| false : Bool
/-- The Boolean value `true`, not to be confused with the proposition `True`. -/
| true : Bool
export Bool (false true)
/--
All the elements of a type that satisfy a predicate.

View File

@@ -30,6 +30,7 @@ variable {α : Sort _} {β : α → Sort _} {γ : (a : α) → β a → Sort _}
set_option doc.verso true
namespace WellFounded
open Relation
/--
The function implemented as the loop {lean}`opaqueFix R F a = F a (fun a _ => opaqueFix R F a)`.
@@ -85,6 +86,23 @@ public theorem extrinsicFix_eq_apply [∀ a, Nonempty (C a)] {R : αα
simp only [extrinsicFix, dif_pos h]
rw [WellFounded.fix_eq]
public theorem extrinsicFix_invImage {α' : Sort _} [ a, Nonempty (C a)] (R : α α Prop) (f : α' α)
(F : a, ( a', R a' a C a') C a) (F' : a, ( a', R (f a') (f a) C (f a')) C (f a))
(h : a r, F (f a) r = F' a fun a' hR => r (f a') hR) (a : α') (h : WellFounded R) :
extrinsicFix (C := (C <| f ·)) (InvImage R f) F' a = extrinsicFix (C := C) R F (f a) := by
have h' := h
rcases h with h
specialize h (f a)
have : Acc (InvImage R f) a := InvImage.accessible _ h
clear h
induction this
rename_i ih
rw [extrinsicFix_eq_apply, extrinsicFix_eq_apply, h]
· congr; ext a x
rw [ih _ x]
· assumption
· exact InvImage.wf _ _
/--
A fixpoint combinator that allows for deferred proofs of termination.
@@ -242,4 +260,273 @@ nontrivial properties about it.
-/
add_decl_doc extrinsicFix₃
/--
A fixpoint combinator that can be used to construct recursive functions with an
*extrinsic, partial* proof of termination.
Given a relation {name}`R` and a fixpoint functional {name}`F` which must be decreasing with respect
to {name}`R`, {lean}`partialExtrinsicFix R F` is the recursive function obtained by having {name}`F` call
itself recursively.
For each input {given}`a`, {lean}`partialExtrinsicFix R F a` can be verified given a *partial* termination
proof. The precise semantics are as follows.
If {lean}`Acc R a` does not hold, {lean}`partialExtrinsicFix R F a` might run forever. In this case,
nothing interesting can be proved about the recursive function; it is opaque and behaves like a
recursive function with the `partial` modifier.
If {lean}`Acc R a` _does_ hold, {lean}`partialExtrinsicFix R F a` is equivalent to
{lean}`F a (fun a' _ => partialExtrinsicFix R F a')`, both logically and regarding its termination behavior.
In particular, if {name}`R` is well-founded, {lean}`partialExtrinsicFix R F a` is equivalent to
{lean}`WellFounded.fix _ F`.
-/
@[inline]
public def partialExtrinsicFix [ a, Nonempty (C a)] (R : α α Prop)
(F : a, ( a', R a' a C a') C a) (a : α) : C a :=
extrinsicFix (α := { a' : α // a' = a TransGen R a' a }) (C := (C ·.1))
(fun p q => R p.1 q.1)
(fun a recur => F a.1 fun a' hR => recur a', by
rcases a.property with ha | ha
· exact Or.inr (TransGen.single (ha hR))
· apply Or.inr
apply TransGen.trans ?_ _
apply TransGen.single
assumption _) a, Or.inl rfl
public theorem partialExtrinsicFix_eq_apply_of_acc [ a, Nonempty (C a)] {R : α α Prop}
{F : a, ( a', R a' a C a') C a} {a : α} (h : Acc R a) :
partialExtrinsicFix R F a = F a (fun a' _ => partialExtrinsicFix R F a') := by
simp only [partialExtrinsicFix]
rw [extrinsicFix_eq_apply]
congr; ext a' hR
let f (x : { x : α // x = a' TransGen R x a' }) : { x : α // x = a TransGen R x a } :=
x.val, by
cases x.property
· rename_i h
rw [h]
exact Or.inr (.single hR)
· rename_i h
apply Or.inr
refine TransGen.trans h ?_
exact .single hR
have := extrinsicFix_invImage (C := (C ·.val)) (R := (R ·.1 ·.1)) (f := f)
(F := fun a r => F a.1 fun a' hR => r a', Or.inr (by rcases a.2 with ha | ha; exact .single (ha hR); exact .trans (.single hR) _) hR)
(F' := fun a r => F a.1 fun a' hR => r a', by rcases a.2 with ha | ha; exact .inr (.single (ha hR)); exact .inr (.trans (.single hR) _) hR)
unfold InvImage at this
rw [this]
· simp +zetaDelta
· constructor
intro x
refine InvImage.accessible _ ?_
cases x.2 <;> rename_i hx
· rwa [hx]
· exact h.inv_of_transGen hx
· constructor
intro x
refine InvImage.accessible _ ?_
cases x.2 <;> rename_i hx
· rwa [hx]
· exact h.inv_of_transGen hx
public theorem partialExtrinsicFix_eq_apply [ a, Nonempty (C a)] {R : α α Prop}
{F : a, ( a', R a' a C a') C a} {a : α} (wf : WellFounded R) :
partialExtrinsicFix R F a = F a (fun a' _ => partialExtrinsicFix R F a') :=
partialExtrinsicFix_eq_apply_of_acc (wf.apply _)
public theorem partialExtrinsicFix_eq_fix [ a, Nonempty (C a)] {R : α α Prop}
{F : a, ( a', R a' a C a') C a}
(wf : WellFounded R) {a : α} :
partialExtrinsicFix R F a = wf.fix F a := by
have h := wf.apply a
induction h with | intro a' h ih
rw [partialExtrinsicFix_eq_apply_of_acc (Acc.intro _ h), WellFounded.fix_eq]
congr 1; ext a'' hR
exact ih _ hR
/--
A 2-ary fixpoint combinator that can be used to construct recursive functions with an
*extrinsic, partial* proof of termination.
Given a relation {name}`R` and a fixpoint functional {name}`F` which must be decreasing with respect
to {name}`R`, {lean}`partialExtrinsicFix₂ R F` is the recursive function obtained by having {name}`F` call
itself recursively.
For each pair of inputs {given}`a` and {given}`b`, {lean}`partialExtrinsicFix₂ R F a b` can be verified
given a *partial* termination proof. The precise semantics are as follows.
If {lean}`Acc R ⟨a, b⟩ ` does not hold, {lean}`partialExtrinsicFix₂ R F a b` might run forever. In this
case, nothing interesting can be proved about the recursive function; it is opaque and behaves like
a recursive function with the `partial` modifier.
If {lean}`Acc R ⟨a, b⟩` _does_ hold, {lean}`partialExtrinsicFix₂ R F a b` is equivalent to
{lean}`F a b (fun a' b' _ => partialExtrinsicFix₂ R F a' b')`, both logically and regarding its
termination behavior.
In particular, if {name}`R` is well-founded, {lean}`partialExtrinsicFix₂ R F a b` is equivalent to
a well-foundesd fixpoint.
-/
@[inline]
public def partialExtrinsicFix₂ [ a b, Nonempty (C₂ a b)]
(R : (a : α) ×' β a (a : α) ×' β a Prop)
(F : (a : α) (b : β a) ((a' : α) (b' : β a') R a', b' a, b C₂ a' b') C₂ a b)
(a : α) (b : β a) :
C₂ a b :=
extrinsicFix₂ (α := α) (β := fun a' => { b' : β a' // (PSigma.mk a' b') = (PSigma.mk a b) TransGen R a', b' a, b })
(C₂ := (C₂ · ·.1))
(fun p q => R p.1, p.2.1 q.1, q.2.1)
(fun a b recur => F a b.1 fun a' b' hR => recur a' b', Or.inr (by
rcases b.property with hb | hb
· exact .single (hb hR)
· apply TransGen.trans ?_ _
apply TransGen.single
assumption) _) a b, Or.inl rfl
public theorem partialExtrinsicFix₂_eq_partialExtrinsicFix [ a b, Nonempty (C₂ a b)]
{R : (a : α) ×' β a (a : α) ×' β a Prop}
{F : (a : α) (b : β a) ((a' : α) (b' : β a') R a', b' a, b C₂ a' b') C₂ a b}
{a : α} {b : β a} (h : Acc R a, b) :
partialExtrinsicFix₂ R F a b = partialExtrinsicFix (α := PSigma β) (C := fun a => C₂ a.1 a.2) R (fun p r => F p.1 p.2 fun a' b' hR => r a', b' hR) a, b := by
simp only [partialExtrinsicFix, partialExtrinsicFix₂, extrinsicFix₂]
let f (x : ((a' : α) ×' { b' // PSigma.mk a' b' = a, b TransGen R a', b' a, b })) : { a' // a' = a, b TransGen R a' a, b } :=
x.1, x.2.1, x.2.2
have := extrinsicFix_invImage (C := fun a => C₂ a.1.1 a.1.2) (f := f) (R := (R ·.1 ·.1))
(F := fun a r => F a.1.1 a.1.2 fun a' b' hR => r a', b', ?refine_a hR)
(F' := fun a r => F a.1 a.2.1 fun a' b' hR => r a', b', ?refine_b hR)
(a := a, b, ?refine_c); rotate_left
· cases a.2 <;> rename_i heq
· rw [heq] at hR
exact .inr (.single hR)
· exact .inr (.trans (.single hR) heq)
· cases a.2.2 <;> rename_i heq
· rw [heq] at hR
exact .inr (.single hR)
· exact .inr (.trans (.single hR) heq)
· exact .inl rfl
unfold InvImage f at this
simp at this
rw [this]
constructor
intro x
apply InvImage.accessible
cases x.2 <;> rename_i heq
· rwa [heq]
· exact h.inv_of_transGen heq
public theorem partialExtrinsicFix₂_eq_apply_of_acc [ a b, Nonempty (C₂ a b)]
{R : (a : α) ×' β a (a : α) ×' β a Prop}
{F : (a : α) (b : β a) ((a' : α) (b' : β a') R a', b' a, b C₂ a' b') C₂ a b}
{a : α} {b : β a} (wf : Acc R a, b) :
partialExtrinsicFix₂ R F a b = F a b (fun a' b' _ => partialExtrinsicFix₂ R F a' b') := by
rw [partialExtrinsicFix₂_eq_partialExtrinsicFix wf, partialExtrinsicFix_eq_apply_of_acc wf]
congr 1; ext a' b' hR
rw [partialExtrinsicFix₂_eq_partialExtrinsicFix (wf.inv hR)]
public theorem partialExtrinsicFix₂_eq_apply [ a b, Nonempty (C₂ a b)]
{R : (a : α) ×' β a (a : α) ×' β a Prop}
{F : (a : α) (b : β a) ((a' : α) (b' : β a') R a', b' a, b C₂ a' b') C₂ a b}
{a : α} {b : β a} (wf : WellFounded R) :
partialExtrinsicFix₂ R F a b = F a b (fun a' b' _ => partialExtrinsicFix₂ R F a' b') :=
partialExtrinsicFix₂_eq_apply_of_acc (wf.apply _)
public theorem partialExtrinsicFix₂_eq_fix [ a b, Nonempty (C₂ a b)]
{R : (a : α) ×' β a (a : α) ×' β a Prop}
{F : a b, ( a' b', R a', b' a, b C₂ a' b') C₂ a b}
(wf : WellFounded R) {a b} :
partialExtrinsicFix₂ R F a b = wf.fix (fun x G => F x.1 x.2 (fun a b h => G a, b h)) a, b := by
rw [partialExtrinsicFix₂_eq_partialExtrinsicFix (wf.apply _), partialExtrinsicFix_eq_fix wf]
/--
A 3-ary fixpoint combinator that can be used to construct recursive functions with an
*extrinsic, partial* proof of termination.
Given a relation {name}`R` and a fixpoint functional {name}`F` which must be decreasing with respect
to {name}`R`, {lean}`partialExtrinsicFix₃ R F` is the recursive function obtained by having {name}`F` call
itself recursively.
For each pair of inputs {given}`a`, {given}`b` and {given}`c`, {lean}`partialExtrinsicFix₃ R F a b` can be
verified given a *partial* termination proof. The precise semantics are as follows.
If {lean}`Acc R ⟨a, b, c⟩ ` does not hold, {lean}`partialExtrinsicFix₃ R F a b` might run forever. In this
case, nothing interesting can be proved about the recursive function; it is opaque and behaves like
a recursive function with the `partial` modifier.
If {lean}`Acc R ⟨a, b, c⟩` _does_ hold, {lean}`partialExtrinsicFix₃ R F a b` is equivalent to
{lean}`F a b c (fun a' b' c' _ => partialExtrinsicFix₃ R F a' b' c')`, both logically and regarding its
termination behavior.
In particular, if {name}`R` is well-founded, {lean}`partialExtrinsicFix₃ R F a b c` is equivalent to
a well-foundesd fixpoint.
-/
@[inline]
public def partialExtrinsicFix₃ [ a b c, Nonempty (C₃ a b c)]
(R : (a : α) ×' (b : β a) ×' γ a b (a : α) ×' (b : β a) ×' γ a b Prop)
(F : (a : α) (b : β a) (c : γ a b) ((a' : α) (b' : β a') (c' : γ a' b') R a', b', c' a, b, c C₃ a' b' c') C₃ a b c)
(a : α) (b : β a) (c : γ a b) :
C₃ a b c :=
extrinsicFix₃ (α := α) (β := β) (γ := fun a' b' => { c' : γ a' b' // (a', b', c' : (a : α) ×' (b : β a) ×' γ a b) = a, b, c TransGen R a', b', c' a, b, c })
(C₃ := (C₃ · · ·.1))
(fun p q => R p.1, p.2.1, p.2.2.1 q.1, q.2.1, q.2.2.1)
(fun a b c recur => F a b c.1 fun a' b' c' hR => recur a' b' c', Or.inr (by
rcases c.property with hb | hb
· exact .single (hb hR)
· apply TransGen.trans ?_ _
apply TransGen.single
assumption) _) a b c, Or.inl rfl
public theorem partialExtrinsicFix₃_eq_partialExtrinsicFix [ a b c, Nonempty (C₃ a b c)]
{R : (a : α) ×' (b : β a) ×' γ a b (a : α) ×' (b : β a) ×' γ a b Prop}
{F : (a : α) (b : β a) (c : γ a b) ((a' : α) (b' : β a') (c' : γ a' b') R a', b', c' a, b, c C₃ a' b' c') C₃ a b c}
{a : α} {b : β a} {c : γ a b} (h : Acc R a, b, c) :
partialExtrinsicFix₃ R F a b c = partialExtrinsicFix (α := (a : α) ×' (b : β a) ×' γ a b) (C := fun a => C₃ a.1 a.2.1 a.2.2) R (fun p r => F p.1 p.2.1 p.2.2 fun a' b' c' hR => r a', b', c' hR) a, b, c := by
simp only [partialExtrinsicFix, partialExtrinsicFix₃, extrinsicFix₃]
let f (x : ((a' : α) ×' (b' : β a') ×' { c' // (a', b', c' : (a : α) ×' (b : β a) ×' γ a b) = a, b, c TransGen R a', b', c' a, b, c })) : { a' // a' = a, b, c TransGen R a' a, b, c } :=
x.1, x.2.1, x.2.2.1, x.2.2.2
have := extrinsicFix_invImage (C := fun a => C₃ a.1.1 a.1.2.1 a.1.2.2) (f := f) (R := (R ·.1 ·.1))
(F := fun a r => F a.1.1 a.1.2.1 a.1.2.2 fun a' b' c' hR => r a', b', c', ?refine_a hR)
(F' := fun a r => F a.1 a.2.1 a.2.2.1 fun a' b' c' hR => r a', b', c', ?refine_b hR)
(a := a, b, c, ?refine_c); rotate_left
· cases a.2 <;> rename_i heq
· rw [heq] at hR
exact .inr (.single hR)
· exact .inr (.trans (.single hR) heq)
· cases a.2.2.2 <;> rename_i heq
· rw [heq] at hR
exact .inr (.single hR)
· exact .inr (.trans (.single hR) heq)
· exact .inl rfl
unfold InvImage f at this
simp at this
rw [this]
constructor
intro x
apply InvImage.accessible
cases x.2 <;> rename_i heq
· rwa [heq]
· exact h.inv_of_transGen heq
public theorem partialExtrinsicFix₃_eq_apply_of_acc [ a b c, Nonempty (C₃ a b c)]
{R : (a : α) ×' (b : β a) ×' γ a b (a : α) ×' (b : β a) ×' γ a b Prop}
{F : (a b c), ( (a' b' c'), R a', b', c' a, b, c C₃ a' b' c') C₃ a b c}
{a : α} {b : β a} {c : γ a b} (wf : Acc R a, b, c) :
partialExtrinsicFix₃ R F a b c = F a b c (fun a b c _ => partialExtrinsicFix₃ R F a b c) := by
rw [partialExtrinsicFix₃_eq_partialExtrinsicFix wf, partialExtrinsicFix_eq_apply_of_acc wf]
congr 1; ext a' b' c' hR
rw [partialExtrinsicFix₃_eq_partialExtrinsicFix (wf.inv hR)]
public theorem partialExtrinsicFix₃_eq_apply [ a b c, Nonempty (C₃ a b c)]
{R : (a : α) ×' (b : β a) ×' γ a b (a : α) ×' (b : β a) ×' γ a b Prop}
{F : (a b c), ( (a' b' c'), R a', b', c' a, b, c C₃ a' b' c') C₃ a b c}
{a : α} {b : β a} {c : γ a b} (wf : WellFounded R) :
partialExtrinsicFix₃ R F a b c = F a b c (fun a b c _ => partialExtrinsicFix₃ R F a b c) :=
partialExtrinsicFix₃_eq_apply_of_acc (wf.apply _)
public theorem partialExtrinsicFix₃_eq_fix [ a b c, Nonempty (C₃ a b c)]
{R : (a : α) ×' (b : β a) ×' γ a b (a : α) ×' (b : β a) ×' γ a b Prop}
{F : a b c, ( a' b' c', R a', b', c' a, b, c C₃ a' b' c') C₃ a b c}
(wf : WellFounded R) {a b c} :
partialExtrinsicFix₃ R F a b c = wf.fix (fun x G => F x.1 x.2.1 x.2.2 (fun a b c h => G a, b, c h)) a, b, c := by
rw [partialExtrinsicFix₃_eq_partialExtrinsicFix (wf.apply _), partialExtrinsicFix_eq_fix wf]
end WellFounded

View File

@@ -10,18 +10,14 @@ public import Lean.Compiler.IR.AddExtern
public import Lean.Compiler.IR.Basic
public import Lean.Compiler.IR.Format
public import Lean.Compiler.IR.CompilerM
public import Lean.Compiler.IR.PushProj
public import Lean.Compiler.IR.NormIds
public import Lean.Compiler.IR.Checker
public import Lean.Compiler.IR.ExpandResetReuse
public import Lean.Compiler.IR.UnboxResult
public import Lean.Compiler.IR.EmitC
public import Lean.Compiler.IR.Sorry
public import Lean.Compiler.IR.ToIR
public import Lean.Compiler.IR.ToIRType
public import Lean.Compiler.IR.Meta
public import Lean.Compiler.IR.SimpleGroundExpr
public import Lean.Compiler.IR.ElimDeadVars
-- The following imports are not required by the compiler. They are here to ensure that there
-- are no orphaned modules.
@@ -36,15 +32,9 @@ def compile (decls : Array Decl) : CompilerM (Array Decl) := do
logDecls `init decls
checkDecls decls
let mut decls := decls
if Compiler.LCNF.compiler.reuse.get ( getOptions) then
decls := decls.map Decl.expandResetReuse
logDecls `expand_reset_reuse decls
decls := decls.map Decl.pushProj
logDecls `push_proj decls
decls updateSorryDep decls
logDecls `result decls
checkDecls decls
decls.forM Decl.detectSimpleGround
addDecls decls
inferMeta decls
return decls

View File

@@ -1,72 +0,0 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Compiler.IR.FreeVars
public section
namespace Lean.IR
/--
This function implements a simple heuristic for let values that we know may be dropped because they
are pure.
-/
private def safeToElim (e : Expr) : Bool :=
match e with
| .ctor .. | .reset .. | .reuse .. | .proj .. | .uproj .. | .sproj .. | .box .. | .unbox ..
| .lit .. | .isShared .. | .pap .. => true
-- 0-ary full applications are considered constants
| .fap _ args => args.isEmpty
| .ap .. => false
partial def reshapeWithoutDead (bs : Array FnBody) (term : FnBody) : FnBody :=
let rec reshape (bs : Array FnBody) (b : FnBody) (used : IndexSet) :=
if bs.isEmpty then b
else
let curr := bs.back!
let bs := bs.pop
let keep (_ : Unit) :=
let used := curr.collectFreeIndices used
let b := curr.setBody b
reshape bs b used
let keepIfUsedJp (vidx : Index) :=
if used.contains vidx then
keep ()
else
reshape bs b used
let keepIfUsedLet (vidx : Index) (val : Expr) :=
if used.contains vidx || !safeToElim val then
keep ()
else
reshape bs b used
match curr with
| FnBody.vdecl x _ e _ => keepIfUsedLet x.idx e
-- TODO: we should keep all struct/union projections because they are used to ensure struct/union values are fully consumed.
| FnBody.jdecl j _ _ _ => keepIfUsedJp j.idx
| _ => keep ()
reshape bs term term.freeIndices
partial def FnBody.elimDead (b : FnBody) : FnBody :=
let (bs, term) := b.flatten
let bs := modifyJPs bs elimDead
let term := match term with
| FnBody.case tid x xType alts =>
let alts := alts.map fun alt => alt.modifyBody elimDead
FnBody.case tid x xType alts
| other => other
reshapeWithoutDead bs term
/-- Eliminate dead let-declarations and join points -/
def Decl.elimDead (d : Decl) : Decl :=
match d with
| .fdecl (body := b) .. => d.updateBody! b.elimDead
| other => other
builtin_initialize registerTraceClass `compiler.ir.elim_dead (inherited := true)
end Lean.IR

View File

@@ -13,7 +13,7 @@ public import Lean.Compiler.IR.SimpCase
public import Lean.Compiler.ModPkgExt
import Lean.Compiler.LCNF.Types
import Lean.Compiler.ClosedTermCache
import Lean.Compiler.IR.SimpleGroundExpr
import Lean.Compiler.LCNF.SimpleGroundExpr
import Init.Omega
import Init.While
import Init.Data.Range.Polymorphic.Iterators
@@ -22,7 +22,9 @@ import Lean.Runtime
public section
namespace Lean.IR.EmitC
open Lean.Compiler.LCNF (isBoxedName)
open Lean.Compiler.LCNF (isBoxedName isSimpleGroundDecl getSimpleGroundExpr
getSimpleGroundExprWithResolvedRefs uint64ToByteArrayLE SimpleGroundExpr SimpleGroundArg
addSimpleGroundDecl)
def leanMainFn := "_lean_main"
@@ -39,9 +41,9 @@ abbrev M := ReaderT Context (EStateM String String)
@[inline] def getModName : M Name := Context.modName <$> read
@[inline] def getModInitFn : M String := do
@[inline] def getModInitFn (phases : IRPhases) : M String := do
let pkg? := ( getEnv).getModulePackage?
return mkModuleInitializationFunctionName ( getModName) pkg?
return mkModuleInitializationFunctionName (phases := phases) ( getModName) pkg?
def getDecl (n : Name) : M Decl := do
let env getEnv
@@ -174,6 +176,23 @@ where
| .nameMkStr args =>
let obj groundNameMkStrToCLit args
mkValueCLit "lean_ctor_object" obj
| .array elems =>
let leanArrayTag := 246
let header := mkHeader s!"sizeof(lean_array_object) + sizeof(void*)*{elems.size}" 0 leanArrayTag
let elemLits elems.mapM groundArgToCLit
let dataArray := String.intercalate "," elemLits.toList
mkValueCLit
"lean_array_object"
s!"\{.m_header = {header}, .m_size = {elems.size}, .m_capacity = {elems.size}, .m_data = \{{dataArray}}}"
| .byteArray data =>
let leanScalarArrayTag := 248
let elemSize : Nat := 1
let header := mkHeader s!"sizeof(lean_sarray_object) + {data.size}" elemSize leanScalarArrayTag
let dataLits := data.map toString
let dataArray := String.intercalate "," dataLits.toList
mkValueCLit
"lean_sarray_object"
s!"\{.m_header = {header}, .m_size = {data.size}, .m_capacity = {data.size}, .m_data = \{{dataArray}}}"
| .reference refDecl => findValueDecl refDecl
mkValueName (name : String) : String :=
@@ -222,7 +241,7 @@ where
break
return mkValueName ( toCName decl)
compileCtor (cidx : Nat) (objArgs : Array SimpleGroundArg) (usizeArgs : Array USize)
compileCtor (cidx : Nat) (objArgs : Array SimpleGroundArg) (usizeArgs : Array UInt64)
(scalarArgs : Array UInt8) : GroundM String := do
let header := mkCtorHeader objArgs.size usizeArgs.size scalarArgs.size cidx
let objArgs objArgs.mapM groundArgToCLit
@@ -343,7 +362,7 @@ def emitMainFn : M Unit := do
/- We disable panic messages because they do not mesh well with extracted closed terms.
See issue #534. We can remove this workaround after we implement issue #467. -/
emitLn "lean_set_panic_messages(false);"
emitLn s!"res = {← getModInitFn}(1 /* builtin */);"
emitLn s!"res = {← getModInitFn (phases := if env.header.isModule then .runtime else .all)}(1 /* builtin */);"
emitLn "lean_set_panic_messages(true);"
emitLns ["lean_io_mark_end_initialization();",
"if (lean_io_result_is_ok(res)) {",
@@ -470,7 +489,7 @@ def emitDec (x : VarId) (n : Nat) (checkRef : Bool) : M Unit := do
emitLn ");"
def emitDel (x : VarId) : M Unit := do
emit "lean_free_object("; emit x; emitLn ");"
emit "lean_del_object("; emit x; emitLn ");"
def emitSetTag (x : VarId) (i : Nat) : M Unit := do
emit "lean_ctor_set_tag("; emit x; emit ", "; emit i; emitLn ");"
@@ -887,24 +906,21 @@ def emitMarkPersistent (d : Decl) (n : Name) : M Unit := do
emitCName n
emitLn ");"
def emitDeclInit (d : Decl) : M Unit := do
def withErrRet (emitIORes : M Unit) : M Unit := do
emit "res = "; emitIORes; emitLn ";"
emitLn "if (lean_io_result_is_error(res)) return res;"
def emitDeclInit (d : Decl) (isBuiltin : Bool) : M Unit := do
let env getEnv
let n := d.name
if isIOUnitInitFn env n then
if isIOUnitBuiltinInitFn env n then
emit "if (builtin) {"
emit "res = "; emitCName n; emitLn "();"
emitLn "if (lean_io_result_is_error(res)) return res;"
if (isBuiltin && isIOUnitBuiltinInitFn env n) || isIOUnitInitFn env n then
withErrRet do
emitCName n; emitLn "()"
emitLn "lean_dec_ref(res);"
if isIOUnitBuiltinInitFn env n then
emit "}"
else if d.params.size == 0 then
match getInitFnNameFor? env d.name with
| some initFn =>
if getBuiltinInitFnNameFor? env d.name |>.isSome then
emit "if (builtin) {"
emit "res = "; emitCName initFn; emitLn "();"
emitLn "if (lean_io_result_is_error(res)) return res;"
if let some initFn := (guard isBuiltin *> getBuiltinInitFnNameFor? env d.name) <|> getInitFnNameFor? env d.name then
withErrRet do
emitCName initFn; emitLn "()"
emitCName n
if d.resultType.isScalar then
emitLn (" = " ++ getUnboxOpName d.resultType ++ "(lean_io_result_get_value(res));")
@@ -912,41 +928,78 @@ def emitDeclInit (d : Decl) : M Unit := do
emitLn " = lean_io_result_get_value(res);"
emitMarkPersistent d n
emitLn "lean_dec_ref(res);"
if getBuiltinInitFnNameFor? env d.name |>.isSome then
emit "}"
| _ =>
if !isClosedTermName env d.name && !isSimpleGroundDecl env d.name then
emitCName n; emit " = "; emitCInitName n; emitLn "();"; emitMarkPersistent d n
else if !isClosedTermName env d.name && !isSimpleGroundDecl env d.name then
emitCName n; emit " = "; emitCInitName n; emitLn "();"; emitMarkPersistent d n
def emitInitFn : M Unit := do
def emitInitFn (phases : IRPhases) : M Unit := do
let env getEnv
let impInitFns env.imports.mapM fun imp => do
let impInitFns env.imports.filterMapM fun imp => do
if phases != .all && imp.isMeta != (phases == .comptime) then
return none
let some idx := env.getModuleIdx? imp.module
| throw "(internal) import without module index" -- should be unreachable
let pkg? := env.getModulePackageByIdx? idx
let fn := mkModuleInitializationFunctionName (phases := if phases == .all then .all else if imp.isMeta then .runtime else phases) imp.module pkg?
emitLn s!"lean_object* {fn}(uint8_t builtin);"
return some fn
let initialized := s!"_G_{mkModuleInitializationPrefix phases}initialized"
emitLns [
s!"static bool {initialized} = false;",
s!"LEAN_EXPORT lean_object* {← getModInitFn (phases := phases)}(uint8_t builtin) \{",
"lean_object * res;",
s!"if ({initialized}) return lean_io_result_mk_ok(lean_box(0));",
s!"{initialized} = true;"
]
impInitFns.forM fun fn => do
withErrRet do
emitLn s!"{fn}(builtin)"
emitLn "lean_dec_ref(res);"
let decls := getDecls env
for d in decls.reverse do
if phases == .all || (phases == .comptime) == isMarkedMeta env d.name then
emitDeclInit d (isBuiltin := phases != .comptime)
emitLns ["return lean_io_result_mk_ok(lean_box(0));", "}"]
/-- Init function used before phase split under module system, keep for compatibility. -/
def emitLegacyInitFn : M Unit := do
let env getEnv
let impInitFns env.imports.filterMapM fun imp => do
let some idx := env.getModuleIdx? imp.module
| throw "(internal) import without module index" -- should be unreachable
let pkg? := env.getModulePackageByIdx? idx
let fn := mkModuleInitializationFunctionName imp.module pkg?
emitLn s!"lean_object* {fn}(uint8_t builtin);"
return fn
return some fn
let initialized := s!"_G_initialized"
emitLns [
"static bool _G_initialized = false;",
s!"LEAN_EXPORT lean_object* {← getModInitFn}(uint8_t builtin) \{",
s!"static bool {initialized} = false;",
s!"LEAN_EXPORT lean_object* {← getModInitFn (phases := .all)}(uint8_t builtin) \{",
"lean_object * res;",
"if (_G_initialized) return lean_io_result_mk_ok(lean_box(0));",
"_G_initialized = true;"
s!"if ({initialized}) return lean_io_result_mk_ok(lean_box(0));",
s!"{initialized} = true;"
]
impInitFns.forM fun fn => emitLns [
s!"res = {fn}(builtin);",
"if (lean_io_result_is_error(res)) return res;",
"lean_dec_ref(res);"]
let decls := getDecls env
decls.reverse.forM emitDeclInit
emitLns ["return lean_io_result_mk_ok(lean_box(0));", "}"]
impInitFns.forM fun fn => do
withErrRet do
emitLn s!"{fn}(builtin)"
emitLn "lean_dec_ref(res);"
withErrRet do
emitLn s!"{← getModInitFn (phases := .runtime)}(builtin)"
emitLn "lean_dec_ref(res);"
withErrRet do
emitLn s!"{← getModInitFn (phases := .comptime)}(builtin)"
emitLn "lean_dec_ref(res);"
emitLns [s!"return {← getModInitFn (phases := .all)}(builtin);", "}"]
def main : M Unit := do
emitFileHeader
emitFnDecls
emitFns
emitInitFn
if ( getEnv).header.isModule then
emitInitFn (phases := .runtime)
emitInitFn (phases := .comptime)
emitLegacyInitFn
else
emitInitFn (phases := .all)
emitMainFnIfNeeded
emitFileFooter

View File

@@ -1081,7 +1081,7 @@ def emitSSet (builder : LLVM.Builder llvmctx) (x : VarId) (n : Nat) (offset : Na
def emitDel (builder : LLVM.Builder llvmctx) (x : VarId) : M llvmctx Unit := do
let argtys := #[ LLVM.voidPtrType llvmctx]
let retty LLVM.voidType llvmctx
let fn getOrCreateFunctionPrototype ( getLLVMModule) retty "lean_free_object" argtys
let fn getOrCreateFunctionPrototype ( getLLVMModule) retty "lean_del_object" argtys
let xv emitLhsVal builder x
let fnty LLVM.functionType retty argtys
let _ LLVM.buildCall2 builder fnty fn #[xv]

View File

@@ -21,7 +21,7 @@ def isTailCallTo (g : Name) (b : FnBody) : Bool :=
| _ => false
def usesModuleFrom (env : Environment) (modulePrefix : Name) : Bool :=
env.allImportedModuleNames.toList.any fun modName => modulePrefix.isPrefixOf modName
env.header.modules.any fun mod => mod.irPhases != .comptime && modulePrefix.isPrefixOf mod.module
namespace CollectUsedDecls

View File

@@ -1,288 +0,0 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Compiler.IR.CompilerM
public import Lean.Compiler.IR.NormIds
public import Lean.Compiler.IR.FreeVars
import Init.Omega
public section
namespace Lean.IR.ExpandResetReuse
/-- Mapping from variable to projections -/
abbrev ProjMap := Std.HashMap VarId Expr
namespace CollectProjMap
abbrev Collector := ProjMap ProjMap
@[inline] def collectVDecl (x : VarId) (v : Expr) : Collector := fun m =>
match v with
| .proj .. => m.insert x v
| .sproj .. => m.insert x v
| .uproj .. => m.insert x v
| _ => m
partial def collectFnBody : FnBody Collector
| .vdecl x _ v b => collectVDecl x v collectFnBody b
| .jdecl _ _ v b => collectFnBody v collectFnBody b
| .case _ _ _ alts => fun s => alts.foldl (fun s alt => collectFnBody alt.body s) s
| e => if e.isTerminal then id else collectFnBody e.body
end CollectProjMap
/-- Create a mapping from variables to projections.
This function assumes variable ids have been normalized -/
def mkProjMap (d : Decl) : ProjMap :=
match d with
| .fdecl (body := b) .. => CollectProjMap.collectFnBody b {}
| _ => {}
structure Context where
projMap : ProjMap
/-- Return true iff `x` is consumed in all branches of the current block.
Here consumption means the block contains a `dec x` or `reuse x ...`. -/
partial def consumed (x : VarId) : FnBody Bool
| .vdecl _ _ v b =>
match v with
| Expr.reuse y _ _ _ => x == y || consumed x b
| _ => consumed x b
| .dec y _ _ _ b => x == y || consumed x b
| .case _ _ _ alts => alts.all fun alt => consumed x alt.body
| e => !e.isTerminal && consumed x e.body
abbrev Mask := Array (Option VarId)
/-- Auxiliary function for eraseProjIncFor -/
partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (mask : Mask) (keep : Array FnBody) : Array FnBody × Mask :=
let done (_ : Unit) := (bs ++ keep.reverse, mask)
let keepInstr (b : FnBody) := eraseProjIncForAux y bs.pop mask (keep.push b)
if h : bs.size < 2 then done ()
else
let b := bs.back!
match b with
| .vdecl _ _ (.sproj _ _ _) _ => keepInstr b
| .vdecl _ _ (.uproj _ _) _ => keepInstr b
| .inc z n c p _ =>
if n == 0 then done () else
let b' := bs[bs.size - 2]
match b' with
| .vdecl w _ (.proj i x) _ =>
if w == z && y == x then
/- Found
```
let z := proj[i] y
inc z n c
```
We keep `proj`, and `inc` when `n > 1`
-/
let bs := bs.pop.pop
let mask := mask.set! i (some z)
let keep := keep.push b'
let keep := if n == 1 then keep else keep.push (FnBody.inc z (n-1) c p FnBody.nil)
eraseProjIncForAux y bs mask keep
else done ()
| _ => done ()
| _ => done ()
/-- Try to erase `inc` instructions on projections of `y` occurring in the tail of `bs`.
Return the updated `bs` and a bit mask specifying which `inc`s have been removed. -/
def eraseProjIncFor (n : Nat) (y : VarId) (bs : Array FnBody) : Array FnBody × Mask :=
eraseProjIncForAux y bs (.replicate n none) #[]
/-- Replace `reuse x ctor ...` with `ctor ...`, and remove `dec x` -/
partial def reuseToCtor (x : VarId) : FnBody FnBody
| FnBody.dec y n c p b =>
if x == y then b -- n must be 1 since `x := reset ...`
else FnBody.dec y n c p (reuseToCtor x b)
| FnBody.vdecl z t v b =>
match v with
| Expr.reuse y c _ xs =>
if x == y then FnBody.vdecl z t (Expr.ctor c xs) b
else FnBody.vdecl z t v (reuseToCtor x b)
| _ =>
FnBody.vdecl z t v (reuseToCtor x b)
| FnBody.case tid y yType alts =>
let alts := alts.map fun alt => alt.modifyBody (reuseToCtor x)
FnBody.case tid y yType alts
| e =>
if e.isTerminal then
e
else
let (instr, b) := e.split
let b := reuseToCtor x b
instr.setBody b
/--
replace
```
x := reset y; b
```
with
```
inc z_1; ...; inc z_i; dec y; b'
```
where `z_i`'s are the variables in `mask`,
and `b'` is `b` where we removed `dec x` and replaced `reuse x ctor_i ...` with `ctor_i ...`.
-/
def mkSlowPath (x y : VarId) (mask : Mask) (b : FnBody) : FnBody :=
let b := reuseToCtor x b
let b := FnBody.dec y 1 true false b
mask.foldl (init := b) fun b m => match m with
| some z => FnBody.inc z 1 true false b
| none => b
abbrev M := ReaderT Context (StateM Nat)
def mkFresh : M VarId :=
modifyGet fun n => ({ idx := n }, n + 1)
def releaseUnreadFields (y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
mask.size.foldM (init := b) fun i _ b =>
match mask[i] with
| some _ => pure b -- code took ownership of this field
| none => do
let fld mkFresh
pure (FnBody.vdecl fld .tobject (Expr.proj i y) (FnBody.dec fld 1 true false b))
def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
zs.size.fold (init := b) fun i _ b => FnBody.set y i zs[i] b
/-- Given `set x[i] := y`, return true iff `y := proj[i] x` -/
def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=
match y with
| .var y =>
match ctx.projMap[y]? with
| some (Expr.proj j w) => j == i && w == x
| _ => false
| .erased => false
/-- Given `uset x[i] := y`, return true iff `y := uproj[i] x` -/
def isSelfUSet (ctx : Context) (x : VarId) (i : Nat) (y : VarId) : Bool :=
match ctx.projMap[y]? with
| some (Expr.uproj j w) => j == i && w == x
| _ => false
/-- Given `sset x[n, i] := y`, return true iff `y := sproj[n, i] x` -/
def isSelfSSet (ctx : Context) (x : VarId) (n : Nat) (i : Nat) (y : VarId) : Bool :=
match ctx.projMap[y]? with
| some (Expr.sproj m j w) => n == m && j == i && w == x
| _ => false
/-- Remove unnecessary `set/uset/sset` operations -/
partial def removeSelfSet (ctx : Context) : FnBody FnBody
| FnBody.set x i y b =>
if isSelfSet ctx x i y then removeSelfSet ctx b
else FnBody.set x i y (removeSelfSet ctx b)
| FnBody.uset x i y b =>
if isSelfUSet ctx x i y then removeSelfSet ctx b
else FnBody.uset x i y (removeSelfSet ctx b)
| FnBody.sset x n i y t b =>
if isSelfSSet ctx x n i y then removeSelfSet ctx b
else FnBody.sset x n i y t (removeSelfSet ctx b)
| FnBody.case tid y yType alts =>
let alts := alts.map fun alt => alt.modifyBody (removeSelfSet ctx)
FnBody.case tid y yType alts
| e =>
if e.isTerminal then e
else
let (instr, b) := e.split
let b := removeSelfSet ctx b
instr.setBody b
partial def reuseToSet (ctx : Context) (x y : VarId) : FnBody FnBody
| FnBody.dec z n c p b =>
if x == z then FnBody.del y b
else FnBody.dec z n c p (reuseToSet ctx x y b)
| FnBody.vdecl z t v b =>
match v with
| Expr.reuse w c u zs =>
if x == w then
let b := setFields y zs (b.replaceVar z y)
let b := if u then FnBody.setTag y c.cidx b else b
removeSelfSet ctx b
else FnBody.vdecl z t v (reuseToSet ctx x y b)
| _ => FnBody.vdecl z t v (reuseToSet ctx x y b)
| FnBody.case tid z zType alts =>
let alts := alts.map fun alt => alt.modifyBody (reuseToSet ctx x y)
FnBody.case tid z zType alts
| e =>
if e.isTerminal then e
else
let (instr, b) := e.split
let b := reuseToSet ctx x y b
instr.setBody b
/--
replace
```
x := reset y; b
```
with
```
let f_i_1 := proj[i_1] y;
...
let f_i_k := proj[i_k] y;
b'
```
where `i_j`s are the field indexes
that the code did not touch immediately before the reset.
That is `mask[j] == none`.
`b'` is `b` where `y` `dec x` is replaced with `del y`,
and `z := reuse x ctor_i ws; F` is replaced with
`set x i ws[i]` operations, and we replace `z` with `x` in `F`
-/
def mkFastPath (x y : VarId) (mask : Mask) (b : FnBody) : M FnBody := do
let ctx read
let b := reuseToSet ctx x y b
releaseUnreadFields y mask b
-- Expand `bs; x := reset[n] y; b`
partial def expand (mainFn : FnBody Array FnBody M FnBody)
(bs : Array FnBody) (x : VarId) (n : Nat) (y : VarId) (b : FnBody) : M FnBody := do
let (bs, mask) := eraseProjIncFor n y bs
/- Remark: we may be duplicating variable/JP indices. That is, `bSlow` and `bFast` may
have duplicate indices. We run `normalizeIds` to fix the ids after we have expand them. -/
let bSlow := mkSlowPath x y mask b
let bFast mkFastPath x y mask b
/- We only optimize recursively the fast. -/
let bFast mainFn bFast #[]
let c mkFresh
let b := FnBody.vdecl c IRType.uint8 (Expr.isShared y) (mkIf c bSlow bFast)
return reshape bs b
partial def searchAndExpand : FnBody Array FnBody M FnBody
| d@(FnBody.vdecl x _ (Expr.reset n y) b), bs =>
if consumed x b then do
expand searchAndExpand bs x n y b
else
searchAndExpand b (push bs d)
| FnBody.jdecl j xs v b, bs => do
let v searchAndExpand v #[]
searchAndExpand b (push bs (FnBody.jdecl j xs v FnBody.nil))
| FnBody.case tid x xType alts, bs => do
let alts alts.mapM fun alt => alt.modifyBodyM fun b => searchAndExpand b #[]
return reshape bs (FnBody.case tid x xType alts)
| b, bs =>
if b.isTerminal then return reshape bs b
else searchAndExpand b.body (push bs b)
def main (d : Decl) : Decl :=
match d with
| .fdecl (body := b) .. =>
let m := mkProjMap d
let nextIdx := d.maxIndex + 1
let bNew := (searchAndExpand b #[] { projMap := m }).run' nextIdx
d.updateBody! bNew
| d => d
end ExpandResetReuse
/-- (Try to) expand `reset` and `reuse` instructions. -/
def Decl.expandResetReuse (d : Decl) : Decl :=
(ExpandResetReuse.main d).normalizeIds
builtin_initialize registerTraceClass `compiler.ir.expand_reset_reuse (inherited := true)
end Lean.IR

View File

@@ -1,245 +0,0 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Compiler.IR.Basic
public section
namespace Lean.IR
namespace MaxIndex
/-! Compute the maximum index `M` used in a declaration.
We `M` to initialize the fresh index generator used to create fresh
variable and join point names.
Recall that we variable and join points share the same namespace in
our implementation.
-/
structure State where
currentMax : Nat := 0
abbrev M := StateM State
private def visitIndex (x : Index) : M Unit := do
modify fun s => { s with currentMax := s.currentMax.max x }
private def visitVar (x : VarId) : M Unit :=
visitIndex x.idx
private def visitJP (j : JoinPointId) : M Unit :=
visitIndex j.idx
private def visitArg (arg : Arg) : M Unit :=
match arg with
| .var x => visitVar x
| .erased => pure ()
private def visitParam (p : Param) : M Unit :=
visitVar p.x
private def visitExpr (e : Expr) : M Unit := do
match e with
| .proj _ x | .uproj _ x | .sproj _ _ x | .box _ x | .unbox x | .reset _ x | .isShared x =>
visitVar x
| .ctor _ ys | .fap _ ys | .pap _ ys =>
ys.forM visitArg
| .ap x ys | .reuse x _ _ ys =>
visitVar x
ys.forM visitArg
| .lit _ => pure ()
partial def visitFnBody (fnBody : FnBody) : M Unit := do
match fnBody with
| .vdecl x _ v b =>
visitVar x
visitExpr v
visitFnBody b
| .jdecl j ys v b =>
visitJP j
visitFnBody v
ys.forM visitParam
visitFnBody b
| .set x _ y b =>
visitVar x
visitArg y
visitFnBody b
| .uset x _ y b | .sset x _ _ y _ b =>
visitVar x
visitVar y
visitFnBody b
| .setTag x _ b | .inc x _ _ _ b | .dec x _ _ _ b | .del x b =>
visitVar x
visitFnBody b
| .case _ x _ alts =>
visitVar x
alts.forM (visitFnBody ·.body)
| .jmp j ys =>
visitJP j
ys.forM visitArg
| .ret x =>
visitArg x
| .unreachable => pure ()
private def visitDecl (decl : Decl) : M Unit := do
match decl with
| .fdecl (xs := xs) (body := b) .. =>
xs.forM visitParam
visitFnBody b
| .extern (xs := xs) .. =>
xs.forM visitParam
end MaxIndex
def FnBody.maxIndex (b : FnBody) : Index := Id.run do
let _, { currentMax } := MaxIndex.visitFnBody b |>.run {}
return currentMax
def Decl.maxIndex (d : Decl) : Index := Id.run do
let _, { currentMax } := MaxIndex.visitDecl d |>.run {}
return currentMax
namespace FreeIndices
/-! We say a variable (join point) index (aka name) is free in a function body
if there isn't a `FnBody.vdecl` (`Fnbody.jdecl`) binding it. -/
structure State where
freeIndices : IndexSet := {}
abbrev M := StateM State
private def visitIndex (x : Index) : M Unit := do
modify fun s => { s with freeIndices := s.freeIndices.insert x }
private def visitVar (x : VarId) : M Unit :=
visitIndex x.idx
private def visitJP (j : JoinPointId) : M Unit :=
visitIndex j.idx
private def visitArg (arg : Arg) : M Unit :=
match arg with
| .var x => visitVar x
| .erased => pure ()
private def visitParam (p : Param) : M Unit :=
visitVar p.x
private def visitExpr (e : Expr) : M Unit := do
match e with
| .proj _ x | .uproj _ x | .sproj _ _ x | .box _ x | .unbox x | .reset _ x | .isShared x =>
visitVar x
| .ctor _ ys | .fap _ ys | .pap _ ys =>
ys.forM visitArg
| .ap x ys | .reuse x _ _ ys =>
visitVar x
ys.forM visitArg
| .lit _ => pure ()
partial def visitFnBody (fnBody : FnBody) : M Unit := do
match fnBody with
| .vdecl x _ v b =>
visitVar x
visitExpr v
visitFnBody b
| .jdecl j ys v b =>
visitJP j
visitFnBody v
ys.forM visitParam
visitFnBody b
| .set x _ y b =>
visitVar x
visitArg y
visitFnBody b
| .uset x _ y b | .sset x _ _ y _ b =>
visitVar x
visitVar y
visitFnBody b
| .setTag x _ b | .inc x _ _ _ b | .dec x _ _ _ b | .del x b =>
visitVar x
visitFnBody b
| .case _ x _ alts =>
visitVar x
alts.forM (visitFnBody ·.body)
| .jmp j ys =>
visitJP j
ys.forM visitArg
| .ret x =>
visitArg x
| .unreachable => pure ()
private def visitDecl (decl : Decl) : M Unit := do
match decl with
| .fdecl (xs := xs) (body := b) .. =>
xs.forM visitParam
visitFnBody b
| .extern (xs := xs) .. =>
xs.forM visitParam
end FreeIndices
def FnBody.collectFreeIndices (b : FnBody) (init : IndexSet) : IndexSet := Id.run do
let _, { freeIndices } := FreeIndices.visitFnBody b |>.run { freeIndices := init }
return freeIndices
def FnBody.freeIndices (b : FnBody) : IndexSet :=
b.collectFreeIndices {}
namespace HasIndex
/-! In principle, we can check whether a function body `b` contains an index `i` using
`b.freeIndices.contains i`, but it is more efficient to avoid the construction
of the set of freeIndices and just search whether `i` occurs in `b` or not.
-/
def visitVar (w : Index) (x : VarId) : Bool := w == x.idx
def visitJP (w : Index) (x : JoinPointId) : Bool := w == x.idx
def visitArg (w : Index) : Arg Bool
| .var x => visitVar w x
| .erased => false
def visitArgs (w : Index) (xs : Array Arg) : Bool :=
xs.any (visitArg w)
def visitParams (w : Index) (ps : Array Param) : Bool :=
ps.any (fun p => w == p.x.idx)
def visitExpr (w : Index) : Expr Bool
| .proj _ x | .uproj _ x | .sproj _ _ x | .box _ x | .unbox x | .reset _ x | .isShared x =>
visitVar w x
| .ctor _ ys | .fap _ ys | .pap _ ys =>
visitArgs w ys
| .ap x ys | .reuse x _ _ ys =>
visitVar w x || visitArgs w ys
| .lit _ => false
partial def visitFnBody (w : Index) : FnBody Bool
| .vdecl _ _ v b =>
visitExpr w v || visitFnBody w b
| .jdecl _ _ v b =>
visitFnBody w v || visitFnBody w b
| FnBody.set x _ y b =>
visitVar w x || visitArg w y || visitFnBody w b
| .uset x _ y b | .sset x _ _ y _ b =>
visitVar w x || visitVar w y || visitFnBody w b
| .setTag x _ b | .inc x _ _ _ b | .dec x _ _ _ b | .del x b =>
visitVar w x || visitFnBody w b
| .case _ x _ alts =>
visitVar w x || alts.any (fun alt => visitFnBody w alt.body)
| .jmp j ys =>
visitJP w j || visitArgs w ys
| .ret x =>
visitArg w x
| .unreachable => false
end HasIndex
def Arg.hasFreeVar (arg : Arg) (x : VarId) : Bool := HasIndex.visitArg x.idx arg
def Expr.hasFreeVar (e : Expr) (x : VarId) : Bool := HasIndex.visitExpr x.idx e
def FnBody.hasFreeVar (b : FnBody) (x : VarId) : Bool := HasIndex.visitFnBody x.idx b
end Lean.IR

View File

@@ -55,22 +55,8 @@ errors from the interpreter itself as those depend on whether we are running in
-/
@[export lean_eval_check_meta]
private partial def evalCheckMeta (env : Environment) (declName : Name) : Except String Unit := do
if !env.header.isModule then
return
go declName |>.run' {}
where go (ref : Name) : StateT NameSet (Except String) Unit := do
if ( get).contains ref then
return
modify (·.insert ref)
if let some localDecl := declMapExt.getState env |>.find? ref then
for ref in collectUsedFDecls localDecl do
go ref
else
-- NOTE: We do not use `getIRPhases` here as it's intended for env decls, nor IR decls. We do
-- not set `includeServer` as we want this check to be independent of server mode. Server-only
-- users disable this check instead.
if findEnvDecl env ref |>.isNone then
throw s!"Cannot evaluate constant `{declName}` as it uses `{ref}` which is neither marked nor imported as `meta`"
if getIRPhases env declName == .runtime then
throw s!"Cannot evaluate constant `{declName}` as it is neither marked nor imported as `meta`"
builtin_initialize
registerTraceClass `compiler.ir.inferMeta

View File

@@ -1,62 +0,0 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Compiler.IR.FreeVars
public import Lean.Compiler.IR.NormIds
public section
namespace Lean.IR
partial def pushProjs (bs : Array FnBody) (alts : Array Alt) (altsF : Array IndexSet) (ctx : Array FnBody) (ctxF : IndexSet) : Array FnBody × Array Alt :=
if bs.isEmpty then (ctx.reverse, alts)
else
let b := bs.back!
let bs := bs.pop
let done (_ : Unit) := (bs.push b ++ ctx.reverse, alts)
let skip (_ : Unit) := pushProjs bs alts altsF (ctx.push b) (b.collectFreeIndices ctxF)
let push (x : VarId) :=
if !ctxF.contains x.idx then
let alts := alts.mapIdx fun i alt => alt.modifyBody fun b' =>
if altsF[i]!.contains x.idx then b.setBody b'
else b'
let altsF := altsF.map fun s => if s.contains x.idx then b.collectFreeIndices s else s
pushProjs bs alts altsF ctx ctxF
else
skip ()
match b with
| FnBody.vdecl x _ v _ =>
match v with
| Expr.proj _ _ => push x
| Expr.uproj _ _ => push x
| Expr.sproj _ _ _ => push x
| Expr.isShared _ => skip ()
| _ => done ()
| _ => done ()
partial def FnBody.pushProj (b : FnBody) : FnBody :=
let (bs, term) := b.flatten
let bs := modifyJPs bs pushProj
match term with
| .case tid x xType alts =>
let altsF := alts.map fun alt => alt.body.freeIndices
let (bs, alts) := pushProjs bs alts altsF #[] (mkIndexSet x.idx)
let alts := alts.map fun alt => alt.modifyBody pushProj
let term := FnBody.case tid x xType alts
reshape bs term
| _ => reshape bs term
/-- Push projections inside `case` branches. -/
def Decl.pushProj (d : Decl) : Decl :=
match d with
| .fdecl (body := b) .. => d.updateBody! b.pushProj |>.normalizeIds
| other => other
builtin_initialize registerTraceClass `compiler.ir.push_proj (inherited := true)
end Lean.IR

View File

@@ -101,6 +101,10 @@ partial def lowerCode (c : LCNF.Code .impure) : M FnBody := do
let ret getFVarValue fvarId
return .ret ret
| .unreach .. => return .unreachable
| .oset fvarId i y k _ =>
let y lowerArg y
let .var fvarId getFVarValue fvarId | unreachable!
return .set fvarId i y ( lowerCode k)
| .sset fvarId i offset y type k _ =>
let .var y getFVarValue y | unreachable!
let .var fvarId getFVarValue fvarId | unreachable!
@@ -109,12 +113,18 @@ partial def lowerCode (c : LCNF.Code .impure) : M FnBody := do
let .var y getFVarValue y | unreachable!
let .var fvarId getFVarValue fvarId | unreachable!
return .uset fvarId i y ( lowerCode k)
| .setTag fvarId cidx k _ =>
let .var var getFVarValue fvarId | unreachable!
return .setTag var cidx ( lowerCode k)
| .inc fvarId n check persistent k _ =>
let .var var getFVarValue fvarId | unreachable!
return .inc var n check persistent ( lowerCode k)
| .dec fvarId n check persistent k _ =>
let .var var getFVarValue fvarId | unreachable!
return .dec var n check persistent ( lowerCode k)
| .del fvarId k _ =>
let .var var getFVarValue fvarId | unreachable!
return .del var ( lowerCode k)
| .fun .. => panic! "all local functions should be λ-lifted"
partial def lowerLet (decl : LCNF.LetDecl .impure) (k : LCNF.Code .impure) : M FnBody := do
@@ -155,6 +165,9 @@ partial def lowerLet (decl : LCNF.LetDecl .impure) (k : LCNF.Code .impure) : M F
| .unbox var =>
withGetFVarValue var fun var => do
continueLet (.unbox var)
| .isShared var =>
withGetFVarValue var fun var => do
continueLet (.isShared var)
| .erased => mkErased ()
where
mkErased (_ : Unit) : M FnBody := do

View File

@@ -37,8 +37,8 @@ Run the initializer of the given module (without `builtin_initialize` commands).
Return `false` if the initializer is not available as native code.
Initializers do not have corresponding Lean definitions, so they cannot be interpreted in this case.
-/
@[inline] private unsafe def runModInit (mod : Name) (pkg? : Option String) : IO Bool :=
runModInitCore (mkModuleInitializationFunctionName mod pkg?)
@[inline] private unsafe def runModInit (mod : Name) (pkg? : Option String) (phases : IRPhases) : IO Bool :=
runModInitCore (mkModuleInitializationFunctionName mod pkg? phases)
/-- Run the initializer for `decl` and store its value for global access. Should only be used while importing. -/
@[extern "lean_run_init"]
@@ -160,36 +160,46 @@ def declareBuiltin (forDecl : Name) (value : Expr) : CoreM Unit :=
@[export lean_run_init_attrs]
private unsafe def runInitAttrs (env : Environment) (opts : Options) : IO Unit := do
if ( isInitializerExecutionEnabled) then
-- **Note**: `ModuleIdx` is not an abbreviation, and we don't have instances for it.
-- Thus, we use `(modIdx : Nat)`
for mod in env.header.moduleNames, (modIdx : Nat) in 0...* do
-- any native Lean code reachable by the interpreter (i.e. from shared
-- libraries with their corresponding module in the Environment) must
-- first be initialized
let pkg? := env.getModulePackageByIdx? modIdx
if ( runModInit mod pkg?) then
if !( isInitializerExecutionEnabled) then
throw <| IO.userError "`enableInitializerExecution` must be run before calling `importModules (loadExts := true)`"
-- **Note**: `ModuleIdx` is not an abbreviation, and we don't have instances for it.
-- Thus, we use `(modIdx : Nat)`
for mod in env.header.modules, (modIdx : Nat) in 0...* do
let initRuntime := Elab.inServer.get opts || mod.irPhases != .runtime
-- any native Lean code reachable by the interpreter (i.e. from shared
-- libraries with their corresponding module in the Environment) must
-- first be initialized
let pkg? := env.getModulePackageByIdx? modIdx
if env.header.isModule && /- TODO: remove after reboostrap -/ false then
let initializedRuntime pure initRuntime <&&> runModInit (phases := .runtime) mod.module pkg?
let initializedComptime runModInit (phases := .comptime) mod.module pkg?
if initializedRuntime || initializedComptime then
continue
-- As `[init]` decls can have global side effects, ensure we run them at most once,
-- just like the compiled code does.
if ( interpretedModInits.get).contains mod then
else
if ( runModInit (phases := .all) mod.module pkg?) then
continue
interpretedModInits.modify (·.insert mod)
let modEntries := regularInitAttr.ext.getModuleEntries env modIdx
-- `getModuleIREntries` is identical to `getModuleEntries` if we loaded only one of
-- .olean (from `meta initialize`)/.ir (`initialize` via transitive `meta import`)
-- so deduplicate (these lists should be very short).
-- If we have both, we should not need to worry about their relative ordering as `meta` and
-- non-`meta` initialize should not have interdependencies.
let modEntries := modEntries ++ (regularInitAttr.ext.getModuleIREntries env modIdx).filter (!modEntries.contains ·)
for (decl, initDecl) in modEntries do
-- Skip initializers we do not have IR for; they should not be reachable by interpretation.
if !Elab.inServer.get opts && getIRPhases env decl == .runtime then
continue
if initDecl.isAnonymous then
let initFn IO.ofExcept <| env.evalConst (IO Unit) opts decl
initFn
else
runInit env opts decl initDecl
-- As `[init]` decls can have global side effects, ensure we run them at most once,
-- just like the compiled code does.
if ( interpretedModInits.get).contains mod.module then
continue
interpretedModInits.modify (·.insert mod.module)
let modEntries := regularInitAttr.ext.getModuleEntries env modIdx
-- `getModuleIREntries` is identical to `getModuleEntries` if we loaded only one of
-- .olean (from `meta initialize`)/.ir (`initialize` via transitive `meta import`)
-- so deduplicate (these lists should be very short).
-- If we have both, we should not need to worry about their relative ordering as `meta` and
-- non-`meta` initialize should not have interdependencies.
let modEntries := modEntries ++ (regularInitAttr.ext.getModuleIREntries env modIdx).filter (!modEntries.contains ·)
for (decl, initDecl) in modEntries do
if !initRuntime && getIRPhases env decl == .runtime then
continue
if initDecl.isAnonymous then
-- Don't check `meta` again as it would not respect `Elab.inServer`
let initFn IO.ofExcept <| env.evalConst (checkMeta := false) (IO Unit) opts decl
initFn
else
runInit env opts decl initDecl
end Lean

View File

@@ -75,6 +75,7 @@ def eqvLetValue (e₁ e₂ : LetValue pu) : EqvM Bool := do
pure (i₁ == i₂ && u₁ == u₂) <&&> eqvFVar v₁ v₂ <&&> eqvArgs as₁ as₂
| .box ty₁ v₁ _, .box ty₂ v₂ _ => eqvType ty₁ ty₂ <&&> eqvFVar v₁ v₂
| .unbox v₁ _, .unbox v₂ _ => eqvFVar v₁ v₂
| .isShared v₁ _, .isShared v₂ _ => eqvFVar v₁ v₂
| _, _ => return false
@[inline] def withFVar (fvarId₁ fvarId₂ : FVarId) (x : EqvM α) : EqvM α :=
@@ -143,6 +144,11 @@ partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do
eqvFVar c₁.discr c₂.discr <&&>
eqvType c₁.resultType c₂.resultType <&&>
eqvAlts c₁.alts c₂.alts
| .oset fvarId₁ i₁ y₁ k₁ _, .oset fvarId₂ i₂ y₂ k₂ _ =>
pure (i₁ == i₂) <&&>
eqvFVar fvarId₁ fvarId₂ <&&>
eqvArg y₁ y₂ <&&>
eqv k₁ k₂
| .sset fvarId₁ i₁ offset₁ y₁ ty₁ k₁ _, .sset fvarId₂ i₂ offset₂ y₂ ty₂ k₂ _ =>
pure (i₁ == i₂) <&&>
pure (offset₁ == offset₂) <&&>
@@ -155,6 +161,10 @@ partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do
eqvFVar fvarId₁ fvarId₂ <&&>
eqvFVar y₁ y₂ <&&>
eqv k₁ k₂
| .setTag fvarId₁ c₁ k₁ _, .setTag fvarId₂ c₂ k₂ _ =>
pure (c₁ == c₂) <&&>
eqvFVar fvarId₁ fvarId₂ <&&>
eqv k₁ k₂
| .inc fvarId₁ n₁ c₁ p₁ k₁ _, .inc fvarId₂ n₂ c₂ p₂ k₂ _ =>
pure (n₁ == n₂) <&&>
pure (c₁ == c₂) <&&>
@@ -167,6 +177,9 @@ partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do
pure (p₁ == p₂) <&&>
eqvFVar fvarId₁ fvarId₂ <&&>
eqv k₁ k₂
| .del fvarId₁ k₁ _, .del fvarId₂ k₂ _ =>
eqvFVar fvarId₁ fvarId₂ <&&>
eqv k₁ k₂
| _, _ => return false
end

View File

@@ -219,6 +219,10 @@ inductive LetValue (pu : Purity) where
| box (ty : Expr) (fvarId : FVarId) (h : pu = .impure := by purity_tac)
/-- Given `fvarId : [t]object`, obtain the underlying scalar value. -/
| unbox (fvarId : FVarId) (h : pu = .impure := by purity_tac)
/--
Return whether the object stored behind `fvarId` is shared or not. The return type is a `UInt8`.
-/
| isShared (fvarId : FVarId) (h : pu = .impure := by purity_tac)
deriving Inhabited, BEq, Hashable
def Arg.toLetValue (arg : Arg pu) : LetValue pu :=
@@ -298,7 +302,12 @@ private unsafe def LetValue.updateUnboxImp (e : LetValue pu) (fvarId' : FVarId)
@[implemented_by LetValue.updateUnboxImp] opaque LetValue.updateUnbox! (e : LetValue pu) (fvarId' : FVarId) : LetValue pu
private unsafe def LetValue.updateIsSharedImp (e : LetValue pu) (fvarId' : FVarId) : LetValue pu :=
match e with
| .isShared fvarId _ => if fvarId == fvarId' then e else .isShared fvarId'
| _ => unreachable!
@[implemented_by LetValue.updateIsSharedImp] opaque LetValue.updateIsShared! (e : LetValue pu) (fvarId' : FVarId) : LetValue pu
private unsafe def LetValue.updateArgsImp (e : LetValue pu) (args' : Array (Arg pu)) : LetValue pu :=
match e with
@@ -331,6 +340,7 @@ def LetValue.toExpr (e : LetValue pu) : Expr :=
#[.fvar var, .const i.name [], ToExpr.toExpr updateHeader] ++ (args.map Arg.toExpr)
| .box ty var _ => mkApp2 (.const `box []) ty (.fvar var)
| .unbox var _ => mkApp (.const `unbox []) (.fvar var)
| .isShared fvarId _ => mkApp (.const `isShared []) (.fvar fvarId)
structure LetDecl (pu : Purity) where
fvarId : FVarId
@@ -361,10 +371,13 @@ inductive Code (pu : Purity) where
| cases (cases : Cases pu)
| return (fvarId : FVarId)
| unreach (type : Expr)
| oset (fvarId : FVarId) (i : Nat) (y : Arg pu) (k : Code pu) (h : pu = .impure := by purity_tac)
| uset (fvarId : FVarId) (i : Nat) (y : FVarId) (k : Code pu) (h : pu = .impure := by purity_tac)
| sset (fvarId : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (k : Code pu) (h : pu = .impure := by purity_tac)
| setTag (fvarId : FVarId) (cidx : Nat) (k : Code pu) (h : pu = .impure := by purity_tac)
| inc (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (k : Code pu) (h : pu = .impure := by purity_tac)
| dec (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (k : Code pu) (h : pu = .impure := by purity_tac)
| del (fvarId : FVarId) (k : Code pu) (h : pu = .impure := by purity_tac)
deriving Inhabited
end
@@ -440,25 +453,32 @@ inductive CodeDecl (pu : Purity) where
| let (decl : LetDecl pu)
| fun (decl : FunDecl pu) (h : pu = .pure := by purity_tac)
| jp (decl : FunDecl pu)
| oset (fvarId : FVarId) (i : Nat) (y : Arg pu) (h : pu = .impure := by purity_tac)
| uset (fvarId : FVarId) (i : Nat) (y : FVarId) (h : pu = .impure := by purity_tac)
| sset (fvarId : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (h : pu = .impure := by purity_tac)
| setTag (fvarId : FVarId) (cidx : Nat) (h : pu = .impure := by purity_tac)
| inc (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (h : pu = .impure := by purity_tac)
| dec (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (h : pu = .impure := by purity_tac)
| del (fvarId : FVarId) (h : pu = .impure := by purity_tac)
deriving Inhabited
def CodeDecl.fvarId : CodeDecl pu FVarId
| .let decl | .fun decl _ | .jp decl => decl.fvarId
| .uset fvarId .. | .sset fvarId .. | .inc fvarId .. | .dec fvarId .. => fvarId
| .uset fvarId .. | .sset fvarId .. | .inc fvarId .. | .dec fvarId .. | .del fvarId ..
| .oset fvarId .. | .setTag fvarId .. => fvarId
def Code.toCodeDecl! : Code pu CodeDecl pu
| .let decl _ => .let decl
| .fun decl _ _ => .fun decl
| .jp decl _ => .jp decl
| .uset fvarId i y _ _ => .uset fvarId i y
| .sset fvarId i offset ty y _ _ => .sset fvarId i offset ty y
| .inc fvarId n check persistent _ _ => .inc fvarId n check persistent
| .dec fvarId n check persistent _ _ => .dec fvarId n check persistent
| _ => unreachable!
| .let decl _ => .let decl
| .fun decl _ _ => .fun decl
| .jp decl _ => .jp decl
| .oset fvarId i y _ _ => .oset fvarId i y
| .uset fvarId i y _ _ => .uset fvarId i y
| .sset fvarId i offset ty y _ _ => .sset fvarId i offset ty y
| .setTag fvarId cidx _ _ => .setTag fvarId cidx
| .inc fvarId n check persistent _ _ => .inc fvarId n check persistent
| .dec fvarId n check persistent _ _ => .dec fvarId n check persistent
| .del fvarId _ _ => .del fvarId
| _ => unreachable!
def attachCodeDecls (decls : Array (CodeDecl pu)) (code : Code pu) : Code pu :=
go decls.size code
@@ -469,10 +489,13 @@ where
| .let decl => go (i-1) (.let decl code)
| .fun decl _ => go (i-1) (.fun decl code)
| .jp decl => go (i-1) (.jp decl code)
| .oset fvarId idx y _ => go (i-1) (.oset fvarId idx y code)
| .uset fvarId idx y _ => go (i-1) (.uset fvarId idx y code)
| .sset fvarId idx offset y ty _ => go (i-1) (.sset fvarId idx offset y ty code)
| .setTag fvarId cidx _ => go (i-1) (.setTag fvarId cidx code)
| .inc fvarId n check persistent _ => go (i-1) (.inc fvarId n check persistent code)
| .dec fvarId n check persistent _ => go (i-1) (.dec fvarId n check persistent code)
| .del fvarId _ => go (i-1) (.del fvarId code)
else
code
@@ -488,14 +511,20 @@ mutual
| .jmp j₁ as₁, .jmp j₂ as₂ => j₁ == j₂ && as₁ == as₂
| .return r₁, .return r₂ => r₁ == r₂
| .unreach t₁, .unreach t₂ => t₁ == t₂
| .oset v₁ i₁ y₁ k₁ _, .oset v₂ i₂ y₂ k₂ _ =>
v₁ == v₂ && i₁ == i₂ && y₁ == y₂ && eqImp k₁ k₂
| .uset v₁ i₁ y₁ k₁ _, .uset v₂ i₂ y₂ k₂ _ =>
v₁ == v₂ && i₁ == i₂ && y₁ == y₂ && eqImp k₁ k₂
| .sset v₁ i₁ o₁ y₁ ty₁ k₁ _, .sset v₂ i₂ o₂ y₂ ty₂ k₂ _ =>
v₁ == v₂ && i₁ == i₂ && o₁ == o₂ && y₁ == y₂ && ty₁ == ty₂ && eqImp k₁ k₂
| .setTag v₁ c₁ k₁ _, .setTag v₂ c₂ k₂ _ =>
v₁ == v₂ && c₁ == c₂ && eqImp k₁ k₂
| .inc v₁ n₁ c₁ p₁ k₁ _, .inc v₂ n₂ c₂ p₂ k₂ _ =>
v₁ == v₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && eqImp k₁ k₂
| .dec v₁ n₁ c₁ p₁ k₁ _, .dec v₂ n₂ c₂ p₂ k₂ _ =>
v₁ == v₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && eqImp k₁ k₂
| .del v₁ k₁ _, .del v₂ k₂ _ =>
v₁ == v₂ && eqImp k₁ k₂
| _, _ => false
private unsafe def eqFunDecl (d₁ d₂ : FunDecl pu) : Bool :=
@@ -588,10 +617,13 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co
| .let decl k => if ptrEq k k' then c else .let decl k'
| .fun decl k _ => if ptrEq k k' then c else .fun decl k'
| .jp decl k => if ptrEq k k' then c else .jp decl k'
| .oset fvarId offset y k _ => if ptrEq k k' then c else .oset fvarId offset y k'
| .sset fvarId i offset y ty k _ => if ptrEq k k' then c else .sset fvarId i offset y ty k'
| .uset fvarId offset y k _ => if ptrEq k k' then c else .uset fvarId offset y k'
| .setTag fvarId cidx k _ => if ptrEq k k' then c else .setTag fvarId cidx k'
| .inc fvarId n check persistent k _ => if ptrEq k k' then c else .inc fvarId n check persistent k'
| .dec fvarId n check persistent k _ => if ptrEq k k' then c else .dec fvarId n check persistent k'
| .del fvarId k _ => if ptrEq k k' then c else .del fvarId k'
| _ => unreachable!
@[implemented_by updateContImp] opaque Code.updateCont! (c : Code pu) (k' : Code pu) : Code pu
@@ -635,6 +667,19 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co
.sset fvarId' i' offset' y' ty' k'
| _ => unreachable!
@[inline] private unsafe def updateOsetImp (c : Code pu) (fvarId' : FVarId)
(i' : Nat) (y' : Arg pu) (k' : Code pu) : Code pu :=
match c with
| .oset fvarId i y k _ =>
if ptrEq fvarId fvarId' && i == i' && ptrEq y y' && ptrEq k k' then
c
else
.oset fvarId' i' y' k'
| _ => unreachable!
@[implemented_by updateOsetImp] opaque Code.updateOset! (c : Code pu) (fvarId' : FVarId)
(i' : Nat) (y' : Arg pu) (k' : Code pu) : Code pu
@[implemented_by updateSsetImp] opaque Code.updateSset! (c : Code pu) (fvarId' : FVarId) (i' : Nat)
(offset' : Nat) (y' : FVarId) (ty' : Expr) (k' : Code pu) : Code pu
@@ -651,6 +696,19 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co
@[implemented_by updateUsetImp] opaque Code.updateUset! (c : Code pu) (fvarId' : FVarId)
(i' : Nat) (y' : FVarId) (k' : Code pu) : Code pu
@[inline] private unsafe def updateSetTagImp (c : Code pu) (fvarId' : FVarId) (cidx' : Nat)
(k' : Code pu) : Code pu :=
match c with
| .setTag fvarId cidx k _ =>
if ptrEq fvarId fvarId' && cidx == cidx' && ptrEq k k' then
c
else
.setTag fvarId' cidx' k'
| _ => unreachable!
@[implemented_by updateSetTagImp] opaque Code.updateSetTag! (c : Code pu) (fvarId' : FVarId)
(cidx' : Nat) (k' : Code pu) : Code pu
@[inline] private unsafe def updateIncImp (c : Code pu) (fvarId' : FVarId) (n' : Nat)
(check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu :=
match c with
@@ -685,6 +743,19 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co
@[implemented_by updateDecImp] opaque Code.updateDec! (c : Code pu) (fvarId' : FVarId) (n' : Nat)
(check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu
@[inline] private unsafe def updateDelImp (c : Code pu) (fvarId' : FVarId) (k' : Code pu) :
Code pu :=
match c with
| .del fvarId k _ =>
if ptrEq fvarId fvarId' && ptrEq k k' then
c
else
.del fvarId' k'
| _ => unreachable!
@[implemented_by updateDelImp] opaque Code.updateDel! (c : Code pu) (fvarId' : FVarId)
(k' : Code pu) : Code pu
private unsafe def updateParamCoreImp (p : Param pu) (type : Expr) : Param pu :=
if ptrEq type p.type then
p
@@ -753,8 +824,8 @@ partial def Code.size (c : Code pu) : Nat :=
where
go (c : Code pu) (n : Nat) : Nat :=
match c with
| .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. => go k (n + 1)
| .let (k := k) .. | .oset (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. | .setTag (k := k) .. | .del (k := k) .. => go k (n + 1)
| .jp decl k | .fun decl k _ => go k <| go decl.value n
| .cases c => c.alts.foldl (init := n+1) fun n alt => go alt.getCode (n+1)
| .jmp .. => n+1
@@ -772,8 +843,8 @@ where
go (c : Code pu) : EStateM Unit Nat Unit := do
match c with
| .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. => inc; go k
| .let (k := k) .. | .oset (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. | .setTag (k := k) .. | .del (k := k) .. => inc; go k
| .jp decl k | .fun decl k _ => inc; go decl.value; go k
| .cases c => inc; c.alts.forM fun alt => go alt.getCode
| .jmp .. => inc
@@ -785,8 +856,8 @@ where
go (c : Code pu) : m Unit := do
f c
match c with
| .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. => go k
| .let (k := k) .. | .oset (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. | .setTag (k := k) .. | .del (k := k) .. => go k
| .fun decl k _ | .jp decl k => go decl.value; go k
| .cases c => c.alts.forM fun alt => go alt.getCode
| .unreach .. | .return .. | .jmp .. => return ()
@@ -1053,7 +1124,7 @@ private def collectLetValue (e : LetValue pu) (s : FVarIdHashSet) : FVarIdHashSe
| .fvar fvarId args => collectArgs args <| s.insert fvarId
| .const _ _ args _ | .pap _ args _ | .fap _ args _ | .ctor _ args _ => collectArgs args s
| .proj _ _ fvarId _ | .sproj _ _ fvarId _ | .uproj _ fvarId _ | .oproj _ fvarId _
| .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ => s.insert fvarId
| .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ | .isShared fvarId _ => s.insert fvarId
| .lit .. | .erased => s
| .reuse fvarId _ _ args _ => collectArgs args <| s.insert fvarId
@@ -1082,7 +1153,12 @@ partial def Code.collectUsed (code : Code pu) (s : FVarIdHashSet := {}) : FVarId
let s := s.insert fvarId
let s := s.insert y
k.collectUsed s
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
| .oset fvarId _ y k _ =>
let s := s.insert fvarId
let s := if let .fvar y := y then s.insert y else s
k.collectUsed s
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) ..
| .del (fvarId := fvarId) (k := k) .. | .setTag (fvarId := fvarId) (k := k) .. =>
k.collectUsed <| s.insert fvarId
end
@@ -1095,7 +1171,11 @@ def CodeDecl.collectUsed (codeDecl : CodeDecl pu) (s : FVarIdHashSet := ∅) : F
| .jp decl | .fun decl _ => decl.collectUsed s
| .sset (fvarId := fvarId) (y := y) .. | .uset (fvarId := fvarId) (y := y) .. =>
s.insert fvarId |>.insert y
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => s.insert fvarId
| .oset (fvarId := fvarId) (y := y) .. =>
let s := s.insert fvarId
if let .fvar y := y then s.insert y else s
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. | .setTag (fvarId := fvarId) ..
| .del (fvarId := fvarId) .. => s.insert fvarId
/--
Traverse the given block of potentially mutually recursive functions
@@ -1125,7 +1205,8 @@ where
modify fun s => s.insert declName
| _ => pure ()
visit k
| .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => visit k
| .oset (k := k) .. | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. | .del (k := k) .. | .setTag (k := k) .. => visit k
go : StateM NameSet Unit :=
decls.forM (·.value.forCodeM visit)

View File

@@ -68,7 +68,8 @@ where
eraseCode k
eraseParam auxParam
return .unreach typeNew
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
| .oset (k := k) ..| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) ..
| .del (k := k) .. | .setTag (k := k) .. =>
return c.updateCont! ( go k)
instance : MonadCodeBind CompilerM where

View File

@@ -149,7 +149,7 @@ def eraseCodeDecl (decl : CodeDecl pu) : CompilerM Unit := do
match decl with
| .let decl => eraseLetDecl decl
| .jp decl | .fun decl _ => eraseFunDecl decl
| .sset .. | .uset .. | .inc .. | .dec .. => return ()
| .sset .. | .uset .. | .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => return ()
/--
Erase all free variables occurring in `decls` from the local context.
@@ -300,6 +300,10 @@ private partial def normLetValueImp (s : FVarSubst pu) (e : LetValue pu) (transl
match normFVarImp s fvarId translator with
| .fvar fvarId' => e.updateUnbox! fvarId'
| .erased => .erased
| .isShared fvarId _ =>
match normFVarImp s fvarId translator with
| .fvar fvarId' => e.updateIsShared! fvarId'
| .erased => .erased
/--
Interface for monads that have a free substitutions.
@@ -497,16 +501,26 @@ mutual
withNormFVarResult ( normFVar fvarId) fun fvarId => do
withNormFVarResult ( normFVar y) fun y => do
return code.updateSset! fvarId i offset y ( normExpr ty) ( normCodeImp k)
| .oset fvarId offset y k _ =>
withNormFVarResult ( normFVar fvarId) fun fvarId => do
let y normArg y
return code.updateOset! fvarId offset y ( normCodeImp k)
| .uset fvarId offset y k _ =>
withNormFVarResult ( normFVar fvarId) fun fvarId => do
withNormFVarResult ( normFVar y) fun y => do
return code.updateUset! fvarId offset y ( normCodeImp k)
| .setTag fvarId cidx k _ =>
withNormFVarResult ( normFVar fvarId) fun fvarId => do
return code.updateSetTag! fvarId cidx ( normCodeImp k)
| .inc fvarId n check persistent k _ =>
withNormFVarResult ( normFVar fvarId) fun fvarId => do
return code.updateInc! fvarId n check persistent ( normCodeImp k)
| .dec fvarId n check persistent k _ =>
withNormFVarResult ( normFVar fvarId) fun fvarId => do
return code.updateDec! fvarId n check persistent ( normCodeImp k)
| .del fvarId k _ =>
withNormFVarResult ( normFVar fvarId) fun fvarId => do
return code.updateDel! fvarId ( normCodeImp k)
end
@[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (decl : FunDecl pu) : m (FunDecl pu) := do

View File

@@ -39,12 +39,18 @@ partial def hashCode (code : Code pu) : UInt64 :=
| .cases c => mixHash (mixHash (hash c.discr) (hash c.resultType)) (hashAlts c.alts)
| .sset fvarId i offset y ty k _ =>
mixHash (mixHash (hash fvarId) (hash i)) (mixHash (mixHash (hash offset) (hash y)) (mixHash (hash ty) (hashCode k)))
| .oset fvarId offset y k _ =>
mixHash (mixHash (hash fvarId) (hash offset)) (mixHash (hash y) (hashCode k))
| .uset fvarId offset y k _ =>
mixHash (mixHash (hash fvarId) (hash offset)) (mixHash (hash y) (hashCode k))
| .setTag fvarId cidx k _ =>
mixHash (hash fvarId) (mixHash (hash cidx) (hashCode k))
| .inc fvarId n check persistent k _ =>
mixHash (mixHash (hash fvarId) (hash n)) (mixHash (mixHash (hash persistent) (hash check)) (hashCode k))
| .dec fvarId n check persistent k _ =>
mixHash (mixHash (hash fvarId) (hash n)) (mixHash (mixHash (hash persistent) (hash check)) (hashCode k))
| .del fvarId k _ =>
mixHash (hash fvarId) (hashCode k)
end

View File

@@ -31,7 +31,7 @@ private def letValueDepOn (e : LetValue pu) : M Bool :=
match e with
| .erased | .lit .. => return false
| .proj _ _ fvarId _ | .oproj _ fvarId _ | .uproj _ fvarId _ | .sproj _ _ fvarId _
| .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ => fvarDepOn fvarId
| .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ | .isShared fvarId _ => fvarDepOn fvarId
| .fvar fvarId args | .reuse fvarId _ _ args _ => fvarDepOn fvarId <||> args.anyM argDepOn
| .const _ _ args _ | .ctor _ args _ | .fap _ args _ | .pap _ args _ => args.anyM argDepOn
@@ -46,8 +46,12 @@ private partial def depOn (c : Code pu) : M Bool :=
| .jmp fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn
| .return fvarId => fvarDepOn fvarId
| .unreach _ => return false
| .sset fv1 _ _ fv2 _ k _ | .uset fv1 _ fv2 k _ => fvarDepOn fv1 <||> fvarDepOn fv2 <||> depOn k
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
| .oset fv1 _ fv2 k _ =>
fvarDepOn fv1 <||> argDepOn fv2 <||> depOn k
| .sset fv1 _ _ fv2 _ k _ | .uset fv1 _ fv2 k _ =>
fvarDepOn fv1 <||> fvarDepOn fv2 <||> depOn k
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) ..
| .del (fvarId := fvarId) (k := k) .. | .setTag (fvarId := fvarId) (k := k) .. =>
fvarDepOn fvarId <||> depOn k
@[inline] def Arg.dependsOn (arg : Arg pu) (s : FVarIdSet) : Bool :=
@@ -66,9 +70,14 @@ def CodeDecl.dependsOn (decl : CodeDecl pu) (s : FVarIdSet) : Bool :=
match decl with
| .let decl => decl.dependsOn s
| .jp decl | .fun decl _ => decl.dependsOn s
| .uset (fvarId := fvarId) (y := y) .. | .sset (fvarId := fvarId) (y := y) .. =>
| .oset (fvarId := fvarId) (y := y) .. =>
s.contains fvarId || y.dependsOn s
| .uset (fvarId := fvarId) (y := y) ..
| .sset (fvarId := fvarId) (y := y) .. =>
s.contains fvarId || s.contains y
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => s.contains fvarId
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. | .del (fvarId := fvarId) ..
| .setTag (fvarId := fvarId) .. =>
s.contains fvarId
/--
Return `true` is `c` depends on a free variable in `s`.

View File

@@ -35,7 +35,7 @@ def collectLocalDeclsLetValue (s : UsedLocalDecls) (e : LetValue pu) : UsedLocal
match e with
| .erased | .lit .. => s
| .proj _ _ fvarId _ | .reset _ fvarId _ | .sproj _ _ fvarId _ | .uproj _ fvarId _
| .oproj _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ => s.insert fvarId
| .oproj _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ | .isShared fvarId _ => s.insert fvarId
| .const _ _ args _ => collectLocalDeclsArgs s args
| .fvar fvarId args | .reuse fvarId _ _ args _ => collectLocalDeclsArgs (s.insert fvarId) args
| .fap _ args _ | .pap _ args _ | .ctor _ args _ => collectLocalDeclsArgs s args
@@ -56,9 +56,8 @@ def LetValue.safeToElim (val : LetValue pu) : Bool :=
| .pure => true
| .impure =>
match val with
-- TODO | .isShared ..
| .ctor .. | .reset .. | .reuse .. | .oproj .. | .uproj .. | .sproj .. | .lit .. | .pap ..
| .box .. | .unbox .. | .erased .. => true
| .box .. | .unbox .. | .erased .. | .isShared .. => true
-- 0-ary full applications are considered constants
| .fap _ args => args.isEmpty
| .fvar .. => false
@@ -95,6 +94,13 @@ partial def Code.elimDead (code : Code pu) : M (Code pu) := do
| .return fvarId => collectFVarM fvarId; return code
| .jmp fvarId args => collectFVarM fvarId; args.forM collectArgM; return code
| .unreach .. => return code
| .oset fvarId _ y k _ =>
let k k.elimDead
if ( get).contains fvarId then
collectArgM y
return code.updateCont! k
else
return k
| .uset fvarId _ y k _ | .sset fvarId _ _ y _ k _ =>
let k k.elimDead
if ( get).contains fvarId then
@@ -102,7 +108,8 @@ partial def Code.elimDead (code : Code pu) : M (Code pu) := do
return code.updateCont! k
else
return k
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) ..
| .setTag (fvarId := fvarId) (k := k) .. | .del (fvarId := fvarId) (k := k) .. =>
let k k.elimDead
collectFVarM fvarId
return code.updateCont! k

View File

@@ -0,0 +1,369 @@
/-
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
module
prelude
public import Lean.Compiler.LCNF.PassManager
import Init.While
/-!
This pass expands pairs of reset-reuse instructions into explicit hot and cold paths. We do this on
the LCNF level rather than letting the backend do it because we can apply domain specific
optimizations at this point in time.
Whenever we encounter a `let token := reset nfields orig; k`, we create code of the shape (not
showing reference counting instructions):
```
jp resetjp token isShared :=
k
cases isShared orig with
| false -> jmp resetjp orig true
| true -> jmp resetjp box(0) false
```
Then within the join point body `k` we turn `dec` instructions on `token` into `del` and expand
`let final := reuse token arg; k'` into another join point:
```
jp reusejp final :=
k'
cases isShared with
| false -> jmp reusejp token
| true ->
let x := alloc args
jmp reusejp x
```
In addition to this we perform optimizations specific to the hot path for both the `resetjp` and
`reusejp`. For the former, we will frequently encounter the pattern:
```
let x_0 = proj[0] orig
inc x_0
...
let x_i = proj[i] orig
inc x_i
let token := reset nfields orig
```
On the hot path we do not free `orig`, thus there is no need to increment the reference counts of
the projections because the reference coming from `orig` will keep all of the projections alive
naturally (a form of "dynamic derived borrows" if you wish). On the cold path the reference counts
still have to happen though.
For `resetjp` we frequently encounter the pattern:
```
let final := reuse token args
set final[0] := x0
...
set final[i] := xi
```
On the hot path we know that `token` and `orig` refer to the same value. Thus, if we can detect that
one of the `xi` is of the shape `let xi := proj[i] orig`, we can omit the store on the hot path.
Just like with `reusejp` on the cold path we have to perform all the stores.
-/
namespace Lean.Compiler.LCNF
open ImpureType
abbrev Mask := Array (Option FVarId)
/--
Try to erase `inc` instructions on projections of `targetId` occuring in the tail of `ds`.
Return the updated `ds` and mask contianing the `FVarId`s whose `inc` was removed.
-/
partial def eraseProjIncFor (nFields : Nat) (targetId : FVarId) (ds : Array (CodeDecl .impure)) :
CompilerM (Array (CodeDecl .impure) × Mask) := do
let mut ds := ds
let mut keep := #[]
let mut mask := Array.replicate nFields none
while ds.size 2 do
let d := ds.back!
match d with
| .let { value := .sproj .., .. } | .let { value := .uproj .., .. } =>
ds := ds.pop
keep := keep.push d
| .inc z n c p =>
assert! n > 0 -- 0 incs should not be happening
let d' := ds[ds.size - 2]!
let .let { fvarId := w, value := .oproj i x _, .. } := d'
| break
if !(w == z && targetId == x) then
break
/-
Found
```
let z := proj[i] targetId
inc z n c
```
We keep `proj`, and `inc` when `n > 1`
-/
ds := ds.pop.pop
mask := mask.set! i (some z)
keep := keep.push d'
keep := if n == 1 then keep else keep.push (.inc z (n - 1) c p)
| _ => break
return (ds ++ keep.reverse, mask)
def mkIf (discr : FVarId) (discrType : Expr) (resultType : Expr) (t e : Code .impure) :
CompilerM (Code .impure) := do
return .cases <| .mk discrType.getAppFn.constName! resultType discr #[
.ctorAlt { name := ``Bool.false, cidx := 0, size := 0, usize := 0, ssize := 0 } e,
.ctorAlt { name := ``Bool.true, cidx := 1, size := 0, usize := 0, ssize := 0 } t,
]
def remapSets (targetId : FVarId) (sets : Array (CodeDecl .impure)) :
CompilerM (Array (CodeDecl .impure)) :=
return sets.map fun
| .oset fvarId i y => .oset targetId i y
| .sset fvarId i offset y ty => .sset targetId i offset y ty
| .uset fvarId i y => .uset targetId i y
| _ => unreachable!
def isSelfOset (fvarId : FVarId) (i : Nat) (y : Arg .impure) : CompilerM Bool := do
match y with
| .fvar y =>
let some value findLetValue? (pu := .impure) y | return false
let .oproj i' fvarId' := value | return false
return i == i' && fvarId == fvarId'
| .erased => return false
def isSelfUset (fvarId : FVarId) (i : Nat) (y : FVarId) : CompilerM Bool := do
let some value findLetValue? (pu := .impure) y | return false
let .uproj i' fvarId' := value | return false
return i == i' && fvarId == fvarId'
def isSelfSset (fvarId : FVarId) (i : Nat) (offset : Nat) (y : FVarId) :
CompilerM Bool := do
let some value findLetValue? (pu := .impure) y | return false
let .sproj i' offset' fvarId' := value | return false
return i == i' && offset == offset' && fvarId == fvarId'
/--
Partition the set instructions in `sets` into a pair `(selfSets, necessarySets)` where `selfSets`
contain instructions that perform a set with the same value projected from `selfId` and
`necessarySets` all others.
-/
def partitionSelfSets (selfId : FVarId) (sets : Array (CodeDecl .impure)) :
CompilerM (Array (CodeDecl .impure) × Array (CodeDecl .impure)) := do
let mut necessarySets := #[]
let mut selfSets := #[]
for set in sets do
let isSelfSet :=
match set with
| .oset _ i y => isSelfOset selfId i y
| .uset _ i y => isSelfUset selfId i y
| .sset _ i offset y _ => isSelfSset selfId i offset y
| _ => unreachable!
if isSelfSet then
selfSets := selfSets.push set
else
necessarySets := necessarySets.push set
return (selfSets, necessarySets)
def collectSucceedingSets (target : FVarId) (k : Code .impure) :
CompilerM (Array (CodeDecl .impure) × Code .impure) := do
let mut sets := #[]
let mut k := k
while true do
match k with
| .oset (fvarId := fvarId) (k := k') .. | .sset (fvarId := fvarId) (k := k') ..
| .uset (fvarId := fvarId) (k := k') .. =>
if target != fvarId then
break
sets := sets.push k.toCodeDecl!
k := k'
| _ => break
return (sets, k)
mutual
/--
Expand the matching `reuse`/`dec` for the allocation in `origAllocId` whose `reset` token is in
`resetTokenId`.
-/
partial def processResetCont (resetTokenId : FVarId) (code : Code .impure) (origAllocId : FVarId)
(isSharedId : FVarId) (currentRetType : Expr) : CompilerM (Code .impure) := do
match code with
| .dec y n _ _ k =>
if resetTokenId == y then
assert! n == 1 -- n must be one since `resetToken := reset ...`
return .del resetTokenId k
else
let k processResetCont resetTokenId k origAllocId isSharedId currentRetType
return code.updateCont! k
| .let decl k =>
match decl.value with
| .reuse y c u xs =>
if resetTokenId != y then
let k processResetCont resetTokenId k origAllocId isSharedId currentRetType
return code.updateCont! k
let (succeedingSets, k) collectSucceedingSets decl.fvarId k
let (selfSets, necessarySets) partitionSelfSets origAllocId succeedingSets
let k := attachCodeDecls necessarySets k
let param := {
fvarId := decl.fvarId,
binderName := decl.binderName,
type := decl.type,
borrow := false
}
let contJp mkFunDecl ( mkFreshBinderName `reusejp) currentRetType #[param] k
let slowPath mkSlowPath decl c xs contJp.fvarId selfSets
let fastPath mkFastPath resetTokenId c u xs contJp.fvarId origAllocId
eraseLetDecl decl
let reuse mkIf isSharedId uint8 currentRetType slowPath fastPath
return .jp contJp reuse
| _ =>
let k processResetCont resetTokenId k origAllocId isSharedId currentRetType
return code.updateCont! k
| .cases cs =>
return code.updateAlts! ( cs.alts.mapMonoM (·.mapCodeM (processResetCont resetTokenId · origAllocId isSharedId cs.resultType)))
| .jp decl k =>
let decl decl.updateValue ( processResetCont resetTokenId decl.value origAllocId isSharedId decl.type)
let k processResetCont resetTokenId k origAllocId isSharedId currentRetType
return code.updateFun! decl k
| .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .setTag (k := k) ..
| .del (k := k) .. | .oset (k := k) .. =>
let k processResetCont resetTokenId k origAllocId isSharedId currentRetType
return code.updateCont! k
| .jmp .. | .return .. | .unreach .. => return code
where
/--
On the slow path we have to:
1. Make a fresh alloation
2. Apply all the self sets as the fresh allocation is of course not in sync with the original one.
3. Pass the fresh allocation to the joinpoint.
-/
mkSlowPath (decl : LetDecl .impure) (info : CtorInfo) (args : Array (Arg .impure))
(contJpId : FVarId) (selfSets : Array (CodeDecl .impure)) : CompilerM (Code .impure) := do
let allocDecl mkLetDecl ( mkFreshBinderName `reuseFailAlloc) decl.type (.ctor info args)
let mut code := .jmp contJpId #[.fvar allocDecl.fvarId]
code := attachCodeDecls ( remapSets allocDecl.fvarId selfSets) code
code := .let allocDecl code
return code
/--
On the fast path path we have to:
1. Make all non-self object sets to "simulate" the allocation (the remaining necessary sets will
be made in the continuation)
2. Pass the reused allocation to the joinpoint.
-/
mkFastPath (resetTokenId : FVarId) (info : CtorInfo) (update : Bool) (args : Array (Arg .impure))
(contJpId : FVarId) (origAllocId : FVarId) : CompilerM (Code .impure) := do
let mut code := .jmp contJpId #[.fvar resetTokenId]
for h : idx in 0...args.size do
if !( isSelfOset origAllocId idx args[idx]) then
code := .oset resetTokenId idx args[idx] code
if update then
code := .setTag resetTokenId info.cidx code
return code
/--
Traverse `code` looking for reset-reuse pairs to expand while `ds` holds the instructions up to the
last branching point.
-/
partial def Code.expandResetReuse (code : Code .impure) (ds : Array (CodeDecl .impure))
(currentRetType : Expr) : CompilerM (Code .impure) := do
let collectAndGo (code : Code .impure) (ds : Array (CodeDecl .impure)) (k : Code .impure) :=
let d := code.toCodeDecl!
k.expandResetReuse (ds.push d) currentRetType
match code with
| .let decl k =>
match decl.value with
| .reset nFields origAllocId => expand ds decl nFields origAllocId k
| _ => collectAndGo code ds k
| .jp decl k =>
let value decl.value.expandResetReuse #[] decl.type
let decl decl.updateValue value
k.expandResetReuse (ds.push (.jp decl)) currentRetType
| .cases cs =>
let alts cs.alts.mapMonoM (·.mapCodeM (·.expandResetReuse #[] cs.resultType))
let code := code.updateAlts! alts
return attachCodeDecls ds code
| .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .setTag (k := k) ..
| .dec (k := k) .. | .del (k := k) .. | .oset (k := k) .. =>
collectAndGo code ds k
| .jmp .. | .return .. | .unreach .. =>
return attachCodeDecls ds code
where
/--
Expand the reset in `decl` together with its matching `reuse`/`dec`s in its continuation `k`.
-/
expand (ds : Array (CodeDecl .impure)) (decl : LetDecl .impure) (nFields : Nat)
(origAllocId : FVarId) (k : Code .impure) : CompilerM (Code .impure) := do
let (ds, mask) eraseProjIncFor nFields origAllocId ds
let isSharedParam mkParam ( mkFreshBinderName `isShared) uint8 false
let k processResetCont decl.fvarId k origAllocId isSharedParam.fvarId currentRetType
let k k.expandResetReuse #[] currentRetType
let allocParam := {
fvarId := decl.fvarId,
binderName := decl.binderName,
type := tobject,
borrow := false
}
let resetJp mkFunDecl ( mkFreshBinderName `resetjp) currentRetType #[allocParam, isSharedParam] k
let isSharedDecl mkLetDecl ( mkFreshBinderName `isSharedCheck) uint8 (.isShared origAllocId)
let slowPath mkSlowPath origAllocId mask resetJp.fvarId isSharedDecl.fvarId
let fastPath mkFastPath origAllocId mask resetJp.fvarId isSharedDecl.fvarId
let mut reset mkIf isSharedDecl.fvarId uint8 currentRetType slowPath fastPath
reset := .let isSharedDecl reset
eraseLetDecl decl
return attachCodeDecls ds (.jp resetJp reset)
/--
On the slow path we cannot reuse the allocation, this means we have to:
1. Increments all variables projected from `origAllocId` that have not been incremented yet by
the shared prologue. On the fast path they are kept alive naturally by the original allocation
but here that is not necessarily the case.
2. Decrement the value being reset (the natural behavior of a failed reset)
3. Pass box(0) as a reuse value into the continuation join point
-/
mkSlowPath (origAllocId : FVarId) (mask : Mask) (resetJpId : FVarId) (isSharedId : FVarId) :
CompilerM (Code .impure) := do
let mut code := .jmp resetJpId #[.erased, .fvar isSharedId]
code := .dec origAllocId 1 true false code
for fvarId? in mask do
let some fvarId := fvarId? | continue
code := .inc fvarId 1 true false code
return code
/--
On the fast path we can reuse the allocation, this means we have to:
1. decrement all unread fields as their parent allocation would usually be dropped at this point
and we want to be garbage free.
2. Pass the original allocation as a reuse value into the continuation join point
-/
mkFastPath (origAllocId : FVarId) (mask : Mask) (resetJpId : FVarId) (isSharedId : FVarId) :
CompilerM (Code .impure) := do
let mut code := .jmp resetJpId #[.fvar origAllocId, .fvar isSharedId]
for h : idx in 0...mask.size do
if mask[idx].isSome then
continue
let fieldDecl mkLetDecl ( mkFreshBinderName `unused) tobject (.oproj idx origAllocId)
code := .let fieldDecl (.dec fieldDecl.fvarId 1 true false code)
return code
end
def Decl.expandResetReuse (decl : Decl .impure) : CompilerM (Decl .impure) := do
if ( getConfig).resetReuse then
let value decl.value.mapCodeM (·.expandResetReuse #[] decl.type)
let decl := { decl with value }
return decl
else
return decl
public def expandResetReuse : Pass :=
Pass.mkPerDeclaration `expandResetReuse .impure Decl.expandResetReuse
builtin_initialize
registerTraceClass `Compiler.expandResetReuse (inherited := true)
end Lean.Compiler.LCNF

View File

@@ -284,7 +284,7 @@ partial def Code.explicitBoxing (code : Code .impure) : BoxM (Code .impure) := d
let some jpDecl findFunDecl? fvarId | unreachable!
castArgsIfNeeded args jpDecl.params fun args => return code.updateJmp! fvarId args
| .unreach .. => return code.updateUnreach! ( getResultType)
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .oset .. | .setTag .. | .del .. => unreachable!
where
/--
Up to this point the type system of IR is quite loose so we can for example encounter situations
@@ -313,7 +313,7 @@ where
| .ctor i _ => return i.type
| .fvar .. | .lit .. | .sproj .. | .oproj .. | .reset .. | .reuse .. =>
return currentType
| .box .. | .unbox .. => unreachable!
| .box .. | .unbox .. | .isShared .. => unreachable!
visitLet (code : Code .impure) (decl : LetDecl .impure) (k : Code .impure) : BoxM (Code .impure) := do
let type tryCorrectLetDeclType decl.type decl.value
@@ -350,7 +350,7 @@ where
| .erased | .reset .. | .sproj .. | .uproj .. | .oproj .. | .lit .. =>
let decl decl.update type decl.value
return code.updateLet! decl k
| .box .. | .unbox .. => unreachable!
| .box .. | .unbox .. | .isShared .. => unreachable!
def run (decls : Array (Decl .impure)) : CompilerM (Array (Decl .impure)) := do
let decls decls.foldlM (init := #[]) fun newDecls decl => do

View File

@@ -117,7 +117,7 @@ partial def collectCode (code : Code .impure) : M Unit := do
| .cases cases => cases.alts.forM (·.forCodeM collectCode)
| .sset (k := k) .. | .uset (k := k) .. => collectCode k
| .return .. | .jmp .. | .unreach .. => return ()
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
/--
Collect the derived value tree as well as the set of parameters that take objects and are borrowed.
@@ -334,6 +334,7 @@ def useLetValue (value : LetValue .impure) : RcM Unit := do
useVar fvarId
useArgs args
| .lit .. | .erased => return ()
| .isShared .. => unreachable!
@[inline]
def bindVar (fvarId : FVarId) : RcM Unit :=
@@ -547,6 +548,7 @@ def LetDecl.explicitRc (code : Code .impure) (decl : LetDecl .impure) (k : Code
addIncBeforeConsumeAll allArgs (code.updateLet! decl k)
| .lit .. | .box .. | .reset .. | .erased .. =>
pure <| code.updateLet! decl k
| .isShared .. => unreachable!
useLetValue decl.value
bindVar decl.fvarId
return k
@@ -622,7 +624,7 @@ partial def Code.explicitRc (code : Code .impure) : RcM (Code .impure) := do
| .unreach .. =>
setRetLiveVars
return code
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .del .. | .oset .. => unreachable!
def Decl.explicitRc (decl : Decl .impure) :
CompilerM (Decl .impure) := do

View File

@@ -10,6 +10,9 @@ public import Lean.Compiler.ClosedTermCache
public import Lean.Compiler.NeverExtractAttr
public import Lean.Compiler.LCNF.Internalize
public import Lean.Compiler.LCNF.ToExpr
import Lean.Compiler.LCNF.ElimDead
import Lean.Compiler.LCNF.DependsOn
meta import Init.Data.FloatArray.Basic
public section
@@ -111,33 +114,94 @@ where
end
/--
Check if `let decl; k` forms an `Array`, `ByteArray`, or `FloatArray` literal. They consist of some
initial allocation (`Array.mkEmpty` or `Array.emptyWithCapacity`) followed by a sequence of
`Array.push` and for the scalar variants finally `ByteArray.mk` or `FloatArray.mk`.
We identify these literals by matching this pattern and ensuring that only the last `push`/`mk` from
the sequence is used in the final continuation. If that is the case, we can pull out the entire
literal as one closed declaration. This avoids the quadratic overhead of repeated `Array.push` calls
on persistent `Array` objects during initialization.
-/
def searchArrayLiteral (decl : LetDecl .pure) (k : Code .pure) :
M (Option (LetDecl .pure × Code .pure)) := do
let .const ``Array.push _ #[_, .fvar parentId, _] := decl.value | return none
let some parentDecl findLetDecl? (pu := .pure) parentId | return none
match parentDecl.value with
| .const ``Array.mkEmpty _ #[_, .fvar sizeFVar]
| .const ``Array.emptyWithCapacity _ #[_, .fvar sizeFVar] =>
let some (.lit (.nat size)) findLetValue? (pu := .pure) sizeFVar | return none
identifyChain parentDecl.fvarId decl k {} size
| _ => return none
where
identifyChain (prevArrayId : FVarId) (decl : LetDecl .pure) (k : Code .pure)
(illegalSet : FVarIdSet) (size : Nat) : M (Option (LetDecl .pure × Code .pure)) := do
match size with
| 0 => return none
| nextSize + 1 =>
let .const ``Array.push _ #[_, .fvar arrId, elemArg] := decl.value | return none
if arrId != prevArrayId then return none
if !( shouldExtractArg elemArg) then return none
if nextSize != 0 then
let illegalSet := illegalSet.insert decl.fvarId
let .let nextDecl nextK := k | return none
identifyChain decl.fvarId nextDecl nextK illegalSet nextSize
else
let occursCheck (decl : LetDecl .pure) (k : Code .pure) (illegalSet : FVarIdSet) := do
if k.dependsOn illegalSet then return none
return some (decl, k)
-- At this point we can be at the end of an `Array` literal or right before the end of a
-- `ByteArray` or `FloatArray` literal, let's check.
match k with
| .let nextDecl@{ value := .const ``ByteArray.mk _ #[.fvar arrayId] _, .. } nextK
| .let nextDecl@{ value := .const ``FloatArray.mk _ #[.fvar arrayId] _, .. } nextK =>
if arrayId != decl.fvarId then
occursCheck decl k illegalSet
else
let illegalSet := illegalSet.insert decl.fvarId
occursCheck nextDecl nextK illegalSet
| _ => occursCheck decl k illegalSet
mutual
partial def visitCode (code : Code .pure) : M (Code .pure) := do
match code with
| .let decl k =>
if ( shouldExtractLetValue true decl.value) then
let _, decls extractLetValue decl.value |>.run {}
let decls := decls.reverse.push (.let decl)
let decls decls.mapM Internalize.internalizeCodeDecl |>.run' {}
let closedCode := attachCodeDecls decls (.return decls.back!.fvarId)
let closedExpr := closedCode.toExpr
let env getEnv
let name if let some closedTermName := getClosedTermName? env closedExpr then
eraseCode closedCode
pure closedTermName
let visitLetDefault := do
if let some (decl, k) searchArrayLiteral decl k then
let name performExtraction decl
let decl decl.updateValue (.const name [] #[])
return code.updateLet! decl ( visitCode k)
else if ( shouldExtractLetValue true decl.value) then
let name performExtraction decl
let decl decl.updateValue (.const name [] #[])
return code.updateLet! decl ( visitCode k)
else
let name := ( read).baseName ++ (`_closed).appendIndexAfter ( get).decls.size
cacheClosedTermName env closedExpr name |> setEnv
let decl := { name, levelParams := [], type := decl.type, params := #[],
value := .code closedCode, inlineAttr? := some .noinline }
decl.saveMono
modify fun s => { s with decls := s.decls.push decl }
pure name
let decl decl.updateValue (.const name [] #[])
return code.updateLet! decl ( visitCode k)
else
return code.updateLet! decl ( visitCode k)
return code.updateLet! decl ( visitCode k)
match decl.value with
| .const ``Array.mkEmpty _ #[_, .fvar sizeId]
| .const ``Array.emptyWithCapacity _ #[_, .fvar sizeId] =>
if let some (.lit (.nat n)) findLetValue? (pu := .pure) sizeId then
if n == 0 then
let name performExtraction decl
let decl decl.updateValue (.const name [] #[])
return code.updateLet! decl ( visitCode k)
else
/-
Extracting non-empty `Array` initializers on their own often isn't helpful because they
will almost always be used later on by other declarations. This most frequently happens in
one of two ways:
1. They get mutated by some ordinary function in which case they will be copied from the
persistent storage anyways.
2. They get used by an `Array` literal which builds up an `Array.push` chain that we
specifically pattern match on starting with an empty initializer of appropriate size.
-/
return code.updateLet! decl ( visitCode k)
else
visitLetDefault
| _ => visitLetDefault
| .fun decl k =>
let decl decl.updateValue ( visitCode decl.value)
return code.updateFun! decl ( visitCode k)
@@ -148,6 +212,26 @@ partial def visitCode (code : Code .pure) : M (Code .pure) := do
let alts cases.alts.mapMonoM (fun alt => do return alt.updateCode ( visitCode alt.getCode))
return code.updateAlts! alts
| .jmp .. | .return _ | .unreach .. => return code
where
performExtraction (decl : LetDecl .pure) : M Name := do
let _, decls extractLetValue decl.value |>.run {}
let decls := decls.reverse.push (.let decl)
let decls decls.mapM Internalize.internalizeCodeDecl |>.run' {}
let closedCode := attachCodeDecls decls (.return decls.back!.fvarId)
let closedExpr := closedCode.toExpr
let env getEnv
if let some closedTermName := getClosedTermName? env closedExpr then
eraseCode closedCode
return closedTermName
else
let name := ( read).baseName ++ (`_closed).appendIndexAfter ( get).decls.size
cacheClosedTermName env closedExpr name |> setEnv
let decl := { name, levelParams := [], type := decl.type, params := #[],
value := .code closedCode, inlineAttr? := some .noinline }
decl.saveMono
modify fun s => { s with decls := s.decls.push decl }
return name
end
@@ -159,7 +243,10 @@ end ExtractClosed
partial def Decl.extractClosed (decl : Decl .pure) (sccDecls : Array (Decl .pure)) :
CompilerM (Array (Decl .pure)) := do
let decl, s ExtractClosed.visitDecl decl |>.run { baseName := decl.name, sccDecls } |>.run {}
let mut decl, s ExtractClosed.visitDecl decl |>.run { baseName := decl.name, sccDecls } |>.run {}
if !s.decls.isEmpty then
-- Closed term extraction might have left behind dead values.
decl decl.elimDeadVars
return s.decls.push decl
def extractClosed : Pass where

View File

@@ -83,12 +83,13 @@ def LetValue.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarI
return e.updateReuse! ( f fvarId) i updateHeader ( args.mapM (TraverseFVar.mapFVarM f))
| .box ty fvarId _ => return e.updateBox! ty ( f fvarId)
| .unbox fvarId _ => return e.updateUnbox! ( f fvarId)
| .isShared fvarId _ => return e.updateIsShared! ( f fvarId)
def LetValue.forFVarM [Monad m] (f : FVarId m Unit) (e : LetValue pu) : m Unit := do
match e with
| .lit .. | .erased => return ()
| .proj _ _ fvarId _ | .oproj _ fvarId _ | .sproj _ _ fvarId _ | .uproj _ fvarId _
| .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ => f fvarId
| .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ | .isShared fvarId _ => f fvarId
| .const _ _ args _ | .pap _ args _ | .fap _ args _ | .ctor _ args _ =>
args.forM (TraverseFVar.forFVarM f)
| .fvar fvarId args | .reuse fvarId _ _ args _ => f fvarId; args.forM (TraverseFVar.forFVarM f)
@@ -139,14 +140,20 @@ partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m F
return Code.updateReturn! c ( f var)
| .unreach typ =>
return Code.updateUnreach! c ( Expr.mapFVarM f typ)
| .oset fvarId offset y k _ =>
return Code.updateOset! c ( f fvarId) offset ( y.mapFVarM f) ( mapFVarM f k)
| .sset fvarId i offset y ty k _ =>
return Code.updateSset! c ( f fvarId) i offset ( f y) ( Expr.mapFVarM f ty) ( mapFVarM f k)
| .uset fvarId offset y k _ =>
return Code.updateUset! c ( f fvarId) offset ( f y) ( mapFVarM f k)
| .setTag fvarId cidx k _ =>
return Code.updateSetTag! c ( f fvarId) cidx ( mapFVarM f k)
| .inc fvarId n check persistent k _ =>
return Code.updateInc! c ( f fvarId) n check persistent ( mapFVarM f k)
| .dec fvarId n check persistent k _ =>
return Code.updateDec! c ( f fvarId) n check persistent ( mapFVarM f k)
| .del fvarId k _ =>
return Code.updateDel! c ( f fvarId) ( mapFVarM f k)
partial def Code.forFVarM [Monad m] (f : FVarId m Unit) (c : Code pu) : m Unit := do
match c with
@@ -182,7 +189,12 @@ partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code pu) : m Un
f fvarId
f y
forFVarM f k
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
| .oset fvarId _ y k _ =>
f fvarId
y.forFVarM f
forFVarM f k
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) ..
| .del (fvarId := fvarId) (k := k) .. | .setTag (fvarId := fvarId) (k := k) .. =>
f fvarId
forFVarM f k
@@ -210,17 +222,22 @@ instance : TraverseFVar (CodeDecl pu) where
| .jp decl => return .jp ( mapFVarM f decl)
| .let decl => return .let ( mapFVarM f decl)
| .uset fvarId i y _ => return .uset ( f fvarId) i ( f y)
| .oset fvarId i y _ => return .oset ( f fvarId) i ( y.mapFVarM f)
| .sset fvarId i offset y ty _ => return .sset ( f fvarId) i offset ( f y) ( mapFVarM f ty)
| .setTag fvarId cidx _ => return .setTag ( f fvarId) cidx
| .inc fvarId n check persistent _ => return .inc ( f fvarId) n check persistent
| .dec fvarId n check persistent _ => return .dec ( f fvarId) n check persistent
| .del fvarId _ => return .del ( f fvarId)
forFVarM f decl :=
match decl with
| .fun decl _ => forFVarM f decl
| .jp decl => forFVarM f decl
| .let decl => forFVarM f decl
| .uset fvarId i y _ => do f fvarId; f y
| .oset fvarId i y _ => do f fvarId; y.forFVarM f
| .sset fvarId i offset y ty _ => do f fvarId; f y; forFVarM f ty
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => f fvarId
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. | .del (fvarId := fvarId) ..
| .setTag (fvarId := fvarId) .. => f fvarId
instance : TraverseFVar (Alt pu) where
mapFVarM f alt := do

View File

@@ -91,7 +91,7 @@ where
| .cases cs => cs.alts.forM (·.forCodeM (goCode declName))
| .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => goCode declName k
| .return .. | .jmp .. | .unreach .. => return ()
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .del .. | .oset .. => unreachable!
/--
Apply the inferred borrow annotations from `map` to a SCC.
@@ -121,7 +121,7 @@ where
| .cases cs => return code.updateAlts! <| cs.alts.mapMonoM (·.mapCodeM (go declName))
| .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => return code.updateCont! ( go declName k)
| .return .. | .jmp .. | .unreach .. => return code
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
structure Ctx where
/--
@@ -300,7 +300,7 @@ where
| .cases cs => cs.alts.forM (·.forCodeM collectCode)
| .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => collectCode k
| .return .. | .unreach .. => return ()
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
public def inferBorrow : Pass where

View File

@@ -120,6 +120,10 @@ private partial def internalizeLetValue (e : LetValue pu) : InternalizeM pu (Let
match ( normFVar fvarId) with
| .fvar fvarId' => return e.updateBox! ty fvarId'
| .erased => return .erased
| .isShared fvarId _ =>
match ( normFVar fvarId) with
| .fvar fvarId' => return e.updateIsShared! fvarId'
| .erased => return .erased
def internalizeLetDecl (decl : LetDecl pu) : InternalizeM pu (LetDecl pu) := do
let binderName refreshBinderName decl.binderName
@@ -166,12 +170,22 @@ partial def internalizeCode (code : Code pu) : InternalizeM pu (Code pu) := do
withNormFVarResult ( normFVar fvarId) fun fvarId => do
withNormFVarResult ( normFVar y) fun y => do
return .uset fvarId offset y ( internalizeCode k)
| .oset fvarId offset y k _ =>
withNormFVarResult ( normFVar fvarId) fun fvarId => do
let y normArg y
return .oset fvarId offset y ( internalizeCode k)
| .setTag fvarId cidx k _ =>
withNormFVarResult ( normFVar fvarId) fun fvarId => do
return .setTag fvarId cidx ( internalizeCode k)
| .inc fvarId n check persistent k _ =>
withNormFVarResult ( normFVar fvarId) fun fvarId => do
return .inc fvarId n check persistent ( internalizeCode k)
| .dec fvarId n check persistent k _ =>
withNormFVarResult ( normFVar fvarId) fun fvarId => do
return .dec fvarId n check persistent ( internalizeCode k)
| .del fvarId k _ =>
withNormFVarResult ( normFVar fvarId) fun fvarId => do
return .del fvarId ( internalizeCode k)
end
@@ -180,8 +194,12 @@ partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl
| .let decl => return .let ( internalizeLetDecl decl)
| .fun decl _ => return .fun ( internalizeFunDecl decl)
| .jp decl => return .jp ( internalizeFunDecl decl)
| .uset fvarId i y _ =>
| .oset fvarId i y _ =>
-- Something weird should be happening if these become erased...
let .fvar fvarId normFVar fvarId | unreachable!
let y normArg y
return .oset fvarId i y
| .uset fvarId i y _ =>
let .fvar fvarId normFVar fvarId | unreachable!
let .fvar y normFVar y | unreachable!
return .uset fvarId i y
@@ -190,12 +208,18 @@ partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl
let .fvar y normFVar y | unreachable!
let ty normExpr ty
return .sset fvarId i offset y ty
| .setTag fvarId cidx _ =>
let .fvar fvarId normFVar fvarId | unreachable!
return .setTag fvarId cidx
| .inc fvarId n check offset _ =>
let .fvar fvarId normFVar fvarId | unreachable!
return .inc fvarId n check offset
| .dec fvarId n check offset _ =>
let .fvar fvarId normFVar fvarId | unreachable!
return .dec fvarId n check offset
| .del fvarId _ =>
let .fvar fvarId normFVar fvarId | unreachable!
return .del fvarId
end Internalize

View File

@@ -77,7 +77,8 @@ mutual
| .let decl k => eraseCode k <| lctx.eraseLetDecl decl
| .jp decl k | .fun decl k _ => eraseCode k <| eraseFunDecl lctx decl
| .cases c => eraseAlts c.alts lctx
| .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
| .oset (k := k) .. | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. | .del (k := k) .. | .setTag (k := k) .. =>
eraseCode k lctx
| .return .. | .jmp .. | .unreach .. => lctx
end

View File

@@ -65,6 +65,8 @@ where
| .jp decl k => go decl.value <||> (do markJpVisited decl.fvarId; go k)
| .uset fvarId _ y k _ | .sset fvarId _ _ y _ k _ =>
visitVar fvarId <||> visitVar y <||> go k
| .oset fvarId _ y k _ =>
visitVar fvarId <||> pure (y.dependsOn ( read).targetSet) <||> go k
| .cases c => visitVar c.discr <||> c.alts.anyM (go ·.getCode)
| .jmp fvarId args =>
(pure <| args.any (·.dependsOn ( read).targetSet)) <||> do
@@ -76,7 +78,8 @@ where
go decl.value
| .return var => visitVar var
| .unreach .. => return false
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) ..
| .setTag (fvarId := fvarId) (k := k) .. | .del (fvarId := fvarId) (k := k) =>
visitVar fvarId <||> go k
@[inline]

View File

@@ -27,6 +27,8 @@ public import Lean.Compiler.LCNF.InferBorrow
public import Lean.Compiler.LCNF.ExplicitBoxing
public import Lean.Compiler.LCNF.ExplicitRC
public import Lean.Compiler.LCNF.Toposort
public import Lean.Compiler.LCNF.ExpandResetReuse
public import Lean.Compiler.LCNF.SimpleGroundExpr
public section
@@ -154,6 +156,9 @@ def builtinPassManager : PassManager := {
inferBorrow,
explicitBoxing,
explicitRc,
expandResetReuse,
pushProj (occurrence := 1),
detectSimpleGround,
inferVisibility (phase := .impure),
saveImpure, -- End of impure phase
toposortPass,

View File

@@ -83,10 +83,10 @@ def ppLetValue (e : LetValue pu) : M Format := do
| .proj _ i fvarId _ => return f!"{← ppFVar fvarId} # {i}"
| .fvar fvarId args => return f!"{← ppFVar fvarId}{← ppArgs args}"
| .const declName us args _ => return f!"{← ppExpr (.const declName us)}{← ppArgs args}"
| .ctor i args _ => return f!"{i} {← ppArgs args}"
| .ctor i args _ => return f!"{i}{← ppArgs args}"
| .fap declName args _ => return f!"{declName}{← ppArgs args}"
| .pap declName args _ => return f!"pap {declName}{← ppArgs args}"
| .oproj i fvarId _ => return f!"proj[{i}] {← ppFVar fvarId}"
| .oproj i fvarId _ => return f!"oproj[{i}] {← ppFVar fvarId}"
| .uproj i fvarId _ => return f!"uproj[{i}] {← ppFVar fvarId}"
| .sproj i offset fvarId _ => return f!"sproj[{i}, {offset}] {← ppFVar fvarId}"
| .reset n fvarId _ => return f!"reset[{n}] {← ppFVar fvarId}"
@@ -94,6 +94,7 @@ def ppLetValue (e : LetValue pu) : M Format := do
return f!"reuse" ++ (if updateHeader then f!"!" else f!"") ++ f!" {← ppFVar fvarId} in {info}{← ppArgs args}"
| .box _ fvarId _ => return f!"box {← ppFVar fvarId}"
| .unbox fvarId _ => return f!"unbox {← ppFVar fvarId}"
| .isShared fvarId _ => return f!"isShared {← ppFVar fvarId}"
def ppParam (param : Param pu) : M Format := do
let borrow := if param.borrow then "@&" else ""
@@ -144,11 +145,15 @@ mutual
return ""
| .sset fvarId i offset y ty k _ =>
if pp.letVarTypes.get ( getOptions) then
return f!"sset {← ppFVar fvarId} [{i}, {offset}] : {← ppExpr ty} := {← ppFVar y} " ++ ";" ++ .line ++ ( ppCode k)
return f!"sset {← ppFVar fvarId}[{i}, {offset}] : {← ppExpr ty} := {← ppFVar y};" ++ .line ++ ( ppCode k)
else
return f!"sset {← ppFVar fvarId} [{i}, {offset}] := {← ppFVar y} " ++ ";" ++ .line ++ ( ppCode k)
return f!"sset {← ppFVar fvarId}[{i}, {offset}] := {← ppFVar y};" ++ .line ++ ( ppCode k)
| .uset fvarId i y k _ =>
return f!"uset {← ppFVar fvarId} [{i}] := {← ppFVar y} " ++ ";" ++ .line ++ ( ppCode k)
return f!"uset {← ppFVar fvarId}[{i}] := {← ppFVar y};" ++ .line ++ ( ppCode k)
| .oset fvarId i y k _ =>
return f!"oset {← ppFVar fvarId} [{i}] := {← ppArg y};" ++ .line ++ ( ppCode k)
| .setTag fvarId cidx k _ =>
return f!"setTag {← ppFVar fvarId} := {cidx};" ++ .line ++ ( ppCode k)
| .inc fvarId n _ _ k _ =>
if n != 1 then
return f!"inc[{n}] {← ppFVar fvarId};" ++ .line ++ ( ppCode k)
@@ -159,6 +164,8 @@ mutual
return f!"dec[{n}] {← ppFVar fvarId};" ++ .line ++ ( ppCode k)
else
return f!"dec {← ppFVar fvarId};" ++ .line ++ ( ppCode k)
| .del fvarId k _ =>
return f!"del {← ppFVar fvarId};" ++ .line ++ ( ppCode k)
partial def ppDeclValue (b : DeclValue pu) : M Format := do

View File

@@ -58,7 +58,8 @@ where
go k
| .cases cs => cs.alts.forM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return ()
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
start (decls : Array (Decl pu)) : StateRefT (Array (LetValue pu)) CompilerM Unit :=
decls.forM (·.value.forCodeM go)
@@ -73,7 +74,8 @@ where
| .jp decl k => modify (·.push decl); go decl.value; go k
| .cases cs => cs.alts.forM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return ()
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
start (decls : Array (Decl pu)) : StateRefT (Array (FunDecl pu)) CompilerM Unit :=
decls.forM (·.value.forCodeM go)
@@ -86,7 +88,8 @@ where
| .fun decl k _ | .jp decl k => go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
partial def filterByFun (pu : Purity) (f : FunDecl pu CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
@@ -96,7 +99,8 @@ where
| .fun decl k _ => do if ( f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
partial def filterByJp (pu : Purity) (f : FunDecl pu CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
@@ -107,7 +111,8 @@ where
| .jp decl k => do if ( f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
partial def filterByFunDecl (pu : Purity) (f : FunDecl pu CompilerM Bool) :
Probe (Decl pu) (Decl pu):=
@@ -118,7 +123,8 @@ where
| .fun decl k _ | .jp decl k => do if ( f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
partial def filterByCases (pu : Purity) (f : Cases pu CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
@@ -128,7 +134,8 @@ where
| .fun decl k _ | .jp decl k => go decl.value <||> go k
| .cases cs => do if ( f cs) then return true else cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
partial def filterByJmp (pu : Purity) (f : FVarId Array (Arg pu) CompilerM Bool) :
Probe (Decl pu) (Decl pu) :=
@@ -140,7 +147,8 @@ where
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp fn var => f fn var
| .return .. | .unreach .. => return false
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
partial def filterByReturn (pu : Purity) (f : FVarId CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
@@ -151,7 +159,8 @@ where
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .unreach .. => return false
| .return var => f var
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
partial def filterByUnreach (pu : Purity) (f : Expr CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
@@ -162,7 +171,8 @@ where
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. => return false
| .unreach typ => f typ
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
@[inline]
def declNames (pu : Purity) : Probe (Decl pu) Name :=

View File

@@ -133,6 +133,8 @@ where
| .jp decl k =>
let decl decl.updateValue ( decl.value.pushProj)
go k (decls.push (.jp decl))
| .oset fvarId i y k _ =>
go k (decls.push (.oset fvarId i y))
| .uset fvarId i y k _ =>
go k (decls.push (.uset fvarId i y))
| .sset fvarId i offset y ty k _ =>
@@ -141,6 +143,10 @@ where
go k (decls.push (.inc fvarId n check persistent))
| .dec fvarId n check persistent k _ =>
go k (decls.push (.dec fvarId n check persistent))
| .del fvarId k _ =>
go k (decls.push (.del fvarId))
| .setTag fvarId cidx k _ =>
go k (decls.push (.setTag fvarId cidx))
| .cases c => c.pushProjs decls
| .jmp .. | .return .. | .unreach .. =>
return attachCodeDecls decls c

View File

@@ -53,7 +53,8 @@ partial def Code.applyRenaming (code : Code pu) (r : Renaming) : CompilerM (Code
| .ctorAlt _ k _ => return alt.updateCode ( k.applyRenaming r)
return code.updateAlts! alts
| .jmp .. | .unreach .. | .return .. => return code
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
| .oset (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) ..
| .del (k := k) .. | .setTag (k := k) .. =>
return code.updateCont! ( k.applyRenaming r)
end

View File

@@ -120,7 +120,7 @@ where
| .return .. | .jmp .. | .unreach .. => return (c, false)
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ | .let _ k =>
goK k
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
def isCtorUsing (instr : CodeDecl .impure) (x : FVarId) : Bool :=
match instr with
@@ -242,7 +242,7 @@ where
return (c.updateCont! k, false)
| .return .. | .jmp .. | .unreach .. =>
return (c, c.isFVarLiveIn x)
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
end
@@ -275,7 +275,7 @@ partial def Code.insertResetReuse (c : Code .impure) : ReuseM (Code .impure) :=
| .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ =>
return c.updateCont! ( k.insertResetReuse)
| .return .. | .jmp .. | .unreach .. => return c
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
partial def Decl.insertResetReuseCore (decl : Decl .impure) : ReuseM (Decl .impure) := do
let value decl.value.mapCodeM fun code => do
@@ -298,7 +298,7 @@ where
| .jp decl k => collectResets decl.value; collectResets k
| .cases c => c.alts.forM (collectResets ·.getCode)
| .unreach .. | .return .. | .jmp .. => return ()
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
def Decl.insertResetReuse (decl : Decl .impure) : CompilerM (Decl .impure) := do

View File

@@ -107,7 +107,8 @@ partial def Code.simpCase (code : Code .impure) : CompilerM (Code .impure) := do
let decl decl.updateValue ( decl.value.simpCase)
return code.updateFun! decl ( k.simpCase)
| .return .. | .jmp .. | .unreach .. => return code
| .let _ k | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
| .let _ k | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) ..
| .setTag (k := k) .. | .del (k := k) .. | .oset (k := k) .. =>
return code.updateCont! ( k.simpCase)
def Decl.simpCase (decl : Decl .impure) : CompilerM (Decl .impure) := do

View File

@@ -6,7 +6,8 @@ Authors: Henrik Böving
module
prelude
public import Lean.Compiler.IR.CompilerM
public import Lean.Compiler.LCNF.CompilerM
public import Lean.Compiler.LCNF.PassManager
import Init.While
/-!
@@ -18,9 +19,9 @@ step can reference this environment extension to generate static initializers fo
declaration.
-/
namespace Lean
namespace Lean.Compiler.LCNF
namespace IR
open ImpureType
/--
An argument to a `SimpleGroundExpr`. They get compiled to `lean_object*` in various ways.
@@ -49,7 +50,7 @@ public inductive SimpleGroundExpr where
Represents a `lean_ctor_object`. Crucially the `scalarArgs` array must have a size that is a
multiple of 8.
-/
| ctor (cidx : Nat) (objArgs : Array SimpleGroundArg) (usizeArgs : Array USize) (scalarArgs : Array UInt8)
| ctor (cidx : Nat) (objArgs : Array SimpleGroundArg) (usizeArgs : Array UInt64) (scalarArgs : Array UInt8)
/--
A string literal, represented by a `lean_string_object`.
-/
@@ -57,7 +58,7 @@ public inductive SimpleGroundExpr where
/--
A partial application, represented by a `lean_closure_object`.
-/
| pap (func : FunId) (args : Array SimpleGroundArg)
| pap (func : Name) (args : Array SimpleGroundArg)
/--
An application of `Lean.Name.mkStrX`. This expression is represented separately to ensure that
long name literals get extracted into statically initializable constants. The arguments contain
@@ -70,6 +71,14 @@ public inductive SimpleGroundExpr where
compiled to a reference to the mangled version of the name.
-/
| reference (n : Name)
/--
An array of `lean_object*` elements, represented by a `lean_array_object`.
-/
| array (elems : Array SimpleGroundArg)
/--
A byte array (scalar array with elem_size=1), represented by a `lean_sarray_object`.
-/
| byteArray (data : Array UInt8)
deriving Inhabited
public structure SimpleGroundExtState where
@@ -139,16 +148,21 @@ inductive SimpleGroundValue where
| uint16 (val : UInt16)
| uint32 (val : UInt32)
| uint64 (val : UInt64)
| usize (val : USize)
| usize (val : UInt64)
/--
Contains the elements of the array in a list in reverse order to enable sharing as we traverse the
expression instead.
-/
| arrayBuilder (elems : List SimpleGroundArg) (remainingCapacity : Nat)
deriving Inhabited
structure State where
groundMap : Std.HashMap VarId SimpleGroundValue := {}
structure DetectState where
groundMap : Std.HashMap FVarId SimpleGroundValue := {}
abbrev M := StateRefT State $ OptionT CompilerM
abbrev DetectM := StateRefT DetectState $ OptionT CompilerM
/--
Attempt to compile `b` into a `SimpleGroundExpr`. If `b` is not compileable return `none`.
Attempt to compile `code` into a `SimpleGroundExpr`. If `code` is not compileable return `none`.
The compiler currently supports the following patterns:
- String literals
@@ -156,47 +170,49 @@ The compiler currently supports the following patterns:
- Constructor calls with other simple expressions
- `Name.mkStrX`, `Name.str._override`, and `Name.num._override`
- references to other declarations marked as simple ground expressions
- Array literals (`Array.mkEmpty` + `Array.push` chains)
- ByteArray literals (`Array.mkEmpty` + `Array.push` chains + `ByteArray.mk`)
-/
partial def compileToSimpleGroundExpr (b : FnBody) : CompilerM (Option SimpleGroundExpr) :=
compileFnBody b |>.run' {} |>.run
partial def compileToSimpleGroundExpr (code : Code .impure) : CompilerM (Option SimpleGroundExpr) :=
go code |>.run' {} |>.run
where
compileFnBody (b : FnBody) : M SimpleGroundExpr := do
match b with
| .vdecl id _ expr (.ret (.var id')) =>
guard <| id == id'
compileFinalExpr expr
| .vdecl id ty expr b => compileNonFinalExpr id ty expr b
go (code : Code .impure) : DetectM SimpleGroundExpr := do
match code with
| .let decl (.return fvarId) =>
guard <| decl.fvarId == fvarId
compileFinalLet decl.value
| .let decl k => compileNonFinalLet decl k
| _ => failure
@[inline]
record (id : VarId) (val : SimpleGroundValue) : M Unit :=
record (id : FVarId) (val : SimpleGroundValue) : DetectM Unit :=
modify fun s => { s with groundMap := s.groundMap.insert id val }
compileNonFinalExpr (id : VarId) (ty : IRType) (expr : Expr) (b : FnBody) : M SimpleGroundExpr := do
match expr with
compileNonFinalLet (decl : LetDecl .impure) (k : Code .impure) : DetectM SimpleGroundExpr := do
match decl.value with
| .fap c #[] =>
guard <| isSimpleGroundDecl ( getEnv) c
record id (.arg (.reference c))
compileFnBody b
record decl.fvarId (.arg (.reference c))
go k
| .lit v =>
match v with
| .num v =>
match ty with
| .tagged =>
| .nat v =>
match decl.type with
| ImpureType.tagged =>
guard <| v < 2^31
record id (.arg (.tagged v))
| .uint8 => record id (.uint8 (.ofNat v))
| .uint16 => record id (.uint16 (.ofNat v))
| .uint32 => record id (.uint32 (.ofNat v))
| .uint64 => record id (.uint64 (.ofNat v))
| .usize => record id (.usize (.ofNat v))
record decl.fvarId (.arg (.tagged v))
| _ => failure
compileFnBody b
| .uint8 v => record decl.fvarId (.uint8 v)
| .uint16 v => record decl.fvarId (.uint16 v)
| .uint32 v => record decl.fvarId (.uint32 v)
| .uint64 v => record decl.fvarId (.uint64 v)
| .usize v => record decl.fvarId (.usize v)
| .str .. => failure
go k
| .ctor i objArgs =>
if i.isScalar then
record id (.arg (.tagged i.cidx))
compileFnBody b
record decl.fvarId (.arg (.tagged i.cidx))
go k
else
let objArgs compileArgs objArgs
let usizeArgs := Array.replicate i.usize 0
@@ -205,16 +221,37 @@ where
(v / a) * a + a * (if v % a != 0 then 1 else 0)
let alignedSsize := align i.ssize 8
let ssizeArgs := Array.replicate alignedSsize 0
compileSetChain id i objArgs usizeArgs ssizeArgs b
compileSetChain decl.fvarId i objArgs usizeArgs ssizeArgs k
| .box _ fvarId =>
match ( get).groundMap[fvarId]! with
| .uint8 v =>
record decl.fvarId (.arg (.tagged v.toNat))
go k
| .uint16 v =>
record decl.fvarId (.arg (.tagged v.toNat))
go k
-- boxed uint32/uint64 get extracted into separate closed terms automatically
| _ => failure
| .fap ``Array.mkEmpty #[.erased, .fvar sizeId]
| .fap ``Array.emptyWithCapacity #[.erased, .fvar sizeId] =>
let .arg (.tagged size) := ( get).groundMap[sizeId]! | failure
record decl.fvarId (.arrayBuilder [] size)
go k
| .fap ``Array.push #[.erased, .fvar arrId, .fvar elemId] =>
let .arrayBuilder elems remainingCapacity := ( get).groundMap[arrId]! | failure
let .arg elemArg := ( get).groundMap[elemId]! | failure
record decl.fvarId (.arrayBuilder (elemArg :: elems) (remainingCapacity - 1))
go k
| _ => failure
compileSetChain (id : VarId) (info : CtorInfo) (objArgs : Array SimpleGroundArg) (usizeArgs : Array USize)
(scalarArgs : Array UInt8) (b : FnBody) : M SimpleGroundExpr := do
match b with
| .ret (.var id') =>
guard <| id == id'
compileSetChain (id : FVarId) (info : CtorInfo) (objArgs : Array SimpleGroundArg)
(usizeArgs : Array UInt64) (scalarArgs : Array UInt8) (code : Code .impure) :
DetectM SimpleGroundExpr := do
match code with
| .return fvarId =>
guard <| id == fvarId
return .ctor info.cidx objArgs usizeArgs scalarArgs
| .sset id' i offset y _ b =>
| .sset id' i offset y _ k =>
guard <| id == id'
let i := i - objArgs.size - usizeArgs.size
let offset := i * 8 + offset
@@ -244,21 +281,21 @@ where
let scalarArgs := scalarArgs.set! (offset + 7) (v >>> 0x38).toUInt8
pure scalarArgs
| _ => failure
compileSetChain id info objArgs usizeArgs scalarArgs b
| .uset id' i y b =>
compileSetChain id info objArgs usizeArgs scalarArgs k
| .uset id' i y k =>
guard <| id == id'
let i := i - objArgs.size
let .usize v := ( get).groundMap[y]! | failure
let usizeArgs := usizeArgs.set! i v
compileSetChain id info objArgs usizeArgs scalarArgs b
compileSetChain id info objArgs usizeArgs scalarArgs k
| _ => failure
compileFinalExpr (e : Expr) : M SimpleGroundExpr := do
compileFinalLet (e : LetValue .impure) : DetectM SimpleGroundExpr := do
match e with
| .lit v =>
match v with
| .str v => return .string v
| .num .. => failure
| _ => failure
| .ctor i args =>
guard <| i.usize == 0 && i.ssize == 0 && !args.isEmpty
return .ctor i.cidx ( compileArgs args) #[] #[]
@@ -289,34 +326,64 @@ where
nameAcc := .str nameAcc str
processedArgs := processedArgs.push (ref, nameAcc.hash)
return .nameMkStr processedArgs
| .pap c ys => return .pap c ( compileArgs ys)
| .fap ``Array.mkEmpty #[.erased, .fvar sizeId]
| .fap ``Array.emptyWithCapacity #[.erased, .fvar sizeId] =>
let .arg (.tagged 0) := ( get).groundMap[sizeId]! | failure
return .array #[]
| .fap ``ByteArray.mk #[.fvar argId] =>
match ( get).groundMap[argId]! with
| .arrayBuilder elems 0 =>
let bytes elems.mapM fun elem => do
let .tagged v := elem | failure
return v.toUInt8
return .byteArray bytes.toArray.reverse
| .arg (.reference ref) =>
let some (.array elems) := getSimpleGroundExprWithResolvedRefs ( getEnv) ref | failure
let bytes elems.mapM fun elem => do
let .tagged v := elem | failure
return v.toUInt8
return .byteArray bytes
| _ => failure
| .fap ``Array.push #[.erased, .fvar arrId, .fvar elemId] =>
let .arrayBuilder elems remainingCapacity := ( get).groundMap[arrId]! | failure
if remainingCapacity > 1 then failure
let .arg elemArg := ( get).groundMap[elemId]! | failure
return .array (elemArg :: elems).toArray.reverse
| .fap c #[] =>
guard <| isSimpleGroundDecl ( getEnv) c
return .reference c
| .box _ fvarId =>
match ( get).groundMap[fvarId]! with
| .uint32 _ => failure -- TODO: figure out how to do this properly with 32/64bit restrictions
| .uint64 v => return .ctor 0 #[] #[] (uint64ToByteArrayLE v)
| .usize v => return .ctor 0 #[] #[v] #[]
| .uint8 _ | .uint16 _ -- boxed uint8/uint16 should never be final expressions
| _ => failure
| .pap c ys => return .pap c ( compileArgs ys)
| _ => failure
compileArg (arg : Arg) : M SimpleGroundArg := do
compileArg (arg : Arg .impure) : DetectM SimpleGroundArg := do
match arg with
| .var var =>
let .arg arg := ( get).groundMap[var]! | failure
| .fvar fvarId =>
let .arg arg := ( get).groundMap[fvarId]! | failure
return arg
| .erased => return .tagged 0
compileArgs (args : Array Arg) : M (Array SimpleGroundArg) := do
compileArgs (args : Array (Arg .impure)) : DetectM (Array SimpleGroundArg) := do
args.mapM compileArg
compileStrArg (arg : Arg) : M (Name × String) := do
let .var var := arg | failure
let (.arg (.reference ref)) := ( get).groundMap[var]! | failure
compileStrArg (arg : Arg .impure) : DetectM (Name × String) := do
let .fvar fvarId := arg | failure
let (.arg (.reference ref)) := ( get).groundMap[fvarId]! | failure
let some (.string val) := getSimpleGroundExprWithResolvedRefs ( getEnv) ref | failure
return (ref, val)
interpStringLiteral (arg : SimpleGroundArg) : M String := do
interpStringLiteral (arg : SimpleGroundArg) : DetectM String := do
let .reference ref := arg | failure
let some (.string val) := getSimpleGroundExprWithResolvedRefs ( getEnv) ref | failure
return val
interpNameLiteral (arg : SimpleGroundArg) : M Name := do
interpNameLiteral (arg : SimpleGroundArg) : DetectM Name := do
match arg with
| .tagged 0 => return .anonymous
| .reference ref =>
@@ -340,15 +407,20 @@ where
Detect whether `d` can be compiled to a `SimpleGroundExpr`. If it can record the associated
`SimpleGroundExpr` into the environment for later processing by code emission.
-/
public def Decl.detectSimpleGround (d : Decl) : CompilerM Unit := do
let .fdecl (body := body) (xs := params) (type := type) .. := d | return ()
if type.isPossibleRef && params.isEmpty then
if let some groundExpr compileToSimpleGroundExpr body then
trace[compiler.ir.simple_ground] m!"Marked {d.name} as simple ground expr"
def Decl.detectSimpleGround (d : Decl .impure) : CompilerM Unit := do
let .code code := d.value | return ()
if d.type.isPossibleRef && d.params.isEmpty then
if let some groundExpr compileToSimpleGroundExpr code then
trace[Compiler.simpleGround] m!"Marked {d.name} as simple ground expr"
modifyEnv fun env => addSimpleGroundDecl env d.name groundExpr
builtin_initialize registerTraceClass `compiler.ir.simple_ground (inherited := true)
public def detectSimpleGround : Pass where
phase := .impure
name := `detectSimpleGround
run := fun decls => do
decls.forM Decl.detectSimpleGround
return decls
end IR
builtin_initialize registerTraceClass `Compiler.simpleGround (inherited := true)
end Lean
end Lean.Compiler.LCNF

View File

@@ -116,10 +116,18 @@ partial def Code.toExprM (code : Code pu) : ToExprM Expr := do
let value := mkApp5 (mkConst `sset) (.fvar fvarId) (toExpr i) (toExpr offset) (.fvar y) ty
let body withFVar fvarId k.toExprM
return .letE `dummy (mkConst ``Unit) value body true
| .oset fvarId offset y k _ =>
let value := mkApp3 (mkConst `oset) (.fvar fvarId) (toExpr offset) ( y.toExprM)
let body withFVar fvarId k.toExprM
return .letE `dummy (mkConst ``Unit) value body true
| .uset fvarId offset y k _ =>
let value := mkApp3 (mkConst `uset) (.fvar fvarId) (toExpr offset) (.fvar y)
let body withFVar fvarId k.toExprM
return .letE `dummy (mkConst ``Unit) value body true
| .setTag fvarId cidx k _ =>
let body withFVar fvarId k.toExprM
let value := mkApp2 (mkConst `setTag) (.fvar fvarId) (toExpr cidx)
return .letE `dummy (mkConst ``Unit) value body true
| .inc fvarId n check persistent k _ =>
let value := mkApp4 (mkConst `inc) (.fvar fvarId) (toExpr n) (toExpr check) (toExpr persistent)
let body withFVar fvarId k.toExprM
@@ -128,6 +136,10 @@ partial def Code.toExprM (code : Code pu) : ToExprM Expr := do
let body withFVar fvarId k.toExprM
let value := mkApp4 (mkConst `dec) (.fvar fvarId) (toExpr n) (toExpr check) (toExpr persistent)
return .letE `dummy (mkConst ``Unit) value body true
| .del fvarId k _ =>
let body withFVar fvarId k.toExprM
let value := mkApp (mkConst `del) (.fvar fvarId)
return .letE `dummy (mkConst ``Unit) value body true
end
public def Code.toExpr (code : Code pu) (xs : Array FVarId := #[]) : Expr :=

View File

@@ -348,15 +348,15 @@ def mkParam (binderName : Name) (type : Expr) : M (Param .pure) := do
modify fun s => { s with lctx := s.lctx.mkLocalDecl param.fvarId binderName type .default }
return param
def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (arg : Arg .pure)
(nondep : Bool) : M (LetDecl .pure) := do
def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (arg : Arg .pure) :
M (LetDecl .pure) := do
let binderName cleanupBinderName binderName
let value' match arg with
| .fvar fvarId => pure <| .fvar fvarId #[]
| .erased | .type .. => pure .erased
let letDecl LCNF.mkLetDecl binderName type' value'
modify fun s => { s with
lctx := s.lctx.mkLetDecl letDecl.fvarId binderName type value nondep
lctx := s.lctx.mkLetDecl letDecl.fvarId binderName type value false
seq := s.seq.push <| .let letDecl
}
return letDecl
@@ -385,38 +385,6 @@ where
else
return (ps, e.instantiateRev xs)
/--
Given `e` and `args` where `mkAppN e (args.map (·.toExpr))` is not necessarily well-typed
(because of dependent typing), returns `e.beta args'` where `args'` are new local declarations each
assigned to a value in `args` with adjusted type (such that the resulting expression is well-typed).
-/
def mkTypeCorrectApp (e : Expr) (args : Array (Arg .pure)) : M Expr := do
if args.isEmpty then
return e
let type liftMetaM <| do
let type Meta.inferType e
if type.getNumHeadForalls < args.size then
-- expose foralls
Meta.forallBoundedTelescope type args.size Meta.mkForallFVars
else
return type
go type 0 #[]
where
go (type : Expr) (i : Nat) (xs : Array Expr) : M Expr := do
if h : i < args.size then
match type with
| .forallE nm t b bi =>
let t := t.instantiateRev xs
let arg := args[i]
if liftMetaM <| Meta.isProp t then
go b (i + 1) (xs.push (mkLcProof t))
else
let decl mkLetDecl nm t arg.toExpr ( arg.inferType) arg (nondep := true)
go b (i + 1) (xs.push (.fvar decl.fvarId))
| _ => liftMetaM <| Meta.throwFunctionExpected (mkAppN e xs)
else
return e.beta xs
def mustEtaExpand (env : Environment) (e : Expr) : Bool :=
if let .const declName _ := e.getAppFn then
match env.find? declName with
@@ -558,7 +526,7 @@ where
k args[arity...*]
```
-/
mkOverApplication (app : Arg .pure) (args : Array Expr) (arity : Nat) : M (Arg .pure) := do
mkOverApplication (app : (Arg .pure)) (args : Array Expr) (arity : Nat) : M (Arg .pure) := do
if args.size == arity then
return app
else
@@ -573,14 +541,11 @@ where
/--
Visit a `matcher`/`casesOn` alternative.
-/
visitAlt (casesAltInfo : CasesAltInfo) (e : Expr) (overArgs : Array (Arg .pure)) :
M (Expr × (Alt .pure)) := do
visitAlt (casesAltInfo : CasesAltInfo) (e : Expr) : M (Expr × (Alt .pure)) := do
withNewScope do
match casesAltInfo with
| .default numHyps =>
let e := mkAppN e (Array.replicate numHyps erasedExpr)
let e mkTypeCorrectApp e overArgs
let c toCode ( visit e)
let c toCode ( visit (mkAppN e (Array.replicate numHyps erasedExpr)))
let altType c.inferType
return (altType, .default c)
| .ctor ctorName numParams =>
@@ -590,7 +555,6 @@ where
let (ps', e') ToLCNF.visitLambda e
ps := ps ++ ps'
e := e'
e mkTypeCorrectApp e overArgs
/-
Insert the free variable ids of fields that are type formers into `toAny`.
Recall that we do not want to have "data" occurring in types.
@@ -615,8 +579,7 @@ where
visitCases (casesInfo : CasesInfo) (e : Expr) : M (Arg .pure) :=
etaIfUnderApplied e casesInfo.arity do
let args := e.getAppArgs
let overArgs (args.drop casesInfo.arity).mapM visitAppArg
let mut resultType toLCNFType ( liftMetaM do Meta.inferType (mkAppN e.getAppFn args))
let mut resultType toLCNFType ( liftMetaM do Meta.inferType (mkAppN e.getAppFn args[*...casesInfo.arity]))
let typeName := casesInfo.indName
let .inductInfo indVal getConstInfo typeName | unreachable!
if casesInfo.numAlts == 0 then
@@ -646,7 +609,8 @@ where
fieldArgs := fieldArgs.push fieldArg
return fieldArgs
let f := args[casesInfo.altsRange.lower]!
visit (mkAppN (mkAppN f fieldArgs) (overArgs.map (·.toExpr)))
let result visit (mkAppN f fieldArgs)
mkOverApplication result args casesInfo.arity
else
let mut alts := #[]
let discr visitAppArg args[casesInfo.discrPos]!
@@ -654,13 +618,14 @@ where
| .fvar discrFVarId => pure discrFVarId
| .erased | .type .. => mkAuxLetDecl .erased
for i in casesInfo.altsRange, numParams in casesInfo.altNumParams do
let (altType, alt) visitAlt numParams args[i]! overArgs
let (altType, alt) visitAlt numParams args[i]!
resultType := joinTypes altType resultType
alts := alts.push alt
let cases := typeName, resultType, discrFVarId, alts
let auxDecl mkAuxParam resultType
pushElement (.cases auxDecl cases)
return .fvar auxDecl.fvarId
let result := .fvar auxDecl.fvarId
mkOverApplication result args casesInfo.arity
visitCtor (arity : Nat) (e : Expr) : M (Arg .pure) :=
etaIfUnderApplied e arity do
@@ -878,14 +843,14 @@ where
visitLet (e : Expr) (xs : Array Expr) : M (Arg .pure) := do
match e with
| .letE binderName type value body nondep =>
| .letE binderName type value body _ =>
let type := type.instantiateRev xs
let value := value.instantiateRev xs
if ( (liftMetaM <| Meta.isProp type) <||> isTypeFormerType type) then
visitLet body (xs.push value)
else
let type' toLCNFType type
let letDecl mkLetDecl binderName type value type' ( visit value) nondep
let letDecl mkLetDecl binderName type value type' ( visit value)
visitLet body (xs.push (.fvar letDecl.fvarId))
| _ =>
let e := e.instantiateRev xs

Some files were not shown because too many files have changed in this diff Show More