diff --git a/cpp/src/prims/detail/per_v_transform_reduce_e.cuh b/cpp/src/prims/detail/per_v_transform_reduce_e.cuh index 4914b7cbcbd..4e1f220f863 100644 --- a/cpp/src/prims/detail/per_v_transform_reduce_e.cuh +++ b/cpp/src/prims/detail/per_v_transform_reduce_e.cuh @@ -65,6 +65,11 @@ #include #include +#include +#include + +using namespace cuda::experimental::stf; + namespace cugraph { namespace detail { @@ -1151,6 +1156,15 @@ void per_v_transform_reduce_e_edge_partition( std::optional> key_segment_offsets, std::optional> const& edge_partition_stream_pool_indices) { + async_resources_handle& cudastf_handle = *raft::resource::get_custom_resource(handle); + stream_ctx cudastf_ctx(handle.get_stream(), cudastf_handle); + + logical_data output_tokens[4]; + for (size_t i = 0; i < 4; i++) + { + output_tokens[i] = cudastf_ctx.logical_token(); + } + constexpr bool use_input_key = !std::is_same_v; using vertex_t = typename GraphViewType::vertex_type; @@ -1174,10 +1188,13 @@ void per_v_transform_reduce_e_edge_partition( if constexpr (update_major && !use_input_key) { // this is necessary as we don't visit // every vertex in the hypersparse segment - thrust::fill(rmm::exec_policy_nosync(exec_stream), - output_buffer + (*key_segment_offsets)[3], - output_buffer + (*key_segment_offsets)[4], - major_init); + // TODO task write output_token[3] + cudastf_ctx.task(output_tokens[3].write())->*[=](cudaStream_t stream) { + thrust::fill(rmm::exec_policy_nosync(stream), + output_buffer + (*key_segment_offsets)[3], + output_buffer + (*key_segment_offsets)[4], + major_init); + }; } auto segment_size = use_input_key @@ -1187,8 +1204,9 @@ void per_v_transform_reduce_e_edge_partition( raft::grid_1d_thread_t update_grid(segment_size, detail::per_v_transform_reduce_e_kernel_block_size, handle.get_device_properties().maxGridSize[0]); + size_t token_idx = 0; auto segment_output_buffer = output_buffer; - if constexpr (update_major) { segment_output_buffer += (*key_segment_offsets)[3]; } + if constexpr (update_major) { segment_output_buffer += (*key_segment_offsets)[3]; token_idx +=3; } auto segment_key_first = edge_partition_key_first; auto segment_key_last = edge_partition_key_last; if constexpr (use_input_key) { @@ -1199,20 +1217,22 @@ void per_v_transform_reduce_e_edge_partition( assert(segment_key_first == nullptr); assert(segment_key_last == nullptr); } - detail::per_v_transform_reduce_e_hypersparse - <<>>( - edge_partition, - segment_key_first, - segment_key_last, - edge_partition_src_value_input, - edge_partition_dst_value_input, - edge_partition_e_value_input, - edge_partition_e_mask, - segment_output_buffer, - e_op, - major_init, - reduce_op, - pred_op); + cudastf_ctx.task(output_tokens[token_idx].rw())->*[=](cudaStream_t stream) { + detail::per_v_transform_reduce_e_hypersparse + <<>>( + edge_partition, + segment_key_first, + segment_key_last, + edge_partition_src_value_input, + edge_partition_dst_value_input, + edge_partition_e_value_input, + edge_partition_e_mask, + segment_output_buffer, + e_op, + major_init, + reduce_op, + pred_op); + }; } } if ((*key_segment_offsets)[3] - (*key_segment_offsets)[2]) { @@ -1223,8 +1243,9 @@ void per_v_transform_reduce_e_edge_partition( raft::grid_1d_thread_t update_grid((*key_segment_offsets)[3] - (*key_segment_offsets)[2], detail::per_v_transform_reduce_e_kernel_block_size, handle.get_device_properties().maxGridSize[0]); + size_t token_idx = 0; auto segment_output_buffer = output_buffer; - if constexpr (update_major) { segment_output_buffer += (*key_segment_offsets)[2]; } + if constexpr (update_major) { segment_output_buffer += (*key_segment_offsets)[2]; token_idx += 2; } std::optional segment_key_first{}; // std::optional as thrust::transform_iterator's default constructor // is a deleted function, segment_key_first should always have a value @@ -1234,8 +1255,10 @@ void per_v_transform_reduce_e_edge_partition( segment_key_first = thrust::make_counting_iterator(edge_partition.major_range_first()); } *segment_key_first += (*key_segment_offsets)[2]; + + cudastf_ctx.task(output_tokens[token_idx].rw())->*[=](cudaStream_t stream) { detail::per_v_transform_reduce_e_low_degree - <<>>( + <<>>( edge_partition, *segment_key_first, *segment_key_first + ((*key_segment_offsets)[3] - (*key_segment_offsets)[2]), @@ -1248,6 +1271,7 @@ void per_v_transform_reduce_e_edge_partition( major_init, reduce_op, pred_op); + }; } if ((*key_segment_offsets)[2] - (*key_segment_offsets)[1] > 0) { auto exec_stream = edge_partition_stream_pool_indices @@ -1257,8 +1281,9 @@ void per_v_transform_reduce_e_edge_partition( raft::grid_1d_warp_t update_grid((*key_segment_offsets)[2] - (*key_segment_offsets)[1], detail::per_v_transform_reduce_e_kernel_block_size, handle.get_device_properties().maxGridSize[0]); + size_t token_idx = 0; auto segment_output_buffer = output_buffer; - if constexpr (update_major) { segment_output_buffer += (*key_segment_offsets)[1]; } + if constexpr (update_major) { segment_output_buffer += (*key_segment_offsets)[1]; token_idx += 1;} std::optional segment_key_first{}; // std::optional as thrust::transform_iterator's default constructor // is a deleted function, segment_key_first should always have a value @@ -1268,8 +1293,10 @@ void per_v_transform_reduce_e_edge_partition( segment_key_first = thrust::make_counting_iterator(edge_partition.major_range_first()); } *segment_key_first += (*key_segment_offsets)[1]; + + cudastf_ctx.task(output_tokens[token_idx].rw())->*[=](cudaStream_t stream) { detail::per_v_transform_reduce_e_mid_degree - <<>>( + <<>>( edge_partition, *segment_key_first, *segment_key_first + ((*key_segment_offsets)[2] - (*key_segment_offsets)[1]), @@ -1283,6 +1310,7 @@ void per_v_transform_reduce_e_edge_partition( major_identity_element, reduce_op, pred_op); + }; } if ((*key_segment_offsets)[1] > 0) { auto exec_stream = edge_partition_stream_pool_indices @@ -1303,8 +1331,9 @@ void per_v_transform_reduce_e_edge_partition( } else { segment_key_first = thrust::make_counting_iterator(edge_partition.major_range_first()); } + cudastf_ctx.task(output_tokens[0].rw())->*[=](cudaStream_t stream) { detail::per_v_transform_reduce_e_high_degree - <<>>( + <<>>( edge_partition, *segment_key_first, *segment_key_first + (*key_segment_offsets)[1], @@ -1318,6 +1347,7 @@ void per_v_transform_reduce_e_edge_partition( major_identity_element, reduce_op, pred_op); + }; } } else { auto exec_stream = edge_partition_stream_pool_indices @@ -1361,6 +1391,8 @@ void per_v_transform_reduce_e_edge_partition( pred_op); } } + + cudastf_ctx.finalize(); } template