Compare commits

...

1 Commits

Author SHA1 Message Date
Scott Morrison
d8abaa5f34 feat: omega handles shift operators, and normalises ground term exponentials 2024-02-21 22:40:20 +11:00
4 changed files with 92 additions and 11 deletions

View File

@@ -5,6 +5,8 @@ Authors: Scott Morrison
-/
prelude
import Init.Data.Int.Order
import Init.Data.Int.DivModLemmas
import Init.Data.Nat.Lemmas
/-!
# Lemmas about `Nat` and `Int` needed internally by `omega`.
@@ -43,6 +45,12 @@ theorem ofNat_lt_of_lt {x y : Nat} (h : x < y) : (x : Int) < (y : Int) :=
theorem ofNat_le_of_le {x y : Nat} (h : x y) : (x : Int) (y : Int) :=
Int.ofNat_le.mpr h
theorem ofNat_shiftLeft_eq {x y : Nat} : (x <<< y : Int) = (x : Int) * (2 ^ y : Nat) := by
simp [Nat.shiftLeft_eq]
theorem ofNat_shiftRight_eq_div_pow {x y : Nat} : (x >>> y : Int) = (x : Int) / (2 ^ y : Nat) := by
simp [Nat.shiftRight_eq_div_pow]
-- FIXME these are insane:
theorem lt_of_not_ge {x y : Int} (h : ¬ (x y)) : y < x := Int.not_le.mp h
theorem lt_of_not_le {x y : Int} (h : ¬ (x y)) : y < x := Int.not_le.mp h

View File

@@ -24,6 +24,24 @@ Allow elaboration of `OmegaConfig` arguments to tactics.
declare_config_elab elabOmegaConfig Lean.Meta.Omega.OmegaConfig
/--
The current `ToExpr` instance for `Int` is bad,
so we roll our own here.
-/
def mkInt (i : Int) : Expr :=
if 0 i then
mkNat i.toNat
else
mkApp3 (.const ``Neg.neg [0]) (.const ``Int []) (mkNat (-i).toNat)
(.const ``Int.instNegInt [])
where
mkNat (n : Nat) : Expr :=
let r := mkRawNatLit n
mkApp3 (.const ``OfNat.ofNat [0]) (.const ``Int []) r
(.app (.const ``instOfNat []) r)
/--
A partially processed `omega` context.
@@ -121,7 +139,7 @@ We also transform the expression as we descend into it:
-/
partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × HashSet Expr) := do
trace[omega] "processing {e}"
match e.int? with
match groundInt? e with
| some i =>
let lc := {const := i}
return lc, mkEvalRflProof e lc,
@@ -184,17 +202,20 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
| some r => pure r
| none => mkAtomLinearCombo e
| (``HMod.hMod, #[_, _, _, _, n, k]) =>
match natCast? k with
| some _ => rewrite e (mkApp2 (.const ``Int.emod_def []) n k)
match groundNat? k with
| some k' => do
let k' := mkInt k'
rewrite ( mkAppM ``HMod.hMod #[n, k']) (mkApp2 (.const ``Int.emod_def []) n k')
| none => mkAtomLinearCombo e
| (``HDiv.hDiv, #[_, _, _, _, x, z]) =>
match intCast? z with
match groundInt? z with
| some 0 => rewrite e (mkApp (.const ``Int.ediv_zero []) x)
| some i =>
| some i => do
let e' mkAppM ``HDiv.hDiv #[x, mkInt i]
if i < 0 then
rewrite e (mkApp2 (.const ``Int.ediv_neg []) x (toExpr (-i)))
rewrite e' (mkApp2 (.const ``Int.ediv_neg []) x (mkInt (-i)))
else
mkAtomLinearCombo e
mkAtomLinearCombo e'
| _ => mkAtomLinearCombo e
| (``Min.min, #[_, _, a, b]) =>
if ( cfg).splitMinMax then
@@ -223,6 +244,9 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
| (``HMod.hMod, #[_, _, _, _, a, b]) => rewrite e (mkApp2 (.const ``Int.ofNat_emod []) a b)
| (``HSub.hSub, #[_, _, _, _, mkAppN (.const ``HSub.hSub _) #[_, _, _, _, a, b], c]) =>
rewrite e (mkApp3 (.const ``Int.ofNat_sub_sub []) a b c)
| (``HPow.hPow, #[_, _, _, _, a, b]) => match groundNat? a, groundNat? b with
| some _, some _ => rewrite e (mkApp2 (.const ``Int.ofNat_pow []) a b)
| _, _ => mkAtomLinearCombo e
| (``Prod.fst, #[_, β, p]) => match p with
| .app (.app (.app (.app (.const ``Prod.mk [0, v]) _) _) x) y =>
rewrite e (mkApp3 (.const ``Int.ofNat_fst_mk [v]) β x y)
@@ -233,6 +257,10 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
| _ => mkAtomLinearCombo e
| (``Min.min, #[_, _, a, b]) => rewrite e (mkApp2 (.const ``Int.ofNat_min []) a b)
| (``Max.max, #[_, _, a, b]) => rewrite e (mkApp2 (.const ``Int.ofNat_max []) a b)
| (``HShiftLeft.hShiftLeft, #[_, _, _, _, a, b]) =>
rewrite e (mkApp2 (.const ``Int.ofNat_shiftLeft_eq []) a b)
| (``HShiftRight.hShiftRight, #[_, _, _, _, a, b]) =>
rewrite e (mkApp2 (.const ``Int.ofNat_shiftRight_eq_div_pow []) a b)
| (``Int.natAbs, #[n]) =>
if ( cfg).splitNatAbs then
rewrite e (mkApp (.const ``Int.ofNat_natAbs []) n)

View File

@@ -108,6 +108,45 @@ def intCast? (n : Expr) : Option Int :=
| (``Nat.cast, #[_, _, n]) => n.nat?
| _ => n.int?
/--
If `groundNat? e = some n`, then `e` is definitionally equal to `OfNat.ofNat n`.
-/
-- We may want to replace this with an implementation using
-- the internals of `simp (config := {ground := true})`
partial def groundNat? (e : Expr) : Option Nat :=
match e.getAppFnArgs with
| (``Nat.cast, #[_, _, n]) => groundNat? n
| (``HAdd.hAdd, #[_, _, _, _, x, y]) => op (· + ·) x y
| (``HMul.hMul, #[_, _, _, _, x, y]) => op (· * ·) x y
| (``HSub.hSub, #[_, _, _, _, x, y]) => op (· - ·) x y
| (``HDiv.hDiv, #[_, _, _, _, x, y]) => op (· / ·) x y
| (``HPow.hPow, #[_, _, _, _, x, y]) => op (· ^ ·) x y
| _ => e.nat?
where op (f : Nat Nat Nat) (x y : Expr) : Option Nat :=
match groundNat? x, groundNat? y with
| some x', some y' => some (f x' y')
| _, _ => none
/--
If `groundInt? e = some i`,
then `e` is definitionally equal to the standard expression for `i`.
-/
partial def groundInt? (e : Expr) : Option Int :=
match e.getAppFnArgs with
| (``Nat.cast, #[_, _, n]) => groundNat? n
| (``HAdd.hAdd, #[_, _, _, _, x, y]) => op (· + ·) x y
| (``HMul.hMul, #[_, _, _, _, x, y]) => op (· * ·) x y
| (``HSub.hSub, #[_, _, _, _, x, y]) => op (· - ·) x y
| (``HDiv.hDiv, #[_, _, _, _, x, y]) => op (· / ·) x y
| (``HPow.hPow, #[_, _, _, _, x, y]) => match groundInt? x, groundNat? y with
| some x', some y' => some (x' ^ y')
| _, _ => none
| _ => e.int?
where op (f : Int Int Int) (x y : Expr) : Option Int :=
match groundNat? x, groundNat? y with
| some x', some y' => some (f x' y')
| _, _ => none
/-- Construct the term with type hint `(Eq.refl a : a = b)`-/
def mkEqReflWithExpectedType (a b : Expr) : MetaM Expr := do
mkExpectedTypeHint ( mkEqRefl a) ( mkEq a b)

View File

@@ -381,6 +381,11 @@ example (i : Fin 7) : (i : Nat) < 8 := by omega
example (x y z i : Nat) (hz : z 1) : x % 2 ^ i + y % 2 ^ i + z < 2 * 2^ i := by omega
/-! ### Ground terms -/
example : 2^7 < 165 := by omega
example (_ : x % 2^7 < 3) : x % 128 < 5 := by omega
/-! ### BitVec -/
-- Currently these tests require calling `simp` with many lemmas,
-- and sometimes adding `toNat_lt` as a hypothesis.
@@ -392,15 +397,16 @@ example (x y : BitVec 8) (hx : x < 16) (hy : y < 16) : x + y < 31 := by
simp [BitVec.lt_def] at *
omega
example (x y z : BitVec 8) (hx : x >>> 1 < 16) (hy : y < 16) (hz : z = x + 2 * y) : z 64 := by
simp [BitVec.lt_def, BitVec.le_def, BitVec.toNat_eq, Nat.shiftRight_eq_div_pow, BitVec.toNat_mul] at *
example (x y z : BitVec 8)
(hx : x >>> 1 < 16) (hy : y < 16) (hz : z = x + 2 * y) : z 64 := by
simp [BitVec.lt_def, BitVec.le_def, BitVec.toNat_eq, BitVec.toNat_mul] at *
omega
example (x : BitVec 8) (hx : (x + 1) <<< 1 = 3) : False := by
simp [BitVec.toNat_eq, Nat.shiftLeft_eq] at *
simp [BitVec.toNat_eq] at *
omega
example (x : BitVec 8) (hx : (x + 1) <<< 1 = 4) : x = 1 x = 129 := by
have := toNat_lt x
simp [BitVec.toNat_eq, Nat.shiftLeft_eq, BitVec.lt_def] at *
simp [BitVec.toNat_eq, BitVec.lt_def] at *
omega