Skip to content

HDF5 device #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ members = [
"rstsr-blas-traits",
"rstsr-test-manifest",
"rstsr-linalg-traits",
"rstsr-hdf5",
]

[workspace.package]
Expand All @@ -30,6 +31,7 @@ rstsr-test-manifest = { path = "./rstsr-test-manifest", default-features = false
rstsr-openblas = { path = "./rstsr-openblas", default-features = false, version = "0.2.0" }
rstsr-blas-traits = { path = "./rstsr-blas-traits", default-features = false, version = "0.2.0" }
rstsr-linalg-traits = { path = "./rstsr-linalg-traits", default-features = false, version = "0.2.0" }
rstsr-hdf5 = { path = "./rstsr-hdf5", default-features = false, version = "0.2.0" }
# basic dependencies
num = { version = "0.4" }
thiserror = { version = "1.0" }
Expand All @@ -41,12 +43,14 @@ rayon = { version = "1.10" }
faer = { version = "0.19" }
faer-ext = { version = "0.3" }
faer-entity = { version = "0.19" }
hdf5-metno = { version = "0.10" }
ndarray = { version = "0.15" }
tokio = { version = "1.44", features = ["rt-multi-thread", "macros"] }
# dev dependencies
npyz = { version = "0.8", features = ["complex"] }
anyhow = { version = "1.0" }
rand = { version = "0.8" }
approx = { version = "0.5" }
ndarray = { version = "0.15" }
criterion = { version = "0.5" }
cpu-time = { version = "1.0" }

Expand Down
33 changes: 0 additions & 33 deletions rstsr-core/src/device_cpu_serial/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,39 +44,6 @@ impl<T> DeviceStorageAPI<T> for DeviceCpuSerial {
let (raw, _) = storage.into_raw_parts();
Ok(raw.into_owned().into_raw())
}

#[inline]
fn get_index<R>(storage: &Storage<R, T, Self>, index: usize) -> T
where
T: Clone,
R: DataAPI<Data = Self::Raw>,
{
storage.raw()[index].clone()
}

#[inline]
fn get_index_ptr<R>(storage: &Storage<R, T, Self>, index: usize) -> *const T
where
R: DataAPI<Data = Self::Raw>,
{
storage.raw().get(index).unwrap() as *const T
}

#[inline]
fn get_index_mut_ptr<R>(storage: &mut Storage<R, T, Self>, index: usize) -> *mut T
where
R: DataMutAPI<Data = Self::Raw>,
{
storage.raw_mut().get_mut(index).unwrap() as *mut T
}

#[inline]
fn set_index<R>(storage: &mut Storage<R, T, Self>, index: usize, value: T)
where
R: DataMutAPI<Data = Self::Raw>,
{
storage.raw_mut()[index] = value;
}
}

impl<T> DeviceAPI<T> for DeviceCpuSerial {}
Expand Down
33 changes: 0 additions & 33 deletions rstsr-core/src/device_faer/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,39 +76,6 @@ impl<T> DeviceStorageAPI<T> for DeviceFaer {
let (raw, _) = storage.into_raw_parts();
Ok(raw.into_owned().into_raw())
}

#[inline]
fn get_index<R>(storage: &Storage<R, T, Self>, index: usize) -> T
where
T: Clone,
R: DataAPI<Data = Self::Raw>,
{
storage.raw()[index].clone()
}

#[inline]
fn get_index_ptr<R>(storage: &Storage<R, T, Self>, index: usize) -> *const T
where
R: DataAPI<Data = Self::Raw>,
{
&storage.raw()[index] as *const T
}

#[inline]
fn get_index_mut_ptr<R>(storage: &mut Storage<R, T, Self>, index: usize) -> *mut T
where
R: DataMutAPI<Data = Self::Raw>,
{
storage.raw_mut().get_mut(index).unwrap() as *mut T
}

#[inline]
fn set_index<R>(storage: &mut Storage<R, T, Self>, index: usize, value: T)
where
R: DataMutAPI<Data = Self::Raw>,
{
storage.raw_mut()[index] = value;
}
}

impl<T> DeviceAPI<T> for DeviceFaer {}
Expand Down
4 changes: 3 additions & 1 deletion rstsr-core/src/layout/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,10 @@ where
/// Check whether current layout has been broadcasted.
///
/// This check is done by checking whether any stride of axis is zero.
/// Additionally, for certain dimension, if (shape, stride) = (1, 0), then
/// this dimension is not considered as broadcasted.
pub fn is_broadcasted(&self) -> bool {
self.stride().as_ref().contains(&0)
self.stride().as_ref().iter().zip(self.shape().as_ref()).any(|(&s, &d)| s == 0 && d != 1)
}
}

Expand Down
32 changes: 31 additions & 1 deletion rstsr-core/src/layout/indexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ pub enum Indexer {
Insert,
/// Expand dimensions.
Ellipsis,
/// Broadcast dimensions.
///
/// This option is not designed to be used by general users. `Broadcast(1)`
/// is equilvalent to `Insert`.
Broadcast(usize),
}

pub use Indexer::Ellipsis;
Expand Down Expand Up @@ -273,6 +278,10 @@ pub trait IndexerLargerOneAPI {
/// Insert dimension after, with shape 1. Number of dimension will increase
/// by 1.
fn dim_insert(&self, axis: isize) -> Result<Layout<Self::DOut>>;

/// Insert dimension after, with shape broadcasted. Number of dimension will
/// increase by 1.
fn dim_broadcast(&self, axis: isize, num: usize) -> Result<Layout<Self::DOut>>;
}

impl<D> IndexerLargerOneAPI for Layout<D>
Expand Down Expand Up @@ -313,6 +322,24 @@ where
let layout = Layout::new(shape, stride, offset)?;
return layout.into_dim();
}

fn dim_broadcast(&self, axis: isize, num: usize) -> Result<Layout<Self::DOut>> {
// dimension check
let axis = if axis < 0 { self.ndim() as isize + axis + 1 } else { axis };
rstsr_pattern!(axis, 0..(self.ndim() + 1) as isize, ValueOutOfRange)?;
let axis = axis as usize;

// get essential information
let mut shape = self.shape().as_ref().to_vec();
let mut stride = self.stride().as_ref().to_vec();
let offset = self.offset();

shape.insert(axis, num);
stride.insert(axis, 0);

let layout = Layout::new(shape, stride, offset)?;
return layout.into_dim();
}
}

pub trait IndexerDynamicAPI: IndexerPreserveAPI {
Expand Down Expand Up @@ -350,7 +377,7 @@ where
Some(_) => rstsr_raise!(InvalidValue, "Only one ellipsis indexer allowed.")?,
None => idx_ellipsis = Some(n),
},
_ => {},
Indexer::Insert | Indexer::Broadcast(_) => {},
}
}

Expand Down Expand Up @@ -392,6 +419,9 @@ where
Indexer::Insert => {
layout = layout.dim_insert(cur_dim)?;
},
Indexer::Broadcast(num) => {
layout = layout.dim_broadcast(cur_dim, *num)?;
},
_ => rstsr_raise!(InvalidValue, "Invalid indexer found : {:?}", indexer)?,
}
}
Expand Down
2 changes: 2 additions & 0 deletions rstsr-core/src/layout/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod iterator;
pub mod layoutbase;
pub mod matmul;
pub mod rearrangement;
pub mod reflection;
pub mod reshape;
pub mod shape;
pub mod slice;
Expand All @@ -17,6 +18,7 @@ pub use iterator::*;
pub use layoutbase::*;
pub use matmul::*;
pub use rearrangement::*;
pub use reflection::*;
pub use reshape::*;
pub use shape::*;
pub use slice::*;
Expand Down
120 changes: 120 additions & 0 deletions rstsr-core/src/layout/reflection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
//! Layout/Indexing/Slicing reflection (inverse mapping) utilities.
//!
//! Note this name does not mean reflection (反射) of program. It just means the
//! inverse (逆 ~ 反) mapping (射).

use crate::prelude_dev::*;

/// Deduce the n-dimensional index from the offset in the layout.
///
/// Tensor indexing is passing the index to layout, and returns the offset (as
/// memory shift to some anchor).
///
/// If offset is obtained by indexing the layout:
/// `offset = layout.index(index)`
/// Then this function tries to obtain the inverse mapping `index` from
/// `layout`.
///
/// # Note
///
/// The inverse mapping may not be unique.
///
/// Restrictions:
/// - Layout should not be broadcasted.
///
/// This function should not be applied in computationally intensive part.
pub fn layour_reflect_index(layout: &Layout<IxD>, offset: usize) -> Result<IxD> {
layout.check_strides()?;
layout.bounds_index()?;
if layout.is_broadcasted() {
rstsr_raise!(InvalidLayout, "Layout should not be broadcasted in `layour_reflect_index`.")?;
}

// 1. Prepare the result
let mut index: Vec<usize> = vec![0; layout.ndim()];

// 2. Get the location vector of strides, with the largest absolute stride to be
// the first Also, only shape != 1 leaves, other cases are ignored.
let arg_stride = layout
.stride()
.iter()
.enumerate()
.filter(|a| layout.shape()[a.0] != 1)
.sorted_by(|a, b| b.1.abs().cmp(&a.1.abs()))
.collect_vec();

// 3. Calculate the index
let mut inner_offset = offset as isize - layout.offset() as isize;
for (n, &(i, &s)) in arg_stride.iter().enumerate() {
let q = inner_offset.unsigned_abs() / s.unsigned_abs();
let r = inner_offset.unsigned_abs() % s.unsigned_abs();
// we can not tolarate if `q != 0` and sign of `inner_offset` is not the same to
// stride `s`. If this happens, index is negative, which is not allowed.
if q != 0 && inner_offset * s < 0 {
rstsr_raise!(InvalidValue, "Negative index occured.")?;
}
// offset is divisible by stride
// then the leaving index should be 0, so early break
if r == 0 {
index[i] = q;
break;
}
// if last element not divisible by stride
// then the provided offset is invalid
if n == arg_stride.len() - 1 {
rstsr_raise!(InvalidValue, "Offset is not divisible by the smallest stride.")?;
}
// generate next inner_offset
// next inner_offset must have the same sign with the next stride
// so the `q` given by modular division may have to increase by 1
inner_offset -= s * q as isize;
index[i] = q;
if inner_offset.is_negative() != arg_stride[n + 1].1.is_negative() {
inner_offset -= s;
index[i] += 1;
}
}

// 4. Before return the index, we should check if the given index can recover
// the input offset.
let offset_recover = layout.index_f(&index.iter().map(|&i| i as isize).collect_vec())?;
if offset_recover != offset {
rstsr_raise!(
RuntimeError,
"The given offset can not be recovered by the index, may be an internal bug."
)?;
}

return Ok(index);
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_layour_reflect_index() {
let layout = vec![10, 12, 1, 15].f().dim_narrow(1, slice!(10, 2, -1)).unwrap();
println!("{:?}", layout);

// test usual case
let offset = layout.index(&[2, 3, 0, 5]);
let index = layour_reflect_index(&layout, offset).unwrap();
assert_eq!(index, [2, 3, 0, 5]);

// test another usual case
let offset = layout.index(&[2, 1, 0, 0]);
let index = layour_reflect_index(&layout, offset).unwrap();
assert_eq!(index, [2, 1, 0, 0]);

// test early stop case
let offset = layout.index(&[0, 0, 0, 3]);
let index = layour_reflect_index(&layout, offset).unwrap();
assert_eq!(index, [0, 0, 0, 3]);

// test case should failed (dim-1 out of bound)
let offset = [10, 12, 1, 15].f().index(&[2, 11, 0, 3]);
let index = layour_reflect_index(&layout, offset);
assert!(index.is_err());
}
}
42 changes: 0 additions & 42 deletions rstsr-core/src/storage/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,6 @@ pub trait DeviceStorageAPI<T>: DeviceRawAPI<T> {
where
Self::Raw: Clone,
R: DataCloneAPI<Data = Self::Raw>;
fn get_index<R>(storage: &Storage<R, T, Self>, index: usize) -> T
where
T: Clone,
R: DataAPI<Data = Self::Raw>;
fn get_index_ptr<R>(storage: &Storage<R, T, Self>, index: usize) -> *const T
where
R: DataAPI<Data = Self::Raw>;
fn get_index_mut_ptr<R>(storage: &mut Storage<R, T, Self>, index: usize) -> *mut T
where
R: DataMutAPI<Data = Self::Raw>;
fn set_index<R>(storage: &mut Storage<R, T, Self>, index: usize, value: T)
where
R: DataMutAPI<Data = Self::Raw>;
}

impl<R, T, B> Storage<R, T, B>
Expand Down Expand Up @@ -99,35 +86,6 @@ where
{
B::into_cpu_vec(self)
}

#[inline]
pub fn get_index(&self, index: usize) -> T
where
T: Clone,
{
B::get_index(self, index)
}

#[inline]
pub fn get_index_ptr(&self, index: usize) -> *const T {
B::get_index_ptr(self, index)
}

#[inline]
pub fn get_index_mut_ptr(&mut self, index: usize) -> *mut T
where
R: DataMutAPI<Data = B::Raw>,
{
B::get_index_mut_ptr(self, index)
}

#[inline]
pub fn set_index(&mut self, index: usize, value: T)
where
R: DataMutAPI<Data = B::Raw>,
{
B::set_index(self, index, value)
}
}

impl<R, T, B> Storage<R, T, B>
Expand Down
14 changes: 14 additions & 0 deletions rstsr-hdf5/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "rstsr-hdf5"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true

[dependencies]
rstsr-core = { workspace = true, default-features = false }
hdf5-metno = { workspace = true }
ndarray = { workspace = true }
Loading
Loading