Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use TensorBase::slice_with in RNN operators #361

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 61 additions & 43 deletions src/ops/rnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use crate::ops::{
add_in_place, mul_in_place, sigmoid, tanh, InputList, IntoOpResult, OpError, Operator,
OutputList,
};
use crate::static_dims;
use crate::tensor_pool::{AutoReturn, TensorPool};
use crate::{check_dims, static_dims};

/// Direction that an RNN operator will traverse the input sequence in.
#[derive(Copy, Clone, Debug)]
Expand Down Expand Up @@ -144,11 +144,16 @@ pub fn gru(
let input = static_dims!(input, 3, "seq, batch, input")?;
let weights = static_dims!(weights, 3, "dir, hidden x 3, input")?;
let recurrent_weights = static_dims!(recurrent_weights, 3)?;
let bias = bias
.map(|bias| static_dims!(bias, 2, "dir, hidden x 6"))
.transpose()?;

let [seq_len, batch, _input_size] = input.shape();
let [_directions, hidden_x3, _input_size] = weights.shape();

let initial_hidden = initial_hidden.map(|ih| static_dims!(ih, 3)).transpose()?;
let initial_hidden = initial_hidden
.map(|initial_hidden| static_dims!(initial_hidden, 3))
.transpose()?;

let num_directions = direction.num_directions();
let hidden_size = hidden_x3 / 3;
Expand All @@ -164,12 +169,12 @@ pub fn gru(
const HIDDEN_GATE: usize = 2;

let n_gates = 3;
let mut gates = Tensor::zeros_in(pool, &[batch, n_gates * hidden_size]).auto_return(pool);
let mut gates = NdTensor::zeros_in(pool, [batch, n_gates * hidden_size]).auto_return(pool);
let gate_range = |gate| (gate * hidden_size)..((gate + 1) * hidden_size);

// Scratch space for output of `hidden_state @ hidden_weights` matmul.
let mut hidden_scratch =
Tensor::zeros_in(pool, &[batch, n_gates * hidden_size]).auto_return(pool);
NdTensor::zeros_in(pool, [batch, n_gates * hidden_size]).auto_return(pool);

let gemm = GemmExecutor::new();
for dir in 0..num_directions {
Expand All @@ -193,10 +198,10 @@ pub fn gru(

let input_bias = bias
.as_ref()
.map(|b| b.slice::<1, _>((dir, ..(n_gates * hidden_size))));
.map(|b| b.slice_with((dir, ..(n_gates * hidden_size))));
let hidden_bias = bias
.as_ref()
.map(|b| b.slice::<1, _>((dir, (n_gates * hidden_size)..)));
.map(|b| b.slice_with((dir, (n_gates * hidden_size)..)));

for seq in sequence_for_dir(direction, dir, seq_len) {
let in_item = input.slice_with([seq]);
Expand Down Expand Up @@ -238,7 +243,7 @@ pub fn gru(
0., /* beta */
);
if let Some(input_bias) = input_bias {
add_in_place(gates.view_mut(), input_bias.as_dyn());
add_in_place(gates.as_dyn_mut(), input_bias.as_dyn());
}

// Compute `hidden @ hidden_weights + hidden_bias` for all gates.
Expand All @@ -252,15 +257,15 @@ pub fn gru(
0., /* beta */
);
if let Some(hidden_bias) = hidden_bias {
add_in_place(hidden_scratch.view_mut(), hidden_bias.as_dyn());
add_in_place(hidden_scratch.as_dyn_mut(), hidden_bias.as_dyn());
}

// Combine inputs for reset and update gates and apply activation.
let mut update_reset_gates = gates.slice_mut::<2, _>((
let mut update_reset_gates = gates.slice_with_mut((
..,
gate_range(UPDATE_GATE).start..gate_range(RESET_GATE).end,
));
let hidden_scratch_reset_update_gates = hidden_scratch.slice::<2, _>((
let hidden_scratch_reset_update_gates = hidden_scratch.slice_with((
..,
gate_range(UPDATE_GATE).start..gate_range(RESET_GATE).end,
));
Expand All @@ -278,22 +283,23 @@ pub fn gru(
// gates are in the same positions in the `update_reset_gates` slice
// as `gates`.
let update_reset_gates = sigmoid(pool, update_reset_gates.as_dyn()).auto_return(pool);
let update_gate = update_reset_gates.slice::<2, _>((.., gate_range(UPDATE_GATE)));
let reset_gate = update_reset_gates.slice::<2, _>((.., gate_range(RESET_GATE)));
let update_reset_gates = update_reset_gates.nd_view::<2>();
let update_gate = update_reset_gates.slice_with((.., gate_range(UPDATE_GATE)));
let reset_gate = update_reset_gates.slice_with((.., gate_range(RESET_GATE)));

// Combine inputs for hidden gate and apply activation.
let mut hidden_gate_recurrent =
hidden_scratch.slice_mut::<2, _>((.., gate_range(HIDDEN_GATE)));
hidden_scratch.slice_with_mut((.., gate_range(HIDDEN_GATE)));
mul_in_place(hidden_gate_recurrent.as_dyn_mut(), reset_gate.as_dyn());

let mut hidden_gate = gates.slice_mut::<2, _>((.., gate_range(HIDDEN_GATE)));
let mut hidden_gate = gates.slice_with_mut((.., gate_range(HIDDEN_GATE)));
add_in_place(hidden_gate.as_dyn_mut(), hidden_gate_recurrent.as_dyn());

// See note above about `sigmoid_in_place`.
let hidden_gate = tanh(pool, hidden_gate.as_dyn()).auto_return(pool);

// Compute next hidden state
let mut hidden_item = hidden.slice_mut::<2, _>([dir]);
let mut hidden_item = hidden.slice_with_mut([dir]);

for (hidden, update, hidden_gate) in zip3(
hidden_item.iter_mut(),
Expand All @@ -304,7 +310,7 @@ pub fn gru(
}

hidden_seq
.slice_mut::<2, _>([seq, dir])
.slice_with_mut([seq, dir])
.copy_from(&hidden_item);
}
}
Expand Down Expand Up @@ -375,9 +381,13 @@ pub fn lstm(
initial_cell: Option<TensorView>,
) -> Result<Vec<Tensor>, OpError> {
// TODO - Add validation of the sizes of individual dimensions in the inputs.
let [seq_len, batch, _input_size] = check_dims!(input, 3, "seq, batch, input");
let [_directions, hidden_x4, _input_size] = check_dims!(weights, 3, "dir, hidden x 4, input");
check_dims!(recurrent_weights, 3);
let input = static_dims!(input, 3, "seq, batch, input")?;
let [seq_len, batch, _input_size] = input.shape();

let weights = static_dims!(weights, 3, "dir, hidden x 4, input")?;
let [_directions, hidden_x4, _input_size] = weights.shape();

let recurrent_weights = static_dims!(recurrent_weights, 3, "dir, hidden x 4, hidden")?;

let num_directions = direction.num_directions();
let hidden_size = hidden_x4 / 4;
Expand All @@ -387,14 +397,20 @@ pub fn lstm(
"weights dim 1 must be 4 * hidden_size",
));
}

let bias = bias.map(|bias| static_dims!(bias, 2)).transpose()?;
if let Some(bias) = bias.as_ref() {
check_dims!(bias, 2);
if bias.size(1) % 8 != 0 {
return Err(OpError::InvalidValue("bias dim 1 must be 8 * hidden_size"));
}
}
check_dims!(initial_hidden?, 3);
check_dims!(initial_cell?, 3);

let initial_hidden = initial_hidden
.map(|initial_hidden| static_dims!(initial_hidden, 3))
.transpose()?;
let initial_cell = initial_cell
.map(|initial_cell| static_dims!(initial_cell, 3))
.transpose()?;

// Contiguous input and bias needed to allow reshaping below.
let input = input.to_contiguous_in(pool).auto_return(pool);
Expand All @@ -407,17 +423,17 @@ pub fn lstm(
const CELL_GATE: usize = 3;

let n_gates = 4;
let mut gates = Tensor::zeros_in(pool, &[batch, n_gates * hidden_size]);
let mut gates = NdTensor::zeros_in(pool, [batch, n_gates * hidden_size]);

let mut cell = initial_cell
.map(|t| t.to_tensor_in(pool))
.unwrap_or_else(|| Tensor::zeros_in(pool, &[num_directions, batch, hidden_size]));
.unwrap_or_else(|| NdTensor::zeros_in(pool, [num_directions, batch, hidden_size]));
let mut hidden = initial_hidden
.map(|t| t.to_tensor_in(pool))
.unwrap_or_else(|| Tensor::zeros_in(pool, &[num_directions, batch, hidden_size]));
.unwrap_or_else(|| NdTensor::zeros_in(pool, [num_directions, batch, hidden_size]));

let mut hidden_seq =
Tensor::<f32>::zeros_in(pool, &[seq_len, num_directions, batch, hidden_size]);
NdTensor::<f32, 4>::zeros_in(pool, [seq_len, num_directions, batch, hidden_size]);

let gemm = GemmExecutor::new();

Expand All @@ -426,15 +442,15 @@ pub fn lstm(
for dir in 0..num_directions {
let prepack = seq_len >= PREPACK_MIN_SEQ_LEN;

let input_weights = weights.slice::<2, _>(dir).transposed();
let input_weights = weights.slice_with(dir).transposed();
let packed_input_weights =
prepack.then(|| gemm.prepack_b_in(pool, input_weights).auto_return(pool));
let input_weights = packed_input_weights
.as_ref()
.map(|packed| GemmInputB::Packed(packed))
.unwrap_or(GemmInputB::Unpacked(input_weights));

let hidden_weights = recurrent_weights.slice::<2, _>(dir).transposed();
let hidden_weights = recurrent_weights.slice_with(dir).transposed();
let packed_hidden_weights =
prepack.then(|| gemm.prepack_b_in(pool, hidden_weights).auto_return(pool));
let hidden_weights = packed_hidden_weights
Expand All @@ -444,10 +460,10 @@ pub fn lstm(

let input_bias = bias
.as_ref()
.map(|b| b.slice::<1, _>((dir, ..(n_gates * hidden_size))));
.map(|b| b.slice_with((dir, ..(n_gates * hidden_size))));
let hidden_bias = bias
.as_ref()
.map(|b| b.slice::<1, _>((dir, (n_gates * hidden_size)..)));
.map(|b| b.slice_with((dir, (n_gates * hidden_size)..)));

for seq in sequence_for_dir(direction, dir, seq_len) {
// From the ONNX spec, the intermediate values are computed as:
Expand All @@ -469,8 +485,8 @@ pub fn lstm(
// supported.
// - `f`, `g` and `h` are activations. `f`=sigmoid, `g` and `h`
// are tanh.
let in_item = input.slice::<2, _>([seq]);
let hidden_item = hidden.slice::<2, _>([dir]);
let in_item = input.slice_with([seq]);
let hidden_item = hidden.slice_with([dir]);

// Update input, output, forget and cell gates.
let gates_row_stride = gates.stride(gates.ndim() - 2);
Expand All @@ -483,7 +499,7 @@ pub fn lstm(
0., /* beta */
);
if let Some(input_bias) = input_bias {
add_in_place(gates.view_mut(), input_bias.as_dyn());
add_in_place(gates.as_dyn_mut(), input_bias.as_dyn());
}

gemm.gemm(
Expand All @@ -495,25 +511,27 @@ pub fn lstm(
1., /* beta */
);
if let Some(hidden_bias) = hidden_bias {
add_in_place(gates.view_mut(), hidden_bias.as_dyn());
add_in_place(gates.as_dyn_mut(), hidden_bias.as_dyn());
}

// Copy gates to work around `tanh_in_place` and `sigmoid_in_place`
// being slow for non-contiguous inputs. See notes in GRU op.
let iof_gates = gates.slice::<2, _>((
let iof_gates = gates.slice_with((
..,
gate_range(INPUT_GATE).start..gate_range(FORGET_GATE).end,
));
let iof_gates = sigmoid(pool, iof_gates.as_dyn()).auto_return(pool);
let input_gate = iof_gates.slice::<2, _>((.., gate_range(INPUT_GATE)));
let out_gate = iof_gates.slice::<2, _>((.., gate_range(OUTPUT_GATE)));
let forget_gate = iof_gates.slice::<2, _>((.., gate_range(FORGET_GATE)));
let iof_gates = iof_gates.nd_view::<2>();

let input_gate = iof_gates.slice_with((.., gate_range(INPUT_GATE)));
let out_gate = iof_gates.slice_with((.., gate_range(OUTPUT_GATE)));
let forget_gate = iof_gates.slice_with((.., gate_range(FORGET_GATE)));

let cell_gate = gates.slice::<2, _>((.., gate_range(CELL_GATE)));
let cell_gate = gates.slice_with((.., gate_range(CELL_GATE)));
let cell_gate = tanh(pool, cell_gate.as_dyn()).auto_return(pool);

// Update cell and hidden state
let mut cell_item = cell.slice_mut::<2, _>([dir]);
let mut cell_item = cell.slice_with_mut([dir]);

for (cell, forget_gate, input_gate, cell_gate) in zip4(
cell_item.iter_mut(),
Expand All @@ -524,20 +542,20 @@ pub fn lstm(
*cell = forget_gate * *cell + input_gate * cell_gate;
}

let mut hidden_item = hidden.slice_mut::<2, _>([dir]);
let mut hidden_item = hidden.slice_with_mut([dir]);
for (hidden, out_gate, cell) in
zip3(hidden_item.iter_mut(), out_gate.iter(), cell_item.iter())
{
*hidden = out_gate * cell.tanh()
}

hidden_seq
.slice_mut::<2, _>([seq, dir])
.slice_with_mut([seq, dir])
.copy_from(&hidden_item);
}
}

Ok([hidden_seq, hidden, cell].into())
Ok([hidden_seq.into_dyn(), hidden.into_dyn(), cell.into_dyn()].into())
}

impl Operator for LSTM {
Expand Down
Loading