Compare commits

...

3 Commits

Author SHA1 Message Date
Leonardo de Moura
3160fa4f55 perf: consider at most one answer for type class resolution subgoals not containing metavariables
closes #3996
2024-04-27 11:45:06 -07:00
Leonardo de Moura
5dbdf2cd8b chore: code convention 2024-04-27 10:49:36 -07:00
Leonardo de Moura
7f4d609d56 perf: linearity issue 2024-04-27 09:56:07 -07:00
3 changed files with 90 additions and 28 deletions

View File

@@ -114,7 +114,7 @@ For example:
(The type of `inst` must not contain mvars.)
-/
partial def computeSynthOrder (inst : Expr) : MetaM (Array Nat) :=
private partial def computeSynthOrder (inst : Expr) : MetaM (Array Nat) :=
withReducible do
let instTy inferType inst

View File

@@ -41,6 +41,14 @@ structure GeneratorNode where
mctx : MetavarContext
instances : Array Instance
currInstanceIdx : Nat
/--
`typeHasMVars := true` if type of `mvar` contains metavariables.
We store this information to implement an optimization that relies on the fact
that instances are "morally canonical."
That is, we need to find at most one answer for this generator node if the type
does not have metavariables.
-/
typeHasMVars : Bool
deriving Inhabited
structure ConsumerNode where
@@ -56,8 +64,8 @@ inductive Waiter where
| root : Waiter
def Waiter.isRoot : Waiter Bool
| Waiter.consumerNode _ => false
| Waiter.root => true
| .consumerNode _ => false
| .root => true
/-!
In tabled resolution, we creating a mapping from goals (e.g., `Coe Nat ?x`) to
@@ -98,10 +106,10 @@ partial def normLevel (u : Level) : M Level := do
if !u.hasMVar then
return u
else match u with
| Level.succ v => return u.updateSucc! ( normLevel v)
| Level.max v w => return u.updateMax! ( normLevel v) ( normLevel w)
| Level.imax v w => return u.updateIMax! ( normLevel v) ( normLevel w)
| Level.mvar mvarId =>
| .succ v => return u.updateSucc! ( normLevel v)
| .max v w => return u.updateMax! ( normLevel v) ( normLevel w)
| .imax v w => return u.updateIMax! ( normLevel v) ( normLevel w)
| .mvar mvarId =>
if ( getMCtx).getLevelDepth mvarId != ( getMCtx).depth then
return u
else
@@ -118,15 +126,15 @@ partial def normExpr (e : Expr) : M Expr := do
if !e.hasMVar then
pure e
else match e with
| Expr.const _ us => return e.updateConst! ( us.mapM normLevel)
| Expr.sort u => return e.updateSort! ( normLevel u)
| Expr.app f a => return e.updateApp! ( normExpr f) ( normExpr a)
| Expr.letE _ t v b _ => return e.updateLet! ( normExpr t) ( normExpr v) ( normExpr b)
| Expr.forallE _ d b _ => return e.updateForallE! ( normExpr d) ( normExpr b)
| Expr.lam _ d b _ => return e.updateLambdaE! ( normExpr d) ( normExpr b)
| Expr.mdata _ b => return e.updateMData! ( normExpr b)
| Expr.proj _ _ b => return e.updateProj! ( normExpr b)
| Expr.mvar mvarId =>
| .const _ us => return e.updateConst! ( us.mapM normLevel)
| .sort u => return e.updateSort! ( normLevel u)
| .app f a => return e.updateApp! ( normExpr f) ( normExpr a)
| .letE _ t v b _ => return e.updateLet! ( normExpr t) ( normExpr v) ( normExpr b)
| .forallE _ d b _ => return e.updateForallE! ( normExpr d) ( normExpr b)
| .lam _ d b _ => return e.updateLambdaE! ( normExpr d) ( normExpr b)
| .mdata _ b => return e.updateMData! ( normExpr b)
| .proj _ _ b => return e.updateProj! ( normExpr b)
| .mvar mvarId =>
if !( mvarId.isAssignable) then
return e
else
@@ -202,7 +210,7 @@ def getInstances (type : Expr) : MetaM (Array Instance) := do
let result := result.insertionSort fun e₁ e₂ => e₁.priority < e₂.priority
let erasedInstances getErasedInstances
let mut result result.filterMapM fun e => match e.val with
| Expr.const constName us =>
| .const constName us =>
if erasedInstances.contains constName then
return none
else
@@ -234,6 +242,7 @@ def mkGeneratorNode? (key mvar : Expr) : MetaM (Option GeneratorNode) := do
let mctx getMCtx
return some {
mvar, key, mctx, instances
typeHasMVars := mvarType.hasMVar
currInstanceIdx := instances.size
}
@@ -351,7 +360,7 @@ def tryResolve (mvar : Expr) (inst : Instance) : MetaM (Option (MetavarContext
let lctx getLCtx
let localInsts getLocalInstances
forallTelescopeReducing mvarType fun xs mvarTypeBody => do
let subgoals, instVal, instTypeBody getSubgoals lctx localInsts xs inst
let { subgoals, instVal, instTypeBody } getSubgoals lctx localInsts xs inst
withTraceNode `Meta.synthInstance.tryResolve (withMCtx ( getMCtx) do
return m!"{exceptOptionEmoji ·} {← instantiateMVars mvarTypeBody} ≟ {← instantiateMVars instTypeBody}") do
if ( isDefEq mvarTypeBody instTypeBody) then
@@ -373,7 +382,7 @@ def tryAnswer (mctx : MetavarContext) (mvar : Expr) (answer : Answer) : SynthM (
/-- Move waiters that are waiting for the given answer to the resume stack. -/
def wakeUp (answer : Answer) : Waiter SynthM Unit
| Waiter.root => do
| .root => do
/- Recall that we now use `ignoreLevelMVarDepth := true`. Thus, we should allow solutions
containing universe metavariables, and not check `answer.result.paramNames.isEmpty`.
We use `openAbstractMVarsResult` to construct the universe metavariables
@@ -383,7 +392,7 @@ def wakeUp (answer : Answer) : Waiter → SynthM Unit
else
let (_, _, answerExpr) openAbstractMVarsResult answer.result
trace[Meta.synthInstance] "skip answer containing metavariables {answerExpr}"
| Waiter.consumerNode cNode =>
| .consumerNode cNode =>
modify fun s => { s with resumeStack := s.resumeStack.push (cNode, answer) }
def isNewAnswer (oldAnswers : Array Answer) (answer : Answer) : Bool :=
@@ -414,11 +423,11 @@ def addAnswer (cNode : ConsumerNode) : SynthM Unit := do
let answer mkAnswer cNode
-- Remark: `answer` does not contain assignable or assigned metavariables.
let key := cNode.key
let entry getEntry key
if isNewAnswer entry.answers answer then
let newEntry := { entry with answers := entry.answers.push answer }
let { waiters, answers } getEntry key
if isNewAnswer answers answer then
let newEntry := { waiters, answers := answers.push answer }
modify fun s => { s with tableEntries := s.tableEntries.insert key newEntry }
entry.waiters.forM (wakeUp answer)
waiters.forM (wakeUp answer)
/--
Return `true` if a type of the form `(a_1 : A_1) → ... → (a_n : A_n) → B` has an unused argument `a_i`.
@@ -426,7 +435,7 @@ def addAnswer (cNode : ConsumerNode) : SynthM Unit := do
Remark: This is syntactic check and no reduction is performed.
-/
private def hasUnusedArguments : Expr Bool
| Expr.forallE _ _ b _ => !b.hasLooseBVar 0 || hasUnusedArguments b
| .forallE _ _ b _ => !b.hasLooseBVar 0 || hasUnusedArguments b
| _ => false
/--
@@ -539,6 +548,17 @@ def generate : SynthM Unit := do
let inst := gNode.instances.get! idx
let mctx := gNode.mctx
let mvar := gNode.mvar
/- See comment at `typeHasMVars` -/
unless gNode.typeHasMVars do
if let some entry := ( get).tableEntries.find? key then
unless entry.answers.isEmpty do
/-
We already have an answer for this node, and since its type does not have metavariables,
we can skip other solutions because we assume instances are "morally canonical".
We have added this optimization to address issue #3996.
-/
modify fun s => { s with generatorStack := s.generatorStack.pop }
return
discard do withMCtx mctx do
withTraceNode `Meta.synthInstance
(return m!"{exceptOptionEmoji ·} apply {inst.val} to {← instantiateMVars (← inferType mvar)}") do
@@ -667,7 +687,7 @@ private partial def preprocessArgs (type : Expr) (i : Nat) (args : Array Expr) (
private def preprocessOutParam (type : Expr) : MetaM Expr :=
forallTelescope type fun xs typeBody => do
match typeBody.getAppFn with
| c@(Expr.const declName _) =>
| c@(.const declName _) =>
let env getEnv
if let some outParamsPos := getOutParamPositions? env declName then
unless outParamsPos.isEmpty do
@@ -775,8 +795,7 @@ def synthInstance (type : Expr) (maxResultSize? : Option Nat := none) : MetaM Ex
private def synthPendingImp (mvarId : MVarId) : MetaM Bool := withIncRecDepth <| mvarId.withContext do
let mvarDecl mvarId.getDecl
match mvarDecl.kind with
| MetavarKind.syntheticOpaque =>
return false
| .syntheticOpaque => return false
| _ =>
/- Check whether the type of the given metavariable is a class or not. If yes, then try to synthesize
it using type class resolution. We only do it for `synthetic` and `natural` metavariables. -/

43
tests/lean/run/3996.lean Normal file
View File

@@ -0,0 +1,43 @@
namespace Ex1
class A where
class B (n : Nat) where
class C where
instance test [B 10000] [C] : A where
instance Bsucc {n : Nat} [B n] : B n.succ where
instance instB0 : B 0 where
instance instB10000 : B 10000 where
/--
error: failed to synthesize
A
-/
#guard_msgs in
#synth A -- should fail quickly
end Ex1
namespace Ex2
class A where
class B (n : Nat) where
class C where
instance test' [B 10] : A where
instance test [B 0] [C] : A where
instance foo {n : Nat} [B n.succ] : B n where
instance instB (n : Nat) : B n where
/--
info: test'
-/
#guard_msgs in
#synth A
end Ex2