Self Attention in Convolutional Neural Networks

Ramin
4 min readMar 9, 2021

--

I recently added self-attention to a network that I trained to detect walls and it improved the Dice score for wall-segementation. I am writing this short article to summarize self-attention in CNNs. I write these notes primarily so that I can come back to them and recall what I did, but I hope you find it useful too.

Why Self-Attention

Self-attention is described in this article. It increases the receptive field of the CNN without adding computational cost associated with very large kernel sizes.

How Does It Work

  • Reshape the features from previous hidden layer such that:

where, C is the number of channels and N is the product of all the other dimensions (we will see the code later)

  • Perform 1x1 convolutions on x to obtain, f, g, and h. This will change the number of channels from C to C*:
  • Compute a series of softmax weights between pixel position in f(x) and g(x):

These weights called “attention map” and are essentially quantify the “importance” of pixels j in the image relative to pixel i. Since these weights (beta) are computed over the entire height and width of the feature set, the receptive field is not limited to the size of a small kernel anymore.

  • Compute the output of the self-attention layer as:

Here, v is the output of yet another 1x1 convolution. Note that the output has the same number of channels as the input features to the self-attention layer.

Here is a figure from the paper that visualizes these operations

Typically, we set: C* = C/8.

  • As the last step, we add the input features, x, to a weighted version of the output (gamma is another learnable scalar parameter):

Pytorch Implementation

The following short and efficient implementation is from Fast.ai

Line 4: define three 1x1 conv layers to create, f(x), g(x), h(x). These are typically called query, key, and value (see line 14)

Line 13: Reshape to a tensor with C x N size.

Line 15: Compute the softmax attention weights as defined above (“bmm” is batch matrix multiply in torch).

Line 17: Restore the original shape of features

This implementation is somewhat different from (but equivalent) to the algorithm described in the paper in that it combines 1x1 convolutions v(x) and h(x) together and calls is h(x) or “value”. The combined 1x1 conv layer has C input channels and C output channels. This implementation is equivalent to the algorithm in the paper because learning two back-to-back 1x1 conv layers are equivalent to learning a single conv layer of compatible size.

Sample Results

I used the self-attention layer in a UNet architecture by replacing the conv layer in the UNet blocks. Introduction of the self-attention layer improved the dice score for segmenting walls. Here is an example from the “Wall Color AI” app:

Source: Wall Color AI App

Become a ML Writer

--

--