Compare commits

...

7 Commits

Author SHA1 Message Date
Leonardo de Moura
5586215f26 fix: improve heuristic 2025-04-13 10:30:39 -07:00
Leonardo de Moura
af406ee7d3 chore: remove annotation 2025-04-13 09:28:44 -07:00
Leonardo de Moura
f7fc4ba42a fix: missing case 2025-04-13 09:27:49 -07:00
Leonardo de Moura
d957f0b8db fix: Lookahead.lean and Split.lean 2025-04-13 09:21:41 -07:00
Leonardo de Moura
0047a3e8e7 chore: fix Internalize.lean 2025-04-13 09:12:53 -07:00
Leonardo de Moura
64f552a1f1 refactor: mbtc 2025-04-13 09:08:36 -07:00
Leonardo de Moura
4561685920 refactor: SplitInfo 2025-04-13 08:52:02 -07:00
6 changed files with 154 additions and 153 deletions

View File

@@ -62,10 +62,10 @@ private def checkAndAddSplitCandidate (e : Expr) : GoalM Unit := do
match e with
| .app .. =>
if ( getConfig).splitIte && (e.isIte || e.isDIte) then
addSplitCandidate e
addSplitCandidate (.default e)
return ()
if isMorallyIff e then
addSplitCandidate e
addSplitCandidate (.default e)
return ()
if ( getConfig).splitMatch then
if ( isMatcherApp e) then
@@ -74,7 +74,7 @@ private def checkAndAddSplitCandidate (e : Expr) : GoalM Unit := do
-- and consequently don't need to be split.
return ()
else
addSplitCandidate e
addSplitCandidate (.default e)
return ()
let .const declName _ := e.getAppFn | return ()
if forbiddenSplitTypes.contains declName then
@@ -82,21 +82,21 @@ private def checkAndAddSplitCandidate (e : Expr) : GoalM Unit := do
unless ( isInductivePredicate declName) do
return ()
if ( get).split.casesTypes.isSplit declName then
addSplitCandidate e
addSplitCandidate (.default e)
else if ( getConfig).splitIndPred then
addSplitCandidate e
addSplitCandidate (.default e)
| .fvar .. =>
let .const declName _ := ( whnfD ( inferType e)).getAppFn | return ()
if ( get).split.casesTypes.isSplit declName then
addSplitCandidate e
addSplitCandidate (.default e)
| .forallE _ d _ _ =>
if ( getConfig).splitImp then
addSplitCandidate e
addSplitCandidate (.default e)
else if Arith.isRelevantPred d then
if ( getConfig).lookahead then
addLookaheadCandidate (.imp e)
addLookaheadCandidate (.default e)
else
addSplitCandidate e
addSplitCandidate (.default e)
| _ => pure ()
/--
@@ -260,7 +260,7 @@ where
-- if (← getConfig).lookahead then
-- addLookaheadCandidate (.arg other.app parent i eq)
-- else
addSplitCandidate eq
addSplitCandidate (.arg other.app parent i eq)
modify fun s => { s with split.argsAt := s.split.argsAt.insert (f, i) ({ arg, type, app := parent } :: others) }
return ()

View File

@@ -12,39 +12,6 @@ import Lean.Meta.Tactic.Grind.EMatch
namespace Lean.Meta.Grind
inductive LookaheadStatus where
| resolved
| notReady
| ready
deriving Inhabited, BEq, Repr
private def checkLookaheadStatus (info : LookaheadInfo) : GoalM LookaheadStatus := do
match info with
| .imp e =>
unless ( isEqTrue e) do return .notReady
let .forallE _ d b _ := e | unreachable!
if ( isEqTrue d <||> isEqFalse d) then return .resolved
unless b.hasLooseBVars do
if ( isEqTrue b <||> isEqFalse b) then return .resolved
return .ready
| .arg a b _ eq =>
if ( isEqTrue eq <||> isEqFalse eq) then return .resolved
let is := ( get).split.lookaheadArgPos[(a, b)]? |>.getD []
let mut j := a.getAppNumArgs
let mut it_a := a
let mut it_b := b
repeat
unless it_a.isApp && it_b.isApp do return .ready
j := j - 1
if j is then
let arg_a := it_a.appArg!
let arg_b := it_b.appArg!
unless ( isEqv arg_a arg_b) do
return .notReady
it_a := it_a.appFn!
it_b := it_b.appFn!
return .ready
private partial def solve (generation : Nat) (goal : Goal) : GrindM Bool := do
cont ( intros generation goal)
where
@@ -110,10 +77,11 @@ def lookahead : GrindTactic := fun goal => do
for info in infos do
if ( isInconsistent) then
return true
match ( checkLookaheadStatus info) with
match ( checkSplitStatus info) with
| .resolved => progress := true
| .ready _ _ true
| .notReady => postponed := info :: postponed
| .ready =>
| .ready _ _ false =>
if ( tryLookahead info.getExpr) then
progress := true
else

View File

@@ -33,13 +33,20 @@ structure MBTC.Context where
-/
eqAssignment : Expr Expr GoalM Bool
private abbrev Map := Std.HashMap (Expr × Nat) (List Expr)
private abbrev Candidates := Std.HashSet (Expr × Expr)
private def mkCandidateKey (a b : Expr) : Expr × Expr :=
if a.lt b then
(a, b)
private structure ArgInfo where
arg : Expr
app : Expr
private abbrev Map := Std.HashMap (Expr × Nat) (List ArgInfo)
private abbrev Candidates := Std.HashSet SplitInfo
private def mkCandidate (a b : ArgInfo) (i : Nat) : GoalM SplitInfo := do
let (lhs, rhs) := if a.arg.lt b.arg then
(a.arg, b.arg)
else
(b, a)
(b.arg, a.arg)
let eq mkEq lhs rhs
let eq shareCommon ( canon eq)
return .arg a.app b.app i eq
/-- Model-based theory combination. -/
def mbtc (ctx : MBTC.Context) : GoalM Bool := do
@@ -58,40 +65,33 @@ def mbtc (ctx : MBTC.Context) : GoalM Bool := do
let some arg getRoot? arg | pure ()
if ( ctx.hasTheoryVar arg) then
trace[grind.debug.mbtc] "{arg} @ {f}:{i}"
if let some others := map[(f, i)]? then
unless others.any (isSameExpr arg ·) do
for other in others do
if ( ctx.eqAssignment arg other) then
if ( hasSameType arg other) then
let k := mkCandidateKey arg other
candidates := candidates.insert k
map := map.insert (f, i) (arg :: others)
let argInfo : ArgInfo := { arg, app := e }
if let some otherInfos := map[(f, i)]? then
unless otherInfos.any fun info => isSameExpr arg info.arg do
for otherInfo in otherInfos do
if ( ctx.eqAssignment arg otherInfo.arg) then
if ( hasSameType arg otherInfo.arg) then
candidates := candidates.insert ( mkCandidate argInfo otherInfo i)
map := map.insert (f, i) (argInfo :: otherInfos)
else
map := map.insert (f, i) [arg]
map := map.insert (f, i) [argInfo]
i := i + 1
if candidates.isEmpty then
return false
if ( get).split.num > ( getConfig).splits then
reportIssue "skipping `mbtc`, maximum number of splits has been reached `(splits := {(← getConfig).splits})`"
return false
let result := candidates.toArray.qsort fun (a₁, b₁) (a₂, b₂) =>
if isSameExpr a₁ a₂ then
b₁.lt b₂
else
a₁.lt a₂
let eqs result.filterMapM fun (a, b) => do
let eq mkEq a b
trace[grind.mbtc] "{eq}"
let eq shareCommon ( canon eq)
if ( isKnownCaseSplit eq) then
let result := candidates.toArray.qsort fun c₁ c₂ => c₁.lt c₂
let result result.filterMapM fun info => do
if ( isKnownCaseSplit info) then
return none
else
internalize eq (Nat.max ( getGeneration a) ( getGeneration b))
return some eq
if eqs.isEmpty then
let .arg a b _ eq := info | return none
internalize eq (Nat.max ( getGeneration a) ( getGeneration b))
return some info
if result.isEmpty then
return false
for eq in eqs do
addSplitCandidate eq
for info in result do
addSplitCandidate info
return true
def mbtcTac (ctx : MBTC.Context) : GrindTactic := fun goal => do

View File

@@ -11,14 +11,14 @@ import Lean.Meta.Tactic.Grind.CasesMatch
namespace Lean.Meta.Grind
inductive CaseSplitStatus where
inductive SplitStatus where
| resolved
| notReady
| ready (numCases : Nat) (isRec := false)
| ready (numCases : Nat) (isRec := false) (tryPostpone := false)
deriving Inhabited, BEq, Repr
/-- Given `c`, the condition of an `if-then-else`, check whether we need to case-split on the `if-then-else` or not -/
private def checkIteCondStatus (c : Expr) : GoalM CaseSplitStatus := do
private def checkIteCondStatus (c : Expr) : GoalM SplitStatus := do
if ( isEqTrue c <||> isEqFalse c) then
return .resolved
else
@@ -28,7 +28,7 @@ private def checkIteCondStatus (c : Expr) : GoalM CaseSplitStatus := do
Given `e` of the form `a b`, check whether we are ready to case-split on `e`.
That is, `e` is `True`, but neither `a` nor `b` is `True`."
-/
private def checkDisjunctStatus (e a b : Expr) : GoalM CaseSplitStatus := do
private def checkDisjunctStatus (e a b : Expr) : GoalM SplitStatus := do
if ( isEqTrue e) then
if ( isEqTrue a <||> isEqTrue b) then
return .resolved
@@ -43,7 +43,7 @@ private def checkDisjunctStatus (e a b : Expr) : GoalM CaseSplitStatus := do
Given `e` of the form `a ∧ b`, check whether we are ready to case-split on `e`.
That is, `e` is `False`, but neither `a` nor `b` is `False`.
-/
private def checkConjunctStatus (e a b : Expr) : GoalM CaseSplitStatus := do
private def checkConjunctStatus (e a b : Expr) : GoalM SplitStatus := do
if ( isEqTrue e) then
return .resolved
else if ( isEqFalse e) then
@@ -60,7 +60,7 @@ There are two cases:
1- `e` is `True`, but neither both `a` and `b` are `True`, nor both `a` and `b` are `False`.
2- `e` is `False`, but neither `a` is `True` and `b` is `False`, nor `a` is `False` and `b` is `True`.
-/
private def checkIffStatus (e a b : Expr) : GoalM CaseSplitStatus := do
private def checkIffStatus (e a b : Expr) : GoalM SplitStatus := do
if ( isEqTrue e) then
if ( (isEqTrue a <&&> isEqTrue b) <||> (isEqFalse a <&&> isEqFalse b)) then
return .resolved
@@ -83,7 +83,7 @@ private def isCongrToPrevSplit (c : Expr) : GoalM Bool := do
else
return c'.isApp && isCongruent ( get).enodes c c'
private def checkForallStatus (e : Expr) : GoalM CaseSplitStatus := do
private def checkForallStatus (e : Expr) : GoalM SplitStatus := do
if ( isEqTrue e) then
let .forallE _ p q _ := e | return .resolved
if ( isEqTrue p <||> isEqFalse p) then
@@ -97,7 +97,7 @@ private def checkForallStatus (e : Expr) : GoalM CaseSplitStatus := do
else
return .notReady
private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do
private def checkDefaultSplitStatus (e : Expr) : GoalM SplitStatus := do
match_expr e with
| Or a b => checkDisjunctStatus e a b
| And a b => checkConjunctStatus e a b
@@ -133,9 +133,37 @@ private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do
return .ready info.ctors.length info.isRec
return .notReady
def checkSplitInfoArgStatus (a b : Expr) (eq : Expr) : GoalM SplitStatus := do
if ( isEqTrue eq <||> isEqFalse eq) then return .resolved
let is := ( get).split.argPosMap[(a, b)]? |>.getD []
let mut j := a.getAppNumArgs
let mut it_a := a
let mut it_b := b
repeat
unless it_a.isApp && it_b.isApp do return .ready 2
j := j - 1
if j is then
let arg_a := it_a.appArg!
let arg_b := it_b.appArg!
unless ( isEqv arg_a arg_b) do
trace_goal[grind.split] "may be irrelevant\na: {a}\nb: {b}\neq: {eq}\narg_a: {arg_a}\narg_b: {arg_b}, gen: {← getGeneration eq}"
/-
We tried to return `.notReady` because we would not be able to derive a congruence, but
`grind_ite.lean` breaks when this heuristic is used. TODO: understand better why.
-/
return .ready 2 (tryPostpone := true)
it_a := it_a.appFn!
it_b := it_b.appFn!
return .ready 2
def checkSplitStatus (s : SplitInfo) : GoalM SplitStatus := do
match s with
| .default e => checkDefaultSplitStatus e
| .arg a b _ eq => checkSplitInfoArgStatus a b eq
private inductive SplitCandidate where
| none
| some (c : Expr) (numCases : Nat) (isRec : Bool)
| some (c : SplitInfo) (numCases : Nat) (isRec : Bool) (tryPostpone : Bool)
/-- Returns the next case-split to be performed. It uses a very simple heuristic. -/
private def selectNextSplit? : GoalM SplitCandidate := do
@@ -143,11 +171,11 @@ private def selectNextSplit? : GoalM SplitCandidate := do
if ( checkMaxCaseSplit) then return .none
go ( get).split.candidates .none []
where
go (cs : List Expr) (c? : SplitCandidate) (cs' : List Expr) : GoalM SplitCandidate := do
go (cs : List SplitInfo) (c? : SplitCandidate) (cs' : List SplitInfo) : GoalM SplitCandidate := do
match cs with
| [] =>
modify fun s => { s with split.candidates := cs'.reverse }
if let .some _ numCases isRec := c? then
if let .some _ numCases isRec _ := c? then
let numSplits := ( get).split.num
-- We only increase the number of splits if there is more than one case or it is recursive.
let numSplits := if numCases > 1 || isRec then numSplits + 1 else numSplits
@@ -156,24 +184,28 @@ where
modify fun s => { s with split.num := numSplits, ematch.num := 0 }
return c?
| c::cs =>
trace_goal[grind.debug.split] "checking: {c}"
match ( checkCaseSplitStatus c) with
trace_goal[grind.debug.split] "checking: {c.getExpr}"
match ( checkSplitStatus c) with
| .notReady => go cs c? (c::cs')
| .resolved => go cs c? cs'
| .ready numCases isRec =>
| .ready numCases isRec tryPostpone =>
if ( cheapCasesOnly) && numCases > 1 then
go cs c? (c::cs')
else match c? with
| .none => go cs (.some c numCases isRec) cs'
| .some c' numCases' _ =>
| .none => go cs (.some c numCases isRec tryPostpone) cs'
| .some c' numCases' _ tryPostpone' =>
let isBetter : GoalM Bool := do
if numCases == 1 && !isRec && numCases' > 1 then
if tryPostpone' && !tryPostpone then
return true
if ( getGeneration c) < ( getGeneration c') then
else if tryPostpone && !tryPostpone' then
return false
else if numCases == 1 && !isRec && numCases' > 1 then
return true
if ( getGeneration c.getExpr) < ( getGeneration c'.getExpr) then
return true
return numCases < numCases'
if ( isBetter) then
go cs (.some c numCases isRec) (c'::cs')
go cs (.some c numCases isRec tryPostpone) (c'::cs')
else
go cs c? (c::cs')
@@ -195,6 +227,7 @@ private def mkCasesMajor (c : Expr) : GoalM Expr := do
else
-- model-based theory combination split
return mkGrindEM c
| Not e => return mkGrindEM e
| _ =>
if let .forallE _ p _ _ := c then
return mkGrindEM p
@@ -215,8 +248,9 @@ and returns a new list of goals if successful.
-/
def splitNext : GrindTactic := fun goal => do
let (goals?, _) GoalM.run goal do
let .some c numCases isRec selectNextSplit?
let .some c numCases isRec _ selectNextSplit?
| return none
let c := c.getExpr
let gen getGeneration c
let genNew := if numCases > 1 || isRec then gen+1 else gen
markCaseSplitAsResolved c

View File

@@ -484,73 +484,72 @@ structure EMatch.State where
matchEqNames : PHashSet Name := {}
deriving Inhabited
/--
Lookahead case-split information. They are cheaper than regular case-splits.
They are created when `Grind.Config.lookahead` is `true`.
The idea is the following: `grind` asserts `¬ p`, if a contradiction is detected
it asserts `p`.
-/
inductive LookaheadInfo where
/-- Case-split information. -/
inductive SplitInfo where
| /--
Given an implication `e`, use lookahead to check whether the antecedent is
implied to be `True`. The lookahead is marked as resolved if the consequent is already
known to be `True`.
Term `e` may be an inductive predicate, `match`-expression, `if`-expression, implication, etc.
-/
imp (e : Expr)
default (e : Expr)
| /--
Given applications `a` and `b`, use lookahead to check whether the corresponding
`i`-th arguments are equal or not. The lookahead is only performed if all other
arguments are already known to be equal or are also tagged as lookahead.
`eq` is the equality between the two arguments
Given applications `a` and `b`, case-split on whether the corresponding
`i`-th arguments are equal or not. The split is only performed if all other
arguments are already known to be equal or are also tagged as split candidates.
-/
arg (a b : Expr) (i : Nat) (eq : Expr)
deriving BEq, Hashable, Inhabited
/-- Returns expression to perform a lookahead case-split. -/
def LookaheadInfo.getExpr : LookaheadInfo Expr
| .imp e => e.bindingDomain!
def SplitInfo.getExpr : SplitInfo Expr
| .default (.forallE _ d _ _) => d
| .default e => e
| .arg _ _ _ eq => eq
/-- Argument `arg : type` of an application `app` -/
structure Arg where
def SplitInfo.lt : SplitInfo SplitInfo Bool
| .default e₁, .default e₂ => e₁.lt e₂
| .arg _ _ _ e₁, .arg _ _ _ e₂ => e₁.lt e₂
| .default _, .arg .. => true
| .arg .., .default _ => false
/-- Argument `arg : type` of an application `app` in `SplitInfo`. -/
structure SplitArg where
arg : Expr
type : Expr
app : Expr
/-- Case splitting related fields for the `grind` goal. -/
structure Split.State where
/-- Inductive datatypes marked for case-splitting -/
casesTypes : CasesTypes := {}
/-- Case-split candidates. -/
candidates : List Expr := []
/-- Number of splits performed to get to this goal. -/
num : Nat := 0
num : Nat := 0
/-- Inductive datatypes marked for case-splitting -/
casesTypes : CasesTypes := {}
/-- Case-split candidates. -/
candidates : List SplitInfo := []
/-- Case-splits that have been inserted at `candidates` at some point. -/
added : PHashSet ENodeKey := {}
added : Std.HashSet SplitInfo := {}
/-- Case-splits that have already been performed, or that do not have to be performed anymore. -/
resolved : PHashSet ENodeKey := {}
resolved : PHashSet ENodeKey := {}
/--
Sequence of cases steps that generated this goal. We only use this information for diagnostics.
Remark: `casesTrace.length ≥ numSplits` because we don't increase the counter for `cases`
applications that generated only 1 subgoal.
-/
trace : List CaseTrace := []
trace : List CaseTrace := []
/-- Lookahead "case-splits". -/
lookaheads : List LookaheadInfo := []
lookaheads : List SplitInfo := []
/--
A mapping `(a, b) ↦ is` s.t. for each `LookaheadInfo.arg a b i eq`
in `lookaheads` we have `i ∈ is`.
We use this information to decide whether the lookahead is "ready"
A mapping `(a, b) ↦ is` s.t. for each `SplitInfo.arg a b i eq`
in `candidates` or `lookaheads` we have `i ∈ is`.
We use this information to decide whether the split/lookahead is "ready"
to be tried or not.
-/
lookaheadArgPos : Std.HashMap (Expr × Expr) (List Nat) := {}
argPosMap : Std.HashMap (Expr × Expr) (List Nat) := {}
/--
Mapping from pairs `(f, i)` to a list of arguments.
Each argument occurs as the `i`-th of an `f`-application.
We use this information to add case-splits for
We use this information to add splits/lookaheads for
triggering extensionality theorems and model-based theory combination.
See `addSplitCandidatesForExt`.
-/
argsAt : PHashMap (Expr × Nat) (List Arg) := {}
argsAt : PHashMap (Expr × Nat) (List SplitArg) := {}
deriving Inhabited
/-- Clean name generator. -/
@@ -1217,12 +1216,12 @@ def getEqcs : GoalM (List (List Expr)) :=
return ( get).getEqcs
/--
Returns `true` if `e` has been already added to the case-split list at one point.
Returns `true` if `s` has been already added to the case-split list at one point.
Remark: this function returns `true` even if the split has already been resolved
and is not in the list anymore.
-/
def isKnownCaseSplit (e : Expr) : GoalM Bool :=
return ( get).split.added.contains { expr := e }
def isKnownCaseSplit (s : SplitInfo) : GoalM Bool :=
return ( get).split.added.contains s
/-- Returns `true` if `e` is a case-split that does not need to be performed anymore. -/
def isResolvedCaseSplit (e : Expr) : GoalM Bool :=
@@ -1238,14 +1237,23 @@ def markCaseSplitAsResolved (e : Expr) : GoalM Unit := do
trace_goal[grind.split.resolved] "{e}"
modify fun s => { s with split.resolved := s.split.resolved.insert { expr := e } }
private def updateSplitArgPosMap (sinfo : SplitInfo) : GoalM Unit := do
let .arg a b i _ := sinfo | return ()
let key := (a, b)
let is := ( get).split.argPosMap[key]? |>.getD []
modify fun s => { s with
split.argPosMap := s.split.argPosMap.insert key (i :: is)
}
/-- Inserts `e` into the list of case-split candidates if it was not inserted before. -/
def addSplitCandidate (e : Expr) : GoalM Unit := do
unless ( isKnownCaseSplit e) do
trace_goal[grind.split.candidate] "{e}"
def addSplitCandidate (sinfo : SplitInfo) : GoalM Unit := do
unless ( isKnownCaseSplit sinfo) do
trace_goal[grind.split.candidate] "{sinfo.getExpr}"
modify fun s => { s with
split.added := s.split.added.insert { expr := e }
split.candidates := e :: s.split.candidates
split.added := s.split.added.insert sinfo
split.candidates := sinfo :: s.split.candidates
}
updateSplitArgPosMap sinfo
/--
Returns extensionality theorems for the given type if available.
@@ -1269,18 +1277,10 @@ def synthesizeInstanceAndAssign (x type : Expr) : MetaM Bool := do
isDefEq x val
/-- Add a new lookahead candidate. -/
def addLookaheadCandidate (info : LookaheadInfo) : GoalM Unit := do
trace_goal[grind.lookahead.add] "{info.getExpr}"
match info with
| .imp .. =>
modify fun s => { s with split.lookaheads := info :: s.split.lookaheads }
| .arg a b i _ =>
let key := (a, b)
let is := ( get).split.lookaheadArgPos[key]? |>.getD []
modify fun s => { s with
split.lookaheads := info :: s.split.lookaheads
split.lookaheadArgPos := s.split.lookaheadArgPos.insert key (i :: is)
}
def addLookaheadCandidate (sinfo : SplitInfo) : GoalM Unit := do
trace_goal[grind.lookahead.add] "{sinfo.getExpr}"
modify fun s => { s with split.lookaheads := sinfo :: s.split.lookaheads }
updateSplitArgPosMap sinfo
/--
Helper function for executing `x` with a fresh `newFacts` and without modifying

View File

@@ -92,8 +92,7 @@ theorem Tree.forall_insert_of_forall
theorem Tree.bst_insert_of_bst
{t : Tree β} (h : BST t) (key : Nat) (value : β)
: BST (t.insert key value) := by
-- TODO: improve `grind` `funext` support, and minimize the number of splits
induction h <;> grind (splits := 12) [BST.node, BST.leaf, ForallTree.leaf, forall_insert_of_forall]
induction h <;> grind [BST.node, BST.leaf, ForallTree.leaf, forall_insert_of_forall]
def BinTree (β : Type u) := { t : Tree β // BST t }