Compare commits

...

13 Commits

Author SHA1 Message Date
Leonardo de Moura
8578f4b091 chore: remove unnecessary annotation 2025-04-17 21:18:26 -07:00
Leonardo de Moura
117b582da5 test: IsCharP 2025-04-17 21:00:00 -07:00
Leonardo de Moura
14791b25b4 chore: use actual value 2025-04-17 20:59:46 -07:00
Leonardo de Moura
df3d362a14 fix: mark p as outParam in IsCharP 2025-04-17 20:58:55 -07:00
Leonardo de Moura
0d56846ee4 test: CommRing 2025-04-17 20:52:26 -07:00
Leonardo de Moura
c4fb6af6fd feat: add Poly.isSorted 2025-04-17 20:52:26 -07:00
Leonardo de Moura
e2509a8ddc fix: sort monomials in decreasing order 2025-04-17 20:52:26 -07:00
Leonardo de Moura
50052dc077 fix: missing case 2025-04-17 20:52:26 -07:00
Leonardo de Moura
5c9471246f chore: cleanup 2025-04-17 20:52:26 -07:00
Leonardo de Moura
253fbc897c feat: improve addConst 2025-04-17 20:52:26 -07:00
Leonardo de Moura
e13a4607c8 feat: add helper theorems 2025-04-17 20:52:26 -07:00
Leonardo de Moura
a07322aa60 feat: LawfulBEq for Power, Mon, and Poly 2025-04-17 20:52:26 -07:00
Leonardo de Moura
57038e2ae5 chore: remove workaround 2025-04-17 20:52:26 -07:00
4 changed files with 109 additions and 32 deletions

View File

@@ -217,7 +217,7 @@ end CommRing
open CommRing
class IsCharP (α : Type u) [CommRing α] (p : Nat) where
class IsCharP (α : Type u) [CommRing α] (p : outParam Nat) where
ofNat_eq_zero_iff (p) : (x : Nat), OfNat.ofNat (α := α) x = 0 x % p = 0
namespace IsCharP

View File

@@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Init.Data.Nat.Lemmas
import Init.Data.Ord
import Init.Data.RArray
import Init.Grind.CommRing.Basic
namespace Lean.Grind
@@ -23,16 +24,6 @@ inductive Expr where
| pow (a : Expr) (k : Nat)
deriving Inhabited, BEq
-- TODO: add support for universes to Lean.RArray
inductive RArray (α : Type u) : Type u where
| leaf : α RArray α
| branch : Nat RArray α RArray α RArray α
def RArray.get (a : RArray α) (n : Nat) : α :=
match a with
| .leaf x => x
| .branch p l r => if n < p then l.get n else r.get n
abbrev Context (α : Type u) := RArray α
def Var.denote {α} (ctx : Context α) (v : Var) : α :=
@@ -52,6 +43,10 @@ structure Power where
k : Nat
deriving BEq, Repr
instance : LawfulBEq Power where
eq_of_beq {a} := by cases a <;> intro b <;> cases b <;> simp_all! [BEq.beq]
rfl := by intro a; cases a <;> simp! [BEq.beq]
def Power.varLt (p₁ p₂ : Power) : Bool :=
p₁.x.blt p₂.x
@@ -67,6 +62,18 @@ inductive Mon where
| cons (p : Power) (m : Mon)
deriving BEq, Repr
instance : LawfulBEq Mon where
eq_of_beq {a} := by
induction a <;> intro b <;> cases b <;> simp_all! [BEq.beq]
next p₁ p₂ => cases p₁ <;> cases p₂ <;> simp <;> intros <;> simp [*]
next p₁ m₁ p₂ m₂ ih =>
cases p₁ <;> cases p₂ <;> simp <;> intros <;> simp [*]
next h => exact ih h
rfl := by
intro a
induction a <;> simp! [BEq.beq]
assumption
def Mon.denote {α} [CommRing α] (ctx : Context α) : Mon α
| .leaf p => p.denote ctx
| .cons p m => p.denote ctx * denote ctx m
@@ -205,6 +212,20 @@ inductive Poly where
| add (k : Int) (v : Mon) (p : Poly)
deriving BEq
instance : LawfulBEq Poly where
eq_of_beq {a} := by
induction a <;> intro b <;> cases b <;> simp_all! [BEq.beq]
intro h₁ h₂ h₃
next m₁ p₁ _ m₂ p₂ ih =>
replace h₂ : m₁ == m₂ := h₂
simp [ih h₃, eq_of_beq h₂]
rfl := by
intro a
induction a <;> simp! [BEq.beq]
next k m p ih =>
show m == m p == p
simp [ih]
def Poly.denote [CommRing α] (ctx : Context α) (p : Poly) : α :=
match p with
| .num k => Int.cast k
@@ -216,10 +237,20 @@ def Poly.ofMon (m : Mon) : Poly :=
def Poly.ofVar (x : Var) : Poly :=
ofMon (Mon.ofVar x)
def Poly.isSorted : Poly Bool
| .num _ => true
| .add _ _ (.num _) => true
| .add _ m₁ (.add k m₂ p) => m₁.grevlex m₂ == .gt && (Poly.add k m₂ p).isSorted
def Poly.addConst (p : Poly) (k : Int) : Poly :=
match p with
bif k == 0 then
p
else
go p
where
go : Poly Poly
| .num k' => .num (k' + k)
| .add k' m p => .add k' m (addConst p k)
| .add k' m p => .add k' m (go p)
def Poly.insert (k : Int) (m : Mon) (p : Poly) : Poly :=
bif k == 0 then
@@ -232,13 +263,13 @@ where
| .add k' m' p =>
match m.grevlex m' with
| .eq =>
let k'' := k + k'
bif k'' == 0 then
let k := k + k'
bif k == 0 then
p
else
.add k'' m p
| .lt => .add k m (.add k' m' p)
| .gt => .add k' m' (go p)
.add k m p
| .gt => .add k m (.add k' m' p)
| .lt => .add k' m' (go p)
def Poly.concat (p₁ p₂ : Poly) : Poly :=
match p₁ with
@@ -264,7 +295,11 @@ def Poly.mulMon (k : Int) (m : Mon) (p : Poly) : Poly :=
go p
where
go : Poly Poly
| .num k' => .add (k*k') m (.num 0)
| .num k' =>
bif k' == 0 then
.num 0
else
.add (k*k') m (.num 0)
| .add k' m' p => .add (k*k') (m.mul m') (go p)
def Poly.combine (p₁ p₂ : Poly) : Poly :=
@@ -285,8 +320,8 @@ where
go fuel p₁ p₂
else
.add k m₁ (go fuel p₁ p₂)
| .lt => .add k₁ m₁ (go fuel p₁ (.add k₂ m₂ p₂))
| .gt => .add k₂ m₂ (go fuel (.add k₁ m₁ p₁) p₂)
| .gt => .add k₁ m₁ (go fuel p₁ (.add k₂ m₂ p₂))
| .lt => .add k₂ m₂ (go fuel (.add k₁ m₁ p₁) p₂)
def Poly.mul (p₁ : Poly) (p₂ : Poly) : Poly :=
go p₁ (.num 0)
@@ -344,8 +379,8 @@ where
p
else
.add k'' m p
| .lt => .add k m (.add k' m' p)
| .gt => .add k' m' (go k p)
| .gt => .add k m (.add k' m' p)
| .lt => .add k' m' (go k p)
def Poly.mulConstC (k : Int) (p : Poly) (c : Nat) : Poly :=
let k := k % c
@@ -404,8 +439,8 @@ where
go fuel p₁ p₂
else
.add k m₁ (go fuel p₁ p₂)
| .lt => .add k₁ m₁ (go fuel p₁ (.add k₂ m₂ p₂))
| .gt => .add k₂ m₂ (go fuel (.add k₁ m₁ p₁) p₂)
| .gt => .add k₁ m₁ (go fuel p₁ (.add k₂ m₂ p₂))
| .lt => .add k₂ m₂ (go fuel (.add k₁ m₁ p₁) p₂)
def Poly.mulC (p₁ : Poly) (p₂ : Poly) (c : Nat) : Poly :=
go p₁ (.num 0)
@@ -556,9 +591,12 @@ theorem Poly.denote_ofVar {α} [CommRing α] (ctx : Context α) (x : Var)
simp [ofVar, denote_ofMon, Mon.denote_ofVar]
theorem Poly.denote_addConst {α} [CommRing α] (ctx : Context α) (p : Poly) (k : Int) : (addConst p k).denote ctx = p.denote ctx + k := by
fun_induction addConst <;> simp [addConst, denote, *]
next => rw [intCast_add]
next => simp [add_comm, add_left_comm, add_assoc]
simp [addConst, cond_eq_if]; split
next => simp [*, intCast_zero, add_zero]
next =>
fun_induction addConst.go <;> simp [addConst.go, denote, *]
next => rw [intCast_add]
next => simp [add_comm, add_left_comm, add_assoc]
theorem Poly.denote_insert {α} [CommRing α] (ctx : Context α) (k : Int) (m : Mon) (p : Poly)
: (insert k m p).denote ctx = k * m.denote ctx + p.denote ctx := by
@@ -595,6 +633,7 @@ theorem Poly.denote_mulMon {α} [CommRing α] (ctx : Context α) (k : Int) (m :
next => simp [denote, *, intCast_zero, zero_mul]
next =>
fun_induction mulMon.go <;> simp [mulMon.go, denote, *]
next h => simp +zetaDelta at h; simp [*, intCast_zero, mul_zero]
next => simp [intCast_mul, intCast_zero, add_zero, mul_comm, mul_left_comm, mul_assoc]
next => simp [Mon.denote_mul, intCast_mul, left_distrib, mul_comm, mul_left_comm, mul_assoc]
@@ -635,6 +674,11 @@ theorem Expr.denote_toPoly {α} [CommRing α] (ctx : Context α) (e : Expr)
next => rw [intCast_pow]
next => simp [Poly.denote_ofMon, Mon.denote, Power.denote_eq]
theorem Expr.eq_of_toPoly_eq {α} [CommRing α] (ctx : Context α) (a b : Expr) (h : a.toPoly == b.toPoly) : a.denote ctx = b.denote ctx := by
have h := congrArg (Poly.denote ctx) (eq_of_beq h)
simp [denote_toPoly] at h
assumption
theorem Poly.denote_addConstC {α c} [CommRing α] [IsCharP α c] (ctx : Context α) (p : Poly) (k : Int) : (addConstC p k c).denote ctx = p.denote ctx + k := by
fun_induction addConstC <;> simp [addConstC, denote, *]
next => rw [IsCharP.intCast_emod, intCast_add]
@@ -747,5 +791,11 @@ theorem Expr.denote_toPolyC {α c} [CommRing α] [IsCharP α c] (ctx : Context
next => rw [IsCharP.intCast_emod, intCast_pow]
next => simp [Poly.denote_ofMon, Mon.denote, Power.denote_eq]
theorem Expr.eq_of_toPolyC_eq {α c} [CommRing α] [IsCharP α c] (ctx : Context α) (a b : Expr)
(h : a.toPolyC c == b.toPolyC c) : a.denote ctx = b.denote ctx := by
have h := congrArg (Poly.denote ctx) (eq_of_beq h)
simp [denote_toPolyC] at h
assumption
end CommRing
end Lean.Grind

View File

@@ -71,7 +71,7 @@ instance : CommRing UInt8 where
pow_succ := UInt8.pow_succ
ofNat_succ x := UInt8.ofNat_add x 1
instance : IsCharP UInt8 (2 ^ 8) where
instance : IsCharP UInt8 256 where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = UInt8.ofNat x := rfl
simp [this, UInt8.ofNat_eq_iff_mod_eq_toNat]
@@ -91,7 +91,7 @@ instance : CommRing UInt16 where
pow_succ := UInt16.pow_succ
ofNat_succ x := UInt16.ofNat_add x 1
instance : IsCharP UInt16 (2 ^ 16) where
instance : IsCharP UInt16 65536 where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = UInt16.ofNat x := rfl
simp [this, UInt16.ofNat_eq_iff_mod_eq_toNat]
@@ -111,7 +111,7 @@ instance : CommRing UInt32 where
pow_succ := UInt32.pow_succ
ofNat_succ x := UInt32.ofNat_add x 1
instance : IsCharP UInt32 (2 ^ 32) where
instance : IsCharP UInt32 4294967296 where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = UInt32.ofNat x := rfl
simp [this, UInt32.ofNat_eq_iff_mod_eq_toNat]
@@ -131,7 +131,7 @@ instance : CommRing UInt64 where
pow_succ := UInt64.pow_succ
ofNat_succ x := UInt64.ofNat_add x 1
instance : IsCharP UInt64 (2 ^ 64) where
instance : IsCharP UInt64 18446744073709551616 where
ofNat_eq_zero_iff {x} := by
have : OfNat.ofNat x = UInt64.ofNat x := rfl
simp [this, UInt64.ofNat_eq_iff_mod_eq_toNat]

View File

@@ -0,0 +1,27 @@
import Lean
import Init.Grind.CommRing.SOM
open Lean.Grind
open Lean.Grind.CommRing
-- Convenient RArray literals
elab tk:"#R[" ts:term,* "]" : term => do
let ts : Array Lean.Syntax := ts
let es ts.mapM fun stx => Lean.Elab.Term.elabTerm stx none
if h : 0 < es.size then
Lean.RArray.toExpr ( Lean.Meta.inferType es[0]!) id (Lean.RArray.ofArray es h)
else
throwErrorAt tk "RArray cannot be empty"
example (x y : Int) : (x + y) * (x + y + 1) = x * (1 + y + x) + (y + 1 + x) * y :=
let ctx := #R[x, y]
let lhs : Expr := .mul (.add (.var 0) (.var 1)) (.add (.add (.var 0) (.var 1)) (.num 1))
let rhs : Expr := .add (.mul (.var 0) (.add (.add (.num 1) (.var 1)) (.var 0)))
(.mul (.add (.add (.var 1) (.num 1)) (.var 0)) (.var 1))
Expr.eq_of_toPoly_eq ctx lhs rhs (Eq.refl true)
example (x y : UInt8) : (128 * x + y) * 2 = y + y :=
let ctx := #R[x, y]
let lhs : Expr := .mul (.add (.mul (.num 128) (.var 0)) (.var 1)) (.num 2)
let rhs : Expr := .add (.var 1) (.var 1)
Expr.eq_of_toPolyC_eq ctx lhs rhs (Eq.refl true)