diff --git a/src/ops/rnn.rs b/src/ops/rnn.rs index f5a8e4e3..d9371b7c 100644 --- a/src/ops/rnn.rs +++ b/src/ops/rnn.rs @@ -9,8 +9,8 @@ use crate::ops::{ add_in_place, mul_in_place, sigmoid, tanh, InputList, IntoOpResult, OpError, Operator, OutputList, }; +use crate::static_dims; use crate::tensor_pool::{AutoReturn, TensorPool}; -use crate::{check_dims, static_dims}; /// Direction that an RNN operator will traverse the input sequence in. #[derive(Copy, Clone, Debug)] @@ -144,11 +144,16 @@ pub fn gru( let input = static_dims!(input, 3, "seq, batch, input")?; let weights = static_dims!(weights, 3, "dir, hidden x 3, input")?; let recurrent_weights = static_dims!(recurrent_weights, 3)?; + let bias = bias + .map(|bias| static_dims!(bias, 2, "dir, hidden x 6")) + .transpose()?; let [seq_len, batch, _input_size] = input.shape(); let [_directions, hidden_x3, _input_size] = weights.shape(); - let initial_hidden = initial_hidden.map(|ih| static_dims!(ih, 3)).transpose()?; + let initial_hidden = initial_hidden + .map(|initial_hidden| static_dims!(initial_hidden, 3)) + .transpose()?; let num_directions = direction.num_directions(); let hidden_size = hidden_x3 / 3; @@ -164,12 +169,12 @@ pub fn gru( const HIDDEN_GATE: usize = 2; let n_gates = 3; - let mut gates = Tensor::zeros_in(pool, &[batch, n_gates * hidden_size]).auto_return(pool); + let mut gates = NdTensor::zeros_in(pool, [batch, n_gates * hidden_size]).auto_return(pool); let gate_range = |gate| (gate * hidden_size)..((gate + 1) * hidden_size); // Scratch space for output of `hidden_state @ hidden_weights` matmul. let mut hidden_scratch = - Tensor::zeros_in(pool, &[batch, n_gates * hidden_size]).auto_return(pool); + NdTensor::zeros_in(pool, [batch, n_gates * hidden_size]).auto_return(pool); let gemm = GemmExecutor::new(); for dir in 0..num_directions { @@ -193,10 +198,10 @@ pub fn gru( let input_bias = bias .as_ref() - .map(|b| b.slice::<1, _>((dir, ..(n_gates * hidden_size)))); + .map(|b| b.slice_with((dir, ..(n_gates * hidden_size)))); let hidden_bias = bias .as_ref() - .map(|b| b.slice::<1, _>((dir, (n_gates * hidden_size)..))); + .map(|b| b.slice_with((dir, (n_gates * hidden_size)..))); for seq in sequence_for_dir(direction, dir, seq_len) { let in_item = input.slice_with([seq]); @@ -238,7 +243,7 @@ pub fn gru( 0., /* beta */ ); if let Some(input_bias) = input_bias { - add_in_place(gates.view_mut(), input_bias.as_dyn()); + add_in_place(gates.as_dyn_mut(), input_bias.as_dyn()); } // Compute `hidden @ hidden_weights + hidden_bias` for all gates. @@ -252,15 +257,15 @@ pub fn gru( 0., /* beta */ ); if let Some(hidden_bias) = hidden_bias { - add_in_place(hidden_scratch.view_mut(), hidden_bias.as_dyn()); + add_in_place(hidden_scratch.as_dyn_mut(), hidden_bias.as_dyn()); } // Combine inputs for reset and update gates and apply activation. - let mut update_reset_gates = gates.slice_mut::<2, _>(( + let mut update_reset_gates = gates.slice_with_mut(( .., gate_range(UPDATE_GATE).start..gate_range(RESET_GATE).end, )); - let hidden_scratch_reset_update_gates = hidden_scratch.slice::<2, _>(( + let hidden_scratch_reset_update_gates = hidden_scratch.slice_with(( .., gate_range(UPDATE_GATE).start..gate_range(RESET_GATE).end, )); @@ -278,22 +283,23 @@ pub fn gru( // gates are in the same positions in the `update_reset_gates` slice // as `gates`. let update_reset_gates = sigmoid(pool, update_reset_gates.as_dyn()).auto_return(pool); - let update_gate = update_reset_gates.slice::<2, _>((.., gate_range(UPDATE_GATE))); - let reset_gate = update_reset_gates.slice::<2, _>((.., gate_range(RESET_GATE))); + let update_reset_gates = update_reset_gates.nd_view::<2>(); + let update_gate = update_reset_gates.slice_with((.., gate_range(UPDATE_GATE))); + let reset_gate = update_reset_gates.slice_with((.., gate_range(RESET_GATE))); // Combine inputs for hidden gate and apply activation. let mut hidden_gate_recurrent = - hidden_scratch.slice_mut::<2, _>((.., gate_range(HIDDEN_GATE))); + hidden_scratch.slice_with_mut((.., gate_range(HIDDEN_GATE))); mul_in_place(hidden_gate_recurrent.as_dyn_mut(), reset_gate.as_dyn()); - let mut hidden_gate = gates.slice_mut::<2, _>((.., gate_range(HIDDEN_GATE))); + let mut hidden_gate = gates.slice_with_mut((.., gate_range(HIDDEN_GATE))); add_in_place(hidden_gate.as_dyn_mut(), hidden_gate_recurrent.as_dyn()); // See note above about `sigmoid_in_place`. let hidden_gate = tanh(pool, hidden_gate.as_dyn()).auto_return(pool); // Compute next hidden state - let mut hidden_item = hidden.slice_mut::<2, _>([dir]); + let mut hidden_item = hidden.slice_with_mut([dir]); for (hidden, update, hidden_gate) in zip3( hidden_item.iter_mut(), @@ -304,7 +310,7 @@ pub fn gru( } hidden_seq - .slice_mut::<2, _>([seq, dir]) + .slice_with_mut([seq, dir]) .copy_from(&hidden_item); } } @@ -375,9 +381,13 @@ pub fn lstm( initial_cell: Option, ) -> Result, OpError> { // TODO - Add validation of the sizes of individual dimensions in the inputs. - let [seq_len, batch, _input_size] = check_dims!(input, 3, "seq, batch, input"); - let [_directions, hidden_x4, _input_size] = check_dims!(weights, 3, "dir, hidden x 4, input"); - check_dims!(recurrent_weights, 3); + let input = static_dims!(input, 3, "seq, batch, input")?; + let [seq_len, batch, _input_size] = input.shape(); + + let weights = static_dims!(weights, 3, "dir, hidden x 4, input")?; + let [_directions, hidden_x4, _input_size] = weights.shape(); + + let recurrent_weights = static_dims!(recurrent_weights, 3, "dir, hidden x 4, hidden")?; let num_directions = direction.num_directions(); let hidden_size = hidden_x4 / 4; @@ -387,14 +397,20 @@ pub fn lstm( "weights dim 1 must be 4 * hidden_size", )); } + + let bias = bias.map(|bias| static_dims!(bias, 2)).transpose()?; if let Some(bias) = bias.as_ref() { - check_dims!(bias, 2); if bias.size(1) % 8 != 0 { return Err(OpError::InvalidValue("bias dim 1 must be 8 * hidden_size")); } } - check_dims!(initial_hidden?, 3); - check_dims!(initial_cell?, 3); + + let initial_hidden = initial_hidden + .map(|initial_hidden| static_dims!(initial_hidden, 3)) + .transpose()?; + let initial_cell = initial_cell + .map(|initial_cell| static_dims!(initial_cell, 3)) + .transpose()?; // Contiguous input and bias needed to allow reshaping below. let input = input.to_contiguous_in(pool).auto_return(pool); @@ -407,17 +423,17 @@ pub fn lstm( const CELL_GATE: usize = 3; let n_gates = 4; - let mut gates = Tensor::zeros_in(pool, &[batch, n_gates * hidden_size]); + let mut gates = NdTensor::zeros_in(pool, [batch, n_gates * hidden_size]); let mut cell = initial_cell .map(|t| t.to_tensor_in(pool)) - .unwrap_or_else(|| Tensor::zeros_in(pool, &[num_directions, batch, hidden_size])); + .unwrap_or_else(|| NdTensor::zeros_in(pool, [num_directions, batch, hidden_size])); let mut hidden = initial_hidden .map(|t| t.to_tensor_in(pool)) - .unwrap_or_else(|| Tensor::zeros_in(pool, &[num_directions, batch, hidden_size])); + .unwrap_or_else(|| NdTensor::zeros_in(pool, [num_directions, batch, hidden_size])); let mut hidden_seq = - Tensor::::zeros_in(pool, &[seq_len, num_directions, batch, hidden_size]); + NdTensor::::zeros_in(pool, [seq_len, num_directions, batch, hidden_size]); let gemm = GemmExecutor::new(); @@ -426,7 +442,7 @@ pub fn lstm( for dir in 0..num_directions { let prepack = seq_len >= PREPACK_MIN_SEQ_LEN; - let input_weights = weights.slice::<2, _>(dir).transposed(); + let input_weights = weights.slice_with(dir).transposed(); let packed_input_weights = prepack.then(|| gemm.prepack_b_in(pool, input_weights).auto_return(pool)); let input_weights = packed_input_weights @@ -434,7 +450,7 @@ pub fn lstm( .map(|packed| GemmInputB::Packed(packed)) .unwrap_or(GemmInputB::Unpacked(input_weights)); - let hidden_weights = recurrent_weights.slice::<2, _>(dir).transposed(); + let hidden_weights = recurrent_weights.slice_with(dir).transposed(); let packed_hidden_weights = prepack.then(|| gemm.prepack_b_in(pool, hidden_weights).auto_return(pool)); let hidden_weights = packed_hidden_weights @@ -444,10 +460,10 @@ pub fn lstm( let input_bias = bias .as_ref() - .map(|b| b.slice::<1, _>((dir, ..(n_gates * hidden_size)))); + .map(|b| b.slice_with((dir, ..(n_gates * hidden_size)))); let hidden_bias = bias .as_ref() - .map(|b| b.slice::<1, _>((dir, (n_gates * hidden_size)..))); + .map(|b| b.slice_with((dir, (n_gates * hidden_size)..))); for seq in sequence_for_dir(direction, dir, seq_len) { // From the ONNX spec, the intermediate values are computed as: @@ -469,8 +485,8 @@ pub fn lstm( // supported. // - `f`, `g` and `h` are activations. `f`=sigmoid, `g` and `h` // are tanh. - let in_item = input.slice::<2, _>([seq]); - let hidden_item = hidden.slice::<2, _>([dir]); + let in_item = input.slice_with([seq]); + let hidden_item = hidden.slice_with([dir]); // Update input, output, forget and cell gates. let gates_row_stride = gates.stride(gates.ndim() - 2); @@ -483,7 +499,7 @@ pub fn lstm( 0., /* beta */ ); if let Some(input_bias) = input_bias { - add_in_place(gates.view_mut(), input_bias.as_dyn()); + add_in_place(gates.as_dyn_mut(), input_bias.as_dyn()); } gemm.gemm( @@ -495,25 +511,27 @@ pub fn lstm( 1., /* beta */ ); if let Some(hidden_bias) = hidden_bias { - add_in_place(gates.view_mut(), hidden_bias.as_dyn()); + add_in_place(gates.as_dyn_mut(), hidden_bias.as_dyn()); } // Copy gates to work around `tanh_in_place` and `sigmoid_in_place` // being slow for non-contiguous inputs. See notes in GRU op. - let iof_gates = gates.slice::<2, _>(( + let iof_gates = gates.slice_with(( .., gate_range(INPUT_GATE).start..gate_range(FORGET_GATE).end, )); let iof_gates = sigmoid(pool, iof_gates.as_dyn()).auto_return(pool); - let input_gate = iof_gates.slice::<2, _>((.., gate_range(INPUT_GATE))); - let out_gate = iof_gates.slice::<2, _>((.., gate_range(OUTPUT_GATE))); - let forget_gate = iof_gates.slice::<2, _>((.., gate_range(FORGET_GATE))); + let iof_gates = iof_gates.nd_view::<2>(); + + let input_gate = iof_gates.slice_with((.., gate_range(INPUT_GATE))); + let out_gate = iof_gates.slice_with((.., gate_range(OUTPUT_GATE))); + let forget_gate = iof_gates.slice_with((.., gate_range(FORGET_GATE))); - let cell_gate = gates.slice::<2, _>((.., gate_range(CELL_GATE))); + let cell_gate = gates.slice_with((.., gate_range(CELL_GATE))); let cell_gate = tanh(pool, cell_gate.as_dyn()).auto_return(pool); // Update cell and hidden state - let mut cell_item = cell.slice_mut::<2, _>([dir]); + let mut cell_item = cell.slice_with_mut([dir]); for (cell, forget_gate, input_gate, cell_gate) in zip4( cell_item.iter_mut(), @@ -524,7 +542,7 @@ pub fn lstm( *cell = forget_gate * *cell + input_gate * cell_gate; } - let mut hidden_item = hidden.slice_mut::<2, _>([dir]); + let mut hidden_item = hidden.slice_with_mut([dir]); for (hidden, out_gate, cell) in zip3(hidden_item.iter_mut(), out_gate.iter(), cell_item.iter()) { @@ -532,12 +550,12 @@ pub fn lstm( } hidden_seq - .slice_mut::<2, _>([seq, dir]) + .slice_with_mut([seq, dir]) .copy_from(&hidden_item); } } - Ok([hidden_seq, hidden, cell].into()) + Ok([hidden_seq.into_dyn(), hidden.into_dyn(), cell.into_dyn()].into()) } impl Operator for LSTM {