Files
lean4/tests/compile_bench/unionfind.lean
Garmelon 08eb78a5b2 chore: switch to new test/bench suite (#12590)
This PR sets up the new integrated test/bench suite. It then migrates
all benchmarks and some related tests to the new suite. There's also
some documentation and some linting.

For now, a lot of the old tests are left alone so this PR doesn't become
even larger than it already is. Eventually, all tests should be migrated
to the new suite though so there isn't a confusing mix of two systems.
2026-02-25 13:51:53 +00:00

127 lines
4.3 KiB
Lean4
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
def StateT' (m : Type Type) (σ : Type) (α : Type) := σ m (α × σ)
namespace StateT'
variable {m : Type Type} [Monad m] {σ : Type} {α β : Type}
@[inline] protected def pure (a : α) : StateT' m σ α := fun s => pure (a, s)
@[inline] protected def bind (x : StateT' m σ α) (f : α StateT' m σ β) : StateT' m σ β := fun s => do let (a, s') x s; f a s'
@[inline] def read : StateT' m σ σ := fun s => pure (s, s)
@[inline] def write (s' : σ) : StateT' m σ Unit := fun s => pure ((), s')
@[inline] def updt (f : σ σ) : StateT' m σ Unit := fun s => pure ((), f s)
instance : Monad (StateT' m σ) :=
{pure := @StateT'.pure _ _ _, bind := @StateT'.bind _ _ _}
end StateT'
def ExceptT' (m : Type Type) (ε : Type) (α : Type) := m (Except ε α)
namespace ExceptT'
variable {m : Type Type} [Monad m] {ε : Type} {α β : Type}
@[inline] protected def pure (a : α) : ExceptT' m ε α := (pure (Except.ok a) : m (Except ε α))
@[inline] protected def bind (x : ExceptT' m ε α) (f : α ExceptT' m ε β) : ExceptT' m ε β :=
(do { let v x; match v with
| Except.error e => pure (Except.error e)
| Except.ok a => f a } : m (Except ε β))
@[inline] def error (e : ε) : ExceptT' m ε α := (pure (Except.error e) : m (Except ε α))
@[inline] def lift (x : m α) : ExceptT' m ε α := (do {let a x; pure (Except.ok a) } : m (Except ε α))
instance : Monad (ExceptT' m ε) :=
{pure := @ExceptT'.pure _ _ _, bind := @ExceptT'.bind _ _ _}
end ExceptT'
abbrev Node := Nat
structure nodeData :=
(find : Node) (rank : Nat := 0)
abbrev ufData := Array nodeData
abbrev M (α : Type) := ExceptT' (StateT' Id ufData) String α
@[inline] def read : M ufData := ExceptT'.lift StateT'.read
@[inline] def write (s : ufData) : M Unit := ExceptT'.lift (StateT'.write s)
@[inline] def updt (f : ufData ufData) : M Unit := ExceptT'.lift (StateT'.updt f)
@[inline] def error {α : Type} (e : String) : M α := ExceptT'.error e
def run {α : Type} (x : M α) (s : ufData := ) : Except String α × ufData :=
x s
def capacity : M Nat :=
do let d read; pure d.size
def findEntryAux : Nat Node M nodeData
| 0, n => error "out of fuel"
| i+1, n =>
do let s read;
if h : n < s.size then
do { let e := s[n];
if e.find = n then pure e
else do let e₁ findEntryAux i e.find;
updt (fun s => s.set! n e₁);
pure e₁ }
else error "invalid Node"
def findEntry (n : Node) : M nodeData :=
do let c capacity;
findEntryAux c n
def find (n : Node) : M Node :=
do let e findEntry n; pure e.find
def mk : M Node :=
do let n capacity;
updt $ fun s => s.push {find := n, rank := 1};
pure n
def union (n₁ n₂ : Node) : M Unit :=
do let r₁ findEntry n₁;
let r₂ findEntry n₂;
if r₁.find = r₂.find then pure ()
else updt $ fun s =>
if r₁.rank < r₂.rank then s.set! r₁.find { find := r₂.find }
else if r₁.rank = r₂.rank then
let s₁ := s.set! r₁.find { find := r₂.find };
s₁.set! r₂.find { r₂ with rank := r₂.rank + 1 }
else s.set! r₂.find { find := r₁.find }
def mkNodes : Nat M Unit
| 0 => pure ()
| n+1 => do _ mk; mkNodes n
def checkEq (n₁ n₂ : Node) : M Unit :=
do let r₁ find n₁; let r₂ find n₂;
unless (r₁ = r₂) do error "nodes are not equal"
def mergePackAux : Nat Nat Nat M Unit
| 0, _, _ => pure ()
| i+1, n, d => do
let c capacity;
if (n+d) < c
then union n (n+d) *> mergePackAux i (n+1) d
else pure ()
def mergePack (d : Nat) : M Unit :=
do let c capacity; mergePackAux c 0 d
def numEqsAux : Nat Node Nat M Nat
| 0, _, r => pure r
| i+1, n, r =>
do let c capacity;
if n < c
then do { let n₁ find n; numEqsAux i (n+1) (if n = n₁ then r else r+1) }
else pure r
def numEqs : M Nat :=
do let c capacity;
numEqsAux c 0 0
def test (n : Nat) : M Nat :=
if n < 2 then error "input must be greater than 1"
else do
mkNodes n;
mergePack 50000;
mergePack 10000;
mergePack 5000;
mergePack 1000;
numEqs
def main (xs : List String) : IO UInt32 :=
let n := xs.head!.toNat!;
match run (test n) with
| (Except.ok v, s) => IO.println ("ok " ++ toString v) *> pure 0
| (Except.error e, s) => IO.println ("Error : " ++ e) *> pure 1