Skip to content

Commit

Permalink
improve documentation, some code cleaning and organization
Browse files Browse the repository at this point in the history
  • Loading branch information
donishadsmith committed Jun 19, 2024
1 parent 548e3c6 commit 7a64bee
Show file tree
Hide file tree
Showing 15 changed files with 521 additions and 330 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ All notable future changes to vswift will be documented in this file.

**As this package is still in the version 0.x.x series, aspects of the package may change rapidly to improve convenience and ease of use.**

**Additionally, as of version 0.1.1, versioning for the 0.x.x series for this package will work as:**
**Additionally, beyond version 0.1.1, versioning for the 0.x.x series for this package will work as:**

`0.minor.patch`

Expand All @@ -15,8 +15,14 @@ noted in the changelog (i.e new functions or parameters, changes in parameter de
- *.patch* : Contains no new features, simply fixes any identified bugs.

## [0.1.1] - 2024-06-19
### ♻ Changed
- Changed order of parameters for ``classCV()`` function.

### 🐛 Fixes
- Standardizes validation data using the mean and standard deviation of the training set.

### 💻 Metadata
- Improved documentation.

## [0.1.0] - 2024-05-13
- First release of vswift package
244 changes: 153 additions & 91 deletions pkg/vswift/R/classCV.R

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pkg/vswift/R/create_dictionary.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
warning(sprintf("classes are now encoded: %s", paste(new_classes, collapse = ", ")))
}
return(classCV_output)
}
}
2 changes: 0 additions & 2 deletions pkg/vswift/R/expand_dataframe.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,3 @@
return(classCV_output)
}
}


35 changes: 20 additions & 15 deletions pkg/vswift/R/genFolds.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
#' Create split datasets and/or folds with optional stratification
#'
#' @name genFolds
#' @description Standalone function generates train-test split datasets and/or k-fold cross-validation folds, with the option to perform stratified sampling based on class distribution.
#' @description Standalone function generates train-test split datasets and/or k-fold cross-validation folds, with the
#' option to perform stratified sampling based on class distribution.
#'
#' @param data A data frame.
#' @param target A numerical index or character name for the target variable. Only needs to be specified if stratified = TRUE. Default = NULL.
#' @param split A numerical value between 0.5 to 0.9 indicating the proportion of data to use for the training set, leaving the rest for the test set. If not specified, train-test splitting will not be done.
#' @param n_folds A numerical value between 3-30 indicating the number of k-folds. If left empty, k-fold cross validation will not be performed.
#' @param stratified A logical value indicating if stratified sampling should be used. Default = FALSE.
#' @param random_seed A numerical value for the random seed to be used. Default = NULL.
#' @param create_data A logical value indicating whether to create all training and test/validation data frames. Default = FALSE.
#' @return A list containing the indices for train-test splitting and/or k-fold cross-validation, with information on the class distribution in the training, test sets, and folds (if applicable)
#' as well as the generated split datasets and folds based on the indices.
#' @param target A numerical index or character name for the target variable. Only needs to be specified if
#' \code{stratified = TRUE}. Default = \code{NULL}.
#' @param split A numerical value between 0.5 to 0.9 indicating the proportion of data to use for the training set,
#' leaving the rest for the test set. If not specified, train-test splitting will not be done.
#' Default = \code{NULL}.
#' @param n_folds A numerical value between 3-30 indicating the number of k-folds. If left empty, k-fold cross
#' validation will not be performed. Default = \code{NULL}.
#' @param stratified A logical value indicating if stratified sampling should be used. Default = \code{FALSE}.
#' @param random_seed A numerical value for the random seed to be used. Default = \code{NULL}.
#' @param create_data A logical value indicating whether to create all training and test/validation data frames.
#' Default = \code{FALSE}.
#' @return A list containing the indices for train-test splitting and/or k-fold cross-validation, with information on
#' the class distribution in the training, test sets, and folds (if applicable) as well as the generated split
#' datasets and folds based on the indices.
#' @examples
#' # Load example dataset
#'
Expand All @@ -24,9 +31,11 @@
#' @author Donisha Smith
#' @export

genFolds <- function(data, target = NULL, split = NULL, n_folds = NULL, stratified = FALSE, random_seed = NULL, create_data = FALSE){
genFolds <- function(data, target = NULL, split = NULL, n_folds = NULL, stratified = FALSE, random_seed = NULL,
create_data = FALSE){
# Check input
.error_handling(data = data, target = target, n_folds = n_folds, split = split, stratified = stratified, random_seed = random_seed, call = "stratified_split")
.error_handling(data = data, target = target, n_folds = n_folds, split = split, stratified = stratified,
random_seed = random_seed, call = "stratified_split")
# Set seed
if(!is.null(random_seed)){
set.seed(random_seed)
Expand Down Expand Up @@ -172,7 +181,3 @@ genFolds <- function(data, target = NULL, split = NULL, n_folds = NULL, stratif
}
return(output)
}




63 changes: 37 additions & 26 deletions pkg/vswift/R/plot.vswift.R
Original file line number Diff line number Diff line change
@@ -1,25 +1,32 @@
#' Plot model evaluation metrics
#'
#' @name plot
#' @description Plots model evaluation metrics (classification accuracy and precision, recall, and f-score for each class) from a vswift x.
#' @description Plots model evaluation metrics (classification accuracy and precision, recall, and f-score for each
#' class) from a vswift object.
#'
#' @param x An x of class vswift.
#' @param split A logical value indicating whether to plot metrics for train-test splitting results. Default = TRUE.
#' @param cv A logical value indicating whether to plot metrics for k-fold cross-validation results. Note: Solid red line represents the mean
#' and dashed blue line represents the standard deviation. Default = TRUE.
#' @param metrics A vector consisting of which metrics to plot. Available metrics includes, "accuracy", "precision", "recall", "f1".
#' Default = c("accuracy","precision", "recall", "f1").
#' @param class_names A vector consisting of class names to plot. If NULL, plots are generated for each class.Defaeult = NULL
#' @param save_plots A logical value to save all plots as separate png files. Plot will not be displayed if set to TRUE. Default = FALSE.
#' @param path A character representing the file location, with trailing slash, to save to. If not specified, the plots will be saved to the current
#' working directory.
#' @param model_type A character or vector of the model metrics to be printed. If not specified, all model metrics will be printed. Available options:
#' "lda" (Linear Discriminant Analysis), "qda" (Quadratic Discriminant Analysis),
#' "logistic" (Logistic Regression), "svm" (Support Vector Machines), "naivebayes" (Naive Bayes),
#' "ann" (Artificial Neural Network), "knn" (K-Nearest Neighbors), "decisiontree" (Decision Tree),
#' "randomforest" (Random Forest), "multinom" (Multinomial Logistic Regression), "gbm" (Gradient Boosting Machine).
#' @param ... Additional arguments that can be passed to the `png()` function.
#'
#' @param x An vswift object.
#' @param split A logical value indicating whether to plot metrics for train-test splitting results.
#' Default = \code{TRUE}.
#' @param cv A logical value indicating whether to plot metrics for k-fold cross-validation results.
#' Note: Solid red line represents the mean and dashed blue line represents the standard deviation.
#' Default = \code{TRUE}.
#' @param metrics A vector consisting of which metrics to plot. Available metrics includes, \code{"accuracy"},
#' \code{"precision"}, \code{"recall"}, \code{"f1"}.
#' Default = \code{c("accuracy","precision", "recall", "f1")}.
#' @param class_names A vector consisting of class names to plot. If NULL, plots are generated for each class.
#' Default = \code{NULL}.
#' @param save_plots A logical value to save all plots as separate png files. Plot will not be displayed if set to TRUE.
#' Default = \code{FALSE}.
#' @param path A character representing the file location, with trailing slash, to save to. If not specified, the plots
#' will be saved to the current working directory. Default = \code{NULL}.
#' @param model_type A character or vector of the model metrics to be printed. If \code{NULL}, all model metrics will
#' be printed. Available options: \code{"lda"} (Linear Discriminant Analysis), \code{"qda"}
#' (Quadratic Discriminant Analysis), code{"logistic"} (Logistic Regression), \code{"svm"}
#' (Support Vector Machines), \code{"naivebayes"} (Naive Bayes), \code{"ann"}
#' (Artificial Neural Network), \code{"knn"} (K-Nearest Neighbors), \code{"decisiontree"}
#' (Decision Tree), \code{"randomforest"} (Random Forest), \code{"multinom"}
#' (Multinomial Logistic Regression), \code{"gbm"} (Gradient Boosting Machine). Default = \code{NULL}.
#' @param ... Additional arguments that can be passed to the \code{png()} function.
#'
#' @return Plots representing evaluation metrics.
#' @examples
Expand All @@ -41,14 +48,16 @@
#' @importFrom graphics abline axis
#' @export

"plot.vswift" <- function(x, ..., split = TRUE, cv = TRUE, metrics = c("accuracy","precision", "recall", "f1"), class_names = NULL, save_plots = FALSE, path = NULL, model_type = NULL){
"plot.vswift" <- function(x, ..., split = TRUE, cv = TRUE, metrics = c("accuracy","precision", "recall", "f1"),
class_names = NULL, save_plots = FALSE, path = NULL, model_type = NULL){

if(inherits(x, "vswift")){
# Create list
model_list = list("lda" = "Linear Discriminant Analysis", "qda" = "Quadratic Discriminant Analysis", "svm" = "Support Vector Machines",
"ann" = "Neural Network", "decisiontree" = "Decision Tree", "randomforest" = "Random Forest", "gbm" = "Gradient Boosted Machine",
"multinom" = "Multinomial Logistic Regression", "logistic" = "Logistic Regression", "knn" = "K-Nearest Neighbors",
"naivebayes" = "Naive Bayes")
model_list = list("lda" = "Linear Discriminant Analysis", "qda" = "Quadratic Discriminant Analysis",
"svm" = "Support Vector Machines", "ann" = "Neural Network", "decisiontree" = "Decision Tree",
"randomforest" = "Random Forest", "gbm" = "Gradient Boosted Machine",
"multinom" = "Multinomial Logistic Regression", "logistic" = "Logistic Regression",
"knn" = "K-Nearest Neighbors","naivebayes" = "Naive Bayes")

# Lowercase and intersect common names
metrics <- intersect(unlist(lapply(metrics, function(x) tolower(x))), c("accuracy","precision", "recall", "f1"))
Expand Down Expand Up @@ -83,12 +92,14 @@
# Iterate over models
for(model in models){
if(save_plots == FALSE){
.visible_plots(x = x, split = split, cv = cv, metrics = metrics, class_names = class_names, model_name = model, model_list = model_list)
.visible_plots(x = x, split = split, cv = cv, metrics = metrics, class_names = class_names, model_name = model,
model_list = model_list)
} else {
.save_plots(x = x, split = split, cv = cv, metrics = metrics, class_names = class_names, path = path, model_name = model, model_list = model_list,...)
.save_plots(x = x, split = split, cv = cv, metrics = metrics, class_names = class_names, path = path,
model_name = model, model_list = model_list,...)
}
}
} else {
stop("x must be of class 'vswift'")
}
}
}
47 changes: 30 additions & 17 deletions pkg/vswift/R/preprocess.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
#' @importFrom parallel detectCores
#' @noRd
#' @export
.error_handling <- function(formula = NULL, data = NULL, target = NULL, predictors = NULL, split = NULL, n_folds = NULL, model_type = NULL, threshold = NULL, stratified = NULL, random_seed = NULL,
impute_method = NULL, impute_args = NULL, mod_args = NULL, n_cores = NULL, standardize = NULL, call = NULL, ...){
.error_handling <- function(formula = NULL, data = NULL, target = NULL, predictors = NULL, split = NULL, n_folds = NULL,
model_type = NULL, threshold = NULL, stratified = NULL, random_seed = NULL,
impute_method = NULL, impute_args = NULL, mod_args = NULL, n_cores = NULL,
standardize = NULL, call = NULL, ...){

# List of valid inputs
valid_inputs <- list(valid_models = c("lda","qda","logistic","svm","naivebayes","ann","knn","decisiontree",
Expand All @@ -13,8 +15,8 @@

# Check standardize
if(!is.null(standardize)){
if(!any(standardize == TRUE, standardize == FALSE, is.numeric(standardize), is.integer(standardize), is.character(standardize))){
stop("`standardize` must either be TRUE, FALSE, or a numeric vector")
if(!inherits(standardize, c("logical", "numeric", "integer", "character"))){
stop("`standardize` must either be TRUE, FALSE, a numeric vector, or a character vector")
}
}
# Check if impute method is valid
Expand Down Expand Up @@ -80,7 +82,8 @@
# Get target and predictors if formula specified
if(!is.null(formula)){
if(any(!is.null(formula) & !is.null(target) || !is.null(predictors))){
warning("`formula` specified with `target` and/or `predictors`, `formula` will overwrite the specified `target` and `predictors`")
warning("`formula` specified with `target` and/or `predictors`, `formula` will overwrite the specified `target`
and `predictors`")
}
get_features_target <- .get_features_target(formula = formula, data = data)
target <- get_features_target[["target"]]
Expand Down Expand Up @@ -170,7 +173,8 @@
stop("number of cores must be a numeric value")
}
if(n_cores > detectCores()){
stop(sprintf("more cores specified than available; only %s cores available but %s cores specified", detectCores(), n_cores))
stop(sprintf("more cores specified than available; only %s cores available but %s cores specified",
detectCores(), n_cores))
}
}
}
Expand Down Expand Up @@ -205,7 +209,8 @@
# Helper function for classCV to check if additional arguments are valid
#' @noRd
#' @export
.check_additional_arguments <- function(model_type = NULL, impute_method = NULL, impute_args = NULL, mod_args = NULL, call = NULL, ...){
.check_additional_arguments <- function(model_type = NULL, impute_method = NULL, impute_args = NULL, mod_args = NULL,
call = NULL, ...){

# Helper function to generate error message
error_message <- function(method_name, invalid_args) {
Expand Down Expand Up @@ -342,7 +347,8 @@
#' @importFrom recipes step_impute_knn recipe all_predictors step_impute_bag prep bake
#' @noRd
#' @export
.imputation <- function(preprocessed_data, target, predictors, formula, imputation_method ,impute_args, classCV_output, iteration, parallel = TRUE, final = FALSE, random_seed = NULL){
.imputation <- function(preprocessed_data, target, predictors, formula, imputation_method ,impute_args, classCV_output,
iteration, parallel = TRUE, final = FALSE, random_seed = NULL){
# Set seed
if(!is.null(random_seed)){
set.seed(random_seed)
Expand Down Expand Up @@ -371,7 +377,8 @@
} else {
formula <- formula
}
rec <- step_impute_knn(recipe = recipe(formula = formula, data = training_data), neighbors = impute_args[["neighbors"]], all_predictors())
rec <- step_impute_knn(recipe = recipe(formula = formula, data = training_data),
neighbors = impute_args[["neighbors"]], all_predictors())
} else {
rec <- step_impute_knn(recipe = recipe(formula = formula, data = training_data),all_predictors())
}
Expand All @@ -382,7 +389,8 @@
} else {
formula <- formula
}
rec <- step_impute_bag(recipe = recipe(formula = formula, data = training_data), trees = impute_args[["trees"]], all_predictors())
rec <- step_impute_bag(recipe = recipe(formula = formula, data = training_data),
trees = impute_args[["trees"]], all_predictors())
} else {
rec <- step_impute_bag(recipe = recipe(formula = formula, data = training_data), all_predictors())
}
Expand Down Expand Up @@ -412,7 +420,8 @@
processed_data <- processed_data[sorted_rows,]

# Create imputation_information list to store information
imputation_information <- .get_missing_info(training_data = training_data, validation_data = validation_data, iteration = iteration, imputation_method = imputation_method)
imputation_information <- .get_missing_info(training_data = training_data, validation_data = validation_data,
iteration = iteration, imputation_method = imputation_method)

if(iteration == "Training"){
imputation_information[["split"]][["prep"]] <- prep
Expand All @@ -439,7 +448,8 @@

} else{
# Get missing information
imputation_information <- .get_missing_info(preprocessed_data = preprocessed_data, imputation_method = imputation_method)
imputation_information <- .get_missing_info(preprocessed_data = preprocessed_data,
imputation_method = imputation_method)
# Impute data
if(imputation_method == "knn_impute"){
if(!is.null(impute_args)){
Expand All @@ -448,7 +458,8 @@
} else {
formula <- formula
}
rec <- step_impute_knn(recipe = recipe(formula = formula, data = preprocessed_data), neighbors = impute_args[["neighbors"]], all_predictors())
rec <- step_impute_knn(recipe = recipe(formula = formula, data = preprocessed_data),
neighbors = impute_args[["neighbors"]], all_predictors())
} else {
rec <- step_impute_knn(recipe = recipe(formula = formula, data = preprocessed_data), all_predictors())
}
Expand All @@ -459,7 +470,8 @@
} else {
formula <- formula
}
rec <- step_impute_bag(recipe = recipe(formula = formula, data = preprocessed_data), trees = impute_args[["trees"]], all_predictors())
rec <- step_impute_bag(recipe = recipe(formula = formula, data = preprocessed_data),
trees = impute_args[["trees"]], all_predictors())
} else {
rec <- step_impute_bag(recipe = recipe(formula = formula, data = preprocessed_data), all_predictors())
}
Expand All @@ -481,7 +493,8 @@
# Assist function for .imputation to get number of missing data for each column
#' @noRd
#' @export
.get_missing_info <- function(preprocessed_data = NULL, training_data = NULL, validation_data = NULL, iteration, imputation_method){
.get_missing_info <- function(preprocessed_data = NULL, training_data = NULL, validation_data = NULL,
iteration, imputation_method){
# Create imputation list
imputation_information <- list()

Expand Down Expand Up @@ -554,10 +567,10 @@
}
}
} else{
warning("no standardization has been done; standardization specified but column indices are outside possible range or column names don't exist")
warning("no standardization has been done; standardization specified but column indices are outside possible range
or column names don't exist")
}

standardize_list <- list("training_data" = training_data, "validation_data" = validation_data)
return(standardize_list)
}

Loading

0 comments on commit 7a64bee

Please sign in to comment.