Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R] Move all DMatrix fields to function arguments #9862

Merged
merged 2 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 111 additions & 39 deletions R-package/R/xgb.DMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,24 @@
#' a \code{dgRMatrix} object,
#' a \code{dsparseVector} object (only when making predictions from a fitted model, will be
#' interpreted as a row vector), or a character string representing a filename.
#' @param info a named list of additional information to store in the \code{xgb.DMatrix} object.
#' See \code{\link{setinfo}} for the specific allowed kinds of
#' @param label Label of the training data.
#' @param weight Weight for each instance.
#'
#' Note that, for ranking task, weights are per-group. In ranking task, one weight
#' is assigned to each group (not each data point). This is because we
#' only care about the relative ordering of data points within each group,
#' so it doesn't make sense to assign weights to individual data points.
#' @param base_margin Base margin used for boosting from existing model.
#' @param missing a float value to represents missing values in data (used only when input is a dense matrix).
#' It is useful when a 0 or some other extreme value represents missing values in data.
#' @param silent whether to suppress printing an informational message after loading from a file.
#' @param feature_names Set names for features.
#' @param nthread Number of threads used for creating DMatrix.
#' @param ... the \code{info} data could be passed directly as parameters, without creating an \code{info} list.
#' @param group Group size for all ranking group.
#' @param qid Query ID for data samples, used for ranking.
#' @param label_lower_bound Lower bound for survival training.
#' @param label_upper_bound Upper bound for survival training.
#' @param feature_weights Set feature weights for column sampling.
#'
#' @details
#' Note that DMatrix objects are not serializable through R functions such as \code{saveRDS} or \code{save}.
Expand All @@ -34,8 +45,24 @@
#' dtrain <- xgb.DMatrix('xgb.DMatrix.data')
#' if (file.exists('xgb.DMatrix.data')) file.remove('xgb.DMatrix.data')
#' @export
xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthread = NULL, ...) {
cnames <- NULL
xgb.DMatrix <- function(
data,
label = NULL,
weight = NULL,
base_margin = NULL,
missing = NA,
silent = FALSE,
feature_names = colnames(data),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, didn't know R supports this type of self-referencing parameter.

nthread = NULL,
group = NULL,
qid = NULL,
label_lower_bound = NULL,
label_upper_bound = NULL,
feature_weights = NULL
) {
if (!is.null(group) && !is.null(qid)) {
stop("Either one of 'group' or 'qid' should be NULL")
}
if (typeof(data) == "character") {
if (length(data) > 1)
stop("'data' has class 'character' and length ", length(data),
Expand All @@ -44,7 +71,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
handle <- .Call(XGDMatrixCreateFromFile_R, data, as.integer(silent))
} else if (is.matrix(data)) {
handle <- .Call(XGDMatrixCreateFromMat_R, data, missing, as.integer(NVL(nthread, -1)))
cnames <- colnames(data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add this back in #9828 .

} else if (inherits(data, "dgCMatrix")) {
handle <- .Call(
XGDMatrixCreateFromCSC_R,
Expand All @@ -55,7 +81,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
missing,
as.integer(NVL(nthread, -1))
)
cnames <- colnames(data)
} else if (inherits(data, "dgRMatrix")) {
handle <- .Call(
XGDMatrixCreateFromCSR_R,
Expand All @@ -66,7 +91,6 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
missing,
as.integer(NVL(nthread, -1))
)
cnames <- colnames(data)
} else if (inherits(data, "dsparseVector")) {
indptr <- c(0L, as.integer(length(data@i)))
ind <- as.integer(data@i) - 1L
Expand All @@ -82,17 +106,38 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
} else {
stop("xgb.DMatrix does not support construction from ", typeof(data))
}

dmat <- handle
attributes(dmat) <- list(class = "xgb.DMatrix")
if (!is.null(cnames)) {
setinfo(dmat, "feature_name", cnames)
}

info <- append(info, list(...))
for (i in seq_along(info)) {
p <- info[i]
setinfo(dmat, names(p), p[[1]])
if (!is.null(label)) {
setinfo(dmat, "label", label)
}
if (!is.null(weight)) {
setinfo(dmat, "weight", weight)
}
if (!is.null(base_margin)) {
setinfo(dmat, "base_margin", base_margin)
}
if (!is.null(feature_names)) {
setinfo(dmat, "feature_name", feature_names)
}
if (!is.null(group)) {
setinfo(dmat, "group", group)
}
if (!is.null(qid)) {
setinfo(dmat, "qid", qid)
}
if (!is.null(label_lower_bound)) {
setinfo(dmat, "label_lower_bound", label_lower_bound)
}
if (!is.null(label_upper_bound)) {
setinfo(dmat, "label_upper_bound", label_upper_bound)
}
if (!is.null(feature_weights)) {
setinfo(dmat, "feature_weights", feature_weights)
}

return(dmat)
}

Expand Down Expand Up @@ -211,14 +256,20 @@ dimnames.xgb.DMatrix <- function(x) {
#' The \code{name} field can be one of the following:
#'
#' \itemize{
#' \item \code{label}: label XGBoost learn from ;
#' \item \code{weight}: to do a weight rescale ;
#' \item \code{base_margin}: base margin is the base prediction XGBoost will boost from ;
#' \item \code{nrow}: number of rows of the \code{xgb.DMatrix}.
#'
#' \item \code{label}
#' \item \code{weight}
#' \item \code{base_margin}
#' \item \code{label_lower_bound}
#' \item \code{label_upper_bound}
#' \item \code{group}
#' \item \code{feature_type}
#' \item \code{feature_name}
#' \item \code{nrow}
#' }
#' See the documentation for \link{xgb.DMatrix} for more information about these fields.
#'
#' \code{group} can be setup by \code{setinfo} but can't be retrieved by \code{getinfo}.
#' Note that, while 'qid' cannot be retrieved, it's possible to get the equivalent 'group'
#' for a DMatrix that had 'qid' assigned.
#'
#' @examples
#' data(agaricus.train, package='xgboost')
Expand All @@ -236,24 +287,37 @@ getinfo <- function(object, ...) UseMethod("getinfo")
#' @rdname getinfo
#' @export
getinfo.xgb.DMatrix <- function(object, name, ...) {
allowed_int_fields <- 'group'
allowed_float_fields <- c(
'label', 'weight', 'base_margin',
'label_lower_bound', 'label_upper_bound'
)
allowed_str_fields <- c("feature_type", "feature_name")
allowed_fields <- c(allowed_float_fields, allowed_int_fields, allowed_str_fields, 'nrow')

if (typeof(name) != "character" ||
length(name) != 1 ||
!name %in% c('label', 'weight', 'base_margin', 'nrow',
'label_lower_bound', 'label_upper_bound', "feature_type", "feature_name")) {
stop(
"getinfo: name must be one of the following\n",
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound', 'feature_type', 'feature_name'"
)
!name %in% allowed_fields) {
stop("getinfo: name must be one of the following\n",
paste(paste0("'", allowed_fields, "'"), collapse = ", "))
}
if (name == "feature_name" || name == "feature_type") {
if (name == "nrow") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this useful in practice? We have dim.xgb.DMatrix exported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I guess it's not really a must-have.

ret <- nrow(object)
} else if (name %in% allowed_str_fields) {
ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name)
} else if (name != "nrow") {
ret <- .Call(XGDMatrixGetInfo_R, object, name)
} else if (name %in% allowed_float_fields) {
ret <- .Call(XGDMatrixGetFloatInfo_R, object, name)
if (length(ret) > nrow(object)) {
ret <- matrix(ret, nrow = nrow(object), byrow = TRUE)
}
} else if (name %in% allowed_int_fields) {
if (name == "group") {
name <- "group_ptr"
}
ret <- .Call(XGDMatrixGetUIntInfo_R, object, name)
if (length(ret) > nrow(object)) {
ret <- matrix(ret, nrow = nrow(object), byrow = TRUE)
}
} else {
ret <- nrow(object)
}
if (length(ret) == 0) return(NULL)
return(ret)
Expand All @@ -270,13 +334,15 @@ getinfo.xgb.DMatrix <- function(object, name, ...) {
#' @param ... other parameters
#'
#' @details
#' The \code{name} field can be one of the following:
#'
#' \itemize{
#' \item \code{label}: label XGBoost learn from ;
#' \item \code{weight}: to do a weight rescale ;
#' \item \code{base_margin}: base margin is the base prediction XGBoost will boost from ;
#' \item \code{group}: number of rows in each group (to use with \code{rank:pairwise} objective).
#' See the documentation for \link{xgb.DMatrix} for possible fields that can be set
#' (which correspond to arguments in that function).
#'
#' Note that the following fields are allowed in the construction of an \code{xgb.DMatrix}
#' but \bold{aren't} allowed here:\itemize{
#' \item data
#' \item missing
#' \item silent
#' \item nthread
#' }
#'
#' @examples
Expand Down Expand Up @@ -328,6 +394,12 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
return(TRUE)
}
if (name == "qid") {
if (NROW(info) != nrow(object))
stop("The length of qid assignments must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
return(TRUE)
}
if (name == "feature_weights") {
if (length(info) != ncol(object)) {
stop("The number of feature weights must equal to the number of columns in the input data")
Expand Down
18 changes: 12 additions & 6 deletions R-package/man/getinfo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 8 additions & 6 deletions R-package/man/setinfo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 30 additions & 5 deletions R-package/man/xgb.DMatrix.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions R-package/src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
extern SEXP XGDMatrixGetFloatInfo_R(SEXP, SEXP);
extern SEXP XGDMatrixGetUIntInfo_R(SEXP, SEXP);
extern SEXP XGDMatrixGetStrFeatureInfo_R(SEXP, SEXP);
extern SEXP XGDMatrixNumCol_R(SEXP);
extern SEXP XGDMatrixNumRow_R(SEXP);
Expand Down Expand Up @@ -76,7 +77,8 @@ static const R_CallMethodDef CallEntries[] = {
{"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 6},
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},
{"XGDMatrixGetFloatInfo_R", (DL_FUNC) &XGDMatrixGetFloatInfo_R, 2},
{"XGDMatrixGetUIntInfo_R", (DL_FUNC) &XGDMatrixGetUIntInfo_R, 2},
{"XGDMatrixGetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixGetStrFeatureInfo_R, 2},
{"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1},
{"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1},
Expand Down
Loading
Loading