Compare commits

...

3 Commits

Author SHA1 Message Date
Scott Morrison
0faf31008f cleanup 2024-03-06 09:17:05 +11:00
Scott Morrison
1f9f87d2f5 better unrolling 2024-03-06 09:09:06 +11:00
Scott Morrison
ae4123d9c0 feat: refactor of BitVec/Bitblast.lean 2024-03-05 22:37:10 +11:00
2 changed files with 188 additions and 0 deletions

View File

@@ -98,6 +98,128 @@ theorem carry_succ (i : Nat) (x y : BitVec w) (c : Bool) :
exact mod_two_pow_add_mod_two_pow_add_bool_lt_two_pow_succ ..
cases x.toNat.testBit i <;> cases y.toNat.testBit i <;> (simp; omega)
/--
Does the addition of two `BitVec`s overflow?
The nice feature of this definition is that
it can be unfolded recursively to a circuit:
```
example (x y : BitVec 4) :
addOverflow x y =
atLeastTwo (x.getLsb 3) (y.getLsb 3) (atLeastTwo (x.getLsb 2) (y.getLsb 2)
(atLeastTwo (x.getLsb 1) (y.getLsb 1) (x.getLsb 0 && y.getLsb 0))) := by
simp [addOverflow, msb_truncate, BitVec.msb, getMsb]
```
-/
def addOverflow (x y : BitVec w) (c : Bool := false) : Bool :=
match w with
| 0 => c
| (w + 1) => atLeastTwo x.msb y.msb (addOverflow (x.truncate w) (y.truncate w) c)
@[simp] theorem addOverflow_length_zero {x y : BitVec 0} : addOverflow x y c = c := rfl
theorem addOverflow_length_succ {x y : BitVec (w+1)} :
addOverflow x y c = atLeastTwo x.msb y.msb (addOverflow (x.truncate w) (y.truncate w) c) :=
rfl
@[simp] theorem addOverflow_zero_left_succ :
addOverflow 0#(w+1) y c = (y.msb && addOverflow 0#w (y.truncate w) c) := by
simp [addOverflow]
@[simp] theorem addOverflow_zero_right_succ {x : BitVec (w+1)} :
addOverflow x 0#(w+1) c = (x.msb && addOverflow (x.truncate w) 0#w c) := by
simp [addOverflow]
@[simp] theorem addOverflow_zero_zero :
addOverflow 0#i 0#i c = (decide (i = 0) && c) := by
cases i <;> simp
theorem carry_eq_addOverflow (i) (x y : BitVec w) (c) :
carry i x y c = addOverflow (x.truncate i) (y.truncate i) c := by
match i with
| 0 => simp
| (i + 1) =>
rw [addOverflow_length_succ, carry_succ, carry_eq_addOverflow]
simp [msb_zeroExtend, Nat.le_succ]
theorem addOverflow_eq_carry {x y : BitVec w} :
addOverflow x y c = carry w x y c := by
have := carry_eq_addOverflow w x y c
simpa using this.symm
theorem addOverflow_cons_cons :
addOverflow (cons a x) (cons b y) = atLeastTwo a b (addOverflow x y) := by
simp [addOverflow]
theorem add_cons_cons (w) (x y : BitVec w) :
(cons a x) + (cons b y) = cons (Bool.xor a (Bool.xor b (addOverflow x y))) (x + y) := by
have pos : 0 < 2^w := Nat.pow_pos Nat.zero_lt_two
apply eq_of_toNat_eq
simp only [toNat_add, toNat_cons']
rw [addOverflow_eq_carry, carry]
simp [Nat.mod_pow_succ]
by_cases h : 2 ^ w x.toNat + y.toNat
· simp [h]
have p : (x.toNat + y.toNat) / 2 ^ w = 1 := by
apply Nat.div_eq_of_lt_le <;> omega
cases a <;> cases b
<;> simp [Nat.one_shiftLeft, Nat.add_left_comm x.toNat, Nat.add_assoc, p, pos]
<;> simp [Nat.add_comm]
· simp [h]
have p : (x.toNat + y.toNat) / 2 ^ w = 0 := by
apply Nat.div_eq_of_lt_le <;> omega
cases a <;> cases b
<;> simp [Nat.one_shiftLeft, Nat.add_left_comm x.toNat, Nat.add_assoc, p, pos]
<;> simp [Nat.add_comm]
theorem msb_add (x y : BitVec w) :
(x + y).msb =
Bool.xor x.msb (Bool.xor y.msb (addOverflow (x.truncate (w-1)) (y.truncate (w-1)))) := by
cases w with
| zero => simp
| succ w =>
conv =>
lhs
rw [eq_msb_cons_truncate x, eq_msb_cons_truncate y, add_cons_cons]
simp [succ_eq_add_one, Nat.add_one_sub_one]
/--
Variant of `getLsb_add` in terms of `addOverflow` rather than `carry`.
-/
theorem getLsb_add' (i : Nat) (x y : BitVec w) :
getLsb (x + y) i = (decide (i < w) && Bool.xor (x.getLsb i)
(Bool.xor (y.getLsb i) (addOverflow (x.truncate i) (y.truncate i)))) := by
by_cases h : i < w
· rw [ msb_truncate (x + y), truncate_add, msb_add, msb_truncate, msb_truncate]
rw [Nat.add_one_sub_one, truncate_truncate_of_le, truncate_truncate_of_le]
simp [h]
all_goals omega
· simp [h]
simp at h
simp [h]
theorem addOverflow_eq_false_of_and_eq_zero {x y : BitVec w} (h : x &&& y = 0#w) :
addOverflow x y = false := by
induction w with
| zero => rfl
| succ w ih =>
have h₁ := congrArg BitVec.msb h
have h₂ := congrArg (·.truncate w) h
simp at h₁ h₂
simp_all [addOverflow_length_succ]
theorem or_eq_add_of_and_eq_zero (x y : BitVec w) (h : x &&& y = 0) :
x ||| y = x + y := by
ext i
have h₁ := congrArg (getLsb · i) h
have h₂ := congrArg (truncate i) h
simp at h₁ h₂
simp only [getLsb_add', getLsb_or]
rw [addOverflow_eq_false_of_and_eq_zero h₂]
-- sat
revert h₁
cases x.getLsb i <;> cases y.getLsb i <;> simp
/-- Carry function for bitwise addition. -/
def adcb (x y c : Bool) : Bool × Bool := (atLeastTwo x y c, Bool.xor x (Bool.xor y c))

View File

@@ -0,0 +1,66 @@
open BitVec
/-!
This is not how you should implement a `bitblast` tactic!
Relying on the simplifier to unroll the bitwise quantifier is not efficient.
A proper bitblaster is in the works.
Nevertheless this is a simple test bed for BitVec lemmas.
-/
theorem Fin.forall_eq_forall_lt (p : Fin n Prop) [DecidablePred p] :
( (x : Fin n), p x) ( (x : Fin n), x < n p x) := by
simp
theorem Fin.forall_lt_succ (p : Fin n Prop) [DecidablePred p] (k : Nat) :
( (x : Fin n), x < (k + 1) p x)
if h : k < n then
(p k, h (x : Fin n), x < k p x)
else
(x : Fin n), x < k p x := by
constructor
· intro w
split <;> rename_i h
· constructor
· exact w k, h (by dsimp; omega)
· intro x q
exact w x (by omega)
· intro x q
exact w _ (by omega)
· intro w x q
split at w <;> rename_i h
· by_cases r : x = k
· subst r
apply w.1
· apply w.2
omega
· exact w _ (by omega)
theorem Fin.forall_lt_zero (p : Fin n Prop) [DecidablePred p] :
( (x : Fin n), x < (0 : Nat) p x) True :=
fun _ => trivial, nofun
macro "bitblast" : tactic => `(tactic|
( apply eq_of_getLsb_eq
rw [Fin.forall_eq_forall_lt]
repeat rw [Fin.forall_lt_succ, dif_pos (by decide)]
rw [Fin.forall_lt_zero]
simp [getLsb_add', addOverflow, msb_eq_getLsb_last]))
-- Examples not involving addition:
example (x : BitVec 64) :
(x <<< 32 >>> 32) = (x.truncate 32).zeroExtend 64 := by
bitblast
example (x : BitVec 64) : (x <<< 32) &&& (x >>> 32) = 0 := by
bitblast
-- Examples involving addition:
-- (Notice here we are limited to small widths, because of the bad implementation.)
example (x y : BitVec 32) : (x + y) <<< 1 = (x <<< 1) + (y <<< 1) := by
bitblast
example (x y : BitVec 32) :
(x + y) &&& 255#32 = (x.truncate 8 + y.truncate 8).zeroExtend 32 := by
bitblast