Skip to content

Commit 416305d

Browse files
kerschkejakob-r
authored andcommitted
Expression handling (#1126)
* initial version of expression-learners * allow expressions within tuning param sets * updated documentation of expression-related files * fixing naming issues * further doc fixes * remove dict argument * updating man-pages * rm dict_template * better documentation of expression-related functions * removed ParamHelpers:: as it is not necessary * cleanup; added test * removed duplicated PH dep * fixes as requested in PR#1126 * use setLearnerId in tests * fix warnings in tests and r cmd check * added assertion for Task * docs, mini cleanup
1 parent c2d40a3 commit 416305d

32 files changed

+402
-16
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ S3method(downsample,Task)
1515
S3method(estimateRelativeOverfitting,ResampleDesc)
1616
S3method(estimateResidualVariance,Learner)
1717
S3method(estimateResidualVariance,WrappedModel)
18+
S3method(evaluateParamExpressions,Learner)
1819
S3method(generateCalibrationData,BenchmarkResult)
1920
S3method(generateCalibrationData,Prediction)
2021
S3method(generateCalibrationData,ResampleResult)
@@ -96,6 +97,7 @@ S3method(getTaskTargetNames,TaskDescUnsupervised)
9697
S3method(getTaskTargets,CostSensTask)
9798
S3method(getTaskTargets,SupervisedTask)
9899
S3method(getTaskTargets,UnsupervisedTask)
100+
S3method(hasExpression,Learner)
99101
S3method(impute,Task)
100102
S3method(impute,data.frame)
101103
S3method(isFailureModel,BaseWrapperModel)
@@ -866,6 +868,7 @@ export(getTaskClassLevels)
866868
export(getTaskCosts)
867869
export(getTaskData)
868870
export(getTaskDescription)
871+
export(getTaskDictionary)
869872
export(getTaskFeatureNames)
870873
export(getTaskFormula)
871874
export(getTaskId)

R/Learner_properties.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,11 @@ listLearnerProperties = function(type = "any") {
7878
assertSubset(type, allProps)
7979
mlr$learner.properties[[type]]
8080
}
81+
82+
#' @param obj [\code{\link{Learner}} | \code{character(1)}]\cr
83+
#' Same as \code{learner} above.
84+
#' @rdname LearnerProperties
85+
#' @export
86+
hasExpression.Learner = function(obj) {
87+
any(hasExpression(obj$par.set)) || any(vlapply(obj$par.vals, is.expression))
88+
}

R/RLearner_classif_randomForest.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ makeRLearner.classif.randomForest = function() {
55
package = "randomForest",
66
par.set = makeParamSet(
77
makeIntegerLearnerParam(id = "ntree", default = 500L, lower = 1L),
8-
makeIntegerLearnerParam(id = "mtry", lower = 1L),
8+
makeIntegerLearnerParam(id = "mtry", lower = 1L, default = expression(floor(sqrt(p)))),
99
makeLogicalLearnerParam(id = "replace", default = TRUE),
10-
makeNumericVectorLearnerParam(id = "classwt", lower = 0),
11-
makeNumericVectorLearnerParam(id = "cutoff", lower = 0, upper = 1),
10+
makeNumericVectorLearnerParam(id = "classwt", lower = 0, len = expression(k)),
11+
makeNumericVectorLearnerParam(id = "cutoff", lower = 0, upper = 1, len = expression(k)),
1212
makeUntypedLearnerParam(id = "strata", tunable = FALSE),
1313
makeIntegerVectorLearnerParam(id = "sampsize", lower = 1L),
1414
makeIntegerLearnerParam(id = "nodesize", default = 1L, lower = 1L),

R/Task_operators.R

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,3 +454,32 @@ getTaskFactorLevels = function(task) {
454454
getTaskWeights = function(task) {
455455
task$weights
456456
}
457+
458+
459+
#' @title Create a dictionary based on the task.
460+
#'
461+
#' @description Returns a dictionary, which contains the \link{Task} itself
462+
#' (\code{task}), the number of features (\code{p}), the number of
463+
#' observations (\code{n}), the task type (\code{type}) and in case of
464+
#' classification tasks, the number of class levels (\code{k}).
465+
#'
466+
#' @template arg_task
467+
#' @return [\code{\link[base]{list}}]. Used for evaluating the expressions
468+
#' within a parameter, parameter set or list of parameters.
469+
#' @family task
470+
#' @export
471+
#' @examples
472+
#' task = makeClassifTask(data = iris, target = "Species")
473+
#' getTaskDictionary(task)
474+
getTaskDictionary = function(task) {
475+
assertClass(task, classes = "Task")
476+
dict = list(
477+
task = task,
478+
p = getTaskNFeats(task),
479+
n = getTaskSize(task),
480+
type = getTaskType(task)
481+
)
482+
if (dict$type == "classif")
483+
dict$k = length(getTaskClassLevels(task))
484+
return(dict)
485+
}

R/evaluateParamExpressions.R

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#' @title Evaluates expressions within a learner or parameter set.
2+
#'
3+
#' @description
4+
#' A \code{\link{Learner}} can contain unevaluated \code{\link[base]{expression}s}
5+
#' as value for a hyperparameter. E.g., these expressions are used if the default
6+
#' value depends on the task size or an upper limit for a parameter is given by
7+
#' the number of features in a task. \code{evaluateParamExpressions} allows to
8+
#' evaluate these expressions using a given dictionary, which holds the following
9+
#' information:
10+
#' \itemize{
11+
#' \item{\code{task}:} the task itself, allowing to access any of its elements.
12+
#' \item{\code{p}:} the number of features in the task
13+
#' \item{\code{n}:} the number of observations in the task
14+
#' \item{\code{type}:} the task type, i.e. "classif", "regr", "surv", "cluster", "costcens" or "multilabel"
15+
#' \item{\code{k}:} the number of classes of the target variable (only available for classification tasks)
16+
#' }
17+
#' Usually the evaluation of the expression is performed automatically, e.g. in
18+
#' \code{\link{train}} or \code{\link{tuneParams}}. Therefore calling
19+
#' \code{evaluateParamExpressions} manually should not be necessary.
20+
#' It is also possible to directly evaluate the expressions of a
21+
#' \code{\link[ParamHelpers]{ParamSet}}, \code{\link[base]{list}} of
22+
#' \code{\link[ParamHelpers]{Param}s} or single \code{\link[ParamHelpers]{Param}s}.
23+
#' For further information on these, please refer to the documentation of the
24+
#' \code{ParamHelpers} package.
25+
#'
26+
#' @param obj [\code{\link{Learner}}]\cr
27+
#' The learner. If you pass a string the learner will be created via
28+
#' \code{\link{makeLearner}}. Expressions within \code{length}, \code{lower}
29+
#' or \code{upper} boundaries, \code{default} or \code{value} will be
30+
#' evaluated using the provided dictionary (\code{dict}).
31+
#' @param dict [\code{environment} | \code{list} | \code{NULL}]\cr
32+
#' Environment or list which will be used for evaluating the variables
33+
#' of expressions within a parameter, parameter set or list of parameters.
34+
#' The default is \code{NULL}.
35+
#' @return [\code{\link{Learner}}].
36+
#' @export
37+
#' @examples
38+
#' ## (1) evaluation of a learner's hyperparameters
39+
#' task = makeClassifTask(data = iris, target = "Species")
40+
#' dict = getTaskDictionary(task = task)
41+
#' lrn1 = makeLearner("classif.rpart", minsplit = expression(k * p),
42+
#' minbucket = expression(3L + 4L * task$task.desc$has.blocking))
43+
#' lrn2 = evaluateParamExpressions(obj = lrn1, dict = dict)
44+
#'
45+
#' getHyperPars(lrn1)
46+
#' getHyperPars(lrn2)
47+
#'
48+
#' ## (2) evaluation of a learner's entire parameter set
49+
#' task = makeClassifTask(data = iris, target = "Species")
50+
#' dict = getTaskDictionary(task = task)
51+
#' lrn1 = makeLearner("classif.randomForest")
52+
#' lrn2 = evaluateParamExpressions(obj = lrn1, dict = dict)
53+
#'
54+
#' ## Note the values for parameters 'mtry', 'classwt' and 'cutoff'
55+
#' lrn1$par.set
56+
#' lrn2$par.set
57+
#'
58+
#' ## (3) evaluation of a parameter set
59+
#' task = makeClassifTask(data = iris, target = "Species")
60+
#' dict = getTaskDictionary(task = task)
61+
#' ps1 = makeParamSet(
62+
#' makeNumericParam("C", lower = expression(k), upper = expression(n), trafo = function(x) 2^x),
63+
#' makeDiscreteParam("sigma", values = expression(list(k, p)))
64+
#' )
65+
#' ps2 = evaluateParamExpressions(obj = ps1, dict = dict)
66+
#'
67+
#' ps1
68+
#' ps2
69+
evaluateParamExpressions.Learner = function(obj, dict = NULL) {
70+
obj = checkLearner(obj)
71+
if (hasExpression(obj)) {
72+
assertList(dict, null.ok = TRUE)
73+
obj$par.set = evaluateParamExpressions(obj = obj$par.set, dict = dict)
74+
obj$par.vals = evaluateParamExpressions(obj = obj$par.vals, dict = dict)
75+
}
76+
return(obj)
77+
}

R/makeLearner.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,12 @@
4343
#' @return [\code{\link{Learner}}].
4444
#' @family learner
4545
#' @export
46+
#' @note Learners can contain task dependent expressions, see \code{\link{evaluateParamExpressions}} for more information.
4647
#' @aliases Learner
4748
#' @examples
4849
#' makeLearner("classif.rpart")
4950
#' makeLearner("classif.lda", predict.type = "prob")
51+
#' makeLearner("classif.rpart", minsplit = expression(k))
5052
#' lrn = makeLearner("classif.lda", method = "t", nu = 10)
5153
#' print(lrn$par.vals)
5254
makeLearner = function(cl, id = cl, predict.type = "response", predict.threshold = NULL,

R/setHyperPars.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111
#' @note If a named (hyper)parameter can't be found for the given learner, the 3
1212
#' closest (hyper)parameter names will be output in case the user mistyped.
1313
#' @export
14+
#' @note Learners can contain task dependent expressions, see \code{\link{evaluateParamExpressions}} for more information.
1415
#' @family learner
1516
#' @importFrom utils adist
1617
#' @examples
1718
#' cl1 = makeLearner("classif.ksvm", sigma = 1)
1819
#' cl2 = setHyperPars(cl1, sigma = 10, par.vals = list(C = 2))
20+
#' cl3 = setHyperPars(cl2, C = expression(round(n / p)))
1921
#' print(cl1)
20-
#' # note the now set and altered hyperparameters:
2122
#' print(cl2)
23+
#' print(cl3)
2224
setHyperPars = function(learner, ..., par.vals = list()) {
2325
args = list(...)
2426
assertClass(learner, classes = "Learner")
@@ -73,7 +75,7 @@ setHyperPars2.Learner = function(learner, par.vals) {
7375
learner$par.set$pars[[n]] = makeUntypedLearnerParam(id = n)
7476
learner$par.vals[[n]] = p
7577
} else {
76-
if (on.par.out.of.bounds != "quiet" && !isFeasible(pd, p)) {
78+
if (on.par.out.of.bounds != "quiet" && !isFeasible(pd, p) && !is.expression(p)) {
7779
msg = sprintf("%s is not feasible for parameter '%s'!", convertToShortString(p), pd$id)
7880
if (on.par.out.of.bounds == "stop") {
7981
stop(msg)

R/train.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
train = function(learner, task, subset, weights = NULL) {
3232
learner = checkLearner(learner)
3333
assertClass(task, classes = "Task")
34+
if (hasExpression(learner)) {
35+
dict = getTaskDictionary(task = task)
36+
learner = evaluateParamExpressions(obj = learner, dict = dict)
37+
}
3438
if (missing(subset)) {
3539
subset = seq_len(getTaskSize(task))
3640
} else {

R/tuneParams.R

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
#' @param par.set [\code{\link[ParamHelpers]{ParamSet}}]\cr
2323
#' Collection of parameters and their constraints for optimization.
2424
#' Dependent parameters with a \code{requires} field must use \code{quote} and not
25-
#' \code{expression} to define it.
25+
#' \code{expression} to define it. On the other hand, task dependent parameters
26+
#' need to be defined with expressions.
2627
#' @param control [\code{\link{TuneControl}}]\cr
2728
#' Control object for search method. Also selects the optimization algorithm for tuning.
2829
#' @template arg_showinfo
@@ -31,6 +32,8 @@
3132
#' @note If you would like to include results from the training data set, make
3233
#' sure to appropriately adjust the resampling strategy and the aggregation for
3334
#' the measure. See example code below.
35+
#' Also note that learners and parameter sets can contain task dependent
36+
#' expressions, see \code{\link{evaluateParamExpressions}} for more information.
3437
#' @export
3538
#' @examples
3639
#' # a grid search for an SVM (with a tiny number of points...)
@@ -50,6 +53,16 @@
5053
#' print(head(generateHyperParsEffectData(res)))
5154
#' print(head(generateHyperParsEffectData(res, trafo = TRUE)))
5255
#'
56+
#' # tuning the parameters 'C' and 'sigma' of an SVM, where the boundaries
57+
#' # of 'sigma' depend on the number of features
58+
#' ps = makeParamSet(
59+
#' makeNumericLearnerParam("sigma", lower = expression(0.2 * p), upper = expression(2.5 * p)),
60+
#' makeDiscreteLearnerParam("C", values = 2^c(-1, 1))
61+
#' )
62+
#' rdesc = makeResampleDesc("Subsample")
63+
#' ctrl = makeTuneControlRandom(maxit = 2L)
64+
#' res = tuneParams("classif.ksvm", iris.task, par.set = ps, control = ctrl, resampling = rdesc)
65+
#'
5366
#' \dontrun{
5467
#' # we optimize the SVM over 3 kernels simultanously
5568
#' # note how we use dependent params (requires = ...) and iterated F-racing here
@@ -81,6 +94,11 @@ tuneParams = function(learner, task, resampling, measures, par.set, control, sho
8194
assertClass(task, classes = "Task")
8295
measures = checkMeasures(measures, learner)
8396
assertClass(par.set, classes = "ParamSet")
97+
if (hasExpression(learner) || hasExpression(par.set)) {
98+
dict = getTaskDictionary(task = task)
99+
learner = evaluateParamExpressions(obj = learner, dict = dict)
100+
par.set = evaluateParamExpressions(obj = par.set, dict = dict)
101+
}
84102
assertClass(control, classes = "TuneControl")
85103
if (!inherits(resampling, "ResampleDesc") && !inherits(resampling, "ResampleInstance"))
86104
stop("Argument resampling must be of class ResampleDesc or ResampleInstance!")
@@ -113,5 +131,3 @@ tuneParams = function(learner, task, resampling, measures, par.set, control, sho
113131
messagef("[Tune] Result: %s : %s", paramValueToString(par.set, or$x), perfsToString(or$y))
114132
return(or)
115133
}
116-
117-

man/LearnerProperties.Rd

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)