Skip to content

Commit

Permalink
[master] get the current device of inputs, resolves #68
Browse files Browse the repository at this point in the history
  • Loading branch information
vacancy committed Mar 30, 2022
1 parent 61b1b72 commit cf10401
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pytorch/prroi_pool/src/prroi_pooling_gpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ at::Tensor prroi_pooling_forward_cuda(const at::Tensor &features, const at::Tens
return output;
}

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(features.device().index());
PrRoIPoolingForwardGpu(
stream, features.data<float>(), rois.data<float>(), output.data<float>(),
nr_channels, height, width, pooled_height, pooled_width, spatial_scale,
Expand Down Expand Up @@ -62,7 +62,7 @@ at::Tensor prroi_pooling_backward_cuda(
return features_diff;
}

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(features.device().index());
PrRoIPoolingBackwardGpu(
stream,
features.data<float>(), rois.data<float>(), output.data<float>(), output_diff.data<float>(),
Expand Down Expand Up @@ -93,7 +93,7 @@ at::Tensor prroi_pooling_coor_backward_cuda(
return coor_diff;
}

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(features.device().index());
PrRoIPoolingCoorBackwardGpu(
stream,
features.data<float>(), rois.data<float>(), output.data<float>(), output_diff.data<float>(),
Expand Down

0 comments on commit cf10401

Please sign in to comment.