Skip to content

Commit

Permalink
Merge pull request #383 from robertknight/generalize-output-conversions
Browse files Browse the repository at this point in the history
Generalize `Output` => `Tensor`/`TensorView` convenience methods
  • Loading branch information
robertknight authored Oct 15, 2024
2 parents d7a3490 + 6217458 commit ba13b33
Show file tree
Hide file tree
Showing 14 changed files with 90 additions and 73 deletions.
4 changes: 3 additions & 1 deletion rten-examples/src/jina_similarity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ fn embed_sentence_batch(

let output_id = model.node_id("last_hidden_state")?;
let [last_hidden_state] = model.run_n(inputs, [output_id], None)?;
let last_hidden_state = last_hidden_state.into_float().ok_or("wrong output type")?;
let last_hidden_state = last_hidden_state
.into_tensor::<f32>()
.ok_or("wrong output type")?;

// Mean pool each item in the batch. We process each batch item separately
// since they can have different lengths.
Expand Down
43 changes: 27 additions & 16 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1673,7 +1673,12 @@ mod tests {
],
);
assert_eq!(results.len(), 1);
expect_equal_with_tolerance(results[0].as_float_ref().unwrap(), &expected, 1e-4, 0.)?;
expect_equal_with_tolerance(
&results[0].as_tensor_view().unwrap(),
&expected.view(),
1e-4,
0.,
)?;

Ok(())
}
Expand Down Expand Up @@ -1793,13 +1798,13 @@ mod tests {
.run(vec![(input_id, input.view().into())], &[op_c_out], None)
.unwrap();
let expected = Tensor::from_data(&[2], vec![2., 3.]);
expect_equal(results[0].as_float_ref().unwrap(), &expected)?;
expect_equal(&results[0].as_tensor_view().unwrap(), &expected.view())?;

let results = g
.run(vec![(input_id, input.into())], &[op_d_out], None)
.unwrap();
let expected = Tensor::from_data(&[2], vec![3., 2.]);
expect_equal(results[0].as_float_ref().unwrap(), &expected)?;
expect_equal(&results[0].as_tensor_view().unwrap(), &expected.view())?;

Ok(())
}
Expand All @@ -1818,8 +1823,14 @@ mod tests {
let results = g
.run(vec![(input_id, input.into())], &[op_a_out, op_b_out], None)
.unwrap();
assert_eq!(results[0].as_float_ref().unwrap(), &Tensor::from(1.));
assert_eq!(results[1].as_float_ref().unwrap(), &Tensor::from(2.));
assert_eq!(
&results[0].as_tensor_view().unwrap(),
&Tensor::from(1.).view()
);
assert_eq!(
&results[1].as_tensor_view().unwrap(),
&Tensor::from(2.).view()
);
}

#[test]
Expand All @@ -1846,7 +1857,7 @@ mod tests {
.unwrap();

let expected = Tensor::from_data(&[5], vec![101., 102., 103., 104., 105.]);
expect_equal(results[0].as_float_ref().unwrap(), &expected)?;
expect_equal(&results[0].as_tensor_view().unwrap(), &expected.view())?;

Ok(())
}
Expand All @@ -1862,7 +1873,7 @@ mod tests {
.run(vec![(input_id, input.view().into())], &[input_id], None)
.unwrap();

expect_equal(results[0].as_float_ref().unwrap(), &input)?;
expect_equal(&results[0].as_tensor_view().unwrap(), &input.view())?;

Ok(())
}
Expand All @@ -1876,7 +1887,7 @@ mod tests {

let results = g.run(vec![], &[const_id], None).unwrap();

expect_equal(results[0].as_float_ref().unwrap(), &value)?;
expect_equal(&results[0].as_tensor_view().unwrap(), &value.view())?;

Ok(())
}
Expand Down Expand Up @@ -2031,7 +2042,7 @@ mod tests {
input: Output,
_other: InputList,
) -> Result<Output, OpError> {
let mut output = input.into_float().unwrap();
let mut output = input.into_tensor::<f32>().unwrap();
for x in output.iter_mut() {
*x = *x + 1.0;
}
Expand All @@ -2055,14 +2066,14 @@ mod tests {
let results = g
.run(vec![(input_id, input.view().into())], &[op1_out], None)
.unwrap();
assert_eq!(results[0].as_float_ref().unwrap()[[0, 0]], 0.0);
assert_eq!(results[0].as_tensor_view::<f32>().unwrap()[[0, 0]], 0.0);

// Second operator should be run in-place, as it meets all the
// requirements for this optimization.
let results = g
.run(vec![(input_id, input.view().into())], &[op2_out], None)
.unwrap();
assert_eq!(results[0].as_float_ref().unwrap()[[0, 0]], 1.0);
assert_eq!(results[0].as_tensor_view::<f32>().unwrap()[[0, 0]], 1.0);

// Third op should not be run in place, because its input is re-used
// for fourth op. Fourth op can run in place as by then, it is the
Expand All @@ -2074,8 +2085,8 @@ mod tests {
None,
)
.unwrap();
assert_eq!(results[0].as_float_ref().unwrap()[[0, 0]], 1.0);
assert_eq!(results[1].as_float_ref().unwrap()[[0, 0]], 2.0);
assert_eq!(results[0].as_tensor_view::<f32>().unwrap()[[0, 0]], 1.0);
assert_eq!(results[1].as_tensor_view::<f32>().unwrap()[[0, 0]], 2.0);
}

// Test that the graph executor will swap inputs to commutative ops if
Expand Down Expand Up @@ -2117,7 +2128,7 @@ mod tests {
// Bias value should be added twice to every input.
assert_eq!(
results[0]
.as_float_ref()
.as_tensor_view::<f32>()
.unwrap()
.iter()
.copied()
Expand Down Expand Up @@ -2199,8 +2210,8 @@ mod tests {
assert_eq!(*run_count.lock().unwrap(), 1);

assert_eq!(results.len(), 2);
let left_split = results.remove(0).into_float().unwrap();
let right_split = results.remove(0).into_float().unwrap();
let left_split = results.remove(0).into_tensor::<f32>().unwrap();
let right_split = results.remove(0).into_tensor::<f32>().unwrap();
assert_eq!(left_split.to_vec(), &[1.0, 2.0]);
assert_eq!(right_split.to_vec(), &[3.0, 4.0, 5.0]);
}
Expand Down
2 changes: 1 addition & 1 deletion src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ mod tests {
fn check_output(mut result: Vec<Output>) -> Tensor<f32> {
assert_eq!(result.len(), 1);

let tensor: Tensor<f32> = result.remove(0).into_float().unwrap();
let tensor: Tensor<f32> = result.remove(0).into_tensor::<f32>().unwrap();
assert_eq!(tensor.shape(), &[2, 2, 2]);
assert_eq!(tensor.to_vec(), &[0.5, 0., 0.1, 0., 1., 2., 0., 0.]);

Expand Down
8 changes: 5 additions & 3 deletions src/ops/binary_elementwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,9 @@ impl Operator for Pow {
input: Output,
other: InputList,
) -> Result<Output, OpError> {
let mut a = input.into_float().ok_or(OpError::IncorrectInputType)?;
let mut a = input
.into_tensor::<f32>()
.ok_or(OpError::IncorrectInputType)?;
let b = other.require_as(0)?;

if can_run_binary_op_in_place(&a, &b) {
Expand Down Expand Up @@ -1151,7 +1153,7 @@ mod tests {
let result = op
.run_in_place(&pool, Output::FloatTensor(a_copy), (&b).into())
.unwrap();
expect_equal(result.as_float_ref().unwrap(), &expected)?;
expect_equal(&result.as_tensor_view().unwrap(), &expected.view())?;

// Run `Add` operator in-place with inputs that don't support in-place
// addition. The operator should fall back to creating a new output tensor.
Expand All @@ -1160,7 +1162,7 @@ mod tests {
let result = op
.run_in_place(&pool, Output::FloatTensor(scalar), (&b).into())
.unwrap();
expect_equal(result.as_float_ref().unwrap(), &expected)?;
expect_equal(&result.as_tensor_view().unwrap(), &expected.view())?;

// In-place addition where the second input must be broadcast to the
// shape of the first.
Expand Down
2 changes: 1 addition & 1 deletion src/ops/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ mod tests {
.run(&pool, (&input, &kernel).into())
.unwrap()
.remove(0)
.into_float()
.into_tensor::<f32>()
.unwrap();
let reference_result = reference_conv(
input.view(),
Expand Down
12 changes: 6 additions & 6 deletions src/ops/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ mod tests {
.run(&pool, (&int_input).into())
.unwrap()
.remove(0)
.into_int()
.into_tensor::<i32>()
.unwrap();

// Flooring cast from float => int32
Expand All @@ -89,7 +89,7 @@ mod tests {
.run(&pool, (&float_input).into())
.unwrap()
.remove(0)
.into_int()
.into_tensor::<i32>()
.unwrap();
assert_eq!(&result, &int_input);

Expand All @@ -101,7 +101,7 @@ mod tests {
.run(&pool, (&float_input).into())
.unwrap()
.remove(0)
.into_float()
.into_tensor::<f32>()
.unwrap();
expect_equal(&result, &float_input)?;

Expand All @@ -110,7 +110,7 @@ mod tests {
.run(&pool, (&int_input).into())
.unwrap()
.remove(0)
.into_float()
.into_tensor::<f32>()
.unwrap();
expect_equal(&result, &float_input)?;

Expand All @@ -131,7 +131,7 @@ mod tests {
.run(&pool, (&int_input).into())
.unwrap()
.remove(0)
.into_float()
.into_tensor::<f32>()
.unwrap();
expect_equal(&result, &Tensor::from([-2147483600.0, 2147483600.0]))?;

Expand All @@ -144,7 +144,7 @@ mod tests {
.run(&pool, (&float_input).into())
.unwrap()
.remove(0)
.into_int()
.into_tensor::<i32>()
.unwrap();
assert_eq!(&result, &Tensor::from([i32::MIN, i32::MAX]));

Expand Down
2 changes: 1 addition & 1 deletion src/ops/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ mod tests {
.run(&pool, (&shape).into())
.unwrap()
.remove(0)
.into_int()
.into_tensor::<i32>()
.unwrap();

assert_eq!(result.shape(), &[1, 5, 10]);
Expand Down
4 changes: 2 additions & 2 deletions src/ops/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ mod tests {
.run(&pool, (&int_input).into())
.unwrap()
.remove(0)
.into_int()
.into_tensor::<i32>()
.unwrap();
assert_eq!(result, int_input);

Expand All @@ -69,7 +69,7 @@ mod tests {
.run(&pool, (&float_input).into())
.unwrap()
.remove(0)
.into_float()
.into_tensor::<f32>()
.unwrap();
expect_equal(&result, &float_input)?;

Expand Down
8 changes: 4 additions & 4 deletions src/ops/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ mod tests {
.run(&pool, (&input, &shape).into())
.unwrap()
.remove(0)
.into_float()
.into_tensor::<f32>()
.unwrap();

expect_equal(&result, &expected)?;
Expand All @@ -889,7 +889,7 @@ mod tests {
.run(&pool, (&input).into())
.unwrap()
.remove(0)
.into_int()
.into_tensor::<i32>()
.unwrap();
assert_eq!(result.shape(), &[4]);
assert_eq!(result.to_vec(), &[1, 1, 2, 2]);
Expand All @@ -900,7 +900,7 @@ mod tests {
.run(&pool, (&input).into())
.unwrap()
.remove(0)
.into_int()
.into_tensor::<i32>()
.unwrap();
assert_eq!(result.shape(), &[4]);
assert_eq!(result.to_vec(), &[1, 1, 2, 2]);
Expand All @@ -915,7 +915,7 @@ mod tests {
.run(&pool, (&input).into())
.unwrap()
.remove(0)
.into_int()
.into_tensor::<i32>()
.unwrap();
assert_eq!(result.ndim(), 0);
assert_eq!(result.item(), Some(&4));
Expand Down
45 changes: 17 additions & 28 deletions src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,36 +340,25 @@ impl Output {
};
}

pub fn into_int(self) -> Option<Tensor<i32>> {
if let Output::Int32Tensor(t) = self {
Some(t)
} else {
None
}
}

pub fn as_int_ref(&self) -> Option<&Tensor<i32>> {
if let Output::Int32Tensor(t) = self {
Some(t)
} else {
None
}
}

pub fn into_float(self) -> Option<Tensor<f32>> {
if let Output::FloatTensor(t) = self {
Some(t)
} else {
None
}
/// Convert this output into a tensor with a given element type.
///
/// Returns `None` if the element type does not match `T`.
pub fn into_tensor<T>(self) -> Option<Tensor<T>>
where
Tensor<T>: TryFrom<Self>,
{
self.try_into().ok()
}

pub fn as_float_ref(&self) -> Option<&Tensor<f32>> {
if let Output::FloatTensor(t) = self {
Some(t)
} else {
None
}
/// Convert a reference to this output into a tensor view with a given
/// element type.
///
/// Returns `None` if the element type does not match `T`.
pub fn as_tensor_view<'a, T>(&'a self) -> Option<TensorView<'a, T>>
where
TensorView<'a, T>: TryFrom<&'a Self>,
{
self.try_into().ok()
}

fn layout(&self) -> &DynLayout {
Expand Down
Loading

0 comments on commit ba13b33

Please sign in to comment.