Compare commits

...

7 Commits

Author SHA1 Message Date
Leonardo de Moura
71e1701df2 chore: fix test
Some goals in the test are now closed by `grind`.
2024-12-27 12:56:12 -08:00
Leonardo de Moura
74bbd88d7a feat: propagate contructor equalities 2024-12-27 12:54:00 -08:00
Leonardo de Moura
ec587c550b feat: add getForallArity 2024-12-27 12:33:21 -08:00
Leonardo de Moura
88b60d7cb5 feat: close goal when merging equivalence classes containing distinct constructors 2024-12-27 11:38:51 -08:00
Leonardo de Moura
32d9045783 chore: use closeGoal at propagateNotDown 2024-12-27 11:02:48 -08:00
Leonardo de Moura
1e856b998c refactor: add closeGoal and support for closing current goal at addEqStep 2024-12-27 10:59:58 -08:00
Leonardo de Moura
ffadc88438 test: vector examples for grind 2024-12-27 10:04:40 -08:00
12 changed files with 183 additions and 66 deletions

View File

@@ -42,7 +42,7 @@ theorem not_eq_of_eq_false {a : Prop} (h : a = False) : (Not a) = True := by sim
theorem eq_false_of_not_eq_true {a : Prop} (h : (Not a) = True) : a = False := by simp_all
theorem eq_true_of_not_eq_false {a : Prop} (h : (Not a) = False) : a = True := by simp_all
theorem true_eq_false_of_not_eq_self {a : Prop} (h : (Not a) = a) : True = False := by
theorem false_of_not_eq_self {a : Prop} (h : (Not a) = a) : False := by
by_cases a <;> simp_all
/-! Eq -/

View File

@@ -1648,6 +1648,23 @@ def isFalse (e : Expr) : Bool :=
def isTrue (e : Expr) : Bool :=
e.cleanupAnnotations.isConstOf ``True
/--
`getForallArity type` returns the arity of a `forall`-type. This function consumes nested annotations,
and performs pending beta reductions. It does **not** use whnf.
Examples:
- If `a` is `Nat`, `getForallArity a` returns `0`
- If `a` is `Nat → Bool`, `getForallArity a` returns `1`
-/
partial def getForallArity : Expr Nat
| .mdata _ b => getForallArity b
| .forallE _ _ b _ => getForallArity b + 1
| e =>
if e.isHeadBetaTarget then
getForallArity e.headBeta
else
let e' := e.cleanupAnnotations
if e != e' then getForallArity e' else 0
/--
Checks if an expression is a "natural number numeral in normal form",
i.e. of type `Nat`, and explicitly of the form `OfNat.ofNat n`

View File

@@ -19,6 +19,7 @@ import Lean.Meta.Tactic.Grind.Proof
import Lean.Meta.Tactic.Grind.Propagate
import Lean.Meta.Tactic.Grind.PP
import Lean.Meta.Tactic.Grind.Simp
import Lean.Meta.Tactic.Grind.Ctor
namespace Lean

View File

@@ -9,6 +9,7 @@ import Lean.Meta.LitValues
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Inv
import Lean.Meta.Tactic.Grind.PP
import Lean.Meta.Tactic.Grind.Ctor
namespace Lean.Meta.Grind
@@ -79,9 +80,6 @@ where
proof? := proofNew?
}
private def markAsInconsistent : GoalM Unit := do
modify fun s => { s with inconsistent := true }
/--
Remove `root` parents from the congruence table.
This is an auxiliary function performed while merging equivalence classes.
@@ -100,6 +98,14 @@ private def reinsertParents (parents : ParentSet) : GoalM Unit := do
for parent in parents do
addCongrTable parent
/-- Closes the goal when `True` and `False` are in the same equivalence class. -/
private def closeGoalWithTrueEqFalse : GoalM Unit := do
let mvarId := ( get).mvarId
unless ( mvarId.isAssigned) do
let trueEqFalse mkEqFalseProof ( getTrueExpr)
let falseProof mkEqMP trueEqFalse (mkConst ``True.intro)
closeGoal falseProof
private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
trace[grind.eq] "{lhs} {if isHEq then "" else "="} {rhs}"
let lhsNode getENode lhs
@@ -111,9 +117,11 @@ private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit
let lhsRoot getENode lhsNode.root
let rhsRoot getENode rhsNode.root
let mut valueInconsistency := false
let mut trueEqFalse := false
if lhsRoot.interpreted && rhsRoot.interpreted then
markAsInconsistent
if lhsNode.root.isTrue || rhsNode.root.isTrue then
markAsInconsistent
trueEqFalse := true
else
valueInconsistency := true
if (lhsRoot.interpreted && !rhsRoot.interpreted)
@@ -122,6 +130,11 @@ private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit
go rhs lhs rhsNode lhsNode rhsRoot lhsRoot true
else
go lhs rhs lhsNode rhsNode lhsRoot rhsRoot false
if trueEqFalse then
closeGoalWithTrueEqFalse
unless ( isInconsistent) do
if lhsRoot.ctor && rhsRoot.ctor then
propagateCtor lhsRoot.self rhsRoot.self
-- TODO: propagate value inconsistency
trace[grind.debug] "after addEqStep, {← ppState}"
checkInvariants

View File

@@ -0,0 +1,50 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Types
namespace Lean.Meta.Grind
private partial def propagateInjEqs (eqs : Expr) (proof : Expr) : GoalM Unit := do
match_expr eqs with
| And left right =>
propagateInjEqs left (.proj ``And 0 proof)
propagateInjEqs right (.proj ``And 1 proof)
| Eq _ lhs rhs => pushEq lhs rhs proof
| HEq _ lhs _ rhs => pushHEq lhs rhs proof
| _ =>
trace[grind.issues] "unexpected injectivity theorem result type{indentExpr eqs}"
return ()
/--
Given constructors `a` and `b`, propagate equalities if they are the same,
and close goal if they are different.
-/
def propagateCtor (a b : Expr) : GoalM Unit := do
let aType whnfD ( inferType a)
let bType whnfD ( inferType b)
unless ( withDefault <| isDefEq aType bType) do
return ()
let ctor₁ := a.getAppFn
let ctor₂ := b.getAppFn
if ctor₁ == ctor₂ then
let .const ctorName _ := a.getAppFn | return ()
let injDeclName := Name.mkStr ctorName "inj"
unless ( getEnv).contains injDeclName do return ()
let info getConstInfo injDeclName
let n := info.type.getForallArity
let mask : Array (Option Expr) := mkArray n none
let mask := mask.set! (n-1) (some ( mkEqProof a b))
let injLemma mkAppOptM injDeclName mask
propagateInjEqs ( inferType injLemma) injLemma
else
let .const declName _ := aType.getAppFn | return ()
let noConfusionDeclName := Name.mkStr declName "noConfusion"
unless ( getEnv).contains noConfusionDeclName do return ()
let target ( get).mvarId.getType
closeGoal ( mkNoConfusion target ( mkEqProof a b))
end Lean.Meta.Grind

View File

@@ -98,6 +98,8 @@ def applyInjection? (goal : Goal) (fvarId : FVarId) : MetaM (Option Goal) := do
return none
partial def loop (goal : Goal) : PreM Unit := do
if goal.inconsistent then
return ()
match ( introNext goal) with
| .done =>
if let some mvarId goal.mvarId.byContra? then
@@ -162,9 +164,7 @@ def preprocess (mvarId : MVarId) (mainDeclName : Name) : MetaM Preprocessor.Stat
def main (mvarId : MVarId) (mainDeclName : Name) : MetaM (List MVarId) := do
let go : GrindM (List MVarId) := do
let s Preprocessor.preprocess mvarId |>.run
let goals s.goals.toList.filterM fun goal => do
let (done, _) GoalM.run goal closeIfInconsistent
return !done
let goals := s.goals.toList.filter fun goal => !goal.inconsistent
return goals.map (·.mvarId)
go.run mainDeclName

View File

@@ -214,7 +214,8 @@ end
Returns a proof that `a = b` (or `HEq a b`).
It assumes `a` and `b` are in the same equivalence class.
-/
def mkEqProof (a b : Expr) : GoalM Expr := do
@[export lean_grind_mk_eq_proof]
def mkEqProofImpl (a b : Expr) : GoalM Expr := do
let p go
trace[grind.proof.detail] "{p}"
return p
@@ -235,29 +236,4 @@ where
def mkHEqProof (a b : Expr) : GoalM Expr :=
mkEqProofCore a b (heq := true)
/--
Returns a proof that `a = True`.
It assumes `a` and `True` are in the same equivalence class.
-/
def mkEqTrueProof (a : Expr) : GoalM Expr := do
mkEqProof a ( getTrueExpr)
/--
Returns a proof that `a = False`.
It assumes `a` and `False` are in the same equivalence class.
-/
def mkEqFalseProof (a : Expr) : GoalM Expr := do
mkEqProof a ( getFalseExpr)
def closeIfInconsistent : GoalM Bool := do
if ( isInconsistent) then
let mvarId := ( get).mvarId
unless ( mvarId.isAssigned) do
let trueEqFalse mkEqFalseProof ( getTrueExpr)
let falseProof mkEqMP trueEqFalse (mkConst ``True.intro)
mvarId.assign falseProof
return true
else
return false
end Lean.Meta.Grind

View File

@@ -111,7 +111,7 @@ builtin_grind_propagator propagateNotDown ↓Not := fun e => do
else if ( isEqTrue e) then
pushEqFalse a <| mkApp2 (mkConst ``Lean.Grind.eq_false_of_not_eq_true) a ( mkEqTrueProof e)
else if ( isEqv e a) then
pushEqFalse ( getTrueExpr) <| mkApp2 (mkConst ``Lean.Grind.true_eq_false_of_not_eq_self) a ( mkEqProof e a)
closeGoal <| mkApp2 (mkConst ``Lean.Grind.false_of_not_eq_self) a ( mkEqProof e a)
/-- Propagates `Eq` upwards -/
builtin_grind_propagator propagateEqUp Eq := fun e => do

View File

@@ -275,10 +275,6 @@ abbrev GoalM := StateRefT Goal GrindM
abbrev Propagator := Expr GoalM Unit
/-- Return `true` if the goal is inconsistent. -/
def isInconsistent : GoalM Bool :=
return ( get).inconsistent
/-- Returns `true` if `e` is the internalized `True` expression. -/
def isTrueExpr (e : Expr) : GrindM Bool :=
return isSameExpr e ( getTrueExpr)
@@ -444,6 +440,46 @@ def mkENode (e : Expr) (generation : Nat) : GoalM Unit := do
let interpreted isInterpreted e
mkENodeCore e interpreted ctor generation
/-- Return `true` if the goal is inconsistent. -/
def isInconsistent : GoalM Bool :=
return ( get).inconsistent
/--
Returns a proof that `a = b` (or `HEq a b`).
It assumes `a` and `b` are in the same equivalence class.
-/
-- Forward definition
@[extern "lean_grind_mk_eq_proof"]
opaque mkEqProof (a b : Expr) : GoalM Expr
/--
Returns a proof that `a = True`.
It assumes `a` and `True` are in the same equivalence class.
-/
def mkEqTrueProof (a : Expr) : GoalM Expr := do
mkEqProof a ( getTrueExpr)
/--
Returns a proof that `a = False`.
It assumes `a` and `False` are in the same equivalence class.
-/
def mkEqFalseProof (a : Expr) : GoalM Expr := do
mkEqProof a ( getFalseExpr)
/-- Marks current goal as inconsistent without assigning `mvarId`. -/
def markAsInconsistent : GoalM Unit := do
modify fun s => { s with inconsistent := true }
/--
Closes the current goal using the given proof of `False` and
marks it as inconsistent if it is not already marked so.
-/
def closeGoal (falseProof : Expr) : GoalM Unit := do
markAsInconsistent
let mvarId := ( get).mvarId
unless ( mvarId.isAssigned) do
mvarId.assign falseProof
/-- Returns all enodes in the goal -/
def getENodes : GoalM (Array ENode) := do
-- We must sort because we are using pointer addresses as keys in `enodes`

View File

@@ -83,3 +83,18 @@ theorem ex1 (f : {α : Type} → α → Nat → Bool → Nat) (a b c : Nat) : f
grind
#print ex1
example (n1 n2 n3 : Nat) (v1 w1 : Vector Nat n1) (w1' : Vector Nat n3) (v2 w2 : Vector Nat n2) :
n1 = n3 v1 = w1 HEq w1 w1' v2 = w2 HEq (v1 ++ v2) (w1' ++ w2) := by
grind
example (n1 n2 n3 : Nat) (v1 w1 : Vector Nat n1) (w1' : Vector Nat n3) (v2 w2 : Vector Nat n2) :
HEq n1 n3 v1 = w1 HEq w1 w1' HEq v2 w2 HEq (v1 ++ v2) (w1' ++ w2) := by
grind
theorem ex2 (n1 n2 n3 : Nat) (v1 w1 v : Vector Nat n1) (w1' : Vector Nat n3) (v2 w2 w : Vector Nat n2) :
HEq n1 n3 v1 = w1 HEq w1 w1' HEq v2 w2 HEq (w1' ++ w2) (v ++ w) HEq (v1 ++ v2) (v ++ w) := by
grind
#print ex2

View File

@@ -21,24 +21,6 @@ theorem ex (h : (f a && (b || f (f c))) = true) (h' : p ∧ q) : b && a := by
open Lean.Grind.Eager in
/--
error: `grind` failed
a b c : Bool
p q : Prop
left✝ : a = true
h✝ : b = true
left : p
right : q
h : b = false
⊢ False
a b c : Bool
p q : Prop
left✝ : a = true
h✝ : b = true
left : p
right : q
h : a = false
⊢ False
a b c : Bool
p q : Prop
left✝ : a = true
@@ -47,15 +29,6 @@ left : p
right : q
h : b = false
⊢ False
a b c : Bool
p q : Prop
left✝ : a = true
h✝ : c = true
left : p
right : q
h : a = false
⊢ False
-/
#guard_msgs (error) in
theorem ex2 (h : (f a && (b || f (f c))) = true) (h' : p q) : b && a := by

View File

@@ -0,0 +1,36 @@
example (a b : List Nat) : a = [] b = [2] a = b False := by
grind
example (a b : List Nat) : a = b a = [] b = [2] False := by
grind
example (a b : Bool) : a = true b = false a = b False := by
grind
example (a b : Sum Nat Bool) : a = .inl c b = .inr true a = b False := by
grind
example (a b : Sum Nat Bool) : a = b a = .inl c b = .inr true a = b False := by
grind
inductive Foo (α : Type) : Nat Type where
| a (v : α) : Foo α 0
| b (n : α) (m : Nat) (v : Vector Nat m) : Foo α (2*m)
example (h₁ : Foo.b x 2 v = f₁) (h₂ : Foo.b y 2 w = f₂) : f₁ = f₂ x = y := by
grind
example (h₁ : Foo.a x = f₁) (h₂ : Foo.a y = f₂) : f₁ = f₂ x = y := by
grind
example (h₁ : a :: b = x) (h₂ : c :: d = y) : x = y a = c := by
grind
example (h : x = y) (h₁ : a :: b = x) (h₂ : c :: d = y) : a = c := by
grind
example (h : x = y) (h₁ : a :: b = x) (h₂ : c :: d = y) : b = d := by
grind
example (a b : Sum Nat Bool) : a = .inl x b = .inl y x y a = b False := by
grind