Compare commits

...

2 Commits

Author SHA1 Message Date
Kim Morrison
6af7849911 . 2025-07-25 17:54:50 +10:00
Kim Morrison
31aa2c00a9 wip 2025-07-25 17:54:01 +10:00

View File

@@ -17,6 +17,8 @@ set_option linter.listVariables true -- Enforce naming conventions for `List`/`A
namespace Array
-- cf https://en.wikipedia.org/wiki/Quicksort#Repeated_elements
/--
Internal implementation of `Array.qsort`.
@@ -32,46 +34,39 @@ if necessary so that the middle (pivot) element is at index `hi`.
We then iterate from `k = lo` to `k = hi`, with a pointer `i` starting at `lo`, and
swapping each element which is less than the pivot to position `i`, and then incrementing `i`.
-/
def qpartition {n} (as : Vector α n) (lt : α α Bool) (lo hi : Nat) (w : lo hi := by omega)
(hlo : lo < n := by omega) (hhi : hi < n := by omega) : {m : Nat // lo m m hi} × Vector α n :=
let mid := (lo + hi) / 2
let as := if lt as[mid] as[lo] then as.swap lo mid else as
let as := if lt as[hi] as[lo] then as.swap lo hi else as
let as := if lt as[mid] as[hi] then as.swap mid hi else as
let pivot := as[hi]
-- During this loop, elements below in `[lo, i)` are less than `pivot`,
-- elements in `[i, k)` are greater than or equal to `pivot`,
-- elements in `[k, hi)` are unexamined,
-- while `as[hi]` is (by definition) the pivot.
let rec loop (as : Vector α n) (i k : Nat)
(ilo : lo i := by omega) (ik : i k := by omega) (w : k hi := by omega) :=
if h : k < hi then
if lt as[k] pivot then
loop (as.swap i k) (i+1) (k+1)
else
loop as i (k+1)
def qpartition {n} (as : Vector α n) (cmp : α α Ordering) (lo hi : Nat) (w : lo hi := by omega)
(hlo : lo < n := by omega) (hhi : hi < n := by omega) : { p : Nat × Nat // lo p.1 p.1 p.2 p.2 hi } × Vector α n :=
let pivot := as[(lo + hi) / 2]
-- During this loop, elements in `[lo, i)` are less than `pivot`,
-- elements in `[i, j)` are equal to `pivot`,
-- elements in `[j, k]` are unexamined,
-- elements in `(k, hi]` are greater than `pivot`,
let rec loop (as : Vector α n) (i j k : Nat)
(ilo : lo i := by omega) (ij : i j := by omega) (jk : j k := by omega) (w : k hi := by omega) :=
if h : j < k then
match cmp as[j] pivot with
| .lt => loop (as.swap i j) (i+1) (j+1) k
| .gt => loop (as.swap j k) i j (k-1)
| .eq => loop as i (j+1) k
else
(i, ilo, by omega, as.swap i hi)
loop as lo lo
match cmp as[j] pivot with
| .lt => i+1, k+1, by sorry, as.swap i j
| .gt => i, k, by omega, as.swap j k
| .eq => i, k+1, by sorry, as
loop as lo lo hi
/--
In-place quicksort.
`qsort as lt lo hi` sorts the subarray `as[lo...=hi]` in-place using `lt` to compare elements.
`qsort as lt lo hi` sorts the subarray `as[lo...=hi]` in-place using `cmp` to compare elements.
-/
@[inline] def qsort (as : Array α) (lt : α α Bool := by exact (· < ·))
@[inline] def qsort (as : Array α) (cmp : α α Ordering := by exact (compareOfLessAndEq · ·))
(lo := 0) (hi := as.size - 1) : Array α :=
let rec @[specialize] sort {n} (as : Vector α n) (lo hi : Nat) (w : lo hi := by omega)
(hlo : lo < n := by omega) (hhi : hi < n := by omega) :=
if h₁ : lo < hi then
let mid, hmid, as := qpartition as lt lo hi
if h₂ : mid hi then
-- This only occurs when `hi ≤ lo`,
-- and thus `as[lo...(hi+1)]` is trivially already sorted.
as
else
-- Otherwise, we recursively sort the two subarrays.
sort (sort as lo mid) (mid+1) hi
let i, k, h₁, h₂, h₃, as := qpartition as cmp lo hi
sort (sort as lo (i-1)) k hi
else as
if h : as.size = 0 then
as
@@ -85,6 +80,6 @@ set_option linter.unusedVariables.funArgs false in
Sort an array using `compare` to compare elements.
-/
def qsortOrd [ord : Ord α] (xs : Array α) : Array α :=
xs.qsort fun x y => compare x y |>.isLT
xs.qsort compare
end Array