Compare commits

...

2 Commits

Author SHA1 Message Date
Leonardo de Moura
4a0995c027 fix: mdata handling at LitValues.lean 2024-02-25 09:25:24 -08:00
Leonardo de Moura
a6c130b6ae fix: match patterns containing int values and constructors 2024-02-25 06:32:54 -08:00
3 changed files with 63 additions and 12 deletions

View File

@@ -1037,6 +1037,14 @@ def getAppFn : Expr → Expr
| app f _ => getAppFn f
| e => e
/--
Similar to `getAppFn`, but skips `mdata`
-/
def getAppFn' : Expr Expr
| app f _ => getAppFn' f
| mdata _ a => getAppFn' a
| e => e
/-- Given `f a₀ a₁ ... aₙ`, returns true if `f` is a constant with name `n`. -/
def isAppOf (e : Expr) (n : Name) : Bool :=
match e.getAppFn with
@@ -1207,10 +1215,21 @@ def getRevArg! : Expr → Nat → Expr
| app f _, i+1 => getRevArg! f i
| _, _ => panic! "invalid index"
/-- Similar to `getRevArg!` but skips `mdata` -/
def getRevArg!' : Expr Nat Expr
| mdata _ a, i => getRevArg!' a i
| app _ a, 0 => a
| app f _, i+1 => getRevArg!' f i
| _, _ => panic! "invalid index"
/-- Given `f a₀ a₁ ... aₙ`, returns the `i`th argument or panics if out of bounds. -/
@[inline] def getArg! (e : Expr) (i : Nat) (n := e.getAppNumArgs) : Expr :=
getRevArg! e (n - i - 1)
/-- Similar to `getArg!`, but skips mdata -/
@[inline] def getArg!' (e : Expr) (i : Nat) (n := e.getAppNumArgs) : Expr :=
getRevArg!' e (n - i - 1)
/-- Given `f a₀ a₁ ... aₙ`, returns the `i`th argument or returns `v₀` if out of bounds. -/
@[inline] def getArgD (e : Expr) (i : Nat) (v₀ : Expr) (n := e.getAppNumArgs) : Expr :=
getRevArgD e (n - i - 1) v₀

View File

@@ -21,20 +21,22 @@ It also provides support for the following exceptional cases.
/-- Returns `some n` if `e` is a raw natural number, i.e., it is of the form `.lit (.natVal n)`. -/
def getRawNatValue? (e : Expr) : Option Nat :=
match e with
match e.consumeMData with
| .lit (.natVal n) => some n
| _ => none
/-- Return `some (n, type)` if `e` is an `OfNat.ofNat`-application encoding `n` for a type with name `typeDeclName`. -/
def getOfNatValue? (e : Expr) (typeDeclName : Name) : MetaM (Option (Nat × Expr)) := OptionT.run do
let e := e.consumeMData
guard <| e.isAppOfArity' ``OfNat.ofNat 3
let type whnfD e.appFn!.appFn!.appArg!
let type whnfD (e.getArg!' 0)
guard <| type.getAppFn.isConstOf typeDeclName
let .lit (.natVal n) := e.appFn!.appArg! | failure
let .lit (.natVal n) := (e.getArg!' 1).consumeMData | failure
return (n, type)
/-- Return `some n` if `e` is a raw natural number or an `OfNat.ofNat`-application encoding `n`. -/
def getNatValue? (e : Expr) : MetaM (Option Nat) := do
let e := e.consumeMData
if let some n := getRawNatValue? e then
return some n
let some (n, _) getOfNatValue? e ``Nat | return none
@@ -45,14 +47,14 @@ def getIntValue? (e : Expr) : MetaM (Option Int) := do
if let some (n, _) getOfNatValue? e ``Int then
return some n
if e.isAppOfArity' ``Neg.neg 3 then
let some (n, _) getOfNatValue? e.appArg!.consumeMData ``Int | return none
let some (n, _) getOfNatValue? (e.getArg!' 2) ``Int | return none
return some (-n)
return none
/-- Return `some c` if `e` is a `Char.ofNat`-application encoding character `c`. -/
def getCharValue? (e : Expr) : MetaM (Option Char) := OptionT.run do
guard <| e.isAppOfArity' ``Char.ofNat 1
let n getNatValue? e.appArg!.consumeMData
let n getNatValue? (e.getArg!' 0)
return Char.ofNat n
/-- Return `some s` if `e` is of the form `.lit (.strVal s)`. -/
@@ -72,8 +74,8 @@ def getFinValue? (e : Expr) : MetaM (Option ((n : Nat) × Fin n)) := OptionT.run
/-- Return `some ⟨n, v⟩` if `e` is af `OfNat.ofNat` application encoding a `BitVec n` with value `v` -/
def getBitVecValue? (e : Expr) : MetaM (Option ((n : Nat) × BitVec n)) := OptionT.run do
if e.isAppOfArity' ``BitVec.ofNat 2 then
let n getNatValue? e.appFn!.appArg!.consumeMData
let v getNatValue? e.appArg!.consumeMData
let n getNatValue? (e.getArg!' 0)
let v getNatValue? (e.getArg!' 1)
return n, BitVec.ofNat n v
let (v, type) getOfNatValue? e ``BitVec
IO.println v

View File

@@ -101,6 +101,12 @@ private def hasNatValPattern (p : Problem) : MetaM Bool :=
| .val v :: _ => return ( getNatValue? v).isSome
| _ => return false
private def hasIntValPattern (p : Problem) : MetaM Bool :=
p.alts.anyM fun alt => do
match alt.patterns with
| .val v :: _ => return ( getIntValue? v).isSome
| _ => return false
private def hasVarPattern (p : Problem) : Bool :=
p.alts.any fun alt => match alt.patterns with
| .var _ :: _ => true
@@ -148,13 +154,20 @@ private def isArrayLitTransition (p : Problem) : Bool :=
| .var _ :: _ => true
| _ => false
private def hasCtorOrInaccessible (p : Problem) : Bool :=
!isNextVar p ||
p.alts.any fun alt => match alt.patterns with
| .ctor .. :: _ => true
| .inaccessible _ :: _ => true
| _ => false
private def isNatValueTransition (p : Problem) : MetaM Bool := do
unless ( hasNatValPattern p) do return false
return !isNextVar p ||
p.alts.any fun alt => match alt.patterns with
| .ctor .. :: _ => true
| .inaccessible _ :: _ => true
| _ => false
return hasCtorOrInaccessible p
private def isIntValueTransition (p : Problem) : MetaM Bool := do
unless ( hasIntValPattern p) do return false
return hasCtorOrInaccessible p
private def processSkipInaccessible (p : Problem) : Problem := Id.run do
let x :: xs := p.vars | unreachable!
@@ -606,6 +619,20 @@ private def expandNatValuePattern (p : Problem) : MetaM Problem := do
| _ => return alt
return { p with alts := alts }
private def expandIntValuePattern (p : Problem) : MetaM Problem := do
let alts p.alts.mapM fun alt => do
match alt.patterns with
| .val n :: ps =>
match ( getIntValue? n) with
| some i =>
if i >= 0 then
return { alt with patterns := .ctor ``Int.ofNat [] [] [.val (toExpr i.toNat)] :: ps }
else
return { alt with patterns := .ctor ``Int.negSucc [] [] [.val (toExpr (-(i + 1)).toNat)] :: ps }
| _ => return alt
| _ => return alt
return { p with alts := alts }
private def expandFinValuePattern (p : Problem) : MetaM Problem := do
let alts p.alts.mapM fun alt => do
match alt.patterns with
@@ -665,6 +692,9 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do
else if ( isNatValueTransition p) then
traceStep ("nat value to constructor")
process ( expandNatValuePattern p)
else if ( isIntValueTransition p) then
traceStep ("int value to constructor")
process ( expandIntValuePattern p)
else if ( isFinValueTransition p) then
traceStep ("fin value to constructor")
process ( expandFinValuePattern p)