Skip to content

Commit

Permalink
Fix static_map::rehash and add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack committed Oct 6, 2023
1 parent 737134c commit 8f38846
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 5 deletions.
8 changes: 4 additions & 4 deletions include/cuco/detail/static_map/static_map.inl
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
cuda_stream_ref stream)
{
auto const is_filled = static_map_ns::detail::slot_is_filled<Key, T>(this->empty_key_sentinel());
this->impl_.rehash(this, is_filled, stream);
this->impl_->rehash(*this, is_filled, stream);
}

template <class Key,
Expand All @@ -374,7 +374,7 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
{
auto const is_filled = static_map_ns::detail::slot_is_filled<Key, T>(this->empty_key_sentinel());
auto const extent = make_window_extent<static_map>(capacity);
this->impl_.rehash(extent, this, is_filled, stream);
this->impl_->rehash(extent, *this, is_filled, stream);
}

template <class Key,
Expand All @@ -389,7 +389,7 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
cuda_stream_ref stream)
{
auto const is_filled = static_map_ns::detail::slot_is_filled<Key, T>(this->empty_key_sentinel());
this->impl_.rehash_async(this, is_filled, stream);
this->impl_->rehash_async(*this, is_filled, stream);
}

template <class Key,
Expand All @@ -405,7 +405,7 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
{
auto const is_filled = static_map_ns::detail::slot_is_filled<Key, T>(this->empty_key_sentinel());
auto const extent = make_window_extent<static_map>(capacity);
this->impl_.rehash_async(extent, this, is_filled, stream);
this->impl_->rehash_async(extent, *this, is_filled, stream);
}

template <class Key,
Expand Down
3 changes: 2 additions & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ ConfigureTest(STATIC_MAP_TEST
static_map/key_sentinel_test.cu
static_map/shared_memory_test.cu
static_map/stream_test.cu
static_map/unique_sequence_test.cu)
static_map/unique_sequence_test.cu
static_map/rehash_test.cu)

###################################################################################################
# - dynamic_map tests -----------------------------------------------------------------------------
Expand Down
55 changes: 55 additions & 0 deletions tests/static_map/rehash_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuco/static_map.cuh>

#include <thrust/device_vector.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/sequence.h>
#include <thrust/tuple.h>

#include <catch2/catch_test_macros.hpp>

TEST_CASE("Rehash", "")
{
using key_type = int;
using mapped_type = long;
constexpr std::size_t num_keys{400};

cuco::experimental::static_map<key_type, mapped_type> map{
num_keys, cuco::empty_key<key_type>{-1}, cuco::empty_value<mapped_type>{-1}};

thrust::device_vector<key_type> d_keys(num_keys);
thrust::device_vector<mapped_type> d_values(num_keys);

thrust::sequence(d_keys.begin(), d_keys.end());
thrust::sequence(d_values.begin(), d_values.end());

auto pairs_begin =
thrust::make_zip_iterator(thrust::make_tuple(d_keys.begin(), d_values.begin()));

map.insert(pairs_begin, pairs_begin + num_keys);

map.rehash();
REQUIRE(map.size() == num_keys);

map.rehash(num_keys * 2);
REQUIRE(map.size() == num_keys);

// TODO erase num_erased keys
// map.rehash()
// REQUIRE(map.size() == num_keys - num_erased);
}

0 comments on commit 8f38846

Please sign in to comment.