Skip to content

Commit

Permalink
Merge pull request #337 from robertknight/pad-dim-fix
Browse files Browse the repository at this point in the history
Fix incorrect reflect padding in tensors with 3+ dims
  • Loading branch information
robertknight authored Aug 29, 2024
2 parents 2a325d2 + 6aa16df commit 2fb2fd3
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions src/ops/pad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,12 @@ pub fn pad<T: Copy>(

let pad_dims = input.ndim() - batch_dims;
let (pad_top, pad_left) = if pad_dims == 1 {
(0, padding[[0]] as usize)
(0, padding[[batch_dims]] as usize)
} else {
(padding[[0]] as usize, padding[[1]] as usize)
(
padding[[batch_dims]] as usize,
padding[[batch_dims + 1]] as usize,
)
};

let mut input = input.view();
Expand Down Expand Up @@ -338,12 +341,24 @@ mod tests {
pads: NdTensor::from([]),
expected: Ok(Tensor::from(2.)),
},
// Pad start columns of a 3D tensor.
Case {
input: [[[1., 2., 3.]]].into(),
pads: [0, 0, 2, 0, 0, 0].into(),
expected: Ok(Tensor::from([[[3., 2., 1., 2., 3.]]])),
},
// Pad end columns of a 3D tensor.
Case {
input: [[[1., 2., 3.]]].into(),
pads: [0, 0, 0, 0, 0, 2].into(),
expected: Ok(Tensor::from([[[1., 2., 3., 2., 1.]]])),
},
// Pad start rows of a 3D tensor.
Case {
input: [[[1.], [2.], [3.]]].into(),
pads: [0, 2, 0, 0, 0, 0].into(),
expected: Ok(Tensor::from([[[3.], [2.], [1.], [2.], [3.]]])),
},
// Pad channel dimension of a 3D tensor.
Case {
input: [[[1., 2., 3.]]].into(),
Expand Down

0 comments on commit 2fb2fd3

Please sign in to comment.