diff --git a/sparse/tpls/KokkosSparse_spgemm_symbolic_tpl_spec_decl.hpp b/sparse/tpls/KokkosSparse_spgemm_symbolic_tpl_spec_decl.hpp index 7db9ce5fc3..53d543b0d3 100644 --- a/sparse/tpls/KokkosSparse_spgemm_symbolic_tpl_spec_decl.hpp +++ b/sparse/tpls/KokkosSparse_spgemm_symbolic_tpl_spec_decl.hpp @@ -206,84 +206,118 @@ void spgemm_symbolic_cusparse(KernelHandle *handle, lno_t m, lno_t n, lno_t k, const ConstEntriesType &entriesA, const ConstRowMapType &row_mapB, const ConstEntriesType &entriesB, - const RowMapType &row_mapC, - bool /* computeRowptrs */) { - using scalar_type = typename KernelHandle::nnz_scalar_t; - using Offset = typename KernelHandle::size_type; - if (handle->is_symbolic_called() && handle->are_rowptrs_computed()) return; - handle->create_cusparse_spgemm_handle(false, false); - auto h = handle->get_cusparse_spgemm_handle(); + const RowMapType &row_mapC, bool computeRowptrs) { + using scalar_type = typename KernelHandle::nnz_scalar_t; + using ordinal_type = typename KernelHandle::nnz_lno_t; + const auto alpha = Kokkos::ArithTraits::one(); + const auto beta = Kokkos::ArithTraits::zero(); + void *dummyValues_AB = nullptr; + bool firstSymbolicCall = false; + if (!handle->is_symbolic_called()) { + handle->create_cusparse_spgemm_handle(false, false); + auto h = handle->get_cusparse_spgemm_handle(); - // Follow - // https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuSPARSE/spgemm + // Follow + // https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuSPARSE/spgemm + + // In non-reuse interface, forced to give A,B dummy values to + // cusparseSpGEMM_compute. And it actually reads them, so they must be + // allocated and of the correct type. This compute will be called again in + // numeric with the real values. + // + // The dummy values can be uninitialized. cusparseSpGEMM_compute does + // not remove numerical zeros from the sparsity pattern. + KOKKOS_IMPL_CUDA_SAFE_CALL(cudaMalloc( + &dummyValues_AB, sizeof(scalar_type) * + std::max(entriesA.extent(0), entriesB.extent(0)))); - const auto alpha = Kokkos::ArithTraits::one(); - const auto beta = Kokkos::ArithTraits::zero(); + KOKKOS_CUSPARSE_SAFE_CALL(cusparseCreateCsr( + &h->descr_A, m, n, entriesA.extent(0), (void *)row_mapA.data(), + (void *)entriesA.data(), dummyValues_AB, CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, h->scalarType)); + + KOKKOS_CUSPARSE_SAFE_CALL(cusparseCreateCsr( + &h->descr_B, n, k, entriesB.extent(0), (void *)row_mapB.data(), + (void *)entriesB.data(), dummyValues_AB, CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, h->scalarType)); + + KOKKOS_CUSPARSE_SAFE_CALL( + cusparseCreateCsr(&h->descr_C, m, k, 0, row_mapC.data(), nullptr, + nullptr, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_BASE_ZERO, h->scalarType)); + + //---------------------------------------------------------------------- + // query workEstimation buffer size, allocate, then call again with buffer. + KOKKOS_CUSPARSE_SAFE_CALL(cusparseSpGEMM_workEstimation( + h->cusparseHandle, h->opA, h->opB, &alpha, h->descr_A, h->descr_B, + &beta, h->descr_C, h->scalarType, h->alg, h->spgemmDescr, + &h->bufferSize3, nullptr)); + KOKKOS_IMPL_CUDA_SAFE_CALL( + cudaMalloc((void **)&h->buffer3, h->bufferSize3)); + KOKKOS_CUSPARSE_SAFE_CALL(cusparseSpGEMM_workEstimation( + h->cusparseHandle, h->opA, h->opB, &alpha, h->descr_A, h->descr_B, + &beta, h->descr_C, h->scalarType, h->alg, h->spgemmDescr, + &h->bufferSize3, h->buffer3)); + + //---------------------------------------------------------------------- + // query compute buffer size, allocate, then call again with buffer. - // In non-reuse interface, forced to give A,B dummy values to - // cusparseSpGEMM_compute. And it actually reads them, so they must be - // allocated and of the correct type. This compute will be called again in - // numeric with the real values. - // - // The dummy values can be uninitialized. cusparseSpGEMM_compute does - // not remove numerical zeros from the sparsity pattern. - void *dummyValues; - KOKKOS_IMPL_CUDA_SAFE_CALL(cudaMalloc( - &dummyValues, - sizeof(scalar_type) * std::max(entriesA.extent(0), entriesB.extent(0)))); - - KOKKOS_CUSPARSE_SAFE_CALL(cusparseCreateCsr( - &h->descr_A, m, n, entriesA.extent(0), (void *)row_mapA.data(), - (void *)entriesA.data(), dummyValues, CUSPARSE_INDEX_32I, - CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, h->scalarType)); - - KOKKOS_CUSPARSE_SAFE_CALL(cusparseCreateCsr( - &h->descr_B, n, k, entriesB.extent(0), (void *)row_mapB.data(), - (void *)entriesB.data(), dummyValues, CUSPARSE_INDEX_32I, - CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, h->scalarType)); - - KOKKOS_CUSPARSE_SAFE_CALL( - cusparseCreateCsr(&h->descr_C, m, k, 0, row_mapC.data(), nullptr, nullptr, - CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, - CUSPARSE_INDEX_BASE_ZERO, h->scalarType)); - - //---------------------------------------------------------------------- - // query workEstimation buffer size, allocate, then call again with buffer. - KOKKOS_CUSPARSE_SAFE_CALL(cusparseSpGEMM_workEstimation( - h->cusparseHandle, h->opA, h->opB, &alpha, h->descr_A, h->descr_B, &beta, - h->descr_C, h->scalarType, h->alg, h->spgemmDescr, &h->bufferSize3, - nullptr)); - KOKKOS_IMPL_CUDA_SAFE_CALL(cudaMalloc((void **)&h->buffer3, h->bufferSize3)); - KOKKOS_CUSPARSE_SAFE_CALL(cusparseSpGEMM_workEstimation( - h->cusparseHandle, h->opA, h->opB, &alpha, h->descr_A, h->descr_B, &beta, - h->descr_C, h->scalarType, h->alg, h->spgemmDescr, &h->bufferSize3, - h->buffer3)); - cudaFree(h->buffer3); - h->buffer3 = nullptr; - - //---------------------------------------------------------------------- - // query compute buffer size, allocate, then call again with buffer. - - KOKKOS_CUSPARSE_SAFE_CALL(cusparseSpGEMM_compute( - h->cusparseHandle, h->opA, h->opB, &alpha, h->descr_A, h->descr_B, &beta, - h->descr_C, h->scalarType, CUSPARSE_SPGEMM_DEFAULT, h->spgemmDescr, - &h->bufferSize4, nullptr)); - KOKKOS_IMPL_CUDA_SAFE_CALL(cudaMalloc((void **)&h->buffer4, h->bufferSize4)); - KOKKOS_CUSPARSE_SAFE_CALL(cusparseSpGEMM_compute( - h->cusparseHandle, h->opA, h->opB, &alpha, h->descr_A, h->descr_B, &beta, - h->descr_C, h->scalarType, CUSPARSE_SPGEMM_DEFAULT, h->spgemmDescr, - &h->bufferSize4, h->buffer4)); - KOKKOS_IMPL_CUDA_SAFE_CALL(cudaFree(dummyValues)); - - int64_t C_nrow, C_ncol, C_nnz; - KOKKOS_CUSPARSE_SAFE_CALL( - cusparseSpMatGetSize(h->descr_C, &C_nrow, &C_ncol, &C_nnz)); - if (C_nnz > std::numeric_limits::max()) { - throw std::runtime_error("nnz of C overflowed over 32-bit int\n"); + KOKKOS_CUSPARSE_SAFE_CALL(cusparseSpGEMM_compute( + h->cusparseHandle, h->opA, h->opB, &alpha, h->descr_A, h->descr_B, + &beta, h->descr_C, h->scalarType, CUSPARSE_SPGEMM_DEFAULT, + h->spgemmDescr, &h->bufferSize4, nullptr)); + KOKKOS_IMPL_CUDA_SAFE_CALL( + cudaMalloc((void **)&h->buffer4, h->bufferSize4)); + KOKKOS_CUSPARSE_SAFE_CALL(cusparseSpGEMM_compute( + h->cusparseHandle, h->opA, h->opB, &alpha, h->descr_A, h->descr_B, + &beta, h->descr_C, h->scalarType, CUSPARSE_SPGEMM_DEFAULT, + h->spgemmDescr, &h->bufferSize4, h->buffer4)); + int64_t C_nrow, C_ncol, C_nnz; + KOKKOS_CUSPARSE_SAFE_CALL( + cusparseSpMatGetSize(h->descr_C, &C_nrow, &C_ncol, &C_nnz)); + if (C_nnz > std::numeric_limits::max()) { + throw std::runtime_error("nnz of C overflowed over 32-bit int\n"); + } + handle->set_c_nnz(C_nnz); + handle->set_call_symbolic(); + firstSymbolicCall = true; } - handle->set_c_nnz(C_nnz); - handle->set_call_symbolic(); - handle->set_computed_rowptrs(); + + if (computeRowptrs && !handle->are_rowptrs_computed()) { + auto h = handle->get_cusparse_spgemm_handle(); + auto C_nnz = handle->get_c_nnz(); + if (!firstSymbolicCall) { + // This is not the first call to symbolic, so dummyValues_AB was not + // allocated above. But, descr_A and descr_B will have been saved in the + // handle, so we can reuse those. + KOKKOS_IMPL_CUDA_SAFE_CALL(cudaMalloc( + &dummyValues_AB, sizeof(scalar_type) * std::max(entriesA.extent(0), + entriesB.extent(0)))); + KOKKOS_CUSPARSE_SAFE_CALL( + cusparseCsrSetPointers(h->descr_A, (void *)row_mapA.data(), + (void *)entriesA.data(), dummyValues_AB)); + KOKKOS_CUSPARSE_SAFE_CALL( + cusparseCsrSetPointers(h->descr_B, (void *)row_mapB.data(), + (void *)entriesB.data(), dummyValues_AB)); + } + void *dummyEntries_C, *dummyValues_C; + KOKKOS_IMPL_CUDA_SAFE_CALL( + cudaMalloc(&dummyEntries_C, sizeof(ordinal_type) * C_nnz)); + KOKKOS_IMPL_CUDA_SAFE_CALL( + cudaMalloc(&dummyValues_C, sizeof(scalar_type) * C_nnz)); + KOKKOS_CUSPARSE_SAFE_CALL(cusparseCsrSetPointers( + h->descr_C, (void *)row_mapC.data(), dummyEntries_C, dummyValues_C)); + + KOKKOS_CUSPARSE_SAFE_CALL(cusparseSpGEMM_copy( + h->cusparseHandle, h->opA, h->opB, &alpha, h->descr_A, h->descr_B, + &beta, h->descr_C, h->scalarType, CUSPARSE_SPGEMM_DEFAULT, + h->spgemmDescr)); + + KOKKOS_IMPL_CUDA_SAFE_CALL(cudaFree(dummyValues_C)); + KOKKOS_IMPL_CUDA_SAFE_CALL(cudaFree(dummyEntries_C)); + handle->set_computed_rowptrs(); + } + KOKKOS_IMPL_CUDA_SAFE_CALL(cudaFree(dummyValues_AB)); } #else diff --git a/sparse/unit_test/Test_Sparse_spgemm.hpp b/sparse/unit_test/Test_Sparse_spgemm.hpp index 1e60b35b81..51fdfb7def 100644 --- a/sparse/unit_test/Test_Sparse_spgemm.hpp +++ b/sparse/unit_test/Test_Sparse_spgemm.hpp @@ -405,6 +405,70 @@ void test_spgemm(lno_t m, lno_t k, lno_t n, size_type nnz, lno_t bandwidth, // device::execution_space::finalize(); } +template +void test_spgemm_symbolic(bool callSymbolicFirst, bool testEmpty) { + using crsMat_t = CrsMatrix; + using graph_t = typename crsMat_t::StaticCrsGraphType; + using values_t = typename crsMat_t::values_type; + using entries_t = typename graph_t::entries_type; + using rowmap_t = typename graph_t::row_map_type::non_const_type; + using const_rowmap_t = typename graph_t::row_map_type; + using KernelHandle = KokkosKernels::Experimental::KokkosKernelsHandle< + size_type, lno_t, scalar_t, typename device::execution_space, + typename device::memory_space, typename device::memory_space>; + // A is m*n, B is n*k, C is m*k + int m = 100; + int n = 300; + int k = 200; + crsMat_t A, B; + // Target 1000 total nonzeros in both A and B. + if (testEmpty) { + // Create A,B with the same dimensions, but zero entries + values_t emptyValues; + entries_t emptyEntries; + // Initialize these to 0 + rowmap_t A_rowmap("A rowmap", m + 1); + rowmap_t B_rowmap("B rowmap", n + 1); + A = crsMat_t("A", m, n, 0, emptyValues, A_rowmap, emptyEntries); + B = crsMat_t("B", n, k, 0, emptyValues, B_rowmap, emptyEntries); + } else { + size_type nnz = 1000; + A = KokkosSparse::Impl::kk_generate_sparse_matrix(m, n, nnz, 10, + 50); + nnz = 1000; + B = KokkosSparse::Impl::kk_generate_sparse_matrix(n, k, nnz, 10, + 50); + KokkosSparse::sort_crs_matrix(A); + KokkosSparse::sort_crs_matrix(B); + } + // Call reference impl to get complete product + crsMat_t C_reference; + Test::run_spgemm(A, B, SPGEMM_DEBUG, C_reference, false); + // Now call just symbolic, and specifically request that rowptrs be populated + // Make sure this never depends on C_rowmap being initialized + rowmap_t C_rowmap(Kokkos::view_alloc(Kokkos::WithoutInitializing, "rowmapC"), + m + 1); + Kokkos::deep_copy(C_rowmap, size_type(123)); + KernelHandle kh; + kh.create_spgemm_handle(); + if (callSymbolicFirst) { + KokkosSparse::Experimental::spgemm_symbolic( + &kh, m, n, k, A.graph.row_map, A.graph.entries, false, B.graph.row_map, + B.graph.entries, false, C_rowmap); + } + KokkosSparse::Experimental::spgemm_symbolic( + &kh, m, n, k, A.graph.row_map, A.graph.entries, false, B.graph.row_map, + B.graph.entries, false, C_rowmap, true); + kh.destroy_spgemm_handle(); + bool isCorrect = KokkosKernels::Impl::kk_is_identical_view< + const_rowmap_t, const_rowmap_t, size_type, + typename device::execution_space>(C_rowmap, C_reference.graph.row_map, 0); + EXPECT_TRUE(isCorrect) + << " spgemm_symbolic produced incorrect rowptrs - callSymbolicFirst = " + << callSymbolicFirst << ", empty A/B = " << testEmpty; +} + template void test_issue402() { @@ -502,6 +566,10 @@ void test_issue402() { test_spgemm(10, 10, 0, 0, 10, 10, true); \ test_spgemm(10, 10, 10, 0, 0, 0, false); \ test_spgemm(10, 10, 10, 0, 0, 0, true); \ + test_spgemm_symbolic(true, true); \ + test_spgemm_symbolic(false, true); \ + test_spgemm_symbolic(true, false); \ + test_spgemm_symbolic(false, false); \ test_issue402(); \ }