K-means clustering
K-means is an unsupervised hard clustering algorithm. Given a data set, the standard version separates the data into K number of cluster, where each sample in the data set is assigned to a unique cluster. This differs from soft clustering algorithms such as a Gaussian Mixture Model (GMM) where each sample has a probability of how strongly it belongs to each cluster. In a later post we are going to initialise GMM parameters with K-means as it's a lot cheaper to train than a soft clustering model.
It's fine to specify the number of fixed clusters for some problems where the value is known. For instance we know that 10 clusters are needed for the MNIST data set. However, this is not always the case so more sophisticated K-means models start with e.g. a single cluster and then checks if splitting clusters will results in a lower total sum of squared error. This also increases the likelihood of finding the global minimum.
We use the Expectation-Maximisation (EM) algorithm to iteratively find the best centroids and assignments for the cluster. We initialise the cluster centroids randomly and then iteratively perform the following EM steps until the model has converged:
It's fine to specify the number of fixed clusters for some problems where the value is known. For instance we know that 10 clusters are needed for the MNIST data set. However, this is not always the case so more sophisticated K-means models start with e.g. a single cluster and then checks if splitting clusters will results in a lower total sum of squared error. This also increases the likelihood of finding the global minimum.
We use the Expectation-Maximisation (EM) algorithm to iteratively find the best centroids and assignments for the cluster. We initialise the cluster centroids randomly and then iteratively perform the following EM steps until the model has converged:
- E: Assign each sample to the nearest cluster (here using Euclidean distance)
- M: Calculate cluster centroids based on the current assignments
The images below show the first random initialisation and the converged solution after 6 iterations. The code for the simple clustering module and corresponding tests can be found at:
Comments
Post a Comment