Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
746f7c2c5e feat: compute PreNullCert 2025-04-23 12:05:43 -07:00
3 changed files with 80 additions and 20 deletions

View File

@@ -45,25 +45,30 @@ structure PreNullCert where
Thus, we need to track a denominator to justify the proof step `div`.
-/
d : Int := 1
deriving Inhabited
def PreNullCert.unit (i : Nat) (n : Nat) : PreNullCert :=
let qs := Array.replicate n (.num 0)
let qs := qs.set! i (.num 1)
{ qs }
def PreNullCert.mul (c : PreNullCert) (k : Int) (char? : Option Nat) : PreNullCert :=
if k == 1 then c
def PreNullCert.div (c : PreNullCert) (k : Int) : RingM PreNullCert := do
return { c with d := c.d * k }
def PreNullCert.mul (c : PreNullCert) (k : Int) : RingM PreNullCert := do
if k == 1 then
return c
else
let g := Int.gcd k c.d
let k := k / g
let d := c.d / g
if k == 1 then
{ c with d }
return { c with d }
else
let qs := c.qs.map fun q => if q.isZero then q else q.mulConst' k char?
{ qs, d }
let qs c.qs.mapM fun q => q.mulConstM k
return { qs, d }
def PreNullCert.combine (k₁ : Int) (m₁ : Mon) (c₁ : PreNullCert) (k₂ : Int) (m₂ : Mon) (c₂ : PreNullCert) (char? : Option Nat) : PreNullCert := Id.run do
def PreNullCert.combine (k₁ : Int) (m₁ : Mon) (c₁ : PreNullCert) (k₂ : Int) (m₂ : Mon) (c₂ : PreNullCert) : RingM PreNullCert := do
let d₁ := c₁.d
let d₂ := c₂.d
let k₁_d₂ := k₁*d₂
@@ -79,17 +84,17 @@ def PreNullCert.combine (k₁ : Int) (m₁ : Mon) (c₁ : PreNullCert) (k₂ : I
let mut qs : Vector Poly n := Vector.replicate n (.num 0)
for h : i in [:n] do
if h₁ : i < qs₁.size then
let q₁ := qs₁[i].mulMon' k₁ m₁ char?
let q₁ qs₁[i].mulMonM k₁ m₁
if h₂ : i < qs₂.size then
let q₂ := qs₂[i].mulMon' k₂ m₂ char?
qs := qs.set i (q₁.combine' q₂ char?)
let q₂ qs₂[i].mulMonM k₂ m₂
qs := qs.set i ( q₁.combineM q₂)
else
qs := qs.set i q₁
else
have : i < n := h.upper
have : qs₁.size = n qs₂.size = n := by simp +zetaDelta [Nat.max_def]; split <;> simp [*]
have : i < qs₂.size := by omega
let q₂ := qs₂[i].mulMon' k₂ m₂ char?
let q₂ qs₂[i].mulMonM k₂ m₂
qs := qs.set i q₂
return { qs := qs.toArray, d }
@@ -101,9 +106,57 @@ structure NullCertHypothesis where
structure ProofM.State where
/-- Mapping from `EqCnstr` to `PreNullCert` -/
cache : Std.HashMap UInt64 PreNullCert := {}
hypToId : Std.HashMap UInt64 Nat := {}
hyps : Array NullCertHypothesis := #[]
abbrev ProofM := StateRefT ProofM.State RingM
private abbrev caching (c : α) (k : ProofM PreNullCert) : ProofM PreNullCert := do
let addr := unsafe (ptrAddrUnsafe c).toUInt64 >>> 2
if let some h := ( get).cache[addr]? then
return h
else
let h k
modify fun s => { s with cache := s.cache.insert addr h }
return h
partial def EqCnstr.toPreNullCert (c : EqCnstr) : ProofM PreNullCert := caching c do
match c.h with
| .core a b lhs rhs =>
let i := ( get).hyps.size
let h mkEqProof a b
modify fun s => { s with hyps := s.hyps.push { h, lhs, rhs } }
return PreNullCert.unit i (i+1)
| .superpose c₁ c₂ k₁ k₂ m₁ m₂ => ( c₁.toPreNullCert).combine k₁ m₁ k₂ m₂ ( c₂.toPreNullCert)
| .simp c₁ c₂ k₁ k₂ m => ( c₁.toPreNullCert).combine k₁ m k₂ .unit ( c₂.toPreNullCert)
| .mul k c => ( c.toPreNullCert).mul k
| .div k c => ( c.toPreNullCert).div k
structure NullCertExt where
d : Int
qhs : Array (Poly × NullCertHypothesis)
def EqCnstr.mkNullCertExt (c : EqCnstr) : RingM NullCertExt := do
let (nc, s) c.toPreNullCert.run {}
return { d := nc.d, qhs := nc.qs.zip s.hyps }
def NullCertExt.toPoly (nc : NullCertExt) : RingM Poly := do
let mut p : Poly := .num 0
for (q, h) in nc.qhs do
p p.combineM ( q.mulM ( (h.lhs.sub h.rhs).toPolyM))
return p
def NullCertExt.check (c : EqCnstr) (nc : NullCertExt) : RingM Bool := do
let p₁ := c.p.mulConst' nc.d ( nonzeroChar?)
let p₂ nc.toPoly
return p₁ == p₂
def setInconsistent (c : EqCnstr) : RingM Unit := do
trace_goal[grind.ring.assert.unsat] "{← c.denoteExpr}"
let nc c.mkNullCertExt
trace_goal[grind.ring.assert.unsat] "{nc.d}*({← c.p.denoteExpr}), {← (← nc.toPoly).denoteExpr}"
trace_goal[grind.ring.assert.unsat] "{nc.d}*({← c.p.denoteExpr}), {← nc.qhs.mapM fun (p, h) => return (← p.denoteExpr, ← h.lhs.denoteExpr, ← h.rhs.denoteExpr) }"
-- TODO
private def mkLemmaPrefix (declName declNameC : Name) : RingM Expr := do
let ring getRing
let ctx toContextExpr
@@ -123,8 +176,5 @@ def setEqUnsat (k : Int) (a b : Expr) (ra rb : RingExpr) : RingM Unit := do
h := mkApp h charInst
closeGoal <| mkApp5 h (toExpr ra) (toExpr rb) (toExpr k) reflBoolTrue ( mkEqProof a b)
def setInconsistent (c : EqCnstr) : RingM Unit := do
trace_goal[grind.ring.assert.unsat] "{← c.denoteExpr}"
-- TODO
end Lean.Meta.Grind.Arith.CommRing

View File

@@ -26,7 +26,7 @@ structure EqCnstr where
inductive EqCnstrProof where
| core (a b : Expr) (ra rb : RingExpr)
| superpose (c₁ c₂ : EqCnstr)
| superpose (c₁ c₂ : EqCnstr) (k₁ k₂ : Int) (m₁ m₂ : Mon)
| simp (c₁ c₂ : EqCnstr) (k₁ k₂ : Int) (m : Mon)
| mul (k : Int) (e : EqCnstr)
| div (k : Int) (e : EqCnstr)
@@ -104,7 +104,7 @@ inductive SimpChain where
```
If we have a commutative ring where
```
∀ (k : Int) (a b : α), k ≠ 0 → (intCast k) * a = 0 → a = 0
∀ (k : Int) (a : α), k ≠ 0 → (intCast k) * a = 0 → a = 0
```
grind can deduce that `x+y+z = 0`
-/

View File

@@ -5,6 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly
namespace Lean.Meta.Grind.Arith.CommRing
@@ -77,9 +78,18 @@ Converts the given ring expression into a multivariate polynomial.
If the ring has a nonzero characteristic, it is used during normalization.
-/
def _root_.Lean.Grind.CommRing.Expr.toPolyM (e : RingExpr) : RingM Poly := do
if let some c nonzeroChar? then
return e.toPolyC c
else
return e.toPoly
if let some c nonzeroChar? then return e.toPolyC c else return e.toPoly
def _root_.Lean.Grind.CommRing.Poly.mulConstM (p : Poly) (k : Int) : RingM Poly :=
return p.mulConst' k ( nonzeroChar?)
def _root_.Lean.Grind.CommRing.Poly.mulMonM (p : Poly) (k : Int) (m : Mon) : RingM Poly :=
return p.mulMon' k m ( nonzeroChar?)
def _root_.Lean.Grind.CommRing.Poly.mulM (p₁ p₂ : Poly) : RingM Poly := do
if let some c nonzeroChar? then return p₁.mulC p₂ c else return p₁.mul p₂
def _root_.Lean.Grind.CommRing.Poly.combineM (p₁ p₂ : Poly) : RingM Poly :=
return p₁.combine' p₂ ( nonzeroChar?)
end Lean.Meta.Grind.Arith.CommRing