mirror of
https://github.com/leanprover/lean4.git
synced 2026-04-21 19:44:07 +00:00
Compare commits
11 Commits
sg/control
...
sg/partial
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7f3f10040f | ||
|
|
cdcb7a4b7f | ||
|
|
f371d95150 | ||
|
|
b6d223f083 | ||
|
|
8621ede3d8 | ||
|
|
07b04adc4a | ||
|
|
ae2568a32c | ||
|
|
55d8907f84 | ||
|
|
b39d61533d | ||
|
|
18434c18a5 | ||
|
|
937166e16d |
6
.github/workflows/ci.yml
vendored
6
.github/workflows/ci.yml
vendored
@@ -279,8 +279,7 @@ jobs:
|
||||
"os": large ? "nscloud-ubuntu-24.04-amd64-8x16-with-cache" : "ubuntu-latest",
|
||||
"enabled": true,
|
||||
"check-rebootstrap": level >= 1,
|
||||
// Done as part of test-bench
|
||||
//"check-stage3": level >= 2,
|
||||
"check-stage3": level >= 2,
|
||||
"test": true,
|
||||
// NOTE: `test-bench` currently seems to be broken on `ubuntu-latest`
|
||||
"test-bench": large && level >= 2,
|
||||
@@ -292,8 +291,7 @@ jobs:
|
||||
"os": large ? "nscloud-ubuntu-24.04-amd64-8x16-with-cache" : "ubuntu-latest",
|
||||
"enabled": true,
|
||||
"check-rebootstrap": level >= 1,
|
||||
// Done as part of test-bench
|
||||
//"check-stage3": level >= 2,
|
||||
"check-stage3": level >= 2,
|
||||
"test": true,
|
||||
"secondary": true,
|
||||
// NOTE: `test-bench` currently seems to be broken on `ubuntu-latest`
|
||||
|
||||
@@ -59,9 +59,9 @@ Examples:
|
||||
* `Nat.repeat f 3 a = f <| f <| f <| a`
|
||||
* `Nat.repeat (· ++ "!") 4 "Hello" = "Hello!!!!"`
|
||||
-/
|
||||
@[specialize, expose] def «repeat» {α : Type u} (f : α → α) : (n : Nat) → (a : α) → α
|
||||
@[specialize, expose] def repeat {α : Type u} (f : α → α) : (n : Nat) → (a : α) → α
|
||||
| 0, a => a
|
||||
| succ n, a => f («repeat» f n a)
|
||||
| succ n, a => f (repeat f n a)
|
||||
|
||||
/--
|
||||
Applies a function to a starting value the specified number of times.
|
||||
@@ -1221,9 +1221,9 @@ theorem not_lt_eq (a b : Nat) : (¬ (a < b)) = (b ≤ a) :=
|
||||
theorem not_gt_eq (a b : Nat) : (¬ (a > b)) = (a ≤ b) :=
|
||||
not_lt_eq b a
|
||||
|
||||
@[csimp] theorem repeat_eq_repeatTR : @«repeat» = @repeatTR :=
|
||||
@[csimp] theorem repeat_eq_repeatTR : @repeat = @repeatTR :=
|
||||
funext fun α => funext fun f => funext fun n => funext fun init =>
|
||||
let rec go : ∀ m n, repeatTR.loop f m («repeat» f n init) = «repeat» f (m + n) init
|
||||
let rec go : ∀ m n, repeatTR.loop f m (repeat f n init) = repeat f (m + n) init
|
||||
| 0, n => by simp [repeatTR.loop]
|
||||
| succ m, n => by rw [repeatTR.loop, succ_add]; exact go m (succ n)
|
||||
(go n 0).symm
|
||||
|
||||
@@ -87,7 +87,7 @@ public theorem IsLinearOrder.of_ord {α : Type u} [LE α] [Ord α] [LawfulOrderO
|
||||
/--
|
||||
This lemma derives a `LawfulOrderLT α` instance from a property involving an `Ord α` instance.
|
||||
-/
|
||||
public theorem LawfulOrderLT.of_ord (α : Type u) [Ord α] [LT α] [LE α] [LawfulOrderOrd α]
|
||||
public instance LawfulOrderLT.of_ord (α : Type u) [Ord α] [LT α] [LE α] [LawfulOrderOrd α]
|
||||
(lt_iff_compare_eq_lt : ∀ a b : α, a < b ↔ compare a b = .lt) :
|
||||
LawfulOrderLT α where
|
||||
lt_iff a b := by
|
||||
@@ -96,7 +96,7 @@ public theorem LawfulOrderLT.of_ord (α : Type u) [Ord α] [LT α] [LE α] [Lawf
|
||||
/--
|
||||
This lemma derives a `LawfulOrderBEq α` instance from a property involving an `Ord α` instance.
|
||||
-/
|
||||
public theorem LawfulOrderBEq.of_ord (α : Type u) [Ord α] [BEq α] [LE α] [LawfulOrderOrd α]
|
||||
public instance LawfulOrderBEq.of_ord (α : Type u) [Ord α] [BEq α] [LE α] [LawfulOrderOrd α]
|
||||
(beq_iff_compare_eq_eq : ∀ a b : α, a == b ↔ compare a b = .eq) :
|
||||
LawfulOrderBEq α where
|
||||
beq_iff_le_and_ge := by
|
||||
@@ -105,7 +105,7 @@ public theorem LawfulOrderBEq.of_ord (α : Type u) [Ord α] [BEq α] [LE α] [La
|
||||
/--
|
||||
This lemma derives a `LawfulOrderInf α` instance from a property involving an `Ord α` instance.
|
||||
-/
|
||||
public theorem LawfulOrderInf.of_ord (α : Type u) [Ord α] [Min α] [LE α] [LawfulOrderOrd α]
|
||||
public instance LawfulOrderInf.of_ord (α : Type u) [Ord α] [Min α] [LE α] [LawfulOrderOrd α]
|
||||
(compare_min_isLE_iff : ∀ a b c : α,
|
||||
(compare a (min b c)).isLE ↔ (compare a b).isLE ∧ (compare a c).isLE) :
|
||||
LawfulOrderInf α where
|
||||
@@ -114,7 +114,7 @@ public theorem LawfulOrderInf.of_ord (α : Type u) [Ord α] [Min α] [LE α] [La
|
||||
/--
|
||||
This lemma derives a `LawfulOrderMin α` instance from a property involving an `Ord α` instance.
|
||||
-/
|
||||
public theorem LawfulOrderMin.of_ord (α : Type u) [Ord α] [Min α] [LE α] [LawfulOrderOrd α]
|
||||
public instance LawfulOrderMin.of_ord (α : Type u) [Ord α] [Min α] [LE α] [LawfulOrderOrd α]
|
||||
(compare_min_isLE_iff : ∀ a b c : α,
|
||||
(compare a (min b c)).isLE ↔ (compare a b).isLE ∧ (compare a c).isLE)
|
||||
(min_eq_or : ∀ a b : α, min a b = a ∨ min a b = b) :
|
||||
@@ -125,7 +125,7 @@ public theorem LawfulOrderMin.of_ord (α : Type u) [Ord α] [Min α] [LE α] [La
|
||||
/--
|
||||
This lemma derives a `LawfulOrderSup α` instance from a property involving an `Ord α` instance.
|
||||
-/
|
||||
public theorem LawfulOrderSup.of_ord (α : Type u) [Ord α] [Max α] [LE α] [LawfulOrderOrd α]
|
||||
public instance LawfulOrderSup.of_ord (α : Type u) [Ord α] [Max α] [LE α] [LawfulOrderOrd α]
|
||||
(compare_max_isLE_iff : ∀ a b c : α,
|
||||
(compare (max a b) c).isLE ↔ (compare a c).isLE ∧ (compare b c).isLE) :
|
||||
LawfulOrderSup α where
|
||||
@@ -134,7 +134,7 @@ public theorem LawfulOrderSup.of_ord (α : Type u) [Ord α] [Max α] [LE α] [La
|
||||
/--
|
||||
This lemma derives a `LawfulOrderMax α` instance from a property involving an `Ord α` instance.
|
||||
-/
|
||||
public theorem LawfulOrderMax.of_ord (α : Type u) [Ord α] [Max α] [LE α] [LawfulOrderOrd α]
|
||||
public instance LawfulOrderMax.of_ord (α : Type u) [Ord α] [Max α] [LE α] [LawfulOrderOrd α]
|
||||
(compare_max_isLE_iff : ∀ a b c : α,
|
||||
(compare (max a b) c).isLE ↔ (compare a c).isLE ∧ (compare b c).isLE)
|
||||
(max_eq_or : ∀ a b : α, max a b = a ∨ max a b = b) :
|
||||
|
||||
@@ -75,9 +75,6 @@ theorem nat_eq (a b : Nat) (x y : Int) : NatCast.natCast a = x → NatCast.natCa
|
||||
theorem of_nat_eq (a b : Nat) (x y : Int) : NatCast.natCast a = x → NatCast.natCast b = y → a = b → x = y := by
|
||||
intro _ _; subst x y; intro; simp [*]
|
||||
|
||||
theorem of_natCast_eq {α : Type u} [NatCast α] (a b : Nat) (x y : α) : NatCast.natCast a = x → NatCast.natCast b = y → a = b → x = y := by
|
||||
intro h₁ h₂ h; subst h; exact h₁.symm.trans h₂
|
||||
|
||||
theorem le_of_not_le {α} [LE α] [Std.IsLinearPreorder α]
|
||||
{a b : α} : ¬ a ≤ b → b ≤ a := by
|
||||
intro h
|
||||
|
||||
@@ -7,5 +7,7 @@ module
|
||||
|
||||
prelude
|
||||
public import Init.Internal.Order.Basic
|
||||
public import Init.Internal.Order.ExtrinsicFix
|
||||
public import Init.Internal.Order.Lemmas
|
||||
public import Init.Internal.Order.MonadTail
|
||||
public import Init.Internal.Order.Tactic
|
||||
|
||||
117
src/Init/Internal/Order/ExtrinsicFix.lean
Normal file
117
src/Init/Internal/Order/ExtrinsicFix.lean
Normal file
@@ -0,0 +1,117 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sebastian Graf, Robin Arnez
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.Internal.Order.Basic
|
||||
public import Init.Internal.Order.MonadTail
|
||||
import Init.Classical
|
||||
|
||||
set_option linter.missingDocs true
|
||||
|
||||
/-!
|
||||
This module provides a fixpoint combinator that combines advantages of `partial` and
|
||||
`partial_fixpoint` recursion.
|
||||
|
||||
The combinator is similar to {lean}`CCPO.fix`, but does not require a CCPO instance or
|
||||
monotonicity proof at the definition site. Therefore, it can be used in situations in which
|
||||
these constraints are unavailable (e.g., when the monad does not have a `MonadTail` instance).
|
||||
Given a CCPO and monotonicity proof, there are theorems guaranteeing that it equals
|
||||
{lean}`CCPO.fix` and satisfies fixpoint induction.
|
||||
-/
|
||||
|
||||
public section
|
||||
|
||||
namespace Lean.Order
|
||||
|
||||
/--
|
||||
The function implemented as the loop {lean}`opaqueFix f = f (opaqueFix f)`.
|
||||
{lean}`opaqueFix f` is the fixpoint of {name}`f`, as long as `f` is monotone with respect to
|
||||
some CCPO on `α`.
|
||||
|
||||
The loop might run forever depending on {name}`f`. It is opaque, i.e., it is impossible to prove
|
||||
nontrivial properties about it.
|
||||
-/
|
||||
@[specialize]
|
||||
partial def opaqueFix {α : Sort u} [Nonempty α] (f : α → α) : α :=
|
||||
f (opaqueFix f)
|
||||
|
||||
/-
|
||||
SAFE assuming that the code generated by iteration over `f` is equivalent to the CCPO
|
||||
fixpoint of `f` if there exists a CCPO making `f` monotone.
|
||||
-/
|
||||
open _root_.Classical in
|
||||
/--
|
||||
A fixpoint combinator that can be used to construct recursive definitions with an *extrinsic*
|
||||
proof of monotonicity.
|
||||
|
||||
Given a fixpoint functional {name}`f`, {lean}`extrinsicFix f` is the recursive function obtained
|
||||
by having {name}`f` call itself recursively.
|
||||
|
||||
If there is no CCPO on `α` making {name}`f` monotone, {lean}`extrinsicFix f` might run forever.
|
||||
In this case, nothing interesting can be proved about the result; it is opaque.
|
||||
|
||||
If there _is_ a CCPO on `α` making {name}`f` monotone, {lean}`extrinsicFix f` is equivalent to
|
||||
{lean}`CCPO.fix f`, logically and regarding its termination behavior.
|
||||
-/
|
||||
@[cbv_opaque, implemented_by opaqueFix]
|
||||
def extrinsicFix {α : Sort u} [Nonempty α] (f : α → α) : α :=
|
||||
if h : ∃ x, x = f x then
|
||||
h.choose
|
||||
else
|
||||
-- Return `opaqueFix f` so that `implemented_by opaqueFix` is sound.
|
||||
-- In effect, `extrinsicFix` is opaque if no fixpoint exists.
|
||||
opaqueFix f
|
||||
|
||||
/--
|
||||
A fixpoint combinator that allows for deferred proofs of monotonicity.
|
||||
|
||||
{lean}`extrinsicFix f` is a function implemented as the loop
|
||||
{lean}`extrinsicFix f = f (extrinsicFix f)`.
|
||||
|
||||
If there is a CCPO making `f` monotone, {name}`extrinsicFix_eq` proves that it satisfies the
|
||||
fixpoint equation, and {name}`extrinsicFix_induct` enables fixpoint induction.
|
||||
Otherwise, {lean}`extrinsicFix f` is opaque, i.e., it is impossible to prove nontrivial
|
||||
properties about it.
|
||||
-/
|
||||
add_decl_doc extrinsicFix
|
||||
|
||||
/-- Every CCPO has at least one element (the bottom element). -/
|
||||
noncomputable local instance CCPO.instNonempty [CCPO α] : Nonempty α := ⟨bot⟩
|
||||
|
||||
/--
|
||||
The fixpoint equation for `extrinsicFix`: given a proof that the fixpoint exists, unfold
|
||||
`extrinsicFix`.
|
||||
-/
|
||||
theorem extrinsicFix_eq {f : α → α}
|
||||
(x : α) (h : x = f x) :
|
||||
letI : Nonempty α := ⟨x⟩; extrinsicFix f = f (extrinsicFix f) := by
|
||||
letI : Nonempty α := ⟨x⟩
|
||||
have h : ∃ x, x = f x := ⟨x, h⟩
|
||||
simp only [extrinsicFix, dif_pos h]
|
||||
exact h.choose_spec
|
||||
|
||||
/--
|
||||
The fixpoint equation for `extrinsicFix`: given a CCPO instance and monotonicity of `f`,
|
||||
{lean}`extrinsicFix f = f (extrinsicFix f)`.
|
||||
-/
|
||||
theorem extrinsicFix_eq_mono [inst : CCPO α] {f : α → α}
|
||||
(hf : monotone f) :
|
||||
extrinsicFix f = f (extrinsicFix f) :=
|
||||
extrinsicFix_eq (fix f hf) (fix_eq hf)
|
||||
|
||||
abbrev discardR {C : α → Sort _} {R : α → α → Prop}
|
||||
(f : ∀ a, (∀ a', R a' a → C a') → C a) : ((∀ a, C a) → (∀ a, C a)) :=
|
||||
fun rec a => f a (fun a _ => rec a)
|
||||
|
||||
theorem extrinsicFix_eq_wf {C : α → Sort _} [∀ a, Nonempty (C a)] {R : α → α → Prop}
|
||||
{f : ∀ a, (∀ a', R a' a → C a') → C a} (h : WellFounded R) {a : α} :
|
||||
extrinsicFix (discardR f) a = f a (fun a _ => extrinsicFix (discardR f) a) := by
|
||||
suffices extrinsicFix (discardR f) = fun a => f a (fun a _ => extrinsicFix (discardR f) a) by
|
||||
conv => lhs; rw [this]
|
||||
apply extrinsicFix_eq (fun a => WellFounded.fix h f a) (funext fun a => (WellFounded.fix_eq h f a))
|
||||
|
||||
end Lean.Order
|
||||
140
src/Init/Internal/Order/MonadTail.lean
Normal file
140
src/Init/Internal/Order/MonadTail.lean
Normal file
@@ -0,0 +1,140 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sebastian Graf
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.Internal.Order.Basic
|
||||
import all Init.System.ST -- for EST.bind in MonadTail instance
|
||||
|
||||
set_option linter.missingDocs true
|
||||
|
||||
public section
|
||||
|
||||
namespace Lean.Order
|
||||
|
||||
/--
|
||||
A *tail monad* is a monad whose bind operation preserves any suitable ordering of the continuation.
|
||||
Specifically, `MonadTail m` asserts that every `m β` carries a chain-complete partial order (CCPO)
|
||||
and that `>>=` is monotone in its second (continuation) argument with respect to that order.
|
||||
|
||||
This is a weaker requirement than `MonoBind`, which requires monotonicity in both arguments.
|
||||
`MonadTail` is sufficient for `partial_fixpoint`-based recursive definitions where the
|
||||
recursive call only appears in the continuation (second argument) of `>>=`.
|
||||
-/
|
||||
class MonadTail (m : Type u → Type v) [Bind m] where
|
||||
/-- Every `m β` with `Nonempty β` has a chain-complete partial order. -/
|
||||
instCCPO β [Nonempty β] : CCPO (m β)
|
||||
/-- Bind is monotone in the second (continuation) argument. -/
|
||||
bind_mono_right {a : m α} {f₁ f₂ : α → m β} [Nonempty β] (h : ∀ x, f₁ x ⊑ f₂ x) :
|
||||
a >>= f₁ ⊑ a >>= f₂
|
||||
|
||||
attribute [implicit_reducible] MonadTail.instCCPO
|
||||
attribute [instance] MonadTail.instCCPO
|
||||
|
||||
@[scoped partial_fixpoint_monotone]
|
||||
theorem MonadTail.monotone_bind_right
|
||||
(m : Type u → Type v) [Monad m] [MonadTail m]
|
||||
{α β : Type u} [Nonempty β]
|
||||
{γ : Sort w} [PartialOrder γ]
|
||||
(f : m α) (g : γ → α → m β)
|
||||
(hmono : monotone g) :
|
||||
monotone (fun (x : γ) => f >>= g x) :=
|
||||
fun _ _ h => MonadTail.bind_mono_right (hmono _ _ h)
|
||||
|
||||
instance : MonadTail Id where
|
||||
instCCPO _ := inferInstanceAs (CCPO (FlatOrder (b := Classical.ofNonempty)))
|
||||
bind_mono_right h := h _
|
||||
|
||||
instance {σ : Type u} {m : Type u → Type v} [Monad m] [MonadTail m] [Nonempty σ] :
|
||||
MonadTail (StateT σ m) where
|
||||
instCCPO α := inferInstanceAs (CCPO (σ → m (α × σ)))
|
||||
bind_mono_right h := by
|
||||
intro s
|
||||
show StateT.bind _ _ s ⊑ StateT.bind _ _ s
|
||||
simp only [StateT.bind]
|
||||
apply MonadTail.bind_mono_right (m := m)
|
||||
intro ⟨x, s'⟩
|
||||
exact h x s'
|
||||
|
||||
instance {ε : Type u} {m : Type u → Type v} [Monad m] [MonadTail m] :
|
||||
MonadTail (ExceptT ε m) where
|
||||
instCCPO β := MonadTail.instCCPO (Except ε β)
|
||||
bind_mono_right h := by
|
||||
apply MonadTail.bind_mono_right (m := m)
|
||||
intro x
|
||||
cases x with
|
||||
| error => exact PartialOrder.rel_refl
|
||||
| ok a => exact h a
|
||||
|
||||
instance : MonadTail (Except ε) where
|
||||
instCCPO β := inferInstanceAs (CCPO (FlatOrder (b := Classical.ofNonempty)))
|
||||
bind_mono_right h := by
|
||||
cases ‹Except _ _› with
|
||||
| error => exact FlatOrder.rel.refl
|
||||
| ok a => exact h a
|
||||
|
||||
instance {m : Type u → Type v} [Monad m] [MonadTail m] :
|
||||
MonadTail (OptionT m) where
|
||||
instCCPO β := MonadTail.instCCPO (Option β)
|
||||
bind_mono_right h := by
|
||||
apply MonadTail.bind_mono_right (m := m)
|
||||
intro x
|
||||
cases x with
|
||||
| none => exact PartialOrder.rel_refl
|
||||
| some a => exact h a
|
||||
|
||||
instance : MonadTail Option where
|
||||
instCCPO _ := inferInstance
|
||||
bind_mono_right h := MonoBind.bind_mono_right h
|
||||
|
||||
instance {ρ : Type u} {m : Type u → Type v} [Monad m] [MonadTail m] :
|
||||
MonadTail (ReaderT ρ m) where
|
||||
instCCPO α := inferInstanceAs (CCPO (ρ → m α))
|
||||
bind_mono_right h := by
|
||||
intro r
|
||||
show ReaderT.bind _ _ r ⊑ ReaderT.bind _ _ r
|
||||
simp only [ReaderT.bind]
|
||||
apply MonadTail.bind_mono_right (m := m)
|
||||
intro x
|
||||
exact h x r
|
||||
|
||||
set_option linter.missingDocs false in
|
||||
noncomputable def ST.bot' [Nonempty α] (s : Void σ) : @FlatOrder (ST.Out σ α) (.mk Classical.ofNonempty (Classical.choice ⟨s⟩)) :=
|
||||
.mk Classical.ofNonempty (Classical.choice ⟨s⟩)
|
||||
|
||||
instance [Nonempty α] : CCPO (ST σ α) where
|
||||
rel := PartialOrder.rel (α := ∀ s, FlatOrder (ST.bot' s))
|
||||
rel_refl := PartialOrder.rel_refl
|
||||
rel_antisymm := PartialOrder.rel_antisymm
|
||||
rel_trans := PartialOrder.rel_trans
|
||||
has_csup hchain := CCPO.has_csup (α := ∀ s, FlatOrder (ST.bot' s)) hchain
|
||||
|
||||
instance : MonadTail (ST σ) where
|
||||
instCCPO _ := inferInstance
|
||||
bind_mono_right {_ _ a f₁ f₂} _ h := by
|
||||
intro w
|
||||
change FlatOrder.rel (ST.bind a f₁ w) (ST.bind a f₂ w)
|
||||
simp only [ST.bind]
|
||||
apply h
|
||||
|
||||
instance : MonadTail BaseIO :=
|
||||
inferInstanceAs (MonadTail (ST IO.RealWorld))
|
||||
|
||||
instance [Nonempty ε] : MonadTail (EST ε σ) where
|
||||
instCCPO _ := inferInstance
|
||||
bind_mono_right h := MonoBind.bind_mono_right h
|
||||
|
||||
instance [Nonempty ε] : MonadTail (EIO ε) :=
|
||||
inferInstanceAs (MonadTail (EST ε IO.RealWorld))
|
||||
|
||||
instance : MonadTail IO :=
|
||||
inferInstanceAs (MonadTail (EIO IO.Error))
|
||||
|
||||
instance {ω : Type} {σ : Type} {m : Type → Type} [Monad m] [MonadTail m] :
|
||||
MonadTail (StateRefT' ω σ m) :=
|
||||
inferInstanceAs (MonadTail (ReaderT (ST.Ref ω σ) m))
|
||||
|
||||
end Lean.Order
|
||||
@@ -7,31 +7,70 @@ module
|
||||
|
||||
prelude
|
||||
public import Init.Core
|
||||
public import Init.Internal.Order.ExtrinsicFix
|
||||
|
||||
public section
|
||||
|
||||
/-!
|
||||
# Notation for `while` and `repeat` loops.
|
||||
-/
|
||||
|
||||
namespace Lean
|
||||
|
||||
/-!
|
||||
# `Loop` type backing `repeat`/`while`/`repeat ... until`
|
||||
|
||||
The parsers and elaborators for `repeat`, `while`, and `repeat ... until` live in
|
||||
`Lean.Parser.Do` and `Lean.Elab.BuiltinDo.Repeat`. This module only provides the
|
||||
`Loop` type (and `ForIn` instance) that those elaborators expand to.
|
||||
-/
|
||||
/-! # `repeat` and `while` notation -/
|
||||
|
||||
inductive Loop where
|
||||
| mk
|
||||
|
||||
open Lean.Order in
|
||||
@[inline]
|
||||
partial def Loop.forIn {β : Type u} {m : Type u → Type v} [Monad m] (_ : Loop) (init : β) (f : Unit → β → m (ForInStep β)) : m β :=
|
||||
let rec @[specialize] loop (b : β) : m β := do
|
||||
def Loop.forIn {β : Type u} {m : Type u → Type v} [Monad m]
|
||||
(_ : Loop) (init : β) (f : Unit → β → m (ForInStep β)) : m β :=
|
||||
haveI : Nonempty (β → m β) := ⟨fun b => pure b⟩
|
||||
Lean.Order.extrinsicFix (fun (cont : β → m β) (b : β) => do
|
||||
match ← f () b with
|
||||
| ForInStep.done b => pure b
|
||||
| ForInStep.yield b => loop b
|
||||
loop init
|
||||
| .done val => pure val
|
||||
| .yield val => cont val) init
|
||||
|
||||
instance [Monad m] : ForIn m Loop Unit where
|
||||
forIn := Loop.forIn
|
||||
|
||||
open Lean.Order in
|
||||
theorem Loop.forIn_eq [Monad m] [MonadTail m]
|
||||
{l : Loop} {b : β} {f : Unit → β → m (ForInStep β)} :
|
||||
Loop.forIn l b f = (do
|
||||
match ← f () b with
|
||||
| .done val => pure val
|
||||
| .yield val => Loop.forIn l val f) := by
|
||||
haveI : Nonempty β := ⟨b⟩
|
||||
simp only [Loop.forIn]
|
||||
apply congrFun
|
||||
apply extrinsicFix_eq
|
||||
intro cont₁ cont₂ h b'
|
||||
apply MonadTail.bind_mono_right
|
||||
intro r
|
||||
cases r with
|
||||
| done => exact PartialOrder.rel_refl
|
||||
| yield val => exact h val
|
||||
|
||||
syntax (name := doRepeat) "repeat " doSeq : doElem
|
||||
|
||||
macro_rules
|
||||
| `(doElem| repeat $seq) => `(doElem| for _ in Loop.mk do $seq)
|
||||
|
||||
syntax "while " ident " : " termBeforeDo " do " doSeq : doElem
|
||||
|
||||
macro_rules
|
||||
| `(doElem| while $h : $cond do $seq) => `(doElem| repeat if $h:ident : $cond then $seq else break)
|
||||
|
||||
syntax "while " termBeforeDo " do " doSeq : doElem
|
||||
|
||||
macro_rules
|
||||
| `(doElem| while $cond do $seq) => `(doElem| repeat if $cond then $seq else break)
|
||||
|
||||
syntax "repeat " doSeq ppDedent(ppLine) "until " term : doElem
|
||||
|
||||
macro_rules
|
||||
| `(doElem| repeat $seq until $cond) => `(doElem| repeat do $seq:doSeq; if $cond then break)
|
||||
|
||||
end Lean
|
||||
|
||||
@@ -207,7 +207,7 @@ def emitLns [EmitToString α] (as : List α) : EmitM Unit :=
|
||||
emitLn "}"
|
||||
return ret
|
||||
|
||||
def toDigit (c : Nat) : String :=
|
||||
def toHexDigit (c : Nat) : String :=
|
||||
String.singleton c.digitChar
|
||||
|
||||
def quoteString (s : String) : String :=
|
||||
@@ -221,11 +221,7 @@ def quoteString (s : String) : String :=
|
||||
else if c == '\"' then "\\\""
|
||||
else if c == '?' then "\\?" -- avoid trigraphs
|
||||
else if c.toNat <= 31 then
|
||||
-- Use octal escapes instead of hex escapes because C hex escapes are
|
||||
-- greedy: "\x01abc" would be parsed as the single escape \x01abc rather
|
||||
-- than \x01 followed by "abc". Octal escapes consume at most 3 digits.
|
||||
let n := c.toNat
|
||||
"\\" ++ toDigit (n / 64) ++ toDigit ((n / 8) % 8) ++ toDigit (n % 8)
|
||||
"\\x" ++ toHexDigit (c.toNat / 16) ++ toHexDigit (c.toNat % 16)
|
||||
-- TODO(Leo): we should use `\unnnn` for escaping unicode characters.
|
||||
else String.singleton c)
|
||||
q;
|
||||
|
||||
@@ -90,22 +90,6 @@ partial def eraseProjIncFor (nFields : Nat) (targetId : FVarId) (ds : Array (Cod
|
||||
| break
|
||||
if !(w == z && targetId == x) then
|
||||
break
|
||||
if mask[i]!.isSome then
|
||||
/-
|
||||
Suppose we encounter a situation like
|
||||
```
|
||||
let x.1 := proj[0] y
|
||||
inc x.1
|
||||
let x.2 := proj[0] y
|
||||
inc x.2
|
||||
```
|
||||
The `inc x.2` will already have been removed. If we don't perform this check we will also
|
||||
remove `inc x.1` and have effectively removed two refcounts while only one was legal.
|
||||
-/
|
||||
keep := keep.push d
|
||||
keep := keep.push d'
|
||||
ds := ds.pop.pop
|
||||
continue
|
||||
/-
|
||||
Found
|
||||
```
|
||||
|
||||
@@ -64,12 +64,6 @@ structure WorkspaceClientCapabilities where
|
||||
deriving ToJson, FromJson
|
||||
|
||||
structure LeanClientCapabilities where
|
||||
/--
|
||||
Whether the client supports incremental `textDocument/publishDiagnostics` updates.
|
||||
If `none` or `false`, the server will never set `PublishDiagnosticsParams.isIncremental?`
|
||||
and always report full diagnostic updates that replace the previous one.
|
||||
-/
|
||||
incrementalDiagnosticSupport? : Option Bool := none
|
||||
/--
|
||||
Whether the client supports `DiagnosticWith.isSilent = true`.
|
||||
If `none` or `false`, silent diagnostics will not be served to the client.
|
||||
@@ -90,13 +84,6 @@ structure ClientCapabilities where
|
||||
lean? : Option LeanClientCapabilities := none
|
||||
deriving ToJson, FromJson
|
||||
|
||||
def ClientCapabilities.incrementalDiagnosticSupport (c : ClientCapabilities) : Bool := Id.run do
|
||||
let some lean := c.lean?
|
||||
| return false
|
||||
let some incrementalDiagnosticSupport := lean.incrementalDiagnosticSupport?
|
||||
| return false
|
||||
return incrementalDiagnosticSupport
|
||||
|
||||
def ClientCapabilities.silentDiagnosticSupport (c : ClientCapabilities) : Bool := Id.run do
|
||||
let some lean := c.lean?
|
||||
| return false
|
||||
|
||||
@@ -159,14 +159,6 @@ abbrev Diagnostic := DiagnosticWith String
|
||||
structure PublishDiagnosticsParams where
|
||||
uri : DocumentUri
|
||||
version? : Option Int := none
|
||||
/--
|
||||
Whether the client should append this set of diagnostics to the previous set
|
||||
rather than replacing the previous set by this one (the default LSP behavior).
|
||||
`false` means the client should replace.
|
||||
`none` is equivalent to `false`.
|
||||
This is a Lean-specific extension (see `LeanClientCapabilities`).
|
||||
-/
|
||||
isIncremental? : Option Bool := none
|
||||
diagnostics : Array Diagnostic
|
||||
deriving Inhabited, BEq, ToJson, FromJson
|
||||
|
||||
|
||||
@@ -102,32 +102,9 @@ def normalizePublishDiagnosticsParams (p : PublishDiagnosticsParams) :
|
||||
sorted.toArray
|
||||
}
|
||||
|
||||
/--
|
||||
Merges a new `textDocument/publishDiagnostics` notification into a previously accumulated one.
|
||||
|
||||
- If there is no previous notification, the new one is used as-is.
|
||||
- If `isIncremental?` is `true`, the new diagnostics are appended.
|
||||
- Otherwise the new notification replaces the previous one.
|
||||
|
||||
The returned params always have `isIncremental? := some false` since they represent the full
|
||||
accumulated set.
|
||||
-/
|
||||
def mergePublishDiagnosticsParams (prev? : Option PublishDiagnosticsParams)
|
||||
(next : PublishDiagnosticsParams) : PublishDiagnosticsParams := Id.run do
|
||||
let replace := { next with isIncremental? := some false }
|
||||
let some prev := prev?
|
||||
| return replace
|
||||
if next.isIncremental?.getD false then
|
||||
return { next with
|
||||
diagnostics := prev.diagnostics ++ next.diagnostics
|
||||
isIncremental? := some false }
|
||||
return replace
|
||||
|
||||
/--
|
||||
Waits for the worker to emit all diagnostic notifications for the current document version and
|
||||
returns the accumulated diagnostics, if any.
|
||||
|
||||
Incoming notifications are merged using `mergePublishDiagnosticsParams`.
|
||||
returns the last notification, if any.
|
||||
|
||||
We used to return all notifications but with debouncing in the server, this would not be
|
||||
deterministic anymore as what messages are dropped depends on wall-clock timing.
|
||||
@@ -135,25 +112,22 @@ deterministic anymore as what messages are dropped depends on wall-clock timing.
|
||||
partial def collectDiagnostics (waitForDiagnosticsId : RequestID := 0) (target : DocumentUri) (version : Nat)
|
||||
: IpcM (Option (Notification PublishDiagnosticsParams)) := do
|
||||
writeRequest ⟨waitForDiagnosticsId, "textDocument/waitForDiagnostics", WaitForDiagnosticsParams.mk target version⟩
|
||||
loop none
|
||||
loop
|
||||
where
|
||||
loop (accumulated? : Option PublishDiagnosticsParams) := do
|
||||
loop := do
|
||||
match (←readMessage) with
|
||||
| Message.response id _ =>
|
||||
if id == waitForDiagnosticsId then
|
||||
return accumulated?.map fun p =>
|
||||
⟨"textDocument/publishDiagnostics", normalizePublishDiagnosticsParams p⟩
|
||||
else loop accumulated?
|
||||
| Message.responseError id _ msg _ =>
|
||||
if id == waitForDiagnosticsId then return none
|
||||
else loop
|
||||
| Message.responseError id _ msg _ =>
|
||||
if id == waitForDiagnosticsId then
|
||||
throw $ userError s!"Waiting for diagnostics failed: {msg}"
|
||||
else loop accumulated?
|
||||
else loop
|
||||
| Message.notification "textDocument/publishDiagnostics" (some param) =>
|
||||
match fromJson? (toJson param) with
|
||||
| Except.ok (diagnosticParam : PublishDiagnosticsParams) =>
|
||||
loop (some (mergePublishDiagnosticsParams accumulated? diagnosticParam))
|
||||
| Except.ok diagnosticParam => return (← loop).getD ⟨"textDocument/publishDiagnostics", normalizePublishDiagnosticsParams diagnosticParam⟩
|
||||
| Except.error inner => throw $ userError s!"Cannot decode publishDiagnostics parameters\n{inner}"
|
||||
| _ => loop accumulated?
|
||||
| _ => loop
|
||||
|
||||
partial def waitForILeans (waitForILeansId : RequestID := 0) (target : DocumentUri) (version : Nat) : IpcM Unit := do
|
||||
writeRequest ⟨waitForILeansId, "$/lean/waitForILeans", WaitForILeansParams.mk target version⟩
|
||||
|
||||
@@ -289,11 +289,9 @@ instance : ToMarkdown VersoModuleDocs.Snippet where
|
||||
|
||||
structure VersoModuleDocs where
|
||||
snippets : PersistentArray VersoModuleDocs.Snippet := {}
|
||||
terminalNesting : Option Nat := snippets.findSomeRev? (·.terminalNesting)
|
||||
deriving Inhabited
|
||||
|
||||
def VersoModuleDocs.terminalNesting : VersoModuleDocs → Option Nat
|
||||
| VersoModuleDocs.mk snippets => snippets.findSomeRev? (·.terminalNesting)
|
||||
|
||||
instance : Repr VersoModuleDocs where
|
||||
reprPrec v _ :=
|
||||
.group <| .nest 2 <|
|
||||
@@ -316,7 +314,10 @@ def add (docs : VersoModuleDocs) (snippet : Snippet) : Except String VersoModule
|
||||
unless docs.canAdd snippet do
|
||||
throw "Can't nest this snippet here"
|
||||
|
||||
return { docs with snippets := docs.snippets.push snippet }
|
||||
return { docs with
|
||||
snippets := docs.snippets.push snippet,
|
||||
terminalNesting := snippet.terminalNesting
|
||||
}
|
||||
|
||||
def add! (docs : VersoModuleDocs) (snippet : Snippet) : VersoModuleDocs :=
|
||||
let ok :=
|
||||
@@ -326,7 +327,10 @@ def add! (docs : VersoModuleDocs) (snippet : Snippet) : VersoModuleDocs :=
|
||||
if not ok then
|
||||
panic! "Can't nest this snippet here"
|
||||
else
|
||||
{ docs with snippets := docs.snippets.push snippet }
|
||||
{ docs with
|
||||
snippets := docs.snippets.push snippet,
|
||||
terminalNesting := snippet.terminalNesting
|
||||
}
|
||||
|
||||
|
||||
private structure DocFrame where
|
||||
|
||||
@@ -15,4 +15,3 @@ public import Lean.Elab.BuiltinDo.Jump
|
||||
public import Lean.Elab.BuiltinDo.Misc
|
||||
public import Lean.Elab.BuiltinDo.For
|
||||
public import Lean.Elab.BuiltinDo.TryCatch
|
||||
public import Lean.Elab.BuiltinDo.Repeat
|
||||
|
||||
@@ -21,8 +21,7 @@ def elabDoIdDecl (x : Ident) (xType? : Option Term) (rhs : TSyntax `doElem) (k :
|
||||
let xType ← Term.elabType (xType?.getD (mkHole x))
|
||||
let lctx ← getLCtx
|
||||
let ctx ← read
|
||||
let ref ← getRef -- store the surrounding reference for error messages in `k`
|
||||
elabDoElem rhs <| .mk (kind := kind) x.getId xType do withRef ref do
|
||||
elabDoElem rhs <| .mk (kind := kind) x.getId xType do
|
||||
withLCtxKeepingMutVarDefs lctx ctx x.getId do
|
||||
Term.addLocalVarInfo x (← getFVarFromUserName x.getId)
|
||||
k
|
||||
|
||||
@@ -23,7 +23,7 @@ open Lean.Meta
|
||||
| `(doFor| for $[$_ : ]? $_:ident in $_ do $_) =>
|
||||
-- This is the target form of the expander, handled by `elabDoFor` below.
|
||||
Macro.throwUnsupported
|
||||
| `(doFor| for%$tk $decls:doForDecl,* do $body) =>
|
||||
| `(doFor| for $decls:doForDecl,* do $body) =>
|
||||
let decls := decls.getElems
|
||||
let `(doForDecl| $[$h? : ]? $pattern in $xs) := decls[0]! | Macro.throwUnsupported
|
||||
let mut doElems := #[]
|
||||
@@ -74,13 +74,12 @@ open Lean.Meta
|
||||
| some ($y, s') =>
|
||||
$s:ident := s'
|
||||
do $body)
|
||||
doElems := doElems.push (← `(doSeqItem| for%$tk $[$h? : ]? $x:ident in $xs do $body))
|
||||
doElems := doElems.push (← `(doSeqItem| for $[$h? : ]? $x:ident in $xs do $body))
|
||||
`(doElem| do $doElems*)
|
||||
| _ => Macro.throwUnsupported
|
||||
|
||||
@[builtin_doElem_elab Lean.Parser.Term.doFor] def elabDoFor : DoElab := fun stx dec => do
|
||||
let `(doFor| for%$tk $[$h? : ]? $x:ident in $xs do $body) := stx | throwUnsupportedSyntax
|
||||
let dec ← dec.ensureUnitAt tk
|
||||
let `(doFor| for $[$h? : ]? $x:ident in $xs do $body) := stx | throwUnsupportedSyntax
|
||||
checkMutVarsForShadowing #[x]
|
||||
let uα ← mkFreshLevelMVar
|
||||
let uρ ← mkFreshLevelMVar
|
||||
@@ -125,6 +124,9 @@ open Lean.Meta
|
||||
defs := defs.push (mkConst ``Unit.unit)
|
||||
return defs
|
||||
|
||||
unless ← isDefEq dec.resultType (← mkPUnit) do
|
||||
logError m!"Type mismatch. `for` loops have result type {← mkPUnit}, but the rest of the `do` sequence expected {dec.resultType}."
|
||||
|
||||
let (preS, σ) ← mkProdMkN (← useLoopMutVars none) mi.u
|
||||
|
||||
let (app, p?) ← match h? with
|
||||
|
||||
@@ -17,7 +17,6 @@ namespace Lean.Elab.Do
|
||||
open Lean.Parser.Term
|
||||
open Lean.Meta
|
||||
|
||||
open InternalSyntax in
|
||||
/--
|
||||
If the given syntax is a `doIf`, return an equivalent `doIf` that has an `else` but no `else if`s or
|
||||
`if let`s.
|
||||
@@ -26,8 +25,8 @@ If the given syntax is a `doIf`, return an equivalent `doIf` that has an `else`
|
||||
match stx with
|
||||
| `(doElem|if $_:doIfProp then $_ else $_) =>
|
||||
Macro.throwUnsupported
|
||||
| `(doElem|if%$tk $cond:doIfCond then $t $[else if%$tks $conds:doIfCond then $ts]* $[else $e?]?) => do
|
||||
let mut e : Syntax ← e?.getDM `(doSeq| skip%$tk)
|
||||
| `(doElem|if $cond:doIfCond then $t $[else if $conds:doIfCond then $ts]* $[else $e?]?) => do
|
||||
let mut e : Syntax ← e?.getDM `(doSeq|pure PUnit.unit)
|
||||
let mut eIsSeq := true
|
||||
for (cond, t) in Array.zip (conds.reverse.push cond) (ts.reverse.push t) do
|
||||
e ← if eIsSeq then pure e else `(doSeq|$(⟨e⟩):doElem)
|
||||
|
||||
@@ -88,18 +88,17 @@ private def checkLetConfigInDo (config : Term.LetConfig) : DoElabM Unit := do
|
||||
throwError "`+generalize` is not supported in `do` blocks"
|
||||
|
||||
partial def elabDoLetOrReassign (config : Term.LetConfig) (letOrReassign : LetOrReassign) (decl : TSyntax ``letDecl)
|
||||
(tk : Syntax) (dec : DoElemCont) : DoElabM Expr := do
|
||||
(dec : DoElemCont) : DoElabM Expr := do
|
||||
checkLetConfigInDo config
|
||||
let vars ← getLetDeclVars decl
|
||||
letOrReassign.checkMutVars vars
|
||||
let dec ← dec.ensureUnitAt tk
|
||||
-- Some decl preprocessing on the patterns and expected types:
|
||||
let decl ← pushTypeIntoReassignment letOrReassign decl
|
||||
let mγ ← mkMonadicType (← read).doBlockResultType
|
||||
match decl with
|
||||
| `(letDecl| $decl:letEqnsDecl) =>
|
||||
let declNew ← `(letDecl| $(⟨← liftMacroM <| Term.expandLetEqnsDecl decl⟩):letIdDecl)
|
||||
return ← Term.withMacroExpansion decl declNew <| elabDoLetOrReassign config letOrReassign declNew tk dec
|
||||
return ← Term.withMacroExpansion decl declNew <| elabDoLetOrReassign config letOrReassign declNew dec
|
||||
| `(letDecl| $pattern:term $[: $xType?]? := $rhs) =>
|
||||
let rhs ← match xType? with | some xType => `(($rhs : $xType)) | none => pure rhs
|
||||
let contElab : DoElabM Expr := elabWithReassignments letOrReassign vars dec.continueWithUnit
|
||||
@@ -163,11 +162,10 @@ partial def elabDoLetOrReassign (config : Term.LetConfig) (letOrReassign : LetOr
|
||||
mkLetFVars #[x, h'] body (usedLetOnly := config.usedOnly) (generalizeNondepLet := false)
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
def elabDoArrow (letOrReassign : LetOrReassign) (stx : TSyntax [``doIdDecl, ``doPatDecl]) (tk : Syntax) (dec : DoElemCont) : DoElabM Expr := do
|
||||
def elabDoArrow (letOrReassign : LetOrReassign) (stx : TSyntax [``doIdDecl, ``doPatDecl]) (dec : DoElemCont) : DoElabM Expr := do
|
||||
match stx with
|
||||
| `(doIdDecl| $x:ident $[: $xType?]? ← $rhs) =>
|
||||
letOrReassign.checkMutVars #[x]
|
||||
let dec ← dec.ensureUnitAt tk
|
||||
-- For plain variable reassignment, we know the expected type of the reassigned variable and
|
||||
-- propagate it eagerly via type ascription if the user hasn't provided one themselves:
|
||||
let xType? ← match letOrReassign, xType? with
|
||||
@@ -179,7 +177,6 @@ def elabDoArrow (letOrReassign : LetOrReassign) (stx : TSyntax [``doIdDecl, ``do
|
||||
(kind := dec.kind)
|
||||
| `(doPatDecl| _%$pattern $[: $patType?]? ← $rhs) =>
|
||||
let x := mkIdentFrom pattern (← mkFreshUserName `__x)
|
||||
let dec ← dec.ensureUnitAt tk
|
||||
elabDoIdDecl x patType? rhs dec.continueWithUnit (kind := dec.kind)
|
||||
| `(doPatDecl| $pattern:term $[: $patType?]? ← $rhs $[| $otherwise? $(rest?)?]?) =>
|
||||
let rest? := rest?.join
|
||||
@@ -208,18 +205,17 @@ private def getLetConfigAndCheckMut (letConfigStx : TSyntax ``Parser.Term.letCon
|
||||
Term.mkLetConfig letConfigStx initConfig
|
||||
|
||||
@[builtin_doElem_elab Lean.Parser.Term.doLet] def elabDoLet : DoElab := fun stx dec => do
|
||||
let `(doLet| let%$tk $[mut%$mutTk?]? $config:letConfig $decl:letDecl) := stx | throwUnsupportedSyntax
|
||||
let `(doLet| let $[mut%$mutTk?]? $config:letConfig $decl:letDecl) := stx | throwUnsupportedSyntax
|
||||
let config ← getLetConfigAndCheckMut config mutTk?
|
||||
elabDoLetOrReassign config (.let mutTk?) decl tk dec
|
||||
elabDoLetOrReassign config (.let mutTk?) decl dec
|
||||
|
||||
@[builtin_doElem_elab Lean.Parser.Term.doHave] def elabDoHave : DoElab := fun stx dec => do
|
||||
let `(doHave| have%$tk $config:letConfig $decl:letDecl) := stx | throwUnsupportedSyntax
|
||||
let `(doHave| have $config:letConfig $decl:letDecl) := stx | throwUnsupportedSyntax
|
||||
let config ← Term.mkLetConfig config { nondep := true }
|
||||
elabDoLetOrReassign config .have decl tk dec
|
||||
elabDoLetOrReassign config .have decl dec
|
||||
|
||||
@[builtin_doElem_elab Lean.Parser.Term.doLetRec] def elabDoLetRec : DoElab := fun stx dec => do
|
||||
let `(doLetRec| let%$tk rec $decls:letRecDecls) := stx | throwUnsupportedSyntax
|
||||
let dec ← dec.ensureUnitAt tk
|
||||
let `(doLetRec| let rec $decls:letRecDecls) := stx | throwUnsupportedSyntax
|
||||
let vars ← getLetRecDeclsVars decls
|
||||
let mγ ← mkMonadicType (← read).doBlockResultType
|
||||
doElabToSyntax m!"let rec body of group {vars}" dec.continueWithUnit fun body => do
|
||||
@@ -231,13 +227,13 @@ private def getLetConfigAndCheckMut (letConfigStx : TSyntax ``Parser.Term.letCon
|
||||
@[builtin_doElem_elab Lean.Parser.Term.doReassign] def elabDoReassign : DoElab := fun stx dec => do
|
||||
-- def doReassign := letIdDeclNoBinders <|> letPatDecl
|
||||
match stx with
|
||||
| `(doReassign| $x:ident $[: $xType?]? :=%$tk $rhs) =>
|
||||
| `(doReassign| $x:ident $[: $xType?]? := $rhs) =>
|
||||
let decl : TSyntax ``letIdDecl ← `(letIdDecl| $x:ident $[: $xType?]? := $rhs)
|
||||
let decl : TSyntax ``letDecl := ⟨mkNode ``letDecl #[decl]⟩
|
||||
elabDoLetOrReassign {} .reassign decl tk dec
|
||||
elabDoLetOrReassign {} .reassign decl dec
|
||||
| `(doReassign| $decl:letPatDecl) =>
|
||||
let decl : TSyntax ``letDecl := ⟨mkNode ``letDecl #[decl]⟩
|
||||
elabDoLetOrReassign {} .reassign decl decl dec
|
||||
elabDoLetOrReassign {} .reassign decl dec
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
@[builtin_doElem_elab Lean.Parser.Term.doLetElse] def elabDoLetElse : DoElab := fun stx dec => do
|
||||
@@ -259,17 +255,17 @@ private def getLetConfigAndCheckMut (letConfigStx : TSyntax ``Parser.Term.letCon
|
||||
elabDoElem (← `(doElem| match $rhs:term with | $pattern => $body:doSeqIndent | _ => $otherwise:doSeqIndent)) dec
|
||||
|
||||
@[builtin_doElem_elab Lean.Parser.Term.doLetArrow] def elabDoLetArrow : DoElab := fun stx dec => do
|
||||
let `(doLetArrow| let%$tk $[mut%$mutTk?]? $cfg:letConfig $decl) := stx | throwUnsupportedSyntax
|
||||
let `(doLetArrow| let $[mut%$mutTk?]? $cfg:letConfig $decl) := stx | throwUnsupportedSyntax
|
||||
let config ← getLetConfigAndCheckMut cfg mutTk?
|
||||
checkLetConfigInDo config
|
||||
if config.nondep || config.usedOnly || config.zeta || config.eq?.isSome then
|
||||
throwErrorAt cfg "configuration options are not supported with `←`"
|
||||
elabDoArrow (.let mutTk?) decl tk dec
|
||||
elabDoArrow (.let mutTk?) decl dec
|
||||
|
||||
@[builtin_doElem_elab Lean.Parser.Term.doReassignArrow] def elabDoReassignArrow : DoElab := fun stx dec => do
|
||||
match stx with
|
||||
| `(doReassignArrow| $decl:doIdDecl) =>
|
||||
elabDoArrow .reassign decl decl dec
|
||||
elabDoArrow .reassign decl dec
|
||||
| `(doReassignArrow| $decl:doPatDecl) =>
|
||||
elabDoArrow .reassign decl decl dec
|
||||
elabDoArrow .reassign decl dec
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
@@ -16,12 +16,6 @@ namespace Lean.Elab.Do
|
||||
open Lean.Parser.Term
|
||||
open Lean.Meta
|
||||
|
||||
open InternalSyntax in
|
||||
@[builtin_doElem_elab Lean.Parser.Term.InternalSyntax.doSkip] def elabDoSkip : DoElab := fun stx dec => do
|
||||
let `(doSkip| skip%$tk) := stx | throwUnsupportedSyntax
|
||||
let dec ← dec.ensureUnitAt tk
|
||||
dec.continueWithUnit
|
||||
|
||||
@[builtin_doElem_elab Lean.Parser.Term.doExpr] def elabDoExpr : DoElab := fun stx dec => do
|
||||
let `(doExpr| $e:term) := stx | throwUnsupportedSyntax
|
||||
let mα ← mkMonadicType dec.resultType
|
||||
@@ -32,28 +26,24 @@ open InternalSyntax in
|
||||
let `(doNested| do $doSeq) := stx | throwUnsupportedSyntax
|
||||
elabDoSeq ⟨doSeq.raw⟩ dec
|
||||
|
||||
open InternalSyntax in
|
||||
@[builtin_doElem_elab Lean.Parser.Term.doUnless] def elabDoUnless : DoElab := fun stx dec => do
|
||||
let `(doUnless| unless%$tk $cond do $body) := stx | throwUnsupportedSyntax
|
||||
elabDoElem (← `(doElem| if $cond then skip%$tk else $body)) dec
|
||||
let `(doUnless| unless $cond do $body) := stx | throwUnsupportedSyntax
|
||||
elabDoElem (← `(doElem| if $cond then pure PUnit.unit else $body)) dec
|
||||
|
||||
@[builtin_doElem_elab Lean.Parser.Term.doDbgTrace] def elabDoDbgTrace : DoElab := fun stx dec => do
|
||||
let `(doDbgTrace| dbg_trace%$tk $msg:term) := stx | throwUnsupportedSyntax
|
||||
let `(doDbgTrace| dbg_trace $msg:term) := stx | throwUnsupportedSyntax
|
||||
let mγ ← mkMonadicType (← read).doBlockResultType
|
||||
let dec ← dec.ensureUnitAt tk
|
||||
doElabToSyntax "dbg_trace body" dec.continueWithUnit fun body => do
|
||||
Term.elabTerm (← `(dbg_trace $msg; $body)) mγ
|
||||
|
||||
@[builtin_doElem_elab Lean.Parser.Term.doAssert] def elabDoAssert : DoElab := fun stx dec => do
|
||||
let `(doAssert| assert!%$tk $cond) := stx | throwUnsupportedSyntax
|
||||
let `(doAssert| assert! $cond) := stx | throwUnsupportedSyntax
|
||||
let mγ ← mkMonadicType (← read).doBlockResultType
|
||||
let dec ← dec.ensureUnitAt tk
|
||||
doElabToSyntax "assert! body" dec.continueWithUnit fun body => do
|
||||
Term.elabTerm (← `(assert! $cond; $body)) mγ
|
||||
|
||||
@[builtin_doElem_elab Lean.Parser.Term.doDebugAssert] def elabDoDebugAssert : DoElab := fun stx dec => do
|
||||
let `(doDebugAssert| debug_assert!%$tk $cond) := stx | throwUnsupportedSyntax
|
||||
let `(doDebugAssert| debug_assert! $cond) := stx | throwUnsupportedSyntax
|
||||
let mγ ← mkMonadicType (← read).doBlockResultType
|
||||
let dec ← dec.ensureUnitAt tk
|
||||
doElabToSyntax "debug_assert! body" dec.continueWithUnit fun body => do
|
||||
Term.elabTerm (← `(debug_assert! $cond; $body)) mγ
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sebastian Graf
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Lean.Elab.BuiltinDo.Basic
|
||||
meta import Lean.Parser.Do
|
||||
import Lean.Elab.BuiltinDo.For
|
||||
|
||||
public section
|
||||
|
||||
namespace Lean.Elab.Do
|
||||
|
||||
open Lean.Parser.Term
|
||||
|
||||
/--
|
||||
Builtin do-element elaborator for `repeat` (syntax kind `Lean.Parser.Term.doRepeat`).
|
||||
|
||||
Expands to `for _ in Loop.mk do ...`. A follow-up change will extend this
|
||||
elaborator to choose between `Loop.mk` and a well-founded `Repeat.mk` based on a
|
||||
configuration option.
|
||||
-/
|
||||
@[builtin_doElem_elab Lean.Parser.Term.doRepeat] def elabDoRepeat : DoElab := fun stx dec => do
|
||||
let `(doElem| repeat%$tk $seq) := stx | throwUnsupportedSyntax
|
||||
let expanded ← `(doElem| for%$tk _ in Loop.mk do $seq)
|
||||
Term.withMacroExpansion stx expanded <|
|
||||
withRef expanded <| elabDoElem ⟨expanded⟩ dec
|
||||
|
||||
@[builtin_macro Lean.Parser.Term.doWhileH] def expandDoWhileH : Macro
|
||||
| `(doElem| while%$tk $h : $cond do $seq) => `(doElem| repeat%$tk if $h:ident : $cond then $seq else break)
|
||||
| _ => Macro.throwUnsupported
|
||||
|
||||
@[builtin_macro Lean.Parser.Term.doWhile] def expandDoWhile : Macro
|
||||
| `(doElem| while%$tk $cond do $seq) => `(doElem| repeat%$tk if $cond then $seq else break)
|
||||
| _ => Macro.throwUnsupported
|
||||
|
||||
@[builtin_macro Lean.Parser.Term.doRepeatUntil] def expandDoRepeatUntil : Macro
|
||||
| `(doElem| repeat%$tk $seq until $cond) => `(doElem| repeat%$tk do $seq:doSeq; if $cond then break)
|
||||
| _ => Macro.throwUnsupported
|
||||
|
||||
end Lean.Elab.Do
|
||||
@@ -220,7 +220,7 @@ def processDefDeriving (view : DerivingClassView) (decl : Expr) (isNoncomputable
|
||||
instName ← liftMacroM <| mkUnusedBaseName instName
|
||||
if isPrivateName declName then
|
||||
instName := mkPrivateName env instName
|
||||
let isMeta := (← read).isMetaSection || isMarkedMeta (← getEnv) declName
|
||||
let isMeta := (← read).declName?.any (isMarkedMeta (← getEnv))
|
||||
let inst ← if backward.inferInstanceAs.wrap.get (← getOptions) then
|
||||
withDeclNameForAuxNaming instName <| withNewMCtxDepth <|
|
||||
wrapInstance result.instVal result.instType
|
||||
@@ -255,12 +255,11 @@ def processDefDeriving (view : DerivingClassView) (decl : Expr) (isNoncomputable
|
||||
logInfoAt cmdRef m!"Try this: {newText}"
|
||||
throwError "failed to derive instance because it depends on \
|
||||
`{.ofConstName noncompRef}`, which is noncomputable"
|
||||
let isMeta := (← read).isMetaSection || isMarkedMeta (← getEnv) declName
|
||||
if isNoncomputable || (← read).isNoncomputableSection then
|
||||
addDecl <| Declaration.defnDecl decl
|
||||
modifyEnv (addNoncomputable · instName)
|
||||
else
|
||||
addAndCompile (Declaration.defnDecl decl) (markMeta := isMeta)
|
||||
addAndCompile <| Declaration.defnDecl decl
|
||||
trace[Elab.Deriving] "Derived instance `{.ofConstName instName}`"
|
||||
-- For Prop-typed instances (theorems), skip `implicit_reducible` since reducibility hints are
|
||||
-- irrelevant for theorems. This matches the behavior of the handwritten `instance` command
|
||||
|
||||
@@ -374,60 +374,14 @@ def withLCtxKeepingMutVarDefs (oldLCtx : LocalContext) (oldCtx : Context) (resul
|
||||
mutVarDefs := oldMutVarDefs
|
||||
}) k
|
||||
|
||||
def mkMonadicResultTypeMismatchError (contType : Expr) (elementType : Expr) : MessageData :=
|
||||
m!"Type mismatch. The `do` element has monadic result type{indentExpr elementType}\n\
|
||||
but the rest of the `do` block has monadic result type{indentExpr contType}"
|
||||
|
||||
/--
|
||||
Given a continuation `dec`, a reference `ref`, and an element result type `elementType`, returns a
|
||||
continuation derived from `dec` with result type `elementType`.
|
||||
If `dec` already has result type `elementType`, simply returns `dec`.
|
||||
Otherwise, an error is logged and a new continuation is returned that calls `dec` with `sorry` as a
|
||||
result. The error is reported at `ref`.
|
||||
-/
|
||||
def DoElemCont.ensureHasTypeAt (dec : DoElemCont) (ref : Syntax) (elementType : Expr) : DoElabM DoElemCont := do
|
||||
if ← isDefEqGuarded dec.resultType elementType then
|
||||
return dec
|
||||
let errMessage := mkMonadicResultTypeMismatchError dec.resultType elementType
|
||||
unless (← readThe Term.Context).errToSorry do
|
||||
throwErrorAt ref errMessage
|
||||
logErrorAt ref errMessage
|
||||
return {
|
||||
resultName := ← mkFreshUserName `__r
|
||||
resultType := elementType
|
||||
k := do
|
||||
mapLetDecl dec.resultName dec.resultType (← mkSorry dec.resultType true)
|
||||
(nondep := true) (kind := .implDetail) fun _ => dec.k
|
||||
kind := dec.kind
|
||||
}
|
||||
|
||||
/--
|
||||
Given a continuation `dec` and a reference `ref`, returns a continuation derived from `dec` with result type `PUnit`.
|
||||
If `dec` already has result type `PUnit`, simply returns `dec`. Otherwise, an error is logged and a
|
||||
new continuation is returned that calls `dec` with `sorry` as a result. The error is reported at `ref`.
|
||||
-/
|
||||
def DoElemCont.ensureUnitAt (dec : DoElemCont) (ref : Syntax) : DoElabM DoElemCont := do
|
||||
dec.ensureHasTypeAt ref (← mkPUnit)
|
||||
|
||||
/--
|
||||
Given a continuation `dec`, returns a continuation derived from `dec` with result type `PUnit`.
|
||||
If `dec` already has result type `PUnit`, simply returns `dec`. Otherwise, an error is logged and a
|
||||
new continuation is returned that calls `dec` with `sorry` as a result.
|
||||
-/
|
||||
def DoElemCont.ensureUnit (dec : DoElemCont) : DoElabM DoElemCont := do
|
||||
dec.ensureUnitAt (← getRef)
|
||||
|
||||
/--
|
||||
Return `$e >>= fun ($dec.resultName : $dec.resultType) => $(← dec.k)`, cancelling
|
||||
the bind if `$(← dec.k)` is `pure $dec.resultName` or `e` is some `pure` computation.
|
||||
-/
|
||||
def DoElemCont.mkBindUnlessPure (dec : DoElemCont) (e : Expr) : DoElabM Expr := do
|
||||
-- let eResultTy ← mkFreshResultType
|
||||
-- let e ← Term.ensureHasType (← mkMonadicType eResultTy) e
|
||||
-- let dec ← dec.ensureHasType eResultTy
|
||||
let x := dec.resultName
|
||||
let k := dec.k
|
||||
let eResultTy := dec.resultType
|
||||
let k := dec.k
|
||||
-- The .ofBinderName below is mainly to interpret `__do_lift` binders as implementation details.
|
||||
let declKind := .ofBinderName x
|
||||
let kResultTy ← mkFreshResultType `kResultTy
|
||||
@@ -467,8 +421,9 @@ Return `let $k.resultName : PUnit := PUnit.unit; $(← k.k)`, ensuring that the
|
||||
is `PUnit` and then immediately zeta-reduce the `let`.
|
||||
-/
|
||||
def DoElemCont.continueWithUnit (dec : DoElemCont) : DoElabM Expr := do
|
||||
let dec ← dec.ensureUnit
|
||||
mapLetDeclZeta dec.resultName (← mkPUnit) (← mkPUnitUnit) (nondep := true) (kind := .ofBinderName dec.resultName) fun _ =>
|
||||
let unit ← mkPUnitUnit
|
||||
discard <| Term.ensureHasType dec.resultType unit
|
||||
mapLetDeclZeta dec.resultName (← mkPUnit) unit (nondep := true) (kind := .ofBinderName dec.resultName) fun _ =>
|
||||
dec.k
|
||||
|
||||
/-- Elaborate the `DoElemCont` with the `deadCode` flag set to `deadSyntactically` to emit warnings. -/
|
||||
@@ -587,7 +542,7 @@ def DoElemCont.withDuplicableCont (nondupDec : DoElemCont) (callerInfo : Control
|
||||
withLocalDeclD nondupDec.resultName nondupDec.resultType fun r => do
|
||||
withLocalDeclsDND (mutDecls.map fun (d : LocalDecl) => (d.userName, d.type)) fun muts => do
|
||||
for (x, newX) in mutVars.zip muts do Term.addTermInfo' x newX
|
||||
withDeadCode (if callerInfo.deadCode then .deadSemantically else .alive) do
|
||||
withDeadCode (if callerInfo.numRegularExits > 0 then .alive else .deadSemantically) do
|
||||
let e ← nondupDec.k
|
||||
mkLambdaFVars (#[r] ++ muts) e
|
||||
unless ← joinRhsMVar.mvarId!.checkedAssign joinRhs do
|
||||
@@ -649,7 +604,6 @@ def enterFinally (resultType : Expr) (k : DoElabM Expr) : DoElabM Expr := do
|
||||
/-- Extracts `MonadInfo` and monadic result type `α` from the expected type of a `do` block `m α`. -/
|
||||
private partial def extractMonadInfo (expectedType? : Option Expr) : Term.TermElabM (MonadInfo × Expr) := do
|
||||
let some expectedType := expectedType? | mkUnknownMonadResult
|
||||
let expectedType ← instantiateMVars expectedType
|
||||
let extractStep? (type : Expr) : Term.TermElabM (Option (MonadInfo × Expr)) := do
|
||||
let .app m resultType := type.consumeMData | return none
|
||||
unless ← isType resultType do return none
|
||||
|
||||
@@ -232,8 +232,9 @@ def ControlLifter.ofCont (info : ControlInfo) (dec : DoElemCont) : DoElabM Contr
|
||||
breakBase?,
|
||||
continueBase?,
|
||||
pureBase := controlStack,
|
||||
-- The success continuation `origCont` is dead code iff the `ControlInfo` says so semantically.
|
||||
pureDeadCode := if info.deadCode then .deadSemantically else .alive,
|
||||
-- The success continuation `origCont` is dead code iff the `ControlInfo` says that there is no
|
||||
-- regular exit.
|
||||
pureDeadCode := if info.numRegularExits > 0 then .alive else .deadSemantically,
|
||||
liftedDoBlockResultType := (← controlStack.stM dec.resultType),
|
||||
}
|
||||
|
||||
|
||||
@@ -16,77 +16,46 @@ namespace Lean.Elab.Do
|
||||
|
||||
open Lean Meta Parser.Term
|
||||
|
||||
/--
|
||||
Represents information about what control effects a `do` block has.
|
||||
|
||||
The fields split by flavor:
|
||||
|
||||
* `breaks`, `continues`, `returnsEarly`, and `reassigns` are **syntactic**: `true`/non-empty iff
|
||||
the corresponding construct appears anywhere in the source text of the block, independent of
|
||||
whether it is semantically reachable. Downstream elaborators must assume every such syntactic
|
||||
effect may occur, because the elaborator visits every doElem (only top-level
|
||||
`return`/`break`/`continue` short-circuit via `elabAsSyntacticallyDeadCode`).
|
||||
* `numRegularExits` is also **syntactic**: the number of times the block wires the enclosing
|
||||
continuation into its elaborated expression. `withDuplicableCont` reads it as a join-point
|
||||
duplication trigger (`> 1`).
|
||||
* `deadCode` is **semantic**: a conservative over-approximation of "every path through the block
|
||||
fails to reach the end normally". It drives the dead-code warning.
|
||||
|
||||
Invariant: `numRegularExits = 0 → deadCode = true`. The converse does not hold — for example a
|
||||
`repeat` with no `break` has `numRegularExits = 1` (the loop elaborator wires its continuation
|
||||
once for the normal-exit path) but `deadCode = true` (the loop never terminates normally).
|
||||
-/
|
||||
/-- Represents information about what control effects a `do` block has. -/
|
||||
structure ControlInfo where
|
||||
/-- The `do` block syntactically contains a `break`. -/
|
||||
/-- The `do` block may `break`. -/
|
||||
breaks : Bool := false
|
||||
/-- The `do` block syntactically contains a `continue`. -/
|
||||
/-- The `do` block may `continue`. -/
|
||||
continues : Bool := false
|
||||
/-- The `do` block syntactically contains an early `return`. -/
|
||||
/-- The `do` block may `return` early. -/
|
||||
returnsEarly : Bool := false
|
||||
/--
|
||||
The number of times the block wires the enclosing continuation into its elaborated expression.
|
||||
Consumed by `withDuplicableCont` to decide whether to introduce a join point (`> 1`).
|
||||
The number of regular exit paths the `do` block has.
|
||||
Corresponds to the number of jumps to an ambient join point.
|
||||
-/
|
||||
numRegularExits : Nat := 1
|
||||
/--
|
||||
Conservative semantic flag: `true` iff every path through the block provably fails to reach the
|
||||
end normally. Implied by `numRegularExits = 0`, but not equivalent (e.g. a `repeat` without
|
||||
`break` has `numRegularExits = 1` yet is dead).
|
||||
-/
|
||||
deadCode : Bool := false
|
||||
/-- The variables that are syntactically reassigned somewhere in the `do` block. -/
|
||||
/-- The variables that are reassigned in the `do` block. -/
|
||||
reassigns : NameSet := {}
|
||||
deriving Inhabited
|
||||
|
||||
def ControlInfo.pure : ControlInfo := {}
|
||||
|
||||
def ControlInfo.sequence (a b : ControlInfo) : ControlInfo := {
|
||||
-- Syntactic fields aggregate unconditionally; the elaborator keeps visiting `b` unless `a` is
|
||||
-- a syntactically-terminal element (only top-level `return`/`break`/`continue` are, via
|
||||
-- `elabAsSyntacticallyDeadCode`).
|
||||
def ControlInfo.sequence (a b : ControlInfo) : ControlInfo :=
|
||||
if a.numRegularExits == 0 then a else {
|
||||
breaks := a.breaks || b.breaks,
|
||||
continues := a.continues || b.continues,
|
||||
returnsEarly := a.returnsEarly || b.returnsEarly,
|
||||
reassigns := a.reassigns ++ b.reassigns,
|
||||
numRegularExits := b.numRegularExits,
|
||||
-- Semantic: the sequence is dead if either part is dead.
|
||||
deadCode := a.deadCode || b.deadCode,
|
||||
reassigns := a.reassigns ++ b.reassigns,
|
||||
}
|
||||
|
||||
def ControlInfo.alternative (a b : ControlInfo) : ControlInfo := {
|
||||
breaks := a.breaks || b.breaks,
|
||||
continues := a.continues || b.continues,
|
||||
returnsEarly := a.returnsEarly || b.returnsEarly,
|
||||
reassigns := a.reassigns ++ b.reassigns,
|
||||
numRegularExits := a.numRegularExits + b.numRegularExits,
|
||||
-- Semantic: alternatives are dead only if all branches are dead.
|
||||
deadCode := a.deadCode && b.deadCode,
|
||||
reassigns := a.reassigns ++ b.reassigns,
|
||||
}
|
||||
|
||||
instance : ToMessageData ControlInfo where
|
||||
toMessageData info := m!"breaks: {info.breaks}, continues: {info.continues},
|
||||
returnsEarly: {info.returnsEarly}, numRegularExits: {info.numRegularExits},
|
||||
deadCode: {info.deadCode}, reassigns: {info.reassigns.toList}"
|
||||
returnsEarly: {info.returnsEarly}, exitsRegularly: {info.numRegularExits},
|
||||
reassigns: {info.reassigns.toList}"
|
||||
|
||||
/-- A handler for inferring `ControlInfo` from a `doElem` syntax. Register with `@[doElem_control_info parserName]`. -/
|
||||
abbrev ControlInfoHandler := TSyntax `doElem → TermElabM ControlInfo
|
||||
@@ -110,7 +79,6 @@ builtin_initialize controlInfoElemAttribute : KeyedDeclsAttribute ControlInfoHan
|
||||
|
||||
namespace InferControlInfo
|
||||
|
||||
open InternalSyntax in
|
||||
mutual
|
||||
|
||||
partial def ofElem (stx : TSyntax `doElem) : TermElabM ControlInfo := do
|
||||
@@ -120,9 +88,9 @@ partial def ofElem (stx : TSyntax `doElem) : TermElabM ControlInfo := do
|
||||
return ← ofElem ⟨stxNew⟩
|
||||
|
||||
match stx with
|
||||
| `(doElem| break) => return { breaks := true, numRegularExits := 0, deadCode := true }
|
||||
| `(doElem| continue) => return { continues := true, numRegularExits := 0, deadCode := true }
|
||||
| `(doElem| return $[$_]?) => return { returnsEarly := true, numRegularExits := 0, deadCode := true }
|
||||
| `(doElem| break) => return { breaks := true, numRegularExits := 0 }
|
||||
| `(doElem| continue) => return { continues := true, numRegularExits := 0 }
|
||||
| `(doElem| return $[$_]?) => return { returnsEarly := true, numRegularExits := 0 }
|
||||
| `(doExpr| $_:term) => return { numRegularExits := 1 }
|
||||
| `(doElem| do $doSeq) => ofSeq doSeq
|
||||
-- Let
|
||||
@@ -160,24 +128,12 @@ partial def ofElem (stx : TSyntax `doElem) : TermElabM ControlInfo := do
|
||||
return thenInfo.alternative info
|
||||
| `(doElem| unless $_ do $elseSeq) =>
|
||||
ControlInfo.alternative {} <$> ofSeq elseSeq
|
||||
-- For/Repeat
|
||||
| `(doElem| for $[$[$_ :]? $_ in $_],* do $bodySeq) =>
|
||||
let info ← ofSeq bodySeq
|
||||
return { info with -- keep only reassigns and earlyReturn
|
||||
numRegularExits := 1,
|
||||
continues := false,
|
||||
breaks := false,
|
||||
deadCode := false,
|
||||
}
|
||||
| `(doRepeat| repeat $bodySeq) =>
|
||||
let info ← ofSeq bodySeq
|
||||
return { info with
|
||||
-- Syntactically the loop elaborator wires the continuation once (for the break path).
|
||||
numRegularExits := 1,
|
||||
continues := false,
|
||||
breaks := false,
|
||||
-- Semantically the loop never terminates normally unless the body may `break`.
|
||||
deadCode := !info.breaks,
|
||||
breaks := false
|
||||
}
|
||||
-- Try
|
||||
| `(doElem| try $trySeq:doSeq $[$catches]* $[finally $finSeq?]?) =>
|
||||
@@ -196,7 +152,6 @@ partial def ofElem (stx : TSyntax `doElem) : TermElabM ControlInfo := do
|
||||
let finInfo ← ofOptionSeq finSeq?
|
||||
return info.sequence finInfo
|
||||
-- Misc
|
||||
| `(doElem| skip) => return .pure
|
||||
| `(doElem| dbg_trace $_) => return .pure
|
||||
| `(doElem| assert! $_) => return .pure
|
||||
| `(doElem| debug_assert! $_) => return .pure
|
||||
@@ -247,7 +202,17 @@ partial def ofLetOrReassign (reassigned : Array Ident) (rhs? : Option (TSyntax `
|
||||
partial def ofSeq (stx : TSyntax ``doSeq) : TermElabM ControlInfo := do
|
||||
let mut info : ControlInfo := {}
|
||||
for elem in getDoElems stx do
|
||||
info := info.sequence (← ofElem elem)
|
||||
if info.numRegularExits == 0 then
|
||||
break
|
||||
let elemInfo ← ofElem elem
|
||||
info := {
|
||||
info with
|
||||
breaks := info.breaks || elemInfo.breaks
|
||||
continues := info.continues || elemInfo.continues
|
||||
returnsEarly := info.returnsEarly || elemInfo.returnsEarly
|
||||
numRegularExits := elemInfo.numRegularExits
|
||||
reassigns := info.reassigns ++ elemInfo.reassigns
|
||||
}
|
||||
return info
|
||||
|
||||
partial def ofOptionSeq (stx? : Option (TSyntax ``doSeq)) : TermElabM ControlInfo := do
|
||||
|
||||
@@ -1782,7 +1782,7 @@ mutual
|
||||
doIfToCode doElem doElems
|
||||
else if k == ``Parser.Term.doUnless then
|
||||
doUnlessToCode doElem doElems
|
||||
else if k == ``Parser.Term.doRepeat then
|
||||
else if k == `Lean.doRepeat then
|
||||
let seq := doElem[1]
|
||||
let expanded ← `(doElem| for _ in Loop.mk do $seq)
|
||||
doSeqToCode (expanded :: doElems)
|
||||
@@ -1819,13 +1819,6 @@ mutual
|
||||
return mkTerminalAction term
|
||||
else
|
||||
return mkSeq term (← doSeqToCode doElems)
|
||||
else if k == ``Parser.Term.InternalSyntax.doSkip then
|
||||
-- In the legacy elaborator, `skip` is treated as `pure PUnit.unit`.
|
||||
let term ← withRef doElem `(pure PUnit.unit)
|
||||
if doElems.isEmpty then
|
||||
return mkTerminalAction term
|
||||
else
|
||||
return mkSeq term (← doSeqToCode doElems)
|
||||
else
|
||||
throwError "unexpected do-element of kind {doElem.getKind}:\n{doElem}"
|
||||
end
|
||||
|
||||
@@ -364,9 +364,8 @@ def elabIdbgTerm : TermElab := fun stx expectedType? => do
|
||||
|
||||
@[builtin_doElem_elab Lean.Parser.Term.doIdbg]
|
||||
def elabDoIdbg : DoElab := fun stx dec => do
|
||||
let `(Lean.Parser.Term.doIdbg| idbg%$tk $e) := stx | throwUnsupportedSyntax
|
||||
let `(Lean.Parser.Term.doIdbg| idbg $e) := stx | throwUnsupportedSyntax
|
||||
let mγ ← mkMonadicType (← read).doBlockResultType
|
||||
let dec ← dec.ensureUnitAt tk
|
||||
doElabToSyntax "idbg body" dec.continueWithUnit fun body => do
|
||||
elabIdbgCore (e := e) (body := body) (ref := stx) mγ
|
||||
|
||||
|
||||
@@ -73,8 +73,6 @@ private def inductiveSyntaxToView (modifiers : Modifiers) (decl : Syntax) (isCoi
|
||||
throwError "Constructor cannot be `protected` because it is in a `private` inductive datatype"
|
||||
checkValidCtorModifier ctorModifiers
|
||||
let ctorName := ctor.getIdAt 3
|
||||
if ctorName.hasMacroScopes && isCoinductive then
|
||||
throwError "Coinductive predicates are not allowed inside of macro scopes"
|
||||
let ctorName := declName ++ ctorName
|
||||
let ctorName ← withRef ctor[3] <| applyVisibility ctorModifiers ctorName
|
||||
let (binders, type?) := expandOptDeclSig ctor[4]
|
||||
|
||||
@@ -222,8 +222,8 @@ private def addNonRecAux (docCtx : LocalContext × LocalInstances) (preDef : Pre
|
||||
if compile && shouldGenCodeFor preDef then
|
||||
compileDecl decl
|
||||
if applyAttrAfterCompilation then
|
||||
saveEqnAffectingOptions preDef.declName
|
||||
enableRealizationsForConst preDef.declName
|
||||
generateEagerEqns preDef.declName
|
||||
addPreDefDocs docCtx preDef
|
||||
if applyAttrAfterCompilation then
|
||||
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
|
||||
|
||||
@@ -28,7 +28,7 @@ def getConstUnfoldEqnFor? (declName : Name) : MetaM (Option Name) := do
|
||||
trace[ReservedNameAction] "getConstUnfoldEqnFor? {declName} failed, no unfold theorem available"
|
||||
return none
|
||||
let name := mkEqLikeNameFor (← getEnv) declName eqUnfoldThmSuffix
|
||||
realizeConst declName name <| withEqnOptions declName do
|
||||
realizeConst declName name do
|
||||
-- we have to call `getUnfoldEqnFor?` again to make `unfoldEqnName` available in this context
|
||||
let some unfoldEqnName ← getUnfoldEqnFor? (nonRec := true) declName | unreachable!
|
||||
let info ← getConstInfo unfoldEqnName
|
||||
|
||||
@@ -367,7 +367,7 @@ def mkEqns (declName : Name) (declNames : Array Name) : MetaM (Array Name) := do
|
||||
thmNames := thmNames.push name
|
||||
-- determinism: `type` should be independent of the environment changes since `baseName` was
|
||||
-- added
|
||||
realizeConst declName name (withEqnOptions declName (doRealize name info type))
|
||||
realizeConst declName name (doRealize name info type)
|
||||
return thmNames
|
||||
where
|
||||
doRealize name info type := withOptions (tactic.hygienic.set · false) do
|
||||
|
||||
@@ -69,10 +69,8 @@ def addPreDefAttributes (preDefs : Array PreDefinition) : TermElabM Unit := do
|
||||
a.name = `instance_reducible || a.name = `implicit_reducible do
|
||||
setIrreducibleAttribute preDef.declName
|
||||
|
||||
for preDef in preDefs do
|
||||
saveEqnAffectingOptions preDef.declName
|
||||
|
||||
/-
|
||||
`enableRealizationsForConst` must happen before `generateEagerEqns`
|
||||
It must happen in reverse order so that constants realized as part of the first decl
|
||||
have realizations for the other ones enabled
|
||||
-/
|
||||
@@ -80,6 +78,7 @@ def addPreDefAttributes (preDefs : Array PreDefinition) : TermElabM Unit := do
|
||||
enableRealizationsForConst preDef.declName
|
||||
|
||||
for preDef in preDefs do
|
||||
generateEagerEqns preDef.declName
|
||||
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
|
||||
|
||||
end Lean.Elab.Mutual
|
||||
|
||||
@@ -163,7 +163,7 @@ public def registerEqnsInfo (preDef : PreDefinition) (declNames : Array Name) (r
|
||||
/-- Generate the "unfold" lemma for `declName`. -/
|
||||
def mkUnfoldEq (declName : Name) (info : EqnInfo) : MetaM Name := do
|
||||
let name := mkEqLikeNameFor (← getEnv) info.declName unfoldThmSuffix
|
||||
realizeConst info.declNames[0]! name (withEqnOptions declName (doRealize name))
|
||||
realizeConst info.declNames[0]! name (doRealize name)
|
||||
return name
|
||||
where
|
||||
doRealize name := withOptions (tactic.hygienic.set · false) do
|
||||
|
||||
@@ -208,11 +208,11 @@ def structuralRecursion
|
||||
-/
|
||||
registerEqnsInfo preDef (preDefs.map (·.declName)) recArgPos fixedParamPerms
|
||||
addSmartUnfoldingDef docCtx preDef recArgPos
|
||||
for preDef in preDefs do
|
||||
saveEqnAffectingOptions preDef.declName
|
||||
for preDef in preDefs do
|
||||
-- must happen in separate loop so realizations can see eqnInfos of all other preDefs
|
||||
enableRealizationsForConst preDef.declName
|
||||
-- must happen after `enableRealizationsForConst`
|
||||
generateEagerEqns preDef.declName
|
||||
applyAttributesOf preDefsNonRec AttributeApplicationTime.afterCompilation
|
||||
|
||||
|
||||
|
||||
@@ -497,21 +497,14 @@ def forEachVar (hs : Array Syntax) (tac : MVarId → FVarId → MetaM MVarId) :
|
||||
/--
|
||||
Searches for a metavariable `g` s.t. `tag` is its exact name.
|
||||
If none then searches for a metavariable `g` s.t. `tag` is a suffix of its name.
|
||||
If none, then it searches for a metavariable `g` s.t. `tag` is a prefix of its name.
|
||||
|
||||
We erase macro scopes from the metavariable's user name before comparing, so that
|
||||
user-written tags match even when a previous tactic left hygienic macro scopes at
|
||||
the end of the tag (e.g. `e_a.yield._@._internal._hyg.0`, where `yield` is not the
|
||||
literal last component of the name). Case tags written by the user are never
|
||||
macro-scoped, so erasing scopes on the mvar side is sufficient.
|
||||
-/
|
||||
If none, then it searches for a metavariable `g` s.t. `tag` is a prefix of its name. -/
|
||||
private def findTag? (mvarIds : List MVarId) (tag : Name) : TacticM (Option MVarId) := do
|
||||
match (← mvarIds.findM? fun mvarId => return tag == (← mvarId.getDecl).userName.eraseMacroScopes) with
|
||||
match (← mvarIds.findM? fun mvarId => return tag == (← mvarId.getDecl).userName) with
|
||||
| some mvarId => return mvarId
|
||||
| none =>
|
||||
match (← mvarIds.findM? fun mvarId => return tag.isSuffixOf (← mvarId.getDecl).userName.eraseMacroScopes) with
|
||||
match (← mvarIds.findM? fun mvarId => return tag.isSuffixOf (← mvarId.getDecl).userName) with
|
||||
| some mvarId => return mvarId
|
||||
| none => mvarIds.findM? fun mvarId => return tag.isPrefixOf (← mvarId.getDecl).userName.eraseMacroScopes
|
||||
| none => mvarIds.findM? fun mvarId => return tag.isPrefixOf (← mvarId.getDecl).userName
|
||||
|
||||
private def getCaseGoals (tag : TSyntax ``binderIdent) : TacticM (MVarId × List MVarId) := do
|
||||
let gs ← getUnsolvedGoals
|
||||
|
||||
@@ -68,10 +68,7 @@ def setGoals (goals : List Goal) : GrindTacticM Unit :=
|
||||
|
||||
def pruneSolvedGoals : GrindTacticM Unit := do
|
||||
let gs ← getGoals
|
||||
let gs ← gs.filterM fun g => do
|
||||
if g.inconsistent then return false
|
||||
-- The metavariable may have been assigned by `isDefEq`
|
||||
return !(← g.mvarId.isAssigned)
|
||||
let gs := gs.filter fun g => !g.inconsistent
|
||||
setGoals gs
|
||||
|
||||
def getUnsolvedGoals : GrindTacticM (List Goal) := do
|
||||
@@ -332,19 +329,13 @@ def liftGoalM (k : GoalM α) : GrindTacticM α := do
|
||||
replaceMainGoal [goal]
|
||||
return a
|
||||
|
||||
inductive LiftActionCoreResult where
|
||||
| closed | subgoals
|
||||
|
||||
def liftActionCore (a : Action) : GrindTacticM LiftActionCoreResult := do
|
||||
def liftAction (a : Action) : GrindTacticM Unit := do
|
||||
let goal ← getMainGoal
|
||||
let ka := fun _ => throwError "tactic is not applicable"
|
||||
let kp := fun goal => return .stuck [goal]
|
||||
match (← liftGrindM <| a goal ka kp) with
|
||||
| .closed _ => replaceMainGoal []; return .closed
|
||||
| .stuck gs => replaceMainGoal gs; return .subgoals
|
||||
|
||||
def liftAction (a : Action) : GrindTacticM Unit := do
|
||||
discard <| liftActionCore a
|
||||
| .closed _ => replaceMainGoal []
|
||||
| .stuck gs => replaceMainGoal gs
|
||||
|
||||
def done : GrindTacticM Unit := do
|
||||
pruneSolvedGoals
|
||||
|
||||
@@ -111,9 +111,7 @@ def evalCheck (tacticName : Name) (k : GoalM Bool)
|
||||
This matches the behavior of these tactics in default tactic mode
|
||||
where `lia` can close `x > 1 → x + y + z > 0` directly. -/
|
||||
if (← read).sym then
|
||||
match (← liftActionCore <| Action.intros 0 >> Action.assertAll) with
|
||||
| .closed => return () -- closed the goal
|
||||
| .subgoals => pure () -- continue
|
||||
liftAction <| Action.intros 0 >> Action.assertAll
|
||||
let recover := (← read).recover
|
||||
liftGoalM do
|
||||
let progress ← k
|
||||
|
||||
@@ -175,7 +175,6 @@ where
|
||||
return !(← allChildrenLt a b)
|
||||
|
||||
lpo (a b : Expr) : MetaM Bool := do
|
||||
checkSystem "Lean.Meta.acLt"
|
||||
-- Case 1: `a < b` if for some child `b_i` of `b`, we have `b_i >= a`
|
||||
if (← someChildGe b a) then
|
||||
return true
|
||||
|
||||
@@ -37,17 +37,12 @@ register_builtin_option backward.eqns.deepRecursiveSplit : Bool := {
|
||||
These options affect the generation of equational theorems in a significant way. For these, their
|
||||
value at definition time, not realization time, should matter.
|
||||
|
||||
This is implemented by storing their values at definition time (when non-default) in an environment
|
||||
extension, and restoring them when the equations are lazily realized.
|
||||
This is implemented by
|
||||
* eagerly realizing the equations when they are set to a non-default value
|
||||
* when realizing them lazily, reset the options to their default
|
||||
-/
|
||||
def eqnAffectingOptions : Array (Lean.Option Bool) := #[backward.eqns.nonrecursive, backward.eqns.deepRecursiveSplit]
|
||||
|
||||
/-- Environment extension that stores the values of `eqnAffectingOptions` at definition time,
|
||||
keyed by declaration name. Only populated when at least one option has a non-default value.
|
||||
Stores an association list of (option name, value) pairs for options that differ from defaults. -/
|
||||
builtin_initialize eqnOptionsExt : MapDeclarationExtension (Array (Name × DataValue)) ←
|
||||
mkMapDeclarationExtension (asyncMode := .local)
|
||||
|
||||
def eqnThmSuffixBase := "eq"
|
||||
def eqnThmSuffixBasePrefix := eqnThmSuffixBase ++ "_"
|
||||
def eqn1ThmSuffix := eqnThmSuffixBasePrefix ++ "1"
|
||||
@@ -158,30 +153,12 @@ structure EqnsExtState where
|
||||
builtin_initialize eqnsExt : EnvExtension EqnsExtState ←
|
||||
registerEnvExtension (pure {}) (asyncMode := .local)
|
||||
|
||||
/--
|
||||
Runs `act` with the equation-affecting options restored to the values stored for `declName`
|
||||
at definition time (or reset to their defaults if none were stored). Use this inside
|
||||
`realizeConst` callbacks, which otherwise see the caller-independent `ctx.opts` rather than
|
||||
the outer `getEqnsFor?` context. -/
|
||||
def withEqnOptions (declName : Name) (act : MetaM α) : MetaM α := do
|
||||
let env ← getEnv
|
||||
let setOpts : Options → Options :=
|
||||
if let some values := eqnOptionsExt.find? env declName then
|
||||
fun os => Id.run do
|
||||
let mut os := eqnAffectingOptions.foldl (fun os o => o.set os o.defValue) os
|
||||
for (name, v) in values do
|
||||
os := os.insert name v
|
||||
return os
|
||||
else
|
||||
fun os => eqnAffectingOptions.foldl (fun os o => o.set os o.defValue) os
|
||||
withOptions setOpts act
|
||||
|
||||
/--
|
||||
Simple equation theorem for nonrecursive definitions.
|
||||
-/
|
||||
def mkSimpleEqThm (declName : Name) (name : Name) : MetaM (Option Name) := do
|
||||
if let some (.defnInfo info) := (← getEnv).find? declName then
|
||||
realizeConst declName name (withEqnOptions declName (doRealize name info))
|
||||
realizeConst declName name (doRealize name info)
|
||||
return some name
|
||||
else
|
||||
return none
|
||||
@@ -252,22 +229,19 @@ private def getEqnsFor?Core (declName : Name) : MetaM (Option (Array Name)) := w
|
||||
Returns equation theorems for the given declaration.
|
||||
-/
|
||||
def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := withLCtx {} {} do
|
||||
withEqnOptions declName do
|
||||
-- This is the entry point for lazy equation generation. Ignore the current value
|
||||
-- of the options, and revert to the default.
|
||||
withOptions (eqnAffectingOptions.foldl fun os o => o.set os o.defValue) do
|
||||
getEqnsFor?Core declName
|
||||
|
||||
/--
|
||||
If any equation theorem affecting option is not the default value, store the option values
|
||||
for later use during lazy equation generation.
|
||||
If any equation theorem affecting option is not the default value, create the equations now.
|
||||
-/
|
||||
def saveEqnAffectingOptions (declName : Name) : MetaM Unit := do
|
||||
def generateEagerEqns (declName : Name) : MetaM Unit := do
|
||||
let opts ← getOptions
|
||||
let mut nonDefaults : Array (Name × DataValue) := #[]
|
||||
for o in eqnAffectingOptions do
|
||||
if o.get opts != o.defValue then
|
||||
nonDefaults := nonDefaults.push (o.name, KVMap.Value.toDataValue (o.get opts))
|
||||
unless nonDefaults.isEmpty do
|
||||
trace[Elab.definition.eqns] "saving equation-affecting options for {declName}"
|
||||
modifyEnv (eqnOptionsExt.insert · declName nonDefaults)
|
||||
if eqnAffectingOptions.any fun o => o.get opts != o.defValue then
|
||||
trace[Elab.definition.eqns] "generating eager equations for {declName}"
|
||||
let _ ← getEqnsFor?Core declName
|
||||
|
||||
@[expose] def GetUnfoldEqnFn := Name → MetaM (Option Name)
|
||||
|
||||
|
||||
@@ -229,33 +229,8 @@ private partial def computeSynthOrder (inst : Expr) (projInfo? : Option Projecti
|
||||
|
||||
return synthed
|
||||
|
||||
def checkImpossibleInstance (c : Expr) : MetaM Unit := do
|
||||
let cTy ← inferType c
|
||||
forallTelescopeReducing cTy fun args ty => do
|
||||
let argTys ← args.mapM inferType
|
||||
let impossibleArgs ← args.zipIdx.filterMapM fun (arg, i) => do
|
||||
let fv := arg.fvarId!
|
||||
if (← fv.getDecl).binderInfo.isInstImplicit then return none
|
||||
if ty.containsFVar fv then return none
|
||||
if argTys[i+1:].any (·.containsFVar fv) then return none
|
||||
return some m!"{arg} : {← inferType arg}"
|
||||
if impossibleArgs.isEmpty then return
|
||||
let impossibleArgs := MessageData.joinSep impossibleArgs.toList ", "
|
||||
throwError m!"Instance {c} has arguments "
|
||||
++ impossibleArgs
|
||||
++ " that are impossible to infer. Those arguments are not instance-implicit and do not appear in another instance-implicit argument or the return type."
|
||||
|
||||
def checkNonClassInstance (declName : Name) (c : Expr) : MetaM Unit := do
|
||||
let type ← inferType c
|
||||
forallTelescopeReducing type fun _ target => do
|
||||
unless (← isClass? target).isSome do
|
||||
unless target.isSorry do
|
||||
throwError m!"instance `{declName}` target `{target}` is not a type class."
|
||||
|
||||
def addInstance (declName : Name) (attrKind : AttributeKind) (prio : Nat) : MetaM Unit := do
|
||||
let c ← mkConstWithLevelParams declName
|
||||
checkImpossibleInstance c
|
||||
checkNonClassInstance declName c
|
||||
let keys ← mkInstanceKey c
|
||||
let status ← getReducibilityStatus declName
|
||||
unless status matches .reducible | .implicitReducible do
|
||||
|
||||
@@ -38,9 +38,7 @@ and assigning `?m := max ?n v`
|
||||
private def solveSelfMax (mvarId : LMVarId) (v : Level) : MetaM Unit := do
|
||||
assert! v.isMax
|
||||
let n ← mkFreshLevelMVar
|
||||
let v' := mkMaxArgsDiff mvarId v n
|
||||
trace[Meta.isLevelDefEq.step] "solveSelfMax: {mkLevelMVar mvarId} := {v'}"
|
||||
assignLevelMVar mvarId v'
|
||||
assignLevelMVar mvarId <| mkMaxArgsDiff mvarId v n
|
||||
|
||||
/--
|
||||
Returns true if `v` is `max u ?m` (or variant). That is, we solve `u =?= max u ?m` as `?m := u`.
|
||||
@@ -55,7 +53,6 @@ private def tryApproxSelfMax (u v : Level) : MetaM Bool := do
|
||||
where
|
||||
solve (v' : Level) (mvarId : LMVarId) : MetaM Bool := do
|
||||
if u == v' then
|
||||
trace[Meta.isLevelDefEq.step] "tryApproxSelfMax {mkLevelMVar mvarId} := {u}"
|
||||
assignLevelMVar mvarId u
|
||||
return true
|
||||
else
|
||||
@@ -74,14 +71,8 @@ private def tryApproxMaxMax (u v : Level) : MetaM Bool := do
|
||||
| _, _ => return false
|
||||
where
|
||||
solve (u₁ u₂ v' : Level) (mvarId : LMVarId) : MetaM Bool := do
|
||||
if u₁ == v' then
|
||||
trace[Meta.isLevelDefEq.step] "tryApproxMaxMax {mkLevelMVar mvarId} := {u₂}"
|
||||
assignLevelMVar mvarId u₂
|
||||
return true
|
||||
else if u₂ == v' then
|
||||
trace[Meta.isLevelDefEq.step] "tryApproxMaxMax {mkLevelMVar mvarId} := {u₁}"
|
||||
assignLevelMVar mvarId u₁
|
||||
return true
|
||||
if u₁ == v' then assignLevelMVar mvarId u₂; return true
|
||||
else if u₂ == v' then assignLevelMVar mvarId u₁; return true
|
||||
else return false
|
||||
|
||||
private def postponeIsLevelDefEq (lhs : Level) (rhs : Level) : MetaM Unit := do
|
||||
@@ -106,11 +97,9 @@ mutual
|
||||
else if (← isMVarWithGreaterDepth v mvarId) then
|
||||
-- If both `u` and `v` are both metavariables, but depth of v is greater, then we assign `v := u`.
|
||||
-- This can only happen when levelAssignDepth is set to a smaller value than depth (e.g. during TC synthesis)
|
||||
trace[Meta.isLevelDefEq.step] "{v} := {u}"
|
||||
assignLevelMVar v.mvarId! u
|
||||
return LBool.true
|
||||
else if !u.occurs v then
|
||||
trace[Meta.isLevelDefEq.step] "{u} := {v}"
|
||||
assignLevelMVar u.mvarId! v
|
||||
return LBool.true
|
||||
else if v.isMax && !strictOccursMax u v then
|
||||
@@ -144,9 +133,8 @@ mutual
|
||||
@[export lean_is_level_def_eq]
|
||||
partial def isLevelDefEqAuxImpl : Level → Level → MetaM Bool
|
||||
| Level.succ lhs, Level.succ rhs => isLevelDefEqAux lhs rhs
|
||||
| lhs, rhs => do
|
||||
withTraceNodeBefore `Meta.isLevelDefEq (fun _ =>
|
||||
withOptions (·.set `pp.instantiateMVars false) do addMessageContext m!"{lhs} =?= {rhs}") do
|
||||
| lhs, rhs =>
|
||||
withTraceNode `Meta.isLevelDefEq (fun _ => return m!"{lhs} =?= {rhs}") do
|
||||
if lhs.getLevelOffset == rhs.getLevelOffset then
|
||||
return lhs.getOffset == rhs.getOffset
|
||||
else
|
||||
|
||||
@@ -9,37 +9,19 @@ public import Lean.Meta.Sym.ExprPtr
|
||||
public import Lean.Meta.Basic
|
||||
import Lean.Meta.Transform
|
||||
namespace Lean.Meta.Sym
|
||||
|
||||
/--
|
||||
Returns `true` if `e` contains a loose bound variable with index in `[0, n)`
|
||||
This function assumes `n` is small. If this becomes a bottleneck, we should
|
||||
implement a version of `lean_expr_has_loose_bvar` that checks the range in one traversal.
|
||||
-/
|
||||
def hasLooseBVarsInRange (e : Expr) (n : Nat) : Bool :=
|
||||
e.hasLooseBVars && go n
|
||||
where
|
||||
go : Nat → Bool
|
||||
| 0 => false
|
||||
| i+1 => e.hasLooseBVar i || go i
|
||||
|
||||
/--
|
||||
Checks if `body` is eta-expanded with `n` applications: `f (.bvar (n-1)) ... (.bvar 0)`.
|
||||
Returns `f` if so and `f` has no loose bvars with indices in the range `[0, n)`; otherwise returns `default`.
|
||||
Returns `f` if so and `f` has no loose bvars; otherwise returns `default`.
|
||||
- `n`: number of remaining applications to check
|
||||
- `i`: expected bvar index (starts at 0, increments with each application)
|
||||
- `default`: returned when not eta-reducible (enables pointer equality check)
|
||||
-/
|
||||
def etaReduceAux (body : Expr) (n : Nat) (i : Nat) (default : Expr) : Expr :=
|
||||
go body n i
|
||||
where
|
||||
go (body : Expr) (m : Nat) (i : Nat) : Expr := Id.run do
|
||||
match m with
|
||||
| 0 =>
|
||||
if hasLooseBVarsInRange body n then default
|
||||
else body.lowerLooseBVars n n
|
||||
| m+1 =>
|
||||
let .app f (.bvar j) := body | default
|
||||
if j == i then go f m (i+1) else default
|
||||
def etaReduceAux (body : Expr) (n : Nat) (i : Nat) (default : Expr) : Expr := Id.run do
|
||||
match n with
|
||||
| 0 => if body.hasLooseBVars then default else body
|
||||
| n+1 =>
|
||||
let .app f (.bvar j) := body | default
|
||||
if j == i then etaReduceAux f n (i+1) default else default
|
||||
|
||||
/--
|
||||
If `e` is of the form `(fun x₁ ... xₙ => f x₁ ... xₙ)` and `f` does not contain `x₁`, ..., `xₙ`,
|
||||
|
||||
@@ -48,8 +48,6 @@ def introCore (mvarId : MVarId) (max : Nat) (names : Array Name) : SymM (Array F
|
||||
assignDelayedMVar auxMVar.mvarId! fvars mvarId'
|
||||
mvarId.assign val
|
||||
let finalize (lctx : LocalContext) (localInsts : LocalInstances) (fvars : Array Expr) (type : Expr) : SymM (Array Expr × MVarId) := do
|
||||
if fvars.isEmpty then
|
||||
return (#[], mvarId)
|
||||
let type ← instantiateRevS type fvars
|
||||
let mvar' ← mkFreshExprMVarAt lctx localInsts type .syntheticOpaque mvarDecl.userName
|
||||
let mvarId' := mvar'.mvarId!
|
||||
|
||||
@@ -128,6 +128,7 @@ def postprocessAppMVars (tacticName : Name) (mvarId : MVarId) (newMVars : Array
|
||||
(synthAssignedInstances := true) (allowSynthFailures := false) : MetaM Unit := do
|
||||
synthAppInstances tacticName mvarId newMVars binderInfos synthAssignedInstances allowSynthFailures
|
||||
-- TODO: default and auto params
|
||||
appendParentTag mvarId newMVars binderInfos
|
||||
|
||||
private def dependsOnOthers (mvar : Expr) (otherMVars : Array Expr) : MetaM Bool :=
|
||||
otherMVars.anyM fun otherMVar => do
|
||||
@@ -222,7 +223,6 @@ def _root_.Lean.MVarId.apply (mvarId : MVarId) (e : Expr) (cfg : ApplyConfig :=
|
||||
let e ← instantiateMVars e
|
||||
mvarId.assign (mkAppN e newMVars)
|
||||
let newMVars ← newMVars.filterM fun mvar => not <$> mvar.mvarId!.isAssigned
|
||||
appendParentTag mvarId newMVars binderInfos
|
||||
let otherMVarIds ← getMVarsNoDelayed e
|
||||
let newMVarIds ← reorderGoals newMVars cfg.newGoals
|
||||
let otherMVarIds := otherMVarIds.filter fun mvarId => !newMVarIds.contains mvarId
|
||||
|
||||
@@ -148,14 +148,11 @@ def propagatePending : OrderM Unit := do
|
||||
- `h₁ : ↑ue' = ue`
|
||||
- `h₂ : ↑ve' = ve`
|
||||
- `h : ue = ve`
|
||||
**Note**: We currently only support `Nat` originals. Thus `↑a` is actually
|
||||
`NatCast.natCast a`. The lemma `nat_eq` is specialized to `Int`, so we
|
||||
only invoke it when the cast destination is `Int`. For other types (e.g.
|
||||
`Rat`), `pushEq ue ve h` above is sufficient and `grind` core can derive
|
||||
the `Nat` equality via `norm_cast`/cast injectivity if needed.
|
||||
**Note**: We currently only support `Nat`. Thus `↑a` is actually
|
||||
`NatCast.natCast a`. If we decide to support arbitrary semirings
|
||||
in this module, we must adjust this code.
|
||||
-/
|
||||
if (← inferType ue) == Int.mkType then
|
||||
pushEq ue' ve' <| mkApp7 (mkConst ``Grind.Order.nat_eq) ue' ve' ue ve h₁ h₂ h
|
||||
pushEq ue' ve' <| mkApp7 (mkConst ``Grind.Order.nat_eq) ue' ve' ue ve h₁ h₂ h
|
||||
where
|
||||
/--
|
||||
If `e` is an auxiliary term used to represent some term `a`, returns
|
||||
@@ -346,7 +343,7 @@ def getStructIdOf? (e : Expr) : GoalM (Option Nat) := do
|
||||
return (← get').exprToStructId.find? { expr := e }
|
||||
|
||||
def propagateIneq (e : Expr) : GoalM Unit := do
|
||||
if let some { e := e', h := he, .. } := (← get').termMap.find? { expr := e } then
|
||||
if let some (e', he) := (← get').termMap.find? { expr := e } then
|
||||
go e' (some he)
|
||||
else
|
||||
go e none
|
||||
@@ -372,27 +369,20 @@ builtin_grind_propagator propagateLT ↓LT.lt := propagateIneq
|
||||
public def processNewEq (a b : Expr) : GoalM Unit := do
|
||||
unless isSameExpr a b do
|
||||
let h ← mkEqProof a b
|
||||
if let some { e := a', h := h₁, α } ← getAuxTerm? a then
|
||||
let some { e := b', h := h₂, .. } ← getAuxTerm? b | return ()
|
||||
if let some (a', h₁) ← getAuxTerm? a then
|
||||
let some (b', h₂) ← getAuxTerm? b | return ()
|
||||
/-
|
||||
We have
|
||||
- `h : a = b`
|
||||
- `h₁ : ↑a = a'`
|
||||
- `h₂ : ↑b = b'`
|
||||
where `a'` and `b'` are `NatCast.natCast α inst _` for some type `α`.
|
||||
-/
|
||||
if α == Int.mkType then
|
||||
let h := mkApp7 (mkConst ``Grind.Order.of_nat_eq) a b a' b' h₁ h₂ h
|
||||
go a' b' h
|
||||
else
|
||||
let u ← getDecLevel α
|
||||
let inst ← synthInstance (mkApp (mkConst ``NatCast [u]) α)
|
||||
let h := mkApp9 (mkConst ``Grind.Order.of_natCast_eq [u]) α inst a b a' b' h₁ h₂ h
|
||||
go a' b' h
|
||||
let h := mkApp7 (mkConst ``Grind.Order.of_nat_eq) a b a' b' h₁ h₂ h
|
||||
go a' b' h
|
||||
else
|
||||
go a b h
|
||||
where
|
||||
getAuxTerm? (e : Expr) : GoalM (Option TermMapEntry) := do
|
||||
getAuxTerm? (e : Expr) : GoalM (Option (Expr × Expr)) := do
|
||||
return (← get').termMap.find? { expr := e }
|
||||
|
||||
go (a b h : Expr) : GoalM Unit := do
|
||||
|
||||
@@ -166,9 +166,9 @@ def setStructId (e : Expr) : OrderM Unit := do
|
||||
exprToStructId := s.exprToStructId.insert { expr := e } structId
|
||||
}
|
||||
|
||||
def updateTermMap (e eNew h α : Expr) : GoalM Unit := do
|
||||
def updateTermMap (e eNew h : Expr) : GoalM Unit := do
|
||||
modify' fun s => { s with
|
||||
termMap := s.termMap.insert { expr := e } { e := eNew, h, α }
|
||||
termMap := s.termMap.insert { expr := e } (eNew, h)
|
||||
termMapInv := s.termMapInv.insert { expr := eNew } (e, h)
|
||||
}
|
||||
|
||||
@@ -198,9 +198,9 @@ where
|
||||
getOriginal? (e : Expr) : GoalM (Option Expr) := do
|
||||
if let some (e', _) := (← get').termMapInv.find? { expr := e } then
|
||||
return some e'
|
||||
let_expr NatCast.natCast α _ a := e | return none
|
||||
let_expr NatCast.natCast _ _ a := e | return none
|
||||
if (← alreadyInternalized a) then
|
||||
updateTermMap a e (← mkEqRefl e) α
|
||||
updateTermMap a e (← mkEqRefl e)
|
||||
return some a
|
||||
else
|
||||
return none
|
||||
@@ -290,7 +290,7 @@ def internalizeTerm (e : Expr) : OrderM Unit := do
|
||||
|
||||
open Arith.Cutsat in
|
||||
def adaptNat (e : Expr) : GoalM Expr := do
|
||||
if let some { e := eNew, .. } := (← get').termMap.find? { expr := e } then
|
||||
if let some (eNew, _) := (← get').termMap.find? { expr := e } then
|
||||
return eNew
|
||||
else match_expr e with
|
||||
| LE.le _ _ lhs rhs => adaptCnstr lhs rhs (isLT := false)
|
||||
@@ -307,12 +307,12 @@ where
|
||||
let h := mkApp6
|
||||
(mkConst (if isLT then ``Nat.ToInt.lt_eq else ``Nat.ToInt.le_eq))
|
||||
lhs rhs lhs' rhs' h₁ h₂
|
||||
updateTermMap e eNew h (← getIntExpr)
|
||||
updateTermMap e eNew h
|
||||
return eNew
|
||||
|
||||
adaptTerm : GoalM Expr := do
|
||||
let (eNew, h) ← natToInt e
|
||||
updateTermMap e eNew h (← getIntExpr)
|
||||
updateTermMap e eNew h
|
||||
return eNew
|
||||
|
||||
def adapt (α : Expr) (e : Expr) : GoalM (Expr × Expr) := do
|
||||
|
||||
@@ -128,13 +128,6 @@ structure Struct where
|
||||
propagate : List ToPropagate := []
|
||||
deriving Inhabited
|
||||
|
||||
/-- Entry/Value for the map `termMap` in `State` -/
|
||||
structure TermMapEntry where
|
||||
e : Expr
|
||||
h : Expr
|
||||
α : Expr
|
||||
deriving Inhabited
|
||||
|
||||
/-- State for all order types detected by `grind`. -/
|
||||
structure State where
|
||||
/-- Order structures detected. -/
|
||||
@@ -150,7 +143,7 @@ structure State where
|
||||
Example: given `x y : Nat`, `x ≤ y + 1` is mapped to `Int.ofNat x ≤ Int.ofNat y + 1`, and proof
|
||||
of equivalence.
|
||||
-/
|
||||
termMap : PHashMap ExprPtr TermMapEntry := {}
|
||||
termMap : PHashMap ExprPtr (Expr × Expr) := {}
|
||||
/-- `termMap` inverse -/
|
||||
termMapInv : PHashMap ExprPtr (Expr × Expr) := {}
|
||||
deriving Inhabited
|
||||
|
||||
@@ -82,7 +82,6 @@ def _root_.Lean.MVarId.rewrite (mvarId : MVarId) (e : Expr) (heq : Expr)
|
||||
postprocessAppMVars `rewrite mvarId newMVars binderInfos
|
||||
(synthAssignedInstances := !tactic.skipAssignedInstances.get (← getOptions))
|
||||
let newMVarIds ← newMVars.map Expr.mvarId! |>.filterM fun mvarId => not <$> mvarId.isAssigned
|
||||
appendParentTag mvarId newMVars binderInfos
|
||||
let otherMVarIds ← getMVarsNoDelayed heqIn
|
||||
let otherMVarIds := otherMVarIds.filter (!newMVarIds.contains ·)
|
||||
let newMVarIds := newMVarIds ++ otherMVarIds
|
||||
|
||||
@@ -51,7 +51,7 @@ register_builtin_option debug.tactic.simp.checkDefEqAttr : Bool := {
|
||||
}
|
||||
|
||||
register_builtin_option warning.simp.varHead : Bool := {
|
||||
defValue := true
|
||||
defValue := false
|
||||
descr := "If true, warns when the head symbol of the left-hand side of a `@[simp]` theorem \
|
||||
is a variable. Such lemmas are tried on every simp step, which can be slow."
|
||||
}
|
||||
|
||||
@@ -145,6 +145,7 @@ public partial def wrapInstance (inst expectedType : Expr) (compile : Bool := tr
|
||||
else
|
||||
let name ← mkAuxDeclName
|
||||
let wrapped ← mkAuxDefinition name expectedType inst (compile := false)
|
||||
setReducibilityStatus name .implicitReducible
|
||||
if isMeta then modifyEnv (markMeta · name)
|
||||
if compile then
|
||||
compileDecls (logErrors := logCompileErrors) #[name]
|
||||
|
||||
@@ -273,15 +273,6 @@ with debug assertions enabled (see the `debugAssertions` option).
|
||||
@[builtin_doElem_parser] def doDebugAssert := leading_parser:leadPrec
|
||||
"debug_assert! " >> termParser
|
||||
|
||||
@[builtin_doElem_parser] def doRepeat := leading_parser
|
||||
"repeat " >> doSeq
|
||||
@[builtin_doElem_parser] def doWhileH := leading_parser
|
||||
"while " >> ident >> " : " >> withForbidden "do" termParser >> " do " >> doSeq
|
||||
@[builtin_doElem_parser] def doWhile := leading_parser
|
||||
"while " >> withForbidden "do" termParser >> " do " >> doSeq
|
||||
@[builtin_doElem_parser] def doRepeatUntil := leading_parser
|
||||
"repeat " >> doSeq >> ppDedent ppLine >> "until " >> termParser
|
||||
|
||||
/-
|
||||
We use `notFollowedBy` to avoid counterintuitive behavior.
|
||||
|
||||
|
||||
@@ -484,7 +484,6 @@ open SubExpr (Pos PosMap)
|
||||
open Delaborator (OptionsPerPos topDownAnalyze DelabM getPPOption)
|
||||
|
||||
def delabLevel (l : Level) (prec : Nat) : MetaM Syntax.Level := do
|
||||
let l ← if getPPInstantiateMVars (← getOptions) then instantiateLevelMVars l else pure l
|
||||
let mvars := getPPMVarsLevels (← getOptions)
|
||||
return Level.quote l prec (mvars := mvars) (lIndex? := (← getMCtx).findLevelIndex?)
|
||||
|
||||
|
||||
@@ -77,6 +77,8 @@ def OutputMessage.ofMsg (msg : JsonRpc.Message) : OutputMessage where
|
||||
msg? := msg
|
||||
serialized := toJson msg |>.compress
|
||||
|
||||
open Widget in
|
||||
|
||||
structure WorkerContext where
|
||||
/-- Synchronized output channel for LSP messages. Notifications for outdated versions are
|
||||
discarded on read. -/
|
||||
@@ -87,6 +89,10 @@ structure WorkerContext where
|
||||
-/
|
||||
maxDocVersionRef : IO.Ref Int
|
||||
freshRequestIdRef : IO.Ref Int
|
||||
/--
|
||||
Diagnostics that are included in every single `textDocument/publishDiagnostics` notification.
|
||||
-/
|
||||
stickyDiagnosticsRef : IO.Ref (Array InteractiveDiagnostic)
|
||||
partialHandlersRef : IO.Ref (Std.TreeMap String PartialHandlerInfo)
|
||||
pendingServerRequestsRef : IO.Ref (Std.TreeMap RequestID (IO.Promise (ServerRequestResponse Json)))
|
||||
hLog : FS.Stream
|
||||
@@ -202,11 +208,19 @@ This option can only be set on the command line, not in the lakefile or via `set
|
||||
diags : Array Widget.InteractiveDiagnostic
|
||||
deriving TypeName
|
||||
|
||||
/-- Sends a `textDocument/publishDiagnostics` notification to the client. -/
|
||||
/--
|
||||
Sends a `textDocument/publishDiagnostics` notification to the client that contains the diagnostics
|
||||
in `ctx.stickyDiagnosticsRef` and `doc.diagnosticsRef`.
|
||||
-/
|
||||
private def publishDiagnostics (ctx : WorkerContext) (doc : EditableDocumentCore)
|
||||
: BaseIO Unit := do
|
||||
let supportsIncremental := ctx.initParams.capabilities.incrementalDiagnosticSupport
|
||||
doc.publishDiagnostics supportsIncremental fun notif => ctx.chanOut.sync.send <| .ofMsg notif
|
||||
let stickyInteractiveDiagnostics ← ctx.stickyDiagnosticsRef.get
|
||||
let docInteractiveDiagnostics ← doc.diagnosticsRef.get
|
||||
let diagnostics :=
|
||||
stickyInteractiveDiagnostics ++ docInteractiveDiagnostics
|
||||
|>.map (·.toDiagnostic)
|
||||
let notification := mkPublishDiagnosticsNotification doc.meta diagnostics
|
||||
ctx.chanOut.sync.send <| .ofMsg notification
|
||||
|
||||
open Language in
|
||||
/--
|
||||
@@ -307,7 +321,7 @@ This option can only be set on the command line, not in the lakefile or via `set
|
||||
if let some cacheRef := node.element.diagnostics.interactiveDiagsRef? then
|
||||
cacheRef.set <| some <| .mk { diags : MemorizedInteractiveDiagnostics }
|
||||
pure diags
|
||||
doc.appendDiagnostics diags
|
||||
doc.diagnosticsRef.modify (· ++ diags)
|
||||
if (← get).hasBlocked then
|
||||
publishDiagnostics ctx doc
|
||||
|
||||
@@ -449,7 +463,7 @@ section Initialization
|
||||
let clientHasWidgets := initParams.initializationOptions?.bind (·.hasWidgets?) |>.getD false
|
||||
let maxDocVersionRef ← IO.mkRef 0
|
||||
let freshRequestIdRef ← IO.mkRef (0 : Int)
|
||||
let stickyDiagsRef ← IO.mkRef {}
|
||||
let stickyDiagnosticsRef ← IO.mkRef ∅
|
||||
let pendingServerRequestsRef ← IO.mkRef ∅
|
||||
let chanOut ← mkLspOutputChannel maxDocVersionRef
|
||||
let timestamp ← IO.monoMsNow
|
||||
@@ -479,10 +493,11 @@ section Initialization
|
||||
maxDocVersionRef
|
||||
freshRequestIdRef
|
||||
cmdlineOpts := opts
|
||||
stickyDiagnosticsRef
|
||||
}
|
||||
let diagnosticsMutex ← Std.Mutex.new { stickyDiagsRef }
|
||||
let doc : EditableDocumentCore := {
|
||||
«meta» := doc, initSnap, diagnosticsMutex
|
||||
«meta» := doc, initSnap
|
||||
diagnosticsRef := (← IO.mkRef ∅)
|
||||
}
|
||||
let reporterCancelTk ← CancelToken.new
|
||||
let reporter ← reportSnapshots ctx doc reporterCancelTk
|
||||
@@ -563,11 +578,14 @@ section Updates
|
||||
modify fun st => { st with pendingRequests := map st.pendingRequests }
|
||||
|
||||
/-- Given the new document, updates editable doc state. -/
|
||||
def updateDocument («meta» : DocumentMeta) : WorkerM Unit := do
|
||||
def updateDocument (doc : DocumentMeta) : WorkerM Unit := do
|
||||
(← get).reporterCancelTk.set
|
||||
let ctx ← read
|
||||
let initSnap ← ctx.processor «meta».mkInputContext
|
||||
let doc ← (← get).doc.update «meta» initSnap
|
||||
let initSnap ← ctx.processor doc.mkInputContext
|
||||
let doc : EditableDocumentCore := {
|
||||
«meta» := doc, initSnap
|
||||
diagnosticsRef := (← IO.mkRef ∅)
|
||||
}
|
||||
let reporterCancelTk ← CancelToken.new
|
||||
let reporter ← reportSnapshots ctx doc reporterCancelTk
|
||||
modify fun st => { st with doc := { doc with reporter }, reporterCancelTk }
|
||||
@@ -619,16 +637,18 @@ section NotificationHandling
|
||||
let ctx ← read
|
||||
let s ← get
|
||||
let text := s.doc.meta.text
|
||||
let importOutOfDateMessage :=
|
||||
.text s!"Imports are out of date and should be rebuilt; \
|
||||
use the \"Restart File\" command in your editor."
|
||||
let importOutOfDataMessage := .text s!"Imports are out of date and should be rebuilt; \
|
||||
use the \"Restart File\" command in your editor."
|
||||
let diagnostic := {
|
||||
range := ⟨⟨0, 0⟩, ⟨1, 0⟩⟩
|
||||
fullRange? := some ⟨⟨0, 0⟩, text.utf8PosToLspPos text.source.rawEndPos⟩
|
||||
severity? := DiagnosticSeverity.information
|
||||
message := importOutOfDateMessage
|
||||
message := importOutOfDataMessage
|
||||
}
|
||||
s.doc.appendStickyDiagnostic diagnostic
|
||||
ctx.stickyDiagnosticsRef.modify fun stickyDiagnostics =>
|
||||
let stickyDiagnostics := stickyDiagnostics.filter
|
||||
(·.message.stripTags != importOutOfDataMessage.stripTags)
|
||||
stickyDiagnostics.push diagnostic
|
||||
publishDiagnostics ctx s.doc.toEditableDocumentCore
|
||||
|
||||
def handleRpcRelease (p : Lsp.RpcReleaseParams) : WorkerM Unit := do
|
||||
@@ -739,17 +759,19 @@ section MessageHandling
|
||||
|
||||
open Widget RequestM Language in
|
||||
def handleGetInteractiveDiagnosticsRequest
|
||||
(doc : EditableDocument)
|
||||
(ctx : WorkerContext)
|
||||
(params : GetInteractiveDiagnosticsParams)
|
||||
: RequestM (Array InteractiveDiagnostic) := do
|
||||
let doc ← readDoc
|
||||
-- NOTE: always uses latest document (which is the only one we can retrieve diagnostics for);
|
||||
-- any race should be temporary as the client should re-request interactive diagnostics when
|
||||
-- they receive the non-interactive diagnostics for the new document
|
||||
let allDiags ← doc.collectCurrentDiagnostics
|
||||
let stickyDiags ← ctx.stickyDiagnosticsRef.get
|
||||
let diags ← doc.diagnosticsRef.get
|
||||
-- NOTE: does not wait for `lineRange?` to be fully elaborated, which would be problematic with
|
||||
-- fine-grained incremental reporting anyway; instead, the client is obligated to resend the
|
||||
-- request when the non-interactive diagnostics of this range have changed
|
||||
return PersistentArray.toArray <| allDiags.filter fun diag =>
|
||||
return (stickyDiags ++ diags).filter fun diag =>
|
||||
let r := diag.fullRange
|
||||
let diagStartLine := r.start.line
|
||||
let diagEndLine :=
|
||||
@@ -762,7 +784,7 @@ section MessageHandling
|
||||
s ≤ diagStartLine ∧ diagStartLine < e ∨
|
||||
diagStartLine ≤ s ∧ s < diagEndLine
|
||||
|
||||
def handlePreRequestSpecialCases? (st : WorkerState)
|
||||
def handlePreRequestSpecialCases? (ctx : WorkerContext) (st : WorkerState)
|
||||
(id : RequestID) (method : String) (params : Json)
|
||||
: RequestM (Option (RequestTask SerializedLspResponse)) := do
|
||||
match method with
|
||||
@@ -773,7 +795,7 @@ section MessageHandling
|
||||
let some seshRef := st.rpcSessions.get? params.sessionId
|
||||
| throw RequestError.rpcNeedsReconnect
|
||||
let params ← RequestM.parseRequestParams Widget.GetInteractiveDiagnosticsParams params.params
|
||||
let resp ← handleGetInteractiveDiagnosticsRequest st.doc params
|
||||
let resp ← handleGetInteractiveDiagnosticsRequest ctx params
|
||||
let resp ← seshRef.modifyGet fun st =>
|
||||
rpcEncode resp st.objects |>.map (·) ({st with objects := ·})
|
||||
return some <| .pure { response? := resp, serialized := resp.compress, isComplete := true }
|
||||
@@ -903,7 +925,7 @@ section MessageHandling
|
||||
serverRequestEmitter := sendUntypedServerRequest ctx
|
||||
}
|
||||
let requestTask? ← EIO.toIO' <| RequestM.run (rc := rc) do
|
||||
if let some response ← handlePreRequestSpecialCases? st id method params then
|
||||
if let some response ← handlePreRequestSpecialCases? ctx st id method params then
|
||||
return response
|
||||
let task ← handleLspRequest method params
|
||||
let task ← handlePostRequestSpecialCases id method params task
|
||||
|
||||
@@ -10,7 +10,6 @@ prelude
|
||||
public import Lean.Language.Lean.Types
|
||||
public import Lean.Server.Snapshots
|
||||
public import Lean.Server.AsyncList
|
||||
public import Std.Sync.Mutex
|
||||
|
||||
public section
|
||||
|
||||
@@ -40,26 +39,6 @@ where
|
||||
| some next => .delayed <| next.task.asServerTask.bindCheap go
|
||||
| none => .nil)
|
||||
|
||||
/--
|
||||
Tracks diagnostics and incremental diagnostic reporting state for a single document version.
|
||||
|
||||
The sticky diagnostics are shared across all document versions via an `IO.Ref`, while per-version
|
||||
diagnostics are stored directly. The whole state is wrapped in a `Std.Mutex` on
|
||||
`EditableDocumentCore` to ensure atomic updates.
|
||||
-/
|
||||
structure DiagnosticsState where
|
||||
/--
|
||||
Diagnostics that persist across document versions (e.g. stale dependency warnings).
|
||||
Shared across all versions via an `IO.Ref`.
|
||||
-/
|
||||
stickyDiagsRef : IO.Ref (PersistentArray Widget.InteractiveDiagnostic)
|
||||
/-- Diagnostics accumulated during snapshot reporting. -/
|
||||
diags : PersistentArray Widget.InteractiveDiagnostic := {}
|
||||
/-- Whether the next `publishDiagnostics` call should be incremental. -/
|
||||
isIncremental : Bool := false
|
||||
/-- Amount of diagnostics reported in `publishDiagnostics` so far. -/
|
||||
publishedDiagsAmount : Nat := 0
|
||||
|
||||
/--
|
||||
A document bundled with processing information. Turned into `EditableDocument` as soon as the
|
||||
reporter task has been started.
|
||||
@@ -71,94 +50,11 @@ structure EditableDocumentCore where
|
||||
initSnap : Language.Lean.InitialSnapshot
|
||||
/-- Old representation for backward compatibility. -/
|
||||
cmdSnaps : AsyncList IO.Error Snapshot := private_decl% mkCmdSnaps initSnap
|
||||
/-- Per-version diagnostics state, protected by a mutex. -/
|
||||
diagnosticsMutex : Std.Mutex DiagnosticsState
|
||||
|
||||
namespace EditableDocumentCore
|
||||
open Widget
|
||||
|
||||
/-- Appends new non-sticky diagnostics. -/
|
||||
def appendDiagnostics (doc : EditableDocumentCore) (diags : Array InteractiveDiagnostic) :
|
||||
BaseIO Unit :=
|
||||
doc.diagnosticsMutex.atomically do
|
||||
modify fun ds => { ds with diags := diags.foldl (init := ds.diags) fun acc d => acc.push d }
|
||||
|
||||
/--
|
||||
Appends a sticky diagnostic and marks the next publish as non-incremental.
|
||||
Removes any existing sticky diagnostic whose `message.stripTags` matches the new one.
|
||||
-/
|
||||
def appendStickyDiagnostic (doc : EditableDocumentCore) (diagnostic : InteractiveDiagnostic) :
|
||||
BaseIO Unit :=
|
||||
doc.diagnosticsMutex.atomically do
|
||||
let ds ← get
|
||||
ds.stickyDiagsRef.modify fun stickyDiags =>
|
||||
let stickyDiags := stickyDiags.filter
|
||||
(·.message.stripTags != diagnostic.message.stripTags)
|
||||
stickyDiags.push diagnostic
|
||||
set { ds with isIncremental := false }
|
||||
|
||||
/-- Returns all current diagnostics (sticky ++ doc). -/
|
||||
def collectCurrentDiagnostics (doc : EditableDocumentCore) :
|
||||
BaseIO (PersistentArray InteractiveDiagnostic) :=
|
||||
doc.diagnosticsMutex.atomically do
|
||||
let ds ← get
|
||||
let stickyDiags ← ds.stickyDiagsRef.get
|
||||
return stickyDiags ++ ds.diags
|
||||
|
||||
/--
|
||||
Creates a new `EditableDocumentCore` for a new document version, sharing the same sticky
|
||||
diagnostics with the previous version.
|
||||
-/
|
||||
def update (doc : EditableDocumentCore) (newMeta : DocumentMeta)
|
||||
(newInitSnap : Language.Lean.InitialSnapshot) : BaseIO EditableDocumentCore := do
|
||||
let stickyDiagsRef ← doc.diagnosticsMutex.atomically do
|
||||
return (← get).stickyDiagsRef
|
||||
let diagnosticsMutex ← Std.Mutex.new { stickyDiagsRef }
|
||||
return { «meta» := newMeta, initSnap := newInitSnap, diagnosticsMutex }
|
||||
|
||||
/--
|
||||
Collects diagnostics for a `textDocument/publishDiagnostics` notification, updates
|
||||
the incremental tracking fields and writes the notification to the client.
|
||||
|
||||
When `incrementalDiagnosticSupport` is `true` and the state allows it, sends only
|
||||
the newly added diagnostics with `isIncremental? := some true`. Otherwise, sends
|
||||
all sticky and non-sticky diagnostics non-incrementally.
|
||||
|
||||
The state update and the write are performed atomically under the diagnostics mutex
|
||||
to prevent reordering between concurrent publishers (the reporter task and the main thread).
|
||||
-/
|
||||
def publishDiagnostics (doc : EditableDocumentCore) (incrementalDiagnosticSupport : Bool)
|
||||
(writeDiagnostics : JsonRpc.Notification Lsp.PublishDiagnosticsParams → BaseIO Unit) :
|
||||
BaseIO Unit := do
|
||||
-- The mutex must be held across both the state update and the write to ensure that concurrent
|
||||
-- publishers (e.g. the reporter task and the main thread) cannot interleave their state reads
|
||||
-- and writes, which would reorder incremental/non-incremental messages and corrupt client state.
|
||||
doc.diagnosticsMutex.atomically do
|
||||
let ds ← get
|
||||
let useIncremental := incrementalDiagnosticSupport && ds.isIncremental
|
||||
let stickyDiags ← ds.stickyDiagsRef.get
|
||||
let diags := ds.diags
|
||||
let publishedDiagsAmount := ds.publishedDiagsAmount
|
||||
set <| { ds with publishedDiagsAmount := diags.size, isIncremental := true }
|
||||
let (diagsToSend, isIncremental) :=
|
||||
if useIncremental then
|
||||
let newDiags := diags.foldl (init := #[]) (start := publishedDiagsAmount) fun acc d =>
|
||||
acc.push d.toDiagnostic
|
||||
(newDiags, true)
|
||||
else
|
||||
let allDiags := stickyDiags.foldl (init := #[]) fun acc d =>
|
||||
acc.push d.toDiagnostic
|
||||
let allDiags := diags.foldl (init := allDiags) fun acc d =>
|
||||
acc.push d.toDiagnostic
|
||||
(allDiags, false)
|
||||
let isIncremental? :=
|
||||
if incrementalDiagnosticSupport then
|
||||
some isIncremental
|
||||
else
|
||||
none
|
||||
writeDiagnostics <| mkPublishDiagnosticsNotification doc.meta diagsToSend isIncremental?
|
||||
|
||||
end EditableDocumentCore
|
||||
/--
|
||||
Interactive versions of diagnostics reported so far. Filled by `reportSnapshots` and read by
|
||||
`handleGetInteractiveDiagnosticsRequest`.
|
||||
-/
|
||||
diagnosticsRef : IO.Ref (Array Widget.InteractiveDiagnostic)
|
||||
|
||||
/-- `EditableDocumentCore` with reporter task. -/
|
||||
structure EditableDocument extends EditableDocumentCore where
|
||||
|
||||
@@ -152,9 +152,9 @@ def protocolOverview : Array MessageOverview := #[
|
||||
.notification {
|
||||
method := "textDocument/publishDiagnostics"
|
||||
direction := .serverToClient
|
||||
kinds := #[.extendedParameterType #[``PublishDiagnosticsParams.isIncremental?, ``PublishDiagnosticsParams.diagnostics, ``DiagnosticWith.fullRange?, ``DiagnosticWith.isSilent?, ``DiagnosticWith.leanTags?]]
|
||||
kinds := #[.extendedParameterType #[``PublishDiagnosticsParams.diagnostics, ``DiagnosticWith.fullRange?, ``DiagnosticWith.isSilent?, ``DiagnosticWith.leanTags?]]
|
||||
parameterType := PublishDiagnosticsParams
|
||||
description := "Emitted by the language server whenever a new set of diagnostics becomes available for a file. Unlike most language servers, the Lean language server emits this notification incrementally while processing the file, not only when the full file has been processed. If the client sets `LeanClientCapabilities.incrementalDiagnosticSupport` and `isIncremental` is `true`, the diagnostics in the notification should be appended to the existing diagnostics for the same document version rather than replacing them."
|
||||
description := "Emitted by the language server whenever a new set of diagnostics becomes available for a file. Unlike most language servers, the Lean language server emits this notification incrementally while processing the file, not only when the full file has been processed."
|
||||
},
|
||||
.notification {
|
||||
method := "$/lean/fileProgress"
|
||||
|
||||
@@ -723,7 +723,6 @@ partial def main (args : List String) : IO Unit := do
|
||||
}
|
||||
}
|
||||
lean? := some {
|
||||
incrementalDiagnosticSupport? := some true
|
||||
silentDiagnosticSupport? := some true
|
||||
rpcWireFormat? := some .v1
|
||||
}
|
||||
|
||||
@@ -133,14 +133,12 @@ def foldDocumentChanges (changes : Array Lsp.TextDocumentContentChangeEvent) (ol
|
||||
changes.foldl applyDocumentChange oldText
|
||||
|
||||
/-- Constructs a `textDocument/publishDiagnostics` notification. -/
|
||||
def mkPublishDiagnosticsNotification (m : DocumentMeta) (diagnostics : Array Lsp.Diagnostic)
|
||||
(isIncremental : Option Bool := none) :
|
||||
def mkPublishDiagnosticsNotification (m : DocumentMeta) (diagnostics : Array Lsp.Diagnostic) :
|
||||
JsonRpc.Notification Lsp.PublishDiagnosticsParams where
|
||||
method := "textDocument/publishDiagnostics"
|
||||
param := {
|
||||
uri := m.uri
|
||||
version? := some m.version
|
||||
isIncremental? := isIncremental
|
||||
diagnostics := diagnostics
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ open Std.DHashMap.Internal
|
||||
|
||||
namespace Std.DHashMap.Raw
|
||||
|
||||
def instDecidableEquiv {α : Type u} {β : α → Type v} [BEq α] [LawfulBEq α] [Hashable α] [∀ k, BEq (β k)] [∀ k, LawfulBEq (β k)] {m₁ m₂ : Raw α β} (h₁ : m₁.WF) (h₂ : m₂.WF) : Decidable (m₁ ~m m₂) :=
|
||||
instance instDecidableEquiv {α : Type u} {β : α → Type v} [BEq α] [LawfulBEq α] [Hashable α] [∀ k, BEq (β k)] [∀ k, LawfulBEq (β k)] {m₁ m₂ : Raw α β} (h₁ : m₁.WF) (h₂ : m₂.WF) : Decidable (m₁ ~m m₂) :=
|
||||
Raw₀.decidableEquiv ⟨m₁, h₁.size_buckets_pos⟩ ⟨m₂, h₂.size_buckets_pos⟩ h₁ h₂
|
||||
|
||||
end Std.DHashMap.Raw
|
||||
|
||||
@@ -19,7 +19,7 @@ open Std.DTreeMap.Internal.Impl
|
||||
|
||||
namespace Std.DTreeMap.Raw
|
||||
|
||||
def instDecidableEquiv {α : Type u} {β : α → Type v} {cmp : α → α → Ordering} [TransCmp cmp] [LawfulEqCmp cmp] [∀ k, BEq (β k)] [∀ k, LawfulBEq (β k)] {t₁ t₂ : Raw α β cmp} (h₁ : t₁.WF) (h₂ : t₂.WF) : Decidable (t₁ ~m t₂) :=
|
||||
instance instDecidableEquiv {α : Type u} {β : α → Type v} {cmp : α → α → Ordering} [TransCmp cmp] [LawfulEqCmp cmp] [∀ k, BEq (β k)] [∀ k, LawfulBEq (β k)] {t₁ t₂ : Raw α β cmp} (h₁ : t₁.WF) (h₂ : t₂.WF) : Decidable (t₁ ~m t₂) :=
|
||||
let : Ord α := ⟨cmp⟩;
|
||||
let : Decidable (t₁.inner ~m t₂.inner) := decidableEquiv t₁.1 t₂.1 h₁ h₂;
|
||||
decidable_of_iff _ ⟨fun h => ⟨h⟩, fun h => h.1⟩
|
||||
|
||||
@@ -19,7 +19,7 @@ open Std.DHashMap.Raw
|
||||
|
||||
namespace Std.HashMap.Raw
|
||||
|
||||
def instDecidableEquiv {α : Type u} {β : Type v} [BEq α] [LawfulBEq α] [Hashable α] [BEq β] [LawfulBEq β] {m₁ m₂ : Raw α β} (h₁ : m₁.WF) (h₂ : m₂.WF) : Decidable (m₁ ~m m₂) :=
|
||||
instance instDecidableEquiv {α : Type u} {β : Type v} [BEq α] [LawfulBEq α] [Hashable α] [BEq β] [LawfulBEq β] {m₁ m₂ : Raw α β} (h₁ : m₁.WF) (h₂ : m₂.WF) : Decidable (m₁ ~m m₂) :=
|
||||
let : Decidable (m₁.1 ~m m₂.1) := DHashMap.Raw.instDecidableEquiv h₁.out h₂.out;
|
||||
decidable_of_iff _ ⟨fun h => ⟨h⟩, fun h => h.1⟩
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ open Std.HashMap.Raw
|
||||
|
||||
namespace Std.HashSet.Raw
|
||||
|
||||
def instDecidableEquiv {α : Type u} [BEq α] [LawfulBEq α] [Hashable α] {m₁ m₂ : Raw α} (h₁ : m₁.WF) (h₂ : m₂.WF) : Decidable (m₁ ~m m₂) :=
|
||||
instance instDecidableEquiv {α : Type u} [BEq α] [LawfulBEq α] [Hashable α] {m₁ m₂ : Raw α} (h₁ : m₁.WF) (h₂ : m₂.WF) : Decidable (m₁ ~m m₂) :=
|
||||
let : Decidable (m₁.1 ~m m₂.1) := HashMap.Raw.instDecidableEquiv h₁.out h₂.out;
|
||||
decidable_of_iff _ ⟨fun h => ⟨h⟩, fun h => h.1⟩
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ open Std.DTreeMap.Raw
|
||||
|
||||
namespace Std.TreeMap.Raw
|
||||
|
||||
def instDecidableEquiv {α : Type u} {β : Type v} {cmp : α → α → Ordering} [TransCmp cmp] [LawfulEqCmp cmp] [BEq β] [LawfulBEq β] {t₁ t₂ : Raw α β cmp} (h₁ : t₁.WF) (h₂ : t₂.WF) : Decidable (t₁ ~m t₂) :=
|
||||
instance instDecidableEquiv {α : Type u} {β : Type v} {cmp : α → α → Ordering} [TransCmp cmp] [LawfulEqCmp cmp] [BEq β] [LawfulBEq β] {t₁ t₂ : Raw α β cmp} (h₁ : t₁.WF) (h₂ : t₂.WF) : Decidable (t₁ ~m t₂) :=
|
||||
let : Ord α := ⟨cmp⟩;
|
||||
let : Decidable (t₁.inner ~m t₂.inner) := DTreeMap.Raw.instDecidableEquiv h₁ h₂;
|
||||
decidable_of_iff _ ⟨fun h => ⟨h⟩, fun h => h.1⟩
|
||||
|
||||
@@ -20,7 +20,7 @@ open Std.TreeMap.Raw
|
||||
|
||||
namespace Std.TreeSet.Raw
|
||||
|
||||
def instDecidableEquiv {α : Type u} {cmp : α → α → Ordering} [TransCmp cmp] [LawfulEqCmp cmp] {t₁ t₂ : Raw α cmp} (h₁ : t₁.WF) (h₂ : t₂.WF) : Decidable (t₁ ~m t₂) :=
|
||||
instance instDecidableEquiv {α : Type u} {cmp : α → α → Ordering} [TransCmp cmp] [LawfulEqCmp cmp] {t₁ t₂ : Raw α cmp} (h₁ : t₁.WF) (h₂ : t₂.WF) : Decidable (t₁ ~m t₂) :=
|
||||
let : Ord α := ⟨cmp⟩;
|
||||
let : Decidable (t₁.inner ~m t₂.inner) := TreeMap.Raw.instDecidableEquiv h₁ h₂;
|
||||
decidable_of_iff _ ⟨fun h => ⟨h⟩, fun h => h.1⟩
|
||||
|
||||
@@ -178,6 +178,48 @@ theorem entails_pure_elim_cons {σ : Type u} [Inhabited σ] (P Q : Prop) : entai
|
||||
@[simp] theorem entails_4 {P Q : SPred [σ₁, σ₂, σ₃, σ₄]} : SPred.entails P Q ↔ (∀ s₁ s₂ s₃ s₄, (P s₁ s₂ s₃ s₄).down → (Q s₁ s₂ s₃ s₄).down) := iff_of_eq rfl
|
||||
@[simp] theorem entails_5 {P Q : SPred [σ₁, σ₂, σ₃, σ₄, σ₅]} : SPred.entails P Q ↔ (∀ s₁ s₂ s₃ s₄ s₅, (P s₁ s₂ s₃ s₄ s₅).down → (Q s₁ s₂ s₃ s₄ s₅).down) := iff_of_eq rfl
|
||||
|
||||
/-!
|
||||
# `SPred.evalsTo`
|
||||
|
||||
Relates a stateful value `SVal σs α` to a pure value `a : α`, lifting equality through the state.
|
||||
-/
|
||||
|
||||
/-- Relates a stateful value to a pure value, lifting equality through the state layers. -/
|
||||
def evalsTo {α : Type u} {σs : List (Type u)} (f : SVal σs α) (a : α) : SPred σs :=
|
||||
match σs with
|
||||
| [] => ⌜a = f⌝
|
||||
| _ :: _ => fun s => evalsTo (f s) a
|
||||
|
||||
@[simp, grind =] theorem evalsTo_nil {f : SVal [] α} {a : α} :
|
||||
evalsTo f a = ⌜a = f⌝ := rfl
|
||||
|
||||
theorem evalsTo_cons {f : σ → SVal σs α} {a : α} {s : σ} :
|
||||
evalsTo (σs := σ::σs) f a s = evalsTo (f s) a := rfl
|
||||
|
||||
@[simp, grind =] theorem evalsTo_1 {f : SVal [σ] α} {a : α} {s : σ} :
|
||||
evalsTo f a s = evalsTo (f s) a := rfl
|
||||
|
||||
@[simp, grind =] theorem evalsTo_2 {f : SVal [σ₁, σ₂] α} {a : α} {s₁ : σ₁} {s₂ : σ₂} :
|
||||
evalsTo f a s₁ s₂ = evalsTo (f s₁ s₂) a := rfl
|
||||
|
||||
@[simp, grind =] theorem evalsTo_3 {f : SVal [σ₁, σ₂, σ₃] α} {a : α}
|
||||
{s₁ : σ₁} {s₂ : σ₂} {s₃ : σ₃} :
|
||||
evalsTo f a s₁ s₂ s₃ = evalsTo (f s₁ s₂ s₃) a := rfl
|
||||
|
||||
@[simp, grind =] theorem evalsTo_4 {f : SVal [σ₁, σ₂, σ₃, σ₄] α} {a : α}
|
||||
{s₁ : σ₁} {s₂ : σ₂} {s₃ : σ₃} {s₄ : σ₄} :
|
||||
evalsTo f a s₁ s₂ s₃ s₄ = evalsTo (f s₁ s₂ s₃ s₄) a := rfl
|
||||
|
||||
@[simp, grind =] theorem evalsTo_5 {f : SVal [σ₁, σ₂, σ₃, σ₄, σ₅] α} {a : α}
|
||||
{s₁ : σ₁} {s₂ : σ₂} {s₃ : σ₃} {s₄ : σ₄} {s₅ : σ₅} :
|
||||
evalsTo f a s₁ s₂ s₃ s₄ s₅ = evalsTo (f s₁ s₂ s₃ s₄ s₅) a := rfl
|
||||
|
||||
theorem evalsTo_total {P : SPred σs} (f : SVal σs α) :
|
||||
P ⊢ₛ ∃ m, evalsTo f m := by
|
||||
induction σs with
|
||||
| nil => simp
|
||||
| cons _ _ ih => intro s; apply ih
|
||||
|
||||
/-! # Tactic support -/
|
||||
|
||||
namespace Tactic
|
||||
@@ -338,3 +380,4 @@ theorem Frame.frame {P Q T : SPred σs} {φ : Prop} [HasFrame P Q φ]
|
||||
· exact HasFrame.reassoc.mp.trans SPred.and_elim_r
|
||||
· intro hp
|
||||
exact HasFrame.reassoc.mp.trans (SPred.and_elim_l' (h hp))
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ module
|
||||
|
||||
prelude
|
||||
public import Std.Do.SPred.Notation
|
||||
import Init.PropLemmas
|
||||
|
||||
@[expose] public section
|
||||
|
||||
@@ -157,3 +158,17 @@ theorem imp_curry {P Q : SVal.StateTuple σs → Prop} : (SVal.curry (fun t =>
|
||||
induction σs
|
||||
case nil => rfl
|
||||
case cons σ σs ih => intro s; simp only [imp_cons, SVal.curry_cons]; exact ih
|
||||
|
||||
/-! # Prop-indexed quantifiers -/
|
||||
|
||||
/-- Simplifies an existential over a true proposition. -/
|
||||
theorem exists_prop_of_true {p : Prop} (h : p) {P : p → SPred σs} : spred(∃ (h : p), P h) = P h := by
|
||||
induction σs with
|
||||
| nil => ext; exact _root_.exists_prop_of_true h
|
||||
| cons σ σs ih => ext; exact ih
|
||||
|
||||
/-- Simplifies a universal over a true proposition. -/
|
||||
theorem forall_prop_of_true {p : Prop} (h : p) {P : p → SPred σs} : spred(∀ (h : p), P h) = P h := by
|
||||
induction σs with
|
||||
| nil => ext; exact _root_.forall_prop_of_true h
|
||||
| cons σ σs ih => ext; exact ih
|
||||
|
||||
94
src/Std/Do/Triple/RepeatSpec.lean
Normal file
94
src/Std/Do/Triple/RepeatSpec.lean
Normal file
@@ -0,0 +1,94 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sebastian Graf
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Std.Do.Triple.SpecLemmas
|
||||
import Std.Tactic.Do.Syntax
|
||||
|
||||
set_option linter.missingDocs true
|
||||
|
||||
@[expose] public section
|
||||
|
||||
namespace Std.Do
|
||||
|
||||
/-!
|
||||
# Specification theorem for `Loop`-based `repeat`/`while` loops
|
||||
|
||||
This file contains the `@[spec]` theorem for `forIn` over `Lean.Loop`, which enables
|
||||
verified reasoning about `repeat`/`while` loops using `mvcgen`.
|
||||
-/
|
||||
|
||||
set_option mvcgen.warning false
|
||||
|
||||
/-- A variant (termination measure) for a `repeat`/`while` loop. -/
|
||||
@[spec_invariant_type]
|
||||
abbrev RepeatVariant (β : Type u) (ps : PostShape.{u}) := β → SVal ps.args (ULift Nat)
|
||||
|
||||
set_option linter.missingDocs false in
|
||||
abbrev RepeatVariant.eval {β ps} (variant : RepeatVariant β ps) (b : β) (n : Nat) :=
|
||||
SPred.evalsTo (variant b) ⟨n⟩
|
||||
|
||||
/-- An invariant for a `repeat`/`while` loop. -/
|
||||
@[spec_invariant_type]
|
||||
abbrev RepeatInvariant β ps := PostCond (Bool × β) ps
|
||||
|
||||
section
|
||||
|
||||
variable {β : Type u} {m : Type u → Type v} {ps : PostShape.{u}}
|
||||
|
||||
private theorem RepeatVariant.eval_total {P : SPred ps.args} (variant : RepeatVariant β ps) (b : β) :
|
||||
P ⊢ₛ ∃ m, RepeatVariant.eval variant b m := by
|
||||
mintro _
|
||||
mhave h2 := SPred.evalsTo_total (variant b)
|
||||
mcases h2 with ⟨m, h2⟩
|
||||
mexists m.down
|
||||
|
||||
private theorem RepeatVariant.add_eval {P Q : SPred ps.args} (variant : RepeatVariant β ps) (b : β)
|
||||
(h : spred(∃ m, RepeatVariant.eval variant b m ∧ P) ⊢ₛ Q) : P ⊢ₛ Q := by
|
||||
apply SPred.entails.trans _ h
|
||||
mintro _
|
||||
mhave h2 := RepeatVariant.eval_total variant b
|
||||
mcases h2 with ⟨m, h2⟩
|
||||
mexists m
|
||||
mconstructor <;> massumption
|
||||
|
||||
end
|
||||
|
||||
section
|
||||
|
||||
variable {β : Type u} {m : Type u → Type v} {ps : PostShape.{u}}
|
||||
variable [Monad m] [Lean.Order.MonadTail m] [WPMonad m ps]
|
||||
|
||||
@[spec]
|
||||
theorem Spec.forIn_loop
|
||||
{l : _root_.Lean.Loop} {init : β} {f : Unit → β → m (ForInStep β)}
|
||||
(measure : RepeatVariant β ps)
|
||||
(inv : RepeatInvariant β ps)
|
||||
(step : ∀ b mb,
|
||||
Triple (f () b)
|
||||
spred(RepeatVariant.eval measure b mb ∧ inv.1 (false, b))
|
||||
(fun r => match r with
|
||||
| .yield b' => spred(∃ mb', RepeatVariant.eval measure b' mb' ∧ ⌜mb' < mb⌝ ∧ inv.1 (false, b'))
|
||||
| .done b' => inv.1 (true, b'), inv.2)) :
|
||||
Triple (forIn l init f) spred(inv.1 (false, init)) (fun b => inv.1 (true, b), inv.2) := by
|
||||
haveI : Nonempty β := ⟨init⟩
|
||||
simp only [forIn]
|
||||
apply RepeatVariant.add_eval measure init
|
||||
apply SPred.exists_elim
|
||||
intro minit
|
||||
induction minit using Nat.strongRecOn generalizing init with
|
||||
| _ minit ih =>
|
||||
rw [_root_.Lean.Loop.forIn_eq]
|
||||
mvcgen [step, ih] with
|
||||
| vc2 =>
|
||||
mrename_i h
|
||||
mcases h with ⟨mb', ⟨hmeasure, ⌜hmb'⌝, h⟩⟩
|
||||
mspec Triple.of_entails_wp (ih mb' hmb')
|
||||
|
||||
end
|
||||
|
||||
end Std.Do
|
||||
@@ -10,6 +10,7 @@ public import Std.Do.Triple.Basic
|
||||
public import Init.Data.Range.Polymorphic.Iterators
|
||||
import Init.Data.Range.Polymorphic
|
||||
public import Init.Data.Slice.Array
|
||||
public import Init.While
|
||||
|
||||
-- This public import is a workaround for #10652.
|
||||
-- Without it, adding the `spec` attribute for `instMonadLiftTOfMonadLift` will fail.
|
||||
|
||||
@@ -132,8 +132,6 @@ partial def Selectable.one (selectables : Array (Selectable α)) : Async α := d
|
||||
let gen := mkStdGen seed
|
||||
let selectables := shuffleIt selectables gen
|
||||
|
||||
let gate ← IO.Promise.new
|
||||
|
||||
for selectable in selectables do
|
||||
if let some val ← selectable.selector.tryFn then
|
||||
let result ← selectable.cont val
|
||||
@@ -143,14 +141,11 @@ partial def Selectable.one (selectables : Array (Selectable α)) : Async α := d
|
||||
let promise ← IO.Promise.new
|
||||
|
||||
for selectable in selectables do
|
||||
if ← finished.get then
|
||||
break
|
||||
|
||||
let waiterPromise ← IO.Promise.new
|
||||
let waiter := Waiter.mk finished waiterPromise
|
||||
selectable.selector.registerFn waiter
|
||||
|
||||
discard <| IO.bindTask (t := waiterPromise.result?) (sync := true) fun res? => do
|
||||
discard <| IO.bindTask (t := waiterPromise.result?) fun res? => do
|
||||
match res? with
|
||||
| none =>
|
||||
/-
|
||||
@@ -162,20 +157,18 @@ partial def Selectable.one (selectables : Array (Selectable α)) : Async α := d
|
||||
let async : Async _ :=
|
||||
try
|
||||
let res ← IO.ofExcept res
|
||||
discard <| await gate.result?
|
||||
|
||||
for selectable in selectables do
|
||||
selectable.selector.unregisterFn
|
||||
|
||||
promise.resolve (.ok (← selectable.cont res))
|
||||
let contRes ← selectable.cont res
|
||||
promise.resolve (.ok contRes)
|
||||
catch e =>
|
||||
promise.resolve (.error e)
|
||||
|
||||
async.toBaseIO
|
||||
|
||||
gate.resolve ()
|
||||
let result ← Async.ofPromise (pure promise)
|
||||
return result
|
||||
Async.ofPromise (pure promise)
|
||||
|
||||
/--
|
||||
Performs fair and data-loss free non-blocking multiplexing on the `Selectable`s in `selectables`.
|
||||
@@ -231,8 +224,6 @@ def Selectable.combine (selectables : Array (Selectable α)) : IO (Selector α)
|
||||
let derivedWaiter := Waiter.mk waiter.finished waiterPromise
|
||||
selectable.selector.registerFn derivedWaiter
|
||||
|
||||
let barrier ← IO.Promise.new
|
||||
|
||||
discard <| IO.bindTask (t := waiterPromise.result?) fun res? => do
|
||||
match res? with
|
||||
| none => return (Task.pure (.ok ()))
|
||||
@@ -240,7 +231,6 @@ def Selectable.combine (selectables : Array (Selectable α)) : IO (Selector α)
|
||||
let async : Async _ := do
|
||||
let mainPromise := waiter.promise
|
||||
|
||||
await barrier
|
||||
for selectable in selectables do
|
||||
selectable.selector.unregisterFn
|
||||
|
||||
|
||||
@@ -6,189 +6,5 @@ Authors: Sofia Rodrigues
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Std.Internal.Http.Server
|
||||
public import Std.Internal.Http.Test.Helpers
|
||||
|
||||
public section
|
||||
|
||||
/-!
|
||||
# HTTP Library
|
||||
|
||||
A low-level HTTP/1.1 server implementation for Lean. This library provides a pure,
|
||||
sans-I/O protocol implementation that can be used with the `Async` library or with
|
||||
custom connection handlers.
|
||||
|
||||
## Overview
|
||||
|
||||
This module provides a complete HTTP/1.1 server implementation with support for:
|
||||
|
||||
- Request/response handling with directional streaming bodies
|
||||
- Keep-alive connections
|
||||
- Chunked transfer encoding
|
||||
- Header validation and management
|
||||
- Configurable timeouts and limits
|
||||
|
||||
**Sans I/O Architecture**: The core protocol logic doesn't perform any actual I/O itself -
|
||||
it just defines how data should be processed. This separation allows the protocol implementation
|
||||
to remain pure and testable, while different transports (TCP sockets, mock clients) handle
|
||||
the actual reading and writing of bytes.
|
||||
|
||||
## Quick Start
|
||||
|
||||
The main entry point is `Server.serve`, which starts an HTTP/1.1 server. Implement the
|
||||
`Server.Handler` type class to define how the server handles requests, errors, and
|
||||
`Expect: 100-continue` headers:
|
||||
|
||||
```lean
|
||||
import Std.Internal.Http
|
||||
|
||||
open Std Internal IO Async
|
||||
open Std Http Server
|
||||
|
||||
structure MyHandler
|
||||
|
||||
instance : Handler MyHandler where
|
||||
onRequest _ req := do
|
||||
Response.ok |>.text "Hello, World!"
|
||||
|
||||
def main : IO Unit := Async.block do
|
||||
let addr : Net.SocketAddress := .v4 ⟨.ofParts 127 0 0 1, 8080⟩
|
||||
let server ← Server.serve addr MyHandler.mk
|
||||
server.waitShutdown
|
||||
```
|
||||
|
||||
## Working with Requests
|
||||
|
||||
Incoming requests are represented by `Request Body.Stream`, which bundles the request
|
||||
line, parsed headers, and a lazily-consumed body. Headers are available immediately,
|
||||
while the body can be streamed or collected on demand, allowing handlers to process both
|
||||
small and large payloads efficiently.
|
||||
|
||||
### Reading Headers
|
||||
|
||||
```lean
|
||||
def handler (req : Request Body.Stream) : ContextAsync (Response Body.Stream) := do
|
||||
-- Access request method and URI
|
||||
let method := req.head.method -- Method.get, Method.post, etc.
|
||||
let uri := req.head.uri -- RequestTarget
|
||||
|
||||
-- Read a specific header
|
||||
if let some contentType := req.head.headers.get? (.mk "content-type") then
|
||||
IO.println s!"Content-Type: {contentType}"
|
||||
|
||||
Response.ok |>.text "OK"
|
||||
```
|
||||
|
||||
### URI Query Semantics
|
||||
|
||||
`RequestTarget.query` is parsed using form-style key/value conventions (`k=v&...`), and `+` is decoded as a
|
||||
space in query components. If you need RFC 3986 opaque query handling, use the raw request target string
|
||||
(`toString req.head.uri`) and parse it with custom logic.
|
||||
|
||||
### Reading the Request Body
|
||||
|
||||
The request body is exposed as `Body.Stream`, which can be consumed incrementally or
|
||||
collected into memory. The `readAll` method reads the entire body, with an optional size
|
||||
limit to protect against unbounded payloads.
|
||||
|
||||
```lean
|
||||
def handler (req : Request Body.Stream) : ContextAsync (Response Body.Stream) := do
|
||||
-- Collect entire body as a String
|
||||
let bodyStr : String ← req.body.readAll
|
||||
|
||||
-- Or with a maximum size limit
|
||||
let bodyStr : String ← req.body.readAll (maximumSize := some 1024)
|
||||
|
||||
Response.ok |>.text s!"Received: {bodyStr}"
|
||||
```
|
||||
|
||||
## Building Responses
|
||||
|
||||
Responses are constructed using a builder API that starts from a status code and adds
|
||||
headers and a body. Common helpers exist for text, HTML, JSON, and binary responses, while
|
||||
still allowing full control over status codes and header values.
|
||||
|
||||
Response builders produce `Async (Response Body.Stream)`.
|
||||
|
||||
```lean
|
||||
-- Text response
|
||||
Response.ok |>.text "Hello!"
|
||||
|
||||
-- HTML response
|
||||
Response.ok |>.html "<h1>Hello!</h1>"
|
||||
|
||||
-- JSON response
|
||||
Response.ok |>.json "{\"key\": \"value\"}"
|
||||
|
||||
-- Binary response
|
||||
Response.ok |>.bytes someByteArray
|
||||
|
||||
-- Custom status
|
||||
Response.new |>.status .created |>.text "Resource created"
|
||||
|
||||
-- With custom headers
|
||||
Response.ok
|
||||
|>.header! "X-Custom-Header" "value"
|
||||
|>.header! "Cache-Control" "no-cache"
|
||||
|>.text "Response with headers"
|
||||
```
|
||||
|
||||
### Streaming Responses
|
||||
|
||||
For large responses or server-sent events, use streaming:
|
||||
|
||||
```lean
|
||||
def handler (req : Request Body.Stream) : ContextAsync (Response Body.Stream) := do
|
||||
Response.ok
|
||||
|>.header! "Content-Type" "text/plain"
|
||||
|>.stream fun stream => do
|
||||
for i in [0:10] do
|
||||
stream.send { data := s!"chunk {i}\n".toUTF8 }
|
||||
Async.sleep 1000
|
||||
stream.close
|
||||
```
|
||||
|
||||
## Server Configuration
|
||||
|
||||
Configure server behavior with `Config`:
|
||||
|
||||
```lean
|
||||
def config : Config := {
|
||||
maxRequests := 10000000,
|
||||
lingeringTimeout := 5000,
|
||||
}
|
||||
|
||||
let server ← Server.serve addr MyHandler.mk config
|
||||
```
|
||||
|
||||
## Handler Type Class
|
||||
|
||||
Implement `Server.Handler` to define how the server processes events. The class has three
|
||||
methods, all with default implementations:
|
||||
|
||||
- `onRequest` — called for each incoming request; returns a response inside `ContextAsync`
|
||||
- `onFailure` — called when an error occurs while processing a request
|
||||
- `onContinue` — called when a request includes an `Expect: 100-continue` header; return
|
||||
`true` to accept the body or `false` to reject it
|
||||
|
||||
```lean
|
||||
structure MyHandler where
|
||||
greeting : String
|
||||
|
||||
instance : Handler MyHandler where
|
||||
onRequest self req := do
|
||||
Response.ok |>.text self.greeting
|
||||
|
||||
onFailure self err := do
|
||||
IO.eprintln s!"Error: {err}"
|
||||
```
|
||||
|
||||
The handler methods operate in the following monads:
|
||||
|
||||
- `onRequest` uses `ContextAsync` — an asynchronous monad (`ReaderT CancellationContext Async`) that provides:
|
||||
- Full access to `Async` operations (spawning tasks, sleeping, concurrent I/O)
|
||||
- A `CancellationContext` tied to the client connection — when the client disconnects, the
|
||||
context is cancelled, allowing your handler to detect this and stop work early
|
||||
- `onFailure` uses `Async`
|
||||
- `onContinue` uses `Async`
|
||||
-/
|
||||
public import Std.Internal.Http.Data
|
||||
public import Std.Internal.Http.Protocol.H1
|
||||
|
||||
@@ -48,12 +48,6 @@ structure Any where
|
||||
-/
|
||||
recvSelector : Selector (Option Chunk)
|
||||
|
||||
/--
|
||||
Non-blocking receive attempt. Returns `none` if no chunk is immediately available,
|
||||
`some (some chunk)` when a chunk is ready, or `some none` at end-of-stream.
|
||||
-/
|
||||
tryRecv : Async (Option (Option Chunk))
|
||||
|
||||
/--
|
||||
Returns the declared size.
|
||||
-/
|
||||
@@ -73,7 +67,6 @@ def ofBody [Http.Body α] (body : α) : Any where
|
||||
close := Http.Body.close body
|
||||
isClosed := Http.Body.isClosed body
|
||||
recvSelector := Http.Body.recvSelector body
|
||||
tryRecv := Http.Body.tryRecv body
|
||||
getKnownSize := Http.Body.getKnownSize body
|
||||
setKnownSize := Http.Body.setKnownSize body
|
||||
|
||||
@@ -84,7 +77,6 @@ instance : Http.Body Any where
|
||||
close := Any.close
|
||||
isClosed := Any.isClosed
|
||||
recvSelector := Any.recvSelector
|
||||
tryRecv := Any.tryRecv
|
||||
getKnownSize := Any.getKnownSize
|
||||
setKnownSize := Any.setKnownSize
|
||||
|
||||
|
||||
@@ -50,12 +50,6 @@ class Body (α : Type) where
|
||||
-/
|
||||
recvSelector : α → Selector (Option Chunk)
|
||||
|
||||
/--
|
||||
Non-blocking receive attempt. Returns `none` if no chunk is immediately available,
|
||||
`some (some chunk)` when a chunk is ready, or `some none` at end-of-stream.
|
||||
-/
|
||||
tryRecv (body : α) : Async (Option (Option Chunk))
|
||||
|
||||
/--
|
||||
Gets the declared size of the body.
|
||||
-/
|
||||
|
||||
@@ -52,13 +52,6 @@ Empty bodies are always closed for reading.
|
||||
def isClosed (_ : Empty) : Async Bool :=
|
||||
pure true
|
||||
|
||||
/--
|
||||
Non-blocking receive. Empty bodies are always at EOF.
|
||||
-/
|
||||
@[inline]
|
||||
def tryRecv (_ : Empty) : Async (Option (Option Chunk)) :=
|
||||
pure (some none)
|
||||
|
||||
/--
|
||||
Selector that immediately resolves with end-of-stream for an empty body.
|
||||
-/
|
||||
@@ -79,7 +72,6 @@ instance : Http.Body Empty where
|
||||
close := Empty.close
|
||||
isClosed := Empty.isClosed
|
||||
recvSelector := Empty.recvSelector
|
||||
tryRecv := Empty.tryRecv
|
||||
getKnownSize _ := pure (some <| .fixed 0)
|
||||
setKnownSize _ _ := pure ()
|
||||
|
||||
|
||||
@@ -100,14 +100,6 @@ def getKnownSize (full : Full) : Async (Option Body.Length) :=
|
||||
| none => pure (some (.fixed 0))
|
||||
| some data => pure (some (.fixed data.size))
|
||||
|
||||
/--
|
||||
Non-blocking receive. `Full` bodies are always in-memory, so data is always
|
||||
immediately available. Returns `some chunk` on first call, `some none` (EOF)
|
||||
once consumed or closed.
|
||||
-/
|
||||
def tryRecv (full : Full) : Async (Option (Option Chunk)) := do
|
||||
return some (← full.state.atomically takeChunk)
|
||||
|
||||
/--
|
||||
Selector that immediately resolves to the remaining chunk (or EOF).
|
||||
-/
|
||||
@@ -136,7 +128,6 @@ instance : Http.Body Full where
|
||||
close := Full.close
|
||||
isClosed := Full.isClosed
|
||||
recvSelector := Full.recvSelector
|
||||
tryRecv := Full.tryRecv
|
||||
getKnownSize := Full.getKnownSize
|
||||
setKnownSize _ _ := pure ()
|
||||
|
||||
|
||||
@@ -227,19 +227,6 @@ def tryRecv (stream : Stream) : Async (Option Chunk) :=
|
||||
Channel.pruneFinishedWaiters
|
||||
Channel.tryRecv'
|
||||
|
||||
/--
|
||||
Non-blocking receive for the `Body` typeclass. Returns `none` when no producer is
|
||||
waiting and the channel is still open, `some (some chunk)` when data is ready,
|
||||
or `some none` at end-of-stream (channel closed with no pending producer).
|
||||
-/
|
||||
def tryRecvBody (stream : Stream) : Async (Option (Option Chunk)) :=
|
||||
stream.state.atomically do
|
||||
Channel.pruneFinishedWaiters
|
||||
if ← Channel.recvReady' then
|
||||
return some (← Channel.tryRecv')
|
||||
else
|
||||
return none
|
||||
|
||||
private def recv' (stream : Stream) : BaseIO (AsyncTask (Option Chunk)) := do
|
||||
stream.state.atomically do
|
||||
Channel.pruneFinishedWaiters
|
||||
@@ -611,7 +598,6 @@ instance : Http.Body Stream where
|
||||
close := Stream.close
|
||||
isClosed := Stream.isClosed
|
||||
recvSelector := Stream.recvSelector
|
||||
tryRecv := Stream.tryRecvBody
|
||||
getKnownSize := Stream.getKnownSize
|
||||
setKnownSize := Stream.setKnownSize
|
||||
|
||||
|
||||
@@ -156,17 +156,6 @@ end Chunk.ExtensionValue
|
||||
/--
|
||||
Represents a chunk of data with optional extensions (key-value pairs).
|
||||
|
||||
The interpretation of a chunk depends on the protocol layer consuming it:
|
||||
|
||||
- HTTP/1.1: The zero-size wire encoding (`0\r\n\r\n`) is reserved
|
||||
exclusively as the `last-chunk` terminator. The HTTP/1.1 writer silently discards
|
||||
any empty chunk (including its extensions) rather than emitting a premature
|
||||
end-of-body signal. `Encode.encode` on a `Chunk.empty` value does produce
|
||||
`"0\r\n\r\n"`, but that path bypasses the writer's framing logic.
|
||||
|
||||
- HTTP/2 (not yet implemented): Chunked transfer encoding does not exist; HTTP/2 uses DATA
|
||||
frames instead. This type is specific to the HTTP/1.1 wire format.
|
||||
|
||||
Reference: https://httpwg.org/specs/rfc9112.html#rfc.section.7.1
|
||||
-/
|
||||
structure Chunk where
|
||||
@@ -212,7 +201,7 @@ def toString? (chunk : Chunk) : Option String :=
|
||||
instance : Encode .v11 Chunk where
|
||||
encode buffer chunk :=
|
||||
let chunkLen := chunk.data.size
|
||||
let exts := chunk.extensions.foldl (fun acc (name, value) =>
|
||||
let exts := chunk.extensions.foldl (fun acc (name, value) =>
|
||||
acc ++ ";" ++ name.value ++ (value.elim "" (fun x => "=" ++ x.quote))) ""
|
||||
let size := Nat.toDigits 16 chunkLen |>.toArray |>.map Char.toUInt8 |> ByteArray.mk
|
||||
buffer.append #[size, exts.toUTF8, "\r\n".toUTF8, chunk.data, "\r\n".toUTF8]
|
||||
|
||||
@@ -78,9 +78,7 @@ namespace ContentLength
|
||||
Parses a content length header value.
|
||||
-/
|
||||
def parse (v : Value) : Option ContentLength :=
|
||||
let s := v.value
|
||||
if s.isEmpty || !s.all Char.isDigit then none
|
||||
else s.toNat?.map (.mk)
|
||||
v.value.toNat?.map (.mk)
|
||||
|
||||
/--
|
||||
Serializes a content length header back to a name-value pair.
|
||||
|
||||
@@ -703,37 +703,22 @@ private def writeHead (messageHead : Message.Head dir.swap) (machine : Machine d
|
||||
let machine := machine.updateKeepAlive shouldKeepAlive
|
||||
let size := Writer.determineTransferMode machine.writer
|
||||
|
||||
-- RFC 7230 §3.3.1: HTTP/1.0 does not support Transfer-Encoding. When the
|
||||
-- response body length is unknown (chunked mode), fall back to connection-close
|
||||
-- framing: disable keep-alive and write raw bytes (no chunk encoding, no TE header).
|
||||
let machine :=
|
||||
match dir, machine.reader.messageHead.version, size with
|
||||
| .receiving, Version.v10, .chunked => machine.disableKeepAlive
|
||||
| _, _, _ => machine
|
||||
|
||||
let headers := messageHead.headers
|
||||
|
||||
-- Add identity header based on direction. handler wins if it already set one.
|
||||
-- Add identity header based on direction
|
||||
let headers :=
|
||||
let identityOpt := machine.config.agentName
|
||||
match dir, identityOpt with
|
||||
| .receiving, some server =>
|
||||
if headers.contains Header.Name.server then headers
|
||||
else headers.insert Header.Name.server server
|
||||
| .sending, some userAgent =>
|
||||
if headers.contains Header.Name.userAgent then headers
|
||||
else headers.insert Header.Name.userAgent userAgent
|
||||
| .receiving, some server => headers.insert Header.Name.server server
|
||||
| .sending, some userAgent => headers.insert Header.Name.userAgent userAgent
|
||||
| _, none => headers
|
||||
|
||||
-- Add Connection header based on keep-alive state and protocol version.
|
||||
-- Erase any handler-supplied value first to avoid duplicate or conflicting
|
||||
-- Connection headers on the wire.
|
||||
let headers := headers.erase Header.Name.connection
|
||||
|
||||
-- Add Connection header based on keep-alive state and protocol version
|
||||
let headers :=
|
||||
if !machine.keepAlive then
|
||||
if !machine.keepAlive ∧ !headers.hasEntry Header.Name.connection (.mk "close") then
|
||||
headers.insert Header.Name.connection (.mk "close")
|
||||
else if machine.reader.messageHead.version == .v10 then
|
||||
else if machine.keepAlive ∧ machine.reader.messageHead.version == .v10
|
||||
∧ !headers.hasEntry Header.Name.connection (.mk "keep-alive") then
|
||||
-- RFC 2616 §19.7.1: HTTP/1.0 keep-alive responses must echo Connection: keep-alive
|
||||
headers.insert Header.Name.connection (.mk "keep-alive")
|
||||
else
|
||||
@@ -744,29 +729,18 @@ private def writeHead (messageHead : Message.Head dir.swap) (machine : Machine d
|
||||
let headers :=
|
||||
match dir, messageHead with
|
||||
| .receiving, messageHead =>
|
||||
if responseForbidsFramingHeaders messageHead.status ∨ messageHead.status == .notModified then
|
||||
headers
|
||||
|>.erase Header.Name.contentLength
|
||||
|>.erase Header.Name.transferEncoding
|
||||
else if machine.reader.messageHead.version == .v10 && size == .chunked then
|
||||
-- RFC 7230 §3.3.1: connection-close framing for HTTP/1.0 — strip all framing
|
||||
-- headers so neither Content-Length nor Transfer-Encoding appears on the wire.
|
||||
headers
|
||||
|>.erase Header.Name.contentLength
|
||||
|>.erase Header.Name.transferEncoding
|
||||
if responseForbidsFramingHeaders messageHead.status then
|
||||
headers.erase Header.Name.contentLength |>.erase Header.Name.transferEncoding
|
||||
else if messageHead.status == .notModified then
|
||||
-- `304` carries no body; keep explicit Content-Length metadata if the
|
||||
-- user supplied it, but never keep Transfer-Encoding.
|
||||
headers.erase Header.Name.transferEncoding
|
||||
else
|
||||
normalizeFramingHeaders headers size
|
||||
| .sending, _ =>
|
||||
normalizeFramingHeaders headers size
|
||||
|
||||
let state : Writer.State :=
|
||||
match size with
|
||||
| .fixed n => .writingBodyFixed n
|
||||
| .chunked =>
|
||||
-- RFC 7230 §3.3.1: HTTP/1.0 server-side uses connection-close framing (no chunk framing).
|
||||
match dir, machine.reader.messageHead.version with
|
||||
| .receiving, .v10 => .writingBodyClosingFrame
|
||||
| _, _ => .writingBodyChunked
|
||||
let state := Writer.State.writingBody size
|
||||
|
||||
machine.modifyWriter (fun writer => {
|
||||
writer with
|
||||
@@ -917,13 +891,6 @@ def send (machine : Machine dir) (message : Message.Head dir.swap) : Machine dir
|
||||
| .receiving => message.status.isInformational
|
||||
| .sending => false
|
||||
if isInterim then
|
||||
-- RFC 9110 §15.2: 1xx responses MUST NOT carry a body, so framing headers
|
||||
-- are meaningless and must not be forwarded even if the handler set them.
|
||||
let message := Message.Head.setHeaders message
|
||||
<| message.headers
|
||||
|>.erase Header.Name.contentLength
|
||||
|>.erase Header.Name.transferEncoding
|
||||
|
||||
machine.modifyWriter (fun w => {
|
||||
w with outputData := Encode.encode (v := .v11) w.outputData message
|
||||
})
|
||||
@@ -1124,11 +1091,11 @@ private partial def processFixedBufferedBody (machine : Machine dir) (n : Nat) :
|
||||
if writer.userClosedBody then
|
||||
completeWriterMessage machine
|
||||
else
|
||||
machine.setWriterState (.writingBodyFixed 0)
|
||||
machine.setWriterState (.writingBody (.fixed 0))
|
||||
else
|
||||
closeOnBadMessage machine
|
||||
else
|
||||
machine.setWriterState (.writingBodyFixed remaining)
|
||||
machine.setWriterState (.writingBody (.fixed remaining))
|
||||
|
||||
/--
|
||||
Handles fixed-length writer state when no user bytes are currently buffered.
|
||||
@@ -1160,28 +1127,20 @@ private partial def processFixedBody (machine : Machine dir) (n : Nat) : Machine
|
||||
processFixedIdleBody machine
|
||||
|
||||
/--
|
||||
Processes chunked transfer-encoding output (HTTP/1.1).
|
||||
Processes chunked transfer-encoding output.
|
||||
|
||||
Writes buffered chunks when available, writes terminal `0\\r\\n\\r\\n` on
|
||||
producer close, and supports omitted-body completion.
|
||||
-/
|
||||
private partial def processChunkedBody (machine : Machine dir) : Machine dir :=
|
||||
if machine.writer.omitBody then
|
||||
completeOmittedBody machine
|
||||
else if machine.writer.userClosedBody then
|
||||
machine.modifyWriter Writer.writeFinalChunk |> completeWriterMessage
|
||||
machine.modifyWriter Writer.writeFinalChunk
|
||||
|> completeWriterMessage
|
||||
else if machine.writer.userData.size > 0 then
|
||||
machine.modifyWriter Writer.writeChunkedBody |> processWrite
|
||||
else
|
||||
machine
|
||||
|
||||
/--
|
||||
Processes connection-close body output (HTTP/1.0 server, unknown body length).
|
||||
-/
|
||||
private partial def processClosingFrameBody (machine : Machine dir) : Machine dir :=
|
||||
if machine.writer.omitBody then
|
||||
completeOmittedBody machine
|
||||
else if machine.writer.userClosedBody then
|
||||
machine.modifyWriter Writer.writeRawBody |> completeWriterMessage
|
||||
else if machine.writer.userData.size > 0 then
|
||||
machine.modifyWriter Writer.writeRawBody |> processWrite
|
||||
machine.modifyWriter Writer.writeChunkedBody
|
||||
|> processWrite
|
||||
else
|
||||
machine
|
||||
|
||||
@@ -1215,12 +1174,10 @@ partial def processWrite (machine : Machine dir) : Machine dir :=
|
||||
|> processWrite
|
||||
else
|
||||
machine
|
||||
| .writingBodyFixed n =>
|
||||
| .writingBody (.fixed n) =>
|
||||
processFixedBody machine n
|
||||
| .writingBodyChunked =>
|
||||
| .writingBody .chunked =>
|
||||
processChunkedBody machine
|
||||
| .writingBodyClosingFrame =>
|
||||
processClosingFrameBody machine
|
||||
| .complete =>
|
||||
processCompleteStep machine
|
||||
| .closed =>
|
||||
|
||||
@@ -65,14 +65,6 @@ def Message.Head.headers (m : Message.Head dir) : Headers :=
|
||||
| .receiving => Request.Head.headers m
|
||||
| .sending => Response.Head.headers m
|
||||
|
||||
/--
|
||||
Returns a copy of the message head with the headers replaced.
|
||||
-/
|
||||
def Message.Head.setHeaders (m : Message.Head dir) (headers : Headers) : Message.Head dir :=
|
||||
match dir with
|
||||
| .receiving => { (m : Request.Head) with headers }
|
||||
| .sending => { (m : Response.Head) with headers }
|
||||
|
||||
/--
|
||||
Gets the version of a `Message`.
|
||||
-/
|
||||
@@ -90,7 +82,7 @@ def Message.Head.getSize (message : Message.Head dir) (allowEOFBody : Bool) : Op
|
||||
match message.headers.getAll? .transferEncoding with
|
||||
| none =>
|
||||
match contentLength with
|
||||
| some #[cl] => .fixed <$> (Header.ContentLength.parse cl |>.map (·.length))
|
||||
| some #[cl] => .fixed <$> cl.value.toNat?
|
||||
| some _ => none -- To avoid request smuggling with malformed/multiple content-length headers.
|
||||
| none => if allowEOFBody then some (.fixed 0) else none
|
||||
|
||||
|
||||
@@ -51,20 +51,9 @@ inductive Writer.State
|
||||
| waitingForFlush
|
||||
|
||||
/--
|
||||
Writing a fixed-length body; `n` is the number of bytes still to be sent.
|
||||
Writing the body output (either fixed-length or chunked).
|
||||
-/
|
||||
| writingBodyFixed (n : Nat)
|
||||
|
||||
/--
|
||||
Writing a chunked transfer-encoding body (HTTP/1.1).
|
||||
-/
|
||||
| writingBodyChunked
|
||||
|
||||
/--
|
||||
Writing a connection-close body (HTTP/1.0 server, unknown length).
|
||||
Raw bytes are written without chunk framing; the peer reads until the connection closes.
|
||||
-/
|
||||
| writingBodyClosingFrame
|
||||
| writingBody (mode : Body.Length)
|
||||
|
||||
/--
|
||||
Completed writing a single message and ready to begin the next one.
|
||||
@@ -173,9 +162,7 @@ def canAcceptData (writer : Writer dir) : Bool :=
|
||||
match writer.state with
|
||||
| .waitingHeaders => true
|
||||
| .waitingForFlush => true
|
||||
| .writingBodyFixed _
|
||||
| .writingBodyChunked
|
||||
| .writingBodyClosingFrame => !writer.userClosedBody
|
||||
| .writingBody _ => !writer.userClosedBody
|
||||
| _ => false
|
||||
|
||||
/--
|
||||
@@ -198,9 +185,6 @@ def determineTransferMode (writer : Writer dir) : Body.Length :=
|
||||
|
||||
/--
|
||||
Adds user data chunks to the writer's buffer if the writer can accept data.
|
||||
|
||||
Empty chunks (zero bytes of data) are accepted here but will be silently dropped
|
||||
during the chunked-encoding write step — see `writeChunkedBody`.
|
||||
-/
|
||||
@[inline]
|
||||
def addUserData (data : Array Chunk) (writer : Writer dir) : Writer dir :=
|
||||
@@ -239,14 +223,12 @@ def writeFixedBody (writer : Writer dir) (limitSize : Nat) : Writer dir × Nat :
|
||||
|
||||
/--
|
||||
Writes accumulated user data to output using chunked transfer encoding.
|
||||
|
||||
Empty chunks are silently discarded. See `Chunk.empty` for the protocol-level rationale.
|
||||
-/
|
||||
def writeChunkedBody (writer : Writer dir) : Writer dir :=
|
||||
if writer.userData.size = 0 then
|
||||
writer
|
||||
else
|
||||
let data := writer.userData.filter (fun c => !c.data.isEmpty)
|
||||
let data := writer.userData
|
||||
{ writer with userData := #[], userDataBytes := 0, outputData := data.foldl (Encode.encode .v11) writer.outputData }
|
||||
|
||||
/--
|
||||
@@ -259,15 +241,6 @@ def writeFinalChunk (writer : Writer dir) : Writer dir :=
|
||||
state := .complete
|
||||
}
|
||||
|
||||
/--
|
||||
Writes accumulated user data to output as raw bytes (HTTP/1.0 connection-close framing).
|
||||
No chunk framing is added; the peer reads until the connection closes.
|
||||
-/
|
||||
def writeRawBody (writer : Writer dir) : Writer dir :=
|
||||
{ writer with
|
||||
outputData := writer.userData.foldl (fun buf c => buf.write c.data) writer.outputData,
|
||||
userData := #[], userDataBytes := 0 }
|
||||
|
||||
/--
|
||||
Extracts all accumulated output data and returns it with a cleared output buffer.
|
||||
-/
|
||||
|
||||
@@ -1,188 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sofia Rodrigues
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Std.Internal.Async
|
||||
public import Std.Internal.Async.TCP
|
||||
public import Std.Sync.CancellationToken
|
||||
public import Std.Sync.Semaphore
|
||||
public import Std.Internal.Http.Server.Config
|
||||
public import Std.Internal.Http.Server.Handler
|
||||
public import Std.Internal.Http.Server.Connection
|
||||
|
||||
public section
|
||||
|
||||
/-!
|
||||
# HTTP Server
|
||||
|
||||
This module defines a simple, asynchronous HTTP/1.1 server implementation.
|
||||
|
||||
It provides the `Std.Http.Server` structure, which encapsulates all server state, and functions for
|
||||
starting, managing, and gracefully shutting down the server.
|
||||
|
||||
The server runs entirely using `Async` and uses a shared `CancellationContext` to signal shutdowns.
|
||||
Each active client connection is tracked, and the server automatically resolves its shutdown
|
||||
promise once all connections have closed.
|
||||
-/
|
||||
|
||||
namespace Std.Http
|
||||
open Std.Internal.IO.Async TCP
|
||||
|
||||
set_option linter.all true
|
||||
|
||||
/--
|
||||
The `Server` structure holds all state required to manage the lifecycle of an HTTP server, including
|
||||
connection tracking and shutdown coordination.
|
||||
-/
|
||||
structure Server where
|
||||
|
||||
/--
|
||||
The context used for shutting down all connections and the server.
|
||||
-/
|
||||
context : Std.CancellationContext
|
||||
|
||||
/--
|
||||
Active HTTP connections
|
||||
-/
|
||||
activeConnections : Std.Mutex Nat
|
||||
|
||||
/--
|
||||
Semaphore used to enforce the maximum number of simultaneous active connections.
|
||||
`none` means no connection limit.
|
||||
-/
|
||||
connectionLimit : Option Std.Semaphore
|
||||
|
||||
/--
|
||||
Indicates when the server has successfully shut down.
|
||||
-/
|
||||
shutdownPromise : Std.Channel Unit
|
||||
|
||||
/--
|
||||
Configuration of the server
|
||||
-/
|
||||
config : Std.Http.Config
|
||||
|
||||
namespace Server
|
||||
|
||||
/--
|
||||
Create a new `Server` structure with an optional configuration.
|
||||
-/
|
||||
def new (config : Std.Http.Config := {}) : IO Server := do
|
||||
let context ← Std.CancellationContext.new
|
||||
let activeConnections ← Std.Mutex.new 0
|
||||
let connectionLimit ←
|
||||
if config.maxConnections = 0 then
|
||||
pure none
|
||||
else
|
||||
some <$> Std.Semaphore.new config.maxConnections
|
||||
let shutdownPromise ← Std.Channel.new
|
||||
|
||||
return { context, activeConnections, connectionLimit, shutdownPromise, config }
|
||||
|
||||
/--
|
||||
Triggers cancellation of all requests and the accept loop in the server. This function should be used
|
||||
in conjunction with `waitShutdown` to properly coordinate the shutdown sequence.
|
||||
-/
|
||||
@[inline]
|
||||
def shutdown (s : Server) : Async Unit :=
|
||||
s.context.cancel .shutdown
|
||||
|
||||
/--
|
||||
Waits for the server to shut down. Blocks until another task or async operation calls the `shutdown` function.
|
||||
-/
|
||||
@[inline]
|
||||
def waitShutdown (s : Server) : Async Unit := do
|
||||
Async.ofAsyncTask ((← s.shutdownPromise.recv).map Except.ok)
|
||||
|
||||
/--
|
||||
Returns a `Selector` that waits for the server to shut down.
|
||||
-/
|
||||
@[inline]
|
||||
def waitShutdownSelector (s : Server) : Selector Unit :=
|
||||
s.shutdownPromise.recvSelector
|
||||
|
||||
/--
|
||||
Triggers cancellation of all requests and the accept loop, then waits for the server to fully shut down.
|
||||
This is a convenience function combining `shutdown` and then `waitShutdown`.
|
||||
-/
|
||||
@[inline]
|
||||
def shutdownAndWait (s : Server) : Async Unit := do
|
||||
s.context.cancel .shutdown
|
||||
s.waitShutdown
|
||||
|
||||
@[inline]
|
||||
private def frameCancellation (s : Server) (releaseConnectionPermit : Bool := false)
|
||||
(action : ContextAsync α) : ContextAsync α := do
|
||||
s.activeConnections.atomically (modify (· + 1))
|
||||
try
|
||||
action
|
||||
finally
|
||||
if releaseConnectionPermit then
|
||||
if let some limit := s.connectionLimit then
|
||||
limit.release
|
||||
|
||||
s.activeConnections.atomically do
|
||||
modify (· - 1)
|
||||
|
||||
if (← get) = 0 ∧ (← s.context.isCancelled) then
|
||||
discard <| s.shutdownPromise.send ()
|
||||
|
||||
/--
|
||||
Start a new HTTP/1.1 server on the given socket address. This function uses `Async` to handle tasks
|
||||
and TCP connections, and returns a `Server` structure that can be used to cancel the server.
|
||||
-/
|
||||
def serve {σ : Type} [Handler σ]
|
||||
(addr : Net.SocketAddress)
|
||||
(handler : σ)
|
||||
(config : Config := {}) (backlog : UInt32 := 1024) : Async Server := do
|
||||
|
||||
let httpServer ← Server.new config
|
||||
|
||||
let server ← Socket.Server.mk
|
||||
server.bind addr
|
||||
server.listen backlog
|
||||
server.noDelay
|
||||
|
||||
let runServer := do
|
||||
frameCancellation httpServer (action := do
|
||||
while true do
|
||||
let permitAcquired ←
|
||||
if let some limit := httpServer.connectionLimit then
|
||||
let permit ← limit.acquire
|
||||
await permit
|
||||
pure true
|
||||
else
|
||||
pure false
|
||||
|
||||
let result ← Selectable.one #[
|
||||
.case (server.acceptSelector) (fun x => pure <| some x),
|
||||
.case (← ContextAsync.doneSelector) (fun _ => pure none)
|
||||
]
|
||||
|
||||
match result with
|
||||
| some client =>
|
||||
let extensions ← do
|
||||
match (← EIO.toBaseIO client.getPeerName) with
|
||||
| .ok addr => pure <| Extensions.empty.insert (Server.RemoteAddr.mk addr)
|
||||
| .error _ => pure Extensions.empty
|
||||
|
||||
ContextAsync.background
|
||||
(frameCancellation httpServer (releaseConnectionPermit := permitAcquired)
|
||||
(action := do
|
||||
serveConnection client handler config extensions))
|
||||
| none =>
|
||||
if permitAcquired then
|
||||
if let some limit := httpServer.connectionLimit then
|
||||
limit.release
|
||||
break
|
||||
)
|
||||
|
||||
background (runServer httpServer.context)
|
||||
|
||||
return httpServer
|
||||
|
||||
end Std.Http.Server
|
||||
@@ -1,196 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sofia Rodrigues
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Std.Time
|
||||
public import Std.Internal.Http.Protocol.H1
|
||||
|
||||
public section
|
||||
|
||||
/-!
|
||||
# Config
|
||||
|
||||
This module exposes the `Config`, a structure that describes timeout, request and headers
|
||||
configuration of an HTTP Server.
|
||||
-/
|
||||
|
||||
namespace Std.Http
|
||||
|
||||
set_option linter.all true
|
||||
|
||||
/--
|
||||
Connection limits configuration with validation.
|
||||
-/
|
||||
structure Config where
|
||||
/--
|
||||
Maximum number of simultaneous active connections (default: 1024).
|
||||
Setting this to `0` disables the limit entirely: the server will accept any number of
|
||||
concurrent connections and no semaphore-based cap is enforced. Use with care — an
|
||||
unconstrained server may exhaust file descriptors or memory under adversarial load.
|
||||
-/
|
||||
maxConnections : Nat := 1024
|
||||
|
||||
/--
|
||||
Maximum number of requests per connection.
|
||||
-/
|
||||
maxRequests : Nat := 100
|
||||
|
||||
/--
|
||||
Maximum number of headers allowed per request.
|
||||
-/
|
||||
maxHeaders : Nat := 50
|
||||
|
||||
/--
|
||||
Maximum aggregate byte size of all header field lines in a single message
|
||||
(name + value bytes plus 4 bytes per line for `: ` and `\r\n`). Default: 64 KiB.
|
||||
-/
|
||||
maxHeaderBytes : Nat := 65536
|
||||
|
||||
/--
|
||||
Timeout (in milliseconds) for receiving additional data while a request is actively being
|
||||
processed (e.g. reading the request body). Applies after the request headers have been parsed
|
||||
and replaces the keep-alive timeout for the duration of the request.
|
||||
-/
|
||||
lingeringTimeout : Time.Millisecond.Offset := 10000
|
||||
|
||||
/--
|
||||
Timeout for keep-alive connections
|
||||
-/
|
||||
keepAliveTimeout : { x : Time.Millisecond.Offset // x > 0 } := ⟨12000, by decide⟩
|
||||
|
||||
/--
|
||||
Maximum time (in milliseconds) allowed to receive the complete request headers after the first
|
||||
byte of a new request arrives. This prevents Slowloris-style attacks where a client sends bytes
|
||||
at a slow rate to hold a connection slot open without completing a request. Once a request starts,
|
||||
each individual read must complete within this window. Default: 5 seconds.
|
||||
-/
|
||||
headerTimeout : Time.Millisecond.Offset := 5000
|
||||
|
||||
/--
|
||||
Whether to enable keep-alive connections by default.
|
||||
-/
|
||||
enableKeepAlive : Bool := true
|
||||
|
||||
/--
|
||||
The maximum size that the connection can receive in a single recv call.
|
||||
-/
|
||||
maximumRecvSize : Nat := 8192
|
||||
|
||||
/--
|
||||
Default buffer size for the connection
|
||||
-/
|
||||
defaultPayloadBytes : Nat := 8192
|
||||
|
||||
/--
|
||||
Whether to automatically generate the `Date` header in responses.
|
||||
-/
|
||||
generateDate : Bool := true
|
||||
|
||||
/--
|
||||
The `Server` header value injected into outgoing responses.
|
||||
`none` suppresses the header entirely.
|
||||
-/
|
||||
serverName : Option Header.Value := some (.mk "LeanHTTP/1.1")
|
||||
|
||||
/--
|
||||
Maximum length of request URI (default: 8192 bytes)
|
||||
-/
|
||||
maxUriLength : Nat := 8192
|
||||
|
||||
/--
|
||||
Maximum number of bytes consumed while parsing request start-lines (default: 8192 bytes).
|
||||
-/
|
||||
maxStartLineLength : Nat := 8192
|
||||
|
||||
/--
|
||||
Maximum length of header field name (default: 256 bytes)
|
||||
-/
|
||||
maxHeaderNameLength : Nat := 256
|
||||
|
||||
/--
|
||||
Maximum length of header field value (default: 8192 bytes)
|
||||
-/
|
||||
maxHeaderValueLength : Nat := 8192
|
||||
|
||||
/--
|
||||
Maximum number of spaces in delimiter sequences (default: 16)
|
||||
-/
|
||||
maxSpaceSequence : Nat := 16
|
||||
|
||||
/--
|
||||
Maximum number of leading empty lines (bare CRLF) to skip before a request-line
|
||||
(RFC 9112 §2.2 robustness). Default: 8.
|
||||
-/
|
||||
maxLeadingEmptyLines : Nat := 8
|
||||
|
||||
/--
|
||||
Maximum length of chunk extension name (default: 256 bytes)
|
||||
-/
|
||||
maxChunkExtNameLength : Nat := 256
|
||||
|
||||
/--
|
||||
Maximum length of chunk extension value (default: 256 bytes)
|
||||
-/
|
||||
maxChunkExtValueLength : Nat := 256
|
||||
|
||||
/--
|
||||
Maximum number of bytes consumed while parsing one chunk-size line with extensions (default: 8192 bytes).
|
||||
-/
|
||||
maxChunkLineLength : Nat := 8192
|
||||
|
||||
/--
|
||||
Maximum allowed chunk payload size in bytes (default: 8 MiB).
|
||||
-/
|
||||
maxChunkSize : Nat := 8 * 1024 * 1024
|
||||
|
||||
/--
|
||||
Maximum allowed total body size per request in bytes (default: 64 MiB).
|
||||
-/
|
||||
maxBodySize : Nat := 64 * 1024 * 1024
|
||||
|
||||
/--
|
||||
Maximum length of reason phrase (default: 512 bytes)
|
||||
-/
|
||||
maxReasonPhraseLength : Nat := 512
|
||||
|
||||
/--
|
||||
Maximum number of trailer headers (default: 20)
|
||||
-/
|
||||
maxTrailerHeaders : Nat := 20
|
||||
|
||||
/--
|
||||
Maximum number of extensions on a single chunk-size line (default: 16).
|
||||
-/
|
||||
maxChunkExtensions : Nat := 16
|
||||
|
||||
namespace Config
|
||||
|
||||
/--
|
||||
Converts to HTTP/1.1 config.
|
||||
-/
|
||||
def toH1Config (config : Config) : Protocol.H1.Config where
|
||||
maxMessages := config.maxRequests
|
||||
maxHeaders := config.maxHeaders
|
||||
maxHeaderBytes := config.maxHeaderBytes
|
||||
enableKeepAlive := config.enableKeepAlive
|
||||
agentName := config.serverName
|
||||
maxUriLength := config.maxUriLength
|
||||
maxStartLineLength := config.maxStartLineLength
|
||||
maxHeaderNameLength := config.maxHeaderNameLength
|
||||
maxHeaderValueLength := config.maxHeaderValueLength
|
||||
maxSpaceSequence := config.maxSpaceSequence
|
||||
maxLeadingEmptyLines := config.maxLeadingEmptyLines
|
||||
maxChunkExtensions := config.maxChunkExtensions
|
||||
maxChunkExtNameLength := config.maxChunkExtNameLength
|
||||
maxChunkExtValueLength := config.maxChunkExtValueLength
|
||||
maxChunkLineLength := config.maxChunkLineLength
|
||||
maxChunkSize := config.maxChunkSize
|
||||
maxBodySize := config.maxBodySize
|
||||
maxReasonPhraseLength := config.maxReasonPhraseLength
|
||||
maxTrailerHeaders := config.maxTrailerHeaders
|
||||
|
||||
end Std.Http.Config
|
||||
@@ -1,560 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sofia Rodrigues
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Std.Internal.Async.TCP
|
||||
public import Std.Internal.Async.ContextAsync
|
||||
public import Std.Internal.Http.Transport
|
||||
public import Std.Internal.Http.Protocol.H1
|
||||
public import Std.Internal.Http.Server.Config
|
||||
public import Std.Internal.Http.Server.Handler
|
||||
|
||||
public section
|
||||
|
||||
namespace Std
|
||||
namespace Http
|
||||
namespace Server
|
||||
|
||||
open Std Internal IO Async TCP Protocol
|
||||
open Time
|
||||
|
||||
/-!
|
||||
# Connection
|
||||
|
||||
This module defines `Server.Connection`, a structure used to handle a single HTTP connection with
|
||||
possibly multiple requests.
|
||||
-/
|
||||
|
||||
set_option linter.all true
|
||||
|
||||
/--
|
||||
Represents the remote address of a client connection.
|
||||
-/
|
||||
structure RemoteAddr where
|
||||
/--
|
||||
The socket address of the remote client.
|
||||
-/
|
||||
addr : Net.SocketAddress
|
||||
deriving TypeName
|
||||
|
||||
instance : ToString RemoteAddr where
|
||||
toString addr := toString addr.addr
|
||||
|
||||
/--
|
||||
A single HTTP connection.
|
||||
-/
|
||||
structure Connection (α : Type) where
|
||||
/--
|
||||
The client connection.
|
||||
-/
|
||||
socket : α
|
||||
|
||||
/--
|
||||
The processing machine for HTTP/1.1.
|
||||
-/
|
||||
machine : H1.Machine .receiving
|
||||
|
||||
/--
|
||||
Extensions to attach to each request (e.g., remote address).
|
||||
-/
|
||||
extensions : Extensions := .empty
|
||||
|
||||
namespace Connection
|
||||
|
||||
/--
|
||||
Events produced by the async select loop in `receiveWithTimeout`.
|
||||
Each variant corresponds to one possible outcome of waiting for I/O.
|
||||
-/
|
||||
private inductive Recv (β : Type)
|
||||
| bytes (x : Option ByteArray)
|
||||
| responseBody (x : Option Chunk)
|
||||
| bodyInterest (x : Bool)
|
||||
| response (x : (Except Error (Response β)))
|
||||
| timeout
|
||||
| shutdown
|
||||
| close
|
||||
|
||||
/--
|
||||
The set of I/O sources to wait on during a single poll iteration.
|
||||
Each `Option` field is `none` when that source is not currently active.
|
||||
-/
|
||||
private structure PollSources (α β : Type) where
|
||||
socket : Option α
|
||||
expect : Option Nat
|
||||
response : Option (Std.Channel (Except Error (Response β)))
|
||||
responseBody : Option β
|
||||
requestBody : Option Body.Stream
|
||||
timeout : Millisecond.Offset
|
||||
keepAliveTimeout : Option Millisecond.Offset
|
||||
headerTimeout : Option Timestamp
|
||||
connectionContext : CancellationContext
|
||||
|
||||
/--
|
||||
Waits for the next I/O event across all active sources described by `sources`.
|
||||
Computes the socket recv size from `config`, then races all active selectables.
|
||||
Calls `Handler.onFailure` and returns `.close` on transport errors.
|
||||
-/
|
||||
private def pollNextEvent
|
||||
{σ β : Type} [Transport α] [Handler σ] [Body β]
|
||||
(config : Config) (handler : σ) (sources : PollSources α β)
|
||||
: Async (Recv β) := do
|
||||
let expectedBytes := sources.expect
|
||||
|>.getD config.defaultPayloadBytes
|
||||
|>.min config.maximumRecvSize
|
||||
|>.toUInt64
|
||||
|
||||
let mut selectables : Array (Selectable (Recv β)) := #[
|
||||
.case sources.connectionContext.doneSelector (fun _ => do
|
||||
let reason ← sources.connectionContext.getCancellationReason
|
||||
match reason with
|
||||
| some .deadline => pure .timeout
|
||||
| _ => pure .shutdown)
|
||||
]
|
||||
|
||||
if let some socket := sources.socket then
|
||||
selectables := selectables.push (.case (Transport.recvSelector socket expectedBytes) (Recv.bytes · |> pure))
|
||||
|
||||
|
||||
if sources.keepAliveTimeout.isNone then
|
||||
if let some timeout := sources.headerTimeout then
|
||||
selectables := selectables.push (.case (← Selector.sleep (timeout - (← Timestamp.now)).toMilliseconds) (fun _ => pure .timeout))
|
||||
else
|
||||
selectables := selectables.push (.case (← Selector.sleep sources.timeout) (fun _ => pure .timeout))
|
||||
|
||||
if let some responseBody := sources.responseBody then
|
||||
selectables := selectables.push (.case (Body.recvSelector responseBody) (Recv.responseBody · |> pure))
|
||||
|
||||
if let some requestBody := sources.requestBody then
|
||||
selectables := selectables.push (.case (requestBody.interestSelector) (Recv.bodyInterest · |> pure))
|
||||
|
||||
if let some response := sources.response then
|
||||
selectables := selectables.push (.case response.recvSelector (Recv.response · |> pure))
|
||||
|
||||
try Selectable.one selectables
|
||||
catch e =>
|
||||
Handler.onFailure handler e
|
||||
pure .close
|
||||
|
||||
/--
|
||||
Handles the `Expect: 100-continue` protocol for a pending request head.
|
||||
Races between the handler's decision (`Handler.onContinue`), the connection being
|
||||
cancelled, and a lingering timeout. Returns the updated machine and whether
|
||||
`pendingHead` should be cleared (i.e. when the request is rejected).
|
||||
-/
|
||||
private def handleContinueEvent
|
||||
{σ : Type} [Handler σ]
|
||||
(handler : σ) (machine : H1.Machine .receiving) (head : Request.Head)
|
||||
(config : Config) (connectionContext : CancellationContext)
|
||||
: Async (H1.Machine .receiving × Bool) := do
|
||||
|
||||
let continueChannel : Std.Channel Bool ← Std.Channel.new
|
||||
let continueTask ← Handler.onContinue handler head |>.asTask
|
||||
|
||||
BaseIO.chainTask continueTask fun
|
||||
| .ok v => discard <| continueChannel.send v
|
||||
| .error _ => discard <| continueChannel.send false
|
||||
|
||||
let canContinue ← Selectable.one #[
|
||||
.case continueChannel.recvSelector pure,
|
||||
.case connectionContext.doneSelector (fun _ => pure false),
|
||||
.case (← Selector.sleep config.lingeringTimeout) (fun _ => pure false)
|
||||
]
|
||||
|
||||
let status := if canContinue then Status.«continue» else Status.expectationFailed
|
||||
return (machine.canContinue status, !canContinue)
|
||||
|
||||
/--
|
||||
Injects a `Date` header into a response head if `Config.generateDate` is set
|
||||
and the response does not already include one.
|
||||
-/
|
||||
private def prepareResponseHead (config : Config) (head : Response.Head) : Async Response.Head := do
|
||||
if config.generateDate ∧ ¬head.headers.contains Header.Name.date then
|
||||
let now ← Std.Time.DateTime.now (tz := .UTC)
|
||||
return { head with headers := head.headers.insert Header.Name.date (Header.Value.ofString! now.toRFC822String) }
|
||||
else
|
||||
return head
|
||||
|
||||
/--
|
||||
Applies a successful handler response to the machine.
|
||||
Optionally injects a `Date` header, records the known body size, and sends the
|
||||
response head. Returns the updated machine and the body stream to drain, or `none`
|
||||
when the body should be omitted (e.g., for HEAD requests).
|
||||
-/
|
||||
private def applyResponse
|
||||
{β : Type} [Body β]
|
||||
(config : Config) (machine : H1.Machine .receiving) (res : Response β)
|
||||
: Async (H1.Machine .receiving × Option β) := do
|
||||
let size ← Body.getKnownSize res.body
|
||||
|
||||
let machineSized :=
|
||||
if let some knownSize := size
|
||||
then machine.setKnownSize knownSize
|
||||
else machine
|
||||
|
||||
let responseHead ← prepareResponseHead config res.line
|
||||
let machineWithHead := machineSized.send responseHead
|
||||
if machineWithHead.writer.omitBody then
|
||||
if ¬(← Body.isClosed res.body) then
|
||||
Body.close res.body
|
||||
return (machineWithHead, none)
|
||||
else
|
||||
return (machineWithHead, some res.body)
|
||||
|
||||
/--
|
||||
All mutable state carried through the connection processing loop.
|
||||
Bundled into a struct so it can be passed to and returned from helper functions.
|
||||
-/
|
||||
private structure ConnectionState (β : Type) where
|
||||
machine : H1.Machine .receiving
|
||||
requestStream : Body.Stream
|
||||
keepAliveTimeout : Option Millisecond.Offset
|
||||
currentTimeout : Millisecond.Offset
|
||||
headerTimeout : Option Timestamp
|
||||
response : Std.Channel (Except Error (Response β))
|
||||
respStream : Option β
|
||||
requiresData : Bool
|
||||
expectData : Option Nat
|
||||
handlerDispatched : Bool
|
||||
pendingHead : Option Request.Head
|
||||
|
||||
/--
|
||||
Processes all H1 events from a single machine step, updating the connection state.
|
||||
Handles keep-alive resets, body-size tracking, `Expect: 100-continue`, and parse errors.
|
||||
Returns the updated state; stops early on `.failed`.
|
||||
-/
|
||||
private def processH1Events
|
||||
{σ β : Type} [Handler σ] [Body β]
|
||||
(handler : σ) (config : Config) (connectionContext : CancellationContext)
|
||||
(events : Array (H1.Event .receiving))
|
||||
(state : ConnectionState β)
|
||||
: Async (ConnectionState β) := do
|
||||
|
||||
let mut st := state
|
||||
|
||||
for event in events do
|
||||
match event with
|
||||
| .needMoreData expect =>
|
||||
st := { st with requiresData := true, expectData := expect }
|
||||
|
||||
| .needAnswer => pure ()
|
||||
|
||||
| .endHeaders head =>
|
||||
|
||||
-- Sets the pending head and removes the KeepAlive or Header timeout.
|
||||
st := { st with
|
||||
currentTimeout := config.lingeringTimeout
|
||||
keepAliveTimeout := none
|
||||
headerTimeout := none
|
||||
pendingHead := some head
|
||||
}
|
||||
|
||||
if let some length := head.getSize true then
|
||||
-- Sets the size of the body that is going out of the connection.
|
||||
Body.setKnownSize st.requestStream (some length)
|
||||
|
||||
| .«continue» =>
|
||||
if let some head := st.pendingHead then
|
||||
let (newMachine, clearPending) ← handleContinueEvent handler st.machine head config connectionContext
|
||||
st := { st with machine := newMachine }
|
||||
if clearPending then
|
||||
st := { st with pendingHead := none }
|
||||
|
||||
| .next =>
|
||||
-- Reset all per-request state for the next pipelined request.
|
||||
if ¬(← Body.isClosed st.requestStream) then
|
||||
Body.close st.requestStream
|
||||
|
||||
if let some res := st.respStream then
|
||||
if ¬(← Body.isClosed res) then
|
||||
Body.close res
|
||||
|
||||
let newStream ← Body.mkStream
|
||||
|
||||
st := { st with
|
||||
requestStream := newStream
|
||||
response := ← Std.Channel.new
|
||||
respStream := none
|
||||
keepAliveTimeout := some config.keepAliveTimeout.val
|
||||
currentTimeout := config.keepAliveTimeout.val
|
||||
headerTimeout := none
|
||||
handlerDispatched := false
|
||||
}
|
||||
|
||||
| .failed err =>
|
||||
Handler.onFailure handler (toString err)
|
||||
|
||||
if ¬(← Body.isClosed st.requestStream) then
|
||||
Body.close st.requestStream
|
||||
|
||||
st := { st with requiresData := false, pendingHead := none }
|
||||
break
|
||||
|
||||
| .closeBody =>
|
||||
if ¬(← Body.isClosed st.requestStream) then
|
||||
Body.close st.requestStream
|
||||
|
||||
| .close => pure ()
|
||||
|
||||
return st
|
||||
|
||||
/--
|
||||
Dispatches a pending request head to the handler if one is waiting.
|
||||
Spawns the handler as an async task and routes its result back through `state.response`.
|
||||
Returns the updated state with `pendingHead` cleared and `handlerDispatched` set.
|
||||
-/
|
||||
private def dispatchPendingRequest
|
||||
{σ : Type} [Handler σ]
|
||||
(handler : σ) (extensions : Extensions) (connectionContext : CancellationContext)
|
||||
(state : ConnectionState (Handler.ResponseBody σ))
|
||||
: Async (ConnectionState (Handler.ResponseBody σ)) := do
|
||||
if let some line := state.pendingHead then
|
||||
|
||||
let task ← Handler.onRequest handler { line, body := state.requestStream, extensions } connectionContext
|
||||
|>.asTask
|
||||
|
||||
BaseIO.chainTask task (discard ∘ state.response.send)
|
||||
return { state with pendingHead := none, handlerDispatched := true }
|
||||
else
|
||||
return state
|
||||
|
||||
/--
|
||||
Attempts a single non-blocking receive from the body and feeds any available chunk
|
||||
into the machine, without going through the `Selectable.one` scheduler.
|
||||
|
||||
For fully-buffered bodies (e.g. `Body.Full`, `Body.Buffered`) this avoids one
|
||||
`Selectable.one` round-trip when the chunk is already in memory. Streaming bodies
|
||||
that have no producer waiting return `none` and fall through to the normal poll loop
|
||||
unchanged.
|
||||
|
||||
Only one chunk is consumed here. Looping would introduce yield points between
|
||||
`Body.tryRecv` calls, allowing a background producer to race ahead and close the
|
||||
stream before `writeHead` runs — turning a streaming response with unknown size
|
||||
into a fixed-length one.
|
||||
-/
|
||||
private def tryDrainBody [Body β]
|
||||
(machine : H1.Machine .receiving) (body : β)
|
||||
: Async (H1.Machine .receiving × Option β) := do
|
||||
match ← Body.tryRecv body with
|
||||
| none => pure (machine, some body)
|
||||
| some (some chunk) => pure (machine.sendData #[chunk], some body)
|
||||
| some none =>
|
||||
if !(← Body.isClosed body) then Body.close body
|
||||
pure (machine.userClosedBody, none)
|
||||
|
||||
/--
|
||||
Processes a single async I/O event and updates the connection state.
|
||||
Returns the updated state and `true` if the connection should be closed immediately.
|
||||
-/
|
||||
private def handleRecvEvent
|
||||
{σ β : Type} [Handler σ] [Body β]
|
||||
(handler : σ) (config : Config)
|
||||
(event : Recv β) (state : ConnectionState β)
|
||||
: Async (ConnectionState β × Bool) := do
|
||||
|
||||
match event with
|
||||
| .bytes (some bs) =>
|
||||
|
||||
let mut st := state
|
||||
|
||||
-- After the first byte after idle we switch from keep-alive timeout to per-request header timeout.
|
||||
if st.keepAliveTimeout.isSome then
|
||||
st := { st with
|
||||
keepAliveTimeout := none
|
||||
headerTimeout := some <| (← Timestamp.now) + config.headerTimeout
|
||||
}
|
||||
|
||||
return ({ st with machine := st.machine.feed bs }, false)
|
||||
|
||||
| .bytes none =>
|
||||
return ({ state with machine := state.machine.noMoreInput }, false)
|
||||
|
||||
| .responseBody (some chunk) =>
|
||||
return ({ state with machine := state.machine.sendData #[chunk] }, false)
|
||||
|
||||
| .responseBody none =>
|
||||
if let some res := state.respStream then
|
||||
if ¬(← Body.isClosed res) then Body.close res
|
||||
return ({ state with machine := state.machine.userClosedBody, respStream := none }, false)
|
||||
|
||||
| .bodyInterest interested =>
|
||||
if interested then
|
||||
let (newMachine, pulledChunk) := state.machine.pullBody
|
||||
let mut st := { state with machine := newMachine }
|
||||
|
||||
if let some pulled := pulledChunk then
|
||||
try st.requestStream.send pulled.chunk pulled.incomplete
|
||||
catch e => Handler.onFailure handler e
|
||||
if pulled.final then
|
||||
if ¬(← Body.isClosed st.requestStream) then
|
||||
Body.close st.requestStream
|
||||
|
||||
return (st, false)
|
||||
else
|
||||
return (state, false)
|
||||
|
||||
| .close => return (state, true)
|
||||
|
||||
| .timeout =>
|
||||
Handler.onFailure handler "request header timeout"
|
||||
return ({ state with machine := state.machine.closeWithError .requestTimeout, handlerDispatched := false }, false)
|
||||
|
||||
| .shutdown =>
|
||||
return ({ state with machine := state.machine.closeWithError .serviceUnavailable, handlerDispatched := false }, false)
|
||||
|
||||
| .response (.error err) =>
|
||||
Handler.onFailure handler err
|
||||
return ({ state with machine := state.machine.closeWithError .internalServerError, handlerDispatched := false }, false)
|
||||
|
||||
| .response (.ok res) =>
|
||||
if state.machine.failed then
|
||||
if ¬(← Body.isClosed res.body) then Body.close res.body
|
||||
return ({ state with handlerDispatched := false }, false)
|
||||
else
|
||||
let (newMachine, newRespStream) ← applyResponse config state.machine res
|
||||
|
||||
-- Eagerly consume one chunk if immediately available (avoids a Selectable.one round-trip).
|
||||
let (drainedMachine, drainedRespStream) ←
|
||||
match newRespStream with
|
||||
| none => pure (newMachine, none)
|
||||
| some body => tryDrainBody newMachine body
|
||||
|
||||
return ({ state with machine := drainedMachine, handlerDispatched := false, respStream := drainedRespStream }, false)
|
||||
|
||||
/--
|
||||
Computes the active `PollSources` for the current connection state.
|
||||
Determines which IO sources need attention and whether to include the socket.
|
||||
-/
|
||||
private def buildPollSources
|
||||
{α β : Type} [Transport α]
|
||||
(socket : α) (connectionContext : CancellationContext) (state : ConnectionState β)
|
||||
: Async (PollSources α β) := do
|
||||
let requestBodyOpen ←
|
||||
if state.machine.canPullBody then pure !(← Body.isClosed state.requestStream)
|
||||
else pure false
|
||||
|
||||
let requestBodyInterested ←
|
||||
if state.machine.canPullBody ∧ requestBodyOpen then state.requestStream.hasInterest
|
||||
else pure false
|
||||
|
||||
let requestBody ←
|
||||
if state.machine.canPullBodyNow ∧ requestBodyOpen then pure (some state.requestStream)
|
||||
else pure none
|
||||
|
||||
-- Include the socket only when there is more to do than waiting for the handler alone.
|
||||
let pollSocket :=
|
||||
state.requiresData ∨ !state.handlerDispatched ∨ state.respStream.isSome ∨
|
||||
state.machine.writer.sentMessage ∨ (state.machine.canPullBody ∧ requestBodyInterested)
|
||||
|
||||
return {
|
||||
socket := if pollSocket then some socket else none
|
||||
expect := state.expectData
|
||||
response := if state.handlerDispatched then some state.response else none
|
||||
responseBody := state.respStream
|
||||
requestBody := requestBody
|
||||
timeout := state.currentTimeout
|
||||
keepAliveTimeout := state.keepAliveTimeout
|
||||
headerTimeout := state.headerTimeout
|
||||
connectionContext := connectionContext
|
||||
}
|
||||
|
||||
/--
|
||||
Runs the main request/response processing loop for a single connection.
|
||||
Drives the HTTP/1.1 state machine through four phases each iteration:
|
||||
send buffered output, process H1 events, dispatch pending requests, poll for I/O.
|
||||
-/
|
||||
private def handle
|
||||
{σ : Type} [Transport α] [h : Handler σ]
|
||||
(connection : Connection α)
|
||||
(config : Config)
|
||||
(connectionContext : CancellationContext)
|
||||
(handler : σ) : Async Unit := do
|
||||
|
||||
let _ : Body (Handler.ResponseBody σ) := Handler.responseBodyInstance
|
||||
|
||||
let socket := connection.socket
|
||||
let initStream ← Body.mkStream
|
||||
|
||||
let mut state : ConnectionState (Handler.ResponseBody σ) := {
|
||||
machine := connection.machine
|
||||
requestStream := initStream
|
||||
keepAliveTimeout := some config.keepAliveTimeout.val
|
||||
currentTimeout := config.keepAliveTimeout.val
|
||||
headerTimeout := none
|
||||
response := ← Std.Channel.new
|
||||
respStream := none
|
||||
requiresData := false
|
||||
expectData := none
|
||||
handlerDispatched := false
|
||||
pendingHead := none
|
||||
}
|
||||
|
||||
while ¬state.machine.halted do
|
||||
|
||||
-- Phase 1: advance the state machine and flush any output.
|
||||
let (newMachine, step) := state.machine.step
|
||||
state := { state with machine := newMachine }
|
||||
|
||||
if step.output.size > 0 then
|
||||
try Transport.sendAll socket step.output.data
|
||||
catch e =>
|
||||
Handler.onFailure handler e
|
||||
break
|
||||
|
||||
-- Phase 2: process all events emitted by this step.
|
||||
state ← processH1Events handler config connectionContext step.events state
|
||||
|
||||
-- Phase 3: dispatch any newly parsed request to the handler.
|
||||
state ← dispatchPendingRequest handler connection.extensions connectionContext state
|
||||
|
||||
-- Phase 4: wait for the next IO event when any source needs attention.
|
||||
if state.requiresData ∨ state.handlerDispatched ∨ state.respStream.isSome ∨ state.machine.canPullBody then
|
||||
state := { state with requiresData := false }
|
||||
let sources ← buildPollSources socket connectionContext state
|
||||
let event ← pollNextEvent config handler sources
|
||||
let (newState, shouldClose) ← handleRecvEvent handler config event state
|
||||
state := newState
|
||||
if shouldClose then break
|
||||
|
||||
-- Clean up: close all open streams and the socket.
|
||||
if ¬(← Body.isClosed state.requestStream) then
|
||||
Body.close state.requestStream
|
||||
|
||||
if let some res := state.respStream then
|
||||
if ¬(← Body.isClosed res) then Body.close res
|
||||
|
||||
Transport.close socket
|
||||
|
||||
end Connection
|
||||
|
||||
/--
|
||||
Handles request/response processing for a single connection using an `Async` handler.
|
||||
The library-level entry point for running a server is `Server.serve`.
|
||||
This function can be used with a `TCP.Socket` or any other type that implements
|
||||
`Transport` to build custom server loops.
|
||||
|
||||
# Example
|
||||
|
||||
```lean
|
||||
-- Create a TCP socket server instance
|
||||
let server ← Socket.Server.mk
|
||||
server.bind addr
|
||||
server.listen backlog
|
||||
|
||||
-- Enter an infinite loop to handle incoming client connections
|
||||
while true do
|
||||
let client ← server.accept
|
||||
background (serveConnection client handler config)
|
||||
```
|
||||
-/
|
||||
def serveConnection
|
||||
{σ : Type} [Transport t] [Handler σ]
|
||||
(client : t) (handler : σ)
|
||||
(config : Config) (extensions : Extensions := .empty) : ContextAsync Unit := do
|
||||
(Connection.mk client { config := config.toH1Config } extensions)
|
||||
|>.handle config (← ContextAsync.getContext) handler
|
||||
|
||||
end Std.Http.Server
|
||||
@@ -1,126 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sofia Rodrigues
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Std.Internal.Async
|
||||
public import Std.Internal.Http.Data
|
||||
public import Std.Internal.Async.ContextAsync
|
||||
|
||||
public section
|
||||
|
||||
namespace Std.Http.Server
|
||||
|
||||
open Std.Internal.IO.Async
|
||||
|
||||
set_option linter.all true
|
||||
|
||||
/--
|
||||
A type class for handling HTTP server requests. Implement this class to define how the server
|
||||
responds to incoming requests, failures, and `Expect: 100-continue` headers.
|
||||
-/
|
||||
class Handler (σ : Type) where
|
||||
/--
|
||||
Concrete body type produced by `onRequest`.
|
||||
Defaults to `Body.Any`, but handlers may override it with any reader/writer-compatible body.
|
||||
-/
|
||||
ResponseBody : Type := Body.Any
|
||||
|
||||
/--
|
||||
Body instance required by the connection loop for receiving response chunks.
|
||||
-/
|
||||
[responseBodyInstance : Body ResponseBody]
|
||||
|
||||
/--
|
||||
Called for each incoming HTTP request.
|
||||
-/
|
||||
onRequest (self : σ) (request : Request Body.Stream) : ContextAsync (Response ResponseBody)
|
||||
|
||||
/--
|
||||
Called when an I/O or transport error occurs while processing a request (e.g. broken socket,
|
||||
handler exception). This is a **notification only**: the connection will close regardless of
|
||||
the handler's response. Use this for logging and metrics. The default implementation does nothing.
|
||||
-/
|
||||
onFailure (self : σ) (error : IO.Error) : Async Unit :=
|
||||
pure ()
|
||||
|
||||
/--
|
||||
Called when a request includes an `Expect: 100-continue` header. Return `true` to send a
|
||||
`100 Continue` response and accept the body. If `false` is returned the server sends
|
||||
`417 Expectation Failed`, disables keep-alive, and closes the request body reader.
|
||||
This function is guarded by `Config.lingeringTimeout` and may be cancelled on server shutdown.
|
||||
The default implementation always returns `true`.
|
||||
-/
|
||||
onContinue (self : σ) (request : Request.Head) : Async Bool :=
|
||||
pure true
|
||||
|
||||
/--
|
||||
A stateless HTTP handler.
|
||||
-/
|
||||
structure StatelessHandler where
|
||||
/--
|
||||
Function called for each incoming request.
|
||||
-/
|
||||
onRequest : Request Body.Stream → ContextAsync (Response Body.Any)
|
||||
|
||||
/--
|
||||
Function called when an I/O or transport error occurs. The default does nothing.
|
||||
-/
|
||||
onFailure : IO.Error → Async Unit := fun _ => pure ()
|
||||
|
||||
/--
|
||||
Function called when a request includes `Expect: 100-continue`. Return `true` to accept
|
||||
the body or `false` to reject it with `417 Expectation Failed`. The default always accepts.
|
||||
-/
|
||||
onContinue : Request.Head → Async Bool := fun _ => pure true
|
||||
|
||||
instance : Handler StatelessHandler where
|
||||
onRequest self request := self.onRequest request
|
||||
onFailure self error := self.onFailure error
|
||||
onContinue self request := self.onContinue request
|
||||
|
||||
namespace Handler
|
||||
|
||||
/--
|
||||
Builds a `StatelessHandler` from a request-handling function.
|
||||
-/
|
||||
def ofFn
|
||||
(f : Request Body.Stream → ContextAsync (Response Body.Any)) :
|
||||
StatelessHandler :=
|
||||
{ onRequest := f }
|
||||
|
||||
/--
|
||||
Builds a `StatelessHandler` from all three callback functions.
|
||||
-/
|
||||
def ofFns
|
||||
(onRequest : Request Body.Stream → ContextAsync (Response Body.Any))
|
||||
(onFailure : IO.Error → Async Unit := fun _ => pure ())
|
||||
(onContinue : Request.Head → Async Bool := fun _ => pure true) :
|
||||
StatelessHandler :=
|
||||
{ onRequest, onFailure, onContinue }
|
||||
|
||||
/--
|
||||
Builds a `StatelessHandler` from a request function and a failure callback. Useful for
|
||||
attaching error logging to a handler.
|
||||
-/
|
||||
def withFailure
|
||||
(handler : StatelessHandler)
|
||||
(onFailure : IO.Error → Async Unit) :
|
||||
StatelessHandler :=
|
||||
{ handler with onFailure }
|
||||
|
||||
/--
|
||||
Builds a `StatelessHandler` from a request function and a continue callback
|
||||
-/
|
||||
def withContinue
|
||||
(handler : StatelessHandler)
|
||||
(onContinue : Request.Head → Async Bool) :
|
||||
StatelessHandler :=
|
||||
{ handler with onContinue }
|
||||
|
||||
end Handler
|
||||
|
||||
end Std.Http.Server
|
||||
@@ -1,243 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sofia Rodrigues
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Std.Internal.Http.Server
|
||||
public import Std.Internal.Async
|
||||
public import Std.Internal.Async.Timer
|
||||
import Init.Data.String.Legacy
|
||||
|
||||
public section
|
||||
|
||||
open Std.Internal.IO Async
|
||||
open Std Http
|
||||
|
||||
namespace Std.Http.Internal.Test
|
||||
|
||||
abbrev TestHandler := Request Body.Stream → ContextAsync (Response Body.Any)
|
||||
|
||||
instance : Std.Http.Server.Handler TestHandler where
|
||||
onRequest handler request := handler request
|
||||
|
||||
/--
|
||||
Default config for server tests. Short lingering timeout, no Date header.
|
||||
-/
|
||||
def defaultConfig : Config :=
|
||||
{ lingeringTimeout := 1000, generateDate := false }
|
||||
|
||||
private def sendRaw
|
||||
(client : Mock.Client) (server : Mock.Server) (raw : ByteArray)
|
||||
(handler : TestHandler) (config : Config) : IO ByteArray :=
|
||||
Async.block do
|
||||
client.send raw
|
||||
Std.Http.Server.serveConnection server handler config |>.run
|
||||
let res ← client.recv?
|
||||
pure (res.getD .empty)
|
||||
|
||||
private def sendClose
|
||||
(client : Mock.Client) (server : Mock.Server) (raw : ByteArray)
|
||||
(handler : TestHandler) (config : Config) : IO ByteArray :=
|
||||
Async.block do
|
||||
client.send raw
|
||||
client.getSendChan.close
|
||||
Std.Http.Server.serveConnection server handler config |>.run
|
||||
let res ← client.recv?
|
||||
pure (res.getD .empty)
|
||||
|
||||
-- Timeout wrapper
|
||||
|
||||
private def withTimeout {α : Type} (name : String) (ms : Nat) (action : IO α) : IO α := do
|
||||
let task ← IO.asTask action
|
||||
let ticks := (ms + 9) / 10
|
||||
let rec loop : Nat → IO α
|
||||
| 0 => do IO.cancel task; throw <| IO.userError s!"'{name}' timed out after {ms}ms"
|
||||
| n + 1 => do
|
||||
if (← IO.getTaskState task) == .finished then
|
||||
match ← IO.wait task with
|
||||
| .ok x => pure x
|
||||
| .error e => throw e
|
||||
else IO.sleep 10; loop n
|
||||
loop ticks
|
||||
|
||||
-- Test grouping
|
||||
|
||||
/--
|
||||
Run `tests` and wrap any failure message with the group name.
|
||||
Use as `#eval runGroup "Topic" do ...`.
|
||||
-/
|
||||
def runGroup (name : String) (tests : IO Unit) : IO Unit :=
|
||||
try tests
|
||||
catch e => throw (IO.userError s!"[{name}]\n{e}")
|
||||
|
||||
-- Per-test runners
|
||||
|
||||
/--
|
||||
Create a fresh mock connection, send `raw`, and run assertions.
|
||||
-/
|
||||
def check
|
||||
(name : String)
|
||||
(raw : String)
|
||||
(handler : TestHandler)
|
||||
(expect : ByteArray → IO Unit)
|
||||
(config : Config := defaultConfig) : IO Unit := do
|
||||
let (client, server) ← Mock.new
|
||||
let response ← sendRaw client server raw.toUTF8 handler config
|
||||
try expect response
|
||||
catch e => throw (IO.userError s!"[{name}] {e}")
|
||||
|
||||
/--
|
||||
Like `check` but closes the client channel before running the server.
|
||||
Use for tests involving truncated input or silent-close (EOF-triggered behavior).
|
||||
-/
|
||||
def checkClose
|
||||
(name : String)
|
||||
(raw : String)
|
||||
(handler : TestHandler)
|
||||
(expect : ByteArray → IO Unit)
|
||||
(config : Config := defaultConfig) : IO Unit := do
|
||||
let (client, server) ← Mock.new
|
||||
let response ← sendClose client server raw.toUTF8 handler config
|
||||
try expect response
|
||||
catch e => throw (IO.userError s!"[{name}] {e}")
|
||||
|
||||
/--
|
||||
Like `check` wrapped in a wall-clock timeout.
|
||||
Required when the test involves streaming, async timers, or keep-alive behavior.
|
||||
-/
|
||||
def checkTimed
|
||||
(name : String)
|
||||
(ms : Nat := 2000)
|
||||
(raw : String)
|
||||
(handler : TestHandler)
|
||||
(expect : ByteArray → IO Unit)
|
||||
(config : Config := defaultConfig) : IO Unit :=
|
||||
withTimeout name ms <| check name raw handler expect config
|
||||
|
||||
-- Assertion helpers
|
||||
|
||||
/--
|
||||
Assert the response starts with `prefix_` (e.g. `"HTTP/1.1 200"`).
|
||||
-/
|
||||
def assertStatus (response : ByteArray) (prefix_ : String) : IO Unit := do
|
||||
let text := String.fromUTF8! response
|
||||
unless text.startsWith prefix_ do
|
||||
throw <| IO.userError s!"expected status {prefix_.quote}, got:\n{text.quote}"
|
||||
|
||||
/--
|
||||
Assert the response is byte-for-byte equal to `expected`.
|
||||
Use sparingly — prefer `assertStatus` + `assertContains` for 200 responses.
|
||||
-/
|
||||
def assertExact (response : ByteArray) (expected : String) : IO Unit := do
|
||||
let text := String.fromUTF8! response
|
||||
unless text == expected do
|
||||
throw <| IO.userError s!"expected:\n{expected.quote}\ngot:\n{text.quote}"
|
||||
|
||||
/--
|
||||
Assert `needle` appears anywhere in the response.
|
||||
-/
|
||||
def assertContains (response : ByteArray) (needle : String) : IO Unit := do
|
||||
let text := String.fromUTF8! response
|
||||
unless text.contains needle do
|
||||
throw <| IO.userError s!"expected to contain {needle.quote}, got:\n{text.quote}"
|
||||
|
||||
/--
|
||||
Assert `needle` does NOT appear in the response.
|
||||
-/
|
||||
def assertAbsent (response : ByteArray) (needle : String) : IO Unit := do
|
||||
let text := String.fromUTF8! response
|
||||
if text.contains needle then
|
||||
throw <| IO.userError s!"expected NOT to contain {needle.quote}, got:\n{text.quote}"
|
||||
|
||||
/--
|
||||
Assert the response contains exactly `n` occurrences of `"HTTP/1.1 "`.
|
||||
-/
|
||||
def assertResponseCount (response : ByteArray) (n : Nat) : IO Unit := do
|
||||
let text := String.fromUTF8! response
|
||||
let got := (text.splitOn "HTTP/1.1 ").length - 1
|
||||
unless got == n do
|
||||
throw <| IO.userError s!"expected {n} HTTP/1.1 responses, got {got}:\n{text.quote}"
|
||||
|
||||
-- Common fixed response strings
|
||||
|
||||
def r400 : String :=
|
||||
"HTTP/1.1 400 Bad Request\x0d\nServer: LeanHTTP/1.1\x0d\nConnection: close\x0d\nContent-Length: 0\x0d\n\x0d\n"
|
||||
|
||||
def r408 : String :=
|
||||
"HTTP/1.1 408 Request Timeout\x0d\nServer: LeanHTTP/1.1\x0d\nConnection: close\x0d\nContent-Length: 0\x0d\n\x0d\n"
|
||||
|
||||
def r413 : String :=
|
||||
"HTTP/1.1 413 Content Too Large\x0d\nServer: LeanHTTP/1.1\x0d\nConnection: close\x0d\nContent-Length: 0\x0d\n\x0d\n"
|
||||
|
||||
def r417 : String :=
|
||||
"HTTP/1.1 417 Expectation Failed\x0d\nServer: LeanHTTP/1.1\x0d\nConnection: close\x0d\nContent-Length: 0\x0d\n\x0d\n"
|
||||
|
||||
def r431 : String :=
|
||||
"HTTP/1.1 431 Request Header Fields Too Large\x0d\nServer: LeanHTTP/1.1\x0d\nConnection: close\x0d\nContent-Length: 0\x0d\n\x0d\n"
|
||||
|
||||
def r505 : String :=
|
||||
"HTTP/1.1 505 HTTP Version Not Supported\x0d\nServer: LeanHTTP/1.1\x0d\nConnection: close\x0d\nContent-Length: 0\x0d\n\x0d\n"
|
||||
|
||||
-- Standard handlers
|
||||
|
||||
/--
|
||||
Always respond 200 "ok" without reading the request body.
|
||||
-/
|
||||
def okHandler : TestHandler := fun _ => Response.ok |>.text "ok"
|
||||
|
||||
/--
|
||||
Read the full request body and echo it back as text/plain.
|
||||
-/
|
||||
def echoHandler : TestHandler := fun req => do
|
||||
Response.ok |>.text (← req.body.readAll)
|
||||
|
||||
/--
|
||||
Respond 200 with the request URI as the body.
|
||||
-/
|
||||
def uriHandler : TestHandler := fun req =>
|
||||
Response.ok |>.text (toString req.line.uri)
|
||||
|
||||
-- Request builder helpers
|
||||
|
||||
/--
|
||||
Minimal GET request. `extra` is appended as raw header lines (each ending with `\x0d\n`)
|
||||
before the blank line.
|
||||
-/
|
||||
def mkGet (path : String := "/") (extra : String := "") : String :=
|
||||
s!"GET {path} HTTP/1.1\x0d\nHost: example.com\x0d\n{extra}\x0d\n"
|
||||
|
||||
/--
|
||||
GET with `Connection: close`.
|
||||
-/
|
||||
def mkGetClose (path : String := "/") : String :=
|
||||
mkGet path "Connection: close\x0d\n"
|
||||
|
||||
/--
|
||||
POST with a fixed Content-Length body. `extra` is appended before the blank line.
|
||||
-/
|
||||
def mkPost (path : String) (body : String) (extra : String := "") : String :=
|
||||
s!"POST {path} HTTP/1.1\x0d\nHost: example.com\x0d\nContent-Length: {body.toUTF8.size}\x0d\n{extra}\x0d\n{body}"
|
||||
|
||||
/--
|
||||
POST with Transfer-Encoding: chunked. `chunkedBody` is the pre-formatted body
|
||||
(use `chunk` + `chunkEnd` to build it).
|
||||
-/
|
||||
def mkChunked (path : String) (chunkedBody : String) (extra : String := "") : String :=
|
||||
s!"POST {path} HTTP/1.1\x0d\nHost: example.com\x0d\nTransfer-Encoding: chunked\x0d\n{extra}\x0d\n{chunkedBody}"
|
||||
|
||||
/--
|
||||
Format a single chunk: `<hex-size>\x0d\n<data>\x0d\n`.
|
||||
-/
|
||||
def chunk (data : String) : String :=
|
||||
let hexSize := Nat.toDigits 16 data.toUTF8.size |> String.ofList
|
||||
s!"{hexSize}\x0d\n{data}\x0d\n"
|
||||
|
||||
/--
|
||||
The terminal zero-chunk that ends a chunked body.
|
||||
-/
|
||||
def chunkEnd : String := "0\x0d\n\x0d\n"
|
||||
|
||||
end Std.Http.Internal.Test
|
||||
@@ -1,253 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sofia Rodrigues
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Std.Internal.Http.Protocol.H1
|
||||
|
||||
public section
|
||||
|
||||
/-!
|
||||
# Transport
|
||||
|
||||
This module exposes a `Transport` type class that is used to represent different transport mechanisms
|
||||
that can be used with an HTTP connection.
|
||||
-/
|
||||
|
||||
namespace Std.Http
|
||||
open Std Internal IO Async TCP
|
||||
|
||||
set_option linter.all true
|
||||
|
||||
/--
|
||||
Generic HTTP interface that abstracts over different transport mechanisms.
|
||||
-/
|
||||
class Transport (α : Type) where
|
||||
/--
|
||||
Receive data from the client connection, up to the expected size.
|
||||
Returns None if the connection is closed or no data is available.
|
||||
-/
|
||||
recv : α → UInt64 → Async (Option ByteArray)
|
||||
|
||||
/--
|
||||
Send all data through the client connection.
|
||||
-/
|
||||
sendAll : α → Array ByteArray → Async Unit
|
||||
|
||||
/--
|
||||
Get a selector for receiving data asynchronously.
|
||||
-/
|
||||
recvSelector : α → UInt64 → Selector (Option ByteArray)
|
||||
|
||||
/--
|
||||
Close the transport connection.
|
||||
The default implementation is a no-op; override this for transports that require explicit teardown.
|
||||
For `Socket.Client`, the runtime closes the file descriptor when the object is finalized.
|
||||
-/
|
||||
close : α → IO Unit := fun _ => pure ()
|
||||
|
||||
instance : Transport Socket.Client where
|
||||
recv client expect := client.recv? expect
|
||||
sendAll client data := client.sendAll data
|
||||
recvSelector client expect := client.recvSelector expect
|
||||
|
||||
namespace Internal
|
||||
|
||||
open Internal.IO.Async in
|
||||
|
||||
/--
|
||||
Shared state for a bidirectional mock connection.
|
||||
-/
|
||||
private structure Mock.SharedState where
|
||||
/--
|
||||
Client to server direction.
|
||||
-/
|
||||
clientToServer : Std.CloseableChannel ByteArray
|
||||
|
||||
/--
|
||||
Server to client direction.
|
||||
-/
|
||||
serverToClient : Std.CloseableChannel ByteArray
|
||||
|
||||
/--
|
||||
Mock client endpoint for testing.
|
||||
-/
|
||||
structure Mock.Client where
|
||||
private shared : Mock.SharedState
|
||||
|
||||
/--
|
||||
Mock server endpoint for testing.
|
||||
-/
|
||||
structure Mock.Server where
|
||||
private shared : Mock.SharedState
|
||||
|
||||
namespace Mock
|
||||
|
||||
/--
|
||||
Creates a mock server and client that are connected to each other and share the
|
||||
same underlying state, enabling bidirectional communication.
|
||||
-/
|
||||
def new : BaseIO (Mock.Client × Mock.Server) := do
|
||||
let first ← Std.CloseableChannel.new
|
||||
let second ← Std.CloseableChannel.new
|
||||
|
||||
return (⟨⟨first, second⟩⟩, ⟨⟨first, second⟩⟩)
|
||||
|
||||
/--
|
||||
Receives data from a channel, joining all available data up to the expected size. First does a
|
||||
blocking recv, then greedily consumes available data with tryRecv until `expect` bytes are reached.
|
||||
-/
|
||||
def recvJoined (recvChan : Std.CloseableChannel ByteArray) (expect : Option UInt64) : Async (Option ByteArray) := do
|
||||
match ← await (← recvChan.recv) with
|
||||
| none => return none
|
||||
| some first =>
|
||||
let mut result := first
|
||||
repeat
|
||||
if let some expect := expect then
|
||||
if result.size.toUInt64 ≥ expect then break
|
||||
|
||||
match ← recvChan.tryRecv with
|
||||
| none => break
|
||||
| some chunk => result := result ++ chunk
|
||||
return some result
|
||||
|
||||
/--
|
||||
Sends a single ByteArray through a channel.
|
||||
-/
|
||||
def send (sendChan : Std.CloseableChannel ByteArray) (data : ByteArray) : Async Unit := do
|
||||
Async.ofAsyncTask ((← sendChan.send data) |>.map (Except.mapError (IO.userError ∘ toString)))
|
||||
|
||||
/--
|
||||
Sends ByteArrays through a channel.
|
||||
-/
|
||||
def sendAll (sendChan : Std.CloseableChannel ByteArray) (data : Array ByteArray) : Async Unit := do
|
||||
for chunk in data do
|
||||
send sendChan chunk
|
||||
|
||||
/--
|
||||
Creates a selector for receiving from a channel.
|
||||
-/
|
||||
def recvSelector (recvChan : Std.CloseableChannel ByteArray) : Selector (Option ByteArray) :=
|
||||
recvChan.recvSelector
|
||||
|
||||
end Mock
|
||||
|
||||
namespace Mock.Client
|
||||
|
||||
/--
|
||||
Gets the receive channel for a client (server to client direction).
|
||||
-/
|
||||
def getRecvChan (client : Mock.Client) : Std.CloseableChannel ByteArray :=
|
||||
client.shared.serverToClient
|
||||
|
||||
/--
|
||||
Gets the send channel for a client (client to server direction).
|
||||
-/
|
||||
def getSendChan (client : Mock.Client) : Std.CloseableChannel ByteArray :=
|
||||
client.shared.clientToServer
|
||||
|
||||
/--
|
||||
Sends a single ByteArray.
|
||||
-/
|
||||
def send (client : Mock.Client) (data : ByteArray) : Async Unit :=
|
||||
Mock.send (getSendChan client) data
|
||||
|
||||
/--
|
||||
Receives data, joining all available chunks.
|
||||
-/
|
||||
def recv? (client : Mock.Client) (expect : Option UInt64 := none) : Async (Option ByteArray) :=
|
||||
Mock.recvJoined (getRecvChan client) expect
|
||||
|
||||
/--
|
||||
Tries to receive data without blocking, joining all immediately available chunks.
|
||||
Returns `none` if no data is available.
|
||||
-/
|
||||
def tryRecv? (client : Mock.Client) (_expect : UInt64 := 0) : BaseIO (Option ByteArray) := do
|
||||
match ← (getRecvChan client).tryRecv with
|
||||
| none => return none
|
||||
| some first =>
|
||||
let mut result := first
|
||||
repeat
|
||||
match ← (getRecvChan client).tryRecv with
|
||||
| none => break
|
||||
| some chunk => result := result ++ chunk
|
||||
return some result
|
||||
|
||||
/--
|
||||
Closes the mock server and client.
|
||||
-/
|
||||
def close (client : Mock.Client) : IO Unit := do
|
||||
if !(← client.shared.clientToServer.isClosed) then client.shared.clientToServer.close
|
||||
if !(← client.shared.serverToClient.isClosed) then client.shared.serverToClient.close
|
||||
|
||||
end Mock.Client
|
||||
|
||||
namespace Mock.Server
|
||||
|
||||
/--
|
||||
Gets the receive channel for a server (client to server direction).
|
||||
-/
|
||||
def getRecvChan (server : Mock.Server) : Std.CloseableChannel ByteArray :=
|
||||
server.shared.clientToServer
|
||||
|
||||
/--
|
||||
Gets the send channel for a server (server to client direction).
|
||||
-/
|
||||
def getSendChan (server : Mock.Server) : Std.CloseableChannel ByteArray :=
|
||||
server.shared.serverToClient
|
||||
|
||||
/--
|
||||
Sends a single ByteArray.
|
||||
-/
|
||||
def send (server : Mock.Server) (data : ByteArray) : Async Unit :=
|
||||
Mock.send (getSendChan server) data
|
||||
|
||||
/--
|
||||
Receives data, joining all available chunks.
|
||||
-/
|
||||
def recv? (server : Mock.Server) (expect : Option UInt64 := none) : Async (Option ByteArray) :=
|
||||
Mock.recvJoined (getRecvChan server) expect
|
||||
|
||||
/--
|
||||
Tries to receive data without blocking, joining all immediately available chunks. Returns `none` if no
|
||||
data is available.
|
||||
-/
|
||||
def tryRecv? (server : Mock.Server) (_expect : UInt64 := 0) : BaseIO (Option ByteArray) := do
|
||||
match ← (getRecvChan server).tryRecv with
|
||||
| none => return none
|
||||
| some first =>
|
||||
let mut result := first
|
||||
repeat
|
||||
match ← (getRecvChan server).tryRecv with
|
||||
| none => break
|
||||
| some chunk => result := result ++ chunk
|
||||
return some result
|
||||
|
||||
/--
|
||||
Closes the mock server and client.
|
||||
-/
|
||||
def close (server : Mock.Server) : IO Unit := do
|
||||
if !(← server.shared.clientToServer.isClosed) then server.shared.clientToServer.close
|
||||
if !(← server.shared.serverToClient.isClosed) then server.shared.serverToClient.close
|
||||
|
||||
|
||||
end Mock.Server
|
||||
|
||||
instance : Transport Mock.Client where
|
||||
recv client expect := Mock.recvJoined (Mock.Client.getRecvChan client) (some expect)
|
||||
sendAll client data := Mock.sendAll (Mock.Client.getSendChan client) data
|
||||
recvSelector client _ := Mock.recvSelector (Mock.Client.getRecvChan client)
|
||||
close client := client.close
|
||||
|
||||
instance : Transport Mock.Server where
|
||||
recv server expect := Mock.recvJoined (Mock.Server.getRecvChan server) (some expect)
|
||||
sendAll server data := Mock.sendAll (Mock.Server.getSendChan server) data
|
||||
recvSelector server _ := Mock.recvSelector (Mock.Server.getRecvChan server)
|
||||
close server := server.close
|
||||
|
||||
end Internal
|
||||
|
||||
end Std.Http
|
||||
@@ -124,9 +124,6 @@ end IPv4Addr
|
||||
|
||||
namespace SocketAddressV4
|
||||
|
||||
instance : ToString SocketAddressV4 where
|
||||
toString sa := toString sa.addr ++ ":" ++ toString sa.port
|
||||
|
||||
instance : Coe SocketAddressV4 SocketAddress where
|
||||
coe addr := .v4 addr
|
||||
|
||||
@@ -164,9 +161,6 @@ end IPv6Addr
|
||||
|
||||
namespace SocketAddressV6
|
||||
|
||||
instance : ToString SocketAddressV6 where
|
||||
toString sa := "[" ++ toString sa.addr ++ "]:" ++ toString sa.port
|
||||
|
||||
instance : Coe SocketAddressV6 SocketAddress where
|
||||
coe addr := .v6 addr
|
||||
|
||||
@@ -192,11 +186,6 @@ end IPAddr
|
||||
|
||||
namespace SocketAddress
|
||||
|
||||
instance : ToString SocketAddress where
|
||||
toString
|
||||
| .v4 sa => toString sa
|
||||
| .v6 sa => toString sa
|
||||
|
||||
/--
|
||||
Obtain the `AddressFamily` associated with a `SocketAddress`.
|
||||
-/
|
||||
|
||||
@@ -11,7 +11,6 @@ public import Std.Sync.Channel
|
||||
public import Std.Sync.Mutex
|
||||
public import Std.Sync.RecursiveMutex
|
||||
public import Std.Sync.Barrier
|
||||
public import Std.Sync.Semaphore
|
||||
public import Std.Sync.SharedMutex
|
||||
public import Std.Sync.Notify
|
||||
public import Std.Sync.Broadcast
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sofia Rodrigues
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.Data.Queue
|
||||
public import Init.System.Promise
|
||||
public import Std.Sync.Mutex
|
||||
|
||||
public section
|
||||
|
||||
namespace Std
|
||||
|
||||
private structure SemaphoreState where
|
||||
permits : Nat
|
||||
waiters : Std.Queue (IO.Promise Unit) := ∅
|
||||
deriving Nonempty
|
||||
|
||||
/--
|
||||
Counting semaphore.
|
||||
|
||||
`Semaphore.acquire` returns a promise that is resolved once a permit is available.
|
||||
If a permit is currently available, the returned promise is already resolved.
|
||||
`Semaphore.release` either resolves one waiting promise or increments the available permits.
|
||||
-/
|
||||
structure Semaphore where private mk ::
|
||||
private lock : Mutex SemaphoreState
|
||||
|
||||
/--
|
||||
Creates a resolved promise.
|
||||
-/
|
||||
private def mkResolvedPromise [Nonempty α] (a : α) : BaseIO (IO.Promise α) := do
|
||||
let promise ← IO.Promise.new
|
||||
promise.resolve a
|
||||
return promise
|
||||
|
||||
/--
|
||||
Creates a new semaphore with `permits` initially available permits.
|
||||
-/
|
||||
def Semaphore.new (permits : Nat) : BaseIO Semaphore := do
|
||||
return { lock := ← Mutex.new { permits } }
|
||||
|
||||
/--
|
||||
Requests one permit.
|
||||
Returns a promise that resolves once the permit is acquired.
|
||||
-/
|
||||
def Semaphore.acquire (sem : Semaphore) : BaseIO (IO.Promise Unit) := do
|
||||
sem.lock.atomically do
|
||||
let st ← get
|
||||
if st.permits > 0 then
|
||||
set { st with permits := st.permits - 1 }
|
||||
mkResolvedPromise ()
|
||||
else
|
||||
let promise ← IO.Promise.new
|
||||
set { st with waiters := st.waiters.enqueue promise }
|
||||
return promise
|
||||
|
||||
/--
|
||||
Tries to acquire a permit without blocking. Returns `true` on success.
|
||||
-/
|
||||
def Semaphore.tryAcquire (sem : Semaphore) : BaseIO Bool := do
|
||||
sem.lock.atomically do
|
||||
let st ← get
|
||||
if st.permits > 0 then
|
||||
set { st with permits := st.permits - 1 }
|
||||
return true
|
||||
else
|
||||
return false
|
||||
|
||||
/--
|
||||
Releases one permit and resolves one waiting acquirer, if any.
|
||||
-/
|
||||
def Semaphore.release (sem : Semaphore) : BaseIO Unit := do
|
||||
let waiter? ← sem.lock.atomically do
|
||||
let st ← get
|
||||
match st.waiters.dequeue? with
|
||||
| some (waiter, waiters) =>
|
||||
set { st with waiters }
|
||||
return some waiter
|
||||
| none =>
|
||||
set { st with permits := st.permits + 1 }
|
||||
return none
|
||||
if let some waiter := waiter? then
|
||||
waiter.resolve ()
|
||||
|
||||
/--
|
||||
Returns the number of currently available permits.
|
||||
-/
|
||||
def Semaphore.availablePermits (sem : Semaphore) : BaseIO Nat :=
|
||||
sem.lock.atomically do
|
||||
return (← get).permits
|
||||
|
||||
end Std
|
||||
@@ -8,6 +8,7 @@ module
|
||||
prelude
|
||||
public import Std.Tactic.Do.ProofMode
|
||||
public import Std.Tactic.Do.Syntax
|
||||
public import Std.Do.Triple.RepeatSpec
|
||||
|
||||
@[expose] public section
|
||||
|
||||
|
||||
@@ -235,7 +235,7 @@ public def checkHashUpToDate
|
||||
: JobM Bool := (·.isUpToDate) <$> checkHashUpToDate' info depTrace depHash oldTrace
|
||||
|
||||
/--
|
||||
**For internal use only.**
|
||||
**Ror internal use only.**
|
||||
Checks whether `info` is up-to-date with the trace.
|
||||
If so, replays the log of the trace if available.
|
||||
-/
|
||||
@@ -271,24 +271,20 @@ Returns `true` if the saved trace exists and its hash matches `inputHash`.
|
||||
|
||||
If up-to-date, replays the saved log from the trace and sets the current
|
||||
build action to `replay`. Otherwise, if the log is empty and trace is synthetic,
|
||||
or if the trace is not up-to-date, the build action will be set to `reuse`.
|
||||
or if the trace is not up-to-date, the build action will be set to `fetch`.
|
||||
-/
|
||||
public def SavedTrace.replayCachedIfUpToDate (inputHash : Hash) (self : SavedTrace) : JobM Bool := do
|
||||
public def SavedTrace.replayOrFetchIfUpToDate (inputHash : Hash) (self : SavedTrace) : JobM Bool := do
|
||||
if let .ok data := self then
|
||||
if data.depHash == inputHash then
|
||||
if data.synthetic && data.log.isEmpty then
|
||||
updateAction .reuse
|
||||
updateAction .fetch
|
||||
else
|
||||
updateAction .replay
|
||||
data.log.replay
|
||||
return true
|
||||
updateAction .reuse
|
||||
updateAction .fetch
|
||||
return false
|
||||
|
||||
@[deprecated replayCachedIfUpToDate (since := "2026-04-15")]
|
||||
public abbrev SavedTrace.replayOrFetchIfUpToDate (inputHash : Hash) (self : SavedTrace) : JobM Bool := do
|
||||
self.replayCachedIfUpToDate inputHash
|
||||
|
||||
/-- **For internal use only.** -/
|
||||
public class ToOutputJson (α : Type u) where
|
||||
toOutputJson (arts : α) : Json
|
||||
@@ -688,7 +684,7 @@ public def buildArtifactUnlessUpToDate
|
||||
let fetchArt? restore := do
|
||||
let some (art : XArtifact exe) ← getArtifacts? inputHash savedTrace pkg
|
||||
| return none
|
||||
unless (← savedTrace.replayCachedIfUpToDate inputHash) do
|
||||
unless (← savedTrace.replayOrFetchIfUpToDate inputHash) do
|
||||
removeFileIfExists file
|
||||
writeFetchTrace traceFile inputHash (toJson art.descr)
|
||||
if restore then
|
||||
|
||||
@@ -29,18 +29,11 @@ namespace Lake
|
||||
public inductive JobAction
|
||||
/-- No information about this job's action is available. -/
|
||||
| unknown
|
||||
/-- Tried to reuse a cached build (e.g., can be set by `replayCachedIfUpToDate`). -/
|
||||
| reuse
|
||||
/-- Tried to replay a completed build action (e.g., can be set by `replayIfUpToDate`). -/
|
||||
/-- Tried to replay a cached build action (set by `buildFileUnlessUpToDate`) -/
|
||||
| replay
|
||||
/-- Tried to unpack a build from an archive (e.g., unpacking a module `ltar`). -/
|
||||
| unpack
|
||||
/--
|
||||
Tried to fetch a build from a remote store (e.g., set when downloading an artifact
|
||||
on-demand from a cache service in `buildArtifactUnlessUpToDate`).
|
||||
-/
|
||||
/-- Tried to fetch a build from a store (can be set by `buildUnlessUpToDate?`) -/
|
||||
| fetch
|
||||
/-- Tried to perform a build action (e.g., set by `buildAction`). -/
|
||||
/-- Tried to perform a build action (set by `buildUnlessUpToDate?`) -/
|
||||
| build
|
||||
deriving Inhabited, Repr, DecidableEq, Ord
|
||||
|
||||
@@ -52,13 +45,11 @@ public instance : Min JobAction := minOfLe
|
||||
public instance : Max JobAction := maxOfLe
|
||||
|
||||
public def merge (a b : JobAction) : JobAction :=
|
||||
max a b -- inlines `max`
|
||||
max a b
|
||||
|
||||
public def verb (failed : Bool) : (self : JobAction) → String
|
||||
public def verb (failed : Bool) : JobAction → String
|
||||
| .unknown => if failed then "Running" else "Ran"
|
||||
| .reuse => if failed then "Reusing" else "Reused"
|
||||
| .replay => if failed then "Replaying" else "Replayed"
|
||||
| .unpack => if failed then "Unpacking" else "Unpacked"
|
||||
| .fetch => if failed then "Fetching" else "Fetched"
|
||||
| .build => if failed then "Building" else "Built"
|
||||
|
||||
|
||||
@@ -900,9 +900,8 @@ where
|
||||
let inputHash := (← getTrace).hash
|
||||
let some ltarOrArts ← getArtifacts? inputHash savedTrace mod.pkg
|
||||
| return .inr savedTrace
|
||||
match (ltarOrArts : ModuleOutputs) with
|
||||
match (ltarOrArts : ModuleOutputs) with
|
||||
| .ltar ltar =>
|
||||
updateAction .unpack
|
||||
mod.clearOutputArtifacts
|
||||
mod.unpackLtar ltar.path
|
||||
-- Note: This branch implies that only the ltar output is (validly) cached.
|
||||
@@ -920,7 +919,7 @@ where
|
||||
else
|
||||
return .inr savedTrace
|
||||
| .arts arts =>
|
||||
unless (← savedTrace.replayCachedIfUpToDate inputHash) do
|
||||
unless (← savedTrace.replayOrFetchIfUpToDate inputHash) do
|
||||
mod.clearOutputArtifacts
|
||||
writeFetchTrace mod.traceFile inputHash (toJson arts.descrs)
|
||||
let arts ←
|
||||
|
||||
@@ -25,8 +25,12 @@ namespace Lake
|
||||
open Lean (Name)
|
||||
|
||||
/-- Fetch the package's direct dependencies. -/
|
||||
def Package.recFetchDeps (self : Package) : FetchM (Job (Array Package)) := do
|
||||
return Job.pure self.depPkgs
|
||||
def Package.recFetchDeps (self : Package) : FetchM (Job (Array Package)) := ensureJob do
|
||||
(pure ·) <$> self.depConfigs.mapM fun cfg => do
|
||||
let some dep ← findPackageByName? cfg.name
|
||||
| error s!"{self.prettyName}: package not found for dependency '{cfg.name}' \
|
||||
(this is likely a bug in Lake)"
|
||||
return dep
|
||||
|
||||
/-- The `PackageFacetConfig` for the builtin `depsFacet`. -/
|
||||
public def Package.depsFacetConfig : PackageFacetConfig depsFacet :=
|
||||
@@ -34,7 +38,10 @@ public def Package.depsFacetConfig : PackageFacetConfig depsFacet :=
|
||||
|
||||
/-- Compute a topological ordering of the package's transitive dependencies. -/
|
||||
def Package.recComputeTransDeps (self : Package) : FetchM (Job (Array Package)) := ensureJob do
|
||||
(pure ·.toArray) <$> self.depPkgs.foldlM (init := OrdPackageSet.empty) fun deps dep => do
|
||||
(pure ·.toArray) <$> self.depConfigs.foldlM (init := OrdPackageSet.empty) fun deps cfg => do
|
||||
let some dep ← findPackageByName? cfg.name
|
||||
| error s!"{self.prettyName}: package not found for dependency '{cfg.name}' \
|
||||
(this is likely a bug in Lake)"
|
||||
let depDeps ← (← fetch <| dep.transDeps).await
|
||||
return depDeps.foldl (·.insert ·) deps |>.insert dep
|
||||
|
||||
@@ -146,7 +153,7 @@ def Package.fetchBuildArchive
|
||||
let upToDate ← buildUnlessUpToDate? (action := .fetch) archiveFile depTrace traceFile do
|
||||
download url archiveFile headers
|
||||
unless upToDate && (← self.buildDir.pathExists) do
|
||||
updateAction .unpack
|
||||
updateAction .fetch
|
||||
untar archiveFile self.buildDir
|
||||
|
||||
@[inline]
|
||||
|
||||
@@ -210,7 +210,7 @@ def mkMonitorContext (cfg : BuildConfig) (jobs : JobQueue) : BaseIO MonitorConte
|
||||
let failLv := cfg.failLv
|
||||
let isVerbose := cfg.verbosity = .verbose
|
||||
let showProgress := cfg.showProgress
|
||||
let minAction := if isVerbose then .unknown else .unpack
|
||||
let minAction := if isVerbose then .unknown else .fetch
|
||||
let showOptional := isVerbose
|
||||
let showTime := isVerbose || !useAnsi
|
||||
let updateFrequency := 100
|
||||
|
||||
@@ -9,7 +9,7 @@ prelude
|
||||
public import Lean.Data.Json
|
||||
import Init.Data.Nat.Fold
|
||||
meta import Init.Data.Nat.Fold
|
||||
public import Lake.Util.String
|
||||
import Lake.Util.String
|
||||
public import Init.Data.String.Search
|
||||
public import Init.Data.String.Extra
|
||||
import Init.Data.Option.Coe
|
||||
@@ -141,8 +141,8 @@ public def ofHex? (s : String) : Option Hash :=
|
||||
if s.utf8ByteSize = 16 && isHex s then ofHex s else none
|
||||
|
||||
/-- Returns the hash as 16-digit lowercase hex string. -/
|
||||
@[inline] public def hex (self : Hash) : String :=
|
||||
lowerHexUInt64 self.val
|
||||
public def hex (self : Hash) : String :=
|
||||
lpad (String.ofList <| Nat.toDigits 16 self.val.toNat) '0' 16
|
||||
|
||||
/-- Parse a hash from a string of decimal digits. -/
|
||||
public def ofDecimal? (s : String) : Option Hash :=
|
||||
|
||||
@@ -69,7 +69,7 @@ public structure LakeOptions where
|
||||
scope? : Option CacheServiceScope := none
|
||||
platform? : Option CachePlatform := none
|
||||
toolchain? : Option CacheToolchain := none
|
||||
rev? : Option GitRev := none
|
||||
rev? : Option String := none
|
||||
maxRevs : Nat := 100
|
||||
shake : Shake.Args := {}
|
||||
|
||||
@@ -563,7 +563,7 @@ private def computePackageRev (pkgDir : FilePath) : CliStateM String := do
|
||||
repo.getHeadRevision
|
||||
|
||||
private def putCore
|
||||
(rev : GitRev) (outputs : FilePath) (artDir : FilePath)
|
||||
(rev : String) (outputs : FilePath) (artDir : FilePath)
|
||||
(service : CacheService) (scope : CacheServiceScope)
|
||||
(platform := CachePlatform.none) (toolchain := CacheToolchain.none)
|
||||
: LoggerIO Unit := do
|
||||
|
||||
@@ -7,7 +7,6 @@ module
|
||||
|
||||
prelude
|
||||
import Init.Control.Do
|
||||
public import Lake.Util.Git
|
||||
public import Lake.Util.Log
|
||||
public import Lake.Util.Version
|
||||
public import Lake.Config.Artifact
|
||||
@@ -470,7 +469,7 @@ public def readOutputs? (cache : Cache) (scope : String) (inputHash : Hash) : Lo
|
||||
cache.dir / "revisions"
|
||||
|
||||
/-- Returns path to the input-to-output mappings of a downloaded package revision. -/
|
||||
@[inline] public def revisionPath (cache : Cache) (scope : String) (rev : GitRev) : FilePath :=
|
||||
@[inline] public def revisionPath (cache : Cache) (scope : String) (rev : String) : FilePath :=
|
||||
cache.revisionDir / scope / s!"{rev}.jsonl"
|
||||
|
||||
end Cache
|
||||
@@ -943,7 +942,7 @@ public def uploadArtifacts
|
||||
public def mapContentType : String := "application/vnd.reservoir.outputs+json-lines"
|
||||
|
||||
def s3RevisionUrl
|
||||
(rev : GitRev) (service : CacheService) (scope : CacheServiceScope)
|
||||
(rev : String) (service : CacheService) (scope : CacheServiceScope)
|
||||
(platform := CachePlatform.none) (toolchain := CacheToolchain.none)
|
||||
: String :=
|
||||
match scope.impl with
|
||||
@@ -957,7 +956,7 @@ def s3RevisionUrl
|
||||
return s!"{url}/{rev}.jsonl"
|
||||
|
||||
public def revisionUrl
|
||||
(rev : GitRev) (service : CacheService) (scope : CacheServiceScope)
|
||||
(rev : String) (service : CacheService) (scope : CacheServiceScope)
|
||||
(platform := CachePlatform.none) (toolchain := CacheToolchain.none)
|
||||
: String :=
|
||||
if service.isReservoir then Id.run do
|
||||
@@ -975,7 +974,7 @@ public def revisionUrl
|
||||
service.s3RevisionUrl rev scope platform toolchain
|
||||
|
||||
public def downloadRevisionOutputs?
|
||||
(rev : GitRev) (cache : Cache) (service : CacheService)
|
||||
(rev : String) (cache : Cache) (service : CacheService)
|
||||
(localScope : String) (remoteScope : CacheServiceScope)
|
||||
(platform := CachePlatform.none) (toolchain := CacheToolchain.none) (force := false)
|
||||
: LoggerIO (Option CacheMap) := do
|
||||
@@ -999,7 +998,7 @@ public def downloadRevisionOutputs?
|
||||
CacheMap.load path platform.isNone
|
||||
|
||||
public def uploadRevisionOutputs
|
||||
(rev : GitRev) (outputs : FilePath) (service : CacheService) (scope : CacheServiceScope)
|
||||
(rev : String) (outputs : FilePath) (service : CacheService) (scope : CacheServiceScope)
|
||||
(platform := CachePlatform.none) (toolchain := CacheToolchain.none)
|
||||
: LoggerIO Unit := do
|
||||
let url := service.s3RevisionUrl rev scope platform toolchain
|
||||
|
||||
@@ -9,7 +9,6 @@ prelude
|
||||
public import Init.Dynamic
|
||||
public import Init.System.FilePath
|
||||
public import Lean.Data.NameMap.Basic
|
||||
public import Lake.Util.Git
|
||||
import Init.Data.ToString.Name
|
||||
import Init.Data.ToString.Macro
|
||||
|
||||
@@ -31,7 +30,7 @@ public inductive DependencySrc where
|
||||
/- A package located at a fixed path relative to the dependent package's directory. -/
|
||||
| path (dir : FilePath)
|
||||
/- A package cloned from a Git repository available at a fixed Git `url`. -/
|
||||
| git (url : String) (rev : Option GitRev) (subDir : Option FilePath)
|
||||
| git (url : String) (rev : Option String) (subDir : Option FilePath)
|
||||
deriving Inhabited, Repr
|
||||
|
||||
/--
|
||||
|
||||
@@ -52,8 +52,6 @@ public structure Package where
|
||||
remoteUrl : String
|
||||
/-- Dependency configurations for the package. -/
|
||||
depConfigs : Array Dependency := #[]
|
||||
/-- **For internal use only.** Resolved direct dependences of the package. -/
|
||||
depPkgs : Array Package := #[]
|
||||
/-- Target configurations in the order declared by the package. -/
|
||||
targetDecls : Array (PConfigDecl keyName) := #[]
|
||||
/-- Name-declaration map of target configurations in the package. -/
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user