diff --git a/R-package/DESCRIPTION b/R-package/DESCRIPTION index b03bc56d5298..ded208356c61 100755 --- a/R-package/DESCRIPTION +++ b/R-package/DESCRIPTION @@ -58,4 +58,4 @@ Imports: utils SystemRequirements: C++11 -RoxygenNote: 7.1.1 +RoxygenNote: 7.1.2 diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index b356ba927177..8df060d28605 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -3,10 +3,13 @@ S3method("dimnames<-",lgb.Dataset) S3method(dim,lgb.Dataset) S3method(dimnames,lgb.Dataset) +S3method(get_field,lgb.Dataset) S3method(getinfo,lgb.Dataset) S3method(predict,lgb.Booster) +S3method(set_field,lgb.Dataset) S3method(setinfo,lgb.Dataset) S3method(slice,lgb.Dataset) +export(get_field) export(getinfo) export(lgb.Dataset) export(lgb.Dataset.construct) @@ -30,6 +33,7 @@ export(lgb.unloader) export(lightgbm) export(readRDS.lgb.Booster) export(saveRDS.lgb.Booster) +export(set_field) export(setinfo) export(slice) import(methods) diff --git a/R-package/R/lgb.Dataset.R b/R-package/R/lgb.Dataset.R index 429eb1f91275..91f729188a1f 100644 --- a/R-package/R/lgb.Dataset.R +++ b/R-package/R/lgb.Dataset.R @@ -335,14 +335,17 @@ Dataset <- R6::R6Class( for (i in seq_along(private$info)) { p <- private$info[i] - self$setinfo(name = names(p), info = p[[1L]]) + self$set_field( + field_name = names(p) + , data = p[[1L]] + ) } } # Get label information existence - if (is.null(self$getinfo(name = "label"))) { + if (is.null(self$get_field(field_name = "label"))) { stop("lgb.Dataset.construct: label should be set") } @@ -452,19 +455,33 @@ Dataset <- R6::R6Class( }, - # Get information getinfo = function(name) { + warning(paste0( + "Dataset$getinfo() is deprecated and will be removed in a future release. " + , "Use Dataset$get_field() instead." + )) + return( + self$get_field( + field_name = name + ) + ) + }, + + get_field = function(field_name) { # Check if attribute key is in the known attribute list - if (!is.character(name) || length(name) != 1L || !name %in% .INFO_KEYS()) { - stop("getinfo: name must one of the following: ", paste0(sQuote(.INFO_KEYS()), collapse = ", ")) + if (!is.character(field_name) || length(field_name) != 1L || !field_name %in% .INFO_KEYS()) { + stop( + "Dataset$get_field(): field_name must one of the following: " + , paste0(sQuote(.INFO_KEYS()), collapse = ", ") + ) } # Check for info name and handle - if (is.null(private$info[[name]])) { + if (is.null(private$info[[field_name]])) { if (lgb.is.null.handle(x = private$handle)) { - stop("Cannot perform getinfo before constructing Dataset.") + stop("Cannot perform Dataset$get_field() before constructing Dataset.") } # Get field size of info @@ -472,7 +489,7 @@ Dataset <- R6::R6Class( .Call( LGBM_DatasetGetFieldSize_R , private$handle - , name + , field_name , info_len ) @@ -481,7 +498,7 @@ Dataset <- R6::R6Class( # Get back fields ret <- NULL - ret <- if (name == "group") { + ret <- if (field_name == "group") { integer(info_len) # Integer } else { numeric(info_len) # Numeric @@ -490,47 +507,62 @@ Dataset <- R6::R6Class( .Call( LGBM_DatasetGetField_R , private$handle - , name + , field_name , ret ) - private$info[[name]] <- ret + private$info[[field_name]] <- ret } } - return(private$info[[name]]) + return(private$info[[field_name]]) }, - # Set information setinfo = function(name, info) { + warning(paste0( + "Dataset$setinfo() is deprecated and will be removed in a future release. " + , "Use Dataset$set_field() instead." + )) + return( + self$set_field( + field_name = name + , data = info + ) + ) + }, + + set_field = function(field_name, data) { # Check if attribute key is in the known attribute list - if (!is.character(name) || length(name) != 1L || !name %in% .INFO_KEYS()) { - stop("setinfo: name must one of the following: ", paste0(sQuote(.INFO_KEYS()), collapse = ", ")) + if (!is.character(field_name) || length(field_name) != 1L || !field_name %in% .INFO_KEYS()) { + stop( + "Dataset$set_field(): field_name must one of the following: " + , paste0(sQuote(.INFO_KEYS()), collapse = ", ") + ) } # Check for type of information - info <- if (name == "group") { - as.integer(info) # Integer + data <- if (field_name == "group") { + as.integer(data) # Integer } else { - as.numeric(info) # Numeric + as.numeric(data) # Numeric } # Store information privately - private$info[[name]] <- info + private$info[[field_name]] <- data - if (!lgb.is.null.handle(x = private$handle) && !is.null(info)) { + if (!lgb.is.null.handle(x = private$handle) && !is.null(data)) { - if (length(info) > 0L) { + if (length(data) > 0L) { .Call( LGBM_DatasetSetField_R , private$handle - , name - , info - , length(info) + , field_name + , data + , length(data) ) private$version <- private$version + 1L @@ -554,7 +586,7 @@ Dataset <- R6::R6Class( , paste(names(additional_keyword_args), collapse = ", ") , ". These are ignored and should be removed. " , "To change the parameters of a Dataset produced by Dataset$slice(), use Dataset$set_params(). " - , "To modify attributes like 'init_score', use Dataset$setinfo(). " + , "To modify attributes like 'init_score', use Dataset$set_field(). " , "In future releases of lightgbm, this warning will become an error." )) } @@ -1110,7 +1142,7 @@ dimnames.lgb.Dataset <- function(x) { #' #' dsub <- lightgbm::slice(dtrain, seq_len(42L)) #' lgb.Dataset.construct(dsub) -#' labels <- lightgbm::getinfo(dsub, "label") +#' labels <- lightgbm::get_field(dsub, "label") #' } #' @export slice <- function(dataset, ...) { @@ -1173,6 +1205,8 @@ getinfo <- function(dataset, ...) { #' @export getinfo.lgb.Dataset <- function(dataset, name, ...) { + warning("Calling getinfo() on a lgb.Dataset is deprecated. Use get_field() instead.") + additional_args <- list(...) if (length(additional_args) > 0L) { warning(paste0( @@ -1187,7 +1221,7 @@ getinfo.lgb.Dataset <- function(dataset, name, ...) { stop("getinfo.lgb.Dataset: input dataset should be an lgb.Dataset object") } - return(dataset$getinfo(name = name)) + return(dataset$get_field(field_name = name)) } @@ -1236,6 +1270,8 @@ setinfo <- function(dataset, ...) { #' @export setinfo.lgb.Dataset <- function(dataset, name, info, ...) { + warning("Calling setinfo() on a lgb.Dataset is deprecated. Use set_field() instead.") + additional_args <- list(...) if (length(additional_args) > 0L) { warning(paste0( @@ -1250,7 +1286,102 @@ setinfo.lgb.Dataset <- function(dataset, name, info, ...) { stop("setinfo.lgb.Dataset: input dataset should be an lgb.Dataset object") } - return(invisible(dataset$setinfo(name = name, info = info))) + return(invisible(dataset$set_field(field_name = name, data = info))) +} + +#' @name get_field +#' @title Get one attribute of a \code{lgb.Dataset} +#' @description Get one attribute of a \code{lgb.Dataset} +#' @param dataset Object of class \code{lgb.Dataset} +#' @param field_name String with the name of the attribute to get. One of the following. +#' \itemize{ +#' \item \code{label}: label lightgbm learns from ; +#' \item \code{weight}: to do a weight rescale ; +#' \item{\code{group}: used for learning-to-rank tasks. An integer vector describing how to +#' group rows together as ordered results from the same set of candidate results to be ranked. +#' For example, if you have a 100-document dataset with \code{group = c(10, 20, 40, 10, 10, 10)}, +#' that means that you have 6 groups, where the first 10 records are in the first group, +#' records 11-30 are in the second group, etc.} +#' \item \code{init_score}: initial score is the base prediction lightgbm will boost from. +#' } +#' @return requested attribute +#' +#' @examples +#' \donttest{ +#' data(agaricus.train, package = "lightgbm") +#' train <- agaricus.train +#' dtrain <- lgb.Dataset(train$data, label = train$label) +#' lgb.Dataset.construct(dtrain) +#' +#' labels <- lightgbm::get_field(dtrain, "label") +#' lightgbm::set_field(dtrain, "label", 1 - labels) +#' +#' labels2 <- lightgbm::get_field(dtrain, "label") +#' stopifnot(all(labels2 == 1 - labels)) +#' } +#' @export +get_field <- function(dataset, field_name) { + UseMethod("get_field") +} + +#' @rdname get_field +#' @export +get_field.lgb.Dataset <- function(dataset, field_name) { + + # Check if dataset is not a dataset + if (!lgb.is.Dataset(x = dataset)) { + stop("get_field.lgb.Dataset(): input dataset should be an lgb.Dataset object") + } + + return(dataset$get_field(field_name = field_name)) + +} + +#' @name set_field +#' @title Set one attribute of a \code{lgb.Dataset} object +#' @description Set one attribute of a \code{lgb.Dataset} +#' @param dataset Object of class \code{lgb.Dataset} +#' @param field_name String with the name of the attribute to set. One of the following. +#' \itemize{ +#' \item \code{label}: label lightgbm learns from ; +#' \item \code{weight}: to do a weight rescale ; +#' \item{\code{group}: used for learning-to-rank tasks. An integer vector describing how to +#' group rows together as ordered results from the same set of candidate results to be ranked. +#' For example, if you have a 100-document dataset with \code{group = c(10, 20, 40, 10, 10, 10)}, +#' that means that you have 6 groups, where the first 10 records are in the first group, +#' records 11-30 are in the second group, etc.} +#' \item \code{init_score}: initial score is the base prediction lightgbm will boost from. +#' } +#' @param data The data for the field. See examples. +#' @return The \code{lgb.Dataset} you passed in. +#' +#' @examples +#' \donttest{ +#' data(agaricus.train, package = "lightgbm") +#' train <- agaricus.train +#' dtrain <- lgb.Dataset(train$data, label = train$label) +#' lgb.Dataset.construct(dtrain) +#' +#' labels <- lightgbm::get_field(dtrain, "label") +#' lightgbm::set_field(dtrain, "label", 1 - labels) +#' +#' labels2 <- lightgbm::get_field(dtrain, "label") +#' stopifnot(all.equal(labels2, 1 - labels)) +#' } +#' @export +set_field <- function(dataset, field_name, data) { + UseMethod("set_field") +} + +#' @rdname set_field +#' @export +set_field.lgb.Dataset <- function(dataset, field_name, data) { + + if (!lgb.is.Dataset(x = dataset)) { + stop("set_field.lgb.Dataset: input dataset should be an lgb.Dataset object") + } + + return(invisible(dataset$set_field(field_name = field_name, data = data))) } #' @name lgb.Dataset.set.categorical diff --git a/R-package/R/lgb.cv.R b/R-package/R/lgb.cv.R index 81890a6a90c7..039411e4cf40 100644 --- a/R-package/R/lgb.cv.R +++ b/R-package/R/lgb.cv.R @@ -206,7 +206,7 @@ lgb.cv <- function(params = list() ) if (!is.null(weight)) { - data$setinfo(name = "weight", info = weight) + data$set_field(field_name = "weight", data = weight) } # Update parameters with parsed parameters @@ -245,8 +245,8 @@ lgb.cv <- function(params = list() nfold = nfold , nrows = nrow(data) , stratified = stratified - , label = getinfo(dataset = data, name = "label") - , group = getinfo(dataset = data, name = "group") + , label = get_field(dataset = data, field_name = "label") + , group = get_field(dataset = data, field_name = "group") , params = params ) @@ -320,8 +320,8 @@ lgb.cv <- function(params = list() if (folds_have_group) { test_indices <- folds[[k]]$fold test_group_indices <- folds[[k]]$group - test_groups <- getinfo(dataset = data, name = "group")[test_group_indices] - train_groups <- getinfo(dataset = data, name = "group")[-test_group_indices] + test_groups <- get_field(dataset = data, field_name = "group")[test_group_indices] + train_groups <- get_field(dataset = data, field_name = "group")[-test_group_indices] } else { test_indices <- folds[[k]] } @@ -330,28 +330,28 @@ lgb.cv <- function(params = list() # set up test set indexDT <- data.table::data.table( indices = test_indices - , weight = getinfo(dataset = data, name = "weight")[test_indices] - , init_score = getinfo(dataset = data, name = "init_score")[test_indices] + , weight = get_field(dataset = data, field_name = "weight")[test_indices] + , init_score = get_field(dataset = data, field_name = "init_score")[test_indices] ) data.table::setorderv(x = indexDT, cols = "indices", order = 1L) dtest <- slice(data, indexDT$indices) - setinfo(dataset = dtest, name = "weight", info = indexDT$weight) - setinfo(dataset = dtest, name = "init_score", info = indexDT$init_score) + set_field(dataset = dtest, field_name = "weight", data = indexDT$weight) + set_field(dataset = dtest, field_name = "init_score", data = indexDT$init_score) # set up training set indexDT <- data.table::data.table( indices = train_indices - , weight = getinfo(dataset = data, name = "weight")[train_indices] - , init_score = getinfo(dataset = data, name = "init_score")[train_indices] + , weight = get_field(dataset = data, field_name = "weight")[train_indices] + , init_score = get_field(dataset = data, field_name = "init_score")[train_indices] ) data.table::setorderv(x = indexDT, cols = "indices", order = 1L) dtrain <- slice(data, indexDT$indices) - setinfo(dataset = dtrain, name = "weight", info = indexDT$weight) - setinfo(dataset = dtrain, name = "init_score", info = indexDT$init_score) + set_field(dataset = dtrain, field_name = "weight", data = indexDT$weight) + set_field(dataset = dtrain, field_name = "init_score", data = indexDT$init_score) if (folds_have_group) { - setinfo(dataset = dtest, name = "group", info = test_groups) - setinfo(dataset = dtrain, name = "group", info = train_groups) + set_field(dataset = dtest, field_name = "group", data = test_groups) + set_field(dataset = dtrain, field_name = "group", data = train_groups) } booster <- Booster$new(params = params, train_set = dtrain) diff --git a/R-package/R/lgb.interprete.R b/R-package/R/lgb.interprete.R index 940613d21225..70aac8760485 100644 --- a/R-package/R/lgb.interprete.R +++ b/R-package/R/lgb.interprete.R @@ -21,7 +21,11 @@ #' data(agaricus.train, package = "lightgbm") #' train <- agaricus.train #' dtrain <- lgb.Dataset(train$data, label = train$label) -#' setinfo(dtrain, "init_score", rep(Logit(mean(train$label)), length(train$label))) +#' set_field( +#' dataset = dtrain +#' , field_name = "init_score" +#' , data = rep(Logit(mean(train$label)), length(train$label)) +#' ) #' data(agaricus.test, package = "lightgbm") #' test <- agaricus.test #' diff --git a/R-package/R/lgb.plot.interpretation.R b/R-package/R/lgb.plot.interpretation.R index 1bc2fa0f9582..aa3cbae05625 100644 --- a/R-package/R/lgb.plot.interpretation.R +++ b/R-package/R/lgb.plot.interpretation.R @@ -25,7 +25,11 @@ #' agaricus.train$data #' , label = labels #' ) -#' setinfo(dtrain, "init_score", rep(Logit(mean(labels)), length(labels))) +#' set_field( +#' dataset = dtrain +#' , field_name = "init_score" +#' , data = rep(Logit(mean(labels)), length(labels)) +#' ) #' #' data(agaricus.test, package = "lightgbm") #' diff --git a/R-package/demo/basic_walkthrough.R b/R-package/demo/basic_walkthrough.R index 3dc672e11d73..b5bfed26d935 100644 --- a/R-package/demo/basic_walkthrough.R +++ b/R-package/demo/basic_walkthrough.R @@ -147,8 +147,8 @@ bst <- lgb.train( , valids = valids ) -# information can be extracted from lgb.Dataset using getinfo -label <- getinfo(dtest, "label") +# information can be extracted from lgb.Dataset using get_field() +label <- get_field(dtest, "label") pred <- predict(bst, test$data) err <- as.numeric(sum(as.integer(pred > 0.5) != label)) / length(label) print(paste("test-error=", err)) diff --git a/R-package/demo/boost_from_prediction.R b/R-package/demo/boost_from_prediction.R index 457561cd5f70..b6b3f1ceba7b 100644 --- a/R-package/demo/boost_from_prediction.R +++ b/R-package/demo/boost_from_prediction.R @@ -27,8 +27,8 @@ ptest <- predict(bst, agaricus.test$data, rawscore = TRUE) # set the init_score property of dtrain and dtest # base margin is the base prediction we will boost from -setinfo(dtrain, "init_score", ptrain) -setinfo(dtest, "init_score", ptest) +set_field(dtrain, "init_score", ptrain) +set_field(dtest, "init_score", ptest) print("This is result of boost from initial prediction") bst <- lgb.train( diff --git a/R-package/demo/cross_validation.R b/R-package/demo/cross_validation.R index f685b520822d..0324f83f2da9 100644 --- a/R-package/demo/cross_validation.R +++ b/R-package/demo/cross_validation.R @@ -42,7 +42,7 @@ lgb.cv( print("Running cross validation, with cutomsized loss function") logregobj <- function(preds, dtrain) { - labels <- getinfo(dtrain, "label") + labels <- get_field(dtrain, "label") preds <- 1.0 / (1.0 + exp(-preds)) grad <- preds - labels hess <- preds * (1.0 - preds) @@ -55,7 +55,7 @@ logregobj <- function(preds, dtrain) { # For example, we are doing logistic loss, the prediction is score before logistic transformation # Keep this in mind when you use the customization, and maybe you need write customized evaluation function evalerror <- function(preds, dtrain) { - labels <- getinfo(dtrain, "label") + labels <- get_field(dtrain, "label") preds <- 1.0 / (1.0 + exp(-preds)) err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels) return(list(name = "error", value = err, higher_better = FALSE)) diff --git a/R-package/demo/early_stopping.R b/R-package/demo/early_stopping.R index fa8abce38b08..5179195f02ad 100644 --- a/R-package/demo/early_stopping.R +++ b/R-package/demo/early_stopping.R @@ -21,7 +21,7 @@ num_round <- 20L # User define objective function, given prediction, return gradient and second order gradient # This is loglikelihood loss logregobj <- function(preds, dtrain) { - labels <- getinfo(dtrain, "label") + labels <- get_field(dtrain, "label") preds <- 1.0 / (1.0 + exp(-preds)) grad <- preds - labels hess <- preds * (1.0 - preds) @@ -35,7 +35,7 @@ logregobj <- function(preds, dtrain) { # The built-in evaluation error assumes input is after logistic transformation # Keep this in mind when you use the customization, and maybe you need write customized evaluation function evalerror <- function(preds, dtrain) { - labels <- getinfo(dtrain, "label") + labels <- get_field(dtrain, "label") err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels) return(list(name = "error", value = err, higher_better = FALSE)) } diff --git a/R-package/demo/multiclass_custom_objective.R b/R-package/demo/multiclass_custom_objective.R index 70d5c6ce3f90..a1e8edc958aa 100644 --- a/R-package/demo/multiclass_custom_objective.R +++ b/R-package/demo/multiclass_custom_objective.R @@ -43,7 +43,7 @@ probs_builtin <- exp(preds_builtin) / rowSums(exp(preds_builtin)) # User defined objective function, given prediction, return gradient and second order gradient custom_multiclass_obj <- function(preds, dtrain) { - labels <- getinfo(dtrain, "label") + labels <- get_field(dtrain, "label") # preds is a matrix with rows corresponding to samples and columns corresponding to choices preds <- matrix(preds, nrow = length(labels)) @@ -73,7 +73,7 @@ custom_multiclass_obj <- function(preds, dtrain) { # define custom metric custom_multiclass_metric <- function(preds, dtrain) { - labels <- getinfo(dtrain, "label") + labels <- get_field(dtrain, "label") preds <- matrix(preds, nrow = length(labels)) preds <- preds - apply(preds, 1L, max) prob <- exp(preds) / rowSums(exp(preds)) diff --git a/R-package/man/get_field.Rd b/R-package/man/get_field.Rd new file mode 100644 index 000000000000..1b6692fcf807 --- /dev/null +++ b/R-package/man/get_field.Rd @@ -0,0 +1,46 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/lgb.Dataset.R +\name{get_field} +\alias{get_field} +\alias{get_field.lgb.Dataset} +\title{Get one attribute of a \code{lgb.Dataset}} +\usage{ +get_field(dataset, field_name) + +\method{get_field}{lgb.Dataset}(dataset, field_name) +} +\arguments{ +\item{dataset}{Object of class \code{lgb.Dataset}} + +\item{field_name}{String with the name of the attribute to get. One of the following. +\itemize{ + \item \code{label}: label lightgbm learns from ; + \item \code{weight}: to do a weight rescale ; + \item{\code{group}: used for learning-to-rank tasks. An integer vector describing how to + group rows together as ordered results from the same set of candidate results to be ranked. + For example, if you have a 100-document dataset with \code{group = c(10, 20, 40, 10, 10, 10)}, + that means that you have 6 groups, where the first 10 records are in the first group, + records 11-30 are in the second group, etc.} + \item \code{init_score}: initial score is the base prediction lightgbm will boost from. +}} +} +\value{ +requested attribute +} +\description{ +Get one attribute of a \code{lgb.Dataset} +} +\examples{ +\donttest{ +data(agaricus.train, package = "lightgbm") +train <- agaricus.train +dtrain <- lgb.Dataset(train$data, label = train$label) +lgb.Dataset.construct(dtrain) + +labels <- lightgbm::get_field(dtrain, "label") +lightgbm::set_field(dtrain, "label", 1 - labels) + +labels2 <- lightgbm::get_field(dtrain, "label") +stopifnot(all(labels2 == 1 - labels)) +} +} diff --git a/R-package/man/lgb.interprete.Rd b/R-package/man/lgb.interprete.Rd index c1905282623d..6431a5011f48 100644 --- a/R-package/man/lgb.interprete.Rd +++ b/R-package/man/lgb.interprete.Rd @@ -34,7 +34,11 @@ Logit <- function(x) log(x / (1.0 - x)) data(agaricus.train, package = "lightgbm") train <- agaricus.train dtrain <- lgb.Dataset(train$data, label = train$label) -setinfo(dtrain, "init_score", rep(Logit(mean(train$label)), length(train$label))) +set_field( + dataset = dtrain + , field_name = "init_score" + , data = rep(Logit(mean(train$label)), length(train$label)) +) data(agaricus.test, package = "lightgbm") test <- agaricus.test diff --git a/R-package/man/lgb.plot.interpretation.Rd b/R-package/man/lgb.plot.interpretation.Rd index f8266308552d..2d7416561f23 100644 --- a/R-package/man/lgb.plot.interpretation.Rd +++ b/R-package/man/lgb.plot.interpretation.Rd @@ -44,7 +44,11 @@ dtrain <- lgb.Dataset( agaricus.train$data , label = labels ) -setinfo(dtrain, "init_score", rep(Logit(mean(labels)), length(labels))) +set_field( + dataset = dtrain + , field_name = "init_score" + , data = rep(Logit(mean(labels)), length(labels)) +) data(agaricus.test, package = "lightgbm") diff --git a/R-package/man/set_field.Rd b/R-package/man/set_field.Rd new file mode 100644 index 000000000000..f9901e27eefd --- /dev/null +++ b/R-package/man/set_field.Rd @@ -0,0 +1,48 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/lgb.Dataset.R +\name{set_field} +\alias{set_field} +\alias{set_field.lgb.Dataset} +\title{Set one attribute of a \code{lgb.Dataset} object} +\usage{ +set_field(dataset, field_name, data) + +\method{set_field}{lgb.Dataset}(dataset, field_name, data) +} +\arguments{ +\item{dataset}{Object of class \code{lgb.Dataset}} + +\item{field_name}{String with the name of the attribute to set. One of the following. +\itemize{ + \item \code{label}: label lightgbm learns from ; + \item \code{weight}: to do a weight rescale ; + \item{\code{group}: used for learning-to-rank tasks. An integer vector describing how to + group rows together as ordered results from the same set of candidate results to be ranked. + For example, if you have a 100-document dataset with \code{group = c(10, 20, 40, 10, 10, 10)}, + that means that you have 6 groups, where the first 10 records are in the first group, + records 11-30 are in the second group, etc.} + \item \code{init_score}: initial score is the base prediction lightgbm will boost from. +}} + +\item{data}{The data for the field. See examples.} +} +\value{ +The \code{lgb.Dataset} you passed in. +} +\description{ +Set one attribute of a \code{lgb.Dataset} +} +\examples{ +\donttest{ +data(agaricus.train, package = "lightgbm") +train <- agaricus.train +dtrain <- lgb.Dataset(train$data, label = train$label) +lgb.Dataset.construct(dtrain) + +labels <- lightgbm::get_field(dtrain, "label") +lightgbm::set_field(dtrain, "label", 1 - labels) + +labels2 <- lightgbm::get_field(dtrain, "label") +stopifnot(all.equal(labels2, 1 - labels)) +} +} diff --git a/R-package/man/slice.Rd b/R-package/man/slice.Rd index 0e2c73108a93..988987f80b2d 100644 --- a/R-package/man/slice.Rd +++ b/R-package/man/slice.Rd @@ -31,6 +31,6 @@ dtrain <- lgb.Dataset(train$data, label = train$label) dsub <- lightgbm::slice(dtrain, seq_len(42L)) lgb.Dataset.construct(dsub) -labels <- lightgbm::getinfo(dsub, "label") +labels <- lightgbm::get_field(dsub, "label") } } diff --git a/R-package/pkgdown/_pkgdown.yml b/R-package/pkgdown/_pkgdown.yml index 60ed132dd1ab..c1e591475892 100644 --- a/R-package/pkgdown/_pkgdown.yml +++ b/R-package/pkgdown/_pkgdown.yml @@ -56,8 +56,8 @@ reference: contents: - '`dim.lgb.Dataset`' - '`dimnames.lgb.Dataset`' - - '`getinfo`' - - '`setinfo`' + - '`get_field`' + - '`set_field`' - '`slice`' - '`lgb.Dataset`' - '`lgb.Dataset.construct`' diff --git a/R-package/tests/testthat/test_custom_objective.R b/R-package/tests/testthat/test_custom_objective.R index f3224e9ebf0e..54f5c300907a 100644 --- a/R-package/tests/testthat/test_custom_objective.R +++ b/R-package/tests/testthat/test_custom_objective.R @@ -9,7 +9,7 @@ watchlist <- list(eval = dtest, train = dtrain) TOLERANCE <- 1e-6 logregobj <- function(preds, dtrain) { - labels <- getinfo(dtrain, "label") + labels <- get_field(dtrain, "label") preds <- 1.0 / (1.0 + exp(-preds)) grad <- preds - labels hess <- preds * (1.0 - preds) @@ -21,7 +21,7 @@ logregobj <- function(preds, dtrain) { # This may make built-in evalution metric calculate wrong results # Keep this in mind when you use the customization, and maybe you need write customized evaluation function evalerror <- function(preds, dtrain) { - labels <- getinfo(dtrain, "label") + labels <- get_field(dtrain, "label") preds <- 1.0 / (1.0 + exp(-preds)) err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels) return(list( diff --git a/R-package/tests/testthat/test_dataset.R b/R-package/tests/testthat/test_dataset.R index 04ca2e3d1cd3..ffcb99b35520 100644 --- a/R-package/tests/testthat/test_dataset.R +++ b/R-package/tests/testthat/test_dataset.R @@ -14,6 +14,7 @@ test_that("lgb.Dataset: basic construction, saving, loading", { # from dense matrix dtest2 <- lgb.Dataset(as.matrix(test_data), label = test_label) expect_equal(getinfo(dtest1, "label"), getinfo(dtest2, "label")) + expect_equal(get_field(dtest1, "label"), get_field(dtest2, "label")) # save to a local file tmp_file <- tempfile("lgb.Dataset_") @@ -23,6 +24,7 @@ test_that("lgb.Dataset: basic construction, saving, loading", { lgb.Dataset.construct(dtest3) unlink(tmp_file) expect_equal(getinfo(dtest1, "label"), getinfo(dtest3, "label")) + expect_equal(get_field(dtest1, "label"), get_field(dtest3, "label")) }) test_that("lgb.Dataset: getinfo & setinfo", { @@ -40,6 +42,21 @@ test_that("lgb.Dataset: getinfo & setinfo", { expect_error(setinfo(dtest, "asdf", test_label)) }) +test_that("lgb.Dataset: get_field & set_field", { + dtest <- lgb.Dataset(test_data) + dtest$construct() + + set_field(dtest, "label", test_label) + labels <- get_field(dtest, "label") + expect_equal(test_label, get_field(dtest, "label")) + + expect_true(length(get_field(dtest, "weight")) == 0L) + expect_true(length(get_field(dtest, "init_score")) == 0L) + + # any other label should error + expect_error(set_field(dtest, "asdf", test_label)) +}) + test_that("lgb.Dataset: slice, dim", { dtest <- lgb.Dataset(test_data, label = test_label) lgb.Dataset.construct(dtest) @@ -255,6 +272,19 @@ test_that("lgb.Dataset$setinfo() should convert 'group' to integer", { expect_identical(ds$getinfo("group"), as.integer(group_as_numeric)) }) +test_that("lgb.Dataset$set_field() should convert 'group' to integer", { + ds <- lgb.Dataset( + data = matrix(rnorm(100L), nrow = 50L, ncol = 2L) + , label = sample(c(0L, 1L), size = 50L, replace = TRUE) + ) + ds$construct() + current_group <- ds$get_field("group") + expect_null(current_group) + group_as_numeric <- rep(25.0, 2L) + ds$set_field("group", group_as_numeric) + expect_identical(ds$get_field("group"), as.integer(group_as_numeric)) +}) + test_that("lgb.Dataset should throw an error if 'reference' is provided but of the wrong format", { data(agaricus.test, package = "lightgbm") test_data <- agaricus.test$data[1L:100L, ] diff --git a/R-package/tests/testthat/test_lgb.interprete.R b/R-package/tests/testthat/test_lgb.interprete.R index bf113db43b97..86a968c7c7ed 100644 --- a/R-package/tests/testthat/test_lgb.interprete.R +++ b/R-package/tests/testthat/test_lgb.interprete.R @@ -11,10 +11,10 @@ test_that("lgb.intereprete works as expected for binary classification", { data(agaricus.train, package = "lightgbm") train <- agaricus.train dtrain <- lgb.Dataset(train$data, label = train$label) - setinfo( + set_field( dataset = dtrain - , "init_score" - , rep( + , field_name = "init_score" + , data = rep( .logit(mean(train$label)) , length(train$label) ) diff --git a/R-package/tests/testthat/test_lgb.plot.interpretation.R b/R-package/tests/testthat/test_lgb.plot.interpretation.R index be51e2fc965e..374de2231f23 100644 --- a/R-package/tests/testthat/test_lgb.plot.interpretation.R +++ b/R-package/tests/testthat/test_lgb.plot.interpretation.R @@ -11,10 +11,10 @@ test_that("lgb.plot.interepretation works as expected for binary classification" data(agaricus.train, package = "lightgbm") train <- agaricus.train dtrain <- lgb.Dataset(train$data, label = train$label) - setinfo( + set_field( dataset = dtrain - , "init_score" - , rep( + , field_name = "init_score" + , data = rep( .logit(mean(train$label)) , length(train$label) )