Compare commits

...

30 Commits

Author SHA1 Message Date
Leonardo de Moura
c32d827a47 chore: fix tests 2025-02-15 09:02:27 -08:00
Leonardo de Moura
79da5a11fc refactor: cleanup names and add doc strings 2025-02-15 08:57:20 -08:00
Leonardo de Moura
f4afcfc923 feat: divisibility constraint normalizer (#7092)
This PR implements divisibility constraint normalization in `simp
+arith`.
2025-02-15 04:20:40 +00:00
jrr6
9cce0ce8d9 fix: ensure get_elem_tactic works in absence of goals (#7088)
This PR fixes the behavior of the indexed-access notation `xs[i]` in
cases where the proof of `i`'s validity is filled in during unification.

Closes #6999.
2025-02-15 03:00:36 +00:00
Leonardo de Moura
57aadf8af9 feat: add helper theorems for normalizing divisibility constraints (#7091)
This PR adds helper theorems for normalizing divisibility constraints.
They are going to be used to implement the cutsat procedure in the
`grind` tactic.
2025-02-15 02:44:49 +00:00
Kyle Miller
1babe9fc67 feat: make binders in #check be hoverable (#7074)
This PR modifies the signature pretty printer to add hover information
for parameters in binders. This makes the binders be consistent with the
hovers in pi types.

Suggested by @david-christiansen
2025-02-14 17:28:54 +00:00
Markus Himmel
dd1a4188a0 feat: Fin.toNat (#7079)
This PR introduces `Fin.toNat` as an alias for `Fin.val`. We add this
function for discoverability and consistency reasons. The normal form
for proofs remains `Fin.val`, and there is a `simp` lemma rewriting
`Fin.toNat` to `Fin.val`.
2025-02-14 11:59:44 +00:00
Markus Himmel
ed42d068d4 feat: UIntX.ofNatTruncate (#7080)
This PR adds the functions `UIntX.ofNatTruncate` (the version for
`UInt32` already exists).
2025-02-14 11:59:41 +00:00
Markus Himmel
784444c7a9 feat: IntX.minValue, IntX.maxValue, IntX.ofIntLE, IntX.ofIntTruncate (#7081)
This PR adds functions `IntX.ofIntLE`, `IntX.ofIntTruncate`, which are
analogous to the unsigned counterparts `UIntX.ofNatLT` and
`UInt.ofNatTruncate`.
2025-02-14 11:59:37 +00:00
Marc Huisinga
05fb67af90 feat: request cancellation (#7054)
This PR adds language server support for request cancellation to the
following expensive requests: Code actions, auto-completion, document
symbols, folding ranges and semantic highlighting. This means that when
the client informs the language server that a request is stale (e.g.
because it belongs to a previous state of the document), the language
server will now prematurely cancel the computation of the response in
order to reduce the CPU load for requests that will be discarded by the
client anyways.
2025-02-14 11:55:43 +00:00
Marc Huisinga
22d1d04059 fix: incremental goal state requests select incomplete snapshot (#6887)
This PR fixes a bug where the goal state selection would sometimes
select incomplete incremental snapshots on whitespace, leading to an
incorrect "no goals" response. Fixes #6594, a regression that was
originally introduced in 4.11.0 by #4727.

The fundamental cause of #6594 was that the snapshot selection would
always select the first snapshot with a range that contains the cursor
position. For tactics, whitespace had to be included in this range.
However, in the test case of #6594, this meant that the snapshot
selection would also sometimes pick a snapshot before the cursor that
still contains the cursor in its whitespace, but which also does not
necessarily contain all the information needed to produce a correct goal
state. Specifically, at the `InfoTree`-level, when the cursor is in
whitespace, we distinguish competing goal states by their level of
indentation. The snapshot selection did not have access to this
information, so it necessarily had to do the wrong thing in some cases.

This PR fixes the issue by adjusting the snapshot selection for goals to
explicitly account for whitespace and indentation, and refactoring the
language processor architecture to thread enough information through to
the snapshot selection so that it can decide which snapshots to use
without having to force too many tasks, which would destroy
incrementality in goal state requests.

Specifically, this PR makes the following adjustments:
- Refactor `SnapshotTask` to contain both a `Syntax` and a `Range`.
Before, `SnapshotTask`s had a single range that was used both for
displaying file progress information and for selecting snapshots in
server requests. For most snapshots, this range did not include
whitespace, though for tactics it did. Now, the `reportingRange` field
of `SnapshotTask` is intended exclusively for reporting file progress
information, and the `Syntax` is used for selecting snapshots in server
requests. Importantly, the `Syntax` contains the full range information
of the snapshot, i.e. its regular range and its range including
whitespace.
- Adjust all call-sites of `SnapshotTask` to produce a reasonable
`Syntax`.
- Adjust the goal snapshot selection to account for whitespace and
indentation, as the `InfoTree` goal selection does.
- Fix a bug in the snapshot tree tracing that would cause it to render
the `Info` of a snapshot at the wrong location when `trace.Elab.info`
was also set.

This PR is based on #6329.
2025-02-14 11:53:24 +00:00
Paul Reichert
36ac6eb912 feat: insertMany, ofList, ofArray, foldr, foldM functions for the tree map (#7051)
This PR implements the methods `insertMany`, `ofList`, `ofArray`,
`foldr` and `foldrM` on the tree map.

---------

Co-authored-by: Paul Reichert <6992158+datokrat@users.noreply.github.com>
2025-02-14 08:24:33 +00:00
Markus Himmel
47548aa171 chore: rename UIntX.ofNatCore, UIntX.ofNat' -> UIntX.ofNatLT (#7071)
This PR unifies the existing functions `UIntX.ofNatCore` and
`UIntX.ofNat'` under a new name, `UIntX.ofNatLT`.
2025-02-14 06:58:15 +00:00
Leonardo de Moura
b26b781992 feat: simprocs for Int and Nat divides predicates (#7078)
This PR implements simprocs for `Int` and `Nat` divides predicates.
2025-02-14 05:43:38 +00:00
Mac Malone
c9c3366521 feat: lake: support plugins (#7001)
This PR adds support for plugins to Lake. Precompiled modules are now
loaded as plugins rather than via `--load-dynlib`.

Additional plugins can be added through an experimental `plugins`
configuration option. The syntax for specifying this is not yet
convenient, and will be improved in future changes. A parallel `dynlibs`
configuration option has been added for specifying additional dynamic
libraries to build and pass to `--load-dynlib`.

This PR also changes the default directory for `.olean`, `.ilean`, and
module dynamic libraries (i.e., `leanLibDir`) to `lib/lean` instead of
the previous default of `lib`. This avoids potential name clashes
between single module shared libraries and the shared libraries of a
full `lean_lib`.

On non-Windows systems, module dynamic libraries are no longer linked to
their imports or external symbols. Symbols from those libraries are left
unresolved until load time. This avoids nesting these dependencies
within the shared library and means Lake no longer needs to augment the
shared library path to allow Lean to resolve such nested dependencies on
load.
2025-02-14 04:57:31 +00:00
Leonardo de Moura
2c2a3a65b2 feat: support theorems for cutsat Div-Solve rule (#7077)
This PR proves the helper theorems for justifying the "Div-Solve" rule
in the cutsat procedure.
2025-02-14 04:55:58 +00:00
Kim Morrison
8cefb2cf65 feat: premise selection API (#7061)
This PR provides a basic API for a premise selection tool, which can be
provided in downstream libraries. It does not implement premise
selection itself!
2025-02-14 04:08:18 +00:00
Lean stage0 autoupdater
80c8837f49 chore: update stage0 2025-02-13 16:00:29 +00:00
Markus Himmel
40c6dfa3ae chore: dsimproc for UIntX.ofNatLT (#7068)
This PR is a follow-up to #7057 and adds a builtin dsimproc for
`UIntX.ofNatLT` which it turns out we need in stage0 before we can get
the deprecation of `UIntX.ofNatCore` in favor of `UIntX.ofNatLT` off the
ground.
2025-02-13 14:51:42 +00:00
Bulhwi Cha
cc76c46244 doc: fix typo (#7067) 2025-02-13 13:21:18 +00:00
Markus Himmel
b38da34db2 chore: rename BitVec.ofNatLt -> BitVec.ofNatLT (#7064)
This PR renames `BitVec.ofNatLt` to `BitVec.ofNatLT` and sets up
deprecations for the old name.
2025-02-13 12:52:31 +00:00
Markus Himmel
4a900cc65c chore: rename IntX.toNat -> IntX.toNatClampNeg (#7066)
This PR renames `IntX.toNat` to `IntX.toNatClampNeg` (to reduce
surprises) and sets up a deprecation.
2025-02-13 12:14:28 +00:00
Markus Himmel
a3fd2eb0fe chore: make IntX constructor private, provide UIntX.toIntX (#7062)
This PR introduces the functions `UIntX.toIntX` as the public API to
obtain the `IntX` that is 2's complement equivalent to a given `UIntX`.
2025-02-13 11:29:31 +00:00
Paul Reichert
6ac530aa1a feat: deprecated find, fold, foldM, mergeBy functions for the tree map (#7036)
This PR adds some deprecated function aliases to the tree map in order
to ease the transition from the `RBMap` to the tree map.

---------

Co-authored-by: Paul Reichert <6992158+datokrat@users.noreply.github.com>
2025-02-13 11:12:22 +00:00
Markus Himmel
04fe72fee0 feat: missing conversion functions for ISize (#7063)
This PR adds `ISize.toInt8`, `ISize.toInt16`, `Int8.toISize`,
`Int16.toISize`.
2025-02-13 11:02:00 +00:00
Joachim Breitner
a833afa935 feat: binderNameHint in congr (#7053)
This PR makes `simp` heed the `binderNameHint` also in the assumptions
of congruence rules. Fixes #7052.
2025-02-13 09:38:42 +00:00
Markus Himmel
7c9454edd2 feat: UIntX.ofFin (#7056)
This PR adds the `UIntX.ofFin` conversion functions.
2025-02-13 08:45:01 +00:00
Markus Himmel
1ecb4a43ae chore: rename UIntX.val -> UIntX.toFin (#7050)
This PR renames the functions `UIntX.val` to `UIntX.toFin`.
2025-02-13 07:50:47 +00:00
Kim Morrison
ae9d12aeaa chore: upstream an Int lemma (#7060) 2025-02-13 03:19:02 +00:00
Leonardo de Moura
e617ce7e4f refactor: move grind offset constraint module to Grind/Arith/Offset (#7058)
This PR moves the `grind` offset constraint module to the
`Grind/Arith/Offset` subdirectory in preparation to the full linear
integer arithmetic module.
2025-02-12 23:16:07 +00:00
151 changed files with 3294 additions and 1274 deletions

View File

@@ -31,8 +31,12 @@ example (names : List String) : names.all (fun name => "Waldo".isPrefixOf name)
If `binder` is not a binder, then the name of `v` attains a macro scope. This only matters when the
resulting term is used in a non-hygienic way, e.g. in termination proofs for well-founded recursion.
This gadget is supported by `simp`, `dsimp` and `rw` in the right-hand-side of an equation, but not
in hypotheses or by other tactics.
This gadget is supported by
* `simp`, `dsimp` and `rw` in the right-hand-side of an equation
* `simp` in the assumptions of congruence rules
It is ineffective in other positions (hyptheses of rewrite rules) or when used by other tactics
(e.g. `apply`).
-/
@[simp ]
def binderNameHint {α : Sort u} {β : Sort v} {γ : Sort w} (v : α) (binder : β) (e : γ) : γ := e

View File

@@ -195,7 +195,7 @@ end Classical
/- Export for Mathlib compat. -/
export Classical (imp_iff_right_iff imp_and_neg_imp_iff and_or_imp not_imp)
/-- Extract an element from a existential statement, using `Classical.choose`. -/
/-- Extract an element from an existential statement, using `Classical.choose`. -/
-- This enables projection notation.
@[reducible] noncomputable def Exists.choose {p : α Prop} (P : a, p a) : α := Classical.choose P

View File

@@ -25,6 +25,10 @@ set_option linter.missingDocs true
namespace BitVec
@[inline, deprecated BitVec.ofNatLT (since := "2025-02-13"), inherit_doc BitVec.ofNatLT]
protected def ofNatLt {n : Nat} (i : Nat) (p : i < 2 ^ n) : BitVec n :=
BitVec.ofNatLT i p
section Nat
instance natCastInst : NatCast (BitVec w) := BitVec.ofNat w
@@ -55,12 +59,12 @@ end subsingleton
section zero_allOnes
/-- Return a bitvector `0` of size `n`. This is the bitvector with all zero bits. -/
protected def zero (n : Nat) : BitVec n := .ofNatLt 0 (Nat.two_pow_pos n)
protected def zero (n : Nat) : BitVec n := .ofNatLT 0 (Nat.two_pow_pos n)
instance : Inhabited (BitVec n) where default := .zero n
/-- Bit vector of size `n` where all bits are `1`s -/
def allOnes (n : Nat) : BitVec n :=
.ofNatLt (2^n - 1) (Nat.le_of_eq (Nat.sub_add_cancel (Nat.two_pow_pos n)))
.ofNatLT (2^n - 1) (Nat.le_of_eq (Nat.sub_add_cancel (Nat.two_pow_pos n)))
end zero_allOnes
@@ -138,7 +142,7 @@ protected def toInt (x : BitVec n) : Int :=
(x.toNat : Int) - (2^n : Nat)
/-- The `BitVec` with value `(2^n + (i mod 2^n)) mod 2^n`. -/
protected def ofInt (n : Nat) (i : Int) : BitVec n := .ofNatLt (i % (Int.ofNat (2^n))).toNat (by
protected def ofInt (n : Nat) (i : Int) : BitVec n := .ofNatLT (i % (Int.ofNat (2^n))).toNat (by
apply (Int.toNat_lt _).mpr
· apply Int.emod_lt_of_pos
exact Int.ofNat_pos.mpr (Nat.two_pow_pos _)
@@ -167,12 +171,12 @@ recommended_spelling "one" for "1#n" in [BitVec.ofNat, «term__#__»]
| `($(_) $n $i:num) => `($i:num#$n)
| _ => throw ()
/-- Notation for bit vector literals without truncation. `i#'lt` is a shorthand for `BitVec.ofNatLt i lt`. -/
/-- Notation for bit vector literals without truncation. `i#'lt` is a shorthand for `BitVec.ofNatLT i lt`. -/
scoped syntax:max term:max noWs "#'" noWs term:max : term
macro_rules | `($i#'$p) => `(BitVec.ofNatLt $i $p)
macro_rules | `($i#'$p) => `(BitVec.ofNatLT $i $p)
/-- Unexpander for bit vector literals without truncation. -/
@[app_unexpander BitVec.ofNatLt] def unexpandBitVecOfNatLt : Lean.PrettyPrinter.Unexpander
@[app_unexpander BitVec.ofNatLT] def unexpandBitVecOfNatLt : Lean.PrettyPrinter.Unexpander
| `($(_) $i $p) => `($i#'$p)
| _ => throw ()
@@ -356,7 +360,7 @@ end relations
section cast
/-- `cast eq x` embeds `x` into an equal `BitVec` type. -/
@[inline] protected def cast (eq : n = m) (x : BitVec n) : BitVec m := .ofNatLt x.toNat (eq x.isLt)
@[inline] protected def cast (eq : n = m) (x : BitVec n) : BitVec m := .ofNatLT x.toNat (eq x.isLt)
@[simp] theorem cast_ofNat {n m : Nat} (h : n = m) (x : Nat) :
(BitVec.ofNat n x).cast h = BitVec.ofNat m x := by

View File

@@ -274,16 +274,27 @@ theorem ofBool_eq_iff_eq : ∀ {b b' : Bool}, BitVec.ofBool b = BitVec.ofBool b'
@[simp, bitvec_to_nat] theorem toNat_ofFin (x : Fin (2^n)) : (BitVec.ofFin x).toNat = x.val := rfl
@[simp] theorem toNat_ofNatLt (x : Nat) (p : x < 2^w) : (x#'p).toNat = x := rfl
@[simp] theorem toNat_ofNatLT (x : Nat) (p : x < 2^w) : (x#'p).toNat = x := rfl
@[simp] theorem getLsbD_ofNatLt {n : Nat} (x : Nat) (lt : x < 2^n) (i : Nat) :
@[deprecated toNat_ofNatLT (since := "2025-02-13")]
theorem toNat_ofNatLt (x : Nat) (p : x < 2^w) : (x#'p).toNat = x := rfl
@[simp] theorem getLsbD_ofNatLT {n : Nat} (x : Nat) (lt : x < 2^n) (i : Nat) :
getLsbD (x#'lt) i = x.testBit i := by
simp [getLsbD, BitVec.ofNatLt]
simp [getLsbD, BitVec.ofNatLT]
@[simp] theorem getMsbD_ofNatLt {n x i : Nat} (h : x < 2^n) :
@[deprecated getLsbD_ofNatLT (since := "2025-02-13")]
theorem getLsbD_ofNatLt {n : Nat} (x : Nat) (lt : x < 2^n) (i : Nat) :
getLsbD (x#'lt) i = x.testBit i := getLsbD_ofNatLT x lt i
@[simp] theorem getMsbD_ofNatLT {n x i : Nat} (h : x < 2^n) :
getMsbD (x#'h) i = (decide (i < n) && x.testBit (n - 1 - i)) := by
simp [getMsbD, getLsbD]
@[deprecated getMsbD_ofNatLT (since := "2025-02-13")]
theorem getMsbD_ofNatLt {n x i : Nat} (h : x < 2^n) :
getMsbD (x#'h) i = (decide (i < n) && x.testBit (n - 1 - i)) := getMsbD_ofNatLT h
@[simp, bitvec_to_nat] theorem toNat_ofNat (x w : Nat) : (BitVec.ofNat w x).toNat = x % 2^w := by
simp [BitVec.toNat, BitVec.ofNat, Fin.ofNat']
@@ -1217,7 +1228,7 @@ theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl
@[simp] theorem ofInt_negSucc_eq_not_ofNat {w n : Nat} :
BitVec.ofInt w (Int.negSucc n) = ~~~.ofNat w n := by
simp only [BitVec.ofInt, Int.toNat, Int.ofNat_eq_coe, toNat_eq, toNat_ofNatLt, toNat_not,
simp only [BitVec.ofInt, Int.toNat, Int.ofNat_eq_coe, toNat_eq, toNat_ofNatLT, toNat_not,
toNat_ofNat]
cases h : Int.negSucc n % ((2 ^ w : Nat) : Int)
case ofNat =>
@@ -1418,7 +1429,7 @@ theorem shiftLeftZeroExtend_eq {x : BitVec w} :
apply eq_of_toNat_eq
rw [shiftLeftZeroExtend, setWidth]
split
· simp only [toNat_ofNatLt, toNat_shiftLeft, toNat_setWidth']
· simp only [toNat_ofNatLT, toNat_shiftLeft, toNat_setWidth']
rw [Nat.mod_eq_of_lt]
rw [Nat.shiftLeft_eq, Nat.pow_add]
exact Nat.mul_lt_mul_of_pos_right x.isLt (Nat.two_pow_pos _)
@@ -2901,7 +2912,7 @@ protected theorem ne_of_lt {x y : BitVec n} : x < y → x ≠ y := by
apply Nat.ne_of_lt
protected theorem umod_lt (x : BitVec n) {y : BitVec n} : 0 < y x % y < y := by
simp only [ofNat_eq_ofNat, lt_def, toNat_ofNat, Nat.zero_mod, umod, toNat_ofNatLt]
simp only [ofNat_eq_ofNat, lt_def, toNat_ofNat, Nat.zero_mod, umod, toNat_ofNatLT]
apply Nat.mod_lt
theorem not_lt_iff_le {x y : BitVec w} : (¬ x < y) y x := by
@@ -3243,7 +3254,7 @@ theorem toNat_smod {x y : BitVec w} : (x.smod y).toNat =
by_cases h : x.msb <;> by_cases h' : y.msb
<;> by_cases h'' : (-x).umod y = 0#w <;> by_cases h''' : x.umod (-y) = 0#w
<;> simp only [h, h', h'', h''']
<;> simp only [umod, toNat_eq, toNat_ofNatLt, toNat_ofNat, Nat.zero_mod] at h'' h'''
<;> simp only [umod, toNat_eq, toNat_ofNatLT, toNat_ofNat, Nat.zero_mod] at h'' h'''
<;> simp [h'', h''']
@[simp]

View File

@@ -56,7 +56,7 @@ def get : (a : @& ByteArray) → (i : @& Nat) → (h : i < a.size := by get_elem
instance : GetElem ByteArray Nat UInt8 fun xs i => i < xs.size where
getElem xs i h := xs.get i
instance : GetElem ByteArray USize UInt8 fun xs i => i.val < xs.size where
instance : GetElem ByteArray USize UInt8 fun xs i => i.toFin < xs.size where
getElem xs i h := xs.uget i h
@[extern "lean_byte_array_set"]

View File

@@ -40,12 +40,12 @@ theorem isValidUInt32 (n : Nat) (h : isValidCharNat n) : n < UInt32.size := by
apply Nat.lt_trans h₂
decide
theorem isValidChar_of_isValidCharNat (n : Nat) (h : isValidCharNat n) : isValidChar (UInt32.ofNat' n (isValidUInt32 n h)) :=
theorem isValidChar_of_isValidCharNat (n : Nat) (h : isValidCharNat n) : isValidChar (UInt32.ofNatLT n (isValidUInt32 n h)) :=
match h with
| Or.inl h =>
Or.inl (UInt32.ofNat'_lt_of_lt _ (by decide) h)
Or.inl (UInt32.ofNatLT_lt_of_lt _ (by decide) h)
| Or.inr h₁, h₂ =>
Or.inr UInt32.lt_ofNat'_of_lt _ (by decide) h₁, UInt32.ofNat'_lt_of_lt _ (by decide) h₂
Or.inr UInt32.lt_ofNatLT_of_lt _ (by decide) h₁, UInt32.ofNatLT_lt_of_lt _ (by decide) h₂
theorem isValidChar_zero : isValidChar 0 :=
Or.inl (by decide)

View File

@@ -51,6 +51,14 @@ Returns `a` modulo `n + 1` as a `Fin n.succ`.
protected def ofNat {n : Nat} (a : Nat) : Fin (n + 1) :=
a % (n+1), Nat.mod_lt _ (Nat.zero_lt_succ _)
-- We provide this because other similar types have a `toNat` function, but `simp` rewrites
-- `i.toNat` to `i.val`.
@[inline, inherit_doc val]
protected def toNat (i : Fin n) : Nat :=
i.val
@[simp] theorem toNat_eq_val {i : Fin n} : i.toNat = i.val := rfl
private theorem mlt {b : Nat} : {a : Nat} a < n b % n < n
| 0, h => Nat.mod_lt _ h
| _+1, h =>

View File

@@ -62,7 +62,7 @@ def get? (ds : FloatArray) (i : Nat) : Option Float :=
instance : GetElem FloatArray Nat Float fun xs i => i < xs.size where
getElem xs i h := xs.get i h
instance : GetElem FloatArray USize Float fun xs i => i.val < xs.size where
instance : GetElem FloatArray USize Float fun xs i => i.toNat < xs.size where
getElem xs i h := xs.uget i h
@[extern "lean_float_array_uset"]

View File

@@ -15,3 +15,4 @@ import Init.Data.Int.Order
import Init.Data.Int.Pow
import Init.Data.Int.Cooper
import Init.Data.Int.Linear
import Init.Data.Int.Cutsat

View File

@@ -0,0 +1,68 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Data.AC
import Init.Data.Int.Gcd
namespace Int.Linear
/-!
Helper theorems for solving divisibility constraints.
The two theorems are used to justify the `Div-Solve` rule
in the section "Strong Conflict Resolution" in the paper
"Cutting to the Chase: Solving Linear Integer Arithmetic".
-/
theorem dvd_solve_1 {x : Int} {d₁ a₁ p₁ : Int} {d₂ a₂ p₂ : Int} {α β d : Int}
(h : α*a₁*d₂ + β*a₂*d₁ = d)
(h₁ : d₁ a₁*x + p₁)
(h₂ : d₂ a₂*x + p₂)
: d₁*d₂ d*x + α*d₂*p₁ + β*d₁*p₂ := by
rcases h₁ with k₁, h₁
replace h₁ : α*a₁*d₂*x + α*d₂*p₁ = d₁*d₂*(α*k₁) := by
have ac₁ : d₁*d₂*(α*k₁) = α*d₂*(d₁*k₁) := by ac_rfl
have ac₂ : α * a₁ * d₂ * x = α * d₂ * (a₁ * x) := by ac_rfl
rw [ac₁, h₁, Int.mul_add, ac₂]
rcases h₂ with k₂, h₂
replace h₂ : β*a₂*d₁*x + β*d₁*p₂ = d₁*d₂*(β*k₂) := by
have ac₁ : d₁*d₂*(β*k₂) = β*d₁*(d₂*k₂) := by ac_rfl
have ac₂ : β * a₂ * d₁ * x = β * d₁ * (a₂ * x) := by ac_rfl
rw [ac₁, h₂, Int.mul_add, ac₂]
replace h₁ : d₁*d₂ α*a₁*d₂*x + α*d₂*p₁ := α*k₁, h₁
replace h₂ : d₁*d₂ β*a₂*d₁*x + β*d₁*p₂ := β*k₂, h₂
have h' := Int.dvd_add h₁ h₂; clear h₁ h₂ k₁ k₂
replace h : d*x = α*a₁*d₂*x + β*a₂*d₁*x := by
rw [h, Int.add_mul]
have ac :
α * a₁ * d₂ * x + α * d₂ * p₁ + (β * a₂ * d₁ * x + β * d₁ * p₂)
=
α * a₁ * d₂ * x + β * a₂ * d₁ * x + α * d₂ * p₁ + β * d₁ * p₂ := by ac_rfl
rw [h, ac]
assumption
theorem dvd_solve_2 {x : Int} {d₁ a₁ p₁ : Int} {d₂ a₂ p₂ : Int} {d : Int}
(h : d = Int.gcd (a₁*d₂) (a₂*d₁))
(h₁ : d₁ a₁*x + p₁)
(h₂ : d₂ a₂*x + p₂)
: d a₂*p₁ - a₁*p₂ := by
rcases h₁ with k₁, h₁
rcases h₂ with k₂, h₂
have h₃ : d a₁*d₂ := by
rw [h]; apply Int.gcd_dvd_left
have h₄ : d a₂*d₁ := by
rw [h]; apply Int.gcd_dvd_right
rcases h₃ with k₃, h₃
rcases h₄ with k₄, h₄
have : a₂*p₁ - a₁*p₂ = a₂*d₁*k₁ - a₁*d₂*k₂ := by
have ac₁ : a₂*d₁*k₁ = a₂*(d₁*k₁) := by ac_rfl
have ac₂ : a₁*d₂*k₂ = a₁*(d₂*k₂) := by ac_rfl
have ac₃ : a₁*(a₂*x) = a₂*(a₁*x) := by ac_rfl
rw [ac₁, ac₂, h₁, h₂, Int.mul_add, Int.mul_add, ac₃, Int.sub_sub, Int.add_comm, Int.add_sub_assoc]
simp
rw [h₃, h₄, Int.mul_assoc, Int.mul_assoc, Int.mul_sub] at this
exact k₄ * k₁ - k₃ * k₂, this
end Int.Linear

View File

@@ -22,11 +22,11 @@ namespace Int
protected theorem dvd_def (a b : Int) : (a b) = Exists (fun c => b = a * c) := rfl
protected theorem dvd_zero (n : Int) : n 0 := 0, (Int.mul_zero _).symm
@[simp] protected theorem dvd_zero (n : Int) : n 0 := 0, (Int.mul_zero _).symm
protected theorem dvd_refl (n : Int) : n n := 1, (Int.mul_one _).symm
@[simp] protected theorem dvd_refl (n : Int) : n n := 1, (Int.mul_one _).symm
protected theorem one_dvd (n : Int) : 1 n := n, (Int.one_mul n).symm
@[simp] protected theorem one_dvd (n : Int) : 1 n := n, (Int.one_mul n).symm
protected theorem dvd_trans : {a b c : Int}, a b b c a c
| _, _, _, d, rfl, e, rfl => Exists.intro (d * e) (by rw [Int.mul_assoc])
@@ -1347,3 +1347,14 @@ theorem bmod_natAbs_plus_one (x : Int) (w : 1 < x.natAbs) : bmod x (x.natAbs + 1
theorem bmod_neg_bmod : bmod (-(bmod x n)) n = bmod (-x) n := by
apply (bmod_add_cancel_right x).mp
rw [Int.add_left_neg, add_bmod_bmod, Int.add_left_neg]
/-! Helper theorems for `dvd` simproc -/
protected theorem dvd_eq_true_of_mod_eq_zero {a b : Int} (h : b % a == 0) : (a b) = True := by
simp [Int.dvd_of_emod_eq_zero, eq_of_beq h]
protected theorem dvd_eq_false_of_mod_ne_zero {a b : Int} (h : b % a != 0) : (a b) = False := by
simp [eq_of_beq] at h
simp [Int.dvd_iff_emod_eq_zero, h]
end Int

View File

@@ -326,6 +326,10 @@ theorem toNat_sub (m n : Nat) : toNat (m - n) = m - n := by
· exact (Nat.add_sub_cancel_left ..).symm
· dsimp; rw [Nat.add_assoc, Nat.sub_eq_zero_of_le (Nat.le_add_right ..)]; rfl
theorem toNat_of_nonpos : {z : Int}, z 0 z.toNat = 0
| 0, _ => rfl
| -[_+1], _ => rfl
/- ## add/sub injectivity -/
protected theorem add_left_inj {i j : Int} (k : Int) : (i + k = j + k) i = j := by

View File

@@ -9,6 +9,7 @@ import Init.Data.Prod
import Init.Data.Int.Lemmas
import Init.Data.Int.LemmasAux
import Init.Data.Int.DivModLemmas
import Init.Data.Int.Gcd
import Init.Data.RArray
namespace Int.Linear
@@ -69,12 +70,14 @@ def Poly.insert (k : Int) (v : Var) (p : Poly) : Poly :=
else
.add k' v' (insert k v p)
/-- Normalizes the given polynomial by fusing monomial and constants. -/
def Poly.norm (p : Poly) : Poly :=
match p with
| .num k => .num k
| .add k v p => (norm p).insert k v
def Expr.toPoly' (e : Expr) :=
/-- Converts the given expression into a polynomial. -/
def Expr.toPoly' (e : Expr) : Poly :=
go 1 e (.num 0)
where
go (coeff : Int) : Expr (Poly Poly)
@@ -86,21 +89,42 @@ where
| .mulR a k => bif k == 0 then id else go (Int.mul coeff k) a
| .neg a => go (-coeff) a
/-- Converts the given expression into a polynomial, and then normalizes it. -/
def Expr.toPoly (e : Expr) : Poly :=
e.toPoly'.norm
inductive PolyCnstr where
| eq (p : Poly)
| le (p : Poly)
/-- Relational contraints: equality and inequality. -/
inductive RelCnstr where
| /-- `p = 0` constraint. -/
eq (p : Poly)
| /-- `p ≤ 0` contraint. -/
le (p : Poly)
deriving BEq
def PolyCnstr.denote (ctx : Context) : PolyCnstr Prop
def RelCnstr.denote (ctx : Context) : RelCnstr Prop
| .eq p => p.denote ctx = 0
| .le p => p.denote ctx 0
/--
Returns the ceiling of the division `a / b`. That is, the result is equivalent to `⌈a / b⌉`.
Examples:
- `cdiv 7 3` returns `3`
- `cdiv (-7) 3` returns `-2`.
-/
def cdiv (a b : Int) : Int :=
-((-a)/b)
/--
Returns the ceiling-compatible remainder of the division `a / b`.
This function ensures that the remainder is consistent with `cdiv`, meaning:
```
a = b * cdiv a b + cmod a b
```
See theorem `cdiv_add_cmod`. We also have
```
-b < cmod a b ≤ 0
```
-/
def cmod (a b : Int) : Int :=
-((-a)%b)
@@ -126,7 +150,11 @@ theorem cmod_eq_zero_iff_emod_eq_zero (a b : Int) : cmod a b = 0 ↔ a%b = 0 :=
simp at this
simp [Int.neg_emod, this, Eq.comm]
theorem cdiv_eq_div_of_divides {a b : Int} (h : (a/b)*b = a) : a/b = cdiv a b := by
private abbrev div_mul_cancel_of_mod_zero :=
@Int.ediv_mul_cancel_of_emod_eq_zero
theorem cdiv_eq_div_of_divides {a b : Int} (h : a % b = 0) : a/b = cdiv a b := by
replace h := div_mul_cancel_of_mod_zero h
have hz : a % b = 0 := by
have := Int.ediv_add_emod a b
conv at this => rhs; rw [ Int.add_zero a]
@@ -143,60 +171,85 @@ theorem cdiv_eq_div_of_divides {a b : Int} (h : (a/b)*b = a) : a/b = cdiv a b :=
next => simp[cdiv, h]
next => rw [Int.mul_eq_mul_right_iff h] at this; assumption
def Poly.div (k : Int) : Poly Poly
| .num k' => .num (cdiv k' k)
| .add k' x p => .add (k'/k) x (div k p)
def Poly.divAll (k : Int) : Poly Bool
| .num k' => (k'/k)*k == k'
| .add k' _ p => (k'/k)*k == k' && divAll k p
def Poly.divCoeffs (k : Int) : Poly Bool
| .num _ => true
| .add k' _ p => (k'/k)*k == k' && divCoeffs k p
/-- Returns the constant of the given linear polynomial. -/
def Poly.getConst : Poly Int
| .num k => k
| .add _ _ p => getConst p
def PolyCnstr.norm : PolyCnstr PolyCnstr
/--
`p.div k` divides all coefficients of the polynomial `p` by `k`, but
rounds up the constant using `cdiv`.
Notes:
- We only use this function with `k`s that divides all coefficients.
- We use `cdiv` for the constant to implement the inequality tightening rule.
-/
def Poly.div (k : Int) : Poly Poly
| .num k' => .num (cdiv k' k)
| .add k' x p => .add (k'/k) x (div k p)
/--
Returns `true` if `k` divides all coefficients and the constant of the given
linear polynomial.
-/
def Poly.divAll (k : Int) : Poly Bool
| .num k' => k' % k == 0
| .add k' _ p => k' % k == 0 && divAll k p
/--
Returns `true` if `k` divides all coefficients of the given linear polynomial.
-/
def Poly.divCoeffs (k : Int) : Poly Bool
| .num _ => true
| .add k' _ p => k' % k == 0 && divCoeffs k p
/-- Normalizes the polynomial of the given relational constraint. -/
def RelCnstr.norm : RelCnstr RelCnstr
| .eq p => .eq p.norm
| .le p => .le p.norm
def PolyCnstr.divAll (k : Int) : PolyCnstr Bool
/-- Returns `true` if `k` divides all coefficients and constant of the given relational constraint. -/
def RelCnstr.divAll (k : Int) : RelCnstr Bool
| .eq p | .le p => p.divAll k
def PolyCnstr.divCoeffs (k : Int) : PolyCnstr Bool
/-- Returns `true` if `k` divides all coefficients of the given relational constraint. -/
def RelCnstr.divCoeffs (k : Int) : RelCnstr Bool
| .eq p | .le p => p.divCoeffs k
def PolyCnstr.isLe : PolyCnstr Bool
/-- Returns `true` if the given relational constraint is an inequality constraint of the form `p ≤ 0`. -/
def RelCnstr.isLe : RelCnstr Bool
| .eq _ => false
| .le _ => true
def PolyCnstr.div (k : Int) : PolyCnstr PolyCnstr
/--
Divides all coefficients and constants in the linear polynomial of the given constraint by `k`.
We rounds up the constant using `cdiv`.
-/
def RelCnstr.div (k : Int) : RelCnstr RelCnstr
| .eq p => .eq <| p.div k
| .le p => .le <| p.div k
inductive ExprCnstr where
/-- Raw relational constraint. They are later converted into `RelCnstr`. -/
inductive RawRelCnstr where
| eq (p₁ p₂ : Expr)
| le (p₁ p₂ : Expr)
deriving Inhabited, BEq
def ExprCnstr.isLe : ExprCnstr Bool
/-- Returns `true` if the given relational constraint is an inequality constraint of the form `e₁ ≤ e₂`. -/
def RawRelCnstr.isLe : RawRelCnstr Bool
| .eq .. => false
| .le .. => true
def ExprCnstr.denote (ctx : Context) : ExprCnstr Prop
def RawRelCnstr.denote (ctx : Context) : RawRelCnstr Prop
| .eq e₁ e₂ => e₁.denote ctx = e₂.denote ctx
| .le e₁ e₂ => e₁.denote ctx e₂.denote ctx
def ExprCnstr.toPoly : ExprCnstr PolyCnstr
def RawRelCnstr.norm : RawRelCnstr RelCnstr
| .eq e₁ e₂ => .eq (e₁.sub e₂).toPoly.norm
| .le e₁ e₂ => .le (e₁.sub e₂).toPoly.norm
-- Certificate for normalizing the coefficients of a constraint
def divBy (e e' : ExprCnstr) (k : Int) : Bool :=
k > 0 && e.toPoly.divAll k && e'.toPoly == e.toPoly.div k
/-- A certificate for normalizing the coefficients of a raw relational constraint. -/
def divBy (e e' : RawRelCnstr) (k : Int) : Bool :=
k > 0 && e.norm.divAll k && e'.norm == e.norm.div k
attribute [local simp] Int.add_comm Int.add_assoc Int.add_left_comm Int.add_mul Int.mul_add
attribute [local simp] Poly.insert Poly.denote Poly.norm Poly.addConst
@@ -228,13 +281,14 @@ private theorem neg_fold (a : Int) : a.neg = -a := rfl
attribute [local simp] sub_fold neg_fold
attribute [local simp] Poly.div Poly.divAll PolyCnstr.denote
attribute [local simp] Poly.div Poly.divAll RelCnstr.denote
theorem Poly.denote_div_eq_of_divAll (ctx : Context) (p : Poly) (k : Int) : p.divAll k (p.div k).denote ctx * k = p.denote ctx := by
induction p with
| num _ => simp; intro h; rw [ cdiv_eq_div_of_divides h]; assumption
| num _ => simp; intro h; rw [ cdiv_eq_div_of_divides h]; exact div_mul_cancel_of_mod_zero h
| add k' v p ih =>
simp; intro h₁ h₂
replace h₁ := div_mul_cancel_of_mod_zero h₁
have ih := ih h₂
simp [ih]
apply congrArg (denote ctx p + ·)
@@ -247,10 +301,11 @@ theorem Poly.denote_div_eq_of_divCoeffs (ctx : Context) (p : Poly) (k : Int) : p
| num k' => simp; rw [Int.mul_comm, cdiv_add_cmod]
| add k' v p ih =>
simp; intro h₁ h₂
replace h₁ := div_mul_cancel_of_mod_zero h₁
rw [ ih h₂]
rw [Int.mul_right_comm, h₁, Int.add_assoc]
attribute [local simp] ExprCnstr.denote ExprCnstr.toPoly Expr.denote
attribute [local simp] RawRelCnstr.denote RawRelCnstr.norm Expr.denote
theorem Expr.denote_toPoly'_go (ctx : Context) (e : Expr) :
(toPoly'.go k e p).denote ctx = k * e.denote ctx + p.denote ctx := by
@@ -279,9 +334,9 @@ theorem Expr.denote_toPoly'_go (ctx : Context) (e : Expr) :
theorem Expr.denote_toPoly (ctx : Context) (e : Expr) : e.toPoly.denote ctx = e.denote ctx := by
simp [toPoly, toPoly', Expr.denote_toPoly'_go]
attribute [local simp] Expr.denote_toPoly PolyCnstr.denote
attribute [local simp] Expr.denote_toPoly RelCnstr.denote
theorem ExprCnstr.denote_toPoly (ctx : Context) (c : ExprCnstr) : c.toPoly.denote ctx = c.denote ctx := by
theorem RawRelCnstr.denote_norm (ctx : Context) (c : RawRelCnstr) : c.norm.denote ctx = c.denote ctx := by
cases c <;> simp
· rw [Int.sub_eq_zero]
· constructor
@@ -300,7 +355,7 @@ instance : LawfulBEq Poly where
· rename_i k v p ih
exact ih
instance : LawfulBEq PolyCnstr where
instance : LawfulBEq RelCnstr where
eq_of_beq {a b} := by
cases a <;> cases b <;> rename_i p₁ p₂ <;> simp_all! [BEq.beq]
· show (p₁ == p₂) = true _
@@ -316,22 +371,22 @@ theorem Expr.eq_of_toPoly_eq (ctx : Context) (e e' : Expr) (h : e.toPoly == e'.t
simp [Poly.norm] at h
assumption
theorem ExprCnstr.eq_of_toPoly_eq (ctx : Context) (c c' : ExprCnstr) (h : c.toPoly == c'.toPoly) : c.denote ctx = c'.denote ctx := by
have h := congrArg (PolyCnstr.denote ctx) (eq_of_beq h)
rw [denote_toPoly, denote_toPoly] at h
theorem RawRelCnstr.eq_of_norm_eq (ctx : Context) (c c' : RawRelCnstr) (h : c.norm == c'.norm) : c.denote ctx = c'.denote ctx := by
have h := congrArg (RelCnstr.denote ctx) (eq_of_beq h)
rw [denote_norm, denote_norm] at h
assumption
theorem ExprCnstr.eq_of_toPoly_eq_var (ctx : Context) (x y : Var) (c : ExprCnstr) (h : c.toPoly == .eq (.add 1 x (.add (-1) y (.num 0))))
theorem RawRelCnstr.eq_of_norm_eq_var (ctx : Context) (x y : Var) (c : RawRelCnstr) (h : c.norm == .eq (.add 1 x (.add (-1) y (.num 0))))
: c.denote ctx = (x.denote ctx = y.denote ctx) := by
have h := congrArg (PolyCnstr.denote ctx) (eq_of_beq h)
rw [denote_toPoly] at h
have h := congrArg (RelCnstr.denote ctx) (eq_of_beq h)
rw [denote_norm] at h
rw [h]; simp
rw [ Int.sub_eq_add_neg, Int.sub_eq_zero]
theorem ExprCnstr.eq_of_toPoly_eq_const (ctx : Context) (x : Var) (k : Int) (c : ExprCnstr) (h : c.toPoly == .eq (.add 1 x (.num (-k))))
theorem RawRelCnstr.eq_of_norm_eq_const (ctx : Context) (x : Var) (k : Int) (c : RawRelCnstr) (h : c.norm == .eq (.add 1 x (.num (-k))))
: c.denote ctx = (x.denote ctx = k) := by
have h := congrArg (PolyCnstr.denote ctx) (eq_of_beq h)
rw [denote_toPoly] at h
have h := congrArg (RelCnstr.denote ctx) (eq_of_beq h)
rw [denote_norm] at h
rw [h]; simp
rw [Int.add_comm, Int.sub_eq_add_neg, Int.sub_eq_zero]
@@ -356,39 +411,39 @@ private theorem eq_mul_le_zero {a b : Int} : 0 < b → (a ≤ 0 ↔ a * b ≤ 0)
rw [this] at h'
exact Int.le_of_mul_le_mul_right h' h
attribute [local simp] PolyCnstr.divAll PolyCnstr.div
attribute [local simp] RelCnstr.divAll RelCnstr.div
theorem ExprCnstr.eq_of_toPoly_eq_of_divBy' (ctx : Context) (e e' : ExprCnstr) (p : PolyCnstr) (k : Int) : k > 0 p.divAll k e.toPoly = p e'.toPoly = p.div k e.denote ctx = e'.denote ctx := by
theorem RawRelCnstr.eq_of_norm_eq_of_divBy' (ctx : Context) (c c' : RawRelCnstr) (p : RelCnstr) (k : Int)
: k > 0 p.divAll k c.norm = p c'.norm = p.div k c.denote ctx = c'.denote ctx := by
intro h₀ h₁ h₂ h₃
have hz : k 0 := Int.ne_of_gt h₀
cases p <;> simp at h₁
next p =>
replace h₁ := Poly.denote_div_eq_of_divAll ctx p k h₁
replace h₂ := congrArg (PolyCnstr.denote ctx) h₂
simp only [PolyCnstr.denote.eq_1, h₁] at h₂
replace h₃ := congrArg (PolyCnstr.denote ctx) h₃
simp only [PolyCnstr.denote.eq_1, PolyCnstr.div] at h₃
replace h₂ := congrArg (RelCnstr.denote ctx) h₂
simp only [RelCnstr.denote.eq_1, h₁] at h₂
replace h₃ := congrArg (RelCnstr.denote ctx) h₃
simp only [RelCnstr.denote.eq_1, RelCnstr.div] at h₃
rw [mul_eq_zero_iff_eq_zero _ _ hz] at h₂
have := Eq.trans h₂ h₃.symm
rw [denote_toPoly, denote_toPoly] at this
rw [denote_norm, denote_norm] at this
exact this
next p =>
-- TODO: this is correct but we can simplify `p ≤ 0` if `p.divCoeffs k` and `p.getConst % k > 0`. Here, we are simplifying only the case `p.getConst % k = 0`
replace h₁ := Poly.denote_div_eq_of_divAll ctx p k h₁
replace h₂ := congrArg (PolyCnstr.denote ctx) h₂
simp only [PolyCnstr.denote.eq_2, h₁] at h₂
replace h₃ := congrArg (PolyCnstr.denote ctx) h₃
simp only [PolyCnstr.denote.eq_2, PolyCnstr.div] at h₃
replace h₂ := congrArg (RelCnstr.denote ctx) h₂
simp only [RelCnstr.denote.eq_2, h₁] at h₂
replace h₃ := congrArg (RelCnstr.denote ctx) h₃
simp only [RelCnstr.denote.eq_2, RelCnstr.div] at h₃
rw [eq_mul_le_zero h₀] at h₃
have := Eq.trans h₂ h₃.symm
rw [denote_toPoly, denote_toPoly] at this
rw [denote_norm, denote_norm] at this
exact this
theorem ExprCnstr.eq_of_divBy (ctx : Context) (e e' : ExprCnstr) (k : Int) : divBy e e' k e.denote ctx = e'.denote ctx := by
theorem RawRelCnstr.eq_of_divBy (ctx : Context) (e e' : RawRelCnstr) (k : Int) : divBy e e' k e.denote ctx = e'.denote ctx := by
intro h
simp only [divBy, Bool.and_eq_true, bne_iff_ne, ne_eq, beq_iff_eq, decide_eq_true_eq] at h
have h₁, h₂, h₃ := h
exact ExprCnstr.eq_of_toPoly_eq_of_divBy' ctx e e' e.toPoly k h₁ h₂ rfl h₃
exact eq_of_norm_eq_of_divBy' ctx e e' e.norm k h₁ h₂ rfl h₃
private theorem mul_add_cmod_le_iff {a k b : Int} (h : k > 0) : a*k + cmod b k 0 a 0 := by
constructor
@@ -414,53 +469,54 @@ private theorem mul_add_cmod_le_iff {a k b : Int} (h : k > 0) : a*k + cmod b k
simp at this
assumption
theorem ExprCnstr.eq_of_toPoly_eq_of_divCoeffs (ctx : Context) (e e' : ExprCnstr) (p : PolyCnstr) (k : Int) : k > 0 p.divCoeffs k p.isLe e.toPoly = p e'.toPoly = p.div k e.denote ctx = e'.denote ctx := by
theorem RawRelCnstr.eq_of_norm_eq_of_divCoeffs (ctx : Context) (c c' : RawRelCnstr) (p : RelCnstr) (k : Int)
: k > 0 p.divCoeffs k p.isLe c.norm = p c'.norm = p.div k c.denote ctx = c'.denote ctx := by
intro h₀ h₁ h₂ h₃ h₄
have hz : k 0 := Int.ne_of_gt h₀
cases p <;> simp [PolyCnstr.isLe] at h₂
cases p <;> simp [RelCnstr.isLe] at h₂
clear h₂
next p =>
simp [PolyCnstr.divCoeffs] at h₁
simp [RelCnstr.divCoeffs] at h₁
replace h₁ := Poly.denote_div_eq_of_divCoeffs ctx p k h₁
replace h₃ := congrArg (PolyCnstr.denote ctx) h₃
simp only [PolyCnstr.denote.eq_2, h₁] at h₃
replace h₄ := congrArg (PolyCnstr.denote ctx) h₄
simp only [PolyCnstr.denote.eq_2, PolyCnstr.div] at h₄
rw [denote_toPoly] at h₃ h₄
replace h₃ := congrArg (RelCnstr.denote ctx) h₃
simp only [RelCnstr.denote.eq_2, h₁] at h₃
replace h₄ := congrArg (RelCnstr.denote ctx) h₄
simp only [RelCnstr.denote.eq_2, RelCnstr.div] at h₄
rw [denote_norm] at h₃ h₄
rw [h₃, h₄]
apply propext
apply mul_add_cmod_le_iff
exact h₀
-- Certificate for normalizing the coefficients of inequality constraint with bound tightening
def divByLe (e e' : ExprCnstr) (k : Int) : Bool :=
k > 0 && e.isLe && e.toPoly.divCoeffs k && e'.toPoly == e.toPoly.div k
/-- Certificate for normalizing the coefficients of inequality constraint with bound tightening. -/
def divByLe (c c' : RawRelCnstr) (k : Int) : Bool :=
k > 0 && c.isLe && c.norm.divCoeffs k && c'.norm == c.norm.div k
theorem ExprCnstr.eq_of_divByLe (ctx : Context) (e e' : ExprCnstr) (k : Int) : divByLe e e' k e.denote ctx = e'.denote ctx := by
theorem RawRelCnstr.eq_of_divByLe (ctx : Context) (c c' : RawRelCnstr) (k : Int) : divByLe c c' k c.denote ctx = c'.denote ctx := by
intro h
simp only [divByLe, Bool.and_eq_true, bne_iff_ne, ne_eq, beq_iff_eq, decide_eq_true_eq] at h
have h₀, h₁, h₂, h₃ := h
have hle : e.toPoly.isLe := by
cases e <;> simp [ExprCnstr.isLe] at h₁
simp [PolyCnstr.isLe]
apply ExprCnstr.eq_of_toPoly_eq_of_divCoeffs ctx e e' e.toPoly k h₀ h₂ hle rfl h₃
have hle : c.norm.isLe := by
cases c <;> simp [RawRelCnstr.isLe] at h₁
simp [RelCnstr.isLe]
apply eq_of_norm_eq_of_divCoeffs ctx c c' c.norm k h₀ h₂ hle rfl h₃
def PolyCnstr.isUnsat : PolyCnstr Bool
def RelCnstr.isUnsat : RelCnstr Bool
| .eq (.num k) => k != 0
| .eq _ => false
| .le (.num k) => k > 0
| .le _ => false
theorem PolyCnstr.eq_false_of_isUnsat (ctx : Context) (p : PolyCnstr) : p.isUnsat p.denote ctx = False := by
theorem RelCnstr.eq_false_of_isUnsat (ctx : Context) (c : RelCnstr) : c.isUnsat c.denote ctx = False := by
unfold isUnsat <;> split <;> simp <;> try contradiction
apply Int.not_le_of_gt
theorem ExprCnstr.eq_false_of_isUnsat (ctx : Context) (c : ExprCnstr) (h : c.toPoly.isUnsat) : c.denote ctx = False := by
have := PolyCnstr.eq_false_of_isUnsat ctx (c.toPoly) h
rw [ExprCnstr.denote_toPoly] at this
theorem RawRelCnstr.eq_false_of_isUnsat (ctx : Context) (c : RawRelCnstr) (h : c.norm.isUnsat) : c.denote ctx = False := by
have := RelCnstr.eq_false_of_isUnsat ctx c.norm h
rw [RawRelCnstr.denote_norm] at this
assumption
def PolyCnstr.isUnsatCoeff (k : Int) : PolyCnstr Bool
def RelCnstr.isUnsatCoeff (k : Int) : RelCnstr Bool
| .eq p => p.divCoeffs k && k > 0 && cmod p.getConst k < 0
| .le _ => false
@@ -491,7 +547,7 @@ private theorem contra {a b k : Int} (h₀ : 0 < k) (h₁ : -k < b) (h₂ : b <
have : (1 : Int) < 1 := Int.lt_of_le_of_lt h₂ h₁
contradiction
private theorem PolyCnstr.eq_false (ctx : Context) (p : Poly) (k : Int) : p.divCoeffs k k > 0 cmod p.getConst k < 0 (PolyCnstr.eq p).denote ctx = False := by
private theorem RelCnstr.eq_false (ctx : Context) (p : Poly) (k : Int) : p.divCoeffs k k > 0 cmod p.getConst k < 0 (RelCnstr.eq p).denote ctx = False := by
simp
intro h₁ h₂ h₃ h
have hnz : k 0 := by intro h; rw [h] at h₂; contradiction
@@ -501,31 +557,140 @@ private theorem PolyCnstr.eq_false (ctx : Context) (p : Poly) (k : Int) : p.divC
have high := h₃
exact contra h₂ low high this
theorem ExprCnstr.eq_false_of_isUnsat_coeff (ctx : Context) (c : ExprCnstr) (k : Int) : c.toPoly.isUnsatCoeff k c.denote ctx = False := by
theorem RawRelCnstr.eq_false_of_isUnsat_coeff (ctx : Context) (c : RawRelCnstr) (k : Int) : c.norm.isUnsatCoeff k c.denote ctx = False := by
intro h
cases c <;> simp [toPoly, PolyCnstr.isUnsatCoeff] at h
cases c <;> simp [norm, RelCnstr.isUnsatCoeff] at h
next e₁ e₂ =>
have h₁, h₂, h₃ := h
have := PolyCnstr.eq_false ctx _ _ h₁ h₂ h₃
have := RelCnstr.eq_false ctx _ _ h₁ h₂ h₃
simp at this
simp
intro he
simp [he] at this
def PolyCnstr.isValid : PolyCnstr Bool
def RelCnstr.isValid : RelCnstr Bool
| .eq (.num k) => k == 0
| .eq _ => false
| .le (.num k) => k 0
| .le _ => false
theorem PolyCnstr.eq_true_of_isValid (ctx : Context) (p : PolyCnstr) : p.isValid p.denote ctx = True := by
theorem RelCnstr.eq_true_of_isValid (ctx : Context) (c : RelCnstr) : c.isValid c.denote ctx = True := by
unfold isValid <;> split <;> simp
theorem ExprCnstr.eq_true_of_isValid (ctx : Context) (c : ExprCnstr) (h : c.toPoly.isValid) : c.denote ctx = True := by
have := PolyCnstr.eq_true_of_isValid ctx (c.toPoly) h
rw [ExprCnstr.denote_toPoly] at this
theorem RawRelCnstr.eq_true_of_isValid (ctx : Context) (c : RawRelCnstr) (h : c.norm.isValid) : c.denote ctx = True := by
have := RelCnstr.eq_true_of_isValid ctx c.norm h
rw [RawRelCnstr.denote_norm] at this
assumption
private def gcd (a b : Int) : Int :=
Int.ofNat <| Int.gcd a b
private theorem gcd_dvd_left (a b : Int) : gcd a b a := by
simp [gcd, Int.gcd_dvd_left]
private theorem gcd_dvd_right (a b : Int) : gcd a b b := by
simp [gcd, Int.gcd_dvd_right]
private theorem gcd_dvd_step {k a b x : Int} (h : k a*x + b) : gcd a k b := by
have h₁ : gcd a k a*x + b := Int.dvd_trans (gcd_dvd_right a k) h
have h₂ : gcd a k a*x := Int.dvd_trans (gcd_dvd_left a k) (Int.dvd_mul_right a x)
exact Int.dvd_iff_dvd_of_dvd_add h₁ |>.mp h₂
def Poly.gcdCoeffs : Poly Int Int
| .num _, k => k
| .add k' _ p, k => gcdCoeffs p (gcd k' k)
theorem Poly.gcd_dvd_const {ctx : Context} {p : Poly} {k : Int} (h : k p.denote ctx) : p.gcdCoeffs k p.getConst := by
induction p generalizing k <;> simp_all [gcdCoeffs]
next k' x p ih =>
rw [Int.add_comm] at h
exact ih (gcd_dvd_step h)
def Poly.mul (p : Poly) (k : Int) : Poly :=
match p with
| .num k' => .num (k*k')
| .add k' v p => .add (k*k') v (mul p k)
@[simp] theorem Poly.denote_mul (ctx : Context) (p : Poly) (k : Int) : (p.mul k).denote ctx = k * p.denote ctx := by
induction p <;> simp [mul, *]
rw [Int.mul_assoc]
/-- Divibility constraint of the form `k p`. -/
structure DvdCnstr where
k : Int
p : Poly
def DvdCnstr.denote (ctx : Context) (c : DvdCnstr) : Prop :=
c.k c.p.denote ctx
def DvdCnstr.isUnsat (c : DvdCnstr) : Bool :=
c.p.getConst % c.p.gcdCoeffs c.k != 0
def DvdCnstr.isEqv (c₁ c₂ : DvdCnstr) (k : Int) : Bool :=
k != 0 && c₁.k == k*c₂.k && c₁.p == c₂.p.mul k
def DvdCnstr.div (k' : Int) : DvdCnstr DvdCnstr
| { k, p } => { k := k / k', p := p.div k' }
private theorem not_dvd_of_not_mod_zero {a b : Int} (h : ¬ b % a = 0) : ¬ a b := by
intro h; have := Int.emod_eq_zero_of_dvd h; contradiction
def DvdCnstr.eq_false_of_isUnsat (ctx : Context) (c : DvdCnstr) : c.isUnsat c.denote ctx = False := by
rcases c with a, p
simp [isUnsat, denote]
intro h₁ h₂
have := Poly.gcd_dvd_const h₂
have := not_dvd_of_not_mod_zero h₁
contradiction
@[local simp] private theorem mul_dvd_mul_eq {a b c : Int} (hnz : a 0) : a * b a * c b c := by
constructor
· intro h
rcases h with k, h
rw [Int.mul_assoc a] at h
replace h := Int.eq_of_mul_eq_mul_left hnz h
exists k
· intro h
rcases h with k, h
exists k
rw [h, Int.mul_assoc]
@[local simp] theorem DvdCnstr.eq_of_isEqv (ctx : Context) (c₁ c₂ : DvdCnstr) (k : Int) (h : isEqv c₁ c₂ k) : c₁.denote ctx = c₂.denote ctx := by
rcases c₁ with a₁, e₁
rcases c₂ with a₂, e₂
simp [isEqv] at h
rcases h with h₁, h₂, h₃
replace h₃ := congrArg (Poly.denote ctx) h₃
simp at h₃
simp [denote, *]
/-- Raw divisibility constraint of the form `k e`. -/
structure RawDvdCnstr where
k : Int
e : Expr
deriving BEq
def RawDvdCnstr.denote (ctx : Context) (c : RawDvdCnstr) : Prop :=
c.k c.e.denote ctx
def RawDvdCnstr.norm (c : RawDvdCnstr) : DvdCnstr :=
{ k := c.k, p := c.e.toPoly }
@[simp] theorem RawDvdCnstr.denote_norm_eq (ctx : Context) (c : RawDvdCnstr) : c.denote ctx = c.norm.denote ctx := by
simp [norm, denote, DvdCnstr.denote]
def RawDvdCnstr.isEqv (c₁ c₂ : RawDvdCnstr) (k : Int) : Bool :=
c₁.norm.isEqv c₂.norm k
def RawDvdCnstr.isUnsat (c : RawDvdCnstr) : Bool :=
c.norm.isUnsat
theorem RawDvdCnstr.eq_of_isEqv (ctx : Context) (c₁ c₂ : RawDvdCnstr) (k : Int) (h : isEqv c₁ c₂ k) : c₁.denote ctx = c₂.denote ctx := by
simp [DvdCnstr.eq_of_isEqv ctx c₁.norm c₂.norm k h]
theorem RawDvdCnstr.eq_false_of_isUnsat (ctx : Context) (c : RawDvdCnstr) (h : c.isUnsat) : c.denote ctx = False := by
simp [DvdCnstr.eq_false_of_isUnsat ctx c.norm h]
end Int.Linear
theorem Int.not_le_eq (a b : Int) : (¬a b) = (b + 1 a) := by

View File

@@ -9,9 +9,9 @@ import Init.Meta
namespace Nat
protected theorem dvd_refl (a : Nat) : a a := 1, by simp
@[simp] protected theorem dvd_refl (a : Nat) : a a := 1, by simp
protected theorem dvd_zero (a : Nat) : a 0 := 0, by simp
@[simp] protected theorem dvd_zero (a : Nat) : a 0 := 0, by simp
protected theorem dvd_mul_left (a b : Nat) : a b * a := b, Nat.mul_comm b a
protected theorem dvd_mul_right (a b : Nat) : a a * b := b, rfl
@@ -129,4 +129,13 @@ protected theorem mul_div_assoc (m : Nat) (H : k n) : m * n / k = m * (n / k
have h1 : m * n / k = m * (n / k * k) / k := by rw [Nat.div_mul_cancel H]
rw [h1, Nat.mul_assoc, Nat.mul_div_cancel _ hpos]
/-! Helper theorems for `dvd` simproc -/
protected theorem dvd_eq_true_of_mod_eq_zero {m n : Nat} (h : n % m == 0) : (m n) = True := by
simp [Nat.dvd_of_mod_eq_zero, eq_of_beq h]
protected theorem dvd_eq_false_of_mod_ne_zero {m n : Nat} (h : n % m != 0) : (m n) = False := by
simp [eq_of_beq] at h
simp [dvd_iff_mod_eq_zero, h]
end Nat

View File

@@ -163,7 +163,7 @@ private def reprArray : Array String := Id.run do
private def reprFast (n : Nat) : String :=
if h : n < 128 then Nat.reprArray.get n h else
if h : n < USize.size then (USize.ofNatCore n h).repr
if h : n < USize.size then (USize.ofNatLT n h).repr
else (toDigits 10 n).asString
@[implemented_by reprFast]

View File

@@ -17,6 +17,7 @@ The type of signed 8-bit integers. This type has special support in the
compiler to make it actually 8 bits rather than wrapping a `Nat`.
-/
structure Int8 where
private ofUInt8 ::
/--
Obtain the `UInt8` that is 2's complement equivalent to the `Int8`.
-/
@@ -27,6 +28,7 @@ The type of signed 16-bit integers. This type has special support in the
compiler to make it actually 16 bits rather than wrapping a `Nat`.
-/
structure Int16 where
private ofUInt16 ::
/--
Obtain the `UInt16` that is 2's complement equivalent to the `Int16`.
-/
@@ -37,6 +39,7 @@ The type of signed 32-bit integers. This type has special support in the
compiler to make it actually 32 bits rather than wrapping a `Nat`.
-/
structure Int32 where
private ofUInt32 ::
/--
Obtain the `UInt32` that is 2's complement equivalent to the `Int32`.
-/
@@ -47,6 +50,7 @@ The type of signed 64-bit integers. This type has special support in the
compiler to make it actually 64 bits rather than wrapping a `Nat`.
-/
structure Int64 where
private ofUInt64 ::
/--
Obtain the `UInt64` that is 2's complement equivalent to the `Int64`.
-/
@@ -59,6 +63,7 @@ For example, if running on a 32-bit machine, ISize is equivalent to `Int32`.
Or on a 64-bit machine, `Int64`.
-/
structure ISize where
private ofUSize ::
/--
Obtain the `USize` that is 2's complement equivalent to the `ISize`.
-/
@@ -72,6 +77,10 @@ Obtain the `BitVec` that contains the 2's complement representation of the `Int8
-/
@[inline] def Int8.toBitVec (x : Int8) : BitVec 8 := x.toUInt8.toBitVec
/-- Obtains the `Int8` that is 2's complement equivalent to the `UInt8`. -/
@[inline] def UInt8.toInt8 (i : UInt8) : Int8 := Int8.ofUInt8 i
@[inline, deprecated UInt8.toInt8 (since := "2025-02-13"), inherit_doc UInt8.toInt8]
def Int8.mk (i : UInt8) : Int8 := UInt8.toInt8 i
@[extern "lean_int8_of_int"]
def Int8.ofInt (i : @& Int) : Int8 := BitVec.ofInt 8 i
@[extern "lean_int8_of_nat"]
@@ -84,7 +93,9 @@ def Int8.toInt (i : Int8) : Int := i.toBitVec.toInt
This function has the same behavior as `Int.toNat` for negative numbers.
If you want to obtain the 2's complement representation use `toBitVec`.
-/
@[inline] def Int8.toNat (i : Int8) : Nat := i.toInt.toNat
@[inline] def Int8.toNatClampNeg (i : Int8) : Nat := i.toInt.toNat
@[inline, deprecated Int8.toNatClampNeg (since := "2025-02-13"), inherit_doc Int8.toNatClampNeg]
def Int8.toNat (i : Int8) : Nat := i.toInt.toNat
/-- Obtains the `Int8` whose 2's complement representation is the given `BitVec 8`. -/
@[inline] def Int8.ofBitVec (b : BitVec 8) : Int8 := b
@[extern "lean_int8_neg"]
@@ -97,6 +108,24 @@ instance : OfNat Int8 n := ⟨Int8.ofNat n⟩
instance : Neg Int8 where
neg := Int8.neg
/-- The maximum value an `Int8` may attain, that is, `2^7 - 1 = 127`. -/
abbrev Int8.maxValue : Int8 := 127
/-- The minimum value an `Int8` may attain, that is, `-2^7 = -128`. -/
abbrev Int8.minValue : Int8 := -128
/-- Constructs an `Int8` from an `Int` which is known to be in bounds. -/
@[inline]
def Int8.ofIntLE (i : Int) (_hl : Int8.minValue.toInt i) (_hr : i Int8.maxValue.toInt) : Int8 :=
Int8.ofInt i
/-- Constructs an `Int8` from an `Int`, clamping if the value is too small or too large. -/
def Int8.ofIntTruncate (i : Int) : Int8 :=
if hl : Int8.minValue.toInt i then
if hr : i Int8.maxValue.toInt then
Int8.ofIntLE i hl hr
else
Int8.minValue
else
Int8.minValue
@[extern "lean_int8_add"]
def Int8.add (a b : Int8) : Int8 := a.toBitVec + b.toBitVec
@[extern "lean_int8_sub"]
@@ -174,6 +203,10 @@ Obtain the `BitVec` that contains the 2's complement representation of the `Int1
-/
@[inline] def Int16.toBitVec (x : Int16) : BitVec 16 := x.toUInt16.toBitVec
/-- Obtains the `Int16` that is 2's complement equivalent to the `UInt16`. -/
@[inline] def UInt16.toInt16 (i : UInt16) : Int16 := Int16.ofUInt16 i
@[inline, deprecated UInt16.toInt16 (since := "2025-02-13"), inherit_doc UInt16.toInt16]
def Int16.mk (i : UInt16) : Int16 := UInt16.toInt16 i
@[extern "lean_int16_of_int"]
def Int16.ofInt (i : @& Int) : Int16 := BitVec.ofInt 16 i
@[extern "lean_int16_of_nat"]
@@ -186,7 +219,9 @@ def Int16.toInt (i : Int16) : Int := i.toBitVec.toInt
This function has the same behavior as `Int.toNat` for negative numbers.
If you want to obtain the 2's complement representation use `toBitVec`.
-/
@[inline] def Int16.toNat (i : Int16) : Nat := i.toInt.toNat
@[inline] def Int16.toNatClampNeg (i : Int16) : Nat := i.toInt.toNat
@[inline, deprecated Int16.toNatClampNeg (since := "2025-02-13"), inherit_doc Int16.toNatClampNeg]
def Int16.toNat (i : Int16) : Nat := i.toInt.toNat
/-- Obtains the `Int16` whose 2's complement representation is the given `BitVec 16`. -/
@[inline] def Int16.ofBitVec (b : BitVec 16) : Int16 := b
@[extern "lean_int16_to_int8"]
@@ -203,6 +238,24 @@ instance : OfNat Int16 n := ⟨Int16.ofNat n⟩
instance : Neg Int16 where
neg := Int16.neg
/-- The maximum value an `Int16` may attain, that is, `2^15 - 1 = 32767`. -/
abbrev Int16.maxValue : Int16 := 32767
/-- The minimum value an `Int16` may attain, that is, `-2^15 = -32768`. -/
abbrev Int16.minValue : Int16 := -32768
/-- Constructs an `Int16` from an `Int` which is known to be in bounds. -/
@[inline]
def Int16.ofIntLE (i : Int) (_hl : Int16.minValue.toInt i) (_hr : i Int16.maxValue.toInt) : Int16 :=
Int16.ofInt i
/-- Constructs an `Int16` from an `Int`, clamping if the value is too small or too large. -/
def Int16.ofIntTruncate (i : Int) : Int16 :=
if hl : Int16.minValue.toInt i then
if hr : i Int16.maxValue.toInt then
Int16.ofIntLE i hl hr
else
Int16.minValue
else
Int16.minValue
@[extern "lean_int16_add"]
def Int16.add (a b : Int16) : Int16 := a.toBitVec + b.toBitVec
@[extern "lean_int16_sub"]
@@ -280,6 +333,10 @@ Obtain the `BitVec` that contains the 2's complement representation of the `Int3
-/
@[inline] def Int32.toBitVec (x : Int32) : BitVec 32 := x.toUInt32.toBitVec
/-- Obtains the `Int32` that is 2's complement equivalent to the `UInt32`. -/
@[inline] def UInt32.toInt32 (i : UInt32) : Int32 := Int32.ofUInt32 i
@[inline, deprecated UInt32.toInt32 (since := "2025-02-13"), inherit_doc UInt32.toInt32]
def Int32.mk (i : UInt32) : Int32 := UInt32.toInt32 i
@[extern "lean_int32_of_int"]
def Int32.ofInt (i : @& Int) : Int32 := BitVec.ofInt 32 i
@[extern "lean_int32_of_nat"]
@@ -292,7 +349,9 @@ def Int32.toInt (i : Int32) : Int := i.toBitVec.toInt
This function has the same behavior as `Int.toNat` for negative numbers.
If you want to obtain the 2's complement representation use `toBitVec`.
-/
@[inline] def Int32.toNat (i : Int32) : Nat := i.toInt.toNat
@[inline] def Int32.toNatClampNeg (i : Int32) : Nat := i.toInt.toNat
@[inline, deprecated Int32.toNatClampNeg (since := "2025-02-13"), inherit_doc Int32.toNatClampNeg]
def Int32.toNat (i : Int32) : Nat := i.toInt.toNat
/-- Obtains the `Int32` whose 2's complement representation is the given `BitVec 32`. -/
@[inline] def Int32.ofBitVec (b : BitVec 32) : Int32 := b
@[extern "lean_int32_to_int8"]
@@ -313,6 +372,24 @@ instance : OfNat Int32 n := ⟨Int32.ofNat n⟩
instance : Neg Int32 where
neg := Int32.neg
/-- The maximum value an `Int32` may attain, that is, `2^31 - 1 = 2147483647`. -/
abbrev Int32.maxValue : Int32 := 2147483647
/-- The minimum value an `Int32` may attain, that is, `-2^31 = -2147483648`. -/
abbrev Int32.minValue : Int32 := -2147483648
/-- Constructs an `Int32` from an `Int` which is known to be in bounds. -/
@[inline]
def Int32.ofIntLE (i : Int) (_hl : Int32.minValue.toInt i) (_hr : i Int32.maxValue.toInt) : Int32 :=
Int32.ofInt i
/-- Constructs an `Int32` from an `Int`, clamping if the value is too small or too large. -/
def Int32.ofIntTruncate (i : Int) : Int32 :=
if hl : Int32.minValue.toInt i then
if hr : i Int32.maxValue.toInt then
Int32.ofIntLE i hl hr
else
Int32.minValue
else
Int32.minValue
@[extern "lean_int32_add"]
def Int32.add (a b : Int32) : Int32 := a.toBitVec + b.toBitVec
@[extern "lean_int32_sub"]
@@ -390,6 +467,10 @@ Obtain the `BitVec` that contains the 2's complement representation of the `Int6
-/
@[inline] def Int64.toBitVec (x : Int64) : BitVec 64 := x.toUInt64.toBitVec
/-- Obtains the `Int64` that is 2's complement equivalent to the `UInt64`. -/
@[inline] def UInt64.toInt64 (i : UInt64) : Int64 := Int64.ofUInt64 i
@[inline, deprecated UInt64.toInt64 (since := "2025-02-13"), inherit_doc UInt64.toInt64]
def Int64.mk (i : UInt64) : Int64 := UInt64.toInt64 i
@[extern "lean_int64_of_int"]
def Int64.ofInt (i : @& Int) : Int64 := BitVec.ofInt 64 i
@[extern "lean_int64_of_nat"]
@@ -402,7 +483,9 @@ def Int64.toInt (i : Int64) : Int := i.toBitVec.toInt
This function has the same behavior as `Int.toNat` for negative numbers.
If you want to obtain the 2's complement representation use `toBitVec`.
-/
@[inline] def Int64.toNat (i : Int64) : Nat := i.toInt.toNat
@[inline] def Int64.toNatClampNeg (i : Int64) : Nat := i.toInt.toNat
@[inline, deprecated Int64.toNatClampNeg (since := "2025-02-13"), inherit_doc Int64.toNatClampNeg]
def Int64.toNat (i : Int64) : Nat := i.toInt.toNat
/-- Obtains the `Int64` whose 2's complement representation is the given `BitVec 64`. -/
@[inline] def Int64.ofBitVec (b : BitVec 64) : Int64 := b
@[extern "lean_int64_to_int8"]
@@ -427,6 +510,24 @@ instance : OfNat Int64 n := ⟨Int64.ofNat n⟩
instance : Neg Int64 where
neg := Int64.neg
/-- The maximum value an `Int64` may attain, that is, `2^63 - 1 = 9223372036854775807`. -/
abbrev Int64.maxValue : Int64 := 9223372036854775807
/-- The minimum value an `Int64` may attain, that is, `-2^63 = -9223372036854775808`. -/
abbrev Int64.minValue : Int64 := -9223372036854775808
/-- Constructs an `Int64` from an `Int` which is known to be in bounds. -/
@[inline]
def Int64.ofIntLE (i : Int) (_hl : Int64.minValue.toInt i) (_hr : i Int64.maxValue.toInt) : Int64 :=
Int64.ofInt i
/-- Constructs an `Int64` from an `Int`, clamping if the value is too small or too large. -/
def Int64.ofIntTruncate (i : Int) : Int64 :=
if hl : Int64.minValue.toInt i then
if hr : i Int64.maxValue.toInt then
Int64.ofIntLE i hl hr
else
Int64.minValue
else
Int64.minValue
@[extern "lean_int64_add"]
def Int64.add (a b : Int64) : Int64 := a.toBitVec + b.toBitVec
@[extern "lean_int64_sub"]
@@ -504,6 +605,10 @@ Obtain the `BitVec` that contains the 2's complement representation of the `ISiz
-/
@[inline] def ISize.toBitVec (x : ISize) : BitVec System.Platform.numBits := x.toUSize.toBitVec
/-- Obtains the `ISize` that is 2's complement equivalent to the `USize`. -/
@[inline] def USize.toISize (i : USize) : ISize := ISize.ofUSize i
@[inline, deprecated USize.toISize (since := "2025-02-13"), inherit_doc USize.toISize]
def ISize.mk (i : USize) : ISize := USize.toISize i
@[extern "lean_isize_of_int"]
def ISize.ofInt (i : @& Int) : ISize := BitVec.ofInt System.Platform.numBits i
@[extern "lean_isize_of_nat"]
@@ -516,18 +621,28 @@ def ISize.toInt (i : ISize) : Int := i.toBitVec.toInt
This function has the same behavior as `Int.toNat` for negative numbers.
If you want to obtain the 2's complement representation use `toBitVec`.
-/
@[inline] def ISize.toNat (i : ISize) : Nat := i.toInt.toNat
@[inline] def ISize.toNatClampNeg (i : ISize) : Nat := i.toInt.toNat
@[inline, deprecated ISize.toNatClampNeg (since := "2025-02-13"), inherit_doc ISize.toNatClampNeg]
def ISize.toNat (i : ISize) : Nat := i.toInt.toNat
/-- Obtains the `ISize` whose 2's complement representation is the given `BitVec`. -/
@[inline] def ISize.ofBitVec (b : BitVec System.Platform.numBits) : ISize := b
@[extern "lean_isize_to_int8"]
def ISize.toInt8 (a : ISize) : Int8 := a.toBitVec.signExtend 8
@[extern "lean_isize_to_int16"]
def ISize.toInt16 (a : ISize) : Int16 := a.toBitVec.signExtend 16
@[extern "lean_isize_to_int32"]
def ISize.toInt32 (a : ISize) : Int32 := a.toBitVec.signExtend 32
/--
Upcast `ISize` to `Int64`. This function is losless as `ISize` is either `Int32` or `Int64`.
Upcasts `ISize` to `Int64`. This function is lossless as `ISize` is either `Int32` or `Int64`.
-/
@[extern "lean_isize_to_int64"]
def ISize.toInt64 (a : ISize) : Int64 := a.toBitVec.signExtend 64
@[extern "lean_int8_to_isize"]
def Int8.toISize (a : Int8) : ISize := a.toBitVec.signExtend System.Platform.numBits
@[extern "lean_int16_to_isize"]
def Int16.toISize (a : Int16) : ISize := a.toBitVec.signExtend System.Platform.numBits
/--
Upcast `Int32` to `ISize`. This function is losless as `ISize` is either `Int32` or `Int64`.
Upcasts `Int32` to `ISize`. This function is lossless as `ISize` is either `Int32` or `Int64`.
-/
@[extern "lean_int32_to_isize"]
def Int32.toISize (a : Int32) : ISize := a.toBitVec.signExtend System.Platform.numBits
@@ -543,6 +658,26 @@ instance : OfNat ISize n := ⟨ISize.ofNat n⟩
instance : Neg ISize where
neg := ISize.neg
/-- The maximum value an `ISize` may attain, that is, `2^(System.Platform.numBits - 1) - 1`. -/
abbrev ISize.maxValue : ISize := .ofInt (2 ^ (System.Platform.numBits - 1) - 1)
-- 9223372036854775807
/-- The minimum value an `ISize` may attain, that is, `-2^(System.Platform.numBits - 1)`. -/
abbrev ISize.minValue : ISize := .ofInt (2 ^ (System.Platform.numBits - 1))
/-- Constructs an `ISize` from an `Int` which is known to be in bounds. -/
@[inline]
def ISize.ofIntLE (i : Int) (_hl : ISize.minValue.toInt i) (_hr : i ISize.maxValue.toInt) : ISize :=
ISize.ofInt i
/-- Constructs an `ISize` from an `Int`, clamping if the value is too small or too large. -/
def ISize.ofIntTruncate (i : Int) : ISize :=
if hl : ISize.minValue.toInt i then
if hr : i ISize.maxValue.toInt then
ISize.ofIntLE i hl hr
else
ISize.minValue
else
ISize.minValue
@[extern "lean_isize_add"]
def ISize.add (a b : ISize) : ISize := a.toBitVec + b.toBitVec
@[extern "lean_isize_sub"]

View File

@@ -9,9 +9,14 @@ import Init.Data.BitVec.Basic
open Nat
/-- Converts a `Fin UInt8.size` into the corresponding `UInt8`. -/
@[inline] def UInt8.ofFin (a : Fin UInt8.size) : UInt8 := a
@[deprecated UInt8.ofBitVec (since := "2025-02-12"), inherit_doc UInt8.ofBitVec]
def UInt8.mk (bitVec : BitVec 8) : UInt8 :=
UInt8.ofBitVec bitVec
@[inline, deprecated UInt8.ofNatLT (since := "2025-02-13"), inherit_doc UInt8.ofNatLT]
def UInt8.ofNatCore (n : Nat) (h : n < UInt8.size) : UInt8 :=
UInt8.ofNatLT n h
@[extern "lean_uint8_add"]
def UInt8.add (a b : UInt8) : UInt8 := a.toBitVec + b.toBitVec
@@ -24,7 +29,7 @@ def UInt8.div (a b : UInt8) : UInt8 := ⟨BitVec.udiv a.toBitVec b.toBitVec⟩
@[extern "lean_uint8_mod"]
def UInt8.mod (a b : UInt8) : UInt8 := BitVec.umod a.toBitVec b.toBitVec
@[deprecated UInt8.mod (since := "2024-09-23")]
def UInt8.modn (a : UInt8) (n : Nat) : UInt8 := Fin.modn a.val n
def UInt8.modn (a : UInt8) (n : Nat) : UInt8 := Fin.modn a.toFin n
@[extern "lean_uint8_land"]
def UInt8.land (a b : UInt8) : UInt8 := a.toBitVec &&& b.toBitVec
@[extern "lean_uint8_lor"]
@@ -76,9 +81,14 @@ instance (a b : UInt8) : Decidable (a ≤ b) := UInt8.decLe a b
instance : Max UInt8 := maxOfLe
instance : Min UInt8 := minOfLe
/-- Converts a `Fin UInt16.size` into the corresponding `UInt16`. -/
@[inline] def UInt16.ofFin (a : Fin UInt16.size) : UInt16 := a
@[deprecated UInt16.ofBitVec (since := "2025-02-12"), inherit_doc UInt16.ofBitVec]
def UInt16.mk (bitVec : BitVec 16) : UInt16 :=
UInt16.ofBitVec bitVec
@[inline, deprecated UInt16.ofNatLT (since := "2025-02-13"), inherit_doc UInt16.ofNatLT]
def UInt16.ofNatCore (n : Nat) (h : n < UInt16.size) : UInt16 :=
UInt16.ofNatLT n h
@[extern "lean_uint16_add"]
def UInt16.add (a b : UInt16) : UInt16 := a.toBitVec + b.toBitVec
@@ -91,7 +101,7 @@ def UInt16.div (a b : UInt16) : UInt16 := ⟨BitVec.udiv a.toBitVec b.toBitVec
@[extern "lean_uint16_mod"]
def UInt16.mod (a b : UInt16) : UInt16 := BitVec.umod a.toBitVec b.toBitVec
@[deprecated UInt16.mod (since := "2024-09-23")]
def UInt16.modn (a : UInt16) (n : Nat) : UInt16 := Fin.modn a.val n
def UInt16.modn (a : UInt16) (n : Nat) : UInt16 := Fin.modn a.toFin n
@[extern "lean_uint16_land"]
def UInt16.land (a b : UInt16) : UInt16 := a.toBitVec &&& b.toBitVec
@[extern "lean_uint16_lor"]
@@ -145,9 +155,14 @@ instance (a b : UInt16) : Decidable (a ≤ b) := UInt16.decLe a b
instance : Max UInt16 := maxOfLe
instance : Min UInt16 := minOfLe
/-- Converts a `Fin UInt32.size` into the corresponding `UInt32`. -/
@[inline] def UInt32.ofFin (a : Fin UInt32.size) : UInt32 := a
@[deprecated UInt32.ofBitVec (since := "2025-02-12"), inherit_doc UInt32.ofBitVec]
def UInt32.mk (bitVec : BitVec 32) : UInt32 :=
UInt32.ofBitVec bitVec
@[inline, deprecated UInt32.ofNatLT (since := "2025-02-13"), inherit_doc UInt32.ofNatLT]
def UInt32.ofNatCore (n : Nat) (h : n < UInt32.size) : UInt32 :=
UInt32.ofNatLT n h
@[extern "lean_uint32_add"]
def UInt32.add (a b : UInt32) : UInt32 := a.toBitVec + b.toBitVec
@@ -160,7 +175,7 @@ def UInt32.div (a b : UInt32) : UInt32 := ⟨BitVec.udiv a.toBitVec b.toBitVec
@[extern "lean_uint32_mod"]
def UInt32.mod (a b : UInt32) : UInt32 := BitVec.umod a.toBitVec b.toBitVec
@[deprecated UInt32.mod (since := "2024-09-23")]
def UInt32.modn (a : UInt32) (n : Nat) : UInt32 := Fin.modn a.val n
def UInt32.modn (a : UInt32) (n : Nat) : UInt32 := Fin.modn a.toFin n
@[extern "lean_uint32_land"]
def UInt32.land (a b : UInt32) : UInt32 := a.toBitVec &&& b.toBitVec
@[extern "lean_uint32_lor"]
@@ -199,9 +214,14 @@ instance : ShiftRight UInt32 := ⟨UInt32.shiftRight⟩
@[extern "lean_bool_to_uint32"]
def Bool.toUInt32 (b : Bool) : UInt32 := if b then 1 else 0
/-- Converts a `Fin UInt64.size` into the corresponding `UInt64`. -/
@[inline] def UInt64.ofFin (a : Fin UInt64.size) : UInt64 := a
@[deprecated UInt64.ofBitVec (since := "2025-02-12"), inherit_doc UInt64.ofBitVec]
def UInt64.mk (bitVec : BitVec 64) : UInt64 :=
UInt64.ofBitVec bitVec
@[inline, deprecated UInt64.ofNatLT (since := "2025-02-13"), inherit_doc UInt64.ofNatLT]
def UInt64.ofNatCore (n : Nat) (h : n < UInt64.size) : UInt64 :=
UInt64.ofNatLT n h
@[extern "lean_uint64_add"]
def UInt64.add (a b : UInt64) : UInt64 := a.toBitVec + b.toBitVec
@@ -214,7 +234,7 @@ def UInt64.div (a b : UInt64) : UInt64 := ⟨BitVec.udiv a.toBitVec b.toBitVec
@[extern "lean_uint64_mod"]
def UInt64.mod (a b : UInt64) : UInt64 := BitVec.umod a.toBitVec b.toBitVec
@[deprecated UInt64.mod (since := "2024-09-23")]
def UInt64.modn (a : UInt64) (n : Nat) : UInt64 := Fin.modn a.val n
def UInt64.modn (a : UInt64) (n : Nat) : UInt64 := Fin.modn a.toFin n
@[extern "lean_uint64_land"]
def UInt64.land (a b : UInt64) : UInt64 := a.toBitVec &&& b.toBitVec
@[extern "lean_uint64_lor"]
@@ -266,9 +286,14 @@ instance (a b : UInt64) : Decidable (a ≤ b) := UInt64.decLe a b
instance : Max UInt64 := maxOfLe
instance : Min UInt64 := minOfLe
/-- Converts a `Fin USize.size` into the corresponding `USize`. -/
@[inline] def USize.ofFin (a : Fin USize.size) : USize := a
@[deprecated USize.ofBitVec (since := "2025-02-12"), inherit_doc USize.ofBitVec]
def USize.mk (bitVec : BitVec System.Platform.numBits) : USize :=
USize.ofBitVec bitVec
@[inline, deprecated USize.ofNatLT (since := "2025-02-13"), inherit_doc USize.ofNatLT]
def USize.ofNatCore (n : Nat) (h : n < USize.size) : USize :=
USize.ofNatLT n h
theorem usize_size_le : USize.size 18446744073709551616 := by
cases usize_size_eq <;> next h => rw [h]; decide
@@ -283,7 +308,7 @@ def USize.div (a b : USize) : USize := ⟨a.toBitVec / b.toBitVec⟩
@[extern "lean_usize_mod"]
def USize.mod (a b : USize) : USize := a.toBitVec % b.toBitVec
@[deprecated USize.mod (since := "2024-09-23")]
def USize.modn (a : USize) (n : Nat) : USize := Fin.modn a.val n
def USize.modn (a : USize) (n : Nat) : USize := Fin.modn a.toFin n
@[extern "lean_usize_land"]
def USize.land (a b : USize) : USize := a.toBitVec &&& b.toBitVec
@[extern "lean_usize_lor"]
@@ -301,7 +326,7 @@ This function is overridden with a native implementation.
-/
@[extern "lean_usize_of_nat"]
def USize.ofNat32 (n : @& Nat) (h : n < 4294967296) : USize :=
USize.ofNatCore n (Nat.lt_of_lt_of_le h le_usize_size)
USize.ofNatLT n (Nat.lt_of_lt_of_le h le_usize_size)
@[extern "lean_uint8_to_usize"]
def UInt8.toUSize (a : UInt8) : USize :=
USize.ofNat32 a.toBitVec.toNat (Nat.lt_trans a.toBitVec.isLt (by decide))
@@ -326,7 +351,7 @@ This function is overridden with a native implementation.
-/
@[extern "lean_usize_to_uint64"]
def USize.toUInt64 (a : USize) : UInt64 :=
UInt64.ofNatCore a.toBitVec.toNat (Nat.lt_of_lt_of_le a.toBitVec.isLt usize_size_le)
UInt64.ofNatLT a.toBitVec.toNat (Nat.lt_of_lt_of_le a.toBitVec.isLt usize_size_le)
instance : Mul USize := USize.mul
instance : Mod USize := USize.mod

View File

@@ -16,18 +16,40 @@ This file thus breaks the import cycle that would be created by this dependency.
open Nat
def UInt8.val (x : UInt8) : Fin UInt8.size := x.toBitVec.toFin
/-- Converts a `UInt8` into the corresponding `Fin UInt8.size`. -/
def UInt8.toFin (x : UInt8) : Fin UInt8.size := x.toBitVec.toFin
@[deprecated UInt8.toFin (since := "2025-02-12"), inherit_doc UInt8.toFin]
def UInt8.val (x : UInt8) : Fin UInt8.size := x.toFin
@[extern "lean_uint8_of_nat"]
def UInt8.ofNat (n : @& Nat) : UInt8 := BitVec.ofNat 8 n
/--
Converts the given natural number to `UInt8`, but returns `2^8 - 1` for natural numbers `>= 2^8`.
-/
def UInt8.ofNatTruncate (n : Nat) : UInt8 :=
if h : n < UInt8.size then
UInt8.ofNatLT n h
else
UInt8.ofNatLT (UInt8.size - 1) (by decide)
abbrev Nat.toUInt8 := UInt8.ofNat
@[extern "lean_uint8_to_nat"]
def UInt8.toNat (n : UInt8) : Nat := n.toBitVec.toNat
instance UInt8.instOfNat : OfNat UInt8 n := UInt8.ofNat n
def UInt16.val (x : UInt16) : Fin UInt16.size := x.toBitVec.toFin
/-- Converts a `UInt16` into the corresponding `Fin UInt16.size`. -/
def UInt16.toFin (x : UInt16) : Fin UInt16.size := x.toBitVec.toFin
@[deprecated UInt16.toFin (since := "2025-02-12"), inherit_doc UInt16.toFin]
def UInt16.val (x : UInt16) : Fin UInt16.size := x.toFin
@[extern "lean_uint16_of_nat"]
def UInt16.ofNat (n : @& Nat) : UInt16 := BitVec.ofNat 16 n
/--
Converts the given natural number to `UInt16`, but returns `2^16 - 1` for natural numbers `>= 2^16`.
-/
def UInt16.ofNatTruncate (n : Nat) : UInt16 :=
if h : n < UInt16.size then
UInt16.ofNatLT n h
else
UInt16.ofNatLT (UInt16.size - 1) (by decide)
abbrev Nat.toUInt16 := UInt16.ofNat
@[extern "lean_uint16_to_nat"]
def UInt16.toNat (n : UInt16) : Nat := n.toBitVec.toNat
@@ -38,19 +60,22 @@ def UInt8.toUInt16 (a : UInt8) : UInt16 := ⟨⟨a.toNat, Nat.lt_trans a.toBitVe
instance UInt16.instOfNat : OfNat UInt16 n := UInt16.ofNat n
def UInt32.val (x : UInt32) : Fin UInt32.size := x.toBitVec.toFin
/-- Converts a `UInt32` into the corresponding `Fin UInt32.size`. -/
def UInt32.toFin (x : UInt32) : Fin UInt32.size := x.toBitVec.toFin
@[deprecated UInt32.toFin (since := "2025-02-12"), inherit_doc UInt32.toFin]
def UInt32.val (x : UInt32) : Fin UInt32.size := x.toFin
@[extern "lean_uint32_of_nat"]
def UInt32.ofNat (n : @& Nat) : UInt32 := BitVec.ofNat 32 n
@[extern "lean_uint32_of_nat"]
def UInt32.ofNat' (n : Nat) (h : n < UInt32.size) : UInt32 := BitVec.ofNatLt n h
@[inline, deprecated UInt32.ofNatLT (since := "2025-02-13"), inherit_doc UInt32.ofNatLT]
def UInt32.ofNat' (n : Nat) (h : n < UInt32.size) : UInt32 := UInt32.ofNatLT n h
/--
Converts the given natural number to `UInt32`, but returns `2^32 - 1` for natural numbers `>= 2^32`.
-/
def UInt32.ofNatTruncate (n : Nat) : UInt32 :=
if h : n < UInt32.size then
UInt32.ofNat' n h
UInt32.ofNatLT n h
else
UInt32.ofNat' (UInt32.size - 1) (by decide)
UInt32.ofNatLT (UInt32.size - 1) (by decide)
abbrev Nat.toUInt32 := UInt32.ofNat
@[extern "lean_uint32_to_uint8"]
def UInt32.toUInt8 (a : UInt32) : UInt8 := a.toNat.toUInt8
@@ -63,19 +88,38 @@ def UInt16.toUInt32 (a : UInt16) : UInt32 := ⟨⟨a.toNat, Nat.lt_trans a.toBit
instance UInt32.instOfNat : OfNat UInt32 n := UInt32.ofNat n
theorem UInt32.ofNatLT_lt_of_lt {n m : Nat} (h1 : n < UInt32.size) (h2 : m < UInt32.size) :
n < m UInt32.ofNatLT n h1 < UInt32.ofNat m := by
simp only [(· < ·), BitVec.toNat, ofNatLT, BitVec.ofNatLT, ofNat, BitVec.ofNat, Fin.ofNat',
Nat.mod_eq_of_lt h2, imp_self]
@[deprecated UInt32.ofNatLT_lt_of_lt (since := "2025-02-13")]
theorem UInt32.ofNat'_lt_of_lt {n m : Nat} (h1 : n < UInt32.size) (h2 : m < UInt32.size) :
n < m UInt32.ofNat' n h1 < UInt32.ofNat m := by
simp only [(· < ·), BitVec.toNat, ofNat', BitVec.ofNatLt, ofNat, BitVec.ofNat, Fin.ofNat',
n < m UInt32.ofNatLT n h1 < UInt32.ofNat m := UInt32.ofNatLT_lt_of_lt h1 h2
theorem UInt32.lt_ofNatLT_of_lt {n m : Nat} (h1 : n < UInt32.size) (h2 : m < UInt32.size) :
m < n UInt32.ofNat m < UInt32.ofNatLT n h1 := by
simp only [(· < ·), BitVec.toNat, ofNatLT, BitVec.ofNatLT, ofNat, BitVec.ofNat, Fin.ofNat',
Nat.mod_eq_of_lt h2, imp_self]
@[deprecated UInt32.lt_ofNatLT_of_lt (since := "2025-02-13")]
theorem UInt32.lt_ofNat'_of_lt {n m : Nat} (h1 : n < UInt32.size) (h2 : m < UInt32.size) :
m < n UInt32.ofNat m < UInt32.ofNat' n h1 := by
simp only [(· < ·), BitVec.toNat, ofNat', BitVec.ofNatLt, ofNat, BitVec.ofNat, Fin.ofNat',
Nat.mod_eq_of_lt h2, imp_self]
m < n UInt32.ofNat m < UInt32.ofNatLT n h1 := UInt32.lt_ofNatLT_of_lt h1 h2
def UInt64.val (x : UInt64) : Fin UInt64.size := x.toBitVec.toFin
/-- Converts a `UInt64` into the corresponding `Fin UInt64.size`. -/
def UInt64.toFin (x : UInt64) : Fin UInt64.size := x.toBitVec.toFin
@[deprecated UInt64.toFin (since := "2025-02-12"), inherit_doc UInt64.toFin]
def UInt64.val (x : UInt64) : Fin UInt64.size := x.toFin
@[extern "lean_uint64_of_nat"]
def UInt64.ofNat (n : @& Nat) : UInt64 := BitVec.ofNat 64 n
/--
Converts the given natural number to `UInt64`, but returns `2^64 - 1` for natural numbers `>= 2^64`.
-/
def UInt64.ofNatTruncate (n : Nat) : UInt64 :=
if h : n < UInt64.size then
UInt64.ofNatLT n h
else
UInt64.ofNatLT (UInt64.size - 1) (by decide)
abbrev Nat.toUInt64 := UInt64.ofNat
@[extern "lean_uint64_to_nat"]
def UInt64.toNat (n : UInt64) : Nat := n.toBitVec.toNat
@@ -97,9 +141,21 @@ instance UInt64.instOfNat : OfNat UInt64 n := ⟨UInt64.ofNat n⟩
@[deprecated usize_size_pos (since := "2024-11-24")] theorem usize_size_gt_zero : USize.size > 0 :=
usize_size_pos
def USize.val (x : USize) : Fin USize.size := x.toBitVec.toFin
/-- Converts a `USize` into the corresponding `Fin USize.size`. -/
def USize.toFin (x : USize) : Fin USize.size := x.toBitVec.toFin
@[deprecated USize.toFin (since := "2025-02-12"), inherit_doc USize.toFin]
def USize.val (x : USize) : Fin USize.size := x.toFin
@[extern "lean_usize_of_nat"]
def USize.ofNat (n : @& Nat) : USize := BitVec.ofNat _ n
/--
Converts the given natural number to `USize`, but returns `USize.size - 1` (i.e., `2^64 - 1` or
`2^32 - 1` depending on the platform) for natural numbers `>= USize.size`.
-/
def USize.ofNatTruncate (n : Nat) : USize :=
if h : n < USize.size then
USize.ofNatLT n h
else
USize.ofNatLT (USize.size - 1) (Nat.pred_lt (Nat.ne_zero_of_lt usize_size_pos))
abbrev Nat.toUSize := USize.ofNat
@[extern "lean_usize_to_nat"]
def USize.toNat (n : USize) : Nat := n.toBitVec.toNat

View File

@@ -29,9 +29,14 @@ macro "declare_uint_theorems" typeName:ident bits:term:arg : command => do
@[simp] theorem toNat_ofNat {n : Nat} : (ofNat n).toNat = n % 2 ^ $bits := BitVec.toNat_ofNat ..
@[simp] theorem toNat_ofNatCore {n : Nat} {h : n < size} : (ofNatCore n h).toNat = n := BitVec.toNat_ofNatLt ..
@[simp] theorem toNat_ofNatLT {n : Nat} {h : n < size} : (ofNatLT n h).toNat = n := BitVec.toNat_ofNatLT ..
@[simp] theorem val_val_eq_toNat (x : $typeName) : x.val.val = x.toNat := rfl
@[deprecated toNat_ofNatLT (since := "2025-02-13")]
theorem toNat_ofNatCore {n : Nat} {h : n < size} : (ofNatLT n h).toNat = n := BitVec.toNat_ofNatLT ..
@[simp] theorem toFin_val_eq_toNat (x : $typeName) : x.toFin.val = x.toNat := rfl
@[deprecated toFin_val_eq_toNat (since := "2025-02-12")]
theorem val_val_eq_toNat (x : $typeName) : x.toFin.val = x.toNat := rfl
theorem toNat_toBitVec_eq_toNat (x : $typeName) : x.toBitVec.toNat = x.toNat := rfl
@@ -86,13 +91,21 @@ macro "declare_uint_theorems" typeName:ident bits:term:arg : command => do
protected theorem eq_iff_toBitVec_eq {a b : $typeName} : a = b a.toBitVec = b.toBitVec :=
Iff.intro toBitVec_eq_of_eq eq_of_toBitVec_eq
open $typeName (eq_of_toBitVec_eq) in
protected theorem eq_of_val_eq {a b : $typeName} (h : a.val = b.val) : a = b := by
rcases a with _; rcases b with _; simp_all [val]
open $typeName (eq_of_toBitVec_eq toFin) in
protected theorem eq_of_toFin_eq {a b : $typeName} (h : a.toFin = b.toFin) : a = b := by
rcases a with _; rcases b with _; simp_all [toFin]
open $typeName (eq_of_toFin_eq) in
@[deprecated eq_of_toFin_eq (since := "2025-02-12")]
protected theorem eq_of_val_eq {a b : $typeName} (h : a.toFin = b.toFin) : a = b :=
eq_of_toFin_eq h
open $typeName (eq_of_val_eq) in
protected theorem val_inj {a b : $typeName} : a.val = b.val a = b :=
Iff.intro eq_of_val_eq (congrArg val)
open $typeName (eq_of_toFin_eq) in
protected theorem toFin_inj {a b : $typeName} : a.toFin = b.toFin a = b :=
Iff.intro eq_of_toFin_eq (congrArg toFin)
open $typeName (toFin_inj) in
@[deprecated toFin_inj (since := "2025-02-12")]
protected theorem val_inj {a b : $typeName} : a.toFin = b.toFin a = b :=
toFin_inj
open $typeName (eq_of_toBitVec_eq) in
protected theorem toBitVec_ne_of_ne {a b : $typeName} (h : a b) : a.toBitVec b.toBitVec :=
@@ -178,7 +191,9 @@ macro "declare_uint_theorems" typeName:ident bits:term:arg : command => do
simp [Nat.mod_eq_of_lt x.toNat_lt_size]
@[simp]
theorem val_ofNat (n : Nat) : val (no_index (OfNat.ofNat n)) = OfNat.ofNat n := rfl
theorem toFin_ofNat (n : Nat) : toFin (no_index (OfNat.ofNat n)) = OfNat.ofNat n := rfl
@[deprecated toFin_ofNat (since := "2025-02-12")]
theorem val_ofNat (n : Nat) : toFin (no_index (OfNat.ofNat n)) = OfNat.ofNat n := rfl
@[simp, int_toBitVec]
theorem toBitVec_ofNat (n : Nat) : toBitVec (no_index (OfNat.ofNat n)) = BitVec.ofNat _ n := rfl

View File

@@ -7,16 +7,16 @@ prelude
import Init.Data.Fin.Log2
@[extern "lean_uint8_log2"]
def UInt8.log2 (a : UInt8) : UInt8 := Fin.log2 a.val
def UInt8.log2 (a : UInt8) : UInt8 := Fin.log2 a.toFin
@[extern "lean_uint16_log2"]
def UInt16.log2 (a : UInt16) : UInt16 := Fin.log2 a.val
def UInt16.log2 (a : UInt16) : UInt16 := Fin.log2 a.toFin
@[extern "lean_uint32_log2"]
def UInt32.log2 (a : UInt32) : UInt32 := Fin.log2 a.val
def UInt32.log2 (a : UInt32) : UInt32 := Fin.log2 a.toFin
@[extern "lean_uint64_log2"]
def UInt64.log2 (a : UInt64) : UInt64 := Fin.log2 a.val
def UInt64.log2 (a : UInt64) : UInt64 := Fin.log2 a.toFin
@[extern "lean_usize_log2"]
def USize.log2 (a : USize) : USize := Fin.log2 a.val
def USize.log2 (a : USize) : USize := Fin.log2 a.toFin

View File

@@ -756,6 +756,13 @@ This is mostly useful for debugging info trees.
syntax (name := infoTreesCmd)
"#info_trees" " in" ppLine command : command
/--
Specify a premise selection engine.
Note that Lean does not ship a default premise selection engine,
so this is only useful in conjunction with a downstream package which provides one.
-/
syntax (name := setPremiseSelectorCmd)
"set_premise_selector" term : command
namespace Parser

View File

@@ -303,7 +303,7 @@ theorem dvd_gcd (xs : IntList) (c : Nat) (w : ∀ {a : Int}, a ∈ xs → (c : I
c xs.gcd := by
simp only [Int.ofNat_dvd_left] at w
induction xs with
| nil => have := Nat.dvd_zero c; simp at this; exact this
| nil => have := Nat.dvd_zero c; simp
| cons x xs ih =>
simp
apply Nat.dvd_gcd

View File

@@ -1904,7 +1904,7 @@ instance : DecidableEq (BitVec n) := BitVec.decEq
/-- The `BitVec` with value `i`, given a proof that `i < 2^n`. -/
@[match_pattern]
protected def BitVec.ofNatLt {n : Nat} (i : Nat) (p : LT.lt i (hPow 2 n)) : BitVec n where
protected def BitVec.ofNatLT {n : Nat} (i : Nat) (p : LT.lt i (hPow 2 n)) : BitVec n where
toFin := i, p
/-- Given a bitvector `x`, return the underlying `Nat`. This is O(1) because `BitVec` is a
@@ -1939,21 +1939,13 @@ structure UInt8 where
attribute [extern "lean_uint8_of_nat_mk"] UInt8.ofBitVec
attribute [extern "lean_uint8_to_nat"] UInt8.toBitVec
/--
Pack a `Nat` less than `2^8` into a `UInt8`.
This function is overridden with a native implementation.
-/
@[extern "lean_uint8_of_nat"]
def UInt8.ofNatCore (n : @& Nat) (h : LT.lt n UInt8.size) : UInt8 where
toBitVec := BitVec.ofNatLt n h
/--
Pack a `Nat` less than `2^8` into a `UInt8`.
This function is overridden with a native implementation.
-/
@[extern "lean_uint8_of_nat"]
def UInt8.ofNatLT (n : @& Nat) (h : LT.lt n UInt8.size) : UInt8 where
toBitVec := BitVec.ofNatLt n h
toBitVec := BitVec.ofNatLT n h
set_option bootstrap.genMatcherCode false in
/--
@@ -1971,7 +1963,7 @@ def UInt8.decEq (a b : UInt8) : Decidable (Eq a b) :=
instance : DecidableEq UInt8 := UInt8.decEq
instance : Inhabited UInt8 where
default := UInt8.ofNatCore 0 (of_decide_eq_true rfl)
default := UInt8.ofNatLT 0 (of_decide_eq_true rfl)
/-- The size of type `UInt16`, that is, `2^16 = 65536`. -/
abbrev UInt16.size : Nat := 65536
@@ -1993,21 +1985,13 @@ structure UInt16 where
attribute [extern "lean_uint16_of_nat_mk"] UInt16.ofBitVec
attribute [extern "lean_uint16_to_nat"] UInt16.toBitVec
/--
Pack a `Nat` less than `2^16` into a `UInt16`.
This function is overridden with a native implementation.
-/
@[extern "lean_uint16_of_nat"]
def UInt16.ofNatCore (n : @& Nat) (h : LT.lt n UInt16.size) : UInt16 where
toBitVec := BitVec.ofNatLt n h
/--
Pack a `Nat` less than `2^16` into a `UInt16`.
This function is overridden with a native implementation.
-/
@[extern "lean_uint16_of_nat"]
def UInt16.ofNatLT (n : @& Nat) (h : LT.lt n UInt16.size) : UInt16 where
toBitVec := BitVec.ofNatLt n h
toBitVec := BitVec.ofNatLT n h
set_option bootstrap.genMatcherCode false in
/--
@@ -2025,7 +2009,7 @@ def UInt16.decEq (a b : UInt16) : Decidable (Eq a b) :=
instance : DecidableEq UInt16 := UInt16.decEq
instance : Inhabited UInt16 where
default := UInt16.ofNatCore 0 (of_decide_eq_true rfl)
default := UInt16.ofNatLT 0 (of_decide_eq_true rfl)
/-- The size of type `UInt32`, that is, `2^32 = 4294967296`. -/
abbrev UInt32.size : Nat := 4294967296
@@ -2047,21 +2031,13 @@ structure UInt32 where
attribute [extern "lean_uint32_of_nat_mk"] UInt32.ofBitVec
attribute [extern "lean_uint32_to_nat"] UInt32.toBitVec
/--
Pack a `Nat` less than `2^32` into a `UInt32`.
This function is overridden with a native implementation.
-/
@[extern "lean_uint32_of_nat"]
def UInt32.ofNatCore (n : @& Nat) (h : LT.lt n UInt32.size) : UInt32 where
toBitVec := BitVec.ofNatLt n h
/--
Pack a `Nat` less than `2^32` into a `UInt32`.
This function is overridden with a native implementation.
-/
@[extern "lean_uint32_of_nat"]
def UInt32.ofNatLT (n : @& Nat) (h : LT.lt n UInt32.size) : UInt32 where
toBitVec := BitVec.ofNatLt n h
toBitVec := BitVec.ofNatLT n h
/--
Unpack a `UInt32` as a `Nat`.
@@ -2084,7 +2060,7 @@ def UInt32.decEq (a b : UInt32) : Decidable (Eq a b) :=
instance : DecidableEq UInt32 := UInt32.decEq
instance : Inhabited UInt32 where
default := UInt32.ofNatCore 0 (of_decide_eq_true rfl)
default := UInt32.ofNatLT 0 (of_decide_eq_true rfl)
instance : LT UInt32 where
lt a b := LT.lt a.toBitVec b.toBitVec
@@ -2132,21 +2108,13 @@ structure UInt64 where
attribute [extern "lean_uint64_of_nat_mk"] UInt64.ofBitVec
attribute [extern "lean_uint64_to_nat"] UInt64.toBitVec
/--
Pack a `Nat` less than `2^64` into a `UInt64`.
This function is overridden with a native implementation.
-/
@[extern "lean_uint64_of_nat"]
def UInt64.ofNatCore (n : @& Nat) (h : LT.lt n UInt64.size) : UInt64 where
toBitVec := BitVec.ofNatLt n h
/--
Pack a `Nat` less than `2^64` into a `UInt64`.
This function is overridden with a native implementation.
-/
@[extern "lean_uint64_of_nat"]
def UInt64.ofNatLT (n : @& Nat) (h : LT.lt n UInt64.size) : UInt64 where
toBitVec := BitVec.ofNatLt n h
toBitVec := BitVec.ofNatLT n h
set_option bootstrap.genMatcherCode false in
/--
@@ -2164,7 +2132,7 @@ def UInt64.decEq (a b : UInt64) : Decidable (Eq a b) :=
instance : DecidableEq UInt64 := UInt64.decEq
instance : Inhabited UInt64 where
default := UInt64.ofNatCore 0 (of_decide_eq_true rfl)
default := UInt64.ofNatLT 0 (of_decide_eq_true rfl)
/-- The size of type `USize`, that is, `2^System.Platform.numBits`. -/
abbrev USize.size : Nat := (hPow 2 System.Platform.numBits)
@@ -2202,21 +2170,13 @@ structure USize where
attribute [extern "lean_usize_of_nat_mk"] USize.ofBitVec
attribute [extern "lean_usize_to_nat"] USize.toBitVec
/--
Pack a `Nat` less than `USize.size` into a `USize`.
This function is overridden with a native implementation.
-/
@[extern "lean_usize_of_nat"]
def USize.ofNatCore (n : @& Nat) (h : LT.lt n USize.size) : USize where
toBitVec := BitVec.ofNatLt n h
/--
Pack a `Nat` less than `USize.size` into a `USize`.
This function is overridden with a native implementation.
-/
@[extern "lean_usize_of_nat"]
def USize.ofNatLT (n : @& Nat) (h : LT.lt n USize.size) : USize where
toBitVec := BitVec.ofNatLt n h
toBitVec := BitVec.ofNatLT n h
set_option bootstrap.genMatcherCode false in
/--
@@ -2234,7 +2194,7 @@ def USize.decEq (a b : USize) : Decidable (Eq a b) :=
instance : DecidableEq USize := USize.decEq
instance : Inhabited USize where
default := USize.ofNatCore 0 usize_size_pos
default := USize.ofNatLT 0 usize_size_pos
/--
A `Nat` denotes a valid unicode codepoint if it is less than `0x110000`, and
@@ -2269,7 +2229,7 @@ This function is overridden with a native implementation.
-/
@[extern "lean_uint32_of_nat"]
def Char.ofNatAux (n : @& Nat) (h : n.isValidChar) : Char :=
{ val := BitVec.ofNatLt n (isValidChar_UInt32 h), valid := h }
{ val := BitVec.ofNatLT n (isValidChar_UInt32 h), valid := h }
/--
Convert a `Nat` into a `Char`. If the `Nat` does not encode a valid unicode scalar value,
@@ -2279,7 +2239,7 @@ Convert a `Nat` into a `Char`. If the `Nat` does not encode a valid unicode scal
def Char.ofNat (n : Nat) : Char :=
dite (n.isValidChar)
(fun h => Char.ofNatAux n h)
(fun _ => { val := BitVec.ofNatLt 0 (of_decide_eq_true rfl), valid := Or.inl (of_decide_eq_true rfl) })
(fun _ => { val := BitVec.ofNatLT 0 (of_decide_eq_true rfl), valid := Or.inl (of_decide_eq_true rfl) })
theorem Char.eq_of_val_eq : {c d : Char}, Eq c.val d.val Eq c d
| _, _, _, _, rfl => rfl
@@ -2302,9 +2262,9 @@ instance : DecidableEq Char :=
/-- Returns the number of bytes required to encode this `Char` in UTF-8. -/
def Char.utf8Size (c : Char) : Nat :=
let v := c.val
ite (LE.le v (UInt32.ofNatCore 0x7F (of_decide_eq_true rfl))) 1
(ite (LE.le v (UInt32.ofNatCore 0x7FF (of_decide_eq_true rfl))) 2
(ite (LE.le v (UInt32.ofNatCore 0xFFFF (of_decide_eq_true rfl))) 3 4))
ite (LE.le v (UInt32.ofNatLT 0x7F (of_decide_eq_true rfl))) 1
(ite (LE.le v (UInt32.ofNatLT 0x7FF (of_decide_eq_true rfl))) 2
(ite (LE.le v (UInt32.ofNatLT 0xFFFF (of_decide_eq_true rfl))) 3 4))
/--
`Option α` is the type of values which are either `some a` for some `a : α`,
@@ -3569,9 +3529,9 @@ with
/-- A hash function for names, which is stored inside the name itself as a
computed field. -/
@[computed_field] hash : Name UInt64
| .anonymous => .ofNatCore 1723 (of_decide_eq_true rfl)
| .anonymous => .ofNatLT 1723 (of_decide_eq_true rfl)
| .str p s => mixHash p.hash s.hash
| .num p v => mixHash p.hash (dite (LT.lt v UInt64.size) (fun h => UInt64.ofNatCore v h) (fun _ => UInt64.ofNatCore 17 (of_decide_eq_true rfl)))
| .num p v => mixHash p.hash (dite (LT.lt v UInt64.size) (fun h => UInt64.ofNatLT v h) (fun _ => UInt64.ofNatLT 17 (of_decide_eq_true rfl)))
instance : Inhabited Name where
default := Name.anonymous

View File

@@ -1589,6 +1589,13 @@ as well as tactics such as `next`, `case`, and `rename_i`.
-/
syntax (name := exposeNames) "expose_names" : tactic
/--
`#suggest_premises` will suggest premises for the current goal, using the currently registered premise selector.
The suggestions are printed in the order of their confidence, from highest to lowest.
-/
syntax (name := suggestPremises) "suggest_premises" : tactic
/--
Close fixed-width `BitVec` and `Bool` goals by obtaining a proof from an external SAT solver and
verifying it inside Lean. The solvable goals are currently limited to
@@ -1791,8 +1798,10 @@ users are encouraged to extend `get_elem_tactic_trivial` instead of this tactic.
macro "get_elem_tactic" : tactic =>
`(tactic| first
/-
Recall that `macro_rules` are tried in reverse order.
We want `assumption` to be tried first.
Recall that `macro_rules` (namely, for `get_elem_tactic_trivial`) are tried in reverse order.
We first, however, try `done`, since the necessary proof may already have been
found during unification, in which case there is no goal to solve (see #6999).
If a goal is present, we want `assumption` to be tried first.
This is important for theorems such as
```
[simp] theorem getElem_pop (a : Array α) (i : Nat) (hi : i < a.pop.size) :
@@ -1805,8 +1814,10 @@ macro "get_elem_tactic" : tactic =>
they add new `macro_rules` for `get_elem_tactic_trivial`.
TODO: Implement priorities for `macro_rules`.
TODO: Ensure we have a **high-priority** macro_rules for `get_elem_tactic_trivial` which is just `assumption`.
TODO: Ensure we have **high-priority** macro_rules for `get_elem_tactic_trivial` which are
just `done` and `assumption`.
-/
| done
| assumption
| get_elem_tactic_trivial
| fail "failed to prove index is valid, possible solutions:

View File

@@ -38,3 +38,4 @@ import Lean.LabelAttribute
import Lean.AddDecl
import Lean.Replay
import Lean.PrivateName
import Lean.PremiseSelection

View File

@@ -82,7 +82,7 @@ def addDecl (decl : Declaration) : CoreM Unit := do
async.commitCheckEnv ( getEnv)
let t BaseIO.mapTask (fun _ => checkAct) env.checked
let endRange? := ( getRef).getTailPos?.map fun pos => pos, pos
Core.logSnapshotTask { range? := endRange?, task := t }
Core.logSnapshotTask { stx? := none, reportingRange? := endRange?, task := t }
where doAdd := do
profileitM Exception "type checking" ( getOptions) do
withTraceNode `Kernel (fun _ => return m!"typechecking declarations {decl.getNames}") do

View File

@@ -537,7 +537,7 @@ partial def compileDecls (decls : List Name) (ref? : Option Declaration := none)
res.commitChecked ( getEnv)
let t BaseIO.mapTask (fun _ => checkAct) env.checked
let endRange? := ( getRef).getTailPos?.map fun pos => pos, pos
Core.logSnapshotTask { range? := endRange?, task := t }
Core.logSnapshotTask { stx? := none, reportingRange? := endRange?, task := t }
where doCompile := do
-- don't compile if kernel errored; should be converted into a task dependency when compilation
-- is made async as well

View File

@@ -358,7 +358,7 @@ def runLintersAsync (stx : Syntax) : CommandElabM Unit := do
-- We only start one task for all linters for now as most linters are fast and we simply want
-- to unblock elaboration of the next command
let lintAct wrapAsyncAsSnapshot fun _ => runLinters stx
logSnapshotTask { range? := none, task := ( BaseIO.asTask lintAct) }
logSnapshotTask { stx? := none, task := ( BaseIO.asTask lintAct) }
protected def getCurrMacroScope : CommandElabM Nat := do pure ( read).currMacroScope
protected def getMainModule : CommandElabM Name := do pure ( getEnv).mainModule
@@ -496,7 +496,7 @@ partial def elabCommand (stx : Syntax) : CommandElabM Unit := do
newNextMacroScope := nextMacroScope
hasTraces
next := Array.zipWith (fun cmdPromise cmd =>
{ range? := cmd.getRange?, task := cmdPromise.resultD default }) cmdPromises cmds
{ stx? := some cmd, task := cmdPromise.resultD default }) cmdPromises cmds
: MacroExpandedSnapshot
}
-- After the first command whose syntax tree changed, we must disable

View File

@@ -215,14 +215,17 @@ private def elabHeaders (views : Array DefView)
return newHeader
if let some snap := view.headerSnap? then
let (tacStx?, newTacTask?) mkTacTask view.value tacPromise
let bodySnap :=
-- Only use first line of body as range when we have incremental tactics as otherwise we
-- would cover their progress
{ range? := if newTacTask?.isSome then
let bodySnap := {
stx? := view.value
reportingRange? :=
if newTacTask?.isSome then
-- Only use first line of body as range when we have incremental tactics as otherwise we
-- would cover their progress
view.ref.getPos?.map fun pos => pos, pos
else
getBodyTerm? view.value |>.getD view.value |>.getRange?
task := bodyPromise.resultD default }
task := bodyPromise.resultD default
}
snap.new.resolve <| some {
diagnostics :=
( Language.Snapshot.Diagnostics.ofMessageLog ( Core.getAndEmptyMessageLog))
@@ -263,7 +266,7 @@ where
:= do
if let some e := getBodyTerm? body then
if let `(by $tacs*) := e then
return (e, some { range? := mkNullNode tacs |>.getRange?, task := tacPromise.resultD default })
return (e, some { stx? := mkNullNode tacs, task := tacPromise.resultD default })
tacPromise.resolve default
return (none, none)
@@ -1093,7 +1096,7 @@ def elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
} }
defs := defs.push {
fullHeaderRef
headerProcessedSnap := { range? := d.getRange?, task := headerPromise.resultD default }
headerProcessedSnap := { stx? := d, task := headerPromise.resultD default }
}
reusedAllHeaders := reusedAllHeaders && view.headerSnap?.any (·.old?.isSome)
views := views.push view

View File

@@ -230,11 +230,11 @@ where
stx := stx'
diagnostics := .empty
inner? := none
finished := .pure {
finished := .finished stx' {
diagnostics := .empty
state? := ( Tactic.saveState)
}
next := #[{ range? := stx'.getRange?, task := promise.resultD default }]
next := #[{ stx? := stx', task := promise.resultD default }]
}
-- Update `tacSnap?` to old unfolding
withTheReader Term.Context ({ · with tacSnap? := some {

View File

@@ -75,17 +75,6 @@ where
-- only allow `next` reuse in this case
oldNext? := oldParsed.next.get? 0 |>.map (old.stx, ·)
-- For `tac`'s snapshot task range, disregard synthetic info as otherwise
-- `SnapshotTree.findInfoTreeAtPos` might choose the wrong snapshot: for example, when
-- hovering over a `show` tactic, we should choose the info tree in `finished` over that in
-- `inner`, which points to execution of the synthesized `refine` step and does not contain
-- the full info. In most other places, siblings in the snapshot tree have disjoint ranges and
-- so this issue does not occur.
let mut range? := tac.getRange? (canonicalOnly := true)
-- Include trailing whitespace in the range so that `goalsAs?` does not have to wait for more
-- snapshots than necessary.
if let some range := range? then
range? := some { range with stop := range.stop.byteIdx + tac.getTrailingSize }
let next IO.Promise.new
let finished IO.Promise.new
let inner IO.Promise.new
@@ -93,9 +82,9 @@ where
desc := tac.getKind.toString
diagnostics := .empty
stx := tac
inner? := some { range?, task := inner.resultD default }
finished := { range?, task := finished.resultD default }
next := #[{ range? := stxs.getRange?, task := next.resultD default }]
inner? := some { stx? := tac, task := inner.resultD default }
finished := { stx? := tac, task := finished.resultD default }
next := #[{ stx? := stxs, task := next.resultD default }]
}
-- Run `tac` in a fresh info tree state and store resulting state in snapshot for
-- incremental reporting, then add back saved trees. Here we rely on `evalTactic`

View File

@@ -285,9 +285,9 @@ where
stx := mkNullNode altStxs
diagnostics := .empty
inner? := none
finished := { range? := none, task := finished.resultD default }
finished := { stx? := mkNullNode altStxs, reportingRange? := none, task := finished.resultD default }
next := Array.zipWith
(fun stx prom => { range? := stx.getRange?, task := prom.resultD default })
(fun stx prom => { stx? := some stx, task := prom.resultD default })
altStxs altPromises
}
goWithIncremental <| altPromises.mapIdx fun i prom => {

View File

@@ -2245,20 +2245,28 @@ def mkIntMul (a b : Expr) : Expr :=
private def intLEPred : Expr :=
mkApp2 (mkConst ``LE.le [0]) Int.mkType Int.mkInstLE
/-- Given `a b : Int`, return `a ≤ b` -/
/-- Given `a b : Int`, returns `a ≤ b` -/
def mkIntLE (a b : Expr) : Expr :=
mkApp2 intLEPred a b
private def intEqPred : Expr :=
mkApp (mkConst ``Eq [1]) Int.mkType
/-- Given `a b : Int`, return `a = b` -/
/-- Given `a b : Int`, returns `a = b` -/
def mkIntEq (a b : Expr) : Expr :=
mkApp2 intEqPred a b
def mkIntLit (n : Nat) : Expr :=
let r := mkRawNatLit n
mkApp3 (mkConst ``OfNat.ofNat [levelZero]) Int.mkType r (mkApp (mkConst ``instOfNat) r)
/-- Given `a b : Int`, returns `a b` -/
def mkIntDvd (a b : Expr) : Expr :=
mkApp4 (mkConst ``Dvd.dvd [0]) Int.mkType (mkConst ``Int.instDvd) a b
def mkIntLit (n : Int) : Expr :=
let r := mkRawNatLit n.natAbs
let r := mkApp3 (mkConst ``OfNat.ofNat [levelZero]) Int.mkType r (mkApp (mkConst ``instOfNat) r)
if n < 0 then
mkIntNeg r
else
r
def reflBoolTrue : Expr :=
mkApp2 (mkConst ``Eq.refl [levelOne]) (mkConst ``Bool) (mkConst ``Bool.true)

View File

@@ -66,31 +66,48 @@ structure Snapshot where
isFatal := false
deriving Inhabited
/--
Yields the default reporting range of a `Syntax`, which is just the `canonicalOnly` range
of the syntax.
-/
def SnapshotTask.defaultReportingRange? (stx? : Option Syntax) : Option String.Range :=
stx?.bind (·.getRange? (canonicalOnly := true))
/-- A task producing some snapshot type (usually a subclass of `Snapshot`). -/
-- Longer-term TODO: Give the server more control over the priority of tasks, depending on e.g. the
-- cursor position. This may require starting the tasks suspended (e.g. in `Thunk`). The server may
-- also need more dependency information for this in order to avoid priority inversion.
structure SnapshotTask (α : Type) where
/--
`Syntax` processed by this `SnapshotTask`.
The `Syntax` is used by the language server to determine whether to force this `SnapshotTask`
when a request is made.
-/
stx? : Option Syntax
/--
Range that is marked as being processed by the server while the task is running. If `none`,
the range of the outer task if some or else the entire file is reported.
-/
range? : Option String.Range
reportingRange? : Option String.Range := SnapshotTask.defaultReportingRange? stx?
/-- Underlying task producing the snapshot. -/
task : Task α
deriving Nonempty, Inhabited
/-- Creates a snapshot task from a reporting range and a `BaseIO` action. -/
def SnapshotTask.ofIO (range? : Option String.Range) (act : BaseIO α) : BaseIO (SnapshotTask α) := do
/-- Creates a snapshot task from the syntax processed by the task and a `BaseIO` action. -/
def SnapshotTask.ofIO (stx? : Option Syntax)
(reportingRange? : Option String.Range := defaultReportingRange? stx?) (act : BaseIO α) :
BaseIO (SnapshotTask α) := do
return {
range?
stx?
reportingRange?
task := ( BaseIO.asTask act)
}
/-- Creates a finished snapshot task. -/
def SnapshotTask.pure (a : α) : SnapshotTask α where
def SnapshotTask.finished (stx? : Option Syntax) (a : α) : SnapshotTask α where
stx?
-- irrelevant when already finished
range? := none
reportingRange? := none
task := .pure a
/--
@@ -99,25 +116,30 @@ def SnapshotTask.pure (a : α) : SnapshotTask α where
def SnapshotTask.cancel (t : SnapshotTask α) : BaseIO Unit :=
IO.cancel t.task
/-- Transforms a task's output without changing the reporting range. -/
def SnapshotTask.map (t : SnapshotTask α) (f : α β) (range? : Option String.Range := t.range?)
(sync := false) : SnapshotTask β :=
{ range?, task := t.task.map (sync := sync) f }
/-- Transforms a task's output without changing the processed syntax. -/
def SnapshotTask.map (t : SnapshotTask α) (f : α β) (stx? : Option Syntax := t.stx?)
(reportingRange? : Option String.Range := t.reportingRange?) (sync := false) : SnapshotTask β :=
{ stx?, reportingRange?, task := t.task.map (sync := sync) f }
/--
Chains two snapshot tasks. The range is taken from the first task if not specified; the range of
the second task is discarded. -/
Chains two snapshot tasks. The processed syntax and the reporting range are taken from the first
task if not specified; the processed syntax and the reporting range of the second task are
discarded. -/
def SnapshotTask.bind (t : SnapshotTask α) (act : α SnapshotTask β)
(range? : Option String.Range := t.range?) (sync := false) : SnapshotTask β :=
{ range?, task := t.task.bind (sync := sync) (act · |>.task) }
(stx? : Option Syntax := t.stx?) (reportingRange? : Option String.Range := t.reportingRange?)
(sync := false) : SnapshotTask β :=
{ stx?, reportingRange?, task := t.task.bind (sync := sync) (act · |>.task) }
/--
Chains two snapshot tasks. The range is taken from the first task if not specified; the range of
the second task is discarded. -/
Chains two snapshot tasks. The processed syntax and the reporting range are taken from the first
task if not specified; the processed syntax and the reporting range of the second task are
discarded. -/
def SnapshotTask.bindIO (t : SnapshotTask α) (act : α BaseIO (SnapshotTask β))
(range? : Option String.Range := t.range?) (sync := false) : BaseIO (SnapshotTask β) :=
(stx? : Option Syntax := t.stx?) (reportingRange? : Option String.Range := t.reportingRange?)
(sync := false) : BaseIO (SnapshotTask β) :=
return {
range?
stx?
reportingRange?
task := ( BaseIO.bindTask (sync := sync) t.task fun a => (·.task) <$> (act a))
}

View File

@@ -347,21 +347,22 @@ where
cancelTk? := ctx.newCancelTk
result? := some {
parserState := newParserState
processedSnap := ( oldSuccess.processedSnap.bindIO (range? := progressRange?)
(sync := true) fun oldProcessed => do
processedSnap := ( oldSuccess.processedSnap.bindIO (stx? := newStx)
(reportingRange? := progressRange?) (sync := true) fun oldProcessed => do
if let some oldProcSuccess := oldProcessed.result? then
-- also wait on old command parse snapshot as parsing is cheap and may allow for
-- elaboration reuse
oldProcSuccess.firstCmdSnap.bindIO (sync := true) (range? := progressRange?) fun oldCmd => do
oldProcSuccess.firstCmdSnap.bindIO (sync := true) (stx? := newStx)
(reportingRange? := progressRange?) fun oldCmd => do
let prom IO.Promise.new
parseCmd oldCmd newParserState oldProcSuccess.cmdState prom (sync := true) ctx
return .pure {
return .finished newStx {
diagnostics := oldProcessed.diagnostics
result? := some {
cmdState := oldProcSuccess.cmdState
firstCmdSnap := { range? := none, task := prom.result! } } }
firstCmdSnap := { stx? := none, task := prom.result! } } }
else
return .pure oldProcessed) } }
return .finished newStx oldProcessed) } }
else return old
-- fast path: if we have parsed the header successfully...
@@ -416,7 +417,7 @@ where
processHeader (stx : Syntax) (parserState : Parser.ModuleParserState) :
LeanProcessingM (SnapshotTask HeaderProcessedSnapshot) := do
let ctx read
SnapshotTask.ofIO (some 0, ctx.input.endPos) <|
SnapshotTask.ofIO stx (some 0, ctx.input.endPos) <|
ReaderT.run (r := ctx) <| -- re-enter reader in new task
withHeaderExceptions (α := HeaderProcessedSnapshot) ({ · with result? := none }) do
let setup match ( setupImports stx) with
@@ -472,7 +473,7 @@ where
infoTree? := cmdState.infoState.trees[0]!
result? := some {
cmdState
firstCmdSnap := { range? := none, task := prom.result! }
firstCmdSnap := { stx? := none, task := prom.result! }
}
}
@@ -491,13 +492,13 @@ where
let progressRange? := some newParserState.pos, ctx.input.endPos
let newProm IO.Promise.new
-- can reuse range, syntax unchanged
let _ old.finishedSnap.bindIO (sync := true) (range? := progressRange?) fun oldFinished =>
let _ old.finishedSnap.bindIO (sync := true) (reportingRange? := progressRange?) fun oldFinished =>
-- also wait on old command parse snapshot as parsing is cheap and may allow for
-- elaboration reuse
oldNext.bindIO (sync := true) (range? := progressRange?) fun oldNext => do
oldNext.bindIO (sync := true) (reportingRange? := progressRange?) fun oldNext => do
parseCmd oldNext newParserState oldFinished.cmdState newProm sync ctx
return .pure ()
prom.resolve <| { old with nextCmdSnap? := some { range? := none, task := newProm.result! } }
return .finished none ()
prom.resolve <| { old with nextCmdSnap? := some { stx? := none, task := newProm.result! } }
else prom.resolve old -- terminal command, we're done!
-- fast path, do not even start new task for this snapshot (see [Incremental Parsing])
@@ -540,7 +541,7 @@ where
prom.resolve <| {
diagnostics := .empty, stx := .missing, parserState
elabSnap := default
finishedSnap := .pure { diagnostics := .empty, cmdState }
finishedSnap := .finished none { diagnostics := .empty, cmdState }
reportSnap := default
nextCmdSnap? := none
}
@@ -552,27 +553,34 @@ where
let elabPromise IO.Promise.new
let finishedPromise IO.Promise.new
let reportPromise IO.Promise.new
-- report terminal tasks on first line of decl such as not to hide incremental tactics'
-- progress
let initRange? := getNiceCommandStartPos? stx |>.map fun pos => pos, pos
let finishedSnap := { range? := initRange?, task := finishedPromise.result! }
let minimalSnapshots := internal.cmdlineSnapshots.get cmdState.scopes.head!.opts
let next? if Parser.isTerminalCommand stx then pure none
-- for now, wait on "command finished" snapshot before parsing next command
else some <$> IO.Promise.new
let nextCmdSnap? := next?.map
({ range? := some parserState.pos, ctx.input.endPos, task := ·.result! })
let diagnostics Snapshot.Diagnostics.ofMessageLog msgLog
let (stx', parserState') := if minimalSnapshots && !Parser.isTerminalCommand stx then
(default, default)
else
(stx, parserState)
-- report terminal tasks on first line of decl such as not to hide incremental tactics'
-- progress
let initRange? := getNiceCommandStartPos? stx |>.map fun pos => pos, pos
let finishedSnap := {
stx? := stx'
reportingRange? := initRange?
task := finishedPromise.result!
}
let next? if Parser.isTerminalCommand stx then pure none
-- for now, wait on "command finished" snapshot before parsing next command
else some <$> IO.Promise.new
let nextCmdSnap? := next?.map ({
stx? := none
reportingRange? := some parserState.pos, ctx.input.endPos
task := ·.result!
})
let diagnostics Snapshot.Diagnostics.ofMessageLog msgLog
prom.resolve {
diagnostics, finishedSnap, nextCmdSnap?
stx := stx', parserState := parserState'
elabSnap := { range? := stx.getRange?, task := elabPromise.result! }
reportSnap := { range? := initRange?, task := reportPromise.result! }
elabSnap := { stx? := stx', task := elabPromise.result! }
reportSnap := { stx? := none, reportingRange? := initRange?, task := reportPromise.result! }
}
let cmdState doElab stx cmdState beginPos
{ old? := old?.map fun old => old.stx, old.elabSnap, new := elabPromise }
@@ -582,8 +590,8 @@ where
-- We want to trace all of `CommandParsedSnapshot` but `traceTask` is part of it, so let's
-- create a temporary snapshot tree containing all tasks but it
let snaps := #[
{ range? := none, task := elabPromise.result!.map (sync := true) toSnapshotTree },
{ range? := none, task := finishedPromise.result!.map (sync := true) toSnapshotTree }] ++
{ stx? := stx', task := elabPromise.result!.map (sync := true) toSnapshotTree },
{ stx? := stx', task := finishedPromise.result!.map (sync := true) toSnapshotTree }] ++
cmdState.snapshotTasks
let tree := SnapshotTree.mk { diagnostics := .empty } snaps
BaseIO.bindTask ( tree.waitAll) fun _ => do
@@ -603,7 +611,11 @@ where
pure <| .pure <| .mk { diagnostics := .empty } #[]
reportPromise.resolve <|
.mk { diagnostics := .empty } <|
cmdState.snapshotTasks.push { range? := initRange?, task := traceTask }
cmdState.snapshotTasks.push {
stx? := none
reportingRange? := initRange?
task := traceTask
}
if let some next := next? then
-- We're definitely off the fast-forwarding path now
parseCmd none parserState cmdState next (sync := false) ctx

View File

@@ -101,7 +101,7 @@ instance : ToSnapshotTree HeaderParsedSnapshot where
/-- Shortcut accessor to the final header state, if successful. -/
def HeaderParsedSnapshot.processedResult (snap : HeaderParsedSnapshot) :
SnapshotTask (Option HeaderProcessedState) :=
snap.result?.bind (·.processedSnap.map (sync := true) (·.result?)) |>.getD (.pure none)
snap.result?.bind (·.processedSnap.map (sync := true) (·.result?)) |>.getD (.finished none none)
/-- Initial snapshot of the Lean language processor: a "header parsed" snapshot. -/
abbrev InitialSnapshot := HeaderParsedSnapshot

View File

@@ -23,7 +23,7 @@ where go range? s := do
if let some range := range? then
desc := desc ++ f!"{file.toPosition range.start}-{file.toPosition range.stop} "
desc := desc ++ .prefixJoin "\n" ( s.element.diagnostics.msgLog.toList.mapM (·.toString))
if let some t := s.element.infoTree? then
trace[Elab.info] ( t.format)
withTraceNode `Elab.snapshotTree (fun _ => pure desc) do
s.children.toList.forM fun c => go c.range? c.get
s.children.toList.forM fun c => go c.reportingRange? c.get
if let some t := s.element.infoTree? then
trace[Elab.info] ( t.format)

View File

@@ -52,7 +52,7 @@ and the innermost binder is at the end. We update the binder names therein when
-/
go (e : Expr) : MonadCacheT ExprStructEq Expr (StateT (Array Name) CoreM) Expr := do
checkCache { val := e : ExprStructEq } fun _ => do
if e.isAppOfArity `binderNameHint 6 then
if e.isAppOfArity ``binderNameHint 6 then
let v := e.appFn!.appFn!.appArg!
let b := e.appFn!.appArg!
let e := e.appArg!

View File

@@ -32,6 +32,9 @@ def isInstDivInt (e : Expr) : MetaM Bool := do
def isInstModInt (e : Expr) : MetaM Bool := do
let_expr Int.instMod e | return false
return true
def isInstDvdInt (e : Expr) : MetaM Bool := do
let_expr Int.instDvd e | return false
return true
def isInstHAddInt (e : Expr) : MetaM Bool := do
let_expr instHAdd _ i e | return false
isInstAddInt i

View File

@@ -74,7 +74,7 @@ def getFinValue? (e : Expr) : MetaM (Option ((n : Nat) × Fin n)) := OptionT.run
Return `some ⟨n, v⟩` if `e` is:
- an `OfNat.ofNat` application
- a `BitVec.ofNat` application
- a `BitVec.ofNatLt` application
- a `BitVec.ofNatLT` application
that encode a `BitVec n` with value `v`.
-/
def getBitVecValue? (e : Expr) : MetaM (Option ((n : Nat) × BitVec n)) := OptionT.run do
@@ -83,7 +83,7 @@ def getBitVecValue? (e : Expr) : MetaM (Option ((n : Nat) × BitVec n)) := Optio
let n getNatValue? nExpr
let v getNatValue? vExpr
return n, BitVec.ofNat n v
| BitVec.ofNatLt nExpr vExpr _ =>
| BitVec.ofNatLT nExpr vExpr _ =>
let n getNatValue? nExpr
let v getNatValue? vExpr
return n, BitVec.ofNat n v

View File

@@ -4,76 +4,4 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Basic
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Util
namespace Lean.Meta.Grind.Arith.Offset
/-- Construct a model that statisfies all offset constraints -/
def mkModel (goal : Goal) : MetaM (Array (Expr × Nat)) := do
let s := goal.arith.offset
let dbg := grind.debug.get ( getOptions)
let nodes := s.nodes
let isInterpreted (u : Nat) : Bool := isNatNum s.nodes[u]!
let mut pre : Array (Option Int) := mkArray nodes.size none
/-
`needAdjust[u]` is true if `u` assignment is not connected to an interpreted value in the graph.
That is, its assignment may be negative.
-/
let mut needAdjust : Array Bool := mkArray nodes.size true
-- Initialize `needAdjust`
for u in [: nodes.size] do
if isInterpreted u then
-- Interpreted values have a fixed value.
needAdjust := needAdjust.set! u false
else if s.sources[u]!.any fun v _ => isInterpreted v then
needAdjust := needAdjust.set! u false
else if s.targets[u]!.any fun v _ => isInterpreted v then
needAdjust := needAdjust.set! u false
-- Set interpreted values
for h : u in [:nodes.size] do
let e := nodes[u]
if let some v getNatValue? e then
pre := pre.set! u (Int.ofNat v)
-- Set remaining values
for u in [:nodes.size] do
let lower? := s.sources[u]!.foldl (init := none) fun val? v k => Id.run do
let some va := pre[v]! | return val?
let val' := va - k
let some val := val? | return val'
if val' > val then return val' else val?
let upper? := s.targets[u]!.foldl (init := none) fun val? v k => Id.run do
let some va := pre[v]! | return val?
let val' := va + k
let some val := val? | return val'
if val' < val then return val' else val?
if dbg then
let some upper := upper? | pure ()
let some lower := lower? | pure ()
assert! lower upper
let some val := pre[u]! | pure ()
assert! lower val
assert! val upper
unless pre[u]!.isSome do
let val := lower?.getD (upper?.getD 0)
pre := pre.set! u (some val)
let min := pre.foldl (init := 0) fun min val? => Id.run do
let some val := val? | return min
if val < min then val else min
let mut r := {}
for u in [:nodes.size] do
let some val := pre[u]! | unreachable!
let val := if needAdjust[u]! then (val - min).toNat else val.toNat
let e := nodes[u]!
/-
We should not include the assignment for auxiliary offset terms since
they do not provide any additional information.
That said, the information is relevant for debugging `grind`.
-/
if (!( isLitValue e) && (isNatOffset? e).isNone && isNatNum? e != some 0) || grind.debug.get ( getOptions) then
r := r.push (e, val)
return r
end Lean.Meta.Grind.Arith.Offset
import Lean.Meta.Tactic.Grind.Arith.Offset.Model

View File

@@ -4,369 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Grind.Offset
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Arith.ProofUtil
namespace Lean.Meta.Grind.Arith.Offset
/-!
This module implements a decision procedure for offset constraints of the form:
```
x + k ≤ y
x ≤ y + k
```
where `k` is a numeral.
Each constraint is represented as an edge in a weighted graph.
The constraint `x + k ≤ y` is represented as a negative edge.
The shortest path between two nodes in the graph corresponds to an implied inequality.
When adding a new edge, the state is considered unsatisfiable if the new edge creates a negative cycle.
An incremental Floyd-Warshall algorithm is used to find the shortest paths between all nodes.
This module can also handle offset equalities of the form `x + k = y` by representing them with two edges:
```
x + k ≤ y
y ≤ x + k
```
The main advantage of this module over a full linear integer arithmetic procedure is
its ability to efficiently detect all implied equalities and inequalities.
-/
def get' : GoalM State := do
return ( get).arith.offset
@[inline] def modify' (f : State State) : GoalM Unit := do
modify fun s => { s with arith.offset := f s.arith.offset }
def mkNode (expr : Expr) : GoalM NodeId := do
if let some nodeId := ( get').nodeMap.find? { expr } then
return nodeId
let nodeId : NodeId := ( get').nodes.size
trace[grind.offset.internalize.term] "{expr} ↦ #{nodeId}"
modify' fun s => { s with
nodes := s.nodes.push expr
nodeMap := s.nodeMap.insert { expr } nodeId
sources := s.sources.push {}
targets := s.targets.push {}
proofs := s.proofs.push {}
}
markAsOffsetTerm expr
return nodeId
private def getExpr (u : NodeId) : GoalM Expr := do
return ( get').nodes[u]!
private def getDist? (u v : NodeId) : GoalM (Option Int) := do
return ( get').targets[u]!.find? v
private def getProof? (u v : NodeId) : GoalM (Option ProofInfo) := do
return ( get').proofs[u]!.find? v
private def getNodeId (e : Expr) : GoalM NodeId := do
let some nodeId := ( get').nodeMap.find? { expr := e }
| throwError "internal `grind` error, term has not been internalized by offset module{indentExpr e}"
return nodeId
private def getProof (u v : NodeId) : GoalM ProofInfo := do
let some p getProof? u v
| throwError "internal `grind` error, failed to construct proof for{indentExpr (← getExpr u)}\nand{indentExpr (← getExpr v)}"
return p
/--
Returns a proof for `u + k ≤ v` (or `u ≤ v + k`) where `k` is the
shortest path between `u` and `v`.
-/
private partial def mkProofForPath (u v : NodeId) : GoalM Expr := do
go ( getProof u v)
where
go (p : ProofInfo) : GoalM Expr := do
if u == p.w then
return p.proof
else
let p' getProof u p.w
go (mkTrans ( get').nodes p' p v)
/--
Given a new edge edge `u --(kuv)--> v` justified by proof `huv` s.t.
it creates a negative cycle with the existing path `v --{kvu}-->* u`, i.e., `kuv + kvu < 0`,
this function closes the current goal by constructing a proof of `False`.
-/
private def setUnsat (u v : NodeId) (kuv : Int) (huv : Expr) (kvu : Int) : GoalM Unit := do
assert! kuv + kvu < 0
let hvu mkProofForPath v u
let u getExpr u
let v getExpr v
closeGoal (mkUnsatProof u v kuv huv kvu hvu)
/-- Sets the new shortest distance `k` between nodes `u` and `v`. -/
private def setDist (u v : NodeId) (k : Int) : GoalM Unit := do
trace[grind.offset.dist] "{({ u, v, k : Cnstr NodeId})}"
modify' fun s => { s with
targets := s.targets.modify u fun es => es.insert v k
sources := s.sources.modify v fun es => es.insert u k
}
private def setProof (u v : NodeId) (p : ProofInfo) : GoalM Unit := do
modify' fun s => { s with
proofs := s.proofs.modify u fun es => es.insert v p
}
@[inline]
private def forEachSourceOf (u : NodeId) (f : NodeId Int GoalM Unit) : GoalM Unit := do
( get').sources[u]!.forM f
@[inline]
private def forEachTargetOf (u : NodeId) (f : NodeId Int GoalM Unit) : GoalM Unit := do
( get').targets[u]!.forM f
/-- Returns `true` if `k` is smaller than the shortest distance between `u` and `v` -/
private def isShorter (u v : NodeId) (k : Int) : GoalM Bool := do
if let some k' getDist? u v then
return k < k'
else
return true
/-- Adds `p` to the list of things to be propagated. -/
private def pushToPropagate (p : ToPropagate) : GoalM Unit :=
modify' fun s => { s with propagate := p :: s.propagate }
private def propagateEqTrue (e : Expr) (u v : NodeId) (k k' : Int) : GoalM Unit := do
let kuv mkProofForPath u v
let u getExpr u
let v getExpr v
pushEqTrue e <| mkPropagateEqTrueProof u v k kuv k'
private def propagateEqFalse (e : Expr) (u v : NodeId) (k k' : Int) : GoalM Unit := do
let kuv mkProofForPath u v
let u getExpr u
let v getExpr v
pushEqFalse e <| mkPropagateEqFalseProof u v k kuv k'
/-- Propagates all pending contraints and equalities and resets to "to do" list. -/
private def propagatePending : GoalM Unit := do
let todo modifyGet fun s => (s.arith.offset.propagate, { s with arith.offset.propagate := [] })
for p in todo do
match p with
| .eqTrue e u v k k' => propagateEqTrue e u v k k'
| .eqFalse e u v k k' => propagateEqFalse e u v k k'
| .eq u v =>
let ue getExpr u
let ve getExpr v
unless ( isEqv ue ve) do
let huv mkProofForPath u v
let hvu mkProofForPath v u
pushEq ue ve <| mkApp4 (mkConst ``Grind.Nat.eq_of_le_of_le) ue ve huv hvu
/--
Given `e` represented by constraint `c` (from `u` to `v`).
Checks whether `e = True` can be propagated using the path `u --(k)--> v`.
If it can, adds a new entry to propagation list.
-/
private def checkEqTrue (u v : NodeId) (k : Int) (c : Cnstr NodeId) (e : Expr) : GoalM Bool := do
if k c.k then
pushToPropagate <| .eqTrue e u v k c.k
return true
return false
/--
Given `e` represented by constraint `c` (from `v` to `u`).
Checks whether `e = False` can be propagated using the path `u --(k)--> v`.
If it can, adds a new entry to propagation list.
-/
private def checkEqFalse (u v : NodeId) (k : Int) (c : Cnstr NodeId) (e : Expr) : GoalM Bool := do
if k + c.k < 0 then
pushToPropagate <| .eqFalse e u v k c.k
return true
return false
/-- Equality propagation. -/
private def checkEq (u v : NodeId) (k : Int) : GoalM Unit := do
if k != 0 then return ()
let some k' getDist? v u | return ()
if k' != 0 then return ()
let ue getExpr u
let ve getExpr v
if ( isEqv ue ve) then return ()
pushToPropagate <| .eq u v
/--
Auxiliary function for implementing `propagateAll`.
Traverses the constraints `c` (representing an expression `e`) s.t.
`c.u = u` and `c.v = v`, it removes `c` from the list of constraints
associated with `(u, v)` IF
- `e` is already assigned, or
- `f c e` returns true
-/
@[inline]
private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId Expr GoalM Bool) : GoalM Unit := do
if let some cs := ( get').cnstrsOf.find? (u, v) then
let cs' cs.filterM fun (c, e) => do
if ( isEqTrue e <||> isEqFalse e) then
return false -- constraint was already assigned
else
return !( f c e)
modify' fun s => { s with cnstrsOf := s.cnstrsOf.insert (u, v) cs' }
/-- Finds constrains and equalities to be propagated. -/
private def checkToPropagate (u v : NodeId) (k : Int) : GoalM Unit := do
updateCnstrsOf u v fun c e => return !( checkEqTrue u v k c e)
updateCnstrsOf v u fun c e => return !( checkEqFalse u v k c e)
checkEq u v k
/--
If `isShorter u v k`, updates the shortest distance between `u` and `v`.
`w` is a node in the path from `u` to `v` such that `(← getProof? w v)` is `some`
-/
private def updateIfShorter (u v : NodeId) (k : Int) (w : NodeId) : GoalM Unit := do
if ( isShorter u v k) then
setDist u v k
setProof u v ( getProof w v)
checkToPropagate u v k
def Cnstr.toExpr (c : Cnstr NodeId) : GoalM Expr := do
let u := ( get').nodes[c.u]!
let v := ( get').nodes[c.v]!
if c.k == 0 then
return mkNatLE u v
else if c.k < 0 then
return mkNatLE (mkNatAdd u (Lean.toExpr ((-c.k).toNat))) v
else
return mkNatLE u (mkNatAdd v (Lean.toExpr c.k.toNat))
def checkInvariants : GoalM Unit := do
let s get'
for u in [:s.targets.size], es in s.targets.toArray do
for (v, k) in es do
let c : Cnstr NodeId := { u, v, k }
trace[grind.debug.offset] "{c}"
let p mkProofForPath u v
trace[grind.debug.offset.proof] "{p} : {← inferType p}"
check p
unless ( withDefault <| isDefEq ( inferType p) ( Cnstr.toExpr c)) do
trace[grind.debug.offset.proof] "failed: {← inferType p} =?= {← Cnstr.toExpr c}"
unreachable!
/--
Adds an edge `u --(k) --> v` justified by the proof term `p`, and then
if no negative cycle was created, updates the shortest distance of affected
node pairs.
-/
def addEdge (u : NodeId) (v : NodeId) (k : Int) (p : Expr) : GoalM Unit := do
if ( isInconsistent) then return ()
if let some k' getDist? v u then
if k'+k < 0 then
setUnsat u v k p k'
return ()
if ( isShorter u v k) then
setDist u v k
setProof u v { w := u, k, proof := p }
checkToPropagate u v k
update
propagatePending
where
update : GoalM Unit := do
forEachTargetOf v fun j k₂ => do
/- Check whether new path: `u -(k)-> v -(k₂)-> j` is shorter -/
updateIfShorter u j (k+k₂) v
forEachSourceOf u fun i k₁ => do
/- Check whether new path: `i -(k₁)-> u -(k)-> v` is shorter -/
updateIfShorter i v (k₁+k) u
forEachTargetOf v fun j k₂ => do
/- Check whether new path: `i -(k₁)-> u -(k)-> v -(k₂) -> j` is shorter -/
updateIfShorter i j (k₁+k+k₂) v
private def internalizeCnstr (e : Expr) (c : Cnstr Expr) : GoalM Unit := do
let u mkNode c.u
let v mkNode c.v
let c := { c with u, v }
if let some k getDist? u v then
if k c.k then
propagateEqTrue e u v k c.k
return ()
if let some k getDist? v u then
if k + c.k < 0 then
propagateEqFalse e v u k c.k
return ()
trace[grind.offset.internalize] "{e} ↦ {c}"
modify' fun s => { s with
cnstrs := s.cnstrs.insert { expr := e } c
cnstrsOf :=
let cs := if let some cs := s.cnstrsOf.find? (u, v) then (c, e) :: cs else [(c, e)]
s.cnstrsOf.insert (u, v) cs
}
private def getZeroNode : GoalM NodeId := do
mkNode ( getNatZeroExpr)
/-- Internalize `e` of the form `b + k` -/
private def internalizeTerm (e : Expr) (b : Expr) (k : Nat) : GoalM Unit := do
-- `e` is of the form `b + k`
let u mkNode e
let v mkNode b
-- `u = v + k`. So, we add edges for `u ≤ v + k` and `v + k ≤ u`.
let h := mkApp (mkConst ``Nat.le_refl) e
addEdge u v k h
addEdge v u (-k) h
-- `0 + k ≤ u`
let z getZeroNode
addEdge z u (-k) <| mkApp2 (mkConst ``Grind.Nat.le_offset) b (toExpr k)
/--
Returns `true`, if `parent?` is relevant for internalization.
For example, we do not want to internalize an offset term that
is the child of an addition. This kind of term will be processed by the
more general linear arithmetic module.
-/
private def isRelevantParent (parent? : Option Expr) : GoalM Bool := do
let some parent := parent? | return false
let z getNatZeroExpr
return !isNatAdd parent && (isNatOffsetCnstr? parent z).isNone
private def isEqParent (parent? : Option Expr) : Bool := Id.run do
let some parent := parent? | return false
return parent.isEq
private def alreadyInternalized (e : Expr) : GoalM Bool := do
let s get'
return s.cnstrs.contains { expr := e } || s.nodeMap.contains { expr := e }
def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
if ( alreadyInternalized e) then
return ()
let z getNatZeroExpr
if let some c := isNatOffsetCnstr? e z then
internalizeCnstr e c
else if ( isRelevantParent parent?) then
if let some (b, k) := isNatOffset? e then
internalizeTerm e b k
else if let some k := isNatNum? e then
-- core module has support for detecting equality between literals
unless isEqParent parent? do
internalizeTerm e z k
@[export lean_process_new_offset_eq]
def processNewOffsetEqImpl (a b : Expr) : GoalM Unit := do
unless isSameExpr a b do
trace[grind.offset.eq.to] "{a}, {b}"
let u getNodeId a
let v getNodeId b
let h mkEqProof a b
addEdge u v 0 <| mkApp3 (mkConst ``Grind.Nat.le_of_eq_1) a b h
addEdge v u 0 <| mkApp3 (mkConst ``Grind.Nat.le_of_eq_2) a b h
@[export lean_process_new_offset_eq_lit]
def processNewOffsetEqLitImpl (a b : Expr) : GoalM Unit := do
unless isSameExpr a b do
trace[grind.offset.eq.to] "{a}, {b}"
let some k := isNatNum? b | unreachable!
let u getNodeId a
let z mkNode ( getNatZeroExpr)
let h mkEqProof a b
addEdge u z k <| mkApp3 (mkConst ``Grind.Nat.le_of_eq_1) a b h
addEdge z u (-k) <| mkApp3 (mkConst ``Grind.Nat.le_of_eq_2) a b h
def traceDists : GoalM Unit := do
let s get'
for u in [:s.targets.size], es in s.targets.toArray do
for (v, k) in es do
trace[grind.offset.dist] "#{u} -({k})-> #{v}"
end Lean.Meta.Grind.Arith.Offset
import Lean.Meta.Tactic.Grind.Arith.Offset.Main
import Lean.Meta.Tactic.Grind.Arith.Offset.Proof
import Lean.Meta.Tactic.Grind.Arith.Offset.Util
import Lean.Meta.Tactic.Grind.Arith.Offset.Types

View File

@@ -0,0 +1,373 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Grind.Offset
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Arith.Offset.Proof
import Lean.Meta.Tactic.Grind.Arith.Offset.Util
namespace Lean.Meta.Grind.Arith.Offset
/-!
This module implements a decision procedure for offset constraints of the form:
```
x + k ≤ y
x ≤ y + k
```
where `k` is a numeral.
Each constraint is represented as an edge in a weighted graph.
The constraint `x + k ≤ y` is represented as a negative edge.
The shortest path between two nodes in the graph corresponds to an implied inequality.
When adding a new edge, the state is considered unsatisfiable if the new edge creates a negative cycle.
An incremental Floyd-Warshall algorithm is used to find the shortest paths between all nodes.
This module can also handle offset equalities of the form `x + k = y` by representing them with two edges:
```
x + k ≤ y
y ≤ x + k
```
The main advantage of this module over a full linear integer arithmetic procedure is
its ability to efficiently detect all implied equalities and inequalities.
-/
def get' : GoalM State := do
return ( get).arith.offset
@[inline] def modify' (f : State State) : GoalM Unit := do
modify fun s => { s with arith.offset := f s.arith.offset }
def mkNode (expr : Expr) : GoalM NodeId := do
if let some nodeId := ( get').nodeMap.find? { expr } then
return nodeId
let nodeId : NodeId := ( get').nodes.size
trace[grind.offset.internalize.term] "{expr} ↦ #{nodeId}"
modify' fun s => { s with
nodes := s.nodes.push expr
nodeMap := s.nodeMap.insert { expr } nodeId
sources := s.sources.push {}
targets := s.targets.push {}
proofs := s.proofs.push {}
}
markAsOffsetTerm expr
return nodeId
private def getExpr (u : NodeId) : GoalM Expr := do
return ( get').nodes[u]!
private def getDist? (u v : NodeId) : GoalM (Option Int) := do
return ( get').targets[u]!.find? v
private def getProof? (u v : NodeId) : GoalM (Option ProofInfo) := do
return ( get').proofs[u]!.find? v
private def getNodeId (e : Expr) : GoalM NodeId := do
let some nodeId := ( get').nodeMap.find? { expr := e }
| throwError "internal `grind` error, term has not been internalized by offset module{indentExpr e}"
return nodeId
private def getProof (u v : NodeId) : GoalM ProofInfo := do
let some p getProof? u v
| throwError "internal `grind` error, failed to construct proof for{indentExpr (← getExpr u)}\nand{indentExpr (← getExpr v)}"
return p
/--
Returns a proof for `u + k ≤ v` (or `u ≤ v + k`) where `k` is the
shortest path between `u` and `v`.
-/
private partial def mkProofForPath (u v : NodeId) : GoalM Expr := do
go ( getProof u v)
where
go (p : ProofInfo) : GoalM Expr := do
if u == p.w then
return p.proof
else
let p' getProof u p.w
go (mkTrans ( get').nodes p' p v)
/--
Given a new edge edge `u --(kuv)--> v` justified by proof `huv` s.t.
it creates a negative cycle with the existing path `v --{kvu}-->* u`, i.e., `kuv + kvu < 0`,
this function closes the current goal by constructing a proof of `False`.
-/
private def setUnsat (u v : NodeId) (kuv : Int) (huv : Expr) (kvu : Int) : GoalM Unit := do
assert! kuv + kvu < 0
let hvu mkProofForPath v u
let u getExpr u
let v getExpr v
closeGoal (mkUnsatProof u v kuv huv kvu hvu)
/-- Sets the new shortest distance `k` between nodes `u` and `v`. -/
private def setDist (u v : NodeId) (k : Int) : GoalM Unit := do
trace[grind.offset.dist] "{({ u, v, k : Cnstr NodeId})}"
modify' fun s => { s with
targets := s.targets.modify u fun es => es.insert v k
sources := s.sources.modify v fun es => es.insert u k
}
private def setProof (u v : NodeId) (p : ProofInfo) : GoalM Unit := do
modify' fun s => { s with
proofs := s.proofs.modify u fun es => es.insert v p
}
@[inline]
private def forEachSourceOf (u : NodeId) (f : NodeId Int GoalM Unit) : GoalM Unit := do
( get').sources[u]!.forM f
@[inline]
private def forEachTargetOf (u : NodeId) (f : NodeId Int GoalM Unit) : GoalM Unit := do
( get').targets[u]!.forM f
/-- Returns `true` if `k` is smaller than the shortest distance between `u` and `v` -/
private def isShorter (u v : NodeId) (k : Int) : GoalM Bool := do
if let some k' getDist? u v then
return k < k'
else
return true
/-- Adds `p` to the list of things to be propagated. -/
private def pushToPropagate (p : ToPropagate) : GoalM Unit :=
modify' fun s => { s with propagate := p :: s.propagate }
private def propagateEqTrue (e : Expr) (u v : NodeId) (k k' : Int) : GoalM Unit := do
let kuv mkProofForPath u v
let u getExpr u
let v getExpr v
pushEqTrue e <| mkPropagateEqTrueProof u v k kuv k'
private def propagateEqFalse (e : Expr) (u v : NodeId) (k k' : Int) : GoalM Unit := do
let kuv mkProofForPath u v
let u getExpr u
let v getExpr v
pushEqFalse e <| mkPropagateEqFalseProof u v k kuv k'
/-- Propagates all pending contraints and equalities and resets to "to do" list. -/
private def propagatePending : GoalM Unit := do
let todo modifyGet fun s => (s.arith.offset.propagate, { s with arith.offset.propagate := [] })
for p in todo do
match p with
| .eqTrue e u v k k' => propagateEqTrue e u v k k'
| .eqFalse e u v k k' => propagateEqFalse e u v k k'
| .eq u v =>
let ue getExpr u
let ve getExpr v
unless ( isEqv ue ve) do
let huv mkProofForPath u v
let hvu mkProofForPath v u
pushEq ue ve <| mkApp4 (mkConst ``Grind.Nat.eq_of_le_of_le) ue ve huv hvu
/--
Given `e` represented by constraint `c` (from `u` to `v`).
Checks whether `e = True` can be propagated using the path `u --(k)--> v`.
If it can, adds a new entry to propagation list.
-/
private def checkEqTrue (u v : NodeId) (k : Int) (c : Cnstr NodeId) (e : Expr) : GoalM Bool := do
if k c.k then
pushToPropagate <| .eqTrue e u v k c.k
return true
return false
/--
Given `e` represented by constraint `c` (from `v` to `u`).
Checks whether `e = False` can be propagated using the path `u --(k)--> v`.
If it can, adds a new entry to propagation list.
-/
private def checkEqFalse (u v : NodeId) (k : Int) (c : Cnstr NodeId) (e : Expr) : GoalM Bool := do
if k + c.k < 0 then
pushToPropagate <| .eqFalse e u v k c.k
return true
return false
/-- Equality propagation. -/
private def checkEq (u v : NodeId) (k : Int) : GoalM Unit := do
if k != 0 then return ()
let some k' getDist? v u | return ()
if k' != 0 then return ()
let ue getExpr u
let ve getExpr v
if ( isEqv ue ve) then return ()
pushToPropagate <| .eq u v
/--
Auxiliary function for implementing `propagateAll`.
Traverses the constraints `c` (representing an expression `e`) s.t.
`c.u = u` and `c.v = v`, it removes `c` from the list of constraints
associated with `(u, v)` IF
- `e` is already assigned, or
- `f c e` returns true
-/
@[inline]
private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId Expr GoalM Bool) : GoalM Unit := do
if let some cs := ( get').cnstrsOf.find? (u, v) then
let cs' cs.filterM fun (c, e) => do
if ( isEqTrue e <||> isEqFalse e) then
return false -- constraint was already assigned
else
return !( f c e)
modify' fun s => { s with cnstrsOf := s.cnstrsOf.insert (u, v) cs' }
/-- Finds constrains and equalities to be propagated. -/
private def checkToPropagate (u v : NodeId) (k : Int) : GoalM Unit := do
updateCnstrsOf u v fun c e => return !( checkEqTrue u v k c e)
updateCnstrsOf v u fun c e => return !( checkEqFalse u v k c e)
checkEq u v k
/--
If `isShorter u v k`, updates the shortest distance between `u` and `v`.
`w` is a node in the path from `u` to `v` such that `(← getProof? w v)` is `some`
-/
private def updateIfShorter (u v : NodeId) (k : Int) (w : NodeId) : GoalM Unit := do
if ( isShorter u v k) then
setDist u v k
setProof u v ( getProof w v)
checkToPropagate u v k
def Cnstr.toExpr (c : Cnstr NodeId) : GoalM Expr := do
let u := ( get').nodes[c.u]!
let v := ( get').nodes[c.v]!
if c.k == 0 then
return mkNatLE u v
else if c.k < 0 then
return mkNatLE (mkNatAdd u (Lean.toExpr ((-c.k).toNat))) v
else
return mkNatLE u (mkNatAdd v (Lean.toExpr c.k.toNat))
def checkInvariants : GoalM Unit := do
let s get'
for u in [:s.targets.size], es in s.targets.toArray do
for (v, k) in es do
let c : Cnstr NodeId := { u, v, k }
trace[grind.debug.offset] "{c}"
let p mkProofForPath u v
trace[grind.debug.offset.proof] "{p} : {← inferType p}"
check p
unless ( withDefault <| isDefEq ( inferType p) ( Cnstr.toExpr c)) do
trace[grind.debug.offset.proof] "failed: {← inferType p} =?= {← Cnstr.toExpr c}"
unreachable!
/--
Adds an edge `u --(k) --> v` justified by the proof term `p`, and then
if no negative cycle was created, updates the shortest distance of affected
node pairs.
-/
def addEdge (u : NodeId) (v : NodeId) (k : Int) (p : Expr) : GoalM Unit := do
if ( isInconsistent) then return ()
if let some k' getDist? v u then
if k'+k < 0 then
setUnsat u v k p k'
return ()
if ( isShorter u v k) then
setDist u v k
setProof u v { w := u, k, proof := p }
checkToPropagate u v k
update
propagatePending
where
update : GoalM Unit := do
forEachTargetOf v fun j k₂ => do
/- Check whether new path: `u -(k)-> v -(k₂)-> j` is shorter -/
updateIfShorter u j (k+k₂) v
forEachSourceOf u fun i k₁ => do
/- Check whether new path: `i -(k₁)-> u -(k)-> v` is shorter -/
updateIfShorter i v (k₁+k) u
forEachTargetOf v fun j k₂ => do
/- Check whether new path: `i -(k₁)-> u -(k)-> v -(k₂) -> j` is shorter -/
updateIfShorter i j (k₁+k+k₂) v
private def internalizeCnstr (e : Expr) (c : Cnstr Expr) : GoalM Unit := do
let u mkNode c.u
let v mkNode c.v
let c := { c with u, v }
if let some k getDist? u v then
if k c.k then
propagateEqTrue e u v k c.k
return ()
if let some k getDist? v u then
if k + c.k < 0 then
propagateEqFalse e v u k c.k
return ()
trace[grind.offset.internalize] "{e} ↦ {c}"
modify' fun s => { s with
cnstrs := s.cnstrs.insert { expr := e } c
cnstrsOf :=
let cs := if let some cs := s.cnstrsOf.find? (u, v) then (c, e) :: cs else [(c, e)]
s.cnstrsOf.insert (u, v) cs
}
private def getZeroNode : GoalM NodeId := do
mkNode ( getNatZeroExpr)
/-- Internalize `e` of the form `b + k` -/
private def internalizeTerm (e : Expr) (b : Expr) (k : Nat) : GoalM Unit := do
-- `e` is of the form `b + k`
let u mkNode e
let v mkNode b
-- `u = v + k`. So, we add edges for `u ≤ v + k` and `v + k ≤ u`.
let h := mkApp (mkConst ``Nat.le_refl) e
addEdge u v k h
addEdge v u (-k) h
-- `0 + k ≤ u`
let z getZeroNode
addEdge z u (-k) <| mkApp2 (mkConst ``Grind.Nat.le_offset) b (toExpr k)
/--
Returns `true`, if `parent?` is relevant for internalization.
For example, we do not want to internalize an offset term that
is the child of an addition. This kind of term will be processed by the
more general linear arithmetic module.
-/
private def isRelevantParent (parent? : Option Expr) : GoalM Bool := do
let some parent := parent? | return false
let z getNatZeroExpr
return !isNatAdd parent && (isNatOffsetCnstr? parent z).isNone
private def isEqParent (parent? : Option Expr) : Bool := Id.run do
let some parent := parent? | return false
return parent.isEq
private def alreadyInternalized (e : Expr) : GoalM Bool := do
let s get'
return s.cnstrs.contains { expr := e } || s.nodeMap.contains { expr := e }
def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
if ( alreadyInternalized e) then
return ()
let z getNatZeroExpr
if let some c := isNatOffsetCnstr? e z then
internalizeCnstr e c
else if ( isRelevantParent parent?) then
if let some (b, k) := isNatOffset? e then
internalizeTerm e b k
else if let some k := isNatNum? e then
-- core module has support for detecting equality between literals
unless isEqParent parent? do
internalizeTerm e z k
@[export lean_process_new_offset_eq]
def processNewOffsetEqImpl (a b : Expr) : GoalM Unit := do
unless isSameExpr a b do
trace[grind.offset.eq.to] "{a}, {b}"
let u getNodeId a
let v getNodeId b
let h mkEqProof a b
addEdge u v 0 <| mkApp3 (mkConst ``Grind.Nat.le_of_eq_1) a b h
addEdge v u 0 <| mkApp3 (mkConst ``Grind.Nat.le_of_eq_2) a b h
@[export lean_process_new_offset_eq_lit]
def processNewOffsetEqLitImpl (a b : Expr) : GoalM Unit := do
unless isSameExpr a b do
trace[grind.offset.eq.to] "{a}, {b}"
let some k := isNatNum? b | unreachable!
let u getNodeId a
let z mkNode ( getNatZeroExpr)
let h mkEqProof a b
addEdge u z k <| mkApp3 (mkConst ``Grind.Nat.le_of_eq_1) a b h
addEdge z u (-k) <| mkApp3 (mkConst ``Grind.Nat.le_of_eq_2) a b h
def traceDists : GoalM Unit := do
let s get'
for u in [:s.targets.size], es in s.targets.toArray do
for (v, k) in es do
trace[grind.offset.dist] "#{u} -({k})-> #{v}"
end Lean.Meta.Grind.Arith.Offset

View File

@@ -0,0 +1,78 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Basic
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Util
namespace Lean.Meta.Grind.Arith.Offset
/-- Construct a model that statisfies all offset constraints -/
def mkModel (goal : Goal) : MetaM (Array (Expr × Nat)) := do
let s := goal.arith.offset
let dbg := grind.debug.get ( getOptions)
let nodes := s.nodes
let isInterpreted (u : Nat) : Bool := isNatNum s.nodes[u]!
let mut pre : Array (Option Int) := mkArray nodes.size none
/-
`needAdjust[u]` is true if `u` assignment is not connected to an interpreted value in the graph.
That is, its assignment may be negative.
-/
let mut needAdjust : Array Bool := mkArray nodes.size true
-- Initialize `needAdjust`
for u in [: nodes.size] do
if isInterpreted u then
-- Interpreted values have a fixed value.
needAdjust := needAdjust.set! u false
else if s.sources[u]!.any fun v _ => isInterpreted v then
needAdjust := needAdjust.set! u false
else if s.targets[u]!.any fun v _ => isInterpreted v then
needAdjust := needAdjust.set! u false
-- Set interpreted values
for h : u in [:nodes.size] do
let e := nodes[u]
if let some v getNatValue? e then
pre := pre.set! u (Int.ofNat v)
-- Set remaining values
for u in [:nodes.size] do
let lower? := s.sources[u]!.foldl (init := none) fun val? v k => Id.run do
let some va := pre[v]! | return val?
let val' := va - k
let some val := val? | return val'
if val' > val then return val' else val?
let upper? := s.targets[u]!.foldl (init := none) fun val? v k => Id.run do
let some va := pre[v]! | return val?
let val' := va + k
let some val := val? | return val'
if val' < val then return val' else val?
if dbg then
let some upper := upper? | pure ()
let some lower := lower? | pure ()
assert! lower upper
let some val := pre[u]! | pure ()
assert! lower val
assert! val upper
unless pre[u]!.isSome do
let val := lower?.getD (upper?.getD 0)
pre := pre.set! u (some val)
let min := pre.foldl (init := 0) fun min val? => Id.run do
let some val := val? | return min
if val < min then val else min
let mut r := {}
for u in [:nodes.size] do
let some val := pre[u]! | unreachable!
let val := if needAdjust[u]! then (val - min).toNat else val.toNat
let e := nodes[u]!
/-
We should not include the assignment for auxiliary offset terms since
they do not provide any additional information.
That said, the information is relevant for debugging `grind`.
-/
if (!( isLitValue e) && (isNatOffset? e).isNone && isNatNum? e != some 0) || grind.debug.get ( getOptions) then
r := r.push (e, val)
return r
end Lean.Meta.Grind.Arith.Offset

View File

@@ -8,14 +8,11 @@ import Init.Grind.Offset
import Init.Grind.Lemmas
import Lean.Meta.Tactic.Grind.Types
namespace Lean.Meta.Grind.Arith
namespace Lean.Meta.Grind.Arith.Offset
/-!
Helper functions for constructing proof terms in the arithmetic procedures.
Helper functions for constructing proof terms in the offset contraint procedure.
-/
namespace Offset
/-- Returns a proof for `true = true` -/
def rfl_true : Expr := mkConst ``Grind.rfl_true
@@ -163,6 +160,4 @@ def mkPropagateEqFalseProof (u v : Expr) (k : Int) (huv : Expr) (k' : Int) : Exp
let k' := -k'
mkApp6 (mkConst ``Grind.Nat.lo_eq_false_of_ro) u v (toExprN k) (toExprN k') rfl_true huv
end Offset
end Lean.Meta.Grind.Arith
end Lean.Meta.Grind.Arith.Offset

View File

@@ -0,0 +1,74 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Data.PersistentArray
import Lean.Meta.Tactic.Grind.ENodeKey
import Lean.Meta.Tactic.Grind.Arith.Util
import Lean.Meta.Tactic.Grind.Arith.Offset.Util
namespace Lean.Meta.Grind.Arith.Offset
abbrev NodeId := Nat
instance : ToMessageData (Offset.Cnstr NodeId) where
toMessageData c := Offset.toMessageData (α := NodeId) (inst := { toMessageData n := m!"#{n}" }) c
/-- Auxiliary structure used for proof extraction. -/
structure ProofInfo where
w : NodeId
k : Int
proof : Expr
deriving Inhabited
/--
Auxiliary inductive type for representing contraints and equalities
that should be propagated to core.
Recall that we cannot compute proofs until the short-distance
data-structures have been fully updated when a new edge is inserted.
Thus, we store the information to be propagated into a list.
See field `propagate` in `State`.
-/
inductive ToPropagate where
| eqTrue (e : Expr) (u v : NodeId) (k k' : Int)
| eqFalse (e : Expr) (u v : NodeId) (k k' : Int)
| eq (u v : NodeId)
deriving Inhabited
/-- State of the constraint offset procedure. -/
structure State where
/-- Mapping from `NodeId` to the `Expr` represented by the node. -/
nodes : PArray Expr := {}
/-- Mapping from `Expr` to a node representing it. -/
nodeMap : PHashMap ENodeKey NodeId := {}
/-- Mapping from `Expr` representing inequalites to constraints. -/
cnstrs : PHashMap ENodeKey (Cnstr NodeId) := {}
/--
Mapping from pairs `(u, v)` to a list of offset constraints on `u` and `v`.
We use this mapping to implement exhaustive constraint propagation.
-/
cnstrsOf : PHashMap (NodeId × NodeId) (List (Cnstr NodeId × Expr)) := {}
/--
For each node with id `u`, `sources[u]` contains
pairs `(v, k)` s.t. there is a path from `v` to `u` with weight `k`.
-/
sources : PArray (AssocList NodeId Int) := {}
/--
For each node with id `u`, `targets[u]` contains
pairs `(v, k)` s.t. there is a path from `u` to `v` with weight `k`.
-/
targets : PArray (AssocList NodeId Int) := {}
/--
Proof reconstruction information. For each node with id `u`, `proofs[u]` contains
pairs `(v, { w, proof })` s.t. there is a path from `u` to `v`, and
`w` is the penultimate node in the path, and `proof` is the justification for
the last edge.
-/
proofs : PArray (AssocList NodeId ProofInfo) := {}
/-- Truth values and equalities to propagate to core. -/
propagate : List ToPropagate := []
deriving Inhabited
end Lean.Meta.Grind.Arith.Offset

View File

@@ -0,0 +1,63 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Expr
import Lean.Message
import Lean.Meta.Tactic.Grind.Arith.Util
namespace Lean.Meta.Grind.Arith.Offset
/-- Returns `some (a, k)` if `e` is of the form `a + k`. -/
def isNatOffset? (e : Expr) : Option (Expr × Nat) := Id.run do
let some (a, b) := isNatAdd? e | none
let some k := isNatNum? b | none
some (a, k)
/-- An offset constraint. -/
structure Cnstr (α : Type) where
u : α
v : α
k : Int := 0
deriving Inhabited
def Cnstr.neg : Cnstr α Cnstr α
| { u, v, k } => { u := v, v := u, k := -k - 1 }
example (c : Cnstr α) : c.neg.neg = c := by
cases c; simp [Cnstr.neg]; omega
def toMessageData [inst : ToMessageData α] (c : Cnstr α) : MessageData :=
match c.k with
| .ofNat 0 => m!"{c.u} ≤ {c.v}"
| .ofNat k => m!"{c.u} ≤ {c.v} + {k}"
| .negSucc k => m!"{c.u} + {k + 1} ≤ {c.v}"
instance : ToMessageData (Cnstr Expr) where
toMessageData c := Offset.toMessageData c
/--
Returns `some cnstr` if `e` is offset constraint.
Remark: `z` is `0` numeral. It is an extra argument because we
want to be able to provide the one that has already been internalized.
-/
def isNatOffsetCnstr? (e : Expr) (z : Expr) : Option (Cnstr Expr) :=
match_expr e with
| LE.le _ inst a b => if isInstLENat inst then go a b else none
| _ => none
where
go (u v : Expr) :=
if let some (u, k) := isNatOffset? u then
some { u, k := - k, v }
else if let some (v, k) := isNatOffset? v then
some { u, v, k }
else if let some k := isNatNum? u then
some { u := z, v, k := - k }
else if let some k := isNatNum? v then
some { u, v := z, k }
else
some { u, v }
end Lean.Meta.Grind.Arith.Offset

View File

@@ -4,76 +4,10 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Data.PersistentArray
import Lean.Meta.Tactic.Grind.ENodeKey
import Lean.Meta.Tactic.Grind.Arith.Util
import Lean.Meta.Tactic.Grind.Arith.Offset.Types
namespace Lean.Meta.Grind.Arith
namespace Offset
abbrev NodeId := Nat
instance : ToMessageData (Offset.Cnstr NodeId) where
toMessageData c := Offset.toMessageData (α := NodeId) (inst := { toMessageData n := m!"#{n}" }) c
/-- Auxiliary structure used for proof extraction. -/
structure ProofInfo where
w : NodeId
k : Int
proof : Expr
deriving Inhabited
/--
Auxiliary inductive type for representing contraints and equalities
that should be propagated to core.
Recall that we cannot compute proofs until the short-distance
data-structures have been fully updated when a new edge is inserted.
Thus, we store the information to be propagated into a list.
See field `propagate` in `State`.
-/
inductive ToPropagate where
| eqTrue (e : Expr) (u v : NodeId) (k k' : Int)
| eqFalse (e : Expr) (u v : NodeId) (k k' : Int)
| eq (u v : NodeId)
deriving Inhabited
/-- State of the constraint offset procedure. -/
structure State where
/-- Mapping from `NodeId` to the `Expr` represented by the node. -/
nodes : PArray Expr := {}
/-- Mapping from `Expr` to a node representing it. -/
nodeMap : PHashMap ENodeKey NodeId := {}
/-- Mapping from `Expr` representing inequalites to constraints. -/
cnstrs : PHashMap ENodeKey (Cnstr NodeId) := {}
/--
Mapping from pairs `(u, v)` to a list of offset constraints on `u` and `v`.
We use this mapping to implement exhaustive constraint propagation.
-/
cnstrsOf : PHashMap (NodeId × NodeId) (List (Cnstr NodeId × Expr)) := {}
/--
For each node with id `u`, `sources[u]` contains
pairs `(v, k)` s.t. there is a path from `v` to `u` with weight `k`.
-/
sources : PArray (AssocList NodeId Int) := {}
/--
For each node with id `u`, `targets[u]` contains
pairs `(v, k)` s.t. there is a path from `u` to `v` with weight `k`.
-/
targets : PArray (AssocList NodeId Int) := {}
/--
Proof reconstruction information. For each node with id `u`, `proofs[u]` contains
pairs `(v, { w, proof })` s.t. there is a path from `u` to `v`, and
`w` is the penultimate node in the path, and `proof` is the justification for
the last edge.
-/
proofs : PArray (AssocList NodeId ProofInfo) := {}
/-- Truth values and equalities to propagate to core. -/
propagate : List ToPropagate := []
deriving Inhabited
end Offset
/-- State for the arithmetic procedures. -/
structure State where
offset : Offset.State := {}

View File

@@ -49,54 +49,5 @@ def isNatNum? (e : Expr) : Option Nat := Id.run do
let .lit (.natVal k) := k | none
some k
/-- Returns `some (a, k)` if `e` is of the form `a + k`. -/
def isNatOffset? (e : Expr) : Option (Expr × Nat) := Id.run do
let some (a, b) := isNatAdd? e | none
let some k := isNatNum? b | none
some (a, k)
/-- An offset constraint. -/
structure Offset.Cnstr (α : Type) where
u : α
v : α
k : Int := 0
deriving Inhabited
def Offset.Cnstr.neg : Cnstr α Cnstr α
| { u, v, k } => { u := v, v := u, k := -k - 1 }
example (c : Offset.Cnstr α) : c.neg.neg = c := by
cases c; simp [Offset.Cnstr.neg]; omega
def Offset.toMessageData [inst : ToMessageData α] (c : Offset.Cnstr α) : MessageData :=
match c.k with
| .ofNat 0 => m!"{c.u} ≤ {c.v}"
| .ofNat k => m!"{c.u} ≤ {c.v} + {k}"
| .negSucc k => m!"{c.u} + {k + 1} ≤ {c.v}"
instance : ToMessageData (Offset.Cnstr Expr) where
toMessageData c := Offset.toMessageData c
/--
Returns `some cnstr` if `e` is offset constraint.
Remark: `z` is `0` numeral. It is an extra argument because we
want to be able to provide the one that has already been internalized.
-/
def isNatOffsetCnstr? (e : Expr) (z : Expr) : Option (Offset.Cnstr Expr) :=
match_expr e with
| LE.le _ inst a b => if isInstLENat inst then go a b else none
| _ => none
where
go (u v : Expr) :=
if let some (u, k) := isNatOffset? u then
some { u, k := - k, v }
else if let some (v, k) := isNatOffset? v then
some { u, v, k }
else if let some k := isNatNum? u then
some { u := z, v, k := - k }
else if let some k := isNatNum? v then
some { u, v := z, k }
else
some { u, v }
end Lean.Meta.Grind.Arith

View File

@@ -57,4 +57,7 @@ partial def isLinearCnstr (e : Expr) : Bool :=
else
false
def isDvdCnstr (e : Expr) : Bool :=
e.isAppOfArity ``Dvd.dvd 4
end Lean.Meta.Linear

View File

@@ -28,10 +28,13 @@ where
| some e, .add 1 x p => go (some (.add e (.var x))) p
| some e, .add k x p => go (some (.add e (.mulL k (.var x)))) p
def PolyCnstr.toExprCnstr : PolyCnstr ExprCnstr
def RelCnstr.toRaw : RelCnstr RawRelCnstr
| .eq p => .eq p.toExpr (.num 0)
| .le p => .le p.toExpr (.num 0)
def DvdCnstr.toRaw : DvdCnstr RawDvdCnstr
| { k, p } => { k, e := p.toExpr }
/-- Applies the given variable permutation to `e` -/
def Expr.applyPerm (perm : Lean.Perm) (e : Expr) : Expr :=
go e
@@ -45,64 +48,74 @@ where
| .mulL k a => .mulL k (go a)
| .mulR a k => .mulR (go a) k
/-- Applies the given variable permutation to the given expression constraint. -/
def ExprCnstr.applyPerm (perm : Lean.Perm) : ExprCnstr ExprCnstr
/-- Applies the given variable permutation to the given raw relational constraint. -/
def RawRelCnstr.applyPerm (perm : Lean.Perm) : RawRelCnstr RawRelCnstr
| .eq a b => .eq (a.applyPerm perm) (b.applyPerm perm)
| .le a b => .le (a.applyPerm perm) (b.applyPerm perm)
/-- Applies the given variable permutation to the given raw divisibility constraint. -/
def RawDvdCnstr.applyPerm (perm : Lean.Perm) : RawDvdCnstr RawDvdCnstr
| { k, e } => { k, e := e.applyPerm perm }
deriving instance Repr for Poly
deriving instance Repr for Expr
deriving instance Repr for RelCnstr
deriving instance Repr for RawRelCnstr
deriving instance Repr for DvdCnstr
deriving instance Repr for RawDvdCnstr
end Int.Linear
namespace Lean.Meta.Linear.Int
deriving instance Repr for Int.Linear.Poly
deriving instance Repr for Int.Linear.Expr
deriving instance Repr for Int.Linear.ExprCnstr
deriving instance Repr for Int.Linear.PolyCnstr
abbrev LinearExpr := Int.Linear.Expr
abbrev LinearCnstr := Int.Linear.ExprCnstr
abbrev PolyExpr := Int.Linear.Poly
def LinearExpr.toExpr (e : LinearExpr) : Expr :=
def ofLinearExpr (e : Int.Linear.Expr) : Expr :=
open Int.Linear.Expr in
match e with
| .num v => mkApp (mkConst ``num) (Lean.toExpr v)
| .var i => mkApp (mkConst ``var) (mkNatLit i)
| .neg a => mkApp (mkConst ``neg) (toExpr a)
| .add a b => mkApp2 (mkConst ``add) (toExpr a) (toExpr b)
| .sub a b => mkApp2 (mkConst ``sub) (toExpr a) (toExpr b)
| .mulL k a => mkApp2 (mkConst ``mulL) (Lean.toExpr k) (toExpr a)
| .mulR a k => mkApp2 (mkConst ``mulR) (toExpr a) (Lean.toExpr k)
| .neg a => mkApp (mkConst ``neg) (ofLinearExpr a)
| .add a b => mkApp2 (mkConst ``add) (ofLinearExpr a) (ofLinearExpr b)
| .sub a b => mkApp2 (mkConst ``sub) (ofLinearExpr a) (ofLinearExpr b)
| .mulL k a => mkApp2 (mkConst ``mulL) (toExpr k) (ofLinearExpr a)
| .mulR a k => mkApp2 (mkConst ``mulR) (ofLinearExpr a) (toExpr k)
instance : ToExpr LinearExpr where
toExpr a := a.toExpr
instance : ToExpr Int.Linear.Expr where
toExpr a := ofLinearExpr a
toTypeExpr := mkConst ``Int.Linear.Expr
protected def LinearCnstr.toExpr (c : LinearCnstr) : Expr :=
open Int.Linear.ExprCnstr in
def ofRawRelCnstr (c : Int.Linear.RawRelCnstr) : Expr :=
match c with
| .eq e₁ e₂ => mkApp2 (mkConst ``eq) (toExpr e₁) (toExpr e₂)
| .le e₁ e₂ => mkApp2 (mkConst ``le) (toExpr e₁) (toExpr e₂)
| .eq e₁ e₂ => mkApp2 (mkConst ``Int.Linear.RawRelCnstr.eq) (toExpr e₁) (toExpr e₂)
| .le e₁ e₂ => mkApp2 (mkConst ``Int.Linear.RawRelCnstr.le) (toExpr e₁) (toExpr e₂)
instance : ToExpr LinearCnstr where
toExpr a := a.toExpr
toTypeExpr := mkConst ``Int.Linear.ExprCnstr
instance : ToExpr Int.Linear.RawRelCnstr where
toExpr a := ofRawRelCnstr a
toTypeExpr := mkConst ``Int.Linear.RawRelCnstr
open Int.Linear.Expr in
def LinearExpr.toArith (ctx : Array Expr) (e : LinearExpr) : MetaM Expr := do
def ofRawDvdCnstr (c : Int.Linear.RawDvdCnstr) : Expr :=
mkApp2 (mkConst ``Int.Linear.RawDvdCnstr.mk) (toExpr c.k) (toExpr c.e)
instance : ToExpr Int.Linear.RawDvdCnstr where
toExpr a := ofRawDvdCnstr a
toTypeExpr := mkConst ``Int.Linear.RawDvdCnstr
def _root_.Int.Linear.Expr.denoteExpr (ctx : Array Expr) (e : Int.Linear.Expr) : MetaM Expr := do
match e with
| .num v => return Lean.toExpr v
| .var i => return ctx[i]!
| .neg a => return mkIntNeg ( toArith ctx a)
| .add a b => return mkIntAdd ( toArith ctx a) ( toArith ctx b)
| .sub a b => return mkIntSub ( toArith ctx a) ( toArith ctx b)
| .mulL k a => return mkIntMul (Lean.toExpr k) ( toArith ctx a)
| .mulR a k => return mkIntMul ( toArith ctx a) (Lean.toExpr k)
| .neg a => return mkIntNeg ( denoteExpr ctx a)
| .add a b => return mkIntAdd ( denoteExpr ctx a) ( denoteExpr ctx b)
| .sub a b => return mkIntSub ( denoteExpr ctx a) ( denoteExpr ctx b)
| .mulL k a => return mkIntMul (toExpr k) ( denoteExpr ctx a)
| .mulR a k => return mkIntMul ( denoteExpr ctx a) (toExpr k)
def LinearCnstr.toArith (ctx : Array Expr) (c : LinearCnstr) : MetaM Expr := do
def _root_.Int.Linear.RawRelCnstr.denoteExpr (ctx : Array Expr) (c : Int.Linear.RawRelCnstr) : MetaM Expr := do
match c with
| .eq e₁ e₂ => return mkIntEq ( LinearExpr.toArith ctx e₁) ( LinearExpr.toArith ctx e₂)
| .le e₁ e₂ => return mkIntLE ( LinearExpr.toArith ctx e₁) ( LinearExpr.toArith ctx e₂)
| .eq e₁ e₂ => return mkIntEq ( e₁.denoteExpr ctx) ( e₂.denoteExpr ctx)
| .le e₁ e₂ => return mkIntLE ( e₁.denoteExpr ctx) ( e₂.denoteExpr ctx)
def _root_.Int.Linear.RawDvdCnstr.denoteExpr (ctx : Array Expr) (c : Int.Linear.RawDvdCnstr) : MetaM Expr := do
return mkIntDvd (mkIntLit c.k) ( c.e.denoteExpr ctx)
namespace ToLinear
@@ -114,7 +127,7 @@ abbrev M := StateRefT State MetaM
open Int.Linear.Expr
def addAsVar (e : Expr) : M LinearExpr := do
def addAsVar (e : Expr) : M Int.Linear.Expr := do
if let some x ( get).varMap.find? e then
return var x
else
@@ -123,14 +136,14 @@ def addAsVar (e : Expr) : M LinearExpr := do
set { varMap := ( s.varMap.insert e x), vars := s.vars.push e : State }
return var x
partial def toLinearExpr (e : Expr) : M LinearExpr := do
partial def toLinearExpr (e : Expr) : M Int.Linear.Expr := do
match e with
| .mdata _ e => toLinearExpr e
| .app .. => visit e
| .mvar .. => visit e
| _ => addAsVar e
where
visit (e : Expr) : M LinearExpr := do
visit (e : Expr) : M Int.Linear.Expr := do
let mul (a b : Expr) := do
match ( getIntValue? a) with
| some k => return .mulL k ( toLinearExpr b)
@@ -168,7 +181,7 @@ where
else addAsVar e
| _ => addAsVar e
partial def toLinearCnstr? (e : Expr) : M (Option LinearCnstr) := OptionT.run do
partial def toRawRelCnstr? (e : Expr) : M (Option Int.Linear.RawRelCnstr) := OptionT.run do
match_expr e with
| Eq α a b =>
let_expr Int α | failure
@@ -200,13 +213,19 @@ partial def toLinearCnstr? (e : Expr) : M (Option LinearCnstr) := OptionT.run do
return .le (.add ( toLinearExpr b) (.num 1)) ( toLinearExpr a)
| _ => failure
partial def toRawDvdCnstr? (e : Expr) : M (Option Int.Linear.RawDvdCnstr) := OptionT.run do
let_expr Dvd.dvd _ inst k b e | failure
guard ( isInstDvdInt inst)
let some k getIntValue? k | failure
return { k, e := ( toLinearExpr b) }
def run (x : M α) : MetaM (α × Array Expr) := do
let (a, s) x.run {}
return (a, s.vars)
end ToLinear
def toLinearExpr (e : Expr) : MetaM (LinearExpr × Array Expr) := do
def toLinearExpr (e : Expr) : MetaM (Int.Linear.Expr × Array Expr) := do
let (e, atoms) ToLinear.run (ToLinear.toLinearExpr e)
if atoms.size == 1 then
return (e, atoms)
@@ -215,8 +234,18 @@ def toLinearExpr (e : Expr) : MetaM (LinearExpr × Array Expr) := do
let e := e.applyPerm perm
return (e, atoms)
def toLinearCnstr? (e : Expr) : MetaM (Option (LinearCnstr × Array Expr)) := do
let (some c, atoms) ToLinear.run (ToLinear.toLinearCnstr? e)
def toRawRelCnstr? (e : Expr) : MetaM (Option (Int.Linear.RawRelCnstr × Array Expr)) := do
let (some c, atoms) ToLinear.run (ToLinear.toRawRelCnstr? e)
| return none
if atoms.size <= 1 then
return some (c, atoms)
else
let (atoms, perm) := sortExprs atoms
let c := c.applyPerm perm
return some (c, atoms)
def toRawDvdCnstr? (e : Expr) : MetaM (Option (Int.Linear.RawDvdCnstr × Array Expr)) := do
let (some c, atoms) ToLinear.run (ToLinear.toRawDvdCnstr? e)
| return none
if atoms.size <= 1 then
return some (c, atoms)

View File

@@ -17,11 +17,11 @@ where
| .num k' => Nat.gcd k k'.natAbs
| .add k' _ p => go (Nat.gcd k k'.natAbs) p
def Int.Linear.PolyCnstr.gcdAll : PolyCnstr Nat
def Int.Linear.PolyCnstr.gcdAll : RelCnstr Nat
| .eq p => p.gcdAll
| .le p => p.gcdAll
def Int.Linear.Poly.gcdCoeffs : Poly Nat
def Int.Linear.Poly.gcdCoeffs' : Poly Nat
| .num _ => 1
| .add k _ p => go k.natAbs p
where
@@ -31,68 +31,68 @@ where
| .num _ => k
| .add k' _ p => go (Nat.gcd k k'.natAbs) p
def Int.Linear.PolyCnstr.gcdCoeffs : PolyCnstr Nat
| .eq p | .le p => p.gcdCoeffs
def Int.Linear.RelCnstr.gcdCoeffs : RelCnstr Nat
| .eq p | .le p => p.gcdCoeffs'
def Int.Linear.PolyCnstr.isEq : PolyCnstr Bool
def Int.Linear.RelCnstr.isEq : RelCnstr Bool
| .eq _ => true
| .le _ => false
def Int.Linear.PolyCnstr.getConst : PolyCnstr Int
def Int.Linear.RelCnstr.getConst : RelCnstr Int
| .eq p | .le p => p.getConst
namespace Lean.Meta.Linear.Int
def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let some (c, atoms) toLinearCnstr? e | return none
def simpRelCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let some (c, atoms) toRawRelCnstr? e | return none
withAbstractAtoms atoms ``Int fun atoms => do
let lhs c.toArith atoms
let p := c.toPoly
let lhs c.denoteExpr atoms
let p := c.norm
if p.isUnsat then
let r := mkConst ``False
let h := mkApp3 (mkConst ``Int.Linear.ExprCnstr.eq_false_of_isUnsat) (toContextExpr atoms) (toExpr c) reflBoolTrue
let h := mkApp3 (mkConst ``Int.Linear.RawRelCnstr.eq_false_of_isUnsat) (toContextExpr atoms) (toExpr c) reflBoolTrue
return some (r, mkExpectedTypeHint h ( mkEq lhs r))
else if p.isValid then
let r := mkConst ``True
let h := mkApp3 (mkConst ``Int.Linear.ExprCnstr.eq_true_of_isValid) (toContextExpr atoms) (toExpr c) reflBoolTrue
let h := mkApp3 (mkConst ``Int.Linear.RawRelCnstr.eq_true_of_isValid) (toContextExpr atoms) (toExpr c) reflBoolTrue
return some (r, mkExpectedTypeHint h ( mkEq lhs r))
else
let c' : LinearCnstr := p.toExprCnstr
let c' := p.toRaw
if c != c' then
match p with
| .eq (.add 1 x (.add (-1) y (.num 0))) =>
let r := mkIntEq atoms[x]! atoms[y]!
let h := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_var) (toContextExpr atoms) (toExpr x) (toExpr y) (toExpr c) reflBoolTrue
let h := mkApp5 (mkConst ``Int.Linear.RawRelCnstr.eq_of_norm_eq_var) (toContextExpr atoms) (toExpr x) (toExpr y) (toExpr c) reflBoolTrue
return some (r, mkExpectedTypeHint h ( mkEq lhs r))
| .eq (.add 1 x (.num k)) =>
let r := mkIntEq atoms[x]! (toExpr (-k))
let h := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_const) (toContextExpr atoms) (toExpr x) (toExpr (-k)) (toExpr c) reflBoolTrue
let h := mkApp5 (mkConst ``Int.Linear.RawRelCnstr.eq_of_norm_eq_const) (toContextExpr atoms) (toExpr x) (toExpr (-k)) (toExpr c) reflBoolTrue
return some (r, mkExpectedTypeHint h ( mkEq lhs r))
| _ =>
let k := p.gcdCoeffs
if k == 1 then
let r c'.toArith atoms
let h := mkApp4 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq) (toContextExpr atoms) (toExpr c) (toExpr c') reflBoolTrue
let r c'.denoteExpr atoms
let h := mkApp4 (mkConst ``Int.Linear.RawRelCnstr.eq_of_norm_eq) (toContextExpr atoms) (toExpr c) (toExpr c') reflBoolTrue
return some (r, mkExpectedTypeHint h ( mkEq lhs r))
else if p.getConst % k == 0 then
let c' : LinearCnstr := (p.div k).toExprCnstr
let r c'.toArith atoms
let h := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_divBy) (toContextExpr atoms) (toExpr c) (toExpr c') (toExpr (Int.ofNat k)) reflBoolTrue
let c' := (p.div k).toRaw
let r c'.denoteExpr atoms
let h := mkApp5 (mkConst ``Int.Linear.RawRelCnstr.eq_of_divBy) (toContextExpr atoms) (toExpr c) (toExpr c') (toExpr (Int.ofNat k)) reflBoolTrue
return some (r, mkExpectedTypeHint h ( mkEq lhs r))
else if p.isEq then
let r := mkConst ``False
let h := mkApp4 (mkConst ``Int.Linear.ExprCnstr.eq_false_of_isUnsat_coeff) (toContextExpr atoms) (toExpr c) (toExpr (Int.ofNat k)) reflBoolTrue
let h := mkApp4 (mkConst ``Int.Linear.RawRelCnstr.eq_false_of_isUnsat_coeff) (toContextExpr atoms) (toExpr c) (toExpr (Int.ofNat k)) reflBoolTrue
return some (r, mkExpectedTypeHint h ( mkEq lhs r))
else
-- `p.isLe`: tighten the bound
let c' : LinearCnstr := (p.div k).toExprCnstr
let r c'.toArith atoms
let h := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_divByLe) (toContextExpr atoms) (toExpr c) (toExpr c') (toExpr (Int.ofNat k)) reflBoolTrue
let c' := (p.div k).toRaw
let r c'.denoteExpr atoms
let h := mkApp5 (mkConst ``Int.Linear.RawRelCnstr.eq_of_divByLe) (toContextExpr atoms) (toExpr c) (toExpr c') (toExpr (Int.ofNat k)) reflBoolTrue
return some (r, mkExpectedTypeHint h ( mkEq lhs r))
else
return none
def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
def simpRelCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
if let some arg := e.not? then
let mut eNew? := none
let mut thmName := Name.anonymous
@@ -116,7 +116,7 @@ def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
| _ => pure ()
if let some eNew := eNew? then
let h₁ := mkApp2 (mkConst thmName) (arg.getArg! 2) (arg.getArg! 3)
if let some (eNew', h₂) simpCnstrPos? eNew then
if let some (eNew', h₂) simpRelCnstrPos? eNew then
let h := mkApp6 (mkConst ``Eq.trans [levelOne]) (mkSort levelZero) e eNew eNew' h₁ h₂
return some (eNew', h)
else
@@ -124,7 +124,27 @@ def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
else
return none
else
simpCnstrPos? e
simpRelCnstrPos? e
def simpDvdCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let some (c, atoms) toRawDvdCnstr? e | return none
if c.k == 0 then return none
withAbstractAtoms atoms ``Int fun atoms => do
let lhs c.denoteExpr atoms
let c' := c.norm
let k := c'.p.gcdCoeffs c'.k
if c'.p.getConst % k == 0 then
let c' := c'.div k
let c' := c'.toRaw
if c == c' then
return none
let r c'.denoteExpr atoms
let h := mkApp5 (mkConst ``Int.Linear.RawDvdCnstr.eq_of_isEqv) (toContextExpr atoms) (toExpr c) (toExpr c') (toExpr k) reflBoolTrue
return some (r, mkExpectedTypeHint h ( mkEq lhs r))
else
let r := mkConst ``False
let h := mkApp3 (mkConst ``Int.Linear.RawDvdCnstr.eq_false_of_isUnsat) (toContextExpr atoms) (toExpr c) reflBoolTrue
return some (r, mkExpectedTypeHint h ( mkEq lhs r))
def simpExpr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let (e, atoms) toLinearExpr e
@@ -133,7 +153,7 @@ def simpExpr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
if e != e' then
-- We only return some if monomials were fused
let p := mkApp4 (mkConst ``Int.Linear.Expr.eq_of_toPoly_eq) (toContextExpr atoms) (toExpr e) (toExpr e') reflBoolTrue
let r LinearExpr.toArith atoms e'
let r e'.denoteExpr atoms
return some (r, p)
else
return none

View File

@@ -13,7 +13,7 @@ namespace Lean.Meta.Linear
def parentIsTarget (parent? : Option Expr) : Bool :=
match parent? with
| none => false
| some parent => isLinearTerm parent || isLinearCnstr parent
| some parent => isLinearTerm parent || isLinearCnstr parent || isDvdCnstr parent
def simp? (e : Expr) (parent? : Option Expr) : MetaM (Option (Expr × Expr)) := do
-- TODO: add support for `Int` and arbitrary ordered comm rings

View File

@@ -108,4 +108,14 @@ builtin_dsimproc [simp, seval] reduceOfNat (Int.ofNat _) := fun e => do
let some a getNatValue? a | return .continue
return .done <| toExpr (Int.ofNat a)
builtin_simproc [simp, seval] reduceDvd ((_ : Int) _) := fun e => do
let_expr Dvd.dvd _ i a b e | return .continue
unless matchesInstance i (mkConst ``instDvd) do return .continue
let some va fromExpr? a | return .continue
let some vb fromExpr? b | return .continue
if vb % va == 0 then
return .done { expr := mkConst ``True, proof? := mkApp3 (mkConst ``Int.dvd_eq_true_of_mod_eq_zero) a b reflBoolTrue}
else
return .done { expr := mkConst ``False, proof? := mkApp3 (mkConst ``Int.dvd_eq_false_of_mod_ne_zero) a b reflBoolTrue}
end Int

View File

@@ -345,4 +345,14 @@ builtin_simproc [simp, seval] reduceSubDiff ((_ - _ : Nat)) := fun e => do
let geProof mkOfDecideEqTrue (mkGENat po no)
applySimprocConst finExpr ``Nat.Simproc.add_sub_add_ge #[pb, nb, po, no, geProof]
builtin_simproc [simp, seval] reduceDvd ((_ : Nat) _) := fun e => do
let_expr Dvd.dvd _ i a b e | return .continue
unless matchesInstance i (mkConst ``instDvd) do return .continue
let some va fromExpr? a | return .continue
let some vb fromExpr? b | return .continue
if vb % va == 0 then
return .done { expr := mkConst ``True, proof? := mkApp3 (mkConst ``Nat.dvd_eq_true_of_mod_eq_zero) a b reflBoolTrue}
else
return .done { expr := mkConst ``False, proof? := mkApp3 (mkConst ``Nat.dvd_eq_false_of_mod_ne_zero) a b reflBoolTrue}
end Nat

View File

@@ -11,7 +11,7 @@ open Lean Meta Simp
macro "declare_uint_simprocs" typeName:ident : command =>
let ofNat := typeName.getId ++ `ofNat
let ofNatCore := mkIdent (typeName.getId ++ `ofNatCore)
let ofNatLT := mkIdent (typeName.getId ++ `ofNatLT)
let toNat := mkIdent (typeName.getId ++ `toNat)
let fromExpr := mkIdent `fromExpr
`(
@@ -54,8 +54,8 @@ builtin_simproc [simp, seval] reduceNe (( _ : $typeName) ≠ _) := reduceBinPr
builtin_dsimproc [simp, seval] reduceBEq (( _ : $typeName) == _) := reduceBoolPred ``BEq.beq 4 (. == .)
builtin_dsimproc [simp, seval] reduceBNe (( _ : $typeName) != _) := reduceBoolPred ``bne 4 (. != .)
builtin_dsimproc [simp, seval] $(mkIdent `reduceOfNatCore):ident ($ofNatCore _ _) := fun e => do
unless e.isAppOfArity $(quote ofNatCore.getId) 2 do return .continue
builtin_dsimproc [simp, seval] $(mkIdent `reduceOfNatLT):ident ($ofNatLT _ _) := fun e => do
unless e.isAppOfArity $(quote ofNatLT.getId) 2 do return .continue
let some value Nat.fromExpr? e.appFn!.appArg! | return .continue
let value := $(mkIdent ofNat) value
return .done <| toExpr value

View File

@@ -483,8 +483,8 @@ def congrDefault (e : Expr) : SimpM Result := do
congrArgs ( simp f) args
/-- Process the given congruence theorem hypothesis. Return true if it made "progress". -/
def processCongrHypothesis (h : Expr) : SimpM Bool := do
forallTelescopeReducing ( inferType h) fun xs hType => withNewLemmas xs do
def processCongrHypothesis (h : Expr) (hType : Expr) : SimpM Bool := do
forallTelescopeReducing hType fun xs hType => withNewLemmas xs do
let lhs instantiateMVars hType.appFn!.appArg!
let r simp lhs
let rhs := hType.appArg!
@@ -521,7 +521,9 @@ def trySimpCongrTheorem? (c : SimpCongrTheorem) (e : Expr) : SimpM (Option Resul
recordCongrTheorem c.theoremName
trace[Debug.Meta.Tactic.simp.congr] "{c.theoremName}, {e}"
let thm mkConstWithFreshMVarLevels c.theoremName
let (xs, bis, type) forallMetaTelescopeReducing ( inferType thm)
let thmType inferType thm
let thmHasBinderNameHint := thmType.hasBinderNameHint
let (xs, bis, type) forallMetaTelescopeReducing thmType
if c.hypothesesPos.any (· xs.size) then
return none
let isIff := type.isAppOf ``Iff
@@ -537,12 +539,14 @@ def trySimpCongrTheorem? (c : SimpCongrTheorem) (e : Expr) : SimpM (Option Resul
if ( withSimpMetaConfig <| isDefEq lhs e) then
let mut modified := false
for i in c.hypothesesPos do
let x := xs[i]!
let h := xs[i]!
let hType instantiateMVars ( inferType h)
let hType if thmHasBinderNameHint then hType.resolveBinderNameHint else pure hType
try
if ( processCongrHypothesis x) then
if ( processCongrHypothesis h hType) then
modified := true
catch _ =>
trace[Meta.Tactic.simp.congr] "processCongrHypothesis {c.theoremName} failed {← inferType x}"
trace[Meta.Tactic.simp.congr] "processCongrHypothesis {c.theoremName} failed {hType}"
-- Remark: we don't need to check ex.isMaxRecDepth anymore since `try .. catch ..`
-- does not catch runtime exceptions by default.
return none

View File

@@ -286,7 +286,7 @@ def simpArith (e : Expr) : SimpM Step := do
if Linear.isLinearCnstr e then
if let some (e', h) Linear.Nat.simpCnstr? e then
return .visit { expr := e', proof? := h }
else if let some (e', h) Linear.Int.simpCnstr? e then
else if let some (e', h) Linear.Int.simpRelCnstr? e then
return .visit { expr := e', proof? := h }
else
return .continue
@@ -300,6 +300,11 @@ def simpArith (e : Expr) : SimpM Step := do
return .visit { expr := e', proof? := h }
else
return .continue
else if Linear.isDvdCnstr e then
if let some (e', h) Linear.Int.simpDvdCnstr? e then
return .visit { expr := e', proof? := h }
else
return .continue
else
return .continue

View File

@@ -0,0 +1,135 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Kim Morrison
-/
prelude
import Lean.Elab.Command
import Lean.Meta.Eval
import Lean.Meta.CompletionName
import Init.Data.Random
/-!
# An API for premise selection algorithms.
This module provides a basic API for premise selection algorithms,
which are used to suggest identifiers that should be introduced in a proof.
The core interface is the `Selector` type, which is a function from a metavariable
and a configuration to a list of suggestions.
The `Selector` is registered as an environment extension, and the trivial (no suggestions) implementation
is `Lean.PremiseSelection.empty`.
Lean does not provide a default premise selector, so this module is intended to be used in conjunction
with a downstream package which registers a premise selector.
-/
namespace Lean.PremiseSelection
/--
A `Suggestion` is essentially just an identifier and a confidence score that the identifier is relevant.
If the premise selection request included information about the intended use (e.g. in the simplifier, in `grind`, etc.)
the score may be adjusted for that application.
-/
structure Suggestion where
name : Name
/--
The score of the suggestion, as a probability that this suggestion should be used.
-/
score : Float
structure Config where
/--
The maximum number of suggestions to return.
-/
maxSuggestions : Option Nat := none
/--
The tactic that is calling the premise selection, e.g. `simp`, `grind`, or `aesop`.
This may be used to adjust the score of the suggestions
-/
caller : Option Name := none
/--
A filter on suggestions; only suggestions returning `true` should be returned.
(It can be better to filter on the premise selection side, to ensure that enough suggestions are returned.)
-/
filter : Name MetaM Bool := fun _ => pure true
/--
An optional arbitrary "hint" to the premise selection algorithm.
There is no guarantee that the algorithm will make any use of the hint.
Potential use cases include a natural language comment provided by the user
(e.g. allowing use of the premise selector as a search engine)
or including context from the current proof and/or file.
We may later split these use cases into separate fields if necessary.
-/
hint : Option String := none
abbrev Selector : Type := MVarId Config MetaM (Array Suggestion)
/--
The trivial premise selector, which returns no suggestions.
-/
def empty : Selector := fun _ _ => pure #[]
/-- A random premise selection algorithm, provided solely for testing purposes. -/
def random (gen : StdGen := 37, 59) : Selector := fun _ cfg => do
IO.stdGenRef.set gen
let env getEnv
let max := cfg.maxSuggestions.getD 10
let consts := env.const2ModIdx.keysArray
let mut suggestions := #[]
while suggestions.size < max do
let i IO.rand 0 consts.size
let name := consts[i]!
if ! (`Lean).isPrefixOf name && Lean.Meta.allowCompletion env name then
suggestions := suggestions.push { name := name, score := 1.0 / consts.size.toFloat }
return suggestions
initialize premiseSelectorExt : EnvExtension (Option Selector)
registerEnvExtension (pure none)
/-- Generate premise suggestions for the given metavariable, using the currently registered premise selector. -/
def select (m : MVarId) (c : Config := {}) : MetaM (Array Suggestion) := do
let some selector := premiseSelectorExt.getState ( getEnv) |
throwError "No premise selector registered. \
(Note the Lean does not provide a default premise selector, these must be installed by a downstream library.)"
selector m c
/-!
Currently the registration mechanism is just global state.
This means that if multiple modules register premise selectors,
the behaviour will be dependent on the order of loading modules.
We should replace this with a mechanism so that
premise selectors are configured via options in the `lakefile`, and
commands are only used to override in a single declaration or file.
-/
/-- Set the current premise selector.-/
def registerPremiseSelector (selector : Selector) : CoreM Unit := do
modifyEnv fun env => premiseSelectorExt.setState env (some selector)
open Lean Elab Command in
@[builtin_command_elab setPremiseSelectorCmd, inherit_doc setPremiseSelectorCmd]
def elabSetPremiseSelector : CommandElab
| `(command| set_premise_selector $selector) => do
let selector liftTermElabM do
try
let selectorTerm Term.elabTermEnsuringType selector (some (Expr.const ``Selector []))
unsafe Meta.evalExpr Selector (Expr.const ``Selector []) selectorTerm
catch _ =>
throwError "Failed to elaborate {selector} as a `MVarId → Config → MetaM (Array Suggestion)`."
liftCoreM (registerPremiseSelector selector)
| _ => throwUnsupportedSyntax
open Lean.Elab.Tactic in
@[builtin_tactic Lean.Parser.Tactic.suggestPremises] def evalSuggestPremises : Tactic := fun _ =>
liftMetaTactic1 fun mvarId => do
let suggestions select mvarId
logInfo m!"Premise suggestions: {suggestions.map (·.name)}"
return mvarId
end Lean.PremiseSelection

View File

@@ -1386,25 +1386,38 @@ where
let e@(.forallE n d e' i) getExpr | unreachable!
let n if bindingNames.contains n then withFreshMacroScope <| MonadQuotation.addMacroScope n else pure n
let bindingNames := bindingNames.insert n
let stxN := mkIdent n
let curIds := curIds.push stxN
if shouldGroupWithNext bindingNames e e' then
withBindingBody n <| delabParamsAux bindingNames idStx groups curIds
withBindingBody' n (mkAnnotatedIdent n) fun stxN =>
delabParamsAux bindingNames idStx groups (curIds.push stxN)
else
let group withBindingDomain do
/-
`mkGroup` constructs binder syntax for the binder names `curIds : Array Ident`, which all have the same type and binder info.
This being a function is solving the following issue:
- To get the last binder name, we need to be under `withBindingBody'`, which lets us annotate the binder with its fvar.
- However, we should delaborate the binder type from outside `withBindingBody'`.
- Thus, we need to partially construct the binder syntax, waiting on the final value of `curIds`.
-/
let mkGroup : Array Ident DelabM Syntax withBindingDomain do
match i with
| .implicit => `(bracketedBinderF|{$curIds* : $( delabTy)})
| .strictImplicit => `(bracketedBinderF|$curIds* : $( delabTy))
| .instImplicit => `(bracketedBinderF|[$stxN : $( delabTy)])
| .implicit => let ty delabTy; pure fun curIds => `(bracketedBinderF|{$curIds* : $ty})
| .strictImplicit => let ty delabTy; pure fun curIds => `(bracketedBinderF|$curIds* : $ty)
| .instImplicit => let ty delabTy; pure fun curIds => `(bracketedBinderF|[$(curIds[0]!) : $ty])
| _ =>
if d.isOptParam then
`(bracketedBinderF|($curIds* : $( withAppFn <| withAppArg delabTy) := $( withAppArg delabTy)))
let ty withAppFn <| withAppArg delabTy
let val withAppArg delabTy
pure fun curIds => `(bracketedBinderF|($curIds* : $ty := $val))
else if let some (.const tacticDecl _) := d.getAutoParamTactic? then
let ty withAppFn <| withAppArg delabTy
let tacticSyntax ofExcept <| evalSyntaxConstant ( getEnv) ( getOptions) tacticDecl
`(bracketedBinderF|($curIds* : $( withAppFn <| withAppArg delabTy) := by $tacticSyntax))
pure fun curIds => `(bracketedBinderF|($curIds* : $ty := by $tacticSyntax))
else
`(bracketedBinderF|($curIds* : $( delabTy)))
withBindingBody n <| delabParams bindingNames idStx (groups.push group)
let ty delabTy
pure fun curIds => `(bracketedBinderF|($curIds* : $ty))
withBindingBody' n (mkAnnotatedIdent n) fun stxN => do
let curIds := curIds.push stxN
let group mkGroup curIds
delabParams bindingNames idStx (groups.push group)
/-
Given the forall `e` with body `e'`, determines if the binder from `e'` (if it is a forall) should be grouped with `e`'s binder.
-/

View File

@@ -125,6 +125,7 @@ def handleCodeAction (params : CodeActionParams) : RequestM (RequestTask (Array
let caps names.mapM evalCodeActionProvider
return ( builtinCodeActionProviders.get).toList.toArray ++ Array.zip names caps
caps.flatMapM fun (providerName, cap) => do
RequestM.checkCancelled
let cas cap params snap
cas.mapIdxM fun i lca => do
if lca.lazy?.isNone then return lca.eager

View File

@@ -5,6 +5,7 @@ Authors: Leonardo de Moura, Marc Huisinga
-/
prelude
import Lean.Server.Completion.CompletionCollectors
import Lean.Server.RequestCancellation
import Std.Data.HashMap
namespace Lean.Server.Completion
@@ -61,11 +62,12 @@ partial def find?
(cmdStx : Syntax)
(infoTree : InfoTree)
(caps : ClientCapabilities)
: IO CompletionList := do
: CancellableM CompletionList := do
let prioritizedPartitions := findPrioritizedCompletionPartitionsAt fileMap hoverPos cmdStx infoTree
let mut allCompletions := #[]
for partition in prioritizedPartitions do
for (i, completionInfoPos) in partition do
CancellableM.checkCancelled
let completions : Array ScoredCompletionItem
match i.info with
| .id stx id danglingDot lctx .. =>

View File

@@ -8,6 +8,7 @@ import Lean.Data.FuzzyMatching
import Lean.Elab.Tactic.Doc
import Lean.Server.Completion.CompletionResolution
import Lean.Server.Completion.EligibleHeaderDecls
import Lean.Server.RequestCancellation
namespace Lean.Server.Completion
open Elab
@@ -36,7 +37,7 @@ section Infrastructure
Monad used for completion computation that allows modifying a completion `State` and reading
`CompletionParams`.
-/
private abbrev M := ReaderT Context $ StateRefT State MetaM
private abbrev M := ReaderT Context $ StateRefT State $ CancellableT MetaM
/-- Adds a new completion item to the state in `M`. -/
private def addItem
@@ -114,10 +115,13 @@ section Infrastructure
(ctx : ContextInfo)
(lctx : LocalContext)
(x : M Unit)
: IO (Array ScoredCompletionItem) :=
ctx.runMetaM lctx do
let (_, s) x.run params, completionInfoPos |>.run {}
return s.items
: CancellableM (Array ScoredCompletionItem) := do
let tk read
let r ctx.runMetaM lctx do
x.run params, completionInfoPos |>.run {} |>.run tk
match r with
| .error _ => throw .requestCancelled
| .ok (_, s) => return s.items
end Infrastructure
@@ -161,6 +165,16 @@ section Utils
return fuzzyMatchScoreWithThreshold? s₁ s₂ |>.map (declName, · / (p₂.getNumParts + 1).toFloat)
return none
private def forEligibleDeclsWithCancellationM [Monad m] [MonadEnv m]
[MonadLiftT (ST IO.RealWorld) m] [MonadCancellable m] [MonadLiftT IO m]
(f : Name ConstantInfo m PUnit) : m PUnit := do
let _ StateT.run (s := 0) <| forEligibleDeclsM fun decl ci => do
modify (· + 1)
if ( get) >= 10000 then
RequestCancellation.check
set <| 0
f decl ci
end Utils
section IdCompletionUtils
@@ -349,7 +363,7 @@ private def idCompletionCore
addUnresolvedCompletionItem localDecl.userName (.fvar localDecl.fvarId) (kind := CompletionItemKind.variable) score
-- search for matches in the environment
let env getEnv
forEligibleDeclsM fun declName c => do
forEligibleDeclsWithCancellationM fun declName c => do
let bestMatch? (·.2) <$> StateT.run (s := none) do
let matchUsingNamespace (ns : Name) : StateT (Option (Name × Float)) M Unit := do
let some (label, score) matchDecl? ns id danglingDot declName
@@ -380,6 +394,7 @@ private def idCompletionCore
matchUsingNamespace Name.anonymous
if let some (bestLabel, bestScore) := bestMatch? then
addUnresolvedCompletionItem bestLabel (.const declName) ( getCompletionKindForDecl c) bestScore
RequestCancellation.check
let matchAlias (ns : Name) (alias : Name) : Option Float :=
-- Recall that aliases may not be atomic and include the namespace where they were created.
if ns.isPrefixOf alias then
@@ -434,7 +449,7 @@ def idCompletion
(id : Name)
(hoverInfo : HoverInfo)
(danglingDot : Bool)
: IO (Array ScoredCompletionItem) :=
: CancellableM (Array ScoredCompletionItem) :=
runM params completionInfoPos ctx lctx do
idCompletionCore ctx stx id hoverInfo danglingDot
@@ -443,7 +458,7 @@ def dotCompletion
(completionInfoPos : Nat)
(ctx : ContextInfo)
(info : TermInfo)
: IO (Array ScoredCompletionItem) :=
: CancellableM (Array ScoredCompletionItem) :=
runM params completionInfoPos ctx info.lctx do
let nameSet try
getDotCompletionTypeNames ( instantiateMVars ( inferType info.expr))
@@ -452,7 +467,7 @@ def dotCompletion
if nameSet.isEmpty then
return
forEligibleDeclsM fun declName c => do
forEligibleDeclsWithCancellationM fun declName c => do
let unnormedTypeName := declName.getPrefix
if ! nameSet.contains unnormedTypeName then
return
@@ -471,7 +486,7 @@ def dotIdCompletion
(lctx : LocalContext)
(id : Name)
(expectedType? : Option Expr)
: IO (Array ScoredCompletionItem) :=
: CancellableM (Array ScoredCompletionItem) :=
runM params completionInfoPos ctx lctx do
let some expectedType := expectedType?
| return ()
@@ -485,7 +500,7 @@ def dotIdCompletion
catch _ =>
pure RBTree.empty
forEligibleDeclsM fun declName c => do
forEligibleDeclsWithCancellationM fun declName c => do
let unnormedTypeName := declName.getPrefix
if ! nameSet.contains unnormedTypeName then
return
@@ -513,7 +528,7 @@ def fieldIdCompletion
(lctx : LocalContext)
(id : Option Name)
(structName : Name)
: IO (Array ScoredCompletionItem) :=
: CancellableM (Array ScoredCompletionItem) :=
runM params completionInfoPos ctx lctx do
let idStr := id.map (·.toString) |>.getD ""
let fieldNames := getStructureFieldsFlattened ( getEnv) structName (includeSubobjectFields := false)

View File

@@ -201,7 +201,7 @@ This option can only be set on the command line, not in the lakefile or via `set
let t BaseIO.asTask do
IO.sleep (server.reportDelayMs.get ctx.cmdlineOpts).toUInt32 -- "Debouncing 1."
BaseIO.bindTask t fun _ => do
let (_, st) handleTasks #[.pure <| toSnapshotTree doc.initSnap] |>.run {}
let (_, st) handleTasks #[.finished none <| toSnapshotTree doc.initSnap] |>.run {}
if ( cancelTk.isSet) then
return .pure ()
@@ -246,8 +246,8 @@ This option can only be set on the command line, not in the lakefile or via `set
handleNode t.task.get
-- limit children's reported range to that of the parent, if any, to avoid strange
-- non-monotonic progress updates; replace missing children's ranges with parent's
let ts := t.task.get.children.map (fun t' => { t' with range? :=
match t.range?, t'.range? with
let ts := t.task.get.children.map (fun t' => { t' with reportingRange? :=
match t.reportingRange?, t'.reportingRange? with
| some r, some r' =>
let start := max r.start r'.start
let stop := min r.stop r'.stop
@@ -288,7 +288,7 @@ This option can only be set on the command line, not in the lakefile or via `set
/-- Reports given tasks' ranges, merging overlapping ones. -/
sendFileProgress (tasks : Array (SnapshotTask SnapshotTree)) : StateT ReportSnapshotsState BaseIO Unit := do
let ranges := tasks.filterMap (·.range?)
let ranges := tasks.filterMap (·.reportingRange?)
let ranges := ranges.qsort (·.start < ·.start)
let ranges := ranges.foldl (init := #[]) fun rs r => match rs[rs.size - 1]? with
| some last =>
@@ -543,14 +543,14 @@ section NotificationHandling
let newDocText := foldDocumentChanges changes oldDoc.meta.text
updateDocument docId.uri, newVersion, newDocText, oldDoc.meta.dependencyBuildMode
for (_, r) in st.pendingRequests do
r.cancelTk.cancel .edit
r.cancelTk.cancelByEdit
def handleCancelRequest (p : CancelParams) : WorkerM Unit := do
let st get
let some r := st.pendingRequests.find? p.id
| return
r.cancelTk.cancel .cancelRequest
r.cancelTk.cancelByCancelRequest
set <| { st with pendingRequests := st.pendingRequests.erase p.id }
/--
@@ -741,6 +741,12 @@ section MessageHandling
pure <| Task.pure <| .ok ()
| Except.ok t => (IO.mapTask · t) fun
| Except.ok r => do
if cancelTk.wasCancelledByCancelRequest then
-- Try not to emit a partial response if this request was cancelled.
-- Clients usually discard responses for requests that they cancelled anyways,
-- but it's still good to send less over the wire in this case.
emitResponse ctx (isComplete := false) <| RequestError.requestCancelled.toLspResponseError id
return
emitResponse ctx (isComplete := r.isComplete) <| .response id (toJson r.response)
| Except.error e =>
emitResponse ctx (isComplete := false) <| e.toLspResponseError id

View File

@@ -121,7 +121,7 @@ def handleInlayHints (_ : InlayHintParams) (s : InlayHintState) :
| some lastEditTimestamp =>
let timeSinceLastEditMs := timestamp - lastEditTimestamp
inlayHintEditDelayMs - timeSinceLastEditMs
let (snaps, _, isComplete) ctx.doc.cmdSnaps.getFinishedPrefixWithConsistentLatency editDelayMs.toUInt32 (cancelTk? := ctx.cancelTk.truncatedTask)
let (snaps, _, isComplete) ctx.doc.cmdSnaps.getFinishedPrefixWithConsistentLatency editDelayMs.toUInt32 (cancelTk? := ctx.cancelTk.cancellationTask)
let finishedRange? : Option String.Range := do
return 0, List.max? <| snaps.map (fun s => s.endPos)
let oldInlayHints :=
@@ -143,7 +143,6 @@ def handleInlayHints (_ : InlayHintParams) (s : InlayHintState) :
let lspInlayHints inlayHints.mapM (·.toLspInlayHint srcSearchPath ctx.doc.meta.text)
let r := { response := lspInlayHints, isComplete }
let s := { s with oldInlayHints := inlayHints }
RequestM.checkCanceled
return (r, s)
def handleInlayHintsDidChange (p : DidChangeTextDocumentParams)

View File

@@ -23,10 +23,14 @@ def findCompletionCmdDataAtPos
(doc : EditableDocument)
(pos : String.Pos)
: Task (Option (Syntax × Elab.InfoTree)) :=
findCmdDataAtPos doc (pos := pos) fun s => Id.run do
let some tailPos := s.stx.getTailPos?
| return false
return pos.byteIdx <= tailPos.byteIdx + s.stx.getTrailingSize
-- `findCmdDataAtPos` may produce an incorrect snapshot when `pos` is in whitespace.
-- However, most completions don't need trailing whitespace at the term level;
-- synthetic completions are the only notions of completion that care care about whitespace.
-- Synthetic tactic completion only needs the `ContextInfo` of the command, so any snapshot
-- will do.
-- Synthetic field completion in `{ }` doesn't care about whitespace;
-- synthetic field completion in `where` only needs to gather the expected type.
findCmdDataAtPos doc pos (includeStop := true)
def handleCompletion (p : CompletionParams)
: RequestM (RequestTask CompletionList) := do
@@ -245,14 +249,54 @@ def handleDefinition (kind : GoToKind) (p : TextDocumentPositionParams)
locationLinksOfInfo kind infoWithCtx snap.infoTree
else return #[]
open Language in
def findGoalsAt? (doc : EditableDocument) (hoverPos : String.Pos) : Task (Option (List Elab.GoalsAtResult)) :=
let text := doc.meta.text
findCmdParsedSnap doc hoverPos |>.bind (sync := true) fun
| some cmdParsed =>
let t := toSnapshotTree cmdParsed |>.foldSnaps [] fun snap oldGoals => Id.run do
let some (pos, tailPos, trailingPos) := getPositions snap
| return .pure (oldGoals, .proceed (foldChildren := false))
let snapRange : String.Range := pos, trailingPos
-- When there is no trailing whitespace, we also consider snapshots directly before the
-- cursor.
let hasNoTrailingWhitespace := tailPos == trailingPos
if ! text.rangeContainsHoverPos snapRange hoverPos (includeStop := hasNoTrailingWhitespace) then
return .pure (oldGoals, .proceed (foldChildren := false))
return snap.task.map (sync := true) fun tree => Id.run do
let some infoTree := tree.element.infoTree?
| return (oldGoals, .proceed (foldChildren := true))
let goals := infoTree.goalsAt? text hoverPos
let optimalSnapRange : String.Range := pos, tailPos
let isOptimalGoalSet :=
text.rangeContainsHoverPos optimalSnapRange hoverPos
(includeStop := hasNoTrailingWhitespace)
|| goals.any fun goal => ! goal.indented
if isOptimalGoalSet then
return (goals, .done)
return (goals, .proceed (foldChildren := true))
t.map fun
| [] => none
| goals => goals
| none =>
.pure none
where
getPositions (snap : SnapshotTask SnapshotTree) : Option (String.Pos × String.Pos × String.Pos) := do
let stx snap.stx?
let pos stx.getPos? (canonicalOnly := true)
let tailPos stx.getTailPos? (canonicalOnly := true)
let trailingPos? stx.getTrailingTailPos? (canonicalOnly := true)
return (pos, tailPos, trailingPos?)
open RequestM in
def getInteractiveGoals (p : Lsp.PlainGoalParams) : RequestM (RequestTask (Option Widget.InteractiveGoals)) := do
let doc readDoc
let text := doc.meta.text
let hoverPos := text.lspPosToUtf8Pos p.position
mapTask (findInfoTreeAtPosWithTrailingWhitespace doc hoverPos) <| Option.bindM fun infoTree => do
let rs@(_ :: _) := infoTree.goalsAt? doc.meta.text hoverPos
| return none
mapTask (findGoalsAt? doc hoverPos) <| Option.mapM fun rs => do
let goals : List Widget.InteractiveGoals rs.mapM fun { ctxInfo := ci, tacticInfo := ti, useAfter := useAfter, .. } => do
let ciAfter := { ci with mctx := ti.mctxAfter }
let ci := if useAfter then ciAfter else { ci with mctx := ti.mctxBefore }
@@ -270,7 +314,7 @@ def getInteractiveGoals (p : Lsp.PlainGoalParams) : RequestM (RequestTask (Optio
-- fail silently, since this is just a bonus feature
return goals
)
return some <| goals.foldl (· ++ ·)
return goals.foldl (· ++ ·)
open Elab in
def handlePlainGoal (p : PlainGoalParams)
@@ -292,7 +336,7 @@ def getInteractiveTermGoal (p : Lsp.PlainTermGoalParams)
let doc readDoc
let text := doc.meta.text
let hoverPos := text.lspPosToUtf8Pos p.position
mapTask (findInfoTreeAtPosWithTrailingWhitespace doc hoverPos) <| Option.bindM fun infoTree => do
mapTask (findInfoTreeAtPos doc hoverPos (includeStop := true)) <| Option.bindM fun infoTree => do
let some {ctx := ci, info := i@(Elab.Info.ofTermInfo ti), ..} := infoTree.termGoalAt? hoverPos
| return none
let ty ci.runMetaM i.lctx do
@@ -382,13 +426,14 @@ partial def handleDocumentSymbol (_ : DocumentSymbolParams)
let t := doc.cmdSnaps.waitAll
mapTask t fun (snaps, _) => do
let mut stxs := snaps.map (·.stx)
return { syms := toDocumentSymbols doc.meta.text stxs #[] [] }
return { syms := toDocumentSymbols doc.meta.text stxs #[] [] }
where
toDocumentSymbols (text : FileMap) (stxs : List Syntax)
(syms : Array DocumentSymbol) (stack : List NamespaceEntry) :
Array DocumentSymbol :=
RequestM (Array DocumentSymbol) := do
RequestM.checkCancelled
match stxs with
| [] => stack.foldl (fun syms entry => entry.finish text syms none) syms
| [] => return stack.foldl (fun syms entry => entry.finish text syms none) syms
| stx::stxs => match stx with
| `(namespace $id) =>
let entry := { name := id.getId.componentsRev, stx, selection := id, prevSiblings := syms }
@@ -411,9 +456,9 @@ where
let syms := entry.finish text syms stx
popStack (n - entry.name.length) syms stack
popStack (id.map (·.getId.getNumParts) |>.getD 1) syms stack
| _ => Id.run do
| _ => do
unless stx.isOfKind ``Lean.Parser.Command.declaration do
return toDocumentSymbols text stxs syms stack
return toDocumentSymbols text stxs syms stack
if let some stxRange := stx.getRange? then
let (name, selection) := match stx with
| `($_:declModifiers $_:attrKind instance $[$np:namedPrio]? $[$id$[.{$ls,*}]?]? $sig:declSig $_) =>
@@ -431,7 +476,7 @@ where
range := stxRange.toLspRange text
selectionRange := selRange.toLspRange text
}
return toDocumentSymbols text stxs (syms.push sym) stack
return toDocumentSymbols text stxs (syms.push sym) stack
toDocumentSymbols text stxs syms stack
partial def handleFoldingRange (_ : FoldingRangeParams)
@@ -450,7 +495,9 @@ partial def handleFoldingRange (_ : FoldingRangeParams)
if let (_, start)::rest := sections then
addRange text FoldingRangeKind.region start text.source.endPos
addRanges text rest []
| stx::stxs => match stx with
| stx::stxs => do
RequestM.checkCancelled
match stx with
| `(namespace $id) =>
addRanges text ((id.getId.getNumParts, stx.getPos?)::sections) stxs
| `(section $(id)?) =>

View File

@@ -147,13 +147,12 @@ def handleSemanticTokens (beginPos : String.Pos) (endPos? : Option String.Pos)
-- for the full file before sending a response. This means that the response will be incomplete,
-- which we mitigate by regularly sending `workspace/semanticTokens/refresh` requests in the
-- `FileWorker` to tell the client to re-compute the semantic tokens.
let (snaps, _, isComplete) doc.cmdSnaps.getFinishedPrefixWithTimeout 3000 (cancelTk? := ctx.cancelTk.truncatedTask)
let (snaps, _, isComplete) doc.cmdSnaps.getFinishedPrefixWithTimeout 3000 (cancelTk? := ctx.cancelTk.cancellationTask)
asTask <| do
return { response := run doc snaps, isComplete }
| some endPos =>
let t := doc.cmdSnaps.waitUntil (·.endPos >= endPos)
mapTask t fun (snaps, _) => do
RequestM.checkCanceled
return { response := run doc snaps, isComplete := true }
where
run doc snaps : RequestM SemanticTokens := do
@@ -164,8 +163,11 @@ where
let syntaxBasedSemanticTokens := collectSyntaxBasedSemanticTokens s.stx
let infoBasedSemanticTokens := collectInfoBasedSemanticTokens s.infoTree
leanSemanticTokens := leanSemanticTokens ++ syntaxBasedSemanticTokens ++ infoBasedSemanticTokens
RequestM.checkCancelled
let absoluteLspSemanticTokens := computeAbsoluteLspSemanticTokens doc.meta.text beginPos endPos? leanSemanticTokens
RequestM.checkCancelled
let absoluteLspSemanticTokens := filterDuplicateSemanticTokens absoluteLspSemanticTokens
RequestM.checkCancelled
let semanticTokens := computeDeltaLspSemanticTokens absoluteLspSemanticTokens
return semanticTokens

View File

@@ -0,0 +1,77 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Marc Huisinga
-/
prelude
import Init.System.Promise
namespace Lean.Server
structure RequestCancellationToken where
cancelledByCancelRequest : IO.Ref Bool
cancelledByEdit : IO.Ref Bool
cancellationPromise : IO.Promise Unit
namespace RequestCancellationToken
def new : IO RequestCancellationToken := do
return {
cancelledByCancelRequest := IO.mkRef false
cancelledByEdit := IO.mkRef false
cancellationPromise := IO.Promise.new
}
def cancelByCancelRequest (tk : RequestCancellationToken) : IO Unit := do
tk.cancelledByCancelRequest.set true
tk.cancellationPromise.resolve ()
def cancelByEdit (tk : RequestCancellationToken) : IO Unit := do
tk.cancelledByEdit.set true
tk.cancellationPromise.resolve ()
def cancellationTask (tk : RequestCancellationToken) : Task Unit :=
tk.cancellationPromise.result!
def wasCancelledByCancelRequest (tk : RequestCancellationToken) : IO Bool :=
tk.cancelledByCancelRequest.get
def wasCancelledByEdit (tk : RequestCancellationToken) : IO Bool := do
tk.cancelledByEdit.get
end RequestCancellationToken
structure RequestCancellation where
def RequestCancellation.requestCancelled : RequestCancellation := {}
abbrev CancellableT m := ReaderT RequestCancellationToken (ExceptT RequestCancellation m)
abbrev CancellableM := CancellableT IO
def CancellableT.run (tk : RequestCancellationToken) (x : CancellableT m α) :
m (Except RequestCancellation α) :=
x tk
def CancellableM.run (tk : RequestCancellationToken) (x : CancellableM α) :
IO (Except RequestCancellation α) :=
CancellableT.run tk x
def CancellableT.checkCancelled [Monad m] [MonadLiftT IO m] : CancellableT m Unit := do
let tk read
if tk.wasCancelledByCancelRequest then
throw .requestCancelled
def CancellableM.checkCancelled : CancellableM Unit :=
CancellableT.checkCancelled
class MonadCancellable (m : Type Type v) where
checkCancelled : m PUnit
instance (m n) [MonadLift m n] [MonadCancellable m] : MonadCancellable n where
checkCancelled := liftM (MonadCancellable.checkCancelled : m PUnit)
instance [Monad m] [MonadLiftT IO m] : MonadCancellable (CancellableT m) where
checkCancelled := CancellableT.checkCancelled
def RequestCancellation.check [MonadCancellable m] : m Unit :=
MonadCancellable.checkCancelled

View File

@@ -11,6 +11,8 @@ import Lean.Data.Json
import Lean.Data.Lsp
import Lean.Elab.Command
import Lean.Server.RequestCancellation
import Lean.Server.FileSource
import Lean.Server.FileWorker.Utils
@@ -18,29 +20,72 @@ import Lean.Server.Rpc.Basic
import Std.Sync.Mutex
import Std.Sync.Mutex
/-- Checks whether `r` contains `hoverPos`, taking into account EOF according to `text`. -/
def Lean.FileMap.rangeContainsHoverPos (text : Lean.FileMap) (r : String.Range)
(hoverPos : String.Pos) (includeStop := false) : Bool :=
-- When `hoverPos` is at the very end of the file, it is *after* the last position in `text`.
-- However, for `includeStop = false`, all ranges stop at the last position in `text`,
-- which always excludes a `hoverPos` at the very end of the file.
-- For the purposes of the language server, we generally assume that ranges that extend to
-- the end of the file also include a `hoverPos` at the very end of the file.
let isRangeAtEOF := r.stop == text.source.endPos
r.contains hoverPos (includeStop := includeStop || isRangeAtEOF)
namespace Lean.Language
/--
Finds the first (in pre-order) snapshot task in `tree` whose `range?` contains `pos` and which
contains an info tree, and then returns that info tree, waiting for any snapshot tasks on the way.
Subtrees that do not contain the position are skipped without forcing their tasks.
-/
partial def SnapshotTree.findInfoTreeAtPos (tree : SnapshotTree) (pos : String.Pos) :
Task (Option Elab.InfoTree) :=
goSeq tree.children.toList
inductive SnapshotTree.foldSnaps.Control where
| done
| proceed (foldChildren : Bool)
partial def SnapshotTree.foldSnaps (tree : SnapshotTree) (init : α)
(f : SnapshotTask SnapshotTree α Task (α × foldSnaps.Control)) : Task α :=
let t := traverseTree init tree
t.map (sync := true) (·.1)
where
goSeq
| [] => .pure none
| t::ts =>
if t.range?.any (·.contains pos) then
t.task.bind (sync := true) fun tree => Id.run do
if let some infoTree := tree.element.infoTree? then
return .pure infoTree
tree.findInfoTreeAtPos pos |>.bind (sync := true) fun
| some infoTree => .pure (some infoTree)
| none => goSeq ts
else
goSeq ts
traverseTree (acc : α) (tree : SnapshotTree) : Task (α × Bool) :=
traverseChildren acc tree.children.toList
traverseChildren (acc : α) : List (SnapshotTask SnapshotTree) Task (α × Bool)
| [] => .pure (acc, false)
| child::otherChildren =>
f child acc |>.bind (sync := true) fun (acc, control) => Id.run do
let .proceed foldChildrenOfChild := control
| return .pure (acc, true)
if ! foldChildrenOfChild then
return traverseChildren acc otherChildren
let subtreeTask := child.task.bind (sync := true) fun tree =>
traverseTree acc tree
return subtreeTask.bind (sync := true) fun (acc, done) => Id.run do
if done then
return .pure (acc, done)
return traverseChildren acc otherChildren
/--
Finds the first (in pre-order) snapshot task in `tree` that contains `hoverPos`
(including whitespace) and which contains an info tree, and then returns that info tree,
waiting for any snapshot tasks on the way.
Subtrees that do not contain the position are skipped without forcing their tasks.
If the caller of this function needs the correct snapshot when the cursor is on whitespace,
then this function is likely the wrong one to call, as it simply yields the first snapshot
that contains `hoverPos` in its whitespace, which is not necessarily the correct one
(e.g. it may be indentation-sensitive).
-/
partial def SnapshotTree.findInfoTreeAtPos (text : FileMap) (tree : SnapshotTree)
(hoverPos : String.Pos) (includeStop : Bool) : Task (Option Elab.InfoTree) :=
tree.foldSnaps (init := none) fun snap _ => Id.run do
let skipChild := .pure (none, .proceed (foldChildren := false))
let some stx := snap.stx?
| return skipChild
let some range := stx.getRangeWithTrailing? (canonicalOnly := true)
| return skipChild
if ! text.rangeContainsHoverPos range hoverPos includeStop then
return skipChild
return snap.task.map (sync := true) fun tree => Id.run do
let some infoTree := tree.element.infoTree?
| return (none, .proceed (foldChildren := true))
return (infoTree, .done)
end Lean.Language
@@ -84,47 +129,6 @@ def toLspResponseError (id : RequestID) (e : RequestError) : ResponseError Unit
end RequestError
inductive RequestCancellationCause where
| cancelRequest
| edit
deriving Inhabited, BEq
structure RequestCancellationToken where
promise : IO.Promise RequestCancellationCause
namespace RequestCancellationToken
def new : IO RequestCancellationToken := do
return { promise := IO.Promise.new }
def cancel (tk : RequestCancellationToken) (cause : RequestCancellationCause) : IO Unit :=
tk.promise.resolve cause
def task (tk : RequestCancellationToken) : Task RequestCancellationCause :=
tk.promise.result!
def truncatedTask (tk : RequestCancellationToken) : Task Unit :=
tk.task.map (sync := true) fun _ => ()
def cancelled? (tk : RequestCancellationToken) : IO (Option RequestCancellationCause) := do
let t := tk.task
if IO.hasFinished t then
return some t.get
else
return none
def wasCancelledByCancelRequest (tk : RequestCancellationToken) : IO Bool := do
let some c tk.cancelled?
| return false
return c matches .cancelRequest
def wasCancelledByEdit (tk : RequestCancellationToken) : IO Bool := do
let some c tk.cancelled?
| return false
return c matches .edit
end RequestCancellationToken
def parseRequestParams (paramType : Type) [FromJson paramType] (params : Json)
: Except RequestError paramType :=
fromJson? params |>.mapError fun inner =>
@@ -158,6 +162,14 @@ instance : MonadLift (EIO Exception) RequestM where
| .error e => throw <| RequestError.ofException e
| .ok v => return v
instance : MonadLift CancellableM RequestM where
monadLift x := do
let ctx read
let r x.run ctx.cancelTk
match r with
| .error _ => throw RequestError.requestCancelled
| .ok v => return v
namespace RequestM
open FileWorker
open Snapshots
@@ -181,7 +193,7 @@ def bindTask (t : Task α) (f : α → RequestM (RequestTask β)) : RequestM (Re
let rc readThe RequestContext
EIO.bindTask t (f · rc)
def checkCanceled : RequestM Unit := do
def checkCancelled : RequestM Unit := do
let rc readThe RequestContext
if rc.cancelTk.wasCancelledByCancelRequest then
throw .requestCancelled
@@ -226,9 +238,9 @@ def withWaitFindSnapAtPos
(x := f)
open Language.Lean in
/-- Finds the first `CommandParsedSnapshot` fulfilling `p`, asynchronously. -/
partial def findCmdParsedSnap (doc : EditableDocument) (p : CommandParsedSnapshot Bool) :
Task (Option CommandParsedSnapshot) := Id.run do
/-- Finds the first `CommandParsedSnapshot` containing `hoverPos`, asynchronously. -/
partial def findCmdParsedSnap (doc : EditableDocument) (hoverPos : String.Pos)
: Task (Option CommandParsedSnapshot) := Id.run do
let some headerParsed := doc.initSnap.result?
| .pure none
headerParsed.processedSnap.task.bind (sync := true) fun headerProcessed => Id.run do
@@ -236,50 +248,43 @@ partial def findCmdParsedSnap (doc : EditableDocument) (p : CommandParsedSnapsho
| return .pure none
headerSuccess.firstCmdSnap.task.bind (sync := true) go
where
go cmdParsed :=
if p cmdParsed then
.pure (some cmdParsed)
else
match cmdParsed.nextCmdSnap? with
| some next => next.task.bind (sync := true) go
| none => .pure none
open Language in
/--
Finds the info tree of the first snapshot task matching `isMatchingSnapshot` and containing `pos`,
asynchronously. The info tree may be from a nested snapshot, such as a single tactic.
See `SnapshotTree.findInfoTreeAtPos` for details on how the search is done.
-/
partial def findInfoTreeAtPos
(doc : EditableDocument)
(isMatchingSnapshot : Lean.CommandParsedSnapshot Bool)
(pos : String.Pos)
: Task (Option Elab.InfoTree) :=
findCmdParsedSnap doc (isMatchingSnapshot ·) |>.bind (sync := true) fun
| some cmdParsed => toSnapshotTree cmdParsed |>.findInfoTreeAtPos pos |>.bind (sync := true) fun
| some infoTree => .pure <| some infoTree
| none => cmdParsed.finishedSnap.task.map (sync := true) fun s =>
-- the parser returns exactly one command per snapshot, and the elaborator creates exactly one node per command
assert! s.cmdState.infoState.trees.size == 1
some s.cmdState.infoState.trees[0]!
go (cmdParsed : CommandParsedSnapshot) : Task (Option CommandParsedSnapshot) := Id.run do
if containsHoverPos cmdParsed then
return .pure (some cmdParsed)
if isAfterHoverPos cmdParsed then
-- This should never happen in principle
-- (commands + trailing ws are consecutive and there is no unassigned space between them),
-- but it's always good to eliminate one additional assumption.
return .pure none
match cmdParsed.nextCmdSnap? with
| some next => next.task.bind (sync := true) go
| none => .pure none
containsHoverPos (cmdParsed : CommandParsedSnapshot) : Bool := Id.run do
let some range := cmdParsed.stx.getRangeWithTrailing? (canonicalOnly := true)
| return false
return doc.meta.text.rangeContainsHoverPos range hoverPos (includeStop := false)
isAfterHoverPos (cmdParsed : CommandParsedSnapshot) : Bool := Id.run do
let some startPos := cmdParsed.stx.getPos? (canonicalOnly := true)
| return false
return hoverPos < startPos
open Language in
/--
Finds the command syntax and info tree of the first snapshot task matching `isMatchingSnapshot` and
containing `pos`, asynchronously. The info tree may be from a nested snapshot,
such as a single tactic.
Finds the command syntax and info tree of the first snapshot task containing `pos`, asynchronously.
The info tree may be from a nested snapshot, such as a single tactic.
See `SnapshotTree.findInfoTreeAtPos` for details on how the search is done.
-/
def findCmdDataAtPos
(doc : EditableDocument)
(isMatchingSnapshot : Lean.CommandParsedSnapshot Bool)
(pos : String.Pos)
(hoverPos : String.Pos)
(includeStop : Bool)
: Task (Option (Syntax × Elab.InfoTree)) :=
findCmdParsedSnap doc (isMatchingSnapshot ·) |>.bind (sync := true) fun
| some cmdParsed => toSnapshotTree cmdParsed |>.findInfoTreeAtPos pos |>.bind (sync := true) fun
findCmdParsedSnap doc hoverPos |>.bind (sync := true) fun
| some cmdParsed => toSnapshotTree cmdParsed |>.findInfoTreeAtPos doc.meta.text hoverPos includeStop |>.bind (sync := true) fun
| some infoTree => .pure <| some (cmdParsed.stx, infoTree)
| none => cmdParsed.finishedSnap.task.map (sync := true) fun s =>
-- the parser returns exactly one command per snapshot, and the elaborator creates exactly one node per command
@@ -287,19 +292,19 @@ def findCmdDataAtPos
some (cmdParsed.stx, s.cmdState.infoState.trees[0]!)
| none => .pure none
open Language in
/--
Finds the info tree of the first snapshot task containing `pos` (including trailing whitespace),
asynchronously. The info tree may be from a nested snapshot, such as a single tactic.
Finds the info tree of the first snapshot task containing `pos`, asynchronously.
The info tree may be from a nested snapshot, such as a single tactic.
See `SnapshotTree.findInfoTreeAtPos` for details on how the search is done.
-/
def findInfoTreeAtPosWithTrailingWhitespace
partial def findInfoTreeAtPos
(doc : EditableDocument)
(pos : String.Pos)
(hoverPos : String.Pos)
(includeStop : Bool)
: Task (Option Elab.InfoTree) :=
-- NOTE: use `>=` since the cursor can be *after* the input (and there is no interesting info on
-- the first character of the subsequent command if any)
findInfoTreeAtPos doc (·.parserState.pos pos) pos
findCmdDataAtPos doc hoverPos includeStop |>.map (sync := true) (·.map (·.2))
open Elab.Command in
def runCommandElabM (snap : Snapshot) (c : RequestT CommandElabM α) : RequestM α := do

View File

@@ -41,6 +41,9 @@ def SourceInfo.updateTrailing (trailing : Substring) : SourceInfo → SourceInfo
def SourceInfo.getRange? (canonicalOnly := false) (info : SourceInfo) : Option String.Range :=
return ( info.getPos? canonicalOnly), ( info.getTailPos? canonicalOnly)
def SourceInfo.getRangeWithTrailing? (canonicalOnly := false) (info : SourceInfo) : Option String.Range :=
return info.getPos? canonicalOnly, info.getTrailingTailPos? canonicalOnly
/--
Converts an `original` or `synthetic (canonical := true)` `SourceInfo` to a
`synthetic (canonical := false)` `SourceInfo`.
@@ -388,6 +391,9 @@ def getRange? (stx : Syntax) (canonicalOnly := false) : Option String.Range :=
| some start, some stop => some { start, stop }
| _, _ => none
def getRangeWithTrailing? (stx : Syntax) (canonicalOnly := false) : Option String.Range :=
return stx.getPos? canonicalOnly, stx.getTrailingTailPos? canonicalOnly
/-- Returns a synthetic Syntax which has the specified `String.Range`. -/
def ofRange (range : String.Range) (canonical := true) : Lean.Syntax :=
.atom (.synthetic range.start range.stop canonical) ""

View File

@@ -397,10 +397,10 @@ structure GetWidgetsResponse where
open Lean Server RequestM in
/-- Get the panel widgets present around a particular position. -/
def getWidgets (pos : Lean.Lsp.Position) : RequestM (RequestTask (GetWidgetsResponse)) := do
def getWidgets (pos : Lean.Lsp.Position) : RequestM (RequestTask GetWidgetsResponse) := do
let doc readDoc
let filemap := doc.meta.text
mapTask (findInfoTreeAtPosWithTrailingWhitespace doc <| filemap.lspPosToUtf8Pos pos) fun
mapTask (findInfoTreeAtPos doc (filemap.lspPosToUtf8Pos pos) (includeStop := true)) fun
| some infoTree@(.context (.commandCtx cc) _) =>
ContextInfo.runMetaM { cc with } {} do
let env getEnv

View File

@@ -179,6 +179,10 @@ Uses the `LawfulEqCmp` instance to cast the retrieved value to the correct type.
def get? [LawfulEqCmp cmp] (t : DTreeMap α β cmp) (a : α) : Option (β a) :=
letI : Ord α := cmp; t.inner.get? a
@[inline, inherit_doc get?, deprecated get? (since := "2025-02-12")]
def find? [LawfulEqCmp cmp] (t : DTreeMap α β cmp) (a : α) : Option (β a) :=
t.get? a
/--
Given a proof that a mapping for the given key is present, retrieves the mapping for the given key.
@@ -197,6 +201,10 @@ Uses the `LawfulEqCmp` instance to cast the retrieved value to the correct type.
def get! [LawfulEqCmp cmp] (t : DTreeMap α β cmp) (a : α) [Inhabited (β a)] : β a :=
letI : Ord α := cmp; t.inner.get! a
@[inline, inherit_doc get!, deprecated get! (since := "2025-02-12")]
def find! [LawfulEqCmp cmp] (t : DTreeMap α β cmp) (a : α) [Inhabited (β a)] : β a :=
t.get! a
/--
Tries to retrieve the mapping for the given key, returning `fallback` if no such mapping is present.
@@ -206,6 +214,10 @@ Uses the `LawfulEqCmp` instance to cast the retrieved value to the correct type.
def getD [LawfulEqCmp cmp] (t : DTreeMap α β cmp) (a : α) (fallback : β a) : β a :=
letI : Ord α := cmp; t.inner.getD a fallback
@[inline, inherit_doc getD, deprecated getD (since := "2025-02-12")]
def findD [LawfulEqCmp cmp] (t : DTreeMap α β cmp) (a : α) (fallback : β a) : β a :=
t.getD a fallback
namespace Const
open Internal (Impl)
@@ -218,6 +230,10 @@ Tries to retrieve the mapping for the given key, returning `none` if no such map
def get? (t : DTreeMap α β cmp) (a : α) : Option β :=
letI : Ord α := cmp; Impl.Const.get? a t.inner
@[inline, inherit_doc get?, deprecated get? (since := "2025-02-12")]
def find? (t : DTreeMap α β cmp) (a : α) : Option β :=
get? t a
/--
Given a proof that a mapping for the given key is present, retrieves the mapping for the given key.
-/
@@ -232,6 +248,10 @@ Tries to retrieve the mapping for the given key, panicking if no such mapping is
def get! (t : DTreeMap α β cmp) (a : α) [Inhabited β] : β :=
letI : Ord α := cmp; Impl.Const.get! a t.inner
@[inline, inherit_doc get!, deprecated get! (since := "2025-02-12")]
def find! (t : DTreeMap α β cmp) (a : α) [Inhabited β] : β :=
get! t a
/--
Tries to retrieve the mapping for the given key, returning `fallback` if no such mapping is present.
-/
@@ -239,6 +259,10 @@ Tries to retrieve the mapping for the given key, returning `fallback` if no such
def getD (t : DTreeMap α β cmp) (a : α) (fallback : β) : β :=
letI : Ord α := cmp; Impl.Const.getD a t.inner fallback
@[inline, inherit_doc getD, deprecated getD (since := "2025-02-12")]
def findD (t : DTreeMap α β cmp) (a : α) (fallback : β) : β :=
getD t a fallback
end Const
variable {δ : Type w} {m : Type w Type w₂} [Monad m]
@@ -248,20 +272,38 @@ variable {δ : Type w} {m : Type w → Type w₂} [Monad m]
def filter (f : (a : α) β a Bool) (t : DTreeMap α β cmp) : DTreeMap α β cmp :=
letI : Ord α := cmp; t.inner.filter f t.wf.balanced |>.impl, t.wf.filter
/--
Folds the given monadic function over the mappings in the map in ascending order.
-/
/-- Folds the given monadic function over the mappings in the map in ascending order. -/
@[inline]
def foldlM (f : δ (a : α) β a m δ) (init : δ) (t : DTreeMap α β cmp) : m δ :=
t.inner.foldlM f init
/--
Folds the given function over the mappings in the map in ascending order.
-/
@[inline, inherit_doc foldlM, deprecated foldlM (since := "2025-02-12")]
def foldM (f : δ (a : α) β a m δ) (init : δ) (t : DTreeMap α β cmp) : m δ :=
t.foldlM f init
/-- Folds the given function over the mappings in the map in ascending order. -/
@[inline]
def foldl (f : δ (a : α) β a δ) (init : δ) (t : DTreeMap α β cmp) : δ :=
t.inner.foldl f init
@[inline, inherit_doc foldl, deprecated foldl (since := "2025-02-12")]
def fold (f : δ (a : α) β a δ) (init : δ) (t : DTreeMap α β cmp) : δ :=
t.foldl f init
/-- Folds the given monadic function over the mappings in the map in descending order. -/
@[inline]
def foldrM (f : δ (a : α) β a m δ) (init : δ) (t : DTreeMap α β cmp) : m δ :=
t.inner.foldrM f init
/-- Folds the given function over the mappings in the map in descending order. -/
@[inline]
def foldr (f : δ (a : α) β a δ) (init : δ) (t : DTreeMap α β cmp) : δ :=
t.inner.foldr f init
@[inline, inherit_doc foldr, deprecated foldr (since := "2025-02-12")]
def revFold (f : δ (a : α) β a δ) (init : δ) (t : DTreeMap α β cmp) : δ :=
foldr f init t
/-- Carries out a monadic action on each mapping in the tree map in ascending order. -/
@[inline]
def forM (f : (a : α) β a m PUnit) (t : DTreeMap α β cmp) : m PUnit :=
@@ -307,11 +349,31 @@ def keysArray (t : DTreeMap α β cmp) : Array α :=
def toList (t : DTreeMap α β cmp) : List ((a : α) × β a) :=
t.inner.toList
/-- Transforms a list of mappings into a tree map. -/
@[inline]
def ofList (l : List ((a : α) × β a)) (cmp : α α Ordering := by exact compare) :
DTreeMap α β cmp :=
letI : Ord α := cmp; Impl.ofList l, Impl.WF.empty.insertMany
@[inline, inherit_doc ofList, deprecated ofList (since := "2025-02-12")]
def fromList (l : List ((a : α) × β a)) (cmp : α α Ordering) : DTreeMap α β cmp :=
ofList l cmp
/-- Transforms the tree map into a list of mappings in ascending order. -/
@[inline]
def toArray (t : DTreeMap α β cmp) : Array ((a : α) × β a) :=
t.inner.toArray
/-- Transforms an array of mappings into a tree map. -/
@[inline]
def ofArray (a : Array ((a : α) × β a)) (cmp : α α Ordering := by exact compare) :
DTreeMap α β cmp :=
letI : Ord α := cmp; Impl.ofArray a, Impl.WF.empty.insertMany
@[inline, inherit_doc ofArray, deprecated ofArray (since := "2025-02-12")]
def fromArray (a : Array ((a : α) × β a)) (cmp : α α Ordering) : DTreeMap α β cmp :=
ofArray a cmp
/--
Returns a map that contains all mappings of `t₁` and `t₂`. In case that both maps contain the
same key `k` with respect to `cmp`, the provided function is used to determine the new value from
@@ -332,6 +394,11 @@ def mergeWith [LawfulEqCmp cmp] (mergeFn : (a : α) → β a → β a → β a)
DTreeMap α β cmp :=
letI : Ord α := cmp; t₁.inner.mergeWith mergeFn t₂.inner t₁.wf.balanced |>.impl, t₁.wf.mergeWith
@[inline, inherit_doc mergeWith, deprecated mergeWith (since := "2025-02-12")]
def mergeBy [LawfulEqCmp cmp] (mergeFn : (a : α) β a β a β a) (t₁ t₂ : DTreeMap α β cmp) :
DTreeMap α β cmp :=
mergeWith mergeFn t₁ t₂
namespace Const
variable {β : Type v}
@@ -340,17 +407,55 @@ variable {β : Type v}
def toList (t : DTreeMap α β cmp) : List (α × β) :=
Impl.Const.toList t.inner
@[inline, inherit_doc DTreeMap.ofList]
def ofList (l : List (α × β)) (cmp : α α Ordering := by exact compare) : DTreeMap α β cmp :=
letI : Ord α := cmp
Impl.Const.ofList l, Impl.WF.empty.constInsertMany
@[inline, inherit_doc DTreeMap.toArray]
def toArray (t : DTreeMap α β cmp) : Array (α × β) :=
t.foldl (init := ) fun acc k v => acc.push k,v
@[inline, inherit_doc DTreeMap.ofList]
def ofArray (a : Array (α × β)) (cmp : α α Ordering := by exact compare) : DTreeMap α β cmp :=
letI : Ord α := cmp
Impl.Const.ofArray a, Impl.WF.empty.constInsertMany
/-- Transforms a list of keys into a tree map. -/
@[inline]
def unitOfList (l : List α) (cmp : α α Ordering := by exact compare) : DTreeMap α Unit cmp :=
letI : Ord α := cmp
Impl.Const.unitOfList l, Impl.WF.empty.constInsertManyIfNewUnit
/-- Transforms an array of keys into a tree map. -/
@[inline]
def unitOfArray (a : Array α) (cmp : α α Ordering := by exact compare) : DTreeMap α Unit cmp :=
letI : Ord α := cmp
Impl.Const.unitOfArray a, Impl.WF.empty.constInsertManyIfNewUnit
@[inline, inherit_doc DTreeMap.mergeWith]
def mergeWith (mergeFn : α β β β) (t₁ t₂ : DTreeMap α β cmp) : DTreeMap α β cmp :=
letI : Ord α := cmp;
Impl.Const.mergeWith mergeFn t₁.inner t₂.inner t₁.wf.balanced |>.impl, t₁.wf.constMergeBy
Impl.Const.mergeWith mergeFn t₁.inner t₂.inner t₁.wf.balanced |>.impl, t₁.wf.constMergeWith
@[inline, inherit_doc mergeWith, deprecated mergeWith (since := "2025-02-12")]
def mergeBy (mergeFn : α β β β) (t₁ t₂ : DTreeMap α β cmp) : DTreeMap α β cmp :=
mergeWith mergeFn t₁ t₂
end Const
/--
Inserts multiple mappings into the tree map by iterating over the given collection and calling
`insert`. If the same key appears multiple times, the last occurrence takes precedence.
Note: this precedence behavior is true for `TreeMap`, `DTreeMap`, `TreeMap.Raw` and `DTreeMap.Raw`.
The `insertMany` function on `TreeSet` and `TreeSet.Raw` behaves differently: it will prefer the first
appearance.
-/
@[inline]
def insertMany {ρ} [ForIn Id ρ ((a : α) × β a)] (t : DTreeMap α β cmp) (l : ρ) : DTreeMap α β cmp :=
letI : Ord α := cmp; t.inner.insertMany l t.wf.balanced, t.wf.insertMany
/--
Erases multiple mappings from the tree map by iterating over the given collection and calling
`erase`.
@@ -359,6 +464,25 @@ Erases multiple mappings from the tree map by iterating over the given collectio
def eraseMany {ρ} [ForIn Id ρ α] (t : DTreeMap α β cmp) (l : ρ) : DTreeMap α β cmp :=
letI : Ord α := cmp; t.inner.eraseMany l t.wf.balanced, t.wf.eraseMany
namespace Const
variable {β : Type v}
@[inline, inherit_doc DTreeMap.insertMany]
def insertMany {ρ} [ForIn Id ρ (α × β)] (t : DTreeMap α β cmp) (l : ρ) : DTreeMap α β cmp :=
letI : Ord α := cmp; Impl.Const.insertMany t.inner l t.wf.balanced, t.wf.constInsertMany
/--
Inserts multiple elements into the tree map by iterating over the given collection and calling
`insertIfNew`. If the same key appears multiple times, the first occurrence takes precedence.
-/
@[inline]
def insertManyIfNewUnit {ρ} [ForIn Id ρ α] (t : DTreeMap α Unit cmp) (l : ρ) : DTreeMap α Unit cmp :=
letI : Ord α := cmp;
Impl.Const.insertManyIfNewUnit t.inner l t.wf.balanced, t.wf.constInsertManyIfNewUnit
end Const
instance [Repr α] [(a : α) Repr (β a)] : Repr (DTreeMap α β cmp) where
reprPrec m prec := Repr.addAppParen ("DTreeMap.ofList " ++ repr m.toList) prec

View File

@@ -456,6 +456,104 @@ def eraseMany! [Ord α] {ρ : Type w} [ForIn Id ρ α] (t : Impl α β) (l : ρ)
r := r.val.erase! a, fun h₀ h₁ => h₁ _ _ (r.2 h₀ h₁)
return r
/-- A tree map obtained by inserting elements into `t`, bundled with an inductive principle. -/
abbrev IteratedInsertionInto [Ord α] (t) :=
{ t' // {P : Impl α β Prop}, P t ( t'' a b h, P t'' P (t''.insert a b h).impl) P t' }
/-- Iterate over `l` and insert all of its elements into `t`. -/
@[inline]
def insertMany [Ord α] {ρ : Type w} [ForIn Id ρ ((a : α) × β a)] (t : Impl α β) (l : ρ) (h : t.Balanced) :
IteratedInsertionInto t := Id.run do
let mut r := t, fun h _ => h
for a, b in l do
let hr := r.2 h (fun t'' a b h _ => (t''.insert a b h).balanced_impl)
r := r.val.insert a b hr |>.impl, fun h₀ h₁ => h₁ _ _ _ _ (r.2 h₀ h₁)
return r
/-- A tree map obtained by inserting elements into `t`, bundled with an inductive principle. -/
abbrev IteratedSlowInsertionInto [Ord α] (t) :=
{ t' // {P : Impl α β Prop}, P t ( t'' a b, P t'' P (t''.insert! a b)) P t' }
/--
Slower version of `insertMany` which can be used in absence of balance information but still
assumes the preconditions of `insertMany`, otherwise might panic.
-/
@[inline]
def insertMany! [Ord α] {ρ : Type w} [ForIn Id ρ ((a : α) × β a)] (t : Impl α β) (l : ρ) :
IteratedSlowInsertionInto t := Id.run do
let mut r := t, fun h _ => h
for a, b in l do
r := r.val.insert! a b, fun h₀ h₁ => h₁ _ _ _ (r.2 h₀ h₁)
return r
namespace Const
variable {β : Type v}
/-- A tree map obtained by inserting elements into `t`, bundled with an inductive principle. -/
abbrev IteratedInsertionInto [Ord α] (t) :=
{ t' // {P : Impl α (fun _ => β) Prop}, P t ( t'' a b h, P t'' P (t''.insert a b h).impl) P t' }
/-- Iterate over `l` and insert all of its elements into `t`. -/
@[inline]
def insertMany [Ord α] {ρ : Type w} [ForIn Id ρ (α × β)] (t : Impl α (fun _ => β)) (l : ρ) (h : t.Balanced) :
IteratedInsertionInto t := Id.run do
let mut r := t, fun h _ => h
for a, b in l do
let hr := r.2 h (fun t'' a b h _ => (t''.insert a b h).balanced_impl)
r := r.val.insert a b hr |>.impl, fun h₀ h₁ => h₁ _ _ _ _ (r.2 h₀ h₁)
return r
/-- A tree map obtained by inserting elements into `t`, bundled with an inductive principle. -/
abbrev IteratedSlowInsertionInto [Ord α] (t) :=
{ t' // {P : Impl α (fun _ => β) Prop}, P t ( t'' a b, P t'' P (t''.insert! a b)) P t' }
/--
Slower version of `insertMany` which can be used in absence of balance information but still
assumes the preconditions of `insertMany`, otherwise might panic.
-/
@[inline]
def insertMany! [Ord α] {ρ : Type w} [ForIn Id ρ (α × β)] (t : Impl α (fun _ => β)) (l : ρ) :
IteratedSlowInsertionInto t := Id.run do
let mut r := t, fun h _ => h
for a, b in l do
r := r.val.insert! a b, fun h₀ h₁ => h₁ _ _ _ (r.2 h₀ h₁)
return r
/-- A tree map obtained by inserting elements into `t`, bundled with an inductive principle. -/
abbrev IteratedUnitInsertionInto [Ord α] (t) :=
{ t' // {P : Impl α (fun _ => Unit) Prop}, P t
( t'' a h, P t'' P (t''.insertIfNew a () h).impl) P t' }
/-- Iterate over `l` and insert all of its elements into `t`. -/
@[inline]
def insertManyIfNewUnit [Ord α] {ρ : Type w} [ForIn Id ρ α] (t : Impl α (fun _ => Unit)) (l : ρ) (h : t.Balanced) :
IteratedUnitInsertionInto t := Id.run do
let mut r := t, fun h _ => h
for a in l do
let hr := r.2 h (fun t'' a h _ => (t''.insertIfNew a () h).balanced_impl)
r := r.val.insertIfNew a () hr |>.impl, fun h₀ h₁ => h₁ _ _ _ (r.2 h₀ h₁)
return r
/-- A tree map obtained by inserting elements into `t`, bundled with an inductive principle. -/
abbrev IteratedSlowUnitInsertionInto [Ord α] (t) :=
{ t' // {P : Impl α (fun _ => Unit) Prop}, P t
( t'' a, P t'' P (t''.insertIfNew! a ())) P t' }
/--
Slower version of `insertManyIfNewUnit` which can be used in absence of balance information but still
assumes the preconditions of `insertManyIfNewUnit`, otherwise might panic.
-/
@[inline]
def insertManyIfNewUnit! [Ord α] {ρ : Type w} [ForIn Id ρ α] (t : Impl α (fun _ => Unit)) (l : ρ) :
IteratedSlowUnitInsertionInto t := Id.run do
let mut r := t, fun h _ => h
for a in l do
r := r.val.insertIfNew! a (), fun h₀ h₁ => h₁ _ _ (r.2 h₀ h₁)
return r
end Const
variable (α β) in
/-- A balanced tree. -/
structure BalancedTree where
@@ -470,6 +568,38 @@ attribute [Std.Internal.tree_tac] BalancedTree.balanced_impl
def SizedBalancedTree.toBalancedTree {lb ub} (t : SizedBalancedTree α β lb ub) : BalancedTree α β :=
t.impl, t.balanced_impl
/-- Transforms an array of mappings into a tree map. -/
@[inline]
def ofArray [Ord α] (a : Array ((a : α) × β a)) : Impl α β :=
empty.insertMany a balanced_empty |>.val
/-- Transforms a list of mappings into a tree map. -/
@[inline]
def ofList [Ord α] (l : List ((a : α) × β a)) : Impl α β :=
empty.insertMany l balanced_empty |>.val
namespace Const
variable {β : Type v}
/-- Transforms a list of mappings into a tree map. -/
@[inline] def ofArray [Ord α] (a : Array (α × β)) : Impl α (fun _ => β) :=
insertMany empty a balanced_empty |>.val
/-- Transforms an array of mappings into a tree map. -/
@[inline] def ofList [Ord α] (l : List (α × β)) : Impl α (fun _ => β) :=
insertMany empty l balanced_empty |>.val
/-- Transforms a list of mappings into a tree map. -/
@[inline] def unitOfArray [Ord α] (a : Array α) : Impl α (fun _ => Unit) :=
insertManyIfNewUnit empty a balanced_empty |>.val
/-- Transforms an array of mappings into a tree map. -/
@[inline] def unitOfList [Ord α] (l : List α) : Impl α (fun _ => Unit) :=
insertManyIfNewUnit empty l balanced_empty |>.val
end Const
/--
Returns the tree consisting of the mappings `(k, (f k v).get)` where `(k, v)` was a mapping in
the original tree and `(f k v).isSome`.

View File

@@ -59,7 +59,7 @@ inductive WF [Ord α] : {β : α → Type v} → Impl α β → Prop where
/-- `mergeWith` preserves well-formedness. Later shown to be subsumed by `.wf`. -/
| mergeWith {t₁ t₂ f h} [LawfulEqOrd α] : WF t₁ WF (t₁.mergeWith f t₂ h).impl
/-- `mergeWith` preserves well-formedness. Later shown to be subsumed by `.wf`. -/
| constMergeBy {t₁ t₂ f h} : WF t₁ WF (Impl.Const.mergeWith f t₁ t₂ h).impl
| constMergeWith {t₁ t₂ f h} : WF t₁ WF (Impl.Const.mergeWith f t₁ t₂ h).impl
/--
A well-formed tree is balanced. This is needed here already because we need to know that the
@@ -74,6 +74,18 @@ theorem WF.eraseMany [Ord α] {ρ} [ForIn Id ρ α] {t : Impl α β} {l : ρ} {h
WF (t.eraseMany l h).val :=
(t.eraseMany l h).2 hwf fun _ _ _ hwf' => hwf'.erase
theorem WF.insertMany [Ord α] {ρ} [ForIn Id ρ ((a : α) × β a)] {t : Impl α β} {l : ρ} {h} (hwf : WF t) :
WF (t.insertMany l h).val :=
(t.insertMany l h).2 hwf fun _ _ _ _ hwf' => hwf'.insert
theorem WF.constInsertMany [Ord α] {β : Type v} {ρ} [ForIn Id ρ (α × β)] {t : Impl α (fun _ => β)}
{l : ρ} {h} (hwf : WF t) : WF (Impl.Const.insertMany t l h).val :=
(Impl.Const.insertMany t l h).2 hwf fun _ _ _ _ hwf' => hwf'.insert
theorem WF.constInsertManyIfNewUnit [Ord α] {ρ} [ForIn Id ρ α] {t : Impl α (fun _ => Unit)} {l : ρ}
{h} (hwf : WF t) : WF (Impl.Const.insertManyIfNewUnit t l h).val :=
(Impl.Const.insertManyIfNewUnit t l h).2 hwf fun _ _ _ hwf' => hwf'.insertIfNew
end Impl
end Std.DTreeMap.Internal

View File

@@ -67,6 +67,7 @@ structure Raw (α : Type u) (β : α → Type v) (_cmp : αα → Ordering
inner : Internal.Impl α β
namespace Raw
open Internal (Impl)
/--
Well-formedness predicate for tree maps. Users of `DTreeMap` will not need to interact with
@@ -150,6 +151,10 @@ def erase (t : Raw α β cmp) (a : α) : Raw α β cmp :=
def get? [LawfulEqCmp cmp] (t : Raw α β cmp) (a : α) : Option (β a) :=
letI : Ord α := cmp; t.inner.get? a
@[inline, inherit_doc get?, deprecated get? (since := "2025-02-12")]
def find? [LawfulEqCmp cmp] (t : Raw α β cmp) (a : α) : Option (β a) :=
t.get? a
@[inline, inherit_doc DTreeMap.get]
def get [LawfulEqCmp cmp] (t : Raw α β cmp) (a : α) (h : a t) : β a :=
letI : Ord α := cmp; t.inner.get a h
@@ -158,18 +163,30 @@ def get [LawfulEqCmp cmp] (t : Raw α β cmp) (a : α) (h : a ∈ t) : β a :=
def get! [LawfulEqCmp cmp] (t : Raw α β cmp) (a : α) [Inhabited (β a)] : β a :=
letI : Ord α := cmp; t.inner.get! a
@[inline, inherit_doc get!, deprecated get! (since := "2025-02-12")]
def find! [LawfulEqCmp cmp] (t : Raw α β cmp) (a : α) [Inhabited (β a)] : β a :=
t.get! a
@[inline, inherit_doc DTreeMap.getD]
def getD [LawfulEqCmp cmp] (t : Raw α β cmp) (a : α) (fallback : β a) : β a :=
letI : Ord α := cmp; t.inner.getD a fallback
@[inline, inherit_doc getD, deprecated getD (since := "2025-02-12")]
def findD [LawfulEqCmp cmp] (t : Raw α β cmp) (a : α) (fallback : β a) : β a :=
t.getD a fallback
namespace Const
open Internal (Impl)
variable {β : Type v}
@[inline, inherit_doc DTreeMap.get?] def get? (t : Raw α β cmp) (a : α) : Option β :=
@[inline, inherit_doc DTreeMap.get?]
def get? (t : Raw α β cmp) (a : α) : Option β :=
letI : Ord α := cmp; Impl.Const.get? a t.inner
@[inline, inherit_doc get?, deprecated get? (since := "2025-02-12")]
def find? (t : Raw α β cmp) (a : α) : Option β :=
get? t a
@[inline, inherit_doc DTreeMap.get]
def get (t : Raw α β cmp) (a : α) (h : a t) : β :=
letI : Ord α := cmp; Impl.Const.get a t.inner h
@@ -178,10 +195,18 @@ def get (t : Raw α β cmp) (a : α) (h : a ∈ t) : β :=
def get! (t : Raw α β cmp) (a : α) [Inhabited β] : β :=
letI : Ord α := cmp; Impl.Const.get! a t.inner
@[inline, inherit_doc get!, deprecated get! (since := "2025-02-12")]
def find! (t : Raw α β cmp) (a : α) [Inhabited β] : β :=
get! t a
@[inline, inherit_doc DTreeMap.getD]
def getD (t : Raw α β cmp) (a : α) (fallback : β) : β :=
letI : Ord α := cmp; Impl.Const.getD a t.inner fallback
@[inline, inherit_doc getD, deprecated getD (since := "2025-02-12")]
def findD (t : Raw α β cmp) (a : α) (fallback : β) : β :=
getD t a fallback
end Const
variable {δ : Type w} {m : Type w Type w₂} [Monad m]
@@ -194,10 +219,30 @@ def filter (f : (a : α) → β a → Bool) (t : Raw α β cmp) : Raw α β cmp
def foldlM (f : δ (a : α) β a m δ) (init : δ) (t : Raw α β cmp) : m δ :=
t.inner.foldlM f init
@[inline, inherit_doc foldlM, deprecated foldlM (since := "2025-02-12")]
def foldM (f : δ (a : α) β a m δ) (init : δ) (t : Raw α β cmp) : m δ :=
t.foldlM f init
@[inline, inherit_doc DTreeMap.foldl]
def foldl (f : δ (a : α) β a δ) (init : δ) (t : Raw α β cmp) : δ :=
t.inner.foldl f init
@[inline, inherit_doc foldl, deprecated foldl (since := "2025-02-12")]
def fold (f : δ (a : α) β a δ) (init : δ) (t : Raw α β cmp) : δ :=
t.foldl f init
@[inline, inherit_doc DTreeMap.foldrM]
def foldrM (f : δ (a : α) β a m δ) (init : δ) (t : Raw α β cmp) : m δ :=
t.inner.foldrM f init
@[inline, inherit_doc DTreeMap.foldr]
def foldr (f : δ (a : α) β a δ) (init : δ) (t : Raw α β cmp) : δ :=
t.inner.foldr f init
@[inline, inherit_doc foldr, deprecated foldr (since := "2025-02-12")]
def revFold (f : δ (a : α) β a δ) (init : δ) (t : Raw α β cmp) : δ :=
foldr f init t
@[inline, inherit_doc DTreeMap.forM]
def forM (f : (a : α) β a m PUnit) (t : Raw α β cmp) : m PUnit :=
t.inner.forM f
@@ -236,37 +281,100 @@ def keysArray (t : Raw α β cmp) : Array α :=
def toList (t : Raw α β cmp) : List ((a : α) × β a) :=
t.inner.toList
/-- Transforms a list of mappings into a tree map. -/
@[inline]
def ofList (l : List ((a : α) × β a)) (cmp : α α Ordering := by exact compare) : Raw α β cmp :=
letI : Ord α := cmp
Impl.ofList l
@[inline, inherit_doc ofList, deprecated ofList (since := "2025-02-12")]
def fromList (l : List ((a : α) × β a)) (cmp : α α Ordering) : Raw α β cmp :=
ofList l cmp
@[inline, inherit_doc DTreeMap.toArray]
def toArray (t : Raw α β cmp) : Array ((a : α) × β a) :=
t.inner.toArray
/-- Transforms an array of mappings into a tree map. -/
@[inline]
def ofArray (a : Array ((a : α) × β a)) (cmp : α α Ordering := by exact compare) : Raw α β cmp :=
letI : Ord α := cmp
Impl.ofArray a
@[inline, inherit_doc ofArray, deprecated ofArray (since := "2025-02-12")]
def fromArray (a : Array ((a : α) × β a)) (cmp : α α Ordering) : Raw α β cmp :=
ofArray a cmp
@[inline, inherit_doc DTreeMap.mergeWith]
def mergeWith [LawfulEqCmp cmp] (mergeFn : (a : α) β a β a β a) (t₁ t₂ : Raw α β cmp) : Raw α β cmp :=
letI : Ord α := cmp; t₁.inner.mergeWith! mergeFn t₂.inner
@[inline, inherit_doc mergeWith, deprecated mergeWith (since := "2025-02-12")]
def mergeBy [LawfulEqCmp cmp] (mergeFn : (a : α) β a β a β a) (t₁ t₂ : Raw α β cmp) :
Raw α β cmp :=
mergeWith mergeFn t₁ t₂
namespace Const
open Internal (Impl)
variable {β : Type v}
@[inline, inherit_doc Raw.toList]
@[inline, inherit_doc DTreeMap.Const.toList]
def toList (t : Raw α β cmp) : List (α × β) :=
Impl.Const.toList t.inner
@[inline, inherit_doc Raw.toArray]
@[inline, inherit_doc DTreeMap.Const.ofList]
def ofList (l : List (α × β)) (cmp : α α Ordering := by exact compare) : Raw α β cmp :=
letI : Ord α := cmp; Impl.Const.ofList l
@[inline, inherit_doc DTreeMap.Const.unitOfList]
def unitOfList (l : List α) (cmp : α α Ordering := by exact compare) : Raw α Unit cmp :=
letI : Ord α := cmp; Impl.Const.unitOfList l
@[inline, inherit_doc DTreeMap.Const.toArray]
def toArray (t : Raw α β cmp) : Array (α × β) :=
Impl.Const.toArray t.inner
@[inline, inherit_doc Raw.mergeWith]
@[inline, inherit_doc DTreeMap.Const.ofArray]
def ofArray (a : Array (α × β)) (cmp : α α Ordering := by exact compare) : Raw α β cmp :=
letI : Ord α := cmp; Impl.Const.ofArray a
@[inline, inherit_doc DTreeMap.Const.ofArray]
def unitOfArray (a : Array α) (cmp : α α Ordering := by exact compare) : Raw α Unit cmp :=
letI : Ord α := cmp; Impl.Const.unitOfArray a
@[inline, inherit_doc DTreeMap.Const.mergeWith]
def mergeWith (mergeFn : α β β β) (t₁ t₂ : Raw α β cmp) : Raw α β cmp :=
letI : Ord α := cmp; Impl.Const.mergeWith! mergeFn t₁.inner t₂.inner
@[inline, inherit_doc mergeWith, deprecated mergeWith (since := "2025-02-12")]
def mergeBy (mergeFn : α β β β) (t₁ t₂ : Raw α β cmp) : Raw α β cmp :=
mergeWith mergeFn t₁ t₂
end Const
@[inline, inherit_doc DTreeMap.insertMany]
def insertMany {ρ} [ForIn Id ρ ((a : α) × β a)] (t : Raw α β cmp) (l : ρ) : Raw α β cmp :=
letI : Ord α := cmp; t.inner.insertMany! l
@[inline, inherit_doc DTreeMap.eraseMany]
def eraseMany {ρ} [ForIn Id ρ α] (t : Raw α β cmp) (l : ρ) : Raw α β cmp :=
letI : Ord α := cmp; t.inner.eraseMany! l
namespace Const
variable {β : Type v}
@[inline, inherit_doc DTreeMap.Const.insertMany]
def insertMany {ρ} [ForIn Id ρ (α × β)] (t : Raw α β cmp) (l : ρ) : Raw α β cmp :=
letI : Ord α := cmp; Impl.Const.insertMany! t.inner l
@[inline, inherit_doc DTreeMap.Const.insertManyIfNewUnit]
def insertManyIfNewUnit {ρ} [ForIn Id ρ α] (t : Raw α Unit cmp) (l : ρ) : Raw α Unit cmp :=
letI : Ord α := cmp; Impl.Const.insertManyIfNewUnit! t.inner l
end Const
instance [Repr α] [(a : α) Repr (β a)] : Repr (Raw α β cmp) where
reprPrec m prec := Repr.addAppParen ("DTreeMap.Raw.ofList " ++ repr m.toList) prec

View File

@@ -129,6 +129,10 @@ def erase (t : TreeMap α β cmp) (a : α) : TreeMap α β cmp :=
def get? (t : TreeMap α β cmp) (a : α) : Option β :=
DTreeMap.Const.get? t.inner a
@[inline, inherit_doc get?, deprecated get? (since := "2025-02-12")]
def find? (t : TreeMap α β cmp) (a : α) : Option β :=
get? t a
@[inline, inherit_doc DTreeMap.get]
def get (t : TreeMap α β cmp) (a : α) (h : a t) : β :=
DTreeMap.Const.get t.inner a h
@@ -137,10 +141,18 @@ def get (t : TreeMap α β cmp) (a : α) (h : a ∈ t) : β :=
def get! (t : TreeMap α β cmp) (a : α) [Inhabited β] : β :=
DTreeMap.Const.get! t.inner a
@[inline, inherit_doc get!, deprecated get! (since := "2025-02-12")]
def find! (t : TreeMap α β cmp) (a : α) [Inhabited β] : β :=
get! t a
@[inline, inherit_doc DTreeMap.getD]
def getD (t : TreeMap α β cmp) (a : α) (fallback : β) : β :=
DTreeMap.Const.getD t.inner a fallback
@[inline, inherit_doc getD, deprecated getD (since := "2025-02-12")]
def findD (t : TreeMap α β cmp) (a : α) (fallback : β) : β :=
getD t a fallback
instance : GetElem? (TreeMap α β cmp) α β (fun m a => a m) where
getElem m a h := m.get a h
getElem? m a := m.get? a
@@ -156,10 +168,30 @@ def filter (f : α → β → Bool) (m : TreeMap α β cmp) : TreeMap α β cmp
def foldlM (f : δ (a : α) β m δ) (init : δ) (t : TreeMap α β cmp) : m δ :=
t.inner.foldlM f init
@[inline, inherit_doc foldlM, deprecated foldlM (since := "2025-02-12")]
def foldM (f : δ (a : α) β m δ) (init : δ) (t : TreeMap α β cmp) : m δ :=
t.foldlM f init
@[inline, inherit_doc DTreeMap.foldl]
def foldl (f : δ (a : α) β δ) (init : δ) (t : TreeMap α β cmp) : δ :=
t.inner.foldl f init
@[inline, inherit_doc foldl, deprecated foldl (since := "2025-02-12")]
def fold (f : δ (a : α) β δ) (init : δ) (t : TreeMap α β cmp) : δ :=
t.foldl f init
@[inline, inherit_doc DTreeMap.foldrM]
def foldrM (f : δ (a : α) β m δ) (init : δ) (t : TreeMap α β cmp) : m δ :=
t.inner.foldrM f init
@[inline, inherit_doc DTreeMap.foldr]
def foldr (f : δ (a : α) β δ) (init : δ) (t : TreeMap α β cmp) : δ :=
t.inner.foldr f init
@[inline, inherit_doc foldr, deprecated foldr (since := "2025-02-12")]
def revFold (f : δ (a : α) β δ) (init : δ) (t : TreeMap α β cmp) : δ :=
foldr f init t
@[inline, inherit_doc DTreeMap.forM]
def forM (f : α β m PUnit) (t : TreeMap α β cmp) : m PUnit :=
t.inner.forM f
@@ -194,14 +226,50 @@ def keysArray (t : TreeMap α β cmp) : Array α :=
def toList (t : TreeMap α β cmp) : List (α × β) :=
DTreeMap.Const.toList t.inner
@[inline, inherit_doc DTreeMap.Const.ofList]
def ofList (l : List (α × β)) (cmp : α α Ordering := by exact compare) : TreeMap α β cmp :=
DTreeMap.Const.ofList l cmp
@[inline, inherit_doc ofList, deprecated ofList (since := "2025-02-12")]
def fromList (l : List (α × β)) (cmp : α α Ordering) : TreeMap α β cmp :=
ofList l cmp
@[inline, inherit_doc DTreeMap.Const.unitOfList]
def unitOfList (l : List α) (cmp : α α Ordering := by exact compare) : TreeMap α Unit cmp :=
DTreeMap.Const.unitOfList l cmp
@[inline, inherit_doc DTreeMap.Const.toArray]
def toArray (t : TreeMap α β cmp) : Array (α × β) :=
DTreeMap.Const.toArray t.inner
@[inline, inherit_doc DTreeMap.mergeWith]
@[inline, inherit_doc DTreeMap.Const.ofArray]
def ofArray (a : Array (α × β)) (cmp : α α Ordering := by exact compare) : TreeMap α β cmp :=
DTreeMap.Const.ofArray a cmp
@[inline, inherit_doc ofArray, deprecated ofArray (since := "2025-02-12")]
def fromArray (a : Array (α × β)) (cmp : α α Ordering) : TreeMap α β cmp :=
ofArray a cmp
@[inline, inherit_doc DTreeMap.Const.unitOfArray]
def unitOfArray (a : Array α) (cmp : α α Ordering := by exact compare) : TreeMap α Unit cmp :=
DTreeMap.Const.unitOfArray a cmp
@[inline, inherit_doc DTreeMap.Const.mergeWith]
def mergeWith (mergeFn : α β β β) (t₁ t₂ : TreeMap α β cmp) : TreeMap α β cmp :=
DTreeMap.Const.mergeWith mergeFn t₁.inner t₂.inner
@[inline, inherit_doc mergeWith, deprecated mergeWith (since := "2025-02-12")]
def mergeBy (mergeFn : α β β β) (t₁ t₂ : TreeMap α β cmp) : TreeMap α β cmp :=
mergeWith mergeFn t₁ t₂
@[inline, inherit_doc DTreeMap.Const.insertMany]
def insertMany {ρ} [ForIn Id ρ (α × β)] (t : TreeMap α β cmp) (l : ρ) : TreeMap α β cmp :=
DTreeMap.Const.insertMany t.inner l
@[inline, inherit_doc DTreeMap.Const.insertManyIfNewUnit]
def insertManyIfNewUnit {ρ} [ForIn Id ρ α] (t : TreeMap α Unit cmp) (l : ρ) : TreeMap α Unit cmp :=
DTreeMap.Const.insertManyIfNewUnit t.inner l
@[inline, inherit_doc DTreeMap.eraseMany]
def eraseMany {ρ} [ForIn Id ρ α] (t : TreeMap α β cmp) (l : ρ) : TreeMap α β cmp :=
t.inner.eraseMany l

View File

@@ -147,6 +147,10 @@ def erase (t : Raw α β cmp) (a : α) : Raw α β cmp :=
def get? (t : Raw α β cmp) (a : α) : Option β :=
DTreeMap.Raw.Const.get? t.inner a
@[inline, inherit_doc get?, deprecated get? (since := "2025-02-12")]
def find? (t : Raw α β cmp) (a : α) : Option β :=
get? t a
@[inline, inherit_doc DTreeMap.Raw.Const.get]
def get (t : Raw α β cmp) (a : α) (h : a t) : β :=
DTreeMap.Raw.Const.get t.inner a h
@@ -155,10 +159,18 @@ def get (t : Raw α β cmp) (a : α) (h : a ∈ t) : β :=
def get! (t : Raw α β cmp) (a : α) [Inhabited β] : β :=
DTreeMap.Raw.Const.get! t.inner a
@[inline, inherit_doc get!, deprecated get! (since := "2025-02-12")]
def find! (t : Raw α β cmp) (a : α) [Inhabited β] : β :=
get! t a
@[inline, inherit_doc DTreeMap.Raw.Const.getD]
def getD (t : Raw α β cmp) (a : α) (fallback : β) : β :=
DTreeMap.Raw.Const.getD t.inner a fallback
@[inline, inherit_doc getD, deprecated getD (since := "2025-02-12")]
def findD (t : Raw α β cmp) (a : α) (fallback : β) : β :=
getD t a fallback
instance : GetElem? (Raw α β cmp) α β (fun m a => a m) where
getElem m a h := m.get a h
getElem? m a := m.get? a
@@ -174,10 +186,30 @@ def filter (f : α → β → Bool) (t : Raw α β cmp) : Raw α β cmp :=
def foldlM (f : δ (a : α) β m δ) (init : δ) (t : Raw α β cmp) : m δ :=
t.inner.foldlM f init
@[inline, inherit_doc foldlM, deprecated foldlM (since := "2025-02-12")]
def foldM (f : δ (a : α) β m δ) (init : δ) (t : Raw α β cmp) : m δ :=
t.foldlM f init
@[inline, inherit_doc DTreeMap.Raw.foldl]
def foldl (f : δ (a : α) β δ) (init : δ) (t : Raw α β cmp) : δ :=
t.inner.foldl f init
@[inline, inherit_doc foldl, deprecated foldl (since := "2025-02-12")]
def fold (f : δ (a : α) β δ) (init : δ) (t : Raw α β cmp) : δ :=
t.foldl f init
@[inline, inherit_doc DTreeMap.Raw.foldrM]
def foldrM (f : δ (a : α) β m δ) (init : δ) (t : Raw α β cmp) : m δ :=
t.inner.foldrM f init
@[inline, inherit_doc DTreeMap.Raw.foldr]
def foldr (f : δ (a : α) β δ) (init : δ) (t : Raw α β cmp) : δ :=
t.inner.foldr f init
@[inline, inherit_doc foldr, deprecated foldr (since := "2025-02-12")]
def revFold (f : δ (a : α) β δ) (init : δ) (t : Raw α β cmp) : δ :=
foldr f init t
@[inline, inherit_doc DTreeMap.Raw.forM]
def forM (f : α β m PUnit) (t : Raw α β cmp) : m PUnit :=
t.inner.forM f
@@ -208,18 +240,54 @@ def keys (t : Raw α β cmp) : List α :=
def keysArray (t : Raw α β cmp) : Array α :=
t.inner.keysArray
@[inline, inherit_doc DTreeMap.Raw.toList]
@[inline, inherit_doc DTreeMap.Raw.Const.toList]
def toList (t : Raw α β cmp) : List (α × β) :=
DTreeMap.Raw.Const.toList t.inner
@[inline, inherit_doc DTreeMap.Raw.toArray]
@[inline, inherit_doc DTreeMap.Raw.Const.ofList]
def ofList (l : List (α × β)) (cmp : α α Ordering := by exact compare) : Raw α β cmp :=
DTreeMap.Raw.Const.ofList l cmp
@[inline, inherit_doc ofList, deprecated ofList (since := "2025-02-12")]
def fromList (l : List (α × β)) (cmp : α α Ordering) : Raw α β cmp :=
ofList l cmp
@[inline, inherit_doc DTreeMap.Const.unitOfList]
def unitOfList (l : List α) (cmp : α α Ordering := by exact compare) : Raw α Unit cmp :=
DTreeMap.Raw.Const.unitOfList l cmp
@[inline, inherit_doc DTreeMap.Raw.Const.toArray]
def toArray (t : Raw α β cmp) : Array (α × β) :=
DTreeMap.Raw.Const.toArray t.inner
@[inline, inherit_doc DTreeMap.Raw.Const.ofArray]
def ofArray (a : Array (α × β)) (cmp : α α Ordering := by exact compare) : Raw α β cmp :=
DTreeMap.Raw.Const.ofArray a cmp
@[inline, inherit_doc ofArray, deprecated ofArray (since := "2025-02-12")]
def fromArray (a : Array (α × β)) (cmp : α α Ordering) : Raw α β cmp :=
ofArray a cmp
@[inline, inherit_doc DTreeMap.Const.unitOfArray]
def unitOfArray (a : Array α) (cmp : α α Ordering := by exact compare) : Raw α Unit cmp :=
DTreeMap.Raw.Const.unitOfArray a cmp
@[inline, inherit_doc DTreeMap.Raw.mergeWith]
def mergeWith (mergeFn : α β β β) (t₁ t₂ : Raw α β cmp) : Raw α β cmp :=
DTreeMap.Raw.Const.mergeWith mergeFn t₁.inner t₂.inner
@[inline, inherit_doc mergeWith, deprecated mergeWith (since := "2025-02-12")]
def mergeBy (mergeFn : α β β β) (t₁ t₂ : Raw α β cmp) : Raw α β cmp :=
mergeWith mergeFn t₁ t₂
@[inline, inherit_doc DTreeMap.Raw.Const.insertMany]
def insertMany {ρ} [ForIn Id ρ (α × β)] (t : Raw α β cmp) (l : ρ) : Raw α β cmp :=
DTreeMap.Raw.Const.insertMany t.inner l
@[inline, inherit_doc DTreeMap.Raw.Const.insertManyIfNewUnit]
def insertManyIfNewUnit {ρ} [ForIn Id ρ α] (t : Raw α Unit cmp) (l : ρ) : Raw α Unit cmp :=
DTreeMap.Raw.Const.insertManyIfNewUnit t.inner l
@[inline, inherit_doc DTreeMap.Raw.eraseMany]
def eraseMany {ρ} [ForIn Id ρ α] (t : Raw α β cmp) (l : ρ) : Raw α β cmp :=
t.inner.eraseMany l

View File

@@ -162,11 +162,36 @@ ascending order.
def foldlM {m δ} [Monad m] (f : δ (a : α) m δ) (init : δ) (t : TreeSet α cmp) : m δ :=
t.inner.foldlM (fun c a _ => f c a) init
@[inline, inherit_doc foldlM, deprecated foldlM (since := "2025-02-12")]
def foldM (f : δ (a : α) m δ) (init : δ) (t : TreeSet α cmp) : m δ :=
t.foldlM f init
/-- Folds the given function over the elements of the tree set in ascending order. -/
@[inline]
def foldl (f : δ (a : α) δ) (init : δ) (t : TreeSet α cmp) : δ :=
t.inner.foldl (fun c a _ => f c a) init
@[inline, inherit_doc foldl, deprecated foldl (since := "2025-02-12")]
def fold (f : δ (a : α) δ) (init : δ) (t : TreeSet α cmp) : δ :=
t.foldl f init
/--
Monadically computes a value by folding the given function over the elements in the tree set in
descending order.
-/
@[inline]
def foldrM {m δ} [Monad m] (f : δ (a : α) m δ) (init : δ) (t : TreeSet α cmp) : m δ :=
t.inner.foldrM (fun c a _ => f c a) init
/-- Folds the given function over the elements of the tree set in descending order. -/
@[inline]
def foldr (f : δ (a : α) δ) (init : δ) (t : TreeSet α cmp) : δ :=
t.inner.foldr (fun c a _ => f c a) init
@[inline, inherit_doc foldr, deprecated foldr (since := "2025-02-12")]
def revFold (f : δ (a : α) δ) (init : δ) (t : TreeSet α cmp) : δ :=
foldr f init t
/-- Carries out a monadic action on each element in the tree set in ascending order. -/
@[inline]
def forM (f : α m PUnit) (t : TreeSet α cmp) : m PUnit :=
@@ -201,11 +226,27 @@ def all (t : TreeSet α cmp) (p : α → Bool) : Bool :=
def toList (t : TreeSet α cmp) : List α :=
t.inner.inner.inner.foldr (fun l a _ => a :: l)
/-- Transforms a list into a tree set. -/
def ofList (l : List α) (cmp : α α Ordering := by exact compare) : TreeSet α cmp :=
TreeMap.unitOfList l cmp
@[inline, inherit_doc ofList, deprecated ofList (since := "2025-02-12")]
def fromList (l : List α) (cmp : α α Ordering) : TreeSet α cmp :=
ofList l cmp
/-- Transforms the tree set into an array of elements in ascending order. -/
@[inline]
def toArray (t : TreeSet α cmp) : Array α :=
t.foldl (init := ) fun acc k => acc.push k
/-- Transforms an array into a tree set. -/
def ofArray (a : Array α) (cmp : α α Ordering := by exact compare) : TreeSet α cmp :=
TreeMap.unitOfArray a cmp
@[inline, inherit_doc ofArray, deprecated ofArray (since := "2025-02-12")]
def fromArray (a : Array α) (cmp : α α Ordering) : TreeSet α cmp :=
ofArray a cmp
/--
Returns a set that contains all mappings of `t₁` and `t₂.
@@ -220,6 +261,19 @@ size of `t₂` as long as `t₁` is unshared.
def merge (t₁ t₂ : TreeSet α cmp) : TreeSet α cmp :=
TreeMap.mergeWith (fun _ _ _ => ()) t₁.inner t₂.inner
/--
Inserts multiple elements into the tree set by iterating over the given collection and calling
`insert`. If the same element (with respect to `cmp`) appears multiple times, the first occurrence
takes precedence.
Note: this precedence behavior is true for `TreeSet` and `TreeSet.Raw`. The `insertMany` function on
`TreeMap`, `DTreeMap`, `TreeMap.Raw` and `DTreeMap.Raw` behaves differently: it will prefer the last
appearance.
-/
@[inline]
def insertMany {ρ} [ForIn Id ρ α] (t : TreeSet α cmp) (l : ρ) : TreeSet α cmp :=
TreeMap.insertManyIfNewUnit t.inner l
/--
Erases multiple items from the tree set by iterating over the given collection and calling erase.
-/

View File

@@ -143,10 +143,30 @@ def filter (f : α → Bool) (t : Raw α cmp) : Raw α cmp :=
def foldlM (f : δ (a : α) m δ) (init : δ) (t : Raw α cmp) : m δ :=
t.inner.foldlM (fun c a _ => f c a) init
@[inline, inherit_doc foldlM, deprecated foldlM (since := "2025-02-12")]
def foldM (f : δ (a : α) m δ) (init : δ) (t : Raw α cmp) : m δ :=
t.foldlM f init
@[inline, inherit_doc TreeSet.empty]
def foldl (f : δ (a : α) δ) (init : δ) (t : Raw α cmp) : δ :=
t.inner.foldl (fun c a _ => f c a) init
@[inline, inherit_doc foldl, deprecated foldl (since := "2025-02-12")]
def fold (f : δ (a : α) δ) (init : δ) (t : Raw α cmp) : δ :=
t.foldl f init
@[inline, inherit_doc TreeSet.empty]
def foldrM (f : δ (a : α) m δ) (init : δ) (t : Raw α cmp) : m δ :=
t.inner.foldrM (fun c a _ => f c a) init
@[inline, inherit_doc TreeSet.empty]
def foldr (f : δ (a : α) δ) (init : δ) (t : Raw α cmp) : δ :=
t.inner.foldr (fun c a _ => f c a) init
@[inline, inherit_doc foldr, deprecated foldr (since := "2025-02-12")]
def revFold (f : δ (a : α) δ) (init : δ) (t : Raw α cmp) : δ :=
foldr f init t
@[inline, inherit_doc TreeSet.empty]
def forM (f : α m PUnit) (t : Raw α cmp) : m PUnit :=
t.inner.forM (fun a _ => f a)
@@ -173,14 +193,34 @@ def all (t : Raw α cmp) (p : α → Bool) : Bool :=
def toList (t : Raw α cmp) : List α :=
t.inner.inner.inner.foldr (fun l a _ => a :: l)
@[inline, inherit_doc TreeSet.ofList]
def ofList (l : List α) (cmp : α α Ordering := by exact compare) : Raw α cmp :=
TreeMap.Raw.unitOfList l cmp
@[inline, inherit_doc ofList, deprecated ofList (since := "2025-02-12")]
def fromList (l : List α) (cmp : α α Ordering) : Raw α cmp :=
ofList l cmp
@[inline, inherit_doc TreeSet.empty]
def toArray (t : Raw α cmp) : Array α :=
t.foldl (init := #[]) fun acc k => acc.push k
@[inline, inherit_doc TreeSet.ofArray]
def ofArray (a : Array α) (cmp : α α Ordering := by exact compare) : Raw α cmp :=
TreeMap.Raw.unitOfArray a cmp
@[inline, inherit_doc ofArray, deprecated ofArray (since := "2025-02-12")]
def fromArray (a : Array α) (cmp : α α Ordering) : Raw α cmp :=
ofArray a cmp
@[inline, inherit_doc TreeSet.empty]
def merge (t₁ t₂ : Raw α cmp) : Raw α cmp :=
TreeMap.Raw.mergeWith (fun _ _ _ => ()) t₁.inner t₂.inner
@[inline, inherit_doc TreeSet.insertMany]
def insertMany {ρ} [ForIn Id ρ α] (t : Raw α cmp) (l : ρ) : Raw α cmp :=
TreeMap.Raw.insertManyIfNewUnit t.inner l
@[inline, inherit_doc TreeSet.empty]
def eraseMany {ρ} [ForIn Id ρ α] (t : Raw α cmp) (l : ρ) : Raw α cmp :=
t.inner.eraseMany l

View File

@@ -57,8 +57,8 @@ theorem BitVec.sle_eq_ult (x y : BitVec w) :
attribute [bv_normalize] BitVec.ofNat_eq_ofNat
@[bv_normalize]
theorem BitVec.ofNatLt_reduce (n : Nat) (h) : BitVec.ofNatLt n h = BitVec.ofNat w n := by
simp [BitVec.ofNatLt, BitVec.ofNat, Fin.ofNat', Nat.mod_eq_of_lt h]
theorem BitVec.ofNatLT_reduce (n : Nat) (h) : BitVec.ofNatLT n h = BitVec.ofNat w n := by
simp [BitVec.ofNatLT, BitVec.ofNat, Fin.ofNat', Nat.mod_eq_of_lt h]
@[bv_normalize]
theorem BitVec.ofBool_eq_if (b : Bool) : BitVec.ofBool b = bif b then 1#1 else 0#1 := by

View File

@@ -2015,6 +2015,7 @@ static inline uint8_t lean_int8_dec_le(uint8_t a1, uint8_t a2) {
static inline uint16_t lean_int8_to_int16(uint8_t a) { return (uint16_t)(int16_t)(int8_t)a; }
static inline uint32_t lean_int8_to_int32(uint8_t a) { return (uint32_t)(int32_t)(int8_t)a; }
static inline uint64_t lean_int8_to_int64(uint8_t a) { return (uint64_t)(int64_t)(int8_t)a; }
static inline size_t lean_int8_to_isize(uint8_t a) { return (size_t)(ptrdiff_t)(int8_t)a; }
/* Int16 */
@@ -2155,6 +2156,7 @@ static inline uint8_t lean_int16_dec_le(uint16_t a1, uint16_t a2) {
static inline uint8_t lean_int16_to_int8(uint16_t a) { return (uint8_t)(int8_t)(int16_t)a; }
static inline uint32_t lean_int16_to_int32(uint16_t a) { return (uint32_t)(int32_t)(int16_t)a; }
static inline uint64_t lean_int16_to_int64(uint16_t a) { return (uint64_t)(int64_t)(int16_t)a; }
static inline size_t lean_int16_to_isize(uint16_t a) { return (size_t)(ptrdiff_t)(int16_t)a; }
/* Int32 */
LEAN_EXPORT int32_t lean_int32_of_big_int(b_lean_obj_arg a);
@@ -2573,6 +2575,8 @@ static inline uint8_t lean_isize_dec_le(size_t a1, size_t a2) {
}
/* ISize -> other */
static inline uint8_t lean_isize_to_int8(size_t a) { return (uint8_t)(int8_t)(ptrdiff_t)a; }
static inline uint16_t lean_isize_to_int16(size_t a) { return (uint16_t)(int16_t)(ptrdiff_t)a; }
static inline uint32_t lean_isize_to_int32(size_t a) { return (uint32_t)(int32_t)(ptrdiff_t)a; }
static inline uint64_t lean_isize_to_int64(size_t a) { return (uint64_t)(int64_t)(ptrdiff_t)a; }

View File

@@ -21,7 +21,7 @@ def compileLeanModule
(leanFile : FilePath)
(oleanFile? ileanFile? cFile? bcFile?: Option FilePath)
(leanPath : SearchPath := []) (rootDir : FilePath := ".")
(dynlibs : Array FilePath := #[]) (dynlibPath : SearchPath := {})
(dynlibs : Array FilePath := #[]) (plugins : Array FilePath := #[])
(leanArgs : Array String := #[]) (lean : FilePath := "lean")
: LogIO Unit := do
let mut args := leanArgs ++
@@ -39,15 +39,16 @@ def compileLeanModule
createParentDirs bcFile
args := args ++ #["-b", bcFile.toString]
for dynlib in dynlibs do
args := args.push s!"--load-dynlib={dynlib}"
args := args ++ #["--load-dynlib", dynlib.toString]
for plugin in plugins do
args := args ++ #["--plugin", plugin.toString]
args := args.push "--json"
withLogErrorPos do
let out rawProc {
args
cmd := lean.toString
env := #[
("LEAN_PATH", leanPath.toString),
(sharedLibPathEnvVar, ( getSearchPath sharedLibPathEnvVar) ++ dynlibPath |>.toString)
("LEAN_PATH", leanPath.toString)
]
}
unless out.stdout.isEmpty do

View File

@@ -6,7 +6,7 @@ Authors: Mac Malone
prelude
import Lake.Build.Data
import Lake.Build.Job.Basic
import Lake.Config.OutFormat
import Lake.Config.Dynlib
/-!
# Simple Builtin Facet Declarations
@@ -22,19 +22,10 @@ open Lean hiding SearchPath
namespace Lake
/-- A dynamic/shared library for linking. -/
structure Dynlib where
/-- Library file path. -/
path : FilePath
/-- Library name without platform-specific prefix/suffix (for `-l`). -/
name : String
/-- Optional library directory (for `-L`). -/
def Dynlib.dir? (self : Dynlib) : Option FilePath :=
self.path.parent
instance : ToText Dynlib := (·.path.toString)
instance : ToJson Dynlib := (·.path.toString)
structure ModuleDeps where
dynlibs : Array FilePath := #[]
plugins : Array FilePath := #[]
deriving Inhabited, Repr
/-! ## Module Facets -/
@@ -58,7 +49,7 @@ The facet which builds all of a module's dependencies
Returns the list of shared libraries to load along with their search path.
-/
abbrev Module.depsFacet := `deps
module_data deps : SearchPath × Array FilePath
module_data deps : ModuleDeps
/--
The core build facet of a Lean file.

View File

@@ -17,26 +17,31 @@ namespace Lake
Builds an `Array` of module imports for a Lean file.
Used by `lake setup-file` to build modules for the Lean server and
by `lake lean` to build the imports of a file.
Returns the set of module dynlibs built (so they can be loaded by Lean).
Returns the dynlibs and plugins built (so they can be loaded by Lean).
-/
def buildImportsAndDeps (leanFile : FilePath) (imports : Array Module) : FetchM (Job (Array FilePath)) := do
withRegisterJob s!"imports ({leanFile})" do
def buildImportsAndDeps
(leanFile : FilePath) (imports : Array Module)
: FetchM (Job ModuleDeps) := do
withRegisterJob s!"setup ({leanFile})" do
if imports.isEmpty then
-- build the package's (and its dependencies') `extraDepTarget`
( getRootPackage).extraDep.fetch <&> (·.map fun _ => #[])
( getRootPackage).extraDep.fetch <&> (·.map fun _ => {})
else
-- build local imports from list
let modJob := Job.mixArray <| imports.mapM (·.olean.fetch)
let precompileImports ( computePrecompileImportsAux leanFile imports).await
let pkgs := precompileImports.foldl (·.insert ·.pkg) OrdPackageSet.empty |>.toArray
let externLibJob := Job.collectArray <|
pkgs.flatMapM (·.externLibs.mapM (·.dynlib.fetch))
let precompileJob := Job.collectArray <|
let externLibsJob fetchExternLibs pkgs
let modLibsJob Job.collectArray <$>
precompileImports.mapM (·.dynlib.fetch)
let job
modJob.bindM fun _ =>
precompileJob.bindM fun modLibs =>
externLibJob.mapM fun externLibs => do
-- NOTE: Lean wants the external library symbols before module symbols
return (externLibs ++ modLibs).map (·.path)
return job
let dynlibsJob ( getRootPackage).dynlibs.fetch
let pluginsJob ( getRootPackage).plugins.fetch
modJob.bindM fun _ =>
modLibsJob.bindM fun modLibs =>
dynlibsJob.bindM fun dynlibs =>
pluginsJob.bindM fun plugins =>
externLibsJob.mapM fun externLibs => do
-- NOTE: Lean wants the external library symbols before module symbols
let dynlibs := (externLibs ++ dynlibs).map (·.path)
let plugins := (modLibs ++ plugins).map (·.path)
return {dynlibs, plugins}

View File

@@ -8,6 +8,7 @@ import Lake.Util.OrdHashSet
import Lake.Util.List
import Lean.Elab.ParseImportsFast
import Lake.Build.Common
import Lake.Build.Target
/-! # Module Facet Builds
Build function definitions for a module's builtin facets.
@@ -99,13 +100,17 @@ def Module.recComputePrecompileImports (mod : Module) : FetchM (Job (Array Modul
def Module.precompileImportsFacetConfig : ModuleFacetConfig precompileImportsFacet :=
mkFacetJobConfig recComputePrecompileImports (buildable := false)
/-- Fetch dynlibs of the packages' external libraries. **For internal use.** -/
def fetchExternLibs (pkgs : Array Package) : FetchM (Job (Array Dynlib)) :=
Job.collectArray <$> pkgs.flatMapM (·.externLibs.mapM (·.dynlib.fetch))
/--
Recursively build a module's dependencies, including:
* Transitive local imports
* Shared libraries (e.g., `extern_lib` targets or precompiled modules)
* `extraDepTargets` of its library
-/
def Module.recBuildDeps (mod : Module) : FetchM (Job (SearchPath × Array FilePath)) := ensureJob do
def Module.recBuildDeps (mod : Module) : FetchM (Job ModuleDeps) := ensureJob do
let extraDepJob mod.lib.extraDep.fetch
/-
@@ -122,17 +127,20 @@ def Module.recBuildDeps (mod : Module) : FetchM (Job (SearchPath × Array FilePa
mod.transImports.fetch else mod.precompileImports.fetch
let precompileImports precompileImports.await
let modLibJobs precompileImports.mapM (·.dynlib.fetch)
let modLibsJob := Job.collectArray modLibJobs
let pkgs := precompileImports.foldl (·.insert ·.pkg) OrdPackageSet.empty
let pkgs := if mod.shouldPrecompile then pkgs.insert mod.pkg else pkgs
let (externJobs, libDirs) recBuildExternDynlibs pkgs.toArray
let externDynlibsJob := Job.collectArray externJobs
let modDynlibsJob := Job.collectArray modLibJobs
let externLibsJob fetchExternLibs pkgs.toArray
let dynlibsJob mod.dynlibs.fetch
let pluginsJob mod.plugins.fetch
extraDepJob.bindM fun _ => do
importJob.bindM fun _ => do
let depTrace takeTrace
modDynlibsJob.bindM fun modDynlibs => do
externDynlibsJob.mapM fun externDynlibs => do
modLibsJob.bindM fun modLibs => do
externLibsJob.bindM fun externLibs => do
dynlibsJob.bindM fun dynlibs => do
pluginsJob.mapM fun plugins => do
match mod.platformIndependent with
| none => addTrace depTrace
| some false => addTrace depTrace; addPlatformTrace
@@ -145,9 +153,9 @@ def Module.recBuildDeps (mod : Module) : FetchM (Job (SearchPath × Array FilePa
Everything else loads fine with just the augmented library path.
* Linux needs the augmented path to resolve nested dependencies in dynlibs.
-/
let dynlibPath := libDirs ++ externDynlibs.filterMap (·.dir?) |>.toList
let dynlibs := externDynlibs.map (·.path) ++ modDynlibs.map (·.path)
return (dynlibPath, dynlibs)
let dynlibs := externLibs.map (·.path) ++ dynlibs.map (·.path)
let plugins := modLibs.map (·.path) ++ plugins.map (·.path)
return {dynlibs, plugins}
/-- The `ModuleFacetConfig` for the builtin `depsFacet`. -/
def Module.depsFacetConfig : ModuleFacetConfig depsFacet :=
@@ -176,14 +184,15 @@ all possible artifacts (i.e., `.olean`, `ilean`, `.c`, and `.bc`).
-/
def Module.recBuildLean (mod : Module) : FetchM (Job Unit) := do
withRegisterJob mod.name.toString do
( mod.deps.fetch).mapM fun (dynlibPath, dynlibs) => do
( mod.deps.fetch).mapM fun {dynlibs, plugins} => do
addLeanTrace
addPureTrace mod.leanArgs
let srcTrace computeTrace (TextFilePath.mk mod.leanFile)
addTrace srcTrace
let upToDate buildUnlessUpToDate? (oldTrace := srcTrace.mtime) mod ( getTrace) mod.traceFile do
compileLeanModule mod.leanFile mod.oleanFile mod.ileanFile mod.cFile mod.bcFile?
( getLeanPath) mod.rootDir dynlibs dynlibPath (mod.weakLeanArgs ++ mod.leanArgs) ( getLean)
( getLeanPath) mod.rootDir dynlibs plugins
(mod.weakLeanArgs ++ mod.leanArgs) ( getLean)
mod.clearOutputHashes
unless upToDate && ( getTrustHash) do
mod.cacheOutputHashes
@@ -297,38 +306,47 @@ def Module.oFacetConfig : ModuleFacetConfig oFacet :=
| .default | .c => mod.co.fetch
| .llvm => mod.bco.fetch
-- TODO: Return `Job OrdModuleSet × OrdPackageSet` or `OrdRBSet Dynlib`
/-- Recursively build the shared library of a module (e.g., for `--load-dynlib`). -/
/--
Recursively build the shared library of a module
(e.g., for `--load-dynlib` or `--plugin`).
-/
-- TODO: Return `Job OrdModuleSet × OrdPackageSet` or `OrdRBSet Dynlib`?
def Module.recBuildDynlib (mod : Module) : FetchM (Job Dynlib) :=
withRegisterJob s!"{mod.name}:dynlib" do
-- Compute dependencies
let transImports ( mod.transImports.fetch).await
let modJobs transImports.mapM (·.dynlib.fetch)
let pkgs := transImports.foldl (·.insert ·.pkg)
OrdPackageSet.empty |>.insert mod.pkg |>.toArray
let (externJobs, pkgLibDirs) recBuildExternDynlibs pkgs
-- Fetch object files
let linkJobs mod.nativeFacets true |>.mapM (fetch <| mod.facet ·.name)
-- Collect Jobs
let linksJob := Job.collectArray linkJobs
let modDynlibsJob := Job.collectArray modJobs
let externDynlibsJob := Job.collectArray externJobs
-- Fetch dependencies' dynlibs
-- for platforms that must link to them (e.g., Windows)
let (modLibsJob, externLibsJob) id do
if Platform.isWindows then
let transImports ( mod.transImports.fetch).await
let modLibsJob Job.collectArray <$> transImports.mapM (·.dynlib.fetch)
let pkgs := transImports.foldl (·.insert ·.pkg)
OrdPackageSet.empty |>.insert mod.pkg |>.toArray
let externLibsJob fetchExternLibs pkgs
return (modLibsJob, externLibsJob)
else
return (Job.pure #[], Job.pure #[])
-- Build dynlib
linksJob.bindM fun links => do
modDynlibsJob.bindM fun modDynlibs => do
externDynlibsJob.mapM fun externDynlibs => do
modLibsJob.bindM fun modLibs => do
externLibsJob.mapM fun externLibs => do
addLeanTrace
addPlatformTrace -- shared libraries are platform-dependent artifacts
addPureTrace mod.linkArgs
buildFileUnlessUpToDate' mod.dynlibFile do
let lean getLeanInstall
let libDirs := pkgLibDirs ++ externDynlibs.filterMap (·.dir?)
let libNames := modDynlibs.map (·.name) ++ externDynlibs.map (·.name)
let args :=
links.map toString ++
libDirs.map (s!"-L{·}") ++ libNames.map (s!"-l{·}") ++
let args := links.map toString
let args :=
if Platform.isWindows then
args ++ (modLibs ++ externLibs).map (·.path.toString)
else
args
let args := args ++
mod.weakLinkArgs ++ mod.linkArgs ++ lean.ccLinkSharedFlags
compileSharedLib mod.dynlibFile args lean.cc
return mod.dynlibFile, mod.dynlibName

View File

@@ -0,0 +1,8 @@
/-
Copyright (c) 2025 Mac Malone. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Mac Malone
-/
prelude
import Lake.Build.Target.Basic
import Lake.Build.Target.Fetch

View File

@@ -0,0 +1,37 @@
/-
Copyright (c) 2025 Mac Malone. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Mac Malone
-/
prelude
import Lake.Build.Data
/-!
# Lake Targets
This module contains the declarative definition of a `Target`.
-/
open Std
namespace Lake
/-- A Lake target that is known to produce an output of a specific type. -/
structure Target (α : Type) where
key : BuildKey
[data_def : FamilyDef BuildData key α]
deriving Repr
protected def Target.repr (x : Target α) (prec : Nat) : Format :=
let indent := if prec >= max_prec then 1 else 2
let ctor := "Lake.Target.mk" ++ Format.line ++ reprArg x.key
Repr.addAppParen (.group (.nest indent ctor)) prec
instance : Repr (Target α) := Target.repr
instance : ToString (Target α) := (·.key.toString)
/--
Shorthand for `Array (Target α)` that supports
dot notation for Lake-specific operations (e.g., `fetch`).
-/
abbrev TargetArray α := Array (Target α)

View File

@@ -0,0 +1,40 @@
/-
Copyright (c) 2025 Mac Malone. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Mac Malone
-/
prelude
import Lake.Build.Job
import Lake.Config.Monad
namespace Lake
protected def BuildKey.fetch (self : BuildKey) [h : FamilyOut BuildData self α] : FetchM (Job α) := do
match self_eq:self with
| moduleFacet modName facetName =>
let some mod findModule? modName
| error s!"invalid target '{self}': module '{modName}' not found in workspace"
have : FamilyOut BuildData (mod.facet facetName).key α :=
by simpa [self_eq] using h.family_key_eq_type
fetch <| mod.facet facetName
| packageFacet pkgName facetName =>
let some pkg findPackage? pkgName
| error s!"invalid target '{self}': package '{pkgName}' not found in workspace"
have : FamilyOut BuildData (pkg.facet facetName).key α :=
by simpa [self_eq] using h.family_key_eq_type
fetch <| pkg.facet facetName
| targetFacet pkgName targetName facetName =>
-- TODO: Support this
error s!"unsupported target {self}: fetching builtin targets by key is not currently supported"
| customTarget pkgName targetName =>
let some pkg findPackage? pkgName
| error s!"invalid target '{self}': package '{pkgName}' not found in workspace"
have : FamilyOut BuildData (pkg.target targetName).key α :=
by simpa [self_eq] using h.family_key_eq_type
fetch <| pkg.target targetName
@[inline] protected def Target.fetch (self : Target α) : FetchM (Job α) := do
have := self.data_def; self.key.fetch
protected def TargetArray.fetch (self : TargetArray α) : FetchM (Job (Array α)) := do
Job.collectArray <$> self.mapM (·.fetch)

View File

@@ -517,10 +517,13 @@ protected def lean : CliM PUnit := do
let ws loadWorkspace ( mkLoadConfig opts)
let imports Lean.parseImports' ( IO.FS.readFile leanFile) leanFile
let imports := imports.filterMap (ws.findModule? ·.module)
let dynlibs ws.runBuild (buildImportsAndDeps leanFile imports) (mkBuildConfig opts)
let {dynlibs, plugins}
ws.runBuild (buildImportsAndDeps leanFile imports) (mkBuildConfig opts)
let spawnArgs := {
args :=
#[leanFile] ++ dynlibs.map (s!"--load-dynlib={·}") ++
#[leanFile] ++
dynlibs.map (s!"--load-dynlib={·}") ++
plugins.map (s!"--plugin={·}") ++
ws.root.moreLeanArgs ++ opts.subArgs
cmd := ws.lakeEnv.lean.lean.toString
env := ws.augmentedEnvVars

View File

@@ -43,14 +43,14 @@ def setupFile
loadWorkspace loadConfig
let imports := imports.foldl (init := #[]) fun imps imp =>
if let some mod := ws.findModule? imp.toName then imps.push mod else imps
let dynlibs MainM.runLogIO (minLv := outLv) (ansiMode := .noAnsi) do
ws.runBuild (buildImportsAndDeps path imports) buildConfig
let {dynlibs, plugins}
MainM.runLogIO (minLv := outLv) (ansiMode := .noAnsi) do
ws.runBuild (buildImportsAndDeps path imports) buildConfig
let paths : LeanPaths := {
oleanPath := ws.leanPath
srcPath := ws.leanSrcPath
loadDynlibPaths := dynlibs
pluginPaths := #[]
: LeanPaths
pluginPaths := plugins
}
let setupOptions : LeanOptions do
let some moduleName searchModuleNameOfFileName path ws.leanSrcPath

View File

@@ -36,7 +36,7 @@ def defaultManifestFile : FilePath := "lake-manifest.json"
def defaultBuildDir : FilePath := defaultLakeDir / "build"
/-- The default Lean library directory for packages. -/
def defaultLeanLibDir : FilePath := "lib"
def defaultLeanLibDir : FilePath := "lib" / "lean"
/-- The default native library directory for packages. -/
def defaultNativeLibDir : FilePath := "lib"

View File

@@ -0,0 +1,25 @@
/-
Copyright (c) 2022 Mac Malone. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Mac Malone
-/
prelude
import Lake.Config.OutFormat
open System Lean
namespace Lake
/-- A dynamic/shared library for linking. -/
structure Dynlib where
/-- Library file path. -/
path : FilePath
/-- Library name without platform-specific prefix/suffix (for `-l`). -/
name : String
/-- Optional library directory (for `-L`). -/
def Dynlib.dir? (self : Dynlib) : Option FilePath :=
self.path.parent
instance : ToText Dynlib := (·.path.toString)
instance : ToJson Dynlib := (·.path.toString)

View File

@@ -5,6 +5,10 @@ Authors: Mac Malone
-/
prelude
import Lean.Util.LeanOptions
import Lake.Build.Target.Basic
import Lake.Config.Dynlib
open System
namespace Lake
@@ -218,4 +222,14 @@ structure LeanConfig where
and Lake will not catch it. Defaults to `none`.
-/
platformIndependent : Option Bool := none
/-
An array of dynamic library targets to load during the elaboration
of a module (via `lean --load-dynlib`).
-/
dynlibs : TargetArray Dynlib := #[]
/-
An array of Lean plugin targets to load during the elaboration
of a module (via `lean --plugin`).
-/
plugins : TargetArray Dynlib := #[]
deriving Inhabited, Repr

View File

@@ -126,6 +126,20 @@ then the default (which is C for now).
@[inline] def backend (self : LeanLib) : Backend :=
Backend.orPreferLeft self.config.backend self.pkg.backend
/--
The dynamic libraries to load for modules of this library.
The targets of the package plus the targets of the library (in that order).
-/
@[inline] def dynlibs (self : LeanLib) : TargetArray Dynlib :=
self.pkg.dynlibs ++ self.config.dynlibs
/--
The Lean plugins for modules of this library.
The targets of the package plus the targets of the library (in that order).
-/
@[inline] def plugins (self : LeanLib) : TargetArray Dynlib :=
self.pkg.plugins ++ self.config.plugins
/--
The arguments to pass to `lean` when compiling the library's Lean files.
`leanArgs` is the accumulation of:

View File

@@ -8,6 +8,7 @@ import Lake.Build.Trace
import Lake.Config.LeanLib
import Lake.Config.OutFormat
import Lake.Util.OrdHashSet
import Lean.Compiler.NameMangling
namespace Lake
open Lean System
@@ -113,11 +114,16 @@ def bcFile? (self : Module) : Option FilePath :=
def dynlibSuffix := "-1"
@[inline] def dynlibName (self : Module) : String :=
-- NOTE: file name MUST be unique on Windows
self.name.toStringWithSep "-" (escape := true) ++ dynlibSuffix
/-
* File name MUST be unique on Windows
* Uses the mangled module name so the library name matches the
name used for the module's initialization function, thus enabling it
to be loaded as a plugin.
-/
self.name.mangle ""
@[inline] def dynlibFile (self : Module) : FilePath :=
self.pkg.nativeLibDir / nameToSharedLib self.dynlibName
self.pkg.leanLibDir / s!"{self.dynlibName}.{sharedLibExt}"
@[inline] def serverOptions (self : Module) : Array LeanOption :=
self.lib.serverOptions
@@ -128,6 +134,12 @@ def dynlibSuffix := "-1"
@[inline] def backend (self : Module) : Backend :=
self.lib.backend
@[inline] def dynlibs (self : Module) : TargetArray Dynlib :=
self.lib.dynlibs
@[inline] def plugins (self : Module) : TargetArray Dynlib :=
self.lib.plugins
@[inline] def leanArgs (self : Module) : Array String :=
self.lib.leanArgs

View File

@@ -5,7 +5,6 @@ Authors: Mac Malone
-/
prelude
import Lean.Data.Json
import Lake.Build.Job.Basic
open Lean

View File

@@ -596,6 +596,14 @@ namespace Package
@[inline] def backend (self : Package) : Backend :=
self.config.backend
/-- The package's `dynlibs` configuration. -/
@[inline] def dynlibs (self : Package) : TargetArray Dynlib :=
self.config.dynlibs
/-- The package's `plugins` configuration. -/
@[inline] def plugins (self : Package) : TargetArray Dynlib :=
self.config.plugins
/-- The package's `leanOptions` configuration. -/
@[inline] def leanOptions (self : Package) : Array LeanOption :=
self.config.leanOptions

View File

@@ -148,7 +148,11 @@ attribute [simp] FamilyOut.family_key_eq_type
instance [FamilyDef Fam a β] : FamilyOut Fam a β where
family_key_eq_type := FamilyDef.family_key_eq_type
/-- The constant type family -/
/-- The identity relation. -/
@[default_instance 0] instance : FamilyDef Fam a (Fam a) where
family_key_eq_type := rfl
/-- The constant type family. -/
instance : FamilyDef (fun _ => β) a β where
family_key_eq_type := rfl

View File

@@ -5,10 +5,10 @@ LAKE=${LAKE:-../../.lake/build/bin/lake}
./clean.sh
# Tests that a non-precmpiled build does not load anything as a dynlib
# Tests that a non-precompiled build does not load anything as a dynlib/plugin
# https://github.com/leanprover/lean4/issues/4565
$LAKE -d app build -v | (grep --color load-dynlib && exit 1 || true)
$LAKE -d lib build -v | (grep --color load-dynlib && exit 1 || true)
$LAKE -d app build -v | (grep --color -E 'load-dynlib|plugin' && exit 1 || true)
$LAKE -d lib build -v | (grep --color -E 'load-dynlib|plugin' && exit 1 || true)
./app/.lake/build/bin/app
./lib/.lake/build/bin/test

View File

@@ -50,13 +50,13 @@ $LAKE build Foo.Bar:print_src | grep --color Bar.lean
# Test the module `deps` facet
$LAKE build +Foo:deps
test -f ./.lake/build/lib/Foo/Bar.olean
test ! -f ./.lake/build/lib/Foo.olean
test -f ./.lake/build/lib/lean/Foo/Bar.olean
test ! -f ./.lake/build/lib/lean/Foo.olean
# Test the module specifier
test ! -f ./.lake/build/lib/Foo/Baz.olean
test ! -f ./.lake/build/lib/lean/Foo/Baz.olean
$LAKE build +Foo.Baz
test -f ./.lake/build/lib/Foo/Baz.olean
test -f ./.lake/build/lib/lean/Foo/Baz.olean
# Test an object file specifier
test ! -f ./.lake/build/ir/Bar.c.o.export
@@ -65,12 +65,12 @@ test -f ./.lake/build/ir/Bar.c.o.export
# Test default targets
test ! -f ./.lake/build/bin/c
test ! -f ./.lake/build/lib/Foo.olean
test ! -f ./.lake/build/lib/lean/Foo.olean
test ! -f ./.lake/build/lib/${LIB_PREFIX}Foo.a
test ! -f ./.lake/build/meow.txt
$LAKE build targets/
./.lake/build/bin/c
test -f ./.lake/build/lib/Foo.olean
test -f ./.lake/build/lib/lean/Foo.olean
test -f ./.lake/build/lib/${LIB_PREFIX}Foo.a
cat ./.lake/build/meow.txt | grep Meow!
@@ -82,15 +82,15 @@ test -f ./.lake/build/lib/${LIB_PREFIX}Foo.$SHARED_LIB_EXT
test -f ./.lake/build/lib/${LIB_PREFIX}Bar.$SHARED_LIB_EXT
# Test dynlib facet
test ! -f ./.lake/build/lib/${LIB_PREFIX}Foo-1.$SHARED_LIB_EXT
test ! -f ./.lake/build/lib/lean/Foo.$SHARED_LIB_EXT
$LAKE build +Foo:dynlib
test -f ./.lake/build/lib/${LIB_PREFIX}Foo-1.$SHARED_LIB_EXT
test -f ./.lake/build/lib/lean/Foo.$SHARED_LIB_EXT
# Test library `extraDepTargets`
test ! -f ./.lake/build/caw.txt
test ! -f ./.lake/build/lib/Baz.olean
test ! -f ./.lake/build/lib/lean/Baz.olean
$LAKE build Baz
test -f ./.lake/build/lib/Baz.olean
test -f ./.lake/build/lib/lean/Baz.olean
cat ./.lake/build/caw.txt | grep Caw!
# Test executable build

Some files were not shown because too many files have changed in this diff Show More