Compare commits

...

4 Commits

Author SHA1 Message Date
Leonardo de Moura
f34cb4a684 chore: add isDefEqD 2025-08-25 20:56:32 -07:00
Leonardo de Moura
20f67bbf03 refactor: use pure functions 2025-08-25 20:42:18 -07:00
Leonardo de Moura
033c584cee chore: break big mutual declaration using foward declaration 2025-08-25 20:19:07 -07:00
Leonardo de Moura
985d78085b chore: avoid abbrev 2025-08-25 19:35:48 -07:00
14 changed files with 172 additions and 161 deletions

View File

@@ -26,7 +26,7 @@ def mkPowFn (u : Level) (type : Expr) (semiringInst : Expr) : m Expr := do
canonExpr <| mkApp4 (mkConst ``HPow.hPow [u, 0, u]) type Nat.mkType type inst
where
checkInst (inst inst' : Expr) : MetaM Unit := do
unless ( withDefault <| isDefEq inst inst') do
unless ( isDefEqD inst inst') do
throwError "instance for power operator{indentExpr inst}\nis not definitionally equal to the `Grind.Semiring` one{indentExpr inst'}"
def mkNatCastFn (u : Level) (type : Expr) (semiringInst : Expr) : m Expr := do
@@ -44,7 +44,7 @@ def mkNatCastFn (u : Level) (type : Expr) (semiringInst : Expr) : m Expr := do
canonExpr <| mkApp2 (mkConst ``NatCast.natCast [u]) type inst
where
checkInst (inst inst' : Expr) : MetaM Unit := do
unless ( withDefault <| isDefEq inst inst') do
unless ( isDefEqD inst inst') do
throwError "instance for natCast{indentExpr inst}\nis not definitionally equal to the `Grind.Semiring` one{indentExpr inst'}"
variable [MonadLiftT MetaM m] [MonadError m] [Monad m] [MonadRing m]
@@ -112,7 +112,7 @@ def getIntCastFn : m Expr := do
return intCastFn
where
checkInst (inst inst' : Expr) : MetaM Unit := do
unless ( withDefault <| isDefEq inst inst') do
unless ( isDefEqD inst inst') do
throwError "instance for intCast{indentExpr inst}\nis not definitionally equal to the `Grind.Ring` one{indentExpr inst'}"
def getNatCastFn : m Expr := do

View File

@@ -44,7 +44,7 @@ are not used.
Remark: recall that the `.reorder` proof objects are delimiters for indicating whether regular variables and
declarations or the prime ones should be used.
-/
structure ProofM.State where
private structure ProofM.State where
/-- Cache for visited cutsat proof terms. The key is the pointer address. -/
cache : Std.HashMap UInt64 Expr := {}
/-- Map from used variables to (temporary) free variable. -/
@@ -64,7 +64,7 @@ structure ProofM.State where
/-- Map from used ring expressions to free variable. -/
ringExprDecls : Std.HashMap CommRing.RingExpr Expr := {}
structure ProofM.Context where
private structure ProofM.Context where
ctx : Expr
/-- Variables before reordering -/
ctx' : Expr
@@ -76,7 +76,7 @@ structure ProofM.Context where
unordered : Bool := false
/-- Auxiliary monad for constructing cutsat proofs. -/
abbrev ProofM := ReaderT ProofM.Context (StateRefT ProofM.State GoalM)
private abbrev ProofM := ReaderT ProofM.Context (StateRefT ProofM.State GoalM)
/-- Returns a Lean expression representing the variable context used to construct cutsat proofs. -/
private def getContext : ProofM Expr := do
@@ -86,7 +86,7 @@ private def getContext : ProofM Expr := do
Execute `k` with `unordered := true`, and the unordered variable context.
We use this combinator to process `.reorder c` justifications.
-/
private abbrev withUnordered (k : ProofM α) : ProofM α := do
private def withUnordered (k : ProofM α) : ProofM α := do
withReader (fun c => { c with ctx := c.ctx', unordered := true }) k
/--
@@ -103,7 +103,7 @@ private def getVarOf (e : Expr) : ProofM Var := do
let some x := ( getVarMap).find? { expr := e } | throwError "`grind` internal error, missing cutsat variable{indentExpr e}"
return x
private abbrev caching (c : α) (k : ProofM Expr) : ProofM Expr := do
private def caching (c : α) (k : ProofM Expr) : ProofM Expr := do
let addr := unsafe (ptrAddrUnsafe c).toUInt64 >>> 2
if let some h := ( get).cache[addr]? then
return h
@@ -189,7 +189,7 @@ private def mkRingContext (h : Expr) : ProofM Expr := do
else
return h
private abbrev withProofContext (x : ProofM Expr) : GoalM Expr := do
private def withProofContext (x : ProofM Expr) : GoalM Expr := do
let ctx := mkFVar ( mkFreshFVarId)
let ctx' := mkFVar ( mkFreshFVarId)
let ringCtx := mkFVar ( mkFreshFVarId)
@@ -205,7 +205,7 @@ where
Returns a Lean expression representing the auxiliary `CommRing` variable context needed for normalizing
nonlinear polynomials.
-/
private abbrev getRingContext : ProofM Expr := do
private def getRingContext : ProofM Expr := do
return ( read).ringCtx
private def DvdCnstr.get_d_a (c : DvdCnstr) : GoalM (Int × Int) := do
@@ -224,13 +224,109 @@ private def _root_.Int.Linear.Poly.denoteExprUsingCurrVars (p : Poly) : ProofM E
let vars getCurrVars
return ( p.denoteExpr (vars[·]!))
inductive MulEqProof where
private inductive MulEqProof where
| const (k : Int) (h : Expr)
| mulVar (k : Int) (a : Expr) (h : Expr)
| none
@[extern "lean_cutsat_eq_cnstr_to_proof"] -- forward definition
private opaque EqCnstr.toExprProof (c' : EqCnstr) : ProofM Expr
private partial def mkMulEqProof (x : Var) (a? : Option Expr) (cs : Array (Expr × Int × EqCnstr)) (c' : EqCnstr) : ProofM Expr := do
let h go ( getCurrVars)[x]!
match h with
| .const k h =>
return mkApp6 (mkConst ``Int.Linear.of_var_eq) ( getContext) ( mkVarDecl x) (toExpr k) ( mkPolyDecl c'.p) eagerReflBoolTrue h
| .mulVar k a h =>
assert! a? == some a
let y getVarOf a
return mkApp7 (mkConst ``Int.Linear.of_var_eq_mul) ( getContext) ( mkVarDecl x) (toExpr k) ( mkVarDecl y) ( mkPolyDecl c'.p) eagerReflBoolTrue h
| .none =>
throwError "`grind` internal error, cutsat failed to construct proof for multiplication equality"
where
goVar (e : Expr) : ProofM MulEqProof := do
if some e == a? then
return .mulVar 1 e (mkApp (mkConst ``Int.Linear.eq_one_mul) e)
else
let some (_, k, c) := cs.find? fun (e', _, _) => e' == e | return .none
let x getVarOf e
let h := mkApp6 (mkConst ``Int.Linear.var_eq) ( getContext) ( mkVarDecl x) (toExpr k) ( mkPolyDecl c.p) eagerReflBoolTrue ( c.toExprProof)
return .const k h
go (e : Expr) : ProofM MulEqProof := do
let_expr HMul.hMul _ _ _ i a b := e | goVar e
if !( isInstHMulInt i) then goVar e else
let ha go a
if let .const 0 h := ha then
return .const 0 (mkApp3 (mkConst ``Int.Linear.mul_eq_zero_left) a b h)
let hb go b
if let .const 0 h := hb then
return .const 0 (mkApp3 (mkConst ``Int.Linear.mul_eq_zero_right) a b h)
match ha, hb with
| .const k₁ h₁, .const k₂ h₂ =>
let k := k₁*k₂
let h := mkApp8 (mkConst ``Int.Linear.mul_eq_kk) a b (toExpr k₁) (toExpr k₂) (toExpr k) h₁ h₂ eagerReflBoolTrue
return .const k h
| .const k₁ h₁, .mulVar k₂ c h₂ =>
let k := k₁*k₂
let h := mkApp9 (mkConst ``Int.Linear.mul_eq_kkx) a b (toExpr k₁) (toExpr k₂) c (toExpr k) h₁ h₂ eagerReflBoolTrue
return .mulVar k c h
| .mulVar k₁ c h₁, .const k₂ h₂ =>
let k := k₁*k₂
let h := mkApp9 (mkConst ``Int.Linear.mul_eq_kxk) a b (toExpr k₁) c (toExpr k₂) (toExpr k) h₁ h₂ eagerReflBoolTrue
return .mulVar k c h
| _, _ => return .none
private def mkDivEqProof (k : Int) (y? : Option Var) (c : EqCnstr) (c' : EqCnstr) : ProofM Expr := do
let .add _ x _ := c'.p | c'.throwUnexpected
let_expr HDiv.hDiv _ _ _ _ a b := ( getCurrVars)[x]! | c'.throwUnexpected
let bVar getVarOf b
let h := mkApp6 (mkConst ``Int.Linear.var_eq) ( getContext) ( mkVarDecl bVar) (toExpr k) ( mkPolyDecl c.p) eagerReflBoolTrue ( c.toExprProof)
if let some y := y? then
let h := mkApp4 (mkConst ``Int.Linear.div_eq) a b (toExpr k) h
return mkApp6 (mkConst ``Int.Linear.of_var_eq_var) ( getContext) ( mkVarDecl x) ( mkVarDecl y) ( mkPolyDecl c'.p) eagerReflBoolTrue h
else
let b' := k
let some aVal getIntValue? a | unreachable!
let k := aVal / b'
let h := mkApp6 (mkConst ``Int.Linear.div_eq') a b (toExpr b') (toExpr k) h eagerReflBoolTrue
return mkApp6 (mkConst ``Int.Linear.of_var_eq) ( getContext) ( mkVarDecl x) (toExpr k) ( mkPolyDecl c'.p) eagerReflBoolTrue h
private def mkModEqProof (k : Int) (y? : Option Var) (c : EqCnstr) (c' : EqCnstr) : ProofM Expr := do
let .add _ x _ := c'.p | c'.throwUnexpected
let_expr HMod.hMod _ _ _ _ a b := ( getCurrVars)[x]! | c'.throwUnexpected
let bVar getVarOf b
let h := mkApp6 (mkConst ``Int.Linear.var_eq) ( getContext) ( mkVarDecl bVar) (toExpr k) ( mkPolyDecl c.p) eagerReflBoolTrue ( c.toExprProof)
if let some y := y? then
let h := mkApp4 (mkConst ``Int.Linear.mod_eq) a b (toExpr k) h
return mkApp6 (mkConst ``Int.Linear.of_var_eq_var) ( getContext) ( mkVarDecl x) ( mkVarDecl y) ( mkPolyDecl c'.p) eagerReflBoolTrue h
else
let b' := k
let some aVal getIntValue? a | unreachable!
let k := aVal % b'
let h := mkApp6 (mkConst ``Int.Linear.mod_eq') a b (toExpr b') (toExpr k) h eagerReflBoolTrue
return mkApp6 (mkConst ``Int.Linear.of_var_eq) ( getContext) ( mkVarDecl x) (toExpr k) ( mkPolyDecl c'.p) eagerReflBoolTrue h
private def mkPowEqProof (ka : Int) (ca? : Option EqCnstr) (kb : Nat) (cb? : Option EqCnstr) (c' : EqCnstr) : ProofM Expr := do
let .add _ x _ := c'.p | c'.throwUnexpected
let_expr HPow.hPow _ _ _ _ a b := ( getCurrVars)[x]! | c'.throwUnexpected
let h₁ if let some ca := ca? then
pure <| mkApp6 (mkConst ``Int.Linear.var_eq) ( getContext) ( mkVarDecl ( getVarOf a)) (toExpr ka) ( mkPolyDecl ca.p) eagerReflBoolTrue ( ca.toExprProof)
else
pure <| mkApp2 (mkConst ``Eq.refl [1]) Int.mkType (mkIntLit ka)
let kbInt := Int.ofNat kb
let h₂ if let some cb := cb? then
let (b', _) mkNatVar b
pure <| mkApp6 (mkConst ``Int.Linear.var_eq) ( getContext) ( mkVarDecl ( getVarOf b')) (toExpr kbInt) ( mkPolyDecl cb.p) eagerReflBoolTrue ( cb.toExprProof)
else
pure <| mkApp2 (mkConst ``Eq.refl [1]) Int.mkType (mkIntLit kb)
let k := ka^kb
let h := mkApp8 (mkConst ``Int.Linear.pow_eq) a b (toExpr ka) (toExpr kbInt) (toExpr k) h₁ h₂ eagerReflBoolTrue
return mkApp6 (mkConst ``Int.Linear.of_var_eq) ( getContext) ( mkVarDecl x) (toExpr k) ( mkPolyDecl c'.p) eagerReflBoolTrue h
mutual
partial def EqCnstr.toExprProof (c' : EqCnstr) : ProofM Expr := caching c' do
@[export lean_cutsat_eq_cnstr_to_proof]
private partial def EqCnstr.toExprProofImpl (c' : EqCnstr) : ProofM Expr := caching c' do
trace[grind.debug.cutsat.proof] "{← c'.pp}"
match c'.h with
| .core0 a zero =>
@@ -278,97 +374,11 @@ partial def EqCnstr.toExprProof (c' : EqCnstr) : ProofM Expr := caching c' do
| .mul a? cs =>
let .add _ x _ := c'.p | c'.throwUnexpected
mkMulEqProof x a? cs c'
| .div k y? c =>
let .add _ x _ := c'.p | c'.throwUnexpected
let_expr HDiv.hDiv _ _ _ _ a b := ( getCurrVars)[x]! | c'.throwUnexpected
let bVar getVarOf b
let h := mkApp6 (mkConst ``Int.Linear.var_eq) ( getContext) ( mkVarDecl bVar) (toExpr k) ( mkPolyDecl c.p) eagerReflBoolTrue ( c.toExprProof)
if let some y := y? then
let h := mkApp4 (mkConst ``Int.Linear.div_eq) a b (toExpr k) h
return mkApp6 (mkConst ``Int.Linear.of_var_eq_var) ( getContext) ( mkVarDecl x) ( mkVarDecl y) ( mkPolyDecl c'.p) eagerReflBoolTrue h
else
let b' := k
let some aVal getIntValue? a | unreachable!
let k := aVal / b'
let h := mkApp6 (mkConst ``Int.Linear.div_eq') a b (toExpr b') (toExpr k) h eagerReflBoolTrue
return mkApp6 (mkConst ``Int.Linear.of_var_eq) ( getContext) ( mkVarDecl x) (toExpr k) ( mkPolyDecl c'.p) eagerReflBoolTrue h
| .mod k y? c =>
let .add _ x _ := c'.p | c'.throwUnexpected
let_expr HMod.hMod _ _ _ _ a b := ( getCurrVars)[x]! | c'.throwUnexpected
let bVar getVarOf b
let h := mkApp6 (mkConst ``Int.Linear.var_eq) ( getContext) ( mkVarDecl bVar) (toExpr k) ( mkPolyDecl c.p) eagerReflBoolTrue ( c.toExprProof)
if let some y := y? then
let h := mkApp4 (mkConst ``Int.Linear.mod_eq) a b (toExpr k) h
return mkApp6 (mkConst ``Int.Linear.of_var_eq_var) ( getContext) ( mkVarDecl x) ( mkVarDecl y) ( mkPolyDecl c'.p) eagerReflBoolTrue h
else
let b' := k
let some aVal getIntValue? a | unreachable!
let k := aVal % b'
let h := mkApp6 (mkConst ``Int.Linear.mod_eq') a b (toExpr b') (toExpr k) h eagerReflBoolTrue
return mkApp6 (mkConst ``Int.Linear.of_var_eq) ( getContext) ( mkVarDecl x) (toExpr k) ( mkPolyDecl c'.p) eagerReflBoolTrue h
| .pow ka ca? kb cb? =>
let .add _ x _ := c'.p | c'.throwUnexpected
let_expr HPow.hPow _ _ _ _ a b := ( getCurrVars)[x]! | c'.throwUnexpected
let h₁ if let some ca := ca? then
pure <| mkApp6 (mkConst ``Int.Linear.var_eq) ( getContext) ( mkVarDecl ( getVarOf a)) (toExpr ka) ( mkPolyDecl ca.p) eagerReflBoolTrue ( ca.toExprProof)
else
pure <| mkApp2 (mkConst ``Eq.refl [1]) Int.mkType (mkIntLit ka)
let kbInt := Int.ofNat kb
let h₂ if let some cb := cb? then
let (b', _) mkNatVar b
pure <| mkApp6 (mkConst ``Int.Linear.var_eq) ( getContext) ( mkVarDecl ( getVarOf b')) (toExpr kbInt) ( mkPolyDecl cb.p) eagerReflBoolTrue ( cb.toExprProof)
else
pure <| mkApp2 (mkConst ``Eq.refl [1]) Int.mkType (mkIntLit kb)
let k := ka^kb
let h := mkApp8 (mkConst ``Int.Linear.pow_eq) a b (toExpr ka) (toExpr kbInt) (toExpr k) h₁ h₂ eagerReflBoolTrue
return mkApp6 (mkConst ``Int.Linear.of_var_eq) ( getContext) ( mkVarDecl x) (toExpr k) ( mkPolyDecl c'.p) eagerReflBoolTrue h
| .div k y? c => mkDivEqProof k y? c c'
| .mod k y? c => mkModEqProof k y? c c'
| .pow ka ca? kb cb? => mkPowEqProof ka ca? kb cb? c'
partial def mkMulEqProof (x : Var) (a? : Option Expr) (cs : Array (Expr × Int × EqCnstr)) (c' : EqCnstr) : ProofM Expr := do
let h go ( getCurrVars)[x]!
match h with
| .const k h =>
return mkApp6 (mkConst ``Int.Linear.of_var_eq) ( getContext) ( mkVarDecl x) (toExpr k) ( mkPolyDecl c'.p) eagerReflBoolTrue h
| .mulVar k a h =>
assert! a? == some a
let y getVarOf a
return mkApp7 (mkConst ``Int.Linear.of_var_eq_mul) ( getContext) ( mkVarDecl x) (toExpr k) ( mkVarDecl y) ( mkPolyDecl c'.p) eagerReflBoolTrue h
| .none =>
throwError "`grind` internal error, cutsat failed to construct proof for multiplication equality"
where
goVar (e : Expr) : ProofM MulEqProof := do
if some e == a? then
return .mulVar 1 e (mkApp (mkConst ``Int.Linear.eq_one_mul) e)
else
let some (_, k, c) := cs.find? fun (e', _, _) => e' == e | return .none
let x getVarOf e
let h := mkApp6 (mkConst ``Int.Linear.var_eq) ( getContext) ( mkVarDecl x) (toExpr k) ( mkPolyDecl c.p) eagerReflBoolTrue ( c.toExprProof)
return .const k h
go (e : Expr) : ProofM MulEqProof := do
let_expr HMul.hMul _ _ _ i a b := e | goVar e
if !( isInstHMulInt i) then goVar e else
let ha go a
if let .const 0 h := ha then
return .const 0 (mkApp3 (mkConst ``Int.Linear.mul_eq_zero_left) a b h)
let hb go b
if let .const 0 h := hb then
return .const 0 (mkApp3 (mkConst ``Int.Linear.mul_eq_zero_right) a b h)
match ha, hb with
| .const k₁ h₁, .const k₂ h₂ =>
let k := k₁*k₂
let h := mkApp8 (mkConst ``Int.Linear.mul_eq_kk) a b (toExpr k₁) (toExpr k₂) (toExpr k) h₁ h₂ eagerReflBoolTrue
return .const k h
| .const k₁ h₁, .mulVar k₂ c h₂ =>
let k := k₁*k₂
let h := mkApp9 (mkConst ``Int.Linear.mul_eq_kkx) a b (toExpr k₁) (toExpr k₂) c (toExpr k) h₁ h₂ eagerReflBoolTrue
return .mulVar k c h
| .mulVar k₁ c h₁, .const k₂ h₂ =>
let k := k₁*k₂
let h := mkApp9 (mkConst ``Int.Linear.mul_eq_kxk) a b (toExpr k₁) c (toExpr k₂) (toExpr k) h₁ h₂ eagerReflBoolTrue
return .mulVar k c h
| _, _ => return .none
partial def DvdCnstr.toExprProof (c' : DvdCnstr) : ProofM Expr := caching c' do
private partial def DvdCnstr.toExprProof (c' : DvdCnstr) : ProofM Expr := caching c' do
trace[grind.debug.cutsat.proof] "{← c'.pp}"
match c'.h with
| .core e =>
@@ -432,7 +442,7 @@ partial def DvdCnstr.toExprProof (c' : DvdCnstr) : ProofM Expr := caching c' do
let h := mkApp4 (mkConst ``Grind.CommRing.norm_int) ( getRingContext) ( mkRingExprDecl e) ( mkRingPolyDecl p) eagerReflBoolTrue
return mkApp6 (mkConst ``Int.Linear.dvd_norm_poly) ( getContext) (toExpr c.d) ( mkPolyDecl c.p) ( mkPolyDecl c'.p) h ( c.toExprProof)
partial def LeCnstr.toExprProof (c' : LeCnstr) : ProofM Expr := caching c' do
private partial def LeCnstr.toExprProof (c' : LeCnstr) : ProofM Expr := caching c' do
trace[grind.debug.cutsat.proof] "{← c'.pp}"
match c'.h with
| .core e =>
@@ -511,7 +521,7 @@ partial def LeCnstr.toExprProof (c' : LeCnstr) : ProofM Expr := caching c' do
let h := mkApp4 (mkConst ``Grind.CommRing.norm_int) ( getRingContext) ( mkRingExprDecl e) ( mkRingPolyDecl p) eagerReflBoolTrue
return mkApp5 (mkConst ``Int.Linear.le_norm_poly) ( getContext) ( mkPolyDecl c.p) ( mkPolyDecl c'.p) h ( c.toExprProof)
partial def DiseqCnstr.toExprProof (c' : DiseqCnstr) : ProofM Expr := caching c' do
private partial def DiseqCnstr.toExprProof (c' : DiseqCnstr) : ProofM Expr := caching c' do
match c'.h with
| .core0 a zero =>
mkDiseqProof a zero
@@ -537,7 +547,7 @@ partial def DiseqCnstr.toExprProof (c' : DiseqCnstr) : ProofM Expr := caching c'
let h := mkApp4 (mkConst ``Grind.CommRing.norm_int) ( getRingContext) ( mkRingExprDecl e) ( mkRingPolyDecl p) eagerReflBoolTrue
return mkApp5 (mkConst ``Int.Linear.diseq_norm_poly) ( getContext) ( mkPolyDecl c.p) ( mkPolyDecl c'.p) h ( c.toExprProof)
partial def CooperSplit.toExprProof (s : CooperSplit) : ProofM Expr := caching s do
private partial def CooperSplit.toExprProof (s : CooperSplit) : ProofM Expr := caching s do
match s.h with
| .dec h => return mkFVar h
| .last hs _ =>
@@ -576,7 +586,7 @@ partial def CooperSplit.toExprProof (s : CooperSplit) : ProofM Expr := caching s
-- `result` is now a proof of `p 0`
return result
partial def UnsatProof.toExprProofCore (h : UnsatProof) : ProofM Expr := do
private partial def UnsatProof.toExprProofCore (h : UnsatProof) : ProofM Expr := do
match h with
| .le c =>
return mkApp4 (mkConst ``Int.Linear.le_unsat) ( getContext) ( mkPolyDecl c.p) eagerReflBoolTrue ( c.toExprProof)

View File

@@ -77,7 +77,7 @@ where
else
return some ( asVar e)
isOfNatZero (e : Expr) : LinearM Bool := do
withDefault <| isDefEq e ( getStruct).ofNatZero
isDefEqD e ( getStruct).ofNatZero
processHMul (i a b : Expr) : LinearM (Option LinExpr) := do
if isHMulIntInst ( getStruct) i then
let some k getIntValue? a | return none

View File

@@ -41,7 +41,7 @@ private def mkExpectedDefEqMsg (a b : Expr) : MetaM MessageData :=
return m!"`grind linarith` expected{indentExpr a}\nto be definitionally equal to{indentExpr b}"
private def ensureDefEq (a b : Expr) : MetaM Unit := do
unless ( withDefault <| isDefEq a b) do
unless ( isDefEqD a b) do
throwError ( mkExpectedDefEqMsg a b)
private def addZeroLtOne (one : Var) : LinearM Unit := do
@@ -94,7 +94,7 @@ private def mkOne? (u : Level) (type : Expr) : GoalM (Option Expr) := do
let some oneInst synthInstance? (mkApp (mkConst ``One [u]) type) | return none
let one internalizeConst <| mkApp2 (mkConst ``One.one [u]) type oneInst
let one' mkNumeral type 1
unless ( withDefault <| isDefEq one one') do reportIssue! ( mkExpectedDefEqMsg one one')
unless ( isDefEqD one one') do reportIssue! ( mkExpectedDefEqMsg one one')
return some one
private def mkPreorderInst? (u : Level) (type : Expr) (leInst? ltInst? : Option Expr) : GoalM (Option Expr) := do
@@ -168,7 +168,7 @@ where
let some parentInst := parentInst? | return none
let some childInst := childInst? | return none
let toField := mkApp4 (mkConst toFieldName [u]) type leInst ltInst childInst
unless ( withDefault <| isDefEq parentInst toField) do
unless ( isDefEqD parentInst toField) do
reportIssue! ( mkExpectedDefEqMsg parentInst toField)
return none
return some childInst

View File

@@ -246,7 +246,7 @@ def checkInvariants : GoalM Unit := do
let p mkProofForPath u v
trace[grind.debug.offset.proof] "{p} : {← inferType p}"
check p
unless ( withDefault <| isDefEq ( inferType p) ( Cnstr.toExpr c)) do
unless ( isDefEqD ( inferType p) ( Cnstr.toExpr c)) do
throwError "`grind` internal error in the offset constraint module, constraint{indentExpr (← Cnstr.toExpr c)}\nis not definitionally equal to type of its proof{indentExpr (← inferType p)}"
/--

View File

@@ -62,7 +62,7 @@ private def isDefEqBounded (a b : Expr) (parent : Expr) : GoalM Bool := do
let curr := ( getConfig).canonHeartbeats
tryCatchRuntimeEx
(withTheReader Core.Context (fun ctx => { ctx with maxHeartbeats := curr*1000 }) do
withDefault <| isDefEq a b)
isDefEqD a b)
fun ex => do
if ex.isRuntime then
reportIssue! "failed to show that{indentExpr a}\nis definitionally equal to{indentExpr b}\nwhile canonicalizing{indentExpr parent}\nusing `{curr}*1000` heartbeats, `(canonHeartbeats := {curr})`"
@@ -91,7 +91,7 @@ private def canonElemCore (parent : Expr) (f : Expr) (i : Nat) (e : Expr) (useIs
```
where `grind` unfolds the definition of `DHashMap.insert` and `TreeMap.insert`.
-/
if ( withDefault <| isDefEq eType cType) then
if ( isDefEqD eType cType) then
if ( isDefEq e c) then
-- We used to check `c.fvarsSubset e` because it is not
-- in general safe to replace `e` with `c` if `c` has more free variables than `e`.

View File

@@ -35,7 +35,7 @@ and close goal if they are different.
def propagateCtor (a b : Expr) : GoalM Unit := do
let aType whnfD ( inferType a)
let bType whnfD ( inferType b)
unless ( withDefault <| isDefEq aType bType) do
unless ( isDefEqD aType bType) do
return ()
let ctor₁ := a.getAppFn
let ctor₂ := b.getAppFn

View File

@@ -4,16 +4,13 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Tactic.Grind.Types
public import Lean.Meta.Tactic.Grind.Intro
public import Lean.Meta.Tactic.Grind.MatchDiscrOnly
public import Lean.Meta.Tactic.Grind.MatchCond
public import Lean.Meta.Tactic.Grind.Core
import Lean.Meta.Tactic.Grind.Intro
import Lean.Meta.Tactic.Grind.MatchDiscrOnly
import Lean.Meta.Tactic.Grind.MatchCond
import Lean.Meta.Tactic.Grind.Core
public section
namespace Lean.Meta.Grind
namespace EMatch
/-! This module implements a simple E-matching procedure as a backtracking search. -/
@@ -355,7 +352,7 @@ private def addNewInstance (thm : EMatchTheorem) (proof : Expr) (generation : Na
-- We must add a hint because `annotateEqnTypeConds` introduces `Grind.PreMatchCond`
-- which is not reducible.
proof := mkExpectedPropHint proof prop
trace_goal[grind.ematch.instance] "{thm.origin.pp}: {prop}"
trace_goal[grind.ematch.instance] "{thm.origin.pp}: {prop}"
addTheoremInstance thm proof prop (generation+1)
private def synthesizeInsts (mvars : Array Expr) (bis : Array BinderInfo) : OptionT M Unit := do
@@ -364,7 +361,7 @@ private def synthesizeInsts (mvars : Array Expr) (bis : Array BinderInfo) : Opti
if bi.isInstImplicit && !( mvar.mvarId!.isAssigned) then
let type inferType mvar
unless ( synthInstanceAndAssign mvar type) do
reportIssue! "failed to synthesize instance when instantiating {thm.origin.pp}{indentExpr type}"
reportIssue! "failed to synthesize instance when instantiating {thm.origin.pp}{indentExpr type}"
failure
private def preprocessGeneralizedPatternRHS (lhs : Expr) (rhs : Expr) (origin : Origin) (expectedType : Expr) : OptionT (StateT Choice M) Expr := do
@@ -376,12 +373,12 @@ private def preprocessGeneralizedPatternRHS (lhs : Expr) (rhs : Expr) (origin :
if ( isEqv lhs rhs) then
return rhs
else
reportIssue! "invalid generalized pattern at `{origin.pp}`\nwhen processing argument with type{indentExpr expectedType}\nfailed to prove{indentExpr lhs}\nis equal to{indentExpr rhs}"
reportIssue! "invalid generalized pattern at `{origin.pp}`\nwhen processing argument with type{indentExpr expectedType}\nfailed to prove{indentExpr lhs}\nis equal to{indentExpr rhs}"
failure
private def assignGeneralizedPatternProof (mvarId : MVarId) (eqProof : Expr) (origin : Origin) : OptionT (StateT Choice M) Unit := do
unless ( mvarId.checkedAssign eqProof) do
reportIssue! "invalid generalized pattern at `{origin.pp}`\nfailed to assign {mkMVar mvarId}\nwith{indentExpr eqProof}"
reportIssue! "invalid generalized pattern at `{origin.pp}`\nfailed to assign {mkMVar mvarId}\nwith{indentExpr eqProof}"
failure
/-- Helper function for `applyAssignment. -/
@@ -397,7 +394,7 @@ private def processDelayed (mvars : Array Expr) (i : Nat) (h : i < mvars.size) :
let rhs preprocessGeneralizedPatternRHS lhs rhs thm.origin mvarIdType
assignGeneralizedPatternProof mvarId ( mkHEqProof lhs rhs) thm.origin
| _ =>
reportIssue! "invalid generalized pattern at `{thm.origin.pp}`\nequality type expected{indentExpr mvarIdType}"
reportIssue! "invalid generalized pattern at `{thm.origin.pp}`\nequality type expected{indentExpr mvarIdType}"
failure
/-- Helper function for `applyAssignment. -/
@@ -430,9 +427,9 @@ private def processUnassigned (mvars : Array Expr) (i : Nat) (v : Expr) (h : i <
if ( isProp vType) then
modify (unassign · bidx)
else
reportIssue! "type error constructing proof for {thm.origin.pp}\nwhen assigning metavariable {mvars[i]} with {indentExpr v}\n{← mkHasTypeButIsExpectedMsg vType mvarIdType}"
reportIssue! "type error constructing proof for {thm.origin.pp}\nwhen assigning metavariable {mvars[i]} with {indentExpr v}\n{← mkHasTypeButIsExpectedMsg vType mvarIdType}"
failure
if ( withDefault <| isDefEq mvarIdType vType) then
if ( isDefEqD mvarIdType vType) then
unless ( mvarId.checkedAssign v) do unassignOrFail
else
if let some heq withoutReportingMVarIssues <| proveEq? vType mvarIdType (abstract := true) then
@@ -467,13 +464,13 @@ private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do w
let thm := ( read).thm
unless ( markTheoremInstance thm.proof c.assignment) do
return ()
trace_goal[grind.ematch.instance.assignment] "{thm.origin.pp}: {assignmentToMessageData c.assignment}"
trace_goal[grind.ematch.instance.assignment] "{thm.origin.pp}: {assignmentToMessageData c.assignment}"
let proof thm.getProofWithFreshMVarLevels
let numParams := thm.numParams
assert! c.assignment.size == numParams
let (mvars, bis, _) forallMetaBoundedTelescope ( inferType proof) numParams
if mvars.size != thm.numParams then
reportIssue! "unexpected number of parameters at {thm.origin.pp}"
reportIssue! "unexpected number of parameters at {thm.origin.pp}"
return ()
let (some _, c) applyAssignment mvars |>.run c | return ()
let some _ synthesizeInsts mvars bis | return ()
@@ -483,7 +480,7 @@ private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do w
else
let mvars mvars.filterM fun mvar => return !( mvar.mvarId!.isAssigned)
if let some mvarBad mvars.findM? fun mvar => return !( isProof mvar) then
reportIssue! "failed to instantiate {thm.origin.pp}, failed to instantiate non propositional argument with type{indentExpr (← inferType mvarBad)}"
reportIssue! "failed to instantiate {thm.origin.pp}, failed to instantiate non propositional argument with type{indentExpr (← inferType mvarBad)}"
let proof mkLambdaFVars (binderInfoForMVars := .default) mvars ( instantiateMVars proof)
addNewInstance thm proof c.gen

View File

@@ -310,12 +310,12 @@ def Origin.key : Origin → Name
| .stx id _ => id
| .local id => id
def Origin.pp [Monad m] [MonadEnv m] [MonadError m] (o : Origin) : m MessageData := do
def Origin.pp (o : Origin) : MessageData :=
match o with
| .decl declName => return MessageData.ofConstName declName
| .fvar fvarId => return mkFVar fvarId
| .stx _ ref => return ref
| .local id => return id
| .decl declName => MessageData.ofConstName declName
| .fvar fvarId => mkFVar fvarId
| .stx _ ref => ref
| .local id => id
instance : BEq Origin where
beq a b := a.key == b.key
@@ -872,7 +872,7 @@ private def ppParamsAt (proof : Expr) (numParams : Nat) (paramPos : List Nat) :
private def logPatternWhen (showInfo : Bool) (origin : Origin) (patterns : List Expr) : MetaM Unit := do
if showInfo then
logInfo m!"{origin.pp}: {patterns.map ppPattern}"
logInfo m!"{origin.pp}: {patterns.map ppPattern}"
/--
Creates an E-matching theorem for a theorem with proof `proof`, `numParams` parameters, and the given set of patterns.
@@ -883,11 +883,11 @@ def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams
-- the patterns have already been selected, there is no point in using priorities here
let (patterns, symbols, bvarFound) NormalizePattern.main patterns ( getGlobalSymbolPriorities) (minPrio := 1)
if symbols.isEmpty then
throwError "invalid pattern for `{origin.pp}`{indentD (patterns.map ppPattern)}\nthe pattern does not contain constant symbols for indexing"
trace[grind.ematch.pattern] "{origin.pp}: {patterns.map ppPattern}"
throwError "invalid pattern for `{origin.pp}`{indentD (patterns.map ppPattern)}\nthe pattern does not contain constant symbols for indexing"
trace[grind.ematch.pattern] "{origin.pp}: {patterns.map ppPattern}"
if let .missing pos checkCoverage proof numParams bvarFound then
let pats : MessageData := m!"{patterns.map ppPattern}"
throwError "invalid pattern(s) for `{origin.pp}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
throwError "invalid pattern(s) for `{origin.pp}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
logPatternWhen showInfo origin patterns
return {
proof, patterns, numParams, symbols
@@ -924,7 +924,7 @@ def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof :
| HEq _ lhs _ rhs => pure (lhs, rhs)
| _ => throwError "invalid E-matching equality theorem, conclusion must be an equality{indentExpr type}"
let pat := if useLhs then lhs else rhs
trace[grind.debug.ematch.pattern] "mkEMatchEqTheoremCore: origin: {origin.pp}, pat: {pat}, useLhs: {useLhs}"
trace[grind.debug.ematch.pattern] "mkEMatchEqTheoremCore: origin: {origin.pp}, pat: {pat}, useLhs: {useLhs}"
let pat preprocessPattern pat normalizePattern
trace[grind.debug.ematch.pattern] "mkEMatchEqTheoremCore: after preprocessing: {pat}, {← normalize pat normConfig}"
let pats := splitWhileForbidden (pat.abstract xs)
@@ -1257,7 +1257,7 @@ def mkEMatchTheoremWithKind?
| .fwd =>
let ps getPropTypes xs
if ps.isEmpty then
throwError "invalid `grind` forward theorem, theorem `{origin.pp}` does not have propositional hypotheses"
throwError "invalid `grind` forward theorem, theorem `{origin.pp}` does not have propositional hypotheses"
pure ps
| .bwd _ => pure #[type]
| .leftRight => pure <| ( getPropTypes xs).push type
@@ -1279,7 +1279,7 @@ where
go (xs : Array Expr) (searchPlaces : Array Expr) : MetaM (Option EMatchTheorem) := do
let some (patterns, symbols) collect xs searchPlaces | return none
let numParams := xs.size
trace[grind.ematch.pattern] "{origin.pp}: {patterns.map ppPattern}"
trace[grind.ematch.pattern] "{origin.pp}: {patterns.map ppPattern}"
logPatternWhen showInfo origin patterns
return some {
proof, patterns, numParams, symbols

View File

@@ -219,7 +219,7 @@ def activateTheorem (thm : EMatchTheorem) (generation : Nat) : GoalM Unit := do
-- We don't want to use structural equality when comparing keys.
let proof shareCommon thm.proof
let thm := { thm with proof, patterns := ( thm.patterns.mapM (internalizePattern · generation thm.origin)) }
trace_goal[grind.ematch] "activated `{thm.origin.pp}`, {thm.patterns.map ppPattern}"
trace_goal[grind.ematch] "activated `{thm.origin.pp}`, {thm.patterns.map ppPattern}"
modify fun s => { s with ematch.newThms := s.ematch.newThms.push thm }
/--
@@ -331,7 +331,7 @@ where
trace_goal[grind.debug.ext] "{f}, {i}, {arg}"
let others := ( get).split.argsAt.find? (f, i) |>.getD []
for other in others do
if ( withDefault <| isDefEq type other.type) then
if ( isDefEqD type other.type) then
let eq := mkApp3 (mkConst ``Eq [ getLevel type]) type arg other.arg
let eq shareCommon eq
internalize eq generation

View File

@@ -143,7 +143,7 @@ private def splitDiagInfoToMessageData (ss : Array SplitDiagInfo) : MetaM Messag
let data ss.mapM fun { c, lctx, numCases, gen, splitSource } => do
let header := m!"{c}"
return MessageData.withContext { env, mctx, lctx, opts } <| .trace { cls } header #[
.trace { cls } m!"source: {splitSource.toMessageData}" #[],
.trace { cls } m!"source: {splitSource.toMessageData}" #[],
.trace { cls } m!"generation: {gen}" #[],
.trace { cls } m!"# cases: {numCases}" #[]
]

View File

@@ -363,7 +363,7 @@ partial def tryToProveFalse (e : Expr) : GoalM Unit := do
| return none
let lhs' go lhs
trace[grind.debug.matchCond.proveFalse] "{lhs'} =?= {rhs}"
unless ( withDefault <| isDefEq lhs' rhs) do
unless ( isDefEqD lhs' rhs) do
return none
let isHEq := α?.isSome
let some lhsEqLhs' if isHEq then proveHEq? lhs lhs' else proveEq? lhs lhs'

View File

@@ -112,7 +112,7 @@ private def ppEqcs : M Unit := do
pushMsg <| .trace { cls := `eqc } "Equivalence classes" otherEqcs
private def ppEMatchTheorem (thm : EMatchTheorem) : MetaM MessageData := do
let m := m!"{thm.origin.pp}: {thm.patterns.map ppPattern}"
let m := m!"{thm.origin.pp}: {thm.patterns.map ppPattern}"
return .trace { cls := `thm } m #[]
private def ppActiveTheoremPatterns : M Unit := do
@@ -185,7 +185,7 @@ private def ppCasesTrace : M Unit := do
let mut msgs := #[]
for { expr, i , num, source } in goal.split.trace.reverse do
msgs := msgs.push <| .trace { cls := `cases } m!"[{i+1}/{num}]: {expr}" #[
.trace { cls := `cases } m!"source: {source.toMessageData}" #[]
.trace { cls := `cases } m!"source: {source.toMessageData}" #[]
]
pushMsg <| .trace { cls := `cases } "Case analyses" msgs

View File

@@ -35,6 +35,10 @@ namespace Lean.Meta.Grind
/-- We use this auxiliary constant to mark delayed congruence proofs. -/
def congrPlaceholderProof := mkConst (Name.mkSimple "[congruence]")
/-- Similar to `isDefEq`, but ensures default transparency is used. -/
def isDefEqD (t s : Expr) : MetaM Bool :=
withDefault <| isDefEq t s
/--
Returns `true` if `e` is `True`, `False`, or a literal value.
See `Lean.Meta.LitValues` for supported literals.
@@ -82,14 +86,14 @@ inductive SplitSource where
input
deriving Inhabited
def SplitSource.toMessageData : SplitSource MetaM MessageData
| .ematch origin => return m!"E-matching {origin.pp}"
| .ext declName => return m!"Extensionality {declName}"
| .mbtc a b i => return m!"Model-based theory combination at argument #{i} of{indentExpr a}\nand{indentExpr b}"
| .beta e => return m!"Beta-reduction of{indentExpr e}"
| .forallProp e => return m!"Forall propagation at{indentExpr e}"
| .existsProp e => return m!"Exists propagation at{indentExpr e}"
| .input => return m!"Initial goal"
def SplitSource.toMessageData : SplitSource MessageData
| .ematch origin => m!"E-matching {origin.pp}"
| .ext declName => m!"Extensionality {declName}"
| .mbtc a b i => m!"Model-based theory combination at argument #{i} of{indentExpr a}\nand{indentExpr b}"
| .beta e => m!"Beta-reduction of{indentExpr e}"
| .forallProp e => m!"Forall propagation at{indentExpr e}"
| .existsProp e => m!"Exists propagation at{indentExpr e}"
| .input => "Initial goal"
/-- Context for `GrindM` monad. -/
structure Context where
@@ -980,8 +984,8 @@ def pushEqCore (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
modify fun s => { s with newFacts := s.newFacts.push <| .eq lhs rhs proof isHEq }
/-- Return `true` if `a` and `b` have the same type. -/
def hasSameType (a b : Expr) : MetaM Bool :=
withDefault do isDefEq ( inferType a) ( inferType b)
def hasSameType (a b : Expr) : MetaM Bool := do
isDefEqD ( inferType a) ( inferType b)
@[inline] def pushEqHEq (lhs rhs proof : Expr) : GoalM Unit := do
if ( hasSameType lhs rhs) then
@@ -1140,8 +1144,8 @@ def isNum (e : Expr) : Bool :=
/--
Returns `true` if type of `t` is definitionally equal to `α`
-/
def hasType (t α : Expr) : MetaM Bool :=
withDefault do isDefEq ( inferType t) α
def hasType (t α : Expr) : MetaM Bool := do
isDefEqD ( inferType t) α
/--
For each equality `b = c` in `parents`, executes `k b c` IF