diff --git a/include/xgboost/cache.h b/include/xgboost/cache.h index 05610f82135b..7fa72b89cae2 100644 --- a/include/xgboost/cache.h +++ b/include/xgboost/cache.h @@ -147,6 +147,25 @@ class DMatrixCache { } return container_.at(key).value; } + /** + * \brief Re-initialize the item in cache. + * + * Since the shared_ptr is used to hold the item, any reference that lives outside of + * the cache can no-longer be reached from the cache. + * + * We use reset instead of erase to avoid walking through the whole cache for renewing + * a single item. (the cache is FIFO, needs to maintain the order). + */ + template + void ResetItem(std::shared_ptr m, Args const&... args) { + std::lock_guard guard{lock_}; + CheckConsistent(); + auto key = Key{m.get(), std::this_thread::get_id()}; + auto it = container_.find(key); + CHECK(it != container_.cend()); + it->second = {m, std::make_shared(args...)}; + CheckConsistent(); + } /** * \brief Get a const reference to the underlying hash map. Clear expired caches before * returning. diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index fc6099b1b8da..c7563e92234f 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -411,9 +411,9 @@ class EvalRankWithCache : public Metric { double Evaluate(HostDeviceVector const& preds, std::shared_ptr p_fmat) override { auto const& info = p_fmat->Info(); - auto& p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_); + auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_); if (p_cache->Param() != param_) { - p_cache = std::make_shared(ctx_, info, param_); + cache_.ResetItem(p_fmat, ctx_, info, param_); } CHECK(p_cache->Param() == param_); CHECK_EQ(preds.Size(), info.labels.Size());