Skip to content

Add residualize_over_grid() #386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 26 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e974402
Add `residualize_over_grid()`
strengejacke Feb 9, 2025
ff97ef2
docs
strengejacke Feb 9, 2025
12d2711
plot
strengejacke Feb 9, 2025
f3b0ad0
Merge branch 'main' into residualize_over_grid
strengejacke Feb 10, 2025
542dd31
Merge branch 'main' into residualize_over_grid
strengejacke Feb 10, 2025
b88e2d2
Merge branch 'main' into residualize_over_grid
strengejacke Feb 12, 2025
38caff7
Merge branch 'main' into residualize_over_grid
strengejacke Feb 14, 2025
a211dbb
Merge branch 'main' into residualize_over_grid
strengejacke Feb 14, 2025
ff4cdd4
Merge branch 'main' into residualize_over_grid
strengejacke Feb 14, 2025
8e3a88c
Merge branch 'main' into residualize_over_grid
strengejacke Feb 18, 2025
732f756
Merge branch 'main' into residualize_over_grid
strengejacke Feb 19, 2025
f60b84b
Merge branch 'main' into residualize_over_grid
strengejacke Feb 20, 2025
29dc7fb
Merge branch 'main' into residualize_over_grid
strengejacke Feb 20, 2025
c5b5c6d
Merge branch 'main' into residualize_over_grid
strengejacke Feb 20, 2025
59fdede
Merge branch 'main' into residualize_over_grid
strengejacke Feb 20, 2025
21edeb0
Merge branch 'main' into residualize_over_grid
strengejacke Feb 20, 2025
136beb3
Merge branch 'main' into residualize_over_grid
strengejacke Feb 22, 2025
996b693
Merge branch 'main' into residualize_over_grid
strengejacke Feb 23, 2025
42f4a6e
Merge branch 'main' into residualize_over_grid
strengejacke Mar 1, 2025
1b8403b
Merge branch 'main' into residualize_over_grid
DominiqueMakowski Mar 3, 2025
0be044b
Merge branch 'main' into residualize_over_grid
strengejacke Mar 4, 2025
ffbe932
Merge branch 'main' into residualize_over_grid
strengejacke Mar 12, 2025
cb007a1
Merge branch 'main' into residualize_over_grid
strengejacke Mar 13, 2025
6c2985e
Merge branch 'main' into residualize_over_grid
strengejacke Mar 20, 2025
a640d90
Merge branch 'main' into residualize_over_grid
strengejacke Mar 26, 2025
29c4b29
Merge branch 'main' into residualize_over_grid
strengejacke Apr 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ S3method(print_md,estimate_smooth)
S3method(print_md,visualisation_matrix)
S3method(reshape_grouplevel,data.frame)
S3method(reshape_grouplevel,estimate_grouplevel)
S3method(residualize_over_grid,data.frame)
S3method(residualize_over_grid,estimate_means)
S3method(smoothing,data.frame)
S3method(smoothing,numeric)
S3method(standardize,estimate_contrasts)
Expand Down Expand Up @@ -86,6 +88,7 @@ export(pool_slopes)
export(print_html)
export(print_md)
export(reshape_grouplevel)
export(residualize_over_grid)
export(smoothing)
export(standardize)
export(unstandardize)
Expand Down
171 changes: 171 additions & 0 deletions R/residualize_over_grid.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#' @title Compute partial residuals from a data grid
#' @name residualize_over_grid
#'
#' @description This function computes partial residuals based on a data grid,
#' where the data grid is usually a data frame from all combinations of factor
#' variables or certain values of numeric vectors. This data grid is usually used
#' as `newdata` argument in `predict()`, and can be created with
#' [`insight::get_datagrid()`].
#'
#' @param grid A data frame representing the data grid, or an object of class
#' `estimate_means` or `estimate_predicted`, as returned by the different
#' `estimate_*()` functions.
#' @param model The model for which to compute partial residuals. The data grid
#' `grid` should match to predictors in the model.
#' @param predictor_name The name of the focal predictor, for which partial residuals
#' are computed.
#' @param ... Currently not used.
#'
#' @section Partial Residuals:
#' For **generalized linear models** (glms), residualized scores are computed as
#' `inv.link(link(Y) + r)` where `Y` are the predicted values on the response
#' scale, and `r` are the *working* residuals.
#'
#' For (generalized) linear **mixed models**, the random effect are also
#' partialled out.
#'
#' @references
#' Fox J, Weisberg S. Visualizing Fit and Lack of Fit in Complex Regression
#' Models with Predictor Effect Plots and Partial Residuals. Journal of
#' Statistical Software 2018;87.
#'
#' @return A data frame with residuals for the focal predictor.
#'
#' @examplesIf requireNamespace("marginaleffects", quietly = TRUE)
#' set.seed(1234)
#' x1 <- rnorm(200)
#' x2 <- rnorm(200)
#' # quadratic relationship
#' y <- 2 * x1 + x1^2 + 4 * x2 + rnorm(200)
#'
#' d <- data.frame(x1, x2, y)
#' model <- lm(y ~ x1 + x2, data = d)
#'
#' pr <- estimate_means(model, c("x1", "x2"))
#' head(residualize_over_grid(pr, model))
#' @export
residualize_over_grid <- function(grid, model, ...) {
UseMethod("residualize_over_grid")

}


#' @rdname residualize_over_grid
#' @export
residualize_over_grid.data.frame <- function(grid, model, predictor_name, ...) {

old_d <- insight::get_predictors(model)
fun_link <- insight::link_function(model)
inv_fun <- insight::link_inverse(model)
predicted <- grid[[predictor_name]]
grid[[predictor_name]] <- NULL

is_fixed <- sapply(grid, function(x) length(unique(x))) == 1
grid <- grid[, !is_fixed, drop = FALSE]
old_d <- old_d[, colnames(grid)[colnames(grid) %in% colnames(old_d)], drop = FALSE]

if (!.is_grid(grid)) {
insight::format_error("Grid for partial residuals must be a fully crossed grid.")
}

# for each var
best_match <- NULL

for (p in colnames(old_d)) {
if (is.factor(old_d[[p]]) || is.logical(old_d[[p]]) || is.character(old_d[[p]])) {
grid[[p]] <- as.character(grid[[p]])
old_d[[p]] <- as.character(old_d[[p]])
} else {
grid[[p]] <- .validate_num(grid[[p]])
}

# if factor / logical / char in old data, find where it is equal
# if numeric in old data, find where it is closest
best_match <- .closest(old_d[[p]], grid[[p]], best_match = best_match)
}

idx <- apply(best_match, 2, which)
idx <- sapply(idx, "[", 1)

# extract working residuals
res <- .safe(stats::residuals(model, type = "working"))

# if failed, and model linear, extract response residuals
if (is.null(res)) {
minfo <- insight::model_info(model)
if (minfo$is_linear) {
res <- .safe(insight::get_residuals(model, type = "response"))
}
}

if (is.null(res)) {
insight::format_alert("Could not extract residuals.")
return(NULL)
}

my_points <- grid[idx, , drop = FALSE]
my_points[[predictor_name]] <- inv_fun(fun_link(predicted[idx]) + res) # add errors

my_points
}


#' @export
residualize_over_grid.estimate_means <- function(grid, model, ...) {
new_d <- as.data.frame(grid)

relevant_columns <- unique(c(
attributes(grid)$trend,
attributes(grid)$contrast,
attributes(grid)$focal_terms,
attributes(grid)$coef_name
))

new_d <- new_d[colnames(new_d) %in% relevant_columns]

residualize_over_grid(new_d, model, predictor_name = attributes(grid)$coef_name, ...)
}


# utilities --------------------------------------------------------------------


.is_grid <- function(df) {
unq <- lapply(df, unique)

if (prod(lengths(unq)) != nrow(df)) {
return(FALSE)
}

df2 <- do.call(expand.grid, args = unq)
df2$..1 <- 1

res <- merge(df, df2, by = colnames(df), all = TRUE)

sum(res$..1) == sum(df2$..1)
}


.closest <- function(x, target, best_match) {
if (is.numeric(x)) {
# AD <- outer(x, target, FUN = function(x, y) abs(x - y))
AD <- abs(outer(x, target, FUN = `-`))
idx <- apply(AD, 1, function(x) x == min(x))
} else {
idx <- t(outer(x, target, FUN = `==`))
}

if (is.matrix(best_match)) {
idx <- idx & best_match
}

idx
}


.validate_num <- function(x) {
if (!is.numeric(x)) {
x <- as.numeric(as.character(x))
}
x
}
2 changes: 2 additions & 0 deletions R/visualisation_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
#' will set a default value for the `modelbased_numeric_as_discrete` argument.
#' Can also be `FALSE`.
#'
#' @examplesIf all(insight::check_if_installed(c("marginaleffects", "see", "ggplot2"), quietly = TRUE)) && getRversion() >= "4.1.0"

Check warning on line 53 in R/visualisation_recipe.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/visualisation_recipe.R,line=53,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 131 characters.
#' library(ggplot2)
#' library(see)
#' # ==============================================
Expand Down Expand Up @@ -137,6 +137,7 @@
#' @export
visualisation_recipe.estimate_predicted <- function(x,
show_data = FALSE,
show_residuals = FALSE,
point = NULL,
line = NULL,
pointrange = NULL,
Expand All @@ -160,6 +161,7 @@
.visualization_recipe(
x,
show_data = show_data,
show_residuals = show_residuals,
point = point,
line = line,
pointrange = pointrange,
Expand Down
51 changes: 51 additions & 0 deletions R/visualisation_recipe_internal.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@


#' @keywords internal
.find_aes <- function(x, model_info = NULL, numeric_as_discrete = 8) {

Check warning on line 5 in R/visualisation_recipe_internal.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/visualisation_recipe_internal.R,line=5,col=1,[cyclocomp_linter] Reduce the cyclomatic complexity of this expression from 52 to at most 40.
# init basic aes
data <- as.data.frame(x)

Check warning on line 7 in R/visualisation_recipe_internal.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/visualisation_recipe_internal.R,line=7,col=3,[object_overwrite_linter] 'data' is an exported object from package 'utils'. Avoid re-using such symbols.
data$.group <- 1

att <- attributes(x)
Expand All @@ -18,11 +18,11 @@
model_response <- attributes(x)$response

# Find predictors
by <- att$focal_terms

Check warning on line 21 in R/visualisation_recipe_internal.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/visualisation_recipe_internal.R,line=21,col=3,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.

# multivariate response models? if so, we need one more stratification in "by"
if (isTRUE(model_info$is_ordinal | model_info$is_multinomial) && "Response" %in% colnames(data)) {
by <- c(by, "Response")

Check warning on line 25 in R/visualisation_recipe_internal.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/visualisation_recipe_internal.R,line=25,col=5,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
data$Response <- factor(data$Response, levels = unique(data$Response))
}

Expand Down Expand Up @@ -57,7 +57,7 @@
# we find the (range of) numeric values for the 2nd term where the interaction
# is "significant", i.e. p < 0.05. We want to map this to a special aes
# so we can color the ribbons accordingly
by <- c(by, "p")

Check warning on line 60 in R/visualisation_recipe_internal.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/visualisation_recipe_internal.R,line=60,col=7,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
significant <- data$p < 0.05
data$p <- "not significant"
data$p[significant] <- "significant"
Expand Down Expand Up @@ -100,7 +100,7 @@

# Assign predictors to aes
if (is.null(by)) {
by <- att$by

Check warning on line 103 in R/visualisation_recipe_internal.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/visualisation_recipe_internal.R,line=103,col=5,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
}
if (length(by) == 0) {
insight::format_error("No `by` variable was detected, so nothing to put in the x-axis.")
Expand Down Expand Up @@ -227,8 +227,9 @@


#' @keywords internal
.visualization_recipe <- function(x,

Check warning on line 230 in R/visualisation_recipe_internal.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/visualisation_recipe_internal.R,line=230,col=1,[cyclocomp_linter] Reduce the cyclomatic complexity of this expression from 51 to at most 40.
show_data = TRUE,
show_residuals = FALSE,
point = NULL,
line = NULL,
pointrange = NULL,
Expand All @@ -243,7 +244,7 @@
model_info <- attributes(x)$model_info

aes <- .find_aes(x, model_info, numeric_as_discrete)
data <- aes$data

Check warning on line 247 in R/visualisation_recipe_internal.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/visualisation_recipe_internal.R,line=247,col=3,[object_overwrite_linter] 'data' is an exported object from package 'utils'. Avoid re-using such symbols.
aes <- aes$aes
global_aes <- list()
layers <- list()
Expand All @@ -263,7 +264,7 @@
}

# Don't plot raw data for transformed responses with no back-transformation
transform <- attributes(x)$transform

Check warning on line 267 in R/visualisation_recipe_internal.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/visualisation_recipe_internal.R,line=267,col=3,[object_overwrite_linter] 'transform' is an exported object from package 'base'. Avoid re-using such symbols.

if (isTRUE(model_info$is_linear) && !isTRUE(transform)) {
# add information about response transformation
Expand All @@ -283,6 +284,15 @@
}


# add residual data as next lowest layer
if (show_residuals) {
layers[[paste0("l", l)]] <- .visualization_recipe_residuals(x, aes)
# Update with additional args
if (!is.null(point)) layers[[paste0("l", l)]] <- utils::modifyList(layers[[paste0("l", l)]], point)
l <- l + 1
}


# intercept line for slopes ----------------------------------
if (inherits(x, "estimate_slopes")) {
layers[[paste0("l", l)]] <- insight::compact_list(list(
Expand Down Expand Up @@ -496,3 +506,44 @@

out
}


# residuals ----------------------------------------------------------------


#' @keywords internal
.visualization_recipe_residuals <- function(x, aes) {
model <- attributes(x)$model
residual_data <- residualize_over_grid(x, model)

# Default changes for binomial models
shape <- 16
stroke <- 0
if (insight::model_info(model)$is_binomial) {
shape <- "|"
stroke <- 1
}

out <- list(
geom = "point",
data = residual_data,
aes = list(
y = y,
x = aes$x,
color = aes$color,
alpha = aes$alpha
),
height = 0,
shape = shape,
stroke = stroke
)

# set default alpha, if not mapped by aes
if (is.null(aes$alpha)) {
out$alpha <- 1 / 3
} else {
out$alpha <- NULL
}

out
}
64 changes: 64 additions & 0 deletions man/residualize_over_grid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/visualisation_recipe.estimate_predicted.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading