Cancer RNA-Seq Data Clustering
Contents
Cancer RNA-Seq Data Clustering#
In this notebook, we will learn to cluster data, i.e., to divide the data points into distinct groups, so that there is relatively small variation within a group and larger variation between groups. This task belongs to “unsupervised learning”, because the training data is not labeled. Compared to classification where our goal is to match the known answers, here we try to find patterns in the data without additional information.
The dataset that we will use as our example is the gene expression cancer RNA-Seq dataset. It contains the expression levels of 20531 genes from 801 patients having different types of tumor. Our goal is to analyze the data and cluster them into groups, so that these groups may correspond to different tumor types. The dataset in fact comes with labels — the patients were diagnosed with 5 types of tumor: BRCA, KIRC, COAD, LUAD and PRAD. But when we analyze the data, we will pretend that the diagnoses are not known (or not all correct?) and see how well we can figure them out.
Load data#
The dataset can be loaded and printed using the pandas
package, which is a popular data analysis and manipulation package. The data loaded in pandas
are in the format of DataFrame
s, and they can be operated by numpy
just like a normal array (in most cases).
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
data = pd.read_csv('data/cancer-data.csv', header=0, index_col=0) # load data, using row-0 as column names and column-0 as row names
labels = pd.read_csv('data/cancer-labels.csv', header=0, index_col=0) # load labels, will use later to check results
data # print data
gene_0 | gene_1 | gene_2 | gene_3 | gene_4 | gene_5 | gene_6 | gene_7 | gene_8 | gene_9 | ... | gene_20521 | gene_20522 | gene_20523 | gene_20524 | gene_20525 | gene_20526 | gene_20527 | gene_20528 | gene_20529 | gene_20530 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
sample_0 | 0.0 | 2.017209 | 3.265527 | 5.478487 | 10.431999 | 0.0 | 7.175175 | 0.591871 | 0.0 | 0.0 | ... | 4.926711 | 8.210257 | 9.723516 | 7.220030 | 9.119813 | 12.003135 | 9.650743 | 8.921326 | 5.286759 | 0.000000 |
sample_1 | 0.0 | 0.592732 | 1.588421 | 7.586157 | 9.623011 | 0.0 | 6.816049 | 0.000000 | 0.0 | 0.0 | ... | 4.593372 | 7.323865 | 9.740931 | 6.256586 | 8.381612 | 12.674552 | 10.517059 | 9.397854 | 2.094168 | 0.000000 |
sample_2 | 0.0 | 3.511759 | 4.327199 | 6.881787 | 9.870730 | 0.0 | 6.972130 | 0.452595 | 0.0 | 0.0 | ... | 5.125213 | 8.127123 | 10.908640 | 5.401607 | 9.911597 | 9.045255 | 9.788359 | 10.090470 | 1.683023 | 0.000000 |
sample_3 | 0.0 | 3.663618 | 4.507649 | 6.659068 | 10.196184 | 0.0 | 7.843375 | 0.434882 | 0.0 | 0.0 | ... | 6.076566 | 8.792959 | 10.141520 | 8.942805 | 9.601208 | 11.392682 | 9.694814 | 9.684365 | 3.292001 | 0.000000 |
sample_4 | 0.0 | 2.655741 | 2.821547 | 6.539454 | 9.738265 | 0.0 | 6.566967 | 0.360982 | 0.0 | 0.0 | ... | 5.996032 | 8.891425 | 10.373790 | 7.181162 | 9.846910 | 11.922439 | 9.217749 | 9.461191 | 5.110372 | 0.000000 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
sample_796 | 0.0 | 1.865642 | 2.718197 | 7.350099 | 10.006003 | 0.0 | 6.764792 | 0.496922 | 0.0 | 0.0 | ... | 6.088133 | 9.118313 | 10.004852 | 4.484415 | 9.614701 | 12.031267 | 9.813063 | 10.092770 | 8.819269 | 0.000000 |
sample_797 | 0.0 | 3.942955 | 4.453807 | 6.346597 | 10.056868 | 0.0 | 7.320331 | 0.000000 | 0.0 | 0.0 | ... | 6.371876 | 9.623335 | 9.823921 | 6.555327 | 9.064002 | 11.633422 | 10.317266 | 8.745983 | 9.659081 | 0.000000 |
sample_798 | 0.0 | 3.249582 | 3.707492 | 8.185901 | 9.504082 | 0.0 | 7.536589 | 1.811101 | 0.0 | 0.0 | ... | 5.719386 | 8.610704 | 10.485517 | 3.589763 | 9.350636 | 12.180944 | 10.681194 | 9.466711 | 4.677458 | 0.586693 |
sample_799 | 0.0 | 2.590339 | 2.787976 | 7.318624 | 9.987136 | 0.0 | 9.213464 | 0.000000 | 0.0 | 0.0 | ... | 5.785237 | 8.605387 | 11.004677 | 4.745888 | 9.626383 | 11.198279 | 10.335513 | 10.400581 | 5.718751 | 0.000000 |
sample_800 | 0.0 | 2.325242 | 3.805932 | 6.530246 | 9.560367 | 0.0 | 7.957027 | 0.000000 | 0.0 | 0.0 | ... | 6.403075 | 8.594354 | 10.243079 | 9.139459 | 10.102934 | 11.641081 | 10.607358 | 9.844794 | 4.550716 | 0.000000 |
801 rows × 20531 columns
labels # print labels
Class | |
---|---|
sample_0 | PRAD |
sample_1 | LUAD |
sample_2 | PRAD |
sample_3 | PRAD |
sample_4 | BRCA |
... | ... |
sample_796 | BRCA |
sample_797 | LUAD |
sample_798 | COAD |
sample_799 | PRAD |
sample_800 | PRAD |
801 rows × 1 columns
Our data is a (801, 20531) array, where each row corresponds to one patient, and each column represents one gene. The values of the entries are non-negative numbers that are readings from the RNA-Seq measurements. We can think of each data point as a 20531-dimensional vector, representing a particular patient that belongs to one of the tumor types. Thus, our goal is to separate the 801 data points into different categories. This may sound like the classification problem we did before, except that here we do not know what are the categories or even how many they are.
Dimensionality reduction#
When dealing with such high-dimensional data, it is a good idea to first reduce the dimensionality, such as by using Principal Component Analysis (PCA). The goal is to find a small number of principal components that capture most of the variation among the data points. Imagine that there is one gene that is not expressed in any patient, then this gene (or the dimension it represents) is not useful at all for distinguishing the data points. Therefore we are interested in finding the directions in the data space along which the data points vary the most.
Recall that, to perform PCA, we first calculate the covariance matrix of the data. Instead of calculating this by hand, we can use the function numpy.cov
. Since it expects a 2-d array as input, where each column is a data point, we have to transpose our data.
cov = np.cov(data.T) # calculate the covariance matrix
Then we need to find its eigenvectors that correspond to the largest eigenvalues. We can calculate the eigenvalues and eigenvectors using the numpy.linalg.eigh
function as before. However, this function calculates all eigenvalues and eigenvectors, and there are too many (20531) of them here. Since we are interested in only the largest few eigenvalues, we can instead use the scipy.sparse.linalg.eigs
function (or the eigsh
function for symmetric matrices) as follows.
import scipy.sparse.linalg as spla
w, v = spla.eigsh(cov, k=50) # calculate the largest k eigenvalues and their eigenvectors
Here we calculated the largest 50 eigenvalues and their eigenvectors. Recall that each eigenvector represents a principal component (PC), and the corresponding eigenvalue represents the variance along that direction. Let us find the directions with the largest variance and see how many we need to capture most of the total variance. (Note that we need to normalize the eigenvalues by the total variance, which is the sum of all eigenvalues. But because here we did not calculate all eigenvalues, we cannot use the sum of only these eigenvalues. Instead, the total variance can be calculated as the trace of the covariance matrix.)
order = np.argsort(w)[::-1] # argsort gives ascending order, [::-1] reverses the order
w = w[order] # order the eigenvalues
v = v[:,order] # each column is an eigenvector, so we order the columns
wnorm = w / np.trace(cov) # normalized eigenvalue = fraction of total variance captured
wsum = np.cumsum(wnorm) # cumulative sum of variance captured
plt.figure()
plt.plot(wsum, '.-')
plt.ylim(0, 1)
plt.xlabel('# principal components', fontsize=24)
plt.ylabel('variance captured', fontsize=24)
plt.show()
We see that the captured fraction of variance increases rapidly with the first few eigenvalues, then slows down and almost plateaus (it will slowly increase to 1 when all 20531 eigenvalues are included). Let us pick, say, \(K = 10\) principal components. We will project the original data onto these 10 components as follows.
K = 10 # number of principal components to use
projected = np.dot(data, v[:,:K]) # project data points onto principal components
We can try to visualize these principal components, although it is hard to make plots in more than 2 or 3 dimensions. Since the first few components capture the most variance, we may hope that plotting 2 or 3 of them would be enough. It turns out that the first and third PCs give a pretty good impression of the data.
plt.figure(figsize=(6,6))
plt.scatter(projected[:,0], projected[:,2], s=5)
plt.xlabel('PC1')
plt.ylabel('PC3')
plt.show()
We see that the data points are naturally separated into 5 clusters. To improve visualization, we can color the data points according to their labels (which we have pretended not to know). For that we need a bit of processing of the labels since they were given as strings but we want indices.
types = set(labels['Class']) # collect all tumor types
nt = len(types) # number of types
print(types)
{'BRCA', 'COAD', 'KIRC', 'LUAD', 'PRAD'}
These represent kidney cancer, breast cancer, prostate cancer, lung cancer, and colon cancer, respectively.
type_to_index = dict(zip(types, range(nt))) # map string labels to indices
indices = labels['Class'].map(type_to_index).to_numpy() # convert labels to indices for all data points
We can now plot the data points in 3D with colors.
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import animation as anim
plt.rcParams["animation.html"] = "jshtml"
cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
colors = np.asarray(cycle)[indices]
fig = plt.figure(figsize=(8,6))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
ax = fig.add_subplot(projection="3d")
ax.scatter(projected[:,0], projected[:,1], projected[:,2], s=5, c=colors)
ax.set_xlabel(r'PC1')
ax.set_ylabel(r'PC2')
ax.set_zlabel(r'PC3')
fig.set_facecolor('w')
def animate(i):
ax.view_init(elev=10., azim=i*10)
return fig,
mov = anim.FuncAnimation(fig, animate, frames=36)
# mov.save('source/cancer.gif', fps=5)
plt.close()
mov
We see that the clusters match very well with the colors. It shows that the first few principal components successfully capture the main features of the dataset. Therefore, we will use these principal components to cluster the data points, again pretending that we do not know the labels (colors).
Clustering#
Our goal is to separate the data points into multiple clusters according to their relative positions in the data space (or the reduced low-dimensional space). We would like the data points within the same cluster to be close together, and different clusters to be relatively far apart. The number of clusters needed to separate the data points is often not known beforehand, so finding the appropriate number of clusters is part of the task. In simple clustering algorithms, such as k-means that we will use, the number of clusters is chosen by hand and given to the algorithm. More sophisticated algorithms may select this number automatically according to certain criteria.
The k-means algorithm aims at minimizing the mean squared distance between all pairs of data points within the same cluster, i.e.,
where \(S_i\) is the set of data points belonging to the \(i\)-th cluster for \(i = 1, \cdots, k\), and \(N_i\) is the number of data points in \(S_i\). Each data point \(\mathbf{X}_n\) is a vector in the data space, and \(|\mathbf{X}_n - \mathbf{X}_m|\) is the Euclidean distrance between two such points. This cost function can be equivalently expressed as the sum of squared distance from every data point to the center of its own cluster, i.e.,
This minimization problem is not easy in the sense that the global minimum is hard to find. However, the k-means algorithm quickly converges to a local minimum, which may be good enough.
The algorithm requires the number of clusters, \(k\), as an input. One may also provide an initial guess for the center position of every cluster. If not given, the algorithm will use random initial positions for the cluster centers. Because of this randomness and the fact that it only finds the local minimum, running the algorithm twice may give different results. So you may have to try a few times and check if the clustering result is satisfactory.
The heuristic idea behind the algorithm is to iterate over two simple steps. The first step starts with the current guess for the center position of every cluster, and reassigns all data points to their nearest center. The second step then collects the data points now assigned to each cluster, and recalculates their center position. The algorithm alternates between these two steps until no reassignment happens, which means the result has converged.
Luckily, we don’t have to program this algorithm ourselves, since it is already provided by the python function scipy.cluster.vq.kmeans2
. In the following, we will use this function to cluster our data. We will use the dimensionally reduced data that we got from PCA above. Also, since we saw that the data points seem to separate nicely into 5 clusters (in agreement with the 5 diagnosed tumor types, which we pretend not to know), we will ask the algorithm to find k = 5 clusters.
import scipy.cluster.vq as vq
k = 5 # choose number of clusters
centroid, lab = vq.kmeans2(projected, k) # perform k-means clustering
The two outputs are: centroid
, which is a 2-d array where each row represents the center of a cluster, and lab
, which is a 1-d array where each entry is an index telling us which cluster each data point belongs to. We could use the lab
indices to color our data points and plot the clustering result. (Note that the clusters are randomly ordered by k-means
. We can reorder them to match the “true labels” we saw above.)
colormap = []
for cent in centroid:
med = np.argmin(np.sum((projected - cent)**2, axis=1)) # find a point close to each center
colormap.append(indices[med]) # find the true label for that point and use its index for the cluster
newindices = np.asarray(colormap)[lab] # convert to new indices for all points
newcolors = np.asarray(cycle)[newindices] # reorder labels
plt.figure(figsize=(6,6))
plt.scatter(projected[:,0], projected[:,2], s=3, c=newcolors)
plt.xlabel('PC1')
plt.ylabel('PC3')
plt.show()
We see that the colors match very well with the clusters. You may notice some points that seem to wander into another cluster — that is because we are only plotting two dimensions, and the points that look overlapping may very well be separated in some other dimensions. Try plotting some other principal components yourself.
Finally, let us see how well we have clustered the data compared to the true labels. Let us highlight the points that are wrongly labeled.
wrong = (newindices != indices)
plt.figure(figsize=(6,6))
plt.scatter(projected[:,0], projected[:,2], s=3, c=newcolors)
plt.scatter(projected[wrong,0], projected[wrong,2], s=100, c='k', marker='X')
plt.xlabel('PC1')
plt.ylabel('PC3')
plt.show()
It turns out that we only made 4 mistakes out of 801 data points. Also, it can be seen that the wrongly labeled data points lie near the boundary of the clusters, which makes the mistakes quite understandable.