Jupyter Notebook#
Lec 31: Hierarchical Clustering#
# Everyone's favorite standard imports
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import time
I’ve made us some data, easy case first, where we’ve got data that I made drawing from three clusters. You can probably look at the picture and figure out what they should be.
X = np.loadtxt('../../DataSets/Clustering-ToyData.csv')
plt.scatter(X[:,0],X[:,1])
plt.show()
But notice that when I look at the order the points are included in the matrix, it’s got nothing to do with the order that the points came in the matrix.
plt.scatter(X[:,0],X[:,1], c = list(range(X.shape[0])))
plt.colorbar()
plt.title('Colored by order of points in the matrix')
plt.show()
The next thing we can do is look at the distance matrix of the points, which is a square matrix for which entry D[i,j]
is the distance between points i
and j
.
from scipy.spatial import distance_matrix
D = distance_matrix(X,X)
plt.matshow(D)
plt.colorbar()
plt.title('Distance matrix')
plt.show()
There are lots of controls in that function including messing with the choice of distances between the points, but for now we’ll leave it at the good ol’ fashioned Euclidean distance,
Before we go on, here’s a trick for figuring out where a particular data point is.
# I want to find point 53
colors = np.zeros(X.shape[0])
colors[53] = -1 #<--- Make the color of point 53 different from the others
plt.scatter(X[:,0],X[:,1], c = colors)
plt.colorbar()
plt.title('Colored by order of points in the matrix')
print('Distance:', D[18,7])
plt.show()
✅ Answer these:
Where are the 7th and 18th (according to python numbering) points in the scatter plot above?
Are they in the same cluster (based on eyeballing it) or different clusters?
What is the distance between the 7th and 18th points? Use the distance matrix to figure it out
# Your code here #
I can also play around with looking at the connections like in class, where we connect up points within distance \(r\) of each other. Mess around with the \(r\) value below to see what changes. Be careful: large \(r\) values wil make for slow drawing.
plt.scatter(X[:,0],X[:,1])
r = 0.3
for i in range(X.shape[0]):
for j in range(X.shape[0]):
if D[i,j] <= r:
p = X[(i,j),:]
plt.plot(p[:,0],p[:,1],c = 'black')
plt.show()
This matrix is very helpful for understanding distances, however, my code in just a moment is going to need the condensed
distance matrix. It turns out that it has the same information as the above distance matrix, just flattened out. For our purposes, you won’t need to fully understand this but the curious can go check out the documentation.
from scipy.spatial.distance import pdist
P = pdist(X)
P.shape
And with that, we can get scipy
to compute our dendrogram for us!
from scipy.cluster import hierarchy
Z = hierarchy.linkage(pdist(X), 'single')
plt.figure()
dn = hierarchy.dendrogram(Z)
plt.show()
First thing to notice, it automatically gives me some color visual guesses for what it thinks the threshold should be for clustering. Hunting around in the documentation, it turns out it chooses this threshold based on some information in the linkage diagram……
h = 0.7*max(Z[:,2]) #<---- this equation is what it uses to pick the cutoff
print('Chosen threshold:', h)
…but we can also control the cutoff that it draws the colors
thresh_height = 0.5
dn = hierarchy.dendrogram(Z, color_threshold = thresh_height)
plt.axhline(y = thresh_height, color = 'r', linestyle = '-')
plt.show()
✅ Q: How many clusters are there if we have a cutoff of 2.0? What about for 0.5?
# Your code here #
Now the magic is that we have the ability to extract the clusters based on the chosen threshold.
labels = hierarchy.fcluster(Z,h,criterion = 'distance')
plt.scatter(X[:,0],X[:,1],c = labels)
plt.show()
np.unique(labels)
✅ Q: Where are the clusters associated to a threshold of 2.0? What about 0.5?
# Your code here #
# Your code here #
A nastier example#
I made another data set that’s a bit more ambiguous. How many clusters are there? 2? 3? Ok, ok, let’s not fight about it. I actually made it using three centers of a normal distribution, but one could argue whether there’s two cleanly split clusters, or whether this is related to density and therefore there are 3.
X = np.loadtxt('../../DataSets/Clustering-ToyData2.csv')
plt.scatter(X[:,0],X[:,1])
plt.show()
✅ Q: Draw the single linkage cluster dendogram as above.
Is there a reasonable choice of threshold?
Using this threshold, how many clusters are there?
Draw the cluster assignments on the points using some choice of threshold, reasonable or not.
# Your code here #
Now, we can mess with the choice of linkage to see how this changes the clustering assignment.
✅ Q: Modify your linkage to use complete
instead of single
and draw the resulting dendrogram.
Is there a reasonable choice of threshold?
Using this threshold, how many clusters are there?
Draw the cluster assignments on the points using some choice of threshold, reasonable or not.
# Your code here #
Congratulations, we’re done!#
Written by Dr. Liz Munch, Michigan State University
This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.