feat: ac normalization in grind (#10146)

This PR implements the basic infrastructure for the new procedure
handling AC operators in grind. It already supports normalizing
disequalities. Future PRs will add support for simplification using
equalities, and computing critical pairs. Examples:
```lean
example {α : Sort u} (op : α → α → α) [Std.Associative op] (a b c : α)
    : op a (op b c) = op (op a b) c := by
  grind only

example {α : Sort u} (op : α → α → α) (u : α) [Std.Associative op] [Std.LawfulIdentity op u] (a b c : α)
    : op a (op b c) = op (op a b) (op c u) := by
  grind only

example {α : Type u} (op : α → α → α) (u : α) [Std.Associative op] [Std.Commutative op] 
    [Std.IdempotentOp op] [Std.LawfulIdentity op u] (a b c : α)
    : op (op a a) (op b c) = op (op (op b a) (op (op u b) b)) c := by
  grind only

example {α} (as bs cs : List α) : as ++ (bs ++ cs) = ((as ++ []) ++ bs) ++ (cs ++ []) := by
  grind only

example (a b c : Nat) : max a (max b c) = max (max b 0) (max a c) ∧ min a b = min b a := by
  grind only [cases Or]
```
This commit is contained in:
Leonardo de Moura
2025-08-26 20:28:30 -07:00
committed by GitHub
parent db3fb47109
commit aaec0f584c
28 changed files with 729 additions and 98 deletions

View File

@@ -15,19 +15,22 @@ public import Init.Data.Bool
namespace Lean.Grind.AC
abbrev Var := Nat
structure Context (α : Type u) where
vars : RArray α
structure Context (α : Sort u) where
vars : RArray (PLift α)
op : α α α
inductive Expr where
| var (x : Nat)
| var (x : Var)
| op (lhs rhs : Expr)
deriving Inhabited, Repr, BEq
noncomputable def Expr.denote {α} (ctx : Context α) (e : Expr) : α :=
Expr.rec (fun x => ctx.vars.get x) (fun _ _ ih₁ ih₂ => ctx.op ih₁ ih₂) e
noncomputable def Var.denote {α : Sort u} (ctx : Context α) (x : Var) : α :=
PLift.rec (fun x => x) (ctx.vars.get x)
theorem Expr.denote_var {α} (ctx : Context α) (x : Var) : (Expr.var x).denote ctx = ctx.vars.get x := rfl
noncomputable def Expr.denote {α} (ctx : Context α) (e : Expr) : α :=
Expr.rec (fun x => x.denote ctx) (fun _ _ ih₁ ih₂ => ctx.op ih₁ ih₂) e
theorem Expr.denote_var {α} (ctx : Context α) (x : Var) : (Expr.var x).denote ctx = x.denote ctx := rfl
theorem Expr.denote_op {α} (ctx : Context α) (a b : Expr) : (Expr.op a b).denote ctx = ctx.op (a.denote ctx) (b.denote ctx) := rfl
attribute [local simp] Expr.denote_var Expr.denote_op
@@ -59,10 +62,10 @@ instance : LawfulBEq Seq where
rfl := by intro a; induction a <;> simp! [BEq.beq]; assumption
noncomputable def Seq.denote {α} (ctx : Context α) (s : Seq) : α :=
Seq.rec (fun x => ctx.vars.get x) (fun x _ ih => ctx.op (ctx.vars.get x) ih) s
Seq.rec (fun x => x.denote ctx) (fun x _ ih => ctx.op (x.denote ctx) ih) s
theorem Seq.denote_var {α} (ctx : Context α) (x : Var) : (Seq.var x).denote ctx = ctx.vars.get x := rfl
theorem Seq.denote_op {α} (ctx : Context α) (x : Var) (s : Seq) : (Seq.cons x s).denote ctx = ctx.op (ctx.vars.get x) (s.denote ctx) := rfl
theorem Seq.denote_var {α} (ctx : Context α) (x : Var) : (Seq.var x).denote ctx = x.denote ctx := rfl
theorem Seq.denote_op {α} (ctx : Context α) (x : Var) (s : Seq) : (Seq.cons x s).denote ctx = ctx.op (x.denote ctx) (s.denote ctx) := rfl
attribute [local simp] Seq.denote_var Seq.denote_op
@@ -152,7 +155,7 @@ theorem Seq.erase0_k_eq_erase0 (s : Seq) : s.erase0_k = s.erase0 := by
attribute [local simp] Seq.erase0_k_eq_erase0
theorem Seq.denote_erase0 {α} (ctx : Context α) {inst : Std.LawfulIdentity ctx.op (ctx.vars.get 0)} (s : Seq)
theorem Seq.denote_erase0 {α} (ctx : Context α) {inst : Std.LawfulIdentity ctx.op (Var.denote ctx 0)} (s : Seq)
: s.erase0.denote ctx = s.denote ctx := by
fun_induction erase0 s <;> simp_all +zetaDelta
next => rw [Std.LawfulLeftIdentity.left_id (self := inst.toLawfulLeftIdentity)]
@@ -179,12 +182,12 @@ theorem Seq.insert_k_eq_insert (x : Var) (s : Seq) : insert_k x s = insert x s :
attribute [local simp] Seq.insert_k_eq_insert
theorem Seq.denote_insert {α} (ctx : Context α) {inst₁ : Std.Associative ctx.op} {inst₂ : Std.Commutative ctx.op} (x : Var) (s : Seq)
: (s.insert x).denote ctx = ctx.op (ctx.vars.get x) (s.denote ctx) := by
: (s.insert x).denote ctx = ctx.op (x.denote ctx) (s.denote ctx) := by
fun_induction insert x s <;> simp
next => rw [Std.Commutative.comm (self := inst₂)]
next y s h ih =>
simp [ih, Std.Associative.assoc (self := inst₁)]
rw [Std.Commutative.comm (self := inst₂) (ctx.vars.get x)]
rw [Std.Commutative.comm (self := inst₂) (x.denote ctx)]
attribute [local simp] Seq.denote_insert
@@ -208,7 +211,7 @@ theorem Seq.denote_sort' {α} (ctx : Context α) {inst₁ : Std.Associative ctx.
fun_induction sort' s acc <;> simp
next x s ih =>
simp [ih, Std.Associative.assoc (self := inst₁)]
rw [Std.Commutative.comm (self := inst₂) (ctx.vars.get x) (s.denote ctx)]
rw [Std.Commutative.comm (self := inst₂) (x.denote ctx) (s.denote ctx)]
attribute [local simp] Seq.denote_sort'
@@ -387,11 +390,11 @@ theorem Seq.denote_combineFuel {α} (ctx : Context α) {inst₁ : Std.Associativ
next ih => simp [ih, Std.Associative.assoc (self := inst₁)]
next x₁ s₁ x₂ s₂ h ih =>
simp [ih]
rw [ Std.Associative.assoc (self := inst₁), Std.Associative.assoc (self := inst₁), Std.Commutative.comm (self := inst₂) (ctx.vars.get x)]
rw [Std.Associative.assoc (self := inst₁), Std.Associative.assoc (self := inst₁), Std.Associative.assoc (self := inst₁) (ctx.vars.get x)]
apply congrArg (ctx.op (ctx.vars.get x))
rw [ Std.Associative.assoc (self := inst₁), Std.Associative.assoc (self := inst₁), Std.Commutative.comm (self := inst₂) (x₂.denote ctx)]
rw [Std.Associative.assoc (self := inst₁), Std.Associative.assoc (self := inst₁), Std.Associative.assoc (self := inst₁) (x₁.denote ctx)]
apply congrArg (ctx.op (x₁.denote ctx))
rw [ Std.Associative.assoc (self := inst₁), Std.Associative.assoc (self := inst₁) (s₁.denote ctx)]
rw [Std.Commutative.comm (self := inst₂) (ctx.vars.get x)]
rw [Std.Commutative.comm (self := inst₂) (x₂.denote ctx)]
attribute [local simp] Seq.denote_combineFuel
@@ -446,54 +449,65 @@ theorem superpose_ac {α} (ctx : Context α) {inst₁ : Std.Associative ctx.op}
apply congrArg (ctx.op (c.denote ctx))
rw [Std.Commutative.comm (self := inst₂) (b.denote ctx)]
noncomputable def norm_a_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool :=
noncomputable def eq_norm_a_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool :=
lhs.toSeq.beq' lhs' |>.and' (rhs.toSeq.beq' rhs')
theorem norm_a {α} (ctx : Context α) {_ : Std.Associative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq)
: norm_a_cert lhs rhs lhs' rhs' lhs.denote ctx = rhs.denote ctx lhs'.denote ctx = rhs'.denote ctx := by
simp [norm_a_cert]; intro _ _; subst lhs' rhs'; simp
theorem eq_norm_a {α} (ctx : Context α) {_ : Std.Associative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq)
: eq_norm_a_cert lhs rhs lhs' rhs' lhs.denote ctx = rhs.denote ctx lhs'.denote ctx = rhs'.denote ctx := by
simp [eq_norm_a_cert]; intro _ _; subst lhs' rhs'; simp
noncomputable def norm_ac_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool :=
noncomputable def eq_norm_ac_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool :=
lhs.toSeq.sort.beq' lhs' |>.and' (rhs.toSeq.sort.beq' rhs')
theorem norm_ac {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq)
: norm_ac_cert lhs rhs lhs' rhs' lhs.denote ctx = rhs.denote ctx lhs'.denote ctx = rhs'.denote ctx := by
simp [norm_ac_cert]; intro _ _; subst lhs' rhs'; simp
theorem eq_norm_ac {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq)
: eq_norm_ac_cert lhs rhs lhs' rhs' lhs.denote ctx = rhs.denote ctx lhs'.denote ctx = rhs'.denote ctx := by
simp [eq_norm_ac_cert]; intro _ _; subst lhs' rhs'; simp
noncomputable def norm_aci_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool :=
lhs.toSeq.erase0.sort.beq' lhs' |>.and' (rhs.toSeq.erase0.sort.beq' rhs')
theorem norm_aci {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} {_ : Std.LawfulIdentity ctx.op (ctx.vars.get 0)}
(lhs rhs : Expr) (lhs' rhs' : Seq) : norm_aci_cert lhs rhs lhs' rhs' lhs.denote ctx = rhs.denote ctx lhs'.denote ctx = rhs'.denote ctx := by
simp [norm_aci_cert]; intro _ _; subst lhs' rhs'; simp
noncomputable def norm_ai_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool :=
noncomputable def eq_norm_ai_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool :=
lhs.toSeq.erase0.beq' lhs' |>.and' (rhs.toSeq.erase0.beq' rhs')
theorem norm_ai {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.LawfulIdentity ctx.op (ctx.vars.get 0)}
(lhs rhs : Expr) (lhs' rhs' : Seq) : norm_ai_cert lhs rhs lhs' rhs' lhs.denote ctx = rhs.denote ctx lhs'.denote ctx = rhs'.denote ctx := by
simp [norm_ai_cert]; intro _ _; subst lhs' rhs'; simp
theorem eq_norm_ai {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.LawfulIdentity ctx.op (Var.denote ctx 0)}
(lhs rhs : Expr) (lhs' rhs' : Seq) : eq_norm_ai_cert lhs rhs lhs' rhs' lhs.denote ctx = rhs.denote ctx lhs'.denote ctx = rhs'.denote ctx := by
simp [eq_norm_ai_cert]; intro _ _; subst lhs' rhs'; simp
noncomputable def norm_acip_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool :=
lhs.toSeq.erase0.sort.eraseDup.beq' lhs' |>.and' (rhs.toSeq.erase0.sort.eraseDup.beq' rhs')
noncomputable def eq_norm_aci_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool :=
lhs.toSeq.erase0.sort.beq' lhs' |>.and' (rhs.toSeq.erase0.sort.beq' rhs')
theorem norm_acip {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op}
{_ : Std.LawfulIdentity ctx.op (ctx.vars.get 0)} {_ : Std.IdempotentOp ctx.op}
(lhs rhs : Expr) (lhs' rhs' : Seq) : norm_acip_cert lhs rhs lhs' rhs' lhs.denote ctx = rhs.denote ctx lhs'.denote ctx = rhs'.denote ctx := by
simp [norm_acip_cert]; intro _ _; subst lhs' rhs'; simp
theorem eq_norm_aci {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} {_ : Std.LawfulIdentity ctx.op (Var.denote ctx 0)}
(lhs rhs : Expr) (lhs' rhs' : Seq) : eq_norm_aci_cert lhs rhs lhs' rhs' lhs.denote ctx = rhs.denote ctx lhs'.denote ctx = rhs'.denote ctx := by
simp [eq_norm_aci_cert]; intro _ _; subst lhs' rhs'; simp
noncomputable def norm_acp_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool :=
lhs.toSeq.sort.eraseDup.beq' lhs' |>.and' (rhs.toSeq.sort.eraseDup.beq' rhs')
theorem norm_acp {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} {_ : Std.IdempotentOp ctx.op}
(lhs rhs : Expr) (lhs' rhs' : Seq) : norm_acp_cert lhs rhs lhs' rhs' lhs.denote ctx = rhs.denote ctx lhs'.denote ctx = rhs'.denote ctx := by
simp [norm_acp_cert]; intro _ _; subst lhs' rhs'; simp
noncomputable def norm_dup_cert (lhs rhs lhs' rhs' : Seq) : Bool :=
noncomputable def eq_erase_dup_cert (lhs rhs lhs' rhs' : Seq) : Bool :=
lhs.eraseDup.beq' lhs' |>.and' (rhs.eraseDup.beq' rhs')
theorem norm_dup (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.IdempotentOp ctx.op}
(lhs rhs lhs' rhs' : Seq) : norm_dup_cert lhs rhs lhs' rhs' lhs.denote ctx = rhs.denote ctx lhs'.denote ctx = rhs'.denote ctx := by
simp [norm_dup_cert]; intro _ _; subst lhs' rhs'; simp
theorem eq_erase_dup {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.IdempotentOp ctx.op}
(lhs rhs lhs' rhs' : Seq) : eq_erase_dup_cert lhs rhs lhs' rhs' lhs.denote ctx = rhs.denote ctx lhs'.denote ctx = rhs'.denote ctx := by
simp [eq_erase_dup_cert]; intro _ _; subst lhs' rhs'; simp
theorem diseq_norm_a {α} (ctx : Context α) {_ : Std.Associative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq)
: eq_norm_a_cert lhs rhs lhs' rhs' lhs.denote ctx rhs.denote ctx lhs'.denote ctx rhs'.denote ctx := by
simp [eq_norm_a_cert]; intro _ _; subst lhs' rhs'; simp
theorem diseq_norm_ac {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq)
: eq_norm_ac_cert lhs rhs lhs' rhs' lhs.denote ctx rhs.denote ctx lhs'.denote ctx rhs'.denote ctx := by
simp [eq_norm_ac_cert]; intro _ _; subst lhs' rhs'; simp
theorem diseq_norm_ai {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.LawfulIdentity ctx.op (Var.denote ctx 0)}
(lhs rhs : Expr) (lhs' rhs' : Seq) : eq_norm_ai_cert lhs rhs lhs' rhs' lhs.denote ctx rhs.denote ctx lhs'.denote ctx rhs'.denote ctx := by
simp [eq_norm_ai_cert]; intro _ _; subst lhs' rhs'; simp
theorem diseq_norm_aci {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} {_ : Std.LawfulIdentity ctx.op (Var.denote ctx 0)}
(lhs rhs : Expr) (lhs' rhs' : Seq) : eq_norm_aci_cert lhs rhs lhs' rhs' lhs.denote ctx rhs.denote ctx lhs'.denote ctx rhs'.denote ctx := by
simp [eq_norm_aci_cert]; intro _ _; subst lhs' rhs'; simp
theorem diseq_erase_dup {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.IdempotentOp ctx.op}
(lhs rhs lhs' rhs' : Seq) : eq_erase_dup_cert lhs rhs lhs' rhs' lhs.denote ctx rhs.denote ctx lhs'.denote ctx rhs'.denote ctx := by
simp [eq_erase_dup_cert]; intro _ _; subst lhs' rhs'; simp
noncomputable def diseq_unsat_cert (lhs rhs : Seq) : Bool :=
lhs.beq' rhs
theorem diseq_unsat {α} (ctx : Context α) (lhs rhs : Seq) : diseq_unsat_cert lhs rhs lhs.denote ctx rhs.denote ctx False := by
simp [diseq_unsat_cert]; intro; subst lhs; simp
end Lean.Grind.AC

View File

@@ -37,6 +37,8 @@ public import Lean.Meta.Tactic.Grind.LawfulEqCmp
public import Lean.Meta.Tactic.Grind.ReflCmp
public import Lean.Meta.Tactic.Grind.SynthInstance
public import Lean.Meta.Tactic.Grind.AC
public import Lean.Meta.Tactic.Grind.VarRename
public import Lean.Meta.Tactic.Grind.ProofUtil
public section

View File

@@ -9,6 +9,12 @@ public import Lean.Meta.Tactic.Grind.AC.Types
public import Lean.Meta.Tactic.Grind.AC.Util
public import Lean.Meta.Tactic.Grind.AC.Var
public import Lean.Meta.Tactic.Grind.AC.Internalize
public import Lean.Meta.Tactic.Grind.AC.Eq
public import Lean.Meta.Tactic.Grind.AC.Seq
public import Lean.Meta.Tactic.Grind.AC.Proof
public import Lean.Meta.Tactic.Grind.AC.DenoteExpr
public import Lean.Meta.Tactic.Grind.AC.ToExpr
public import Lean.Meta.Tactic.Grind.AC.VarRename
public section
namespace Lean
builtin_initialize registerTraceClass `grind.ac

View File

@@ -0,0 +1,33 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Tactic.Grind.AC.Util
import Lean.Meta.AppBuilder
public section
namespace Lean.Meta.Grind.AC
open Lean.Grind
variable [Monad M] [MonadGetStruct M] [MonadError M]
def _root_.Lean.Grind.AC.Seq.denoteExpr (s : AC.Seq) : M Expr := do
match s with
| .var x => return ( getStruct).vars[x]!
| .cons x s => return mkApp2 ( getStruct).op ( getStruct).vars[x]! ( denoteExpr s)
def _root_.Lean.Grind.AC.Expr.denoteExpr (e : AC.Expr) : M Expr := do
match e with
| .var x => return ( getStruct).vars[x]!
| .op lhs rhs => return mkApp2 ( getStruct).op ( denoteExpr lhs) ( denoteExpr rhs)
def EqCnstr.denoteExpr (c : EqCnstr) : M Expr := do
let s getStruct
return mkApp3 (mkConst ``Eq [s.u]) s.type ( c.lhs.denoteExpr) ( c.rhs.denoteExpr)
def DiseqCnstr.denoteExpr (c : DiseqCnstr) : M Expr := do
let s getStruct
return mkApp3 (mkConst ``Ne [s.u]) s.type ( c.lhs.denoteExpr) ( c.rhs.denoteExpr)
end Lean.Meta.Grind.AC

View File

@@ -0,0 +1,79 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Tactic.Grind.AC.Util
import Lean.Meta.Tactic.Grind.AC.DenoteExpr
import Lean.Meta.Tactic.Grind.AC.Proof
public section
namespace Lean.Meta.Grind.AC
open Lean.Grind
/-- For each structure `s` s.t. `a` and `b` are elements of, execute `k` -/
@[specialize] def withExprs (a b : Expr) (k : ACM Unit) : GoalM Unit := do
let ids₁ getTermOpIds a
if ids₁.isEmpty then return ()
let ids₂ getTermOpIds b
go ids₁ ids₂
where
go : List Nat List Nat GoalM Unit
| [], _ => return ()
| _, [] => return ()
| id₁::ids₁, id₂::ids₂ => do
if id₁ == id₂ then
ACM.run id₁ k; go ids₁ ids₂
else if id₁ < id₂ then
go ids₁ (id₂::ids₂)
else
go (id₁::ids₁) ids₂
def asACExpr (e : Expr) : ACM AC.Expr := do
if let some e' := ( getStruct).denote.find? { expr := e } then
return e'
else
return .var (( getStruct).varMap.find? { expr := e }).get!
def norm (e : AC.Expr) : ACM AC.Seq := do
match ( isCommutative), ( hasNeutral) with
| true, true => return e.toSeq.erase0.sort
| true, false => return e.toSeq.sort
| false, true => return e.toSeq.erase0
| false, false => return e.toSeq
def saveDiseq (c : DiseqCnstr) : ACM Unit := do
modifyStruct fun s => { s with diseqs := s.diseqs.push c }
def DiseqCnstr.eraseDup (c : DiseqCnstr) : ACM DiseqCnstr := do
unless ( isIdempotent) do return c
let lhs := c.lhs.eraseDup
let rhs := c.rhs.eraseDup
if c.lhs == lhs && c.rhs == rhs then
return c
else
return { lhs, rhs, h := .erase_dup c }
def DiseqCnstr.assert (c : DiseqCnstr) : ACM Unit := do
let c c.eraseDup
-- TODO: simplify and check conflict
trace[grind.ac.assert] "{← c.denoteExpr}"
if c.lhs == c.rhs then
c.setUnsat
else
saveDiseq c
@[export lean_process_ac_eq]
def processNewEqImpl (a b : Expr) : GoalM Unit := withExprs a b do
trace[grind.ac.assert] "{a} = {b}"
-- TODO
@[export lean_process_ac_diseq]
def processNewDiseqImpl (a b : Expr) : GoalM Unit := withExprs a b do
let ea asACExpr a
let lhs norm ea
let eb asACExpr b
let rhs norm eb
{ lhs, rhs, h := .core a b ea eb : DiseqCnstr }.assert
end Lean.Meta.Grind.AC

View File

@@ -7,6 +7,7 @@ module
prelude
public import Lean.Meta.Tactic.Grind.Types
public import Lean.Meta.Tactic.Grind.AC.Util
import Lean.Meta.Tactic.Grind.AC.DenoteExpr
public section
namespace Lean.Meta.Grind.AC
@@ -15,6 +16,12 @@ private def isParentSameOpApp (parent? : Option Expr) (op : Expr) : GoalM Bool :
unless e.isApp && e.appFn!.isApp do return false
return isSameExpr e.appFn!.appFn! op
partial def reify (e : Expr) : ACM Grind.AC.Expr := do
if let some (a, b) isOp? e then
return .op ( reify a) ( reify b)
else
return .var ( mkVar e)
@[export lean_grind_ac_internalize]
def internalizeImpl (e : Expr) (parent? : Option Expr) : GoalM Unit := do
unless ( getConfig).ac do return ()
@@ -22,7 +29,12 @@ def internalizeImpl (e : Expr) (parent? : Option Expr) : GoalM Unit := do
let op := e.appFn!.appFn!
let some id getOpId? op | return ()
if ( isParentSameOpApp parent? op) then return ()
trace[grind.ac.internalize] "[{id}] {e}"
-- TODO: internalize `e`
ACM.run id do
if ( getStruct).denote.contains { expr := e } then return ()
let e' reify e
modifyStruct fun s => { s with denote := s.denote.insert { expr := e } e' }
trace[grind.ac.internalize] "[{id}] {← e'.denoteExpr}"
addTermOpId e
markAsACTerm e
end Lean.Meta.Grind.AC

View File

@@ -0,0 +1,142 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Tactic.Grind.AC.Util
import Lean.Meta.Tactic.Grind.Diseq
import Lean.Meta.Tactic.Grind.ProofUtil
import Lean.Meta.Tactic.Grind.AC.ToExpr
import Lean.Meta.Tactic.Grind.AC.VarRename
public section
namespace Lean.Meta.Grind.AC
open Lean.Grind
structure ProofM.State where
cache : Std.HashMap UInt64 Expr := {}
exprDecls : Std.HashMap AC.Expr Expr := {}
seqDecls : Std.HashMap AC.Seq Expr := {}
structure ProofM.Context where
ctx : Expr
abbrev ProofM := ReaderT ProofM.Context (StateRefT ProofM.State ACM)
/-- Returns a Lean expression representing the variable context used to construct `AC` proof steps. -/
private def getContext : ProofM Expr := do
return ( read).ctx
private abbrev caching (c : α) (k : ProofM Expr) : ProofM Expr := do
let addr := unsafe (ptrAddrUnsafe c).toUInt64 >>> 2
if let some h := ( get).cache[addr]? then
return h
else
let h k
modify fun s => { s with cache := s.cache.insert addr h }
return h
local macro "declare! " decls:ident a:ident : term =>
`(do if let some x := ( get).$decls[$a]? then
return x
let x := mkFVar ( mkFreshFVarId);
modify fun s => { s with $decls:ident := (s.$decls).insert $a x };
return x)
private def mkSeqDecl (s : AC.Seq) : ProofM Expr := do
declare! seqDecls s
private def mkExprDecl (e : AC.Expr) : ProofM Expr := do
declare! exprDecls e
private def getCommInst : ACM Expr := do
let some inst := ( getStruct).commInst? | throwError "`grind` internal error, `{(← getStruct).op}` is not commutative"
return inst
private def getIdempotentInst : ACM Expr := do
let some inst := ( getStruct).idempotentInst? | throwError "`grind` internal error, `{(← getStruct).op}` is not idempotent"
return inst
private def getNeutralInst : ACM Expr := do
let some inst := ( getStruct).neutralInst? | throwError "`grind` internal error, `{(← getStruct).op}` does not have a neutral element"
return inst
private def mkPrefix (declName : Name) : ProofM Expr := do
let s getStruct
return mkApp2 (mkConst declName [s.u]) s.type ( getContext)
private def mkAPrefix (declName : Name) : ProofM Expr := do
let s getStruct
return mkApp3 (mkConst declName [s.u]) s.type ( getContext) s.assocInst
private def mkACPrefix (declName : Name) : ProofM Expr := do
let s getStruct
return mkApp4 (mkConst declName [s.u]) s.type ( getContext) s.assocInst ( getCommInst)
private def mkAIPrefix (declName : Name) : ProofM Expr := do
let s getStruct
return mkApp4 (mkConst declName [s.u]) s.type ( getContext) s.assocInst ( getNeutralInst)
private def mkACIPrefix (declName : Name) : ProofM Expr := do
let s getStruct
return mkApp5 (mkConst declName [s.u]) s.type ( getContext) s.assocInst ( getCommInst) ( getNeutralInst)
private def mkDupPrefix (declName : Name) : ProofM Expr := do
let s getStruct
return mkApp4 (mkConst declName [s.u]) s.type ( getContext) s.assocInst ( getIdempotentInst)
private def toContextExpr (vars : Array Expr) : ACM Expr := do
let s getStruct
if h : 0 < vars.size then
RArray.toExpr (mkApp (mkConst ``PLift [s.u]) s.type) id (RArray.ofFn (vars[·]) h)
else unreachable!
private def mkContext (h : Expr) : ProofM Expr := do
let s getStruct
let mut usedVars :=
collectMapVars ( get).seqDecls (·.collectVars) >>
collectMapVars ( get).exprDecls (·.collectVars) >>
(if ( hasNeutral) then (collectVar 0) else id) <| {}
let vars' := usedVars.toArray
let varRename := mkVarRename vars'
let vars := ( getStruct).vars
let up := mkApp (mkConst ``PLift.up [s.u]) s.type
let vars := vars'.map fun x => mkApp up vars[x]!
let h := mkLetOfMap ( get).seqDecls h `p (mkConst ``Grind.AC.Seq) fun p => toExpr <| p.renameVars varRename
let h := mkLetOfMap ( get).exprDecls h `e (mkConst ``Grind.AC.Expr) fun e => toExpr <| e.renameVars varRename
let h := h.abstract #[( read).ctx]
if h.hasLooseBVars then
let ctxType := mkApp (mkConst ``AC.Context [s.u]) s.type
let ctxVal := mkApp3 (mkConst ``AC.Context.mk [s.u]) s.type ( toContextExpr vars) s.op
return .letE `ctx ctxType ctxVal h (nondep := false)
else
return h
private abbrev withProofContext (x : ProofM Expr) : ACM Expr := do
let ctx := mkFVar ( mkFreshFVarId)
go { ctx } |>.run' {}
where
go : ProofM Expr := do
mkContext ( x)
partial def DiseqCnstr.toExprProof (c : DiseqCnstr) : ProofM Expr := do caching c do
match c.h with
| .core a b lhs rhs =>
let h match ( isCommutative), ( hasNeutral) with
| false, false => mkAPrefix ``AC.diseq_norm_a
| false, true => mkAIPrefix ``AC.diseq_norm_ai
| true, false => mkACPrefix ``AC.diseq_norm_ac
| true, true => mkACIPrefix ``AC.diseq_norm_aci
return mkApp6 h ( mkExprDecl lhs) ( mkExprDecl rhs) ( mkSeqDecl c.lhs) ( mkSeqDecl c.rhs) eagerReflBoolTrue ( mkDiseqProof a b)
| .erase_dup c₁ =>
let h mkDupPrefix ``AC.diseq_erase_dup
return mkApp6 h ( mkSeqDecl c₁.lhs) ( mkSeqDecl c₁.rhs) ( mkSeqDecl c.lhs) ( mkSeqDecl c.rhs) eagerReflBoolTrue ( c₁.toExprProof)
| _ => throwError "NIY"
def DiseqCnstr.setUnsat (c : DiseqCnstr) : ACM Unit := do
let h withProofContext do
return mkApp4 ( mkPrefix ``AC.diseq_unsat) ( mkSeqDecl c.lhs) ( mkSeqDecl c.rhs) eagerReflBoolTrue ( c.toExprProof)
closeGoal h
end Lean.Meta.Grind.AC

View File

@@ -0,0 +1,17 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Init.Core
public import Init.Grind.AC
public section
namespace Lean.Grind.AC
def Seq.length : Seq Nat
| .var _ => 1
| .cons _ s => s.length + 1
end Lean.Grind.AC

View File

@@ -0,0 +1,34 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Init.Grind.AC
public import Lean.ToExpr
public section
namespace Lean.Meta.Grind.AC
open Lean.Grind
/-!
`ToExpr` instances for `AC` types.
-/
def ofSeq (m : AC.Seq) : Expr :=
match m with
| .var x => mkApp (mkConst ``AC.Seq.var) (toExpr x)
| .cons x s => mkApp2 (mkConst ``AC.Seq.cons) (toExpr x) (ofSeq s)
instance : ToExpr AC.Seq where
toExpr := ofSeq
toTypeExpr := mkConst ``AC.Seq
def ofExpr (m : AC.Expr) : Expr :=
match m with
| .var x => mkApp (mkConst ``AC.Expr.var) (toExpr x)
| .op lhs rhs => mkApp2 (mkConst ``AC.Expr.op) (ofExpr lhs) (ofExpr rhs)
instance : ToExpr AC.Expr where
toExpr := ofExpr
toTypeExpr := mkConst ``AC.Expr
end Lean.Meta.Grind.AC

View File

@@ -11,10 +11,57 @@ public import Std.Data.HashMap
public import Lean.Expr
public import Lean.Data.PersistentArray
public import Lean.Meta.Tactic.Grind.ExprPtr
import Lean.Meta.Tactic.Grind.AC.Seq
public section
namespace Lean.Meta.Grind.AC
open Lean.Grind.AC
open Lean.Grind
deriving instance Hashable for AC.Expr, AC.Seq
mutual
structure EqCnstr where
lhs : AC.Seq
rhs : AC.Seq
h : EqCnstrProof
id : Nat
inductive EqCnstrProof where
| core (a b : Expr) (ea eb : AC.Expr)
| erase_dup (c : EqCnstr)
| simp_ac (lhs : Bool) (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr)
| superpose_ac (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr)
| simp_suffix (lhs : Bool) (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr)
| simp_prefix (lhs : Bool) (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr)
| simp_middle (lhs : Bool) (s₁ s₂ : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr)
| superpose_prefix (s₁ s₂ : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr)
end
instance : Inhabited EqCnstrProof where
default := .core default default default default
instance : Inhabited EqCnstr where
default := { lhs := default, rhs := default, h := default, id := 0 }
protected def EqCnstr.compare (c₁ c₂ : EqCnstr) : Ordering :=
(compare c₁.lhs.length c₂.lhs.length) |>.then
(compare c₁.id c₂.id)
abbrev Queue : Type := Std.TreeSet EqCnstr EqCnstr.compare
mutual
structure DiseqCnstr where
lhs : AC.Seq
rhs : AC.Seq
h : DiseqCnstrProof
inductive DiseqCnstrProof where
| core (a b : Expr) (ea eb : AC.Expr)
| erase_dup (c : DiseqCnstr)
| simp_ac (lhs : Bool) (s : AC.Seq) (c₁ : DiseqCnstr) (c₂ : EqCnstr)
| simp_suffix (lhs : Bool) (s : AC.Seq) (c₁ : DiseqCnstr) (c₂ : EqCnstr)
| simp_prefix (lhs : Bool) (s : AC.Seq) (c₁ : DiseqCnstr) (c₂ : EqCnstr)
| simp_middle (lhs : Bool) (s₁ s₂ : AC.Seq) (c₁ : DiseqCnstr) (c₂ : EqCnstr)
end
structure Struct where
id : Nat
@@ -27,13 +74,23 @@ structure Struct where
idempotentInst? : Option Expr
commInst? : Option Expr
neutralInst? : Option Expr
/-- Next unique id for `EqCnstr`s. -/
nextId : Nat := 0
/--
Mapping from variables to their denotations.
Remark each variable can be in only one ring.
-/
vars : PArray Expr := {}
/-- Mapping from `Expr` to a variable representing it. -/
varMap : PHashMap ExprPtr Var := {}
varMap : PHashMap ExprPtr AC.Var := {}
/-- Mapping from Lean expressions to their representations as `AC.Expr` -/
denote : PHashMap ExprPtr AC.Expr := {}
/-- Equations to process. -/
queue : Queue := {}
/-- Processed equations. -/
basis : List EqCnstr := {}
/-- Disequalities. -/
diseqs : PArray DiseqCnstr := {}
deriving Inhabited
/-- State for all associative operators detected by `grind`. -/
@@ -47,8 +104,10 @@ structure State where
Mapping from operators to its "operator id". We cache failures using `none`.
`opIdOf[op]` is `some id`, then `id < structs.size`. -/
opIdOf : PHashMap ExprPtr (Option Nat) := {}
-- Remark: a term may be argument of different associative operators.
-- TODO: add mappings
/--
Mapping from expressions/terms to their structure ids.
Recall that term may be the argument of different operators. -/
exprToOpIds : PHashMap ExprPtr (List Nat) := {}
deriving Inhabited
end Lean.Meta.Grind.AC

View File

@@ -10,8 +10,8 @@ public import Lean.Meta.Tactic.Grind.ProveEq
public import Lean.Meta.Tactic.Grind.SynthInstance
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId
public section
namespace Lean.Meta.Grind.AC
open Lean.Grind
def get' : GoalM State := do
return ( get).ac
@@ -50,6 +50,10 @@ protected def ACM.getStruct : ACM Struct := do
instance : MonadGetStruct ACM where
getStruct := ACM.getStruct
def modifyStruct (f : Struct Struct) : ACM Unit := do
let opId getOpId
modify' fun s => { s with structs := s.structs.modify opId f }
def getOp : ACM Expr :=
return ( getStruct).op
@@ -71,6 +75,42 @@ private def isArithOpInOtherModules (op : Expr) (f : Expr) : GoalM Bool := do
if ( Arith.CommRing.getSemiringId? α).isSome then return true
return false
def getTermOpIds (e : Expr) : GoalM (List Nat) := do
return ( get').exprToOpIds.find? { expr := e } |>.getD []
private def insertOpId (m : PHashMap ExprPtr (List Nat)) (e : Expr) (opId : Nat) : PHashMap ExprPtr (List Nat) :=
let ids := if let some ids := m.find? { expr := e } then
go ids
else
[opId]
m.insert { expr := e } ids
where
go : List Nat List Nat
| [] => [opId]
| id::ids => if opId < id then
opId :: id :: ids
else if opId == id then
opId :: ids
else
id :: go ids
def addTermOpId (e : Expr) : ACM Unit := do
let opId getOpId
modify' fun s => { s with exprToOpIds := insertOpId s.exprToOpIds e opId }
def mkVar (e : Expr) : ACM AC.Var := do
let s getStruct
if let some var := s.varMap.find? { expr := e } then
return var
let var : AC.Var := s.vars.size
modifyStruct fun s => { s with
vars := s.vars.push e
varMap := s.varMap.insert { expr := e } var
}
addTermOpId e
markAsACTerm e
return var
def getOpId? (op : Expr) : GoalM (Option Nat) := do
if let some id? := ( get').opIdOf.find? { expr := op } then
return id?
@@ -98,7 +138,7 @@ where
let idempotentInst? synthInstance? idempotentType
let (neutralInst?, neutral?) do
let neutral mkFreshExprMVar α
let identityType := mkApp3 (mkConst ``Std.Identity [u]) α op neutral
let identityType := mkApp3 (mkConst ``Std.LawfulIdentity [u]) α op neutral
if let some identityInst synthInstance? identityType then
let neutral instantiateExprMVars neutral
let neutral preprocessLight neutral
@@ -112,8 +152,24 @@ where
id, type := α, u, op, neutral?, assocInst, commInst?,
idempotentInst?, neutralInst?
}}
-- TODO: neutral element must be variable 0
trace[grind.debug.ac.op] "{op}, comm: {commInst?.isSome}, idempotent: {idempotentInst?.isSome}, neutral?: {neutral?}"
if let some neutral := neutral? then ACM.run id do
-- Create neutral variable to ensure it is variable 0
discard <| mkVar neutral
return some id
def isOp? (e : Expr) : ACM (Option (Expr × Expr)) := do
unless e.isApp && e.appFn!.isApp do return none
unless isSameExpr e.appFn!.appFn! ( getOp) do return none
return some (e.appFn!.appArg!, e.appArg!)
def isCommutative : ACM Bool :=
return ( getStruct).commInst?.isSome
def hasNeutral : ACM Bool :=
return ( getStruct).neutralInst?.isSome
def isIdempotent : ACM Bool :=
return ( getStruct).idempotentInst?.isSome
end Lean.Meta.Grind.AC

View File

@@ -0,0 +1,33 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Init.Grind.AC
public import Lean.Meta.Tactic.Grind.VarRename
namespace Lean.Grind.AC
open Lean.Meta.Grind
public def Seq.renameVars (s : Seq) (f : VarRename) : Seq :=
match s with
| .var x => .var (f x)
| .cons x s => .cons (f x) (renameVars s f)
public def Expr.renameVars (e : Expr) (f : VarRename) : Expr :=
match e with
| .var x => .var (f x)
| .op a b => .op (renameVars a f) (renameVars b f)
public def Seq.collectVars (s : Seq) : VarCollector :=
match s with
| .var x => collectVar x
| .cons x s => collectVar x >> s.collectVars
public def Expr.collectVars (e : Expr) : VarCollector :=
match e with
| .var x => collectVar x
| .op a b => a.collectVars >> b.collectVars
end Lean.Grind.AC

View File

@@ -7,7 +7,6 @@ module
prelude
public import Lean.Meta.Tactic.Grind.Arith.Util
public import Lean.Meta.Tactic.Grind.Arith.ProofUtil
public import Lean.Meta.Tactic.Grind.Arith.Types
public import Lean.Meta.Tactic.Grind.Arith.Main
public import Lean.Meta.Tactic.Grind.Arith.Offset
@@ -15,7 +14,6 @@ public import Lean.Meta.Tactic.Grind.Arith.Cutsat
public import Lean.Meta.Tactic.Grind.Arith.CommRing
public import Lean.Meta.Tactic.Grind.Arith.Linear
public import Lean.Meta.Tactic.Grind.Arith.Simproc
public import Lean.Meta.Tactic.Grind.Arith.VarRename
public import Lean.Meta.Tactic.Grind.Arith.Internalize
public section

View File

@@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Util.Trace
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly

View File

@@ -4,22 +4,19 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Init.Grind.Ring.OfSemiring
public import Lean.Meta.Tactic.Grind.Diseq
public import Lean.Meta.Tactic.Grind.Arith.ProofUtil
public import Lean.Meta.Tactic.Grind.ProofUtil
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId
public import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr
public import Lean.Meta.Tactic.Grind.Arith.CommRing.SafePoly
public import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr
public import Lean.Meta.Tactic.Grind.Arith.CommRing.SemiringM
public import Lean.Meta.Tactic.Grind.Arith.VarRename
public import Lean.Meta.Tactic.Grind.VarRename
import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename
import Lean.Meta.Tactic.Grind.Arith.CommRing.Functions
public section
namespace Lean.Meta.Grind.Arith.CommRing
/--
Returns a context of type `RArray α` containing the variables `vars` where

View File

@@ -4,14 +4,11 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Init.Grind.Ring.Poly
public import Init.Grind.Ring.OfSemiring
public import Lean.ToExpr
public section
namespace Lean.Meta.Grind.Arith.CommRing
open Grind.CommRing
/-!

View File

@@ -4,14 +4,12 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Init.Grind.Ring.Poly
public import Init.Grind.Ring.OfSemiring
public import Lean.Meta.Tactic.Grind.Arith.VarRename
public import Lean.Meta.Tactic.Grind.VarRename
namespace Lean.Grind.CommRing
open Lean.Meta.Grind.Arith
open Lean.Meta.Grind
public def Power.renameVars (pw : Power) (f : VarRename) : Power :=
{ pw with x := (f pw.x) }
@@ -59,7 +57,7 @@ public def Expr.collectVars (e : Expr) : VarCollector :=
end Lean.Grind.CommRing
namespace Lean.Grind.Ring.OfSemiring
open Lean.Meta.Grind.Arith
open Lean.Meta.Grind
public def Expr.renameVars (e : Expr) (f : VarRename) : Expr :=
match e with

View File

@@ -8,8 +8,8 @@ prelude
public import Init.Grind.Ring.Poly
public import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Diseq
import Lean.Meta.Tactic.Grind.Arith.ProofUtil
import Lean.Meta.Tactic.Grind.Arith.VarRename
import Lean.Meta.Tactic.Grind.ProofUtil
import Lean.Meta.Tactic.Grind.VarRename
import Lean.Meta.Tactic.Grind.Arith.Cutsat.CommRing
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Util
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Nat

View File

@@ -4,13 +4,11 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Init.Data.Int.Linear
public import Lean.Meta.Tactic.Grind.Arith.VarRename
public import Lean.Meta.Tactic.Grind.VarRename
namespace Int.Linear
open Lean.Meta.Grind.Arith
open Lean.Meta.Grind
public def Poly.renameVars (p : Poly) (f : VarRename) : Poly :=
match p with

View File

@@ -9,11 +9,11 @@ public import Lean.Meta.Tactic.Grind.Arith.Linear.Util
import Lean.Meta.Tactic.Grind.Arith.Util
import Lean.Meta.Tactic.Grind.Arith.Linear.ToExpr
import Lean.Meta.Tactic.Grind.Arith.Linear.DenoteExpr
import Lean.Meta.Tactic.Grind.Arith.VarRename
import Lean.Meta.Tactic.Grind.VarRename
import Lean.Meta.Tactic.Grind.Diseq
import Lean.Meta.Tactic.Grind.ProofUtil
import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename
import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr
import Lean.Meta.Tactic.Grind.Arith.ProofUtil
import Lean.Meta.Tactic.Grind.Arith.Linear.VarRename
import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename

View File

@@ -7,10 +7,10 @@ module
prelude
public import Init.Grind.Ordered.Linarith
public import Lean.Meta.Tactic.Grind.Arith.VarRename
public import Lean.Meta.Tactic.Grind.VarRename
namespace Lean.Grind.Linarith
open Lean.Meta.Grind.Arith
open Lean.Meta.Grind
public def Poly.renameVars (p : Poly) (f : VarRename) : Poly :=
match p with

View File

@@ -215,6 +215,31 @@ def propagateLinarith : PendingTheoryPropagation → GoalM Unit
| .diseqs ps => propagateLinarithDiseqs ps
| _ => return ()
/--
Helper function for combining `ENode.ac?` fields and detecting what needs to be
propagated to the ac module.
-/
private def checkACEq (rhsRoot lhsRoot : ENode) : GoalM PendingTheoryPropagation := do
match lhsRoot.ac? with
| some lhs =>
if let some rhs := rhsRoot.ac? then
return .eq lhs rhs
else
-- We have to retrieve the node because other fields have been updated
let rhsRoot getENode rhsRoot.self
setENode rhsRoot.self { rhsRoot with ac? := lhs }
return .diseqs ( getParents rhsRoot.self)
| none =>
if rhsRoot.ac?.isSome then
return .diseqs ( getParents lhsRoot.self)
else
return .none
def propagateAC : PendingTheoryPropagation GoalM Unit
| .eq lhs rhs => AC.processNewEq lhs rhs
| .diseqs ps => propagateACDiseqs ps
| _ => return ()
/--
Tries to apply beta-reduction using the parent applications of the functions in `fns` with
the lambda expressions in `lams`.
@@ -349,6 +374,7 @@ where
let cutsatTodo checkCutsatEq rhsRoot lhsRoot
let ringTodo checkCommRingEq rhsRoot lhsRoot
let linarithTodo checkLinarithEq rhsRoot lhsRoot
let ACTodo checkACEq rhsRoot lhsRoot
resetParentsOf lhsRoot.self
copyParentsTo parents rhsNode.root
unless ( isInconsistent) do
@@ -363,6 +389,7 @@ where
propagateCutsat cutsatTodo
propagateCommRing ringTodo
propagateLinarith linarithTodo
propagateAC ACTodo
updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do
let isFalseRoot isFalseExpr rootNew
traverseEqc lhs fun n => do

View File

@@ -4,13 +4,10 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Tactic.Grind.Types
public section
namespace Lean.Meta.Grind.Arith
namespace Lean.Meta.Grind
def mkLetOfMap {_ : Hashable α} {_ : BEq α} (m : Std.HashMap α Expr) (e : Expr)
(varPrefix : Name) (varType : Expr) (toExpr : α Expr) : Expr := Id.run do
@@ -25,4 +22,4 @@ def mkLetOfMap {_ : Hashable α} {_ : BEq α} (m : Std.HashMap α Expr) (e : Exp
i := i - 1
return e
end Lean.Meta.Grind.Arith
end Lean.Meta.Grind

View File

@@ -186,6 +186,7 @@ builtin_grind_propagator propagateEqDown ↓Eq := fun e => do
propagateCutsatDiseq lhs rhs
propagateCommRingDiseq lhs rhs
propagateLinarithDiseq lhs rhs
propagateACDiseq lhs rhs
let thms getExtTheorems α
if !thms.isEmpty then
/-

View File

@@ -440,7 +440,12 @@ structure ENode where
to the linarith module. Its implementation is similar to the `offset?` field.
-/
linarith? : Option Expr := none
-- Remark: we expect to have builtin support for offset constraints, cutsat, comm ring, and linarith.
/--
The `ac?` field is used to propagate equalities from the `grind` congruence closure module
to the ac module. Its implementation is similar to the `offset?` field.
-/
ac? : Option Expr := none
-- Remark: we expect to have builtin support for offset constraints, cutsat, comm ring, linarith, and ac.
-- If the number of satellite solvers increases, we may add support for an arbitrary solvers like done in Z3.
deriving Inhabited, Repr
@@ -1283,6 +1288,53 @@ def markAsLinarithTerm (e : Expr) : GoalM Unit := do
setENode root.self { root with linarith? := some e }
propagateLinarithDiseqs ( getParents root.self)
/--
Notifies the ac module that `a = b` where
`a` and `b` are terms that have been internalized by this module.
-/
@[extern "lean_process_ac_eq"] -- forward definition
opaque AC.processNewEq (a b : Expr) : GoalM Unit
/--
Notifies the ac module that `a ≠ b` where
`a` and `b` are terms that have been internalized by this module.
-/
@[extern "lean_process_ac_diseq"] -- forward definition
opaque AC.processNewDiseq (a b : Expr) : GoalM Unit
/--
Given `lhs` and `rhs` that are known to be disequal, checks whether
`lhs` and `rhs` have ac terms `e₁` and `e₂` attached to them,
and invokes process `AC.processNewDiseq e₁ e₂`
-/
def propagateACDiseq (lhs rhs : Expr) : GoalM Unit := do
let some lhs get? lhs | return ()
let some rhs get? rhs | return ()
AC.processNewDiseq lhs rhs
where
get? (a : Expr) : GoalM (Option Expr) := do
return ( getRootENode a).ac?
/--
Traverses disequalities in `parents`, and propagate the ones relevant to the
ac module.
-/
def propagateACDiseqs (parents : ParentSet) : GoalM Unit := do
forEachDiseq parents propagateACDiseq
/--
Marks `e` as a term of interest to the ac module.
If the root of `e`s equivalence class has already a term of interest,
a new equality is propagated to the ac module.
-/
def markAsACTerm (e : Expr) : GoalM Unit := do
let root getRootENode e
if let some e' := root.ac? then
AC.processNewEq e e'
else
setENode root.self { root with ac? := some e }
propagateACDiseqs ( getParents root.self)
/-- Returns `true` is `e` is the root of its congruence class. -/
def isCongrRoot (e : Expr) : GoalM Bool := do
return ( getENode e).isCongrRoot

View File

@@ -4,13 +4,12 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Init.Prelude
public import Init.Data.Array.QSort
public import Std.Data.HashSet
public section
namespace Lean.Meta.Grind.Arith
namespace Lean.Meta.Grind
abbrev Var : Type := Nat
abbrev FoundVars := Std.HashSet Nat
@@ -45,4 +44,4 @@ def mkVarRename (new2old : Array Var) : VarRename := Id.run do
new := new + 1
{ map := old2new }
end Lean.Meta.Grind.Arith
end Lean.Meta.Grind

0
src/out Normal file
View File

View File

@@ -0,0 +1,81 @@
open Lean Grind AC
example {α : Type u} (op : α α α) [Std.Associative op] (a b c : α)
: op a (op b c) = op (op a b) c := by
grind only
example {α : Sort u} (op : α α α) [Std.Associative op] (a b c : α)
: op a (op b c) = op (op a b) c := by
grind only
example {α : Sort u} (op : α α α) (u : α) [Std.Associative op] [Std.LawfulIdentity op u] (a b c : α)
: op a (op b c) = op (op a b) (op c u) := by
grind only
example {α : Sort u} (op : α α α) [Std.Associative op] [Std.Commutative op] (a b c : α)
: op c (op b a) = op (op b a) c := by
grind only
example {α : Sort u} (op : α α α) (u : α) [Std.Associative op] [Std.Commutative op] [Std.LawfulIdentity op u] (a b c : α)
: op a (op b c) = op (op b a) c := by
grind only
example {α : Sort u} (op : α α α) (u : α) [Std.Associative op] [Std.Commutative op] [Std.LawfulIdentity op u] (a b c : α)
: op a (op b (op u c)) = op (op b a) (op u c) := by
grind only
example {α : Sort u} (op : α α α) [Std.Associative op] [Std.IdempotentOp op] (a b c : α)
: op (op a a) (op b c) = op (op a (op b b)) c := by
grind only
example {α : Sort u} (op : α α α) [Std.Associative op] [Std.Commutative op] [Std.IdempotentOp op] (a b c : α)
: op (op a a) (op b c) = op (op (op b a) (op b b)) c := by
grind only
example {α : Sort u} (op : α α α) (u : α) [Std.Associative op] [Std.Commutative op] [Std.IdempotentOp op] [Std.LawfulIdentity op u] (a b c : α)
: op (op a a) (op b c) = op (op (op b a) (op (op u b) b)) c := by
grind only
example {α : Type u} (op : α α α) [Std.Associative op] [Std.Commutative op] [Std.IdempotentOp op] (a b c : α)
: op (op a a) (op b c) = op (op (op b a) (op b b)) c := by
grind only
example {α : Type u} (op : α α α) (u : α) [Std.Associative op] [Std.Commutative op] [Std.IdempotentOp op] [Std.LawfulIdentity op u] (a b c : α)
: op (op a a) (op b c) = op (op (op b a) (op (op u b) b)) c := by
grind only
example {α} (as bs cs : List α) : as ++ (bs ++ cs) = ((as ++ []) ++ bs) ++ (cs ++ []) := by
grind only
example (a b c : Nat) : max a (max b c) = max (max b 0) (max a c) := by
grind only
/--
trace: [grind.debug.proof] Classical.byContradiction
(intro_with_eq (¬(max a (max b c) = max (max b 0) (max a c) ∧ min a b = min b a))
(¬max a (max b c) = max (max b 0) (max a c) ¬min a b = min b a) False
(Grind.not_and (max a (max b c) = max (max b 0) (max a c)) (min a b = min b a)) fun h =>
Or.casesOn h
(fun h_1 =>
let ctx :=
Context.mk
(RArray.branch 2 (RArray.branch 1 (RArray.leaf (PLift.up 0)) (RArray.leaf (PLift.up a)))
(RArray.branch 3 (RArray.leaf (PLift.up b)) (RArray.leaf (PLift.up c))))
max;
let e_1 := ((Expr.var 2).op (Expr.var 0)).op ((Expr.var 1).op (Expr.var 3));
let e_2 := (Expr.var 1).op ((Expr.var 2).op (Expr.var 3));
let p_1 := Seq.cons 1 (Seq.cons 2 (Seq.var 3));
diseq_unsat ctx p_1 p_1 (eagerReduce (Eq.refl true))
(diseq_norm_aci ctx e_2 e_1 p_1 p_1 (eagerReduce (Eq.refl true)) h_1))
fun h_1 =>
let ctx := Context.mk (RArray.branch 1 (RArray.leaf (PLift.up a)) (RArray.leaf (PLift.up b))) min;
let e_1 := (Expr.var 1).op (Expr.var 0);
let e_2 := (Expr.var 0).op (Expr.var 1);
let p_1 := Seq.cons 0 (Seq.var 1);
diseq_unsat ctx p_1 p_1 (eagerReduce (Eq.refl true))
(diseq_norm_ac ctx e_2 e_1 p_1 p_1 (eagerReduce (Eq.refl true)) h_1))
-/
#guard_msgs in
set_option pp.structureInstances false in
set_option trace.grind.debug.proof true in
example (a b c : Nat) : max a (max b c) = max (max b 0) (max a c) min a b = min b a := by
grind only [cases Or]