Compare commits

...

2 Commits

Author SHA1 Message Date
Leonardo de Moura
c2cfd286c5 feat: fold numeric literals 2024-03-03 18:31:16 -08:00
Leonardo de Moura
98112b73da refactor: use match_expr at evalNat and toLinearExpr
This commit also adds support for `/`, `%`, and `^` at `evalNat`
2024-03-03 18:19:14 -08:00
7 changed files with 129 additions and 57 deletions

View File

@@ -18,9 +18,10 @@ private abbrev withInstantiatedMVars (e : Expr) (k : Expr → OptionT MetaM α)
k eNew
def isNatProjInst (declName : Name) (numArgs : Nat) : Bool :=
(numArgs == 4 && (declName == ``Add.add || declName == ``Sub.sub || declName == ``Mul.mul))
|| (numArgs == 6 && (declName == ``HAdd.hAdd || declName == ``HSub.hSub || declName == ``HMul.hMul))
|| (numArgs == 3 && declName == ``OfNat.ofNat)
(numArgs == 4 && (declName == ``Add.add || declName == ``Sub.sub || declName == ``Mul.mul || declName == ``Div.div || declName == ``Mod.mod || declName == ``NatPow.pow))
|| (numArgs == 5 && (declName == ``Pow.pow))
|| (numArgs == 6 && (declName == ``HAdd.hAdd || declName == ``HSub.hSub || declName == ``HMul.hMul || declName == ``HDiv.hDiv || declName == ``HMod.hMod || declName == ``HPow.hPow))
|| (numArgs == 3 && declName == ``OfNat.ofNat)
/--
Evaluate simple `Nat` expressions.
@@ -35,31 +36,21 @@ partial def evalNat (e : Expr) : OptionT MetaM Nat := do
| _ => failure
where
visit e := do
let f := e.getAppFn
match f with
| .mvar .. => withInstantiatedMVars e evalNat
| .const c _ =>
let nargs := e.getAppNumArgs
if c == ``Nat.succ && nargs == 1 then
let v evalNat (e.getArg! 0)
return v+1
else if c == ``Nat.add && nargs == 2 then
let v₁ evalNat (e.getArg! 0)
let v₂ evalNat (e.getArg! 1)
return v₁ + v₂
else if c == ``Nat.sub && nargs == 2 then
let v₁ evalNat (e.getArg! 0)
let v₂ evalNat (e.getArg! 1)
return v₁ - v₂
else if c == ``Nat.mul && nargs == 2 then
let v₁ evalNat (e.getArg! 0)
let v₂ evalNat (e.getArg! 1)
return v₁ * v₂
else if isNatProjInst c nargs then
match_expr e with
| Nat.succ a => return ( evalNat a) + 1
| Nat.add a b => return ( evalNat a) + ( evalNat b)
| Nat.sub a b => return ( evalNat a) - ( evalNat b)
| Nat.mul a b => return ( evalNat a) * ( evalNat b)
| Nat.div a b => return ( evalNat a) / ( evalNat b)
| Nat.mod a b => return ( evalNat a) % ( evalNat b)
| Nat.pow a b => return ( evalNat a) ^ ( evalNat b)
| _ =>
let e instantiateMVarsIfMVarApp e
let f := e.getAppFn
if f.isConst && isNatProjInst f.constName! e.getAppNumArgs then
evalNat ( unfoldProjInst? e)
else
failure
| _ => failure
mutual

View File

@@ -74,42 +74,31 @@ def addAsVar (e : Expr) : M LinearExpr := do
partial def toLinearExpr (e : Expr) : M LinearExpr := do
match e with
| Expr.lit (Literal.natVal n) => return num n
| Expr.mdata _ e => toLinearExpr e
| Expr.const ``Nat.zero .. => return num 0
| Expr.app .. => visit e
| Expr.mvar .. => visit e
| _ => addAsVar e
| .lit (.natVal n) => return num n
| .mdata _ e => toLinearExpr e
| .const ``Nat.zero .. => return num 0
| .app .. => visit e
| .mvar .. => visit e
| _ => addAsVar e
where
visit (e : Expr) : M LinearExpr := do
let f := e.getAppFn
match f with
| Expr.mvar .. =>
let eNew instantiateMVars e
if eNew != e then
toLinearExpr eNew
match_expr e with
| Nat.succ a => return inc ( toLinearExpr a)
| Nat.add a b => return add ( toLinearExpr a) ( toLinearExpr b)
| Nat.mul a b =>
match ( evalNat a |>.run) with
| some k => return mulL k ( toLinearExpr b)
| none => match ( evalNat b |>.run) with
| some k => return mulR ( toLinearExpr a) k
| none => addAsVar e
| _ =>
let e instantiateMVarsIfMVarApp e
let f := e.getAppFn
if f.isConst && isNatProjInst f.constName! e.getAppNumArgs then
let some e unfoldProjInst? e | addAsVar e
toLinearExpr e
else
addAsVar e
| Expr.const declName .. =>
let numArgs := e.getAppNumArgs
if declName == ``Nat.succ && numArgs == 1 then
return inc ( toLinearExpr e.appArg!)
else if declName == ``Nat.add && numArgs == 2 then
return add ( toLinearExpr (e.getArg! 0)) ( toLinearExpr (e.getArg! 1))
else if declName == ``Nat.mul && numArgs == 2 then
match ( evalNat (e.getArg! 0) |>.run) with
| some k => return mulL k ( toLinearExpr (e.getArg! 1))
| none => match ( evalNat (e.getArg! 1) |>.run) with
| some k => return mulR ( toLinearExpr (e.getArg! 0)) k
| none => addAsVar e
else if isNatProjInst declName numArgs then
if let some e unfoldProjInst? e then
toLinearExpr e
else
addAsVar e
else
addAsVar e
| _ => addAsVar e
partial def toLinearCnstr? (e : Expr) : M (Option LinearCnstr) := do
let f := e.getAppFn

View File

@@ -268,4 +268,15 @@ builtin_simproc [simp, seval] reduceAllOnes (allOnes _) := fun e => do
let some n Nat.fromExpr? n | return .continue
return .done { expr := toExpr (allOnes n) }
builtin_simproc [simp, seval] reduceBitVecOfFin (BitVec.ofFin _) := fun e => do
let_expr BitVec.ofFin w v e | return .continue
let some w evalNat w |>.run | return .continue
let some _, v getFinValue? v | return .continue
return .done { expr := toExpr (BitVec.ofNat w v.val) }
builtin_simproc [simp, seval] reduceBitVecToFin (BitVec.toFin _) := fun e => do
let_expr BitVec.toFin _ v e | return .continue
let some _, v getBitVecValue? v | return .continue
return .done { expr := toExpr v.toFin }
end BitVec

View File

@@ -71,4 +71,13 @@ builtin_simproc [simp, seval] isValue ((OfNat.ofNat _ : Fin _)) := fun e => do
return .done { expr := e }
return .done { expr := toExpr v }
builtin_simproc [simp, seval] reduceFinMk (Fin.mk _ _) := fun e => do
let_expr Fin.mk n v _ e | return .continue
let some n evalNat n |>.run | return .continue
let some v getNatValue? v | return .continue
if h : n > 0 then
return .done { expr := toExpr (Fin.ofNat' v h) }
else
return .continue
end Fin

View File

@@ -89,4 +89,14 @@ builtin_simproc [simp, seval] reduceBNe (( _ : Int) != _) := reduceBoolPred ``
builtin_simproc [simp, seval] reduceAbs (natAbs _) := reduceNatCore ``natAbs natAbs
builtin_simproc [simp, seval] reduceToNat (Int.toNat _) := reduceNatCore ``Int.toNat Int.toNat
builtin_simproc [simp, seval] reduceNegSucc (Int.negSucc _) := fun e => do
let_expr Int.negSucc a e | return .continue
let some a getNatValue? a | return .continue
return .done { expr := toExpr (-(Int.ofNat a + 1)) }
builtin_simproc [simp, seval] reduceOfNat (Int.ofNat _) := fun e => do
let_expr Int.ofNat a e | return .continue
let some a getNatValue? a | return .continue
return .done { expr := toExpr (Int.ofNat a) }
end Int

View File

@@ -60,6 +60,12 @@ builtin_simproc [simp, seval] $(mkIdent `reduceOfNatCore):ident ($ofNatCore _ _)
let value := $(mkIdent ofNat) value
return .done { expr := toExpr value }
builtin_simproc [simp, seval] $(mkIdent `reduceOfNat):ident ($(mkIdent ofNat) _) := fun e => do
unless e.isAppOfArity $(quote ofNat) 1 do return .continue
let some value Nat.fromExpr? e.appArg! | return .continue
let value := $(mkIdent ofNat) value
return .done { expr := toExpr value }
builtin_simproc [simp, seval] $(mkIdent `reduceToNat):ident ($toNat _) := fun e => do
unless e.isAppOfArity $(quote toNat.getId) 1 do return .continue
let some v ($fromExpr e.appArg!) | return .continue

View File

@@ -0,0 +1,56 @@
open BitVec
example : (Fin.mk 5 (by decide) : Fin 10) + 2 = x := by
simp
guard_target = 7 = x
sorry
example : (Fin.mk 5 (by decide) : Fin 10) + 2 = x := by
simp (config := { ground := true }) only
guard_target = 7 = x
sorry
example : (BitVec.ofFin (Fin.mk 2 (by decide)) : BitVec 32) + 2 = x := by
simp
guard_target = 4#32 = x
sorry
example : (BitVec.ofFin (Fin.mk 2 (by decide)) : BitVec 32) + 2 = x := by
simp (config := { ground := true }) only
guard_target = 4#32 = x
sorry
example : (BitVec.ofFin 2 : BitVec 32) + 2 = x := by
simp
guard_target = 4#32 = x
sorry
example (h : -2 = x) : Int.negSucc 3 + 2 = x := by
simp
guard_target = -2 = x
assumption
example (h : -2 = x) : Int.negSucc 3 + 2 = x := by
simp (config := { ground := true }) only
guard_target = -2 = x
assumption
example : Int.ofNat 3 + 2 = x := by
simp
guard_target = 5 = x
sorry
example : Int.ofNat 3 + 2 = x := by
simp (config := { ground := true }) only
guard_target = 5 = x
sorry
example (h : 5 = x) : UInt32.ofNat 2 + 3 = x := by
simp
guard_target = 5 = x
assumption
example (h : 5 = x) : UInt32.ofNat 2 + 3 = x := by
simp (config := { ground := true }) only
guard_target = 5 = x
assumption