From ca694f6fb99dd0aa0c468dadb022cdee06c7cdda Mon Sep 17 00:00:00 2001 From: Nic Crane Date: Tue, 17 Mar 2026 10:46:48 +0000 Subject: [PATCH] implement when all/any and add tests and build docs --- r/R/dplyr-funcs-conditional.R | 22 +++++++ r/R/dplyr-funcs-doc.R | 4 +- r/man/acero.Rd | 4 +- r/man/read_json_arrow.Rd | 2 +- r/man/schema.Rd | 2 +- .../testthat/test-dplyr-funcs-conditional.R | 66 +++++++++++++++++++ 6 files changed, 96 insertions(+), 4 deletions(-) diff --git a/r/R/dplyr-funcs-conditional.R b/r/R/dplyr-funcs-conditional.R index 25d7fbc668cf..ca7744e70f4c 100644 --- a/r/R/dplyr-funcs-conditional.R +++ b/r/R/dplyr-funcs-conditional.R @@ -99,6 +99,28 @@ register_bindings_conditional <- function() { out }) + register_binding("dplyr::when_any", function(..., na_rm = FALSE, size = NULL) { + if (!is.null(size)) { + arrow_not_supported("`when_any()` with `size` specified") + } + args <- list2(...) + if (na_rm) { + args <- lapply(args, function(x) call_binding("coalesce", x, FALSE)) + } + Reduce("|", args) + }) + + register_binding("dplyr::when_all", function(..., na_rm = FALSE, size = NULL) { + if (!is.null(size)) { + arrow_not_supported("`when_all()` with `size` specified") + } + args <- list2(...) + if (na_rm) { + args <- lapply(args, function(x) call_binding("coalesce", x, TRUE)) + } + Reduce("&", args) + }) + register_binding( "dplyr::case_when", function(..., .default = NULL, .ptype = NULL, .size = NULL) { diff --git a/r/R/dplyr-funcs-doc.R b/r/R/dplyr-funcs-doc.R index 9293d14c94c0..e0b3dd095c9f 100644 --- a/r/R/dplyr-funcs-doc.R +++ b/r/R/dplyr-funcs-doc.R @@ -21,7 +21,7 @@ #' #' The `arrow` package contains methods for 38 `dplyr` table functions, many of #' which are "verbs" that do transformations to one or more tables. -#' The package also has mappings of 224 R functions to the corresponding +#' The package also has mappings of 226 R functions to the corresponding #' functions in the Arrow compute library. These allow you to write code inside #' of `dplyr` methods that call R functions, including many in packages like #' `stringr` and `lubridate`, and they will get translated to Arrow and run @@ -214,6 +214,8 @@ #' * [`if_else()`][dplyr::if_else()] #' * [`n()`][dplyr::n()] #' * [`n_distinct()`][dplyr::n_distinct()] +#' * [`when_all()`][dplyr::when_all()] +#' * [`when_any()`][dplyr::when_any()] #' #' ## hms #' diff --git a/r/man/acero.Rd b/r/man/acero.Rd index ee156cc9129b..a43617493a33 100644 --- a/r/man/acero.Rd +++ b/r/man/acero.Rd @@ -9,7 +9,7 @@ \description{ The \code{arrow} package contains methods for 38 \code{dplyr} table functions, many of which are "verbs" that do transformations to one or more tables. -The package also has mappings of 224 R functions to the corresponding +The package also has mappings of 226 R functions to the corresponding functions in the Arrow compute library. These allow you to write code inside of \code{dplyr} methods that call R functions, including many in packages like \code{stringr} and \code{lubridate}, and they will get translated to Arrow and run @@ -207,6 +207,8 @@ Valid values are "s", "ms" (default), "us", "ns". \item \code{\link[dplyr:if_else]{if_else()}} \item \code{\link[dplyr:context]{n()}} \item \code{\link[dplyr:n_distinct]{n_distinct()}} +\item \code{\link[dplyr:when-any-all]{when_all()}} +\item \code{\link[dplyr:when-any-all]{when_any()}} } } diff --git a/r/man/read_json_arrow.Rd b/r/man/read_json_arrow.Rd index b809a63bcc6f..abf6b8fc44a8 100644 --- a/r/man/read_json_arrow.Rd +++ b/r/man/read_json_arrow.Rd @@ -54,7 +54,7 @@ If \code{schema} is not provided, Arrow data types are inferred from the data: \item JSON numbers convert to \code{\link[=int64]{int64()}}, falling back to \code{\link[=float64]{float64()}} if a non-integer is encountered. \item JSON strings of the kind "YYYY-MM-DD" and "YYYY-MM-DD hh:mm:ss" convert to \code{\link[=timestamp]{timestamp(unit = "s")}}, falling back to \code{\link[=utf8]{utf8()}} if a conversion error occurs. -\item JSON arrays convert to a \code{\link[=list_of]{list_of()}} type, and inference proceeds recursively on the JSON arrays' values. +\item JSON arrays convert to a \code{\link[vctrs:list_of]{vctrs::list_of()}} type, and inference proceeds recursively on the JSON arrays' values. \item Nested JSON objects convert to a \code{\link[=struct]{struct()}} type, and inference proceeds recursively on the JSON objects' values. } diff --git a/r/man/schema.Rd b/r/man/schema.Rd index 65ab2eea0d27..ff77a05d84aa 100644 --- a/r/man/schema.Rd +++ b/r/man/schema.Rd @@ -7,7 +7,7 @@ schema(...) } \arguments{ -\item{...}{\link[=field]{fields}, field name/\link[=data-type]{data type} pairs (or a list of), or object from which to extract +\item{...}{\link[vctrs:fields]{fields}, field name/\link[=data-type]{data type} pairs (or a list of), or object from which to extract a schema} } \description{ diff --git a/r/tests/testthat/test-dplyr-funcs-conditional.R b/r/tests/testthat/test-dplyr-funcs-conditional.R index 58373db253fd..c966973835ad 100644 --- a/r/tests/testthat/test-dplyr-funcs-conditional.R +++ b/r/tests/testthat/test-dplyr-funcs-conditional.R @@ -517,3 +517,69 @@ test_that("external objects are found when they're not in the global environment tibble(x = c("a", "b"), x2 = c("foo", NA)) ) }) + +test_that("when_any()", { + # combines with OR + compare_dplyr_binding( + .input |> + mutate(result = when_any(lgl, false)) |> + collect(), + tbl + ) + + # na_rm=TRUE treats NA as FALSE + compare_dplyr_binding( + .input |> + mutate(result = when_any(lgl, false, na_rm = TRUE)) |> + collect(), + tbl + ) + + # works in filter() + compare_dplyr_binding( + .input |> + filter(when_any(int > 5, dbl > 3)) |> + collect(), + tbl + ) + + # size not supported + expect_arrow_eval_error( + when_any(lgl, false, size = 10), + "`when_any\\(\\)` with `size` specified not supported in Arrow", + class = "arrow_not_supported" + ) +}) + +test_that("when_all()", { + # combines with AND + compare_dplyr_binding( + .input |> + mutate(result = when_all(lgl, false)) |> + collect(), + tbl + ) + + # na_rm=TRUE treats NA as TRUE + compare_dplyr_binding( + .input |> + mutate(result = when_all(lgl, false, na_rm = TRUE)) |> + collect(), + tbl + ) + + # works in filter() + compare_dplyr_binding( + .input |> + filter(when_all(int > 5, dbl > 3)) |> + collect(), + tbl + ) + + # size not supported + expect_arrow_eval_error( + when_all(lgl, false, size = 10), + "`when_all\\(\\)` with `size` specified not supported in Arrow", + class = "arrow_not_supported" + ) +})