Compare commits

...

4 Commits

Author SHA1 Message Date
Kim Morrison
d8178ad761 wip 2025-01-29 23:02:57 +11:00
Kim Morrison
4023dc8a7b . 2025-01-29 21:38:17 +11:00
Kim Morrison
b82e5aa400 test file 2025-01-29 21:24:34 +11:00
Kim Morrison
68eda08aae feat: missing monadic functions on List/Array/Vector 2025-01-29 21:24:26 +11:00
9 changed files with 288 additions and 33 deletions

View File

@@ -452,7 +452,7 @@ def mapM {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : α
@[deprecated mapM (since := "2024-11-11")] abbrev sequenceMap := @mapM
/-- Variant of `mapIdxM` which receives the index as a `Fin as.size`. -/
/-- Variant of `mapIdxM` which receives the index `i` along with the bound `i < as.size`. -/
@[inline]
def mapFinIdxM {α : Type u} {β : Type v} {m : Type v Type w} [Monad m]
(as : Array α) (f : (i : Nat) α (h : i < as.size) m β) : m (Array β) :=
@@ -464,13 +464,25 @@ def mapFinIdxM {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m]
rw [ inv, Nat.add_assoc, Nat.add_comm 1 j, Nat.add_comm]
apply Nat.le_add_right
have : i + (j + 1) = as.size := by rw [ inv, Nat.add_comm j 1, Nat.add_assoc]
map i (j+1) this (bs.push ( f j (as.get j j_lt) j_lt))
map i (j+1) this (bs.push ( f j as[j] j_lt))
map as.size 0 rfl (mkEmpty as.size)
@[inline]
def mapIdxM {α : Type u} {β : Type v} {m : Type v Type w} [Monad m] (f : Nat α m β) (as : Array α) : m (Array β) :=
as.mapFinIdxM fun i a _ => f i a
@[inline]
def firstM {α : Type u} {m : Type v Type w} [Alternative m] (f : α m β) (as : Array α) : m β :=
go 0
where
go (i : Nat) : m β :=
if hlt : i < as.size then
f as[i] <|> go (i+1)
else
failure
termination_by as.size - i
decreasing_by exact Nat.sub_succ_lt_self as.size i hlt
@[inline]
def findSomeM? {α : Type u} {β : Type v} {m : Type v Type w} [Monad m] (f : α m (Option β)) (as : Array α) : m (Option β) := do
for a in as do
@@ -564,6 +576,9 @@ def findRevM? {α : Type} {m : Type → Type w} [Monad m] (p : α → m Bool) (a
def forM {α : Type u} {m : Type v Type w} [Monad m] (f : α m PUnit) (as : Array α) (start := 0) (stop := as.size) : m PUnit :=
as.foldlM (fun _ => f) start stop
instance : ForM m (Array α) α where
forM xs f := forM f xs
@[inline]
def forRevM {α : Type u} {m : Type v Type w} [Monad m] (f : α m PUnit) (as : Array α) (start := as.size) (stop := 0) : m PUnit :=
as.foldrM (fun a _ => f a) start stop
@@ -595,6 +610,9 @@ def count {α : Type u} [BEq α] (a : α) (as : Array α) : Nat :=
def map {α : Type u} {β : Type v} (f : α β) (as : Array α) : Array β :=
Id.run <| as.mapM f
instance : Functor Array where
map := map
/-- Variant of `mapIdx` which receives the index as a `Fin as.size`. -/
@[inline]
def mapFinIdx {α : Type u} {β : Type v} (as : Array α) (f : (i : Nat) α (h : i < as.size) β) : Array β :=
@@ -732,6 +750,24 @@ def flatMap (f : α → Array β) (as : Array α) : Array β :=
@[inline] def flatten (as : Array (Array α)) : Array α :=
as.foldl (init := empty) fun r a => r ++ a
def reverse (as : Array α) : Array α :=
if h : as.size 1 then
as
else
loop as 0 as.size - 1, Nat.pred_lt (mt (fun h : as.size = 0 => h by decide) h)
where
termination {i j : Nat} (h : i < j) : j - 1 - (i + 1) < j - i := by
rw [Nat.sub_sub, Nat.add_comm]
exact Nat.lt_of_le_of_lt (Nat.pred_le _) (Nat.sub_succ_lt_self _ _ h)
loop (as : Array α) (i : Nat) (j : Fin as.size) :=
if h : i < j then
have := termination h
let as := as.swap i j (Nat.lt_trans h j.2)
have : j-1 < as.size := by rw [size_swap]; exact Nat.lt_of_le_of_lt (Nat.pred_le _) j.2
loop as (i+1) j-1, this
else
as
@[inline]
def filter (p : α Bool) (as : Array α) (start := 0) (stop := as.size) : Array α :=
as.foldl (init := #[]) (start := start) (stop := stop) fun r a =>
@@ -742,6 +778,11 @@ def filterM {α : Type} [Monad m] (p : α → m Bool) (as : Array α) (start :=
as.foldlM (init := #[]) (start := start) (stop := stop) fun r a => do
if ( p a) then return r.push a else return r
@[inline]
def filterRevM {α : Type} [Monad m] (p : α m Bool) (as : Array α) (start := as.size) (stop := 0) : m (Array α) :=
reverse <$> as.foldrM (init := #[]) (start := start) (stop := stop) fun a r => do
if ( p a) then return r.push a else return r
@[specialize]
def filterMapM [Monad m] (f : α m (Option β)) (as : Array α) (start := 0) (stop := as.size) : m (Array β) :=
as.foldlM (init := #[]) (start := start) (stop := stop) fun bs a => do
@@ -773,24 +814,6 @@ def partition (p : α → Bool) (as : Array α) : Array α × Array α := Id.run
cs := cs.push a
return (bs, cs)
def reverse (as : Array α) : Array α :=
if h : as.size 1 then
as
else
loop as 0 as.size - 1, Nat.pred_lt (mt (fun h : as.size = 0 => h by decide) h)
where
termination {i j : Nat} (h : i < j) : j - 1 - (i + 1) < j - i := by
rw [Nat.sub_sub, Nat.add_comm]
exact Nat.lt_of_le_of_lt (Nat.pred_le _) (Nat.sub_succ_lt_self _ _ h)
loop (as : Array α) (i : Nat) (j : Fin as.size) :=
if h : i < j then
have := termination h
let as := as.swap i j (Nat.lt_trans h j.2)
have : j-1 < as.size := by rw [size_swap]; exact Nat.lt_of_le_of_lt (Nat.pred_le _) j.2
loop as (i+1) j-1, this
else
as
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
def popWhile (p : α Bool) (as : Array α) : Array α :=
if h : as.size > 0 then

View File

@@ -3759,6 +3759,27 @@ namespace List
as.toArray.unzip = Prod.map List.toArray List.toArray as.unzip := by
ext1 <;> simp
@[simp] theorem firstM_toArray [Alternative m] (as : List α) (f : α m β) :
as.toArray.firstM f = as.firstM f := by
unfold Array.firstM
suffices i, i as.length firstM.go f as.toArray (as.length - i) = firstM f (as.drop (as.length - i)) by
specialize this as.length
simpa
intro i
induction i with
| zero => simp [firstM.go]
| succ i ih =>
unfold firstM.go
split <;> rename_i h
· rw [drop_eq_getElem_cons h]
intro h'
specialize ih (by omega)
have : as.length - (i + 1) + 1 = as.length - i := by omega
simp_all [ih]
· simp only [size_toArray, Nat.not_lt] at h
have : as.length = 0 := by omega
simp_all
end List
namespace Array

View File

@@ -83,7 +83,7 @@ theorem foldrM_filter [Monad m] [LawfulMonad m] (p : α → Bool) (g : α → β
@[congr] theorem forM_congr [Monad m] {as bs : Array α} (w : as = bs)
{f : α m PUnit} :
forM f as = forM f bs := by
as.forM f = bs.forM f := by
cases as <;> cases bs
simp_all

View File

@@ -98,6 +98,7 @@ def forA {m : Type u → Type v} [Applicative m] {α : Type w} (as : List α) (f
| [] => pure
| a :: as => f a *> forA as f
@[specialize]
def filterAuxM {m : Type Type v} [Monad m] {α : Type} (f : α m Bool) : List α List α m (List α)
| [], acc => pure acc
@@ -136,6 +137,19 @@ def filterMapM {m : Type u → Type v} [Monad m] {α β : Type u} (f : α → m
| some b => loop as (b::bs)
loop as []
/--
Applies the monadic function `f` on every element `x` in the list, left-to-right, and returns the
concatenation of the results.
-/
@[inline]
def flatMapM {m : Type u Type v} [Monad m] {α : Type w} {β : Type u} (f : α m (List β)) (as : List α) : m (List β) :=
let rec @[specialize] loop
| [], bs => pure bs.reverse
| a :: as, bs => do
let bs' f a
loop as (bs' ++ bs)
loop as []
/--
Folds a monadic function over a list from left to right:
```

View File

@@ -15,17 +15,15 @@ namespace List
/-! ## Operations using indexes -/
/-! ### mapIdx -/
/--
Given a list `as = [a₀, a₁, ...]` function `f : Fin as.lengthα → β`, returns the list
`[f 0 a₀, f 1 a₁, ...]`.
Given a list `as = [a₀, a₁, ...]` and a function `f : (i : Nat) → α → (h : i < as.length) → β`, returns the list
`[f 0 a₀, f 1 a₁, ...]`.
-/
@[inline] def mapFinIdx (as : List α) (f : (i : Nat) α (h : i < as.length) β) : List β :=
go as #[] (by simp)
where
/-- Auxiliary for `mapFinIdx`:
`mapFinIdx.go [a₀, a₁, ...] acc = acc.toList ++ [f 0 a₀, f 1 a₁, ...]` -/
`mapFinIdx.go [a₀, a₁, ...] acc = acc.toList ++ [f 0 a₀, f 1 a₁, ...]` -/
@[specialize] go : (bs : List α) (acc : Array β) bs.length + acc.size = as.length List β
| [], acc, h => acc.toList
| a :: as, acc, h =>
@@ -42,6 +40,31 @@ Given a function `f : Nat → α → β` and `as : List α`, `as = [a₀, a₁,
| [], acc => acc.toList
| a :: as, acc => go as (acc.push (f acc.size a))
/--
Given a list `as = [a₀, a₁, ...]` and a monadic function `f : (i : Nat) → α → (h : i < as.length) → m β`,
returns the list `[f 0 a₀ ⋯, f 1 a₁ ⋯, ...]`.
-/
@[inline] def mapFinIdxM [Monad m] (as : List α) (f : (i : Nat) α (h : i < as.length) m β) : m (List β) :=
go as #[] (by simp)
where
/-- Auxiliary for `mapFinIdxM`:
`mapFinIdxM.go [a₀, a₁, ...] acc = acc.toList ++ [f 0 a₀ ⋯, f 1 a₁ ⋯, ...]` -/
@[specialize] go : (bs : List α) (acc : Array β) bs.length + acc.size = as.length m (List β)
| [], acc, h => pure acc.toList
| a :: as, acc, h => do
go as (acc.push ( f acc.size a (by simp at h; omega))) (by simp at h ; omega)
/--
Given a monadic function `f : Nat → α → m β` and `as : List α`, `as = [a₀, a₁, ...]`,
returns the list `[f 0 a₀, f 1 a₁, ...]`.
-/
@[inline] def mapIdxM [Monad m] (f : Nat α m β) (as : List α) : m (List β) := go as #[] where
/-- Auxiliary for `mapIdxM`:
`mapIdxM.go [a₀, a₁, ...] acc = acc.toList ++ [f acc.size a₀, f (acc.size + 1) a₁, ...]` -/
@[specialize] go : List α Array β m (List β)
| [], acc => pure acc.toList
| a :: as, acc => do go as (acc.push ( f acc.size a))
/-! ### mapFinIdx -/
@[congr] theorem mapFinIdx_congr {xs ys : List α} (w : xs = ys)

View File

@@ -28,7 +28,11 @@ attribute [simp] mapA forA filterAuxM firstM anyM allM findM? findSomeM?
/-! ### mapM -/
/-- Alternate (non-tail-recursive) form of mapM for proofs. -/
/-- Alternate (non-tail-recursive) form of mapM for proofs.
Note that we can not have this as the main definition and replace it using a `@[csimp]` lemma,
because they are only equal when `m` is a `LawfulMonad`.
-/
def mapM' [Monad m] (f : α m β) : List α m (List β)
| [] => pure []
| a :: l => return ( f a) :: ( l.mapM' f)

View File

@@ -180,7 +180,7 @@ which also receives the index of the element, and the fact that the index is les
v.toArray.mapFinIdx (fun i a h => f i a (by simpa [v.size_toArray] using h)), by simp
/-- Map a monadic function over a vector. -/
def mapM [Monad m] (f : α m β) (v : Vector α n) : m (Vector β n) := do
@[inline] def mapM [Monad m] (f : α m β) (v : Vector α n) : m (Vector β n) := do
go 0 (Nat.zero_le n) #v[]
where
go (i : Nat) (h : i n) (r : Vector β i) : m (Vector β n) := do
@@ -189,6 +189,40 @@ where
else
return r.cast (by omega)
@[inline] def forM [Monad m] (v : Vector α n) (f : α m PUnit) : m PUnit :=
v.toArray.forM f
@[inline] def flatMapM [Monad m] (v : Vector α n) (f : α m (Vector β k)) : m (Vector β (n * k)) := do
go 0 (Nat.zero_le n) (#v[].cast (by omega))
where
go (i : Nat) (h : i n) (r : Vector β (i * k)) : m (Vector β (n * k)) := do
if h' : i < n then
go (i+1) (by omega) ((r ++ ( f v[i])).cast (Nat.succ_mul i k).symm)
else
return r.cast (by congr; omega)
/-- Variant of `mapIdxM` which receives the index `i` along with the bound `i < n. -/
@[inline]
def mapFinIdxM {α : Type u} {β : Type v} {m : Type v Type w} [Monad m]
(as : Vector α n) (f : (i : Nat) α (h : i < n) m β) : m (Vector β n) :=
let rec @[specialize] map (i : Nat) (j : Nat) (inv : i + j = n) (bs : Vector β (n - i)) : m (Vector β n) := do
match i, inv with
| 0, _ => pure bs
| i+1, inv =>
have j_lt : j < n := by
rw [ inv, Nat.add_assoc, Nat.add_comm 1 j, Nat.add_comm]
apply Nat.le_add_right
have : i + (j + 1) = n := by rw [ inv, Nat.add_comm j 1, Nat.add_assoc]
map i (j+1) this ((bs.push ( f j as[j] j_lt)).cast (by omega))
map n 0 rfl (#v[].cast (by simp))
@[inline]
def mapIdxM {α : Type u} {β : Type v} {m : Type v Type w} [Monad m] (f : Nat α m β) (as : Vector α n) : m (Vector β n) :=
as.mapFinIdxM fun i a _ => f i a
@[inline] def firstM {α : Type u} {m : Type v Type w} [Alternative m] (f : α m β) (as : Vector α n) : m β :=
as.toArray.firstM f
@[inline] def flatten (v : Vector (Vector α n) m) : Vector α (m * n) :=
(v.toArray.map Vector.toArray).flatten,
by rcases v; simp_all [Function.comp_def, Array.map_const']
@@ -309,6 +343,16 @@ no element of the index matches the given value.
@[inline] def indexOf? [BEq α] (v : Vector α n) (x : α) : Option (Fin n) :=
(v.toArray.indexOf? x).map (Fin.cast v.size_toArray)
/--
Note that the universe level is contrained to `Type` here,
to avoid having to have the predicate live in `p : α → m (ULift Bool)`.
-/
@[inline] def findM? {α : Type} {m : Type Type} [Monad m] (f : α m Bool) (as : Vector α n) : m (Option α) :=
as.toArray.findM? f
@[inline] def findSomeM? [Monad m] (f : α m (Option β)) (as : Vector α n) : m (Option β) :=
as.toArray.findSomeM? f
/-- Returns `true` when `v` is a prefix of the vector `w`. -/
@[inline] def isPrefixOf [BEq α] (v : Vector α m) (w : Vector α n) : Bool :=
v.toArray.isPrefixOf w.toArray
@@ -345,6 +389,10 @@ no element of the index matches the given value.
instance : ForIn' m (Vector α n) α inferInstance where
forIn' v b f := Array.forIn' v.toArray b (fun a h b => f a (by simpa using h) b)
/-! ### ForM instance -/
instance : ForM m (Vector α n) α where
forM := forM
/-! ### ToStream instance -/
instance : ToStream (Vector α n) (Subarray α) where

View File

@@ -104,6 +104,12 @@ theorem toArray_mk (a : Array α) (h : a.size = n) : (Vector.mk a h).toArray = a
@[simp] theorem indexOf?_mk [BEq α] (a : Array α) (h : a.size = n) (x : α) :
(Vector.mk a h).indexOf? x = (a.indexOf? x).map (Fin.cast h) := rfl
@[simp] theorem findM?_mk [Monad m] (a : Array α) (h : a.size = n) (f : α m Bool) :
(Vector.mk a h).findM? f = a.findM? f := rfl
@[simp] theorem findSomeM?_mk [Monad m] (a : Array α) (h : a.size = n) (f : α m (Option β)) :
(Vector.mk a h).findSomeM? f = a.findSomeM? f := rfl
@[simp] theorem mk_isEqv_mk (r : α α Bool) (a b : Array α) (ha : a.size = n) (hb : b.size = n) :
Vector.isEqv (Vector.mk a ha) (Vector.mk b hb) r = Array.isEqv a b r := by
simp [Vector.isEqv, Array.isEqv, ha, hb]
@@ -121,6 +127,16 @@ theorem toArray_mk (a : Array α) (h : a.size = n) : (Vector.mk a h).toArray = a
(Vector.mk a h).mapFinIdx f =
Vector.mk (a.mapFinIdx fun i a h' => f i a (by simpa [h] using h')) (by simp [h]) := rfl
@[simp] theorem forM_mk [Monad m] (f : α m PUnit) (a : Array α) (h : a.size = n) :
(Vector.mk a h).forM f = a.forM f := rfl
@[simp] theorem flatMap_mk (f : α Vector β m) (a : Array α) (h : a.size = n) :
(Vector.mk a h).flatMap f =
Vector.mk (a.flatMap (fun a => (f a).toArray)) (by simp [h, Array.map_const']) := rfl
@[simp] theorem firstM_mk [Alternative m] (f : α m β) (a : Array α) (h : a.size = n) :
(Vector.mk a h).firstM f = a.firstM f := rfl
@[simp] theorem reverse_mk (a : Array α) (h : a.size = n) :
(Vector.mk a h).reverse = Vector.mk a.reverse (by simp [h]) := rfl
@@ -1656,11 +1672,6 @@ theorem eq_iff_flatten_eq {L L' : Vector (Vector α n) m} :
/-! ### flatMap -/
@[simp] theorem flatMap_mk (l : Array α) (h : l.size = m) (f : α Vector β n) :
(mk l h).flatMap f =
mk (l.flatMap (fun a => (f a).toArray)) (by simp [Array.map_const', h]) := by
simp [flatMap]
@[simp] theorem flatMap_toArray (l : Vector α n) (f : α Vector β m) :
l.toArray.flatMap (fun a => (f a).toArray) = (l.flatMap f).toArray := by
rcases l with l, rfl

View File

@@ -0,0 +1,111 @@
-- This files tracks the implementation of monadic functions for lists, arrays, and vectors.
-- This is just about the definitions, not the theorems.
#check List.mapM
#check Array.mapM
#check Vector.mapM
#check List.flatMapM
#check Array.flatMapM
#check Vector.flatMapM
#check List.mapFinIdxM
#check Array.mapFinIdxM
#check Vector.mapFinIdxM
#check List.mapIdxM
#check Array.mapIdxM
#check Vector.mapIdxM
#check List.firstM
#check Array.firstM
#check Vector.firstM
#check List.forM
#check Array.forM
#check Vector.forM
#check List.filterM
#check Array.filterM
#check List.filterRevM
#check Array.filterRevM
#check List.filterMapM
#check Array.filterMapM
#check List.foldlM
#check Array.foldlM
#check Vector.foldlM
#check List.foldrM
#check Array.foldrM
#check Vector.foldrM
#check List.findM?
#check Array.findM?
#check Vector.findM?
#check List.findSomeM?
#check Array.findSomeM?
#check Vector.findSomeM?
#check List.anyM
#check Array.anyM
#check Vector.anyM
#check List.allM
#check Array.allM
#check Vector.allM
variable {m : Type v Type w} [Monad m] {α : Type} {n : Nat}
#synth ForIn' m (List α) α inferInstance
#synth ForIn' m (Array α) α inferInstance
#synth ForIn' m (Vector α n) α inferInstance
#check List.forM
#check Array.forM
#check Vector.forM
#synth ForM m (List α) α
#synth ForM m (Array α) α
#synth ForM m (Vector α n) α
#synth Functor List
#synth Functor Array
-- These operations still have discrepancies.
-- #check List.modifyM
#check Array.modifyM
-- #check Vector.modifyM
-- #check List.forRevM
#check Array.forRevM
-- #check Vector.forRevM
-- #check List.findRevM?
#check Array.findRevM?
-- #check Vector.findRevM?
-- #check List.findSomeRevM?
#check Array.findSomeRevM?
-- #check Vector.findSomeRevM?
-- #check List.findIdxM?
#check Array.findIdxM?
-- #check Vector.findIdxM?
-- The following have not been implemented for any of the containers.
-- #check List.foldlIdxM
-- #check Array.foldlIdxM
-- #check Vector.foldlIdxM
-- #check List.foldrIdxM
-- #check Array.foldrIdxM
-- #check Vector.foldrIdxM
-- #check List.ofFnM
-- #check Array.ofFnM
-- #check Vector.ofFnM