From 42e4be2b704edebc27c68c399671e62949cde49f Mon Sep 17 00:00:00 2001 From: Laurae Date: Sun, 5 Mar 2017 13:29:09 +0100 Subject: [PATCH 1/6] Create branch --- R-package/R/lgb.train.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R-package/R/lgb.train.R b/R-package/R/lgb.train.R index 8b08851c8c2e..2fbcc60628e3 100644 --- a/R-package/R/lgb.train.R +++ b/R-package/R/lgb.train.R @@ -1,4 +1,4 @@ -#' Main training logic for LightGBM +#' Main training logic for LightGBM #' #' @param params List of parameters #' @param data a \code{lgb.Dataset} object, used for training From a558b1e0f873a174952b26fc8f4c322ebf41679d Mon Sep 17 00:00:00 2001 From: Laurae2 Date: Wed, 15 Mar 2017 13:18:02 +0100 Subject: [PATCH 2/6] Attempt to add raw + saveRDS --- R-package/NAMESPACE | 1 + R-package/R/lgb.Booster.R | 5 +++- R-package/R/saveRDS.lgb.Booster.R | 36 +++++++++++++++++++++++ R-package/man/saveRDS.lgb.Booster.Rd | 43 ++++++++++++++++++++++++++++ 4 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 R-package/R/saveRDS.lgb.Booster.R create mode 100644 R-package/man/saveRDS.lgb.Booster.Rd diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index 4ba52903ccb7..e4daa91e9473 100755 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -26,6 +26,7 @@ export(lgb.plot.interpretation) export(lgb.save) export(lgb.train) export(lightgbm) +export(saveRDS.lgb.Booster) export(setinfo) export(slice) import(methods) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index 305243b9b77a..075b9dd00edb 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -195,7 +195,10 @@ Booster <- R6Class( predictor <- Predictor$new(private$handle) predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape) }, - to_predictor = function() { Predictor$new(private$handle) } + to_predictor = function() { Predictor$new(private$handle) }, + raw = function() { + raw <<- self$dump_model() + } ), private = list( handle = NULL, diff --git a/R-package/R/saveRDS.lgb.Booster.R b/R-package/R/saveRDS.lgb.Booster.R new file mode 100644 index 000000000000..1ce833690ce6 --- /dev/null +++ b/R-package/R/saveRDS.lgb.Booster.R @@ -0,0 +1,36 @@ +#' saveRDS for lgb.Booster models +#' +#' Attemps to save a model using RDS. +#' +#' @param object R object to serialize. +#' @param file a connection or the name of the file where the R object is saved to or read from. +#' @param ascii a logical. If TRUE or NA, an ASCII representation is written; otherwise (default), a binary one is used. See the comments in the help for save. +#' @param version the workspace format version to use. NULL specifies the current default version (2). Versions prior to 2 are not supported, so this will only be relevant when there are later versions. +#' @param compress a logical specifying whether saving to a named file is to use "gzip" compression, or one of "gzip", "bzip2" or "xz" to indicate the type of compression to be used. Ignored if file is a connection. +#' @param refhook a hook function for handling reference objects. +#' +#' @return NULL invisibly. +#' +#' @examples +#' \dontrun{ +#' library(lightgbm) +#' data(agaricus.train, package='lightgbm') +#' train <- agaricus.train +#' dtrain <- lgb.Dataset(train$data, label=train$label) +#' data(agaricus.test, package='lightgbm') +#' test <- agaricus.test +#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label) +#' params <- list(objective="regression", metric="l2") +#' valids <- list(test=dtest) +#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) +#' } +#' @export + +saveRDS.lgb.Booster <- function(object, file = "", ascii = FALSE, version = NULL, compress = TRUE, refhook = NULL) { + + if (class(object$raw) == "function") { + object$raw() + } + saveRDS(object, file = "", ascii = FALSE, version = NULL, compress = TRUE, refhook = NULL) + +} \ No newline at end of file diff --git a/R-package/man/saveRDS.lgb.Booster.Rd b/R-package/man/saveRDS.lgb.Booster.Rd new file mode 100644 index 000000000000..2bb232f12059 --- /dev/null +++ b/R-package/man/saveRDS.lgb.Booster.Rd @@ -0,0 +1,43 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/saveRDS.lgb.Booster.R +\name{saveRDS.lgb.Booster} +\alias{saveRDS.lgb.Booster} +\title{saveRDS for lgb.Booster models} +\usage{ +saveRDS.lgb.Booster(object, file = "", ascii = FALSE, version = NULL, + compress = TRUE, refhook = NULL) +} +\arguments{ +\item{object}{R object to serialize.} + +\item{file}{a connection or the name of the file where the R object is saved to or read from.} + +\item{ascii}{a logical. If TRUE or NA, an ASCII representation is written; otherwise (default), a binary one is used. See the comments in the help for save.} + +\item{version}{the workspace format version to use. NULL specifies the current default version (2). Versions prior to 2 are not supported, so this will only be relevant when there are later versions.} + +\item{compress}{a logical specifying whether saving to a named file is to use "gzip" compression, or one of "gzip", "bzip2" or "xz" to indicate the type of compression to be used. Ignored if file is a connection.} + +\item{refhook}{a hook function for handling reference objects.} +} +\value{ +NULL invisibly. +} +\description{ +Attemps to save a model using RDS. +} +\examples{ +\dontrun{ + library(lightgbm) + data(agaricus.train, package='lightgbm') + train <- agaricus.train + dtrain <- lgb.Dataset(train$data, label=train$label) + data(agaricus.test, package='lightgbm') + test <- agaricus.test + dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label) + params <- list(objective="regression", metric="l2") + valids <- list(test=dtest) + model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) +} +} + From 37cbfcbd0a4ea914069a6aeb1fcc4bfda4d94b40 Mon Sep 17 00:00:00 2001 From: Laurae2 Date: Wed, 15 Mar 2017 13:41:03 +0100 Subject: [PATCH 3/6] Attempt to switch to save -> raw system (lock) --- R-package/R/lgb.Booster.R | 5 +++-- R-package/R/saveRDS.lgb.Booster.R | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index 075b9dd00edb..378c8c9e5be5 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -196,8 +196,9 @@ Booster <- R6Class( predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape) }, to_predictor = function() { Predictor$new(private$handle) }, - raw = function() { - raw <<- self$dump_model() + raw = NA, + save = function() { + raw <- self$dump_model() } ), private = list( diff --git a/R-package/R/saveRDS.lgb.Booster.R b/R-package/R/saveRDS.lgb.Booster.R index 1ce833690ce6..4606be537e12 100644 --- a/R-package/R/saveRDS.lgb.Booster.R +++ b/R-package/R/saveRDS.lgb.Booster.R @@ -28,8 +28,8 @@ saveRDS.lgb.Booster <- function(object, file = "", ascii = FALSE, version = NULL, compress = TRUE, refhook = NULL) { - if (class(object$raw) == "function") { - object$raw() + if (is.na(raw)) { + object$save() } saveRDS(object, file = "", ascii = FALSE, version = NULL, compress = TRUE, refhook = NULL) From e8758dc32224108644129588f8accf509d525c9c Mon Sep 17 00:00:00 2001 From: Laurae2 Date: Wed, 15 Mar 2017 13:49:47 +0100 Subject: [PATCH 4/6] Switch to self$raw. --- R-package/R/lgb.Booster.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index 378c8c9e5be5..9766534789a8 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -198,7 +198,7 @@ Booster <- R6Class( to_predictor = function() { Predictor$new(private$handle) }, raw = NA, save = function() { - raw <- self$dump_model() + self$raw <- self$dump_model() } ), private = list( From a5397d4db90d62374e9b10da54315f3c5808f729 Mon Sep 17 00:00:00 2001 From: Laurae2 Date: Wed, 15 Mar 2017 13:58:15 +0100 Subject: [PATCH 5/6] Switch to object$raw and add raw option. --- R-package/R/saveRDS.lgb.Booster.R | 14 ++++++++------ R-package/man/saveRDS.lgb.Booster.Rd | 11 +++++++---- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/R-package/R/saveRDS.lgb.Booster.R b/R-package/R/saveRDS.lgb.Booster.R index 4606be537e12..41cd9f05070d 100644 --- a/R-package/R/saveRDS.lgb.Booster.R +++ b/R-package/R/saveRDS.lgb.Booster.R @@ -1,13 +1,14 @@ #' saveRDS for lgb.Booster models #' -#' Attemps to save a model using RDS. +#' Attemps to save a model using RDS. Has an additional parameter (\code{raw}) which decides whether to save the raw model or not. #' #' @param object R object to serialize. #' @param file a connection or the name of the file where the R object is saved to or read from. #' @param ascii a logical. If TRUE or NA, an ASCII representation is written; otherwise (default), a binary one is used. See the comments in the help for save. -#' @param version the workspace format version to use. NULL specifies the current default version (2). Versions prior to 2 are not supported, so this will only be relevant when there are later versions. -#' @param compress a logical specifying whether saving to a named file is to use "gzip" compression, or one of "gzip", "bzip2" or "xz" to indicate the type of compression to be used. Ignored if file is a connection. +#' @param version the workspace format version to use. \code{NULL} specifies the current default version (2). Versions prior to 2 are not supported, so this will only be relevant when there are later versions. +#' @param compress a logical specifying whether saving to a named file is to use "gzip" compression, or one of \code{"gzip"}, \code{"bzip2"} or \code{"xz"} to indicate the type of compression to be used. Ignored if file is a connection. #' @param refhook a hook function for handling reference objects. +#' @param raw whether to save the model in a raw variable or not, recommended to leave it to \code{TRUE}. #' #' @return NULL invisibly. #' @@ -23,14 +24,15 @@ #' params <- list(objective="regression", metric="l2") #' valids <- list(test=dtest) #' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) +#' saveRDS(model, "model.rds") #' } #' @export -saveRDS.lgb.Booster <- function(object, file = "", ascii = FALSE, version = NULL, compress = TRUE, refhook = NULL) { +saveRDS.lgb.Booster <- function(object, file = "", ascii = FALSE, version = NULL, compress = TRUE, refhook = NULL, raw = TRUE) { - if (is.na(raw)) { + if (is.na(object$raw) & (raw)) { object$save() } saveRDS(object, file = "", ascii = FALSE, version = NULL, compress = TRUE, refhook = NULL) -} \ No newline at end of file +} diff --git a/R-package/man/saveRDS.lgb.Booster.Rd b/R-package/man/saveRDS.lgb.Booster.Rd index 2bb232f12059..8894fa079102 100644 --- a/R-package/man/saveRDS.lgb.Booster.Rd +++ b/R-package/man/saveRDS.lgb.Booster.Rd @@ -5,7 +5,7 @@ \title{saveRDS for lgb.Booster models} \usage{ saveRDS.lgb.Booster(object, file = "", ascii = FALSE, version = NULL, - compress = TRUE, refhook = NULL) + compress = TRUE, refhook = NULL, raw = TRUE) } \arguments{ \item{object}{R object to serialize.} @@ -14,17 +14,19 @@ saveRDS.lgb.Booster(object, file = "", ascii = FALSE, version = NULL, \item{ascii}{a logical. If TRUE or NA, an ASCII representation is written; otherwise (default), a binary one is used. See the comments in the help for save.} -\item{version}{the workspace format version to use. NULL specifies the current default version (2). Versions prior to 2 are not supported, so this will only be relevant when there are later versions.} +\item{version}{the workspace format version to use. \code{NULL} specifies the current default version (2). Versions prior to 2 are not supported, so this will only be relevant when there are later versions.} -\item{compress}{a logical specifying whether saving to a named file is to use "gzip" compression, or one of "gzip", "bzip2" or "xz" to indicate the type of compression to be used. Ignored if file is a connection.} +\item{compress}{a logical specifying whether saving to a named file is to use "gzip" compression, or one of \code{"gzip"}, \code{"bzip2"} or \code{"xz"} to indicate the type of compression to be used. Ignored if file is a connection.} \item{refhook}{a hook function for handling reference objects.} + +\item{raw}{whether to save the model in a raw variable or not, recommended to leave it to \code{TRUE}.} } \value{ NULL invisibly. } \description{ -Attemps to save a model using RDS. +Attemps to save a model using RDS. Has an additional parameter (\code{raw}) which decides whether to save the raw model or not. } \examples{ \dontrun{ @@ -38,6 +40,7 @@ Attemps to save a model using RDS. params <- list(objective="regression", metric="l2") valids <- list(test=dtest) model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) + saveRDS(model, "model.rds") } } From a391eb3b67feff75e9d40e30eea3408dd848c4af Mon Sep 17 00:00:00 2001 From: Laurae2 Date: Wed, 15 Mar 2017 14:04:01 +0100 Subject: [PATCH 6/6] Can't overload RDS. Have user manually overload. --- R-package/.Rhistory | 512 +++++++++++++++++++++++++++ R-package/NAMESPACE | 1 - R-package/R/saveRDS.lgb.Booster.R | 38 -- R-package/man/saveRDS.lgb.Booster.Rd | 46 --- 4 files changed, 512 insertions(+), 85 deletions(-) create mode 100644 R-package/.Rhistory delete mode 100644 R-package/R/saveRDS.lgb.Booster.R delete mode 100644 R-package/man/saveRDS.lgb.Booster.Rd diff --git a/R-package/.Rhistory b/R-package/.Rhistory new file mode 100644 index 000000000000..d67b85b7ca3e --- /dev/null +++ b/R-package/.Rhistory @@ -0,0 +1,512 @@ +backgroundSize = "100% 90%", +backgroundRepeat = "no-repeat", +backgroundPosition = "center") %>% +formatStyle("Number of Projects", +background = styleColorBar(c(0, 7), color = "lightblue"), +backgroundSize = "100% 90%", +backgroundRepeat = "no-repeat", +backgroundPosition = "center") %>% +formatRound("Average Weekly Hours", digits = 2) %>% +formatStyle("Average Weekly Hours", +background = styleColorBar(c(0, 72), color = "lightblue"), +backgroundSize = "100% 90%", +backgroundRepeat = "no-repeat", +backgroundPosition = "center") %>% +formatStyle("Years in Company", +background = styleColorBar(c(0, 10), color = "lightblue"), +backgroundSize = "100% 90%", +backgroundRepeat = "no-repeat", +backgroundPosition = "center") %>% +formatStyle(c("Work Accident"), +backgroundColor = styleEqual(c("No Accident", "Accident"), c("lightgrey", "pink"))) %>% +formatStyle("Has Left", +backgroundColor = styleEqual(c("Not Left", "Left"), c("lightgrey", "pink"))) %>% +formatStyle("Promotion in Last 5 Years", +backgroundColor = styleEqual(c("No Promotion", "Promotion"), c("lightgrey", "orange"))) %>% +formatStyle("Salary", +backgroundColor = styleEqual(c("Low", "Medium", "High"), c("orange", "yellow", "lightblue"))) +}) +# Add Correlation plot +output$corrplot <- renderPlotly({ +plot_data <- dist_data() +plot_data[upper.tri(plot_data)] <- NA +plot_data <- melt(plot_data, na.rm = TRUE) +plot_data$value <- plot_data$value / nrow(better_data()) +plot_data$value <- -(plot_data$value - 0.5) * 2 +colnames(plot_data) <- c("Variable_1", "Variable_2", "Agreement") +if (input$check_corr == TRUE) { +plot_data$Text <- sprintf("%0.2f", round(plot_data$Agreement, digits = 2)) +return(ggplotly(ggplot(data = plot_data, aes_string(x = "Variable_1", y = "Variable_2", fill = "Agreement")) + geom_tile(color = "white") + geom_text(aes_string(x = "Variable_1", y = "Variable_2", label = "Text")) + scale_fill_gradient2(low = "red", high = "green", mid = "white", midpoint = 0, limit = c(-1, 1), space = "Lab", name = "Agreement
Strength") + theme_bw() + theme(axis.text.x = element_text(angle = 45)) + labs(x = "Variable 1", y = "Variable 2"), autosize = TRUE, margin = list(l = 20, r = 20, b = 250, t = 20, p = 4))) +} else { +return(ggplotly(ggplot(data = plot_data, aes_string(x = "Variable_1", y = "Variable_2", fill = "Agreement")) + geom_tile(color = "white") + scale_fill_gradient2(low = "green", high = "red", mid = "white", midpoint = 0, limit = c(-1, 1), space = "Lab", name = "Agreement
Strength") + theme_bw() + theme(axis.text.x = element_text(angle = 45)) + labs(x = "Variable 1", y = "Variable 2"), autosize = TRUE, margin = list(l = 20, r = 20, b = 250, t = 20, p = 4))) +} +}) +# Add Graph plot +output$graphplot <- renderPlot({ +plot_data <- dist_data() +plot_data <- plot_data / nrow(better_data()) +plot_data <- -(plot_data - 0.5) * 2 +features_name <- c("Satisfaction Level", "Last Evaluation Score", "Number of Projects", "Average Weekly Hours", "Years in Company", "Work Accident", "Has Left", "Promotion in Last 5 Years", "Department: Accounting", "Department: Human Resources", "Department: IT", "Department: Management", "Department: Marketing", "Department: Product Management", "Department: Research and Development", "Department: Sales", "Department: Support", "Department: Technical", "Salary: Low", "Salary: Medium", "Salary: High") +features_selected <- which(features_name %in% input$feat_corr) +colnames(plot_data) <- c("A1", "A2", "B", "C", "D", "E", "F", "G", "H01", "H02", "H03", "H04", "H05", "H06", "H07", "H08", "H09", "H10", "I1", "I2", "I3")[features_selected] +features_name <- c("A1: Satisfaction Level", "A2: Last Evaluation Score", "B: Number of Projects", "C: Average Weekly Hours", "D: Years in Company", "E: Work Accident", "F: Has Left", "G: Promotion in Last 5 Years", "H1 Department: Accounting", "H2 Department: Human Resources", "H3 Department: IT", "H4 Department: Management", "H5 Department: Marketing", "H6 Department: Product Management", "H7 Department: Research and Development", "H8 Department: Sales", "H9 Department: Support", "H10 Department: Technical", "I1 Salary: Low", "I2 Salary: Medium", "I3 Salary: High") +qgraph(plot_data, layout = "spring", groups = features_name[features_selected], palette = "pastel", theme = "classic", shape = "ellipse", borders = FALSE, vTrans = 100, vsize = 12, title = paste0("Agreement: [", paste(sprintf("%.03f", range(plot_data)), collapse = ", "), "]"), edge.labels = TRUE, XKCD = TRUE) +}) +# Plot tree +output$tree <- renderPlot({ +tree_data <- copy(better_data()) +levels(tree_data$`Department`) <- c("Accounting", "HR", "IT", "Mgmt", "Marketing", "Product Mgmt", "R&D", "Sales", "Support", "Tech") +tree_label <- copy(tree_data[[input$label]]) +#tree_data[[input$label]] <- NULL +tree_data <- tree_data[, unique(c(input$ban, input$label)), with = FALSE] +#tree_data <- tree_data[, input$ban[which(!input$ban %in% input$label)], with = FALSE] +formula <- reformulate(termlabels = paste0("`", input$ban[which(!input$ban %in% input$label)], "`"), response = input$label) +temp_model <- rpart(formula = formula, +data = tree_data, +method = ifelse(input$label %in% c("Satisfaction Level", "Last Evaluation Score", "Average Weekly Hours"), "anova", ifelse(input$label %in% c("Number of Projects", "Years in Company"), "poisson", "class")), +control = rpart.control(minsplit = input$min_split, +minbucket = input$min_bucket, +cp = input$min_improve, +maxcompete = 0, +maxsurrogate = input$surrogate_search, +usesurrogate = input$surrogate_type, +xval = 3, +surrogatestyle = input$surrogate_style, +maxdepth = input$max_depth)) +# temp_model <- Laurae::FeatureLookup(data = tree_data, +# label = tree_label, +# ban = NULL, +# antiban = FALSE, +# type = ifelse(input$label %in% c("Satisfaction Level", "Last Evaluation Score", "Average Weekly Hours"), "anova", ifelse(input$label %in% c("Number of Projects", "Years in Company"), "poisson", "class")), +# split = "information", +# folds = 3, +# seed = input$seed, +# verbose = FALSE, +# plots = FALSE, +# max_depth = input$max_depth, +# min_split = input$min_split, +# min_bucket = input$min_bucket, +# min_improve = input$min_improve, +# competing_splits = 0, +# surrogate_search = input$surrogate_search, +# surrogate_type = input$surrogate_type, +# surrogate_style = input$surrogate_style) +rpart.plot(temp_model, main = "Decision Tree", tweak = input$size/100) +}) +# Need to stop using a button? +observeEvent(input$done, { +stopApp(TRUE) +}) +} +runGadget(shinyApp(ui, server), viewer = paneViewer()) +library(lightgbm) +data(agaricus.train, package='lightgbm') +train <- agaricus.train +dtrain <- lgb.Dataset(train$data, label=train$label) +data(agaricus.test, package='lightgbm') +test <- agaricus.test +dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label) +params <- list(objective="regression", metric="l2") +valids <- list(test=dtest) +model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) +getwd() +saveRDS("D:/model.rds") +saveRDS(model, "D:/model.rds") +model$raw +model$raw() +unlockBinding("raw", model) +model$raw() +unlockBinding(model) +bindingIsLocked("raw", model) +model$whatever <- "yes" +?makeActiveBinding +f <- local( { +x <- 1 +function(v) { +if (missing(v)) +cat("get\n") +else { +cat("set\n") +x <<- v +} +x +} +}) +makeActiveBinding("fred", f, .GlobalEnv) +bindingIsActive("fred", .GlobalEnv) +fred +fred <- 2 +fred +makeActiveBinding("model", whatever, .GlobalEnv) +makeActiveBinding("whatever", model, .GlobalEnv) +makeActiveBinding("raw", model, .GlobalEnv) +makeActiveBinding("save", model, .GlobalEnv) +makeActiveBinding("save", raw, model) +with(model, save <- raw) +model[["save"]] <- model$raw +model["save"] <- model$raw +?lockEnvironment +environmentIsLocked(model) +model$best_iter +model$best_iter <- 1 +model$best_iter +install_github("Laurae2/LightGBM/R-package@patch-10") +devtools::install_github("Laurae2/LightGBM/R-package@patch-10") +?lightgbm::lgb.train +library(lightgbm) +data(agaricus.train, package='lightgbm') +train <- agaricus.train +dtrain <- lgb.Dataset(train$data, label=train$label) +data(agaricus.test, package='lightgbm') +test <- agaricus.test +dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label) +params <- list(objective="regression", metric="l2") +valids <- list(test=dtest) +model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) +saveRDS(model, "D:/model.rds") +model$raw +model$save() +model$raw +Booster <- R6Class( +"lgb.Booster", +cloneable = FALSE, +public = list( +best_iter = -1, +record_evals = list(), +finalize = function() { +if (!lgb.is.null.handle(private$handle)) { +lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle) +private$handle <- NULL +} +}, +initialize = function(params = list(), +train_set = NULL, +modelfile = NULL, +...) { +params <- append(params, list(...)) +params_str <- lgb.params2str(params) +handle <- lgb.new.handle() +if (!is.null(train_set)) { +if (!lgb.check.r6.class(train_set, "lgb.Dataset")) { +stop("lgb.Booster: Can only use lgb.Dataset as training data") +} +handle <- +lgb.call("LGBM_BoosterCreate_R", ret = handle, train_set$.__enclos_env__$private$get_handle(), params_str) +private$train_set <- train_set +private$num_dataset <- 1 +private$init_predictor <- train_set$.__enclos_env__$private$predictor +if (!is.null(private$init_predictor)) { +lgb.call("LGBM_BoosterMerge_R", ret = NULL, +handle, +private$init_predictor$.__enclos_env__$private$handle) +} +private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE) +} else if (!is.null(modelfile)) { +if (!is.character(modelfile)) { +stop("lgb.Booster: Can only use a string as model file path") +} +handle <- +lgb.call("LGBM_BoosterCreateFromModelfile_R", +ret = handle, +lgb.c_str(modelfile)) +} else { +stop( +"lgb.Booster: Need at least either training dataset or model file to create booster instance" +) +} +class(handle) <- "lgb.Booster.handle" +private$handle <- handle +private$num_class <- 1L +private$num_class <- +lgb.call("LGBM_BoosterGetNumClasses_R", ret = private$num_class, private$handle) +}, +set_train_data_name = function(name) { +private$name_train_set <- name +self +}, +add_valid = function(data, name) { +if (!lgb.check.r6.class(data, "lgb.Dataset")) { +stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data") +} +if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) { +stop( +"lgb.Booster.add_valid: Failed to add validation data; you should use the same predictor for these data" +) +} +if (!is.character(name)) { +stop("lgb.Booster.add_valid: Can only use characters as data name") +} +lgb.call("LGBM_BoosterAddValidData_R", ret = NULL, private$handle, data$.__enclos_env__$private$get_handle()) +private$valid_sets <- c(private$valid_sets, data) +private$name_valid_sets <- c(private$name_valid_sets, name) +private$num_dataset <- private$num_dataset + 1 +private$is_predicted_cur_iter <- +c(private$is_predicted_cur_iter, FALSE) +self +}, +reset_parameter = function(params, ...) { +params <- append(params, list(...)) +params_str <- algb.params2str(params) +lgb.call("LGBM_BoosterResetParameter_R", ret = NULL, +private$handle, +params_str) +self +}, +update = function(train_set = NULL, fobj = NULL) { +if (!is.null(train_set)) { +if (!lgb.check.r6.class(train_set, "lgb.Dataset")) { +stop("lgb.Booster.update: Only can use lgb.Dataset as training data") +} +if (!identical(train_set$predictor, private$init_predictor)) { +stop( +"lgb.Booster.update: Change train_set failed, you should use the same predictor for these data" +) +} +lgb.call("LGBM_BoosterResetTrainingData_R", ret = NULL, +private$handle, +train_set$.__enclos_env__$private$get_handle()) +private$train_set = train_set +} +if (is.null(fobj)) { +ret <- lgb.call("LGBM_BoosterUpdateOneIter_R", ret = NULL, private$handle) +} else { +if (!is.function(fobj)) { stop("lgb.Booster.update: fobj should be a function") } +gpair <- fobj(private$inner_predict(1), private$train_set) +if(is.null(gpair$grad) | is.null(gpair$hess)){ +stop("lgb.Booster.update: custom objective should +return a list with attributes (hess, grad)") +} +ret <- lgb.call( +"LGBM_BoosterUpdateOneIterCustom_R", ret = NULL, +private$handle, +gpair$grad, +gpair$hess, +length(gpair$grad) +) +} +for (i in seq_along(private$is_predicted_cur_iter)) { +private$is_predicted_cur_iter[[i]] <- FALSE +} +ret +}, +rollback_one_iter = function() { +lgb.call("LGBM_BoosterRollbackOneIter_R", ret = NULL, private$handle) +for (i in seq_along(private$is_predicted_cur_iter)) { +private$is_predicted_cur_iter[[i]] <- FALSE +} +self +}, +current_iter = function() { +cur_iter <- 0L +lgb.call("LGBM_BoosterGetCurrentIteration_R", ret = cur_iter, private$handle) +}, +eval = function(data, name, feval = NULL) { +if (!lgb.check.r6.class(data, "lgb.Dataset")) { +stop("lgb.Booster.eval: Can only use lgb.Dataset to eval") +} +data_idx <- 0 +if (identical(data, private$train_set)) { data_idx <- 1 } else { +if (length(private$valid_sets) > 0) { +for (i in seq_along(private$valid_sets)) { +if (identical(data, private$valid_sets[[i]])) { +data_idx <- i + 1 +break +} +} +} +} +if (data_idx == 0) { +self$add_valid(data, name) +data_idx <- private$num_dataset +} +private$inner_eval(name, data_idx, feval) +}, +eval_train = function(feval = NULL) { +private$inner_eval(private$name_train_set, 1, feval) +}, +eval_valid = function(feval = NULL) { +ret = list() +if (length(private$valid_sets) <= 0) { return(ret) } +for (i in seq_along(private$valid_sets)) { +ret <- append(ret, private$inner_eval(private$name_valid_sets[[i]], i + 1, feval)) +} +ret +}, +save_model = function(filename, num_iteration = NULL) { +if (is.null(num_iteration)) { num_iteration <- self$best_iter } +lgb.call( +"LGBM_BoosterSaveModel_R", +ret = NULL, +private$handle, +as.integer(num_iteration), +lgb.c_str(filename) +) +self +}, +dump_model = function(num_iteration = NULL) { +if (is.null(num_iteration)) { num_iteration <- self$best_iter } +lgb.call.return.str( +"LGBM_BoosterDumpModel_R", +private$handle, +as.integer(num_iteration) +) +}, +predict = function(data, +num_iteration = NULL, +rawscore = FALSE, +predleaf = FALSE, +header = FALSE, +reshape = FALSE) { +if (is.null(num_iteration)) { num_iteration <- self$best_iter } +predictor <- Predictor$new(private$handle) +predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape) +}, +to_predictor = function() { Predictor$new(private$handle) }, +raw = NA, +save = function() { +self$raw <- self$dump_model() +} +), +private = list( +handle = NULL, +train_set = NULL, +name_train_set = "training", +valid_sets = list(), +name_valid_sets = list(), +predict_buffer = list(), +is_predicted_cur_iter = list(), +num_class = 1, +num_dataset = 0, +init_predictor = NULL, +eval_names = NULL, +higher_better_inner_eval = NULL, +inner_predict = function(idx) { +data_name <- private$name_train_set +if (idx > 1) { data_name <- private$name_valid_sets[[idx - 1]] } +if (idx > private$num_dataset) { +stop("data_idx should not be greater than num_dataset") +} +if (is.null(private$predict_buffer[[data_name]])) { +npred <- 0L +npred <- lgb.call("LGBM_BoosterGetNumPredict_R", +ret = npred, +private$handle, +as.integer(idx - 1)) +private$predict_buffer[[data_name]] <- rep(0.0, npred) +} +if (!private$is_predicted_cur_iter[[idx]]) { +private$predict_buffer[[data_name]] <- lgb.call( +"LGBM_BoosterGetPredict_R", +ret = private$predict_buffer[[data_name]], +private$handle, +as.integer(idx - 1) +) +private$is_predicted_cur_iter[[idx]] <- TRUE +} +private$predict_buffer[[data_name]] +}, +get_eval_info = function() { +if (is.null(private$eval_names)) { +names <- lgb.call.return.str("LGBM_BoosterGetEvalNames_R", private$handle) +if (nchar(names) > 0) { +names <- strsplit(names, "\t")[[1]] +private$eval_names <- names +private$higher_better_inner_eval <- rep(FALSE, length(names)) +for (i in seq_along(names)) { +if ((names[i] == "auc") | grepl("^ndcg", names[i])) { +private$higher_better_inner_eval[i] <- TRUE +} +} +} +} +private$eval_names +}, +inner_eval = function(data_name, data_idx, feval = NULL) { +if (data_idx > private$num_dataset) { +stop("data_idx should not be greater than num_dataset") +} +private$get_eval_info() +ret <- list() +if (length(private$eval_names) > 0) { +tmp_vals <- rep(0.0, length(private$eval_names)) +tmp_vals <- lgb.call("LGBM_BoosterGetEval_R", ret = tmp_vals, +private$handle, +as.integer(data_idx - 1)) +for (i in seq_along(private$eval_names)) { +res <- list() +res$data_name <- data_name +res$name <- private$eval_names[i] +res$value <- tmp_vals[i] +res$higher_better <- private$higher_better_inner_eval[i] +ret <- append(ret, list(res)) +} +} +if (!is.null(feval)) { +if (!is.function(feval)) { +stop("lgb.Booster.eval: feval should be a function") +} +data <- private$train_set +if (data_idx > 1) { data <- private$valid_sets[[data_idx - 1]] } +res <- feval(private$inner_predict(data_idx), data) +if(is.null(res$name) | is.null(res$value) | +is.null(res$higher_better)) { +stop("lgb.Booster.eval: custom eval function should return a +list with attribute (name, value, higher_better)"); +} +res$data_name <- data_name +ret <- append(ret, list(res)) +} +ret +} +) +) +library(lightgbm) +data(agaricus.train, package='lightgbm') +train <- agaricus.train +dtrain <- lgb.Dataset(train$data, label=train$label) +data(agaricus.test, package='lightgbm') +test <- agaricus.test +dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label) +params <- list(objective="regression", metric="l2") +valids <- list(test=dtest) +model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) +model$save() +model$raw +my_model <- readRDS("D:/model.rds") +predict(my_model, agaricus.test$data) +predict(agaricus.test$data, my_model) +predict(model, agaricus.test$data) +predict(test$data, agaricus.test$data) +predict(train$data, agaricus.test$data) +head(predict(model, train$data)) +head(model$train$data) +head(model$predict(train$data)) +head(model$predict(dtrain)) +head(model$predict(dtrain, 1, FALSE, FALSE, FALSE, FALSE)) +model$predict +predict(model, test$data) +devtools::install_github("Laurae2/LightGBM/R-package@patch-10") +library(lightgbm) +data(agaricus.train, package='lightgbm') +train <- agaricus.train +dtrain <- lgb.Dataset(train$data, label=train$label) +data(agaricus.test, package='lightgbm') +test <- agaricus.test +dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label) +params <- list(objective="regression", metric="l2") +valids <- list(test=dtest) +model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) +preds <- predict(model, test$data) +saveRDS(model, "D:/model.rds") +preds <- predict(model, test$data) +new_model <- readRDS("D:/model.rds") +preds <- predict(new_model, test$data) +preds <- predict(model, test$data) +new_m$odel$raw +new_model$raw +model$save() +model$raw +?readRDS +setwd("D:/Data Science/LightGBM_GitHub/LightGBM/R-package") +devtools::document() +devtools::document() diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index e4daa91e9473..4ba52903ccb7 100755 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -26,7 +26,6 @@ export(lgb.plot.interpretation) export(lgb.save) export(lgb.train) export(lightgbm) -export(saveRDS.lgb.Booster) export(setinfo) export(slice) import(methods) diff --git a/R-package/R/saveRDS.lgb.Booster.R b/R-package/R/saveRDS.lgb.Booster.R deleted file mode 100644 index 41cd9f05070d..000000000000 --- a/R-package/R/saveRDS.lgb.Booster.R +++ /dev/null @@ -1,38 +0,0 @@ -#' saveRDS for lgb.Booster models -#' -#' Attemps to save a model using RDS. Has an additional parameter (\code{raw}) which decides whether to save the raw model or not. -#' -#' @param object R object to serialize. -#' @param file a connection or the name of the file where the R object is saved to or read from. -#' @param ascii a logical. If TRUE or NA, an ASCII representation is written; otherwise (default), a binary one is used. See the comments in the help for save. -#' @param version the workspace format version to use. \code{NULL} specifies the current default version (2). Versions prior to 2 are not supported, so this will only be relevant when there are later versions. -#' @param compress a logical specifying whether saving to a named file is to use "gzip" compression, or one of \code{"gzip"}, \code{"bzip2"} or \code{"xz"} to indicate the type of compression to be used. Ignored if file is a connection. -#' @param refhook a hook function for handling reference objects. -#' @param raw whether to save the model in a raw variable or not, recommended to leave it to \code{TRUE}. -#' -#' @return NULL invisibly. -#' -#' @examples -#' \dontrun{ -#' library(lightgbm) -#' data(agaricus.train, package='lightgbm') -#' train <- agaricus.train -#' dtrain <- lgb.Dataset(train$data, label=train$label) -#' data(agaricus.test, package='lightgbm') -#' test <- agaricus.test -#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label) -#' params <- list(objective="regression", metric="l2") -#' valids <- list(test=dtest) -#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) -#' saveRDS(model, "model.rds") -#' } -#' @export - -saveRDS.lgb.Booster <- function(object, file = "", ascii = FALSE, version = NULL, compress = TRUE, refhook = NULL, raw = TRUE) { - - if (is.na(object$raw) & (raw)) { - object$save() - } - saveRDS(object, file = "", ascii = FALSE, version = NULL, compress = TRUE, refhook = NULL) - -} diff --git a/R-package/man/saveRDS.lgb.Booster.Rd b/R-package/man/saveRDS.lgb.Booster.Rd deleted file mode 100644 index 8894fa079102..000000000000 --- a/R-package/man/saveRDS.lgb.Booster.Rd +++ /dev/null @@ -1,46 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/saveRDS.lgb.Booster.R -\name{saveRDS.lgb.Booster} -\alias{saveRDS.lgb.Booster} -\title{saveRDS for lgb.Booster models} -\usage{ -saveRDS.lgb.Booster(object, file = "", ascii = FALSE, version = NULL, - compress = TRUE, refhook = NULL, raw = TRUE) -} -\arguments{ -\item{object}{R object to serialize.} - -\item{file}{a connection or the name of the file where the R object is saved to or read from.} - -\item{ascii}{a logical. If TRUE or NA, an ASCII representation is written; otherwise (default), a binary one is used. See the comments in the help for save.} - -\item{version}{the workspace format version to use. \code{NULL} specifies the current default version (2). Versions prior to 2 are not supported, so this will only be relevant when there are later versions.} - -\item{compress}{a logical specifying whether saving to a named file is to use "gzip" compression, or one of \code{"gzip"}, \code{"bzip2"} or \code{"xz"} to indicate the type of compression to be used. Ignored if file is a connection.} - -\item{refhook}{a hook function for handling reference objects.} - -\item{raw}{whether to save the model in a raw variable or not, recommended to leave it to \code{TRUE}.} -} -\value{ -NULL invisibly. -} -\description{ -Attemps to save a model using RDS. Has an additional parameter (\code{raw}) which decides whether to save the raw model or not. -} -\examples{ -\dontrun{ - library(lightgbm) - data(agaricus.train, package='lightgbm') - train <- agaricus.train - dtrain <- lgb.Dataset(train$data, label=train$label) - data(agaricus.test, package='lightgbm') - test <- agaricus.test - dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label) - params <- list(objective="regression", metric="l2") - valids <- list(test=dtest) - model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) - saveRDS(model, "model.rds") -} -} -