-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[R] Move all DMatrix fields to function arguments #9862
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,13 +8,24 @@ | |
#' a \code{dgRMatrix} object, | ||
#' a \code{dsparseVector} object (only when making predictions from a fitted model, will be | ||
#' interpreted as a row vector), or a character string representing a filename. | ||
#' @param info a named list of additional information to store in the \code{xgb.DMatrix} object. | ||
#' See \code{\link{setinfo}} for the specific allowed kinds of | ||
#' @param label Label of the training data. | ||
#' @param weight Weight for each instance. | ||
#' | ||
#' Note that, for ranking task, weights are per-group. In ranking task, one weight | ||
#' is assigned to each group (not each data point). This is because we | ||
#' only care about the relative ordering of data points within each group, | ||
#' so it doesn't make sense to assign weights to individual data points. | ||
#' @param base_margin Base margin used for boosting from existing model. | ||
#' @param missing a float value to represents missing values in data (used only when input is a dense matrix). | ||
#' It is useful when a 0 or some other extreme value represents missing values in data. | ||
#' @param silent whether to suppress printing an informational message after loading from a file. | ||
#' @param feature_names Set names for features. | ||
#' @param nthread Number of threads used for creating DMatrix. | ||
#' @param ... the \code{info} data could be passed directly as parameters, without creating an \code{info} list. | ||
#' @param group Group size for all ranking group. | ||
#' @param qid Query ID for data samples, used for ranking. | ||
#' @param label_lower_bound Lower bound for survival training. | ||
#' @param label_upper_bound Upper bound for survival training. | ||
#' @param feature_weights Set feature weights for column sampling. | ||
#' | ||
#' @details | ||
#' Note that DMatrix objects are not serializable through R functions such as \code{saveRDS} or \code{save}. | ||
|
@@ -34,8 +45,24 @@ | |
#' dtrain <- xgb.DMatrix('xgb.DMatrix.data') | ||
#' if (file.exists('xgb.DMatrix.data')) file.remove('xgb.DMatrix.data') | ||
#' @export | ||
xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthread = NULL, ...) { | ||
cnames <- NULL | ||
xgb.DMatrix <- function( | ||
data, | ||
label = NULL, | ||
weight = NULL, | ||
base_margin = NULL, | ||
missing = NA, | ||
silent = FALSE, | ||
feature_names = colnames(data), | ||
nthread = NULL, | ||
group = NULL, | ||
qid = NULL, | ||
label_lower_bound = NULL, | ||
label_upper_bound = NULL, | ||
feature_weights = NULL | ||
) { | ||
if (!is.null(group) && !is.null(qid)) { | ||
stop("Either one of 'group' or 'qid' should be NULL") | ||
} | ||
if (typeof(data) == "character") { | ||
if (length(data) > 1) | ||
stop("'data' has class 'character' and length ", length(data), | ||
|
@@ -44,7 +71,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre | |
handle <- .Call(XGDMatrixCreateFromFile_R, data, as.integer(silent)) | ||
} else if (is.matrix(data)) { | ||
handle <- .Call(XGDMatrixCreateFromMat_R, data, missing, as.integer(NVL(nthread, -1))) | ||
cnames <- colnames(data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will add this back in #9828 . |
||
} else if (inherits(data, "dgCMatrix")) { | ||
handle <- .Call( | ||
XGDMatrixCreateFromCSC_R, | ||
|
@@ -55,7 +81,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre | |
missing, | ||
as.integer(NVL(nthread, -1)) | ||
) | ||
cnames <- colnames(data) | ||
} else if (inherits(data, "dgRMatrix")) { | ||
handle <- .Call( | ||
XGDMatrixCreateFromCSR_R, | ||
|
@@ -66,7 +91,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre | |
missing, | ||
as.integer(NVL(nthread, -1)) | ||
) | ||
cnames <- colnames(data) | ||
} else if (inherits(data, "dsparseVector")) { | ||
indptr <- c(0L, as.integer(length(data@i))) | ||
ind <- as.integer(data@i) - 1L | ||
|
@@ -82,17 +106,38 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre | |
} else { | ||
stop("xgb.DMatrix does not support construction from ", typeof(data)) | ||
} | ||
|
||
dmat <- handle | ||
attributes(dmat) <- list(class = "xgb.DMatrix") | ||
if (!is.null(cnames)) { | ||
setinfo(dmat, "feature_name", cnames) | ||
} | ||
|
||
info <- append(info, list(...)) | ||
for (i in seq_along(info)) { | ||
p <- info[i] | ||
setinfo(dmat, names(p), p[[1]]) | ||
if (!is.null(label)) { | ||
setinfo(dmat, "label", label) | ||
} | ||
if (!is.null(weight)) { | ||
setinfo(dmat, "weight", weight) | ||
} | ||
if (!is.null(base_margin)) { | ||
setinfo(dmat, "base_margin", base_margin) | ||
} | ||
if (!is.null(feature_names)) { | ||
setinfo(dmat, "feature_name", feature_names) | ||
} | ||
if (!is.null(group)) { | ||
setinfo(dmat, "group", group) | ||
} | ||
if (!is.null(qid)) { | ||
setinfo(dmat, "qid", qid) | ||
} | ||
if (!is.null(label_lower_bound)) { | ||
setinfo(dmat, "label_lower_bound", label_lower_bound) | ||
} | ||
if (!is.null(label_upper_bound)) { | ||
setinfo(dmat, "label_upper_bound", label_upper_bound) | ||
} | ||
if (!is.null(feature_weights)) { | ||
setinfo(dmat, "feature_weights", feature_weights) | ||
} | ||
|
||
return(dmat) | ||
} | ||
|
||
|
@@ -211,14 +256,20 @@ dimnames.xgb.DMatrix <- function(x) { | |
#' The \code{name} field can be one of the following: | ||
#' | ||
#' \itemize{ | ||
#' \item \code{label}: label XGBoost learn from ; | ||
#' \item \code{weight}: to do a weight rescale ; | ||
#' \item \code{base_margin}: base margin is the base prediction XGBoost will boost from ; | ||
#' \item \code{nrow}: number of rows of the \code{xgb.DMatrix}. | ||
#' | ||
#' \item \code{label} | ||
#' \item \code{weight} | ||
#' \item \code{base_margin} | ||
#' \item \code{label_lower_bound} | ||
#' \item \code{label_upper_bound} | ||
#' \item \code{group} | ||
#' \item \code{feature_type} | ||
#' \item \code{feature_name} | ||
#' \item \code{nrow} | ||
#' } | ||
#' See the documentation for \link{xgb.DMatrix} for more information about these fields. | ||
#' | ||
#' \code{group} can be setup by \code{setinfo} but can't be retrieved by \code{getinfo}. | ||
#' Note that, while 'qid' cannot be retrieved, it's possible to get the equivalent 'group' | ||
#' for a DMatrix that had 'qid' assigned. | ||
#' | ||
#' @examples | ||
#' data(agaricus.train, package='xgboost') | ||
|
@@ -236,24 +287,37 @@ getinfo <- function(object, ...) UseMethod("getinfo") | |
#' @rdname getinfo | ||
#' @export | ||
getinfo.xgb.DMatrix <- function(object, name, ...) { | ||
allowed_int_fields <- 'group' | ||
allowed_float_fields <- c( | ||
'label', 'weight', 'base_margin', | ||
'label_lower_bound', 'label_upper_bound' | ||
) | ||
allowed_str_fields <- c("feature_type", "feature_name") | ||
allowed_fields <- c(allowed_float_fields, allowed_int_fields, allowed_str_fields, 'nrow') | ||
|
||
if (typeof(name) != "character" || | ||
length(name) != 1 || | ||
!name %in% c('label', 'weight', 'base_margin', 'nrow', | ||
'label_lower_bound', 'label_upper_bound', "feature_type", "feature_name")) { | ||
stop( | ||
"getinfo: name must be one of the following\n", | ||
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound', 'feature_type', 'feature_name'" | ||
) | ||
!name %in% allowed_fields) { | ||
stop("getinfo: name must be one of the following\n", | ||
paste(paste0("'", allowed_fields, "'"), collapse = ", ")) | ||
} | ||
if (name == "feature_name" || name == "feature_type") { | ||
if (name == "nrow") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this useful in practice? We have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then I guess it's not really a must-have. |
||
ret <- nrow(object) | ||
} else if (name %in% allowed_str_fields) { | ||
ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name) | ||
} else if (name != "nrow") { | ||
ret <- .Call(XGDMatrixGetInfo_R, object, name) | ||
} else if (name %in% allowed_float_fields) { | ||
ret <- .Call(XGDMatrixGetFloatInfo_R, object, name) | ||
if (length(ret) > nrow(object)) { | ||
ret <- matrix(ret, nrow = nrow(object), byrow = TRUE) | ||
} | ||
} else if (name %in% allowed_int_fields) { | ||
if (name == "group") { | ||
name <- "group_ptr" | ||
} | ||
ret <- .Call(XGDMatrixGetUIntInfo_R, object, name) | ||
if (length(ret) > nrow(object)) { | ||
ret <- matrix(ret, nrow = nrow(object), byrow = TRUE) | ||
} | ||
} else { | ||
ret <- nrow(object) | ||
} | ||
if (length(ret) == 0) return(NULL) | ||
return(ret) | ||
|
@@ -270,13 +334,15 @@ getinfo.xgb.DMatrix <- function(object, name, ...) { | |
#' @param ... other parameters | ||
#' | ||
#' @details | ||
#' The \code{name} field can be one of the following: | ||
#' | ||
#' \itemize{ | ||
#' \item \code{label}: label XGBoost learn from ; | ||
#' \item \code{weight}: to do a weight rescale ; | ||
#' \item \code{base_margin}: base margin is the base prediction XGBoost will boost from ; | ||
#' \item \code{group}: number of rows in each group (to use with \code{rank:pairwise} objective). | ||
#' See the documentation for \link{xgb.DMatrix} for possible fields that can be set | ||
#' (which correspond to arguments in that function). | ||
#' | ||
#' Note that the following fields are allowed in the construction of an \code{xgb.DMatrix} | ||
#' but \bold{aren't} allowed here:\itemize{ | ||
#' \item data | ||
#' \item missing | ||
#' \item silent | ||
#' \item nthread | ||
#' } | ||
#' | ||
#' @examples | ||
|
@@ -328,6 +394,12 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) { | |
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info)) | ||
return(TRUE) | ||
} | ||
if (name == "qid") { | ||
if (NROW(info) != nrow(object)) | ||
stop("The length of qid assignments must equal to the number of rows in the input data") | ||
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info)) | ||
return(TRUE) | ||
} | ||
if (name == "feature_weights") { | ||
if (length(info) != ncol(object)) { | ||
stop("The number of feature weights must equal to the number of columns in the input data") | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, didn't know R supports this type of self-referencing parameter.