From 83de3b67a20598deb20b3fd3d08c8325866b33fc Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Tue, 31 Oct 2023 08:51:49 -0500 Subject: [PATCH] Add bitwise and other helper methods to BigBitField. Also ensures that when checking if a bit is set, that the buffer is not extended unnecessarily. These changes are partially derived from #2802 / @nyaoouo - thanks! --- peewee.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++++- tests/fields.py | 49 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/peewee.py b/peewee.py index 5f63d0b1e..97732ebba 100644 --- a/peewee.py +++ b/peewee.py @@ -5047,6 +5047,9 @@ def __init__(self, instance, name): value = bytearray(value) self._buffer = self.instance.__data__[self.name] = value + def clear(self): + self._buffer.clear() + def _ensure_length(self, idx): byte_num, byte_offset = divmod(idx, 8) cur_size = len(self._buffer) @@ -5068,9 +5071,55 @@ def toggle_bit(self, idx): return bool(self._buffer[byte_num] & (1 << byte_offset)) def is_set(self, idx): - byte_num, byte_offset = self._ensure_length(idx) + byte_num, byte_offset = divmod(idx, 8) + cur_size = len(self._buffer) + if cur_size <= byte_num: + return False return bool(self._buffer[byte_num] & (1 << byte_offset)) + __getitem__ = is_set + def __setitem__(self, item, value): + self.set_bit(item) if value else self.clear_bit(item) + __delitem__ = clear_bit + + def __len__(self): + return len(self._buffer) + + def _get_compatible_data(self, other): + if isinstance(other, BigBitFieldData): + data = other._buffer + elif isinstance(other, (bytes, bytearray, memoryview)): + data = other + else: + raise ValueError('Incompatible data-type') + diff = len(data) - len(self) + if diff > 0: self._buffer.extend(b'\x00' * diff) + return data + + def _bitwise_op(self, other, op): + if isinstance(other, BigBitFieldData): + data = other._buffer + elif isinstance(other, (bytes, bytearray, memoryview)): + data = other + else: + raise ValueError('Incompatible data-type') + buf = bytearray(b'\x00' * max(len(self), len(other))) + for i, (a, b) in enumerate(zip(self._buffer, data)): + buf[i] = op(a, b) + return buf + + def __and__(self, other): + return self._bitwise_op(other, operator.and_) + def __or__(self, other): + return self._bitwise_op(other, operator.or_) + def __xor__(self, other): + return self._bitwise_op(other, operator.xor) + + def __iter__(self): + for b in self._buffer: + for j in range(8): + yield 1 if (b & (1 << j)) else 0 + def __repr__(self): return repr(self._buffer) if sys.version_info[0] < 3: diff --git a/tests/fields.py b/tests/fields.py index 464883879..c4d1fa1d1 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -678,6 +678,35 @@ def test_bigbit_zero_idx(self): b.data.clear_bit(0) self.assertFalse(b.data.is_set(0)) + # Out-of-bounds returns False and does not extend data. + self.assertFalse(b.data.is_set(1000)) + self.assertTrue(len(b.data), 1) + + def test_bigbit_item_methods(self): + b = Bits() + idxs = [0, 1, 4, 7, 8, 15, 16, 31, 32, 63] + for i in idxs: + b.data[i] = True + for i in range(64): + self.assertEqual(b.data[i], i in idxs) + + data = list(b.data) + self.assertEqual(data, [1 if i in idxs else 0 for i in range(64)]) + + for i in range(64): + del b.data[i] + self.assertEqual(len(b.data), 8) + self.assertEqual(b.data._buffer, b'\x00' * 8) + + def test_bigbit_set_clear(self): + b = Bits() + b.data = b'\x01' + for i in range(8): + self.assertEqual(b.data[i], i == 0) + + b.data.clear() + self.assertEqual(len(b.data), 0) + def test_bigbit_field(self): b = Bits.create() b.data.set_bit(1) @@ -692,6 +721,26 @@ def test_bigbit_field(self): else: self.assertFalse(b_db.data.is_set(x)) + def test_bigbit_field_bitwise(self): + b1 = Bits(data=b'\x11') + b2 = Bits(data=b'\x12') + b3 = Bits(data=b'\x99') + self.assertEqual(b1.data & b2.data, b'\x10') + self.assertEqual(b1.data | b2.data, b'\x13') + self.assertEqual(b1.data ^ b2.data, b'\x03') + self.assertEqual(b1.data & b3.data, b'\x11') + self.assertEqual(b1.data | b3.data, b'\x99') + self.assertEqual(b1.data ^ b3.data, b'\x88') + + b1.data &= b2.data + self.assertEqual(b1.data._buffer, b'\x10') + + b1.data |= b2.data + self.assertEqual(b1.data._buffer, b'\x12') + + b1.data ^= b3.data + self.assertEqual(b1.data._buffer, b'\x8b') + def test_bigbit_field_bulk_create(self): b1, b2, b3 = Bits(), Bits(), Bits() b1.data.set_bit(1)