K-means clustering visualization using D3.js
In this tutorial we will be implementing K-means clustering using JavaScript and D3.js.
Our project will have 2 files: index.html
for markup & index.js
for the script.
Let’s start by creating the HTML file. In the head tag we import the D3.js library
<head>
<script src='https://d3js.org/d3.v6.js'></script>
</head>
And in the body tag:
- Add an empty div container, where all the action will happen. We assign it an id so we can select it in our JS code.
- Include our script, which will make the action happen.
<body>
<div id='container'></div>
<script src='index.js'></script>
</body>
Moving on to the index.js
file, let’s start with specifying dimensions and margins:
const margin = { top: 10, right: 60, bottom: 20, left: 20 };
const viewBox = { x: 0, y: 0, w: 1000, h: 600 };
const width = viewBox.w - margin.left - margin.right;
const height = viewBox.h - margin.top - margin.bottom;
Next, let’s fire up D3.js by selecting our div container right by it’s id (#container
) and adding an SVG with a group (g
tag) that will take up all the space, except the margins.
const svg = d3.select('#container')
.append('svg')
.attr('viewBox', `${viewBox.x} ${viewBox.y} ${viewBox.w} ${viewBox.h}`)
.attr('width', window.innerWidth - margin.left - margin.right)
.attr('height', window.innerHeight - margin.top - margin.bottom)
.append('g')
.attr('transform', `translate(${margin.left}, ${margin.top})`) // mind the margins
.attr('color', '#e6e8ea') // font color
.attr('font-weight', 'bold') // we are bold enough to do this
.attr('stroke-width', 2); // and even this
Let’s add a resize listener, so the sneaky testers won’t complain.
window.addEventListener('resize', function (event) { // testers hate this one simple function
d3.select('svg')
.attr('viewBox', `${viewBox.x} ${viewBox.y} ${viewBox.w} ${viewBox.h}`)
.attr('width', window.innerWidth - margin.left - margin.right)
.attr('height', window.innerHeight - margin.top - margin.bottom)
});
Do you like dark themed pages? I do, so let’s make the background dark, but not too dark. How about #1e1e1e
?
document.body.style.background = '#1e1e1e';
D3.js makes it easy to add axes. First we initialize the domains and the scaling functions. I usually grade people arbitrary data points from 0 to 10. Let that be our domain:
const xrange = [0, 10];
const x = d3.scaleLinear()
.domain(xrange) // values from our domain (0 to 10)
.range([0, width]); // will be assigned a valid x coordinate
const yrange = [0, 10];
const y = d3.scaleLinear()
.domain(yrange) // remember that in SVG the y axis points downwards
.range([height, 0]); // but we want our axis pointing upwards, like a normal damn axis
We also initialize a scaling function for coloring points according to their cluster:
const color = d3.scaleOrdinal(d3.schemeCategory10); // 10 different colors for 10 different numbers
Then we add the axes themselves to the SVG:
svg.append('g')
.attr('transform', `translate(0, ${height})`) // placed at the bottom
.call(d3.axisBottom(x));
svg.append('g')
.call(d3.axisLeft(y));
Now that we have the axes, let’s generate some random points. We assign them random coordinates from our domains. The cluster is initally unknown, so let’s set it to null
.
const getRandomPoint = () => {
const point = {
x: Math.random() * xrange[1],
y: Math.random() * yrange[1],
cluster: null
};
return point;
}
const generatePoints = (n) => { // generate an array of n random points
return Array.from(Array(n)).map(_ => getRandomPoint());
}
Suppose we have 1000 points and 5 clusters. We can use the same generatePoints
function to initialize both the data points and the cluster centroids:
const numPoints = 1000;
const numClusters = 5;
const points = generatePoints(numPoints);
const centroids = generatePoints(numClusters);
Now that we have some random data, let’s visualize it using D3.js:
const pointsSvg = svg.append('g') // place them in a group, so they don't run away
.attr('id', 'points-svg') // assign them an id, taking away their individuality
.selectAll('dot')
.data(points) // loop over our data
.join('circle') // add a circle
.attr('cx', d => x(d.x)) // position
.attr('cy', d => y(d.y))
.attr('r', 4) // radius
.style('fill', d => color(d.cluster)); // color according to the cluster
And almost the same thing for centroids:
const centroidsSvg = svg.append('g')
.attr('id', 'centroids-svg')
.selectAll('dot')
.data(centroids)
.join('circle')
.attr('cx', d => x(d.x))
.attr('cy', d => y(d.y))
.attr('r', 5) // a bit bigger than data points
.style('fill', '#e6e8ea') // greyish fill
.attr('stroke', (d, i) => color(i)) // and a thick colorful outline
.attr('stroke-width', 2);
Remember we set the initial cluster to null
? Let’s fix it by writing a function to find the index of the closest centroid to a given point
const distance = (a, b) => { // Euclidean distance in 2D
return Math.sqrt((a.x - b.x) ** 2 + (a.y - b.y) ** 2);
}
const closestCentroid = (point) => {
const distances = centroids.map(centroid => distance(point, centroid)); // distance to each centroid
const i = distances.findIndex(d => d === Math.min(...distances)); // index of the closest centroid
return i;
}
Using the closestCentroid
function we can identify which cluster a point belongs to and color it accordingly. Let’s make the recoloring process smooth by adding a 500ms transition:
const updatePoints = () => {
points.forEach(point => {
point.cluster = closestCentroid(point);
});
pointsSvg.transition()
.duration(500)
.style('fill', d => color(d.cluster));
}
Updating the cluster centroids consists of calculating the average position of all data points belonging to that cluster and then updating the centroid position. Also smooth:
const avg = (arr) => arr.reduce((p, c) => p + c, 0) / arr.length; // average of a numeric array
const updateCentroids = () => {
centroids.forEach((centroid, i) => {
const cluster = points.filter(point => point.cluster === i); // all points in the cluster
if (cluster.length > 0) {
centroid.x = avg(cluster.map(point => point.x)); // calculate average position
centroid.y = avg(cluster.map(point => point.y));
}
});
centroidsSvg.transition()
.duration(500)
.attr('cx', d => x(d.x))
.attr('cy', d => y(d.y)); // update centroid position
}
That’s the whole logic. Let’s start it up and update every second:
updatePoints(); // assign the initial cluster
setInterval(() => {
updateCentroids();
updatePoints();
}, 1000); // 1000ms = 1s
It works! Looks smooth. I also added controls, so you can enter any number of data points and clusters.
Check it out: https://ivanludvig.github.io/kmeans/
The conclusion that I want to make today is the following:
Slow is smooth, smooth is fast