Normalization-free transformers are subcritical, Part 2.
Overview
In my previous blog post, I demonstrated empirically that normalization-free (DyT/Derf) transformers [1, 2] have worse gradient propagation properties than the standard pre-LN transformer – namely, they exhibit much stronger gradient amplification (approximately stretched-exponential, as opposed to the power-law growth in the pre-LN baseline). Although the theoretical analysis correctly characterized the gap between the models at a qualitative level, it did not account for the attention mechanism.
In this blog post, I modify the theoretical argument by reintroducing attention, using the theoretical framework developed in [3], which restricts the initial token configurations to permutation-invariant ones. We generalize the analysis in [3] to normalization-free transformers by replacing the LayerNorms with pointwise activation functions. We show that attention does not change the mechanism that makes gradient propagation in normalization-free transformers inferior to that in pre-LN transformers. However, we can now demonstrate not only qualitative agreement between theoretical and empirical activation norms, gradients, and Jacobians, but also perfect quantitative agreement.
How to read this post
This post contains the following sections:
-
- 1.1 Introduction
- 1.2 Setup
- 1.3 Forward signal propagation
- 1.4 Backward gradient propagation
-
- 2.1 Theory
- 2.2 Experiments
-
3. References
Readers interested primarily in why normalization-free Transformers have worse gradient propagation may want to proceed to Section 2.1, which relies on results from Sections 1.3 and 1.4.
Section 1.1 overviews mean-field theory at initialization, which tracks the dynamics of a pair of activation vectors. Section 1.2 introduces the notation for the Transformer. Sections 1.3 and 1.4 extend the mean-field recursion relations of [3] to normalization-free transformers. Section 2.1 uses these relations to compare gradient propagation for LayerNorm versus /-like pointwise normalizations, and Section 2.2 validates the theory against measurements in a ViT.
Mean-field framework
Introduction
For a general introduction to the theory of signal propagation and the mean-field formalism in the large-width limit at initialization, I refer the reader to my previous blog post.
[3] observed that, for permutation-equivariant transformers (i.e., with bidirectional attention and no positional encoding), the mean-field theory at initialization effectively reduces to the layer-to-layer evolution of just two degrees of freedom, provided the initial token configuration is permutation-invariant: the component variance of the activation vector at a given position, , and the covariance between components of activation vectors at different positions, . Here, is a -dimensional activation vector at layer and position . Geometrically, the former is the squared norm of the activation vector at a given position, normalized by , while the latter is the normalized dot product between activation vectors at different positions; consequently, is the cosine similarity between activation vectors at different positions.
The two main ingredients in the signal-propagation calculation at initialization are (co)variance propagation through pointwise nonlinearities and through linear layers. For the former, let us illustrate the calculation using a nonlinearity and two activation vectors at different positions, and , whose component-wise covariance is given by
That is, for any component of the activation vectors, , and components with different indices are uncorrelated. We then compute the (non-centered) covariance matrix after applying the nonlinearity as
In this expression, each is a scalar dummy variable rather than a multi-component activation vector – we have simply suppressed the component index . Even if the original activations have zero mean, passing through some nonlinearities (e.g., ReLU) can introduce a non-zero mean ; however, it is eliminated by the subsequent linear transformation.
A subsequent linear transformation with zero mean and variance (assuming Gaussian weights for simplicity) multiplies the covariance by . Thus, combining the nonlinearity with the linear transformation yields .
Setup
Assume a Transformer with context size has layers, alternating (bidirectional) self-attention and a position-wise MLP with ReLU activation, with residual connections. The input to each residual branch is normalized – either with LayerNorm or with a pointwise transform such as DyT/Derf. For simplicity, we assume single-head attention – in case of multi-head attention the signal propagation equations remain exactly identical. The dynamics of activation vectors of hidden dimension are given by the following equation:
Note that here indexes layers (attention and MLP) rather than transformer blocks. The vectors are normalized activation vectors:
Here may be or a pointwise transform; for example, in the case of Derf with parameter , . The attention scores between the -th query and the -th key are computed in the standard way:
All weights are initialized from zero-mean Gaussian distributions, with variances that are shared across Transformer blocks: in the attention layer, , , , have component-wise variances , , , and , respectively; in the MLP layer, and have component-wise variances and . Here denotes the input dimension of the layer.
Forward signal propagation
With a number of simplifying assumptions about the statistics of attention scores (see Assumption 2 in [3]), one can solve for the dynamics of and :
For brevity, we define , , and . We recall that is the context size, i.e. the number of tokens/patches. Finally, and are the covariance components after propagation through :
where , , and , .
The coefficients of and in the MLP expressions in Eqs. (6) and (7) arise from covariance propagation through the ReLU nonlinearity, where is given by
Backward gradient propagation
To characterize backward gradient propagation between layers and with , we use the Frobenius norm of the Jacobian , averaged over weight initializations – the APJN (averaged partial Jacobian norm) [4]:
In the large-width limit, the APJN satisfies the recursion relation [4]
Equivalently, for . In our setup,
Here is the variance obtained by propagating through , where the prime denotes the derivative:
This expression is somewhat vague for LayerNorm, so in that case we simply define to avoid confusion. The quantity arises from differentiating the normalization, both for the pointwise transform and for LayerNorm.
In fact, for LayerNorm and Derf, the quantities , , and can be computed analytically. We provide these expressions here for completeness.
LayerNorm:
Derf (with parameter ):
LayerNorm vs. Derf/DyT
Theory
We now have all the components (Eqs. (6), (7), and (12)) to show that, for /-like normalization functions, the APJN grows approximately as a stretched-exponential, i.e. like for some parameter , whereas in the standard pre-LN setup it grows approximately as a power law. The general argument is identical to that in my previous blog post, so we omit the details here. The key idea is that in both cases for large , as follows from Eq. (6); however, in Eq. (12), for /-like normalization functions, , whereas for LayerNorm, . This implies the stated behavior of the APJN, which is given by a product of the factors.
This conclusion remains valid in the presence of attention. In the forward pass, the linear growth of persists because the attention contribution is bounded. In the backward pass, since , the denominator in Eq. (12) (attn) cannot suppress by more than a factor of . And even if it could, the MLP contribution remains the same as without attention, providing stretched-exponential growth for /-like normalization functions and power-law growth for LayerNorm.
Experiments
We compute the layer-wise (co)variances and , as well as the APJN, both from the mean-field analysis and by estimating them from the ViT model, for the pre-LN baseline and for Derf with various values of the parameter .
Fig. 1 compares quantities and computed from the mean-field analysis (left) with those estimated from a ViT forward pass (right), where the input to the first Transformer block is a generated permutation-invariant token configuration. In both cases, and ; the number of layers is , and the context size is . The ViT model is initialized as in vit_large_patch16_224, with hidden dimension and component-wise weight standard deviations equal to . To match its behavior, in mean-field analysis we set and .
Fig. 2 (left) compares APJN computed from mean-field analysis and from the ViT model via Hutchinson’s method [4]. Fig. 2 (right) shows gradient amplification coefficients estimated from the ViT backward pass on a batch of permutation-invariant token configurations. The observed gradient amplification is slightly larger than the APJN, likely because the gradients lie in a subspace corresponding to larger-than-average Jacobian eigenvalues.
Overall, matching the pre-LN gradient amplification behavior requires choosing a smaller in the Derf model. However, smaller values of in Derf also yield smaller updates to the residual stream. If instead we try to align pre-LN and Derf by matching the magnitude of the residual-stream update, then the gradient amplification in Derf becomes much larger. Concretely, choosing so that the curves for pre-LN and Derf are as close as possible leads to a large .
Figure 1. (a) The component-wise variance of the activation vector at a given position, , (b) the covariance between components of activation vectors at different positions, , and (c) their ratio, . Left: mean-field analysis; right: values estimated from a ViT forward pass on a batch of permutation-invariant token configurations with the same initial value of . The black solid line indicates the pre-LN baseline. Colored lines show Derf variants with varying values of .
Figure 2. Left: APJN. Solid lines indicate values computed from mean-field theory (MFT). Crosses indicate values obtained from a ViT via Hutchinson’s method [4]. Right: gradient amplification coefficients estimated from a ViT backward pass on a batch of permutation-invariant token configurations.