Compare commits

...

5 Commits

Author SHA1 Message Date
Leonardo de Moura
7ecd1c51be chore: cleanup 2025-02-11 15:00:38 -08:00
Leonardo de Moura
bd2ff7b11a feat: sortExprs in the Nat case 2025-02-11 14:54:08 -08:00
Leonardo de Moura
faabc99be8 chore: fix test 2025-02-11 14:26:56 -08:00
Leonardo de Moura
2b5fc22c23 feat: sort arith atoms
This PR ensures that terms such as `f (2*x + y)` and `f (y + x + x)`
have the same normal form.
2025-02-11 14:25:24 -08:00
Leonardo de Moura
cdedf36844 feat: add sortExprs 2025-02-11 14:08:32 -08:00
9 changed files with 148 additions and 33 deletions

View File

@@ -35,11 +35,11 @@ inductive Expr where
deriving Inhabited
def Expr.denote (ctx : Context) : Expr Nat
| Expr.add a b => Nat.add (denote ctx a) (denote ctx b)
| Expr.num k => k
| Expr.var v => v.denote ctx
| Expr.mulL k e => Nat.mul k (denote ctx e)
| Expr.mulR e k => Nat.mul (denote ctx e) k
| .add a b => Nat.add (denote ctx a) (denote ctx b)
| .num k => k
| .var v => v.denote ctx
| .mulL k e => Nat.mul k (denote ctx e)
| .mulR e k => Nat.mul (denote ctx e) k
abbrev Poly := List (Nat × Var)
@@ -146,17 +146,17 @@ where
-- Implementation note: This assembles the result using difference lists
-- to avoid `++` on lists.
go (coeff : Nat) : Expr (Poly Poly)
| Expr.num k => bif k == 0 then id else ((coeff * k, fixedVar) :: ·)
| Expr.var i => ((coeff, i) :: ·)
| Expr.add a b => go coeff a go coeff b
| Expr.mulL k a
| Expr.mulR a k => bif k == 0 then id else go (coeff * k) a
| .num k => bif k == 0 then id else ((coeff * k, fixedVar) :: ·)
| .var i => ((coeff, i) :: ·)
| .add a b => go coeff a go coeff b
| .mulL k a
| .mulR a k => bif k == 0 then id else go (coeff * k) a
def Expr.toNormPoly (e : Expr) : Poly :=
e.toPoly.norm
def Expr.inc (e : Expr) : Expr :=
Expr.add e (Expr.num 1)
.add e (.num 1)
structure PolyCnstr where
eq : Bool
@@ -244,21 +244,21 @@ def Certificate.denote (ctx : Context) (c : Certificate) : Prop :=
def monomialToExpr (k : Nat) (v : Var) : Expr :=
bif v == fixedVar then
Expr.num k
.num k
else bif k == 1 then
Expr.var v
.var v
else
Expr.mulL k (Expr.var v)
.mulL k (.var v)
def Poly.toExpr (p : Poly) : Expr :=
match p with
| [] => Expr.num 0
| [] => .num 0
| (k, v) :: p => go (monomialToExpr k v) p
where
go (e : Expr) (p : Poly) : Expr :=
match p with
| [] => e
| (k, v) :: p => go (Expr.add e (monomialToExpr k v)) p
| (k, v) :: p => go (.add e (monomialToExpr k v)) p
def PolyCnstr.toExpr (c : PolyCnstr) : ExprCnstr :=
{ c with lhs := c.lhs.toExpr, rhs := c.rhs.toExpr }

View File

@@ -5,6 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Init.Data.Int.Linear
import Lean.Util.SortExprs
import Lean.Meta.Check
import Lean.Meta.Offset
import Lean.Meta.IntInstTesters
@@ -31,6 +32,24 @@ def PolyCnstr.toExprCnstr : PolyCnstr → ExprCnstr
| .eq p => .eq p.toExpr (.num 0)
| .le p => .le p.toExpr (.num 0)
/-- Applies the given variable permutation to `e` -/
def Expr.applyPerm (perm : Lean.Perm) (e : Expr) : Expr :=
go e
where
go : Expr Expr
| .num v => .num v
| .var i => .var (perm[(i : Nat)]?.getD i)
| .neg a => .neg (go a)
| .add a b => .add (go a) (go b)
| .sub a b => .sub (go a) (go b)
| .mulL k a => .mulL k (go a)
| .mulR a k => .mulR (go a) k
/-- Applies the given variable permutation to the given expression constraint. -/
def ExprCnstr.applyPerm (perm : Lean.Perm) : ExprCnstr ExprCnstr
| .eq a b => .eq (a.applyPerm perm) (b.applyPerm perm)
| .le a b => .le (a.applyPerm perm) (b.applyPerm perm)
end Int.Linear
namespace Lean.Meta.Linear.Int
@@ -187,7 +206,24 @@ def run (x : M α) : MetaM (α × Array Expr) := do
end ToLinear
export ToLinear (toLinearCnstr? toLinearExpr)
def toLinearExpr (e : Expr) : MetaM (LinearExpr × Array Expr) := do
let (e, atoms) ToLinear.run (ToLinear.toLinearExpr e)
if atoms.size == 1 then
return (e, atoms)
else
let (atoms, perm) := sortExprs atoms
let e := e.applyPerm perm
return (e, atoms)
def toLinearCnstr? (e : Expr) : MetaM (Option (LinearCnstr × Array Expr)) := do
let (some c, atoms) ToLinear.run (ToLinear.toLinearCnstr? e)
| return none
if atoms.size <= 1 then
return some (c, atoms)
else
let (atoms, perm) := sortExprs atoms
let c := c.applyPerm perm
return some (c, atoms)
def toContextExpr (ctx : Array Expr) : Expr :=
if h : 0 < ctx.size then

View File

@@ -44,7 +44,7 @@ def Int.Linear.PolyCnstr.getConst : PolyCnstr → Int
namespace Lean.Meta.Linear.Int
def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let (some c, atoms) ToLinear.run (ToLinear.toLinearCnstr? e) | return none
let some (c, atoms) toLinearCnstr? e | return none
withAbstractAtoms atoms ``Int fun atoms => do
let lhs c.toArith atoms
let p := c.toPoly
@@ -127,13 +127,13 @@ def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
simpCnstrPos? e
def simpExpr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let (e, ctx) ToLinear.run (ToLinear.toLinearExpr e)
let (e, atoms) toLinearExpr e
let p := e.toPoly
let e' := p.toExpr
if e != e' then
-- We only return some if monomials were fused
let p := mkApp4 (mkConst ``Int.Linear.Expr.eq_of_toPoly_eq) (toContextExpr ctx) (toExpr e) (toExpr e') reflBoolTrue
let r LinearExpr.toArith ctx e'
let p := mkApp4 (mkConst ``Int.Linear.Expr.eq_of_toPoly_eq) (toContextExpr atoms) (toExpr e) (toExpr e') reflBoolTrue
let r LinearExpr.toArith atoms e'
return some (r, p)
else
return none

View File

@@ -4,12 +4,32 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Util.SortExprs
import Lean.Meta.Check
import Lean.Meta.Offset
import Lean.Meta.AppBuilder
import Lean.Meta.KExprMap
import Lean.Data.RArray
namespace Nat.Linear
/-- Applies the given variable permutation to `e` -/
def Expr.applyPerm (perm : Lean.Perm) (e : Expr) : Expr :=
go e
where
go : Expr Expr
| .num v => .num v
| .var i => .var (perm[(i : Nat)]?.getD i)
| .add a b => .add (go a) (go b)
| .mulL k a => .mulL k (go a)
| .mulR a k => .mulR (go a) k
/-- Applies the given variable permutation to the given expression constraint. -/
def ExprCnstr.applyPerm (perm : Lean.Perm) : ExprCnstr ExprCnstr
| { eq, lhs, rhs } => { eq, lhs := lhs.applyPerm perm, rhs := rhs.applyPerm perm }
end Nat.Linear
namespace Lean.Meta.Linear.Nat
deriving instance Repr for Nat.Linear.Expr
@@ -140,7 +160,24 @@ def run (x : M α) : MetaM (α × Array Expr) := do
end ToLinear
export ToLinear (toLinearCnstr? toLinearExpr)
def toLinearExpr (e : Expr) : MetaM (LinearExpr × Array Expr) := do
let (e, atoms) ToLinear.run (ToLinear.toLinearExpr e)
if atoms.size == 1 then
return (e, atoms)
else
let (atoms, perm) := sortExprs atoms
let e := e.applyPerm perm
return (e, atoms)
def toLinearCnstr? (e : Expr) : MetaM (Option (LinearCnstr × Array Expr)) := do
let (some c, atoms) ToLinear.run (ToLinear.toLinearCnstr? e)
| return none
if atoms.size <= 1 then
return some (c, atoms)
else
let (atoms, perm) := sortExprs atoms
let c := c.applyPerm perm
return some (c, atoms)
def toContextExpr (ctx : Array Expr) : Expr :=
if h : 0 < ctx.size then
@@ -148,4 +185,4 @@ def toContextExpr (ctx : Array Expr) : Expr :=
else
RArray.toExpr (mkConst ``Nat) id (RArray.leaf (mkNatLit 0))
end Lean.Meta.Linear.Nat
namespace Lean.Meta.Linear.Nat

View File

@@ -10,7 +10,8 @@ import Lean.Meta.Tactic.LinearArith.Nat.Basic
namespace Lean.Meta.Linear.Nat
def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let (some c, atoms) ToLinear.run (ToLinear.toLinearCnstr? e) | return none
let some (c, atoms) toLinearCnstr? e
| return none
withAbstractAtoms atoms ``Nat fun atoms => do
let lhs c.toArith atoms
let c₁ := c.toPoly
@@ -67,7 +68,7 @@ def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
simpCnstrPos? e
def simpExpr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let (e, ctx) ToLinear.run (ToLinear.toLinearExpr e)
let (e, ctx) toLinearExpr e
let p := e.toPoly
let p' := p.norm
if p'.length < p.length then

View File

@@ -35,3 +35,4 @@ import Lean.Util.SafeExponentiation
import Lean.Util.NumObjs
import Lean.Util.NumApps
import Lean.Util.FVarSubset
import Lean.Util.SortExprs

View File

@@ -0,0 +1,23 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Expr
namespace Lean
abbrev Perm := Std.HashMap Nat Nat
/--
Sorts the given expressions using `Expr.lt`, and creates a "permutation map" storing the new position of each expression.
-/
def sortExprs (es : Array Expr) : Array Expr × Perm :=
let es := es.mapIdx fun i e => (e, i)
let es := es.qsort fun (e₁, _) (e₂, _) => e₁.lt e₂
let (_, perm) := es.foldl (init := (0, Std.HashMap.empty)) fun (i, perm) (_, j) => (i+1, perm.insert j i)
let es := es.map (·.1)
(es, perm)
end Lean

View File

@@ -171,18 +171,18 @@ fun x y z f =>
id
(Int.Linear.ExprCnstr.eq_true_of_isValid
(Lean.RArray.branch 1 (Lean.RArray.leaf x)
(Lean.RArray.branch 2 (Lean.RArray.leaf x_1) (Lean.RArray.leaf z)))
(Lean.RArray.branch 2 (Lean.RArray.leaf z) (Lean.RArray.leaf x_1)))
(Int.Linear.ExprCnstr.le
((((((Int.Linear.Expr.var 0).add (Int.Linear.Expr.var 1)).add (Int.Linear.Expr.num 2)).add
(Int.Linear.Expr.var 1)).add
(Int.Linear.Expr.var 2)).add
(Int.Linear.Expr.var 2))
(((((((Int.Linear.Expr.var 1).add (Int.Linear.Expr.mulL 3 (Int.Linear.Expr.var 2))).add
((((((Int.Linear.Expr.var 0).add (Int.Linear.Expr.var 2)).add (Int.Linear.Expr.num 2)).add
(Int.Linear.Expr.var 2)).add
(Int.Linear.Expr.var 1)).add
(Int.Linear.Expr.var 1))
(((((((Int.Linear.Expr.var 2).add (Int.Linear.Expr.mulL 3 (Int.Linear.Expr.var 1))).add
(Int.Linear.Expr.num 1)).add
(Int.Linear.Expr.num 1)).add
(Int.Linear.Expr.var 0)).add
(Int.Linear.Expr.var 1)).sub
(Int.Linear.Expr.var 2)))
(Int.Linear.Expr.var 2)).sub
(Int.Linear.Expr.var 1)))
(Eq.refl true)))
(f y))
-/
@@ -256,3 +256,12 @@ example (x : Int) : (11*x ≤ 10) ↔ (x ≤ 0) := by
example (x : Int) : (11*x > 10) (x 1) := by
simp +arith only
example (x y : Int) : (2*x + y + y = 4) (y + x = 2) := by
simp +arith
example (x y : Int) : (2*x + y + y 3) (y + x 1) := by
simp +arith
example (f : Int Int) (x y : Int) : f (2*x + y) = f (y + x + x) := by
simp +arith

View File

@@ -0,0 +1,8 @@
example (x y : Nat) : (2*x + y = 4) (y + x + x = 4) := by
simp +arith
example (x y : Nat) : (2*x + y 3) (y + x + x 3) := by
simp +arith
example (f : Nat Nat) (x y : Nat) : f (2*x + y) = f (y + x + x) := by
simp +arith