-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[R package] Accept data frames as inputs #4207
Changes from all commits
e398823
7bd4ed6
f90304a
a911c37
d6671df
08d8a34
9fb2ffa
cb55287
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -35,6 +35,10 @@ Dataset <- R6::R6Class( | |||||||||||||||||||||||||||||||||
free_raw_data = TRUE, | ||||||||||||||||||||||||||||||||||
used_indices = NULL, | ||||||||||||||||||||||||||||||||||
info = list(), | ||||||||||||||||||||||||||||||||||
label = NULL, | ||||||||||||||||||||||||||||||||||
weight = NULL, | ||||||||||||||||||||||||||||||||||
init_score = NULL, | ||||||||||||||||||||||||||||||||||
group = NULL, | ||||||||||||||||||||||||||||||||||
Comment on lines
+38
to
+41
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this change and the corresponding changes below necessary to support the feature "Accept data frames as inputs"? Or do you just have a personal preference for explicit keyword arguments instead of the If it's not strictly required to support this feature, please revert this change, make this PR work with the existing interface of That will reduce the size of the diff here, which should allow us to provide a higher-quality review and bring this PR to a resolution more quickly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's necessary, as otherwise it wouldn't be possible to pass them as column names. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Ok, please revert any changes related to being able to pass these characteristics as column names. I'm happy to have a discussion about that in a separate feature request issue, but I don't think that change is strictly required to accomplish the behavior "accept data frames as inputs". We have a strong preference in this project for incremental progress through as-small-as-possible pull requests focused on a single thing. That enables the team of maintainers here (mostly volunteers) to provide higher-quality reviews and prevents the situation where pull requests touching a lot of code drag on for a long time and block or complicate other development. |
||||||||||||||||||||||||||||||||||
...) { | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# validate inputs early to avoid unnecessary computation | ||||||||||||||||||||||||||||||||||
|
@@ -45,28 +49,92 @@ Dataset <- R6::R6Class( | |||||||||||||||||||||||||||||||||
stop("lgb.Dataset: If provided, predictor must be a ", sQuote("lgb.Predictor")) | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Create known attributes list | ||||||||||||||||||||||||||||||||||
if (!is.null(label)) info[["label"]] <- label | ||||||||||||||||||||||||||||||||||
if (!is.null(weight)) info[["weight"]] <- weight | ||||||||||||||||||||||||||||||||||
if (!is.null(init_score)) info[["init_score"]] <- init_score | ||||||||||||||||||||||||||||||||||
if (!is.null(group)) info[["group"]] <- group | ||||||||||||||||||||||||||||||||||
Comment on lines
+53
to
+56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Following the style in the rest of the R package, please use |
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Check for additional parameters | ||||||||||||||||||||||||||||||||||
additional_params <- list(...) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Create known attributes list | ||||||||||||||||||||||||||||||||||
INFO_KEYS <- c("label", "weight", "init_score", "group") | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Check if attribute key is in the known attribute list | ||||||||||||||||||||||||||||||||||
for (key in names(additional_params)) { | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Key existing | ||||||||||||||||||||||||||||||||||
if (key %in% INFO_KEYS) { | ||||||||||||||||||||||||||||||||||
# Store as param | ||||||||||||||||||||||||||||||||||
params[[key]] <- additional_params[[key]] | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Store as info | ||||||||||||||||||||||||||||||||||
info[[key]] <- additional_params[[key]] | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||
# If it's a data.frame, will keep track of the categorical encodings | ||||||||||||||||||||||||||||||||||
if (inherits(data, "data.frame")) { | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
if (!nrow(data) || !ncol(data)) | ||||||||||||||||||||||||||||||||||
stop("'data' is empty.") | ||||||||||||||||||||||||||||||||||
Comment on lines
+72
to
+73
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
if (is.null(reference)) { | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Factors are taken directly in data frames, so should not be supplied | ||||||||||||||||||||||||||||||||||
if (!is.null(categorical_feature)) | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't agree with this decision. If my data are already in integer or numeric format, this would mean that to tell I'm supportive of the "automatically treat factors as categorical features" proposal, but I'd prefer to have behavior like "if the input is a |
||||||||||||||||||||||||||||||||||
stop("Cannot pass 'categorical_feature' for data.frame. Categorical features should be factor columns.") | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Column names will also be taken directly | ||||||||||||||||||||||||||||||||||
if (!is.null(colnames)) | ||||||||||||||||||||||||||||||||||
stop("Cannot pass 'colnames' for data.frame. Column names will be taken from it directly.") | ||||||||||||||||||||||||||||||||||
Comment on lines
+82
to
+83
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
colnames <- names(data) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# First check if the column types are all numeric or categorical | ||||||||||||||||||||||||||||||||||
supported_coltypes <- c("numeric", "integer", "logical", "character", "factor", "POSIXct", "Date") | ||||||||||||||||||||||||||||||||||
coltype_is_unsupported <- sapply(data, function(x) !inherits(x, supported_coltypes)) | ||||||||||||||||||||||||||||||||||
if (any(coltype_is_unsupported)) | ||||||||||||||||||||||||||||||||||
stop("'data' contains unsupported column types.") | ||||||||||||||||||||||||||||||||||
Comment on lines
+89
to
+90
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Ordered factors are not supported, so it will warn if there's any | ||||||||||||||||||||||||||||||||||
has_ordered_factor <- sapply(data, is.ordered) | ||||||||||||||||||||||||||||||||||
if (any(has_ordered_factor)) | ||||||||||||||||||||||||||||||||||
warning("Warning: ordered factors are not supported, will interpret them as unordered.") | ||||||||||||||||||||||||||||||||||
Comment on lines
+94
to
+95
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Store as param | ||||||||||||||||||||||||||||||||||
params[[key]] <- additional_params[[key]] | ||||||||||||||||||||||||||||||||||
# For faster conversions between types | ||||||||||||||||||||||||||||||||||
data <- data.table::as.data.table(data) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Now see if there are any categorical columns that will be encoded | ||||||||||||||||||||||||||||||||||
cols_char <- sapply(data, is.character) | ||||||||||||||||||||||||||||||||||
if (any(cols_char)) { | ||||||||||||||||||||||||||||||||||
names_cols_char <- names(data)[cols_char] | ||||||||||||||||||||||||||||||||||
data[, (names_cols_char) := lapply(.SD, factor), .SDcols = names_cols_char] | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
cols_factor <- sapply(data, is.factor) | ||||||||||||||||||||||||||||||||||
if (any(cols_factor)) { | ||||||||||||||||||||||||||||||||||
categorical_feature <- names(data)[cols_factor] | ||||||||||||||||||||||||||||||||||
data[, (categorical_feature) := lapply(.SD, factor), .SDcols = categorical_feature] | ||||||||||||||||||||||||||||||||||
private$factor_levels <- lapply(data[, categorical_feature, with = FALSE], levels) | ||||||||||||||||||||||||||||||||||
encode_categ <- function(x) { | ||||||||||||||||||||||||||||||||||
x <- as.numeric(x) | ||||||||||||||||||||||||||||||||||
x[is.na(x)] <- 0.0 | ||||||||||||||||||||||||||||||||||
x <- x - 1.0 | ||||||||||||||||||||||||||||||||||
return(x) | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
data[ | ||||||||||||||||||||||||||||||||||
, (categorical_feature) := lapply(.SD, encode_categ) | ||||||||||||||||||||||||||||||||||
, .SDcols = categorical_feature | ||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Finally, convert all columns to numeric and turn it into a matrix | ||||||||||||||||||||||||||||||||||
data <- as.matrix(data[, lapply(.SD, as.numeric)]) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# When passing a reference, will take the columns and categorical encodings from it instead | ||||||||||||||||||||||||||||||||||
data <- self$process_data_frame_columns( | ||||||||||||||||||||||||||||||||||
data, | ||||||||||||||||||||||||||||||||||
reference$get_colnames(), | ||||||||||||||||||||||||||||||||||
reference$get_categorical_feature(), | ||||||||||||||||||||||||||||||||||
reference$get_factor_levels() | ||||||||||||||||||||||||||||||||||
Comment on lines
+130
to
+133
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
private$is_from_data_frame <- TRUE | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Check for matrix format | ||||||||||||||||||||||||||||||||||
|
@@ -419,6 +487,49 @@ Dataset <- R6::R6Class( | |||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Get levels used to encode factor variables in data frames | ||||||||||||||||||||||||||||||||||
get_factor_levels = function() { | ||||||||||||||||||||||||||||||||||
return(private$factor_levels) | ||||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
get_categorical_feature = function() { | ||||||||||||||||||||||||||||||||||
return(private$categorical_feature) | ||||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
get_is_from_data_frame = function() { | ||||||||||||||||||||||||||||||||||
return(private$is_from_data_frame) | ||||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
process_data_frame_columns = function(data, colnames, categorical_feature, factor_levels) { | ||||||||||||||||||||||||||||||||||
data <- as.data.table(data) | ||||||||||||||||||||||||||||||||||
if (!is.null(colnames)) | ||||||||||||||||||||||||||||||||||
data <- data[, colnames, with = FALSE] | ||||||||||||||||||||||||||||||||||
Comment on lines
+505
to
+506
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
if (!is.null(factor_levels)) { | ||||||||||||||||||||||||||||||||||
data[ | ||||||||||||||||||||||||||||||||||
, (categorical_feature) | ||||||||||||||||||||||||||||||||||
:= mapply( | ||||||||||||||||||||||||||||||||||
function(col, levs) factor(col, levs), | ||||||||||||||||||||||||||||||||||
.SD, factor_levels, SIMPLIFY = FALSE | ||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
, .SDcols = categorical_feature | ||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||
encode_categ <- function(x) { | ||||||||||||||||||||||||||||||||||
x <- as.numeric(x) | ||||||||||||||||||||||||||||||||||
x[is.na(x)] <- 0.0 | ||||||||||||||||||||||||||||||||||
x <- x - 1.0 | ||||||||||||||||||||||||||||||||||
return(x) | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
data[ | ||||||||||||||||||||||||||||||||||
, (categorical_feature) := lapply(.SD, encode_categ) | ||||||||||||||||||||||||||||||||||
, .SDcols = categorical_feature | ||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||
if (any(sapply(data, function(x) is.character(x) || is.factor(x)))) | ||||||||||||||||||||||||||||||||||
stop("'data' contains categorical columns, but 'reference' did not have encodings for them.") | ||||||||||||||||||||||||||||||||||
Comment on lines
+527
to
+528
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
return(as.matrix(data[, lapply(.SD, as.numeric)])) | ||||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Get information | ||||||||||||||||||||||||||||||||||
getinfo = function(name) { | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
|
@@ -665,6 +776,8 @@ Dataset <- R6::R6Class( | |||||||||||||||||||||||||||||||||
reference = NULL, | ||||||||||||||||||||||||||||||||||
colnames = NULL, | ||||||||||||||||||||||||||||||||||
categorical_feature = NULL, | ||||||||||||||||||||||||||||||||||
factor_levels = NULL, | ||||||||||||||||||||||||||||||||||
is_from_data_frame = FALSE, | ||||||||||||||||||||||||||||||||||
predictor = NULL, | ||||||||||||||||||||||||||||||||||
free_raw_data = TRUE, | ||||||||||||||||||||||||||||||||||
used_indices = NULL, | ||||||||||||||||||||||||||||||||||
|
@@ -712,6 +825,48 @@ Dataset <- R6::R6Class( | |||||||||||||||||||||||||||||||||
self$finalize() | ||||||||||||||||||||||||||||||||||
return(invisible(self)) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
substitute_from_df_cols = function(data, label, weight, init_score, | ||||||||||||||||||||||||||||||||||
label_name, weight_name, init_score_name, | ||||||||||||||||||||||||||||||||||
env_where_to_substitute) { | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
check_is_df_col <- function(var, var_name, data) { | ||||||||||||||||||||||||||||||||||
var_name <- head(as.character(var_name), 1L) | ||||||||||||||||||||||||||||||||||
if (inherits(data, "data.frame") && NROW(var_name) && var_name != "NULL") { | ||||||||||||||||||||||||||||||||||
if (var_name %in% names(data)) { | ||||||||||||||||||||||||||||||||||
var <- data[[var_name]] | ||||||||||||||||||||||||||||||||||
data <- as.data.table(data)[, setdiff(names(data), var_name), with = FALSE] | ||||||||||||||||||||||||||||||||||
} else if (is.character(var) && NROW(var) == 1L && var %in% names(data)) { | ||||||||||||||||||||||||||||||||||
var <- data[[var]] | ||||||||||||||||||||||||||||||||||
data <- as.data.table(data)[, setdiff(names(data), var), with = FALSE] | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
return(list(var, data)) | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
label_name <- head(as.character(label_name), 1L) | ||||||||||||||||||||||||||||||||||
weight_name <- head(as.character(weight_name), 1L) | ||||||||||||||||||||||||||||||||||
init_score_name <- head(as.character(init_score_name), 1L) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
temp <- check_is_df_col(label, label_name, data) | ||||||||||||||||||||||||||||||||||
label <- temp[[1L]] | ||||||||||||||||||||||||||||||||||
data <- temp[[2L]] | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
temp <- check_is_df_col(weight, weight_name, data) | ||||||||||||||||||||||||||||||||||
weight <- temp[[1L]] | ||||||||||||||||||||||||||||||||||
data <- temp[[2L]] | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
temp <- check_is_df_col(init_score, init_score_name, data) | ||||||||||||||||||||||||||||||||||
init_score <- temp[[1L]] | ||||||||||||||||||||||||||||||||||
data <- temp[[2L]] | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
env_where_to_substitute$data <- data | ||||||||||||||||||||||||||||||||||
env_where_to_substitute$label <- label | ||||||||||||||||||||||||||||||||||
env_where_to_substitute$weight <- weight | ||||||||||||||||||||||||||||||||||
env_where_to_substitute$init_score <- init_score | ||||||||||||||||||||||||||||||||||
return(NULL) | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
|
@@ -720,14 +875,22 @@ Dataset <- R6::R6Class( | |||||||||||||||||||||||||||||||||
#' @title Construct \code{lgb.Dataset} object | ||||||||||||||||||||||||||||||||||
#' @description Construct \code{lgb.Dataset} object from dense matrix, sparse matrix | ||||||||||||||||||||||||||||||||||
#' or local file (that was created previously by saving an \code{lgb.Dataset}). | ||||||||||||||||||||||||||||||||||
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename | ||||||||||||||||||||||||||||||||||
#' @param data a \code{matrix} object, a \code{data.frame} object, a \code{dgCMatrix} object, | ||||||||||||||||||||||||||||||||||
#' or a character representing a filename. | ||||||||||||||||||||||||||||||||||
#' | ||||||||||||||||||||||||||||||||||
#' If passing a `data.frame`, will assume that columns are numeric if they are of types | ||||||||||||||||||||||||||||||||||
#' numeric, integer, logical, Date, or POSIXct; and will assume they are categorical if | ||||||||||||||||||||||||||||||||||
#' they are of types factor or character (ordered factors are taken as unordered). | ||||||||||||||||||||||||||||||||||
#' Other column types are not supported. | ||||||||||||||||||||||||||||||||||
#' @param params a list of parameters. See | ||||||||||||||||||||||||||||||||||
#' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#dataset-parameters}{ | ||||||||||||||||||||||||||||||||||
#' The "Dataset Parameters" section of the documentation} for a list of parameters | ||||||||||||||||||||||||||||||||||
#' and valid values. | ||||||||||||||||||||||||||||||||||
#' @param reference reference dataset. When LightGBM creates a Dataset, it does some preprocessing like binning | ||||||||||||||||||||||||||||||||||
#' continuous features into histograms. If you want to apply the same bin boundaries from an existing | ||||||||||||||||||||||||||||||||||
#' dataset to new \code{data}, pass that existing Dataset to this argument. | ||||||||||||||||||||||||||||||||||
#' dataset to new \code{data}, pass that existing Dataset to this argument. If the reference passed | ||||||||||||||||||||||||||||||||||
#' was constructed from a `data.frame`, will also take its column names, column order, column types, | ||||||||||||||||||||||||||||||||||
#' and levels of factor columns. | ||||||||||||||||||||||||||||||||||
#' @param colnames names of columns | ||||||||||||||||||||||||||||||||||
#' @param categorical_feature categorical features. This can either be a character vector of feature | ||||||||||||||||||||||||||||||||||
#' names or an integer vector with the indices of the features (e.g. | ||||||||||||||||||||||||||||||||||
|
@@ -738,6 +901,20 @@ Dataset <- R6::R6Class( | |||||||||||||||||||||||||||||||||
#' cannot be changed after it has been constructed. If you'd prefer to be able to | ||||||||||||||||||||||||||||||||||
#' change the Dataset object after construction, set \code{free_raw_data = FALSE}. | ||||||||||||||||||||||||||||||||||
#' @param info a list of information of the \code{lgb.Dataset} object | ||||||||||||||||||||||||||||||||||
#' @param label Label of the data (target variable). Should be a numeric vector. | ||||||||||||||||||||||||||||||||||
#' If `data` is a `data.frame`, can also specify it as a column name, passed either as a character | ||||||||||||||||||||||||||||||||||
#' variable or as a name. | ||||||||||||||||||||||||||||||||||
#' @param weight Weight for each instance/observation. Should be a numeric vector. | ||||||||||||||||||||||||||||||||||
#' If `data` is a `data.frame`, can also specify it as a column name, passed either as a character | ||||||||||||||||||||||||||||||||||
#' variable or as a name. | ||||||||||||||||||||||||||||||||||
#' @param init_score Init score for Dataset. Should be a numeric vector. | ||||||||||||||||||||||||||||||||||
#' If `data` is a `data.frame`, can also specify it as a column name, passed either as a character | ||||||||||||||||||||||||||||||||||
#' variable or as a name. | ||||||||||||||||||||||||||||||||||
#' @param group Group/query data, as integer vector. Only used in the learning-to-rank task. | ||||||||||||||||||||||||||||||||||
#' sum(group) = nrow(data). | ||||||||||||||||||||||||||||||||||
#' For example, if you have a 100-document dataset with `group = c(10, 20, 40, 10, 10, 10)`, | ||||||||||||||||||||||||||||||||||
#' that means that you have 6 groups, where the first 10 records are in the first group, | ||||||||||||||||||||||||||||||||||
#' records 11-30 are in the second group, records 31-70 are in the third group, etc. | ||||||||||||||||||||||||||||||||||
#' @param ... other information to pass to \code{info} or parameters pass to \code{params} | ||||||||||||||||||||||||||||||||||
#' | ||||||||||||||||||||||||||||||||||
#' @return constructed dataset | ||||||||||||||||||||||||||||||||||
|
@@ -760,7 +937,19 @@ lgb.Dataset <- function(data, | |||||||||||||||||||||||||||||||||
categorical_feature = NULL, | ||||||||||||||||||||||||||||||||||
free_raw_data = TRUE, | ||||||||||||||||||||||||||||||||||
info = list(), | ||||||||||||||||||||||||||||||||||
label = NULL, | ||||||||||||||||||||||||||||||||||
weight = NULL, | ||||||||||||||||||||||||||||||||||
init_score = NULL, | ||||||||||||||||||||||||||||||||||
group = NULL, | ||||||||||||||||||||||||||||||||||
...) { | ||||||||||||||||||||||||||||||||||
# Take variables from column names if appropriate | ||||||||||||||||||||||||||||||||||
if (is.data.frame(data)) { | ||||||||||||||||||||||||||||||||||
Dataset$private_methods$substitute_from_df_cols( | ||||||||||||||||||||||||||||||||||
data, label, weight, init_score, | ||||||||||||||||||||||||||||||||||
substitute(label), substitute(weight), substitute(init_score), | ||||||||||||||||||||||||||||||||||
environment() | ||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Create new dataset | ||||||||||||||||||||||||||||||||||
return( | ||||||||||||||||||||||||||||||||||
|
@@ -774,6 +963,10 @@ lgb.Dataset <- function(data, | |||||||||||||||||||||||||||||||||
, free_raw_data = free_raw_data | ||||||||||||||||||||||||||||||||||
, used_indices = NULL | ||||||||||||||||||||||||||||||||||
, info = info | ||||||||||||||||||||||||||||||||||
, label = label | ||||||||||||||||||||||||||||||||||
, weight = weight | ||||||||||||||||||||||||||||||||||
, init_score = init_score | ||||||||||||||||||||||||||||||||||
, group = group | ||||||||||||||||||||||||||||||||||
, ... | ||||||||||||||||||||||||||||||||||
)) | ||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -23,19 +23,26 @@ CVBooster <- R6::R6Class( | |||||||||||||||||||||||||||||||
#' @description Cross validation logic used by LightGBM | ||||||||||||||||||||||||||||||||
#' @inheritParams lgb_shared_params | ||||||||||||||||||||||||||||||||
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples. | ||||||||||||||||||||||||||||||||
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}} | ||||||||||||||||||||||||||||||||
#' @param weight vector of response values. If not NULL, will set to dataset | ||||||||||||||||||||||||||||||||
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}. | ||||||||||||||||||||||||||||||||
#' If \code{data} is a `data.frame`, can also specify it as a column name, passed either as a character | ||||||||||||||||||||||||||||||||
#' variable or as a name. | ||||||||||||||||||||||||||||||||
#' @param weight vector of response values. If not NULL, will set to dataset. | ||||||||||||||||||||||||||||||||
#' If \code{data} is a `data.frame`, can also specify it as a column name, passed either as a character | ||||||||||||||||||||||||||||||||
#' variable or as a name. | ||||||||||||||||||||||||||||||||
#' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals} | ||||||||||||||||||||||||||||||||
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation | ||||||||||||||||||||||||||||||||
#' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified | ||||||||||||||||||||||||||||||||
#' by the values of outcome labels. | ||||||||||||||||||||||||||||||||
#' @param folds \code{list} provides a possibility to use a list of pre-defined CV folds | ||||||||||||||||||||||||||||||||
#' (each element must be a vector of test fold's indices). When folds are supplied, | ||||||||||||||||||||||||||||||||
#' the \code{nfold} and \code{stratified} parameters are ignored. | ||||||||||||||||||||||||||||||||
#' @param colnames feature names, if not null, will use this to overwrite the names in dataset | ||||||||||||||||||||||||||||||||
#' @param colnames feature names, if not null, will use this to overwrite the names in dataset. | ||||||||||||||||||||||||||||||||
#' Not supported for `data.frame` inputs. | ||||||||||||||||||||||||||||||||
#' @param categorical_feature categorical features. This can either be a character vector of feature | ||||||||||||||||||||||||||||||||
#' names or an integer vector with the indices of the features (e.g. | ||||||||||||||||||||||||||||||||
#' \code{c(1L, 10L)} to say "the first and tenth columns"). | ||||||||||||||||||||||||||||||||
#' Not supported for `data.frame` inputs as for them it will determine this automatically | ||||||||||||||||||||||||||||||||
#' according to the column type (see the documentation of \link{lgb.Dataset} for details). | ||||||||||||||||||||||||||||||||
#' @param callbacks List of callback functions that are applied at each iteration. | ||||||||||||||||||||||||||||||||
#' @param reset_data Boolean, setting it to TRUE (not the default value) will transform the booster model | ||||||||||||||||||||||||||||||||
#' into a predictor model which frees up memory and the original datasets | ||||||||||||||||||||||||||||||||
|
@@ -99,6 +106,13 @@ lgb.cv <- function(params = list() | |||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
# If 'data' is not an lgb.Dataset, try to construct one using 'label' | ||||||||||||||||||||||||||||||||
if (!lgb.is.Dataset(x = data)) { | ||||||||||||||||||||||||||||||||
if (inherits(data, "data.frame")) { | ||||||||||||||||||||||||||||||||
Dataset$private_methods$substitute_from_df_cols( | ||||||||||||||||||||||||||||||||
data, label, weight, NULL, | ||||||||||||||||||||||||||||||||
substitute(label), substitute(weight), NULL, | ||||||||||||||||||||||||||||||||
environment() | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
Comment on lines
+110
to
+114
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
if (is.null(label)) { | ||||||||||||||||||||||||||||||||
stop("'label' must be provided for lgb.cv if 'data' is not an 'lgb.Dataset'") | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please match the style in used in the rest of the R package, which is comma-first and closing parentheses vertically aligned with the beginning of the line.
And please use keyword arguments for any function calls to internal methods with more than one argument. This makes the code a bit easier to read and prevents bugs caused by mistakes in the order that arguments are given.