Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
6f6bbfc717 feat: outParam and coercions experiment 2022-07-05 21:09:53 -07:00
6 changed files with 127 additions and 24 deletions

View File

@@ -300,7 +300,10 @@ private def propagateExpectedType (arg : Arg) : M Unit := do
unless fTypeBody.hasLooseBVars do
unless ( hasOptAutoParams fTypeBody) do
trace[Elab.app.propagateExpectedType] "{expectedType} =?= {fTypeBody}"
if ( isDefEq expectedType fTypeBody) then
if ( isSyntheticMVar fTypeBody) then
trace[Elab.app.propagateExpectedType] "{fTypeBody} is sythethic, skipping propagation"
modify fun s => { s with propagateExpected := false }
else if ( isDefEq expectedType fTypeBody) then
/- Note that we only set `propagateExpected := false` when propagation has succeeded. -/
modify fun s => { s with propagateExpected := false }
@@ -326,9 +329,10 @@ private def finalize : M Expr := do
match s.expectedType? with
| none => pure ()
| some expectedType =>
trace[Elab.app.finalize] "expected type: {expectedType}"
-- Try to propagate expected type. Ignore if types are not definitionally equal, caller must handle it.
discard <| isDefEq expectedType eType
unless ( isSyntheticMVar eType) do
-- Try to propagate expected type. Ignore if types are not definitionally equal, caller must handle it.
trace[Elab.app.finalize] "expected type: {expectedType}"
discard <| isDefEq expectedType eType
synthesizeAppInstMVars
return e
@@ -353,12 +357,68 @@ private def hasArgsToProcess : M Bool := do
let s get
return !s.args.isEmpty || !s.namedArgs.isEmpty
/- Return true if the next argument at `args` is of the form `_` -/
/- Return `true` if the next argument at `args` is of the form `_` -/
private def isNextArgHole : M Bool := do
match ( get).args with
| Arg.stx (Syntax.node _ ``Lean.Parser.Term.hole _) :: _ => pure true
| _ => pure false
/--
Return `true` if the next argument to be processed is the outparam of a local instance.
For example, suppose we have the class
```lean
class Get (Cont : Type u) (Idx : Type v) (Elem : outParam (Type w)) where
get (xs : Cont) (i : Idx) : Elem
```
And the current value of `fType` is
```
{Cont : Type u_1} → {Idx : Type u_2} → {Elem : Type u_3} → [self : Get Cont Idx Elem] → Cont → Idx → Elem
```
then the result returned by this method is `false` since `Cont` is not the output param of any local instance.
Now assume `fType` is
```
{Elem : Type u_3} → [self : Get Cont Idx Elem] → Cont → Idx → Elem
```
then, the method returns `true` because `Elem` is an output parameter for the local instance `[self : Get Cont Idx Elem]`.
-/
private partial def isNextOutParamOfLocalInstance : M Bool := do
let type := ( get).fType.bindingBody!
if ( hasLocalInstaceWithOutParams type) then
let x := mkFVar ( mkFreshFVarId)
isOutParamOfLocalInstance x (type.instantiate1 x)
else
return false
where
/- (quick filter) Return true if `type` constains a binder `[C ...]` where `C` is a class containing outparams. -/
hasLocalInstaceWithOutParams (type : Expr) : CoreM Bool := do
let .forallE _ d b c := type | return false
if c.binderInfo.isInstImplicit then
if let .const declName .. := d.getAppFn then
if hasOutParams ( getEnv) declName then
return true
hasLocalInstaceWithOutParams b
isOutParamOfLocalInstance (x : Expr) (type : Expr) : MetaM Bool := do
let .forallE _ d b c := type | return false
if c.binderInfo.isInstImplicit then
if let .const declName .. := d.getAppFn then
if hasOutParams ( getEnv) declName then
let cType inferType d.getAppFn
if ( isOutParamOf x 0 d.getAppArgs cType) then
return true
isOutParamOfLocalInstance x b
isOutParamOf (x : Expr) (i : Nat) (args : Array Expr) (cType : Expr) : MetaM Bool := do
if h : i < args.size then
match ( whnf cType) with
| .forallE _ d b _ =>
let arg := args.get i, h
if arg == x && d.isOutParam then
return true
isOutParamOf x (i+1) args b
| _ => return false
else
return false
mutual
/-
@@ -374,7 +434,10 @@ mutual
private partial def addImplicitArg (argName : Name) : M Expr := do
let argType getArgExpectedType
let arg mkFreshExprMVar argType
let arg if ( isNextOutParamOfLocalInstance) then
mkFreshExprMVar argType .synthetic
else
mkFreshExprMVar argType
modify fun s => { s with toSetErrorCtx := s.toSetErrorCtx.push arg.mvarId! }
addNewArg argName arg
main

View File

@@ -292,11 +292,11 @@ def elabBinOp : TermElab := fun stx expectedType? => do
trace[Elab.binop] "hasUncomparable: {r.hasUncomparable}, maxType: {r.max?}"
if r.hasUncomparable || r.max?.isNone then
let result toExpr tree
ensureHasType expectedType? result
ensureHasType expectedType? result (coeAtSyntheticMVar := false)
else
let result toExpr ( applyCoe tree r.max?.get!)
trace[Elab.binop] "result: {result}"
ensureHasType expectedType? result
ensureHasType expectedType? result (coeAtSyntheticMVar := false)
@[builtinTermElab binop_lazy]
def elabBinOpLazy : TermElab := elabBinOp
@@ -324,7 +324,7 @@ def elabBinRelCore (noProp : Bool) (stx : Syntax) (expectedType? : Option Expr)
let lhs toBoolIfNecessary lhs
let rhs toBoolIfNecessary rhs
let lhsType inferType lhs
let rhs ensureHasType lhsType rhs
let rhs ensureHasType lhsType rhs (coeAtSyntheticMVar := false)
elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] expectedType? (explicit := false) (ellipsis := false)
else
let mut maxType := r.max?.get!
@@ -342,7 +342,7 @@ where
if noProp then
-- We use `withNewMCtxDepth` to make sure metavariables are not assigned
if ( withNewMCtxDepth <| isDefEq ( inferType e) (mkSort levelZero)) then
return ( ensureHasType (Lean.mkConst ``Bool) e)
return ( ensureHasType (Lean.mkConst ``Bool) e (coeAtSyntheticMVar := false))
return e
@[builtinTermElab binrel] def elabBinRel : TermElab := elabBinRelCore false
@@ -404,7 +404,7 @@ def elabBinCalc : TermElab := fun stx expectedType? => do
throwErrorAt stepStxs[i]! "invalid 'calc' step, step result is not a relation{indentExpr resultType}"
| _ => throwErrorAt stepStxs[i]! "invalid 'calc' step, failed to synthesize `Trans` instance{indentExpr selfType}"
pure ()
ensureHasType expectedType? result
ensureHasType expectedType? result (coeAtSyntheticMVar := false)
@[builtinTermElab defaultOrOfNonempty]
def elabDefaultOrNonempty : TermElab := fun stx expectedType? => do

View File

@@ -71,10 +71,14 @@ private def synthesizePendingInstMVar (instMVar : MVarId) : TermElabM Bool :=
If `mvar` can be synthesized, then assign `auxMVarId := (expandCoe eNew)`.
-/
private def synthesizePendingCoeInstMVar
(auxMVarId : MVarId) (errorMsgHeader? : Option String) (eNew : Expr) (expectedType : Expr) (eType : Expr) (e : Expr) (f? : Option Expr) : TermElabM Bool := do
(auxMVarId : MVarId) (errorMsgHeader? : Option String) (eNew : Expr) (expectedType : Expr) (eType : Expr) (e : Expr) (f? : Option Expr) (postponeOnSyntheticCoe : Bool) : TermElabM Bool := do
let instMVarId := eNew.appArg!.mvarId!
withMVarContext instMVarId do
if ( isDefEq expectedType eType) then
let eType instantiateMVars eType
if postponeOnSyntheticCoe then
if ( isSyntheticMVar eType) then
return false
if ( withDefault (isDefEq expectedType eType)) then
/- This case may seem counterintuitive since we created the coercion
because the `isDefEq expectedType eType` test failed before.
However, it may succeed here because we have more information, for example, metavariables
@@ -312,11 +316,11 @@ mutual
reportUnsolvedGoals remainingGoals
/-- Try to synthesize the given pending synthetic metavariable. -/
private partial def synthesizeSyntheticMVar (mvarSyntheticDecl : SyntheticMVarDecl) (postponeOnError : Bool) (runTactics : Bool) : TermElabM Bool :=
private partial def synthesizeSyntheticMVar (mvarSyntheticDecl : SyntheticMVarDecl) (postponeOnError : Bool) (runTactics : Bool) (postponeOnSyntheticCoe : Bool) : TermElabM Bool :=
withRef mvarSyntheticDecl.stx do
match mvarSyntheticDecl.kind with
| SyntheticMVarKind.typeClass => synthesizePendingInstMVar mvarSyntheticDecl.mvarId
| SyntheticMVarKind.coe header? eNew expectedType eType e f? => synthesizePendingCoeInstMVar mvarSyntheticDecl.mvarId header? eNew expectedType eType e f?
| SyntheticMVarKind.coe header? eNew expectedType eType e f? => synthesizePendingCoeInstMVar mvarSyntheticDecl.mvarId header? eNew expectedType eType e f? (postponeOnSyntheticCoe := postponeOnSyntheticCoe)
-- NOTE: actual processing at `synthesizeSyntheticMVarsAux`
| SyntheticMVarKind.postponed savedContext => resumePostponed savedContext mvarSyntheticDecl.stx mvarSyntheticDecl.mvarId postponeOnError
| SyntheticMVarKind.tactic tacticCode savedContext =>
@@ -329,7 +333,7 @@ mutual
/--
Try to synthesize the current list of pending synthetic metavariables.
Return `true` if at least one of them was synthesized. -/
private partial def synthesizeSyntheticMVarsStep (postponeOnError : Bool) (runTactics : Bool) : TermElabM Bool := do
private partial def synthesizeSyntheticMVarsStep (postponeOnError : Bool) (runTactics : Bool) (postponeOnSyntheticCoe : Bool) : TermElabM Bool := do
let ctx read
traceAtCmdPos `Elab.resuming fun _ =>
m!"resuming synthetic metavariables, mayPostpone: {ctx.mayPostpone}, postponeOnError: {postponeOnError}"
@@ -343,7 +347,7 @@ mutual
let remainingSyntheticMVars syntheticMVars.filterRevM fun mvarDecl => do
-- We use `traceM` because we want to make sure the metavar local context is used to trace the message
traceM `Elab.postpone (withMVarContext mvarDecl.mvarId do addMessageContext m!"resuming {mkMVar mvarDecl.mvarId}")
let succeeded synthesizeSyntheticMVar mvarDecl postponeOnError runTactics
let succeeded synthesizeSyntheticMVar mvarDecl postponeOnError runTactics postponeOnSyntheticCoe
trace[Elab.postpone] if succeeded then format "succeeded" else format "not ready yet"
pure !succeeded
-- Merge new synthetic metavariables with `remainingSyntheticMVars`, i.e., metavariables that still couldn't be synthesized
@@ -367,7 +371,7 @@ mutual
let rec loop (_ : Unit) : TermElabM Unit := do
withRef ( getSomeSynthethicMVarsRef) <| withIncRecDepth do
unless ( get).syntheticMVars.isEmpty do
if synthesizeSyntheticMVarsStep (postponeOnError := false) (runTactics := false) then
if synthesizeSyntheticMVarsStep (postponeOnError := false) (runTactics := false) (postponeOnSyntheticCoe := true) then
loop ()
else if !mayPostpone then
/- Resume pending metavariables with "elaboration postponement" disabled.
@@ -386,13 +390,13 @@ mutual
We the type of `x` may learn later its type and it may contain implicit and/or auto arguments.
By disabling postponement, we are essentially giving up the opportunity of learning `x`s type
and assume it does not have implict and/or auto arguments. -/
if withoutPostponing <| synthesizeSyntheticMVarsStep (postponeOnError := true) (runTactics := false) then
if withoutPostponing <| synthesizeSyntheticMVarsStep (postponeOnError := true) (runTactics := false) (postponeOnSyntheticCoe := true) then
loop ()
else if synthesizeUsingDefault then
loop ()
else if withoutPostponing <| synthesizeSyntheticMVarsStep (postponeOnError := false) (runTactics := false) then
else if withoutPostponing <| synthesizeSyntheticMVarsStep (postponeOnError := false) (runTactics := false) (postponeOnSyntheticCoe := false) then
loop ()
else if synthesizeSyntheticMVarsStep (postponeOnError := false) (runTactics := true) then
else if synthesizeSyntheticMVarsStep (postponeOnError := false) (runTactics := true) (postponeOnSyntheticCoe := true) then
loop ()
else
reportStuckSyntheticMVars ignoreStuckTC

View File

@@ -949,10 +949,14 @@ private def tryLiftAndCoe (errorMsgHeader? : Option String) (expectedType : Expr
Argument `f?` is used only for generating error messages. -/
def ensureHasTypeAux (expectedType? : Option Expr) (eType : Expr) (e : Expr)
(f? : Option Expr := none) (errorMsgHeader? : Option String := none) : TermElabM Expr := do
(f? : Option Expr := none) (errorMsgHeader? : Option String := none) (coeAtSyntheticMVar : Bool := true) : TermElabM Expr := do
match expectedType? with
| none => return e
| some expectedType =>
if coeAtSyntheticMVar then
if ( isSyntheticMVar eType) && !( read).inPattern && ( getEnv).contains ``Lean.Internal.coeM then
if !e.isAppOf ``OfNat.ofNat then
return ( mkCoe expectedType eType e f? errorMsgHeader?)
if ( isDefEq eType expectedType) then
return e
else
@@ -961,12 +965,12 @@ def ensureHasTypeAux (expectedType? : Option Expr) (eType : Expr) (e : Expr)
/--
If `expectedType?` is `some t`, then ensure `t` and type of `e` are definitionally equal.
If they are not, then try coercions. -/
def ensureHasType (expectedType? : Option Expr) (e : Expr) (errorMsgHeader? : Option String := none) : TermElabM Expr :=
def ensureHasType (expectedType? : Option Expr) (e : Expr) (errorMsgHeader? : Option String := none) (coeAtSyntheticMVar : Bool := true) : TermElabM Expr :=
match expectedType? with
| none => return e
| _ => do
let eType inferType e
ensureHasTypeAux expectedType? eType e none errorMsgHeader?
ensureHasTypeAux expectedType? eType e none errorMsgHeader? coeAtSyntheticMVar
/--
Create a synthetic sorry for the given expected type. If `expectedType? = none`, then a fresh

View File

@@ -496,6 +496,16 @@ def getMVarDecl (mvarId : MVarId) : MetaM MetavarDecl := do
| some d => pure d
| none => throwError "unknown metavariable '?{mvarId.name}'"
def getMVarDeclKind (mvarId : MVarId) : MetaM MetavarKind :=
return ( getMVarDecl mvarId).kind
/-- Reture `true` if `e` is a synthetic (or synthetic opaque) metavariable -/
def isSyntheticMVar (e : Expr) : MetaM Bool := do
if e.isMVar then
return ( getMVarDeclKind e.mvarId!) matches .synthetic | .syntheticOpaque
else
return false
def setMVarKind (mvarId : MVarId) (kind : MetavarKind) : MetaM Unit :=
modifyMCtx fun mctx => mctx.setMVarKind mvarId kind

View File

@@ -0,0 +1,22 @@
namespace Ex
class Get (Cont : Type u) (Idx : Type v) (Elem : outParam (Type w)) where
get (xs : Cont) (i : Idx) : Elem
export Get (get)
instance [Inhabited α] : Get (Array α) Nat α where
get xs i := xs.get! i
instance : Coe Bool Nat where
coe b := if b then 1 else 0
def g (as : Array (Array Bool)) : Nat :=
let bs := get as 0
get bs 1
def h (as : Array (Array Bool)) (i : Nat) : Nat :=
let bs := get as i
f (get bs i)
end Ex