# How to use random forest as a clustering method

If you ask me: what’s your favorite machine learning algorithm? I would answer logistic regression (with regularization: Lasso, Ridge and Elastic) followed by random forest. In fact, that’s how we try those methods in order. Deep learning can perform well for tabular data with complicated architecture while random forest or boost tree based method usually work well out of the box. Regression and random forest are more interpretable too.

We all know we can use random forest to do classification or regression (supervised machine learning), but do you know you can use random forest for clustering too (unsupervised machine learning)?

When you have a mixed numeric and categorical dataset where it’s not straightforward to define a distance between observations, random forest can be trained in an unsupervised manner and generate the proximity matrix.

The proximity represents the percentage of trees where the two observations appear in the same leaf node. So the higher the value, the closer the observations.

This is pretty cool! I first got to know it from Josh Starmer’s StatQuest!

For datasets that are all numeric, the Random Forest step is not necessary. You can use distance/similarity metrics such as Euclidean, Mahalanobis, and Manhattan (?dist in R).

### Use random forest for classification

Let’s use a real example with the iris dataset.

library(tidymodels)
set.seed(123)
data<- iris
head(data)
#>   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
#> 1          5.1         3.5          1.4         0.2  setosa
#> 2          4.9         3.0          1.4         0.2  setosa
#> 3          4.7         3.2          1.3         0.2  setosa
#> 4          4.6         3.1          1.5         0.2  setosa
#> 5          5.0         3.6          1.4         0.2  setosa
#> 6          5.4         3.9          1.7         0.4  setosa
table(data$Species) #> #> setosa versicolor virginica #> 50 50 50 It is a small dataset with 3 species of flowers: setosa, versicolor, virginica. The feaures are the sepal/petal length/width. split the dataset to training and testing sets. data_split <- initial_split(data, strata = "Species") data_train <- training(data_split) data_test <- testing(data_split) build a random forest model using tidymodels to classify the species: rf_recipe <- recipe(formula = Species ~ ., data = data_train) %>% step_zv(all_predictors()) ## feature importance sore to TRUE, and the proximity matrix to TRUE rf_spec <- rand_forest() %>% set_engine("randomForest", importance = TRUE, proximity = TRUE) %>% set_mode("classification") rf_workflow <- workflow() %>% add_recipe(rf_recipe) %>% add_model(rf_spec) rf_fit <- fit(rf_workflow, data = data_train) We can use the model to do classification on the testing set: ## confusion matrix predict(rf_fit, new_data = data_test) %>% bind_cols(data_test %>% select(Species)) %>% conf_mat(truth = Species, estimate = .pred_class) #> Truth #> Prediction setosa versicolor virginica #> setosa 13 0 0 #> versicolor 0 13 1 #> virginica 0 0 12 Read my previous blog post on using random forest for scRNAseq marker gene identification. ### Use the proximity matrix for clustering The proximity matrix is hidden deep in the list: proximity_mat<- rf_fit$fit$fit$fit$proximity pca_prcomp<- prcomp(proximity_mat, center = TRUE, scale. = FALSE) plot(pca_prcomp) pca_df<- data.frame(pca_prcomp$x[,1:2], Species = data_train$Species) ggplot(pca_df, aes(x= PC1, y = PC2)) + geom_point(aes(color = Species)) of course, we can use the original matrix for PCA too because they are all numeric values for the variables. head(data_train) #> Sepal.Length Sepal.Width Petal.Length Petal.Width Species #> 3 4.7 3.2 1.3 0.2 setosa #> 4 4.6 3.1 1.5 0.2 setosa #> 5 5.0 3.6 1.4 0.2 setosa #> 7 4.6 3.4 1.4 0.3 setosa #> 8 5.0 3.4 1.5 0.2 setosa #> 9 4.4 2.9 1.4 0.2 setosa pca_prcomp2<- prcomp(as.matrix(data_train[, -5]), center = TRUE, scale. = FALSE) plot(pca_prcomp2) pca_df2<- data.frame(pca_prcomp2$x[,1:2], Species = data_train$Species) ggplot(pca_df2, aes(x= PC1, y = PC2)) + geom_point(aes(color = Species)) However, imagine that not all the variables are numeric, and we can not easily plot a PCA plot using the raw data. We can use random forest to get a proximity matrix and use that matrix for PCA as shown above. ### clustering using the proximity matrix dim(proximity_mat) #>  111 111 proximity_mat[1:5, 1:5] #> 1 2 3 4 5 #> 1 1.0000000 0.9857143 1.00000 1.0000000 1.000000 #> 2 0.9857143 1.0000000 0.96875 0.9848485 0.974026 #> 3 1.0000000 0.9687500 1.00000 1.0000000 1.000000 #> 4 1.0000000 0.9848485 1.00000 1.0000000 1.000000 #> 5 1.0000000 0.9740260 1.00000 1.0000000 1.000000 rownames(proximity_mat)<- data_train[, 5] colnames(proximity_mat)<- data_train[, 5] # turn it to a distance iris_distance<- dist(1-proximity_mat) # hierarchical clustering hc<- hclust(iris_distance) # cut the tree to 3 clusters clusters<- cutree(hc, k = 3) library(dendextend) library(dendsort) library(Polychrome) mypal <- kelly.colors(4)[-1] swatch(mypal) plot_dend<- function(dend,...){ dend_labels<- dend %>% labels() dend %>% color_labels(col = mypal[as.numeric(as.factor(dend_labels))]) %>% set("labels_cex", 0.7) %>% plot(...) } plot the dendrogram # no sorting on dendrogram as.dendrogram(hc) %>% plot_dend() # sort the dendrogram as.dendrogram(hc) %>% dendsort() %>% plot_dend() visualize the clusters in the PCA plot using the proximity matrix. We see there are some miss-classifications between versicolor and virginica. pca_df<- data.frame(pca_prcomp$x[,1:2],
Species = data_train$Species, clusters = as.factor(clusters)) ggplot(pca_df, aes(x= PC1, y = PC2)) + geom_point(aes(color = Species, shape = clusters)) + theme_bw(base_size = 14) visualize the clusters in the PCA plot using the raw data. pca_df2<- data.frame(pca_prcomp2$x[,1:2],
Species = data_train\$Species,
clusters = as.factor(clusters))

ggplot(pca_df2, aes(x= PC1, y = PC2)) +
geom_point(aes(color = Species, shape = clusters)) +
theme_bw(base_size = 14) PS. If you want to learn more about clustering, heatmap and PCA, my book have full chapters devoted to those topics. Grab a copy to become a data master at here!