Full code
Demo (requires WebGPU support)

Atwood’s Law states:

Any application that can be written in JavaScript, will eventually be written in JavaScript.

Once I had an assignment to implement the k-means clustering algorithm with a visualization in any programming language. Considering that I had experience making pretty visualizations with JavaScript, that’s what I chose and it turned out great. Naturally, someone told me that JS is slow compared to C++, Python or whatever (Rust wasn’t a big thing then). But now, years later, having heard of the trendy and efficient WebGPU, I decided to check it out by implementing k-means on the GPU with JavaScript. I am especially excited about compute shaders.

Setting up

We start by creating the index.html file with a 512x512 canvas and linking a CSS file styles.css and JS script index.js.

<!DOCTYPE html>
<html>
  <head>
    <meta charset='UTF-8'>
    <title>K-Means with WebGPU</title>
    <link rel='stylesheet' href='styles.css'>
  </head>
  <body>
    <canvas width='512' height='512'></canvas>
    <script type='module' src='index.js'></script>
  </body>
</html>

In styles.css we set a black background and position the canvas in the center.

body {
    padding: 0;
    margin: 0;
    background-color: #000;
}

canvas {
    position: absolute;
    top: 50%;
    left: 50%;
    transform: translate(-50%, -50%);
    border: 2px solid purple;
}

In the index.js file we check for WebGPU support (browser compatibility can be found here).

const adapter = await navigator.gpu?.requestAdapter();
const device = await adapter?.requestDevice();
if (!device) {
    window.alert('WebGPU not supported');
    throw new Error('WebGPU not supported');
}

const canvas = document.querySelector('canvas');
const context = canvas.getContext('webgpu');
const canvasFormat = navigator.gpu.getPreferredCanvasFormat();
context.configure({ device, format: canvasFormat });

Let’s assume we are in 2D, so we can visualize it easily. Suppose we want to cluster N points into K clusters. We store the following constants in config.js. RESOLUTION is the size of the canvas and WORKGROUP_SIZE is a parameter used for compute shaders.

const N = 1000;
const K = 5;
const RESOLUTION = 512;
const WORKGROUP_SIZE = 4;

export { N, K, RESOLUTION, WORKGROUP_SIZE };

The data we are working with are the following:

  • points: array of floats of size 2N (two coordinates for each point)
  • centroids: array of floats (coordinates) of size 2K
  • clusters: array of integers of size N, indicating the id of the assigned cluster for each point

We observe that during running k-means, points remains constant, while centroids and clusters change with each iteration (unless they converge). Our app will be split into two parts: computing and rendering. The solution I came up with is shown in the following diagram.

scheme

Render

First, we will implement rendering of the data points as squares and cluster centroids as triangles.

!
Beware: boilerplate ahead
keep junior web developers away from the screen

Vertex Buffers

To render data points, we first create a vertex buffer with a square (a square is 2 triangles = 6 vertices).

const pointVertices = new Float32Array([
    -1, -1,
    1, -1,
    1, 1,
    -1, -1,
    1, 1,
    -1, 1
]);
const pointVertexBuffer = device.createBuffer({
    size: pointVertices.byteLength,
    usage: GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST
});
device.queue.writeBuffer(pointVertexBuffer, 0, pointVertices);

In a similar way we define the vertex buffer for centroids, which are rendered as bigger triangles.

const centroidVertices = new Float32Array([
    -2, -2,
    2, -2,
    0, 2
]);
const centroidVertexBuffer = device.createBuffer({
    size: centroidVertices.byteLength,
    usage: GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST
});
device.queue.writeBuffer(centroidVertexBuffer, 0, centroidVertices);

Next, the layout of vertex buffers is specified: the array contains pairs of coordinates, represented by 4-byte (32 bit) floats.

const vertexBufferLayout = {
    arrayStride: 2 * 4, // 2 x 4 bytes
    attributes: [{ format: 'float32x2', offset: 0, shaderLocation: 0 }]
};

Storage Buffers

Data point coordinates are generated randomly as 2N floats in the range (0, 1). Storage buffers are used rather than uniform buffers as they can hold more data. Uniform buffers have a minimum maximum (yeah, right) size of 64k. Since we copy data into these buffers, we add the COPY_DST usage flag.

const points = new Float32Array(2 * N);
for (let i = 0; i < 2 * N; i++) {
    points[i] = Math.random();
}
const pointsBuffer = device.createBuffer({
    size: points.byteLength,
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
});
device.queue.writeBuffer(pointsBuffer, 0, points);

Similarly, 2K floats are generated for coordinates of K centroids.

const centroids = new Float32Array(2 * K);
for (let i = 0; i < N; i++) {
    centroids[i] = Math.random();
}
const centroidsBuffer = device.createBuffer({
    size: centroids.byteLength,
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
});
device.queue.writeBuffer(centroidsBuffer, 0, centroids);

To store assigned clusters to points, we create a buffer, but keep it empty.

const clustersBuffer = device.createBuffer({
    size: 4 * N, // 4 bytes x N points
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
});

Then we create a bind group and the rendering pipeline. But first we define the layouts, specifying that these buffers are visible to the vertex and compute shaders and read-only storage buffers.

const renderBindGroupLayout = device.createBindGroupLayout({
    entries: [{
        binding: 0,
        visibility: GPUShaderStage.VERTEX | GPUShaderStage.COMPUTE,
        buffer: { type: 'read-only-storage' }
    }, {
        binding: 1,
        visibility: GPUShaderStage.VERTEX | GPUShaderStage.COMPUTE,
        buffer: { type: 'read-only-storage' }
    },{
        binding: 2,
        visibility: GPUShaderStage.VERTEX | GPUShaderStage.COMPUTE,
        buffer: { type: 'read-only-storage' }
    }]
});

const renderBindGroup = device.createBindGroup({
    layout: renderBindGroupLayout,
    entries: [
        { binding: 0, resource: { buffer: pointsBuffer } },
        { binding: 1, resource: { buffer: centroidsBuffer } },
        { binding: 2, resource: { buffer: clustersBuffer } }
    ]
});

const renderPipelineLayout = device.createPipelineLayout({
    bindGroupLayouts: [renderBindGroupLayout]
});

Shaders

The vertex shader for points runs once for every vertex in the vertex buffer for each data point. It returns a struct containing position (like any vertex shader) and the cluster index of a point, passing it to the fragment shader. By using this design, we can use the same fragment shader for both entities. To use storage buffers, we specify their bind group index via @group() and binding indices via @binding(), which were defined in the bind group layout. Inside the main vs function we do a simple transformation, mapping each point position from (0, 1) into the clip space (-1, 1) and stretching the square so it’s 4px big.

struct VertexOutput {
    @builtin(position) position: vec4f,
    @location(0) @interpolate(flat) cluster: u32
};

@group(0) @binding(0) var<storage> points: array<f32>;
@group(0) @binding(2) var<storage> clusters: array<u32>;

@vertex
fn vs(
    @location(0) vertex: vec2f, 
    @builtin(instance_index) i: u32
) -> VertexOutput {
    let position = 2 * vec2f(points[2*i], points[2*i + 1]) - 1 + 4 * vertex/${RESOLUTION};

    var output: VertexOutput;
    output.position = vec4f(position, 0, 1);
    output.cluster = clusters[i];

    return output;
}

How did we use RESOLUTION? Oh, we can’t work with JS without resorting to ugly solutions. The shader is stored in a JS file pointsShader.js, which exports it as a string. The constants are imported from config.js.

A similar vertex shader is used for rendering centroids.

struct VertexOutput {
    @builtin(position) position: vec4f,
    @location(0) @interpolate(flat) cluster: u32
};

@group(0) @binding(1) var<storage> centroids: array<f32>;

@vertex
fn vs(
    @location(0) vertex: vec2f,
    @builtin(instance_index) i: u32
) -> VertexOutput {
    let position = 2 * vec2f(centroids[2*i], centroids[2*i + 1]) - 1 + 6*vertex/${RESOLUTION};

    var output: VertexOutput;
    output.position = vec4f(position, 0, 1);
    output.cluster = i;

    return output;
}

Why didn’t we use the same vertex shader since they are so similar? We use different sources to look up the cluster id: the clusters array for points and the instance index for centroids. I thought about solutions to reuse the same vertex shader: either by using separate bind groups or separate fragment shaders, but both alternatives make things uglier.

Next, the fragment shader, which colours everything. The colour of our points is defined by the cluster passed from the vertex shaders. I came up with this function that maps each integer from 0 to K-1 to a unique and hopefully distinct colour.

@fragment
fn fs(@location(0) @interpolate(flat) cluster: u32) -> @location(0) vec4f {
    let normalizedColorId = f32(cluster) / ${K};

    return vec4f(
        fract(1 - pow(normalizedColorId, 2)), 
        fract(normalizedColorId + 0.31), 
        fract(normalizedColorId + 0.82), 
        1
    );
}

Pipelines

Let’s define a pipeline for rendering points and centroids using the shaders defined above. They share the same fragment shader.

const fragmentModule = device.createShaderModule({ code: fragmentShader });
const pointsModule = device.createShaderModule({ code: pointsShader });
const pointsPipeline = device.createRenderPipeline({
    layout: renderPipelineLayout,
    vertex: {
        module: pointsModule,
        buffers: [vertexBufferLayout]
    },
    fragment: {
        module: fragmentModule,
        targets: [{ format: canvasFormat }]
    }
});

const centroidsModule = device.createShaderModule({ code: centroidsShader });
const centroidsPipeline = device.createRenderPipeline({
    layout: renderPipelineLayout,
    vertex: {
        module: centroidsModule,
        buffers: [vertexBufferLayout]
    },
    fragment: {
        module: fragmentModule,
        targets: [{ format: canvasFormat }]
    }
});

Almost there! Now, we create a function for running the pipelines, which is called every 500ms. It begins with the computations (TODO), then clearing the view with black. The points pipeline is ran for pointVertices.length / 2 vertices in N instances: one instance per point. And similarly for the centroids pipeline.

const update = () => {
    const encoder = device.createCommandEncoder();
    // ... compute pass ...
    const pass = encoder.beginRenderPass({
        colorAttachments: [{
            view: context.getCurrentTexture().createView(),
            loadOp: 'clear',
            clearValue: { r: 0, g: 0, b: 0, a: 1.0 },
            storeOp: 'store',
        }]
    });

    pass.setPipeline(pointsPipeline);
    pass.setBindGroup(0, renderBindGroup);
    pass.setVertexBuffer(0, pointVertexBuffer);
    pass.draw(pointVertices.length / 2, N);

    pass.setPipeline(centroidsPipeline);
    pass.setBindGroup(0, renderBindGroup);
    pass.setVertexBuffer(0, centroidVertexBuffer);
    pass.draw(centroidVertices.length / 2, K);

    pass.end();
    device.queue.submit([encoder.finish()]);
}

setInterval(update, 500);

After all this work, all we get is this. The points are all the same colour, because we didn’t compute the assigned clusters yet.

intermediate result

Compute

The computational phase consists of two parts:

  1. assign: assigning clusters to points, outputting clusters
  2. update: updating the centroids positions, outputting centroids

Buffers and Bind Groups

Since storage buffers used in vertex shaders have to be read-only, we create separate read-write storage buffers for the compute phase in a separate bind group. After the computations are done, the results are copied to the read-only buffers and rendered.

const centroidsBufferComp = device.createBuffer({
    size: centroids.byteLength,
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
});
const clustersBufferComp = device.createBuffer({
    size: 4 * N,
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
});

const computeBindGroupLayout = device.createBindGroupLayout({
    entries: [{
        binding: 0,
        visibility: GPUShaderStage.COMPUTE,
        buffer: { type: 'storage' }
    }, {
        binding: 1,
        visibility: GPUShaderStage.COMPUTE,
        buffer: { type: 'storage' }
    }]
});

const computeBindGroup = device.createBindGroup({
    layout: computeBindGroupLayout,
    entries: [
        { binding: 0, resource: { buffer: centroidsBufferComp } },
        { binding: 1, resource: { buffer: clustersBufferComp } }
    ]
});

Since we have to access the data point coordinates in compute shaders (stored in renderBindGroup), both bind groups are passed to the pipeline.

const computePipelineLayout = device.createPipelineLayout({
    bindGroupLayouts: [renderBindGroupLayout, computeBindGroupLayout]
});

Compute Shaders

The assignment phase is run for each point, indexed by the invocation id. We iterate over the centroids and find the closest one, updating the value in clusters. Read only buffers from @group(0) are used where possible.

@group(0) @binding(0) var<storage> points: array<f32>;
@group(0) @binding(1) var<storage> centroids: array<f32>;
@group(1) @binding(1) var<storage, read_write> clusters: array<u32>;

fn dist(a: vec2f, b: vec2f) -> f32 {
    return pow((a.x - b.x), 2) + pow((a.y - b.y), 2);
}

@compute @workgroup_size(${WORKGROUP_SIZE})
fn cs(@builtin(global_invocation_id) id: vec3u) {
    let pointId = id.x;
    let pos = vec2f(points[2*pointId], points[2*pointId + 1]);
    var min_dist = -1.;
    var closest = 0;

    for (var i = 0; i < ${K}; i++) {
        let centroid = vec2f(centroids[2*i], centroids[2*i + 1]);
        let d = dist(pos, centroid);
        if (min_dist == -1 || d < min_dist){
            closest = i;
            min_dist = d;
        }
    }

    clusters[pointId] = u32(closest);
}

The update phase is run for each cluster/centroid. We compute the mean position of points belonging to the current cluster.

@group(0) @binding(0) var<storage> points: array<f32>;
@group(0) @binding(2) var<storage> clusters: array<u32>;
@group(1) @binding(0) var<storage, read_write> centroids: array<f32>;

@compute @workgroup_size(${WORKGROUP_SIZE})
fn cs(@builtin(global_invocation_id) id: vec3u) {
    let centroid = id.x;
    var sum = vec2f(0);
    var num = 0;

    for (var i = 0; i < ${N}; i++) {
        if (clusters[i] == centroid) {
            sum += vec2f(points[2*i], points[2*i + 1]);
            num += 1;
        }
    }

    centroids[2*centroid] = sum.x / f32(num);
    centroids[2*centroid + 1] = sum.y / f32(num);
}

Pipelines

Compute pipelines are created for each step.

const assignModule = device.createShaderModule({ code: assignClusters });
const assignPipeline = device.createComputePipeline({
    layout: computePipelineLayout,
    compute: {
        module: assignModule
    }
});

const updateModule = device.createShaderModule({ code: updateCentroids });
const updatePipeline = device.createComputePipeline({
    layout: computePipelineLayout,
    compute: {
        module: updateModule
    }
});

First, the assignment pipeline is run. The compute shader is run for each point, hence we need Math.ceil(N / WORKGROUP_SIZE) work groups. After this step, the assigned clusters are copied into the read-only buffer. Then, the centroid positions are updated and copied into the corresponding read-only buffer. In both steps we attach both bind groups. After the computations are done, the render pipeline is run.

const encoder = device.createCommandEncoder();

const assignPass = encoder.beginComputePass();
assignPass.setPipeline(assignPipeline);
assignPass.setBindGroup(0, renderBindGroup);
assignPass.setBindGroup(1, computeBindGroup);
assignPass.dispatchWorkgroups(Math.ceil(N / WORKGROUP_SIZE));
assignPass.end();

encoder.copyBufferToBuffer(clustersBufferComp, 0, clustersBuffer, 0, 4 * N);

const updatePass = encoder.beginComputePass();
updatePass.setPipeline(updatePipeline);
updatePass.setBindGroup(0, renderBindGroup);
updatePass.setBindGroup(1, computeBindGroup);
updatePass.dispatchWorkgroups(Math.ceil(K / WORKGROUP_SIZE));
updatePass.end();

encoder.copyBufferToBuffer(centroidsBufferComp, 0, centroidsBuffer, 0, centroids.byteLength);

// ... render pass ...

And… it works! You can see the clusters being updated until convergence. Efficiency of the code wasn’t investigated in this post, but it’s a question worth looking into because that’s the whole point of performing computations on the GPU.

final result