Compare commits

...

2 Commits

Author SHA1 Message Date
Leonardo de Moura
f0a19b2425 feat: abstract some proofs 2025-08-11 16:54:04 -07:00
Leonardo de Moura
3bf6e6f93c feat: infrastructure for proof abstraction 2025-08-11 16:01:38 -07:00
2 changed files with 142 additions and 21 deletions

View File

@@ -28,8 +28,14 @@ structure ProofM.State where
exprMap : Std.HashMap Int.Linear.Expr Expr := {}
ringPolyMap : Std.HashMap CommRing.Poly Expr := {}
ringExprMap : Std.HashMap CommRing.RingExpr Expr := {}
paramNames : Array Name
paramFVars : Array Expr
paramTypes : Array Expr
args : Array Expr
structure ProofM.Context where
vars : Array Expr
vars' : Array Expr
ctx : Expr
/-- Variables before reordering -/
ctx' : Expr
@@ -52,7 +58,7 @@ Execute `k` with `unordered := true`, and the unordered variable context.
We use this combinator to process `.reorder c` justifications.
-/
private abbrev withUnordered (k : ProofM α) : ProofM α := do
withReader (fun c => { c with ctx := c.ctx', unordered := true }) k
withReader (fun c => { c with ctx := c.ctx', vars := c.vars', unordered := true }) k
private abbrev getVarMap : ProofM (PHashMap ExprPtr Var) := do
if ( read).unordered then
@@ -97,17 +103,11 @@ private def mkRingExprDecl (e : CommRing.RingExpr) : ProofM Expr := do
modify fun s => { s with ringExprMap := s.ringExprMap.insert e x }
return x
private def toContextExprCore (vars : PArray Expr) (type : Expr) : MetaM Expr :=
private def toContextExpr (vars : Array Expr) : MetaM Expr :=
if h : 0 < vars.size then
RArray.toExpr type id (RArray.ofFn (vars[·]) h)
RArray.toExpr Int.mkType id (RArray.ofFn (vars[·]) h)
else
RArray.toExpr type id (RArray.leaf (mkIntLit 0))
private def toContextExpr : GoalM Expr := do
toContextExprCore ( getVars) Int.mkType
private def toContextExpr' : GoalM Expr := do
toContextExprCore ( get').vars' Int.mkType
RArray.toExpr Int.mkType id (RArray.leaf (mkIntLit 0))
private def toRingContextExpr : GoalM Expr := do
if ( get').usedCommRing then
@@ -115,11 +115,45 @@ private def toRingContextExpr : GoalM Expr := do
return ( CommRing.RingM.run ringId do CommRing.toContextExpr)
RArray.toExpr Int.mkType id (RArray.leaf (mkIntLit 0))
private def initStateAndVars : GoalM (ProofM.State × Array Expr × Array Expr) := do
let mut vars := #[]
let mut vars' := #[]
let mut paramNames := #[]
let mut paramFVars := #[]
let mut paramTypes := #[]
let mut args := #[]
let mut i := 0
for expr in ( get').vars do
i := i + 1
if expr.isFVar then
vars := vars.push expr
else
let x := mkFVar ( mkFreshFVarId)
vars := vars.push x
paramNames := paramNames.push <| (`x).appendIndexAfter i
paramFVars := paramFVars.push x
paramTypes := paramTypes.push Int.mkType
args := args.push expr
i := 0
for expr in ( get').vars' do
i := i + 1
if expr.isFVar then
vars' := vars'.push expr
else
let x := mkFVar ( mkFreshFVarId)
vars' := vars'.push x
paramNames := paramNames.push <| (`y).appendIndexAfter i
paramFVars := paramFVars.push x
paramTypes := paramTypes.push Int.mkType
args := args.push expr
return ({ paramNames, paramTypes, paramFVars, args }, vars, vars')
private abbrev withProofContext (x : ProofM Expr) : GoalM Expr := do
withLetDecl `ctx (mkApp (mkConst ``RArray [levelZero]) Int.mkType) ( toContextExpr) fun ctx => do
withLetDecl `ctx' (mkApp (mkConst ``RArray [levelZero]) Int.mkType) ( toContextExpr') fun ctx' => do
let (s, vars, vars') initStateAndVars
withLetDecl `ctx (mkApp (mkConst ``RArray [levelZero]) Int.mkType) ( toContextExpr vars) fun ctx => do
withLetDecl `ctx' (mkApp (mkConst ``RArray [levelZero]) Int.mkType) ( toContextExpr vars') fun ctx' => do
withLetDecl `rctx (mkApp (mkConst ``RArray [levelZero]) Int.mkType) ( toRingContextExpr) fun ringCtx => do
go { ctx, ctx', ringCtx } |>.run' {}
go { vars, vars', ctx, ctx', ringCtx } |>.run' s
where
go : ProofM Expr := do
let h x
@@ -127,7 +161,80 @@ where
let h mkLetOfMap ( get).exprMap h `e (mkConst ``Int.Linear.Expr) toExpr
let h mkLetOfMap ( get).ringPolyMap h `rp (mkConst ``Grind.CommRing.Poly) toExpr
let h mkLetOfMap ( get).ringExprMap h `re (mkConst ``Grind.CommRing.Expr) toExpr
mkLetFVars #[( getContext), ( read).ctx', ( read).ringCtx ] h
let h mkLetFVars #[( getContext), ( read).ctx', ( read).ringCtx ] h
let h := mkLambdaN ( get).paramNames ( get).paramFVars ( get).paramTypes h
return mkAppN h ( get).args
/-- Returns an expression representing polynomial `p` using abstract variables in `ProofM` -/
private def denotePoly (p : Poly) : ProofM Expr := do
let vars := ( read).vars
p.denoteExpr (vars[·]!)
/-- Returns a Lean expression representing the linear expression `e` using abstract variables in `ProofM` -/
private def denoteLinExpr (e : Int.Linear.Expr) : ProofM Expr := do
let vars := ( read).vars
e.denoteExpr (vars[·]!)
/-- Returns an expression representing polynomial `p ≤ 0` using abstract variables in `ProofM` -/
private def denotePolyLE (p : Poly) : ProofM Expr :=
return mkIntLE ( denotePoly p) (mkIntLit 0)
private def denoteLinExprLE (lhs rhs : Int.Linear.Expr) : ProofM Expr :=
return mkIntLE ( denoteLinExpr lhs) ( denoteLinExpr rhs)
private def denoteVarPolyEq (x : Var) (p : Poly) : ProofM Expr :=
return mkIntEq ( read).vars[x]! ( denotePoly p)
private def denotePolyEq (lhs rhs : Poly) : ProofM Expr :=
return mkIntEq ( denotePoly lhs) ( denotePoly rhs)
private def denoteLinExprEq (lhs rhs : Int.Linear.Expr) : ProofM Expr :=
return mkIntEq ( denoteLinExpr lhs) ( denoteLinExpr rhs)
private def denoteLinExprNE (lhs rhs : Int.Linear.Expr) : ProofM Expr :=
return mkNot ( denoteLinExprEq lhs rhs)
/--
Given an "external" proof `h`, and its abstract type `abstType` using the abstract variables in `ProofM`,
creates a new "abstracted" hypothesis with type `abstType` and records `h` as the argument to instantiate it.
-/
private def mkHyp (h : Expr) (abstType : Expr) : ProofM Expr := do
let h' := mkFVar ( mkFreshFVarId)
modify fun s => { s with
paramNames := s.paramNames.push <| (`h).appendIndexAfter (s.paramNames.size + 1)
paramFVars := s.paramFVars.push h'
paramTypes := s.paramTypes.push abstType
args := s.args.push h
}
return h'
/--
Given an "external" inequality proof `h : p ≤ 0`, create an abstract hypothesis for it using the abstract
variables in `ProofM`, and records `h` as the argument to instantiate it.
-/
private def mkHypPolyLE (h : Expr) (p : Poly) : ProofM Expr := do
mkHyp h ( denotePolyLE p)
/-- Similar to `mkHypPolyLE` -/
private def mkHypLinExprLE (h : Expr) (lhs rhs : Int.Linear.Expr) : ProofM Expr := do
mkHyp h ( denoteLinExprLE lhs rhs)
/-- Similar to `mkHypPolyLE` -/
private def mkHypPolyEq (h : Expr) (lhs rhs : Poly) : ProofM Expr := do
mkHyp h ( denotePolyEq lhs rhs)
/-- Similar to `mkHypPolyLE` -/
private def mkHypVarPolyEq (h : Expr) (x : Var) (p : Poly) : ProofM Expr := do
mkHyp h ( denoteVarPolyEq x p)
/-- Similar to `mkHypPolyLE` -/
private def mkHypLinExprEq (h : Expr) (lhs rhs : Int.Linear.Expr) : ProofM Expr := do
mkHyp h ( denoteLinExprEq lhs rhs)
/-- Similar to `mkHypPolyLE` -/
private def mkHypLinExprNE (h : Expr) (lhs rhs : Int.Linear.Expr) : ProofM Expr := do
mkHyp h ( denoteLinExprNE lhs rhs)
/--
Returns a Lean expression representing the auxiliary `CommRing` variable context needed for normalizing
@@ -148,14 +255,15 @@ partial def EqCnstr.toExprProof (c' : EqCnstr) : ProofM Expr := caching c' do
| .core0 a zero =>
mkEqProof a zero
| .core a b p₁ p₂ =>
let h mkEqProof a b
let h mkHypPolyEq ( mkEqProof a b) p₁ p₂
return mkApp6 (mkConst ``Int.Linear.eq_of_core) ( getContext) ( mkPolyDecl p₁) ( mkPolyDecl p₂) ( mkPolyDecl c'.p) reflBoolTrue h
| .coreToInt a b toIntThm lhs rhs =>
let h := mkApp toIntThm ( mkEqProof a b)
let h mkHypLinExprEq (mkApp toIntThm ( mkEqProof a b)) lhs rhs
return mkApp6 (mkConst ``Int.Linear.eq_norm_expr) ( getContext) ( mkExprDecl lhs) ( mkExprDecl rhs) ( mkPolyDecl c'.p) reflBoolTrue h
| .defn e p =>
let some x := ( getVarMap).find? { expr := e } | throwError "`grind` internal error, missing cutsat variable{indentExpr e}"
return mkApp6 (mkConst ``Int.Linear.eq_def) ( getContext) (toExpr x) ( mkPolyDecl p) ( mkPolyDecl c'.p) reflBoolTrue ( mkEqRefl e)
let h mkHypVarPolyEq ( mkEqRefl e) x p
return mkApp6 (mkConst ``Int.Linear.eq_def) ( getContext) (toExpr x) ( mkPolyDecl p) ( mkPolyDecl c'.p) reflBoolTrue h
| .defnNat h x e =>
return mkApp6 (mkConst ``Int.Linear.eq_def') ( getContext) (toExpr x) ( mkExprDecl e) ( mkPolyDecl c'.p) reflBoolTrue h
| .norm c =>
@@ -256,16 +364,19 @@ partial def LeCnstr.toExprProof (c' : LeCnstr) : ProofM Expr := caching c' do
trace[grind.debug.cutsat.proof] "{← c'.pp}"
match c'.h with
| .core e =>
mkOfEqTrue ( mkEqTrueProof e)
let h mkHypPolyLE (mkOfEqTrueCore e ( mkEqTrueProof e)) c'.p
return h
| .coreNeg e p =>
let h mkOfEqFalse ( mkEqFalseProof e)
let h mkHypPolyLE (mkOfEqFalseCore e ( mkEqFalseProof e)) p
return mkApp5 (mkConst ``Int.Linear.le_neg) ( getContext) ( mkPolyDecl p) ( mkPolyDecl c'.p) reflBoolTrue h
| .coreToInt e pos toIntThm lhs rhs =>
let h if pos then pure <| mkOfEqTrueCore e ( mkEqTrueProof e) else pure <| mkOfEqFalseCore e ( mkEqFalseProof e)
let h := mkApp toIntThm h
let h mkHypLinExprLE h lhs rhs
return mkApp6 (mkConst ``Int.Linear.le_norm_expr) ( getContext) ( mkExprDecl lhs) ( mkExprDecl rhs) ( mkPolyDecl c'.p) reflBoolTrue h
| .ofNatNonneg a =>
return mkApp (mkConst ``Nat.ToInt.toNat_nonneg) a
let h mkHypPolyLE (mkApp (mkConst ``Nat.ToInt.toNat_nonneg) a) c'.p
return h
| .bound h => return h
| .dec h =>
return mkFVar h
@@ -339,7 +450,7 @@ partial def DiseqCnstr.toExprProof (c' : DiseqCnstr) : ProofM Expr := caching c'
let h mkDiseqProof a b
return mkApp6 (mkConst ``Int.Linear.diseq_of_core) ( getContext) ( mkPolyDecl p₁) ( mkPolyDecl p₂) ( mkPolyDecl c'.p) reflBoolTrue h
| .coreToInt a b toIntThm lhs rhs =>
let h := mkApp toIntThm ( mkDiseqProof a b)
let h mkHypLinExprNE (mkApp toIntThm ( mkDiseqProof a b)) lhs rhs
return mkApp6 (mkConst ``Int.Linear.not_eq_norm_expr) ( getContext) ( mkExprDecl lhs) ( mkExprDecl rhs) ( mkPolyDecl c'.p) reflBoolTrue h
| .norm c =>
return mkApp5 (mkConst ``Int.Linear.diseq_norm) ( getContext) ( mkPolyDecl c.p) ( mkPolyDecl c'.p) reflBoolTrue ( c.toExprProof)

View File

@@ -25,4 +25,14 @@ def mkLetOfMap {_ : Hashable α} {_ : BEq α} (m : Std.HashMap α Expr) (e : Exp
i := i - 1
return e
def mkLambdaN (ns : Array Name) (xs : Array Expr) (xsTypes : Array Expr) (b : Expr) : Expr :=
if _ : xs.size xsTypes.size xs.size ns.size then unreachable! else
let b := b.abstract xs
xs.size.foldRev (init := b) fun i _ b =>
let n := ns[i]
let x := xs[i]
let xType := xsTypes[i]
let xType := xType.abstractRange i xs
mkLambda n .default xType b
end Lean.Meta.Grind.Arith