Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
204d6d446b perf: Expr.toPoly in grind
This PR adds a version of `CommRing.Expr.toPoly` optimized for kernel
reduction. We use this function not only to implement `grind ring`,
but also to interface the ring module with `grind cutsat`.
2025-08-04 08:11:01 -07:00
2 changed files with 55 additions and 6 deletions

View File

@@ -636,9 +636,57 @@ def Expr.toPoly : Expr → Poly
.num 1
else match a with
| .num n => .num (n^k)
| .intCast n => .num (n^k)
| .natCast n => .num (n^k)
| .var x => Poly.ofMon (.mult {x, k} .unit)
| _ => a.toPoly.pow k
@[expose] noncomputable def Expr.toPoly_k (e : Expr) : Poly :=
Expr.rec
(fun k => .num k) (fun k => .num k) (fun k => .num k)
(fun x => .ofVar x)
(fun _ ih => ih.mulConst_k (-1))
(fun _ _ ih₁ ih₂ => ih₁.combine_k ih₂)
(fun _ _ ih₁ ih₂ => ih₁.combine_k (ih₂.mulConst_k (-1)))
(fun _ _ ih₁ ih₂ => ih₁.mul ih₂)
(fun a k ih => Bool.rec
(Expr.rec (fun n => .num (n^k)) (fun n => .num (n^k)) (fun n => (.num (n^k)))
(fun x => .ofMon (.mult {x, k} .unit)) (fun _ _ => ih.pow k)
(fun _ _ _ _ => ih.pow k)
(fun _ _ _ _ => ih.pow k)
(fun _ _ _ _ => ih.pow k)
(fun _ _ _ => ih.pow k)
a)
(.num 1)
(k.beq 0))
e
@[simp] theorem Expr.toPoly_k_eq_toPoly (e : Expr) : e.toPoly_k = e.toPoly := by
induction e <;> simp only [toPoly, toPoly_k]
next a ih => rw [Poly.mulConst_k_eq_mulConst]; congr
case add => rw [ Poly.combine_k_eq_combine]; congr
case sub => rw [ Poly.combine_k_eq_combine, Poly.mulConst_k_eq_mulConst]; congr
case mul => congr
case pow a k ih =>
rw [cond_eq_if]; split
next h => rw [Nat.beq_eq_true_eq, Nat.beq_eq] at h; rw [h]
next h =>
rw [Nat.beq_eq_true_eq, Nat.beq_eq, Bool.not_eq_true] at h; rw [h]; dsimp only
show
(Expr.rec (fun n => .num (n^k)) (fun n => .num (n^k)) (fun n => (.num (n^k)))
(fun x => .ofMon (.mult {x, k} .unit)) (fun _ _ => a.toPoly_k.pow k)
(fun _ _ _ _ => a.toPoly_k.pow k)
(fun _ _ _ _ => a.toPoly_k.pow k)
(fun _ _ _ _ => a.toPoly_k.pow k)
(fun _ _ _ => a.toPoly_k.pow k)
a) = match a with
| num n => Poly.num (n ^ k)
| .intCast n => .num (n^k)
| .natCast n => .num (n^k)
| var x => Poly.ofMon (Mon.mult { x := x, k := k } Mon.unit)
| x => a.toPoly.pow k
cases a <;> try simp [*]
def Poly.normEq0 (p : Poly) (c : Nat) : Poly :=
match p with
| .num a =>
@@ -1050,6 +1098,7 @@ theorem Expr.denote_toPoly {α} [CommRing α] (ctx : Context α) (e : Expr)
neg_mul, one_mul, sub_eq_add_neg, denoteInt_eq, *]
next => rw [Ring.intCast_natCast]
next a k h => simp at h; simp [h, Semiring.pow_zero]
next => rw [Ring.intCast_natCast]
next => simp [Poly.denote_ofMon, Mon.denote, Power.denote_eq, mul_one]
theorem Expr.eq_of_toPoly_eq {α} [CommRing α] (ctx : Context α) (a b : Expr) (h : a.toPoly == b.toPoly) : a.denote ctx = b.denote ctx := by
@@ -1320,7 +1369,7 @@ Theorems for stepwise proof-term construction
-/
@[expose]
noncomputable def core_cert (lhs rhs : Expr) (p : Poly) : Bool :=
(lhs.sub rhs).toPoly.beq' p
(lhs.sub rhs).toPoly_k.beq' p
theorem core {α} [CommRing α] (ctx : Context α) (lhs rhs : Expr) (p : Poly)
: core_cert lhs rhs p lhs.denote ctx = rhs.denote ctx p.denote ctx = 0 := by
@@ -1403,7 +1452,7 @@ theorem d_stepk {α} [CommRing α] (ctx : Context α) (k₁ : Int) (k : Int) (in
@[expose]
noncomputable def imp_1eq_cert (lhs rhs : Expr) (p₁ p₂ : Poly) : Bool :=
(lhs.sub rhs).toPoly.beq' p₁ |>.and' (p₂.beq' (.num 0))
(lhs.sub rhs).toPoly_k.beq' p₁ |>.and' (p₂.beq' (.num 0))
theorem imp_1eq {α} [CommRing α] (ctx : Context α) (lhs rhs : Expr) (p₁ p₂ : Poly)
: imp_1eq_cert lhs rhs p₁ p₂ (1:Int) * p₁.denote ctx = p₂.denote ctx lhs.denote ctx = rhs.denote ctx := by
@@ -1412,7 +1461,7 @@ theorem imp_1eq {α} [CommRing α] (ctx : Context α) (lhs rhs : Expr) (p₁ p
@[expose]
noncomputable def imp_keq_cert (lhs rhs : Expr) (k : Int) (p₁ p₂ : Poly) : Bool :=
!Int.beq' k 0 |>.and' ((lhs.sub rhs).toPoly.beq' p₁ |>.and' (p₂.beq' (.num 0)))
!Int.beq' k 0 |>.and' ((lhs.sub rhs).toPoly_k.beq' p₁ |>.and' (p₂.beq' (.num 0)))
theorem imp_keq {α} [CommRing α] (ctx : Context α) [NoNatZeroDivisors α] (k : Int) (lhs rhs : Expr) (p₁ p₂ : Poly)
: imp_keq_cert lhs rhs k p₁ p₂ k * p₁.denote ctx = p₂.denote ctx lhs.denote ctx = rhs.denote ctx := by
@@ -1725,7 +1774,7 @@ theorem d_normEq0 {α} [CommRing α] (ctx : Context α) (k : Int) (c : Nat) (ini
intro h; rw [p₁.normEq0_eq] <;> assumption
@[expose] noncomputable def norm_int_cert (e : Expr) (p : Poly) : Bool :=
e.toPoly.beq' p
e.toPoly_k.beq' p
theorem norm_int (ctx : Context Int) (e : Expr) (p : Poly) : norm_int_cert e p e.denote ctx = p.denote' ctx := by
simp [norm_int_cert, Poly.denote'_eq_denote]; intro; subst p; simp [Expr.denote_toPoly]

View File

@@ -1,5 +1,5 @@
-- Comparisons against `omega`:
-- set_option diagnostics true
-- This one is much slower (~10s in the kernel) than omega (~2s in the kernel).
example {a b c d e f a' b' c' d' e' f' : Int}
(h₁ : c = a + 3 * b) (h₂ : c' = a' + b') (h₃ : d = 2 * a + 3 * b) (h₄ : d' = 2 * a' + b') (h₅ : e = a + b)
@@ -52,4 +52,4 @@ example {a b c d e f a' b' c' d' e' f' : Int}
f = 2 f' = 1
f = -1 f' = -2
f = -2 f' = -1 f = 1 f' = 3 f = 3 f' = 1 f = -1 f' = -3 f = -3 f' = -1) :
a = 0 b = 0 := by grind (splits := 50)
a = 0 b = 0 := by grind (splits := 40)