Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
c5d8a378ba feat: instantiate ematch theorems using assignment
This PR implements `Grind.EMatch.instantiateTheorem` in the (WIP)
`grind` tactic.
2024-12-31 09:26:33 -08:00
4 changed files with 111 additions and 15 deletions

View File

@@ -33,6 +33,7 @@ builtin_initialize registerTraceClass `grind.internalize
builtin_initialize registerTraceClass `grind.ematch
builtin_initialize registerTraceClass `grind.ematch.pattern
builtin_initialize registerTraceClass `grind.ematch.instance
builtin_initialize registerTraceClass `grind.ematch.instance.assignment
builtin_initialize registerTraceClass `grind.issues
builtin_initialize registerTraceClass `grind.simp

View File

@@ -53,8 +53,8 @@ structure Choice where
/-- Theorem instances found so far. We only internalize them after we complete a full round of E-matching. -/
structure TheoremInstance where
prop : Expr
proof : Expr
prop : Expr
generation : Nat
deriving Inhabited
@@ -163,10 +163,75 @@ private def processContinue (c : Choice) (p : Expr) : M Unit := do
let c := { c with gen := Nat.max gen c.gen }
modify fun s => { s with choiceStack := c :: s.choiceStack }
private partial def instantiateTheorem (c : Choice) : M Unit := do
trace[grind.ematch.instance] "{(← read).thm.origin.key} : {assignmentToMessageData c.assignment}"
-- TODO
return ()
/--
Stores new theorem instance in the state.
Recall that new instances are internalized later, after a full round of ematching.
-/
private def addNewInstance (origin : Origin) (proof : Expr) (generation : Nat) : M Unit := do
let proof instantiateMVars proof
if grind.debug.proofs.get ( getOptions) then
check proof
let prop inferType proof
trace[grind.ematch.instance] "{← origin.pp}: {prop}"
modify fun s => { s with newInstances := s.newInstances.push { proof, prop, generation } }
/--
After processing a (multi-)pattern, use the choice assignment to instantiate the proof.
Missing parameters are synthesized using type inference and type class synthesis."
-/
private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do
let thm := ( read).thm
trace[grind.ematch.instance.assignment] "{← thm.origin.pp}: {assignmentToMessageData c.assignment}"
let proof thm.getProofWithFreshMVarLevels
let numParams := thm.numParams
assert! c.assignment.size == numParams
let (mvars, bis, _) forallMetaBoundedTelescope ( inferType proof) numParams
if mvars.size != thm.numParams then
trace[grind.issues] "unexpected number of parameters at {← thm.origin.pp}"
return ()
-- Apply assignment
for h : i in [:mvars.size] do
let v := c.assignment[numParams - i - 1]!
unless isSameExpr v unassigned do
let mvarId := mvars[i].mvarId!
unless ( mvarId.checkedAssign v) do
trace[grind.issues] "type error constructing proof for {← thm.origin.pp}\nwhen assigning metavariable {mvars[i]} with {indentExpr v}"
return ()
-- Synthesize instances
for mvar in mvars, bi in bis do
if bi.isInstImplicit && !( mvar.mvarId!.isAssigned) then
let type inferType mvar
unless ( synthesizeInstance mvar type) do
trace[grind.issues] "failed to synthesize instance when instantiating {← thm.origin.pp}{indentExpr type}"
return ()
if ( mvars.allM (·.mvarId!.isAssigned)) then
addNewInstance thm.origin (mkAppN proof mvars) c.gen
else
-- instance has hypothesis
mkImp mvars 0 proof #[]
where
synthesizeInstance (x type : Expr) : MetaM Bool := do
let .some val trySynthInstance type | return false
isDefEq x val
mkImp (mvars : Array Expr) (i : Nat) (proof : Expr) (xs : Array Expr) : M Unit := do
if h : i < mvars.size then
let mvar := mvars[i]
if ( mvar.mvarId!.isAssigned) then
mkImp mvars (i+1) (mkApp proof mvar) xs
else
let mvarType instantiateMVars ( inferType mvar)
if mvarType.hasMVar then
let thm := ( read).thm
trace[grind.issues] "failed to create hypothesis for instance of {← thm.origin.pp} hypothesis type has metavars{indentExpr mvarType}"
return ()
withLocalDeclD ( mkFreshUserName `h) mvarType fun x => do
mkImp mvars (i+1) (mkApp proof x) (xs.push x)
else
let proof instantiateMVars proof
let proof mkLambdaFVars xs proof
let thm := ( read).thm
addNewInstance thm.origin proof c.gen
/-- Process choice stack until we don't have more choices to be processed. -/
private partial def processChoices : M Unit := do

View File

@@ -5,6 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.HeadIndex
import Lean.PrettyPrinter
import Lean.Util.FoldConsts
import Lean.Util.CollectFVars
import Lean.Meta.Basic
@@ -32,8 +33,21 @@ def Origin.key : Origin → Name
| .stx id _ => id
| .other => `other
def Origin.pp [Monad m] [MonadEnv m] [MonadError m] (o : Origin) : m MessageData := do
match o with
| .decl declName => return MessageData.ofConst ( mkConstWithLevelParams declName)
| .fvar fvarId => return mkFVar fvarId
| .stx _ ref => return ref
| .other => return "[unknown]"
/-- A theorem for heuristic instantiation based on E-matching. -/
structure EMatchTheorem where
/--
It stores universe parameter names for universe polymorphic proofs.
Recall that it is non-empty only when we elaborate an expression provided by the user.
When `proof` is just a constant, we can use the universe parameter names stored in the declaration.
-/
levelParams : Array Name
proof : Expr
numParams : Nat
patterns : List Expr
@@ -54,6 +68,20 @@ def EMatchTheorems.insert (s : EMatchTheorems) (thm : EMatchTheorem) : EMatchThe
else
return PersistentHashMap.insert s declName [thm]
def EMatchTheorem.getProofWithFreshMVarLevels (thm : EMatchTheorem) : MetaM Expr := do
if thm.proof.isConst && thm.levelParams.isEmpty then
let declName := thm.proof.constName!
let info getConstInfo declName
if info.levelParams.isEmpty then
return thm.proof
else
mkConstWithFreshMVarLevels declName
else if thm.levelParams.isEmpty then
return thm.proof
else
let us thm.levelParams.mapM fun _ => mkFreshLevelMVar
return thm.proof.instantiateLevelParamsArray thm.levelParams us
private builtin_initialize ematchTheoremsExt : SimpleScopedEnvExtension EMatchTheorem EMatchTheorems
registerSimpleScopedEnvExtension {
addEntry := EMatchTheorems.insert
@@ -316,7 +344,8 @@ def addEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr)
throwError "invalid pattern(s) for `{declName}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
ematchTheoremsExt.add {
proof, patterns, numParams, symbols
origin := .decl declName
levelParams := #[]
origin := .decl declName
}
def getEMatchTheorems : CoreM EMatchTheorems :=

View File

@@ -6,10 +6,11 @@ grind_pattern Array.get_set_ne => (a.set i v hi)[j]
set_option trace.grind.ematch.instance true
set_option grind.debug.proofs true
/--
info: [grind.ematch.instance] Array.get_set_eq : [α, bs, j, w, Lean.Grind.nestedProof (j < bs.toList.length) h₂]
[grind.ematch.instance] Array.get_set_eq : [α, as, i, v, Lean.Grind.nestedProof (i < as.toList.length) h₁]
[grind.ematch.instance] Array.get_set_ne : [α, bs, j, Lean.Grind.nestedProof (j < bs.toList.length) h₂, i, w, _, _]
info: [grind.ematch.instance] Array.get_set_eq: (bs.set j w ⋯)[j] = w
[grind.ematch.instance] Array.get_set_eq: (as.set i v ⋯)[i] = v
-/
#guard_msgs (info) in
example (as : Array α)
@@ -31,8 +32,8 @@ theorem Rtrans (a b c : Nat) : R a b → R b c → R a c := sorry
grind_pattern Rtrans => R a b, R b c
/--
info: [grind.ematch.instance] Rtrans : [b, c, d, _, _]
[grind.ematch.instance] Rtrans : [a, b, c, _, _]
info: [grind.ematch.instance] Rtrans: R b c → R c d → R b d
[grind.ematch.instance] Rtrans: R a b → R b c → R a c
-/
#guard_msgs (info) in
example : R a b R b c R c d False := by
@@ -41,10 +42,10 @@ example : R a b → R b c → R c d → False := by
-- In the following test we are performing one round of ematching only
/--
info: [grind.ematch.instance] Rtrans : [c, d, e, _, _]
[grind.ematch.instance] Rtrans : [c, d, n, _, _]
[grind.ematch.instance] Rtrans : [b, c, d, _, _]
[grind.ematch.instance] Rtrans : [a, b, c, _, _]
info: [grind.ematch.instance] Rtrans: R c d → R d e → R c e
[grind.ematch.instance] Rtrans: R c d → R d n → R c n
[grind.ematch.instance] Rtrans: R b c → R c d → R b d
[grind.ematch.instance] Rtrans: R a b → R b c → R a c
-/
#guard_msgs (info) in
example : R a b R b c R c d R d e R d n False := by