16. K-Means Clustering: An Animated Walkthrough#

This notebook animates the k-means algorithm step by step on the Iris dataset (petal length vs. petal width), so you can watch each phase of the algorithm unfold:

  1. Initialize — place K centroids randomly in feature space

  2. Assign — color each point by its nearest centroid

  3. Update — move each centroid to the mean of its assigned points

  4. Repeat until convergence

Voronoi decision boundaries are drawn so you can see why each point is assigned to its cluster.

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.animation import FuncAnimation
from matplotlib.colors import ListedColormap
from IPython.display import HTML

from sklearn.datasets import load_iris
from sklearn.metrics import silhouette_score

# ── Reproducibility ──────────────────────────────────────────────────────────
RNG = np.random.default_rng(7)

16.1. The Dataset: Iris Sepal Length vs. Sepal Width#

We use two features from the Iris dataset — sepal length (cm) and sepal width (cm). These features produce overlapping clusters, so the Voronoi boundaries and their meeting point (the triple-point) fall well within the data range — making the decision regions easy to see.

Note that we are hiding the true labels from the algorithm; k-means only sees the raw coordinates.

iris = load_iris()
# Features: sepal length (col 0) and sepal width (col 1)
X = iris.data[:, 0:2]
true_labels = iris.target
species = iris.target_names

fig, ax = plt.subplots(figsize=(7, 5))
scatter = ax.scatter(X[:, 0], X[:, 1], c=true_labels, cmap='Set1', alpha=0.7, edgecolors='k', linewidths=0.4)
ax.set_xlabel("Sepal Length (cm)")
ax.set_ylabel("Sepal Width (cm)")
ax.set_title("Iris Dataset — True Labels (hidden from k-means)")
handles = [mpatches.Patch(color=scatter.cmap(scatter.norm(i)), label=s) for i, s in enumerate(species)]
ax.legend(handles=handles, title="True Species")
plt.tight_layout()
plt.show()
../_images/497d875e7abc27a8779b17b8a27f8cedff59f59420c244957fb63c2875cf7167.png

16.2. K-Means From Scratch#

We implement k-means manually so we can record every intermediate state (every assign step and every update step), not just the final result.

def run_kmeans(X, K, rng, max_iter=15):
    """
    Run k-means and record every intermediate state.

    Returns a list of frames. Each frame is a dict:
        phase      : 'init' | 'assign' | 'update' | 'converged'
        iteration  : int
        centroids  : (K, 2) array — centroid positions THIS frame
        labels     : (N,)   array — cluster assignment of each point
        description: human-readable caption
    """
    n, d = X.shape

    # ── Initialization: pick K random points as starting centroids ────────────
    idx = rng.choice(n, size=K, replace=False)
    centroids = X[idx].copy()
    labels = np.full(n, -1, dtype=int)  # -1 = unassigned

    frames = []
    frames.append(dict(
        phase="init", iteration=0,
        centroids=centroids.copy(), labels=labels.copy(),
        description="Step 0 — Initialize: place K centroids at random data points"
    ))

    for it in range(1, max_iter + 1):
        # ── Assignment step ───────────────────────────────────────────────────
        # For each point, compute distance to every centroid; assign closest.
        dists = np.linalg.norm(X[:, None, :] - centroids[None, :, :], axis=2)  # (N, K)
        new_labels = np.argmin(dists, axis=1)

        frames.append(dict(
            phase="assign", iteration=it,
            centroids=centroids.copy(), labels=new_labels.copy(),
            description=f"Iteration {it} — Assign: color each point by its nearest centroid"
        ))

        # ── Update step ───────────────────────────────────────────────────────
        new_centroids = np.array([X[new_labels == k].mean(axis=0) for k in range(K)])

        frames.append(dict(
            phase="update", iteration=it,
            centroids=new_centroids.copy(), labels=new_labels.copy(),
            description=f"Iteration {it} — Update: move each centroid to its cluster mean"
        ))

        # ── Convergence check ─────────────────────────────────────────────────
        if np.allclose(centroids, new_centroids):
            frames.append(dict(
                phase="converged", iteration=it,
                centroids=new_centroids.copy(), labels=new_labels.copy(),
                description=f"Converged after {it} iteration(s) — no point changed cluster"
            ))
            break

        centroids = new_centroids
        labels = new_labels

    return frames


K = 3
frames = run_kmeans(X, K=K, rng=RNG)
print(f"Recorded {len(frames)} animation frames over {frames[-1]['iteration']} iteration(s).")
print("Phases:", [f['phase'] for f in frames])
Recorded 31 animation frames over 15 iteration(s).
Phases: ['init', 'assign', 'update', 'assign', 'update', 'assign', 'update', 'assign', 'update', 'assign', 'update', 'assign', 'update', 'assign', 'update', 'assign', 'update', 'assign', 'update', 'assign', 'update', 'assign', 'update', 'assign', 'update', 'assign', 'update', 'assign', 'update', 'assign', 'update']

16.3. The Animation#

Each frame shows:

  • Colored points — cluster assignment (gray = unassigned, during init)

  • Voronoi regions — background shading showing which centroid “owns” each area of feature space

  • Centroid markers — diamonds (◆) that move during the Update step

  • Caption — describes what just happened

  • Silhouette score panel (right) — plots the silhouette score after each assign step, revealing how cluster quality evolves. Dashed reference lines mark 0.50 (good) and 0.75 (strong) confidence thresholds. The current score is highlighted in red.

def make_animation(X, frames, K, interval_ms=900):
    """
    Build and return a FuncAnimation that steps through each k-means frame.
    Right panel tracks silhouette score (one point per assign step).
    """
    # ── Color palette ─────────────────────────────────────────────────────────
    CLUSTER_COLORS = ['#E63946', '#2A9D8F', '#F4A261']
    UNASSIGNED_COLOR = '#BBBBBB'

    # ── Voronoi grid ──────────────────────────────────────────────────────────
    pad = 0.3
    x_min, x_max = X[:, 0].min() - pad, X[:, 0].max() + pad
    y_min, y_max = X[:, 1].min() - pad, X[:, 1].max() + pad
    grid_res = 300

    gx = np.linspace(x_min, x_max, grid_res)
    gy = np.linspace(y_min, y_max, grid_res)
    grid_pts = np.c_[*np.meshgrid(gx, gy)]          # (grid_res², 2)

    bg_colors = [plt.cm.colors.to_rgba(c, alpha=0.18) for c in CLUSTER_COLORS]
    bg_cmap = ListedColormap(bg_colors)

    # ── Precompute silhouette score at each assign step ───────────────────────
    # One value per iteration (assign and update share the same labels, so
    # computing once per assign frame is sufficient).
    history = []   # list of (frame_idx, iteration, score)
    for fi, frame in enumerate(frames):
        if frame['phase'] == 'assign':
            score = silhouette_score(X, frame['labels'])
            history.append((fi, frame['iteration'], score))

    max_iter = frames[-1]['iteration']
    all_scores = [s for _, _, s in history]
    sil_lo = max(0.0, min(all_scores) - 0.05)
    sil_hi = min(1.0, max(all_scores) + 0.05)

    # ── Figure: main Voronoi plot + silhouette score panel ────────────────────
    fig, (ax, ax_sil) = plt.subplots(
        1, 2, figsize=(11, 5),
        gridspec_kw={'width_ratios': [3, 1]}
    )
    fig.patch.set_facecolor('white')

    # ── Main axes ─────────────────────────────────────────────────────────────
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_xlabel("Sepal Length (cm)", fontsize=11)
    ax.set_ylabel("Sepal Width (cm)", fontsize=11)

    bg_img = ax.imshow(
        np.zeros((grid_res, grid_res), dtype=int),
        extent=[x_min, x_max, y_min, y_max],
        origin='lower', aspect='auto',
        cmap=bg_cmap, vmin=0, vmax=K - 1,
        interpolation='nearest', zorder=0
    )

    scatter = ax.scatter(
        X[:, 0], X[:, 1],
        c=[UNASSIGNED_COLOR] * len(X),
        edgecolors='k', linewidths=0.4, s=45, zorder=2
    )

    centroid_scatters = []
    for k in range(K):
        cs = ax.plot([], [], marker='D', markersize=10,
                     color=CLUSTER_COLORS[k], markeredgecolor='black',
                     markeredgewidth=1.0, linestyle='none', zorder=4)[0]
        centroid_scatters.append(cs)

    caption = ax.text(
        0.5, 1.03, '', transform=ax.transAxes,
        ha='center', va='bottom', fontsize=10,
        bbox=dict(boxstyle='round,pad=0.3', facecolor='#F0F0F0', edgecolor='gray')
    )

    legend_handles = [
        mpatches.Patch(color=CLUSTER_COLORS[k], label=f'Cluster {k + 1}')
        for k in range(K)
    ]
    ax.legend(handles=legend_handles, loc='upper left', fontsize=9)

    # ── Silhouette panel ──────────────────────────────────────────────────────
    ax_sil.set_xlim(0.5, max_iter + 0.5)
    ax_sil.set_ylim(sil_lo, sil_hi)
    ax_sil.set_xlabel("Iteration", fontsize=10)
    ax_sil.set_ylabel("Silhouette Score", fontsize=10)
    ax_sil.set_title("Silhouette Score", fontsize=11)
    ax_sil.set_xticks(range(1, max_iter + 1))
    ax_sil.tick_params(axis='both', labelsize=9)
    # Reference lines for score quality thresholds
    ax_sil.axhline(0.50, color='#999999', linestyle='--', linewidth=0.8, label='0.50')
    ax_sil.axhline(0.75, color='#555555', linestyle='--', linewidth=0.8, label='0.75')
    ax_sil.legend(fontsize=8, loc='lower right')

    hist_line,  = ax_sil.plot([], [], color='steelblue', linewidth=1.5, zorder=2)
    hist_dots   = ax_sil.scatter([], [], color='steelblue', s=40, zorder=3)
    curr_dot    = ax_sil.scatter([], [], color='crimson', s=70, zorder=5)
    curr_text   = ax_sil.text(0, 0, '', fontsize=9, color='crimson',
                               ha='left', va='bottom')

    plt.tight_layout()

    # ── Frame update ──────────────────────────────────────────────────────────
    def update(frame_idx):
        frame = frames[frame_idx]
        centroids = frame['centroids']
        labels    = frame['labels']

        # Voronoi background
        grid_dists  = np.linalg.norm(
            grid_pts[:, None, :] - centroids[None, :, :], axis=2
        )
        grid_labels = np.argmin(grid_dists, axis=1).reshape(grid_res, grid_res)
        bg_img.set_data(grid_labels)

        # Point colors
        if frame['phase'] == 'init':
            colors = [UNASSIGNED_COLOR] * len(X)
        else:
            colors = [CLUSTER_COLORS[l] for l in labels]
        scatter.set_facecolors(colors)

        # Centroid markers
        for k in range(K):
            centroid_scatters[k].set_data([centroids[k, 0]], [centroids[k, 1]])

        # Caption
        caption.set_text(frame['description'])

        # Silhouette panel — reveal scores up to the current frame
        visible = [(it, s) for fi, it, s in history if fi <= frame_idx]
        if visible:
            its, scs = zip(*visible)
            hist_line.set_data(its, scs)
            hist_dots.set_offsets(np.c_[its, scs])
            cx, cy = its[-1], scs[-1]
            curr_dot.set_offsets([[cx, cy]])
            # nudge label so it doesn't overlap the dot
            label_x = cx + 0.08 if cx < max_iter - 0.5 else cx - 0.08
            curr_text.set_position((label_x, cy))
            curr_text.set_ha('left' if cx < max_iter - 0.5 else 'right')
            curr_text.set_text(f'{cy:.3f}')
        else:
            hist_line.set_data([], [])
            hist_dots.set_offsets(np.empty((0, 2)))
            curr_dot.set_offsets(np.empty((0, 2)))
            curr_text.set_text('')

        return ([bg_img, scatter, caption,
                 hist_line, hist_dots, curr_dot, curr_text]
                + centroid_scatters)

    anim = FuncAnimation(
        fig, update,
        frames=len(frames),
        interval=interval_ms,
        blit=False,
        repeat=True
    )

    plt.close(fig)
    return anim


anim = make_animation(X, frames, K=K, interval_ms=1000)
HTML(anim.to_jshtml())
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[4], line 167
    163     return anim
    166 anim = make_animation(X, frames, K=K, interval_ms=1000)
--> 167 HTML(anim.to_jshtml())

File ~/.pyenv/versions/3.13.1/envs/datascience/lib/python3.13/site-packages/matplotlib/animation.py:1376, in Animation.to_jshtml(self, fps, embed_frames, default_mode)
   1372         path = Path(tmpdir, "temp.html")
   1373         writer = HTMLWriter(fps=fps,
   1374                             embed_frames=embed_frames,
   1375                             default_mode=default_mode)
-> 1376         self.save(str(path), writer=writer)
   1377         self._html_representation = path.read_text()
   1379 return self._html_representation

File ~/.pyenv/versions/3.13.1/envs/datascience/lib/python3.13/site-packages/matplotlib/animation.py:1109, in Animation.save(self, filename, writer, fps, dpi, codec, bitrate, extra_args, metadata, extra_anim, savefig_kwargs, progress_callback)
   1106     savefig_kwargs['transparent'] = False   # just to be safe!
   1108 for anim in all_anim:
-> 1109     anim._init_draw()  # Clear the initial frame
   1110 frame_number = 0
   1111 # TODO: Currently only FuncAnimation has a save_count
   1112 #       attribute. Can we generalize this to all Animations?

File ~/.pyenv/versions/3.13.1/envs/datascience/lib/python3.13/site-packages/matplotlib/animation.py:1770, in FuncAnimation._init_draw(self)
   1762         warnings.warn(
   1763             "Can not start iterating the frames for the initial draw. "
   1764             "This can be caused by passing in a 0 length sequence "
   (...)   1767             "it may be exhausted due to a previous display or save."
   1768         )
   1769         return
-> 1770     self._draw_frame(frame_data)
   1771 else:
   1772     self._drawn_artists = self._init_func()

File ~/.pyenv/versions/3.13.1/envs/datascience/lib/python3.13/site-packages/matplotlib/animation.py:1789, in FuncAnimation._draw_frame(self, framedata)
   1785     self._save_seq = self._save_seq[-self._save_count:]
   1787 # Call the func with framedata and args. If blitting is desired,
   1788 # func needs to return a sequence of any artists that were modified.
-> 1789 self._drawn_artists = self._func(framedata, *self._args)
   1791 if self._blit:
   1793     err = RuntimeError('The animation function must return a sequence '
   1794                        'of Artist objects.')

Cell In[4], line 112, in make_animation.<locals>.update(frame_idx)
    108 labels    = frame['labels']
    110 # Voronoi background
    111 grid_dists  = np.linalg.norm(
--> 112     grid_pts[:, None, :] - centroids[None, :, :], axis=2
    113 )
    114 grid_labels = np.argmin(grid_dists, axis=1).reshape(grid_res, grid_res)
    115 bg_img.set_data(grid_labels)

ValueError: operands could not be broadcast together with shapes (300,1,600) (1,3,2) 

16.4. Try Different Initializations#

A key weakness of k-means is sensitivity to initialization — different starting centroids can lead to different (sometimes worse) final clusters. Run the cell below a few times with different seeds to see this in action.

import random

# Change this seed to see a different random initialization
SEED = random.randint(0, 999)
print(f"Using seed: {SEED}")

frames_alt = run_kmeans(X, K=K, rng=np.random.default_rng(SEED))
print(f"Converged in {frames_alt[-1]['iteration']} iteration(s), {len(frames_alt)} frames.")

anim_alt = make_animation(X, frames_alt, K=K, interval_ms=1000)
HTML(anim_alt.to_jshtml())
Using seed: 128
Converged in 15 iteration(s), 31 frames.
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[5], line 11
      8 print(f"Converged in {frames_alt[-1]['iteration']} iteration(s), {len(frames_alt)} frames.")
     10 anim_alt = make_animation(X, frames_alt, K=K, interval_ms=1000)
---> 11 HTML(anim_alt.to_jshtml())

File ~/.pyenv/versions/3.13.1/envs/datascience/lib/python3.13/site-packages/matplotlib/animation.py:1376, in Animation.to_jshtml(self, fps, embed_frames, default_mode)
   1372         path = Path(tmpdir, "temp.html")
   1373         writer = HTMLWriter(fps=fps,
   1374                             embed_frames=embed_frames,
   1375                             default_mode=default_mode)
-> 1376         self.save(str(path), writer=writer)
   1377         self._html_representation = path.read_text()
   1379 return self._html_representation

File ~/.pyenv/versions/3.13.1/envs/datascience/lib/python3.13/site-packages/matplotlib/animation.py:1109, in Animation.save(self, filename, writer, fps, dpi, codec, bitrate, extra_args, metadata, extra_anim, savefig_kwargs, progress_callback)
   1106     savefig_kwargs['transparent'] = False   # just to be safe!
   1108 for anim in all_anim:
-> 1109     anim._init_draw()  # Clear the initial frame
   1110 frame_number = 0
   1111 # TODO: Currently only FuncAnimation has a save_count
   1112 #       attribute. Can we generalize this to all Animations?

File ~/.pyenv/versions/3.13.1/envs/datascience/lib/python3.13/site-packages/matplotlib/animation.py:1770, in FuncAnimation._init_draw(self)
   1762         warnings.warn(
   1763             "Can not start iterating the frames for the initial draw. "
   1764             "This can be caused by passing in a 0 length sequence "
   (...)   1767             "it may be exhausted due to a previous display or save."
   1768         )
   1769         return
-> 1770     self._draw_frame(frame_data)
   1771 else:
   1772     self._drawn_artists = self._init_func()

File ~/.pyenv/versions/3.13.1/envs/datascience/lib/python3.13/site-packages/matplotlib/animation.py:1789, in FuncAnimation._draw_frame(self, framedata)
   1785     self._save_seq = self._save_seq[-self._save_count:]
   1787 # Call the func with framedata and args. If blitting is desired,
   1788 # func needs to return a sequence of any artists that were modified.
-> 1789 self._drawn_artists = self._func(framedata, *self._args)
   1791 if self._blit:
   1793     err = RuntimeError('The animation function must return a sequence '
   1794                        'of Artist objects.')

Cell In[4], line 112, in make_animation.<locals>.update(frame_idx)
    108 labels    = frame['labels']
    110 # Voronoi background
    111 grid_dists  = np.linalg.norm(
--> 112     grid_pts[:, None, :] - centroids[None, :, :], axis=2
    113 )
    114 grid_labels = np.argmin(grid_dists, axis=1).reshape(grid_res, grid_res)
    115 bg_img.set_data(grid_labels)

ValueError: operands could not be broadcast together with shapes (300,1,600) (1,3,2)