Skip to content

Commit

Permalink
fix tests to check values instead of shape
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsanbear committed Mar 12, 2024
1 parent 3e83575 commit e38c90a
Showing 1 changed file with 39 additions and 10 deletions.
49 changes: 39 additions & 10 deletions tests/einops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ use candle_einops::{einops, Backend};

#[test]
fn candle_layers() -> Result<()> {
let input = Tensor::randn(0.0, 1.0, (10, 3, 32, 32), &Device::Cpu)?;
let input = Tensor::randn(0.0f32, 1.0, (10, 3, 32, 32), &Device::Cpu)?;

let output1 = einops!("b c (h max(2)) (w max(2)) -> b c h w", &input);
let output2 = input.max_pool2d(2)?;

assert_eq!(output1.shape(), output2.shape());
assert_eq!(
output1.flatten_all()?.to_vec1::<f32>()?,
output2.flatten_all()?.to_vec1::<f32>()?
);

Ok(())
}
Expand All @@ -20,22 +23,42 @@ fn consistency_checks() -> Result<()> {

let output = einops!("a b c d e f -> a (b) (c d e) f", &input);
assert_eq!(
input.flatten(0, input.dims().len() - 1)?.shape(),
output.flatten(0, output.dims().len() - 1)?.shape()
input
.flatten(0, input.dims().len() - 1)?
.flatten_all()?
.to_vec1::<f32>()?,
output
.flatten(0, output.dims().len() - 1)?
.flatten_all()?
.to_vec1::<f32>()?,
);

let output1 = einops!("a b c d e f -> f e d c b a", &input);
let output2 = einops!("f e d c b a -> a b c d e f", &input);
assert_eq!(output1.shape(), output2.shape());
assert_eq!(output1.dims(), output2.dims());
assert_eq!(
output1.flatten_all()?.to_vec1::<f32>()?,
output2.flatten_all()?.to_vec1::<f32>()?
);

let intermediate = einops!("a b c d e f -> (f d) c (e b) a", &input);
let output = einops!("(f d:5) c (e b:2) a -> a b c d e f", &intermediate);
assert_eq!(output.shape(), input.shape());
assert_eq!(output.dims(), input.dims());
assert_eq!(
output.flatten_all()?.to_vec1::<f32>()?,
input.flatten_all()?.to_vec1::<f32>()?
);

let input = Tensor::arange(0f32, (2 * 3 * 4) as f32, &Device::Cpu)?.reshape(&[2, 3, 4]);
let output = einops!("a b c -> b c a", &input);
assert_eq!(input.i((1, 2, 3))?.shape(), output.i((2, 3, 1))?.shape());
assert_eq!(input.i((0, 1, 2))?.shape(), output.i((1, 2, 0))?.shape());
assert_eq!(
input.i((1, 2, 3))?.flatten_all()?.to_vec1::<f32>()?,
output.i((2, 3, 1))?.flatten_all()?.to_vec1::<f32>()?
);
assert_eq!(
input.i((0, 1, 2))?.flatten_all()?.to_vec1::<f32>()?,
output.i((1, 2, 0))?.flatten_all()?.to_vec1::<f32>()?
);

Ok(())
}
Expand All @@ -44,7 +67,10 @@ macro_rules! test {
($pattern1:literal, $pattern2:literal, $tensor:ident) => {
let output1 = einops!($pattern1, &$tensor);
let output2 = einops!($pattern2, &$tensor);
assert_eq!(output1.shape(), output2.shape(), "({}) & ({}) failed", $pattern1, $pattern2);
assert_eq!(
output1.flatten_all()?.to_vec1::<f32>()?,
output2.flatten_all()?.to_vec1::<f32>()?, "({}) & ({}) failed", $pattern1, $pattern2
);
};
($(($pattern1:literal, $pattern2:literal)),*, $tensor:ident) => {
$(test!($pattern1, $pattern2, $tensor);)*
Expand Down Expand Up @@ -88,7 +114,10 @@ macro_rules! seq_test {
($pattern1:literal, $pattern2:literal, $tensor:ident) => {
let intermediate = einops!($pattern1, $tensor.clone());
let output = einops!($pattern2, &intermediate.clone());
assert_eq!($tensor.clone().shape(), output.shape(), "({}) & ({}) failed", $pattern1, $pattern2);
assert_eq!(
$tensor.clone().flatten_all()?.to_vec1::<f32>()?,
output.flatten_all()?.to_vec1::<f32>()?, "({}) & ({}) failed", $pattern1, $pattern2
);
};
($(($pattern1:literal, $pattern2:literal)),*, $tensor:ident) => {
$(seq_test!($pattern1, $pattern2, $tensor);)*
Expand Down

0 comments on commit e38c90a

Please sign in to comment.