Compare commits

...

1 Commits

Author SHA1 Message Date
Kim Morrison
c5facc7547 feat: align List/Array/Vector.flatMap 2025-01-16 13:16:17 +11:00
6 changed files with 193 additions and 15 deletions

View File

@@ -1560,6 +1560,11 @@ theorem filterMap_eq_push_iff {f : α → Option β} {l : Array α} {l' : Array
cases bs
simp
theorem toArray_append {xs : List α} {ys : Array α} :
xs.toArray ++ ys = (xs ++ ys.toList).toArray := by
rcases ys with ys
simp
@[simp] theorem toArray_eq_append_iff {xs : List α} {as bs : Array α} :
xs.toArray = as ++ bs xs = as.toList ++ bs.toList := by
cases as
@@ -1871,6 +1876,11 @@ theorem append_eq_map_iff {f : α → β} :
rw [ flatten_map_toArray]
simp
theorem flatten_toArray (l : List (Array α)) :
l.toArray.flatten = (l.map Array.toList).flatten.toArray := by
apply ext'
simp
@[simp] theorem size_flatten (L : Array (Array α)) : L.flatten.size = (L.map size).sum := by
cases L using array₂_induction
simp [Function.comp_def]
@@ -1886,14 +1896,14 @@ theorem mem_flatten : ∀ {L : Array (Array α)}, a ∈ L.flatten ↔ ∃ l, l
· rintro s, h₁, h₂
refine s.toList, s, h₁, rfl, h₂
@[simp] theorem flatten_eq_nil_iff {L : Array (Array α)} : L.flatten = #[] l L, l = #[] := by
@[simp] theorem flatten_eq_empty_iff {L : Array (Array α)} : L.flatten = #[] l L, l = #[] := by
induction L using array₂_induction
simp
@[simp] theorem nil_eq_flatten_iff {L : Array (Array α)} : #[] = L.flatten l L, l = #[] := by
rw [eq_comm, flatten_eq_nil_iff]
@[simp] theorem empty_eq_flatten_iff {L : Array (Array α)} : #[] = L.flatten l L, l = #[] := by
rw [eq_comm, flatten_eq_empty_iff]
theorem flatten_ne_nil_iff {xs : Array (Array α)} : xs.flatten #[] x, x xs x #[] := by
theorem flatten_ne_empty_iff {xs : Array (Array α)} : xs.flatten #[] x, x xs x #[] := by
simp
theorem exists_of_mem_flatten : a flatten L l, l L a l := mem_flatten.1
@@ -2029,6 +2039,102 @@ theorem eq_iff_flatten_eq {L L' : Array (Array α)} :
rw [List.map_inj_right]
simp +contextual
/-! ### flatMap -/
theorem flatMap_def (l : Array α) (f : α Array β) : l.flatMap f = flatten (map f l) := by
rcases l with l
simp [flatten_toArray, Function.comp_def, List.flatMap_def]
theorem flatMap_toList (l : Array α) (f : α List β) :
l.toList.flatMap f = (l.flatMap (fun a => (f a).toArray)).toList := by
rcases l with l
simp
@[simp] theorem flatMap_id (l : Array (Array α)) : l.flatMap id = l.flatten := by simp [flatMap_def]
@[simp] theorem flatMap_id' (l : Array (Array α)) : l.flatMap (fun a => a) = l.flatten := by simp [flatMap_def]
@[simp]
theorem size_flatMap (l : Array α) (f : α Array β) :
(l.flatMap f).size = sum (map (fun a => (f a).size) l) := by
rcases l with l
simp [Function.comp_def]
@[simp] theorem mem_flatMap {f : α Array β} {b} {l : Array α} : b l.flatMap f a, a l b f a := by
simp [flatMap_def, mem_flatten]
exact fun _, a, h₁, rfl, h₂ => a, h₁, h₂, fun a, h₁, h₂ => _, a, h₁, rfl, h₂
theorem exists_of_mem_flatMap {b : β} {l : Array α} {f : α Array β} :
b l.flatMap f a, a l b f a := mem_flatMap.1
theorem mem_flatMap_of_mem {b : β} {l : Array α} {f : α Array β} {a} (al : a l) (h : b f a) :
b l.flatMap f := mem_flatMap.2 a, al, h
@[simp]
theorem flatMap_eq_empty_iff {l : Array α} {f : α Array β} : l.flatMap f = #[] x l, f x = #[] := by
rw [flatMap_def, flatten_eq_empty_iff]
simp
theorem forall_mem_flatMap {p : β Prop} {l : Array α} {f : α Array β} :
( (x) (_ : x l.flatMap f), p x) (a) (_ : a l) (b) (_ : b f a), p b := by
simp only [mem_flatMap, forall_exists_index, and_imp]
constructor <;> (intros; solve_by_elim)
theorem flatMap_singleton (f : α Array β) (x : α) : #[x].flatMap f = f x := by
simp
@[simp] theorem flatMap_singleton' (l : Array α) : (l.flatMap fun x => #[x]) = l := by
rcases l with l
simp
@[simp] theorem flatMap_append (xs ys : Array α) (f : α Array β) :
(xs ++ ys).flatMap f = xs.flatMap f ++ ys.flatMap f := by
rcases xs with xs
rcases ys with ys
simp
theorem flatMap_assoc {α β} (l : Array α) (f : α Array β) (g : β Array γ) :
(l.flatMap f).flatMap g = l.flatMap fun x => (f x).flatMap g := by
rcases l with l
simp [List.flatMap_assoc, flatMap_toList]
theorem map_flatMap (f : β γ) (g : α Array β) (l : Array α) :
(l.flatMap g).map f = l.flatMap fun a => (g a).map f := by
rcases l with l
simp [List.map_flatMap]
theorem flatMap_map (f : α β) (g : β Array γ) (l : Array α) :
(map f l).flatMap g = l.flatMap (fun a => g (f a)) := by
rcases l with l
simp [List.flatMap_map]
theorem map_eq_flatMap {α β} (f : α β) (l : Array α) : map f l = l.flatMap fun x => #[f x] := by
simp only [ map_singleton]
rw [ flatMap_singleton' l, map_flatMap, flatMap_singleton']
theorem filterMap_flatMap {β γ} (l : Array α) (g : α Array β) (f : β Option γ) :
(l.flatMap g).filterMap f = l.flatMap fun a => (g a).filterMap f := by
rcases l with l
simp [List.filterMap_flatMap]
theorem filter_flatMap (l : Array α) (g : α Array β) (f : β Bool) :
(l.flatMap g).filter f = l.flatMap fun a => (g a).filter f := by
rcases l with l
simp [List.filter_flatMap]
theorem flatMap_eq_foldl (f : α Array β) (l : Array α) :
l.flatMap f = l.foldl (fun acc a => acc ++ f a) #[] := by
rcases l with l
simp only [List.flatMap_toArray, List.flatMap_eq_foldl, size_toArray, List.foldl_toArray']
suffices l', (List.foldl (fun acc a => acc ++ (f a).toList) l' l).toArray =
List.foldl (fun acc a => acc ++ f a) l'.toArray l by
simpa using this []
induction l with
| nil => simp
| cons a l ih =>
intro l'
simp [ih ((l' ++ (f a).toList)), toArray_append]
/-! Content below this point has not yet been aligned with `List`. -/
-- This is a duplicate of `List.toArray_toList`.

View File

@@ -606,11 +606,11 @@ set_option linter.missingDocs false in
to get a list of lists, and then concatenates them all together.
* `[2, 3, 2].bind range = [0, 1, 0, 1, 2, 0, 1]`
-/
@[inline] def flatMap {α : Type u} {β : Type v} (a : List α) (b : α List β) : List β := flatten (map b a)
@[inline] def flatMap {α : Type u} {β : Type v} (b : α List β) (a : List α) : List β := flatten (map b a)
@[simp] theorem flatMap_nil (f : α List β) : List.flatMap [] f = [] := by simp [flatten, List.flatMap]
@[simp] theorem flatMap_nil (f : α List β) : List.flatMap f [] = [] := by simp [flatten, List.flatMap]
@[simp] theorem flatMap_cons x xs (f : α List β) :
List.flatMap (x :: xs) f = f x ++ List.flatMap xs f := by simp [flatten, List.flatMap]
List.flatMap f (x :: xs) = f x ++ List.flatMap f xs := by simp [flatten, List.flatMap]
set_option linter.missingDocs false in
@[deprecated flatMap (since := "2024-10-16")] abbrev bind := @flatMap

View File

@@ -96,14 +96,14 @@ The following operations are given `@[csimp]` replacements below:
/-! ### flatMap -/
/-- Tail recursive version of `List.flatMap`. -/
@[inline] def flatMapTR (as : List α) (f : α List β) : List β := go as #[] where
@[inline] def flatMapTR (f : α List β) (as : List α) : List β := go as #[] where
/-- Auxiliary for `flatMap`: `flatMap.go f as = acc.toList ++ bind f as` -/
@[specialize] go : List α Array β List β
| [], acc => acc.toList
| x::xs, acc => go xs (acc ++ f x)
@[csimp] theorem flatMap_eq_flatMapTR : @List.flatMap = @flatMapTR := by
funext α β as f
funext α β f as
let rec go : as acc, flatMapTR.go f as acc = acc.toList ++ as.flatMap f
| [], acc => by simp [flatMapTR.go, flatMap]
| x::xs, acc => by simp [flatMapTR.go, flatMap, go xs]
@@ -112,7 +112,7 @@ The following operations are given `@[csimp]` replacements below:
/-! ### flatten -/
/-- Tail recursive version of `List.flatten`. -/
@[inline] def flattenTR (l : List (List α)) : List α := flatMapTR l id
@[inline] def flattenTR (l : List (List α)) : List α := l.flatMapTR id
@[csimp] theorem flatten_eq_flattenTR : @flatten = @flattenTR := by
funext α l; rw [ List.flatMap_id, List.flatMap_eq_flatMapTR]; rfl

View File

@@ -2070,14 +2070,14 @@ theorem eq_iff_flatten_eq : ∀ {L L' : List (List α)},
theorem flatMap_def (l : List α) (f : α List β) : l.flatMap f = flatten (map f l) := by rfl
@[simp] theorem flatMap_id (l : List (List α)) : List.flatMap l id = l.flatten := by simp [flatMap_def]
@[simp] theorem flatMap_id (l : List (List α)) : l.flatMap id = l.flatten := by simp [flatMap_def]
@[simp] theorem flatMap_id' (l : List (List α)) : List.flatMap l (fun a => a) = l.flatten := by simp [flatMap_def]
@[simp] theorem flatMap_id' (l : List (List α)) : l.flatMap (fun a => a) = l.flatten := by simp [flatMap_def]
@[simp]
theorem length_flatMap (l : List α) (f : α List β) :
length (l.flatMap f) = sum (map (length f) l) := by
rw [List.flatMap, length_flatten, map_map]
length (l.flatMap f) = sum (map (fun a => (f a).length) l) := by
rw [List.flatMap, length_flatten, map_map, Function.comp_def]
@[simp] theorem mem_flatMap {f : α List β} {b} {l : List α} : b l.flatMap f a, a l b f a := by
simp [flatMap_def, mem_flatten]
@@ -2090,7 +2090,7 @@ theorem mem_flatMap_of_mem {b : β} {l : List α} {f : α → List β} {a} (al :
b l.flatMap f := mem_flatMap.2 a, al, h
@[simp]
theorem flatMap_eq_nil_iff {l : List α} {f : α List β} : List.flatMap l f = [] x l, f x = [] :=
theorem flatMap_eq_nil_iff {l : List α} {f : α List β} : l.flatMap f = [] x l, f x = [] :=
flatten_eq_nil_iff.trans <| by
simp only [mem_map, forall_exists_index, and_imp, forall_apply_eq_imp_iff₂]

View File

@@ -174,6 +174,9 @@ result is empty. If `stop` is greater than the size of the vector, the size is u
(v.toArray.map Vector.toArray).flatten,
by rcases v; simp_all [Function.comp_def, Array.map_const']
@[inline] def flatMap (v : Vector α n) (f : α Vector β m) : Vector β (n * m) :=
v.toArray.flatMap fun a => (f a).toArray, by simp [Array.map_const']
/-- Maps corresponding elements of two vectors of equal size using the function `f`. -/
@[inline] def zipWith (a : Vector α n) (b : Vector β n) (f : α β φ) : Vector φ n :=
Array.zipWith a.toArray b.toArray f, by simp

View File

@@ -1525,6 +1525,75 @@ theorem eq_iff_flatten_eq {L L' : Vector (Vector α n) m} :
subst this
rfl
/-! ### flatMap -/
@[simp] theorem flatMap_mk (l : Array α) (h : l.size = m) (f : α Vector β n) :
(mk l h).flatMap f =
mk (l.flatMap (fun a => (f a).toArray)) (by simp [Array.map_const', h]) := by
simp [flatMap]
@[simp] theorem flatMap_toArray (l : Vector α n) (f : α Vector β m) :
l.toArray.flatMap (fun a => (f a).toArray) = (l.flatMap f).toArray := by
rcases l with l, rfl
simp
theorem flatMap_def (l : Vector α n) (f : α Vector β m) : l.flatMap f = flatten (map f l) := by
rcases l with l, rfl
simp [Array.flatMap_def, Function.comp_def]
@[simp] theorem flatMap_id (l : Vector (Vector α m) n) : l.flatMap id = l.flatten := by simp [flatMap_def]
@[simp] theorem flatMap_id' (l : Vector (Vector α m) n) : l.flatMap (fun a => a) = l.flatten := by simp [flatMap_def]
@[simp] theorem mem_flatMap {f : α Vector β m} {b} {l : Vector α n} : b l.flatMap f a, a l b f a := by
simp [flatMap_def, mem_flatten]
exact fun _, a, h₁, rfl, h₂ => a, h₁, h₂, fun a, h₁, h₂ => _, a, h₁, rfl, h₂
theorem exists_of_mem_flatMap {b : β} {l : Vector α n} {f : α Vector β m} :
b l.flatMap f a, a l b f a := mem_flatMap.1
theorem mem_flatMap_of_mem {b : β} {l : Vector α n} {f : α Vector β m} {a} (al : a l) (h : b f a) :
b l.flatMap f := mem_flatMap.2 a, al, h
theorem forall_mem_flatMap {p : β Prop} {l : Vector α n} {f : α Vector β m} :
( (x) (_ : x l.flatMap f), p x) (a) (_ : a l) (b) (_ : b f a), p b := by
simp only [mem_flatMap, forall_exists_index, and_imp]
constructor <;> (intros; solve_by_elim)
theorem flatMap_singleton (f : α Vector β m) (x : α) : #v[x].flatMap f = (f x).cast (by simp) := by
simp [flatMap_def]
@[simp] theorem flatMap_singleton' (l : Vector α n) : (l.flatMap fun x => #v[x]) = l.cast (by simp) := by
rcases l with l, rfl
simp
@[simp] theorem flatMap_append (xs ys : Vector α n) (f : α Vector β m) :
(xs ++ ys).flatMap f = (xs.flatMap f ++ ys.flatMap f).cast (by simp [Nat.add_mul]) := by
rcases xs with xs
rcases ys with ys
simp [flatMap_def, flatten_append]
theorem flatMap_assoc {α β} (l : Vector α n) (f : α Vector β m) (g : β Vector γ k) :
(l.flatMap f).flatMap g = (l.flatMap fun x => (f x).flatMap g).cast (by simp [Nat.mul_assoc]) := by
rcases l with l, rfl
simp [Array.flatMap_assoc]
theorem map_flatMap (f : β γ) (g : α Vector β m) (l : Vector α n) :
(l.flatMap g).map f = l.flatMap fun a => (g a).map f := by
rcases l with l, rfl
simp [Array.map_flatMap]
theorem flatMap_map (f : α β) (g : β Vector γ k) (l : Vector α n) :
(map f l).flatMap g = l.flatMap (fun a => g (f a)) := by
rcases l with l, rfl
simp [Array.flatMap_map]
theorem map_eq_flatMap {α β} (f : α β) (l : Vector α n) :
map f l = (l.flatMap fun x => #v[f x]).cast (by simp) := by
rcases l with l, rfl
simp [Array.map_eq_flatMap]
/-! Content below this point has not yet been aligned with `List` and `Array`. -/
@[simp] theorem getElem_ofFn {α n} (f : Fin n α) (i : Nat) (h : i < n) :