Skip to content

Commit

Permalink
replace redundant bits_per_element field with method
Browse files Browse the repository at this point in the history
  • Loading branch information
somethingelseentirely committed Aug 31, 2024
1 parent c7d6cd5 commit 14d887c
Showing 1 changed file with 52 additions and 48 deletions.
100 changes: 52 additions & 48 deletions src/wavelet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ use std::ops::Range;
#[derive(Clone, Debug)]
pub struct WaveletMatrix {
data: Box<[RsVec]>,
bits_per_element: u16,
}

impl WaveletMatrix {
Expand Down Expand Up @@ -123,7 +122,6 @@ impl WaveletMatrix {

Self {
data: data.into_iter().map(BitVec::into).collect(),
bits_per_element,
}
}

Expand Down Expand Up @@ -220,7 +218,6 @@ impl WaveletMatrix {

Self {
data: data.into_iter().map(BitVec::into).collect(),
bits_per_element,
}
}

Expand Down Expand Up @@ -296,7 +293,7 @@ impl WaveletMatrix {
/// The function is used by the `get_value` and `get_u64` functions, deduplicating code.
#[inline(always)]
fn reconstruct_value_unchecked<F: FnMut(u64)>(&self, mut i: usize, mut target_func: F) {
for level in 0..self.bits_per_element as usize {
for level in 0..self.bits_per_element() {
let bit = self.data[level].get_unchecked(i);
target_func(bit);
if bit == 0 {
Expand Down Expand Up @@ -345,8 +342,8 @@ impl WaveletMatrix {
/// [`get_value`]: WaveletMatrix::get_value
#[must_use]
pub fn get_value_unchecked(&self, i: usize) -> BitVec {
let mut value = BitVec::from_zeros(self.bits_per_element as usize);
let mut level = self.bits_per_element - 1;
let mut value = BitVec::from_zeros(self.bits_per_element());
let mut level = self.bits_per_element() - 1;
self.reconstruct_value_unchecked(i, |bit| {
value.set_unchecked(level as usize, bit);
level = level.saturating_sub(1);
Expand All @@ -372,7 +369,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn get_u64(&self, i: usize) -> Option<u64> {
if self.bits_per_element > 64 || self.data.is_empty() || i >= self.data[0].len() {
if self.bits_per_element() > 64 || self.data.is_empty() || i >= self.data[0].len() {
None
} else {
Some(self.get_u64_unchecked(i))
Expand Down Expand Up @@ -418,7 +415,7 @@ impl WaveletMatrix {
#[must_use]
pub fn rank_range_unchecked(&self, mut range: Range<usize>, symbol: &BitVec) -> usize {
for (level, data) in self.data.iter().enumerate() {
if symbol.get_unchecked((self.bits_per_element - 1) as usize - level) == 0 {
if symbol.get_unchecked((self.bits_per_element() - 1) - level) == 0 {
range.start = data.rank0(range.start);
range.end = data.rank0(range.end);
} else {
Expand Down Expand Up @@ -455,7 +452,7 @@ impl WaveletMatrix {
pub fn rank_range(&self, range: Range<usize>, symbol: &BitVec) -> Option<usize> {
if range.start >= self.len()
|| range.end > self.len()
|| symbol.len() != self.bits_per_element as usize
|| symbol.len() != self.bits_per_element()
{
None
} else {
Expand All @@ -482,7 +479,7 @@ impl WaveletMatrix {
#[must_use]
pub fn rank_range_u64_unchecked(&self, mut range: Range<usize>, symbol: u64) -> usize {
for (level, data) in self.data.iter().enumerate() {
if (symbol >> ((self.bits_per_element - 1) as usize - level)) & 1 == 0 {
if (symbol >> ((self.bits_per_element() - 1) - level)) & 1 == 0 {
range.start = data.rank0(range.start);
range.end = data.rank0(range.end);
} else {
Expand Down Expand Up @@ -515,7 +512,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn rank_range_u64(&self, range: Range<usize>, symbol: u64) -> Option<usize> {
if range.start >= self.len() || range.end > self.len() || self.bits_per_element > 64 {
if range.start >= self.len() || range.end > self.len() || self.bits_per_element() > 64 {
None
} else {
Some(self.rank_range_u64_unchecked(range, symbol))
Expand Down Expand Up @@ -583,7 +580,7 @@ impl WaveletMatrix {
if offset > i
|| offset >= self.len()
|| i > self.len()
|| symbol.len() != self.bits_per_element as usize
|| symbol.len() != self.bits_per_element()
{
None
} else {
Expand Down Expand Up @@ -643,7 +640,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn rank_offset_u64(&self, offset: usize, i: usize, symbol: u64) -> Option<usize> {
if offset > i || offset >= self.len() || i > self.len() || self.bits_per_element > 64 {
if offset > i || offset >= self.len() || i > self.len() || self.bits_per_element() > 64 {
None
} else {
Some(self.rank_offset_u64_unchecked(offset, i, symbol))
Expand Down Expand Up @@ -696,7 +693,7 @@ impl WaveletMatrix {
/// [`BitVec`]: BitVec
#[must_use]
pub fn rank(&self, i: usize, symbol: &BitVec) -> Option<usize> {
if i > self.len() || symbol.len() != self.bits_per_element as usize {
if i > self.len() || symbol.len() != self.bits_per_element() {
None
} else {
Some(self.rank_range_unchecked(0..i, symbol))
Expand Down Expand Up @@ -744,7 +741,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn rank_u64(&self, i: usize, symbol: u64) -> Option<usize> {
if i > self.len() || self.bits_per_element > 64 {
if i > self.len() || self.bits_per_element() > 64 {
None
} else {
Some(self.rank_range_u64_unchecked(0..i, symbol))
Expand Down Expand Up @@ -776,7 +773,7 @@ impl WaveletMatrix {
let mut range_start = offset;

for (level, data) in self.data.iter().enumerate() {
if symbol.get_unchecked((self.bits_per_element - 1) as usize - level) == 0 {
if symbol.get_unchecked((self.bits_per_element() - 1) - level) == 0 {
range_start = data.rank0(range_start);
} else {
range_start = data.rank0 + data.rank1(range_start);
Expand All @@ -786,7 +783,7 @@ impl WaveletMatrix {
let mut range_end = range_start + rank;

for (level, data) in self.data.iter().enumerate().rev() {
if symbol.get_unchecked((self.bits_per_element - 1) as usize - level) == 0 {
if symbol.get_unchecked((self.bits_per_element() - 1) - level) == 0 {
range_end = data.select0(range_end);
} else {
range_end = data.select1(range_end - data.rank0);
Expand Down Expand Up @@ -821,7 +818,7 @@ impl WaveletMatrix {
/// [`BitVec`]: BitVec
#[must_use]
pub fn select_offset(&self, offset: usize, rank: usize, symbol: &BitVec) -> Option<usize> {
if offset >= self.len() || symbol.len() != self.bits_per_element as usize {
if offset >= self.len() || symbol.len() != self.bits_per_element() {
None
} else {
let idx = self.select_offset_unchecked(offset, rank, symbol);
Expand Down Expand Up @@ -856,7 +853,7 @@ impl WaveletMatrix {
let mut range_start = offset;

for (level, data) in self.data.iter().enumerate() {
if (symbol >> ((self.bits_per_element - 1) as usize - level)) & 1 == 0 {
if (symbol >> ((self.bits_per_element() - 1) - level)) & 1 == 0 {
range_start = data.rank0(range_start);
} else {
range_start = data.rank0 + data.rank1(range_start);
Expand All @@ -866,7 +863,7 @@ impl WaveletMatrix {
let mut range_end = range_start + rank;

for (level, data) in self.data.iter().enumerate().rev() {
if (symbol >> ((self.bits_per_element - 1) as usize - level)) & 1 == 0 {
if (symbol >> ((self.bits_per_element() - 1) - level)) & 1 == 0 {
range_end = data.select0(range_end);
} else {
range_end = data.select1(range_end - data.rank0);
Expand Down Expand Up @@ -898,7 +895,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn select_offset_u64(&self, offset: usize, rank: usize, symbol: u64) -> Option<usize> {
if offset >= self.len() || self.bits_per_element > 64 {
if offset >= self.len() || self.bits_per_element() > 64 {
None
} else {
let idx = self.select_offset_u64_unchecked(offset, rank, symbol);
Expand Down Expand Up @@ -955,7 +952,7 @@ impl WaveletMatrix {
/// [`BitVec`]: BitVec
#[must_use]
pub fn select(&self, rank: usize, symbol: &BitVec) -> Option<usize> {
if symbol.len() == self.bits_per_element as usize {
if symbol.len() == self.bits_per_element() {
let idx = self.select_unchecked(rank, symbol);
if idx < self.len() {
Some(idx)
Expand Down Expand Up @@ -1007,7 +1004,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn select_u64(&self, rank: usize, symbol: u64) -> Option<usize> {
if self.bits_per_element > 64 {
if self.bits_per_element() > 64 {
None
} else {
let idx = self.select_u64_unchecked(rank, symbol);
Expand Down Expand Up @@ -1035,7 +1032,7 @@ impl WaveletMatrix {
/// [`quantile`]: WaveletMatrix::quantile
#[must_use]
pub fn quantile_unchecked(&self, range: Range<usize>, k: usize) -> BitVec {
let result = BitVec::from_zeros(self.bits_per_element as usize);
let result = BitVec::from_zeros(self.bits_per_element());

self.partial_quantile_search_unchecked(range, k, 0, result)
}
Expand All @@ -1053,7 +1050,7 @@ impl WaveletMatrix {
start_level: usize,
mut prefix: BitVec,
) -> BitVec {
debug_assert!(prefix.len() == self.bits_per_element as usize);
debug_assert!(prefix.len() == self.bits_per_element());
debug_assert!(!range.is_empty());
debug_assert!(range.end <= self.len());

Expand All @@ -1069,7 +1066,7 @@ impl WaveletMatrix {
} else {
// the element is among the ones, so we set the bit to 1, and move the range
// into the 1-partition of the next level
prefix.set_unchecked((self.bits_per_element - 1) as usize - level, 1);
prefix.set_unchecked((self.bits_per_element() - 1) - level, 1);
k -= zeros;
range.start = data.rank0 + (range.start - zeros_start); // range.start - zeros_start is the rank1 of range.start
range.end = data.rank0 + (range.end - zeros_end); // same here
Expand Down Expand Up @@ -1182,7 +1179,7 @@ impl WaveletMatrix {
start_level: usize,
mut prefix: u64,
) -> u64 {
debug_assert!(self.bits_per_element <= 64);
debug_assert!(self.bits_per_element() <= 64);
debug_assert!(!range.is_empty());
debug_assert!(range.end <= self.len());

Expand Down Expand Up @@ -1229,7 +1226,7 @@ impl WaveletMatrix {
pub fn quantile_u64(&self, range: Range<usize>, k: usize) -> Option<u64> {
if range.start >= self.len()
|| range.end > self.len()
|| self.bits_per_element > 64
|| self.bits_per_element() > 64
|| k >= range.end - range.start
{
None
Expand Down Expand Up @@ -1273,7 +1270,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn get_sorted_u64(&self, i: usize) -> Option<u64> {
if i >= self.len() || self.bits_per_element > 64 {
if i >= self.len() || self.bits_per_element() > 64 {
None
} else {
Some(self.get_sorted_u64_unchecked(i))
Expand Down Expand Up @@ -1546,7 +1543,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn range_median_u64(&self, range: Range<usize>) -> Option<u64> {
if range.is_empty() || self.bits_per_element > 64 || range.end > self.len() {
if range.is_empty() || self.bits_per_element() > 64 || range.end > self.len() {
None
} else {
let k = (range.end - 1 - range.start) / 2;
Expand Down Expand Up @@ -1675,7 +1672,7 @@ impl WaveletMatrix {
/// [`BitVec`]: BitVec
#[must_use]
pub fn predecessor(&self, range: Range<usize>, symbol: &BitVec) -> Option<BitVec> {
if symbol.len() != self.bits_per_element as usize
if symbol.len() != self.bits_per_element()
|| range.is_empty()
|| self.is_empty()
|| range.end > self.len()
Expand All @@ -1686,10 +1683,10 @@ impl WaveletMatrix {
self.predecessor_generic_unchecked(
range,
symbol,
BitVec::from_zeros(self.bits_per_element as usize),
|level, symbol| symbol.get_unchecked((self.bits_per_element - 1) as usize - level),
BitVec::from_zeros(self.bits_per_element()),
|level, symbol| symbol.get_unchecked((self.bits_per_element() - 1) - level),
|bit, level, result| {
result.set_unchecked((self.bits_per_element - 1) as usize - level, bit);
result.set_unchecked((self.bits_per_element() - 1) - level, bit);
},
Self::partial_quantile_search_unchecked,
)
Expand Down Expand Up @@ -1719,7 +1716,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn predecessor_u64(&self, range: Range<usize>, symbol: u64) -> Option<u64> {
if self.bits_per_element > 64
if self.bits_per_element() > 64
|| range.is_empty()
|| self.is_empty()
|| range.end > self.len()
Expand All @@ -1731,7 +1728,7 @@ impl WaveletMatrix {
range,
&symbol,
0,
|level, symbol| symbol >> ((self.bits_per_element - 1) as usize - level) & 1,
|level, symbol| symbol >> ((self.bits_per_element() - 1) - level) & 1,
|bit, _level, result| {
// we ignore the level here, and instead rely on the fact that the bits are set in order.
// we have to do that, because the quantile_search_u64 does the same.
Expand Down Expand Up @@ -1862,7 +1859,7 @@ impl WaveletMatrix {
/// [`BitVec`]: BitVec
#[must_use]
pub fn successor(&self, range: Range<usize>, symbol: &BitVec) -> Option<BitVec> {
if symbol.len() != self.bits_per_element as usize
if symbol.len() != self.bits_per_element()
|| range.is_empty()
|| self.is_empty()
|| range.end > self.len()
Expand All @@ -1873,10 +1870,10 @@ impl WaveletMatrix {
self.successor_generic_unchecked(
range,
symbol,
BitVec::from_zeros(self.bits_per_element as usize),
|level, symbol| symbol.get_unchecked((self.bits_per_element - 1) as usize - level),
BitVec::from_zeros(self.bits_per_element()),
|level, symbol| symbol.get_unchecked((self.bits_per_element() - 1) - level),
|bit, level, result| {
result.set_unchecked((self.bits_per_element - 1) as usize - level, bit);
result.set_unchecked((self.bits_per_element() - 1) - level, bit);
},
Self::partial_quantile_search_unchecked,
)
Expand Down Expand Up @@ -1906,7 +1903,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn successor_u64(&self, range: Range<usize>, symbol: u64) -> Option<u64> {
if self.bits_per_element > 64
if self.bits_per_element() > 64
|| range.is_empty()
|| self.is_empty()
|| range.end > self.len()
Expand All @@ -1918,7 +1915,7 @@ impl WaveletMatrix {
range,
&symbol,
0,
|level, symbol| symbol >> ((self.bits_per_element - 1) as usize - level) & 1,
|level, symbol| symbol >> ((self.bits_per_element() - 1) - level) & 1,
|bit, _level, result| {
// we ignore the level here, and instead rely on the fact that the bits are set in order.
// we have to do that, because the quantile_search_u64 does the same.
Expand All @@ -1945,7 +1942,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn iter_u64(&self) -> Option<WaveletNumRefIter> {
if self.bits_per_element > 64 {
if self.bits_per_element() > 64 {
None
} else {
Some(WaveletNumRefIter::new(self))
Expand All @@ -1957,7 +1954,7 @@ impl WaveletMatrix {
/// If the number of bits per element exceeds 64, `None` is returned.
#[must_use]
pub fn into_iter_u64(self) -> Option<WaveletNumIter> {
if self.bits_per_element > 64 {
if self.bits_per_element() > 64 {
None
} else {
Some(WaveletNumIter::new(self))
Expand Down Expand Up @@ -1996,7 +1993,7 @@ impl WaveletMatrix {
/// ```
#[must_use]
pub fn iter_sorted_u64(&self) -> Option<WaveletSortedNumRefIter> {
if self.bits_per_element > 64 {
if self.bits_per_element() > 64 {
None
} else {
Some(WaveletSortedNumRefIter::new(self))
Expand All @@ -2008,17 +2005,24 @@ impl WaveletMatrix {
/// If the number of bits per element exceeds 64, `None` is returned.
#[must_use]
pub fn into_iter_sorted_u64(self) -> Option<WaveletSortedNumIter> {
if self.bits_per_element > 64 {
if self.bits_per_element() > 64 {
None
} else {
Some(WaveletSortedNumIter::new(self))
}
}

/// Get the number of bits per element in the alphabet of the encoded sequence.
#[must_use]
#[inline(always)]
pub fn bits_per_element(&self) -> usize {
self.data.len()
}

/// Get the number of bits per element in the alphabet of the encoded sequence.
#[deprecated(since="0.1.6", note="please use `bits_per_element` instead")]
pub fn bit_len(&self) -> u16 {
self.bits_per_element
self.bits_per_element() as u16
}

/// Get the number of elements stored in the encoded sequence.
Expand Down

0 comments on commit 14d887c

Please sign in to comment.