Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
aaeb1bbf61 feat: simplify equations in grind AC module
This PR adds support for equality simplification helper functions to
the `grind` AC module.
2025-08-27 20:28:25 -07:00
6 changed files with 237 additions and 2 deletions

View File

@@ -114,6 +114,7 @@ structure Config where
When `true` (default: `true`), uses procedure for handling associative (and commutative) operators.
-/
ac := true
acSteps := 1000
/--
Maximum exponent eagerly evaluated while computing bounds for `ToInt` and
the characteristic of a ring.

View File

@@ -22,4 +22,6 @@ builtin_initialize registerTraceClass `grind.ac.assert
builtin_initialize registerTraceClass `grind.ac.internalize
builtin_initialize registerTraceClass `grind.debug.ac.op
builtin_initialize registerTraceClass `grind.debug.ac.basis
builtin_initialize registerTraceClass `grind.debug.ac.simp
end Lean

View File

@@ -8,9 +8,11 @@ prelude
public import Lean.Meta.Tactic.Grind.AC.Util
import Lean.Meta.Tactic.Grind.AC.DenoteExpr
import Lean.Meta.Tactic.Grind.AC.Proof
import Lean.Meta.Tactic.Grind.AC.Seq
public section
namespace Lean.Meta.Grind.AC
open Lean.Grind
/-- For each structure `s` s.t. `a` and `b` are elements of, execute `k` -/
@[specialize] def withExprs (a b : Expr) (k : ACM Unit) : GoalM Unit := do
let ids₁ getTermOpIds a
@@ -63,10 +65,205 @@ def DiseqCnstr.assert (c : DiseqCnstr) : ACM Unit := do
else
saveDiseq c
def mkEqCnstr (lhs rhs : AC.Seq) (h : EqCnstrProof) : ACM EqCnstr := do
let id := ( getStruct).nextId
modifyStruct fun s => { s with nextId := s.nextId + 1 }
return { lhs, rhs, h, id }
def EqCnstr.eraseDup (c : EqCnstr) : ACM EqCnstr := do
unless ( isIdempotent) do return c
let lhs := c.lhs.eraseDup
let rhs := c.rhs.eraseDup
if c.lhs == lhs && c.rhs == rhs then
return c
else
return { c with lhs, rhs, h := .erase_dup c }
def EqCnstr.orient (c : EqCnstr) : EqCnstr :=
if compare c.rhs c.lhs == .gt then
{ c with lhs := c.rhs, rhs := c.lhs, h := .swap c }
else
c
def EqCnstr.superposeWith (c : EqCnstr) : ACM Unit := do
trace[Meta.debug] "superpose {← c.denoteExpr}"
return () -- TODO
/--
Returns `some (c, r)`, where `c` is an equation from the basis whose LHS simplifies `s` when
`(← isCommutative)` is `false`
-/
private def _root_.Lean.Grind.AC.Seq.findSimpA? (s : AC.Seq) : ACM (Option (EqCnstr × AC.SubseqResult)) := do
for c in ( getStruct).basis do
let r := c.lhs.subseq s
unless r matches .false do return some (c, r)
return none
/--
Returns `some (c, r)`, where `c` is an equation from the basis whose LHS simplifies `s` when
`(← isCommutative)` is `true`
-/
private def _root_.Lean.Grind.AC.Seq.findSimpAC? (s : AC.Seq) : ACM (Option (EqCnstr × AC.SubsetResult)) := do
for c in ( getStruct).basis do
let r := c.lhs.subset s
unless r matches .false do return some (c, r)
return none
private def simplifyLhsWithA (c : EqCnstr) (c' : EqCnstr) (r : AC.SubseqResult) : EqCnstr :=
match r with
| .exact => { c with lhs := c'.rhs, h := .simp_exact (lhs := true) c c' }
| .prefix s => { c with lhs := c'.rhs ++ s, h := .simp_prefix (lhs := true) s c c' }
| .suffix s => { c with lhs := s ++ c'.rhs, h := .simp_suffix (lhs := true) s c c' }
| .middle p s => { c with lhs := p ++ c'.rhs ++ s, h := .simp_middle (lhs := true) p s c c' }
| .false => c
private def simplifyRhsWithA (c : EqCnstr) (c' : EqCnstr) (r : AC.SubseqResult) : EqCnstr :=
match r with
| .exact => { c with rhs := c'.rhs, h := .simp_exact (lhs := false) c c' }
| .prefix s => { c with rhs := c'.rhs ++ s, h := .simp_prefix (lhs := false) s c c' }
| .suffix s => { c with rhs := s ++ c'.rhs, h := .simp_suffix (lhs := false) s c c' }
| .middle p s => { c with rhs := p ++ c'.rhs ++ s, h := .simp_middle (lhs := false) p s c c' }
| .false => c
/-- Simplifies `c` using the basis when `(← isCommutative)` is `false` -/
private def EqCnstr.simplifyA (c : EqCnstr) : ACM EqCnstr := do
let mut c c.eraseDup
repeat
incSteps
if ( checkMaxSteps) then return c
if let some (c', r) c.lhs.findSimpA? then
c := simplifyLhsWithA c c' r
c c.eraseDup
else if let some (c', r) c.rhs.findSimpA? then
c := simplifyRhsWithA c c' r
c c.eraseDup
else
trace[grind.debug.ac.simplify] "{← c.denoteExpr}"
return c
return c
/--
Simplifies `c` (lhs and rhs) using `c'`, returns `some c` if simplified.
Case `(← isCommutative) == false`
-/
private def simplifyWithA' (c : EqCnstr) (c' : EqCnstr) : Option EqCnstr := do
let r₁ := c'.lhs.subseq c.lhs
let c := simplifyLhsWithA c c' r₁
let r₂ := c'.lhs.subseq c.rhs
let c := simplifyRhsWithA c c' r₂
if r₁ matches .false && r₂ matches .false then none else some c
private def simplifyLhsWithAC (c : EqCnstr) (c' : EqCnstr) (r : AC.SubsetResult) : EqCnstr :=
match r with
| .exact => { c with lhs := c'.rhs, h := .simp_exact (lhs := true) c c' }
| .strict s => { c with lhs := c'.rhs.union s, h := .simp_ac (lhs := true) s c c' }
| .false => c
private def simplifyRhsWithAC (c : EqCnstr) (c' : EqCnstr) (r : AC.SubsetResult) : EqCnstr :=
match r with
| .exact => { c with rhs := c'.rhs, h := .simp_exact (lhs := false) c c' }
| .strict s => { c with rhs := c'.rhs.union s, h := .simp_ac (lhs := false) s c c' }
| .false => c
/--
Simplifies `c` (lhs and rhs) using `c'`, returns `some c` if simplified.
Case `(← isCommutative) == true`
-/
private def simplifyWithAC' (c : EqCnstr) (c' : EqCnstr) : Option EqCnstr := do
let r₁ := c'.lhs.subset c.lhs
let c := simplifyLhsWithAC c c' r₁
let r₂ := c'.lhs.subset c.rhs
let c := simplifyRhsWithAC c c' r₂
if r₁ matches .false && r₂ matches .false then none else some c
/-- Simplify `c` using the basis when `(← isCommutative)` is `true` -/
private def EqCnstr.simplifyAC (c : EqCnstr) : ACM EqCnstr := do
let mut c c.eraseDup
repeat
incSteps
if ( checkMaxSteps) then return c
if let some (c', r) c.lhs.findSimpAC? then
c := simplifyLhsWithAC c c' r
c c.eraseDup
else if let some (c', r) c.rhs.findSimpAC? then
c := simplifyRhsWithAC c c' r
c c.eraseDup
else
trace[grind.debug.ac.simplify] "{← c.denoteExpr}"
return c
return c
/--
Simplifies `c` (lhs and rhs) using `c'`, returns `some c` if simplified.
-/
private def EqCnstr.simplifyWith (c : EqCnstr) (c' : EqCnstr) : ACM (Option EqCnstr) := do
incSteps
if ( isCommutative) then
return simplifyWithAC' c c'
else
return simplifyWithA' c c'
/-- Simplify `c` using the basis -/
private def EqCnstr.simplify (c : EqCnstr) : ACM EqCnstr := do
if ( isCommutative) then c.simplifyAC else c.simplifyA
def EqCnstr.addToQueue (c : EqCnstr) : ACM Unit := do
modifyStruct fun s => { s with queue := s.queue.insert c }
def EqCnstr.simplifyBasis (c : EqCnstr) : ACM Unit := do
let rec go (basis : List EqCnstr) (acc : List EqCnstr) : ACM (List EqCnstr) := do
match basis with
| [] => return acc.reverse
| c' :: basis =>
if let some c' c'.simplifyWith c then
c'.addToQueue
go basis acc
else
go basis (c' :: acc)
let basis go ( getStruct).basis []
modifyStruct fun s => { s with basis }
private def addSorted (c : EqCnstr) : List EqCnstr List EqCnstr
| [] => [c]
| c' :: cs =>
if c.lhs.length c'.lhs.length then
c :: c' :: cs
else
c' :: addSorted c cs
def EqCnstr.addToBasisCore (c : EqCnstr) : ACM Unit := do
trace[grind.debug.ac.basis] "{← c.denoteExpr}"
modifyStruct fun s => { s with
basis := addSorted c s.basis
recheck := true
}
def EqCnstr.addToBasisAfterSimp (c : EqCnstr) : ACM Unit := do
c.simplifyBasis
c.superposeWith
trace_goal[grind.ac.assert.basis] "{← c.denoteExpr}"
addToBasisCore c
def EqCnstr.assert (c : EqCnstr) : ACM Unit := do
let c c.simplify
if c.lhs == c.rhs then
return ()
else
let c := c.orient
trace[grind.ac.assert] "{← c.denoteExpr}"
if c.lhs.isVar then
c.addToBasisAfterSimp
else
c.addToQueue
@[export lean_process_ac_eq]
def processNewEqImpl (a b : Expr) : GoalM Unit := withExprs a b do
trace[grind.ac.assert] "{a} = {b}"
-- TODO
let ea asACExpr a
let lhs norm ea
let eb asACExpr b
let rhs norm eb
let c mkEqCnstr lhs rhs (.core a b ea eb)
c.assert
@[export lean_process_ac_diseq]
def processNewDiseqImpl (a b : Expr) : GoalM Unit := withExprs a b do

View File

@@ -7,6 +7,7 @@ module
prelude
public import Init.Core
public import Init.Grind.AC
public import Init.Data.Ord
public section
namespace Lean.Grind.AC
@@ -29,6 +30,26 @@ where
| .var x, acc => .cons x acc
| .cons x s, acc => go s (.cons x acc)
protected def Seq.compare (s₁ s₂ : Seq) : Ordering :=
let len₁ := s₁.length
let len₂ := s₂.length
if len₁ < len₂ then
.lt
else if len₁ > len₂ then
.gt
else
lex s₁ s₂
where
lex (s₁ s₂ : Seq) : Ordering :=
match s₁, s₂ with
| .var x, .var y => compare x y
| .cons .., .var _ => .gt
| .var _, .cons .. => .lt
| .cons x s₁, .cons y s₂ => compare x y |>.then (lex s₁ s₂)
instance : Ord Seq where
compare := Seq.compare
instance : Append Seq where
append := Seq.concat

View File

@@ -28,6 +28,8 @@ structure EqCnstr where
inductive EqCnstrProof where
| core (a b : Expr) (ea eb : AC.Expr)
| erase_dup (c : EqCnstr)
| swap (c : EqCnstr)
| simp_exact (lhs : Bool) (c₁ : EqCnstr) (c₂ : EqCnstr)
| simp_ac (lhs : Bool) (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr)
| superpose_ac (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr)
| simp_suffix (lhs : Bool) (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr)
@@ -91,6 +93,11 @@ structure Struct where
basis : List EqCnstr := {}
/-- Disequalities. -/
diseqs : PArray DiseqCnstr := {}
/--
If `recheck` is `true`, then new equalities have been added to the basis since we checked
disequalities and implied equalities.
-/
recheck : Bool := false
deriving Inhabited
/-- State for all associative operators detected by `grind`. -/
@@ -108,6 +115,7 @@ structure State where
Mapping from expressions/terms to their structure ids.
Recall that term may be the argument of different operators. -/
exprToOpIds : PHashMap ExprPtr (List Nat) := {}
steps := 0
deriving Inhabited
end Lean.Meta.Grind.AC

View File

@@ -19,6 +19,12 @@ def get' : GoalM State := do
@[inline] def modify' (f : State State) : GoalM Unit := do
modify fun s => { s with ac := f s.ac }
def checkMaxSteps : GoalM Bool := do
return ( get').steps >= ( getConfig).acSteps
def incSteps : GoalM Unit := do
modify' fun s => { s with steps := s.steps + 1 }
structure ACM.Context where
opId : Nat