diff --git a/Cargo.toml b/Cargo.toml index e6c4908..492c823 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "rstsr-blas-traits", "rstsr-test-manifest", "rstsr-linalg-traits", + "rstsr-hdf5", ] [workspace.package] @@ -30,6 +31,7 @@ rstsr-test-manifest = { path = "./rstsr-test-manifest", default-features = false rstsr-openblas = { path = "./rstsr-openblas", default-features = false, version = "0.2.0" } rstsr-blas-traits = { path = "./rstsr-blas-traits", default-features = false, version = "0.2.0" } rstsr-linalg-traits = { path = "./rstsr-linalg-traits", default-features = false, version = "0.2.0" } +rstsr-hdf5 = { path = "./rstsr-hdf5", default-features = false, version = "0.2.0" } # basic dependencies num = { version = "0.4" } thiserror = { version = "1.0" } @@ -41,12 +43,14 @@ rayon = { version = "1.10" } faer = { version = "0.19" } faer-ext = { version = "0.3" } faer-entity = { version = "0.19" } +hdf5-metno = { version = "0.10" } +ndarray = { version = "0.15" } +tokio = { version = "1.44", features = ["rt-multi-thread", "macros"] } # dev dependencies npyz = { version = "0.8", features = ["complex"] } anyhow = { version = "1.0" } rand = { version = "0.8" } approx = { version = "0.5" } -ndarray = { version = "0.15" } criterion = { version = "0.5" } cpu-time = { version = "1.0" } diff --git a/rstsr-core/src/device_cpu_serial/device.rs b/rstsr-core/src/device_cpu_serial/device.rs index be6b753..39ac5d2 100644 --- a/rstsr-core/src/device_cpu_serial/device.rs +++ b/rstsr-core/src/device_cpu_serial/device.rs @@ -44,39 +44,6 @@ impl DeviceStorageAPI for DeviceCpuSerial { let (raw, _) = storage.into_raw_parts(); Ok(raw.into_owned().into_raw()) } - - #[inline] - fn get_index(storage: &Storage, index: usize) -> T - where - T: Clone, - R: DataAPI, - { - storage.raw()[index].clone() - } - - #[inline] - fn get_index_ptr(storage: &Storage, index: usize) -> *const T - where - R: DataAPI, - { - storage.raw().get(index).unwrap() as *const T - } - - #[inline] - fn get_index_mut_ptr(storage: &mut Storage, index: usize) -> *mut T - where - R: DataMutAPI, - { - storage.raw_mut().get_mut(index).unwrap() as *mut T - } - - #[inline] - fn set_index(storage: &mut Storage, index: usize, value: T) - where - R: DataMutAPI, - { - storage.raw_mut()[index] = value; - } } impl DeviceAPI for DeviceCpuSerial {} diff --git a/rstsr-core/src/device_faer/device.rs b/rstsr-core/src/device_faer/device.rs index e275f30..9ddd0b9 100644 --- a/rstsr-core/src/device_faer/device.rs +++ b/rstsr-core/src/device_faer/device.rs @@ -76,39 +76,6 @@ impl DeviceStorageAPI for DeviceFaer { let (raw, _) = storage.into_raw_parts(); Ok(raw.into_owned().into_raw()) } - - #[inline] - fn get_index(storage: &Storage, index: usize) -> T - where - T: Clone, - R: DataAPI, - { - storage.raw()[index].clone() - } - - #[inline] - fn get_index_ptr(storage: &Storage, index: usize) -> *const T - where - R: DataAPI, - { - &storage.raw()[index] as *const T - } - - #[inline] - fn get_index_mut_ptr(storage: &mut Storage, index: usize) -> *mut T - where - R: DataMutAPI, - { - storage.raw_mut().get_mut(index).unwrap() as *mut T - } - - #[inline] - fn set_index(storage: &mut Storage, index: usize, value: T) - where - R: DataMutAPI, - { - storage.raw_mut()[index] = value; - } } impl DeviceAPI for DeviceFaer {} diff --git a/rstsr-core/src/layout/broadcast.rs b/rstsr-core/src/layout/broadcast.rs index 3e021d6..b104dd1 100644 --- a/rstsr-core/src/layout/broadcast.rs +++ b/rstsr-core/src/layout/broadcast.rs @@ -239,8 +239,10 @@ where /// Check whether current layout has been broadcasted. /// /// This check is done by checking whether any stride of axis is zero. + /// Additionally, for certain dimension, if (shape, stride) = (1, 0), then + /// this dimension is not considered as broadcasted. pub fn is_broadcasted(&self) -> bool { - self.stride().as_ref().contains(&0) + self.stride().as_ref().iter().zip(self.shape().as_ref()).any(|(&s, &d)| s == 0 && d != 1) } } diff --git a/rstsr-core/src/layout/indexer.rs b/rstsr-core/src/layout/indexer.rs index 0bc34f7..f4a533a 100644 --- a/rstsr-core/src/layout/indexer.rs +++ b/rstsr-core/src/layout/indexer.rs @@ -13,6 +13,11 @@ pub enum Indexer { Insert, /// Expand dimensions. Ellipsis, + /// Broadcast dimensions. + /// + /// This option is not designed to be used by general users. `Broadcast(1)` + /// is equilvalent to `Insert`. + Broadcast(usize), } pub use Indexer::Ellipsis; @@ -273,6 +278,10 @@ pub trait IndexerLargerOneAPI { /// Insert dimension after, with shape 1. Number of dimension will increase /// by 1. fn dim_insert(&self, axis: isize) -> Result>; + + /// Insert dimension after, with shape broadcasted. Number of dimension will + /// increase by 1. + fn dim_broadcast(&self, axis: isize, num: usize) -> Result>; } impl IndexerLargerOneAPI for Layout @@ -313,6 +322,24 @@ where let layout = Layout::new(shape, stride, offset)?; return layout.into_dim(); } + + fn dim_broadcast(&self, axis: isize, num: usize) -> Result> { + // dimension check + let axis = if axis < 0 { self.ndim() as isize + axis + 1 } else { axis }; + rstsr_pattern!(axis, 0..(self.ndim() + 1) as isize, ValueOutOfRange)?; + let axis = axis as usize; + + // get essential information + let mut shape = self.shape().as_ref().to_vec(); + let mut stride = self.stride().as_ref().to_vec(); + let offset = self.offset(); + + shape.insert(axis, num); + stride.insert(axis, 0); + + let layout = Layout::new(shape, stride, offset)?; + return layout.into_dim(); + } } pub trait IndexerDynamicAPI: IndexerPreserveAPI { @@ -350,7 +377,7 @@ where Some(_) => rstsr_raise!(InvalidValue, "Only one ellipsis indexer allowed.")?, None => idx_ellipsis = Some(n), }, - _ => {}, + Indexer::Insert | Indexer::Broadcast(_) => {}, } } @@ -392,6 +419,9 @@ where Indexer::Insert => { layout = layout.dim_insert(cur_dim)?; }, + Indexer::Broadcast(num) => { + layout = layout.dim_broadcast(cur_dim, *num)?; + }, _ => rstsr_raise!(InvalidValue, "Invalid indexer found : {:?}", indexer)?, } } diff --git a/rstsr-core/src/layout/mod.rs b/rstsr-core/src/layout/mod.rs index 7bd0471..4998227 100644 --- a/rstsr-core/src/layout/mod.rs +++ b/rstsr-core/src/layout/mod.rs @@ -5,6 +5,7 @@ pub mod iterator; pub mod layoutbase; pub mod matmul; pub mod rearrangement; +pub mod reflection; pub mod reshape; pub mod shape; pub mod slice; @@ -17,6 +18,7 @@ pub use iterator::*; pub use layoutbase::*; pub use matmul::*; pub use rearrangement::*; +pub use reflection::*; pub use reshape::*; pub use shape::*; pub use slice::*; diff --git a/rstsr-core/src/layout/reflection.rs b/rstsr-core/src/layout/reflection.rs new file mode 100644 index 0000000..7aa3e7e --- /dev/null +++ b/rstsr-core/src/layout/reflection.rs @@ -0,0 +1,120 @@ +//! Layout/Indexing/Slicing reflection (inverse mapping) utilities. +//! +//! Note this name does not mean reflection (反射) of program. It just means the +//! inverse (逆 ~ 反) mapping (射). + +use crate::prelude_dev::*; + +/// Deduce the n-dimensional index from the offset in the layout. +/// +/// Tensor indexing is passing the index to layout, and returns the offset (as +/// memory shift to some anchor). +/// +/// If offset is obtained by indexing the layout: +/// `offset = layout.index(index)` +/// Then this function tries to obtain the inverse mapping `index` from +/// `layout`. +/// +/// # Note +/// +/// The inverse mapping may not be unique. +/// +/// Restrictions: +/// - Layout should not be broadcasted. +/// +/// This function should not be applied in computationally intensive part. +pub fn layour_reflect_index(layout: &Layout, offset: usize) -> Result { + layout.check_strides()?; + layout.bounds_index()?; + if layout.is_broadcasted() { + rstsr_raise!(InvalidLayout, "Layout should not be broadcasted in `layour_reflect_index`.")?; + } + + // 1. Prepare the result + let mut index: Vec = vec![0; layout.ndim()]; + + // 2. Get the location vector of strides, with the largest absolute stride to be + // the first Also, only shape != 1 leaves, other cases are ignored. + let arg_stride = layout + .stride() + .iter() + .enumerate() + .filter(|a| layout.shape()[a.0] != 1) + .sorted_by(|a, b| b.1.abs().cmp(&a.1.abs())) + .collect_vec(); + + // 3. Calculate the index + let mut inner_offset = offset as isize - layout.offset() as isize; + for (n, &(i, &s)) in arg_stride.iter().enumerate() { + let q = inner_offset.unsigned_abs() / s.unsigned_abs(); + let r = inner_offset.unsigned_abs() % s.unsigned_abs(); + // we can not tolarate if `q != 0` and sign of `inner_offset` is not the same to + // stride `s`. If this happens, index is negative, which is not allowed. + if q != 0 && inner_offset * s < 0 { + rstsr_raise!(InvalidValue, "Negative index occured.")?; + } + // offset is divisible by stride + // then the leaving index should be 0, so early break + if r == 0 { + index[i] = q; + break; + } + // if last element not divisible by stride + // then the provided offset is invalid + if n == arg_stride.len() - 1 { + rstsr_raise!(InvalidValue, "Offset is not divisible by the smallest stride.")?; + } + // generate next inner_offset + // next inner_offset must have the same sign with the next stride + // so the `q` given by modular division may have to increase by 1 + inner_offset -= s * q as isize; + index[i] = q; + if inner_offset.is_negative() != arg_stride[n + 1].1.is_negative() { + inner_offset -= s; + index[i] += 1; + } + } + + // 4. Before return the index, we should check if the given index can recover + // the input offset. + let offset_recover = layout.index_f(&index.iter().map(|&i| i as isize).collect_vec())?; + if offset_recover != offset { + rstsr_raise!( + RuntimeError, + "The given offset can not be recovered by the index, may be an internal bug." + )?; + } + + return Ok(index); +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_layour_reflect_index() { + let layout = vec![10, 12, 1, 15].f().dim_narrow(1, slice!(10, 2, -1)).unwrap(); + println!("{:?}", layout); + + // test usual case + let offset = layout.index(&[2, 3, 0, 5]); + let index = layour_reflect_index(&layout, offset).unwrap(); + assert_eq!(index, [2, 3, 0, 5]); + + // test another usual case + let offset = layout.index(&[2, 1, 0, 0]); + let index = layour_reflect_index(&layout, offset).unwrap(); + assert_eq!(index, [2, 1, 0, 0]); + + // test early stop case + let offset = layout.index(&[0, 0, 0, 3]); + let index = layour_reflect_index(&layout, offset).unwrap(); + assert_eq!(index, [0, 0, 0, 3]); + + // test case should failed (dim-1 out of bound) + let offset = [10, 12, 1, 15].f().index(&[2, 11, 0, 3]); + let index = layour_reflect_index(&layout, offset); + assert!(index.is_err()); + } +} diff --git a/rstsr-core/src/storage/device.rs b/rstsr-core/src/storage/device.rs index c2c6383..9dd2c1c 100644 --- a/rstsr-core/src/storage/device.rs +++ b/rstsr-core/src/storage/device.rs @@ -36,19 +36,6 @@ pub trait DeviceStorageAPI: DeviceRawAPI { where Self::Raw: Clone, R: DataCloneAPI; - fn get_index(storage: &Storage, index: usize) -> T - where - T: Clone, - R: DataAPI; - fn get_index_ptr(storage: &Storage, index: usize) -> *const T - where - R: DataAPI; - fn get_index_mut_ptr(storage: &mut Storage, index: usize) -> *mut T - where - R: DataMutAPI; - fn set_index(storage: &mut Storage, index: usize, value: T) - where - R: DataMutAPI; } impl Storage @@ -99,35 +86,6 @@ where { B::into_cpu_vec(self) } - - #[inline] - pub fn get_index(&self, index: usize) -> T - where - T: Clone, - { - B::get_index(self, index) - } - - #[inline] - pub fn get_index_ptr(&self, index: usize) -> *const T { - B::get_index_ptr(self, index) - } - - #[inline] - pub fn get_index_mut_ptr(&mut self, index: usize) -> *mut T - where - R: DataMutAPI, - { - B::get_index_mut_ptr(self, index) - } - - #[inline] - pub fn set_index(&mut self, index: usize, value: T) - where - R: DataMutAPI, - { - B::set_index(self, index, value) - } } impl Storage diff --git a/rstsr-hdf5/Cargo.toml b/rstsr-hdf5/Cargo.toml new file mode 100644 index 0000000..d4d793e --- /dev/null +++ b/rstsr-hdf5/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "rstsr-hdf5" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true + +[dependencies] +rstsr-core = { workspace = true, default-features = false } +hdf5-metno = { workspace = true } +ndarray = { workspace = true } diff --git a/rstsr-hdf5/readme.md b/rstsr-hdf5/readme.md new file mode 100644 index 0000000..e82b27b --- /dev/null +++ b/rstsr-hdf5/readme.md @@ -0,0 +1,3 @@ +# rstsr-hdf5 + +This crate is to make hdf5 file as raw data of tensor, making some simple I/O operations available for RSTSR. \ No newline at end of file diff --git a/rstsr-hdf5/src/asarray.rs b/rstsr-hdf5/src/asarray.rs new file mode 100644 index 0000000..b07f26a --- /dev/null +++ b/rstsr-hdf5/src/asarray.rs @@ -0,0 +1,25 @@ +use rstsr_core::storage; + +use crate::{device::DeviceHDF5, prelude_dev::*}; + +/// A dummy struct that involves `DeviceHDF5` and `T`. +/// +/// This struct is used to implement `AsArrayAPI` for `Dataset` to avoid orphan +/// rule restriction. +pub struct DeviceHDF5WithDType { + _phantom_hdf5: std::marker::PhantomData, + _phantom_t: std::marker::PhantomData, +} + +impl AsArrayAPI> for Dataset +where + T: H5Type, +{ + type Out = Tensor; + + fn asarray_f(self) -> Result { + let layout = self.shape().c(); + let storage = storage::Storage::new(self.into(), DeviceHDF5); + Ok(Tensor::new(storage, layout)) + } +} diff --git a/rstsr-hdf5/src/device.rs b/rstsr-hdf5/src/device.rs new file mode 100644 index 0000000..f74cb5b --- /dev/null +++ b/rstsr-hdf5/src/device.rs @@ -0,0 +1,44 @@ +use crate::prelude_dev::*; + +#[derive(Clone, Debug, Default)] +pub struct DeviceHDF5; + +impl DeviceBaseAPI for DeviceHDF5 { + fn same_device(&self, _other: &Self) -> bool { + true + } +} + +impl DeviceRawAPI for DeviceHDF5 { + type Raw = Dataset; +} + +impl DeviceStorageAPI for DeviceHDF5 +where + T: H5Type, +{ + fn len(storage: &Storage) -> usize + where + R: DataAPI, + { + storage.raw().size() + } + + fn to_cpu_vec(storage: &Storage) -> Result> + where + Self::Raw: Clone, + R: DataCloneAPI, + { + storage.raw().read_raw::().map_err(|e| Error::DeviceError(e.to_string())) + } + + fn into_cpu_vec(storage: Storage) -> Result> + where + Self::Raw: Clone, + R: DataCloneAPI, + { + Self::to_cpu_vec(&storage) + } +} + +impl DeviceAPI for DeviceHDF5 where T: H5Type {} diff --git a/rstsr-hdf5/src/lib.rs b/rstsr-hdf5/src/lib.rs new file mode 100644 index 0000000..775d60b --- /dev/null +++ b/rstsr-hdf5/src/lib.rs @@ -0,0 +1,8 @@ +#![allow(clippy::needless_return)] +#![allow(non_camel_case_types)] +#![doc = include_str!("../readme.md")] + +pub mod asarray; +pub mod device; +pub mod prelude_dev; +pub mod read; diff --git a/rstsr-hdf5/src/prelude_dev.rs b/rstsr-hdf5/src/prelude_dev.rs new file mode 100644 index 0000000..24c3481 --- /dev/null +++ b/rstsr-hdf5/src/prelude_dev.rs @@ -0,0 +1,7 @@ +#![allow(unused_imports)] + +pub(crate) use rstsr_core::prelude_dev::*; + +pub(crate) use hdf5_metno::{Dataset, H5Type}; + +pub use crate::device::DeviceHDF5; diff --git a/rstsr-hdf5/src/read.rs b/rstsr-hdf5/src/read.rs new file mode 100644 index 0000000..a98f878 --- /dev/null +++ b/rstsr-hdf5/src/read.rs @@ -0,0 +1,9 @@ +use crate::prelude_dev::*; + +pub fn read_slice_to_cpu(_dataset: &Dataset, _layout: &D) -> Result<(Vec, Layout)> +where + T: H5Type, + D: DimAPI, +{ + todo!() +} diff --git a/rstsr-openblas/src/device.rs b/rstsr-openblas/src/device.rs index f12f99a..08a76d9 100644 --- a/rstsr-openblas/src/device.rs +++ b/rstsr-openblas/src/device.rs @@ -69,39 +69,6 @@ impl DeviceStorageAPI for DeviceBLAS { let (raw, _) = storage.into_raw_parts(); Ok(raw.into_owned().into_raw()) } - - #[inline] - fn get_index(storage: &Storage, index: usize) -> T - where - T: Clone, - R: DataAPI, - { - storage.raw()[index].clone() - } - - #[inline] - fn get_index_ptr(storage: &Storage, index: usize) -> *const T - where - R: DataAPI, - { - &storage.raw()[index] as *const T - } - - #[inline] - fn get_index_mut_ptr(storage: &mut Storage, index: usize) -> *mut T - where - R: DataMutAPI, - { - storage.raw_mut().get_mut(index).unwrap() as *mut T - } - - #[inline] - fn set_index(storage: &mut Storage, index: usize, value: T) - where - R: DataMutAPI, - { - storage.raw_mut()[index] = value; - } } impl DeviceAPI for DeviceBLAS {}