Skip to content

Commit

Permalink
SparseArray TakeFn returns results in the requested order (#212)
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 authored Apr 5, 2024
1 parent a3a4e03 commit 79a0630
Showing 1 changed file with 61 additions and 22 deletions.
83 changes: 61 additions & 22 deletions vortex-array/src/array/sparse/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,17 +129,16 @@ impl TakeFn for SparseArray {
fn take(&self, indices: &dyn Array) -> VortexResult<ArrayRef> {
let flat_indices = flatten_primitive(indices)?;
// if we are taking a lot of values we should build a hashmap
let exact_taken_indices = if indices.len() > 512 {
let (positions, physical_take_indices) = if indices.len() > 512 {
take_map(self, flat_indices)?
} else {
take_search_sorted(self, flat_indices)?
};

let taken_values = take(self.values(), &exact_taken_indices)?;
let taken_values = take(self.values(), &physical_take_indices)?;

Ok(SparseArray::new(
PrimitiveArray::from((0u64..exact_taken_indices.len() as u64).collect::<Vec<_>>())
.into_array(),
positions.into_array(),
taken_values,
indices.len(),
self.fill_value().clone(),
Expand All @@ -148,27 +147,34 @@ impl TakeFn for SparseArray {
}
}

fn take_map(array: &SparseArray, indices: PrimitiveArray) -> VortexResult<PrimitiveArray> {
fn take_map(
array: &SparseArray,
indices: PrimitiveArray,
) -> VortexResult<(PrimitiveArray, PrimitiveArray)> {
let indices_map: HashMap<u64, u64> = array
.resolved_indices()
.iter()
.enumerate()
.map(|(i, r)| (*r as u64, i as u64))
.collect();
let patch_indices: Vec<u64> = match_each_integer_ptype!(indices.ptype(), |$P| {
let (positions, patch_indices): (Vec<u64>, Vec<u64>) = match_each_integer_ptype!(indices.ptype(), |$P| {
indices.typed_data::<$P>()
.iter()
.map(|i| *i as u64)
.filter_map(|pi| indices_map.get(&pi).copied())
.collect::<Vec<_>>()
.map(|pi| *pi as u64)
.enumerate()
.filter_map(|(i, pi)| indices_map.get(&pi).copied().map(|phy_idx| (i as u64, phy_idx)))
.unzip()
});
Ok(PrimitiveArray::from(patch_indices))
Ok((
PrimitiveArray::from(positions),
PrimitiveArray::from(patch_indices),
))
}

fn take_search_sorted(
array: &SparseArray,
indices: PrimitiveArray,
) -> VortexResult<PrimitiveArray> {
) -> VortexResult<(PrimitiveArray, PrimitiveArray)> {
let adjusted_indices = match_each_integer_ptype!(indices.ptype(), |$P| {
indices.typed_data::<$P>()
.iter()
Expand All @@ -184,17 +190,22 @@ fn take_search_sorted(
.collect::<VortexResult<Vec<_>>>()?,
);
let taken_indices = flatten_primitive(&take(array.indices(), &physical_indices)?)?;
match_each_integer_ptype!(taken_indices.ptype(), |$P| {
Ok(PrimitiveArray::from(taken_indices
.typed_data::<$P>()
.iter()
.copied()
.zip_eq(adjusted_indices)
.zip_eq(physical_indices.typed_data::<u64>())
.filter(|((taken_idx, orig_idx), _)| *taken_idx as usize == *orig_idx)
.map(|(_, physical_idx)| *physical_idx)
.collect::<Vec<_>>()))
})
let (positions, patch_indices): (Vec<u64>, Vec<u64>) = match_each_integer_ptype!(taken_indices.ptype(), |$P| {
taken_indices
.typed_data::<$P>()
.iter()
.copied()
.enumerate()
.zip_eq(adjusted_indices)
.zip_eq(physical_indices.typed_data::<u64>())
.filter(|(((_, taken_idx), orig_idx), _)| *taken_idx as usize == *orig_idx)
.map(|(((i, _), _), physical_idx)| (i as u64, *physical_idx))
.unzip()
});
Ok((
PrimitiveArray::from(positions),
PrimitiveArray::from(patch_indices),
))
}

#[cfg(test)]
Expand Down Expand Up @@ -261,4 +272,32 @@ mod test {
[]
);
}

#[test]
fn ordered_take() {
let sparse = SparseArray::new(
PrimitiveArray::from(vec![0u64, 37, 47, 99]).into_array(),
PrimitiveArray::from(vec![1.23f64, 0.47, 9.99, 3.5]).into_array(),
100,
Scalar::null(&DType::Float(FloatWidth::_64, Nullability::Nullable)),
);
let taken = take(&sparse, &PrimitiveArray::from(vec![69, 37])).unwrap();
assert_eq!(
taken
.as_sparse()
.indices()
.as_primitive()
.typed_data::<u64>(),
[1]
);
assert_eq!(
taken
.as_sparse()
.values()
.as_primitive()
.typed_data::<f64>(),
[0.47f64]
);
assert_eq!(taken.len(), 2);
}
}

0 comments on commit 79a0630

Please sign in to comment.