Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Depends:
R (>= 4.2.0)
Imports:
cli,
lifecycle,
rlang,
stats,
tidyselect,
Expand Down
42 changes: 21 additions & 21 deletions R/ipw.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#' variable.
#' @param outcome_mod A fitted, weighted outcome model of class [stats::glm()]
#' or [stats::lm()], with the outcome as the dependent variable.
#' @param .df A data frame containing the exposure, outcome, and covariates. If
#' @param .data A data frame containing the exposure, outcome, and covariates. If
#' `NULL`, `ipw()` will try to extract the data from `ps_mod` and
#' `outcome_mod`.
#' @param estimand A character string specifying the causal estimand: `ate`,
Expand Down Expand Up @@ -95,7 +95,7 @@
ipw <- function(
ps_mod,
outcome_mod,
.df = NULL,
.data = NULL,
estimand = NULL,
ps_link = NULL,
conf_level = 0.95
Expand All @@ -107,19 +107,19 @@ ipw <- function(
exposure_name <- fmla_extract_left_chr(ps_mod)
outcome_name <- fmla_extract_left_chr(outcome_mod)

if (is.null(.df)) {
if (is.null(.data)) {
exposure <- fmla_extract_left_vctr(ps_mod)
outcome <- fmla_extract_left_vctr(outcome_mod)
} else {
assert_class(exposure_name, "character", .length = 1)
assert_class(outcome_name, "character", .length = 1)
assert_columns_exist(.df, c(exposure_name, outcome_name))
assert_columns_exist(.data, c(exposure_name, outcome_name))

exposure <- .df[[exposure_name]]
outcome <- .df[[outcome_name]]
exposure <- .data[[exposure_name]]
outcome <- .data[[outcome_name]]
}

ps <- predict(ps_mod, type = "response", newdata = .df)
ps <- predict(ps_mod, type = "response", newdata = .data)

if (is.null(ps_link)) {
ps_link <- ps_mod$family$link
Expand All @@ -143,7 +143,7 @@ ipw <- function(
wts = wts,
exposure = exposure,
exposure_name = exposure_name,
.df = .df
.data = .data
)

uncorrected_lin_vars <- linearize_variables_for_wts(
Expand Down Expand Up @@ -503,13 +503,13 @@ estimate_marginal_means <- function(
wts,
exposure,
exposure_name,
.df = NULL,
.data = NULL,
call = rlang::caller_env()
) {
# todo: this could be generalized with split() and lapply()
if (is.null(.df)) {
.df <- model.frame(outcome_mod)
check_exposure(.df, exposure_name, call = call)
if (is.null(.data)) {
.data <- model.frame(outcome_mod)
check_exposure(.data, exposure_name, call = call)
}
# todo: make this more flexible for different values and model specs
# maybe can optionally accept a function for g-comp
Expand All @@ -523,16 +523,16 @@ estimate_marginal_means <- function(
))
}

.df_1 <- .df
.df_0 <- .df
.df_1[[exposure_name]] <- exposure_values[[2]]
.df_0[[exposure_name]] <- exposure_values[[1]]
.data_1 <- .data
.data_0 <- .data
.data_1[[exposure_name]] <- exposure_values[[2]]
.data_0[[exposure_name]] <- exposure_values[[1]]

n1 <- sum(wts[exposure == exposure_values[[2]]])
mu1 <- mean(predict(outcome_mod, newdata = .df_1, type = "response"))
mu1 <- mean(predict(outcome_mod, newdata = .data_1, type = "response"))

n0 <- sum(wts[exposure == exposure_values[[1]]])
mu0 <- mean(predict(outcome_mod, newdata = .df_0, type = "response"))
mu0 <- mean(predict(outcome_mod, newdata = .data_0, type = "response"))

list(
# exposure = 1
Expand Down Expand Up @@ -656,14 +656,14 @@ check_estimand <- function(wts, estimand, call = rlang::caller_env()) {
}
}

check_exposure <- function(.df, .exposure_name, call = rlang::caller_env()) {
check_exposure <- function(.data, .exposure_name, call = rlang::caller_env()) {
assert_class(.exposure_name, "character", .length = 1, call = call)
if (!(.exposure_name %in% names(.df))) {
if (!(.exposure_name %in% names(.data))) {
abort(
c(
"{.val { .exposure_name}} not found in {.code model.frame(outcome_mod)}.",
i = "The outcome model may have transformations in the formula.",
i = "Please specify {.arg .df}"
i = "Please specify {.arg .data}"
),
call = call,
error_class = "propensity_columns_exist_error"
Expand Down
61 changes: 38 additions & 23 deletions R/ps_calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#' preserves the attributes of causal weight objects when applicable.
#'
#' @param ps Numeric vector of propensity scores between 0 and 1
#' @param treat A binary vector of treatment assignments
#' @param .exposure A binary vector of treatment assignments
#' @param method Calibration method:
#' \describe{
#' \item{`"logistic"`}{(Default) Logistic calibration (also known as Platt scaling).
Expand All @@ -24,39 +24,54 @@
#' regression (`smooth = FALSE`). When `TRUE`, uses `mgcv::gam()` with
#' spline smoothing. When `FALSE`, uses `stats::glm()`. Ignored for
#' `method = "isoreg"`.
#' @param .treated The value representing the treatment group. If not provided,
#' `ps_calibrate()` will attempt to automatically determine the treatment coding.
#' @param .untreated The value representing the control group. If not provided,
#' `ps_calibrate()` will attempt to automatically determine the control coding.
#' @param .focal_level The value representing the focal group (typically treatment).
#' If not provided, `ps_calibrate()` will attempt to automatically determine the coding.
#' @param .reference_level The value representing the reference group (typically control).
#' If not provided, `ps_calibrate()` will attempt to automatically determine the coding.
#' @param .treated `r lifecycle::badge("deprecated")` Use `.focal_level` instead.
#' @param .untreated `r lifecycle::badge("deprecated")` Use `.reference_level` instead.
#' @param estimand Character indicating the estimand type.
#'
#' @return A calibrated propensity score object (`psw`)
#'
#' @examples
#' # Generate example data
#' ps <- runif(100)
#' treat <- rbinom(100, 1, ps)
#' exposure <- rbinom(100, 1, ps)
#'
#' # Logistic calibration with smoothing (default)
#' calibrated_smooth <- ps_calibrate(ps, treat)
#' calibrated_smooth <- ps_calibrate(ps, exposure)
#'
#' # Logistic calibration without smoothing (simple logistic regression)
#' calibrated_simple <- ps_calibrate(ps, treat, smooth = FALSE)
#' calibrated_simple <- ps_calibrate(ps, exposure, smooth = FALSE)
#'
#' # Isotonic regression
#' calibrated_iso <- ps_calibrate(ps, treat, method = "isoreg")
#' calibrated_iso <- ps_calibrate(ps, exposure, method = "isoreg")
#' @importFrom stats glm fitted isoreg binomial
#' @export
ps_calibrate <- function(
ps,
treat,
.exposure,
method = c("logistic", "isoreg"),
smooth = TRUE,
.focal_level = NULL,
.reference_level = NULL,
estimand = NULL,
.treated = NULL,
.untreated = NULL,
estimand = NULL
.untreated = NULL
) {
method <- rlang::arg_match(method)

# Handle deprecation
focal_params <- handle_focal_deprecation(
.focal_level,
.reference_level,
.treated,
.untreated,
"ps_calibrate"
)
.focal_level <- focal_params$.focal_level
.reference_level <- focal_params$.reference_level
# Check that ps is numeric and in valid range
if (!is.numeric(ps)) {
abort(
Expand All @@ -83,15 +98,15 @@ ps_calibrate <- function(
}

# Transform treatment to binary if needed
treat <- transform_exposure_binary(
treat,
.treated = .treated,
.untreated = .untreated
.exposure <- transform_exposure_binary(
.exposure,
.focal_level = .focal_level,
.reference_level = .reference_level
)

if (length(ps) != length(treat)) {
if (length(ps) != length(.exposure)) {
abort(
"Propensity score vector `ps` must be the same length as `treat`.",
"Propensity score vector `ps` must be the same length as `.exposure`.",
error_class = "propensity_length_error"
)
}
Expand All @@ -114,7 +129,7 @@ ps_calibrate <- function(
}

# Handle NA values
na_idx <- is.na(ps) | is.na(treat)
na_idx <- is.na(ps) | is.na(.exposure)

# Perform calibration based on method
calib_model <- NULL
Expand All @@ -128,7 +143,7 @@ ps_calibrate <- function(

# Create data frame for GAM fitting (only non-NA values)
calib_data <- data.frame(
treat = treat[!na_idx],
treat = .exposure[!na_idx],
ps = ps[!na_idx]
)

Expand Down Expand Up @@ -156,7 +171,7 @@ ps_calibrate <- function(
# For simple logistic regression, fit on original data (not data frame)
# This handles the case where smooth was originally FALSE or was set to FALSE due to fallback
if (is.null(calib_model)) {
calib_model <- stats::glm(treat ~ ps, family = stats::binomial())
calib_model <- stats::glm(.exposure ~ ps, family = stats::binomial())
}
}

Expand Down Expand Up @@ -191,7 +206,7 @@ ps_calibrate <- function(
if (any(na_idx)) {
# Work with non-NA values only
ps_valid <- ps[!na_idx]
treat_valid <- treat[!na_idx]
treat_valid <- .exposure[!na_idx]

# Order by propensity scores for isotonic regression
ord <- order(ps_valid)
Expand All @@ -214,7 +229,7 @@ ps_calibrate <- function(
# No NAs, proceed normally
ord <- order(ps)
ps_ordered <- ps[ord]
treat_ordered <- treat[ord]
treat_ordered <- .exposure[ord]

# Fit isotonic regression
iso_fit <- stats::isoreg(ps_ordered, treat_ordered)
Expand Down
Loading
Loading