diff --git a/rust/geoarrow/src/algorithm/native/take.rs b/rust/geoarrow/src/algorithm/native/take.rs index da4e8995..9e427dd2 100644 --- a/rust/geoarrow/src/algorithm/native/take.rs +++ b/rust/geoarrow/src/algorithm/native/take.rs @@ -55,6 +55,58 @@ impl Take for PointArray { } } +// Note that GeometryArray's builder parameters differ from other native array types which means +// it cannot use the macro to build a Take impl +impl Take for GeometryArray { + type Output = Result; + + fn take(&self, indices: &UInt32Array) -> Self::Output { + let mut capacity = GeometryCapacity::new_empty(DEFAULT_PREFER_MULTI); + + for index in indices.iter().flatten() { + capacity.add_geometry(self.get(index.as_usize()).as_ref())?; + } + + let mut builder = GeometryBuilder::with_capacity_and_options( + capacity, + self.coord_type(), + self.metadata(), + DEFAULT_PREFER_MULTI, + ); + + for index in indices.iter() { + if let Some(index) = index { + builder.push_geometry(self.get(index.as_usize()).as_ref())?; + } else { + builder.push_null(); + } + } + + Ok(builder.finish()) + } + + fn take_range(&self, range: &Range) -> Self::Output { + let mut capacity = GeometryCapacity::new_empty(DEFAULT_PREFER_MULTI); + + for i in range.start..range.end { + capacity.add_geometry(self.get(i).as_ref())?; + } + + let mut builder = GeometryBuilder::with_capacity_and_options( + capacity, + self.coord_type(), + self.metadata(), + DEFAULT_PREFER_MULTI, + ); + + for i in range.start..range.end { + builder.push_geometry(self.get(i).as_ref())?; + } + + Ok(builder.finish()) + } +} + // TODO: parameterize over input and output separately macro_rules! take_impl { @@ -232,6 +284,7 @@ impl Take for &dyn NativeArray { MultiPoint(_, XY) => Arc::new(self.as_multi_point().take(indices)?), MultiLineString(_, XY) => Arc::new(self.as_multi_line_string().take(indices)?), MultiPolygon(_, XY) => Arc::new(self.as_multi_polygon().take(indices)?), + Geometry(_) => Arc::new(self.as_geometry().take(indices)?), GeometryCollection(_, XY) => Arc::new(self.as_geometry_collection().take(indices)?), _ => return Err(GeoArrowError::IncorrectType("".into())), }; @@ -249,6 +302,7 @@ impl Take for &dyn NativeArray { MultiPoint(_, XY) => Arc::new(self.as_multi_point().take_range(range)?), MultiLineString(_, XY) => Arc::new(self.as_multi_line_string().take_range(range)?), MultiPolygon(_, XY) => Arc::new(self.as_multi_polygon().take_range(range)?), + Geometry(_) => Arc::new(self.as_geometry().take_range(range)?), GeometryCollection(_, XY) => Arc::new(self.as_geometry_collection().take_range(range)?), _ => return Err(GeoArrowError::IncorrectType("".into())), }; @@ -312,3 +366,56 @@ chunked_impl!(ChunkedGeometryArray); chunked_impl!(ChunkedGeometryArray); chunked_impl!(ChunkedGeometryArray); chunked_impl!(ChunkedGeometryArray); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn geometry_take_impl() -> Result<()> { + let indices: UInt32Array = vec![0, 2].into(); + let point = geo::point!(x: 0., y: 0.); + let ls: geo::LineString = vec![(1., 1.), (2., 2.)].into(); + + let mut geo_array = GeometryBuilder::new(); + geo_array.push_geometry(Some(&point))?; + geo_array.push_geometry(Some(&geo::point!(x: 1., y: 1.)))?; + geo_array.push_geometry(Some(&ls))?; + let geo_array = geo_array.finish(); + + let take_array = geo_array.take(&indices)?; + assert_eq!( + 2, + take_array.len(), + "take resulted in an unexpected number of items" + ); + assert_eq!(take_array.value(0), point); + assert_eq!(take_array.value(1), ls); + + Ok(()) + } + + #[test] + fn geometry_take_range_impl() -> Result<()> { + let point = geo::point!(x: 0., y: 0.); + let ls: geo::LineString = vec![(1., 1.), (2., 2.)].into(); + + let mut geo_array = GeometryBuilder::new(); + geo_array.push_geometry(Some(&geo::point!(x: 1., y: 1.)))?; + geo_array.push_geometry(Some(&point))?; + geo_array.push_geometry(Some(&ls))?; + let geo_array = geo_array.finish(); + + let range = 1..geo_array.len(); + let take_array = geo_array.take_range(&range)?; + assert_eq!( + 2, + take_array.len(), + "take range resulted in an unexpected number of items" + ); + assert_eq!(take_array.value(0), point); + assert_eq!(take_array.value(1), ls); + + Ok(()) + } +}