diff --git a/.gitignore b/.gitignore index 0f6767c2a..5187a05a0 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,5 @@ CRAN-SUBMISSION paper/data .idea/ .vsc/ -paper/data \ No newline at end of file +paper/data +.vscode diff --git a/R/LearnerTorch.R b/R/LearnerTorch.R index 1b9672b89..128bb66be 100644 --- a/R/LearnerTorch.R +++ b/R/LearnerTorch.R @@ -532,7 +532,20 @@ LearnerTorch = R6Class("LearnerTorch", "worker_packages" ) args = param_vals[names(param_vals) %in% dl_args] - for(param_name in c("sampler", "batch_sampler")){ + + ok = sum(is.null(args$sampler), is.null(args$batch_sampler), is.null(args$batch_size)) == 2L + + if (!ok) { + stopf("Provide either 'sampler', 'batch_sampler', or 'batch_size'.") + } + + if (is.null(args$batch_size)) { + if (!is.null(args$shuffle) || !is.null(args$drop_last)) { + stopf("'shuffle' and 'drop_last' are only allowed when 'batch_size' is provided.") + } + } + + for (param_name in c("sampler", "batch_sampler")){ param_val <- args[[param_name]] if (!is.null(param_val)) { # instantiate these params which should be classes. @@ -542,6 +555,9 @@ LearnerTorch = R6Class("LearnerTorch", invoke(dataloader, dataset = dataset, .args = args) }, .dataloader_predict = function(dataset, param_vals) { + if (is.null(param_vals$batch_size)) { + stop("'batch_size' must be provided for prediction.") + } param_vals_test = insert_named(param_vals, list(shuffle = FALSE, drop_last = FALSE)) private$.dataloader(dataset, param_vals_test) }, diff --git a/R/paramset_torchlearner.R b/R/paramset_torchlearner.R index c7a474e6b..3c4e4fab0 100644 --- a/R/paramset_torchlearner.R +++ b/R/paramset_torchlearner.R @@ -71,10 +71,14 @@ paramset_torchlearner = function(task_type, jittable = FALSE) { patience = p_int(lower = 0L, tags = c("train", "required"), init = 0L), min_delta = p_dbl(lower = 0, tags = c("train", "required"), init = 0), # dataloader parameters - batch_size = p_int(tags = c("train", "predict", "required"), lower = 1L), + batch_size = p_int(tags = c("train", "predict"), lower = 1L), shuffle = p_lgl(tags = "train", default = FALSE, init = TRUE), - sampler = p_uty(tags = c("train", "predict")), - batch_sampler = p_uty(tags = c("train", "predict")), + sampler = p_uty(tags = c("train", "predict"), custom_check = crate(function(x) { + checkmate::check_class(x, "torch_sampler") + })), + batch_sampler = p_uty(tags = c("train", "predict"), custom_check = crate(function(x) { + checkmate::check_class(x, "torch_sampler") + })), num_workers = p_int(lower = 0, default = 0, tags = c("train", "predict")), collate_fn = p_uty(tags = c("train", "predict"), default = NULL), pin_memory = p_lgl(default = FALSE, tags = c("train", "predict")),