diff --git a/HugeCTR/src/hps/embedding_cache.cpp b/HugeCTR/src/hps/embedding_cache.cpp index 962f678b6a..f366b0fc26 100644 --- a/HugeCTR/src/hps/embedding_cache.cpp +++ b/HugeCTR/src/hps/embedding_cache.cpp @@ -296,7 +296,7 @@ void EmbeddingCache::lookup(size_t const table_id, float* const d_v HCTR_LIB_THROW( cudaMemcpyAsync(d_vectors, workspace_handler.h_missing_emb_vec_[table_id], num_keys * cache_config_.embedding_vec_size_[table_id] * sizeof(float), - cudaMemcpyHostToDevice, stream)); + cudaMemcpyDefault, stream)); HCTR_LIB_THROW(cudaStreamSynchronize(stream)); parameter_server_->free_buffer(memory_block); } @@ -552,12 +552,15 @@ void EmbeddingCache::finalize() { template void EmbeddingCache::insert_stream_for_sync( std::vector lookup_streams_) { - if (lookup_streams_.size() != gpu_emb_caches_.size()) { - HCTR_OWN_THROW(Error_t::WrongInput, - "The number of lookup streams is not equal to the number of embedding tables."); - } - for (size_t idx = 0; idx < lookup_streams_.size(); ++idx) { - gpu_emb_caches_[idx]->Record(lookup_streams_[idx]); + if (cache_config_.use_gpu_embedding_cache_) { + if (lookup_streams_.size() != gpu_emb_caches_.size()) { + HCTR_OWN_THROW( + Error_t::WrongInput, + "The number of lookup streams is not equal to the number of embedding tables."); + } + for (size_t idx = 0; idx < lookup_streams_.size(); ++idx) { + gpu_emb_caches_[idx]->Record(lookup_streams_[idx]); + } } } diff --git a/test/inference/hps/lookup_session_test.py b/test/inference/hps/lookup_session_test.py index 53ff13f2da..2875557056 100644 --- a/test/inference/hps/lookup_session_test.py +++ b/test/inference/hps/lookup_session_test.py @@ -146,6 +146,13 @@ def hps_dlpack(model_name, embedding_file_list, data_file, enable_cache, cache_t True, hugectr.inference.EmbeddingCacheType_t.Dynamic, ) + h1, h2 = hps_dlpack( + model_name, + embedding_file_list, + data_file, + False, + hugectr.inference.EmbeddingCacheType_t.Dynamic, + ) u1, u2 = hps_dlpack( model_name, embedding_file_list, data_file, True, hugectr.inference.EmbeddingCacheType_t.UVM ) @@ -173,15 +180,28 @@ def hps_dlpack(model_name, embedding_file_list, data_file, enable_cache, cache_t diff = u2.reshape(1, 26 * 16) - d2.reshape(1, 26 * 16) if diff.mean() > 1e-3: raise RuntimeError( - "The lookup results of UVM cache are consistent with Dynamic cache: {}".format( + "The lookup results of UVM cache are not consistent with Dynamic cache: {}".format( + diff.mean() + ) + ) + sys.exit(1) + else: + print( + "The lookup results on UVM are consistent with Dynamic cache, mse: {}".format( + diff.mean() + ) + ) + diff = h2.reshape(1, 26 * 16) - d2.reshape(1, 26 * 16) + if diff.mean() > 1e-3: + raise RuntimeError( + "The lookup results of Database backend are not consistent with Dynamic cache: {}".format( diff.mean() ) ) sys.exit(1) else: print( - "Pytorch dlpack on cpu results are consistent with native HPS lookup api, mse: {}".format( + "The lookup results on Database backend are consistent with Dynamic cache, mse: {}".format( diff.mean() ) ) - # hps_dlpack(model_name, network_file, dense_file, embedding_file_list, data_file, False)