From 05a7e32c4b0cec9f2bfaee04714c407aa5223607 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 24 Mar 2023 01:53:34 +0800 Subject: [PATCH] Restore. --- src/data/data.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/data/data.cc b/src/data/data.cc index c8e3019ce506..0485c5895c4c 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -478,6 +478,11 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) { return; } else if (key == "label") { CopyTensorInfoImpl(ctx, arr, &this->labels); + if (this->num_row_ != 0 && this->labels.Shape(0) != this->num_row_) { + CHECK_EQ(this->labels.Size() % this->num_row_, 0) << "Incorrect size for labels."; + size_t n_targets = this->labels.Size() / this->num_row_; + this->labels.Reshape(this->num_row_, n_targets); + } auto const& h_labels = labels.Data()->ConstHostVector(); auto valid = std::none_of(h_labels.cbegin(), h_labels.cend(), data::LabelsCheck{}); CHECK(valid) << "Label contains NaN, infinity or a value too large."; @@ -740,6 +745,7 @@ void MetaInfo::Validate(std::int32_t device) const { return; } if (labels.Size() != 0) { + std::cout << labels.Shape(0) << " nr:" << num_row_ << std::endl; CHECK_EQ(labels.Shape(0), num_row_) << "Size of labels must equal to number of rows."; CheckDevice(device, labels); return;