Skip to content

Commit

Permalink
feat: fix BitVec.abs, prove toInt produces the expected value.
Browse files Browse the repository at this point in the history
The previous definition of `abs` was incorrect when only the msb was `1` and all other bits were `0`.
For example, consider bit-width 3:

```
100 -- 4#3
```

If we compute `-x`, i.e. `!x + 1`, we get:

```
 011
+001
 ---
 100
```

We recover `4#3` once again.

The semantically correct implementation can use `BitVec.slt`,
and we can prove the equivalence to the bit-fiddling hack:

```
// https://math.stackexchange.com/q/2565736/261373
int iabs(int a) {
   int t = a >> 31;
   a = (a^t) - t;
   return a;
}
```

written in lean, this is:

```
def BitVec.abs' (x : BitVec w) :
  let t := x >> (w - 1)
  (x ^^^ t) - t
```
  • Loading branch information
bollu committed Oct 21, 2024
1 parent b814be6 commit 6a110aa
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 13 deletions.
3 changes: 0 additions & 3 deletions src/Init/Data/BitVec/Bitblast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,6 @@ theorem slt_eq_not_carry (x y : BitVec w) :
simp only [slt_eq_ult, bne, ult_eq_not_carry]
cases x.msb == y.msb <;> simp

theorem sle_eq_not_slt (x y : BitVec w) : x.sle y = !y.slt x := by
simp only [BitVec.sle, BitVec.slt, ← decide_not, decide_eq_decide]; omega

theorem sle_eq_carry (x y : BitVec w) :
x.sle y = !((x.msb == y.msb).xor (carry w y (~~~x) true)) := by
rw [sle_eq_not_slt, slt_eq_not_carry, beq_comm]
Expand Down
112 changes: 102 additions & 10 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Init.Data.Fin.Lemmas
import Init.Data.Nat.Lemmas
import Init.Data.Nat.Mod
import Init.Data.Int.Bitwise.Lemmas
import Init.Data.Int.Lemmas
import Init.Data.Int.Pow

set_option linter.missingDocs true
Expand Down Expand Up @@ -206,6 +207,7 @@ theorem eq_of_getMsbD_eq {x y : BitVec w}
theorem of_length_zero {x : BitVec 0} : x = 0#0 := by ext; simp

theorem toNat_zero_length (x : BitVec 0) : x.toNat = 0 := by simp [of_length_zero]
theorem toInt_length_zero (x : BitVec 0) : x.toInt = 0 := by simp [of_length_zero]
theorem getLsbD_zero_length (x : BitVec 0) : x.getLsbD i = false := by simp
theorem getMsbD_zero_length (x : BitVec 0) : x.getMsbD i = false := by simp
theorem msb_zero_length (x : BitVec 0) : x.msb = false := by simp [BitVec.msb, of_length_zero]
Expand Down Expand Up @@ -2070,16 +2072,6 @@ theorem smod_zero {x : BitVec n} : x.smod 0#n = x := by
· simp
· by_cases h : x = 0#n <;> simp [h]

/-! ### abs -/

@[simp, bv_toNat]
theorem toNat_abs {x : BitVec w} : x.abs.toNat = if x.msb then 2^w - x.toNat else x.toNat := by
simp only [BitVec.abs, neg_eq]
by_cases h : x.msb = true
· simp only [h, ↓reduceIte, toNat_neg]
have : 2 * x.toNat ≥ 2 ^ w := BitVec.msb_eq_true_iff_two_mul_ge.mp h
rw [Nat.mod_eq_of_lt (by omega)]
· simp [h]

/-! ### mul -/

Expand Down Expand Up @@ -2674,6 +2666,106 @@ theorem getLsbD_intMax (w : Nat) : (intMax w).getLsbD i = decide (i + 1 < w) :=
· rw [Nat.sub_add_cancel (Nat.two_pow_pos (w - 1)), Nat.two_pow_pred_mod_two_pow (by omega)]


/-! ### abs -/

private theorem two_pow_plus_one_div_two (w : Nat) : ((2^w + 1) / 2) = 2^(w - 1) := by
apply Nat.div_eq_of_lt_le
· rcases w with rfl | w
· decide
· simp
omega
· rcases w with rfl | w
· decide
· rw [Nat.add_one_sub_one, Nat.add_mul,
Nat.one_mul, Nat.add_lt_add_iff_right,
Nat.pow_add]
omega

theorem abs_def {x : BitVec w} : x.abs = if x.msb then .neg x else x := rfl

/-- The value of the bitvector (interpreted as an integer) is always less than 2^w -/
theorem toInt_lt (x : BitVec w) : x.toInt < 2 ^ w := by
rw [toInt_eq_msb_cond]
norm_cast
omega

/-- The negation value of the bitvector (interpreted as an integer) is always less than 2^w -/
theorem neg_toInt_lt (x : BitVec w) : - x.toInt < 2 ^ w := by
have := toInt_lt x
rw [toInt_eq_msb_cond]
split
case isTrue h =>
simp only [gt_iff_lt]
norm_cast
have := msb_eq_true_iff_two_mul_ge.mp h
omega
case isFalse h =>
norm_cast
omega

/-- The msb of `intMin w` is `true` for all `w > 0` -/
@[simp]
theorem msb_intMin : (intMin w).msb = decide (w > 0) := by
rw [intMin]
rw [msb_eq_decide]
simp
rcases w with rfl | w
· rfl
· simp
have : 0 < 2^w := Nat.pow_pos (by decide)
have : 2^w < 2^(w + 1) := by
rw [Nat.pow_succ]
omega
rw [Nat.mod_eq_of_lt (by omega)]
simp

@[simp]
theorem abs_intMin : (intMin w).abs = intMin w := by
rw [abs_def]
simp [msb_intMin]


@[simp] theorem toInt_zero (w : Nat) : (0#w).toInt = 0 := by
simp [BitVec.toInt]
omega

/-- the msb is true iff the bitvector , when interpreted as a signed 2s complement number, is less than zero -/
theorem msb_eq_decide_slt_zero (x : BitVec w) : x.msb = decide (x.slt 0#w) := by
simp only [BitVec.slt, toInt_eq_msb_cond, msb_zero, Bool.false_eq_true,
↓reduceIte, toNat_ofNat, Nat.zero_mod, Int.Nat.cast_ofNat_Int, decide_eq_true_eq]
rcases h : x.msb <;> simp [h] <;> omega


theorem sle_eq_not_slt (x y : BitVec w) : x.sle y = !y.slt x := by
simp only [BitVec.sle, BitVec.slt, ← decide_not, decide_eq_decide]; omega


/-- If x >= 0, then x.abs = x -/
theorem abs_of_sle (x : BitVec w) (hx : (0#w).sle x) : x.abs = x := by
rw [abs_def, msb_eq_decide_slt_zero]
simp only [sle_eq_not_slt, Bool.not_eq_eq_eq_not, Bool.not_true] at hx
simp [hx]

/-- If x < 0, then x.abs = -x -/
theorem abs_of_slt (x : BitVec w) (hx : x.slt 0#w) : x.abs = -x := by
rw [abs_def, msb_eq_decide_slt_zero]
simp only [hx, decide_True, ↓reduceIte, neg_eq]

/-- Either x < y or x ≥ y, i.e. y ≤ x -/
theorem slt_or_sle (x y : BitVec w) : x.slt y ∨ y.sle x := by
rw [BitVec.slt, BitVec.sle]
by_cases h : x.toInt < y.toInt <;> simp [h] <;> omega

/-- TODO: what should I name this lemmas? -/
theorem abs_cases (x : BitVec w) : x.abs =
if x = intMin w then (intMin w)
else if x.slt 0 then -x
else x := by
· rw [abs_def]
rw [msb_eq_decide_slt_zero]
by_cases hx : x.slt 0#w <;> by_cases hx' : x = intMin w <;> simp [hx, hx']


/-! ### Non-overflow theorems -/

/-- If `x.toNat * y.toNat < 2^w`, then the multiplication `(x * y)` does not overflow. -/
Expand Down
7 changes: 7 additions & 0 deletions src/Init/Data/Int/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,13 @@ instance : Min Int := minOfLe

instance : Max Int := maxOfLe

/--
Return the absolute value of an integer.
-/
def abs : Int → Int
| ofNat n => .ofNat n
| negSucc n => .ofNat n.succ

end Int

/--
Expand Down
24 changes: 24 additions & 0 deletions src/Init/Data/Int/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -531,4 +531,28 @@ theorem natCast_one : ((1 : Nat) : Int) = (1 : Int) := rfl
@[simp] theorem natCast_mul (a b : Nat) : ((a * b : Nat) : Int) = (a : Int) * (b : Int) := by
simp

/-! abs lemmas -/

@[simp]
theorem abs_eq_self {x : Int} (h : x ≥ 0) : x.abs = x := by
cases x
case ofNat h =>
rfl
case negSucc h =>
contradiction

@[simp]
theorem Int.abs_zero : Int.abs 0 = 0 := rfl

@[simp]
theorem abs_eq_neg {x : Int} (h : x < 0) : x.abs = -x := by
cases x
case ofNat h =>
contradiction
case negSucc n =>
rfl

@[simp]
theorem ofNat_abs (x : Nat) : (x : Int).abs = (x : Int) := rfl

end Int

0 comments on commit 6a110aa

Please sign in to comment.