Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion r/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Imports:
utils,
vctrs
Roxygen: list(markdown = TRUE, r6 = FALSE, load = "source")
RoxygenNote: 7.3.3
RoxygenNote: 7.3.3.9000
Config/testthat/edition: 3
Config/build/bootstrap: TRUE
Suggests:
Expand Down
142 changes: 128 additions & 14 deletions r/R/dplyr-funcs-conditional.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.

# nolint start: cyclocomp_linter.
register_bindings_conditional <- function() {
register_binding("%in%", function(x, table) {
# We use `is_in` here, unlike with Arrays, which use `is_in_meta_binary`
Expand Down Expand Up @@ -134,21 +135,134 @@ register_bindings_conditional <- function() {
validation_error(paste0("`.default` must have size 1, not size ", length(.default), "."))
}

query[n + 1] <- TRUE
value[n + 1] <- .default
}
Expression$create(
"case_when",
args = c(
Expression$create(
"make_struct",
args = query,
options = list(field_names = as.character(seq_along(query)))
),
value
)
)
query[[n + 1]] <- TRUE
value[[n + 1]] <- .default
}
build_case_when_expr(query, value)
},
notes = "`.ptype` and `.size` arguments not supported"
)

register_binding("dplyr::replace_when", function(x, ...) {
formulas <- list2(...)
n <- length(formulas)
if (n == 0) {
return(x)
}
query <- vector("list", n + 1)
value <- vector("list", n + 1)
mask <- caller_env()
for (i in seq_len(n)) {
f <- formulas[[i]]
if (!inherits(f, "formula")) {
validation_error("Each argument to replace_when() must be a two-sided formula")
}
query[[i]] <- arrow_eval(f[[2]], mask)
value[[i]] <- arrow_eval(f[[3]], mask)
if (!call_binding("is.logical", query[[i]])) {
validation_error("Left side of each formula in replace_when() must be a logical expression")
}
}
query[[n + 1]] <- TRUE
value[[n + 1]] <- x
build_case_when_expr(query, value)
})

register_binding("dplyr::replace_values", function(x, ..., from = NULL, to = NULL) {
parsed <- parse_value_mapping(x, list2(...), from, to, caller_env(), "replace_values")
if (is.null(parsed)) {
return(x)
}
query <- parsed$query
value <- parsed$value
n <- length(query)
query[[n + 1]] <- TRUE
value[[n + 1]] <- x
build_case_when_expr(query, value)
})

register_binding(
"dplyr::recode_values",
function(x, ..., from = NULL, to = NULL, default = NULL, unmatched = "default", ptype = NULL) {
if (!is.null(ptype)) {
arrow_not_supported("`recode_values()` with `ptype` specified")
}
if (unmatched == "error") {
arrow_not_supported("`recode_values()` with `unmatched = \"error\"`")
}

parsed <- parse_value_mapping(x, list2(...), from, to, caller_env(), "recode_values")
if (is.null(parsed)) {
query <- list()
value <- list()
} else {
query <- parsed$query
value <- parsed$value
}

if (!is.null(default)) {
n <- length(query)
query[[n + 1]] <- TRUE
value[[n + 1]] <- Expression$scalar(default)
}
build_case_when_expr(query, value)
},
notes = "`ptype` argument and `unmatched = \"error\"` not supported"
)

# Create case_when Expression from query/value lists
build_case_when_expr <- function(query, value) {
Expression$create(
"case_when",
args = c(
Expression$create(
"make_struct",
args = query,
options = list(field_names = as.character(seq_along(query)))
),
value
)
)
}

# Parse value ~ replacement formulas or from/to vectors into query/value lists
# Used by replace_values and recode_values
parse_value_mapping <- function(x, formulas, from, to, mask, fn) {
if (length(formulas) > 0 && !is.null(from)) {
validation_error(paste0("Can't use both `...` and `from`/`to` in ", fn, "()"))
}

if (length(formulas) > 0) {
n <- length(formulas)
query <- vector("list", n)
value <- vector("list", n)
for (i in seq_len(n)) {
f <- formulas[[i]]
if (!inherits(f, "formula")) {
validation_error(paste0("Each argument to ", fn, "() must be a two-sided formula"))
}
lhs <- arrow_eval(f[[2]], mask)
rhs <- arrow_eval(f[[3]], mask)
query[[i]] <- x == lhs
value[[i]] <- rhs
}
list(query = query, value = value)
} else if (!is.null(from)) {
if (is.null(to)) {
validation_error("`to` must be provided when using `from`")
}
n <- length(from)
to <- vctrs::vec_recycle(to, n)
query <- vector("list", n)
value <- vector("list", n)
for (i in seq_len(n)) {
query[[i]] <- x == from[[i]]
value[[i]] <- Expression$scalar(to[[i]])
}
list(query = query, value = value)
} else {
NULL
}
}
}
# nolint end.
5 changes: 4 additions & 1 deletion r/R/dplyr-funcs-doc.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#'
#' The `arrow` package contains methods for 38 `dplyr` table functions, many of
#' which are "verbs" that do transformations to one or more tables.
#' The package also has mappings of 224 R functions to the corresponding
#' The package also has mappings of 227 R functions to the corresponding
#' functions in the Arrow compute library. These allow you to write code inside
#' of `dplyr` methods that call R functions, including many in packages like
#' `stringr` and `lubridate`, and they will get translated to Arrow and run
Expand Down Expand Up @@ -214,6 +214,9 @@
#' * [`if_else()`][dplyr::if_else()]
#' * [`n()`][dplyr::n()]
#' * [`n_distinct()`][dplyr::n_distinct()]
#' * [`recode_values()`][dplyr::recode_values()]: `ptype` argument and `unmatched = "error"` not supported
#' * [`replace_values()`][dplyr::replace_values()]
#' * [`replace_when()`][dplyr::replace_when()]
#'
#' ## hms
#'
Expand Down
5 changes: 4 additions & 1 deletion r/man/acero.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion r/man/read_json_arrow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion r/man/schema.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

132 changes: 132 additions & 0 deletions r/tests/testthat/test-dplyr-funcs-conditional.R
Original file line number Diff line number Diff line change
Expand Up @@ -517,3 +517,135 @@ test_that("external objects are found when they're not in the global environment
tibble(x = c("a", "b"), x2 = c("foo", NA))
)
})

test_that("replace_when()", {
# replaces matching values, keeps original otherwise
compare_dplyr_binding(
.input |>
mutate(result = replace_when(int, int > 5 ~ 100L)) |>
collect(),
tbl
)

# multiple conditions
compare_dplyr_binding(
.input |>
mutate(result = replace_when(int, int > 7 ~ 100L, int < 3 ~ 0L)) |>
collect(),
tbl
)

# no formulas returns x unchanged
compare_dplyr_binding(
.input |>
mutate(result = replace_when(int)) |>
collect(),
tbl
)

# validation errors
expect_arrow_eval_error(
replace_when(int, TRUE),
"Each argument to replace_when\\(\\) must be a two-sided formula",
class = "validation_error"
)
expect_arrow_eval_error(
replace_when(int, 0L ~ 100L),
"Left side of each formula in replace_when\\(\\) must be a logical expression",
class = "validation_error"
)
})

test_that("replace_values()", {
# formula interface
compare_dplyr_binding(
.input |>
mutate(result = replace_values(chr, "a" ~ "A", "b" ~ "B")) |>
collect(),
tbl
)

# from/to interface
compare_dplyr_binding(
.input |>
mutate(result = replace_values(chr, from = c("a", "b"), to = c("A", "B"))) |>
collect(),
tbl
)

# unmatched values kept
compare_dplyr_binding(
.input |>
mutate(result = replace_values(chr, "a" ~ "A")) |>
collect(),
tbl
)

# no replacements returns x unchanged
compare_dplyr_binding(
.input |>
mutate(result = replace_values(chr)) |>
collect(),
tbl
)

# validation errors
expect_arrow_eval_error(
replace_values(chr, "a" ~ "A", from = "b"),
"Can't use both `...` and `from`/`to` in replace_values\\(\\)",
class = "validation_error"
)
expect_arrow_eval_error(
replace_values(chr, from = "a"),
"`to` must be provided when using `from`",
class = "validation_error"
)
})

test_that("recode_values()", {
# formula interface with default NA
compare_dplyr_binding(
.input |>
mutate(result = recode_values(chr, "a" ~ "A", "b" ~ "B")) |>
collect(),
tbl
)

# from/to interface
compare_dplyr_binding(
.input |>
mutate(result = recode_values(chr, from = c("a", "b"), to = c("A", "B"))) |>
collect(),
tbl
)

# custom default
compare_dplyr_binding(
.input |>
mutate(result = recode_values(chr, "a" ~ "A", default = "other")) |>
collect(),
tbl
)

# validation errors
expect_arrow_eval_error(
recode_values(chr, "a" ~ "A", from = "b"),
"Can't use both `...` and `from`/`to` in recode_values\\(\\)",
class = "validation_error"
)
expect_arrow_eval_error(
recode_values(chr, from = "a"),
"`to` must be provided when using `from`",
class = "validation_error"
)
expect_arrow_eval_error(
recode_values(chr, "a" ~ "A", ptype = character()),
"`recode_values\\(\\)` with `ptype` specified not supported in Arrow",
class = "arrow_not_supported"
)
expect_arrow_eval_error(
recode_values(chr, "a" ~ "A", unmatched = "error"),
"`recode_values\\(\\)` with `unmatched = \"error\"` not supported in Arrow",
class = "arrow_not_supported"
)
})
Loading