mirror of
https://github.com/leanprover/lean4.git
synced 2026-03-17 18:34:06 +00:00
feat: support expected type annotation in doPatDecl
This PR allows `let ⟨width, height⟩ : Nat × Nat ← action` in do-notation, propagating the expected type to the monadic action. Previously, only `let ⟨width, height⟩ : Nat × Nat := ← action` was supported, requiring the less ergonomic `:= ←` workaround. The type annotation is added as `optType` in the `doPatDecl` parser, matching `doIdDecl`'s existing support. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -145,13 +145,13 @@ def elabDoArrow (letOrReassign : LetOrReassign) (stx : TSyntax [``doIdDecl, ``do
|
||||
| _, _ => pure xType?
|
||||
elabDoIdDecl x xType? rhs (declareMutVar? letOrReassign.getLetMutTk? x <| dec.continueWithUnit)
|
||||
(kind := dec.kind)
|
||||
| `(doPatDecl| _%$pattern ← $rhs) =>
|
||||
| `(doPatDecl| _%$pattern $[: $patType?]? ← $rhs) =>
|
||||
let x := mkIdentFrom pattern (← mkFreshUserName `__x)
|
||||
elabDoIdDecl x none rhs dec.continueWithUnit (kind := dec.kind)
|
||||
| `(doPatDecl| $pattern:term ← $rhs $[| $otherwise? $(rest?)?]?) =>
|
||||
elabDoIdDecl x patType? rhs dec.continueWithUnit (kind := dec.kind)
|
||||
| `(doPatDecl| $pattern:term $[: $patType?]? ← $rhs $[| $otherwise? $(rest?)?]?) =>
|
||||
let rest? := rest?.join
|
||||
let x := mkIdentFrom pattern (← mkFreshUserName `__x)
|
||||
elabDoIdDecl x none rhs do
|
||||
elabDoIdDecl x patType? rhs do
|
||||
match letOrReassign, otherwise? with
|
||||
| .let mutTk?, some otherwise =>
|
||||
elabDoElem (← `(doElem| let $[mut%$mutTk?]? $pattern:term := $x | $otherwise $(rest?)?)) dec
|
||||
|
||||
@@ -184,7 +184,7 @@ partial def ofLetOrReassignArrow (reassignment : Bool) (decl : TSyntax [``doIdDe
|
||||
| `(doIdDecl| $x:ident $[: $_]? ← $rhs) =>
|
||||
let reassigns := if reassignment then #[x] else #[]
|
||||
ofLetOrReassign reassigns rhs none none
|
||||
| `(doPatDecl| $pattern ← $rhs $[| $otherwise? $[$body??]?]?) =>
|
||||
| `(doPatDecl| $pattern $[: $_]? ← $rhs $[| $otherwise? $[$body??]?]?) =>
|
||||
let reassigns ← if reassignment then getPatternVarsEx pattern else pure #[]
|
||||
ofLetOrReassign reassigns rhs otherwise? body??.join
|
||||
| _ => throwError "Not a let or reassignment declaration: {toString decl}"
|
||||
|
||||
@@ -722,7 +722,7 @@ def getDoLetRecVars (doLetRec : Syntax) : TermElabM (Array Var) := do
|
||||
def getDoIdDeclVar (doIdDecl : Syntax) : Var :=
|
||||
doIdDecl[0]
|
||||
|
||||
-- termParser >> leftArrow >> termParser >> optional (" | " >> termParser)
|
||||
-- termParser >> optType >> leftArrow >> termParser >> optional (" | " >> termParser)
|
||||
def getDoPatDeclVars (doPatDecl : Syntax) : TermElabM (Array Var) := do
|
||||
let pattern := doPatDecl[0]
|
||||
getPatternVarsEx pattern
|
||||
@@ -1420,7 +1420,7 @@ mutual
|
||||
where
|
||||
```
|
||||
def doIdDecl := leading_parser ident >> optType >> leftArrow >> doElemParser
|
||||
def doPatDecl := leading_parser termParser >> leftArrow >> doElemParser >> optional ((" | " >> doSeq) >> optional doSeq)
|
||||
def doPatDecl := leading_parser termParser >> optType >> leftArrow >> doElemParser >> optional ((" | " >> doSeq) >> optional doSeq)
|
||||
```
|
||||
-/
|
||||
partial def doLetArrowToCode (doLetArrow : Syntax) (doElems : List Syntax) : M CodeBlock := do
|
||||
@@ -1440,13 +1440,21 @@ mutual
|
||||
| kRef::_ => concat c kRef y k
|
||||
else if decl.getKind == ``Parser.Term.doPatDecl then
|
||||
let pattern := decl[0]
|
||||
let doElem := decl[2]
|
||||
let optElse := decl[3]
|
||||
let optType := decl[1]
|
||||
let doElem := decl[3]
|
||||
let optElse := decl[4]
|
||||
if optElse.isNone then withFreshMacroScope do
|
||||
let auxDo ← if isMutableLet doLetArrow then
|
||||
`(do let%$doLetArrow __discr ← $doElem; let%$doLetArrow mut $pattern:term := __discr)
|
||||
let auxDo ← if optType.isNone then
|
||||
if isMutableLet doLetArrow then
|
||||
`(do let%$doLetArrow __discr ← $doElem; let%$doLetArrow mut $pattern:term := __discr)
|
||||
else
|
||||
`(do let%$doLetArrow __discr ← $doElem; let%$doLetArrow $pattern:term := __discr)
|
||||
else
|
||||
`(do let%$doLetArrow __discr ← $doElem; let%$doLetArrow $pattern:term := __discr)
|
||||
let ty := optType[0][1]
|
||||
if isMutableLet doLetArrow then
|
||||
`(do let%$doLetArrow __discr : $ty ← $doElem; let%$doLetArrow mut $pattern:term := __discr)
|
||||
else
|
||||
`(do let%$doLetArrow __discr : $ty ← $doElem; let%$doLetArrow $pattern:term := __discr)
|
||||
doSeqToCode <| getDoSeqElems (getDoSeq auxDo) ++ doElems
|
||||
else
|
||||
let elseSeq := optElse[1]
|
||||
@@ -1457,7 +1465,11 @@ mutual
|
||||
else
|
||||
pure (getDoSeqElems contSeq).toArray
|
||||
let contSeq := mkDoSeq contSeq
|
||||
let auxDo ← `(do let%$doLetArrow __discr ← $doElem; match%$doLetArrow __discr with | $pattern:term => $contSeq | _ => $elseSeq)
|
||||
let auxDo ← if optType.isNone then
|
||||
`(do let%$doLetArrow __discr ← $doElem; match%$doLetArrow __discr with | $pattern:term => $contSeq | _ => $elseSeq)
|
||||
else
|
||||
let ty := optType[0][1]
|
||||
`(do let%$doLetArrow __discr : $ty ← $doElem; match%$doLetArrow __discr with | $pattern:term => $contSeq | _ => $elseSeq)
|
||||
doSeqToCode <| getDoSeqElems (getDoSeq auxDo) ++ doElems
|
||||
else
|
||||
throwError "unexpected kind of `do` declaration"
|
||||
@@ -1492,10 +1504,15 @@ mutual
|
||||
doSeqToCode <| getDoSeqElems (getDoSeq auxDo) ++ doElems
|
||||
else if decl.getKind == ``Parser.Term.doPatDecl then
|
||||
let pattern := decl[0]
|
||||
let doElem := decl[2]
|
||||
let optElse := decl[3]
|
||||
let optType := decl[1]
|
||||
let doElem := decl[3]
|
||||
let optElse := decl[4]
|
||||
if optElse.isNone then withFreshMacroScope do
|
||||
let auxDo ← `(do let __discr ← $doElem; $pattern:term := __discr)
|
||||
let auxDo ← if optType.isNone then
|
||||
`(do let __discr ← $doElem; $pattern:term := __discr)
|
||||
else
|
||||
let ty := optType[0][1]
|
||||
`(do let __discr : $ty ← $doElem; $pattern:term := __discr)
|
||||
doSeqToCode <| getDoSeqElems (getDoSeq auxDo) ++ doElems
|
||||
else
|
||||
throwError "reassignment with `|` (i.e., \"else clause\") is not currently supported"
|
||||
|
||||
@@ -86,7 +86,7 @@ def doIdDecl := leading_parser
|
||||
atomic (ident >> optType >> ppSpace >> leftArrow) >>
|
||||
doElemParser
|
||||
def doPatDecl := leading_parser
|
||||
atomic (termParser >> ppSpace >> leftArrow) >>
|
||||
atomic (termParser >> optType >> ppSpace >> leftArrow) >>
|
||||
doElemParser >> optional ((checkColGe >> " | " >> doSeqIndent) >> optional (checkColGe >> doSeqIndent))
|
||||
@[builtin_doElem_parser] def doLetArrow := leading_parser withPosition <|
|
||||
"let " >> optional "mut " >> (doIdDecl <|> doPatDecl)
|
||||
|
||||
67
tests/elab/doLetArrowPatExpectedType.lean
Normal file
67
tests/elab/doLetArrowPatExpectedType.lean
Normal file
@@ -0,0 +1,67 @@
|
||||
-- Test that `let pat : Type ← rhs` works in do-notation
|
||||
|
||||
-- Basic: anonymous constructor pattern with expected type
|
||||
def test1 : Id (Nat × String) := do
|
||||
let ⟨x, y⟩ : Nat × String ← pure ⟨1, "hello"⟩
|
||||
return (x, y)
|
||||
|
||||
example : test1 = (1, "hello") := rfl
|
||||
|
||||
-- Tuple pattern with expected type
|
||||
def test2 : Id (Nat × Nat) := do
|
||||
let (a, b) : Nat × Nat ← pure (2, 3)
|
||||
return (a + b, a * b)
|
||||
|
||||
example : test2 = (5, 6) := rfl
|
||||
|
||||
-- The expected type helps resolve what would otherwise be ambiguous
|
||||
structure Dims where
|
||||
width : Nat
|
||||
height : Nat
|
||||
|
||||
def getDims : Id Dims := pure ⟨800, 600⟩
|
||||
|
||||
def test3 : Id Nat := do
|
||||
let ⟨w, h⟩ : Dims ← getDims
|
||||
return w * h
|
||||
|
||||
example : test3 = 480000 := rfl
|
||||
|
||||
-- With else branch
|
||||
def test4 : Option Nat := do
|
||||
let .some x : Option Nat ← some (some 42)
|
||||
| none
|
||||
return x
|
||||
|
||||
example : test4 = some 42 := rfl
|
||||
|
||||
-- With else branch, taking the else
|
||||
def test5 : Option Nat := do
|
||||
let .some _x : Option Nat ← some none
|
||||
| pure 99
|
||||
return _x
|
||||
|
||||
example : test5 = some 99 := rfl
|
||||
|
||||
-- Mutable let with pattern and type
|
||||
def test6 : Id Nat := do
|
||||
let mut ⟨a, b⟩ : Nat × Nat ← pure (1, 2)
|
||||
a := a + 10
|
||||
b := b + 20
|
||||
return a + b
|
||||
|
||||
example : test6 = 33 := rfl
|
||||
|
||||
-- Without type annotation still works (regression test)
|
||||
def test7 : Id Nat := do
|
||||
let (x, y) ← pure (1, 2)
|
||||
return x + y
|
||||
|
||||
example : test7 = 3 := rfl
|
||||
|
||||
-- Wildcard pattern with type
|
||||
def test8 : Id Nat := do
|
||||
let _ : Nat ← pure 42
|
||||
return 0
|
||||
|
||||
example : test8 = 0 := rfl
|
||||
Reference in New Issue
Block a user