diff --git a/src/Init/Grind/AC.lean b/src/Init/Grind/AC.lean index d348ed0e39..d177b2433e 100644 --- a/src/Init/Grind/AC.lean +++ b/src/Init/Grind/AC.lean @@ -15,19 +15,22 @@ public import Init.Data.Bool namespace Lean.Grind.AC abbrev Var := Nat -structure Context (α : Type u) where - vars : RArray α +structure Context (α : Sort u) where + vars : RArray (PLift α) op : α → α → α inductive Expr where - | var (x : Nat) + | var (x : Var) | op (lhs rhs : Expr) deriving Inhabited, Repr, BEq -noncomputable def Expr.denote {α} (ctx : Context α) (e : Expr) : α := - Expr.rec (fun x => ctx.vars.get x) (fun _ _ ih₁ ih₂ => ctx.op ih₁ ih₂) e +noncomputable def Var.denote {α : Sort u} (ctx : Context α) (x : Var) : α := + PLift.rec (fun x => x) (ctx.vars.get x) -theorem Expr.denote_var {α} (ctx : Context α) (x : Var) : (Expr.var x).denote ctx = ctx.vars.get x := rfl +noncomputable def Expr.denote {α} (ctx : Context α) (e : Expr) : α := + Expr.rec (fun x => x.denote ctx) (fun _ _ ih₁ ih₂ => ctx.op ih₁ ih₂) e + +theorem Expr.denote_var {α} (ctx : Context α) (x : Var) : (Expr.var x).denote ctx = x.denote ctx := rfl theorem Expr.denote_op {α} (ctx : Context α) (a b : Expr) : (Expr.op a b).denote ctx = ctx.op (a.denote ctx) (b.denote ctx) := rfl attribute [local simp] Expr.denote_var Expr.denote_op @@ -59,10 +62,10 @@ instance : LawfulBEq Seq where rfl := by intro a; induction a <;> simp! [BEq.beq]; assumption noncomputable def Seq.denote {α} (ctx : Context α) (s : Seq) : α := - Seq.rec (fun x => ctx.vars.get x) (fun x _ ih => ctx.op (ctx.vars.get x) ih) s + Seq.rec (fun x => x.denote ctx) (fun x _ ih => ctx.op (x.denote ctx) ih) s -theorem Seq.denote_var {α} (ctx : Context α) (x : Var) : (Seq.var x).denote ctx = ctx.vars.get x := rfl -theorem Seq.denote_op {α} (ctx : Context α) (x : Var) (s : Seq) : (Seq.cons x s).denote ctx = ctx.op (ctx.vars.get x) (s.denote ctx) := rfl +theorem Seq.denote_var {α} (ctx : Context α) (x : Var) : (Seq.var x).denote ctx = x.denote ctx := rfl +theorem Seq.denote_op {α} (ctx : Context α) (x : Var) (s : Seq) : (Seq.cons x s).denote ctx = ctx.op (x.denote ctx) (s.denote ctx) := rfl attribute [local simp] Seq.denote_var Seq.denote_op @@ -152,7 +155,7 @@ theorem Seq.erase0_k_eq_erase0 (s : Seq) : s.erase0_k = s.erase0 := by attribute [local simp] Seq.erase0_k_eq_erase0 -theorem Seq.denote_erase0 {α} (ctx : Context α) {inst : Std.LawfulIdentity ctx.op (ctx.vars.get 0)} (s : Seq) +theorem Seq.denote_erase0 {α} (ctx : Context α) {inst : Std.LawfulIdentity ctx.op (Var.denote ctx 0)} (s : Seq) : s.erase0.denote ctx = s.denote ctx := by fun_induction erase0 s <;> simp_all +zetaDelta next => rw [Std.LawfulLeftIdentity.left_id (self := inst.toLawfulLeftIdentity)] @@ -179,12 +182,12 @@ theorem Seq.insert_k_eq_insert (x : Var) (s : Seq) : insert_k x s = insert x s : attribute [local simp] Seq.insert_k_eq_insert theorem Seq.denote_insert {α} (ctx : Context α) {inst₁ : Std.Associative ctx.op} {inst₂ : Std.Commutative ctx.op} (x : Var) (s : Seq) - : (s.insert x).denote ctx = ctx.op (ctx.vars.get x) (s.denote ctx) := by + : (s.insert x).denote ctx = ctx.op (x.denote ctx) (s.denote ctx) := by fun_induction insert x s <;> simp next => rw [Std.Commutative.comm (self := inst₂)] next y s h ih => simp [ih, ← Std.Associative.assoc (self := inst₁)] - rw [Std.Commutative.comm (self := inst₂) (ctx.vars.get x)] + rw [Std.Commutative.comm (self := inst₂) (x.denote ctx)] attribute [local simp] Seq.denote_insert @@ -208,7 +211,7 @@ theorem Seq.denote_sort' {α} (ctx : Context α) {inst₁ : Std.Associative ctx. fun_induction sort' s acc <;> simp next x s ih => simp [ih, ← Std.Associative.assoc (self := inst₁)] - rw [Std.Commutative.comm (self := inst₂) (ctx.vars.get x) (s.denote ctx)] + rw [Std.Commutative.comm (self := inst₂) (x.denote ctx) (s.denote ctx)] attribute [local simp] Seq.denote_sort' @@ -387,11 +390,11 @@ theorem Seq.denote_combineFuel {α} (ctx : Context α) {inst₁ : Std.Associativ next ih => simp [ih, Std.Associative.assoc (self := inst₁)] next x₁ s₁ x₂ s₂ h ih => simp [ih] - rw [← Std.Associative.assoc (self := inst₁), ← Std.Associative.assoc (self := inst₁), Std.Commutative.comm (self := inst₂) (ctx.vars.get x₂)] - rw [Std.Associative.assoc (self := inst₁), Std.Associative.assoc (self := inst₁), Std.Associative.assoc (self := inst₁) (ctx.vars.get x₁)] - apply congrArg (ctx.op (ctx.vars.get x₁)) + rw [← Std.Associative.assoc (self := inst₁), ← Std.Associative.assoc (self := inst₁), Std.Commutative.comm (self := inst₂) (x₂.denote ctx)] + rw [Std.Associative.assoc (self := inst₁), Std.Associative.assoc (self := inst₁), Std.Associative.assoc (self := inst₁) (x₁.denote ctx)] + apply congrArg (ctx.op (x₁.denote ctx)) rw [← Std.Associative.assoc (self := inst₁), ← Std.Associative.assoc (self := inst₁) (s₁.denote ctx)] - rw [Std.Commutative.comm (self := inst₂) (ctx.vars.get x₂)] + rw [Std.Commutative.comm (self := inst₂) (x₂.denote ctx)] attribute [local simp] Seq.denote_combineFuel @@ -446,54 +449,65 @@ theorem superpose_ac {α} (ctx : Context α) {inst₁ : Std.Associative ctx.op} apply congrArg (ctx.op (c.denote ctx)) rw [Std.Commutative.comm (self := inst₂) (b.denote ctx)] -noncomputable def norm_a_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := +noncomputable def eq_norm_a_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := lhs.toSeq.beq' lhs' |>.and' (rhs.toSeq.beq' rhs') -theorem norm_a {α} (ctx : Context α) {_ : Std.Associative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq) - : norm_a_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by - simp [norm_a_cert]; intro _ _; subst lhs' rhs'; simp +theorem eq_norm_a {α} (ctx : Context α) {_ : Std.Associative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq) + : eq_norm_a_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by + simp [eq_norm_a_cert]; intro _ _; subst lhs' rhs'; simp -noncomputable def norm_ac_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := +noncomputable def eq_norm_ac_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := lhs.toSeq.sort.beq' lhs' |>.and' (rhs.toSeq.sort.beq' rhs') -theorem norm_ac {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq) - : norm_ac_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by - simp [norm_ac_cert]; intro _ _; subst lhs' rhs'; simp +theorem eq_norm_ac {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq) + : eq_norm_ac_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by + simp [eq_norm_ac_cert]; intro _ _; subst lhs' rhs'; simp -noncomputable def norm_aci_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := - lhs.toSeq.erase0.sort.beq' lhs' |>.and' (rhs.toSeq.erase0.sort.beq' rhs') - -theorem norm_aci {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} {_ : Std.LawfulIdentity ctx.op (ctx.vars.get 0)} - (lhs rhs : Expr) (lhs' rhs' : Seq) : norm_aci_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by - simp [norm_aci_cert]; intro _ _; subst lhs' rhs'; simp - -noncomputable def norm_ai_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := +noncomputable def eq_norm_ai_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := lhs.toSeq.erase0.beq' lhs' |>.and' (rhs.toSeq.erase0.beq' rhs') -theorem norm_ai {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.LawfulIdentity ctx.op (ctx.vars.get 0)} - (lhs rhs : Expr) (lhs' rhs' : Seq) : norm_ai_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by - simp [norm_ai_cert]; intro _ _; subst lhs' rhs'; simp +theorem eq_norm_ai {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.LawfulIdentity ctx.op (Var.denote ctx 0)} + (lhs rhs : Expr) (lhs' rhs' : Seq) : eq_norm_ai_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by + simp [eq_norm_ai_cert]; intro _ _; subst lhs' rhs'; simp -noncomputable def norm_acip_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := - lhs.toSeq.erase0.sort.eraseDup.beq' lhs' |>.and' (rhs.toSeq.erase0.sort.eraseDup.beq' rhs') +noncomputable def eq_norm_aci_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := + lhs.toSeq.erase0.sort.beq' lhs' |>.and' (rhs.toSeq.erase0.sort.beq' rhs') -theorem norm_acip {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} - {_ : Std.LawfulIdentity ctx.op (ctx.vars.get 0)} {_ : Std.IdempotentOp ctx.op} - (lhs rhs : Expr) (lhs' rhs' : Seq) : norm_acip_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by - simp [norm_acip_cert]; intro _ _; subst lhs' rhs'; simp +theorem eq_norm_aci {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} {_ : Std.LawfulIdentity ctx.op (Var.denote ctx 0)} + (lhs rhs : Expr) (lhs' rhs' : Seq) : eq_norm_aci_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by + simp [eq_norm_aci_cert]; intro _ _; subst lhs' rhs'; simp -noncomputable def norm_acp_cert (lhs rhs : Expr) (lhs' rhs' : Seq) : Bool := - lhs.toSeq.sort.eraseDup.beq' lhs' |>.and' (rhs.toSeq.sort.eraseDup.beq' rhs') - -theorem norm_acp {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} {_ : Std.IdempotentOp ctx.op} - (lhs rhs : Expr) (lhs' rhs' : Seq) : norm_acp_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by - simp [norm_acp_cert]; intro _ _; subst lhs' rhs'; simp - -noncomputable def norm_dup_cert (lhs rhs lhs' rhs' : Seq) : Bool := +noncomputable def eq_erase_dup_cert (lhs rhs lhs' rhs' : Seq) : Bool := lhs.eraseDup.beq' lhs' |>.and' (rhs.eraseDup.beq' rhs') -theorem norm_dup (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.IdempotentOp ctx.op} - (lhs rhs lhs' rhs' : Seq) : norm_dup_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by - simp [norm_dup_cert]; intro _ _; subst lhs' rhs'; simp +theorem eq_erase_dup {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.IdempotentOp ctx.op} + (lhs rhs lhs' rhs' : Seq) : eq_erase_dup_cert lhs rhs lhs' rhs' → lhs.denote ctx = rhs.denote ctx → lhs'.denote ctx = rhs'.denote ctx := by + simp [eq_erase_dup_cert]; intro _ _; subst lhs' rhs'; simp + +theorem diseq_norm_a {α} (ctx : Context α) {_ : Std.Associative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq) + : eq_norm_a_cert lhs rhs lhs' rhs' → lhs.denote ctx ≠ rhs.denote ctx → lhs'.denote ctx ≠ rhs'.denote ctx := by + simp [eq_norm_a_cert]; intro _ _; subst lhs' rhs'; simp + +theorem diseq_norm_ac {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} (lhs rhs : Expr) (lhs' rhs' : Seq) + : eq_norm_ac_cert lhs rhs lhs' rhs' → lhs.denote ctx ≠ rhs.denote ctx → lhs'.denote ctx ≠ rhs'.denote ctx := by + simp [eq_norm_ac_cert]; intro _ _; subst lhs' rhs'; simp + +theorem diseq_norm_ai {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.LawfulIdentity ctx.op (Var.denote ctx 0)} + (lhs rhs : Expr) (lhs' rhs' : Seq) : eq_norm_ai_cert lhs rhs lhs' rhs' → lhs.denote ctx ≠ rhs.denote ctx → lhs'.denote ctx ≠ rhs'.denote ctx := by + simp [eq_norm_ai_cert]; intro _ _; subst lhs' rhs'; simp + +theorem diseq_norm_aci {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.Commutative ctx.op} {_ : Std.LawfulIdentity ctx.op (Var.denote ctx 0)} + (lhs rhs : Expr) (lhs' rhs' : Seq) : eq_norm_aci_cert lhs rhs lhs' rhs' → lhs.denote ctx ≠ rhs.denote ctx → lhs'.denote ctx ≠ rhs'.denote ctx := by + simp [eq_norm_aci_cert]; intro _ _; subst lhs' rhs'; simp + +theorem diseq_erase_dup {α} (ctx : Context α) {_ : Std.Associative ctx.op} {_ : Std.IdempotentOp ctx.op} + (lhs rhs lhs' rhs' : Seq) : eq_erase_dup_cert lhs rhs lhs' rhs' → lhs.denote ctx ≠ rhs.denote ctx → lhs'.denote ctx ≠ rhs'.denote ctx := by + simp [eq_erase_dup_cert]; intro _ _; subst lhs' rhs'; simp + +noncomputable def diseq_unsat_cert (lhs rhs : Seq) : Bool := + lhs.beq' rhs + +theorem diseq_unsat {α} (ctx : Context α) (lhs rhs : Seq) : diseq_unsat_cert lhs rhs → lhs.denote ctx ≠ rhs.denote ctx → False := by + simp [diseq_unsat_cert]; intro; subst lhs; simp end Lean.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind.lean b/src/Lean/Meta/Tactic/Grind.lean index da96eb091d..2a4d3cd49e 100644 --- a/src/Lean/Meta/Tactic/Grind.lean +++ b/src/Lean/Meta/Tactic/Grind.lean @@ -37,6 +37,8 @@ public import Lean.Meta.Tactic.Grind.LawfulEqCmp public import Lean.Meta.Tactic.Grind.ReflCmp public import Lean.Meta.Tactic.Grind.SynthInstance public import Lean.Meta.Tactic.Grind.AC +public import Lean.Meta.Tactic.Grind.VarRename +public import Lean.Meta.Tactic.Grind.ProofUtil public section diff --git a/src/Lean/Meta/Tactic/Grind/AC.lean b/src/Lean/Meta/Tactic/Grind/AC.lean index 6f8183b879..85f49e2d58 100644 --- a/src/Lean/Meta/Tactic/Grind/AC.lean +++ b/src/Lean/Meta/Tactic/Grind/AC.lean @@ -9,6 +9,12 @@ public import Lean.Meta.Tactic.Grind.AC.Types public import Lean.Meta.Tactic.Grind.AC.Util public import Lean.Meta.Tactic.Grind.AC.Var public import Lean.Meta.Tactic.Grind.AC.Internalize +public import Lean.Meta.Tactic.Grind.AC.Eq +public import Lean.Meta.Tactic.Grind.AC.Seq +public import Lean.Meta.Tactic.Grind.AC.Proof +public import Lean.Meta.Tactic.Grind.AC.DenoteExpr +public import Lean.Meta.Tactic.Grind.AC.ToExpr +public import Lean.Meta.Tactic.Grind.AC.VarRename public section namespace Lean builtin_initialize registerTraceClass `grind.ac diff --git a/src/Lean/Meta/Tactic/Grind/AC/DenoteExpr.lean b/src/Lean/Meta/Tactic/Grind/AC/DenoteExpr.lean new file mode 100644 index 0000000000..8efd379f68 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC/DenoteExpr.lean @@ -0,0 +1,33 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Tactic.Grind.AC.Util +import Lean.Meta.AppBuilder +public section +namespace Lean.Meta.Grind.AC +open Lean.Grind +variable [Monad M] [MonadGetStruct M] [MonadError M] + +def _root_.Lean.Grind.AC.Seq.denoteExpr (s : AC.Seq) : M Expr := do + match s with + | .var x => return (← getStruct).vars[x]! + | .cons x s => return mkApp2 (← getStruct).op (← getStruct).vars[x]! (← denoteExpr s) + +def _root_.Lean.Grind.AC.Expr.denoteExpr (e : AC.Expr) : M Expr := do + match e with + | .var x => return (← getStruct).vars[x]! + | .op lhs rhs => return mkApp2 (← getStruct).op (← denoteExpr lhs) (← denoteExpr rhs) + +def EqCnstr.denoteExpr (c : EqCnstr) : M Expr := do + let s ← getStruct + return mkApp3 (mkConst ``Eq [s.u]) s.type (← c.lhs.denoteExpr) (← c.rhs.denoteExpr) + +def DiseqCnstr.denoteExpr (c : DiseqCnstr) : M Expr := do + let s ← getStruct + return mkApp3 (mkConst ``Ne [s.u]) s.type (← c.lhs.denoteExpr) (← c.rhs.denoteExpr) + +end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Eq.lean b/src/Lean/Meta/Tactic/Grind/AC/Eq.lean new file mode 100644 index 0000000000..e902cb8373 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC/Eq.lean @@ -0,0 +1,79 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Tactic.Grind.AC.Util +import Lean.Meta.Tactic.Grind.AC.DenoteExpr +import Lean.Meta.Tactic.Grind.AC.Proof +public section +namespace Lean.Meta.Grind.AC +open Lean.Grind +/-- For each structure `s` s.t. `a` and `b` are elements of, execute `k` -/ +@[specialize] def withExprs (a b : Expr) (k : ACM Unit) : GoalM Unit := do + let ids₁ ← getTermOpIds a + if ids₁.isEmpty then return () + let ids₂ ← getTermOpIds b + go ids₁ ids₂ +where + go : List Nat → List Nat → GoalM Unit + | [], _ => return () + | _, [] => return () + | id₁::ids₁, id₂::ids₂ => do + if id₁ == id₂ then + ACM.run id₁ k; go ids₁ ids₂ + else if id₁ < id₂ then + go ids₁ (id₂::ids₂) + else + go (id₁::ids₁) ids₂ + +def asACExpr (e : Expr) : ACM AC.Expr := do + if let some e' := (← getStruct).denote.find? { expr := e } then + return e' + else + return .var ((← getStruct).varMap.find? { expr := e }).get! + +def norm (e : AC.Expr) : ACM AC.Seq := do + match (← isCommutative), (← hasNeutral) with + | true, true => return e.toSeq.erase0.sort + | true, false => return e.toSeq.sort + | false, true => return e.toSeq.erase0 + | false, false => return e.toSeq + +def saveDiseq (c : DiseqCnstr) : ACM Unit := do + modifyStruct fun s => { s with diseqs := s.diseqs.push c } + +def DiseqCnstr.eraseDup (c : DiseqCnstr) : ACM DiseqCnstr := do + unless (← isIdempotent) do return c + let lhs := c.lhs.eraseDup + let rhs := c.rhs.eraseDup + if c.lhs == lhs && c.rhs == rhs then + return c + else + return { lhs, rhs, h := .erase_dup c } + +def DiseqCnstr.assert (c : DiseqCnstr) : ACM Unit := do + let c ← c.eraseDup + -- TODO: simplify and check conflict + trace[grind.ac.assert] "{← c.denoteExpr}" + if c.lhs == c.rhs then + c.setUnsat + else + saveDiseq c + +@[export lean_process_ac_eq] +def processNewEqImpl (a b : Expr) : GoalM Unit := withExprs a b do + trace[grind.ac.assert] "{a} = {b}" + -- TODO + +@[export lean_process_ac_diseq] +def processNewDiseqImpl (a b : Expr) : GoalM Unit := withExprs a b do + let ea ← asACExpr a + let lhs ← norm ea + let eb ← asACExpr b + let rhs ← norm eb + { lhs, rhs, h := .core a b ea eb : DiseqCnstr }.assert + +end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Internalize.lean b/src/Lean/Meta/Tactic/Grind/AC/Internalize.lean index d3f0bf85f3..b9eafcf615 100644 --- a/src/Lean/Meta/Tactic/Grind/AC/Internalize.lean +++ b/src/Lean/Meta/Tactic/Grind/AC/Internalize.lean @@ -7,6 +7,7 @@ module prelude public import Lean.Meta.Tactic.Grind.Types public import Lean.Meta.Tactic.Grind.AC.Util +import Lean.Meta.Tactic.Grind.AC.DenoteExpr public section namespace Lean.Meta.Grind.AC @@ -15,6 +16,12 @@ private def isParentSameOpApp (parent? : Option Expr) (op : Expr) : GoalM Bool : unless e.isApp && e.appFn!.isApp do return false return isSameExpr e.appFn!.appFn! op +partial def reify (e : Expr) : ACM Grind.AC.Expr := do + if let some (a, b) ← isOp? e then + return .op (← reify a) (← reify b) + else + return .var (← mkVar e) + @[export lean_grind_ac_internalize] def internalizeImpl (e : Expr) (parent? : Option Expr) : GoalM Unit := do unless (← getConfig).ac do return () @@ -22,7 +29,12 @@ def internalizeImpl (e : Expr) (parent? : Option Expr) : GoalM Unit := do let op := e.appFn!.appFn! let some id ← getOpId? op | return () if (← isParentSameOpApp parent? op) then return () - trace[grind.ac.internalize] "[{id}] {e}" - -- TODO: internalize `e` + ACM.run id do + if (← getStruct).denote.contains { expr := e } then return () + let e' ← reify e + modifyStruct fun s => { s with denote := s.denote.insert { expr := e } e' } + trace[grind.ac.internalize] "[{id}] {← e'.denoteExpr}" + addTermOpId e + markAsACTerm e end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Proof.lean b/src/Lean/Meta/Tactic/Grind/AC/Proof.lean new file mode 100644 index 0000000000..d2962d88f2 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC/Proof.lean @@ -0,0 +1,142 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Tactic.Grind.AC.Util +import Lean.Meta.Tactic.Grind.Diseq +import Lean.Meta.Tactic.Grind.ProofUtil +import Lean.Meta.Tactic.Grind.AC.ToExpr +import Lean.Meta.Tactic.Grind.AC.VarRename +public section +namespace Lean.Meta.Grind.AC +open Lean.Grind + +structure ProofM.State where + cache : Std.HashMap UInt64 Expr := {} + exprDecls : Std.HashMap AC.Expr Expr := {} + seqDecls : Std.HashMap AC.Seq Expr := {} + +structure ProofM.Context where + ctx : Expr + +abbrev ProofM := ReaderT ProofM.Context (StateRefT ProofM.State ACM) + +/-- Returns a Lean expression representing the variable context used to construct `AC` proof steps. -/ +private def getContext : ProofM Expr := do + return (← read).ctx + +private abbrev caching (c : α) (k : ProofM Expr) : ProofM Expr := do + let addr := unsafe (ptrAddrUnsafe c).toUInt64 >>> 2 + if let some h := (← get).cache[addr]? then + return h + else + let h ← k + modify fun s => { s with cache := s.cache.insert addr h } + return h + +local macro "declare! " decls:ident a:ident : term => + `(do if let some x := (← get).$decls[$a]? then + return x + let x := mkFVar (← mkFreshFVarId); + modify fun s => { s with $decls:ident := (s.$decls).insert $a x }; + return x) + +private def mkSeqDecl (s : AC.Seq) : ProofM Expr := do + declare! seqDecls s + +private def mkExprDecl (e : AC.Expr) : ProofM Expr := do + declare! exprDecls e + +private def getCommInst : ACM Expr := do + let some inst := (← getStruct).commInst? | throwError "`grind` internal error, `{(← getStruct).op}` is not commutative" + return inst + +private def getIdempotentInst : ACM Expr := do + let some inst := (← getStruct).idempotentInst? | throwError "`grind` internal error, `{(← getStruct).op}` is not idempotent" + return inst + +private def getNeutralInst : ACM Expr := do + let some inst := (← getStruct).neutralInst? | throwError "`grind` internal error, `{(← getStruct).op}` does not have a neutral element" + return inst + +private def mkPrefix (declName : Name) : ProofM Expr := do + let s ← getStruct + return mkApp2 (mkConst declName [s.u]) s.type (← getContext) + +private def mkAPrefix (declName : Name) : ProofM Expr := do + let s ← getStruct + return mkApp3 (mkConst declName [s.u]) s.type (← getContext) s.assocInst + +private def mkACPrefix (declName : Name) : ProofM Expr := do + let s ← getStruct + return mkApp4 (mkConst declName [s.u]) s.type (← getContext) s.assocInst (← getCommInst) + +private def mkAIPrefix (declName : Name) : ProofM Expr := do + let s ← getStruct + return mkApp4 (mkConst declName [s.u]) s.type (← getContext) s.assocInst (← getNeutralInst) + +private def mkACIPrefix (declName : Name) : ProofM Expr := do + let s ← getStruct + return mkApp5 (mkConst declName [s.u]) s.type (← getContext) s.assocInst (← getCommInst) (← getNeutralInst) + +private def mkDupPrefix (declName : Name) : ProofM Expr := do + let s ← getStruct + return mkApp4 (mkConst declName [s.u]) s.type (← getContext) s.assocInst (← getIdempotentInst) + +private def toContextExpr (vars : Array Expr) : ACM Expr := do + let s ← getStruct + if h : 0 < vars.size then + RArray.toExpr (mkApp (mkConst ``PLift [s.u]) s.type) id (RArray.ofFn (vars[·]) h) + else unreachable! + +private def mkContext (h : Expr) : ProofM Expr := do + let s ← getStruct + let mut usedVars := + collectMapVars (← get).seqDecls (·.collectVars) >> + collectMapVars (← get).exprDecls (·.collectVars) >> + (if (← hasNeutral) then (collectVar 0) else id) <| {} + let vars' := usedVars.toArray + let varRename := mkVarRename vars' + let vars := (← getStruct).vars + let up := mkApp (mkConst ``PLift.up [s.u]) s.type + let vars := vars'.map fun x => mkApp up vars[x]! + let h := mkLetOfMap (← get).seqDecls h `p (mkConst ``Grind.AC.Seq) fun p => toExpr <| p.renameVars varRename + let h := mkLetOfMap (← get).exprDecls h `e (mkConst ``Grind.AC.Expr) fun e => toExpr <| e.renameVars varRename + let h := h.abstract #[(← read).ctx] + if h.hasLooseBVars then + let ctxType := mkApp (mkConst ``AC.Context [s.u]) s.type + let ctxVal := mkApp3 (mkConst ``AC.Context.mk [s.u]) s.type (← toContextExpr vars) s.op + return .letE `ctx ctxType ctxVal h (nondep := false) + else + return h + +private abbrev withProofContext (x : ProofM Expr) : ACM Expr := do + let ctx := mkFVar (← mkFreshFVarId) + go { ctx } |>.run' {} +where + go : ProofM Expr := do + mkContext (← x) + +partial def DiseqCnstr.toExprProof (c : DiseqCnstr) : ProofM Expr := do caching c do + match c.h with + | .core a b lhs rhs => + let h ← match (← isCommutative), (← hasNeutral) with + | false, false => mkAPrefix ``AC.diseq_norm_a + | false, true => mkAIPrefix ``AC.diseq_norm_ai + | true, false => mkACPrefix ``AC.diseq_norm_ac + | true, true => mkACIPrefix ``AC.diseq_norm_aci + return mkApp6 h (← mkExprDecl lhs) (← mkExprDecl rhs) (← mkSeqDecl c.lhs) (← mkSeqDecl c.rhs) eagerReflBoolTrue (← mkDiseqProof a b) + | .erase_dup c₁ => + let h ← mkDupPrefix ``AC.diseq_erase_dup + return mkApp6 h (← mkSeqDecl c₁.lhs) (← mkSeqDecl c₁.rhs) (← mkSeqDecl c.lhs) (← mkSeqDecl c.rhs) eagerReflBoolTrue (← c₁.toExprProof) + | _ => throwError "NIY" + +def DiseqCnstr.setUnsat (c : DiseqCnstr) : ACM Unit := do + let h ← withProofContext do + return mkApp4 (← mkPrefix ``AC.diseq_unsat) (← mkSeqDecl c.lhs) (← mkSeqDecl c.rhs) eagerReflBoolTrue (← c.toExprProof) + closeGoal h + +end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Seq.lean b/src/Lean/Meta/Tactic/Grind/AC/Seq.lean new file mode 100644 index 0000000000..316bde1097 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC/Seq.lean @@ -0,0 +1,17 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Init.Core +public import Init.Grind.AC +public section +namespace Lean.Grind.AC + +def Seq.length : Seq → Nat + | .var _ => 1 + | .cons _ s => s.length + 1 + +end Lean.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/ToExpr.lean b/src/Lean/Meta/Tactic/Grind/AC/ToExpr.lean new file mode 100644 index 0000000000..97c49687cd --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC/ToExpr.lean @@ -0,0 +1,34 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Init.Grind.AC +public import Lean.ToExpr +public section +namespace Lean.Meta.Grind.AC +open Lean.Grind +/-! +`ToExpr` instances for `AC` types. +-/ +def ofSeq (m : AC.Seq) : Expr := + match m with + | .var x => mkApp (mkConst ``AC.Seq.var) (toExpr x) + | .cons x s => mkApp2 (mkConst ``AC.Seq.cons) (toExpr x) (ofSeq s) + +instance : ToExpr AC.Seq where + toExpr := ofSeq + toTypeExpr := mkConst ``AC.Seq + +def ofExpr (m : AC.Expr) : Expr := + match m with + | .var x => mkApp (mkConst ``AC.Expr.var) (toExpr x) + | .op lhs rhs => mkApp2 (mkConst ``AC.Expr.op) (ofExpr lhs) (ofExpr rhs) + +instance : ToExpr AC.Expr where + toExpr := ofExpr + toTypeExpr := mkConst ``AC.Expr + +end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Types.lean b/src/Lean/Meta/Tactic/Grind/AC/Types.lean index e111f6fee2..ff19fda126 100644 --- a/src/Lean/Meta/Tactic/Grind/AC/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/AC/Types.lean @@ -11,10 +11,57 @@ public import Std.Data.HashMap public import Lean.Expr public import Lean.Data.PersistentArray public import Lean.Meta.Tactic.Grind.ExprPtr +import Lean.Meta.Tactic.Grind.AC.Seq public section - namespace Lean.Meta.Grind.AC -open Lean.Grind.AC +open Lean.Grind + +deriving instance Hashable for AC.Expr, AC.Seq + +mutual +structure EqCnstr where + lhs : AC.Seq + rhs : AC.Seq + h : EqCnstrProof + id : Nat + +inductive EqCnstrProof where + | core (a b : Expr) (ea eb : AC.Expr) + | erase_dup (c : EqCnstr) + | simp_ac (lhs : Bool) (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr) + | superpose_ac (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr) + | simp_suffix (lhs : Bool) (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr) + | simp_prefix (lhs : Bool) (s : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr) + | simp_middle (lhs : Bool) (s₁ s₂ : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr) + | superpose_prefix (s₁ s₂ : AC.Seq) (c₁ : EqCnstr) (c₂ : EqCnstr) +end + +instance : Inhabited EqCnstrProof where + default := .core default default default default + +instance : Inhabited EqCnstr where + default := { lhs := default, rhs := default, h := default, id := 0 } + +protected def EqCnstr.compare (c₁ c₂ : EqCnstr) : Ordering := + (compare c₁.lhs.length c₂.lhs.length) |>.then + (compare c₁.id c₂.id) + +abbrev Queue : Type := Std.TreeSet EqCnstr EqCnstr.compare + +mutual +structure DiseqCnstr where + lhs : AC.Seq + rhs : AC.Seq + h : DiseqCnstrProof + +inductive DiseqCnstrProof where + | core (a b : Expr) (ea eb : AC.Expr) + | erase_dup (c : DiseqCnstr) + | simp_ac (lhs : Bool) (s : AC.Seq) (c₁ : DiseqCnstr) (c₂ : EqCnstr) + | simp_suffix (lhs : Bool) (s : AC.Seq) (c₁ : DiseqCnstr) (c₂ : EqCnstr) + | simp_prefix (lhs : Bool) (s : AC.Seq) (c₁ : DiseqCnstr) (c₂ : EqCnstr) + | simp_middle (lhs : Bool) (s₁ s₂ : AC.Seq) (c₁ : DiseqCnstr) (c₂ : EqCnstr) +end structure Struct where id : Nat @@ -27,13 +74,23 @@ structure Struct where idempotentInst? : Option Expr commInst? : Option Expr neutralInst? : Option Expr + /-- Next unique id for `EqCnstr`s. -/ + nextId : Nat := 0 /-- Mapping from variables to their denotations. Remark each variable can be in only one ring. -/ vars : PArray Expr := {} /-- Mapping from `Expr` to a variable representing it. -/ - varMap : PHashMap ExprPtr Var := {} + varMap : PHashMap ExprPtr AC.Var := {} + /-- Mapping from Lean expressions to their representations as `AC.Expr` -/ + denote : PHashMap ExprPtr AC.Expr := {} + /-- Equations to process. -/ + queue : Queue := {} + /-- Processed equations. -/ + basis : List EqCnstr := {} + /-- Disequalities. -/ + diseqs : PArray DiseqCnstr := {} deriving Inhabited /-- State for all associative operators detected by `grind`. -/ @@ -47,8 +104,10 @@ structure State where Mapping from operators to its "operator id". We cache failures using `none`. `opIdOf[op]` is `some id`, then `id < structs.size`. -/ opIdOf : PHashMap ExprPtr (Option Nat) := {} - -- Remark: a term may be argument of different associative operators. - -- TODO: add mappings + /-- + Mapping from expressions/terms to their structure ids. + Recall that term may be the argument of different operators. -/ + exprToOpIds : PHashMap ExprPtr (List Nat) := {} deriving Inhabited end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Util.lean b/src/Lean/Meta/Tactic/Grind/AC/Util.lean index 8cc982012c..d83c994f49 100644 --- a/src/Lean/Meta/Tactic/Grind/AC/Util.lean +++ b/src/Lean/Meta/Tactic/Grind/AC/Util.lean @@ -10,8 +10,8 @@ public import Lean.Meta.Tactic.Grind.ProveEq public import Lean.Meta.Tactic.Grind.SynthInstance public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId public section - namespace Lean.Meta.Grind.AC +open Lean.Grind def get' : GoalM State := do return (← get).ac @@ -50,6 +50,10 @@ protected def ACM.getStruct : ACM Struct := do instance : MonadGetStruct ACM where getStruct := ACM.getStruct +def modifyStruct (f : Struct → Struct) : ACM Unit := do + let opId ← getOpId + modify' fun s => { s with structs := s.structs.modify opId f } + def getOp : ACM Expr := return (← getStruct).op @@ -71,6 +75,42 @@ private def isArithOpInOtherModules (op : Expr) (f : Expr) : GoalM Bool := do if (← Arith.CommRing.getSemiringId? α).isSome then return true return false +def getTermOpIds (e : Expr) : GoalM (List Nat) := do + return (← get').exprToOpIds.find? { expr := e } |>.getD [] + +private def insertOpId (m : PHashMap ExprPtr (List Nat)) (e : Expr) (opId : Nat) : PHashMap ExprPtr (List Nat) := + let ids := if let some ids := m.find? { expr := e } then + go ids + else + [opId] + m.insert { expr := e } ids +where + go : List Nat → List Nat + | [] => [opId] + | id::ids => if opId < id then + opId :: id :: ids + else if opId == id then + opId :: ids + else + id :: go ids + +def addTermOpId (e : Expr) : ACM Unit := do + let opId ← getOpId + modify' fun s => { s with exprToOpIds := insertOpId s.exprToOpIds e opId } + +def mkVar (e : Expr) : ACM AC.Var := do + let s ← getStruct + if let some var := s.varMap.find? { expr := e } then + return var + let var : AC.Var := s.vars.size + modifyStruct fun s => { s with + vars := s.vars.push e + varMap := s.varMap.insert { expr := e } var + } + addTermOpId e + markAsACTerm e + return var + def getOpId? (op : Expr) : GoalM (Option Nat) := do if let some id? := (← get').opIdOf.find? { expr := op } then return id? @@ -98,7 +138,7 @@ where let idempotentInst? ← synthInstance? idempotentType let (neutralInst?, neutral?) ← do let neutral ← mkFreshExprMVar α - let identityType := mkApp3 (mkConst ``Std.Identity [u]) α op neutral + let identityType := mkApp3 (mkConst ``Std.LawfulIdentity [u]) α op neutral if let some identityInst ← synthInstance? identityType then let neutral ← instantiateExprMVars neutral let neutral ← preprocessLight neutral @@ -112,8 +152,24 @@ where id, type := α, u, op, neutral?, assocInst, commInst?, idempotentInst?, neutralInst? }} - -- TODO: neutral element must be variable 0 trace[grind.debug.ac.op] "{op}, comm: {commInst?.isSome}, idempotent: {idempotentInst?.isSome}, neutral?: {neutral?}" + if let some neutral := neutral? then ACM.run id do + -- Create neutral variable to ensure it is variable 0 + discard <| mkVar neutral return some id +def isOp? (e : Expr) : ACM (Option (Expr × Expr)) := do + unless e.isApp && e.appFn!.isApp do return none + unless isSameExpr e.appFn!.appFn! (← getOp) do return none + return some (e.appFn!.appArg!, e.appArg!) + +def isCommutative : ACM Bool := + return (← getStruct).commInst?.isSome + +def hasNeutral : ACM Bool := + return (← getStruct).neutralInst?.isSome + +def isIdempotent : ACM Bool := + return (← getStruct).idempotentInst?.isSome + end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/VarRename.lean b/src/Lean/Meta/Tactic/Grind/AC/VarRename.lean new file mode 100644 index 0000000000..da16f07af4 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC/VarRename.lean @@ -0,0 +1,33 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Init.Grind.AC +public import Lean.Meta.Tactic.Grind.VarRename +namespace Lean.Grind.AC +open Lean.Meta.Grind + +public def Seq.renameVars (s : Seq) (f : VarRename) : Seq := + match s with + | .var x => .var (f x) + | .cons x s => .cons (f x) (renameVars s f) + +public def Expr.renameVars (e : Expr) (f : VarRename) : Expr := + match e with + | .var x => .var (f x) + | .op a b => .op (renameVars a f) (renameVars b f) + +public def Seq.collectVars (s : Seq) : VarCollector := + match s with + | .var x => collectVar x + | .cons x s => collectVar x >> s.collectVars + +public def Expr.collectVars (e : Expr) : VarCollector := + match e with + | .var x => collectVar x + | .op a b => a.collectVars >> b.collectVars + +end Lean.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/Arith.lean b/src/Lean/Meta/Tactic/Grind/Arith.lean index cd53e9321b..265141f7a0 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith.lean @@ -7,7 +7,6 @@ module prelude public import Lean.Meta.Tactic.Grind.Arith.Util -public import Lean.Meta.Tactic.Grind.Arith.ProofUtil public import Lean.Meta.Tactic.Grind.Arith.Types public import Lean.Meta.Tactic.Grind.Arith.Main public import Lean.Meta.Tactic.Grind.Arith.Offset @@ -15,7 +14,6 @@ public import Lean.Meta.Tactic.Grind.Arith.Cutsat public import Lean.Meta.Tactic.Grind.Arith.CommRing public import Lean.Meta.Tactic.Grind.Arith.Linear public import Lean.Meta.Tactic.Grind.Arith.Simproc -public import Lean.Meta.Tactic.Grind.Arith.VarRename public import Lean.Meta.Tactic.Grind.Arith.Internalize public section diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean index 86ceaf42fe..ec878454e6 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean @@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Lean.Util.Trace public import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean index aff198b2ba..a7747f0fb4 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean @@ -4,22 +4,19 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Init.Grind.Ring.OfSemiring public import Lean.Meta.Tactic.Grind.Diseq -public import Lean.Meta.Tactic.Grind.Arith.ProofUtil +public import Lean.Meta.Tactic.Grind.ProofUtil public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId public import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr public import Lean.Meta.Tactic.Grind.Arith.CommRing.SafePoly public import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr public import Lean.Meta.Tactic.Grind.Arith.CommRing.SemiringM -public import Lean.Meta.Tactic.Grind.Arith.VarRename +public import Lean.Meta.Tactic.Grind.VarRename import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename import Lean.Meta.Tactic.Grind.Arith.CommRing.Functions - public section - namespace Lean.Meta.Grind.Arith.CommRing /-- Returns a context of type `RArray α` containing the variables `vars` where diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/ToExpr.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/ToExpr.lean index 23bbb9f7a1..ad918d490e 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/ToExpr.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/ToExpr.lean @@ -4,14 +4,11 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Init.Grind.Ring.Poly public import Init.Grind.Ring.OfSemiring public import Lean.ToExpr - public section - namespace Lean.Meta.Grind.Arith.CommRing open Grind.CommRing /-! diff --git a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/VarRename.lean b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/VarRename.lean index 7f15cc7e86..d1b7e3b0b1 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/CommRing/VarRename.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/CommRing/VarRename.lean @@ -4,14 +4,12 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Init.Grind.Ring.Poly public import Init.Grind.Ring.OfSemiring -public import Lean.Meta.Tactic.Grind.Arith.VarRename - +public import Lean.Meta.Tactic.Grind.VarRename namespace Lean.Grind.CommRing -open Lean.Meta.Grind.Arith +open Lean.Meta.Grind public def Power.renameVars (pw : Power) (f : VarRename) : Power := { pw with x := (f pw.x) } @@ -59,7 +57,7 @@ public def Expr.collectVars (e : Expr) : VarCollector := end Lean.Grind.CommRing namespace Lean.Grind.Ring.OfSemiring -open Lean.Meta.Grind.Arith +open Lean.Meta.Grind public def Expr.renameVars (e : Expr) (f : VarRename) : Expr := match e with diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean index 2cb8ee42e3..c03ef11ae3 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean @@ -8,8 +8,8 @@ prelude public import Init.Grind.Ring.Poly public import Lean.Meta.Tactic.Grind.Types import Lean.Meta.Tactic.Grind.Diseq -import Lean.Meta.Tactic.Grind.Arith.ProofUtil -import Lean.Meta.Tactic.Grind.Arith.VarRename +import Lean.Meta.Tactic.Grind.ProofUtil +import Lean.Meta.Tactic.Grind.VarRename import Lean.Meta.Tactic.Grind.Arith.Cutsat.CommRing import Lean.Meta.Tactic.Grind.Arith.Cutsat.Util import Lean.Meta.Tactic.Grind.Arith.Cutsat.Nat diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/VarRename.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/VarRename.lean index f65b4018b3..4b70194792 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/VarRename.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/VarRename.lean @@ -4,13 +4,11 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Init.Data.Int.Linear -public import Lean.Meta.Tactic.Grind.Arith.VarRename - +public import Lean.Meta.Tactic.Grind.VarRename namespace Int.Linear -open Lean.Meta.Grind.Arith +open Lean.Meta.Grind public def Poly.renameVars (p : Poly) (f : VarRename) : Poly := match p with diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean index c0c617ae9a..40bd165cf7 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean @@ -9,11 +9,11 @@ public import Lean.Meta.Tactic.Grind.Arith.Linear.Util import Lean.Meta.Tactic.Grind.Arith.Util import Lean.Meta.Tactic.Grind.Arith.Linear.ToExpr import Lean.Meta.Tactic.Grind.Arith.Linear.DenoteExpr -import Lean.Meta.Tactic.Grind.Arith.VarRename +import Lean.Meta.Tactic.Grind.VarRename import Lean.Meta.Tactic.Grind.Diseq +import Lean.Meta.Tactic.Grind.ProofUtil import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr -import Lean.Meta.Tactic.Grind.Arith.ProofUtil import Lean.Meta.Tactic.Grind.Arith.Linear.VarRename import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear/VarRename.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear/VarRename.lean index 6504e20ebd..d264927894 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Linear/VarRename.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear/VarRename.lean @@ -7,10 +7,10 @@ module prelude public import Init.Grind.Ordered.Linarith -public import Lean.Meta.Tactic.Grind.Arith.VarRename +public import Lean.Meta.Tactic.Grind.VarRename namespace Lean.Grind.Linarith -open Lean.Meta.Grind.Arith +open Lean.Meta.Grind public def Poly.renameVars (p : Poly) (f : VarRename) : Poly := match p with diff --git a/src/Lean/Meta/Tactic/Grind/Core.lean b/src/Lean/Meta/Tactic/Grind/Core.lean index 7fb0326d66..284aaefb08 100644 --- a/src/Lean/Meta/Tactic/Grind/Core.lean +++ b/src/Lean/Meta/Tactic/Grind/Core.lean @@ -215,6 +215,31 @@ def propagateLinarith : PendingTheoryPropagation → GoalM Unit | .diseqs ps => propagateLinarithDiseqs ps | _ => return () +/-- +Helper function for combining `ENode.ac?` fields and detecting what needs to be +propagated to the ac module. +-/ +private def checkACEq (rhsRoot lhsRoot : ENode) : GoalM PendingTheoryPropagation := do + match lhsRoot.ac? with + | some lhs => + if let some rhs := rhsRoot.ac? then + return .eq lhs rhs + else + -- We have to retrieve the node because other fields have been updated + let rhsRoot ← getENode rhsRoot.self + setENode rhsRoot.self { rhsRoot with ac? := lhs } + return .diseqs (← getParents rhsRoot.self) + | none => + if rhsRoot.ac?.isSome then + return .diseqs (← getParents lhsRoot.self) + else + return .none + +def propagateAC : PendingTheoryPropagation → GoalM Unit + | .eq lhs rhs => AC.processNewEq lhs rhs + | .diseqs ps => propagateACDiseqs ps + | _ => return () + /-- Tries to apply beta-reduction using the parent applications of the functions in `fns` with the lambda expressions in `lams`. @@ -349,6 +374,7 @@ where let cutsatTodo ← checkCutsatEq rhsRoot lhsRoot let ringTodo ← checkCommRingEq rhsRoot lhsRoot let linarithTodo ← checkLinarithEq rhsRoot lhsRoot + let ACTodo ← checkACEq rhsRoot lhsRoot resetParentsOf lhsRoot.self copyParentsTo parents rhsNode.root unless (← isInconsistent) do @@ -363,6 +389,7 @@ where propagateCutsat cutsatTodo propagateCommRing ringTodo propagateLinarith linarithTodo + propagateAC ACTodo updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do let isFalseRoot ← isFalseExpr rootNew traverseEqc lhs fun n => do diff --git a/src/Lean/Meta/Tactic/Grind/Arith/ProofUtil.lean b/src/Lean/Meta/Tactic/Grind/ProofUtil.lean similarity index 91% rename from src/Lean/Meta/Tactic/Grind/Arith/ProofUtil.lean rename to src/Lean/Meta/Tactic/Grind/ProofUtil.lean index b07a792038..96fe0a92e9 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/ProofUtil.lean +++ b/src/Lean/Meta/Tactic/Grind/ProofUtil.lean @@ -4,13 +4,10 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Lean.Meta.Tactic.Grind.Types - public section - -namespace Lean.Meta.Grind.Arith +namespace Lean.Meta.Grind def mkLetOfMap {_ : Hashable α} {_ : BEq α} (m : Std.HashMap α Expr) (e : Expr) (varPrefix : Name) (varType : Expr) (toExpr : α → Expr) : Expr := Id.run do @@ -25,4 +22,4 @@ def mkLetOfMap {_ : Hashable α} {_ : BEq α} (m : Std.HashMap α Expr) (e : Exp i := i - 1 return e -end Lean.Meta.Grind.Arith +end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Propagate.lean b/src/Lean/Meta/Tactic/Grind/Propagate.lean index f1030983eb..35d3a0b7fc 100644 --- a/src/Lean/Meta/Tactic/Grind/Propagate.lean +++ b/src/Lean/Meta/Tactic/Grind/Propagate.lean @@ -186,6 +186,7 @@ builtin_grind_propagator propagateEqDown ↓Eq := fun e => do propagateCutsatDiseq lhs rhs propagateCommRingDiseq lhs rhs propagateLinarithDiseq lhs rhs + propagateACDiseq lhs rhs let thms ← getExtTheorems α if !thms.isEmpty then /- diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index 7c99f2ef94..af19e68dc2 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -440,7 +440,12 @@ structure ENode where to the linarith module. Its implementation is similar to the `offset?` field. -/ linarith? : Option Expr := none - -- Remark: we expect to have builtin support for offset constraints, cutsat, comm ring, and linarith. + /-- + The `ac?` field is used to propagate equalities from the `grind` congruence closure module + to the ac module. Its implementation is similar to the `offset?` field. + -/ + ac? : Option Expr := none + -- Remark: we expect to have builtin support for offset constraints, cutsat, comm ring, linarith, and ac. -- If the number of satellite solvers increases, we may add support for an arbitrary solvers like done in Z3. deriving Inhabited, Repr @@ -1283,6 +1288,53 @@ def markAsLinarithTerm (e : Expr) : GoalM Unit := do setENode root.self { root with linarith? := some e } propagateLinarithDiseqs (← getParents root.self) +/-- +Notifies the ac module that `a = b` where +`a` and `b` are terms that have been internalized by this module. +-/ +@[extern "lean_process_ac_eq"] -- forward definition +opaque AC.processNewEq (a b : Expr) : GoalM Unit + +/-- +Notifies the ac module that `a ≠ b` where +`a` and `b` are terms that have been internalized by this module. +-/ +@[extern "lean_process_ac_diseq"] -- forward definition +opaque AC.processNewDiseq (a b : Expr) : GoalM Unit + +/-- +Given `lhs` and `rhs` that are known to be disequal, checks whether +`lhs` and `rhs` have ac terms `e₁` and `e₂` attached to them, +and invokes process `AC.processNewDiseq e₁ e₂` +-/ +def propagateACDiseq (lhs rhs : Expr) : GoalM Unit := do + let some lhs ← get? lhs | return () + let some rhs ← get? rhs | return () + AC.processNewDiseq lhs rhs +where + get? (a : Expr) : GoalM (Option Expr) := do + return (← getRootENode a).ac? + +/-- +Traverses disequalities in `parents`, and propagate the ones relevant to the +ac module. +-/ +def propagateACDiseqs (parents : ParentSet) : GoalM Unit := do + forEachDiseq parents propagateACDiseq + +/-- +Marks `e` as a term of interest to the ac module. +If the root of `e`s equivalence class has already a term of interest, +a new equality is propagated to the ac module. +-/ +def markAsACTerm (e : Expr) : GoalM Unit := do + let root ← getRootENode e + if let some e' := root.ac? then + AC.processNewEq e e' + else + setENode root.self { root with ac? := some e } + propagateACDiseqs (← getParents root.self) + /-- Returns `true` is `e` is the root of its congruence class. -/ def isCongrRoot (e : Expr) : GoalM Bool := do return (← getENode e).isCongrRoot diff --git a/src/Lean/Meta/Tactic/Grind/Arith/VarRename.lean b/src/Lean/Meta/Tactic/Grind/VarRename.lean similarity index 95% rename from src/Lean/Meta/Tactic/Grind/Arith/VarRename.lean rename to src/Lean/Meta/Tactic/Grind/VarRename.lean index 5dbd3d7336..342f6be835 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/VarRename.lean +++ b/src/Lean/Meta/Tactic/Grind/VarRename.lean @@ -4,13 +4,12 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ module - prelude public import Init.Prelude public import Init.Data.Array.QSort public import Std.Data.HashSet public section -namespace Lean.Meta.Grind.Arith +namespace Lean.Meta.Grind abbrev Var : Type := Nat abbrev FoundVars := Std.HashSet Nat @@ -45,4 +44,4 @@ def mkVarRename (new2old : Array Var) : VarRename := Id.run do new := new + 1 { map := old2new } -end Lean.Meta.Grind.Arith +end Lean.Meta.Grind diff --git a/src/out b/src/out new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/lean/run/grind_ac_1.lean b/tests/lean/run/grind_ac_1.lean new file mode 100644 index 0000000000..68f0cb9e9b --- /dev/null +++ b/tests/lean/run/grind_ac_1.lean @@ -0,0 +1,81 @@ +open Lean Grind AC +example {α : Type u} (op : α → α → α) [Std.Associative op] (a b c : α) + : op a (op b c) = op (op a b) c := by + grind only + +example {α : Sort u} (op : α → α → α) [Std.Associative op] (a b c : α) + : op a (op b c) = op (op a b) c := by + grind only + +example {α : Sort u} (op : α → α → α) (u : α) [Std.Associative op] [Std.LawfulIdentity op u] (a b c : α) + : op a (op b c) = op (op a b) (op c u) := by + grind only + +example {α : Sort u} (op : α → α → α) [Std.Associative op] [Std.Commutative op] (a b c : α) + : op c (op b a) = op (op b a) c := by + grind only + +example {α : Sort u} (op : α → α → α) (u : α) [Std.Associative op] [Std.Commutative op] [Std.LawfulIdentity op u] (a b c : α) + : op a (op b c) = op (op b a) c := by + grind only + +example {α : Sort u} (op : α → α → α) (u : α) [Std.Associative op] [Std.Commutative op] [Std.LawfulIdentity op u] (a b c : α) + : op a (op b (op u c)) = op (op b a) (op u c) := by + grind only + +example {α : Sort u} (op : α → α → α) [Std.Associative op] [Std.IdempotentOp op] (a b c : α) + : op (op a a) (op b c) = op (op a (op b b)) c := by + grind only + +example {α : Sort u} (op : α → α → α) [Std.Associative op] [Std.Commutative op] [Std.IdempotentOp op] (a b c : α) + : op (op a a) (op b c) = op (op (op b a) (op b b)) c := by + grind only + +example {α : Sort u} (op : α → α → α) (u : α) [Std.Associative op] [Std.Commutative op] [Std.IdempotentOp op] [Std.LawfulIdentity op u] (a b c : α) + : op (op a a) (op b c) = op (op (op b a) (op (op u b) b)) c := by + grind only + +example {α : Type u} (op : α → α → α) [Std.Associative op] [Std.Commutative op] [Std.IdempotentOp op] (a b c : α) + : op (op a a) (op b c) = op (op (op b a) (op b b)) c := by + grind only + +example {α : Type u} (op : α → α → α) (u : α) [Std.Associative op] [Std.Commutative op] [Std.IdempotentOp op] [Std.LawfulIdentity op u] (a b c : α) + : op (op a a) (op b c) = op (op (op b a) (op (op u b) b)) c := by + grind only + +example {α} (as bs cs : List α) : as ++ (bs ++ cs) = ((as ++ []) ++ bs) ++ (cs ++ []) := by + grind only + +example (a b c : Nat) : max a (max b c) = max (max b 0) (max a c) := by + grind only + +/-- +trace: [grind.debug.proof] Classical.byContradiction + (intro_with_eq (¬(max a (max b c) = max (max b 0) (max a c) ∧ min a b = min b a)) + (¬max a (max b c) = max (max b 0) (max a c) ∨ ¬min a b = min b a) False + (Grind.not_and (max a (max b c) = max (max b 0) (max a c)) (min a b = min b a)) fun h => + Or.casesOn h + (fun h_1 => + let ctx := + Context.mk + (RArray.branch 2 (RArray.branch 1 (RArray.leaf (PLift.up 0)) (RArray.leaf (PLift.up a))) + (RArray.branch 3 (RArray.leaf (PLift.up b)) (RArray.leaf (PLift.up c)))) + max; + let e_1 := ((Expr.var 2).op (Expr.var 0)).op ((Expr.var 1).op (Expr.var 3)); + let e_2 := (Expr.var 1).op ((Expr.var 2).op (Expr.var 3)); + let p_1 := Seq.cons 1 (Seq.cons 2 (Seq.var 3)); + diseq_unsat ctx p_1 p_1 (eagerReduce (Eq.refl true)) + (diseq_norm_aci ctx e_2 e_1 p_1 p_1 (eagerReduce (Eq.refl true)) h_1)) + fun h_1 => + let ctx := Context.mk (RArray.branch 1 (RArray.leaf (PLift.up a)) (RArray.leaf (PLift.up b))) min; + let e_1 := (Expr.var 1).op (Expr.var 0); + let e_2 := (Expr.var 0).op (Expr.var 1); + let p_1 := Seq.cons 0 (Seq.var 1); + diseq_unsat ctx p_1 p_1 (eagerReduce (Eq.refl true)) + (diseq_norm_ac ctx e_2 e_1 p_1 p_1 (eagerReduce (Eq.refl true)) h_1)) +-/ +#guard_msgs in +set_option pp.structureInstances false in +set_option trace.grind.debug.proof true in +example (a b c : Nat) : max a (max b c) = max (max b 0) (max a c) ∧ min a b = min b a := by + grind only [cases Or]