Compare commits

...

4 Commits

Author SHA1 Message Date
Leonardo de Moura
6e53203140 refactor: use macro 2025-08-30 13:16:22 -07:00
Leonardo de Moura
927a9a1ed6 use namespaces 2025-08-30 12:53:44 -07:00
Leonardo de Moura
82eee19b4f fix: use Grind.check 2025-08-30 12:49:00 -07:00
Leonardo de Moura
8bb038c9d4 feat: AC.checkStruct 2025-08-30 12:30:14 -07:00
4 changed files with 195 additions and 153 deletions

View File

@@ -24,4 +24,6 @@ builtin_initialize registerTraceClass `grind.ac.internalize
builtin_initialize registerTraceClass `grind.debug.ac.op
builtin_initialize registerTraceClass `grind.debug.ac.basis
builtin_initialize registerTraceClass `grind.debug.ac.simp
builtin_initialize registerTraceClass `grind.debug.ac.check
builtin_initialize registerTraceClass `grind.debug.ac.queue
end Lean

View File

@@ -45,63 +45,6 @@ def norm (e : AC.Expr) : ACM AC.Seq := do
| false, true => return e.toSeq.erase0
| false, false => return e.toSeq
def saveDiseq (c : DiseqCnstr) : ACM Unit := do
modifyStruct fun s => { s with diseqs := s.diseqs.push c }
def DiseqCnstr.eraseDup (c : DiseqCnstr) : ACM DiseqCnstr := do
unless ( isIdempotent) do return c
let lhs := c.lhs.eraseDup
let rhs := c.rhs.eraseDup
if c.lhs == lhs && c.rhs == rhs then
return c
else
return { lhs, rhs, h := .erase_dup c }
def DiseqCnstr.assert (c : DiseqCnstr) : ACM Unit := do
let c c.eraseDup
-- TODO: simplify and check conflict
trace[grind.ac.assert] "{← c.denoteExpr}"
if c.lhs == c.rhs then
c.setUnsat
else
saveDiseq c
def mkEqCnstr (lhs rhs : AC.Seq) (h : EqCnstrProof) : ACM EqCnstr := do
let id := ( getStruct).nextId
modifyStruct fun s => { s with nextId := s.nextId + 1 }
return { lhs, rhs, h, id }
def EqCnstr.eraseDup (c : EqCnstr) : ACM EqCnstr := do
unless ( isIdempotent) do return c
let lhs := c.lhs.eraseDup
let rhs := c.rhs.eraseDup
if c.lhs == lhs && c.rhs == rhs then
return c
else
return { c with lhs, rhs, h := .erase_dup c }
def EqCnstr.erase0 (c : EqCnstr) : ACM EqCnstr := do
unless ( hasNeutral) do return c
let lhs := c.lhs.erase0
let rhs := c.rhs.erase0
if c.lhs == lhs && c.rhs == rhs then
return c
else
return { c with lhs, rhs, h := .erase0 c }
def EqCnstr.cleanup (c : EqCnstr) : ACM EqCnstr := do
( c.eraseDup).erase0
def EqCnstr.orient (c : EqCnstr) : EqCnstr :=
if compare c.rhs c.lhs == .gt then
{ c with lhs := c.rhs, rhs := c.lhs, h := .swap c }
else
c
def EqCnstr.superposeWith (c : EqCnstr) : ACM Unit := do
trace[Meta.debug] "superpose {← c.denoteExpr}"
return () -- TODO
/--
Returns `some (c, r)`, where `c` is an equation from the basis whose LHS simplifies `s` when
`(← isCommutative)` is `false`
@@ -122,105 +65,161 @@ private def _root_.Lean.Grind.AC.Seq.findSimpAC? (s : AC.Seq) : ACM (Option (EqC
unless r matches .false do return some (c, r)
return none
private def simplifyLhsWithA (c : EqCnstr) (c' : EqCnstr) (r : AC.SubseqResult) : EqCnstr :=
match r with
| .exact => { c with lhs := c'.rhs, h := .simp_exact (lhs := true) c c' }
| .prefix s => { c with lhs := c'.rhs ++ s, h := .simp_prefix (lhs := true) s c c' }
| .suffix s => { c with lhs := s ++ c'.rhs, h := .simp_suffix (lhs := true) s c c' }
| .middle p s => { c with lhs := p ++ c'.rhs ++ s, h := .simp_middle (lhs := true) p s c c' }
| .false => c
private def simplifyRhsWithA (c : EqCnstr) (c' : EqCnstr) (r : AC.SubseqResult) : EqCnstr :=
match r with
| .exact => { c with rhs := c'.rhs, h := .simp_exact (lhs := false) c c' }
| .prefix s => { c with rhs := c'.rhs ++ s, h := .simp_prefix (lhs := false) s c c' }
| .suffix s => { c with rhs := s ++ c'.rhs, h := .simp_suffix (lhs := false) s c c' }
| .middle p s => { c with rhs := p ++ c'.rhs ++ s, h := .simp_middle (lhs := false) p s c c' }
| .false => c
/-- Simplifies `c` using the basis when `(← isCommutative)` is `false` -/
private def EqCnstr.simplifyA (c : EqCnstr) : ACM EqCnstr := do
let mut c c.cleanup
repeat
incSteps
if ( checkMaxSteps) then return c
if let some (c', r) c.lhs.findSimpA? then
c := simplifyLhsWithA c c' r
c c.cleanup
else if let some (c', r) c.rhs.findSimpA? then
c := simplifyRhsWithA c c' r
c c.cleanup
else
trace[grind.debug.ac.simplify] "{← c.denoteExpr}"
local macro "gen_cnstr_fns " cnstr:ident : command =>
let mkId (declName : Name) := mkIdent <| cnstr.getId ++ declName
`(
private def $(mkId `eraseDup) (c : $cnstr) : ACM $cnstr := do
unless ( isIdempotent) do return c
let lhs := c.lhs.eraseDup
let rhs := c.rhs.eraseDup
if c.lhs == lhs && c.rhs == rhs then
return c
return c
/--
Simplifies `c` (lhs and rhs) using `c'`, returns `some c` if simplified.
Case `(← isCommutative) == false`
-/
private def simplifyWithA' (c : EqCnstr) (c' : EqCnstr) : Option EqCnstr := do
let r₁ := c'.lhs.subseq c.lhs
let c := simplifyLhsWithA c c' r₁
let r₂ := c'.lhs.subseq c.rhs
let c := simplifyRhsWithA c c' r₂
if r₁ matches .false && r₂ matches .false then none else some c
private def simplifyLhsWithAC (c : EqCnstr) (c' : EqCnstr) (r : AC.SubsetResult) : EqCnstr :=
match r with
| .exact => { c with lhs := c'.rhs, h := .simp_exact (lhs := true) c c' }
| .strict s => { c with lhs := c'.rhs.union s, h := .simp_ac (lhs := true) s c c' }
| .false => c
private def simplifyRhsWithAC (c : EqCnstr) (c' : EqCnstr) (r : AC.SubsetResult) : EqCnstr :=
match r with
| .exact => { c with rhs := c'.rhs, h := .simp_exact (lhs := false) c c' }
| .strict s => { c with rhs := c'.rhs.union s, h := .simp_ac (lhs := false) s c c' }
| .false => c
/--
Simplifies `c` (lhs and rhs) using `c'`, returns `some c` if simplified.
Case `(← isCommutative) == true`
-/
private def simplifyWithAC' (c : EqCnstr) (c' : EqCnstr) : Option EqCnstr := do
let r₁ := c'.lhs.subset c.lhs
let c := simplifyLhsWithAC c c' r₁
let r₂ := c'.lhs.subset c.rhs
let c := simplifyRhsWithAC c c' r₂
if r₁ matches .false && r₂ matches .false then none else some c
/-- Simplify `c` using the basis when `(← isCommutative)` is `true` -/
private def EqCnstr.simplifyAC (c : EqCnstr) : ACM EqCnstr := do
let mut c c.cleanup
repeat
incSteps
if ( checkMaxSteps) then return c
if let some (c', r) c.lhs.findSimpAC? then
c := simplifyLhsWithAC c c' r
c c.cleanup
else if let some (c', r) c.rhs.findSimpAC? then
c := simplifyRhsWithAC c c' r
c c.cleanup
else
trace[grind.debug.ac.simplify] "{← c.denoteExpr}"
return c
return c
return { c with lhs, rhs, h := .erase_dup c }
/--
Simplifies `c` (lhs and rhs) using `c'`, returns `some c` if simplified.
-/
private def EqCnstr.simplifyWith (c : EqCnstr) (c' : EqCnstr) : ACM (Option EqCnstr) := do
incSteps
if ( isCommutative) then
return simplifyWithAC' c c'
private def $(mkId `erase0) (c : $cnstr) : ACM $cnstr := do
unless ( hasNeutral) do return c
let lhs := c.lhs.erase0
let rhs := c.rhs.erase0
if c.lhs == lhs && c.rhs == rhs then
return c
else
return { c with lhs, rhs, h := .erase0 c }
private def $(mkId `cleanup) (c : $cnstr) : ACM $cnstr := do
( c.eraseDup).erase0
private def $(mkId `simplifyLhsWithA) (c : $cnstr) (c' : EqCnstr) (r : AC.SubseqResult) : $cnstr :=
match r with
| .exact => { c with lhs := c'.rhs, h := .simp_exact (lhs := true) c c' }
| .prefix s => { c with lhs := c'.rhs ++ s, h := .simp_prefix (lhs := true) s c c' }
| .suffix s => { c with lhs := s ++ c'.rhs, h := .simp_suffix (lhs := true) s c c' }
| .middle p s => { c with lhs := p ++ c'.rhs ++ s, h := .simp_middle (lhs := true) p s c c' }
| .false => c
private def $(mkId `simplifyRhsWithA) (c : $cnstr) (c' : EqCnstr) (r : AC.SubseqResult) : $cnstr :=
match r with
| .exact => { c with rhs := c'.rhs, h := .simp_exact (lhs := false) c c' }
| .prefix s => { c with rhs := c'.rhs ++ s, h := .simp_prefix (lhs := false) s c c' }
| .suffix s => { c with rhs := s ++ c'.rhs, h := .simp_suffix (lhs := false) s c c' }
| .middle p s => { c with rhs := p ++ c'.rhs ++ s, h := .simp_middle (lhs := false) p s c c' }
| .false => c
/-- Simplifies `c` using the basis when `(← isCommutative)` is `false` -/
private def $(mkId `simplifyA) (c : $cnstr) : ACM $cnstr := do
let mut c c.cleanup
repeat
incSteps
if ( checkMaxSteps) then return c
if let some (c', r) c.lhs.findSimpA? then
c := c.simplifyLhsWithA c' r
c c.cleanup
else if let some (c', r) c.rhs.findSimpA? then
c := c.simplifyRhsWithA c' r
c c.cleanup
else
trace[grind.debug.ac.simplify] "{← c.denoteExpr}"
return c
return c
/--
Simplifies `c` (lhs and rhs) using `c'`, returns `some c` if simplified.
Case `(← isCommutative) == false`
-/
private def $(mkId `simplifyWithA') (c : $cnstr) (c' : EqCnstr) : Option $cnstr := do
let r₁ := c'.lhs.subseq c.lhs
let c := c.simplifyLhsWithA c' r₁
let r₂ := c'.lhs.subseq c.rhs
let c := c.simplifyRhsWithA c' r₂
if r₁ matches .false && r₂ matches .false then none else some c
private def $(mkId `simplifyLhsWithAC) (c : $cnstr) (c' : EqCnstr) (r : AC.SubsetResult) : $cnstr :=
match r with
| .exact => { c with lhs := c'.rhs, h := .simp_exact (lhs := true) c c' }
| .strict s => { c with lhs := c'.rhs.union s, h := .simp_ac (lhs := true) s c c' }
| .false => c
private def $(mkId `simplifyRhsWithAC) (c : $cnstr) (c' : EqCnstr) (r : AC.SubsetResult) : $cnstr :=
match r with
| .exact => { c with rhs := c'.rhs, h := .simp_exact (lhs := false) c c' }
| .strict s => { c with rhs := c'.rhs.union s, h := .simp_ac (lhs := false) s c c' }
| .false => c
/--
Simplifies `c` (lhs and rhs) using `c'`, returns `some c` if simplified.
Case `(← isCommutative) == true`
-/
private def $(mkId `simplifyWithAC') (c : $cnstr) (c' : EqCnstr) : Option $cnstr := do
let r₁ := c'.lhs.subset c.lhs
let c := c.simplifyLhsWithAC c' r₁
let r₂ := c'.lhs.subset c.rhs
let c := c.simplifyRhsWithAC c' r₂
if r₁ matches .false && r₂ matches .false then none else some c
/-- Simplify `c` using the basis when `(← isCommutative)` is `true` -/
private def $(mkId `simplifyAC) (c : $cnstr) : ACM $cnstr := do
let mut c c.cleanup
repeat
incSteps
if ( checkMaxSteps) then return c
if let some (c', r) c.lhs.findSimpAC? then
c := c.simplifyLhsWithAC c' r
c c.cleanup
else if let some (c', r) c.rhs.findSimpAC? then
c := c.simplifyRhsWithAC c' r
c c.cleanup
else
trace[grind.debug.ac.simplify] "{← c.denoteExpr}"
return c
return c
/--
Simplifies `c` (lhs and rhs) using `c'`, returns `some c` if simplified.
-/
private def $(mkId `simplifyWith) (c : $cnstr) (c' : EqCnstr) : ACM (Option $cnstr) := do
incSteps
if ( isCommutative) then
return c.simplifyWithAC' c'
else
return c.simplifyWithA' c'
/-- Simplify `c` using the basis -/
private def $(mkId `simplify) (c : $cnstr) : ACM $cnstr := do
if ( isCommutative) then c.simplifyAC else c.simplifyA
)
gen_cnstr_fns EqCnstr
gen_cnstr_fns DiseqCnstr
def saveDiseq (c : DiseqCnstr) : ACM Unit := do
modifyStruct fun s => { s with diseqs := s.diseqs.push c }
def DiseqCnstr.assert (c : DiseqCnstr) : ACM Unit := do
let c c.eraseDup
-- let c ← c.simplify -- TODO: uncomment after implementing proof generation
trace[grind.ac.assert] "{← c.denoteExpr}"
if c.lhs == c.rhs then
c.setUnsat
else
return simplifyWithA' c c'
saveDiseq c
/-- Simplify `c` using the basis -/
private def EqCnstr.simplify (c : EqCnstr) : ACM EqCnstr := do
if ( isCommutative) then c.simplifyAC else c.simplifyA
def mkEqCnstr (lhs rhs : AC.Seq) (h : EqCnstrProof) : ACM EqCnstr := do
let id := ( getStruct).nextId
modifyStruct fun s => { s with nextId := s.nextId + 1 }
return { lhs, rhs, h, id }
def EqCnstr.orient (c : EqCnstr) : EqCnstr :=
if compare c.rhs c.lhs == .gt then
{ c with lhs := c.rhs, rhs := c.lhs, h := .swap c }
else
c
def EqCnstr.superposeWith (c : EqCnstr) : ACM Unit := do
trace[Meta.debug] "superpose {← c.denoteExpr}"
return () -- TODO
def EqCnstr.addToQueue (c : EqCnstr) : ACM Unit := do
trace[grind.debug.ac.queue] "{← c.denoteExpr}"
modifyStruct fun s => { s with queue := s.queue.insert c }
def EqCnstr.simplifyBasis (c : EqCnstr) : ACM Unit := do
@@ -257,6 +256,10 @@ def EqCnstr.addToBasisAfterSimp (c : EqCnstr) : ACM Unit := do
trace_goal[grind.ac.assert.basis] "{← c.denoteExpr}"
addToBasisCore c
def EqCnstr.addToBasis (c : EqCnstr) : ACM Unit := do
let c c.simplify
c.addToBasisAfterSimp
def EqCnstr.assert (c : EqCnstr) : ACM Unit := do
let c c.simplify
if c.lhs == c.rhs then
@@ -286,9 +289,44 @@ def processNewDiseqImpl (a b : Expr) : GoalM Unit := withExprs a b do
let rhs norm eb
{ lhs, rhs, h := .core a b ea eb : DiseqCnstr }.assert
def checkStruct : ACM Bool := do
private def isQueueEmpty : ACM Bool :=
return ( getStruct).queue.isEmpty
/--
Returns `true` if the todo queue is not empty or the `recheck` flag is set to `true`
-/
private def needCheck : ACM Bool := do
unless ( isQueueEmpty) do return true
return ( getStruct).recheck
private def getNext? : ACM (Option EqCnstr) := do
let some c := ( getStruct).queue.min? | return none
modifyStruct fun s => { s with queue := s.queue.erase c }
incSteps
return some c
private def checkDiseqs : ACM Unit := do
-- TODO
return false
return ()
private def propagateEqs : ACM Unit := do
-- TODO
return ()
private def checkStruct : ACM Bool := do
unless ( needCheck) do return false
trace_goal[grind.debug.ac.check] "{(← getStruct).op}"
repeat
checkSystem "ac"
let some c getNext? | break
trace_goal[grind.debug.ac.check] "{← c.denoteExpr}"
c.addToBasis
if ( isInconsistent) then return true
if ( checkMaxSteps) then return true
checkDiseqs
propagateEqs
modifyStruct fun s => { s with recheck := false }
return true
def check : GoalM Bool := do profileitM Exception "grind ac" ( getOptions) do
if ( checkMaxSteps) then return false

View File

@@ -32,10 +32,10 @@ inductive EqCnstrProof where
| swap (c : EqCnstr)
| simp_exact (lhs : Bool) (c₁ : EqCnstr) (c₂ : EqCnstr)
| simp_ac (lhs : Bool) (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr)
| superpose_ac (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr)
| simp_suffix (lhs : Bool) (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr)
| simp_prefix (lhs : Bool) (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr)
| simp_middle (lhs : Bool) (s₁ s₂ : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr)
| superpose_ac (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr)
| superpose_prefix (s₁ s₂ : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr)
end
@@ -61,6 +61,7 @@ inductive DiseqCnstrProof where
| core (a b : Expr) (ea eb : AC.Expr)
| erase_dup (c : DiseqCnstr)
| erase0 (c : DiseqCnstr)
| simp_exact (lhs : Bool) (c₁ : DiseqCnstr) (c₂ : EqCnstr)
| simp_ac (lhs : Bool) (s : AC.Seq) (c₁ : DiseqCnstr) (c₂ : EqCnstr)
| simp_suffix (lhs : Bool) (s : AC.Seq) (c₁ : DiseqCnstr) (c₂ : EqCnstr)
| simp_prefix (lhs : Bool) (s : AC.Seq) (c₁ : DiseqCnstr) (c₂ : EqCnstr)

View File

@@ -12,6 +12,7 @@ import Lean.Meta.Tactic.Grind.EMatch
import Lean.Meta.Tactic.Grind.Arith
import Lean.Meta.Tactic.Grind.Lookahead
import Lean.Meta.Tactic.Grind.Intro
import Lean.Meta.Tactic.Grind.Check
public section
namespace Lean.Meta.Grind
def tryFallback : GoalM Bool := do
@@ -50,7 +51,7 @@ where
intros gen
else
break
if ( assertAll <||> Arith.check <||> ematch <||> lookahead <||> splitNext <||> Arith.Cutsat.mbtc
if ( assertAll <||> check <||> ematch <||> lookahead <||> splitNext <||> Arith.Cutsat.mbtc
<||> Arith.Linear.mbtc <||> tryFallback) then
continue
return some ( getGoal) -- failed