Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to log_carat/log_price in docu #154

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
### Documentation

- Add vignette for Tidymodels.
- Update "basic_use" vignette.
- Update vignettes.
- Update README.

# shapviz 0.9.4
Expand Down
23 changes: 15 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,27 @@ library(shapviz)
library(ggplot2)
library(xgboost)

set.seed(10)
set.seed(1)

# Build model
x <- c("carat", "cut", "color", "clarity")
dtrain <- xgb.DMatrix(data.matrix(diamonds[x]), label = diamonds$price)
fit <- xgb.train(params = list(learning_rate = 0.1), data = dtrain, nrounds = 65)
xvars <- c("log_carat", "cut", "color", "clarity")
X <- diamonds |>
transform(log_carat = log(carat)) |>
subset(select = xvars)

# Fit (untuned) model
fit <- xgb.train(
params = list(learning_rate = 0.1),
data = xgb.DMatrix(data.matrix(X), label = log(diamonds$price)),
nrounds = 65
)

# SHAP analysis: X can even contain factors
dia_2000 <- diamonds[sample(nrow(diamonds), 2000), x]
shp <- shapviz(fit, X_pred = data.matrix(dia_2000), X = dia_2000)
X_explain <- X[sample(nrow(X), 2000), ]
shp <- shapviz(fit, X_pred = data.matrix(X_explain), X = X_explain)

sv_importance(shp, show_numbers = TRUE)
sv_importance(shp, kind = "bee")
sv_dependence(shp, v = x) # patchwork
sv_dependence(shp, v = xvars) # patchwork
```

![](man/figures/README-imp.svg)
Expand Down
Binary file modified man/figures/README-bee.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/README-dep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
363 changes: 71 additions & 292 deletions man/figures/README-imp.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-dep-ranger.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-dep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-tidy-lgb-dep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-tidy-lgb-imp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-tidy-rf-dep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-tidy-rf-imp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-tidy-xgb-dep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-tidy-xgb-imp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-tidy-xgb-inter.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 16 additions & 11 deletions vignettes/basic_use.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -62,26 +62,31 @@ Shiny diamonds... let's use XGBoost to model their prices by the four "C" variab
library(shapviz)
library(ggplot2)
library(xgboost)
library(patchwork) # We will need its "&" operator

set.seed(1)

# Build model
x <- c("carat", "cut", "color", "clarity")
dtrain <- xgb.DMatrix(data.matrix(diamonds[x]), label = diamonds$price, nthread = 1)
xvars <- c("log_carat", "cut", "color", "clarity")
X <- diamonds |>
transform(log_carat = log(carat)) |>
subset(select = xvars)
head(X)

# Fit (untuned) model
fit <- xgb.train(
params = list(learning_rate = 0.1, nthread = 1), data = dtrain, nrounds = 65
params = list(learning_rate = 0.1, nthread = 1),
data = xgb.DMatrix(data.matrix(X), label = log(diamonds$price), nthread = 1),
nrounds = 65
)

# SHAP analysis: X can even contain factors
dia_2000 <- diamonds[sample(nrow(diamonds), 2000), x]
shp <- shapviz(fit, X_pred = data.matrix(dia_2000), X = dia_2000)
X_explain <- X[sample(nrow(X), 2000), ]
shp <- shapviz(fit, X_pred = data.matrix(X_explain), X = X_explain)

sv_importance(shp, show_numbers = TRUE)
sv_importance(shp, kind = "beeswarm") # kind = "both" combines bar and bee
sv_importance(shp, kind = "beeswarm")
```
```{r, fig.width=8.5, fig.height=5.5}
sv_dependence(shp, v = x) # patchwork object
sv_dependence(shp, v = xvars) # patchwork object
```

### Decompose single predictions
Expand Down Expand Up @@ -112,9 +117,9 @@ Note that SHAP interaction values are multiplied by two (except main effects).

```{r, fig.width=8.5, fig.height=5.5}
shp_i <- shapviz(
fit, X_pred = data.matrix(dia_2000[x]), X = dia_2000, interactions = TRUE
fit, X_pred = data.matrix(X_explain), X = X_explain, interactions = TRUE
)
sv_dependence(shp_i, v = "carat", color_var = x, interactions = TRUE)
sv_dependence(shp_i, v = "log_carat", color_var = xvars, interactions = TRUE)
sv_interaction(shp_i) +
theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1))
```
Expand Down
70 changes: 33 additions & 37 deletions vignettes/geographic.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -56,38 +56,38 @@ library(xgboost)
library(ggplot2)
library(shapviz)

head(miami)
miami <- miami |>
transform(
log_living = log(TOT_LVG_AREA),
log_land = log(LND_SQFOOT),
log_price = log(SALE_PRC)
)

x_coord <- c("LATITUDE", "LONGITUDE")
x_nongeo <- c("TOT_LVG_AREA", "LND_SQFOOT", "structure_quality", "age")
x <- c(x_coord, x_nongeo)
x_nongeo <- c("log_living", "log_land", "structure_quality", "age")
xvars <- c(x_coord, x_nongeo)

# Train/valid split
# Select training data
set.seed(1)
ix <- sample(nrow(miami), 0.8 * nrow(miami))
X_train <- data.matrix(miami[ix, x])
X_valid <- data.matrix(miami[-ix, x])
y_train <- log(miami$SALE_PRC[ix])
y_valid <- log(miami$SALE_PRC[-ix])

# Fit XGBoost model with early stopping
dtrain <- xgb.DMatrix(X_train, label = y_train, nthread = 1)
dvalid <- xgb.DMatrix(X_valid, label = y_valid, nthread = 1)

params <- list(
learning_rate = 0.2, objective = "reg:squarederror", max_depth = 5, nthread = 1
)

fit <- xgb.train(params = params, data = dtrain, nrounds = 200)
train <- miami[ix, ]
X_train <- train[xvars]
y_train <- train$log_price

# Fit XGBoost model
params <- list(learning_rate = 0.2, nthread = 1)
dtrain <- xgb.DMatrix(data.matrix(X_train), label = y_train, nthread = 1)
fit <- xgb.train(params, dtrain, nrounds = 200)
```

Let's first study selected SHAP dependence plots, evaluated on the validation dataset with around 2800 observations. Note that we could as well use (a subset of) the training data for this purpose.
Let's first study selected SHAP dependence plots for an explanation dataset of size 2000.

```{r}
sv <- shapviz(fit, X_pred = X_valid)
X_explain <- X_train[1:2000, ]
sv <- shapviz(fit, X_pred = data.matrix(X_explain))
sv_dependence(
sv,
v = c("TOT_LVG_AREA", "structure_quality", "LONGITUDE", "LATITUDE"),
v = c("log_living", "structure_quality", "LONGITUDE", "LATITUDE"),
alpha = 0.2
)

Expand Down Expand Up @@ -115,34 +115,30 @@ The second step leads to a model that is additive in each non-geographic compone
```{r}
# Extend the feature set
more_geo <- c("CNTR_DIST", "OCEAN_DIST", "RAIL_DIST", "HWY_DIST")
x2 <- c(x, more_geo)
xvars <- c(xvars, more_geo)
X_train <- train[xvars]
dtrain <- xgb.DMatrix(data.matrix(X_train), label = y_train, nthread = 1)

X_train2 <- data.matrix(miami[ix, x2])
X_valid2 <- data.matrix(miami[-ix, x2])

dtrain2 <- xgb.DMatrix(X_train2, label = y_train, nthread = 1)
dvalid2 <- xgb.DMatrix(X_valid2, label = y_valid, nthread = 1)

# Build interaction constraint vector
# Build interaction constraint vector and add it to params
ic <- c(
list(which(x2 %in% c(x_coord, more_geo)) - 1),
as.list(which(x2 %in% x_nongeo) - 1)
list(which(xvars %in% c(x_coord, more_geo)) - 1),
as.list(which(xvars %in% x_nongeo) - 1)
)

# Modify parameters
params$interaction_constraints <- ic

fit2 <- xgb.train(params = params, data = dtrain2, nrounds = 200)
# Fit XGBoost model
fit <- xgb.train(params, dtrain, nrounds = 200)

# SHAP analysis
sv2 <- shapviz(fit2, X_pred = X_valid2)
X_explain <- X_train[2:2000, ]
sv <- shapviz(fit, X_pred = data.matrix(X_explain))

# Two selected features: Thanks to additivity, structure_quality can be read as
# Ceteris Paribus
sv_dependence(sv2, v = c("structure_quality", "LONGITUDE"), alpha = 0.2)
sv_dependence(sv, v = c("structure_quality", "LONGITUDE"), alpha = 0.2)

# Total geographic effect (Ceteris Paribus thanks to additivity)
sv_dependence2D(sv2, x = "LONGITUDE", y = "LATITUDE", add_vars = more_geo) +
sv_dependence2D(sv, x = "LONGITUDE", y = "LATITUDE", add_vars = more_geo) +
coord_equal()
```

Expand Down
55 changes: 35 additions & 20 deletions vignettes/tidymodels.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,16 @@ library(shapviz)

set.seed(10)

splits <- initial_split(diamonds)
splits <- diamonds |>
transform(
log_price = log(price),
log_carat = log(carat)
) |>
initial_split()
df_train <- training(splits)

dia_recipe <- df_train |>
recipe(price ~ carat + color + clarity + cut)
recipe(log_price ~ log_carat + color + clarity + cut)

rf <- rand_forest(mode = "regression") |>
set_engine("ranger")
Expand All @@ -41,10 +46,10 @@ fit <- rf_wf |>
fit(df_train)

# SHAP analysis
xvars <- c("carat", "color", "clarity", "cut")
xvars <- c("log_carat", "color", "clarity", "cut")
X_explain <- df_train[1:1000, xvars] # Use only feature columns

# 90 seconds on laptop
# 1.5 minutes on laptop
# Note: If you have more than p=8 features, use kernelshap() instead of permshap()
system.time(
shap_values <- fit |>
Expand Down Expand Up @@ -78,18 +83,23 @@ of course, you don't *have* to work with SHAP interactions, especially if your m

**Remark:** Don't use 1:m transforms such as One-Hot-Encodings. They are usually not necessary and make the workflow more complicated. If you can't avoid this, check the `collapse` argument in `shapviz()`.

```
```r
library(tidymodels)
library(shapviz)
library(patchwork)

set.seed(10)

splits <- initial_split(diamonds)
splits <- diamonds |>
transform(
log_price = log(price),
log_carat = log(carat)
) |>
initial_split()
df_train <- training(splits)

dia_recipe <- df_train |>
recipe(price ~ carat + color + clarity + cut) |>
recipe(log_price ~ log_carat + color + clarity + cut) |>
step_integer(all_ordered())

# Should be tuned in practice
Expand Down Expand Up @@ -126,22 +136,22 @@ shap_values |>
# Absolute average SHAP interactions (off-diagonals already multiplied by 2)
shap_values |>
sv_interaction(kind = "no")
# carat clarity color cut
# carat 2998.30769 591.8859 425.63902 99.11383
# clarity 591.88589 632.2544 192.14847 25.47713
# color 425.63906 192.1484 424.91991 20.15823
# cut 99.11392 25.4771 20.15823 109.26374
# log_carat clarity color cut
# log_carat 0.87400688 0.067567245 0.032599394 0.024273852
# clarity 0.06756720 0.143393109 0.028236784 0.004910905
# color 0.03259941 0.028236796 0.095656042 0.004804729
# cut 0.02427382 0.004910904 0.004804732 0.031114735

# Usual dependence plot
xvars <- c("carat", "color", "clarity", "cut")
xvars <- c("log_carat", "color", "clarity", "cut")

shap_values |>
sv_dependence(xvars) &
plot_annotation("SHAP dependence plots") # patchwork magic

# SHAP interactions for carat
shap_values |>
sv_dependence("carat", color_var = xvars, interactions = TRUE) &
sv_dependence("log_carat", color_var = xvars, interactions = TRUE) &
plot_annotation("SHAP interactions for carat")
```
![](../man/figures/VIGNETTE-tidy-xgb-imp.png)
Expand All @@ -164,11 +174,16 @@ library(shapviz)

set.seed(10)

splits <- initial_split(diamonds)
splits <- diamonds |>
transform(
log_price = log(price),
log_carat = log(carat)
) |>
initial_split()
df_train <- training(splits)

dia_recipe <- df_train |>
recipe(price ~ carat + color + clarity + cut) |>
recipe(price ~ log_carat + color + clarity + cut) |>
step_integer(color, clarity) # we keep cut a factor (for illustration only)

# Should be tuned in practice
Expand All @@ -193,9 +208,9 @@ X_pred <- bake( # Goes to lightgbm:::predict.lgb.Booster()
bonsai:::prepare_df_lgbm()

head(X_pred, 2)
# carat color clarity cut
# [1,] 1.37 5 5 3
# [2,] 0.55 2 3 4
# log_carat color clarity cut
# [1,] 0.3148107 5 5 3
# [2,] -0.5978370 2 3 4

stopifnot(colnames(X_pred) %in% colnames(df_explain))

Expand All @@ -206,7 +221,7 @@ shap_values |>
sv_importance(show_numbers = TRUE)

shap_values |>
sv_dependence(c("carat", "color", "clarity", "cut"))
sv_dependence(c("log_carat", "color", "clarity", "cut"))
```
![](../man/figures/VIGNETTE-tidy-lgb-imp.png)
![](../man/figures/VIGNETTE-tidy-lgb-dep.png)
Expand Down
Loading