Demo
Full code on GitHub

In this tutorial we will be implementing K-means clustering using JavaScript and D3.js.

Our project will have 2 files: index.html for markup & index.js for the script. Let’s start by creating the HTML file. In the head tag we import the D3.js library

<head>
    <script src='https://d3js.org/d3.v6.js'></script>
</head>

And in the body tag:

  • Add an empty div container, where all the action will happen. We assign it an id so we can select it in our JS code.
  • Include our script, which will make the action happen.
<body>
    <div id='container'></div>
    <script src='index.js'></script>
</body>

Moving on to the index.js file, let’s start with specifying dimensions and margins:

const margin = { top: 10, right: 60, bottom: 20, left: 20 };
const viewBox = { x: 0, y: 0, w: 1000, h: 600 };
const width = viewBox.w - margin.left - margin.right;
const height = viewBox.h - margin.top - margin.bottom;

Next, let’s fire up D3.js by selecting our div container right by it’s id (#container) and adding an SVG with a group (g tag) that will take up all the space, except the margins.

const svg = d3.select('#container')
    .append('svg')
    .attr('viewBox', `${viewBox.x} ${viewBox.y} ${viewBox.w} ${viewBox.h}`)
    .attr('width', window.innerWidth - margin.left - margin.right)
    .attr('height', window.innerHeight - margin.top - margin.bottom)
    .append('g')
    .attr('transform', `translate(${margin.left}, ${margin.top})`)     // mind the margins
    .attr('color', '#e6e8ea')                                          // font color
    .attr('font-weight', 'bold')                                       // we are bold enough to do this
    .attr('stroke-width', 2);                                          // and even this

Let’s add a resize listener, so the sneaky testers won’t complain.

window.addEventListener('resize', function (event) {    // testers hate this one simple function
    d3.select('svg')
        .attr('viewBox', `${viewBox.x} ${viewBox.y} ${viewBox.w} ${viewBox.h}`)
        .attr('width', window.innerWidth - margin.left - margin.right)
        .attr('height', window.innerHeight - margin.top - margin.bottom)
});

Do you like dark themed pages? I do, so let’s make the background dark, but not too dark. How about #1e1e1e?

document.body.style.background = '#1e1e1e';

D3.js makes it easy to add axes. First we initialize the domains and the scaling functions. I usually grade people arbitrary data points from 0 to 10. Let that be our domain:

const xrange = [0, 10];
const x = d3.scaleLinear()
    .domain(xrange)         // values from our domain (0 to 10)
    .range([0, width]);     // will be assigned a valid x coordinate

const yrange = [0, 10];
const y = d3.scaleLinear()
    .domain(yrange)         // remember that in SVG the y axis points downwards
    .range([height, 0]);    // but we want our axis pointing upwards, like a normal damn axis

We also initialize a scaling function for coloring points according to their cluster:

const color = d3.scaleOrdinal(d3.schemeCategory10); // 10 different colors for 10 different numbers

Then we add the axes themselves to the SVG:

svg.append('g')
    .attr('transform', `translate(0, ${height})`)  // placed at the bottom
    .call(d3.axisBottom(x));

svg.append('g')
    .call(d3.axisLeft(y));

Now that we have the axes, let’s generate some random points. We assign them random coordinates from our domains. The cluster is initally unknown, so let’s set it to null.

const getRandomPoint = () => {
    const point = {
        x: Math.random() * xrange[1],
        y: Math.random() * yrange[1],
        cluster: null
    };

    return point;
}

const generatePoints = (n) => {     // generate an array of n random points
    return Array.from(Array(n)).map(_ => getRandomPoint());
}

Suppose we have 1000 points and 5 clusters. We can use the same generatePoints function to initialize both the data points and the cluster centroids:

const numPoints = 1000;
const numClusters = 5;
const points = generatePoints(numPoints);
const centroids = generatePoints(numClusters);

Now that we have some random data, let’s visualize it using D3.js:

const pointsSvg = svg.append('g')          // place them in a group, so they don't run away
    .attr('id', 'points-svg')              // assign them an id, taking away their individuality
    .selectAll('dot')
    .data(points)                          // loop over our data
    .join('circle')                        // add a circle
    .attr('cx', d => x(d.x))               // position
    .attr('cy', d => y(d.y))
    .attr('r', 4)                          // radius
    .style('fill', d => color(d.cluster)); // color according to the cluster

And almost the same thing for centroids:

const centroidsSvg = svg.append('g')
    .attr('id', 'centroids-svg')
    .selectAll('dot')
    .data(centroids)
    .join('circle')
    .attr('cx', d => x(d.x))
    .attr('cy', d => y(d.y))
    .attr('r', 5)                       // a bit bigger than data points
    .style('fill', '#e6e8ea')           // greyish fill
    .attr('stroke', (d, i) => color(i)) // and a thick colorful outline
    .attr('stroke-width', 2);

Remember we set the initial cluster to null? Let’s fix it by writing a function to find the index of the closest centroid to a given point

const distance = (a, b) => {    // Euclidean distance in 2D
    return Math.sqrt((a.x - b.x) ** 2 + (a.y - b.y) ** 2);
}

const closestCentroid = (point) => {
    const distances = centroids.map(centroid => distance(point, centroid));   // distance to each centroid
    const i = distances.findIndex(d => d === Math.min(...distances));         // index of the closest centroid
    return i;
}

Using the closestCentroid function we can identify which cluster a point belongs to and color it accordingly. Let’s make the recoloring process smooth by adding a 500ms transition:

const updatePoints = () => {
    points.forEach(point => {
        point.cluster = closestCentroid(point);
    });
    pointsSvg.transition()
        .duration(500)
        .style('fill', d => color(d.cluster));
}

Updating the cluster centroids consists of calculating the average position of all data points belonging to that cluster and then updating the centroid position. Also smooth:

const avg = (arr) => arr.reduce((p, c) => p + c, 0) / arr.length;      // average of a numeric array
const updateCentroids = () => {
    centroids.forEach((centroid, i) => {
        const cluster = points.filter(point => point.cluster === i);   // all points in the cluster
        if (cluster.length > 0) {
            centroid.x = avg(cluster.map(point => point.x));           // calculate average position
            centroid.y = avg(cluster.map(point => point.y));
        }
    });
    centroidsSvg.transition()
        .duration(500)
        .attr('cx', d => x(d.x))
        .attr('cy', d => y(d.y));                                      // update centroid position
}

That’s the whole logic. Let’s start it up and update every second:

updatePoints();          // assign the initial cluster
setInterval(() => {
    updateCentroids();
    updatePoints();
}, 1000);               // 1000ms = 1s

It works! Looks smooth. I also added controls, so you can enter any number of data points and clusters.
Check it out: https://ivanludvig.github.io/kmeans/

The conclusion that I want to make today is the following:

Slow is smooth, smooth is fast