diff --git a/libs/sparse_table/src/lib.rs b/libs/sparse_table/src/lib.rs index 8f03e88a..c2acb356 100644 --- a/libs/sparse_table/src/lib.rs +++ b/libs/sparse_table/src/lib.rs @@ -1,152 +1,350 @@ -//! Sparse table です。いまのところ Argmin しかありません。 -use std::ops; +//! # Sparse Table +//! +//! * [`SparseTable`] (1-dimensional) +//! * [`SparseTable2d`] (2-dimensional) +//! +//! # [`Op`] trait +//! +//! [`Op::mul`] must be associative and idempotent. -#[derive(Debug, Clone)] -pub struct SparseTableArgmin { - table: Vec>, - seq: Vec, +use std::fmt::Debug; +use std::iter::FromIterator; +use std::ops::Index; +use std::ops::RangeBounds; + +/// A trait for the operation used in sparse tables. +pub trait Op { + /// The type of the values. + type Value; + + /// Multiplies two values: $x \cdot y$. + fn mul(lhs: &Self::Value, rhs: &Self::Value) -> Self::Value; } -impl SparseTableArgmin { - pub fn from_vec(seq: Vec) -> Self { - let n = seq.len(); - let mut table = vec![(0..n).collect::>()]; - let mut d = 1; - while 2 * d < n { - let prv = table.last().unwrap(); - let mut crr = prv.clone(); - for i in 0..n - d { - if seq[crr[i + d]] < seq[crr[i]] { - crr[i] = crr[i + d]; - } - } - table.push(crr); - d *= 2; - } - Self { table, seq } +/// A sparse table for 1-dimensional range queries. +pub struct SparseTable { + table: Vec>, +} +impl SparseTable { + /// Constructs a sparse table from a vector of values. + pub fn new(values: Vec) -> Self { + values.into() } - pub fn query(&self, range: impl ops::RangeBounds) -> Option { - let ops::Range { start, end } = convert_to_range(self.seq.len(), range); + /// Constructs a sparse table from a slice of values. + pub fn clone_from_slice(values: &[O::Value]) -> Self + where + O::Value: Clone, + { + values.into() + } + + /// Returns the value at the given index. + pub fn get(&self, index: usize) -> &O::Value { + &self.table[0][index] + } + + /// Returns $x_l \cdot x_{l+1} \cdot \ldots \cdot x_{r-1}$, or `None` if $l = r$. + pub fn fold(&self, range: impl RangeBounds) -> Option { + let (start, end) = open(range, self.table[0].len()); assert!(start <= end); - if start == end { - None - } else { - Some(if start + 1 == end { - start - } else { - let d = (end - start).next_power_of_two() / 2; - let row = &self.table[d.trailing_zeros() as usize]; - let u = row[start]; - let v = row[end - d]; - if self.seq[u] <= self.seq[v] { - u - } else { - v - } - }) + (start < end).then_some(())?; + let p = (end - start).ilog2() as usize; + let row = &self.table[p]; + Some(O::mul(&row[start], &row[end - (1 << p)])) + } + + /// Returns an iterator over the values. + pub fn iter(&self) -> impl Iterator { + self.table[0].iter() + } + + /// Returns a slice of the values. + pub fn as_slice(&self) -> &[O::Value] { + &self.table[0] + } + + /// Collects the values into a vector. + pub fn collect_vec(&self) -> Vec + where + O::Value: Clone, + { + self.table[0].clone() + } + + /// Returns the inner table. + pub fn inner(&self) -> &Vec> { + &self.table + } +} + +impl Debug for SparseTable +where + O::Value: Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SparseTable") + .field("table", &self.table) + .finish() + } +} + +impl From> for SparseTable { + fn from(values: Vec) -> Self { + let n = values.len(); + let mut table = vec![values]; + let mut i = 1; + while i * 2 <= n { + let last = table.last().unwrap(); + let current = last + .iter() + .zip(&last[i..]) + .map(|(a, b)| O::mul(a, b)) + .collect(); + table.push(current); + i *= 2; } + Self { table } + } +} + +impl<'a, O: Op> From<&'a [O::Value]> for SparseTable +where + O::Value: Clone, +{ + fn from(values: &'a [O::Value]) -> Self { + values.to_vec().into() + } +} + +impl FromIterator for SparseTable { + fn from_iter>(iter: T) -> Self { + iter.into_iter().collect::>().into() + } +} + +impl Index for SparseTable { + type Output = O::Value; + + fn index(&self, index: usize) -> &Self::Output { + &self.table[0][index] } +} + +/// A sparse table for 2-dimensional range queries. +/// +/// The operation must also be commutative. +pub struct SparseTable2d { + table: Vec>>>, +} - #[inline] - pub fn min(&self, range: impl ops::RangeBounds) -> Option<&T> { - self.query(range).map(|i| &self.seq[i]) +impl SparseTable2d { + /// Constructs a sparse table from a vector of values. + pub fn new(values: Vec>) -> Self { + values.into() } - #[inline] - pub fn get(&self, index: I) -> Option<&I::Output> + /// Constructs a sparse table from a slice of values. + pub fn clone_from_slice(values: &[Vec]) -> Self where - I: std::slice::SliceIndex<[T]>, + O::Value: Clone, { - self.seq.get(index) + values.into() + } + + /// Returns $(x_{i_0, j_0} \cdot \dots \cdot x_{i_1-1, j_0}) \cdot \dots \cdot (x_{i_0, j_1-1} \cdot \dots \cdot x_{i_1-1, j_1-1})$, or `None` if $i_0 = i_1$ or $j_0 = j_1$. + pub fn fold(&self, i: impl RangeBounds, j: impl RangeBounds) -> Option { + let (i0, mut i1) = open(i, self.table[0][0].len()); + let (j0, mut j1) = open(j, self.table[0][0].first().map_or(0, Vec::len)); + assert!(i0 <= i1); + assert!(j0 <= j1); + (i0 < i1 && j0 < j1).then_some(())?; + let p = (i1 - i0).ilog2() as usize; + let q = (j1 - j0).ilog2() as usize; + let grid = &self.table[p][q]; + i1 -= 1 << p; + j1 -= 1 << q; + Some(O::mul( + &O::mul(&grid[i0][j0], &grid[i1][j0]), + &O::mul(&grid[i0][j1], &grid[i1][j1]), + )) + } + + /// Returns that yields the row of the table. + pub fn iter(&self) -> impl Iterator { + self.table[0][0].iter().map(Vec::as_slice) + } + + /// Returns a slice of the values. + pub fn as_slice(&self) -> &[Vec] { + &self.table[0][0] + } + + /// Collects the values into a vector of vectors. + pub fn collect_vec(&self) -> Vec> + where + O::Value: Clone, + { + self.table[0][0].clone() + } + + /// Returns the inner table. + pub fn inner(&self) -> &Vec>>> { + &self.table } } -impl ops::Index for SparseTableArgmin +impl Debug for SparseTable2d where - I: std::slice::SliceIndex<[T]>, - T: Clone, + O::Value: Debug, { - type Output = I::Output; + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SparseTable2d") + .field("table", &self.table) + .finish() + } +} - fn index(&self, index: I) -> &I::Output { - self.seq.index(index) +impl From>> for SparseTable2d { + fn from(values: Vec>) -> Self { + let h = values.len(); + let w = values.first().map_or(0, Vec::len); + let mut table = vec![vec![values]]; + let mut j = 1; + while j * 2 <= w { + let last = table[0].last().unwrap(); + let current = last + .iter() + .map(|row| { + row.iter() + .zip(&row[j..]) + .map(|(a, b)| O::mul(a, b)) + .collect::>() + }) + .collect(); + table[0].push(current); + j *= 2; + } + let mut i = 1; + while i * 2 <= h { + let last = table.last().unwrap(); + let current = last + .iter() + .map(|grid| { + grid.iter() + .zip(&grid[i..]) + .map(|(a, b)| { + a.iter() + .zip(b) + .map(|(a, b)| O::mul(a, b)) + .collect::>() + }) + .collect::>() + }) + .collect(); + table.push(current); + i *= 2; + } + Self { table } } } -fn convert_to_range(len: usize, range_bound: T) -> ops::Range +impl<'a, O: Op> From<&'a [Vec]> for SparseTable2d where - T: ops::RangeBounds, + O::Value: Clone, { - use ops::Bound::Excluded; - use ops::Bound::Included; - use ops::Bound::Unbounded; - ops::Range { - start: match range_bound.start_bound() { - Excluded(&x) => x + 1, - Included(&x) => x, - Unbounded => 0, - }, - end: match range_bound.end_bound() { - Excluded(&x) => x, - Included(&x) => x + 1, - Unbounded => len, - }, + fn from(values: &'a [Vec]) -> Self { + values.to_vec().into() } } +fn open>(bounds: B, n: usize) -> (usize, usize) { + use std::ops::Bound; + let start = match bounds.start_bound() { + Bound::Unbounded => 0, + Bound::Included(&x) => x, + Bound::Excluded(&x) => x + 1, + }; + let end = match bounds.end_bound() { + Bound::Unbounded => n, + Bound::Included(&x) => x + 1, + Bound::Excluded(&x) => x, + }; + (start, end) +} + #[cfg(test)] mod tests { use super::*; + use rand::rngs::StdRng; + use rand::Rng; + use rand::SeedableRng; + use std::ops::Range; + + enum O {} + impl Op for O { + type Value = u64; + + fn mul(lhs: &Self::Value, rhs: &Self::Value) -> Self::Value { + (*lhs).max(*rhs) + } + } #[test] - fn test_hand() { - let a = vec![4, 3, 5, 1, 3, 2]; - let spt = SparseTableArgmin::from_vec(a); - assert_eq!(spt.query(3..3), None); - assert_eq!(spt.query(5..6), Some(5)); - assert_eq!(spt.query(1..3), Some(1)); - assert_eq!(spt.query(1..5), Some(3)); - assert_eq!(spt.query(0..6), Some(3)); - assert_eq!(spt.query(0..=3), Some(3)); - assert_eq!(spt.query(0..=2), Some(1)); - assert_eq!(spt.query(..), Some(3)); + fn test_sparse_table() { + let mut rng = StdRng::seed_from_u64(42); + for _ in 0..100 { + let n = rng.gen_range(1..=100); + let q = rng.gen_range(1..=100); + let vec = (0..n) + .map(|_| rng.gen_range(0..u64::MAX)) + .collect::>(); + let st = SparseTable::::clone_from_slice(&vec); + for _ in 0..q { + let range = random_range(&mut rng, n); + let expected = vec[range.clone()].iter().copied().max(); + let actual = st.fold(range.clone()); + assert_eq!(expected, actual); + } + } } #[test] - fn test_random() { - const LEN_MAX: usize = 40; - const VALUE_MAX: u32 = 16; - const NUMBER_OF_INSTANCE: usize = 80; - const NUMBER_OF_QUERIES: usize = 80; - - for _ in 0..NUMBER_OF_INSTANCE { - let a = std::iter::repeat_with(|| rand::random::() % VALUE_MAX) - .take(LEN_MAX) + fn test_sparse_table_2d() { + let mut rng = StdRng::seed_from_u64(42); + for _ in 0..100 { + let h = rng.gen_range(1..=10); + let w = rng.gen_range(1..=10); + let q = rng.gen_range(1..=100); + let vec = (0..h) + .map(|_| { + (0..w) + .map(|_| rng.gen_range(0..u64::MAX)) + .collect::>() + }) .collect::>(); - let spt = SparseTableArgmin::from_vec(a.clone()); - - for _ in 0..NUMBER_OF_QUERIES { - let (l, r) = { - let mut l = rand::random::() % LEN_MAX; - let mut r = rand::random::() % LEN_MAX; - if l > r { - std::mem::swap(&mut l, &mut r); - } - r += 1; - (l, r) - }; - let expected = a + let st = SparseTable2d::::clone_from_slice(&vec); + for _ in 0..q { + let i = random_range(&mut rng, h); + let j = random_range(&mut rng, w); + let expected = vec[i.clone()] .iter() - .enumerate() - .skip(l) - .take(r - l) - .map(|(i, &x)| (x, i)) - .min() - .map(|(_, i)| i); - let result = spt.query(l..r); - assert_eq!(expected, result); + .flat_map(|row| &row[j.clone()]) + .max() + .copied(); + let actual = st.fold(i.clone(), j.clone()); + assert_eq!(expected, actual); } } } + + fn random_range(rng: &mut StdRng, n: usize) -> Range { + let start = rng.gen_range(0..=n + 1); + let end = rng.gen_range(0..=n); + if start <= end { + start..end + } else { + end..start - 1 + } + } }