mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-19 19:34:13 +00:00
Compare commits
49 Commits
no_dsimp_i
...
indexmap_u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a2e87fabe | ||
|
|
ec60620534 | ||
|
|
4606c35c40 | ||
|
|
9f4c81342e | ||
|
|
89c01c9e7e | ||
|
|
ce980895b2 | ||
|
|
6c5de545f9 | ||
|
|
21a281b496 | ||
|
|
7cd6b78a9c | ||
|
|
b64e5dec1e | ||
|
|
d1514f3cec | ||
|
|
a972c4f50d | ||
|
|
7416309805 | ||
|
|
bb68f31527 | ||
|
|
85341d02ac | ||
|
|
59f3abd0bd | ||
|
|
2f3912df74 | ||
|
|
5ce756f350 | ||
|
|
6d370ec3c2 | ||
|
|
332c1ec46a | ||
|
|
4c5e3d73af | ||
|
|
2b2b72d113 | ||
|
|
5b0b365406 | ||
|
|
892cbe22f8 | ||
|
|
3883f0f669 | ||
|
|
30c8b39b23 | ||
|
|
e7b6bd6734 | ||
|
|
16919852d9 | ||
|
|
29545dcf10 | ||
|
|
b772852522 | ||
|
|
ebec1b3a16 | ||
|
|
00c8431cf8 | ||
|
|
b919cfff30 | ||
|
|
e441ed8e46 | ||
|
|
119533d602 | ||
|
|
9b9ce0c2ac | ||
|
|
3f0acbbb48 | ||
|
|
6dcd6c8f08 | ||
|
|
71be4901c3 | ||
|
|
5e13e71a84 | ||
|
|
08ee91a433 | ||
|
|
f790ff1961 | ||
|
|
c4aac5d7c5 | ||
|
|
316761c202 | ||
|
|
b248b13ac2 | ||
|
|
08f43acefb | ||
|
|
9a37dba765 | ||
|
|
a47eb31076 | ||
|
|
819fb6a6a8 |
@@ -46,6 +46,21 @@ This PR adds a `num?` parameter to `mkPatternFromTheorem` to control how many
|
||||
leading quantifiers are stripped when creating a pattern.
|
||||
```
|
||||
|
||||
**Changelog labels:** Add one `changelog-*` label to categorize the PR for release notes:
|
||||
- `changelog-language` - Language features and metaprograms
|
||||
- `changelog-tactics` - User facing tactics
|
||||
- `changelog-server` - Language server, widgets, and IDE extensions
|
||||
- `changelog-pp` - Pretty printing
|
||||
- `changelog-library` - Library
|
||||
- `changelog-compiler` - Compiler, runtime, and FFI
|
||||
- `changelog-lake` - Lake
|
||||
- `changelog-doc` - Documentation
|
||||
- `changelog-ffi` - FFI changes
|
||||
- `changelog-other` - Other changes
|
||||
- `changelog-no` - Do not include this PR in the release changelog
|
||||
|
||||
If you're unsure which label applies, it's fine to omit the label and let reviewers add it.
|
||||
|
||||
## CI Log Retrieval
|
||||
|
||||
When CI jobs fail, investigate immediately - don't wait for other jobs to complete. Individual job logs are often available even while other jobs are still running. Try `gh run view <run-id> --log` or `gh run view <run-id> --log-failed`, or use `gh run view <run-id> --job=<job-id>` to target the specific failed job. Sleeping is fine when asked to monitor CI and no failures exist yet, but once any job fails, investigate that failure immediately.
|
||||
|
||||
25
.github/workflows/pr-release.yml
vendored
25
.github/workflows/pr-release.yml
vendored
@@ -43,6 +43,19 @@ jobs:
|
||||
name: build-.*
|
||||
name_is_regexp: true
|
||||
|
||||
# Verify artifacts were downloaded before any side effects (tag creation, release deletion).
|
||||
- name: Verify release artifacts exist
|
||||
if: ${{ steps.workflow-info.outputs.pullRequestNumber != '' }}
|
||||
run: |
|
||||
shopt -s nullglob
|
||||
files=(artifacts/*/*)
|
||||
if [ ${#files[@]} -eq 0 ]; then
|
||||
echo "::error::No artifacts found matching artifacts/*/*"
|
||||
exit 1
|
||||
fi
|
||||
echo "Found ${#files[@]} artifacts to upload:"
|
||||
printf '%s\n' "${files[@]}"
|
||||
|
||||
- name: Push tag
|
||||
if: ${{ steps.workflow-info.outputs.pullRequestNumber != '' }}
|
||||
run: |
|
||||
@@ -74,18 +87,6 @@ jobs:
|
||||
gh release delete --repo ${{ github.repository_owner }}/lean4-pr-releases pr-release-${{ steps.workflow-info.outputs.pullRequestNumber }}-${{ env.SHORT_SHA }} -y || true
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.PR_RELEASES_TOKEN }}
|
||||
# Verify artifacts were downloaded (equivalent to fail_on_unmatched_files in the old action).
|
||||
- name: Verify release artifacts exist
|
||||
if: ${{ steps.workflow-info.outputs.pullRequestNumber != '' }}
|
||||
run: |
|
||||
shopt -s nullglob
|
||||
files=(artifacts/*/*)
|
||||
if [ ${#files[@]} -eq 0 ]; then
|
||||
echo "::error::No artifacts found matching artifacts/*/*"
|
||||
exit 1
|
||||
fi
|
||||
echo "Found ${#files[@]} artifacts to upload:"
|
||||
printf '%s\n' "${files[@]}"
|
||||
# We use `gh release create` instead of `softprops/action-gh-release` because
|
||||
# the latter enumerates all releases to check for existing ones, which fails
|
||||
# when the repository has more than 10000 releases (GitHub API pagination limit).
|
||||
|
||||
104
CMakeLists.txt
104
CMakeLists.txt
@@ -10,22 +10,22 @@ option(USE_MIMALLOC "use mimalloc" ON)
|
||||
get_cmake_property(vars CACHE_VARIABLES)
|
||||
foreach(var ${vars})
|
||||
get_property(currentHelpString CACHE "${var}" PROPERTY HELPSTRING)
|
||||
if("${var}" MATCHES "STAGE0_(.*)")
|
||||
if(var MATCHES "STAGE0_(.*)")
|
||||
list(APPEND STAGE0_ARGS "-D${CMAKE_MATCH_1}=${${var}}")
|
||||
elseif("${var}" MATCHES "STAGE1_(.*)")
|
||||
elseif(var MATCHES "STAGE1_(.*)")
|
||||
list(APPEND STAGE1_ARGS "-D${CMAKE_MATCH_1}=${${var}}")
|
||||
elseif("${currentHelpString}" MATCHES "No help, variable specified on the command line." OR "${currentHelpString}" STREQUAL "")
|
||||
elseif(currentHelpString MATCHES "No help, variable specified on the command line." OR currentHelpString STREQUAL "")
|
||||
list(APPEND CL_ARGS "-D${var}=${${var}}")
|
||||
if("${var}" MATCHES "USE_GMP|CHECK_OLEAN_VERSION|LEAN_VERSION_.*|LEAN_SPECIAL_VERSION_DESC")
|
||||
if(var MATCHES "USE_GMP|CHECK_OLEAN_VERSION|LEAN_VERSION_.*|LEAN_SPECIAL_VERSION_DESC")
|
||||
# must forward options that generate incompatible .olean format
|
||||
list(APPEND STAGE0_ARGS "-D${var}=${${var}}")
|
||||
elseif("${var}" MATCHES "LLVM*|PKG_CONFIG|USE_LAKE|USE_MIMALLOC")
|
||||
elseif(var MATCHES "LLVM*|PKG_CONFIG|USE_LAKE|USE_MIMALLOC")
|
||||
list(APPEND STAGE0_ARGS "-D${var}=${${var}}")
|
||||
endif()
|
||||
elseif("${var}" MATCHES "USE_MIMALLOC")
|
||||
elseif(var MATCHES "USE_MIMALLOC")
|
||||
list(APPEND CL_ARGS "-D${var}=${${var}}")
|
||||
list(APPEND STAGE0_ARGS "-D${var}=${${var}}")
|
||||
elseif(("${var}" MATCHES "CMAKE_.*") AND NOT ("${var}" MATCHES "CMAKE_BUILD_TYPE") AND NOT ("${var}" MATCHES "CMAKE_HOME_DIRECTORY"))
|
||||
elseif((var MATCHES "CMAKE_.*") AND NOT (var MATCHES "CMAKE_BUILD_TYPE") AND NOT (var MATCHES "CMAKE_HOME_DIRECTORY"))
|
||||
list(APPEND PLATFORM_ARGS "-D${var}=${${var}}")
|
||||
endif()
|
||||
endforeach()
|
||||
@@ -34,15 +34,15 @@ include(ExternalProject)
|
||||
project(LEAN CXX C)
|
||||
|
||||
if(NOT (DEFINED STAGE0_CMAKE_EXECUTABLE_SUFFIX))
|
||||
set(STAGE0_CMAKE_EXECUTABLE_SUFFIX "${CMAKE_EXECUTABLE_SUFFIX}")
|
||||
set(STAGE0_CMAKE_EXECUTABLE_SUFFIX "${CMAKE_EXECUTABLE_SUFFIX}")
|
||||
endif()
|
||||
|
||||
# Don't do anything with cadical on wasm
|
||||
if (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Emscripten")
|
||||
if(NOT CMAKE_SYSTEM_NAME MATCHES "Emscripten")
|
||||
find_program(CADICAL cadical)
|
||||
if(NOT CADICAL)
|
||||
set(CADICAL_CXX c++)
|
||||
if (CADICAL_USE_CUSTOM_CXX)
|
||||
if(CADICAL_USE_CUSTOM_CXX)
|
||||
set(CADICAL_CXX ${CMAKE_CXX_COMPILER})
|
||||
# Use same platform flags as for Lean executables, in particular from `prepare-llvm-linux.sh`,
|
||||
# but not Lean-specific `LEAN_EXTRA_CXX_FLAGS` such as fsanitize.
|
||||
@@ -54,42 +54,51 @@ if (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Emscripten")
|
||||
set(CADICAL_CXX "${CCACHE} ${CADICAL_CXX}")
|
||||
endif()
|
||||
# missing stdio locking API on Windows
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Windows")
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
string(APPEND CADICAL_CXXFLAGS " -DNUNLOCKED")
|
||||
endif()
|
||||
string(APPEND CADICAL_CXXFLAGS " -DNCLOSEFROM")
|
||||
ExternalProject_add(cadical
|
||||
ExternalProject_Add(
|
||||
cadical
|
||||
PREFIX cadical
|
||||
GIT_REPOSITORY https://github.com/arminbiere/cadical
|
||||
GIT_TAG rel-2.1.2
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND $(MAKE) -f ${CMAKE_SOURCE_DIR}/src/cadical.mk
|
||||
CMAKE_EXECUTABLE_SUFFIX=${CMAKE_EXECUTABLE_SUFFIX}
|
||||
CXX=${CADICAL_CXX}
|
||||
CXXFLAGS=${CADICAL_CXXFLAGS}
|
||||
LDFLAGS=${CADICAL_LDFLAGS}
|
||||
BUILD_COMMAND
|
||||
$(MAKE) -f ${CMAKE_SOURCE_DIR}/src/cadical.mk CMAKE_EXECUTABLE_SUFFIX=${CMAKE_EXECUTABLE_SUFFIX}
|
||||
CXX=${CADICAL_CXX} CXXFLAGS=${CADICAL_CXXFLAGS} LDFLAGS=${CADICAL_LDFLAGS}
|
||||
BUILD_IN_SOURCE ON
|
||||
INSTALL_COMMAND "")
|
||||
set(CADICAL ${CMAKE_BINARY_DIR}/cadical/cadical${CMAKE_EXECUTABLE_SUFFIX} CACHE FILEPATH "path to cadical binary" FORCE)
|
||||
INSTALL_COMMAND ""
|
||||
)
|
||||
set(
|
||||
CADICAL
|
||||
${CMAKE_BINARY_DIR}/cadical/cadical${CMAKE_EXECUTABLE_SUFFIX}
|
||||
CACHE FILEPATH
|
||||
"path to cadical binary"
|
||||
FORCE
|
||||
)
|
||||
list(APPEND EXTRA_DEPENDS cadical)
|
||||
endif()
|
||||
list(APPEND CL_ARGS -DCADICAL=${CADICAL})
|
||||
endif()
|
||||
|
||||
if (USE_MIMALLOC)
|
||||
ExternalProject_add(mimalloc
|
||||
if(USE_MIMALLOC)
|
||||
ExternalProject_Add(
|
||||
mimalloc
|
||||
PREFIX mimalloc
|
||||
GIT_REPOSITORY https://github.com/microsoft/mimalloc
|
||||
GIT_TAG v2.2.3
|
||||
# just download, we compile it as part of each stage as it is small
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND "")
|
||||
INSTALL_COMMAND ""
|
||||
)
|
||||
list(APPEND EXTRA_DEPENDS mimalloc)
|
||||
endif()
|
||||
|
||||
if (NOT STAGE1_PREV_STAGE)
|
||||
ExternalProject_add(stage0
|
||||
if(NOT STAGE1_PREV_STAGE)
|
||||
ExternalProject_Add(
|
||||
stage0
|
||||
SOURCE_DIR "${LEAN_SOURCE_DIR}/stage0"
|
||||
SOURCE_SUBDIR src
|
||||
BINARY_DIR stage0
|
||||
@@ -97,38 +106,49 @@ if (NOT STAGE1_PREV_STAGE)
|
||||
# (however, CI will override this as we need to embed the githash into the stage 1 library built
|
||||
# by stage 0)
|
||||
CMAKE_ARGS -DSTAGE=0 -DUSE_GITHASH=OFF ${PLATFORM_ARGS} ${STAGE0_ARGS}
|
||||
BUILD_ALWAYS ON # cmake doesn't auto-detect changes without a download method
|
||||
INSTALL_COMMAND "" # skip install
|
||||
BUILD_ALWAYS
|
||||
ON # cmake doesn't auto-detect changes without a download method
|
||||
INSTALL_COMMAND
|
||||
"" # skip install
|
||||
DEPENDS ${EXTRA_DEPENDS}
|
||||
)
|
||||
list(APPEND EXTRA_DEPENDS stage0)
|
||||
endif()
|
||||
ExternalProject_add(stage1
|
||||
ExternalProject_Add(
|
||||
stage1
|
||||
SOURCE_DIR "${LEAN_SOURCE_DIR}"
|
||||
SOURCE_SUBDIR src
|
||||
BINARY_DIR stage1
|
||||
CMAKE_ARGS -DSTAGE=1 -DPREV_STAGE=${CMAKE_BINARY_DIR}/stage0 -DPREV_STAGE_CMAKE_EXECUTABLE_SUFFIX=${STAGE0_CMAKE_EXECUTABLE_SUFFIX} ${CL_ARGS} ${STAGE1_ARGS}
|
||||
CMAKE_ARGS
|
||||
-DSTAGE=1 -DPREV_STAGE=${CMAKE_BINARY_DIR}/stage0
|
||||
-DPREV_STAGE_CMAKE_EXECUTABLE_SUFFIX=${STAGE0_CMAKE_EXECUTABLE_SUFFIX} ${CL_ARGS} ${STAGE1_ARGS}
|
||||
BUILD_ALWAYS ON
|
||||
INSTALL_COMMAND ""
|
||||
DEPENDS ${EXTRA_DEPENDS}
|
||||
STEP_TARGETS configure
|
||||
)
|
||||
ExternalProject_add(stage2
|
||||
ExternalProject_Add(
|
||||
stage2
|
||||
SOURCE_DIR "${LEAN_SOURCE_DIR}"
|
||||
SOURCE_SUBDIR src
|
||||
BINARY_DIR stage2
|
||||
CMAKE_ARGS -DSTAGE=2 -DPREV_STAGE=${CMAKE_BINARY_DIR}/stage1 -DPREV_STAGE_CMAKE_EXECUTABLE_SUFFIX=${CMAKE_EXECUTABLE_SUFFIX} ${CL_ARGS}
|
||||
CMAKE_ARGS
|
||||
-DSTAGE=2 -DPREV_STAGE=${CMAKE_BINARY_DIR}/stage1 -DPREV_STAGE_CMAKE_EXECUTABLE_SUFFIX=${CMAKE_EXECUTABLE_SUFFIX}
|
||||
${CL_ARGS}
|
||||
BUILD_ALWAYS ON
|
||||
INSTALL_COMMAND ""
|
||||
DEPENDS stage1
|
||||
EXCLUDE_FROM_ALL ON
|
||||
STEP_TARGETS configure
|
||||
)
|
||||
ExternalProject_add(stage3
|
||||
ExternalProject_Add(
|
||||
stage3
|
||||
SOURCE_DIR "${LEAN_SOURCE_DIR}"
|
||||
SOURCE_SUBDIR src
|
||||
BINARY_DIR stage3
|
||||
CMAKE_ARGS -DSTAGE=3 -DPREV_STAGE=${CMAKE_BINARY_DIR}/stage2 -DPREV_STAGE_CMAKE_EXECUTABLE_SUFFIX=${CMAKE_EXECUTABLE_SUFFIX} ${CL_ARGS}
|
||||
CMAKE_ARGS
|
||||
-DSTAGE=3 -DPREV_STAGE=${CMAKE_BINARY_DIR}/stage2 -DPREV_STAGE_CMAKE_EXECUTABLE_SUFFIX=${CMAKE_EXECUTABLE_SUFFIX}
|
||||
${CL_ARGS}
|
||||
BUILD_ALWAYS ON
|
||||
INSTALL_COMMAND ""
|
||||
DEPENDS stage2
|
||||
@@ -137,24 +157,14 @@ ExternalProject_add(stage3
|
||||
|
||||
# targets forwarded to appropriate stages
|
||||
|
||||
add_custom_target(update-stage0
|
||||
COMMAND $(MAKE) -C stage1 update-stage0
|
||||
DEPENDS stage1)
|
||||
add_custom_target(update-stage0 COMMAND $(MAKE) -C stage1 update-stage0 DEPENDS stage1)
|
||||
|
||||
add_custom_target(update-stage0-commit
|
||||
COMMAND $(MAKE) -C stage1 update-stage0-commit
|
||||
DEPENDS stage1)
|
||||
add_custom_target(update-stage0-commit COMMAND $(MAKE) -C stage1 update-stage0-commit DEPENDS stage1)
|
||||
|
||||
add_custom_target(test
|
||||
COMMAND $(MAKE) -C stage1 test
|
||||
DEPENDS stage1)
|
||||
add_custom_target(test COMMAND $(MAKE) -C stage1 test DEPENDS stage1)
|
||||
|
||||
add_custom_target(clean-stdlib
|
||||
COMMAND $(MAKE) -C stage1 clean-stdlib
|
||||
DEPENDS stage1)
|
||||
add_custom_target(clean-stdlib COMMAND $(MAKE) -C stage1 clean-stdlib DEPENDS stage1)
|
||||
|
||||
install(CODE "execute_process(COMMAND make -C stage1 install)")
|
||||
|
||||
add_custom_target(check-stage3
|
||||
COMMAND diff "stage2/bin/lean" "stage3/bin/lean"
|
||||
DEPENDS stage3)
|
||||
add_custom_target(check-stage3 COMMAND diff "stage2/bin/lean" "stage3/bin/lean" DEPENDS stage3)
|
||||
|
||||
13
script/fmt
Executable file
13
script/fmt
Executable file
@@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# This script expects to be run from the repo root.
|
||||
|
||||
# Format cmake files
|
||||
find -regex '.*/CMakeLists\.txt\(\.in\)?\|.*\.cmake\(\.in\)?' \
|
||||
! -path './build/*' \
|
||||
! -path "./stage0/*" \
|
||||
-exec \
|
||||
uvx gersemi --in-place --line-length 120 --indent 2 \
|
||||
--definitions src/cmake/Modules/ src/CMakeLists.txt \
|
||||
-- {} +
|
||||
File diff suppressed because it is too large
Load Diff
@@ -469,13 +469,13 @@ namespace EStateM
|
||||
|
||||
instance : LawfulMonad (EStateM ε σ) := .mk'
|
||||
(id_map := fun x => funext <| fun s => by
|
||||
dsimp only [EStateM.instMonad, EStateM.map]
|
||||
simp only [Functor.map, EStateM.map]
|
||||
match x s with
|
||||
| .ok _ _ => rfl
|
||||
| .error _ _ => rfl)
|
||||
(pure_bind := fun _ _ => by rfl)
|
||||
(bind_assoc := fun x _ _ => funext <| fun s => by
|
||||
dsimp only [EStateM.instMonad, EStateM.bind]
|
||||
simp only [bind, EStateM.bind]
|
||||
match x s with
|
||||
| .ok _ _ => rfl
|
||||
| .error _ _ => rfl)
|
||||
|
||||
@@ -932,6 +932,14 @@ noncomputable def HEq.ndrec.{u1, u2} {α : Sort u2} {a : α} {motive : {β : Sor
|
||||
noncomputable def HEq.ndrecOn.{u1, u2} {α : Sort u2} {a : α} {motive : {β : Sort u2} → β → Sort u1} {β : Sort u2} {b : β} (h : a ≍ b) (m : motive a) : motive b :=
|
||||
h.rec m
|
||||
|
||||
/-- `HEq.ndrec` specialized to homogeneous heterogeneous equality -/
|
||||
noncomputable def HEq.homo_ndrec.{u1, u2} {α : Sort u2} {a : α} {motive : α → Sort u1} (m : motive a) {b : α} (h : a ≍ b) : motive b :=
|
||||
(eq_of_heq h).ndrec m
|
||||
|
||||
/-- `HEq.ndrec` specialized to homogeneous heterogeneous equality, symmetric variant -/
|
||||
noncomputable def HEq.homo_ndrec_symm.{u1, u2} {α : Sort u2} {a : α} {motive : α → Sort u1} (m : motive a) {b : α} (h : b ≍ a) : motive b :=
|
||||
(eq_of_heq h).ndrec_symm m
|
||||
|
||||
/-- `HEq.ndrec` variant -/
|
||||
noncomputable def HEq.elim {α : Sort u} {a : α} {p : α → Sort v} {b : α} (h₁ : a ≍ b) (h₂ : p a) : p b :=
|
||||
eq_of_heq h₁ ▸ h₂
|
||||
@@ -1478,6 +1486,29 @@ def Prod.map {α₁ : Type u₁} {α₂ : Type u₂} {β₁ : Type v₁} {β₂
|
||||
|
||||
/-! # Dependent products -/
|
||||
|
||||
instance {α : Type u} {β : α → Type v} [h₁ : DecidableEq α] [h₂ : ∀ a, DecidableEq (β a)] :
|
||||
DecidableEq (Sigma β)
|
||||
| ⟨a₁, b₁⟩, ⟨a₂, b₂⟩ =>
|
||||
match a₁, b₁, a₂, b₂, h₁ a₁ a₂ with
|
||||
| _, b₁, _, b₂, isTrue (Eq.refl _) =>
|
||||
match b₁, b₂, h₂ _ b₁ b₂ with
|
||||
| _, _, isTrue (Eq.refl _) => isTrue rfl
|
||||
| _, _, isFalse n => isFalse fun h ↦
|
||||
Sigma.noConfusion rfl .rfl (heq_of_eq h) fun _ e₂ ↦ n (eq_of_heq e₂)
|
||||
| _, _, _, _, isFalse n => isFalse fun h ↦
|
||||
Sigma.noConfusion rfl .rfl (heq_of_eq h) fun e₁ _ ↦ n (eq_of_heq e₁)
|
||||
|
||||
instance {α : Sort u} {β : α → Sort v} [h₁ : DecidableEq α] [h₂ : ∀ a, DecidableEq (β a)] : DecidableEq (PSigma β)
|
||||
| ⟨a₁, b₁⟩, ⟨a₂, b₂⟩ =>
|
||||
match a₁, b₁, a₂, b₂, h₁ a₁ a₂ with
|
||||
| _, b₁, _, b₂, isTrue (Eq.refl _) =>
|
||||
match b₁, b₂, h₂ _ b₁ b₂ with
|
||||
| _, _, isTrue (Eq.refl _) => isTrue rfl
|
||||
| _, _, isFalse n => isFalse fun h ↦
|
||||
PSigma.noConfusion rfl .rfl (heq_of_eq h) fun _ e₂ ↦ n (eq_of_heq e₂)
|
||||
| _, _, _, _, isFalse n => isFalse fun h ↦
|
||||
PSigma.noConfusion rfl .rfl (heq_of_eq h) fun e₁ _ ↦ n (eq_of_heq e₁)
|
||||
|
||||
theorem Exists.of_psigma_prop {α : Sort u} {p : α → Prop} : (PSigma (fun x => p x)) → Exists (fun x => p x)
|
||||
| ⟨x, hx⟩ => ⟨x, hx⟩
|
||||
|
||||
|
||||
@@ -30,3 +30,4 @@ public import Init.Data.Array.Erase
|
||||
public import Init.Data.Array.Zip
|
||||
public import Init.Data.Array.InsertIdx
|
||||
public import Init.Data.Array.Extract
|
||||
public import Init.Data.Array.MinMax
|
||||
|
||||
@@ -3065,6 +3065,18 @@ theorem foldl_eq_foldlM {f : β → α → β} {b} {xs : Array α} {start stop :
|
||||
theorem foldr_eq_foldrM {f : α → β → β} {b} {xs : Array α} {start stop : Nat} :
|
||||
xs.foldr f b start stop = (xs.foldrM (m := Id) (pure <| f · ·) b start stop).run := rfl
|
||||
|
||||
public theorem foldl_eq_foldl_extract {xs : Array α} {f : β → α → β} {init : β} :
|
||||
xs.foldl (init := init) (start := start) (stop := stop) f =
|
||||
(xs.extract start stop).foldl (init := init) f := by
|
||||
simp only [foldl_eq_foldlM]
|
||||
rw [foldlM_start_stop]
|
||||
|
||||
public theorem foldr_eq_foldr_extract {xs : Array α} {f : α → β → β} {init : β} :
|
||||
xs.foldr (init := init) (start := start) (stop := stop) f =
|
||||
(xs.extract stop start).foldr (init := init) f := by
|
||||
simp only [foldr_eq_foldrM]
|
||||
rw [foldrM_start_stop]
|
||||
|
||||
@[simp] theorem id_run_foldlM {f : β → α → Id β} {b} {xs : Array α} {start stop : Nat} :
|
||||
Id.run (xs.foldlM f b start stop) = xs.foldl (f · · |>.run) b start stop := rfl
|
||||
|
||||
|
||||
401
src/Init/Data/Array/MinMax.lean
Normal file
401
src/Init/Data/Array/MinMax.lean
Normal file
@@ -0,0 +1,401 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Paul Reichert
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.Data.Array.Bootstrap
|
||||
public import Init.Data.Array.Lemmas
|
||||
public import Init.Data.Array.DecidableEq
|
||||
import Init.Data.List.MinMax
|
||||
import Init.Data.List.ToArray
|
||||
|
||||
namespace Array
|
||||
|
||||
/-! ## Minima and maxima -/
|
||||
|
||||
/-! ### min -/
|
||||
|
||||
/--
|
||||
Returns the smallest element of a non-empty array.
|
||||
|
||||
Examples:
|
||||
* `#[4].min (by decide) = 4`
|
||||
* `#[1, 4, 2, 10, 6].min (by decide) = 1`
|
||||
-/
|
||||
public protected def min [Min α] (arr : Array α) (h : arr ≠ #[]) : α :=
|
||||
haveI : arr.size > 0 := by simp [Array.size_pos_iff, h]
|
||||
arr.foldl min arr[0] (start := 1)
|
||||
|
||||
/-! ### min? -/
|
||||
|
||||
/--
|
||||
Returns the smallest element of the array if it is not empty, or `none` if it is empty.
|
||||
|
||||
Examples:
|
||||
* `#[].min? = none`
|
||||
* `#[4].min? = some 4`
|
||||
* `#[1, 4, 2, 10, 6].min? = some 1`
|
||||
-/
|
||||
public protected def min? [Min α] (arr : Array α) : Option α :=
|
||||
if h : arr ≠ #[] then
|
||||
some (arr.min h)
|
||||
else
|
||||
none
|
||||
|
||||
/-! ### max -/
|
||||
|
||||
/--
|
||||
Returns the largest element of a non-empty array.
|
||||
|
||||
Examples:
|
||||
* `#[4].max (by decide) = 4`
|
||||
* `#[1, 4, 2, 10, 6].max (by decide) = 10`
|
||||
-/
|
||||
public protected def max [Max α] (arr : Array α) (h : arr ≠ #[]) : α :=
|
||||
haveI : arr.size > 0 := by simp [Array.size_pos_iff, h]
|
||||
arr.foldl max arr[0] (start := 1)
|
||||
|
||||
/-! ### max? -/
|
||||
|
||||
/--
|
||||
Returns the largest element of the array if it is not empty, or `none` if it is empty.
|
||||
|
||||
Examples:
|
||||
* `#[].max? = none`
|
||||
* `#[4].max? = some 4`
|
||||
* `#[1, 4, 2, 10, 6].max? = some 10`
|
||||
-/
|
||||
public protected def max? [Max α] (arr : Array α) : Option α :=
|
||||
if h : arr ≠ #[] then
|
||||
some (arr.max h)
|
||||
else
|
||||
none
|
||||
|
||||
/-! ### Compatibility with `List` -/
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem _root_.List.min_toArray [Min α] {l : List α} {h} :
|
||||
l.toArray.min h = l.min (by simpa [List.ne_nil_iff_length_pos] using h) := by
|
||||
let h' : l ≠ [] := by simpa [List.ne_nil_iff_length_pos] using h
|
||||
change l.toArray.min h = l.min h'
|
||||
rw [Array.min]
|
||||
· induction l
|
||||
· contradiction
|
||||
· rename_i x xs
|
||||
simp only [List.getElem_toArray, List.getElem_cons_zero, List.size_toArray, List.length_cons]
|
||||
rw [List.toArray_cons, foldl_eq_foldl_extract]
|
||||
rw [← Array.foldl_toList, Array.toList_extract, List.extract_eq_drop_take]
|
||||
simp [List.min]
|
||||
|
||||
public theorem _root_.List.min_eq_min_toArray [Min α] {l : List α} {h} :
|
||||
l.min h = l.toArray.min (by simpa [List.ne_nil_iff_length_pos] using h) := by
|
||||
simp
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem min_toList [Min α] {xs : Array α} {h} :
|
||||
xs.toList.min h = xs.min (by simpa [List.ne_nil_iff_length_pos] using h) := by
|
||||
cases xs; simp
|
||||
|
||||
public theorem min_eq_min_toList [Min α] {xs : Array α} {h} :
|
||||
xs.min h = xs.toList.min (by simpa [List.ne_nil_iff_length_pos] using h) := by
|
||||
simp
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem _root_.List.min?_toArray [Min α] {l : List α} :
|
||||
l.toArray.min? = l.min? := by
|
||||
rw [Array.min?]
|
||||
split
|
||||
· simp [List.min_toArray, List.min_eq_get_min?, - List.get_min?]
|
||||
· simp_all
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem min?_toList [Min α] {xs : Array α} :
|
||||
xs.toList.min? = xs.min? := by
|
||||
cases xs; simp
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem _root_.List.max_toArray [Max α] {l : List α} {h} :
|
||||
l.toArray.max h = l.max (by simpa [List.ne_nil_iff_length_pos] using h) := by
|
||||
let h' : l ≠ [] := by simpa [List.ne_nil_iff_length_pos] using h
|
||||
change l.toArray.max h = l.max h'
|
||||
rw [Array.max]
|
||||
· induction l
|
||||
· contradiction
|
||||
· rename_i x xs
|
||||
simp only [List.getElem_toArray, List.getElem_cons_zero, List.size_toArray, List.length_cons]
|
||||
rw [List.toArray_cons, foldl_eq_foldl_extract]
|
||||
rw [← Array.foldl_toList, Array.toList_extract, List.extract_eq_drop_take]
|
||||
simp [List.max]
|
||||
|
||||
public theorem _root_.List.max_eq_max_toArray [Max α] {l : List α} {h} :
|
||||
l.max h = l.toArray.max (by simpa [List.ne_nil_iff_length_pos] using h) := by
|
||||
simp
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem max_toList [Max α] {xs : Array α} {h} :
|
||||
xs.toList.max h = xs.max (by simpa [List.ne_nil_iff_length_pos] using h) := by
|
||||
cases xs; simp
|
||||
|
||||
public theorem max_eq_max_toList [Max α] {xs : Array α} {h} :
|
||||
xs.max h = xs.toList.max (by simpa [List.ne_nil_iff_length_pos] using h) := by
|
||||
simp
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem _root_.List.max?_toArray [Max α] {l : List α} :
|
||||
l.toArray.max? = l.max? := by
|
||||
rw [Array.max?]
|
||||
split
|
||||
· simp [List.max_toArray, List.max_eq_get_max?, - List.get_max?]
|
||||
· simp_all
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem max?_toList [Max α] {xs : Array α} :
|
||||
xs.toList.max? = xs.max? := by
|
||||
cases xs; simp
|
||||
|
||||
/-! ### Lemmas about `min?` -/
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem min?_empty [Min α] : (#[] : Array α).min? = none :=
|
||||
(rfl)
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem min?_singleton [Min α] {x : α} : #[x].min? = some x :=
|
||||
(rfl)
|
||||
|
||||
-- We don't put `@[simp]` on `min?_singleton_append'`,
|
||||
-- because the definition in terms of `foldl` is not useful for proofs.
|
||||
public theorem min?_singleton_append' [Min α] {xs : Array α} :
|
||||
(#[x] ++ xs).min? = some (xs.foldl min x) := by
|
||||
simp [← min?_toList, toList_append, List.min?]
|
||||
|
||||
@[simp]
|
||||
public theorem min?_singleton_append [Min α] [Std.Associative (min : α → α → α)] {xs : Array α} :
|
||||
(#[x] ++ xs).min? = some (xs.min?.elim x (min x)) := by
|
||||
simp [← min?_toList, toList_append, List.min?_cons]
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem min?_eq_none_iff {xs : Array α} [Min α] : xs.min? = none ↔ xs = #[] := by
|
||||
rcases xs with ⟨l⟩
|
||||
simp
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem isSome_min?_iff {xs : Array α} [Min α] : xs.min?.isSome ↔ xs ≠ #[] := by
|
||||
rcases xs with ⟨l⟩
|
||||
simp
|
||||
|
||||
@[grind .]
|
||||
public theorem isSome_min?_of_mem {xs : Array α} [Min α] {a : α} (h : a ∈ xs) :
|
||||
xs.min?.isSome := by
|
||||
rw [← min?_toList]
|
||||
apply List.isSome_min?_of_mem (a := a)
|
||||
simpa
|
||||
|
||||
public theorem isSome_min?_of_ne_empty [Min α] (xs : Array α) (h : xs ≠ #[]) : xs.min?.isSome := by
|
||||
rw [← min?_toList]
|
||||
apply List.isSome_min?_of_ne_nil
|
||||
simpa
|
||||
|
||||
public theorem min?_mem [Min α] [Std.MinEqOr α] (xs : Array α) (h : xs.min? = some a) : a ∈ xs := by
|
||||
rw [← min?_toList] at h
|
||||
simpa using List.min?_mem h
|
||||
|
||||
public theorem le_min?_iff [Min α] [LE α] [Std.LawfulOrderInf α] :
|
||||
{xs : Array α} → xs.min? = some a → ∀ {x}, x ≤ a ↔ ∀ b, b ∈ xs → x ≤ b := by
|
||||
intro xs h x
|
||||
simp only [← min?_toList] at h
|
||||
simpa using List.le_min?_iff h
|
||||
|
||||
public theorem min?_eq_some_iff [Min α] [LE α] {xs : Array α} [Std.IsLinearOrder α]
|
||||
[Std.LawfulOrderMin α] : xs.min? = some a ↔ a ∈ xs ∧ ∀ b, b ∈ xs → a ≤ b := by
|
||||
rcases xs with ⟨l⟩
|
||||
simpa using List.min?_eq_some_iff
|
||||
|
||||
public theorem min?_replicate [Min α] [Std.IdempotentOp (min : α → α → α)] {n : Nat} {a : α} :
|
||||
(replicate n a).min? = if n = 0 then none else some a := by
|
||||
rw [← List.toArray_replicate, List.min?_toArray, List.min?_replicate]
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem min?_replicate_of_pos [Min α] [Std.MinEqOr α] {n : Nat} {a : α} (h : 0 < n) :
|
||||
(replicate n a).min? = some a := by
|
||||
simp [min?_replicate, Nat.ne_of_gt h]
|
||||
|
||||
public theorem foldl_min [Min α] [Std.IdempotentOp (min : α → α → α)]
|
||||
[Std.Associative (min : α → α → α)] {xs : Array α} {a : α} :
|
||||
xs.foldl (init := a) min = min a (xs.min?.getD a) := by
|
||||
rcases xs with ⟨l⟩
|
||||
simp [List.foldl_min]
|
||||
|
||||
/-! ### Lemmas about `max?` -/
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem max?_empty [Max α] : (#[] : Array α).max? = none :=
|
||||
(rfl)
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem max?_singleton [Max α] {x : α} : #[x].max? = some x :=
|
||||
(rfl)
|
||||
|
||||
-- We don't put `@[simp]` on `max?_singleton_append'`,
|
||||
-- because the definition in terms of `foldl` is not useful for proofs.
|
||||
public theorem max?_singleton_append' [Max α] {xs : Array α} : (#[x] ++ xs).max? = some (xs.foldl max x) := by
|
||||
simp [← max?_toList, toList_append, List.max?]
|
||||
|
||||
@[simp]
|
||||
public theorem max?_singleton_append [Max α] [Std.Associative (max : α → α → α)] {xs : Array α} :
|
||||
(#[x] ++ xs).max? = some (xs.max?.elim x (max x)) := by
|
||||
simp [← max?_toList, toList_append, List.max?_cons]
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem max?_eq_none_iff {xs : Array α} [Max α] : xs.max? = none ↔ xs = #[] := by
|
||||
rcases xs with ⟨l⟩
|
||||
simp
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem isSome_max?_iff {xs : Array α} [Max α] : xs.max?.isSome ↔ xs ≠ #[] := by
|
||||
rcases xs with ⟨l⟩
|
||||
simp
|
||||
|
||||
@[grind .]
|
||||
public theorem isSome_max?_of_mem {xs : Array α} [Max α] {a : α} (h : a ∈ xs) :
|
||||
xs.max?.isSome := by
|
||||
rw [← max?_toList]
|
||||
apply List.isSome_max?_of_mem (a := a)
|
||||
simpa
|
||||
|
||||
public theorem isSome_max?_of_ne_empty [Max α] (xs : Array α) (h : xs ≠ #[]) : xs.max?.isSome := by
|
||||
rw [← max?_toList]
|
||||
apply List.isSome_max?_of_ne_nil
|
||||
simpa
|
||||
|
||||
public theorem max?_mem [Max α] [Std.MaxEqOr α] (xs : Array α) (h : xs.max? = some a) : a ∈ xs := by
|
||||
rw [← max?_toList] at h
|
||||
simpa using List.max?_mem h
|
||||
|
||||
public theorem max?_le_iff [Max α] [LE α] [Std.LawfulOrderSup α] :
|
||||
{xs : Array α} → xs.max? = some a → ∀ {x}, a ≤ x ↔ ∀ b, b ∈ xs → b ≤ x := by
|
||||
intro xs h x
|
||||
simp only [← max?_toList] at h
|
||||
simpa using List.max?_le_iff h
|
||||
|
||||
public theorem max?_eq_some_iff [Max α] [LE α] {xs : Array α} [Std.IsLinearOrder α]
|
||||
[Std.LawfulOrderMax α] : xs.max? = some a ↔ a ∈ xs ∧ ∀ b, b ∈ xs → b ≤ a := by
|
||||
rcases xs with ⟨l⟩
|
||||
simpa using List.max?_eq_some_iff
|
||||
|
||||
public theorem max?_replicate [Max α] [Std.IdempotentOp (max : α → α → α)] {n : Nat} {a : α} :
|
||||
(replicate n a).max? = if n = 0 then none else some a := by
|
||||
rw [← List.toArray_replicate, List.max?_toArray, List.max?_replicate]
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem max?_replicate_of_pos [Max α] [Std.MaxEqOr α] {n : Nat} {a : α} (h : 0 < n) :
|
||||
(replicate n a).max? = some a := by
|
||||
simp [max?_replicate, Nat.ne_of_gt h]
|
||||
|
||||
public theorem foldl_max [Max α] [Std.IdempotentOp (max : α → α → α)] [Std.Associative (max : α → α → α)]
|
||||
{xs : Array α} {a : α} : xs.foldl (init := a) max = max a (xs.max?.getD a) := by
|
||||
rcases xs with ⟨l⟩
|
||||
simp [List.foldl_max]
|
||||
|
||||
/-! ### Lemmas about `min` -/
|
||||
|
||||
@[simp, grind =]
|
||||
theorem min_singleton [Min α] {x : α} :
|
||||
#[x].min (ne_empty_of_size_eq_add_one rfl) = x := by
|
||||
(rfl)
|
||||
|
||||
public theorem min?_eq_some_min [Min α] : {xs : Array α} → (h : xs ≠ #[]) →
|
||||
xs.min? = some (xs.min h)
|
||||
| ⟨a::as⟩, _ => by simp [Array.min, Array.min?]
|
||||
|
||||
public theorem min_eq_get_min? [Min α] : (xs : Array α) → (h : xs ≠ #[]) →
|
||||
xs.min h = xs.min?.get (xs.isSome_min?_of_ne_empty h)
|
||||
| ⟨a::as⟩, _ => by simp [Array.min, Array.min?]
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem get_min? [Min α] {xs : Array α} {h : xs.min?.isSome} :
|
||||
xs.min?.get h = xs.min (isSome_min?_iff.mp h) := by
|
||||
simp [min?_eq_some_min (isSome_min?_iff.mp h)]
|
||||
|
||||
@[grind .]
|
||||
public theorem min_mem [Min α] [Std.MinEqOr α] {xs : Array α} (h : xs ≠ #[]) : xs.min h ∈ xs :=
|
||||
xs.min?_mem (min?_eq_some_min h)
|
||||
|
||||
@[grind .]
|
||||
public theorem min_le_of_mem [Min α] [LE α] [Std.IsLinearOrder α] [Std.LawfulOrderMin α]
|
||||
{xs : Array α} {a : α} (ha : a ∈ xs) :
|
||||
xs.min (ne_empty_of_mem ha) ≤ a :=
|
||||
(Array.min?_eq_some_iff.mp (min?_eq_some_min (ne_empty_of_mem ha))).right a ha
|
||||
|
||||
public protected theorem le_min_iff [Min α] [LE α] [Std.LawfulOrderInf α]
|
||||
{xs : Array α} (h : xs ≠ #[]) : ∀ {x}, x ≤ xs.min h ↔ ∀ b, b ∈ xs → x ≤ b :=
|
||||
le_min?_iff (min?_eq_some_min h)
|
||||
|
||||
public theorem min_eq_iff [Min α] [LE α] {xs : Array α} [Std.IsLinearOrder α] [Std.LawfulOrderMin α]
|
||||
(h : xs ≠ #[]) : xs.min h = a ↔ a ∈ xs ∧ ∀ b, b ∈ xs → a ≤ b := by
|
||||
simpa [min?_eq_some_min h] using (min?_eq_some_iff (xs := xs))
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem min_replicate [Min α] [Std.MinEqOr α] {n : Nat} {a : α} (h : (replicate n a) ≠ #[]) :
|
||||
(replicate n a).min h = a := by
|
||||
have n_pos : 0 < n := by simpa [Nat.ne_zero_iff_zero_lt] using h
|
||||
simpa [min?_eq_some_min h] using (min?_replicate_of_pos (a := a) n_pos)
|
||||
|
||||
public theorem foldl_min_eq_min [Min α] [Std.IdempotentOp (min : α → α → α)]
|
||||
[Std.Associative (min : α → α → α)] {xs : Array α} (h : xs ≠ #[]) {a : α} :
|
||||
xs.foldl min a = min a (xs.min h) := by
|
||||
simpa [min?_eq_some_min h] using foldl_min (xs := xs)
|
||||
|
||||
/-! ### Lemmas about `max` -/
|
||||
|
||||
@[simp, grind =]
|
||||
theorem max_singleton [Max α] {x : α} :
|
||||
#[x].max (ne_empty_of_size_eq_add_one rfl) = x := by
|
||||
(rfl)
|
||||
|
||||
public theorem max?_eq_some_max [Max α] : {xs : Array α} → (h : xs ≠ #[]) →
|
||||
xs.max? = some (xs.max h)
|
||||
| ⟨a::as⟩, _ => by simp [Array.max, Array.max?]
|
||||
|
||||
public theorem max_eq_get_max? [Max α] : (xs : Array α) → (h : xs ≠ #[]) →
|
||||
xs.max h = xs.max?.get (xs.isSome_max?_of_ne_empty h)
|
||||
| ⟨a::as⟩, _ => by simp [Array.max, Array.max?]
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem get_max? [Max α] {xs : Array α} {h : xs.max?.isSome} :
|
||||
xs.max?.get h = xs.max (isSome_max?_iff.mp h) := by
|
||||
simp [max?_eq_some_max (isSome_max?_iff.mp h)]
|
||||
|
||||
@[grind .]
|
||||
public theorem max_mem [Max α] [Std.MaxEqOr α] {xs : Array α} (h : xs ≠ #[]) : xs.max h ∈ xs :=
|
||||
xs.max?_mem (max?_eq_some_max h)
|
||||
|
||||
public protected theorem max_le_iff [Max α] [LE α] [Std.LawfulOrderSup α]
|
||||
{xs : Array α} (h : xs ≠ #[]) : ∀ {x}, xs.max h ≤ x ↔ ∀ b, b ∈ xs → b ≤ x :=
|
||||
max?_le_iff (max?_eq_some_max h)
|
||||
|
||||
public theorem max_eq_iff [Max α] [LE α] {xs : Array α} [Std.IsLinearOrder α] [Std.LawfulOrderMax α]
|
||||
(h : xs ≠ #[]) : xs.max h = a ↔ a ∈ xs ∧ ∀ b, b ∈ xs → b ≤ a := by
|
||||
simpa [max?_eq_some_max h] using (max?_eq_some_iff (xs := xs))
|
||||
|
||||
@[grind .]
|
||||
public theorem le_max_of_mem [Max α] [LE α] [Std.IsLinearOrder α] [Std.LawfulOrderMax α]
|
||||
{xs : Array α} {a : α} (ha : a ∈ xs) :
|
||||
a ≤ xs.max (ne_empty_of_mem ha) :=
|
||||
(Array.max?_eq_some_iff.mp (max?_eq_some_max (ne_empty_of_mem ha))).right a ha
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem max_replicate [Max α] [Std.MaxEqOr α] {n : Nat} {a : α} (h : (replicate n a) ≠ #[]) :
|
||||
(replicate n a).max h = a := by
|
||||
have n_pos : 0 < n := by simpa [Nat.ne_zero_iff_zero_lt] using h
|
||||
simpa [max?_eq_some_max h] using (max?_replicate_of_pos (a := a) n_pos)
|
||||
|
||||
public theorem foldl_max_eq_max [Max α] [Std.IdempotentOp (max : α → α → α)]
|
||||
[Std.Associative (max : α → α → α)] {xs : Array α} (h : xs ≠ #[]) {a : α} :
|
||||
xs.foldl max a = max a (xs.max h) := by
|
||||
simpa [max?_eq_some_max h] using foldl_max (xs := xs)
|
||||
|
||||
end Array
|
||||
@@ -11,6 +11,8 @@ public import Init.Grind.Ordered.Ring
|
||||
|
||||
/-! # Internal `grind` algebra instances for `Dyadic`. -/
|
||||
|
||||
@[expose] public section
|
||||
|
||||
open Lean.Grind
|
||||
|
||||
namespace Dyadic
|
||||
|
||||
@@ -4,7 +4,9 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Kim Morrison
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.Data.Dyadic.Basic
|
||||
import Init.Data.Dyadic.Round
|
||||
import Init.Grind.Ordered.Ring
|
||||
|
||||
@@ -12,6 +14,8 @@ import Init.Grind.Ordered.Ring
|
||||
# Inversion for dyadic numbers
|
||||
-/
|
||||
|
||||
@[expose] public section
|
||||
|
||||
namespace Dyadic
|
||||
|
||||
/--
|
||||
|
||||
@@ -7,7 +7,7 @@ module
|
||||
|
||||
prelude
|
||||
public import Init.Data.Dyadic.Basic
|
||||
import all Init.Data.Dyadic.Instances
|
||||
import Init.Data.Dyadic.Instances
|
||||
import Init.Grind.Ordered.Rat
|
||||
import Init.Grind.Ordered.Field
|
||||
|
||||
|
||||
@@ -1153,6 +1153,15 @@ theorem ediv_le_iff_le_mul {k x y : Int} (h : 0 < k) : x / k ≤ y ↔ x < y * k
|
||||
rw [Int.le_iff_lt_add_one, Int.ediv_lt_iff_lt_mul h, Int.add_mul]
|
||||
omega
|
||||
|
||||
theorem le_mul_iff_le_left {x y z : Int} (hz : 0 < z) :
|
||||
x ≤ y * z ↔ (x + z - 1) / z ≤ y := by
|
||||
rw [Int.ediv_le_iff_le_mul hz]
|
||||
omega
|
||||
|
||||
theorem le_mul_iff_le_right {x y z : Int} (hy : 0 < y) :
|
||||
x ≤ y * z ↔ (x + y - 1) / y ≤ z := by
|
||||
rw [← le_mul_iff_le_left hy, Int.mul_comm]
|
||||
|
||||
protected theorem le_mul_of_ediv_le {a b c : Int} (H1 : 0 ≤ b) (H2 : b ∣ a) (H3 : a / b ≤ c) :
|
||||
a ≤ c * b := by
|
||||
rw [← Int.ediv_mul_cancel H2]; exact Int.mul_le_mul_of_nonneg_right H3 H1
|
||||
@@ -1206,6 +1215,11 @@ theorem add_ediv {a b c : Int} (h : c ≠ 0) :
|
||||
protected theorem ediv_le_ediv {a b c : Int} (H : 0 < c) (H' : a ≤ b) : a / c ≤ b / c :=
|
||||
Int.le_ediv_of_mul_le H (Int.le_trans (Int.ediv_mul_le _ (Int.ne_of_gt H)) H')
|
||||
|
||||
theorem ediv_add_ediv_le_add_ediv {x y z : Int} (hz : 0 < z) :
|
||||
x / z + y / z ≤ (x + y) / z := by
|
||||
rw [Int.le_ediv_iff_mul_le hz, Int.add_mul]
|
||||
apply Int.add_le_add <;> apply Int.ediv_mul_le <;> omega
|
||||
|
||||
/-- If `n > 0` then `m` is not divisible by `n` iff it is between `n * k` and `n * (k + 1)`
|
||||
for some `k`. -/
|
||||
theorem not_dvd_iff_lt_mul_succ (m : Int) (hn : 0 < n) :
|
||||
@@ -1783,12 +1797,12 @@ theorem ediv_lt_ediv_iff_of_dvd_of_neg_of_neg {a b c d : Int} (hb : b < 0) (hd :
|
||||
|
||||
theorem ediv_lt_ediv_of_lt {a b c : Int} (h : a < b) (hcb : c ∣ b) (hc : 0 < c) :
|
||||
a / c < b / c :=
|
||||
Int.lt_ediv_of_mul_lt (Int.le_of_lt hc) hcb
|
||||
Int.lt_ediv_of_mul_lt (Int.le_of_lt hc) hcb
|
||||
(Int.lt_of_le_of_lt (Int.ediv_mul_le _ (Int.ne_of_gt hc)) h)
|
||||
|
||||
|
||||
theorem ediv_lt_ediv_of_lt_of_neg {a b c : Int} (h : b < a) (hca : c ∣ a) (hc : c < 0) :
|
||||
a / c < b / c :=
|
||||
(Int.ediv_lt_iff_of_dvd_of_neg hc hca).2
|
||||
(Int.ediv_lt_iff_of_dvd_of_neg hc hca).2
|
||||
(Int.lt_of_le_of_lt (Int.mul_ediv_self_le (Int.ne_of_lt hc)) h)
|
||||
|
||||
/-! ### `tdiv` and ordering -/
|
||||
|
||||
@@ -94,8 +94,9 @@ By convention, the monadic iterator associated with an object can be obtained vi
|
||||
For example, `List.iterM IO` creates an iterator over a list in the monad `IO`.
|
||||
|
||||
See `Init.Data.Iterators.Consumers` for ways to use an iterator. For example, `it.toList` will
|
||||
convert a provably finite iterator `it` into a list and `it.allowNontermination.toList` will
|
||||
do so even if finiteness cannot be proved. It is also always possible to manually iterate using
|
||||
convert an iterator `it` into a list and `it.ensureTermination.toList` guarantees that this
|
||||
operation will terminate, given a proof that the iterator is finite.
|
||||
It is also always possible to manually iterate using
|
||||
`it.step`, relying on the termination measures `it.finitelyManySteps` and `it.finitelyManySkips`.
|
||||
|
||||
See `Iter` for a more convenient interface in case that no monadic effects are needed (`m = Id`).
|
||||
@@ -139,8 +140,9 @@ By convention, the monadic iterator associated with an object can be obtained vi
|
||||
For example, `List.iterM IO` creates an iterator over a list in the monad `IO`.
|
||||
|
||||
See `Init.Data.Iterators.Consumers` for ways to use an iterator. For example, `it.toList` will
|
||||
convert a provably finite iterator `it` into a list and `it.allowNontermination.toList` will
|
||||
do so even if finiteness cannot be proved. It is also always possible to manually iterate using
|
||||
convert an iterator `it` into a list and `it.ensureTermination.toList` guarantees that this
|
||||
operation will terminate, given a proof that the iterator is finite.
|
||||
It is also always possible to manually iterate using
|
||||
`it.step`, relying on the termination measures `it.finitelyManySteps` and `it.finitelyManySkips`.
|
||||
|
||||
See `IterM` for iterators that operate in a monad.
|
||||
@@ -754,8 +756,8 @@ def IterM.finitelyManySteps {α : Type w} {m : Type w → Type w'} {β : Type w}
|
||||
⟨it⟩
|
||||
|
||||
/--
|
||||
Termination measure to be used in well-founded recursive functions recursing over a finite iterator
|
||||
(see also `Finite`).
|
||||
Termination measure to be used in recursive functions built with `WellFounded.extrinsicFix`
|
||||
recursing over a finite iterator without requiring a proof of finiteness (see also `Finite`).
|
||||
-/
|
||||
@[expose]
|
||||
def IterM.finitelyManySteps! {α : Type w} {m : Type w → Type w'} {β : Type w} [Iterator α m β]
|
||||
@@ -796,6 +798,11 @@ def Iter.finitelyManySteps {α : Type w} {β : Type w} [Iterator α Id β] [Iter
|
||||
(it : Iter (α := α) β) : IterM.TerminationMeasures.Finite α Id :=
|
||||
it.toIterM.finitelyManySteps
|
||||
|
||||
@[inherit_doc IterM.finitelyManySteps!, expose]
|
||||
def Iter.finitelyManySteps! {α : Type w} {β : Type w} [Iterator α Id β]
|
||||
(it : Iter (α := α) β) : IterM.TerminationMeasures.Finite α Id :=
|
||||
it.toIterM.finitelyManySteps!
|
||||
|
||||
/--
|
||||
This theorem is used by a `decreasing_trivial` extension. It powers automatic termination proofs
|
||||
with `IterM.finitelyManySteps`.
|
||||
@@ -902,6 +909,16 @@ def IterM.finitelyManySkips {α : Type w} {m : Type w → Type w'} {β : Type w}
|
||||
[Iterators.Productive α m] (it : IterM (α := α) m β) : IterM.TerminationMeasures.Productive α m :=
|
||||
⟨it⟩
|
||||
|
||||
/--
|
||||
Termination measure to be used in recursive functions built with `WellFounded.extrinsicFix`
|
||||
recursing over a productive iterator without requiring a proof of productiveness
|
||||
(see also `Productive`).
|
||||
-/
|
||||
@[expose]
|
||||
def IterM.finitelyManySkips! {α : Type w} {m : Type w → Type w'} {β : Type w} [Iterator α m β]
|
||||
(it : IterM (α := α) m β) : IterM.TerminationMeasures.Productive α m :=
|
||||
⟨it⟩
|
||||
|
||||
/--
|
||||
This theorem is used by a `decreasing_trivial` extension. It powers automatic termination proofs
|
||||
with `IterM.finitelyManySkips`.
|
||||
@@ -922,6 +939,11 @@ def Iter.finitelyManySkips {α : Type w} {β : Type w} [Iterator α Id β] [Iter
|
||||
(it : Iter (α := α) β) : IterM.TerminationMeasures.Productive α Id :=
|
||||
it.toIterM.finitelyManySkips
|
||||
|
||||
@[inherit_doc IterM.finitelyManySkips!, expose]
|
||||
def Iter.finitelyManySkips! {α : Type w} {β : Type w} [Iterator α Id β]
|
||||
(it : Iter (α := α) β) : IterM.TerminationMeasures.Productive α Id :=
|
||||
it.toIterM.finitelyManySkips!
|
||||
|
||||
/--
|
||||
This theorem is used by a `decreasing_trivial` extension. It powers automatic termination proofs
|
||||
with `Iter.finitelyManySkips`.
|
||||
|
||||
@@ -21,21 +21,70 @@ If possible, takes `n` steps with the iterator `it` and
|
||||
returns the `n`-th emitted value, or `none` if `it` finished
|
||||
before emitting `n` values.
|
||||
|
||||
This function requires a `Productive` instance proving that the iterator will always emit a value
|
||||
after a finite number of skips. If the iterator is not productive or such an instance is not
|
||||
available, consider using `it.allowNontermination.atIdxSlow?` instead of `it.atIdxSlow?`. However,
|
||||
it is not possible to formally verify the behavior of the partial variant.
|
||||
If the iterator is not productive, this function might run forever in an endless loop of iterator
|
||||
steps. The variant `it.ensureTermination.atIdxSlow?` is guaranteed to terminate after finitely many
|
||||
steps.
|
||||
-/
|
||||
@[specialize]
|
||||
def Iter.atIdxSlow? {α β} [Iterator α Id β] [Productive α Id]
|
||||
def Iter.atIdxSlow? {α β} [Iterator α Id β]
|
||||
(n : Nat) (it : Iter (α := α) β) : Option β :=
|
||||
match it.step with
|
||||
| .yield it' out _ =>
|
||||
match n with
|
||||
| 0 => some out
|
||||
| k + 1 => it'.atIdxSlow? k
|
||||
| .skip it' _ => it'.atIdxSlow? n
|
||||
| .done _ => none
|
||||
WellFounded.extrinsicFix₂ (C₂ := fun _ _ => Option β) (α := Iter (α := α) β) (β := fun _ => Nat)
|
||||
(InvImage
|
||||
(Prod.Lex WellFoundedRelation.rel IterM.TerminationMeasures.Productive.Rel)
|
||||
(fun p => (p.2, p.1.finitelyManySkips!)))
|
||||
(fun it n recur =>
|
||||
match it.step with
|
||||
| .yield it' out _ =>
|
||||
match n with
|
||||
| 0 => some out
|
||||
| k + 1 => recur it' k (by decreasing_tactic)
|
||||
| .skip it' _ => recur it' n (by decreasing_tactic)
|
||||
| .done _ => none) it n
|
||||
|
||||
-- We provide the functional induction principle by hand because `atIdxSlow?` is implemented using
|
||||
-- `extrinsicFix₂` and not using well-founded recursion.
|
||||
/-
|
||||
An induction principle for `Iter.atIdxSlow?`.
|
||||
|
||||
This lemma provides a functional induction principle for reasoning about `Iter.atIdxSlow? n it`.
|
||||
|
||||
The induction follows the structure of iterator steps.
|
||||
- base case: when we reach the desired index (`n = 0`) and get a `.yield` step
|
||||
- inductive case: when we have a `.yield` step but need to continue (`n > 0`)
|
||||
- skip case: when we encounter a `.skip` step and continue with the same index
|
||||
- done case: when the iterator is exhausted and we return `none`
|
||||
-/
|
||||
theorem Iter.atIdxSlow?.induct_unfolding {α β : Type u} [Iterator α Id β] [Productive α Id]
|
||||
(motive : Nat → Iter β → Option β → Prop)
|
||||
-- Base case: we have reached index 0 and found a value
|
||||
(yield_zero : ∀ (it it' : Iter (α := α) β) (out : β) (property : it.IsPlausibleStep (IterStep.yield it' out)),
|
||||
it.step = ⟨IterStep.yield it' out, property⟩ → motive 0 it (some out))
|
||||
-- Inductive case: we have a yield but need to continue to a higher index
|
||||
(yield_succ : ∀ (it it' : Iter (α := α) β) (out : β) (property : it.IsPlausibleStep (IterStep.yield it' out)),
|
||||
it.step = ⟨IterStep.yield it' out, property⟩ →
|
||||
∀ (k : Nat), motive k it' (Iter.atIdxSlow? k it') → motive k.succ it (Iter.atIdxSlow? k it'))
|
||||
-- Skip case: we encounter a skip and continue with the same index
|
||||
(skip_case : ∀ (n : Nat) (it it' : Iter β) (property : it.IsPlausibleStep (IterStep.skip it')),
|
||||
it.step = ⟨IterStep.skip it', property⟩ →
|
||||
motive n it' (Iter.atIdxSlow? n it') → motive n it (Iter.atIdxSlow? n it'))
|
||||
-- Done case: the iterator is exhausted, return none
|
||||
(done_case : ∀ (n : Nat) (it : Iter β) (property : it.IsPlausibleStep IterStep.done),
|
||||
it.step = ⟨IterStep.done, property⟩ → motive n it none)
|
||||
-- The conclusion: the property holds for all indices and iterators
|
||||
(n : Nat) (it : Iter β) : motive n it (Iter.atIdxSlow? n it) := by
|
||||
simp only [atIdxSlow?] at *
|
||||
rw [WellFounded.extrinsicFix₂_eq_apply]
|
||||
· split
|
||||
· split
|
||||
· apply yield_zero <;> assumption
|
||||
· apply yield_succ
|
||||
all_goals try assumption
|
||||
apply Iter.atIdxSlow?.induct_unfolding <;> assumption
|
||||
· apply skip_case
|
||||
all_goals try assumption
|
||||
apply Iter.atIdxSlow?.induct_unfolding <;> assumption
|
||||
· apply done_case <;> assumption
|
||||
· exact InvImage.wf _ WellFoundedRelation.wf
|
||||
termination_by (n, it.finitelyManySkips)
|
||||
|
||||
/--
|
||||
@@ -43,22 +92,21 @@ If possible, takes `n` steps with the iterator `it` and
|
||||
returns the `n`-th emitted value, or `none` if `it` finished
|
||||
before emitting `n` values.
|
||||
|
||||
This is a partial, potentially nonterminating, function. It is not possible to formally verify
|
||||
its behavior. If the iterator has a `Productive` instance, consider using `Iter.atIdxSlow?` instead.
|
||||
This variant terminates after finitely many steps and requires a proof that the iterator is
|
||||
productive. If such a proof is not available, consider using `Iter.toArray`.
|
||||
-/
|
||||
@[specialize]
|
||||
partial def Iter.Partial.atIdxSlow? {α β} [Iterator α Id β] [Monad Id]
|
||||
(n : Nat) (it : Iter.Partial (α := α) β) : Option β := do
|
||||
match it.it.step with
|
||||
| .yield it' out _ =>
|
||||
match n with
|
||||
| 0 => some out
|
||||
| k + 1 => (⟨it'⟩ : Iter.Partial (α := α) β).atIdxSlow? k
|
||||
| .skip it' _ => (⟨it'⟩ : Iter.Partial (α := α) β).atIdxSlow? n
|
||||
| .done _ => none
|
||||
@[inline]
|
||||
def Iter.Total.atIdxSlow? {α β} [Iterator α Id β] [Productive α Id]
|
||||
(n : Nat) (it : Iter.Total (α := α) β) : Option β :=
|
||||
it.it.atIdxSlow? n
|
||||
|
||||
@[inline, inherit_doc Iter.atIdxSlow?, deprecated Iter.atIdxSlow? (since := "2026-01-28")]
|
||||
def Iter.Partial.atIdxSlow? {α β} [Iterator α Id β]
|
||||
(n : Nat) (it : Iter.Partial (α := α) β) : Option β :=
|
||||
it.it.atIdxSlow? n
|
||||
|
||||
@[always_inline, inline, inherit_doc IterM.atIdx?]
|
||||
def Iter.atIdx? {α β} [Iterator α Id β] [Productive α Id] [IteratorAccess α Id]
|
||||
def Iter.atIdx? {α β} [Iterator α Id β] [IteratorAccess α Id]
|
||||
(n : Nat) (it : Iter (α := α) β) : Option β :=
|
||||
match (IteratorAccess.nextAtIdx? it.toIterM n).run.val with
|
||||
| .yield _ out => some out
|
||||
|
||||
@@ -667,6 +667,42 @@ def Iter.Total.first? {α β : Type w} [Iterator α Id β] [IteratorLoop α Id I
|
||||
(it : Iter.Total (α := α) β) : Option β :=
|
||||
it.it.first?
|
||||
|
||||
/--
|
||||
Returns `true` if the iterator yields no values.
|
||||
|
||||
`O(|it|)` since the iterator may skip an unknown number of times before returning a result.
|
||||
Short-circuits upon encountering the first result. Only the first element of `it` is examined.
|
||||
|
||||
If the iterator is not productive, this function might run forever. The variant
|
||||
`it.ensureTermination.isEmpty` always terminates after finitely many steps.
|
||||
|
||||
Examples:
|
||||
* `[].iter.isEmpty = true`
|
||||
* `[1].iter.isEmpty = false`
|
||||
-/
|
||||
@[inline]
|
||||
def Iter.isEmpty {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
|
||||
(it : Iter (α := α) β) : Bool :=
|
||||
it.toIterM.isEmpty.run.down
|
||||
|
||||
/--
|
||||
Returns `true` if the iterator yields no values.
|
||||
|
||||
`O(|it|)` since the iterator may skip an unknown number of times before returning a result.
|
||||
Short-circuits upon encountering the first result. Only the first element of `it` is examined.
|
||||
|
||||
This variant terminates after finitely many steps and requires a proof that the iterator is
|
||||
productive. If such a proof is not available, consider using `Iter.isEmpty`.
|
||||
|
||||
Examples:
|
||||
* `[].iter.isEmpty = true`
|
||||
* `[1].iter.isEmpty = false`
|
||||
-/
|
||||
@[inline]
|
||||
def Iter.Total.isEmpty {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id] [Productive α Id]
|
||||
(it : Iter.Total (α := α) β) : Bool :=
|
||||
it.it.isEmpty
|
||||
|
||||
/--
|
||||
Steps through the whole iterator, counting the number of outputs emitted.
|
||||
|
||||
@@ -675,9 +711,15 @@ Steps through the whole iterator, counting the number of outputs emitted.
|
||||
This function's runtime is linear in the number of steps taken by the iterator.
|
||||
-/
|
||||
@[always_inline, inline, expose]
|
||||
def Iter.count {α : Type w} {β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
|
||||
def Iter.length {α : Type w} {β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
|
||||
(it : Iter (α := α) β) : Nat :=
|
||||
it.toIterM.count.run.down
|
||||
it.toIterM.length.run.down
|
||||
|
||||
@[inline, inherit_doc Iter.length, deprecated Iter.length (since := "2026-01-28"), expose]
|
||||
def Iter.count := @Iter.length
|
||||
|
||||
@[inline, inherit_doc Iter.length, deprecated Iter.length (since := "2025-10-29"), expose]
|
||||
def Iter.size := @Iter.length
|
||||
|
||||
/--
|
||||
Steps through the whole iterator, counting the number of outputs emitted.
|
||||
@@ -686,22 +728,10 @@ Steps through the whole iterator, counting the number of outputs emitted.
|
||||
|
||||
This function's runtime is linear in the number of steps taken by the iterator.
|
||||
-/
|
||||
@[always_inline, inline, expose, deprecated Iter.count (since := "2025-10-29")]
|
||||
def Iter.size {α : Type w} {β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
|
||||
(it : Iter (α := α) β) : Nat :=
|
||||
it.count
|
||||
|
||||
/--
|
||||
Steps through the whole iterator, counting the number of outputs emitted.
|
||||
|
||||
**Performance**:
|
||||
|
||||
This function's runtime is linear in the number of steps taken by the iterator.
|
||||
-/
|
||||
@[always_inline, inline, expose, deprecated Iter.count (since := "2025-12-04")]
|
||||
@[always_inline, inline, expose, deprecated Iter.length (since := "2025-12-04")]
|
||||
def Iter.Partial.count {α : Type w} {β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
|
||||
(it : Iter.Partial (α := α) β) : Nat :=
|
||||
it.it.toIterM.count.run.down
|
||||
it.it.toIterM.length.run.down
|
||||
|
||||
/--
|
||||
Steps through the whole iterator, counting the number of outputs emitted.
|
||||
@@ -710,9 +740,9 @@ Steps through the whole iterator, counting the number of outputs emitted.
|
||||
|
||||
This function's runtime is linear in the number of steps taken by the iterator.
|
||||
-/
|
||||
@[always_inline, inline, expose, deprecated Iter.count (since := "2025-10-29")]
|
||||
@[always_inline, inline, expose, deprecated Iter.length (since := "2025-10-29")]
|
||||
def Iter.Partial.size {α : Type w} {β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
|
||||
(it : Iter.Partial (α := α) β) : Nat :=
|
||||
it.it.count
|
||||
it.it.length
|
||||
|
||||
end Std
|
||||
|
||||
@@ -950,6 +950,38 @@ def IterM.Total.first? {α β : Type w} {m : Type w → Type w'} [Monad m] [Iter
|
||||
[IteratorLoop α m m] [Productive α m] (it : IterM.Total (α := α) m β) : m (Option β) :=
|
||||
it.it.first?
|
||||
|
||||
set_option doc.verso true in
|
||||
/--
|
||||
Returns {lean}`ULift.up true` if the iterator {name}`it` yields no values.
|
||||
|
||||
{lit}`O(|it|)` since the iterator may skip an unknown number of times before returning a result.
|
||||
Short-circuits upon encountering the first result. Only the first element of {name}`it` is examined.
|
||||
|
||||
If the iterator is not productive, this function might run forever. The variant
|
||||
{lit}`it.ensureTermination.isEmpty` always terminates after finitely many steps.
|
||||
-/
|
||||
@[always_inline]
|
||||
def IterM.isEmpty {α β : Type w} {m : Type w → Type w'} [Monad m] [Iterator α m β]
|
||||
[IteratorLoop α m m] (it : IterM (α := α) m β) : m (ULift Bool) :=
|
||||
IteratorLoop.forIn (fun _ _ => flip Bind.bind) _ (fun _ _ s => s = ForInStep.done (.up false)) it
|
||||
(.up true) (fun _ _ _ => pure ⟨ForInStep.done (.up false), rfl⟩)
|
||||
|
||||
set_option doc.verso true in
|
||||
/--
|
||||
Returns {lean}`ULift.up true` if the iterator {name}`it` yields no values.
|
||||
|
||||
{lit}`O(|it|)` since the iterator may skip an unknown number of times before returning a result.
|
||||
Short-circuits upon encountering the first result. Only the first element of {name}`it` is examined.
|
||||
|
||||
This variant terminates after finitely many steps and requires a proof that the iterator is
|
||||
finite. If such a proof is not available, consider using {name}`IterM.isEmpty`.
|
||||
-/
|
||||
@[always_inline, inline]
|
||||
def IterM.Total.isEmpty {α β : Type w} {m : Type w → Type w'} [Monad m]
|
||||
[Iterator α m β] [IteratorLoop α m m] [Productive α m] (it : IterM.Total (α := α) m β) :
|
||||
m (ULift Bool) :=
|
||||
it.it.isEmpty
|
||||
|
||||
section Count
|
||||
|
||||
/--
|
||||
@@ -960,21 +992,15 @@ Steps through the whole iterator, counting the number of outputs emitted.
|
||||
This function's runtime is linear in the number of steps taken by the iterator.
|
||||
-/
|
||||
@[always_inline, inline]
|
||||
def IterM.count {α : Type w} {m : Type w → Type w'} {β : Type w} [Iterator α m β]
|
||||
def IterM.length {α : Type w} {m : Type w → Type w'} {β : Type w} [Iterator α m β]
|
||||
[IteratorLoop α m m] [Monad m] (it : IterM (α := α) m β) : m (ULift Nat) :=
|
||||
it.fold (init := .up 0) fun acc _ => .up (acc.down + 1)
|
||||
|
||||
/--
|
||||
Steps through the whole iterator, counting the number of outputs emitted.
|
||||
@[inline, inherit_doc IterM.length, deprecated IterM.length (since := "2026-01-28"), expose]
|
||||
def IterM.count := @IterM.length
|
||||
|
||||
**Performance**:
|
||||
|
||||
This function's runtime is linear in the number of steps taken by the iterator.
|
||||
-/
|
||||
@[always_inline, inline, deprecated IterM.count (since := "2025-10-29")]
|
||||
def IterM.size {α : Type w} {m : Type w → Type w'} {β : Type w} [Iterator α m β]
|
||||
[IteratorLoop α m m] [Monad m] (it : IterM (α := α) m β) : m (ULift Nat) :=
|
||||
it.count
|
||||
@[inline, inherit_doc IterM.length, deprecated IterM.length (since := "2025-10-29"), expose]
|
||||
def IterM.size := @IterM.length
|
||||
|
||||
/--
|
||||
Steps through the whole iterator, counting the number of outputs emitted.
|
||||
@@ -983,7 +1009,7 @@ Steps through the whole iterator, counting the number of outputs emitted.
|
||||
|
||||
This function's runtime is linear in the number of steps taken by the iterator.
|
||||
-/
|
||||
@[always_inline, inline, deprecated IterM.count (since := "2025-12-04")]
|
||||
@[always_inline, inline, deprecated IterM.length (since := "2025-12-04")]
|
||||
def IterM.Partial.count {α : Type w} {m : Type w → Type w'} {β : Type w} [Iterator α m β]
|
||||
[IteratorLoop α m m] [Monad m] (it : IterM.Partial (α := α) m β) : m (ULift Nat) :=
|
||||
it.it.fold (init := .up 0) fun acc _ => .up (acc.down + 1)
|
||||
@@ -995,10 +1021,10 @@ Steps through the whole iterator, counting the number of outputs emitted.
|
||||
|
||||
This function's runtime is linear in the number of steps taken by the iterator.
|
||||
-/
|
||||
@[always_inline, inline, deprecated IterM.Partial.count (since := "2025-10-29")]
|
||||
@[always_inline, inline, deprecated IterM.length (since := "2025-10-29")]
|
||||
def IterM.Partial.size {α : Type w} {m : Type w → Type w'} {β : Type w} [Iterator α m β]
|
||||
[IteratorLoop α m m] [Monad m] (it : IterM.Partial (α := α) m β) : m (ULift Nat) :=
|
||||
it.it.count
|
||||
it.it.length
|
||||
|
||||
end Count
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ consumers such as `toList`. They can be used without any proof of termination su
|
||||
or `Productive`, but as they are implemented with the `partial` declaration modifier, they are
|
||||
opaque for the kernel and it is impossible to prove anything about them.
|
||||
-/
|
||||
@[always_inline, inline]
|
||||
@[always_inline, inline, deprecated "The consumers on iterators do not require proofs of termination anymore. For example, use `it.toList` instead of `it.allowNontermination.toList`." (since := "2026-01-28")]
|
||||
def IterM.allowNontermination {α : Type w} {m : Type w → Type w'} {β : Type w}
|
||||
(it : IterM (α := α) m β) : IterM.Partial (α := α) m β :=
|
||||
⟨it⟩
|
||||
|
||||
@@ -29,7 +29,7 @@ consumers such as `toList`. They can be used without any proof of termination su
|
||||
or `Productive`, but as they are implemented with the `partial` declaration modifier, they are
|
||||
opaque for the kernel and it is impossible to prove anything about them.
|
||||
-/
|
||||
@[always_inline, inline]
|
||||
@[always_inline, inline, deprecated "The consumers on iterators do not require proofs of termination anymore. For example, use `it.toList` instead of `it.allowNontermination.toList`." (since := "2026-01-28")]
|
||||
def Iter.allowNontermination {α : Type w} {β : Type w}
|
||||
(it : Iter (α := α) β) : Iter.Partial (α := α) β :=
|
||||
⟨it⟩
|
||||
|
||||
@@ -79,12 +79,15 @@ theorem Iter.toArray_attachWith [Iterator α Id β]
|
||||
simp [Iter.toList_toArray]
|
||||
|
||||
@[simp]
|
||||
theorem Iter.count_attachWith [Iterator α Id β]
|
||||
theorem Iter.length_attachWith [Iterator α Id β]
|
||||
{it : Iter (α := α) β} {hP}
|
||||
[Finite α Id] [IteratorLoop α Id Id]
|
||||
[LawfulIteratorLoop α Id Id] :
|
||||
(it.attachWith P hP).count = it.count := by
|
||||
rw [← Iter.length_toList_eq_count, toList_attachWith]
|
||||
(it.attachWith P hP).length = it.length := by
|
||||
rw [← Iter.length_toList_eq_length, toList_attachWith]
|
||||
simp
|
||||
|
||||
@[deprecated Iter.length_attachWith (since := "2026-01-28")]
|
||||
def Iter.count_attachWith := @Iter.length_attachWith
|
||||
|
||||
end Std
|
||||
|
||||
@@ -722,11 +722,14 @@ end Fold
|
||||
section Count
|
||||
|
||||
@[simp]
|
||||
theorem Iter.count_map {α β β' : Type w} [Iterator α Id β]
|
||||
theorem Iter.length_map {α β β' : Type w} [Iterator α Id β]
|
||||
[IteratorLoop α Id Id] [Finite α Id] [LawfulIteratorLoop α Id Id]
|
||||
{it : Iter (α := α) β} {f : β → β'} :
|
||||
(it.map f).count = it.count := by
|
||||
simp [map_eq_toIter_map_toIterM, count_eq_count_toIterM]
|
||||
(it.map f).length = it.length := by
|
||||
simp [map_eq_toIter_map_toIterM, length_eq_length_toIterM]
|
||||
|
||||
@[deprecated Iter.length_map (since := "2026-01-28")]
|
||||
def Iter.count_map := @Iter.length_map
|
||||
|
||||
end Count
|
||||
|
||||
|
||||
@@ -60,12 +60,15 @@ theorem IterM.map_unattach_toArray_attachWith [Iterator α m β] [Monad m] [Mona
|
||||
simp [-map_unattach_toList_attachWith, -IterM.toArray_toList]
|
||||
|
||||
@[simp]
|
||||
theorem IterM.count_attachWith [Iterator α m β] [Monad m] [Monad n]
|
||||
theorem IterM.length_attachWith [Iterator α m β] [Monad m] [Monad n]
|
||||
{it : IterM (α := α) m β} {hP}
|
||||
[Finite α m] [IteratorLoop α m m] [LawfulMonad m] [LawfulIteratorLoop α m m] :
|
||||
(it.attachWith P hP).count = it.count := by
|
||||
rw [← up_length_toList_eq_count, ← up_length_toList_eq_count,
|
||||
(it.attachWith P hP).length = it.length := by
|
||||
rw [← up_length_toList_eq_length, ← up_length_toList_eq_length,
|
||||
← map_unattach_toList_attachWith (it := it) (P := P) (hP := hP)]
|
||||
simp only [Functor.map_map, List.length_unattach]
|
||||
|
||||
@[deprecated IterM.length_attachWith (since := "2026-01-28")]
|
||||
def IterM.count_attachWith := @IterM.length_attachWith
|
||||
|
||||
end Std
|
||||
|
||||
@@ -1620,18 +1620,21 @@ end Fold
|
||||
section Count
|
||||
|
||||
@[simp]
|
||||
theorem IterM.count_map {α β β' : Type w} {m : Type w → Type w'} [Iterator α m β] [Monad m]
|
||||
theorem IterM.length_map {α β β' : Type w} {m : Type w → Type w'} [Iterator α m β] [Monad m]
|
||||
[IteratorLoop α m m] [Finite α m] [LawfulMonad m] [LawfulIteratorLoop α m m]
|
||||
{it : IterM (α := α) m β} {f : β → β'} :
|
||||
(it.map f).count = it.count := by
|
||||
(it.map f).length = it.length := by
|
||||
induction it using IterM.inductSteps with | step it ihy ihs
|
||||
rw [count_eq_match_step, count_eq_match_step, step_map, bind_assoc]
|
||||
rw [length_eq_match_step, length_eq_match_step, step_map, bind_assoc]
|
||||
apply bind_congr; intro step
|
||||
cases step.inflate using PlausibleIterStep.casesOn
|
||||
· simp [ihy ‹_›]
|
||||
· simp [ihs ‹_›]
|
||||
· simp
|
||||
|
||||
@[deprecated IterM.length_map (since := "2026-01-28")]
|
||||
def IterM.count_map := @IterM.length_map
|
||||
|
||||
end Count
|
||||
|
||||
section AnyAll
|
||||
|
||||
@@ -66,14 +66,14 @@ theorem IterM.toArray_uLift [Iterator α m β] [Monad m] [Monad n] {it : IterM (
|
||||
simp
|
||||
|
||||
@[simp]
|
||||
theorem IterM.count_uLift [Iterator α m β] [Monad m] [Monad n] {it : IterM (α := α) m β}
|
||||
theorem IterM.length_uLift [Iterator α m β] [Monad m] [Monad n] {it : IterM (α := α) m β}
|
||||
[MonadLiftT m (ULiftT n)] [Finite α m] [IteratorLoop α m m]
|
||||
[LawfulMonad m] [LawfulMonad n] [LawfulIteratorLoop α m m]
|
||||
[LawfulMonadLiftT m (ULiftT n)] :
|
||||
(it.uLift n).count =
|
||||
(.up ·.down.down) <$> (monadLift (n := ULiftT n) it.count).run := by
|
||||
(it.uLift n).length =
|
||||
(.up ·.down.down) <$> (monadLift (n := ULiftT n) it.length).run := by
|
||||
induction it using IterM.inductSteps with | step it ihy ihs
|
||||
rw [count_eq_match_step, count_eq_match_step, monadLift_bind, map_eq_pure_bind, step_uLift]
|
||||
rw [length_eq_match_step, length_eq_match_step, monadLift_bind, map_eq_pure_bind, step_uLift]
|
||||
simp only [bind_assoc, ULiftT.run_bind]
|
||||
apply bind_congr; intro step
|
||||
cases step.down.inflate using PlausibleIterStep.casesOn
|
||||
@@ -81,4 +81,7 @@ theorem IterM.count_uLift [Iterator α m β] [Monad m] [Monad n] {it : IterM (α
|
||||
· simp [ihs ‹_›]
|
||||
· simp
|
||||
|
||||
@[deprecated IterM.length_uLift (since := "2026-01-28")]
|
||||
def IterM.count_uLift := @IterM.length_uLift
|
||||
|
||||
end Std
|
||||
|
||||
@@ -47,18 +47,18 @@ theorem Iter.atIdxSlow?_take {α β}
|
||||
[Iterator α Id β] [Productive α Id] {k l : Nat}
|
||||
{it : Iter (α := α) β} :
|
||||
(it.take k).atIdxSlow? l = if l < k then it.atIdxSlow? l else none := by
|
||||
fun_induction it.atIdxSlow? l generalizing k
|
||||
case case1 it it' out h h' =>
|
||||
simp only [atIdxSlow?.eq_def (it := it.take k), step_take, h']
|
||||
induction l, it using Iter.atIdxSlow?.induct_unfolding generalizing k
|
||||
case yield_zero it it' out h h' =>
|
||||
simp only [atIdxSlow?_eq_match (it := it.take k), step_take, h']
|
||||
cases k <;> simp
|
||||
case case2 it it' out h h' l ih =>
|
||||
simp only [Nat.succ_eq_add_one, atIdxSlow?.eq_def (it := it.take k), step_take, h']
|
||||
case yield_succ it it' out h h' l ih =>
|
||||
simp only [Nat.succ_eq_add_one, atIdxSlow?_eq_match (it := it.take k), step_take, h']
|
||||
cases k <;> cases l <;> simp [ih]
|
||||
case case3 l it it' h h' ih =>
|
||||
simp only [atIdxSlow?.eq_def (it := it.take k), step_take, h']
|
||||
case skip_case l it it' h h' ih =>
|
||||
simp only [atIdxSlow?_eq_match (it := it.take k), step_take, h']
|
||||
cases k <;> cases l <;> simp [ih]
|
||||
case case4 l it h h' =>
|
||||
simp only [atIdxSlow?.eq_def (it := it.take k), step_take, h']
|
||||
case done_case l it h h' =>
|
||||
simp only [atIdxSlow?_eq_match (it := it.take k), step_take, h']
|
||||
cases k <;> cases l <;> simp
|
||||
|
||||
@[simp]
|
||||
|
||||
@@ -57,11 +57,14 @@ theorem Iter.toArray_uLift [Iterator α Id β] {it : Iter (α := α) β}
|
||||
simp [-toArray_toList]
|
||||
|
||||
@[simp]
|
||||
theorem Iter.count_uLift [Iterator α Id β] {it : Iter (α := α) β}
|
||||
theorem Iter.length_uLift [Iterator α Id β] {it : Iter (α := α) β}
|
||||
[Finite α Id] [IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id] :
|
||||
it.uLift.count = it.count := by
|
||||
simp only [monadLift, uLift_eq_toIter_uLift_toIterM, count_eq_count_toIterM, toIterM_toIter]
|
||||
rw [IterM.count_uLift]
|
||||
it.uLift.length = it.length := by
|
||||
simp only [monadLift, uLift_eq_toIter_uLift_toIterM, length_eq_length_toIterM, toIterM_toIter]
|
||||
rw [IterM.length_uLift]
|
||||
simp [monadLift]
|
||||
|
||||
@[deprecated Iter.length_uLift (since := "2026-01-28")]
|
||||
def Iter.count_uLift := @Iter.length_uLift
|
||||
|
||||
end Std
|
||||
|
||||
@@ -7,6 +7,7 @@ module
|
||||
|
||||
prelude
|
||||
public import Init.Data.Iterators.Consumers.Access
|
||||
import Init.Data.Iterators.Lemmas.Basic
|
||||
|
||||
namespace Std.Iter
|
||||
open Std.Iterators
|
||||
@@ -21,6 +22,6 @@ public theorem atIdxSlow?_eq_match [Iterator α Id β] [Productive α Id]
|
||||
| n + 1 => it'.atIdxSlow? n
|
||||
| .skip it' => it'.atIdxSlow? n
|
||||
| .done => none) := by
|
||||
fun_induction it.atIdxSlow? n <;> simp_all
|
||||
induction n, it using Iter.atIdxSlow?.induct_unfolding <;> simp_all
|
||||
|
||||
end Std.Iter
|
||||
|
||||
@@ -163,12 +163,14 @@ theorem Iter.getElem?_toList_eq_atIdxSlow? {α β}
|
||||
{it : Iter (α := α) β} {k : Nat} :
|
||||
it.toList[k]? = it.atIdxSlow? k := by
|
||||
induction it using Iter.inductSteps generalizing k with | step it ihy ihs
|
||||
rw [toList_eq_match_step, atIdxSlow?]
|
||||
obtain ⟨step, h⟩ := it.step
|
||||
cases step
|
||||
· cases k <;> simp [ihy h]
|
||||
· simp [ihs h]
|
||||
· simp
|
||||
rw [toList_eq_match_step, atIdxSlow?, WellFounded.extrinsicFix₂_eq_apply]
|
||||
· obtain ⟨step, h⟩ := it.step
|
||||
cases step
|
||||
· cases k <;> simp [ihy h, atIdxSlow?]
|
||||
· simp [ihs h, atIdxSlow?]
|
||||
· simp
|
||||
· apply InvImage.wf
|
||||
exact WellFoundedRelation.wf
|
||||
|
||||
theorem Iter.toList_eq_of_atIdxSlow?_eq {α₁ α₂ β}
|
||||
[Iterator α₁ Id β] [Finite α₁ Id]
|
||||
|
||||
@@ -460,69 +460,90 @@ theorem Iter.foldl_toArray {α β : Type w} {γ : Type x} [Iterator α Id β] [F
|
||||
it.toArray.foldl (init := init) f = it.fold (init := init) f := by
|
||||
rw [fold_eq_foldM, Array.foldl_eq_foldlM, ← Iter.foldlM_toArray]
|
||||
|
||||
theorem Iter.count_eq_count_toIterM {α β : Type w} [Iterator α Id β]
|
||||
theorem Iter.length_eq_length_toIterM {α β : Type w} [Iterator α Id β]
|
||||
[Finite α Id] [IteratorLoop α Id Id.{w}] {it : Iter (α := α) β} :
|
||||
it.count = it.toIterM.count.run.down :=
|
||||
it.length = it.toIterM.length.run.down :=
|
||||
(rfl)
|
||||
|
||||
theorem Iter.count_eq_fold {α β : Type w} [Iterator α Id β]
|
||||
@[deprecated Iter.length_eq_length_toIterM (since := "2026-01-28")]
|
||||
def Iter.count_eq_count_toIterM := @Iter.length_eq_length_toIterM
|
||||
|
||||
theorem Iter.length_eq_fold {α β : Type w} [Iterator α Id β]
|
||||
[Finite α Id] [IteratorLoop α Id Id.{w}] [LawfulIteratorLoop α Id Id.{w}]
|
||||
[IteratorLoop α Id Id.{0}] [LawfulIteratorLoop α Id Id.{0}]
|
||||
{it : Iter (α := α) β} :
|
||||
it.count = it.fold (γ := Nat) (init := 0) (fun acc _ => acc + 1) := by
|
||||
rw [count_eq_count_toIterM, IterM.count_eq_fold, ← fold_eq_fold_toIterM]
|
||||
it.length = it.fold (γ := Nat) (init := 0) (fun acc _ => acc + 1) := by
|
||||
rw [length_eq_length_toIterM, IterM.length_eq_fold, ← fold_eq_fold_toIterM]
|
||||
rw [← fold_hom (f := ULift.down)]
|
||||
simp
|
||||
|
||||
theorem Iter.count_eq_forIn {α β : Type w} [Iterator α Id β]
|
||||
@[deprecated Iter.length_eq_fold (since := "2026-01-28")]
|
||||
def Iter.count_eq_fold := @Iter.length_eq_fold
|
||||
|
||||
theorem Iter.length_eq_forIn {α β : Type w} [Iterator α Id β]
|
||||
[Finite α Id] [IteratorLoop α Id Id.{w}] [LawfulIteratorLoop α Id Id.{w}]
|
||||
[IteratorLoop α Id Id.{0}] [LawfulIteratorLoop α Id Id.{0}]
|
||||
{it : Iter (α := α) β} :
|
||||
it.count = (ForIn.forIn (m := Id) it 0 (fun _ acc => return .yield (acc + 1))).run := by
|
||||
rw [count_eq_fold, forIn_pure_yield_eq_fold, Id.run_pure]
|
||||
it.length = (ForIn.forIn (m := Id) it 0 (fun _ acc => return .yield (acc + 1))).run := by
|
||||
rw [length_eq_fold, forIn_pure_yield_eq_fold, Id.run_pure]
|
||||
|
||||
theorem Iter.count_eq_match_step {α β : Type w} [Iterator α Id β]
|
||||
@[deprecated Iter.length_eq_forIn (since := "2026-01-28")]
|
||||
def Iter.count_eq_forIn := @Iter.length_eq_forIn
|
||||
|
||||
theorem Iter.length_eq_match_step {α β : Type w} [Iterator α Id β]
|
||||
[Finite α Id] [IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id]
|
||||
{it : Iter (α := α) β} :
|
||||
it.count = (match it.step.val with
|
||||
| .yield it' _ => it'.count + 1
|
||||
| .skip it' => it'.count
|
||||
it.length = (match it.step.val with
|
||||
| .yield it' _ => it'.length + 1
|
||||
| .skip it' => it'.length
|
||||
| .done => 0) := by
|
||||
simp only [count_eq_count_toIterM]
|
||||
rw [IterM.count_eq_match_step]
|
||||
simp only [length_eq_length_toIterM]
|
||||
rw [IterM.length_eq_match_step]
|
||||
simp only [bind_pure_comp, id_map', Id.run_bind, Iter.step]
|
||||
cases it.toIterM.step.run.inflate using PlausibleIterStep.casesOn <;> simp
|
||||
|
||||
@[simp]
|
||||
theorem Iter.size_toArray_eq_count {α β : Type w} [Iterator α Id β] [Finite α Id]
|
||||
[IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id]
|
||||
{it : Iter (α := α) β} :
|
||||
it.toArray.size = it.count := by
|
||||
simp only [toArray_eq_toArray_toIterM, count_eq_count_toIterM, Id.run_map,
|
||||
← IterM.up_size_toArray_eq_count]
|
||||
|
||||
@[deprecated Iter.size_toArray_eq_count (since := "2025-10-29")]
|
||||
def Iter.size_toArray_eq_size := @size_toArray_eq_count
|
||||
@[deprecated Iter.length_eq_match_step (since := "2026-01-28")]
|
||||
def Iter.count_eq_match_step := @Iter.length_eq_match_step
|
||||
|
||||
@[simp]
|
||||
theorem Iter.length_toList_eq_count {α β : Type w} [Iterator α Id β] [Finite α Id]
|
||||
theorem Iter.size_toArray_eq_length {α β : Type w} [Iterator α Id β] [Finite α Id]
|
||||
[IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id]
|
||||
{it : Iter (α := α) β} :
|
||||
it.toList.length = it.count := by
|
||||
rw [← toList_toArray, Array.length_toList, size_toArray_eq_count]
|
||||
it.toArray.size = it.length := by
|
||||
simp only [toArray_eq_toArray_toIterM, length_eq_length_toIterM, Id.run_map,
|
||||
← IterM.up_size_toArray_eq_length]
|
||||
|
||||
@[deprecated Iter.length_toList_eq_count (since := "2025-10-29")]
|
||||
def Iter.length_toList_eq_size := @length_toList_eq_count
|
||||
@[deprecated Iter.size_toArray_eq_length (since := "2025-10-29")]
|
||||
def Iter.size_toArray_eq_size := @size_toArray_eq_length
|
||||
|
||||
@[deprecated Iter.size_toArray_eq_length (since := "2026-01-28")]
|
||||
def Iter.size_toArray_eq_count := @size_toArray_eq_length
|
||||
|
||||
@[simp]
|
||||
theorem Iter.length_toListRev_eq_count {α β : Type w} [Iterator α Id β] [Finite α Id]
|
||||
theorem Iter.length_toList_eq_length {α β : Type w} [Iterator α Id β] [Finite α Id]
|
||||
[IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id]
|
||||
{it : Iter (α := α) β} :
|
||||
it.toListRev.length = it.count := by
|
||||
rw [toListRev_eq, List.length_reverse, length_toList_eq_count]
|
||||
it.toList.length = it.length := by
|
||||
rw [← toList_toArray, Array.length_toList, size_toArray_eq_length]
|
||||
|
||||
@[deprecated Iter.length_toListRev_eq_count (since := "2025-10-29")]
|
||||
def Iter.length_toListRev_eq_size := @length_toListRev_eq_count
|
||||
@[deprecated Iter.length_toList_eq_length (since := "2025-10-29")]
|
||||
def Iter.length_toList_eq_size := @length_toList_eq_length
|
||||
|
||||
@[deprecated Iter.length_toList_eq_length (since := "2026-01-28")]
|
||||
def Iter.length_toList_eq_count := @length_toList_eq_length
|
||||
|
||||
@[simp]
|
||||
theorem Iter.length_toListRev_eq_length {α β : Type w} [Iterator α Id β] [Finite α Id]
|
||||
[IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id]
|
||||
{it : Iter (α := α) β} :
|
||||
it.toListRev.length = it.length := by
|
||||
rw [toListRev_eq, List.length_reverse, length_toList_eq_length]
|
||||
|
||||
@[deprecated Iter.length_toListRev_eq_length (since := "2025-10-29")]
|
||||
def Iter.length_toListRev_eq_size := @length_toListRev_eq_length
|
||||
|
||||
@[deprecated Iter.length_toListRev_eq_length (since := "2026-01-28")]
|
||||
def Iter.length_toListRev_eq_count := @length_toListRev_eq_length
|
||||
|
||||
theorem Iter.anyM_eq_forIn {α β : Type w} {m : Type → Type w'} [Iterator α Id β]
|
||||
[Finite α Id] [Monad m] [LawfulMonad m] [IteratorLoop α Id m] [LawfulIteratorLoop α Id m]
|
||||
@@ -930,11 +951,35 @@ theorem Iter.first?_eq_match_step {α β : Type w} [Iterator α Id β] [Iterator
|
||||
generalize it.toIterM.step.run.inflate = s
|
||||
rcases s with ⟨_|_|_, _⟩ <;> simp [Iter.first?_eq_first?_toIterM]
|
||||
|
||||
theorem Iter.first?_eq_head?_toList {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
|
||||
@[simp, grind =]
|
||||
theorem Iter.head?_toList {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
|
||||
[Finite α Id] [LawfulIteratorLoop α Id Id] {it : Iter (α := α) β} :
|
||||
it.first? = it.toList.head? := by
|
||||
it.toList.head? = it.first? := by
|
||||
induction it using Iter.inductSteps with | step it ihy ihs
|
||||
rw [first?_eq_match_step, toList_eq_match_step]
|
||||
cases it.step using PlausibleIterStep.casesOn <;> simp [*]
|
||||
|
||||
theorem Iter.isEmpty_eq_isEmpty_toIterM {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
|
||||
{it : Iter (α := α) β} :
|
||||
it.isEmpty = it.toIterM.isEmpty.run.down := (rfl)
|
||||
|
||||
theorem Iter.isEmpty_eq_match_step {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
|
||||
[Productive α Id] [LawfulIteratorLoop α Id Id] {it : Iter (α := α) β} :
|
||||
it.isEmpty = match it.step.val with
|
||||
| .yield _ _ => false
|
||||
| .skip it' => it'.isEmpty
|
||||
| .done => true := by
|
||||
rw [Iter.isEmpty_eq_isEmpty_toIterM, IterM.isEmpty_eq_match_step]
|
||||
simp only [Id.run_bind, step]
|
||||
generalize it.toIterM.step.run.inflate = s
|
||||
rcases s with ⟨_|_|_, _⟩ <;> simp [Iter.isEmpty_eq_isEmpty_toIterM]
|
||||
|
||||
@[simp, grind =]
|
||||
theorem Iter.isEmpty_toList {α β : Type w} [Iterator α Id β] [IteratorLoop α Id Id]
|
||||
[Finite α Id] [LawfulIteratorLoop α Id Id] {it : Iter (α := α) β} :
|
||||
it.toList.isEmpty = it.isEmpty := by
|
||||
induction it using Iter.inductSteps with | step it ihy ihs
|
||||
rw [isEmpty_eq_match_step, toList_eq_match_step]
|
||||
cases it.step using PlausibleIterStep.casesOn <;> simp [*]
|
||||
|
||||
end Std
|
||||
|
||||
@@ -476,27 +476,33 @@ theorem IterM.drain_eq_map_toArray {α β : Type w} {m : Type w → Type w'} [It
|
||||
it.drain = (fun _ => .unit) <$> it.toList := by
|
||||
simp [IterM.drain_eq_map_toList]
|
||||
|
||||
theorem IterM.count_eq_fold {α β : Type w} {m : Type w → Type w'} [Iterator α m β]
|
||||
theorem IterM.length_eq_fold {α β : Type w} {m : Type w → Type w'} [Iterator α m β]
|
||||
[Finite α m] [Monad m] [LawfulMonad m] [IteratorLoop α m m]
|
||||
{it : IterM (α := α) m β} :
|
||||
it.count = it.fold (init := .up 0) (fun acc _ => .up <| acc.down + 1) :=
|
||||
it.length = it.fold (init := .up 0) (fun acc _ => .up <| acc.down + 1) :=
|
||||
(rfl)
|
||||
|
||||
theorem IterM.count_eq_forIn {α β : Type w} {m : Type w → Type w'} [Iterator α m β]
|
||||
@[deprecated IterM.length_eq_fold (since := "2026-01-28")]
|
||||
def IterM.count_eq_fold := @IterM.length_eq_fold
|
||||
|
||||
theorem IterM.length_eq_forIn {α β : Type w} {m : Type w → Type w'} [Iterator α m β]
|
||||
[Finite α m] [Monad m] [LawfulMonad m] [IteratorLoop α m m]
|
||||
{it : IterM (α := α) m β} :
|
||||
it.count = ForIn.forIn it (.up 0) (fun _ acc => return .yield (.up (acc.down + 1))) :=
|
||||
it.length = ForIn.forIn it (.up 0) (fun _ acc => return .yield (.up (acc.down + 1))) :=
|
||||
(rfl)
|
||||
|
||||
theorem IterM.count_eq_match_step {α β : Type w} {m : Type w → Type w'} [Iterator α m β]
|
||||
@[deprecated IterM.length_eq_forIn (since := "2026-01-28")]
|
||||
def IterM.count_eq_forIn := @IterM.length_eq_forIn
|
||||
|
||||
theorem IterM.length_eq_match_step {α β : Type w} {m : Type w → Type w'} [Iterator α m β]
|
||||
[Finite α m] [Monad m] [LawfulMonad m] [IteratorLoop α m m] [LawfulIteratorLoop α m m]
|
||||
{it : IterM (α := α) m β} :
|
||||
it.count = (do
|
||||
it.length = (do
|
||||
match (← it.step).inflate.val with
|
||||
| .yield it' _ => return .up ((← it'.count).down + 1)
|
||||
| .skip it' => return .up (← it'.count).down
|
||||
| .yield it' _ => return .up ((← it'.length).down + 1)
|
||||
| .skip it' => return .up (← it'.length).down
|
||||
| .done => return .up 0) := by
|
||||
simp only [count_eq_fold]
|
||||
simp only [length_eq_fold]
|
||||
have (acc : Nat) (it' : IterM (α := α) m β) :
|
||||
it'.fold (init := ULift.up acc) (fun acc _ => .up (acc.down + 1)) =
|
||||
(ULift.up <| ·.down + acc) <$>
|
||||
@@ -512,33 +518,45 @@ theorem IterM.count_eq_match_step {α β : Type w} {m : Type w → Type w'} [Ite
|
||||
· simp
|
||||
· simp
|
||||
|
||||
@[deprecated IterM.length_eq_match_step (since := "2026-01-28")]
|
||||
def IterM.count_eq_match_step := @IterM.length_eq_match_step
|
||||
|
||||
@[simp]
|
||||
theorem IterM.up_size_toArray_eq_count {α β : Type w} [Iterator α m β] [Finite α m]
|
||||
theorem IterM.up_size_toArray_eq_length {α β : Type w} [Iterator α m β] [Finite α m]
|
||||
[Monad m] [LawfulMonad m]
|
||||
[IteratorLoop α m m] [LawfulIteratorLoop α m m]
|
||||
{it : IterM (α := α) m β} :
|
||||
(.up <| ·.size) <$> it.toArray = it.count := by
|
||||
rw [toArray_eq_fold, count_eq_fold, ← fold_hom]
|
||||
(.up <| ·.size) <$> it.toArray = it.length := by
|
||||
rw [toArray_eq_fold, length_eq_fold, ← fold_hom]
|
||||
· simp only [List.size_toArray, List.length_nil]; rfl
|
||||
· simp
|
||||
|
||||
@[deprecated IterM.up_size_toArray_eq_length (since := "2026-01-28")]
|
||||
def IterM.up_size_toArray_eq_count := @IterM.up_size_toArray_eq_length
|
||||
|
||||
@[simp]
|
||||
theorem IterM.up_length_toList_eq_count {α β : Type w} [Iterator α m β] [Finite α m]
|
||||
theorem IterM.up_length_toList_eq_length {α β : Type w} [Iterator α m β] [Finite α m]
|
||||
[Monad m] [LawfulMonad m]
|
||||
[IteratorLoop α m m] [LawfulIteratorLoop α m m]
|
||||
{it : IterM (α := α) m β} :
|
||||
(.up <| ·.length) <$> it.toList = it.count := by
|
||||
rw [toList_eq_fold, count_eq_fold, ← fold_hom]
|
||||
(.up <| ·.length) <$> it.toList = it.length := by
|
||||
rw [toList_eq_fold, length_eq_fold, ← fold_hom]
|
||||
· simp only [List.length_nil]; rfl
|
||||
· simp
|
||||
|
||||
@[deprecated IterM.up_length_toList_eq_length (since := "2026-01-28")]
|
||||
def IterM.up_length_toList_eq_count := @IterM.up_length_toList_eq_length
|
||||
|
||||
@[simp]
|
||||
theorem IterM.up_length_toListRev_eq_count {α β : Type w} [Iterator α m β] [Finite α m]
|
||||
theorem IterM.up_length_toListRev_eq_length {α β : Type w} [Iterator α m β] [Finite α m]
|
||||
[Monad m] [LawfulMonad m]
|
||||
[IteratorLoop α m m] [LawfulIteratorLoop α m m]
|
||||
{it : IterM (α := α) m β} :
|
||||
(.up <| ·.length) <$> it.toListRev = it.count := by
|
||||
simp only [toListRev_eq, Functor.map_map, List.length_reverse, up_length_toList_eq_count]
|
||||
(.up <| ·.length) <$> it.toListRev = it.length := by
|
||||
simp only [toListRev_eq, Functor.map_map, List.length_reverse, up_length_toList_eq_length]
|
||||
|
||||
@[deprecated IterM.up_length_toListRev_eq_length (since := "2026-01-28")]
|
||||
def IterM.up_length_toListRev_eq_count := @IterM.up_length_toListRev_eq_length
|
||||
|
||||
theorem IterM.anyM_eq_forIn {α β : Type w} {m : Type w → Type w'} [Iterator α m β]
|
||||
[Finite α m] [Monad m] [LawfulMonad m] [IteratorLoop α m m] [LawfulIteratorLoop α m m]
|
||||
@@ -861,4 +879,24 @@ theorem IterM.first?_eq_match_step {α β : Type w} {m : Type w → Type w'} [Mo
|
||||
simp only [DefaultConsumers.forIn_eq, *]
|
||||
exact IterM.DefaultConsumers.forIn'_eq_forIn' _ this (by simp)
|
||||
|
||||
theorem IterM.isEmpty_eq_match_step {α β : Type w} {m : Type w → Type w'} [Monad m]
|
||||
[Iterator α m β] [IteratorLoop α m m] [LawfulMonad m] [Productive α m]
|
||||
[LawfulIteratorLoop α m m] {it : IterM (α := α) m β} :
|
||||
it.isEmpty = (do
|
||||
match (← it.step).inflate.val with
|
||||
| .yield _ _ => return .up false
|
||||
| .skip it' => it'.isEmpty
|
||||
| .done => return .up true) := by
|
||||
simp only [isEmpty]
|
||||
have := IteratorLoop.wellFounded_of_productive (α := α) (β := β) (m := m)
|
||||
(P := fun _ _ s => s = ForInStep.done (ULift.up false)) (by simp)
|
||||
simp only [LawfulIteratorLoop.lawful _ _ _ _ _ this]
|
||||
rw [IterM.DefaultConsumers.forIn_eq, IterM.DefaultConsumers.forIn'_eq_match_step _ this]
|
||||
simp only [flip, pure_bind]
|
||||
congr
|
||||
ext s
|
||||
split <;> try (simp [*]; done)
|
||||
simp only [DefaultConsumers.forIn_eq, *]
|
||||
exact IterM.DefaultConsumers.forIn'_eq_forIn' _ this (by simp)
|
||||
|
||||
end Std
|
||||
|
||||
@@ -16,6 +16,8 @@ public import Init.Data.List.Find
|
||||
public import Init.Data.List.Impl
|
||||
public import Init.Data.List.Lemmas
|
||||
public import Init.Data.List.MinMax
|
||||
public import Init.Data.List.MinMaxIdx
|
||||
public import Init.Data.List.MinMaxOn
|
||||
public import Init.Data.List.Monadic
|
||||
public import Init.Data.List.Nat
|
||||
public import Init.Data.List.Notation
|
||||
|
||||
@@ -85,7 +85,7 @@ theorem cons_lex_cons_iff : Lex r (a :: l₁) (b :: l₂) ↔ r a b ∨ a = b
|
||||
|
||||
theorem cons_lt_cons_iff [LT α] {a b} {l₁ l₂ : List α} :
|
||||
(a :: l₁) < (b :: l₂) ↔ a < b ∨ a = b ∧ l₁ < l₂ := by
|
||||
dsimp only [instLT, List.lt]
|
||||
simp only [LT.lt, List.lt]
|
||||
simp [cons_lex_cons_iff]
|
||||
|
||||
@[simp] theorem cons_lt_cons_self [LT α] [i₀ : Std.Irrefl (· < · : α → α → Prop)] {l₁ l₂ : List α} :
|
||||
@@ -101,7 +101,7 @@ theorem cons_le_cons_iff [LT α]
|
||||
[i₂ : Std.Trichotomous (· < · : α → α → Prop)]
|
||||
{a b} {l₁ l₂ : List α} :
|
||||
(a :: l₁) ≤ (b :: l₂) ↔ a < b ∨ a = b ∧ l₁ ≤ l₂ := by
|
||||
dsimp only [instLE, instLT, List.le, List.lt]
|
||||
simp only [LE.le, LT.lt, List.le, List.lt]
|
||||
open Classical in
|
||||
simp only [not_cons_lex_cons_iff, ne_eq]
|
||||
constructor
|
||||
|
||||
@@ -29,7 +29,11 @@ open Nat
|
||||
|
||||
/-! ### min? -/
|
||||
|
||||
@[simp] theorem min?_nil [Min α] : ([] : List α).min? = none := rfl
|
||||
@[simp, grind =] theorem min?_nil [Min α] : ([] : List α).min? = none := rfl
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem min?_singleton [Min α] {x : α} : [x].min? = some x :=
|
||||
(rfl)
|
||||
|
||||
-- We don't put `@[simp]` on `min?_cons'`,
|
||||
-- because the definition in terms of `foldl` is not useful for proofs.
|
||||
@@ -39,9 +43,14 @@ theorem min?_cons' [Min α] {xs : List α} : (x :: xs).min? = some (foldl min x
|
||||
(x :: xs).min? = some (xs.min?.elim x (min x)) := by
|
||||
cases xs <;> simp [min?_cons', foldl_assoc]
|
||||
|
||||
@[simp] theorem min?_eq_none_iff {xs : List α} [Min α] : xs.min? = none ↔ xs = [] := by
|
||||
@[simp, grind =] theorem min?_eq_none_iff {xs : List α} [Min α] : xs.min? = none ↔ xs = [] := by
|
||||
cases xs <;> simp [min?]
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem isSome_min?_iff {xs : List α} [Min α] : xs.min?.isSome ↔ xs ≠ [] := by
|
||||
cases xs <;> simp [min?]
|
||||
|
||||
@[grind .]
|
||||
theorem isSome_min?_of_mem {l : List α} [Min α] {a : α} (h : a ∈ l) :
|
||||
l.min?.isSome := by
|
||||
cases l <;> simp_all [min?_cons']
|
||||
@@ -143,7 +152,8 @@ theorem min?_replicate [Min α] [Std.IdempotentOp (min : α → α → α)] {n :
|
||||
| zero => rfl
|
||||
| succ n ih => cases n <;> simp_all [replicate_succ, min?_cons', Std.IdempotentOp.idempotent]
|
||||
|
||||
@[simp] theorem min?_replicate_of_pos [Min α] [MinEqOr α] {n : Nat} {a : α} (h : 0 < n) :
|
||||
@[simp, grind =]
|
||||
theorem min?_replicate_of_pos [Min α] [MinEqOr α] {n : Nat} {a : α} (h : 0 < n) :
|
||||
(replicate n a).min? = some a := by
|
||||
simp [min?_replicate, Nat.ne_of_gt h]
|
||||
|
||||
@@ -160,6 +170,11 @@ theorem foldl_min [Min α] [Std.IdempotentOp (min : α → α → α)] [Std.Asso
|
||||
|
||||
/-! ### min -/
|
||||
|
||||
@[simp, grind =]
|
||||
theorem min_singleton [Min α] {x : α} :
|
||||
[x].min (cons_ne_nil _ _) = x := by
|
||||
(rfl)
|
||||
|
||||
theorem min?_eq_some_min [Min α] : {l : List α} → (hl : l ≠ []) →
|
||||
l.min? = some (l.min hl)
|
||||
| a::as, _ => by simp [List.min, List.min?_cons']
|
||||
@@ -168,15 +183,22 @@ theorem min_eq_get_min? [Min α] : (l : List α) → (hl : l ≠ []) →
|
||||
l.min hl = l.min?.get (isSome_min?_of_ne_nil hl)
|
||||
| a::as, _ => by simp [List.min, List.min?_cons']
|
||||
|
||||
@[simp, grind =]
|
||||
theorem get_min? [Min α] {l : List α} {h : l.min?.isSome} :
|
||||
l.min?.get h = l.min (isSome_min?_iff.mp h) := by
|
||||
simp [min?_eq_some_min (isSome_min?_iff.mp h)]
|
||||
|
||||
theorem min_eq_head {α : Type u} [Min α] {l : List α} (hl : l ≠ [])
|
||||
(h : l.Pairwise (fun a b => min a b = a)) : l.min hl = l.head hl := by
|
||||
apply Option.some.inj
|
||||
rw [← min?_eq_some_min, ← head?_eq_some_head]
|
||||
exact min?_eq_head? h
|
||||
|
||||
@[grind .]
|
||||
theorem min_mem [Min α] [MinEqOr α] {l : List α} (hl : l ≠ []) : l.min hl ∈ l :=
|
||||
min?_mem (min?_eq_some_min hl)
|
||||
|
||||
@[grind .]
|
||||
theorem min_le_of_mem [Min α] [LE α] [Std.IsLinearOrder α] [Std.LawfulOrderMin α]
|
||||
{l : List α} {a : α} (ha : a ∈ l) :
|
||||
l.min (ne_nil_of_mem ha) ≤ a :=
|
||||
@@ -190,7 +212,7 @@ theorem min_eq_iff [Min α] [LE α] {l : List α} [IsLinearOrder α] [LawfulOrde
|
||||
l.min hl = a ↔ a ∈ l ∧ ∀ b, b ∈ l → a ≤ b := by
|
||||
simpa [min?_eq_some_min hl] using (min?_eq_some_iff (xs := l))
|
||||
|
||||
@[simp] theorem min_replicate [Min α] [MinEqOr α] {n : Nat} {a : α} (h : replicate n a ≠ []) :
|
||||
@[simp, grind =] theorem min_replicate [Min α] [MinEqOr α] {n : Nat} {a : α} (h : replicate n a ≠ []) :
|
||||
(replicate n a).min h = a := by
|
||||
have n_pos : 0 < n := Nat.pos_of_ne_zero (fun hn => by simp [hn] at h)
|
||||
simpa [min?_eq_some_min h] using (min?_replicate_of_pos (a := a) n_pos)
|
||||
@@ -202,7 +224,11 @@ theorem foldl_min_eq_min [Min α] [Std.IdempotentOp (min : α → α → α)] [S
|
||||
|
||||
/-! ### max? -/
|
||||
|
||||
@[simp] theorem max?_nil [Max α] : ([] : List α).max? = none := rfl
|
||||
@[simp, grind =] theorem max?_nil [Max α] : ([] : List α).max? = none := rfl
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem max?_singleton [Max α] {x : α} : [x].max? = some x :=
|
||||
(rfl)
|
||||
|
||||
-- We don't put `@[simp]` on `max?_cons'`,
|
||||
-- because the definition in terms of `foldl` is not useful for proofs.
|
||||
@@ -212,9 +238,14 @@ theorem max?_cons' [Max α] {xs : List α} : (x :: xs).max? = some (foldl max x
|
||||
(x :: xs).max? = some (xs.max?.elim x (max x)) := by
|
||||
cases xs <;> simp [max?_cons', foldl_assoc]
|
||||
|
||||
@[simp] theorem max?_eq_none_iff {xs : List α} [Max α] : xs.max? = none ↔ xs = [] := by
|
||||
@[simp, grind =] theorem max?_eq_none_iff {xs : List α} [Max α] : xs.max? = none ↔ xs = [] := by
|
||||
cases xs <;> simp [max?]
|
||||
|
||||
@[simp, grind =]
|
||||
public theorem isSome_max?_iff {xs : List α} [Max α] : xs.max?.isSome ↔ xs ≠ [] := by
|
||||
cases xs <;> simp [max?]
|
||||
|
||||
@[grind .]
|
||||
theorem isSome_max?_of_mem {l : List α} [Max α] {a : α} (h : a ∈ l) :
|
||||
l.max?.isSome := by
|
||||
cases l <;> simp_all [max?_cons']
|
||||
@@ -329,7 +360,8 @@ theorem max?_replicate [Max α] [Std.IdempotentOp (max : α → α → α)] {n :
|
||||
| zero => rfl
|
||||
| succ n ih => cases n <;> simp_all [replicate_succ, max?_cons', Std.IdempotentOp.idempotent]
|
||||
|
||||
@[simp] theorem max?_replicate_of_pos [Max α] [MaxEqOr α] {n : Nat} {a : α} (h : 0 < n) :
|
||||
@[simp, grind =]
|
||||
theorem max?_replicate_of_pos [Max α] [MaxEqOr α] {n : Nat} {a : α} (h : 0 < n) :
|
||||
(replicate n a).max? = some a := by
|
||||
simp [max?_replicate, Nat.ne_of_gt h]
|
||||
|
||||
@@ -346,6 +378,11 @@ theorem foldl_max [Max α] [Std.IdempotentOp (max : α → α → α)] [Std.Asso
|
||||
|
||||
/-! ### max -/
|
||||
|
||||
@[simp, grind =]
|
||||
theorem max_singleton [Max α] {x : α} :
|
||||
[x].max (cons_ne_nil _ _) = x := by
|
||||
(rfl)
|
||||
|
||||
theorem max?_eq_some_max [Max α] : {l : List α} → (hl : l ≠ []) →
|
||||
l.max? = some (l.max hl)
|
||||
| a::as, _ => by simp [List.max, List.max?_cons']
|
||||
@@ -354,12 +391,18 @@ theorem max_eq_get_max? [Max α] : (l : List α) → (hl : l ≠ []) →
|
||||
l.max hl = l.max?.get (isSome_max?_of_ne_nil hl)
|
||||
| a::as, _ => by simp [List.max, List.max?_cons']
|
||||
|
||||
@[simp, grind =]
|
||||
theorem get_max? [Max α] {l : List α} {h : l.max?.isSome} :
|
||||
l.max?.get h = l.max (isSome_max?_iff.mp h) := by
|
||||
simp [max?_eq_some_max (isSome_max?_iff.mp h)]
|
||||
|
||||
theorem max_eq_head {α : Type u} [Max α] {l : List α} (hl : l ≠ [])
|
||||
(h : l.Pairwise (fun a b => max a b = a)) : l.max hl = l.head hl := by
|
||||
apply Option.some.inj
|
||||
rw [← max?_eq_some_max, ← head?_eq_some_head]
|
||||
exact max?_eq_head? h
|
||||
|
||||
@[grind .]
|
||||
theorem max_mem [Max α] [MaxEqOr α] {l : List α} (hl : l ≠ []) : l.max hl ∈ l :=
|
||||
max?_mem (max?_eq_some_max hl)
|
||||
|
||||
@@ -371,12 +414,13 @@ theorem max_eq_iff [Max α] [LE α] {l : List α} [IsLinearOrder α] [LawfulOrde
|
||||
l.max hl = a ↔ a ∈ l ∧ ∀ b, b ∈ l → b ≤ a := by
|
||||
simpa [max?_eq_some_max hl] using (max?_eq_some_iff (xs := l))
|
||||
|
||||
@[grind .]
|
||||
theorem le_max_of_mem [Max α] [LE α] [Std.IsLinearOrder α] [Std.LawfulOrderMax α]
|
||||
{l : List α} {a : α} (ha : a ∈ l) :
|
||||
a ≤ l.max (List.ne_nil_of_mem ha) :=
|
||||
(max?_eq_some_iff.mp (max?_eq_some_max (List.ne_nil_of_mem ha))).right a ha
|
||||
|
||||
@[simp] theorem max_replicate [Max α] [MaxEqOr α] {n : Nat} {a : α} (h : replicate n a ≠ []) :
|
||||
@[simp, grind =] theorem max_replicate [Max α] [MaxEqOr α] {n : Nat} {a : α} (h : replicate n a ≠ []) :
|
||||
(replicate n a).max h = a := by
|
||||
have n_pos : 0 < n := Nat.pos_of_ne_zero (fun hn => by simp [hn] at h)
|
||||
simpa [max?_eq_some_max h] using (max?_replicate_of_pos (a := a) n_pos)
|
||||
|
||||
830
src/Init/Data/List/MinMaxIdx.lean
Normal file
830
src/Init/Data/List/MinMaxIdx.lean
Normal file
@@ -0,0 +1,830 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Paul Reichert
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.Data.List.MinMaxOn
|
||||
import Init.Data.List.MinMaxOn
|
||||
public import Init.Data.List.Pairwise
|
||||
public import Init.Data.Subtype.Order
|
||||
import Init.Data.Order.Lemmas
|
||||
import Init.Data.List.Nat.TakeDrop
|
||||
import Init.Data.Order.Opposite
|
||||
import Init.Data.Nat.Order
|
||||
|
||||
public section
|
||||
|
||||
open Std
|
||||
open scoped OppositeOrderInstances
|
||||
|
||||
set_option doc.verso true
|
||||
set_option linter.missingDocs true
|
||||
set_option linter.listVariables true -- Enforce naming conventions for `List`/`Array`/`Vector` variables.
|
||||
set_option linter.indexVariables true -- Enforce naming conventions for index variables.
|
||||
|
||||
namespace List
|
||||
|
||||
/--
|
||||
Returns the index of an element of the non-empty list {name}`xs` that minimizes {name}`f`.
|
||||
If {given}`x, y` are such that {lean}`f x = f y`, it returns the index of whichever comes first
|
||||
in the list.
|
||||
|
||||
The correctness of this function assumes {name}`β` to be linearly pre-ordered.
|
||||
-/
|
||||
@[inline]
|
||||
def minIdxOn [LE β] [DecidableLE β] (f : α → β) (xs : List α) (h : xs ≠ []) : Nat :=
|
||||
match xs with
|
||||
| y :: ys => go y 0 1 ys
|
||||
where
|
||||
@[specialize]
|
||||
go (x : α) (i : Nat) (j : Nat) (xs : List α) :=
|
||||
match xs with
|
||||
| [] => i
|
||||
| y :: ys =>
|
||||
if f x ≤ f y then
|
||||
go x i (j + 1) ys
|
||||
else
|
||||
go y j (j + 1) ys
|
||||
|
||||
/--
|
||||
Returns the index of an element of {name}`xs` that minimizes {name}`f`. If {given}`x, y`
|
||||
are such that {lean}`f x = f y`, it returns the index of whichever comes first in the list.
|
||||
Returns {name}`none` if the list is empty.
|
||||
|
||||
The correctness of this function assumes {name}`β` to be linearly pre-ordered.
|
||||
-/
|
||||
@[inline]
|
||||
def minIdxOn? [LE β] [DecidableLE β] (f : α → β) (xs : List α) : Option Nat :=
|
||||
match xs with
|
||||
| [] => none
|
||||
| y :: ys => some ((y :: ys).minIdxOn f (nomatch ·))
|
||||
|
||||
/--
|
||||
Returns the index of an element of the non-empty list {name}`xs` that maximizes {name}`f`.
|
||||
If {given}`x, y` are such that {lean}`f x = f y`, it returns the index of whichever comes first
|
||||
in the list.
|
||||
|
||||
The correctness of this function assumes {name}`β` to be linearly pre-ordered.
|
||||
-/
|
||||
@[inline]
|
||||
def maxIdxOn [LE β] [DecidableLE β] (f : α → β) (xs : List α) (h : xs ≠ []) : Nat :=
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
xs.minIdxOn f h
|
||||
|
||||
/--
|
||||
Returns the index of an element of {name}`xs` that maximizes {name}`f`. If {given}`x, y`
|
||||
are such that {lean}`f x = f y`, it returns the index of whichever comes first in the list.
|
||||
Returns {name}`none` if the list is empty.
|
||||
|
||||
The correctness of this function assumes {name}`β` to be linearly pre-ordered.
|
||||
-/
|
||||
@[inline]
|
||||
def maxIdxOn? [LE β] [DecidableLE β] (f : α → β) (xs : List α) : Option Nat :=
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
xs.minIdxOn? f
|
||||
|
||||
protected theorem maxIdxOn_eq_minIdxOn {le : LE β} {_ : DecidableLE β} {f : α → β}
|
||||
{xs : List α} {h} :
|
||||
xs.maxIdxOn f h = (letI := le.opposite; xs.minIdxOn f h) :=
|
||||
(rfl)
|
||||
|
||||
private theorem minIdxOn.go_lt_length_add [LE β] [DecidableLE β] {f : α → β} {x : α}
|
||||
{i j : Nat} {xs : List α} (h : i < j) :
|
||||
List.minIdxOn.go f x i j xs < xs.length + j := by
|
||||
induction xs generalizing x i j
|
||||
· simp [go, h]
|
||||
· rename_i y ys ih
|
||||
simp only [go, length_cons, Nat.add_assoc, Nat.add_comm 1]
|
||||
split
|
||||
· exact ih (Nat.lt_succ_of_lt ‹i < j›)
|
||||
· exact ih (Nat.lt_succ_self j)
|
||||
|
||||
private theorem minIdxOn.go_eq_of_forall_le [LE β] [DecidableLE β] {f : α → β}
|
||||
{x : α} {i j : Nat} {xs : List α} (h : ∀ y ∈ xs, f x ≤ f y) :
|
||||
List.minIdxOn.go f x i j xs = i := by
|
||||
induction xs generalizing x i j
|
||||
· simp [go]
|
||||
· rename_i y ys ih
|
||||
simp only [go]
|
||||
split
|
||||
· apply ih
|
||||
simp_all
|
||||
· simp_all
|
||||
|
||||
private theorem exists_getElem_eq_of_drop_eq_cons {xs : List α} {k : Nat} {y : α} {ys : List α}
|
||||
(h : xs.drop k = y :: ys) : ∃ hlt : k < xs.length, xs[k] = y := by
|
||||
have hlt : k < xs.length := by
|
||||
false_or_by_contra
|
||||
have : drop k xs = [] := drop_of_length_le (by omega)
|
||||
simp [this] at h
|
||||
refine ⟨hlt, ?_⟩
|
||||
have := take_append_drop k xs
|
||||
rw [h] at this
|
||||
simp +singlePass only [← this]
|
||||
rw [getElem_append_right (length_take_le _ _)]
|
||||
simp [length_take_of_le (Nat.le_of_lt hlt)]
|
||||
|
||||
private theorem take_succ_eq_append_of_drop_eq_cons {xs : List α} {k : Nat} {y : α}
|
||||
{ys : List α} (h : xs.drop k = y :: ys) : xs.take (k + 1) = xs.take k ++ [y] := by
|
||||
obtain ⟨hlt, rfl⟩ := exists_getElem_eq_of_drop_eq_cons h
|
||||
rw [take_succ_eq_append_getElem hlt]
|
||||
|
||||
private theorem minIdxOn_eq_go_drop [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β}
|
||||
{xs : List α} (h : xs ≠ []) {k : Nat} :
|
||||
∃ (i : Nat) (hlt : i < xs.length), i ≤ k ∧ xs[i] = (xs.take (k + 1)).minOn f (by simpa) ∧
|
||||
xs.minIdxOn f h = List.minIdxOn.go f ((xs.take (k + 1)).minOn f (by cases xs <;> simp_all)) i (k + 1) (xs.drop (k + 1)) := by
|
||||
match xs with
|
||||
| y :: ys =>
|
||||
simp only [drop_succ_cons]
|
||||
induction k
|
||||
· simp [minIdxOn]
|
||||
· rename_i k ih
|
||||
specialize ih
|
||||
obtain ⟨i, hlt, hi, ih⟩ := ih
|
||||
simp only [ih, ← drop_drop]
|
||||
simp only [length_cons] at hlt
|
||||
match h : drop k ys with
|
||||
| [] =>
|
||||
have : ys.length ≤ k := by simp_all
|
||||
simp [drop_nil, minIdxOn.go, take_of_length_le, hi, ih, hlt, this, Nat.le_succ_of_le]
|
||||
| z :: zs =>
|
||||
simp only [minIdxOn.go]
|
||||
have : take (k + 1 + 1) (y :: ys) = take (k + 1) (y :: ys) ++ [z] := by apply take_succ_eq_append_of_drop_eq_cons ‹_›
|
||||
simp only [this, List.minOn_append (xs := take (k + 1) (y :: ys)) (by simp) (cons_ne_nil _ _)]
|
||||
simp only [take_succ_cons] at this
|
||||
split
|
||||
· simp only [List.minOn_singleton, minOn_eq_left, length_cons, *]
|
||||
exact ⟨i, by omega, Nat.le_succ_of_le ‹i ≤ k›, by simp [ih], rfl⟩
|
||||
· simp only [List.minOn_singleton, not_false_eq_true, minOn_eq_right, length_cons, *]
|
||||
obtain ⟨hlt, rfl⟩ := exists_getElem_eq_of_drop_eq_cons h
|
||||
exact ⟨k + 1, by omega, Nat.le_refl _, by simp, rfl⟩
|
||||
|
||||
@[simp]
|
||||
protected theorem minIdxOn_nil_eq_iff_true [LE β] [DecidableLE β] {f : α → β} {x : Nat}
|
||||
(h : [] ≠ []) : ([] : List α).minIdxOn f h = x ↔ True :=
|
||||
nomatch h
|
||||
|
||||
protected theorem minIdxOn_nil_eq_iff_false [LE β] [DecidableLE β] {f : α → β} {x : Nat}
|
||||
(h : [] ≠ []) : ([] : List α).minIdxOn f h = x ↔ False :=
|
||||
nomatch h
|
||||
|
||||
@[simp]
|
||||
protected theorem minIdxOn_singleton [LE β] [DecidableLE β] {x : α} {f : α → β} :
|
||||
[x].minIdxOn f (of_decide_eq_false rfl) = 0 := by
|
||||
rw [minIdxOn, minIdxOn.go]
|
||||
|
||||
@[simp]
|
||||
protected theorem minIdxOn_lt_length [LE β] [DecidableLE β] {f : α → β} {xs : List α}
|
||||
(h : xs ≠ []) : xs.minIdxOn f h < xs.length := by
|
||||
rw [minIdxOn.eq_def]
|
||||
split
|
||||
simp [minIdxOn.go_lt_length_add]
|
||||
|
||||
protected theorem minIdxOn_le_of_apply_getElem_le_apply_minOn [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{f : α → β} {xs : List α} (h : xs ≠ [])
|
||||
{k : Nat} (hi : k < xs.length) (hle : f xs[k] ≤ f (xs.minOn f h)) :
|
||||
xs.minIdxOn f h ≤ k := by
|
||||
obtain ⟨i, _, hi, _, h'⟩ := minIdxOn_eq_go_drop (f := f) h (k := k)
|
||||
rw [h']
|
||||
refine Nat.le_trans ?_ hi
|
||||
apply Nat.le_of_eq
|
||||
apply minIdxOn.go_eq_of_forall_le
|
||||
intro y hy
|
||||
refine le_trans (List.apply_minOn_le_of_mem (y := xs[k]) (by rw [mem_take_iff_getElem]; exact ⟨k, by omega, rfl⟩)) ?_
|
||||
refine le_trans hle ?_
|
||||
apply List.apply_minOn_le_of_mem
|
||||
apply mem_of_mem_drop
|
||||
exact hy
|
||||
|
||||
protected theorem apply_minOn_lt_apply_getElem_of_lt_minIdxOn [LE β] [DecidableLE β] [LT β] [IsLinearPreorder β]
|
||||
[LawfulOrderLT β]
|
||||
{f : α → β} {xs : List α} (h : xs ≠ [])
|
||||
{k : Nat} (hk : k < xs.minIdxOn f h) :
|
||||
f (xs.minOn f h) < f (xs[k]'(by haveI := List.minIdxOn_lt_length (f := f) h; omega)) := by
|
||||
simp only [← not_le] at hk ⊢
|
||||
apply hk.imp
|
||||
apply List.minIdxOn_le_of_apply_getElem_le_apply_minOn
|
||||
|
||||
@[simp]
|
||||
protected theorem getElem_minIdxOn [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{f : α → β} {xs : List α} (h : xs ≠ []) :
|
||||
xs[xs.minIdxOn f h] = xs.minOn f h := by
|
||||
obtain ⟨i, hlt, hi, heq, h'⟩ := minIdxOn_eq_go_drop (f := f) h (k := xs.length)
|
||||
simp only [drop_eq_nil_of_le (as := xs) (i := xs.length + 1) (by omega), minIdxOn.go] at h'
|
||||
simp [h', heq, take_of_length_le (l := xs) (i := xs.length + 1) (by omega)]
|
||||
|
||||
protected theorem le_minIdxOn_of_apply_getElem_lt_apply_getElem [LE β] [DecidableLE β] [LT β] [IsLinearPreorder β]
|
||||
[LawfulOrderLT β] {f : α → β} {xs : List α} (h : xs ≠ []) {i : Nat} (hi : i < xs.length)
|
||||
(hi' : ∀ j, (_ : j < i) → f xs[i] < f xs[j]) :
|
||||
i ≤ xs.minIdxOn f h := by
|
||||
false_or_by_contra; rename_i hgt
|
||||
simp only [not_le] at hgt
|
||||
specialize hi' _ hgt
|
||||
simp only [List.getElem_minIdxOn] at hi'
|
||||
apply (not_le.mpr hi').elim
|
||||
apply List.apply_minOn_le_of_mem
|
||||
simp
|
||||
|
||||
protected theorem minIdxOn_le_of_apply_getElem_le_apply_getElem [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{f : α → β} {xs : List α} (h : xs ≠ []) {i : Nat} (hi : i < xs.length)
|
||||
(hi' : ∀ j, (_ : j < xs.length) → f xs[i] ≤ f xs[j]) :
|
||||
xs.minIdxOn f h ≤ i := by
|
||||
apply List.minIdxOn_le_of_apply_getElem_le_apply_minOn h hi
|
||||
simp only [List.le_apply_minOn_iff, List.mem_iff_getElem]
|
||||
rintro _ ⟨j, hj, rfl⟩
|
||||
exact hi' _ hj
|
||||
|
||||
protected theorem minIdxOn_eq_iff [LE β] [DecidableLE β] [LT β] [IsLinearPreorder β]
|
||||
[LawfulOrderLT β]
|
||||
{f : α → β} {xs : List α} (h : xs ≠ []) {i : Nat} :
|
||||
xs.minIdxOn f h = i ↔ ∃ (h : i < xs.length),
|
||||
(∀ j, (_ : j < xs.length) → f xs[i] ≤ f xs[j]) ∧
|
||||
(∀ j, (_ : j < i) → f xs[i] < f xs[j]) := by
|
||||
apply Iff.intro
|
||||
· rintro rfl
|
||||
simp only [List.getElem_minIdxOn]
|
||||
refine ⟨List.minIdxOn_lt_length h, ?_, ?_⟩
|
||||
· simp [List.apply_minOn_le_of_mem]
|
||||
· exact fun j hj => List.apply_minOn_lt_apply_getElem_of_lt_minIdxOn h hj
|
||||
· rintro ⟨hi, h₁, h₂⟩
|
||||
apply le_antisymm
|
||||
· apply List.minIdxOn_le_of_apply_getElem_le_apply_getElem h hi h₁
|
||||
· apply List.le_minIdxOn_of_apply_getElem_lt_apply_getElem h hi h₂
|
||||
|
||||
protected theorem minIdxOn_eq_iff_eq_minOn [LE β] [DecidableLE β] [LT β] [IsLinearPreorder β]
|
||||
[LawfulOrderLT β] {f : α → β} {xs : List α} (h : xs ≠ []) {i : Nat} :
|
||||
xs.minIdxOn f h = i ↔ ∃ hi : i < xs.length, xs[i] = xs.minOn f h ∧
|
||||
∀ (j : Nat) (hj : j < i), f (xs.minOn f h) < f xs[j] := by
|
||||
apply Iff.intro
|
||||
· rintro rfl
|
||||
refine ⟨List.minIdxOn_lt_length h, List.getElem_minIdxOn h, ?_⟩
|
||||
intro j hj
|
||||
exact List.apply_minOn_lt_apply_getElem_of_lt_minIdxOn h hj
|
||||
· rintro ⟨hlt, heq, h'⟩
|
||||
specialize h' (xs.minIdxOn f h)
|
||||
simp only [List.getElem_minIdxOn] at h'
|
||||
apply le_antisymm
|
||||
· apply List.minIdxOn_le_of_apply_getElem_le_apply_minOn h hlt
|
||||
simp [heq, le_refl]
|
||||
· simpa [lt_irrefl] using h'
|
||||
|
||||
private theorem minIdxOn.go_eq
|
||||
[LE β] [DecidableLE β] [IsLinearPreorder β] {x : α} {xs : List α} {f : α → β} :
|
||||
List.minIdxOn.go f x i j xs =
|
||||
if h : xs = [] then i
|
||||
else if f x ≤ f (xs.minOn f h) then i
|
||||
else (xs.minIdxOn f h) + j := by
|
||||
open scoped Classical.Order in
|
||||
induction xs generalizing x i j
|
||||
· simp [go]
|
||||
· rename_i y ys ih
|
||||
simp only [go, reduceCtorEq, ↓reduceDIte]
|
||||
split
|
||||
· rw [ih]
|
||||
split
|
||||
· simp [*]
|
||||
· simp only [List.minOn_cons, ↓reduceDIte, le_apply_minOn_iff, true_and, *]
|
||||
split
|
||||
· rfl
|
||||
· rename_i hlt
|
||||
simp only [minIdxOn]
|
||||
split
|
||||
simp only [ih, reduceCtorEq, ↓reduceDIte]
|
||||
rw [if_neg]
|
||||
· simp [minIdxOn, Nat.add_assoc, Nat.add_comm 1]
|
||||
· simp only [not_le] at hlt ⊢
|
||||
exact lt_of_lt_of_le hlt ‹_›
|
||||
· rename_i hlt
|
||||
rw [if_neg]
|
||||
· rw [minIdxOn, ih]
|
||||
split
|
||||
· simp [*, go]
|
||||
· simp only [↓reduceDIte, *]
|
||||
split
|
||||
· simp
|
||||
· simp only [Nat.add_assoc, Nat.add_comm 1]
|
||||
· simp only [not_le] at hlt ⊢
|
||||
exact lt_of_le_of_lt (List.apply_minOn_le_of_mem mem_cons_self) hlt
|
||||
|
||||
protected theorem minIdxOn_cons
|
||||
[LE β] [DecidableLE β] [IsLinearPreorder β] {x : α} {xs : List α} {f : α → β} :
|
||||
(x :: xs).minIdxOn f (by exact of_decide_eq_false rfl) =
|
||||
if h : xs = [] then 0
|
||||
else if f x ≤ f (xs.minOn f h) then 0
|
||||
else (xs.minIdxOn f h) + 1 := by
|
||||
simpa [List.minIdxOn] using minIdxOn.go_eq
|
||||
|
||||
protected theorem minIdxOn_eq_zero_iff [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs : List α} {f : α → β} (h : xs ≠ []) :
|
||||
xs.minIdxOn f h = 0 ↔ ∀ x ∈ xs, f (xs.head h) ≤ f x := by
|
||||
rw [minIdxOn.eq_def]
|
||||
split
|
||||
rename_i y ys _
|
||||
simp only [mem_cons, head_cons, forall_eq_or_imp, le_refl, true_and]
|
||||
apply Iff.intro
|
||||
· intro h
|
||||
cases ys
|
||||
· simp
|
||||
· intro a ha
|
||||
refine le_trans ?_ (List.apply_minOn_le_of_mem ha)
|
||||
simpa [minIdxOn.go_eq] using h
|
||||
· intro h
|
||||
cases ys
|
||||
· simp [minIdxOn.go]
|
||||
· simpa [minIdxOn.go_eq, List.le_apply_minOn_iff] using h
|
||||
|
||||
section Append
|
||||
|
||||
/-!
|
||||
The proof of {name}`List.minOn_append` uses associativity of {name}`minOn` and applies {name}`foldl_assoc`.
|
||||
The proof of {name (scope := "Init.Data.List.MinMaxIdx")}`minIdxOn_append` is analogous, but the
|
||||
aggregation operation, {name (scope := "Init.Data.List.MinMaxIdx")}`combineMinIdxOn`, depends on
|
||||
the length of the lists to combine. After proving associativity of the aggregation operation,
|
||||
the proof closely follows the proof of {name}`foldl_assoc`.
|
||||
-/
|
||||
|
||||
private def combineMinIdxOn [LE β] [DecidableLE β]
|
||||
(f : α → β) {xs ys : List α} (i j : Nat) (hi : i < xs.length) (hj : j < ys.length) : Nat :=
|
||||
if f xs[i] ≤ f ys[j] then
|
||||
i
|
||||
else
|
||||
xs.length + j
|
||||
|
||||
private theorem combineMinIdxOn_lt [LE β] [DecidableLE β]
|
||||
(f : α → β) {xs ys : List α} {i j : Nat} (hi : i < xs.length) (hj : j < ys.length) :
|
||||
combineMinIdxOn f i j hi hj < (xs ++ ys).length := by
|
||||
simp only [combineMinIdxOn]
|
||||
split <;> (simp; omega)
|
||||
|
||||
private theorem combineMinIdxOn_assoc [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs ys zs : List α} {i j k : Nat} {f : α → β} (hi : i < xs.length) (hj : j < ys.length)
|
||||
(hk : k < zs.length) :
|
||||
combineMinIdxOn f (combineMinIdxOn f i j _ _) k
|
||||
(combineMinIdxOn_lt f hi hj) hk = combineMinIdxOn f i (combineMinIdxOn f j k _ _) hi (combineMinIdxOn_lt f hj hk) := by
|
||||
open scoped Classical.Order in
|
||||
simp only [combineMinIdxOn]
|
||||
split
|
||||
· rw [getElem_append_left (by omega)]
|
||||
split
|
||||
· split
|
||||
· rw [getElem_append_left (by omega)]
|
||||
simp [*]
|
||||
· rw [getElem_append_right (by omega)]
|
||||
simp [*]
|
||||
· split
|
||||
· have := le_trans ‹f xs[i] ≤ f ys[j]› ‹f ys[j] ≤ f zs[k]›
|
||||
contradiction
|
||||
· rw [getElem_append_right (by omega)]
|
||||
simp [*, Nat.add_assoc]
|
||||
· rw [getElem_append_right (by omega)]
|
||||
simp only [Nat.add_sub_cancel_left]
|
||||
split
|
||||
· rw [getElem_append_left (by omega), if_neg ‹_›]
|
||||
· rename_i h₁ h₂
|
||||
rw [not_le] at h₁ h₂
|
||||
rw [getElem_append_right (by omega)]
|
||||
simp only [Nat.add_sub_cancel_left]
|
||||
have := not_le.mpr <| lt_trans h₂ h₁
|
||||
simp [*, Nat.add_assoc]
|
||||
|
||||
private theorem minIdxOn_cons_aux [LE β] [DecidableLE β]
|
||||
[IsLinearPreorder β] {x : α} {xs : List α} {f : α → β} (hxs : xs ≠ []) :
|
||||
(x :: xs).minIdxOn f (by simp) =
|
||||
combineMinIdxOn f _ _
|
||||
(List.minIdxOn_lt_length (f := f) (cons_ne_nil x []))
|
||||
(List.minIdxOn_lt_length (f := f) hxs) := by
|
||||
rw [minIdxOn, combineMinIdxOn]
|
||||
simp [minIdxOn.go_eq, hxs, List.getElem_minIdxOn, Nat.add_comm 1]
|
||||
|
||||
private theorem minIdxOn_append_aux [LE β] [DecidableLE β]
|
||||
[IsLinearPreorder β] {xs ys : List α} {f : α → β} (hxs : xs ≠ []) (hys : ys ≠ []) :
|
||||
(xs ++ ys).minIdxOn f (by simp [hxs]) =
|
||||
combineMinIdxOn f _ _
|
||||
(List.minIdxOn_lt_length (f := f) hxs)
|
||||
(List.minIdxOn_lt_length (f := f) hys) := by
|
||||
induction xs
|
||||
· contradiction
|
||||
· rename_i x xs ih
|
||||
match xs with
|
||||
| [] => simp [minIdxOn_cons_aux (xs := ys) ‹_›]
|
||||
| z :: zs =>
|
||||
simp +singlePass only [cons_append]
|
||||
simp only [minIdxOn_cons_aux (xs := z :: zs ++ ys) (by simp), ih (by simp),
|
||||
minIdxOn_cons_aux (xs := z :: zs) (by simp), combineMinIdxOn_assoc]
|
||||
|
||||
protected theorem minIdxOn_append [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs ys : List α} {f : α → β} (hxs : xs ≠ []) (hys : ys ≠ []) :
|
||||
(xs ++ ys).minIdxOn f (by simp [hxs]) =
|
||||
if f (xs.minOn f hxs) ≤ f (ys.minOn f hys) then
|
||||
xs.minIdxOn f hxs
|
||||
else
|
||||
xs.length + ys.minIdxOn f hys := by
|
||||
simp [minIdxOn_append_aux hxs hys, combineMinIdxOn, List.getElem_minIdxOn]
|
||||
|
||||
end Append
|
||||
|
||||
protected theorem left_le_minIdxOn_append [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs ys : List α} {f : α → β} (h : xs ≠ []) :
|
||||
xs.minIdxOn f h ≤ (xs ++ ys).minIdxOn f (by simp [h]) := by
|
||||
by_cases hys : ys = []
|
||||
· simp [hys]
|
||||
· rw [List.minIdxOn_append h hys]
|
||||
split
|
||||
· apply Nat.le_refl
|
||||
· have := List.minIdxOn_lt_length (f := f) h
|
||||
omega
|
||||
|
||||
protected theorem minIdxOn_take_le [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs : List α} {f : α → β} {i : Nat} (h : xs.take i ≠ []) :
|
||||
(xs.take i).minIdxOn f h ≤ xs.minIdxOn f (List.ne_nil_of_take_ne_nil h) := by
|
||||
have := take_append_drop i xs
|
||||
conv => rhs; simp +singlePass only [← this]
|
||||
apply List.left_le_minIdxOn_append
|
||||
|
||||
@[simp]
|
||||
protected theorem minIdxOn_replicate [LE β] [DecidableLE β] [Refl (α := β) (· ≤ ·)]
|
||||
{n : Nat} {a : α} {f : α → β} (h : replicate n a ≠ []) :
|
||||
(replicate n a).minIdxOn f h = 0 := by
|
||||
match n with
|
||||
| 0 => simp at h
|
||||
| n + 1 =>
|
||||
simp only [minIdxOn, replicate_succ]
|
||||
generalize 1 = j
|
||||
induction n generalizing j
|
||||
· simp [minIdxOn.go]
|
||||
· simp only [replicate_succ, minIdxOn.go] at *
|
||||
split
|
||||
· simp [*]
|
||||
· have := le_refl (f a)
|
||||
contradiction
|
||||
|
||||
@[simp]
|
||||
protected theorem maxIdxOn_nil_eq_iff_true [LE β] [DecidableLE β] {f : α → β} {x : Nat}
|
||||
(h : [] ≠ []) : ([] : List α).maxIdxOn f h = x ↔ True :=
|
||||
nomatch h
|
||||
|
||||
protected theorem maxIdxOn_nil_eq_iff_false [LE β] [DecidableLE β] {f : α → β} {x : Nat}
|
||||
(h : [] ≠ []) : ([] : List α).maxIdxOn f h = x ↔ False :=
|
||||
nomatch h
|
||||
|
||||
@[simp]
|
||||
protected theorem maxIdxOn_singleton [LE β] [DecidableLE β] {x : α} {f : α → β} :
|
||||
[x].maxIdxOn f (of_decide_eq_false rfl) = 0 :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.minIdxOn_singleton
|
||||
|
||||
@[simp]
|
||||
protected theorem maxIdxOn_lt_length [LE β] [DecidableLE β] {f : α → β} {xs : List α}
|
||||
(h : xs ≠ []) : xs.maxIdxOn f h < xs.length :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.minIdxOn_lt_length h
|
||||
|
||||
protected theorem maxIdxOn_le_of_apply_getElem_le_apply_maxOn [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{f : α → β} {xs : List α} (h : xs ≠ [])
|
||||
{k : Nat} (hi : k < xs.length) (hle : f (xs.maxOn f h) ≤ f xs[k]) :
|
||||
xs.maxIdxOn f h ≤ k := by
|
||||
simp only [List.maxIdxOn_eq_minIdxOn, List.maxOn_eq_minOn] at hle ⊢
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
exact List.minIdxOn_le_of_apply_getElem_le_apply_minOn h hi (by simpa [LE.le_opposite_iff] using hle)
|
||||
|
||||
protected theorem apply_maxOn_lt_apply_getElem_of_lt_maxIdxOn [LE β] [DecidableLE β] [LT β] [IsLinearPreorder β]
|
||||
[LawfulOrderLT β]
|
||||
{f : α → β} {xs : List α} (h : xs ≠ [])
|
||||
{k : Nat} (hk : k < xs.maxIdxOn f h) :
|
||||
f (xs[k]'(by haveI := List.maxIdxOn_lt_length (f := f) h; omega)) < f (xs.maxOn f h) := by
|
||||
simp only [List.maxIdxOn_eq_minIdxOn, List.maxOn_eq_minOn] at hk ⊢
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
letI : LT β := LT.opposite inferInstance
|
||||
simpa [LT.lt_opposite_iff] using List.apply_minOn_lt_apply_getElem_of_lt_minIdxOn (f := f) h hk
|
||||
|
||||
@[simp]
|
||||
protected theorem getElem_maxIdxOn [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{f : α → β} {xs : List α} (h : xs ≠ []) :
|
||||
xs[xs.maxIdxOn f h] = xs.maxOn f h := by
|
||||
simp only [List.maxIdxOn_eq_minIdxOn, List.maxOn_eq_minOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
exact List.getElem_minIdxOn h
|
||||
|
||||
protected theorem le_maxIdxOn_of_apply_getElem_lt_apply_getElem [LE β] [DecidableLE β] [LT β]
|
||||
[IsLinearPreorder β] [LawfulOrderLT β] {f : α → β} {xs : List α} (h : xs ≠ []) {i : Nat}
|
||||
(hi : i < xs.length) (hi' : ∀ j, (_ : j < i) → f xs[j] < f xs[i]) :
|
||||
i ≤ xs.maxIdxOn f h := by
|
||||
simp only [List.maxIdxOn_eq_minIdxOn]
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
letI : LT β := LT.opposite inferInstance
|
||||
simpa [LE.le_opposite_iff] using List.le_minIdxOn_of_apply_getElem_lt_apply_getElem h hi
|
||||
(by simpa [LT.lt_opposite_iff] using hi')
|
||||
|
||||
protected theorem maxIdxOn_le_of_apply_getElem_le_apply_getElem [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{f : α → β} {xs : List α} (h : xs ≠ []) {i : Nat} (hi : i < xs.length)
|
||||
(hi' : ∀ j, (_ : j < xs.length) → f xs[j] ≤ f xs[i]) :
|
||||
xs.maxIdxOn f h ≤ i := by
|
||||
simp only [List.maxIdxOn_eq_minIdxOn]
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
simpa [LE.le_opposite_iff] using List.minIdxOn_le_of_apply_getElem_le_apply_getElem (f := f) h hi
|
||||
(by simpa [LE.le_opposite_iff] using hi')
|
||||
|
||||
protected theorem maxIdxOn_eq_iff [LE β] [DecidableLE β] [LT β] [IsLinearPreorder β]
|
||||
[LawfulOrderLT β]
|
||||
{f : α → β} {xs : List α} (h : xs ≠ []) {i : Nat} :
|
||||
xs.maxIdxOn f h = i ↔ ∃ (h : i < xs.length),
|
||||
(∀ j, (_ : j < xs.length) → f xs[j] ≤ f xs[i]) ∧
|
||||
(∀ j, (_ : j < i) → f xs[j] < f xs[i]) := by
|
||||
simp only [List.maxIdxOn_eq_minIdxOn]
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
letI : LT β := LT.opposite inferInstance
|
||||
simpa [LE.le_opposite_iff, LT.lt_opposite_iff] using List.minIdxOn_eq_iff (f := f) h
|
||||
|
||||
protected theorem maxIdxOn_eq_iff_eq_maxOn [LE β] [DecidableLE β] [LT β] [IsLinearPreorder β]
|
||||
[LawfulOrderLT β] {f : α → β} {xs : List α} (h : xs ≠ []) {i : Nat} :
|
||||
xs.maxIdxOn f h = i ↔ ∃ hi : i < xs.length, xs[i] = xs.maxOn f h ∧
|
||||
∀ (j : Nat) (hj : j < i), f xs[j] < f (xs.maxOn f h) := by
|
||||
simp only [List.maxIdxOn_eq_minIdxOn, List.maxOn_eq_minOn]
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
letI : LT β := LT.opposite inferInstance
|
||||
simpa [LT.lt_opposite_iff] using List.minIdxOn_eq_iff_eq_minOn (f := f) h
|
||||
|
||||
protected theorem maxIdxOn_cons
|
||||
[LE β] [DecidableLE β] [IsLinearPreorder β] {x : α} {xs : List α} {f : α → β} :
|
||||
(x :: xs).maxIdxOn f (by exact of_decide_eq_false rfl) =
|
||||
if h : xs = [] then 0
|
||||
else if f (xs.maxOn f h) ≤ f x then 0
|
||||
else (xs.maxIdxOn f h) + 1 := by
|
||||
simp only [List.maxIdxOn_eq_minIdxOn, List.maxOn_eq_minOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
simpa [LE.le_opposite_iff] using List.minIdxOn_cons (f := f)
|
||||
|
||||
protected theorem maxIdxOn_eq_zero_iff [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs : List α} {f : α → β} (h : xs ≠ []) :
|
||||
xs.maxIdxOn f h = 0 ↔ ∀ x ∈ xs, f x ≤ f (xs.head h) := by
|
||||
simp only [List.maxIdxOn_eq_minIdxOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
simpa [LE.le_opposite_iff] using List.minIdxOn_eq_zero_iff h (f := f)
|
||||
|
||||
protected theorem maxIdxOn_append [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs ys : List α} {f : α → β} (hxs : xs ≠ []) (hys : ys ≠ []) :
|
||||
(xs ++ ys).maxIdxOn f (by simp [hxs]) =
|
||||
if f (ys.maxOn f hys) ≤ f (xs.maxOn f hxs) then
|
||||
xs.maxIdxOn f hxs
|
||||
else
|
||||
xs.length + ys.maxIdxOn f hys := by
|
||||
simp only [List.maxIdxOn_eq_minIdxOn, List.maxOn_eq_minOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
simpa [LE.le_opposite_iff] using List.minIdxOn_append hxs hys (f := f)
|
||||
|
||||
protected theorem left_le_maxIdxOn_append [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs ys : List α} {f : α → β} (h : xs ≠ []) :
|
||||
xs.maxIdxOn f h ≤ (xs ++ ys).maxIdxOn f (by simp [h]) :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.left_le_minIdxOn_append h
|
||||
|
||||
protected theorem maxIdxOn_take_le [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs : List α} {f : α → β} {i : Nat} (h : xs.take i ≠ []) :
|
||||
(xs.take i).maxIdxOn f h ≤ xs.maxIdxOn f (List.ne_nil_of_take_ne_nil h) :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.minIdxOn_take_le h
|
||||
|
||||
@[simp]
|
||||
protected theorem maxIdxOn_replicate [LE β] [DecidableLE β] [Refl (α := β) (· ≤ ·)]
|
||||
{n : Nat} {a : α} {f : α → β} (h : replicate n a ≠ []) :
|
||||
(replicate n a).maxIdxOn f h = 0 :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.minIdxOn_replicate h
|
||||
|
||||
@[simp]
|
||||
protected theorem minIdxOn?_nil [LE β] [DecidableLE β] {f : α → β} :
|
||||
([] : List α).minIdxOn? f = none :=
|
||||
(rfl)
|
||||
|
||||
@[simp]
|
||||
protected theorem minIdxOn?_singleton [LE β] [DecidableLE β] {x : α} {f : α → β} :
|
||||
[x].minIdxOn? f = some 0 :=
|
||||
(rfl)
|
||||
|
||||
@[simp]
|
||||
protected theorem isSome_minIdxOn?_iff [LE β] [DecidableLE β] {f : α → β} {xs : List α} :
|
||||
(xs.minIdxOn? f).isSome ↔ xs ≠ [] := by
|
||||
cases xs <;> simp [minIdxOn?]
|
||||
|
||||
protected theorem minIdxOn_eq_get_minIdxOn? [LE β] [DecidableLE β] {f : α → β} {xs : List α}
|
||||
(h : xs ≠ []) : xs.minIdxOn f h = (xs.minIdxOn? f).get (List.isSome_minIdxOn?_iff.mpr h) := by
|
||||
match xs with
|
||||
| [] => contradiction
|
||||
| _ :: _ => simp [minIdxOn?]
|
||||
|
||||
@[simp]
|
||||
protected theorem get_minIdxOn?_eq_minIdxOn [LE β] [DecidableLE β] {f : α → β} {xs : List α}
|
||||
(h : (xs.minIdxOn? f).isSome) :
|
||||
(xs.minIdxOn? f).get h = xs.minIdxOn f (List.isSome_minIdxOn?_iff.mp h) := by
|
||||
rw [List.minIdxOn_eq_get_minIdxOn?]
|
||||
|
||||
protected theorem minIdxOn?_eq_some_minIdxOn [LE β] [DecidableLE β] {f : α → β} {xs : List α}
|
||||
(h : xs ≠ []) : xs.minIdxOn? f = some (xs.minIdxOn f h) := by
|
||||
match xs with
|
||||
| [] => contradiction
|
||||
| _ :: _ => simp [minIdxOn?]
|
||||
|
||||
protected theorem minIdxOn_eq_of_minIdxOn?_eq_some
|
||||
[LE β] [DecidableLE β] {f : α → β} {xs : List α} {i : Nat} (h : xs.minIdxOn? f = some i) :
|
||||
xs.minIdxOn f (List.isSome_minIdxOn?_iff.mp (Option.isSome_of_eq_some h)) = i := by
|
||||
have h' := List.isSome_minIdxOn?_iff.mp (Option.isSome_of_eq_some h)
|
||||
rwa [List.minIdxOn?_eq_some_minIdxOn h', Option.some.injEq] at h
|
||||
|
||||
protected theorem isSome_minIdxOn?_of_mem
|
||||
[LE β] [DecidableLE β] {f : α → β} {xs : List α} {x : α} (h : x ∈ xs) :
|
||||
(xs.minIdxOn? f).isSome := by
|
||||
apply List.isSome_minIdxOn?_iff.mpr
|
||||
exact ne_nil_of_mem h
|
||||
|
||||
protected theorem minIdxOn?_cons_eq_some_minIdxOn
|
||||
[LE β] [DecidableLE β] {f : α → β} {x : α} {xs : List α} :
|
||||
(x :: xs).minIdxOn? f = some ((x :: xs).minIdxOn f (nomatch ·)) := by
|
||||
simp [List.minIdxOn?_eq_some_minIdxOn]
|
||||
|
||||
protected theorem minIdxOn?_eq_if
|
||||
[LE β] [DecidableLE β] {f : α → β} {xs : List α} :
|
||||
xs.minIdxOn? f =
|
||||
if h : xs ≠ [] then
|
||||
some (xs.minIdxOn f h)
|
||||
else
|
||||
none := by
|
||||
cases xs <;> simp [List.minIdxOn?_cons_eq_some_minIdxOn]
|
||||
|
||||
protected theorem minIdxOn?_cons
|
||||
[LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β} {x : α} {xs : List α} :
|
||||
(x :: xs).minIdxOn? f = some
|
||||
(if h : xs = [] then 0
|
||||
else if f x ≤ f (xs.minOn f h) then 0
|
||||
else (xs.minIdxOn f h) + 1) := by
|
||||
simp [List.minIdxOn?_eq_some_minIdxOn, List.minIdxOn_cons]
|
||||
|
||||
protected theorem ne_nil_of_minIdxOn?_eq_some
|
||||
[LE β] [DecidableLE β] {f : α → β} {k : Nat} {xs : List α} (h : xs.minIdxOn? f = some k) :
|
||||
xs ≠ [] := by
|
||||
rintro rfl
|
||||
simp at h
|
||||
|
||||
protected theorem lt_length_of_minIdxOn?_eq_some [LE β] [DecidableLE β] {f : α → β}
|
||||
{xs : List α} (h : xs.minIdxOn? f = some i) : i < xs.length := by
|
||||
have hne : xs ≠ [] := List.ne_nil_of_minIdxOn?_eq_some h
|
||||
rw [List.minIdxOn?_eq_some_minIdxOn hne] at h
|
||||
have := List.minIdxOn_lt_length (f := f) hne
|
||||
simp_all
|
||||
|
||||
@[simp]
|
||||
protected theorem get_minIdxOn?_lt_length [LE β] [DecidableLE β] {f : α → β} {xs : List α}
|
||||
(h : (xs.minIdxOn? f).isSome) : (xs.minIdxOn? f).get h < xs.length := by
|
||||
rw [List.get_minIdxOn?_eq_minIdxOn]
|
||||
apply List.minIdxOn_lt_length
|
||||
|
||||
@[simp]
|
||||
protected theorem getElem_get_minIdxOn? [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{f : α → β} {xs : List α} (h : (xs.minIdxOn? f).isSome) :
|
||||
xs[(xs.minIdxOn? f).get h] = xs.minOn f (List.isSome_minIdxOn?_iff.mp h) := by
|
||||
rw [getElem_congr rfl (List.get_minIdxOn?_eq_minIdxOn _), List.getElem_minIdxOn]
|
||||
|
||||
protected theorem minIdxOn?_eq_some_zero_iff [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs : List α} {f : α → β} :
|
||||
xs.minIdxOn? f = some 0 ↔ ∃ h : xs ≠ [], ∀ x ∈ xs, f (xs.head h) ≤ f x := by
|
||||
simp [Option.eq_some_iff_get_eq, List.minIdxOn_eq_zero_iff]
|
||||
|
||||
protected theorem minIdxOn?_replicate [LE β] [DecidableLE β] [Refl (α := β) (· ≤ ·)]
|
||||
{n : Nat} {a : α} {f : α → β} :
|
||||
(replicate n a).minIdxOn? f = if n = 0 then none else some 0 := by
|
||||
simp [List.minIdxOn?_eq_if]
|
||||
|
||||
@[simp]
|
||||
protected theorem minIdxOn?_replicate_of_pos [LE β] [DecidableLE β] [Refl (α := β) (· ≤ ·)]
|
||||
{n : Nat} {a : α} {f : α → β} (h : 0 < n) :
|
||||
(replicate n a).minIdxOn? f = some 0 := by
|
||||
simp [List.minIdxOn?_replicate, Nat.ne_zero_of_lt h]
|
||||
|
||||
/-! ### maxIdxOn? -/
|
||||
|
||||
protected theorem maxIdxOn?_eq_minIdxOn? {le : LE β} {_ : DecidableLE β} {f : α → β}
|
||||
{xs : List α} :
|
||||
xs.maxIdxOn? f = (letI := le.opposite; xs.minIdxOn? f) :=
|
||||
(rfl)
|
||||
|
||||
@[simp]
|
||||
protected theorem maxIdxOn?_nil [LE β] [DecidableLE β] {f : α → β} :
|
||||
([] : List α).maxIdxOn? f = none :=
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
List.minIdxOn?_nil
|
||||
|
||||
@[simp]
|
||||
protected theorem maxIdxOn?_singleton [LE β] [DecidableLE β] {x : α} {f : α → β} :
|
||||
[x].maxIdxOn? f = some 0 :=
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
List.minIdxOn?_singleton
|
||||
|
||||
@[simp]
|
||||
protected theorem isSome_maxIdxOn?_iff [LE β] [DecidableLE β] {f : α → β} {xs : List α} :
|
||||
(xs.maxIdxOn? f).isSome ↔ xs ≠ [] := by
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
exact List.isSome_minIdxOn?_iff
|
||||
|
||||
protected theorem maxIdxOn_eq_get_maxIdxOn? [LE β] [DecidableLE β] {f : α → β} {xs : List α}
|
||||
(h : xs ≠ []) : xs.maxIdxOn f h = (xs.maxIdxOn? f).get (List.isSome_maxIdxOn?_iff.mpr h) := by
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
exact List.minIdxOn_eq_get_minIdxOn? h
|
||||
|
||||
@[simp]
|
||||
protected theorem get_maxIdxOn?_eq_maxIdxOn [LE β] [DecidableLE β] {f : α → β} {xs : List α}
|
||||
(h : (xs.maxIdxOn? f).isSome) :
|
||||
(xs.maxIdxOn? f).get h = xs.maxIdxOn f (List.isSome_maxIdxOn?_iff.mp h) := by
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
exact List.get_minIdxOn?_eq_minIdxOn h
|
||||
|
||||
protected theorem maxIdxOn?_eq_some_maxIdxOn [LE β] [DecidableLE β] {f : α → β} {xs : List α}
|
||||
(h : xs ≠ []) : xs.maxIdxOn? f = some (xs.maxIdxOn f h) := by
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
exact List.minIdxOn?_eq_some_minIdxOn h
|
||||
|
||||
protected theorem maxIdxOn_eq_of_maxIdxOn?_eq_some
|
||||
[LE β] [DecidableLE β] {f : α → β} {xs : List α} {i : Nat} (h : xs.maxIdxOn? f = some i) :
|
||||
xs.maxIdxOn f (List.isSome_maxIdxOn?_iff.mp (Option.isSome_of_eq_some h)) = i := by
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
exact List.minIdxOn_eq_of_minIdxOn?_eq_some h
|
||||
|
||||
protected theorem isSome_maxIdxOn?_of_mem
|
||||
[LE β] [DecidableLE β] {f : α → β} {xs : List α} {x : α} (h : x ∈ xs) :
|
||||
(xs.maxIdxOn? f).isSome := by
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
exact List.isSome_minIdxOn?_of_mem h
|
||||
|
||||
protected theorem maxIdxOn?_cons_eq_some_maxIdxOn
|
||||
[LE β] [DecidableLE β] {f : α → β} {x : α} {xs : List α} :
|
||||
(x :: xs).maxIdxOn? f = some ((x :: xs).maxIdxOn f (nomatch ·)) := by
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
exact List.minIdxOn?_cons_eq_some_minIdxOn
|
||||
|
||||
protected theorem maxIdxOn?_eq_if
|
||||
[LE β] [DecidableLE β] {f : α → β} {xs : List α} :
|
||||
xs.maxIdxOn? f =
|
||||
if h : xs ≠ [] then
|
||||
some (xs.maxIdxOn f h)
|
||||
else
|
||||
none := by
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
exact List.minIdxOn?_eq_if
|
||||
|
||||
protected theorem maxIdxOn?_cons
|
||||
[LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β} {x : α} {xs : List α} :
|
||||
(x :: xs).maxIdxOn? f = some
|
||||
(if h : xs = [] then 0
|
||||
else if f (xs.maxOn f h) ≤ f x then 0
|
||||
else (xs.maxIdxOn f h) + 1) := by
|
||||
simp only [List.maxIdxOn_eq_minIdxOn, List.maxOn_eq_minOn]
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
simpa [LE.le_opposite_iff] using List.minIdxOn?_cons (f := f)
|
||||
|
||||
protected theorem ne_nil_of_maxIdxOn?_eq_some
|
||||
[LE β] [DecidableLE β] {f : α → β} {k : Nat} {xs : List α} (h : xs.maxIdxOn? f = some k) :
|
||||
xs ≠ [] := by
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
exact List.ne_nil_of_minIdxOn?_eq_some (by simpa only [List.maxIdxOn?_eq_minIdxOn?] using h)
|
||||
|
||||
protected theorem lt_length_of_maxIdxOn?_eq_some [LE β] [DecidableLE β] {f : α → β}
|
||||
{xs : List α} (h : xs.maxIdxOn? f = some i) : i < xs.length := by
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
exact List.lt_length_of_minIdxOn?_eq_some h
|
||||
|
||||
@[simp]
|
||||
protected theorem get_maxIdxOn?_lt_length [LE β] [DecidableLE β] {f : α → β} {xs : List α}
|
||||
(h : (xs.maxIdxOn? f).isSome) : (xs.maxIdxOn? f).get h < xs.length := by
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
exact List.get_minIdxOn?_lt_length h
|
||||
|
||||
@[simp]
|
||||
protected theorem getElem_get_maxIdxOn? [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{f : α → β} {xs : List α} (h : (xs.maxIdxOn? f).isSome) :
|
||||
xs[(xs.maxIdxOn? f).get h] = xs.maxOn f (List.isSome_maxIdxOn?_iff.mp h) := by
|
||||
simp only [List.maxIdxOn?_eq_minIdxOn?, List.maxOn_eq_minOn]
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
exact List.getElem_get_minIdxOn? h
|
||||
|
||||
protected theorem maxIdxOn?_eq_some_zero_iff [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs : List α} {f : α → β} :
|
||||
xs.maxIdxOn? f = some 0 ↔ ∃ h : xs ≠ [], ∀ x ∈ xs, f x ≤ f (xs.head h) := by
|
||||
simp only [List.maxIdxOn?_eq_minIdxOn?]
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
simpa [LE.le_opposite_iff] using List.minIdxOn?_eq_some_zero_iff (f := f)
|
||||
|
||||
protected theorem maxIdxOn?_replicate [LE β] [DecidableLE β] [Refl (α := β) (· ≤ ·)]
|
||||
{n : Nat} {a : α} {f : α → β} :
|
||||
(replicate n a).maxIdxOn? f = if n = 0 then none else some 0 := by
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
exact List.minIdxOn?_replicate
|
||||
|
||||
@[simp]
|
||||
protected theorem maxIdxOn?_replicate_of_pos [LE β] [DecidableLE β] [Refl (α := β) (· ≤ ·)]
|
||||
{n : Nat} {a : α} {f : α → β} (h : 0 < n) :
|
||||
(replicate n a).maxIdxOn? f = some 0 := by
|
||||
letI : LE β := LE.opposite inferInstance
|
||||
exact List.minIdxOn?_replicate_of_pos h
|
||||
|
||||
end List
|
||||
623
src/Init/Data/List/MinMaxOn.lean
Normal file
623
src/Init/Data/List/MinMaxOn.lean
Normal file
@@ -0,0 +1,623 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Paul Reichert
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.Data.Order.MinMaxOn
|
||||
public import Init.Data.Int.OfNat
|
||||
public import Init.Data.List.Lemmas
|
||||
public import Init.Data.List.TakeDrop
|
||||
import Init.Data.Order.Lemmas
|
||||
import Init.Data.List.Sublist
|
||||
import Init.Data.List.MinMax
|
||||
import Init.Data.Order.Opposite
|
||||
|
||||
set_option doc.verso true
|
||||
set_option linter.missingDocs true
|
||||
set_option linter.listVariables true -- Enforce naming conventions for `List`/`Array`/`Vector` variables.
|
||||
set_option linter.indexVariables true -- Enforce naming conventions for index variables.
|
||||
|
||||
public section
|
||||
|
||||
open Std
|
||||
open scoped OppositeOrderInstances
|
||||
|
||||
namespace List
|
||||
|
||||
/--
|
||||
Returns an element of the non-empty list {name}`l` that minimizes {name}`f`. If {given}`x, y` are
|
||||
such that {lean}`f x = f y`, it returns whichever comes first in the list.
|
||||
|
||||
The correctness of this function assumes {name}`β` to be linearly pre-ordered.
|
||||
The property that {name}`List.minOn` is the first minimizer in the list is guaranteed by the lemma
|
||||
{name (scope := "Init.Data.List.MinMaxIdx")}`List.getElem_minIdxOn`.
|
||||
-/
|
||||
@[inline, suggest_for List.argmin]
|
||||
protected def minOn [LE β] [DecidableLE β] (f : α → β) (l : List α) (h : l ≠ []) : α :=
|
||||
match l with
|
||||
| x :: xs => xs.foldl (init := x) (minOn f)
|
||||
|
||||
/--
|
||||
Returns an element of the non-empty list {name}`l` that maximizes {name}`f`. If {given}`x, y` are
|
||||
such that {lean}`f x = f y`, it returns whichever comes first in the list.
|
||||
|
||||
The correctness of this function assumes {name}`β` to be linearly pre-ordered.
|
||||
The property that {name}`List.maxOn` is the first maximizer in the list is guaranteed by the lemma
|
||||
{name (scope := "Init.Data.List.MinMaxIdx")}`List.getElem_maxIdxOn`.
|
||||
-/
|
||||
@[inline, suggest_for List.argmax]
|
||||
protected def maxOn [i : LE β] [DecidableLE β] (f : α → β) (l : List α) (h : l ≠ []) : α :=
|
||||
letI : LE β := i.opposite
|
||||
l.minOn f h
|
||||
|
||||
/--
|
||||
Returns an element of {name}`l` that minimizes {name}`f`. If {given}`x, y` are such that
|
||||
{lean}`f x = f y`, it returns whichever comes first in the list. Returns {name}`none` if the list is
|
||||
empty.
|
||||
|
||||
The correctness of this function assumes {name}`β` to be linearly pre-ordered.
|
||||
The property that {name}`List.minOn?` is the first minimizer in the list is guaranteed by the lemma
|
||||
{name (scope := "Init.Data.List.MinMaxIdx")}`List.getElem_get_minIdxOn?`
|
||||
-/
|
||||
@[inline, suggest_for List.argmin? List.argmin] -- Mathlib's `List.argmin` returns an `Option α`
|
||||
protected def minOn? [LE β] [DecidableLE β] (f : α → β) (l : List α) : Option α :=
|
||||
match l with
|
||||
| [] => none
|
||||
| x :: xs => some (xs.foldl (init := x) (minOn f))
|
||||
|
||||
/--
|
||||
Returns an element of {name}`l` that maximizes {name}`f`. If {given}`x, y` are such that
|
||||
{lean}`f x = f y`, it returns whichever comes first in the list. Returns {name}`none` if the list is
|
||||
empty.
|
||||
|
||||
The correctness of this function assumes {name}`β` to be linearly pre-ordered.
|
||||
The property that {name}`List.maxOn?` is the first minimizer in the list is guaranteed by the lemma
|
||||
{name (scope := "Init.Data.List.MinMaxIdx")}`List.getElem_get_maxIdxOn?`.
|
||||
-/
|
||||
@[inline, suggest_for List.argmax? List.argmax] -- Mathlib's `List.argmax` returns an `Option α`
|
||||
protected def maxOn? [i : LE β] [DecidableLE β] (f : α → β) (l : List α) : Option α :=
|
||||
letI : LE β := i.opposite
|
||||
l.minOn? f
|
||||
|
||||
/-! ### minOn -/
|
||||
|
||||
@[simp]
|
||||
protected theorem minOn_singleton [LE β] [DecidableLE β] {x : α} {f : α → β} :
|
||||
[x].minOn f (of_decide_eq_false rfl) = x := by
|
||||
simp [List.minOn]
|
||||
|
||||
protected theorem minOn_cons
|
||||
[LE β] [DecidableLE β] [IsLinearPreorder β] {x : α} {xs : List α} {f : α → β} :
|
||||
(x :: xs).minOn f (by exact of_decide_eq_false rfl) =
|
||||
if h : xs = [] then x else minOn f x (xs.minOn f h) := by
|
||||
simp only [List.minOn]
|
||||
match xs with
|
||||
| [] => simp
|
||||
| y :: xs => simp [foldl_assoc]
|
||||
|
||||
@[simp]
|
||||
protected theorem minOn_id [Min α] [LE α] [DecidableLE α] [LawfulOrderLeftLeaningMin α]
|
||||
{xs : List α} (h : xs ≠ []) :
|
||||
xs.minOn id h = xs.min h := by
|
||||
have : minOn (α := α) id = min := by ext; apply minOn_id
|
||||
simp only [List.minOn, List.min, this]
|
||||
match xs with
|
||||
| _ :: _ => simp
|
||||
|
||||
@[simp]
|
||||
protected theorem minOn_mem [LE β] [DecidableLE β] {xs : List α}
|
||||
{f : α → β} {h : xs ≠ []} : xs.minOn f h ∈ xs := by
|
||||
simp only [List.minOn]
|
||||
match xs with
|
||||
| x :: xs =>
|
||||
fun_induction xs.foldl (init := x) (_root_.minOn f)
|
||||
· simp
|
||||
· rename_i x y _ ih
|
||||
simp only [ne_eq, reduceCtorEq, not_false_eq_true, mem_cons, forall_const, foldl_cons] at ih ⊢
|
||||
cases ih <;> rename_i heq
|
||||
· cases minOn_eq_or (f := f) (x := x) (y := y) <;> simp_all
|
||||
· simp [*]
|
||||
|
||||
protected theorem apply_minOn_le_of_mem [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs : List α} {f : α → β} {y : α} (hx : y ∈ xs) :
|
||||
f (xs.minOn f (List.ne_nil_of_mem hx)) ≤ f y := by
|
||||
have h : xs ≠ [] := List.ne_nil_of_mem hx
|
||||
simp only [List.minOn]
|
||||
match xs with
|
||||
| x :: xs =>
|
||||
fun_induction xs.foldl (init := x) (_root_.minOn f) generalizing y
|
||||
· simp only [mem_cons] at hx
|
||||
simp_all [le_refl _]
|
||||
· rename_i x y _ ih
|
||||
simp at ih ⊢
|
||||
rcases mem_cons.mp hx with rfl | hx
|
||||
· exact le_trans ih.1 apply_minOn_le_left
|
||||
· rcases mem_cons.mp hx with rfl | hx
|
||||
· exact le_trans ih.1 apply_minOn_le_right
|
||||
· apply ih.2
|
||||
assumption
|
||||
|
||||
protected theorem le_apply_minOn_iff [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs : List α} {f : α → β} (h : xs ≠ []) {b : β} :
|
||||
b ≤ f (xs.minOn f h) ↔ ∀ x ∈ xs, b ≤ f x := by
|
||||
match xs with
|
||||
| x :: xs =>
|
||||
rw [List.minOn]
|
||||
induction xs generalizing x
|
||||
· simp
|
||||
· rw [foldl_cons, foldl_assoc, le_apply_minOn_iff]
|
||||
simp_all
|
||||
|
||||
protected theorem apply_minOn_le_iff [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs : List α} {f : α → β} (h : xs ≠ []) {b : β} :
|
||||
f (xs.minOn f h) ≤ b ↔ ∃ x ∈ xs, f x ≤ b := by
|
||||
apply Iff.intro
|
||||
· intro h
|
||||
match xs with
|
||||
| x :: xs =>
|
||||
rw [List.minOn] at h
|
||||
induction xs generalizing x
|
||||
· simpa using h
|
||||
· rename_i y ys ih _
|
||||
rw [foldl_cons] at h
|
||||
specialize ih (minOn f x y) (by simp) h
|
||||
obtain ⟨z, hm, hle⟩ := ih
|
||||
rcases mem_cons.mp hm with rfl | hm
|
||||
· cases minOn_eq_or (f := f) (x := x) (y := y)
|
||||
· exact ⟨x, by simp_all⟩
|
||||
· exact ⟨y, by simp_all⟩
|
||||
· exact ⟨z, by simp_all⟩
|
||||
· rintro ⟨x, hm, hx⟩
|
||||
exact le_trans (List.apply_minOn_le_of_mem hm) hx
|
||||
|
||||
protected theorem lt_apply_minOn_iff
|
||||
[LE β] [DecidableLE β] [LT β] [IsLinearPreorder β] [LawfulOrderLT β]
|
||||
{xs : List α} {f : α → β} (h : xs ≠ []) {b : β} :
|
||||
b < f (xs.minOn f h) ↔ ∀ x ∈ xs, b < f x := by
|
||||
simpa [not_le] using not_congr <| xs.apply_minOn_le_iff (f := f) h (b := b)
|
||||
|
||||
protected theorem apply_minOn_lt_iff
|
||||
[LE β] [DecidableLE β] [LT β] [IsLinearPreorder β] [LawfulOrderLT β]
|
||||
{xs : List α} {f : α → β} (h : xs ≠ []) {b : β} :
|
||||
f (xs.minOn f h) < b ↔ ∃ x ∈ xs, f x < b := by
|
||||
simpa [not_le] using not_congr <| xs.le_apply_minOn_iff (f := f) h (b := b)
|
||||
|
||||
protected theorem apply_minOn_le_apply_minOn_of_subset [LE β] [DecidableLE β]
|
||||
[IsLinearPreorder β] {xs ys : List α} {f : α → β} (hxs : ys ⊆ xs) (hys : ys ≠ []) :
|
||||
haveI : xs ≠ [] := by intro h; rw [h] at hxs; simp_all [subset_nil]
|
||||
f (xs.minOn f this) ≤ f (ys.minOn f hys) := by
|
||||
rw [List.le_apply_minOn_iff]
|
||||
intro x hx
|
||||
exact List.apply_minOn_le_of_mem (hxs hx)
|
||||
|
||||
protected theorem le_apply_minOn_take [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs : List α} {f : α → β} {i : Nat} (h : xs.take i ≠ []) :
|
||||
f (xs.minOn f (List.ne_nil_of_take_ne_nil h)) ≤ f ((xs.take i).minOn f h) := by
|
||||
apply List.apply_minOn_le_apply_minOn_of_subset
|
||||
apply take_subset
|
||||
|
||||
protected theorem apply_minOn_append_le_left [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs ys : List α} {f : α → β} (h : xs ≠ []) :
|
||||
f ((xs ++ ys).minOn f (append_ne_nil_of_left_ne_nil h ys)) ≤
|
||||
f (xs.minOn f h) := by
|
||||
apply List.apply_minOn_le_apply_minOn_of_subset
|
||||
apply subset_append_left
|
||||
|
||||
protected theorem apply_minOn_append_le_right [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs ys : List α} {f : α → β} (h : ys ≠ []) :
|
||||
f ((xs ++ ys).minOn f (append_ne_nil_of_right_ne_nil xs h)) ≤
|
||||
f (ys.minOn f h) := by
|
||||
apply List.apply_minOn_le_apply_minOn_of_subset
|
||||
apply subset_append_right
|
||||
|
||||
@[simp]
|
||||
protected theorem minOn_append [LE β] [DecidableLE β] [IsLinearPreorder β] {xs ys : List α}
|
||||
{f : α → β} (hxs : xs ≠ []) (hys : ys ≠ []) :
|
||||
(xs ++ ys).minOn f (by simp [hxs]) = minOn f (xs.minOn f hxs) (ys.minOn f hys) := by
|
||||
match xs, ys with
|
||||
| x :: xs, y :: ys => simp [List.minOn, foldl_assoc]
|
||||
|
||||
protected theorem minOn_eq_head [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs : List α} {f : α → β} (h : xs ≠ []) (h' : ∀ x ∈ xs, f (xs.head h) ≤ f x) :
|
||||
xs.minOn f h = xs.head h := by
|
||||
match xs with
|
||||
| x :: xs =>
|
||||
simp only [List.minOn]
|
||||
induction xs
|
||||
· simp
|
||||
· simp only [foldl_cons, head_cons]
|
||||
rw [minOn_eq_left] <;> simp_all
|
||||
|
||||
protected theorem min_map
|
||||
[LE β] [DecidableLE β] [Min β] [IsLinearPreorder β] [LawfulOrderLeftLeaningMin β] {xs : List α}
|
||||
{f : α → β} (h : xs ≠ []) :
|
||||
(xs.map f).min (by simpa) = f (xs.minOn f h) := by
|
||||
match xs with
|
||||
| x :: xs =>
|
||||
simp only [List.minOn, map_cons, List.min, foldl_map]
|
||||
rw [foldl_hom]
|
||||
simp [min_apply]
|
||||
|
||||
@[simp]
|
||||
protected theorem minOn_replicate [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{n : Nat} {a : α} {f : α → β} (h : replicate n a ≠ []) :
|
||||
(replicate n a).minOn f h = a := by
|
||||
induction n
|
||||
· simp at h
|
||||
· rename_i n ih
|
||||
simp only [ne_eq, replicate_eq_nil_iff] at ih
|
||||
simp +contextual [List.replicate, List.minOn_cons, ih]
|
||||
|
||||
/-! ### maxOn -/
|
||||
|
||||
protected theorem maxOn_eq_minOn {le : LE β} {dle : DecidableLE β} {xs : List α} {f : α → β} {h} :
|
||||
xs.maxOn f h = (letI := le.opposite; xs.minOn f h) :=
|
||||
(rfl)
|
||||
|
||||
@[simp]
|
||||
protected theorem maxOn_singleton [LE β] [DecidableLE β] {x : α} {f : α → β} :
|
||||
[x].maxOn f (of_decide_eq_false rfl) = x := by
|
||||
simp [List.maxOn]
|
||||
|
||||
protected theorem maxOn_cons
|
||||
[LE β] [DecidableLE β] [IsLinearPreorder β] {x : α} {xs : List α} {f : α → β} :
|
||||
(x :: xs).maxOn f (by exact of_decide_eq_false rfl) =
|
||||
if h : xs = [] then x else maxOn f x (xs.maxOn f h) := by
|
||||
simp only [maxOn_eq_minOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
exact List.minOn_cons (f := f)
|
||||
|
||||
protected theorem min_eq_max {min : Min α} {xs : List α} {h} :
|
||||
xs.min h = (letI := min.oppositeMax; xs.max h) := by
|
||||
simp only [List.min, List.max]
|
||||
rw [Min.oppositeMax_def]
|
||||
simp
|
||||
|
||||
protected theorem max_eq_min {max : Max α} {xs : List α} {h} :
|
||||
xs.max h = (letI := max.oppositeMin; xs.min h) := by
|
||||
simp only [List.min, List.max]
|
||||
rw [Max.oppositeMin_def]
|
||||
simp
|
||||
|
||||
protected theorem max?_eq_min? {max : Max α} {xs : List α} :
|
||||
xs.max? = (letI := max.oppositeMin; xs.min?) := by
|
||||
simp only [List.min?, List.max?]
|
||||
rw [Max.oppositeMin_def]
|
||||
simp
|
||||
|
||||
@[simp]
|
||||
protected theorem maxOn_id [Max α] [LE α] [DecidableLE α] [LawfulOrderLeftLeaningMax α]
|
||||
{xs : List α} (h : xs ≠ []) :
|
||||
xs.maxOn id h = xs.max h := by
|
||||
simp only [List.maxOn_eq_minOn]
|
||||
letI : LE α := (inferInstanceAs (LE α)).opposite
|
||||
letI : Min α := (inferInstanceAs (Max α)).oppositeMin
|
||||
simpa only [List.max_eq_min] using List.minOn_id h
|
||||
|
||||
@[simp]
|
||||
protected theorem maxOn_mem [LE β] [DecidableLE β] {xs : List α}
|
||||
{f : α → β} {h : xs ≠ []} : xs.maxOn f h ∈ xs :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.minOn_mem (f := f)
|
||||
|
||||
protected theorem le_apply_maxOn_of_mem [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs : List α} {f : α → β} {y : α} (hx : y ∈ xs) :
|
||||
f y ≤ f (xs.maxOn f (List.ne_nil_of_mem hx)) := by
|
||||
rw [List.maxOn_eq_minOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
simpa [LE.le_opposite_iff] using List.apply_minOn_le_of_mem (f := f) hx
|
||||
|
||||
protected theorem apply_maxOn_le_iff [LE β] [DecidableLE β] [IsLinearPreorder β] {xs : List α}
|
||||
{f : α → β} (h : xs ≠ []) {b : β} :
|
||||
f (xs.maxOn f h) ≤ b ↔ ∀ x ∈ xs, f x ≤ b := by
|
||||
rw [List.maxOn_eq_minOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
simpa [LE.le_opposite_iff] using List.le_apply_minOn_iff (f := f) h
|
||||
|
||||
protected theorem le_apply_maxOn_iff [LE β] [DecidableLE β] [IsLinearPreorder β] {xs : List α}
|
||||
{f : α → β} (h : xs ≠ []) {b : β} :
|
||||
b ≤ f (xs.maxOn f h) ↔ ∃ x ∈ xs, b ≤ f x := by
|
||||
rw [List.maxOn_eq_minOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
simpa [LE.le_opposite_iff] using List.apply_minOn_le_iff (f := f) h
|
||||
|
||||
protected theorem apply_maxOn_lt_iff
|
||||
[LE β] [DecidableLE β] [LT β] [IsLinearPreorder β] [LawfulOrderLT β]
|
||||
{xs : List α} {f : α → β} (h : xs ≠ []) {b : β} :
|
||||
f (xs.maxOn f h) < b ↔ ∀ x ∈ xs, f x < b := by
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
letI : LT β := (inferInstanceAs (LT β)).opposite
|
||||
simpa [LT.lt_opposite_iff] using List.lt_apply_minOn_iff (f := f) h
|
||||
|
||||
protected theorem lt_apply_maxOn_iff
|
||||
[LE β] [DecidableLE β] [LT β] [IsLinearPreorder β] [LawfulOrderLT β]
|
||||
{xs : List α} {f : α → β} (h : xs ≠ []) {b : β} :
|
||||
b < f (xs.maxOn f h) ↔ ∃ x ∈ xs, b < f x := by
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
letI : LT β := (inferInstanceAs (LT β)).opposite
|
||||
simpa [LT.lt_opposite_iff] using List.apply_minOn_lt_iff (f := f) h
|
||||
|
||||
protected theorem apply_maxOn_le_apply_maxOn_of_subset [LE β] [DecidableLE β]
|
||||
[IsLinearPreorder β] {xs ys : List α} {f : α → β} (hxs : ys ⊆ xs) (hys : ys ≠ []) :
|
||||
haveI : xs ≠ [] := by intro h; rw [h] at hxs; simp_all [subset_nil]
|
||||
f (ys.maxOn f hys) ≤ f (xs.maxOn f this) := by
|
||||
rw [List.maxOn_eq_minOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
simpa [LE.le_opposite_iff] using List.apply_minOn_le_apply_minOn_of_subset (f := f) hxs hys
|
||||
|
||||
protected theorem apply_maxOn_take_le [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs : List α} {f : α → β} {i : Nat} (h : xs.take i ≠ []) :
|
||||
f ((xs.take i).maxOn f h) ≤ f (xs.maxOn f (List.ne_nil_of_take_ne_nil h)) := by
|
||||
rw [List.maxOn_eq_minOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
simpa [LE.le_opposite_iff] using List.le_apply_minOn_take (f := f) h
|
||||
|
||||
protected theorem le_apply_maxOn_append_left [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs ys : List α} {f : α → β} (h : xs ≠ []) :
|
||||
f (xs.maxOn f h) ≤
|
||||
f ((xs ++ ys).maxOn f (append_ne_nil_of_left_ne_nil h ys)) := by
|
||||
rw [List.maxOn_eq_minOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
simpa [LE.le_opposite_iff] using List.apply_minOn_append_le_left (f := f) h
|
||||
|
||||
protected theorem le_apply_maxOn_append_right [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{xs ys : List α} {f : α → β} (h : ys ≠ []) :
|
||||
f (ys.maxOn f h) ≤
|
||||
f ((xs ++ ys).maxOn f (append_ne_nil_of_right_ne_nil xs h)) := by
|
||||
rw [List.maxOn_eq_minOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
simpa [LE.le_opposite_iff] using List.apply_minOn_append_le_right (f := f) h
|
||||
|
||||
@[simp]
|
||||
protected theorem maxOn_append [LE β] [DecidableLE β] [IsLinearPreorder β] {xs ys : List α}
|
||||
{f : α → β} (hxs : xs ≠ []) (hys : ys ≠ []) :
|
||||
(xs ++ ys).maxOn f (by simp [hxs]) = maxOn f (xs.maxOn f hxs) (ys.maxOn f hys) := by
|
||||
simp only [List.maxOn_eq_minOn, maxOn_eq_minOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
simpa [LE.le_opposite_iff] using List.minOn_append (f := f) hxs hys
|
||||
|
||||
protected theorem maxOn_eq_head [LE β] [DecidableLE β] [IsLinearPreorder β] {xs : List α}
|
||||
{f : α → β} (h : xs ≠ []) (h' : ∀ x ∈ xs, f x ≤ f (xs.head h)) :
|
||||
xs.maxOn f h = xs.head h := by
|
||||
rw [List.maxOn_eq_minOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
simpa [LE.le_opposite_iff] using List.minOn_eq_head (f := f) h (by simpa [LE.le_opposite_iff] using h')
|
||||
|
||||
protected theorem max_map
|
||||
[LE β] [DecidableLE β] [Max β] [IsLinearPreorder β] [LawfulOrderLeftLeaningMax β] {xs : List α}
|
||||
{f : α → β} (h : xs ≠ []) : (xs.map f).max (by simpa) = f (xs.maxOn f h) := by
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
letI : Min β := (inferInstanceAs (Max β)).oppositeMin
|
||||
simpa [List.max_eq_min] using List.min_map (f := f) h
|
||||
|
||||
@[simp]
|
||||
protected theorem maxOn_replicate [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{n : Nat} {a : α} {f : α → β} (h : replicate n a ≠ []) :
|
||||
(replicate n a).maxOn f h = a :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.minOn_replicate (f := f) h
|
||||
|
||||
/-! ### minOn? -/
|
||||
|
||||
/-- {lit}`List.minOn?` returns {name}`none` when applied to an empty list. -/
|
||||
@[simp]
|
||||
protected theorem minOn?_nil [LE β] [DecidableLE β] {f : α → β} :
|
||||
([] : List α).minOn? f = none := by
|
||||
simp [List.minOn?]
|
||||
|
||||
protected theorem minOn?_cons_eq_some_minOn
|
||||
[LE β] [DecidableLE β] {f : α → β} {x : α} {xs : List α} :
|
||||
(x :: xs).minOn? f = some ((x :: xs).minOn f (fun h => nomatch h)) := by
|
||||
simp [List.minOn?, List.minOn]
|
||||
|
||||
protected theorem minOn?_cons
|
||||
[LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β} {x : α} {xs : List α} :
|
||||
(x :: xs).minOn? f = some ((xs.minOn? f).elim x (minOn f x)) := by
|
||||
simp only [List.minOn?]
|
||||
split <;> simp [foldl_assoc]
|
||||
|
||||
@[simp]
|
||||
protected theorem minOn?_singleton [LE β] [DecidableLE β] {x : α} {f : α → β} :
|
||||
[x].minOn? f = some x := by
|
||||
simp [List.minOn?_cons_eq_some_minOn]
|
||||
|
||||
@[simp]
|
||||
protected theorem minOn?_id [Min α] [LE α] [DecidableLE α] [LawfulOrderLeftLeaningMin α]
|
||||
{xs : List α} : xs.minOn? id = xs.min? := by
|
||||
cases xs
|
||||
· simp
|
||||
· simp only [List.minOn?_cons_eq_some_minOn, List.minOn_id, List.min?_eq_some_min (List.cons_ne_nil _ _)]
|
||||
|
||||
protected theorem minOn?_eq_if
|
||||
[LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β} {xs : List α} :
|
||||
xs.minOn? f =
|
||||
if h : xs ≠ [] then
|
||||
some (xs.minOn f h)
|
||||
else
|
||||
none := by
|
||||
fun_cases xs.minOn? f <;> simp [List.minOn]
|
||||
|
||||
@[simp]
|
||||
protected theorem isSome_minOn?_iff [LE β] [DecidableLE β] {f : α → β} {xs : List α} :
|
||||
(xs.minOn? f).isSome ↔ xs ≠ [] := by
|
||||
fun_cases xs.minOn? f <;> simp
|
||||
|
||||
protected theorem minOn_eq_get_minOn? [LE β] [DecidableLE β] {f : α → β} {xs : List α}
|
||||
(h : xs ≠ []) : xs.minOn f h = (xs.minOn? f).get (List.isSome_minOn?_iff.mpr h) := by
|
||||
fun_cases xs.minOn? f
|
||||
· contradiction
|
||||
· simp [List.minOn?, List.minOn]
|
||||
|
||||
protected theorem minOn?_eq_some_minOn [LE β] [DecidableLE β] {f : α → β} {xs : List α}
|
||||
(h : xs ≠ []) : xs.minOn? f = some (xs.minOn f h) := by
|
||||
simp [List.minOn_eq_get_minOn? h]
|
||||
|
||||
@[simp]
|
||||
protected theorem get_minOn? [LE β] [DecidableLE β] {f : α → β} {xs : List α}
|
||||
(h : xs ≠ []) : (xs.minOn? f).get (List.isSome_minOn?_iff.mpr h) = xs.minOn f h := by
|
||||
rw [List.minOn_eq_get_minOn?]
|
||||
|
||||
protected theorem minOn_eq_of_minOn?_eq_some
|
||||
[LE β] [DecidableLE β] {f : α → β} {xs : List α} {x : α} (h : xs.minOn? f = some x) :
|
||||
xs.minOn f (List.isSome_minOn?_iff.mp (Option.isSome_of_eq_some h)) = x := by
|
||||
have h' := List.isSome_minOn?_iff.mp (Option.isSome_of_eq_some h)
|
||||
rwa [List.minOn?_eq_some_minOn h', Option.some.injEq] at h
|
||||
|
||||
protected theorem isSome_minOn?_of_mem
|
||||
[LE β] [DecidableLE β] {f : α → β} {xs : List α} {x : α} (h : x ∈ xs) :
|
||||
(xs.minOn? f).isSome := by
|
||||
apply List.isSome_minOn?_iff.mpr
|
||||
exact ne_nil_of_mem h
|
||||
|
||||
protected theorem apply_get_minOn?_le_of_mem
|
||||
[LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β} {xs : List α} {x : α} (h : x ∈ xs) :
|
||||
f ((xs.minOn? f).get (List.isSome_minOn?_of_mem h)) ≤ f x := by
|
||||
rw [List.get_minOn? (ne_nil_of_mem h)]
|
||||
apply List.apply_minOn_le_of_mem h
|
||||
|
||||
protected theorem minOn?_mem [LE β] [DecidableLE β] {xs : List α}
|
||||
{f : α → β} (h : xs.minOn? f = some a) : a ∈ xs := by
|
||||
rw [← List.minOn_eq_of_minOn?_eq_some h]
|
||||
apply List.minOn_mem
|
||||
|
||||
protected theorem minOn?_replicate [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{n : Nat} {a : α} {f : α → β} :
|
||||
(replicate n a).minOn? f = if n = 0 then none else some a := by
|
||||
split
|
||||
· simp [*]
|
||||
· rw [List.minOn?_eq_some_minOn, List.minOn_replicate]
|
||||
simp [*]
|
||||
|
||||
@[simp]
|
||||
protected theorem minOn?_replicate_of_pos [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{n : Nat} {a : α} {f : α → β} (h : 0 < n) :
|
||||
(replicate n a).minOn? f = some a := by
|
||||
simp [List.minOn?_replicate, show n ≠ 0 from Nat.ne_zero_of_lt h]
|
||||
|
||||
@[simp]
|
||||
protected theorem minOn?_append [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
(xs ys : List α) (f : α → β) :
|
||||
(xs ++ ys).minOn? f =
|
||||
(xs.minOn? f).merge (_root_.minOn f) (ys.minOn? f) := by
|
||||
by_cases xs = [] <;> by_cases ys = [] <;> simp [*, List.minOn?_eq_if, List.minOn_append]
|
||||
|
||||
/-! ### maxOn? -/
|
||||
|
||||
protected theorem maxOn?_eq_minOn? {le : LE β} {dle : DecidableLE β} {xs : List α} {f : α → β} :
|
||||
xs.maxOn? f = (letI := le.opposite; xs.minOn? f) :=
|
||||
(rfl)
|
||||
|
||||
@[simp]
|
||||
protected theorem maxOn?_nil [LE β] [DecidableLE β] {f : α → β} :
|
||||
([] : List α).maxOn? f = none :=
|
||||
List.minOn?_nil (f := f)
|
||||
|
||||
protected theorem maxOn?_cons_eq_some_maxOn
|
||||
[LE β] [DecidableLE β] {f : α → β} {x : α} {xs : List α} :
|
||||
(x :: xs).maxOn? f = some ((x :: xs).maxOn f (fun h => nomatch h)) :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.minOn?_cons_eq_some_minOn
|
||||
|
||||
protected theorem maxOn?_cons
|
||||
[LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β} {x : α} {xs : List α} :
|
||||
(x :: xs).maxOn? f = some ((xs.maxOn? f).elim x (maxOn f x)) := by
|
||||
have : maxOn f x = (letI : LE β := LE.opposite inferInstance; minOn f x) := by
|
||||
ext; simp only [maxOn_eq_minOn]
|
||||
simp only [List.maxOn?_eq_minOn?, this]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
exact List.minOn?_cons
|
||||
|
||||
@[simp]
|
||||
protected theorem maxOn?_singleton [LE β] [DecidableLE β] {x : α} {f : α → β} :
|
||||
[x].maxOn? f = some x :=
|
||||
List.minOn?_singleton (f := f)
|
||||
|
||||
@[simp]
|
||||
protected theorem maxOn?_id [Max α] [LE α] [DecidableLE α] [LawfulOrderLeftLeaningMax α]
|
||||
{xs : List α} : xs.maxOn? id = xs.max? := by
|
||||
letI : LE α := (inferInstanceAs (LE α)).opposite
|
||||
letI : Min α := (inferInstanceAs (Max α)).oppositeMin
|
||||
simpa only [List.maxOn?_eq_minOn?, List.max?_eq_min?] using List.minOn?_id (α := α)
|
||||
|
||||
protected theorem maxOn?_eq_if
|
||||
[LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β} {xs : List α} :
|
||||
xs.maxOn? f =
|
||||
if h : xs ≠ [] then
|
||||
some (xs.maxOn f h)
|
||||
else
|
||||
none :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.minOn?_eq_if
|
||||
|
||||
@[simp]
|
||||
protected theorem isSome_maxOn?_iff [LE β] [DecidableLE β] {f : α → β} {xs : List α} :
|
||||
(xs.maxOn? f).isSome ↔ xs ≠ [] := by
|
||||
fun_cases xs.maxOn? f <;> simp
|
||||
|
||||
protected theorem maxOn_eq_get_maxOn? [LE β] [DecidableLE β] {f : α → β} {xs : List α}
|
||||
(h : xs ≠ []) : xs.maxOn f h = (xs.maxOn? f).get (List.isSome_maxOn?_iff.mpr h) :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.minOn_eq_get_minOn? (f := f) h
|
||||
|
||||
protected theorem maxOn?_eq_some_maxOn [LE β] [DecidableLE β] {f : α → β} {xs : List α}
|
||||
(h : xs ≠ []) : xs.maxOn? f = some (xs.maxOn f h) :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.minOn?_eq_some_minOn (f := f) h
|
||||
|
||||
@[simp]
|
||||
protected theorem get_maxOn? [LE β] [DecidableLE β] {f : α → β} {xs : List α}
|
||||
(h : xs ≠ []) : (xs.maxOn? f).get (List.isSome_maxOn?_iff.mpr h) = xs.maxOn f h :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.get_minOn? (f := f) h
|
||||
|
||||
protected theorem maxOn_eq_of_maxOn?_eq_some
|
||||
[LE β] [DecidableLE β] {f : α → β} {xs : List α} {x : α} (h : xs.maxOn? f = some x) :
|
||||
xs.maxOn f (List.isSome_maxOn?_iff.mp (Option.isSome_of_eq_some h)) = x :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.minOn_eq_of_minOn?_eq_some (f := f) h
|
||||
|
||||
protected theorem isSome_maxOn?_of_mem
|
||||
[LE β] [DecidableLE β] {f : α → β} {xs : List α} {x : α} (h : x ∈ xs) :
|
||||
(xs.maxOn? f).isSome :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.isSome_minOn?_of_mem (f := f) h
|
||||
|
||||
protected theorem le_apply_get_maxOn?_of_mem
|
||||
[LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β} {xs : List α} {x : α} (h : x ∈ xs) :
|
||||
f x ≤ f ((xs.maxOn? f).get (List.isSome_maxOn?_of_mem h)) := by
|
||||
simp only [List.maxOn?_eq_minOn?]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
simpa [LE.le_opposite_iff] using List.apply_get_minOn?_le_of_mem (f := f) h
|
||||
|
||||
protected theorem maxOn?_mem [LE β] [DecidableLE β] {xs : List α}
|
||||
{f : α → β} (h : xs.maxOn? f = some a) : a ∈ xs :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.minOn?_mem (f := f) h
|
||||
|
||||
protected theorem maxOn?_replicate [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{n : Nat} {a : α} {f : α → β} :
|
||||
(replicate n a).maxOn? f = if n = 0 then none else some a :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.minOn?_replicate
|
||||
|
||||
@[simp]
|
||||
protected theorem maxOn?_replicate_of_pos [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
{n : Nat} {a : α} {f : α → β} (h : 0 < n) :
|
||||
(replicate n a).maxOn? f = some a :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
List.minOn?_replicate_of_pos (f := f) h
|
||||
|
||||
@[simp]
|
||||
protected theorem maxOn?_append [LE β] [DecidableLE β] [IsLinearPreorder β]
|
||||
(xs ys : List α) (f : α → β) : (xs ++ ys).maxOn? f =
|
||||
(xs.maxOn? f).merge (_root_.maxOn f) (ys.maxOn? f) := by
|
||||
have : maxOn f = (letI : LE β := LE.opposite inferInstance; minOn f) := by
|
||||
ext; simp only [maxOn_eq_minOn]
|
||||
simp only [List.maxOn?_eq_minOn?, this]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
exact List.minOn?_append xs ys f
|
||||
|
||||
end List
|
||||
@@ -141,6 +141,10 @@ theorem take_append_of_le_length {l₁ l₂ : List α} {i : Nat} (h : i ≤ l₁
|
||||
(l₁ ++ l₂).take i = l₁.take i := by
|
||||
simp [take_append, Nat.sub_eq_zero_of_le h]
|
||||
|
||||
@[grind =]
|
||||
theorem take_append_length {l₁ l₂ : List α} : (l₁ ++ l₂).take l₁.length = l₁ := by
|
||||
simp
|
||||
|
||||
/-- Taking the first `l₁.length + i` elements in `l₁ ++ l₂` is the same as appending the first
|
||||
`i` elements of `l₂` to `l₁`. -/
|
||||
theorem take_length_add_append {l₁ l₂ : List α} (i : Nat) :
|
||||
@@ -304,7 +308,6 @@ theorem drop_length_cons {l : List α} (h : l ≠ []) (a : α) :
|
||||
|
||||
/-- Dropping the elements up to `i` in `l₁ ++ l₂` is the same as dropping the elements up to `i`
|
||||
in `l₁`, dropping the elements up to `i - l₁.length` in `l₂`, and appending them. -/
|
||||
@[grind =]
|
||||
theorem drop_append {l₁ l₂ : List α} {i : Nat} :
|
||||
drop i (l₁ ++ l₂) = drop i l₁ ++ drop (i - l₁.length) l₂ := by
|
||||
induction l₁ generalizing i
|
||||
@@ -315,10 +318,15 @@ theorem drop_append {l₁ l₂ : List α} {i : Nat} :
|
||||
congr 1
|
||||
omega
|
||||
|
||||
@[grind =]
|
||||
theorem drop_append_of_le_length {l₁ l₂ : List α} {i : Nat} (h : i ≤ l₁.length) :
|
||||
(l₁ ++ l₂).drop i = l₁.drop i ++ l₂ := by
|
||||
simp [drop_append, Nat.sub_eq_zero_of_le h]
|
||||
|
||||
@[grind =]
|
||||
theorem drop_append_length {l₁ l₂ : List α} : (l₁ ++ l₂).drop l₁.length = l₂ := by
|
||||
simp [List.drop_append_of_le_length (Nat.le_refl _)]
|
||||
|
||||
/-- Dropping the elements up to `l₁.length + i` in `l₁ + l₂` is the same as dropping the elements
|
||||
up to `i` in `l₂`. -/
|
||||
@[simp]
|
||||
|
||||
@@ -54,6 +54,15 @@ theorem div_le_iff_le_mul (h : 0 < k) : x / k ≤ y ↔ x ≤ y * k + k - 1 := b
|
||||
rw [le_iff_lt_add_one, Nat.div_lt_iff_lt_mul h, Nat.add_one_mul]
|
||||
omega
|
||||
|
||||
theorem le_mul_iff_le_left (hz : 0 < z) :
|
||||
x ≤ y * z ↔ (x + z - 1) / z ≤ y := by
|
||||
rw [Nat.div_le_iff_le_mul hz]
|
||||
omega
|
||||
|
||||
theorem le_mul_iff_le_right (hy : 0 < y) :
|
||||
x ≤ y * z ↔ (x + y - 1) / y ≤ z := by
|
||||
rw [← le_mul_iff_le_left hy, Nat.mul_comm]
|
||||
|
||||
-- TODO: reprove `div_eq_of_lt_le` in terms of this:
|
||||
protected theorem div_eq_iff (h : 0 < k) : x / k = y ↔ y * k ≤ x ∧ x ≤ y * k + k - 1 := by
|
||||
rw [Nat.eq_iff_le_and_ge, and_comm, le_div_iff_mul_le h, Nat.div_le_iff_le_mul h]
|
||||
@@ -95,6 +104,12 @@ theorem div_add_le_right {z : Nat} (h : 0 < z) (x y : Nat) :
|
||||
x / (y + z) ≤ x / z :=
|
||||
div_le_div_left (Nat.le_add_left z y) h
|
||||
|
||||
theorem div_add_div_le_add_div {x y z : Nat} : x / z + y / z ≤ (x + y) / z := by
|
||||
by_cases hc : z > 0
|
||||
· rw [Nat.le_div_iff_mul_le hc, Nat.add_mul]
|
||||
apply Nat.add_le_add <;> apply Nat.div_mul_le_self
|
||||
· simp_all
|
||||
|
||||
theorem succ_div_of_dvd {a b : Nat} (h : b ∣ a + 1) :
|
||||
(a + 1) / b = a / b + 1 := by
|
||||
replace h := mod_eq_zero_of_dvd h
|
||||
|
||||
@@ -13,4 +13,6 @@ public import Init.Data.Order.Lemmas
|
||||
public import Init.Data.Order.LemmasExtra
|
||||
public import Init.Data.Order.Factories
|
||||
public import Init.Data.Order.FactoriesExtra
|
||||
public import Init.Data.Order.MinMaxOn
|
||||
public import Init.Data.Order.Opposite
|
||||
public import Init.Data.Order.PackageFactories
|
||||
|
||||
@@ -142,6 +142,10 @@ public theorem not_gt_of_lt {α : Type u} [LT α] [i : Std.Asymm (α := α) (·
|
||||
(h : a < b) : ¬ b < a :=
|
||||
i.asymm a b h
|
||||
|
||||
public theorem lt_irrefl {α : Type u} [LT α] [i : Std.Irrefl (α := α) (· < ·)] {a : α} :
|
||||
¬ a < a :=
|
||||
i.irrefl a
|
||||
|
||||
public theorem le_of_lt {α : Type u} [LT α] [LE α] [LawfulOrderLT α] {a b : α} (h : a < b) :
|
||||
a ≤ b := (lt_iff_le_and_not_ge.1 h).1
|
||||
|
||||
|
||||
198
src/Init/Data/Order/MinMaxOn.lean
Normal file
198
src/Init/Data/Order/MinMaxOn.lean
Normal file
@@ -0,0 +1,198 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Paul Reichert
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.NotationExtra
|
||||
public import Init.Data.Order.Lemmas
|
||||
public import Init.Data.Order.Opposite
|
||||
|
||||
open Std
|
||||
open scoped OppositeOrderInstances
|
||||
|
||||
/-! ## Definitions -/
|
||||
|
||||
/--
|
||||
Returns either `x` or `y`, the one with the smaller value under `f`.
|
||||
|
||||
If `f x ≤ f y`, it returns `x`, and otherwise returns `y`.
|
||||
-/
|
||||
public def minOn [LE β] [DecidableLE β] (f : α → β) (x y : α) :=
|
||||
if f x ≤ f y then x else y
|
||||
|
||||
/--
|
||||
Returns either `x` or `y`, the one with the greater value under `f`.
|
||||
|
||||
If `f y ≤ f x`, it returns `x`, and otherwise returns `y`.
|
||||
-/
|
||||
public def maxOn [i : LE β] [DecidableLE β] (f : α → β) (x y : α) :=
|
||||
letI := i.opposite
|
||||
minOn f x y
|
||||
|
||||
/-! ## `minOn` Lemmas -/
|
||||
|
||||
public theorem minOn_id [Min α] [LE α] [DecidableLE α] [LawfulOrderLeftLeaningMin α] {x y : α} :
|
||||
minOn id x y = min x y := by
|
||||
simp [minOn, min_eq_if]
|
||||
|
||||
public theorem maxOn_id [Max α] [LE α] [DecidableLE α] [LawfulOrderLeftLeaningMax α] {x y : α} :
|
||||
maxOn id x y = max x y := by
|
||||
letI : LE α := (inferInstanceAs (LE α)).opposite
|
||||
letI : Min α := (inferInstanceAs (Max α)).oppositeMin
|
||||
simp [maxOn, minOn_id, Max.min_oppositeMin, this]
|
||||
|
||||
public theorem minOn_eq_or [LE β] [DecidableLE β] {f : α → β} {x y : α} :
|
||||
minOn f x y = x ∨ minOn f x y = y := by
|
||||
rw [minOn]
|
||||
split
|
||||
· exact Or.inl rfl
|
||||
· exact Or.inr rfl
|
||||
|
||||
@[simp]
|
||||
public theorem minOn_self [LE β] [DecidableLE β] {f : α → β} {x : α} :
|
||||
minOn f x x = x := by
|
||||
cases minOn_eq_or (f := f) (x := x) (y := x) <;> assumption
|
||||
|
||||
public theorem minOn_eq_left [LE β] [DecidableLE β] {f : α → β} {x y : α} (h : f x ≤ f y) :
|
||||
minOn f x y = x := by
|
||||
simp [minOn, h]
|
||||
|
||||
public theorem minOn_eq_right [LE β] [DecidableLE β] {f : α → β} {x y : α} (h : ¬ f x ≤ f y) :
|
||||
minOn f x y = y := by
|
||||
simp [minOn, h]
|
||||
|
||||
public theorem minOn_eq_right_of_lt
|
||||
[LE β] [DecidableLE β] [LT β] [Total (α := β) (· ≤ ·)] [LawfulOrderLT β]
|
||||
{f : α → β} {x y : α} (h : f y < f x) :
|
||||
minOn f x y = y := by
|
||||
apply minOn_eq_right
|
||||
simpa [not_le] using h
|
||||
|
||||
public theorem apply_minOn_le_left [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β}
|
||||
{x y : α} : f (minOn f x y) ≤ f x := by
|
||||
rw [minOn]
|
||||
split
|
||||
· apply le_refl
|
||||
· exact le_of_not_ge ‹_›
|
||||
|
||||
public theorem apply_minOn_le_right [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β}
|
||||
{x y : α} : f (minOn f x y) ≤ f y := by
|
||||
rw [minOn]
|
||||
split
|
||||
· assumption
|
||||
· apply le_refl
|
||||
|
||||
public theorem le_apply_minOn_iff [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β}
|
||||
{x y : α} {b : β} :
|
||||
b ≤ f (minOn f x y) ↔ b ≤ f x ∧ b ≤ f y := by
|
||||
apply Iff.intro
|
||||
· intro h
|
||||
exact ⟨le_trans h apply_minOn_le_left, le_trans h apply_minOn_le_right⟩
|
||||
· intro h
|
||||
cases minOn_eq_or (f := f) (x := x) (y := y) <;> simp_all
|
||||
|
||||
public theorem minOn_assoc [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β}
|
||||
{x y z : α} : minOn f (minOn f x y) z = minOn f x (minOn f y z) := by
|
||||
open scoped Classical.Order in
|
||||
simp only [minOn]
|
||||
split
|
||||
· split
|
||||
· split
|
||||
· rfl
|
||||
· rfl
|
||||
· split
|
||||
· have : ¬ f x ≤ f z := by assumption
|
||||
have : f x ≤ f z := le_trans ‹f x ≤ f y› ‹f y ≤ f z›
|
||||
contradiction
|
||||
· rfl
|
||||
· split
|
||||
· rfl
|
||||
· have : f z < f y := not_le.mp ‹¬ f y ≤ f z›
|
||||
have : f y < f x := not_le.mp ‹¬ f x ≤ f y›
|
||||
have : f z < f x := lt_trans ‹_› ‹_›
|
||||
rw [if_neg]
|
||||
exact not_le.mpr ‹_›
|
||||
|
||||
public instance [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β} :
|
||||
Associative (minOn f) where
|
||||
assoc := by apply minOn_assoc
|
||||
|
||||
public theorem min_apply [LE β] [DecidableLE β] [Min β] [LawfulOrderLeftLeaningMin β]
|
||||
{f : α → β} {x y : α} : min (f x) (f y) = f (minOn f x y) := by
|
||||
rw [min_eq_if, minOn]
|
||||
split <;> rfl
|
||||
|
||||
/-! ## `maxOn` Lemmas -/
|
||||
|
||||
public theorem maxOn_eq_minOn [le : LE β] [DecidableLE β] {f : α → β} {x y : α} :
|
||||
maxOn f x y = (letI := le.opposite; minOn f x y) :=
|
||||
(rfl)
|
||||
|
||||
public theorem maxOn_eq_or [LE β] [DecidableLE β] {f : α → β} {x y : α} :
|
||||
maxOn f x y = x ∨ maxOn f x y = y :=
|
||||
@minOn_eq_or ..
|
||||
|
||||
@[simp]
|
||||
public theorem maxOn_self [LE β] [DecidableLE β] {f : α → β} {x : α} :
|
||||
maxOn f x x = x :=
|
||||
@minOn_self ..
|
||||
|
||||
public theorem maxOn_eq_left [le : LE β] [DecidableLE β] {f : α → β} {x y : α} (h : f y ≤ f x) :
|
||||
maxOn f x y = x := by
|
||||
simp only [maxOn_eq_minOn]
|
||||
exact @minOn_eq_left (h := by simpa [LE.opposite_def]) ..
|
||||
|
||||
public theorem maxOn_eq_right [LE β] [DecidableLE β] {f : α → β} {x y : α} (h : ¬ f y ≤ f x) :
|
||||
maxOn f x y = y := by
|
||||
simp only [maxOn_eq_minOn]
|
||||
exact @minOn_eq_right (h := by simpa [LE.opposite_def]) ..
|
||||
|
||||
public theorem maxOn_eq_right_of_lt
|
||||
[LE β] [DecidableLE β] [LT β] [Total (α := β) (· ≤ ·)] [LawfulOrderLT β]
|
||||
{f : α → β} {x y : α} (h : f x < f y) :
|
||||
maxOn f x y = y :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
letI : LT β := (inferInstanceAs (LT β)).opposite
|
||||
minOn_eq_right_of_lt (h := by simpa [LT.lt_opposite_iff] using h) ..
|
||||
|
||||
public theorem left_le_apply_maxOn [le : LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β}
|
||||
{x y : α} : f x ≤ f (maxOn f x y) := by
|
||||
rw [maxOn_eq_minOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
simpa only [LE.le_opposite_iff] using apply_minOn_le_left (f := f) ..
|
||||
|
||||
public theorem right_le_apply_maxOn [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β}
|
||||
{x y : α} : f y ≤ f (maxOn f x y) := by
|
||||
rw [maxOn_eq_minOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
simpa only [LE.le_opposite_iff] using apply_minOn_le_right (f := f)
|
||||
|
||||
public theorem apply_maxOn_le_iff [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β}
|
||||
{x y : α} {b : β} :
|
||||
f (maxOn f x y) ≤ b ↔ f x ≤ b ∧ f y ≤ b := by
|
||||
rw [maxOn_eq_minOn]
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
simpa only [LE.le_opposite_iff] using le_apply_minOn_iff (f := f)
|
||||
|
||||
public theorem maxOn_assoc [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β}
|
||||
{x y z : α} : maxOn f (maxOn f x y) z = maxOn f x (maxOn f y z) :=
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
minOn_assoc (f := f)
|
||||
|
||||
public instance [LE β] [DecidableLE β] [IsLinearPreorder β] {f : α → β} :
|
||||
Associative (maxOn f) where
|
||||
assoc := by
|
||||
apply maxOn_assoc
|
||||
|
||||
public theorem max_apply [LE β] [DecidableLE β] [Max β] [LawfulOrderLeftLeaningMax β]
|
||||
{f : α → β} {x y : α} : max (f x) (f y) = f (maxOn f x y) := by
|
||||
letI : LE β := (inferInstanceAs (LE β)).opposite
|
||||
letI : Min β := (inferInstanceAs (Max β)).oppositeMin
|
||||
simpa [Max.min_oppositeMin] using min_apply (f := f)
|
||||
|
||||
public theorem apply_maxOn [LE β] [DecidableLE β] [Max β] [LawfulOrderLeftLeaningMax β]
|
||||
{f : α → β} {x y : α} : f (maxOn f x y) = max (f x) (f y) :=
|
||||
max_apply.symm
|
||||
407
src/Init/Data/Order/Opposite.lean
Normal file
407
src/Init/Data/Order/Opposite.lean
Normal file
@@ -0,0 +1,407 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Paul Reichert
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.Data.Order.ClassesExtra
|
||||
public import Init.Data.Order.LemmasExtra
|
||||
|
||||
public section
|
||||
|
||||
open Std
|
||||
|
||||
set_option linter.missingDocs true
|
||||
set_option linter.listVariables true -- Enforce naming conventions for `List`/`Array`/`Vector` variables.
|
||||
set_option linter.indexVariables true -- Enforce naming conventions for index variables.
|
||||
|
||||
/-
|
||||
Note: We're having verso docstrings disabled because the examples depend on instances that
|
||||
are provided later in the module. They will be converted into verso docstrings at the end
|
||||
of the module.
|
||||
-/
|
||||
|
||||
/--
|
||||
Inverts an {name}`LE` instance.
|
||||
|
||||
The result is an {lean}`LE α` instance where {lit}`a ≤ b` holds when {name}`le` would have
|
||||
{lit}`b ≤ a` hold.
|
||||
|
||||
If {name}`le` obeys laws, then {lean}`le.opposite` obeys the opposite laws. For example, if
|
||||
{name}`le` encodes a linear order, then {lean}`le.opposite` also encodes a linear order.
|
||||
To automatically derive these laws, use {lit}`open Std.OppositeOrderInstances`.
|
||||
|
||||
For example, {name}`LE.opposite` can be used to derive maximum operations from minimum operations,
|
||||
since finding the minimum in the opposite order is the same as finding the maximum in the original order:
|
||||
|
||||
```lean +warning
|
||||
def min' [LE α] [DecidableLE α] (a b : α) : α :=
|
||||
if a ≤ b then a else b
|
||||
|
||||
open scoped Std.OppositeOrderInstances in
|
||||
def max' [LE α] [DecidableLE α] (a b : α) : α :=
|
||||
letI : LE α := (inferInstanceAs (LE α)).opposite
|
||||
-- `DecidableLE` for the opposite order is derived automatically via `OppositeOrderInstances`
|
||||
min' a b
|
||||
```
|
||||
|
||||
Without the `open scoped` command, Lean would not find the required {lit}`DecidableLE α`
|
||||
instance for the opposite order.
|
||||
-/
|
||||
def LE.opposite (le : LE α) : LE α where
|
||||
le a b := b ≤ a
|
||||
|
||||
theorem LE.opposite_def {le : LE α} :
|
||||
le.opposite = ⟨fun a b => b ≤ a⟩ :=
|
||||
(rfl)
|
||||
|
||||
theorem LE.le_opposite_iff {le : LE α} {a b : α} :
|
||||
(haveI := le.opposite; a ≤ b) ↔ b ≤ a := by
|
||||
exact Iff.rfl
|
||||
|
||||
/--
|
||||
Inverts an {name}`LT` instance.
|
||||
|
||||
The result is an {lean}`LT α` instance where {lit}`a < b` holds when {name}`lt` would have
|
||||
{lit}`b < a` hold.
|
||||
|
||||
If {name}`lt` obeys laws, then {lean}`lt.opposite` obeys the opposite laws.
|
||||
To automatically derive these laws, use {lit}`open scoped Std.OppositeOrderInstances`.
|
||||
|
||||
For example, one can use the derived instances to prove properties about the opposite {name}`LT`
|
||||
instance:
|
||||
|
||||
```lean
|
||||
open scoped Std.OppositeOrderInstances in
|
||||
example [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsLinearOrder α] {x y : α} :
|
||||
letI : LE α := LE.opposite inferInstance
|
||||
letI : LT α := LT.opposite inferInstance
|
||||
¬ y ≤ x ↔ x < y :=
|
||||
letI : LE α := LE.opposite inferInstance
|
||||
letI : LT α := LT.opposite inferInstance
|
||||
Std.not_le
|
||||
```
|
||||
|
||||
Without the `open scoped` command, Lean would not find the {lit}`LawfulOrderLT α`
|
||||
and {lit}`IsLinearOrder α` instances for the opposite order that are required by {name}`not_le`.
|
||||
-/
|
||||
def LT.opposite (lt : LT α) : LT α where
|
||||
lt a b := b < a
|
||||
|
||||
theorem LT.opposite_def {lt : LT α} :
|
||||
lt.opposite = ⟨fun a b => b < a⟩ :=
|
||||
(rfl)
|
||||
|
||||
theorem LT.lt_opposite_iff {lt : LT α} {a b : α} :
|
||||
(haveI := lt.opposite; a < b) ↔ b < a := by
|
||||
exact Iff.rfl
|
||||
|
||||
/--
|
||||
Creates a {name}`Max` instance from a {name}`Min` instance.
|
||||
|
||||
The result is a {lean}`Max α` instance that uses {lean}`min.min` as its {name}`max` operation.
|
||||
|
||||
If {name}`min` obeys laws, then {lean}`min.oppositeMax` obeys the corresponding laws for {name}`Max`.
|
||||
To automatically derive these laws, use {lit}`open scoped Std.OppositeOrderInstances`.
|
||||
|
||||
For example, one can use the derived instances to prove properties about the opposite {name}`Max`
|
||||
instance:
|
||||
|
||||
```lean
|
||||
open scoped Std.OppositeOrderInstances in
|
||||
example [LE α] [DecidableLE α] [Min α] [Std.LawfulOrderLeftLeaningMin α] {a b : α} :
|
||||
letI : LE α := LE.opposite inferInstance
|
||||
letI : Max α := (inferInstance : Min α).oppositeMax
|
||||
max a b = if b ≤ a then a else b :=
|
||||
letI : LE α := LE.opposite inferInstance
|
||||
letI : Max α := (inferInstance : Min α).oppositeMax
|
||||
Std.max_eq_if
|
||||
```
|
||||
|
||||
Without the `open scoped` command, Lean would not find the {lit}`LawfulOrderLeftLeaningMax α`
|
||||
instance for the opposite order that is required by {name}`max_eq_if`.
|
||||
-/
|
||||
def Min.oppositeMax (min : Min α) : Max α where
|
||||
max a b := Min.min a b
|
||||
|
||||
theorem Min.oppositeMax_def {min : Min α} :
|
||||
min.oppositeMax = ⟨Min.min⟩ :=
|
||||
(rfl)
|
||||
|
||||
theorem Min.max_oppositeMax {min : Min α} {a b : α} :
|
||||
(haveI := min.oppositeMax; Max.max a b) = Min.min a b :=
|
||||
(rfl)
|
||||
|
||||
/--
|
||||
Creates a {name}`Min` instance from a {name}`Max` instance.
|
||||
|
||||
The result is a {lean}`Min α` instance that uses {lean}`max.max` as its {name}`min` operation.
|
||||
|
||||
If {name}`max` obeys laws, then {lean}`max.oppositeMin` obeys the corresponding laws for {name}`Min`.
|
||||
To automatically derive these laws, use {lit}`open scoped Std.OppositeOrderInstances`.
|
||||
|
||||
For example, one can use the derived instances to prove properties about the opposite {name}`Min`
|
||||
instance:
|
||||
|
||||
```lean
|
||||
open scoped Std.OppositeOrderInstances in
|
||||
example [LE α] [DecidableLE α] [Max α] [Std.LawfulOrderLeftLeaningMax α] {a b : α} :
|
||||
letI : LE α := LE.opposite inferInstance
|
||||
letI : Min α := (inferInstance : Max α).oppositeMin
|
||||
min a b = if a ≤ b then a else b :=
|
||||
letI : LE α := LE.opposite inferInstance
|
||||
letI : Min α := (inferInstance : Max α).oppositeMin
|
||||
Std.min_eq_if
|
||||
```
|
||||
|
||||
Without the `open scoped` command, Lean would not find the {lit}`LawfulOrderLeftLeaningMin α`
|
||||
instance for the opposite order that is required by {name}`min_eq_if`.
|
||||
-/
|
||||
def Max.oppositeMin (max : Max α) : Min α where
|
||||
min a b := Max.max a b
|
||||
|
||||
theorem Max.oppositeMin_def {min : Max α} :
|
||||
min.oppositeMin = ⟨Max.max⟩ :=
|
||||
(rfl)
|
||||
|
||||
theorem Max.min_oppositeMin {max : Max α} {a b : α} :
|
||||
(haveI := max.oppositeMin; Min.min a b) = Max.max a b :=
|
||||
(rfl)
|
||||
|
||||
namespace Std.OppositeOrderInstances
|
||||
|
||||
@[no_expose]
|
||||
scoped instance (priority := low) instDecidableLEOpposite {i : LE α} [id : DecidableLE α] :
|
||||
haveI := i.opposite
|
||||
DecidableLE α :=
|
||||
fun a b => id b a
|
||||
|
||||
@[no_expose]
|
||||
scoped instance (priority := low) instDecidableLTOpposite {i : LT α} [id : DecidableLT α] :
|
||||
haveI := i.opposite
|
||||
DecidableLT α :=
|
||||
fun a b => id b a
|
||||
|
||||
scoped instance (priority := low) instLEReflOpposite {i : LE α} [Refl (α := α) (· ≤ ·)] :
|
||||
haveI := i.opposite
|
||||
Refl (α := α) (· ≤ ·) :=
|
||||
letI := i.opposite
|
||||
{ refl a := letI := i; le_refl a }
|
||||
|
||||
scoped instance (priority := low) instLESymmOpposite {i : LE α} [Symm (α := α) (· ≤ ·)] :
|
||||
haveI := i.opposite
|
||||
Symm (α := α) (· ≤ ·) :=
|
||||
letI := i.opposite
|
||||
{ symm a b hab := by
|
||||
simp only [LE.opposite] at *
|
||||
letI := i
|
||||
exact Symm.symm b a hab }
|
||||
|
||||
scoped instance (priority := low) instLEAntisymmOpposite {i : LE α} [Antisymm (α := α) (· ≤ ·)] :
|
||||
haveI := i.opposite
|
||||
Antisymm (α := α) (· ≤ ·) :=
|
||||
letI := i.opposite
|
||||
{ antisymm a b hab hba := by
|
||||
simp only [LE.opposite] at *
|
||||
letI := i
|
||||
exact le_antisymm hba hab }
|
||||
|
||||
scoped instance (priority := low) instLEAsymmOpposite {i : LE α} [Asymm (α := α) (· ≤ ·)] :
|
||||
haveI := i.opposite
|
||||
Asymm (α := α) (· ≤ ·) :=
|
||||
letI := i.opposite
|
||||
{ asymm a b hab := by
|
||||
simp only [LE.opposite] at *
|
||||
letI := i
|
||||
exact Asymm.asymm b a hab }
|
||||
|
||||
scoped instance (priority := low) instLETransOpposite {i : LE α}
|
||||
[Trans (· ≤ ·) (· ≤ ·) (· ≤ · : α → α → Prop)] :
|
||||
haveI := i.opposite
|
||||
Trans (· ≤ ·) (· ≤ ·) (· ≤ · : α → α → Prop) :=
|
||||
letI := i.opposite
|
||||
{ trans hab hbc := by
|
||||
simp only [LE.opposite] at *
|
||||
letI := i
|
||||
exact Trans.trans hbc hab }
|
||||
|
||||
scoped instance (priority := low) instLETotalOpposite {i : LE α} [Total (α := α) (· ≤ ·)] :
|
||||
haveI := i.opposite
|
||||
Total (α := α) (· ≤ ·) :=
|
||||
letI := i.opposite
|
||||
{ total a b := letI := i; le_total (a := b) (b := a) }
|
||||
|
||||
scoped instance (priority := low) instLEIrreflOpposite {i : LE α} [Irrefl (α := α) (· ≤ ·)] :
|
||||
haveI := i.opposite
|
||||
Irrefl (α := α) (· ≤ ·) :=
|
||||
letI := i.opposite
|
||||
{ irrefl a := letI := i; Irrefl.irrefl (r := (· ≤ ·)) a }
|
||||
|
||||
scoped instance (priority := low) instIsPreorderOpposite {i : LE α} [IsPreorder α] :
|
||||
haveI := i.opposite
|
||||
IsPreorder α :=
|
||||
letI := i.opposite
|
||||
{ le_refl a := le_refl a
|
||||
le_trans _ _ _ := le_trans }
|
||||
|
||||
scoped instance (priority := low) instIsPartialOrderOpposite {i : LE α} [IsPartialOrder α] :
|
||||
haveI := i.opposite
|
||||
IsPartialOrder α :=
|
||||
letI := i.opposite
|
||||
{ le_antisymm _ _ := le_antisymm }
|
||||
|
||||
scoped instance (priority := low) instIsLinearPreorderOpposite {i : LE α} [IsLinearPreorder α] :
|
||||
haveI := i.opposite
|
||||
IsLinearPreorder α :=
|
||||
letI := i.opposite
|
||||
{ le_total _ _ := le_total }
|
||||
|
||||
scoped instance (priority := low) instIsLinearOrderOpposite {i : LE α} [IsLinearOrder α] :
|
||||
haveI := i.opposite
|
||||
IsLinearOrder α :=
|
||||
letI := i.opposite; {}
|
||||
|
||||
scoped instance (priority := low) instLawfulOrderOrdOpposite {il : LE α} {io : Ord α}
|
||||
[LawfulOrderOrd α] :
|
||||
haveI := il.opposite
|
||||
haveI := io.opposite
|
||||
LawfulOrderOrd α :=
|
||||
letI := il.opposite
|
||||
letI := io.opposite
|
||||
{ isLE_compare a b := by
|
||||
simp only [LE.opposite, Ord.opposite]
|
||||
letI := il; letI := io
|
||||
apply isLE_compare
|
||||
isGE_compare a b := by
|
||||
simp only [LE.opposite, Ord.opposite]
|
||||
letI := il; letI := io
|
||||
apply isGE_compare }
|
||||
|
||||
scoped instance (priority := low) instLawfulOrderLTOpposite {il : LE α} {it : LT α}
|
||||
[LawfulOrderLT α] :
|
||||
haveI := il.opposite
|
||||
haveI := it.opposite
|
||||
LawfulOrderLT α :=
|
||||
letI := il.opposite
|
||||
letI := it.opposite
|
||||
{ lt_iff a b := by
|
||||
simp only [LE.opposite, LT.opposite]
|
||||
letI := il; letI := it
|
||||
exact LawfulOrderLT.lt_iff b a }
|
||||
|
||||
scoped instance (priority := low) instLawfulOrderBEqOpposite {il : LE α} {ib : BEq α}
|
||||
[LawfulOrderBEq α] :
|
||||
haveI := il.opposite
|
||||
LawfulOrderBEq α :=
|
||||
letI := il.opposite
|
||||
{ beq_iff_le_and_ge a b := by
|
||||
simp only [LE.opposite]
|
||||
letI := il; letI := ib
|
||||
rw [LawfulOrderBEq.beq_iff_le_and_ge]
|
||||
exact and_comm }
|
||||
|
||||
scoped instance (priority := low) instLawfulOrderInfOpposite {il : LE α} {im : Min α}
|
||||
[LawfulOrderInf α] :
|
||||
haveI := il.opposite
|
||||
haveI := im.oppositeMax
|
||||
LawfulOrderSup α :=
|
||||
letI := il.opposite
|
||||
letI := im.oppositeMax
|
||||
{ max_le_iff a b c := by
|
||||
simp only [LE.opposite, Min.oppositeMax]
|
||||
letI := il; letI := im
|
||||
exact LawfulOrderInf.le_min_iff c a b }
|
||||
|
||||
scoped instance (priority := low) instLawfulOrderMinOpposite {il : LE α} {im : Min α}
|
||||
[LawfulOrderMin α] :
|
||||
haveI := il.opposite
|
||||
haveI := im.oppositeMax
|
||||
LawfulOrderMax α :=
|
||||
letI := il.opposite
|
||||
letI := im.oppositeMax
|
||||
{ max_eq_or a b := by
|
||||
simp only [Min.oppositeMax]
|
||||
letI := il; letI := im
|
||||
exact MinEqOr.min_eq_or a b
|
||||
max_le_iff a b c := by
|
||||
simp only [LE.opposite, Min.oppositeMax]
|
||||
letI := il; letI := im
|
||||
exact LawfulOrderInf.le_min_iff c a b }
|
||||
|
||||
scoped instance (priority := low) instLawfulOrderSupOpposite {il : LE α} {im : Max α}
|
||||
[LawfulOrderSup α] :
|
||||
haveI := il.opposite
|
||||
haveI := im.oppositeMin
|
||||
LawfulOrderInf α :=
|
||||
letI := il.opposite
|
||||
letI := im.oppositeMin
|
||||
{ le_min_iff a b c := by
|
||||
simp only [LE.opposite, Max.oppositeMin]
|
||||
letI := il; letI := im
|
||||
exact LawfulOrderSup.max_le_iff b c a }
|
||||
|
||||
scoped instance (priority := low) instLawfulOrderMaxOpposite {il : LE α} {im : Max α}
|
||||
[LawfulOrderMax α] :
|
||||
haveI := il.opposite
|
||||
haveI := im.oppositeMin
|
||||
LawfulOrderMin α :=
|
||||
letI := il.opposite
|
||||
letI := im.oppositeMin
|
||||
{ min_eq_or a b := by
|
||||
simp only [Max.oppositeMin]
|
||||
letI := il; letI := im
|
||||
exact MaxEqOr.max_eq_or a b
|
||||
le_min_iff a b c := by
|
||||
simp only [LE.opposite, Max.oppositeMin]
|
||||
letI := il; letI := im
|
||||
exact LawfulOrderSup.max_le_iff b c a }
|
||||
|
||||
scoped instance (priority := low) instLawfulOrderLeftLeaningMinOpposite {il : LE α} {im : Min α}
|
||||
[LawfulOrderLeftLeaningMin α] :
|
||||
haveI := il.opposite
|
||||
haveI := im.oppositeMax
|
||||
LawfulOrderLeftLeaningMax α :=
|
||||
letI := il.opposite
|
||||
letI := im.oppositeMax
|
||||
{ max_eq_left a b hab := by
|
||||
simp only [Min.oppositeMax]
|
||||
letI := il; letI := im
|
||||
exact LawfulOrderLeftLeaningMin.min_eq_left a b hab
|
||||
max_eq_right a b hab := by
|
||||
simp only [Min.oppositeMax]
|
||||
letI := il; letI := im
|
||||
exact LawfulOrderLeftLeaningMin.min_eq_right a b hab }
|
||||
|
||||
scoped instance (priority := low) instLawfulOrderLeftLeaningMaxOpposite {il : LE α} {im : Max α}
|
||||
[LawfulOrderLeftLeaningMax α] :
|
||||
haveI := il.opposite
|
||||
haveI := im.oppositeMin
|
||||
LawfulOrderLeftLeaningMin α :=
|
||||
letI := il.opposite
|
||||
letI := im.oppositeMin
|
||||
{ min_eq_left a b hab := by
|
||||
simp only [Max.oppositeMin]
|
||||
letI := il; letI := im
|
||||
exact LawfulOrderLeftLeaningMax.max_eq_left a b hab
|
||||
min_eq_right a b hab := by
|
||||
simp only [Max.oppositeMin]
|
||||
letI := il; letI := im
|
||||
exact LawfulOrderLeftLeaningMax.max_eq_right a b hab }
|
||||
|
||||
end OppositeOrderInstances
|
||||
|
||||
-- When imported from a non-module, these instances are exposed, and reducing them during
|
||||
-- type class resolution is too inefficient.
|
||||
attribute [irreducible] LE.opposite LT.opposite Min.oppositeMax Max.oppositeMin
|
||||
|
||||
section DocsToVerso
|
||||
|
||||
set_option linter.unusedVariables false -- Otherwise, we get warnings about Verso code blocks.
|
||||
docs_to_verso LE.opposite
|
||||
docs_to_verso LT.opposite
|
||||
docs_to_verso Min.oppositeMax
|
||||
docs_to_verso Max.oppositeMin
|
||||
|
||||
end DocsToVerso
|
||||
@@ -85,9 +85,12 @@ theorem toList_eq {α : Type u} {it : Iter (α := SubarrayIterator α) α} :
|
||||
· rw [dif_neg]; rotate_left; exact h
|
||||
simp_all [it.internalState.xs.stop_le_array_size]
|
||||
|
||||
theorem count_eq {α : Type u} {it : Iter (α := SubarrayIterator α) α} :
|
||||
it.count = it.internalState.xs.stop - it.internalState.xs.start := by
|
||||
simp [← Iter.length_toList_eq_count, toList_eq, it.internalState.xs.stop_le_array_size]
|
||||
theorem length_eq {α : Type u} {it : Iter (α := SubarrayIterator α) α} :
|
||||
it.length = it.internalState.xs.stop - it.internalState.xs.start := by
|
||||
simp [← Iter.length_toList_eq_length, toList_eq, it.internalState.xs.stop_le_array_size]
|
||||
|
||||
@[deprecated length_eq (since := "2026-01-28")]
|
||||
def count_eq := @length_eq
|
||||
|
||||
end SubarrayIterator
|
||||
|
||||
@@ -105,7 +108,7 @@ theorem toList_internalIter {α : Type u} {s : Subarray α} :
|
||||
public instance : LawfulSliceSize (Internal.SubarrayData α) where
|
||||
lawful s := by
|
||||
simp [SliceSize.size, ToIterator.iter_eq, Iter.toIter_toIterM,
|
||||
← Iter.length_toList_eq_count, SubarrayIterator.toList_eq,
|
||||
← Iter.length_toList_eq_length, SubarrayIterator.toList_eq,
|
||||
s.internalRepresentation.stop_le_array_size, start, stop, array]
|
||||
|
||||
public theorem toArray_eq_sliceToArray {α : Type u} {s : Subarray α} :
|
||||
|
||||
@@ -60,12 +60,15 @@ public theorem forIn_toArray {γ : Type u} {β : Type v}
|
||||
ForIn.forIn s.toArray init f = ForIn.forIn s init f := by
|
||||
rw [← forIn_internalIter, ← Iter.forIn_toArray, Slice.toArray]
|
||||
|
||||
theorem Internal.size_eq_count_iter [ToIterator (Slice γ) Id α β]
|
||||
theorem Internal.size_eq_length_iter [ToIterator (Slice γ) Id α β]
|
||||
[Iterator α Id β] [Finite α Id]
|
||||
[IteratorLoop α Id Id] [LawfulIteratorLoop α Id Id]
|
||||
{s : Slice γ} [SliceSize γ] [LawfulSliceSize γ] :
|
||||
s.size = (Internal.iter s).count := by
|
||||
simp only [Slice.size, iter, LawfulSliceSize.lawful, ← Iter.length_toList_eq_count]
|
||||
s.size = (Internal.iter s).length := by
|
||||
simp only [Slice.size, iter, LawfulSliceSize.lawful, ← Iter.length_toList_eq_length]
|
||||
|
||||
@[deprecated Internal.size_eq_length_iter (since := "2026-01-28")]
|
||||
def Internal.size_eq_count_iter := @Internal.size_eq_length_iter
|
||||
|
||||
theorem Internal.toArray_eq_toArray_iter {s : Slice γ} [ToIterator (Slice γ) Id α β]
|
||||
[Iterator α Id β]
|
||||
@@ -91,7 +94,7 @@ theorem size_toArray_eq_size [ToIterator (Slice γ) Id α β]
|
||||
{s : Slice γ} :
|
||||
s.toArray.size = s.size := by
|
||||
letI : IteratorLoop α Id Id := .defaultImplementation
|
||||
rw [Internal.size_eq_count_iter, Internal.toArray_eq_toArray_iter, Iter.size_toArray_eq_count]
|
||||
rw [Internal.size_eq_length_iter, Internal.toArray_eq_toArray_iter, Iter.size_toArray_eq_length]
|
||||
|
||||
@[simp]
|
||||
theorem length_toList_eq_size [ToIterator (Slice γ) Id α β]
|
||||
@@ -100,7 +103,7 @@ theorem length_toList_eq_size [ToIterator (Slice γ) Id α β]
|
||||
[Finite α Id] :
|
||||
s.toList.length = s.size := by
|
||||
letI : IteratorLoop α Id Id := .defaultImplementation
|
||||
rw [Internal.size_eq_count_iter, Internal.toList_eq_toList_iter, Iter.length_toList_eq_count]
|
||||
rw [Internal.size_eq_length_iter, Internal.toList_eq_toList_iter, Iter.length_toList_eq_length]
|
||||
|
||||
@[simp]
|
||||
theorem length_toListRev_eq_size [ToIterator (Slice γ) Id α β]
|
||||
@@ -109,7 +112,7 @@ theorem length_toListRev_eq_size [ToIterator (Slice γ) Id α β]
|
||||
[Finite α Id]
|
||||
[LawfulIteratorLoop α Id Id] :
|
||||
s.toListRev.length = s.size := by
|
||||
rw [Internal.size_eq_count_iter, Internal.toListRev_eq_toListRev_iter,
|
||||
Iter.length_toListRev_eq_count]
|
||||
rw [Internal.size_eq_length_iter, Internal.toListRev_eq_toListRev_iter,
|
||||
Iter.length_toListRev_eq_length]
|
||||
|
||||
end Std.Slice
|
||||
|
||||
@@ -34,7 +34,7 @@ attribute [instance] ListSlice.instToIterator
|
||||
universe v w
|
||||
|
||||
instance : SliceSize (Internal.ListSliceData α) where
|
||||
size s := (Internal.iter s).count
|
||||
size s := (Internal.iter s).length
|
||||
|
||||
@[no_expose]
|
||||
instance {α : Type u} {m : Type v → Type w} [Monad m] :
|
||||
|
||||
@@ -60,7 +60,7 @@ public theorem toList_toArray {xs : ListSlice α} :
|
||||
@[simp, grind =]
|
||||
public theorem length_toList {xs : ListSlice α} :
|
||||
xs.toList.length = xs.size := by
|
||||
simp [ListSlice.toList_eq, Std.Slice.size, Std.Slice.SliceSize.size, ← Iter.length_toList_eq_count,
|
||||
simp [ListSlice.toList_eq, Std.Slice.size, Std.Slice.SliceSize.size, ← Iter.length_toList_eq_length,
|
||||
toList_internalIter]; rfl
|
||||
|
||||
@[grind =]
|
||||
|
||||
@@ -45,7 +45,7 @@ class LawfulSliceSize (γ : Type u) [SliceSize γ] [ToIterator (Slice γ) Id α
|
||||
/-- The iterator of a slice `s` of type `Slice γ` emits exactly `SliceSize.size s` elements. -/
|
||||
lawful :
|
||||
letI : IteratorLoop α Id Id := .defaultImplementation
|
||||
∀ s : Slice γ, SliceSize.size s = (ToIterator.iter (γ := Slice γ) s).count
|
||||
∀ s : Slice γ, SliceSize.size s = (ToIterator.iter (γ := Slice γ) s).length
|
||||
|
||||
/--
|
||||
Returns the number of elements with distinct indices in the given slice.
|
||||
|
||||
@@ -905,9 +905,9 @@ Examples:
|
||||
def chars (s : Slice) :=
|
||||
Std.Iter.map (fun ⟨pos, h⟩ => pos.get h) (positions s)
|
||||
|
||||
@[deprecated "There is no constant-time length function on slices. Use `s.positions.count` instead, or `isEmpty` if you only need to know whether the slice is empty." (since := "2025-11-20")]
|
||||
@[deprecated "There is no constant-time length function on slices. Use `s.positions.length` instead, or `isEmpty` if you only need to know whether the slice is empty." (since := "2025-11-20")]
|
||||
def length (s : Slice) : Nat :=
|
||||
s.positions.count
|
||||
s.positions.length
|
||||
|
||||
structure RevPosIterator (s : Slice) where
|
||||
currPos : s.Pos
|
||||
|
||||
@@ -137,6 +137,11 @@ structure Config where
|
||||
For local theorems, use `+suggestions` instead.
|
||||
-/
|
||||
locals : Bool := false
|
||||
/--
|
||||
If `instances` is `true`, `dsimp` will visit instance arguments.
|
||||
If option `backward.dsimp.instances` is `true`, it overrides this field.
|
||||
-/
|
||||
instances : Bool := false
|
||||
deriving Inhabited, BEq
|
||||
|
||||
end DSimp
|
||||
@@ -308,6 +313,11 @@ structure Config where
|
||||
For local theorems, use `+suggestions` instead.
|
||||
-/
|
||||
locals : Bool := false
|
||||
/--
|
||||
If `instances` is `true`, `dsimp` will visit instance arguments.
|
||||
If option `backward.dsimp.instances` is `true`, it overrides this field.
|
||||
-/
|
||||
instances : Bool := false
|
||||
deriving Inhabited, BEq
|
||||
|
||||
-- Configuration object for `simp_all`
|
||||
@@ -374,7 +384,7 @@ structure ExtractLetsConfig where
|
||||
/-- If true (default: false), eliminate unused lets rather than extract them. -/
|
||||
usedOnly : Bool := false
|
||||
/-- If true (default: true), reuse local declarations that have syntactically equal values.
|
||||
Note that even when false, the caching strategy for `extract_let`s may result in fewer extracted let bindings than expected. -/
|
||||
Note that even when false, the caching strategy for `extract_lets` may result in fewer extracted let bindings than expected. -/
|
||||
merge : Bool := true
|
||||
/-- When merging is enabled, if true (default: true), make use of pre-existing local definitions in the local context. -/
|
||||
useContext : Bool := true
|
||||
|
||||
@@ -872,6 +872,12 @@ Substring matching:
|
||||
(after whitespace normalization). This is useful when you only care about part of the message.
|
||||
- `substring := false` (the default) requires exact matching (modulo whitespace normalization).
|
||||
|
||||
Stabilizing output:
|
||||
When messages contain autogenerated names (e.g., metavariables like `?m.47`), the output may
|
||||
differ between runs or Lean versions. Use `set_option pp.mvars.anonymous false` to replace
|
||||
anonymous metavariables with `?_` while preserving user-named metavariables like `?a`.
|
||||
Alternatively, `set_option pp.mvars false` replaces all metavariables with `?_`.
|
||||
|
||||
For example, `#guard_msgs (error, drop all) in cmd` means to check errors and drop
|
||||
everything else.
|
||||
|
||||
|
||||
@@ -322,6 +322,10 @@ For more information: [Equality](https://lean-lang.org/theorem_proving_in_lean4/
|
||||
@[symm] theorem Eq.symm {α : Sort u} {a b : α} (h : Eq a b) : Eq b a :=
|
||||
h ▸ rfl
|
||||
|
||||
/-- Non-dependent recursor for the equality type (symmetric variant) -/
|
||||
@[simp] abbrev Eq.ndrec_symm.{u1, u2} {α : Sort u2} {a : α} {motive : α → Sort u1} (m : motive a) {b : α} (h : Eq b a) : motive b :=
|
||||
h.symm.ndrec m
|
||||
|
||||
/--
|
||||
Equality is transitive: if `a = b` and `b = c` then `a = c`.
|
||||
|
||||
|
||||
@@ -3,12 +3,9 @@ Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura, Mario Carneiro
|
||||
-/
|
||||
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.Data.Array.Set
|
||||
|
||||
public section
|
||||
|
||||
/-!
|
||||
|
||||
@@ -110,11 +110,13 @@ def fileUriToPath? (uri : String) : Option System.FilePath := Id.run do
|
||||
else
|
||||
let mut p := (unescapeUri uri).drop "file://".length |>.copy
|
||||
p := p.dropWhile (λ c => c != '/') |>.copy -- drop the hostname.
|
||||
-- On Windows, the path "/c:/temp" needs to become "C:/temp"
|
||||
if System.Platform.isWindows && p.length >= 2 &&
|
||||
p.front == '/' && (String.Pos.Raw.get p ⟨1⟩).isAlpha && String.Pos.Raw.get p ⟨2⟩ == ':' then
|
||||
-- see also `pathToUri`
|
||||
p := String.Pos.Raw.modify (p.drop 1).copy 0 .toUpper
|
||||
if System.Platform.isWindows then
|
||||
-- On Windows, the path "/c:/temp" needs to become "C:/temp"
|
||||
if p.length >= 2 &&
|
||||
p.front == '/' && (String.Pos.Raw.get p ⟨1⟩).isAlpha && String.Pos.Raw.get p ⟨2⟩ == ':' then
|
||||
-- see also `pathToUri`
|
||||
p := String.Pos.Raw.modify (p.drop 1).copy 0 .toUpper
|
||||
p := p.map (fun c => if c == '/' then '\\' else c)
|
||||
some p
|
||||
|
||||
end Uri
|
||||
|
||||
@@ -1093,8 +1093,6 @@ See also:
|
||||
* `first | tac1 | tac2` implements the backtracking used by `repeat`
|
||||
-/
|
||||
syntax "repeat " tacticSeq : tactic
|
||||
macro_rules
|
||||
| `(tactic| repeat $seq) => `(tactic| first | ($seq); repeat $seq | skip)
|
||||
|
||||
/--
|
||||
`repeat' tac` recursively applies `tac` on all of the goals so long as it succeeds.
|
||||
|
||||
@@ -270,10 +270,11 @@ def registerParametricAttribute (impl : ParametricAttributeImpl α) : IO (Parame
|
||||
let mut r := if impl.preserveOrder then
|
||||
decls.toArray.reverse.filterMap (fun n => return (n, ← m.find? n))
|
||||
else
|
||||
m.foldl (fun a n p => a.push (n, p)) #[]
|
||||
let r := m.foldl (fun a n p => a.push (n, p)) #[]
|
||||
r.qsort (fun a b => Name.quickLt a.1 b.1)
|
||||
if lvl != .private then
|
||||
r := r.filter (fun ⟨n, a⟩ => impl.filterExport env n a)
|
||||
r.qsort (fun a b => Name.quickLt a.1 b.1)
|
||||
r
|
||||
statsFn := fun (_, m) => "parametric attribute" ++ Format.line ++ "number of local entries: " ++ format m.size
|
||||
}
|
||||
let attrImpl : AttributeImpl := {
|
||||
|
||||
@@ -33,6 +33,7 @@ def isAuxRecursor (env : Environment) (declName : Name) : Bool :=
|
||||
-- TODO: use `markAuxRecursor` when they are defined
|
||||
-- An attribute is not a good solution since we don't want users to control what is tagged as an auxiliary recursor.
|
||||
|| declName == ``Eq.ndrec
|
||||
|| declName == ``Eq.ndrec_symm
|
||||
|| declName == ``Eq.ndrecOn
|
||||
|
||||
def isAuxRecursorWithSuffix (env : Environment) (declName : Name) (suffix : String) : Bool :=
|
||||
|
||||
@@ -115,10 +115,10 @@ private def exportIREntries (env : Environment) : Array (Name × Array EnvExtens
|
||||
-- safety: cast to erased type
|
||||
let irEntries : Array EnvExtensionEntry := unsafe unsafeCast <| sortDecls irDecls
|
||||
|
||||
-- see `regularInitAttr.filterExport`
|
||||
let initDecls : Array (Name × Name) := regularInitAttr.ext.getState env
|
||||
|>.2.foldl (fun a n p => a.push (n, p)) #[]
|
||||
|>.qsort (fun a b => Name.quickLt a.1 b.1)
|
||||
-- save all initializers independent of meta/private. Non-meta initializers will only be used when
|
||||
-- .ir is actually loaded, and private ones iff visible.
|
||||
let initDecls : Array (Name × Name) :=
|
||||
regularInitAttr.ext.exportEntriesFn env (regularInitAttr.ext.getState env) .private
|
||||
-- safety: cast to erased type
|
||||
let initDecls : Array EnvExtensionEntry := unsafe unsafeCast initDecls
|
||||
|
||||
|
||||
@@ -40,14 +40,14 @@ structure BuilderState where
|
||||
For this reason we carry around these kinds of bindings in this substitution and apply it whenever
|
||||
we access an fvar in the conversion.
|
||||
-/
|
||||
subst : LCNF.FVarSubst := {}
|
||||
subst : LCNF.FVarSubst .pure := {}
|
||||
|
||||
abbrev M := StateRefT BuilderState CoreM
|
||||
|
||||
instance : LCNF.MonadFVarSubst M false where
|
||||
instance : LCNF.MonadFVarSubst M .pure false where
|
||||
getSubst := return (← get).subst
|
||||
|
||||
instance : LCNF.MonadFVarSubstState M where
|
||||
instance : LCNF.MonadFVarSubstState M .pure where
|
||||
modifySubst f := modify fun s => { s with subst := f s.subst }
|
||||
|
||||
def M.run (x : M α) : CoreM α := do
|
||||
@@ -102,7 +102,7 @@ def lowerLitValue (v : LCNF.LitValue) : LitVal × IRType :=
|
||||
| .uint64 v => ⟨.num (UInt64.toNat v), .uint64⟩
|
||||
| .usize v => ⟨.num (UInt64.toNat v), .usize⟩
|
||||
|
||||
def lowerArg (a : LCNF.Arg) : M Arg := do
|
||||
def lowerArg (a : LCNF.Arg .pure) : M Arg := do
|
||||
match a with
|
||||
| .fvar fvarId => getFVarValue fvarId
|
||||
| .erased | .type .. => return .erased
|
||||
@@ -121,15 +121,15 @@ def lowerProj (base : VarId) (ctorInfo : CtorInfo) (field : CtorFieldInfo)
|
||||
| .erased => ⟨.erased, .erased⟩
|
||||
| .void => ⟨.erased, .void⟩
|
||||
|
||||
def lowerParam (p : LCNF.Param) : M Param := do
|
||||
def lowerParam (p : LCNF.Param .pure) : M Param := do
|
||||
let x ← bindVar p.fvarId
|
||||
let ty ← toIRType p.type
|
||||
if ty.isVoid || ty.isErased then
|
||||
Compiler.LCNF.addSubst p.fvarId .erased
|
||||
Compiler.LCNF.addSubst p.fvarId (.erased : LCNF.Arg .pure)
|
||||
return { x, borrow := p.borrow, ty }
|
||||
|
||||
mutual
|
||||
partial def lowerCode (c : LCNF.Code) : M FnBody := do
|
||||
partial def lowerCode (c : LCNF.Code .pure) : M FnBody := do
|
||||
match c with
|
||||
| .let decl k => lowerLet decl k
|
||||
| .jp decl k =>
|
||||
@@ -149,7 +149,7 @@ partial def lowerCode (c : LCNF.Code) : M FnBody := do
|
||||
for idx in 0...ps.size do
|
||||
let p := ps[idx]!
|
||||
if idx == info.fieldIdx then
|
||||
LCNF.addSubst p.fvarId (.fvar cases.discr)
|
||||
LCNF.addSubst p.fvarId (.fvar cases.discr : LCNF.Arg .pure)
|
||||
else
|
||||
bindErased p.fvarId
|
||||
lowerCode k
|
||||
@@ -165,7 +165,7 @@ partial def lowerCode (c : LCNF.Code) : M FnBody := do
|
||||
| .unreach .. => return .unreachable
|
||||
| .fun .. => panic! "all local functions should be λ-lifted"
|
||||
|
||||
partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
|
||||
partial def lowerLet (decl : LCNF.LetDecl .pure) (k : LCNF.Code .pure) : M FnBody := do
|
||||
let value ← LCNF.normLetValue decl.value
|
||||
match value with
|
||||
| .lit litValue =>
|
||||
@@ -175,7 +175,7 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
|
||||
| .proj typeName i fvarId =>
|
||||
if let some info ← hasTrivialStructure? typeName then
|
||||
if info.fieldIdx == i then
|
||||
LCNF.addSubst decl.fvarId (.fvar fvarId)
|
||||
LCNF.addSubst decl.fvarId (.fvar fvarId : LCNF.Arg .pure)
|
||||
else
|
||||
bindErased decl.fvarId
|
||||
lowerCode k
|
||||
@@ -250,7 +250,8 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
|
||||
| some (.defnInfo ..) | some (.opaqueInfo ..) =>
|
||||
mkFap name irArgs
|
||||
| some (.axiomInfo ..) | .some (.quotInfo ..) | .some (.inductInfo ..) | .some (.thmInfo ..) =>
|
||||
throwNamedError lean.dependsOnNoncomputable f!"`{name}` not supported by code generator; consider marking definition as `noncomputable`"
|
||||
-- Should have been caught by `ToLCNF`
|
||||
throwError f!"ToIR: unexpected use of noncomputable declaration `{name}`; please report this issue"
|
||||
| some (.recInfo ..) =>
|
||||
throwError f!"code generator does not support recursor `{name}` yet, consider using 'match ... with' and/or structural recursion"
|
||||
| none => panic! "reference to unbound name"
|
||||
@@ -302,11 +303,11 @@ where
|
||||
else
|
||||
mkOverApplication name numParams args
|
||||
|
||||
partial def lowerAlt (discr : VarId) (a : LCNF.Alt) : M Alt := do
|
||||
partial def lowerAlt (discr : VarId) (a : LCNF.Alt .pure) : M Alt := do
|
||||
match a with
|
||||
| .alt ctorName params code =>
|
||||
let ⟨ctorInfo, fields⟩ ← getCtorLayout ctorName
|
||||
let lowerParams (params : Array LCNF.Param) (fields : Array CtorFieldInfo) : M FnBody := do
|
||||
let lowerParams (params : Array (LCNF.Param .pure)) (fields : Array CtorFieldInfo) : M FnBody := do
|
||||
let rec loop (i : Nat) : M FnBody := do
|
||||
match params[i]?, fields[i]? with
|
||||
| some param, some field =>
|
||||
@@ -340,7 +341,7 @@ where resultTypeForArity (type : Lean.Expr) (arity : Nat) : Lean.Expr :=
|
||||
| .const ``lcErased _ => mkConst ``lcErased
|
||||
| _ => panic! "invalid arity"
|
||||
|
||||
def lowerDecl (d : LCNF.Decl) : M (Option Decl) := do
|
||||
def lowerDecl (d : LCNF.Decl .pure) : M (Option Decl) := do
|
||||
let params ← d.params.mapM lowerParam
|
||||
let mut resultType ← lowerResultType d.type d.params.size
|
||||
let taggedReturn := taggedReturnAttr.hasTag (← getEnv) d.name
|
||||
@@ -366,7 +367,7 @@ def lowerDecl (d : LCNF.Decl) : M (Option Decl) := do
|
||||
|
||||
end ToIR
|
||||
|
||||
def toIR (decls: Array LCNF.Decl) : CoreM (Array Decl) := do
|
||||
def toIR (decls: Array (LCNF.Decl .pure)) : CoreM (Array Decl) := do
|
||||
let mut irDecls := #[]
|
||||
for decl in decls do
|
||||
if let some irDecl ← ToIR.lowerDecl decl |>.run then
|
||||
|
||||
@@ -174,8 +174,11 @@ private unsafe def runInitAttrs (env : Environment) (opts : Options) : IO Unit :
|
||||
continue
|
||||
interpretedModInits.modify (·.insert mod)
|
||||
let modEntries := regularInitAttr.ext.getModuleEntries env modIdx
|
||||
-- `getModuleIREntries` is identical to `getModuleEntries` if we loaded only one of .olean/.ir
|
||||
-- so deduplicate (these lists should be very short)
|
||||
-- `getModuleIREntries` is identical to `getModuleEntries` if we loaded only one of
|
||||
-- .olean (from `meta initialize`)/.ir (`initialize` via transitive `meta import`)
|
||||
-- so deduplicate (these lists should be very short).
|
||||
-- If we have both, we should not need to worry about their relative ordering as `meta` and
|
||||
-- non-`meta` initialize should not have interdependencies.
|
||||
let modEntries := modEntries ++ (regularInitAttr.ext.getModuleIREntries env modIdx).filter (!modEntries.contains ·)
|
||||
for (decl, initDecl) in modEntries do
|
||||
-- Skip initializers we do not have IR for; they should not be reachable by interpretation.
|
||||
|
||||
@@ -30,7 +30,6 @@ public import Lean.Compiler.LCNF.ReduceJpArity
|
||||
public import Lean.Compiler.LCNF.Simp
|
||||
public import Lean.Compiler.LCNF.Specialize
|
||||
public import Lean.Compiler.LCNF.SpecInfo
|
||||
public import Lean.Compiler.LCNF.Testing
|
||||
public import Lean.Compiler.LCNF.ToDecl
|
||||
public import Lean.Compiler.LCNF.ToExpr
|
||||
public import Lean.Compiler.LCNF.ToLCNF
|
||||
|
||||
@@ -40,14 +40,14 @@ def eqvTypes (es₁ es₂ : Array Expr) : EqvM Bool := do
|
||||
else
|
||||
return false
|
||||
|
||||
def eqvArg (a₁ a₂ : Arg) : EqvM Bool := do
|
||||
def eqvArg (a₁ a₂ : Arg pu) : EqvM Bool := do
|
||||
match a₁, a₂ with
|
||||
| .type e₁, .type e₂ => eqvType e₁ e₂
|
||||
| .type e₁ _, .type e₂ _ => eqvType e₁ e₂
|
||||
| .fvar x₁, .fvar x₂ => eqvFVar x₁ x₂
|
||||
| .erased, .erased => return true
|
||||
| _, _ => return false
|
||||
|
||||
def eqvArgs (as₁ as₂ : Array Arg) : EqvM Bool := do
|
||||
def eqvArgs (as₁ as₂ : Array (Arg pu)) : EqvM Bool := do
|
||||
if as₁.size = as₂.size then
|
||||
for a₁ in as₁, a₂ in as₂ do
|
||||
unless (← eqvArg a₁ a₂) do
|
||||
@@ -56,19 +56,19 @@ def eqvArgs (as₁ as₂ : Array Arg) : EqvM Bool := do
|
||||
else
|
||||
return false
|
||||
|
||||
def eqvLetValue (e₁ e₂ : LetValue) : EqvM Bool := do
|
||||
def eqvLetValue (e₁ e₂ : LetValue pu) : EqvM Bool := do
|
||||
match e₁, e₂ with
|
||||
| .lit v₁, .lit v₂ => return v₁ == v₂
|
||||
| .erased, .erased => return true
|
||||
| .proj s₁ i₁ x₁, .proj s₂ i₂ x₂ => pure (s₁ == s₂ && i₁ == i₂) <&&> eqvFVar x₁ x₂
|
||||
| .const n₁ us₁ as₁, .const n₂ us₂ as₂ => pure (n₁ == n₂ && us₁ == us₂) <&&> eqvArgs as₁ as₂
|
||||
| .proj s₁ i₁ x₁ _, .proj s₂ i₂ x₂ _ => pure (s₁ == s₂ && i₁ == i₂) <&&> eqvFVar x₁ x₂
|
||||
| .const n₁ us₁ as₁ _, .const n₂ us₂ as₂ _ => pure (n₁ == n₂ && us₁ == us₂) <&&> eqvArgs as₁ as₂
|
||||
| .fvar f₁ as₁, .fvar f₂ as₂ => eqvFVar f₁ f₂ <&&> eqvArgs as₁ as₂
|
||||
| _, _ => return false
|
||||
|
||||
@[inline] def withFVar (fvarId₁ fvarId₂ : FVarId) (x : EqvM α) : EqvM α :=
|
||||
withReader (·.insert fvarId₂ fvarId₁) x
|
||||
|
||||
@[inline] def withParams (params₁ params₂ : Array Param) (x : EqvM Bool) : EqvM Bool := do
|
||||
@[inline] def withParams (params₁ params₂ : Array (Param pu)) (x : EqvM Bool) : EqvM Bool := do
|
||||
if h : params₂.size = params₁.size then
|
||||
let rec @[specialize] go (i : Nat) : EqvM Bool := do
|
||||
if h : i < params₁.size then
|
||||
@@ -85,7 +85,7 @@ def eqvLetValue (e₁ e₂ : LetValue) : EqvM Bool := do
|
||||
else
|
||||
return false
|
||||
|
||||
def sortAlts (alts : Array Alt) : Array Alt :=
|
||||
def sortAlts (alts : Array (Alt pu)) : Array (Alt pu) :=
|
||||
alts.qsort fun
|
||||
| .alt .., .default .. => true
|
||||
| .alt ctorName₁ .., .alt ctorName₂ .. => Name.lt ctorName₁ ctorName₂
|
||||
@@ -93,13 +93,13 @@ def sortAlts (alts : Array Alt) : Array Alt :=
|
||||
|
||||
mutual
|
||||
|
||||
partial def eqvAlts (alts₁ alts₂ : Array Alt) : EqvM Bool := do
|
||||
partial def eqvAlts (alts₁ alts₂ : Array (Alt pu)) : EqvM Bool := do
|
||||
if alts₁.size = alts₂.size then
|
||||
let alts₁ := sortAlts alts₁
|
||||
let alts₂ := sortAlts alts₂
|
||||
for alt₁ in alts₁, alt₂ in alts₂ do
|
||||
match alt₁, alt₂ with
|
||||
| .alt ctorName₁ ps₁ k₁, .alt ctorName₂ ps₂ k₂ =>
|
||||
| .alt ctorName₁ ps₁ k₁ _, .alt ctorName₂ ps₂ k₂ _ =>
|
||||
unless ctorName₁ == ctorName₂ do return false
|
||||
unless (← withParams ps₁ ps₂ (eqv k₁ k₂)) do return false
|
||||
| .default k₁, .default k₂ => unless (← eqv k₁ k₂) do return false
|
||||
@@ -108,13 +108,13 @@ partial def eqvAlts (alts₁ alts₂ : Array Alt) : EqvM Bool := do
|
||||
else
|
||||
return false
|
||||
|
||||
partial def eqv (code₁ code₂ : Code) : EqvM Bool := do
|
||||
partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do
|
||||
match code₁, code₂ with
|
||||
| .let decl₁ k₁, .let decl₂ k₂ =>
|
||||
eqvType decl₁.type decl₂.type <&&>
|
||||
eqvLetValue decl₁.value decl₂.value <&&>
|
||||
withFVar decl₁.fvarId decl₂.fvarId (eqv k₁ k₂)
|
||||
| .fun decl₁ k₁, .fun decl₂ k₂
|
||||
| .fun decl₁ k₁ _, .fun decl₂ k₂ _
|
||||
| .jp decl₁ k₁, .jp decl₂ k₂ =>
|
||||
eqvType decl₁.type decl₂.type <&&>
|
||||
withParams decl₁.params decl₂.params (eqv decl₁.value decl₂.value) <&&>
|
||||
@@ -135,7 +135,7 @@ end AlphaEqv
|
||||
/--
|
||||
Return `true` if `c₁` and `c₂` are alpha equivalent.
|
||||
-/
|
||||
def Code.alphaEqv (c₁ c₂ : Code) : Bool :=
|
||||
def Code.alphaEqv (c₁ c₂ : Code pu) : Bool :=
|
||||
AlphaEqv.eqv c₁ c₂ |>.run {}
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
||||
@@ -13,15 +13,21 @@ public section
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
builtin_initialize auxDeclCacheExt : CacheExtension Decl Name ← CacheExtension.register
|
||||
structure AuxDeclCacheKey where
|
||||
pu : Purity
|
||||
decl : Decl pu
|
||||
deriving BEq, Hashable
|
||||
|
||||
builtin_initialize auxDeclCacheExt : CacheExtension AuxDeclCacheKey Name ← CacheExtension.register
|
||||
|
||||
inductive CacheAuxDeclResult where
|
||||
| new
|
||||
| alreadyCached (declName : Name)
|
||||
|
||||
def cacheAuxDecl (decl : Decl) : CompilerM CacheAuxDeclResult := do
|
||||
def cacheAuxDecl (decl : Decl pu) : CompilerM CacheAuxDeclResult := do
|
||||
let key := { decl with name := .anonymous }
|
||||
let key ← normalizeFVarIds key
|
||||
let key := ⟨pu, key⟩
|
||||
match (← auxDeclCacheExt.find? key) with
|
||||
| some declName =>
|
||||
return .alreadyCached declName
|
||||
|
||||
@@ -24,14 +24,50 @@ and the approach described in the paper
|
||||
|
||||
-/
|
||||
|
||||
structure Param where
|
||||
/--
|
||||
This type is used to index the fundamental LCNF IR data structures. Depending on its value different
|
||||
constructors are available for the different semantic phases of LCNF.
|
||||
|
||||
Notably in order to save memory we never index the IR types over `Purity`. Instead the type is
|
||||
parametrized by the phase and the individual constructors might carry a proof (that will be erased)
|
||||
that they are only allowed in a certain phase.
|
||||
-/
|
||||
inductive Purity where
|
||||
/--
|
||||
The code we are acting on is still pure, things like reordering up to value dependencies are
|
||||
acceptable.
|
||||
-/
|
||||
| pure
|
||||
/--
|
||||
The code we are acting on is to be considered generally impure, doing reorderings is potentially
|
||||
no longer legal.
|
||||
-/
|
||||
| impure
|
||||
deriving Inhabited, DecidableEq, Hashable
|
||||
|
||||
instance : ToString Purity where
|
||||
toString
|
||||
| .pure => "pure"
|
||||
| .impure => "impure"
|
||||
|
||||
@[inline]
|
||||
def Purity.withAssertPurity [Inhabited α] (is : Purity) (should : Purity)
|
||||
(k : (is = should) → α) : α :=
|
||||
if h : is = should then
|
||||
k h
|
||||
else
|
||||
panic! s!"Purity should be {should} but is {is}, this is a bug"
|
||||
|
||||
scoped macro "purity_tac" : tactic => `(tactic| first | with_reducible rfl | assumption)
|
||||
|
||||
structure Param (pu : Purity) where
|
||||
fvarId : FVarId
|
||||
binderName : Name
|
||||
type : Expr
|
||||
borrow : Bool
|
||||
deriving Inhabited, BEq
|
||||
|
||||
def Param.toExpr (p : Param) : Expr :=
|
||||
def Param.toExpr (p : Param pu) : Expr :=
|
||||
.fvar p.fvarId
|
||||
|
||||
inductive LitValue where
|
||||
@@ -55,111 +91,111 @@ def LitValue.toExpr : LitValue → Expr
|
||||
| .uint64 v => .app (.const ``UInt64.ofNat []) (.lit (.natVal (UInt64.toNat v)))
|
||||
| .usize v => .app (.const ``USize.ofNat []) (.lit (.natVal (UInt64.toNat v)))
|
||||
|
||||
inductive Arg where
|
||||
inductive Arg (pu : Purity) where
|
||||
| erased
|
||||
| fvar (fvarId : FVarId)
|
||||
| type (expr : Expr)
|
||||
| type (expr : Expr) (h : pu = .pure := by purity_tac)
|
||||
deriving Inhabited, BEq, Hashable
|
||||
|
||||
def Param.toArg (p : Param) : Arg :=
|
||||
def Param.toArg (p : Param pu) : Arg pu :=
|
||||
.fvar p.fvarId
|
||||
|
||||
def Arg.toExpr (arg : Arg) : Expr :=
|
||||
def Arg.toExpr (arg : Arg pu) : Expr :=
|
||||
match arg with
|
||||
| .erased => erasedExpr
|
||||
| .fvar fvarId => .fvar fvarId
|
||||
| .type e => e
|
||||
| .type e _ => e
|
||||
|
||||
private unsafe def Arg.updateTypeImp (arg : Arg) (type' : Expr) : Arg :=
|
||||
private unsafe def Arg.updateTypeImp (arg : Arg pu) (type' : Expr) : Arg pu :=
|
||||
match arg with
|
||||
| .type ty => if ptrEq ty type' then arg else .type type'
|
||||
| .type ty _ => if ptrEq ty type' then arg else .type type'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by Arg.updateTypeImp] opaque Arg.updateType! (arg : Arg) (type : Expr) : Arg
|
||||
@[implemented_by Arg.updateTypeImp] opaque Arg.updateType! (arg : Arg pu) (type : Expr) : Arg pu
|
||||
|
||||
private unsafe def Arg.updateFVarImp (arg : Arg) (fvarId' : FVarId) : Arg :=
|
||||
private unsafe def Arg.updateFVarImp (arg : Arg pu) (fvarId' : FVarId) : Arg pu :=
|
||||
match arg with
|
||||
| .fvar fvarId => if fvarId' == fvarId then arg else .fvar fvarId'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by Arg.updateFVarImp] opaque Arg.updateFVar! (arg : Arg) (fvarId' : FVarId) : Arg
|
||||
@[implemented_by Arg.updateFVarImp] opaque Arg.updateFVar! (arg : Arg pu) (fvarId' : FVarId) : Arg pu
|
||||
|
||||
inductive LetValue where
|
||||
inductive LetValue (pu : Purity) where
|
||||
| lit (value : LitValue)
|
||||
| erased
|
||||
| proj (typeName : Name) (idx : Nat) (struct : FVarId)
|
||||
| const (declName : Name) (us : List Level) (args : Array Arg)
|
||||
| fvar (fvarId : FVarId) (args : Array Arg)
|
||||
| proj (typeName : Name) (idx : Nat) (struct : FVarId) (h : pu = .pure := by purity_tac)
|
||||
| const (declName : Name) (us : List Level) (args : Array (Arg pu)) (h : pu = .pure := by purity_tac)
|
||||
| fvar (fvarId : FVarId) (args : Array (Arg pu))
|
||||
deriving Inhabited, BEq, Hashable
|
||||
|
||||
def Arg.toLetValue (arg : Arg) : LetValue :=
|
||||
def Arg.toLetValue (arg : Arg pu) : LetValue pu :=
|
||||
match arg with
|
||||
| .fvar fvarId => .fvar fvarId #[]
|
||||
| .erased | .type .. => .erased
|
||||
|
||||
private unsafe def LetValue.updateProjImp (e : LetValue) (fvarId' : FVarId) : LetValue :=
|
||||
private unsafe def LetValue.updateProjImp (e : LetValue pu) (fvarId' : FVarId) : LetValue pu :=
|
||||
match e with
|
||||
| .proj s i fvarId => if fvarId == fvarId' then e else .proj s i fvarId'
|
||||
| .proj s i fvarId _ => if fvarId == fvarId' then e else .proj s i fvarId'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by LetValue.updateProjImp] opaque LetValue.updateProj! (e : LetValue) (fvarId' : FVarId) : LetValue
|
||||
@[implemented_by LetValue.updateProjImp] opaque LetValue.updateProj! (e : LetValue pu) (fvarId' : FVarId) : LetValue pu
|
||||
|
||||
private unsafe def LetValue.updateConstImp (e : LetValue) (declName' : Name) (us' : List Level) (args' : Array Arg) : LetValue :=
|
||||
private unsafe def LetValue.updateConstImp (e : LetValue pu) (declName' : Name) (us' : List Level) (args' : Array (Arg pu)) : LetValue pu :=
|
||||
match e with
|
||||
| .const declName us args => if declName == declName' && ptrEq us us' && ptrEq args args' then e else .const declName' us' args'
|
||||
| .const declName us args _ => if declName == declName' && ptrEq us us' && ptrEq args args' then e else .const declName' us' args'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by LetValue.updateConstImp] opaque LetValue.updateConst! (e : LetValue) (declName' : Name) (us' : List Level) (args' : Array Arg) : LetValue
|
||||
@[implemented_by LetValue.updateConstImp] opaque LetValue.updateConst! (e : LetValue pu) (declName' : Name) (us' : List Level) (args' : Array (Arg pu)) : LetValue pu
|
||||
|
||||
private unsafe def LetValue.updateFVarImp (e : LetValue) (fvarId' : FVarId) (args' : Array Arg) : LetValue :=
|
||||
private unsafe def LetValue.updateFVarImp (e : LetValue pu) (fvarId' : FVarId) (args' : Array (Arg pu)) : LetValue pu :=
|
||||
match e with
|
||||
| .fvar fvarId args => if fvarId == fvarId' && ptrEq args args' then e else .fvar fvarId' args'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by LetValue.updateFVarImp] opaque LetValue.updateFVar! (e : LetValue) (fvarId' : FVarId) (args' : Array Arg) : LetValue
|
||||
@[implemented_by LetValue.updateFVarImp] opaque LetValue.updateFVar! (e : LetValue pu) (fvarId' : FVarId) (args' : Array (Arg pu)) : LetValue pu
|
||||
|
||||
private unsafe def LetValue.updateArgsImp (e : LetValue) (args' : Array Arg) : LetValue :=
|
||||
private unsafe def LetValue.updateArgsImp (e : LetValue pu) (args' : Array (Arg pu)) : LetValue pu :=
|
||||
match e with
|
||||
| .const declName us args => if ptrEq args args' then e else .const declName us args'
|
||||
| .const declName us args h => if ptrEq args args' then e else .const declName us args'
|
||||
| .fvar fvarId args => if ptrEq args args' then e else .fvar fvarId args'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by LetValue.updateArgsImp] opaque LetValue.updateArgs! (e : LetValue) (args' : Array Arg) : LetValue
|
||||
@[implemented_by LetValue.updateArgsImp] opaque LetValue.updateArgs! (e : LetValue pu) (args' : Array (Arg pu)) : LetValue pu
|
||||
|
||||
def LetValue.toExpr (e : LetValue) : Expr :=
|
||||
def LetValue.toExpr (e : LetValue pu) : Expr :=
|
||||
match e with
|
||||
| .lit v => v.toExpr
|
||||
| .erased => erasedExpr
|
||||
| .proj n i s => .proj n i (.fvar s)
|
||||
| .const n us as => mkAppN (.const n us) (as.map Arg.toExpr)
|
||||
| .proj n i s _ => .proj n i (.fvar s)
|
||||
| .const n us as _ => mkAppN (.const n us) (as.map Arg.toExpr)
|
||||
| .fvar fvarId as => mkAppN (.fvar fvarId) (as.map Arg.toExpr)
|
||||
|
||||
structure LetDecl where
|
||||
structure LetDecl (pu : Purity) where
|
||||
fvarId : FVarId
|
||||
binderName : Name
|
||||
type : Expr
|
||||
value : LetValue
|
||||
value : LetValue pu
|
||||
deriving Inhabited, BEq
|
||||
|
||||
mutual
|
||||
|
||||
inductive Alt where
|
||||
| alt (ctorName : Name) (params : Array Param) (code : Code)
|
||||
| default (code : Code)
|
||||
inductive Alt (pu : Purity) where
|
||||
| alt (ctorName : Name) (params : Array (Param pu)) (code : Code pu) (h : pu = .pure := by purity_tac)
|
||||
| default (code : Code pu)
|
||||
|
||||
inductive FunDecl where
|
||||
| mk (fvarId : FVarId) (binderName : Name) (params : Array Param) (type : Expr) (value : Code)
|
||||
inductive FunDecl (pu : Purity) where
|
||||
| mk (fvarId : FVarId) (binderName : Name) (params : Array (Param pu)) (type : Expr) (value : Code pu)
|
||||
|
||||
inductive Cases where
|
||||
| mk (typeName : Name) (resultType : Expr) (discr : FVarId) (alts : Array Alt)
|
||||
inductive Cases (pu : Purity) where
|
||||
| mk (typeName : Name) (resultType : Expr) (discr : FVarId) (alts : Array (Alt pu))
|
||||
deriving Inhabited
|
||||
|
||||
inductive Code where
|
||||
| let (decl : LetDecl) (k : Code)
|
||||
| fun (decl : FunDecl) (k : Code)
|
||||
| jp (decl : FunDecl) (k : Code)
|
||||
| jmp (fvarId : FVarId) (args : Array Arg)
|
||||
| cases (cases : Cases)
|
||||
inductive Code (pu : Purity) where
|
||||
| let (decl : LetDecl pu) (k : Code pu)
|
||||
| fun (decl : FunDecl pu) (k : Code pu) (h : pu = .pure := by purity_tac)
|
||||
| jp (decl : FunDecl pu) (k : Code pu)
|
||||
| jmp (fvarId : FVarId) (args : Array (Arg pu))
|
||||
| cases (cases : Cases pu)
|
||||
| return (fvarId : FVarId)
|
||||
| unreach (type : Expr)
|
||||
deriving Inhabited
|
||||
@@ -167,99 +203,99 @@ inductive Code where
|
||||
end
|
||||
|
||||
@[inline]
|
||||
def FunDecl.fvarId : FunDecl → FVarId
|
||||
def FunDecl.fvarId : FunDecl pu → FVarId
|
||||
| .mk (fvarId := fvarId) .. => fvarId
|
||||
|
||||
@[inline]
|
||||
def FunDecl.binderName : FunDecl → Name
|
||||
def FunDecl.binderName : FunDecl pu → Name
|
||||
| .mk (binderName := binderName) .. => binderName
|
||||
|
||||
@[inline]
|
||||
def FunDecl.params : FunDecl → Array Param
|
||||
def FunDecl.params : FunDecl pu → Array (Param pu)
|
||||
| .mk (params := params) .. => params
|
||||
|
||||
@[inline]
|
||||
def FunDecl.type : FunDecl → Expr
|
||||
def FunDecl.type : FunDecl pu → Expr
|
||||
| .mk (type := type) .. => type
|
||||
|
||||
@[inline]
|
||||
def FunDecl.value : FunDecl → Code
|
||||
def FunDecl.value : FunDecl pu → Code pu
|
||||
| .mk (value := value) .. => value
|
||||
|
||||
@[inline]
|
||||
def FunDecl.updateBinderName : FunDecl → Name → FunDecl
|
||||
def FunDecl.updateBinderName : FunDecl pu → Name → FunDecl pu
|
||||
| .mk fvarId _ params type value, new =>
|
||||
.mk fvarId new params type value
|
||||
|
||||
@[inline]
|
||||
def FunDecl.toParam (decl : FunDecl) (borrow : Bool) : Param :=
|
||||
def FunDecl.toParam (decl : FunDecl pu) (borrow : Bool) : Param pu :=
|
||||
match decl with
|
||||
| .mk fvarId binderName _ type .. => ⟨fvarId, binderName, type, borrow⟩
|
||||
|
||||
@[inline]
|
||||
def Cases.typeName : Cases → Name
|
||||
def Cases.typeName : Cases pu → Name
|
||||
| .mk (typeName := typeName) .. => typeName
|
||||
|
||||
@[inline]
|
||||
def Cases.resultType : Cases → Expr
|
||||
def Cases.resultType : Cases pu → Expr
|
||||
| .mk (resultType := resultType) .. => resultType
|
||||
|
||||
@[inline]
|
||||
def Cases.discr : Cases → FVarId
|
||||
def Cases.discr : Cases pu → FVarId
|
||||
| .mk (discr := discr) .. => discr
|
||||
|
||||
@[inline]
|
||||
def Cases.alts : Cases → Array Alt
|
||||
def Cases.alts : Cases pu → Array (Alt pu)
|
||||
| .mk (alts := alts) .. => alts
|
||||
|
||||
@[inline]
|
||||
def Cases.updateAlts : Cases → Array Alt → Cases
|
||||
def Cases.updateAlts : Cases pu → Array (Alt pu) → Cases pu
|
||||
| .mk typeName resultType discr _, new =>
|
||||
.mk typeName resultType discr new
|
||||
|
||||
deriving instance Inhabited for Alt
|
||||
deriving instance Inhabited for FunDecl
|
||||
|
||||
def FunDecl.getArity (decl : FunDecl) : Nat :=
|
||||
def FunDecl.getArity (decl : FunDecl pu) : Nat :=
|
||||
decl.params.size
|
||||
|
||||
/--
|
||||
Return the constructor names that have an explicit (non-default) alternative.
|
||||
-/
|
||||
def Cases.getCtorNames (c : Cases) : NameSet :=
|
||||
def Cases.getCtorNames (c : Cases pu) : NameSet :=
|
||||
c.alts.foldl (init := {}) fun ctorNames alt =>
|
||||
match alt with
|
||||
| .default _ => ctorNames
|
||||
| .alt ctorName .. => ctorNames.insert ctorName
|
||||
|
||||
inductive CodeDecl where
|
||||
| let (decl : LetDecl)
|
||||
| fun (decl : FunDecl)
|
||||
| jp (decl : FunDecl)
|
||||
inductive CodeDecl (pu : Purity) where
|
||||
| let (decl : LetDecl pu)
|
||||
| fun (decl : FunDecl pu) (h : pu = .pure := by purity_tac)
|
||||
| jp (decl : FunDecl pu)
|
||||
deriving Inhabited
|
||||
|
||||
def CodeDecl.fvarId : CodeDecl → FVarId
|
||||
| .let decl | .fun decl | .jp decl => decl.fvarId
|
||||
def CodeDecl.fvarId : CodeDecl pu → FVarId
|
||||
| .let decl | .fun decl _ | .jp decl => decl.fvarId
|
||||
|
||||
def attachCodeDecls (decls : Array CodeDecl) (code : Code) : Code :=
|
||||
def attachCodeDecls (decls : Array (CodeDecl pu)) (code : Code pu) : Code pu :=
|
||||
go decls.size code
|
||||
where
|
||||
go (i : Nat) (code : Code) : Code :=
|
||||
go (i : Nat) (code : Code pu) : Code pu :=
|
||||
if i > 0 then
|
||||
match decls[i-1]! with
|
||||
| .let decl => go (i-1) (.let decl code)
|
||||
| .fun decl => go (i-1) (.fun decl code)
|
||||
| .fun decl _ => go (i-1) (.fun decl code)
|
||||
| .jp decl => go (i-1) (.jp decl code)
|
||||
else
|
||||
code
|
||||
|
||||
mutual
|
||||
private unsafe def eqImp (c₁ c₂ : Code) : Bool :=
|
||||
private unsafe def eqImp (c₁ c₂ : Code pu) : Bool :=
|
||||
if ptrEq c₁ c₂ then
|
||||
true
|
||||
else match c₁, c₂ with
|
||||
| .let d₁ k₁, .let d₂ k₂ => d₁ == d₂ && eqImp k₁ k₂
|
||||
| .fun d₁ k₁, .fun d₂ k₂
|
||||
| .fun d₁ k₁ _, .fun d₂ k₂ _
|
||||
| .jp d₁ k₁, .jp d₂ k₂ => eqFunDecl d₁ d₂ && eqImp k₁ k₂
|
||||
| .cases c₁, .cases c₂ => eqCases c₁ c₂
|
||||
| .jmp j₁ as₁, .jmp j₂ as₂ => j₁ == j₂ && as₁ == as₂
|
||||
@@ -267,7 +303,7 @@ mutual
|
||||
| .unreach t₁, .unreach t₂ => t₁ == t₂
|
||||
| _, _ => false
|
||||
|
||||
private unsafe def eqFunDecl (d₁ d₂ : FunDecl) : Bool :=
|
||||
private unsafe def eqFunDecl (d₁ d₂ : FunDecl pu) : Bool :=
|
||||
if ptrEq d₁ d₂ then
|
||||
true
|
||||
else
|
||||
@@ -275,62 +311,62 @@ mutual
|
||||
d₁.params == d₂.params && d₁.type == d₂.type &&
|
||||
eqImp d₁.value d₂.value
|
||||
|
||||
private unsafe def eqCases (c₁ c₂ : Cases) : Bool :=
|
||||
private unsafe def eqCases (c₁ c₂ : Cases pu) : Bool :=
|
||||
c₁.resultType == c₂.resultType && c₁.discr == c₂.discr &&
|
||||
c₁.typeName == c₂.typeName && c₁.alts.isEqv c₂.alts eqAlt
|
||||
|
||||
private unsafe def eqAlt (a₁ a₂ : Alt) : Bool :=
|
||||
private unsafe def eqAlt (a₁ a₂ : Alt pu) : Bool :=
|
||||
match a₁, a₂ with
|
||||
| .default k₁, .default k₂ => eqImp k₁ k₂
|
||||
| .alt c₁ ps₁ k₁, .alt c₂ ps₂ k₂ => c₁ == c₂ && ps₁ == ps₂ && eqImp k₁ k₂
|
||||
| .alt c₁ ps₁ k₁ _, .alt c₂ ps₂ k₂ _ => c₁ == c₂ && ps₁ == ps₂ && eqImp k₁ k₂
|
||||
| _, _ => false
|
||||
end
|
||||
|
||||
@[implemented_by eqImp] protected opaque Code.beq : Code → Code → Bool
|
||||
@[implemented_by eqImp] protected opaque Code.beq : Code pu → Code pu → Bool
|
||||
|
||||
instance : BEq Code where
|
||||
instance : BEq (Code pu) where
|
||||
beq := Code.beq
|
||||
|
||||
@[implemented_by eqFunDecl] protected opaque FunDecl.beq : FunDecl → FunDecl → Bool
|
||||
@[implemented_by eqFunDecl] protected opaque FunDecl.beq : FunDecl pu → FunDecl pu → Bool
|
||||
|
||||
instance : BEq FunDecl where
|
||||
instance : BEq (FunDecl pu) where
|
||||
beq := FunDecl.beq
|
||||
|
||||
def Alt.getCode : Alt → Code
|
||||
def Alt.getCode : Alt pu → Code pu
|
||||
| .default k => k
|
||||
| .alt _ _ k => k
|
||||
| .alt _ _ k _ => k
|
||||
|
||||
def Alt.getParams : Alt → Array Param
|
||||
def Alt.getParams : Alt pu → Array (Param pu)
|
||||
| .default _ => #[]
|
||||
| .alt _ ps _ => ps
|
||||
| .alt _ ps _ _ => ps
|
||||
|
||||
def Alt.forCodeM [Monad m] (alt : Alt) (f : Code → m Unit) : m Unit := do
|
||||
def Alt.forCodeM [Monad m] (alt : Alt pu) (f : Code pu → m Unit) : m Unit := do
|
||||
match alt with
|
||||
| .default k => f k
|
||||
| .alt _ _ k => f k
|
||||
| .alt _ _ k _ => f k
|
||||
|
||||
private unsafe def updateAltCodeImp (alt : Alt) (k' : Code) : Alt :=
|
||||
private unsafe def updateAltCodeImp (alt : Alt pu) (k' : Code pu) : Alt pu :=
|
||||
match alt with
|
||||
| .default k => if ptrEq k k' then alt else .default k'
|
||||
| .alt ctorName ps k => if ptrEq k k' then alt else .alt ctorName ps k'
|
||||
| .alt ctorName ps k _ => if ptrEq k k' then alt else .alt ctorName ps k'
|
||||
|
||||
@[implemented_by updateAltCodeImp] opaque Alt.updateCode (alt : Alt) (c : Code) : Alt
|
||||
@[implemented_by updateAltCodeImp] opaque Alt.updateCode (alt : Alt pu) (c : Code pu) : Alt pu
|
||||
|
||||
private unsafe def updateAltImp (alt : Alt) (ps' : Array Param) (k' : Code) : Alt :=
|
||||
private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Code pu) : Alt pu :=
|
||||
match alt with
|
||||
| .alt ctorName ps k => if ptrEq k k' && ptrEq ps ps' then alt else .alt ctorName ps' k'
|
||||
| .alt ctorName ps k _ => if ptrEq k k' && ptrEq ps ps' then alt else .alt ctorName ps' k'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateAltImp] opaque Alt.updateAlt! (alt : Alt) (ps' : Array Param) (k' : Code) : Alt
|
||||
@[implemented_by updateAltImp] opaque Alt.updateAlt! (alt : Alt pu) (ps' : Array (Param pu)) (k' : Code pu) : Alt pu
|
||||
|
||||
@[inline] private unsafe def updateAltsImp (c : Code) (alts : Array Alt) : Code :=
|
||||
@[inline] private unsafe def updateAltsImp (c : Code pu) (alts : Array (Alt pu)) : Code pu :=
|
||||
match c with
|
||||
| .cases cs => if ptrEq cs.alts alts then c else .cases <| cs.updateAlts alts
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateAltsImp] opaque Code.updateAlts! (c : Code) (alts : Array Alt) : Code
|
||||
@[implemented_by updateAltsImp] opaque Code.updateAlts! (c : Code pu) (alts : Array (Alt pu)) : Code pu
|
||||
|
||||
@[inline] private unsafe def updateCasesImp (c : Code) (resultType : Expr) (discr : FVarId) (alts : Array Alt) : Code :=
|
||||
@[inline] private unsafe def updateCasesImp (c : Code pu) (resultType : Expr) (discr : FVarId) (alts : Array (Alt pu)) : Code pu :=
|
||||
match c with
|
||||
| .cases cs =>
|
||||
if ptrEq cs.alts alts && ptrEq cs.resultType resultType && cs.discr == discr then
|
||||
@@ -339,54 +375,54 @@ private unsafe def updateAltImp (alt : Alt) (ps' : Array Param) (k' : Code) : Al
|
||||
.cases <| ⟨cs.typeName, resultType, discr, alts⟩
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateCasesImp] opaque Code.updateCases! (c : Code) (resultType : Expr) (discr : FVarId) (alts : Array Alt) : Code
|
||||
@[implemented_by updateCasesImp] opaque Code.updateCases! (c : Code pu) (resultType : Expr) (discr : FVarId) (alts : Array (Alt pu)) : Code pu
|
||||
|
||||
@[inline] private unsafe def updateLetImp (c : Code) (decl' : LetDecl) (k' : Code) : Code :=
|
||||
@[inline] private unsafe def updateLetImp (c : Code pu) (decl' : LetDecl pu) (k' : Code pu) : Code pu :=
|
||||
match c with
|
||||
| .let decl k => if ptrEq k k' && ptrEq decl decl' then c else .let decl' k'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateLetImp] opaque Code.updateLet! (c : Code) (decl' : LetDecl) (k' : Code) : Code
|
||||
@[implemented_by updateLetImp] opaque Code.updateLet! (c : Code pu) (decl' : LetDecl pu) (k' : Code pu) : Code pu
|
||||
|
||||
@[inline] private unsafe def updateContImp (c : Code) (k' : Code) : Code :=
|
||||
@[inline] private unsafe def updateContImp (c : Code pu) (k' : Code pu) : Code pu :=
|
||||
match c with
|
||||
| .let decl k => if ptrEq k k' then c else .let decl k'
|
||||
| .fun decl k => if ptrEq k k' then c else .fun decl k'
|
||||
| .fun decl k _ => if ptrEq k k' then c else .fun decl k'
|
||||
| .jp decl k => if ptrEq k k' then c else .jp decl k'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateContImp] opaque Code.updateCont! (c : Code) (k' : Code) : Code
|
||||
@[implemented_by updateContImp] opaque Code.updateCont! (c : Code pu) (k' : Code pu) : Code pu
|
||||
|
||||
@[inline] private unsafe def updateFunImp (c : Code) (decl' : FunDecl) (k' : Code) : Code :=
|
||||
@[inline] private unsafe def updateFunImp (c : Code pu) (decl' : FunDecl pu) (k' : Code pu) : Code pu :=
|
||||
match c with
|
||||
| .fun decl k => if ptrEq k k' && ptrEq decl decl' then c else .fun decl' k'
|
||||
| .fun decl k _ => if ptrEq k k' && ptrEq decl decl' then c else .fun decl' k'
|
||||
| .jp decl k => if ptrEq k k' && ptrEq decl decl' then c else .jp decl' k'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateFunImp] opaque Code.updateFun! (c : Code) (decl' : FunDecl) (k' : Code) : Code
|
||||
@[implemented_by updateFunImp] opaque Code.updateFun! (c : Code pu) (decl' : FunDecl pu) (k' : Code pu) : Code pu
|
||||
|
||||
@[inline] private unsafe def updateReturnImp (c : Code) (fvarId' : FVarId) : Code :=
|
||||
@[inline] private unsafe def updateReturnImp (c : Code pu) (fvarId' : FVarId) : Code pu :=
|
||||
match c with
|
||||
| .return fvarId => if fvarId == fvarId' then c else .return fvarId'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateReturnImp] opaque Code.updateReturn! (c : Code) (fvarId' : FVarId) : Code
|
||||
@[implemented_by updateReturnImp] opaque Code.updateReturn! (c : Code pu) (fvarId' : FVarId) : Code pu
|
||||
|
||||
@[inline] private unsafe def updateJmpImp (c : Code) (fvarId' : FVarId) (args' : Array Arg) : Code :=
|
||||
@[inline] private unsafe def updateJmpImp (c : Code pu) (fvarId' : FVarId) (args' : Array (Arg pu)) : Code pu :=
|
||||
match c with
|
||||
| .jmp fvarId args => if fvarId == fvarId' && ptrEq args args' then c else .jmp fvarId' args'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateJmpImp] opaque Code.updateJmp! (c : Code) (fvarId' : FVarId) (args' : Array Arg) : Code
|
||||
@[implemented_by updateJmpImp] opaque Code.updateJmp! (c : Code pu) (fvarId' : FVarId) (args' : Array (Arg pu)) : Code pu
|
||||
|
||||
@[inline] private unsafe def updateUnreachImp (c : Code) (type' : Expr) : Code :=
|
||||
@[inline] private unsafe def updateUnreachImp (c : Code pu) (type' : Expr) : Code pu :=
|
||||
match c with
|
||||
| .unreach type => if ptrEq type type' then c else .unreach type'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateUnreachImp] opaque Code.updateUnreach! (c : Code) (type' : Expr) : Code
|
||||
@[implemented_by updateUnreachImp] opaque Code.updateUnreach! (c : Code pu) (type' : Expr) : Code pu
|
||||
|
||||
private unsafe def updateParamCoreImp (p : Param) (type : Expr) : Param :=
|
||||
private unsafe def updateParamCoreImp (p : Param pu) (type : Expr) : Param pu :=
|
||||
if ptrEq type p.type then
|
||||
p
|
||||
else
|
||||
@@ -397,9 +433,9 @@ Low-level update `Param` function. It does not update the local context.
|
||||
Consider using `Param.update : Param → Expr → CompilerM Param` if you want the local context
|
||||
to be updated.
|
||||
-/
|
||||
@[implemented_by updateParamCoreImp] opaque Param.updateCore (p : Param) (type : Expr) : Param
|
||||
@[implemented_by updateParamCoreImp] opaque Param.updateCore (p : Param pu) (type : Expr) : Param pu
|
||||
|
||||
private unsafe def updateLetDeclCoreImp (decl : LetDecl) (type : Expr) (value : LetValue) : LetDecl :=
|
||||
private unsafe def updateLetDeclCoreImp (decl : LetDecl pu) (type : Expr) (value : LetValue pu) : LetDecl pu :=
|
||||
if ptrEq type decl.type && ptrEq value decl.value then
|
||||
decl
|
||||
else
|
||||
@@ -410,9 +446,9 @@ Low-level update `LetDecl` function. It does not update the local context.
|
||||
Consider using `LetDecl.update : LetDecl → Expr → Expr → CompilerM LetDecl` if you want the local context
|
||||
to be updated.
|
||||
-/
|
||||
@[implemented_by updateLetDeclCoreImp] opaque LetDecl.updateCore (decl : LetDecl) (type : Expr) (value : LetValue) : LetDecl
|
||||
@[implemented_by updateLetDeclCoreImp] opaque LetDecl.updateCore (decl : LetDecl pu) (type : Expr) (value : LetValue pu) : LetDecl pu
|
||||
|
||||
private unsafe def updateFunDeclCoreImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl :=
|
||||
private unsafe def updateFunDeclCoreImp (decl: FunDecl pu) (type : Expr) (params : Array (Param pu)) (value : Code pu) : FunDecl pu :=
|
||||
if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then
|
||||
decl
|
||||
else
|
||||
@@ -423,9 +459,9 @@ Low-level update `FunDecl` function. It does not update the local context.
|
||||
Consider using `FunDecl.update : LetDecl → Expr → Array Param → Code → CompilerM FunDecl` if you want the local context
|
||||
to be updated.
|
||||
-/
|
||||
@[implemented_by updateFunDeclCoreImp] opaque FunDecl.updateCore (decl : FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl
|
||||
@[implemented_by updateFunDeclCoreImp] opaque FunDecl.updateCore (decl : FunDecl pu) (type : Expr) (params : Array (Param pu)) (value : Code pu) : FunDecl pu
|
||||
|
||||
def Cases.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases :=
|
||||
def Cases.extractAlt! (cases : Cases pu) (ctorName : Name) : Alt pu × Cases pu :=
|
||||
let found i := (cases.alts[i]!, cases.updateAlts (cases.alts.eraseIdx! i))
|
||||
if let some i := cases.alts.findFinIdx? fun | .alt ctorName' .. => ctorName == ctorName' | _ => false then
|
||||
found i
|
||||
@@ -434,34 +470,34 @@ def Cases.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases :=
|
||||
else
|
||||
unreachable!
|
||||
|
||||
def Alt.mapCodeM [Monad m] (alt : Alt) (f : Code → m Code) : m Alt := do
|
||||
def Alt.mapCodeM [Monad m] (alt : Alt pu) (f : Code pu → m (Code pu)) : m (Alt pu) := do
|
||||
return alt.updateCode (← f alt.getCode)
|
||||
|
||||
def Code.isDecl : Code → Bool
|
||||
def Code.isDecl : Code pu → Bool
|
||||
| .let .. | .fun .. | .jp .. => true
|
||||
| _ => false
|
||||
|
||||
def Code.isFun : Code → Bool
|
||||
def Code.isFun : Code pu → Bool
|
||||
| .fun .. => true
|
||||
| _ => false
|
||||
|
||||
def Code.isReturnOf : Code → FVarId → Bool
|
||||
def Code.isReturnOf : Code pu → FVarId → Bool
|
||||
| .return fvarId, fvarId' => fvarId == fvarId'
|
||||
| _, _ => false
|
||||
|
||||
partial def Code.size (c : Code) : Nat :=
|
||||
partial def Code.size (c : Code pu) : Nat :=
|
||||
go c 0
|
||||
where
|
||||
go (c : Code) (n : Nat) : Nat :=
|
||||
go (c : Code pu) (n : Nat) : Nat :=
|
||||
match c with
|
||||
| .let _ k => go k (n+1)
|
||||
| .jp decl k | .fun decl k => go k <| go decl.value n
|
||||
| .jp decl k | .fun decl k _ => go k <| go decl.value n
|
||||
| .cases c => c.alts.foldl (init := n+1) fun n alt => go alt.getCode (n+1)
|
||||
| .jmp .. => n+1
|
||||
| .return .. | unreach .. => n -- `return` & `unreach` have weight zero
|
||||
|
||||
/-- Return true iff `c.size ≤ n` -/
|
||||
partial def Code.sizeLe (c : Code) (n : Nat) : Bool :=
|
||||
partial def Code.sizeLe (c : Code pu) (n : Nat) : Bool :=
|
||||
match go c |>.run 0 with
|
||||
| .ok .. => true
|
||||
| .error .. => false
|
||||
@@ -470,26 +506,26 @@ where
|
||||
modify (·+1)
|
||||
unless (← get) <= n do throw ()
|
||||
|
||||
go (c : Code) : EStateM Unit Nat Unit := do
|
||||
go (c : Code pu) : EStateM Unit Nat Unit := do
|
||||
match c with
|
||||
| .let _ k => inc; go k
|
||||
| .jp decl k | .fun decl k => inc; go decl.value; go k
|
||||
| .jp decl k | .fun decl k _ => inc; go decl.value; go k
|
||||
| .cases c => inc; c.alts.forM fun alt => go alt.getCode
|
||||
| .jmp .. => inc
|
||||
| .return .. | unreach .. => return ()
|
||||
|
||||
partial def Code.forM [Monad m] (c : Code) (f : Code → m Unit) : m Unit :=
|
||||
partial def Code.forM [Monad m] (c : Code pu) (f : Code pu → m Unit) : m Unit :=
|
||||
go c
|
||||
where
|
||||
go (c : Code) : m Unit := do
|
||||
go (c : Code pu) : m Unit := do
|
||||
f c
|
||||
match c with
|
||||
| .let _ k => go k
|
||||
| .fun decl k | .jp decl k => go decl.value; go k
|
||||
| .fun decl k _ | .jp decl k => go decl.value; go k
|
||||
| .cases c => c.alts.forM fun alt => go alt.getCode
|
||||
| .unreach .. | .return .. | .jmp .. => return ()
|
||||
|
||||
partial def Code.instantiateValueLevelParams (code : Code) (levelParams : List Name) (us : List Level) : Code :=
|
||||
partial def Code.instantiateValueLevelParams (code : Code .pure) (levelParams : List Name) (us : List Level) : Code .pure :=
|
||||
instCode code
|
||||
where
|
||||
instLevel (u : Level) :=
|
||||
@@ -498,67 +534,67 @@ where
|
||||
instExpr (e : Expr) :=
|
||||
e.instantiateLevelParamsNoCache levelParams us
|
||||
|
||||
instParams (ps : Array Param) :=
|
||||
instParams (ps : Array (Param .pure)) :=
|
||||
ps.mapMono fun p => p.updateCore (instExpr p.type)
|
||||
|
||||
instAlt (alt : Alt) :=
|
||||
instAlt (alt : Alt .pure) :=
|
||||
match alt with
|
||||
| .default k => alt.updateCode (instCode k)
|
||||
| .alt _ ps k => alt.updateAlt! (instParams ps) (instCode k)
|
||||
| .alt _ ps k _ => alt.updateAlt! (instParams ps) (instCode k)
|
||||
|
||||
instArg (arg : Arg) : Arg :=
|
||||
instArg (arg : Arg .pure) : Arg .pure :=
|
||||
match arg with
|
||||
| .type e => arg.updateType! (instExpr e)
|
||||
| .type e _ => arg.updateType! (instExpr e)
|
||||
| .fvar .. | .erased => arg
|
||||
|
||||
instLetValue (e : LetValue) : LetValue :=
|
||||
instLetValue (e : LetValue .pure) : LetValue .pure :=
|
||||
match e with
|
||||
| .const declName vs args => e.updateConst! declName (vs.mapMono instLevel) (args.mapMono instArg)
|
||||
| .const declName vs args _ => e.updateConst! declName (vs.mapMono instLevel) (args.mapMono instArg)
|
||||
| .fvar fvarId args => e.updateFVar! fvarId (args.mapMono instArg)
|
||||
| .proj .. | .lit .. | .erased => e
|
||||
|
||||
instLetDecl (decl : LetDecl) :=
|
||||
instLetDecl (decl : LetDecl .pure) :=
|
||||
decl.updateCore (instExpr decl.type) (instLetValue decl.value)
|
||||
|
||||
instFunDecl (decl : FunDecl) :=
|
||||
instFunDecl (decl : FunDecl .pure) :=
|
||||
decl.updateCore (instExpr decl.type) (instParams decl.params) (instCode decl.value)
|
||||
|
||||
instCode (code : Code) :=
|
||||
instCode (code : Code .pure) :=
|
||||
match code with
|
||||
| .let decl k => code.updateLet! (instLetDecl decl) (instCode k)
|
||||
| .jp decl k | .fun decl k => code.updateFun! (instFunDecl decl) (instCode k)
|
||||
| .jp decl k | .fun decl k _ => code.updateFun! (instFunDecl decl) (instCode k)
|
||||
| .cases c => code.updateCases! (instExpr c.resultType) c.discr (c.alts.mapMono instAlt)
|
||||
| .jmp fvarId args => code.updateJmp! fvarId (args.mapMono instArg)
|
||||
| .return .. => code
|
||||
| .unreach type => code.updateUnreach! (instExpr type)
|
||||
|
||||
inductive DeclValue where
|
||||
| code (code : Code)
|
||||
inductive DeclValue (pu : Purity) where
|
||||
| code (code : Code pu)
|
||||
| extern (externAttrData : ExternAttrData)
|
||||
deriving Inhabited, BEq
|
||||
|
||||
partial def DeclValue.size : DeclValue → Nat
|
||||
partial def DeclValue.size : DeclValue pu → Nat
|
||||
| .code c => c.size
|
||||
| .extern .. => 0
|
||||
|
||||
def DeclValue.mapCode (f : Code → Code) : DeclValue → DeclValue :=
|
||||
def DeclValue.mapCode (f : Code pu → Code pu) : DeclValue pu → DeclValue pu :=
|
||||
fun
|
||||
| .code c => .code (f c)
|
||||
| .extern e => .extern e
|
||||
|
||||
def DeclValue.mapCodeM [Monad m] (f : Code → m Code) : DeclValue → m DeclValue :=
|
||||
def DeclValue.mapCodeM [Monad m] (f : Code pu → m (Code pu)) : DeclValue pu → m (DeclValue pu) :=
|
||||
fun v => do
|
||||
match v with
|
||||
| .code c => return .code (← f c)
|
||||
| .extern .. => return v
|
||||
|
||||
def DeclValue.forCodeM [Monad m] (f : Code → m Unit) : DeclValue → m Unit :=
|
||||
def DeclValue.forCodeM [Monad m] (f : Code pu → m Unit) : DeclValue pu → m Unit :=
|
||||
fun v => do
|
||||
match v with
|
||||
| .code c => f c
|
||||
| .extern .. => return ()
|
||||
|
||||
def DeclValue.isCodeAndM [Monad m] (v : DeclValue) (f : Code → m Bool) : m Bool :=
|
||||
def DeclValue.isCodeAndM [Monad m] (v : DeclValue pu) (f : Code pu → m Bool) : m Bool :=
|
||||
match v with
|
||||
| .code c => f c
|
||||
| .extern .. => pure false
|
||||
@@ -566,7 +602,7 @@ def DeclValue.isCodeAndM [Monad m] (v : DeclValue) (f : Code → m Bool) : m Boo
|
||||
/--
|
||||
Declaration being processed by the Lean to Lean compiler passes.
|
||||
-/
|
||||
structure Decl where
|
||||
structure Decl (pu : Purity) where
|
||||
/--
|
||||
The name of the declaration from the `Environment` it came from
|
||||
-/
|
||||
@@ -584,12 +620,12 @@ structure Decl where
|
||||
/--
|
||||
Parameters.
|
||||
-/
|
||||
params : Array Param
|
||||
params : Array (Param pu)
|
||||
/--
|
||||
The body of the declaration, usually changes as it progresses
|
||||
through compiler passes.
|
||||
-/
|
||||
value : DeclValue
|
||||
value : DeclValue pu
|
||||
/--
|
||||
We set this flag to true during LCNF conversion. When we receive
|
||||
a block of functions to be compiled, we set this flag to `true`
|
||||
@@ -631,31 +667,37 @@ structure Decl where
|
||||
inlineAttr? : Option InlineAttributeKind
|
||||
deriving Inhabited, BEq
|
||||
|
||||
def Decl.size (decl : Decl) : Nat :=
|
||||
def Decl.size (decl : Decl pu) : Nat :=
|
||||
decl.value.size
|
||||
|
||||
def Decl.getArity (decl : Decl) : Nat :=
|
||||
def Decl.getArity (decl : Decl pu) : Nat :=
|
||||
decl.params.size
|
||||
|
||||
def Decl.inlineAttr (decl : Decl) : Bool :=
|
||||
def Decl.inlineAttr (decl : Decl pu) : Bool :=
|
||||
decl.inlineAttr? matches some .inline
|
||||
|
||||
def Decl.noinlineAttr (decl : Decl) : Bool :=
|
||||
def Decl.noinlineAttr (decl : Decl pu) : Bool :=
|
||||
decl.inlineAttr? matches some .noinline
|
||||
|
||||
def Decl.inlineIfReduceAttr (decl : Decl) : Bool :=
|
||||
def Decl.inlineIfReduceAttr (decl : Decl pu) : Bool :=
|
||||
decl.inlineAttr? matches some .inlineIfReduce
|
||||
|
||||
def Decl.alwaysInlineAttr (decl : Decl) : Bool :=
|
||||
def Decl.alwaysInlineAttr (decl : Decl pu) : Bool :=
|
||||
decl.inlineAttr? matches some .alwaysInline
|
||||
|
||||
/-- Return `true` if the given declaration has been annotated with `[inline]`, `[inline_if_reduce]`, `[macro_inline]`, or `[always_inline]` -/
|
||||
def Decl.inlineable (decl : Decl) : Bool :=
|
||||
def Decl.inlineable (decl : Decl pu) : Bool :=
|
||||
match decl.inlineAttr? with
|
||||
| some .noinline => false
|
||||
| some _ => true
|
||||
| none => false
|
||||
|
||||
def Decl.castPurity! (decl : Decl pu1) (pu2 : Purity) : Decl pu2 :=
|
||||
if h : pu1 = pu2 then
|
||||
h ▸ decl
|
||||
else
|
||||
panic! s!"Purity {pu1} does not match {pu2}, this is a bug"
|
||||
|
||||
/--
|
||||
Return `some i` if `decl` is of the form
|
||||
```
|
||||
@@ -669,21 +711,21 @@ That is, `f` is a sequence of declarations followed by a `cases` on the paramete
|
||||
We use this function to decide whether we should inline a declaration tagged with
|
||||
`[inline_if_reduce]` or not.
|
||||
-/
|
||||
def Decl.isCasesOnParam? (decl : Decl) : Option Nat :=
|
||||
def Decl.isCasesOnParam? (decl : Decl pu) : Option Nat :=
|
||||
match decl.value with
|
||||
| .code c => go c
|
||||
| .extern .. => none
|
||||
where
|
||||
go (code : Code) : Option Nat :=
|
||||
go {pu : Purity} (code : Code pu) : Option Nat :=
|
||||
match code with
|
||||
| .let _ k | .jp _ k | .fun _ k => go k
|
||||
| .let _ k | .jp _ k | .fun _ k _ => go k
|
||||
| .cases c => decl.params.findIdx? fun param => param.fvarId == c.discr
|
||||
| _ => none
|
||||
|
||||
def Decl.instantiateTypeLevelParams (decl : Decl) (us : List Level) : Expr :=
|
||||
def Decl.instantiateTypeLevelParams (decl : Decl pu) (us : List Level) : Expr :=
|
||||
decl.type.instantiateLevelParamsNoCache decl.levelParams us
|
||||
|
||||
def Decl.instantiateParamsLevelParams (decl : Decl) (us : List Level) : Array Param :=
|
||||
def Decl.instantiateParamsLevelParams (decl : Decl pu) (us : List Level) : Array (Param pu) :=
|
||||
decl.params.mapMono fun param => param.updateCore (param.type.instantiateLevelParamsNoCache decl.levelParams us)
|
||||
|
||||
/--
|
||||
@@ -700,11 +742,11 @@ def hasLocalInst (type : Expr) : CoreM Bool := do
|
||||
/--
|
||||
Return `true` if `decl` is supposed to be inlined/specialized.
|
||||
-/
|
||||
def Decl.isTemplateLike (decl : Decl) : CoreM Bool := do
|
||||
def Decl.isTemplateLike (decl : Decl pu) : CoreM Bool := do
|
||||
let env ← getEnv
|
||||
if ← hasLocalInst decl.type then
|
||||
return true -- `decl` applications will be specialized
|
||||
else if Meta.isInstanceCore env decl.name then
|
||||
else if (← isInstanceReducible decl.name) then
|
||||
return true -- `decl` is "fuel" for code specialization
|
||||
else if decl.inlineable || hasSpecializeAttribute env decl.name then
|
||||
return true -- `decl` is going to be inlined or specialized
|
||||
@@ -721,40 +763,40 @@ private partial def collectType (e : Expr) : FVarIdHashSet → FVarIdHashSet :=
|
||||
| .proj .. | .letE .. => unreachable!
|
||||
| _ => id
|
||||
|
||||
private def collectArg (arg : Arg) (s : FVarIdHashSet) : FVarIdHashSet :=
|
||||
private def collectArg (arg : Arg pu) (s : FVarIdHashSet) : FVarIdHashSet :=
|
||||
match arg with
|
||||
| .erased => s
|
||||
| .fvar fvarId => s.insert fvarId
|
||||
| .type e => collectType e s
|
||||
| .type e _ => collectType e s
|
||||
|
||||
private def collectArgs (args : Array Arg) (s : FVarIdHashSet) : FVarIdHashSet :=
|
||||
private def collectArgs (args : Array (Arg pu)) (s : FVarIdHashSet) : FVarIdHashSet :=
|
||||
args.foldl (init := s) fun s arg => collectArg arg s
|
||||
|
||||
private def collectLetValue (e : LetValue) (s : FVarIdHashSet) : FVarIdHashSet :=
|
||||
private def collectLetValue (e : LetValue pu) (s : FVarIdHashSet) : FVarIdHashSet :=
|
||||
match e with
|
||||
| .fvar fvarId args => collectArgs args <| s.insert fvarId
|
||||
| .const _ _ args => collectArgs args s
|
||||
| .proj _ _ fvarId => s.insert fvarId
|
||||
| .const _ _ args _ => collectArgs args s
|
||||
| .proj _ _ fvarId _ => s.insert fvarId
|
||||
| .lit .. | .erased => s
|
||||
|
||||
private partial def collectParams (ps : Array Param) (s : FVarIdHashSet) : FVarIdHashSet :=
|
||||
private partial def collectParams (ps : Array (Param pu)) (s : FVarIdHashSet) : FVarIdHashSet :=
|
||||
ps.foldl (init := s) fun s p => collectType p.type s
|
||||
|
||||
mutual
|
||||
partial def FunDecl.collectUsed (decl : FunDecl) (s : FVarIdHashSet := {}) : FVarIdHashSet :=
|
||||
partial def FunDecl.collectUsed (decl : FunDecl pu) (s : FVarIdHashSet := {}) : FVarIdHashSet :=
|
||||
decl.value.collectUsed <| collectParams decl.params <| collectType decl.type s
|
||||
|
||||
partial def Code.collectUsed (code : Code) (s : FVarIdHashSet := {}) : FVarIdHashSet :=
|
||||
partial def Code.collectUsed (code : Code pu) (s : FVarIdHashSet := {}) : FVarIdHashSet :=
|
||||
match code with
|
||||
| .let decl k => k.collectUsed <| collectLetValue decl.value <| collectType decl.type s
|
||||
| .jp decl k | .fun decl k => k.collectUsed <| decl.collectUsed s
|
||||
| .jp decl k | .fun decl k _ => k.collectUsed <| decl.collectUsed s
|
||||
| .cases c =>
|
||||
let s := s.insert c.discr
|
||||
let s := collectType c.resultType s
|
||||
c.alts.foldl (init := s) fun s alt =>
|
||||
match alt with
|
||||
| .default k => k.collectUsed s
|
||||
| .alt _ ps k => k.collectUsed <| collectParams ps s
|
||||
| .alt _ ps k _ => k.collectUsed <| collectParams ps s
|
||||
| .return fvarId => s.insert fvarId
|
||||
| .unreach type => collectType type s
|
||||
| .jmp fvarId args => collectArgs args <| s.insert fvarId
|
||||
@@ -771,7 +813,7 @@ This is an overapproximation, and relies on the fact that our frontend
|
||||
computes strongly connected components.
|
||||
See comment at `recursive` field.
|
||||
-/
|
||||
partial def markRecDecls (decls : Array Decl) : Array Decl :=
|
||||
partial def markRecDecls (decls : Array (Decl pu)) : Array (Decl pu) :=
|
||||
let (_, isRec) := go |>.run {}
|
||||
decls.map fun decl =>
|
||||
if isRec.contains decl.name then
|
||||
@@ -779,13 +821,13 @@ partial def markRecDecls (decls : Array Decl) : Array Decl :=
|
||||
else
|
||||
decl
|
||||
where
|
||||
visit (code : Code) : StateM NameSet Unit := do
|
||||
visit {pu : Purity} (code : Code pu) : StateM NameSet Unit := do
|
||||
match code with
|
||||
| .jp decl k | .fun decl k => visit decl.value; visit k
|
||||
| .jp decl k | .fun decl k _ => visit decl.value; visit k
|
||||
| .cases c => c.alts.forM fun alt => visit alt.getCode
|
||||
| .unreach .. | .jmp .. | .return .. => return ()
|
||||
| .let decl k =>
|
||||
if let .const declName _ _ := decl.value then
|
||||
if let .const declName _ _ _ := decl.value then
|
||||
if decls.any (·.name == declName) then
|
||||
modify fun s => s.insert declName
|
||||
visit k
|
||||
@@ -793,13 +835,13 @@ where
|
||||
go : StateM NameSet Unit :=
|
||||
decls.forM (·.value.forCodeM visit)
|
||||
|
||||
def instantiateRangeArgs (e : Expr) (beginIdx endIdx : Nat) (args : Array Arg) : Expr :=
|
||||
def instantiateRangeArgs (e : Expr) (beginIdx endIdx : Nat) (args : Array (Arg pu)) : Expr :=
|
||||
if !e.hasLooseBVars then
|
||||
e
|
||||
else
|
||||
e.instantiateRange beginIdx endIdx (args.map (·.toExpr))
|
||||
|
||||
def instantiateRevRangeArgs (e : Expr) (beginIdx endIdx : Nat) (args : Array Arg) : Expr :=
|
||||
def instantiateRevRangeArgs (e : Expr) (beginIdx endIdx : Nat) (args : Array (Arg pu)) : Expr :=
|
||||
if !e.hasLooseBVars then
|
||||
e
|
||||
else
|
||||
|
||||
@@ -14,7 +14,7 @@ namespace Lean.Compiler.LCNF
|
||||
|
||||
/-- Helper class for lifting `CompilerM.codeBind` -/
|
||||
class MonadCodeBind (m : Type → Type) where
|
||||
codeBind : (c : Code) → (f : FVarId → m Code) → m Code
|
||||
codeBind : {pu : Purity} → (c : Code pu) → (f : FVarId → m (Code pu)) → m (Code pu)
|
||||
|
||||
/--
|
||||
Return code that is equivalent to `c >>= f`. That is, executes `c`, and then `f x`, where
|
||||
@@ -25,16 +25,17 @@ an invalid block would be generated. It would be invalid because `f` would not
|
||||
be applied to `jp_i`. Note that, we could have decided to create a copy of `jp_i` where we apply `f` to it,
|
||||
by we decided to not do it to avoid code duplication.
|
||||
-/
|
||||
abbrev Code.bind [MonadCodeBind m] (c : Code) (f : FVarId → m Code) : m Code :=
|
||||
abbrev Code.bind [MonadCodeBind m] (c : Code pu) (f : FVarId → m (Code pu)) : m (Code pu) :=
|
||||
MonadCodeBind.codeBind c f
|
||||
|
||||
partial def CompilerM.codeBind (c : Code) (f : FVarId → CompilerM Code) : CompilerM Code := do
|
||||
partial def CompilerM.codeBind (c : Code pu) (f : FVarId → CompilerM (Code pu)) :
|
||||
CompilerM (Code pu) := do
|
||||
go c |>.run {}
|
||||
where
|
||||
go (c : Code) : ReaderT FVarIdSet CompilerM Code := do
|
||||
go (c : Code pu) : ReaderT FVarIdSet CompilerM (Code pu) := do
|
||||
match c with
|
||||
| .let decl k => return .let decl (← go k)
|
||||
| .fun decl k => return .fun decl (← go k)
|
||||
| .fun decl k _ => return .fun decl (← go k)
|
||||
| .jp decl k =>
|
||||
let value ← go decl.value
|
||||
let type ← value.inferParamType decl.params
|
||||
@@ -43,7 +44,7 @@ where
|
||||
return .jp decl (← go k)
|
||||
| .cases c =>
|
||||
let alts ← c.alts.mapM fun
|
||||
| .alt ctorName params k => return .alt ctorName params (← go k)
|
||||
| .alt ctorName params k _ => return .alt ctorName params (← go k)
|
||||
| .default k => return .default (← go k)
|
||||
if alts.isEmpty then
|
||||
throwError "`Code.bind` failed, empty `cases` found"
|
||||
@@ -60,7 +61,7 @@ where
|
||||
This code is not very efficient, we could ask caller to provide the type of `c >>= f`,
|
||||
but this is more convenient, and this case is seldom reached.
|
||||
-/
|
||||
let auxParam ← mkAuxParam type
|
||||
let auxParam ← mkAuxParam (pu := pu) type
|
||||
let k ← f auxParam.fvarId
|
||||
let typeNew ← k.inferType
|
||||
eraseCode k
|
||||
@@ -81,10 +82,10 @@ Create new parameters for the given arrow type.
|
||||
Example: if `type` is `Nat → Bool → Int`, the result is
|
||||
an array containing two new parameters with types `Nat` and `Bool`.
|
||||
-/
|
||||
partial def mkNewParams (type : Expr) : CompilerM (Array Param) :=
|
||||
partial def mkNewParams (type : Expr) : CompilerM (Array (Param pu)) :=
|
||||
go type #[] #[]
|
||||
where
|
||||
go (type : Expr) (xs : Array Expr) (ps : Array Param) : CompilerM (Array Param) := do
|
||||
go (type : Expr) (xs : Array Expr) (ps : Array (Param pu)) : CompilerM (Array (Param pu)) := do
|
||||
match type with
|
||||
| .forallE _ d b _ =>
|
||||
let d := d.instantiateRev xs
|
||||
@@ -98,15 +99,16 @@ where
|
||||
else
|
||||
return ps
|
||||
|
||||
def isEtaExpandCandidateCore (type : Expr) (params : Array Param) : Bool :=
|
||||
def isEtaExpandCandidateCore (type : Expr) (params : Array (Param .pure)) : Bool :=
|
||||
let typeArity := getArrowArity type
|
||||
let valueArity := params.size
|
||||
typeArity > valueArity
|
||||
|
||||
abbrev FunDecl.isEtaExpandCandidate (decl : FunDecl) : Bool :=
|
||||
abbrev FunDecl.isEtaExpandCandidate (decl : FunDecl .pure) : Bool :=
|
||||
isEtaExpandCandidateCore decl.type decl.params
|
||||
|
||||
def etaExpandCore (type : Expr) (params : Array Param) (value : Code) : CompilerM (Array Param × Code) := do
|
||||
def etaExpandCore (type : Expr) (params : Array (Param .pure)) (value : Code .pure) :
|
||||
CompilerM (Array (Param .pure) × Code .pure) := do
|
||||
let valueType ← instantiateForall type (params.map (mkFVar ·.fvarId))
|
||||
let psNew ← mkNewParams valueType
|
||||
let params := params ++ psNew
|
||||
@@ -116,17 +118,17 @@ def etaExpandCore (type : Expr) (params : Array Param) (value : Code) : Compiler
|
||||
return .let auxDecl (.return auxDecl.fvarId)
|
||||
return (params, value)
|
||||
|
||||
def etaExpandCore? (type : Expr) (params : Array Param) (value : Code) : CompilerM (Option (Array Param × Code)) := do
|
||||
def etaExpandCore? (type : Expr) (params : Array (Param .pure)) (value : Code .pure) : CompilerM (Option (Array (Param .pure) × Code .pure)) := do
|
||||
if isEtaExpandCandidateCore type params then
|
||||
etaExpandCore type params value
|
||||
else
|
||||
return none
|
||||
|
||||
def FunDecl.etaExpand (decl : FunDecl) : CompilerM FunDecl := do
|
||||
def FunDecl.etaExpand (decl : FunDecl .pure) : CompilerM (FunDecl .pure) := do
|
||||
let some (params, value) ← etaExpandCore? decl.type decl.params decl.value | return decl
|
||||
decl.update decl.type params value
|
||||
|
||||
def Decl.etaExpand (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.etaExpand (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
match decl.value with
|
||||
| .code code =>
|
||||
let some (params, newCode) ← etaExpandCore? decl.type decl.params code | return decl
|
||||
|
||||
@@ -20,17 +20,17 @@ namespace CSE
|
||||
|
||||
structure State where
|
||||
map : PHashMap Expr FVarId := {}
|
||||
subst : FVarSubst := {}
|
||||
subst : FVarSubst .pure := {}
|
||||
|
||||
abbrev M := StateRefT State CompilerM
|
||||
|
||||
instance : MonadFVarSubst M false where
|
||||
instance : MonadFVarSubst M .pure false where
|
||||
getSubst := return (← get).subst
|
||||
|
||||
instance : MonadFVarSubstState M where
|
||||
instance : MonadFVarSubstState M .pure where
|
||||
modifySubst f := modify fun s => { s with subst := f s.subst }
|
||||
|
||||
@[inline] def getSubst : M FVarSubst :=
|
||||
@[inline] def getSubst : M (FVarSubst .pure) :=
|
||||
return (← get).subst
|
||||
|
||||
@[inline] def addEntry (value : Expr) (fvarId : FVarId) : M Unit :=
|
||||
@@ -40,31 +40,32 @@ instance : MonadFVarSubstState M where
|
||||
let map := (← get).map
|
||||
try x finally modify fun s => { s with map }
|
||||
|
||||
def replaceLet (decl : LetDecl) (fvarId : FVarId) : M Unit := do
|
||||
def replaceLet (decl : LetDecl .pure) (fvarId : FVarId) : M Unit := do
|
||||
eraseLetDecl decl
|
||||
addFVarSubst decl.fvarId fvarId
|
||||
|
||||
def replaceFun (decl : FunDecl) (fvarId : FVarId) : M Unit := do
|
||||
def replaceFun (decl : FunDecl .pure) (fvarId : FVarId) : M Unit := do
|
||||
eraseFunDecl decl
|
||||
addFVarSubst decl.fvarId fvarId
|
||||
|
||||
def hasNeverExtract (v : LetValue) : CompilerM Bool :=
|
||||
def hasNeverExtract (v : LetValue .pure) : CompilerM Bool :=
|
||||
match v with
|
||||
| .const declName .. =>
|
||||
return hasNeverExtractAttribute (← getEnv) declName
|
||||
| .lit _ | .erased | .proj .. | .fvar .. =>
|
||||
return false
|
||||
|
||||
partial def _root_.Lean.Compiler.LCNF.Code.cse (shouldElimFunDecls : Bool) (code : Code) : CompilerM Code :=
|
||||
partial def _root_.Lean.Compiler.LCNF.Code.cse (shouldElimFunDecls : Bool) (code : Code .pure) :
|
||||
CompilerM (Code .pure) :=
|
||||
go code |>.run' {}
|
||||
where
|
||||
goFunDecl (decl : FunDecl) : M FunDecl := do
|
||||
goFunDecl (decl : FunDecl .pure) : M (FunDecl .pure) := do
|
||||
let type ← normExpr decl.type
|
||||
let params ← normParams decl.params
|
||||
let value ← withNewScope do go decl.value
|
||||
decl.update type params value
|
||||
|
||||
go (code : Code) : M Code := do
|
||||
go (code : Code .pure) : M (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let decl ← normLetDecl decl
|
||||
@@ -118,12 +119,13 @@ end CSE
|
||||
/--
|
||||
Common sub-expression elimination
|
||||
-/
|
||||
def Decl.cse (shouldElimFunDecls : Bool) (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.cse (shouldElimFunDecls : Bool) (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
let value ← decl.value.mapCodeM (·.cse shouldElimFunDecls)
|
||||
return { decl with value }
|
||||
|
||||
def cse (phase : Phase := .base) (shouldElimFunDecls := false) (occurrence := 0) : Pass :=
|
||||
.mkPerDeclaration `cse (Decl.cse shouldElimFunDecls) phase occurrence
|
||||
phase.withPurityCheck .pure fun h =>
|
||||
.mkPerDeclaration `cse phase (h ▸ Decl.cse shouldElimFunDecls) occurrence
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.cse (inherited := true)
|
||||
|
||||
@@ -79,7 +79,8 @@ the subtype relation in sanity checks and add the necessary casts.
|
||||
-/
|
||||
|
||||
namespace Check
|
||||
open InferType
|
||||
namespace Pure
|
||||
open InferType InferType.Pure
|
||||
|
||||
/-
|
||||
Type and structural properties checker for LCNF expressions.
|
||||
@@ -110,7 +111,7 @@ def isCtorParam (f : Expr) (i : Nat) : CoreM Bool := do
|
||||
let .ctorInfo info ← getConstInfo declName | return false
|
||||
return i < info.numParams
|
||||
|
||||
def checkAppArgs (f : Expr) (args : Array Arg) : CheckM Unit := do
|
||||
def checkAppArgs (f : Expr) (args : Array (Arg .pure)) : CheckM Unit := do
|
||||
let mut fType ← inferType f
|
||||
let mut j := 0
|
||||
for h : i in *...args.size do
|
||||
@@ -129,11 +130,11 @@ def checkAppArgs (f : Expr) (args : Array Arg) : CheckM Unit := do
|
||||
let expectedType := instantiateRevRangeArgs d j i args
|
||||
if (← checkTypes) then
|
||||
let argType ← arg.inferType
|
||||
unless (← InferType.compatibleTypes argType expectedType) do
|
||||
unless (← compatibleTypes argType expectedType) do
|
||||
throwError "type mismatch at LCNF application{indentExpr (mkAppN f (args.map Arg.toExpr))}\nargument {arg.toExpr} has type{indentExpr argType}\nbut is expected to have type{indentExpr expectedType}"
|
||||
fType := b
|
||||
|
||||
def checkLetValue (e : LetValue) : CheckM Unit := do
|
||||
def checkLetValue (e : LetValue .pure) : CheckM Unit := do
|
||||
match e with
|
||||
| .lit .. | .erased => pure ()
|
||||
| .const declName us args => checkAppArgs (mkConst declName us) args
|
||||
@@ -154,18 +155,18 @@ def checkJpInScope (jp : FVarId) : CheckM Unit := do
|
||||
-/
|
||||
throwError "invalid jump to out of scope join point `{mkFVar jp}`"
|
||||
|
||||
def checkParam (param : Param) : CheckM Unit := do
|
||||
def checkParam (param : Param .pure) : CheckM Unit := do
|
||||
unless param == (← getParam param.fvarId) do
|
||||
throwError "LCNF parameter mismatch at `{param.binderName}`, does not value in local context"
|
||||
|
||||
def checkParams (params : Array Param) : CheckM Unit :=
|
||||
def checkParams (params : Array (Param .pure)) : CheckM Unit :=
|
||||
params.forM checkParam
|
||||
|
||||
def checkLetDecl (letDecl : LetDecl) : CheckM Unit := do
|
||||
def checkLetDecl (letDecl : LetDecl .pure) : CheckM Unit := do
|
||||
checkLetValue letDecl.value
|
||||
if (← checkTypes) then
|
||||
let valueType ← letDecl.value.inferType
|
||||
unless (← InferType.compatibleTypes letDecl.type valueType) do
|
||||
unless (← compatibleTypes letDecl.type valueType) do
|
||||
throwError "type mismatch at `{letDecl.binderName}`, value has type{indentExpr valueType}\nbut is expected to have type{indentExpr letDecl.type}"
|
||||
unless letDecl == (← getLetDecl letDecl.fvarId) do
|
||||
throwError "LCNF let declaration mismatch at `{letDecl.binderName}`, does not match value in local context"
|
||||
@@ -183,7 +184,7 @@ def addFVarId (fvarId : FVarId) : CheckM Unit := do
|
||||
addFVarId fvarId
|
||||
withReader (fun ctx => { ctx with jps := ctx.jps.insert fvarId }) x
|
||||
|
||||
@[inline] def withParams (params : Array Param) (x : CheckM α) : CheckM α := do
|
||||
@[inline] def withParams (params : Array (Param .pure)) (x : CheckM α) : CheckM α := do
|
||||
params.forM (addFVarId ·.fvarId)
|
||||
withReader (fun ctx => { ctx with vars := params.foldl (init := ctx.vars) fun vars p => vars.insert p.fvarId })
|
||||
x
|
||||
@@ -192,18 +193,18 @@ mutual
|
||||
|
||||
set_option linter.all false
|
||||
|
||||
partial def checkFunDeclCore (declName : Name) (params : Array Param) (type : Expr) (value : Code) : CheckM Unit := do
|
||||
partial def checkFunDeclCore (declName : Name) (params : Array (Param .pure)) (type : Expr) (value : Code .pure) : CheckM Unit := do
|
||||
checkParams params
|
||||
withParams params do
|
||||
discard <| check value
|
||||
if (← checkTypes) then
|
||||
let valueType ← mkForallParams params (← value.inferType)
|
||||
unless (← InferType.compatibleTypes type valueType) do
|
||||
unless (← compatibleTypes type valueType) do
|
||||
throwError "type mismatch at `{.ofConstName declName}`, value has type{indentExpr valueType}\nbut is expected to have type{indentExpr type}"
|
||||
|
||||
partial def checkFunDecl (funDecl : FunDecl) : CheckM Unit := do
|
||||
partial def checkFunDecl (funDecl : FunDecl .pure) : CheckM Unit := do
|
||||
checkFunDeclCore funDecl.binderName funDecl.params funDecl.type funDecl.value
|
||||
let decl ← getFunDecl funDecl.fvarId
|
||||
let decl ← getFunDecl (pu := .pure) funDecl.fvarId
|
||||
unless decl.binderName == funDecl.binderName do
|
||||
throwError "LCNF local function declaration mismatch at `{funDecl.binderName}`, binder name in local context `{decl.binderName}`"
|
||||
unless decl.type == funDecl.type do
|
||||
@@ -211,7 +212,7 @@ partial def checkFunDecl (funDecl : FunDecl) : CheckM Unit := do
|
||||
unless (← getFunDecl funDecl.fvarId) == funDecl do
|
||||
throwError "LCNF local function declaration mismatch at `{funDecl.binderName}`, declaration in local context does match"
|
||||
|
||||
partial def checkCases (c : Cases) : CheckM Unit := do
|
||||
partial def checkCases (c : Cases .pure) : CheckM Unit := do
|
||||
let mut ctorNames : NameSet := {}
|
||||
let mut hasDefault := false
|
||||
checkFVar c.discr
|
||||
@@ -230,7 +231,7 @@ partial def checkCases (c : Cases) : CheckM Unit := do
|
||||
throwError "invalid LCNF `cases`, `{ctorName}` has # {val.numFields} fields, but alternative has # {params.size} alternatives"
|
||||
withParams params do check k
|
||||
|
||||
partial def check (code : Code) : CheckM Unit := do
|
||||
partial def check (code : Code .pure) : CheckM Unit := do
|
||||
match code with
|
||||
| .let decl k => checkLetDecl decl; withFVarId decl.fvarId do check k
|
||||
| .fun decl k =>
|
||||
@@ -241,7 +242,7 @@ partial def check (code : Code) : CheckM Unit := do
|
||||
| .cases c => checkCases c
|
||||
| .jmp fvarId args =>
|
||||
checkJpInScope fvarId
|
||||
let decl ← getFunDecl fvarId
|
||||
let decl ← getFunDecl (pu := .pure) fvarId
|
||||
unless decl.getArity == args.size do
|
||||
throwError "invalid LCNF `goto`, join point {decl.binderName} has #{decl.getArity} parameters, but #{args.size} were provided"
|
||||
checkAppArgs (.fvar fvarId) args
|
||||
@@ -253,9 +254,12 @@ end
|
||||
def run (x : CheckM α) : CompilerM α :=
|
||||
x |>.run {} |>.run' {} |>.run {}
|
||||
|
||||
end Pure
|
||||
end Check
|
||||
|
||||
def Decl.check (decl : Decl) : CompilerM Unit := do
|
||||
Check.run do decl.value.forCodeM (Check.checkFunDeclCore decl.name decl.params decl.type)
|
||||
def Decl.check (decl : Decl pu) : CompilerM Unit := do
|
||||
match pu with
|
||||
| .pure => Check.Pure.run do decl.value.forCodeM (Check.Pure.checkFunDeclCore decl.name decl.params decl.type)
|
||||
| .impure => panic! "Check for impure unimplemented" -- TODO
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
||||
@@ -33,10 +33,6 @@ structure Context where
|
||||
Remark: the lambda lifting pass abstracts all `let`/`fun`-declarations.
|
||||
-/
|
||||
abstract : FVarId → Bool
|
||||
/--
|
||||
Indicates whether we are processing terms beneath a binder.
|
||||
-/
|
||||
isUnderBinder : Bool
|
||||
|
||||
/--
|
||||
State for the `ClosureM` monad.
|
||||
@@ -49,7 +45,7 @@ structure State where
|
||||
/--
|
||||
Free variables that must become new parameters of the code being specialized.
|
||||
-/
|
||||
params : Array Param := #[]
|
||||
params : Array (Param .pure) := #[]
|
||||
/--
|
||||
Let-declarations and local function declarations that are going to be "copied" to the code
|
||||
being processed. For example, when this module is used in the code specializer, the let-declarations
|
||||
@@ -60,7 +56,7 @@ structure State where
|
||||
All customers of this module try to avoid work duplication. If a let-declaration is a ground value,
|
||||
it most likely will be computed during compilation time, and work duplication is not an issue.
|
||||
-/
|
||||
decls : Array CodeDecl := #[]
|
||||
decls : Array (CodeDecl .pure) := #[]
|
||||
|
||||
/--
|
||||
Monad for implementing the dependency collector.
|
||||
@@ -79,16 +75,16 @@ mutual
|
||||
Collect dependencies in parameters. We need this because parameters may
|
||||
contain other type parameters.
|
||||
-/
|
||||
partial def collectParams (params : Array Param) : ClosureM Unit :=
|
||||
partial def collectParams (params : Array (Param .pure)) : ClosureM Unit :=
|
||||
params.forM (collectType ·.type)
|
||||
|
||||
partial def collectArg (arg : Arg) : ClosureM Unit :=
|
||||
partial def collectArg (arg : Arg .pure) : ClosureM Unit :=
|
||||
match arg with
|
||||
| .erased => return ()
|
||||
| .type e => collectType e
|
||||
| .fvar fvarId => collectFVar fvarId
|
||||
|
||||
partial def collectLetValue (e : LetValue) : ClosureM Unit := do
|
||||
partial def collectLetValue (e : LetValue .pure) : ClosureM Unit := do
|
||||
match e with
|
||||
| .erased | .lit .. => return ()
|
||||
| .proj _ _ fvarId => collectFVar fvarId
|
||||
@@ -99,12 +95,11 @@ mutual
|
||||
Collect dependencies in the given code. We need this function to be able
|
||||
to collect dependencies in a local function declaration.
|
||||
-/
|
||||
partial def collectCode (c : Code) : ClosureM Unit := do
|
||||
partial def collectCode (c : Code .pure) : ClosureM Unit := do
|
||||
match c with
|
||||
| .let decl k =>
|
||||
collectType decl.type
|
||||
withReader (fun ctx => { ctx with isUnderBinder := ctx.isUnderBinder || decl.type.isForall })
|
||||
do collectLetValue decl.value
|
||||
collectLetValue decl.value
|
||||
collectCode k
|
||||
| .fun decl k | .jp decl k => collectFunDecl decl; collectCode k
|
||||
| .cases c =>
|
||||
@@ -119,11 +114,10 @@ mutual
|
||||
| .return fvarId => collectFVar fvarId
|
||||
|
||||
/-- Collect dependencies of a local function declaration. -/
|
||||
partial def collectFunDecl (decl : FunDecl) : ClosureM Unit := do
|
||||
partial def collectFunDecl (decl : FunDecl .pure) : ClosureM Unit := do
|
||||
collectType decl.type
|
||||
collectParams decl.params
|
||||
withReader (fun ctx => { ctx with isUnderBinder := true }) do
|
||||
collectCode decl.value
|
||||
collectCode decl.value
|
||||
|
||||
/--
|
||||
Process the given free variable.
|
||||
@@ -146,7 +140,7 @@ mutual
|
||||
modify fun s => { s with params := s.params.push param }
|
||||
else if let some letDecl ← findLetDecl? fvarId then
|
||||
collectType letDecl.type
|
||||
if ctx.isUnderBinder || ctx.abstract letDecl.fvarId then
|
||||
if ctx.abstract letDecl.fvarId then
|
||||
modify fun s => { s with params := s.params.push <| { letDecl with borrow := false } }
|
||||
else
|
||||
collectLetValue letDecl.value
|
||||
@@ -161,8 +155,9 @@ mutual
|
||||
|
||||
end
|
||||
|
||||
def run (x : ClosureM α) (inScope : FVarId → Bool) (abstract : FVarId → Bool := fun _ => true) : CompilerM (α × Array Param × Array CodeDecl) := do
|
||||
let (a, s) ← x { inScope, abstract, isUnderBinder := false } |>.run {}
|
||||
def run (x : ClosureM α) (inScope : FVarId → Bool) (abstract : FVarId → Bool := fun _ => true) :
|
||||
CompilerM (α × Array (Param .pure) × Array (CodeDecl .pure)) := do
|
||||
let (a, s) ← x { inScope, abstract } |>.run {}
|
||||
-- If we've abstracted an fvar into a param, exclude its definition. Note that this still allows
|
||||
-- for other decls the removed decl depends upon to be included, but they will be removed later
|
||||
-- for having no users.
|
||||
|
||||
@@ -72,10 +72,13 @@ partial def compatibleTypesQuick (a b : Expr) : Bool :=
|
||||
| .const n us, .const m vs => n == m && List.isEqv us vs Level.isEquiv
|
||||
| _, _ => false
|
||||
|
||||
namespace InferType
|
||||
namespace Pure
|
||||
|
||||
/--
|
||||
Complete check for `compatibleTypes`. It eta-expands type formers. See comment at `compatibleTypes`.
|
||||
-/
|
||||
partial def InferType.compatibleTypesFull (a b : Expr) : InferTypeM Bool := do
|
||||
partial def compatibleTypesFull (a b : Expr) : InferTypeM Bool := do
|
||||
if a.isErased || b.isErased then
|
||||
return true
|
||||
else
|
||||
@@ -141,10 +144,13 @@ This is a simplification. We used to use `isErasedCompatible`, but this only add
|
||||
For item 2, we would have to modify the `toLCNFType` function and make sure a type former is erased if the expected
|
||||
type is not always a type former (see `S.mk` type and example in the note above).
|
||||
-/
|
||||
def InferType.compatibleTypes (a b : Expr) : InferTypeM Bool := do
|
||||
def compatibleTypes (a b : Expr) : InferTypeM Bool := do
|
||||
if compatibleTypesQuick a b then
|
||||
return true
|
||||
else
|
||||
compatibleTypesFull a b
|
||||
|
||||
end Pure
|
||||
end InferType
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
||||
@@ -21,7 +21,12 @@ inductive Phase where
|
||||
| base
|
||||
/-- In this phase polymorphism has been eliminated. -/
|
||||
| mono
|
||||
deriving Inhabited, BEq
|
||||
| impure
|
||||
deriving Inhabited, DecidableEq
|
||||
|
||||
@[expose, reducible] def Phase.toPurity : Phase → Purity
|
||||
| .base | .mono => .pure
|
||||
| .impure => .impure
|
||||
|
||||
/--
|
||||
The state managed by the `CompilerM` `Monad`.
|
||||
@@ -52,48 +57,53 @@ instance : Monad CompilerM := let i := inferInstanceAs (Monad CompilerM); { pure
|
||||
def getPhase : CompilerM Phase :=
|
||||
return (← read).phase
|
||||
|
||||
def getPurity : CompilerM Purity :=
|
||||
return (← getPhase).toPurity
|
||||
|
||||
def inBasePhase : CompilerM Bool :=
|
||||
return (← getPhase) matches .base
|
||||
|
||||
instance : AddMessageContext CompilerM where
|
||||
addMessageContext msgData := do
|
||||
let env ← getEnv
|
||||
let lctx := (← get).lctx.toLocalContext
|
||||
let lctx := (← get).lctx.toLocalContext (← getPurity)
|
||||
let opts ← getOptions
|
||||
return MessageData.withContext { env, lctx, opts, mctx := {} } msgData
|
||||
|
||||
def getType (fvarId : FVarId) : CompilerM Expr := do
|
||||
let lctx := (← get).lctx
|
||||
if let some decl := lctx.letDecls[fvarId]? then
|
||||
let pu ← getPurity
|
||||
if let some decl := (lctx.letDecls pu)[fvarId]? then
|
||||
return decl.type
|
||||
else if let some decl := lctx.params[fvarId]? then
|
||||
else if let some decl := (lctx.params pu)[fvarId]? then
|
||||
return decl.type
|
||||
else if let some decl := lctx.funDecls[fvarId]? then
|
||||
else if let some decl := (lctx.funDecls pu)[fvarId]? then
|
||||
return decl.type
|
||||
else
|
||||
throwError "unknown free variable {fvarId.name}"
|
||||
|
||||
def getBinderName (fvarId : FVarId) : CompilerM Name := do
|
||||
let lctx := (← get).lctx
|
||||
if let some decl := lctx.letDecls[fvarId]? then
|
||||
let pu ← getPurity
|
||||
if let some decl := (lctx.letDecls pu)[fvarId]? then
|
||||
return decl.binderName
|
||||
else if let some decl := lctx.params[fvarId]? then
|
||||
else if let some decl := (lctx.params pu)[fvarId]? then
|
||||
return decl.binderName
|
||||
else if let some decl := lctx.funDecls[fvarId]? then
|
||||
else if let some decl := (lctx.funDecls pu)[fvarId]? then
|
||||
return decl.binderName
|
||||
else
|
||||
throwError "unknown free variable {fvarId.name}"
|
||||
|
||||
def findParam? (fvarId : FVarId) : CompilerM (Option Param) :=
|
||||
return (← get).lctx.params[fvarId]?
|
||||
def findParam? (fvarId : FVarId) : CompilerM (Option (Param pu)) := do
|
||||
return ((← get).lctx.params pu)[fvarId]?
|
||||
|
||||
def findLetDecl? (fvarId : FVarId) : CompilerM (Option LetDecl) :=
|
||||
return (← get).lctx.letDecls[fvarId]?
|
||||
def findLetDecl? (fvarId : FVarId) : CompilerM (Option (LetDecl pu)) := do
|
||||
return ((← get).lctx.letDecls pu)[fvarId]?
|
||||
|
||||
def findFunDecl? (fvarId : FVarId) : CompilerM (Option FunDecl) :=
|
||||
return (← get).lctx.funDecls[fvarId]?
|
||||
def findFunDecl? (fvarId : FVarId) : CompilerM (Option (FunDecl pu)) := do
|
||||
return ((← get).lctx.funDecls pu)[fvarId]?
|
||||
|
||||
def findLetValue? (fvarId : FVarId) : CompilerM (Option LetValue) := do
|
||||
def findLetValue? (fvarId : FVarId) : CompilerM (Option (LetValue pu)) := do
|
||||
let some { value, .. } ← findLetDecl? fvarId | return none
|
||||
return some value
|
||||
|
||||
@@ -101,56 +111,56 @@ def isConstructorApp (fvarId : FVarId) : CompilerM Bool := do
|
||||
let some (.const declName _ _) ← findLetValue? fvarId | return false
|
||||
return (← getEnv).find? declName matches some (.ctorInfo ..)
|
||||
|
||||
def Arg.isConstructorApp (arg : Arg) : CompilerM Bool := do
|
||||
def Arg.isConstructorApp (arg : Arg pu) : CompilerM Bool := do
|
||||
let .fvar fvarId := arg | return false
|
||||
LCNF.isConstructorApp fvarId
|
||||
|
||||
def getParam (fvarId : FVarId) : CompilerM Param := do
|
||||
def getParam (fvarId : FVarId) : CompilerM (Param pu) := do
|
||||
let some param ← findParam? fvarId | throwError "unknown parameter {fvarId.name}"
|
||||
return param
|
||||
|
||||
def getLetDecl (fvarId : FVarId) : CompilerM LetDecl := do
|
||||
def getLetDecl (fvarId : FVarId) : CompilerM (LetDecl pu) := do
|
||||
let some decl ← findLetDecl? fvarId | throwError "unknown let-declaration {fvarId.name}"
|
||||
return decl
|
||||
|
||||
def getFunDecl (fvarId : FVarId) : CompilerM FunDecl := do
|
||||
def getFunDecl (fvarId : FVarId) : CompilerM (FunDecl pu) := do
|
||||
let some decl ← findFunDecl? fvarId | throwError "unknown local function {fvarId.name}"
|
||||
return decl
|
||||
|
||||
@[inline] def modifyLCtx (f : LCtx → LCtx) : CompilerM Unit := do
|
||||
modify fun s => { s with lctx := f s.lctx }
|
||||
|
||||
def eraseLetDecl (decl : LetDecl) : CompilerM Unit := do
|
||||
def eraseLetDecl (decl : LetDecl pu) : CompilerM Unit := do
|
||||
modifyLCtx fun lctx => lctx.eraseLetDecl decl
|
||||
|
||||
def eraseFunDecl (decl : FunDecl) (recursive := true) : CompilerM Unit := do
|
||||
def eraseFunDecl (decl : FunDecl pu) (recursive := true) : CompilerM Unit := do
|
||||
modifyLCtx fun lctx => lctx.eraseFunDecl decl recursive
|
||||
|
||||
def eraseCode (code : Code) : CompilerM Unit := do
|
||||
def eraseCode (code : Code pu) : CompilerM Unit := do
|
||||
modifyLCtx fun lctx => lctx.eraseCode code
|
||||
|
||||
def eraseParam (param : Param) : CompilerM Unit :=
|
||||
def eraseParam (param : Param pu) : CompilerM Unit :=
|
||||
modifyLCtx fun lctx => lctx.eraseParam param
|
||||
|
||||
def eraseParams (params : Array Param) : CompilerM Unit :=
|
||||
def eraseParams (params : Array (Param pu)) : CompilerM Unit :=
|
||||
modifyLCtx fun lctx => lctx.eraseParams params
|
||||
|
||||
def eraseCodeDecl (decl : CodeDecl) : CompilerM Unit := do
|
||||
def eraseCodeDecl (decl : CodeDecl pu) : CompilerM Unit := do
|
||||
match decl with
|
||||
| .let decl => eraseLetDecl decl
|
||||
| .jp decl | .fun decl => eraseFunDecl decl
|
||||
| .jp decl | .fun decl _ => eraseFunDecl decl
|
||||
|
||||
/--
|
||||
Erase all free variables occurring in `decls` from the local context.
|
||||
-/
|
||||
def eraseCodeDecls (decls : Array CodeDecl) : CompilerM Unit := do
|
||||
def eraseCodeDecls (decls : Array (CodeDecl pu)) : CompilerM Unit := do
|
||||
decls.forM fun decl => eraseCodeDecl decl
|
||||
|
||||
def eraseDecl (decl : Decl) : CompilerM Unit := do
|
||||
def eraseDecl (decl : Decl pu) : CompilerM Unit := do
|
||||
eraseParams decl.params
|
||||
decl.value.forCodeM eraseCode
|
||||
|
||||
abbrev Decl.erase (decl : Decl) : CompilerM Unit :=
|
||||
abbrev Decl.erase (decl : Decl pu) : CompilerM Unit :=
|
||||
eraseDecl decl
|
||||
|
||||
/--
|
||||
@@ -166,7 +176,7 @@ it is a free variable, a type (or type former), or `lcErased`.
|
||||
|
||||
`Check.lean` contains a substitution validator.
|
||||
-/
|
||||
abbrev FVarSubst := Std.HashMap FVarId Arg
|
||||
abbrev FVarSubst (pu : Purity) := Std.HashMap FVarId (Arg pu)
|
||||
|
||||
/--
|
||||
Replace the free variables in `e` using the given substitution.
|
||||
@@ -179,7 +189,7 @@ If `translator = false`, we assume the substitution contains free variable repla
|
||||
and given entries such as `x₁ ↦ x₂`, `x₂ ↦ x₃`, ..., `xₙ₋₁ ↦ xₙ`, and the expression `f x₁ x₂`, we want the resulting
|
||||
expression to be `f xₙ xₙ`. We use this setting, for example, in the simplifier.
|
||||
-/
|
||||
private partial def normExprImp (s : FVarSubst) (e : Expr) (translator : Bool) : Expr :=
|
||||
private partial def normExprImp (s : FVarSubst pu) (e : Expr) (translator : Bool) : Expr :=
|
||||
go e
|
||||
where
|
||||
goApp (e : Expr) : Expr :=
|
||||
@@ -192,7 +202,7 @@ where
|
||||
match e with
|
||||
| .fvar fvarId => match s[fvarId]? with
|
||||
| some (.fvar fvarId') => if translator then .fvar fvarId' else go (.fvar fvarId')
|
||||
| some (.type e) => if translator then e else go e
|
||||
| some (.type e _) => if translator then e else go e
|
||||
| some .erased => erasedExpr
|
||||
| none => e
|
||||
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e
|
||||
@@ -225,7 +235,7 @@ This function panics if the substitution is mapping `fvarId` to an expression th
|
||||
That is, it is not a type (or type former), nor `lcErased`. Recall that a valid `FVarSubst` contains only
|
||||
expressions that are free variables, `lcErased`, or type formers.
|
||||
-/
|
||||
partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : Bool) : NormFVarResult :=
|
||||
partial def normFVarImp (s : FVarSubst pu) (fvarId : FVarId) (translator : Bool) : NormFVarResult :=
|
||||
match s[fvarId]? with
|
||||
| some (.fvar fvarId') =>
|
||||
if translator then
|
||||
@@ -234,7 +244,7 @@ partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : Bool) :
|
||||
normFVarImp s fvarId' translator
|
||||
-- Types and type formers are only preserved as hints and
|
||||
-- are erased in computationally relevant contexts.
|
||||
| some .erased | some (.type _) => .erased
|
||||
| some .erased | some (.type _ _) => .erased
|
||||
| none => .fvar fvarId
|
||||
|
||||
/--
|
||||
@@ -242,18 +252,18 @@ Replace the free variables in `arg` using the given substitution.
|
||||
|
||||
See `normExprImp`
|
||||
-/
|
||||
private partial def normArgImp (s : FVarSubst) (arg : Arg) (translator : Bool) : Arg :=
|
||||
private partial def normArgImp (s : FVarSubst pu) (arg : Arg pu) (translator : Bool) : Arg pu :=
|
||||
match arg with
|
||||
| .erased => arg
|
||||
| .fvar fvarId =>
|
||||
match s[fvarId]? with
|
||||
| some (arg'@(.fvar _)) =>
|
||||
if translator then arg' else normArgImp s arg' translator
|
||||
| some (arg'@.erased) | some (arg'@(.type _)) => arg'
|
||||
| some (arg'@.erased) | some (arg'@(.type _ _)) => arg'
|
||||
| none => arg
|
||||
| .type e => arg.updateType! (normExprImp s e translator)
|
||||
| .type e _ => arg.updateType! (normExprImp s e translator)
|
||||
|
||||
private def normArgsImp (s : FVarSubst) (args : Array Arg) (translator : Bool) : Array Arg :=
|
||||
private def normArgsImp (s : FVarSubst pu) (args : Array (Arg pu)) (translator : Bool) : Array (Arg pu) :=
|
||||
args.mapMono (normArgImp s · translator)
|
||||
|
||||
/--
|
||||
@@ -261,13 +271,13 @@ Replace the free variables in `e` using the given substitution.
|
||||
|
||||
See `normExprImp`
|
||||
-/
|
||||
private partial def normLetValueImp (s : FVarSubst) (e : LetValue) (translator : Bool) : LetValue :=
|
||||
private partial def normLetValueImp (s : FVarSubst pu) (e : LetValue pu) (translator : Bool) : LetValue pu :=
|
||||
match e with
|
||||
| .erased | .lit .. => e
|
||||
| .proj _ _ fvarId => match normFVarImp s fvarId translator with
|
||||
| .proj _ _ fvarId _ => match normFVarImp s fvarId translator with
|
||||
| .fvar fvarId' => e.updateProj! fvarId'
|
||||
| .erased => .erased
|
||||
| .const _ _ args => e.updateArgs! (normArgsImp s args translator)
|
||||
| .const _ _ args _ => e.updateArgs! (normArgsImp s args translator)
|
||||
| .fvar fvarId args => match normFVarImp s fvarId translator with
|
||||
| .fvar fvarId' => e.updateFVar! fvarId' (normArgsImp s args translator)
|
||||
| .erased => .erased
|
||||
@@ -275,20 +285,20 @@ private partial def normLetValueImp (s : FVarSubst) (e : LetValue) (translator :
|
||||
/--
|
||||
Interface for monads that have a free substitutions.
|
||||
-/
|
||||
class MonadFVarSubst (m : Type → Type) (translator : outParam Bool) where
|
||||
getSubst : m FVarSubst
|
||||
class MonadFVarSubst (m : Type → Type) (pu : outParam Purity) (translator : outParam Bool) where
|
||||
getSubst : m (FVarSubst pu)
|
||||
|
||||
export MonadFVarSubst (getSubst)
|
||||
|
||||
instance (m n) [MonadLift m n] [MonadFVarSubst m t] : MonadFVarSubst n t where
|
||||
instance (m n) [MonadLift m n] [MonadFVarSubst m pu t] : MonadFVarSubst n pu t where
|
||||
getSubst := liftM (getSubst : m _)
|
||||
|
||||
class MonadFVarSubstState (m : Type → Type) where
|
||||
modifySubst : (FVarSubst → FVarSubst) → m Unit
|
||||
class MonadFVarSubstState (m : Type → Type) (pu : outParam Purity) where
|
||||
modifySubst : (FVarSubst pu → FVarSubst pu) → m Unit
|
||||
|
||||
export MonadFVarSubstState (modifySubst)
|
||||
|
||||
instance (m n) [MonadLift m n] [MonadFVarSubstState m] : MonadFVarSubstState n where
|
||||
instance (m n) [MonadLift m n] [MonadFVarSubstState m pu] : MonadFVarSubstState n pu where
|
||||
modifySubst f := liftM (modifySubst f : m _)
|
||||
|
||||
/--
|
||||
@@ -296,35 +306,35 @@ Add the substitution `fvarId ↦ e`, `e` must be a valid LCNF `Arg`.
|
||||
|
||||
See `Check.lean` for the free variable substitution checker.
|
||||
-/
|
||||
@[inline] def addSubst [MonadFVarSubstState m] (fvarId : FVarId) (arg : Arg) : m Unit :=
|
||||
@[inline] def addSubst [MonadFVarSubstState m pu] (fvarId : FVarId) (arg : Arg pu) : m Unit :=
|
||||
modifySubst fun s => s.insert fvarId arg
|
||||
|
||||
/--
|
||||
Add the entry `fvarId ↦ fvarId'` to the free variable substitution.
|
||||
-/
|
||||
@[inline] def addFVarSubst [MonadFVarSubstState m] (fvarId : FVarId) (fvarId' : FVarId) : m Unit :=
|
||||
@[inline] def addFVarSubst [MonadFVarSubstState m ph] (fvarId : FVarId) (fvarId' : FVarId) : m Unit :=
|
||||
modifySubst fun s => s.insert fvarId (.fvar fvarId')
|
||||
|
||||
@[inline, inherit_doc normFVarImp] def normFVar [MonadFVarSubst m t] [Monad m] (fvarId : FVarId) : m NormFVarResult :=
|
||||
@[inline, inherit_doc normFVarImp] def normFVar [MonadFVarSubst m pu t] [Monad m] (fvarId : FVarId) : m NormFVarResult :=
|
||||
return normFVarImp (← getSubst) fvarId t
|
||||
|
||||
@[inline, inherit_doc normExprImp] def normExpr [MonadFVarSubst m t] [Monad m] (e : Expr) : m Expr :=
|
||||
@[inline, inherit_doc normExprImp] def normExpr [MonadFVarSubst m pu t] [Monad m] (e : Expr) : m Expr :=
|
||||
return normExprImp (← getSubst) e t
|
||||
|
||||
@[inline, inherit_doc normArgImp] def normArg [MonadFVarSubst m t] [Monad m] (arg : Arg) : m Arg :=
|
||||
@[inline, inherit_doc normArgImp] def normArg [MonadFVarSubst m pu t] [Monad m] (arg : Arg pu) : m (Arg pu) :=
|
||||
return normArgImp (← getSubst) arg t
|
||||
|
||||
@[inline, inherit_doc normLetValueImp] def normLetValue [MonadFVarSubst m t] [Monad m] (e : LetValue) : m LetValue :=
|
||||
@[inline, inherit_doc normLetValueImp] def normLetValue [MonadFVarSubst m pu t] [Monad m] (e : LetValue pu) : m (LetValue pu) :=
|
||||
return normLetValueImp (← getSubst) e t
|
||||
|
||||
@[inherit_doc normExprImp, inline]
|
||||
def normExprCore (s : FVarSubst) (e : Expr) (translator : Bool) : Expr :=
|
||||
def normExprCore (s : FVarSubst pu) (e : Expr) (translator : Bool) : Expr :=
|
||||
normExprImp s e translator
|
||||
|
||||
/--
|
||||
Normalize the given arguments using the current substitution.
|
||||
-/
|
||||
def normArgs [MonadFVarSubst m t] [Monad m] (args : Array Arg) : m (Array Arg) :=
|
||||
def normArgs [MonadFVarSubst m pu t] [Monad m] (args : Array (Arg pu)) : m (Array (Arg pu)) :=
|
||||
return normArgsImp (← getSubst) args t
|
||||
|
||||
def mkFreshBinderName (binderName := `_x): CompilerM Name := do
|
||||
@@ -342,35 +352,35 @@ def ensureNotAnonymous (binderName : Name) (baseName : Name) : CompilerM Name :=
|
||||
Helper functions for creating LCNF local declarations.
|
||||
-/
|
||||
|
||||
def mkParam (binderName : Name) (type : Expr) (borrow : Bool) : CompilerM Param := do
|
||||
def mkParam (binderName : Name) (type : Expr) (borrow : Bool) : CompilerM (Param pu) := do
|
||||
let fvarId ← mkFreshFVarId
|
||||
let binderName ← ensureNotAnonymous binderName `_y
|
||||
let param := { fvarId, binderName, type, borrow }
|
||||
modifyLCtx fun lctx => lctx.addParam param
|
||||
return param
|
||||
|
||||
def mkLetDecl (binderName : Name) (type : Expr) (value : LetValue) : CompilerM LetDecl := do
|
||||
def mkLetDecl (binderName : Name) (type : Expr) (value : LetValue pu) : CompilerM (LetDecl pu) := do
|
||||
let fvarId ← mkFreshFVarId
|
||||
let binderName ← ensureNotAnonymous binderName `_x
|
||||
let decl := { fvarId, binderName, type, value }
|
||||
modifyLCtx fun lctx => lctx.addLetDecl decl
|
||||
return decl
|
||||
|
||||
def mkFunDecl (binderName : Name) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
|
||||
def mkFunDecl (binderName : Name) (type : Expr) (params : Array (Param pu)) (value : Code pu) : CompilerM (FunDecl pu) := do
|
||||
let fvarId ← mkFreshFVarId
|
||||
let binderName ← ensureNotAnonymous binderName `_f
|
||||
let funDecl := ⟨fvarId, binderName, params, type, value⟩
|
||||
modifyLCtx fun lctx => lctx.addFunDecl funDecl
|
||||
return funDecl
|
||||
|
||||
def mkLetDeclErased : CompilerM LetDecl := do
|
||||
def mkLetDeclErased : CompilerM (LetDecl pu) := do
|
||||
mkLetDecl (← mkFreshBinderName `_x) erasedExpr .erased
|
||||
|
||||
def mkReturnErased : CompilerM Code := do
|
||||
def mkReturnErased : CompilerM (Code pu) := do
|
||||
let auxDecl ← mkLetDeclErased
|
||||
return .let auxDecl (.return auxDecl.fvarId)
|
||||
|
||||
private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param := do
|
||||
private unsafe def updateParamImp (p : Param pu) (type : Expr) : CompilerM (Param pu) := do
|
||||
if ptrEq type p.type then
|
||||
return p
|
||||
else
|
||||
@@ -378,9 +388,9 @@ private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param :=
|
||||
modifyLCtx fun lctx => lctx.addParam p
|
||||
return p
|
||||
|
||||
@[implemented_by updateParamImp] opaque Param.update (p : Param) (type : Expr) : CompilerM Param
|
||||
@[implemented_by updateParamImp] opaque Param.update (p : Param pu) (type : Expr) : CompilerM (Param pu)
|
||||
|
||||
private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : LetValue) : CompilerM LetDecl := do
|
||||
private unsafe def updateLetDeclImp (decl : LetDecl pu) (type : Expr) (value : LetValue pu) : CompilerM (LetDecl pu) := do
|
||||
if ptrEq type decl.type && ptrEq value decl.value then
|
||||
return decl
|
||||
else
|
||||
@@ -388,12 +398,12 @@ private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : LetV
|
||||
modifyLCtx fun lctx => lctx.addLetDecl decl
|
||||
return decl
|
||||
|
||||
@[implemented_by updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : LetValue) : CompilerM LetDecl
|
||||
@[implemented_by updateLetDeclImp] opaque LetDecl.update (decl : LetDecl pu) (type : Expr) (value : LetValue pu) : CompilerM (LetDecl pu)
|
||||
|
||||
def LetDecl.updateValue (decl : LetDecl) (value : LetValue) : CompilerM LetDecl :=
|
||||
def LetDecl.updateValue (decl : LetDecl pu) (value : LetValue pu) : CompilerM (LetDecl pu) :=
|
||||
decl.update decl.type value
|
||||
|
||||
private unsafe def updateFunDeclImp (decl : FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
|
||||
private unsafe def updateFunDeclImp (decl : FunDecl pu) (type : Expr) (params : Array (Param pu)) (value : Code pu) : CompilerM (FunDecl pu) := do
|
||||
if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then
|
||||
return decl
|
||||
else
|
||||
@@ -401,48 +411,48 @@ private unsafe def updateFunDeclImp (decl : FunDecl) (type : Expr) (params : Arr
|
||||
modifyLCtx fun lctx => lctx.addFunDecl decl
|
||||
return decl
|
||||
|
||||
@[implemented_by updateFunDeclImp] opaque FunDecl.update (decl : FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl
|
||||
@[implemented_by updateFunDeclImp] opaque FunDecl.update (decl : FunDecl pu) (type : Expr) (params : Array (Param pu)) (value : Code pu) : CompilerM (FunDecl pu)
|
||||
|
||||
abbrev FunDecl.update' (decl : FunDecl) (type : Expr) (value : Code) : CompilerM FunDecl :=
|
||||
abbrev FunDecl.update' (decl : FunDecl pu) (type : Expr) (value : Code pu) : CompilerM (FunDecl pu) :=
|
||||
decl.update type decl.params value
|
||||
|
||||
abbrev FunDecl.updateValue (decl : FunDecl) (value : Code) : CompilerM FunDecl :=
|
||||
abbrev FunDecl.updateValue (decl : FunDecl pu) (value : Code pu) : CompilerM (FunDecl pu) :=
|
||||
decl.update decl.type decl.params value
|
||||
|
||||
@[inline] def normParam [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (p : Param) : m Param := do
|
||||
@[inline] def normParam [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (p : Param pu) : m (Param pu) := do
|
||||
p.update (← normExpr p.type)
|
||||
|
||||
def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (ps : Array Param) : m (Array Param) :=
|
||||
def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (ps : Array (Param pu)) : m (Array (Param pu)) :=
|
||||
ps.mapMonoM normParam
|
||||
|
||||
def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (decl : LetDecl) : m LetDecl := do
|
||||
def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (decl : LetDecl pu) : m (LetDecl pu) := do
|
||||
decl.update (← normExpr decl.type) (← normLetValue decl.value)
|
||||
|
||||
abbrev NormalizerM (_translator : Bool) := ReaderT FVarSubst CompilerM
|
||||
abbrev NormalizerM (pu : Purity) (_translator : Bool) := ReaderT (FVarSubst pu) CompilerM
|
||||
|
||||
instance : MonadFVarSubst (NormalizerM t) t where
|
||||
instance : MonadFVarSubst (NormalizerM pu t) pu t where
|
||||
getSubst := read
|
||||
|
||||
/--
|
||||
If `result` is `.fvar fvarId`, then return `x fvarId`. Otherwise, it is `.erased`,
|
||||
and method returns `let _x.i := .erased; return _x.i`.
|
||||
-/
|
||||
@[inline] def withNormFVarResult [MonadLiftT CompilerM m] [Monad m] (result : NormFVarResult) (x : FVarId → m Code) : m Code := do
|
||||
@[inline] def withNormFVarResult [MonadLiftT CompilerM m] [Monad m] (result : NormFVarResult) (x : FVarId → m (Code pu)) : m (Code pu) := do
|
||||
match result with
|
||||
| .fvar fvarId => x fvarId
|
||||
| .erased => mkReturnErased
|
||||
|
||||
mutual
|
||||
partial def normFunDeclImp (decl : FunDecl) : NormalizerM t FunDecl := do
|
||||
partial def normFunDeclImp (decl : FunDecl pu) : NormalizerM pu t (FunDecl pu) := do
|
||||
let type ← normExpr decl.type
|
||||
let params ← normParams decl.params
|
||||
let value ← normCodeImp decl.value
|
||||
decl.update type params value
|
||||
|
||||
partial def normCodeImp (code : Code) : NormalizerM t Code := do
|
||||
partial def normCodeImp (code : Code pu) : NormalizerM pu t (Code pu) := do
|
||||
match code with
|
||||
| .let decl k => return code.updateLet! (← normLetDecl decl) (← normCodeImp k)
|
||||
| .fun decl k | .jp decl k => return code.updateFun! (← normFunDeclImp decl) (← normCodeImp k)
|
||||
| .fun decl k _ | .jp decl k => return code.updateFun! (← normFunDeclImp decl) (← normCodeImp k)
|
||||
| .return fvarId => withNormFVarResult (← normFVar fvarId) fun fvarId => return code.updateReturn! fvarId
|
||||
| .jmp fvarId args => withNormFVarResult (← normFVar fvarId) fun fvarId => return code.updateJmp! fvarId (← normArgs args)
|
||||
| .unreach type => return code.updateUnreach! (← normExpr type)
|
||||
@@ -451,28 +461,28 @@ mutual
|
||||
withNormFVarResult (← normFVar c.discr) fun discr => do
|
||||
let alts ← c.alts.mapMonoM fun alt =>
|
||||
match alt with
|
||||
| .alt _ params k => return alt.updateAlt! (← normParams params) (← normCodeImp k)
|
||||
| .alt _ params k _ => return alt.updateAlt! (← normParams params) (← normCodeImp k)
|
||||
| .default k => return alt.updateCode (← normCodeImp k)
|
||||
return code.updateCases! resultType discr alts
|
||||
end
|
||||
|
||||
@[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (decl : FunDecl) : m FunDecl := do
|
||||
@[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (decl : FunDecl pu) : m (FunDecl pu) := do
|
||||
normFunDeclImp (t := t) decl (← getSubst)
|
||||
|
||||
/-- Similar to `internalize`, but does not refresh `FVarId`s. -/
|
||||
@[inline] def normCode [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (code : Code) : m Code := do
|
||||
@[inline] def normCode [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (code : Code pu) : m (Code pu) := do
|
||||
normCodeImp (t := t) code (← getSubst)
|
||||
|
||||
def replaceExprFVars (e : Expr) (s : FVarSubst) (translator : Bool) : CompilerM Expr :=
|
||||
(normExpr e : NormalizerM translator Expr).run s
|
||||
def replaceExprFVars (e : Expr) (s : FVarSubst pu) (translator : Bool) : CompilerM Expr :=
|
||||
(normExpr e : NormalizerM pu translator Expr).run s
|
||||
|
||||
def replaceFVars (code : Code) (s : FVarSubst) (translator : Bool) : CompilerM Code :=
|
||||
(normCode code : NormalizerM translator Code).run s
|
||||
def replaceFVars (code : Code pu) (s : FVarSubst pu) (translator : Bool) : CompilerM (Code pu) :=
|
||||
(normCode code : NormalizerM pu translator (Code pu)).run s
|
||||
|
||||
def mkFreshJpName : CompilerM Name := do
|
||||
mkFreshBinderName `_jp
|
||||
|
||||
def mkAuxParam (type : Expr) (borrow := false) : CompilerM Param := do
|
||||
def mkAuxParam (type : Expr) (borrow := false) : CompilerM (Param pu) := do
|
||||
mkParam (← mkFreshBinderName `_y) type borrow
|
||||
|
||||
def getConfig : CompilerM ConfigOptions :=
|
||||
|
||||
@@ -12,25 +12,25 @@ public section
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
instance : Hashable Param where
|
||||
instance : Hashable (Param pu) where
|
||||
hash p := mixHash (hash p.fvarId) (hash p.type)
|
||||
|
||||
def hashParams (ps : Array Param) : UInt64 :=
|
||||
def hashParams (ps : Array (Param pu)) : UInt64 :=
|
||||
hash ps
|
||||
|
||||
mutual
|
||||
partial def hashAlt (alt : Alt) : UInt64 :=
|
||||
partial def hashAlt (alt : Alt pu) : UInt64 :=
|
||||
match alt with
|
||||
| .alt ctorName ps k => mixHash (mixHash (hash ctorName) (hash ps)) (hashCode k)
|
||||
| .alt ctorName ps k _ => mixHash (mixHash (hash ctorName) (hash ps)) (hashCode k)
|
||||
| .default k => hashCode k
|
||||
|
||||
partial def hashAlts (alts : Array Alt) : UInt64 :=
|
||||
partial def hashAlts (alts : Array (Alt pu)) : UInt64 :=
|
||||
alts.foldl (fun r a => mixHash r (hashAlt a)) 7
|
||||
|
||||
partial def hashCode (code : Code) : UInt64 :=
|
||||
partial def hashCode (code : Code pu) : UInt64 :=
|
||||
match code with
|
||||
| .let decl k => mixHash (mixHash (hash decl.fvarId) (hash decl.type)) (mixHash (hash decl.value) (hashCode k))
|
||||
| .fun decl k | .jp decl k =>
|
||||
| .fun decl k _ | .jp decl k =>
|
||||
mixHash (mixHash (mixHash (hash decl.fvarId) (hash decl.type)) (mixHash (hashCode decl.value) (hashCode k))) (hash decl.params)
|
||||
| .return fvarId => hash fvarId
|
||||
| .unreach type => hash type
|
||||
@@ -39,7 +39,7 @@ partial def hashCode (code : Code) : UInt64 :=
|
||||
|
||||
end
|
||||
|
||||
instance : Hashable Code where
|
||||
instance : Hashable (Code pu) where
|
||||
hash c := hashCode c
|
||||
|
||||
deriving instance Hashable for DeclValue
|
||||
|
||||
@@ -21,46 +21,46 @@ private def typeDepOn (e : Expr) : M Bool := do
|
||||
let s ← read
|
||||
return e.hasAnyFVar fun fvarId => s.contains fvarId
|
||||
|
||||
private def argDepOn (a : Arg) : M Bool := do
|
||||
private def argDepOn (a : Arg pu) : M Bool := do
|
||||
match a with
|
||||
| .erased => return false
|
||||
| .fvar fvarId => fvarDepOn fvarId
|
||||
| .type e => typeDepOn e
|
||||
| .type e _ => typeDepOn e
|
||||
|
||||
private def letValueDepOn (e : LetValue) : M Bool :=
|
||||
private def letValueDepOn (e : LetValue pu) : M Bool :=
|
||||
match e with
|
||||
| .erased | .lit .. => return false
|
||||
| .proj _ _ fvarId => fvarDepOn fvarId
|
||||
| .proj _ _ fvarId _ => fvarDepOn fvarId
|
||||
| .fvar fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn
|
||||
| .const _ _ args => args.anyM argDepOn
|
||||
| .const _ _ args _ => args.anyM argDepOn
|
||||
|
||||
private def LetDecl.depOn (decl : LetDecl) : M Bool :=
|
||||
private def LetDecl.depOn (decl : LetDecl pu) : M Bool :=
|
||||
typeDepOn decl.type <||> letValueDepOn decl.value
|
||||
|
||||
private partial def depOn (c : Code) : M Bool :=
|
||||
private partial def depOn (c : Code pu) : M Bool :=
|
||||
match c with
|
||||
| .let decl k => decl.depOn <||> depOn k
|
||||
| .jp decl k | .fun decl k => typeDepOn decl.type <||> depOn decl.value <||> depOn k
|
||||
| .jp decl k | .fun decl k _ => typeDepOn decl.type <||> depOn decl.value <||> depOn k
|
||||
| .cases c => typeDepOn c.resultType <||> fvarDepOn c.discr <||> c.alts.anyM fun alt => depOn alt.getCode
|
||||
| .jmp fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn
|
||||
| .return fvarId => fvarDepOn fvarId
|
||||
| .unreach _ => return false
|
||||
|
||||
@[inline] def LetDecl.dependsOn (decl : LetDecl) (s : FVarIdSet) : Bool :=
|
||||
@[inline] def LetDecl.dependsOn (decl : LetDecl pu) (s : FVarIdSet) : Bool :=
|
||||
decl.depOn s
|
||||
|
||||
@[inline] def FunDecl.dependsOn (decl : FunDecl) (s : FVarIdSet) : Bool :=
|
||||
@[inline] def FunDecl.dependsOn (decl : FunDecl pu) (s : FVarIdSet) : Bool :=
|
||||
typeDepOn decl.type s || depOn decl.value s
|
||||
|
||||
def CodeDecl.dependsOn (decl : CodeDecl) (s : FVarIdSet) : Bool :=
|
||||
def CodeDecl.dependsOn (decl : CodeDecl pu) (s : FVarIdSet) : Bool :=
|
||||
match decl with
|
||||
| .let decl => decl.dependsOn s
|
||||
| .jp decl | .fun decl => decl.dependsOn s
|
||||
| .jp decl | .fun decl _ => decl.dependsOn s
|
||||
|
||||
/--
|
||||
Return `true` is `c` depends on a free variable in `s`.
|
||||
-/
|
||||
def Code.dependsOn (c : Code) (s : FVarIdSet) : Bool :=
|
||||
def Code.dependsOn (c : Code pu) (s : FVarIdSet) : Bool :=
|
||||
depOn c s
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
||||
@@ -19,16 +19,16 @@ Collect set of (let) free variables in a LCNF value.
|
||||
This code exploits the LCNF property that local declarations do not occur in types.
|
||||
-/
|
||||
|
||||
def collectLocalDeclsArg (s : UsedLocalDecls) (arg : Arg) : UsedLocalDecls :=
|
||||
def collectLocalDeclsArg (s : UsedLocalDecls) (arg : Arg .pure) : UsedLocalDecls :=
|
||||
match arg with
|
||||
| .fvar fvarId => s.insert fvarId
|
||||
-- Locally declared variables do not occur in types.
|
||||
| .type _ | .erased => s
|
||||
|
||||
def collectLocalDeclsArgs (s : UsedLocalDecls) (args : Array Arg) : UsedLocalDecls :=
|
||||
def collectLocalDeclsArgs (s : UsedLocalDecls) (args : Array (Arg .pure)) : UsedLocalDecls :=
|
||||
args.foldl (init := s) collectLocalDeclsArg
|
||||
|
||||
def collectLocalDeclsLetValue (s : UsedLocalDecls) (e : LetValue) : UsedLocalDecls :=
|
||||
def collectLocalDeclsLetValue (s : UsedLocalDecls) (e : LetValue .pure) : UsedLocalDecls :=
|
||||
match e with
|
||||
| .erased | .lit .. => s
|
||||
| .proj _ _ fvarId => s.insert fvarId
|
||||
@@ -39,21 +39,22 @@ namespace ElimDead
|
||||
|
||||
abbrev M := StateRefT UsedLocalDecls CompilerM
|
||||
|
||||
private abbrev collectArgM (arg : Arg) : M Unit :=
|
||||
private abbrev collectArgM (arg : Arg .pure) : M Unit :=
|
||||
modify (collectLocalDeclsArg · arg)
|
||||
|
||||
private abbrev collectLetValueM (e : LetValue) : M Unit :=
|
||||
private abbrev collectLetValueM (e : LetValue .pure) : M Unit :=
|
||||
modify (collectLocalDeclsLetValue · e)
|
||||
|
||||
private abbrev collectFVarM (fvarId : FVarId) : M Unit :=
|
||||
modify (·.insert fvarId)
|
||||
|
||||
mutual
|
||||
partial def visitFunDecl (funDecl : FunDecl) : M FunDecl := do
|
||||
|
||||
partial def visitFunDecl (funDecl : FunDecl .pure) : M (FunDecl .pure) := do
|
||||
let value ← elimDead funDecl.value
|
||||
funDecl.updateValue value
|
||||
|
||||
partial def elimDead (code : Code) : M Code := do
|
||||
partial def elimDead (code : Code .pure) : M (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let k ← elimDead k
|
||||
@@ -84,10 +85,11 @@ end
|
||||
|
||||
end ElimDead
|
||||
|
||||
def Code.elimDead (code : Code) : CompilerM Code :=
|
||||
-- TODO: Generalize this to arbitrary phases, keep in mind that in impure elim dead is not as easy though
|
||||
def Code.elimDead (code : Code .pure) : CompilerM (Code .pure) :=
|
||||
ElimDead.elimDead code |>.run' {}
|
||||
|
||||
def Decl.elimDead (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.elimDead (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
return { decl with value := (← decl.value.mapCodeM Code.elimDead) }
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
||||
@@ -239,14 +239,14 @@ Attempt to turn a `Value` that is representing a literal into a set of
|
||||
auxiliary declarations + the final `FVarId` of the declaration that
|
||||
contains the actual literal. If it is not a literal return none.
|
||||
-/
|
||||
partial def getLiteral (v : Value) : CompilerM (Option ((Array CodeDecl) × FVarId)) := do
|
||||
partial def getLiteral (v : Value) : CompilerM (Option ((Array (CodeDecl .pure)) × FVarId)) := do
|
||||
if isLiteral v then
|
||||
let literal ← go v
|
||||
return some literal
|
||||
else
|
||||
return none
|
||||
where
|
||||
go : Value → CompilerM ((Array CodeDecl) × FVarId)
|
||||
go : Value → CompilerM ((Array (CodeDecl .pure)) × FVarId)
|
||||
| .ctor ``Nat.zero #[] .. => do
|
||||
let decl ← mkAuxLetDecl <| .lit <| .nat <| 0
|
||||
return (#[.let decl], decl.fvarId)
|
||||
@@ -260,7 +260,7 @@ where
|
||||
let flatten acc := fun (decls, var) => (acc.fst ++ decls, acc.snd.push <| .fvar var)
|
||||
let (decls, args) :=
|
||||
fields.foldl (init := (#[], Array.replicate ctorInfo.numParams .erased)) flatten
|
||||
let letVal : LetValue := .const ctorName [] args
|
||||
let letVal : LetValue .pure := .const ctorName [] args
|
||||
let letDecl ← mkAuxLetDecl letVal
|
||||
return (decls.push <| .let letDecl, letDecl.fvarId)
|
||||
| _ => unreachable!
|
||||
@@ -328,7 +328,7 @@ structure InterpContext where
|
||||
a single declaration or a mutual block of declarations where their
|
||||
analysis might influence each other as we approach the fixpoint.
|
||||
-/
|
||||
decls : Array Decl
|
||||
decls : Array (Decl .pure)
|
||||
/--
|
||||
The index of the function we are currently operating on in `decls.`
|
||||
-/
|
||||
@@ -386,7 +386,7 @@ def findVarValue (var : FVarId) : InterpM Value := do
|
||||
/--
|
||||
Find the value of `arg` using the logic of `findVarValue`.
|
||||
-/
|
||||
def findArgValue (arg : Arg) : InterpM Value := do
|
||||
def findArgValue (arg : Arg .pure) : InterpM Value := do
|
||||
match arg with
|
||||
| .fvar fvarId => findVarValue fvarId
|
||||
| _ => return .top
|
||||
@@ -421,7 +421,8 @@ Furthermore if we see that `params.size != args.size` we know that this is
|
||||
a partial application and set the values of the remaining parameters to
|
||||
`top` since it is impossible to track what will happen with them from here on.
|
||||
-/
|
||||
def updateFunDeclParamsAssignment (params : Array Param) (args : Array Arg) : InterpM Bool := do
|
||||
def updateFunDeclParamsAssignment (params : Array (Param .pure)) (args : Array (Arg .pure)) :
|
||||
InterpM Bool := do
|
||||
let mut ret := false
|
||||
let env ← getEnv
|
||||
for param in params, arg in args do
|
||||
@@ -443,7 +444,7 @@ def updateFunDeclParamsAssignment (params : Array Param) (args : Array Arg) : In
|
||||
updateVarAssignment param.fvarId .top
|
||||
return ret
|
||||
|
||||
def updateFunDeclParamsTop (params : Array Param) : InterpM Bool := do
|
||||
def updateFunDeclParamsTop (params : Array (Param .pure)) : InterpM Bool := do
|
||||
let mut ret := false
|
||||
for param in params do
|
||||
let paramVal ← findVarValue param.fvarId
|
||||
@@ -453,7 +454,7 @@ def updateFunDeclParamsTop (params : Array Param) : InterpM Bool := do
|
||||
ret := true
|
||||
return ret
|
||||
|
||||
private partial def resetNestedFunDeclParams : Code → InterpM Unit
|
||||
private partial def resetNestedFunDeclParams : Code .pure → InterpM Unit
|
||||
| .let _ k => resetNestedFunDeclParams k
|
||||
| .jp decl k | .fun decl k => do
|
||||
decl.params.forM (resetVarAssignment ·.fvarId)
|
||||
@@ -467,7 +468,7 @@ private partial def resetNestedFunDeclParams : Code → InterpM Unit
|
||||
/--
|
||||
The actual abstract interpreter on a block of `Code`.
|
||||
-/
|
||||
partial def interpCode : Code → InterpM Unit
|
||||
partial def interpCode : Code .pure → InterpM Unit
|
||||
| .let decl k => do
|
||||
let val ← interpLetValue decl.value
|
||||
updateVarAssignment decl.fvarId val
|
||||
@@ -503,7 +504,7 @@ where
|
||||
/--
|
||||
The abstract interpreter on a `LetValue`.
|
||||
-/
|
||||
interpLetValue (letVal : LetValue) : InterpM Value := do
|
||||
interpLetValue (letVal : LetValue .pure) : InterpM Value := do
|
||||
match letVal with
|
||||
| .lit val => return .ofLCNFLit val
|
||||
| .proj _ idx struct =>
|
||||
@@ -513,7 +514,7 @@ where
|
||||
let env ← getEnv
|
||||
args.forM handleFunArg
|
||||
match (← getDecl? declName) with
|
||||
| some decl =>
|
||||
| some ⟨_, decl⟩ =>
|
||||
if decl.getArity == args.size then
|
||||
match getFunctionSummary? env declName with
|
||||
| some v => return v
|
||||
@@ -538,7 +539,7 @@ where
|
||||
return .top
|
||||
| .erased => return .top
|
||||
|
||||
handleFunArg (arg : Arg) : InterpM Unit := do
|
||||
handleFunArg (arg : Arg .pure) : InterpM Unit := do
|
||||
if let .fvar fvarId := arg then
|
||||
handleFunVar fvarId
|
||||
|
||||
@@ -557,7 +558,7 @@ where
|
||||
resetNestedFunDeclParams funDecl.value
|
||||
interpCode funDecl.value
|
||||
|
||||
interpFunCall (funDecl : FunDecl) (args : Array Arg) : InterpM Unit := do
|
||||
interpFunCall (funDecl : FunDecl .pure) (args : Array (Arg .pure)) : InterpM Unit := do
|
||||
let updated ← updateFunDeclParamsAssignment funDecl.params args
|
||||
if updated then
|
||||
/- We must reset the value of nested function declaration
|
||||
@@ -608,11 +609,11 @@ Use the information produced by the abstract interpreter to:
|
||||
- Eliminate branches that we know cannot be hit
|
||||
- Eliminate values that we know have to be constants.
|
||||
-/
|
||||
partial def elimDead (assignment : Assignment) (decl : Decl) : CompilerM Decl := do
|
||||
partial def elimDead (assignment : Assignment) (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
trace[Compiler.elimDeadBranches] s!"Eliminating {decl.name} with {repr (← assignment.toArray |>.mapM (fun (name, val) => do return (toString (← getBinderName name), val)))}"
|
||||
return { decl with value := (← decl.value.mapCodeM go) }
|
||||
where
|
||||
go (code : Code) : CompilerM Code := do
|
||||
go (code : Code .pure) : CompilerM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
return code.updateLet! decl (← go k)
|
||||
@@ -624,16 +625,14 @@ where
|
||||
match alt with
|
||||
| .alt ctor args body =>
|
||||
if discrVal.containsCtor ctor then
|
||||
let filter param := do
|
||||
let constantInfos ← args.filterMapM fun param => do
|
||||
if let some val := assignment[param.fvarId]? then
|
||||
if let some literal ← val.getLiteral then
|
||||
return some (param, literal)
|
||||
return none
|
||||
let constantInfos ← args.filterMapM filter
|
||||
if constantInfos.size != 0 then
|
||||
let folder := fun (body, subst) (param, decls, var) => do
|
||||
let (body, subst) ← constantInfos.foldlM (init := (← go body, {})) fun (body, subst) (param, decls, var) => do
|
||||
return (attachCodeDecls decls body, subst.insert param.fvarId (.fvar var))
|
||||
let (body, subst) ← constantInfos.foldlM (init := (← go body, {})) folder
|
||||
let body ← replaceFVars body subst false
|
||||
return alt.updateCode body
|
||||
else
|
||||
@@ -649,7 +648,7 @@ where
|
||||
end UnreachableBranches
|
||||
|
||||
open UnreachableBranches in
|
||||
def Decl.elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
def Decl.elimDeadBranches (decls : Array (Decl .pure)) : CompilerM (Array (Decl .pure)) := do
|
||||
/-
|
||||
We sort declarations by size here to ensure that when we restart in inferStep it will mostly be
|
||||
small declarations that get re-analyzed.
|
||||
|
||||
@@ -16,11 +16,11 @@ public section
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace ExtractClosed
|
||||
|
||||
abbrev ExtractM := StateRefT (Array CodeDecl) CompilerM
|
||||
abbrev ExtractM := StateRefT (Array (CodeDecl .pure)) CompilerM
|
||||
|
||||
mutual
|
||||
|
||||
partial def extractLetValue (v : LetValue) : ExtractM Unit := do
|
||||
partial def extractLetValue (v : LetValue .pure) : ExtractM Unit := do
|
||||
match v with
|
||||
| .const _ _ args => args.forM extractArg
|
||||
| .fvar fnVar args =>
|
||||
@@ -29,7 +29,7 @@ partial def extractLetValue (v : LetValue) : ExtractM Unit := do
|
||||
| .proj _ _ baseVar => extractFVar baseVar
|
||||
| .lit _ | .erased => return ()
|
||||
|
||||
partial def extractArg (arg : Arg) : ExtractM Unit := do
|
||||
partial def extractArg (arg : Arg .pure) : ExtractM Unit := do
|
||||
match arg with
|
||||
| .fvar fvarId => extractFVar fvarId
|
||||
| .type _ | .erased => return ()
|
||||
@@ -41,17 +41,17 @@ partial def extractFVar (fvarId : FVarId) : ExtractM Unit := do
|
||||
|
||||
end
|
||||
|
||||
def isIrrelevantArg (arg : Arg) : Bool :=
|
||||
def isIrrelevantArg (arg : Arg .pure) : Bool :=
|
||||
match arg with
|
||||
| .erased | .type _ => true
|
||||
| .fvar _ => false
|
||||
|
||||
structure Context where
|
||||
baseName : Name
|
||||
sccDecls : Array Decl
|
||||
sccDecls : Array (Decl .pure)
|
||||
|
||||
structure State where
|
||||
decls : Array Decl := {}
|
||||
decls : Array (Decl .pure) := {}
|
||||
/--
|
||||
Cache for `shouldExtractFVar` in order to avoid superlinear behavior.
|
||||
-/
|
||||
@@ -61,7 +61,7 @@ abbrev M := ReaderT Context $ StateRefT State CompilerM
|
||||
|
||||
mutual
|
||||
|
||||
partial def shouldExtractLetValue (isRoot : Bool) (v : LetValue) : M Bool := do
|
||||
partial def shouldExtractLetValue (isRoot : Bool) (v : LetValue .pure) : M Bool := do
|
||||
match v with
|
||||
| .lit (.str _) => return true
|
||||
| .lit (.nat v) =>
|
||||
@@ -90,7 +90,7 @@ partial def shouldExtractLetValue (isRoot : Bool) (v : LetValue) : M Bool := do
|
||||
| .fvar fnVar args => return (← shouldExtractFVar fnVar) && (← args.allM shouldExtractArg)
|
||||
| .proj _ _ baseVar => shouldExtractFVar baseVar
|
||||
|
||||
partial def shouldExtractArg (arg : Arg) : M Bool := do
|
||||
partial def shouldExtractArg (arg : Arg .pure) : M Bool := do
|
||||
match arg with
|
||||
| .fvar fvarId => shouldExtractFVar fvarId
|
||||
| .type _ | .erased => return true
|
||||
@@ -113,7 +113,7 @@ end
|
||||
|
||||
mutual
|
||||
|
||||
partial def visitCode (code : Code) : M Code := do
|
||||
partial def visitCode (code : Code .pure) : M (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
if (← shouldExtractLetValue true decl.value) then
|
||||
@@ -151,13 +151,14 @@ partial def visitCode (code : Code) : M Code := do
|
||||
|
||||
end
|
||||
|
||||
def visitDecl (decl : Decl) : M Decl := do
|
||||
def visitDecl (decl : Decl .pure) : M (Decl .pure) := do
|
||||
let value ← decl.value.mapCodeM visitCode
|
||||
return { decl with value }
|
||||
|
||||
end ExtractClosed
|
||||
|
||||
partial def Decl.extractClosed (decl : Decl) (sccDecls : Array Decl) : CompilerM (Array Decl) := do
|
||||
partial def Decl.extractClosed (decl : Decl .pure) (sccDecls : Array (Decl .pure)) :
|
||||
CompilerM (Array (Decl .pure)) := do
|
||||
let ⟨decl, s⟩ ← ExtractClosed.visitDecl decl |>.run { baseName := decl.name, sccDecls } |>.run {}
|
||||
return s.decls.push decl
|
||||
|
||||
|
||||
@@ -48,67 +48,67 @@ instance : TraverseFVar Expr where
|
||||
mapFVarM := Expr.mapFVarM
|
||||
forFVarM := Expr.forFVarM
|
||||
|
||||
def Arg.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (arg : Arg) : m Arg := do
|
||||
def Arg.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (arg : Arg pu) : m (Arg pu) := do
|
||||
match arg with
|
||||
| .erased => return .erased
|
||||
| .type e => return arg.updateType! (← TraverseFVar.mapFVarM f e)
|
||||
| .type e _ => return arg.updateType! (← TraverseFVar.mapFVarM f e)
|
||||
| .fvar fvarId => return arg.updateFVar! (← f fvarId)
|
||||
|
||||
def Arg.forFVarM [Monad m] (f : FVarId → m Unit) (arg : Arg) : m Unit := do
|
||||
def Arg.forFVarM [Monad m] (f : FVarId → m Unit) (arg : Arg pu) : m Unit := do
|
||||
match arg with
|
||||
| .erased => return ()
|
||||
| .type e => TraverseFVar.forFVarM f e
|
||||
| .type e _ => TraverseFVar.forFVarM f e
|
||||
| .fvar fvarId => f fvarId
|
||||
|
||||
instance : TraverseFVar Arg where
|
||||
instance : TraverseFVar (Arg pu) where
|
||||
mapFVarM := Arg.mapFVarM
|
||||
forFVarM := Arg.forFVarM
|
||||
|
||||
def LetValue.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (e : LetValue) : m LetValue := do
|
||||
def LetValue.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (e : LetValue pu) : m (LetValue pu) := do
|
||||
match e with
|
||||
| .lit .. | .erased => return e
|
||||
| .proj _ _ fvarId => return e.updateProj! (← f fvarId)
|
||||
| .const _ _ args => return e.updateArgs! (← args.mapM (TraverseFVar.mapFVarM f))
|
||||
| .proj _ _ fvarId _ => return e.updateProj! (← f fvarId)
|
||||
| .const _ _ args _ => return e.updateArgs! (← args.mapM (TraverseFVar.mapFVarM f))
|
||||
| .fvar fvarId args => return e.updateFVar! (← f fvarId) (← args.mapM (TraverseFVar.mapFVarM f))
|
||||
|
||||
def LetValue.forFVarM [Monad m] (f : FVarId → m Unit) (e : LetValue) : m Unit := do
|
||||
def LetValue.forFVarM [Monad m] (f : FVarId → m Unit) (e : LetValue pu) : m Unit := do
|
||||
match e with
|
||||
| .lit .. | .erased => return ()
|
||||
| .proj _ _ fvarId => f fvarId
|
||||
| .const _ _ args => args.forM (TraverseFVar.forFVarM f)
|
||||
| .proj _ _ fvarId _ => f fvarId
|
||||
| .const _ _ args _ => args.forM (TraverseFVar.forFVarM f)
|
||||
| .fvar fvarId args => f fvarId; args.forM (TraverseFVar.forFVarM f)
|
||||
|
||||
instance : TraverseFVar LetValue where
|
||||
instance : TraverseFVar (LetValue pu) where
|
||||
mapFVarM := LetValue.mapFVarM
|
||||
forFVarM := LetValue.forFVarM
|
||||
|
||||
partial def LetDecl.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (decl : LetDecl) : m LetDecl := do
|
||||
partial def LetDecl.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (decl : LetDecl pu) : m (LetDecl pu) := do
|
||||
decl.update (← Expr.mapFVarM f decl.type) (← LetValue.mapFVarM f decl.value)
|
||||
|
||||
partial def LetDecl.forFVarM [Monad m] (f : FVarId → m Unit) (decl : LetDecl) : m Unit := do
|
||||
partial def LetDecl.forFVarM [Monad m] (f : FVarId → m Unit) (decl : LetDecl pu) : m Unit := do
|
||||
Expr.forFVarM f decl.type
|
||||
LetValue.forFVarM f decl.value
|
||||
|
||||
instance : TraverseFVar LetDecl where
|
||||
instance : TraverseFVar (LetDecl pu) where
|
||||
mapFVarM := LetDecl.mapFVarM
|
||||
forFVarM := LetDecl.forFVarM
|
||||
|
||||
partial def Param.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (param : Param) : m Param := do
|
||||
partial def Param.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (param : Param pu) : m (Param pu) := do
|
||||
param.update (← Expr.mapFVarM f param.type)
|
||||
|
||||
partial def Param.forFVarM [Monad m] (f : FVarId → m Unit) (param : Param) : m Unit := do
|
||||
partial def Param.forFVarM [Monad m] (f : FVarId → m Unit) (param : Param pu) : m Unit := do
|
||||
Expr.forFVarM f param.type
|
||||
|
||||
instance : TraverseFVar Param where
|
||||
instance : TraverseFVar (Param pu) where
|
||||
mapFVarM := Param.mapFVarM
|
||||
forFVarM := Param.forFVarM
|
||||
|
||||
partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (c : Code) : m Code := do
|
||||
partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (c : Code pu) : m (Code pu) := do
|
||||
match c with
|
||||
| .let decl k =>
|
||||
let decl ← LetDecl.mapFVarM f decl
|
||||
return Code.updateLet! c decl (← mapFVarM f k)
|
||||
| .fun decl k =>
|
||||
| .fun decl k _ =>
|
||||
let params ← decl.params.mapM (Param.mapFVarM f)
|
||||
let decl ← decl.update (← Expr.mapFVarM f decl.type) params (← mapFVarM f decl.value)
|
||||
return Code.updateFun! c decl (← mapFVarM f k)
|
||||
@@ -125,12 +125,12 @@ partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m F
|
||||
| .unreach typ =>
|
||||
return Code.updateUnreach! c (← Expr.mapFVarM f typ)
|
||||
|
||||
partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code) : m Unit := do
|
||||
partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code pu) : m Unit := do
|
||||
match c with
|
||||
| .let decl k =>
|
||||
LetDecl.forFVarM f decl
|
||||
forFVarM f k
|
||||
| .fun decl k =>
|
||||
| .fun decl k _ =>
|
||||
decl.params.forM (Param.forFVarM f)
|
||||
Expr.forFVarM f decl.type
|
||||
forFVarM f decl.value
|
||||
@@ -151,45 +151,45 @@ partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code) : m Unit
|
||||
| .unreach typ =>
|
||||
Expr.forFVarM f typ
|
||||
|
||||
instance : TraverseFVar Code where
|
||||
instance : TraverseFVar (Code pu) where
|
||||
mapFVarM := Code.mapFVarM
|
||||
forFVarM := Code.forFVarM
|
||||
|
||||
def FunDecl.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (decl : FunDecl) : m FunDecl := do
|
||||
def FunDecl.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (decl : FunDecl pu) : m (FunDecl pu) := do
|
||||
let params ← decl.params.mapM (Param.mapFVarM f)
|
||||
decl.update (← Expr.mapFVarM f decl.type) params (← Code.mapFVarM f decl.value)
|
||||
|
||||
def FunDecl.forFVarM [Monad m] (f : FVarId → m Unit) (decl : FunDecl) : m Unit := do
|
||||
def FunDecl.forFVarM [Monad m] (f : FVarId → m Unit) (decl : FunDecl pu) : m Unit := do
|
||||
decl.params.forM (Param.forFVarM f)
|
||||
Expr.forFVarM f decl.type
|
||||
Code.forFVarM f decl.value
|
||||
|
||||
instance : TraverseFVar FunDecl where
|
||||
instance : TraverseFVar (FunDecl pu) where
|
||||
mapFVarM := FunDecl.mapFVarM
|
||||
forFVarM := FunDecl.forFVarM
|
||||
|
||||
instance : TraverseFVar CodeDecl where
|
||||
instance : TraverseFVar (CodeDecl pu) where
|
||||
mapFVarM f decl := do
|
||||
match decl with
|
||||
| .fun decl => return .fun (← mapFVarM f decl)
|
||||
| .fun decl _ => return .fun (← mapFVarM f decl)
|
||||
| .jp decl => return .jp (← mapFVarM f decl)
|
||||
| .let decl => return .let (← mapFVarM f decl)
|
||||
forFVarM f decl :=
|
||||
match decl with
|
||||
| .fun decl => forFVarM f decl
|
||||
| .fun decl _ => forFVarM f decl
|
||||
| .jp decl => forFVarM f decl
|
||||
| .let decl => forFVarM f decl
|
||||
|
||||
instance : TraverseFVar Alt where
|
||||
instance : TraverseFVar (Alt pu) where
|
||||
mapFVarM f alt := do
|
||||
match alt with
|
||||
| .alt ctor params c =>
|
||||
| .alt ctor params c _ =>
|
||||
let params ← params.mapM (Param.mapFVarM f)
|
||||
return .alt ctor params (← Code.mapFVarM f c)
|
||||
| .default c => return .default (← Code.mapFVarM f c)
|
||||
forFVarM f alt := do
|
||||
match alt with
|
||||
| .alt _ params c =>
|
||||
| .alt _ params c _ =>
|
||||
params.forM (Param.forFVarM f)
|
||||
Code.forFVarM f c
|
||||
| .default c => Code.forFVarM f c
|
||||
|
||||
@@ -46,12 +46,12 @@ inductive AbsValue where
|
||||
|
||||
structure Context where
|
||||
/-- Declaration in the same mutual block. -/
|
||||
decls : Array Decl
|
||||
decls : Array (Decl .pure)
|
||||
/--
|
||||
Function being analyzed. We check every recursive call to this function.
|
||||
Remark: `main` is in `decls`.
|
||||
-/
|
||||
main : Decl
|
||||
main : Decl .pure
|
||||
/--
|
||||
The assignment maps free variable ids in the current code being analyzed to abstract values.
|
||||
We only track the abstract value assigned to parameters.
|
||||
@@ -84,17 +84,17 @@ def evalFVar (fvarId : FVarId) : FixParamM AbsValue := do
|
||||
let some val := (← read).assignment.get? fvarId | return .top
|
||||
return val
|
||||
|
||||
def evalArg (arg : Arg) : FixParamM AbsValue := do
|
||||
def evalArg (arg : Arg .pure) : FixParamM AbsValue := do
|
||||
match arg with
|
||||
| .erased => return .erased
|
||||
| .type (.fvar fvarId) => evalFVar fvarId
|
||||
| .type _ => return .top
|
||||
| .type (.fvar fvarId) _ => evalFVar fvarId
|
||||
| .type _ _ => return .top
|
||||
| .fvar fvarId => evalFVar fvarId
|
||||
|
||||
def inMutualBlock (declName : Name) : FixParamM Bool :=
|
||||
return (← read).decls.any (·.name == declName)
|
||||
|
||||
def mkAssignment (decl : Decl) (values : Array AbsValue) : FVarIdMap AbsValue := Id.run do
|
||||
def mkAssignment (decl : Decl .pure) (values : Array AbsValue) : FVarIdMap AbsValue := Id.run do
|
||||
let mut assignment := {}
|
||||
for param in decl.params, value in values do
|
||||
assignment := assignment.insert param.fvarId value
|
||||
@@ -102,12 +102,12 @@ def mkAssignment (decl : Decl) (values : Array AbsValue) : FVarIdMap AbsValue :=
|
||||
|
||||
mutual
|
||||
|
||||
partial def evalLetValue (e : LetValue) : FixParamM Unit := do
|
||||
partial def evalLetValue (e : LetValue .pure) : FixParamM Unit := do
|
||||
match e with
|
||||
| .const declName _ args => evalApp declName args
|
||||
| .const declName _ args _ => evalApp declName args
|
||||
| _ => return ()
|
||||
|
||||
partial def isEquivalentFunDecl? (decl : FunDecl) : FixParamM (Option Nat) := do
|
||||
partial def isEquivalentFunDecl? (decl : FunDecl .pure) : FixParamM (Option Nat) := do
|
||||
let .let { fvarId, value := (.fvar funFvarId args), .. } k := decl.value | return none
|
||||
if args.size != decl.params.size then return none
|
||||
let .return retFVarId := k | return none
|
||||
@@ -120,10 +120,10 @@ partial def isEquivalentFunDecl? (decl : FunDecl) : FixParamM (Option Nat) := do
|
||||
if arg != .fvar param.fvarId && arg != .erased then return none
|
||||
return some funIdx
|
||||
|
||||
partial def evalCode (code : Code) : FixParamM Unit := do
|
||||
partial def evalCode (code : Code .pure) : FixParamM Unit := do
|
||||
match code with
|
||||
| .let decl k => evalLetValue decl.value; evalCode k
|
||||
| .fun decl k =>
|
||||
| .fun decl k _ =>
|
||||
if let some paramIdx ← isEquivalentFunDecl? decl then
|
||||
withReader (fun ctx =>
|
||||
{ ctx with assignment := ctx.assignment.insert decl.fvarId (.val paramIdx) })
|
||||
@@ -135,7 +135,7 @@ partial def evalCode (code : Code) : FixParamM Unit := do
|
||||
| .cases c => c.alts.forM fun alt => evalCode alt.getCode
|
||||
| .unreach .. | .jmp .. | .return .. => return ()
|
||||
|
||||
partial def evalApp (declName : Name) (args : Array Arg) : FixParamM Unit := do
|
||||
partial def evalApp (declName : Name) (args : Array (Arg .pure)) : FixParamM Unit := do
|
||||
let main := (← read).main
|
||||
if declName == main.name then
|
||||
-- Recursive call to the function being analyzed
|
||||
@@ -180,6 +180,9 @@ def mkInitialValues (numParams : Nat) : Array AbsValue := Id.run do
|
||||
end FixedParams
|
||||
open FixedParams
|
||||
|
||||
-- TODO: consider making it phase polymorphic, this requires detecting in place mutations of
|
||||
-- variables etc in addition to just graph theory
|
||||
|
||||
/--
|
||||
Given the (potentially mutually) recursive declarations `decls`,
|
||||
return a map from declaration name `decl.name` to a bit-mask `m` where `m[i]` is true
|
||||
@@ -188,7 +191,7 @@ applications.
|
||||
The function assumes that if a function `f` was declared in a mutual block, then `decls`
|
||||
contains all (computationally relevant) functions in the mutual block.
|
||||
-/
|
||||
def mkFixedParamsMap (decls : Array Decl) : NameMap (Array Bool) := Id.run do
|
||||
def mkFixedParamsMap (decls : Array (Decl .pure)) : NameMap (Array Bool) := Id.run do
|
||||
let mut result := {}
|
||||
for decl in decls do
|
||||
let values := mkInitialValues decl.params.size
|
||||
|
||||
@@ -38,7 +38,7 @@ inductive Decision where
|
||||
| unknown
|
||||
deriving Hashable, BEq, Inhabited, Repr
|
||||
|
||||
def Decision.ofAlt : Alt → Decision
|
||||
def Decision.ofAlt : Alt .pure → Decision
|
||||
| .alt name _ _ => .arm name
|
||||
| .default _ => .default
|
||||
|
||||
@@ -50,7 +50,7 @@ structure BaseFloatContext where
|
||||
All the declarations that were collected in the current LCNF basic
|
||||
block up to the current statement (in reverse order for efficiency).
|
||||
-/
|
||||
decls : List CodeDecl := []
|
||||
decls : List (CodeDecl .pure) := []
|
||||
|
||||
/--
|
||||
The state for `FloatM`
|
||||
@@ -67,7 +67,7 @@ structure FloatState where
|
||||
- Which declarations do we move into a certain arm
|
||||
- Which declarations do we move into the default arm
|
||||
-/
|
||||
newArms : Std.HashMap Decision (List CodeDecl)
|
||||
newArms : Std.HashMap Decision (List (CodeDecl .pure))
|
||||
|
||||
/--
|
||||
Use to collect relevant declarations for the floating mechanism.
|
||||
@@ -82,7 +82,7 @@ abbrev FloatM := StateRefT FloatState BaseFloatM
|
||||
/--
|
||||
Add `decl` to the list of declarations and run `x` with that updated context.
|
||||
-/
|
||||
def withNewCandidate (decl : CodeDecl) (x : BaseFloatM α) : BaseFloatM α :=
|
||||
def withNewCandidate (decl : CodeDecl .pure) (x : BaseFloatM α) : BaseFloatM α :=
|
||||
withReader (fun r => { r with decls := decl :: r.decls }) do
|
||||
x
|
||||
|
||||
@@ -98,7 +98,7 @@ Whether to ignore `decl` for the floating mechanism. We want to do this if:
|
||||
- `decl`' is storing a typeclass instance
|
||||
- `decl` is a projection from a variable that is storing a typeclass instance
|
||||
-/
|
||||
def ignore? (decl : LetDecl) : BaseFloatM Bool := do
|
||||
def ignore? (decl : LetDecl .pure) : BaseFloatM Bool := do
|
||||
if (← isArrowClass? decl.type).isSome then
|
||||
return true
|
||||
else if let .proj _ _ fvarId := decl.value then
|
||||
@@ -117,7 +117,7 @@ up to this point, with respect to `cs`. The initial decisions are:
|
||||
- `arm` or `default` if we see the declaration only being used in exactly one cases arm
|
||||
- `unknown` otherwise
|
||||
-/
|
||||
def initialDecisions (cs : Cases) : BaseFloatM (Std.HashMap FVarId Decision) := do
|
||||
def initialDecisions (cs : Cases .pure) : BaseFloatM (Std.HashMap FVarId Decision) := do
|
||||
let mut map := Std.HashMap.emptyWithCapacity (← read).decls.length
|
||||
let owned : Std.HashSet FVarId := ∅
|
||||
(map, _) ← (← read).decls.foldlM (init := (map, owned)) fun (acc, owned) val => do
|
||||
@@ -135,12 +135,12 @@ def initialDecisions (cs : Cases) : BaseFloatM (Std.HashMap FVarId Decision) :=
|
||||
(_, map) ← goCases cs |>.run map
|
||||
return map
|
||||
where
|
||||
visitDecl (env : Environment) (value : CodeDecl) : StateM (Std.HashSet FVarId) Bool := do
|
||||
visitDecl (env : Environment) (value : CodeDecl .pure) : StateM (Std.HashSet FVarId) Bool := do
|
||||
match value with
|
||||
| .let decl => visitLetValue env decl.value
|
||||
| _ => return false -- will need to investigate whether that can be a problem
|
||||
|
||||
visitLetValue (env : Environment) (value : LetValue) : StateM (Std.HashSet FVarId) Bool := do
|
||||
visitLetValue (env : Environment) (value : LetValue .pure) : StateM (Std.HashSet FVarId) Bool := do
|
||||
match value with
|
||||
| .proj _ _ x => visitArg (.fvar x) true
|
||||
| .const nm _ args =>
|
||||
@@ -158,7 +158,7 @@ where
|
||||
(← visitArg (.fvar x) false)
|
||||
| .erased | .lit _ => return false
|
||||
|
||||
visitArg (var : Arg) (borrowed : Bool) : StateM (Std.HashSet FVarId) Bool := do
|
||||
visitArg (var : Arg .pure) (borrowed : Bool) : StateM (Std.HashSet FVarId) Bool := do
|
||||
let .fvar v := var | return false
|
||||
let res := (← get).contains v
|
||||
unless borrowed do
|
||||
@@ -173,16 +173,16 @@ where
|
||||
modify fun s => s.insert var .dont
|
||||
-- otherwise we already have the proper decision
|
||||
|
||||
goAlt (alt : Alt) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
goAlt (alt : Alt .pure) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
forFVarM (goFVar (.ofAlt alt)) alt
|
||||
goCases (cs : Cases) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
goCases (cs : Cases .pure) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
cs.alts.forM goAlt
|
||||
|
||||
/--
|
||||
Compute the initial new arms. This will just set up a map from all arms of
|
||||
`cs` to empty `Array`s, plus one additional entry for `dont`.
|
||||
-/
|
||||
def initialNewArms (cs : Cases) : Std.HashMap Decision (List CodeDecl) := Id.run do
|
||||
def initialNewArms (cs : Cases .pure) : Std.HashMap Decision (List (CodeDecl .pure)) := Id.run do
|
||||
let mut map := Std.HashMap.emptyWithCapacity (cs.alts.size + 1)
|
||||
map := map.insert .dont []
|
||||
cs.alts.foldr (init := map) fun val acc => acc.insert (.ofAlt val) []
|
||||
@@ -203,7 +203,7 @@ cases z with
|
||||
Here `x` and `y` are originally marked as getting floated into `n` and `m`
|
||||
respectively but since `z` can't be moved we don't want that to move `x` and `y`.
|
||||
-/
|
||||
def dontFloat (decl : CodeDecl) : FloatM Unit := do
|
||||
def dontFloat (decl : CodeDecl .pure) : FloatM Unit := do
|
||||
forFVarM goFVar decl
|
||||
modify fun s => { s with newArms := s.newArms.insert .dont (decl :: s.newArms[Decision.dont]!) }
|
||||
where
|
||||
@@ -257,7 +257,7 @@ Will:
|
||||
```
|
||||
If we are at `y` `x` is still marked to be moved but we don't want that.
|
||||
-/
|
||||
def float (decl : CodeDecl) : FloatM Unit := do
|
||||
def float (decl : CodeDecl .pure) : FloatM Unit := do
|
||||
let arm := (← get).decision[decl.fvarId]!
|
||||
forFVarM (goFVar · arm) decl
|
||||
modify fun s => { s with newArms := s.newArms.insert arm (decl :: s.newArms[arm]!) }
|
||||
@@ -273,7 +273,7 @@ where
|
||||
Iterate through `decl`, pushing local declarations that are only used in one
|
||||
control flow arm into said arm in order to avoid useless computations.
|
||||
-/
|
||||
partial def floatLetIn (decl : Decl) : CompilerM Decl := do
|
||||
partial def floatLetIn (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
let newValue ← decl.value.mapCodeM go |>.run {}
|
||||
return { decl with value := newValue }
|
||||
where
|
||||
@@ -296,7 +296,7 @@ where
|
||||
else
|
||||
float decl
|
||||
|
||||
go (code : Code) : BaseFloatM Code := do
|
||||
go (code : Code .pure) : BaseFloatM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
withNewCandidate (.let decl) do
|
||||
@@ -334,11 +334,12 @@ where
|
||||
|
||||
end FloatLetIn
|
||||
|
||||
def Decl.floatLetIn (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.floatLetIn (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
FloatLetIn.floatLetIn decl
|
||||
|
||||
def floatLetIn (phase := Phase.base) (occurrence := 0) : Pass :=
|
||||
.mkPerDeclaration `floatLetIn Decl.floatLetIn phase occurrence
|
||||
phase.withPurityCheck .pure fun h =>
|
||||
.mkPerDeclaration `floatLetIn phase (h ▸ Decl.floatLetIn) occurrence
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.floatLetIn (inherited := true)
|
||||
|
||||
@@ -14,6 +14,10 @@ public section
|
||||
namespace Lean.Compiler.LCNF
|
||||
/-! # Type inference for LCNF -/
|
||||
|
||||
namespace InferType
|
||||
|
||||
namespace Pure
|
||||
|
||||
/-
|
||||
Note about **erasure confusion**.
|
||||
|
||||
@@ -53,10 +57,9 @@ but the expected type is `S Nat Type (fun x => Nat)`. `fun x => Nat` is not eras
|
||||
here because it is a type former.
|
||||
-/
|
||||
|
||||
namespace InferType
|
||||
|
||||
/-
|
||||
Type inference algorithm for LCNF. Invoked by the LCNF type checker
|
||||
Type inference algorithm for pure LCNF. Invoked by the LCNF type checker
|
||||
to check correctness of LCNF IR.
|
||||
-/
|
||||
|
||||
@@ -80,12 +83,12 @@ def mkForallFVars (xs : Array Expr) (type : Expr) : InferTypeM Expr :=
|
||||
let b := type.abstract xs
|
||||
xs.size.foldRevM (init := b) fun i _ b => do
|
||||
let x := xs[i]
|
||||
let n ← InferType.getBinderName x.fvarId!
|
||||
let ty ← InferType.getType x.fvarId!
|
||||
let n ← getBinderName x.fvarId!
|
||||
let ty ← getType x.fvarId!
|
||||
let ty := ty.abstractRange i xs;
|
||||
return .forallE n ty b .default
|
||||
|
||||
def mkForallParams (params : Array Param) (type : Expr) : InferTypeM Expr :=
|
||||
def mkForallParams (params : Array (Param .pure)) (type : Expr) : InferTypeM Expr :=
|
||||
let xs := params.map fun p => .fvar p.fvarId
|
||||
mkForallFVars xs type |>.run {}
|
||||
|
||||
@@ -97,7 +100,7 @@ def mkForallParams (params : Array Param) (type : Expr) : InferTypeM Expr :=
|
||||
def inferConstType (declName : Name) (us : List Level) : CompilerM Expr := do
|
||||
if declName == ``lcErased then
|
||||
return erasedExpr
|
||||
else if let some decl ← getDecl? declName then
|
||||
else if let some ⟨_, decl⟩ ← getDecl? declName then
|
||||
return decl.instantiateTypeLevelParams us
|
||||
else
|
||||
/- Declaration does not have code associated with it: constructor, inductive type, foreign function -/
|
||||
@@ -114,7 +117,7 @@ def inferLitValueType (value : LitValue) : Expr :=
|
||||
| .usize .. => mkConst ``USize
|
||||
|
||||
mutual
|
||||
partial def inferArgType (arg : Arg) : InferTypeM Expr :=
|
||||
partial def inferArgType (arg : Arg .pure) : InferTypeM Expr :=
|
||||
match arg with
|
||||
| .erased => return erasedExpr
|
||||
| .type e => inferType e
|
||||
@@ -124,13 +127,13 @@ mutual
|
||||
match e with
|
||||
| .const c us => inferConstType c us
|
||||
| .app .. => inferAppType e
|
||||
| .fvar fvarId => InferType.getType fvarId
|
||||
| .fvar fvarId => getType fvarId
|
||||
| .sort lvl => return .sort (mkLevelSucc lvl)
|
||||
| .forallE .. => inferForallType e
|
||||
| .lam .. => inferLambdaType e
|
||||
| .letE .. | .mvar .. | .mdata .. | .lit .. | .bvar .. | .proj .. => unreachable!
|
||||
|
||||
partial def inferLetValueType (e : LetValue) : InferTypeM Expr := do
|
||||
partial def inferLetValueType (e : LetValue .pure) : InferTypeM Expr := do
|
||||
match e with
|
||||
| .erased => return erasedExpr
|
||||
| .lit v => return inferLitValueType v
|
||||
@@ -138,7 +141,7 @@ mutual
|
||||
| .const declName us args => inferAppTypeCore (← inferConstType declName us) args
|
||||
| .fvar fvarId args => inferAppTypeCore (← getType fvarId) args
|
||||
|
||||
partial def inferAppTypeCore (fType : Expr) (args : Array Arg) : InferTypeM Expr := do
|
||||
partial def inferAppTypeCore (fType : Expr) (args : Array (Arg .pure)) : InferTypeM Expr := do
|
||||
let mut j := 0
|
||||
let mut fType := fType
|
||||
for i in *...args.size do
|
||||
@@ -237,60 +240,79 @@ mutual
|
||||
mkForallFVars fvars type
|
||||
|
||||
end
|
||||
end Pure
|
||||
|
||||
namespace Impure
|
||||
end Impure
|
||||
|
||||
end InferType
|
||||
|
||||
-- TODO
|
||||
def inferType (e : Expr) : CompilerM Expr :=
|
||||
InferType.inferType e |>.run {}
|
||||
InferType.Pure.inferType e |>.run {}
|
||||
|
||||
def inferAppType (fnType : Expr) (args : Array Arg) : CompilerM Expr :=
|
||||
InferType.inferAppTypeCore fnType args |>.run {}
|
||||
def inferAppType (fnType : Expr) (args : Array (Arg pu)) : CompilerM Expr :=
|
||||
match pu with
|
||||
| .pure => InferType.Pure.inferAppTypeCore fnType args |>.run {}
|
||||
| .impure => panic! "Infer type for impure unimplemented" -- TODO
|
||||
|
||||
def getLevel (type : Expr) : CompilerM Level := do
|
||||
match (← inferType type) with
|
||||
| .sort u => return u
|
||||
| e => if e.isErased then return levelOne else throwError "type expected{indentExpr type}"
|
||||
def Arg.inferType (arg : Arg pu) : CompilerM Expr :=
|
||||
match pu with
|
||||
| .pure => InferType.Pure.inferArgType arg |>.run {}
|
||||
| .impure => panic! "Infer type for impure unimplemented" -- TODO
|
||||
|
||||
def Arg.inferType (arg : Arg) : CompilerM Expr :=
|
||||
InferType.inferArgType arg |>.run {}
|
||||
def LetValue.inferType (e : LetValue pu) : CompilerM Expr :=
|
||||
match pu with
|
||||
| .pure => InferType.Pure.inferLetValueType e |>.run {}
|
||||
| .impure => panic! "Infer type for impure unimplemented" -- TODO
|
||||
|
||||
def LetValue.inferType (e : LetValue) : CompilerM Expr :=
|
||||
InferType.inferLetValueType e |>.run {}
|
||||
def Code.inferType (code : Code pu) : CompilerM Expr := do
|
||||
match pu with
|
||||
| .pure =>
|
||||
match code with
|
||||
| .let _ k | .fun _ k _ | .jp _ k => k.inferType
|
||||
| .return fvarId => getType fvarId
|
||||
| .jmp fvarId args => InferType.Pure.inferAppTypeCore (← getType fvarId) args |>.run {}
|
||||
| .unreach type => return type
|
||||
| .cases c => return c.resultType
|
||||
| .impure => panic! "Infer type for impure unimplemented" -- TODO
|
||||
|
||||
def Code.inferType (code : Code) : CompilerM Expr := do
|
||||
match code with
|
||||
| .let _ k | .fun _ k | .jp _ k => k.inferType
|
||||
| .return fvarId => getType fvarId
|
||||
| .jmp fvarId args => InferType.inferAppTypeCore (← getType fvarId) args |>.run {}
|
||||
| .unreach type => return type
|
||||
| .cases c => return c.resultType
|
||||
|
||||
def Code.inferParamType (params : Array Param) (code : Code) : CompilerM Expr := do
|
||||
def Code.inferParamType (params : Array (Param pu)) (code : Code pu) : CompilerM Expr := do
|
||||
let type ← code.inferType
|
||||
let xs := params.map fun p => .fvar p.fvarId
|
||||
InferType.mkForallFVars xs type |>.run {}
|
||||
InferType.Pure.mkForallFVars xs type |>.run {}
|
||||
|
||||
def Alt.inferType (alt : Alt) : CompilerM Expr :=
|
||||
def Alt.inferType (alt : Alt pu) : CompilerM Expr :=
|
||||
alt.getCode.inferType
|
||||
|
||||
def mkAuxLetDecl (e : LetValue) (prefixName := `_x) : CompilerM LetDecl := do
|
||||
def mkAuxLetDecl (e : LetValue pu) (prefixName := `_x) : CompilerM (LetDecl pu) := do
|
||||
mkLetDecl (← mkFreshBinderName prefixName) (← e.inferType) e
|
||||
|
||||
def mkForallParams (params : Array Param) (type : Expr) : CompilerM Expr :=
|
||||
InferType.mkForallParams params type |>.run {}
|
||||
def mkForallParams (params : Array (Param pu)) (type : Expr) : CompilerM Expr :=
|
||||
match pu with
|
||||
| .pure => InferType.Pure.mkForallParams params type |>.run {}
|
||||
| .impure => panic! "Infer type for impure unimplemented" -- TODO
|
||||
|
||||
def mkAuxFunDecl (params : Array Param) (code : Code) (prefixName := `_f) : CompilerM FunDecl := do
|
||||
private def mkAuxFunDeclAux (params : Array (Param pu)) (code : Code pu) (prefixName : Name) :
|
||||
CompilerM (FunDecl pu) := do
|
||||
let type ← mkForallParams params (← code.inferType)
|
||||
let binderName ← mkFreshBinderName prefixName
|
||||
mkFunDecl binderName type params code
|
||||
|
||||
def mkAuxJpDecl (params : Array Param) (code : Code) (prefixName := `_jp) : CompilerM FunDecl := do
|
||||
mkAuxFunDecl params code prefixName
|
||||
def mkAuxFunDecl (params : Array (Param .pure)) (code : Code .pure) (prefixName := `_f) :
|
||||
CompilerM (FunDecl .pure) := do
|
||||
mkAuxFunDeclAux params code prefixName
|
||||
|
||||
def mkAuxJpDecl' (param : Param) (code : Code) (prefixName := `_jp) : CompilerM FunDecl := do
|
||||
def mkAuxJpDecl (params : Array (Param pu)) (code : Code pu) (prefixName := `_jp) :
|
||||
CompilerM (FunDecl pu) := do
|
||||
mkAuxFunDeclAux params code prefixName
|
||||
|
||||
def mkAuxJpDecl' (param : Param pu) (code : Code pu) (prefixName := `_jp) :
|
||||
CompilerM (FunDecl pu) := do
|
||||
let params := #[param]
|
||||
mkAuxFunDecl params code prefixName
|
||||
mkAuxFunDeclAux params code prefixName
|
||||
|
||||
def mkCasesResultType (alts : Array Alt) : CompilerM Expr := do
|
||||
def mkCasesResultType (alts : Array (Alt pu)) : CompilerM Expr := do
|
||||
if alts.isEmpty then
|
||||
throwError "`Code.bind` failed, empty `cases` found"
|
||||
let mut resultType ← alts[0]!.inferType
|
||||
|
||||
@@ -22,44 +22,45 @@ private def refreshBinderName (binderName : Name) : CompilerM Name := do
|
||||
|
||||
namespace Internalize
|
||||
|
||||
abbrev InternalizeM := StateRefT FVarSubst CompilerM
|
||||
abbrev InternalizeM (pu : Purity) := StateRefT (FVarSubst pu) CompilerM
|
||||
|
||||
/--
|
||||
The `InternalizeM` monad is a translator. It "translates" the free variables
|
||||
in the input expressions and `Code`, into new fresh free variables in the
|
||||
local context.
|
||||
-/
|
||||
instance : MonadFVarSubst InternalizeM true where
|
||||
instance : MonadFVarSubst (InternalizeM pu) pu true where
|
||||
getSubst := get
|
||||
|
||||
instance : MonadFVarSubstState InternalizeM where
|
||||
instance : MonadFVarSubstState (InternalizeM pu) pu where
|
||||
modifySubst := modify
|
||||
|
||||
private def mkNewFVarId (fvarId : FVarId) : InternalizeM FVarId := do
|
||||
private def mkNewFVarId (fvarId : FVarId) : InternalizeM pu FVarId := do
|
||||
let fvarId' ← Lean.mkFreshFVarId
|
||||
addFVarSubst fvarId fvarId'
|
||||
return fvarId'
|
||||
|
||||
private partial def internalizeExpr (e : Expr) : InternalizeM Expr :=
|
||||
private partial def internalizeExpr (e : Expr) : InternalizeM pu Expr :=
|
||||
go e
|
||||
where
|
||||
goApp (e : Expr) : InternalizeM Expr := do
|
||||
goApp (e : Expr) : InternalizeM pu Expr := do
|
||||
match e with
|
||||
| .app f a => return e.updateApp! (← goApp f) (← go a)
|
||||
| _ => go e
|
||||
|
||||
go (e : Expr) : InternalizeM Expr := do
|
||||
go (e : Expr) : InternalizeM pu Expr := do
|
||||
if e.hasFVar then
|
||||
match e with
|
||||
| .fvar fvarId => match (← get)[fvarId]? with
|
||||
| .fvar fvarId =>
|
||||
match (← get)[fvarId]? with
|
||||
| some (.fvar fvarId') =>
|
||||
-- In LCNF, types can't depend on let-bound fvars.
|
||||
if (← findParam? fvarId').isSome then
|
||||
if (← findParam? (pu := pu) fvarId').isSome then
|
||||
return .fvar fvarId'
|
||||
else
|
||||
return anyExpr
|
||||
| some .erased => return erasedExpr
|
||||
| some (.type e) | none => return e
|
||||
| some (.type e _) | none => return e
|
||||
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => return e
|
||||
| .app f a => return e.updateApp! (← goApp f) (← go a) |>.headBeta
|
||||
| .mdata _ b => return e.updateMData! (← go b)
|
||||
@@ -70,7 +71,7 @@ where
|
||||
else
|
||||
return e
|
||||
|
||||
def internalizeParam (p : Param) : InternalizeM Param := do
|
||||
def internalizeParam (p : Param pu) : InternalizeM pu (Param pu) := do
|
||||
let binderName ← refreshBinderName p.binderName
|
||||
let type ← internalizeExpr p.type
|
||||
let fvarId ← mkNewFVarId p.fvarId
|
||||
@@ -78,31 +79,31 @@ def internalizeParam (p : Param) : InternalizeM Param := do
|
||||
modifyLCtx fun lctx => lctx.addParam p
|
||||
return p
|
||||
|
||||
def internalizeArg (arg : Arg) : InternalizeM Arg := do
|
||||
def internalizeArg (arg : Arg pu) : InternalizeM pu (Arg pu) := do
|
||||
match arg with
|
||||
| .fvar fvarId =>
|
||||
match (← get)[fvarId]? with
|
||||
| some arg'@(.fvar _) => return arg'
|
||||
| some arg'@.erased | some arg'@(.type _) => return arg'
|
||||
| some arg'@.erased | some arg'@(.type _ _) => return arg'
|
||||
| none => return arg
|
||||
| .type e => return arg.updateType! (← internalizeExpr e)
|
||||
| .type e _ => return arg.updateType! (← internalizeExpr e)
|
||||
| .erased => return arg
|
||||
|
||||
def internalizeArgs (args : Array Arg) : InternalizeM (Array Arg) :=
|
||||
def internalizeArgs (args : Array (Arg pu)) : InternalizeM pu (Array (Arg pu)) :=
|
||||
args.mapM internalizeArg
|
||||
|
||||
private partial def internalizeLetValue (e : LetValue) : InternalizeM LetValue := do
|
||||
private partial def internalizeLetValue (e : LetValue pu) : InternalizeM pu (LetValue pu) := do
|
||||
match e with
|
||||
| .erased | .lit .. => return e
|
||||
| .proj _ _ fvarId => match (← normFVar fvarId) with
|
||||
| .proj _ _ fvarId _ => match (← normFVar fvarId) with
|
||||
| .fvar fvarId' => return e.updateProj! fvarId'
|
||||
| .erased => return .erased
|
||||
| .const _ _ args => return e.updateArgs! (← internalizeArgs args)
|
||||
| .const _ _ args _ => return e.updateArgs! (← internalizeArgs args)
|
||||
| .fvar fvarId args => match (← normFVar fvarId) with
|
||||
| .fvar fvarId' => return e.updateFVar! fvarId' (← internalizeArgs args)
|
||||
| .erased => return .erased
|
||||
|
||||
def internalizeLetDecl (decl : LetDecl) : InternalizeM LetDecl := do
|
||||
def internalizeLetDecl (decl : LetDecl pu) : InternalizeM pu (LetDecl pu) := do
|
||||
let binderName ← refreshBinderName decl.binderName
|
||||
let type ← internalizeExpr decl.type
|
||||
let value ← internalizeLetValue decl.value
|
||||
@@ -113,7 +114,7 @@ def internalizeLetDecl (decl : LetDecl) : InternalizeM LetDecl := do
|
||||
|
||||
mutual
|
||||
|
||||
partial def internalizeFunDecl (decl : FunDecl) : InternalizeM FunDecl := do
|
||||
partial def internalizeFunDecl (decl : FunDecl pu) : InternalizeM pu (FunDecl pu) := do
|
||||
let type ← internalizeExpr decl.type
|
||||
let binderName ← refreshBinderName decl.binderName
|
||||
let params ← decl.params.mapM internalizeParam
|
||||
@@ -123,10 +124,10 @@ partial def internalizeFunDecl (decl : FunDecl) : InternalizeM FunDecl := do
|
||||
modifyLCtx fun lctx => lctx.addFunDecl decl
|
||||
return decl
|
||||
|
||||
partial def internalizeCode (code : Code) : InternalizeM Code := do
|
||||
partial def internalizeCode (code : Code pu) : InternalizeM pu (Code pu) := do
|
||||
match code with
|
||||
| .let decl k => return .let (← internalizeLetDecl decl) (← internalizeCode k)
|
||||
| .fun decl k => return .fun (← internalizeFunDecl decl) (← internalizeCode k)
|
||||
| .fun decl k _ => return .fun (← internalizeFunDecl decl) (← internalizeCode k)
|
||||
| .jp decl k => return .jp (← internalizeFunDecl decl) (← internalizeCode k)
|
||||
| .return fvarId => withNormFVarResult (← normFVar fvarId) fun fvarId => return .return fvarId
|
||||
| .jmp fvarId args => withNormFVarResult (← normFVar fvarId) fun fvarId => return .jmp fvarId (← internalizeArgs args)
|
||||
@@ -134,19 +135,19 @@ partial def internalizeCode (code : Code) : InternalizeM Code := do
|
||||
| .cases c =>
|
||||
withNormFVarResult (← normFVar c.discr) fun discr => do
|
||||
let resultType ← internalizeExpr c.resultType
|
||||
let internalizeAltCode (k : Code) : InternalizeM Code :=
|
||||
let internalizeAltCode (k : Code pu) : InternalizeM pu (Code pu) :=
|
||||
internalizeCode k
|
||||
let alts ← c.alts.mapM fun
|
||||
| .alt ctorName params k => return .alt ctorName (← params.mapM internalizeParam) (← internalizeAltCode k)
|
||||
| .alt ctorName params k _ => return .alt ctorName (← params.mapM internalizeParam) (← internalizeAltCode k)
|
||||
| .default k => return .default (← internalizeAltCode k)
|
||||
return .cases ⟨c.typeName, resultType, discr, alts⟩
|
||||
|
||||
end
|
||||
|
||||
partial def internalizeCodeDecl (decl : CodeDecl) : InternalizeM CodeDecl := do
|
||||
partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl pu) := do
|
||||
match decl with
|
||||
| .let decl => return .let (← internalizeLetDecl decl)
|
||||
| .fun decl => return .fun (← internalizeFunDecl decl)
|
||||
| .fun decl _ => return .fun (← internalizeFunDecl decl)
|
||||
| .jp decl => return .jp (← internalizeFunDecl decl)
|
||||
|
||||
end Internalize
|
||||
@@ -154,14 +155,14 @@ end Internalize
|
||||
/--
|
||||
Refresh free variables ids in `code`, and store their declarations in the local context.
|
||||
-/
|
||||
partial def Code.internalize (code : Code) (s : FVarSubst := {}) : CompilerM Code :=
|
||||
partial def Code.internalize (code : Code pu) (s : FVarSubst pu := {}) : CompilerM (Code pu) :=
|
||||
Internalize.internalizeCode code |>.run' s
|
||||
|
||||
open Internalize in
|
||||
def Decl.internalize (decl : Decl) (s : FVarSubst := {}): CompilerM Decl :=
|
||||
def Decl.internalize (decl : Decl pu) (s : FVarSubst pu := {}): CompilerM (Decl pu) :=
|
||||
go decl |>.run' s
|
||||
where
|
||||
go (decl : Decl) : InternalizeM Decl := do
|
||||
go (decl : Decl pu) : InternalizeM pu (Decl pu) := do
|
||||
let type ← internalizeExpr decl.type
|
||||
let params ← decl.params.mapM internalizeParam
|
||||
let value ← decl.value.mapCodeM internalizeCode
|
||||
@@ -170,13 +171,13 @@ where
|
||||
/--
|
||||
Create a fresh local context and internalize the given decls.
|
||||
-/
|
||||
def cleanup (decl : Array Decl) : CompilerM (Array Decl) := do
|
||||
def cleanup (decl : Array (Decl pu)) : CompilerM (Array (Decl pu)) := do
|
||||
modify fun _ => {}
|
||||
decl.mapM fun decl => do
|
||||
modify fun s => { s with nextIdx := 1 }
|
||||
decl.internalize
|
||||
|
||||
def normalizeFVarIds (decl : Decl) : CoreM Decl := do
|
||||
def normalizeFVarIds (decl : Decl pu) : CoreM (Decl pu) := do
|
||||
let ngenSaved ← getNGen
|
||||
setNGen {}
|
||||
try
|
||||
|
||||
@@ -92,13 +92,13 @@ private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
|
||||
/--
|
||||
Remove all join point candidates contained in `a`.
|
||||
-/
|
||||
private partial def removeCandidatesInArg (a : Arg) : FindM Unit := do
|
||||
private partial def removeCandidatesInArg (a : Arg .pure) : FindM Unit := do
|
||||
forFVarM eraseCandidate a
|
||||
|
||||
/--
|
||||
Remove all join point candidates contained in `a`.
|
||||
-/
|
||||
private partial def removeCandidatesInLetValue (e : LetValue) : FindM Unit := do
|
||||
private partial def removeCandidatesInLetValue (e : LetValue .pure) : FindM Unit := do
|
||||
forFVarM eraseCandidate e
|
||||
|
||||
/--
|
||||
@@ -117,7 +117,7 @@ private def addDependency (src : FVarId) (target : FVarId) : FindM Unit := do
|
||||
{ targetInfo with associated := targetInfo.associated.insert src }
|
||||
|
||||
@[inline]
|
||||
private def withFnBody (decl : FunDecl) (x : FindM α) : FindM α :=
|
||||
private def withFnBody (decl : FunDecl .pure) (x : FindM α) : FindM α :=
|
||||
withReader (fun ctx => {
|
||||
ctx with
|
||||
definitionDepth := ctx.definitionDepth + 1,
|
||||
@@ -125,7 +125,7 @@ private def withFnBody (decl : FunDecl) (x : FindM α) : FindM α :=
|
||||
x
|
||||
|
||||
@[inline]
|
||||
private def withFnDefined (decl : FunDecl) (x : FindM α) : FindM α :=
|
||||
private def withFnDefined (decl : FunDecl .pure) (x : FindM α) : FindM α :=
|
||||
withReader (fun ctx => {
|
||||
ctx with
|
||||
scope := ctx.scope.insert decl.fvarId ctx.definitionDepth }) do
|
||||
@@ -163,11 +163,11 @@ def test (b : Bool) (x y : Nat) : Nat :=
|
||||
this. This is because otherwise the calls to `myjp` in `f` and `g` would
|
||||
produce out of scope join point jumps.
|
||||
-/
|
||||
partial def find (decl : Decl) : CompilerM FindState := do
|
||||
partial def find (decl : Decl .pure) : CompilerM FindState := do
|
||||
let (_, candidates) ← decl.value.forCodeM go |>.run {} |>.run {}
|
||||
return candidates
|
||||
where
|
||||
go : Code → FindM Unit
|
||||
go : Code .pure → FindM Unit
|
||||
| .let decl k => do
|
||||
match k, decl.value with
|
||||
| .return valId, .fvar fvarId args =>
|
||||
@@ -207,13 +207,13 @@ where
|
||||
Replace all join point candidate `fun` declarations with `jp` ones
|
||||
and all calls to them with `jmp`s.
|
||||
-/
|
||||
partial def replace (decl : Decl) (state : FindState) : CompilerM Decl := do
|
||||
partial def replace (decl : Decl .pure) (state : FindState) : CompilerM (Decl .pure) := do
|
||||
let mapper := fun acc cname _ => do return acc.insert cname (← mkFreshJpName)
|
||||
let replaceCtx : ReplaceCtx ← state.candidates.foldM (init := ∅) mapper
|
||||
let newValue ← decl.value.mapCodeM go |>.run replaceCtx
|
||||
return { decl with value := newValue }
|
||||
where
|
||||
go (code : Code) : ReplaceM Code := do
|
||||
go (code : Code .pure) : ReplaceM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
match k, decl.value with
|
||||
@@ -274,7 +274,7 @@ structure ExtendState where
|
||||
to `Param`s. The free variables in this map are the once that the context
|
||||
of said join point will be extended by passing in the respective parameter.
|
||||
-/
|
||||
fvarMap : Std.HashMap FVarId (Std.HashMap FVarId Param) := {}
|
||||
fvarMap : Std.HashMap FVarId (Std.HashMap FVarId (Param .pure)) := {}
|
||||
|
||||
/--
|
||||
The monad for the `extendJoinPointContext` pass.
|
||||
@@ -388,7 +388,7 @@ the join point. This is so in the case of nested join points that refer
|
||||
to parameters of the current one we extend the context of the nested
|
||||
join points by said parameters.
|
||||
-/
|
||||
def withNewJpScope (decl : FunDecl) (x : ExtendM α): ExtendM α := do
|
||||
def withNewJpScope (decl : FunDecl .pure) (x : ExtendM α): ExtendM α := do
|
||||
withReader (fun ctx => { ctx with currentJp? := some decl.fvarId }) do
|
||||
modify fun s => { s with fvarMap := s.fvarMap.insert decl.fvarId {} }
|
||||
withNewScope do
|
||||
@@ -401,7 +401,7 @@ It will back up the current scope (since we are doing a case split
|
||||
and want to continue with other arms afterwards) and add all of the
|
||||
parameters of the match arm to the list of candidates.
|
||||
-/
|
||||
def withNewAltScope (alt : Alt) (x : ExtendM α) : ExtendM α := do
|
||||
def withNewAltScope (alt : Alt .pure) (x : ExtendM α) : ExtendM α := do
|
||||
withBackTrackingScope do
|
||||
withNewCandidates (alt.getParams.map (·.fvarId)) do
|
||||
x
|
||||
@@ -418,7 +418,7 @@ All of this is done to eliminate dependencies of join points onto their
|
||||
position within the code so we can pull them out as far as possible, hopefully
|
||||
enabling new inlining possibilities in the next simplifier run.
|
||||
-/
|
||||
partial def extend (decl : Decl) : CompilerM Decl := do
|
||||
partial def extend (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
let newValue ← decl.value.mapCodeM go |>.run {} |>.run' {} |>.run' {}
|
||||
let decl := { decl with value := newValue }
|
||||
decl.pullFunDecls
|
||||
@@ -426,7 +426,7 @@ where
|
||||
goFVar (fvar : FVarId) : ExtendM FVarId := do
|
||||
extendByIfNecessary fvar
|
||||
replaceFVar fvar
|
||||
go (code : Code) : ExtendM Code := do
|
||||
go (code : Code .pure) : ExtendM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let decl ← decl.updateValue (← mapFVarM goFVar decl.value)
|
||||
@@ -491,7 +491,7 @@ structure AnalysisState where
|
||||
A map, that for each join point id contains a map from all (so far)
|
||||
duplicated argument ids to the respective duplicate value
|
||||
-/
|
||||
jpJmpArgs : FVarIdMap FVarSubst := {}
|
||||
jpJmpArgs : FVarIdMap (FVarSubst .pure) := {}
|
||||
|
||||
abbrev ReduceAnalysisM := ReaderT AnalysisCtx StateRefT AnalysisState ScopeM
|
||||
abbrev ReduceActionM := ReaderT AnalysisState CompilerM
|
||||
@@ -539,17 +539,17 @@ After we have performed all of these optimizations we can take away the
|
||||
(remaining) common arguments and end up with nicely floated and optimized
|
||||
code that has as little arguments as possible in the join points.
|
||||
-/
|
||||
partial def reduce (decl : Decl) : CompilerM Decl := do
|
||||
partial def reduce (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
let (_, analysis) ← decl.value.forCodeM goAnalyze |>.run {} |>.run {} |>.run' {}
|
||||
let newValue ← decl.value.mapCodeM goReduce |>.run analysis
|
||||
return { decl with value := newValue }
|
||||
where
|
||||
goAnalyzeFunDecl (fn : FunDecl) : ReduceAnalysisM Unit := do
|
||||
goAnalyzeFunDecl (fn : FunDecl .pure) : ReduceAnalysisM Unit := do
|
||||
withNewScope do
|
||||
fn.params.forM (addToScope ·.fvarId)
|
||||
goAnalyze fn.value
|
||||
|
||||
goAnalyze (code : Code) : ReduceAnalysisM Unit := do
|
||||
goAnalyze (code : Code .pure) : ReduceAnalysisM Unit := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
addToScope decl.fvarId
|
||||
@@ -571,7 +571,7 @@ where
|
||||
goAnalyze alt.getCode
|
||||
cs.alts.forM visitor
|
||||
| .jmp fn args =>
|
||||
let decl ← getFunDecl fn
|
||||
let decl ← getFunDecl (pu := .pure) fn
|
||||
if let some knownArgs := (← get).jpJmpArgs.get? fn then
|
||||
let mut newArgs := knownArgs
|
||||
for (param, arg) in decl.params.zip args do
|
||||
@@ -589,7 +589,7 @@ where
|
||||
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn interestingArgs }
|
||||
| .return .. | .unreach .. => return ()
|
||||
|
||||
goReduce (code : Code) : ReduceActionM Code := do
|
||||
goReduce (code : Code .pure) : ReduceActionM (Code .pure) := do
|
||||
match code with
|
||||
| .jp decl k =>
|
||||
if let some reducibleArgs := (← read).jpJmpArgs.get? decl.fvarId then
|
||||
@@ -613,7 +613,7 @@ where
|
||||
return Code.updateFun! code decl (← goReduce k)
|
||||
| .jmp fn args =>
|
||||
let reducibleArgs := (← read).jpJmpArgs.get! fn
|
||||
let decl ← getFunDecl fn
|
||||
let decl ← getFunDecl (pu := .pure) fn
|
||||
let newParams := decl.params.zip args
|
||||
|>.filter (!reducibleArgs.contains ·.fst.fvarId)
|
||||
|>.map Prod.snd
|
||||
@@ -630,7 +630,7 @@ where
|
||||
|
||||
end JoinPointCommonArgs
|
||||
|
||||
def Decl.findJoinPoints? (decl : Decl) : CompilerM (Option Decl) := do
|
||||
def Decl.findJoinPoints? (decl : Decl .pure) : CompilerM (Option (Decl .pure)) := do
|
||||
let findResult ← JoinPointFinder.find decl
|
||||
trace[Compiler.findJoinPoints] "Found {findResult.candidates.size} jp candidates for {decl.name}"
|
||||
if findResult.candidates.isEmpty then
|
||||
@@ -642,29 +642,32 @@ def Decl.findJoinPoints? (decl : Decl) : CompilerM (Option Decl) := do
|
||||
Find all `fun` declarations in `decl` that qualify as join points then replace
|
||||
their definitions and call sites with `jp`/`jmp`.
|
||||
-/
|
||||
def Decl.findJoinPoints (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.findJoinPoints (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
return (← Decl.findJoinPoints? decl).getD decl
|
||||
|
||||
def findJoinPoints (occurrence : Nat := 0) : Pass :=
|
||||
.mkPerDeclaration `findJoinPoints Decl.findJoinPoints .base (occurrence := occurrence)
|
||||
.mkPerDeclaration `findJoinPoints .base Decl.findJoinPoints (occurrence := occurrence)
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.findJoinPoints (inherited := true)
|
||||
|
||||
def Decl.extendJoinPointContext (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.extendJoinPointContext (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
JoinPointContextExtender.extend decl
|
||||
|
||||
-- TODO: It might make sense to extend this to impure one day
|
||||
def extendJoinPointContext (occurrence : Nat := 0) (phase := Phase.mono) (_h : phase ≠ .base := by simp): Pass :=
|
||||
.mkPerDeclaration `extendJoinPointContext Decl.extendJoinPointContext phase (occurrence := occurrence)
|
||||
phase.withPurityCheck .pure fun h =>
|
||||
.mkPerDeclaration `extendJoinPointContext phase (h ▸ Decl.extendJoinPointContext) (occurrence := occurrence)
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.extendJoinPointContext (inherited := true)
|
||||
|
||||
def Decl.commonJoinPointArgs (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.commonJoinPointArgs (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
JoinPointCommonArgs.reduce decl
|
||||
|
||||
-- TODO: It might make sense to extend this to impure one day
|
||||
def commonJoinPointArgs : Pass :=
|
||||
.mkPerDeclaration `commonJoinPointArgs Decl.commonJoinPointArgs .mono
|
||||
.mkPerDeclaration `commonJoinPointArgs .mono Decl.commonJoinPointArgs
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.commonJoinPointArgs (inherited := true)
|
||||
|
||||
@@ -16,61 +16,97 @@ namespace Lean.Compiler.LCNF
|
||||
LCNF local context.
|
||||
-/
|
||||
structure LCtx where
|
||||
params : Std.HashMap FVarId Param := {}
|
||||
letDecls : Std.HashMap FVarId LetDecl := {}
|
||||
funDecls : Std.HashMap FVarId FunDecl := {}
|
||||
paramsPure : Std.HashMap FVarId (Param .pure) := {}
|
||||
paramsImpure : Std.HashMap FVarId (Param .impure) := {}
|
||||
letDeclsPure : Std.HashMap FVarId (LetDecl .pure) := {}
|
||||
letDeclsImpure : Std.HashMap FVarId (LetDecl .impure) := {}
|
||||
funDeclsPure : Std.HashMap FVarId (FunDecl .pure) := {}
|
||||
funDeclsImpure : Std.HashMap FVarId (FunDecl .impure) := {}
|
||||
deriving Inhabited
|
||||
|
||||
def LCtx.addParam (lctx : LCtx) (param : Param) : LCtx :=
|
||||
{ lctx with params := lctx.params.insert param.fvarId param }
|
||||
def LCtx.addParam (lctx : LCtx) (param : Param pu) : LCtx :=
|
||||
match pu with
|
||||
| .pure => { lctx with paramsPure := lctx.paramsPure.insert param.fvarId param }
|
||||
| .impure => { lctx with paramsImpure := lctx.paramsImpure.insert param.fvarId param }
|
||||
|
||||
def LCtx.addLetDecl (lctx : LCtx) (letDecl : LetDecl) : LCtx :=
|
||||
{ lctx with letDecls := lctx.letDecls.insert letDecl.fvarId letDecl }
|
||||
def LCtx.addLetDecl (lctx : LCtx) (letDecl : LetDecl pu) : LCtx :=
|
||||
match pu with
|
||||
| .pure => { lctx with letDeclsPure := lctx.letDeclsPure.insert letDecl.fvarId letDecl }
|
||||
| .impure => { lctx with letDeclsImpure := lctx.letDeclsImpure.insert letDecl.fvarId letDecl }
|
||||
|
||||
def LCtx.addFunDecl (lctx : LCtx) (funDecl : FunDecl) : LCtx :=
|
||||
{ lctx with funDecls := lctx.funDecls.insert funDecl.fvarId funDecl }
|
||||
def LCtx.addFunDecl (lctx : LCtx) (funDecl : FunDecl pu) : LCtx :=
|
||||
match pu with
|
||||
| .pure => { lctx with funDeclsPure := lctx.funDeclsPure.insert funDecl.fvarId funDecl }
|
||||
| .impure => { lctx with funDeclsImpure := lctx.funDeclsImpure.insert funDecl.fvarId funDecl }
|
||||
|
||||
def LCtx.eraseParam (lctx : LCtx) (param : Param) : LCtx :=
|
||||
{ lctx with params := lctx.params.erase param.fvarId }
|
||||
def LCtx.eraseParam (lctx : LCtx) (param : Param pu) : LCtx :=
|
||||
match pu with
|
||||
| .pure => { lctx with paramsPure := lctx.paramsPure.erase param.fvarId }
|
||||
| .impure => { lctx with paramsImpure := lctx.paramsImpure.erase param.fvarId }
|
||||
|
||||
def LCtx.eraseParams (lctx : LCtx) (ps : Array Param) : LCtx :=
|
||||
{ lctx with params := ps.foldl (init := lctx.params) fun params p => params.erase p.fvarId }
|
||||
def LCtx.eraseParams (lctx : LCtx) (ps : Array (Param pu)) : LCtx :=
|
||||
match pu with
|
||||
| .pure => { lctx with paramsPure := ps.foldl (init := lctx.paramsPure) fun params p => params.erase p.fvarId }
|
||||
| .impure => { lctx with paramsImpure := ps.foldl (init := lctx.paramsImpure) fun params p => params.erase p.fvarId }
|
||||
|
||||
def LCtx.eraseLetDecl (lctx : LCtx) (decl : LetDecl) : LCtx :=
|
||||
{ lctx with letDecls := lctx.letDecls.erase decl.fvarId }
|
||||
def LCtx.eraseLetDecl (lctx : LCtx) (decl : LetDecl pu) : LCtx :=
|
||||
match pu with
|
||||
| .pure => { lctx with letDeclsPure := lctx.letDeclsPure.erase decl.fvarId }
|
||||
| .impure => { lctx with letDeclsImpure := lctx.letDeclsImpure.erase decl.fvarId }
|
||||
|
||||
mutual
|
||||
partial def LCtx.eraseFunDecl (lctx : LCtx) (decl : FunDecl) (recursive := true) : LCtx :=
|
||||
let lctx := { lctx with funDecls := lctx.funDecls.erase decl.fvarId }
|
||||
partial def LCtx.eraseFunDecl (lctx : LCtx) (decl : FunDecl pu) (recursive := true) : LCtx :=
|
||||
let lctx :=
|
||||
match pu with
|
||||
| .pure => { lctx with funDeclsPure := lctx.funDeclsPure.erase decl.fvarId }
|
||||
| .impure => { lctx with funDeclsImpure := lctx.funDeclsImpure.erase decl.fvarId }
|
||||
if recursive then
|
||||
eraseCode decl.value <| eraseParams lctx decl.params
|
||||
else
|
||||
lctx
|
||||
|
||||
partial def LCtx.eraseAlts (alts : Array Alt) (lctx : LCtx) : LCtx :=
|
||||
partial def LCtx.eraseAlts (alts : Array (Alt pu)) (lctx : LCtx) : LCtx :=
|
||||
alts.foldl (init := lctx) fun lctx alt =>
|
||||
match alt with
|
||||
| .default k => eraseCode k lctx
|
||||
| .alt _ ps k => eraseCode k <| eraseParams lctx ps
|
||||
| .alt _ ps k _ => eraseCode k <| eraseParams lctx ps
|
||||
|
||||
partial def LCtx.eraseCode (code : Code) (lctx : LCtx) : LCtx :=
|
||||
partial def LCtx.eraseCode (code : Code pu) (lctx : LCtx) : LCtx :=
|
||||
match code with
|
||||
| .let decl k => eraseCode k <| lctx.eraseLetDecl decl
|
||||
| .jp decl k | .fun decl k => eraseCode k <| eraseFunDecl lctx decl
|
||||
| .jp decl k | .fun decl k _ => eraseCode k <| eraseFunDecl lctx decl
|
||||
| .cases c => eraseAlts c.alts lctx
|
||||
| _ => lctx
|
||||
end
|
||||
|
||||
@[inline]
|
||||
def LCtx.params (lctx : LCtx) (pu : Purity) : Std.HashMap FVarId (Param pu) :=
|
||||
match pu with
|
||||
| .pure => lctx.paramsPure
|
||||
| .impure => lctx.paramsImpure
|
||||
|
||||
@[inline]
|
||||
def LCtx.letDecls (lctx : LCtx) (pu : Purity) : Std.HashMap FVarId (LetDecl pu) :=
|
||||
match pu with
|
||||
| .pure => lctx.letDeclsPure
|
||||
| .impure => lctx.letDeclsImpure
|
||||
|
||||
@[inline]
|
||||
def LCtx.funDecls (lctx : LCtx) (pu : Purity) : Std.HashMap FVarId (FunDecl pu) :=
|
||||
match pu with
|
||||
| .pure => lctx.funDeclsPure
|
||||
| .impure => lctx.funDeclsImpure
|
||||
|
||||
/--
|
||||
Convert a LCNF local context into a regular Lean local context.
|
||||
-/
|
||||
def LCtx.toLocalContext (lctx : LCtx) : LocalContext := Id.run do
|
||||
def LCtx.toLocalContext (lctx : LCtx) (pu : Purity) : LocalContext := Id.run do
|
||||
let mut result := {}
|
||||
for (_, param) in lctx.params.toArray do
|
||||
for (_, param) in lctx.params pu do
|
||||
result := result.addDecl (.cdecl 0 param.fvarId param.binderName param.type .default .default)
|
||||
for (_, decl) in lctx.letDecls.toArray do
|
||||
for (_, decl) in lctx.letDecls pu do
|
||||
result := result.addDecl (.ldecl 0 decl.fvarId decl.binderName decl.type decl.value.toExpr true .default)
|
||||
for (_, decl) in lctx.funDecls.toArray do
|
||||
for (_, decl) in lctx.funDecls pu do
|
||||
result := result.addDecl (.cdecl 0 decl.fvarId decl.binderName decl.type .default .default)
|
||||
return result
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ structure Context where
|
||||
Declaration where lambda lifting is being applied.
|
||||
We use it to provide the "base name" for auxiliary declarations and the flag `safe`.
|
||||
-/
|
||||
mainDecl : Decl
|
||||
mainDecl : Decl .pure
|
||||
/--
|
||||
If true, the lambda-lifted functions inherit the inline attribute from `mainDecl`.
|
||||
We use this feature to implement `@[inline] instance ...` and `@[always_inline] instance ...`
|
||||
@@ -51,7 +51,7 @@ structure State where
|
||||
/--
|
||||
New auxiliary declarations
|
||||
-/
|
||||
decls : Array Decl := #[]
|
||||
decls : Array (Decl .pure) := #[]
|
||||
/--
|
||||
Next index for generating auxiliary declaration name.
|
||||
-/
|
||||
@@ -64,13 +64,13 @@ abbrev LiftM := ReaderT Context (StateRefT State (ScopeT CompilerM))
|
||||
Return `true` if the given declaration takes a local instance as a parameter.
|
||||
We lambda lift this kind of local function declaration before specialization.
|
||||
-/
|
||||
def hasInstParam (decl : FunDecl) : CompilerM Bool :=
|
||||
def hasInstParam (decl : FunDecl .pure) : CompilerM Bool :=
|
||||
decl.params.anyM fun param => return (← isArrowClass? param.type).isSome
|
||||
|
||||
/--
|
||||
Return `true` if the given declaration should be lambda lifted.
|
||||
-/
|
||||
def shouldLift (decl : FunDecl) : LiftM Bool := do
|
||||
def shouldLift (decl : FunDecl .pure) : LiftM Bool := do
|
||||
let minSize := (← read).minSize
|
||||
if decl.value.size < minSize then
|
||||
return false
|
||||
@@ -85,7 +85,7 @@ partial def mkAuxDeclName : LiftM Name := do
|
||||
if (← getDecl? nameNew).isNone then return nameNew
|
||||
mkAuxDeclName
|
||||
|
||||
def replaceFunDecl (decl : FunDecl) (value : LetValue) : LiftM LetDecl := do
|
||||
def replaceFunDecl (decl : FunDecl .pure) (value : LetValue .pure) : LiftM (LetDecl .pure) := do
|
||||
/- We reuse `decl`s `fvarId` to avoid substitution -/
|
||||
let declNew := { fvarId := decl.fvarId, binderName := decl.binderName, type := decl.type, value }
|
||||
modifyLCtx fun lctx => lctx.addLetDecl declNew
|
||||
@@ -97,7 +97,7 @@ open Internalize in
|
||||
Create a new auxiliary declaration. The array `closure` contains all free variables
|
||||
occurring in `decl`.
|
||||
-/
|
||||
def mkAuxDecl (closure : Array Param) (decl : FunDecl) : LiftM LetDecl := do
|
||||
def mkAuxDecl (closure : Array (Param .pure)) (decl : FunDecl .pure) : LiftM (LetDecl .pure) := do
|
||||
let nameNew ← mkAuxDeclName
|
||||
let inlineAttr? ← if (← read).inheritInlineAttrs then pure (← read).mainDecl.inlineAttr? else pure none
|
||||
let auxDecl ← go nameNew (← read).mainDecl.safe inlineAttr? |>.run' {}
|
||||
@@ -113,16 +113,16 @@ def mkAuxDecl (closure : Array Param) (decl : FunDecl) : LiftM LetDecl := do
|
||||
let value := .const auxDeclName us (closure.map (.fvar ·.fvarId))
|
||||
replaceFunDecl decl value
|
||||
where
|
||||
go (nameNew : Name) (safe : Bool) (inlineAttr? : Option InlineAttributeKind) : InternalizeM Decl := do
|
||||
go (nameNew : Name) (safe : Bool) (inlineAttr? : Option InlineAttributeKind) : InternalizeM .pure (Decl .pure):= do
|
||||
let params := (← closure.mapM internalizeParam) ++ (← decl.params.mapM internalizeParam)
|
||||
let code ← internalizeCode decl.value
|
||||
let type ← code.inferType
|
||||
let type ← mkForallParams params type
|
||||
let value := .code code
|
||||
let decl := { name := nameNew, levelParams := [], params, type, value, safe, inlineAttr?, recursive := false : Decl }
|
||||
let decl := { name := nameNew, levelParams := [], params, type, value, safe, inlineAttr?, recursive := false : Decl .pure }
|
||||
return decl.setLevelParams
|
||||
|
||||
def etaContractibleDecl? (decl : FunDecl) : LiftM (Option LetDecl) := do
|
||||
def etaContractibleDecl? (decl : FunDecl .pure) : LiftM (Option (LetDecl .pure)) := do
|
||||
if !(← read).allowEtaContraction then return none
|
||||
let .let { fvarId := letVar, value := .const declName us args, .. } (.return retVar) := decl.value
|
||||
| return none
|
||||
@@ -137,11 +137,11 @@ def etaContractibleDecl? (decl : FunDecl) : LiftM (Option LetDecl) := do
|
||||
replaceFunDecl decl value
|
||||
|
||||
mutual
|
||||
partial def visitFunDecl (funDecl : FunDecl) : LiftM FunDecl := do
|
||||
partial def visitFunDecl (funDecl : FunDecl .pure) : LiftM (FunDecl .pure) := do
|
||||
let value ← withParams funDecl.params <| visitCode funDecl.value
|
||||
funDecl.update' funDecl.type value
|
||||
|
||||
partial def visitCode (code : Code) : LiftM Code := do
|
||||
partial def visitCode (code : Code .pure) : LiftM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let k ← withFVar decl.fvarId <| visitCode k
|
||||
@@ -174,14 +174,14 @@ mutual
|
||||
| .unreach .. | .jmp .. | .return .. => return code
|
||||
end
|
||||
|
||||
def main (decl : Decl) : LiftM Decl := do
|
||||
def main (decl : Decl .pure) : LiftM (Decl .pure) := do
|
||||
let value ← withParams decl.params <| decl.value.mapCodeM visitCode
|
||||
return { decl with value }
|
||||
|
||||
end LambdaLifting
|
||||
|
||||
partial def Decl.lambdaLifting (decl : Decl) (liftInstParamOnly : Bool) (allowEtaContraction : Bool)
|
||||
(suffix : Name) (inheritInlineAttrs := false) (minSize := 0) : CompilerM (Array Decl) := do
|
||||
partial def Decl.lambdaLifting (decl : Decl .pure) (liftInstParamOnly : Bool) (allowEtaContraction : Bool)
|
||||
(suffix : Name) (inheritInlineAttrs := false) (minSize := 0) : CompilerM (Array (Decl .pure)) := do
|
||||
let ctx := {
|
||||
mainDecl := decl,
|
||||
liftInstParamOnly,
|
||||
@@ -214,7 +214,7 @@ def eagerLambdaLifting : Pass where
|
||||
name := `eagerLambdaLifting
|
||||
run := fun decls => do
|
||||
decls.foldlM (init := #[]) fun decls decl => do
|
||||
if decl.inlineable || (← Meta.isInstance decl.name) then
|
||||
if decl.inlineable || (← isInstanceReducible decl.name) then
|
||||
return decls.push decl
|
||||
else
|
||||
return decls ++ (← decl.lambdaLifting (liftInstParamOnly := true) (allowEtaContraction := false) (suffix := `_elam))
|
||||
|
||||
@@ -105,45 +105,45 @@ open Lean.CollectLevelParams
|
||||
abbrev visitType (type : Expr) : Visitor :=
|
||||
visitExpr type
|
||||
|
||||
def visitArg (arg : Arg) : Visitor :=
|
||||
def visitArg (arg : Arg .pure) : Visitor :=
|
||||
match arg with
|
||||
| .erased | .fvar .. => id
|
||||
| .type e => visitType e
|
||||
| .type e _ => visitType e
|
||||
|
||||
def visitArgs (args : Array Arg) : Visitor :=
|
||||
def visitArgs (args : Array (Arg .pure)) : Visitor :=
|
||||
fun s => args.foldl (init := s) fun s arg => visitArg arg s
|
||||
|
||||
def visitLetValue (e : LetValue) : Visitor :=
|
||||
def visitLetValue (e : LetValue .pure) : Visitor :=
|
||||
match e with
|
||||
| .erased | .lit .. | .proj .. => id
|
||||
| .const _ us args => visitLevels us ∘ visitArgs args
|
||||
| .const _ us args _ => visitLevels us ∘ visitArgs args
|
||||
| .fvar _ args => visitArgs args
|
||||
|
||||
def visitParam (p : Param) : Visitor :=
|
||||
def visitParam (p : Param .pure) : Visitor :=
|
||||
visitType p.type
|
||||
|
||||
def visitParams (ps : Array Param) : Visitor :=
|
||||
def visitParams (ps : Array (Param .pure)) : Visitor :=
|
||||
fun s => ps.foldl (init := s) fun s p => visitParam p s
|
||||
|
||||
mutual
|
||||
partial def visitAlt (alt : Alt) : Visitor :=
|
||||
partial def visitAlt (alt : Alt .pure) : Visitor :=
|
||||
match alt with
|
||||
| .default k => visitCode k
|
||||
| .alt _ ps k => visitCode k ∘ visitParams ps
|
||||
| .alt _ ps k _ => visitCode k ∘ visitParams ps
|
||||
|
||||
partial def visitAlts (alts : Array Alt) : Visitor :=
|
||||
partial def visitAlts (alts : Array (Alt .pure)) : Visitor :=
|
||||
fun s => alts.foldl (init := s) fun s alt => visitAlt alt s
|
||||
|
||||
partial def visitCode : Code → Visitor
|
||||
partial def visitCode : Code .pure → Visitor
|
||||
| .let decl k => visitCode k ∘ visitLetValue decl.value ∘ visitType decl.type
|
||||
| .fun decl k | .jp decl k => visitCode k ∘ visitCode decl.value ∘ visitParams decl.params ∘ visitType decl.type
|
||||
| .fun decl k _ | .jp decl k => visitCode k ∘ visitCode decl.value ∘ visitParams decl.params ∘ visitType decl.type
|
||||
| .cases c => visitAlts c.alts ∘ visitType c.resultType
|
||||
| .unreach type => visitType type
|
||||
| .return _ => id
|
||||
| .jmp _ args => visitArgs args
|
||||
end
|
||||
|
||||
def visitDeclValue : DeclValue → Visitor
|
||||
def visitDeclValue : DeclValue .pure → Visitor
|
||||
| .code c => visitCode c
|
||||
| .extern .. => id
|
||||
|
||||
@@ -156,7 +156,7 @@ open CollectLevelParams
|
||||
Collect universe level parameters collecting in the type, parameters, and value, and then
|
||||
set `decl.levelParams` with the resulting value.
|
||||
-/
|
||||
def Decl.setLevelParams (decl : Decl) : Decl :=
|
||||
def Decl.setLevelParams (decl : Decl .pure) : Decl .pure :=
|
||||
let levelParams := (visitDeclValue decl.value ∘ visitParams decl.params ∘ visitType decl.type) {} |>.params.toList
|
||||
{ decl with levelParams }
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import Lean.Meta.Match.MatcherInfo
|
||||
import Lean.Compiler.LCNF.SplitSCC
|
||||
public import Lean.Compiler.IR.Basic
|
||||
public import Lean.Compiler.LCNF.CompilerM
|
||||
|
||||
public section
|
||||
namespace Lean.Compiler.LCNF
|
||||
/--
|
||||
@@ -50,7 +51,7 @@ A checkpoint in code generation to print all declarations in between
|
||||
compiler passes in order to ease debugging.
|
||||
The trace can be viewed with `set_option trace.Compiler.step true`.
|
||||
-/
|
||||
def checkpoint (stepName : Name) (decls : Array Decl) (shouldCheck : Bool) : CompilerM Unit := do
|
||||
def checkpoint (stepName : Name) (decls : Array (Decl pu)) (shouldCheck : Bool) : CompilerM Unit := do
|
||||
for decl in decls do
|
||||
trace[Compiler.stat] "{decl.name} : {decl.size}"
|
||||
withOptions (fun opts => opts.set `pp.motives.pi false) do
|
||||
@@ -101,12 +102,12 @@ def run (declNames : Array Name) : CompilerM (Array (Array IR.Decl)) := withAtLe
|
||||
let decls := markRecDecls decls
|
||||
let manager ← getPassManager
|
||||
let isCheckEnabled := compiler.check.get (← getOptions)
|
||||
let decls ← runPassManagerPart "compilation (LCNF base)" manager.basePasses decls isCheckEnabled
|
||||
let decls ← runPassManagerPart "compilation (LCNF mono)" manager.monoPasses decls isCheckEnabled
|
||||
let decls ← runPassManagerPart .pure .pure "compilation (LCNF base)" manager.basePasses decls isCheckEnabled
|
||||
let decls ← runPassManagerPart .pure .pure "compilation (LCNF mono)" manager.monoPasses decls isCheckEnabled
|
||||
let sccs ← withTraceNode `Compiler.splitSCC (fun _ => return m!"Splitting up SCC") do
|
||||
splitScc decls
|
||||
sccs.mapM fun decls => do
|
||||
let decls ← runPassManagerPart "compilation (LCNF mono)" manager.monoPassesNoLambda decls isCheckEnabled
|
||||
let decls ← runPassManagerPart .pure .pure "compilation (LCNF mono)" manager.monoPassesNoLambda decls isCheckEnabled
|
||||
if (← Lean.isTracingEnabledFor `Compiler.result) then
|
||||
for decl in decls do
|
||||
let decl ← normalizeFVarIds decl
|
||||
@@ -115,14 +116,19 @@ def run (declNames : Array Name) : CompilerM (Array (Array IR.Decl)) := withAtLe
|
||||
let irDecls ← IR.toIR decls
|
||||
IR.compile irDecls
|
||||
where
|
||||
runPassManagerPart (profilerName : String) (passes : Array Pass) (decls : Array Decl)
|
||||
(isCheckEnabled : Bool) : CompilerM (Array Decl) := do
|
||||
runPassManagerPart (inPhase outPhase : Purity) (profilerName : String)
|
||||
(passes : Array Pass) (decls : Array (Decl inPhase)) (isCheckEnabled : Bool) :
|
||||
CompilerM (Array (Decl outPhase)) := do
|
||||
profileitM Exception profilerName (← getOptions) do
|
||||
let mut decls := decls
|
||||
let mut state : (pu : Purity) × Array (Decl pu) := ⟨inPhase, decls⟩
|
||||
for pass in passes do
|
||||
decls ← withTraceNode `Compiler (fun _ => return m!"compiler phase: {pass.phase}, pass: {pass.name}") do
|
||||
withPhase pass.phase <| pass.run decls
|
||||
withPhase pass.phaseOut <| checkpoint pass.name decls (isCheckEnabled || pass.shouldAlwaysRunCheck)
|
||||
state ← withTraceNode `Compiler (fun _ => return m!"compiler phase: {pass.phase}, pass: {pass.name}") do
|
||||
let decls ← withPhase pass.phase do
|
||||
state.fst.withAssertPurity pass.phase.toPurity fun h => do
|
||||
pass.run (h ▸ state.snd)
|
||||
pure ⟨_, decls⟩
|
||||
withPhase pass.phaseOut <| checkpoint pass.name state.snd (isCheckEnabled || pass.shouldAlwaysRunCheck)
|
||||
let decls := state.fst.withAssertPurity outPhase fun h => h ▸ state.snd
|
||||
return decls
|
||||
|
||||
end PassManager
|
||||
|
||||
@@ -33,7 +33,7 @@ instance (m n) [MonadLift m n] [MonadFunctor m n] [MonadScope m] : MonadScope n
|
||||
def inScope [MonadScope m] [Monad m] (fvarId : FVarId) : m Bool :=
|
||||
return (← getScope).contains fvarId
|
||||
|
||||
@[inline] def withParams [MonadScope m] [Monad m] (ps : Array Param) (x : m α) : m α :=
|
||||
@[inline] def withParams [MonadScope m] [Monad m] (ps : Array (Param pu)) (x : m α) : m α :=
|
||||
withScope (fun s => ps.foldl (init := s) fun s p => s.insert p.fvarId) x
|
||||
|
||||
@[inline] def withFVar [MonadScope m] [Monad m] (fvarId : FVarId) (x : m α) : m α :=
|
||||
|
||||
@@ -99,4 +99,7 @@ def getOtherDeclMonoType (declName : Name) : CoreM Expr := do
|
||||
monoTypeExt.insert declName type
|
||||
return type
|
||||
|
||||
def getOtherDeclImpureType (_declName : Name) : CoreM Expr := do
|
||||
panic! "Other decl impure type unimplemented" -- TODO
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
||||
@@ -19,5 +19,6 @@ def getOtherDeclType (declName : Name) (us : List Level := []) : CompilerM Expr
|
||||
match (← getPhase) with
|
||||
| .base => getOtherDeclBaseType declName us
|
||||
| .mono => getOtherDeclMonoType declName
|
||||
| .impure => getOtherDeclImpureType declName
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
||||
@@ -15,6 +15,20 @@ namespace Lean.Compiler.LCNF
|
||||
@[expose] def Phase.toNat : Phase → Nat
|
||||
| .base => 0
|
||||
| .mono => 1
|
||||
| .impure => 2
|
||||
|
||||
instance : ToString Phase where
|
||||
toString
|
||||
| .base => "base"
|
||||
| .mono => "mono"
|
||||
| .impure => "impure"
|
||||
|
||||
def Phase.withPurityCheck [Inhabited α] (pp : Phase) (ip : Purity)
|
||||
(x : pp.toPurity = ip → α) : α :=
|
||||
if h : pp.toPurity = ip then
|
||||
x h
|
||||
else
|
||||
panic! s!"Compiler error: {pp} is not equivalent to IR phase {ip}, this is a bug"
|
||||
|
||||
instance : LT Phase where
|
||||
lt l r := l.toNat < r.toNat
|
||||
@@ -60,7 +74,7 @@ structure Pass where
|
||||
/--
|
||||
The actual pass function, operating on the `Decl`s.
|
||||
-/
|
||||
run : Array Decl → CompilerM (Array Decl)
|
||||
run : Array (Decl phase.toPurity) → CompilerM (Array (Decl phase.toPurity))
|
||||
|
||||
instance : Inhabited Pass where
|
||||
default := { phase := .base, name := default, run := fun decls => return decls }
|
||||
@@ -90,14 +104,10 @@ structure PassManager where
|
||||
monoPassesNoLambda : Array Pass
|
||||
deriving Inhabited
|
||||
|
||||
instance : ToString Phase where
|
||||
toString
|
||||
| .base => "base"
|
||||
| .mono => "mono"
|
||||
|
||||
namespace Pass
|
||||
|
||||
def mkPerDeclaration (name : Name) (run : Decl → CompilerM Decl) (phase : Phase) (occurrence : Nat := 0) : Pass where
|
||||
def mkPerDeclaration (name : Name) (phase : Phase)
|
||||
(run : Decl phase.toPurity → CompilerM (Decl phase.toPurity)) (occurrence : Nat := 0) : Pass where
|
||||
occurrence := occurrence
|
||||
phase := phase
|
||||
name := name
|
||||
@@ -190,6 +200,7 @@ def run (manager : PassManager) (installer : PassInstaller) : CoreM PassManager
|
||||
return { manager with basePasses := (← installer.install manager.basePasses) }
|
||||
| .mono =>
|
||||
return { manager with monoPasses := (← installer.install manager.monoPasses) }
|
||||
| .impure => panic! "Pass manager support for impure unimplemented" -- TODO
|
||||
|
||||
private unsafe def getPassInstallerUnsafe (declName : Name) : CoreM PassInstaller := do
|
||||
ofExcept <| (← getEnv).evalConstCheck PassInstaller (← getOptions) ``PassInstaller declName
|
||||
|
||||
@@ -45,10 +45,15 @@ private builtin_initialize baseTransparentDeclsExt : EnvExtension (List Name ×
|
||||
Set of public declarations whose mono bodies should be exported to other modules
|
||||
-/
|
||||
private builtin_initialize monoTransparentDeclsExt : EnvExtension (List Name × NameSet) ← mkDeclSetExt
|
||||
/--
|
||||
Set of public declarations whose impure bodies should be exported to other modules
|
||||
-/
|
||||
private builtin_initialize impureTransparentDeclsExt : EnvExtension (List Name × NameSet) ← mkDeclSetExt
|
||||
|
||||
private def getTransparencyExt : Phase → EnvExtension (List Name × NameSet)
|
||||
| .base => baseTransparentDeclsExt
|
||||
| .mono => monoTransparentDeclsExt
|
||||
| .impure => impureTransparentDeclsExt
|
||||
|
||||
def isDeclPublic (env : Environment) (declName : Name) : Bool := Id.run do
|
||||
if !env.header.isModule then
|
||||
@@ -81,26 +86,28 @@ def setDeclTransparent (env : Environment) (phase : Phase) (declName : Name) : E
|
||||
getTransparencyExt phase |>.modifyState env fun s =>
|
||||
(declName :: s.1, s.2.insert declName)
|
||||
|
||||
abbrev DeclExtState := PHashMap Name Decl
|
||||
abbrev DeclExtState (pu : Purity) := PHashMap Name (Decl pu)
|
||||
|
||||
private abbrev declLt (a b : Decl) :=
|
||||
private abbrev declLt (a b : Decl pu) :=
|
||||
Name.quickLt a.name b.name
|
||||
|
||||
private def sortedDecls (s : DeclExtState) : Array Decl :=
|
||||
private def sortedDecls (s : DeclExtState pu) : Array (Decl pu) :=
|
||||
let decls := s.foldl (init := #[]) fun ps _ v => ps.push v
|
||||
decls.qsort declLt
|
||||
|
||||
private abbrev findAtSorted? (decls : Array Decl) (declName : Name) : Option Decl :=
|
||||
let tmpDecl : Decl := default
|
||||
private abbrev findAtSorted? (decls : Array (Decl pu)) (declName : Name) : Option (Decl pu) :=
|
||||
let tmpDecl : Decl pu := default
|
||||
let tmpDecl := { tmpDecl with name := declName }
|
||||
decls.binSearch tmpDecl declLt
|
||||
|
||||
@[expose] def DeclExt := PersistentEnvExtension Decl Decl DeclExtState
|
||||
@[expose] def DeclExt (pu : Purity) :=
|
||||
PersistentEnvExtension (Decl pu) (Decl pu) (DeclExtState pu)
|
||||
|
||||
instance : Inhabited DeclExt :=
|
||||
inferInstanceAs (Inhabited (PersistentEnvExtension Decl Decl DeclExtState))
|
||||
instance : Inhabited (DeclExt pu) :=
|
||||
inferInstanceAs (Inhabited (PersistentEnvExtension (Decl pu) (Decl pu) (DeclExtState pu)))
|
||||
|
||||
def mkDeclExt (phase : Phase) (name : Name := by exact decl_name%) : IO DeclExt :=
|
||||
def mkDeclExt (phase : Phase) (name : Name := by exact decl_name%) :
|
||||
IO (DeclExt phase.toPurity) :=
|
||||
registerPersistentEnvExtension {
|
||||
name,
|
||||
mkInitial := pure {},
|
||||
@@ -128,74 +135,77 @@ def mkDeclExt (phase : Phase) (name : Name := by exact decl_name%) : IO DeclExt
|
||||
otherState.insert k v
|
||||
}
|
||||
|
||||
builtin_initialize baseExt : DeclExt ← mkDeclExt .base
|
||||
builtin_initialize monoExt : DeclExt ← mkDeclExt .mono
|
||||
builtin_initialize baseExt : DeclExt .pure ← mkDeclExt .base
|
||||
builtin_initialize monoExt : DeclExt .pure ← mkDeclExt .mono
|
||||
builtin_initialize impureExt : DeclExt .impure ← mkDeclExt .impure
|
||||
|
||||
def getDeclCore? (env : Environment) (ext : DeclExt) (declName : Name) : Option Decl :=
|
||||
def getDeclCore? (env : Environment) (ext : DeclExt pu) (declName : Name) : Option (Decl pu) :=
|
||||
match env.getModuleIdxFor? declName with
|
||||
| some modIdx => findAtSorted? (ext.getModuleEntries env modIdx) declName
|
||||
| none => ext.getState env |>.find? declName
|
||||
|
||||
def getBaseDecl? (declName : Name) : CoreM (Option Decl) := do
|
||||
def getBaseDecl? (declName : Name) : CoreM (Option (Decl .pure)) := do
|
||||
return getDeclCore? (← getEnv) baseExt declName
|
||||
|
||||
def getMonoDecl? (declName : Name) : CoreM (Option Decl) := do
|
||||
def getMonoDecl? (declName : Name) : CoreM (Option (Decl .pure)) := do
|
||||
return getDeclCore? (← getEnv) monoExt declName
|
||||
|
||||
def saveBaseDeclCore (env : Environment) (decl : Decl) : Environment :=
|
||||
def getImpureDecl? (declName : Name) : CoreM (Option (Decl .impure)) := do
|
||||
return getDeclCore? (← getEnv) impureExt declName
|
||||
|
||||
def saveBaseDeclCore (env : Environment) (decl : Decl .pure) : Environment :=
|
||||
baseExt.addEntry env decl
|
||||
|
||||
def saveMonoDeclCore (env : Environment) (decl : Decl) : Environment :=
|
||||
def saveMonoDeclCore (env : Environment) (decl : Decl .pure) : Environment :=
|
||||
monoExt.addEntry env decl
|
||||
|
||||
def Decl.saveBase (decl : Decl) : CoreM Unit :=
|
||||
def saveImpureDeclCore (env : Environment) (decl : Decl .impure) : Environment :=
|
||||
impureExt.addEntry env decl
|
||||
|
||||
def Decl.saveBase (decl : Decl .pure) : CoreM Unit :=
|
||||
modifyEnv (saveBaseDeclCore · decl)
|
||||
|
||||
def Decl.saveMono (decl : Decl) : CoreM Unit :=
|
||||
def Decl.saveMono (decl : Decl .pure) : CoreM Unit :=
|
||||
modifyEnv (saveMonoDeclCore · decl)
|
||||
|
||||
def Decl.save (decl : Decl) : CompilerM Unit := do
|
||||
match (← getPhase) with
|
||||
| .base => decl.saveBase
|
||||
| .mono => decl.saveMono
|
||||
def Decl.saveImpure (decl : Decl .impure) : CoreM Unit :=
|
||||
modifyEnv (saveImpureDeclCore · decl)
|
||||
|
||||
def getDeclAt? (declName : Name) (phase : Phase) : CoreM (Option Decl) :=
|
||||
def Decl.save (decl : Decl pu) : CompilerM Unit := do
|
||||
match (← getPhase) with
|
||||
| .base => Phase.withPurityCheck .base pu fun h =>
|
||||
(h.symm ▸ decl).saveBase
|
||||
| .mono => Phase.withPurityCheck .mono pu fun h =>
|
||||
(h.symm ▸ decl).saveMono
|
||||
| .impure => Phase.withPurityCheck .impure pu fun h =>
|
||||
(h.symm ▸ decl).saveImpure
|
||||
|
||||
def getDeclAt? (declName : Name) (phase : Phase) : CoreM (Option (Decl phase.toPurity)) :=
|
||||
match phase with
|
||||
| .base => getBaseDecl? declName
|
||||
| .mono => getMonoDecl? declName
|
||||
| .impure => getImpureDecl? declName
|
||||
|
||||
def getDecl? (declName : Name) : CompilerM (Option Decl) := do
|
||||
getDeclAt? declName (← getPhase)
|
||||
@[inline]
|
||||
def getDecl? (declName : Name) : CompilerM (Option ((pu : Purity) × Decl pu)) := do
|
||||
let some decl ← getDeclAt? declName (← getPhase) | return none
|
||||
return some ⟨_, decl⟩
|
||||
|
||||
def getLocalDeclAt? (declName : Name) (phase : Phase) : CompilerM (Option Decl) := do
|
||||
def getLocalDeclAt? (declName : Name) (phase : Phase) : CompilerM (Option (Decl phase.toPurity)) := do
|
||||
match phase with
|
||||
| .base => return baseExt.getState (← getEnv) |>.find? declName
|
||||
| .mono => return monoExt.getState (← getEnv) |>.find? declName
|
||||
| .impure => return impureExt.getState (← getEnv) |>.find? declName
|
||||
|
||||
def getLocalDecl? (declName : Name) : CompilerM (Option Decl) := do
|
||||
getLocalDeclAt? declName (← getPhase)
|
||||
@[inline]
|
||||
def getLocalDecl? (declName : Name) : CompilerM (Option ((pu : Purity) × Decl pu)) := do
|
||||
let some decl ← getLocalDeclAt? declName (← getPhase) | return none
|
||||
return some ⟨_, decl⟩
|
||||
|
||||
def getExt (phase : Phase) : DeclExt :=
|
||||
def getExt (phase : Phase) : DeclExt phase.toPurity :=
|
||||
match phase with
|
||||
| .base => baseExt
|
||||
| .mono => monoExt
|
||||
|
||||
def forEachDecl (f : Decl → CoreM Unit) (phase := Phase.base) : CoreM Unit := do
|
||||
let ext := getExt phase
|
||||
let env ← getEnv
|
||||
for modIdx in *...env.allImportedModuleNames.size do
|
||||
for decl in ext.getModuleEntries env modIdx do
|
||||
f decl
|
||||
ext.getState env |>.forM fun _ decl => f decl
|
||||
|
||||
def forEachModuleDecl (moduleName : Name) (f : Decl → CoreM Unit) (phase := Phase.base) : CoreM Unit := do
|
||||
let ext := getExt phase
|
||||
let env ← getEnv
|
||||
let some modIdx := env.getModuleIdx? moduleName | throwError "module `{moduleName}` not found"
|
||||
for decl in ext.getModuleEntries env modIdx do
|
||||
f decl
|
||||
|
||||
def forEachMainModuleDecl (f : Decl → CoreM Unit) (phase := Phase.base) : CoreM Unit := do
|
||||
(getExt phase).getState (← getEnv) |>.forM fun _ decl => f decl
|
||||
| .impure => impureExt
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
||||
@@ -43,11 +43,11 @@ def ppFVar (fvarId : FVarId) : M Format :=
|
||||
def ppExpr (e : Expr) : M Format := do
|
||||
Meta.ppExpr e |>.run' { lctx := (← read) }
|
||||
|
||||
def ppArg (e : Arg) : M Format := do
|
||||
def ppArg (e : Arg pu) : M Format := do
|
||||
match e with
|
||||
| .erased => return "◾"
|
||||
| .fvar fvarId => ppFVar fvarId
|
||||
| .type e =>
|
||||
| .type e _ =>
|
||||
if pp.explicit.get (← getOptions) then
|
||||
if e.isConst || e.isProp || e.isType0 || e.isFVar then
|
||||
ppExpr e
|
||||
@@ -56,7 +56,7 @@ def ppArg (e : Arg) : M Format := do
|
||||
else
|
||||
return "_"
|
||||
|
||||
def ppArgs (args : Array Arg) : M Format := do
|
||||
def ppArgs (args : Array (Arg pu)) : M Format := do
|
||||
prefixJoin " " args ppArg
|
||||
|
||||
def ppLitValue (lit : LitValue) : M Format := do
|
||||
@@ -64,49 +64,49 @@ def ppLitValue (lit : LitValue) : M Format := do
|
||||
| .nat v | .uint8 v | .uint16 v | .uint32 v | .uint64 v | .usize v => return format v
|
||||
| .str v => return format (repr v)
|
||||
|
||||
def ppLetValue (e : LetValue) : M Format := do
|
||||
def ppLetValue (e : LetValue pu) : M Format := do
|
||||
match e with
|
||||
| .erased => return "◾"
|
||||
| .lit v => ppLitValue v
|
||||
| .proj _ i fvarId => return f!"{← ppFVar fvarId} # {i}"
|
||||
| .proj _ i fvarId _ => return f!"{← ppFVar fvarId} # {i}"
|
||||
| .fvar fvarId args => return f!"{← ppFVar fvarId}{← ppArgs args}"
|
||||
| .const declName us args => return f!"{← ppExpr (.const declName us)}{← ppArgs args}"
|
||||
| .const declName us args _ => return f!"{← ppExpr (.const declName us)}{← ppArgs args}"
|
||||
|
||||
def ppParam (param : Param) : M Format := do
|
||||
def ppParam (param : Param pu) : M Format := do
|
||||
let borrow := if param.borrow then "@&" else ""
|
||||
if pp.funBinderTypes.get (← getOptions) then
|
||||
return Format.paren f!"{param.binderName} : {borrow}{← ppExpr param.type}"
|
||||
else
|
||||
return format s!"{borrow}{param.binderName}"
|
||||
|
||||
def ppParams (params : Array Param) : M Format := do
|
||||
def ppParams (params : Array (Param pu)) : M Format := do
|
||||
prefixJoin " " params ppParam
|
||||
|
||||
def ppLetDecl (letDecl : LetDecl) : M Format := do
|
||||
def ppLetDecl (letDecl : LetDecl pu) : M Format := do
|
||||
if pp.letVarTypes.get (← getOptions) then
|
||||
return f!"let {letDecl.binderName} : {← ppExpr letDecl.type} := {← ppLetValue letDecl.value}"
|
||||
else
|
||||
return f!"let {letDecl.binderName} := {← ppLetValue letDecl.value}"
|
||||
|
||||
def getFunType (ps : Array Param) (type : Expr) : CoreM Expr :=
|
||||
def getFunType (ps : Array (Param pu)) (type : Expr) : CoreM Expr :=
|
||||
if type.isErased then
|
||||
pure type
|
||||
else
|
||||
instantiateForall type (ps.map (mkFVar ·.fvarId))
|
||||
|
||||
mutual
|
||||
partial def ppFunDecl (funDecl : FunDecl) : M Format := do
|
||||
partial def ppFunDecl (funDecl : FunDecl pu) : M Format := do
|
||||
return f!"{funDecl.binderName}{← ppParams funDecl.params} : {← ppExpr (← getFunType funDecl.params funDecl.type)} :={indentD (← ppCode funDecl.value)}"
|
||||
|
||||
partial def ppAlt (alt : Alt) : M Format := do
|
||||
partial def ppAlt (alt : Alt pu) : M Format := do
|
||||
match alt with
|
||||
| .default k => return f!"| _ =>{indentD (← ppCode k)}"
|
||||
| .alt ctorName params k => return f!"| {ctorName}{← ppParams params} =>{indentD (← ppCode k)}"
|
||||
| .alt ctorName params k _ => return f!"| {ctorName}{← ppParams params} =>{indentD (← ppCode k)}"
|
||||
|
||||
partial def ppCode (c : Code) : M Format := do
|
||||
partial def ppCode (c : Code pu) : M Format := do
|
||||
match c with
|
||||
| .let decl k => return (← ppLetDecl decl) ++ ";" ++ .line ++ (← ppCode k)
|
||||
| .fun decl k => return f!"fun " ++ (← ppFunDecl decl) ++ ";" ++ .line ++ (← ppCode k)
|
||||
| .fun decl k _ => return f!"fun " ++ (← ppFunDecl decl) ++ ";" ++ .line ++ (← ppCode k)
|
||||
| .jp decl k => return f!"jp " ++ (← ppFunDecl decl) ++ ";" ++ .line ++ (← ppCode k)
|
||||
| .cases c => return f!"cases {← ppFVar c.discr} : {← ppExpr c.resultType}{← prefixJoin .line c.alts ppAlt}"
|
||||
| .return fvarId => return f!"return {← ppFVar fvarId}"
|
||||
@@ -117,7 +117,7 @@ mutual
|
||||
else
|
||||
return "⊥"
|
||||
|
||||
partial def ppDeclValue (b : DeclValue) : M Format := do
|
||||
partial def ppDeclValue (b : DeclValue pu) : M Format := do
|
||||
match b with
|
||||
| .code c => ppCode c
|
||||
| .extern .. => return "extern"
|
||||
@@ -125,21 +125,21 @@ end
|
||||
|
||||
def run (x : M α) : CompilerM α :=
|
||||
withOptions (pp.sanitizeNames.set · false) do
|
||||
x |>.run (← get).lctx.toLocalContext
|
||||
x |>.run ((← get).lctx.toLocalContext (← getPurity))
|
||||
|
||||
end PP
|
||||
|
||||
def ppCode (code : Code) : CompilerM Format :=
|
||||
def ppCode (code : Code pu) : CompilerM Format :=
|
||||
PP.run <| PP.ppCode code
|
||||
|
||||
def ppLetValue (e : LetValue) : CompilerM Format :=
|
||||
def ppLetValue (e : LetValue pu) : CompilerM Format :=
|
||||
PP.run <| PP.ppLetValue e
|
||||
|
||||
def ppDecl (decl : Decl) : CompilerM Format :=
|
||||
def ppDecl (decl : Decl pu) : CompilerM Format :=
|
||||
PP.run do
|
||||
return f!"def {decl.name}{← PP.ppParams decl.params} : {← PP.ppExpr (← PP.getFunType decl.params decl.type)} :={indentD (← PP.ppDeclValue decl.value)}"
|
||||
|
||||
def ppFunDecl (decl : FunDecl) : CompilerM Format :=
|
||||
def ppFunDecl (decl : FunDecl pu) : CompilerM Format :=
|
||||
PP.run do
|
||||
return f!"fun {← PP.ppFunDecl decl}"
|
||||
|
||||
@@ -159,7 +159,7 @@ Similar to `ppDecl`, but in `CoreM`, and it does not assume
|
||||
`decl` has already been internalized.
|
||||
This function is used for debugging purposes.
|
||||
-/
|
||||
def ppDecl' (decl : Decl) : CoreM Format := do
|
||||
def ppDecl' (decl : Decl pu) : CoreM Format := do
|
||||
runCompilerWithoutModifyingState do
|
||||
ppDecl (← decl.internalize)
|
||||
|
||||
@@ -167,7 +167,7 @@ def ppDecl' (decl : Decl) : CoreM Format := do
|
||||
Similar to `ppCode`, but in `CoreM`, and it does not assume
|
||||
`code` has already been internalized.
|
||||
-/
|
||||
def ppCode' (code : Code) : CoreM Format := do
|
||||
def ppCode' (code : Code pu) : CoreM Format := do
|
||||
runCompilerWithoutModifyingState do
|
||||
ppCode (← code.internalize)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def filter (f : α → CompilerM Bool) : Probe α α := fun data => data.filterM
|
||||
def sorted [Inhabited α] [LT α] [DecidableLT α] : Probe α α := fun data => return data.qsort (· < ·)
|
||||
|
||||
@[inline]
|
||||
def sortedBySize : Probe Decl (Nat × Decl) := fun decls =>
|
||||
def sortedBySize (pu : Purity) : Probe (Decl pu) (Nat × Decl pu) := fun decls =>
|
||||
let decls := decls.map fun decl => (decl.size, decl)
|
||||
return decls.qsort fun (sz₁, decl₁) (sz₂, decl₂) =>
|
||||
if sz₁ == sz₂ then Name.lt decl₁.name decl₂.name else sz₁ < sz₂
|
||||
@@ -44,116 +44,118 @@ def countUnique [ToString α] [BEq α] [Hashable α] : Probe α (α × Nat) := f
|
||||
def countUniqueSorted [ToString α] [BEq α] [Hashable α] [Inhabited α] : Probe α (α × Nat) :=
|
||||
countUnique >=> fun data => return data.qsort (fun l r => l.snd < r.snd)
|
||||
|
||||
partial def getLetValues : Probe Decl LetValue := fun decls => do
|
||||
partial def getLetValues (pu : Purity) : Probe (Decl pu) (LetValue pu) := fun decls => do
|
||||
let (_, res) ← start decls |>.run #[]
|
||||
return res
|
||||
where
|
||||
go (c : Code) : StateRefT (Array LetValue) CompilerM Unit := do
|
||||
go (c : Code pu) : StateRefT (Array (LetValue pu)) CompilerM Unit := do
|
||||
match c with
|
||||
| .let (decl : LetDecl) (k : Code) =>
|
||||
| .let decl k =>
|
||||
modify fun s => s.push decl.value
|
||||
go k
|
||||
| .fun decl k | .jp decl k =>
|
||||
| .fun decl k _ | .jp decl k =>
|
||||
go decl.value
|
||||
go k
|
||||
| .cases cs => cs.alts.forM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return ()
|
||||
start (decls : Array Decl) : StateRefT (Array LetValue) CompilerM Unit :=
|
||||
start (decls : Array (Decl pu)) : StateRefT (Array (LetValue pu)) CompilerM Unit :=
|
||||
decls.forM (·.value.forCodeM go)
|
||||
|
||||
partial def getJps : Probe Decl FunDecl := fun decls => do
|
||||
partial def getJps (pu : Purity) : Probe (Decl pu) (FunDecl pu) := fun decls => do
|
||||
let (_, res) ← start decls |>.run #[]
|
||||
return res
|
||||
where
|
||||
go (code : Code) : StateRefT (Array FunDecl) CompilerM Unit := do
|
||||
go (code : Code pu) : StateRefT (Array (FunDecl pu)) CompilerM Unit := do
|
||||
match code with
|
||||
| .let _ k => go k
|
||||
| .fun decl k => go decl.value; go k
|
||||
| .fun decl k _ => go decl.value; go k
|
||||
| .jp decl k => modify (·.push decl); go decl.value; go k
|
||||
| .cases cs => cs.alts.forM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return ()
|
||||
|
||||
start (decls : Array Decl) : StateRefT (Array FunDecl) CompilerM Unit :=
|
||||
start (decls : Array (Decl pu)) : StateRefT (Array (FunDecl pu)) CompilerM Unit :=
|
||||
decls.forM (·.value.forCodeM go)
|
||||
|
||||
partial def filterByLet (f : LetDecl → CompilerM Bool) : Probe Decl Decl :=
|
||||
partial def filterByLet (pu : Purity) (f : LetDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
where
|
||||
go : Code → CompilerM Bool
|
||||
go : Code pu → CompilerM Bool
|
||||
| .let decl k => do if (← f decl) then return true else go k
|
||||
| .fun decl k | .jp decl k => go decl.value <||> go k
|
||||
| .fun decl k _ | .jp decl k => go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return false
|
||||
|
||||
partial def filterByFun (f : FunDecl → CompilerM Bool) : Probe Decl Decl :=
|
||||
partial def filterByFun (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
where
|
||||
go : Code → CompilerM Bool
|
||||
go : Code pu → CompilerM Bool
|
||||
| .let _ k | .jp _ k => go k
|
||||
| .fun decl k => do if (← f decl) then return true else go decl.value <||> go k
|
||||
| .fun decl k _ => do if (← f decl) then return true else go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return false
|
||||
|
||||
partial def filterByJp (f : FunDecl → CompilerM Bool) : Probe Decl Decl :=
|
||||
partial def filterByJp (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
where
|
||||
go : Code → CompilerM Bool
|
||||
go : Code pu → CompilerM Bool
|
||||
| .let _ k => go k
|
||||
| .fun decl k => go decl.value <||> go k
|
||||
| .fun decl k _ => go decl.value <||> go k
|
||||
| .jp decl k => do if (← f decl) then return true else go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return false
|
||||
|
||||
partial def filterByFunDecl (f : FunDecl → CompilerM Bool) : Probe Decl Decl :=
|
||||
partial def filterByFunDecl (pu : Purity) (f : FunDecl pu → CompilerM Bool) :
|
||||
Probe (Decl pu) (Decl pu):=
|
||||
filter (·.value.isCodeAndM go)
|
||||
where
|
||||
go : Code → CompilerM Bool
|
||||
go : Code pu → CompilerM Bool
|
||||
| .let _ k => go k
|
||||
| .fun decl k | .jp decl k => do if (← f decl) then return true else go decl.value <||> go k
|
||||
| .fun decl k _ | .jp decl k => do if (← f decl) then return true else go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return false
|
||||
|
||||
partial def filterByCases (f : Cases → CompilerM Bool) : Probe Decl Decl :=
|
||||
partial def filterByCases (pu : Purity) (f : Cases pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
where
|
||||
go : Code → CompilerM Bool
|
||||
go : Code pu → CompilerM Bool
|
||||
| .let _ k => go k
|
||||
| .fun decl k | .jp decl k => go decl.value <||> go k
|
||||
| .fun decl k _ | .jp decl k => go decl.value <||> go k
|
||||
| .cases cs => do if (← f cs) then return true else cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return false
|
||||
|
||||
partial def filterByJmp (f : FVarId → Array Arg → CompilerM Bool) : Probe Decl Decl :=
|
||||
partial def filterByJmp (pu : Purity) (f : FVarId → Array (Arg pu) → CompilerM Bool) :
|
||||
Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
where
|
||||
go : Code → CompilerM Bool
|
||||
go : Code pu → CompilerM Bool
|
||||
| .let _ k => go k
|
||||
| .fun decl k | .jp decl k => go decl.value <||> go k
|
||||
| .fun decl k _ | .jp decl k => go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp fn var => f fn var
|
||||
| .return .. | .unreach .. => return false
|
||||
|
||||
partial def filterByReturn (f : FVarId → CompilerM Bool) : Probe Decl Decl :=
|
||||
partial def filterByReturn (pu : Purity) (f : FVarId → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
where
|
||||
go : Code → CompilerM Bool
|
||||
go : Code pu → CompilerM Bool
|
||||
| .let _ k => go k
|
||||
| .fun decl k | .jp decl k => go decl.value <||> go k
|
||||
| .fun decl k _ | .jp decl k => go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .unreach .. => return false
|
||||
| .return var => f var
|
||||
|
||||
partial def filterByUnreach (f : Expr → CompilerM Bool) : Probe Decl Decl :=
|
||||
partial def filterByUnreach (pu : Purity) (f : Expr → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
where
|
||||
go : Code → CompilerM Bool
|
||||
go : Code pu → CompilerM Bool
|
||||
| .let _ k => go k
|
||||
| .fun decl k | .jp decl k => go decl.value <||> go k
|
||||
| .fun decl k _ | .jp decl k => go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. => return false
|
||||
| .unreach typ => f typ
|
||||
|
||||
@[inline]
|
||||
def declNames : Probe Decl Name :=
|
||||
def declNames (pu : Purity) : Probe (Decl pu) Name :=
|
||||
Probe.map (fun decl => return decl.name)
|
||||
|
||||
@[inline]
|
||||
@@ -172,7 +174,8 @@ def tail (n : Nat) : Probe α α := fun data => return data[(data.size - n)...*]
|
||||
@[inline]
|
||||
def head (n : Nat) : Probe α α := fun data => return data[*...n]
|
||||
|
||||
def runOnDeclsNamed (declNames : Array Name) (probe : Probe Decl β) (phase : Phase := Phase.base): CoreM (Array β) := do
|
||||
def runOnDeclsNamed (declNames : Array Name) (phase : Phase := Phase.base)
|
||||
(probe : Probe (Decl phase.toPurity) β) : CoreM (Array β) := do
|
||||
let ext := getExt phase
|
||||
let env ← getEnv
|
||||
let decls ← declNames.mapM fun name => do
|
||||
@@ -180,14 +183,15 @@ def runOnDeclsNamed (declNames : Array Name) (probe : Probe Decl β) (phase : Ph
|
||||
return decl
|
||||
probe decls |>.run (phase := phase)
|
||||
|
||||
def runOnModule (moduleName : Name) (probe : Probe Decl β) (phase : Phase := Phase.base): CoreM (Array β) := do
|
||||
def runOnModule (moduleName : Name) (phase : Phase := Phase.base)
|
||||
(probe : Probe (Decl phase.toPurity) β) : CoreM (Array β) := do
|
||||
let ext := getExt phase
|
||||
let env ← getEnv
|
||||
let some modIdx := env.getModuleIdx? moduleName | throwError "module `{moduleName}` not found"
|
||||
let decls := ext.getModuleEntries env modIdx
|
||||
probe decls |>.run (phase := phase)
|
||||
|
||||
def runGlobally (probe : Probe Decl β) (phase : Phase := Phase.base) : CoreM (Array β) := do
|
||||
def runGlobally (phase : Phase := Phase.base) (probe : Probe (Decl phase.toPurity) β) : CoreM (Array β) := do
|
||||
let ext := getExt phase
|
||||
let env ← getEnv
|
||||
let mut decls := #[]
|
||||
@@ -195,7 +199,7 @@ def runGlobally (probe : Probe Decl β) (phase : Phase := Phase.base) : CoreM (A
|
||||
decls := decls.append <| ext.getModuleEntries env modIdx
|
||||
probe decls |>.run (phase := phase)
|
||||
|
||||
def toPass [ToString β] (probe : Probe Decl β) (phase : Phase) : Pass where
|
||||
def toPass [ToString β] (phase : Phase) (probe : Probe (Decl phase.toPurity) β) : Pass where
|
||||
phase := phase
|
||||
name := `probe
|
||||
run := fun decls => do
|
||||
|
||||
@@ -19,7 +19,7 @@ Local function declaration and join point being pulled.
|
||||
-/
|
||||
structure ToPull where
|
||||
isFun : Bool
|
||||
decl : FunDecl
|
||||
decl : FunDecl .pure
|
||||
used : FVarIdHashSet
|
||||
deriving Inhabited
|
||||
|
||||
@@ -50,7 +50,8 @@ where
|
||||
else
|
||||
go as (a :: keep) dep
|
||||
|
||||
partial def findFVarDepsFixpoint (todo : List ToPull) (acc : Array ToPull := #[]) : PullM (Array ToPull) := do
|
||||
partial def findFVarDepsFixpoint (todo : List ToPull) (acc : Array ToPull := #[]) :
|
||||
PullM (Array ToPull) := do
|
||||
match todo with
|
||||
| [] => return acc
|
||||
| p :: ps =>
|
||||
@@ -65,7 +66,7 @@ partial def findFVarDeps (fvarId : FVarId) : PullM (Array ToPull) := do
|
||||
Similar to `findFVarDeps`. Extract from the state any local function declarations that depends on the given
|
||||
parameters.
|
||||
-/
|
||||
def findParamsDeps (params : Array Param) : PullM (Array ToPull) := do
|
||||
def findParamsDeps (params : Array (Param pu)) : PullM (Array ToPull) := do
|
||||
let mut acc := #[]
|
||||
for param in params do
|
||||
acc := acc ++ (← findFVarDeps param.fvarId)
|
||||
@@ -74,7 +75,7 @@ def findParamsDeps (params : Array Param) : PullM (Array ToPull) := do
|
||||
/--
|
||||
Construct the code `fun p.decl k` or `jp p.decl k`.
|
||||
-/
|
||||
def ToPull.attach (p : ToPull) (k : Code) : Code :=
|
||||
def ToPull.attach (p : ToPull) (k : Code .pure) : Code .pure :=
|
||||
if p.isFun then
|
||||
.fun p.decl k
|
||||
else
|
||||
@@ -83,19 +84,19 @@ def ToPull.attach (p : ToPull) (k : Code) : Code :=
|
||||
/--
|
||||
Attach the given array of local function declarations and join points to `k`.
|
||||
-/
|
||||
partial def attach (ps : Array ToPull) (k : Code) : Code := Id.run do
|
||||
partial def attach (ps : Array ToPull) (k : Code .pure) : Code .pure := Id.run do
|
||||
let visited := ps.map fun _ => false
|
||||
let (_, (k, _)) := go |>.run (k, visited)
|
||||
return k
|
||||
where
|
||||
go : StateM (Code × Array Bool) Unit := do
|
||||
go : StateM (Code .pure × Array Bool) Unit := do
|
||||
for i in *...ps.size do
|
||||
visit i
|
||||
|
||||
visited (i : Nat) : StateM (Code × Array Bool) Bool :=
|
||||
visited (i : Nat) : StateM (Code .pure × Array Bool) Bool :=
|
||||
return (← get).2[i]!
|
||||
|
||||
visit (i : Nat) : StateM (Code × Array Bool) Unit := do
|
||||
visit (i : Nat) : StateM (Code .pure × Array Bool) Unit := do
|
||||
unless (← visited i) do
|
||||
modify fun (k, visited) => (k, visited.set! i true)
|
||||
let pi := ps[i]!
|
||||
@@ -110,7 +111,7 @@ where
|
||||
Extract from the state any local function declarations that depends on the given
|
||||
free variable, **and** attach to code `k`.
|
||||
-/
|
||||
partial def attachFVarDeps (fvarId : FVarId) (k : Code) : PullM Code := do
|
||||
partial def attachFVarDeps (fvarId : FVarId) (k : Code .pure) : PullM (Code .pure) := do
|
||||
let ps ← findFVarDeps fvarId
|
||||
return attach ps k
|
||||
|
||||
@@ -118,11 +119,11 @@ partial def attachFVarDeps (fvarId : FVarId) (k : Code) : PullM Code := do
|
||||
Similar to `attachFVarDeps`. Extract from the state any local function declarations that depends on the given
|
||||
parameters, **and** attach to code `k`.
|
||||
-/
|
||||
def attachParamsDeps (params : Array Param) (k : Code) : PullM Code := do
|
||||
def attachParamsDeps (params : Array (Param .pure)) (k : Code .pure) : PullM (Code .pure) := do
|
||||
let ps ← findParamsDeps params
|
||||
return attach ps k
|
||||
|
||||
def attachJps (k : Code) : PullM Code := do
|
||||
def attachJps (k : Code .pure) : PullM (Code .pure) := do
|
||||
let jps := (← get).filter fun info => !info.isFun
|
||||
modify fun s => s.filter fun info => info.isFun
|
||||
let jps ← findFVarDepsFixpoint jps
|
||||
@@ -132,7 +133,7 @@ mutual
|
||||
/--
|
||||
Add local function declaration (or join point if `isFun = false`) to the state.
|
||||
-/
|
||||
partial def addToPull (isFun : Bool) (decl : FunDecl) : PullM Unit := do
|
||||
partial def addToPull (isFun : Bool) (decl : FunDecl .pure) : PullM Unit := do
|
||||
let saved ← get
|
||||
modify fun _ => []
|
||||
let mut value ← pull decl.value
|
||||
@@ -147,19 +148,19 @@ partial def addToPull (isFun : Bool) (decl : FunDecl) : PullM Unit := do
|
||||
Pull local function declarations and join points in `code`.
|
||||
The state contains the declarations being pulled.
|
||||
-/
|
||||
partial def pull (code : Code) : PullM Code := do
|
||||
partial def pull (code : Code .pure) : PullM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let k ← pull k
|
||||
let k ← attachFVarDeps decl.fvarId k
|
||||
return code.updateLet! decl k
|
||||
| .fun decl k => addToPull true decl; pull k
|
||||
| .fun decl k _ => addToPull true decl; pull k
|
||||
| .jp decl k => addToPull false decl; pull k
|
||||
| .cases c =>
|
||||
let alts ← c.alts.mapMonoM fun alt => do
|
||||
match alt with
|
||||
| .default k => return alt.updateCode (← pull k)
|
||||
| .alt _ ps k =>
|
||||
| .alt _ ps k _ =>
|
||||
let k ← pull k
|
||||
let k ← attachParamsDeps ps k
|
||||
return alt.updateCode k
|
||||
@@ -174,13 +175,13 @@ open PullFunDecls
|
||||
/--
|
||||
Pull local function declarations and join points in the given declaration.
|
||||
-/
|
||||
def Decl.pullFunDecls (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.pullFunDecls (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
let (value, ps) ← decl.value.mapCodeM pull |>.run []
|
||||
let value := value.mapCode (attach ps.toArray)
|
||||
return { decl with value }
|
||||
|
||||
def pullFunDecls : Pass :=
|
||||
.mkPerDeclaration `pullFunDecls Decl.pullFunDecls .base
|
||||
.mkPerDeclaration `pullFunDecls .base Decl.pullFunDecls
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.pullFunDecls (inherited := true)
|
||||
|
||||
@@ -15,28 +15,28 @@ namespace Lean.Compiler.LCNF
|
||||
namespace PullLetDecls
|
||||
|
||||
structure Context where
|
||||
isCandidateFn : LetDecl → FVarIdSet → CompilerM Bool
|
||||
isCandidateFn : LetDecl .pure → FVarIdSet → CompilerM Bool
|
||||
included : FVarIdSet := {}
|
||||
|
||||
structure State where
|
||||
toPull : Array LetDecl := #[]
|
||||
toPull : Array (LetDecl .pure) := #[]
|
||||
|
||||
abbrev PullM := ReaderT Context $ StateRefT State CompilerM
|
||||
|
||||
@[inline] def withFVar (fvarId : FVarId) (x : PullM α) : PullM α :=
|
||||
withReader (fun ctx => { ctx with included := ctx.included.insert fvarId }) x
|
||||
|
||||
@[inline] def withParams (ps : Array Param) (x : PullM α) : PullM α :=
|
||||
@[inline] def withParams (ps : Array (Param .pure)) (x : PullM α) : PullM α :=
|
||||
withReader (fun ctx => { ctx with included := ps.foldl (init := ctx.included) fun s p => s.insert p.fvarId }) x
|
||||
|
||||
@[inline] def withNewScope (x : PullM α) : PullM α :=
|
||||
withReader (fun ctx => { ctx with included := {} }) x
|
||||
|
||||
partial def withCheckpoint (x : PullM Code) : PullM Code := do
|
||||
partial def withCheckpoint (x : PullM (Code .pure)) : PullM (Code .pure) := do
|
||||
let toPullSizeSaved := (← get).toPull.size
|
||||
let c ← withNewScope x
|
||||
let toPull := (← get).toPull
|
||||
let rec go (i : Nat) (included : FVarIdSet) : StateM (Array LetDecl) Code := do
|
||||
let rec go (i : Nat) (included : FVarIdSet) : StateM (Array (LetDecl .pure)) (Code .pure) := do
|
||||
if h : i < toPull.size then
|
||||
let letDecl := toPull[i]
|
||||
if letDecl.dependsOn included then
|
||||
@@ -51,11 +51,11 @@ partial def withCheckpoint (x : PullM Code) : PullM Code := do
|
||||
modify fun s => { s with toPull := s.toPull.shrink toPullSizeSaved ++ keep }
|
||||
return c
|
||||
|
||||
def attachToPull (c : Code) : PullM Code := do
|
||||
def attachToPull (c : Code .pure) : PullM (Code .pure) := do
|
||||
let toPull := (← get).toPull
|
||||
return toPull.foldr (init := c) fun decl c => .let decl c
|
||||
|
||||
def shouldPull (decl : LetDecl) : PullM Bool := do
|
||||
def shouldPull (decl : LetDecl .pure) : PullM Bool := do
|
||||
unless decl.dependsOn (← read).included do
|
||||
if (← (← read).isCandidateFn decl (← read).included) then
|
||||
modify fun s => { s with toPull := s.toPull.push decl }
|
||||
@@ -63,12 +63,12 @@ def shouldPull (decl : LetDecl) : PullM Bool := do
|
||||
return false
|
||||
|
||||
mutual
|
||||
partial def pullAlt (alt : Alt) : PullM Alt :=
|
||||
partial def pullAlt (alt : (Alt .pure)) : PullM (Alt .pure) :=
|
||||
match alt with
|
||||
| .default k => return alt.updateCode (← withNewScope <| pullDecls k)
|
||||
| .alt _ params k => return alt.updateCode (← withNewScope <| withParams params <| pullDecls k)
|
||||
|
||||
partial def pullDecls (code : Code) : PullM Code := do
|
||||
partial def pullDecls (code : Code .pure) : PullM (Code .pure) := do
|
||||
match code with
|
||||
| .cases c =>
|
||||
-- At the present time, we can't correctly enforce the dependencies required for lifting
|
||||
@@ -93,21 +93,21 @@ mutual
|
||||
|
||||
end
|
||||
|
||||
def PullM.run (x : PullM α) (isCandidateFn : LetDecl → FVarIdSet → CompilerM Bool) : CompilerM α :=
|
||||
def PullM.run (x : PullM α) (isCandidateFn : LetDecl .pure → FVarIdSet → CompilerM Bool) : CompilerM α :=
|
||||
x { isCandidateFn } |>.run' {}
|
||||
|
||||
end PullLetDecls
|
||||
|
||||
open PullLetDecls
|
||||
|
||||
def Decl.pullLetDecls (decl : Decl) (isCandidateFn : LetDecl → FVarIdSet → CompilerM Bool) : CompilerM Decl := do
|
||||
def Decl.pullLetDecls (decl : Decl .pure) (isCandidateFn : LetDecl .pure → FVarIdSet → CompilerM Bool) : CompilerM (Decl .pure) := do
|
||||
PullM.run (isCandidateFn := isCandidateFn) do
|
||||
withParams decl.params do
|
||||
let value ← decl.value.mapCodeM pullDecls
|
||||
let value ← value.mapCodeM attachToPull
|
||||
return { decl with value }
|
||||
|
||||
def Decl.pullInstances (decl : Decl) : CompilerM Decl :=
|
||||
def Decl.pullInstances (decl : Decl .pure) : CompilerM (Decl .pure) :=
|
||||
decl.pullLetDecls fun letDecl candidates => do
|
||||
-- TODO: Correctly represent these dependencies so this check isn't required.
|
||||
if let .const _ _ args := letDecl.value then
|
||||
@@ -122,7 +122,7 @@ def Decl.pullInstances (decl : Decl) : CompilerM Decl :=
|
||||
return false
|
||||
|
||||
def pullInstances : Pass :=
|
||||
.mkPerDeclaration `pullInstances Decl.pullInstances .base
|
||||
.mkPerDeclaration `pullInstances .base Decl.pullInstances
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.pullInstances (inherited := true)
|
||||
|
||||
@@ -52,7 +52,7 @@ We assume this limitation is irrelevant in practice.
|
||||
namespace FindUsed
|
||||
|
||||
structure Context where
|
||||
decl : Decl
|
||||
decl : Decl .pure
|
||||
params : FVarIdSet
|
||||
|
||||
structure State where
|
||||
@@ -64,12 +64,12 @@ def visitFVar (fvarId : FVarId) : FindUsedM Unit := do
|
||||
if (← read).params.contains fvarId then
|
||||
modify fun s => { s with used := s.used.insert fvarId }
|
||||
|
||||
def visitArg (arg : Arg) : FindUsedM Unit := do
|
||||
def visitArg (arg : Arg .pure) : FindUsedM Unit := do
|
||||
match arg with
|
||||
| .erased | .type .. => return ()
|
||||
| .fvar fvarId => visitFVar fvarId
|
||||
|
||||
def visitLetValue (e : LetValue) : FindUsedM Unit := do
|
||||
def visitLetValue (e : LetValue .pure) : FindUsedM Unit := do
|
||||
match e with
|
||||
| .erased | .lit .. => return ()
|
||||
| .proj _ _ fvarId => visitFVar fvarId
|
||||
@@ -93,7 +93,7 @@ def visitLetValue (e : LetValue) : FindUsedM Unit := do
|
||||
else
|
||||
args.forM visitArg
|
||||
|
||||
partial def visit (code : Code) : FindUsedM Unit := do
|
||||
partial def visit (code : Code .pure) : FindUsedM Unit := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
visitLetValue decl.value
|
||||
@@ -107,7 +107,7 @@ partial def visit (code : Code) : FindUsedM Unit := do
|
||||
| .return fvarId => visitFVar fvarId
|
||||
| .unreach _ => return ()
|
||||
|
||||
def collectUsedParams (decl : Decl) : CompilerM FVarIdHashSet := do
|
||||
def collectUsedParams (decl : Decl .pure) : CompilerM FVarIdHashSet := do
|
||||
let params := decl.params.foldl (init := {}) fun s p => s.insert p.fvarId
|
||||
let (_, { used, .. }) ← decl.value.forCodeM visit |>.run { decl, params } |>.run {}
|
||||
return used
|
||||
@@ -123,7 +123,7 @@ structure Context where
|
||||
|
||||
abbrev ReduceM := ReaderT Context CompilerM
|
||||
|
||||
partial def reduce (code : Code) : ReduceM Code := do
|
||||
partial def reduce (code : Code .pure) : ReduceM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let .const declName _ args := decl.value | do return code.updateLet! decl (← reduce k)
|
||||
@@ -148,7 +148,7 @@ end ReduceArity
|
||||
|
||||
open FindUsed ReduceArity Internalize
|
||||
|
||||
def Decl.reduceArity (decl : Decl) : CompilerM (Array Decl) := do
|
||||
def Decl.reduceArity (decl : Decl .pure) : CompilerM (Array (Decl .pure)) := do
|
||||
match decl.value with
|
||||
| .code code =>
|
||||
let used ← collectUsedParams decl
|
||||
@@ -160,7 +160,7 @@ def Decl.reduceArity (decl : Decl) : CompilerM (Array Decl) := do
|
||||
trace[Compiler.reduceArity] "{decl.name}, used params: {used.toList.map mkFVar}"
|
||||
let mask := decl.params.map fun param => used.contains param.fvarId
|
||||
let auxName := decl.name ++ `_redArg
|
||||
let mkAuxDecl : CompilerM Decl := do
|
||||
let mkAuxDecl : CompilerM (Decl .pure) := do
|
||||
let params := decl.params.filter fun param => used.contains param.fvarId
|
||||
let value ← decl.value.mapCodeM reduce |>.run { declName := decl.name, auxDeclName := auxName, paramMask := mask }
|
||||
let type ← code.inferType
|
||||
@@ -168,7 +168,7 @@ def Decl.reduceArity (decl : Decl) : CompilerM (Array Decl) := do
|
||||
let auxDecl := { decl with name := auxName, levelParams := [], type, params, value }
|
||||
auxDecl.saveMono
|
||||
return auxDecl
|
||||
let updateDecl : InternalizeM Decl := do
|
||||
let updateDecl : InternalizeM .pure (Decl .pure) := do
|
||||
let params ← decl.params.mapM internalizeParam
|
||||
let mut args := #[]
|
||||
for used in mask, param in params do
|
||||
|
||||
@@ -18,7 +18,7 @@ namespace ReduceJpArity
|
||||
|
||||
abbrev ReduceM := ReaderT (FVarIdMap (Array Bool)) CompilerM
|
||||
|
||||
partial def reduce (code : Code) : ReduceM Code := do
|
||||
partial def reduce (code : Code .pure) : ReduceM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k => return code.updateLet! decl (← reduce k)
|
||||
| .fun decl k =>
|
||||
@@ -69,12 +69,14 @@ open ReduceJpArity
|
||||
/--
|
||||
Try to reduce arity of join points
|
||||
-/
|
||||
def Decl.reduceJpArity (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.reduceJpArity (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
let value ← decl.value.mapCodeM reduce |>.run {}
|
||||
return { decl with value }
|
||||
|
||||
-- TODO: This can be made Purity generic
|
||||
def reduceJpArity (phase := Phase.base) : Pass :=
|
||||
.mkPerDeclaration `reduceJpArity Decl.reduceJpArity phase
|
||||
phase.withPurityCheck .pure fun h =>
|
||||
.mkPerDeclaration `reduceJpArity phase (h ▸ Decl.reduceJpArity)
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.reduceJpArity (inherited := true)
|
||||
|
||||
@@ -16,7 +16,7 @@ A mapping from free variable id to binder name.
|
||||
-/
|
||||
abbrev Renaming := FVarIdMap Name
|
||||
|
||||
def Param.applyRenaming (param : Param) (r : Renaming) : CompilerM Param := do
|
||||
def Param.applyRenaming (param : Param pu) (r : Renaming) : CompilerM (Param pu) := do
|
||||
if let some binderName := r.get? param.fvarId then
|
||||
let param := { param with binderName }
|
||||
modifyLCtx fun lctx => lctx.addParam param
|
||||
@@ -24,7 +24,7 @@ def Param.applyRenaming (param : Param) (r : Renaming) : CompilerM Param := do
|
||||
else
|
||||
return param
|
||||
|
||||
def LetDecl.applyRenaming (decl : LetDecl) (r : Renaming) : CompilerM LetDecl := do
|
||||
def LetDecl.applyRenaming (decl : LetDecl pu) (r : Renaming) : CompilerM (LetDecl pu) := do
|
||||
if let some binderName := r.get? decl.fvarId then
|
||||
let decl := { decl with binderName }
|
||||
modifyLCtx fun lctx => lctx.addLetDecl decl
|
||||
@@ -33,7 +33,7 @@ def LetDecl.applyRenaming (decl : LetDecl) (r : Renaming) : CompilerM LetDecl :=
|
||||
return decl
|
||||
|
||||
mutual
|
||||
partial def FunDecl.applyRenaming (decl : FunDecl) (r : Renaming) : CompilerM FunDecl := do
|
||||
partial def FunDecl.applyRenaming (decl : (FunDecl pu)) (r : Renaming) : CompilerM (FunDecl pu) := do
|
||||
if let some binderName := r.get? decl.fvarId then
|
||||
let decl := decl.updateBinderName binderName
|
||||
modifyLCtx fun lctx => lctx.addFunDecl decl
|
||||
@@ -41,20 +41,20 @@ partial def FunDecl.applyRenaming (decl : FunDecl) (r : Renaming) : CompilerM Fu
|
||||
else
|
||||
decl.updateValue (← decl.value.applyRenaming r)
|
||||
|
||||
partial def Code.applyRenaming (code : Code) (r : Renaming) : CompilerM Code := do
|
||||
partial def Code.applyRenaming (code : Code pu) (r : Renaming) : CompilerM (Code pu) := do
|
||||
match code with
|
||||
| .let decl k => return code.updateLet! (← decl.applyRenaming r) (← k.applyRenaming r)
|
||||
| .fun decl k | .jp decl k => return code.updateFun! (← decl.applyRenaming r) (← k.applyRenaming r)
|
||||
| .fun decl k _ | .jp decl k => return code.updateFun! (← decl.applyRenaming r) (← k.applyRenaming r)
|
||||
| .cases c =>
|
||||
let alts ← c.alts.mapMonoM fun alt =>
|
||||
match alt with
|
||||
| .default k => return alt.updateCode (← k.applyRenaming r)
|
||||
| .alt _ ps k => return alt.updateAlt! (← ps.mapMonoM (·.applyRenaming r)) (← k.applyRenaming r)
|
||||
| .alt _ ps k _ => return alt.updateAlt! (← ps.mapMonoM (·.applyRenaming r)) (← k.applyRenaming r)
|
||||
return code.updateAlts! alts
|
||||
| .jmp .. | .unreach .. | .return .. => return code
|
||||
end
|
||||
|
||||
def Decl.applyRenaming (decl : Decl) (r : Renaming) : CompilerM Decl := do
|
||||
def Decl.applyRenaming (decl : Decl pu) (r : Renaming) : CompilerM (Decl pu) := do
|
||||
if r.isEmpty then
|
||||
return decl
|
||||
else
|
||||
|
||||
@@ -24,7 +24,7 @@ public section
|
||||
namespace Lean.Compiler.LCNF
|
||||
open Simp
|
||||
|
||||
def Decl.simp? (decl : Decl) : SimpM (Option Decl) := do
|
||||
def Decl.simp? (decl : Decl .pure) : SimpM (Option (Decl .pure)) := do
|
||||
let .code code := decl.value | return none
|
||||
updateFunDeclInfo code
|
||||
traceM `Compiler.simp.inline.info do return m!"{decl.name}:{Format.nest 2 (← (← get).funDeclInfoMap.format)}"
|
||||
@@ -42,7 +42,7 @@ def Decl.simp? (decl : Decl) : SimpM (Option Decl) := do
|
||||
else
|
||||
return none
|
||||
|
||||
partial def Decl.simp (decl : Decl) (config : Config) : CompilerM Decl := do
|
||||
partial def Decl.simp (decl : Decl .pure) (config : Config) : CompilerM (Decl .pure) := do
|
||||
let mut config := config
|
||||
if (← isTemplateLike decl) then
|
||||
/-
|
||||
@@ -54,7 +54,7 @@ partial def Decl.simp (decl : Decl) (config : Config) : CompilerM Decl := do
|
||||
config := { config with etaPoly := false, inlinePartial := false }
|
||||
go decl config
|
||||
where
|
||||
go (decl : Decl) (config : Config) : CompilerM Decl := do
|
||||
go (decl : Decl .pure) (config : Config) : CompilerM (Decl .pure) := do
|
||||
if let some decl ← decl.simp? |>.run { config, declName := decl.name } |>.run' {} |>.run {} then
|
||||
-- TODO: bound number of steps?
|
||||
go decl config
|
||||
@@ -62,7 +62,8 @@ where
|
||||
return decl
|
||||
|
||||
def simp (config : Config := {}) (occurrence : Nat := 0) (phase := Phase.base) : Pass :=
|
||||
.mkPerDeclaration `simp (Decl.simp · config) phase (occurrence := occurrence)
|
||||
phase.withPurityCheck .pure fun h =>
|
||||
.mkPerDeclaration `simp phase (h ▸ (Decl.simp · config)) (occurrence := occurrence)
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.simp (inherited := true)
|
||||
|
||||
@@ -22,10 +22,10 @@ let _x.2 := _f.1
|
||||
```
|
||||
`findFunDecl? _x.2` returns `none`, but `findFunDecl'? _x.2` returns the declaration for `_f.1`.
|
||||
-/
|
||||
partial def findFunDecl'? (fvarId : FVarId) : CompilerM (Option FunDecl) := do
|
||||
if let some decl ← findFunDecl? fvarId then
|
||||
partial def findFunDecl'? (fvarId : FVarId) : CompilerM (Option (FunDecl pu)) := do
|
||||
if let some decl ← findFunDecl? (pu := pu) fvarId then
|
||||
return decl
|
||||
else if let some (.fvar fvarId' #[]) ← findLetValue? fvarId then
|
||||
else if let some (.fvar fvarId' #[]) ← findLetValue? (pu := pu) fvarId then
|
||||
findFunDecl'? fvarId'
|
||||
else
|
||||
return none
|
||||
|
||||
@@ -18,14 +18,14 @@ namespace ConstantFold
|
||||
A constant folding monad, the additional state stores auxiliary declarations
|
||||
required to build the new constant.
|
||||
-/
|
||||
abbrev FolderM := StateRefT (Array CodeDecl) CompilerM
|
||||
abbrev FolderM := StateRefT (Array (CodeDecl .pure)) CompilerM
|
||||
|
||||
/--
|
||||
A constant folder for a specific function, takes all the arguments of a
|
||||
certain function and produces a new `Expr` + auxiliary declarations in
|
||||
the `FolderM` monad on success. If the folding fails it returns `none`.
|
||||
-/
|
||||
abbrev Folder := Array Arg → FolderM (Option LetValue)
|
||||
abbrev Folder := Array (Arg .pure) → FolderM (Option (LetValue .pure))
|
||||
|
||||
/--
|
||||
A typeclass for detecting and producing literals of arbitrary types
|
||||
@@ -43,7 +43,7 @@ class Literal (α : Type) where
|
||||
final `Expr` putting them all together into a literal of type `α`,
|
||||
where again the idea of what a literal is depends on `α`.
|
||||
-/
|
||||
mkLit : α → FolderM LetValue
|
||||
mkLit : α → FolderM (LetValue .pure)
|
||||
|
||||
export Literal (getLit mkLit)
|
||||
|
||||
@@ -51,7 +51,7 @@ export Literal (getLit mkLit)
|
||||
A wrapper around `LCNF.mkAuxLetDecl` that will automatically store the
|
||||
`LetDecl` in the state of `FolderM`.
|
||||
-/
|
||||
def mkAuxLetDecl (e : LetValue) (prefixName := `_x) : FolderM FVarId := do
|
||||
def mkAuxLetDecl (e : LetValue .pure) (prefixName := `_x) : FolderM FVarId := do
|
||||
let decl ← LCNF.mkAuxLetDecl e prefixName
|
||||
modify fun s => s.push <| .let decl
|
||||
return decl.fvarId
|
||||
@@ -66,10 +66,10 @@ def mkAuxLit [Literal α] (x : α) (prefixName := `_x) : FolderM FVarId := do
|
||||
mkAuxLetDecl lit prefixName
|
||||
|
||||
partial def getNatLit (fvarId : FVarId) : CompilerM (Option Nat) := do
|
||||
let some (.lit (.nat n)) ← findLetValue? fvarId | return none
|
||||
let some (.lit (.nat n)) ← findLetValue? (pu := .pure) fvarId | return none
|
||||
return n
|
||||
|
||||
def mkNatLit (n : Nat) : FolderM LetValue :=
|
||||
def mkNatLit (n : Nat) : FolderM (LetValue .pure) :=
|
||||
return .lit (.nat n)
|
||||
|
||||
instance : Literal Nat where
|
||||
@@ -77,10 +77,10 @@ instance : Literal Nat where
|
||||
mkLit := mkNatLit
|
||||
|
||||
def getStringLit (fvarId : FVarId) : CompilerM (Option String) := do
|
||||
let some (.lit (.str s)) ← findLetValue? fvarId | return none
|
||||
let some (.lit (.str s)) ← findLetValue? (pu := .pure) fvarId | return none
|
||||
return s
|
||||
|
||||
def mkStringLit (n : String) : FolderM LetValue :=
|
||||
def mkStringLit (n : String) : FolderM (LetValue .pure) :=
|
||||
return .lit (.str n)
|
||||
|
||||
instance : Literal String where
|
||||
@@ -91,7 +91,7 @@ def getBoolLit (fvarId : FVarId) : CompilerM (Option Bool) := do
|
||||
let some (.const ctor [] #[]) ← findLetValue? fvarId | return none
|
||||
return ctor == ``Bool.true
|
||||
|
||||
def mkBoolLit (b : Bool) : FolderM LetValue :=
|
||||
def mkBoolLit (b : Bool) : FolderM (LetValue .pure) :=
|
||||
let ctor := if b then ``Bool.true else ``Bool.false
|
||||
return .const ctor [] #[]
|
||||
|
||||
@@ -115,7 +115,7 @@ instance : Literal Char := mkNatWrapperInstance Char.ofNat ``Char.ofNat Char.toN
|
||||
|
||||
def mkUIntInstance (matchLit : LitValue → Option α) (litValueCtor : α → LitValue) : Literal α where
|
||||
getLit fvarId := do
|
||||
let some (.lit litVal) ← findLetValue? fvarId | return none
|
||||
let some (.lit litVal) ← findLetValue? (pu := .pure) fvarId | return none
|
||||
return matchLit litVal
|
||||
mkLit x :=
|
||||
return .lit <| litValueCtor x
|
||||
@@ -162,7 +162,7 @@ let _x.26 := @Array.push _ _x.24 z
|
||||
_x.26
|
||||
```
|
||||
-/
|
||||
def mkPseudoArrayLiteral (elements : Array FVarId) (typ : Expr) (typLevel : Level) : FolderM LetValue := do
|
||||
def mkPseudoArrayLiteral (elements : Array FVarId) (typ : Expr) (typLevel : Level) : FolderM (LetValue .pure) := do
|
||||
let sizeLit ← mkAuxLit elements.size
|
||||
let mut literal ← mkAuxLetDecl <| .const ``Array.mkEmpty [typLevel] #[.type typ, .fvar sizeLit]
|
||||
for element in elements do
|
||||
@@ -335,7 +335,7 @@ def Folder.mulShift [Literal α] [BEq α] (shiftLeft : Name) (pow2 : α → α)
|
||||
-- TODO: add option for controlling the limit
|
||||
def natPowThreshold := 256
|
||||
|
||||
def foldNatPow (args : Array Arg) : FolderM (Option LetValue) := do
|
||||
def foldNatPow (args : Array (Arg .pure)) : FolderM (Option (LetValue .pure)) := do
|
||||
let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none
|
||||
let some value₁ ← getNatLit fvarId₁ | return none
|
||||
let some value₂ ← getNatLit fvarId₂ | return none
|
||||
@@ -347,14 +347,14 @@ def foldNatPow (args : Array Arg) : FolderM (Option LetValue) := do
|
||||
/--
|
||||
Folder for ofNat operations on fixed-sized integer types.
|
||||
-/
|
||||
def Folder.ofNat (f : Nat → LitValue) (args : Array Arg) : FolderM (Option LetValue) := do
|
||||
def Folder.ofNat (f : Nat → LitValue) (args : Array (Arg .pure)) : FolderM (Option (LetValue .pure)) := do
|
||||
let #[.fvar fvarId] := args | return none
|
||||
let some value ← getNatLit fvarId | return none
|
||||
return some (.lit (f value))
|
||||
|
||||
def Folder.toNat (args : Array Arg) : FolderM (Option LetValue) := do
|
||||
def Folder.toNat (args : Array (Arg .pure)) : FolderM (Option (LetValue .pure)) := do
|
||||
let #[.fvar fvarId] := args | return none
|
||||
let some (.lit lit) ← findLetValue? fvarId | return none
|
||||
let some (.lit lit) ← findLetValue? (pu := .pure) fvarId | return none
|
||||
match lit with
|
||||
| .uint8 v | .uint16 v | .uint32 v | .uint64 v | .usize v => return some (.lit (.nat v.toNat))
|
||||
| .nat _ | .str _ => return none
|
||||
@@ -436,7 +436,7 @@ def stringFolders : List (Name × Folder) := [
|
||||
/--
|
||||
Apply all known folders to `decl`.
|
||||
-/
|
||||
def applyFolders (decl : LetDecl) (folders : SMap Name Folder) : CompilerM (Option (Array CodeDecl)) := do
|
||||
def applyFolders (decl : LetDecl .pure) (folders : SMap Name Folder) : CompilerM (Option (Array (CodeDecl .pure))) := do
|
||||
match decl.value with
|
||||
| .const name _ args =>
|
||||
if let some folder := folders.find? name then
|
||||
@@ -495,7 +495,7 @@ def getFolders : CoreM (SMap Name Folder) :=
|
||||
/--
|
||||
Apply a list of default folders to `decl`
|
||||
-/
|
||||
def foldConstants (decl : LetDecl) : CompilerM (Option (Array CodeDecl)) := do
|
||||
def foldConstants (decl : LetDecl .pure) : CompilerM (Option (Array (CodeDecl .pure))) := do
|
||||
applyFolders decl (← getFolders)
|
||||
|
||||
end ConstantFold
|
||||
|
||||
@@ -19,7 +19,7 @@ and the number of occurrences.
|
||||
We use this function to decide whether to create a `.default` case
|
||||
or not.
|
||||
-/
|
||||
private def getMaxOccs (alts : Array Alt) : Alt × Nat := Id.run do
|
||||
private def getMaxOccs (alts : Array (Alt .pure)) : Alt .pure × Nat := Id.run do
|
||||
let mut maxAlt := alts[0]!
|
||||
let mut max := getNumOccsOf alts 0
|
||||
for h : i in 1...alts.size do
|
||||
@@ -35,7 +35,7 @@ where
|
||||
Note that the number of occurrences can be greater than 1 only when
|
||||
the alternative does not depend on field parameters
|
||||
-/
|
||||
getNumOccsOf (alts : Array Alt) (i : Nat) : Nat := Id.run do
|
||||
getNumOccsOf (alts : Array (Alt .pure)) (i : Nat) : Nat := Id.run do
|
||||
let code := alts[i]!.getCode
|
||||
let mut n := 1
|
||||
for h : j in (i+1)...alts.size do
|
||||
@@ -47,7 +47,7 @@ where
|
||||
Add a default case to the given `cases` alternatives if there
|
||||
are alternatives with equivalent (aka alpha equivalent) right hand sides.
|
||||
-/
|
||||
def addDefaultAlt (alts : Array Alt) : SimpM (Array Alt) := do
|
||||
def addDefaultAlt (alts : Array (Alt .pure)) : SimpM (Array (Alt .pure)) := do
|
||||
if alts.size <= 1 || alts.any (· matches .default ..) then
|
||||
return alts
|
||||
else
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user