mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-17 18:34:06 +00:00
Compare commits
13 Commits
57df23f27e
...
paul/array
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3be124112c | ||
|
|
c74a93b72c | ||
|
|
bba4444539 | ||
|
|
1611ac1067 | ||
|
|
8c6f1d825b | ||
|
|
05b773ff0c | ||
|
|
f81b24c9f0 | ||
|
|
a253f86d88 | ||
|
|
2379950f38 | ||
|
|
3a253294e0 | ||
|
|
c76c0f9c42 | ||
|
|
0a5285bb4e | ||
|
|
87f9c0e808 |
@@ -34,3 +34,4 @@ public import Init.Data.Array.MinMax
|
||||
public import Init.Data.Array.Nat
|
||||
public import Init.Data.Array.Int
|
||||
public import Init.Data.Array.Count
|
||||
public import Init.Data.Array.Sort
|
||||
|
||||
10
src/Init/Data/Array/Sort.lean
Normal file
10
src/Init/Data/Array/Sort.lean
Normal file
@@ -0,0 +1,10 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Paul Reichert
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.Data.Array.Sort.Basic
|
||||
public import Init.Data.Array.Sort.Lemmas
|
||||
55
src/Init/Data/Array/Sort/Basic.lean
Normal file
55
src/Init/Data/Array/Sort/Basic.lean
Normal file
@@ -0,0 +1,55 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Paul Reichert
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.Data.Array.Subarray.Split
|
||||
public import Init.Data.Slice.Array
|
||||
import Init.Omega
|
||||
|
||||
public section
|
||||
|
||||
private def Array.MergeSort.Internal.merge (xs ys : Array α) (le : α → α → Bool := by exact (· ≤ ·)) :
|
||||
Array α :=
|
||||
if hxs : 0 < xs.size then
|
||||
if hys : 0 < ys.size then
|
||||
go xs[*...*] ys[*...*] (by simp only [Array.size_mkSlice_rii]; omega) (by simp only [Array.size_mkSlice_rii]; omega) (Array.emptyWithCapacity (xs.size + ys.size))
|
||||
else
|
||||
xs
|
||||
else
|
||||
ys
|
||||
where
|
||||
go (xs ys : Subarray α) (hxs : 0 < xs.size) (hys : 0 < ys.size) (acc : Array α) : Array α :=
|
||||
let x := xs[0]
|
||||
let y := ys[0]
|
||||
if le x y then
|
||||
if hi : 1 < xs.size then
|
||||
go (xs.drop 1) ys (by simp only [Subarray.size_drop]; omega) hys (acc.push x)
|
||||
else
|
||||
ys.foldl (init := acc.push x) (fun acc y => acc.push y)
|
||||
else
|
||||
if hj : 1 < ys.size then
|
||||
go xs (ys.drop 1) hxs (by simp only [Subarray.size_drop]; omega) (acc.push y)
|
||||
else
|
||||
xs.foldl (init := acc.push y) (fun acc x => acc.push x)
|
||||
termination_by xs.size + ys.size
|
||||
|
||||
def Subarray.mergeSort (xs : Subarray α) (le : α → α → Bool := by exact (· ≤ ·)) : Array α :=
|
||||
if h : 1 < xs.size then
|
||||
let splitIdx := (xs.size + 1) / 2 -- We follow the same splitting convention as `List.mergeSort`
|
||||
let left := xs[*...splitIdx]
|
||||
let right := xs[splitIdx...*]
|
||||
Array.MergeSort.Internal.merge (mergeSort left le) (mergeSort right le) le
|
||||
else
|
||||
xs.toArray
|
||||
termination_by xs.size
|
||||
decreasing_by
|
||||
· simp only [Subarray.size_mkSlice_rio]; omega
|
||||
· simp only [Subarray.size_mkSlice_rci]; omega
|
||||
|
||||
@[inline]
|
||||
def Array.mergeSort (xs : Array α) (le : α → α → Bool := by exact (· ≤ ·)) : Array α :=
|
||||
xs[*...*].mergeSort le
|
||||
240
src/Init/Data/Array/Sort/Lemmas.lean
Normal file
240
src/Init/Data/Array/Sort/Lemmas.lean
Normal file
@@ -0,0 +1,240 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Paul Reichert
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.Data.Array.Sort.Basic
|
||||
public import Init.Data.List.Sort.Basic
|
||||
public import Init.Data.Array.Perm
|
||||
import all Init.Data.Array.Sort.Basic
|
||||
import all Init.Data.List.Sort.Basic
|
||||
import Init.Data.List.Sort.Lemmas
|
||||
import Init.Data.Slice.Array.Lemmas
|
||||
import Init.Data.Slice.List.Lemmas
|
||||
import Init.Data.Array.Bootstrap
|
||||
import Init.Data.Array.Lemmas
|
||||
import Init.Data.Array.MapIdx
|
||||
import Init.ByCases
|
||||
|
||||
public section
|
||||
|
||||
private theorem Array.MergeSort.merge.go_eq_listMerge {xs ys : Subarray α} {hxs hys le acc} :
|
||||
(Array.MergeSort.Internal.merge.go le xs ys hxs hys acc).toList = acc.toList ++ List.merge xs.toList ys.toList le := by
|
||||
fun_induction Array.MergeSort.Internal.merge.go le xs ys hxs hys acc
|
||||
· rename_i xs ys _ _ _ _ _ _ _ _
|
||||
rw [List.merge.eq_def]
|
||||
split
|
||||
· have : xs.size = 0 := by simp [← Subarray.length_toList, *]
|
||||
omega
|
||||
· have : ys.size = 0 := by simp [← Subarray.length_toList, *]
|
||||
omega
|
||||
· rename_i x' xs' y' ys' _ _
|
||||
simp +zetaDelta only at *
|
||||
have h₁ : x' = xs[0] := by simp [Subarray.getElem_eq_getElem_toList, *]
|
||||
have h₂ : y' = ys[0] := by simp [Subarray.getElem_eq_getElem_toList, *]
|
||||
cases h₁
|
||||
cases h₂
|
||||
simp [Subarray.toList_drop, *]
|
||||
· rename_i xs ys _ _ _ _ _ _ _
|
||||
rw [List.merge.eq_def]
|
||||
split
|
||||
· have : xs.size = 0 := by simp [← Subarray.length_toList, *]
|
||||
omega
|
||||
· have : ys.size = 0 := by simp [← Subarray.length_toList, *]
|
||||
omega
|
||||
· rename_i x' xs' y' ys' _ _
|
||||
simp +zetaDelta only at *
|
||||
have h₁ : x' = xs[0] := by simp [Subarray.getElem_eq_getElem_toList, *]
|
||||
have h₂ : y' = ys[0] := by simp [Subarray.getElem_eq_getElem_toList, *]
|
||||
cases h₁
|
||||
cases h₂
|
||||
simp [*]
|
||||
have : xs.size = xs'.length + 1 := by simp [← Subarray.length_toList, *]
|
||||
have : xs' = [] := List.eq_nil_of_length_eq_zero (by omega)
|
||||
simp only [this]
|
||||
rw [← Subarray.foldl_toList]
|
||||
simp [*]
|
||||
· rename_i xs ys _ _ _ _ _ _ _ _
|
||||
rw [List.merge.eq_def]
|
||||
split
|
||||
· have : xs.size = 0 := by simp [← Subarray.length_toList, *]
|
||||
omega
|
||||
· have : ys.size = 0 := by simp [← Subarray.length_toList, *]
|
||||
omega
|
||||
· rename_i x' xs' y' ys' _ _
|
||||
simp +zetaDelta only at *
|
||||
have h₁ : x' = xs[0] := by simp [Subarray.getElem_eq_getElem_toList, *]
|
||||
have h₂ : y' = ys[0] := by simp [Subarray.getElem_eq_getElem_toList, *]
|
||||
cases h₁
|
||||
cases h₂
|
||||
simp [Subarray.toList_drop, *]
|
||||
· rename_i xs ys _ _ _ _ _ _ _
|
||||
rw [List.merge.eq_def]
|
||||
split
|
||||
· have : xs.size = 0 := by simp [← Subarray.length_toList, *]
|
||||
omega
|
||||
· have : ys.size = 0 := by simp [← Subarray.length_toList, *]
|
||||
omega
|
||||
· rename_i x' xs' y' ys' _ _
|
||||
simp +zetaDelta only at *
|
||||
have h₁ : x' = xs[0] := by simp [Subarray.getElem_eq_getElem_toList, *]
|
||||
have h₂ : y' = ys[0] := by simp [Subarray.getElem_eq_getElem_toList, *]
|
||||
cases h₁
|
||||
cases h₂
|
||||
simp [*]
|
||||
have : ys.size = ys'.length + 1 := by simp [← Subarray.length_toList, *]
|
||||
have : ys' = [] := List.eq_nil_of_length_eq_zero (by omega)
|
||||
simp [this]
|
||||
rw [← Subarray.foldl_toList]
|
||||
simp [*]
|
||||
|
||||
private theorem Array.MergeSort.merge_eq_listMerge {xs ys : Array α} {le} :
|
||||
(Array.MergeSort.Internal.merge xs ys le).toList = List.merge xs.toList ys.toList le := by
|
||||
rw [Array.MergeSort.Internal.merge]
|
||||
split <;> rename_i heq₁
|
||||
· split <;> rename_i heq₂
|
||||
· simp [Array.MergeSort.merge.go_eq_listMerge]
|
||||
· have : ys.toList = [] := by simp_all
|
||||
simp [this]
|
||||
· have : xs.toList = [] := by simp_all
|
||||
simp [this]
|
||||
|
||||
private theorem List.mergeSort_eq_merge_mkSlice {xs : List α} :
|
||||
xs.mergeSort le =
|
||||
if 1 < xs.length then
|
||||
merge (xs[*...((xs.length + 1) / 2)].toList.mergeSort le) (xs[((xs.length + 1) / 2)...*].toList.mergeSort le) le
|
||||
else
|
||||
xs := by
|
||||
fun_cases xs.mergeSort le
|
||||
· simp
|
||||
· simp
|
||||
· rename_i x y ys lr hl hr
|
||||
simp [lr]
|
||||
|
||||
theorem Subarray.toList_mergeSort {xs : Subarray α} {le : α → α → Bool} :
|
||||
(xs.mergeSort le).toList = xs.toList.mergeSort le := by
|
||||
fun_induction xs.mergeSort le
|
||||
· rw [List.mergeSort_eq_merge_mkSlice]
|
||||
simp +zetaDelta [Array.MergeSort.merge_eq_listMerge, *]
|
||||
· simp [List.mergeSort_eq_merge_mkSlice, *]
|
||||
|
||||
@[simp, grind =]
|
||||
theorem Subarray.mergeSort_eq_mergeSort_toArray {xs : Subarray α} {le : α → α → Bool} :
|
||||
xs.mergeSort le = xs.toArray.mergeSort le := by
|
||||
simp [← Array.toList_inj, toList_mergeSort, Array.mergeSort]
|
||||
|
||||
theorem Subarray.mergeSort_toArray {xs : Subarray α} {le : α → α → Bool} :
|
||||
xs.toArray.mergeSort le = xs.mergeSort le := by
|
||||
simp
|
||||
|
||||
theorem Array.toList_mergeSort {xs : Array α} {le : α → α → Bool} :
|
||||
(xs.mergeSort le).toList = xs.toList.mergeSort le := by
|
||||
rw [Array.mergeSort, Subarray.toList_mergeSort, Array.toList_mkSlice_rii]
|
||||
|
||||
theorem Array.mergeSort_eq_toArray_mergeSort_toList {xs : Array α} {le : α → α → Bool} :
|
||||
xs.mergeSort le = (xs.toList.mergeSort le).toArray := by
|
||||
simp [← toList_mergeSort]
|
||||
|
||||
/-!
|
||||
# Basic properties of `Array.mergeSort`.
|
||||
|
||||
* `pairwise_mergeSort`: `mergeSort` produces a sorted array.
|
||||
* `mergeSort_perm`: `mergeSort` is a permutation of the input array.
|
||||
* `mergeSort_of_pairwise`: `mergeSort` does not change a sorted array.
|
||||
* `sublist_mergeSort`: if `c` is a sorted sublist of `l`, then `c` is still a sublist of `mergeSort le l`.
|
||||
-/
|
||||
|
||||
namespace Array
|
||||
|
||||
-- Enable this instance locally so we can write `Pairwise le` instead of `Pairwise (le · ·)` everywhere.
|
||||
attribute [local instance] boolRelToRel
|
||||
|
||||
@[simp] theorem mergeSort_empty : (#[] : Array α).mergeSort r = #[] := by
|
||||
simp [mergeSort_eq_toArray_mergeSort_toList]
|
||||
|
||||
@[simp] theorem mergeSort_singleton {a : α} : #[a].mergeSort r = #[a] := by
|
||||
simp [mergeSort_eq_toArray_mergeSort_toList]
|
||||
|
||||
theorem mergeSort_perm {xs : Array α} {le} : (xs.mergeSort le).Perm xs := by
|
||||
simpa [mergeSort_eq_toArray_mergeSort_toList, Array.perm_iff_toList_perm] using List.mergeSort_perm _ _
|
||||
|
||||
@[simp] theorem size_mergeSort {xs : Array α} : (mergeSort xs le).size = xs.size := by
|
||||
simp [mergeSort_eq_toArray_mergeSort_toList]
|
||||
|
||||
@[simp] theorem mem_mergeSort {a : α} {xs : Array α} : a ∈ mergeSort xs le ↔ a ∈ xs := by
|
||||
simp [mergeSort_eq_toArray_mergeSort_toList]
|
||||
|
||||
/--
|
||||
The result of `Array.mergeSort` is sorted,
|
||||
as long as the comparison function is transitive (`le a b → le b c → le a c`)
|
||||
and total in the sense that `le a b || le b a`.
|
||||
|
||||
The comparison function need not be irreflexive, i.e. `le a b` and `le b a` is allowed even when `a ≠ b`.
|
||||
-/
|
||||
theorem pairwise_mergeSort
|
||||
(trans : ∀ (a b c : α), le a b → le b c → le a c)
|
||||
(total : ∀ (a b : α), le a b || le b a)
|
||||
{xs : Array α} :
|
||||
(mergeSort xs le).toList.Pairwise (le · ·) := by
|
||||
simpa [mergeSort_eq_toArray_mergeSort_toList] using List.pairwise_mergeSort trans total _
|
||||
|
||||
/--
|
||||
If the input array is already sorted, then `mergeSort` does not change the array.
|
||||
-/
|
||||
theorem mergeSort_of_pairwise {le : α → α → Bool} {xs : Array α} (_ : xs.toList.Pairwise (le · ·)) :
|
||||
mergeSort xs le = xs := by
|
||||
simpa [mergeSort_eq_toArray_mergeSort_toList, List.toArray_eq_iff] using List.mergeSort_of_pairwise ‹_›
|
||||
|
||||
/--
|
||||
This merge sort algorithm is stable,
|
||||
in the sense that breaking ties in the ordering function using the position in the array
|
||||
has no effect on the output.
|
||||
|
||||
That is, elements which are equal with respect to the ordering function will remain
|
||||
in the same order in the output array as they were in the input array.
|
||||
|
||||
See also:
|
||||
* `sublist_mergeSort`: if `c <+ l` and `c.Pairwise le`, then `c <+ (mergeSort le l).toList`.
|
||||
* `pair_sublist_mergeSort`: if `[a, b] <+ l` and `le a b`, then `[a, b] <+ (mergeSort le l).toList`)
|
||||
-/
|
||||
theorem mergeSort_zipIdx {xs : Array α} :
|
||||
(mergeSort (xs.zipIdx.map fun (a, i) => (a, i)) (List.zipIdxLE le)).map (·.1) = mergeSort xs le := by
|
||||
simpa [mergeSort_eq_toArray_mergeSort_toList, Array.toList_zipIdx] using List.mergeSort_zipIdx
|
||||
|
||||
/--
|
||||
Another statement of stability of merge sort.
|
||||
If `c` is a sorted sublist of `xs.toList`,
|
||||
then `c` is still a sublist of `(mergeSort le xs).toList`.
|
||||
-/
|
||||
theorem sublist_mergeSort {le : α → α → Bool}
|
||||
(trans : ∀ (a b c : α), le a b → le b c → le a c)
|
||||
(total : ∀ (a b : α), le a b || le b a)
|
||||
{ys : List α} (_ : ys.Pairwise (le · ·)) (_ : List.Sublist ys xs.toList) :
|
||||
List.Sublist ys (mergeSort xs le).toList := by
|
||||
simpa [mergeSort_eq_toArray_mergeSort_toList, Array.toList_zipIdx] using
|
||||
List.sublist_mergeSort trans total ‹_› ‹_›
|
||||
|
||||
/--
|
||||
Another statement of stability of merge sort.
|
||||
If a pair `[a, b]` is a sublist of `xs.toList` and `le a b`,
|
||||
then `[a, b]` is still a sublist of `(mergeSort le xs).toList`.
|
||||
-/
|
||||
theorem pair_sublist_mergeSort
|
||||
(trans : ∀ (a b c : α), le a b → le b c → le a c)
|
||||
(total : ∀ (a b : α), le a b || le b a)
|
||||
(hab : le a b) (h : List.Sublist [a, b] xs.toList) :
|
||||
List.Sublist [a, b] (mergeSort xs le).toList := by
|
||||
simpa [mergeSort_eq_toArray_mergeSort_toList, Array.toList_zipIdx] using
|
||||
List.pair_sublist_mergeSort trans total ‹_› ‹_›
|
||||
|
||||
theorem map_mergeSort {r : α → α → Bool} {s : β → β → Bool} {f : α → β}
|
||||
{xs : Array α} (hxs : ∀ a ∈ xs, ∀ b ∈ xs, r a b = s (f a) (f b)) :
|
||||
(xs.mergeSort r).map f = (xs.map f).mergeSort s := by
|
||||
simp only [mergeSort_eq_toArray_mergeSort_toList, List.map_toArray, toList_map, mk.injEq]
|
||||
apply List.map_mergeSort
|
||||
simpa
|
||||
|
||||
end Array
|
||||
@@ -1,18 +1,88 @@
|
||||
/-
|
||||
Benchmark comparing `List.mergeSort` and `Array.mergeSort` performance.
|
||||
|
||||
Usage:
|
||||
./mergeSort <N>
|
||||
|
||||
where N specifies test size: N * 10^5 elements will be sorted.
|
||||
|
||||
Example:
|
||||
./mergeSort 10 # Sort 1,000,000 elements
|
||||
./mergeSort 100 # Sort 10,000,000 elements
|
||||
|
||||
The benchmark runs 4 test cases for each implementation:
|
||||
1. Reversed data (worst case for some algorithms)
|
||||
2. Already sorted data (best case)
|
||||
3. Random data
|
||||
4. Partially sorted data with duplicates
|
||||
|
||||
Results are reported per-pattern and in aggregate.
|
||||
-/
|
||||
|
||||
open List.MergeSort.Internal
|
||||
|
||||
@[noinline]
|
||||
def sortList (xs : List Nat) : IO Nat := return (mergeSortTR₂ xs).length
|
||||
|
||||
@[noinline]
|
||||
def sortArray (xs : Array Nat) : IO Nat := return xs.mergeSort.size
|
||||
|
||||
def benchOne (label : String) (listInput : List Nat) (arrayInput : Array Nat) (n : Nat) :
|
||||
IO (Nat × Nat) := do
|
||||
let start ← IO.monoMsNow
|
||||
let r1 ← sortList listInput
|
||||
let mid ← IO.monoMsNow
|
||||
let r2 ← sortArray arrayInput
|
||||
let done ← IO.monoMsNow
|
||||
if r1 != n || r2 != n then
|
||||
throw <| IO.userError s!"{label}: correctness check failed"
|
||||
let listMs := mid - start
|
||||
let arrayMs := done - mid
|
||||
let ratio := if listMs == 0 then 0.0 else arrayMs.toFloat / listMs.toFloat
|
||||
IO.println s!" {label}: List {listMs}ms, Array {arrayMs}ms, ratio {ratio}"
|
||||
return (listMs, arrayMs)
|
||||
|
||||
def main (args : List String) : IO Unit := do
|
||||
let k := 5
|
||||
let some arg := args[0]? | throw <| IO.userError s!"specify length of test data in multiples of 10^{k}"
|
||||
let some i := arg.toNat? | throw <| IO.userError s!"specify length of test data in multiples of 10^{k}"
|
||||
let some arg := args[0]? | throw <| IO.userError s!"Usage: mergeSort <N>\nSorts N * 10^{k} elements"
|
||||
let some i := arg.toNat? | throw <| IO.userError s!"Invalid argument: expected positive integer"
|
||||
let n := i * (10^k)
|
||||
let i₁ := (List.range' 1 n).reverse
|
||||
let i₂ := List.range n
|
||||
let i₃ ← (List.range n).mapM (fun _ => IO.rand 0 1000)
|
||||
let i₄ := (List.range (i * (10^(k-3)))).flatMap (fun k => (k * 1000 + 1) :: (k * 1000) :: List.range' (k * 1000 + 2) 998)
|
||||
let start ← IO.monoMsNow
|
||||
let o₁ := (mergeSortTR₂ i₁).length == n
|
||||
let o₂ := (mergeSortTR₂ i₂).length == n
|
||||
let o₃ := (mergeSortTR₂ i₃).length == n
|
||||
let o₄ := (mergeSortTR₂ i₄).length == n
|
||||
IO.println (((← IO.monoMsNow) - start)/4)
|
||||
IO.Process.exit (if o₁ && o₂ && o₃ && o₄ then 0 else 1)
|
||||
|
||||
IO.println s!"Benchmarking mergeSort with n={n} ({i} * 10^{k})"
|
||||
IO.println ""
|
||||
|
||||
-- Generate test inputs (Lists)
|
||||
let reversed := (List.range' 1 n).reverse
|
||||
let sorted := List.range n
|
||||
let random ← (List.range n).mapM (fun _ => IO.rand 0 1000)
|
||||
let partiallySorted := (List.range (i * (10^(k-3)))).flatMap (fun k =>
|
||||
(k * 1000 + 1) :: (k * 1000) :: List.range' (k * 1000 + 2) 998)
|
||||
|
||||
-- Per-pattern benchmarks
|
||||
IO.println "Per-pattern results:"
|
||||
let (lt1, at1) ← benchOne "Reversed " reversed reversed.toArray n
|
||||
let (lt2, at2) ← benchOne "Sorted " sorted sorted.toArray n
|
||||
let (lt3, at3) ← benchOne "Random " random random.toArray n
|
||||
let (lt4, at4) ← benchOne "Partially sorted" partiallySorted partiallySorted.toArray n
|
||||
|
||||
-- Aggregate
|
||||
let listTotal := lt1 + lt2 + lt3 + lt4
|
||||
let arrayTotal := at1 + at2 + at3 + at4
|
||||
IO.println ""
|
||||
IO.println s!"Aggregate (4 cases):"
|
||||
IO.println s!" List.mergeSort: {listTotal} ms total, {listTotal/4} ms average"
|
||||
IO.println s!" Array.mergeSort: {arrayTotal} ms total, {arrayTotal/4} ms average"
|
||||
IO.println ""
|
||||
|
||||
IO.println "Comparison:"
|
||||
if arrayTotal < listTotal then
|
||||
let speedup := listTotal.toFloat / arrayTotal.toFloat
|
||||
IO.println s!" Array.mergeSort is {speedup}x faster overall"
|
||||
else if listTotal < arrayTotal then
|
||||
let speedup := arrayTotal.toFloat / listTotal.toFloat
|
||||
IO.println s!" List.mergeSort is {speedup}x faster overall"
|
||||
else
|
||||
IO.println " Both implementations took the same time"
|
||||
|
||||
IO.println ""
|
||||
IO.println "(ratio > 1 means List faster, < 1 means Array faster)"
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
# mergeSortBenchmark
|
||||
|
||||
Benchmarking `List.mergeSort`.
|
||||
Benchmarking `List.mergeSort` and `Array.mergeSort`.
|
||||
|
||||
Run `lake exe mergeSort k` to run a benchmark on lists of size `k * 10^5`.
|
||||
This reports the average time (in milliseconds) to sort:
|
||||
* an already sorted list
|
||||
* a reverse sorted list
|
||||
* an almost sorted list
|
||||
* and a random list with duplicates
|
||||
Run `lake exe mergeSort k` to run a benchmark on collections of size `k * 10^5`.
|
||||
This reports the total and average time (in milliseconds) to sort:
|
||||
* an already sorted list/array
|
||||
* a reverse sorted list/array
|
||||
* an almost sorted list/array
|
||||
* and a random list/array with duplicates
|
||||
|
||||
The benchmark also reports the comparative performance between the two implementations.
|
||||
|
||||
## Performance Characteristics
|
||||
|
||||
In many cases, `List.mergeSort` is faster. However, for large, random collections (>= 600k elements), `Array.mergeSort` scales better.
|
||||
|
||||
Run `python3 bench.py` to run this for `k = 1, .., 10`, and calculate a best fit
|
||||
of the model `A * k + B * k * log k` to the observed runtimes.
|
||||
(This isn't really what one should do:
|
||||
fitting a log to data across a single order of magnitude is not helpful.)
|
||||
|
||||
More detailed comparisons can be generated using `python3 bench2.py`.
|
||||
|
||||
@@ -1,38 +1,73 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick benchmark script for mergeSort comparison.
|
||||
|
||||
Runs benchmarks across different input sizes, fits n*log(n) curves
|
||||
to aggregate times, and prints results. For detailed per-pattern
|
||||
visualization, use bench2.py instead.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import re
|
||||
import numpy as np
|
||||
from scipy.optimize import curve_fit
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Function to run the command and capture the elapsed time from stdout
|
||||
# Function to run the command and capture the elapsed times from stdout
|
||||
def benchmark(i):
|
||||
result = subprocess.run([f'./.lake/build/bin/mergeSort', str(i)], capture_output=True, text=True)
|
||||
elapsed_time_ms = int(result.stdout.strip()) # Assuming the time is printed as a single integer in ms
|
||||
return elapsed_time_ms / 1e3 # Convert milliseconds to seconds
|
||||
result = subprocess.run(
|
||||
['./.lake/build/bin/mergeSort', str(i)],
|
||||
capture_output=True, text=True, check=True
|
||||
)
|
||||
list_match = re.search(r'List\.mergeSort:\s+(\d+)\s+ms total', result.stdout)
|
||||
array_match = re.search(r'Array\.mergeSort:\s+(\d+)\s+ms total', result.stdout)
|
||||
if not list_match or not array_match:
|
||||
raise ValueError(f"Failed to parse output:\n{result.stdout}")
|
||||
return int(list_match.group(1)), int(array_match.group(1))
|
||||
|
||||
# Benchmark for i = 0.1, 0.2, ..., 1.0 with 5 runs each
|
||||
i_values = []
|
||||
times = []
|
||||
list_times = []
|
||||
array_times = []
|
||||
|
||||
print("Running benchmarks...")
|
||||
for i in range(1, 11):
|
||||
run_times = sorted([benchmark(i) for _ in range(5)])
|
||||
middle_three_avg = np.mean(run_times[1:4]) # Take the average of the middle 3 times
|
||||
times.append(middle_three_avg)
|
||||
i_values.append(i / 1e1)
|
||||
print(f" Size: {i * 100_000} elements (5 runs)...", end=' ', flush=True)
|
||||
|
||||
list_runs = []
|
||||
array_runs = []
|
||||
for _ in range(5):
|
||||
lt, at = benchmark(i)
|
||||
list_runs.append(lt)
|
||||
array_runs.append(at)
|
||||
|
||||
list_avg = np.median(list_runs)
|
||||
array_avg = np.median(array_runs)
|
||||
|
||||
i_values.append(i / 10)
|
||||
list_times.append(list_avg / 1000)
|
||||
array_times.append(array_avg / 1000)
|
||||
|
||||
print(f"List: {list_avg:.0f}ms, Array: {array_avg:.0f}ms")
|
||||
|
||||
# Fit the data to A*i + B*i*log(i)
|
||||
def model(i, A, B):
|
||||
return A * i + B * i * np.log(i)
|
||||
return A * i + B * i * np.log(np.maximum(i, 1e-10))
|
||||
|
||||
popt, _ = curve_fit(model, i_values, times)
|
||||
A, B = popt
|
||||
list_popt, _ = curve_fit(model, i_values, list_times)
|
||||
array_popt, _ = curve_fit(model, i_values, array_times)
|
||||
|
||||
# Print the fit parameters
|
||||
print(f"Best fit parameters: A = {A}, B = {B}")
|
||||
print(f"\nBest fit parameters for A*i + B*i*log(i):")
|
||||
print(f" List.mergeSort: A = {list_popt[0]:.6f}, B = {list_popt[1]:.6f}")
|
||||
print(f" Array.mergeSort: A = {array_popt[0]:.6f}, B = {array_popt[1]:.6f}")
|
||||
|
||||
# Plot the results
|
||||
plt.plot(i_values, times, 'o', label='Benchmark Data (Avg of Middle 3)')
|
||||
plt.plot(i_values, model(np.array(i_values), *popt), '-', label=f'Fit: A*i + B*i*log(i)\nA={A:.3f}, B={B:.3f}')
|
||||
plt.xlabel('i')
|
||||
plt.plot(i_values, list_times, 'o', label='List.mergeSort', color='blue')
|
||||
plt.plot(i_values, array_times, 's', label='Array.mergeSort', color='red')
|
||||
plt.plot(i_values, model(np.array(i_values), *list_popt), '-', color='blue', alpha=0.5,
|
||||
label=f'List fit: A={list_popt[0]:.3f}, B={list_popt[1]:.3f}')
|
||||
plt.plot(i_values, model(np.array(i_values), *array_popt), '-', color='red', alpha=0.5,
|
||||
label=f'Array fit: A={array_popt[0]:.3f}, B={array_popt[1]:.3f}')
|
||||
plt.xlabel('Input size (millions)')
|
||||
plt.ylabel('Time (s)')
|
||||
plt.title('MergeSort Aggregate Performance')
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.show()
|
||||
|
||||
140
tests/bench/mergeSort/bench2.py
Normal file
140
tests/bench/mergeSort/bench2.py
Normal file
@@ -0,0 +1,140 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Benchmark script for comparing List.mergeSort and Array.mergeSort performance.
|
||||
|
||||
Runs benchmarks across different input sizes (100k to 1M elements),
|
||||
collects per-pattern results, and generates comparison plots.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import re
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
PATTERNS = ["Reversed", "Sorted", "Random", "Partially sorted"]
|
||||
|
||||
def benchmark(i):
|
||||
"""
|
||||
Run the benchmark for size i * 10^5 and extract per-pattern times.
|
||||
|
||||
Returns:
|
||||
dict: { pattern: (list_ms, array_ms) }
|
||||
"""
|
||||
result = subprocess.run(
|
||||
['./.lake/build/bin/mergeSort', str(i)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
|
||||
results = {}
|
||||
for pattern in PATTERNS:
|
||||
m = re.search(
|
||||
rf'{re.escape(pattern)}\s*:\s*List\s+(\d+)ms,\s*Array\s+(\d+)ms',
|
||||
result.stdout
|
||||
)
|
||||
if not m:
|
||||
raise ValueError(f"Failed to parse '{pattern}' from:\n{result.stdout}")
|
||||
results[pattern] = (int(m.group(1)), int(m.group(2)))
|
||||
|
||||
return results
|
||||
|
||||
# Benchmark for i = 1, 2, ..., 10 (100k to 1M elements) with 3 runs each
|
||||
sizes = list(range(1, 11))
|
||||
num_runs = 3
|
||||
|
||||
# { pattern: { "list": [avg_per_size], "array": [avg_per_size] } }
|
||||
data = {p: {"list": [], "array": []} for p in PATTERNS}
|
||||
|
||||
print("Running benchmarks...")
|
||||
for i in sizes:
|
||||
n = i * 100_000
|
||||
print(f" Size: {n:>10} elements ({num_runs} runs)...", end=' ', flush=True)
|
||||
|
||||
runs = {p: {"list": [], "array": []} for p in PATTERNS}
|
||||
|
||||
for _ in range(num_runs):
|
||||
results = benchmark(i)
|
||||
for p in PATTERNS:
|
||||
lt, at = results[p]
|
||||
runs[p]["list"].append(lt)
|
||||
runs[p]["array"].append(at)
|
||||
|
||||
parts = []
|
||||
for p in PATTERNS:
|
||||
list_avg = np.median(runs[p]["list"])
|
||||
array_avg = np.median(runs[p]["array"])
|
||||
data[p]["list"].append(list_avg)
|
||||
data[p]["array"].append(array_avg)
|
||||
parts.append(f"{p}: L={list_avg:.0f} A={array_avg:.0f}")
|
||||
print(" | ".join(parts))
|
||||
|
||||
sizes_k = [i * 100 for i in sizes] # in thousands
|
||||
|
||||
# --- Plotting ---
|
||||
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
||||
fig.suptitle('MergeSort: List vs Array by Data Pattern', fontsize=14, fontweight='bold')
|
||||
|
||||
colors = {"list": "#2196F3", "array": "#F44336"}
|
||||
|
||||
for ax, pattern in zip(axes.flat, PATTERNS):
|
||||
list_ms = np.array(data[pattern]["list"])
|
||||
array_ms = np.array(data[pattern]["array"])
|
||||
|
||||
ax.plot(sizes_k, list_ms, 'o-', color=colors["list"], label='List.mergeSort', markersize=5)
|
||||
ax.plot(sizes_k, array_ms, 's-', color=colors["array"], label='Array.mergeSort', markersize=5)
|
||||
|
||||
ax.set_title(pattern, fontsize=12, fontweight='bold')
|
||||
ax.set_xlabel('Size (thousands)')
|
||||
ax.set_ylabel('Time (ms)')
|
||||
ax.legend(fontsize=9)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Annotate winner at largest size
|
||||
if list_ms[-1] < array_ms[-1]:
|
||||
ratio = array_ms[-1] / list_ms[-1]
|
||||
ax.annotate(f'List {ratio:.1f}x faster', xy=(0.98, 0.95),
|
||||
xycoords='axes fraction', ha='right', va='top',
|
||||
fontsize=9, color=colors["list"], fontweight='bold')
|
||||
else:
|
||||
ratio = list_ms[-1] / array_ms[-1]
|
||||
ax.annotate(f'Array {ratio:.1f}x faster', xy=(0.98, 0.95),
|
||||
xycoords='axes fraction', ha='right', va='top',
|
||||
fontsize=9, color=colors["array"], fontweight='bold')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# --- Speedup summary plot ---
|
||||
fig2, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
|
||||
|
||||
# Left: ratio per pattern across sizes
|
||||
for pattern in PATTERNS:
|
||||
list_ms = np.array(data[pattern]["list"])
|
||||
array_ms = np.array(data[pattern]["array"])
|
||||
ratio = array_ms / np.maximum(list_ms, 1)
|
||||
ax1.plot(sizes_k, ratio, 'o-', label=pattern, markersize=5)
|
||||
|
||||
ax1.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
|
||||
ax1.set_xlabel('Size (thousands)')
|
||||
ax1.set_ylabel('Array time / List time')
|
||||
ax1.set_title('Ratio by Pattern (< 1 = Array faster)')
|
||||
ax1.legend(fontsize=9)
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Right: aggregate
|
||||
list_total = np.zeros(len(sizes))
|
||||
array_total = np.zeros(len(sizes))
|
||||
for p in PATTERNS:
|
||||
list_total += np.array(data[p]["list"])
|
||||
array_total += np.array(data[p]["array"])
|
||||
|
||||
ax2.plot(sizes_k, list_total, 'o-', color=colors["list"], label='List (aggregate)', markersize=5)
|
||||
ax2.plot(sizes_k, array_total, 's-', color=colors["array"], label='Array (aggregate)', markersize=5)
|
||||
ax2.set_xlabel('Size (thousands)')
|
||||
ax2.set_ylabel('Total time (ms, 4 patterns)')
|
||||
ax2.set_title('Aggregate Performance')
|
||||
ax2.legend(fontsize=9)
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user