Pages

Split a Data Frame into Testing and Training Sets in R

I recently analyzed some data trying to find a model that would explain body fat distribution as predicted by several blood biomarkers. I had more predictors than samples (p>n), and I didn't have a clue which variables, interactions, or quadratic terms made biological sense to put into a model.

I then turned to a few data mining procedures that I learned about during grad school but never really used (LASSO, Random Forest, support vector machines, etc). So far, Random Forest is working unbelievably well. The boostrapping and aggregation ("bagging," i.e. the random component of Random Forest) avoids overfitting so well that I'm able to explain about 80% of the variation in an unseen sample using a model derived from only 30 training samples. (This paper offers the best explanation of Random Forest I've come across).

While doing this I needed to write an R function to split up a dataset into training and testing sets so I could train models on one half and test them on unseen data. I'm sure a function already exists to do something similar, but it was trivial enough to write a function to do it myself.

This function takes a data frame and returns two dataframes (as a list), one called trainset, one called testset.

splitdf <- function(dataframe, seed=NULL) {
    if (!is.null(seed)) set.seed(seed)
    index <- 1:nrow(dataframe)
    trainindex <- sample(index, trunc(length(index)/2))
    trainset <- dataframe[trainindex, ]
    testset <- dataframe[-trainindex, ]
    list(trainset=trainset,testset=testset)
}

In R, you can generally fit a model doing something like this:

mymodel <- method(y~x, data=mydata)
...and then predict the outcome for new data using the generic predict function:

predvals <- predict(mymodel, newdataframe)

Here's some R code that uses the built in iris data, splits the dataset into training and testing sets, and develops a model to predict sepal length based on every other variable in the dataset using Random Forest.

*Edit 2011-02-25* Thanks for all the comments. Clearly the split() function does something very similar to this, and the createDataPartition() function in the caret package does this too.