Compare commits

...

10 Commits

Author SHA1 Message Date
Leonardo de Moura
2fd8e72625 test: grind funCC attribute 2025-11-22 20:46:18 -08:00
Leonardo de Moura
ce8afcfb4a feat: grind funCC attribute 2025-11-22 20:46:02 -08:00
Leonardo de Moura
86d8d9fb6d test: grind +funCC 2025-11-22 20:18:23 -08:00
Leonardo de Moura
4aed9f26be refactor: fo ==> funCC 2025-11-22 20:16:21 -08:00
Leonardo de Moura
4a1bab9f04 chore: add funCC option 2025-11-22 20:00:11 -08:00
Leonardo de Moura
499626c8c6 chore: fix test 2025-11-22 18:48:27 -08:00
Leonardo de Moura
273e7f910e fix: isPropagateBetaTarget 2025-11-22 18:48:15 -08:00
Leonardo de Moura
43333dd777 chore: add note 2025-11-22 18:47:52 -08:00
Leonardo de Moura
a61fe3ba7a chore: checkParents 2025-11-22 18:45:06 -08:00
Leonardo de Moura
7bd6c9daad feat: improve higher-order function support in grind 2025-11-22 18:22:19 -08:00
17 changed files with 414 additions and 108 deletions

View File

@@ -188,6 +188,16 @@ where the term `f` contains at least one constant symbol.
-/
syntax grindInj := &"inj"
/--
The `funCC` modifier marks global functions that support **function-valued congruence closure**.
Given an application `f a₁ a₂ … aₙ`, when `funCC := true`,
`grind` generates and tracks equalities for all partial applications:
- `f a₁`
- `f a₁ a₂`
- `…`
- `f a₁ a₂ … aₙ`
-/
syntax grindFunCC := &"funCC"
/--
`symbol <prio>` sets the priority of a constant for `grind`s pattern-selection
procedure. `grind` prefers patterns that contain higher-priority symbols.
Example:
@@ -214,7 +224,7 @@ syntax grindMod :=
grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd
<|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager
<|> grindCases <|> grindIntro <|> grindExt <|> grindGen <|> grindSym <|> grindInj
<|> grindDef
<|> grindFunCC <|> grindDef
/--
Marks a theorem or definition for use by the `grind` tactic.

View File

@@ -172,6 +172,36 @@ structure Config where
and then reintroduces them while simplifying and applying eager `cases`.
-/
revert := false
/--
When `true`, it enables **function-valued congruence closure**.
`grind` treats equalities of partially applied functions as first-class equalities
and propagates them through further applications.
Given an application `f a₁ a₂ … aₙ`, when `funCC := true` *and* function equality is enabled for `f`,
`grind` generates and tracks equalities for all partial applications:
- `f a₁`
- `f a₁ a₂`
- `…`
- `f a₁ a₂ … aₙ`
This allows equalities such as `f a₁ = g` to propagate to
`f a₁ a₂ = g a₂`.
**When is function equality enabled for a symbol?**
Function equality is automatically enabled in the following cases:
1. **`f` is not a constant.** (For example, a lambda expression, a local variable, or a function parameter.)
2. **`f` is a structure field projection**, *provided the structure is not a `class`.*
3. **`f` is a constant marked with the attribute:** `@[grind funCC]`
If none of the above conditions apply, function equality is disabled for `f`, and congruence
closure behaves almost like it does in SMT solvers for first-order logic.
Here is an example, `grind` can solve when `funCC := true`
```
example (a b : Nat) (g : Nat → Nat) (f : Nat → Nat → Nat) (h : f a = g) :
f a b = g b := by
grind
```
-/
funCC := true
deriving Inhabited, BEq
/--

View File

@@ -239,7 +239,7 @@ where
elabEMatchTheorem declName (.default false) minIndexable
else
return thms.toArray
| .cases _ | .intro | .inj | .ext | .symbol _ =>
| .cases _ | .intro | .inj | .ext | .symbol _ | .funCC =>
throwError "invalid modifier"
def logAnchor (c : SplitInfo) : TermElabM Unit := do

View File

@@ -161,7 +161,8 @@ def mkGrindParams
this is not very effective. We now use anchors to restrict the set of case-splits.
-/
let casesTypes Grind.getCasesTypes
let params := { params with ematch, casesTypes, inj }
let funCCs Grind.getFunCCSet
let params := { params with ematch, casesTypes, inj, funCCs }
let suggestions if config.suggestions then
LibrarySuggestions.select mvarId
else

View File

@@ -141,6 +141,8 @@ def processParam (params : Grind.Params)
| .symbol prio =>
ensureNoMinIndexable minIndexable
params := { params with symPrios := params.symPrios.insert declName prio }
| .funCC =>
params := { params with funCCs := params.funCCs.insert declName }
return params
def processAnchor (params : Grind.Params) (val : TSyntax `hexnum) : CoreM Grind.Params := do
@@ -161,7 +163,7 @@ def processTermParam (params : Grind.Params)
checkNoRevert params
let kind if let some mod := mod? then Grind.getAttrKindCore mod else pure .infer
let kind match kind with
| .ematch .user | .cases _ | .intro | .inj | .ext | .symbol _ =>
| .ematch .user | .cases _ | .intro | .inj | .ext | .symbol _ | .funCC =>
throwError "invalid `grind` parameter, only global declarations are allowed with this kind of modifier"
| .ematch kind => pure kind
| .infer => pure <| .default false

View File

@@ -47,6 +47,7 @@ public import Lean.Meta.Tactic.Grind.EMatchAction
public import Lean.Meta.Tactic.Grind.Filter
public import Lean.Meta.Tactic.Grind.CollectParams
public import Lean.Meta.Tactic.Grind.Finish
public import Lean.Meta.Tactic.Grind.FunCC
public section
namespace Lean

View File

@@ -8,8 +8,8 @@ prelude
public import Lean.Meta.Tactic.Grind.Injective
public import Lean.Meta.Tactic.Grind.Cases
public import Lean.Meta.Tactic.Grind.ExtAttr
public import Lean.Meta.Tactic.Grind.FunCC
import Lean.ExtraModUses
public section
namespace Lean.Meta.Grind
@@ -21,6 +21,7 @@ inductive AttrKind where
| ext
| symbol (prio : Nat)
| inj
| funCC
/-- Return theorem kind for `stx` of the form `Attr.grindThmMod` -/
def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do
@@ -46,6 +47,7 @@ def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do
| `(Parser.Attr.grindMod|intro) => return .intro
| `(Parser.Attr.grindMod|ext) => return .ext
| `(Parser.Attr.grindMod|inj) => return .inj
| `(Parser.Attr.grindMod|funCC) => return .funCC
| `(Parser.Attr.grindMod|symbol $prio:prio) =>
let some prio := prio.raw.isNatLit? | throwErrorAt prio "priority expected"
return .symbol prio
@@ -126,6 +128,7 @@ private def registerGrindAttr (minIndexable : Bool) (showInfo : Bool) : IO Unit
addEMatchAttrAndSuggest stx declName attrKind ( getGlobalSymbolPriorities) (minIndexable := minIndexable) (showInfo := showInfo)
| .symbol prio => addSymbolPriorityAttr declName attrKind prio
| .inj => addInjectiveAttr declName attrKind
| .funCC => addFunCCAttr declName attrKind
erase := fun declName => MetaM.run' do
if showInfo then
throwError "`[grind?]` is a helper attribute for displaying inferred patterns, if you want to remove the attribute, consider using `[grind]` instead"
@@ -135,6 +138,8 @@ private def registerGrindAttr (minIndexable : Bool) (showInfo : Bool) : IO Unit
eraseExtAttr declName
else if ( isInjectiveTheorem declName) then
eraseInjectiveAttr declName
else if ( hasFunCCAttr declName) then
eraseFunCCAttr declName
else
eraseEMatchAttr declName
}

View File

@@ -54,7 +54,8 @@ private def isPropagateBetaTarget (e : Expr) : GoalM Bool := do
where
go (f : Expr) : GoalM Bool := do
if let some root getRootENode? f then
return root.hasLambdas
if root.hasLambdas then
return true
let .app f _ := f | return false
go f

View File

@@ -114,7 +114,11 @@ def propagateBeta (lams : Array Expr) (fns : Array Expr) : GoalM Unit := do
propagateBetaEqs lams curr args.reverse
let .app f arg := curr
| break
-- Remark: recall that we do not eagerly internalize partial applications.
/-
**Note**: Recall that we do not eagerly internalize all partial applications.
We can add a small optimization here. If `useFO parent` is `false`, then
we know `curr` has been internalized
-/
internalize curr ( getGeneration parent)
args := args.push arg
curr := f

View File

@@ -0,0 +1,34 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.ScopedEnvExtension
public section
namespace Lean.Meta.Grind
private builtin_initialize funCCExt : SimpleScopedEnvExtension Name NameSet
registerSimpleScopedEnvExtension {
initial := {}
addEntry := fun s declName => s.insert declName
}
def getFunCCSet : CoreM NameSet :=
return funCCExt.getState ( getEnv)
def hasFunCCAttr (declName : Name) : CoreM Bool := do
return ( getFunCCSet).contains declName
def addFunCCAttr (declName : Name) (attrKind : AttributeKind) : CoreM Unit := do
funCCExt.add declName attrKind
def eraseFunCCAttr (declName : Name) : CoreM Unit := do
let s getFunCCSet
unless s.contains declName do
throwError "`{.ofConstName declName}` is not marked with the `[grind]` attribute"
let s := s.erase declName
modifyEnv fun env => funCCExt.modifyState env fun _ => s
end Lean.Meta.Grind

View File

@@ -16,21 +16,42 @@ import Lean.Meta.Tactic.Grind.Simp
import Lean.Meta.Tactic.Grind.Proof
import Lean.Meta.Tactic.Grind.MarkNestedSubsingletons
import Lean.Meta.Tactic.Grind.PropagateInj
import Lean.Meta.Tactic.Grind.FunCC
public section
namespace Lean.Meta.Grind
/--
Returns `true` if we can generate a congruence proof for `e₁ = e₂`.
See paper: Congruence Closure in Intensional Type Theory for additional details.
-/
private def isCongruentCheck (e₁ e₂ : Expr) : GoalM Bool := do
if ( useFunCC e₁) then
go e₁ e₂
else
/- Using first-order approximation. -/
let f := e₁.getAppFn
let g := e₂.getAppFn
if isSameExpr f g then return true
hasSameType f g
where
go (e₁ e₂ : Expr) : GoalM Bool := do
let .app f _ := e₁ | return false
let .app g _ := e₂ | return false
if isSameExpr f g then return true
if ( hasSameType f g) then return true
go f g
/-- Adds `e` to congruence table. -/
def addCongrTable (e : Expr) : GoalM Unit := do
if let some { e := e' } := ( get).congrTable.find? { e } then
-- `f` and `g` must have the same type.
-- See paper: Congruence Closure in Intensional Type Theory
/-
See paper: Congruence Closure in Intensional Type Theory
**Note**: We do **not** implement the expensive quadratic case used in the paper.
-/
if e.isApp then
let f := e.getAppFn
let g := e'.getAppFn
unless isSameExpr f g do
unless ( hasSameType f g) do
reportIssue! "found congruence between{indentExpr e}\nand{indentExpr e'}\nbut functions have different types"
return ()
unless ( isCongruentCheck e e') do
reportIssue! "found congruence between{indentExpr e}\nand{indentExpr e'}\nbut functions have different types"
return ()
trace_goal[grind.debug.congr] "{e} = {e'}"
if ( isEqCongrSymm e e') then
-- **Note**: See comment at `eqCongrSymmPlaceholderProof`
@@ -154,8 +175,8 @@ private def pushCastHEqs (e : Expr) : GoalM Unit := do
| f@Eq.recOn α a motive b h v => pushHEq e v (mkApp6 (mkConst ``Grind.eqRecOn_heq f.constLevels!) α a motive b h v)
| _ => return ()
private def mkENode' (e : Expr) (generation : Nat) : GoalM Unit :=
mkENodeCore e (ctor := false) (interpreted := false) (generation := generation)
private def mkENode' (e : Expr) (generation : Nat) (funCC := false) : GoalM Unit :=
mkENodeCore e (ctor := false) (interpreted := false) (generation := generation) (funCC := funCC)
/-- Internalizes the nested ground terms in the given pattern. -/
private partial def internalizePattern (pattern : Expr) (generation : Nat) (origin : Origin) : GoalM Expr := do
@@ -189,7 +210,7 @@ where
/-- Internalizes the `MatchCond` gadget. -/
private def internalizeMatchCond (matchCond : Expr) (generation : Nat) : GoalM Unit := do
mkENode' matchCond generation
mkENode' matchCond generation (funCC := false)
let (lhss, e') collectMatchCondLhssAndAbstract matchCond
lhss.forM fun lhs => do internalize lhs generation; registerParent matchCond lhs
propagateUp matchCond
@@ -413,6 +434,35 @@ private def tryEta (e : Expr) (generation : Nat) : GoalM Unit := do
internalize e' generation
pushEq e e' ( mkEqRefl e)
/--
Returns `true` if we should use `funCC` for applications of the given constant symbol.
-/
private def useFunCongrAtDecl (declName : Name) : GrindM Bool := do
if ( readThe Grind.Context).funCCs.contains declName then
return true
if ( isInstance declName) then
/- **Note**: Instances are support elements. No `funCC` -/
return false
if let some projInfo getProjectionFnInfo? declName then
if projInfo.fromClass then
/- **Note**: Field of a class are treated as support elements. No `funCC`. -/
return false
/- **Note**: Check the type of the field. If it is a function type, use `funCC` -/
let declInfo getConstInfo declName
let isFn forallBoundedTelescope declInfo.type (some (projInfo.numParams + 1)) fun _ type => do
let type whnf type
return type.isForall
return isFn
return false
/--
Returns `true` if we should use `funCC` for `f`-applications.
-/
private def useFunCongrAtFn (f : Expr) : GrindM Bool := do
unless ( getConfig).funCC do return false
let .const declName _ := f | return true
useFunCongrAtDecl declName
@[export lean_grind_internalize]
private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit := withIncRecDepth do
if ( alreadyInternalized e) then
@@ -484,7 +534,8 @@ where
else if e.isAppOfArity ``Grind.MatchCond 1 then
internalizeMatchCond e generation
else e.withApp fun f args => do
mkENode e generation
let funCC useFunCongrAtFn f
mkENode e generation (funCC := funCC)
updateAppMap e
checkAndAddSplitCandidate e
pushCastHEqs e
@@ -509,13 +560,30 @@ where
else
if let .const fName _ := f then
activateTheorems fName generation
if funCC then
internalizeImpl f generation e
else
internalizeImpl f generation e
registerParent e f
for h : i in *...args.size do
let arg := args[i]
internalize arg generation e
registerParent e arg
if funCC then
let rec traverse (curr : Expr) : GoalM Unit := do
let .app f a := curr | return ()
mkENode curr generation (funCC := true)
internalizeImpl a generation e
traverse f
registerParent curr a
registerParent curr f
addCongrTable curr
let .app curr a := e | unreachable!
internalizeImpl a generation e
traverse curr
registerParent e a
registerParent e curr
else
for h : i in *...args.size do
let arg := args[i]
internalizeImpl arg generation e
registerParent e arg
addCongrTable e
Solvers.internalize e parent?
propagateUp e

View File

@@ -65,33 +65,35 @@ def checkMatchCondParent (e : Expr) (parent : Expr) : GoalM Bool := do
return false
def checkParents (e : Expr) : GoalM Unit := do
if ( isRoot e) then
for parent in ( getParents e).elems do
if isMatchCond parent then
unless ( checkMatchCondParent e parent) do
throwError "e: {e}, parent: {parent}"
assert! ( checkMatchCondParent e parent)
else
let mut found := false
-- There is an argument `arg` s.t. root of `arg` is `e`.
for arg in parent.getAppArgs do
if ( checkChild e arg) then
found := true
break
-- Recall that we have support for `Expr.forallE` propagation. See `ForallProp.lean`.
if let .forallE _ d b _ := parent then
if ( checkChild e d) then
found := true
unless b.hasLooseBVars do
if ( checkChild e b) then
found := true
unless found do
unless ( checkChild e parent.getAppFn) do
-- **Note**: We currently do not check the `funCC` case
unless ( useFunCC e) do
if ( isRoot e) then
for parent in ( getParents e).elems do
if isMatchCond parent then
unless ( checkMatchCondParent e parent) do
throwError "e: {e}, parent: {parent}"
assert! ( checkChild e parent.getAppFn)
else
-- All the parents are stored in the root of the equivalence class.
assert! ( getParents e).isEmpty
assert! ( checkMatchCondParent e parent)
else
let mut found := false
-- There is an argument `arg` s.t. root of `arg` is `e`.
for arg in parent.getAppArgs do
if ( checkChild e arg) then
found := true
break
-- Recall that we have support for `Expr.forallE` propagation. See `ForallProp.lean`.
if let .forallE _ d b _ := parent then
if ( checkChild e d) then
found := true
unless b.hasLooseBVars do
if ( checkChild e b) then
found := true
unless found do
unless ( checkChild e parent.getAppFn) do
throwError "e: {e}, parent: {parent}"
assert! ( checkChild e parent.getAppFn)
else
-- All the parents are stored in the root of the equivalence class.
assert! ( getParents e).isEmpty
def checkPtrEqImpliesStructEq : GoalM Unit := do
let exprs getExprs

View File

@@ -39,6 +39,7 @@ structure Params where
casesTypes : CasesTypes := {}
extra : PArray EMatchTheorem := {}
extraInj : PArray InjectiveTheorem := {}
funCCs : NameSet := {}
norm : Simp.Context
normProcs : Array Simprocs
anchorRefs? : Option (Array AnchorRef) := none
@@ -96,7 +97,8 @@ def GrindM.run (x : GrindM α) (params : Params) (evalTactic? : Option EvalTacti
let config := params.config
let symPrios := params.symPrios
let anchorRefs? := params.anchorRefs?
x ( mkMethods evalTactic?).toMethodsRef { config, anchorRefs?, simpMethods, simp, trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr, ordEqExpr, intExpr, symPrios }
let funCCs := params.funCCs
x ( mkMethods evalTactic?).toMethodsRef { config, anchorRefs?, simpMethods, simp, trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr, ordEqExpr, intExpr, symPrios, funCCs }
|>.run' { scState }
private def mkCleanState (mvarId : MVarId) (params : Params) : MetaM Clean.State := mvarId.withContext do
@@ -119,12 +121,12 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do
let clean mkCleanState mvarId params
let sstates Solvers.mkInitialStates
GoalM.run' { mvarId, ematch.thmMap := thmMap, inj.thms := params.inj, split.casesTypes := casesTypes, clean, sstates } do
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore btrueExpr (interpreted := false) (ctor := true) (generation := 0)
mkENodeCore bfalseExpr (interpreted := false) (ctor := true) (generation := 0)
mkENodeCore natZeroExpr (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore ordEqExpr (interpreted := false) (ctor := true) (generation := 0)
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0) (funCC := false)
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0) (funCC := false)
mkENodeCore btrueExpr (interpreted := false) (ctor := true) (generation := 0) (funCC := false)
mkENodeCore bfalseExpr (interpreted := false) (ctor := true) (generation := 0) (funCC := false)
mkENodeCore natZeroExpr (interpreted := true) (ctor := false) (generation := 0) (funCC := false)
mkENodeCore ordEqExpr (interpreted := false) (ctor := true) (generation := 0) (funCC := false)
for thm in params.extra do
activateTheorem thm 0

View File

@@ -138,6 +138,26 @@ mutual
let h := mkApp5 (mkConst ``Lean.Grind.nestedDecidable_congr) p q ( mkEqProofCore p q false) hp hq
mkEqOfHEqIfNeeded h heq
partial def mkEqCongrProof (lhs rhs : Expr) : GoalM Expr := withIncRecDepth do
let_expr f@Eq α₁ a₁ b₁ := lhs | unreachable!
let_expr Eq α₂ a₂ b₂ := rhs | unreachable!
assert! ( get).hasSameRoot a₁ a₂ && ( get).hasSameRoot b₁ b₂
let us := f.constLevels!
if !isSameExpr α₁ α₂ then
return mkApp8 (mkConst ``Grind.heq_congr us) α₁ α₂ a₁ b₁ a₂ b₂ ( mkEqProofCore a₁ a₂ true) ( mkEqProofCore b₁ b₂ true)
else
return mkApp7 (mkConst ``Grind.eq_congr us) α₁ a₁ b₁ a₂ b₂ ( mkEqProofCore a₁ a₂ false) ( mkEqProofCore b₁ b₂ false)
partial def mkEqCongrSymmProof (lhs rhs : Expr) : GoalM Expr := withIncRecDepth do
let_expr f@Eq α₁ a₁ b₁ := lhs | unreachable!
let_expr Eq α₂ a₂ b₂ := rhs | unreachable!
assert! ( get).hasSameRoot a₁ b₂ && ( get).hasSameRoot b₁ a₂
let us := f.constLevels!
if !isSameExpr α₁ α₂ then
return mkApp8 (mkConst ``Grind.heq_congr' us) α₁ α₂ a₁ b₁ a₂ b₂ ( mkEqProofCore a₁ b₂ true) ( mkEqProofCore b₁ a₂ true)
else
return mkApp7 (mkConst ``Grind.eq_congr' us) α₁ a₁ b₁ a₂ b₂ ( mkEqProofCore a₁ b₂ false) ( mkEqProofCore b₁ a₂ false)
/--
Constructs a congruence proof for `lhs` and `rhs` using `congr`, `congrFun`, and `congrArg`.
This function assumes `isCongrDefaultProofTarget` returned `true`.
@@ -172,6 +192,25 @@ mutual
else
return thm.proof
private partial def mkHCongrProof' (f g : Expr) (numArgs : Nat) (lhs rhs : Expr) (heq : Bool) : GoalM Expr := do
let thm mkHCongrWithArity f numArgs
assert! thm.argKinds.size == numArgs
let proof mkHCongrProofHelper thm lhs rhs numArgs
if isSameExpr f g then
mkEqOfHEqIfNeeded proof heq
else
/-
`lhs` is of the form `f a_1 ... a_n`
`rhs` is of the form `g b_1 ... b_n`
`proof : f a_1 ... a_n ≍ f b_1 ... b_n`
We construct a proof for `f a_1 ... a_n ≍ g b_1 ... b_n` using `Eq.ndrec`
-/
let motive withLocalDeclD ( mkFreshUserName `x) ( inferType f) fun x => do
mkLambdaFVars #[x] ( mkHEq lhs (mkAppN x (rhs.getAppArgsN numArgs)))
let fEq mkEqProofCore f g false
let proof mkEqNDRec motive proof fEq
mkEqOfHEqIfNeeded proof heq
private partial def mkHCongrProof (lhs rhs : Expr) (heq : Bool) : GoalM Expr := do
let f := lhs.getAppFn
let g := rhs.getAppFn
@@ -189,43 +228,20 @@ mutual
let proof mkHCongrProofHelper thm lhs rhs numArgs
mkEqOfHEqIfNeeded proof heq
else
let thm mkHCongrWithArity f numArgs
assert! thm.argKinds.size == numArgs
let proof mkHCongrProofHelper thm lhs rhs numArgs
mkHCongrProof' f g numArgs lhs rhs heq
private partial def mkCongrProofFunCC (lhs rhs : Expr) (heq : Bool) : GoalM Expr := do
let rec go (e₁ e₂ : Expr) (numArgs : Nat) : GoalM Expr := do
let .app f _ := e₁ | unreachable!
let .app g _ := e₂ | unreachable!
let numArgs := numArgs + 1
if isSameExpr f g then
mkEqOfHEqIfNeeded proof heq
mkHCongrProof' f g numArgs lhs rhs heq
else if ( hasSameType f g) then
mkHCongrProof' f g numArgs lhs rhs heq
else
/-
`lhs` is of the form `f a_1 ... a_n`
`rhs` is of the form `g b_1 ... b_n`
`proof : f a_1 ... a_n ≍ f b_1 ... b_n`
We construct a proof for `f a_1 ... a_n ≍ g b_1 ... b_n` using `Eq.ndrec`
-/
let motive withLocalDeclD ( mkFreshUserName `x) ( inferType f) fun x => do
mkLambdaFVars #[x] ( mkHEq lhs (mkAppN x rhs.getAppArgs))
let fEq mkEqProofCore f g false
let proof mkEqNDRec motive proof fEq
mkEqOfHEqIfNeeded proof heq
partial def mkEqCongrProof (lhs rhs : Expr) : GoalM Expr := withIncRecDepth do
let_expr f@Eq α₁ a₁ b₁ := lhs | unreachable!
let_expr Eq α₂ a₂ b₂ := rhs | unreachable!
assert! ( get).hasSameRoot a₁ a₂ && ( get).hasSameRoot b₁ b₂
let us := f.constLevels!
if !isSameExpr α₁ α₂ then
return mkApp8 (mkConst ``Grind.heq_congr us) α₁ α₂ a₁ b₁ a₂ b₂ ( mkEqProofCore a₁ a₂ true) ( mkEqProofCore b₁ b₂ true)
else
return mkApp7 (mkConst ``Grind.eq_congr us) α₁ a₁ b₁ a₂ b₂ ( mkEqProofCore a₁ a₂ false) ( mkEqProofCore b₁ b₂ false)
partial def mkEqCongrSymmProof (lhs rhs : Expr) : GoalM Expr := withIncRecDepth do
let_expr f@Eq α₁ a₁ b₁ := lhs | unreachable!
let_expr Eq α₂ a₂ b₂ := rhs | unreachable!
assert! ( get).hasSameRoot a₁ b₂ && ( get).hasSameRoot b₁ a₂
let us := f.constLevels!
if !isSameExpr α₁ α₂ then
return mkApp8 (mkConst ``Grind.heq_congr' us) α₁ α₂ a₁ b₁ a₂ b₂ ( mkEqProofCore a₁ b₂ true) ( mkEqProofCore b₁ a₂ true)
else
return mkApp7 (mkConst ``Grind.eq_congr' us) α₁ a₁ b₁ a₂ b₂ ( mkEqProofCore a₁ b₂ false) ( mkEqProofCore b₁ a₂ false)
go f g numArgs
go lhs rhs 0
/-- Constructs a congruence proof for `lhs` and `rhs`. -/
private partial def mkCongrProof (lhs rhs : Expr) (heq : Bool) : GoalM Expr := do
@@ -234,6 +250,8 @@ mutual
let u withDefault <| getLevel p₁
let v withDefault <| getLevel q₁
return mkApp6 (mkConst ``implies_congr [u, v]) p₁ p₂ q₁ q₂ ( mkEqProofCore p₁ p₂ false) ( mkEqProofCore q₁ q₂ false)
else if ( useFunCC lhs) then
mkCongrProofFunCC lhs rhs heq
else
let f := lhs.getAppFn
let g := rhs.getAppFn

View File

@@ -163,6 +163,8 @@ structure Context where
splitSource : SplitSource := .input
/-- Symbol priorities for inferring E-matching patterns -/
symPrios : SymbolPriorities
/-- Global declarations marked with `@[grind funCC]` -/
funCCs : NameSet
trueExpr : Expr
falseExpr : Expr
natZExpr : Expr
@@ -508,6 +510,12 @@ structure ENode where
mt : Nat := 0
/-- Solver terms attached to this E-node. -/
sTerms : SolverTerms := .nil
/--
If `funCC := true`, then the expression associated with this entry is an application, and
function congruence closure is enabled for it.
See `Grind.Config.funCC` for additional details.
-/
funCC : Bool := true
deriving Inhabited, Repr
def ENode.isRoot (n : ENode) :=
@@ -552,6 +560,12 @@ private def hasSameRoot (enodes : ENodeMap) (a b : Expr) : Bool := Id.run do
let some n2 := enodes.find? { expr := b } | return false
isSameExpr n1.root n2.root
private def useFunCC' (enodes : ENodeMap) (e : Expr) : Bool :=
if let some n := enodes.find? { expr := e } then
n.funCC
else
false
private def congrHash (enodes : ENodeMap) (e : Expr) : UInt64 :=
if let .forallE _ d b _ := e then
mixHash (hashRoot enodes d) (hashRoot enodes b)
@@ -559,7 +573,14 @@ private def congrHash (enodes : ENodeMap) (e : Expr) : UInt64 :=
| Grind.nestedProof p _ => hashRoot enodes p
| Grind.nestedDecidable p _ => mixHash 13 (hashRoot enodes p)
| Eq _ lhs rhs => goEq lhs rhs
| _ => go e 17
| _ =>
match e with
| .app f a =>
if useFunCC' enodes e then
mixHash (hashRoot enodes f) (hashRoot enodes a)
else
go f (hashRoot enodes a)
| _ => hashRoot enodes e
where
goEq (lhs rhs : Expr) : UInt64 :=
let h₁ := hashRoot enodes lhs
@@ -570,28 +591,34 @@ where
| .app f a => go f (mixHash r (hashRoot enodes a))
| _ => mixHash r (hashRoot enodes e)
/-- Returns `true` if `a` and `b` are congruent modulo the equivalence classes in `enodes`. -/
private partial def isCongruent (enodes : ENodeMap) (a b : Expr) : Bool :=
if let .forallE _ d₁ b₁ _ := a then
if let .forallE _ d₂ b₂ _ := b then
/-- Returns `true` if `e₁` and `e₂` are congruent modulo the equivalence classes in `enodes`. -/
private partial def isCongruent (enodes : ENodeMap) (e₁ e₂ : Expr) : Bool :=
if let .forallE _ d₁ b₁ _ := e₁ then
if let .forallE _ d₂ b₂ _ := e₂ then
hasSameRoot enodes d₁ d₂ && hasSameRoot enodes b₁ b₂
else
false
else match_expr a with
else match_expr e₁ with
| Grind.nestedProof p₁ _ =>
let_expr Grind.nestedProof p₂ _ := b | false
let_expr Grind.nestedProof p₂ _ := e₂ | false
hasSameRoot enodes p₁ p₂
| Grind.nestedDecidable p₁ _ =>
let_expr Grind.nestedDecidable p₂ _ := b | false
let_expr Grind.nestedDecidable p₂ _ := e₂ | false
hasSameRoot enodes p₁ p₂
| Eq _ lhs₁ rhs₁ =>
let_expr Eq _ lhs₂ rhs₂ := b | false
let_expr Eq _ lhs₂ rhs₂ := e₂ | false
goEq lhs₁ rhs₁ lhs₂ rhs₂
| _ =>
if a.isApp && b.isApp then
go a b
| _ => Id.run do
let .app f a := e₁ | return false
let .app g b := e₂ | return false
if useFunCC' enodes e₁ then
/-
**Note**: We are not in `MetaM` here. Thus, we cannot check whether `f` and `g` have the same type.
So, we approximate and try to handle this issue when generating the proof term.
-/
hasSameRoot enodes a b && hasSameRoot enodes f g
else
false
hasSameRoot enodes a b && go f g
where
goEq (lhs₁ rhs₁ lhs₂ rhs₂ : Expr) : Bool :=
(hasSameRoot enodes lhs₁ lhs₂ && hasSameRoot enodes rhs₁ rhs₂)
@@ -1078,6 +1105,12 @@ def getRootENode (e : Expr) : GoalM ENode := do
def getRootENode? (e : Expr) : GoalM (Option ENode) := do
let some n getENode? e | return none
getENode? n.root
/--
Returns `true` if the ENode associate with `e` has support for function equality
congruence closure. See `Grind.Config.funCC` for additional details.
-/
def useFunCC (e : Expr) : GoalM Bool :=
return ( getENode e).funCC
/--
Returns the next element in the equivalence class of `e`
@@ -1193,7 +1226,7 @@ def copyParentsTo (parents : ParentSet) (root : Expr) : GoalM Unit := do
curr := curr.insert parent
modify fun s => { s with parents := s.parents.insert { expr := root } curr }
def mkENodeCore (e : Expr) (interpreted ctor : Bool) (generation : Nat) : GoalM Unit := do
def mkENodeCore (e : Expr) (interpreted ctor : Bool) (generation : Nat) (funCC : Bool) : GoalM Unit := do
let n := {
self := e, next := e, root := e, congr := e, size := 1
flipped := false
@@ -1201,7 +1234,7 @@ def mkENodeCore (e : Expr) (interpreted ctor : Bool) (generation : Nat) : GoalM
hasLambdas := e.isLambda
mt := ( get).ematch.gmt
idx := ( get).nextIdx
interpreted, ctor, generation
interpreted, ctor, generation, funCC
}
modify fun s => { s with
enodeMap := s.enodeMap.insert { expr := e } n
@@ -1214,11 +1247,11 @@ def mkENodeCore (e : Expr) (interpreted ctor : Bool) (generation : Nat) : GoalM
Creates an `ENode` for `e` if one does not already exist.
This method assumes `e` has been hash-consed.
-/
def mkENode (e : Expr) (generation : Nat) : GoalM Unit := do
def mkENode (e : Expr) (generation : Nat) (funCC : Bool := false) : GoalM Unit := do
if ( alreadyInternalized e) then return ()
let ctor := ( isConstructorAppCore? e).isSome
let interpreted isInterpreted e
mkENodeCore e interpreted ctor generation
mkENodeCore e interpreted ctor generation funCC
def setENode (e : Expr) (n : ENode) : GoalM Unit :=
modify fun s => { s with

View File

@@ -59,6 +59,7 @@ h_1 : ¬op (op a b) (op b c) = op (op c d) c
[eqc] False propositions
[prop] op (op a b) (op b c) = op (op c d) c
[eqc] Equivalence classes
[eqc] {op (op a b), op (op c d)}
[eqc] {op a b, op c d}
[assoc] Operator `op`
[basis] Basis

View File

@@ -0,0 +1,94 @@
import Lean
example (m : Nat) (a b : Nat Nat) (h : b = a) :
b m = a m := by
grind
example (m n : Nat) (a b : Nat Nat Nat) : b = a m = n i = j b m i = a n j := by
grind
example (m : Nat) (a : Nat Nat) (f : (Nat Nat) (Nat Nat)) (h : f a = a) :
f a m = a m := by
grind
example (m : Nat) (a : Nat Nat) (f : (Nat Nat) (Nat Nat)) (h : f a = a) :
f a m = a m := by
fail_if_success grind -funCC
grind
example (a b : Nat) (g : Nat Nat) (f : Nat Nat Nat) (h : f a = g) :
f a b = g b := by
grind
example (a b : Nat) (g : Nat Nat) (f : Nat Nat Nat) (h : f a = g) :
f a b = g b := by
fail_if_success grind -funCC
grind
namespace WithStructure
structure Test where
apply: Unit Prop
def test : Test := {
apply := fun () => True
}
example : test.apply () := by
grind [test]
end WithStructure
-- grind succeeds without the thunk
namespace WithoutThunk
structure Test where
apply: Prop
def test : Test := {
apply := True
}
example : test.apply := by
grind [test]
end WithoutThunk
-- grind succeeds without structure
namespace WithoutStructure
def Test := Unit Prop
def test : Test := fun () => True
example : test () := by
grind [test]
end WithoutStructure
namespace Ex
opaque f : Nat Nat Nat
opaque g : Nat Nat
example (a b c : Nat) : f a = g b = c f a b = g c := by
fail_if_success grind
simp_all
example (a b c : Nat) : f a = g b = c f a b = g c := by
grind [funCC f, funCC g]
example (a b c : Nat) : f a = g b = c f a b = g c := by
fail_if_success grind
simp_all
attribute [grind funCC] f g
example (a b c : Nat) : f a = g b = c f a b = g c := by
grind
example (a b c : Nat) : f a = g b = c f a b = g c := by
fail_if_success grind -funCC
grind
end Ex