Skip to content

Commit

Permalink
Merge pull request #324 from robertknight/1d-conv-non-contiguous
Browse files Browse the repository at this point in the history
Fix 1D conv failing with non-contiguous inputs
  • Loading branch information
robertknight authored Aug 24, 2024
2 parents 764aa87 + 9d7e242 commit 5410adb
Showing 1 changed file with 45 additions and 19 deletions.
64 changes: 45 additions & 19 deletions src/ops/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ pub fn conv(
) -> Result<Tensor, OpError> {
// Handle 1D convolution by expanding to 2D and then removing the extra
// dimension from the result.
if let &[n, c, w] = input.shape() {
let [out_c, k_in_c, k_w] = check_dims!(kernel, 3, "OCW");
if let &[_n, _c, _w] = input.shape() {
let [_out_c, _k_in_c, _k_w] = check_dims!(kernel, 3, "OCW");

let mut input_2d = input.clone();
input_2d.reshape(&[n, c, 1, w]);
input_2d.insert_axis(2);

let mut kernel_2d = kernel.clone();
kernel_2d.reshape(&[out_c, k_in_c, 1, k_w]);
kernel_2d.insert_axis(2);

let padding_2d = padding.expand_1d_to_2d()?;

Expand Down Expand Up @@ -228,6 +228,7 @@ pub fn conv(
let in_group = input.slice::<4, _>((.., in_chan_start..in_chan_end));
let mut out_group = output.slice_mut::<3, _>((.., out_chans.clone()));

let kernel = kernel.to_contiguous_in(pool);
let kernel_mat = kernel
.slice::<4, _>([out_chans.clone()])
.reshaped([out_channels_per_group, in_channels_per_group * k_h * k_w]);
Expand Down Expand Up @@ -1133,23 +1134,48 @@ mod tests {
fn test_conv_1d() {
let mut rng = XorShiftRng::new(1234);
let [n, in_c, out_c, in_w, k_w] = [1, 5, 10, 20, 3];
let input = Tensor::rand(&[n, in_c, in_w], &mut rng);
let kernel = Tensor::rand(&[out_c, in_c, k_w], &mut rng);

let pool = new_pool();
let result = conv(
&pool,
input.view(),
kernel.view(),
None,
Padding::Same,
1, /* groups */
&[1], /* stride */
&[1], /* dilation */
)
.unwrap();
struct Case {
input: Tensor,
kernel: Tensor,
}

assert_eq!(result.shape(), &[n, out_c, in_w]);
let cases = [
Case {
input: Tensor::rand(&[n, in_c, in_w], &mut rng),
kernel: Tensor::rand(&[out_c, in_c, k_w], &mut rng),
},
// Non-contiguous inputs
Case {
input: {
let mut input = Tensor::rand(&[n, in_w, in_c], &mut rng);
input.permute(&[0, 2, 1]);
input
},
kernel: {
let mut kernel = Tensor::rand(&[out_c, k_w, in_c], &mut rng);
kernel.permute(&[0, 2, 1]);
kernel
},
},
];

for Case { input, kernel } in cases {
let pool = new_pool();
let result = conv(
&pool,
input.view(),
kernel.view(),
None,
Padding::Same,
1, /* groups */
&[1], /* stride */
&[1], /* dilation */
)
.unwrap();

assert_eq!(result.shape(), &[n, out_c, in_w]);
}
}

#[test]
Expand Down

0 comments on commit 5410adb

Please sign in to comment.