R Versions Test Status codecov License Platform Support

This R package is a simple, user-friendly tool for train-test splitting and k-fold cross-validation for classification data using various classification algorithms from popular R packages. The functions used from packages for each classification algorithms:

  • lda() from MASS package for Linear Discriminant Analysis
  • qda() from MASS package for Quadratic Discriminant Analysis
  • glm() from base package with family = "binomial" for Logistic Regression
  • svm() from e1071 package for Support Vector Machines
  • naive_bayes() from naivebayes package for Naive Bayes
  • nnet() from nnet package for Artificial Neural Network
  • train.kknn() from kknn package for K-Nearest Neighbors
  • rpart() from rpart package for Decision Trees
  • randomForest() from randomForest package for Random Forest
  • multinom() from nnet package for Multinomial Regression
  • xgb.train() from xgboost package for Gradient Boosting Machines

This package was initially inspired by topepo's caret package.


  • Versatile Data Splitting: Perform train-test splits or k-fold cross-validation on your classification data.
  • Support for Popular Algorithms: Choose from a wide range of classification algorithms such as Linear Discriminant Analysis, Quadratic Discriminant Analysis, Logistic Regression, Support Vector Machines, Naive Bayes, Artificial Neural Networks, K-Nearest Neighbors, Decision Trees, Random Forest, Multinomial Logistic Regression, and Gradient Boosting Machines. Additionally, multiple algorithms can be specified in a single function call.
  • Stratified Sampling Option: Ensure representative class distribution using stratified sampling based on class proportions.
  • Handling Unseen Categorical Levels: Automatically exclude observations from the validation/test set with categories not seen during model training. This is particularly helpful for specific algorithms that might throw errors in such cases.
  • Model Saving Capabilities: Save all models utilized for training and testing.
  • Dataset Saving Options: Preserve split datasets and folds.
  • Model Creation: Easily create and save final models.
  • Missing Data Imputation: Choose from two imputation methods - Bagged Tree Imputation and KNN Imputation. These two methods use the step_bag_impute() and step_knn_impute() functions from the recipes package, respectively. The recipes package is used to create an imputation model using the training data to predict missing data in the predictors for both the training data and the validation data. This is done to prevent data leakage. Rows with missing target variables are removed and the target is removed from being a predictor during imputation.
  • Performance Metrics: View performance metrics in the console and generate/save plots for key metrics, including overall classification accuracy, as well as f-score, precision, and recall for each class in the target variable across train-test split and k-fold cross-validation.
  • Automatic Numerical Encoding: Classes within the target variable are automatically numerically encoded for algorithms such as Logistic Regression and Gradient Boosted Models that require numerical inputs for the target variable.
  • Parallel Processing: Specify the n_cores and future.seed parameters in parallel_configs to specify the number of cores for parallel processing to process multiple folds simultaneously. Only available when cross validation is specified.
  • Minimal Code Requirement: Access desired information quickly and efficiently with just a few lines of code.


From the "main" branch:

# Install 'remotes' to install packages from Github

# Install 'vswift' package
remotes::install_github("donishadsmith/vswift/pkg/vswift", ref="main")
# Display documentation for the 'vswift' package
help(package = "vswift")

Github release:

# Install 'remotes' to install packages from Github

# Install 'vswift' package

# Display documentation for the 'vswift' package
help(package = "vswift")


The type of classification algorithm is specified using the models parameter in the classCV() function.

Acceptable inputs for the models parameter includes:

  • "lda" for Linear Discriminant Analysis
  • "qda" for Quadratic Discriminant Analysis
  • "logistic" for Logistic Regression
  • "svm" for Support Vector Machines
  • "naivebayes" for Naive Bayes
  • "ann" for Artificial Neural Network
  • "knn" for K-Nearest Neighbors
  • "decisiontree" for Decision Trees
  • "randomforest" for Random Forest
  • "multinom" for Multinomial Regression
  • "gbm" for Gradient Boosting Machines

Using a single model:

# Load the package

# Perform train-test split and k-fold cross-validation with stratified sampling
results <- classCV(data = iris,
                   target = "Species",
                   models = "lda",
                   train_params = list(split = 0.8, n_folds = 5, stratified = TRUE, random_seed = 50)
# Also valid; the target variable can refer to the column index

results <- classCV(data = iris,
                   target = 5,
                   models = "lda",
                   train_params = list(split = 0.8, n_folds = 5, stratified = TRUE, random_seed = 50))

# Using formula method is also valid 

results <- classCV(formula = Species ~ .,
                   data = iris,
                   models = "lda",
                   train_params = list(split = 0.8, n_folds = 5, stratified = TRUE, random_seed = 50))

classCV() produces a vswift object which can be used for custom printing and plotting of performance metrics by using the print() and plot() functions.



[1] "vswift"
# Print parameter information and model evaluation metrics
print(results, parameters = TRUE, metrics = TRUE)


- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

Model: Linear Discriminant Analysis 

Formula: Species ~ .

Number of Features: 4

Classes: setosa, versicolor, virginica

Training Parameters: list(split = 0.8, n_folds = 5, stratified = TRUE, random_seed = 50, standardize = FALSE, remove_obs = FALSE)

Model Parameters: list(map_args = NULL, final_model = FALSE)

Missing Data: 0

Effective Sample Size: 150

Imputation Parameters: list(method = NULL, args = NULL)

Parallel Configs: list(n_cores = NULL, future.seed = NULL)

_ _ _ _ _ _ _ _ 

Classification Accuracy:  0.98 

Class:           Precision:  Recall:  F-Score:

setosa                1.00     1.00      1.00 
versicolor            1.00     0.95      0.97 
virginica             0.95     1.00      0.98 

_ _ _ _ 

Classification Accuracy:  0.97 

Class:           Precision:  Recall:  F-Score:

setosa                1.00     1.00      1.00 
versicolor            0.91     1.00      0.95 
virginica             1.00     0.90      0.95 

 K-fold CV 
_ _ _ _ _ _ _ _ _ 

Average Classification Accuracy:  0.98 (0.04) 

Class:           Average Precision:  Average Recall:  Average F-score:

setosa               1.00 (0.00)       1.00 (0.00)       1.00 (0.00) 
versicolor           0.98 (0.05)       0.96 (0.09)       0.97 (0.07) 
virginica            0.96 (0.08)       0.98 (0.04)       0.97 (0.06) 

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
# Plot model evaluation metrics
plot(results, split = TRUE, cv = TRUE, save_plots = TRUE, path = getwd())

The number of predictors can be modified using the predictors or formula parameters:

# Using knn on iris dataset, using the first, third, and fourth columns as predictors. Also, adding an additional argument, `ks = 5`, which is used in train.kknn() from kknn package

results <- classCV(data = iris,
                   target = "Species",
                   predictors = c("Sepal.Length","Petal.Length","Petal.Width"),
                   models = "knn",
                   train_params = list(split = 0.8, n_folds = 5, stratified = TRUE, random_seed = 50),
                   ks = 5)

# All configurations below are valid and will produce the same output

args <- list(knn = list(ks = 5))
results <- classCV(data = iris,
                   target = 5,
                   predictors = c(1,3,4),
                   models = "knn",
                   train_params = list(split = 0.8, n_folds = 5, stratified = TRUE, random_seed = 50),
                   model_params = list(map_args = args))

results <- classCV(formula = Species ~ Sepal.Length + Petal.Length + Petal.Width,
                   data = iris,
                   models = "knn",
                   train_params = list(split = 0.8, n_folds = 5, stratified = TRUE, random_seed = 50),
                   ks = 5)


- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

Model: K-Nearest Neighbors 

Formula: Species ~ Sepal.Length + Petal.Length + Petal.Width

Number of Features: 3

Classes: setosa, versicolor, virginica

Training Parameters: list(split = 0.8, n_folds = 5, stratified = TRUE, random_seed = 50, standardize = FALSE, remove_obs = FALSE)

Model Parameters: list(map_args = list(knn = list(ks = 5)), final_model = FALSE)

Missing Data: 0

Effective Sample Size: 150

Imputation Parameters: list(method = NULL, args = NULL)

Parallel Configs: list(n_cores = NULL, future.seed = NULL)

_ _ _ _ _ _ _ _ 

Classification Accuracy:  0.97 

Class:           Precision:  Recall:  F-Score:

setosa                1.00     1.00      1.00 
versicolor            0.95     0.95      0.95 
virginica             0.95     0.95      0.95 

_ _ _ _ 

Classification Accuracy:  0.97 

Class:           Precision:  Recall:  F-Score:

setosa                1.00     1.00      1.00 
versicolor            0.91     1.00      0.95 
virginica             1.00     0.90      0.95 

 K-fold CV 
_ _ _ _ _ _ _ _ _ 

Average Classification Accuracy:  0.96 (0.05) 

Class:           Average Precision:  Average Recall:  Average F-score:

setosa               1.00 (0.00)       1.00 (0.00)       1.00 (0.00) 
versicolor           0.92 (0.08)       0.96 (0.09)       0.94 (0.08) 
virginica            0.96 (0.09)       0.92 (0.08)       0.94 (0.08) 

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

Displaying what is contained in the vswift object by converting its class to a list and using R's base print() function.

class(results) <- "list"
Species ~ Sepal.Length + Petal.Length + Petal.Width

[1] 3

[1] "knn"

[1] 5



[1] 0.8

[1] 5

[1] TRUE

[1] 50



[1] 0

[1] 150







[1] "setosa"     "versicolor" "virginica" 

    setosa versicolor  virginica 
 0.3333333  0.3333333  0.3333333 

       Set Classification Accuracy Class: setosa Precision Class: setosa Recall Class: setosa F-Score Class: versicolor Precision Class: versicolor Recall Class: versicolor F-Score
1 Training               0.9666667                       1                    1                     1                   0.9500000                     0.95                  0.950000
2     Test               0.9666667                       1                    1                     1                   0.9090909                     1.00                  0.952381
  Class: virginica Precision Class: virginica Recall Class: virginica F-Score
1                       0.95                    0.95                0.9500000
2                       1.00                    0.90                0.9473684

                    Fold Classification Accuracy Class: setosa Precision Class: setosa Recall Class: setosa F-Score Class: versicolor Precision Class: versicolor Recall
1                 Fold 1              0.86666667                       1                    1                     1                  0.80000000               0.80000000
2                 Fold 2              0.96666667                       1                    1                     1                  0.90909091               1.00000000
3                 Fold 3              1.00000000                       1                    1                     1                  1.00000000               1.00000000
4                 Fold 4              1.00000000                       1                    1                     1                  1.00000000               1.00000000
5                 Fold 5              0.96666667                       1                    1                     1                  0.90909091               1.00000000
6               Mean CV:              0.96000000                       1                    1                     1                  0.92363636               0.96000000
7 Standard Deviation CV:              0.05477226                       0                    0                     0                  0.08272228               0.08944272
8     Standard Error CV:              0.02449490                       0                    0                     0                  0.03699453               0.04000000
  Class: versicolor F-Score Class: virginica Precision Class: virginica Recall Class: virginica F-Score
1                0.80000000                 0.80000000              0.80000000               0.80000000
2                0.95238095                 1.00000000              0.90000000               0.94736842
3                1.00000000                 1.00000000              1.00000000               1.00000000
4                1.00000000                 1.00000000              1.00000000               1.00000000
5                0.95238095                 1.00000000              0.90000000               0.94736842
6                0.94095238                 0.96000000              0.92000000               0.93894737
7                0.08231349                 0.08944272              0.08366600               0.08201074
8                0.03681171                 0.04000000              0.03741657               0.03667632

Using multiple models with parallel processing

Note: This example uses the internet advertisement data from the UCI Machine Learning Repository.

# Set url for interet advertisement data from UCI Machine Learning Repository. This data has 3,278 instances and 1558 attributes. 

url <- ""

# Set file destination

dest_file <- file.path(getwd(),"")

# Download zip file


# Unzip file

unzip(zipfile = dest_file , files = "")

# Read data

ad_data <- read.csv("")

# Load in vswift


# Create arguments variable to tune parameters for multiple models
args <- list("knn" = list(ks = 5), 
             "gbm" = list(params = list(booster = "gbtree", objective = "multi:softmax",
                                        lambda = 0.0003, alpha = 0.0003, num_class = 2, eta = 0.8,
                                        max_depth = 6), nrounds = 10))

print("Without Parallel Processing:")

# Obtain new start time 

start <- proc.time()

# Run the same model without parallel processing 

results <- classCV(data = ad_data,
                   target = "ad.",
                   models = c("knn","svm","decisiontree","gbm"),
                   train_params = list(split = 0.8, n_folds = 5, random_seed = 50),
                   model_params = list(map_args = args)

# Get end time 
end <- proc.time() - start

# Print time

print("Parallel Processing:")

# Adjust maximum object size that can be passed to workers during parallel processing; ~1.2 gb
options(future.globals.maxSize = 1200 * 1024^2)

# Obtain start time
start_par <- proc.time()

# Run model using parallel processing with 4 cores
results <- classCV(data = ad_data,
                   target = "ad.",
                   models = c("knn","svm","decisiontree","gbm"),
                   train_params = list(split = 0.8, n_folds = 5, random_seed = 50),
                   model_params = list(map_args = args),
                   parallel_configs = list(n_cores = 4, future.seed = 100)

# Obtain end time

end_par <- proc.time() - start_par

# Print time


[1] "Without Parallel Processing:"

Warning message:
In .create_dictionary(preprocessed_data = preprocessed_data,  :
  classes are now encoded: ad. = 0, nonad. = 1

   user  system elapsed 
 336.93    1.97  344.62 

[1] "Parallel Processing:"

Warning message:
In .create_dictionary(preprocessed_data = preprocessed_data,  :
  classes are now encoded: ad. = 0, nonad. = 1

   user  system elapsed 
   1.48    9.16  188.28
# Print parameter information and model evaluation metrics; If number of features > 20, the tartget replaces the formula
print(results, models = c("gbm", "knn"))


- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

Model: Gradient Boosted Machine 

Target: ad.

Number of Features: 1558

Classes: ad., nonad.

Training Parameters: list(split = 0.8, n_folds = 5, random_seed = 50, stratified = FALSE, standardize = FALSE, remove_obs = FALSE)

Model Parameters: list(map_args = list(gbm = list(params = list(booster = "gbtree", objective = "multi:softmax", lambda = 3e-04, alpha = 3e-04, num_class = 2, eta = 0.8, max_depth = 6), nrounds = 10)), final_model = FALSE)

Missing Data: 0

Effective Sample Size: 3278

Imputation Parameters: list(method = NULL, args = NULL)

Parallel Configs: list(n_cores = 4, future.seed = 100)

_ _ _ _ _ _ _ _ 

Classification Accuracy:  0.99 

Class:       Precision:  Recall:  F-Score:

ad.               0.99     0.96      0.97 
nonad.            0.99     1.00      1.00 

_ _ _ _ 

Classification Accuracy:  0.98 

Class:       Precision:  Recall:  F-Score:

ad.               0.94     0.89      0.92 
nonad.            0.98     0.99      0.99 

 K-fold CV 
_ _ _ _ _ _ _ _ _ 

Average Classification Accuracy:  0.98 (0.01) 

Class:       Average Precision:  Average Recall:  Average F-score:

ad.              0.95 (0.02)       0.88 (0.04)       0.91 (0.02) 
nonad.           0.98 (0.01)       0.99 (0.00)       0.99 (0.00) 

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

Model: K-Nearest Neighbors 

Target: ad.

Number of Features: 1558

Classes: ad., nonad.

Training Parameters: list(split = 0.8, n_folds = 5, random_seed = 50, stratified = FALSE, standardize = FALSE, remove_obs = FALSE)

Model Parameters: list(map_args = list(knn = list(ks = 5)), final_model = FALSE)

Missing Data: 0

Effective Sample Size: 3278

Imputation Parameters: list(method = NULL, args = NULL)

Parallel Configs: list(n_cores = 4, future.seed = 100)

_ _ _ _ _ _ _ _ 

Classification Accuracy:  1.00 

Class:       Precision:  Recall:  F-Score:

ad.               1.00     0.99      1.00 
nonad.            1.00     1.00      1.00 

_ _ _ _ 

Classification Accuracy:  0.96 

Class:       Precision:  Recall:  F-Score:

ad.               0.89     0.80      0.84 
nonad.            0.97     0.98      0.98 

 K-fold CV 
_ _ _ _ _ _ _ _ _ 

Average Classification Accuracy:  0.93 (0.01) 

Class:       Average Precision:  Average Recall:  Average F-score:

ad.              0.71 (0.07)       0.82 (0.01)       0.76 (0.04) 
nonad.           0.97 (0.00)       0.95 (0.02)       0.96 (0.01) 

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

# Plot results

plot(results, models = "gbm" , save_plots = TRUE,
     class_names = "ad.", metrics = c("precision", "recall"))

