Layer Normalization as fast as possible

• ⏱ 10 min read

LayerNorm (and its close sibling RMSNorm) have superseded batch normalization as the go-to normalization technique for deep learning. Some kind of normalization is essential in stabilizing inputs to each layer ensuring the model can learn efficiently. LLaMA, Whisper and other recent transformer architectures all use (Layer|RMS)Norm. As a result, if you want to run these big models in the browser on the GPU (like me), you need to implement a fast LayerNorm kernel yourself.

Of all the kernels I wrote, I think LayerNorm is the most insightful for less seasoned GPU programmers. In this post we will explore how we can iteratively arrive at quite a competitive WebGPU LayerNorm kernel, whilst covering some of the cornerstones of GPU programming.

The problem

On the surface, LayerNorm seems like a simple operation - however it is not naively parallelizable. Let's see how PyTorch defines LayerNorm in their documentation:

y = x E [ x ] V a r [ x ] + ϵ γ + β \begin{align} y = \frac{x - E[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta \end{align}

x x is the input tensor, γ \gamma and β \beta are learnable parameters, and ϵ \epsilon is a small constant to avoid division by zero. Looking at this formula, the first thing to note as a GPU programmer is that it requires 2 group statistics: mean and variance. This means that we can't immediately parallelize the computation of each output element. First, we need to compute the mean and variance along the dimension we intend to normalize, and then use those to compute each output element. Computing the mean is straightforward, but the variance is a bit more complicated.

Reductions

Whenever we are required to compute some kind of total or group statistic, your first thought should be to use a reduction. Reductions are a core primitive of GPU programming, and it is essential you build an understanding of them. Fortunately, Mark Harris of NVIDIA has written a timeless resource on the topic. I highly recommend you read it before proceeding. Without the use of reductions, you'll find yourself unable to parallelize many common operations.

However, if you are short on time, here is a quick visual that should convey the core idea of a reduction:

Reduction

Above you can see a summation of 16 elements using 8 threads. Each thread accesses 2 elements, offset by the number of active invocations in that iteration.

One pass algorithm

Now that we have a basic understanding of reductions, let's see how we can use them! The most obvious way to compute the mean and variance for our LayerNorm is to use a one-pass algorithm. In the most common case, we will be normalizing the last dimension of a tensor, where each workgroup computes the mean and variance of a single row.

This method requires the use of shared memory, where all threads in a block share data and synchronize their execution. Each thread will accumulate 2 partial results: the sum and the squared sum. These can then be stored in shared memory, and reduced to a single value. The mean and variance can then be computed from the reduced values.

The kernel looks something like this (Full Code):

const BLOCK_SIZE: u32 = 128u;
var<workgroup> sum: array<f32, BLOCK_SIZE>;
var<workgroup> sq_sum: array<f32, BLOCK_SIZE>;
var<workgroup> mean: f32;
var<workgroup> sigma: f32;

fn block_reduce(index: u32, stride: u32) {
    if index < stride {
        sum[index] += sum[index + stride];
        sq_sum[index] += sq_sum[index + stride];
    }
    workgroupBarrier();
}

fn main() {
    //Each thread computes the partial sum and squared sum
    for (var i = local_id.x; i < metadata.N; i += BLOCK_SIZE) {
       let val = X[anchor + i];
       sum[local_id.x] += val;
       sq_sum[local_id.x] += val * val;
    }
    workgroupBarrier(); //Wait for all threads to finish

    block_reduce(local_id.x, 64u);
    //... manual unroll ...
    block_reduce(local_id.x, 1u);

    if local_id.x == 0u {
        mean = sum[0] / f32(metadata.N);
        //We use inverseSqrt for performance, so technically this isn't sigma
        sigma = inverseSqrt((sq_sum[0] - (mean * mean) / f32(metadata.N)) + metadata.eps);
    }
    workgroupBarrier();
    //...compute LN...
}

We process the data in blocks of 128 elements, and use a manually unrolled reduction to reduce down to our 2 final values. Let's see how it performs:

One-Pass
time: [124979.2722 ns 126546.2184 ns 128406.2134 ns]
thrpt: [30.4210 GiB/s 30.8682 GiB/s 31.2552 GiB/s]

2 samples not close - AVGE=0.0000016181261 MAE=0.000022888184 at [0, 137, 615]

Whilst the performance is good, we have a problem. The output isn't matching PyTorch! What's going on here?

Unfortunately, we have bumped into a common problem in GPU programming: Catastrophic cancellation. On the highlighted line in the code sample above, we do the following (we use inverseSqrt for peformance, but it's the same idea):

σ = x 2 μ 2 N + ϵ \begin{align} \sigma &= \sqrt{\sum{x^2} - \frac{\mu^2}{N} + \epsilon} \\ \end{align}

If the squared sum is close to the mean squared over N N , computing the difference will lead to a significant loss of precision. Whenever we have 2 good approximations of values of a similar magnitude, subtracting them can lead to a very poor approximation of the true value. For a formal analysis of this phenomenon, I recommend exploring the Wikipedia article linked above.

Whilst catastrophic cancellation can be mitigated by using higher precision types or something like Kahan summation, these both have performance implications. Is there a way we can avoid this problem without sacrificing performance?

Two-pass algorithm

Next, let's try a two-pass algorithm to avoid the troublesome subtraction in the one-pass algorithm. By first computing the mean, and subtracting it from each element as we compute the variance, we can avoid catastrophic cancellation. The kernel looks something like this (Full Code):

fn mu(local_id: vec3<u32>, anchor: u32) -> f32 {
    var threadSum = 0f;
    for (var i: u32 = local_id.x; i < metadata.N; i += BLOCK_SIZE) {
        threadSum += X[anchor + i];
    }
    smem[local_id.x] = threadSum;
    workgroupBarrier();

    block_reduce(local_id.x, 64u);
    ///... manual unroll ...
    block_reduce(local_id.x, 1u);

    return smem[0] / f32(metadata.N);
}

fn sigma(local_id: vec3<u32>, anchor: u32, mu: f32) -> f32 {
    var threadSum = 0f;
    for (var i: u32 = local_id.x; i < metadata.N; i += BLOCK_SIZE) {
        let val = X[anchor + i] - mu;
        threadSum += (val * val);
    }
    smem[local_id.x] = threadSum;
    workgroupBarrier();

    block_reduce(local_id.x, 64u);
    ///... manual unroll ...
    block_reduce(local_id.x, 1u);

    return smem[0] / f32(metadata.N);
}

This is very similar to the one-pass algorithm, but will obviously be much slower as we need to perform two passes over the data. Let's see how it performs:

Two-Pass
time: [158307.3426 ns 159382.1927 ns 160646.1034 ns]
thrpt: [24.3159 GiB/s 24.5087 GiB/s 24.6751 GiB/s]

All close - AVGE=0.0000005285047 MAE=0.0000076293945 at [0, 18, 370]

Great, we match PyTorch's output! But we've taken a significant performance hit. Is there a way we can get the best of both worlds?

Welford's algorithm

Fortunately, there is a well known one-pass algorithm that gives us the best of both worlds, with high performance and numerical precision. This algorithm is known as Welford's algorithm. Welford's algorithm is an online algorithm, allowing us to build up the complete solution from partial solutions. To do this, we need to form a recurrence relation for each component.

Let's look at the relation for the mean:

μ k = μ k 1 + x k μ k 1 k \begin{align} \mu_{k} &= \mu_{k-1} + \frac{x_k - \mu_{k-1}}{k} \\ \end{align}

We can update our mean by adding the difference between the current value and the previous mean, divided by the new number of elements (our new values "contribution" to the mean).

Let's do the same for variance:

σ k 2 = σ k 1 2 + ( x k μ k 1 ) ( x k μ k ) σ k 1 2 k \begin{align} \sigma_{k}^2 &= \sigma_{k-1}^2 + \frac{(x_k - \mu_{k-1})(x_k - \mu_{k}) - \sigma_{k-1}^2}{k} \\ \end{align}

This is a bit more complicated, but it's the same idea. We can update our variance using only the previous variance, mean and the new value. In practice, the above formula for variance is not used, and the algorithm originally stated by Welford uses the sum of squares of differences from the current mean as the quantity for updating. This is known as M 2 , k M_{2,k} , and is defined as:

M 2 , k = M 2 , k 1 + ( x k μ k 1 ) ( x k μ k ) σ k 2 = M 2 , k k \begin{align} M_{2,k} &= M_{2,k-1} + (x_k - \mu_{k-1})(x_k - \mu_{k}) \\ \sigma_{k}^2 &= \frac{M_{2,k}}{k} \\ \end{align}

Now that we have these recurrence relations, threads can now compute "partial means" and "partial variances", and using a reduction we can construct the full solution.

Implementing Welford's Algorithm

To achieve peak performance with Welford's algorithm, we are going to make use of the bleeding edge WebGPU Subgroups Proposal.

For those unfamiliar with the term, a subgroup is a generalization of the instrisic thread grouping of the underlying GPU hardware. For CUDA it's known as a warp, wave for DirectX and simdgroup for Metal. The group of threads executes the same instruction in parallel, and can share data between their registers. This circumvents the need for shared memory, and can lead to significant speedups. Designing your algorithms with the thread grouping in mind is essential for achieving peak performance.

Let's take a look at the core components of Welford's algorithm:

fn welford_combine(val: f32, mean: ptr<function, f32>, m2: ptr<function, f32>, count: ptr<function, f32>) {
    *count += 1.0;
    let delta1 = val - *mean;
    *mean += delta1 / *count;
    let delta2 = val - *mean;
    *m2 += delta1 * delta2;
}

This is the heart of the algorithm, which updates the running mean and M 2 M_{2} defined above. Once each thread has computed its partial mean and M 2 M_{2} , we can combine them together to form the full solution. This is where the subgroup extension comes in:

fn welford_warp_reduce(thread_mean: f32, thread_m2: f32, thread_count: f32, mean: ptr<function, f32>, m2: ptr<function, f32>, count: ptr<function, f32>) {
    *mean = thread_mean;
    *m2 = thread_m2;
    *count = thread_count;
    //Subgroup reduce into first lane
    for (var offset = subgrp_size >> 1u; offset > 0u; offset >>= 1u) {
        let b_mean = subgroupShuffleDown(*mean, offset);
        let b_m2 = subgroupShuffleDown(*m2, offset);
        let b_count = subgroupShuffleDown(*count, offset);
        block_welford_combine(b_mean, b_m2, b_count, mean, m2, count);
    }

    //Broadcast result to all threads
    *mean = subgroupBroadcast(*mean, 0u);
    *m2 = subgroupBroadcast(*m2, 0u);
    *count = subgroupBroadcast(*count, 0u);
}

This uses the new subgroupShuffleDown function, which passes the value from the active invocation at lane + offset to the invocation at lane. We then use a version of welford_combine that handles the case where count > 1 to merge them. Once the subgroup has reduced the value to the first lane, we broadcast it to all threads using subgroupBroadcast.

Welford Vectorized
time: [111251.7095 ns 113935.1043 ns 117158.9148 ns]
thrpt: [33.3415 GiB/s 34.2849 GiB/s 35.1118 GiB/s]

All close - AVGE=0.0000014362298 MAE=0.000017166138 at [0, 1688, 413]

As we can see, our performance is now on par with the one-pass algorithm, but no longer suffers from numerical instability!

Conclusion

We've seen how we can arrive at a fast, numerically stable LayerNorm kernel by using the proposed subgroup extension, which will be a welcome addition to the WebGPU ecosystem. There is still more to explore here, for more information I recommend reading this excellent post by OneFlow on which this post is based.

Check out the benchmarking repo for more, and thanks to sekstini for their insightful comments!

References

[1] [2]

Christopher Fleetwood - 2024