Compare commits

...

1 Commits

Author SHA1 Message Date
Kim Morrison
308b246462 chore: move Aray.qsort to Basic file 2025-04-30 15:00:46 +02:00
3 changed files with 85 additions and 65 deletions

View File

@@ -1,70 +1,9 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Copyright (c) 2024 Lean FRO. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
Authors: Kim Morrison
-/
module
prelude
import Init.Data.Vector.Basic
import Init.Data.Ord
set_option linter.listVariables true -- Enforce naming conventions for `List`/`Array`/`Vector` variables.
-- We do not enable `linter.indexVariables` because it is helpful to name index variables `lo`, `mid`, `hi`, etc.
namespace Array
private def qpartition {n} (as : Vector α n) (lt : α α Bool) (lo hi : Nat)
(hlo : lo < n := by omega) (hhi : hi < n := by omega) : {n : Nat // lo n} × Vector α n :=
let mid := (lo + hi) / 2
let as := if lt as[mid] as[lo] then as.swap lo mid else as
let as := if lt as[hi] as[lo] then as.swap lo hi else as
let as := if lt as[mid] as[hi] then as.swap mid hi else as
let pivot := as[hi]
let rec loop (as : Vector α n) (i j : Nat)
(ilo : lo i := by omega) (jh : j < n := by omega) (w : i j := by omega) :=
if h : j < hi then
if lt as[j] pivot then
loop (as.swap i j) (i+1) (j+1)
else
loop as i (j+1)
else
(i, ilo, as.swap i hi)
loop as lo lo
/--
Sorts an array using the Quicksort algorithm.
The optional parameter `lt` specifies an ordering predicate. It defaults to `LT.lt`, which must be
decidable to be used for sorting. Use `Array.qsortOrd` to sort the array according to the `Ord α`
instance.
The optional parameters `low` and `high` delimit the region of the array that is sorted. Both are
inclusive, and default to sorting the entire array.
-/
@[inline] def qsort (as : Array α) (lt : α α Bool := by exact (· < ·))
(low := 0) (high := as.size - 1) : Array α :=
let rec @[specialize] sort {n} (as : Vector α n) (lo hi : Nat)
(hlo : lo < n := by omega) (hhi : hi < n := by omega) :=
if h₁ : lo < hi then
let mid, hmid, as := qpartition as lt lo hi
if h₂ : mid hi then
as
else
sort (sort as lo mid) (mid+1) hi
else as
if h : as.size = 0 then
as
else
let low := min low (as.size - 1)
let high := min high (as.size - 1)
sort as, rfl low high |>.toArray
set_option linter.unusedVariables.funArgs false in
/--
Sorts an array using the Quicksort algorithm, using `Ord.compare` to compare elements.
-/
def qsortOrd [ord : Ord α] (xs : Array α) : Array α :=
xs.qsort fun x y => compare x y |>.isLT
end Array
import Init.Data.Array.QSort.Basic

View File

@@ -0,0 +1,81 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
import Init.Data.Vector.Basic
import Init.Data.Ord
set_option linter.listVariables true -- Enforce naming conventions for `List`/`Array`/`Vector` variables.
-- We do not enable `linter.indexVariables` because it is helpful to name index variables `lo`, `mid`, `hi`, etc.
namespace Array
/--
Internal implementation of `Array.qsort`.
`qpartition as lt lo hi hlo hhi` returns a pair `(⟨m, h₁, h₂⟩, as')` where
`as'` is a permutation of `as` and `m` is a number such that:
- `lo ≤ m`
- `m < n`
- `∀ i, lo ≤ i → i < m → lt as[i] as[m]`
- `∀ j, m < j → j < hi → !lt as[j] as[m]`
It does so by first swapping the elements at indices `lo`, `mid := (lo + hi) / 2`, and `hi`
if necessary so that the middle (pivot) element is at index `hi`.
We then iterate from `j = lo` to `j = hi`, with a pointer `i` starting at `lo`, and
swapping each element which is less than the pivot to position `i`, and then incrementing `i`.
-/
def qpartition {n} (as : Vector α n) (lt : α α Bool) (lo hi : Nat)
(hlo : lo < n := by omega) (hhi : hi < n := by omega) : {m : Nat // lo m m < n} × Vector α n :=
let mid := (lo + hi) / 2
let as := if lt as[mid] as[lo] then as.swap lo mid else as
let as := if lt as[hi] as[lo] then as.swap lo hi else as
let as := if lt as[mid] as[hi] then as.swap mid hi else as
let pivot := as[hi]
let rec loop (as : Vector α n) (i j : Nat)
(ilo : lo i := by omega) (jh : j < n := by omega) (w : i j := by omega) :=
if h : j < hi then
if lt as[j] pivot then
loop (as.swap i j) (i+1) (j+1)
else
loop as i (j+1)
else
(i, ilo, by omega, as.swap i hi)
loop as lo lo
/--
In-place quicksort.
`qsort as lt low high` sorts the subarray `as[low:high+1]` in-place using `lt` to compare elements.
-/
@[inline] def qsort (as : Array α) (lt : α α Bool := by exact (· < ·))
(low := 0) (high := as.size - 1) : Array α :=
let rec @[specialize] sort {n} (as : Vector α n) (lo hi : Nat)
(hlo : lo < n := by omega) (hhi : hi < n := by omega) :=
if h₁ : lo < hi then
let mid, hmid, as := qpartition as lt lo hi
if h₂ : mid hi then
as
else
sort (sort as lo mid) (mid+1) hi
else as
if h : as.size = 0 then
as
else
let low := min low (as.size - 1)
let high := min high (as.size - 1)
sort as.toVector low high |>.toArray
set_option linter.unusedVariables.funArgs false in
/--
Sort an array using `compare` to compare elements.
-/
def qsortOrd [ord : Ord α] (xs : Array α) : Array α :=
xs.qsort fun x y => compare x y |>.isLT
end Array

View File

@@ -3040,7 +3040,7 @@ theorem getElem_swap {xs : Vector α n} {i j : Nat} (hi hj) {k : Nat} (hk : k <
(xs.swap i j hi hj)[i]'(by simpa using hi) = xs[j] := by
simp [getElem_swap]
@[simp] theorem getElem_swap_of_ne {xs : Vector α n} {i j : Nat} (hi hj) (hk : k < n)
@[simp] theorem getElem_swap_of_ne {xs : Vector α n} {i j : Nat} {hi hj} {hk : k < n}
(hi' : k i) (hj' : k j) : (xs.swap i j hi hj)[k] = xs[k] := by
simp_all [getElem_swap]