Skip to content

Commit

Permalink
Expression handling (#1126)
Browse files Browse the repository at this point in the history
* initial version of expression-learners

* allow expressions within tuning param sets

* updated documentation of expression-related files

* fixing naming issues

* further doc fixes

* remove dict argument

* updating man-pages

* rm dict_template

* better documentation of expression-related functions

* removed ParamHelpers:: as it is not necessary

* cleanup; added test

* removed duplicated PH dep

* fixes as requested in PR#1126

* use setLearnerId in tests

* fix warnings in tests and r cmd check

* added assertion for Task

* docs, mini cleanup
  • Loading branch information
kerschke authored and jakob-r committed Mar 2, 2017
1 parent c2d40a3 commit 416305d
Show file tree
Hide file tree
Showing 32 changed files with 402 additions and 16 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ S3method(downsample,Task)
S3method(estimateRelativeOverfitting,ResampleDesc)
S3method(estimateResidualVariance,Learner)
S3method(estimateResidualVariance,WrappedModel)
S3method(evaluateParamExpressions,Learner)
S3method(generateCalibrationData,BenchmarkResult)
S3method(generateCalibrationData,Prediction)
S3method(generateCalibrationData,ResampleResult)
Expand Down Expand Up @@ -96,6 +97,7 @@ S3method(getTaskTargetNames,TaskDescUnsupervised)
S3method(getTaskTargets,CostSensTask)
S3method(getTaskTargets,SupervisedTask)
S3method(getTaskTargets,UnsupervisedTask)
S3method(hasExpression,Learner)
S3method(impute,Task)
S3method(impute,data.frame)
S3method(isFailureModel,BaseWrapperModel)
Expand Down Expand Up @@ -866,6 +868,7 @@ export(getTaskClassLevels)
export(getTaskCosts)
export(getTaskData)
export(getTaskDescription)
export(getTaskDictionary)
export(getTaskFeatureNames)
export(getTaskFormula)
export(getTaskId)
Expand Down
8 changes: 8 additions & 0 deletions R/Learner_properties.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,11 @@ listLearnerProperties = function(type = "any") {
assertSubset(type, allProps)
mlr$learner.properties[[type]]
}

#' @param obj [\code{\link{Learner}} | \code{character(1)}]\cr
#' Same as \code{learner} above.
#' @rdname LearnerProperties
#' @export
hasExpression.Learner = function(obj) {
any(hasExpression(obj$par.set)) || any(vlapply(obj$par.vals, is.expression))
}
6 changes: 3 additions & 3 deletions R/RLearner_classif_randomForest.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ makeRLearner.classif.randomForest = function() {
package = "randomForest",
par.set = makeParamSet(
makeIntegerLearnerParam(id = "ntree", default = 500L, lower = 1L),
makeIntegerLearnerParam(id = "mtry", lower = 1L),
makeIntegerLearnerParam(id = "mtry", lower = 1L, default = expression(floor(sqrt(p)))),
makeLogicalLearnerParam(id = "replace", default = TRUE),
makeNumericVectorLearnerParam(id = "classwt", lower = 0),
makeNumericVectorLearnerParam(id = "cutoff", lower = 0, upper = 1),
makeNumericVectorLearnerParam(id = "classwt", lower = 0, len = expression(k)),
makeNumericVectorLearnerParam(id = "cutoff", lower = 0, upper = 1, len = expression(k)),
makeUntypedLearnerParam(id = "strata", tunable = FALSE),
makeIntegerVectorLearnerParam(id = "sampsize", lower = 1L),
makeIntegerLearnerParam(id = "nodesize", default = 1L, lower = 1L),
Expand Down
29 changes: 29 additions & 0 deletions R/Task_operators.R
Original file line number Diff line number Diff line change
Expand Up @@ -454,3 +454,32 @@ getTaskFactorLevels = function(task) {
getTaskWeights = function(task) {
task$weights
}


#' @title Create a dictionary based on the task.
#'
#' @description Returns a dictionary, which contains the \link{Task} itself
#' (\code{task}), the number of features (\code{p}), the number of
#' observations (\code{n}), the task type (\code{type}) and in case of
#' classification tasks, the number of class levels (\code{k}).
#'
#' @template arg_task
#' @return [\code{\link[base]{list}}]. Used for evaluating the expressions
#' within a parameter, parameter set or list of parameters.
#' @family task
#' @export
#' @examples
#' task = makeClassifTask(data = iris, target = "Species")
#' getTaskDictionary(task)
getTaskDictionary = function(task) {
assertClass(task, classes = "Task")
dict = list(
task = task,
p = getTaskNFeats(task),
n = getTaskSize(task),
type = getTaskType(task)
)
if (dict$type == "classif")
dict$k = length(getTaskClassLevels(task))
return(dict)
}
77 changes: 77 additions & 0 deletions R/evaluateParamExpressions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#' @title Evaluates expressions within a learner or parameter set.
#'
#' @description
#' A \code{\link{Learner}} can contain unevaluated \code{\link[base]{expression}s}
#' as value for a hyperparameter. E.g., these expressions are used if the default
#' value depends on the task size or an upper limit for a parameter is given by
#' the number of features in a task. \code{evaluateParamExpressions} allows to
#' evaluate these expressions using a given dictionary, which holds the following
#' information:
#' \itemize{
#' \item{\code{task}:} the task itself, allowing to access any of its elements.
#' \item{\code{p}:} the number of features in the task
#' \item{\code{n}:} the number of observations in the task
#' \item{\code{type}:} the task type, i.e. "classif", "regr", "surv", "cluster", "costcens" or "multilabel"
#' \item{\code{k}:} the number of classes of the target variable (only available for classification tasks)
#' }
#' Usually the evaluation of the expression is performed automatically, e.g. in
#' \code{\link{train}} or \code{\link{tuneParams}}. Therefore calling
#' \code{evaluateParamExpressions} manually should not be necessary.
#' It is also possible to directly evaluate the expressions of a
#' \code{\link[ParamHelpers]{ParamSet}}, \code{\link[base]{list}} of
#' \code{\link[ParamHelpers]{Param}s} or single \code{\link[ParamHelpers]{Param}s}.
#' For further information on these, please refer to the documentation of the
#' \code{ParamHelpers} package.
#'
#' @param obj [\code{\link{Learner}}]\cr
#' The learner. If you pass a string the learner will be created via
#' \code{\link{makeLearner}}. Expressions within \code{length}, \code{lower}
#' or \code{upper} boundaries, \code{default} or \code{value} will be
#' evaluated using the provided dictionary (\code{dict}).
#' @param dict [\code{environment} | \code{list} | \code{NULL}]\cr
#' Environment or list which will be used for evaluating the variables
#' of expressions within a parameter, parameter set or list of parameters.
#' The default is \code{NULL}.
#' @return [\code{\link{Learner}}].
#' @export
#' @examples
#' ## (1) evaluation of a learner's hyperparameters
#' task = makeClassifTask(data = iris, target = "Species")
#' dict = getTaskDictionary(task = task)
#' lrn1 = makeLearner("classif.rpart", minsplit = expression(k * p),
#' minbucket = expression(3L + 4L * task$task.desc$has.blocking))
#' lrn2 = evaluateParamExpressions(obj = lrn1, dict = dict)
#'
#' getHyperPars(lrn1)
#' getHyperPars(lrn2)
#'
#' ## (2) evaluation of a learner's entire parameter set
#' task = makeClassifTask(data = iris, target = "Species")
#' dict = getTaskDictionary(task = task)
#' lrn1 = makeLearner("classif.randomForest")
#' lrn2 = evaluateParamExpressions(obj = lrn1, dict = dict)
#'
#' ## Note the values for parameters 'mtry', 'classwt' and 'cutoff'
#' lrn1$par.set
#' lrn2$par.set
#'
#' ## (3) evaluation of a parameter set
#' task = makeClassifTask(data = iris, target = "Species")
#' dict = getTaskDictionary(task = task)
#' ps1 = makeParamSet(
#' makeNumericParam("C", lower = expression(k), upper = expression(n), trafo = function(x) 2^x),
#' makeDiscreteParam("sigma", values = expression(list(k, p)))
#' )
#' ps2 = evaluateParamExpressions(obj = ps1, dict = dict)
#'
#' ps1
#' ps2
evaluateParamExpressions.Learner = function(obj, dict = NULL) {
obj = checkLearner(obj)
if (hasExpression(obj)) {
assertList(dict, null.ok = TRUE)
obj$par.set = evaluateParamExpressions(obj = obj$par.set, dict = dict)
obj$par.vals = evaluateParamExpressions(obj = obj$par.vals, dict = dict)
}
return(obj)
}
2 changes: 2 additions & 0 deletions R/makeLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@
#' @return [\code{\link{Learner}}].
#' @family learner
#' @export
#' @note Learners can contain task dependent expressions, see \code{\link{evaluateParamExpressions}} for more information.
#' @aliases Learner
#' @examples
#' makeLearner("classif.rpart")
#' makeLearner("classif.lda", predict.type = "prob")
#' makeLearner("classif.rpart", minsplit = expression(k))
#' lrn = makeLearner("classif.lda", method = "t", nu = 10)
#' print(lrn$par.vals)
makeLearner = function(cl, id = cl, predict.type = "response", predict.threshold = NULL,
Expand Down
6 changes: 4 additions & 2 deletions R/setHyperPars.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
#' @note If a named (hyper)parameter can't be found for the given learner, the 3
#' closest (hyper)parameter names will be output in case the user mistyped.
#' @export
#' @note Learners can contain task dependent expressions, see \code{\link{evaluateParamExpressions}} for more information.
#' @family learner
#' @importFrom utils adist
#' @examples
#' cl1 = makeLearner("classif.ksvm", sigma = 1)
#' cl2 = setHyperPars(cl1, sigma = 10, par.vals = list(C = 2))
#' cl3 = setHyperPars(cl2, C = expression(round(n / p)))
#' print(cl1)
#' # note the now set and altered hyperparameters:
#' print(cl2)
#' print(cl3)
setHyperPars = function(learner, ..., par.vals = list()) {
args = list(...)
assertClass(learner, classes = "Learner")
Expand Down Expand Up @@ -73,7 +75,7 @@ setHyperPars2.Learner = function(learner, par.vals) {
learner$par.set$pars[[n]] = makeUntypedLearnerParam(id = n)
learner$par.vals[[n]] = p
} else {
if (on.par.out.of.bounds != "quiet" && !isFeasible(pd, p)) {
if (on.par.out.of.bounds != "quiet" && !isFeasible(pd, p) && !is.expression(p)) {
msg = sprintf("%s is not feasible for parameter '%s'!", convertToShortString(p), pd$id)
if (on.par.out.of.bounds == "stop") {
stop(msg)
Expand Down
4 changes: 4 additions & 0 deletions R/train.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
train = function(learner, task, subset, weights = NULL) {
learner = checkLearner(learner)
assertClass(task, classes = "Task")
if (hasExpression(learner)) {
dict = getTaskDictionary(task = task)
learner = evaluateParamExpressions(obj = learner, dict = dict)
}
if (missing(subset)) {
subset = seq_len(getTaskSize(task))
} else {
Expand Down
22 changes: 19 additions & 3 deletions R/tuneParams.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
#' @param par.set [\code{\link[ParamHelpers]{ParamSet}}]\cr
#' Collection of parameters and their constraints for optimization.
#' Dependent parameters with a \code{requires} field must use \code{quote} and not
#' \code{expression} to define it.
#' \code{expression} to define it. On the other hand, task dependent parameters
#' need to be defined with expressions.
#' @param control [\code{\link{TuneControl}}]\cr
#' Control object for search method. Also selects the optimization algorithm for tuning.
#' @template arg_showinfo
Expand All @@ -31,6 +32,8 @@
#' @note If you would like to include results from the training data set, make
#' sure to appropriately adjust the resampling strategy and the aggregation for
#' the measure. See example code below.
#' Also note that learners and parameter sets can contain task dependent
#' expressions, see \code{\link{evaluateParamExpressions}} for more information.
#' @export
#' @examples
#' # a grid search for an SVM (with a tiny number of points...)
Expand All @@ -50,6 +53,16 @@
#' print(head(generateHyperParsEffectData(res)))
#' print(head(generateHyperParsEffectData(res, trafo = TRUE)))
#'
#' # tuning the parameters 'C' and 'sigma' of an SVM, where the boundaries
#' # of 'sigma' depend on the number of features
#' ps = makeParamSet(
#' makeNumericLearnerParam("sigma", lower = expression(0.2 * p), upper = expression(2.5 * p)),
#' makeDiscreteLearnerParam("C", values = 2^c(-1, 1))
#' )
#' rdesc = makeResampleDesc("Subsample")
#' ctrl = makeTuneControlRandom(maxit = 2L)
#' res = tuneParams("classif.ksvm", iris.task, par.set = ps, control = ctrl, resampling = rdesc)
#'
#' \dontrun{
#' # we optimize the SVM over 3 kernels simultanously
#' # note how we use dependent params (requires = ...) and iterated F-racing here
Expand Down Expand Up @@ -81,6 +94,11 @@ tuneParams = function(learner, task, resampling, measures, par.set, control, sho
assertClass(task, classes = "Task")
measures = checkMeasures(measures, learner)
assertClass(par.set, classes = "ParamSet")
if (hasExpression(learner) || hasExpression(par.set)) {
dict = getTaskDictionary(task = task)
learner = evaluateParamExpressions(obj = learner, dict = dict)
par.set = evaluateParamExpressions(obj = par.set, dict = dict)
}
assertClass(control, classes = "TuneControl")
if (!inherits(resampling, "ResampleDesc") && !inherits(resampling, "ResampleInstance"))
stop("Argument resampling must be of class ResampleDesc or ResampleInstance!")
Expand Down Expand Up @@ -113,5 +131,3 @@ tuneParams = function(learner, task, resampling, measures, par.set, control, sho
messagef("[Tune] Result: %s : %s", paramValueToString(par.set, or$x), perfsToString(or$y))
return(or)
}


6 changes: 6 additions & 0 deletions man/LearnerProperties.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

79 changes: 79 additions & 0 deletions man/evaluateParamExpressions.Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/getTaskClassLevels.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/getTaskCosts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/getTaskData.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 416305d

Please sign in to comment.