Skip to content

Commit

Permalink
fixed a bug in how the model weights were used to create the plots
Browse files Browse the repository at this point in the history
  • Loading branch information
bertcarnell committed Jul 31, 2024
1 parent d4423d8 commit 405ec6c
Show file tree
Hide file tree
Showing 19 changed files with 238 additions and 86 deletions.
34 changes: 30 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: tornado
Title: Plots for Model Sensitivity and Variable Importance
Version: 0.1.4
Version: 0.2.0
Authors@R:
person(given = "Rob",
family = "Carnell",
Expand All @@ -16,15 +16,41 @@ Suggests:
randomForest,
knitr,
rmarkdown
RoxygenNote: 7.3.0
RoxygenNote: 7.3.2
Imports:
survival,
assertthat,
ggplot2,
scales,
grid,
gridExtra,
rlang
rlang,
Hmisc
VignetteBuilder: knitr
URL: https://bertcarnell.github.io/tornado/, https://github.com/bertcarnell/tornado
BugReports: https://github.com/bertcarnell/tornado/issues
Collate:
'create_data_low_high.R'
'create_dict.R'
'create_endpoints.R'
'create_factor_plot_data.R'
'create_importance_data.R'
'create_means.R'
'create_plot_data.R'
'importance.R'
'importance_glm.R'
'importance_glmnet.R'
'importance_lm.R'
'importance_survreg.R'
'importance_train.R'
'plot_importance_plot.R'
'plot_tornado_plot.R'
'print_importance_plot.R'
'print_tornado_plot.R'
'quantile_ordered.R'
'tornado.R'
'tornado_coxph.R'
'tornado_glm.R'
'tornado_glmnet.R'
'tornado_lm.R'
'tornado_survreg.R'
'tornado_train.R'
4 changes: 3 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export(importance)
export(tornado)
import(ggplot2)
import(survival)
importFrom(assertthat,assert_that)
importFrom(Hmisc,wtd.quantile)
importFrom(grDevices,dev.cur)
importFrom(grDevices,dev.off)
importFrom(grid,gpar)
Expand All @@ -41,3 +41,5 @@ importFrom(stats,model.matrix)
importFrom(stats,predict)
importFrom(stats,quantile)
importFrom(stats,terms)
importFrom(stats,weighted.mean)
importFrom(stats,weights)
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# version 0.2.0 (2024-07-30)

- bug report from @ghobro (https://github.com/ghobro) for weighted models

# version 0.1.4 (2024-01-18)

- Update to package website. Thanks to @olivroy for the pull request.
Expand Down
11 changes: 6 additions & 5 deletions R/create_dict.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#'
#' @return the dictionary data.frame
#'
#' @importFrom assertthat assert_that
#' @noRd
.create_dict <- function(dict, used_variables)
{
Expand All @@ -21,10 +20,12 @@
}
} else
{
assertthat::assert_that(all(names(dict) == c("old", "new")),
msg = "The variable name translation dictionary must be a list or data.frame with components old and new")
assertthat::assert_that(all(used_variables %in% dict$old),
msg = "All the variables used in the model must be in dict$old")
if (!all(names(dict) == c("old", "new"))) {
stop("The variable name translation dictionary must be a list or data.frame with components old and new")
}
if (!all(used_variables %in% dict$old)) {
stop("All the variables used in the model must be in dict$old")
}
}
return(dict)
}
62 changes: 40 additions & 22 deletions R/create_endpoints.R
Original file line number Diff line number Diff line change
@@ -1,57 +1,72 @@
# Copyright 2021 Robert Carnell

.allowed_types <- c("PercentChange", "percentiles", "ranges", "StdDev")

#' Create variable endpoints for tornado plots
#'
#' @param training_data the data.frame with training data
#' @param means the data.frame with variable means
#' @param type the type of tornado plot
#' @param alpha the percentile or alpha level
#' @param wt model weights
#'
#' @importFrom stats quantile
#' @importFrom Hmisc wtd.quantile
#'
#' @return a list of the endpoints and levels
#' @noRd
.create_endpoints <- function(training_data, means, type, alpha)
.create_endpoints <- function(training_data, means, type, alpha, wt = NA)
{
which_factor <- which(sapply(training_data, is.factor))
lmeans <- length(means)
assertthat::assert_that(type %in% c("PercentChange", "percentiles", "ranges", "StdDev"),
msg = "Type must be one of PercentChange, percentiles, ranges, StdDev")

# if (type == "PercentChange" && length(which_factor) > 0)
# {
# warning("The PercentChange method will not show variation for factor variables")
# } else if (type == "percentiles" && length(which_factor) > 0)
# {
# warning("The percentiles method will not show variation for factor variables")
# }
if (length(type) != 1) {
stop(paste0("type must be a singleton and must be one of ", paste(.allowed_types, collapse = ",")))
}
if (!(type %in% .allowed_types))
{
stop(paste0("type must be one of ", paste(.allowed_types, collapse = ",")))
}

## All factors
if (lmeans == length(which_factor))
{
endpoints <- as.data.frame(matrix(NA, nrow = 2, ncol = lmeans))
names(endpoints) <- names(means)
Level = NA
## percentiles
} else if (type == "percentiles" && alpha > 0 && alpha < 0.5)
{
if (length(which_factor) > 0)
{
endpoints <- data.frame(
apply(training_data[,-which_factor], 2, stats::quantile, probs = c(alpha, 1 - alpha))
)
if (any(is.na(wt))) {
endpoints <- data.frame(
apply(training_data[,-which_factor], 2, stats::quantile, probs = c(alpha, 1 - alpha))
)
} else {
endpoints <- data.frame(
apply(training_data[,-which_factor], 2, Hmisc::wtd.quantile, weights = wt, probs = c(alpha, 1 - alpha))
)
}
names(endpoints) <- names(means)[-which_factor]
endpoints2 <- data.frame(lapply(means[,which_factor], function(z) rep(z, 2)))
names(endpoints2) <- names(means)[which_factor]
endpoints <- cbind(endpoints, endpoints2)
} else
{
endpoints <- data.frame(
apply(training_data, 2, stats::quantile, probs = c(alpha, 1 - alpha))
)
if (any(is.na(wt))) {
endpoints <- data.frame(
apply(training_data, 2, stats::quantile, probs = c(alpha, 1 - alpha))
)
} else {
endpoints <- data.frame(
apply(training_data, 2, Hmisc::wtd.quantile, weights = wt, probs = c(alpha, 1 - alpha))
)
}
names(endpoints) <- names(means)
}
Level <- c(paste0(round(alpha*100,0),"th"),
paste0(round((1 - alpha)*100,0), "th"))
## PercentChange
} else if (type == "PercentChange" && alpha > 0)
{
if (length(which_factor) > 0)
Expand All @@ -71,6 +86,7 @@
names(endpoints) <- names(means)
}
Level <- scales::percent(c(1 - alpha, 1 + alpha))
## ranges
} else if (type == "ranges")
{
if (length(which_factor) > 0)
Expand All @@ -88,24 +104,26 @@
names(endpoints) <- names(means)
}
Level <- c("Lower","Upper")
## StdDev
} else if (type == "StdDev" && alpha > 0)
{
sdf <- function(z)
sdf <- function(z, wt)
{
c(mean(z) - alpha*stats::sd(z), mean(z) + alpha*stats::sd(z))
v <- ifelse(any(is.na(wt)), stats::sd(z), sqrt(Hmisc::wtd.var(z, wt)))
c(mean(z) - alpha*v, mean(z) + alpha*v)
}
if (length(which_factor) > 0)
{
endpoints <- as.data.frame(
apply(training_data[,-which_factor], 2, sdf)
apply(training_data[,-which_factor], 2, sdf, wt = wt)
)
names(endpoints) <- names(means)[-which_factor]
endpoints2 <- data.frame(lapply(means[,which_factor], function(z) rep(z, 2)))
names(endpoints2) <- names(means)[which_factor]
endpoints <- cbind(endpoints, endpoints2)
} else
{
endpoints <- as.data.frame(apply(training_data, 2, sdf))
endpoints <- as.data.frame(apply(training_data, 2, sdf, wt = wt))
names(endpoints) <- names(means)
}
# grid graphics will not the multi-byte character encodings correctly
Expand All @@ -116,7 +134,7 @@
paste0("mean + ", alpha, "*std"))
} else
{
stop("command not recognized")
stop("type not recognized")
}

return(list(endpoints = endpoints, Level = Level))
Expand Down
6 changes: 3 additions & 3 deletions R/create_importance_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#' @param isDeviance TRUE for glms and other deviance residual based models
#'
#' @importFrom stats add1 anova formula
#' @importFrom assertthat assert_that
#'
#' @return a data.frame
#' @noRd
Expand Down Expand Up @@ -66,8 +65,9 @@
add1_contr <- model_add1$`Sum of Sq`[-1] / sum(model_anova$`Sum Sq`)
}

assertthat::assert_that(all(var_final == rownames(model_add1)[-1]),
msg = "Unexpected Internal error")
if (!all(var_final == rownames(model_add1)[-1])) {
stop("Unexpected Internal error")
}
dev_add1 <- data.frame(vars = var_final, add1 = add1_contr,
stringsAsFactors = FALSE)

Expand Down
27 changes: 22 additions & 5 deletions R/create_means.R
Original file line number Diff line number Diff line change
@@ -1,24 +1,41 @@
# Copyright 2021 Robert Carnell

#' Create training data column means for tornado plots
#' Create training data column means or weighted column means for tornado plots
#'
#' @param training_data a data.frame
#' @param wt model weights
#'
#' @importFrom stats weighted.mean
#'
#' @return a data.frame of means
#' @noRd
.create_means <- function(training_data)
.create_means <- function(training_data, wt = NA)
{
if (any(is.na(training_data))) {
stop("NA values not permitted in training_data in .create_means")
}
means <- data.frame(lapply(training_data, function(x)
{
if (is.numeric(x))
{
if (any(is.na(wt))) {
return(mean(x))
} else {
return(stats::weighted.mean(x, wt))
}
return(mean(x))
} else if (is.factor(x))
{
# pick the most frequent class
tt <- table(x)
ttmax <- names(tt[which.max(tt)])
return(factor(ttmax, levels = levels(x)))
if (any(is.na(wt))) {
tt <- table(x)
ttmax <- names(tt[which.max(tt)])
return(factor(ttmax, levels = levels(x)))
} else {
tt <- by(wt, x, sum)
ttmax <- names(tt[which.max(tt)])
return(factor(ttmax, levels = levels(x)))
}
}
}))
return(means)
Expand Down
23 changes: 15 additions & 8 deletions R/create_plot_data.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2019 Robert Carnell
#' @include create_endpoints.R

#' Internal Method to create Plot Data in tornado plots
#'
Expand All @@ -11,8 +12,7 @@
#'
#' @return the data to create the tornado plot
#'
#' @importFrom assertthat assert_that
#' @importFrom stats terms
#' @importFrom stats terms weights
#' @noRd
.create_plot_data <- function(model, modeldata, type="PercentChange", alpha=0.10,
dict=NA, predict_type = "response")
Expand Down Expand Up @@ -48,21 +48,28 @@
# alpha <- 0.10
# dict <- NA
# }
assertthat::assert_that(is.data.frame(modeldata),
msg = "The data must be contained in a data.frame")
assertthat::assert_that(type %in% c("PercentChange","percentiles","ranges", "StdDev"),
msg = "type must be PercentChagne, percentiles, ranges, StdDev")
if (!is.data.frame(modeldata)) {
stop("The data must be contained in a data.frame")
}
if (!(type %in% .allowed_types)) {
stop(paste0("type must be in ", paste(.allowed_types, collapse=",")))
}

used_variables <- rownames(attr(stats::terms(model), "factors"))[-1]

dict <- .create_dict(dict, used_variables)

training_data <- subset(modeldata, select = used_variables)
means <- .create_means(training_data)
if (is.null(stats::weights(model))) {
model_weights <- NA
} else {
model_weights <- stats::weights(model)
}
means <- .create_means(training_data, model_weights)
names_means <- names(means)
lmeans <- length(means)

ret <- .create_endpoints(training_data, means, type, alpha)
ret <- .create_endpoints(training_data, means, type, alpha, model_weights)
endpoints <- ret$endpoints
Level <- ret$Level
base_Level <- c("A","B")
Expand Down
6 changes: 3 additions & 3 deletions R/importance_glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#' @seealso \code{\link{importance}}
#'
#' @importFrom stats model.matrix model.frame
#' @importFrom assertthat assert_that
#'
#' @examples
#' if (requireNamespace("glmnet", quietly = TRUE))
Expand All @@ -38,8 +37,9 @@ importance.cv.glmnet <- function(model_final, model_data, form, dict = NA, nperm
# nperm <- 100
# geom_bar_control <- list(fill = "green")

assertthat::assert_that(requireNamespace("glmnet", quietly = TRUE),
msg = "The glmnet package is required to use this method")
if (!requireNamespace("glmnet", quietly = TRUE)) {
stop("The glmnet package is required to use this method")
}

otherVariables <- list(...)
modelframe <- stats::model.frame(form, data = model_data)
Expand Down
6 changes: 3 additions & 3 deletions R/plot_importance_plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#' @importFrom gridExtra arrangeGrob
#' @importFrom scales percent
#' @importFrom grDevices dev.cur dev.off
#' @importFrom assertthat assert_that
#' @importFrom rlang .data
#'
#' @examples
Expand All @@ -40,8 +39,9 @@ plot.importance_plot <- function(x, plot = TRUE, nvar = NA,
geom_bar_control = list(fill = "#69BE28"),
...)
{
assertthat::assert_that(length(nvar) == 1,
msg = "nvar must be a length 1 integer or NA")
if (length(nvar) != 1) {
stop("nvar must be a length 1 integer or NA")
}

if (x$type %in% c("lm", "glm"))
{
Expand Down
Loading

0 comments on commit 405ec6c

Please sign in to comment.