mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-17 18:34:06 +00:00
This PR sets up the new integrated test/bench suite. It then migrates all benchmarks and some related tests to the new suite. There's also some documentation and some linting. For now, a lot of the old tests are left alone so this PR doesn't become even larger than it already is. Eventually, all tests should be migrated to the new suite though so there isn't a confusing mix of two systems.
369 lines
12 KiB
Lean4
369 lines
12 KiB
Lean4
/-
|
||
Linear Diophantine equation solver
|
||
|
||
Author: Marc Huisinga
|
||
-/
|
||
import Lean.Data.AssocList
|
||
import Std.Data.HashMap
|
||
import Std.Data.Iterators.Producers.Range
|
||
import Std.Data.Iterators.Combinators.StepSize
|
||
|
||
open Lean
|
||
|
||
namespace Int
|
||
|
||
def roundedDiv (a b : Int) : Int := Id.run <| do
|
||
if b = 0 then
|
||
return 0
|
||
let mut div := a / b
|
||
let rest := a % b
|
||
-- This determines how we should adjust the divisor.
|
||
-- The extra logic is to preserve tie-breaking behavior from
|
||
-- a time when div used T-rounding
|
||
if a ≥ 0 then
|
||
if 2*rest ≥ b.natAbs then
|
||
div := div + (if b ≥ 0 then 1 else -1)
|
||
else
|
||
if 2*rest > b.natAbs then
|
||
div := div + (if b >= 0 then 1 else -1)
|
||
return div
|
||
|
||
def mod' (a b : Int) : Int :=
|
||
a - b*(a.roundedDiv b)
|
||
|
||
end Int
|
||
|
||
namespace Lean.AssocList
|
||
|
||
def map (f : α → β → δ) : AssocList α β → AssocList α δ
|
||
| AssocList.nil => AssocList.nil
|
||
| AssocList.cons k v t => AssocList.cons k (f k v) (map f t)
|
||
|
||
def filter (p : α → β → Bool) : AssocList α β → AssocList α β
|
||
| AssocList.nil => AssocList.nil
|
||
| AssocList.cons k v t =>
|
||
if p k v then
|
||
AssocList.cons k v (filter p t)
|
||
else
|
||
filter p t
|
||
|
||
end Lean.AssocList
|
||
|
||
namespace Std.HashMap
|
||
|
||
variable [BEq α] [Hashable α]
|
||
|
||
def modify! [Inhabited β] (xs : HashMap α β) (k : α) (f : β → β) : HashMap α β :=
|
||
let v := xs[k]!
|
||
xs.erase k |>.insert k (f v)
|
||
|
||
def mapValsM [Monad m] (f : β → m γ) (xs : HashMap α β) : m (HashMap α γ) :=
|
||
HashMap.emptyWithCapacity (capacity := xs.size) |> xs.foldM fun acc k v => return acc.insert k (←f v)
|
||
|
||
def mapVals (f : β → γ) (xs : HashMap α β) : HashMap α γ :=
|
||
HashMap.emptyWithCapacity (capacity := xs.size) |> xs.fold fun acc k v => acc.insert k (f v)
|
||
|
||
def fastMapVals (f : α → β → β) (xs : HashMap α β) : HashMap α β :=
|
||
xs.map f
|
||
|
||
def getAny? (x : HashMap α β) : Option (α × β) := Id.run <| do
|
||
for (k, v) in x do
|
||
return some (k, v)
|
||
return none
|
||
|
||
end Std.HashMap
|
||
|
||
open Std (HashMap)
|
||
|
||
structure Equation where
|
||
id : Nat
|
||
coeffs : HashMap Nat Int
|
||
const : Int
|
||
deriving Inhabited
|
||
|
||
def gcd (coeffs : HashMap Nat Int) : Nat :=
|
||
let coeffs := coeffs.mapVals (·.natAbs)
|
||
let coeffsContent := coeffs.toArray
|
||
match coeffsContent with
|
||
| #[] => panic! "Cannot calculate GCD of empty list of coefficients"
|
||
| #[(_, x)] => x
|
||
| coeffsContent =>
|
||
coeffsContent[0]!.2.gcd coeffsContent[1]!.2
|
||
|> coeffs.fold fun acc _ v => acc.gcd v
|
||
|
||
namespace Equation
|
||
|
||
def preprocess? (e : Equation) : Option Equation := Id.run <| do
|
||
let gcd : Int := gcd e.coeffs
|
||
if e.const % gcd ≠ 0 then
|
||
return none
|
||
return some { e with
|
||
coeffs := e.coeffs.fastMapVals fun _ coeff => coeff / gcd
|
||
const := e.const / gcd }
|
||
|
||
def subst (fromEq toEq : Equation) (varIdx : Nat) : Equation := Id.run <| do
|
||
-- varIdx ≡ k
|
||
-- fromEq ≡ sₖxₖ + ∑ i ∈ V_fromEq\{k}. aᵢxᵢ = A
|
||
-- ⇔ xₖ = sₖA - ∑ i ∈ V_fromEq\{k}. sₖaᵢxᵢ
|
||
-- toEq ≡ bₖxₖ + ∑ i ∈ V_toEq\{k}. bᵢxᵢ = B
|
||
-- ⇝ B = bₖ(sₖA - ∑ i ∈ V_fromEq\{k}. sₖaᵢxᵢ) + ∑ i ∈ V_toEq\{k}. bᵢxᵢ
|
||
-- = bₖsₖA - ∑ i ∈ V_fromEq\{k}. sₖbₖaᵢxᵢ + ∑ i ∈ V_toEq\{k}. bᵢxᵢ
|
||
-- = bₖsₖA + ∑ i ∈ V_fromEq\V_toEq. -sₖbₖaᵢxᵢ
|
||
-- + ∑ i ∈ V_toEq\{k} ∩ V_fromEq. (bᵢ - sₖbₖaᵢ)xᵢ
|
||
-- + ∑ i ∈ V_toEq\V_fromEq. bᵢxᵢ
|
||
-- ⇔ B - bₖsₖA = + ∑ i ∈ X. -sₖbₖaᵢxᵢ
|
||
-- + ∑ i ∈ Y. (bᵢ - sₖbₖaᵢ)xᵢ
|
||
-- + ∑ i ∈ Z. bᵢxᵢ
|
||
-- with X, Y, Z defined as above, X ∪ Y ∪ Z = (V_fromEq ∪ V_toEq)\{k}
|
||
-- and X, Y, Z pairwise disjoint
|
||
let A := fromEq.const
|
||
let B := toEq.const
|
||
let V_fromEq := fromEq.coeffs
|
||
let V_toEq := toEq.coeffs
|
||
let k := varIdx
|
||
let sₖ := V_fromEq[k]!
|
||
let bₖ := V_toEq[k]!
|
||
let mut V_toEq := V_toEq.fastMapVals fun i bᵢ =>
|
||
match V_fromEq[i]? with
|
||
| none =>
|
||
bᵢ
|
||
| some aᵢ =>
|
||
bᵢ - sₖ*bₖ*aᵢ
|
||
for (i, aᵢ) in V_fromEq do
|
||
if ¬V_toEq.contains i then
|
||
V_toEq := V_toEq.insert i (-sₖ*bₖ*aᵢ)
|
||
V_toEq := V_toEq.filter fun i bᵢ => i ≠ k ∧ bᵢ ≠ 0
|
||
let B' := B - bₖ*sₖ*A
|
||
{ toEq with coeffs := V_toEq, const := B' }
|
||
|
||
def normalize (e : Equation) : Equation := Id.run <| do
|
||
if e.coeffs.size ≠ 1 then
|
||
return e
|
||
let (i, c) := e.coeffs.getAny?.get!
|
||
return { e with
|
||
coeffs := e.coeffs.insert i 1
|
||
const := Int.ediv e.const c }
|
||
|
||
def invert (e : Equation) : Equation :=
|
||
{ e with
|
||
coeffs := e.coeffs.fastMapVals fun _ coeff => (-1)*coeff
|
||
const := (-1)*e.const }
|
||
|
||
def reorganizeFor (e : Equation) (varIdx : Nat) : Equation := Id.run <| do
|
||
let singletonCoeff := e.coeffs[varIdx]!
|
||
let mut e := { e with coeffs := e.coeffs.fastMapVals fun _ coeff => (-1)*coeff }
|
||
if singletonCoeff = -1 then
|
||
e := e.invert
|
||
{ e with coeffs := e.coeffs.erase varIdx }
|
||
|
||
def findSingleton? (e : Equation) : Option (Nat × Int) := Id.run <| do
|
||
for (i, coeff) in e.coeffs do
|
||
if coeff = 1 ∨ coeff = -1 then
|
||
return some (i, coeff)
|
||
return none
|
||
|
||
def findAbsMinimumCoeff? (e : Equation) : Option (Nat × Int) := Id.run <| do
|
||
let mut r? : Option (Nat × Int) := none
|
||
for (i, coeff) in e.coeffs do
|
||
match r? with
|
||
| none =>
|
||
r? := some (i, coeff)
|
||
| some (_, coeff') =>
|
||
if coeff.natAbs < coeff'.natAbs then
|
||
r? := some (i, coeff)
|
||
return r?
|
||
|
||
end Equation
|
||
|
||
structure Problem where
|
||
equations : HashMap Nat Equation
|
||
solvedEquations : HashMap Nat Equation
|
||
nEquations : Nat
|
||
nVars : Nat
|
||
deriving Inhabited
|
||
|
||
def preprocess? (eqs : HashMap Nat Equation) : Option (HashMap Nat Equation) :=
|
||
eqs.mapValsM (·.preprocess?)
|
||
|
||
def eliminateSingleton (p : Problem) (singletonEq : Equation) (varIdx : Nat) : Problem := Id.run <| do
|
||
let mut eqsWithVarIdx : Array Nat := #[]
|
||
for (id, eq) in p.equations do
|
||
if eq.coeffs.contains varIdx then
|
||
eqsWithVarIdx := eqsWithVarIdx.push id
|
||
let mut equations := p.equations
|
||
for id in eqsWithVarIdx do
|
||
if id == singletonEq.id then
|
||
continue
|
||
equations := equations.modify! id fun eq => singletonEq.subst eq varIdx |>.normalize
|
||
equations := equations.erase singletonEq.id
|
||
let solvedEquations := p.solvedEquations.insert varIdx <| singletonEq.reorganizeFor varIdx
|
||
return { p with
|
||
equations := equations
|
||
solvedEquations := solvedEquations }
|
||
|
||
partial def eliminateSingletons (p : Problem) : Problem := Id.run <| do
|
||
let mut r? : Option (Equation × Nat) := none
|
||
for (_, eq) in p.equations do
|
||
match eq.findSingleton? with
|
||
| none =>
|
||
continue
|
||
| some (varIdx, _) =>
|
||
r? := some (eq, varIdx)
|
||
match r? with
|
||
| none =>
|
||
return p
|
||
| some (eq, varIdx) =>
|
||
let p := eliminateSingleton p eq varIdx
|
||
return eliminateSingletons p
|
||
|
||
def addAuxEquation (p : Problem) : Problem := Id.run <| do
|
||
let mut E? : Option Equation := none
|
||
let mut k? : Option Nat := none
|
||
let mut aₖ? : Option Int := none
|
||
for (_, eq) in p.equations do
|
||
match eq.findAbsMinimumCoeff?, aₖ? with
|
||
| none, _ => continue
|
||
| some (k', aₖ'), none =>
|
||
E? := some eq
|
||
k? := some k'
|
||
aₖ? := some aₖ'
|
||
| some (k', aₖ'), some aₖ =>
|
||
if aₖ'.natAbs < aₖ.natAbs then
|
||
E? := some eq
|
||
k? := some k'
|
||
aₖ? := some aₖ'
|
||
let mut E := E?.get!
|
||
let k := k?.get!
|
||
let mut aₖ := aₖ?.get!
|
||
if aₖ < 0 then
|
||
aₖ := -aₖ
|
||
E := E.invert
|
||
let m := aₖ + 1
|
||
let σIdx := p.nVars
|
||
let newEqCoeffs := E.coeffs.fastMapVals (fun _ coeff => coeff.mod' m)
|
||
|>.insert σIdx (-m)
|
||
|>.filter (fun _ coeff => coeff ≠ 0)
|
||
let newEqConst := E.const.mod' m
|
||
let newEq : Equation := ⟨p.nEquations, newEqCoeffs, newEqConst⟩
|
||
let E'coeffs := E.coeffs.filter (fun i _ => i ≠ k)
|
||
|>.fastMapVals (fun _ aᵢ => aᵢ.roundedDiv m + aᵢ.mod' m)
|
||
|>.insert σIdx (-aₖ)
|
||
|>.filter (fun _ coeff => coeff ≠ 0)
|
||
let c := E.const
|
||
let E'const := c.roundedDiv m + c.mod' m
|
||
let E' := { E with coeffs := E'coeffs, const := E'const }.normalize
|
||
let equations' := p.equations.insert E'.id E' |>.insert newEq.id newEq
|
||
let p' : Problem := { p with
|
||
equations := equations'
|
||
nVars := p.nVars + 1
|
||
nEquations := p.nEquations + 1 }
|
||
return eliminateSingleton p' newEq k
|
||
|
||
inductive Solution
|
||
| unsat
|
||
| sat (assignment : Array Int)
|
||
deriving Inhabited
|
||
|
||
partial def readSolution? (p : Problem) : Option Solution := Id.run <| do
|
||
if p.equations.any (fun _ eq => eq.coeffs.size ≠ 0) then
|
||
return none
|
||
if p.equations.any (fun _ eq => eq.const ≠ 0) then
|
||
return some Solution.unsat
|
||
let mut assignment : Array (Option Int) := Array.replicate p.nVars none
|
||
for i in *...p.nVars do
|
||
assignment := readSolution i assignment
|
||
return Solution.sat <| assignment.map (·.get!)
|
||
where
|
||
readSolution (varIdx : Nat) (assignment : Array (Option Int)) : Array (Option Int) := Id.run <| do
|
||
match p.solvedEquations[varIdx]? with
|
||
| none =>
|
||
return assignment.set! varIdx (some 0)
|
||
| some eq =>
|
||
let mut assignment := assignment
|
||
let mut r := eq.const
|
||
for (i, coeff) in eq.coeffs do
|
||
if assignment[i]!.isNone then
|
||
assignment := readSolution i assignment
|
||
r := r + coeff*assignment[i]!.get!
|
||
return assignment.set! varIdx (some r)
|
||
|
||
partial def solveProblem' (p : Problem) : Solution := Id.run <| do
|
||
match readSolution? p with
|
||
| some solution => return solution
|
||
| none =>
|
||
let p := eliminateSingletons p
|
||
match readSolution? p with
|
||
| some solution => return solution
|
||
| none =>
|
||
let p := addAuxEquation p
|
||
return solveProblem' p
|
||
|
||
def isSatAssignment (p : Problem) (assignment : Array Int) : Bool :=
|
||
¬ p.equations.any fun _ (eq : Equation) => Id.run <| do
|
||
let mut r := 0
|
||
for (i, coeff) in eq.coeffs do
|
||
r := r + coeff*assignment[i]!
|
||
return r ≠ eq.const
|
||
|
||
def solveProblem (p : Problem) : Solution :=
|
||
let nVars := p.nVars
|
||
match solveProblem' p with
|
||
| Solution.unsat =>
|
||
Solution.unsat
|
||
| Solution.sat assignment =>
|
||
let assignment' := assignment.extract 0 nVars
|
||
if isSatAssignment p assignment' then
|
||
Solution.sat assignment'
|
||
else
|
||
Solution.unsat
|
||
|
||
def error (msg : String) : IO α :=
|
||
throw <| IO.userError s!"Error: {msg}."
|
||
|
||
def Array.ithVal (xs : Array String) (i : Nat) (name : String) : IO Int := do
|
||
let some unparsed := xs[i]?
|
||
| error s!"Missing {name}"
|
||
let some parsed := String.toInt? unparsed
|
||
| error s!"Invalid {name}: `{unparsed}`"
|
||
return parsed
|
||
|
||
def main (args : List String) : IO UInt32 := do
|
||
let some path := args.head?
|
||
| error "Usage: liasolver <input file>"
|
||
let lines ← IO.FS.lines path <&> Array.filter (¬·.isEmpty)
|
||
let some headerLine := lines[0]?
|
||
| error "No header line"
|
||
let header := headerLine.splitOn.toArray
|
||
let nEquations ← header.ithVal 0 "amount of equations"
|
||
let nVars ← header.ithVal 1 "amount of variables"
|
||
let mut equations : HashMap Nat Equation := ∅
|
||
for line in lines[1...*] do
|
||
let elems := line.splitOn.toArray
|
||
let nTerms ← elems.ithVal 0 "amount of equation terms"
|
||
let 0 ← elems.ithVal (elems.size - 1) "end of line symbol"
|
||
| error "Non-zero end of line symbol"
|
||
let const ← elems.ithVal (elems.size - 2) "constant value"
|
||
let mut coeffs := ∅
|
||
for i in ((1 : Nat)...(elems.size-2)).iter.stepSize 2 do
|
||
let coeff ← elems.ithVal i "coefficient"
|
||
let varIdx ← elems.ithVal (i + 1) "variable index"
|
||
if varIdx < 1 then
|
||
error "Invalid variable index"
|
||
let varIdx := varIdx.toNat - 1
|
||
if coeff ≠ 0 then
|
||
coeffs := coeffs.insert varIdx coeff
|
||
if coeffs.size ≠ 0 then
|
||
equations := equations.insert equations.size ⟨equations.size, coeffs, const⟩
|
||
match preprocess? equations with
|
||
| none =>
|
||
IO.println "UNSAT"
|
||
| some equations' =>
|
||
let problem : Problem := ⟨equations', ∅, equations'.size, nVars.natAbs⟩
|
||
match solveProblem problem with
|
||
| Solution.unsat =>
|
||
IO.println "UNSAT"
|
||
| Solution.sat assignment =>
|
||
IO.println "SAT"
|
||
IO.println <| String.intercalate " " <| assignment.toList.map toString
|
||
return 0
|