Skip to content

Commit

Permalink
Add support for cross-validation using query ID (#4474)
Browse files Browse the repository at this point in the history
* adding support for matrix slicing with query ID for cross-validation

* hail mary test of unrar installation for windows tests

* trying to modify tests to run in Github CI

* Remove dependency on wget and unrar

* Save error log from R test

* Relax assertion in test_training

* Use int instead of bool in C function interface

* Revise R interface

* Add XGDMatrixSliceDMatrixEx and keep old XGDMatrixSliceDMatrix for API compatibility
  • Loading branch information
bryan-woods authored and hcho3 committed May 23, 2019
1 parent 5a567ec commit 278562d
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,5 @@ List of Contributors
* [Sam Wilkinson](https://samwilkinson.io)
* [Matthew Jones](https://github.com/mt-jones)
* [Jiaxiang Li](https://github.com/JiaxiangBU)
* [Bryan Woods](https://github.com/bryan-woods)
- Bryan added support for cross-validation for the ranking objective
2 changes: 2 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ def TestR(args) {
sh """
${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/build_test_rpkg.sh
"""
// Save error log, if any
archiveArtifacts artifacts: "xgboost.Rcheck/00install.out", allowEmptyArchive: true
deleteDir()
}
}
7 changes: 4 additions & 3 deletions R-package/src/xgboost_R.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,10 @@ SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
idxvec[i] = INTEGER(idxset)[i] - 1;
}
DMatrixHandle res;
CHECK_CALL(XGDMatrixSliceDMatrix(R_ExternalPtrAddr(handle),
BeginPtr(idxvec), len,
&res));
CHECK_CALL(XGDMatrixSliceDMatrixEx(R_ExternalPtrAddr(handle),
BeginPtr(idxvec), len,
&res,
0));
ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
R_API_END();
Expand Down
14 changes: 14 additions & 0 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,20 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
const int *idxset,
bst_ulong len,
DMatrixHandle *out);
/*!
* \brief create a new dmatrix from sliced content of existing matrix
* \param handle instance of data matrix to be sliced
* \param idxset index set
* \param len length of index set
* \param out a sliced new matrix
* \param allow_groups allow slicing of an array with groups
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
const int *idxset,
bst_ulong len,
DMatrixHandle *out,
int allow_groups);
/*!
* \brief free space in data matrix
* \return 0 when success, -1 when failure happens
Expand Down
3 changes: 2 additions & 1 deletion jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSliceDMat
jint* indexset = jenv->GetIntArrayElements(jindexset, 0);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jindexset);

jint ret = (jint) XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, &result);
// default to not allowing slicing with group ID specified -- feel free to add if necessary
jint ret = (jint) XGDMatrixSliceDMatrixEx(handle, (int const *)indexset, len, &result, 0);
setHandle(jenv, jout, result);
//release
jenv->ReleaseIntArrayElements(jindexset, indexset, 0);
Expand Down
13 changes: 8 additions & 5 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,13 +795,15 @@ def num_col(self):
ctypes.byref(ret)))
return ret.value

def slice(self, rindex):
def slice(self, rindex, allow_groups=False):
"""Slice the DMatrix and return a new DMatrix that only contains `rindex`.
Parameters
----------
rindex : list
List of indices to be selected.
allow_groups : boolean
Allow slicing of a matrix with a groups attribute
Returns
-------
Expand All @@ -811,10 +813,11 @@ def slice(self, rindex):
res = DMatrix(None, feature_names=self.feature_names,
feature_types=self.feature_types)
res.handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixSliceDMatrix(self.handle,
c_array(ctypes.c_int, rindex),
c_bst_ulong(len(rindex)),
ctypes.byref(res.handle)))
_check_call(_LIB.XGDMatrixSliceDMatrixEx(self.handle,
c_array(ctypes.c_int, rindex),
c_bst_ulong(len(rindex)),
ctypes.byref(res.handle),
ctypes.c_int(1 if allow_groups else 0)))
return res

@property
Expand Down
62 changes: 57 additions & 5 deletions python-package/xgboost/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,56 @@ def eval(self, iteration, feval):
return self.bst.eval_set(self.watchlist, iteration, feval)


def groups_to_rows(groups, boundaries):
"""
Given group row boundaries, convert ground indexes to row indexes
:param groups: list of groups for testing
:param boundaries: rows index limits of each group
:return: row in group
"""
return np.concatenate([np.arange(boundaries[g], boundaries[g+1]) for g in groups])


def mkgroupfold(dall, nfold, param, evals=(), fpreproc=None, shuffle=True):
"""
Make n folds for cross-validation maintaining groups
:return: cross-validation folds
"""
# we have groups for pairwise ranking... get a list of the group indexes
group_boundaries = dall.get_uint_info('group_ptr')
group_sizes = np.diff(group_boundaries)

if shuffle is True:
idx = np.random.permutation(len(group_sizes))
else:
idx = np.arange(len(group_sizes))
# list by fold of test group indexes
out_group_idset = np.array_split(idx, nfold)
# list by fold of train group indexes
in_group_idset = [np.concatenate([out_group_idset[i] for i in range(nfold) if k != i])
for k in range(nfold)]
# from the group indexes, convert them to row indexes
in_idset = [groups_to_rows(in_groups, group_boundaries) for in_groups in in_group_idset]
out_idset = [groups_to_rows(out_groups, group_boundaries) for out_groups in out_group_idset]

# build the folds by taking the appropriate slices
ret = []
for k in range(nfold):
# perform the slicing using the indexes determined by the above methods
dtrain = dall.slice(in_idset[k], allow_groups=True)
dtrain.set_group(group_sizes[in_group_idset[k]])
dtest = dall.slice(out_idset[k], allow_groups=True)
dtest.set_group(group_sizes[out_group_idset[k]])
# run preprocessing on the data set if needed
if fpreproc is not None:
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
else:
tparam = param
plst = list(tparam.items()) + [('eval_metric', itm) for itm in evals]
ret.append(CVPack(dtrain, dtest, plst))
return ret


def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
folds=None, shuffle=True):
"""
Expand All @@ -243,16 +293,17 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
np.random.seed(seed)

if stratified is False and folds is None:
# Do standard k-fold cross validation
# Do standard k-fold cross validation. Automatically determine the folds.
if len(dall.get_uint_info('group_ptr')) > 1:
return mkgroupfold(dall, nfold, param, evals=evals, fpreproc=fpreproc, shuffle=shuffle)

if shuffle is True:
idx = np.random.permutation(dall.num_row())
else:
idx = np.arange(dall.num_row())
out_idset = np.array_split(idx, nfold)
in_idset = [
np.concatenate([out_idset[i] for i in range(nfold) if k != i])
for k in range(nfold)
]
in_idset = [np.concatenate([out_idset[i] for i in range(nfold) if k != i])
for k in range(nfold)]
elif folds is not None:
# Use user specified custom split using indices
try:
Expand All @@ -274,6 +325,7 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,

ret = []
for k in range(nfold):
# perform the slicing using the indexes determined by the above methods
dtrain = dall.slice(in_idset[k])
dtest = dall.slice(out_idset[k])
# run preprocessing on the data set if needed
Expand Down
21 changes: 17 additions & 4 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,14 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
const int* idxset,
xgboost::bst_ulong len,
DMatrixHandle* out) {
return XGDMatrixSliceDMatrixEx(handle, idxset, len, out, 0);
}

XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
const int* idxset,
xgboost::bst_ulong len,
DMatrixHandle* out,
int allow_groups) {
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());

API_BEGIN();
Expand All @@ -682,8 +690,10 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
src.CopyFrom(static_cast<std::shared_ptr<DMatrix>*>(handle)->get());
data::SimpleCSRSource& ret = *source;

CHECK_EQ(src.info.group_ptr_.size(), 0U)
if (!allow_groups) {
CHECK_EQ(src.info.group_ptr_.size(), 0U)
<< "slice does not support group structure";
}

ret.Clear();
ret.info.num_row_ = len;
Expand Down Expand Up @@ -814,11 +824,14 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
const std::vector<unsigned>* vec = nullptr;
if (!std::strcmp(field, "root_index")) {
vec = &info.root_index_;
*out_len = static_cast<xgboost::bst_ulong>(vec->size());
*out_dptr = dmlc::BeginPtr(*vec);
} else if (!std::strcmp(field, "group_ptr")) {
vec = &info.group_ptr_;
} else {
LOG(FATAL) << "Unknown uint field name " << field;
LOG(FATAL) << "Unknown comp uint field name " << field
<< " with comparison " << std::strcmp(field, "group_ptr");
}
*out_len = static_cast<xgboost::bst_ulong>(vec->size());
*out_dptr = dmlc::BeginPtr(*vec);
API_END();
}

Expand Down
117 changes: 117 additions & 0 deletions tests/python/test_ranking.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
import numpy as np
from scipy.sparse import csr_matrix
import xgboost
import sys
import os
from sklearn.datasets import load_svmlight_files
import unittest
import itertools
import glob
import shutil
import urllib.request
import zipfile


def test_ranking_with_unweighted_data():
Xrow = np.array([1, 2, 6, 8, 11, 14, 16, 17])
Expand Down Expand Up @@ -63,3 +73,110 @@ def test_ranking_with_weighted_data():
# the ranking predictor will first try to correctly sort the last query group
# before correctly sorting other groups.
assert all(p <= q for p, q in zip(is_sorted, is_sorted[1:]))


class TestRanking(unittest.TestCase):

@classmethod
def setUpClass(cls):
"""
Download and setup the test fixtures
"""
# download the test data
cls.dpath = 'demo/rank/'
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
target = cls.dpath + '/MQ2008.zip'
urllib.request.urlretrieve(url=src, filename=target)

with zipfile.ZipFile(target, 'r') as f:
f.extractall(path=cls.dpath)

x_train, y_train, qid_train, x_test, y_test, qid_test, x_valid, y_valid, qid_valid = load_svmlight_files(
(cls.dpath + "MQ2008/Fold1/train.txt",
cls.dpath + "MQ2008/Fold1/test.txt",
cls.dpath + "MQ2008/Fold1/vali.txt"),
query_id=True, zero_based=False)
# instantiate the matrices
cls.dtrain = xgboost.DMatrix(x_train, y_train)
cls.dvalid = xgboost.DMatrix(x_valid, y_valid)
cls.dtest = xgboost.DMatrix(x_test, y_test)
# set the group counts from the query IDs
cls.dtrain.set_group([len(list(items))
for _key, items in itertools.groupby(qid_train)])
cls.dtest.set_group([len(list(items))
for _key, items in itertools.groupby(qid_test)])
cls.dvalid.set_group([len(list(items))
for _key, items in itertools.groupby(qid_valid)])
# save the query IDs for testing
cls.qid_train = qid_train
cls.qid_test = qid_test
cls.qid_valid = qid_valid

# model training parameters
cls.params = {'objective': 'rank:pairwise',
'booster': 'gbtree',
'silent': 0,
'eval_metric': ['ndcg']
}

@classmethod
def tearDownClass(cls):
"""
Cleanup test artifacts from download and unpacking
:return:
"""
os.remove(cls.dpath + "MQ2008.zip")
shutil.rmtree(cls.dpath + "MQ2008")

def test_training(self):
"""
Train an XGBoost ranking model
"""
# specify validations set to watch performance
watchlist = [(self.dtest, 'eval'), (self.dtrain, 'train')]
bst = xgboost.train(self.params, self.dtrain, num_boost_round=2500,
early_stopping_rounds=10, evals=watchlist)
assert bst.best_score > 0.98

def test_cv(self):
"""
Test cross-validation with a group specified
"""
cv = xgboost.cv(self.params, self.dtrain, num_boost_round=2500,
early_stopping_rounds=10, nfold=10, as_pandas=False)
assert isinstance(cv, dict)
self.assertSetEqual(set(cv.keys()), {'test-ndcg-mean', 'train-ndcg-mean', 'test-ndcg-std', 'train-ndcg-std'},
"CV results dict key mismatch")

def test_cv_no_shuffle(self):
"""
Test cross-validation with a group specified
"""
cv = xgboost.cv(self.params, self.dtrain, num_boost_round=2500,
early_stopping_rounds=10, shuffle=False, nfold=10, as_pandas=False)
assert isinstance(cv, dict)
assert len(cv) == 4

def test_get_group(self):
"""
Retrieve the group number from the dmatrix
"""
# control that should work
self.dtrain.get_uint_info('root_index')
# test the new getter
self.dtrain.get_uint_info('group_ptr')

for d, qid in [(self.dtrain, self.qid_train),
(self.dvalid, self.qid_valid),
(self.dtest, self.qid_test)]:
# size of each group
group_sizes = np.array([len(list(items))
for _key, items in itertools.groupby(qid)])
# indexes of group boundaries
group_limits = d.get_uint_info('group_ptr')
assert len(group_limits) == len(group_sizes)+1
assert np.array_equal(np.diff(group_limits), group_sizes)
assert np.array_equal(
group_sizes, np.diff(d.get_uint_info('group_ptr')))
assert np.array_equal(group_sizes, np.diff(d.get_uint_info('group_ptr')))
assert np.array_equal(group_limits, d.get_uint_info('group_ptr'))

0 comments on commit 278562d

Please sign in to comment.