Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Imports:
cli,
dplyr (>= 1.1.0),
generics (>= 0.1.2),
hardhat (>= 1.3.0),
hardhat (>= 1.4.2.9000),
lifecycle (>= 1.0.3),
rlang (>= 1.1.4),
tibble,
Expand Down Expand Up @@ -116,6 +116,7 @@ Collate:
'prob-roc_aunp.R'
'prob-roc_aunu.R'
'prob-roc_curve.R'
'quant-weighted_interval_score.R'
'reexports.R'
'surv-brier_survival.R'
'surv-brier_survival_integrated.R'
Expand All @@ -125,3 +126,5 @@ Collate:
'template.R'
'validation.R'
'yardstick-package.R'
Remotes:
tidymodels/hardhat
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ S3method(tidy,conf_mat)
S3method(validate_truth_estimate_types,default)
S3method(validate_truth_estimate_types,factor)
S3method(validate_truth_estimate_types,numeric)
S3method(weighted_interval_score,data.frame)
export(accuracy)
export(accuracy_vec)
export(average_precision)
Expand All @@ -150,6 +151,7 @@ export(check_linear_pred_survival_metric)
export(check_numeric_metric)
export(check_ordered_prob_metric)
export(check_prob_metric)
export(check_quantile_metric)
export(check_static_survival_metric)
export(class_metric_summarizer)
export(classification_cost)
Expand Down Expand Up @@ -213,6 +215,7 @@ export(new_linear_pred_survival_metric)
export(new_numeric_metric)
export(new_ordered_prob_metric)
export(new_prob_metric)
export(new_quantile_metric)
export(new_static_survival_metric)
export(npv)
export(npv_vec)
Expand All @@ -228,6 +231,7 @@ export(pr_curve)
export(precision)
export(precision_vec)
export(prob_metric_summarizer)
export(quantile_metric_summarizer)
export(ranked_prob_score)
export(ranked_prob_score_vec)
export(recall)
Expand Down Expand Up @@ -265,6 +269,8 @@ export(specificity_vec)
export(static_survival_metric_summarizer)
export(tidy)
export(validate_estimator)
export(weighted_interval_score)
export(weighted_interval_score_vec)
export(yardstick_any_missing)
export(yardstick_remove_missing)
import(rlang)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

* Added infrastructure for survival metrics on the linear predictor. (#551)

* Added infrastructure for quantile metrics. (#569)

* Added quantile metric `weighted_interval_score()`. (#569)

# yardstick 1.3.2

* All messages, warnings and errors has been translated to use {cli} package (#517, #522).
Expand Down
54 changes: 53 additions & 1 deletion R/aaa-metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ metric_set <- function(...) {
fn_cls %in% c("prob_metric", "class_metric", "ordered_prob_metric")
) {
make_prob_class_metric_function(fns)
} else if (fn_cls == "quantile_metric") {
make_quantile_metric_function(fns)
} else if (
fn_cls %in%
c(
Expand Down Expand Up @@ -663,6 +665,55 @@ make_survival_metric_function <- function(fns) {
metric_function
}

make_quantile_metric_function <- function(fns) {
metric_function <- function(
data,
truth,
estimate,
na_rm = TRUE,
case_weights = NULL,
...
) {
# Construct common argument set for each metric call
# Doing this dynamically inside the generated function means
# we capture the correct arguments
call_args <- quos(
data = data,
truth = !!enquo(truth),
estimate = !!enquo(estimate),
na_rm = na_rm,
case_weights = !!enquo(case_weights),
... = ...
)

# Construct calls from the functions + arguments
calls <- lapply(fns, call2, !!!call_args)

calls <- mapply(call_remove_static_arguments, calls, fns)

# Evaluate
metric_list <- mapply(
FUN = eval_safely,
calls, # .x
names(calls), # .y
SIMPLIFY = FALSE,
USE.NAMES = FALSE
)

dplyr::bind_rows(metric_list)
}

class(metric_function) <- c(
"quantile_metric_set",
"metric_set",
class(metric_function)
)

attr(metric_function, "metrics") <- fns

metric_function
}

validate_not_empty <- function(x, call = caller_env()) {
if (is_empty(x)) {
cli::cli_abort(
Expand Down Expand Up @@ -705,7 +756,8 @@ validate_function_class <- function(fns) {
"dynamic_survival_metric",
"static_survival_metric",
"integrated_survival_metric",
"linear_pred_survival_metric"
"linear_pred_survival_metric",
"quantile_metric"
)

if (n_unique == 1L) {
Expand Down
7 changes: 7 additions & 0 deletions R/aaa-new.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ new_linear_pred_survival_metric <- function(fn, direction) {
new_metric(fn, direction, class = "linear_pred_survival_metric")
}

#' @rdname new-metric
#' @export
new_quantile_metric <- function(fn, direction) {
new_metric(fn, direction, class = "quantile_metric")
}

#' @include import-standalone-types-check.R
new_metric <- function(fn, direction, class = NULL, call = caller_env()) {
check_function(fn, call = call)
Expand Down Expand Up @@ -128,6 +134,7 @@ format.metric <- function(x, ...) {
"static_survival_metric" = "static survival metric",
"integrated_survival_metric" = "integrated survival metric",
"linear_pred_survival_metric" = "linear predictor survival metric",
"quantile_metric" = "quantile metric",
"metric"
)

Expand Down
14 changes: 14 additions & 0 deletions R/check-metric.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#' - For `check_ordered_prob_metric()`, an ordered factor.
#' - For `check_dynamic_survival_metric()`, a Surv object.
#' - For `check_static_survival_metric()`, a Surv object.
#' - For `check_quantile_metric()`, a numeric vector.
#'
#' @param estimate The realized `estimate` result.
#' - For `check_numeric_metric()`, a numeric vector.
Expand All @@ -25,6 +26,7 @@
#' a numeric matrix for multic-class `truth`.
#' - For `check_dynamic_survival_metric()`, list-column of data.frames.
#' - For `check_static_survival_metric()`, a numeric vector.
#' - For `check_quantile_metric()`, a `hardhat::quantile_pred` vector.
#'
#' @param case_weights The realized case weights, as a numeric vector. This must
#' be the same length as `truth`.
Expand Down Expand Up @@ -132,3 +134,15 @@ check_linear_pred_survival_metric <- function(
validate_case_weights(case_weights, size = nrow(truth), call = call)
validate_surv_truth_numeric_estimate(truth, estimate, call = call)
}

#' @rdname check_metric
#' @export
check_quantile_metric <- function(
truth,
estimate,
case_weights,
call = caller_env()
) {
validate_numeric_truth_quantile_estimate(truth, estimate, call = call)
validate_case_weights(case_weights, size = length(truth), call = call)
}
Loading