Compare commits

...

5 Commits

Author SHA1 Message Date
Leonardo de Moura
12850dfb3d chore: new test 2025-01-27 15:21:43 -08:00
Leonardo de Moura
bfe4f19768 chore: fix test 2025-01-27 15:17:04 -08:00
Leonardo de Moura
d8a65168af fix: add "to propagate" list to offset contraint module 2025-01-27 15:16:31 -08:00
Leonardo de Moura
17348d999d chore: use throwError 2025-01-27 13:34:25 -08:00
Leonardo de Moura
c3cc968430 chore: leftovers 2025-01-27 13:34:12 -08:00
5 changed files with 133 additions and 72 deletions

View File

@@ -65,18 +65,23 @@ private def getNodeId (e : Expr) : GoalM NodeId := do
| throwError "internal `grind` error, term has not been internalized by offset module{indentExpr e}"
return nodeId
private def getProof (u v : NodeId) : GoalM ProofInfo := do
let some p getProof? u v
| throwError "internal `grind` error, failed to construct proof for{indentExpr (← getExpr u)}\nand{indentExpr (← getExpr v)}"
return p
/--
Returns a proof for `u + k ≤ v` (or `u ≤ v + k`) where `k` is the
shortest path between `u` and `v`.
-/
private partial def mkProofForPath (u v : NodeId) : GoalM Expr := do
go ( getProof? u v).get!
go ( getProof u v)
where
go (p : ProofInfo) : GoalM Expr := do
if u == p.w then
return p.proof
else
let p' := ( getProof? u p.w).get!
let p' getProof u p.w
go (mkTrans ( get').nodes p' p v)
/--
@@ -119,33 +124,69 @@ private def isShorter (u v : NodeId) (k : Int) : GoalM Bool := do
else
return true
/-- Adds `p` to the list of things to be propagated. -/
private def pushToPropagate (p : ToPropagate) : GoalM Unit :=
modify' fun s => { s with propagate := p :: s.propagate }
private def propagateEqTrue (e : Expr) (u v : NodeId) (k k' : Int) : GoalM Unit := do
let kuv mkProofForPath u v
let u getExpr u
let v getExpr v
pushEqTrue e <| mkPropagateEqTrueProof u v k kuv k'
private def propagateEqFalse (e : Expr) (u v : NodeId) (k k' : Int) : GoalM Unit := do
let kuv mkProofForPath u v
let u getExpr u
let v getExpr v
pushEqFalse e <| mkPropagateEqFalseProof u v k kuv k'
/-- Propagates all pending contraints and equalities and resets to "to do" list. -/
private def propagatePending : GoalM Unit := do
let todo modifyGet fun s => (s.arith.offset.propagate, { s with arith.offset.propagate := [] })
for p in todo do
match p with
| .eqTrue e u v k k' => propagateEqTrue e u v k k'
| .eqFalse e u v k k' => propagateEqFalse e u v k k'
| .eq u v =>
let ue getExpr u
let ve getExpr v
unless ( isEqv ue ve) do
let huv mkProofForPath u v
let hvu mkProofForPath v u
pushEq ue ve <| mkApp4 (mkConst ``Grind.Nat.eq_of_le_of_le) ue ve huv hvu
/--
Tries to assign `e` to `True`, which is represented by constraint `c` (from `u` to `v`), using the
path `u --(k)--> v`.
Given `e` represented by constraint `c` (from `u` to `v`).
Checks whether `e = True` can be propagated using the path `u --(k)--> v`.
If it can, adds a new entry to propagation list.
-/
private def propagateTrue (u v : NodeId) (k : Int) (c : Cnstr NodeId) (e : Expr) : GoalM Bool := do
private def checkEqTrue (u v : NodeId) (k : Int) (c : Cnstr NodeId) (e : Expr) : GoalM Bool := do
if k c.k then
trace[grind.offset.propagate] "{{ u, v, k : Cnstr NodeId}} ==> {e} = True"
let kuv mkProofForPath u v
let u getExpr u
let v getExpr v
pushEqTrue e <| mkPropagateEqTrueProof u v k kuv c.k
pushToPropagate <| .eqTrue e u v k c.k
return true
return false
/--
Tries to assign `e` to `False`, which is represented by constraint `c` (from `v` to `u`), using the
path `u --(k)--> v`.
Given `e` represented by constraint `c` (from `v` to `u`).
Checks whether `e = False` can be propagated using the path `u --(k)--> v`.
If it can, adds a new entry to propagation list.
-/
private def propagateFalse (u v : NodeId) (k : Int) (c : Cnstr NodeId) (e : Expr) : GoalM Bool := do
private def checkEqFalse (u v : NodeId) (k : Int) (c : Cnstr NodeId) (e : Expr) : GoalM Bool := do
if k + c.k < 0 then
trace[grind.offset.propagate] "{{ u, v, k : Cnstr NodeId}} ==> {e} = False"
let kuv mkProofForPath u v
let u getExpr u
let v getExpr v
pushEqFalse e <| mkPropagateEqFalseProof u v k kuv c.k
pushToPropagate <| .eqFalse e u v k c.k
return true
return false
/-- Equality propagation. -/
private def checkEq (u v : NodeId) (k : Int) : GoalM Unit := do
if k != 0 then return ()
let some k' getDist? v u | return ()
if k' != 0 then return ()
let ue getExpr u
let ve getExpr v
if ( isEqv ue ve) then return ()
pushToPropagate <| .eq u v
/--
Auxiliary function for implementing `propagateAll`.
Traverses the constraints `c` (representing an expression `e`) s.t.
@@ -164,34 +205,44 @@ private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → GoalM B
return !( f c e)
modify' fun s => { s with cnstrsOf := s.cnstrsOf.insert (u, v) cs' }
/-- Equality propagation. -/
private def propagateEq (u v : NodeId) (k : Int) : GoalM Unit := do
if k != 0 then return ()
let some k' getDist? v u | return ()
if k' != 0 then return ()
let ue getExpr u
let ve getExpr v
if ( isEqv ue ve) then return ()
let huv mkProofForPath u v
let hvu mkProofForPath v u
trace[grind.offset.eq.from] "{ue}, {ve}"
pushEq ue ve <| mkApp4 (mkConst ``Grind.Nat.eq_of_le_of_le) ue ve huv hvu
/-- Performs constraint propagation. -/
private def propagateAll (u v : NodeId) (k : Int) : GoalM Unit := do
updateCnstrsOf u v fun c e => return !( propagateTrue u v k c e)
updateCnstrsOf v u fun c e => return !( propagateFalse u v k c e)
propagateEq u v k
/-- Finds constrains and equalities to be propagated. -/
private def checkToPropagate (u v : NodeId) (k : Int) : GoalM Unit := do
updateCnstrsOf u v fun c e => return !( checkEqTrue u v k c e)
updateCnstrsOf v u fun c e => return !( checkEqFalse u v k c e)
checkEq u v k
/--
If `isShorter u v k`, updates the shortest distance between `u` and `v`.
`w` is the penultimate node in the path from `u` to `v`.
`w` is a node in the path from `u` to `v` such that `(← getProof? w v)` is `some`
-/
private def updateIfShorter (u v : NodeId) (k : Int) (w : NodeId) : GoalM Unit := do
if ( isShorter u v k) then
setDist u v k
setProof u v ( getProof? w v).get!
propagateAll u v k
setProof u v ( getProof w v)
checkToPropagate u v k
def Cnstr.toExpr (c : Cnstr NodeId) : GoalM Expr := do
let u := ( get').nodes[c.u]!
let v := ( get').nodes[c.v]!
if c.k == 0 then
return mkNatLE u v
else if c.k < 0 then
return mkNatLE (mkNatAdd u (Lean.toExpr ((-c.k).toNat))) v
else
return mkNatLE u (mkNatAdd v (Lean.toExpr c.k.toNat))
def checkInvariants : GoalM Unit := do
let s get'
for u in [:s.targets.size], es in s.targets.toArray do
for (v, k) in es do
let c : Cnstr NodeId := { u, v, k }
trace[grind.debug.offset] "{c}"
let p mkProofForPath u v
trace[grind.debug.offset.proof] "{p} : {← inferType p}"
check p
unless ( withDefault <| isDefEq ( inferType p) ( Cnstr.toExpr c)) do
trace[grind.debug.offset.proof] "failed: {← inferType p} =?= {← Cnstr.toExpr c}"
unreachable!
/--
Adds an edge `u --(k) --> v` justified by the proof term `p`, and then
@@ -207,8 +258,9 @@ def addEdge (u : NodeId) (v : NodeId) (k : Int) (p : Expr) : GoalM Unit := do
if ( isShorter u v k) then
setDist u v k
setProof u v { w := u, k, proof := p }
propagateAll u v k
checkToPropagate u v k
update
propagatePending
where
update : GoalM Unit := do
forEachTargetOf v fun j k₂ => do
@@ -226,10 +278,12 @@ private def internalizeCnstr (e : Expr) (c : Cnstr Expr) : GoalM Unit := do
let v mkNode c.v
let c := { c with u, v }
if let some k getDist? u v then
if ( propagateTrue u v k c e) then
if k c.k then
propagateEqTrue e u v k c.k
return ()
if let some k getDist? v u then
if ( propagateFalse v u k c e) then
if k + c.k < 0 then
propagateEqFalse e v u k c.k
return ()
trace[grind.offset.internalize] "{e} ↦ {c}"
modify' fun s => { s with
@@ -315,27 +369,4 @@ def traceDists : GoalM Unit := do
for (v, k) in es do
trace[grind.offset.dist] "#{u} -({k})-> #{v}"
def Cnstr.toExpr (c : Cnstr NodeId) : GoalM Expr := do
let u := ( get').nodes[c.u]!
let v := ( get').nodes[c.v]!
if c.k == 0 then
return mkNatLE u v
else if c.k < 0 then
return mkNatLE (mkNatAdd u (Lean.toExpr ((-c.k).toNat))) v
else
return mkNatLE u (mkNatAdd v (Lean.toExpr c.k.toNat))
def checkInvariants : GoalM Unit := do
let s get'
for u in [:s.targets.size], es in s.targets.toArray do
for (v, k) in es do
let c : Cnstr NodeId := { u, v, k }
trace[grind.debug.offset] "{c}"
let p mkProofForPath u v
trace[grind.debug.offset.proof] "{p} : {← inferType p}"
check p
unless ( withDefault <| isDefEq ( inferType p) ( Cnstr.toExpr c)) do
trace[grind.debug.offset.proof] "failed: {← inferType p} =?= {← Cnstr.toExpr c}"
unreachable!
end Lean.Meta.Grind.Arith.Offset

View File

@@ -24,6 +24,20 @@ structure ProofInfo where
proof : Expr
deriving Inhabited
/--
Auxiliary inductive type for representing contraints and equalities
that should be propagated to core.
Recall that we cannot compute proofs until the short-distance
data-structures have been fully updated when a new edge is inserted.
Thus, we store the information to be propagated into a list.
See field `propagate` in `State`.
-/
inductive ToPropagate where
| eqTrue (e : Expr) (u v : NodeId) (k k' : Int)
| eqFalse (e : Expr) (u v : NodeId) (k k' : Int)
| eq (u v : NodeId)
deriving Inhabited
/-- State of the constraint offset procedure. -/
structure State where
/-- Mapping from `NodeId` to the `Expr` represented by the node. -/
@@ -53,7 +67,9 @@ structure State where
`w` is the penultimate node in the path, and `proof` is the justification for
the last edge.
-/
proofs : PArray (AssocList NodeId ProofInfo) := {}
proofs : PArray (AssocList NodeId ProofInfo) := {}
/-- Truth values and equalities to propagate to core. -/
propagate : List ToPropagate := []
deriving Inhabited
end Offset

View File

@@ -103,13 +103,10 @@ private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do
return .ready info.ctors.length info.isRec
if e.isFVar then
let type whnfD ( inferType e)
trace[Meta.debug] "1. {e}"
let report : GoalM Unit := do
reportIssue "cannot perform case-split on {e}, unexpected type{indentExpr type}"
let .const declName _ := type.getAppFn | report; return .resolved
trace[Meta.debug] "2. {e}"
let .inductInfo info getConstInfo declName | report; return .resolved
trace[Meta.debug] "3. {e}"
return .ready info.ctors.length info.isRec
return .notReady

View File

@@ -0,0 +1,14 @@
inductive Vec (α : Type u) : Nat Type u
| nil : Vec α 0
| cons : α Vec α n Vec α (n+1)
def h (v w : Vec α n) : Nat :=
match v, w with
| _, .cons _ (.cons _ _) => 20
| .nil, _ => 30
| _, _ => 40
-- In the following example, we need to instruct `grind` to case-split on `Vec` terms because
-- we don't have a propagation rule for given `as : Vec α (n+1)`, then `∃ b bs, as = .cons b bs`
example (a b : Vec α 2) : h a b = 20 := by
grind [h.eq_def, Vec]

View File

@@ -268,9 +268,12 @@ fun {a4} p a1 a2 a3 =>
intro_with_eq (p ↔ a4 ≤ a3 + 2) (p = (a4 ≤ a3 + 2)) (a1 ≤ a4) (iff_eq p (a4 ≤ a3 + 2)) fun a_3 =>
Classical.byContradiction
(intro_with_eq (¬a1 ≤ a4) (a4 + 1 ≤ a1) False (Nat.not_le_eq a1 a4) fun x =>
Nat.unsat_lo_lo a4 a1 1 7 rfl_true x
(Nat.lo_lo a1 a2 a4 1 6 (Nat.of_le_eq_false a2 a1 (Eq.trans (Eq.symm a) (eq_false a_1)))
(Nat.lo_lo a2 a3 a4 3 3 a_2 (Nat.of_ro_eq_false a4 a3 2 (Eq.trans (Eq.symm a_3) (eq_false a_1))))))
Eq.mp
(Eq.trans (Eq.symm (eq_true x))
(Nat.lo_eq_false_of_lo a1 a4 7 1 rfl_true
(Nat.lo_lo a1 a2 a4 1 6 (Nat.of_le_eq_false a2 a1 (Eq.trans (Eq.symm a) (eq_false a_1)))
(Nat.lo_lo a2 a3 a4 3 3 a_2 (Nat.of_ro_eq_false a4 a3 2 (Eq.trans (Eq.symm a_3) (eq_false a_1)))))))
True.intro)
-/
#guard_msgs (info) in
open Lean Grind in