Skip to content

Commit

Permalink
Merge pull request #62 from stackhpc/missing
Browse files Browse the repository at this point in the history
Add support for missing data
  • Loading branch information
markgoddard authored Jul 27, 2023
2 parents 691476e + 465920b commit eccd84b
Show file tree
Hide file tree
Showing 11 changed files with 985 additions and 30 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ with a JSON payload of the form:
// List of algorithms used to filter the data
// - optional, defaults to no filters
"filters": [{"id": "shuffle", "element_size": 4}],
// Missing data description
// - optional, defaults to no missing data
// - exactly one of the keys below should be specified
// - the values should match the data type (dtype)
"missing": {
"missing_value": 42,
"missing_values": [42, -42],
"valid_min": 42,
"valid_max": 42,
"valid_range": [-42, 42],
}
}
```

Expand Down
27 changes: 27 additions & 0 deletions scripts/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,23 @@ def get_args() -> argparse.Namespace:
parser.add_argument("--selection", type=str)
parser.add_argument("--compression", type=str)
parser.add_argument("--shuffle", action=argparse.BooleanOptionalAction)
missing = parser.add_mutually_exclusive_group()
missing.add_argument("--missing-value", type=str)
missing.add_argument("--missing-values", type=str)
missing.add_argument("--valid-min", type=str)
missing.add_argument("--valid-max", type=str)
missing.add_argument("--valid-range", type=str)
parser.add_argument("--verbose", action=argparse.BooleanOptionalAction)
return parser.parse_args()


def parse_number(s: str):
try:
return int(s)
except ValueError:
return float(s)


def build_request_data(args: argparse.Namespace) -> dict:
request_data = {
'source': args.source,
Expand All @@ -65,6 +78,20 @@ def build_request_data(args: argparse.Namespace) -> dict:
filters.append({"id": "shuffle", "element_size": element_size})
if filters:
request_data["filters"] = filters
missing = None
if args.missing_value:
missing = {"missing_value": parse_number(args.missing_value)}
if args.missing_values:
missing = {"missing_values": [parse_number(n) for n in args.missing_values.split(",")]}
if args.valid_min:
missing = {"valid_min": parse_number(args.valid_min)}
if args.valid_max:
missing = {"valid_max": parse_number(args.valid_max)}
if args.valid_range:
min, max = args.valid_range.split(",")
missing = {"valid_range": [parse_number(min), parse_number(max)]}
if missing:
request_data["missing"] = missing
return {k: v for k, v in request_data.items() if v is not None}


Expand Down
15 changes: 15 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use std::error::Error;
use thiserror::Error;
use tracing::{event, Level};

use crate::types::DValue;

/// Active Storage server error type
///
/// This type encapsulates the various errors that may occur.
Expand All @@ -34,6 +36,9 @@ pub enum ActiveStorageError {
#[error("failed to convert from bytes to {type_name}")]
FromBytes { type_name: &'static str },

#[error("Incompatible value {0} for missing")]
IncompatibleMissing(DValue),

/// Error deserialising request data into RequestData
#[error("request data is not valid")]
RequestDataJsonRejection(#[from] JsonRejection),
Expand Down Expand Up @@ -184,6 +189,7 @@ impl From<ActiveStorageError> for ErrorResponse {
// Bad request
ActiveStorageError::Decompression(_)
| ActiveStorageError::EmptyArray { operation: _ }
| ActiveStorageError::IncompatibleMissing(_)
| ActiveStorageError::RequestDataJsonRejection(_)
| ActiveStorageError::RequestDataValidationSingle(_)
| ActiveStorageError::RequestDataValidation(_)
Expand Down Expand Up @@ -345,6 +351,15 @@ mod tests {
.await;
}

#[tokio::test]
async fn incompatible_missing() {
let value = 32.into();
let error = ActiveStorageError::IncompatibleMissing(value);
let message = "Incompatible value 32 for missing";
let caused_by = None;
test_active_storage_error(error, StatusCode::BAD_REQUEST, message, caused_by).await;
}

#[tokio::test]
async fn request_data_validation_single() {
let validation_error = validator::ValidationError::new("foo");
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ pub mod server;
#[cfg(test)]
pub mod test_utils;
pub mod tracing;
pub mod types;
pub mod validated_json;
120 changes: 116 additions & 4 deletions src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use strum_macros::Display;
use url::Url;
use validator::{Validate, ValidationError};

use crate::types::{DValue, Missing};

/// Supported numerical data types
#[derive(Clone, Copy, Debug, Deserialize, Display, PartialEq)]
#[serde(rename_all = "lowercase")]
Expand All @@ -14,11 +16,11 @@ pub enum DType {
Int32,
/// [i64]
Int64,
/// [u64]
/// [u32]
Uint32,
/// [u64]
Uint64,
/// [f64]
/// [f32]
Float32,
/// [f64]
Float64,
Expand Down Expand Up @@ -142,6 +144,8 @@ pub struct RequestData {
pub compression: Option<Compression>,
/// List of filter algorithms
pub filters: Option<Vec<Filter>>,
/// Missing data
pub missing: Option<Missing<DValue>>,
}

/// Validate an array shape
Expand Down Expand Up @@ -230,6 +234,9 @@ fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationErr
}
_ => (),
};
if let Some(missing) = &request_data.missing {
missing.validate(request_data.dtype)?;
};
Ok(())
}

Expand Down Expand Up @@ -359,6 +366,11 @@ mod tests {
Token::U32(4),
Token::MapEnd,
Token::SeqEnd,
Token::Str("missing"),
Token::Some,
Token::Enum { name: "Missing" },
Token::Str("missing_value"),
Token::I32(42),
Token::StructEnd,
],
);
Expand Down Expand Up @@ -664,14 +676,40 @@ mod tests {
)
}

#[test]
fn test_invalid_missing() {
assert_de_tokens_error::<RequestData>(
&[
Token::Struct {
name: "RequestData",
len: 2,
},
Token::Str("missing"),
Token::Some,
Token::Enum { name: "Missing" },
Token::Str("foo"),
Token::StructEnd
],
"unknown variant `foo`, expected one of `missing_value`, `missing_values`, `valid_min`, `valid_max`, `valid_range`",
)
}

#[test]
#[should_panic(expected = "Incompatible value 9223372036854775807 for missing")]
fn test_missing_invalid_value_for_dtype() {
let mut request_data = test_utils::get_test_request_data();
request_data.missing = Some(Missing::MissingValue(i64::max_value().into()));
request_data.validate().unwrap()
}

#[test]
fn test_unknown_field() {
assert_de_tokens_error::<RequestData>(&[
Token::Struct { name: "RequestData", len: 2 },
Token::Str("foo"),
Token::StructEnd
],
"unknown field `foo`, expected one of `source`, `bucket`, `object`, `dtype`, `offset`, `size`, `shape`, `order`, `selection`, `compression`, `filters`"
"unknown field `foo`, expected one of `source`, `bucket`, `object`, `dtype`, `offset`, `size`, `shape`, `order`, `selection`, `compression`, `filters`, `missing`"
)
}

Expand All @@ -686,8 +724,82 @@ mod tests {

#[test]
fn test_json_optional_fields() {
let json = r#"{"source": "http://example.com", "bucket": "bar", "object": "baz", "dtype": "int32", "offset": 4, "size": 8, "shape": [2, 5], "order": "C", "selection": [[1, 2, 3], [4, 5, 6]], "compression": {"id": "gzip"}, "filters": [{"id": "shuffle", "element_size": 4}]}"#;
let json = r#"{
"source": "http://example.com",
"bucket": "bar",
"object": "baz",
"dtype": "int32",
"offset": 4,
"size": 8,
"shape": [2, 5],
"order": "C",
"selection": [[1, 2, 3], [4, 5, 6]],
"compression": {"id": "gzip"},
"filters": [{"id": "shuffle", "element_size": 4}],
"missing": {"missing_value": 42}
}"#;
let request_data = serde_json::from_str::<RequestData>(json).unwrap();
assert_eq!(request_data, test_utils::get_test_request_data_optional());
}

#[test]
fn test_json_optional_fields2() {
let json = r#"{
"source": "http://example.com",
"bucket": "bar",
"object": "baz",
"dtype": "float64",
"offset": 4,
"size": 8,
"shape": [2, 5, 10],
"order": "F",
"selection": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
"compression": {"id": "zlib"},
"filters": [{"id": "shuffle", "element_size": 8}],
"missing": {"valid_range": [-1.0, 999.0]}
}"#;
let request_data = serde_json::from_str::<RequestData>(json).unwrap();
let mut expected = test_utils::get_test_request_data_optional();
expected.dtype = DType::Float64;
expected.shape = Some(vec![2, 5, 10]);
expected.order = Some(Order::F);
expected.selection = Some(vec![
Slice::new(1, 2, 3),
Slice::new(4, 5, 6),
Slice::new(7, 8, 9),
]);
expected.compression = Some(Compression::Zlib);
expected.filters = Some(vec![Filter::Shuffle { element_size: 8 }]);
expected.missing = Some(Missing::ValidRange(
DValue::from_f64(-1.0).unwrap(),
DValue::from_f64(999.0).unwrap(),
));
assert_eq!(request_data, expected);
}

#[test]
fn test_json_optional_fields3() {
let json = format!(
r#"{{
"source": "http://example.com",
"bucket": "bar",
"object": "baz",
"dtype": "int32",
"missing": {{"missing_values": [{}, -1, 0, 1, {}]}}
}}"#,
i64::min_value(),
u64::max_value()
);
let request_data = serde_json::from_str::<RequestData>(&json).unwrap();
let mut expected = test_utils::get_test_request_data();
expected.dtype = DType::Int32;
expected.missing = Some(Missing::MissingValues(vec![
i64::min_value().into(),
(-1).into(),
0.into(),
1.into(),
u64::max_value().into(),
]));
assert_eq!(request_data, expected);
}
}
7 changes: 7 additions & 0 deletions src/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use crate::error::ActiveStorageError;
use crate::models;
use crate::types::dvalue::TryFromDValue;

use axum::body::Bytes;

Expand All @@ -12,9 +13,12 @@ pub trait Element:
+ PartialOrd
+ num_traits::FromPrimitive
+ num_traits::Zero
+ std::convert::From<u16>
+ std::fmt::Debug
+ std::iter::Sum
+ std::ops::Add<Output = Self>
+ std::ops::Div<Output = Self>
+ TryFromDValue
+ zerocopy::AsBytes
+ zerocopy::FromBytes
{
Expand All @@ -28,9 +32,12 @@ impl<T> Element for T where
+ num_traits::FromPrimitive
+ num_traits::One
+ num_traits::Zero
+ std::convert::From<u16>
+ std::fmt::Debug
+ std::iter::Sum
+ std::ops::Add<Output = Self>
+ std::ops::Div<Output = Self>
+ TryFromDValue
+ zerocopy::AsBytes
+ zerocopy::FromBytes
{
Expand Down
Loading

0 comments on commit eccd84b

Please sign in to comment.