{ "cells": [ { "cell_type": "markdown", "id": "improved-explorer", "metadata": { "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "# Cancer RNA-Seq Data Clustering" ] }, { "cell_type": "markdown", "id": "christian-experience", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "In this notebook, we will learn to cluster data, i.e., to divide the data points into distinct groups, so that there is relatively small variation within a group and larger variation between groups. This task belongs to \"unsupervised learning\", because the training data is not labeled. Compared to classification where our goal is to match the known answers, here we try to find patterns in the data without additional information." ] }, { "cell_type": "markdown", "id": "conventional-parliament", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "The dataset that we will use as our example is the [gene expression cancer RNA-Seq dataset](https://archive.ics.uci.edu/ml/datasets/gene+expression+cancer+RNA-Seq). It contains the expression levels of 20531 genes from 801 patients having different types of tumor. Our goal is to analyze the data and cluster them into groups, so that these groups may correspond to different tumor types. The dataset in fact comes with labels --- the patients were diagnosed with 5 types of tumor: BRCA, KIRC, COAD, LUAD and PRAD. But when we analyze the data, we will pretend that the diagnoses are not known (or not all correct?) and see how well we can figure them out." ] }, { "cell_type": "markdown", "id": "golden-combine", "metadata": { "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "## Load data" ] }, { "cell_type": "markdown", "id": "delayed-porter", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "The dataset can be loaded and printed using the `pandas` package, which is a popular data analysis and manipulation package. The data loaded in `pandas` are in the format of `DataFrame`s, and they can be operated by `numpy` just like a normal array (in most cases)." ] }, { "cell_type": "code", "execution_count": 1, "id": "solved-nothing", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 2, "id": "lined-resistance", "metadata": {}, "outputs": [], "source": [ "data = pd.read_csv('data/cancer-data.csv', header=0, index_col=0) # load data, using row-0 as column names and column-0 as row names\n", "labels = pd.read_csv('data/cancer-labels.csv', header=0, index_col=0) # load labels, will use later to check results" ] }, { "cell_type": "code", "execution_count": 3, "id": "flush-demand", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
gene_0gene_1gene_2gene_3gene_4gene_5gene_6gene_7gene_8gene_9...gene_20521gene_20522gene_20523gene_20524gene_20525gene_20526gene_20527gene_20528gene_20529gene_20530
sample_00.02.0172093.2655275.47848710.4319990.07.1751750.5918710.00.0...4.9267118.2102579.7235167.2200309.11981312.0031359.6507438.9213265.2867590.000000
sample_10.00.5927321.5884217.5861579.6230110.06.8160490.0000000.00.0...4.5933727.3238659.7409316.2565868.38161212.67455210.5170599.3978542.0941680.000000
sample_20.03.5117594.3271996.8817879.8707300.06.9721300.4525950.00.0...5.1252138.12712310.9086405.4016079.9115979.0452559.78835910.0904701.6830230.000000
sample_30.03.6636184.5076496.65906810.1961840.07.8433750.4348820.00.0...6.0765668.79295910.1415208.9428059.60120811.3926829.6948149.6843653.2920010.000000
sample_40.02.6557412.8215476.5394549.7382650.06.5669670.3609820.00.0...5.9960328.89142510.3737907.1811629.84691011.9224399.2177499.4611915.1103720.000000
..................................................................
sample_7960.01.8656422.7181977.35009910.0060030.06.7647920.4969220.00.0...6.0881339.11831310.0048524.4844159.61470112.0312679.81306310.0927708.8192690.000000
sample_7970.03.9429554.4538076.34659710.0568680.07.3203310.0000000.00.0...6.3718769.6233359.8239216.5553279.06400211.63342210.3172668.7459839.6590810.000000
sample_7980.03.2495823.7074928.1859019.5040820.07.5365891.8111010.00.0...5.7193868.61070410.4855173.5897639.35063612.18094410.6811949.4667114.6774580.586693
sample_7990.02.5903392.7879767.3186249.9871360.09.2134640.0000000.00.0...5.7852378.60538711.0046774.7458889.62638311.19827910.33551310.4005815.7187510.000000
sample_8000.02.3252423.8059326.5302469.5603670.07.9570270.0000000.00.0...6.4030758.59435410.2430799.13945910.10293411.64108110.6073589.8447944.5507160.000000
\n", "

801 rows × 20531 columns

\n", "
" ], "text/plain": [ " gene_0 gene_1 gene_2 gene_3 gene_4 gene_5 gene_6 \\\n", "sample_0 0.0 2.017209 3.265527 5.478487 10.431999 0.0 7.175175 \n", "sample_1 0.0 0.592732 1.588421 7.586157 9.623011 0.0 6.816049 \n", "sample_2 0.0 3.511759 4.327199 6.881787 9.870730 0.0 6.972130 \n", "sample_3 0.0 3.663618 4.507649 6.659068 10.196184 0.0 7.843375 \n", "sample_4 0.0 2.655741 2.821547 6.539454 9.738265 0.0 6.566967 \n", "... ... ... ... ... ... ... ... \n", "sample_796 0.0 1.865642 2.718197 7.350099 10.006003 0.0 6.764792 \n", "sample_797 0.0 3.942955 4.453807 6.346597 10.056868 0.0 7.320331 \n", "sample_798 0.0 3.249582 3.707492 8.185901 9.504082 0.0 7.536589 \n", "sample_799 0.0 2.590339 2.787976 7.318624 9.987136 0.0 9.213464 \n", "sample_800 0.0 2.325242 3.805932 6.530246 9.560367 0.0 7.957027 \n", "\n", " gene_7 gene_8 gene_9 ... gene_20521 gene_20522 gene_20523 \\\n", "sample_0 0.591871 0.0 0.0 ... 4.926711 8.210257 9.723516 \n", "sample_1 0.000000 0.0 0.0 ... 4.593372 7.323865 9.740931 \n", "sample_2 0.452595 0.0 0.0 ... 5.125213 8.127123 10.908640 \n", "sample_3 0.434882 0.0 0.0 ... 6.076566 8.792959 10.141520 \n", "sample_4 0.360982 0.0 0.0 ... 5.996032 8.891425 10.373790 \n", "... ... ... ... ... ... ... ... \n", "sample_796 0.496922 0.0 0.0 ... 6.088133 9.118313 10.004852 \n", "sample_797 0.000000 0.0 0.0 ... 6.371876 9.623335 9.823921 \n", "sample_798 1.811101 0.0 0.0 ... 5.719386 8.610704 10.485517 \n", "sample_799 0.000000 0.0 0.0 ... 5.785237 8.605387 11.004677 \n", "sample_800 0.000000 0.0 0.0 ... 6.403075 8.594354 10.243079 \n", "\n", " gene_20524 gene_20525 gene_20526 gene_20527 gene_20528 \\\n", "sample_0 7.220030 9.119813 12.003135 9.650743 8.921326 \n", "sample_1 6.256586 8.381612 12.674552 10.517059 9.397854 \n", "sample_2 5.401607 9.911597 9.045255 9.788359 10.090470 \n", "sample_3 8.942805 9.601208 11.392682 9.694814 9.684365 \n", "sample_4 7.181162 9.846910 11.922439 9.217749 9.461191 \n", "... ... ... ... ... ... \n", "sample_796 4.484415 9.614701 12.031267 9.813063 10.092770 \n", "sample_797 6.555327 9.064002 11.633422 10.317266 8.745983 \n", "sample_798 3.589763 9.350636 12.180944 10.681194 9.466711 \n", "sample_799 4.745888 9.626383 11.198279 10.335513 10.400581 \n", "sample_800 9.139459 10.102934 11.641081 10.607358 9.844794 \n", "\n", " gene_20529 gene_20530 \n", "sample_0 5.286759 0.000000 \n", "sample_1 2.094168 0.000000 \n", "sample_2 1.683023 0.000000 \n", "sample_3 3.292001 0.000000 \n", "sample_4 5.110372 0.000000 \n", "... ... ... \n", "sample_796 8.819269 0.000000 \n", "sample_797 9.659081 0.000000 \n", "sample_798 4.677458 0.586693 \n", "sample_799 5.718751 0.000000 \n", "sample_800 4.550716 0.000000 \n", "\n", "[801 rows x 20531 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data # print data" ] }, { "cell_type": "code", "execution_count": 4, "id": "grave-judgment", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Class
sample_0PRAD
sample_1LUAD
sample_2PRAD
sample_3PRAD
sample_4BRCA
......
sample_796BRCA
sample_797LUAD
sample_798COAD
sample_799PRAD
sample_800PRAD
\n", "

801 rows × 1 columns

\n", "
" ], "text/plain": [ " Class\n", "sample_0 PRAD\n", "sample_1 LUAD\n", "sample_2 PRAD\n", "sample_3 PRAD\n", "sample_4 BRCA\n", "... ...\n", "sample_796 BRCA\n", "sample_797 LUAD\n", "sample_798 COAD\n", "sample_799 PRAD\n", "sample_800 PRAD\n", "\n", "[801 rows x 1 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labels # print labels" ] }, { "cell_type": "markdown", "id": "scheduled-immigration", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "Our data is a (801, 20531) array, where each row corresponds to one patient, and each column represents one gene. The values of the entries are non-negative numbers that are readings from the RNA-Seq measurements. We can think of each data point as a 20531-dimensional vector, representing a particular patient that belongs to one of the tumor types. Thus, our goal is to separate the 801 data points into different categories. This may sound like the classification problem we did before, except that here we do not know what are the categories or even how many they are." ] }, { "cell_type": "markdown", "id": "spanish-capital", "metadata": { "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "## Dimensionality reduction" ] }, { "cell_type": "markdown", "id": "historical-filename", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "When dealing with such high-dimensional data, it is a good idea to first reduce the dimensionality, such as by using **Principal Component Analysis (PCA)**. The goal is to find a small number of principal components that capture most of the variation among the data points. Imagine that there is one gene that is not expressed in any patient, then this gene (or the dimension it represents) is not useful at all for distinguishing the data points. Therefore we are interested in finding the directions in the data space along which the data points vary the most." ] }, { "cell_type": "markdown", "id": "imposed-blogger", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "Recall that, to perform PCA, we first calculate the covariance matrix of the data. Instead of calculating this by hand, we can use the function `numpy.cov`. Since it expects a 2-d array as input, where each column is a data point, we have to transpose our data." ] }, { "cell_type": "code", "execution_count": 5, "id": "indian-shipping", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [], "source": [ "cov = np.cov(data.T) # calculate the covariance matrix" ] }, { "cell_type": "markdown", "id": "3702a127-45d2-4bfe-aa53-aca8617a0773", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "Then we need to find its eigenvectors that correspond to the largest eigenvalues. We can calculate the eigenvalues and eigenvectors using the `numpy.linalg.eigh` function as before. However, this function calculates *all* eigenvalues and eigenvectors, and there are too many (20531) of them here. Since we are interested in only the largest few eigenvalues, we can instead use the `scipy.sparse.linalg.eigs` function (or the `eigsh` function for symmetric matrices) as follows." ] }, { "cell_type": "code", "execution_count": 6, "id": "composite-parish", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [], "source": [ "import scipy.sparse.linalg as spla\n", "\n", "w, v = spla.eigsh(cov, k=50) # calculate the largest k eigenvalues and their eigenvectors" ] }, { "cell_type": "markdown", "id": "royal-shower", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "Here we calculated the largest 50 eigenvalues and their eigenvectors. Recall that each eigenvector represents a principal component (PC), and the corresponding eigenvalue represents the variance along that direction. Let us find the directions with the largest variance and see how many we need to capture most of the total variance. (Note that we need to normalize the eigenvalues by the total variance, which is the sum of all eigenvalues. But because here we did not calculate all eigenvalues, we cannot use the sum of only these eigenvalues. Instead, the total variance can be calculated as the trace of the covariance matrix.)" ] }, { "cell_type": "code", "execution_count": 7, "id": "baking-helicopter", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [], "source": [ "order = np.argsort(w)[::-1] # argsort gives ascending order, [::-1] reverses the order\n", "w = w[order] # order the eigenvalues\n", "v = v[:,order] # each column is an eigenvector, so we order the columns\n", "\n", "wnorm = w / np.trace(cov) # normalized eigenvalue = fraction of total variance captured\n", "wsum = np.cumsum(wnorm) # cumulative sum of variance captured" ] }, { "cell_type": "code", "execution_count": 8, "id": "handy-sapphire", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "plt.plot(wsum, '.-')\n", "plt.ylim(0, 1)\n", "plt.xlabel('# principal components', fontsize=24)\n", "plt.ylabel('variance captured', fontsize=24)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "legitimate-ballet", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "We see that the captured fraction of variance increases rapidly with the first few eigenvalues, then slows down and almost plateaus (it will slowly increase to 1 when all 20531 eigenvalues are included). Let us pick, say, $K = 10$ principal components. We will project the original data onto these 10 components as follows." ] }, { "cell_type": "code", "execution_count": 9, "id": "interior-quantum", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [], "source": [ "K = 10 # number of principal components to use\n", "\n", "projected = np.dot(data, v[:,:K]) # project data points onto principal components" ] }, { "cell_type": "markdown", "id": "outstanding-season", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "We can try to visualize these principal components, although it is hard to make plots in more than 2 or 3 dimensions. Since the first few components capture the most variance, we may hope that plotting 2 or 3 of them would be enough. It turns out that the first and third PCs give a pretty good impression of the data." ] }, { "cell_type": "code", "execution_count": 10, "id": "a2815355-4354-40c4-a9a0-c930bd203969", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(6,6))\n", "plt.scatter(projected[:,0], projected[:,2], s=5)\n", "plt.xlabel('PC1')\n", "plt.ylabel('PC3')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "5f434906-77aa-4eff-b1bd-315b6b2b68b7", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "We see that the data points are naturally separated into 5 clusters. To improve visualization, we can color the data points according to their labels (which we have pretended not to know). For that we need a bit of processing of the labels since they were given as strings but we want indices." ] }, { "cell_type": "code", "execution_count": 11, "id": "molecular-strap", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'BRCA', 'COAD', 'KIRC', 'LUAD', 'PRAD'}\n" ] } ], "source": [ "types = set(labels['Class']) # collect all tumor types\n", "nt = len(types) # number of types\n", "print(types)" ] }, { "cell_type": "markdown", "id": "19dce532-118d-4879-a618-10e67379cb63", "metadata": {}, "source": [ "These represent kidney cancer, breast cancer, prostate cancer, lung cancer, and colon cancer, respectively." ] }, { "cell_type": "code", "execution_count": 12, "id": "precious-connecticut", "metadata": {}, "outputs": [], "source": [ "type_to_index = dict(zip(types, range(nt))) # map string labels to indices\n", "indices = labels['Class'].map(type_to_index).to_numpy() # convert labels to indices for all data points" ] }, { "cell_type": "markdown", "id": "generic-graduation", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "We can now plot the data points in 3D with colors." ] }, { "cell_type": "code", "execution_count": 13, "id": "greenhouse-factory", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [], "source": [ "from mpl_toolkits.mplot3d import Axes3D\n", "from matplotlib import animation as anim\n", "plt.rcParams[\"animation.html\"] = \"jshtml\"\n", "cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']\n", "colors = np.asarray(cycle)[indices]\n", "\n", "fig = plt.figure(figsize=(8,6))\n", "fig.subplots_adjust(left=0, right=1, bottom=0, top=1)\n", "ax = fig.add_subplot(projection=\"3d\")\n", "ax.scatter(projected[:,0], projected[:,1], projected[:,2], s=5, c=colors)\n", "ax.set_xlabel(r'PC1')\n", "ax.set_ylabel(r'PC2')\n", "ax.set_zlabel(r'PC3')\n", "fig.set_facecolor('w')\n", "\n", "def animate(i):\n", " ax.view_init(elev=10., azim=i*10)\n", " return fig,\n", "\n", "mov = anim.FuncAnimation(fig, animate, frames=36)\n", "# mov.save('source/cancer.gif', fps=5)\n", "plt.close()" ] }, { "cell_type": "code", "execution_count": 14, "id": "218c8581-e71f-43ad-a4e3-e5bc57ff0192", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "
\n", " \n", "
\n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", "
\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mov" ] }, { "cell_type": "markdown", "id": "backed-patient", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "We see that the clusters match very well with the colors. It shows that the first few principal components successfully capture the main features of the dataset. Therefore, we will use these principal components to cluster the data points, again pretending that we do not know the labels (colors)." ] }, { "cell_type": "markdown", "id": "ancient-ontario", "metadata": { "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "## Clustering" ] }, { "cell_type": "markdown", "id": "a4173d12-4f8a-4ead-bc88-94f041c8d54b", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "Our goal is to separate the data points into multiple clusters according to their relative positions in the data space (or the reduced low-dimensional space). We would like the data points within the same cluster to be close together, and different clusters to be relatively far apart. The number of clusters needed to separate the data points is often not known beforehand, so finding the appropriate number of clusters is part of the task. In simple clustering algorithms, such as **k-means** that we will use, the number of clusters is chosen by hand and given to the algorithm. More sophisticated algorithms may select this number automatically according to certain criteria." ] }, { "cell_type": "markdown", "id": "beautiful-annotation", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "The k-means algorithm aims at minimizing the mean squared distance between all pairs of data points within the same cluster, i.e.,\n", "\\begin{equation}\n", "L = \\sum_{i=1}^k \\frac{1}{2 N_i} \\sum_{\\mathbf{X}_n, \\mathbf{X}_m \\in S_i} \\big| \\mathbf{X}_n - \\mathbf{X}_m \\big|^2\n", "\\end{equation}\n", "where $S_i$ is the set of data points belonging to the $i$-th cluster for $i = 1, \\cdots, k$, and $N_i$ is the number of data points in $S_i$. Each data point $\\mathbf{X}_n$ is a vector in the data space, and $|\\mathbf{X}_n - \\mathbf{X}_m|$ is the Euclidean distrance between two such points. This cost function can be equivalently expressed as the sum of squared distance from every data point to the center of its own cluster, i.e.,\n", "\\begin{equation}\n", "L = \\sum_{i=1}^k \\sum_{\\mathbf{X}_n \\in S_i} \\big| \\mathbf{X}_n - \\mathbf{\\mu}_i \\big|^2, \\qquad \\textsf{where} \\quad \\mathbf{\\mu}_i = \\frac{1}{N_i} \\sum_{\\mathbf{X}_n \\in S_i} \\mathbf{X}_n\n", "\\end{equation}\n", "This minimization problem is not easy in the sense that the global minimum is hard to find. However, the k-means algorithm quickly converges to a local minimum, which may be good enough." ] }, { "cell_type": "markdown", "id": "79c17159-36b8-4a71-b3d2-241eb72b2904", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "The algorithm requires the number of clusters, $k$, as an input. One may also provide an initial guess for the center position of every cluster. If not given, the algorithm will use random initial positions for the cluster centers. Because of this randomness and the fact that it only finds the local minimum, running the algorithm twice may give different results. So you may have to try a few times and check if the clustering result is satisfactory." ] }, { "cell_type": "markdown", "id": "38160443-05b0-4b6f-903d-c497df620cf1", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "source": [ "The heuristic idea behind the algorithm is to iterate over two simple steps. The first step starts with the current guess for the center position of every cluster, and reassigns all data points to their nearest center. The second step then collects the data points now assigned to each cluster, and recalculates their center position. The algorithm alternates between these two steps until no reassignment happens, which means the result has converged." ] }, { "cell_type": "markdown", "id": "instructional-inquiry", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "Luckily, we don't have to program this algorithm ourselves, since it is already provided by the python function `scipy.cluster.vq.kmeans2`. In the following, we will use this function to cluster our data. We will use the dimensionally reduced data that we got from PCA above. Also, since we saw that the data points seem to separate nicely into 5 clusters (in agreement with the 5 diagnosed tumor types, which we pretend not to know), we will ask the algorithm to find k = 5 clusters." ] }, { "cell_type": "code", "execution_count": 15, "id": "2284786a-2c4d-4c4d-8846-5244384979a3", "metadata": {}, "outputs": [], "source": [ "import scipy.cluster.vq as vq" ] }, { "cell_type": "code", "execution_count": 16, "id": "incorrect-wages", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [], "source": [ "k = 5 # choose number of clusters\n", "centroid, lab = vq.kmeans2(projected, k) # perform k-means clustering" ] }, { "cell_type": "markdown", "id": "sporting-prediction", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "source": [ "The two outputs are: `centroid`, which is a 2-d array where each row represents the center of a cluster, and `lab`, which is a 1-d array where each entry is an index telling us which cluster each data point belongs to. We could use the `lab` indices to color our data points and plot the clustering result. (Note that the clusters are randomly ordered by `k-means`. We can reorder them to match the \"true labels\" we saw above.)" ] }, { "cell_type": "code", "execution_count": 17, "id": "looking-hours", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [], "source": [ "colormap = []\n", "for cent in centroid:\n", " med = np.argmin(np.sum((projected - cent)**2, axis=1)) # find a point close to each center\n", " colormap.append(indices[med]) # find the true label for that point and use its index for the cluster\n", "newindices = np.asarray(colormap)[lab] # convert to new indices for all points\n", "newcolors = np.asarray(cycle)[newindices] # reorder labels" ] }, { "cell_type": "code", "execution_count": 18, "id": "74995f48-0bcf-427f-a964-165dd8221f34", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(6,6))\n", "plt.scatter(projected[:,0], projected[:,2], s=3, c=newcolors)\n", "plt.xlabel('PC1')\n", "plt.ylabel('PC3')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "administrative-morning", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "source": [ "We see that the colors match very well with the clusters. You may notice some points that seem to wander into another cluster --- that is because we are only plotting two dimensions, and the points that look overlapping may very well be separated in some other dimensions. Try plotting some other principal components yourself." ] }, { "cell_type": "markdown", "id": "million-protein", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "Finally, let us see how well we have clustered the data compared to the true labels. Let us highlight the points that are wrongly labeled." ] }, { "cell_type": "code", "execution_count": 19, "id": "4d6155aa-93cb-43a1-afe7-529dc9f73ae1", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "wrong = (newindices != indices)\n", "\n", "plt.figure(figsize=(6,6))\n", "plt.scatter(projected[:,0], projected[:,2], s=3, c=newcolors)\n", "plt.scatter(projected[wrong,0], projected[wrong,2], s=100, c='k', marker='X')\n", "plt.xlabel('PC1')\n", "plt.ylabel('PC3')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "metropolitan-recorder", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "source": [ "It turns out that we only made 4 mistakes out of 801 data points. Also, it can be seen that the wrongly labeled data points lie near the boundary of the clusters, which makes the mistakes quite understandable." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }