mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-17 18:34:06 +00:00
This PR sets up the new integrated test/bench suite. It then migrates all benchmarks and some related tests to the new suite. There's also some documentation and some linting. For now, a lot of the old tests are left alone so this PR doesn't become even larger than it already is. Eventually, all tests should be migrated to the new suite though so there isn't a confusing mix of two systems.
127 lines
4.3 KiB
Lean4
Executable File
127 lines
4.3 KiB
Lean4
Executable File
def StateT' (m : Type → Type) (σ : Type) (α : Type) := σ → m (α × σ)
|
||
namespace StateT'
|
||
variable {m : Type → Type} [Monad m] {σ : Type} {α β : Type}
|
||
@[inline] protected def pure (a : α) : StateT' m σ α := fun s => pure (a, s)
|
||
@[inline] protected def bind (x : StateT' m σ α) (f : α → StateT' m σ β) : StateT' m σ β := fun s => do let (a, s') ← x s; f a s'
|
||
@[inline] def read : StateT' m σ σ := fun s => pure (s, s)
|
||
@[inline] def write (s' : σ) : StateT' m σ Unit := fun s => pure ((), s')
|
||
@[inline] def updt (f : σ → σ) : StateT' m σ Unit := fun s => pure ((), f s)
|
||
instance : Monad (StateT' m σ) :=
|
||
{pure := @StateT'.pure _ _ _, bind := @StateT'.bind _ _ _}
|
||
end StateT'
|
||
|
||
def ExceptT' (m : Type → Type) (ε : Type) (α : Type) := m (Except ε α)
|
||
namespace ExceptT'
|
||
variable {m : Type → Type} [Monad m] {ε : Type} {α β : Type}
|
||
@[inline] protected def pure (a : α) : ExceptT' m ε α := (pure (Except.ok a) : m (Except ε α))
|
||
@[inline] protected def bind (x : ExceptT' m ε α) (f : α → ExceptT' m ε β) : ExceptT' m ε β :=
|
||
(do { let v ← x; match v with
|
||
| Except.error e => pure (Except.error e)
|
||
| Except.ok a => f a } : m (Except ε β))
|
||
@[inline] def error (e : ε) : ExceptT' m ε α := (pure (Except.error e) : m (Except ε α))
|
||
@[inline] def lift (x : m α) : ExceptT' m ε α := (do {let a ← x; pure (Except.ok a) } : m (Except ε α))
|
||
instance : Monad (ExceptT' m ε) :=
|
||
{pure := @ExceptT'.pure _ _ _, bind := @ExceptT'.bind _ _ _}
|
||
end ExceptT'
|
||
|
||
abbrev Node := Nat
|
||
|
||
structure nodeData :=
|
||
(find : Node) (rank : Nat := 0)
|
||
|
||
abbrev ufData := Array nodeData
|
||
|
||
abbrev M (α : Type) := ExceptT' (StateT' Id ufData) String α
|
||
@[inline] def read : M ufData := ExceptT'.lift StateT'.read
|
||
@[inline] def write (s : ufData) : M Unit := ExceptT'.lift (StateT'.write s)
|
||
@[inline] def updt (f : ufData → ufData) : M Unit := ExceptT'.lift (StateT'.updt f)
|
||
@[inline] def error {α : Type} (e : String) : M α := ExceptT'.error e
|
||
def run {α : Type} (x : M α) (s : ufData := ∅) : Except String α × ufData :=
|
||
x s
|
||
|
||
def capacity : M Nat :=
|
||
do let d ← read; pure d.size
|
||
|
||
def findEntryAux : Nat → Node → M nodeData
|
||
| 0, n => error "out of fuel"
|
||
| i+1, n =>
|
||
do let s ← read;
|
||
if h : n < s.size then
|
||
do { let e := s[n];
|
||
if e.find = n then pure e
|
||
else do let e₁ ← findEntryAux i e.find;
|
||
updt (fun s => s.set! n e₁);
|
||
pure e₁ }
|
||
else error "invalid Node"
|
||
|
||
def findEntry (n : Node) : M nodeData :=
|
||
do let c ← capacity;
|
||
findEntryAux c n
|
||
|
||
def find (n : Node) : M Node :=
|
||
do let e ← findEntry n; pure e.find
|
||
|
||
def mk : M Node :=
|
||
do let n ← capacity;
|
||
updt $ fun s => s.push {find := n, rank := 1};
|
||
pure n
|
||
|
||
def union (n₁ n₂ : Node) : M Unit :=
|
||
do let r₁ ← findEntry n₁;
|
||
let r₂ ← findEntry n₂;
|
||
if r₁.find = r₂.find then pure ()
|
||
else updt $ fun s =>
|
||
if r₁.rank < r₂.rank then s.set! r₁.find { find := r₂.find }
|
||
else if r₁.rank = r₂.rank then
|
||
let s₁ := s.set! r₁.find { find := r₂.find };
|
||
s₁.set! r₂.find { r₂ with rank := r₂.rank + 1 }
|
||
else s.set! r₂.find { find := r₁.find }
|
||
|
||
|
||
def mkNodes : Nat → M Unit
|
||
| 0 => pure ()
|
||
| n+1 => do _ ← mk; mkNodes n
|
||
|
||
def checkEq (n₁ n₂ : Node) : M Unit :=
|
||
do let r₁ ← find n₁; let r₂ ← find n₂;
|
||
unless (r₁ = r₂) do error "nodes are not equal"
|
||
|
||
def mergePackAux : Nat → Nat → Nat → M Unit
|
||
| 0, _, _ => pure ()
|
||
| i+1, n, d => do
|
||
let c ← capacity;
|
||
if (n+d) < c
|
||
then union n (n+d) *> mergePackAux i (n+1) d
|
||
else pure ()
|
||
|
||
def mergePack (d : Nat) : M Unit :=
|
||
do let c ← capacity; mergePackAux c 0 d
|
||
|
||
def numEqsAux : Nat → Node → Nat → M Nat
|
||
| 0, _, r => pure r
|
||
| i+1, n, r =>
|
||
do let c ← capacity;
|
||
if n < c
|
||
then do { let n₁ ← find n; numEqsAux i (n+1) (if n = n₁ then r else r+1) }
|
||
else pure r
|
||
|
||
def numEqs : M Nat :=
|
||
do let c ← capacity;
|
||
numEqsAux c 0 0
|
||
|
||
def test (n : Nat) : M Nat :=
|
||
if n < 2 then error "input must be greater than 1"
|
||
else do
|
||
mkNodes n;
|
||
mergePack 50000;
|
||
mergePack 10000;
|
||
mergePack 5000;
|
||
mergePack 1000;
|
||
numEqs
|
||
|
||
def main (xs : List String) : IO UInt32 :=
|
||
let n := xs.head!.toNat!;
|
||
match run (test n) with
|
||
| (Except.ok v, s) => IO.println ("ok " ++ toString v) *> pure 0
|
||
| (Except.error e, s) => IO.println ("Error : " ++ e) *> pure 1
|