From aaec0f584c85c8c79fcad610394e017c4df1dd49 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 26 Aug 2025 20:28:30 -0700 Subject: [PATCH] feat: ac normalization in `grind` (#10146) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR implements the basic infrastructure for the new procedure handling AC operators in grind. It already supports normalizing disequalities. Future PRs will add support for simplification using equalities, and computing critical pairs. Examples: ```lean example {α : Sort u} (op : α → α → α) [Std.Associative op] (a b c : α) : op a (op b c) = op (op a b) c := by grind only example {α : Sort u} (op : α → α → α) (u : α) [Std.Associative op] [Std.LawfulIdentity op u] (a b c : α) : op a (op b c) = op (op a b) (op c u) := by grind only example {α : Type u} (op : α → α → α) (u : α) [Std.Associative op] [Std.Commutative op] [Std.IdempotentOp op] [Std.LawfulIdentity op u] (a b c : α) : op (op a a) (op b c) = op (op (op b a) (op (op u b) b)) c := by grind only example {α} (as bs cs : List α) : as ++ (bs ++ cs) = ((as ++ []) ++ bs) ++ (cs ++ []) := by grind only example (a b c : Nat) : max a (max b c) = max (max b 0) (max a c) ∧ min a b = min b a := by grind only [cases Or] ``` --- src/Init/Grind/AC.lean | 120 ++++++++------- src/Lean/Meta/Tactic/Grind.lean | 2 + src/Lean/Meta/Tactic/Grind/AC.lean | 6 + src/Lean/Meta/Tactic/Grind/AC/DenoteExpr.lean | 33 ++++ src/Lean/Meta/Tactic/Grind/AC/Eq.lean | 79 ++++++++++ .../Meta/Tactic/Grind/AC/Internalize.lean | 16 +- src/Lean/Meta/Tactic/Grind/AC/Proof.lean | 142 ++++++++++++++++++ src/Lean/Meta/Tactic/Grind/AC/Seq.lean | 17 +++ src/Lean/Meta/Tactic/Grind/AC/ToExpr.lean | 34 +++++ src/Lean/Meta/Tactic/Grind/AC/Types.lean | 69 ++++++++- src/Lean/Meta/Tactic/Grind/AC/Util.lean | 62 +++++++- src/Lean/Meta/Tactic/Grind/AC/VarRename.lean | 33 ++++ src/Lean/Meta/Tactic/Grind/Arith.lean | 2 - .../Meta/Tactic/Grind/Arith/CommRing.lean | 1 - .../Tactic/Grind/Arith/CommRing/Proof.lean | 7 +- .../Tactic/Grind/Arith/CommRing/ToExpr.lean | 3 - .../Grind/Arith/CommRing/VarRename.lean | 8 +- .../Meta/Tactic/Grind/Arith/Cutsat/Proof.lean | 4 +- .../Tactic/Grind/Arith/Cutsat/VarRename.lean | 6 +- .../Meta/Tactic/Grind/Arith/Linear/Proof.lean | 4 +- .../Tactic/Grind/Arith/Linear/VarRename.lean | 4 +- src/Lean/Meta/Tactic/Grind/Core.lean | 27 ++++ .../Tactic/Grind/{Arith => }/ProofUtil.lean | 7 +- src/Lean/Meta/Tactic/Grind/Propagate.lean | 1 + src/Lean/Meta/Tactic/Grind/Types.lean | 54 ++++++- .../Tactic/Grind/{Arith => }/VarRename.lean | 5 +- src/out | 0 tests/lean/run/grind_ac_1.lean | 81 ++++++++++ 28 files changed, 729 insertions(+), 98 deletions(-) create mode 100644 src/Lean/Meta/Tactic/Grind/AC/DenoteExpr.lean create mode 100644 src/Lean/Meta/Tactic/Grind/AC/Eq.lean create mode 100644 src/Lean/Meta/Tactic/Grind/AC/Proof.lean create mode 100644 src/Lean/Meta/Tactic/Grind/AC/Seq.lean create mode 100644 src/Lean/Meta/Tactic/Grind/AC/ToExpr.lean create mode 100644 src/Lean/Meta/Tactic/Grind/AC/VarRename.lean rename src/Lean/Meta/Tactic/Grind/{Arith => }/ProofUtil.lean (91%) rename src/Lean/Meta/Tactic/Grind/{Arith => }/VarRename.lean (95%) create mode 100644 src/out create mode 100644 tests/lean/run/grind_ac_1.lean diff --git a/src/Init/Grind/AC.lean b/src/Init/Grind/AC.lean index d348ed0e39..d177b2433e 100644 --- a/src/Init/Grind/AC.lean +++ b/src/Init/Grind/AC.lean @@ -15,19 +15,22 @@ public import Init.Data.Bool namespace Lean.Grind.AC abbrev Var := Nat -structure Context (α : Type u) where - vars : RArray α +structure Context (α : Sort u) where + vars : RArray (PLift α) op : α → α → α inductive Expr where - | var (x : Nat) + | var (x : Var) | op (lhs rhs : Expr) deriving Inhabited, Repr, BEq -noncomputable def Expr.denote {α} (ctx : Context α) (e : Expr) : α := - Expr.rec (fun x => ctx.vars.get x) (fun _ _ ih₁ ih₂ => ctx.op ih₁ ih₂) e +noncomputable def Var.denote {α : Sort u} (ctx : Context α) (x : Var) : α := + PLift.rec (fun x => x) (ctx.vars.get x) -theorem Expr.denote_var {α} (ctx : Context α) (x : Var) : (Expr.var x).denote ctx = ctx.vars.get x := rfl +noncomputable def Expr.denote {α} (ctx : Context α) (e : Expr) : α := + Expr.rec (fun x => x.denote ctx) (fun _ _ ih₁ ih₂ => ctx.op ih₁ ih₂) e + +theorem Expr.denote_var {α} (ctx : Context α) (x : Var) : (Expr.var x).denote ctx = x.denote ctx := rfl theorem Expr.denote_op {α} (ctx : Context α) (a b : Expr) : (Expr.op a b).denote ctx = ctx.op (a.denote ctx) (b.denote ctx) := rfl attribute [local simp] Expr.denote_var Expr.denote_op @@ -59,10 +62,10 @@ instance : LawfulBEq Seq where rfl := by intro a; induction a <;> simp! [BEq.beq]; assumption noncomputable def Seq.denote {α} (ctx : Context α) (s : Seq) : α := - Seq.rec (fun x => ctx.vars.get x) (fun x _ ih => ctx.op (ctx.vars.get x) ih) s + Seq.rec (fun x => x.denote ctx) (fun x _ ih => ctx.op (x.denote ctx) ih) s -theorem Seq.denote_var {α} (ctx : Context α) (x : Var) : (Seq.var x).denote ctx = ctx.vars.get x := rfl -theorem Seq.denote_op {α} (ctx : Context α) (x : Var) (s : Seq) : (Seq.cons x s).denote ctx = ctx.op (ctx.vars.get x) (s.denote ctx) := rfl +theorem Seq.denote_var {α} (ctx : Context α) (x : Var) : (Seq.var x).denote ctx = x.denote ctx := rfl +theorem Seq.denote_op {α} (ctx : Context α) (x : Var) (s : Seq) : (Seq.cons x s).denote ctx = ctx.op (x.denote ctx) (s.denote ctx) := rfl attribute [local simp] Seq.denote_var Seq.denote_op @@ -152,7 +155,7 @@ theorem Seq.erase0_k_eq_erase0 (s : Seq) : s.erase0_k = s.erase0 := by attribute [local simp] Seq.erase0_k_eq_erase0 -theorem Seq.denote_erase0 {α} (ctx : Context α) {inst : Std.LawfulIdentity ctx.op (ctx.vars.get 0)} (s : Seq) +theorem Seq.denote_erase0 {α} (ctx : Context α) {inst : Std.LawfulIdentity ctx.op (Var.denote ctx 0)} (s : Seq) : s.erase0.denote ctx = s.denote ctx := by fun_induction erase0 s <;> simp_all +zetaDelta next => rw [Std.LawfulLeftIdentity.left_id (self := inst.toLawfulLeftIdentity)] @@ -179,12 +182,12 @@ theorem Seq.insert_k_eq_insert (x : Var) (s : Seq) : insert_k x s = insert x s : attribute [local simp] Seq.insert_k_eq_insert theorem Seq.denote_insert {α} (ctx : Context α) {inst₁ : Std.Associative ctx.op} {inst₂ : Std.Commutative ctx.op} (x : Var) (s : Seq) - : (s.insert x).denote ctx = ctx.op (ctx.vars.get x) (s.denote ctx) := by + : (s.insert x).denote ctx = ctx.op (x.denote ctx) (s.denote ctx) := by fun_induction insert x s <;> simp next => rw [Std.Commutative.comm (self := inst₂)] next y s h ih => simp [ih, ← Std.Associative.assoc (self := inst₁)] - rw [Std.Commutative.comm (self := inst₂) (ctx.vars.get x)] + rw [Std.Commutative.comm (self := inst₂) (x.denote ctx)] attribute [local simp] Seq.denote_insert @@ -208,7 +211,7 @@ theorem Seq.denote_sort' {α} (ctx : Context α) {inst₁ : Std.Associative ctx. fun_induction sort' s acc <;> simp next x s ih => simp [ih, ← Std.Associative.assoc (self := inst₁)] - rw [Std.Commutative.comm (self := inst₂) (ctx.vars.get x) (s.denote ctx)] + rw [Std.Commutative.comm (self := inst₂) (x.denote ctx) (s.denote ctx)] attribute [local simp] Seq.denote_sort' @@ -387,11 +390,11 @@ theorem Seq.denote_combineFuel {α} (ctx : Context α) {inst₁ : Std.Associativ next ih => simp [ih, Std.Associative.assoc (self := inst₁)] next x₁ s₁ x₂ s₂ h ih => simp [ih] - rw [← Std.Associative.assoc (self := inst₁), ← Std.Associative.assoc (self := inst₁), Std.Commutative.comm (self := inst₂) (ctx.vars.get x₂)] - rw [Std.Associative.assoc (self := inst₁), Std.Associative.assoc (self := inst₁), Std.Associative.assoc (self := inst₁) (ctx.vars.get x₁)] - apply congrArg (ctx.op (ctx.vars.get x₁)) + rw [← Std.Associative.assoc (self := inst₁), ← Std.Associative.assoc (self := inst₁), Std.Commutative.comm (self := inst₂) (x₂.denote ctx)] + rw [Std.Associative.assoc (self := inst₁), Std.Associative.assoc (self := inst₁), Std.Associative.assoc (self := inst₁) (x₁.denote ctx)] + apply congrArg (ctx.op (x₁.denote ctx)) rw [← Std.Associative.assoc (self := inst₁), ← Std.Associative.assoc (self := inst₁) (s₁.denote ctx)] - rw [Std.Commutative.comm (self := inst₂) (ctx.vars.get x₂)] + rw [Std.Commutative.comm (self := inst₂) (x₂.denote ctx)] attribute [local simp] Seq.denote_combineFuel @@ -446,54 +449,65 @@ theorem superpose_ac {α} (ctx : Context α) {inst₁ : Std.Associative ctx.op} apply congrArg (ctx.op (c.denote ctx)) rw [Std.Commutative.comm (self := inst₂) (b.denote ctx)] -noncomputable def norm_a_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := +noncomputable def eq_norm_a_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := lhs.toSeq.beq' lhs' |>.and' (rhs.toSeq.beq' rhs') -theorem norm_a {α} (ctx : Context α) {_ : Std.Associative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq) - : norm_a_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by - simp [norm_a_cert]; intro _ _; subst lhs' rhs'; simp +theorem eq_norm_a {α} (ctx : Context α) {_ : Std.Associative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq) + : eq_norm_a_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by + simp [eq_norm_a_cert]; intro _ _; subst lhs' rhs'; simp -noncomputable def norm_ac_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := +noncomputable def eq_norm_ac_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := lhs.toSeq.sort.beq' lhs' |>.and' (rhs.toSeq.sort.beq' rhs') -theorem norm_ac {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq) - : norm_ac_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by - simp [norm_ac_cert]; intro _ _; subst lhs' rhs'; simp +theorem eq_norm_ac {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq) + : eq_norm_ac_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by + simp [eq_norm_ac_cert]; intro _ _; subst lhs' rhs'; simp -noncomputable def norm_aci_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := - lhs.toSeq.erase0.sort.beq' lhs' |>.and' (rhs.toSeq.erase0.sort.beq' rhs') - -theorem norm_aci {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} {_ : Std.LawfulIdentity ctx.op (ctx.vars.get 0)} - (lhs rhs : Expr) (lhs' rhs' : Seq) : norm_aci_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by - simp [norm_aci_cert]; intro _ _; subst lhs' rhs'; simp - -noncomputable def norm_ai_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := +noncomputable def eq_norm_ai_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := lhs.toSeq.erase0.beq' lhs' |>.and' (rhs.toSeq.erase0.beq' rhs') -theorem norm_ai {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.LawfulIdentity ctx.op (ctx.vars.get 0)} - (lhs rhs : Expr) (lhs' rhs' : Seq) : norm_ai_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by - simp [norm_ai_cert]; intro _ _; subst lhs' rhs'; simp +theorem eq_norm_ai {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.LawfulIdentity ctx.op (Var.denote ctx 0)} + (lhs rhs : Expr) (lhs' rhs' : Seq) : eq_norm_ai_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by + simp [eq_norm_ai_cert]; intro _ _; subst lhs' rhs'; simp -noncomputable def norm_acip_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := - lhs.toSeq.erase0.sort.eraseDup.beq' lhs' |>.and' (rhs.toSeq.erase0.sort.eraseDup.beq' rhs') +noncomputable def eq_norm_aci_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := + lhs.toSeq.erase0.sort.beq' lhs' |>.and' (rhs.toSeq.erase0.sort.beq' rhs') -theorem norm_acip {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} - {_ : Std.LawfulIdentity ctx.op (ctx.vars.get 0)} {_ : Std.IdempotentOp ctx.op} - (lhs rhs : Expr) (lhs' rhs' : Seq) : norm_acip_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by - simp [norm_acip_cert]; intro _ _; subst lhs' rhs'; simp +theorem eq_norm_aci {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} {_ : Std.LawfulIdentity ctx.op (Var.denote ctx 0)} + (lhs rhs : Expr) (lhs' rhs' : Seq) : eq_norm_aci_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by + simp [eq_norm_aci_cert]; intro _ _; subst lhs' rhs'; simp -noncomputable def norm_acp_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := - lhs.toSeq.sort.eraseDup.beq' lhs' |>.and' (rhs.toSeq.sort.eraseDup.beq' rhs') - -theorem norm_acp {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} {_ : Std.IdempotentOp ctx.op} - (lhs rhs : Expr) (lhs' rhs' : Seq) : norm_acp_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by - simp [norm_acp_cert]; intro _ _; subst lhs' rhs'; simp - -noncomputable def norm_dup_cert (lhs rhs lhs' rhs' : Seq) : Bool := +noncomputable def eq_erase_dup_cert (lhs rhs lhs' rhs' : Seq) : Bool := lhs.eraseDup.beq' lhs' |>.and' (rhs.eraseDup.beq' rhs') -theorem norm_dup (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.IdempotentOp ctx.op} - (lhs rhs lhs' rhs' : Seq) : norm_dup_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by - simp [norm_dup_cert]; intro _ _; subst lhs' rhs'; simp +theorem eq_erase_dup {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.IdempotentOp ctx.op} + (lhs rhs lhs' rhs' : Seq) : eq_erase_dup_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by + simp [eq_erase_dup_cert]; intro _ _; subst lhs' rhs'; simp + +theorem diseq_norm_a {α} (ctx : Context α) {_ : Std.Associative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq) + : eq_norm_a_cert lhs rhs lhs' rhs' → lhs.denote ctx ≠ rhs.denote ctx → lhs'.denote ctx ≠ rhs'.denote ctx := by + simp [eq_norm_a_cert]; intro _ _; subst lhs' rhs'; simp + +theorem diseq_norm_ac {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq) + : eq_norm_ac_cert lhs rhs lhs' rhs' → lhs.denote ctx ≠ rhs.denote ctx → lhs'.denote ctx ≠ rhs'.denote ctx := by + simp [eq_norm_ac_cert]; intro _ _; subst lhs' rhs'; simp + +theorem diseq_norm_ai {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.LawfulIdentity ctx.op (Var.denote ctx 0)} + (lhs rhs : Expr) (lhs' rhs' : Seq) : eq_norm_ai_cert lhs rhs lhs' rhs' → lhs.denote ctx ≠ rhs.denote ctx → lhs'.denote ctx ≠ rhs'.denote ctx := by + simp [eq_norm_ai_cert]; intro _ _; subst lhs' rhs'; simp + +theorem diseq_norm_aci {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} {_ : Std.LawfulIdentity ctx.op (Var.denote ctx 0)} + (lhs rhs : Expr) (lhs' rhs' : Seq) : eq_norm_aci_cert lhs rhs lhs' rhs' → lhs.denote ctx ≠ rhs.denote ctx → lhs'.denote ctx ≠ rhs'.denote ctx := by + simp [eq_norm_aci_cert]; intro _ _; subst lhs' rhs'; simp + +theorem diseq_erase_dup {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.IdempotentOp ctx.op} + (lhs rhs lhs' rhs' : Seq) : eq_erase_dup_cert lhs rhs lhs' rhs' → lhs.denote ctx ≠ rhs.denote ctx → lhs'.denote ctx ≠ rhs'.denote ctx := by + simp [eq_erase_dup_cert]; intro _ _; subst lhs' rhs'; simp + +noncomputable def diseq_unsat_cert (lhs rhs : Seq) : Bool := + lhs.beq' rhs + +theorem diseq_unsat {α} (ctx : Context α) (lhs rhs : Seq) : diseq_unsat_cert lhs rhs → lhs.denote ctx ≠ rhs.denote ctx → False := by + simp [diseq_unsat_cert]; intro; subst lhs; simp end Lean.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind.lean b/src/Lean/Meta/Tactic/Grind.lean index da96eb091d..2a4d3cd49e 100644 --- a/src/Lean/Meta/Tactic/Grind.lean +++ b/src/Lean/Meta/Tactic/Grind.lean @@ -37,6 +37,8 @@ public import Lean.Meta.Tactic.Grind.LawfulEqCmp public import Lean.Meta.Tactic.Grind.ReflCmp public import Lean.Meta.Tactic.Grind.SynthInstance public import Lean.Meta.Tactic.Grind.AC +public import Lean.Meta.Tactic.Grind.VarRename +public import Lean.Meta.Tactic.Grind.ProofUtil public section diff --git a/src/Lean/Meta/Tactic/Grind/AC.lean b/src/Lean/Meta/Tactic/Grind/AC.lean index 6f8183b879..85f49e2d58 100644 --- a/src/Lean/Meta/Tactic/Grind/AC.lean +++ b/src/Lean/Meta/Tactic/Grind/AC.lean @@ -9,6 +9,12 @@ public import Lean.Meta.Tactic.Grind.AC.Types public import Lean.Meta.Tactic.Grind.AC.Util public import Lean.Meta.Tactic.Grind.AC.Var public import Lean.Meta.Tactic.Grind.AC.Internalize +public import Lean.Meta.Tactic.Grind.AC.Eq +public import Lean.Meta.Tactic.Grind.AC.Seq +public import Lean.Meta.Tactic.Grind.AC.Proof +public import Lean.Meta.Tactic.Grind.AC.DenoteExpr +public import Lean.Meta.Tactic.Grind.AC.ToExpr +public import Lean.Meta.Tactic.Grind.AC.VarRename public section namespace Lean builtin_initialize registerTraceClass `grind.ac diff --git a/src/Lean/Meta/Tactic/Grind/AC/DenoteExpr.lean b/src/Lean/Meta/Tactic/Grind/AC/DenoteExpr.lean new file mode 100644 index 0000000000..8efd379f68 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC/DenoteExpr.lean @@ -0,0 +1,33 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Tactic.Grind.AC.Util +import Lean.Meta.AppBuilder +public section +namespace Lean.Meta.Grind.AC +open Lean.Grind +variable [Monad M] [MonadGetStruct M] [MonadError M] + +def _root_.Lean.Grind.AC.Seq.denoteExpr (s : AC.Seq) : M Expr := do + match s with + | .var x => return (← getStruct).vars[x]! + | .cons x s => return mkApp2 (← getStruct).op (← getStruct).vars[x]! (← denoteExpr s) + +def _root_.Lean.Grind.AC.Expr.denoteExpr (e : AC.Expr) : M Expr := do + match e with + | .var x => return (← getStruct).vars[x]! + | .op lhs rhs => return mkApp2 (← getStruct).op (← denoteExpr lhs) (← denoteExpr rhs) + +def EqCnstr.denoteExpr (c : EqCnstr) : M Expr := do + let s ← getStruct + return mkApp3 (mkConst ``Eq [s.u]) s.type (← c.lhs.denoteExpr) (← c.rhs.denoteExpr) + +def DiseqCnstr.denoteExpr (c : DiseqCnstr) : M Expr := do + let s ← getStruct + return mkApp3 (mkConst ``Ne [s.u]) s.type (← c.lhs.denoteExpr) (← c.rhs.denoteExpr) + +end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Eq.lean b/src/Lean/Meta/Tactic/Grind/AC/Eq.lean new file mode 100644 index 0000000000..e902cb8373 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC/Eq.lean @@ -0,0 +1,79 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Tactic.Grind.AC.Util +import Lean.Meta.Tactic.Grind.AC.DenoteExpr +import Lean.Meta.Tactic.Grind.AC.Proof +public section +namespace Lean.Meta.Grind.AC +open Lean.Grind +/-- For each structure `s` s.t. `a` and `b` are elements of, execute `k` -/ +@[specialize] def withExprs (a b : Expr) (k : ACM Unit) : GoalM Unit := do + let ids₁ ← getTermOpIds a + if ids₁.isEmpty then return () + let ids₂ ← getTermOpIds b + go ids₁ ids₂ +where + go : List Nat → List Nat → GoalM Unit + | [], _ => return () + | _, [] => return () + | id₁::ids₁, id₂::ids₂ => do + if id₁ == id₂ then + ACM.run id₁ k; go ids₁ ids₂ + else if id₁ < id₂ then + go ids₁ (id₂::ids₂) + else + go (id₁::ids₁) ids₂ + +def asACExpr (e : Expr) : ACM AC.Expr := do + if let some e' := (← getStruct).denote.find? { expr := e } then + return e' + else + return .var ((← getStruct).varMap.find? { expr := e }).get! + +def norm (e : AC.Expr) : ACM AC.Seq := do + match (← isCommutative), (← hasNeutral) with + | true, true => return e.toSeq.erase0.sort + | true, false => return e.toSeq.sort + | false, true => return e.toSeq.erase0 + | false, false => return e.toSeq + +def saveDiseq (c : DiseqCnstr) : ACM Unit := do + modifyStruct fun s => { s with diseqs := s.diseqs.push c } + +def DiseqCnstr.eraseDup (c : DiseqCnstr) : ACM DiseqCnstr := do + unless (← isIdempotent) do return c + let lhs := c.lhs.eraseDup + let rhs := c.rhs.eraseDup + if c.lhs == lhs && c.rhs == rhs then + return c + else + return { lhs, rhs, h := .erase_dup c } + +def DiseqCnstr.assert (c : DiseqCnstr) : ACM Unit := do + let c ← c.eraseDup + -- TODO: simplify and check conflict + trace[grind.ac.assert] "{← c.denoteExpr}" + if c.lhs == c.rhs then + c.setUnsat + else + saveDiseq c + +@[export lean_process_ac_eq] +def processNewEqImpl (a b : Expr) : GoalM Unit := withExprs a b do + trace[grind.ac.assert] "{a} = {b}" + -- TODO + +@[export lean_process_ac_diseq] +def processNewDiseqImpl (a b : Expr) : GoalM Unit := withExprs a b do + let ea ← asACExpr a + let lhs ← norm ea + let eb ← asACExpr b + let rhs ← norm eb + { lhs, rhs, h := .core a b ea eb : DiseqCnstr }.assert + +end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Internalize.lean b/src/Lean/Meta/Tactic/Grind/AC/Internalize.lean index d3f0bf85f3..b9eafcf615 100644 --- a/src/Lean/Meta/Tactic/Grind/AC/Internalize.lean +++ b/src/Lean/Meta/Tactic/Grind/AC/Internalize.lean @@ -7,6 +7,7 @@ module prelude public import Lean.Meta.Tactic.Grind.Types public import Lean.Meta.Tactic.Grind.AC.Util +import Lean.Meta.Tactic.Grind.AC.DenoteExpr public section namespace Lean.Meta.Grind.AC @@ -15,6 +16,12 @@ private def isParentSameOpApp (parent? : Option Expr) (op : Expr) : GoalM Bool : unless e.isApp && e.appFn!.isApp do return false return isSameExpr e.appFn!.appFn! op +partial def reify (e : Expr) : ACM Grind.AC.Expr := do + if let some (a, b) ← isOp? e then + return .op (← reify a) (← reify b) + else + return .var (← mkVar e) + @[export lean_grind_ac_internalize] def internalizeImpl (e : Expr) (parent? : Option Expr) : GoalM Unit := do unless (← getConfig).ac do return () @@ -22,7 +29,12 @@ def internalizeImpl (e : Expr) (parent? : Option Expr) : GoalM Unit := do let op := e.appFn!.appFn! let some id ← getOpId? op | return () if (← isParentSameOpApp parent? op) then return () - trace[grind.ac.internalize] "[{id}] {e}" - -- TODO: internalize `e` + ACM.run id do + if (← getStruct).denote.contains { expr := e } then return () + let e' ← reify e + modifyStruct fun s => { s with denote := s.denote.insert { expr := e } e' } + trace[grind.ac.internalize] "[{id}] {← e'.denoteExpr}" + addTermOpId e + markAsACTerm e end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Proof.lean b/src/Lean/Meta/Tactic/Grind/AC/Proof.lean new file mode 100644 index 0000000000..d2962d88f2 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC/Proof.lean @@ -0,0 +1,142 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Tactic.Grind.AC.Util +import Lean.Meta.Tactic.Grind.Diseq +import Lean.Meta.Tactic.Grind.ProofUtil +import Lean.Meta.Tactic.Grind.AC.ToExpr +import Lean.Meta.Tactic.Grind.AC.VarRename +public section +namespace Lean.Meta.Grind.AC +open Lean.Grind + +structure ProofM.State where + cache : Std.HashMap UInt64 Expr := {} + exprDecls : Std.HashMap AC.Expr Expr := {} + seqDecls : Std.HashMap AC.Seq Expr := {} + +structure ProofM.Context where + ctx : Expr + +abbrev ProofM := ReaderT ProofM.Context (StateRefT ProofM.State ACM) + +/-- Returns a Lean expression representing the variable context used to construct `AC` proof steps. -/ +private def getContext : ProofM Expr := do + return (← read).ctx + +private abbrev caching (c : α) (k : ProofM Expr) : ProofM Expr := do + let addr := unsafe (ptrAddrUnsafe c).toUInt64 >>> 2 + if let some h := (← get).cache[addr]? then + return h + else + let h ← k + modify fun s => { s with cache := s.cache.insert addr h } + return h + +local macro "declare! " decls:ident a:ident : term => + `(do if let some x := (← get).$decls[$a]? then + return x + let x := mkFVar (← mkFreshFVarId); + modify fun s => { s with $decls:ident := (s.$decls).insert $a x }; + return x) + +private def mkSeqDecl (s : AC.Seq) : ProofM Expr := do + declare! seqDecls s + +private def mkExprDecl (e : AC.Expr) : ProofM Expr := do + declare! exprDecls e + +private def getCommInst : ACM Expr := do + let some inst := (← getStruct).commInst? | throwError "`grind` internal error, `{(← getStruct).op}` is not commutative" + return inst + +private def getIdempotentInst : ACM Expr := do + let some inst := (← getStruct).idempotentInst? | throwError "`grind` internal error, `{(← getStruct).op}` is not idempotent" + return inst + +private def getNeutralInst : ACM Expr := do + let some inst := (← getStruct).neutralInst? | throwError "`grind` internal error, `{(← getStruct).op}` does not have a neutral element" + return inst + +private def mkPrefix (declName : Name) : ProofM Expr := do + let s ← getStruct + return mkApp2 (mkConst declName [s.u]) s.type (← getContext) + +private def mkAPrefix (declName : Name) : ProofM Expr := do + let s ← getStruct + return mkApp3 (mkConst declName [s.u]) s.type (← getContext) s.assocInst + +private def mkACPrefix (declName : Name) : ProofM Expr := do + let s ← getStruct + return mkApp4 (mkConst declName [s.u]) s.type (← getContext) s.assocInst (← getCommInst) + +private def mkAIPrefix (declName : Name) : ProofM Expr := do + let s ← getStruct + return mkApp4 (mkConst declName [s.u]) s.type (← getContext) s.assocInst (← getNeutralInst) + +private def mkACIPrefix (declName : Name) : ProofM Expr := do + let s ← getStruct + return mkApp5 (mkConst declName [s.u]) s.type (← getContext) s.assocInst (← getCommInst) (← getNeutralInst) + +private def mkDupPrefix (declName : Name) : ProofM Expr := do + let s ← getStruct + return mkApp4 (mkConst declName [s.u]) s.type (← getContext) s.assocInst (← getIdempotentInst) + +private def toContextExpr (vars : Array Expr) : ACM Expr := do + let s ← getStruct + if h : 0 < vars.size then + RArray.toExpr (mkApp (mkConst ``PLift [s.u]) s.type) id (RArray.ofFn (vars[·]) h) + else unreachable! + +private def mkContext (h : Expr) : ProofM Expr := do + let s ← getStruct + let mut usedVars := + collectMapVars (← get).seqDecls (·.collectVars) >> + collectMapVars (← get).exprDecls (·.collectVars) >> + (if (← hasNeutral) then (collectVar 0) else id) <| {} + let vars' := usedVars.toArray + let varRename := mkVarRename vars' + let vars := (← getStruct).vars + let up := mkApp (mkConst ``PLift.up [s.u]) s.type + let vars := vars'.map fun x => mkApp up vars[x]! + let h := mkLetOfMap (← get).seqDecls h `p (mkConst ``Grind.AC.Seq) fun p => toExpr <| p.renameVars varRename + let h := mkLetOfMap (← get).exprDecls h `e (mkConst ``Grind.AC.Expr) fun e => toExpr <| e.renameVars varRename + let h := h.abstract #[(← read).ctx] + if h.hasLooseBVars then + let ctxType := mkApp (mkConst ``AC.Context [s.u]) s.type + let ctxVal := mkApp3 (mkConst ``AC.Context.mk [s.u]) s.type (← toContextExpr vars) s.op + return .letE `ctx ctxType ctxVal h (nondep := false) + else + return h + +private abbrev withProofContext (x : ProofM Expr) : ACM Expr := do + let ctx := mkFVar (← mkFreshFVarId) + go { ctx } |>.run' {} +where + go : ProofM Expr := do + mkContext (← x) + +partial def DiseqCnstr.toExprProof (c : DiseqCnstr) : ProofM Expr := do caching c do + match c.h with + | .core a b lhs rhs => + let h ← match (← isCommutative), (← hasNeutral) with + | false, false => mkAPrefix ``AC.diseq_norm_a + | false, true => mkAIPrefix ``AC.diseq_norm_ai + | true, false => mkACPrefix ``AC.diseq_norm_ac + | true, true => mkACIPrefix ``AC.diseq_norm_aci + return mkApp6 h (← mkExprDecl lhs) (← mkExprDecl rhs) (← mkSeqDecl c.lhs) (← mkSeqDecl c.rhs) eagerReflBoolTrue (← mkDiseqProof a b) + | .erase_dup c₁ => + let h ← mkDupPrefix ``AC.diseq_erase_dup + return mkApp6 h (← mkSeqDecl c₁.lhs) (← mkSeqDecl c₁.rhs) (← mkSeqDecl c.lhs) (← mkSeqDecl c.rhs) eagerReflBoolTrue (← c₁.toExprProof) + | _ => throwError "NIY" + +def DiseqCnstr.setUnsat (c : DiseqCnstr) : ACM Unit := do + let h ← withProofContext do + return mkApp4 (← mkPrefix ``AC.diseq_unsat) (← mkSeqDecl c.lhs) (← mkSeqDecl c.rhs) eagerReflBoolTrue (← c.toExprProof) + closeGoal h + +end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Seq.lean b/src/Lean/Meta/Tactic/Grind/AC/Seq.lean new file mode 100644 index 0000000000..316bde1097 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC/Seq.lean @@ -0,0 +1,17 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Init.Core +public import Init.Grind.AC +public section +namespace Lean.Grind.AC + +def Seq.length : Seq → Nat + | .var _ => 1 + | .cons _ s => s.length + 1 + +end Lean.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/ToExpr.lean b/src/Lean/Meta/Tactic/Grind/AC/ToExpr.lean new file mode 100644 index 0000000000..97c49687cd --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC/ToExpr.lean @@ -0,0 +1,34 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Init.Grind.AC +public import Lean.ToExpr +public section +namespace Lean.Meta.Grind.AC +open Lean.Grind +/-! +`ToExpr` instances for `AC` types. +-/ +def ofSeq (m : AC.Seq) : Expr := + match m with + | .var x => mkApp (mkConst ``AC.Seq.var) (toExpr x) + | .cons x s => mkApp2 (mkConst ``AC.Seq.cons) (toExpr x) (ofSeq s) + +instance : ToExpr AC.Seq where + toExpr := ofSeq + toTypeExpr := mkConst ``AC.Seq + +def ofExpr (m : AC.Expr) : Expr := + match m with + | .var x => mkApp (mkConst ``AC.Expr.var) (toExpr x) + | .op lhs rhs => mkApp2 (mkConst ``AC.Expr.op) (ofExpr lhs) (ofExpr rhs) + +instance : ToExpr AC.Expr where + toExpr := ofExpr + toTypeExpr := mkConst ``AC.Expr + +end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Types.lean b/src/Lean/Meta/Tactic/Grind/AC/Types.lean index e111f6fee2..ff19fda126 100644 --- a/src/Lean/Meta/Tactic/Grind/AC/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/AC/Types.lean @@ -11,10 +11,57 @@ public import Std.Data.HashMap public import Lean.Expr public import Lean.Data.PersistentArray public import Lean.Meta.Tactic.Grind.ExprPtr +import Lean.Meta.Tactic.Grind.AC.Seq public section - namespace Lean.Meta.Grind.AC -open Lean.Grind.AC +open Lean.Grind + +deriving instance Hashable for AC.Expr, AC.Seq + +mutual +structure EqCnstr where + lhs : AC.Seq + rhs : AC.Seq + h : EqCnstrProof + id : Nat + +inductive EqCnstrProof where + | core (a b : Expr) (ea eb : AC.Expr) + | erase_dup (c : EqCnstr) + | simp_ac (lhs : Bool) (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr) + | superpose_ac (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr) + | simp_suffix (lhs : Bool) (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr) + | simp_prefix (lhs : Bool) (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr) + | simp_middle (lhs : Bool) (s₁ s₂ : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr) + | superpose_prefix (s₁ s₂ : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr) +end + +instance : Inhabited EqCnstrProof where + default := .core default default default default + +instance : Inhabited EqCnstr where + default := { lhs := default, rhs := default, h := default, id := 0 } + +protected def EqCnstr.compare (c₁ c₂ : EqCnstr) : Ordering := + (compare c₁.lhs.length c₂.lhs.length) |>.then + (compare c₁.id c₂.id) + +abbrev Queue : Type := Std.TreeSet EqCnstr EqCnstr.compare + +mutual +structure DiseqCnstr where + lhs : AC.Seq + rhs : AC.Seq + h : DiseqCnstrProof + +inductive DiseqCnstrProof where + | core (a b : Expr) (ea eb : AC.Expr) + | erase_dup (c : DiseqCnstr) + | simp_ac (lhs : Bool) (s : AC.Seq) (c₁ : DiseqCnstr) (c₂ : EqCnstr) + | simp_suffix (lhs : Bool) (s : AC.Seq) (c₁ : DiseqCnstr) (c₂ : EqCnstr) + | simp_prefix (lhs : Bool) (s : AC.Seq) (c₁ : DiseqCnstr) (c₂ : EqCnstr) + | simp_middle (lhs : Bool) (s₁ s₂ : AC.Seq) (c₁ : DiseqCnstr) (c₂ : EqCnstr) +end structure Struct where id : Nat @@ -27,13 +74,23 @@ structure Struct where idempotentInst? : Option Expr commInst? : Option Expr neutralInst? : Option Expr + /-- Next unique id for `EqCnstr`s. -/ + nextId : Nat := 0 /-- Mapping from variables to their denotations. Remark each variable can be in only one ring. -/ vars : PArray Expr := {} /-- Mapping from `Expr` to a variable representing it. -/ - varMap : PHashMap ExprPtr Var := {} + varMap : PHashMap ExprPtr AC.Var := {} + /-- Mapping from Lean expressions to their representations as `AC.Expr` -/ + denote : PHashMap ExprPtr AC.Expr := {} + /-- Equations to process. -/ + queue : Queue := {} + /-- Processed equations. -/ + basis : List EqCnstr := {} + /-- Disequalities. -/ + diseqs : PArray DiseqCnstr := {} deriving Inhabited /-- State for all associative operators detected by `grind`. -/ @@ -47,8 +104,10 @@ structure State where Mapping from operators to its "operator id". We cache failures using `none`. `opIdOf[op]` is `some id`, then `id < structs.size`. -/ opIdOf : PHashMap ExprPtr (Option Nat) := {} - -- Remark: a term may be argument of different associative operators. - -- TODO: add mappings + /-- + Mapping from expressions/terms to their structure ids. + Recall that term may be the argument of different operators. -/ + exprToOpIds : PHashMap ExprPtr (List Nat) := {} deriving Inhabited end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Util.lean b/src/Lean/Meta/Tactic/Grind/AC/Util.lean index 8cc982012c..d83c994f49 100644 --- a/src/Lean/Meta/Tactic/Grind/AC/Util.lean +++ b/src/Lean/Meta/Tactic/Grind/AC/Util.lean @@ -10,8 +10,8 @@ public import Lean.Meta.Tactic.Grind.ProveEq public import Lean.Meta.Tactic.Grind.SynthInstance public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId public section - namespace Lean.Meta.Grind.AC +open Lean.Grind def get' : GoalM State := do return (← get).ac @@ -50,6 +50,10 @@ protected def ACM.getStruct : ACM Struct := do instance : MonadGetStruct ACM where getStruct := ACM.getStruct +def modifyStruct (f : Struct → Struct) : ACM Unit := do + let opId ← getOpId + modify' fun s => { s with structs := s.structs.modify opId f } + def getOp : ACM Expr := return (← getStruct).op @@ -71,6 +75,42 @@ private def isArithOpInOtherModules (op : Expr) (f : Expr) : GoalM Bool := do if (← Arith.CommRing.getSemiringId? α).isSome then return true return false +def getTermOpIds (e : Expr) : GoalM (List Nat) := do + return (← get').exprToOpIds.find? { expr := e } |>.getD [] + +private def insertOpId (m : PHashMap ExprPtr (List Nat)) (e : Expr) (opId : Nat) : PHashMap ExprPtr (List Nat) := + let ids := if let some ids := m.find? { expr := e } then + go ids + else + [opId] + m.insert { expr := e } ids +where + go : List Nat → List Nat + | [] => [opId] + | id::ids => if opId < id then + opId :: id :: ids + else if opId == id then + opId :: ids + else + id :: go ids + +def addTermOpId (e : Expr) : ACM Unit := do + let opId ← getOpId + modify' fun s => { s with exprToOpIds := insertOpId s.exprToOpIds e opId } + +def mkVar (e : Expr) : ACM AC.Var := do + let s ← getStruct + if let some var := s.varMap.find? { expr := e } then + return var + let var : AC.Var := s.vars.size + modifyStruct fun s => { s with + vars := s.vars.push e + varMap := s.varMap.insert { expr := e } var + } + addTermOpId e + markAsACTerm e + return var + def getOpId? (op : Expr) : GoalM (Option Nat) := do if let some id? := (← get').opIdOf.find? { expr := op } then return id? @@ -98,7 +138,7 @@ where let idempotentInst? ← synthInstance? idempotentType let (neutralInst?, neutral?) ← do let neutral ← mkFreshExprMVar α - let identityType := mkApp3 (mkConst ``Std.Identity [u]) α op neutral + let identityType := mkApp3 (mkConst ``Std.LawfulIdentity [u]) α op neutral if let some identityInst ← synthInstance? identityType then let neutral ← instantiateExprMVars neutral let neutral ← preprocessLight neutral @@ -112,8 +152,24 @@ where id, type := α, u, op, neutral?, assocInst, commInst?, idempotentInst?, neutralInst? }} - -- TODO: neutral element must be variable 0 trace[grind.debug.ac.op] "{op}, comm: {commInst?.isSome}, idempotent: {idempotentInst?.isSome}, neutral?: {neutral?}" + if let some neutral := neutral? then ACM.run id do + -- Create neutral variable to ensure it is variable 0 + discard <| mkVar neutral return some id +def isOp? (e : Expr) : ACM (Option (Expr × Expr)) := do + unless e.isApp && e.appFn!.isApp do return none + unless isSameExpr e.appFn!.appFn! (← getOp) do return none + return some (e.appFn!.appArg!, e.appArg!) + +def isCommutative : ACM Bool := + return (← getStruct).commInst?.isSome + +def hasNeutral : ACM Bool := + return (← getStruct).neutralInst?.isSome + +def isIdempotent : ACM Bool := + return (← getStruct).idempotentInst?.isSome + end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/VarRename.lean b/src/Lean/Meta/Tactic/Grind/AC/VarRename.lean new file mode 100644 index 0000000000..da16f07af4 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC/VarRename.lean @@ -0,0 +1,33 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Init.Grind.AC +public import Lean.Meta.Tactic.Grind.VarRename +namespace Lean.Grind.AC +open Lean.Meta.Grind + +public def Seq.renameVars (s : Seq) (f : VarRename) : Seq := + match s with + | .var x => .var (f x) + | .cons x s => .cons (f x) (renameVars s f) + +public def Expr.renameVars (e : Expr) (f : VarRename) : Expr := + match e with + | .var x => .var (f x) + | .op a b => .op (renameVars a f) (renameVars b f) + +public def Seq.collectVars (s : Seq) : VarCollector := + match s with + | .var x => collectVar x + | .cons x s => collectVar x >> s.collectVars + +public def Expr.collectVars (e : Expr) : VarCollector := + match e with + | .var x => collectVar x + | .op a b => a.collectVars >> b.collectVars + +end Lean.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/Arith.lean b/src/Lean/Meta/Tactic/Grind/Arith.lean index cd53e9321b..265141f7a0 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith.lean @@ -7,7 +7,6 @@ module prelude public import Lean.Meta.Tactic.Grind.Arith.Util -public import Lean.Meta.Tactic.Grind.Arith.ProofUtil public import Lean.Meta.Tactic.Grind.Arith.Types public import Lean.Meta.Tactic.Grind.Arith.Main public import Lean.Meta.Tactic.Grind.Arith.Offset @@ -15,7 +14,6 @@ public import Lean.Meta.Tactic.Grind.Arith.Cutsat public import Lean.Meta.Tactic.Grind.Arith.CommRing public import Lean.Meta.Tactic.Grind.Arith.Linear public import Lean.Meta.Tactic.Grind.Arith.Simproc -public import Lean.Meta.Tactic.Grind.Arith.VarRename public import Lean.Meta.Tactic.Grind.Arith.Internalize public section diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean index 86ceaf42fe..ec878454e6 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean @@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Lean.Util.Trace public import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean index aff198b2ba..a7747f0fb4 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean @@ -4,22 +4,19 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Init.Grind.Ring.OfSemiring public import Lean.Meta.Tactic.Grind.Diseq -public import Lean.Meta.Tactic.Grind.Arith.ProofUtil +public import Lean.Meta.Tactic.Grind.ProofUtil public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId public import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr public import Lean.Meta.Tactic.Grind.Arith.CommRing.SafePoly public import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr public import Lean.Meta.Tactic.Grind.Arith.CommRing.SemiringM -public import Lean.Meta.Tactic.Grind.Arith.VarRename +public import Lean.Meta.Tactic.Grind.VarRename import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename import Lean.Meta.Tactic.Grind.Arith.CommRing.Functions - public section - namespace Lean.Meta.Grind.Arith.CommRing /-- Returns a context of type `RArray α` containing the variables `vars` where diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/ToExpr.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/ToExpr.lean index 23bbb9f7a1..ad918d490e 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/ToExpr.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/ToExpr.lean @@ -4,14 +4,11 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Init.Grind.Ring.Poly public import Init.Grind.Ring.OfSemiring public import Lean.ToExpr - public section - namespace Lean.Meta.Grind.Arith.CommRing open Grind.CommRing /-! diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/VarRename.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/VarRename.lean index 7f15cc7e86..d1b7e3b0b1 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/VarRename.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/VarRename.lean @@ -4,14 +4,12 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Init.Grind.Ring.Poly public import Init.Grind.Ring.OfSemiring -public import Lean.Meta.Tactic.Grind.Arith.VarRename - +public import Lean.Meta.Tactic.Grind.VarRename namespace Lean.Grind.CommRing -open Lean.Meta.Grind.Arith +open Lean.Meta.Grind public def Power.renameVars (pw : Power) (f : VarRename) : Power := { pw with x := (f pw.x) } @@ -59,7 +57,7 @@ public def Expr.collectVars (e : Expr) : VarCollector := end Lean.Grind.CommRing namespace Lean.Grind.Ring.OfSemiring -open Lean.Meta.Grind.Arith +open Lean.Meta.Grind public def Expr.renameVars (e : Expr) (f : VarRename) : Expr := match e with diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean index 2cb8ee42e3..c03ef11ae3 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean @@ -8,8 +8,8 @@ prelude public import Init.Grind.Ring.Poly public import Lean.Meta.Tactic.Grind.Types import Lean.Meta.Tactic.Grind.Diseq -import Lean.Meta.Tactic.Grind.Arith.ProofUtil -import Lean.Meta.Tactic.Grind.Arith.VarRename +import Lean.Meta.Tactic.Grind.ProofUtil +import Lean.Meta.Tactic.Grind.VarRename import Lean.Meta.Tactic.Grind.Arith.Cutsat.CommRing import Lean.Meta.Tactic.Grind.Arith.Cutsat.Util import Lean.Meta.Tactic.Grind.Arith.Cutsat.Nat diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/VarRename.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/VarRename.lean index f65b4018b3..4b70194792 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/VarRename.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/VarRename.lean @@ -4,13 +4,11 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Init.Data.Int.Linear -public import Lean.Meta.Tactic.Grind.Arith.VarRename - +public import Lean.Meta.Tactic.Grind.VarRename namespace Int.Linear -open Lean.Meta.Grind.Arith +open Lean.Meta.Grind public def Poly.renameVars (p : Poly) (f : VarRename) : Poly := match p with diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean index c0c617ae9a..40bd165cf7 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean @@ -9,11 +9,11 @@ public import Lean.Meta.Tactic.Grind.Arith.Linear.Util import Lean.Meta.Tactic.Grind.Arith.Util import Lean.Meta.Tactic.Grind.Arith.Linear.ToExpr import Lean.Meta.Tactic.Grind.Arith.Linear.DenoteExpr -import Lean.Meta.Tactic.Grind.Arith.VarRename +import Lean.Meta.Tactic.Grind.VarRename import Lean.Meta.Tactic.Grind.Diseq +import Lean.Meta.Tactic.Grind.ProofUtil import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr -import Lean.Meta.Tactic.Grind.Arith.ProofUtil import Lean.Meta.Tactic.Grind.Arith.Linear.VarRename import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear/VarRename.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear/VarRename.lean index 6504e20ebd..d264927894 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Linear/VarRename.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear/VarRename.lean @@ -7,10 +7,10 @@ module prelude public import Init.Grind.Ordered.Linarith -public import Lean.Meta.Tactic.Grind.Arith.VarRename +public import Lean.Meta.Tactic.Grind.VarRename namespace Lean.Grind.Linarith -open Lean.Meta.Grind.Arith +open Lean.Meta.Grind public def Poly.renameVars (p : Poly) (f : VarRename) : Poly := match p with diff --git a/src/Lean/Meta/Tactic/Grind/Core.lean b/src/Lean/Meta/Tactic/Grind/Core.lean index 7fb0326d66..284aaefb08 100644 --- a/src/Lean/Meta/Tactic/Grind/Core.lean +++ b/src/Lean/Meta/Tactic/Grind/Core.lean @@ -215,6 +215,31 @@ def propagateLinarith : PendingTheoryPropagation → GoalM Unit | .diseqs ps => propagateLinarithDiseqs ps | _ => return () +/-- +Helper function for combining `ENode.ac?` fields and detecting what needs to be +propagated to the ac module. +-/ +private def checkACEq (rhsRoot lhsRoot : ENode) : GoalM PendingTheoryPropagation := do + match lhsRoot.ac? with + | some lhs => + if let some rhs := rhsRoot.ac? then + return .eq lhs rhs + else + -- We have to retrieve the node because other fields have been updated + let rhsRoot ← getENode rhsRoot.self + setENode rhsRoot.self { rhsRoot with ac? := lhs } + return .diseqs (← getParents rhsRoot.self) + | none => + if rhsRoot.ac?.isSome then + return .diseqs (← getParents lhsRoot.self) + else + return .none + +def propagateAC : PendingTheoryPropagation → GoalM Unit + | .eq lhs rhs => AC.processNewEq lhs rhs + | .diseqs ps => propagateACDiseqs ps + | _ => return () + /-- Tries to apply beta-reduction using the parent applications of the functions in `fns` with the lambda expressions in `lams`. @@ -349,6 +374,7 @@ where let cutsatTodo ← checkCutsatEq rhsRoot lhsRoot let ringTodo ← checkCommRingEq rhsRoot lhsRoot let linarithTodo ← checkLinarithEq rhsRoot lhsRoot + let ACTodo ← checkACEq rhsRoot lhsRoot resetParentsOf lhsRoot.self copyParentsTo parents rhsNode.root unless (← isInconsistent) do @@ -363,6 +389,7 @@ where propagateCutsat cutsatTodo propagateCommRing ringTodo propagateLinarith linarithTodo + propagateAC ACTodo updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do let isFalseRoot ← isFalseExpr rootNew traverseEqc lhs fun n => do diff --git a/src/Lean/Meta/Tactic/Grind/Arith/ProofUtil.lean b/src/Lean/Meta/Tactic/Grind/ProofUtil.lean similarity index 91% rename from src/Lean/Meta/Tactic/Grind/Arith/ProofUtil.lean rename to src/Lean/Meta/Tactic/Grind/ProofUtil.lean index b07a792038..96fe0a92e9 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/ProofUtil.lean +++ b/src/Lean/Meta/Tactic/Grind/ProofUtil.lean @@ -4,13 +4,10 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Lean.Meta.Tactic.Grind.Types - public section - -namespace Lean.Meta.Grind.Arith +namespace Lean.Meta.Grind def mkLetOfMap {_ : Hashable α} {_ : BEq α} (m : Std.HashMap α Expr) (e : Expr) (varPrefix : Name) (varType : Expr) (toExpr : α → Expr) : Expr := Id.run do @@ -25,4 +22,4 @@ def mkLetOfMap {_ : Hashable α} {_ : BEq α} (m : Std.HashMap α Expr) (e : Exp i := i - 1 return e -end Lean.Meta.Grind.Arith +end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Propagate.lean b/src/Lean/Meta/Tactic/Grind/Propagate.lean index f1030983eb..35d3a0b7fc 100644 --- a/src/Lean/Meta/Tactic/Grind/Propagate.lean +++ b/src/Lean/Meta/Tactic/Grind/Propagate.lean @@ -186,6 +186,7 @@ builtin_grind_propagator propagateEqDown ↓Eq := fun e => do propagateCutsatDiseq lhs rhs propagateCommRingDiseq lhs rhs propagateLinarithDiseq lhs rhs + propagateACDiseq lhs rhs let thms ← getExtTheorems α if !thms.isEmpty then /- diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index 7c99f2ef94..af19e68dc2 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -440,7 +440,12 @@ structure ENode where to the linarith module. Its implementation is similar to the `offset?` field. -/ linarith? : Option Expr := none - -- Remark: we expect to have builtin support for offset constraints, cutsat, comm ring, and linarith. + /-- + The `ac?` field is used to propagate equalities from the `grind` congruence closure module + to the ac module. Its implementation is similar to the `offset?` field. + -/ + ac? : Option Expr := none + -- Remark: we expect to have builtin support for offset constraints, cutsat, comm ring, linarith, and ac. -- If the number of satellite solvers increases, we may add support for an arbitrary solvers like done in Z3. deriving Inhabited, Repr @@ -1283,6 +1288,53 @@ def markAsLinarithTerm (e : Expr) : GoalM Unit := do setENode root.self { root with linarith? := some e } propagateLinarithDiseqs (← getParents root.self) +/-- +Notifies the ac module that `a = b` where +`a` and `b` are terms that have been internalized by this module. +-/ +@[extern "lean_process_ac_eq"] -- forward definition +opaque AC.processNewEq (a b : Expr) : GoalM Unit + +/-- +Notifies the ac module that `a ≠ b` where +`a` and `b` are terms that have been internalized by this module. +-/ +@[extern "lean_process_ac_diseq"] -- forward definition +opaque AC.processNewDiseq (a b : Expr) : GoalM Unit + +/-- +Given `lhs` and `rhs` that are known to be disequal, checks whether +`lhs` and `rhs` have ac terms `e₁` and `e₂` attached to them, +and invokes process `AC.processNewDiseq e₁ e₂` +-/ +def propagateACDiseq (lhs rhs : Expr) : GoalM Unit := do + let some lhs ← get? lhs | return () + let some rhs ← get? rhs | return () + AC.processNewDiseq lhs rhs +where + get? (a : Expr) : GoalM (Option Expr) := do + return (← getRootENode a).ac? + +/-- +Traverses disequalities in `parents`, and propagate the ones relevant to the +ac module. +-/ +def propagateACDiseqs (parents : ParentSet) : GoalM Unit := do + forEachDiseq parents propagateACDiseq + +/-- +Marks `e` as a term of interest to the ac module. +If the root of `e`s equivalence class has already a term of interest, +a new equality is propagated to the ac module. +-/ +def markAsACTerm (e : Expr) : GoalM Unit := do + let root ← getRootENode e + if let some e' := root.ac? then + AC.processNewEq e e' + else + setENode root.self { root with ac? := some e } + propagateACDiseqs (← getParents root.self) + /-- Returns `true` is `e` is the root of its congruence class. -/ def isCongrRoot (e : Expr) : GoalM Bool := do return (← getENode e).isCongrRoot diff --git a/src/Lean/Meta/Tactic/Grind/Arith/VarRename.lean b/src/Lean/Meta/Tactic/Grind/VarRename.lean similarity index 95% rename from src/Lean/Meta/Tactic/Grind/Arith/VarRename.lean rename to src/Lean/Meta/Tactic/Grind/VarRename.lean index 5dbd3d7336..342f6be835 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/VarRename.lean +++ b/src/Lean/Meta/Tactic/Grind/VarRename.lean @@ -4,13 +4,12 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Init.Prelude public import Init.Data.Array.QSort public import Std.Data.HashSet public section -namespace Lean.Meta.Grind.Arith +namespace Lean.Meta.Grind abbrev Var : Type := Nat abbrev FoundVars := Std.HashSet Nat @@ -45,4 +44,4 @@ def mkVarRename (new2old : Array Var) : VarRename := Id.run do new := new + 1 { map := old2new } -end Lean.Meta.Grind.Arith +end Lean.Meta.Grind diff --git a/src/out b/src/out new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/lean/run/grind_ac_1.lean b/tests/lean/run/grind_ac_1.lean new file mode 100644 index 0000000000..68f0cb9e9b --- /dev/null +++ b/tests/lean/run/grind_ac_1.lean @@ -0,0 +1,81 @@ +open Lean Grind AC +example {α : Type u} (op : α → α → α) [Std.Associative op] (a b c : α) + : op a (op b c) = op (op a b) c := by + grind only + +example {α : Sort u} (op : α → α → α) [Std.Associative op] (a b c : α) + : op a (op b c) = op (op a b) c := by + grind only + +example {α : Sort u} (op : α → α → α) (u : α) [Std.Associative op] [Std.LawfulIdentity op u] (a b c : α) + : op a (op b c) = op (op a b) (op c u) := by + grind only + +example {α : Sort u} (op : α → α → α) [Std.Associative op] [Std.Commutative op] (a b c : α) + : op c (op b a) = op (op b a) c := by + grind only + +example {α : Sort u} (op : α → α → α) (u : α) [Std.Associative op] [Std.Commutative op] [Std.LawfulIdentity op u] (a b c : α) + : op a (op b c) = op (op b a) c := by + grind only + +example {α : Sort u} (op : α → α → α) (u : α) [Std.Associative op] [Std.Commutative op] [Std.LawfulIdentity op u] (a b c : α) + : op a (op b (op u c)) = op (op b a) (op u c) := by + grind only + +example {α : Sort u} (op : α → α → α) [Std.Associative op] [Std.IdempotentOp op] (a b c : α) + : op (op a a) (op b c) = op (op a (op b b)) c := by + grind only + +example {α : Sort u} (op : α → α → α) [Std.Associative op] [Std.Commutative op] [Std.IdempotentOp op] (a b c : α) + : op (op a a) (op b c) = op (op (op b a) (op b b)) c := by + grind only + +example {α : Sort u} (op : α → α → α) (u : α) [Std.Associative op] [Std.Commutative op] [Std.IdempotentOp op] [Std.LawfulIdentity op u] (a b c : α) + : op (op a a) (op b c) = op (op (op b a) (op (op u b) b)) c := by + grind only + +example {α : Type u} (op : α → α → α) [Std.Associative op] [Std.Commutative op] [Std.IdempotentOp op] (a b c : α) + : op (op a a) (op b c) = op (op (op b a) (op b b)) c := by + grind only + +example {α : Type u} (op : α → α → α) (u : α) [Std.Associative op] [Std.Commutative op] [Std.IdempotentOp op] [Std.LawfulIdentity op u] (a b c : α) + : op (op a a) (op b c) = op (op (op b a) (op (op u b) b)) c := by + grind only + +example {α} (as bs cs : List α) : as ++ (bs ++ cs) = ((as ++ []) ++ bs) ++ (cs ++ []) := by + grind only + +example (a b c : Nat) : max a (max b c) = max (max b 0) (max a c) := by + grind only + +/-- +trace: [grind.debug.proof] Classical.byContradiction + (intro_with_eq (¬(max a (max b c) = max (max b 0) (max a c) ∧ min a b = min b a)) + (¬max a (max b c) = max (max b 0) (max a c) ∨ ¬min a b = min b a) False + (Grind.not_and (max a (max b c) = max (max b 0) (max a c)) (min a b = min b a)) fun h => + Or.casesOn h + (fun h_1 => + let ctx := + Context.mk + (RArray.branch 2 (RArray.branch 1 (RArray.leaf (PLift.up 0)) (RArray.leaf (PLift.up a))) + (RArray.branch 3 (RArray.leaf (PLift.up b)) (RArray.leaf (PLift.up c)))) + max; + let e_1 := ((Expr.var 2).op (Expr.var 0)).op ((Expr.var 1).op (Expr.var 3)); + let e_2 := (Expr.var 1).op ((Expr.var 2).op (Expr.var 3)); + let p_1 := Seq.cons 1 (Seq.cons 2 (Seq.var 3)); + diseq_unsat ctx p_1 p_1 (eagerReduce (Eq.refl true)) + (diseq_norm_aci ctx e_2 e_1 p_1 p_1 (eagerReduce (Eq.refl true)) h_1)) + fun h_1 => + let ctx := Context.mk (RArray.branch 1 (RArray.leaf (PLift.up a)) (RArray.leaf (PLift.up b))) min; + let e_1 := (Expr.var 1).op (Expr.var 0); + let e_2 := (Expr.var 0).op (Expr.var 1); + let p_1 := Seq.cons 0 (Seq.var 1); + diseq_unsat ctx p_1 p_1 (eagerReduce (Eq.refl true)) + (diseq_norm_ac ctx e_2 e_1 p_1 p_1 (eagerReduce (Eq.refl true)) h_1)) +-/ +#guard_msgs in +set_option pp.structureInstances false in +set_option trace.grind.debug.proof true in +example (a b c : Nat) : max a (max b c) = max (max b 0) (max a c) ∧ min a b = min b a := by + grind only [cases Or]